Incrementally Migrating to LangChain

Nimo smiling at the camera
Nimo Beeren
2024-05-25

Frameworks like LangChain provide a useful base to build LLM applications, but migrating an existing app can be daunting. When my team and I were tasked to bring a prototype RAG app to production, we experienced a need for structure. We wanted a system that allows:

But we couldn’t just replace all of our code with LangChain’s pre-built integrations. We had custom logic, dependencies and wanted full control over every component. We also just didn’t have the time to migrate the entire application at once.

When faced with this problem, we came up with a migration strategy that let us migrate components to LangChain one-by-one, without changing their underlying behavior or rewriting large parts of our code. In order of preference:

  1. If there exists a LangChain integration that we can use as a drop-in replacement for our component, do that
  2. If the component is very simple, refactor it to conform to one of LangChain’s base interfaces
  3. If the component is more complex, wrap it in a new component that conforms to one of LangChain’s base interfaces

The key is that we don’t force ourselves to use pre-built LangChain integrations that aren’t a perfect fit for us. Instead, we ensure that each component conforms to one of the interfaces which these pre-built integrations are built on. This allows our own components to slot in seamlessly with the rest of the LangChain ecosystem.

Let’s take a look at how this worked for our RAG application.

Migrating the Retriever

At its core, a retriever is a function that takes a query and returns a list of relevant documents. But in practice, you might also be rewriting/decomposing the query, using multiple search backends or reranking the results. In our case, there wasn’t a ready-to-use integration that could do all that. Instead, we built a new class which conforms to the BaseRetriever interface. This is actually what all LangChain retrievers do. All you have to do is implement the _get_relevant_documents method, which takes in a query and returns a list of documents:

from typing import List
from langchain.schema import BaseRetriever, Document

# For example, this might be how you retrieved documents before
def legacy_get_docs(query: str):
	return [
		{"title": "All About Bananas", "text": "Bananas are cool!"},
		{"title": "The Berry Book", "text": "Blueberries rock."},
	]

class MyRetriever(BaseRetriever):
	def _get_relevant_documents(self, query: str) -> List[Document]:
		# Get your documents as you normally would
		docs = legacy_get_docs(query)
		# Map the resulting documents
		return [self.map_document(doc) for doc in docs]

	def map_document(self, old) -> Document:
		# Map your own document object to LangChain's Document
		return Document(
            page_content=old["text"],
            metadata={"title": old["title"]}
		)

We can now call it using the invoke method, which is already implemented by BaseRetriever:

retriever = MyRetriever()
docs = retriever.invoke("What's long and yellow?")
print(docs)  # prints [Document(...)]

This is all it takes to build a retriever. We can now include it in chains using LangChain Expression Language (LCEL), or call it with invoke. Just don’t call the _get_relevant_documents method directly because then you won’t be able to use things like callbacks (which are used by observability integrations).

The example above is contrived, but adding your own parameters is easy enough. Just be aware that LangChain interfaces are Pydantic models. That means all constructor arguments are automatically validated at runtime. This helps to prevent passing arguments with the wrong type, but you must specify the name and type of every field in the class definition:

class MyRetriever(BaseRetriever):
    number_of_docs: int  # add a field here

    def _get_relevant_documents(self, query: str) -> List[Document]:
        docs = legacy_get_docs(
            query,
            # You can now access the field on `self`
            number_of_docs=self.number_of_docs
        )
        return ...

# In the calling code
retriever = MyRetriever(number_of_docs=5)

We’ve added a field number_of_docs with type int, which is passed as a keyword argument to the MyRetriever constructor. Keep in mind that if your field is not a primitive Python type (int, str, dict, etc.) it must also be a Pydantic model. The Pydantic docs show how you can easily adapt your Python classes to Pydantic models, or you can just use Python’s Any type if you want to skip validation.

In case your old retriever is also a class or you want to run some initialization code like creating an API client, you can override the __init__ method:

from typing import Optional
from elasticsearch import Elasticsearch

