diff --git a/pyproject.toml b/pyproject.toml index ddad831cb..897ec77f1 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -56,6 +56,7 @@ dependencies = [ "matplotlib-scalebar>=0.8", "networkx>=2.6", "numba>=0.56.4", + "numba-progress>=1.1", "numpy>=1.23", "omnipath>=1.0.7", "pandas>=2.1", diff --git a/src/squidpy/_utils.py b/src/squidpy/_utils.py index 17e6a7a83..999c84e60 100644 --- a/src/squidpy/_utils.py +++ b/src/squidpy/_utils.py @@ -7,7 +7,7 @@ import os import warnings from collections.abc import Callable, Generator, Hashable, Iterable, Sequence -from contextlib import contextmanager +from contextlib import contextmanager, nullcontext from enum import Enum from multiprocessing import Manager from queue import Queue @@ -19,7 +19,11 @@ import numpy as np import xarray as xr from spatialdata.models import Image2DModel, Labels2DModel -from tqdm.auto import tqdm + +if TYPE_CHECKING: + from contextlib import AbstractContextManager + + from numba_progress import ProgressBar __all__ = ["singledispatchmethod", "Signal", "SigQueue", "NDArray", "NDArrayA"] @@ -243,6 +247,25 @@ def wrapper(*args: Any, **kwargs: Any) -> Any: return wrapper +def progress_bar( + total: int, + *, + show_progress_bar: bool = True, + unit: str = "item", + desc: str | None = None, +) -> AbstractContextManager[ProgressBar | None]: + """Create a progress bar usable both inside and outside :mod:`numba` functions.""" + if not show_progress_bar: + return nullcontext(None) + + from numba_progress import ProgressBar + + kwargs: dict[str, Any] = {"total": total, "unit": unit} + if desc is not None: + kwargs["desc"] = desc + return ProgressBar(**kwargs) + + def thread_map( fn: Callable[..., Any], items: Sequence[Any], @@ -262,9 +285,9 @@ def thread_map( n_jobs Number of worker threads. ``1`` runs sequentially (no pool overhead). show_progress_bar - Whether to display a ``tqdm`` progress bar. + Whether to display a progress bar (see :func:`progress_bar`). unit - Label shown next to the ``tqdm`` counter. + Label shown next to the progress counter. Returns ------- @@ -273,17 +296,25 @@ def thread_map( """ from concurrent.futures import ThreadPoolExecutor - if n_jobs == 1: - it: Iterable[Any] = map(fn, items) - if show_progress_bar and tqdm is not None: - it = tqdm(it, total=len(items), unit=unit) - return list(it) - - with ThreadPoolExecutor(max_workers=n_jobs) as pool: - it = pool.map(fn, items) - if show_progress_bar and tqdm is not None: - it = tqdm(it, total=len(items), unit=unit) - return list(it) + items = list(items) + + with progress_bar(len(items), show_progress_bar=show_progress_bar, unit=unit) as pbar: + if n_jobs == 1: + results = [] + for item in items: + results.append(fn(item)) + if pbar is not None: + pbar.update(1) + return results + + with ThreadPoolExecutor(max_workers=n_jobs) as pool: + results = [] + # ``pool.map`` yields in submission order, so results stay aligned with *items*. + for res in pool.map(fn, items): + results.append(res) + if pbar is not None: + pbar.update(1) + return results def _get_n_cores(n_cores: int | None) -> int: