From 946caaa77eb912a49856c75a141c96491165255e Mon Sep 17 00:00:00 2001 From: Seok Namkoong Date: Wed, 25 Feb 2026 11:02:00 +0900 Subject: [PATCH] [DRAFT][passes] Passes and tests for VLM PatchEmbed optimization This draft is for introducing passes and tests for VLM PatchEmbed optimization. TICO-DCO-1.0-TICO-DCO-1.0-Signed-off-by: Seok Namkoong --- .../test_convert_conv3d_to_conv2d.py | 29 +++++ .../test_convert_permute_to_reshape.py | 61 ++++++++++ test/utils/pass_value_test.py | 4 +- tico/passes/convert_conv3d_to_conv2d.py | 86 +++++++++++++- tico/passes/convert_permute_to_reshape.py | 105 ++++++++++++++++++ 5 files changed, 283 insertions(+), 2 deletions(-) create mode 100644 test/unit_test/pass_test/test_convert_permute_to_reshape.py create mode 100644 tico/passes/convert_permute_to_reshape.py diff --git a/test/unit_test/pass_test/test_convert_conv3d_to_conv2d.py b/test/unit_test/pass_test/test_convert_conv3d_to_conv2d.py index f5fd038d..5af54636 100644 --- a/test/unit_test/pass_test/test_convert_conv3d_to_conv2d.py +++ b/test/unit_test/pass_test/test_convert_conv3d_to_conv2d.py @@ -264,3 +264,32 @@ def test_pass(self): self.run_value_test(ConvertConv3dToConv2d()) self.assertEqual(num_of_ops(self.exported_program(), ops.aten.conv3d), 0) self.assertGreaterEqual(num_of_ops(self.exported_program(), ops.aten.conv2d), 2) + + +class Conv3dPerfectFitKernel(torch.nn.Module): + """Conv3D with perfect fitting kernel""" + + def __init__(self): + super().__init__() + self.conv3d = torch.nn.Conv3d( + in_channels=3, + out_channels=1024, + kernel_size=(2, 16, 16), + stride=(2, 16, 16), + padding=(0, 0, 0), + ) + + def forward(self, input): + return self.conv3d(input) + + def get_example_inputs(self): + return (torch.randn(5, 3, 2, 16, 16),), {} + + +class ConvertConv3dPerfectFitKernelTest(SinglePassValueTest): + def test_pass(self): + self.setup(Conv3dPerfectFitKernel()) + self.assertEqual(num_of_ops(self.exported_program(), ops.aten.conv3d), 1) + self.run_value_test(ConvertConv3dToConv2d()) + self.assertEqual(num_of_ops(self.exported_program(), ops.aten.conv3d), 0) + self.assertGreaterEqual(num_of_ops(self.exported_program(), ops.aten.conv2d), 1) diff --git a/test/unit_test/pass_test/test_convert_permute_to_reshape.py b/test/unit_test/pass_test/test_convert_permute_to_reshape.py new file mode 100644 index 00000000..97a19c06 --- /dev/null +++ b/test/unit_test/pass_test/test_convert_permute_to_reshape.py @@ -0,0 +1,61 @@ +# Copyright (c) 2026 Samsung Electronics Co., Ltd. All Rights Reserved +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +import torch + +from tico.passes import ops +from tico.passes.convert_permute_to_reshape import ConvertPermuteToReshape + +from test.utils.helper import num_of_ops +from test.utils.pass_value_test import SinglePassValueTest + + +class PermuteBasic(torch.nn.Module): + def __init__(self): + super().__init__() + + def forward(self, x): + return torch.permute(x, (1, 2, 3, 0)) + + def get_example_inputs(self): + return (torch.rand([1, 5, 1, 3]),), {} + + +class PermuteBasicTest(SinglePassValueTest): + def test_pass(self): + self.setup(PermuteBasic()) + self.assertEqual(num_of_ops(self.exported_program(), ops.aten.permute), 1) + self.run_value_test(ConvertPermuteToReshape(True)) + self.assertEqual(num_of_ops(self.exported_program(), ops.aten.permute), 0) + self.assertEqual(num_of_ops(self.exported_program(), ops.aten.reshape), 1) + + +class PermuteBasicNegative(torch.nn.Module): + def __init__(self): + super().__init__() + + def forward(self, x): + return torch.permute(x, (2, 3, 0, 1)) + + def get_example_inputs(self): + return (torch.rand([1, 5, 1, 3]),), {} + + +class PermuteBasicNegativeTest(SinglePassValueTest): + def test_pass(self): + self.setup(PermuteBasicNegative()) + self.assertEqual(num_of_ops(self.exported_program(), ops.aten.permute), 1) + self.run_value_test(ConvertPermuteToReshape(True)) + self.assertEqual(num_of_ops(self.exported_program(), ops.aten.permute), 1) + self.assertEqual(num_of_ops(self.exported_program(), ops.aten.reshape), 0) diff --git a/test/utils/pass_value_test.py b/test/utils/pass_value_test.py index a9fdc06e..84130691 100644 --- a/test/utils/pass_value_test.py +++ b/test/utils/pass_value_test.py @@ -78,6 +78,7 @@ def run_value_test(self, test_pass: PassBase): # type: ignore[override] # inference after pass ret_after = self.ep.module()(*self.forward_args, **self.forward_kwargs) + self.assertEqual(ret_before.shape, ret_after.shape) self.assertTrue(torch.allclose(ret_before, ret_after, atol=1e-06)) @@ -93,4 +94,5 @@ def run_value_test(self, test_passes: list): # type: ignore[override] # inference after pass ret_after = self.ep.module()(*self.forward_args, **self.forward_kwargs) - self.assertTrue(torch.equal(ret_before, ret_after)) + self.assertEqual(ret_before.shape, ret_after.shape) + self.assertTrue(torch.allclose(ret_before, ret_after, atol=1e-06)) diff --git a/tico/passes/convert_conv3d_to_conv2d.py b/tico/passes/convert_conv3d_to_conv2d.py index 5664c72f..9d04d314 100644 --- a/tico/passes/convert_conv3d_to_conv2d.py +++ b/tico/passes/convert_conv3d_to_conv2d.py @@ -403,6 +403,88 @@ def convert(self, exported_program: ExportedProgram, node: torch.fx.Node) -> boo return modified + def optimized_convert( + self, exported_program: ExportedProgram, node: torch.fx.Node + ) -> bool: + logger = logging.getLogger(__name__) + modified = False + graph_module = exported_program.graph_module + graph = graph_module.graph + + # Extract conv3d arguments + args = Conv3DArgs(*node.args, **node.kwargs) # type: ignore[arg-type] + + input = args.input + weight = args.weight + bias = args.bias + groups = args.groups + + input_shape = extract_shape(input) + weight_shape = extract_shape(weight) + + if not (len(input_shape) == 5): + raise NotYetSupportedError( + f"Only support 5D input tensor: node's input shape: {input_shape}" + ) + + if not (len(weight_shape) == 5): + raise NotYetSupportedError( + f"Only support 5D weight tensor: node's weight shape: {weight_shape}" + ) + + N, C_in, T_in, H_in, W_in = input_shape + C_out, C_in_weight, kT, kH, kW = weight_shape + + if T_in == kT and H_in == kH and W_in == kW and groups == 1: + with graph.inserting_before(node): + input_reshape = create_node( + graph, + torch.ops.aten.reshape.default, + args=(input, [1, 1, N, C_in * T_in * H_in * W_in]), + origin=node, + ) + weight_reshape = create_node( + graph, + torch.ops.aten.reshape.default, + args=(weight, [C_out, 1, 1, C_in_weight * kT * kH * kW]), + origin=node, + ) + conv2d = create_node( + graph, + torch.ops.aten.conv2d.default, + args=( + input_reshape, + weight_reshape, + bias, + [1, 1], # stride + [0, 0], # padding + [1, 1], # dilation + groups, + ), + origin=node, + ) + conv2d_permute = create_node( + graph, + torch.ops.aten.permute.default, + args=(conv2d, [2, 1, 0, 3]), + origin=node, + ) + conv2d_reshape = create_node( + graph, + torch.ops.aten.reshape.default, + args=(conv2d_permute, [N, C_out, 1, 1, 1]), + origin=node, + ) + + # Replace the original node + node.replace_all_uses_with(conv2d_reshape, propagate_meta=False) + logger.debug( + f"{node.name} is replaced with optimized conv2d decomposition" + ) + modified = True + + return modified + def call(self, exported_program: ExportedProgram) -> PassResult: target_conv_op = [torch.ops.aten.conv3d.default, torch.ops.aten.conv3d.padding] graph_module = exported_program.graph_module @@ -414,7 +496,9 @@ def call(self, exported_program: ExportedProgram) -> PassResult: for node in graph.nodes: if not is_target_node(node, target_conv_op): continue - modified |= self.convert(exported_program, node) + modified |= self.optimized_convert(exported_program, node) or self.convert( + exported_program, node + ) graph.eliminate_dead_code() graph.lint() diff --git a/tico/passes/convert_permute_to_reshape.py b/tico/passes/convert_permute_to_reshape.py new file mode 100644 index 00000000..df0bee0c --- /dev/null +++ b/tico/passes/convert_permute_to_reshape.py @@ -0,0 +1,105 @@ +# Copyright (c) 2026 Samsung Electronics Co., Ltd. All Rights Reserved +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +from typing import TYPE_CHECKING + +if TYPE_CHECKING: + import torch.fx +import torch +from torch.export import ExportedProgram + +from tico.passes import ops +from tico.serialize.circle_mapping import extract_shape +from tico.utils import logging +from tico.utils.graph import create_node +from tico.utils.passes import PassBase, PassResult +from tico.utils.trace_decorators import trace_graph_diff_on_pass +from tico.utils.utils import is_target_node +from tico.utils.validate_args_kwargs import PermuteArgs + + +@trace_graph_diff_on_pass +class ConvertPermuteToReshape(PassBase): + """ + This pass replaces `aten.permute` to `aten.reshape` when + the order of output data is exactly same as input data. + """ + + def __init__(self, enabled: bool = False): + super().__init__() + self.enabled = enabled + + def call(self, exported_program: ExportedProgram) -> PassResult: + if not self.enabled: + return PassResult(False) + + logger = logging.getLogger(__name__) + + graph_module = exported_program.graph_module + graph = graph_module.graph + modified = False + + for node in graph.nodes: + if not isinstance(node, torch.fx.Node) or not is_target_node( + node, ops.aten.permute + ): + continue + + # Extract permute arguments + args = PermuteArgs(*node.args, **node.kwargs) # type: ignore[arg-type] + + input = args.input + dims = args.dims + + input_shape = extract_shape(input) + + # When permute dims with non-1 values have same order, + # we can replace permute to reshape + # + # For example, if + # - input.shape = [1, x, 1, y] + # - torch.permute(input, [1, 2, 3, 0]) + # then permute dims 2 and 0 keeps same order for 'x' and 'y'. + is_same_order = True + last_dim = -1 + for dim in dims: + if input_shape[dim] == 1: + continue + + if last_dim < dim: + last_dim = dim + else: + is_same_order = False + break + + if is_same_order == True: + with graph.inserting_before(node): + reshape = create_node( + graph, + torch.ops.aten.reshape.default, + args=(input, [input_shape[dim] for dim in dims]), + origin=node, + ) + + node.replace_all_uses_with(reshape, propagate_meta=False) + modified = True + logger.debug( + f"{node.name} is replaced with {reshape.name} operators" + ) + + graph.eliminate_dead_code() + graph.lint() + graph_module.recompile() + + return PassResult(modified)