From 74b19c8ac88575f93a76d56a9c8b1a544f715f8e Mon Sep 17 00:00:00 2001 From: Omar Khattab Date: Tue, 26 Nov 2024 09:20:38 -0800 Subject: [PATCH] Switch settings from contextvar to thread local storage (for Colab) (#1860) --- dsp/utils/settings.py | 50 ++++++++++++++++++++++---------------- dspy/utils/parallelizer.py | 39 ++++++++++++++++++++--------- 2 files changed, 56 insertions(+), 33 deletions(-) diff --git a/dsp/utils/settings.py b/dsp/utils/settings.py index 4ffbd23d9..16ae6f93b 100644 --- a/dsp/utils/settings.py +++ b/dsp/utils/settings.py @@ -1,8 +1,6 @@ import copy import threading - from contextlib import contextmanager -from contextvars import ContextVar from dsp.utils.utils import dotdict DEFAULT_CONFIG = dotdict( @@ -31,8 +29,14 @@ # Global base configuration main_thread_config = copy.deepcopy(DEFAULT_CONFIG) -# Initialize the context variable with an empty dict as default -dspy_ctx_overrides = ContextVar('dspy_ctx_overrides', default=dotdict()) + +class ThreadLocalOverrides(threading.local): + def __init__(self): + self.overrides = dotdict() # Initialize thread-local overrides + + +# Create the thread-local storage +thread_local_overrides = ThreadLocalOverrides() class Settings: @@ -53,7 +57,7 @@ def __new__(cls): return cls._instance def __getattr__(self, name): - overrides = dspy_ctx_overrides.get() + overrides = getattr(thread_local_overrides, 'overrides', dotdict()) if name in overrides: return overrides[name] elif name in main_thread_config: @@ -76,7 +80,7 @@ def __setitem__(self, key, value): self.__setattr__(key, value) def __contains__(self, key): - overrides = dspy_ctx_overrides.get() + overrides = getattr(thread_local_overrides, 'overrides', dotdict()) return key in overrides or key in main_thread_config def get(self, key, default=None): @@ -86,45 +90,49 @@ def get(self, key, default=None): return default def copy(self): - overrides = dspy_ctx_overrides.get() + overrides = getattr(thread_local_overrides, 'overrides', dotdict()) return dotdict({**main_thread_config, **overrides}) @property def config(self): config = self.copy() - del config['lock'] + if 'lock' in config: + del config['lock'] return config # Configuration methods - def configure(self, return_token=False, **kwargs): + def configure(self, **kwargs): global main_thread_config - overrides = dspy_ctx_overrides.get() - new_overrides = dotdict({**copy.deepcopy(DEFAULT_CONFIG), **main_thread_config, **overrides, **kwargs}) - token = dspy_ctx_overrides.set(new_overrides) + + # Get or initialize thread-local overrides + overrides = getattr(thread_local_overrides, 'overrides', dotdict()) + thread_local_overrides.overrides = dotdict( + {**copy.deepcopy(DEFAULT_CONFIG), **main_thread_config, **overrides, **kwargs} + ) # Update main_thread_config, in the main thread only if threading.current_thread() is threading.main_thread(): - main_thread_config = new_overrides - - if return_token: - return token + main_thread_config = thread_local_overrides.overrides @contextmanager def context(self, **kwargs): """Context manager for temporary configuration changes.""" - token = self.configure(return_token=True, **kwargs) + global main_thread_config + original_overrides = getattr(thread_local_overrides, 'overrides', dotdict()).copy() + original_main_thread_config = main_thread_config.copy() + + self.configure(**kwargs) try: yield finally: - dspy_ctx_overrides.reset(token) + thread_local_overrides.overrides = original_overrides if threading.current_thread() is threading.main_thread(): - global main_thread_config - main_thread_config = dotdict({**copy.deepcopy(DEFAULT_CONFIG), **dspy_ctx_overrides.get()}) + main_thread_config = original_main_thread_config def __repr__(self): - overrides = dspy_ctx_overrides.get() + overrides = getattr(thread_local_overrides, 'overrides', dotdict()) combined_config = {**main_thread_config, **overrides} return repr(combined_config) diff --git a/dspy/utils/parallelizer.py b/dspy/utils/parallelizer.py index c6b5f3d5f..f40ee98d4 100644 --- a/dspy/utils/parallelizer.py +++ b/dspy/utils/parallelizer.py @@ -5,14 +5,11 @@ import threading import traceback import contextlib - -from contextvars import copy_context from tqdm.contrib.logging import logging_redirect_tqdm from concurrent.futures import ThreadPoolExecutor, as_completed logger = logging.getLogger(__name__) - class ParallelExecutor: def __init__( self, @@ -80,10 +77,16 @@ def _execute_isolated_single_thread(self, function, data): if self.cancel_jobs.is_set(): break - # Create an isolated context for each task - task_ctx = copy_context() - result = task_ctx.run(function, item) - results.append(result) + # Create an isolated context for each task using thread-local overrides + from dsp.utils.settings import thread_local_overrides + original_overrides = thread_local_overrides.overrides + thread_local_overrides.overrides = thread_local_overrides.overrides.copy() + + try: + result = function(item) + results.append(result) + finally: + thread_local_overrides.overrides = original_overrides if self.compare_results: # Assumes score is the last element of the result tuple @@ -137,18 +140,30 @@ def interrupt_handler(sig, frame): # If not in the main thread, skip setting signal handlers yield - def cancellable_function(index_item): + def cancellable_function(parent_overrides, index_item): index, item = index_item if self.cancel_jobs.is_set(): return index, job_cancelled - return index, function(item) + + # Create an isolated context for each task using thread-local overrides + from dsp.utils.settings import thread_local_overrides + original_overrides = thread_local_overrides.overrides + thread_local_overrides.overrides = parent_overrides.copy() + + try: + return index, function(item) + finally: + thread_local_overrides.overrides = original_overrides with ThreadPoolExecutor(max_workers=self.num_threads) as executor, interrupt_handler_manager(): + # Capture the parent thread's overrides + from dsp.utils.settings import thread_local_overrides + parent_overrides = thread_local_overrides.overrides.copy() + futures = {} for pair in enumerate(data): - # Capture the context for each task - task_ctx = copy_context() - future = executor.submit(task_ctx.run, cancellable_function, pair) + # Pass the parent thread's overrides to each thread + future = executor.submit(cancellable_function, parent_overrides, pair) futures[future] = pair pbar = tqdm.tqdm(