Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Integrate cachetools for in-memory LM caching, including unhashable types & pydantic #1896

Merged
merged 7 commits into from
Dec 7, 2024
Merged
Show file tree
Hide file tree
Changes from 6 commits
Commits
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
77 changes: 59 additions & 18 deletions dspy/clients/lm.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,10 +4,13 @@
import threading
import uuid
from datetime import datetime
from hashlib import sha256
from typing import Any, Dict, List, Literal, Optional

import litellm
import pydantic
import ujson
from cachetools import LRUCache, cached

from dspy.adapters.base import Adapter
from dspy.clients.openai import OpenAIProvider
Expand Down Expand Up @@ -92,7 +95,7 @@ def __call__(self, prompt=None, messages=None, **kwargs):
completion = cached_litellm_text_completion if cache else litellm_text_completion

response = completion(
request=ujson.dumps(dict(model=self.model, messages=messages, **kwargs)),
request=dict(model=self.model, messages=messages, **kwargs),
num_retries=self.num_retries,
)
outputs = [c.message.content if hasattr(c, "message") else c["text"] for c in response["choices"]]
Expand Down Expand Up @@ -153,7 +156,11 @@ def thread_function_wrapper():
thread = threading.Thread(target=thread_function_wrapper)
model_to_finetune = self.finetuning_model or self.model
job = self.provider.TrainingJob(
thread=thread, model=model_to_finetune, train_data=train_data, train_kwargs=train_kwargs, data_format=data_format
thread=thread,
model=model_to_finetune,
train_data=train_data,
train_kwargs=train_kwargs,
data_format=data_format,
)
thread.start()

Expand Down Expand Up @@ -212,47 +219,81 @@ def copy(self, **kwargs):
return new_instance


@functools.lru_cache(maxsize=None)
def cached_litellm_completion(request, num_retries: int):
def request_cache(maxsize: Optional[int] = None):
Copy link
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

@okhat @bahtman @CyrusNuevoDia Thoughts on this approach? See inline comments discussing advantages below

Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Looks cool!

Could set default maxsize = float("inf") here

"""
A threadsafe decorator to create an in-memory LRU cache for LM inference functions that accept
a dictionary-like LM request. An in-memory cache for LM calls is critical for ensuring
good performance when optimizing and evaluating DSPy LMs (disk caching alone is too slow).

Args:
maxsize: The maximum size of the cache.

Returns:
A decorator that wraps the target function with caching.
"""

def cache_key(request: Dict[str, Any]) -> str:
# Transform Pydantic models into JSON-convertible format and exclude unhashable objects
params = {k: (v.dict() if isinstance(v, pydantic.BaseModel) else v) for k, v in request.items()}
params = {k: v for k, v in params.items() if not callable(v)}
return sha256(ujson.dumps(params, sort_keys=True).encode()).hexdigest()

def decorator(func):
@cached(
# NB: cachetools doesn't support maxsize=None; it recommends using float("inf") instead
cache=LRUCache(maxsize=maxsize or float("inf")),
key=lambda request, *args, **kwargs: cache_key(request),
Copy link
Collaborator Author

@dbczumar dbczumar Dec 6, 2024

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

This is the key advantage of cachetools. Unlike lru_cache, it allows us to define a cache key by applying a custom function to one or more arguments, rather than forcing all arguments to be hashed / JSON-encoded, passed to the function, and then decoded afterwards. Encoding / decoding is infeasible for callables.

Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Can you do a global dspy.settings.request_cache default = LRUCache(maxsize=10_000_000) and then have this function pull from that?

Copy link
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

With this naming, it could be confused with the disk cache though, right? It seems like we'd want some unified way to refer to both caches, or more distinctive naming. Thoughts?

# Use a lock to ensure thread safety for the cache when DSPy LMs are queried
# concurrently, e.g. during optimization and evaluation
lock=threading.Lock(),
Copy link
Collaborator Author

@dbczumar dbczumar Dec 6, 2024

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

cachetools provides thread safety natively. alternatively, we could try to implement our own cache with the required thread safety functionality, but I suspect there might be bugs (best to reuse something that is known to work)

Copy link
Collaborator

@okhat okhat Dec 7, 2024

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

This is not a blocker for merge, but I'm slightly uneasy about Python-level locking (compared to whatever functools normally does?). Maybe it's required for thread safety, but since it's happening for every single LM call it's a bit worrisome.

Copy link
Collaborator Author

@dbczumar dbczumar Dec 7, 2024

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Thanks @okhat ! Functools uses a Python lock as well (Rlock). Ill follow up with a small PR to use Rlock instead of Lock.

)
@functools.wraps(func)
def wrapper(request: dict, *args, **kwargs):
return func(request, *args, **kwargs)

