Skip to content

RAG LLM

  • Setting up the retrival and using Lanchain APIs

Modify LLM Chain

  • At the moment the LLM chain is a retriver, if you want to add functionality, you will need to modify the LLMChainInitializer function.
  • To change the way vectorstore is used, modify the QASetup function.
  • To change the way Ollama works, caching works and add generation and stuff, modify the LLMChainCreator function.

LLMChainCreator

Description: Gets Ollama, sends query, enables query caching

Source code in backend/modules/rag_llm.py
 84
 85
 86
 87
 88
 89
 90
 91
 92
 93
 94
 95
 96
 97
 98
 99
100
101
102
103
104
105
106
107
108
109
110
111
class LLMChainCreator:
    """
    Description: Gets Ollama, sends query, enables query caching
    """

    def __init__(self, config: dict, local: bool = False):
        self.config = config
        self.local = local

    def get_llm_chain(self) -> LLMChain | bool:
        """
        Description: Send a query to Ollama using the paths.
        """
        base_url = "http://127.0.0.1:11434" if self.local else "http://ollama:11434"
        llm = Ollama(model=self.config["llm_model"], base_url=base_url)
        map_template = self.config["llm_prompt_template"]
        map_prompt = PromptTemplate.from_template(map_template)
        return map_prompt | llm | StrOutputParser()

    def enable_cache(self):
        """
        Description: Enable a cache for queries to prevent running the same query again for no reason.
        """
        set_llm_cache(
            SQLiteCache(
                database_path=os.path.join(self.config["data_dir"], ".langchain.db")
            )
        )

enable_cache()

Description: Enable a cache for queries to prevent running the same query again for no reason.

Source code in backend/modules/rag_llm.py
103
104
105
106
107
108
109
110
111
def enable_cache(self):
    """
    Description: Enable a cache for queries to prevent running the same query again for no reason.
    """
    set_llm_cache(
        SQLiteCache(
            database_path=os.path.join(self.config["data_dir"], ".langchain.db")
        )
    )

get_llm_chain()

Description: Send a query to Ollama using the paths.

Source code in backend/modules/rag_llm.py
 93
 94
 95
 96
 97
 98
 99
100
101
def get_llm_chain(self) -> LLMChain | bool:
    """
    Description: Send a query to Ollama using the paths.
    """
    base_url = "http://127.0.0.1:11434" if self.local else "http://ollama:11434"
    llm = Ollama(model=self.config["llm_model"], base_url=base_url)
    map_template = self.config["llm_prompt_template"]
    map_prompt = PromptTemplate.from_template(map_template)
    return map_prompt | llm | StrOutputParser()

LLMChainInitializer

Description: Setup the vectordb (Chroma) as a retriever with parameters

Source code in backend/modules/rag_llm.py
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
class LLMChainInitializer:
    """
    Description: Setup the vectordb (Chroma) as a retriever with parameters
    """

    @staticmethod
    def initialize_llm_chain(
        vectordb: Chroma, config: dict
    ) -> langchain.chains.retrieval_qa.base.RetrievalQA:
        if config["search_type"] == "similarity_score_threshold":
            return vectordb.as_retriever(
                search_type=config["search_type"],
                search_kwargs={
                    "k": config["num_return_documents"],
                    "score_threshold": 0.5,
                },
            )
        else:
            return vectordb.as_retriever(
                search_type=config["search_type"],
                search_kwargs={"k": config["num_return_documents"]},
            )

QASetup

Description: Setup the VectorDB, QA and initalize the LLM for each type of data

Source code in backend/modules/rag_llm.py
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
class QASetup:
    """
    Description: Setup the VectorDB, QA and initalize the LLM for each type of data
    """

    def __init__(
        self, config: dict, data_type: str, client: ClientAPI, subset_ids: list = None
    ):
        self.config = config
        self.data_type = data_type
        self.client = client
        self.subset_ids = subset_ids

    def setup_vector_db_and_qa(self):
        self.config["type_of_data"] = self.data_type

        metadata_processor = OpenMLMetadataProcessor(config=self.config)
        openml_data_object, data_id, all_metadata, handler = (
            metadata_processor.get_all_metadata_from_openml()
        )
        metadata_df, all_metadata = metadata_processor.create_metadata_dataframe(
            handler,
            openml_data_object,
            data_id,
            all_metadata,
            subset_ids=self.subset_ids,
        )

        vector_store_manager = VectorStoreManager(self.client, self.config)
        vectordb = vector_store_manager.create_vector_store(metadata_df)
        qa = LLMChainInitializer.initialize_llm_chain(vectordb, self.config)

        return qa, all_metadata