onnxruntime-directml 1.24.1__cp314-cp314-win_amd64.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (322) hide show
  1. onnxruntime/LICENSE +21 -0
  2. onnxruntime/Privacy.md +21 -0
  3. onnxruntime/ThirdPartyNotices.txt +6121 -0
  4. onnxruntime/__init__.py +418 -0
  5. onnxruntime/backend/__init__.py +6 -0
  6. onnxruntime/backend/backend.py +175 -0
  7. onnxruntime/backend/backend_rep.py +52 -0
  8. onnxruntime/capi/DirectML.dll +0 -0
  9. onnxruntime/capi/__init__.py +4 -0
  10. onnxruntime/capi/_ld_preload.py +7 -0
  11. onnxruntime/capi/_pybind_state.py +33 -0
  12. onnxruntime/capi/build_and_package_info.py +2 -0
  13. onnxruntime/capi/convert_npz_to_onnx_adapter.py +48 -0
  14. onnxruntime/capi/onnxruntime.dll +0 -0
  15. onnxruntime/capi/onnxruntime_collect_build_info.py +47 -0
  16. onnxruntime/capi/onnxruntime_inference_collection.py +1440 -0
  17. onnxruntime/capi/onnxruntime_providers_shared.dll +0 -0
  18. onnxruntime/capi/onnxruntime_pybind11_state.pyd +0 -0
  19. onnxruntime/capi/onnxruntime_validation.py +154 -0
  20. onnxruntime/capi/version_info.py +2 -0
  21. onnxruntime/datasets/__init__.py +18 -0
  22. onnxruntime/datasets/logreg_iris.onnx +0 -0
  23. onnxruntime/datasets/mul_1.onnx +0 -0
  24. onnxruntime/datasets/sigmoid.onnx +13 -0
  25. onnxruntime/quantization/CalTableFlatBuffers/KeyValue.py +78 -0
  26. onnxruntime/quantization/CalTableFlatBuffers/TrtTable.py +90 -0
  27. onnxruntime/quantization/CalTableFlatBuffers/__init__.py +0 -0
  28. onnxruntime/quantization/__init__.py +19 -0
  29. onnxruntime/quantization/base_quantizer.py +529 -0
  30. onnxruntime/quantization/calibrate.py +1267 -0
  31. onnxruntime/quantization/execution_providers/qnn/__init__.py +2 -0
  32. onnxruntime/quantization/execution_providers/qnn/fusion_lpnorm.py +132 -0
  33. onnxruntime/quantization/execution_providers/qnn/fusion_spacetodepth.py +162 -0
  34. onnxruntime/quantization/execution_providers/qnn/mixed_precision_overrides_utils.py +413 -0
  35. onnxruntime/quantization/execution_providers/qnn/preprocess.py +353 -0
  36. onnxruntime/quantization/execution_providers/qnn/quant_config.py +389 -0
  37. onnxruntime/quantization/fusions/__init__.py +4 -0
  38. onnxruntime/quantization/fusions/fusion.py +311 -0
  39. onnxruntime/quantization/fusions/fusion_gelu.py +272 -0
  40. onnxruntime/quantization/fusions/fusion_layernorm.py +146 -0
  41. onnxruntime/quantization/fusions/replace_upsample_with_resize.py +96 -0
  42. onnxruntime/quantization/matmul_bnb4_quantizer.py +239 -0
  43. onnxruntime/quantization/matmul_nbits_quantizer.py +1638 -0
  44. onnxruntime/quantization/neural_compressor/__init__.py +1 -0
  45. onnxruntime/quantization/neural_compressor/onnx_model.py +1251 -0
  46. onnxruntime/quantization/neural_compressor/util.py +80 -0
  47. onnxruntime/quantization/neural_compressor/weight_only.py +932 -0
  48. onnxruntime/quantization/onnx_model.py +600 -0
  49. onnxruntime/quantization/onnx_quantizer.py +1163 -0
  50. onnxruntime/quantization/operators/__init__.py +2 -0
  51. onnxruntime/quantization/operators/activation.py +119 -0
  52. onnxruntime/quantization/operators/argmax.py +18 -0
  53. onnxruntime/quantization/operators/attention.py +73 -0
  54. onnxruntime/quantization/operators/base_operator.py +26 -0
  55. onnxruntime/quantization/operators/binary_op.py +72 -0
  56. onnxruntime/quantization/operators/concat.py +62 -0
  57. onnxruntime/quantization/operators/conv.py +260 -0
  58. onnxruntime/quantization/operators/direct_q8.py +78 -0
  59. onnxruntime/quantization/operators/embed_layernorm.py +121 -0
  60. onnxruntime/quantization/operators/gather.py +64 -0
  61. onnxruntime/quantization/operators/gavgpool.py +62 -0
  62. onnxruntime/quantization/operators/gemm.py +172 -0
  63. onnxruntime/quantization/operators/lstm.py +121 -0
  64. onnxruntime/quantization/operators/matmul.py +231 -0
  65. onnxruntime/quantization/operators/maxpool.py +34 -0
  66. onnxruntime/quantization/operators/norm.py +40 -0
  67. onnxruntime/quantization/operators/pad.py +172 -0
  68. onnxruntime/quantization/operators/pooling.py +67 -0
  69. onnxruntime/quantization/operators/qdq_base_operator.py +22 -0
  70. onnxruntime/quantization/operators/resize.py +34 -0
  71. onnxruntime/quantization/operators/softmax.py +74 -0
  72. onnxruntime/quantization/operators/split.py +63 -0
  73. onnxruntime/quantization/operators/where.py +87 -0
  74. onnxruntime/quantization/preprocess.py +141 -0
  75. onnxruntime/quantization/qdq_loss_debug.py +389 -0
  76. onnxruntime/quantization/qdq_quantizer.py +1477 -0
  77. onnxruntime/quantization/quant_utils.py +1051 -0
  78. onnxruntime/quantization/quantize.py +953 -0
  79. onnxruntime/quantization/registry.py +110 -0
  80. onnxruntime/quantization/shape_inference.py +204 -0
  81. onnxruntime/quantization/static_quantize_runner.py +256 -0
  82. onnxruntime/quantization/tensor_quant_overrides.py +520 -0
  83. onnxruntime/tools/__init__.py +10 -0
  84. onnxruntime/tools/check_onnx_model_mobile_usability.py +47 -0
  85. onnxruntime/tools/convert_onnx_models_to_ort.py +380 -0
  86. onnxruntime/tools/file_utils.py +47 -0
  87. onnxruntime/tools/logger.py +11 -0
  88. onnxruntime/tools/make_dynamic_shape_fixed.py +73 -0
  89. onnxruntime/tools/mobile_helpers/__init__.py +0 -0
  90. onnxruntime/tools/mobile_helpers/coreml_supported_mlprogram_ops.md +53 -0
  91. onnxruntime/tools/mobile_helpers/coreml_supported_neuralnetwork_ops.md +43 -0
  92. onnxruntime/tools/mobile_helpers/nnapi_supported_ops.md +58 -0
  93. onnxruntime/tools/mobile_helpers/usability_checker.py +738 -0
  94. onnxruntime/tools/offline_tuning.py +169 -0
  95. onnxruntime/tools/onnx_model_utils.py +416 -0
  96. onnxruntime/tools/onnx_randomizer.py +85 -0
  97. onnxruntime/tools/onnxruntime_test.py +164 -0
  98. onnxruntime/tools/optimize_onnx_model.py +56 -0
  99. onnxruntime/tools/ort_format_model/__init__.py +27 -0
  100. onnxruntime/tools/ort_format_model/operator_type_usage_processors.py +653 -0
  101. onnxruntime/tools/ort_format_model/ort_flatbuffers_py/__init__.py +0 -0
  102. onnxruntime/tools/ort_format_model/ort_flatbuffers_py/fbs/ArgType.py +7 -0
  103. onnxruntime/tools/ort_format_model/ort_flatbuffers_py/fbs/ArgTypeAndIndex.py +67 -0
  104. onnxruntime/tools/ort_format_model/ort_flatbuffers_py/fbs/Attribute.py +337 -0
  105. onnxruntime/tools/ort_format_model/ort_flatbuffers_py/fbs/AttributeType.py +18 -0
  106. onnxruntime/tools/ort_format_model/ort_flatbuffers_py/fbs/Checkpoint.py +125 -0
  107. onnxruntime/tools/ort_format_model/ort_flatbuffers_py/fbs/DeprecatedKernelCreateInfos.py +120 -0
  108. onnxruntime/tools/ort_format_model/ort_flatbuffers_py/fbs/DeprecatedNodeIndexAndKernelDefHash.py +68 -0
  109. onnxruntime/tools/ort_format_model/ort_flatbuffers_py/fbs/DeprecatedSessionState.py +96 -0
  110. onnxruntime/tools/ort_format_model/ort_flatbuffers_py/fbs/DeprecatedSubGraphSessionState.py +72 -0
  111. onnxruntime/tools/ort_format_model/ort_flatbuffers_py/fbs/Dimension.py +71 -0
  112. onnxruntime/tools/ort_format_model/ort_flatbuffers_py/fbs/DimensionValue.py +80 -0
  113. onnxruntime/tools/ort_format_model/ort_flatbuffers_py/fbs/DimensionValueType.py +8 -0
  114. onnxruntime/tools/ort_format_model/ort_flatbuffers_py/fbs/EdgeEnd.py +32 -0
  115. onnxruntime/tools/ort_format_model/ort_flatbuffers_py/fbs/FloatProperty.py +67 -0
  116. onnxruntime/tools/ort_format_model/ort_flatbuffers_py/fbs/Graph.py +320 -0
  117. onnxruntime/tools/ort_format_model/ort_flatbuffers_py/fbs/InferenceSession.py +88 -0
  118. onnxruntime/tools/ort_format_model/ort_flatbuffers_py/fbs/IntProperty.py +67 -0
  119. onnxruntime/tools/ort_format_model/ort_flatbuffers_py/fbs/KernelTypeStrArgsEntry.py +91 -0
  120. onnxruntime/tools/ort_format_model/ort_flatbuffers_py/fbs/KernelTypeStrResolver.py +78 -0
  121. onnxruntime/tools/ort_format_model/ort_flatbuffers_py/fbs/MapType.py +71 -0
  122. onnxruntime/tools/ort_format_model/ort_flatbuffers_py/fbs/Model.py +223 -0
  123. onnxruntime/tools/ort_format_model/ort_flatbuffers_py/fbs/ModuleState.py +141 -0
  124. onnxruntime/tools/ort_format_model/ort_flatbuffers_py/fbs/Node.py +317 -0
  125. onnxruntime/tools/ort_format_model/ort_flatbuffers_py/fbs/NodeEdge.py +126 -0
  126. onnxruntime/tools/ort_format_model/ort_flatbuffers_py/fbs/NodeType.py +7 -0
  127. onnxruntime/tools/ort_format_model/ort_flatbuffers_py/fbs/NodesToOptimizeIndices.py +160 -0
  128. onnxruntime/tools/ort_format_model/ort_flatbuffers_py/fbs/OpIdKernelTypeStrArgsEntry.py +91 -0
  129. onnxruntime/tools/ort_format_model/ort_flatbuffers_py/fbs/OperatorSetId.py +67 -0
  130. onnxruntime/tools/ort_format_model/ort_flatbuffers_py/fbs/OptimizerGroup.py +117 -0
  131. onnxruntime/tools/ort_format_model/ort_flatbuffers_py/fbs/ParameterOptimizerState.py +91 -0
  132. onnxruntime/tools/ort_format_model/ort_flatbuffers_py/fbs/PropertyBag.py +152 -0
  133. onnxruntime/tools/ort_format_model/ort_flatbuffers_py/fbs/RuntimeOptimizationRecord.py +105 -0
  134. onnxruntime/tools/ort_format_model/ort_flatbuffers_py/fbs/RuntimeOptimizationRecordContainerEntry.py +91 -0
  135. onnxruntime/tools/ort_format_model/ort_flatbuffers_py/fbs/RuntimeOptimizations.py +79 -0
  136. onnxruntime/tools/ort_format_model/ort_flatbuffers_py/fbs/SequenceType.py +58 -0
  137. onnxruntime/tools/ort_format_model/ort_flatbuffers_py/fbs/Shape.py +78 -0
  138. onnxruntime/tools/ort_format_model/ort_flatbuffers_py/fbs/SparseTensor.py +114 -0
  139. onnxruntime/tools/ort_format_model/ort_flatbuffers_py/fbs/StringProperty.py +67 -0
  140. onnxruntime/tools/ort_format_model/ort_flatbuffers_py/fbs/StringStringEntry.py +67 -0
  141. onnxruntime/tools/ort_format_model/ort_flatbuffers_py/fbs/Tensor.py +203 -0
  142. onnxruntime/tools/ort_format_model/ort_flatbuffers_py/fbs/TensorDataType.py +26 -0
  143. onnxruntime/tools/ort_format_model/ort_flatbuffers_py/fbs/TensorTypeAndShape.py +71 -0
  144. onnxruntime/tools/ort_format_model/ort_flatbuffers_py/fbs/TypeInfo.py +83 -0
  145. onnxruntime/tools/ort_format_model/ort_flatbuffers_py/fbs/TypeInfoValue.py +9 -0
  146. onnxruntime/tools/ort_format_model/ort_flatbuffers_py/fbs/ValueInfo.py +84 -0
  147. onnxruntime/tools/ort_format_model/ort_flatbuffers_py/fbs/__init__.py +6 -0
  148. onnxruntime/tools/ort_format_model/ort_model_processor.py +86 -0
  149. onnxruntime/tools/ort_format_model/types.py +85 -0
  150. onnxruntime/tools/ort_format_model/utils.py +61 -0
  151. onnxruntime/tools/pytorch_export_contrib_ops.py +129 -0
  152. onnxruntime/tools/pytorch_export_helpers.py +131 -0
  153. onnxruntime/tools/qdq_helpers/__init__.py +0 -0
  154. onnxruntime/tools/qdq_helpers/optimize_qdq_model.py +37 -0
  155. onnxruntime/tools/qnn/add_trans_cast.py +292 -0
  156. onnxruntime/tools/qnn/gen_qnn_ctx_onnx_model.py +364 -0
  157. onnxruntime/tools/qnn/preprocess.py +165 -0
  158. onnxruntime/tools/reduced_build_config_parser.py +203 -0
  159. onnxruntime/tools/remove_initializer_from_input.py +37 -0
  160. onnxruntime/tools/symbolic_shape_infer.py +3094 -0
  161. onnxruntime/tools/update_onnx_opset.py +31 -0
  162. onnxruntime/transformers/__init__.py +8 -0
  163. onnxruntime/transformers/affinity_helper.py +40 -0
  164. onnxruntime/transformers/benchmark.py +942 -0
  165. onnxruntime/transformers/benchmark_helper.py +643 -0
  166. onnxruntime/transformers/bert_perf_test.py +629 -0
  167. onnxruntime/transformers/bert_test_data.py +641 -0
  168. onnxruntime/transformers/compare_bert_results.py +256 -0
  169. onnxruntime/transformers/constants.py +47 -0
  170. onnxruntime/transformers/convert_generation.py +3605 -0
  171. onnxruntime/transformers/convert_tf_models_to_pytorch.py +205 -0
  172. onnxruntime/transformers/convert_to_packing_mode.py +385 -0
  173. onnxruntime/transformers/dynamo_onnx_helper.py +205 -0
  174. onnxruntime/transformers/float16.py +501 -0
  175. onnxruntime/transformers/fusion_attention.py +1189 -0
  176. onnxruntime/transformers/fusion_attention_clip.py +340 -0
  177. onnxruntime/transformers/fusion_attention_sam2.py +533 -0
  178. onnxruntime/transformers/fusion_attention_unet.py +1307 -0
  179. onnxruntime/transformers/fusion_attention_vae.py +300 -0
  180. onnxruntime/transformers/fusion_bart_attention.py +435 -0
  181. onnxruntime/transformers/fusion_base.py +141 -0
  182. onnxruntime/transformers/fusion_bias_add.py +57 -0
  183. onnxruntime/transformers/fusion_biasgelu.py +66 -0
  184. onnxruntime/transformers/fusion_biassplitgelu.py +110 -0
  185. onnxruntime/transformers/fusion_conformer_attention.py +222 -0
  186. onnxruntime/transformers/fusion_constant_fold.py +144 -0
  187. onnxruntime/transformers/fusion_embedlayer.py +810 -0
  188. onnxruntime/transformers/fusion_fastgelu.py +492 -0
  189. onnxruntime/transformers/fusion_gelu.py +258 -0
  190. onnxruntime/transformers/fusion_gelu_approximation.py +25 -0
  191. onnxruntime/transformers/fusion_gemmfastgelu.py +121 -0
  192. onnxruntime/transformers/fusion_gpt_attention.py +546 -0
  193. onnxruntime/transformers/fusion_gpt_attention_megatron.py +355 -0
  194. onnxruntime/transformers/fusion_gpt_attention_no_past.py +260 -0
  195. onnxruntime/transformers/fusion_group_norm.py +180 -0
  196. onnxruntime/transformers/fusion_layernorm.py +489 -0
  197. onnxruntime/transformers/fusion_mha_mmdit.py +667 -0
  198. onnxruntime/transformers/fusion_nhwc_conv.py +99 -0
  199. onnxruntime/transformers/fusion_options.py +340 -0
  200. onnxruntime/transformers/fusion_qordered_attention.py +420 -0
  201. onnxruntime/transformers/fusion_qordered_gelu.py +118 -0
  202. onnxruntime/transformers/fusion_qordered_layernorm.py +122 -0
  203. onnxruntime/transformers/fusion_qordered_matmul.py +216 -0
  204. onnxruntime/transformers/fusion_quickgelu.py +74 -0
  205. onnxruntime/transformers/fusion_reshape.py +173 -0
  206. onnxruntime/transformers/fusion_rotary_attention.py +1591 -0
  207. onnxruntime/transformers/fusion_shape.py +109 -0
  208. onnxruntime/transformers/fusion_simplified_layernorm.py +165 -0
  209. onnxruntime/transformers/fusion_skip_group_norm.py +254 -0
  210. onnxruntime/transformers/fusion_skiplayernorm.py +209 -0
  211. onnxruntime/transformers/fusion_transpose.py +167 -0
  212. onnxruntime/transformers/fusion_utils.py +321 -0
  213. onnxruntime/transformers/huggingface_models.py +74 -0
  214. onnxruntime/transformers/import_utils.py +20 -0
  215. onnxruntime/transformers/io_binding_helper.py +487 -0
  216. onnxruntime/transformers/large_model_exporter.py +395 -0
  217. onnxruntime/transformers/machine_info.py +230 -0
  218. onnxruntime/transformers/metrics.py +163 -0
  219. onnxruntime/transformers/models/bart/__init__.py +12 -0
  220. onnxruntime/transformers/models/bart/export.py +98 -0
  221. onnxruntime/transformers/models/bert/__init__.py +12 -0
  222. onnxruntime/transformers/models/bert/eval_squad.py +329 -0
  223. onnxruntime/transformers/models/gpt2/__init__.py +12 -0
  224. onnxruntime/transformers/models/gpt2/benchmark_gpt2.py +413 -0
  225. onnxruntime/transformers/models/gpt2/convert_to_onnx.py +566 -0
  226. onnxruntime/transformers/models/gpt2/gpt2_helper.py +1031 -0
  227. onnxruntime/transformers/models/gpt2/gpt2_parity.py +513 -0
  228. onnxruntime/transformers/models/gpt2/gpt2_tester.py +501 -0
  229. onnxruntime/transformers/models/gpt2/parity_check_helper.py +146 -0
  230. onnxruntime/transformers/models/llama/__init__.py +12 -0
  231. onnxruntime/transformers/models/llama/benchmark.py +700 -0
  232. onnxruntime/transformers/models/llama/benchmark_all.py +488 -0
  233. onnxruntime/transformers/models/llama/benchmark_e2e.py +608 -0
  234. onnxruntime/transformers/models/llama/convert_to_onnx.py +1064 -0
  235. onnxruntime/transformers/models/llama/dist_settings.py +57 -0
  236. onnxruntime/transformers/models/llama/llama_inputs.py +504 -0
  237. onnxruntime/transformers/models/llama/llama_parity.py +343 -0
  238. onnxruntime/transformers/models/llama/llama_torch.py +47 -0
  239. onnxruntime/transformers/models/llama/quant_kv_dataloader.py +108 -0
  240. onnxruntime/transformers/models/longformer/__init__.py +12 -0
  241. onnxruntime/transformers/models/longformer/benchmark_longformer.py +821 -0
  242. onnxruntime/transformers/models/longformer/convert_to_onnx.py +413 -0
  243. onnxruntime/transformers/models/longformer/generate_test_data.py +347 -0
  244. onnxruntime/transformers/models/longformer/longformer_helper.py +76 -0
  245. onnxruntime/transformers/models/phi2/__init__.py +12 -0
  246. onnxruntime/transformers/models/phi2/convert_to_onnx.py +590 -0
  247. onnxruntime/transformers/models/phi2/inference_example.py +414 -0
  248. onnxruntime/transformers/models/sam2/__init__.py +12 -0
  249. onnxruntime/transformers/models/sam2/benchmark_sam2.py +638 -0
  250. onnxruntime/transformers/models/sam2/convert_to_onnx.py +270 -0
  251. onnxruntime/transformers/models/sam2/image_decoder.py +272 -0
  252. onnxruntime/transformers/models/sam2/image_encoder.py +236 -0
  253. onnxruntime/transformers/models/sam2/mask_decoder.py +208 -0
  254. onnxruntime/transformers/models/sam2/nvtx_helper.py +33 -0
  255. onnxruntime/transformers/models/sam2/prompt_encoder.py +189 -0
  256. onnxruntime/transformers/models/sam2/sam2_demo.py +321 -0
  257. onnxruntime/transformers/models/sam2/sam2_image_onnx_predictor.py +279 -0
  258. onnxruntime/transformers/models/sam2/sam2_utils.py +147 -0
  259. onnxruntime/transformers/models/stable_diffusion/__init__.py +12 -0
  260. onnxruntime/transformers/models/stable_diffusion/benchmark.py +1519 -0
  261. onnxruntime/transformers/models/stable_diffusion/benchmark_controlnet.py +426 -0
  262. onnxruntime/transformers/models/stable_diffusion/demo_txt2img.py +103 -0
  263. onnxruntime/transformers/models/stable_diffusion/demo_txt2img_xl.py +269 -0
  264. onnxruntime/transformers/models/stable_diffusion/demo_utils.py +778 -0
  265. onnxruntime/transformers/models/stable_diffusion/diffusion_models.py +1318 -0
  266. onnxruntime/transformers/models/stable_diffusion/diffusion_schedulers.py +1179 -0
  267. onnxruntime/transformers/models/stable_diffusion/engine_builder.py +295 -0
  268. onnxruntime/transformers/models/stable_diffusion/engine_builder_ort_cuda.py +387 -0
  269. onnxruntime/transformers/models/stable_diffusion/engine_builder_ort_trt.py +288 -0
  270. onnxruntime/transformers/models/stable_diffusion/engine_builder_tensorrt.py +395 -0
  271. onnxruntime/transformers/models/stable_diffusion/engine_builder_torch.py +108 -0
  272. onnxruntime/transformers/models/stable_diffusion/optimize_pipeline.py +590 -0
  273. onnxruntime/transformers/models/stable_diffusion/ort_optimizer.py +136 -0
  274. onnxruntime/transformers/models/stable_diffusion/pipeline_stable_diffusion.py +831 -0
  275. onnxruntime/transformers/models/stable_diffusion/trt_utilities.py +12 -0
  276. onnxruntime/transformers/models/t5/__init__.py +12 -0
  277. onnxruntime/transformers/models/t5/convert_to_onnx.py +318 -0
  278. onnxruntime/transformers/models/t5/t5_decoder.py +437 -0
  279. onnxruntime/transformers/models/t5/t5_encoder.py +70 -0
  280. onnxruntime/transformers/models/t5/t5_encoder_decoder_init.py +361 -0
  281. onnxruntime/transformers/models/t5/t5_helper.py +302 -0
  282. onnxruntime/transformers/models/whisper/__init__.py +12 -0
  283. onnxruntime/transformers/models/whisper/benchmark.py +585 -0
  284. onnxruntime/transformers/models/whisper/benchmark_all.py +526 -0
  285. onnxruntime/transformers/models/whisper/convert_to_onnx.py +609 -0
  286. onnxruntime/transformers/models/whisper/whisper_chain.py +334 -0
  287. onnxruntime/transformers/models/whisper/whisper_decoder.py +464 -0
  288. onnxruntime/transformers/models/whisper/whisper_encoder.py +164 -0
  289. onnxruntime/transformers/models/whisper/whisper_encoder_decoder_init.py +371 -0
  290. onnxruntime/transformers/models/whisper/whisper_helper.py +1035 -0
  291. onnxruntime/transformers/models/whisper/whisper_inputs.py +380 -0
  292. onnxruntime/transformers/models/whisper/whisper_jump_times.py +477 -0
  293. onnxruntime/transformers/onnx_exporter.py +719 -0
  294. onnxruntime/transformers/onnx_model.py +1636 -0
  295. onnxruntime/transformers/onnx_model_bart.py +141 -0
  296. onnxruntime/transformers/onnx_model_bert.py +488 -0
  297. onnxruntime/transformers/onnx_model_bert_keras.py +474 -0
  298. onnxruntime/transformers/onnx_model_bert_tf.py +588 -0
  299. onnxruntime/transformers/onnx_model_clip.py +42 -0
  300. onnxruntime/transformers/onnx_model_conformer.py +32 -0
  301. onnxruntime/transformers/onnx_model_gpt2.py +101 -0
  302. onnxruntime/transformers/onnx_model_mmdit.py +112 -0
  303. onnxruntime/transformers/onnx_model_phi.py +929 -0
  304. onnxruntime/transformers/onnx_model_sam2.py +137 -0
  305. onnxruntime/transformers/onnx_model_t5.py +985 -0
  306. onnxruntime/transformers/onnx_model_tnlr.py +226 -0
  307. onnxruntime/transformers/onnx_model_unet.py +258 -0
  308. onnxruntime/transformers/onnx_model_vae.py +42 -0
  309. onnxruntime/transformers/onnx_utils.py +55 -0
  310. onnxruntime/transformers/optimizer.py +620 -0
  311. onnxruntime/transformers/past_helper.py +149 -0
  312. onnxruntime/transformers/profile_result_processor.py +358 -0
  313. onnxruntime/transformers/profiler.py +434 -0
  314. onnxruntime/transformers/quantize_helper.py +76 -0
  315. onnxruntime/transformers/shape_infer_helper.py +121 -0
  316. onnxruntime/transformers/shape_optimizer.py +400 -0
  317. onnxruntime/transformers/torch_onnx_export_helper.py +74 -0
  318. onnxruntime_directml-1.24.1.dist-info/METADATA +216 -0
  319. onnxruntime_directml-1.24.1.dist-info/RECORD +322 -0
  320. onnxruntime_directml-1.24.1.dist-info/WHEEL +5 -0
  321. onnxruntime_directml-1.24.1.dist-info/entry_points.txt +2 -0
  322. onnxruntime_directml-1.24.1.dist-info/top_level.txt +1 -0