return wrapper

return decorator


@request_cache(maxsize=None)
def cached_litellm_completion(request: Dict[str, Any], num_retries: int):
return litellm_completion(
request,
cache={"no-cache": False, "no-store": False},
num_retries=num_retries,
)


def litellm_completion(request, num_retries: int, cache={"no-cache": True, "no-store": True}):
kwargs = ujson.loads(request)
Copy link
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

We no longer have to serialize / deserialize request within the litellm_completion and litellm_text_completion calls

def litellm_completion(request: Dict[str, Any], num_retries: int, cache={"no-cache": True, "no-store": True}):
return litellm.completion(
num_retries=num_retries,
cache=cache,
**kwargs,
**request,
)


@functools.lru_cache(maxsize=None)
def cached_litellm_text_completion(request, num_retries: int):
@request_cache(maxsize=None)
def cached_litellm_text_completion(request: Dict[str, Any], num_retries: int):
return litellm_text_completion(
request,
num_retries=num_retries,
cache={"no-cache": False, "no-store": False},
)


def litellm_text_completion(request, num_retries: int, cache={"no-cache": True, "no-store": True}):
kwargs = ujson.loads(request)

def litellm_text_completion(request: Dict[str, Any], num_retries: int, cache={"no-cache": True, "no-store": True}):
# Extract the provider and model from the model string.
# TODO: Not all the models are in the format of "provider/model"
model = kwargs.pop("model").split("/", 1)
model = request.pop("model").split("/", 1)
provider, model = model[0] if len(model) > 1 else "openai", model[-1]

# Use the API key and base from the kwargs, or from the environment.
api_key = kwargs.pop("api_key", None) or os.getenv(f"{provider}_API_KEY")
api_base = kwargs.pop("api_base", None) or os.getenv(f"{provider}_API_BASE")
# Use the API key and base from the request, or from the environment.
api_key = request.pop("api_key", None) or os.getenv(f"{provider}_API_KEY")
api_base = request.pop("api_base", None) or os.getenv(f"{provider}_API_BASE")

# Build the prompt from the messages.
prompt = "\n\n".join([x["content"] for x in kwargs.pop("messages")] + ["BEGIN RESPONSE:"])
prompt = "\n\n".join([x["content"] for x in request.pop("messages")] + ["BEGIN RESPONSE:"])

