From 7e102fe7933961f91dd29b5429af7049c2c719b5 Mon Sep 17 00:00:00 2001 From: Omar Khattab Date: Tue, 17 Dec 2024 05:46:24 -0800 Subject: [PATCH] Cleanup and fixing some imports (#1949) --- dspy/__init__.py | 4 +- dspy/evaluate/__init__.py | 12 +- dspy/evaluate/auto_evaluation.py | 42 --- dspy/functional/__init__.py | 1 - dspy/functional/functional.py | 450 ------------------------------- dspy/primitives/box.py | 157 ----------- 6 files changed, 3 insertions(+), 663 deletions(-) delete mode 100644 dspy/functional/__init__.py delete mode 100644 dspy/functional/functional.py delete mode 100644 dspy/primitives/box.py diff --git a/dspy/__init__.py b/dspy/__init__.py index 4718af5cb..6f4428525 100644 --- a/dspy/__init__.py +++ b/dspy/__init__.py @@ -6,11 +6,9 @@ import dspy.retrievers -# Functional must be imported after primitives, predict and signatures -from dspy.functional import * # isort: skip from dspy.evaluate import Evaluate # isort: skip from dspy.clients import * # isort: skip -from dspy.adapters import * # isort: skip +from dspy.adapters import Adapter, ChatAdapter, JSONAdapter, Image # isort: skip from dspy.utils.logging_utils import configure_dspy_loggers, disable_logging, enable_logging from dspy.utils.asyncify import asyncify from dspy.utils.saving import load diff --git a/dspy/evaluate/__init__.py b/dspy/evaluate/__init__.py index 3e0f16fdd..c2336c3f9 100644 --- a/dspy/evaluate/__init__.py +++ b/dspy/evaluate/__init__.py @@ -1,13 +1,5 @@ from dspy.dsp.utils import EM, normalize_text -from dspy.evaluate import auto_evaluation +from dspy.evaluate.metrics import answer_exact_match, answer_passage_match from dspy.evaluate.evaluate import Evaluate -from dspy.evaluate import metrics - -__all__ = [ - "auto_evaluation", - "Evaluate", - "metrics", - "EM", - "normalize_text", -] +from dspy.evaluate.auto_evaluation import SemanticF1, CompleteAndGrounded diff --git a/dspy/evaluate/auto_evaluation.py b/dspy/evaluate/auto_evaluation.py index d98332143..001a9043e 100644 --- a/dspy/evaluate/auto_evaluation.py +++ b/dspy/evaluate/auto_evaluation.py @@ -98,45 +98,3 @@ def forward(self, example, pred, trace=None): score = f1_score(groundedness.groundedness, completeness.completeness) return score if trace is None else score >= self.threshold - - - -# """ -# Soon-to-be deprecated Signatures & Modules Below. -# """ - - -# class AnswerCorrectnessSignature(dspy.Signature): -# """Verify that the predicted answer matches the gold answer.""" - -# question = dspy.InputField() -# gold_answer = dspy.InputField(desc="correct answer for question") -# predicted_answer = dspy.InputField(desc="predicted answer for question") -# is_correct = dspy.OutputField(desc="True or False") - - -# class AnswerCorrectness(dspy.Module): -# def __init__(self): -# super().__init__() -# self.evaluate_correctness = dspy.ChainOfThought(AnswerCorrectnessSignature) - -# def forward(self, question, gold_answer, predicted_answer): -# return self.evaluate_correctness(question=question, gold_answer=gold_answer, predicted_answer=predicted_answer) - - -# class AnswerFaithfulnessSignature(dspy.Signature): -# """Verify that the predicted answer is based on the provided context.""" - -# context = dspy.InputField(desc="relevant facts for producing answer") -# question = dspy.InputField() -# answer = dspy.InputField(desc="often between 1 and 5 words") -# is_faithful = dspy.OutputField(desc="True or False") - - -# class AnswerFaithfulness(dspy.Module): -# def __init__(self): -# super().__init__() -# self.evaluate_faithfulness = dspy.ChainOfThought(AnswerFaithfulnessSignature) - -# def forward(self, context, question, answer): -# return self.evaluate_faithfulness(context=context, question=question, answer=answer) diff --git a/dspy/functional/__init__.py b/dspy/functional/__init__.py deleted file mode 100644 index a1ce1bb7f..000000000 --- a/dspy/functional/__init__.py +++ /dev/null @@ -1 +0,0 @@ -# from .functional import FunctionalModule, TypedChainOfThought, TypedPredictor, cot, predictor diff --git a/dspy/functional/functional.py b/dspy/functional/functional.py deleted file mode 100644 index 2c4ba3d1b..000000000 --- a/dspy/functional/functional.py +++ /dev/null @@ -1,450 +0,0 @@ -# import json -# import ujson -# import logging -# import inspect -# import typing -# import pydantic - -# from functools import lru_cache -# from pydantic.fields import FieldInfo -# from typing import Annotated, Callable, List, Tuple, Union # noqa: UP035 - -# import dspy -# from dspy.dsp.adapters import passages2text -# from dspy.primitives.prediction import Prediction -# from dspy.signatures.signature import ensure_signature, make_signature - -# @lru_cache(maxsize=None) -# def warn_once(msg: str): -# logging.warning(msg) - - -# def predictor(*args: tuple, **kwargs) -> Callable[..., dspy.Module]: -# def _predictor(func) -> dspy.Module: -# """Decorator that creates a predictor module based on the provided function.""" -# signature = _func_to_signature(func) -# *_, output_key = signature.output_fields.keys() -# return _StripOutput(TypedPredictor(signature, **kwargs), output_key) - -# # if we have only a single callable argument, the decorator was invoked with no key word arguments -# # so we just return the wrapped function -# if len(args) == 1 and callable(args[0]) and len(kwargs) == 0: -# return _predictor(args[0]) -# return _predictor - - -# def cot(*args: tuple, **kwargs) -> Callable[..., dspy.Module]: -# def _cot(func) -> dspy.Module: -# """Decorator that creates a chain of thought module based on the provided function.""" -# signature = _func_to_signature(func) -# *_, output_key = signature.output_fields.keys() -# return _StripOutput(TypedChainOfThought(signature, **kwargs), output_key) - -# # if we have only a single callable argument, the decorator was invoked with no key word arguments -# # so we just return the wrapped function -# if len(args) == 1 and callable(args[0]) and len(kwargs) == 0: -# return _cot(args[0]) -# return _cot - - -# class _StripOutput(dspy.Module): -# def __init__(self, predictor, output_key): -# super().__init__() -# self.predictor = predictor -# self.output_key = output_key - -# def copy(self): -# return _StripOutput(self.predictor.copy(), self.output_key) - -# def forward(self, **kwargs): -# prediction = self.predictor(**kwargs) -# return prediction[self.output_key] - - -# class FunctionalModule(dspy.Module): -# """To use the @cot and @predictor decorators, your module needs to inherit form this class.""" - -# def __init__(self): -# super().__init__() -# for name in dir(self): -# attr = getattr(self, name) -# if isinstance(attr, dspy.Module): -# self.__dict__[name] = attr.copy() - - -# def TypedChainOfThought(signature, instructions=None, reasoning=None, *, max_retries=3) -> dspy.Module: # noqa: N802 -# """Just like TypedPredictor, but adds a ChainOfThought OutputField.""" -# signature = ensure_signature(signature, instructions) -# output_keys = ", ".join(signature.output_fields.keys()) - -# default_rationale = dspy.OutputField( -# prefix="Reasoning: Let's think step by step in order to", -# desc="${produce the " + output_keys + "}. We ...", -# ) -# reasoning = reasoning or default_rationale - -# return TypedPredictor( -# signature.prepend( -# "reasoning", -# reasoning, -# ), -# max_retries=max_retries, -# ) - - -# class TypedPredictor(dspy.Module): -# def __init__(self, signature, instructions=None, *, max_retries=3, wrap_json=False, explain_errors=False): -# """Like dspy.Predict, but enforces type annotations in the signature. - -# Args: -# signature: The signature of the module. Can use type annotations. -# instructions: A description of what the model should do. -# max_retries: The number of times to retry the prediction if the output is invalid. -# wrap_json: If True, json objects in the input will be wrapped in ```json ... ``` -# explain_errors: If True, the model will try to explain the errors it encounters. -# """ -# super().__init__() - -# # Warn: deprecation warning. -# warn_once( -# "\t*** Since DSPy 2.5.16+, TypedPredictors are now deprecated, underperform, and are about to be removed! ***\n" -# "Please use standard predictors, e.g. dspy.Predict and dspy.ChainOfThought.\n" -# "They now support type annotations and other features of TypedPredictors and " -# "tend to work much better out of the box.\n" -# "Please let us know if you face any issues: https://github.com/stanfordnlp/dspy/issues" -# ) - -# signature = ensure_signature(signature, instructions) -# self.predictor = dspy.Predict(signature, _parse_values=False) -# self.max_retries = max_retries -# self.wrap_json = wrap_json -# self.explain_errors = explain_errors - -# @property -# def signature(self) -> dspy.Signature: -# return self.predictor.signature - -# @signature.setter -# def signature(self, value: dspy.Signature): -# self.predictor.signature = value - -# def copy(self) -> "TypedPredictor": -# return TypedPredictor( -# self.signature, -# max_retries=self.max_retries, -# wrap_json=self.wrap_json, -# explain_errors=self.explain_errors, -# ) - -# def __repr__(self): -# """Return a string representation of the TypedPredictor object.""" -# return f"TypedPredictor({self.signature})" - -# def _make_example(self, field) -> str: -# # Note: DSPy will cache this call so we only pay the first time TypedPredictor is called. -# if hasattr(field, "model_json_schema"): -# pass -# schema = field.json_schema_extra["schema"] -# parser = field.json_schema_extra["parser"] -# if self.wrap_json: -# schema = "```json\n" + schema + "\n```\n" -# json_object = dspy.Predict( -# make_signature( -# "json_schema -> json_object", -# "Make a very succinct json object that validates with the following schema", -# ), -# _parse_values=False, -# )(json_schema=schema).json_object -# # We use the parser to make sure the json object is valid. -# try: -# parser(_unwrap_json(json_object, parser)) -# except (pydantic.ValidationError, ValueError): -# return "" # Unable to make an example -# return json_object -# # TODO: Another fun idea is to only (but automatically) do this if the output fails. -# # We could also have a more general "suggest solution" prompt that tries to fix the output -# # More directly. -# # TODO: Instead of using a language model to create the example, we can also just use a -# # library like https://pypi.org/project/polyfactory/ that's made exactly to do this. - -# def _format_error( -# self, -# error: Exception, -# task_description: Union[str, FieldInfo], -# model_output: str, -# lm_explain: bool, -# ) -> str: -# if isinstance(error, pydantic.ValidationError): -# errors = [] -# for e in error.errors(): -# fields = ", ".join(map(str, e["loc"])) -# errors.append(f"{e['msg']}: {fields} (error type: {e['type']})") -# error_text = "; ".join(errors) -# else: -# error_text = repr(error) - -# if self.explain_errors and lm_explain: -# if isinstance(task_description, FieldInfo): -# args = task_description.json_schema_extra -# task_description = args["prefix"] + " " + args["desc"] -# return ( -# error_text -# + "\n" -# + self._make_explanation( -# task_description=task_description, -# model_output=model_output, -# error=error_text, -# ) -# ) - -# return error_text - -# def _make_explanation(self, task_description: str, model_output: str, error: str) -> str: -# class Signature(dspy.Signature): -# """I gave my language model a task, but it failed. - -# Figure out what went wrong, and write instructions to help it avoid the error next time. -# """ - -# task_description: str = dspy.InputField(desc="What I asked the model to do") -# language_model_output: str = dspy.InputField(desc="The output of the model") -# error: str = dspy.InputField(desc="The validation error triggered by the models output") -# explanation: str = dspy.OutputField(desc="Explain what the model did wrong") -# advice: str = dspy.OutputField( -# desc="Instructions for the model to do better next time. A single paragraph.", -# ) - -# # TODO: We could also try repair the output here. For example, if the output is a float, but the -# # model returned a "float + explanation", the repair could be to remove the explanation. - -# return dspy.Predict(Signature)( -# task_description=task_description, -# language_model_output=model_output, -# error=error, -# _parse_values=False, -# ).advice - -# def _prepare_signature(self) -> dspy.Signature: -# """Add formats and parsers to the signature fields, based on the type annotations of the fields.""" -# signature = self.signature -# for name, field in self.signature.fields.items(): -# is_output = field.json_schema_extra["__dspy_field_type"] == "output" -# type_ = field.annotation -# if is_output: -# if type_ is bool: - -# def parse(x): -# x = x.strip().lower() -# if x not in ("true", "false"): -# raise ValueError("Respond with true or false") -# return x == "true" - -# signature = signature.with_updated_fields( -# name, -# desc=field.json_schema_extra.get("desc", "") -# + (" (Respond with true or false)" if type_ != str else ""), -# format=lambda x: x if isinstance(x, str) else str(x), -# parser=parse, -# ) -# elif type_ in (str, int, float): -# signature = signature.with_updated_fields( -# name, -# desc=field.json_schema_extra.get("desc", "") -# + (f" (Respond with a single {type_.__name__} value)" if type_ != str else ""), -# format=lambda x: x if isinstance(x, str) else str(x), -# parser=type_, -# ) -# else: -# # Anything else we wrap in a pydantic object -# if ( -# inspect.isclass(type_) -# and typing.get_origin(type_) not in (list, tuple) # To support Python 3.9 -# and issubclass(type_, pydantic.BaseModel) -# ): -# def to_json(x): -# return x.model_dump_json() -# def from_json(x, type_=type_): -# return type_.model_validate_json(x) -# schema = json.dumps(type_.model_json_schema()) -# else: -# adapter = pydantic.TypeAdapter(type_) -# def to_json(x): -# return adapter.serializer.to_json(x) -# def from_json(x, type_=adapter): -# return type_.validate_json(x) -# schema = json.dumps(adapter.json_schema()) -# if self.wrap_json: -# def to_json(x, inner=to_json): -# return "```json\n" + inner(x) + "\n```\n" -# schema = "```json\n" + schema + "\n```" -# signature = signature.with_updated_fields( -# name, -# desc=field.json_schema_extra.get("desc", "") -# + (". Respond with a single JSON object. JSON Schema: " + schema), -# format=lambda x, to_json=to_json: (x if isinstance(x, str) else to_json(x)), -# parser=lambda x, from_json=from_json: from_json(_unwrap_json(x, from_json)), -# schema=schema, -# type_=type_, -# ) -# else: # If input field -# is_json = False -# def format_(x): -# return x if isinstance(x, str) else str(x) -# if type_ in (List[str], list[str], Tuple[str], tuple[str]): -# format_ = passages2text -# # Special formatting for lists of known types. Maybe the output fields sohuld have this too? -# elif typing.get_origin(type_) in (List, list, Tuple, tuple): -# (inner_type,) = typing.get_args(type_) -# if inspect.isclass(inner_type) and issubclass(inner_type, pydantic.BaseModel): -# def format_(x): -# return x if isinstance(x, str) else "[" + ",".join(i.model_dump_json() for i in x) + "]" -# else: -# def format_(x): -# return x if isinstance(x, str) else json.dumps(x) -# is_json = True -# elif inspect.isclass(type_) and issubclass(type_, pydantic.BaseModel): -# def format_(x): -# return x if isinstance(x, str) else x.model_dump_json() -# is_json = True -# if self.wrap_json and is_json: -# def format_(x, inner=format_): -# return x if isinstance(x, str) else "```json\n" + inner(x) + "\n```\n" -# signature = signature.with_updated_fields(name, format=format_) - -# return signature - -# def forward(self, **kwargs) -> dspy.Prediction: -# modified_kwargs = kwargs.copy() -# # We have to re-prepare the signature on every forward call, because the base -# # signature might have been modified by an optimizer or something like that. -# signature = self._prepare_signature() -# for try_i in range(self.max_retries): -# result = self.predictor(**modified_kwargs, new_signature=signature) -# errors = {} -# parsed_results = [] -# # Parse the outputs -# for completion in result.completions: -# parsed = {} -# for name, field in signature.output_fields.items(): -# try: -# value = completion[name] -# parser = field.json_schema_extra.get("parser", lambda x: x) -# parsed[name] = parser(value) -# except (pydantic.ValidationError, ValueError) as e: -# errors[name] = self._format_error( -# e, -# signature.fields[name], -# value, -# lm_explain=try_i + 1 < self.max_retries, -# ) - -# # If we can, we add an example to the error message -# current_desc = field.json_schema_extra.get("desc", "") -# i = current_desc.find("JSON Schema: ") -# if i == -1: -# continue # Only add examples to JSON objects -# suffix, current_desc = current_desc[i:], current_desc[:i] -# prefix = "You MUST use this format: " -# if ( -# try_i + 1 < self.max_retries -# and prefix not in current_desc -# and (example := self._make_example(field)) -# ): -# signature = signature.with_updated_fields( -# name, -# desc=current_desc + "\n" + prefix + example + "\n" + suffix, -# ) -# # No reason trying to parse the general signature, or run more completions, if we already have errors -# if errors: -# break -# # Instantiate the actual signature with the parsed values. -# # This allow pydantic to validate the fields defined in the signature. -# try: -# _ = self.signature(**kwargs, **parsed) -# parsed_results.append(parsed) -# except pydantic.ValidationError as e: -# errors["general"] = self._format_error( -# e, -# signature.instructions, -# "\n\n".join( -# "> " + field.json_schema_extra["prefix"] + " " + completion[name] -# for name, field in signature.output_fields.items() -# ), -# lm_explain=try_i + 1 < self.max_retries, -# ) -# if errors: -# # Add new fields for each error -# for name, error in errors.items(): -# modified_kwargs[f"error_{name}_{try_i}"] = error -# if name == "general": -# error_prefix = "General:" -# else: -# error_prefix = signature.output_fields[name].json_schema_extra["prefix"] -# number = "" if try_i == 0 else f" ({try_i+1})" -# signature = signature.append( -# f"error_{name}_{try_i}", -# dspy.InputField( -# prefix=f"Past Error{number} in {error_prefix}", -# desc="An error to avoid in the future", -# ), -# ) -# else: -# # If there are no errors, we return the parsed results -# return Prediction.from_completions( -# {key: [r[key] for r in parsed_results] for key in signature.output_fields}, -# ) -# raise ValueError( -# "Too many retries trying to get the correct output format. " + "Try simplifying the requirements.", -# errors, -# ) - - -# def _func_to_signature(func): -# """Make a dspy.Signature based on a function definition.""" -# sig = inspect.signature(func) -# annotations = typing.get_type_hints(func, include_extras=True) -# output_key = func.__name__ -# instructions = func.__doc__ -# fields = {} - -# # Input fields -# for param in sig.parameters.values(): -# if param.name == "self": -# continue -# # We default to str as the type of the input -# annotation = annotations.get(param.name, str) -# kwargs = {} -# if typing.get_origin(annotation) is Annotated: -# desc = next((arg for arg in typing.get_args(annotation) if isinstance(arg, str)), None) -# if desc is not None: -# kwargs["desc"] = desc -# fields[param.name] = (annotation, dspy.InputField(**kwargs)) - -# # Output field -# kwargs = {} -# annotation = annotations.get("return", str) -# if typing.get_origin(annotation) is Annotated: -# desc = next((arg for arg in typing.get_args(annotation) if isinstance(arg, str)), None) -# if desc is not None: -# kwargs["desc"] = desc -# fields[output_key] = (annotation, dspy.OutputField(**kwargs)) - -# return dspy.Signature(fields, instructions) - - -# def _unwrap_json(output, from_json: Callable[[str], Union[pydantic.BaseModel, str, None]]): -# try: -# parsing_result = from_json(output) -# if isinstance(parsing_result, pydantic.BaseModel): -# return parsing_result.model_dump_json() -# else: -# return output -# except (ValueError, pydantic.ValidationError, AttributeError): -# output = output.strip() -# if output.startswith("```"): -# if not output.startswith("```json"): -# raise ValueError("json output should start with ```json") from None -# if not output.endswith("```"): -# raise ValueError("Don't write anything after the final json ```") from None -# output = output[7:-3].strip() -# return ujson.dumps(ujson.loads(output)) # ujson is a bit more robust than the standard json \ No newline at end of file diff --git a/dspy/primitives/box.py b/dspy/primitives/box.py deleted file mode 100644 index db1b51d8f..000000000 --- a/dspy/primitives/box.py +++ /dev/null @@ -1,157 +0,0 @@ -""" -TODO: If we want to have Prediction::{**keys, completions, box} where box.{key} will behave as a value but also include -the completions internally. - -The main thing left is to determine the semantic (and then implement them) for applying operations on the object. - -If we have a string (query) and completions (5 queries), and we modify the string, what happens to the completions? - -- Option 1: We modify the string and the (other) completions are not affected. -- Option 2: We modify the string and the (other) completions too. -- Option 3: We modify the string and the (other) completions are deleted. - -Option 2 seems most reasonable, but it depends on what metadata the box is going to store. - -It seems that a box fundamentally has two functions then: -- Store a value and its "alternatives" (and hence allow transparent application over operations on value/alts) - - But not all operations can/should be applied on all as a map. - - I guess mainly "give me a string or list or dict or tuple or int or float" has to commit to a value. - - There also needs to be a .item(). -- Potentially track the "source" of the value (i.e., the predictor that generated it, and its inputs) -- Give the value (eventually) to something that will consume the main value (implicitly or explicitly) or all/some of its alternatives explicitly. - -It might be wise to make this responsible for a smaller scope for now: - -- Just one string (and its alternatives). -- No source tracking. -- Allow operations on the string to map over the alternatives. -- Seamless extraction at code boundaries. - - Basically, code will either treat this as string implicitly - (and hence only know about the one value, and on best effort basis we update the alternatives) - - Or code will explicitly work with the string or explicitly work with the full set of alternatives. - -- By default, all programs (and their sub-programs) will be running inside a context in which preserve_boxes=True. -- But outside the program, once we see that none of the parent contexts have preserve_boxes=True, we can automatically - unpack all boxes before returning to user. - -Okay, so we'll have predictors return a `pred` in which `pred.query` is a box. - -You'd usually do one of: - -### Things that just give you one string - 1- Print `pred.query` or save it in a dict somewhere or a file somewhere. - 2- Call `pred.query.item()` to get the string explicitly. - 3- Modifications in freeform Python. - - Modify it by calling `pred.query = 'new query'` altogether. - - Modify it by doing `pred.query += 'new query'` or templating `f'{pred.query} new query'`. - - Other modifications are not allowed on strings (e.g., `pred.query[0] = 'a'` or `pred.query[0] += 'a'`). - - Cast to boolean after a comparison: `if pred.query == 'something': ...` - - Pytorch would say RuntimeError: Boolean value of Tensor with more than one value is ambiguous - - But we can keep the primary value and use that in the boolean. - - So technically, comparison can stick around, giving you multiple internal bools. - -Overall, I think it's coherent semantics, for the time being, to say that any of the above will just give you a string back and lose all tracking. - - -### Things that give you a list of strings - 1- Explicitly asking for the candidates/completions. - 2- Then you could filter or map that list arbitrarily. - -In this case, it's just that Box will serve as syntactic sugar. If you don't want to think about `n` at all, you can -pretend you have a string. If you do anything arbitrary on it, it indeed becomes a string. -If you later decide to treat it as a list, it's easy to do so without losing that info when you say `pred.query`. - -### Things that are more interesting - -A) You can now pass pred.query to a DSPy predictor (or searcher, etc) and it can either naively work with the string, -like pass it to a template, or it can explicitly ask for the list of candidates and do something with that. - -This will need a lot more string-specific operations though: -- endswith, startswith, contains, split, strip, lower, upper, etc. -- when doing ' '.join() must do map(str, values_to_join). No implicit __str__ conversion! -- We can probably automate this by having a general fallback? That either returns one value or maps that over all of them. - -B) When you say dspy.assert pred.sentence1.endswith('blue'), it will actually check all the alternatives and locally filter them if possible. -It may keep the bad ones somewhere just in case too. - -We could make this a little more explicit like dspy.assert(pred.sentence1, lambda x: x.endswith('blue')) - -C) When program_temperature is high, we can actually have some more interesting logic here. When you try to do things that are "selective", -maybe we'll randomly give you one of the strings (that remain valid in the box, based on assertions). - -This could lead to interesting efficiency, because we can basically rerun the program, it'll still generate n=10 candidates, -but each time it'll use a different one. So when branch_index changes, you get a new candidate each time, but it should be consistent in the same box. -I guess this could be done by shuffling the same N=10 things. So basically, there's a user N and there's a system-level M. - -We can sometimes optimize things by doing M=5. So we'll generate Nx5 candidates in one or more calls (depending on value of Nx5). -Then depending on the branch_idx, we'll return a fixed set of N candidates. But we have the rest. - -""" - - -class BoxType(type): - # List of operations to override - ops = [ - # Arithmetic operations - 'add', 'sub', 'mul', 'truediv', 'floordiv', 'mod', 'pow', - 'lshift', 'rshift', 'and', 'or', 'xor', - # 'r'-prefixed versions of arithmetic operations - 'radd', 'rsub', 'rmul', 'rtruediv', 'rfloordiv', 'rmod', - 'rpow', 'rlshift', 'rrshift', 'rand', 'ror', 'rxor', - # Sequence operations - 'getitem', 'setitem', 'delitem', 'contains', - # Unary and other operations - 'neg', 'pos', 'abs', 'invert', 'round', 'len', - 'getitem', 'setitem', 'delitem', 'contains', 'iter', - # Mappings operations (for dicts) - 'get', 'keys', 'values', 'items', - # Comparison - 'eq', 'ne', 'lt', 'le', 'gt', 'ge', - ] - - def __init__(cls, name, bases, attrs): - def create_method(op): - def method(self, other=None): - if op in ['len', 'keys', 'values', 'items']: - return getattr(self._value, op)() - elif isinstance(other, Box): - return Box(getattr(self._value, f'__{op}__')(other._value)) - elif other is not None: - return Box(getattr(self._value, f'__{op}__')(other)) - else: - return NotImplemented - return method - - for op in BoxType.ops: - setattr(cls, f'__{op}__', create_method(op)) - - super().__init__(name, bases, attrs) - - -class Box(metaclass=BoxType): - def __init__(self, value, source=False): - self._value = value - self._source = source - - def __repr__(self): - return repr(self._value) - - def __str__(self): - return str(self._value) - - def __bool__(self): - return bool(self._value) - - # if method is missing just call it on the _value - def __getattr__(self, name): - return Box(getattr(self._value, name)) - - # # Unlike the others, this one collapses to a bool directly - # def __eq__(self, other): - # if isinstance(other, Box): - # return self._value == other._value - # else: - # return self._value == other - - # def __ne__(self, other): - # return not self.__eq__(other)