onnxruntime-directml 1.24.1__cp314-cp314-win_amd64.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (322) hide show
  1. onnxruntime/LICENSE +21 -0
  2. onnxruntime/Privacy.md +21 -0
  3. onnxruntime/ThirdPartyNotices.txt +6121 -0
  4. onnxruntime/__init__.py +418 -0
  5. onnxruntime/backend/__init__.py +6 -0
  6. onnxruntime/backend/backend.py +175 -0
  7. onnxruntime/backend/backend_rep.py +52 -0
  8. onnxruntime/capi/DirectML.dll +0 -0
  9. onnxruntime/capi/__init__.py +4 -0
  10. onnxruntime/capi/_ld_preload.py +7 -0
  11. onnxruntime/capi/_pybind_state.py +33 -0
  12. onnxruntime/capi/build_and_package_info.py +2 -0
  13. onnxruntime/capi/convert_npz_to_onnx_adapter.py +48 -0
  14. onnxruntime/capi/onnxruntime.dll +0 -0
  15. onnxruntime/capi/onnxruntime_collect_build_info.py +47 -0
  16. onnxruntime/capi/onnxruntime_inference_collection.py +1440 -0
  17. onnxruntime/capi/onnxruntime_providers_shared.dll +0 -0
  18. onnxruntime/capi/onnxruntime_pybind11_state.pyd +0 -0
  19. onnxruntime/capi/onnxruntime_validation.py +154 -0
  20. onnxruntime/capi/version_info.py +2 -0
  21. onnxruntime/datasets/__init__.py +18 -0
  22. onnxruntime/datasets/logreg_iris.onnx +0 -0
  23. onnxruntime/datasets/mul_1.onnx +0 -0
  24. onnxruntime/datasets/sigmoid.onnx +13 -0
  25. onnxruntime/quantization/CalTableFlatBuffers/KeyValue.py +78 -0
  26. onnxruntime/quantization/CalTableFlatBuffers/TrtTable.py +90 -0
  27. onnxruntime/quantization/CalTableFlatBuffers/__init__.py +0 -0
  28. onnxruntime/quantization/__init__.py +19 -0
  29. onnxruntime/quantization/base_quantizer.py +529 -0
  30. onnxruntime/quantization/calibrate.py +1267 -0
  31. onnxruntime/quantization/execution_providers/qnn/__init__.py +2 -0
  32. onnxruntime/quantization/execution_providers/qnn/fusion_lpnorm.py +132 -0
  33. onnxruntime/quantization/execution_providers/qnn/fusion_spacetodepth.py +162 -0
  34. onnxruntime/quantization/execution_providers/qnn/mixed_precision_overrides_utils.py +413 -0
  35. onnxruntime/quantization/execution_providers/qnn/preprocess.py +353 -0
  36. onnxruntime/quantization/execution_providers/qnn/quant_config.py +389 -0
  37. onnxruntime/quantization/fusions/__init__.py +4 -0
  38. onnxruntime/quantization/fusions/fusion.py +311 -0
  39. onnxruntime/quantization/fusions/fusion_gelu.py +272 -0
  40. onnxruntime/quantization/fusions/fusion_layernorm.py +146 -0
  41. onnxruntime/quantization/fusions/replace_upsample_with_resize.py +96 -0
  42. onnxruntime/quantization/matmul_bnb4_quantizer.py +239 -0
  43. onnxruntime/quantization/matmul_nbits_quantizer.py +1638 -0
  44. onnxruntime/quantization/neural_compressor/__init__.py +1 -0
  45. onnxruntime/quantization/neural_compressor/onnx_model.py +1251 -0
  46. onnxruntime/quantization/neural_compressor/util.py +80 -0
  47. onnxruntime/quantization/neural_compressor/weight_only.py +932 -0
  48. onnxruntime/quantization/onnx_model.py +600 -0
  49. onnxruntime/quantization/onnx_quantizer.py +1163 -0
  50. onnxruntime/quantization/operators/__init__.py +2 -0
  51. onnxruntime/quantization/operators/activation.py +119 -0
  52. onnxruntime/quantization/operators/argmax.py +18 -0
  53. onnxruntime/quantization/operators/attention.py +73 -0
  54. onnxruntime/quantization/operators/base_operator.py +26 -0
  55. onnxruntime/quantization/operators/binary_op.py +72 -0
  56. onnxruntime/quantization/operators/concat.py +62 -0
  57. onnxruntime/quantization/operators/conv.py +260 -0
  58. onnxruntime/quantization/operators/direct_q8.py +78 -0
  59. onnxruntime/quantization/operators/embed_layernorm.py +121 -0
  60. onnxruntime/quantization/operators/gather.py +64 -0
  61. onnxruntime/quantization/operators/gavgpool.py +62 -0
  62. onnxruntime/quantization/operators/gemm.py +172 -0
  63. onnxruntime/quantization/operators/lstm.py +121 -0
  64. onnxruntime/quantization/operators/matmul.py +231 -0
  65. onnxruntime/quantization/operators/maxpool.py +34 -0
  66. onnxruntime/quantization/operators/norm.py +40 -0
  67. onnxruntime/quantization/operators/pad.py +172 -0
  68. onnxruntime/quantization/operators/pooling.py +67 -0
  69. onnxruntime/quantization/operators/qdq_base_operator.py +22 -0
  70. onnxruntime/quantization/operators/resize.py +34 -0
  71. onnxruntime/quantization/operators/softmax.py +74 -0
  72. onnxruntime/quantization/operators/split.py +63 -0
  73. onnxruntime/quantization/operators/where.py +87 -0
  74. onnxruntime/quantization/preprocess.py +141 -0
  75. onnxruntime/quantization/qdq_loss_debug.py +389 -0
  76. onnxruntime/quantization/qdq_quantizer.py +1477 -0
  77. onnxruntime/quantization/quant_utils.py +1051 -0
  78. onnxruntime/quantization/quantize.py +953 -0
  79. onnxruntime/quantization/registry.py +110 -0
  80. onnxruntime/quantization/shape_inference.py +204 -0
  81. onnxruntime/quantization/static_quantize_runner.py +256 -0
  82. onnxruntime/quantization/tensor_quant_overrides.py +520 -0
  83. onnxruntime/tools/__init__.py +10 -0
  84. onnxruntime/tools/check_onnx_model_mobile_usability.py +47 -0
  85. onnxruntime/tools/convert_onnx_models_to_ort.py +380 -0
  86. onnxruntime/tools/file_utils.py +47 -0
  87. onnxruntime/tools/logger.py +11 -0
  88. onnxruntime/tools/make_dynamic_shape_fixed.py +73 -0
  89. onnxruntime/tools/mobile_helpers/__init__.py +0 -0
  90. onnxruntime/tools/mobile_helpers/coreml_supported_mlprogram_ops.md +53 -0
  91. onnxruntime/tools/mobile_helpers/coreml_supported_neuralnetwork_ops.md +43 -0
  92. onnxruntime/tools/mobile_helpers/nnapi_supported_ops.md +58 -0
  93. onnxruntime/tools/mobile_helpers/usability_checker.py +738 -0
  94. onnxruntime/tools/offline_tuning.py +169 -0
  95. onnxruntime/tools/onnx_model_utils.py +416 -0
  96. onnxruntime/tools/onnx_randomizer.py +85 -0
  97. onnxruntime/tools/onnxruntime_test.py +164 -0
  98. onnxruntime/tools/optimize_onnx_model.py +56 -0
  99. onnxruntime/tools/ort_format_model/__init__.py +27 -0
  100. onnxruntime/tools/ort_format_model/operator_type_usage_processors.py +653 -0
  101. onnxruntime/tools/ort_format_model/ort_flatbuffers_py/__init__.py +0 -0
  102. onnxruntime/tools/ort_format_model/ort_flatbuffers_py/fbs/ArgType.py +7 -0
  103. onnxruntime/tools/ort_format_model/ort_flatbuffers_py/fbs/ArgTypeAndIndex.py +67 -0
  104. onnxruntime/tools/ort_format_model/ort_flatbuffers_py/fbs/Attribute.py +337 -0
  105. onnxruntime/tools/ort_format_model/ort_flatbuffers_py/fbs/AttributeType.py +18 -0
  106. onnxruntime/tools/ort_format_model/ort_flatbuffers_py/fbs/Checkpoint.py +125 -0
  107. onnxruntime/tools/ort_format_model/ort_flatbuffers_py/fbs/DeprecatedKernelCreateInfos.py +120 -0
  108. onnxruntime/tools/ort_format_model/ort_flatbuffers_py/fbs/DeprecatedNodeIndexAndKernelDefHash.py +68 -0
  109. onnxruntime/tools/ort_format_model/ort_flatbuffers_py/fbs/DeprecatedSessionState.py +96 -0
  110. onnxruntime/tools/ort_format_model/ort_flatbuffers_py/fbs/DeprecatedSubGraphSessionState.py +72 -0
  111. onnxruntime/tools/ort_format_model/ort_flatbuffers_py/fbs/Dimension.py +71 -0
  112. onnxruntime/tools/ort_format_model/ort_flatbuffers_py/fbs/DimensionValue.py +80 -0
  113. onnxruntime/tools/ort_format_model/ort_flatbuffers_py/fbs/DimensionValueType.py +8 -0
  114. onnxruntime/tools/ort_format_model/ort_flatbuffers_py/fbs/EdgeEnd.py +32 -0
  115. onnxruntime/tools/ort_format_model/ort_flatbuffers_py/fbs/FloatProperty.py +67 -0
  116. onnxruntime/tools/ort_format_model/ort_flatbuffers_py/fbs/Graph.py +320 -0
  117. onnxruntime/tools/ort_format_model/ort_flatbuffers_py/fbs/InferenceSession.py +88 -0
  118. onnxruntime/tools/ort_format_model/ort_flatbuffers_py/fbs/IntProperty.py +67 -0
  119. onnxruntime/tools/ort_format_model/ort_flatbuffers_py/fbs/KernelTypeStrArgsEntry.py +91 -0
  120. onnxruntime/tools/ort_format_model/ort_flatbuffers_py/fbs/KernelTypeStrResolver.py +78 -0
  121. onnxruntime/tools/ort_format_model/ort_flatbuffers_py/fbs/MapType.py +71 -0
  122. onnxruntime/tools/ort_format_model/ort_flatbuffers_py/fbs/Model.py +223 -0
  123. onnxruntime/tools/ort_format_model/ort_flatbuffers_py/fbs/ModuleState.py +141 -0
  124. onnxruntime/tools/ort_format_model/ort_flatbuffers_py/fbs/Node.py +317 -0
  125. onnxruntime/tools/ort_format_model/ort_flatbuffers_py/fbs/NodeEdge.py +126 -0
  126. onnxruntime/tools/ort_format_model/ort_flatbuffers_py/fbs/NodeType.py +7 -0
  127. onnxruntime/tools/ort_format_model/ort_flatbuffers_py/fbs/NodesToOptimizeIndices.py +160 -0
  128. onnxruntime/tools/ort_format_model/ort_flatbuffers_py/fbs/OpIdKernelTypeStrArgsEntry.py +91 -0
  129. onnxruntime/tools/ort_format_model/ort_flatbuffers_py/fbs/OperatorSetId.py +67 -0
  130. onnxruntime/tools/ort_format_model/ort_flatbuffers_py/fbs/OptimizerGroup.py +117 -0
  131. onnxruntime/tools/ort_format_model/ort_flatbuffers_py/fbs/ParameterOptimizerState.py +91 -0
  132. onnxruntime/tools/ort_format_model/ort_flatbuffers_py/fbs/PropertyBag.py +152 -0
  133. onnxruntime/tools/ort_format_model/ort_flatbuffers_py/fbs/RuntimeOptimizationRecord.py +105 -0
  134. onnxruntime/tools/ort_format_model/ort_flatbuffers_py/fbs/RuntimeOptimizationRecordContainerEntry.py +91 -0
  135. onnxruntime/tools/ort_format_model/ort_flatbuffers_py/fbs/RuntimeOptimizations.py +79 -0
  136. onnxruntime/tools/ort_format_model/ort_flatbuffers_py/fbs/SequenceType.py +58 -0
  137. onnxruntime/tools/ort_format_model/ort_flatbuffers_py/fbs/Shape.py +78 -0
  138. onnxruntime/tools/ort_format_model/ort_flatbuffers_py/fbs/SparseTensor.py +114 -0
  139. onnxruntime/tools/ort_format_model/ort_flatbuffers_py/fbs/StringProperty.py +67 -0
  140. onnxruntime/tools/ort_format_model/ort_flatbuffers_py/fbs/StringStringEntry.py +67 -0
  141. onnxruntime/tools/ort_format_model/ort_flatbuffers_py/fbs/Tensor.py +203 -0
  142. onnxruntime/tools/ort_format_model/ort_flatbuffers_py/fbs/TensorDataType.py +26 -0
  143. onnxruntime/tools/ort_format_model/ort_flatbuffers_py/fbs/TensorTypeAndShape.py +71 -0
  144. onnxruntime/tools/ort_format_model/ort_flatbuffers_py/fbs/TypeInfo.py +83 -0
  145. onnxruntime/tools/ort_format_model/ort_flatbuffers_py/fbs/TypeInfoValue.py +9 -0
  146. onnxruntime/tools/ort_format_model/ort_flatbuffers_py/fbs/ValueInfo.py +84 -0
  147. onnxruntime/tools/ort_format_model/ort_flatbuffers_py/fbs/__init__.py +6 -0
  148. onnxruntime/tools/ort_format_model/ort_model_processor.py +86 -0
  149. onnxruntime/tools/ort_format_model/types.py +85 -0
  150. onnxruntime/tools/ort_format_model/utils.py +61 -0
  151. onnxruntime/tools/pytorch_export_contrib_ops.py +129 -0
  152. onnxruntime/tools/pytorch_export_helpers.py +131 -0
  153. onnxruntime/tools/qdq_helpers/__init__.py +0 -0
  154. onnxruntime/tools/qdq_helpers/optimize_qdq_model.py +37 -0
  155. onnxruntime/tools/qnn/add_trans_cast.py +292 -0
  156. onnxruntime/tools/qnn/gen_qnn_ctx_onnx_model.py +364 -0
  157. onnxruntime/tools/qnn/preprocess.py +165 -0
  158. onnxruntime/tools/reduced_build_config_parser.py +203 -0
  159. onnxruntime/tools/remove_initializer_from_input.py +37 -0
  160. onnxruntime/tools/symbolic_shape_infer.py +3094 -0
  161. onnxruntime/tools/update_onnx_opset.py +31 -0
  162. onnxruntime/transformers/__init__.py +8 -0
  163. onnxruntime/transformers/affinity_helper.py +40 -0
  164. onnxruntime/transformers/benchmark.py +942 -0
  165. onnxruntime/transformers/benchmark_helper.py +643 -0
  166. onnxruntime/transformers/bert_perf_test.py +629 -0
  167. onnxruntime/transformers/bert_test_data.py +641 -0
  168. onnxruntime/transformers/compare_bert_results.py +256 -0
  169. onnxruntime/transformers/constants.py +47 -0
  170. onnxruntime/transformers/convert_generation.py +3605 -0
  171. onnxruntime/transformers/convert_tf_models_to_pytorch.py +205 -0
  172. onnxruntime/transformers/convert_to_packing_mode.py +385 -0
  173. onnxruntime/transformers/dynamo_onnx_helper.py +205 -0
  174. onnxruntime/transformers/float16.py +501 -0
  175. onnxruntime/transformers/fusion_attention.py +1189 -0
  176. onnxruntime/transformers/fusion_attention_clip.py +340 -0
  177. onnxruntime/transformers/fusion_attention_sam2.py +533 -0
  178. onnxruntime/transformers/fusion_attention_unet.py +1307 -0
  179. onnxruntime/transformers/fusion_attention_vae.py +300 -0
  180. onnxruntime/transformers/fusion_bart_attention.py +435 -0
  181. onnxruntime/transformers/fusion_base.py +141 -0
  182. onnxruntime/transformers/fusion_bias_add.py +57 -0
  183. onnxruntime/transformers/fusion_biasgelu.py +66 -0
  184. onnxruntime/transformers/fusion_biassplitgelu.py +110 -0
  185. onnxruntime/transformers/fusion_conformer_attention.py +222 -0
  186. onnxruntime/transformers/fusion_constant_fold.py +144 -0
  187. onnxruntime/transformers/fusion_embedlayer.py +810 -0
  188. onnxruntime/transformers/fusion_fastgelu.py +492 -0
  189. onnxruntime/transformers/fusion_gelu.py +258 -0
  190. onnxruntime/transformers/fusion_gelu_approximation.py +25 -0
  191. onnxruntime/transformers/fusion_gemmfastgelu.py +121 -0
  192. onnxruntime/transformers/fusion_gpt_attention.py +546 -0
  193. onnxruntime/transformers/fusion_gpt_attention_megatron.py +355 -0
  194. onnxruntime/transformers/fusion_gpt_attention_no_past.py +260 -0
  195. onnxruntime/transformers/fusion_group_norm.py +180 -0
  196. onnxruntime/transformers/fusion_layernorm.py +489 -0
  197. onnxruntime/transformers/fusion_mha_mmdit.py +667 -0
  198. onnxruntime/transformers/fusion_nhwc_conv.py +99 -0
  199. onnxruntime/transformers/fusion_options.py +340 -0
  200. onnxruntime/transformers/fusion_qordered_attention.py +420 -0
  201. onnxruntime/transformers/fusion_qordered_gelu.py +118 -0
  202. onnxruntime/transformers/fusion_qordered_layernorm.py +122 -0
  203. onnxruntime/transformers/fusion_qordered_matmul.py +216 -0
  204. onnxruntime/transformers/fusion_quickgelu.py +74 -0
  205. onnxruntime/transformers/fusion_reshape.py +173 -0
  206. onnxruntime/transformers/fusion_rotary_attention.py +1591 -0
  207. onnxruntime/transformers/fusion_shape.py +109 -0
  208. onnxruntime/transformers/fusion_simplified_layernorm.py +165 -0
  209. onnxruntime/transformers/fusion_skip_group_norm.py +254 -0
  210. onnxruntime/transformers/fusion_skiplayernorm.py +209 -0
  211. onnxruntime/transformers/fusion_transpose.py +167 -0
  212. onnxruntime/transformers/fusion_utils.py +321 -0
  213. onnxruntime/transformers/huggingface_models.py +74 -0
  214. onnxruntime/transformers/import_utils.py +20 -0
  215. onnxruntime/transformers/io_binding_helper.py +487 -0
  216. onnxruntime/transformers/large_model_exporter.py +395 -0
  217. onnxruntime/transformers/machine_info.py +230 -0
  218. onnxruntime/transformers/metrics.py +163 -0
  219. onnxruntime/transformers/models/bart/__init__.py +12 -0
  220. onnxruntime/transformers/models/bart/export.py +98 -0
  221. onnxruntime/transformers/models/bert/__init__.py +12 -0
  222. onnxruntime/transformers/models/bert/eval_squad.py +329 -0
  223. onnxruntime/transformers/models/gpt2/__init__.py +12 -0
  224. onnxruntime/transformers/models/gpt2/benchmark_gpt2.py +413 -0
  225. onnxruntime/transformers/models/gpt2/convert_to_onnx.py +566 -0
  226. onnxruntime/transformers/models/gpt2/gpt2_helper.py +1031 -0
  227. onnxruntime/transformers/models/gpt2/gpt2_parity.py +513 -0
  228. onnxruntime/transformers/models/gpt2/gpt2_tester.py +501 -0
  229. onnxruntime/transformers/models/gpt2/parity_check_helper.py +146 -0
  230. onnxruntime/transformers/models/llama/__init__.py +12 -0
  231. onnxruntime/transformers/models/llama/benchmark.py +700 -0
  232. onnxruntime/transformers/models/llama/benchmark_all.py +488 -0
  233. onnxruntime/transformers/models/llama/benchmark_e2e.py +608 -0
  234. onnxruntime/transformers/models/llama/convert_to_onnx.py +1064 -0
  235. onnxruntime/transformers/models/llama/dist_settings.py +57 -0
  236. onnxruntime/transformers/models/llama/llama_inputs.py +504 -0
  237. onnxruntime/transformers/models/llama/llama_parity.py +343 -0
  238. onnxruntime/transformers/models/llama/llama_torch.py +47 -0
  239. onnxruntime/transformers/models/llama/quant_kv_dataloader.py +108 -0
  240. onnxruntime/transformers/models/longformer/__init__.py +12 -0
  241. onnxruntime/transformers/models/longformer/benchmark_longformer.py +821 -0
  242. onnxruntime/transformers/models/longformer/convert_to_onnx.py +413 -0
  243. onnxruntime/transformers/models/longformer/generate_test_data.py +347 -0
  244. onnxruntime/transformers/models/longformer/longformer_helper.py +76 -0
  245. onnxruntime/transformers/models/phi2/__init__.py +12 -0
  246. onnxruntime/transformers/models/phi2/convert_to_onnx.py +590 -0
  247. onnxruntime/transformers/models/phi2/inference_example.py +414 -0
  248. onnxruntime/transformers/models/sam2/__init__.py +12 -0
  249. onnxruntime/transformers/models/sam2/benchmark_sam2.py +638 -0
  250. onnxruntime/transformers/models/sam2/convert_to_onnx.py +270 -0
  251. onnxruntime/transformers/models/sam2/image_decoder.py +272 -0
  252. onnxruntime/transformers/models/sam2/image_encoder.py +236 -0
  253. onnxruntime/transformers/models/sam2/mask_decoder.py +208 -0
  254. onnxruntime/transformers/models/sam2/nvtx_helper.py +33 -0
  255. onnxruntime/transformers/models/sam2/prompt_encoder.py +189 -0
  256. onnxruntime/transformers/models/sam2/sam2_demo.py +321 -0
  257. onnxruntime/transformers/models/sam2/sam2_image_onnx_predictor.py +279 -0
  258. onnxruntime/transformers/models/sam2/sam2_utils.py +147 -0
  259. onnxruntime/transformers/models/stable_diffusion/__init__.py +12 -0
  260. onnxruntime/transformers/models/stable_diffusion/benchmark.py +1519 -0
  261. onnxruntime/transformers/models/stable_diffusion/benchmark_controlnet.py +426 -0
  262. onnxruntime/transformers/models/stable_diffusion/demo_txt2img.py +103 -0
  263. onnxruntime/transformers/models/stable_diffusion/demo_txt2img_xl.py +269 -0
  264. onnxruntime/transformers/models/stable_diffusion/demo_utils.py +778 -0
  265. onnxruntime/transformers/models/stable_diffusion/diffusion_models.py +1318 -0
  266. onnxruntime/transformers/models/stable_diffusion/diffusion_schedulers.py +1179 -0
  267. onnxruntime/transformers/models/stable_diffusion/engine_builder.py +295 -0
  268. onnxruntime/transformers/models/stable_diffusion/engine_builder_ort_cuda.py +387 -0
  269. onnxruntime/transformers/models/stable_diffusion/engine_builder_ort_trt.py +288 -0
  270. onnxruntime/transformers/models/stable_diffusion/engine_builder_tensorrt.py +395 -0
  271. onnxruntime/transformers/models/stable_diffusion/engine_builder_torch.py +108 -0
  272. onnxruntime/transformers/models/stable_diffusion/optimize_pipeline.py +590 -0
  273. onnxruntime/transformers/models/stable_diffusion/ort_optimizer.py +136 -0
  274. onnxruntime/transformers/models/stable_diffusion/pipeline_stable_diffusion.py +831 -0
  275. onnxruntime/transformers/models/stable_diffusion/trt_utilities.py +12 -0
  276. onnxruntime/transformers/models/t5/__init__.py +12 -0
  277. onnxruntime/transformers/models/t5/convert_to_onnx.py +318 -0
  278. onnxruntime/transformers/models/t5/t5_decoder.py +437 -0
  279. onnxruntime/transformers/models/t5/t5_encoder.py +70 -0
  280. onnxruntime/transformers/models/t5/t5_encoder_decoder_init.py +361 -0
  281. onnxruntime/transformers/models/t5/t5_helper.py +302 -0
  282. onnxruntime/transformers/models/whisper/__init__.py +12 -0
  283. onnxruntime/transformers/models/whisper/benchmark.py +585 -0
  284. onnxruntime/transformers/models/whisper/benchmark_all.py +526 -0
  285. onnxruntime/transformers/models/whisper/convert_to_onnx.py +609 -0
  286. onnxruntime/transformers/models/whisper/whisper_chain.py +334 -0
  287. onnxruntime/transformers/models/whisper/whisper_decoder.py +464 -0
  288. onnxruntime/transformers/models/whisper/whisper_encoder.py +164 -0
  289. onnxruntime/transformers/models/whisper/whisper_encoder_decoder_init.py +371 -0
  290. onnxruntime/transformers/models/whisper/whisper_helper.py +1035 -0
  291. onnxruntime/transformers/models/whisper/whisper_inputs.py +380 -0
  292. onnxruntime/transformers/models/whisper/whisper_jump_times.py +477 -0
  293. onnxruntime/transformers/onnx_exporter.py +719 -0
  294. onnxruntime/transformers/onnx_model.py +1636 -0
  295. onnxruntime/transformers/onnx_model_bart.py +141 -0
  296. onnxruntime/transformers/onnx_model_bert.py +488 -0
  297. onnxruntime/transformers/onnx_model_bert_keras.py +474 -0
  298. onnxruntime/transformers/onnx_model_bert_tf.py +588 -0
  299. onnxruntime/transformers/onnx_model_clip.py +42 -0
  300. onnxruntime/transformers/onnx_model_conformer.py +32 -0
  301. onnxruntime/transformers/onnx_model_gpt2.py +101 -0
  302. onnxruntime/transformers/onnx_model_mmdit.py +112 -0
  303. onnxruntime/transformers/onnx_model_phi.py +929 -0
  304. onnxruntime/transformers/onnx_model_sam2.py +137 -0
  305. onnxruntime/transformers/onnx_model_t5.py +985 -0
  306. onnxruntime/transformers/onnx_model_tnlr.py +226 -0
  307. onnxruntime/transformers/onnx_model_unet.py +258 -0
  308. onnxruntime/transformers/onnx_model_vae.py +42 -0
  309. onnxruntime/transformers/onnx_utils.py +55 -0
  310. onnxruntime/transformers/optimizer.py +620 -0
  311. onnxruntime/transformers/past_helper.py +149 -0
  312. onnxruntime/transformers/profile_result_processor.py +358 -0
  313. onnxruntime/transformers/profiler.py +434 -0
  314. onnxruntime/transformers/quantize_helper.py +76 -0
  315. onnxruntime/transformers/shape_infer_helper.py +121 -0
  316. onnxruntime/transformers/shape_optimizer.py +400 -0
  317. onnxruntime/transformers/torch_onnx_export_helper.py +74 -0
  318. onnxruntime_directml-1.24.1.dist-info/METADATA +216 -0
  319. onnxruntime_directml-1.24.1.dist-info/RECORD +322 -0
  320. onnxruntime_directml-1.24.1.dist-info/WHEEL +5 -0
  321. onnxruntime_directml-1.24.1.dist-info/entry_points.txt +2 -0
  322. onnxruntime_directml-1.24.1.dist-info/top_level.txt +1 -0
