Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Caching: Use reentrant locks, don't discard callables (or any other unhashable object in the future) from the cache key #1905

Merged
merged 24 commits into from
Dec 8, 2024
29 changes: 25 additions & 4 deletions dspy/clients/lm.py
Original file line number Diff line number Diff line change
Expand Up @@ -233,9 +233,30 @@ def request_cache(maxsize: Optional[int] = None):
"""

def cache_key(request: Dict[str, Any]) -> str:
# Transform Pydantic models into JSON-convertible format and exclude unhashable objects
params = {k: (v.dict() if isinstance(v, pydantic.BaseModel) else v) for k, v in request.items()}
params = {k: v for k, v in params.items() if not callable(v)}
"""
Obtain a unique cache key for the given request dictionary by hashing its JSON
representation. For request fields having types that are known to be JSON-incompatible,
convert them to a JSON-serializable format before hashing.

Note: Unhashable values should *not* be ignored / discarded, since that would
potentially lead to cache collisions. For example, consider request A containing only
hashable values and request B containing the same hashable values in addition to one
unhashable value. Discarding the unhashable value would lead to a cache collision between
requests A and B, even though they are semantically different.
"""

def transform_value(value):
if isinstance(value, type) and issubclass(value, pydantic.BaseModel):
return value.schema()
elif isinstance(value, pydantic.BaseModel):
return value.dict()
else:
try:
return hash(value)
Copy link
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Most functions appear to be hashable. This is tested in test_lm.py

except Exception:
return value

params = {k: transform_value(v) for k, v in request.items()}
return sha256(ujson.dumps(params, sort_keys=True).encode()).hexdigest()

def decorator(func):
Expand All @@ -245,7 +266,7 @@ def decorator(func):
key=lambda request, *args, **kwargs: cache_key(request),
# Use a lock to ensure thread safety for the cache when DSPy LMs are queried
# concurrently, e.g. during optimization and evaluation
lock=threading.Lock(),
lock=threading.RLock(),
)
@functools.wraps(func)
def wrapper(request: dict, *args, **kwargs):
Expand Down
24 changes: 19 additions & 5 deletions tests/clients/test_lm.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,7 +4,7 @@
import pytest

import dspy
from tests.test_utils.server import litellm_test_server
from tests.test_utils.server import litellm_test_server, read_litellm_test_server_request_logs


def test_chat_lms_can_be_queried(litellm_test_server):
Expand Down Expand Up @@ -49,21 +49,35 @@ def test_text_lms_can_be_queried(litellm_test_server):
assert azure_openai_lm("azure openai query") == expected_response


def test_lm_calls_support_unhashable_types(litellm_test_server):
def test_lm_calls_support_callables(litellm_test_server):
api_base, server_log_file_path = litellm_test_server

lm_with_unhashable_callable = dspy.LM(
lm_with_callable = dspy.LM(
model="openai/dspy-test-model",
api_base=api_base,
api_key="fakekey",
# Define a callable kwarg for the LM to use during inference
azure_ad_token_provider=lambda *args, **kwargs: None,
)
lm_with_unhashable_callable("Query")
lm_with_callable("Query")

# Define and invoke a nearly-identical LM that lacks the callable kwarg
lm_without_callable = dspy.LM(
model="openai/dspy-test-model",
api_base=api_base,
api_key="fakekey",
)
lm_without_callable("Query")

# Verify that 2 requests were made to the LiteLLM server - one for each LM.
# This verifies that there wasn't a cache collision between the LMs due to
# the callable
request_logs = read_litellm_test_server_request_logs(server_log_file_path)
assert len(request_logs) == 2


def test_lm_calls_support_pydantic_models(litellm_test_server):
api_base, server_log_file_path = litellm_test_server
api_base, _ = litellm_test_server

class ResponseFormat(pydantic.BaseModel):
response: str
Expand Down
Loading