onnxruntime-directml 1.24.1__cp314-cp314-win_amd64.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (322) hide show
  1. onnxruntime/LICENSE +21 -0
  2. onnxruntime/Privacy.md +21 -0
  3. onnxruntime/ThirdPartyNotices.txt +6121 -0
  4. onnxruntime/__init__.py +418 -0
  5. onnxruntime/backend/__init__.py +6 -0
  6. onnxruntime/backend/backend.py +175 -0
  7. onnxruntime/backend/backend_rep.py +52 -0
  8. onnxruntime/capi/DirectML.dll +0 -0
  9. onnxruntime/capi/__init__.py +4 -0
  10. onnxruntime/capi/_ld_preload.py +7 -0
  11. onnxruntime/capi/_pybind_state.py +33 -0
  12. onnxruntime/capi/build_and_package_info.py +2 -0
  13. onnxruntime/capi/convert_npz_to_onnx_adapter.py +48 -0
  14. onnxruntime/capi/onnxruntime.dll +0 -0
  15. onnxruntime/capi/onnxruntime_collect_build_info.py +47 -0
  16. onnxruntime/capi/onnxruntime_inference_collection.py +1440 -0
  17. onnxruntime/capi/onnxruntime_providers_shared.dll +0 -0
  18. onnxruntime/capi/onnxruntime_pybind11_state.pyd +0 -0
  19. onnxruntime/capi/onnxruntime_validation.py +154 -0
  20. onnxruntime/capi/version_info.py +2 -0
  21. onnxruntime/datasets/__init__.py +18 -0
  22. onnxruntime/datasets/logreg_iris.onnx +0 -0
  23. onnxruntime/datasets/mul_1.onnx +0 -0
  24. onnxruntime/datasets/sigmoid.onnx +13 -0
  25. onnxruntime/quantization/CalTableFlatBuffers/KeyValue.py +78 -0
  26. onnxruntime/quantization/CalTableFlatBuffers/TrtTable.py +90 -0
  27. onnxruntime/quantization/CalTableFlatBuffers/__init__.py +0 -0
  28. onnxruntime/quantization/__init__.py +19 -0
  29. onnxruntime/quantization/base_quantizer.py +529 -0
  30. onnxruntime/quantization/calibrate.py +1267 -0
  31. onnxruntime/quantization/execution_providers/qnn/__init__.py +2 -0
  32. onnxruntime/quantization/execution_providers/qnn/fusion_lpnorm.py +132 -0
  33. onnxruntime/quantization/execution_providers/qnn/fusion_spacetodepth.py +162 -0
  34. onnxruntime/quantization/execution_providers/qnn/mixed_precision_overrides_utils.py +413 -0
  35. onnxruntime/quantization/execution_providers/qnn/preprocess.py +353 -0
  36. onnxruntime/quantization/execution_providers/qnn/quant_config.py +389 -0
  37. onnxruntime/quantization/fusions/__init__.py +4 -0
  38. onnxruntime/quantization/fusions/fusion.py +311 -0
  39. onnxruntime/quantization/fusions/fusion_gelu.py +272 -0
  40. onnxruntime/quantization/fusions/fusion_layernorm.py +146 -0
  41. onnxruntime/quantization/fusions/replace_upsample_with_resize.py +96 -0
  42. onnxruntime/quantization/matmul_bnb4_quantizer.py +239 -0
  43. onnxruntime/quantization/matmul_nbits_quantizer.py +1638 -0
  44. onnxruntime/quantization/neural_compressor/__init__.py +1 -0
  45. onnxruntime/quantization/neural_compressor/onnx_model.py +1251 -0
  46. onnxruntime/quantization/neural_compressor/util.py +80 -0
  47. onnxruntime/quantization/neural_compressor/weight_only.py +932 -0
  48. onnxruntime/quantization/onnx_model.py +600 -0
  49. onnxruntime/quantization/onnx_quantizer.py +1163 -0
  50. onnxruntime/quantization/operators/__init__.py +2 -0
  51. onnxruntime/quantization/operators/activation.py +119 -0
  52. onnxruntime/quantization/operators/argmax.py +18 -0
  53. onnxruntime/quantization/operators/attention.py +73 -0
  54. onnxruntime/quantization/operators/base_operator.py +26 -0
  55. onnxruntime/quantization/operators/binary_op.py +72 -0
  56. onnxruntime/quantization/operators/concat.py +62 -0
  57. onnxruntime/quantization/operators/conv.py +260 -0
  58. onnxruntime/quantization/operators/direct_q8.py +78 -0
  59. onnxruntime/quantization/operators/embed_layernorm.py +121 -0
  60. onnxruntime/quantization/operators/gather.py +64 -0
  61. onnxruntime/quantization/operators/gavgpool.py +62 -0
  62. onnxruntime/quantization/operators/gemm.py +172 -0
  63. onnxruntime/quantization/operators/lstm.py +121 -0
  64. onnxruntime/quantization/operators/matmul.py +231 -0
  65. onnxruntime/quantization/operators/maxpool.py +34 -0
  66. onnxruntime/quantization/operators/norm.py +40 -0
  67. onnxruntime/quantization/operators/pad.py +172 -0
  68. onnxruntime/quantization/operators/pooling.py +67 -0
  69. onnxruntime/quantization/operators/qdq_base_operator.py +22 -0
  70. onnxruntime/quantization/operators/resize.py +34 -0
  71. onnxruntime/quantization/operators/softmax.py +74 -0
  72. onnxruntime/quantization/operators/split.py +63 -0
  73. onnxruntime/quantization/operators/where.py +87 -0
  74. onnxruntime/quantization/preprocess.py +141 -0
  75. onnxruntime/quantization/qdq_loss_debug.py +389 -0
  76. onnxruntime/quantization/qdq_quantizer.py +1477 -0
  77. onnxruntime/quantization/quant_utils.py +1051 -0
  78. onnxruntime/quantization/quantize.py +953 -0
  79. onnxruntime/quantization/registry.py +110 -0
  80. onnxruntime/quantization/shape_inference.py +204 -0
  81. onnxruntime/quantization/static_quantize_runner.py +256 -0
  82. onnxruntime/quantization/tensor_quant_overrides.py +520 -0
  83. onnxruntime/tools/__init__.py +10 -0
  84. onnxruntime/tools/check_onnx_model_mobile_usability.py +47 -0
  85. onnxruntime/tools/convert_onnx_models_to_ort.py +380 -0
  86. onnxruntime/tools/file_utils.py +47 -0
  87. onnxruntime/tools/logger.py +11 -0
  88. onnxruntime/tools/make_dynamic_shape_fixed.py +73 -0
  89. onnxruntime/tools/mobile_helpers/__init__.py +0 -0
  90. onnxruntime/tools/mobile_helpers/coreml_supported_mlprogram_ops.md +53 -0
  91. onnxruntime/tools/mobile_helpers/coreml_supported_neuralnetwork_ops.md +43 -0
  92. onnxruntime/tools/mobile_helpers/nnapi_supported_ops.md +58 -0
  93. onnxruntime/tools/mobile_helpers/usability_checker.py +738 -0
  94. onnxruntime/tools/offline_tuning.py +169 -0
  95. onnxruntime/tools/onnx_model_utils.py +416 -0
  96. onnxruntime/tools/onnx_randomizer.py +85 -0
  97. onnxruntime/tools/onnxruntime_test.py +164 -0
  98. onnxruntime/tools/optimize_onnx_model.py +56 -0
  99. onnxruntime/tools/ort_format_model/__init__.py +27 -0
  100. onnxruntime/tools/ort_format_model/operator_type_usage_processors.py +653 -0
  101. onnxruntime/tools/ort_format_model/ort_flatbuffers_py/__init__.py +0 -0
  102. onnxruntime/tools/ort_format_model/ort_flatbuffers_py/fbs/ArgType.py +7 -0
  103. onnxruntime/tools/ort_format_model/ort_flatbuffers_py/fbs/ArgTypeAndIndex.py +67 -0
  104. onnxruntime/tools/ort_format_model/ort_flatbuffers_py/fbs/Attribute.py +337 -0
  105. onnxruntime/tools/ort_format_model/ort_flatbuffers_py/fbs/AttributeType.py +18 -0
  106. onnxruntime/tools/ort_format_model/ort_flatbuffers_py/fbs/Checkpoint.py +125 -0
  107. onnxruntime/tools/ort_format_model/ort_flatbuffers_py/fbs/DeprecatedKernelCreateInfos.py +120 -0
  108. onnxruntime/tools/ort_format_model/ort_flatbuffers_py/fbs/DeprecatedNodeIndexAndKernelDefHash.py +68 -0
  109. onnxruntime/tools/ort_format_model/ort_flatbuffers_py/fbs/DeprecatedSessionState.py +96 -0
  110. onnxruntime/tools/ort_format_model/ort_flatbuffers_py/fbs/DeprecatedSubGraphSessionState.py +72 -0
  111. onnxruntime/tools/ort_format_model/ort_flatbuffers_py/fbs/Dimension.py +71 -0
  112. onnxruntime/tools/ort_format_model/ort_flatbuffers_py/fbs/DimensionValue.py +80 -0
  113. onnxruntime/tools/ort_format_model/ort_flatbuffers_py/fbs/DimensionValueType.py +8 -0
  114. onnxruntime/tools/ort_format_model/ort_flatbuffers_py/fbs/EdgeEnd.py +32 -0
  115. onnxruntime/tools/ort_format_model/ort_flatbuffers_py/fbs/FloatProperty.py +67 -0
  116. onnxruntime/tools/ort_format_model/ort_flatbuffers_py/fbs/Graph.py +320 -0
  117. onnxruntime/tools/ort_format_model/ort_flatbuffers_py/fbs/InferenceSession.py +88 -0
  118. onnxruntime/tools/ort_format_model/ort_flatbuffers_py/fbs/IntProperty.py +67 -0
  119. onnxruntime/tools/ort_format_model/ort_flatbuffers_py/fbs/KernelTypeStrArgsEntry.py +91 -0
  120. onnxruntime/tools/ort_format_model/ort_flatbuffers_py/fbs/KernelTypeStrResolver.py +78 -0
  121. onnxruntime/tools/ort_format_model/ort_flatbuffers_py/fbs/MapType.py +71 -0
  122. onnxruntime/tools/ort_format_model/ort_flatbuffers_py/fbs/Model.py +223 -0
  123. onnxruntime/tools/ort_format_model/ort_flatbuffers_py/fbs/ModuleState.py +141 -0
  124. onnxruntime/tools/ort_format_model/ort_flatbuffers_py/fbs/Node.py +317 -0
  125. onnxruntime/tools/ort_format_model/ort_flatbuffers_py/fbs/NodeEdge.py +126 -0
  126. onnxruntime/tools/ort_format_model/ort_flatbuffers_py/fbs/NodeType.py +7 -0
  127. onnxruntime/tools/ort_format_model/ort_flatbuffers_py/fbs/NodesToOptimizeIndices.py +160 -0
  128. onnxruntime/tools/ort_format_model/ort_flatbuffers_py/fbs/OpIdKernelTypeStrArgsEntry.py +91 -0
  129. onnxruntime/tools/ort_format_model/ort_flatbuffers_py/fbs/OperatorSetId.py +67 -0
  130. onnxruntime/tools/ort_format_model/ort_flatbuffers_py/fbs/OptimizerGroup.py +117 -0
  131. onnxruntime/tools/ort_format_model/ort_flatbuffers_py/fbs/ParameterOptimizerState.py +91 -0
  132. onnxruntime/tools/ort_format_model/ort_flatbuffers_py/fbs/PropertyBag.py +152 -0
  133. onnxruntime/tools/ort_format_model/ort_flatbuffers_py/fbs/RuntimeOptimizationRecord.py +105 -0
  134. onnxruntime/tools/ort_format_model/ort_flatbuffers_py/fbs/RuntimeOptimizationRecordContainerEntry.py +91 -0
  135. onnxruntime/tools/ort_format_model/ort_flatbuffers_py/fbs/RuntimeOptimizations.py +79 -0
  136. onnxruntime/tools/ort_format_model/ort_flatbuffers_py/fbs/SequenceType.py +58 -0
  137. onnxruntime/tools/ort_format_model/ort_flatbuffers_py/fbs/Shape.py +78 -0
  138. onnxruntime/tools/ort_format_model/ort_flatbuffers_py/fbs/SparseTensor.py +114 -0
  139. onnxruntime/tools/ort_format_model/ort_flatbuffers_py/fbs/StringProperty.py +67 -0
  140. onnxruntime/tools/ort_format_model/ort_flatbuffers_py/fbs/StringStringEntry.py +67 -0
  141. onnxruntime/tools/ort_format_model/ort_flatbuffers_py/fbs/Tensor.py +203 -0
  142. onnxruntime/tools/ort_format_model/ort_flatbuffers_py/fbs/TensorDataType.py +26 -0
  143. onnxruntime/tools/ort_format_model/ort_flatbuffers_py/fbs/TensorTypeAndShape.py +71 -0
  144. onnxruntime/tools/ort_format_model/ort_flatbuffers_py/fbs/TypeInfo.py +83 -0
  145. onnxruntime/tools/ort_format_model/ort_flatbuffers_py/fbs/TypeInfoValue.py +9 -0
  146. onnxruntime/tools/ort_format_model/ort_flatbuffers_py/fbs/ValueInfo.py +84 -0
  147. onnxruntime/tools/ort_format_model/ort_flatbuffers_py/fbs/__init__.py +6 -0
  148. onnxruntime/tools/ort_format_model/ort_model_processor.py +86 -0
  149. onnxruntime/tools/ort_format_model/types.py +85 -0
  150. onnxruntime/tools/ort_format_model/utils.py +61 -0
  151. onnxruntime/tools/pytorch_export_contrib_ops.py +129 -0
  152. onnxruntime/tools/pytorch_export_helpers.py +131 -0
  153. onnxruntime/tools/qdq_helpers/__init__.py +0 -0
  154. onnxruntime/tools/qdq_helpers/optimize_qdq_model.py +37 -0
  155. onnxruntime/tools/qnn/add_trans_cast.py +292 -0
  156. onnxruntime/tools/qnn/gen_qnn_ctx_onnx_model.py +364 -0
  157. onnxruntime/tools/qnn/preprocess.py +165 -0
  158. onnxruntime/tools/reduced_build_config_parser.py +203 -0
  159. onnxruntime/tools/remove_initializer_from_input.py +37 -0
  160. onnxruntime/tools/symbolic_shape_infer.py +3094 -0
  161. onnxruntime/tools/update_onnx_opset.py +31 -0
  162. onnxruntime/transformers/__init__.py +8 -0
  163. onnxruntime/transformers/affinity_helper.py +40 -0
  164. onnxruntime/transformers/benchmark.py +942 -0
  165. onnxruntime/transformers/benchmark_helper.py +643 -0
  166. onnxruntime/transformers/bert_perf_test.py +629 -0
  167. onnxruntime/transformers/bert_test_data.py +641 -0
  168. onnxruntime/transformers/compare_bert_results.py +256 -0
  169. onnxruntime/transformers/constants.py +47 -0
  170. onnxruntime/transformers/convert_generation.py +3605 -0
  171. onnxruntime/transformers/convert_tf_models_to_pytorch.py +205 -0
  172. onnxruntime/transformers/convert_to_packing_mode.py +385 -0
  173. onnxruntime/transformers/dynamo_onnx_helper.py +205 -0
  174. onnxruntime/transformers/float16.py +501 -0
  175. onnxruntime/transformers/fusion_attention.py +1189 -0
  176. onnxruntime/transformers/fusion_attention_clip.py +340 -0
  177. onnxruntime/transformers/fusion_attention_sam2.py +533 -0
  178. onnxruntime/transformers/fusion_attention_unet.py +1307 -0
  179. onnxruntime/transformers/fusion_attention_vae.py +300 -0
  180. onnxruntime/transformers/fusion_bart_attention.py +435 -0
  181. onnxruntime/transformers/fusion_base.py +141 -0
  182. onnxruntime/transformers/fusion_bias_add.py +57 -0
  183. onnxruntime/transformers/fusion_biasgelu.py +66 -0
  184. onnxruntime/transformers/fusion_biassplitgelu.py +110 -0
  185. onnxruntime/transformers/fusion_conformer_attention.py +222 -0
  186. onnxruntime/transformers/fusion_constant_fold.py +144 -0
  187. onnxruntime/transformers/fusion_embedlayer.py +810 -0
  188. onnxruntime/transformers/fusion_fastgelu.py +492 -0
  189. onnxruntime/transformers/fusion_gelu.py +258 -0
  190. onnxruntime/transformers/fusion_gelu_approximation.py +25 -0
  191. onnxruntime/transformers/fusion_gemmfastgelu.py +121 -0
  192. onnxruntime/transformers/fusion_gpt_attention.py +546 -0
  193. onnxruntime/transformers/fusion_gpt_attention_megatron.py +355 -0
  194. onnxruntime/transformers/fusion_gpt_attention_no_past.py +260 -0
  195. onnxruntime/transformers/fusion_group_norm.py +180 -0
  196. onnxruntime/transformers/fusion_layernorm.py +489 -0
  197. onnxruntime/transformers/fusion_mha_mmdit.py +667 -0
  198. onnxruntime/transformers/fusion_nhwc_conv.py +99 -0
  199. onnxruntime/transformers/fusion_options.py +340 -0
  200. onnxruntime/transformers/fusion_qordered_attention.py +420 -0
  201. onnxruntime/transformers/fusion_qordered_gelu.py +118 -0
  202. onnxruntime/transformers/fusion_qordered_layernorm.py +122 -0
  203. onnxruntime/transformers/fusion_qordered_matmul.py +216 -0
  204. onnxruntime/transformers/fusion_quickgelu.py +74 -0
  205. onnxruntime/transformers/fusion_reshape.py +173 -0
  206. onnxruntime/transformers/fusion_rotary_attention.py +1591 -0
  207. onnxruntime/transformers/fusion_shape.py +109 -0
  208. onnxruntime/transformers/fusion_simplified_layernorm.py +165 -0
  209. onnxruntime/transformers/fusion_skip_group_norm.py +254 -0
  210. onnxruntime/transformers/fusion_skiplayernorm.py +209 -0
  211. onnxruntime/transformers/fusion_transpose.py +167 -0
  212. onnxruntime/transformers/fusion_utils.py +321 -0
  213. onnxruntime/transformers/huggingface_models.py +74 -0
  214. onnxruntime/transformers/import_utils.py +20 -0
  215. onnxruntime/transformers/io_binding_helper.py +487 -0
  216. onnxruntime/transformers/large_model_exporter.py +395 -0
  217. onnxruntime/transformers/machine_info.py +230 -0
  218. onnxruntime/transformers/metrics.py +163 -0
  219. onnxruntime/transformers/models/bart/__init__.py +12 -0
  220. onnxruntime/transformers/models/bart/export.py +98 -0
  221. onnxruntime/transformers/models/bert/__init__.py +12 -0
  222. onnxruntime/transformers/models/bert/eval_squad.py +329 -0
  223. onnxruntime/transformers/models/gpt2/__init__.py +12 -0
  224. onnxruntime/transformers/models/gpt2/benchmark_gpt2.py +413 -0
  225. onnxruntime/transformers/models/gpt2/convert_to_onnx.py +566 -0
  226. onnxruntime/transformers/models/gpt2/gpt2_helper.py +1031 -0
  227. onnxruntime/transformers/models/gpt2/gpt2_parity.py +513 -0
  228. onnxruntime/transformers/models/gpt2/gpt2_tester.py +501 -0
  229. onnxruntime/transformers/models/gpt2/parity_check_helper.py +146 -0
  230. onnxruntime/transformers/models/llama/__init__.py +12 -0
  231. onnxruntime/transformers/models/llama/benchmark.py +700 -0
  232. onnxruntime/transformers/models/llama/benchmark_all.py +488 -0
  233. onnxruntime/transformers/models/llama/benchmark_e2e.py +608 -0
  234. onnxruntime/transformers/models/llama/convert_to_onnx.py +1064 -0
  235. onnxruntime/transformers/models/llama/dist_settings.py +57 -0
  236. onnxruntime/transformers/models/llama/llama_inputs.py +504 -0
  237. onnxruntime/transformers/models/llama/llama_parity.py +343 -0
  238. onnxruntime/transformers/models/llama/llama_torch.py +47 -0
  239. onnxruntime/transformers/models/llama/quant_kv_dataloader.py +108 -0
  240. onnxruntime/transformers/models/longformer/__init__.py +12 -0
  241. onnxruntime/transformers/models/longformer/benchmark_longformer.py +821 -0
  242. onnxruntime/transformers/models/longformer/convert_to_onnx.py +413 -0
  243. onnxruntime/transformers/models/longformer/generate_test_data.py +347 -0
  244. onnxruntime/transformers/models/longformer/longformer_helper.py +76 -0
  245. onnxruntime/transformers/models/phi2/__init__.py +12 -0
  246. onnxruntime/transformers/models/phi2/convert_to_onnx.py +590 -0
  247. onnxruntime/transformers/models/phi2/inference_example.py +414 -0
  248. onnxruntime/transformers/models/sam2/__init__.py +12 -0
  249. onnxruntime/transformers/models/sam2/benchmark_sam2.py +638 -0
  250. onnxruntime/transformers/models/sam2/convert_to_onnx.py +270 -0
  251. onnxruntime/transformers/models/sam2/image_decoder.py +272 -0
  252. onnxruntime/transformers/models/sam2/image_encoder.py +236 -0
  253. onnxruntime/transformers/models/sam2/mask_decoder.py +208 -0
  254. onnxruntime/transformers/models/sam2/nvtx_helper.py +33 -0
  255. onnxruntime/transformers/models/sam2/prompt_encoder.py +189 -0
  256. onnxruntime/transformers/models/sam2/sam2_demo.py +321 -0
  257. onnxruntime/transformers/models/sam2/sam2_image_onnx_predictor.py +279 -0
  258. onnxruntime/transformers/models/sam2/sam2_utils.py +147 -0
  259. onnxruntime/transformers/models/stable_diffusion/__init__.py +12 -0
  260. onnxruntime/transformers/models/stable_diffusion/benchmark.py +1519 -0
  261. onnxruntime/transformers/models/stable_diffusion/benchmark_controlnet.py +426 -0
  262. onnxruntime/transformers/models/stable_diffusion/demo_txt2img.py +103 -0
  263. onnxruntime/transformers/models/stable_diffusion/demo_txt2img_xl.py +269 -0
  264. onnxruntime/transformers/models/stable_diffusion/demo_utils.py +778 -0
  265. onnxruntime/transformers/models/stable_diffusion/diffusion_models.py +1318 -0
  266. onnxruntime/transformers/models/stable_diffusion/diffusion_schedulers.py +1179 -0
  267. onnxruntime/transformers/models/stable_diffusion/engine_builder.py +295 -0
  268. onnxruntime/transformers/models/stable_diffusion/engine_builder_ort_cuda.py +387 -0
  269. onnxruntime/transformers/models/stable_diffusion/engine_builder_ort_trt.py +288 -0
  270. onnxruntime/transformers/models/stable_diffusion/engine_builder_tensorrt.py +395 -0
  271. onnxruntime/transformers/models/stable_diffusion/engine_builder_torch.py +108 -0
  272. onnxruntime/transformers/models/stable_diffusion/optimize_pipeline.py +590 -0
  273. onnxruntime/transformers/models/stable_diffusion/ort_optimizer.py +136 -0
  274. onnxruntime/transformers/models/stable_diffusion/pipeline_stable_diffusion.py +831 -0
  275. onnxruntime/transformers/models/stable_diffusion/trt_utilities.py +12 -0
  276. onnxruntime/transformers/models/t5/__init__.py +12 -0
  277. onnxruntime/transformers/models/t5/convert_to_onnx.py +318 -0
  278. onnxruntime/transformers/models/t5/t5_decoder.py +437 -0
  279. onnxruntime/transformers/models/t5/t5_encoder.py +70 -0
  280. onnxruntime/transformers/models/t5/t5_encoder_decoder_init.py +361 -0
  281. onnxruntime/transformers/models/t5/t5_helper.py +302 -0
  282. onnxruntime/transformers/models/whisper/__init__.py +12 -0
  283. onnxruntime/transformers/models/whisper/benchmark.py +585 -0
  284. onnxruntime/transformers/models/whisper/benchmark_all.py +526 -0
  285. onnxruntime/transformers/models/whisper/convert_to_onnx.py +609 -0
  286. onnxruntime/transformers/models/whisper/whisper_chain.py +334 -0
  287. onnxruntime/transformers/models/whisper/whisper_decoder.py +464 -0
  288. onnxruntime/transformers/models/whisper/whisper_encoder.py +164 -0
  289. onnxruntime/transformers/models/whisper/whisper_encoder_decoder_init.py +371 -0
  290. onnxruntime/transformers/models/whisper/whisper_helper.py +1035 -0
  291. onnxruntime/transformers/models/whisper/whisper_inputs.py +380 -0
  292. onnxruntime/transformers/models/whisper/whisper_jump_times.py +477 -0
  293. onnxruntime/transformers/onnx_exporter.py +719 -0
  294. onnxruntime/transformers/onnx_model.py +1636 -0
  295. onnxruntime/transformers/onnx_model_bart.py +141 -0
  296. onnxruntime/transformers/onnx_model_bert.py +488 -0
  297. onnxruntime/transformers/onnx_model_bert_keras.py +474 -0
  298. onnxruntime/transformers/onnx_model_bert_tf.py +588 -0
  299. onnxruntime/transformers/onnx_model_clip.py +42 -0
  300. onnxruntime/transformers/onnx_model_conformer.py +32 -0
  301. onnxruntime/transformers/onnx_model_gpt2.py +101 -0
  302. onnxruntime/transformers/onnx_model_mmdit.py +112 -0
  303. onnxruntime/transformers/onnx_model_phi.py +929 -0
  304. onnxruntime/transformers/onnx_model_sam2.py +137 -0
  305. onnxruntime/transformers/onnx_model_t5.py +985 -0
  306. onnxruntime/transformers/onnx_model_tnlr.py +226 -0
  307. onnxruntime/transformers/onnx_model_unet.py +258 -0
  308. onnxruntime/transformers/onnx_model_vae.py +42 -0
  309. onnxruntime/transformers/onnx_utils.py +55 -0
  310. onnxruntime/transformers/optimizer.py +620 -0
  311. onnxruntime/transformers/past_helper.py +149 -0
  312. onnxruntime/transformers/profile_result_processor.py +358 -0
  313. onnxruntime/transformers/profiler.py +434 -0
  314. onnxruntime/transformers/quantize_helper.py +76 -0
  315. onnxruntime/transformers/shape_infer_helper.py +121 -0
  316. onnxruntime/transformers/shape_optimizer.py +400 -0
  317. onnxruntime/transformers/torch_onnx_export_helper.py +74 -0
  318. onnxruntime_directml-1.24.1.dist-info/METADATA +216 -0
  319. onnxruntime_directml-1.24.1.dist-info/RECORD +322 -0
  320. onnxruntime_directml-1.24.1.dist-info/WHEEL +5 -0
  321. onnxruntime_directml-1.24.1.dist-info/entry_points.txt +2 -0
  322. onnxruntime_directml-1.24.1.dist-info/top_level.txt +1 -0
@@ -0,0 +1,389 @@
1
+ # -------------------------------------------------------------------------
2
+ # Copyright (c) Microsoft Corporation. All rights reserved.
3
+ # Licensed under the MIT License. See License.txt in the project root for
4
+ # license information.
5
+ # --------------------------------------------------------------------------
6
+ from __future__ import annotations
7
+
8
+ import copy
9
+ import logging
10
+ from pathlib import Path
11
+ from typing import Any
12
+
13
+ import numpy as np
14
+ import onnx
15
+
16
+ from ...calibrate import CalibrationDataReader, CalibrationMethod
17
+ from ...quant_utils import QuantType
18
+ from ...quantize import StaticQuantConfig
19
+ from ...tensor_quant_overrides import TensorQuantOverridesHelper
20
+ from .mixed_precision_overrides_utils import MixedPrecisionTensorQuantOverridesFixer
21
+
22
+ Q16_TYPES = {QuantType.QInt16, QuantType.QUInt16}
23
+ Q8_TYPES = {QuantType.QInt8, QuantType.QUInt8}
24
+ Q4_TYPES = {QuantType.QInt4, QuantType.QUInt4}
25
+ OP_TYPES_TO_EXCLUDE = {"Cast"}
26
+ MODEL_SIZE_THRESHOLD = 2147483648 # Quant model should use external data if >= 2GB
27
+
28
+
29
+ def warn_unable_to_override(
30
+ node: onnx.NodeProto,
31
+ what_str: str,
32
+ tensor_name: str,
33
+ io_kind: str,
34
+ ):
35
+ logging.warning(
36
+ f"Unable to override {what_str} for {node.op_type} node's {io_kind} "
37
+ "because it has already been overridden! Check the initial quantization overrides provided "
38
+ "to get_qnn_qdq_config() if the generated QDQ model does not run on QNN EP. "
39
+ f"Node name: {node.name}, {io_kind} name: {tensor_name}"
40
+ )
41
+
42
+
43
+ def get_qnn_qdq_config(
44
+ model_input: str | Path | onnx.ModelProto,
45
+ calibration_data_reader: CalibrationDataReader,
46
+ calibrate_method: CalibrationMethod = CalibrationMethod.MinMax,
47
+ activation_type: QuantType = QuantType.QUInt8,
48
+ weight_type: QuantType = QuantType.QUInt8,
49
+ per_channel: bool = False,
50
+ init_overrides: dict[str, list[dict[str, Any]]] | None = None,
51
+ add_qtype_converts: bool = True,
52
+ activation_symmetric: bool = False,
53
+ weight_symmetric: bool | None = None,
54
+ keep_removable_activations: bool = False,
55
+ stride: int | None = None,
56
+ calibration_providers: list[str] | None = None,
57
+ op_types_to_quantize: list[str] | None = None,
58
+ nodes_to_exclude: list[str] | None = None,
59
+ ) -> StaticQuantConfig:
60
+ """
61
+ Returns a static quantization configuration suitable for running QDQ models on QNN EP.
62
+ This is done primarily by setting tensor-level quantization overrides.
63
+
64
+ Params:
65
+ model_input: Path to the input model file or ModelProto.
66
+ calibration_data_reader: Calibration data reader.
67
+ calibrate_methode: The calibration method. Defaults to MinMax.
68
+ activation_type: The default activation quantization type. Defaults to QUInt8.
69
+ weight_type: The default weight quantization type. Defaults to QUInt8.
70
+ per_channel: Global option that determines if a fixed set of operator types should be quantized per-channel.
71
+ Defaults to false. Alternatively, use the tensor-level `init_overrides` to select individual operators
72
+ and their quantization axes.
73
+
74
+ If set, the quantization tool uses per-channel quantization for the following operator types and inputs:
75
+ - Conv:
76
+ - input[1] on axis 0
77
+ - input[2] (bias) on axis 0
78
+ - ConvTranspose:
79
+ - input[1] on axis 1
80
+ - input[2] (bias) on axis 0
81
+ init_overrides: Initial tensor-level quantization overrides. Defaults to None. This function updates of a copy
82
+ of these overrides with any necessary adjustments and includes them in the returned
83
+ configuration object (i.e., config.extra_options['TensorQuantOverrides']).
84
+
85
+ The key is a tensor name and the value is a list of dictionaries. For per-tensor quantization, the list
86
+ contains a single dictionary. For per-channel quantization, the list contains either a dictionary for
87
+ each channel in the tensor or a single dictionary that is assumed to apply to all channels. An 'axis'
88
+ key must be present in the first dictionary for per-channel quantization.
89
+
90
+ Each dictionary contains optional overrides with the following keys and values.
91
+ 'quant_type' = QuantType : The tensor's quantization data type.
92
+ 'axis' = Int : The per-channel axis. Must be present for per-channel weights.
93
+ 'scale' = Float : The scale value to use. Must also specify `zero_point` if set.
94
+ 'zero_point' = Int : The zero-point value to use. Must also specify `scale` is set.
95
+ 'symmetric' = Bool : If the tensor should use symmetric quantization. Invalid if also
96
+ set `scale` or `zero_point`.
97
+ 'reduce_range' = Bool : If the quantization range should be reduced. Invalid if also
98
+ set `scale` or `zero_point`. Only valid for initializers.
99
+ 'rmax' = Float : Override the maximum real tensor value in calibration data.
100
+ Invalid if also set `scale` or `zero_point`.
101
+ 'rmin' = Float : Override the minimum real tensor value in calibration data.
102
+ Invalid if also set `scale` or `zero_point`.
103
+ 'convert' = Dict : A nested dictionary with the same keys for an activation
104
+ tensor that should be converted to another quantization type.
105
+ 'convert["recv_nodes"] = Set : Set of node names that consume the converted activation,
106
+ other nodes get the original type. If not specified,
107
+ assume all consumer nodes get the converted type.
108
+ add_qtype_converts: True if this function should automatically add "convert" entries to the provided
109
+ `init_overrides` to ensure that operators use valid input/output types (activations only).
110
+ Ex: if you override the output of an Add to 16-bit, this option ensures that the activation inputs
111
+ of the Add are also up-converted to 16-bit and that data types for surrounding ops are converted
112
+ appropriately. Refer to the documentation in mixed_precision_overrides_utils.py for additional details.
113
+ activation_symmetric: True if activations should be quantized symmetrically (i.e, rmax == -rmin) by default.
114
+ Defaults to false. For int8 and int16, this results in zero-point values of 0. For uint8 and uin16,
115
+ the zero-point values are 128 and 32,768, respectively.
116
+ weight_symmetric: True if weights should be quantized symmetrically (i.e., rmax == -rmin) by default.
117
+ Defaults to None. If set to None, weight_symmetric is assumed true if the weight_type is a signed int.
118
+ keep_removable_activations: Defaults to false. If true, "removable" activations (e.g., Clip or Relu) will not
119
+ be removed, and will be explicitly represented in the QDQ model. If false, these activations
120
+ are automatically removed if activations are asymmetrically quantized. Keeping these activations
121
+ is necessary if optimizations or EP transformations will later remove
122
+ QuantizeLinear/DequantizeLinear operators from the model.
123
+ calibration_providers: Execution providers to run the session during calibration. Default is None which uses
124
+ [ "CPUExecutionProvider" ].
125
+ op_types_to_quantize: If set to None, all operator types will be quantized except for OP_TYPES_TO_EXCLUDE
126
+ nodes_to_exclude: List of nodes names to exclude from quantization. The nodes in this list will be excluded from
127
+ quantization when it is not None.
128
+
129
+ Returns:
130
+ A StaticQuantConfig object
131
+ """
132
+ if weight_symmetric is None:
133
+ weight_symmetric = weight_type in {QuantType.QInt8, QuantType.QInt16}
134
+
135
+ model = (
136
+ model_input
137
+ if isinstance(model_input, onnx.ModelProto)
138
+ else onnx.load_model(model_input, load_external_data=False)
139
+ )
140
+
141
+ op_types = set()
142
+ model_has_external_data = False
143
+ name_to_initializer = {}
144
+
145
+ # Build map of initializers (name -> initializer) and
146
+ # check if the model has external data.
147
+ for initializer in model.graph.initializer:
148
+ name_to_initializer[initializer.name] = initializer
149
+ if onnx.external_data_helper.uses_external_data(initializer):
150
+ model_has_external_data = True
151
+
152
+ overrides_helper = TensorQuantOverridesHelper(copy.deepcopy(init_overrides) if init_overrides else {})
153
+
154
+ if not overrides_helper.empty() and add_qtype_converts:
155
+ # Fix mixed-precision overrides.
156
+ overrides_fixer = MixedPrecisionTensorQuantOverridesFixer.create_from_model(
157
+ overrides_helper, model, activation_type
158
+ )
159
+ overrides_fixer.apply(activation_type, activation_symmetric)
160
+
161
+ # Setup quantization overrides for specific operator types to ensure compatibility with QNN EP.
162
+ qnn_compat = QnnCompatibilityOverrides(
163
+ activation_type,
164
+ weight_type,
165
+ activation_symmetric,
166
+ weight_symmetric,
167
+ per_channel,
168
+ overrides_helper,
169
+ name_to_initializer,
170
+ )
171
+
172
+ op_types_to_quantize_set = set(op_types_to_quantize) if op_types_to_quantize else None
173
+ nodes_to_exclude_set = set(nodes_to_exclude) if nodes_to_exclude else None
174
+
175
+ for node in model.graph.node:
176
+ if op_types_to_quantize_set and node.op_type not in op_types_to_quantize_set:
177
+ continue
178
+ if nodes_to_exclude_set and node.name in nodes_to_exclude_set:
179
+ continue
180
+ op_types.add(node.op_type)
181
+ qnn_compat.process_node(node)
182
+
183
+ extra_options = {
184
+ "MinimumRealRange": 0.0001,
185
+ "DedicatedQDQPair": False, # Let ORT optimizer duplicate DQ nodes
186
+ "QDQKeepRemovableActivations": keep_removable_activations,
187
+ "TensorQuantOverrides": overrides_helper.get_dict(),
188
+ "ActivationSymmetric": activation_symmetric,
189
+ "WeightSymmetric": weight_symmetric,
190
+ "CalibStridedMinMax": stride,
191
+ }
192
+
193
+ # ONNX opset < 21 does not support 16-bit quantization, so must use 'com.microsoft' domain
194
+ # on Q/DQ operators if using 16-bit or 4-bit quantization.
195
+ onnx_opset = next(x for x in model.opset_import if x.domain == "" or x.domain == "ai.onnx")
196
+ if onnx_opset.version < 21:
197
+ opset21_types = Q16_TYPES.union(Q4_TYPES)
198
+ overrides_have_opset21_types = any(t in opset21_types for t in overrides_helper.get_quant_types())
199
+ if activation_type in opset21_types or weight_type in opset21_types or overrides_have_opset21_types:
200
+ extra_options["UseQDQContribOps"] = True
201
+
202
+ return StaticQuantConfig(
203
+ calibration_data_reader,
204
+ calibrate_method=calibrate_method,
205
+ activation_type=activation_type,
206
+ weight_type=weight_type,
207
+ op_types_to_quantize=(
208
+ op_types_to_quantize if op_types_to_quantize else list(op_types.difference(OP_TYPES_TO_EXCLUDE))
209
+ ),
210
+ nodes_to_exclude=nodes_to_exclude,
211
+ per_channel=per_channel,
212
+ use_external_data_format=(model_has_external_data or model.ByteSize() >= MODEL_SIZE_THRESHOLD),
213
+ calibration_providers=calibration_providers,
214
+ extra_options=extra_options,
215
+ )
216
+
217
+
218
+ class QnnCompatibilityOverrides:
219
+ """
220
+ Helper that processes nodes to generate quantization overrides that make the resulting QDQ model
221
+ compatible with QNN EP.
222
+ """
223
+
224
+ def __init__(
225
+ self,
226
+ default_activation_qtype: QuantType,
227
+ default_weight_qtype: QuantType,
228
+ activation_symmetric: bool,
229
+ weight_symmetric: bool,
230
+ per_channel: bool,
231
+ overrides: TensorQuantOverridesHelper,
232
+ initializers: dict[str, onnx.TensorProto],
233
+ ):
234
+ self.default_activation_qtype = default_activation_qtype
235
+ self.default_weight_qtype = default_weight_qtype
236
+ self.activation_symmetric = activation_symmetric
237
+ self.weight_symmetric = weight_symmetric
238
+ self.per_channel = per_channel
239
+ self.overrides = overrides
240
+ self.initializers = initializers
241
+
242
+ self.process_fns = {
243
+ "MatMul": self._process_matmul,
244
+ "LayerNormalization": self._process_layernorm,
245
+ "Sigmoid": self._process_sigmoid,
246
+ "Tanh": self._process_tanh,
247
+ }
248
+
249
+ def process_node(self, node: onnx.NodeProto):
250
+ process_fn = self.process_fns.get(node.op_type)
251
+
252
+ if process_fn is not None:
253
+ process_fn(node)
254
+
255
+ def _make_static_inputs_use_default_weight_type(self, node: onnx.NodeProto):
256
+ """
257
+ Overrides initializer input(s) to use the default weight type if:
258
+ - The default weight type is 8-bit
259
+ - One of the inputs is a 16-bit activation
260
+ - The other input is an initializer (per-tensor quantized)
261
+
262
+ This is necessary because the quantization tool does not assign MatMul or LayerNorm initializer
263
+ inputs the default weight type. Instead, it assigns the default activation type.
264
+ """
265
+ if self.default_weight_qtype not in Q8_TYPES:
266
+ return
267
+
268
+ input_16bit_act_name = None
269
+ input_weight_name = None
270
+
271
+ # Loop through first 2 inputs to find a 16-bit activation and a (per-tensor) weight.
272
+ for i in range(2):
273
+ input_name = node.input[i]
274
+ if not input_name:
275
+ continue
276
+
277
+ is_weight = input_name in self.initializers
278
+ qtype_info = self.overrides.get_node_input_qtype_info(
279
+ input_name,
280
+ node.name,
281
+ default_qtype=None if is_weight else self.default_activation_qtype,
282
+ )
283
+
284
+ if qtype_info.axis is not None:
285
+ return # Don't process MatMul with a per-channel quantized input.
286
+
287
+ if (
288
+ is_weight
289
+ and qtype_info.quant_type == self.default_weight_qtype
290
+ and qtype_info.symmetric == self.weight_symmetric
291
+ ):
292
+ return # Return. Weight is already overridden to use the desired weight type.
293
+
294
+ if is_weight:
295
+ input_weight_name = input_name
296
+ elif qtype_info.quant_type in Q16_TYPES:
297
+ input_16bit_act_name = input_name
298
+
299
+ # Override initializer input to use the default weight type.
300
+ if input_16bit_act_name and input_weight_name:
301
+ did_update = self.overrides.update_tensor_overrides(
302
+ input_weight_name,
303
+ {"quant_type": self.default_weight_qtype, "symmetric": self.weight_symmetric},
304
+ overwrite=False,
305
+ )
306
+
307
+ if not did_update:
308
+ warn_unable_to_override(node, "quant_type/symmetric", input_weight_name, "input weight")
309
+
310
+ def _process_matmul(self, node: onnx.NodeProto):
311
+ assert node.op_type == "MatMul", f"Expected MatMul, but got {node.op_type}"
312
+
313
+ if not self.per_channel:
314
+ self._make_static_inputs_use_default_weight_type(node)
315
+ return
316
+
317
+ # QNN does not support per-channel MatMul. However, the ORT quantization tool attempts to use per-channel
318
+ # quantization for MatMul by default *if* the global per_channel setting is enabled. So, we need to
319
+ # provide explicit per-tensor quantization overrides for MatMul if per_channel is enabled and
320
+ # the user did not provide any other overrides.
321
+ for input_name in node.input:
322
+ is_weight_no_overrides = input_name in self.initializers and input_name not in self.overrides
323
+ if is_weight_no_overrides:
324
+ self.overrides.update_tensor_overrides(
325
+ input_name,
326
+ {"quant_type": self.default_weight_qtype, "symmetric": self.weight_symmetric},
327
+ )
328
+
329
+ def _process_layernorm(self, node: onnx.NodeProto):
330
+ assert node.op_type == "LayerNormalization", f"Expected LayerNormalization, but got {node.op_type}"
331
+
332
+ if not self.per_channel:
333
+ self._make_static_inputs_use_default_weight_type(node)
334
+
335
+ def _process_sigmoid(self, node: onnx.NodeProto):
336
+ """
337
+ Overrides 16-bit Sigmoid's output scale and zero-point as per QNN requirements.
338
+ """
339
+ assert node.op_type == "Sigmoid", f"Expected Sigmoid, but got {node.op_type}"
340
+ output_type = self.overrides.get_node_output_qtype_info(
341
+ node.output[0], self.default_activation_qtype
342
+ ).quant_type
343
+
344
+ if output_type == QuantType.QUInt16:
345
+ self.overrides.update_tensor_overrides(
346
+ node.output[0],
347
+ {
348
+ "quant_type": output_type,
349
+ "scale": np.array(1.0 / 65536.0, dtype=np.float32),
350
+ "zero_point": np.array(0, dtype=np.uint16),
351
+ },
352
+ )
353
+ elif output_type == QuantType.QInt16:
354
+ self.overrides.update_tensor_overrides(
355
+ node.output[0],
356
+ {
357
+ "quant_type": output_type,
358
+ "scale": np.array(1.0 / 32768.0, dtype=np.float32),
359
+ "zero_point": np.array(0, dtype=np.int16),
360
+ },
361
+ )
362
+
363
+ def _process_tanh(self, node: onnx.NodeProto):
364
+ """
365
+ Overrides 16-bit Tanh's output scale and zero-point as per QNN requirements.
366
+ """
367
+ assert node.op_type == "Tanh", f"Expected Tanh, but got {node.op_type}"
368
+ output_type = self.overrides.get_node_output_qtype_info(
369
+ node.output[0], self.default_activation_qtype
370
+ ).quant_type
371
+
372
+ if output_type == QuantType.QUInt16:
373
+ self.overrides.update_tensor_overrides(
374
+ node.output[0],
375
+ {
376
+ "quant_type": output_type,
377
+ "scale": np.array(1.0 / 32768.0, dtype=np.float32),
378
+ "zero_point": np.array(32768, dtype=np.uint16),
379
+ },
380
+ )
381
+ elif output_type == QuantType.QInt16:
382
+ self.overrides.update_tensor_overrides(
383
+ node.output[0],
384
+ {
385
+ "quant_type": output_type,
386
+ "scale": np.array(1.0 / 32768.0, dtype=np.float32),
387
+ "zero_point": np.array(0, dtype=np.int16),
388
+ },
389
+ )
@@ -0,0 +1,4 @@
1
+ from .fusion import Fusion # noqa: F401
2
+ from .fusion_gelu import FusionGelu # noqa: F401
3
+ from .fusion_layernorm import FusionLayerNormalization # noqa: F401
4
+ from .replace_upsample_with_resize import ReplaceUpsampleWithResize # noqa: F401