Skip to content

Commit

Permalink
Added support for more types (#1900)
Browse files Browse the repository at this point in the history
* Added support for more types
  • Loading branch information
thomasahle authored and isaacbmiller committed Dec 11, 2024
1 parent 0cd8494 commit 6064015
Show file tree
Hide file tree
Showing 2 changed files with 208 additions and 32 deletions.
116 changes: 85 additions & 31 deletions dspy/signatures/signature.py
Original file line number Diff line number Diff line change
Expand Up @@ -7,12 +7,13 @@
from contextlib import ExitStack, contextmanager
from copy import deepcopy
from typing import Any, Dict, Tuple, Type, Union # noqa: UP035
import importlib

from pydantic import BaseModel, Field, create_model
from pydantic.fields import FieldInfo

import dsp
from dspy.adapters.image_utils import Image
from dspy.adapters.image_utils import Image # noqa: F401
from dspy.signatures.field import InputField, OutputField, new_to_old_field


Expand Down Expand Up @@ -355,7 +356,7 @@ def make_signature(
if type_ is None:
type_ = str
# if not isinstance(type_, type) and not isinstance(typing.get_origin(type_), type):
if not isinstance(type_, (type, typing._GenericAlias, types.GenericAlias)):
if not isinstance(type_, (type, typing._GenericAlias, types.GenericAlias, typing._SpecialForm)):
raise ValueError(f"Field types must be types, not {type(type_)}")
if not isinstance(field, FieldInfo):
raise ValueError(f"Field values must be Field instances, not {type(field)}")
Expand Down Expand Up @@ -400,53 +401,106 @@ def _parse_arg_string(string: str, names=None) -> Dict[str, str]:


def _parse_type_node(node, names=None) -> Any:
"""Recursively parse an AST node representing a type annotation.
without using structural pattern matching introduced in Python 3.10.
"""
"""Recursively parse an AST node representing a type annotation."""

if names is None:
names = typing.__dict__
names = dict(typing.__dict__)
names['NoneType'] = type(None)

def resolve_name(id_: str):
# Check if it's a built-in known type or in the provided names
if id_ in names:
return names[id_]

# Common built-in types
builtin_types = [int, str, float, bool, list, tuple, dict, set, frozenset, complex, bytes, bytearray]

# Try PIL Image if 'Image' encountered
if 'Image' not in names:
try:
from PIL import Image
names['Image'] = Image
except ImportError:
pass

# If we have PIL Image and id_ is 'Image', return it
if 'Image' in names and id_ == 'Image':
return names['Image']

# Check if it matches any known built-in type by name
for t in builtin_types:
if t.__name__ == id_:
return t

# Attempt to import a module with this name dynamically
# This allows handling of module-based annotations like `dspy.Image`.
try:
mod = importlib.import_module(id_)
names[id_] = mod
return mod
except ImportError:
pass

# If we don't know the type or module, raise an error
raise ValueError(f"Unknown name: {id_}")

if isinstance(node, ast.Module):
body = node.body
if len(body) != 1:
raise ValueError(f"Code is not syntactically valid: {node}")
return _parse_type_node(body[0], names)
if len(node.body) != 1:
raise ValueError(f"Code is not syntactically valid: {ast.dump(node)}")
return _parse_type_node(node.body[0], names)

if isinstance(node, ast.Expr):
value = node.value
return _parse_type_node(value, names)
return _parse_type_node(node.value, names)

if isinstance(node, ast.Name):
id_ = node.id
if id_ in names:
return names[id_]
return resolve_name(node.id)

for type_ in [int, str, float, bool, list, tuple, dict, Image]:
if type_.__name__ == id_:
return type_
raise ValueError(f"Unknown name: {id_}")
if isinstance(node, ast.Attribute):
base = _parse_type_node(node.value, names)
attr_name = node.attr
if hasattr(base, attr_name):
return getattr(base, attr_name)
else:
raise ValueError(f"Unknown attribute: {attr_name} on {base}")

if isinstance(node, ast.Subscript):
base_type = _parse_type_node(node.value, names)
arg_type = _parse_type_node(node.slice, names)
return base_type[arg_type]
slice_node = node.slice
if isinstance(slice_node, ast.Index): # For older Python versions
slice_node = slice_node.value

if isinstance(slice_node, ast.Tuple):
arg_types = tuple(_parse_type_node(elt, names) for elt in slice_node.elts)
else:
arg_types = (_parse_type_node(slice_node, names),)

# Special handling for Union, Optional
if base_type is typing.Union:
return typing.Union[arg_types]
if base_type is typing.Optional:
if len(arg_types) != 1:
raise ValueError("Optional must have exactly one type argument")
return typing.Optional[arg_types[0]]

return base_type[arg_types]

if isinstance(node, ast.Tuple):
elts = node.elts
return tuple(_parse_type_node(elt, names) for elt in elts)
return tuple(_parse_type_node(elt, names) for elt in node.elts)

if isinstance(node, ast.Constant):
return node.value

if isinstance(node, ast.Call) and node.func.id == "Field":
if isinstance(node, ast.Call) and isinstance(node.func, ast.Name) and node.func.id == "Field":
keys = [kw.arg for kw in node.keywords]
values = [kw.value.value for kw in node.keywords]
values = []
for kw in node.keywords:
if isinstance(kw.value, ast.Constant):
values.append(kw.value.value)
else:
values.append(_parse_type_node(kw.value, names))
return Field(**dict(zip(keys, values)))

if isinstance(node, ast.Attribute) and node.attr == "Image":
return Image

raise ValueError(f"Code is not syntactically valid: {node}")

raise ValueError(f"Unhandled AST node type in annotation: {ast.dump(node)}")

def infer_prefix(attribute_name: str) -> str:
"""Infer a prefix from an attribute name."""
Expand Down
124 changes: 123 additions & 1 deletion tests/signatures/test_signature.py
Original file line number Diff line number Diff line change
@@ -1,5 +1,5 @@
import textwrap
from typing import List
from typing import Any, Dict, List, Optional, Tuple, Union

import pydantic
import pytest
Expand Down Expand Up @@ -279,3 +279,125 @@ class CustomSignature2(dspy.Signature):
assert CustomSignature2.instructions == "I am a malicious instruction."
assert CustomSignature2.fields["sentence"].json_schema_extra["desc"] == "I am an malicious input!"
assert CustomSignature2.fields["sentiment"].json_schema_extra["prefix"] == "Sentiment:"


def test_typed_signatures_basic_types():
# Simple built-in types
sig = Signature("input1: int, input2: str -> output: float")
assert "input1" in sig.input_fields
assert sig.input_fields["input1"].annotation == int
assert "input2" in sig.input_fields
assert sig.input_fields["input2"].annotation == str
assert "output" in sig.output_fields
assert sig.output_fields["output"].annotation == float


def test_typed_signatures_generics():
# More complex generic types
sig = Signature("input_list: List[int], input_dict: Dict[str, float] -> output_tuple: Tuple[str, int]")
assert "input_list" in sig.input_fields
assert sig.input_fields["input_list"].annotation == List[int]
assert "input_dict" in sig.input_fields
assert sig.input_fields["input_dict"].annotation == Dict[str, float]
assert "output_tuple" in sig.output_fields
assert sig.output_fields["output_tuple"].annotation == Tuple[str, int]


def test_typed_signatures_unions_and_optionals():
sig = Signature("input_opt: Optional[str], input_union: Union[int, None] -> output_union: Union[int, str]")
assert "input_opt" in sig.input_fields
# Optional[str] is actually Union[str, None]
# Depending on the environment, it might resolve to Union[str, None] or Optional[str], either is correct.
# We'll just check for a Union containing str and NoneType:
input_opt_annotation = sig.input_fields["input_opt"].annotation
assert (input_opt_annotation == Optional[str] or
(getattr(input_opt_annotation, '__origin__', None) is Union and str in input_opt_annotation.__args__ and type(None) in input_opt_annotation.__args__))

assert "input_union" in sig.input_fields
input_union_annotation = sig.input_fields["input_union"].annotation
assert (getattr(input_union_annotation, '__origin__', None) is Union and
int in input_union_annotation.__args__ and type(None) in input_union_annotation.__args__)

assert "output_union" in sig.output_fields
output_union_annotation = sig.output_fields["output_union"].annotation
assert (getattr(output_union_annotation, '__origin__', None) is Union and
int in output_union_annotation.__args__ and str in output_union_annotation.__args__)


def test_typed_signatures_any():
sig = Signature("input_any: Any -> output_any: Any")
assert "input_any" in sig.input_fields
assert sig.input_fields["input_any"].annotation == Any
assert "output_any" in sig.output_fields
assert sig.output_fields["output_any"].annotation == Any


def test_typed_signatures_nested():
# Nested generics and unions
sig = Signature("input_nested: List[Union[str, int]] -> output_nested: Tuple[int, Optional[float], List[str]]")
input_nested_ann = sig.input_fields["input_nested"].annotation
assert getattr(input_nested_ann, '__origin__', None) is list
assert len(input_nested_ann.__args__) == 1
union_arg = input_nested_ann.__args__[0]
assert getattr(union_arg, '__origin__', None) is Union
assert str in union_arg.__args__ and int in union_arg.__args__

output_nested_ann = sig.output_fields["output_nested"].annotation
assert getattr(output_nested_ann, '__origin__', None) is tuple
assert output_nested_ann.__args__[0] == int
# The second arg is Optional[float], which is Union[float, None]
second_arg = output_nested_ann.__args__[1]
assert getattr(second_arg, '__origin__', None) is Union
assert float in second_arg.__args__ and type(None) in second_arg.__args__
# The third arg is List[str]
third_arg = output_nested_ann.__args__[2]
assert getattr(third_arg, '__origin__', None) is list
assert third_arg.__args__[0] == str


def test_typed_signatures_from_dict():
# Creating a Signature directly from a dictionary with types
fields = {
"input_str_list": (List[str], InputField()),
"input_dict_int": (Dict[str, int], InputField()),
"output_tup": (Tuple[int, float], OutputField()),
}
sig = Signature(fields)
assert "input_str_list" in sig.input_fields
assert sig.input_fields["input_str_list"].annotation == List[str]
assert "input_dict_int" in sig.input_fields
assert sig.input_fields["input_dict_int"].annotation == Dict[str, int]
assert "output_tup" in sig.output_fields
assert sig.output_fields["output_tup"].annotation == Tuple[int, float]


def test_typed_signatures_complex_combinations():
# Test a very complex signature with multiple nested constructs
# input_complex: Dict[str, List[Optional[Tuple[int, str]]]] -> output_complex: Union[List[str], Dict[str, Any]]
sig = Signature("input_complex: Dict[str, List[Optional[Tuple[int, str]]]] -> output_complex: Union[List[str], Dict[str, Any]]")
input_complex_ann = sig.input_fields["input_complex"].annotation
assert getattr(input_complex_ann, '__origin__', None) is dict
key_arg, value_arg = input_complex_ann.__args__
assert key_arg == str
# value_arg: List[Optional[Tuple[int, str]]]
assert getattr(value_arg, '__origin__', None) is list
inner_union = value_arg.__args__[0]
# inner_union should be Optional[Tuple[int, str]]
# which is Union[Tuple[int, str], None]
assert getattr(inner_union, '__origin__', None) is Union
tuple_type = [t for t in inner_union.__args__ if t != type(None)][0]
assert getattr(tuple_type, '__origin__', None) is tuple
assert tuple_type.__args__ == (int, str)

output_complex_ann = sig.output_fields["output_complex"].annotation
assert getattr(output_complex_ann, '__origin__', None) is Union
assert len(output_complex_ann.__args__) == 2
possible_args = set(output_complex_ann.__args__)
# Expecting List[str] and Dict[str, Any]
# Because sets don't preserve order, just check membership.
# Find the List[str] arg
list_arg = next(a for a in possible_args if getattr(a, '__origin__', None) is list)
dict_arg = next(a for a in possible_args if getattr(a, '__origin__', None) is dict)
assert list_arg.__args__ == (str,)
k, v = dict_arg.__args__
assert k == str and v == Any

0 comments on commit 6064015

Please sign in to comment.