Skip to content

Commit

Permalink
Caching: Use reentrant locks, don't discard callables (or any other u…
Browse files Browse the repository at this point in the history
…nhashable object in the future) from the cache key (#1905)

* fix

Signed-off-by: dbczumar <[email protected]>

* fix

Signed-off-by: dbczumar <[email protected]>

* fix

Signed-off-by: dbczumar <[email protected]>

* fix

Signed-off-by: dbczumar <[email protected]>

* Fallback

Signed-off-by: dbczumar <[email protected]>

* fix

Signed-off-by: dbczumar <[email protected]>

* fix

Signed-off-by: dbczumar <[email protected]>

* fix

Signed-off-by: dbczumar <[email protected]>

* fix

Signed-off-by: dbczumar <[email protected]>

* fix

Signed-off-by: dbczumar <[email protected]>

* fix

Signed-off-by: dbczumar <[email protected]>

* fix

Signed-off-by: dbczumar <[email protected]>

* fix

Signed-off-by: dbczumar <[email protected]>

* fix

Signed-off-by: dbczumar <[email protected]>

* Test

Signed-off-by: dbczumar <[email protected]>

* fix

Signed-off-by: dbczumar <[email protected]>

* Debug

Signed-off-by: dbczumar <[email protected]>

* Debug

Signed-off-by: dbczumar <[email protected]>

* Debug

Signed-off-by: dbczumar <[email protected]>

* fix

Signed-off-by: dbczumar <[email protected]>

* progress

Signed-off-by: dbczumar <[email protected]>

* Test no cache

Signed-off-by: dbczumar <[email protected]>

* Test no cache

Signed-off-by: dbczumar <[email protected]>

* fix

Signed-off-by: dbczumar <[email protected]>

---------

Signed-off-by: dbczumar <[email protected]>
  • Loading branch information
dbczumar authored Dec 8, 2024
1 parent a798ebc commit 4c4bc29
Show file tree
Hide file tree
Showing 3 changed files with 101 additions and 17 deletions.
45 changes: 39 additions & 6 deletions dspy/clients/lm.py
Original file line number Diff line number Diff line change
Expand Up @@ -233,23 +233,56 @@ def request_cache(maxsize: Optional[int] = None):
"""

def cache_key(request: Dict[str, Any]) -> str:
# Transform Pydantic models into JSON-convertible format and exclude unhashable objects
params = {k: (v.dict() if isinstance(v, pydantic.BaseModel) else v) for k, v in request.items()}
params = {k: v for k, v in params.items() if not callable(v)}
"""
Obtain a unique cache key for the given request dictionary by hashing its JSON
representation. For request fields having types that are known to be JSON-incompatible,
convert them to a JSON-serializable format before hashing.
Note: Values that cannot be converted to JSON should *not* be ignored / discarded, since
that would potentially lead to cache collisions. For example, consider request A
containing only JSON-convertible values and request B containing the same JSON-convertible
values in addition to one unconvertible value. Discarding the unconvertible value would
lead to a cache collision between requests A and B, even though they are semantically
different.
"""

def transform_value(value):
if isinstance(value, type) and issubclass(value, pydantic.BaseModel):
return value.schema()
elif isinstance(value, pydantic.BaseModel):
return value.dict()
elif callable(value) and hasattr(value, "__code__") and hasattr(value.__code__, "co_code"):
return value.__code__.co_code.decode("utf-8")
else:
# Note: We don't attempt to compute a hash of the value, since the default
# implementation of hash() is id(), which may collide if the same memory address
# is reused for different objects at different times
return value

params = {k: transform_value(v) for k, v in request.items()}
return sha256(ujson.dumps(params, sort_keys=True).encode()).hexdigest()

def decorator(func):
@cached(
# NB: cachetools doesn't support maxsize=None; it recommends using float("inf") instead
cache=LRUCache(maxsize=maxsize or float("inf")),
key=lambda request, *args, **kwargs: cache_key(request),
key=lambda key, request, *args, **kwargs: key,
# Use a lock to ensure thread safety for the cache when DSPy LMs are queried
# concurrently, e.g. during optimization and evaluation
lock=threading.Lock(),
lock=threading.RLock(),
)
def func_cached(key: str, request: Dict[str, Any], *args, **kwargs):
return func(request, *args, **kwargs)

@functools.wraps(func)
def wrapper(request: dict, *args, **kwargs):
return func(request, *args, **kwargs)
try:
key = cache_key(request)
return func_cached(key, request, *args, **kwargs)
except Exception:
# If the cache key cannot be computed (e.g. because it contains a value that cannot
# be converted to JSON), bypass the cache and call the target function directly
return func(request, *args, **kwargs)

return wrapper

Expand Down
42 changes: 42 additions & 0 deletions tests/caching/test_caching.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,6 +2,7 @@
import os
import shutil
import tempfile
from unittest.mock import patch

import pytest

Expand Down Expand Up @@ -108,3 +109,44 @@ def test_lm_calls_are_cached_in_memory_when_expected(litellm_test_server, tempor

request_logs = read_litellm_test_server_request_logs(server_log_file_path)
assert len(request_logs) == 2


def test_lm_calls_skip_in_memory_cache_if_key_not_computable():
with patch("litellm.completion") as mock_litellm_completion:

class NonJsonSerializable:
pass

lm = dspy.LM(
model="fakemodel/fakemodel",
non_json_serializable=NonJsonSerializable(),
)
lm("Example query")
lm("Example query")

assert mock_litellm_completion.call_count == 2


def test_lm_calls_with_callables_are_cached_as_expected():
with patch("litellm.completion") as mock_completion:
lm_with_callable = dspy.LM(
model="openai/dspy-test-model",
api_base="fakebase",
api_key="fakekey",
# Define a callable kwarg for the LM to use during inference
azure_ad_token_provider=lambda *args, **kwargs: None,
)
# Invoke the LM twice; the second call should be cached in memory
lm_with_callable("Query")
lm_with_callable("Query")

# Define and invoke a nearly-identical LM that lacks the callable kwarg,
# which should not hit the in-memory cache
lm_without_callable = dspy.LM(
model="openai/dspy-test-model",
api_base="fakebase",
api_key="fakekey",
)
lm_without_callable("Query")

assert mock_completion.call_count == 2
31 changes: 20 additions & 11 deletions tests/clients/test_lm.py
Original file line number Diff line number Diff line change
@@ -1,5 +1,6 @@
from unittest import mock

import litellm
import pydantic
import pytest

Expand Down Expand Up @@ -49,21 +50,29 @@ def test_text_lms_can_be_queried(litellm_test_server):
assert azure_openai_lm("azure openai query") == expected_response


def test_lm_calls_support_unhashable_types(litellm_test_server):
api_base, server_log_file_path = litellm_test_server
def test_lm_calls_support_callables(litellm_test_server):
api_base, _ = litellm_test_server

lm_with_unhashable_callable = dspy.LM(
model="openai/dspy-test-model",
api_base=api_base,
api_key="fakekey",
# Define a callable kwarg for the LM to use during inference
azure_ad_token_provider=lambda *args, **kwargs: None,
)
lm_with_unhashable_callable("Query")
with mock.patch("litellm.completion", autospec=True, wraps=litellm.completion) as spy_completion:
azure_ad_token_provider = lambda *args, **kwargs: None
lm_with_callable = dspy.LM(
model="openai/dspy-test-model",
api_base=api_base,
api_key="fakekey",
azure_ad_token_provider=azure_ad_token_provider,
)
lm_with_callable("Query")

spy_completion.assert_called_once()
call_args = spy_completion.call_args.kwargs
assert call_args["model"] == "openai/dspy-test-model"
assert call_args["api_base"] == api_base
assert call_args["api_key"] == "fakekey"
assert call_args["azure_ad_token_provider"] is azure_ad_token_provider


def test_lm_calls_support_pydantic_models(litellm_test_server):
api_base, server_log_file_path = litellm_test_server
api_base, _ = litellm_test_server

class ResponseFormat(pydantic.BaseModel):
response: str
Expand Down

0 comments on commit 4c4bc29

Please sign in to comment.