Skip to content

Commit

Permalink
Allow react.Tool to wrap methods
Browse files Browse the repository at this point in the history
The big reason for this is to pass parameters out-of-band, e.g. a
user_id to ensure the LLM doesn't get the wrong data.

The unit test includes a usage, you can't use it as a decorator this
way, but it works.

The alternative, of course, is to have a very long function and have all
the tools be nested functions. It works, but can lead to some very long
functions. I prefer long classes over long functions.
  • Loading branch information
tkellogg committed Nov 25, 2024
1 parent 2aa6f01 commit 2629084
Show file tree
Hide file tree
Showing 2 changed files with 27 additions and 2 deletions.
2 changes: 1 addition & 1 deletion dspy/predict/react.py
Original file line number Diff line number Diff line change
Expand Up @@ -9,7 +9,7 @@

class Tool:
def __init__(self, func: Callable, name: str = None, desc: str = None, args: dict[str, Any] = None):
annotations_func = func if inspect.isfunction(func) else func.__call__
annotations_func = func if inspect.isfunction(func) or inspect.ismethod(func) else func.__call__
self.func = func
self.name = name or getattr(func, '__name__', type(func).__name__)
self.desc = desc or getattr(func, '__doc__', None) or getattr(annotations_func, '__doc__', "")
Expand Down
27 changes: 26 additions & 1 deletion tests/predict/test_react.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,6 +2,7 @@

import dspy
from dspy.utils.dummies import DummyLM, dummy_rm
from dspy.predict import react


# def test_example_no_tools():
Expand Down Expand Up @@ -121,4 +122,28 @@
# react = dspy.ReAct(ExampleSignature)

# assert react.react[0].signature.instructions is not None
# assert react.react[0].signature.instructions.startswith("You are going to generate output based on input.")
# assert react.react[0].signature.instructions.startswith("You are going to generate output based on input.")

def test_tool_from_function():
def foo(a: int, b: int) -> int:
"""Add two numbers."""
return a + b

tool = react.Tool(foo)
assert tool.name == "foo"
assert tool.desc == "Add two numbers."
assert tool.args == {"a": "int", "b": "int"}

def test_tool_from_class():
class Foo:
def __init__(self, user_id: str):
self.user_id = user_id

def foo(self, a: int, b: int) -> int:
"""Add two numbers."""
return a + b

tool = react.Tool(Foo("123").foo)
assert tool.name == "foo"
assert tool.desc == "Add two numbers."
assert tool.args == {"a": "int", "b": "int"}

0 comments on commit 2629084

Please sign in to comment.