return litellm.text_completion(
cache=cache,
Expand All @@ -261,5 +302,5 @@ def litellm_text_completion(request, num_retries: int, cache={"no-cache": True,
api_base=api_base,
prompt=prompt,
num_retries=num_retries,
**kwargs,
**request,
)
4 changes: 2 additions & 2 deletions poetry.lock

Some generated files are not rendered by default. Learn more about how customized files appear on GitHub.

2 changes: 2 additions & 0 deletions pyproject.toml
Original file line number Diff line number Diff line change
Expand Up @@ -44,6 +44,7 @@ dependencies = [
"tenacity>=8.2.3",
"anyio",
"asyncer==0.0.8",
"cachetools",
]

[project.optional-dependencies]
Expand Down Expand Up @@ -138,6 +139,7 @@ falkordb = "^1.0.9"
json-repair = "^0.30.0"
tenacity = ">=8.2.3"
asyncer = "0.0.8"
cachetools = "^5.5.0"

[tool.poetry.group.dev.dependencies]
pytest = "^8.3.3"
Expand Down
5 changes: 3 additions & 2 deletions requirements.txt
Original file line number Diff line number Diff line change
@@ -1,4 +1,7 @@
anyio
asyncer==0.0.8
backoff
cachetools
datasets
diskcache
httpx
Expand All @@ -15,5 +18,3 @@ requests
tenacity>=8.2.3
tqdm
ujson
anyio
asyncer==0.0.8
20 changes: 20 additions & 0 deletions tests/caching/test_caching.py
Original file line number Diff line number Diff line change
Expand Up @@ -88,3 +88,23 @@ def test_lm_calls_are_cached_across_interpreter_sessions(litellm_test_server, te

request_logs = read_litellm_test_server_request_logs(server_log_file_path)
assert len(request_logs) == 0


def test_lm_calls_are_cached_in_memory_when_expected(litellm_test_server, temporary_blank_cache_dir):
api_base, server_log_file_path = litellm_test_server

lm1 = dspy.LM(
model="openai/dspy-test-model",
api_base=api_base,
api_key="fakekey",
)
lm1("Example query")
# Remove the disk cache, after which the LM must rely on in-memory caching
shutil.rmtree(temporary_blank_cache_dir)
lm1("Example query2")
lm1("Example query2")
lm1("Example query2")
lm1("Example query2")

request_logs = read_litellm_test_server_request_logs(server_log_file_path)
assert len(request_logs) == 2
59 changes: 56 additions & 3 deletions tests/clients/test_lm.py
Original file line number Diff line number Diff line change
@@ -1,24 +1,77 @@
from unittest import mock

import pydantic
import pytest

import dspy
from tests.test_utils.server import litellm_test_server


def test_lms_can_be_queried(litellm_test_server):
def test_chat_lms_can_be_queried(litellm_test_server):
api_base, _ = litellm_test_server
expected_response = ["Hi!"]

openai_lm = dspy.LM(
model="openai/dspy-test-model",
api_base=api_base,
api_key="fakekey",
model_type="chat",
)
openai_lm("openai query")
assert openai_lm("openai query") == expected_response

azure_openai_lm = dspy.LM(
model="azure/dspy-test-model",
api_base=api_base,
api_key="fakekey",
model_type="chat",
)
azure_openai_lm("azure openai query")
assert azure_openai_lm("azure openai query") == expected_response


def test_text_lms_can_be_queried(litellm_test_server):
Copy link
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Since we're making changes to litellm_text_completion as well, we should have some coverage for LM queries with model_type="text"

api_base, _ = litellm_test_server
expected_response = ["Hi!"]

openai_lm = dspy.LM(
model="openai/dspy-test-model",
api_base=api_base,
api_key="fakekey",
model_type="text",
)
assert openai_lm("openai query") == expected_response

azure_openai_lm = dspy.LM(
model="azure/dspy-test-model",
api_base=api_base,
api_key="fakekey",
model_type="text",
)
assert azure_openai_lm("azure openai query") == expected_response


def test_lm_calls_support_unhashable_types(litellm_test_server):
api_base, server_log_file_path = litellm_test_server

lm_with_unhashable_callable = dspy.LM(
model="openai/dspy-test-model",
api_base=api_base,
api_key="fakekey",
# Define a callable kwarg for the LM to use during inference
azure_ad_token_provider=lambda *args, **kwargs: None,
)
lm_with_unhashable_callable("Query")
Comment on lines +52 to +62
Copy link
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Fails on main with:

E       TypeError: <function test_lm_calls_support_unhashable_types.<locals>.<lambda> at 0x31204d5a0> is not JSON serializable



def test_lm_calls_support_pydantic_models(litellm_test_server):
api_base, server_log_file_path = litellm_test_server

class ResponseFormat(pydantic.BaseModel):
response: str

lm = dspy.LM(
model="openai/dspy-test-model",
api_base=api_base,
api_key="fakekey",
response_format=ResponseFormat,
)
lm("Query")
Comment on lines +65 to +77
Copy link
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Fails on main with:

TypeError: <class 'tests.caching.test_caching.test_lm_calls_support_pydantic_models.<locals>.ResponseFormat'> is not JSON serializable

Loading