Skip to content

Commit

Permalink
Add on_tool_start/end callbacks (#1879)
Browse files Browse the repository at this point in the history
Signed-off-by: B-Step62 <[email protected]>
  • Loading branch information
B-Step62 authored and isaacbmiller committed Dec 11, 2024
1 parent 8c93942 commit e9e728f
Show file tree
Hide file tree
Showing 4 changed files with 83 additions and 1 deletion.
2 changes: 1 addition & 1 deletion dspy/predict/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,6 +5,6 @@
from .multi_chain_comparison import MultiChainComparison
from .predict import Predict
from .program_of_thought import ProgramOfThought
from .react import ReAct
from .react import ReAct, Tool
from .retry import Retry
from .parallel import Parallel
2 changes: 2 additions & 0 deletions dspy/predict/react.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,6 +5,7 @@
from dspy.primitives.program import Module
from dspy.signatures.signature import ensure_signature
from dspy.adapters.json_adapter import get_annotation_name
from dspy.utils.callback import with_callbacks
from typing import Callable, Any, get_type_hints, get_origin, Literal

class Tool:
Expand All @@ -19,6 +20,7 @@ def __init__(self, func: Callable, name: str = None, desc: str = None, args: dic
for k, v in (args or get_type_hints(annotations_func)).items() if k != 'return'
}

@with_callbacks
def __call__(self, *args, **kwargs):
return self.func(*args, **kwargs)

Expand Down
39 changes: 39 additions & 0 deletions dspy/utils/callback.py
Original file line number Diff line number Diff line change
Expand Up @@ -190,6 +190,38 @@ def on_adapter_parse_end(
"""
pass

def on_tool_start(
self,
call_id: str,
instance: Any,
inputs: Dict[str, Any],
):
"""A handler triggered when a tool is called.
Args:
call_id: A unique identifier for the call. Can be used to connect start/end handlers.
instance: The Tool instance.
inputs: The inputs to the Tool's __call__ method. Each arguments is stored as
a key-value pair in a dictionary.
"""
pass

def on_tool_end(
self,
call_id: str,
outputs: Optional[Dict[str, Any]],
exception: Optional[Exception] = None,
):
"""A handler triggered after a tool is executed.
Args:
call_id: A unique identifier for the call. Can be used to connect start/end handlers.
outputs: The outputs of the Tool's __call__ method. If the method is interrupted by
an exception, this will be None.
exception: If an exception is raised during the execution, it will be stored here.
"""
pass


def with_callbacks(fn):
@functools.wraps(fn)
Expand Down Expand Up @@ -256,6 +288,9 @@ def _get_on_start_handler(callback: BaseCallback, instance: Any, fn: Callable) -
else:
raise ValueError(f"Unsupported adapter method for using callback: {fn.__name__}.")

if isinstance(instance, dspy.Tool):
return callback.on_tool_start

# We treat everything else as a module.
return callback.on_module_start

Expand All @@ -272,5 +307,9 @@ def _get_on_end_handler(callback: BaseCallback, instance: Any, fn: Callable) ->
return callback.on_adapter_parse_end
else:
raise ValueError(f"Unsupported adapter method for using callback: {fn.__name__}.")

if isinstance(instance, dspy.Tool):
return callback.on_tool_end

# We treat everything else as a module.
return callback.on_module_end
41 changes: 41 additions & 0 deletions tests/callback/test_callback.py
Original file line number Diff line number Diff line change
Expand Up @@ -47,6 +47,12 @@ def on_adapter_parse_start(self, call_id, instance, inputs):
def on_adapter_parse_end(self, call_id, outputs, exception):
self.calls.append({"handler": "on_adapter_parse_end", "outputs": outputs, "exception": exception})

def on_tool_start(self, call_id, instance, inputs):
self.calls.append({"handler": "on_tool_start", "instance": instance, "inputs": inputs})

def on_tool_end(self, call_id, outputs, exception):
self.calls.append({"handler": "on_tool_end", "outputs": outputs, "exception": exception})


@pytest.mark.parametrize(
("args", "kwargs"),
Expand Down Expand Up @@ -181,6 +187,41 @@ def test_callback_complex_module():
]


def test_tool_calls():
callback = MyCallback()
dspy.settings.configure(callbacks=[callback])

def tool_1(query: str) -> str:
"""A dummy tool function."""
return "result 1"

def tool_2(query: str) -> str:
"""Another dummy tool function."""
return "result 2"

class MyModule(dspy.Module):
def __init__(self):
self.tools = [dspy.Tool(tool_1), dspy.Tool(tool_2)]

def forward(self, query: str) -> str:
query = self.tools[0](query)
return self.tools[1](query)

module = MyModule()
result = module("query")

assert result == "result 2"
assert len(callback.calls) == 6
assert [call["handler"] for call in callback.calls] == [
"on_module_start",
"on_tool_start",
"on_tool_end",
"on_tool_start",
"on_tool_end",
"on_module_end",
]


def test_active_id():
# Test the call ID is generated and handled properly
class CustomCallback(BaseCallback):
Expand Down

0 comments on commit e9e728f

Please sign in to comment.