Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
2 changes: 2 additions & 0 deletions perfkitbenchmarker/benchmark_spec.py
Original file line number Diff line number Diff line change
Expand Up @@ -862,6 +862,8 @@ def ConstructVirtualMachineGroup(
)
if group_spec.cidr: # apply cidr range to all vms in vm_group
group_spec.vm_spec.cidr = group_spec.cidr
vm_class = virtual_machine.GetVmClass(cloud, os_type)
vm_class.AdjustVmSpec(group_spec.vm_spec, disk_spec)
vm = self._CreateVirtualMachine(group_spec.vm_spec, os_type, cloud)
vm.vm_group = group_name
if disk_spec and not vm.is_static:
Expand Down
16 changes: 16 additions & 0 deletions perfkitbenchmarker/providers/azure/azure_virtual_machine.py
Original file line number Diff line number Diff line change
Expand Up @@ -744,6 +744,22 @@ class AzureVirtualMachine(
low_priority_status_code: int | None
spot_early_termination: bool

@classmethod
def AdjustVmSpec(cls, vm_spec, disk_spec):
super().AdjustVmSpec(vm_spec, disk_spec)
if disk_spec and disk_spec.disk_type in (
azure_disk.PREMIUM_STORAGE_V2,
azure_disk.ULTRA_STORAGE,
):
if vm_spec.zone and not util.GetAvailabilityZoneFromZone(vm_spec.zone):
region = util.GetRegionFromZone(vm_spec.zone)
vm_spec.zone = f'{region}-1'
logging.info(
'Forcing zone to %s for VM spec because disk type %s requires it.',
vm_spec.zone,
disk_spec.disk_type,
)

def __init__(self, vm_spec):
"""Initialize an Azure virtual machine.

Expand Down
5 changes: 5 additions & 0 deletions perfkitbenchmarker/virtual_machine.py
Original file line number Diff line number Diff line change
Expand Up @@ -338,6 +338,11 @@ class BaseVirtualMachine(os_mixin.BaseOsMixin, resource.BaseResource):
# inheritence, we need it here.
cpu_arch: str

@classmethod
def AdjustVmSpec(cls, vm_spec, disk_spec):
"""Adjusts the vm_spec based on the disk_spec before VM creation."""
pass

def __init__(self, vm_spec: virtual_machine_spec.BaseVmSpec):
"""Initialize BaseVirtualMachine class.

Expand Down
24 changes: 24 additions & 0 deletions tests/providers/azure/azure_virtual_machine_test.py
Original file line number Diff line number Diff line change
Expand Up @@ -149,6 +149,30 @@ def testInsufficientSpotCapacity(self):
with self.assertRaises(errors.Benchmarks.InsufficientCapacityCloudFailure):
vm._Create()

def testAdjustVmSpecForcesZone(self):
spec = azure_virtual_machine.AzureVmSpec(
_COMPONENT, machine_type='Standard_D2s_v5', zone='centralus'
)
disk_spec = mock.Mock(disk_type='PremiumV2_LRS')
azure_virtual_machine.AzureVirtualMachine.AdjustVmSpec(spec, disk_spec)
self.assertEqual(spec.zone, 'centralus-1')

def testAdjustVmSpecDoesNotForceZoneIfAlreadyZonal(self):
spec = azure_virtual_machine.AzureVmSpec(
_COMPONENT, machine_type='Standard_D2s_v5', zone='centralus-2'
)
disk_spec = mock.Mock(disk_type='PremiumV2_LRS')
azure_virtual_machine.AzureVirtualMachine.AdjustVmSpec(spec, disk_spec)
self.assertEqual(spec.zone, 'centralus-2')

def testAdjustVmSpecDoesNotForceZoneForNonZonalDisk(self):
spec = azure_virtual_machine.AzureVmSpec(
_COMPONENT, machine_type='Standard_D2s_v5', zone='centralus'
)
disk_spec = mock.Mock(disk_type='Premium_LRS')
azure_virtual_machine.AzureVirtualMachine.AdjustVmSpec(spec, disk_spec)
self.assertEqual(spec.zone, 'centralus')


class AzurePublicIPAddressTest(pkb_common_test_case.PkbCommonTestCase):

Expand Down