Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Retries via LiteLLM RetryPolicy #1866

Merged
merged 17 commits into from
Dec 17, 2024
Merged
40 changes: 35 additions & 5 deletions dspy/clients/lm.py
Original file line number Diff line number Diff line change
Expand Up @@ -11,6 +11,7 @@
import pydantic
import ujson
from cachetools import LRUCache, cached
from litellm import RetryPolicy

from dspy.adapters.base import Adapter
from dspy.clients.openai import OpenAIProvider
Expand All @@ -36,7 +37,7 @@ def __init__(
max_tokens: int = 1000,
cache: bool = True,
callbacks: Optional[List[BaseCallback]] = None,
num_retries: int = 3,
num_retries: int = 8,
Copy link
Collaborator Author

@dbczumar dbczumar Dec 17, 2024

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Empirically, this provides roughly 1 minute of retries, which is typically necessary to overcome rate limit errors (providers like Azure OpenAI & Databricks support RPM rate limits)

provider=None,
finetuning_model: Optional[str] = None,
launch_kwargs: Optional[dict[str, Any]] = None,
Expand Down Expand Up @@ -102,14 +103,13 @@ def __call__(self, prompt=None, messages=None, **kwargs):
outputs = [
{
"text": c.message.content if hasattr(c, "message") else c["text"],
"logprobs": c.logprobs if hasattr(c, "logprobs") else c["logprobs"]
"logprobs": c.logprobs if hasattr(c, "logprobs") else c["logprobs"],
Copy link
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Just the linter being itself...

}
for c in response["choices"]
]
else:
outputs = [c.message.content if hasattr(c, "message") else c["text"] for c in response["choices"]]


# Logging, with removed api key & where `cost` is None on cache hit.
kwargs = {k: v for k, v in kwargs.items() if not k.startswith("api_")}
entry = dict(prompt=prompt, messages=messages, kwargs=kwargs, response=response)
Expand Down Expand Up @@ -310,8 +310,12 @@ def cached_litellm_completion(request: Dict[str, Any], num_retries: int):

def litellm_completion(request: Dict[str, Any], num_retries: int, cache={"no-cache": True, "no-store": True}):
return litellm.completion(
num_retries=num_retries,
cache=cache,
retry_policy=_get_litellm_retry_policy(num_retries),
# In LiteLLM version 1.55.3 (the first version that supports retry_policy as an argument
# to completion()), the default value of max_retries is non-zero for certain providers, and
# max_retries is stacked on top of the retry_policy. To avoid this, we set max_retries=0
max_retries=0,
**request,
)

Expand Down Expand Up @@ -344,6 +348,32 @@ def litellm_text_completion(request: Dict[str, Any], num_retries: int, cache={"n
api_key=api_key,
api_base=api_base,
prompt=prompt,
num_retries=num_retries,
retry_policy=_get_litellm_retry_policy(num_retries),
# In LiteLLM version 1.55.3 (the first version that supports retry_policy as an argument
# to completion()), the default value of max_retries is non-zero for certain providers, and
# max_retries is stacked on top of the retry_policy. To avoid this, we set max_retries=0
max_retries=0,
**request,
)


def _get_litellm_retry_policy(num_retries: int) -> RetryPolicy:
"""
Get a LiteLLM retry policy for retrying requests when transient API errors occur.
Args:
num_retries: The number of times to retry a request if it fails transiently due to
network error, rate limiting, etc. Requests are retried with exponential
backoff.
Returns:
A LiteLLM RetryPolicy instance.
"""
return RetryPolicy(
TimeoutErrorRetries=num_retries,
RateLimitErrorRetries=num_retries,
InternalServerErrorRetries=num_retries,
ContentPolicyViolationErrorRetries=num_retries,
# We don't retry on errors that are unlikely to be transient
# (e.g. bad request, invalid auth credentials)
BadRequestErrorRetries=0,
AuthenticationErrorRetries=0,
)
8 changes: 4 additions & 4 deletions poetry.lock

Some generated files are not rendered by default. Learn more about how customized files appear on GitHub.

4 changes: 2 additions & 2 deletions pyproject.toml
Original file line number Diff line number Diff line change
Expand Up @@ -38,7 +38,7 @@ dependencies = [
"pydantic~=2.0",
"jinja2",
"magicattr~=0.1.6",
"litellm",
"litellm==1.55.3",
Copy link
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

LiteLLM version 1.55.3 is the only version that correctly supports passing retry_policy to completion() while respecting the number of retries specified by the policy

"diskcache",
"json-repair",
"tenacity>=8.2.3",
Expand Down Expand Up @@ -132,7 +132,7 @@ pgvector = { version = "^0.2.5", optional = true }
llama-index = { version = "^0.10.30", optional = true }
jinja2 = "^3.1.3"
magicattr = "^0.1.6"
litellm = { version = "==1.53.7", extras = ["proxy"] }
litellm = "1.55.3"
diskcache = "^5.6.0"
json-repair = "^0.30.0"
tenacity = ">=8.2.3"
Expand Down
6 changes: 3 additions & 3 deletions requirements.txt
Original file line number Diff line number Diff line change
Expand Up @@ -2,12 +2,14 @@ anyio
asyncer==0.0.8
backoff
cachetools
cloudpickle
datasets
diskcache
httpx
jinja2
joblib~=1.3
json-repair
litellm[proxy]==1.53.7
litellm[proxy]==1.55.3
magicattr~=0.1.6
openai
optuna
Expand All @@ -18,5 +20,3 @@ requests
tenacity>=8.2.3
tqdm
ujson
cloudpickle
jinja2
Comment on lines -21 to -22
Copy link
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Linter reordering

74 changes: 73 additions & 1 deletion tests/clients/test_lm.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,7 +5,7 @@
import pytest

import dspy
from tests.test_utils.server import litellm_test_server
from tests.test_utils.server import litellm_test_server, read_litellm_test_server_request_logs


def test_chat_lms_can_be_queried(litellm_test_server):
Expand Down Expand Up @@ -84,3 +84,75 @@ class ResponseFormat(pydantic.BaseModel):
response_format=ResponseFormat,
)
lm("Query")


@pytest.mark.parametrize(
("error_code", "expected_exception", "expected_num_retries"),
[
("429", litellm.RateLimitError, 2),
("504", litellm.Timeout, 2),
# Don't retry on user errors
("400", litellm.BadRequestError, 0),
("401", litellm.AuthenticationError, 0),
# TODO: LiteLLM retry logic isn't implemented properly for internal server errors
# and content policy violations, both of which may be transient and should be retried
# ("content-policy-violation, litellm.BadRequestError, 1),
# ("500", litellm.InternalServerError, 0, 1),
],
)
def test_lm_chat_calls_are_retried_for_expected_failures(
litellm_test_server,
error_code,
expected_exception,
expected_num_retries,
):
api_base, server_log_file_path = litellm_test_server

openai_lm = dspy.LM(
model="openai/dspy-test-model",
api_base=api_base,
api_key="fakekey",
num_retries=2,
model_type="chat",
)
with pytest.raises(expected_exception):
openai_lm(error_code)

request_logs = read_litellm_test_server_request_logs(server_log_file_path)
assert len(request_logs) == expected_num_retries + 1 # 1 initial request + 1 retries


@pytest.mark.parametrize(
("error_code", "expected_exception", "expected_num_retries"),
[
("429", litellm.RateLimitError, 2),
("504", litellm.Timeout, 2),
# Don't retry on user errors
("400", litellm.BadRequestError, 0),
("401", litellm.AuthenticationError, 0),
# TODO: LiteLLM retry logic isn't implemented properly for internal server errors
# and content policy violations, both of which may be transient and should be retried
# ("content-policy-violation, litellm.BadRequestError, 2),
# ("500", litellm.InternalServerError, 0, 2),
],
)
def test_lm_text_calls_are_retried_for_expected_failures(
litellm_test_server,
error_code,
expected_exception,
expected_num_retries,
):
api_base, server_log_file_path = litellm_test_server

openai_lm = dspy.LM(
model="openai/dspy-test-model",
api_base=api_base,
api_key="fakekey",
num_retries=2,
model_type="text",
)
with pytest.raises(expected_exception):
openai_lm(error_code)

request_logs = read_litellm_test_server_request_logs(server_log_file_path)
assert len(request_logs) == expected_num_retries + 1 # 1 initial request + 1 retries
23 changes: 20 additions & 3 deletions tests/test_utils/server/litellm_server.py
Original file line number Diff line number Diff line change
Expand Up @@ -10,21 +10,38 @@
class DSPyTestModel(CustomLLM):
def completion(self, *args, **kwargs) -> litellm.ModelResponse:
_append_request_to_log_file(kwargs)
return _get_mock_llm_response()
return _get_mock_llm_response(kwargs)

async def acompletion(self, *args, **kwargs) -> litellm.ModelResponse:
_append_request_to_log_file(kwargs)
return _get_mock_llm_response()
return _get_mock_llm_response(kwargs)


def _get_mock_llm_response():
def _get_mock_llm_response(request_kwargs):
_throw_exception_based_on_content_if_applicable(request_kwargs)
return litellm.completion(
model="gpt-3.5-turbo",
messages=[{"role": "user", "content": "Hello world"}],
mock_response="Hi!",
)


def _throw_exception_based_on_content_if_applicable(request_kwargs):
"""
Throws an exception, for testing purposes, based on the content of the request message.
"""
model = request_kwargs["model"]
content = request_kwargs["messages"][0]["content"]
if "429" in content:
raise litellm.RateLimitError(message="Rate limit exceeded", llm_provider=None, model=model)
elif "504" in content:
raise litellm.Timeout("Request timed out!", llm_provider=None, model=model)
elif "400" in content:
raise litellm.BadRequestError(message="Bad request", llm_provider=None, model=model)
elif "401" in content:
raise litellm.AuthenticationError(message="Authentication error", llm_provider=None, model=model)


def _append_request_to_log_file(completion_kwargs):
log_file_path = os.environ.get(LITELLM_TEST_SERVER_LOG_FILE_PATH_ENV_VAR)
if log_file_path is None:
Expand Down
1 change: 1 addition & 0 deletions tests/test_utils/server/litellm_server_config.yaml
Original file line number Diff line number Diff line change
Expand Up @@ -7,6 +7,7 @@ model_list:
model: "dspy-test-provider/dspy-test-model"

litellm_settings:
num_retries: 0
Copy link
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Disable retries on the server side to ensure that server retries don't stack atop client retries, which can result in test failures due to a mismatch between expected and actual # of retries

custom_provider_map:
- {
"provider": "dspy-test-provider",
Expand Down
Loading