Conversational RAG using Memory
Last Updated: September 24, 2024
In this notebook, we’ll explore how to incorporate memory into a RAG pipeline to enable conversations with our documents, using an InMemoryChatMessageStore
, a ChatMessageRetriever
, and a ChatMessageWriter
.
Useful Sources
Installation
Install Haystack, haystack-experimental
and datasets
with pip:
!pip install -U haystack-ai git+https://github.com/deepset-ai/haystack-experimental.git datasets
Enter OpenAI API key
import os
from getpass import getpass
if "OPENAI_API_KEY" not in os.environ:
os.environ["OPENAI_API_KEY"] = getpass("Enter OpenAI API key:")
Enter OpenAI API key:Β·Β·Β·Β·Β·Β·Β·Β·Β·Β·
Create DocumentStore and Index Documents
Create an index with seven-wonders dataset:
from haystack import Document
from haystack.components.retrievers.in_memory import InMemoryBM25Retriever
from haystack.document_stores.in_memory import InMemoryDocumentStore
from datasets import load_dataset
dataset = load_dataset("bilgeyucel/seven-wonders", split="train")
docs = [Document(content=doc["content"], meta=doc["meta"]) for doc in dataset]
document_store = InMemoryDocumentStore()
document_store.write_documents(documents=docs)
/usr/local/lib/python3.10/dist-packages/huggingface_hub/utils/_token.py:89: UserWarning:
The secret `HF_TOKEN` does not exist in your Colab secrets.
To authenticate with the Hugging Face Hub, create a token in your settings tab (https://huggingface.co/settings/tokens), set it as secret in your Google Colab and restart your session.
You will be able to reuse this secret in all of your notebooks.
Please note that authentication is recommended but still optional to access public models or datasets.
warnings.warn(
151
Create Memory
Memory, so the conversation history, is saved as ChatMessage
objects in a InMemoryChatMessageStore
. When required, you can retrieve the conversation history from the chat message store using ChatMessageRetriever
.
To store memory, initialize an InMemoryChatMessageStore
, a ChatMessageRetriever
and a ChatMessageWriter
. Import these components from
haystack-experimental
package:
from haystack_experimental.chat_message_stores.in_memory import InMemoryChatMessageStore
from haystack_experimental.components.retrievers import ChatMessageRetriever
from haystack_experimental.components.writers import ChatMessageWriter
# Memory components
memory_store = InMemoryChatMessageStore()
memory_retriever = ChatMessageRetriever(memory_store)
memory_writer = ChatMessageWriter(memory_store)
Prompt Template for RAG with Memory
Prepare a prompt template for RAG and additionally, add another section for memory. Memory info will be retrieved by ChatMessageRetriever
from the InMemoryChatMessageStore
and injected into the prompt through memories
prompt variable.
from haystack.dataclasses import ChatMessage
system_message = ChatMessage.from_system("You are a helpful AI assistant using provided supporting documents and conversation history to assist humans")
user_message_template ="""Given the conversation history and the provided supporting documents, give a brief answer to the question.
Note that supporting documents are not part of the conversation. If question can't be answered from supporting documents, say so.
Conversation history:
{% for memory in memories %}
{{ memory.content }}
{% endfor %}
Supporting documents:
{% for doc in documents %}
{{ doc.content }}
{% endfor %}
\nQuestion: {{query}}
\nAnswer:
"""
user_message = ChatMessage.from_user(user_message_template)
Build the Pipeline
Add components for RAG and memory to build your pipeline. Incorporate BranchJoiner into your pipeline to handle messages from both the user and the LLM, writing them to the memory store.
from typing import List
from haystack import Pipeline
from haystack.components.builders import ChatPromptBuilder, PromptBuilder
from haystack.components.generators.chat import OpenAIChatGenerator
from haystack.components.generators import OpenAIGenerator
from haystack.components.joiners import BranchJoiner
from haystack.components.converters import OutputAdapter
pipeline = Pipeline()
# components for RAG
pipeline.add_component("retriever", InMemoryBM25Retriever(document_store=document_store, top_k=3))
pipeline.add_component("prompt_builder", ChatPromptBuilder(variables=["query", "documents", "memories"], required_variables=["query", "documents", "memories"]))
pipeline.add_component("llm", OpenAIChatGenerator())
# components for memory
pipeline.add_component("memory_retriever", memory_retriever)
pipeline.add_component("memory_writer", memory_writer)
pipeline.add_component("memory_joiner", BranchJoiner(List[ChatMessage]))
# connections for RAG
pipeline.connect("retriever.documents", "prompt_builder.documents")
pipeline.connect("prompt_builder.prompt", "llm.messages")
pipeline.connect("llm.replies", "memory_joiner")
# connections for memory
pipeline.connect("memory_joiner", "memory_writer")
pipeline.connect("memory_retriever", "prompt_builder.memories")
<haystack.core.pipeline.pipeline.Pipeline object at 0x784bd4ed8bb0>
π
Components
- retriever: InMemoryBM25Retriever
- prompt_builder: ChatPromptBuilder
- llm: OpenAIChatGenerator
- memory_retriever: ChatMessageRetriever
- memory_writer: ChatMessageWriter
- memory_joiner: BranchJoiner
π€οΈ Connections
- retriever.documents -> prompt_builder.documents (List[Document])
- prompt_builder.prompt -> llm.messages (List[ChatMessage])
- llm.replies -> memory_joiner.value (List[ChatMessage])
- memory_retriever.messages -> prompt_builder.memories (List[ChatMessage])
- memory_joiner.value -> memory_writer.messages (List[ChatMessage])
Visualize the pipeline
Visualize the pipeline with the
show()
method to confirm the connections are correct.
pipeline.show()
Run the Pipeline
Test the pipeline with some queries. Ensure that every user query is also sent to the memory_joiner
so that both the user queries and the LLM responses are stored together in the memory store.
Here are example queries you can try:
- What does Rhodes Statue look like?
- Who built it?
while True:
messages = [system_message, user_message]
question = input("Enter your question or Q to exit.\nπ§ ")
if question=="Q":
break
res = pipeline.run(data={"retriever": {"query": question},
"prompt_builder": {"template": messages, "query": question},
"memory_joiner": {"value": [ChatMessage.from_user(question)]}},
include_outputs_from=["llm"])
assistant_resp = res['llm']['replies'][0]
print(f"π€ {assistant_resp.content}")
Enter your question or Q to exit.
π§ What does Rhodes Statue look like?
π€ While scholars do not know exactly what the Rhodes Statue looked like, they have a good idea of what the head and face looked like. It would have had curly hair with evenly spaced spikes of bronze or silver flame radiating, similar to the images found on contemporary Rhodian coins. The anecdotal depictions of the Colossus straddling the harbor's entry point have no historic or scientific basis.
Enter your question or Q to exit.
π§ Who built it?
π€ The Hanging Gardens of Babylon were believed to have been built by the Neo-Babylonian King Nebuchadnezzar II for his Median wife, Queen Amytis. There is no specific information provided in the supporting documents about who built the Rhodes Statue.
Enter your question or Q to exit.
π§ Q
β οΈ If you followed the example queries, you’ll notice that the second question was answered incorrectly. This happened because the retrieved documents weren’t relevant to the user’s query. The retrieval was based on the query “Who built it?”, which doesn’t have enough context to retrieve documents. Let’s fix it with rephrasing the query for search.
Prompt Template for Rephrasing User Query
In conversational systems, simply injecting memory into the prompt is not enough to perform RAG effectively. There needs to be a mechanism to rephrase the user’s query based on the conversation history to ensure relevant documents are retrieved. For instance, if the first user query is “What’s the first name of Einstein?” and the second query is “Where was he born?”, the system should understand that “he” refers to Einstein. The rephrasing mechanism should then modify the second query to “Where was Einstein born?” to retrieve the correct documents.
We can use an LLM to rephrase the user’s query. Let’s create a prompt that instructs the LLM to rephrase the query, incorporating the conversation history, to make it suitable for retrieving relevant documents.
query_rephrase_template = """
Rewrite the question for search while keeping its meaning and key terms intact.
If the conversation history is empty, DO NOT change the query.
Use conversation history only if necessary, and avoid extending the query with your own knowledge.
If no changes are needed, output the current question as is.
Conversation history:
{% for memory in memories %}
{{ memory.content }}
{% endfor %}
User Query: {{query}}
Rewritten Query:
"""
Build the Conversational RAG Pipeline
Now, let’s incorporate query rephrasing into our pipeline by adding a new
PromptBuilder with the prompt above,
OpenAIGenerator, and an
OutputAdapter. The OpenAIGenerator
will rephrase the user’s query for search, and the OutputAdapter
will convert the output from the OpenAIGenerator
into the input for the InMemoryBM25Retriever
. The rest of the pipeline will be the same.
from typing import List
from haystack import Pipeline
from haystack.components.builders import ChatPromptBuilder, PromptBuilder
from haystack.components.generators.chat import OpenAIChatGenerator
from haystack.components.generators import OpenAIGenerator
from haystack.components.joiners import BranchJoiner
from haystack.components.converters import OutputAdapter
conversational_rag = Pipeline()
# components for query rephrasing
conversational_rag.add_component("query_rephrase_prompt_builder", PromptBuilder(query_rephrase_template))
conversational_rag.add_component("query_rephrase_llm", OpenAIGenerator())
conversational_rag.add_component("list_to_str_adapter", OutputAdapter(template="{{ replies[0] }}", output_type=str))
# components for RAG
conversational_rag.add_component("retriever", InMemoryBM25Retriever(document_store=document_store, top_k=3))
conversational_rag.add_component("prompt_builder", ChatPromptBuilder(variables=["query", "documents", "memories"], required_variables=["query", "documents", "memories"]))
conversational_rag.add_component("llm", OpenAIChatGenerator())
# components for memory
conversational_rag.add_component("memory_retriever", ChatMessageRetriever(memory_store))
conversational_rag.add_component("memory_writer", ChatMessageWriter(memory_store))
conversational_rag.add_component("memory_joiner", BranchJoiner(List[ChatMessage]))
# connections for query rephrasing
conversational_rag.connect("memory_retriever", "query_rephrase_prompt_builder.memories")
conversational_rag.connect("query_rephrase_prompt_builder.prompt", "query_rephrase_llm")
conversational_rag.connect("query_rephrase_llm.replies", "list_to_str_adapter")
conversational_rag.connect("list_to_str_adapter", "retriever.query")
# connections for RAG
conversational_rag.connect("retriever.documents", "prompt_builder.documents")
conversational_rag.connect("prompt_builder.prompt", "llm.messages")
conversational_rag.connect("llm.replies", "memory_joiner")
# connections for memory
conversational_rag.connect("memory_joiner", "memory_writer")
conversational_rag.connect("memory_retriever", "prompt_builder.memories")
<haystack.core.pipeline.pipeline.Pipeline object at 0x784b7429fa00>
π
Components
- query_rephrase_prompt_builder: PromptBuilder
- query_rephrase_llm: OpenAIGenerator
- list_to_str_adapter: OutputAdapter
- retriever: InMemoryBM25Retriever
- prompt_builder: ChatPromptBuilder
- llm: OpenAIChatGenerator
- memory_retriever: ChatMessageRetriever
- memory_writer: ChatMessageWriter
- memory_joiner: BranchJoiner
π€οΈ Connections
- query_rephrase_prompt_builder.prompt -> query_rephrase_llm.prompt (str)
- query_rephrase_llm.replies -> list_to_str_adapter.replies (List[str])
- list_to_str_adapter.output -> retriever.query (str)
- retriever.documents -> prompt_builder.documents (List[Document])
- prompt_builder.prompt -> llm.messages (List[ChatMessage])
- llm.replies -> memory_joiner.value (List[ChatMessage])
- memory_retriever.messages -> query_rephrase_prompt_builder.memories (List[ChatMessage])
- memory_retriever.messages -> prompt_builder.memories (List[ChatMessage])
- memory_joiner.value -> memory_writer.messages (List[ChatMessage])
Let’s have a conversation π
Now, run the pipeline with the relevant inputs. Instead of sending the query directly to the retriever
, this time, pass it to the query_rephrase_prompt_builder
to rephrase it.
Here are some example queries and follow ups you can try:
- What does Rhodes Statue look like? - Who built it? - Did he destroy it?
- Where is Gardens of Babylon? - When was it built?
while True:
messages = [system_message, user_message]
question = input("Enter your question or Q to exit.\nπ§ ")
if question=="Q":
break
res = conversational_rag.run(data={"query_rephrase_prompt_builder": {"query": question},
"prompt_builder": {"template": messages, "query": question},
"memory_joiner": {"value": [ChatMessage.from_user(question)]}},
include_outputs_from=["llm","query_rephrase_llm"])
search_query = res['query_rephrase_llm']['replies'][0]
print(f" π Search Query: {search_query}")
assistant_resp = res['llm']['replies'][0]
print(f"π€ {assistant_resp.content}")
Enter your question or Q to exit.
π§ What does Rhodes Statue look like?
π Search Query: What does the Rhodes Statue look like?
π€ While scholars do not know exactly what the Rhodes Statue looked like, they have a good idea of what the head and face looked like. It would have had curly hair with evenly spaced spikes of bronze or silver flame radiating, similar to the images found on contemporary Rhodian coins.
Enter your question or Q to exit.
π§ Who built it?
π Search Query: Who built the Rhodes Statue?
π€ The Colossus of Rhodes was built by Chares of Lindos, a native of Rhodes, directed to do so by the Rhodians to celebrate their victory in defending the city against an attack by Demetrius Poliorcetes in 280 BC.
Enter your question or Q to exit.
π§ Q
β Notice that this time, with the help of query rephrasing, we’ve built a conversational RAG pipeline that can handle follow-up queries and retrieve the relevant documents.