Skip to content
Draft
Show file tree
Hide file tree
Changes from 2 commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
15 changes: 12 additions & 3 deletions src/azure-cli/azure/cli/command_modules/acr/_docker_utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -165,9 +165,18 @@ def _get_aad_token_after_challenge(cli_ctx,

scope = _resolve_acr_scope(cli_ctx)

# this might be a cross tenant scenario, so pass subscription to get_raw_token
creds, _, tenant = profile.get_raw_token(subscription=get_subscription_id(cli_ctx),
resource=scope)
# this might be a cross tenant scenario, so pass subscription to get_raw_token.
# In some environments (e.g. AzureML MSI/SSO), acquiring a token for the ACR-specific
# audience may fail while the ARM management token is still available. Fall back to the
# ARM token since ACR's /oauth2/exchange endpoint accepts ARM tokens via the
# access_token grant_type.
try:
creds, _, tenant = profile.get_raw_token(subscription=get_subscription_id(cli_ctx),
resource=scope)
except CLIError as e:
logger.debug("Failed to get AAD token for ACR scope '%s' (%s). "
"Falling back to ARM management token.", scope, str(e))
creds, _, tenant = profile.get_raw_token(subscription=get_subscription_id(cli_ctx))

headers = {'Content-Type': 'application/x-www-form-urlencoded'}
content = {
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -1268,6 +1268,67 @@ def _validate_access_token_request(self, mock_requests_get, mock_requests_post,
headers={'Content-Type': 'application/x-www-form-urlencoded'},
verify=mock.ANY)

@mock.patch('azure.cli.command_modules.acr._docker_utils.get_subscription_id', autospec=True)
@mock.patch('azure.cli.command_modules.acr._docker_utils.get_registry_by_name', autospec=True)
@mock.patch('requests.post', autospec=True)
@mock.patch('requests.get', autospec=True)
@mock.patch('azure.cli.core._profile.Profile.get_raw_token', autospec=True)
def test_get_docker_credentials_arm_token_fallback(
self, mock_get_raw_token, mock_requests_get, mock_requests_post,
mock_get_registry_by_name, mock_get_subscription):
"""When acquiring an ACR-specific AAD token fails (e.g. AzureML MSI/SSO environments),
az acr login should fall back to the ARM management token for the /oauth2/exchange call."""
from knack.util import CLIError as KnackCLIError
from azure.mgmt.containerregistry.models import Registry, Sku

test_registry = 'testregistry'
test_login_server = '{}.azurecr.io'.format(test_registry)
cmd = self._setup_cmd()

mock_get_subscription.return_value = TEST_SUBSCRIPTION

registry = Registry(location='westus', sku=Sku(name='Standard'))
registry.login_server = test_login_server
mock_get_registry_by_name.return_value = registry, None

# Simulate: ACR-scope token fails (SSO error), ARM token succeeds
arm_token_response = ('Bearer', TEST_AAD_ACCESS_TOKEN, {}), TEST_SUBSCRIPTION, TEST_TENANT
mock_get_raw_token.side_effect = [
KnackCLIError("SSO failure, to mitigated it please try to click Jupyter/JupyterLab."),
arm_token_response,
]

# Set up the challenge and refresh token HTTP responses
challenge_response = mock.MagicMock()
challenge_response.headers = {
'WWW-Authenticate': 'Bearer realm="https://{}/oauth2/token",service="{}"'.format(
test_login_server, test_login_server)
}
challenge_response.status_code = 401
mock_requests_get.return_value = challenge_response

token_response = mock.MagicMock()
token_response.headers = {}
token_response.status_code = 200
token_response.content = json.dumps({
'refresh_token': TEST_ACR_REFRESH_TOKEN,
'access_token': TEST_ACR_ACCESS_TOKEN}).encode()
mock_requests_post.return_value = token_response

# get_login_credentials should succeed via the ARM token fallback
login_server, username, password = get_login_credentials(cmd, test_registry)

self.assertEqual(login_server, test_login_server)
self.assertEqual(username, EMPTY_GUID)
self.assertEqual(password, TEST_ACR_REFRESH_TOKEN)

# Verify fallback: first call used ACR scope, second used ARM default (no resource)
acr_scope = "https://{}.azure.net".format(ACR_AUDIENCE_RESOURCE_NAME)
calls = mock_get_raw_token.call_args_list
self.assertEqual(len(calls), 2)
self.assertEqual(calls[0], mock.call(mock.ANY, resource=acr_scope, subscription=mock.ANY))
self.assertEqual(calls[1], mock.call(mock.ANY, subscription=mock.ANY))

@mock.patch('azure.cli.command_modules.acr.helm.get_access_credentials', autospec=True)
@mock.patch('requests.request', autospec=True)
def test_helm_list(self, mock_requests_get, mock_get_access_credentials):
Expand Down
Loading