Skip to content

Commit

Permalink
make version check loose (#1946)
Browse files Browse the repository at this point in the history
  • Loading branch information
chenmoneygithub authored Dec 18, 2024
1 parent 422973a commit e0055ca
Show file tree
Hide file tree
Showing 4 changed files with 114 additions and 5 deletions.
4 changes: 1 addition & 3 deletions dspy/primitives/module.py
Original file line number Diff line number Diff line change
Expand Up @@ -219,9 +219,7 @@ def save(self, path, save_program=False):
with open(path, "wb") as f:
cloudpickle.dump(state, f)
else:
raise ValueError(
f"`path` must end with `.json` or `.pkl` when `save_program=False`, but received: {path}"
)
raise ValueError(f"`path` must end with `.json` or `.pkl` when `save_program=False`, but received: {path}")

def load(self, path):
"""Load the saved module. You may also want to check out dspy.load, if you want to
Expand Down
5 changes: 3 additions & 2 deletions dspy/utils/saving.py
Original file line number Diff line number Diff line change
Expand Up @@ -10,10 +10,11 @@


def get_dependency_versions():
cloudpickle_version = '.'.join(cloudpickle.__version__.split('.')[:2])
return {
"python": f"{sys.version_info.major}.{sys.version_info.minor}.{sys.version_info.micro}",
"python": f"{sys.version_info.major}.{sys.version_info.minor}",
"dspy": importlib_metadata.version("dspy"),
"cloudpickle": cloudpickle.__version__,
"cloudpickle": cloudpickle_version,
}


Expand Down
55 changes: 55 additions & 0 deletions tests/primitives/test_module.py
Original file line number Diff line number Diff line change
@@ -1,6 +1,8 @@
import dspy
import threading
from dspy.utils.dummies import DummyLM
import logging
from unittest.mock import patch


def test_deepcopy_basic():
Expand Down Expand Up @@ -106,3 +108,56 @@ def dummy_metric(example, pred, trace=None):

assert str(new_cot.predict.signature) == str(compiled_cot.predict.signature)
assert new_cot.predict.demos == compiled_cot.predict.demos


def test_load_with_version_mismatch(tmp_path):
from dspy.primitives.module import logger

# Mock versions during save
save_versions = {"python": "3.9", "dspy": "2.4.0", "cloudpickle": "2.0"}

# Mock versions during load
load_versions = {"python": "3.10", "dspy": "2.5.0", "cloudpickle": "2.1"}

predict = dspy.Predict("question->answer")

# Create a custom handler to capture log messages
class ListHandler(logging.Handler):
def __init__(self):
super().__init__()
self.messages = []

def emit(self, record):
self.messages.append(record.getMessage())

# Add handler and set level
handler = ListHandler()
original_level = logger.level
logger.addHandler(handler)
logger.setLevel(logging.WARNING)

try:
save_path = tmp_path / "program.pkl"
# Mock version during save
with patch("dspy.primitives.module.get_dependency_versions", return_value=save_versions):
predict.save(save_path)

# Mock version during load
with patch("dspy.primitives.module.get_dependency_versions", return_value=load_versions):
loaded_predict = dspy.Predict("question->answer")
loaded_predict.load(save_path)

# Assert warnings were logged, and one warning for each mismatched dependency.
assert len(handler.messages) == 3

for msg in handler.messages:
assert "There is a mismatch of" in msg

# Verify the model still loads correctly despite version mismatches
assert isinstance(loaded_predict, dspy.Predict)
assert str(predict.signature) == str(loaded_predict.signature)

finally:
# Clean up: restore original level and remove handler
logger.setLevel(original_level)
logger.removeHandler(handler)
55 changes: 55 additions & 0 deletions tests/utils/test_saving.py
Original file line number Diff line number Diff line change
@@ -1,5 +1,9 @@
import dspy
from dspy.utils import DummyLM
from unittest.mock import patch
import pytest
from dspy.utils.saving import get_dependency_versions
import logging


def test_save_predict(tmp_path):
Expand Down Expand Up @@ -74,3 +78,54 @@ def dummy_metric(example, pred, trace=None):
loaded_predict = dspy.load(tmp_path)
assert compiled_predict.demos == loaded_predict.demos
assert compiled_predict.signature == loaded_predict.signature


def test_load_with_version_mismatch(tmp_path):
from dspy.utils.saving import logger

# Mock versions during save
save_versions = {"python": "3.9", "dspy": "2.4.0", "cloudpickle": "2.0"}

# Mock versions during load
load_versions = {"python": "3.10", "dspy": "2.5.0", "cloudpickle": "2.1"}

predict = dspy.Predict("question->answer")

# Create a custom handler to capture log messages
class ListHandler(logging.Handler):
def __init__(self):
super().__init__()
self.messages = []

def emit(self, record):
self.messages.append(record.getMessage())

# Add handler and set level
handler = ListHandler()
original_level = logger.level
logger.addHandler(handler)
logger.setLevel(logging.WARNING)

try:
# Mock version during save
with patch("dspy.utils.saving.get_dependency_versions", return_value=save_versions):
predict.save(tmp_path, save_program=True)

# Mock version during load
with patch("dspy.utils.saving.get_dependency_versions", return_value=load_versions):
loaded_predict = dspy.load(tmp_path)

# Assert warnings were logged, and one warning for each mismatched dependency.
assert len(handler.messages) == 3

for msg in handler.messages:
assert "There is a mismatch of" in msg

# Verify the model still loads correctly despite version mismatches
assert isinstance(loaded_predict, dspy.Predict)
assert predict.signature == loaded_predict.signature

finally:
# Clean up: restore original level and remove handler
logger.setLevel(original_level)
logger.removeHandler(handler)

0 comments on commit e0055ca

Please sign in to comment.