Skip to content
Closed
Show file tree
Hide file tree
Changes from 2 commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
4 changes: 2 additions & 2 deletions google/cloud/aiplatform/jobs.py
Original file line number Diff line number Diff line change
Expand Up @@ -222,7 +222,7 @@ def _log_job_state(self):
% (
self.__class__.__name__,
self._gca_resource.name,
self._gca_resource.state,
self._gca_resource.state.name,
)
)

Expand Down Expand Up @@ -1490,7 +1490,7 @@ def iter_outputs(
if self.state != gca_job_state.JobState.JOB_STATE_SUCCEEDED:
raise RuntimeError(
f"Cannot read outputs until BatchPredictionJob has succeeded, "
f"current state: {self._gca_resource.state}"
f"current state: {self._gca_resource.state.name}"
)

output_info = self._gca_resource.output_info
Expand Down
2 changes: 1 addition & 1 deletion google/cloud/aiplatform/pipeline_jobs.py
Original file line number Diff line number Diff line change
Expand Up @@ -786,7 +786,7 @@ def _block_until_complete(self):
% (
self.__class__.__name__,
self._gca_resource.name,
self._gca_resource.state,
self._gca_resource.state.name,
)
)
log_wait = min(log_wait * multiplier, max_wait)
Expand Down
2 changes: 1 addition & 1 deletion google/cloud/aiplatform/schedules.py
Original file line number Diff line number Diff line change
Expand Up @@ -223,7 +223,7 @@ def _block_until_complete(self) -> None:
% (
self.__class__.__name__,
self._gca_resource.name,
self._gca_resource.state,
self._gca_resource.state.name,
)
)
log_wait = min(log_wait * multiplier, max_wait)
Expand Down
2 changes: 1 addition & 1 deletion google/cloud/aiplatform/training_jobs.py
Original file line number Diff line number Diff line change
Expand Up @@ -974,7 +974,7 @@ def _block_until_complete(self):
% (
self.__class__.__name__,
self._gca_resource.name,
self._gca_resource.state,
self._gca_resource.state.name,
)
)
log_wait = min(log_wait * _WAIT_TIME_MULTIPLIER, _MAX_WAIT_TIME)
Expand Down
30 changes: 30 additions & 0 deletions tests/unit/aiplatform/test_jobs.py
Original file line number Diff line number Diff line change
Expand Up @@ -398,6 +398,21 @@ def test_dashboard_uri_format(self):
)
assert uri == expected

@pytest.mark.usefixtures("fake_job_getter_mock")
def test_log_job_state_uses_symbolic_name(self):
"""_log_job_state must log the enum name, not the integer value (regression for Python 3.11+)."""
fake_job = self.FakeJob(job_name=_TEST_JOB_RESOURCE_NAME)
fake_job._gca_resource = mock.Mock()
fake_job._gca_resource.name = _TEST_JOB_RESOURCE_NAME
fake_job._gca_resource.state = gca_job_state_compat.JobState.JOB_STATE_RUNNING

with mock.patch.object(jobs._LOGGER, "info") as mock_info:
fake_job._log_job_state()

logged_msg = mock_info.call_args[0][0]
assert "JOB_STATE_RUNNING" in logged_msg
assert "current state:\n3" not in logged_msg


@pytest.fixture
def get_batch_prediction_job_mock():
Expand Down Expand Up @@ -713,6 +728,21 @@ def test_batch_prediction_iter_dirs_while_running(self):
)
bp.iter_outputs()

@pytest.mark.usefixtures("get_batch_prediction_job_running_bq_output_mock")
def test_batch_prediction_iter_dirs_while_running_error_uses_symbolic_state_name(
self,
):
"""RuntimeError message must use symbolic state name, not integer (regression for Python 3.11+)."""
with pytest.raises(RuntimeError) as exc_info:
bp = jobs.BatchPredictionJob(
batch_prediction_job_name=_TEST_BATCH_PREDICTION_JOB_NAME
)
bp.iter_outputs()

