-
Notifications
You must be signed in to change notification settings - Fork 1.6k
Commit
This commit does not belong to any branch on this repository, and may belong to a fork outside of the repository.
Simplify built-in modules (remove new_* & extended_signature). Prepar…
…e for assertions v2. (#1943) * Simplify built-in modules (remove extended_signature and new_signature) and remove assertions temporarily * Update test_retry.py * Fully turn CoT into a typical Module
- Loading branch information
Showing
13 changed files
with
523 additions
and
577 deletions.
There are no files selected for viewing
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -1,74 +1,74 @@ | ||
import copy | ||
# import copy | ||
|
||
import dspy | ||
# import dspy | ||
|
||
from .predict import Predict | ||
# from .predict import Predict | ||
|
||
|
||
class Retry(Predict): | ||
def __init__(self, module): | ||
super().__init__(module.signature) | ||
self.module = module | ||
self.original_signature = module.extended_signature if isinstance(module, dspy.ChainOfThought) else module.signature | ||
self.original_forward = module.forward | ||
self.new_signature = self._create_new_signature(self.original_signature) | ||
# class Retry(Predict): | ||
# def __init__(self, module): | ||
# super().__init__(module.signature) | ||
# self.module = module | ||
# self.original_signature = module.signature | ||
# self.original_forward = module.forward | ||
# self.new_signature = self._create_new_signature(self.original_signature) | ||
|
||
def _create_new_signature(self, signature): | ||
# Add "Past" input fields for each output field | ||
for key, value in signature.output_fields.items(): | ||
actual_prefix = value.json_schema_extra["prefix"].split(":")[0] + ":" | ||
signature = signature.append(f"past_{key}", dspy.InputField( | ||
prefix="Previous " + actual_prefix, | ||
desc=f"past {actual_prefix[:-1]} with errors", | ||
format=value.json_schema_extra.get("format"), | ||
)) | ||
# def _create_new_signature(self, signature): | ||
# # Add "Past" input fields for each output field | ||
# for key, value in signature.output_fields.items(): | ||
# actual_prefix = value.json_schema_extra["prefix"].split(":")[0] + ":" | ||
# signature = signature.append(f"past_{key}", dspy.InputField( | ||
# prefix="Previous " + actual_prefix, | ||
# desc=f"past {actual_prefix[:-1]} with errors", | ||
# format=value.json_schema_extra.get("format"), | ||
# )) | ||
|
||
signature = signature.append("feedback", dspy.InputField( | ||
prefix="Instructions:", | ||
desc="Some instructions you must satisfy", | ||
format=str, | ||
)) | ||
# signature = signature.append("feedback", dspy.InputField( | ||
# prefix="Instructions:", | ||
# desc="Some instructions you must satisfy", | ||
# format=str, | ||
# )) | ||
|
||
return signature | ||
# return signature | ||
|
||
def forward(self, *, past_outputs, **kwargs): | ||
# Take into account the possible new signature, as in TypedPredictor | ||
new_signature = kwargs.pop("new_signature", None) | ||
if new_signature: | ||
self.original_signature = new_signature | ||
self.new_signature = self._create_new_signature(self.original_signature) | ||
# def forward(self, *, past_outputs, **kwargs): | ||
# # Take into account the possible new signature, as in TypedPredictor | ||
# new_signature = kwargs.pop("new_signature", None) | ||
# if new_signature: | ||
# self.original_signature = new_signature | ||
# self.new_signature = self._create_new_signature(self.original_signature) | ||
|
||
# Convert the dict past_outputs={"answer": ...} to kwargs | ||
# {past_answer=..., ...} | ||
for key, value in past_outputs.items(): | ||
past_key = f"past_{key}" | ||
if past_key in self.new_signature.input_fields: | ||
kwargs[past_key] = value | ||
# Tell the wrapped module to use the new signature. | ||
# Note: This only works if the wrapped module is a Predict or ChainOfThought. | ||
kwargs["new_signature"] = self.new_signature | ||
return self.original_forward(**kwargs) | ||
# # Convert the dict past_outputs={"answer": ...} to kwargs | ||
# # {past_answer=..., ...} | ||
# for key, value in past_outputs.items(): | ||
# past_key = f"past_{key}" | ||
# if past_key in self.new_signature.input_fields: | ||
# kwargs[past_key] = value | ||
# # Tell the wrapped module to use the new signature. | ||
# # Note: This only works if the wrapped module is a Predict or ChainOfThought. | ||
# kwargs["new_signature"] = self.new_signature | ||
# return self.original_forward(**kwargs) | ||
|
||
def __call__(self, **kwargs): | ||
copy.deepcopy(kwargs) | ||
kwargs["_trace"] = False | ||
kwargs.setdefault("demos", self.demos if self.demos is not None else []) | ||
# def __call__(self, **kwargs): | ||
# copy.deepcopy(kwargs) | ||
# kwargs["_trace"] = False | ||
# kwargs.setdefault("demos", self.demos if self.demos is not None else []) | ||
|
||
# perform backtracking | ||
if dspy.settings.backtrack_to == self: | ||
for key, value in dspy.settings.backtrack_to_args.items(): | ||
kwargs.setdefault(key, value) | ||
pred = self.forward(**kwargs) | ||
else: | ||
pred = self.module(**kwargs) | ||
# # perform backtracking | ||
# if dspy.settings.backtrack_to == self: | ||
# for key, value in dspy.settings.backtrack_to_args.items(): | ||
# kwargs.setdefault(key, value) | ||
# pred = self.forward(**kwargs) | ||
# else: | ||
# pred = self.module(**kwargs) | ||
|
||
# now pop multiple reserved keys | ||
# NOTE(shangyin) past_outputs seems not useful to include in demos, | ||
# therefore dropped | ||
for key in ["_trace", "demos", "signature", "new_signature", "config", "lm", "past_outputs"]: | ||
kwargs.pop(key, None) | ||
# # now pop multiple reserved keys | ||
# # NOTE(shangyin) past_outputs seems not useful to include in demos, | ||
# # therefore dropped | ||
# for key in ["_trace", "demos", "signature", "new_signature", "config", "lm", "past_outputs"]: | ||
# kwargs.pop(key, None) | ||
|
||
if dspy.settings.trace is not None: | ||
trace = dspy.settings.trace | ||
trace.append((self, {**kwargs}, pred)) | ||
return pred | ||
# if dspy.settings.trace is not None: | ||
# trace = dspy.settings.trace | ||
# trace.append((self, {**kwargs}, pred)) | ||
# return pred |
Oops, something went wrong.