onnxruntime-directml 1.20.0__cp313-cp313-win_amd64.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (305) hide show
  1. onnxruntime/LICENSE +21 -0
  2. onnxruntime/Privacy.md +21 -0
  3. onnxruntime/ThirdPartyNotices.txt +6508 -0
  4. onnxruntime/__init__.py +78 -0
  5. onnxruntime/backend/__init__.py +6 -0
  6. onnxruntime/backend/backend.py +174 -0
  7. onnxruntime/backend/backend_rep.py +53 -0
  8. onnxruntime/capi/DirectML.dll +0 -0
  9. onnxruntime/capi/__init__.py +4 -0
  10. onnxruntime/capi/_ld_preload.py +7 -0
  11. onnxruntime/capi/_pybind_state.py +33 -0
  12. onnxruntime/capi/convert_npz_to_onnx_adapter.py +48 -0
  13. onnxruntime/capi/onnxruntime.dll +0 -0
  14. onnxruntime/capi/onnxruntime_collect_build_info.py +47 -0
  15. onnxruntime/capi/onnxruntime_inference_collection.py +1108 -0
  16. onnxruntime/capi/onnxruntime_providers_shared.dll +0 -0
  17. onnxruntime/capi/onnxruntime_pybind11_state.pyd +0 -0
  18. onnxruntime/capi/onnxruntime_validation.py +150 -0
  19. onnxruntime/capi/version_info.py +2 -0
  20. onnxruntime/datasets/__init__.py +17 -0
  21. onnxruntime/datasets/logreg_iris.onnx +0 -0
  22. onnxruntime/datasets/mul_1.onnx +0 -0
  23. onnxruntime/datasets/sigmoid.onnx +13 -0
  24. onnxruntime/quantization/CalTableFlatBuffers/KeyValue.py +78 -0
  25. onnxruntime/quantization/CalTableFlatBuffers/TrtTable.py +90 -0
  26. onnxruntime/quantization/CalTableFlatBuffers/__init__.py +0 -0
  27. onnxruntime/quantization/__init__.py +16 -0
  28. onnxruntime/quantization/base_quantizer.py +532 -0
  29. onnxruntime/quantization/calibrate.py +1245 -0
  30. onnxruntime/quantization/execution_providers/qnn/__init__.py +2 -0
  31. onnxruntime/quantization/execution_providers/qnn/fusion_lpnorm.py +132 -0
  32. onnxruntime/quantization/execution_providers/qnn/mixed_precision_overrides_utils.py +413 -0
  33. onnxruntime/quantization/execution_providers/qnn/preprocess.py +307 -0
  34. onnxruntime/quantization/execution_providers/qnn/quant_config.py +387 -0
  35. onnxruntime/quantization/fusions/__init__.py +3 -0
  36. onnxruntime/quantization/fusions/fusion.py +311 -0
  37. onnxruntime/quantization/fusions/fusion_gelu.py +272 -0
  38. onnxruntime/quantization/fusions/fusion_layernorm.py +135 -0
  39. onnxruntime/quantization/matmul_4bits_quantizer.py +1480 -0
  40. onnxruntime/quantization/matmul_bnb4_quantizer.py +240 -0
  41. onnxruntime/quantization/onnx_model.py +580 -0
  42. onnxruntime/quantization/onnx_quantizer.py +1008 -0
  43. onnxruntime/quantization/operators/__init__.py +2 -0
  44. onnxruntime/quantization/operators/activation.py +119 -0
  45. onnxruntime/quantization/operators/argmax.py +18 -0
  46. onnxruntime/quantization/operators/attention.py +73 -0
  47. onnxruntime/quantization/operators/base_operator.py +26 -0
  48. onnxruntime/quantization/operators/binary_op.py +72 -0
  49. onnxruntime/quantization/operators/concat.py +62 -0
  50. onnxruntime/quantization/operators/conv.py +258 -0
  51. onnxruntime/quantization/operators/direct_q8.py +78 -0
  52. onnxruntime/quantization/operators/embed_layernorm.py +121 -0
  53. onnxruntime/quantization/operators/gather.py +64 -0
  54. onnxruntime/quantization/operators/gavgpool.py +62 -0
  55. onnxruntime/quantization/operators/gemm.py +166 -0
  56. onnxruntime/quantization/operators/lstm.py +117 -0
  57. onnxruntime/quantization/operators/matmul.py +231 -0
  58. onnxruntime/quantization/operators/maxpool.py +34 -0
  59. onnxruntime/quantization/operators/norm.py +40 -0
  60. onnxruntime/quantization/operators/pad.py +100 -0
  61. onnxruntime/quantization/operators/pooling.py +67 -0
  62. onnxruntime/quantization/operators/qdq_base_operator.py +22 -0
  63. onnxruntime/quantization/operators/resize.py +34 -0
  64. onnxruntime/quantization/operators/softmax.py +74 -0
  65. onnxruntime/quantization/operators/split.py +63 -0
  66. onnxruntime/quantization/operators/where.py +87 -0
  67. onnxruntime/quantization/preprocess.py +141 -0
  68. onnxruntime/quantization/qdq_loss_debug.py +389 -0
  69. onnxruntime/quantization/qdq_quantizer.py +1187 -0
  70. onnxruntime/quantization/quant_utils.py +891 -0
  71. onnxruntime/quantization/quantize.py +748 -0
  72. onnxruntime/quantization/registry.py +106 -0
  73. onnxruntime/quantization/shape_inference.py +187 -0
  74. onnxruntime/quantization/tensor_quant_overrides.py +516 -0
  75. onnxruntime/tools/__init__.py +10 -0
  76. onnxruntime/tools/check_onnx_model_mobile_usability.py +47 -0
  77. onnxruntime/tools/convert_onnx_models_to_ort.py +377 -0
  78. onnxruntime/tools/file_utils.py +46 -0
  79. onnxruntime/tools/logger.py +11 -0
  80. onnxruntime/tools/make_dynamic_shape_fixed.py +72 -0
  81. onnxruntime/tools/mobile_helpers/__init__.py +0 -0
  82. onnxruntime/tools/mobile_helpers/coreml_supported_mlprogram_ops.md +33 -0
  83. onnxruntime/tools/mobile_helpers/coreml_supported_neuralnetwork_ops.md +43 -0
  84. onnxruntime/tools/mobile_helpers/nnapi_supported_ops.md +58 -0
  85. onnxruntime/tools/mobile_helpers/usability_checker.py +739 -0
  86. onnxruntime/tools/offline_tuning.py +169 -0
  87. onnxruntime/tools/onnx_model_utils.py +413 -0
  88. onnxruntime/tools/onnx_randomizer.py +85 -0
  89. onnxruntime/tools/onnxruntime_test.py +164 -0
  90. onnxruntime/tools/optimize_onnx_model.py +55 -0
  91. onnxruntime/tools/ort_format_model/__init__.py +25 -0
  92. onnxruntime/tools/ort_format_model/operator_type_usage_processors.py +663 -0
  93. onnxruntime/tools/ort_format_model/ort_flatbuffers_py/__init__.py +0 -0
  94. onnxruntime/tools/ort_format_model/ort_flatbuffers_py/fbs/ArgType.py +7 -0
  95. onnxruntime/tools/ort_format_model/ort_flatbuffers_py/fbs/ArgTypeAndIndex.py +67 -0
  96. onnxruntime/tools/ort_format_model/ort_flatbuffers_py/fbs/Attribute.py +337 -0
  97. onnxruntime/tools/ort_format_model/ort_flatbuffers_py/fbs/AttributeType.py +18 -0
  98. onnxruntime/tools/ort_format_model/ort_flatbuffers_py/fbs/Checkpoint.py +125 -0
  99. onnxruntime/tools/ort_format_model/ort_flatbuffers_py/fbs/DeprecatedKernelCreateInfos.py +120 -0
  100. onnxruntime/tools/ort_format_model/ort_flatbuffers_py/fbs/DeprecatedNodeIndexAndKernelDefHash.py +68 -0
  101. onnxruntime/tools/ort_format_model/ort_flatbuffers_py/fbs/DeprecatedSessionState.py +96 -0
  102. onnxruntime/tools/ort_format_model/ort_flatbuffers_py/fbs/DeprecatedSubGraphSessionState.py +72 -0
  103. onnxruntime/tools/ort_format_model/ort_flatbuffers_py/fbs/Dimension.py +71 -0
  104. onnxruntime/tools/ort_format_model/ort_flatbuffers_py/fbs/DimensionValue.py +80 -0
  105. onnxruntime/tools/ort_format_model/ort_flatbuffers_py/fbs/DimensionValueType.py +8 -0
  106. onnxruntime/tools/ort_format_model/ort_flatbuffers_py/fbs/EdgeEnd.py +32 -0
  107. onnxruntime/tools/ort_format_model/ort_flatbuffers_py/fbs/FloatProperty.py +67 -0
  108. onnxruntime/tools/ort_format_model/ort_flatbuffers_py/fbs/Graph.py +320 -0
  109. onnxruntime/tools/ort_format_model/ort_flatbuffers_py/fbs/InferenceSession.py +88 -0
  110. onnxruntime/tools/ort_format_model/ort_flatbuffers_py/fbs/IntProperty.py +67 -0
  111. onnxruntime/tools/ort_format_model/ort_flatbuffers_py/fbs/KernelTypeStrArgsEntry.py +91 -0
  112. onnxruntime/tools/ort_format_model/ort_flatbuffers_py/fbs/KernelTypeStrResolver.py +78 -0
  113. onnxruntime/tools/ort_format_model/ort_flatbuffers_py/fbs/MapType.py +71 -0
  114. onnxruntime/tools/ort_format_model/ort_flatbuffers_py/fbs/Model.py +223 -0
  115. onnxruntime/tools/ort_format_model/ort_flatbuffers_py/fbs/ModuleState.py +141 -0
  116. onnxruntime/tools/ort_format_model/ort_flatbuffers_py/fbs/Node.py +317 -0
  117. onnxruntime/tools/ort_format_model/ort_flatbuffers_py/fbs/NodeEdge.py +126 -0
  118. onnxruntime/tools/ort_format_model/ort_flatbuffers_py/fbs/NodeType.py +7 -0
  119. onnxruntime/tools/ort_format_model/ort_flatbuffers_py/fbs/NodesToOptimizeIndices.py +160 -0
  120. onnxruntime/tools/ort_format_model/ort_flatbuffers_py/fbs/OpIdKernelTypeStrArgsEntry.py +91 -0
  121. onnxruntime/tools/ort_format_model/ort_flatbuffers_py/fbs/OperatorSetId.py +67 -0
  122. onnxruntime/tools/ort_format_model/ort_flatbuffers_py/fbs/OptimizerGroup.py +117 -0
  123. onnxruntime/tools/ort_format_model/ort_flatbuffers_py/fbs/ParameterOptimizerState.py +91 -0
  124. onnxruntime/tools/ort_format_model/ort_flatbuffers_py/fbs/PropertyBag.py +152 -0
  125. onnxruntime/tools/ort_format_model/ort_flatbuffers_py/fbs/RuntimeOptimizationRecord.py +105 -0
  126. onnxruntime/tools/ort_format_model/ort_flatbuffers_py/fbs/RuntimeOptimizationRecordContainerEntry.py +91 -0
  127. onnxruntime/tools/ort_format_model/ort_flatbuffers_py/fbs/RuntimeOptimizations.py +79 -0
  128. onnxruntime/tools/ort_format_model/ort_flatbuffers_py/fbs/SequenceType.py +58 -0
  129. onnxruntime/tools/ort_format_model/ort_flatbuffers_py/fbs/Shape.py +78 -0
  130. onnxruntime/tools/ort_format_model/ort_flatbuffers_py/fbs/SparseTensor.py +114 -0
  131. onnxruntime/tools/ort_format_model/ort_flatbuffers_py/fbs/StringProperty.py +67 -0
  132. onnxruntime/tools/ort_format_model/ort_flatbuffers_py/fbs/StringStringEntry.py +67 -0
  133. onnxruntime/tools/ort_format_model/ort_flatbuffers_py/fbs/Tensor.py +203 -0
  134. onnxruntime/tools/ort_format_model/ort_flatbuffers_py/fbs/TensorDataType.py +26 -0
  135. onnxruntime/tools/ort_format_model/ort_flatbuffers_py/fbs/TensorTypeAndShape.py +71 -0
  136. onnxruntime/tools/ort_format_model/ort_flatbuffers_py/fbs/TypeInfo.py +83 -0
  137. onnxruntime/tools/ort_format_model/ort_flatbuffers_py/fbs/TypeInfoValue.py +9 -0
  138. onnxruntime/tools/ort_format_model/ort_flatbuffers_py/fbs/ValueInfo.py +84 -0
  139. onnxruntime/tools/ort_format_model/ort_flatbuffers_py/fbs/__init__.py +6 -0
  140. onnxruntime/tools/ort_format_model/ort_model_processor.py +86 -0
  141. onnxruntime/tools/ort_format_model/types.py +84 -0
  142. onnxruntime/tools/ort_format_model/utils.py +62 -0
  143. onnxruntime/tools/pytorch_export_contrib_ops.py +108 -0
  144. onnxruntime/tools/pytorch_export_helpers.py +131 -0
  145. onnxruntime/tools/qdq_helpers/__init__.py +0 -0
  146. onnxruntime/tools/qdq_helpers/optimize_qdq_model.py +37 -0
  147. onnxruntime/tools/reduced_build_config_parser.py +202 -0
  148. onnxruntime/tools/symbolic_shape_infer.py +3016 -0
  149. onnxruntime/tools/update_onnx_opset.py +31 -0
  150. onnxruntime/transformers/__init__.py +8 -0
  151. onnxruntime/transformers/affinity_helper.py +40 -0
  152. onnxruntime/transformers/benchmark.py +944 -0
  153. onnxruntime/transformers/benchmark_helper.py +646 -0
  154. onnxruntime/transformers/bert_perf_test.py +634 -0
  155. onnxruntime/transformers/bert_test_data.py +642 -0
  156. onnxruntime/transformers/compare_bert_results.py +246 -0
  157. onnxruntime/transformers/constants.py +47 -0
  158. onnxruntime/transformers/convert_generation.py +3124 -0
  159. onnxruntime/transformers/convert_tf_models_to_pytorch.py +205 -0
  160. onnxruntime/transformers/convert_to_packing_mode.py +387 -0
  161. onnxruntime/transformers/dynamo_onnx_helper.py +104 -0
  162. onnxruntime/transformers/float16.py +501 -0
  163. onnxruntime/transformers/fusion_attention.py +1235 -0
  164. onnxruntime/transformers/fusion_attention_clip.py +257 -0
  165. onnxruntime/transformers/fusion_attention_sam2.py +534 -0
  166. onnxruntime/transformers/fusion_attention_unet.py +1304 -0
  167. onnxruntime/transformers/fusion_attention_vae.py +301 -0
  168. onnxruntime/transformers/fusion_bart_attention.py +640 -0
  169. onnxruntime/transformers/fusion_base.py +137 -0
  170. onnxruntime/transformers/fusion_bias_add.py +58 -0
  171. onnxruntime/transformers/fusion_biasgelu.py +66 -0
  172. onnxruntime/transformers/fusion_biassplitgelu.py +111 -0
  173. onnxruntime/transformers/fusion_conformer_attention.py +143 -0
  174. onnxruntime/transformers/fusion_embedlayer.py +811 -0
  175. onnxruntime/transformers/fusion_fastgelu.py +360 -0
  176. onnxruntime/transformers/fusion_gelu.py +259 -0
  177. onnxruntime/transformers/fusion_gelu_approximation.py +25 -0
  178. onnxruntime/transformers/fusion_gemmfastgelu.py +122 -0
  179. onnxruntime/transformers/fusion_gpt_attention.py +546 -0
  180. onnxruntime/transformers/fusion_gpt_attention_megatron.py +355 -0
  181. onnxruntime/transformers/fusion_gpt_attention_no_past.py +260 -0
  182. onnxruntime/transformers/fusion_group_norm.py +179 -0
  183. onnxruntime/transformers/fusion_layernorm.py +465 -0
  184. onnxruntime/transformers/fusion_nhwc_conv.py +100 -0
  185. onnxruntime/transformers/fusion_options.py +340 -0
  186. onnxruntime/transformers/fusion_qordered_attention.py +421 -0
  187. onnxruntime/transformers/fusion_qordered_gelu.py +119 -0
  188. onnxruntime/transformers/fusion_qordered_layernorm.py +123 -0
  189. onnxruntime/transformers/fusion_qordered_matmul.py +217 -0
  190. onnxruntime/transformers/fusion_quickgelu.py +74 -0
  191. onnxruntime/transformers/fusion_reshape.py +173 -0
  192. onnxruntime/transformers/fusion_rotary_attention.py +1592 -0
  193. onnxruntime/transformers/fusion_shape.py +110 -0
  194. onnxruntime/transformers/fusion_simplified_layernorm.py +159 -0
  195. onnxruntime/transformers/fusion_skip_group_norm.py +255 -0
  196. onnxruntime/transformers/fusion_skiplayernorm.py +209 -0
  197. onnxruntime/transformers/fusion_transpose.py +168 -0
  198. onnxruntime/transformers/fusion_utils.py +307 -0
  199. onnxruntime/transformers/huggingface_models.py +167 -0
  200. onnxruntime/transformers/import_utils.py +20 -0
  201. onnxruntime/transformers/io_binding_helper.py +442 -0
  202. onnxruntime/transformers/large_model_exporter.py +395 -0
  203. onnxruntime/transformers/machine_info.py +221 -0
  204. onnxruntime/transformers/metrics.py +164 -0
  205. onnxruntime/transformers/models/bart/__init__.py +12 -0
  206. onnxruntime/transformers/models/bart/export.py +98 -0
  207. onnxruntime/transformers/models/bert/__init__.py +12 -0
  208. onnxruntime/transformers/models/bert/eval_squad.py +329 -0
  209. onnxruntime/transformers/models/gpt2/__init__.py +12 -0
  210. onnxruntime/transformers/models/gpt2/benchmark_gpt2.py +413 -0
  211. onnxruntime/transformers/models/gpt2/convert_to_onnx.py +561 -0
  212. onnxruntime/transformers/models/gpt2/gpt2_helper.py +1032 -0
  213. onnxruntime/transformers/models/gpt2/gpt2_parity.py +513 -0
  214. onnxruntime/transformers/models/gpt2/gpt2_tester.py +501 -0
  215. onnxruntime/transformers/models/gpt2/parity_check_helper.py +146 -0
  216. onnxruntime/transformers/models/llama/__init__.py +12 -0
  217. onnxruntime/transformers/models/llama/benchmark.py +703 -0
  218. onnxruntime/transformers/models/llama/benchmark_all.py +488 -0
  219. onnxruntime/transformers/models/llama/benchmark_e2e.py +606 -0
  220. onnxruntime/transformers/models/llama/convert_to_onnx.py +1027 -0
  221. onnxruntime/transformers/models/llama/dist_settings.py +57 -0
  222. onnxruntime/transformers/models/llama/llama_inputs.py +503 -0
  223. onnxruntime/transformers/models/llama/llama_parity.py +309 -0
  224. onnxruntime/transformers/models/llama/llama_torch.py +47 -0
  225. onnxruntime/transformers/models/llama/quant_kv_dataloader.py +108 -0
  226. onnxruntime/transformers/models/longformer/__init__.py +12 -0
  227. onnxruntime/transformers/models/longformer/benchmark_longformer.py +821 -0
  228. onnxruntime/transformers/models/longformer/convert_to_onnx.py +413 -0
  229. onnxruntime/transformers/models/longformer/generate_test_data.py +347 -0
  230. onnxruntime/transformers/models/longformer/longformer_helper.py +77 -0
  231. onnxruntime/transformers/models/phi2/__init__.py +12 -0
  232. onnxruntime/transformers/models/phi2/convert_to_onnx.py +576 -0
  233. onnxruntime/transformers/models/phi2/inference_example.py +414 -0
  234. onnxruntime/transformers/models/sam2/__init__.py +12 -0
  235. onnxruntime/transformers/models/sam2/benchmark_sam2.py +625 -0
  236. onnxruntime/transformers/models/sam2/convert_to_onnx.py +260 -0
  237. onnxruntime/transformers/models/sam2/image_decoder.py +273 -0
  238. onnxruntime/transformers/models/sam2/image_encoder.py +186 -0
  239. onnxruntime/transformers/models/sam2/mask_decoder.py +208 -0
  240. onnxruntime/transformers/models/sam2/nvtx_helper.py +33 -0
  241. onnxruntime/transformers/models/sam2/prompt_encoder.py +189 -0
  242. onnxruntime/transformers/models/sam2/sam2_demo.py +322 -0
  243. onnxruntime/transformers/models/sam2/sam2_image_onnx_predictor.py +280 -0
  244. onnxruntime/transformers/models/sam2/sam2_utils.py +147 -0
  245. onnxruntime/transformers/models/stable_diffusion/__init__.py +12 -0
  246. onnxruntime/transformers/models/stable_diffusion/benchmark.py +1429 -0
  247. onnxruntime/transformers/models/stable_diffusion/benchmark_controlnet.py +426 -0
  248. onnxruntime/transformers/models/stable_diffusion/demo_txt2img.py +102 -0
  249. onnxruntime/transformers/models/stable_diffusion/demo_txt2img_xl.py +268 -0
  250. onnxruntime/transformers/models/stable_diffusion/demo_utils.py +778 -0
  251. onnxruntime/transformers/models/stable_diffusion/diffusion_models.py +1319 -0
  252. onnxruntime/transformers/models/stable_diffusion/diffusion_schedulers.py +1181 -0
  253. onnxruntime/transformers/models/stable_diffusion/engine_builder.py +296 -0
  254. onnxruntime/transformers/models/stable_diffusion/engine_builder_ort_cuda.py +388 -0
  255. onnxruntime/transformers/models/stable_diffusion/engine_builder_ort_trt.py +288 -0
  256. onnxruntime/transformers/models/stable_diffusion/engine_builder_tensorrt.py +395 -0
  257. onnxruntime/transformers/models/stable_diffusion/engine_builder_torch.py +108 -0
  258. onnxruntime/transformers/models/stable_diffusion/optimize_pipeline.py +350 -0
  259. onnxruntime/transformers/models/stable_diffusion/ort_optimizer.py +136 -0
  260. onnxruntime/transformers/models/stable_diffusion/pipeline_stable_diffusion.py +831 -0
  261. onnxruntime/transformers/models/stable_diffusion/trt_utilities.py +12 -0
  262. onnxruntime/transformers/models/t5/__init__.py +12 -0
  263. onnxruntime/transformers/models/t5/convert_to_onnx.py +278 -0
  264. onnxruntime/transformers/models/t5/past_helper.py +150 -0
  265. onnxruntime/transformers/models/t5/t5_decoder.py +438 -0
  266. onnxruntime/transformers/models/t5/t5_encoder.py +171 -0
  267. onnxruntime/transformers/models/t5/t5_encoder_decoder_init.py +299 -0
  268. onnxruntime/transformers/models/t5/t5_helper.py +272 -0
  269. onnxruntime/transformers/models/whisper/__init__.py +12 -0
  270. onnxruntime/transformers/models/whisper/benchmark.py +610 -0
  271. onnxruntime/transformers/models/whisper/benchmark_all.py +528 -0
  272. onnxruntime/transformers/models/whisper/convert_to_onnx.py +536 -0
  273. onnxruntime/transformers/models/whisper/whisper_chain.py +329 -0
  274. onnxruntime/transformers/models/whisper/whisper_decoder.py +402 -0
  275. onnxruntime/transformers/models/whisper/whisper_encoder.py +164 -0
  276. onnxruntime/transformers/models/whisper/whisper_encoder_decoder_init.py +306 -0
  277. onnxruntime/transformers/models/whisper/whisper_helper.py +524 -0
  278. onnxruntime/transformers/models/whisper/whisper_openai_helper.py +84 -0
  279. onnxruntime/transformers/onnx_exporter.py +717 -0
  280. onnxruntime/transformers/onnx_model.py +1569 -0
  281. onnxruntime/transformers/onnx_model_bart.py +142 -0
  282. onnxruntime/transformers/onnx_model_bert.py +481 -0
  283. onnxruntime/transformers/onnx_model_bert_keras.py +475 -0
  284. onnxruntime/transformers/onnx_model_bert_tf.py +589 -0
  285. onnxruntime/transformers/onnx_model_clip.py +40 -0
  286. onnxruntime/transformers/onnx_model_conformer.py +33 -0
  287. onnxruntime/transformers/onnx_model_gpt2.py +101 -0
  288. onnxruntime/transformers/onnx_model_phi.py +930 -0
  289. onnxruntime/transformers/onnx_model_sam2.py +138 -0
  290. onnxruntime/transformers/onnx_model_t5.py +791 -0
  291. onnxruntime/transformers/onnx_model_tnlr.py +227 -0
  292. onnxruntime/transformers/onnx_model_unet.py +259 -0
  293. onnxruntime/transformers/onnx_model_vae.py +43 -0
  294. onnxruntime/transformers/onnx_utils.py +55 -0
  295. onnxruntime/transformers/optimizer.py +612 -0
  296. onnxruntime/transformers/profiler.py +725 -0
  297. onnxruntime/transformers/quantize_helper.py +76 -0
  298. onnxruntime/transformers/shape_infer_helper.py +122 -0
  299. onnxruntime/transformers/shape_optimizer.py +401 -0
  300. onnxruntime/transformers/torch_onnx_export_helper.py +74 -0
  301. onnxruntime_directml-1.20.0.dist-info/METADATA +187 -0
  302. onnxruntime_directml-1.20.0.dist-info/RECORD +305 -0
  303. onnxruntime_directml-1.20.0.dist-info/WHEEL +5 -0
  304. onnxruntime_directml-1.20.0.dist-info/entry_points.txt +2 -0
  305. onnxruntime_directml-1.20.0.dist-info/top_level.txt +1 -0
