From f1fc6bc01e0d35704a2a5ca65ee81acef8dccc22 Mon Sep 17 00:00:00 2001 From: Corey Zumar <39497902+dbczumar@users.noreply.github.com> Date: Tue, 10 Dec 2024 06:13:53 -0800 Subject: [PATCH] Support structured outputs response format based on signature in JSON adapter (#1881) * Fix Signed-off-by: dbczumar * Fix Signed-off-by: dbczumar * fix Signed-off-by: dbczumar * Debug Signed-off-by: dbczumar * Fix Signed-off-by: dbczumar * Here Signed-off-by: dbczumar * Here Signed-off-by: dbczumar * Update json_adapter.py * Update json_adapter.py * Update json_adapter.py * Update json_adapter.py --------- Signed-off-by: dbczumar --- dspy/adapters/json_adapter.py | 64 ++++++++++++++- tests/reliability/test_pydantic_models.py | 16 +++- tests/reliability/utils.py | 1 - tests/test_json_adapter.py | 95 +++++++++++++++++++++++ 4 files changed, 169 insertions(+), 7 deletions(-) create mode 100644 tests/test_json_adapter.py diff --git a/dspy/adapters/json_adapter.py b/dspy/adapters/json_adapter.py index afe2e8c4e..281df5cb4 100644 --- a/dspy/adapters/json_adapter.py +++ b/dspy/adapters/json_adapter.py @@ -2,13 +2,15 @@ import enum import inspect import json +import logging import textwrap +from copy import deepcopy from typing import Any, Dict, KeysView, Literal, NamedTuple, get_args, get_origin import json_repair import litellm import pydantic -from pydantic import TypeAdapter +from pydantic import TypeAdapter, create_model from pydantic.fields import FieldInfo from dspy.adapters.base import Adapter @@ -18,6 +20,8 @@ from ..signatures.signature import SignatureMeta from ..signatures.utils import get_dspy_field_type +_logger = logging.getLogger(__name__) + class FieldInfoWithName(NamedTuple): name: str @@ -35,7 +39,16 @@ def __call__(self, lm, lm_kwargs, signature, demos, inputs, _parse_values=True): try: provider = lm.model.split("/", 1)[0] or "openai" if "response_format" in litellm.get_supported_openai_params(model=lm.model, custom_llm_provider=provider): - outputs = lm(**inputs, **lm_kwargs, response_format={"type": "json_object"}) + try: + response_format = _get_structured_outputs_response_format(signature) + outputs = lm(**inputs, **lm_kwargs, response_format=response_format) + except Exception: + _logger.debug( + "Failed to obtain response using signature-based structured outputs" + " response format: Falling back to default 'json_object' response format." + " Exception: {e}" + ) + outputs = lm(**inputs, **lm_kwargs, response_format={"type": "json_object"}) else: outputs = lm(**inputs, **lm_kwargs) @@ -303,3 +316,50 @@ def format_signature_fields_for_instructions(role, fields: Dict[str, FieldInfo]) # ", and then ending with the marker for `completed`.") return "\n\n".join(parts).strip() + + +def _get_structured_outputs_response_format(signature: SignatureMeta) -> pydantic.BaseModel: + """ + Obtains the LiteLLM / OpenAI `response_format` parameter for generating structured outputs from + an LM request, based on the output fields of the specified DSPy signature. + + Args: + signature: The DSPy signature for which to obtain the `response_format` request parameter. + Returns: + A Pydantic model representing the `response_format` parameter for the LM request. + """ + + def filter_json_schema_extra(field_name: str, field_info: FieldInfo) -> FieldInfo: + """ + Recursively filter the `json_schema_extra` of a FieldInfo to exclude DSPy internal attributes + (e.g. `__dspy_field_type`) and remove descriptions that are placeholders for the field name. + """ + field_copy = deepcopy(field_info) # Make a copy to avoid mutating the original + + # Update `json_schema_extra` for the copied field + if field_copy.json_schema_extra: + field_copy.json_schema_extra = { + key: value + for key, value in field_info.json_schema_extra.items() + if key not in ("desc", "__dspy_field_type") + } + field_desc = field_info.json_schema_extra.get("desc") + if field_desc is not None and field_desc != f"${{{field_name}}}": + field_copy.json_schema_extra["desc"] = field_desc + + # Handle nested fields + if hasattr(field_copy.annotation, "__pydantic_model__"): + # Recursively update fields of the nested model + nested_model = field_copy.annotation.__pydantic_model__ + updated_fields = { + key: filter_json_schema_extra(key, value) for key, value in nested_model.__fields__.items() + } + # Create a new model with the same name and updated fields + field_copy.annotation = create_model(nested_model.__name__, **updated_fields) + + return field_copy + + output_pydantic_fields = { + key: (value.annotation, filter_json_schema_extra(key, value)) for key, value in signature.output_fields.items() + } + return create_model("DSPyProgramOutputs", **output_pydantic_fields) diff --git a/tests/reliability/test_pydantic_models.py b/tests/reliability/test_pydantic_models.py index 823351441..bb6991541 100644 --- a/tests/reliability/test_pydantic_models.py +++ b/tests/reliability/test_pydantic_models.py @@ -2,6 +2,7 @@ from typing import List import pydantic +import pytest import dspy from tests.reliability.utils import assert_program_output_correct, known_failing_models @@ -33,22 +34,29 @@ class QA(dspy.Signature): assert_program_output_correct( program_input=question, program_output=answer.comments, - grading_guidelines="The comments should be relevant to the answer", + grading_guidelines=( + "The comments should be relevant to the answer. They don't need to restate the answer explicitly." + ), ) assert answer.certainty >= 0 assert answer.certainty <= 1 assert len(answer.comments) >= 2 -def test_color_classification_using_enum(): +@pytest.mark.parametrize("module", [dspy.Predict, dspy.ChainOfThought]) +def test_color_classification_using_enum(module): Color = Enum("Color", ["RED", "GREEN", "BLUE"]) class Colorful(dspy.Signature): text: str = dspy.InputField() color: Color = dspy.OutputField() - program = dspy.Predict(Colorful) - color = program(text="The sky is blue").color + program = module(Colorful) + # Note: The precise text, including the trailing period, is important here for ensuring that + # the program is correctly extracting the color from the text; previous implementations have + # produced invalid enum responses for "The sky is blue.", but they have produced valid enum + # responses for "The sky is blue" (without the period). + color = program(text="The sky is blue.").color assert color == Color.BLUE diff --git a/tests/reliability/utils.py b/tests/reliability/utils.py index 37c368760..349192fe3 100644 --- a/tests/reliability/utils.py +++ b/tests/reliability/utils.py @@ -31,7 +31,6 @@ def assert_program_output_correct( grading_guidelines = [grading_guidelines] with judge_dspy_configuration(): - print("GUIDELINES", grading_guidelines) for guideline_entry in grading_guidelines: judge_response = _get_judge_program()( program_input=str(program_input), diff --git a/tests/test_json_adapter.py b/tests/test_json_adapter.py new file mode 100644 index 000000000..2009e8936 --- /dev/null +++ b/tests/test_json_adapter.py @@ -0,0 +1,95 @@ +from unittest import mock + +import pydantic +import pytest +from pydantic import create_model + +import dspy + + +def test_json_adapter_passes_structured_output_when_supported_by_model(): + class OutputField3(pydantic.BaseModel): + subfield1: int = pydantic.Field(description="Int subfield 1", ge=0, le=10) + subfield2: float = pydantic.Field(description="Float subfield 2") + + class TestSignature(dspy.Signature): + input1: str = dspy.InputField() + output1: str = dspy.OutputField() # Description intentionally left blank + output2: bool = dspy.OutputField(desc="Boolean output field") + output3: OutputField3 = dspy.OutputField(desc="Nested output field") + output4_unannotated = dspy.OutputField(desc="Unannotated output field") + + program = dspy.Predict(TestSignature) + + # Configure DSPy to use an OpenAI LM that supports structured outputs + dspy.configure(lm=dspy.LM(model="openai/gpt4o"), adapter=dspy.JSONAdapter()) + with mock.patch("litellm.completion") as mock_completion: + program(input1="Test input") + + def clean_schema_extra(field_name, field_info): + attrs = dict(field_info.__repr_args__()) + if "json_schema_extra" in attrs: + attrs["json_schema_extra"] = { + k: v + for k, v in attrs["json_schema_extra"].items() + if k != "__dspy_field_type" and not (k == "desc" and v == f"${{{field_name}}}") + } + return attrs + + mock_completion.assert_called_once() + _, call_kwargs = mock_completion.call_args + response_format = call_kwargs.get("response_format") + assert response_format is not None + assert issubclass(response_format, pydantic.BaseModel) + assert response_format.model_fields.keys() == {"output1", "output2", "output3", "output4_unannotated"} + for field_name in response_format.model_fields: + assert dict(response_format.model_fields[field_name].__repr_args__()) == clean_schema_extra( + field_name=field_name, + field_info=TestSignature.output_fields[field_name], + ) + + # Configure DSPy to use a model from a fake provider that doesn't support structured outputs + dspy.configure(lm=dspy.LM(model="fakeprovider/fakemodel"), adapter=dspy.JSONAdapter()) + with mock.patch("litellm.completion") as mock_completion: + program(input1="Test input") + + mock_completion.assert_called_once() + _, call_kwargs = mock_completion.call_args + assert response_format not in call_kwargs + + +def test_json_adapter_falls_back_when_structured_outputs_fails(): + class TestSignature(dspy.Signature): + input1: str = dspy.InputField() + output1: str = dspy.OutputField(desc="String output field") + + dspy.configure(lm=dspy.LM(model="openai/gpt4o"), adapter=dspy.JSONAdapter()) + program = dspy.Predict(TestSignature) + with mock.patch("litellm.completion") as mock_completion: + mock_completion.side_effect = [Exception("Bad structured outputs!"), mock_completion.return_value] + program(input1="Test input") + assert mock_completion.call_count == 2 + _, first_call_kwargs = mock_completion.call_args_list[0] + assert issubclass(first_call_kwargs.get("response_format"), pydantic.BaseModel) + _, second_call_kwargs = mock_completion.call_args_list[1] + assert second_call_kwargs.get("response_format") == {"type": "json_object"} + + +def test_json_adapter_with_structured_outputs_does_not_mutate_original_signature(): + class OutputField3(pydantic.BaseModel): + subfield1: int = pydantic.Field(description="Int subfield 1") + subfield2: float = pydantic.Field(description="Float subfield 2") + + class TestSignature(dspy.Signature): + input1: str = dspy.InputField() + output1: str = dspy.OutputField() # Description intentionally left blank + output2: bool = dspy.OutputField(desc="Boolean output field") + output3: OutputField3 = dspy.OutputField(desc="Nested output field") + output4_unannotated = dspy.OutputField(desc="Unannotated output field") + + dspy.configure(lm=dspy.LM(model="openai/gpt4o"), adapter=dspy.JSONAdapter()) + program = dspy.Predict(TestSignature) + with mock.patch("litellm.completion"): + program(input1="Test input") + + assert program.signature.output_fields == TestSignature.output_fields