From 257105e894549f6ed2a63f55c73e9334ce46e0d8 Mon Sep 17 00:00:00 2001 From: "selman.ozleyen" Date: Fri, 19 Jun 2026 16:06:24 +0200 Subject: [PATCH 1/3] use numba progress for in threadpool --- pyproject.toml | 1 + src/squidpy/_utils.py | 84 +++++++++++++++++++++++++++++++++++-------- 2 files changed, 70 insertions(+), 15 deletions(-) diff --git a/pyproject.toml b/pyproject.toml index 9a67f5d83..cbc1eb1ae 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -56,6 +56,7 @@ dependencies = [ "matplotlib-scalebar>=0.8", "networkx>=2.6", "numba>=0.56.4", + "numba-progress>=1.1.0", "numpy>=1.23", "omnipath>=1.0.7", "pandas>=2.1", diff --git a/src/squidpy/_utils.py b/src/squidpy/_utils.py index 17e6a7a83..6c1e95dac 100644 --- a/src/squidpy/_utils.py +++ b/src/squidpy/_utils.py @@ -7,7 +7,7 @@ import os import warnings from collections.abc import Callable, Generator, Hashable, Iterable, Sequence -from contextlib import contextmanager +from contextlib import contextmanager, nullcontext from enum import Enum from multiprocessing import Manager from queue import Queue @@ -19,7 +19,11 @@ import numpy as np import xarray as xr from spatialdata.models import Image2DModel, Labels2DModel -from tqdm.auto import tqdm + +if TYPE_CHECKING: + from contextlib import AbstractContextManager + + from numba_progress import ProgressBar __all__ = ["singledispatchmethod", "Signal", "SigQueue", "NDArray", "NDArrayA"] @@ -243,6 +247,48 @@ def wrapper(*args: Any, **kwargs: Any) -> Any: return wrapper +def progress_bar( + total: int, + *, + show_progress_bar: bool = True, + unit: str = "item", + desc: str | None = None, +) -> AbstractContextManager[ProgressBar | None]: + """Create a progress bar usable both inside and outside :mod:`numba` functions. + + Wraps :class:`numba_progress.ProgressBar`, which polls an atomic counter from a background + thread. The same object can be ``.update()``-ed from pure Python *and* passed into an + ``njit(nogil=True)`` function, so squidpy uses a single progress-bar implementation instead of + one for numba and one for everything else. Functions built on :func:`parallelize` keep their own + queue-based bar and do not use this. + + Parameters + ---------- + total + Total number of iterations. + show_progress_bar + Whether to display the progress bar. When ``False`` a no-op context yielding ``None`` is + returned, so call sites guard updates with ``if pbar is not None``. + unit + Label shown next to the counter. + desc + Optional prefix description. + + Returns + ------- + A context manager yielding a :class:`numba_progress.ProgressBar` (or ``None`` when disabled). + """ + if not show_progress_bar: + return nullcontext(None) + + from numba_progress import ProgressBar + + kwargs: dict[str, Any] = {"total": total, "unit": unit} + if desc is not None: + kwargs["desc"] = desc + return ProgressBar(**kwargs) + + def thread_map( fn: Callable[..., Any], items: Sequence[Any], @@ -262,9 +308,9 @@ def thread_map( n_jobs Number of worker threads. ``1`` runs sequentially (no pool overhead). show_progress_bar - Whether to display a ``tqdm`` progress bar. + Whether to display a progress bar (see :func:`progress_bar`). unit - Label shown next to the ``tqdm`` counter. + Label shown next to the progress counter. Returns ------- @@ -273,17 +319,25 @@ def thread_map( """ from concurrent.futures import ThreadPoolExecutor - if n_jobs == 1: - it: Iterable[Any] = map(fn, items) - if show_progress_bar and tqdm is not None: - it = tqdm(it, total=len(items), unit=unit) - return list(it) - - with ThreadPoolExecutor(max_workers=n_jobs) as pool: - it = pool.map(fn, items) - if show_progress_bar and tqdm is not None: - it = tqdm(it, total=len(items), unit=unit) - return list(it) + items = list(items) + + with progress_bar(len(items), show_progress_bar=show_progress_bar, unit=unit) as pbar: + if n_jobs == 1: + results = [] + for item in items: + results.append(fn(item)) + if pbar is not None: + pbar.update(1) + return results + + with ThreadPoolExecutor(max_workers=n_jobs) as pool: + results = [] + # ``pool.map`` yields in submission order, so results stay aligned with *items*. + for res in pool.map(fn, items): + results.append(res) + if pbar is not None: + pbar.update(1) + return results def _get_n_cores(n_cores: int | None) -> int: From 15ec58fe0f2efce97a26a61a3542b396affce30e Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Selman=20=C3=96zleyen?= <32667648+selmanozleyen@users.noreply.github.com> Date: Fri, 19 Jun 2026 16:10:03 +0200 Subject: [PATCH 2/3] Update _utils.py --- src/squidpy/_utils.py | 25 +------------------------ 1 file changed, 1 insertion(+), 24 deletions(-) diff --git a/src/squidpy/_utils.py b/src/squidpy/_utils.py index 6c1e95dac..999c84e60 100644 --- a/src/squidpy/_utils.py +++ b/src/squidpy/_utils.py @@ -254,30 +254,7 @@ def progress_bar( unit: str = "item", desc: str | None = None, ) -> AbstractContextManager[ProgressBar | None]: - """Create a progress bar usable both inside and outside :mod:`numba` functions. - - Wraps :class:`numba_progress.ProgressBar`, which polls an atomic counter from a background - thread. The same object can be ``.update()``-ed from pure Python *and* passed into an - ``njit(nogil=True)`` function, so squidpy uses a single progress-bar implementation instead of - one for numba and one for everything else. Functions built on :func:`parallelize` keep their own - queue-based bar and do not use this. - - Parameters - ---------- - total - Total number of iterations. - show_progress_bar - Whether to display the progress bar. When ``False`` a no-op context yielding ``None`` is - returned, so call sites guard updates with ``if pbar is not None``. - unit - Label shown next to the counter. - desc - Optional prefix description. - - Returns - ------- - A context manager yielding a :class:`numba_progress.ProgressBar` (or ``None`` when disabled). - """ + """Create a progress bar usable both inside and outside :mod:`numba` functions.""" if not show_progress_bar: return nullcontext(None) From 69a88ba069368a993872ba514572d8e214970bb7 Mon Sep 17 00:00:00 2001 From: "pre-commit-ci[bot]" <66853113+pre-commit-ci[bot]@users.noreply.github.com> Date: Fri, 19 Jun 2026 15:57:40 +0000 Subject: [PATCH 3/3] [pre-commit.ci] auto fixes from pre-commit.com hooks for more information, see https://pre-commit.ci --- pyproject.toml | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/pyproject.toml b/pyproject.toml index 10681183e..897ec77f1 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -56,7 +56,7 @@ dependencies = [ "matplotlib-scalebar>=0.8", "networkx>=2.6", "numba>=0.56.4", - "numba-progress>=1.1.0", + "numba-progress>=1.1", "numpy>=1.23", "omnipath>=1.0.7", "pandas>=2.1",