Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Simplify built-in modules (remove new_* & extended_signature). Prepare for assertions v2. #1943

Merged
merged 3 commits into from
Dec 16, 2024
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
4 changes: 2 additions & 2 deletions dspy/predict/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -6,5 +6,5 @@
from .predict import Predict
from .program_of_thought import ProgramOfThought
from .react import ReAct, Tool
from .retry import Retry
from .parallel import Parallel
from .parallel import Parallel
# from .retry import Retry
51 changes: 6 additions & 45 deletions dspy/predict/chain_of_thought.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,58 +2,19 @@
from dspy.primitives.program import Module
from dspy.signatures.signature import ensure_signature

# TODO: This shouldn't inherit from Predict. It should be a module that has one or two predictors.
# Let's focus on the activated case. It's a predictor with the expanded signature.
# Now, when deactivated, it's a predictor with the original signature.
# When activate is None, though, we need the expanded one but during forward we need to pass the right signature.


class ChainOfThought(Module):
def __init__(self, signature, rationale_type=None, activated=True, **config):
def __init__(self, signature, rationale_type=None, **config):
super().__init__()

self.activated = activated

self.signature = signature = ensure_signature(signature)
*_keys, last_key = signature.output_fields.keys()
signature = ensure_signature(signature)

prefix = "Reasoning: Let's think step by step in order to"

if isinstance(dspy.settings.lm, dspy.LM):
desc = "${reasoning}"
elif dspy.settings.experimental:
desc = "${produce the output fields}. We ..."
else:
desc = f"${{produce the {last_key}}}. We ..."

desc = "${reasoning}"
rationale_type = rationale_type or dspy.OutputField(prefix=prefix, desc=desc)

# Add "rationale" field to the output signature.
if isinstance(dspy.settings.lm, dspy.LM) or dspy.settings.experimental:
extended_signature = signature.prepend("reasoning", rationale_type, type_=str)
else:
extended_signature = signature.prepend("rationale", rationale_type, type_=str)
extended_signature = signature.prepend("reasoning", rationale_type, type_=str)

self._predict = dspy.Predict(extended_signature, **config)
self._predict.extended_signature = extended_signature
self.predict = dspy.Predict(extended_signature, **config)

def forward(self, **kwargs):
assert self.activated in [True, False]

signature = kwargs.pop("new_signature", self._predict.extended_signature if self.activated else self.signature)
return self._predict(signature=signature, **kwargs)

@property
def demos(self):
return self._predict.demos

@property
def extended_signature(self):
return self._predict.extended_signature


"""
TODO: In principle, we can update the field's prefix during forward too to fill any thing based on the input args.

IF the user didn't overwrite our default rationale_type.
"""
return self.predict(**kwargs)
15 changes: 4 additions & 11 deletions dspy/predict/predict.py
Original file line number Diff line number Diff line change
Expand Up @@ -51,9 +51,6 @@ def dump_state(self):
state["demos"].append(demo)

state["signature"] = self.signature.dump_state()
# `extended_signature` is a special field for `Predict`s like CoT.
if hasattr(self, "extended_signature"):
state["extended_signature"] = self.extended_signature.dump_state()
return state

def load_state(self, state):
Expand Down Expand Up @@ -82,8 +79,8 @@ def load_state(self, state):

self.signature = self.signature.load_state(state["signature"])

if "extended_signature" in state:
self.extended_signature = self.extended_signature.load_state(state["extended_signature"])
if "extended_signature" in state: # legacy, up to and including 2.5, for CoT.
self.signature = self.signature.load_state(state["extended_signature"])

return self

Expand All @@ -96,14 +93,14 @@ def forward(self, **kwargs):
assert not dspy.settings.compiling, "It's no longer ever the case that .compiling is True"

# Extract the three privileged keyword arguments.
new_signature = ensure_signature(kwargs.pop("new_signature", None))
assert "new_signature" not in kwargs, "new_signature is no longer a valid keyword argument."
signature = ensure_signature(kwargs.pop("signature", self.signature))
demos = kwargs.pop("demos", self.demos)
config = dict(**self.config, **kwargs.pop("config", {}))

# Get the right LM to use.
lm = kwargs.pop("lm", self.lm) or dspy.settings.lm
assert lm is not None, "No LM is loaded."
assert isinstance(lm, dspy.LM), "No LM is loaded."

# If temperature is 0.0 but its n > 1, set temperature to 0.7.
temperature = config.get("temperature")
Expand All @@ -113,15 +110,11 @@ def forward(self, **kwargs):
if (temperature is None or temperature <= 0.15) and num_generations > 1:
config["temperature"] = 0.7

if new_signature is not None:
signature = new_signature

if not all(k in kwargs for k in signature.input_fields):
present = [k for k in signature.input_fields if k in kwargs]
missing = [k for k in signature.input_fields if k not in kwargs]
print(f"WARNING: Not all input fields were provided to module. Present: {present}. Missing: {missing}.")

assert isinstance(lm, dspy.LM)
completions = v2_5_generate(lm, config, signature, demos, kwargs, _parse_values=self._parse_values)

