Skip to content

Commit

Permalink
add configuration (#10)
Browse files Browse the repository at this point in the history
* add configuration

* Update deps

---------

Co-authored-by: jacoblee93 <[email protected]>
  • Loading branch information
hwchase17 and jacoblee93 authored Oct 17, 2023
1 parent 0eb24c5 commit 2165215
Show file tree
Hide file tree
Showing 3 changed files with 292 additions and 34 deletions.
58 changes: 40 additions & 18 deletions main.py
Original file line number Diff line number Diff line change
@@ -1,14 +1,14 @@
"""Main entrypoint for the app."""
import os
import asyncio
import os
from operator import itemgetter
from typing import Dict, List, Optional, Sequence

import langsmith
from fastapi import FastAPI, Request
from fastapi.middleware.cors import CORSMiddleware
from langchain.callbacks.manager import CallbackManagerForRetrieverRun
from langchain.chat_models import ChatOpenAI
from langchain.chat_models import ChatAnthropic, ChatOpenAI
from langchain.document_loaders import AsyncHtmlLoader
from langchain.document_transformers import Html2TextTransformer
from langchain.embeddings import OpenAIEmbeddings
Expand All @@ -18,14 +18,16 @@
TavilySearchAPIRetriever)
from langchain.retrievers.document_compressors import (
DocumentCompressorPipeline, EmbeddingsFilter)
from langchain.retrievers.you import YouRetriever
from langchain.schema import Document
from langchain.schema.document import Document
from langchain.schema.language_model import BaseLanguageModel
from langchain.schema.messages import AIMessage, HumanMessage
from langchain.schema.output_parser import StrOutputParser
from langchain.schema.retriever import BaseRetriever
from langchain.schema.runnable import (Runnable, RunnableBranch,
RunnableLambda, RunnableMap)
from langchain.schema.runnable import (ConfigurableField, Runnable,
RunnableBranch, RunnableLambda,
RunnableMap)
from langchain.text_splitter import RecursiveCharacterTextSplitter
# Backup
from langchain.utilities import GoogleSearchAPIWrapper
Expand Down Expand Up @@ -94,10 +96,9 @@
class ChatRequest(TypedDict):
question: str
chat_history: Optional[List[Dict[str, str]]]
# conversation_id: Optional[str]


class BackupRetriever(BaseRetriever):
class GoogleCustomSearchRetriever(BaseRetriever):
search: GoogleSearchAPIWrapper = GoogleSearchAPIWrapper()
num_search_results = 6

Expand Down Expand Up @@ -151,22 +152,34 @@ def _get_relevant_documents(
return docs


def get_base_retriever():
if (os.environ.get("USE_BACKUP", "false") == "true"):
return BackupRetriever()
return TavilySearchAPIRetriever(k=6, include_raw_content=True, include_images=True)


def _get_retriever():
def get_retriever():
embeddings = OpenAIEmbeddings()
splitter = RecursiveCharacterTextSplitter(chunk_size=800, chunk_overlap=20)
relevance_filter = EmbeddingsFilter(embeddings=embeddings, similarity_threshold=0.8)
pipeline_compressor = DocumentCompressorPipeline(
transformers=[splitter, relevance_filter]
)
base_retriever = get_base_retriever()
return ContextualCompressionRetriever(
base_compressor=pipeline_compressor, base_retriever=base_retriever
base_tavily_retriever = TavilySearchAPIRetriever(
k=6, include_raw_content=True, include_images=True
)
tavily_retriever = ContextualCompressionRetriever(
base_compressor=pipeline_compressor, base_retriever=base_tavily_retriever
)
base_google_retriever = GoogleCustomSearchRetriever()
google_retriever = ContextualCompressionRetriever(
base_compressor=pipeline_compressor, base_retriever=base_google_retriever
)
base_you_retriever = YouRetriever()
you_retriever = ContextualCompressionRetriever(
base_compressor=pipeline_compressor, base_retriever=base_you_retriever
)
return tavily_retriever.configurable_alternatives(
# This gives this field an id
# When configuring the end runnable, we can then use this id to configure this field
ConfigurableField(id="retriever"),
default_key="tavily",
google=google_retriever,
you=you_retriever,
).with_config(run_name="FinalSourceRetriever")


Expand Down Expand Up @@ -263,13 +276,22 @@ def create_chain(
# model="gpt-4",
streaming=True,
temperature=0,
).configurable_alternatives(
# This gives this field an id
# When configuring the end runnable, we can then use this id to configure this field
ConfigurableField(id="llm"),
default_key="openai",
anthropic=ChatAnthropic(model="claude-2", max_tokens=16384),
)

retriever = _get_retriever()

retriever = get_retriever()

chain = create_chain(llm, retriever)

add_routes(app, chain, path="/chat", input_type=ChatRequest)
add_routes(
app, chain, path="/chat", input_type=ChatRequest, config_keys=["configurable"]
)


@app.post("/feedback")
Expand Down
Loading

0 comments on commit 2165215

Please sign in to comment.