Skip to content

Commit

Permalink
Simplify built-in modules (remove new_* & extended_signature). Prepar…
Browse files Browse the repository at this point in the history
…e for assertions v2. (#1943)

* Simplify built-in modules (remove extended_signature and new_signature) and remove assertions temporarily

* Update test_retry.py

* Fully turn CoT into a typical Module
  • Loading branch information
okhat authored Dec 16, 2024
1 parent c117dac commit ae86009
Show file tree
Hide file tree
Showing 13 changed files with 523 additions and 577 deletions.
4 changes: 2 additions & 2 deletions dspy/predict/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -6,5 +6,5 @@
from .predict import Predict
from .program_of_thought import ProgramOfThought
from .react import ReAct, Tool
from .retry import Retry
from .parallel import Parallel
from .parallel import Parallel
# from .retry import Retry
51 changes: 6 additions & 45 deletions dspy/predict/chain_of_thought.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,58 +2,19 @@
from dspy.primitives.program import Module
from dspy.signatures.signature import ensure_signature

# TODO: This shouldn't inherit from Predict. It should be a module that has one or two predictors.
# Let's focus on the activated case. It's a predictor with the expanded signature.
# Now, when deactivated, it's a predictor with the original signature.
# When activate is None, though, we need the expanded one but during forward we need to pass the right signature.


class ChainOfThought(Module):
def __init__(self, signature, rationale_type=None, activated=True, **config):
def __init__(self, signature, rationale_type=None, **config):
super().__init__()

self.activated = activated

self.signature = signature = ensure_signature(signature)
*_keys, last_key = signature.output_fields.keys()
signature = ensure_signature(signature)

prefix = "Reasoning: Let's think step by step in order to"

if isinstance(dspy.settings.lm, dspy.LM):
desc = "${reasoning}"
elif dspy.settings.experimental:
desc = "${produce the output fields}. We ..."
else:
desc = f"${{produce the {last_key}}}. We ..."

desc = "${reasoning}"
rationale_type = rationale_type or dspy.OutputField(prefix=prefix, desc=desc)

# Add "rationale" field to the output signature.
if isinstance(dspy.settings.lm, dspy.LM) or dspy.settings.experimental:
extended_signature = signature.prepend("reasoning", rationale_type, type_=str)
else:
extended_signature = signature.prepend("rationale", rationale_type, type_=str)
extended_signature = signature.prepend("reasoning", rationale_type, type_=str)

self._predict = dspy.Predict(extended_signature, **config)
self._predict.extended_signature = extended_signature
self.predict = dspy.Predict(extended_signature, **config)

def forward(self, **kwargs):
assert self.activated in [True, False]

signature = kwargs.pop("new_signature", self._predict.extended_signature if self.activated else self.signature)
return self._predict(signature=signature, **kwargs)

@property
def demos(self):
return self._predict.demos

@property
def extended_signature(self):
return self._predict.extended_signature


"""
TODO: In principle, we can update the field's prefix during forward too to fill any thing based on the input args.
IF the user didn't overwrite our default rationale_type.
"""
return self.predict(**kwargs)
15 changes: 4 additions & 11 deletions dspy/predict/predict.py
Original file line number Diff line number Diff line change
Expand Up @@ -51,9 +51,6 @@ def dump_state(self):
state["demos"].append(demo)

state["signature"] = self.signature.dump_state()
# `extended_signature` is a special field for `Predict`s like CoT.
if hasattr(self, "extended_signature"):
state["extended_signature"] = self.extended_signature.dump_state()
return state

def load_state(self, state):
Expand Down Expand Up @@ -82,8 +79,8 @@ def load_state(self, state):

self.signature = self.signature.load_state(state["signature"])

if "extended_signature" in state:
self.extended_signature = self.extended_signature.load_state(state["extended_signature"])
if "extended_signature" in state: # legacy, up to and including 2.5, for CoT.
self.signature = self.signature.load_state(state["extended_signature"])

return self

Expand All @@ -96,14 +93,14 @@ def forward(self, **kwargs):
assert not dspy.settings.compiling, "It's no longer ever the case that .compiling is True"

# Extract the three privileged keyword arguments.
new_signature = ensure_signature(kwargs.pop("new_signature", None))
assert "new_signature" not in kwargs, "new_signature is no longer a valid keyword argument."
signature = ensure_signature(kwargs.pop("signature", self.signature))
demos = kwargs.pop("demos", self.demos)
config = dict(**self.config, **kwargs.pop("config", {}))

# Get the right LM to use.
lm = kwargs.pop("lm", self.lm) or dspy.settings.lm
assert lm is not None, "No LM is loaded."
assert isinstance(lm, dspy.LM), "No LM is loaded."

# If temperature is 0.0 but its n > 1, set temperature to 0.7.
temperature = config.get("temperature")
Expand All @@ -113,15 +110,11 @@ def forward(self, **kwargs):
if (temperature is None or temperature <= 0.15) and num_generations > 1:
config["temperature"] = 0.7

if new_signature is not None:
signature = new_signature

if not all(k in kwargs for k in signature.input_fields):
present = [k for k in signature.input_fields if k in kwargs]
missing = [k for k in signature.input_fields if k not in kwargs]
print(f"WARNING: Not all input fields were provided to module. Present: {present}. Missing: {missing}.")

assert isinstance(lm, dspy.LM)
completions = v2_5_generate(lm, config, signature, demos, kwargs, _parse_values=self._parse_values)

pred = Prediction.from_completions(completions, signature=signature)
Expand Down
12 changes: 12 additions & 0 deletions dspy/predict/react.py
Original file line number Diff line number Diff line change
Expand Up @@ -124,6 +124,9 @@ def format(trajectory: dict[str, Any], last_iteration: bool):
Another potential fix is to more natively support a "variadic" input field, where the input is a list of dictionaries,
or a big dictionary, and have each adatper format it accordingly.
Trajectories also affect meta-programming modules that view the trace later. It's inefficient O(n^2) to view the
trace of every module repeating the prefix.
TOPIC 02: Handling default arguments in the Tool class.
Expand All @@ -140,4 +143,13 @@ def format(trajectory: dict[str, Any], last_iteration: bool):
TOPIC 05: Adding more structure around how the instruction is formatted.
* Concretely, it's now a string, so an optimizer can and does rewrite it freely.
* An alternative would be to add more structure, such that a certain template is fixed but values are variable?
TOPIC 06: Idiomatically allowing tools that maintain state across iterations, but not across different `forward` calls.
* So the tool would be newly initialized at the start of each `forward` call, but maintain state across iterations.
* This is pretty useful for allowing the agent to keep notes or count certain things, etc.
TOPIC 07: Make max_iters a bit more expressive.
* Allow passing `max_iters` in forward to overwrite the default.
* Get rid of `last_iteration: bool` in the format function. It's not necessary now.
"""
122 changes: 61 additions & 61 deletions dspy/predict/retry.py
Original file line number Diff line number Diff line change
@@ -1,74 +1,74 @@
import copy
# import copy

import dspy
# import dspy

from .predict import Predict
# from .predict import Predict


class Retry(Predict):
def __init__(self, module):
super().__init__(module.signature)
self.module = module
self.original_signature = module.extended_signature if isinstance(module, dspy.ChainOfThought) else module.signature
self.original_forward = module.forward
self.new_signature = self._create_new_signature(self.original_signature)
# class Retry(Predict):
# def __init__(self, module):
# super().__init__(module.signature)
# self.module = module
# self.original_signature = module.signature
# self.original_forward = module.forward
# self.new_signature = self._create_new_signature(self.original_signature)

def _create_new_signature(self, signature):
# Add "Past" input fields for each output field
for key, value in signature.output_fields.items():
actual_prefix = value.json_schema_extra["prefix"].split(":")[0] + ":"
signature = signature.append(f"past_{key}", dspy.InputField(
prefix="Previous " + actual_prefix,
desc=f"past {actual_prefix[:-1]} with errors",
format=value.json_schema_extra.get("format"),
))
# def _create_new_signature(self, signature):
# # Add "Past" input fields for each output field
# for key, value in signature.output_fields.items():
# actual_prefix = value.json_schema_extra["prefix"].split(":")[0] + ":"
# signature = signature.append(f"past_{key}", dspy.InputField(
# prefix="Previous " + actual_prefix,
# desc=f"past {actual_prefix[:-1]} with errors",
# format=value.json_schema_extra.get("format"),
# ))

signature = signature.append("feedback", dspy.InputField(
prefix="Instructions:",
desc="Some instructions you must satisfy",
format=str,
))
# signature = signature.append("feedback", dspy.InputField(
# prefix="Instructions:",
# desc="Some instructions you must satisfy",
# format=str,
# ))

return signature
# return signature

def forward(self, *, past_outputs, **kwargs):
# Take into account the possible new signature, as in TypedPredictor
new_signature = kwargs.pop("new_signature", None)
if new_signature:
self.original_signature = new_signature
self.new_signature = self._create_new_signature(self.original_signature)
# def forward(self, *, past_outputs, **kwargs):
# # Take into account the possible new signature, as in TypedPredictor
# new_signature = kwargs.pop("new_signature", None)
# if new_signature:
# self.original_signature = new_signature
# self.new_signature = self._create_new_signature(self.original_signature)

# Convert the dict past_outputs={"answer": ...} to kwargs
# {past_answer=..., ...}
for key, value in past_outputs.items():
past_key = f"past_{key}"
if past_key in self.new_signature.input_fields:
kwargs[past_key] = value
# Tell the wrapped module to use the new signature.
# Note: This only works if the wrapped module is a Predict or ChainOfThought.
kwargs["new_signature"] = self.new_signature
return self.original_forward(**kwargs)
# # Convert the dict past_outputs={"answer": ...} to kwargs
# # {past_answer=..., ...}
# for key, value in past_outputs.items():
# past_key = f"past_{key}"
# if past_key in self.new_signature.input_fields:
# kwargs[past_key] = value
# # Tell the wrapped module to use the new signature.
# # Note: This only works if the wrapped module is a Predict or ChainOfThought.
# kwargs["new_signature"] = self.new_signature
# return self.original_forward(**kwargs)

def __call__(self, **kwargs):
copy.deepcopy(kwargs)
kwargs["_trace"] = False
kwargs.setdefault("demos", self.demos if self.demos is not None else [])
# def __call__(self, **kwargs):
# copy.deepcopy(kwargs)
# kwargs["_trace"] = False
# kwargs.setdefault("demos", self.demos if self.demos is not None else [])

# perform backtracking
if dspy.settings.backtrack_to == self:
for key, value in dspy.settings.backtrack_to_args.items():
kwargs.setdefault(key, value)
pred = self.forward(**kwargs)
else:
pred = self.module(**kwargs)
# # perform backtracking
# if dspy.settings.backtrack_to == self:
# for key, value in dspy.settings.backtrack_to_args.items():
# kwargs.setdefault(key, value)
# pred = self.forward(**kwargs)
# else:
# pred = self.module(**kwargs)

# now pop multiple reserved keys
# NOTE(shangyin) past_outputs seems not useful to include in demos,
# therefore dropped
for key in ["_trace", "demos", "signature", "new_signature", "config", "lm", "past_outputs"]:
kwargs.pop(key, None)
# # now pop multiple reserved keys
# # NOTE(shangyin) past_outputs seems not useful to include in demos,
# # therefore dropped
# for key in ["_trace", "demos", "signature", "new_signature", "config", "lm", "past_outputs"]:
# kwargs.pop(key, None)

if dspy.settings.trace is not None:
trace = dspy.settings.trace
trace.append((self, {**kwargs}, pred))
return pred
# if dspy.settings.trace is not None:
# trace = dspy.settings.trace
# trace.append((self, {**kwargs}, pred))
# return pred
Loading

0 comments on commit ae86009

Please sign in to comment.