diff --git a/tools/onnx_subgraph/src/lib/graph.cpp b/tools/onnx_subgraph/src/lib/graph.cpp index ecd095c0d5a..bef970990fb 100644 --- a/tools/onnx_subgraph/src/lib/graph.cpp +++ b/tools/onnx_subgraph/src/lib/graph.cpp @@ -16,8 +16,154 @@ #include "graph.h" +std::unordered_set getInitializer(const onnx::GraphProto &graph) +{ + std::unordered_set initializerNames; + + for (const auto &initializer : graph.initializer()) + { + NodeTensor nt; + nt.name = initializer.name(); + std::vector shape; + + for (const auto &dim : initializer.dims()) + { + shape.push_back(dim); + } + + nt.shape = shape; + initializerNames.insert(nt); + } + + return initializerNames; +} + +std::unordered_set getIOvalue(const onnx::GraphProto &graph) +{ + std::unordered_set IOvalue; + + for (const auto &value_info : graph.value_info()) + { + NodeTensor nt; + nt.name = value_info.name(); + std::vector shape; + + for (const auto &dim : value_info.type().tensor_type().shape().dim()) + { + shape.push_back(dim.dim_value()); + } + + nt.shape = shape; + IOvalue.insert(nt); + } + + for (auto value_info : graph.input()) + { + NodeTensor nt; + nt.name = value_info.name(); + std::vector shape; + + for (const auto &dim : value_info.type().tensor_type().shape().dim()) + { + shape.push_back(dim.dim_value()); + } + + nt.shape = shape; + IOvalue.insert(nt); + } + + for (auto value_info : graph.output()) + { + NodeTensor nt; + nt.name = value_info.name(); + std::vector shape; + + for (const auto &dim : value_info.type().tensor_type().shape().dim()) + { + shape.push_back(dim.dim_value()); + } + + nt.shape = shape; + IOvalue.insert(nt); + } + + return IOvalue; +} + +std::unordered_set::const_iterator +isInputFromInitializer(const std::string &name, const std::unordered_set &tensors) +{ + return std::find_if(tensors.begin(), tensors.end(), + [&](const NodeTensor &tensor) { return tensor.name == name; }); +} + +void determineGraphInput(const onnx::GraphProto &g, + const std::unordered_set &initializerNames, + std::unordered_set &graphInputs) +{ + std::unordered_set allnodeOutputs; + + // Iterate over each node in the graph to collect all outputs + for (const auto &node : g.node()) + { + // Get the output list of the current node + const auto &outputs = node.output(); + + // Insert each output into the set of all node outputs + for (const auto &output : outputs) + { + allnodeOutputs.insert(output); + } + } + + // Iterate over each node in the graph to identify inputs not produced by any node + for (const auto &node : g.node()) + { + // Get the input list of the current node + const auto &inputs = node.input(); + + // Check each input to determine if it is an external input to the graph + for (const auto &input : inputs) + { + // If the input is not found in the set of all node outputs, it is a graph input + if (std::find(allnodeOutputs.begin(), allnodeOutputs.end(), input) == allnodeOutputs.end()) + { + auto iter = isInputFromInitializer(input, initializerNames); + NodeTensor nt; + nt.name = input; + + if (iter != initializerNames.end()) + { + graphInputs.insert(*iter); + } + } + } + } +} + onnx::GraphProto GetGraphFromOnnx(std::string &path) { onnx::ModelProto model; + + std::ifstream input(path, std::ios::ate | std::ios::binary); + if (!input.is_open()) + { + std::cout << "Error: Failed to open file: " << path << std::endl; + exit(-1); + } + + // get current position in file + std::streamsize size = input.tellg(); + + // move to start of file + input.seekg(0, std::ios::beg); + + // read raw data + std::vector buffer(size); + input.read(buffer.data(), size); + + // parse protobuf + model.ParseFromArray(buffer.data(), size); + return model.graph(); }