diff --git a/src/azure-cli/HISTORY.rst b/src/azure-cli/HISTORY.rst index 7bbafa33310..93f9cd30238 100644 --- a/src/azure-cli/HISTORY.rst +++ b/src/azure-cli/HISTORY.rst @@ -13,6 +13,7 @@ Release History * `az acr task logs`: Align log streaming with the default TLS behavior used by the rest of Azure CLI commands (#33486) * `az acr run/build`: Align log streaming with the default TLS behavior used by the rest of Azure CLI commands (#33486) * `az acr login`: Harden binary resolution and credential passing (#33373) +* `az acr login`: Fix ARM token exchange fallback for AzureML MSI/SSO environments (#33699) **AKS** diff --git a/src/azure-cli/azure/cli/command_modules/acr/_docker_utils.py b/src/azure-cli/azure/cli/command_modules/acr/_docker_utils.py index dd30846daa3..72d54b0acfc 100644 --- a/src/azure-cli/azure/cli/command_modules/acr/_docker_utils.py +++ b/src/azure-cli/azure/cli/command_modules/acr/_docker_utils.py @@ -165,9 +165,18 @@ def _get_aad_token_after_challenge(cli_ctx, scope = _resolve_acr_scope(cli_ctx) - # this might be a cross tenant scenario, so pass subscription to get_raw_token - creds, _, tenant = profile.get_raw_token(subscription=get_subscription_id(cli_ctx), - resource=scope) + # this might be a cross tenant scenario, so pass subscription to get_raw_token. + # In some environments (e.g. AzureML MSI/SSO), acquiring a token for the ACR-specific + # audience may fail while the ARM management token is still available. Fall back to the + # ARM token since ACR's /oauth2/exchange endpoint accepts ARM tokens via the + # access_token grant_type. + try: + creds, _, tenant = profile.get_raw_token(subscription=get_subscription_id(cli_ctx), + resource=scope) + except CLIError as e: + logger.debug("Failed to get AAD token for ACR scope '%s' (%s). " + "Falling back to ARM management token.", scope, str(e)) + creds, _, tenant = profile.get_raw_token(subscription=get_subscription_id(cli_ctx)) headers = {'Content-Type': 'application/x-www-form-urlencoded'} content = { diff --git a/src/azure-cli/azure/cli/command_modules/acr/tests/latest/test_acr_commands_mock.py b/src/azure-cli/azure/cli/command_modules/acr/tests/latest/test_acr_commands_mock.py index 95c8e3c34db..41ed3f6b9f2 100644 --- a/src/azure-cli/azure/cli/command_modules/acr/tests/latest/test_acr_commands_mock.py +++ b/src/azure-cli/azure/cli/command_modules/acr/tests/latest/test_acr_commands_mock.py @@ -1268,6 +1268,67 @@ def _validate_access_token_request(self, mock_requests_get, mock_requests_post, headers={'Content-Type': 'application/x-www-form-urlencoded'}, verify=mock.ANY) + @mock.patch('azure.cli.command_modules.acr._docker_utils.get_subscription_id', autospec=True) + @mock.patch('azure.cli.command_modules.acr._docker_utils.get_registry_by_name', autospec=True) + @mock.patch('requests.post', autospec=True) + @mock.patch('requests.get', autospec=True) + @mock.patch('azure.cli.core._profile.Profile.get_raw_token', autospec=True) + def test_get_docker_credentials_arm_token_fallback( + self, mock_get_raw_token, mock_requests_get, mock_requests_post, + mock_get_registry_by_name, mock_get_subscription): + """When acquiring an ACR-specific AAD token fails (e.g. AzureML MSI/SSO environments), + az acr login should fall back to the ARM management token for the /oauth2/exchange call.""" + from knack.util import CLIError as KnackCLIError + from azure.mgmt.containerregistry.models import Registry, Sku + + test_registry = 'testregistry' + test_login_server = '{}.azurecr.io'.format(test_registry) + cmd = self._setup_cmd() + + mock_get_subscription.return_value = TEST_SUBSCRIPTION + + registry = Registry(location='westus', sku=Sku(name='Standard')) + registry.login_server = test_login_server + mock_get_registry_by_name.return_value = registry, None + + # Simulate: ACR-scope token fails (SSO error), ARM token succeeds + arm_token_response = ('Bearer', TEST_AAD_ACCESS_TOKEN, {}), TEST_SUBSCRIPTION, TEST_TENANT + mock_get_raw_token.side_effect = [ + KnackCLIError("SSO failure, to mitigated it please try to click Jupyter/JupyterLab."), + arm_token_response, + ] + + # Set up the challenge and refresh token HTTP responses + challenge_response = mock.MagicMock() + challenge_response.headers = { + 'WWW-Authenticate': 'Bearer realm="https://{}/oauth2/token",service="{}"'.format( + test_login_server, test_login_server) + } + challenge_response.status_code = 401 + mock_requests_get.return_value = challenge_response + + token_response = mock.MagicMock() + token_response.headers = {} + token_response.status_code = 200 + token_response.content = json.dumps({ + 'refresh_token': TEST_ACR_REFRESH_TOKEN, + 'access_token': TEST_ACR_ACCESS_TOKEN}).encode() + mock_requests_post.return_value = token_response + + # get_login_credentials should succeed via the ARM token fallback + login_server, username, password = get_login_credentials(cmd, test_registry) + + self.assertEqual(login_server, test_login_server) + self.assertEqual(username, EMPTY_GUID) + self.assertEqual(password, TEST_ACR_REFRESH_TOKEN) + + # Verify fallback: first call used ACR scope, second used ARM default (no resource) + acr_scope = "https://{}.azure.net".format(ACR_AUDIENCE_RESOURCE_NAME) + calls = mock_get_raw_token.call_args_list + self.assertEqual(len(calls), 2) + self.assertEqual(calls[0], mock.call(mock.ANY, resource=acr_scope, subscription=mock.ANY)) + self.assertEqual(calls[1], mock.call(mock.ANY, subscription=mock.ANY)) + @mock.patch('azure.cli.command_modules.acr.helm.get_access_credentials', autospec=True) @mock.patch('requests.request', autospec=True) def test_helm_list(self, mock_requests_get, mock_get_access_credentials):