From 4cc2d62a0d76d6076ef12e0a6cb94ea97256b44a Mon Sep 17 00:00:00 2001 From: shashwat1198 Date: Mon, 10 Feb 2025 14:27:26 +0000 Subject: [PATCH 1/2] PR for MoveLinearPastEltwiseMul transformation --- src/finn/transformation/streamline/reorder.py | 67 +++++++++++++++++++ 1 file changed, 67 insertions(+) diff --git a/src/finn/transformation/streamline/reorder.py b/src/finn/transformation/streamline/reorder.py index 8ac2d7dad6..324271e4ea 100644 --- a/src/finn/transformation/streamline/reorder.py +++ b/src/finn/transformation/streamline/reorder.py @@ -594,6 +594,73 @@ def apply(self, model): model = model.transform(InferShapes()) return (model, graph_modified) +class MoveLinearPastEltwiseMul(Transformation): + """Move linear operations (mul) past elementwise mul operations where possible. + Specifically,matches and transforms the following patterns: + (x*A) * (y*B) -> (xy)*(A*B) + where x and y are dynamic inputs, A, B are constant tensors (in general). + """ + + def move_node(self, graph, n, prod0, prod1, node_ind): + # found! move one of the muls to output, remove the other one + lin0_in0 = prod0.input[0] + lin1_in0 = prod1.input[0] + in0 = n.input[0] + out = n.output[0] + # connect the eltwise mul inputs to mul inputs + n.input[0] = lin0_in0 + n.input[1] = lin1_in0 + # connect mul0 output to eltwise mul output + prod0.output[0] = out + # connect the input of mul0 and output of eltwise mul together + n.output[0] = in0 + prod0.input[0] = in0 + # move prod0 node past eltwise mul node, and remove prod1 + graph.node.remove(prod1) + graph.node.remove(prod0) + graph.node.insert(node_ind - 2, prod0) + + def apply(self, model): + graph = model.graph + node_ind = 0 + graph_modified = False + nodes = [n for n in graph.node] + for n in nodes: + node_ind += 1 + #checking if the operation is eltwisemul + if n.op_type == "Mul": + in0 = n.input[0] + in1 = n.input[1] + if in0 is None or in1 is None: + continue + A = model.get_initializer(in0) + B = model.get_initializer(in1) + if A is not None or B is not None: + continue + # check for mul with same initializer on both inputs + prod0 = model.find_producer(in0) + prod1 = model.find_producer(in1) + if prod0 is None or prod1 is None or (prod0 == prod1): + continue + if len(prod0.input) < 2 or len(prod1.input) < 2: + continue + init0 = model.get_initializer(prod0.input[1]) + init1 = model.get_initializer(prod1.input[1]) + # if either initializer is None, skip + if init0 is None or init1 is None: + continue + if prod0.op_type == "Mul" and prod1.op_type == "Mul": + # Adding the update intializer condition + init = init0*init1 + # update initializer of prod0, the node which will move + model.set_initializer(prod0.input[1],init) + self.move_node(graph,n,prod0,prod1,node_ind) + node_ind -= 1 + graph_modified = True + else: + continue + model = model.transform(InferShapes()) + return (model, graph_modified) class MoveScalarLinearPastInvariants(Transformation): """Move scalar linear operations (mul, add) past functions which are invariant From 075bdf8a9f0ff1b5d7b716ed5ef7429ec1b7efc0 Mon Sep 17 00:00:00 2001 From: shashwat1198 Date: Mon, 10 Feb 2025 15:12:01 +0000 Subject: [PATCH 2/2] Fix pre-commit issues Signed-off-by: shashwat1198 --- src/finn/transformation/streamline/reorder.py | 12 +++++++----- 1 file changed, 7 insertions(+), 5 deletions(-) diff --git a/src/finn/transformation/streamline/reorder.py b/src/finn/transformation/streamline/reorder.py index 324271e4ea..acea4d0632 100644 --- a/src/finn/transformation/streamline/reorder.py +++ b/src/finn/transformation/streamline/reorder.py @@ -594,6 +594,7 @@ def apply(self, model): model = model.transform(InferShapes()) return (model, graph_modified) + class MoveLinearPastEltwiseMul(Transformation): """Move linear operations (mul) past elementwise mul operations where possible. Specifically,matches and transforms the following patterns: @@ -627,7 +628,7 @@ def apply(self, model): nodes = [n for n in graph.node] for n in nodes: node_ind += 1 - #checking if the operation is eltwisemul + # checking if the operation is eltwisemul if n.op_type == "Mul": in0 = n.input[0] in1 = n.input[1] @@ -650,11 +651,11 @@ def apply(self, model): if init0 is None or init1 is None: continue if prod0.op_type == "Mul" and prod1.op_type == "Mul": - # Adding the update intializer condition - init = init0*init1 + # Adding the update intializer condition + init = init0 * init1 # update initializer of prod0, the node which will move - model.set_initializer(prod0.input[1],init) - self.move_node(graph,n,prod0,prod1,node_ind) + model.set_initializer(prod0.input[1], init) + self.move_node(graph, n, prod0, prod1, node_ind) node_ind -= 1 graph_modified = True else: @@ -662,6 +663,7 @@ def apply(self, model): model = model.transform(InferShapes()) return (model, graph_modified) + class MoveScalarLinearPastInvariants(Transformation): """Move scalar linear operations (mul, add) past functions which are invariant to them. Specifically, matches and transforms the following patterns: