Skip to content

Commit

Permalink
Adapters: Support JSON serialization of all pydantic types (e.g. date…
Browse files Browse the repository at this point in the history
…times, enums, etc.) (#1853)

* Add

Signed-off-by: dbczumar <[email protected]>

* fix

Signed-off-by: dbczumar <[email protected]>

* fix

Signed-off-by: dbczumar <[email protected]>

* fix

Signed-off-by: dbczumar <[email protected]>

* fix

Signed-off-by: dbczumar <[email protected]>

---------

Signed-off-by: dbczumar <[email protected]>
  • Loading branch information
dbczumar authored and isaacbmiller committed Dec 11, 2024
1 parent 4699adf commit 6ec4a76
Show file tree
Hide file tree
Showing 8 changed files with 304 additions and 183 deletions.
99 changes: 3 additions & 96 deletions dspy/adapters/chat_adapter.py
Original file line number Diff line number Diff line change
Expand Up @@ -14,7 +14,7 @@

from dsp.adapters.base_template import Field
from dspy.adapters.base import Adapter
from dspy.adapters.image_utils import Image, encode_image
from dspy.adapters.utils import find_enum_member, format_field_value
from dspy.signatures.field import OutputField
from dspy.signatures.signature import Signature, SignatureMeta
from dspy.signatures.utils import get_dspy_field_type
Expand Down Expand Up @@ -114,99 +114,6 @@ def format_fields(self, signature, values, role):
return format_fields(fields_with_values)


def format_blob(blob):
if "\n" not in blob and "«" not in blob and "»" not in blob:
return f"«{blob}»"

modified_blob = blob.replace("\n", "\n ")
return f"«««\n {modified_blob}\n»»»"


def format_input_list_field_value(value: List[Any]) -> str:
"""
Formats the value of an input field of type List[Any].
Args:
value: The value of the list-type input field.
Returns:
A string representation of the input field's list value.
"""
if len(value) == 0:
return "N/A"
if len(value) == 1:
return format_blob(value[0])

return "\n".join([f"[{idx+1}] {format_blob(txt)}" for idx, txt in enumerate(value)])


def _serialize_for_json(value):
if isinstance(value, pydantic.BaseModel):
return value.model_dump()
elif isinstance(value, list):
return [_serialize_for_json(item) for item in value]
elif isinstance(value, dict):
return {key: _serialize_for_json(val) for key, val in value.items()}
else:
return value


def _format_field_value(field_info: FieldInfo, value: Any, assume_text=True) -> Union[str, dict]:
"""
Formats the value of the specified field according to the field's DSPy type (input or output),
annotation (e.g. str, int, etc.), and the type of the value itself.
Args:
field_info: Information about the field, including its DSPy field type and annotation.
value: The value of the field.
Returns:
The formatted value of the field, represented as a string.
"""
string_value = None

if field_info.annotation == Image and is_image(value):
print("value: ", value)
value = Image(url=encode_image(value))
# print("field info: ", field_info)
# if not isinstance(value, Image):
# print(f"Coerced image: {value}")
# coerced_image = Image(url=encode_image(value))
# print("post coerce: ", coerced_image)
# string_value = json.dumps(_serialize_for_json(coerced_image), ensure_ascii=False)
if isinstance(value, list) and field_info.annotation is str:
# If the field has no special type requirements, format it as a nice numbered list for the LM.
string_value = format_input_list_field_value(value)
elif isinstance(value, pydantic.BaseModel) or isinstance(value, dict) or isinstance(value, list):
string_value = json.dumps(_serialize_for_json(value), ensure_ascii=False)
else:
string_value = str(value)

if assume_text:
return string_value

# What we actually want is that for any image inside of any arbitrary normal python or pudantic object, when we see it
# it will trigger some sort of escape sequence that we then combine at the end in order to make it a cohesive request to send to OAI
# Hooking too deep into the serialization process is a bad idea, but we need an escape hatch somewhere

# elif (isinstance(value, Image) or field_info.annotation == Image):
# # This validation should happen somewhere else
# # Safe to import PIL here because it's only imported when an image is actually being formatted
# try:
# import PIL
# except ImportError:
# raise ImportError("PIL is required to format images; Run `pip install pillow` to install it.")
# image_value = value
# if not isinstance(image_value, Image):
# if isinstance(image_value, dict) and "url" in image_value:
# image_value = image_value["url"]
# elif isinstance(image_value, str) or isinstance(image_value, PIL.Image.Image):
# image_value = encode_image(image_value)
# assert isinstance(image_value, str)
# image_value = Image(url=image_value)
# return {"type": "image_url", "image_url": image_value.model_dump()}
else:
return {"type": "text", "text": string_value}


def format_fields(fields_with_values: Dict[FieldInfoWithName, Any], assume_text=True) -> Union[str, List[dict]]:
"""
Formats the values of the specified fields according to the field's DSPy type (input or output),
Expand All @@ -221,7 +128,7 @@ def format_fields(fields_with_values: Dict[FieldInfoWithName, Any], assume_text=
"""
output = []
for field, field_value in fields_with_values.items():
formatted_field_value = _format_field_value(field_info=field.info, value=field_value, assume_text=assume_text)
formatted_field_value = format_field_value(field_info=field.info, value=field_value, assume_text=assume_text)
if assume_text:
output.append(f"[[ ## {field.name} ## ]]\n{formatted_field_value}")
else:
Expand All @@ -243,7 +150,7 @@ def parse_value(value, annotation):
parsed_value = value

if isinstance(annotation, enum.EnumMeta):
parsed_value = annotation[value]
return find_enum_member(annotation, value)
elif isinstance(value, str):
try:
parsed_value = json.loads(value)
Expand Down
118 changes: 41 additions & 77 deletions dspy/adapters/json_adapter.py
Original file line number Diff line number Diff line change
@@ -1,39 +1,41 @@
import ast
import json
import enum
import inspect
import litellm
import pydantic
import json
import textwrap
import json_repair

from typing import Any, Dict, KeysView, Literal, NamedTuple, get_args, get_origin

import json_repair
import litellm
import pydantic
from pydantic import TypeAdapter
from pydantic.fields import FieldInfo
from typing import Any, Dict, KeysView, List, Literal, NamedTuple, get_args, get_origin

from dspy.adapters.base import Adapter
from dspy.adapters.utils import find_enum_member, format_field_value, serialize_for_json

from ..adapters.image_utils import Image
from ..signatures.signature import SignatureMeta
from ..signatures.utils import get_dspy_field_type


class FieldInfoWithName(NamedTuple):
name: str
info: FieldInfo


class JSONAdapter(Adapter):
def __init__(self):
pass

def __call__(self, lm, lm_kwargs, signature, demos, inputs, _parse_values=True):
inputs = self.format(signature, demos, inputs)
inputs = dict(prompt=inputs) if isinstance(inputs, str) else dict(messages=inputs)



try:
provider = lm.model.split('/', 1)[0] or "openai"
if 'response_format' in litellm.get_supported_openai_params(model=lm.model, custom_llm_provider=provider):
outputs = lm(**inputs, **lm_kwargs, response_format={ "type": "json_object" })
provider = lm.model.split("/", 1)[0] or "openai"
if "response_format" in litellm.get_supported_openai_params(model=lm.model, custom_llm_provider=provider):
outputs = lm(**inputs, **lm_kwargs, response_format={"type": "json_object"})
else:
outputs = lm(**inputs, **lm_kwargs)

Expand All @@ -44,11 +46,12 @@ def __call__(self, lm, lm_kwargs, signature, demos, inputs, _parse_values=True):

for output in outputs:
value = self.parse(signature, output, _parse_values=_parse_values)
assert set(value.keys()) == set(signature.output_fields.keys()), f"Expected {signature.output_fields.keys()} but got {value.keys()}"
assert set(value.keys()) == set(
signature.output_fields.keys()
), f"Expected {signature.output_fields.keys()} but got {value.keys()}"
values.append(value)

return values

return values

def format(self, signature, demos, inputs):
messages = []
Expand All @@ -71,7 +74,7 @@ def format(self, signature, demos, inputs):
messages.append(format_turn(signature, demo, role="assistant", incomplete=demo in incomplete_demos))

messages.append(format_turn(signature, inputs, role="user"))

return messages

def parse(self, signature, completion, _parse_values=True):
Expand All @@ -90,7 +93,7 @@ def parse(self, signature, completion, _parse_values=True):

def format_turn(self, signature, values, role, incomplete=False):
return format_turn(signature, values, role, incomplete)

def format_fields(self, signature, values, role):
fields_with_values = {
FieldInfoWithName(name=field_name, info=field_info): values.get(
Expand All @@ -101,16 +104,16 @@ def format_fields(self, signature, values, role):
}

return format_fields(role=role, fields_with_values=fields_with_values)


def parse_value(value, annotation):
if annotation is str:
return str(value)

parsed_value = value

if isinstance(annotation, enum.EnumMeta):
parsed_value = annotation[value]
parsed_value = find_enum_member(annotation, value)
elif isinstance(value, str):
try:
parsed_value = json.loads(value)
Expand All @@ -119,45 +122,10 @@ def parse_value(value, annotation):
parsed_value = ast.literal_eval(value)
except (ValueError, SyntaxError):
parsed_value = value

return TypeAdapter(annotation).validate_python(parsed_value)


def format_blob(blob):
if "\n" not in blob and "«" not in blob and "»" not in blob:
return f"«{blob}»"

modified_blob = blob.replace("\n", "\n ")
return f"«««\n {modified_blob}\n»»»"


def format_input_list_field_value(value: List[Any]) -> str:
"""
Formats the value of an input field of type List[Any].
Args:
value: The value of the list-type input field.
Returns:
A string representation of the input field's list value.
"""
if len(value) == 0:
return "N/A"
if len(value) == 1:
return format_blob(value[0])

return "\n".join([f"[{idx+1}] {format_blob(txt)}" for idx, txt in enumerate(value)])
return TypeAdapter(annotation).validate_python(parsed_value)


def _serialize_for_json(value):
if isinstance(value, pydantic.BaseModel):
return value.model_dump()
elif isinstance(value, list):
return [_serialize_for_json(item) for item in value]
elif isinstance(value, dict):
return {key: _serialize_for_json(val) for key, val in value.items()}
else:
return value

def _format_field_value(field_info: FieldInfo, value: Any) -> str:
"""
Formats the value of the specified field according to the field's DSPy type (input or output),
Expand All @@ -169,17 +137,10 @@ def _format_field_value(field_info: FieldInfo, value: Any) -> str:
Returns:
The formatted value of the field, represented as a string.
"""

if isinstance(value, list) and field_info.annotation is str:
# If the field has no special type requirements, format it as a nice numbere list for the LM.
return format_input_list_field_value(value)
if field_info.annotation is Image:
raise NotImplementedError("Images are not yet supported in JSON mode.")
elif isinstance(value, pydantic.BaseModel) or isinstance(value, dict) or isinstance(value, list):
return json.dumps(_serialize_for_json(value))
else:
return str(value)

return format_field_value(field_info=field_info, value=value, assume_text=True)


def format_fields(role: str, fields_with_values: Dict[FieldInfoWithName, Any]) -> str:
Expand All @@ -197,9 +158,8 @@ def format_fields(role: str, fields_with_values: Dict[FieldInfoWithName, Any]) -

if role == "assistant":
d = fields_with_values.items()
d = {k.name: _serialize_for_json(v) for k, v in d}

return json.dumps(_serialize_for_json(d), indent=2)
d = {k.name: v for k, v in d}
return json.dumps(serialize_for_json(d), indent=2)

output = []
for field, field_value in fields_with_values.items():
Expand Down Expand Up @@ -246,15 +206,19 @@ def format_turn(signature: SignatureMeta, values: Dict[str, Any], role, incomple
field_name, "Not supplied for this particular example."
)
for field_name, field_info in fields.items()
}
},
)
content.append(formatted_fields)

