-
Notifications
You must be signed in to change notification settings - Fork 1.6k
Commit
This commit does not belong to any branch on this repository, and may belong to a fork outside of the repository.
Revamp ReAct, adjust Bootstrap, adjust ChatAdapter (#1713)
* JsonAdapter: Handle JSON formatting in demo's outputs * Adjustmetns for JsonAdapter * Revamp ReAct, adjust Bootstrap (handle repeat calls to a module; transpose order for max_rounds), adjust ChatAdapter (handle incomplete demos better) * Remove ReAct tests (outdated format) * Remove react tests (outdated)
- Loading branch information
Showing
9 changed files
with
444 additions
and
429 deletions.
There are no files selected for viewing
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -1,3 +1,3 @@ | ||
from .base import Adapter | ||
from .chat_adapter import ChatAdapter | ||
from .json_adapter import JsonAdapter | ||
from dspy.adapters.base import Adapter | ||
from dspy.adapters.chat_adapter import ChatAdapter | ||
from dspy.adapters.json_adapter import JsonAdapter |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -1,130 +1,101 @@ | ||
import dsp | ||
import dspy | ||
from dspy.signatures.signature import ensure_signature | ||
|
||
from ..primitives.program import Module | ||
from .predict import Predict | ||
import inspect | ||
|
||
# TODO: Simplify a lot. | ||
# TODO: Divide Action and Action Input like langchain does for ReAct. | ||
from pydantic import BaseModel | ||
from dspy.primitives.program import Module | ||
from dspy.signatures.signature import ensure_signature | ||
from dspy.adapters.json_adapter import get_annotation_name | ||
from typing import Callable, Any, get_type_hints, get_origin, Literal | ||
|
||
# TODO: There's a lot of value in having a stopping condition in the LM calls at `\n\nObservation:` | ||
class Tool: | ||
def __init__(self, func: Callable, name: str = None, desc: str = None, args: dict[str, Any] = None): | ||
annotations_func = func if inspect.isfunction(func) else func.__call__ | ||
self.func = func | ||
self.name = name or getattr(func, '__name__', type(func).__name__) | ||
self.desc = desc or getattr(func, '__doc__', None) or getattr(annotations_func, '__doc__', "No description") | ||
self.args = { | ||
k: v.schema() if isinstance((origin := get_origin(v) or v), type) and issubclass(origin, BaseModel) | ||
else get_annotation_name(v) | ||
for k, v in (args or get_type_hints(annotations_func)).items() if k != 'return' | ||
} | ||
|
||
# TODO [NEW]: When max_iters is about to be reached, reduce the set of available actions to only the Finish action. | ||
def __call__(self, *args, **kwargs): | ||
return self.func(*args, **kwargs) | ||
|
||
|
||
class ReAct(Module): | ||
def __init__(self, signature, max_iters=5, num_results=3, tools=None): | ||
super().__init__() | ||
def __init__(self, signature, tools: list[Callable], max_iters=5): | ||
""" | ||
Tools is either a list of functions, callable classes, or dspy.Tool instances. | ||
""" | ||
|
||
self.signature = signature = ensure_signature(signature) | ||
self.max_iters = max_iters | ||
|
||
self.tools = tools or [dspy.Retrieve(k=num_results)] | ||
self.tools = {tool.name: tool for tool in self.tools} | ||
|
||
self.input_fields = self.signature.input_fields | ||
self.output_fields = self.signature.output_fields | ||
|
||
assert len(self.output_fields) == 1, "ReAct only supports one output field." | ||
tools = [t if isinstance(t, Tool) or hasattr(t, 'input_variable') else Tool(t) for t in tools] | ||
tools = {tool.name: tool for tool in tools} | ||
|
||
inputs_ = ", ".join([f"`{k}`" for k in self.input_fields.keys()]) | ||
outputs_ = ", ".join([f"`{k}`" for k in self.output_fields.keys()]) | ||
inputs_ = ", ".join([f"`{k}`" for k in signature.input_fields.keys()]) | ||
outputs_ = ", ".join([f"`{k}`" for k in signature.output_fields.keys()]) | ||
instr = [f"{signature.instructions}\n"] if signature.instructions else [] | ||
|
||
instr = [] | ||
|
||
if self.signature.instructions is not None: | ||
instr.append(f"{self.signature.instructions}\n") | ||
|
||
instr.extend([ | ||
f"You will be given {inputs_} and you will respond with {outputs_}.\n", | ||
"To do this, you will interleave Thought, Action, and Observation steps.\n", | ||
"Thought can reason about the current situation, and Action can be the following types:\n", | ||
f"You will be given {inputs_} and your goal is to finish with {outputs_}.\n", | ||
"To do this, you will interleave Thought, Tool Name, and Tool Args, and receive a resulting Observation.\n", | ||
"Thought can reason about the current situation, and Tool Name can be the following types:\n", | ||
]) | ||
|
||
self.tools["Finish"] = dspy.Example( | ||
name="Finish", | ||
input_variable=outputs_.strip("`"), | ||
desc=f"returns the final {outputs_} and finishes the task", | ||
finish_desc = f"Signals that the final outputs, i.e. {outputs_}, are now available and marks the task as complete." | ||
finish_args = {} #k: v.annotation for k, v in signature.output_fields.items()} | ||
tools["finish"] = Tool(func=lambda **kwargs: kwargs, name="finish", desc=finish_desc, args=finish_args) | ||
|
||
for idx, tool in enumerate(tools.values()): | ||
desc = tool.desc.replace("\n", " ") | ||
args = tool.args if hasattr(tool, 'args') else str({tool.input_variable: str}) | ||
desc = f"whose description is <desc>{desc}</desc>. It takes arguments {args} in JSON format." | ||
instr.append(f"({idx+1}) {tool.name}, {desc}") | ||
|
||
signature_ = ( | ||
dspy.Signature({**signature.input_fields}, "\n".join(instr)) | ||
.append("trajectory", dspy.InputField(), type_=str) | ||
.append("next_thought", dspy.OutputField(), type_=str) | ||
.append("next_tool_name", dspy.OutputField(), type_=Literal[tuple(tools.keys())]) | ||
.append("next_tool_args", dspy.OutputField(), type_=dict[str, Any]) | ||
) | ||
|
||
for idx, tool in enumerate(self.tools): | ||
tool = self.tools[tool] | ||
instr.append( | ||
f"({idx+1}) {tool.name}[{tool.input_variable}], which {tool.desc}", | ||
) | ||
|
||
instr = "\n".join(instr) | ||
self.react = [ | ||
Predict(dspy.Signature(self._generate_signature(i), instr)) | ||
for i in range(1, max_iters + 1) | ||
] | ||
|
||
def _generate_signature(self, iters): | ||
signature_dict = {} | ||
for key, val in self.input_fields.items(): | ||
signature_dict[key] = val | ||
|
||
for j in range(1, iters + 1): | ||
IOField = dspy.OutputField if j == iters else dspy.InputField | ||
|
||
signature_dict[f"Thought_{j}"] = IOField( | ||
prefix=f"Thought {j}:", | ||
desc="next steps to take based on last observation", | ||
) | ||
|
||
tool_list = " or ".join( | ||
[ | ||
f"{tool.name}[{tool.input_variable}]" | ||
for tool in self.tools.values() | ||
if tool.name != "Finish" | ||
], | ||
) | ||
signature_dict[f"Action_{j}"] = IOField( | ||
prefix=f"Action {j}:", | ||
desc=f"always either {tool_list} or, when done, Finish[<answer>], where <answer> is the answer to the question itself.", | ||
) | ||
|
||
if j < iters: | ||
signature_dict[f"Observation_{j}"] = IOField( | ||
prefix=f"Observation {j}:", | ||
desc="observations based on action", | ||
format=dsp.passages2text, | ||
) | ||
|
||
return signature_dict | ||
|
||
def act(self, output, hop): | ||
try: | ||
action = output[f"Action_{hop+1}"] | ||
action_name, action_val = action.strip().split("\n")[0].split("[", 1) | ||
action_val = action_val.rsplit("]", 1)[0] | ||
|
||
if action_name == "Finish": | ||
return action_val | ||
|
||
result = self.tools[action_name](action_val) #result must be a str, list, or tuple | ||
# Handle the case where 'passages' attribute is missing | ||
output[f"Observation_{hop+1}"] = getattr(result, "passages", result) | ||
|
||
except Exception: | ||
output[f"Observation_{hop+1}"] = ( | ||
"Failed to parse action. Bad formatting or incorrect action name." | ||
) | ||
# raise e | ||
|
||
def forward(self, **kwargs): | ||
args = {key: kwargs[key] for key in self.input_fields.keys() if key in kwargs} | ||
|
||
for hop in range(self.max_iters): | ||
# with dspy.settings.context(show_guidelines=(i <= 2)): | ||
output = self.react[hop](**args) | ||
output[f'Action_{hop + 1}'] = output[f'Action_{hop + 1}'].split('\n')[0] | ||
|
||
if action_val := self.act(output, hop): | ||
break | ||
args.update(output) | ||
fallback_signature = ( | ||
dspy.Signature({**signature.input_fields, **signature.output_fields}) | ||
.append("trajectory", dspy.InputField(), type_=str) | ||
) | ||
|
||
observations = [args[key] for key in args if key.startswith("Observation")] | ||
self.tools = tools | ||
self.react = dspy.Predict(signature_) | ||
self.extract = dspy.ChainOfThought(fallback_signature) | ||
|
||
def forward(self, **input_args): | ||
trajectory = {} | ||
|
||
def format(trajectory_: dict[str, Any], last_iteration: bool): | ||
adapter = dspy.settings.adapter or dspy.ChatAdapter() | ||
blob = adapter.format_fields(dspy.Signature(f"{', '.join(trajectory_.keys())} -> x"), trajectory_) | ||
warning = f"\n\nWarning: The maximum number of iterations ({self.max_iters}) has been reached." | ||
warning += " You must now produce the finish action." | ||
return blob + (warning if last_iteration else "") | ||
|
||
for idx in range(self.max_iters): | ||
pred = self.react(**input_args, trajectory=format(trajectory, last_iteration=(idx == self.max_iters-1))) | ||
|
||
trajectory[f"thought_{idx}"] = pred.next_thought | ||
trajectory[f"tool_name_{idx}"] = pred.next_tool_name | ||
trajectory[f"tool_args_{idx}"] = pred.next_tool_args | ||
|
||
try: | ||
trajectory[f"observation_{idx}"] = self.tools[pred.next_tool_name](**pred.next_tool_args) | ||
except Exception as e: | ||
trajectory[f"observation_{idx}"] = f"Failed to execute: {e}" | ||
|
||
if pred.next_tool_name == "finish": | ||
break | ||
|
||
# assumes only 1 output field for now - TODO: handling for multiple output fields | ||
return dspy.Prediction(observations=observations, **{list(self.output_fields.keys())[0]: action_val or ""}) | ||
extract = self.extract(**input_args, trajectory=format(trajectory, last_iteration=False)) | ||
return dspy.Prediction(trajectory=trajectory, **extract) |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Oops, something went wrong.