Skip to content

Commit

Permalink
Retries via LiteLLM RetryPolicy (#1866)
Browse files Browse the repository at this point in the history
* Retr

Signed-off-by: dbczumar <[email protected]>

* works

Signed-off-by: dbczumar <[email protected]>

* retry

Signed-off-by: dbczumar <[email protected]>

* Retry

Signed-off-by: dbczumar <[email protected]>

* fix

Signed-off-by: dbczumar <[email protected]>

* fix

Signed-off-by: dbczumar <[email protected]>

* Rename

Signed-off-by: dbczumar <[email protected]>

* fix

Signed-off-by: dbczumar <[email protected]>

* Update test_lm.py

* Update test_lm.py

* Update test_lm.py

* fix

Signed-off-by: dbczumar <[email protected]>

* fix

Signed-off-by: dbczumar <[email protected]>

* fix

Signed-off-by: dbczumar <[email protected]>

* Make tests more robust

Signed-off-by: dbczumar <[email protected]>

* Update pyproject.toml

---------

Signed-off-by: dbczumar <[email protected]>
  • Loading branch information
dbczumar authored Dec 17, 2024
1 parent 7c2f604 commit 1bae71c
Show file tree
Hide file tree
Showing 7 changed files with 138 additions and 18 deletions.
40 changes: 35 additions & 5 deletions dspy/clients/lm.py
Original file line number Diff line number Diff line change
Expand Up @@ -11,6 +11,7 @@
import pydantic
import ujson
from cachetools import LRUCache, cached
from litellm import RetryPolicy

from dspy.adapters.base import Adapter
from dspy.clients.openai import OpenAIProvider
Expand All @@ -36,7 +37,7 @@ def __init__(
max_tokens: int = 1000,
cache: bool = True,
callbacks: Optional[List[BaseCallback]] = None,
num_retries: int = 3,
num_retries: int = 8,
provider=None,
finetuning_model: Optional[str] = None,
launch_kwargs: Optional[dict[str, Any]] = None,
Expand Down Expand Up @@ -102,14 +103,13 @@ def __call__(self, prompt=None, messages=None, **kwargs):
outputs = [
{
"text": c.message.content if hasattr(c, "message") else c["text"],
"logprobs": c.logprobs if hasattr(c, "logprobs") else c["logprobs"]
"logprobs": c.logprobs if hasattr(c, "logprobs") else c["logprobs"],
}
for c in response["choices"]
]
else:
outputs = [c.message.content if hasattr(c, "message") else c["text"] for c in response["choices"]]


# Logging, with removed api key & where `cost` is None on cache hit.
kwargs = {k: v for k, v in kwargs.items() if not k.startswith("api_")}
entry = dict(prompt=prompt, messages=messages, kwargs=kwargs, response=response)
Expand Down Expand Up @@ -310,8 +310,12 @@ def cached_litellm_completion(request: Dict[str, Any], num_retries: int):

def litellm_completion(request: Dict[str, Any], num_retries: int, cache={"no-cache": True, "no-store": True}):
return litellm.completion(
num_retries=num_retries,
cache=cache,
retry_policy=_get_litellm_retry_policy(num_retries),
# In LiteLLM version 1.55.3 (the first version that supports retry_policy as an argument
# to completion()), the default value of max_retries is non-zero for certain providers, and
# max_retries is stacked on top of the retry_policy. To avoid this, we set max_retries=0
max_retries=0,
**request,
)

Expand Down Expand Up @@ -344,6 +348,32 @@ def litellm_text_completion(request: Dict[str, Any], num_retries: int, cache={"n
api_key=api_key,
api_base=api_base,
prompt=prompt,
num_retries=num_retries,
retry_policy=_get_litellm_retry_policy(num_retries),
# In LiteLLM version 1.55.3 (the first version that supports retry_policy as an argument
# to completion()), the default value of max_retries is non-zero for certain providers, and
# max_retries is stacked on top of the retry_policy. To avoid this, we set max_retries=0
max_retries=0,
**request,
)


def _get_litellm_retry_policy(num_retries: int) -> RetryPolicy:
"""
Get a LiteLLM retry policy for retrying requests when transient API errors occur.
Args:
num_retries: The number of times to retry a request if it fails transiently due to
network error, rate limiting, etc. Requests are retried with exponential
backoff.
Returns:
A LiteLLM RetryPolicy instance.
"""
return RetryPolicy(
TimeoutErrorRetries=num_retries,
RateLimitErrorRetries=num_retries,
InternalServerErrorRetries=num_retries,
ContentPolicyViolationErrorRetries=num_retries,
# We don't retry on errors that are unlikely to be transient
# (e.g. bad request, invalid auth credentials)
BadRequestErrorRetries=0,
AuthenticationErrorRetries=0,
)
8 changes: 4 additions & 4 deletions poetry.lock

Some generated files are not rendered by default. Learn more about how customized files appear on GitHub.

4 changes: 2 additions & 2 deletions pyproject.toml
Original file line number Diff line number Diff line change
Expand Up @@ -38,7 +38,7 @@ dependencies = [
"pydantic~=2.0",
"jinja2",
"magicattr~=0.1.6",
"litellm",
"litellm==1.55.3",
"diskcache",
"json-repair",
"tenacity>=8.2.3",
Expand Down Expand Up @@ -132,7 +132,7 @@ pgvector = { version = "^0.2.5", optional = true }
llama-index = { version = "^0.10.30", optional = true }
jinja2 = "^3.1.3"
magicattr = "^0.1.6"
litellm = { version = "==1.53.7", extras = ["proxy"] }
litellm = { version = "==1.55.3", extras = ["proxy"] }
diskcache = "^5.6.0"
json-repair = "^0.30.0"
tenacity = ">=8.2.3"
Expand Down
6 changes: 3 additions & 3 deletions requirements.txt
Original file line number Diff line number Diff line change
Expand Up @@ -2,12 +2,14 @@ anyio
asyncer==0.0.8
backoff
cachetools
cloudpickle
datasets
diskcache
httpx
jinja2
joblib~=1.3
json-repair
litellm[proxy]==1.53.7
litellm[proxy]==1.55.3
magicattr~=0.1.6
openai
optuna
Expand All @@ -18,5 +20,3 @@ requests
tenacity>=8.2.3
tqdm
ujson
cloudpickle
jinja2
74 changes: 73 additions & 1 deletion tests/clients/test_lm.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,7 +5,7 @@
import pytest

import dspy
from tests.test_utils.server import litellm_test_server
from tests.test_utils.server import litellm_test_server, read_litellm_test_server_request_logs


def test_chat_lms_can_be_queried(litellm_test_server):
Expand Down Expand Up @@ -84,3 +84,75 @@ class ResponseFormat(pydantic.BaseModel):
response_format=ResponseFormat,
)
lm("Query")


@pytest.mark.parametrize(
("error_code", "expected_exception", "expected_num_retries"),
[
("429", litellm.RateLimitError, 2),
("504", litellm.Timeout, 3),
# Don't retry on user errors
("400", litellm.BadRequestError, 0),
("401", litellm.AuthenticationError, 0),
# TODO: LiteLLM retry logic isn't implemented properly for internal server errors
# and content policy violations, both of which may be transient and should be retried
# ("content-policy-violation, litellm.BadRequestError, 1),
# ("500", litellm.InternalServerError, 0, 1),
],
)
def test_lm_chat_calls_are_retried_for_expected_failures(
litellm_test_server,
error_code,
expected_exception,
expected_num_retries,
):
api_base, server_log_file_path = litellm_test_server

openai_lm = dspy.LM(
model="openai/dspy-test-model",
api_base=api_base,
api_key="fakekey",
num_retries=expected_num_retries,
model_type="chat",
)
with pytest.raises(expected_exception):
openai_lm(error_code)

request_logs = read_litellm_test_server_request_logs(server_log_file_path)
assert len(request_logs) == expected_num_retries + 1 # 1 initial request + 1 retries


@pytest.mark.parametrize(
("error_code", "expected_exception", "expected_num_retries"),
[
("429", litellm.RateLimitError, 2),
("504", litellm.Timeout, 3),
# Don't retry on user errors
("400", litellm.BadRequestError, 0),
("401", litellm.AuthenticationError, 0),
# TODO: LiteLLM retry logic isn't implemented properly for internal server errors
# and content policy violations, both of which may be transient and should be retried
# ("content-policy-violation, litellm.BadRequestError, 2),
# ("500", litellm.InternalServerError, 0, 2),
],
)
def test_lm_text_calls_are_retried_for_expected_failures(
litellm_test_server,
error_code,
expected_exception,
expected_num_retries,
):
api_base, server_log_file_path = litellm_test_server

openai_lm = dspy.LM(
model="openai/dspy-test-model",
api_base=api_base,
api_key="fakekey",
num_retries=expected_num_retries,
model_type="text",
)
with pytest.raises(expected_exception):
openai_lm(error_code)

request_logs = read_litellm_test_server_request_logs(server_log_file_path)
assert len(request_logs) == expected_num_retries + 1 # 1 initial request + 1 retries
23 changes: 20 additions & 3 deletions tests/test_utils/server/litellm_server.py
Original file line number Diff line number Diff line change
Expand Up @@ -10,21 +10,38 @@
class DSPyTestModel(CustomLLM):
def completion(self, *args, **kwargs) -> litellm.ModelResponse:
_append_request_to_log_file(kwargs)
return _get_mock_llm_response()
return _get_mock_llm_response(kwargs)

async def acompletion(self, *args, **kwargs) -> litellm.ModelResponse:
_append_request_to_log_file(kwargs)
return _get_mock_llm_response()
return _get_mock_llm_response(kwargs)


def _get_mock_llm_response():
def _get_mock_llm_response(request_kwargs):
_throw_exception_based_on_content_if_applicable(request_kwargs)
return litellm.completion(
model="gpt-3.5-turbo",
messages=[{"role": "user", "content": "Hello world"}],
mock_response="Hi!",
)


def _throw_exception_based_on_content_if_applicable(request_kwargs):
"""
Throws an exception, for testing purposes, based on the content of the request message.
"""
model = request_kwargs["model"]
content = request_kwargs["messages"][0]["content"]
if "429" in content:
raise litellm.RateLimitError(message="Rate limit exceeded", llm_provider=None, model=model)
elif "504" in content:
raise litellm.Timeout("Request timed out!", llm_provider=None, model=model)
elif "400" in content:
raise litellm.BadRequestError(message="Bad request", llm_provider=None, model=model)
elif "401" in content:
raise litellm.AuthenticationError(message="Authentication error", llm_provider=None, model=model)


def _append_request_to_log_file(completion_kwargs):
log_file_path = os.environ.get(LITELLM_TEST_SERVER_LOG_FILE_PATH_ENV_VAR)
if log_file_path is None:
Expand Down
1 change: 1 addition & 0 deletions tests/test_utils/server/litellm_server_config.yaml
Original file line number Diff line number Diff line change
Expand Up @@ -7,6 +7,7 @@ model_list:
model: "dspy-test-provider/dspy-test-model"

litellm_settings:
num_retries: 0
custom_provider_map:
- {
"provider": "dspy-test-provider",
Expand Down

0 comments on commit 1bae71c

Please sign in to comment.