Skip to content
Draft
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
29 changes: 29 additions & 0 deletions test/unit_test/pass_test/test_convert_conv3d_to_conv2d.py
Original file line number Diff line number Diff line change
Expand Up @@ -264,3 +264,32 @@ def test_pass(self):
self.run_value_test(ConvertConv3dToConv2d())
self.assertEqual(num_of_ops(self.exported_program(), ops.aten.conv3d), 0)
self.assertGreaterEqual(num_of_ops(self.exported_program(), ops.aten.conv2d), 2)


class Conv3dPerfectFitKernel(torch.nn.Module):
"""Conv3D with perfect fitting kernel"""

def __init__(self):
super().__init__()
self.conv3d = torch.nn.Conv3d(
in_channels=3,
out_channels=1024,
kernel_size=(2, 16, 16),
stride=(2, 16, 16),
padding=(0, 0, 0),
)

def forward(self, input):
return self.conv3d(input)

def get_example_inputs(self):
return (torch.randn(5, 3, 2, 16, 16),), {}


class ConvertConv3dPerfectFitKernelTest(SinglePassValueTest):
def test_pass(self):
self.setup(Conv3dPerfectFitKernel())
self.assertEqual(num_of_ops(self.exported_program(), ops.aten.conv3d), 1)
self.run_value_test(ConvertConv3dToConv2d())
self.assertEqual(num_of_ops(self.exported_program(), ops.aten.conv3d), 0)
self.assertGreaterEqual(num_of_ops(self.exported_program(), ops.aten.conv2d), 1)
61 changes: 61 additions & 0 deletions test/unit_test/pass_test/test_convert_permute_to_reshape.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,61 @@
# Copyright (c) 2026 Samsung Electronics Co., Ltd. All Rights Reserved
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.

import torch

from tico.passes import ops
from tico.passes.convert_permute_to_reshape import ConvertPermuteToReshape

from test.utils.helper import num_of_ops
from test.utils.pass_value_test import SinglePassValueTest


class PermuteBasic(torch.nn.Module):
def __init__(self):
super().__init__()

def forward(self, x):
return torch.permute(x, (1, 2, 3, 0))

def get_example_inputs(self):
return (torch.rand([1, 5, 1, 3]),), {}


class PermuteBasicTest(SinglePassValueTest):
def test_pass(self):
self.setup(PermuteBasic())
self.assertEqual(num_of_ops(self.exported_program(), ops.aten.permute), 1)
self.run_value_test(ConvertPermuteToReshape(True))
self.assertEqual(num_of_ops(self.exported_program(), ops.aten.permute), 0)
self.assertEqual(num_of_ops(self.exported_program(), ops.aten.reshape), 1)


class PermuteBasicNegative(torch.nn.Module):
def __init__(self):
super().__init__()

def forward(self, x):
return torch.permute(x, (2, 3, 0, 1))

def get_example_inputs(self):
return (torch.rand([1, 5, 1, 3]),), {}


class PermuteBasicNegativeTest(SinglePassValueTest):
def test_pass(self):
self.setup(PermuteBasicNegative())
self.assertEqual(num_of_ops(self.exported_program(), ops.aten.permute), 1)
self.run_value_test(ConvertPermuteToReshape(True))
self.assertEqual(num_of_ops(self.exported_program(), ops.aten.permute), 1)
self.assertEqual(num_of_ops(self.exported_program(), ops.aten.reshape), 0)
4 changes: 3 additions & 1 deletion test/utils/pass_value_test.py
Original file line number Diff line number Diff line change
Expand Up @@ -78,6 +78,7 @@ def run_value_test(self, test_pass: PassBase): # type: ignore[override]
# inference after pass
ret_after = self.ep.module()(*self.forward_args, **self.forward_kwargs)

self.assertEqual(ret_before.shape, ret_after.shape)
self.assertTrue(torch.allclose(ret_before, ret_after, atol=1e-06))


Expand All @@ -93,4 +94,5 @@ def run_value_test(self, test_passes: list): # type: ignore[override]
# inference after pass
ret_after = self.ep.module()(*self.forward_args, **self.forward_kwargs)

self.assertTrue(torch.equal(ret_before, ret_after))
self.assertEqual(ret_before.shape, ret_after.shape)
self.assertTrue(torch.allclose(ret_before, ret_after, atol=1e-06))
86 changes: 85 additions & 1 deletion tico/passes/convert_conv3d_to_conv2d.py
Original file line number Diff line number Diff line change
Expand Up @@ -403,6 +403,88 @@ def convert(self, exported_program: ExportedProgram, node: torch.fx.Node) -> boo

return modified

def optimized_convert(
self, exported_program: ExportedProgram, node: torch.fx.Node
) -> bool:
logger = logging.getLogger(__name__)
modified = False
graph_module = exported_program.graph_module
graph = graph_module.graph

# Extract conv3d arguments
args = Conv3DArgs(*node.args, **node.kwargs) # type: ignore[arg-type]

input = args.input
weight = args.weight
bias = args.bias
groups = args.groups

