Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
Original file line number Diff line number Diff line change
Expand Up @@ -378,6 +378,9 @@ def execute(self, context: Context):
waiter_delay=self.poll_interval,
waiter_max_attempts=self.max_attempt,
aws_conn_id=self.aws_conn_id,
region_name=self.region_name,
verify=self.verify,
botocore_config=self.botocore_config,
),
method_name="execute_complete",
)
Expand Down Expand Up @@ -497,6 +500,9 @@ def execute(self, context: Context) -> Any:
waiter_delay=self.poll_interval,
waiter_max_attempts=self.max_attempt,
aws_conn_id=self.aws_conn_id,
region_name=self.region_name,
verify=self.verify,
botocore_config=self.botocore_config,
),
method_name="execute_complete",
# timeout is set to ensure that if a trigger dies, the timeout does not restart
Expand Down Expand Up @@ -668,6 +674,9 @@ def execute(self, context: Context):
waiter_delay=self.poll_interval,
waiter_max_attempts=self.max_attempts,
aws_conn_id=self.aws_conn_id,
region_name=self.region_name,
verify=self.verify,
botocore_config=self.botocore_config,
),
method_name="execute_complete",
# timeout is set to ensure that if a trigger dies, the timeout does not restart
Expand Down Expand Up @@ -775,6 +784,9 @@ def execute(self, context: Context):
waiter_delay=self.poll_interval,
waiter_max_attempts=self.max_attempts,
aws_conn_id=self.aws_conn_id,
region_name=self.region_name,
verify=self.verify,
botocore_config=self.botocore_config,
),
method_name="execute_complete",
# timeout is set to ensure that if a trigger dies, the timeout does not restart
Expand Down Expand Up @@ -901,6 +913,9 @@ def execute(self, context: Context):
waiter_delay=self.poll_interval,
waiter_max_attempts=self.max_attempts,
aws_conn_id=self.aws_conn_id,
region_name=self.region_name,
verify=self.verify,
botocore_config=self.botocore_config,
),
method_name="execute_complete",
)
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -39,14 +39,20 @@ class RedshiftCreateClusterTrigger(AwsBaseWaiterTrigger):
:param waiter_delay: The amount of time in seconds to wait between attempts.
:param waiter_max_attempts: The maximum number of attempts to be made.
:param aws_conn_id: The Airflow connection used for AWS credentials.
:param region_name: The AWS region where the cluster is. Used to build the hook.
:param verify: Whether or not to verify SSL certificates. Used to build the hook.
:param botocore_config: Configuration dictionary for the botocore client. Used to build the hook.
"""

def __init__(
self,
*,
cluster_identifier: str,
aws_conn_id: str | None = "aws_default",
region_name: str | None = None,
Comment thread
KarshVashi marked this conversation as resolved.
waiter_delay: int = 15,
waiter_max_attempts: int = 999999,
**kwargs,
):
super().__init__(
serialized_fields={"cluster_identifier": cluster_identifier},
Expand All @@ -59,10 +65,17 @@ def __init__(
waiter_delay=waiter_delay,
waiter_max_attempts=waiter_max_attempts,
aws_conn_id=aws_conn_id,
region_name=region_name,
**kwargs,
)

def hook(self) -> AwsGenericHook:
return RedshiftHook(aws_conn_id=self.aws_conn_id)
return RedshiftHook(
aws_conn_id=self.aws_conn_id,
region_name=self.region_name,
verify=self.verify,
config=self.botocore_config,
)


class RedshiftPauseClusterTrigger(AwsBaseWaiterTrigger):
Expand All @@ -76,14 +89,20 @@ class RedshiftPauseClusterTrigger(AwsBaseWaiterTrigger):
:param waiter_delay: The amount of time in seconds to wait between attempts.
:param waiter_max_attempts: The maximum number of attempts to be made.
:param aws_conn_id: The Airflow connection used for AWS credentials.
:param region_name: The AWS region where the cluster is. Used to build the hook.
:param verify: Whether or not to verify SSL certificates. Used to build the hook.
:param botocore_config: Configuration dictionary for the botocore client. Used to build the hook.
"""