class MyRetriever(BaseRetriever):
    # Make the field optional, otherwise you have to provide it as
    # a constructor argument
    elastic: Optional[Elasticsearch]

    def __init__(self, *args, **kwargs):
        super().__init__(*args, **kwargs)  # important!
        self.elastic = Elasticsearch(...)

    def _get_relevant_documents(self, query: str) -> List[Document]:
        docs = self.elastic.search(query)
        return ...

Keep in mind that everything you want to get or set on self must be defined as a field on the model, so in this case we have to add the elastic field to our model before assigning something to self.elastic. When overriding __init__(), make sure to call super().__init__() at the very start of your constructor, otherwise you’ll get an error like 'MyRetriever' object has no attribute '__fields_set__'.

Migrating the Prompt Template and LLM

Because our retriever now uses LangChain’s Document structure, we need to adjust the other components too. Our prompt template didn’t have complex logic or dependencies, so we refactored it to implement the BasePromptTemplate interface directly:

from typing import List
from langchain.schema import (BaseMessage, SystemMessage
                              HumanMessage, Document)
from langchain_core.prompts import BaseChatPromptTemplate

class MyPromptTemplate(BaseChatPromptTemplate):
    input_variables: List[str] = ["query", "docs"]  # important!

    def format_messages(
        self, query: str, docs: List[Document]
    ) -> List[BaseMessage]:
        return [
            SystemMessage(
                "Using the documents, answer the user's question.\n\n"
                "Documents:\n"
                + "\n\n".join([doc.page_content for doc in docs])
            ),
            HumanMessage(query)
        ]

# In the calling code
MyPromptTemplate().invoke({
    "query": "Who dis?",
    "docs": [Document(page_content="It's Mr. Worldwide")]
})

An advantage of this approach over something like ChatPromptTemplate.from_messages is that you can take arguments with types other than str, such as List[Document]. This way, all of your prompt formatting code can be kept in one place. Just be sure to set the input_variables field (and use a List[str] type annotation), otherwise validation fails.

With our LLM, we weren’t doing anything special so we just used the pre-built ChatOpenAI integration as a drop-in replacement for the OpenAI client we were using before:

# Using OpenAI library
from openai import OpenAI

client = OpenAI(api_key=...)
client.chat.completions.create(
    messages=[
        {
            "role": "user",
            "content": "What is the best fruit?",
        }
    ],
    model="gpt-4o",
)

# Using LangChain
from langchain_openai import ChatOpenAI
from langchain.schema import HumanMessage

chat = ChatOpenAI(api_key=..., model="gpt-4o")
chat.invoke([HumanMessage(content="What is the best fruit?")])

The code is very similar, but because all LangChain chat models conform to the same BaseChatModel interface, we can easily swap out OpenAI for any other LLM.

Putting It All Together

Once we have the three basic elements in place, they can be neatly wrapped into a chain using LCEL:

from langchain_core.runnables import RunnablePassthrough

retriever = MyRetriever()
prompt = MyPromptTemplate()
chat = ChatOpenAI(api_key=..., model="gpt-4o")

chain = (
    {"query": RunnablePassthrough(), "docs": retriever}
    | prompt
    | chat
)

output = chain.invoke("What is the best fruit?")
print(output.content)  # Blueberries!

Of course, you can still use all of your components separately by calling invoke() on them.

Trade-offs

A few things to consider before taking this approach:

Conclusion

While rewriting all of your code at once is hard, it’s much easier to incrementally migrate your components one-by-one. These are the most important interfaces you’ll want to keep in mind:

class BaseRetriever:
    _get_relevant_documents(self, query: str) -> List[Document]
class BaseChatPromptTemplate:
    format_messages(self, **kwargs) -> List[BaseMessage]
class Document:
    page_content: str
    metadata: dict
class BaseMessage:
    role: "system" | "human" | "ai"
    content: str

When your components use these interfaces, it becomes trivial to swap them out:

Next to that, when you use LCEL, observability is basically free:

Finally, your components have a single, clear purpose which makes your code easier to understand and change.