Skip to content

Commit

Permalink
Switch settings from contextvar to thread local storage (for Colab) (#…
Browse files Browse the repository at this point in the history
  • Loading branch information
okhat authored and isaacbmiller committed Dec 11, 2024
1 parent 6ec4a76 commit e28358e
Show file tree
Hide file tree
Showing 2 changed files with 56 additions and 33 deletions.
50 changes: 29 additions & 21 deletions dsp/utils/settings.py
Original file line number Diff line number Diff line change
@@ -1,8 +1,6 @@
import copy
import threading

from contextlib import contextmanager
from contextvars import ContextVar
from dsp.utils.utils import dotdict

DEFAULT_CONFIG = dotdict(
Expand Down Expand Up @@ -31,8 +29,14 @@
# Global base configuration
main_thread_config = copy.deepcopy(DEFAULT_CONFIG)

# Initialize the context variable with an empty dict as default
dspy_ctx_overrides = ContextVar('dspy_ctx_overrides', default=dotdict())

class ThreadLocalOverrides(threading.local):
def __init__(self):
self.overrides = dotdict() # Initialize thread-local overrides


# Create the thread-local storage
thread_local_overrides = ThreadLocalOverrides()


class Settings:
Expand All @@ -53,7 +57,7 @@ def __new__(cls):
return cls._instance

def __getattr__(self, name):
overrides = dspy_ctx_overrides.get()
overrides = getattr(thread_local_overrides, 'overrides', dotdict())
if name in overrides:
return overrides[name]
elif name in main_thread_config:
Expand All @@ -76,7 +80,7 @@ def __setitem__(self, key, value):
self.__setattr__(key, value)

def __contains__(self, key):
overrides = dspy_ctx_overrides.get()
overrides = getattr(thread_local_overrides, 'overrides', dotdict())
return key in overrides or key in main_thread_config

def get(self, key, default=None):
Expand All @@ -86,45 +90,49 @@ def get(self, key, default=None):
return default

def copy(self):
overrides = dspy_ctx_overrides.get()
overrides = getattr(thread_local_overrides, 'overrides', dotdict())
return dotdict({**main_thread_config, **overrides})

@property
def config(self):
config = self.copy()
del config['lock']
if 'lock' in config:
del config['lock']
return config

# Configuration methods

def configure(self, return_token=False, **kwargs):
def configure(self, **kwargs):
global main_thread_config
overrides = dspy_ctx_overrides.get()
new_overrides = dotdict({**copy.deepcopy(DEFAULT_CONFIG), **main_thread_config, **overrides, **kwargs})
token = dspy_ctx_overrides.set(new_overrides)

# Get or initialize thread-local overrides
overrides = getattr(thread_local_overrides, 'overrides', dotdict())
thread_local_overrides.overrides = dotdict(
{**copy.deepcopy(DEFAULT_CONFIG), **main_thread_config, **overrides, **kwargs}
)

# Update main_thread_config, in the main thread only
if threading.current_thread() is threading.main_thread():
main_thread_config = new_overrides

if return_token:
return token
main_thread_config = thread_local_overrides.overrides

@contextmanager
def context(self, **kwargs):
"""Context manager for temporary configuration changes."""
token = self.configure(return_token=True, **kwargs)
global main_thread_config
original_overrides = getattr(thread_local_overrides, 'overrides', dotdict()).copy()
original_main_thread_config = main_thread_config.copy()

self.configure(**kwargs)
try:
yield
finally:
dspy_ctx_overrides.reset(token)
thread_local_overrides.overrides = original_overrides

if threading.current_thread() is threading.main_thread():
global main_thread_config
main_thread_config = dotdict({**copy.deepcopy(DEFAULT_CONFIG), **dspy_ctx_overrides.get()})
main_thread_config = original_main_thread_config

def __repr__(self):
overrides = dspy_ctx_overrides.get()
overrides = getattr(thread_local_overrides, 'overrides', dotdict())
combined_config = {**main_thread_config, **overrides}
return repr(combined_config)

Expand Down
39 changes: 27 additions & 12 deletions dspy/utils/parallelizer.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,14 +5,11 @@
import threading
import traceback
import contextlib

from contextvars import copy_context
from tqdm.contrib.logging import logging_redirect_tqdm
from concurrent.futures import ThreadPoolExecutor, as_completed

logger = logging.getLogger(__name__)


class ParallelExecutor:
def __init__(
self,
Expand Down Expand Up @@ -80,10 +77,16 @@ def _execute_isolated_single_thread(self, function, data):
if self.cancel_jobs.is_set():
break

# Create an isolated context for each task
task_ctx = copy_context()
result = task_ctx.run(function, item)
results.append(result)
# Create an isolated context for each task using thread-local overrides
from dsp.utils.settings import thread_local_overrides
original_overrides = thread_local_overrides.overrides
thread_local_overrides.overrides = thread_local_overrides.overrides.copy()

try:
result = function(item)
results.append(result)
finally:
thread_local_overrides.overrides = original_overrides

if self.compare_results:
# Assumes score is the last element of the result tuple
Expand Down Expand Up @@ -137,18 +140,30 @@ def interrupt_handler(sig, frame):
# If not in the main thread, skip setting signal handlers
yield

def cancellable_function(index_item):
def cancellable_function(parent_overrides, index_item):
index, item = index_item
if self.cancel_jobs.is_set():
return index, job_cancelled
return index, function(item)

# Create an isolated context for each task using thread-local overrides
from dsp.utils.settings import thread_local_overrides
original_overrides = thread_local_overrides.overrides
thread_local_overrides.overrides = parent_overrides.copy()

try:
return index, function(item)
finally:
thread_local_overrides.overrides = original_overrides

with ThreadPoolExecutor(max_workers=self.num_threads) as executor, interrupt_handler_manager():
# Capture the parent thread's overrides
from dsp.utils.settings import thread_local_overrides
parent_overrides = thread_local_overrides.overrides.copy()

futures = {}
for pair in enumerate(data):
# Capture the context for each task
task_ctx = copy_context()
future = executor.submit(task_ctx.run, cancellable_function, pair)
# Pass the parent thread's overrides to each thread
future = executor.submit(cancellable_function, parent_overrides, pair)
futures[future] = pair

pbar = tqdm.tqdm(
Expand Down

0 comments on commit e28358e

Please sign in to comment.