def __init__(
self,
*,
cluster_identifier: str,
aws_conn_id: str | None = "aws_default",
region_name: str | None = None,
waiter_delay: int = 15,
waiter_max_attempts: int = 999999,
**kwargs,
):
super().__init__(
serialized_fields={"cluster_identifier": cluster_identifier},
Expand All @@ -96,10 +115,17 @@ def __init__(
waiter_delay=waiter_delay,
waiter_max_attempts=waiter_max_attempts,
aws_conn_id=aws_conn_id,
region_name=region_name,
**kwargs,
)

def hook(self) -> AwsGenericHook:
return RedshiftHook(aws_conn_id=self.aws_conn_id)
return RedshiftHook(
aws_conn_id=self.aws_conn_id,
region_name=self.region_name,
verify=self.verify,
config=self.botocore_config,
)


class RedshiftCreateClusterSnapshotTrigger(AwsBaseWaiterTrigger):
Expand All @@ -113,14 +139,20 @@ class RedshiftCreateClusterSnapshotTrigger(AwsBaseWaiterTrigger):
:param waiter_delay: The amount of time in seconds to wait between attempts.
:param waiter_max_attempts: The maximum number of attempts to be made.
:param aws_conn_id: The Airflow connection used for AWS credentials.
:param region_name: The AWS region where the cluster is. Used to build the hook.
:param verify: Whether or not to verify SSL certificates. Used to build the hook.
:param botocore_config: Configuration dictionary for the botocore client. Used to build the hook.
"""

def __init__(
self,
*,
cluster_identifier: str,
aws_conn_id: str | None = "aws_default",
region_name: str | None = None,
waiter_delay: int = 15,
waiter_max_attempts: int = 999999,
**kwargs,
):
super().__init__(
serialized_fields={"cluster_identifier": cluster_identifier},
Expand All @@ -133,10 +165,17 @@ def __init__(
waiter_delay=waiter_delay,
waiter_max_attempts=waiter_max_attempts,
aws_conn_id=aws_conn_id,
region_name=region_name,
**kwargs,
)

def hook(self) -> AwsGenericHook:
return RedshiftHook(aws_conn_id=self.aws_conn_id)
return RedshiftHook(
aws_conn_id=self.aws_conn_id,
region_name=self.region_name,
verify=self.verify,
config=self.botocore_config,
)


class RedshiftResumeClusterTrigger(AwsBaseWaiterTrigger):
Expand All @@ -150,14 +189,20 @@ class RedshiftResumeClusterTrigger(AwsBaseWaiterTrigger):
:param waiter_delay: The amount of time in seconds to wait between attempts.
:param waiter_max_attempts: The maximum number of attempts to be made.
:param aws_conn_id: The Airflow connection used for AWS credentials.
:param region_name: The AWS region where the cluster is. Used to build the hook.
:param verify: Whether or not to verify SSL certificates. Used to build the hook.
:param botocore_config: Configuration dictionary for the botocore client. Used to build the hook.
"""

def __init__(
self,
*,
cluster_identifier: str,
aws_conn_id: str | None = "aws_default",
region_name: str | None = None,
waiter_delay: int = 15,
waiter_max_attempts: int = 999999,
**kwargs,
):
super().__init__(
serialized_fields={"cluster_identifier": cluster_identifier},
Expand All @@ -170,10 +215,17 @@ def __init__(
waiter_delay=waiter_delay,
waiter_max_attempts=waiter_max_attempts,
aws_conn_id=aws_conn_id,
region_name=region_name,
**kwargs,
)

def hook(self) -> AwsGenericHook:
return RedshiftHook(aws_conn_id=self.aws_conn_id)
return RedshiftHook(
aws_conn_id=self.aws_conn_id,
region_name=self.region_name,
verify=self.verify,
config=self.botocore_config,
)


class RedshiftDeleteClusterTrigger(AwsBaseWaiterTrigger):
Expand All @@ -184,14 +236,20 @@ class RedshiftDeleteClusterTrigger(AwsBaseWaiterTrigger):
:param waiter_max_attempts: The maximum number of attempts to be made.
:param aws_conn_id: The Airflow connection used for AWS credentials.
:param waiter_delay: The amount of time in seconds to wait between attempts.
:param region_name: The AWS region where the cluster is. Used to build the hook.
:param verify: Whether or not to verify SSL certificates. Used to build the hook.
:param botocore_config: Configuration dictionary for the botocore client. Used to build the hook.
"""

def __init__(
self,
*,
cluster_identifier: str,
aws_conn_id: str | None = "aws_default",
region_name: str | None = None,
waiter_delay: int = 30,
waiter_max_attempts: int = 30,
**kwargs,
):
super().__init__(
serialized_fields={"cluster_identifier": cluster_identifier},
Expand All @@ -204,10 +262,17 @@ def __init__(
waiter_delay=waiter_delay,
waiter_max_attempts=waiter_max_attempts,
aws_conn_id=aws_conn_id,
region_name=region_name,
**kwargs,
)

def hook(self) -> AwsGenericHook:
return RedshiftHook(aws_conn_id=self.aws_conn_id)
return RedshiftHook(
aws_conn_id=self.aws_conn_id,
region_name=self.region_name,
verify=self.verify,
config=self.botocore_config,
)


class RedshiftClusterTrigger(BaseTrigger):
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -24,6 +24,11 @@

from airflow.providers.amazon.aws.triggers.redshift_cluster import (
RedshiftClusterTrigger,
RedshiftCreateClusterSnapshotTrigger,
RedshiftCreateClusterTrigger,
RedshiftDeleteClusterTrigger,
RedshiftPauseClusterTrigger,
RedshiftResumeClusterTrigger,
)
from airflow.triggers.base import TriggerEvent

Expand Down Expand Up @@ -115,3 +120,106 @@ async def test_redshift_cluster_sensor_trigger_exception(self, mock_cluster_stat
# so we validate for length of task to be 1
assert len(task) == 1
assert TriggerEvent({"status": "error", "message": "Test exception"}) in task


WAITER_TRIGGER_PARAMS = [
pytest.param(
RedshiftCreateClusterTrigger,
15,
999999,
id="RedshiftCreateClusterTrigger",
),
pytest.param(
RedshiftPauseClusterTrigger,
15,
999999,
id="RedshiftPauseClusterTrigger",
),
pytest.param(
RedshiftCreateClusterSnapshotTrigger,
15,
999999,
id="RedshiftCreateClusterSnapshotTrigger",
),
pytest.param(
RedshiftResumeClusterTrigger,
15,
999999,
id="RedshiftResumeClusterTrigger",
),
pytest.param(
RedshiftDeleteClusterTrigger,
30,
30,
id="RedshiftDeleteClusterTrigger",
),
]


class TestRedshiftWaiterTriggers:
"""Tests for the five Redshift triggers that inherit from ``AwsBaseWaiterTrigger``."""

@pytest.mark.parametrize(
("trigger_cls", "default_delay", "default_max_attempts"),
WAITER_TRIGGER_PARAMS,
)
def test_serialization(self, trigger_cls, default_delay, default_max_attempts):
trigger = trigger_cls(
cluster_identifier="test_cluster",
aws_conn_id="aws_default",
region_name="us-east-1",
)

classpath, kwargs = trigger.serialize()
assert classpath == f"airflow.providers.amazon.aws.triggers.redshift_cluster.{trigger_cls.__name__}"
assert kwargs == {
"cluster_identifier": "test_cluster",
"waiter_delay": default_delay,
"waiter_max_attempts": default_max_attempts,
"aws_conn_id": "aws_default",
"region_name": "us-east-1",
}

@pytest.mark.parametrize(
("trigger_cls", "default_delay", "default_max_attempts"),
WAITER_TRIGGER_PARAMS,
)
def test_serialization_with_verify_and_botocore_config(
self, trigger_cls, default_delay, default_max_attempts
):
trigger = trigger_cls(
cluster_identifier="test_cluster",
aws_conn_id="aws_default",
verify=False,
botocore_config={"connect_timeout": 30},
)

_, kwargs = trigger.serialize()
assert kwargs["verify"] is False
assert kwargs["botocore_config"] == {"connect_timeout": 30}
Comment thread
KarshVashi marked this conversation as resolved.
assert "region_name" not in kwargs

@pytest.mark.parametrize(
("trigger_cls", "default_delay", "default_max_attempts"),
WAITER_TRIGGER_PARAMS,
)
@mock.patch("airflow.providers.amazon.aws.triggers.redshift_cluster.RedshiftHook")
def test_hook_propagates_verify_and_botocore_config(
self, mock_hook_cls, trigger_cls, default_delay, default_max_attempts
):
trigger = trigger_cls(
cluster_identifier="test_cluster",
aws_conn_id="test_conn",
region_name="eu-west-1",
verify="/path/to/ca-bundle.crt",
botocore_config={"read_timeout": 60},
)

trigger.hook()

mock_hook_cls.assert_called_once_with(
aws_conn_id="test_conn",
region_name="eu-west-1",
verify="/path/to/ca-bundle.crt",
config={"read_timeout": 60},
)
Loading