What is RAG (Retrieval-Augmented Generation)

12 February 2025
AI

What is RAG (Retrieval-Augmented Generation)?

Retrieval-Augmented Generation is the optimization of the output of a large language model (LLM) by using an additional knowledge base without retraining the model. It is more cost effective than training an entire LLM on a specific dataset.

How Does RAG Work?

An LLM takes the users input and returns a response based on the information it was trained on. The RAG process introduces an additional information retrieval component, which pulls in additional information based on the user input. The input and the additional information are then handed to the LLM to improve the quality of the responses.

1. Create External Data

First, the external data has to be generated. It can be based on various formats like files, database records or other APIs. This information is then converted into a numerical representation using a process called “embedding language models” and generally stored in a vector database.

2. Retrieve Relevant Information

When the user interacts with the model, the user query is converted into a vector representation and then matched with the vector database. If a customer of a telecommunications company might ask a chatbot: “What is the data allowance of my current plan and what other plan options do I have?”, the system might retrieve the documents about the latest plans as well as the customer’s current plan details.

3. Augment the LLM Prompt

This additional information is then added to the context of the prompt being fed to the LLM. In our example, the system then generates a tailored response explaining the data allowance of the customer and the difference to the other available plans.

4. Update External Data

The external data needs to be kept up to date asynchronously to improve the effectivity of the RAG system.

Challenges of Implementing an RAG Workflow

  • RAG workflows are relatively new, first proposed in 2020, and developers are still figuring out best practices.
  • To implement RAG is more expensive than just using an LLM (but probably less expensive than retraining an LLM).
  • Determining the best structure for the external data store can be challenging.
  • It is crucial to keep the external data up to date for the system to be effective

Simple Implementation Example

The following code shows a simple implementation of Retrieval-Augmented generation, based on a dataset of movie plots retrieved from Hugging Face.

Packages

  • transformers : Hugging Face Transformers package
    • The package pulling in transformer models from hugging face
  • sentence-transformers : Sentence Transformers package
    • Provides an easy method to compute dense vector representations for sentences, paragraphs, and images. Based on transformer networks like BERT
  • annoy : Annoy package
    • For similarity search between vectors (also see [[Semantic Search]])

Workflow

1. Transform Contextual Data and Store in Vector Database

First, the contextual data needs to be transformed to a vector representation and then stored to a vector database. The transformation (also called embedding) is done using the sentence-transformers package and then stored in a vector database (in this case annoy)

2. Embedding of question and similarity search

Next, the question is embedded in the same way and then a similarity search is used to identify the documents of the contextual data which are the most relevant.

3. Query the LLM

Then, the LLM is queried using the question and the contextual data.

Implementation

# Import all required packages
from sentence_transformers import SentenceTransformer
from transformers import AutoModelForQuestionAnswering, AutoTokenizer
import annoy
import csv

# Load models
embedding_model = SentenceTransformer('all-mpnet-base-v2') # Embedding model used to encode text data 
qa_model = AutoModelForQuestionAnswering.from_pretrained("distilbert-base-uncased") # Question answering model
qa_tokenizer = AutoTokenizer.from_pretrained("distilbert-base-uncased") # Tokenizer for question answering model

# Prepare text data
plotcolumnIndex = 7
documentData = []
with open('wiki_movie_plots_deduped.csv', 'r') as f:
  reader = csv.reader(f)
  next(reader)
  i = 0
  for row in reader:
    if len(row) > plotcolumnIndex:
      documentData.append(row[plotcolumnIndex])
    if i > 200: # only import the first 200 entries for speed
      break
    i += 1

# Encode the document data 
document_embeddings = embedding_model.encode(documentData)

# Create Annoy index
dimension = document_embeddings.shape[1] # Length of the emedding vector
index = annoy.AnnoyIndex(dimension, 'angular') # New index that's read-write and stores angular distances
for i in range(len(documentData)):
  index.add_item(i, document_embeddings[i]) # Add document embeddings to the index

# Build a forest of 10 trees (more trees give a higher precision when querying)
index.build(10)

# Set and encode the question
question = "What does Alice find in the rabbit hole?"
question_embedding = embedding_model.encode(question)

# Similarity search 
k = 5 # Number of similar documents to retrieve
ids = index.get_nns_by_vector(question_embedding, k) # Returns the k closest items

# Create relevant context from the most similar documents
context = ""
for id in ids:
  context += documentData[id] + " "

# Tokenize the question and context, so that it can be used as input to the QA model
## Truncation is used to ensure that the input length is less than or equal to 512 
## Return_tensors is set to "pt" to return PyTorch tensors
## Max_length is set to 512 to ensure that the input length is less than or equal to 512
inputs = qa_tokenizer(question, context, truncation=True, return_tensors="pt", max_length=512)

# Get the answer from the QA model
outputs = qa_model(**inputs)

# Get the answer span
## A question answer model predicts the start and end of the answer span in the context
## Find the token that the model thinks is most likely to be the start and end of the answer 
answer_start_index = outputs.start_logits.argmax()
answer_end_index = outputs.end_logits.argmax()

# Decode the answer span
if answer_start_index == 0 and answer_end_index == 0:
  predict_answer = "No answer found in context."
else:
  # Get the tokens between the start and end index
  predict_answer_tokens = inputs.input_ids[0][answer_start_index:answer_end_index + 1]
  # Decode the tokens
  predict_answer = qa_tokenizer.decode(predict_answer_tokens)

print(f"Question: {question}")
# > Question: What does Alice find in the rabbit hole?
print(f"Answer: {predict_answer}")
# > Answer: A tiny door.