@@ -0,0 +1,109 @@
1
+ # -------------------------------------------------------------------------
2
+ # Copyright (c) Microsoft Corporation. All rights reserved.
3
+ # Licensed under the MIT License.
4
+ # --------------------------------------------------------------------------
5
+
6
+ from logging import getLogger
7
+
8
+ from fusion_base import Fusion
9
+ from fusion_utils import FusionUtils
10
+ from numpy import ndarray
11
+ from onnx import NodeProto, TensorProto
12
+ from onnx_model import OnnxModel
13
+
14
+ logger = getLogger(__name__)
15
+
16
+
17
+ class FusionShape(Fusion):
18
+ def __init__(self, model: OnnxModel):
19
+ super().__init__(model, "Shape", "Concat")
20
+ self.utils = FusionUtils(model)
21
+ self.shape_infer = None
22
+ self.shape_infer_done = False
23
+
24
+ def get_dimensions_from_tensor_proto(self, tensor_proto: TensorProto) -> int | None:
25
+ if tensor_proto.type.tensor_type.HasField("shape"):
26
+ return len(tensor_proto.type.tensor_type.shape.dim)
27
+ else:
28
+ return None
29
+
30
+ def get_dimensions(self, input_name: str) -> int | None:
31
+ shape = self.model.get_shape(input_name)
32
+ if shape is not None:
33
+ return len(shape)
34
+
35
+ if not self.shape_infer_done:
36
+ self.shape_infer = self.model.infer_runtime_shape(update=True)
37
+ self.shape_infer_done = True
38
+
39
+ if self.shape_infer is not None:
40
+ return self.get_dimensions_from_tensor_proto(self.shape_infer.known_vi_[input_name])
41
+
42
+ return None
43
+
44
+ def fuse(
45
+ self,
46
+ concat_node: NodeProto,
47
+ input_name_to_nodes: dict[str, list[NodeProto]],
48
+ output_name_to_node: dict[str, NodeProto],
49
+ ):
50
+ #
51
+ # Simplify subgraph like
52
+ #
53
+ # (2d_input)
54
+ # / \
55
+ # Shape shape
56
+ # / \
57
+ # Gather(indices=0) Gather(indices=1)
58
+ # | |
59
+ # Unsqueeze(axes=0) Unsqueeze(axes=0)
60
+ # \ /
61
+ # Concat
62
+ # |
63
+ #
64
+ # into (2d_input) --> Shape -->
65
+ #
66
+ opset_version = self.model.get_opset_version()
67
+
68
+ inputs = len(concat_node.input)
69
+ root = None
70
+ shape_output = None
71
+ for i in range(inputs):
72
+ path = self.model.match_parent_path(
73
+ concat_node,
74
+ ["Unsqueeze", "Gather", "Shape"],
75
+ [i, 0, 0],
76
+ output_name_to_node,
77
+ )
78
+ if path is None:
79
+ return
80
+
81
+ unsqueeze, gather, shape = path
82
+ if i == 0:
83
+ shape_output = shape.output[0]
84
+ if root is None:
85
+ root = shape.input[0]
86
+ if self.get_dimensions(root) != inputs:
87
+ return
88
+ elif shape.input[0] != root:
89
+ return
90
+
91
+ if not FusionUtils.check_node_attribute(unsqueeze, "axis", 0, default_value=0):
92
+ return
93
+
94
+ if opset_version < 13:
95
+ if not FusionUtils.check_node_attribute(unsqueeze, "axes", [0]):
96
+ return
97
+ else:
98
+ if not self.utils.check_node_input_value(unsqueeze, 1, [0]):
99
+ return
100
+
101
+ value = self.model.get_constant_value(gather.input[1])
102
+
103
+ if not (isinstance(value, ndarray) and value.size == 1 and value.item() == i):
104
+ return
105
+
106
+ if self.model.find_graph_output(concat_node.output[0]) is None:
107
+ self.model.replace_input_of_all_nodes(concat_node.output[0], shape_output)
108
+ self.increase_counter("Reshape")
109
+ self.prune_graph = True
@@ -0,0 +1,165 @@
1
+ import logging
2
+
3
+ from fusion_base import Fusion
4
+ from fusion_skiplayernorm import FusionSkipLayerNormalization
5
+ from onnx import helper
6
+ from onnx_model import OnnxModel
7
+
8
+ logger = logging.getLogger(__name__)
9
+
10
+
11
+ class FusionSimplifiedLayerNormalization(Fusion):
12
+ def __init__(self, model: OnnxModel):
13
+ super().__init__(model, "SimplifiedLayerNormalization", "Mul")
14
+
15
+ def fuse(self, node, input_name_to_nodes: dict, output_name_to_node: dict):
16
+ if node.op_type != "Mul":
17
+ return
18
+
19
+ sim_ln_nodes = None
20
+ # RMSNorm formula:
21
+ # S = Pow(X, 2) or S = Mul(X, X)
22
+ # MS = ReduceMean(S)
23
+ # MSEps = Add(MS, epsilon)
24
+ # RMS = Sqrt(MSEps)
25
+ # InvRMS = Div(1, RMS) or InvRMS = Reciprocal(RMS)
26
+ # Normalized = Mul(D, InvRMS)
27
+ # Y = Mul(Normalized, Scale)
28
+ #
29
+ # (root_input) ----------------------------------------+
30
+ # | |
31
+ # v v
32
+ # Pow --> ReduceMean --> Add ---> Sqrt --> Div --> Mul --> Mul (node)
33
+ # (B=2) (A/B=eps) (A=1) (A/B=scale)
34
+ #
35
+ # (root_input) ----------------------------------------+
36
+ # | | |
37
+ # v v v
38
+ # Mul --> ReduceMean --> Add ---> Sqrt --> Div --> Mul --> Mul (node)
39
+ # (B=2) (A/B=eps) (A=1) (A/B=scale)
40
+ #
41
+ return_indice = []
42
+ sim_ln_nodes = self.model.match_parent_path(
43
+ node,
44
+ ["Mul", "Div", "Sqrt", "Add", "ReduceMean"],
45
+ [None, 1, 1, 0, None],
46
+ output_name_to_node=output_name_to_node,
47
+ return_indice=return_indice,
48
+ )
49
+
50
+ if sim_ln_nodes:
51
+ mul_node, div_node, _sqrt_node, add_node, reduce_mean_node = sim_ln_nodes
52
+ if not self.model.has_constant_input(div_node, 1.0):
53
+ return
54
+ node_parent = mul_node
55
+ else:
56
+ # Div(1, RMS) can also be represented as Reciprocal(RMS) like
57
+ #
58
+ # (root_input) -----------------------------------------------+
59
+ # | |
60
+ # v v
61
+ # Pow --> ReduceMean --> Add ---> Sqrt --> Reciprocal --> Mul --> Mul (node)
62
+ # (B=2) (A/B=eps) (A/B=scale)
63
+ #
64
+ # (root_input) -----------------------------------------------+
65
+ # | | |
66
+ # v v v
67
+ # Mul --> ReduceMean --> Add ---> Sqrt --> Reciprocal --> Mul --> Mul (node)
68
+ # (B=2) (A/B=eps) (A/B=scale)
69
+ #
70
+ return_indice = []
71
+ sim_ln_nodes = self.model.match_parent_path(
72
+ node,
73
+ ["Mul", "Reciprocal", "Sqrt", "Add", "ReduceMean"],
74
+ [None, 1, 0, 0, None],
75
+ output_name_to_node=output_name_to_node,
76
+ return_indice=return_indice,
77
+ )
78
+ if sim_ln_nodes is not None:
79
+ mul_node, _reciprocal_node, _sqrt_node, add_node, reduce_mean_node = sim_ln_nodes
80
+ node_parent = mul_node
81
+ else:
82
+ # (root_input) --------------------------------+
83
+ # | |
84
+ # v v
85
+ # Pow --> ReduceMean --> Add ---> Sqrt --> Div --> Mul (node)
86
+ # (B=2) (A/B=eps) (A/B=scale)
87
+ #
88
+ # (root_input) --------------------------------+
89
+ # | | |
90
+ # v v v
91
+ # Mul --> ReduceMean --> Add ---> Sqrt --> Div --> Mul (node)
92
+ # (B=2) (A/B=eps) (A/B=scale)
93
+ #
94
+ return_indice = []
95
+ sim_ln_nodes = self.model.match_parent_path(
96
+ node,
97
+ ["Div", "Sqrt", "Add", "ReduceMean"],
98
+ [None, 1, 0, None],
99
+ output_name_to_node=output_name_to_node,
100
+ return_indice=return_indice,
101
+ )
102
+ if sim_ln_nodes is not None:
103
+ div_node, _sqrt_node, add_node, reduce_mean_node = sim_ln_nodes
104
+ node_parent = div_node
105
+ else:
106
+ return
107
+
108
+ reduce_mean_parent = self.model.get_parent(reduce_mean_node, 0, output_name_to_node)
109
+ if reduce_mean_parent is None or reduce_mean_parent.op_type not in ["Pow", "Mul"]:
110
+ return
111
+
112
+ if reduce_mean_parent.op_type == "Pow":
113
+ if self.model.find_constant_input(reduce_mean_parent, 2.0) != 1:
114
+ return
115
+ else:
116
+ assert reduce_mean_parent.op_type == "Mul"
117
+ if reduce_mean_parent[0] != reduce_mean_parent[1]:
118
+ return
119
+
120
+ root_input = reduce_mean_parent.input[0]
121
+ if root_input not in node_parent.input:
122
+ return
123
+
124
+ _i, epsilon = self.model.get_constant_input(add_node)
125
+ if epsilon is None or epsilon <= 0 or epsilon > 1.0e-4:
126
+ logger.warning(f"epsilon value is not expected: {epsilon}")
127
+ return
128
+
129
+ # ReduceMean must have keepdims == 1
130
+ keepdims = self.model.get_node_attribute(reduce_mean_node, "keepdims")
131
+ if not keepdims:
132
+ return
133
+
134
+ # ReduceMean axes must refer only to the last dimension.
135
+ # Axes became an input in opset 18. Before then, axes was an attribute.
136
+ axes = self.model.get_node_attribute(reduce_mean_node, "axes")
137
+ if (not axes) and len(reduce_mean_node.input) > 1:
138
+ axes = self.model.get_constant_value(reduce_mean_node.input[1])
139
+ # Make sure only one axis as required by SimplifiedLayerNormalization spec.
140
+ if not axes or len(axes) != 1:
141
+ return
142
+
143
+ self.nodes_to_remove.extend(sim_ln_nodes)
144
+ self.nodes_to_remove.append(reduce_mean_parent)
145
+ self.nodes_to_remove.append(node)
146
+
147
+ normalize_node = helper.make_node(
148
+ "SimplifiedLayerNormalization",
149
+ inputs=[root_input, node.input[1 - return_indice[0]]],
150
+ outputs=[node.output[0]],
151
+ name=self.model.create_node_name("SimplifiedLayerNormalization", name_prefix="RMSNorm"),
152
+ )
153
+ normalize_node.attribute.extend([helper.make_attribute("epsilon", float(epsilon))])
154
+ normalize_node.attribute.extend([helper.make_attribute("axis", axes[0])])
155
+ normalize_node.attribute.extend([helper.make_attribute("stash_type", 1)])
156
+ self.nodes_to_add.append(normalize_node)
157
+ self.node_name_to_graph_name[normalize_node.name] = self.this_graph_name
158
+
159
+
160
+ class FusionSkipSimplifiedLayerNormalization(FusionSkipLayerNormalization):
161
+ def __init__(self, model: OnnxModel):
162
+ super().__init__(model, "SkipSimplifiedLayerNormalization", "SimplifiedLayerNormalization")
163
+
164
+ def fuse(self, node, input_name_to_nodes, output_name_to_node):
165
+ super().fuse(node, input_name_to_nodes, output_name_to_node)
@@ -0,0 +1,254 @@
1
+ # -------------------------------------------------------------------------
2
+ # Copyright (c) Microsoft Corporation. All rights reserved.
3
+ # Licensed under the MIT License.
4
+ # --------------------------------------------------------------------------
5
+ from logging import getLogger
6
+
7
+ from fusion_base import Fusion
8
+ from fusion_utils import NumpyHelper
9
+ from onnx import helper
10
+ from onnx_model import OnnxModel
11
+
12
+ logger = getLogger(__name__)
13
+
14
+
15
+ class FusionSkipGroupNorm(Fusion):
16
+ """
17
+ Fuse Add + GroupNorm into one node: SkipGroupNorm.
18
+ """
19
+
20
+ def __init__(self, model: OnnxModel):
21
+ super().__init__(model, "SkipGroupNorm", "GroupNorm")
22
+ # Update shape inference is needed since other fusions might add new edge which does not have shape info yet.
23
+ self.shape_infer_helper = self.model.infer_runtime_shape(update=True)
24
+
25
+ if self.shape_infer_helper is None:
26
+ logger.warning("SkipGroupNorm fusion will be skipped since symbolic shape inference disabled or failed.")
27
+
28
+ def create_transpose_node(self, input_name: str, perm: list[int], output_name=None):
29
+ """Append a Transpose node after an input"""
30
+ node_name = self.model.create_node_name("Transpose")
31
+ if output_name is None:
32
+ output_name = node_name + "_out" + "-" + input_name
33
+ transpose_node = helper.make_node("Transpose", inputs=[input_name], outputs=[output_name], name=node_name)
34
+ transpose_node.attribute.extend([helper.make_attribute("perm", perm)])
35
+ return transpose_node
36
+
37
+ def get_skip_index(self, add, is_channel_last: bool):
38
+ """Add has two inputs. This classifies which input is skip based on shape info (skip allows broadcast)."""
39
+ skip = -1
40
+ broadcast = False
41
+
42
+ assert self.shape_infer_helper is not None
43
+ shape_a = self.shape_infer_helper.get_edge_shape(add.input[0])
44
+ shape_b = self.shape_infer_helper.get_edge_shape(add.input[1])
45
+ assert shape_a is not None and shape_b is not None
46
+
47
+ if len(shape_a) == 4 and len(shape_b) == 4:
48
+ if shape_a == shape_b:
49
+ skip = 1
50
+ else:
51
+ c = 3 if is_channel_last else 1
52
+ h = 1 if is_channel_last else 2
53
+ w = 2 if is_channel_last else 3
54
+ if shape_a[0] == shape_b[0] and shape_a[c] == shape_b[c]:
55
+ if shape_b[h] == 1 and shape_b[w] == 1:
56
+ skip = 1
57
+ broadcast = True
58
+ elif shape_a[h] == 1 and shape_a[w] == 1:
59
+ skip = 0
60
+ broadcast = True
61
+
62
+ if skip < 0:
63
+ logger.debug(
64
+ "skip SkipGroupNorm fusion since shape of Add inputs (%s, %s) are not expected",
65
+ add.input[0],
66
+ add.input[1],
67
+ )
68
+ return skip, broadcast
69
+
70
+ def has_multiple_consumers(self, output_name, input_name_to_nodes):
71
+ """Whether an output has multiple consumers (like graph output or more than one children nodes)"""
72
+ return self.model.find_graph_output(output_name) is not None or (
73
+ output_name in input_name_to_nodes and len(input_name_to_nodes[output_name]) > 1
74
+ )
75
+
76
+ def remove_if_safe(self, node, input_name_to_nodes):
77
+ """Remove a node if it is safe (only one children, and not graph output)"""
78
+ if not self.has_multiple_consumers(node.output[0], input_name_to_nodes):
79
+ self.nodes_to_remove.extend([node])
80
+
81
+ def is_bias_1d(self, bias_name: str):
82
+ """Whether bias is an initializer of one dimension"""
83
+ initializer = self.model.get_initializer(bias_name)
84
+ if initializer is None:
85
+ return False
86
+
87
+ bias_weight = NumpyHelper.to_array(initializer)
88
+ if bias_weight is None:
89
+ logger.debug("Bias weight not found")
90
+ return False
91
+
92
+ if len(bias_weight.shape) != 1:
93
+ logger.debug("Bias weight is not 1D")
94
+ return False
95
+ return True
96
+
97
+ def match_bias_path(self, node, input_name_to_nodes, output_name_to_node):
98
+ """
99
+ Match the bias graph pattern from an Transpose node after Reshape node like in below example.
100
+ It checks whether the bias is 1D initializer. If so, remove Add and redirect MatMul output to Reshape.
101
+ """
102
+ # Before Fusion:
103
+ # MatMul (bias)
104
+ # \ / (shape)
105
+ # Add /
106
+ # \ /
107
+ # (a) Reshape
108
+ # \ |
109
+ # Transpose([0, 3, 1, 2]) Transpose([0, 3, 1, 2]) --- the start node, this func only handles the above nodes.
110
+ # \ /
111
+ # Add
112
+ # / \
113
+ # (c) Transpose([0,2,3,1])
114
+ # |
115
+ # GroupNorm
116
+ # |
117
+ # (d)
118
+ #
119
+ # After Fusion (the nodes below Reshape is handled in the fuse function):
120
+ # MatMul (shape)
121
+ # \ /
122
+ # (a) Reshape
123
+ # \ /
124
+ # SkipGroupNorm
125
+ # / \
126
+ # (d) Transpose([0, 3, 1, 2])
127
+ # \
128
+ # (c)
129
+
130
+ add_input_index = []
131
+ bias_nodes = self.model.match_parent_path(
132
+ node, ["Reshape", "Add", "MatMul"], [0, 0, None], output_name_to_node, add_input_index
133
+ )
134
+ if bias_nodes is None:
135
+ return None
136
+
137
+ (reshape, add_bias, matmul) = bias_nodes
138
+ bias = bias_nodes[1].input[1 - add_input_index[0]]
139
+ if not self.is_bias_1d(bias):
140
+ return None
141
+
142
+ reshape.input[0] = matmul.output[0]
143
+ self.remove_if_safe(add_bias, input_name_to_nodes)
144
+
145
+ return bias
146
+
147
+ def match_transpose_from_nhwc(self, output_name, input_name_to_nodes, output_name_to_node):
148
+ """Match whether an output is from a Transpose(perm=[0,3,1,2]) node."""
149
+ parent = output_name_to_node.get(output_name, None)
150
+ if parent is not None and parent.op_type == "Transpose":
151
+ permutation = OnnxModel.get_node_attribute(parent, "perm")
152
+ if permutation == [0, 3, 1, 2]:
153
+ self.remove_if_safe(parent, input_name_to_nodes)
154
+ return parent
155
+ return None
156
+
157
+ def fuse(self, node, input_name_to_nodes, output_name_to_node):
158
+ # This fusion requires shape information, so skip it if shape is not available.
159
+ if self.shape_infer_helper is None:
160
+ return
161
+
162
+ # Before Fusion:
163
+ # (a) (b)
164
+ # \ /
165
+ # Add
166
+ # /\
167
+ # (c) Transpose([0,2,3,1])
168
+ # \
169
+ # GroupNorm
170
+ # |
171
+ # (d)
172
+ #
173
+ # After Fusion:
174
+ # (a) (b)
175
+ # \ /
176
+ # Transpose([0,2,3,1]) Transpose([0,2,3,1])
177
+ # \ /
178
+ # SkipGroupNorm
179
+ # / \
180
+ # / Transpose([0, 3, 1, 2])
181
+ # / \
182
+ # (d) (c)
183
+ nodes = self.model.match_parent_path(node, ["Transpose", "Add"], [0, 0], output_name_to_node)
184
+ if nodes is None:
185
+ return
186
+
187
+ (transpose, add) = nodes
188
+ if transpose in self.nodes_to_remove or add in self.nodes_to_remove:
189
+ return
190
+
191
+ if self.has_multiple_consumers(transpose.output[0], input_name_to_nodes):
192
+ return
193
+
194
+ permutation = OnnxModel.get_node_attribute(transpose, "perm")
195
+ if permutation != [0, 2, 3, 1]:
196
+ return
197
+
198
+ inputs = []
199
+ bias = None
200
+ for i in range(2):
201
+ matched_transpose = self.match_transpose_from_nhwc(add.input[i], input_name_to_nodes, output_name_to_node)
202
+ if matched_transpose:
203
+ # When there is an Transpose node before Add (see examples in match_bias_path), we do not need to
204
+ # insert another Transpose node. The existing Transpose node will be removed in prune_graph if it
205
+ # has only one consumer.
206
+ inputs.append(matched_transpose.input[0])
207
+ # See whether it match bias pattern.
208
+ if bias is None:
209
+ bias = self.match_bias_path(matched_transpose, input_name_to_nodes, output_name_to_node)
210
+ else:
211
+ # Otherwise, insert a Transpose node before Add.
212
+ new_transpose = self.create_transpose_node(add.input[i], [0, 2, 3, 1])
213
+ self.model.add_node(new_transpose, self.this_graph_name)
214
+ inputs.append(new_transpose.output[0])
215
+
216
+ skip, broadcast = self.get_skip_index(add, is_channel_last=False)
217
+ if skip < 0:
218
+ return
219
+
220
+ inputs = [inputs[1 - skip], node.input[1], node.input[2], inputs[skip]]
221
+ if bias:
222
+ inputs = [*inputs, bias]
223
+
224
+ outputs = node.output
225
+
226
+ new_node_name = self.model.create_node_name(self.fused_op_type, name_prefix="SkipGroupNorm")
227
+ if self.has_multiple_consumers(add.output[0], input_name_to_nodes):
228
+ add_out_name = new_node_name + "_add_out"
229
+ outputs.append(add_out_name)
230
+
231
+ # Insert a Transpose node after add output.
232
+ add_out_transpose = self.create_transpose_node(add_out_name, [0, 3, 1, 2], add.output[0])
233
+ self.model.add_node(add_out_transpose, self.this_graph_name)
234
+
235
+ skip_group_norm = helper.make_node(
236
+ self.fused_op_type,
237
+ inputs=inputs,
238
+ outputs=outputs,
239
+ name=new_node_name,
240
+ )
241
+ skip_group_norm.domain = "com.microsoft"
242
+
243
+ self.increase_counter(
244
+ f"SkipGroupNorm(add_out={int(len(outputs) > 1)} bias={int(bias is not None)} broadcast={int(broadcast)})"
245
+ )
246
+
247
+ # Pass attributes from GroupNorm node to SkipGroupNorm
248
+ for att in node.attribute:
249
+ skip_group_norm.attribute.extend([att])
250
+
251
+ self.nodes_to_remove.extend([add, transpose, node])
252
+ self.nodes_to_add.append(skip_group_norm)
253
+ self.node_name_to_graph_name[skip_group_norm.name] = self.this_graph_name
254
+ self.prune_graph = True