-
Notifications
You must be signed in to change notification settings - Fork 1.6k
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Integrate cachetools for in-memory LM caching, including unhashable types & pydantic #1896
Changes from 6 commits
File filter
Filter by extension
Conversations
Jump to
Diff view
Diff view
There are no files selected for viewing
Original file line number | Diff line number | Diff line change |
---|---|---|
|
@@ -4,10 +4,13 @@ | |
import threading | ||
import uuid | ||
from datetime import datetime | ||
from hashlib import sha256 | ||
from typing import Any, Dict, List, Literal, Optional | ||
|
||
import litellm | ||
import pydantic | ||
import ujson | ||
from cachetools import LRUCache, cached | ||
|
||
from dspy.adapters.base import Adapter | ||
from dspy.clients.openai import OpenAIProvider | ||
|
@@ -92,7 +95,7 @@ def __call__(self, prompt=None, messages=None, **kwargs): | |
completion = cached_litellm_text_completion if cache else litellm_text_completion | ||
|
||
response = completion( | ||
request=ujson.dumps(dict(model=self.model, messages=messages, **kwargs)), | ||
request=dict(model=self.model, messages=messages, **kwargs), | ||
num_retries=self.num_retries, | ||
) | ||
outputs = [c.message.content if hasattr(c, "message") else c["text"] for c in response["choices"]] | ||
|
@@ -153,7 +156,11 @@ def thread_function_wrapper(): | |
thread = threading.Thread(target=thread_function_wrapper) | ||
model_to_finetune = self.finetuning_model or self.model | ||
job = self.provider.TrainingJob( | ||
thread=thread, model=model_to_finetune, train_data=train_data, train_kwargs=train_kwargs, data_format=data_format | ||
thread=thread, | ||
model=model_to_finetune, | ||
train_data=train_data, | ||
train_kwargs=train_kwargs, | ||
data_format=data_format, | ||
) | ||
thread.start() | ||
|
||
|
@@ -212,47 +219,81 @@ def copy(self, **kwargs): | |
return new_instance | ||
|
||
|
||
@functools.lru_cache(maxsize=None) | ||
def cached_litellm_completion(request, num_retries: int): | ||
def request_cache(maxsize: Optional[int] = None): | ||
""" | ||
A threadsafe decorator to create an in-memory LRU cache for LM inference functions that accept | ||
a dictionary-like LM request. An in-memory cache for LM calls is critical for ensuring | ||
good performance when optimizing and evaluating DSPy LMs (disk caching alone is too slow). | ||
|
||
Args: | ||
maxsize: The maximum size of the cache. | ||
|
||
Returns: | ||
A decorator that wraps the target function with caching. | ||
""" | ||
|
||
def cache_key(request: Dict[str, Any]) -> str: | ||
# Transform Pydantic models into JSON-convertible format and exclude unhashable objects | ||
params = {k: (v.dict() if isinstance(v, pydantic.BaseModel) else v) for k, v in request.items()} | ||
params = {k: v for k, v in params.items() if not callable(v)} | ||
return sha256(ujson.dumps(params, sort_keys=True).encode()).hexdigest() | ||
|
||
def decorator(func): | ||
@cached( | ||
# NB: cachetools doesn't support maxsize=None; it recommends using float("inf") instead | ||
cache=LRUCache(maxsize=maxsize or float("inf")), | ||
key=lambda request, *args, **kwargs: cache_key(request), | ||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. This is the key advantage of cachetools. Unlike There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Can you do a global There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. With this naming, it could be confused with the disk cache though, right? It seems like we'd want some unified way to refer to both caches, or more distinctive naming. Thoughts? |
||
# Use a lock to ensure thread safety for the cache when DSPy LMs are queried | ||
# concurrently, e.g. during optimization and evaluation | ||
lock=threading.Lock(), | ||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. cachetools provides thread safety natively. alternatively, we could try to implement our own cache with the required thread safety functionality, but I suspect there might be bugs (best to reuse something that is known to work) There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. This is not a blocker for merge, but I'm slightly uneasy about Python-level locking (compared to whatever functools normally does?). Maybe it's required for thread safety, but since it's happening for every single LM call it's a bit worrisome. There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Thanks @okhat ! Functools uses a Python lock as well (Rlock). Ill follow up with a small PR to use Rlock instead of Lock. |
||
) | ||
@functools.wraps(func) | ||
def wrapper(request: dict, *args, **kwargs): | ||
return func(request, *args, **kwargs) | ||
|
||
return wrapper | ||
|
||
return decorator | ||
|
||
|
||
@request_cache(maxsize=None) | ||
def cached_litellm_completion(request: Dict[str, Any], num_retries: int): | ||
return litellm_completion( | ||
request, | ||
cache={"no-cache": False, "no-store": False}, | ||
num_retries=num_retries, | ||
) | ||
|
||
|
||
def litellm_completion(request, num_retries: int, cache={"no-cache": True, "no-store": True}): | ||
kwargs = ujson.loads(request) | ||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. We no longer have to serialize / deserialize |
||
def litellm_completion(request: Dict[str, Any], num_retries: int, cache={"no-cache": True, "no-store": True}): | ||
return litellm.completion( | ||
num_retries=num_retries, | ||
cache=cache, | ||
**kwargs, | ||
**request, | ||
) | ||
|
||
|
||
@functools.lru_cache(maxsize=None) | ||
def cached_litellm_text_completion(request, num_retries: int): | ||
@request_cache(maxsize=None) | ||
def cached_litellm_text_completion(request: Dict[str, Any], num_retries: int): | ||
return litellm_text_completion( | ||
request, | ||
num_retries=num_retries, | ||
cache={"no-cache": False, "no-store": False}, | ||
) | ||
|
||
|
||
def litellm_text_completion(request, num_retries: int, cache={"no-cache": True, "no-store": True}): | ||
kwargs = ujson.loads(request) | ||
|
||
def litellm_text_completion(request: Dict[str, Any], num_retries: int, cache={"no-cache": True, "no-store": True}): | ||
# Extract the provider and model from the model string. | ||
# TODO: Not all the models are in the format of "provider/model" | ||
model = kwargs.pop("model").split("/", 1) | ||
model = request.pop("model").split("/", 1) | ||
provider, model = model[0] if len(model) > 1 else "openai", model[-1] | ||
|
||
# Use the API key and base from the kwargs, or from the environment. | ||
api_key = kwargs.pop("api_key", None) or os.getenv(f"{provider}_API_KEY") | ||
api_base = kwargs.pop("api_base", None) or os.getenv(f"{provider}_API_BASE") | ||
# Use the API key and base from the request, or from the environment. | ||
api_key = request.pop("api_key", None) or os.getenv(f"{provider}_API_KEY") | ||
api_base = request.pop("api_base", None) or os.getenv(f"{provider}_API_BASE") | ||
|
||
# Build the prompt from the messages. | ||
prompt = "\n\n".join([x["content"] for x in kwargs.pop("messages")] + ["BEGIN RESPONSE:"]) | ||
prompt = "\n\n".join([x["content"] for x in request.pop("messages")] + ["BEGIN RESPONSE:"]) | ||
|
||
return litellm.text_completion( | ||
cache=cache, | ||
|
@@ -261,5 +302,5 @@ def litellm_text_completion(request, num_retries: int, cache={"no-cache": True, | |
api_base=api_base, | ||
prompt=prompt, | ||
num_retries=num_retries, | ||
**kwargs, | ||
**request, | ||
) |
Some generated files are not rendered by default. Learn more about how customized files appear on GitHub.
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -1,24 +1,77 @@ | ||
from unittest import mock | ||
|
||
import pydantic | ||
import pytest | ||
|
||
import dspy | ||
from tests.test_utils.server import litellm_test_server | ||
|
||
|
||
def test_lms_can_be_queried(litellm_test_server): | ||
def test_chat_lms_can_be_queried(litellm_test_server): | ||
api_base, _ = litellm_test_server | ||
expected_response = ["Hi!"] | ||
|
||
openai_lm = dspy.LM( | ||
model="openai/dspy-test-model", | ||
api_base=api_base, | ||
api_key="fakekey", | ||
model_type="chat", | ||
) | ||
openai_lm("openai query") | ||
assert openai_lm("openai query") == expected_response | ||
|
||
azure_openai_lm = dspy.LM( | ||
model="azure/dspy-test-model", | ||
api_base=api_base, | ||
api_key="fakekey", | ||
model_type="chat", | ||
) | ||
azure_openai_lm("azure openai query") | ||
assert azure_openai_lm("azure openai query") == expected_response | ||
|
||
|
||
def test_text_lms_can_be_queried(litellm_test_server): | ||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Since we're making changes to |
||
api_base, _ = litellm_test_server | ||
expected_response = ["Hi!"] | ||
|
||
openai_lm = dspy.LM( | ||
model="openai/dspy-test-model", | ||
api_base=api_base, | ||
api_key="fakekey", | ||
model_type="text", | ||
) | ||
assert openai_lm("openai query") == expected_response | ||
|
||
azure_openai_lm = dspy.LM( | ||
model="azure/dspy-test-model", | ||
api_base=api_base, | ||
api_key="fakekey", | ||
model_type="text", | ||
) | ||
assert azure_openai_lm("azure openai query") == expected_response | ||
|
||
|
||
def test_lm_calls_support_unhashable_types(litellm_test_server): | ||
api_base, server_log_file_path = litellm_test_server | ||
|
||
lm_with_unhashable_callable = dspy.LM( | ||
model="openai/dspy-test-model", | ||
api_base=api_base, | ||
api_key="fakekey", | ||
# Define a callable kwarg for the LM to use during inference | ||
azure_ad_token_provider=lambda *args, **kwargs: None, | ||
) | ||
lm_with_unhashable_callable("Query") | ||
Comment on lines
+52
to
+62
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Fails on
|
||
|
||
|
||
def test_lm_calls_support_pydantic_models(litellm_test_server): | ||
api_base, server_log_file_path = litellm_test_server | ||
|
||
class ResponseFormat(pydantic.BaseModel): | ||
response: str | ||
|
||
lm = dspy.LM( | ||
model="openai/dspy-test-model", | ||
api_base=api_base, | ||
api_key="fakekey", | ||
response_format=ResponseFormat, | ||
) | ||
lm("Query") | ||
Comment on lines
+65
to
+77
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Fails on
|
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
@okhat @bahtman @CyrusNuevoDia Thoughts on this approach? See inline comments discussing advantages below
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Looks cool!
Could set default
maxsize = float("inf")
here