Biomedical RAG on your local machine - ft. LLaMA + Qdrant

Implementing RAG pipeline for scientific conversational QA on your local machine

TL;DR

Retrieval Augmented Generation (RAG) [1,2,3] is a popular LLM-application pattern to incorporate external domain knowledge from retrieval databases to condition the generative LLM output. This post implements a RAG pipeline with Llama 2[4] on one’s local machine for long-form QA on biomedical text. Three reasons for this post:

  1. This is Part 1 of the 2-part blog series. It provides the necessary RAG pipeline apparatus as well as baseline performance for the task of long-form medical QA that will be later required in Part 2 where I experiment with fine-tuning the LLM.

  2. For democratised access, I wanted to implement RAG entirely on my local machine with smaller LLMs. No expensive Google Colab compute or LLMs behind remote APIs.

  3. instead of viewing all the concepts in isolation, I wanted to integrate the many exciting ideas (LLM quantisation, Qdrant, fastembed, llama.cpp, etc) within a single project.

Why you may want to read this: You are interested in conversational knowledge applications such as long-form QA, chatbots or search in specialised domains (e.g. medicine/finance). OR you want to compare implementation notes for RAG, vector search, etc. Or, you may want to build generative LLM applications locally with small models and are interested in enabling tools such as llama.cpp and LLM quantisation.

Deep Dive

Motivation: To build a conversational interface for a knowledge-retrieval application (be it search, question answering, or chat agents) RAG has become the standard pattern for integrating external domain knowledge into generative LLM workflows. I like to think of this task as an open book exam, where the book chapters are selected by the retrieval mechanism and the LLM is prompted for reading comprehension on this retrieved context. It goes a long way in overcoming limitations of foundation LLMs (hallucinations, being trained on static outdated data, lack of domain knowledge, etc) by augmenting the LLM’s intrinsic parametric memory with external vector databases/retrieval mechanisms.

Note, RAG workflows commonly adapt off-the-shelf LLMs using sophisticated prompt-engineering techniques to coax the LLM to perform a desired generative task, rather than fine-tuning the model. Relying solely on prompts at inference time for In-Context-Learning A more advanced read on topic can be found in Stanford AI Blog's 'How in-context-learning works': https://ai.stanford.edu/blog/understanding-incontext/ to perform complex reasoning tasks is brittle. It has implications on correctness of LLM answer, consistency of output format, high inference costs due to lengthy multi-step prompts, etc. The challenges are amplified when operating in domains that have very stringent safety concerns and near zero tolerance for misinformation. Fine-tuning is known to produce more reliable and consistent LLM output than solely relying on in-context learning through clever prompts. That experiment is left to Part 2 of this series.

Task Details

The premise of this tutorial is applying generative LLMs within a RAG framework for a (1) conversational QA task, focussed on (2) niche domains which are expected to require more specialised knowledge than the general purpose training corpora used for the foundation LLMs.

We select the task of long-form question answering based on biomedical research content from PubMed journal articles.

Figure 1: Sample from PubMedQA dataset converted for instruction-tuning

Implementation

Here is an overview of the steps in the implementation workflow, in sequence

Let’s elaborate on each of these steps

Implementation - Prerequisites

  1. Prepare LLaMA model to deploy for inference locally, using llama.cpp project This project enables running LLaMA 2 model variants in quantized form for inference on one’s local regular purpose machine. The llama.cpp can be deployed as a cli tool, or deploy it as a server (See example here)
    • Download Llama model (weights, tokenizer, etc) using the meta request form and instructions provided here
    • Install the llama.cpp project: Follow comprehensive setup instructions provided here
    • Quantise the desired llama model (Following instructions from llama.cpp as seen here). I used the 7B and 7B-chat models. In a nutshell:
      • After obtaining the original LLaMA weights, place them in ./models
      • Install python dependencies from requirements.txt file
      • Convert the model to GGUF FP16 format python3 convert.py models/7B-chat/
      • Perform model quantization. This is a technique to reduce the memory footprint of the LLM weights by using a different data type to store the entries of the weight matrices. Read more about it here. I perform 8-bit model quantization, and the resultant model is used in the subsequent RAG pipeline.
          ./quantize ./models/7B-chat/ggml-model-f16.gguf ./models/7B-chat/ggml-model-q8_0.gguf q8_0
        
  2. Vectorise text documents in a vector database, and create searchable index

    a. Obtain PubMedQA_instruction from Huggingface datasets hub. This is pre-formats the PubMedQA dataset for instruction-tuning.

        from datasets import load_dataset
        dataset_name = "FedML/PubMedQA_instruction"
        data = load_dataset(dataset_name, split="train[0:2000]")
    

    b. Vectorise data using qdrant-fastembed and create a searchable vector index in Qdrant

    • Install Qdrant’s FastEmbed - ‘a lightweight, fast, Python library built for embedding generation’, following instructions here.
    • Install qdrant-client, a python client to interface with the Qdrant vector database and search engine following instructions here.
    • Qdrant allows the developer to customise the search index config - you can annotate a single document with multiple vectors, or with additional payload (e.g. dates, titles, full text bodies, etc) that can be used as document filters or key-word searches to augment vector search
     from qdrant_client import models, QdrantClient
    
     COLLECTION_NAME = "pubmedqa"
     EMBEDDING_MODEL_NAME = "BAAI/bge-base-en"
     qdrant_client = QdrantClient(":memory:")
     qdrant_client.set_model(EMBEDDING_MODEL_NAME)
        
     qdrant_client.create_collection(
         collection_name=COLLECTION_NAME,
         vectors_config=qdrant_client.get_fastembed_vector_params()
     )
        
     qdrant_client.add(
         collection_name=COLLECTION_NAME,
         documents=[doc['text'] for doc in data],
         ids=[doc['id'] for doc in data]
     )
        
     # Qdrant search engine client makes it very easy to search a query against a database, and retrieve top-k documents 
     search_hits = qdrant_client.query(
         collection_name=collection_name,
         query_text=query,
         limit=top_k)
    
    

