Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Adapters: Support JSON serialization of all pydantic types (e.g. datetimes, enums, etc.) #1853

Merged
merged 6 commits into from
Nov 26, 2024
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
86 changes: 3 additions & 83 deletions dspy/adapters/chat_adapter.py
Original file line number Diff line number Diff line change
Expand Up @@ -14,7 +14,7 @@

from dsp.adapters.base_template import Field
from dspy.adapters.base import Adapter
from dspy.adapters.image_utils import Image, encode_image
from dspy.adapters.utils import find_enum_member, format_field_value
from dspy.signatures.field import OutputField
from dspy.signatures.signature import Signature, SignatureMeta
from dspy.signatures.utils import get_dspy_field_type
Expand Down Expand Up @@ -115,86 +115,6 @@ def format_fields(self, signature, values, role):
return format_fields(fields_with_values)


def format_blob(blob):
Copy link
Collaborator Author

@dbczumar dbczumar Nov 25, 2024

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Moved these utilities into a shared adapters/utils.py module, since they're identical but repeated across JSONAdapter and ChatAdapter.

(Except that JSONAdapter doesn't supported images yet and throws an "unsupported" exception. This behavior is preserved by wrapping utils.format_field_value() with aif field_info.annotation is Image: throw block)

if "\n" not in blob and "«" not in blob and "»" not in blob:
return f"«{blob}»"

modified_blob = blob.replace("\n", "\n ")
return f"«««\n {modified_blob}\n»»»"


def format_input_list_field_value(value: List[Any]) -> str:
"""
Formats the value of an input field of type List[Any].

Args:
value: The value of the list-type input field.
Returns:
A string representation of the input field's list value.
"""
if len(value) == 0:
return "N/A"
if len(value) == 1:
return format_blob(value[0])

return "\n".join([f"[{idx+1}] {format_blob(txt)}" for idx, txt in enumerate(value)])


def _serialize_for_json(value):
if isinstance(value, pydantic.BaseModel):
return value.model_dump()
elif isinstance(value, list):
return [_serialize_for_json(item) for item in value]
elif isinstance(value, dict):
return {key: _serialize_for_json(val) for key, val in value.items()}
else:
return value
Comment on lines -143 to -151
Copy link
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

The core change in this PR is the replacement of this method with the new version in utils.py:

def serialize_for_json(value: Any) -> Any:
    """
    Formats the specified value so that it can be serialized as a JSON string.
    Args:
        value: The value to format as a JSON string.
    Returns:
        The formatted value, which is serializable as a JSON string.
    """
    # Attempt to format the value as a JSON-compatible object using pydantic, falling back to
    # a string representation of the value if that fails (e.g. if the value contains an object
    # that pydantic doesn't recognize or can't serialize)
    try:
        return TypeAdapter(type(value)).dump_python(value, mode="json")
    except Exception:
        return str(value)

This reuses pydantic's JSON serialization, ensuring we don't miss certain types.



def _format_field_value(field_info: FieldInfo, value: Any, assume_text=True) -> Union[str, dict]:
"""
Formats the value of the specified field according to the field's DSPy type (input or output),
annotation (e.g. str, int, etc.), and the type of the value itself.

Args:
field_info: Information about the field, including its DSPy field type and annotation.
value: The value of the field.
Returns:
The formatted value of the field, represented as a string.
"""
string_value = None
if isinstance(value, list) and field_info.annotation is str:
# If the field has no special type requirements, format it as a nice numbered list for the LM.
string_value = format_input_list_field_value(value)
elif isinstance(value, pydantic.BaseModel) or isinstance(value, dict) or isinstance(value, list):
string_value = json.dumps(_serialize_for_json(value), ensure_ascii=False)
else:
string_value = str(value)

if assume_text:
return string_value
elif isinstance(value, Image) or field_info.annotation == Image:
# This validation should happen somewhere else
# Safe to import PIL here because it's only imported when an image is actually being formatted
try:
import PIL
except ImportError:
raise ImportError("PIL is required to format images; Run `pip install pillow` to install it.")
image_value = value
if not isinstance(image_value, Image):
if isinstance(image_value, dict) and "url" in image_value:
image_value = image_value["url"]
elif isinstance(image_value, str):
image_value = encode_image(image_value)
elif isinstance(image_value, PIL.Image.Image):
image_value = encode_image(image_value)
assert isinstance(image_value, str)
image_value = Image(url=image_value)
return {"type": "image_url", "image_url": image_value.model_dump()}
else:
return {"type": "text", "text": string_value}


def format_fields(fields_with_values: Dict[FieldInfoWithName, Any], assume_text=True) -> Union[str, List[dict]]:
"""
Formats the values of the specified fields according to the field's DSPy type (input or output),
Expand All @@ -209,7 +129,7 @@ def format_fields(fields_with_values: Dict[FieldInfoWithName, Any], assume_text=
"""
output = []
for field, field_value in fields_with_values.items():
formatted_field_value = _format_field_value(field_info=field.info, value=field_value, assume_text=assume_text)
formatted_field_value = format_field_value(field_info=field.info, value=field_value, assume_text=assume_text)
if assume_text:
output.append(f"[[ ## {field.name} ## ]]\n{formatted_field_value}")
else:
Expand All @@ -231,7 +151,7 @@ def parse_value(value, annotation):
parsed_value = value

if isinstance(annotation, enum.EnumMeta):
parsed_value = annotation[value]
Copy link
Collaborator Author

@dbczumar dbczumar Nov 25, 2024

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

This enum handling logic was incorrect.

Given an enum like

from enum import Enum

class Status(Enum):
    PENDING = "pending"
    IN_PROGRESS = "in_progress"
    COMPLETED = "completed"

Serializing this enum to JSON with Pydantic produces in_progress (an enum field value), not IN_PROGRESS (an enum field *name). Status('in_progress') is the correct way to restore the enum field from its value, while Status['in_progress'] throws:

python3.10/enum.py:440, in EnumMeta.__getitem__(cls, name)
    439 def __getitem__(cls, name):
--> 440     return cls._member_map_[name]

KeyError: 'in_progress'

I've added test coverage to confirm that the new behavior is correct (see test_predict and reliability/complex_types/generated/test_many_types_1)

return find_enum_member(annotation, value)
elif isinstance(value, str):
try:
parsed_value = json.loads(value)
Expand Down
118 changes: 41 additions & 77 deletions dspy/adapters/json_adapter.py
Original file line number Diff line number Diff line change
@@ -1,39 +1,41 @@
import ast
import json
import enum
import inspect
import litellm
import pydantic
import json
import textwrap
import json_repair

from typing import Any, Dict, KeysView, Literal, NamedTuple, get_args, get_origin

import json_repair
import litellm
import pydantic
from pydantic import TypeAdapter
from pydantic.fields import FieldInfo
from typing import Any, Dict, KeysView, List, Literal, NamedTuple, get_args, get_origin

Comment on lines +4 to 13
Copy link
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

This is all just the linter reordering imports

from dspy.adapters.base import Adapter
from dspy.adapters.utils import find_enum_member, format_field_value, serialize_for_json

from ..adapters.image_utils import Image
from ..signatures.signature import SignatureMeta
from ..signatures.utils import get_dspy_field_type


class FieldInfoWithName(NamedTuple):
name: str
info: FieldInfo


class JSONAdapter(Adapter):
def __init__(self):
pass

def __call__(self, lm, lm_kwargs, signature, demos, inputs, _parse_values=True):
inputs = self.format(signature, demos, inputs)
inputs = dict(prompt=inputs) if isinstance(inputs, str) else dict(messages=inputs)



try:
provider = lm.model.split('/', 1)[0] or "openai"
if 'response_format' in litellm.get_supported_openai_params(model=lm.model, custom_llm_provider=provider):
outputs = lm(**inputs, **lm_kwargs, response_format={ "type": "json_object" })
Comment on lines -31 to -36
Copy link
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Just the linter - no material change here

provider = lm.model.split("/", 1)[0] or "openai"
if "response_format" in litellm.get_supported_openai_params(model=lm.model, custom_llm_provider=provider):
outputs = lm(**inputs, **lm_kwargs, response_format={"type": "json_object"})
else:
outputs = lm(**inputs, **lm_kwargs)

Expand All @@ -44,11 +46,12 @@ def __call__(self, lm, lm_kwargs, signature, demos, inputs, _parse_values=True):

for output in outputs:
value = self.parse(signature, output, _parse_values=_parse_values)
assert set(value.keys()) == set(signature.output_fields.keys()), f"Expected {signature.output_fields.keys()} but got {value.keys()}"
assert set(value.keys()) == set(
signature.output_fields.keys()
), f"Expected {signature.output_fields.keys()} but got {value.keys()}"
Comment on lines +49 to +51
Copy link
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Just the linter - no material change here

values.append(value)

return values

return values

def format(self, signature, demos, inputs):
messages = []
Expand All @@ -71,7 +74,7 @@ def format(self, signature, demos, inputs):
messages.append(format_turn(signature, demo, role="assistant", incomplete=demo in incomplete_demos))

messages.append(format_turn(signature, inputs, role="user"))

return messages

def parse(self, signature, completion, _parse_values=True):
Expand All @@ -90,7 +93,7 @@ def parse(self, signature, completion, _parse_values=True):

def format_turn(self, signature, values, role, incomplete=False):
return format_turn(signature, values, role, incomplete)

def format_fields(self, signature, values, role):
fields_with_values = {
FieldInfoWithName(name=field_name, info=field_info): values.get(
Expand All @@ -101,16 +104,16 @@ def format_fields(self, signature, values, role):
}

return format_fields(role=role, fields_with_values=fields_with_values)


def parse_value(value, annotation):
if annotation is str:
return str(value)

parsed_value = value

if isinstance(annotation, enum.EnumMeta):
parsed_value = annotation[value]
parsed_value = find_enum_member(annotation, value)
elif isinstance(value, str):
try:
parsed_value = json.loads(value)
Expand All @@ -119,45 +122,10 @@ def parse_value(value, annotation):
parsed_value = ast.literal_eval(value)
except (ValueError, SyntaxError):
parsed_value = value

return TypeAdapter(annotation).validate_python(parsed_value)


def format_blob(blob):
if "\n" not in blob and "«" not in blob and "»" not in blob:
return f"«{blob}»"

modified_blob = blob.replace("\n", "\n ")
return f"«««\n {modified_blob}\n»»»"


def format_input_list_field_value(value: List[Any]) -> str:
"""
Formats the value of an input field of type List[Any].
Args:
value: The value of the list-type input field.
Returns:
A string representation of the input field's list value.
"""
if len(value) == 0:
return "N/A"
if len(value) == 1:
return format_blob(value[0])

return "\n".join([f"[{idx+1}] {format_blob(txt)}" for idx, txt in enumerate(value)])
return TypeAdapter(annotation).validate_python(parsed_value)


def _serialize_for_json(value):
if isinstance(value, pydantic.BaseModel):
return value.model_dump()
elif isinstance(value, list):
return [_serialize_for_json(item) for item in value]
elif isinstance(value, dict):
return {key: _serialize_for_json(val) for key, val in value.items()}
else:
return value

Comment on lines -126 to -160
Copy link
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Deduplicate code with chat_adapter by moving these into a utils file - see https://github.com/stanfordnlp/dspy/pull/1853/files#r1856167223

def _format_field_value(field_info: FieldInfo, value: Any) -> str:
"""
Formats the value of the specified field according to the field's DSPy type (input or output),
Expand All @@ -169,17 +137,10 @@ def _format_field_value(field_info: FieldInfo, value: Any) -> str:
Returns:
The formatted value of the field, represented as a string.
"""

if isinstance(value, list) and field_info.annotation is str:
# If the field has no special type requirements, format it as a nice numbere list for the LM.
return format_input_list_field_value(value)
if field_info.annotation is Image:
raise NotImplementedError("Images are not yet supported in JSON mode.")
elif isinstance(value, pydantic.BaseModel) or isinstance(value, dict) or isinstance(value, list):
return json.dumps(_serialize_for_json(value))
else:
return str(value)

return format_field_value(field_info=field_info, value=value, assume_text=True)


def format_fields(role: str, fields_with_values: Dict[FieldInfoWithName, Any]) -> str:
Expand All @@ -197,9 +158,8 @@ def format_fields(role: str, fields_with_values: Dict[FieldInfoWithName, Any]) -

if role == "assistant":
d = fields_with_values.items()
d = {k.name: _serialize_for_json(v) for k, v in d}

return json.dumps(_serialize_for_json(d), indent=2)
d = {k.name: v for k, v in d}
return json.dumps(serialize_for_json(d), indent=2)

output = []
for field, field_value in fields_with_values.items():
Expand Down Expand Up @@ -246,15 +206,19 @@ def format_turn(signature: SignatureMeta, values: Dict[str, Any], role, incomple
field_name, "Not supplied for this particular example."
)
for field_name, field_info in fields.items()
}
},
)
content.append(formatted_fields)

