From 1fee93174de709f07211967cf97850d6de95c51a Mon Sep 17 00:00:00 2001 From: Jammy2211 Date: Mon, 13 Apr 2026 20:54:48 +0100 Subject: [PATCH] feat: add background thread for on-the-fly visualization during sampling MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit Run perform_quick_update on a daemon thread so the sampler is not blocked while matplotlib renders plots. Uses a latest-only queue pattern — if a new best-fit arrives before the previous visualization finishes, the stale request is silently replaced. JAX arrays are conditionally converted to numpy before handing to the worker thread. Co-Authored-By: Claude Opus 4.6 --- autofit/config/general.yaml | 3 +- autofit/non_linear/analysis/analysis.py | 10 ++ autofit/non_linear/fitness.py | 34 +++- autofit/non_linear/quick_update.py | 107 ++++++++++++ autofit/non_linear/search/abstract_search.py | 9 + .../non_linear/search/nest/nautilus/search.py | 4 +- test_autofit/non_linear/test_quick_update.py | 163 ++++++++++++++++++ 7 files changed, 324 insertions(+), 6 deletions(-) create mode 100644 autofit/non_linear/quick_update.py create mode 100644 test_autofit/non_linear/test_quick_update.py diff --git a/autofit/config/general.yaml b/autofit/config/general.yaml index 4130d5c93..739bb0420 100644 --- a/autofit/config/general.yaml +++ b/autofit/config/general.yaml @@ -1,6 +1,7 @@ updates: iterations_per_quick_update: 1e99 # Non-linear search iterations between every quick update, which just displays the maximum likelihood model fit. - iterations_per_full_update: 1e99 # Non-linear search iterations between every full update, which outputs all visuals and result fits (e.g. model.result, search.summary), this exits the search and can be slow. + iterations_per_full_update: 1e99 # Non-linear search iterations between every full update, which outputs all visuals and result fits (e.g. model.result, search.summary), this exits the search and can be slow. + quick_update_background: false # If True, quick-update visualization runs on a background thread so sampling is not blocked. hpc: hpc_mode: false # If True, use HPC mode, which disables GUI visualization, logging to screen and other settings which are not suited to running on a super computer. iterations_per_quick_update: 1e99 # Non-linear search iterations between every quick update, which just displays the maximum likelihood model fit. diff --git a/autofit/non_linear/analysis/analysis.py b/autofit/non_linear/analysis/analysis.py index 6137560ae..1201744cc 100644 --- a/autofit/non_linear/analysis/analysis.py +++ b/autofit/non_linear/analysis/analysis.py @@ -299,6 +299,16 @@ def make_result( analysis=analysis, ) + @property + def supports_background_update(self) -> bool: + """Whether this analysis supports background quick updates.""" + return False + + @property + def supports_jax_visualization(self) -> bool: + """Whether the visualizer can work directly with JAX arrays.""" + return False + def perform_quick_update(self, paths, instance): raise NotImplementedError diff --git a/autofit/non_linear/fitness.py b/autofit/non_linear/fitness.py index c36800fc5..0d5153165 100644 --- a/autofit/non_linear/fitness.py +++ b/autofit/non_linear/fitness.py @@ -44,6 +44,7 @@ def __init__( use_jax_vmap : bool = False, batch_size : Optional[int] = None, iterations_per_quick_update: Optional[int] = None, + background_quick_update: bool = False, ): """ Interfaces with any non-linear search to fit the model to the data and return a log likelihood via @@ -129,6 +130,20 @@ def __init__( self.quick_update_max_lh = -self._xp.inf self.quick_update_count = 0 + self._background_quick_update = None + + if background_quick_update and self.iterations_per_quick_update is not None: + from autofit.non_linear.quick_update import BackgroundQuickUpdate + + convert_jax = ( + getattr(self.analysis, "_use_jax", False) + and not getattr(self.analysis, "supports_jax_visualization", False) + ) + + self._background_quick_update = BackgroundQuickUpdate( + convert_jax=convert_jax, + ) + if self.paths is not None: self.check_log_likelihood(fitness=self) @@ -314,10 +329,15 @@ def manage_quick_update(self, parameters, log_likelihood): instance = self.model.instance_from_vector(vector=self.quick_update_max_lh_parameters, xp=self._xp) - try: - self.analysis.perform_quick_update(self.paths, instance) - except NotImplementedError: - pass + if self._background_quick_update is not None: + self._background_quick_update.submit( + self.analysis, self.paths, instance, + ) + else: + try: + self.analysis.perform_quick_update(self.paths, instance) + except NotImplementedError: + pass result_info = text_util.result_max_lh_info_from( max_log_likelihood_sample=self.quick_update_max_lh_parameters.tolist(), @@ -333,6 +353,12 @@ def manage_quick_update(self, parameters, log_likelihood): logger.info(f"Quick update complete in {time.time() - start_time} seconds.") + def shutdown_quick_update(self): + """Shut down the background quick-update worker, if one is running.""" + if self._background_quick_update is not None: + self._background_quick_update.shutdown() + self._background_quick_update = None + @timeout(timeout_seconds) def __call__(self, parameters, *kwargs): """ diff --git a/autofit/non_linear/quick_update.py b/autofit/non_linear/quick_update.py new file mode 100644 index 000000000..ebaa8d8eb --- /dev/null +++ b/autofit/non_linear/quick_update.py @@ -0,0 +1,107 @@ +import copy +import logging +import threading + +import numpy as np + +logger = logging.getLogger(__name__) + + +def _convert_jax_to_numpy(instance): + """ + Return a deep copy of *instance* with every JAX array replaced by a + NumPy array. Plain NumPy values and non-array attributes are left + unchanged. + + This is used so that the background visualisation thread never + touches JAX / GPU state, which is not thread-safe. + """ + instance = copy.deepcopy(instance) + + for attr in vars(instance): + value = getattr(instance, attr) + if hasattr(value, "device"): + setattr(instance, attr, np.asarray(value)) + + return instance + + +class BackgroundQuickUpdate: + """ + Runs ``analysis.perform_quick_update`` on a background daemon thread so + that the sampler is not blocked while matplotlib renders and saves plots. + + Uses a **latest-only** pattern: if a new best-fit arrives before the + previous visualisation finishes, the stale request is silently replaced. + + Parameters + ---------- + convert_jax + If ``True``, JAX arrays on the model instance are converted to + NumPy before handing them to the worker thread. + """ + + def __init__(self, convert_jax: bool = False): + self._convert_jax = convert_jax + + self._lock = threading.Lock() + self._pending = None + self._has_work = threading.Event() + self._stop = threading.Event() + + self._thread = threading.Thread( + target=self._worker, + daemon=True, + name="quick-update-worker", + ) + self._thread.start() + + def submit(self, analysis, paths, instance): + """ + Enqueue a quick-update request. If a previous request is still + pending (not yet picked up by the worker), it is replaced. + """ + + if self._convert_jax: + instance = _convert_jax_to_numpy(instance) + + with self._lock: + self._pending = (analysis, paths, instance) + + self._has_work.set() + + def shutdown(self, timeout: float = 10.0): + """Signal the worker to stop after draining pending work.""" + self._stop.set() + self._has_work.set() + self._thread.join(timeout=timeout) + + def _process_pending(self): + with self._lock: + work = self._pending + self._pending = None + + if work is None: + return + + analysis, paths, instance = work + + try: + analysis.perform_quick_update(paths, instance) + except NotImplementedError: + pass + except Exception: + logger.exception( + "Background quick-update raised an exception (ignored)." + ) + + def _worker(self): + while True: + self._has_work.wait() + self._has_work.clear() + + self._process_pending() + + if self._stop.is_set(): + self._process_pending() + break diff --git a/autofit/non_linear/search/abstract_search.py b/autofit/non_linear/search/abstract_search.py index ccb2c1392..b15062ce1 100644 --- a/autofit/non_linear/search/abstract_search.py +++ b/autofit/non_linear/search/abstract_search.py @@ -214,6 +214,12 @@ def __init__( self.iterations_per_full_update = float((iterations_per_full_update or conf.instance["general"]["updates"]["iterations_per_full_update"])) + self.quick_update_background = bool( + conf.instance["general"]["updates"].get( + "quick_update_background", False, + ) + ) + if conf.instance["general"]["hpc"]["hpc_mode"]: self.iterations_per_quick_update = float(conf.instance["general"]["hpc"][ "iterations_per_quick_update" @@ -664,6 +670,9 @@ def start_resume_fit(self, analysis: Analysis, model: AbstractPriorModel) -> Res analysis=analysis, ) + if hasattr(fitness, "shutdown_quick_update"): + fitness.shutdown_quick_update() + samples = self.perform_update( model=model, analysis=analysis, diff --git a/autofit/non_linear/search/nest/nautilus/search.py b/autofit/non_linear/search/nest/nautilus/search.py index b9b568b83..b772b42e2 100644 --- a/autofit/non_linear/search/nest/nautilus/search.py +++ b/autofit/non_linear/search/nest/nautilus/search.py @@ -192,6 +192,7 @@ def _fit(self, model: AbstractPriorModel, analysis): fom_is_log_likelihood=True, resample_figure_of_merit=-1.0e99, iterations_per_quick_update=self.iterations_per_quick_update, + background_quick_update=self.quick_update_background, use_jax_vmap=self.use_jax_vmap, batch_size=self.n_batch, ) @@ -210,7 +211,8 @@ def _fit(self, model: AbstractPriorModel, analysis): paths=self.paths, fom_is_log_likelihood=True, resample_figure_of_merit=-1.0e99, - iterations_per_quick_update=self.iterations_per_quick_update + iterations_per_quick_update=self.iterations_per_quick_update, + background_quick_update=self.quick_update_background, ) search_internal = self.fit_multiprocessing( diff --git a/test_autofit/non_linear/test_quick_update.py b/test_autofit/non_linear/test_quick_update.py new file mode 100644 index 000000000..be6eecac5 --- /dev/null +++ b/test_autofit/non_linear/test_quick_update.py @@ -0,0 +1,163 @@ +import threading +import time + +import numpy as np +import pytest + +from autofit.non_linear.quick_update import BackgroundQuickUpdate, _convert_jax_to_numpy + + +class MockPaths: + pass + + +class MockAnalysis: + """Records calls to perform_quick_update for assertions.""" + + def __init__(self, delay=0.0): + self.calls = [] + self._lock = threading.Lock() + self._delay = delay + + def perform_quick_update(self, paths, instance): + if self._delay: + time.sleep(self._delay) + with self._lock: + self.calls.append(instance) + + +class ErrorAnalysis: + """Always raises from perform_quick_update.""" + + def perform_quick_update(self, paths, instance): + raise RuntimeError("Visualization failed") + + +class TestBackgroundQuickUpdate: + def test_single_submit(self): + analysis = MockAnalysis() + worker = BackgroundQuickUpdate() + + worker.submit(analysis, MockPaths(), "instance_1") + worker.shutdown() + + assert analysis.calls == ["instance_1"] + + def test_latest_only(self): + """When multiple submits happen before the worker picks up, only the + last one should be processed.""" + analysis = MockAnalysis(delay=0.2) + worker = BackgroundQuickUpdate() + + # First submit — will be picked up and start processing (with delay) + worker.submit(analysis, MockPaths(), "instance_1") + time.sleep(0.05) # let the worker pick it up and start processing + + # These two land while the worker is busy with instance_1 + worker.submit(analysis, MockPaths(), "instance_2") + worker.submit(analysis, MockPaths(), "instance_3") + + worker.shutdown() + + # instance_1 was already being processed, instance_3 replaces instance_2 + assert analysis.calls == ["instance_1", "instance_3"] + + def test_shutdown_joins_cleanly(self): + worker = BackgroundQuickUpdate() + worker.shutdown(timeout=2.0) + assert not worker._thread.is_alive() + + def test_exception_does_not_crash(self): + analysis = ErrorAnalysis() + worker = BackgroundQuickUpdate() + + worker.submit(analysis, MockPaths(), "instance") + time.sleep(0.1) # let the worker process it + + # Worker should still be alive and functional + assert worker._thread.is_alive() + worker.shutdown() + + def test_exception_followed_by_valid(self): + """After an exception, subsequent submits should still work.""" + error_analysis = ErrorAnalysis() + good_analysis = MockAnalysis() + worker = BackgroundQuickUpdate() + + worker.submit(error_analysis, MockPaths(), "bad") + time.sleep(0.1) + + worker.submit(good_analysis, MockPaths(), "good") + worker.shutdown() + + assert good_analysis.calls == ["good"] + + +class TestConvertJaxToNumpy: + def test_converts_array_with_device_attr(self): + class FakeJaxArray: + def __init__(self, data): + self.data = data + self.device = "gpu:0" + + def __array__(self, dtype=None): + return np.array(self.data, dtype=dtype) + + class Instance: + def __init__(self): + self.param = FakeJaxArray([1.0, 2.0, 3.0]) + self.name = "test" + + instance = Instance() + converted = _convert_jax_to_numpy(instance) + + assert isinstance(converted.param, np.ndarray) + np.testing.assert_array_equal(converted.param, [1.0, 2.0, 3.0]) + assert converted.name == "test" + # Original should be unchanged + assert hasattr(instance.param, "device") + + def test_leaves_plain_values_alone(self): + class Instance: + def __init__(self): + self.x = 1.0 + self.arr = np.array([1, 2]) + + instance = Instance() + converted = _convert_jax_to_numpy(instance) + + assert converted.x == 1.0 + np.testing.assert_array_equal(converted.arr, [1, 2]) + + +class TestConvertJaxFlag: + def test_convert_jax_false(self): + analysis = MockAnalysis() + worker = BackgroundQuickUpdate(convert_jax=False) + + obj = {"key": "value"} # not a real instance, just checking pass-through + worker.submit(analysis, MockPaths(), obj) + worker.shutdown() + + assert analysis.calls[0] is obj + + def test_convert_jax_true(self): + class FakeJaxArray: + def __init__(self, data): + self.data = data + self.device = "gpu:0" + + def __array__(self, dtype=None): + return np.array(self.data, dtype=dtype) + + class Instance: + def __init__(self): + self.param = FakeJaxArray([1.0]) + + analysis = MockAnalysis() + worker = BackgroundQuickUpdate(convert_jax=True) + + worker.submit(analysis, MockPaths(), Instance()) + worker.shutdown() + + assert isinstance(analysis.calls[0].param, np.ndarray)