pred = Prediction.from_completions(completions, signature=signature)
Expand Down
12 changes: 12 additions & 0 deletions dspy/predict/react.py
Original file line number Diff line number Diff line change
Expand Up @@ -124,6 +124,9 @@ def format(trajectory: dict[str, Any], last_iteration: bool):
Another potential fix is to more natively support a "variadic" input field, where the input is a list of dictionaries,
or a big dictionary, and have each adatper format it accordingly.
Trajectories also affect meta-programming modules that view the trace later. It's inefficient O(n^2) to view the
trace of every module repeating the prefix.
TOPIC 02: Handling default arguments in the Tool class.
Expand All @@ -140,4 +143,13 @@ def format(trajectory: dict[str, Any], last_iteration: bool):
TOPIC 05: Adding more structure around how the instruction is formatted.
* Concretely, it's now a string, so an optimizer can and does rewrite it freely.
* An alternative would be to add more structure, such that a certain template is fixed but values are variable?
TOPIC 06: Idiomatically allowing tools that maintain state across iterations, but not across different `forward` calls.
* So the tool would be newly initialized at the start of each `forward` call, but maintain state across iterations.
* This is pretty useful for allowing the agent to keep notes or count certain things, etc.
TOPIC 07: Make max_iters a bit more expressive.
* Allow passing `max_iters` in forward to overwrite the default.
* Get rid of `last_iteration: bool` in the format function. It's not necessary now.
"""
122 changes: 61 additions & 61 deletions dspy/predict/retry.py
Original file line number Diff line number Diff line change
@@ -1,74 +1,74 @@
import copy
# import copy

import dspy
# import dspy

from .predict import Predict
# from .predict import Predict


class Retry(Predict):
def __init__(self, module):
super().__init__(module.signature)
self.module = module
self.original_signature = module.extended_signature if isinstance(module, dspy.ChainOfThought) else module.signature
self.original_forward = module.forward
self.new_signature = self._create_new_signature(self.original_signature)
# class Retry(Predict):
# def __init__(self, module):
# super().__init__(module.signature)
# self.module = module
# self.original_signature = module.signature
# self.original_forward = module.forward
# self.new_signature = self._create_new_signature(self.original_signature)

def _create_new_signature(self, signature):
# Add "Past" input fields for each output field
for key, value in signature.output_fields.items():
actual_prefix = value.json_schema_extra["prefix"].split(":")[0] + ":"
signature = signature.append(f"past_{key}", dspy.InputField(
prefix="Previous " + actual_prefix,
desc=f"past {actual_prefix[:-1]} with errors",
format=value.json_schema_extra.get("format"),
))
# def _create_new_signature(self, signature):
# # Add "Past" input fields for each output field
# for key, value in signature.output_fields.items():
# actual_prefix = value.json_schema_extra["prefix"].split(":")[0] + ":"
# signature = signature.append(f"past_{key}", dspy.InputField(
# prefix="Previous " + actual_prefix,
# desc=f"past {actual_prefix[:-1]} with errors",
# format=value.json_schema_extra.get("format"),
# ))

signature = signature.append("feedback", dspy.InputField(
prefix="Instructions:",
desc="Some instructions you must satisfy",
format=str,
))
# signature = signature.append("feedback", dspy.InputField(
# prefix="Instructions:",
# desc="Some instructions you must satisfy",
# format=str,
# ))

return signature
# return signature

def forward(self, *, past_outputs, **kwargs):
# Take into account the possible new signature, as in TypedPredictor
new_signature = kwargs.pop("new_signature", None)
if new_signature:
self.original_signature = new_signature
self.new_signature = self._create_new_signature(self.original_signature)
# def forward(self, *, past_outputs, **kwargs):
# # Take into account the possible new signature, as in TypedPredictor
# new_signature = kwargs.pop("new_signature", None)
# if new_signature:
# self.original_signature = new_signature
# self.new_signature = self._create_new_signature(self.original_signature)

# Convert the dict past_outputs={"answer": ...} to kwargs
# {past_answer=..., ...}
for key, value in past_outputs.items():
past_key = f"past_{key}"
if past_key in self.new_signature.input_fields:
kwargs[past_key] = value
# Tell the wrapped module to use the new signature.
# Note: This only works if the wrapped module is a Predict or ChainOfThought.
kwargs["new_signature"] = self.new_signature
return self.original_forward(**kwargs)
# # Convert the dict past_outputs={"answer": ...} to kwargs
# # {past_answer=..., ...}
# for key, value in past_outputs.items():
# past_key = f"past_{key}"
# if past_key in self.new_signature.input_fields:
# kwargs[past_key] = value
# # Tell the wrapped module to use the new signature.
# # Note: This only works if the wrapped module is a Predict or ChainOfThought.
# kwargs["new_signature"] = self.new_signature
# return self.original_forward(**kwargs)

def __call__(self, **kwargs):
copy.deepcopy(kwargs)
kwargs["_trace"] = False
kwargs.setdefault("demos", self.demos if self.demos is not None else [])
# def __call__(self, **kwargs):
# copy.deepcopy(kwargs)
# kwargs["_trace"] = False
# kwargs.setdefault("demos", self.demos if self.demos is not None else [])

# perform backtracking
if dspy.settings.backtrack_to == self:
for key, value in dspy.settings.backtrack_to_args.items():
kwargs.setdefault(key, value)
pred = self.forward(**kwargs)
else:
pred = self.module(**kwargs)
# # perform backtracking
# if dspy.settings.backtrack_to == self:
# for key, value in dspy.settings.backtrack_to_args.items():
# kwargs.setdefault(key, value)
# pred = self.forward(**kwargs)
# else:
# pred = self.module(**kwargs)

# now pop multiple reserved keys
# NOTE(shangyin) past_outputs seems not useful to include in demos,
# therefore dropped
for key in ["_trace", "demos", "signature", "new_signature", "config", "lm", "past_outputs"]:
kwargs.pop(key, None)
# # now pop multiple reserved keys
# # NOTE(shangyin) past_outputs seems not useful to include in demos,
# # therefore dropped
# for key in ["_trace", "demos", "signature", "new_signature", "config", "lm", "past_outputs"]:
# kwargs.pop(key, None)

if dspy.settings.trace is not None:
trace = dspy.settings.trace
trace.append((self, {**kwargs}, pred))
return pred
# if dspy.settings.trace is not None:
# trace = dspy.settings.trace
# trace.append((self, {**kwargs}, pred))
# return pred
Loading
Loading