onnxruntime-directml 1.24.1__cp314-cp314-win_amd64.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (322) hide show
  1. onnxruntime/LICENSE +21 -0
  2. onnxruntime/Privacy.md +21 -0
  3. onnxruntime/ThirdPartyNotices.txt +6121 -0
  4. onnxruntime/__init__.py +418 -0
  5. onnxruntime/backend/__init__.py +6 -0
  6. onnxruntime/backend/backend.py +175 -0
  7. onnxruntime/backend/backend_rep.py +52 -0
  8. onnxruntime/capi/DirectML.dll +0 -0
  9. onnxruntime/capi/__init__.py +4 -0
  10. onnxruntime/capi/_ld_preload.py +7 -0
  11. onnxruntime/capi/_pybind_state.py +33 -0
  12. onnxruntime/capi/build_and_package_info.py +2 -0
  13. onnxruntime/capi/convert_npz_to_onnx_adapter.py +48 -0
  14. onnxruntime/capi/onnxruntime.dll +0 -0
  15. onnxruntime/capi/onnxruntime_collect_build_info.py +47 -0
  16. onnxruntime/capi/onnxruntime_inference_collection.py +1440 -0
  17. onnxruntime/capi/onnxruntime_providers_shared.dll +0 -0
  18. onnxruntime/capi/onnxruntime_pybind11_state.pyd +0 -0
  19. onnxruntime/capi/onnxruntime_validation.py +154 -0
  20. onnxruntime/capi/version_info.py +2 -0
  21. onnxruntime/datasets/__init__.py +18 -0
  22. onnxruntime/datasets/logreg_iris.onnx +0 -0
  23. onnxruntime/datasets/mul_1.onnx +0 -0
  24. onnxruntime/datasets/sigmoid.onnx +13 -0
  25. onnxruntime/quantization/CalTableFlatBuffers/KeyValue.py +78 -0
  26. onnxruntime/quantization/CalTableFlatBuffers/TrtTable.py +90 -0
  27. onnxruntime/quantization/CalTableFlatBuffers/__init__.py +0 -0
  28. onnxruntime/quantization/__init__.py +19 -0
  29. onnxruntime/quantization/base_quantizer.py +529 -0
  30. onnxruntime/quantization/calibrate.py +1267 -0
  31. onnxruntime/quantization/execution_providers/qnn/__init__.py +2 -0
  32. onnxruntime/quantization/execution_providers/qnn/fusion_lpnorm.py +132 -0
  33. onnxruntime/quantization/execution_providers/qnn/fusion_spacetodepth.py +162 -0
  34. onnxruntime/quantization/execution_providers/qnn/mixed_precision_overrides_utils.py +413 -0
  35. onnxruntime/quantization/execution_providers/qnn/preprocess.py +353 -0
  36. onnxruntime/quantization/execution_providers/qnn/quant_config.py +389 -0
  37. onnxruntime/quantization/fusions/__init__.py +4 -0
  38. onnxruntime/quantization/fusions/fusion.py +311 -0
  39. onnxruntime/quantization/fusions/fusion_gelu.py +272 -0
  40. onnxruntime/quantization/fusions/fusion_layernorm.py +146 -0
  41. onnxruntime/quantization/fusions/replace_upsample_with_resize.py +96 -0
  42. onnxruntime/quantization/matmul_bnb4_quantizer.py +239 -0
  43. onnxruntime/quantization/matmul_nbits_quantizer.py +1638 -0
  44. onnxruntime/quantization/neural_compressor/__init__.py +1 -0
  45. onnxruntime/quantization/neural_compressor/onnx_model.py +1251 -0
  46. onnxruntime/quantization/neural_compressor/util.py +80 -0
  47. onnxruntime/quantization/neural_compressor/weight_only.py +932 -0
  48. onnxruntime/quantization/onnx_model.py +600 -0
  49. onnxruntime/quantization/onnx_quantizer.py +1163 -0
  50. onnxruntime/quantization/operators/__init__.py +2 -0
  51. onnxruntime/quantization/operators/activation.py +119 -0
  52. onnxruntime/quantization/operators/argmax.py +18 -0
  53. onnxruntime/quantization/operators/attention.py +73 -0
  54. onnxruntime/quantization/operators/base_operator.py +26 -0
  55. onnxruntime/quantization/operators/binary_op.py +72 -0
  56. onnxruntime/quantization/operators/concat.py +62 -0
  57. onnxruntime/quantization/operators/conv.py +260 -0
  58. onnxruntime/quantization/operators/direct_q8.py +78 -0
  59. onnxruntime/quantization/operators/embed_layernorm.py +121 -0
  60. onnxruntime/quantization/operators/gather.py +64 -0
  61. onnxruntime/quantization/operators/gavgpool.py +62 -0
  62. onnxruntime/quantization/operators/gemm.py +172 -0
  63. onnxruntime/quantization/operators/lstm.py +121 -0
  64. onnxruntime/quantization/operators/matmul.py +231 -0
  65. onnxruntime/quantization/operators/maxpool.py +34 -0
  66. onnxruntime/quantization/operators/norm.py +40 -0
  67. onnxruntime/quantization/operators/pad.py +172 -0
  68. onnxruntime/quantization/operators/pooling.py +67 -0
  69. onnxruntime/quantization/operators/qdq_base_operator.py +22 -0
  70. onnxruntime/quantization/operators/resize.py +34 -0
  71. onnxruntime/quantization/operators/softmax.py +74 -0
  72. onnxruntime/quantization/operators/split.py +63 -0
  73. onnxruntime/quantization/operators/where.py +87 -0
  74. onnxruntime/quantization/preprocess.py +141 -0
  75. onnxruntime/quantization/qdq_loss_debug.py +389 -0
  76. onnxruntime/quantization/qdq_quantizer.py +1477 -0
  77. onnxruntime/quantization/quant_utils.py +1051 -0
  78. onnxruntime/quantization/quantize.py +953 -0
  79. onnxruntime/quantization/registry.py +110 -0
  80. onnxruntime/quantization/shape_inference.py +204 -0
  81. onnxruntime/quantization/static_quantize_runner.py +256 -0
  82. onnxruntime/quantization/tensor_quant_overrides.py +520 -0
  83. onnxruntime/tools/__init__.py +10 -0
  84. onnxruntime/tools/check_onnx_model_mobile_usability.py +47 -0
  85. onnxruntime/tools/convert_onnx_models_to_ort.py +380 -0
  86. onnxruntime/tools/file_utils.py +47 -0
  87. onnxruntime/tools/logger.py +11 -0
  88. onnxruntime/tools/make_dynamic_shape_fixed.py +73 -0
  89. onnxruntime/tools/mobile_helpers/__init__.py +0 -0
  90. onnxruntime/tools/mobile_helpers/coreml_supported_mlprogram_ops.md +53 -0
  91. onnxruntime/tools/mobile_helpers/coreml_supported_neuralnetwork_ops.md +43 -0
  92. onnxruntime/tools/mobile_helpers/nnapi_supported_ops.md +58 -0
  93. onnxruntime/tools/mobile_helpers/usability_checker.py +738 -0
  94. onnxruntime/tools/offline_tuning.py +169 -0
  95. onnxruntime/tools/onnx_model_utils.py +416 -0
  96. onnxruntime/tools/onnx_randomizer.py +85 -0
  97. onnxruntime/tools/onnxruntime_test.py +164 -0
  98. onnxruntime/tools/optimize_onnx_model.py +56 -0
  99. onnxruntime/tools/ort_format_model/__init__.py +27 -0
  100. onnxruntime/tools/ort_format_model/operator_type_usage_processors.py +653 -0
  101. onnxruntime/tools/ort_format_model/ort_flatbuffers_py/__init__.py +0 -0
  102. onnxruntime/tools/ort_format_model/ort_flatbuffers_py/fbs/ArgType.py +7 -0
  103. onnxruntime/tools/ort_format_model/ort_flatbuffers_py/fbs/ArgTypeAndIndex.py +67 -0
  104. onnxruntime/tools/ort_format_model/ort_flatbuffers_py/fbs/Attribute.py +337 -0
  105. onnxruntime/tools/ort_format_model/ort_flatbuffers_py/fbs/AttributeType.py +18 -0
  106. onnxruntime/tools/ort_format_model/ort_flatbuffers_py/fbs/Checkpoint.py +125 -0
  107. onnxruntime/tools/ort_format_model/ort_flatbuffers_py/fbs/DeprecatedKernelCreateInfos.py +120 -0
  108. onnxruntime/tools/ort_format_model/ort_flatbuffers_py/fbs/DeprecatedNodeIndexAndKernelDefHash.py +68 -0
  109. onnxruntime/tools/ort_format_model/ort_flatbuffers_py/fbs/DeprecatedSessionState.py +96 -0
  110. onnxruntime/tools/ort_format_model/ort_flatbuffers_py/fbs/DeprecatedSubGraphSessionState.py +72 -0
  111. onnxruntime/tools/ort_format_model/ort_flatbuffers_py/fbs/Dimension.py +71 -0
  112. onnxruntime/tools/ort_format_model/ort_flatbuffers_py/fbs/DimensionValue.py +80 -0
  113. onnxruntime/tools/ort_format_model/ort_flatbuffers_py/fbs/DimensionValueType.py +8 -0
  114. onnxruntime/tools/ort_format_model/ort_flatbuffers_py/fbs/EdgeEnd.py +32 -0
  115. onnxruntime/tools/ort_format_model/ort_flatbuffers_py/fbs/FloatProperty.py +67 -0
  116. onnxruntime/tools/ort_format_model/ort_flatbuffers_py/fbs/Graph.py +320 -0
  117. onnxruntime/tools/ort_format_model/ort_flatbuffers_py/fbs/InferenceSession.py +88 -0
  118. onnxruntime/tools/ort_format_model/ort_flatbuffers_py/fbs/IntProperty.py +67 -0
  119. onnxruntime/tools/ort_format_model/ort_flatbuffers_py/fbs/KernelTypeStrArgsEntry.py +91 -0
  120. onnxruntime/tools/ort_format_model/ort_flatbuffers_py/fbs/KernelTypeStrResolver.py +78 -0
  121. onnxruntime/tools/ort_format_model/ort_flatbuffers_py/fbs/MapType.py +71 -0
  122. onnxruntime/tools/ort_format_model/ort_flatbuffers_py/fbs/Model.py +223 -0
  123. onnxruntime/tools/ort_format_model/ort_flatbuffers_py/fbs/ModuleState.py +141 -0
  124. onnxruntime/tools/ort_format_model/ort_flatbuffers_py/fbs/Node.py +317 -0
  125. onnxruntime/tools/ort_format_model/ort_flatbuffers_py/fbs/NodeEdge.py +126 -0
  126. onnxruntime/tools/ort_format_model/ort_flatbuffers_py/fbs/NodeType.py +7 -0
  127. onnxruntime/tools/ort_format_model/ort_flatbuffers_py/fbs/NodesToOptimizeIndices.py +160 -0
  128. onnxruntime/tools/ort_format_model/ort_flatbuffers_py/fbs/OpIdKernelTypeStrArgsEntry.py +91 -0
  129. onnxruntime/tools/ort_format_model/ort_flatbuffers_py/fbs/OperatorSetId.py +67 -0
  130. onnxruntime/tools/ort_format_model/ort_flatbuffers_py/fbs/OptimizerGroup.py +117 -0
  131. onnxruntime/tools/ort_format_model/ort_flatbuffers_py/fbs/ParameterOptimizerState.py +91 -0
  132. onnxruntime/tools/ort_format_model/ort_flatbuffers_py/fbs/PropertyBag.py +152 -0
  133. onnxruntime/tools/ort_format_model/ort_flatbuffers_py/fbs/RuntimeOptimizationRecord.py +105 -0
  134. onnxruntime/tools/ort_format_model/ort_flatbuffers_py/fbs/RuntimeOptimizationRecordContainerEntry.py +91 -0
  135. onnxruntime/tools/ort_format_model/ort_flatbuffers_py/fbs/RuntimeOptimizations.py +79 -0
  136. onnxruntime/tools/ort_format_model/ort_flatbuffers_py/fbs/SequenceType.py +58 -0
  137. onnxruntime/tools/ort_format_model/ort_flatbuffers_py/fbs/Shape.py +78 -0
  138. onnxruntime/tools/ort_format_model/ort_flatbuffers_py/fbs/SparseTensor.py +114 -0
  139. onnxruntime/tools/ort_format_model/ort_flatbuffers_py/fbs/StringProperty.py +67 -0
  140. onnxruntime/tools/ort_format_model/ort_flatbuffers_py/fbs/StringStringEntry.py +67 -0
  141. onnxruntime/tools/ort_format_model/ort_flatbuffers_py/fbs/Tensor.py +203 -0
  142. onnxruntime/tools/ort_format_model/ort_flatbuffers_py/fbs/TensorDataType.py +26 -0
  143. onnxruntime/tools/ort_format_model/ort_flatbuffers_py/fbs/TensorTypeAndShape.py +71 -0
  144. onnxruntime/tools/ort_format_model/ort_flatbuffers_py/fbs/TypeInfo.py +83 -0
  145. onnxruntime/tools/ort_format_model/ort_flatbuffers_py/fbs/TypeInfoValue.py +9 -0
  146. onnxruntime/tools/ort_format_model/ort_flatbuffers_py/fbs/ValueInfo.py +84 -0
  147. onnxruntime/tools/ort_format_model/ort_flatbuffers_py/fbs/__init__.py +6 -0
  148. onnxruntime/tools/ort_format_model/ort_model_processor.py +86 -0
  149. onnxruntime/tools/ort_format_model/types.py +85 -0
  150. onnxruntime/tools/ort_format_model/utils.py +61 -0
  151. onnxruntime/tools/pytorch_export_contrib_ops.py +129 -0
  152. onnxruntime/tools/pytorch_export_helpers.py +131 -0
  153. onnxruntime/tools/qdq_helpers/__init__.py +0 -0
  154. onnxruntime/tools/qdq_helpers/optimize_qdq_model.py +37 -0
  155. onnxruntime/tools/qnn/add_trans_cast.py +292 -0
  156. onnxruntime/tools/qnn/gen_qnn_ctx_onnx_model.py +364 -0
  157. onnxruntime/tools/qnn/preprocess.py +165 -0
  158. onnxruntime/tools/reduced_build_config_parser.py +203 -0
  159. onnxruntime/tools/remove_initializer_from_input.py +37 -0
  160. onnxruntime/tools/symbolic_shape_infer.py +3094 -0
  161. onnxruntime/tools/update_onnx_opset.py +31 -0
  162. onnxruntime/transformers/__init__.py +8 -0
  163. onnxruntime/transformers/affinity_helper.py +40 -0
  164. onnxruntime/transformers/benchmark.py +942 -0
  165. onnxruntime/transformers/benchmark_helper.py +643 -0
  166. onnxruntime/transformers/bert_perf_test.py +629 -0
  167. onnxruntime/transformers/bert_test_data.py +641 -0
  168. onnxruntime/transformers/compare_bert_results.py +256 -0
  169. onnxruntime/transformers/constants.py +47 -0
  170. onnxruntime/transformers/convert_generation.py +3605 -0
  171. onnxruntime/transformers/convert_tf_models_to_pytorch.py +205 -0
  172. onnxruntime/transformers/convert_to_packing_mode.py +385 -0
  173. onnxruntime/transformers/dynamo_onnx_helper.py +205 -0
  174. onnxruntime/transformers/float16.py +501 -0
  175. onnxruntime/transformers/fusion_attention.py +1189 -0
  176. onnxruntime/transformers/fusion_attention_clip.py +340 -0
  177. onnxruntime/transformers/fusion_attention_sam2.py +533 -0
  178. onnxruntime/transformers/fusion_attention_unet.py +1307 -0
  179. onnxruntime/transformers/fusion_attention_vae.py +300 -0
  180. onnxruntime/transformers/fusion_bart_attention.py +435 -0
  181. onnxruntime/transformers/fusion_base.py +141 -0
  182. onnxruntime/transformers/fusion_bias_add.py +57 -0
  183. onnxruntime/transformers/fusion_biasgelu.py +66 -0
  184. onnxruntime/transformers/fusion_biassplitgelu.py +110 -0
  185. onnxruntime/transformers/fusion_conformer_attention.py +222 -0
  186. onnxruntime/transformers/fusion_constant_fold.py +144 -0
  187. onnxruntime/transformers/fusion_embedlayer.py +810 -0
  188. onnxruntime/transformers/fusion_fastgelu.py +492 -0
  189. onnxruntime/transformers/fusion_gelu.py +258 -0
  190. onnxruntime/transformers/fusion_gelu_approximation.py +25 -0
  191. onnxruntime/transformers/fusion_gemmfastgelu.py +121 -0
  192. onnxruntime/transformers/fusion_gpt_attention.py +546 -0
  193. onnxruntime/transformers/fusion_gpt_attention_megatron.py +355 -0
  194. onnxruntime/transformers/fusion_gpt_attention_no_past.py +260 -0
  195. onnxruntime/transformers/fusion_group_norm.py +180 -0
  196. onnxruntime/transformers/fusion_layernorm.py +489 -0
  197. onnxruntime/transformers/fusion_mha_mmdit.py +667 -0
  198. onnxruntime/transformers/fusion_nhwc_conv.py +99 -0
  199. onnxruntime/transformers/fusion_options.py +340 -0
  200. onnxruntime/transformers/fusion_qordered_attention.py +420 -0
  201. onnxruntime/transformers/fusion_qordered_gelu.py +118 -0
  202. onnxruntime/transformers/fusion_qordered_layernorm.py +122 -0
  203. onnxruntime/transformers/fusion_qordered_matmul.py +216 -0
  204. onnxruntime/transformers/fusion_quickgelu.py +74 -0
  205. onnxruntime/transformers/fusion_reshape.py +173 -0
  206. onnxruntime/transformers/fusion_rotary_attention.py +1591 -0
  207. onnxruntime/transformers/fusion_shape.py +109 -0
  208. onnxruntime/transformers/fusion_simplified_layernorm.py +165 -0
  209. onnxruntime/transformers/fusion_skip_group_norm.py +254 -0
  210. onnxruntime/transformers/fusion_skiplayernorm.py +209 -0
  211. onnxruntime/transformers/fusion_transpose.py +167 -0
  212. onnxruntime/transformers/fusion_utils.py +321 -0
  213. onnxruntime/transformers/huggingface_models.py +74 -0
  214. onnxruntime/transformers/import_utils.py +20 -0
  215. onnxruntime/transformers/io_binding_helper.py +487 -0
  216. onnxruntime/transformers/large_model_exporter.py +395 -0
  217. onnxruntime/transformers/machine_info.py +230 -0
  218. onnxruntime/transformers/metrics.py +163 -0
  219. onnxruntime/transformers/models/bart/__init__.py +12 -0
  220. onnxruntime/transformers/models/bart/export.py +98 -0
  221. onnxruntime/transformers/models/bert/__init__.py +12 -0
  222. onnxruntime/transformers/models/bert/eval_squad.py +329 -0
  223. onnxruntime/transformers/models/gpt2/__init__.py +12 -0
  224. onnxruntime/transformers/models/gpt2/benchmark_gpt2.py +413 -0
  225. onnxruntime/transformers/models/gpt2/convert_to_onnx.py +566 -0
  226. onnxruntime/transformers/models/gpt2/gpt2_helper.py +1031 -0
  227. onnxruntime/transformers/models/gpt2/gpt2_parity.py +513 -0
  228. onnxruntime/transformers/models/gpt2/gpt2_tester.py +501 -0
  229. onnxruntime/transformers/models/gpt2/parity_check_helper.py +146 -0
  230. onnxruntime/transformers/models/llama/__init__.py +12 -0
  231. onnxruntime/transformers/models/llama/benchmark.py +700 -0
  232. onnxruntime/transformers/models/llama/benchmark_all.py +488 -0
  233. onnxruntime/transformers/models/llama/benchmark_e2e.py +608 -0
  234. onnxruntime/transformers/models/llama/convert_to_onnx.py +1064 -0
  235. onnxruntime/transformers/models/llama/dist_settings.py +57 -0
  236. onnxruntime/transformers/models/llama/llama_inputs.py +504 -0
  237. onnxruntime/transformers/models/llama/llama_parity.py +343 -0
  238. onnxruntime/transformers/models/llama/llama_torch.py +47 -0
  239. onnxruntime/transformers/models/llama/quant_kv_dataloader.py +108 -0
  240. onnxruntime/transformers/models/longformer/__init__.py +12 -0
  241. onnxruntime/transformers/models/longformer/benchmark_longformer.py +821 -0
  242. onnxruntime/transformers/models/longformer/convert_to_onnx.py +413 -0
  243. onnxruntime/transformers/models/longformer/generate_test_data.py +347 -0
  244. onnxruntime/transformers/models/longformer/longformer_helper.py +76 -0
  245. onnxruntime/transformers/models/phi2/__init__.py +12 -0
  246. onnxruntime/transformers/models/phi2/convert_to_onnx.py +590 -0
  247. onnxruntime/transformers/models/phi2/inference_example.py +414 -0
  248. onnxruntime/transformers/models/sam2/__init__.py +12 -0
  249. onnxruntime/transformers/models/sam2/benchmark_sam2.py +638 -0
  250. onnxruntime/transformers/models/sam2/convert_to_onnx.py +270 -0
  251. onnxruntime/transformers/models/sam2/image_decoder.py +272 -0
  252. onnxruntime/transformers/models/sam2/image_encoder.py +236 -0
  253. onnxruntime/transformers/models/sam2/mask_decoder.py +208 -0
  254. onnxruntime/transformers/models/sam2/nvtx_helper.py +33 -0
  255. onnxruntime/transformers/models/sam2/prompt_encoder.py +189 -0
  256. onnxruntime/transformers/models/sam2/sam2_demo.py +321 -0
  257. onnxruntime/transformers/models/sam2/sam2_image_onnx_predictor.py +279 -0
  258. onnxruntime/transformers/models/sam2/sam2_utils.py +147 -0
  259. onnxruntime/transformers/models/stable_diffusion/__init__.py +12 -0
  260. onnxruntime/transformers/models/stable_diffusion/benchmark.py +1519 -0
  261. onnxruntime/transformers/models/stable_diffusion/benchmark_controlnet.py +426 -0
  262. onnxruntime/transformers/models/stable_diffusion/demo_txt2img.py +103 -0
  263. onnxruntime/transformers/models/stable_diffusion/demo_txt2img_xl.py +269 -0
  264. onnxruntime/transformers/models/stable_diffusion/demo_utils.py +778 -0
  265. onnxruntime/transformers/models/stable_diffusion/diffusion_models.py +1318 -0
  266. onnxruntime/transformers/models/stable_diffusion/diffusion_schedulers.py +1179 -0
  267. onnxruntime/transformers/models/stable_diffusion/engine_builder.py +295 -0
  268. onnxruntime/transformers/models/stable_diffusion/engine_builder_ort_cuda.py +387 -0
  269. onnxruntime/transformers/models/stable_diffusion/engine_builder_ort_trt.py +288 -0
  270. onnxruntime/transformers/models/stable_diffusion/engine_builder_tensorrt.py +395 -0
  271. onnxruntime/transformers/models/stable_diffusion/engine_builder_torch.py +108 -0
  272. onnxruntime/transformers/models/stable_diffusion/optimize_pipeline.py +590 -0
  273. onnxruntime/transformers/models/stable_diffusion/ort_optimizer.py +136 -0
  274. onnxruntime/transformers/models/stable_diffusion/pipeline_stable_diffusion.py +831 -0
  275. onnxruntime/transformers/models/stable_diffusion/trt_utilities.py +12 -0
  276. onnxruntime/transformers/models/t5/__init__.py +12 -0
  277. onnxruntime/transformers/models/t5/convert_to_onnx.py +318 -0
  278. onnxruntime/transformers/models/t5/t5_decoder.py +437 -0
  279. onnxruntime/transformers/models/t5/t5_encoder.py +70 -0
  280. onnxruntime/transformers/models/t5/t5_encoder_decoder_init.py +361 -0
  281. onnxruntime/transformers/models/t5/t5_helper.py +302 -0
  282. onnxruntime/transformers/models/whisper/__init__.py +12 -0
  283. onnxruntime/transformers/models/whisper/benchmark.py +585 -0
  284. onnxruntime/transformers/models/whisper/benchmark_all.py +526 -0
  285. onnxruntime/transformers/models/whisper/convert_to_onnx.py +609 -0
  286. onnxruntime/transformers/models/whisper/whisper_chain.py +334 -0
  287. onnxruntime/transformers/models/whisper/whisper_decoder.py +464 -0
  288. onnxruntime/transformers/models/whisper/whisper_encoder.py +164 -0
  289. onnxruntime/transformers/models/whisper/whisper_encoder_decoder_init.py +371 -0
  290. onnxruntime/transformers/models/whisper/whisper_helper.py +1035 -0
  291. onnxruntime/transformers/models/whisper/whisper_inputs.py +380 -0
  292. onnxruntime/transformers/models/whisper/whisper_jump_times.py +477 -0
  293. onnxruntime/transformers/onnx_exporter.py +719 -0
  294. onnxruntime/transformers/onnx_model.py +1636 -0
  295. onnxruntime/transformers/onnx_model_bart.py +141 -0
  296. onnxruntime/transformers/onnx_model_bert.py +488 -0
  297. onnxruntime/transformers/onnx_model_bert_keras.py +474 -0
  298. onnxruntime/transformers/onnx_model_bert_tf.py +588 -0
  299. onnxruntime/transformers/onnx_model_clip.py +42 -0
  300. onnxruntime/transformers/onnx_model_conformer.py +32 -0
  301. onnxruntime/transformers/onnx_model_gpt2.py +101 -0
  302. onnxruntime/transformers/onnx_model_mmdit.py +112 -0
  303. onnxruntime/transformers/onnx_model_phi.py +929 -0
  304. onnxruntime/transformers/onnx_model_sam2.py +137 -0
  305. onnxruntime/transformers/onnx_model_t5.py +985 -0
  306. onnxruntime/transformers/onnx_model_tnlr.py +226 -0
  307. onnxruntime/transformers/onnx_model_unet.py +258 -0
  308. onnxruntime/transformers/onnx_model_vae.py +42 -0
  309. onnxruntime/transformers/onnx_utils.py +55 -0
  310. onnxruntime/transformers/optimizer.py +620 -0
  311. onnxruntime/transformers/past_helper.py +149 -0
  312. onnxruntime/transformers/profile_result_processor.py +358 -0
  313. onnxruntime/transformers/profiler.py +434 -0
  314. onnxruntime/transformers/quantize_helper.py +76 -0
  315. onnxruntime/transformers/shape_infer_helper.py +121 -0
  316. onnxruntime/transformers/shape_optimizer.py +400 -0
  317. onnxruntime/transformers/torch_onnx_export_helper.py +74 -0
  318. onnxruntime_directml-1.24.1.dist-info/METADATA +216 -0
  319. onnxruntime_directml-1.24.1.dist-info/RECORD +322 -0
  320. onnxruntime_directml-1.24.1.dist-info/WHEEL +5 -0
  321. onnxruntime_directml-1.24.1.dist-info/entry_points.txt +2 -0
  322. onnxruntime_directml-1.24.1.dist-info/top_level.txt +1 -0
