Skip to content
105 changes: 71 additions & 34 deletions lisa/microsoft/testsuites/device_passthrough/functional_tests.py
Original file line number Diff line number Diff line change
@@ -1,13 +1,13 @@
# Copyright (c) Microsoft Corporation.
# Licensed under the MIT license.
import re
from typing import TYPE_CHECKING, Any, Dict, Tuple, Union, cast
from typing import TYPE_CHECKING, Any, Dict, Optional, Tuple, Union, cast

from lisa import Environment, Node, TestCaseMetadata, TestSuite, TestSuiteMetadata
from lisa.base_tools import Cat
from lisa.operating_system import Windows
from lisa.platform_ import Platform
from lisa.sut_orchestrator import CLOUD_HYPERVISOR, HYPERV
from lisa.sut_orchestrator import CLOUD_HYPERVISOR, HYPERV, OPENVMM
from lisa.testsuite import TestResult, simple_requirement
from lisa.tools import Lspci
from lisa.util import LisaException, SkippedException
Expand All @@ -20,13 +20,17 @@
from lisa.sut_orchestrator.libvirt.schema import (
DeviceAddressSchema as LibvirtDeviceAddressSchema,
)
from lisa.sut_orchestrator.openvmm.context import (
DeviceAddressSchema as OpenVmmDeviceAddressSchema,
)

HostDeviceAddressSchema = Union[
HypervDeviceAddressSchema,
LibvirtDeviceAddressSchema,
OpenVmmDeviceAddressSchema,
]

SUPPORTED_PASSTHROUGH_PLATFORMS = [CLOUD_HYPERVISOR, HYPERV]
SUPPORTED_PASSTHROUGH_PLATFORMS = [CLOUD_HYPERVISOR, HYPERV, OPENVMM]


