From 7146298fe9c5340913c10d2c40d5b8110bd360ce Mon Sep 17 00:00:00 2001 From: Youxin Chen Date: Wed, 2 Apr 2025 13:19:49 +0800 Subject: [PATCH] [tools/onnx-subgraph] add onnx process APIs definition add onnx process function definition, implementation will be added in the cpp file in next PR ONE-DCO-1.0-Signed-off-by: Youxin Chen --- tools/onnx_subgraph/include/graph.h | 109 ++++++++++++++++++++++++++++ 1 file changed, 109 insertions(+) diff --git a/tools/onnx_subgraph/include/graph.h b/tools/onnx_subgraph/include/graph.h index f39c6d58926..8633726a375 100644 --- a/tools/onnx_subgraph/include/graph.h +++ b/tools/onnx_subgraph/include/graph.h @@ -76,6 +76,115 @@ template <> struct hash } // namespace std +/** + * @brief Extracts the names and shapes of initializers from the ONNX graph. + * + * @param [in] graph The ONNX graph from which to extract initializers. + * @pre The ONNX graph should be valid and contain initializers. + * @post The names and shapes of the initializers are stored in an unordered set of NodeTensor + * objects. + * @exception None + * @return An unordered set of NodeTensor objects containing the names and shapes of the + * initializers. + */ +std::unordered_set getInitializer(const onnx::GraphProto &graph); + +/** + * @brief Extracts the names and shapes of inputs, outputs, and value_info from the ONNX graph. + * + * @param [in] graph The ONNX graph from which to extract inputs, outputs, and value_info. + * @pre The ONNX graph should be valid and contain inputs, outputs, and value_info. + * @post The names and shapes of the inputs, outputs, and value_info are stored in an unordered + * set of NodeTensor objects. + * @exception None + * @return An unordered set of NodeTensor objects containing the names and shapes of the inputs, + * outputs, and value_info. + */ +std::unordered_set getIOvalue(const onnx::GraphProto &graph); + +/** + * @brief Determines the input tensors of the graph that are not produced by any node in the + * graph. + * + * @param [in] g The ONNX GraphProto object representing the graph. + * @param [in] initializerNames A set of NodeTensor objects representing the initializers in the + * graph. + * @param [out] graphInputs A set of NodeTensor objects representing the input tensors of the + * graph. + * @pre The GraphProto object g should be valid and contain nodes with proper input and output + * lists. + * @post The graphInputs set will be populated with NodeTensor objects that are inputs to the + * graph. + * @exception None + * @return None + */ +void determineGraphInput(const onnx::GraphProto &g, + const std::unordered_set &initializerNames, + std::unordered_set &graphInputs); + +/** + * @brief Determines the output tensors of the graph that are either outputs of the original + * graph or are used as inputs in other parts of the graph. + * + * @param [in] originalGraph The original ONNX GraphProto object representing the graph. + * @param [in] g The ONNX GraphProto object representing the graph to analyze. + * @param [in] allgraphInputs_1 A vector of sets of NodeTensor objects representing the first + * set of inputs to the graph. + * @param [in] allgraphInputs_2 A vector of sets of NodeTensor objects representing the second + * set of inputs to the graph. + * @param [out] graphOutputs A set of NodeTensor objects representing the output tensors of the + * graph. + * @pre The GraphProto objects originalGraph and g should be valid and contain nodes with + * proper input and output lists. + * @post The graphOutputs set will be populated with NodeTensor objects that are outputs of the + * graph. + * @exception None + * @return None + */ +void determineGraphOutput(const onnx::GraphProto &originalGraph, const onnx::GraphProto &g, + std::vector> &allgraphInputs_1, + std::vector> &allgraphInputs_2, + std::unordered_set &graphOutputs); + +/** + * @brief Finds the name of the node that produces a specified output tensor in the given ONNX + * graph. + * + * @param [in] g The ONNX GraphProto object representing the graph. + * @param [in] outputTensorName The name of the output tensor to find the producing node for. + * @pre The GraphProto object g should be valid and contain nodes with proper input and output + * lists. + * @post None + * @exception None + * @return The name of the node that produces the specified output tensor, or an empty string if + * no such node is found. + */ +std::string findInputNode(const onnx::GraphProto &g, const std::string &outputTensorName); + +/** + * @brief Collects the names of all nodes in the given ONNX graph. + * + * @param [in] graph The ONNX GraphProto object representing the graph. + * @pre The GraphProto object graph should be valid and contain nodes with proper names. + * @post None + * @exception None + * @return An unordered set containing the names of all nodes in the graph. + */ +std::unordered_set collectNodeNames(const onnx::GraphProto &graph); + +/** + * @brief Merges nodes from the source graph into the target graph. + * + * @param [in,out] targetGraph The ONNX GraphProto object to which nodes will be added. + * @param [in] sourceGraph The ONNX GraphProto object from which nodes will be copied. + * @pre Both GraphProto objects should be valid. + * @post Nodes from sourceGraph are added to targetGraph. + * @exception Exits the program with an error message if the number of nodes in targetGraph does not + * match the expected size after merging. + * @return None + */ +void mergeGraphs(onnx::GraphProto &targetGraph, onnx::GraphProto &sourceGraph); + /** * @brief Loads an ONNX model from a file and returns the graph contained within. *