@@ -0,0 +1,2 @@
1
+ from .preprocess import qnn_preprocess_model # noqa: F401
2
+ from .quant_config import get_qnn_qdq_config # noqa: F401
@@ -0,0 +1,132 @@
1
+ # -------------------------------------------------------------------------
2
+ # Copyright (c) Microsoft Corporation. All rights reserved.
3
+ # Licensed under the MIT License. See License.txt in the project root for
4
+ # license information.
5
+ # --------------------------------------------------------------------------
6
+ from __future__ import annotations
7
+
8
+ import onnx
9
+
10
+ from ...fusions import Fusion
11
+ from ...onnx_model import ONNXModel
12
+
13
+
14
+ class FusionLpNormalization(Fusion):
15
+ def __init__(self, model: ONNXModel, epsilon: float = 1e-12):
16
+ super().__init__(model, "LpNormalization", "ReduceL2")
17
+ self.epsilon = epsilon
18
+
19
+ def fuse(
20
+ self,
21
+ reduce_node: onnx.NodeProto,
22
+ input_name_to_nodes: dict[str, list[onnx.NodeProto]],
23
+ output_name_to_node: dict[str, onnx.NodeProto],
24
+ ):
25
+ """
26
+ Interface function that tries to fuse a node sequence containing a ReduceL2 node into a single
27
+ LpNormalization node.
28
+
29
+ Pattern 1:
30
+ [root] --> ReduceL2 -----> Clip --> Expand ----> Div -->
31
+ | (axis=-1) (min=epsilon) (shape=root) ^
32
+ | (keepdims=True) |
33
+ | |
34
+ +-----------------------------------------------+
35
+ Notes:
36
+ - ReduceL2 must use the last axis, and keepdims == True
37
+ - Clip must only have a min attribute that is ~1e-12
38
+ - Expand must restore the shape to root.shape
39
+ - The output of Expand must be the second input to Div.
40
+ """
41
+ if reduce_node.output[0] not in input_name_to_nodes:
42
+ return
43
+
44
+ # ReduceL2 must have one Clip child
45
+ children = input_name_to_nodes[reduce_node.output[0]]
46
+ if len(children) != 1 or children[0].op_type != "Clip":
47
+ return
48
+
49
+ # ReduceL2 must have keepdims == True
50
+ keepdims = self.get_node_attribute(reduce_node, "keepdims")
51
+ if not keepdims:
52
+ return
53
+
54
+ # ReduceL2 axes must refer only to the last dimension.
55
+ # Axes became an input in opset 18. Before then, axes was an attribute
56
+ reduce_input_ttype = self.model.get_tensor_type(reduce_node.input[0])
57
+ if not reduce_input_ttype:
58
+ return
59
+
60
+ reduce_input_shape = self.tensor_shape_to_list(reduce_input_ttype)
61
+ if not reduce_input_shape:
62
+ return
63
+
64
+ axes = self.get_node_attribute(reduce_node, "axes")
65
+ if not axes and len(reduce_node.input) > 1:
66
+ axes = self.model.get_constant_value(reduce_node.input[1])
67
+
68
+ if not axes or len(axes) != 1:
69
+ return
70
+
71
+ last_dim = len(reduce_input_shape) - 1
72
+ if axes[0] != -1 and axes[0] != last_dim:
73
+ return
74
+
75
+ # Clip node must have a min attribute approximately equal to 1e-12
76
+ clip_node = children[0]
77
+ clip_min = self.get_node_attribute(clip_node, "min")
78
+ if clip_min is None and len(clip_node.input) > 1:
79
+ clip_min = self.model.get_constant_value(clip_node.input[1])
80
+
81
+ clip_max = self.get_node_attribute(clip_node, "max") # TODO: clip_max could be FLOAT_MAX
82
+ if clip_max is None and len(clip_node.input) > 2:
83
+ clip_max = self.model.get_constant_value(clip_node.input[2])
84
+
85
+ if not (clip_max is None and clip_min is not None and clip_min > 0 and abs(clip_min - self.epsilon) < 1e-13):
86
+ return
87
+
88
+ if clip_node.output[0] not in input_name_to_nodes:
89
+ return
90
+
91
+ # Clip must have a single Expand child.
92
+ children = input_name_to_nodes[clip_node.output[0]]
93
+ if len(children) != 1 or children[0].op_type != "Expand":
94
+ return
95
+
96
+ expand_node = children[0]
97
+ if expand_node.output[0] not in input_name_to_nodes:
98
+ return
99
+
100
+ # Expand must have a single Div child
101
+ children = input_name_to_nodes[expand_node.output[0]]
102
+ if len(children) != 1 or children[0].op_type != "Div":
103
+ return
104
+
105
+ div_node = children[0]
106
+
107
+ # The first input to Div must be the root of the subgraph (i.e., reduce_node.input[0])
108
+ # The second input to Div must be the output of the Expand.
109
+ # As long as these two inputs go to the same Div node, then ONNX validation will ensure that
110
+ # their shapes match.
111
+ if div_node.input[0] != reduce_node.input[0]:
112
+ return
113
+ if div_node.input[1] != expand_node.output[0]:
114
+ return
115
+
116
+ subgraph_input = reduce_node.input[0]
117
+ subgraph_output = div_node.output[0]
118
+
119
+ subgraph_nodes = [reduce_node, clip_node, expand_node, div_node]
120
+ if not self.is_safe_to_fuse_nodes(subgraph_nodes, [subgraph_output], input_name_to_nodes, output_name_to_node):
121
+ return
122
+
123
+ self.nodes_to_remove.extend(subgraph_nodes)
124
+ fused_node = onnx.helper.make_node(
125
+ self.fused_op_type,
126
+ name=self.create_unique_node_name(),
127
+ inputs=[subgraph_input],
128
+ outputs=[subgraph_output],
129
+ p=2,
130
+ axis=-1,
131
+ )
132
+ self.nodes_to_add.append(fused_node)
@@ -0,0 +1,413 @@
1
+ # -------------------------------------------------------------------------
2
+ # Copyright (c) Microsoft Corporation. All rights reserved.
3
+ # Licensed under the MIT License. See License.txt in the project root for
4
+ # license information.
5
+ # --------------------------------------------------------------------------
6
+ from __future__ import annotations
7
+
8
+ import logging
9
+ from dataclasses import dataclass
10
+
11
+ import onnx
12
+
13
+ from ...quant_utils import QuantType
14
+ from ...tensor_quant_overrides import QuantTypeInfo, TensorQuantOverridesHelper
15
+
16
+
17
+ @dataclass
18
+ class TensorTypeRequest:
19
+ """
20
+ Bundles desired quantization type requests for a tensor. A distinction is made between the
21
+ produced type and the consumed type.
22
+ """
23
+
24
+ # The tensor's quant type at the producer end. If None, assumed to be the default activation quant type.
25
+ producer: QuantTypeInfo | None
26
+
27
+ # The tensor's quant type received by a set of consumer nodes.
28
+ # If None, assumed to be the default activation quant type for all consumers.
29
+ # consumers[1] is a set of consumer node names.
30
+ consumers: tuple[QuantTypeInfo, set[str]] | None
31
+
32
+
33
+ class MixedPrecisionTensorQuantOverridesFixer:
34
+ """
35
+ Helper that generates tensor quantization overrides for mixed-precision QDQ models.
36
+
37
+ Specifically, this helper fixes an initial set of quantization overrides that assign a non-default
38
+ activation quantization type to one or more tensors by doing the following:
39
+ - Inferring which other tensors need to be overridden to the non-default activation quantization type.
40
+ - Inserting quantization data type conversions.
41
+
42
+ Example:
43
+ --------
44
+
45
+ Float model:
46
+
47
+ input_0 --> Op1 --> Op3 --> Op5 --> Op6 --> output_0
48
+ ^
49
+ |
50
+ input_1 --> Op2 -+-> Op4 ----+
51
+ |
52
+ +-> Op7 --> output_1
53
+ |
54
+ +-> Op8 --> output_2
55
+
56
+ If we'd like to quantize this model to uint8 precision, but would like to make sure tensor "Op4_out"
57
+ is quantized to 16-bit, then we would specify the following initial tensor quantization overrides:
58
+
59
+ ```
60
+ init_overrides = {"Op4_out": [{"quant_type": QuantType.QUInt16}]}
61
+ ```
62
+
63
+ These initial overrides may not create a valid model because Op4 and Op5 may require both the input and output
64
+ to be the same type (e.g., uint16). This helper fixes the overrides so that input/output data types
65
+ are valid:
66
+
67
+ ```
68
+ overrides = TensorQuantOverridesHelper(init_overrides)
69
+
70
+ fixer = MixedPrecisionTensorQuantOverridesFixer.create_from_model(overrides, model, QuantType.QUInt8)
71
+ fixer.apply(
72
+ default_activation_qtype=QuantType.QUInt8,
73
+ default_activation_symmetric=False,
74
+ )
75
+ ```
76
+
77
+ The above snippet generates the following "fixed" overrides (get via overrides.get_dict()):
78
+
79
+ {
80
+ "Op2_out": [{"quant_type": QUInt8, "convert": {"quant_type": QUInt16, "recv_nodes": {"Op4"}}}],
81
+ "Op3_out": [{"quant_type": QUInt8, "convert": {"quant_type": QUInt16, "recv_nodes": {"Op5"}}}],
82
+ "Op4_out": [{"quant_type": QUInt16}],
83
+ "Op5_out": [{"quant_type": QUInt16, "convert": {"quant_type": QUInt8, "recv_nodes": {"Op6"}}}]
84
+ }
85
+
86
+ How to interpret the fixed overrides:
87
+ - Op2's output is consumed by Op4, Op7, and Op8. Op4 consumes the converted u16 type,
88
+ but Op7 and Op8 consume the original u8 type.
89
+ - Op3's output is converted from u8 to u16. Op5 consumes the converted u16 type.
90
+ - Op4's output is just u16 (not converted). All consumers of Op4_out get the u16 type.
91
+ - Op5's output is converted from u16 to u8. Op6 consumes the u8 type.
92
+ """
93
+
94
+ def __init__(
95
+ self,
96
+ overrides: TensorQuantOverridesHelper,
97
+ producers: dict[str, onnx.NodeProto],
98
+ consumers: dict[str, list[onnx.NodeProto]],
99
+ value_infos: dict[str, onnx.ValueInfoProto],
100
+ initializers: dict[str, onnx.TensorProto],
101
+ ):
102
+ """
103
+ Params:
104
+ overrides: The initial tensor quantization overrides to fix.
105
+ producers: Dictionary that maps a tensor name to the producer node that generates the tensor.
106
+ consumers: Dictionary that maps a tensor name to the consumer nodes that take the tensor as input.
107
+ value_infos: Dictionary that maps a tensor name to its onnx.ValueInfoProto.
108
+ initializers: Dictionary that maps an initializer name to its onnx.TensorProto.
109
+ """
110
+ self.overrides = overrides
111
+ self.consumers = consumers
112
+ self.producers = producers
113
+ self.value_infos = value_infos
114
+ self.initializers = initializers
115
+
116
+ @staticmethod
117
+ def create_from_model(
118
+ overrides: TensorQuantOverridesHelper, model: onnx.ModelProto, default_activation_qtype: QuantType
119
+ ) -> MixedPrecisionTensorQuantOverridesFixer:
120
+ """
121
+ Helper function that creates an instance of this class from a loaded ONNX model.
122
+
123
+ Params:
124
+ overrides: The initial tensor quantization overrides to fix.
125
+ model: Loaded ONNX model
126
+ default_activation_qtype: The intended default activation quantization type.
127
+ Used to validate the initial overrides.
128
+
129
+ Returns:
130
+ Initialized MixedPrecisionTensorQuantOverridesFixer object
131
+ """
132
+ model = onnx.shape_inference.infer_shapes(model) # Need to infer shapes to get value_infos
133
+
134
+ # Build dictionaries that enable convenient lookups of initializers and value_infos by name.
135
+ initializers = {initializer.name: initializer for initializer in model.graph.initializer}
136
+ value_infos = {vi.name: vi for vi in model.graph.value_info}
137
+ value_infos.update({ot.name: ot for ot in model.graph.output})
138
+ value_infos.update({it.name: it for it in model.graph.input})
139
+
140
+ # Ensure that the user-provided initial overrides are actually valid.
141
+ valid, err = overrides.is_valid(initializers, set(value_infos), default_activation_qtype)
142
+ if not valid:
143
+ pprint_overrides = overrides.pprint_str(indent=4)
144
+ logging.error(f"Provided invalid tensor quantization overrides:\n{pprint_overrides}")
145
+ raise ValueError(err)
146
+
147
+ consumers = {}
148
+ producers = {}
149
+
150
+ # Build dictionaries that map a tensor name to the consumer or producer nodes.
151
+ for node in model.graph.node:
152
+ for input_name in node.input:
153
+ if input_name:
154
+ if input_name not in consumers:
155
+ consumers[input_name] = []
156
+
157
+ consumers[input_name].append(node)
158
+
159
+ for output_name in node.output:
160
+ producers[output_name] = node
161
+
162
+ return MixedPrecisionTensorQuantOverridesFixer(overrides, producers, consumers, value_infos, initializers)
163
+
164
+ def apply(
165
+ self,
166
+ default_activation_qtype: QuantType,
167
+ default_activation_symmetric: bool,
168
+ ):
169
+ """
170
+ Fixes the initial tensor quantization overrides (in-place) for use in mixed-precision QDQ models.
171
+
172
+ Params:
173
+ default_activation_qtype: The intended default activation quantization type.
174
+ default_activation_symmetric: The intended default symmetry used to quantize activations.
175
+ """
176
+ type_requests = self.get_desired_tensor_types(default_activation_qtype, default_activation_symmetric)
177
+
178
+ # Use type requests to "fix" tensor quantization overrides by adding
179
+ # quantization type conversions where necessary.
180
+ for tensor_name, type_req in type_requests.items():
181
+ all_consumers = set([node.name for node in self.consumers.get(tensor_name, [])])
182
+ has_producer_req = type_req.producer is not None
183
+ has_consumer_req = bool(type_req.consumers)
184
+
185
+ # Only producer type: Add conversion back to default activation type
186
+ if has_producer_req and not has_consumer_req:
187
+ self._update_converted_tensor(
188
+ tensor_name, type_req.producer, QuantTypeInfo(default_activation_qtype), all_consumers
189
+ )
190
+ # Only consumers
191
+ elif not has_producer_req and has_consumer_req:
192
+ prod_type_info = self.overrides.get_node_output_qtype_info(tensor_name, default_activation_qtype)
193
+ consumer_type_info = type_req.consumers[0]
194
+
195
+ if prod_type_info != consumer_type_info:
196
+ self._update_converted_tensor(
197
+ tensor_name, prod_type_info, consumer_type_info, type_req.consumers[1]
198
+ )
199
+ else:
200
+ if not self._check_nodes_are_not_convert_consumers(tensor_name, type_req.consumers[1]):
201
+ raise ValueError(
202
+ f"Tensor override for '{tensor_name}' converts the type for consumers that need the original type."
203
+ )
204
+ # Both producer and consumers
205
+ elif has_producer_req and has_consumer_req:
206
+ prod_type_info = type_req.producer
207
+ consumer_type_info = type_req.consumers[0]
208
+
209
+ if prod_type_info != consumer_type_info:
210
+ self._update_converted_tensor(
211
+ tensor_name, prod_type_info, consumer_type_info, type_req.consumers[1]
212
+ )
213
+ else:
214
+ consumers_for_original_type = all_consumers.difference(type_req.consumers[1])
215
+
216
+ if len(consumers_for_original_type) == 0:
217
+ # All consumers want the overridden type, so no need for convert nodes!
218
+ # Just add the override to the new new if not already present.
219
+ if tensor_name not in self.overrides:
220
+ self.overrides[tensor_name] = [{}]
221
+ prod_type_info.save_to_dict(self.overrides[tensor_name][0])
222
+
223
+ assert "convert" not in self.overrides[tensor_name][0]
224
+ else:
225
+ # Some consumers don't want the overridden type.
226
+ self._update_converted_tensor(
227
+ tensor_name,
228
+ prod_type_info,
229
+ QuantTypeInfo(default_activation_qtype),
230
+ consumers_for_original_type,
231
+ )
232
+ else:
233
+ raise ValueError(f"TypeRequest for tensor {tensor_name} has no producer or consumers.")
234
+
235
+ # Done. Check if the overrides are valid.
236
+ valid, err = self.overrides.is_valid(self.initializers, set(self.value_infos), default_activation_qtype)
237
+ if not valid:
238
+ pprint_overrides = self.overrides.pprint_str(indent=4)
239
+ logging.error(
240
+ f"Generated invalid tensor quantization overrides for mixed-precision QDQ model:\n{pprint_overrides}"
241
+ )
242
+ raise ValueError(err)
243
+
244
+ def get_desired_tensor_types(
245
+ self,
246
+ default_activation_qtype: QuantType,
247
+ default_activation_symmetric: bool,
248
+ ) -> dict[str, TensorTypeRequest]:
249
+ """
250
+ Iterates through the initial tensor quantization overrides and builds a set of TensorTypeRequests objects
251
+ that describe the quantization types required at each tensor. These TensorTypeRequests objects are ultimately
252
+ used to generated the "fixed" overrides.
253
+
254
+ Params:
255
+ default_activation_qtype: The intended default activation quantization type.
256
+ default_activation_symmetric: The intended default symmetry used to quantize activations.
257
+
258
+ Returns:
259
+ TensorTypeRequest objects as a dict that maps a tensor name to its requested types.
260
+ """
261
+ type_requests = {}
262
+ default_activation_type_info = QuantTypeInfo(default_activation_qtype, default_activation_symmetric)
263
+
264
+ # Scan tensor overrides for type conversion requests.
265
+ for tensor_name, override_list in self.overrides.items():
266
+ if not self.__is_tensor_quantizable(tensor_name):
267
+ continue # Skip non-quantizable tensors (e.g., not a float)
268
+
269
+ if tensor_name in self.initializers:
270
+ continue # Skip initializers
271
+
272
+ if not override_list or len(override_list) > 1:
273
+ continue # Skip per-channel stuff
274
+
275
+ override_dict = override_list[0]
276
+ quant_type_info = QuantTypeInfo.load_from_dict(override_dict, default_activation_type_info.quant_type)
277
+ producer_node = self.producers.get(tensor_name) # None if this is a model input
278
+
279
+ if quant_type_info != default_activation_type_info and "convert" not in override_dict:
280
+ if producer_node is not None:
281
+ self._add_type_requests_for_node(type_requests, quant_type_info, producer_node)
282
+
283
+ # Find all consumer nodes of `tensor_name` and update their inputs/outputs to the new type.
284
+ for consumer_node in self.consumers.get(tensor_name, []):
285
+ self._add_type_requests_for_node(type_requests, quant_type_info, consumer_node)
286
+
287
+ return type_requests
288
+
289
+ def _add_type_requests_for_node(
290
+ self,
291
+ type_requests: dict[str, TensorTypeRequest],
292
+ quant_type_info: QuantTypeInfo,
293
+ node: onnx.NodeProto,
294
+ ):
295
+ """
296
+ Adds TensorTypeRequest objects for a given node, assuming that we want all its inputs and outputs
297
+ to have the same quantization type (as specified by the `quant_type_info` parameter).
298
+
299
+ Params:
300
+ type_requests: Dictionary of type requests to append to for this node.
301
+ quant_type_info: The quantization type to use for inputs and outputs.
302
+ node: The node for which the TensorTypeRequest objects are created and added to type_requests.
303
+ """
304
+ # Add output side
305
+ for output_name in node.output:
306
+ if not self.__is_tensor_quantizable(output_name):
307
+ continue
308
+
309
+ if output_name not in type_requests:
310
+ type_requests[output_name] = TensorTypeRequest(quant_type_info, None)
311
+ else:
312
+ if (
313
+ type_requests[output_name].producer is not None
314
+ and type_requests[output_name].producer != quant_type_info
315
+ ):
316
+ raise ValueError(f"Tensor {output_name} has multiple types.")
317
+
318
+ type_requests[output_name].producer = quant_type_info
319
+
320
+ # Add the consumer side
321
+ for input_name in node.input:
322
+ if input_name and input_name not in self.initializers and self.__is_tensor_quantizable(input_name):
323
+ if input_name not in type_requests:
324
+ type_requests[input_name] = TensorTypeRequest(None, None)
325
+
326
+ if type_requests[input_name].consumers is None:
327
+ type_requests[input_name].consumers = (quant_type_info, set())
328
+
329
+ if type_requests[input_name].consumers[0] != quant_type_info:
330
+ raise ValueError(f"Tensor {input_name} has consumers requesting different types.")
331
+
332
+ if not node.name:
333
+ raise ValueError(
334
+ f"Node of type {node.op_type} with output 0 {node.output[0]} does not have a name!"
335
+ )
336
+
337
+ type_requests[input_name].consumers[1].add(node.name)
338
+
339
+ def _update_converted_tensor(
340
+ self,
341
+ tensor_name: str,
342
+ producer_type_info: QuantTypeInfo,
343
+ consumer_type_info: QuantTypeInfo,
344
+ consumer_names: set[str],
345
+ ):
346
+ """
347
+ Updates the tensor quantization overrides for a tensor that is converted from one type to another.
348
+
349
+ Params:
350
+ tensor_name: The name of the tensor for which to update overrides.
351
+ producer_type_info: Info for the tensor's produced type.
352
+ consumer_type_info: Info for the tensor's consumed (i.e., converted) type.
353
+ consumer_names: Nodes names of consumers that consume the converted type.
354
+ """
355
+ if tensor_name not in self.overrides or not self.overrides[tensor_name]:
356
+ self.overrides[tensor_name] = [{}]
357
+ producer_type_info.save_to_dict(self.overrides[tensor_name][0])
358
+
359
+ overrides = self.overrides[tensor_name][0]
360
+ if producer_type_info != QuantTypeInfo.load_from_dict(overrides):
361
+ raise ValueError(f"Desired producer quant_type for {tensor_name} doesn't match existing type.")
362
+
363
+ if consumer_names:
364
+ if "convert" not in overrides:
365
+ overrides["convert"] = {}
366
+ consumer_type_info.save_to_dict(overrides["convert"])
367
+
368
+ convert_dict = overrides["convert"]
369
+ if consumer_type_info != QuantTypeInfo.load_from_dict(convert_dict):
370
+ raise ValueError(f"Desired consumer quant_type for {tensor_name} doesn't match existing type.")
371
+
372
+ if "recv_nodes" not in convert_dict:
373
+ convert_dict["recv_nodes"] = set()
374
+
375
+ convert_dict["recv_nodes"].update(consumer_names)
376
+
377
+ def _check_nodes_are_not_convert_consumers(self, tensor_name: str, node_names: set[str]):
378
+ """
379
+ Returns true if the given nodes do not consume/receive a converted quantization type.
380
+
381
+ Params:
382
+ tensor_name: The name of the tensor to check.
383
+ node_names: Set of node names that should not be consumers of the converted type.
384
+ """
385
+ if tensor_name not in self.overrides or not self.overrides[tensor_name]:
386
+ return True
387
+
388
+ overrides = self.overrides[tensor_name][0]
389
+
390
+ if "convert" not in overrides:
391
+ return True
392
+
393
+ convert_dict = overrides["convert"]
394
+
395
+ if "recv_nodes" not in convert_dict:
396
+ return False
397
+
398
+ return not convert_dict["recv_nodes"].intersection(node_names)
399
+
400
+ def __is_tensor_quantizable(self, tensor_name):
401
+ weight = self.initializers.get(tensor_name)
402
+ if weight is not None:
403
+ if weight.data_type in (onnx.TensorProto.FLOAT, onnx.TensorProto.FLOAT16):
404
+ return True
405
+ elif tensor_name in self.value_infos:
406
+ vi = self.value_infos[tensor_name]
407
+ if vi.type.HasField("tensor_type") and vi.type.tensor_type.elem_type in (
408
+ onnx.TensorProto.FLOAT,
409
+ onnx.TensorProto.FLOAT16,
410
+ ):
411
+ return True
412
+
413
+ return False