diff --git a/papermill/engines.py b/papermill/engines.py index c44d548d..89f93650 100644 --- a/papermill/engines.py +++ b/papermill/engines.py @@ -3,9 +3,9 @@ import datetime import sys from functools import wraps +from importlib.metadata import entry_points import dateutil -import entrypoints from .clientwrap import PapermillNotebookClient from .exceptions import PapermillException @@ -34,7 +34,7 @@ def register_entry_points(self): Load handlers provided by other packages """ - for entrypoint in entrypoints.get_group_all("papermill.engine"): + for entrypoint in entry_points().select(group="papermill.engine"): self.register(entrypoint.name, entrypoint.load()) def get_engine(self, name=None): diff --git a/papermill/iorw.py b/papermill/iorw.py index 14a0122c..50f0f3c8 100644 --- a/papermill/iorw.py +++ b/papermill/iorw.py @@ -4,8 +4,8 @@ import sys import warnings from contextlib import contextmanager +from importlib.metadata import entry_points -import entrypoints import nbformat import requests import yaml @@ -116,7 +116,7 @@ def register(self, scheme, handler): def register_entry_points(self): # Load handlers provided by other packages - for entrypoint in entrypoints.get_group_all("papermill.io"): + for entrypoint in entry_points().select(group="papermill.io"): self.register(entrypoint.name, entrypoint.load()) def get_handler(self, path, extensions=None): diff --git a/papermill/tests/test_abs.py b/papermill/tests/test_abs.py index 57dff73f..cdbd0358 100644 --- a/papermill/tests/test_abs.py +++ b/papermill/tests/test_abs.py @@ -4,7 +4,7 @@ from azure.identity import EnvironmentCredential -from ..abs import AzureBlobStore +from papermill.abs import AzureBlobStore class MockBytesIO: diff --git a/papermill/tests/test_adl.py b/papermill/tests/test_adl.py index 3aa98544..1c27f978 100644 --- a/papermill/tests/test_adl.py +++ b/papermill/tests/test_adl.py @@ -1,9 +1,9 @@ import unittest from unittest.mock import MagicMock, Mock, patch -from ..adl import ADL -from ..adl import core as adl_core -from ..adl import lib as adl_lib +from papermill.adl import ADL +from papermill.adl import core as adl_core +from papermill.adl import lib as adl_lib class ADLTest(unittest.TestCase): diff --git a/papermill/tests/test_autosave.py b/papermill/tests/test_autosave.py index 74ae06e8..dd3f80ae 100644 --- a/papermill/tests/test_autosave.py +++ b/papermill/tests/test_autosave.py @@ -6,9 +6,10 @@ import nbformat -from .. import engines -from ..engines import NotebookExecutionManager -from ..execute import execute_notebook +from papermill import engines +from papermill.engines import NotebookExecutionManager +from papermill.execute import execute_notebook + from . import get_notebook_path diff --git a/papermill/tests/test_cli.py b/papermill/tests/test_cli.py index 72ef5f9a..c25161a5 100755 --- a/papermill/tests/test_cli.py +++ b/papermill/tests/test_cli.py @@ -15,8 +15,9 @@ import pytest from click.testing import CliRunner -from .. import cli -from ..cli import _is_float, _is_int, _resolve_type, papermill +from papermill import cli +from papermill.cli import _is_float, _is_int, _resolve_type, papermill + from . import get_notebook_path, kernel_name diff --git a/papermill/tests/test_clientwrap.py b/papermill/tests/test_clientwrap.py index cfa2a81a..4409f2bd 100644 --- a/papermill/tests/test_clientwrap.py +++ b/papermill/tests/test_clientwrap.py @@ -3,9 +3,10 @@ import nbformat -from ..clientwrap import PapermillNotebookClient -from ..engines import NotebookExecutionManager -from ..log import logger +from papermill.clientwrap import PapermillNotebookClient +from papermill.engines import NotebookExecutionManager +from papermill.log import logger + from . import get_notebook_path diff --git a/papermill/tests/test_engines.py b/papermill/tests/test_engines.py index db5ee17c..d6111b1b 100644 --- a/papermill/tests/test_engines.py +++ b/papermill/tests/test_engines.py @@ -6,10 +6,11 @@ import dateutil from nbformat.notebooknode import NotebookNode -from .. import engines, exceptions -from ..engines import Engine, NBClientEngine, NotebookExecutionManager -from ..iorw import load_notebook_node -from ..log import logger +from papermill import engines, exceptions +from papermill.engines import Engine, NBClientEngine, NotebookExecutionManager +from papermill.iorw import load_notebook_node +from papermill.log import logger + from . import get_notebook_path @@ -489,10 +490,14 @@ def test_getting(self): self.assertRaises(exceptions.PapermillException, self.papermill_engines.get_engine, "non-existent") def test_registering_entry_points(self): - fake_entrypoint = Mock(load=Mock()) + fake_entrypoint = Mock() fake_entrypoint.name = "fake-engine" + fake_entrypoint.load.return_value = Mock() + + mock_entry_points = Mock() + mock_entry_points.select.return_value = [fake_entrypoint] - with patch("entrypoints.get_group_all", return_value=[fake_entrypoint]) as mock_get_group_all: + with patch("papermill.engines.entry_points", return_value=mock_entry_points): self.papermill_engines.register_entry_points() - mock_get_group_all.assert_called_once_with("papermill.engine") + mock_entry_points.select.assert_called_once_with(group="papermill.engine") self.assertEqual(self.papermill_engines.get_engine("fake-engine"), fake_entrypoint.load.return_value) diff --git a/papermill/tests/test_exceptions.py b/papermill/tests/test_exceptions.py index 0a7e7a8d..6aaa72a9 100644 --- a/papermill/tests/test_exceptions.py +++ b/papermill/tests/test_exceptions.py @@ -4,7 +4,7 @@ import pytest -from .. import exceptions +from papermill import exceptions @pytest.fixture diff --git a/papermill/tests/test_execute.py b/papermill/tests/test_execute.py index 09600d64..d0584e11 100644 --- a/papermill/tests/test_execute.py +++ b/papermill/tests/test_execute.py @@ -10,12 +10,13 @@ import nbformat from nbformat import validate -from .. import engines, translators -from ..exceptions import PapermillExecutionError, strip_color -from ..execute import execute_notebook -from ..iorw import load_notebook_node -from ..log import logger -from ..utils import chdir +from papermill import engines, translators +from papermill.exceptions import PapermillExecutionError, strip_color +from papermill.execute import execute_notebook +from papermill.iorw import load_notebook_node +from papermill.log import logger +from papermill.utils import chdir + from . import get_notebook_path, kernel_name execute_notebook = partial(execute_notebook, kernel_name=kernel_name) diff --git a/papermill/tests/test_gcs.py b/papermill/tests/test_gcs.py index ebd635a0..081f05d2 100644 --- a/papermill/tests/test_gcs.py +++ b/papermill/tests/test_gcs.py @@ -1,8 +1,8 @@ import unittest from unittest.mock import patch -from ..exceptions import PapermillRateLimitException -from ..iorw import GCSHandler, fallback_gs_is_retriable +from papermill.exceptions import PapermillRateLimitException +from papermill.iorw import GCSHandler, fallback_gs_is_retriable try: try: diff --git a/papermill/tests/test_hdfs.py b/papermill/tests/test_hdfs.py index 44c024df..1e5da20c 100644 --- a/papermill/tests/test_hdfs.py +++ b/papermill/tests/test_hdfs.py @@ -1,7 +1,7 @@ import unittest from unittest.mock import MagicMock, patch -from ..iorw import HDFSHandler +from papermill.iorw import HDFSHandler class MockHadoopFileSystem(MagicMock): diff --git a/papermill/tests/test_iorw.py b/papermill/tests/test_iorw.py index ab09f01a..555cb00a 100644 --- a/papermill/tests/test_iorw.py +++ b/papermill/tests/test_iorw.py @@ -10,9 +10,9 @@ import pytest from requests.exceptions import ConnectionError -from .. import iorw -from ..exceptions import PapermillException -from ..iorw import ( +from papermill import iorw +from papermill.exceptions import PapermillException +from papermill.iorw import ( ADLHandler, HttpHandler, LocalHandler, @@ -24,6 +24,7 @@ papermill_io, read_yaml_file, ) + from . import get_notebook_path FIXTURE_PATH = os.path.join(os.path.dirname(__file__), 'fixtures') @@ -101,12 +102,16 @@ def test_get_notebook_node_handler(self): self.assertIsInstance(self.papermill_io.get_handler(test_nb), NotebookNodeHandler) def test_entrypoint_register(self): - fake_entrypoint = Mock(load=Mock()) + fake_entrypoint = Mock() fake_entrypoint.name = "fake-from-entry-point://" + fake_entrypoint.load.return_value = Mock() + + mock_entry_points = Mock() + mock_entry_points.select.return_value = [fake_entrypoint] - with patch("entrypoints.get_group_all", return_value=[fake_entrypoint]) as mock_get_group_all: + with patch("papermill.iorw.entry_points", return_value=mock_entry_points): self.papermill_io.register_entry_points() - mock_get_group_all.assert_called_once_with("papermill.io") + mock_entry_points.select.assert_called_once_with(group="papermill.io") fake_ = self.papermill_io.get_handler("fake-from-entry-point://") assert fake_ == fake_entrypoint.load.return_value @@ -206,6 +211,7 @@ def test_write_local_directory(self): with patch.object(io, 'open'): # Shouldn't raise with missing directory LocalHandler().write("buffer", "local.ipynb") + os.unlink("local.ipynb") def test_write_passed_cwd(self): with TemporaryDirectory() as temp_dir: diff --git a/papermill/tests/test_parameterize.py b/papermill/tests/test_parameterize.py index 431caa12..e1235600 100644 --- a/papermill/tests/test_parameterize.py +++ b/papermill/tests/test_parameterize.py @@ -1,9 +1,10 @@ import unittest from datetime import datetime -from ..exceptions import PapermillMissingParameterException -from ..iorw import load_notebook_node -from ..parameterize import add_builtin_parameters, parameterize_notebook, parameterize_path +from papermill.exceptions import PapermillMissingParameterException +from papermill.iorw import load_notebook_node +from papermill.parameterize import add_builtin_parameters, parameterize_notebook, parameterize_path + from . import get_notebook_path diff --git a/papermill/tests/test_s3.py b/papermill/tests/test_s3.py index bf006830..34060765 100644 --- a/papermill/tests/test_s3.py +++ b/papermill/tests/test_s3.py @@ -7,7 +7,7 @@ import pytest from moto import mock_aws -from ..s3 import S3, Bucket, Key, Prefix +from papermill.s3 import S3, Bucket, Key, Prefix @pytest.fixture @@ -168,7 +168,7 @@ def s3_client(): mock_aws.start() client = boto3.client('s3') - client.create_bucket(Bucket=test_bucket_name, CreateBucketConfiguration={'LocationConstraint': 'us-west-2'}) + client.create_bucket(Bucket=test_bucket_name, CreateBucketConfiguration={'LocationConstraint': 'eu-central-1'}) client.put_object(Bucket=test_bucket_name, Key=test_file_path, Body=test_nb_content) client.put_object(Bucket=test_bucket_name, Key=test_empty_file_path, Body='') yield S3() diff --git a/papermill/tests/test_translators.py b/papermill/tests/test_translators.py index 0edc1f07..95651bf1 100644 --- a/papermill/tests/test_translators.py +++ b/papermill/tests/test_translators.py @@ -4,9 +4,9 @@ import pytest from nbformat.v4 import new_code_cell -from .. import translators -from ..exceptions import PapermillException -from ..models import Parameter +from papermill import translators +from papermill.exceptions import PapermillException +from papermill.models import Parameter @pytest.mark.parametrize( diff --git a/papermill/tests/test_utils.py b/papermill/tests/test_utils.py index 4228b436..a6ef8a54 100644 --- a/papermill/tests/test_utils.py +++ b/papermill/tests/test_utils.py @@ -6,8 +6,8 @@ import pytest from nbformat.v4 import new_code_cell, new_notebook -from ..exceptions import PapermillParameterOverwriteWarning -from ..utils import ( +from papermill.exceptions import PapermillParameterOverwriteWarning +from papermill.utils import ( any_tagged_cell, chdir, merge_kwargs, diff --git a/pyproject.toml b/pyproject.toml index 99336a0a..91d54c4f 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -197,6 +197,7 @@ quiet-level = 3 ignore-words-list = "dne, compiletime" [tool.pytest.ini_options] +testpaths = [ "papermill/tests" ] env = [ "AWS_SECRET_ACCESS_KEY=foobar_secret", "AWS_ACCESS_KEY_ID=foobar_key", diff --git a/requirements.txt b/requirements.txt index 6f8ebb8a..3bee7442 100644 --- a/requirements.txt +++ b/requirements.txt @@ -4,6 +4,5 @@ nbformat >= 5.2.0 nbclient >= 0.2.0 tqdm >= 4.32.2 requests -entrypoints tenacity >= 5.0.2 aiohttp >=3.9.0; python_version=="3.12"