Skip to content

Commit

Permalink
make DatabricksRM compatible with Mosaic agent framework
Browse files Browse the repository at this point in the history
  • Loading branch information
chenmoneygithub committed Nov 15, 2024
1 parent 8ae8254 commit 115dc49
Showing 1 changed file with 53 additions and 8 deletions.
61 changes: 53 additions & 8 deletions dspy/retrieve/databricks_rm.py
Original file line number Diff line number Diff line change
@@ -1,5 +1,6 @@
import json
import os
from dataclasses import dataclass
from importlib.util import find_spec
from typing import Any, Dict, List, Optional, Union

Expand All @@ -11,6 +12,19 @@
_databricks_sdk_installed = find_spec("databricks.sdk") is not None


@dataclass
class Document:
page_content: str
metadata: Dict[str, Any]
type: str

def to_dict(self) -> Dict[str, Any]:
return {
"page_content": self.page_content,
"metadata": self.metadata,
"type": self.type,
}

class DatabricksRM(dspy.Retrieve):
"""
A retriever module that uses a Databricks Mosaic AI Vector Search Index to return the top-k
Expand Down Expand Up @@ -76,6 +90,7 @@ def __init__(
k: int = 3,
docs_id_column_name: str = "id",
text_column_name: str = "text",
use_with_databricks_agent_framework: bool = False,
):
"""
Args:
Expand All @@ -100,6 +115,8 @@ def __init__(
containing document IDs.
text_column_name (str): The name of the column in the Databricks Vector Search Index
containing document text to retrieve.
use_with_databricks_agent_framework (bool): Whether to use the `DatabricksRM` in a way that is
compatible with the Databricks Mosaic Agent Framework.
"""
super().__init__(k=k)
self.databricks_token = databricks_token if databricks_token is not None else os.environ.get("DATABRICKS_TOKEN")
Expand All @@ -119,6 +136,20 @@ def __init__(
self.k = k
self.docs_id_column_name = docs_id_column_name
self.text_column_name = text_column_name
self.use_with_databricks_agent_framework = use_with_databricks_agent_framework
if self.use_with_databricks_agent_framework:
try:
import mlflow
mlflow.models.set_retriever_schema(
primary_key="doc_id",
text_column="page_content",
doc_uri="doc_uri",
)
except ImportError:
raise ValueError(
"To use the `DatabricksRM` retriever module with the Databricks Mosaic Agent Framework, "
"you must install the mlflow Python library. Please install mlflow via `pip install mlflow`."
)

def _extract_doc_ids(self, item: Dict[str, Any]) -> str:
"""Extracts the document id from a search result
Expand Down Expand Up @@ -154,7 +185,7 @@ def forward(
query: Union[str, List[float]],
query_type: str = "ANN",
filters_json: Optional[str] = None,
) -> dspy.Prediction:
) -> Union[dspy.Prediction, List[Dict[str, Any]]]:
"""
Retrieve documents from a Databricks Mosaic AI Vector Search Index that are relevant to the
specified query.
Expand All @@ -172,7 +203,9 @@ def forward(
parameter overrides the `filters_json` parameter passed to the constructor.
Returns:
dspy.Prediction: An object containing the retrieved results.
A list of dictionaries when ``use_with_databricks_agent_framework`` is ``True``,
or a ``dspy.Prediction`` object when ``use_with_databricks_agent_framework`` is
``False``.
"""
if query_type in ["vector", "text"]:
# Older versions of DSPy used a `query_type` argument to disambiguate between text
Expand Down Expand Up @@ -239,12 +272,24 @@ def forward(
# Sorting results by score in descending order
sorted_docs = sorted(items, key=lambda x: x["score"], reverse=True)[: self.k]

# Returning the prediction
return Prediction(
docs=[doc[self.text_column_name] for doc in sorted_docs],
doc_ids=[self._extract_doc_ids(doc) for doc in sorted_docs],
extra_columns=[self._get_extra_columns(item) for item in sorted_docs],
)
if self.use_with_databricks_agent_framework:
return [Document(
page_content=doc[self.text_column_name],
metadata={
"doc_id": self._extract_doc_ids(doc),
"doc_uri": f"index/{self.databricks_index_name}/id/{self._extract_doc_ids(doc)}",
}
| self._get_extra_columns(doc),
type="Document",
).to_dict() for doc in sorted_docs]
else:
# Returning the prediction
return Prediction(
docs=[doc[self.text_column_name] for doc in sorted_docs],
doc_ids=[self._extract_doc_ids(doc) for doc in sorted_docs],
extra_columns=[self._get_extra_columns(item) for item in sorted_docs],
)


@staticmethod
def _query_via_databricks_sdk(
Expand Down

0 comments on commit 115dc49

Please sign in to comment.