From b016a84b97a7f33ea84202000558ba20795990eb Mon Sep 17 00:00:00 2001 From: "Tuan Anh Nguyen Dang (Tadashi_Cin)" Date: Wed, 4 Dec 2024 11:09:57 +0700 Subject: [PATCH] feat: add Google embedding support & update setup (#550) bump:patch --- flowsettings.py | 18 +++-- libs/kotaemon/kotaemon/embeddings/__init__.py | 2 + .../kotaemon/embeddings/langchain_based.py | 35 ++++++++++ libs/ktem/ktem/embeddings/manager.py | 2 + libs/ktem/ktem/pages/setup.py | 67 ++++++++++++++++--- 5 files changed, 109 insertions(+), 15 deletions(-) diff --git a/flowsettings.py b/flowsettings.py index e8b78f3af..e248fac94 100644 --- a/flowsettings.py +++ b/flowsettings.py @@ -26,6 +26,7 @@ KH_ENABLE_FIRST_SETUP = True KH_DEMO_MODE = config("KH_DEMO_MODE", default=False, cast=bool) +KH_OLLAMA_URL = config("KH_OLLAMA_URL", default="http://localhost:11434/v1/") # App can be ran from anywhere and it's not trivial to decide where to store app data. # So let's use the same directory as the flowsetting.py file. @@ -162,7 +163,7 @@ KH_LLMS["ollama"] = { "spec": { "__type__": "kotaemon.llms.ChatOpenAI", - "base_url": "http://localhost:11434/v1/", + "base_url": KH_OLLAMA_URL, "model": config("LOCAL_MODEL", default="llama3.1:8b"), "api_key": "ollama", }, @@ -171,7 +172,7 @@ KH_EMBEDDINGS["ollama"] = { "spec": { "__type__": "kotaemon.embeddings.OpenAIEmbeddings", - "base_url": "http://localhost:11434/v1/", + "base_url": KH_OLLAMA_URL, "model": config("LOCAL_MODEL_EMBEDDINGS", default="nomic-embed-text"), "api_key": "ollama", }, @@ -195,11 +196,11 @@ }, "default": False, } -KH_LLMS["gemini"] = { +KH_LLMS["google"] = { "spec": { "__type__": "kotaemon.llms.chats.LCGeminiChat", - "model_name": "gemini-1.5-pro", - "api_key": "your-key", + "model_name": "gemini-1.5-flash", + "api_key": config("GOOGLE_API_KEY", default="your-key"), }, "default": False, } @@ -231,6 +232,13 @@ }, "default": False, } +KH_EMBEDDINGS["google"] = { + "spec": { + "__type__": "kotaemon.embeddings.LCGoogleEmbeddings", + "model": "models/text-embedding-004", + "google_api_key": config("GOOGLE_API_KEY", default="your-key"), + } +} # KH_EMBEDDINGS["huggingface"] = { # "spec": { # "__type__": "kotaemon.embeddings.LCHuggingFaceEmbeddings", diff --git a/libs/kotaemon/kotaemon/embeddings/__init__.py b/libs/kotaemon/kotaemon/embeddings/__init__.py index 92b3d1f4b..0ff777428 100644 --- a/libs/kotaemon/kotaemon/embeddings/__init__.py +++ b/libs/kotaemon/kotaemon/embeddings/__init__.py @@ -4,6 +4,7 @@ from .langchain_based import ( LCAzureOpenAIEmbeddings, LCCohereEmbeddings, + LCGoogleEmbeddings, LCHuggingFaceEmbeddings, LCOpenAIEmbeddings, ) @@ -18,6 +19,7 @@ "LCAzureOpenAIEmbeddings", "LCCohereEmbeddings", "LCHuggingFaceEmbeddings", + "LCGoogleEmbeddings", "OpenAIEmbeddings", "AzureOpenAIEmbeddings", "FastEmbedEmbeddings", diff --git a/libs/kotaemon/kotaemon/embeddings/langchain_based.py b/libs/kotaemon/kotaemon/embeddings/langchain_based.py index 03ff9c670..9e8422a04 100644 --- a/libs/kotaemon/kotaemon/embeddings/langchain_based.py +++ b/libs/kotaemon/kotaemon/embeddings/langchain_based.py @@ -219,3 +219,38 @@ def _get_lc_class(self): from langchain.embeddings import HuggingFaceBgeEmbeddings return HuggingFaceBgeEmbeddings + + +class LCGoogleEmbeddings(LCEmbeddingMixin, BaseEmbeddings): + """Wrapper around Langchain's Google GenAI embedding, focusing on key parameters""" + + google_api_key: str = Param( + help="API key (https://aistudio.google.com/app/apikey)", + default=None, + required=True, + ) + model: str = Param( + help="Model name to use (https://ai.google.dev/gemini-api/docs/models/gemini#text-embedding-and-embedding)", # noqa + default="models/text-embedding-004", + required=True, + ) + + def __init__( + self, + model: str = "models/text-embedding-004", + google_api_key: Optional[str] = None, + **params, + ): + super().__init__( + model=model, + google_api_key=google_api_key, + **params, + ) + + def _get_lc_class(self): + try: + from langchain_google_genai import GoogleGenerativeAIEmbeddings + except ImportError: + raise ImportError("Please install langchain-google-genai") + + return GoogleGenerativeAIEmbeddings diff --git a/libs/ktem/ktem/embeddings/manager.py b/libs/ktem/ktem/embeddings/manager.py index c33d151db..1c1c47027 100644 --- a/libs/ktem/ktem/embeddings/manager.py +++ b/libs/ktem/ktem/embeddings/manager.py @@ -57,6 +57,7 @@ def load_vendors(self): AzureOpenAIEmbeddings, FastEmbedEmbeddings, LCCohereEmbeddings, + LCGoogleEmbeddings, LCHuggingFaceEmbeddings, OpenAIEmbeddings, TeiEndpointEmbeddings, @@ -68,6 +69,7 @@ def load_vendors(self): FastEmbedEmbeddings, LCCohereEmbeddings, LCHuggingFaceEmbeddings, + LCGoogleEmbeddings, TeiEndpointEmbeddings, ] diff --git a/libs/ktem/ktem/pages/setup.py b/libs/ktem/ktem/pages/setup.py index f7e70a118..21efa5d9a 100644 --- a/libs/ktem/ktem/pages/setup.py +++ b/libs/ktem/ktem/pages/setup.py @@ -9,7 +9,10 @@ from theflow.settings import settings as flowsettings KH_DEMO_MODE = getattr(flowsettings, "KH_DEMO_MODE", False) -DEFAULT_OLLAMA_URL = "http://localhost:11434/api" +KH_OLLAMA_URL = getattr(flowsettings, "KH_OLLAMA_URL", "http://localhost:11434/v1/") +DEFAULT_OLLAMA_URL = KH_OLLAMA_URL.replace("v1", "api") +if DEFAULT_OLLAMA_URL.endswith("/"): + DEFAULT_OLLAMA_URL = DEFAULT_OLLAMA_URL[:-1] DEMO_MESSAGE = ( @@ -55,8 +58,9 @@ def on_building_ui(self): gr.Markdown(f"# Welcome to {self._app.app_name} first setup!") self.radio_model = gr.Radio( [ - ("Cohere API (*free registration* available) - recommended", "cohere"), - ("OpenAI API (for more advance models)", "openai"), + ("Cohere API (*free registration*) - recommended", "cohere"), + ("Google API (*free registration*)", "google"), + ("OpenAI API (for GPT-based models)", "openai"), ("Local LLM (for completely *private RAG*)", "ollama"), ], label="Select your model provider", @@ -92,6 +96,18 @@ def on_building_ui(self): show_label=False, placeholder="Cohere API Key" ) + with gr.Column(visible=False) as self.google_option: + gr.Markdown( + ( + "#### Google API Key\n\n" + "(register your free API key " + "at https://aistudio.google.com/app/apikey)" + ) + ) + self.google_api_key = gr.Textbox( + show_label=False, placeholder="Google API Key" + ) + with gr.Column(visible=False) as self.ollama_option: gr.Markdown( ( @@ -119,7 +135,12 @@ def on_register_events(self): self.openai_api_key.submit, ], fn=self.update_model, - inputs=[self.cohere_api_key, self.openai_api_key, self.radio_model], + inputs=[ + self.cohere_api_key, + self.openai_api_key, + self.google_api_key, + self.radio_model, + ], outputs=[self.setup_log], show_progress="hidden", ) @@ -147,13 +168,19 @@ def on_register_events(self): fn=self.switch_options_view, inputs=[self.radio_model], show_progress="hidden", - outputs=[self.cohere_option, self.openai_option, self.ollama_option], + outputs=[ + self.cohere_option, + self.openai_option, + self.ollama_option, + self.google_option, + ], ) def update_model( self, cohere_api_key, openai_api_key, + google_api_key, radio_model_value, ): # skip if KH_DEMO_MODE @@ -221,12 +248,32 @@ def update_model( }, default=True, ) + elif radio_model_value == "google": + if google_api_key: + llms.update( + name="google", + spec={ + "__type__": "kotaemon.llms.chats.LCGeminiChat", + "model_name": "gemini-1.5-flash", + "api_key": google_api_key, + }, + default=True, + ) + embeddings.update( + name="google", + spec={ + "__type__": "kotaemon.embeddings.LCGoogleEmbeddings", + "model": "models/text-embedding-004", + "google_api_key": google_api_key, + }, + default=True, + ) elif radio_model_value == "ollama": llms.update( name="ollama", spec={ "__type__": "kotaemon.llms.ChatOpenAI", - "base_url": "http://localhost:11434/v1/", + "base_url": KH_OLLAMA_URL, "model": "llama3.1:8b", "api_key": "ollama", }, @@ -236,7 +283,7 @@ def update_model( name="ollama", spec={ "__type__": "kotaemon.embeddings.OpenAIEmbeddings", - "base_url": "http://localhost:11434/v1/", + "base_url": KH_OLLAMA_URL, "model": "nomic-embed-text", "api_key": "ollama", }, @@ -270,7 +317,7 @@ def update_model( yield log_content except Exception as e: log_content += ( - "Make sure you have download and installed Ollama correctly." + "Make sure you have download and installed Ollama correctly. " f"Got error: {str(e)}" ) yield log_content @@ -345,9 +392,9 @@ def update_default_settings(self, radio_model_value, default_settings): return default_settings def switch_options_view(self, radio_model_value): - components_visible = [gr.update(visible=False) for _ in range(3)] + components_visible = [gr.update(visible=False) for _ in range(4)] - values = ["cohere", "openai", "ollama", None] + values = ["cohere", "openai", "ollama", "google", None] assert radio_model_value in values, f"Invalid value {radio_model_value}" if radio_model_value is not None: