Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Streaming #1874

Merged
merged 23 commits into from
Dec 17, 2024
Merged
Show file tree
Hide file tree
Changes from 20 commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
41 changes: 37 additions & 4 deletions docs/docs/tutorials/deployment/index.md
Original file line number Diff line number Diff line change
Expand Up @@ -41,9 +41,11 @@ class Question(BaseModel):
# Configure your language model and 'asyncify' your DSPy program.
lm = dspy.LM("openai/gpt-4o-mini")
dspy.settings.configure(lm=lm, async_max_workers=4) # default is 8
dspy_program = dspy.ChainOfThought("question -> answer")
dspy_program = dspy.asyncify(dspy_program)

dspy_program = dspy.asyncify(dspy.ChainOfThought("question -> answer"))
streaming_dspy_program = dspy.streamify(dspy_program)

# Define an endpoint (no streaming)
@app.post("/predict")
async def predict(question: Question):
try:
Expand All @@ -54,14 +56,45 @@ async def predict(question: Question):
}
except Exception as e:
raise HTTPException(status_code=500, detail=str(e))

# Define an endpoint (streaming)
from fastapi.responses import StreamingResponse

@app.post("/predict/stream")
async def stream(question: Question):
async def generate():
async for value in streaming_dspy_program(question=question.text):
if isinstance(value, dspy.Prediction):
data = {"prediction": value.labels().toDict()}
elif isinstance(value, litellm.ModelResponse):
data = {"chunk": value.json()}
yield f"data: {ujson.dumps(data)}\n\n"
yield "data: [DONE]\n\n"

return StreamingResponse(generate(), media_type="text/event-stream")

# Since you're often going to want to stream the result of a DSPy program as server-sent events,
# we've included a helper function for that, which is equivalent to the code above.

from dspy.utils.streaming import streaming_response

@app.post("/predict/stream")
async def stream(question: Question):
stream = streaming_dspy_program(question=question.text)
return StreamingResponse(streaming_response(stream), media_type="text/event-stream")
CyrusNuevoDia marked this conversation as resolved.
Show resolved Hide resolved
```

In the code above, we call `dspy.asyncify` to convert the dspy program to run in async mode for high-throughput FastAPI
deployments. Currently, this runs the dspy program in a
separate thread and awaits its result. By default, the limit of spawned threads is 8. Think of this like a worker pool.
deployments. Currently, this runs the dspy program in a separate thread and awaits its result.

By default, the limit of spawned threads is 8. Think of this like a worker pool.
If you have 8 in-flight programs and call it once more, the 9th call will wait until one of the 8 returns.
You can configure the async capacity using the new `async_max_workers` setting.

We also use `dspy.streamify` to convert the dspy program to a streaming mode. This is useful when you want to stream
the intermediate outputs (i.e. O1-style reasoning) to the client before the final prediction is ready. This uses
asyncify under the hood and inherits the execution semantics.

Write your code to a file, e.g., `fastapi_dspy.py`. Then you can serve the app with:

```bash
Expand Down
1 change: 1 addition & 0 deletions dspy/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -13,6 +13,7 @@
from dspy.utils.logging_utils import configure_dspy_loggers, disable_logging, enable_logging
from dspy.utils.asyncify import asyncify
from dspy.utils.saving import load
from dspy.utils.streaming import streamify

from dspy.dsp.utils.settings import settings

Expand Down
39 changes: 31 additions & 8 deletions dspy/clients/lm.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,13 +5,16 @@
import uuid
from datetime import datetime
from hashlib import sha256
from typing import Any, Dict, List, Literal, Optional
from typing import Any, Dict, List, Literal, Optional, cast

import litellm
import pydantic
import ujson
from anyio.streams.memory import MemoryObjectSendStream
from asyncer import syncify
from cachetools import LRUCache, cached

import dspy
from dspy.adapters.base import Adapter
from dspy.clients.openai import OpenAIProvider
from dspy.clients.provider import Provider, TrainingJob
Expand Down Expand Up @@ -102,14 +105,13 @@ def __call__(self, prompt=None, messages=None, **kwargs):
outputs = [
{
"text": c.message.content if hasattr(c, "message") else c["text"],
"logprobs": c.logprobs if hasattr(c, "logprobs") else c["logprobs"]
"logprobs": c.logprobs if hasattr(c, "logprobs") else c["logprobs"],
}
for c in response["choices"]
]
else:
outputs = [c.message.content if hasattr(c, "message") else c["text"] for c in response["choices"]]


# Logging, with removed api key & where `cost` is None on cache hit.
kwargs = {k: v for k, v in kwargs.items() if not k.startswith("api_")}
entry = dict(prompt=prompt, messages=messages, kwargs=kwargs, response=response)
Expand Down Expand Up @@ -309,11 +311,32 @@ def cached_litellm_completion(request: Dict[str, Any], num_retries: int):


def litellm_completion(request: Dict[str, Any], num_retries: int, cache={"no-cache": True, "no-store": True}):
return litellm.completion(
num_retries=num_retries,
cache=cache,
**request,
)
stream = dspy.settings.send_stream
if stream is None:
return litellm.completion(
num_retries=num_retries,
cache=cache,
**request,
)

# The stream is already opened, and will be closed by the caller.
stream = cast(MemoryObjectSendStream, stream)

@syncify
async def stream_completion():
response = await litellm.acompletion(
num_retries=num_retries,
cache=cache,
stream=True,
**request,
)
chunks = []
async for chunk in response:
chunks.append(chunk)
await stream.send(chunk)
return litellm.stream_chunk_builder(chunks)

return stream_completion()


@request_cache(maxsize=None)
Expand Down
20 changes: 11 additions & 9 deletions dspy/dsp/utils/settings.py
Original file line number Diff line number Diff line change
@@ -1,6 +1,7 @@
import copy
import threading
from contextlib import contextmanager

from dspy.dsp.utils.utils import dotdict

DEFAULT_CONFIG = dotdict(
Expand All @@ -17,6 +18,7 @@
backoff_time=10,
callbacks=[],
async_max_workers=8,
send_stream=None,
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

This is the only meaningful change in the file - everything else is a linter adjustment

)

# Global base configuration
Expand Down Expand Up @@ -50,7 +52,7 @@ def __new__(cls):
return cls._instance

def __getattr__(self, name):
overrides = getattr(thread_local_overrides, 'overrides', dotdict())
overrides = getattr(thread_local_overrides, "overrides", dotdict())
if name in overrides:
return overrides[name]
elif name in main_thread_config:
Expand All @@ -59,7 +61,7 @@ def __getattr__(self, name):
raise AttributeError(f"'Settings' object has no attribute '{name}'")

def __setattr__(self, name, value):
if name in ('_instance',):
if name in ("_instance",):
super().__setattr__(name, value)
else:
self.configure(**{name: value})
Expand All @@ -73,7 +75,7 @@ def __setitem__(self, key, value):
self.__setattr__(key, value)

def __contains__(self, key):
overrides = getattr(thread_local_overrides, 'overrides', dotdict())
overrides = getattr(thread_local_overrides, "overrides", dotdict())
return key in overrides or key in main_thread_config

def get(self, key, default=None):
Expand All @@ -83,14 +85,14 @@ def get(self, key, default=None):
return default

def copy(self):
overrides = getattr(thread_local_overrides, 'overrides', dotdict())
overrides = getattr(thread_local_overrides, "overrides", dotdict())
return dotdict({**main_thread_config, **overrides})

@property
def config(self):
config = self.copy()
if 'lock' in config:
del config['lock']
if "lock" in config:
del config["lock"]
return config

# Configuration methods
Expand All @@ -99,7 +101,7 @@ def configure(self, **kwargs):
global main_thread_config

# Get or initialize thread-local overrides
overrides = getattr(thread_local_overrides, 'overrides', dotdict())
overrides = getattr(thread_local_overrides, "overrides", dotdict())
thread_local_overrides.overrides = dotdict(
{**copy.deepcopy(DEFAULT_CONFIG), **main_thread_config, **overrides, **kwargs}
)
Expand All @@ -112,7 +114,7 @@ def configure(self, **kwargs):
def context(self, **kwargs):
"""Context manager for temporary configuration changes."""
global main_thread_config
original_overrides = getattr(thread_local_overrides, 'overrides', dotdict()).copy()
original_overrides = getattr(thread_local_overrides, "overrides", dotdict()).copy()
original_main_thread_config = main_thread_config.copy()

self.configure(**kwargs)
Expand All @@ -125,7 +127,7 @@ def context(self, **kwargs):
main_thread_config = original_main_thread_config

def __repr__(self):
overrides = getattr(thread_local_overrides, 'overrides', dotdict())
overrides = getattr(thread_local_overrides, "overrides", dotdict())
combined_config = {**main_thread_config, **overrides}
return repr(combined_config)

Expand Down
87 changes: 87 additions & 0 deletions dspy/utils/streaming.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,87 @@
from asyncio import iscoroutinefunction
from typing import Any, AsyncGenerator, Awaitable, Callable

import litellm
import ujson
from anyio import create_memory_object_stream, create_task_group
from anyio.streams.memory import MemoryObjectSendStream

from dspy.primitives.prediction import Prediction
from dspy.primitives.program import Module
from dspy.utils.asyncify import asyncify


def streamify(program: Module) -> Callable[[Any, Any], Awaitable[Any]]:
"""
Wrap a DSPy program so that it streams its outputs incrementally, rather than returning them
all at once.

Args:
program: The DSPy program to wrap with streaming functionality.
Returns:
A function that takes the same arguments as the original program, but returns an async
generator that yields the program's outputs incrementally.

Example:
>>> class TestSignature(dspy.Signature):
>>> input_text: str = dspy.InputField()
>>> output_text: str = dspy.OutputField()
>>>
>>> # Create the program and wrap it with streaming functionality
>>> program = dspy.streamify(dspy.Predict(TestSignature))
>>>
>>> # Use the program with streaming output
>>> async def use_streaming():
>>> output_stream = program(input_text="Test")
>>> async for value in output_stream:
>>> print(value) # Print each streamed value incrementally
"""
import dspy

if not iscoroutinefunction(program):
program = asyncify(program)

async def generator(args, kwargs, stream: MemoryObjectSendStream):
with dspy.settings.context(send_stream=stream):
prediction = await program(*args, **kwargs)

await stream.send(prediction)

async def streamer(*args, **kwargs):
send_stream, receive_stream = create_memory_object_stream(16)
async with create_task_group() as tg, send_stream, receive_stream:
tg.start_soon(generator, args, kwargs, send_stream)

async for value in receive_stream:
yield value
if isinstance(value, Prediction):
return

return streamer


async def streaming_response(streamer: AsyncGenerator) -> AsyncGenerator:
"""
Convert a DSPy program output stream to an OpenAI-compatible output stream that can be
used by a service as an API response to a streaming request.

Args:
streamer: An async generator that yields values from a DSPy program output stream.
Returns:
An async generator that yields OpenAI-compatible streaming response chunks.
"""
async for value in streamer:
if isinstance(value, Prediction):
data = {"prediction": {k: v for k, v in value.items(include_dspy=False)}}
CyrusNuevoDia marked this conversation as resolved.
Show resolved Hide resolved
yield f"data: {ujson.dumps(data)}\n\n"
elif isinstance(value, litellm.ModelResponse):
data = {"chunk": value.json()}
yield f"data: {ujson.dumps(data)}\n\n"
elif isinstance(value, str) and value.startswith("data:"):
# The chunk value is an OpenAI-compatible streaming chunk value,
# e.g. "data: {"finish_reason": "stop", "index": 0, "is_finished": True, ...}",
# so yield it directly
yield value
CyrusNuevoDia marked this conversation as resolved.
Show resolved Hide resolved
else:
raise ValueError(f"Unknown chunk value type: {value}")
yield "data: [DONE]\n\n"
25 changes: 25 additions & 0 deletions tests/test_utils/server/litellm_server.py
Original file line number Diff line number Diff line change
@@ -1,8 +1,11 @@
import json
import os
import time
from typing import AsyncIterator, Iterator

import litellm
from litellm import CustomLLM
from litellm.types.utils import GenericStreamingChunk

LITELLM_TEST_SERVER_LOG_FILE_PATH_ENV_VAR = "LITELLM_TEST_SERVER_LOG_FILE_PATH"

Expand All @@ -16,6 +19,28 @@ async def acompletion(self, *args, **kwargs) -> litellm.ModelResponse:
_append_request_to_log_file(kwargs)
return _get_mock_llm_response()

def streaming(self, *args, **kwargs) -> Iterator[GenericStreamingChunk]:
generic_streaming_chunk: GenericStreamingChunk = {
"finish_reason": "stop",
"index": 0,
"is_finished": True,
"text": '{"output_text": "Hello!"}',
"tool_use": None,
"usage": {"completion_tokens": 0, "prompt_tokens": 0, "total_tokens": 0},
}
return generic_streaming_chunk # type: ignore

async def astreaming(self, *args, **kwargs) -> AsyncIterator[GenericStreamingChunk]:
generic_streaming_chunk: GenericStreamingChunk = {
"finish_reason": "stop",
"index": 0,
"is_finished": True,
"text": '{"output_text": "Hello!"}',
"tool_use": None,
"usage": {"completion_tokens": 0, "prompt_tokens": 0, "total_tokens": 0},
}
yield generic_streaming_chunk


def _get_mock_llm_response():
return litellm.completion(
Expand Down
Loading