-
Notifications
You must be signed in to change notification settings - Fork 1.6k
Commit
This commit does not belong to any branch on this repository, and may belong to a fork outside of the repository.
Add initial DSPy reliability tests (#1773)
* fix Signed-off-by: dbczumar <[email protected]> * fix Signed-off-by: dbczumar <[email protected]> * fix Signed-off-by: dbczumar <[email protected]> * Fix Signed-off-by: dbczumar <[email protected]> * fix Signed-off-by: dbczumar <[email protected]> * fix Signed-off-by: dbczumar <[email protected]> * fix Signed-off-by: dbczumar <[email protected]> * fix Signed-off-by: dbczumar <[email protected]> * fix Signed-off-by: dbczumar <[email protected]> * fix Signed-off-by: dbczumar <[email protected]> * fix Signed-off-by: dbczumar <[email protected]> * fix Signed-off-by: dbczumar <[email protected]> * fix Signed-off-by: dbczumar <[email protected]> * format Signed-off-by: dbczumar <[email protected]> * fix Signed-off-by: dbczumar <[email protected]> * fix Signed-off-by: dbczumar <[email protected]> * fix Signed-off-by: dbczumar <[email protected]> * fix Signed-off-by: dbczumar <[email protected]> * fix Signed-off-by: dbczumar <[email protected]> * fix Signed-off-by: dbczumar <[email protected]> * fix Signed-off-by: dbczumar <[email protected]> * Update run_tests.yml --------- Signed-off-by: dbczumar <[email protected]>
- Loading branch information
Showing
7 changed files
with
425 additions
and
1 deletion.
There are no files selected for viewing
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,58 @@ | ||
# DSPy Reliability Tests | ||
|
||
This directory contains reliability tests for DSPy programs. The purpose of these tests is to verify that DSPy programs reliabily produce expected outputs across multiple large language models (LLMs), regardless of model size or capability. These tests are designed to ensure that DSPy programs maintain robustness and accuracy across diverse LLM configurations. | ||
|
||
### Overview | ||
|
||
Each test in this directory executes a DSPy program using various LLMs. By running the same tests across different models, these tests help validate that DSPy programs handle a wide range of inputs effectively and produce reliable outputs, even in cases where the model might struggle with the input or task. | ||
|
||
### Key Features | ||
|
||
- **Diverse LLMs**: Each DSPy program is tested with multiple LLMs, ranging from smaller models to more advanced, high-performance models. This approach allows us to assess the consistency and generality of DSPy program outputs across different model capabilities. | ||
- **Challenging and Adversarial Tests**: Some of the tests are intentionally challenging or adversarial, crafted to push the boundaries of DSPy. These challenging cases allow us to gauge the robustness of DSPy and identify areas for potential improvement. | ||
- **Cross-Model Compatibility**: By testing with different LLMs, we aim to ensure that DSPy programs perform well across model types and configurations, reducing model-specific edge cases and enhancing program versatility. | ||
|
||
### Running the Tests | ||
|
||
- First, populate the configuration file `reliability_tests_conf.yaml` (located in this directory) with the necessary LiteLLM model/provider names and access credentials for 1. each LLM you want to test and 2. the LLM judge that you want to use for assessing the correctness of outputs in certain test cases. These should be placed in the `litellm_params` section for each model in the defined `model_list`. You can also use `litellm_params` to specify values for LLM hyperparameters like `temperature`. Any model that lacks configured `litellm_params` in the configuration file will be ignored during testing. | ||
|
||
The configuration must also specify a DSPy adapter to use when testing, e.g. `"chat"` (for `dspy.ChatAdapter`) or `"json"` (for `dspy.JSONAdapter`). | ||
|
||
An example of `reliability_tests_conf.yaml`: | ||
|
||
```yaml | ||
adapter: chat | ||
model_list: | ||
# The model to use for judging the correctness of program | ||
# outputs throughout reliability test suites. We recommend using | ||
# a high quality model as the judge, such as OpenAI GPT-4o | ||
- model_name: "judge" | ||
litellm_params: | ||
model: "openai/gpt-4o" | ||
api_key: "<my_openai_api_key>" | ||
- model_name: "gpt-4o" | ||
litellm_params: | ||
model: "openai/gpt-4o" | ||
api_key: "<my_openai_api_key>" | ||
- model_name: "claude-3.5-sonnet" | ||
litellm_params: | ||
model: "anthropic/claude-3.5" | ||
api_key: "<my_anthropic_api_key>" | ||
|
||
- Second, to run the tests, run the following command from this directory: | ||
|
||
```bash | ||
pytest . | ||
``` | ||
|
||
This will execute all tests for the configured models and display detailed results for each model configuration. Tests are set up to mark expected failures for known challenging cases where a specific model might struggle, while actual (unexpected) DSPy reliability issues are flagged as failures (see below). | ||
|
||
### Known Failing Models | ||
|
||
Some tests may be expected to fail with certain models, especially in challenging cases. These known failures are logged but do not affect the overall test result. This setup allows us to keep track of model-specific limitations without obstructing general test outcomes. Models that are known to fail a particular test case are specified using the `@known_failing_models` decorator. For example: | ||
|
||
``` | ||
@known_failing_models(["llama-3.2-3b-instruct"]) | ||
def test_program_with_complex_deeply_nested_output_structure(): | ||
... | ||
``` |
Empty file.
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,84 @@ | ||
import os | ||
|
||
import pytest | ||
|
||
import dspy | ||
from tests.conftest import clear_settings | ||
from tests.reliability.utils import parse_reliability_conf_yaml | ||
|
||
# Standard list of models that should be used for periodic DSPy reliability testing | ||
MODEL_LIST = [ | ||
"gpt-4o", | ||
"gpt-4o-mini", | ||
"gpt-4-turbo", | ||
"gpt-o1-preview", | ||
"gpt-o1-mini", | ||
"claude-3.5-sonnet", | ||
"claude-3.5-haiku", | ||
"gemini-1.5-pro", | ||
"gemini-1.5-flash", | ||
"llama-3.1-405b-instruct", | ||
"llama-3.1-70b-instruct", | ||
"llama-3.1-8b-instruct", | ||
"llama-3.2-3b-instruct", | ||
] | ||
|
||
|
||
def pytest_generate_tests(metafunc): | ||
""" | ||
Hook to parameterize reliability test cases with each model defined in the | ||
reliability tests YAML configuration | ||
""" | ||
known_failing_models = getattr(metafunc.function, "_known_failing_models", []) | ||
|
||
if "configure_model" in metafunc.fixturenames: | ||
params = [(model, model in known_failing_models) for model in MODEL_LIST] | ||
ids = [f"{model}" for model, _ in params] # Custom IDs for display | ||
metafunc.parametrize("configure_model", params, indirect=True, ids=ids) | ||
|
||
|
||
@pytest.fixture(autouse=True) | ||
def configure_model(request): | ||
""" | ||
Fixture to configure the DSPy library with a particular configured model and adapter | ||
before executing a test case. | ||
""" | ||
module_dir = os.path.dirname(os.path.abspath(__file__)) | ||
conf_path = os.path.join(module_dir, "reliability_conf.yaml") | ||
reliability_conf = parse_reliability_conf_yaml(conf_path) | ||
|
||
if reliability_conf.adapter.lower() == "chat": | ||
adapter = dspy.ChatAdapter() | ||
elif reliability_conf.adapter.lower() == "json": | ||
adapter = dspy.JSONAdapter() | ||
else: | ||
raise ValueError(f"Unknown adapter specification '{adapter}' in reliability_conf.yaml") | ||
|
||
model_name, should_ignore_failure = request.param | ||
model_params = reliability_conf.models.get(model_name) | ||
if model_params: | ||
lm = dspy.LM(**model_params) | ||
dspy.configure(lm=lm, adapter=adapter) | ||
else: | ||
pytest.skip( | ||
f"Skipping test because no reliability testing YAML configuration was found" f" for model {model_name}." | ||
) | ||
|
||
# Store `should_ignore_failure` flag on the request node for use in post-test handling | ||
request.node.should_ignore_failure = should_ignore_failure | ||
request.node.model_name = model_name | ||
|
||
|
||
@pytest.hookimpl(tryfirst=True, hookwrapper=True) | ||
def pytest_runtest_makereport(item, call): | ||
""" | ||
Hook to conditionally ignore failures in a given test case for known failing models. | ||
""" | ||
outcome = yield | ||
rep = outcome.get_result() | ||
|
||
should_ignore_failure = getattr(item, "should_ignore_failure", False) | ||
|
||
if should_ignore_failure and rep.failed: | ||
rep.outcome = "passed" | ||
rep.wasxfail = "Ignoring failure for known failing model" |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,75 @@ | ||
adapter: chat | ||
model_list: | ||
# The model to use for judging the correctness of program | ||
# outputs throughout reliability test suites. We recommend using | ||
# a high quality model as the judge, such as OpenAI GPT-4o | ||
- model_name: "judge" | ||
litellm_params: | ||
# model: "<litellm_provider>/<litellm_model_name>" | ||
# api_key: "api key" | ||
# api_base: "<api_base>" | ||
- model_name: "gpt-4o" | ||
litellm_params: | ||
# model: "<litellm_provider>/<litellm_model_name>" | ||
# api_key: "api key" | ||
# api_base: "<api_base>" | ||
- model_name: "gpt-4o-mini" | ||
litellm_params: | ||
# model: "<litellm_provider>/<litellm_model_name>" | ||
# api_key: "api key" | ||
# api_base: "<api_base>" | ||
- model_name: "gpt-4-turbo" | ||
litellm_params: | ||
# model: "<litellm_provider>/<litellm_model_name>" | ||
# api_key: "api key" | ||
# api_base: "<api_base>" | ||
- model_name: "gpt-o1" | ||
litellm_params: | ||
# model: "<litellm_provider>/<litellm_model_name>" | ||
# api_key: "api key" | ||
# api_base: "<api_base>" | ||
- model_name: "gpt-o1-mini" | ||
litellm_params: | ||
# model: "<litellm_provider>/<litellm_model_name>" | ||
# api_key: "api key" | ||
# api_base: "<api_base>" | ||
- model_name: "claude-3.5-sonnet" | ||
litellm_params: | ||
# model: "<litellm_provider>/<litellm_model_name>" | ||
# api_key: "api key" | ||
# api_base: "<api_base>" | ||
- model_name: "claude-3.5-haiku" | ||
litellm_params: | ||
# model: "<litellm_provider>/<litellm_model_name>" | ||
# api_key: "api key" | ||
# api_base: "<api_base>" | ||
- model_name: "gemini-1.5-pro" | ||
litellm_params: | ||
# model: "<litellm_provider>/<litellm_model_name>" | ||
# api_key: "api key" | ||
# api_base: "<api_base>" | ||
- model_name: "gemini-1.5-flash" | ||
litellm_params: | ||
# model: "<litellm_provider>/<litellm_model_name>" | ||
# api_key: "api key" | ||
# api_base: "<api_base>" | ||
- model_name: "llama-3.1-405b-instruct" | ||
litellm_params: | ||
# model: "<litellm_provider>/<litellm_model_name>" | ||
# api_key: "api key" | ||
# api_base: "<api_base>" | ||
- model_name: "llama-3.1-70b-instruct" | ||
litellm_params: | ||
# model: "<litellm_provider>/<litellm_model_name>" | ||
# api_key: "api key" | ||
# api_base: "<api_base>" | ||
- model_name: "llama-3.1-8b-instruct" | ||
litellm_params: | ||
# model: "<litellm_provider>/<litellm_model_name>" | ||
# api_key: "api key" | ||
# api_base: "<api_base>" | ||
- model_name: "llama-3.2-3b-instruct" | ||
litellm_params: | ||
# model: "<litellm_provider>/<litellm_model_name>" | ||
# api_key: "api key" | ||
# api_base: "<api_base>" |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,89 @@ | ||
from enum import Enum | ||
from typing import List | ||
|
||
import pydantic | ||
|
||
import dspy | ||
from tests.reliability.utils import assert_program_output_correct, known_failing_models | ||
|
||
|
||
def test_qa_with_pydantic_answer_model(): | ||
class Answer(pydantic.BaseModel): | ||
value: str | ||
certainty: float = pydantic.Field( | ||
description="A value between 0 and 1 indicating the model's confidence in the answer." | ||
) | ||
comments: List[str] = pydantic.Field( | ||
description="At least two comments providing additional details about the answer." | ||
) | ||
|
||
class QA(dspy.Signature): | ||
question: str = dspy.InputField() | ||
answer: Answer = dspy.OutputField() | ||
|
||
program = dspy.Predict(QA) | ||
answer = program(question="What is the capital of France?").answer | ||
|
||
assert_program_output_correct( | ||
program_output=answer.value, | ||
grading_guidelines="The answer should be Paris. Answer should not contain extraneous information.", | ||
) | ||
assert_program_output_correct( | ||
program_output=answer.comments, grading_guidelines="The comments should be relevant to the answer" | ||
) | ||
assert answer.certainty >= 0 | ||
assert answer.certainty <= 1 | ||
assert len(answer.comments) >= 2 | ||
|
||
|
||
def test_color_classification_using_enum(): | ||
Color = Enum("Color", ["RED", "GREEN", "BLUE"]) | ||
|
||
class Colorful(dspy.Signature): | ||
text: str = dspy.InputField() | ||
color: Color = dspy.OutputField() | ||
|
||
program = dspy.Predict(Colorful) | ||
color = program(text="The sky is blue").color | ||
|
||
assert color == Color.BLUE | ||
|
||
|
||
def test_entity_extraction_with_multiple_primitive_outputs(): | ||
class ExtractEntityFromDescriptionOutput(pydantic.BaseModel): | ||
entity_hu: str = pydantic.Field(description="The extracted entity in Hungarian, cleaned and lowercased.") | ||
entity_en: str = pydantic.Field(description="The English translation of the extracted Hungarian entity.") | ||
is_inverted: bool = pydantic.Field( | ||
description="Boolean flag indicating if the input is connected in an inverted way." | ||
) | ||
categories: str = pydantic.Field(description="English categories separated by '|' to which the entity belongs.") | ||
review: bool = pydantic.Field( | ||
description="Boolean flag indicating low confidence or uncertainty in the extraction." | ||
) | ||
|
||
class ExtractEntityFromDescription(dspy.Signature): | ||
"""Extract an entity from a Hungarian description, provide its English translation, categories, and an inverted flag.""" | ||
|
||
description: str = dspy.InputField(description="The input description in Hungarian.") | ||
entity: ExtractEntityFromDescriptionOutput = dspy.OutputField( | ||
description="The extracted entity and its properties." | ||
) | ||
|
||
program = dspy.ChainOfThought(ExtractEntityFromDescription) | ||
|
||
extracted_entity = program(description="A kávé egy növényi eredetű ital, amelyet a kávébabból készítenek.").entity | ||
assert_program_output_correct( | ||
program_output=extracted_entity.entity_hu, | ||
grading_guidelines="The translation of the text into English should be equivalent to 'coffee'", | ||
) | ||
assert_program_output_correct( | ||
program_output=extracted_entity.entity_hu, | ||
grading_guidelines="The text should be equivalent to 'coffee'", | ||
) | ||
assert_program_output_correct( | ||
program_output=extracted_entity.categories, | ||
grading_guidelines=( | ||
"The text should contain English language categories that apply to the word 'coffee'." | ||
" The categories should be separated by the character '|'." | ||
), | ||
) |
Oops, something went wrong.