Implementation - RAG inference workflow on each user query

  1. Retrieve relevant documents by querying the vector DB using Qdrant’s semantic search engine API

    Qdrant search engine client makes it very easy to search a query against a database, and retrieve top-k documents Let’s write a simple function that the RAG pipeline can use to retrieve relevant context from the Qdrant vector database at inference time.

    # define class retriever/ vectorstore_as_retriever
     def retrieve_relevant_doc_context(query, qdrant_client, collection_name, top_k=3, verbose=False):
         rel_docs = []
         search_hits = qdrant_client.query(
             collection_name=collection_name,
             query_text=query,
             limit=top_k)
         for hit in search_hits:
             hit_dict = {'text': hit.metadata['document'],
                         'score': hit.score}
             rel_docs.append(hit_dict['text'])
             if verbose:
                 print(hit_dict)
         return rel_docs
    
     rel_docs = retrieve_relevant_doc_context(query=query,
                                  qdrant_client=qdrant_client,
                                  collection_name=COLLECTION_NAME,
                                  top_k=3)
    
    
  2. Construct an LLM prompt with user question, and relevant documents - instructing LLM to answer from within the provided context

     from langchain.prompts import PromptTemplate
        
     RAG_PROMPT_string = ("""\
     Human: Here is a question from a medical professional: 
    <question> 
    {user_query}
    </question>
        
     Here are some search results from a medical encyclopedia that you must reference to answer the question: 
     {extracts}
        
     Once again, here is the question:
     <question>
     {user_query}
     </question>
        
     Your objective is to write a high quality, concise answer
     for the medical professional within <answer> </answer> tags. Otherwise, write ANSWER NOT FOUND)
        
     Assistant: <answer>\n\n """ 
     )
        
     rag_prompt_template = PromptTemplate.from_template(RAG_PROMPT_string)
    
     from typing import List
     def prep_rag_prompt(query: str,
                         rel_search_extracts: List,
                         prompt_template,
                        ) -> str:
         prompt = prompt_template.format(extracts='\n\n'.join(rel_docs),
                                              user_query=query,
                                              )
         return prompt
     rag_prompt = prep_rag_prompt(query=query, rel_search_extracts = rel_docs, prompt_template=rag_prompt_template)
    
  3. Prompting the LLaMA model within the RAG pipeline. Calling the local quantized LLaMA model from behind the llama.cpp app is made very easy by the llama-cpp-python project. This provides python bindings to the llama.cpp project.

     from langchain.llms import LlamaCpp
     from langchain.callbacks.manager import CallbackManager
     from langchain.callbacks.streaming_stdout import StreamingStdOutCallbackHandler
        
     n_gpu_layers = 1  # Metal set to 1 is enough.
     n_batch = 512  # Should be between 1 and n_ctx, consider the amount of RAM of your Apple Silicon Chip.
     callback_manager = CallbackManager([StreamingStdOutCallbackHandler()])
     LLAMA_CPP_Q8_PATH = "/Users/mitrap/PycharmProjects/llama.cpp/models/7B-chat/ggml-model-q8_0.gguf"
        
     # Make sure the model path is correct for your system!
     llm = LlamaCpp(
         model_path=LLAMA_CPP_Q8_PATH,
         n_gpu_layers=n_gpu_layers,
         n_batch=n_batch,
         n_ctx=2048,
         f16_kv=True,  # MUST set to True, otherwise you will run into problem after a couple of calls
         callback_manager=callback_manager,
         verbose=True,
     )
    
     rag_prompt = prep_rag_prompt(query=query,
                                  rel_search_extracts = rel_docs,
                                  prompt_template=rag_prompt_template)
        
     llm(rag_prompt)
    

The code illustrating the above workflow is made available on github

References:

  1. IBM - What is Retrieval Augmented Generation? [Link]
  2. Prompt Engineering Guide - Retrieval Augmented Generation (RAG) [Link]
  3. The Complete Overview to Retrieval Augmented Generation (RAG) [Link]
  4. Llama 2: Open Foundation and Fine-Tuned Chat Models [Link]
  5. How does in-context learning work? A framework for understanding the differences from traditional supervised learning [Link]
  6. FastEmbed: Fast and Lightweight Embedding Generation for Text [Link]