input_shape = extract_shape(input)
weight_shape = extract_shape(weight)

if not (len(input_shape) == 5):
raise NotYetSupportedError(
f"Only support 5D input tensor: node's input shape: {input_shape}"
)

if not (len(weight_shape) == 5):
raise NotYetSupportedError(
f"Only support 5D weight tensor: node's weight shape: {weight_shape}"
)

N, C_in, T_in, H_in, W_in = input_shape
C_out, C_in_weight, kT, kH, kW = weight_shape

if T_in == kT and H_in == kH and W_in == kW and groups == 1:
with graph.inserting_before(node):
input_reshape = create_node(
graph,
torch.ops.aten.reshape.default,
args=(input, [1, 1, N, C_in * T_in * H_in * W_in]),
origin=node,
)
weight_reshape = create_node(
graph,
torch.ops.aten.reshape.default,
args=(weight, [C_out, 1, 1, C_in_weight * kT * kH * kW]),
origin=node,
)
conv2d = create_node(
graph,
torch.ops.aten.conv2d.default,
args=(
input_reshape,
weight_reshape,
bias,
[1, 1], # stride
[0, 0], # padding
[1, 1], # dilation
groups,
),
origin=node,
)
conv2d_permute = create_node(
graph,
torch.ops.aten.permute.default,
args=(conv2d, [2, 1, 0, 3]),
origin=node,
)
conv2d_reshape = create_node(
graph,
torch.ops.aten.reshape.default,
args=(conv2d_permute, [N, C_out, 1, 1, 1]),
origin=node,
)

# Replace the original node
node.replace_all_uses_with(conv2d_reshape, propagate_meta=False)
logger.debug(
f"{node.name} is replaced with optimized conv2d decomposition"
)
modified = True

return modified

def call(self, exported_program: ExportedProgram) -> PassResult:
target_conv_op = [torch.ops.aten.conv3d.default, torch.ops.aten.conv3d.padding]
graph_module = exported_program.graph_module
Expand All @@ -414,7 +496,9 @@ def call(self, exported_program: ExportedProgram) -> PassResult:
for node in graph.nodes:
if not is_target_node(node, target_conv_op):
continue
modified |= self.convert(exported_program, node)
modified |= self.optimized_convert(exported_program, node) or self.convert(
exported_program, node
)

graph.eliminate_dead_code()
graph.lint()
Expand Down
105 changes: 105 additions & 0 deletions tico/passes/convert_permute_to_reshape.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,105 @@
# Copyright (c) 2026 Samsung Electronics Co., Ltd. All Rights Reserved
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.

from typing import TYPE_CHECKING

if TYPE_CHECKING:
import torch.fx
import torch
from torch.export import ExportedProgram

from tico.passes import ops
from tico.serialize.circle_mapping import extract_shape
from tico.utils import logging
from tico.utils.graph import create_node
from tico.utils.passes import PassBase, PassResult
from tico.utils.trace_decorators import trace_graph_diff_on_pass
from tico.utils.utils import is_target_node
from tico.utils.validate_args_kwargs import PermuteArgs


@trace_graph_diff_on_pass
class ConvertPermuteToReshape(PassBase):
"""
This pass replaces `aten.permute` to `aten.reshape` when
the order of output data is exactly same as input data.
"""

def __init__(self, enabled: bool = False):
super().__init__()
self.enabled = enabled

def call(self, exported_program: ExportedProgram) -> PassResult:
if not self.enabled:
return PassResult(False)

logger = logging.getLogger(__name__)

graph_module = exported_program.graph_module
graph = graph_module.graph
modified = False

for node in graph.nodes:
if not isinstance(node, torch.fx.Node) or not is_target_node(
node, ops.aten.permute
):
continue

# Extract permute arguments
args = PermuteArgs(*node.args, **node.kwargs) # type: ignore[arg-type]

input = args.input
dims = args.dims

input_shape = extract_shape(input)

Copy link
Copy Markdown
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Could you normalize the input shape for negative integer cases?

ndims = len(input_shape)
normalized_dims = [(d if d >= 0 else d + ndims) for d in dims]


# When permute dims with non-1 values have same order,
# we can replace permute to reshape
#
# For example, if
# - input.shape = [1, x, 1, y]
# - torch.permute(input, [1, 2, 3, 0])
# then permute dims 2 and 0 keeps same order for 'x' and 'y'.
is_same_order = True
last_dim = -1
for dim in dims:
if input_shape[dim] == 1:
continue

if last_dim < dim:
last_dim = dim
else:
is_same_order = False
break

if is_same_order == True:
with graph.inserting_before(node):
reshape = create_node(
graph,
torch.ops.aten.reshape.default,
args=(input, [input_shape[dim] for dim in dims]),
origin=node,
)

node.replace_all_uses_with(reshape, propagate_meta=False)
modified = True
logger.debug(
f"{node.name} is replaced with {reshape.name} operators"
)

graph.eliminate_dead_code()
graph.lint()
graph_module.recompile()

return PassResult(modified)