if role == "user":

def type_info(v):
return f" (must be formatted as a valid Python {get_annotation_name(v.annotation)})" \
if v.annotation is not str else ""

return (
f" (must be formatted as a valid Python {get_annotation_name(v.annotation)})"
if v.annotation is not str
else ""
)
Comment on lines +216 to +220
Copy link
Collaborator Author

@dbczumar dbczumar Nov 25, 2024

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

This (and everything else below it in this file) is just a linter change


# TODO: Consider if not incomplete:
content.append(
"Respond with a JSON object in the following order of fields: "
Expand Down Expand Up @@ -297,15 +261,15 @@ def prepare_instructions(signature: SignatureMeta):
def field_metadata(field_name, field_info):
type_ = field_info.annotation

if get_dspy_field_type(field_info) == 'input' or type_ is str:
if get_dspy_field_type(field_info) == "input" or type_ is str:
desc = ""
elif type_ is bool:
desc = "must be True or False"
elif type_ in (int, float):
desc = f"must be a single {type_.__name__} value"
elif inspect.isclass(type_) and issubclass(type_, enum.Enum):
desc= f"must be one of: {'; '.join(type_.__members__)}"
elif hasattr(type_, '__origin__') and type_.__origin__ is Literal:
desc = f"must be one of: {'; '.join(type_.__members__)}"
elif hasattr(type_, "__origin__") and type_.__origin__ is Literal:
desc = f"must be one of: {'; '.join([str(x) for x in type_.__args__])}"
else:
desc = "must be pareseable according to the following JSON schema: "
Expand All @@ -320,13 +284,13 @@ def format_signature_fields_for_instructions(role, fields: Dict[str, FieldInfo])
fields_with_values={
FieldInfoWithName(name=field_name, info=field_info): field_metadata(field_name, field_info)
for field_name, field_info in fields.items()
}
},
)

parts.append("Inputs will have the following structure:")
parts.append(format_signature_fields_for_instructions('user', signature.input_fields))
parts.append(format_signature_fields_for_instructions("user", signature.input_fields))
parts.append("Outputs will be a JSON object with the following fields.")
parts.append(format_signature_fields_for_instructions('assistant', signature.output_fields))
parts.append(format_signature_fields_for_instructions("assistant", signature.output_fields))
# parts.append(format_fields({BuiltInCompletedOutputFieldInfo: ""}))

instructions = textwrap.dedent(signature.instructions)
Expand Down
Loading
Loading