diff --git a/perfkitbenchmarker/benchmark_spec.py b/perfkitbenchmarker/benchmark_spec.py index e77e2bccdc..7fdc27248e 100644 --- a/perfkitbenchmarker/benchmark_spec.py +++ b/perfkitbenchmarker/benchmark_spec.py @@ -862,6 +862,8 @@ def ConstructVirtualMachineGroup( ) if group_spec.cidr: # apply cidr range to all vms in vm_group group_spec.vm_spec.cidr = group_spec.cidr + vm_class = virtual_machine.GetVmClass(cloud, os_type) + vm_class.AdjustVmSpec(group_spec.vm_spec, disk_spec) vm = self._CreateVirtualMachine(group_spec.vm_spec, os_type, cloud) vm.vm_group = group_name if disk_spec and not vm.is_static: diff --git a/perfkitbenchmarker/providers/azure/azure_virtual_machine.py b/perfkitbenchmarker/providers/azure/azure_virtual_machine.py index 91c6c55000..f309d1e848 100644 --- a/perfkitbenchmarker/providers/azure/azure_virtual_machine.py +++ b/perfkitbenchmarker/providers/azure/azure_virtual_machine.py @@ -744,6 +744,22 @@ class AzureVirtualMachine( low_priority_status_code: int | None spot_early_termination: bool + @classmethod + def AdjustVmSpec(cls, vm_spec, disk_spec): + super().AdjustVmSpec(vm_spec, disk_spec) + if disk_spec and disk_spec.disk_type in ( + azure_disk.PREMIUM_STORAGE_V2, + azure_disk.ULTRA_STORAGE, + ): + if vm_spec.zone and not util.GetAvailabilityZoneFromZone(vm_spec.zone): + region = util.GetRegionFromZone(vm_spec.zone) + vm_spec.zone = f'{region}-1' + logging.info( + 'Forcing zone to %s for VM spec because disk type %s requires it.', + vm_spec.zone, + disk_spec.disk_type, + ) + def __init__(self, vm_spec): """Initialize an Azure virtual machine. diff --git a/perfkitbenchmarker/virtual_machine.py b/perfkitbenchmarker/virtual_machine.py index df10b5dcc4..81fb6fbcbe 100644 --- a/perfkitbenchmarker/virtual_machine.py +++ b/perfkitbenchmarker/virtual_machine.py @@ -338,6 +338,11 @@ class BaseVirtualMachine(os_mixin.BaseOsMixin, resource.BaseResource): # inheritence, we need it here. cpu_arch: str + @classmethod + def AdjustVmSpec(cls, vm_spec, disk_spec): + """Adjusts the vm_spec based on the disk_spec before VM creation.""" + pass + def __init__(self, vm_spec: virtual_machine_spec.BaseVmSpec): """Initialize BaseVirtualMachine class. diff --git a/tests/providers/azure/azure_virtual_machine_test.py b/tests/providers/azure/azure_virtual_machine_test.py index 0cfb43b942..da8061559d 100644 --- a/tests/providers/azure/azure_virtual_machine_test.py +++ b/tests/providers/azure/azure_virtual_machine_test.py @@ -149,6 +149,30 @@ def testInsufficientSpotCapacity(self): with self.assertRaises(errors.Benchmarks.InsufficientCapacityCloudFailure): vm._Create() + def testAdjustVmSpecForcesZone(self): + spec = azure_virtual_machine.AzureVmSpec( + _COMPONENT, machine_type='Standard_D2s_v5', zone='centralus' + ) + disk_spec = mock.Mock(disk_type='PremiumV2_LRS') + azure_virtual_machine.AzureVirtualMachine.AdjustVmSpec(spec, disk_spec) + self.assertEqual(spec.zone, 'centralus-1') + + def testAdjustVmSpecDoesNotForceZoneIfAlreadyZonal(self): + spec = azure_virtual_machine.AzureVmSpec( + _COMPONENT, machine_type='Standard_D2s_v5', zone='centralus-2' + ) + disk_spec = mock.Mock(disk_type='PremiumV2_LRS') + azure_virtual_machine.AzureVirtualMachine.AdjustVmSpec(spec, disk_spec) + self.assertEqual(spec.zone, 'centralus-2') + + def testAdjustVmSpecDoesNotForceZoneForNonZonalDisk(self): + spec = azure_virtual_machine.AzureVmSpec( + _COMPONENT, machine_type='Standard_D2s_v5', zone='centralus' + ) + disk_spec = mock.Mock(disk_type='Premium_LRS') + azure_virtual_machine.AzureVirtualMachine.AdjustVmSpec(spec, disk_spec) + self.assertEqual(spec.zone, 'centralus') + class AzurePublicIPAddressTest(pkb_common_test_case.PkbCommonTestCase):