if role == "user":

def type_info(v):
return f" (must be formatted as a valid Python {get_annotation_name(v.annotation)})" \
if v.annotation is not str else ""

return (
f" (must be formatted as a valid Python {get_annotation_name(v.annotation)})"
if v.annotation is not str
else ""
)

# TODO: Consider if not incomplete:
content.append(
"Respond with a JSON object in the following order of fields: "
Expand Down Expand Up @@ -297,15 +261,15 @@ def prepare_instructions(signature: SignatureMeta):
def field_metadata(field_name, field_info):
type_ = field_info.annotation

if get_dspy_field_type(field_info) == 'input' or type_ is str:
if get_dspy_field_type(field_info) == "input" or type_ is str:
desc = ""
elif type_ is bool:
desc = "must be True or False"
elif type_ in (int, float):
desc = f"must be a single {type_.__name__} value"
elif inspect.isclass(type_) and issubclass(type_, enum.Enum):
desc= f"must be one of: {'; '.join(type_.__members__)}"
elif hasattr(type_, '__origin__') and type_.__origin__ is Literal:
desc = f"must be one of: {'; '.join(type_.__members__)}"
elif hasattr(type_, "__origin__") and type_.__origin__ is Literal:
desc = f"must be one of: {'; '.join([str(x) for x in type_.__args__])}"
else:
desc = "must be pareseable according to the following JSON schema: "
Expand All @@ -320,13 +284,13 @@ def format_signature_fields_for_instructions(role, fields: Dict[str, FieldInfo])
fields_with_values={
FieldInfoWithName(name=field_name, info=field_info): field_metadata(field_name, field_info)
for field_name, field_info in fields.items()
}
},
)

parts.append("Inputs will have the following structure:")
parts.append(format_signature_fields_for_instructions('user', signature.input_fields))
parts.append(format_signature_fields_for_instructions("user", signature.input_fields))
parts.append("Outputs will be a JSON object with the following fields.")
parts.append(format_signature_fields_for_instructions('assistant', signature.output_fields))
parts.append(format_signature_fields_for_instructions("assistant", signature.output_fields))
# parts.append(format_fields({BuiltInCompletedOutputFieldInfo: ""}))

instructions = textwrap.dedent(signature.instructions)
Expand Down
Loading

0 comments on commit 6ec4a76

Please sign in to comment.