Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
1 change: 1 addition & 0 deletions pyproject.toml
Original file line number Diff line number Diff line change
Expand Up @@ -56,6 +56,7 @@ dependencies = [
"matplotlib-scalebar>=0.8",
"networkx>=2.6",
"numba>=0.56.4",
"numba-progress>=1.1",
"numpy>=1.23",
"omnipath>=1.0.7",
"pandas>=2.1",
Expand Down
61 changes: 46 additions & 15 deletions src/squidpy/_utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -7,7 +7,7 @@
import os
import warnings
from collections.abc import Callable, Generator, Hashable, Iterable, Sequence
from contextlib import contextmanager
from contextlib import contextmanager, nullcontext
from enum import Enum
from multiprocessing import Manager
from queue import Queue
Expand All @@ -19,7 +19,11 @@
import numpy as np
import xarray as xr
from spatialdata.models import Image2DModel, Labels2DModel
from tqdm.auto import tqdm

if TYPE_CHECKING:
from contextlib import AbstractContextManager

Copy link
Copy Markdown
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Why do we need this?

Copy link
Copy Markdown
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

to annotate the return type


from numba_progress import ProgressBar

__all__ = ["singledispatchmethod", "Signal", "SigQueue", "NDArray", "NDArrayA"]

Expand Down Expand Up @@ -243,6 +247,25 @@ def wrapper(*args: Any, **kwargs: Any) -> Any:
return wrapper


def progress_bar(
total: int,
*,
show_progress_bar: bool = True,
unit: str = "item",
desc: str | None = None,
) -> AbstractContextManager[ProgressBar | None]:
"""Create a progress bar usable both inside and outside :mod:`numba` functions."""
if not show_progress_bar:
return nullcontext(None)

from numba_progress import ProgressBar

kwargs: dict[str, Any] = {"total": total, "unit": unit}
if desc is not None:
kwargs["desc"] = desc
return ProgressBar(**kwargs)


def thread_map(
fn: Callable[..., Any],
items: Sequence[Any],
Expand All @@ -262,9 +285,9 @@ def thread_map(
n_jobs
Number of worker threads. ``1`` runs sequentially (no pool overhead).
show_progress_bar
Whether to display a ``tqdm`` progress bar.
Whether to display a progress bar (see :func:`progress_bar`).
unit
Label shown next to the ``tqdm`` counter.
Label shown next to the progress counter.

Returns
-------
Expand All @@ -273,17 +296,25 @@ def thread_map(
"""
from concurrent.futures import ThreadPoolExecutor

if n_jobs == 1:
it: Iterable[Any] = map(fn, items)
if show_progress_bar and tqdm is not None:
it = tqdm(it, total=len(items), unit=unit)
return list(it)

with ThreadPoolExecutor(max_workers=n_jobs) as pool:
it = pool.map(fn, items)
if show_progress_bar and tqdm is not None:
it = tqdm(it, total=len(items), unit=unit)
return list(it)
items = list(items)

with progress_bar(len(items), show_progress_bar=show_progress_bar, unit=unit) as pbar:
if n_jobs == 1:
results = []
for item in items:
results.append(fn(item))
if pbar is not None:
pbar.update(1)
return results

with ThreadPoolExecutor(max_workers=n_jobs) as pool:
results = []
# ``pool.map`` yields in submission order, so results stay aligned with *items*.
for res in pool.map(fn, items):
results.append(res)
if pbar is not None:
pbar.update(1)
return results


def _get_n_cores(n_cores: int | None) -> int:
Expand Down
Loading