@@ -0,0 +1,66 @@
1
+ # -------------------------------------------------------------------------
2
+ # Copyright (c) Microsoft Corporation. All rights reserved.
3
+ # Licensed under the MIT License.
4
+ # --------------------------------------------------------------------------
5
+
6
+ from logging import getLogger
7
+
8
+ from fusion_base import Fusion
9
+ from fusion_utils import NumpyHelper
10
+ from onnx import helper
11
+ from onnx_model import OnnxModel
12
+
13
+ logger = getLogger(__name__)
14
+
15
+
16
+ class FusionBiasGelu(Fusion):
17
+ def __init__(self, model: OnnxModel, is_fastgelu):
18
+ if is_fastgelu:
19
+ super().__init__(model, "FastGelu", "FastGelu", "add bias")
20
+ else:
21
+ super().__init__(model, "BiasGelu", "Gelu")
22
+
23
+ def fuse(self, node, input_name_to_nodes, output_name_to_node):
24
+ gelu_op_type = node.op_type
25
+ fuse_op_type = "BiasGelu" if gelu_op_type == "Gelu" else "FastGelu"
26
+
27
+ if len(node.input) != 1:
28
+ return
29
+
30
+ nodes = self.model.match_parent_path(node, ["Add", "MatMul"], [0, None])
31
+ if nodes is None:
32
+ return
33
+ (add, matmul) = nodes
34
+
35
+ bias_weight = None
36
+ # bias should be one dimension
37
+ bias_index = -1
38
+ for i, input in enumerate(add.input):
39
+ initializer = self.model.get_initializer(input)
40
+ if initializer is None:
41
+ continue
42
+ bias_index = i
43
+ bias_weight = NumpyHelper.to_array(initializer)
44
+ break
45
+ if bias_weight is None:
46
+ return
47
+ if len(bias_weight.shape) != 1:
48
+ return
49
+
50
+ subgraph_nodes = [node, add]
51
+ if not self.model.is_safe_to_fuse_nodes(
52
+ subgraph_nodes, [node.output[0]], input_name_to_nodes, output_name_to_node
53
+ ):
54
+ return
55
+
56
+ self.nodes_to_remove.extend(subgraph_nodes)
57
+
58
+ fused_node = helper.make_node(
59
+ fuse_op_type,
60
+ inputs=[matmul.output[0], add.input[bias_index]],
61
+ outputs=node.output,
62
+ name=self.model.create_node_name(fuse_op_type, gelu_op_type + "_AddBias_"),
63
+ )
64
+ fused_node.domain = "com.microsoft"
65
+ self.nodes_to_add.append(fused_node)
66
+ self.node_name_to_graph_name[fused_node.name] = self.this_graph_name
@@ -0,0 +1,110 @@
1
+ # -------------------------------------------------------------------------
2
+ # Copyright (c) Microsoft Corporation. All rights reserved.
3
+ # Licensed under the MIT License.
4
+ # --------------------------------------------------------------------------
5
+ from logging import getLogger
6
+
7
+ from fusion_base import Fusion
8
+ from onnx import helper
9
+ from onnx_model import OnnxModel
10
+
11
+ logger = getLogger(__name__)
12
+
13
+
14
+ class FusionBiasSplitGelu(Fusion):
15
+ def __init__(self, model: OnnxModel):
16
+ super().__init__(model, "BiasSplitGelu", "Gelu")
17
+
18
+ def fuse(self, gelu_node, input_name_to_nodes: dict, output_name_to_node: dict):
19
+ """
20
+ [root] --->Add --------------------> Slice ---------------> Mul -->
21
+ | ^ ^
22
+ | | |
23
+ +----------------------------+---Slice --> Gelu---+
24
+ | | ^
25
+ | |-----|
26
+ | | |
27
+ | Mul Mul
28
+ | ^ ^
29
+ v | |
30
+ Shape ---> Gather --> Add --> Div --+
31
+ """
32
+ if gelu_node.output[0] not in input_name_to_nodes:
33
+ return
34
+ children = input_name_to_nodes[gelu_node.output[0]]
35
+ if len(children) != 1 or children[0].op_type != "Mul":
36
+ return
37
+ mul_after_gelu = children[0]
38
+
39
+ slice_before_gelu = self.model.match_parent(gelu_node, "Slice", 0, output_name_to_node)
40
+ if slice_before_gelu is None:
41
+ return
42
+
43
+ if self.model.find_constant_input(slice_before_gelu, -1, delta=0.001) != 3:
44
+ return
45
+
46
+ add_output = slice_before_gelu.input[0]
47
+
48
+ start_index_nodes = self.model.match_parent_path(
49
+ slice_before_gelu,
50
+ ["Div", "Add", "Gather", "Shape", "Add"],
51
+ [1, 0, 0, 0, 0],
52
+ output_name_to_node, # Mul(1) is optional
53
+ )
54
+ if start_index_nodes is None:
55
+ start_index_nodes = self.model.match_parent_path(
56
+ slice_before_gelu,
57
+ ["Mul", "Div", "Add", "Gather", "Shape", "Add"],
58
+ [1, 0, 0, 0, 0, 0],
59
+ output_name_to_node,
60
+ )
61
+
62
+ if start_index_nodes is None or start_index_nodes[-2].input[0] != add_output:
63
+ return
64
+
65
+ end_index_nodes = self.model.match_parent_path(slice_before_gelu, ["Mul", "Div"], [2, 0], output_name_to_node)
66
+
67
+ if (
68
+ end_index_nodes is None or end_index_nodes[1] not in start_index_nodes
69
+ ): # the Div is parent of both two Mul nodes
70
+ return
71
+
72
+ slice_before_mul = self.model.match_parent(mul_after_gelu, "Slice", 0, output_name_to_node)
73
+ if slice_before_mul is None:
74
+ return
75
+
76
+ if (
77
+ slice_before_mul.input[2] != slice_before_gelu.input[1]
78
+ ): # end index of slice_before_mul is start index of slice_before_gelu
79
+ return
80
+
81
+ subgraph_nodes = [
82
+ *start_index_nodes,
83
+ end_index_nodes[0],
84
+ mul_after_gelu,
85
+ gelu_node,
86
+ slice_before_mul,
87
+ slice_before_gelu,
88
+ ]
89
+ subgraph_output = mul_after_gelu.output[0]
90
+ if not self.model.is_safe_to_fuse_nodes(
91
+ subgraph_nodes, [subgraph_output], input_name_to_nodes, output_name_to_node
92
+ ):
93
+ logger.info("Skip fuse BiasSplitGelu since it is not safe to fuse the subgraph.")
94
+ return
95
+
96
+ add_node = start_index_nodes[-1]
97
+ bias_index, _value = self.model.get_constant_input(add_node)
98
+ if not isinstance(bias_index, int):
99
+ return
100
+ self.nodes_to_remove.extend(subgraph_nodes)
101
+ node_name = self.model.create_node_name("BiasSplitGelu", name_prefix="BiasSplitGelu")
102
+ fused_node = helper.make_node(
103
+ "BiasSplitGelu",
104
+ inputs=[add_node.input[1 - bias_index], add_node.input[bias_index]],
105
+ outputs=[subgraph_output],
106
+ name=node_name,
107
+ )
108
+ fused_node.domain = "com.microsoft"
109
+ self.nodes_to_add.append(fused_node)
110
+ self.node_name_to_graph_name[node_name] = self.this_graph_name
@@ -0,0 +1,222 @@
1
+ # -------------------------------------------------------------------------
2
+ # Copyright (c) Microsoft Corporation. All rights reserved.
3
+ # Licensed under the MIT License.
4
+ # --------------------------------------------------------------------------
5
+ import logging
6
+
7
+ from fusion_attention import AttentionMask, FusionAttention
8
+ from onnx_model import OnnxModel
9
+
10
+ logger = logging.getLogger(__name__)
11
+
12
+
13
+ class FusionConformerAttention(FusionAttention):
14
+ """
15
+ Fuse Conformer Attention subgraph into one MultiHeadAttention node.
16
+ """
17
+
18
+ def __init__(
19
+ self,
20
+ model: OnnxModel,
21
+ hidden_size: int,
22
+ num_heads: int,
23
+ attention_mask: AttentionMask,
24
+ ):
25
+ super().__init__(model, hidden_size, num_heads, attention_mask)
26
+
27
+ def fuse(self, normalize_node, input_name_to_nodes, output_name_to_node):
28
+ # SkipLayerNormalization has two inputs, and one of them is the root input for attention.
29
+ qkv_nodes = self.model.match_parent_path(
30
+ normalize_node,
31
+ ["Add", "MatMul", "Reshape", "Transpose", "MatMul"],
32
+ [1, None, 0, 0, 0],
33
+ )
34
+ if qkv_nodes is None:
35
+ logger.debug("fuse_conformer_attention: failed to match qkv path")
36
+ return
37
+
38
+ reshape_qkv, transpose_qkv, matmul_qkv = qkv_nodes[-3], qkv_nodes[-2], qkv_nodes[-1]
39
+
40
+ past_v, present_v = "", ""
41
+ v_nodes = self.model.match_parent_path(
42
+ matmul_qkv,
43
+ ["Concat", "Transpose", "Reshape", "Add", "MatMul"],
44
+ [1, 1, 0, 0, 1],
45
+ )
46
+ if v_nodes is None:
47
+ v_nodes = self.model.match_parent_path(
48
+ matmul_qkv,
49
+ ["Transpose", "Reshape", "Add", "MatMul"],
50
+ [1, 0, 0, 0],
51
+ )
52
+ if v_nodes is None:
53
+ logger.debug("fuse_conformer_attention: failed to match v path")
54
+ return
55
+ else:
56
+ concat_v = v_nodes[0]
57
+ concat_parent = self.model.get_parent(concat_v, 0, None)
58
+ present_v = concat_v.output[0]
59
+ past_v = concat_parent.output[0]
60
+
61
+ add_v, matmul_v = v_nodes[-2], v_nodes[-1]
62
+
63
+ attn_mask = ""
64
+ qk_nodes = self.model.match_parent_path(
65
+ matmul_qkv,
66
+ ["Softmax", "Add", "MatMul"],
67
+ [0, 0, 0],
68
+ )
69
+ if qk_nodes is None:
70
+ qk_nodes = self.model.match_parent_path(
71
+ matmul_qkv,
72
+ ["Where", "Softmax", "Where", "Add", "MatMul"],
73
+ [0, 2, 0, 2, 0],
74
+ )
75
+ if qk_nodes is None:
76
+ logger.debug("fuse_conformer_attention: failed to match qk path")
77
+ return
78
+
79
+ where_qk = qk_nodes[2]
80
+ mask_nodes = self.model.match_parent_path(
81
+ where_qk,
82
+ ["Equal", "Unsqueeze", "Cast"],
83
+ [0, 0, 0],
84
+ )
85
+ if mask_nodes is not None:
86
+ attn_mask = mask_nodes[-1].output[0]
87
+
88
+ add_qk, matmul_qk = qk_nodes[-2], qk_nodes[-1]
89
+
90
+ q_nodes = self.model.match_parent_path(
91
+ matmul_qk,
92
+ ["Div", "Transpose", "Reshape", "Add", "MatMul"],
93
+ [0, 0, 0, 0, 1],
94
+ )
95
+ if q_nodes is None:
96
+ q_nodes = self.model.match_parent_path(
97
+ matmul_qk,
98
+ ["Mul", "Transpose", "Reshape", "Add", "MatMul"],
99
+ [0, 0, 0, 0, 0],
100
+ )
101
+ if q_nodes is None:
102
+ logger.debug("fuse_conformer_attention: failed to match q path")
103
+ return
104
+
105
+ reshape_q, add_q, matmul_q = q_nodes[-3], q_nodes[-2], q_nodes[-1]
106
+
107
+ extra_q_nodes = self.model.match_parent_path(
108
+ add_qk,
109
+ ["Reshape", "Transpose", "MatMul", "Transpose", "Reshape", "Div"],
110
+ [1, 0, 0, 0, 0, 0],
111
+ )
112
+ if extra_q_nodes is not None and q_nodes[0] != extra_q_nodes[-1]:
113
+ logger.debug("fuse_conformer_attention: failed to match extra q path")
114
+ return
115
+
116
+ past_k, present_k = "", ""
117
+ k_nodes = self.model.match_parent_path(
118
+ matmul_qk,
119
+ ["Transpose", "Concat", "Transpose", "Reshape", "Add", "MatMul"],
120
+ [1, 0, 1, 0, 0, 1],
121
+ )
122
+ if k_nodes is None:
123
+ k_nodes = self.model.match_parent_path(
124
+ matmul_qk,
125
+ ["Transpose", "Transpose", "Reshape", "Add", "MatMul"],
126
+ [1, 0, 0, 0, 0],
127
+ )
128
+ if k_nodes is None:
129
+ k_nodes = self.model.match_parent_path(
130
+ matmul_qk,
131
+ ["Transpose", "Reshape", "Add", "MatMul"],
132
+ [1, 0, 0, 0],
133
+ )
134
+ if k_nodes is None:
135
+ logger.debug("fuse_conformer_attention: failed to match k path")
136
+ return
137
+ else:
138
+ concat_k = k_nodes[1]
139
+ concat_parent = self.model.get_parent(concat_k, 0, None)
140
+ past_k = concat_parent.output[0]
141
+ present_k = concat_k.output[0]
142
+
143
+ add_k, matmul_k = k_nodes[-2], k_nodes[-1]
144
+
145
+ num_heads, hidden_size = self.get_num_heads_and_hidden_size(reshape_q)
146
+ if num_heads <= 0 or hidden_size <= 0 or (hidden_size % num_heads) != 0:
147
+ logger.debug("fuse_conformer_attention: failed to detect num_heads or hidden_size")
148
+ return
149
+
150
+ new_node = None
151
+ use_packed_attention_op = (
152
+ matmul_q.input[0] == matmul_k.input[0] and matmul_k.input[0] == matmul_v.input[0] and extra_q_nodes is None
153
+ )
154
+ if use_packed_attention_op:
155
+ # Self-attention, use Attention op
156
+ new_node = self.create_attention_node(
157
+ mask_index=attn_mask,
158
+ q_matmul=matmul_q,
159
+ k_matmul=matmul_k,
160
+ v_matmul=matmul_v,
161
+ q_add=add_q,
162
+ k_add=add_k,
163
+ v_add=add_v,
164
+ num_heads=num_heads,
165
+ hidden_size=hidden_size,
166
+ first_input=matmul_q.input[0],
167
+ output=reshape_qkv.output[0],
168
+ add_qk_str=add_qk.input[1],
169
+ past_k=past_k,
170
+ past_v=past_v,
171
+ present_k=present_k,
172
+ present_v=present_v,
173
+ )
174
+ else:
175
+ new_node = self.create_multihead_attention_node(
176
+ q_matmul=matmul_q,
177
+ k_matmul=matmul_k,
178
+ v_matmul=matmul_v,
179
+ q_add=add_q,
180
+ k_add=add_k,
181
+ v_add=add_v,
182
+ num_heads=num_heads,
183
+ hidden_size=hidden_size,
184
+ output=reshape_qkv.output[0],
185
+ key_padding_mask=attn_mask,
186
+ add_qk=add_qk.input[1],
187
+ past_k=past_k,
188
+ past_v=past_v,
189
+ present_k=present_k,
190
+ present_v=present_v,
191
+ )
192
+
193
+ if new_node is None:
194
+ logger.debug("fuse_conformer_attention: MultiHeadAttention node creation failed")
195
+ return
196
+
197
+ self.nodes_to_add.append(new_node)
198
+ self.node_name_to_graph_name[new_node.name] = self.this_graph_name
199
+
200
+ self.nodes_to_remove.extend([reshape_qkv, transpose_qkv, matmul_qkv])
201
+ self.nodes_to_remove.extend(qk_nodes)
202
+
203
+ # When using MultiHeadAttention, keep MatMul nodes unfused in original graph
204
+ if not use_packed_attention_op:
205
+ if q_nodes[-1].op_type == "MatMul":
206
+ q_nodes.pop()
207
+ if k_nodes[-1].op_type == "MatMul":
208
+ k_nodes.pop()
209
+ if v_nodes[-1].op_type == "MatMul":
210
+ v_nodes.pop()
211
+
212
+ if extra_q_nodes is None:
213
+ # Don't remove Q nodes for conformer-transducer (CT) model since it has
214
+ # an extra set of nodes attached to the output of the Q path that are not
215
+ # part of the attention computation
216
+ self.nodes_to_remove.extend(q_nodes)
217
+
218
+ self.nodes_to_remove.extend(k_nodes)
219
+ self.nodes_to_remove.extend(v_nodes)
220
+
221
+ # Use prune graph to remove mask nodes since they are shared by all attention nodes.
222
+ self.prune_graph = True
@@ -0,0 +1,144 @@
1
+ # -------------------------------------------------------------------------
2
+ # Copyright (c) Microsoft Corporation. All rights reserved.
3
+ # Licensed under the MIT License.
4
+ # --------------------------------------------------------------------------
5
+
6
+ from logging import getLogger
7
+
8
+ from fusion_base import Fusion
9
+ from fusion_utils import NumpyHelper
10
+ from onnx_model import OnnxModel
11
+
12
+ logger = getLogger(__name__)
13
+
14
+
15
+ class FusionConstantFold(Fusion):
16
+ def __init__(self, model: OnnxModel):
17
+ super().__init__(model, "", ["Transpose"])
18
+ self.count = 0
19
+
20
+ def apply(self):
21
+ super().apply()
22
+ if self.count > 0:
23
+ logger.info(f"Constant Folded: {self.count}")
24
+
25
+ def fuse(self, node, input_name_to_nodes, output_name_to_node):
26
+ """
27
+ Apply multiple fusions on Transpose nodes that can be constant folded.
28
+ """
29
+ self.fuse_1(node, input_name_to_nodes, output_name_to_node)
30
+ self.fuse_2(node, input_name_to_nodes, output_name_to_node)
31
+
32
+ def fuse_1(self, node, input_name_to_nodes, output_name_to_node):
33
+ """
34
+ Constant fold any initializer data representing a MatMul's
35
+ weights that are stored in a Transpose op
36
+
37
+ Ex: Transpose --> Gemm or Transpose --> MatMul
38
+ """
39
+ # Check if Transpose node only has one input and one output
40
+ if len(node.input) != 1 or len(node.output) != 1:
41
+ logger.debug("fuse_constant_fold: node has more than one input or output")
42
+ return
43
+
44
+ # Check if input is initializer data
45
+ proto = self.model.get_initializer(node.input[0])
46
+ if proto is None:
47
+ logger.debug("fuse_constant_fold: failed to identify initializer input")
48
+ return
49
+
50
+ # Check that all nodes using input are Transpose ops that also only use the initializer data as input
51
+ skip = False
52
+ for child_node in input_name_to_nodes[node.input[0]]:
53
+ if not (child_node.op_type == "Transpose" and len(node.input) == 1):
54
+ skip = True
55
+ break
56
+ if skip:
57
+ logger.debug("fuse_constant_fold: other non-Transpose nodes use the initializer")
58
+ return
59
+
60
+ # Check that all nodes using output are Gemm or MatMul ops
61
+ for child_node in input_name_to_nodes[node.output[0]]:
62
+ if not (child_node.op_type == "Gemm" or child_node.op_type == "MatMul"):
63
+ skip = True
64
+ break
65
+ if skip:
66
+ logger.debug("fuse_constant_fold: other non-Gemm and non-MatMul nodes use the transposed data")
67
+ return
68
+
69
+ # Check if initializer data is 2D
70
+ weight = NumpyHelper.to_array(proto)
71
+ if len(weight.shape) != 2:
72
+ logger.debug("fuse_constant_fold: shape of initializer data is not 2D")
73
+ return
74
+
75
+ # Remove old TensorProto and add new TensorProto while re-using same name
76
+ name = proto.name
77
+ dtype = proto.data_type
78
+ self.remove_initializer(proto)
79
+ self.add_initializer(
80
+ name=name,
81
+ data_type=dtype,
82
+ dims=[weight.shape[1], weight.shape[0]],
83
+ vals=weight.T,
84
+ )
85
+
86
+ # Update weights input to be the initializer name and not
87
+ # the output of the Transpose op
88
+ for child_node in input_name_to_nodes[node.output[0]]:
89
+ for i in range(len(child_node.input)):
90
+ if child_node.input[i] == node.output[0]:
91
+ child_node.input[i] = node.input[0]
92
+
93
+ if child_node.op_type == "Gemm" and (i == 0 or i == 1):
94
+ # Ensure that transA/transB is set to 0 in Gemm
95
+ key = "transA" if i == 0 else "transB"
96
+ for j, attr_key in enumerate(child_node.attribute):
97
+ if attr_key.name == key:
98
+ child_node.attribute[j].i = 0
99
+
100
+ # Add node to list of nodes to remove
101
+ self.nodes_to_remove.append(node)
102
+ self.count += 1
103
+
104
+ def fuse_2(self, node, input_name_to_nodes, output_name_to_node):
105
+ """
106
+ Constant fold any Transpose --> Transpose ops since the root input
107
+ is the final result
108
+
109
+ Ex: root_input --> Transpose --> Transpose --> next_node to root_input --> next_node
110
+ """
111
+ # Check if Transpose node only has one input and one output
112
+ if len(node.input) != 1 or len(node.output) != 1:
113
+ logger.debug("fuse_constant_fold: node has more than one input or output")
114
+ return
115
+
116
+ # Check if parent node is Transpose node with only one input and one output
117
+ parent_node = self.model.match_parent(node, "Transpose", 0)
118
+ if parent_node is None:
119
+ logger.debug("fuse_constant_fold: failed to identify parent Transpose node")
120
+ return
121
+ if len(parent_node.input) != 1 or len(parent_node.output) != 1:
122
+ logger.debug("fuse_constant_fold: parent node has more than one input or output")
123
+ return
124
+
125
+ node_perm = node.attribute[0].ints
126
+ parent_node_perm = parent_node.attribute[0].ints
127
+
128
+ if node_perm != parent_node_perm:
129
+ logger.debug("fuse_constant_fold: Transpose node permutations aren't identical")
130
+ return
131
+
132
+ # For nodes that use output of child Transpose node as an input,
133
+ # replace that input with root_input
134
+ root_input = parent_node.input[0]
135
+ output_nodes = input_name_to_nodes[node.output[0]]
136
+ for output_node in output_nodes:
137
+ for i, input_ in enumerate(output_node.input):
138
+ if input_ == node.output[0]:
139
+ output_node.input[i] = root_input
140
+
141
+ # Add node to list of nodes to remove
142
+ self.nodes_to_remove.append(node)
143
+ self.nodes_to_remove.append(parent_node)
144
+ self.count += 1