Skip to content
Merged
Show file tree
Hide file tree
Changes from 1 commit
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
Original file line number Diff line number Diff line change
Expand Up @@ -378,6 +378,9 @@ def execute(self, context: Context):
waiter_delay=self.poll_interval,
waiter_max_attempts=self.max_attempt,
aws_conn_id=self.aws_conn_id,
region_name=self.region_name,
verify=self.verify,
botocore_config=self.botocore_config,
),
method_name="execute_complete",
)
Expand Down Expand Up @@ -497,6 +500,9 @@ def execute(self, context: Context) -> Any:
waiter_delay=self.poll_interval,
waiter_max_attempts=self.max_attempt,
aws_conn_id=self.aws_conn_id,
region_name=self.region_name,
verify=self.verify,
botocore_config=self.botocore_config,
),
method_name="execute_complete",
# timeout is set to ensure that if a trigger dies, the timeout does not restart
Expand Down Expand Up @@ -668,6 +674,9 @@ def execute(self, context: Context):
waiter_delay=self.poll_interval,
waiter_max_attempts=self.max_attempts,
aws_conn_id=self.aws_conn_id,
region_name=self.region_name,
verify=self.verify,
botocore_config=self.botocore_config,
),
method_name="execute_complete",
# timeout is set to ensure that if a trigger dies, the timeout does not restart
Expand Down Expand Up @@ -775,6 +784,9 @@ def execute(self, context: Context):
waiter_delay=self.poll_interval,
waiter_max_attempts=self.max_attempts,
aws_conn_id=self.aws_conn_id,
region_name=self.region_name,
verify=self.verify,
botocore_config=self.botocore_config,
),
method_name="execute_complete",
# timeout is set to ensure that if a trigger dies, the timeout does not restart
Expand Down Expand Up @@ -901,6 +913,9 @@ def execute(self, context: Context):
waiter_delay=self.poll_interval,
waiter_max_attempts=self.max_attempts,
aws_conn_id=self.aws_conn_id,
region_name=self.region_name,
verify=self.verify,
botocore_config=self.botocore_config,
),
method_name="execute_complete",
)
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -43,10 +43,13 @@ class RedshiftCreateClusterTrigger(AwsBaseWaiterTrigger):

def __init__(
self,
*,
cluster_identifier: str,
aws_conn_id: str | None = "aws_default",
region_name: str | None = None,
Comment thread
KarshVashi marked this conversation as resolved.
waiter_delay: int = 15,
waiter_max_attempts: int = 999999,
**kwargs,
):
super().__init__(
serialized_fields={"cluster_identifier": cluster_identifier},
Expand All @@ -59,10 +62,17 @@ def __init__(
waiter_delay=waiter_delay,
waiter_max_attempts=waiter_max_attempts,
aws_conn_id=aws_conn_id,
region_name=region_name,
**kwargs,
)

def hook(self) -> AwsGenericHook:
return RedshiftHook(aws_conn_id=self.aws_conn_id)
return RedshiftHook(
aws_conn_id=self.aws_conn_id,
region_name=self.region_name,
verify=self.verify,
config=self.botocore_config,
)


class RedshiftPauseClusterTrigger(AwsBaseWaiterTrigger):
Expand All @@ -80,10 +90,13 @@ class RedshiftPauseClusterTrigger(AwsBaseWaiterTrigger):

def __init__(
self,
*,
cluster_identifier: str,
aws_conn_id: str | None = "aws_default",
region_name: str | None = None,
waiter_delay: int = 15,
waiter_max_attempts: int = 999999,
**kwargs,
):
super().__init__(
serialized_fields={"cluster_identifier": cluster_identifier},
Expand All @@ -96,10 +109,17 @@ def __init__(
waiter_delay=waiter_delay,
waiter_max_attempts=waiter_max_attempts,
aws_conn_id=aws_conn_id,
region_name=region_name,
**kwargs,
)

def hook(self) -> AwsGenericHook:
return RedshiftHook(aws_conn_id=self.aws_conn_id)
return RedshiftHook(
aws_conn_id=self.aws_conn_id,
region_name=self.region_name,
verify=self.verify,
config=self.botocore_config,
)


class RedshiftCreateClusterSnapshotTrigger(AwsBaseWaiterTrigger):
Expand All @@ -117,10 +137,13 @@ class RedshiftCreateClusterSnapshotTrigger(AwsBaseWaiterTrigger):

def __init__(
self,
*,
cluster_identifier: str,
aws_conn_id: str | None = "aws_default",
region_name: str | None = None,
waiter_delay: int = 15,
waiter_max_attempts: int = 999999,
**kwargs,
):
super().__init__(
serialized_fields={"cluster_identifier": cluster_identifier},
Expand All @@ -133,10 +156,17 @@ def __init__(
waiter_delay=waiter_delay,
waiter_max_attempts=waiter_max_attempts,
aws_conn_id=aws_conn_id,
region_name=region_name,
**kwargs,
)

def hook(self) -> AwsGenericHook:
return RedshiftHook(aws_conn_id=self.aws_conn_id)
return RedshiftHook(
aws_conn_id=self.aws_conn_id,
region_name=self.region_name,
verify=self.verify,
config=self.botocore_config,
)


class RedshiftResumeClusterTrigger(AwsBaseWaiterTrigger):
Expand All @@ -154,10 +184,13 @@ class RedshiftResumeClusterTrigger(AwsBaseWaiterTrigger):

def __init__(
self,
*,
cluster_identifier: str,
aws_conn_id: str | None = "aws_default",
region_name: str | None = None,
waiter_delay: int = 15,
waiter_max_attempts: int = 999999,
**kwargs,
):
super().__init__(
serialized_fields={"cluster_identifier": cluster_identifier},
Expand All @@ -170,10 +203,17 @@ def __init__(
waiter_delay=waiter_delay,
waiter_max_attempts=waiter_max_attempts,
aws_conn_id=aws_conn_id,
region_name=region_name,
**kwargs,
)

def hook(self) -> AwsGenericHook:
return RedshiftHook(aws_conn_id=self.aws_conn_id)
return RedshiftHook(
aws_conn_id=self.aws_conn_id,
region_name=self.region_name,
verify=self.verify,
config=self.botocore_config,
)


class RedshiftDeleteClusterTrigger(AwsBaseWaiterTrigger):
Expand All @@ -188,10 +228,13 @@ class RedshiftDeleteClusterTrigger(AwsBaseWaiterTrigger):

def __init__(
self,
*,
cluster_identifier: str,
aws_conn_id: str | None = "aws_default",
region_name: str | None = None,
waiter_delay: int = 30,
waiter_max_attempts: int = 30,
**kwargs,
):
super().__init__(
serialized_fields={"cluster_identifier": cluster_identifier},
Expand All @@ -204,10 +247,17 @@ def __init__(
waiter_delay=waiter_delay,
waiter_max_attempts=waiter_max_attempts,
aws_conn_id=aws_conn_id,
region_name=region_name,
**kwargs,
)

def hook(self) -> AwsGenericHook:
return RedshiftHook(aws_conn_id=self.aws_conn_id)
return RedshiftHook(
aws_conn_id=self.aws_conn_id,
region_name=self.region_name,
verify=self.verify,
config=self.botocore_config,
)


class RedshiftClusterTrigger(BaseTrigger):
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -24,6 +24,11 @@

from airflow.providers.amazon.aws.triggers.redshift_cluster import (
RedshiftClusterTrigger,
RedshiftCreateClusterSnapshotTrigger,
RedshiftCreateClusterTrigger,
RedshiftDeleteClusterTrigger,
RedshiftPauseClusterTrigger,
RedshiftResumeClusterTrigger,
)
from airflow.triggers.base import TriggerEvent

Expand Down Expand Up @@ -115,3 +120,105 @@ async def test_redshift_cluster_sensor_trigger_exception(self, mock_cluster_stat
# so we validate for length of task to be 1
assert len(task) == 1
assert TriggerEvent({"status": "error", "message": "Test exception"}) in task


WAITER_TRIGGER_PARAMS = [
pytest.param(
RedshiftCreateClusterTrigger,
15,
999999,
id="RedshiftCreateClusterTrigger",
),
pytest.param(
RedshiftPauseClusterTrigger,
15,
999999,
id="RedshiftPauseClusterTrigger",
),
pytest.param(
RedshiftCreateClusterSnapshotTrigger,
15,
999999,
id="RedshiftCreateClusterSnapshotTrigger",
),
pytest.param(
RedshiftResumeClusterTrigger,
15,
999999,
id="RedshiftResumeClusterTrigger",
),
pytest.param(
RedshiftDeleteClusterTrigger,
30,
30,
id="RedshiftDeleteClusterTrigger",
),
]


class TestRedshiftWaiterTriggers:
"""Tests for the five Redshift triggers that inherit from ``AwsBaseWaiterTrigger``."""

@pytest.mark.parametrize(
("trigger_cls", "default_delay", "default_max_attempts"),
WAITER_TRIGGER_PARAMS,
)
def test_serialization(self, trigger_cls, default_delay, default_max_attempts):
trigger = trigger_cls(
cluster_identifier="test_cluster",
aws_conn_id="aws_default",
region_name="us-east-1",
)

classpath, kwargs = trigger.serialize()
assert classpath == f"airflow.providers.amazon.aws.triggers.redshift_cluster.{trigger_cls.__name__}"
assert kwargs == {
"cluster_identifier": "test_cluster",
"waiter_delay": default_delay,
"waiter_max_attempts": default_max_attempts,
"aws_conn_id": "aws_default",
"region_name": "us-east-1",
}

@pytest.mark.parametrize(
("trigger_cls", "default_delay", "default_max_attempts"),
WAITER_TRIGGER_PARAMS,
)
def test_serialization_with_verify_and_botocore_config(
self, trigger_cls, default_delay, default_max_attempts
):
trigger = trigger_cls(
cluster_identifier="test_cluster",
aws_conn_id="aws_default",
verify=False,
botocore_config={"connect_timeout": 30},
)

_, kwargs = trigger.serialize()
assert kwargs["verify"] is False
assert kwargs["botocore_config"] == {"connect_timeout": 30}
Comment thread
KarshVashi marked this conversation as resolved.

@pytest.mark.parametrize(
("trigger_cls", "default_delay", "default_max_attempts"),
WAITER_TRIGGER_PARAMS,
)
@mock.patch("airflow.providers.amazon.aws.triggers.redshift_cluster.RedshiftHook")
def test_hook_propagates_verify_and_botocore_config(
self, mock_hook_cls, trigger_cls, default_delay, default_max_attempts
):
trigger = trigger_cls(
cluster_identifier="test_cluster",
aws_conn_id="test_conn",
region_name="eu-west-1",
verify="/path/to/ca-bundle.crt",
botocore_config={"read_timeout": 60},
)

trigger.hook()

mock_hook_cls.assert_called_once_with(
aws_conn_id="test_conn",
region_name="eu-west-1",
verify="/path/to/ca-bundle.crt",
config={"read_timeout": 60},
)
Loading