Skip to content

Commit

Permalink
Support returning logprobs in Predictor (#1895)
Browse files Browse the repository at this point in the history
* support returning logprobs in Predictor

* allow output to be either str or dict

* return outputs as a list of strings if user doesn't set logprobs

* Update lm.py

---------

Co-authored-by: Omar Khattab <[email protected]>
  • Loading branch information
veronicalyu320 and okhat authored Dec 10, 2024
1 parent f1fc6bc commit e690743
Show file tree
Hide file tree
Showing 2 changed files with 19 additions and 2 deletions.
9 changes: 8 additions & 1 deletion dspy/adapters/base.py
Original file line number Diff line number Diff line change
Expand Up @@ -22,8 +22,15 @@ def __call__(self, lm, lm_kwargs, signature, demos, inputs, _parse_values=True):

try:
for output in outputs:
value = self.parse(signature, output, _parse_values=_parse_values)
if type(output) is str:
output_text, output_logprobs = output, None
elif type(output) is dict:
output_text, output_logprobs = output["text"], output["logprobs"]
else:
raise ValueError(f"Expected str or dict but got {type(output)}")
value = self.parse(signature, output_text, _parse_values=_parse_values)
assert set(value.keys()) == set(signature.output_fields.keys()), f"Expected {signature.output_fields.keys()} but got {value.keys()}"
value["logprobs"] = output_logprobs
values.append(value)
return values

Expand Down
12 changes: 11 additions & 1 deletion dspy/clients/lm.py
Original file line number Diff line number Diff line change
Expand Up @@ -98,8 +98,18 @@ def __call__(self, prompt=None, messages=None, **kwargs):
request=dict(model=self.model, messages=messages, **kwargs),
num_retries=self.num_retries,
)
outputs = [c.message.content if hasattr(c, "message") else c["text"] for c in response["choices"]]
if kwargs.get("logprobs"):
outputs = [
{
"text": c.message.content if hasattr(c, "message") else c["text"],
"logprobs": c.logprobs if hasattr(c, "logprobs") else c["logprobs"]
}
for c in response["choices"]
]
else:
outputs = [c.message.content if hasattr(c, "message") else c["text"] for c in response["choices"]]


# Logging, with removed api key & where `cost` is None on cache hit.
kwargs = {k: v for k, v in kwargs.items() if not k.startswith("api_")}
entry = dict(prompt=prompt, messages=messages, kwargs=kwargs, response=response)
Expand Down

0 comments on commit e690743

Please sign in to comment.