onnxruntime-directml 1.20.0__cp313-cp313-win_amd64.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (305) hide show
  1. onnxruntime/LICENSE +21 -0
  2. onnxruntime/Privacy.md +21 -0
  3. onnxruntime/ThirdPartyNotices.txt +6508 -0
  4. onnxruntime/__init__.py +78 -0
  5. onnxruntime/backend/__init__.py +6 -0
  6. onnxruntime/backend/backend.py +174 -0
  7. onnxruntime/backend/backend_rep.py +53 -0
  8. onnxruntime/capi/DirectML.dll +0 -0
  9. onnxruntime/capi/__init__.py +4 -0
  10. onnxruntime/capi/_ld_preload.py +7 -0
  11. onnxruntime/capi/_pybind_state.py +33 -0
  12. onnxruntime/capi/convert_npz_to_onnx_adapter.py +48 -0
  13. onnxruntime/capi/onnxruntime.dll +0 -0
  14. onnxruntime/capi/onnxruntime_collect_build_info.py +47 -0
  15. onnxruntime/capi/onnxruntime_inference_collection.py +1108 -0
  16. onnxruntime/capi/onnxruntime_providers_shared.dll +0 -0
  17. onnxruntime/capi/onnxruntime_pybind11_state.pyd +0 -0
  18. onnxruntime/capi/onnxruntime_validation.py +150 -0
  19. onnxruntime/capi/version_info.py +2 -0
  20. onnxruntime/datasets/__init__.py +17 -0
  21. onnxruntime/datasets/logreg_iris.onnx +0 -0
  22. onnxruntime/datasets/mul_1.onnx +0 -0
  23. onnxruntime/datasets/sigmoid.onnx +13 -0
  24. onnxruntime/quantization/CalTableFlatBuffers/KeyValue.py +78 -0
  25. onnxruntime/quantization/CalTableFlatBuffers/TrtTable.py +90 -0
  26. onnxruntime/quantization/CalTableFlatBuffers/__init__.py +0 -0
  27. onnxruntime/quantization/__init__.py +16 -0
  28. onnxruntime/quantization/base_quantizer.py +532 -0
  29. onnxruntime/quantization/calibrate.py +1245 -0
  30. onnxruntime/quantization/execution_providers/qnn/__init__.py +2 -0
  31. onnxruntime/quantization/execution_providers/qnn/fusion_lpnorm.py +132 -0
  32. onnxruntime/quantization/execution_providers/qnn/mixed_precision_overrides_utils.py +413 -0
  33. onnxruntime/quantization/execution_providers/qnn/preprocess.py +307 -0
  34. onnxruntime/quantization/execution_providers/qnn/quant_config.py +387 -0
  35. onnxruntime/quantization/fusions/__init__.py +3 -0
  36. onnxruntime/quantization/fusions/fusion.py +311 -0
  37. onnxruntime/quantization/fusions/fusion_gelu.py +272 -0
  38. onnxruntime/quantization/fusions/fusion_layernorm.py +135 -0
  39. onnxruntime/quantization/matmul_4bits_quantizer.py +1480 -0
  40. onnxruntime/quantization/matmul_bnb4_quantizer.py +240 -0
  41. onnxruntime/quantization/onnx_model.py +580 -0
  42. onnxruntime/quantization/onnx_quantizer.py +1008 -0
  43. onnxruntime/quantization/operators/__init__.py +2 -0
  44. onnxruntime/quantization/operators/activation.py +119 -0
  45. onnxruntime/quantization/operators/argmax.py +18 -0
  46. onnxruntime/quantization/operators/attention.py +73 -0
  47. onnxruntime/quantization/operators/base_operator.py +26 -0
  48. onnxruntime/quantization/operators/binary_op.py +72 -0
  49. onnxruntime/quantization/operators/concat.py +62 -0
  50. onnxruntime/quantization/operators/conv.py +258 -0
  51. onnxruntime/quantization/operators/direct_q8.py +78 -0
  52. onnxruntime/quantization/operators/embed_layernorm.py +121 -0
  53. onnxruntime/quantization/operators/gather.py +64 -0
  54. onnxruntime/quantization/operators/gavgpool.py +62 -0
  55. onnxruntime/quantization/operators/gemm.py +166 -0
  56. onnxruntime/quantization/operators/lstm.py +117 -0
  57. onnxruntime/quantization/operators/matmul.py +231 -0
  58. onnxruntime/quantization/operators/maxpool.py +34 -0
  59. onnxruntime/quantization/operators/norm.py +40 -0
  60. onnxruntime/quantization/operators/pad.py +100 -0
  61. onnxruntime/quantization/operators/pooling.py +67 -0
  62. onnxruntime/quantization/operators/qdq_base_operator.py +22 -0
  63. onnxruntime/quantization/operators/resize.py +34 -0
  64. onnxruntime/quantization/operators/softmax.py +74 -0
  65. onnxruntime/quantization/operators/split.py +63 -0
  66. onnxruntime/quantization/operators/where.py +87 -0
  67. onnxruntime/quantization/preprocess.py +141 -0
  68. onnxruntime/quantization/qdq_loss_debug.py +389 -0
  69. onnxruntime/quantization/qdq_quantizer.py +1187 -0
  70. onnxruntime/quantization/quant_utils.py +891 -0
  71. onnxruntime/quantization/quantize.py +748 -0
  72. onnxruntime/quantization/registry.py +106 -0
  73. onnxruntime/quantization/shape_inference.py +187 -0
  74. onnxruntime/quantization/tensor_quant_overrides.py +516 -0
  75. onnxruntime/tools/__init__.py +10 -0
  76. onnxruntime/tools/check_onnx_model_mobile_usability.py +47 -0
  77. onnxruntime/tools/convert_onnx_models_to_ort.py +377 -0
  78. onnxruntime/tools/file_utils.py +46 -0
  79. onnxruntime/tools/logger.py +11 -0
  80. onnxruntime/tools/make_dynamic_shape_fixed.py +72 -0
  81. onnxruntime/tools/mobile_helpers/__init__.py +0 -0
  82. onnxruntime/tools/mobile_helpers/coreml_supported_mlprogram_ops.md +33 -0
  83. onnxruntime/tools/mobile_helpers/coreml_supported_neuralnetwork_ops.md +43 -0
  84. onnxruntime/tools/mobile_helpers/nnapi_supported_ops.md +58 -0
  85. onnxruntime/tools/mobile_helpers/usability_checker.py +739 -0
  86. onnxruntime/tools/offline_tuning.py +169 -0
  87. onnxruntime/tools/onnx_model_utils.py +413 -0
  88. onnxruntime/tools/onnx_randomizer.py +85 -0
  89. onnxruntime/tools/onnxruntime_test.py +164 -0
  90. onnxruntime/tools/optimize_onnx_model.py +55 -0
  91. onnxruntime/tools/ort_format_model/__init__.py +25 -0
  92. onnxruntime/tools/ort_format_model/operator_type_usage_processors.py +663 -0
  93. onnxruntime/tools/ort_format_model/ort_flatbuffers_py/__init__.py +0 -0
  94. onnxruntime/tools/ort_format_model/ort_flatbuffers_py/fbs/ArgType.py +7 -0
  95. onnxruntime/tools/ort_format_model/ort_flatbuffers_py/fbs/ArgTypeAndIndex.py +67 -0
  96. onnxruntime/tools/ort_format_model/ort_flatbuffers_py/fbs/Attribute.py +337 -0
  97. onnxruntime/tools/ort_format_model/ort_flatbuffers_py/fbs/AttributeType.py +18 -0
  98. onnxruntime/tools/ort_format_model/ort_flatbuffers_py/fbs/Checkpoint.py +125 -0
  99. onnxruntime/tools/ort_format_model/ort_flatbuffers_py/fbs/DeprecatedKernelCreateInfos.py +120 -0
  100. onnxruntime/tools/ort_format_model/ort_flatbuffers_py/fbs/DeprecatedNodeIndexAndKernelDefHash.py +68 -0
  101. onnxruntime/tools/ort_format_model/ort_flatbuffers_py/fbs/DeprecatedSessionState.py +96 -0
  102. onnxruntime/tools/ort_format_model/ort_flatbuffers_py/fbs/DeprecatedSubGraphSessionState.py +72 -0
  103. onnxruntime/tools/ort_format_model/ort_flatbuffers_py/fbs/Dimension.py +71 -0
  104. onnxruntime/tools/ort_format_model/ort_flatbuffers_py/fbs/DimensionValue.py +80 -0
  105. onnxruntime/tools/ort_format_model/ort_flatbuffers_py/fbs/DimensionValueType.py +8 -0
  106. onnxruntime/tools/ort_format_model/ort_flatbuffers_py/fbs/EdgeEnd.py +32 -0
  107. onnxruntime/tools/ort_format_model/ort_flatbuffers_py/fbs/FloatProperty.py +67 -0
  108. onnxruntime/tools/ort_format_model/ort_flatbuffers_py/fbs/Graph.py +320 -0
  109. onnxruntime/tools/ort_format_model/ort_flatbuffers_py/fbs/InferenceSession.py +88 -0
  110. onnxruntime/tools/ort_format_model/ort_flatbuffers_py/fbs/IntProperty.py +67 -0
  111. onnxruntime/tools/ort_format_model/ort_flatbuffers_py/fbs/KernelTypeStrArgsEntry.py +91 -0
  112. onnxruntime/tools/ort_format_model/ort_flatbuffers_py/fbs/KernelTypeStrResolver.py +78 -0
  113. onnxruntime/tools/ort_format_model/ort_flatbuffers_py/fbs/MapType.py +71 -0
  114. onnxruntime/tools/ort_format_model/ort_flatbuffers_py/fbs/Model.py +223 -0
  115. onnxruntime/tools/ort_format_model/ort_flatbuffers_py/fbs/ModuleState.py +141 -0
  116. onnxruntime/tools/ort_format_model/ort_flatbuffers_py/fbs/Node.py +317 -0
  117. onnxruntime/tools/ort_format_model/ort_flatbuffers_py/fbs/NodeEdge.py +126 -0
  118. onnxruntime/tools/ort_format_model/ort_flatbuffers_py/fbs/NodeType.py +7 -0
  119. onnxruntime/tools/ort_format_model/ort_flatbuffers_py/fbs/NodesToOptimizeIndices.py +160 -0
  120. onnxruntime/tools/ort_format_model/ort_flatbuffers_py/fbs/OpIdKernelTypeStrArgsEntry.py +91 -0
  121. onnxruntime/tools/ort_format_model/ort_flatbuffers_py/fbs/OperatorSetId.py +67 -0
  122. onnxruntime/tools/ort_format_model/ort_flatbuffers_py/fbs/OptimizerGroup.py +117 -0
  123. onnxruntime/tools/ort_format_model/ort_flatbuffers_py/fbs/ParameterOptimizerState.py +91 -0
  124. onnxruntime/tools/ort_format_model/ort_flatbuffers_py/fbs/PropertyBag.py +152 -0
  125. onnxruntime/tools/ort_format_model/ort_flatbuffers_py/fbs/RuntimeOptimizationRecord.py +105 -0
  126. onnxruntime/tools/ort_format_model/ort_flatbuffers_py/fbs/RuntimeOptimizationRecordContainerEntry.py +91 -0
  127. onnxruntime/tools/ort_format_model/ort_flatbuffers_py/fbs/RuntimeOptimizations.py +79 -0
  128. onnxruntime/tools/ort_format_model/ort_flatbuffers_py/fbs/SequenceType.py +58 -0
  129. onnxruntime/tools/ort_format_model/ort_flatbuffers_py/fbs/Shape.py +78 -0
  130. onnxruntime/tools/ort_format_model/ort_flatbuffers_py/fbs/SparseTensor.py +114 -0
  131. onnxruntime/tools/ort_format_model/ort_flatbuffers_py/fbs/StringProperty.py +67 -0
  132. onnxruntime/tools/ort_format_model/ort_flatbuffers_py/fbs/StringStringEntry.py +67 -0
  133. onnxruntime/tools/ort_format_model/ort_flatbuffers_py/fbs/Tensor.py +203 -0
  134. onnxruntime/tools/ort_format_model/ort_flatbuffers_py/fbs/TensorDataType.py +26 -0
  135. onnxruntime/tools/ort_format_model/ort_flatbuffers_py/fbs/TensorTypeAndShape.py +71 -0
  136. onnxruntime/tools/ort_format_model/ort_flatbuffers_py/fbs/TypeInfo.py +83 -0
  137. onnxruntime/tools/ort_format_model/ort_flatbuffers_py/fbs/TypeInfoValue.py +9 -0
  138. onnxruntime/tools/ort_format_model/ort_flatbuffers_py/fbs/ValueInfo.py +84 -0
  139. onnxruntime/tools/ort_format_model/ort_flatbuffers_py/fbs/__init__.py +6 -0
  140. onnxruntime/tools/ort_format_model/ort_model_processor.py +86 -0
  141. onnxruntime/tools/ort_format_model/types.py +84 -0
  142. onnxruntime/tools/ort_format_model/utils.py +62 -0
  143. onnxruntime/tools/pytorch_export_contrib_ops.py +108 -0
  144. onnxruntime/tools/pytorch_export_helpers.py +131 -0
  145. onnxruntime/tools/qdq_helpers/__init__.py +0 -0
  146. onnxruntime/tools/qdq_helpers/optimize_qdq_model.py +37 -0
  147. onnxruntime/tools/reduced_build_config_parser.py +202 -0
  148. onnxruntime/tools/symbolic_shape_infer.py +3016 -0
  149. onnxruntime/tools/update_onnx_opset.py +31 -0
  150. onnxruntime/transformers/__init__.py +8 -0
  151. onnxruntime/transformers/affinity_helper.py +40 -0
  152. onnxruntime/transformers/benchmark.py +944 -0
  153. onnxruntime/transformers/benchmark_helper.py +646 -0
  154. onnxruntime/transformers/bert_perf_test.py +634 -0
  155. onnxruntime/transformers/bert_test_data.py +642 -0
  156. onnxruntime/transformers/compare_bert_results.py +246 -0
  157. onnxruntime/transformers/constants.py +47 -0
  158. onnxruntime/transformers/convert_generation.py +3124 -0
  159. onnxruntime/transformers/convert_tf_models_to_pytorch.py +205 -0
  160. onnxruntime/transformers/convert_to_packing_mode.py +387 -0
  161. onnxruntime/transformers/dynamo_onnx_helper.py +104 -0
  162. onnxruntime/transformers/float16.py +501 -0
  163. onnxruntime/transformers/fusion_attention.py +1235 -0
  164. onnxruntime/transformers/fusion_attention_clip.py +257 -0
  165. onnxruntime/transformers/fusion_attention_sam2.py +534 -0
  166. onnxruntime/transformers/fusion_attention_unet.py +1304 -0
  167. onnxruntime/transformers/fusion_attention_vae.py +301 -0
  168. onnxruntime/transformers/fusion_bart_attention.py +640 -0
  169. onnxruntime/transformers/fusion_base.py +137 -0
  170. onnxruntime/transformers/fusion_bias_add.py +58 -0
  171. onnxruntime/transformers/fusion_biasgelu.py +66 -0
  172. onnxruntime/transformers/fusion_biassplitgelu.py +111 -0
  173. onnxruntime/transformers/fusion_conformer_attention.py +143 -0
  174. onnxruntime/transformers/fusion_embedlayer.py +811 -0
  175. onnxruntime/transformers/fusion_fastgelu.py +360 -0
  176. onnxruntime/transformers/fusion_gelu.py +259 -0
  177. onnxruntime/transformers/fusion_gelu_approximation.py +25 -0
  178. onnxruntime/transformers/fusion_gemmfastgelu.py +122 -0
  179. onnxruntime/transformers/fusion_gpt_attention.py +546 -0
  180. onnxruntime/transformers/fusion_gpt_attention_megatron.py +355 -0
  181. onnxruntime/transformers/fusion_gpt_attention_no_past.py +260 -0
  182. onnxruntime/transformers/fusion_group_norm.py +179 -0
  183. onnxruntime/transformers/fusion_layernorm.py +465 -0
  184. onnxruntime/transformers/fusion_nhwc_conv.py +100 -0
  185. onnxruntime/transformers/fusion_options.py +340 -0
  186. onnxruntime/transformers/fusion_qordered_attention.py +421 -0
  187. onnxruntime/transformers/fusion_qordered_gelu.py +119 -0
  188. onnxruntime/transformers/fusion_qordered_layernorm.py +123 -0
  189. onnxruntime/transformers/fusion_qordered_matmul.py +217 -0
  190. onnxruntime/transformers/fusion_quickgelu.py +74 -0
  191. onnxruntime/transformers/fusion_reshape.py +173 -0
  192. onnxruntime/transformers/fusion_rotary_attention.py +1592 -0
  193. onnxruntime/transformers/fusion_shape.py +110 -0
  194. onnxruntime/transformers/fusion_simplified_layernorm.py +159 -0
  195. onnxruntime/transformers/fusion_skip_group_norm.py +255 -0
  196. onnxruntime/transformers/fusion_skiplayernorm.py +209 -0
  197. onnxruntime/transformers/fusion_transpose.py +168 -0
  198. onnxruntime/transformers/fusion_utils.py +307 -0
  199. onnxruntime/transformers/huggingface_models.py +167 -0
  200. onnxruntime/transformers/import_utils.py +20 -0
  201. onnxruntime/transformers/io_binding_helper.py +442 -0
  202. onnxruntime/transformers/large_model_exporter.py +395 -0
  203. onnxruntime/transformers/machine_info.py +221 -0
  204. onnxruntime/transformers/metrics.py +164 -0
  205. onnxruntime/transformers/models/bart/__init__.py +12 -0
  206. onnxruntime/transformers/models/bart/export.py +98 -0
  207. onnxruntime/transformers/models/bert/__init__.py +12 -0
  208. onnxruntime/transformers/models/bert/eval_squad.py +329 -0
  209. onnxruntime/transformers/models/gpt2/__init__.py +12 -0
  210. onnxruntime/transformers/models/gpt2/benchmark_gpt2.py +413 -0
  211. onnxruntime/transformers/models/gpt2/convert_to_onnx.py +561 -0
  212. onnxruntime/transformers/models/gpt2/gpt2_helper.py +1032 -0
  213. onnxruntime/transformers/models/gpt2/gpt2_parity.py +513 -0
  214. onnxruntime/transformers/models/gpt2/gpt2_tester.py +501 -0
  215. onnxruntime/transformers/models/gpt2/parity_check_helper.py +146 -0
  216. onnxruntime/transformers/models/llama/__init__.py +12 -0
  217. onnxruntime/transformers/models/llama/benchmark.py +703 -0
  218. onnxruntime/transformers/models/llama/benchmark_all.py +488 -0
  219. onnxruntime/transformers/models/llama/benchmark_e2e.py +606 -0
  220. onnxruntime/transformers/models/llama/convert_to_onnx.py +1027 -0
  221. onnxruntime/transformers/models/llama/dist_settings.py +57 -0
  222. onnxruntime/transformers/models/llama/llama_inputs.py +503 -0
  223. onnxruntime/transformers/models/llama/llama_parity.py +309 -0
  224. onnxruntime/transformers/models/llama/llama_torch.py +47 -0
  225. onnxruntime/transformers/models/llama/quant_kv_dataloader.py +108 -0
  226. onnxruntime/transformers/models/longformer/__init__.py +12 -0
  227. onnxruntime/transformers/models/longformer/benchmark_longformer.py +821 -0
  228. onnxruntime/transformers/models/longformer/convert_to_onnx.py +413 -0
  229. onnxruntime/transformers/models/longformer/generate_test_data.py +347 -0
  230. onnxruntime/transformers/models/longformer/longformer_helper.py +77 -0
  231. onnxruntime/transformers/models/phi2/__init__.py +12 -0
  232. onnxruntime/transformers/models/phi2/convert_to_onnx.py +576 -0
  233. onnxruntime/transformers/models/phi2/inference_example.py +414 -0
  234. onnxruntime/transformers/models/sam2/__init__.py +12 -0
  235. onnxruntime/transformers/models/sam2/benchmark_sam2.py +625 -0
  236. onnxruntime/transformers/models/sam2/convert_to_onnx.py +260 -0
  237. onnxruntime/transformers/models/sam2/image_decoder.py +273 -0
  238. onnxruntime/transformers/models/sam2/image_encoder.py +186 -0
  239. onnxruntime/transformers/models/sam2/mask_decoder.py +208 -0
  240. onnxruntime/transformers/models/sam2/nvtx_helper.py +33 -0
  241. onnxruntime/transformers/models/sam2/prompt_encoder.py +189 -0
  242. onnxruntime/transformers/models/sam2/sam2_demo.py +322 -0
  243. onnxruntime/transformers/models/sam2/sam2_image_onnx_predictor.py +280 -0
  244. onnxruntime/transformers/models/sam2/sam2_utils.py +147 -0
  245. onnxruntime/transformers/models/stable_diffusion/__init__.py +12 -0
  246. onnxruntime/transformers/models/stable_diffusion/benchmark.py +1429 -0
  247. onnxruntime/transformers/models/stable_diffusion/benchmark_controlnet.py +426 -0
  248. onnxruntime/transformers/models/stable_diffusion/demo_txt2img.py +102 -0
  249. onnxruntime/transformers/models/stable_diffusion/demo_txt2img_xl.py +268 -0
  250. onnxruntime/transformers/models/stable_diffusion/demo_utils.py +778 -0
  251. onnxruntime/transformers/models/stable_diffusion/diffusion_models.py +1319 -0
  252. onnxruntime/transformers/models/stable_diffusion/diffusion_schedulers.py +1181 -0
  253. onnxruntime/transformers/models/stable_diffusion/engine_builder.py +296 -0
  254. onnxruntime/transformers/models/stable_diffusion/engine_builder_ort_cuda.py +388 -0
  255. onnxruntime/transformers/models/stable_diffusion/engine_builder_ort_trt.py +288 -0
  256. onnxruntime/transformers/models/stable_diffusion/engine_builder_tensorrt.py +395 -0
  257. onnxruntime/transformers/models/stable_diffusion/engine_builder_torch.py +108 -0
  258. onnxruntime/transformers/models/stable_diffusion/optimize_pipeline.py +350 -0
  259. onnxruntime/transformers/models/stable_diffusion/ort_optimizer.py +136 -0
  260. onnxruntime/transformers/models/stable_diffusion/pipeline_stable_diffusion.py +831 -0
  261. onnxruntime/transformers/models/stable_diffusion/trt_utilities.py +12 -0
  262. onnxruntime/transformers/models/t5/__init__.py +12 -0
  263. onnxruntime/transformers/models/t5/convert_to_onnx.py +278 -0
  264. onnxruntime/transformers/models/t5/past_helper.py +150 -0
  265. onnxruntime/transformers/models/t5/t5_decoder.py +438 -0
  266. onnxruntime/transformers/models/t5/t5_encoder.py +171 -0
  267. onnxruntime/transformers/models/t5/t5_encoder_decoder_init.py +299 -0
  268. onnxruntime/transformers/models/t5/t5_helper.py +272 -0
  269. onnxruntime/transformers/models/whisper/__init__.py +12 -0
  270. onnxruntime/transformers/models/whisper/benchmark.py +610 -0
  271. onnxruntime/transformers/models/whisper/benchmark_all.py +528 -0
  272. onnxruntime/transformers/models/whisper/convert_to_onnx.py +536 -0
  273. onnxruntime/transformers/models/whisper/whisper_chain.py +329 -0
  274. onnxruntime/transformers/models/whisper/whisper_decoder.py +402 -0
  275. onnxruntime/transformers/models/whisper/whisper_encoder.py +164 -0
  276. onnxruntime/transformers/models/whisper/whisper_encoder_decoder_init.py +306 -0
  277. onnxruntime/transformers/models/whisper/whisper_helper.py +524 -0
  278. onnxruntime/transformers/models/whisper/whisper_openai_helper.py +84 -0
  279. onnxruntime/transformers/onnx_exporter.py +717 -0
  280. onnxruntime/transformers/onnx_model.py +1569 -0
  281. onnxruntime/transformers/onnx_model_bart.py +142 -0
  282. onnxruntime/transformers/onnx_model_bert.py +481 -0
  283. onnxruntime/transformers/onnx_model_bert_keras.py +475 -0
  284. onnxruntime/transformers/onnx_model_bert_tf.py +589 -0
  285. onnxruntime/transformers/onnx_model_clip.py +40 -0
  286. onnxruntime/transformers/onnx_model_conformer.py +33 -0
  287. onnxruntime/transformers/onnx_model_gpt2.py +101 -0
  288. onnxruntime/transformers/onnx_model_phi.py +930 -0
  289. onnxruntime/transformers/onnx_model_sam2.py +138 -0
  290. onnxruntime/transformers/onnx_model_t5.py +791 -0
  291. onnxruntime/transformers/onnx_model_tnlr.py +227 -0
  292. onnxruntime/transformers/onnx_model_unet.py +259 -0
  293. onnxruntime/transformers/onnx_model_vae.py +43 -0
  294. onnxruntime/transformers/onnx_utils.py +55 -0
  295. onnxruntime/transformers/optimizer.py +612 -0
  296. onnxruntime/transformers/profiler.py +725 -0
  297. onnxruntime/transformers/quantize_helper.py +76 -0
  298. onnxruntime/transformers/shape_infer_helper.py +122 -0
  299. onnxruntime/transformers/shape_optimizer.py +401 -0
  300. onnxruntime/transformers/torch_onnx_export_helper.py +74 -0
  301. onnxruntime_directml-1.20.0.dist-info/METADATA +187 -0
  302. onnxruntime_directml-1.20.0.dist-info/RECORD +305 -0
  303. onnxruntime_directml-1.20.0.dist-info/WHEEL +5 -0
  304. onnxruntime_directml-1.20.0.dist-info/entry_points.txt +2 -0
  305. onnxruntime_directml-1.20.0.dist-info/top_level.txt +1 -0
