From 00ff785ca09ee77a51b9a8818804d440205c9fba Mon Sep 17 00:00:00 2001 From: chenyx113 Date: Tue, 4 Mar 2025 20:43:38 +0800 Subject: [PATCH 1/3] [tools/onnx-subgraph] add multi subgraphs inference code add code for multi subgraphs inference ONE-DCO-1.0-Signed-off-by: Youxin Chen --- .../onnx_subgraph/single_vs_multiple_onnx.py | 56 +++++++++++++++++++ 1 file changed, 56 insertions(+) diff --git a/tools/onnx_subgraph/single_vs_multiple_onnx.py b/tools/onnx_subgraph/single_vs_multiple_onnx.py index 7442fa520af..a0788e1b7a8 100644 --- a/tools/onnx_subgraph/single_vs_multiple_onnx.py +++ b/tools/onnx_subgraph/single_vs_multiple_onnx.py @@ -36,6 +36,53 @@ class ModelInference: def __init__(self, model_path, subgraphsiostxt_path): self.model_path = model_path self.subgraphsiostxt_path = subgraphsiostxt_path + self.sessions, self.sorted_file_paths = self.load_sessions() + + def load_sessions(self): + with open(self.subgraphsiostxt_path, 'r') as file: + content = file.read() + subgraph_order_map = {} + matches = re.findall(r'(\w+)subgraph(\d+): order(\d+)', content) + + for match in matches: + subgraph_type, subgraph_number, order = match + # lower_subgraph_type = subgraph_type.lower() + file_path = os.path.join(self.model_path, + f"{subgraph_type}subgraph{subgraph_number}.onnx") + if int(order) in subgraph_order_map: + subgraph_order_map[int(order)].append(file_path) + else: + subgraph_order_map[int(order)] = [file_path] + + sorted_file_paths = [] + for order in sorted(subgraph_order_map.keys()): + sorted_file_paths.extend(subgraph_order_map[order]) + + sessions = [ort.InferenceSession(model) for model in sorted_file_paths] + return sessions, sorted_file_paths + + def infer_multiple_onnx_models(self, + initial_input_data, + output_names_to_collect=None): + input_data = initial_input_data + collected_outputs = {} + + for i, (session, + model_file) in enumerate(zip(self.sessions, self.sorted_file_paths)): + input_names = [inp.name for inp in session.get_inputs()] + output_names = [out.name for out in session.get_outputs()] + model_input_data = {name: input_data[name] for name in input_names} + outputs = session.run(None, model_input_data) + current_model_outputs = dict(zip(output_names, outputs)) + if output_names_to_collect is not None: + for output_name in output_names_to_collect: + if output_name in current_model_outputs: + collected_outputs[output_name] = current_model_outputs[ + output_name] + + if i < len(self.sessions) - 1: + input_data.update(current_model_outputs) + return collected_outputs def infer_single_onnx_model(model_file, input_data): session = ort.InferenceSession(model_file) @@ -116,7 +163,16 @@ def prepare_initial_input_data(onnx_model_path, default_input_data): "x": np.random.rand(1, 3, 256, 256).astype(np.float32), } initial_input_data = prepare_initial_input_data(args.single, default_input_data) + # Perform inference using a single ONNX model output_single = ModelInference.infer_single_onnx_model(args.single, initial_input_data) print("Single model inference completed!") + + # Retrieve all output names from the single model + output_names_list = list(output_single.keys()) + + # Perform inference using multiple split subgraph models + output_multiple = model_inference.infer_multiple_onnx_models( + initial_input_data, output_names_list) + print("Multiple subgraph inference completed!") From 079eb67a5a6e334e71dc8698ffe18c94820550df Mon Sep 17 00:00:00 2001 From: chenyx113 Date: Wed, 5 Mar 2025 11:41:16 +0800 Subject: [PATCH 2/3] Update single_vs_multiple_onnx.py update code as the review comment --- tools/onnx_subgraph/single_vs_multiple_onnx.py | 5 +++-- 1 file changed, 3 insertions(+), 2 deletions(-) diff --git a/tools/onnx_subgraph/single_vs_multiple_onnx.py b/tools/onnx_subgraph/single_vs_multiple_onnx.py index a0788e1b7a8..b4db35ff116 100644 --- a/tools/onnx_subgraph/single_vs_multiple_onnx.py +++ b/tools/onnx_subgraph/single_vs_multiple_onnx.py @@ -74,7 +74,9 @@ def infer_multiple_onnx_models(self, model_input_data = {name: input_data[name] for name in input_names} outputs = session.run(None, model_input_data) current_model_outputs = dict(zip(output_names, outputs)) - if output_names_to_collect is not None: + if output_names_to_collect is None: + return {} + else: for output_name in output_names_to_collect: if output_name in current_model_outputs: collected_outputs[output_name] = current_model_outputs[ @@ -163,7 +165,6 @@ def prepare_initial_input_data(onnx_model_path, default_input_data): "x": np.random.rand(1, 3, 256, 256).astype(np.float32), } initial_input_data = prepare_initial_input_data(args.single, default_input_data) - # Perform inference using a single ONNX model output_single = ModelInference.infer_single_onnx_model(args.single, initial_input_data) From 48074e1829d0d958799f5cf5ac2da8afaef31bad Mon Sep 17 00:00:00 2001 From: chenyx113 Date: Wed, 5 Mar 2025 14:12:23 +0800 Subject: [PATCH 3/3] Update single_vs_multiple_onnx.py move parameter exception checking out of loop --- tools/onnx_subgraph/single_vs_multiple_onnx.py | 13 ++++++------- 1 file changed, 6 insertions(+), 7 deletions(-) diff --git a/tools/onnx_subgraph/single_vs_multiple_onnx.py b/tools/onnx_subgraph/single_vs_multiple_onnx.py index b4db35ff116..86d3723ec09 100644 --- a/tools/onnx_subgraph/single_vs_multiple_onnx.py +++ b/tools/onnx_subgraph/single_vs_multiple_onnx.py @@ -64,6 +64,8 @@ def load_sessions(self): def infer_multiple_onnx_models(self, initial_input_data, output_names_to_collect=None): + if output_names_to_collect is None: + return {} input_data = initial_input_data collected_outputs = {} @@ -74,13 +76,10 @@ def infer_multiple_onnx_models(self, model_input_data = {name: input_data[name] for name in input_names} outputs = session.run(None, model_input_data) current_model_outputs = dict(zip(output_names, outputs)) - if output_names_to_collect is None: - return {} - else: - for output_name in output_names_to_collect: - if output_name in current_model_outputs: - collected_outputs[output_name] = current_model_outputs[ - output_name] + + for output_name in output_names_to_collect: + if output_name in current_model_outputs: + collected_outputs[output_name] = current_model_outputs[output_name] if i < len(self.sessions) - 1: input_data.update(current_model_outputs)