error_msg = str(exc_info.value)
assert "JOB_STATE_RUNNING" in error_msg
assert "current state: 3" not in error_msg

@pytest.mark.usefixtures("get_batch_prediction_job_empty_output_mock")
def test_batch_prediction_iter_dirs_invalid_output_info(self):
"""
Expand Down
43 changes: 43 additions & 0 deletions tests/unit/aiplatform/test_pipeline_job_schedules.py
Original file line number Diff line number Diff line change
Expand Up @@ -43,6 +43,7 @@
)
from google.cloud.aiplatform import (
pipeline_job_schedules,
schedules as aiplatform_schedules,
)
from google.cloud.aiplatform.preview.pipelinejob import (
pipeline_jobs as preview_pipeline_jobs,
Expand Down Expand Up @@ -434,6 +435,48 @@ def setup_method(self):
def teardown_method(self):
initializer.global_pool.shutdown(wait=True)

def test_block_until_complete_logs_symbolic_state_name(self):
Comment thread
racinmat marked this conversation as resolved.
"""State log must use symbolic enum name, not a bare integer (regression for Python 3.11+)."""
state_sequence = [
gca_schedule.Schedule.State.ACTIVE, # first loop check
gca_schedule.Schedule.State.COMPLETED, # second check exits loop
]
state_index = [0]

def get_state():
s = state_sequence[state_index[0]]
state_index[0] = min(state_index[0] + 1, len(state_sequence) - 1)
return s

mock_schedule = mock.Mock()
type(mock_schedule).state = mock.PropertyMock(side_effect=get_state)

active_gca = gca_schedule.Schedule(
name=_TEST_PIPELINE_JOB_SCHEDULE_NAME,
state=gca_schedule.Schedule.State.ACTIVE,
)
mock_schedule._gca_resource = active_gca

logged_messages = []

# time.time: first call sets previous_time=0; second gives 10 → triggers log (10 >= 5)
time_vals = iter([0.0, 10.0, 20.0])
with mock.patch(
"google.cloud.aiplatform.schedules.time.time", side_effect=time_vals
), mock.patch(
"google.cloud.aiplatform.schedules.time.sleep"
), mock.patch.object(
aiplatform_schedules._LOGGER,
"info",
side_effect=lambda msg, *a, **kw: logged_messages.append(msg),
):
aiplatform_schedules._Schedule._block_until_complete(mock_schedule)

state_log = next((m for m in logged_messages if "current state" in m), None)
assert state_log is not None, "No 'current state' log message found"
assert "ACTIVE" in state_log
assert "current state:\n1" not in state_log

@pytest.mark.parametrize(
"job_spec",
[_TEST_PIPELINE_SPEC_JSON, _TEST_PIPELINE_SPEC_YAML, _TEST_PIPELINE_JOB],
Expand Down
39 changes: 39 additions & 0 deletions tests/unit/aiplatform/test_pipeline_jobs.py
Original file line number Diff line number Diff line change
Expand Up @@ -708,6 +708,45 @@ def setup_method(self):
def teardown_method(self):
initializer.global_pool.shutdown(wait=True)

@mock.patch.object(pipeline_jobs, "_JOB_WAIT_TIME", 0)
@mock.patch.object(pipeline_jobs, "_LOG_WAIT_TIME", 0)
def test_block_until_complete_logs_symbolic_state_name(
self,
mock_pipeline_service_create,
mock_pipeline_service_get,
mock_pipeline_bucket_exists,
):
"""State log must use symbolic enum name, not a bare integer (regression for Python 3.11+)."""
aiplatform.init(
project=_TEST_PROJECT,
staging_bucket=_TEST_GCS_BUCKET_NAME,
location=_TEST_LOCATION,
credentials=_TEST_CREDENTIALS,
)

logged_messages = []

with patch.object(
storage.Blob, "download_as_bytes"
) as mock_load, mock.patch.object(
pipeline_jobs._LOGGER,
"info",
side_effect=lambda msg, *a, **kw: logged_messages.append(msg),
):
mock_load.return_value = _TEST_PIPELINE_SPEC_JSON.encode()

job = pipeline_jobs.PipelineJob(
display_name=_TEST_PIPELINE_JOB_DISPLAY_NAME,
template_path=_TEST_TEMPLATE_PATH,
job_id=_TEST_PIPELINE_JOB_ID,
)
job.run(sync=True, create_request_timeout=None)

state_log = next((m for m in logged_messages if "current state" in m), None)
assert state_log is not None, "No 'current state' log message found"
assert "PIPELINE_STATE_RUNNING" in state_log
assert "current state:\n3" not in state_log

@pytest.mark.parametrize(
"job_spec",
[_TEST_PIPELINE_SPEC_JSON, _TEST_PIPELINE_SPEC_YAML, _TEST_PIPELINE_JOB],
Expand Down
53 changes: 53 additions & 0 deletions tests/unit/aiplatform/test_training_jobs.py
Original file line number Diff line number Diff line change
Expand Up @@ -1354,6 +1354,59 @@ def teardown_method(self):
pathlib.Path(self._local_script_file_name).unlink()
initializer.global_pool.shutdown(wait=True)

def test_block_until_complete_logs_symbolic_state_name(
self, mock_model_service_get
):
"""State log must use symbolic enum name, not a bare integer (regression for Python 3.11+)."""
aiplatform.init(project=_TEST_PROJECT, staging_bucket=_TEST_BUCKET_NAME)

logged_messages = []

with mock.patch.object(
pipeline_service_client.PipelineServiceClient, "create_training_pipeline"
) as mock_create, mock.patch.object(
source_utils._TrainingScriptPythonPackager, "package_and_copy_to_gcs"
) as mock_pkg, mock.patch.object(
pipeline_service_client.PipelineServiceClient, "get_training_pipeline"
) as mock_get, mock.patch.object(
training_jobs, "_LOG_WAIT_TIME", 0
), mock.patch.object(
training_jobs, "_JOB_WAIT_TIME", 0
), mock.patch.object(
training_jobs._LOGGER,
"info",
side_effect=lambda msg, *a, **kw: logged_messages.append(msg),
):
mock_pkg.return_value = _TEST_OUTPUT_PYTHON_PACKAGE_PATH
mock_create.return_value = gca_training_pipeline.TrainingPipeline(
name=_TEST_PIPELINE_RESOURCE_NAME,
state=gca_pipeline_state.PipelineState.PIPELINE_STATE_SUCCEEDED,
model_to_upload=gca_model.Model(name=_TEST_MODEL_NAME),
)
_running = gca_training_pipeline.TrainingPipeline(
name=_TEST_PIPELINE_RESOURCE_NAME,
state=gca_pipeline_state.PipelineState.PIPELINE_STATE_RUNNING,
training_task_inputs={},
)
_succeeded = gca_training_pipeline.TrainingPipeline(
name=_TEST_PIPELINE_RESOURCE_NAME,
state=gca_pipeline_state.PipelineState.PIPELINE_STATE_SUCCEEDED,
training_task_inputs={},
model_to_upload=gca_model.Model(name=_TEST_MODEL_NAME),
)
mock_get.side_effect = [_running, _running] + [_succeeded] * 8
job = training_jobs.CustomTrainingJob(
display_name=_TEST_DISPLAY_NAME,
script_path=self._local_script_file_name,
container_uri=_TEST_TRAINING_CONTAINER_IMAGE,
)
job.run(base_output_dir=_TEST_BASE_OUTPUT_DIR, sync=True)

state_log = next((m for m in logged_messages if "current state" in m), None)
assert state_log is not None, "No 'current state' log message found"
assert "PIPELINE_STATE_RUNNING" in state_log
assert "current state:\n3" not in state_log

@mock.patch.object(training_jobs, "_JOB_WAIT_TIME", 1)
@mock.patch.object(training_jobs, "_LOG_WAIT_TIME", 1)
@pytest.mark.parametrize("sync", [True, False])
Expand Down
Loading