-
Notifications
You must be signed in to change notification settings - Fork 1.6k
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Adapters: Support JSON serialization of all pydantic types (e.g. datetimes, enums, etc.) #1853
Changes from all commits
File filter
Filter by extension
Conversations
Jump to
Diff view
Diff view
There are no files selected for viewing
Original file line number | Diff line number | Diff line change |
---|---|---|
|
@@ -14,7 +14,7 @@ | |
|
||
from dsp.adapters.base_template import Field | ||
from dspy.adapters.base import Adapter | ||
from dspy.adapters.image_utils import Image, encode_image | ||
from dspy.adapters.utils import find_enum_member, format_field_value | ||
from dspy.signatures.field import OutputField | ||
from dspy.signatures.signature import Signature, SignatureMeta | ||
from dspy.signatures.utils import get_dspy_field_type | ||
|
@@ -115,86 +115,6 @@ def format_fields(self, signature, values, role): | |
return format_fields(fields_with_values) | ||
|
||
|
||
def format_blob(blob): | ||
if "\n" not in blob and "«" not in blob and "»" not in blob: | ||
return f"«{blob}»" | ||
|
||
modified_blob = blob.replace("\n", "\n ") | ||
return f"«««\n {modified_blob}\n»»»" | ||
|
||
|
||
def format_input_list_field_value(value: List[Any]) -> str: | ||
""" | ||
Formats the value of an input field of type List[Any]. | ||
|
||
Args: | ||
value: The value of the list-type input field. | ||
Returns: | ||
A string representation of the input field's list value. | ||
""" | ||
if len(value) == 0: | ||
return "N/A" | ||
if len(value) == 1: | ||
return format_blob(value[0]) | ||
|
||
return "\n".join([f"[{idx+1}] {format_blob(txt)}" for idx, txt in enumerate(value)]) | ||
|
||
|
||
def _serialize_for_json(value): | ||
if isinstance(value, pydantic.BaseModel): | ||
return value.model_dump() | ||
elif isinstance(value, list): | ||
return [_serialize_for_json(item) for item in value] | ||
elif isinstance(value, dict): | ||
return {key: _serialize_for_json(val) for key, val in value.items()} | ||
else: | ||
return value | ||
Comment on lines
-143
to
-151
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. The core change in this PR is the replacement of this method with the new version in
This reuses pydantic's JSON serialization, ensuring we don't miss certain types. |
||
|
||
|
||
def _format_field_value(field_info: FieldInfo, value: Any, assume_text=True) -> Union[str, dict]: | ||
""" | ||
Formats the value of the specified field according to the field's DSPy type (input or output), | ||
annotation (e.g. str, int, etc.), and the type of the value itself. | ||
|
||
Args: | ||
field_info: Information about the field, including its DSPy field type and annotation. | ||
value: The value of the field. | ||
Returns: | ||
The formatted value of the field, represented as a string. | ||
""" | ||
string_value = None | ||
if isinstance(value, list) and field_info.annotation is str: | ||
# If the field has no special type requirements, format it as a nice numbered list for the LM. | ||
string_value = format_input_list_field_value(value) | ||
elif isinstance(value, pydantic.BaseModel) or isinstance(value, dict) or isinstance(value, list): | ||
string_value = json.dumps(_serialize_for_json(value), ensure_ascii=False) | ||
else: | ||
string_value = str(value) | ||
|
||
if assume_text: | ||
return string_value | ||
elif isinstance(value, Image) or field_info.annotation == Image: | ||
# This validation should happen somewhere else | ||
# Safe to import PIL here because it's only imported when an image is actually being formatted | ||
try: | ||
import PIL | ||
except ImportError: | ||
raise ImportError("PIL is required to format images; Run `pip install pillow` to install it.") | ||
image_value = value | ||
if not isinstance(image_value, Image): | ||
if isinstance(image_value, dict) and "url" in image_value: | ||
image_value = image_value["url"] | ||
elif isinstance(image_value, str): | ||
image_value = encode_image(image_value) | ||
elif isinstance(image_value, PIL.Image.Image): | ||
image_value = encode_image(image_value) | ||
assert isinstance(image_value, str) | ||
image_value = Image(url=image_value) | ||
return {"type": "image_url", "image_url": image_value.model_dump()} | ||
else: | ||
return {"type": "text", "text": string_value} | ||
|
||
|
||
def format_fields(fields_with_values: Dict[FieldInfoWithName, Any], assume_text=True) -> Union[str, List[dict]]: | ||
""" | ||
Formats the values of the specified fields according to the field's DSPy type (input or output), | ||
|
@@ -209,7 +129,7 @@ def format_fields(fields_with_values: Dict[FieldInfoWithName, Any], assume_text= | |
""" | ||
output = [] | ||
for field, field_value in fields_with_values.items(): | ||
formatted_field_value = _format_field_value(field_info=field.info, value=field_value, assume_text=assume_text) | ||
formatted_field_value = format_field_value(field_info=field.info, value=field_value, assume_text=assume_text) | ||
if assume_text: | ||
output.append(f"[[ ## {field.name} ## ]]\n{formatted_field_value}") | ||
else: | ||
|
@@ -231,7 +151,7 @@ def parse_value(value, annotation): | |
parsed_value = value | ||
|
||
if isinstance(annotation, enum.EnumMeta): | ||
parsed_value = annotation[value] | ||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. This enum handling logic was incorrect. Given an enum like
Serializing this enum to JSON with Pydantic produces
I've added test coverage to confirm that the new behavior is correct (see |
||
return find_enum_member(annotation, value) | ||
elif isinstance(value, str): | ||
try: | ||
parsed_value = json.loads(value) | ||
|
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -1,39 +1,41 @@ | ||
import ast | ||
import json | ||
import enum | ||
import inspect | ||
import litellm | ||
import pydantic | ||
import json | ||
import textwrap | ||
import json_repair | ||
|
||
from typing import Any, Dict, KeysView, Literal, NamedTuple, get_args, get_origin | ||
|
||
import json_repair | ||
import litellm | ||
import pydantic | ||
from pydantic import TypeAdapter | ||
from pydantic.fields import FieldInfo | ||
from typing import Any, Dict, KeysView, List, Literal, NamedTuple, get_args, get_origin | ||
|
||
Comment on lines
+4
to
13
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. This is all just the linter reordering imports |
||
from dspy.adapters.base import Adapter | ||
from dspy.adapters.utils import find_enum_member, format_field_value, serialize_for_json | ||
|
||
from ..adapters.image_utils import Image | ||
from ..signatures.signature import SignatureMeta | ||
from ..signatures.utils import get_dspy_field_type | ||
|
||
|
||
class FieldInfoWithName(NamedTuple): | ||
name: str | ||
info: FieldInfo | ||
|
||
|
||
class JSONAdapter(Adapter): | ||
def __init__(self): | ||
pass | ||
|
||
def __call__(self, lm, lm_kwargs, signature, demos, inputs, _parse_values=True): | ||
inputs = self.format(signature, demos, inputs) | ||
inputs = dict(prompt=inputs) if isinstance(inputs, str) else dict(messages=inputs) | ||
|
||
|
||
|
||
try: | ||
provider = lm.model.split('/', 1)[0] or "openai" | ||
if 'response_format' in litellm.get_supported_openai_params(model=lm.model, custom_llm_provider=provider): | ||
outputs = lm(**inputs, **lm_kwargs, response_format={ "type": "json_object" }) | ||
Comment on lines
-31
to
-36
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Just the linter - no material change here |
||
provider = lm.model.split("/", 1)[0] or "openai" | ||
if "response_format" in litellm.get_supported_openai_params(model=lm.model, custom_llm_provider=provider): | ||
outputs = lm(**inputs, **lm_kwargs, response_format={"type": "json_object"}) | ||
else: | ||
outputs = lm(**inputs, **lm_kwargs) | ||
|
||
|
@@ -44,11 +46,12 @@ def __call__(self, lm, lm_kwargs, signature, demos, inputs, _parse_values=True): | |
|
||
for output in outputs: | ||
value = self.parse(signature, output, _parse_values=_parse_values) | ||
assert set(value.keys()) == set(signature.output_fields.keys()), f"Expected {signature.output_fields.keys()} but got {value.keys()}" | ||
assert set(value.keys()) == set( | ||
signature.output_fields.keys() | ||
), f"Expected {signature.output_fields.keys()} but got {value.keys()}" | ||
Comment on lines
+49
to
+51
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Just the linter - no material change here |
||
values.append(value) | ||
|
||
return values | ||
|
||
return values | ||
|
||
def format(self, signature, demos, inputs): | ||
messages = [] | ||
|
@@ -71,7 +74,7 @@ def format(self, signature, demos, inputs): | |
messages.append(format_turn(signature, demo, role="assistant", incomplete=demo in incomplete_demos)) | ||
|
||
messages.append(format_turn(signature, inputs, role="user")) | ||
|
||
return messages | ||
|
||
def parse(self, signature, completion, _parse_values=True): | ||
|
@@ -90,7 +93,7 @@ def parse(self, signature, completion, _parse_values=True): | |
|
||
def format_turn(self, signature, values, role, incomplete=False): | ||
return format_turn(signature, values, role, incomplete) | ||
|
||
def format_fields(self, signature, values, role): | ||
fields_with_values = { | ||
FieldInfoWithName(name=field_name, info=field_info): values.get( | ||
|
@@ -101,16 +104,16 @@ def format_fields(self, signature, values, role): | |
} | ||
|
||
return format_fields(role=role, fields_with_values=fields_with_values) | ||
|
||
|
||
def parse_value(value, annotation): | ||
if annotation is str: | ||
return str(value) | ||
|
||
parsed_value = value | ||
|
||
if isinstance(annotation, enum.EnumMeta): | ||
parsed_value = annotation[value] | ||
parsed_value = find_enum_member(annotation, value) | ||
elif isinstance(value, str): | ||
try: | ||
parsed_value = json.loads(value) | ||
|
@@ -119,45 +122,10 @@ def parse_value(value, annotation): | |
parsed_value = ast.literal_eval(value) | ||
except (ValueError, SyntaxError): | ||
parsed_value = value | ||
|
||
return TypeAdapter(annotation).validate_python(parsed_value) | ||
|
||
|
||
def format_blob(blob): | ||
if "\n" not in blob and "«" not in blob and "»" not in blob: | ||
return f"«{blob}»" | ||
|
||
modified_blob = blob.replace("\n", "\n ") | ||
return f"«««\n {modified_blob}\n»»»" | ||
|
||
|
||
def format_input_list_field_value(value: List[Any]) -> str: | ||
""" | ||
Formats the value of an input field of type List[Any]. | ||
Args: | ||
value: The value of the list-type input field. | ||
Returns: | ||
A string representation of the input field's list value. | ||
""" | ||
if len(value) == 0: | ||
return "N/A" | ||
if len(value) == 1: | ||
return format_blob(value[0]) | ||
|
||
return "\n".join([f"[{idx+1}] {format_blob(txt)}" for idx, txt in enumerate(value)]) | ||
return TypeAdapter(annotation).validate_python(parsed_value) | ||
|
||
|
||
def _serialize_for_json(value): | ||
if isinstance(value, pydantic.BaseModel): | ||
return value.model_dump() | ||
elif isinstance(value, list): | ||
return [_serialize_for_json(item) for item in value] | ||
elif isinstance(value, dict): | ||
return {key: _serialize_for_json(val) for key, val in value.items()} | ||
else: | ||
return value | ||
|
||
Comment on lines
-126
to
-160
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Deduplicate code with chat_adapter by moving these into a |
||
def _format_field_value(field_info: FieldInfo, value: Any) -> str: | ||
""" | ||
Formats the value of the specified field according to the field's DSPy type (input or output), | ||
|
@@ -169,17 +137,10 @@ def _format_field_value(field_info: FieldInfo, value: Any) -> str: | |
Returns: | ||
The formatted value of the field, represented as a string. | ||
""" | ||
|
||
if isinstance(value, list) and field_info.annotation is str: | ||
# If the field has no special type requirements, format it as a nice numbere list for the LM. | ||
return format_input_list_field_value(value) | ||
if field_info.annotation is Image: | ||
raise NotImplementedError("Images are not yet supported in JSON mode.") | ||
elif isinstance(value, pydantic.BaseModel) or isinstance(value, dict) or isinstance(value, list): | ||
return json.dumps(_serialize_for_json(value)) | ||
else: | ||
return str(value) | ||
|
||
return format_field_value(field_info=field_info, value=value, assume_text=True) | ||
|
||
|
||
def format_fields(role: str, fields_with_values: Dict[FieldInfoWithName, Any]) -> str: | ||
|
@@ -197,9 +158,8 @@ def format_fields(role: str, fields_with_values: Dict[FieldInfoWithName, Any]) - | |
|
||
if role == "assistant": | ||
d = fields_with_values.items() | ||
d = {k.name: _serialize_for_json(v) for k, v in d} | ||
|
||
return json.dumps(_serialize_for_json(d), indent=2) | ||
d = {k.name: v for k, v in d} | ||
return json.dumps(serialize_for_json(d), indent=2) | ||
|
||
output = [] | ||
for field, field_value in fields_with_values.items(): | ||
|
@@ -246,15 +206,19 @@ def format_turn(signature: SignatureMeta, values: Dict[str, Any], role, incomple | |
field_name, "Not supplied for this particular example." | ||
) | ||
for field_name, field_info in fields.items() | ||
} | ||
}, | ||
) | ||
content.append(formatted_fields) | ||
|
||
if role == "user": | ||
|
||
def type_info(v): | ||
return f" (must be formatted as a valid Python {get_annotation_name(v.annotation)})" \ | ||
if v.annotation is not str else "" | ||
|
||
return ( | ||
f" (must be formatted as a valid Python {get_annotation_name(v.annotation)})" | ||
if v.annotation is not str | ||
else "" | ||
) | ||
Comment on lines
+216
to
+220
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. This (and everything else below it in this file) is just a linter change |
||
|
||
# TODO: Consider if not incomplete: | ||
content.append( | ||
"Respond with a JSON object in the following order of fields: " | ||
|
@@ -297,15 +261,15 @@ def prepare_instructions(signature: SignatureMeta): | |
def field_metadata(field_name, field_info): | ||
type_ = field_info.annotation | ||
|
||
if get_dspy_field_type(field_info) == 'input' or type_ is str: | ||
if get_dspy_field_type(field_info) == "input" or type_ is str: | ||
desc = "" | ||
elif type_ is bool: | ||
desc = "must be True or False" | ||
elif type_ in (int, float): | ||
desc = f"must be a single {type_.__name__} value" | ||
elif inspect.isclass(type_) and issubclass(type_, enum.Enum): | ||
desc= f"must be one of: {'; '.join(type_.__members__)}" | ||
elif hasattr(type_, '__origin__') and type_.__origin__ is Literal: | ||
desc = f"must be one of: {'; '.join(type_.__members__)}" | ||
elif hasattr(type_, "__origin__") and type_.__origin__ is Literal: | ||
desc = f"must be one of: {'; '.join([str(x) for x in type_.__args__])}" | ||
else: | ||
desc = "must be pareseable according to the following JSON schema: " | ||
|
@@ -320,13 +284,13 @@ def format_signature_fields_for_instructions(role, fields: Dict[str, FieldInfo]) | |
fields_with_values={ | ||
FieldInfoWithName(name=field_name, info=field_info): field_metadata(field_name, field_info) | ||
for field_name, field_info in fields.items() | ||
} | ||
}, | ||
) | ||
|
||
parts.append("Inputs will have the following structure:") | ||
parts.append(format_signature_fields_for_instructions('user', signature.input_fields)) | ||
parts.append(format_signature_fields_for_instructions("user", signature.input_fields)) | ||
parts.append("Outputs will be a JSON object with the following fields.") | ||
parts.append(format_signature_fields_for_instructions('assistant', signature.output_fields)) | ||
parts.append(format_signature_fields_for_instructions("assistant", signature.output_fields)) | ||
# parts.append(format_fields({BuiltInCompletedOutputFieldInfo: ""})) | ||
|
||
instructions = textwrap.dedent(signature.instructions) | ||
|
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Moved these utilities into a shared
adapters/utils.py
module, since they're identical but repeated across JSONAdapter and ChatAdapter.(Except that JSONAdapter doesn't supported images yet and throws an "unsupported" exception. This behavior is preserved by wrapping
utils.format_field_value()
with aif field_info.annotation is Image: throw
block)