@@ -0,0 +1,387 @@
1
+ # -------------------------------------------------------------------------
2
+ # Copyright (c) Microsoft Corporation. All rights reserved.
3
+ # Licensed under the MIT License. See License.txt in the project root for
4
+ # license information.
5
+ # --------------------------------------------------------------------------
6
+ from __future__ import annotations
7
+
8
+ import copy
9
+ import logging
10
+ from pathlib import Path
11
+ from typing import Any
12
+
13
+ import numpy as np
14
+ import onnx
15
+
16
+ from ...calibrate import CalibrationDataReader, CalibrationMethod
17
+ from ...quant_utils import QuantType
18
+ from ...quantize import StaticQuantConfig
19
+ from ...tensor_quant_overrides import TensorQuantOverridesHelper
20
+ from .mixed_precision_overrides_utils import MixedPrecisionTensorQuantOverridesFixer
21
+
22
+ Q16_TYPES = {QuantType.QInt16, QuantType.QUInt16}
23
+ Q8_TYPES = {QuantType.QInt8, QuantType.QUInt8}
24
+ Q4_TYPES = {QuantType.QInt4, QuantType.QUInt4}
25
+ OP_TYPES_TO_EXCLUDE = {"Cast"}
26
+ MODEL_SIZE_THRESHOLD = 2147483648 # Quant model should use external data if >= 2GB
27
+
28
+
29
+ def warn_unable_to_override(
30
+ node: onnx.NodeProto,
31
+ what_str: str,
32
+ tensor_name: str,
33
+ io_kind: str,
34
+ ):
35
+ logging.warning(
36
+ f"Unable to override {what_str} for {node.op_type} node's {io_kind} "
37
+ "because it has already been overridden! Check the initial quantization overrides provided "
38
+ "to get_qnn_qdq_config() if the generated QDQ model does not run on QNN EP. "
39
+ f"Node name: {node.name}, {io_kind} name: {tensor_name}"
40
+ )
41
+
42
+
43
+ def get_qnn_qdq_config(
44
+ model_input: str | Path | onnx.ModelProto,
45
+ calibration_data_reader: CalibrationDataReader,
46
+ calibrate_method: CalibrationMethod = CalibrationMethod.MinMax,
47
+ activation_type: QuantType = QuantType.QUInt8,
48
+ weight_type: QuantType = QuantType.QUInt8,
49
+ per_channel: bool = False,
50
+ init_overrides: dict[str, list[dict[str, Any]]] | None = None,
51
+ add_qtype_converts: bool = True,
52
+ activation_symmetric: bool = False,
53
+ weight_symmetric: bool | None = None,
54
+ keep_removable_activations: bool = False,
55
+ stride: int | None = None,
56
+ ) -> StaticQuantConfig:
57
+ """
58
+ Returns a static quantization configuration suitable for running QDQ models on QNN EP.
59
+ This is done primarily by setting tensor-level quantization overrides.
60
+
61
+ Params:
62
+ model_input: Path to the input model file or ModelProto.
63
+ calibration_data_reader: Calibration data reader.
64
+ calibrate_methode: The calibration method. Defaults to MinMax.
65
+ activation_type: The default activation quantization type. Defaults to QUInt8.
66
+ weight_type: The default weight quantization type. Defaults to QUInt8.
67
+ per_channel: Global option that determines if a fixed set of operator types should be quantized per-channel.
68
+ Defaults to false. Alternatively, use the tensor-level `init_overrides` to select individual operators
69
+ and their quantization axes.
70
+
71
+ If set, the quantization tool uses per-channel quantization for the following operator types and inputs:
72
+ - Conv:
73
+ - input[1] on axis 0
74
+ - input[2] (bias) on axis 0
75
+ - ConvTranspose:
76
+ - input[1] on axis 1
77
+ - input[2] (bias) on axis 0
78
+ init_overrides: Initial tensor-level quantization overrides. Defaults to None. This function updates of a copy
79
+ of these overrides with any necessary adjustments and includes them in the returned
80
+ configuration object (i.e., config.extra_options['TensorQuantOverrides']).
81
+
82
+ The key is a tensor name and the value is a list of dictionaries. For per-tensor quantization, the list
83
+ contains a single dictionary. For per-channel quantization, the list contains either a dictionary for
84
+ each channel in the tensor or a single dictionary that is assumed to apply to all channels. An 'axis'
85
+ key must be present in the first dictionary for per-channel quantization.
86
+
87
+ Each dictionary contains optional overrides with the following keys and values.
88
+ 'quant_type' = QuantType : The tensor's quantization data type.
89
+ 'axis' = Int : The per-channel axis. Must be present for per-channel weights.
90
+ 'scale' = Float : The scale value to use. Must also specify `zero_point` if set.
91
+ 'zero_point' = Int : The zero-point value to use. Must also specify `scale` is set.
92
+ 'symmetric' = Bool : If the tensor should use symmetric quantization. Invalid if also
93
+ set `scale` or `zero_point`.
94
+ 'reduce_range' = Bool : If the quantization range should be reduced. Invalid if also
95
+ set `scale` or `zero_point`. Only valid for initializers.
96
+ 'rmax' = Float : Override the maximum real tensor value in calibration data.
97
+ Invalid if also set `scale` or `zero_point`.
98
+ 'rmin' = Float : Override the minimum real tensor value in calibration data.
99
+ Invalid if also set `scale` or `zero_point`.
100
+ 'convert' = Dict : A nested dictionary with the same keys for an activation
101
+ tensor that should be converted to another quantization type.
102
+ 'convert["recv_nodes"] = Set : Set of node names that consume the converted activation,
103
+ other nodes get the original type. If not specified,
104
+ assume all consumer nodes get the converted type.
105
+ add_qtype_converts: True if this function should automatically add "convert" entries to the provided
106
+ `init_overrides` to ensure that operators use valid input/output types (activations only).
107
+ Ex: if you override the output of an Add to 16-bit, this option ensures that the activation inputs
108
+ of the Add are also up-converted to 16-bit and that data types for surrounding ops are converted
109
+ appropriately. Refer to the documentation in mixed_precision_overrides_utils.py for additional details.
110
+ activation_symmetric: True if activations should be quantized symmetrically (i.e, rmax == -rmin) by default.
111
+ Defaults to false. For int8 and int16, this results in zero-point values of 0. For uint8 and uin16,
112
+ the zero-point values are 128 and 32,768, respectively.
113
+ weight_symmetric: True if weights should be quantized symmetrically (i.e., rmax == -rmin) by default.
114
+ Defaults to None. If set to None, weight_symmetric is assumed true if the weight_type is a signed int.
115
+ keep_removable_activations: Defaults to false. If true, "removable" activations (e.g., Clip or Relu) will not
116
+ be removed, and will be explicitly represented in the QDQ model. If false, these activations
117
+ are automatically removed if activations are asymmetrically quantized. Keeping these activations
118
+ is necessary if optimizations or EP transformations will later remove
119
+ QuantizeLinear/DequantizeLinear operators from the model.
120
+
121
+ Returns:
122
+ A StaticQuantConfig object
123
+ """
124
+ if weight_symmetric is None:
125
+ weight_symmetric = weight_type in {QuantType.QInt8, QuantType.QInt16}
126
+
127
+ model = (
128
+ model_input
129
+ if isinstance(model_input, onnx.ModelProto)
130
+ else onnx.load_model(model_input, load_external_data=False)
131
+ )
132
+
133
+ op_types = set()
134
+ model_has_external_data = False
135
+ name_to_initializer = {}
136
+
137
+ # Build map of initializers (name -> initializer) and
138
+ # check if the model has external data.
139
+ for initializer in model.graph.initializer:
140
+ name_to_initializer[initializer.name] = initializer
141
+ if onnx.external_data_helper.uses_external_data(initializer):
142
+ model_has_external_data = True
143
+
144
+ overrides_helper = TensorQuantOverridesHelper(copy.deepcopy(init_overrides) if init_overrides else {})
145
+
146
+ if not overrides_helper.empty() and add_qtype_converts:
147
+ # Fix mixed-precision overrides.
148
+ overrides_fixer = MixedPrecisionTensorQuantOverridesFixer.create_from_model(
149
+ overrides_helper, model, activation_type
150
+ )
151
+ overrides_fixer.apply(activation_type, activation_symmetric)
152
+
153
+ # Setup quantization overrides for specific operator types to ensure compatibility with QNN EP.
154
+ qnn_compat = QnnCompatibilityOverrides(
155
+ activation_type,
156
+ weight_type,
157
+ activation_symmetric,
158
+ weight_symmetric,
159
+ per_channel,
160
+ overrides_helper,
161
+ name_to_initializer,
162
+ )
163
+
164
+ for node in model.graph.node:
165
+ op_types.add(node.op_type)
166
+ qnn_compat.process_node(node)
167
+
168
+ extra_options = {
169
+ "MinimumRealRange": 0.0001,
170
+ "DedicatedQDQPair": False, # Let ORT optimizer duplicate DQ nodes
171
+ "QDQKeepRemovableActivations": keep_removable_activations,
172
+ "TensorQuantOverrides": overrides_helper.get_dict(),
173
+ "ActivationSymmetric": activation_symmetric,
174
+ "WeightSymmetric": weight_symmetric,
175
+ "CalibStridedMinMax": stride,
176
+ }
177
+
178
+ # ONNX opset < 21 does not support 16-bit quantization, so must use 'com.microsoft' domain
179
+ # on Q/DQ operators if using 16-bit or 4-bit quantization.
180
+ onnx_opset = next(x for x in model.opset_import if x.domain == "" or x.domain == "ai.onnx")
181
+ if onnx_opset.version < 21:
182
+ opset21_types = Q16_TYPES.union(Q4_TYPES)
183
+ overrides_have_opset21_types = any(t in opset21_types for t in overrides_helper.get_quant_types())
184
+ if activation_type in opset21_types or weight_type in opset21_types or overrides_have_opset21_types:
185
+ extra_options["UseQDQContribOps"] = True
186
+
187
+ return StaticQuantConfig(
188
+ calibration_data_reader,
189
+ calibrate_method=calibrate_method,
190
+ activation_type=activation_type,
191
+ weight_type=weight_type,
192
+ op_types_to_quantize=list(op_types.difference(OP_TYPES_TO_EXCLUDE)),
193
+ per_channel=per_channel,
194
+ use_external_data_format=(model_has_external_data or model.ByteSize() >= MODEL_SIZE_THRESHOLD),
195
+ extra_options=extra_options,
196
+ )
197
+
198
+
199
+ class QnnCompatibilityOverrides:
200
+ """
201
+ Helper that processes nodes to generate quantization overrides that make the resulting QDQ model
202
+ compatible with QNN EP.
203
+ """
204
+
205
+ def __init__(
206
+ self,
207
+ default_activation_qtype: QuantType,
208
+ default_weight_qtype: QuantType,
209
+ activation_symmetric: bool,
210
+ weight_symmetric: bool,
211
+ per_channel: bool,
212
+ overrides: TensorQuantOverridesHelper,
213
+ initializers: dict[str, onnx.TensorProto],
214
+ ):
215
+ self.default_activation_qtype = default_activation_qtype
216
+ self.default_weight_qtype = default_weight_qtype
217
+ self.activation_symmetric = activation_symmetric
218
+ self.weight_symmetric = weight_symmetric
219
+ self.per_channel = per_channel
220
+ self.overrides = overrides
221
+ self.initializers = initializers
222
+
223
+ self.process_fns = {
224
+ "MatMul": self._process_matmul,
225
+ "LayerNormalization": self._process_layernorm,
226
+ "Sigmoid": self._process_sigmoid,
227
+ "Tanh": self._process_tanh,
228
+ }
229
+
230
+ def process_node(self, node: onnx.NodeProto):
231
+ process_fn = self.process_fns.get(node.op_type)
232
+
233
+ if process_fn is not None:
234
+ process_fn(node)
235
+
236
+ def _make_static_inputs_use_default_weight_type(self, node: onnx.NodeProto):
237
+ """
238
+ Overrides initializer input(s) to use the default weight type if:
239
+ - The default weight type is 8-bit
240
+ - One of the inputs is a 16-bit activation
241
+ - The other input is an initializer (per-tensor quantized)
242
+
243
+ This is necessary because the quantization tool does not assign MatMul or LayerNorm initializer
244
+ inputs the default weight type. Instead, it assigns the default activation type.
245
+ """
246
+ if self.default_weight_qtype not in Q8_TYPES:
247
+ return
248
+
249
+ input_16bit_act_name = None
250
+ input_weight_name = None
251
+
252
+ # Loop through first 2 inputs to find a 16-bit activation and a (per-tensor) weight.
253
+ for i in range(2):
254
+ input_name = node.input[i]
255
+ if not input_name:
256
+ continue
257
+
258
+ is_weight = input_name in self.initializers
259
+ qtype_info = self.overrides.get_node_input_qtype_info(
260
+ input_name,
261
+ node.name,
262
+ default_qtype=None if is_weight else self.default_activation_qtype,
263
+ )
264
+
265
+ if qtype_info.axis is not None:
266
+ return # Don't process MatMul with a per-channel quantized input.
267
+
268
+ if (
269
+ is_weight
270
+ and qtype_info.quant_type == self.default_weight_qtype
271
+ and qtype_info.symmetric == self.weight_symmetric
272
+ ):
273
+ return # Return. Weight is already overridden to use the desired weight type.
274
+
275
+ if is_weight:
276
+ input_weight_name = input_name
277
+ elif qtype_info.quant_type in Q16_TYPES:
278
+ input_16bit_act_name = input_name
279
+
280
+ # Override initializer input to use the default weight type.
281
+ if input_16bit_act_name and input_weight_name:
282
+ did_update = self.overrides.update_tensor_overrides(
283
+ input_weight_name,
284
+ {"quant_type": self.default_weight_qtype, "symmetric": self.weight_symmetric},
285
+ overwrite=False,
286
+ )
287
+
288
+ if not did_update:
289
+ warn_unable_to_override(node, "quant_type/symmetric", input_weight_name, "input weight")
290
+
291
+ def _process_matmul(self, node: onnx.NodeProto):
292
+ assert node.op_type == "MatMul", f"Expected MatMul, but got {node.op_type}"
293
+
294
+ if not self.per_channel:
295
+ self._make_static_inputs_use_default_weight_type(node)
296
+ return
297
+
298
+ # QNN does not support per-channel MatMul. However, the ORT quantization tool attempts to use per-channel
299
+ # quantization for MatMul by default *if* the global per_channel setting is enabled. So, we need to
300
+ # provide explicit per-tensor quantization overrides for MatMul if per_channel is enabled and
301
+ # the user did not provide any other overrides.
302
+ for input_name in node.input:
303
+ is_weight_no_overrides = input_name in self.initializers and input_name not in self.overrides
304
+ if is_weight_no_overrides:
305
+ self.overrides.update_tensor_overrides(
306
+ input_name,
307
+ {"quant_type": self.default_weight_qtype, "symmetric": self.weight_symmetric},
308
+ )
309
+
310
+ def _process_layernorm(self, node: onnx.NodeProto):
311
+ assert node.op_type == "LayerNormalization", f"Expected LayerNormalization, but got {node.op_type}"
312
+
313
+ if not self.per_channel:
314
+ self._make_static_inputs_use_default_weight_type(node)
315
+ return
316
+
317
+ has_weight_no_overrides = node.input[1] in self.initializers and node.input[1] not in self.overrides
318
+ has_bias_no_overrides = (
319
+ len(node.input) > 2
320
+ and node.input[2]
321
+ and node.input[2] in self.initializers
322
+ and node.input[2] not in self.overrides
323
+ )
324
+
325
+ if has_weight_no_overrides or has_bias_no_overrides:
326
+ # TODO: Make bias input not per-channel. QNN needs it to be per-tensor, but quantizer
327
+ # tries to makes it per-channel if the weight is also per-channel.
328
+ raise ValueError(
329
+ "get_qnn_qdq_config() does not currently support the global per_channel option with LayerNormalization."
330
+ " Please try using custom overrides that make bias per-tensor quantized."
331
+ )
332
+
333
+ def _process_sigmoid(self, node: onnx.NodeProto):
334
+ """
335
+ Overrides 16-bit Sigmoid's output scale and zero-point as per QNN requirements.
336
+ """
337
+ assert node.op_type == "Sigmoid", f"Expected Sigmoid, but got {node.op_type}"
338
+ output_type = self.overrides.get_node_output_qtype_info(
339
+ node.output[0], self.default_activation_qtype
340
+ ).quant_type
341
+
342
+ if output_type == QuantType.QUInt16:
343
+ self.overrides.update_tensor_overrides(
344
+ node.output[0],
345
+ {
346
+ "quant_type": output_type,
347
+ "scale": np.array(1.0 / 65536.0, dtype=np.float32),
348
+ "zero_point": np.array(0, dtype=np.uint16),
349
+ },
350
+ )
351
+ elif output_type == QuantType.QInt16:
352
+ self.overrides.update_tensor_overrides(
353
+ node.output[0],
354
+ {
355
+ "quant_type": output_type,
356
+ "scale": np.array(1.0 / 32768.0, dtype=np.float32),
357
+ "zero_point": np.array(0, dtype=np.int16),
358
+ },
359
+ )
360
+
361
+ def _process_tanh(self, node: onnx.NodeProto):
362
+ """
363
+ Overrides 16-bit Tanh's output scale and zero-point as per QNN requirements.
364
+ """
365
+ assert node.op_type == "Tanh", f"Expected Tanh, but got {node.op_type}"
366
+ output_type = self.overrides.get_node_output_qtype_info(
367
+ node.output[0], self.default_activation_qtype
368
+ ).quant_type
369
+
370
+ if output_type == QuantType.QUInt16:
371
+ self.overrides.update_tensor_overrides(
372
+ node.output[0],
373
+ {
374
+ "quant_type": output_type,
375
+ "scale": np.array(1.0 / 32768.0, dtype=np.float32),
376
+ "zero_point": np.array(32768, dtype=np.uint16),
377
+ },
378
+ )
379
+ elif output_type == QuantType.QInt16:
380
+ self.overrides.update_tensor_overrides(
381
+ node.output[0],
382
+ {
383
+ "quant_type": output_type,
384
+ "scale": np.array(1.0 / 32768.0, dtype=np.float32),
385
+ "zero_point": np.array(0, dtype=np.int16),
386
+ },
387
+ )
@@ -0,0 +1,3 @@
1
+ from .fusion import Fusion # noqa: F401
2
+ from .fusion_gelu import FusionGelu # noqa: F401
3
+ from .fusion_layernorm import FusionLayerNormalization # noqa: F401