Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
4 changes: 2 additions & 2 deletions papermill/engines.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,9 +3,9 @@
import datetime
import sys
from functools import wraps
from importlib.metadata import entry_points

import dateutil
import entrypoints

from .clientwrap import PapermillNotebookClient
from .exceptions import PapermillException
Expand Down Expand Up @@ -34,7 +34,7 @@ def register_entry_points(self):

Load handlers provided by other packages
"""
for entrypoint in entrypoints.get_group_all("papermill.engine"):
for entrypoint in entry_points().select(group="papermill.engine"):
self.register(entrypoint.name, entrypoint.load())

def get_engine(self, name=None):
Expand Down
4 changes: 2 additions & 2 deletions papermill/iorw.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,8 +4,8 @@
import sys
import warnings
from contextlib import contextmanager
from importlib.metadata import entry_points

import entrypoints
import nbformat
import requests
import yaml
Expand Down Expand Up @@ -116,7 +116,7 @@ def register(self, scheme, handler):

def register_entry_points(self):
# Load handlers provided by other packages
for entrypoint in entrypoints.get_group_all("papermill.io"):
for entrypoint in entry_points().select(group="papermill.io"):
self.register(entrypoint.name, entrypoint.load())

def get_handler(self, path, extensions=None):
Expand Down
2 changes: 1 addition & 1 deletion papermill/tests/test_abs.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,7 +4,7 @@

from azure.identity import EnvironmentCredential

from ..abs import AzureBlobStore
from papermill.abs import AzureBlobStore


class MockBytesIO:
Expand Down
6 changes: 3 additions & 3 deletions papermill/tests/test_adl.py
Original file line number Diff line number Diff line change
@@ -1,9 +1,9 @@
import unittest
from unittest.mock import MagicMock, Mock, patch

from ..adl import ADL
from ..adl import core as adl_core
from ..adl import lib as adl_lib
from papermill.adl import ADL
from papermill.adl import core as adl_core
from papermill.adl import lib as adl_lib


class ADLTest(unittest.TestCase):
Expand Down
7 changes: 4 additions & 3 deletions papermill/tests/test_autosave.py
Original file line number Diff line number Diff line change
Expand Up @@ -6,9 +6,10 @@

import nbformat

from .. import engines
from ..engines import NotebookExecutionManager
from ..execute import execute_notebook
from papermill import engines
from papermill.engines import NotebookExecutionManager
from papermill.execute import execute_notebook

from . import get_notebook_path


Expand Down
5 changes: 3 additions & 2 deletions papermill/tests/test_cli.py
Original file line number Diff line number Diff line change
Expand Up @@ -15,8 +15,9 @@
import pytest
from click.testing import CliRunner

from .. import cli
from ..cli import _is_float, _is_int, _resolve_type, papermill
from papermill import cli
from papermill.cli import _is_float, _is_int, _resolve_type, papermill

from . import get_notebook_path, kernel_name


Expand Down
7 changes: 4 additions & 3 deletions papermill/tests/test_clientwrap.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,9 +3,10 @@

import nbformat

from ..clientwrap import PapermillNotebookClient
from ..engines import NotebookExecutionManager
from ..log import logger
from papermill.clientwrap import PapermillNotebookClient
from papermill.engines import NotebookExecutionManager
from papermill.log import logger

from . import get_notebook_path


Expand Down
19 changes: 12 additions & 7 deletions papermill/tests/test_engines.py
Original file line number Diff line number Diff line change
Expand Up @@ -6,10 +6,11 @@
import dateutil
from nbformat.notebooknode import NotebookNode

from .. import engines, exceptions
from ..engines import Engine, NBClientEngine, NotebookExecutionManager
from ..iorw import load_notebook_node
from ..log import logger
from papermill import engines, exceptions
from papermill.engines import Engine, NBClientEngine, NotebookExecutionManager
from papermill.iorw import load_notebook_node
from papermill.log import logger

from . import get_notebook_path


Expand Down Expand Up @@ -489,10 +490,14 @@ def test_getting(self):
self.assertRaises(exceptions.PapermillException, self.papermill_engines.get_engine, "non-existent")

def test_registering_entry_points(self):
fake_entrypoint = Mock(load=Mock())
fake_entrypoint = Mock()
fake_entrypoint.name = "fake-engine"
fake_entrypoint.load.return_value = Mock()

mock_entry_points = Mock()
mock_entry_points.select.return_value = [fake_entrypoint]

with patch("entrypoints.get_group_all", return_value=[fake_entrypoint]) as mock_get_group_all:
with patch("papermill.engines.entry_points", return_value=mock_entry_points):
self.papermill_engines.register_entry_points()
mock_get_group_all.assert_called_once_with("papermill.engine")
mock_entry_points.select.assert_called_once_with(group="papermill.engine")
self.assertEqual(self.papermill_engines.get_engine("fake-engine"), fake_entrypoint.load.return_value)
2 changes: 1 addition & 1 deletion papermill/tests/test_exceptions.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,7 +4,7 @@

import pytest

from .. import exceptions
from papermill import exceptions


@pytest.fixture
Expand Down
13 changes: 7 additions & 6 deletions papermill/tests/test_execute.py
Original file line number Diff line number Diff line change
Expand Up @@ -10,12 +10,13 @@
import nbformat
from nbformat import validate

from .. import engines, translators
from ..exceptions import PapermillExecutionError, strip_color
from ..execute import execute_notebook
from ..iorw import load_notebook_node
from ..log import logger
from ..utils import chdir
from papermill import engines, translators
from papermill.exceptions import PapermillExecutionError, strip_color
from papermill.execute import execute_notebook
from papermill.iorw import load_notebook_node
from papermill.log import logger
from papermill.utils import chdir

from . import get_notebook_path, kernel_name

execute_notebook = partial(execute_notebook, kernel_name=kernel_name)
Expand Down
4 changes: 2 additions & 2 deletions papermill/tests/test_gcs.py
Original file line number Diff line number Diff line change
@@ -1,8 +1,8 @@
import unittest
from unittest.mock import patch

from ..exceptions import PapermillRateLimitException
from ..iorw import GCSHandler, fallback_gs_is_retriable
from papermill.exceptions import PapermillRateLimitException
from papermill.iorw import GCSHandler, fallback_gs_is_retriable

try:
try:
Expand Down
2 changes: 1 addition & 1 deletion papermill/tests/test_hdfs.py
Original file line number Diff line number Diff line change
@@ -1,7 +1,7 @@
import unittest
from unittest.mock import MagicMock, patch

from ..iorw import HDFSHandler
from papermill.iorw import HDFSHandler


class MockHadoopFileSystem(MagicMock):
Expand Down
18 changes: 12 additions & 6 deletions papermill/tests/test_iorw.py
Original file line number Diff line number Diff line change
Expand Up @@ -10,9 +10,9 @@
import pytest
from requests.exceptions import ConnectionError

from .. import iorw
from ..exceptions import PapermillException
from ..iorw import (
from papermill import iorw
from papermill.exceptions import PapermillException
from papermill.iorw import (
ADLHandler,
HttpHandler,
LocalHandler,
Expand All @@ -24,6 +24,7 @@
papermill_io,
read_yaml_file,
)

from . import get_notebook_path

FIXTURE_PATH = os.path.join(os.path.dirname(__file__), 'fixtures')
Expand Down Expand Up @@ -101,12 +102,16 @@ def test_get_notebook_node_handler(self):
self.assertIsInstance(self.papermill_io.get_handler(test_nb), NotebookNodeHandler)

def test_entrypoint_register(self):
fake_entrypoint = Mock(load=Mock())
fake_entrypoint = Mock()
fake_entrypoint.name = "fake-from-entry-point://"
fake_entrypoint.load.return_value = Mock()

mock_entry_points = Mock()
mock_entry_points.select.return_value = [fake_entrypoint]

with patch("entrypoints.get_group_all", return_value=[fake_entrypoint]) as mock_get_group_all:
with patch("papermill.iorw.entry_points", return_value=mock_entry_points):
self.papermill_io.register_entry_points()
mock_get_group_all.assert_called_once_with("papermill.io")
mock_entry_points.select.assert_called_once_with(group="papermill.io")
fake_ = self.papermill_io.get_handler("fake-from-entry-point://")
assert fake_ == fake_entrypoint.load.return_value

Expand Down Expand Up @@ -206,6 +211,7 @@ def test_write_local_directory(self):
with patch.object(io, 'open'):
# Shouldn't raise with missing directory
LocalHandler().write("buffer", "local.ipynb")
os.unlink("local.ipynb")

def test_write_passed_cwd(self):
with TemporaryDirectory() as temp_dir:
Expand Down
7 changes: 4 additions & 3 deletions papermill/tests/test_parameterize.py
Original file line number Diff line number Diff line change
@@ -1,9 +1,10 @@
import unittest
from datetime import datetime

from ..exceptions import PapermillMissingParameterException
from ..iorw import load_notebook_node
from ..parameterize import add_builtin_parameters, parameterize_notebook, parameterize_path
from papermill.exceptions import PapermillMissingParameterException
from papermill.iorw import load_notebook_node
from papermill.parameterize import add_builtin_parameters, parameterize_notebook, parameterize_path

from . import get_notebook_path


Expand Down
4 changes: 2 additions & 2 deletions papermill/tests/test_s3.py
Original file line number Diff line number Diff line change
Expand Up @@ -7,7 +7,7 @@
import pytest
from moto import mock_aws

from ..s3 import S3, Bucket, Key, Prefix
from papermill.s3 import S3, Bucket, Key, Prefix


@pytest.fixture
Expand Down Expand Up @@ -168,7 +168,7 @@ def s3_client():
mock_aws.start()

client = boto3.client('s3')
client.create_bucket(Bucket=test_bucket_name, CreateBucketConfiguration={'LocationConstraint': 'us-west-2'})
client.create_bucket(Bucket=test_bucket_name, CreateBucketConfiguration={'LocationConstraint': 'eu-central-1'})
client.put_object(Bucket=test_bucket_name, Key=test_file_path, Body=test_nb_content)
client.put_object(Bucket=test_bucket_name, Key=test_empty_file_path, Body='')
yield S3()
Expand Down
6 changes: 3 additions & 3 deletions papermill/tests/test_translators.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,9 +4,9 @@
import pytest
from nbformat.v4 import new_code_cell

from .. import translators
from ..exceptions import PapermillException
from ..models import Parameter
from papermill import translators
from papermill.exceptions import PapermillException
from papermill.models import Parameter


@pytest.mark.parametrize(
Expand Down
4 changes: 2 additions & 2 deletions papermill/tests/test_utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -6,8 +6,8 @@
import pytest
from nbformat.v4 import new_code_cell, new_notebook

from ..exceptions import PapermillParameterOverwriteWarning
from ..utils import (
from papermill.exceptions import PapermillParameterOverwriteWarning
from papermill.utils import (
any_tagged_cell,
chdir,
merge_kwargs,
Expand Down
1 change: 1 addition & 0 deletions pyproject.toml
Original file line number Diff line number Diff line change
Expand Up @@ -197,6 +197,7 @@ quiet-level = 3
ignore-words-list = "dne, compiletime"

[tool.pytest.ini_options]
testpaths = [ "papermill/tests" ]
env = [
"AWS_SECRET_ACCESS_KEY=foobar_secret",
"AWS_ACCESS_KEY_ID=foobar_key",
Expand Down
1 change: 0 additions & 1 deletion requirements.txt
Original file line number Diff line number Diff line change
Expand Up @@ -4,6 +4,5 @@ nbformat >= 5.2.0
nbclient >= 0.2.0
tqdm >= 4.32.2
requests
entrypoints
tenacity >= 5.0.2
aiohttp >=3.9.0; python_version=="3.12"