From e0055cabd591a87ffac1f3d54808a63e311471fc Mon Sep 17 00:00:00 2001 From: Chen Qian Date: Tue, 17 Dec 2024 18:35:51 -0800 Subject: [PATCH] make version check loose (#1946) --- dspy/primitives/module.py | 4 +-- dspy/utils/saving.py | 5 +-- tests/primitives/test_module.py | 55 +++++++++++++++++++++++++++++++++ tests/utils/test_saving.py | 55 +++++++++++++++++++++++++++++++++ 4 files changed, 114 insertions(+), 5 deletions(-) diff --git a/dspy/primitives/module.py b/dspy/primitives/module.py index 8bec91620..6fba01eb1 100644 --- a/dspy/primitives/module.py +++ b/dspy/primitives/module.py @@ -219,9 +219,7 @@ def save(self, path, save_program=False): with open(path, "wb") as f: cloudpickle.dump(state, f) else: - raise ValueError( - f"`path` must end with `.json` or `.pkl` when `save_program=False`, but received: {path}" - ) + raise ValueError(f"`path` must end with `.json` or `.pkl` when `save_program=False`, but received: {path}") def load(self, path): """Load the saved module. You may also want to check out dspy.load, if you want to diff --git a/dspy/utils/saving.py b/dspy/utils/saving.py index 95659a72d..283e5dce6 100644 --- a/dspy/utils/saving.py +++ b/dspy/utils/saving.py @@ -10,10 +10,11 @@ def get_dependency_versions(): + cloudpickle_version = '.'.join(cloudpickle.__version__.split('.')[:2]) return { - "python": f"{sys.version_info.major}.{sys.version_info.minor}.{sys.version_info.micro}", + "python": f"{sys.version_info.major}.{sys.version_info.minor}", "dspy": importlib_metadata.version("dspy"), - "cloudpickle": cloudpickle.__version__, + "cloudpickle": cloudpickle_version, } diff --git a/tests/primitives/test_module.py b/tests/primitives/test_module.py index d0532fca1..1d458eddb 100644 --- a/tests/primitives/test_module.py +++ b/tests/primitives/test_module.py @@ -1,6 +1,8 @@ import dspy import threading from dspy.utils.dummies import DummyLM +import logging +from unittest.mock import patch def test_deepcopy_basic(): @@ -106,3 +108,56 @@ def dummy_metric(example, pred, trace=None): assert str(new_cot.predict.signature) == str(compiled_cot.predict.signature) assert new_cot.predict.demos == compiled_cot.predict.demos + + +def test_load_with_version_mismatch(tmp_path): + from dspy.primitives.module import logger + + # Mock versions during save + save_versions = {"python": "3.9", "dspy": "2.4.0", "cloudpickle": "2.0"} + + # Mock versions during load + load_versions = {"python": "3.10", "dspy": "2.5.0", "cloudpickle": "2.1"} + + predict = dspy.Predict("question->answer") + + # Create a custom handler to capture log messages + class ListHandler(logging.Handler): + def __init__(self): + super().__init__() + self.messages = [] + + def emit(self, record): + self.messages.append(record.getMessage()) + + # Add handler and set level + handler = ListHandler() + original_level = logger.level + logger.addHandler(handler) + logger.setLevel(logging.WARNING) + + try: + save_path = tmp_path / "program.pkl" + # Mock version during save + with patch("dspy.primitives.module.get_dependency_versions", return_value=save_versions): + predict.save(save_path) + + # Mock version during load + with patch("dspy.primitives.module.get_dependency_versions", return_value=load_versions): + loaded_predict = dspy.Predict("question->answer") + loaded_predict.load(save_path) + + # Assert warnings were logged, and one warning for each mismatched dependency. + assert len(handler.messages) == 3 + + for msg in handler.messages: + assert "There is a mismatch of" in msg + + # Verify the model still loads correctly despite version mismatches + assert isinstance(loaded_predict, dspy.Predict) + assert str(predict.signature) == str(loaded_predict.signature) + + finally: + # Clean up: restore original level and remove handler + logger.setLevel(original_level) + logger.removeHandler(handler) diff --git a/tests/utils/test_saving.py b/tests/utils/test_saving.py index 6a0089946..6b28c942c 100644 --- a/tests/utils/test_saving.py +++ b/tests/utils/test_saving.py @@ -1,5 +1,9 @@ import dspy from dspy.utils import DummyLM +from unittest.mock import patch +import pytest +from dspy.utils.saving import get_dependency_versions +import logging def test_save_predict(tmp_path): @@ -74,3 +78,54 @@ def dummy_metric(example, pred, trace=None): loaded_predict = dspy.load(tmp_path) assert compiled_predict.demos == loaded_predict.demos assert compiled_predict.signature == loaded_predict.signature + + +def test_load_with_version_mismatch(tmp_path): + from dspy.utils.saving import logger + + # Mock versions during save + save_versions = {"python": "3.9", "dspy": "2.4.0", "cloudpickle": "2.0"} + + # Mock versions during load + load_versions = {"python": "3.10", "dspy": "2.5.0", "cloudpickle": "2.1"} + + predict = dspy.Predict("question->answer") + + # Create a custom handler to capture log messages + class ListHandler(logging.Handler): + def __init__(self): + super().__init__() + self.messages = [] + + def emit(self, record): + self.messages.append(record.getMessage()) + + # Add handler and set level + handler = ListHandler() + original_level = logger.level + logger.addHandler(handler) + logger.setLevel(logging.WARNING) + + try: + # Mock version during save + with patch("dspy.utils.saving.get_dependency_versions", return_value=save_versions): + predict.save(tmp_path, save_program=True) + + # Mock version during load + with patch("dspy.utils.saving.get_dependency_versions", return_value=load_versions): + loaded_predict = dspy.load(tmp_path) + + # Assert warnings were logged, and one warning for each mismatched dependency. + assert len(handler.messages) == 3 + + for msg in handler.messages: + assert "There is a mismatch of" in msg + + # Verify the model still loads correctly despite version mismatches + assert isinstance(loaded_predict, dspy.Predict) + assert predict.signature == loaded_predict.signature + + finally: + # Clean up: restore original level and remove handler + logger.setLevel(original_level) + logger.removeHandler(handler)