-
Notifications
You must be signed in to change notification settings - Fork 1.6k
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Retries via LiteLLM RetryPolicy #1866
Changes from 15 commits
faa2698
f52c4f1
8a01558
b808a0f
6acd071
5472be6
b809af7
0db52c8
3b29d81
3b6d2b4
f3c76dc
9c4ed35
81e062b
71a8015
bec6341
1b666ad
3600c1a
File filter
Filter by extension
Conversations
Jump to
Diff view
Diff view
There are no files selected for viewing
Original file line number | Diff line number | Diff line change |
---|---|---|
|
@@ -11,6 +11,7 @@ | |
import pydantic | ||
import ujson | ||
from cachetools import LRUCache, cached | ||
from litellm import RetryPolicy | ||
|
||
from dspy.adapters.base import Adapter | ||
from dspy.clients.openai import OpenAIProvider | ||
|
@@ -36,7 +37,7 @@ def __init__( | |
max_tokens: int = 1000, | ||
cache: bool = True, | ||
callbacks: Optional[List[BaseCallback]] = None, | ||
num_retries: int = 3, | ||
num_retries: int = 8, | ||
provider=None, | ||
finetuning_model: Optional[str] = None, | ||
launch_kwargs: Optional[dict[str, Any]] = None, | ||
|
@@ -102,14 +103,13 @@ def __call__(self, prompt=None, messages=None, **kwargs): | |
outputs = [ | ||
{ | ||
"text": c.message.content if hasattr(c, "message") else c["text"], | ||
"logprobs": c.logprobs if hasattr(c, "logprobs") else c["logprobs"] | ||
"logprobs": c.logprobs if hasattr(c, "logprobs") else c["logprobs"], | ||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Just the linter being itself... |
||
} | ||
for c in response["choices"] | ||
] | ||
else: | ||
outputs = [c.message.content if hasattr(c, "message") else c["text"] for c in response["choices"]] | ||
|
||
|
||
# Logging, with removed api key & where `cost` is None on cache hit. | ||
kwargs = {k: v for k, v in kwargs.items() if not k.startswith("api_")} | ||
entry = dict(prompt=prompt, messages=messages, kwargs=kwargs, response=response) | ||
|
@@ -310,8 +310,12 @@ def cached_litellm_completion(request: Dict[str, Any], num_retries: int): | |
|
||
def litellm_completion(request: Dict[str, Any], num_retries: int, cache={"no-cache": True, "no-store": True}): | ||
return litellm.completion( | ||
num_retries=num_retries, | ||
cache=cache, | ||
retry_policy=_get_litellm_retry_policy(num_retries), | ||
# In LiteLLM version 1.55.3 (the first version that supports retry_policy as an argument | ||
# to completion()), the default value of max_retries is non-zero for certain providers, and | ||
# max_retries is stacked on top of the retry_policy. To avoid this, we set max_retries=0 | ||
max_retries=0, | ||
**request, | ||
) | ||
|
||
|
@@ -344,6 +348,32 @@ def litellm_text_completion(request: Dict[str, Any], num_retries: int, cache={"n | |
api_key=api_key, | ||
api_base=api_base, | ||
prompt=prompt, | ||
num_retries=num_retries, | ||
retry_policy=_get_litellm_retry_policy(num_retries), | ||
# In LiteLLM version 1.55.3 (the first version that supports retry_policy as an argument | ||
# to completion()), the default value of max_retries is non-zero for certain providers, and | ||
# max_retries is stacked on top of the retry_policy. To avoid this, we set max_retries=0 | ||
max_retries=0, | ||
**request, | ||
) | ||
|
||
|
||
def _get_litellm_retry_policy(num_retries: int) -> RetryPolicy: | ||
""" | ||
Get a LiteLLM retry policy for retrying requests when transient API errors occur. | ||
Args: | ||
num_retries: The number of times to retry a request if it fails transiently due to | ||
network error, rate limiting, etc. Requests are retried with exponential | ||
backoff. | ||
Returns: | ||
A LiteLLM RetryPolicy instance. | ||
""" | ||
return RetryPolicy( | ||
TimeoutErrorRetries=num_retries, | ||
RateLimitErrorRetries=num_retries, | ||
InternalServerErrorRetries=num_retries, | ||
ContentPolicyViolationErrorRetries=num_retries, | ||
# We don't retry on errors that are unlikely to be transient | ||
# (e.g. bad request, invalid auth credentials) | ||
BadRequestErrorRetries=0, | ||
AuthenticationErrorRetries=0, | ||
) |
Some generated files are not rendered by default. Learn more about how customized files appear on GitHub.
Original file line number | Diff line number | Diff line change |
---|---|---|
|
@@ -38,7 +38,7 @@ dependencies = [ | |
"pydantic~=2.0", | ||
"jinja2", | ||
"magicattr~=0.1.6", | ||
"litellm", | ||
"litellm==1.55.3", | ||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. LiteLLM version 1.55.3 is the only version that correctly supports passing |
||
"diskcache", | ||
"json-repair", | ||
"tenacity>=8.2.3", | ||
|
@@ -132,7 +132,7 @@ pgvector = { version = "^0.2.5", optional = true } | |
llama-index = { version = "^0.10.30", optional = true } | ||
jinja2 = "^3.1.3" | ||
magicattr = "^0.1.6" | ||
litellm = { version = "==1.53.7", extras = ["proxy"] } | ||
litellm = "1.55.3" | ||
diskcache = "^5.6.0" | ||
json-repair = "^0.30.0" | ||
tenacity = ">=8.2.3" | ||
|
Original file line number | Diff line number | Diff line change |
---|---|---|
|
@@ -2,12 +2,14 @@ anyio | |
asyncer==0.0.8 | ||
backoff | ||
cachetools | ||
cloudpickle | ||
datasets | ||
diskcache | ||
httpx | ||
jinja2 | ||
joblib~=1.3 | ||
json-repair | ||
litellm[proxy]==1.53.7 | ||
litellm[proxy]==1.55.3 | ||
magicattr~=0.1.6 | ||
openai | ||
optuna | ||
|
@@ -18,5 +20,3 @@ requests | |
tenacity>=8.2.3 | ||
tqdm | ||
ujson | ||
cloudpickle | ||
jinja2 | ||
Comment on lines
-21
to
-22
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Linter reordering |
Original file line number | Diff line number | Diff line change |
---|---|---|
|
@@ -7,6 +7,7 @@ model_list: | |
model: "dspy-test-provider/dspy-test-model" | ||
|
||
litellm_settings: | ||
num_retries: 0 | ||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Disable retries on the server side to ensure that server retries don't stack atop client retries, which can result in test failures due to a mismatch between expected and actual # of retries |
||
custom_provider_map: | ||
- { | ||
"provider": "dspy-test-provider", | ||
|
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Empirically, this provides roughly 1 minute of retries, which is typically necessary to overcome rate limit errors (providers like Azure OpenAI & Databricks support RPM rate limits)