@TestSuiteMetadata(
Expand All @@ -44,8 +48,8 @@ class DevicePassthroughFunctionalTests(TestSuite):
@TestCaseMetadata(
description="""
Check if passthrough device is visible to guest.
This testcase supports the CLOUD_HYPERVISOR and HYPERV platforms
of LISA. Please refer below runbook snippet.
This testcase supports the CLOUD_HYPERVISOR, HYPERV, and OPENVMM
platforms of LISA. Please refer below runbook snippet.

platform:
- type: cloud-hypervisor
Expand Down Expand Up @@ -89,32 +93,22 @@ def verify_device_passthrough_on_guest(
if platform is None:
raise SkippedException(
"Device passthrough validation requires a LISA platform context. "
"Verify the runbook uses cloud-hypervisor or hyperv."
)
platform_name = platform.type_name()
node_context: Any

if platform_name == CLOUD_HYPERVISOR:
# Import at runtime to avoid libvirt dependency on other platforms.
from lisa.sut_orchestrator.libvirt.context import (
get_node_context as get_libvirt_node_context,
"Verify the runbook uses cloud-hypervisor, hyperv, or openvmm."
)
platform_name = self._get_platform_name(platform, node)
node_context = self._get_passthrough_context(node, platform_name)

node_context = get_libvirt_node_context(node)
elif platform_name == HYPERV:
from lisa.sut_orchestrator.hyperv.context import (
get_node_context as get_hyperv_node_context,
)
if not node_context.passthrough_devices:
raise SkippedException("No passthrough devices are assigned to node")

node_context = get_hyperv_node_context(node)
else:
host_node = getattr(node_context, "host", None)
if host_node is None and environment.platform is not None:
host_node = getattr(environment.platform, "host_node", None)
if host_node is None and platform_name != HYPERV:
raise SkippedException(
f"Device passthrough validation is not supported on '{platform_name}'"
"No host node is available for passthrough device validation"
)

if not node_context.passthrough_devices:
raise SkippedException("No passthrough devices are assigned to node")

expected_devices: Dict[Tuple[str, str, str], int] = {}
for passthrough_context in node_context.passthrough_devices:
pool_type = str(passthrough_context.pool_type.value)
Expand All @@ -124,7 +118,7 @@ def verify_device_passthrough_on_guest(
)
for host_device in passthrough_context.device_list:
vendor_device_id = self._vendor_device_from_host_device(
platform, host_device
platform_name, platform, host_node, host_device
)
key = (
pool_type,
Expand All @@ -147,12 +141,48 @@ def verify_device_passthrough_on_guest(
f"Vendor/Device ID: {ven_id}:{dev_id}"
)

@staticmethod
def _get_platform_name(platform: Platform, node: Node) -> str:
node_type = node.type_name()
if node_type == OPENVMM:
return node_type

return platform.type_name()

@staticmethod
def _get_passthrough_context(node: Node, platform_name: str) -> Any:
if platform_name == OPENVMM:
from lisa.sut_orchestrator.openvmm.context import (
get_node_context as get_openvmm_node_context,
)

return get_openvmm_node_context(node)

if platform_name == CLOUD_HYPERVISOR:
from lisa.sut_orchestrator.libvirt.context import (
get_node_context as get_libvirt_node_context,
)

return get_libvirt_node_context(node)

if platform_name == HYPERV:
from lisa.sut_orchestrator.hyperv.context import (
get_node_context as get_hyperv_node_context,
)

return get_hyperv_node_context(node)

raise SkippedException(
f"Device passthrough validation is not supported on '{platform_name}'"
)

@staticmethod
def _vendor_device_from_host_device(
platform_name: str,
platform: Platform,
host_node: Optional[Node],
device: "HostDeviceAddressSchema",
) -> Dict[str, str]:
platform_name = platform.type_name()
if platform_name == HYPERV:
hyperv_device = cast("HypervDeviceAddressSchema", device)
instance_id = hyperv_device.instance_id
Expand All @@ -171,19 +201,26 @@ def _vendor_device_from_host_device(
"device_id": match.group("device_id").lower(),
}

if platform_name != CLOUD_HYPERVISOR:
if platform_name not in [CLOUD_HYPERVISOR, OPENVMM]:
raise LisaException(
f"Device passthrough host device lookup is not supported on "
f"'{platform_name}'. Use a cloud-hypervisor or hyperv platform."
f"'{platform_name}'. Use a cloud-hypervisor, hyperv, or openvmm "
"platform."
)

cloud_hypervisor = cast("CloudHypervisorPlatform", platform)
libvirt_device = cast("LibvirtDeviceAddressSchema", device)
if host_node is None:
raise LisaException(
"No host node is available for passthrough device vendor lookup"
)
if platform_name == CLOUD_HYPERVISOR:
cloud_hypervisor = cast("CloudHypervisorPlatform", platform)
host_node = cloud_hypervisor.host_node
pci_device = cast(Any, device)
bdf = (
f"{libvirt_device.domain}:{libvirt_device.bus}:"
f"{libvirt_device.slot}.{libvirt_device.function}"
f"{pci_device.domain}:{pci_device.bus}:"
f"{pci_device.slot}.{pci_device.function}"
).lower()
cat = cloud_hypervisor.host_node.tools[Cat]
cat = host_node.tools[Cat]
vendor_raw = cat.read(f"/sys/bus/pci/devices/{bdf}/vendor", sudo=True).strip()
device_raw = cat.read(f"/sys/bus/pci/devices/{bdf}/device", sudo=True).strip()
# Normalize to 4-digit lowercase hex used by lspci identifiers.
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -26,7 +26,7 @@
)
from lisa.environment import Environment, Node
from lisa.operating_system import Windows
from lisa.sut_orchestrator import CLOUD_HYPERVISOR, HYPERV
from lisa.sut_orchestrator import CLOUD_HYPERVISOR, HYPERV, OPENVMM
from lisa.testsuite import TestResult
from lisa.tools import Dhclient, Kill, PowerShell, Sysctl
from lisa.tools.ip import Ip
Expand All @@ -47,7 +47,7 @@
from lisa.util.logger import get_logger
from lisa.util.parallel import run_in_parallel

SUPPORTED_PASSTHROUGH_PLATFORMS = [CLOUD_HYPERVISOR, HYPERV]
SUPPORTED_PASSTHROUGH_PLATFORMS = [CLOUD_HYPERVISOR, HYPERV, OPENVMM]
WINDOWS_NTTTCP_MAX_SERVER_THREADS = 64
WINDOWS_NTTTCP_MAX_MIXED_TCP_CONNECTIONS = 512
WINDOWS_NTTTCP_RECEIVER_WAIT_TIMEOUT = 90
Expand Down Expand Up @@ -790,6 +790,13 @@ def _refresh_passthrough_nic_address(
return passthrough_nic_ip

def _get_passthrough_node_context(self, node: Node) -> Any:
if node.type_name() == OPENVMM:
from lisa.sut_orchestrator.openvmm.context import (
get_node_context as get_openvmm_node_context,
)

return get_openvmm_node_context(node)

try:
from lisa.sut_orchestrator.libvirt.context import (
get_node_context as get_libvirt_node_context,
Expand Down
7 changes: 5 additions & 2 deletions lisa/node.py
Original file line number Diff line number Diff line change
Expand Up @@ -355,7 +355,9 @@ def execute_async(
)

def cleanup(self) -> None:
for guest in self.guests:
guests = list(self.guests)
self._guests = []
for guest in guests:
try:
guest.cleanup()
except Exception:
Expand Down Expand Up @@ -747,6 +749,7 @@ def set_connection_info(
username: str = "root",
password: str = "",
private_key_file: str = "",
proxy_jump_boxes: Optional[List[schema.ConnectionInfo]] = None,
) -> None:
if not address and not public_address:
raise LisaException(
Expand Down Expand Up @@ -774,7 +777,7 @@ def set_connection_info(
password,
private_key_file,
)
self._shell = SshShell(self._connection_info)
self._shell = SshShell(self._connection_info, proxy_jump_boxes)

self.public_address = public_address
self.public_port = public_port
Expand Down
33 changes: 31 additions & 2 deletions lisa/runners/lisa_runner.py
Original file line number Diff line number Diff line change
Expand Up @@ -712,7 +712,9 @@ def _get_runnable_test_results(
try:
if result.check_environment(
environment=tested_environment,
environment_platform_type=self.platform.type_name(),
environment_platform_type=self._get_test_platform_type(
result, self.platform.type_name()
),
save_reason=True,
) and (
not result.runtime_data.use_new_environment
Expand Down Expand Up @@ -872,7 +874,9 @@ def _merge_test_requirements(
test_req: TestCaseRequirement = test_result.runtime_data.requirement
environment_requirement: Optional[EnvironmentSpace] = None

check_result = test_result.check_platform(platform_type)
check_result = test_result.check_platform(
self._get_test_platform_type(test_result, platform_type)
)
if not check_result.result:
test_result.set_status(TestStatus.SKIPPED, check_result.reasons)
continue
Expand Down Expand Up @@ -926,6 +930,31 @@ def _create_guest_parent_requirement(

return EnvironmentSpace(nodes=[schema.NodeSpace()])

def _get_test_platform_type(
self, test_result: TestResult, platform_type: str
) -> str:
Comment thread
vyadavmsft marked this conversation as resolved.
if not getattr(self, "_guest_enabled", False):
return platform_type

requirement = test_result.runtime_data.requirement
if (
not requirement
or not requirement.platform_type
or len(requirement.platform_type.items) == 0
):
return platform_type

platform_runbook = cast(schema.Platform, self.platform.runbook)
for guest_runbook in platform_runbook.guests:
guest_platform_type = getattr(guest_runbook, "type", "")
if (
guest_platform_type
and test_result.check_platform(guest_platform_type).result
):
return guest_platform_type

return platform_type

def _merge_platform_requirement(
self,
environment_requirement: EnvironmentSpace,
Expand Down
Loading
Loading