Skip to content

Commit

Permalink
Add SemanticF1, fix LM.copy, fix CoT.save
Browse files Browse the repository at this point in the history
  • Loading branch information
okhat committed Oct 16, 2024
1 parent f6b9d8b commit f1544b8
Show file tree
Hide file tree
Showing 5 changed files with 67 additions and 15 deletions.
17 changes: 14 additions & 3 deletions dspy/clients/lm.py
Original file line number Diff line number Diff line change
Expand Up @@ -64,9 +64,20 @@ def inspect_history(self, n: int = 1):
_inspect_history(self, n)

def copy(self, **kwargs):
"""Returns a copy of the language model with the same parameters."""
kwargs = {**self.__dict__, **kwargs}
return self.__class__(**kwargs)
"""Returns a copy of the language model with possibly updated parameters."""

import copy
new_instance = copy.deepcopy(self)
new_instance.history = []

for key, value in kwargs.items():
if hasattr(self, key):
setattr(new_instance, key, value)
if (key in self.kwargs) or (not hasattr(self, key)):
new_instance.kwargs[key] = value

return new_instance



@functools.lru_cache(maxsize=None)
Expand Down
46 changes: 41 additions & 5 deletions dspy/evaluate/auto_evaluation.py
Original file line number Diff line number Diff line change
@@ -1,35 +1,71 @@
import dspy


class SemanticRecallPrecision(dspy.Signature):
"""
Compare a system's response to the ground truth to compute its recall and precision.
If asked to reason, enumerate key ideas in each response, and whether they are present in the other response.
"""

question: str = dspy.InputField()
ground_truth: str = dspy.InputField()
system_response: str = dspy.InputField()
recall: float = dspy.OutputField(desc="fraction (out of 1.0) of ground truth covered by the system response")
precision: float = dspy.OutputField(desc="fraction (out of 1.0) of system response covered by the ground truth")


def f1_score(precision, recall):
return 0.0 if precision + recall == 0 else 2 * (precision * recall) / (precision + recall)


class SemanticF1(dspy.Module):
def __init__(self, threshold=0.66):
self.threshold = threshold
self.module = dspy.ChainOfThought(SemanticRecallPrecision)

def forward(self, example, pred, trace=None):
scores = self.module(question=example.question, ground_truth=example.response, system_response=pred.response)
score = f1_score(scores.precision, scores.recall)

return score if trace is None else score >= self.threshold


"""
Soon-to-be deprecated Signatures & Modules Below.
"""


class AnswerCorrectnessSignature(dspy.Signature):
"""Verify that the predicted answer matches the gold answer."""

question = dspy.InputField()
gold_answer = dspy.InputField(desc="correct answer for question")
predicted_answer = dspy.InputField(desc="predicted answer for question")
is_correct = dspy.OutputField(desc='True or False')
is_correct = dspy.OutputField(desc="True or False")


class AnswerCorrectness(dspy.Module):
def __init__(self):
super().__init__()
self.evaluate_correctness = dspy.ChainOfThought(AnswerCorrectnessSignature)

def forward(self, question, gold_answer, predicted_answer):
return self.evaluate_correctness(question=question, gold_answer=gold_answer, predicted_answer=predicted_answer)


class AnswerFaithfulnessSignature(dspy.Signature):
"""Verify that the predicted answer is based on the provided context."""

context = dspy.InputField(desc="relevant facts for producing answer")
question = dspy.InputField()
answer = dspy.InputField(desc="often between 1 and 5 words")
is_faithful = dspy.OutputField(desc='True or False')
is_faithful = dspy.OutputField(desc="True or False")


class AnswerFaithfulness(dspy.Module):
def __init__(self):
super().__init__()
self.evaluate_faithfulness = dspy.ChainOfThought(AnswerFaithfulnessSignature)

def forward(self, context, question, answer):
return self.evaluate_faithfulness(context=context, question=question, answer=answer)
16 changes: 11 additions & 5 deletions dspy/evaluate/evaluate.py
Original file line number Diff line number Diff line change
Expand Up @@ -191,12 +191,16 @@ def wrapped_program(example_idx, example):
current_error_count = self.error_count
if current_error_count >= self.max_errors:
raise e

if self.provide_traceback:
dspy.logger.error(f"Error for example in dev set: \t\t {e}\n\twith inputs:\n\t\t{example.inputs()}\n\nStack trace:\n\t{traceback.format_exc()}")
dspy.logger.error(
f"Error for example in dev set: \t\t {e}\n\twith inputs:\n\t\t{example.inputs()}\n\nStack trace:\n\t{traceback.format_exc()}"
)
else:
dspy.logger.error(f"Error for example in dev set: \t\t {e}. Set `provide_traceback=True` to see the stack trace.")

dspy.logger.error(
f"Error for example in dev set: \t\t {e}. Set `provide_traceback=True` to see the stack trace."
)

return example_idx, example, {}, 0.0
finally:
if creating_new_thread:
Expand Down Expand Up @@ -303,7 +307,9 @@ def stylize_metric_name(df: pd.DataFrame, metric_name: str) -> pd.DataFrame:
:param df: The pandas DataFrame for which to stylize cell contents.
:param metric_name: The name of the metric for which to stylize DataFrame cell contents.
"""
df[metric_name] = df[metric_name].apply(lambda x: f"✔️ [{x}]" if x else str)
df[metric_name] = df[metric_name].apply(
lambda x: f"✔️ [{x:.3f}]" if x and isinstance(x, float) else f"✔️ [{x}]" if x else ""
)
return df


Expand Down
2 changes: 1 addition & 1 deletion dspy/predict/predict.py
Original file line number Diff line number Diff line change
Expand Up @@ -84,7 +84,7 @@ def load_state(self, state, use_legacy_loading=False):
self.signature = self.signature.load_state(state["signature"])

if "extended_signature" in state:
self.extended_signature.load_state(state["extended_signature"])
self.extended_signature = self.extended_signature.load_state(state["extended_signature"])

def _load_state_legacy(self, state):
"""Legacy state loading for backwards compatibility.
Expand Down
1 change: 0 additions & 1 deletion dspy/primitives/module.py
Original file line number Diff line number Diff line change
Expand Up @@ -149,7 +149,6 @@ def reset_copy(self):
return new_instance

def dump_state(self, save_verbose):
print(self.named_parameters())
return {name: param.dump_state(save_verbose) for name, param in self.named_parameters()}

def load_state(self, state, use_legacy_loading=False):
Expand Down

0 comments on commit f1544b8

Please sign in to comment.