@@ -0,0 +1,600 @@
1
+ # --------------------------------------------------------------------------
2
+ # Copyright (c) Microsoft Corporation. All rights reserved.
3
+ # Licensed under the MIT License.
4
+ # --------------------------------------------------------------------------
5
+ from pathlib import Path
6
+
7
+ import onnx
8
+ import onnx.helper as onnx_helper
9
+ import onnx.numpy_helper as onnx_numpy_helper
10
+ from onnx.onnx_pb import ModelProto
11
+
12
+ from .quant_utils import attribute_to_kwarg, find_by_name
13
+
14
+
15
+ def _clean_initializers_helper(graph, model):
16
+ """Clean unused initializers from graph.
17
+
18
+ Returns:
19
+ A cleaned graph without unused initializers
20
+ A list of tensor names, which are not produced by this graph and its subgraphes
21
+ """
22
+ requesting_tensor_names = set()
23
+ requesting_tensor_names.update(input_name for node in graph.node for input_name in node.input if input_name)
24
+ requesting_tensor_names.update(g_out.name for g_out in graph.output if g_out.name)
25
+
26
+ new_nodes = []
27
+ for node in graph.node:
28
+ new_node = node
29
+ graph_attrs = [
30
+ attr
31
+ for attr in node.attribute
32
+ if attr.type == onnx.AttributeProto.GRAPH or attr.type == onnx.AttributeProto.GRAPHS
33
+ ]
34
+ if graph_attrs:
35
+ kwargs = {}
36
+ for attr in node.attribute:
37
+ new_attribute = {}
38
+ if attr.type == onnx.AttributeProto.GRAPH:
39
+ (
40
+ cleaned_sub_graph,
41
+ sub_requesting_tensor_names,
42
+ ) = _clean_initializers_helper(attr.g, model)
43
+ new_attribute = {attr.name: cleaned_sub_graph}
44
+ requesting_tensor_names.update(sub_requesting_tensor_names)
45
+ elif attr.type == onnx.AttributeProto.GRAPHS:
46
+ cleaned_graphes = []
47
+ for subgraph in attr.graphs:
48
+ (
49
+ cleaned_sub_graph,
50
+ sub_requesting_tensor_names,
51
+ ) = _clean_initializers_helper(subgraph, model)
52
+ cleaned_graphes.append(cleaned_sub_graph)
53
+ requesting_tensor_names.update(sub_requesting_tensor_names)
54
+ new_attribute = {attr.name: cleaned_graphes}
55
+ else:
56
+ new_attribute = attribute_to_kwarg(attr)
57
+ kwargs.update(new_attribute)
58
+ new_node = onnx_helper.make_node(node.op_type, node.input, node.output, name=node.name, **kwargs)
59
+ new_nodes.append(new_node)
60
+
61
+ graph.ClearField("node")
62
+ graph.node.extend(new_nodes)
63
+
64
+ requesting_tensor_names.difference_update(output for node in graph.node for output in node.output)
65
+
66
+ unused_initializer = []
67
+ for initializer in graph.initializer:
68
+ if initializer.name in requesting_tensor_names:
69
+ requesting_tensor_names.remove(initializer.name)
70
+ else:
71
+ # mark it to remove, remove here directly will cause mis-behavier
72
+ unused_initializer.append(initializer)
73
+
74
+ name_to_input = {input.name: input for input in graph.input}
75
+ for initializer in unused_initializer:
76
+ graph.initializer.remove(initializer)
77
+ if initializer.name in name_to_input:
78
+ try:
79
+ graph.input.remove(name_to_input[initializer.name])
80
+ except StopIteration:
81
+ if model.ir_version < 4:
82
+ print(f"Warning: invalid weight name {initializer.name} found in the graph (not a graph input)")
83
+
84
+ requesting_tensor_names.difference_update(input.name for input in graph.input)
85
+
86
+ return graph, requesting_tensor_names
87
+
88
+
89
+ class ONNXModel:
90
+ def __init__(self, model: ModelProto):
91
+ self.model = model
92
+
93
+ def nodes(self):
94
+ return self.model.graph.node
95
+
96
+ def initializer(self):
97
+ return self.model.graph.initializer
98
+
99
+ def initializer_extend(self, inits):
100
+ if len(inits) == 0:
101
+ raise ValueError("Can add an empty list.")
102
+ for init in self.initializer():
103
+ self._check_init(init, "gain")
104
+ for init in inits:
105
+ self._check_init(init)
106
+ self.model.graph.initializer.append(init)
107
+
108
+ def graph(self):
109
+ return self.model.graph
110
+
111
+ def ir_version(self):
112
+ return self.model.ir_version
113
+
114
+ def opset_import(self):
115
+ return self.model.opset_import
116
+
117
+ def set_opset_import(self, domain, version):
118
+ for opset in self.model.opset_import:
119
+ if opset.domain == domain:
120
+ opset.version = version
121
+ return
122
+
123
+ self.model.opset_import.extend([onnx_helper.make_opsetid(domain, version)])
124
+
125
+ def remove_node(self, node):
126
+ if node in self.model.graph.node:
127
+ self.model.graph.node.remove(node)
128
+
129
+ def remove_nodes(self, nodes_to_remove):
130
+ for node in nodes_to_remove:
131
+ self.remove_node(node)
132
+
133
+ def add_node(self, node):
134
+ self.model.graph.node.extend([self._check_node(node)])
135
+
136
+ def add_nodes(self, nodes_to_add):
137
+ for node in nodes_to_add:
138
+ self.add_node(node)
139
+
140
+ def add_initializer(self, tensor):
141
+ if find_by_name(tensor.name, self.model.graph.initializer) is None:
142
+ self._check_init(tensor)
143
+ self.model.graph.initializer.extend([tensor])
144
+
145
+ def get_initializer(self, name):
146
+ for tensor in self.model.graph.initializer:
147
+ if tensor.name == name:
148
+ return tensor
149
+ return None
150
+
151
+ def find_graph_input(self, input_name):
152
+ for input in self.model.graph.input:
153
+ if input.name == input_name:
154
+ return input
155
+ return None
156
+
157
+ def find_graph_output(self, output_name):
158
+ for output in self.model.graph.output:
159
+ if output.name == output_name:
160
+ return output
161
+ return None
162
+
163
+ def get_tensor_type(self, tensor_name: str):
164
+ tensor_type_map = {obj.name: obj.type for obj in self.model.graph.value_info}
165
+
166
+ if tensor_name in tensor_type_map:
167
+ return tensor_type_map[tensor_name].tensor_type
168
+
169
+ g_input = self.find_graph_input(tensor_name)
170
+ if g_input:
171
+ return g_input.type.tensor_type
172
+
173
+ g_output = self.find_graph_output(tensor_name)
174
+ if g_output:
175
+ return g_output.type.tensor_type
176
+
177
+ return None
178
+
179
+ def get_constant_value(self, output_name):
180
+ for node in self.model.graph.node:
181
+ if node.op_type == "Constant":
182
+ if node.output[0] == output_name:
183
+ for attr in node.attribute:
184
+ if attr.name == "value":
185
+ return onnx_numpy_helper.to_array(attr.t)
186
+
187
+ # Fallback to initializer since constant folding may have been applied.
188
+ initializer = self.get_initializer(output_name)
189
+ if initializer is not None:
190
+ return onnx_numpy_helper.to_array(initializer)
191
+
192
+ return None
193
+
194
+ def get_initializer_name_set(self):
195
+ return {initializer.name for initializer in self.model.graph.initializer}
196
+
197
+ def remove_initializer(self, tensor):
198
+ if tensor in self.model.graph.initializer:
199
+ self.model.graph.initializer.remove(tensor)
200
+ for input in self.model.graph.input:
201
+ if input.name == tensor.name:
202
+ self.model.graph.input.remove(input)
203
+ break
204
+
205
+ def remove_initializers(self, init_to_remove):
206
+ for initializer in init_to_remove:
207
+ self.remove_initializer(initializer)
208
+
209
+ def get_non_initializer_inputs(self):
210
+ initializer_names = self.get_initializer_name_set()
211
+ non_initializer_inputs = set()
212
+ for input in self.model.graph.input:
213
+ if input.name not in initializer_names:
214
+ non_initializer_inputs.add(input.name)
215
+ return non_initializer_inputs
216
+
217
+ def input_name_to_nodes(self):
218
+ input_name_to_nodes = {}
219
+ for node in self.model.graph.node:
220
+ for input_name in node.input:
221
+ if input_name: # Could be empty when it is optional
222
+ if input_name not in input_name_to_nodes:
223
+ input_name_to_nodes[input_name] = [node]
224
+ else:
225
+ input_name_to_nodes[input_name].append(node)
226
+ return input_name_to_nodes
227
+
228
+ def output_name_to_node(self):
229
+ output_name_to_node = {}
230
+ for node in self.model.graph.node:
231
+ for output_name in node.output:
232
+ if output_name: # Could be empty when it is optional
233
+ output_name_to_node[output_name] = node
234
+ return output_name_to_node
235
+
236
+ def get_children(self, node, input_name_to_nodes=None):
237
+ if input_name_to_nodes is None:
238
+ input_name_to_nodes = self.input_name_to_nodes()
239
+
240
+ children = []
241
+ for output in node.output:
242
+ if output in input_name_to_nodes:
243
+ for node in input_name_to_nodes[output]:
244
+ children.append(node) # noqa: PERF402
245
+ return children
246
+
247
+ def get_parents(self, node, output_name_to_node=None):
248
+ if output_name_to_node is None:
249
+ output_name_to_node = self.output_name_to_node()
250
+
251
+ parents = []
252
+ for input in node.input:
253
+ if input in output_name_to_node:
254
+ parents.append(output_name_to_node[input])
255
+ return parents
256
+
257
+ def get_parent(self, node, idx, output_name_to_node=None):
258
+ if output_name_to_node is None:
259
+ output_name_to_node = self.output_name_to_node()
260
+
261
+ if len(node.input) <= idx:
262
+ return None
263
+
264
+ input = node.input[idx]
265
+ if input not in output_name_to_node:
266
+ return None
267
+
268
+ return output_name_to_node[input]
269
+
270
+ def find_node_by_name(self, node_name, new_nodes_list, graph):
271
+ """Find out if a node exists in a graph or a node is in the
272
+ new set of nodes created during quantization.
273
+
274
+ Returns:
275
+ The node found or None.
276
+ """
277
+ graph_nodes_list = list(graph.node) # deep copy
278
+ graph_nodes_list.extend(new_nodes_list)
279
+ node = find_by_name(node_name, graph_nodes_list)
280
+ return node
281
+
282
+ def get_largest_node_name_suffix(self, node_name_prefix):
283
+ """
284
+ Gets the largest node name (int) suffix for all node names that begin with `node_name_prefix`.
285
+ Example: for nodes my_prefix_0 and my_prefix_3, this method returns 3.
286
+ """
287
+ suffix = -1
288
+
289
+ for node in self.model.graph.node:
290
+ if node.name and node.name.startswith(node_name_prefix):
291
+ try:
292
+ index = int(node.name[len(node_name_prefix) :])
293
+ suffix = max(index, suffix)
294
+ except ValueError:
295
+ continue
296
+
297
+ return suffix
298
+
299
+ def get_largest_initializer_name_suffix(self, initializer_name_prefix):
300
+ """
301
+ Gets the largest initializer name integer suffix for all initializer names that begin
302
+ with `initializer_name_prefix`. This can be used to create unique initializer names.
303
+
304
+ Example: for initializer names 'my_weight_0' and 'my_weight_3', this method returns 3 if
305
+ `initializer_name_prefix` is 'my_weight_'.
306
+ """
307
+ suffix = -1
308
+
309
+ for initializer in self.model.graph.initializer:
310
+ if initializer.name.startswith(initializer_name_prefix):
311
+ try:
312
+ index = int(initializer.name[len(initializer_name_prefix) :])
313
+ suffix = max(index, suffix)
314
+ except ValueError:
315
+ continue
316
+
317
+ return suffix
318
+
319
+ def find_nodes_by_initializer(self, graph, initializer):
320
+ """
321
+ Find all nodes with given initializer as an input.
322
+ """
323
+ nodes = []
324
+ for node in graph.node:
325
+ for node_input in node.input:
326
+ if node_input == initializer.name:
327
+ nodes.append(node)
328
+ return nodes
329
+
330
+ @staticmethod
331
+ def __get_initializer(name, graph_path):
332
+ for gid in range(len(graph_path) - 1, -1, -1):
333
+ graph = graph_path[gid]
334
+ for tensor in graph.initializer:
335
+ if tensor.name == name:
336
+ return tensor, graph
337
+ return None, None
338
+
339
+ @staticmethod
340
+ def __replace_gemm_with_matmul(graph_path):
341
+ new_nodes = []
342
+ graph = graph_path[-1]
343
+ for node in graph.node:
344
+ graph_attrs = [attr for attr in node.attribute if attr.type == 5 or attr.type == 10]
345
+ if graph_attrs:
346
+ kwargs = {}
347
+ for attr in node.attribute:
348
+ if attr.type == 5:
349
+ graph_path.append(attr.g)
350
+ kv = {attr.name: ONNXModel.__replace_gemm_with_matmul(graph_path)}
351
+ elif attr.type == 10:
352
+ value = []
353
+ for subgraph in attr.graphs:
354
+ graph_path.append(subgraph)
355
+ value.extend([ONNXModel.__replace_gemm_with_matmul(graph_path)])
356
+ kv = {attr.name: value}
357
+ else:
358
+ kv = attribute_to_kwarg(attr)
359
+ kwargs.update(kv)
360
+ node = onnx_helper.make_node( # noqa: PLW2901
361
+ node.op_type, node.input, node.output, name=node.name, **kwargs
362
+ )
363
+
364
+ if node.op_type == "Gemm":
365
+ alpha = 1.0
366
+ beta = 1.0
367
+ transA = 0 # noqa: N806
368
+ transB = 0 # noqa: N806
369
+ for attr in node.attribute:
370
+ if attr.name == "alpha":
371
+ alpha = onnx_helper.get_attribute_value(attr)
372
+ elif attr.name == "beta":
373
+ beta = onnx_helper.get_attribute_value(attr)
374
+ elif attr.name == "transA":
375
+ transA = onnx_helper.get_attribute_value(attr) # noqa: N806
376
+ elif attr.name == "transB":
377
+ transB = onnx_helper.get_attribute_value(attr) # noqa: N806
378
+ if alpha == 1.0 and beta == 1.0 and transA == 0:
379
+ inputB = node.input[1] # noqa: N806
380
+ if transB == 1:
381
+ B, Bs_graph = ONNXModel.__get_initializer(node.input[1], graph_path) # noqa: N806
382
+ if B:
383
+ # assume B is not used by any other node
384
+ B_array = onnx_numpy_helper.to_array(B) # noqa: N806
385
+ B_trans = onnx_numpy_helper.from_array(B_array.T) # noqa: N806
386
+ B_trans.name = B.name
387
+ Bs_graph.initializer.remove(B)
388
+ for input in Bs_graph.input:
389
+ if input.name == inputB:
390
+ Bs_graph.input.remove(input)
391
+ break
392
+ Bs_graph.initializer.extend([B_trans])
393
+ else:
394
+ inputB += "_Transposed" # noqa: N806
395
+ transpose_node = onnx_helper.make_node(
396
+ "Transpose",
397
+ inputs=[node.input[1]],
398
+ outputs=[inputB],
399
+ name=node.name + "_Transpose" if node.name else "",
400
+ )
401
+ new_nodes.append(transpose_node)
402
+
403
+ matmul_node = onnx_helper.make_node(
404
+ "MatMul",
405
+ inputs=[node.input[0], inputB],
406
+ outputs=[node.output[0] + ("_MatMul" if len(node.input) > 2 else "")],
407
+ name=node.name + "_MatMul" if node.name else "",
408
+ )
409
+ new_nodes.append(matmul_node)
410
+
411
+ if len(node.input) > 2:
412
+ add_node = onnx_helper.make_node(
413
+ "Add",
414
+ inputs=[node.output[0] + "_MatMul", node.input[2]],
415
+ outputs=node.output,
416
+ name=node.name + "_Add" if node.name else "",
417
+ )
418
+ new_nodes.append(add_node)
419
+
420
+ # unsupported
421
+ else:
422
+ new_nodes.append(node)
423
+
424
+ # not GEMM
425
+ else:
426
+ new_nodes.append(node)
427
+
428
+ graph.ClearField("node")
429
+ graph.node.extend(new_nodes)
430
+ graph_path.pop()
431
+ return graph
432
+
433
+ def replace_gemm_with_matmul(self):
434
+ graph_path = [self.graph()]
435
+ ONNXModel.__replace_gemm_with_matmul(graph_path)
436
+
437
+ def save_model_to_file(self, output_path, use_external_data_format=False):
438
+ """
439
+ Save model to external data, which is needed for model size > 2GB
440
+ """
441
+ self.topological_sort()
442
+ if use_external_data_format:
443
+ onnx.external_data_helper.convert_model_to_external_data(
444
+ self.model,
445
+ all_tensors_to_one_file=True,
446
+ location=Path(output_path).name + ".data",
447
+ convert_attribute=True,
448
+ )
449
+ for init in self.model.graph.initializer:
450
+ self._check_init(init, "end")
451
+ onnx.save_model(self.model, output_path)
452
+
453
+ @staticmethod
454
+ def replace_node_input(node, old_input_name, new_input_name):
455
+ assert isinstance(old_input_name, str) and isinstance(new_input_name, str)
456
+ for j in range(len(node.input)):
457
+ if node.input[j] == old_input_name:
458
+ node.input[j] = new_input_name
459
+
460
+ def replace_input_of_all_nodes(self, old_input_name, new_input_name):
461
+ for node in self.model.graph.node:
462
+ ONNXModel.replace_node_input(node, old_input_name, new_input_name)
463
+
464
+ def replace_input_of_nodes(self, old_input_name, new_input_name, node_names_set):
465
+ for node in self.model.graph.node:
466
+ if node.name in node_names_set:
467
+ ONNXModel.replace_node_input(node, old_input_name, new_input_name)
468
+
469
+ @staticmethod
470
+ def replace_node_output(node, old_output_name, new_output_name):
471
+ assert isinstance(old_output_name, str) and isinstance(new_output_name, str)
472
+ for j in range(len(node.output)):
473
+ if node.output[j] == old_output_name:
474
+ node.output[j] = new_output_name
475
+
476
+ def replace_output_of_all_nodes(self, old_output_name, new_output_name):
477
+ for node in self.model.graph.node:
478
+ ONNXModel.replace_node_output(node, old_output_name, new_output_name)
479
+
480
+ def replace_output_of_nodes(self, old_output_name, new_output_name, node_names_set):
481
+ for node in self.model.graph.node:
482
+ if node.name in node_names_set:
483
+ ONNXModel.replace_node_output(node, old_output_name, new_output_name)
484
+
485
+ def remove_unused_constant(self):
486
+ input_name_to_nodes = self.input_name_to_nodes()
487
+
488
+ # remove unused constant
489
+ unused_nodes = []
490
+ nodes = self.nodes()
491
+ for node in nodes:
492
+ if (
493
+ node.op_type == "Constant"
494
+ and not self.is_graph_output(node.output[0])
495
+ and node.output[0] not in input_name_to_nodes
496
+ ):
497
+ unused_nodes.append(node)
498
+
499
+ self.remove_nodes(unused_nodes)
500
+
501
+ ununsed_weights = []
502
+ for w in self.initializer():
503
+ if w.name not in input_name_to_nodes and not self.is_graph_output(w.name):
504
+ ununsed_weights.append(w)
505
+ # Remove from graph.input
506
+ for graph_input in self.graph().input:
507
+ if graph_input.name == w.name:
508
+ self.graph().input.remove(graph_input)
509
+
510
+ self.remove_initializers(ununsed_weights)
511
+
512
+ def is_graph_output(self, output_name):
513
+ return any(output.name == output_name for output in self.model.graph.output)
514
+
515
+ def is_graph_input(self, tensor_name: str) -> bool:
516
+ return any(input.name == tensor_name for input in self.model.graph.input)
517
+
518
+ # TODO:use OnnxModel.graph_topological_sort(self.model.graph) from transformers.onnx_model
519
+ # Currently it breaks Openvino/Linux training gpu pipeline so hold off for 1.8 release
520
+ def topological_sort(self):
521
+ deps_count = [0] * len(self.nodes()) # dependency count of each node
522
+ deps_to_nodes = {} # input to node indice
523
+ sorted_nodes = [] # initialize sorted_nodes
524
+ for node_idx, node in enumerate(self.nodes()):
525
+ # CANNOT use len(node.input) directly because input can be optional
526
+ deps_count[node_idx] = sum(1 for _ in node.input if _)
527
+ if deps_count[node_idx] == 0: # Constant doesn't depend on any inputs
528
+ sorted_nodes.append(self.nodes()[node_idx])
529
+ continue
530
+
531
+ for input_name in node.input:
532
+ if not input_name:
533
+ continue
534
+ if input_name not in deps_to_nodes:
535
+ deps_to_nodes[input_name] = [node_idx]
536
+ else:
537
+ deps_to_nodes[input_name].append(node_idx)
538
+
539
+ initializer_names = [init.name for init in self.initializer()]
540
+ graph_input_names = [input.name for input in self.model.graph.input]
541
+ input_names = initializer_names + graph_input_names
542
+ input_names.sort()
543
+ prev_input_name = None
544
+ for input_name in input_names:
545
+ if prev_input_name == input_name:
546
+ continue
547
+
548
+ prev_input_name = input_name
549
+ if input_name in deps_to_nodes:
550
+ for node_idx in deps_to_nodes[input_name]:
551
+ deps_count[node_idx] = deps_count[node_idx] - 1
552
+ if deps_count[node_idx] == 0:
553
+ sorted_nodes.append(self.nodes()[node_idx])
554
+
555
+ start = 0
556
+ end = len(sorted_nodes)
557
+
558
+ while start < end:
559
+ for output in sorted_nodes[start].output:
560
+ if output in deps_to_nodes:
561
+ for node_idx in deps_to_nodes[output]:
562
+ deps_count[node_idx] = deps_count[node_idx] - 1
563
+ if deps_count[node_idx] == 0:
564
+ sorted_nodes.append(self.nodes()[node_idx])
565
+ end = end + 1
566
+ start = start + 1
567
+
568
+ assert end == len(self.graph().node), "Graph is not a DAG"
569
+ self.graph().ClearField("node")
570
+ self.graph().node.extend(sorted_nodes)
571
+
572
+ def clean_initializers(self):
573
+ return _clean_initializers_helper(self.graph(), self.model)
574
+
575
+ def _check_init(self, init, test=None):
576
+ if init.data_type == onnx.TensorProto.FLOAT8E4M3FN:
577
+ if init.HasField("raw_data"):
578
+ b = list(init.raw_data)
579
+ if any((i & 127) == 127 for i in b):
580
+ raise ValueError(f"Initializer {init.name!r} has nan.")
581
+ return init
582
+
583
+ def _check_node(self, node):
584
+ """
585
+ A quantization to float 8 does not use quantized bias but float 16 bias.
586
+ This function checks that DequantizeLinear is not used to
587
+ dequantize from float 16.
588
+ """
589
+ if node.op_type == "DequantizeLinear":
590
+ zero_point = node.input[2]
591
+ init = self.get_initializer(zero_point)
592
+ dtype = init.data_type
593
+ if dtype in {
594
+ onnx.TensorProto.FLOAT16,
595
+ onnx.TensorProto.FLOAT,
596
+ onnx.TensorProto.DOUBLE,
597
+ onnx.TensorProto.BFLOAT16,
598
+ }:
599
+ raise RuntimeError(f"Unsupported DequantizeLinear operator, dequantization from {dtype}.")
600
+ return node