onnxruntime-directml 1.24.1__cp314-cp314-win_amd64.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (322) hide show
  1. onnxruntime/LICENSE +21 -0
  2. onnxruntime/Privacy.md +21 -0
  3. onnxruntime/ThirdPartyNotices.txt +6121 -0
  4. onnxruntime/__init__.py +418 -0
  5. onnxruntime/backend/__init__.py +6 -0
  6. onnxruntime/backend/backend.py +175 -0
  7. onnxruntime/backend/backend_rep.py +52 -0
  8. onnxruntime/capi/DirectML.dll +0 -0
  9. onnxruntime/capi/__init__.py +4 -0
  10. onnxruntime/capi/_ld_preload.py +7 -0
  11. onnxruntime/capi/_pybind_state.py +33 -0
  12. onnxruntime/capi/build_and_package_info.py +2 -0
  13. onnxruntime/capi/convert_npz_to_onnx_adapter.py +48 -0
  14. onnxruntime/capi/onnxruntime.dll +0 -0
  15. onnxruntime/capi/onnxruntime_collect_build_info.py +47 -0
  16. onnxruntime/capi/onnxruntime_inference_collection.py +1440 -0
  17. onnxruntime/capi/onnxruntime_providers_shared.dll +0 -0
  18. onnxruntime/capi/onnxruntime_pybind11_state.pyd +0 -0
  19. onnxruntime/capi/onnxruntime_validation.py +154 -0
  20. onnxruntime/capi/version_info.py +2 -0
  21. onnxruntime/datasets/__init__.py +18 -0
  22. onnxruntime/datasets/logreg_iris.onnx +0 -0
  23. onnxruntime/datasets/mul_1.onnx +0 -0
  24. onnxruntime/datasets/sigmoid.onnx +13 -0
  25. onnxruntime/quantization/CalTableFlatBuffers/KeyValue.py +78 -0
  26. onnxruntime/quantization/CalTableFlatBuffers/TrtTable.py +90 -0
  27. onnxruntime/quantization/CalTableFlatBuffers/__init__.py +0 -0
  28. onnxruntime/quantization/__init__.py +19 -0
  29. onnxruntime/quantization/base_quantizer.py +529 -0
  30. onnxruntime/quantization/calibrate.py +1267 -0
  31. onnxruntime/quantization/execution_providers/qnn/__init__.py +2 -0
  32. onnxruntime/quantization/execution_providers/qnn/fusion_lpnorm.py +132 -0
  33. onnxruntime/quantization/execution_providers/qnn/fusion_spacetodepth.py +162 -0
  34. onnxruntime/quantization/execution_providers/qnn/mixed_precision_overrides_utils.py +413 -0
  35. onnxruntime/quantization/execution_providers/qnn/preprocess.py +353 -0
  36. onnxruntime/quantization/execution_providers/qnn/quant_config.py +389 -0
  37. onnxruntime/quantization/fusions/__init__.py +4 -0
  38. onnxruntime/quantization/fusions/fusion.py +311 -0
  39. onnxruntime/quantization/fusions/fusion_gelu.py +272 -0
  40. onnxruntime/quantization/fusions/fusion_layernorm.py +146 -0
  41. onnxruntime/quantization/fusions/replace_upsample_with_resize.py +96 -0
  42. onnxruntime/quantization/matmul_bnb4_quantizer.py +239 -0
  43. onnxruntime/quantization/matmul_nbits_quantizer.py +1638 -0
  44. onnxruntime/quantization/neural_compressor/__init__.py +1 -0
  45. onnxruntime/quantization/neural_compressor/onnx_model.py +1251 -0
  46. onnxruntime/quantization/neural_compressor/util.py +80 -0
  47. onnxruntime/quantization/neural_compressor/weight_only.py +932 -0
  48. onnxruntime/quantization/onnx_model.py +600 -0
  49. onnxruntime/quantization/onnx_quantizer.py +1163 -0
  50. onnxruntime/quantization/operators/__init__.py +2 -0
  51. onnxruntime/quantization/operators/activation.py +119 -0
  52. onnxruntime/quantization/operators/argmax.py +18 -0
  53. onnxruntime/quantization/operators/attention.py +73 -0
  54. onnxruntime/quantization/operators/base_operator.py +26 -0
  55. onnxruntime/quantization/operators/binary_op.py +72 -0
  56. onnxruntime/quantization/operators/concat.py +62 -0
  57. onnxruntime/quantization/operators/conv.py +260 -0
  58. onnxruntime/quantization/operators/direct_q8.py +78 -0
  59. onnxruntime/quantization/operators/embed_layernorm.py +121 -0
  60. onnxruntime/quantization/operators/gather.py +64 -0
  61. onnxruntime/quantization/operators/gavgpool.py +62 -0
  62. onnxruntime/quantization/operators/gemm.py +172 -0
  63. onnxruntime/quantization/operators/lstm.py +121 -0
  64. onnxruntime/quantization/operators/matmul.py +231 -0
  65. onnxruntime/quantization/operators/maxpool.py +34 -0
  66. onnxruntime/quantization/operators/norm.py +40 -0
  67. onnxruntime/quantization/operators/pad.py +172 -0
  68. onnxruntime/quantization/operators/pooling.py +67 -0
  69. onnxruntime/quantization/operators/qdq_base_operator.py +22 -0
  70. onnxruntime/quantization/operators/resize.py +34 -0
  71. onnxruntime/quantization/operators/softmax.py +74 -0
  72. onnxruntime/quantization/operators/split.py +63 -0
  73. onnxruntime/quantization/operators/where.py +87 -0
  74. onnxruntime/quantization/preprocess.py +141 -0
  75. onnxruntime/quantization/qdq_loss_debug.py +389 -0
  76. onnxruntime/quantization/qdq_quantizer.py +1477 -0
  77. onnxruntime/quantization/quant_utils.py +1051 -0
  78. onnxruntime/quantization/quantize.py +953 -0
  79. onnxruntime/quantization/registry.py +110 -0
  80. onnxruntime/quantization/shape_inference.py +204 -0
  81. onnxruntime/quantization/static_quantize_runner.py +256 -0
  82. onnxruntime/quantization/tensor_quant_overrides.py +520 -0
  83. onnxruntime/tools/__init__.py +10 -0
  84. onnxruntime/tools/check_onnx_model_mobile_usability.py +47 -0
  85. onnxruntime/tools/convert_onnx_models_to_ort.py +380 -0
  86. onnxruntime/tools/file_utils.py +47 -0
  87. onnxruntime/tools/logger.py +11 -0
  88. onnxruntime/tools/make_dynamic_shape_fixed.py +73 -0
  89. onnxruntime/tools/mobile_helpers/__init__.py +0 -0
  90. onnxruntime/tools/mobile_helpers/coreml_supported_mlprogram_ops.md +53 -0
  91. onnxruntime/tools/mobile_helpers/coreml_supported_neuralnetwork_ops.md +43 -0
  92. onnxruntime/tools/mobile_helpers/nnapi_supported_ops.md +58 -0
  93. onnxruntime/tools/mobile_helpers/usability_checker.py +738 -0
  94. onnxruntime/tools/offline_tuning.py +169 -0
  95. onnxruntime/tools/onnx_model_utils.py +416 -0
  96. onnxruntime/tools/onnx_randomizer.py +85 -0
  97. onnxruntime/tools/onnxruntime_test.py +164 -0
  98. onnxruntime/tools/optimize_onnx_model.py +56 -0
  99. onnxruntime/tools/ort_format_model/__init__.py +27 -0
  100. onnxruntime/tools/ort_format_model/operator_type_usage_processors.py +653 -0
  101. onnxruntime/tools/ort_format_model/ort_flatbuffers_py/__init__.py +0 -0
  102. onnxruntime/tools/ort_format_model/ort_flatbuffers_py/fbs/ArgType.py +7 -0
  103. onnxruntime/tools/ort_format_model/ort_flatbuffers_py/fbs/ArgTypeAndIndex.py +67 -0
  104. onnxruntime/tools/ort_format_model/ort_flatbuffers_py/fbs/Attribute.py +337 -0
  105. onnxruntime/tools/ort_format_model/ort_flatbuffers_py/fbs/AttributeType.py +18 -0
  106. onnxruntime/tools/ort_format_model/ort_flatbuffers_py/fbs/Checkpoint.py +125 -0
  107. onnxruntime/tools/ort_format_model/ort_flatbuffers_py/fbs/DeprecatedKernelCreateInfos.py +120 -0
  108. onnxruntime/tools/ort_format_model/ort_flatbuffers_py/fbs/DeprecatedNodeIndexAndKernelDefHash.py +68 -0
  109. onnxruntime/tools/ort_format_model/ort_flatbuffers_py/fbs/DeprecatedSessionState.py +96 -0
  110. onnxruntime/tools/ort_format_model/ort_flatbuffers_py/fbs/DeprecatedSubGraphSessionState.py +72 -0
  111. onnxruntime/tools/ort_format_model/ort_flatbuffers_py/fbs/Dimension.py +71 -0
  112. onnxruntime/tools/ort_format_model/ort_flatbuffers_py/fbs/DimensionValue.py +80 -0
  113. onnxruntime/tools/ort_format_model/ort_flatbuffers_py/fbs/DimensionValueType.py +8 -0
  114. onnxruntime/tools/ort_format_model/ort_flatbuffers_py/fbs/EdgeEnd.py +32 -0
  115. onnxruntime/tools/ort_format_model/ort_flatbuffers_py/fbs/FloatProperty.py +67 -0
  116. onnxruntime/tools/ort_format_model/ort_flatbuffers_py/fbs/Graph.py +320 -0
  117. onnxruntime/tools/ort_format_model/ort_flatbuffers_py/fbs/InferenceSession.py +88 -0
  118. onnxruntime/tools/ort_format_model/ort_flatbuffers_py/fbs/IntProperty.py +67 -0
  119. onnxruntime/tools/ort_format_model/ort_flatbuffers_py/fbs/KernelTypeStrArgsEntry.py +91 -0
  120. onnxruntime/tools/ort_format_model/ort_flatbuffers_py/fbs/KernelTypeStrResolver.py +78 -0
  121. onnxruntime/tools/ort_format_model/ort_flatbuffers_py/fbs/MapType.py +71 -0
  122. onnxruntime/tools/ort_format_model/ort_flatbuffers_py/fbs/Model.py +223 -0
  123. onnxruntime/tools/ort_format_model/ort_flatbuffers_py/fbs/ModuleState.py +141 -0
  124. onnxruntime/tools/ort_format_model/ort_flatbuffers_py/fbs/Node.py +317 -0
  125. onnxruntime/tools/ort_format_model/ort_flatbuffers_py/fbs/NodeEdge.py +126 -0
  126. onnxruntime/tools/ort_format_model/ort_flatbuffers_py/fbs/NodeType.py +7 -0
  127. onnxruntime/tools/ort_format_model/ort_flatbuffers_py/fbs/NodesToOptimizeIndices.py +160 -0
  128. onnxruntime/tools/ort_format_model/ort_flatbuffers_py/fbs/OpIdKernelTypeStrArgsEntry.py +91 -0
  129. onnxruntime/tools/ort_format_model/ort_flatbuffers_py/fbs/OperatorSetId.py +67 -0
  130. onnxruntime/tools/ort_format_model/ort_flatbuffers_py/fbs/OptimizerGroup.py +117 -0
  131. onnxruntime/tools/ort_format_model/ort_flatbuffers_py/fbs/ParameterOptimizerState.py +91 -0
  132. onnxruntime/tools/ort_format_model/ort_flatbuffers_py/fbs/PropertyBag.py +152 -0
  133. onnxruntime/tools/ort_format_model/ort_flatbuffers_py/fbs/RuntimeOptimizationRecord.py +105 -0
  134. onnxruntime/tools/ort_format_model/ort_flatbuffers_py/fbs/RuntimeOptimizationRecordContainerEntry.py +91 -0
  135. onnxruntime/tools/ort_format_model/ort_flatbuffers_py/fbs/RuntimeOptimizations.py +79 -0
  136. onnxruntime/tools/ort_format_model/ort_flatbuffers_py/fbs/SequenceType.py +58 -0
  137. onnxruntime/tools/ort_format_model/ort_flatbuffers_py/fbs/Shape.py +78 -0
  138. onnxruntime/tools/ort_format_model/ort_flatbuffers_py/fbs/SparseTensor.py +114 -0
  139. onnxruntime/tools/ort_format_model/ort_flatbuffers_py/fbs/StringProperty.py +67 -0
  140. onnxruntime/tools/ort_format_model/ort_flatbuffers_py/fbs/StringStringEntry.py +67 -0
  141. onnxruntime/tools/ort_format_model/ort_flatbuffers_py/fbs/Tensor.py +203 -0
  142. onnxruntime/tools/ort_format_model/ort_flatbuffers_py/fbs/TensorDataType.py +26 -0
  143. onnxruntime/tools/ort_format_model/ort_flatbuffers_py/fbs/TensorTypeAndShape.py +71 -0
  144. onnxruntime/tools/ort_format_model/ort_flatbuffers_py/fbs/TypeInfo.py +83 -0
  145. onnxruntime/tools/ort_format_model/ort_flatbuffers_py/fbs/TypeInfoValue.py +9 -0
  146. onnxruntime/tools/ort_format_model/ort_flatbuffers_py/fbs/ValueInfo.py +84 -0
  147. onnxruntime/tools/ort_format_model/ort_flatbuffers_py/fbs/__init__.py +6 -0
  148. onnxruntime/tools/ort_format_model/ort_model_processor.py +86 -0
  149. onnxruntime/tools/ort_format_model/types.py +85 -0
  150. onnxruntime/tools/ort_format_model/utils.py +61 -0
  151. onnxruntime/tools/pytorch_export_contrib_ops.py +129 -0
  152. onnxruntime/tools/pytorch_export_helpers.py +131 -0
  153. onnxruntime/tools/qdq_helpers/__init__.py +0 -0
  154. onnxruntime/tools/qdq_helpers/optimize_qdq_model.py +37 -0
  155. onnxruntime/tools/qnn/add_trans_cast.py +292 -0
  156. onnxruntime/tools/qnn/gen_qnn_ctx_onnx_model.py +364 -0
  157. onnxruntime/tools/qnn/preprocess.py +165 -0
  158. onnxruntime/tools/reduced_build_config_parser.py +203 -0
  159. onnxruntime/tools/remove_initializer_from_input.py +37 -0
  160. onnxruntime/tools/symbolic_shape_infer.py +3094 -0
  161. onnxruntime/tools/update_onnx_opset.py +31 -0
  162. onnxruntime/transformers/__init__.py +8 -0
  163. onnxruntime/transformers/affinity_helper.py +40 -0
  164. onnxruntime/transformers/benchmark.py +942 -0
  165. onnxruntime/transformers/benchmark_helper.py +643 -0
  166. onnxruntime/transformers/bert_perf_test.py +629 -0
  167. onnxruntime/transformers/bert_test_data.py +641 -0
  168. onnxruntime/transformers/compare_bert_results.py +256 -0
  169. onnxruntime/transformers/constants.py +47 -0
  170. onnxruntime/transformers/convert_generation.py +3605 -0
  171. onnxruntime/transformers/convert_tf_models_to_pytorch.py +205 -0
  172. onnxruntime/transformers/convert_to_packing_mode.py +385 -0
  173. onnxruntime/transformers/dynamo_onnx_helper.py +205 -0
  174. onnxruntime/transformers/float16.py +501 -0
  175. onnxruntime/transformers/fusion_attention.py +1189 -0
  176. onnxruntime/transformers/fusion_attention_clip.py +340 -0
  177. onnxruntime/transformers/fusion_attention_sam2.py +533 -0
  178. onnxruntime/transformers/fusion_attention_unet.py +1307 -0
  179. onnxruntime/transformers/fusion_attention_vae.py +300 -0
  180. onnxruntime/transformers/fusion_bart_attention.py +435 -0
  181. onnxruntime/transformers/fusion_base.py +141 -0
  182. onnxruntime/transformers/fusion_bias_add.py +57 -0
  183. onnxruntime/transformers/fusion_biasgelu.py +66 -0
  184. onnxruntime/transformers/fusion_biassplitgelu.py +110 -0
  185. onnxruntime/transformers/fusion_conformer_attention.py +222 -0
  186. onnxruntime/transformers/fusion_constant_fold.py +144 -0
  187. onnxruntime/transformers/fusion_embedlayer.py +810 -0
  188. onnxruntime/transformers/fusion_fastgelu.py +492 -0
  189. onnxruntime/transformers/fusion_gelu.py +258 -0
  190. onnxruntime/transformers/fusion_gelu_approximation.py +25 -0
  191. onnxruntime/transformers/fusion_gemmfastgelu.py +121 -0
  192. onnxruntime/transformers/fusion_gpt_attention.py +546 -0
  193. onnxruntime/transformers/fusion_gpt_attention_megatron.py +355 -0
  194. onnxruntime/transformers/fusion_gpt_attention_no_past.py +260 -0
  195. onnxruntime/transformers/fusion_group_norm.py +180 -0
  196. onnxruntime/transformers/fusion_layernorm.py +489 -0
  197. onnxruntime/transformers/fusion_mha_mmdit.py +667 -0
  198. onnxruntime/transformers/fusion_nhwc_conv.py +99 -0
  199. onnxruntime/transformers/fusion_options.py +340 -0
  200. onnxruntime/transformers/fusion_qordered_attention.py +420 -0
  201. onnxruntime/transformers/fusion_qordered_gelu.py +118 -0
  202. onnxruntime/transformers/fusion_qordered_layernorm.py +122 -0
  203. onnxruntime/transformers/fusion_qordered_matmul.py +216 -0
  204. onnxruntime/transformers/fusion_quickgelu.py +74 -0
  205. onnxruntime/transformers/fusion_reshape.py +173 -0
  206. onnxruntime/transformers/fusion_rotary_attention.py +1591 -0
  207. onnxruntime/transformers/fusion_shape.py +109 -0
  208. onnxruntime/transformers/fusion_simplified_layernorm.py +165 -0
  209. onnxruntime/transformers/fusion_skip_group_norm.py +254 -0
  210. onnxruntime/transformers/fusion_skiplayernorm.py +209 -0
  211. onnxruntime/transformers/fusion_transpose.py +167 -0
  212. onnxruntime/transformers/fusion_utils.py +321 -0
  213. onnxruntime/transformers/huggingface_models.py +74 -0
  214. onnxruntime/transformers/import_utils.py +20 -0
  215. onnxruntime/transformers/io_binding_helper.py +487 -0
  216. onnxruntime/transformers/large_model_exporter.py +395 -0
  217. onnxruntime/transformers/machine_info.py +230 -0
  218. onnxruntime/transformers/metrics.py +163 -0
  219. onnxruntime/transformers/models/bart/__init__.py +12 -0
  220. onnxruntime/transformers/models/bart/export.py +98 -0
  221. onnxruntime/transformers/models/bert/__init__.py +12 -0
  222. onnxruntime/transformers/models/bert/eval_squad.py +329 -0
  223. onnxruntime/transformers/models/gpt2/__init__.py +12 -0
  224. onnxruntime/transformers/models/gpt2/benchmark_gpt2.py +413 -0
  225. onnxruntime/transformers/models/gpt2/convert_to_onnx.py +566 -0
  226. onnxruntime/transformers/models/gpt2/gpt2_helper.py +1031 -0
  227. onnxruntime/transformers/models/gpt2/gpt2_parity.py +513 -0
  228. onnxruntime/transformers/models/gpt2/gpt2_tester.py +501 -0
  229. onnxruntime/transformers/models/gpt2/parity_check_helper.py +146 -0
  230. onnxruntime/transformers/models/llama/__init__.py +12 -0
  231. onnxruntime/transformers/models/llama/benchmark.py +700 -0
  232. onnxruntime/transformers/models/llama/benchmark_all.py +488 -0
  233. onnxruntime/transformers/models/llama/benchmark_e2e.py +608 -0
  234. onnxruntime/transformers/models/llama/convert_to_onnx.py +1064 -0
  235. onnxruntime/transformers/models/llama/dist_settings.py +57 -0
  236. onnxruntime/transformers/models/llama/llama_inputs.py +504 -0
  237. onnxruntime/transformers/models/llama/llama_parity.py +343 -0
  238. onnxruntime/transformers/models/llama/llama_torch.py +47 -0
  239. onnxruntime/transformers/models/llama/quant_kv_dataloader.py +108 -0
  240. onnxruntime/transformers/models/longformer/__init__.py +12 -0
  241. onnxruntime/transformers/models/longformer/benchmark_longformer.py +821 -0
  242. onnxruntime/transformers/models/longformer/convert_to_onnx.py +413 -0
  243. onnxruntime/transformers/models/longformer/generate_test_data.py +347 -0
  244. onnxruntime/transformers/models/longformer/longformer_helper.py +76 -0
  245. onnxruntime/transformers/models/phi2/__init__.py +12 -0
  246. onnxruntime/transformers/models/phi2/convert_to_onnx.py +590 -0
  247. onnxruntime/transformers/models/phi2/inference_example.py +414 -0
  248. onnxruntime/transformers/models/sam2/__init__.py +12 -0
  249. onnxruntime/transformers/models/sam2/benchmark_sam2.py +638 -0
  250. onnxruntime/transformers/models/sam2/convert_to_onnx.py +270 -0
  251. onnxruntime/transformers/models/sam2/image_decoder.py +272 -0
  252. onnxruntime/transformers/models/sam2/image_encoder.py +236 -0
  253. onnxruntime/transformers/models/sam2/mask_decoder.py +208 -0
  254. onnxruntime/transformers/models/sam2/nvtx_helper.py +33 -0
  255. onnxruntime/transformers/models/sam2/prompt_encoder.py +189 -0
  256. onnxruntime/transformers/models/sam2/sam2_demo.py +321 -0
  257. onnxruntime/transformers/models/sam2/sam2_image_onnx_predictor.py +279 -0
  258. onnxruntime/transformers/models/sam2/sam2_utils.py +147 -0
  259. onnxruntime/transformers/models/stable_diffusion/__init__.py +12 -0
  260. onnxruntime/transformers/models/stable_diffusion/benchmark.py +1519 -0
  261. onnxruntime/transformers/models/stable_diffusion/benchmark_controlnet.py +426 -0
  262. onnxruntime/transformers/models/stable_diffusion/demo_txt2img.py +103 -0
  263. onnxruntime/transformers/models/stable_diffusion/demo_txt2img_xl.py +269 -0
  264. onnxruntime/transformers/models/stable_diffusion/demo_utils.py +778 -0
  265. onnxruntime/transformers/models/stable_diffusion/diffusion_models.py +1318 -0
  266. onnxruntime/transformers/models/stable_diffusion/diffusion_schedulers.py +1179 -0
  267. onnxruntime/transformers/models/stable_diffusion/engine_builder.py +295 -0
  268. onnxruntime/transformers/models/stable_diffusion/engine_builder_ort_cuda.py +387 -0
  269. onnxruntime/transformers/models/stable_diffusion/engine_builder_ort_trt.py +288 -0
  270. onnxruntime/transformers/models/stable_diffusion/engine_builder_tensorrt.py +395 -0
  271. onnxruntime/transformers/models/stable_diffusion/engine_builder_torch.py +108 -0
  272. onnxruntime/transformers/models/stable_diffusion/optimize_pipeline.py +590 -0
  273. onnxruntime/transformers/models/stable_diffusion/ort_optimizer.py +136 -0
  274. onnxruntime/transformers/models/stable_diffusion/pipeline_stable_diffusion.py +831 -0
  275. onnxruntime/transformers/models/stable_diffusion/trt_utilities.py +12 -0
  276. onnxruntime/transformers/models/t5/__init__.py +12 -0
  277. onnxruntime/transformers/models/t5/convert_to_onnx.py +318 -0
  278. onnxruntime/transformers/models/t5/t5_decoder.py +437 -0
  279. onnxruntime/transformers/models/t5/t5_encoder.py +70 -0
  280. onnxruntime/transformers/models/t5/t5_encoder_decoder_init.py +361 -0
  281. onnxruntime/transformers/models/t5/t5_helper.py +302 -0
  282. onnxruntime/transformers/models/whisper/__init__.py +12 -0
  283. onnxruntime/transformers/models/whisper/benchmark.py +585 -0
  284. onnxruntime/transformers/models/whisper/benchmark_all.py +526 -0
  285. onnxruntime/transformers/models/whisper/convert_to_onnx.py +609 -0
  286. onnxruntime/transformers/models/whisper/whisper_chain.py +334 -0
  287. onnxruntime/transformers/models/whisper/whisper_decoder.py +464 -0
  288. onnxruntime/transformers/models/whisper/whisper_encoder.py +164 -0
  289. onnxruntime/transformers/models/whisper/whisper_encoder_decoder_init.py +371 -0
  290. onnxruntime/transformers/models/whisper/whisper_helper.py +1035 -0
  291. onnxruntime/transformers/models/whisper/whisper_inputs.py +380 -0
  292. onnxruntime/transformers/models/whisper/whisper_jump_times.py +477 -0
  293. onnxruntime/transformers/onnx_exporter.py +719 -0
  294. onnxruntime/transformers/onnx_model.py +1636 -0
  295. onnxruntime/transformers/onnx_model_bart.py +141 -0
  296. onnxruntime/transformers/onnx_model_bert.py +488 -0
  297. onnxruntime/transformers/onnx_model_bert_keras.py +474 -0
  298. onnxruntime/transformers/onnx_model_bert_tf.py +588 -0
  299. onnxruntime/transformers/onnx_model_clip.py +42 -0
  300. onnxruntime/transformers/onnx_model_conformer.py +32 -0
  301. onnxruntime/transformers/onnx_model_gpt2.py +101 -0
  302. onnxruntime/transformers/onnx_model_mmdit.py +112 -0
  303. onnxruntime/transformers/onnx_model_phi.py +929 -0
  304. onnxruntime/transformers/onnx_model_sam2.py +137 -0
  305. onnxruntime/transformers/onnx_model_t5.py +985 -0
  306. onnxruntime/transformers/onnx_model_tnlr.py +226 -0
  307. onnxruntime/transformers/onnx_model_unet.py +258 -0
  308. onnxruntime/transformers/onnx_model_vae.py +42 -0
  309. onnxruntime/transformers/onnx_utils.py +55 -0
  310. onnxruntime/transformers/optimizer.py +620 -0
  311. onnxruntime/transformers/past_helper.py +149 -0
  312. onnxruntime/transformers/profile_result_processor.py +358 -0
  313. onnxruntime/transformers/profiler.py +434 -0
  314. onnxruntime/transformers/quantize_helper.py +76 -0
  315. onnxruntime/transformers/shape_infer_helper.py +121 -0
  316. onnxruntime/transformers/shape_optimizer.py +400 -0
  317. onnxruntime/transformers/torch_onnx_export_helper.py +74 -0
  318. onnxruntime_directml-1.24.1.dist-info/METADATA +216 -0
  319. onnxruntime_directml-1.24.1.dist-info/RECORD +322 -0
  320. onnxruntime_directml-1.24.1.dist-info/WHEEL +5 -0
  321. onnxruntime_directml-1.24.1.dist-info/entry_points.txt +2 -0
  322. onnxruntime_directml-1.24.1.dist-info/top_level.txt +1 -0
@@ -0,0 +1,216 @@
1
+ # -------------------------------------------------------------------------
2
+ # Copyright (c) Microsoft Corporation. All rights reserved.
3
+ # Licensed under the MIT License.
4
+ # --------------------------------------------------------------------------
5
+
6
+ from logging import getLogger
7
+
8
+ from fusion_base import Fusion
9
+ from fusion_utils import FusionUtils
10
+ from onnx import helper
11
+ from onnx_model import OnnxModel
12
+
13
+ logger = getLogger(__name__)
14
+
15
+
16
+ class FusionQOrderedMatMul(Fusion):
17
+ def __init__(self, model: OnnxModel):
18
+ super().__init__(model, "QOrderedMatMul", "MatMul")
19
+
20
+ def fuse(self, node, input_name_to_nodes: dict, output_name_to_node: dict):
21
+ matmul_children = self.model.get_children(node, input_name_to_nodes)
22
+
23
+ # Should only have 1 child - Bias Add
24
+ if len(matmul_children) != 1 or matmul_children[0].op_type != "Add":
25
+ return
26
+
27
+ bias_add_node = matmul_children[0]
28
+
29
+ # Atleast one of the inputs to Bias Add node must be a constant
30
+ bias_add_node_index = 0
31
+ if (
32
+ self.model.get_constant_value(bias_add_node.input[0]) is None
33
+ and self.model.get_constant_value(bias_add_node.input[1]) is None
34
+ ):
35
+ return
36
+
37
+ if self.model.get_constant_value(bias_add_node.input[0]) is None:
38
+ bias_add_node_index = 1
39
+
40
+ bias_add_children = self.model.get_children(bias_add_node, input_name_to_nodes)
41
+
42
+ if len(bias_add_children) != 1:
43
+ return
44
+
45
+ bias_add_child = bias_add_children[0]
46
+
47
+ # Bias Add can have another Add downstream (Residual Add layer)
48
+ residual_add_node = None
49
+
50
+ downstream_quantize_node = None
51
+
52
+ if bias_add_child.op_type == "Add":
53
+ residual_add_node = bias_add_child
54
+
55
+ residual_add_children = self.model.get_children(residual_add_node, input_name_to_nodes)
56
+
57
+ if len(residual_add_children) != 1 or residual_add_children[0].op_type != "QuantizeLinear":
58
+ return
59
+
60
+ downstream_quantize_node = residual_add_children[0]
61
+
62
+ elif bias_add_child.op_type == "QuantizeLinear":
63
+ downstream_quantize_node = bias_add_child
64
+
65
+ else:
66
+ return
67
+
68
+ # Make sure the downstream QuantizeLinear has the proper zero points and scales
69
+ if not FusionUtils.check_qdq_node_for_fusion(downstream_quantize_node, self.model):
70
+ return
71
+
72
+ # The first input to MatMul should flow through a DequantizeLinear node
73
+ first_path_id, first_input_parent_nodes, _ = self.model.match_parent_paths(
74
+ node,
75
+ [(["DequantizeLinear"], [0])],
76
+ output_name_to_node,
77
+ )
78
+
79
+ # If Attention is not fused, this is the pattern to look for
80
+ # leading upto the MatMul
81
+ reshape_node_0 = None
82
+ transpose_node_0 = None
83
+ if first_path_id < 0:
84
+ first_path_id, first_input_parent_nodes, _ = self.model.match_parent_paths(
85
+ node,
86
+ [(["Reshape", "Transpose", "DequantizeLinear", "QuantizeLinear"], [0, 0, 0, 0])],
87
+ output_name_to_node,
88
+ )
89
+
90
+ if first_path_id < 0:
91
+ return
92
+
93
+ reshape_node_0 = first_input_parent_nodes[0]
94
+ transpose_node_0 = first_input_parent_nodes[1]
95
+ dequantize_node_0 = first_input_parent_nodes[2]
96
+ else:
97
+ dequantize_node_0 = first_input_parent_nodes[0]
98
+
99
+ # Make sure the upstream DequantizeLinear-0 has the proper zero points and scales
100
+ if not FusionUtils.check_qdq_node_for_fusion(dequantize_node_0, self.model):
101
+ return
102
+
103
+ # The second input to MatMul should flow through a DequantizeLinear node
104
+ dequantize_node_1 = None
105
+ is_weight_transpose_required = True
106
+
107
+ weight_path_id, weight_nodes, _ = self.model.match_parent_paths(
108
+ node,
109
+ [(["DequantizeLinear", "QuantizeLinear", "Transpose", "DequantizeLinear"], [1, 0, 0, 0])],
110
+ output_name_to_node,
111
+ )
112
+
113
+ if weight_path_id < 0:
114
+ weight_path_id, weight_nodes, _ = self.model.match_parent_paths(
115
+ node,
116
+ [(["DequantizeLinear"], [1])],
117
+ output_name_to_node,
118
+ )
119
+
120
+ if weight_path_id < 0:
121
+ return
122
+
123
+ dequantize_node_1 = weight_nodes[0]
124
+ else:
125
+ is_weight_transpose_required = False
126
+ dequantize_node_1 = weight_nodes[3]
127
+
128
+ # Check if weight 'B' is a constant
129
+ if self.model.get_constant_value(dequantize_node_1.input[0]) is None:
130
+ return
131
+
132
+ # Make sure the upstream DequantizeLinear-1 has the proper zero points and scales
133
+ # Per-channel scales are supported for weights alone
134
+ if not FusionUtils.check_qdq_node_for_fusion(dequantize_node_1, self.model, False):
135
+ return
136
+
137
+ # Make sure the upstream flow into the Residual Add node flows through a DQ node
138
+ residual_add_dequantize_node = None
139
+
140
+ if residual_add_node is not None:
141
+ residual_path_id, residual_input_parent_nodes, _ = self.model.match_parent_paths(
142
+ residual_add_node,
143
+ [
144
+ (["DequantizeLinear"], [1]),
145
+ ],
146
+ output_name_to_node,
147
+ )
148
+
149
+ if residual_path_id < 0:
150
+ return
151
+
152
+ residual_add_dequantize_node = residual_input_parent_nodes[0]
153
+
154
+ # Make sure the upstream DequantizeLinear to the Residual Add has the proper zero points and scales
155
+ if residual_add_dequantize_node is not None and not FusionUtils.check_qdq_node_for_fusion(
156
+ residual_add_dequantize_node, self.model
157
+ ):
158
+ return
159
+
160
+ # Subgraph nodes to be fused
161
+ subgraph_nodes = [node, bias_add_node] # MatMul + Bias Add
162
+
163
+ if residual_add_node is not None:
164
+ subgraph_nodes.extend([residual_add_node]) # Residual Add
165
+
166
+ subgraph_nodes.extend(weight_nodes)
167
+ subgraph_nodes.extend([downstream_quantize_node]) # Downstream Q node
168
+
169
+ if not self.model.is_safe_to_fuse_nodes(
170
+ subgraph_nodes, downstream_quantize_node.output, input_name_to_nodes, output_name_to_node
171
+ ):
172
+ logger.debug("It is not safe to fuse QOrderedMatMul node. Skip")
173
+ return
174
+
175
+ # Deal with the case where-in the Attention subgraph is not fused
176
+ if transpose_node_0 is not None:
177
+ self.model.replace_node_input(transpose_node_0, transpose_node_0.input[0], dequantize_node_0.input[0])
178
+
179
+ # Make inputs
180
+ fused_node_inputs = [
181
+ reshape_node_0.output[0] if reshape_node_0 is not None else dequantize_node_0.input[0],
182
+ dequantize_node_0.input[1],
183
+ dequantize_node_1.input[0],
184
+ dequantize_node_1.input[1],
185
+ downstream_quantize_node.input[1],
186
+ bias_add_node.input[bias_add_node_index],
187
+ ]
188
+
189
+ if residual_add_node is not None:
190
+ fused_node_inputs.append(residual_add_dequantize_node.input[0])
191
+ fused_node_inputs.append(residual_add_dequantize_node.input[1])
192
+
193
+ # The MatMul weight 'B' and 'bias' need some post-processing
194
+ # Transpose weight 'B' from order ROW to order COL
195
+ # This offline transpose is needed only while using the CUDA EP
196
+ # TODO: Make this fusion logic EP-agnostic ?
197
+ if is_weight_transpose_required:
198
+ weight_tensor = self.model.get_initializer(dequantize_node_1.input[0])
199
+ FusionUtils.transpose_2d_int8_tensor(weight_tensor)
200
+
201
+ fused_node = helper.make_node(
202
+ "QOrderedMatMul",
203
+ inputs=fused_node_inputs,
204
+ outputs=[downstream_quantize_node.output[0]],
205
+ name=self.model.create_node_name("QOrderedMatMul", name_prefix="QOrderedMatMul"),
206
+ )
207
+
208
+ fused_node.attribute.extend([helper.make_attribute("order_A", 1)])
209
+ fused_node.attribute.extend([helper.make_attribute("order_B", 0)])
210
+ fused_node.attribute.extend([helper.make_attribute("order_Y", 1)])
211
+
212
+ fused_node.domain = "com.microsoft"
213
+
214
+ self.nodes_to_remove.extend(subgraph_nodes)
215
+ self.nodes_to_add.append(fused_node)
216
+ self.node_name_to_graph_name[fused_node.name] = self.this_graph_name
@@ -0,0 +1,74 @@
1
+ # -------------------------------------------------------------------------
2
+ # Copyright (c) Microsoft Corporation. All rights reserved.
3
+ # Licensed under the MIT License.
4
+ # --------------------------------------------------------------------------
5
+
6
+ import logging
7
+
8
+ from fusion_base import Fusion
9
+ from onnx import helper
10
+ from onnx_model import OnnxModel
11
+
12
+ logger = logging.getLogger(__name__)
13
+
14
+
15
+ class FusionQuickGelu(Fusion):
16
+ def __init__(self, model: OnnxModel):
17
+ super().__init__(model, "QuickGelu", ["Mul"])
18
+
19
+ def fuse(self, node, input_name_to_nodes, output_name_to_node):
20
+ # Fuse the following subgraph to `QuickGelu`
21
+ #
22
+ # root_input
23
+ # / \
24
+ # | Mul ----+
25
+ # | (B = ~1.702) |
26
+ # \ | |
27
+ # \ Sigmoid |---- `QuickGelu`
28
+ # \ / |
29
+ # \ / |
30
+ # Mul ----+
31
+ # |
32
+ # root_output
33
+
34
+ if node.op_type != "Mul":
35
+ logger.debug("fuse_quickgelu: failed to match second Mul node")
36
+ return
37
+
38
+ second_mul_node = node
39
+ root_input = second_mul_node.input[0]
40
+
41
+ sigmoid_node = self.model.match_parent_path(second_mul_node, ["Sigmoid"], [1])
42
+ if sigmoid_node is None:
43
+ logger.debug("fuse_quickgelu: failed to match Sigmoid node")
44
+ return
45
+ sigmoid_node = sigmoid_node[0]
46
+
47
+ first_mul_node = self.model.match_parent_path(sigmoid_node, ["Mul"], [0])
48
+ if first_mul_node is None:
49
+ logger.debug("fuse_quickgelu: failed to match first Mul node")
50
+ return
51
+ first_mul_node = first_mul_node[0]
52
+
53
+ approximation_value = self.model.get_constant_value(first_mul_node.input[1]).item()
54
+ if abs(approximation_value - 1.7021484375) >= 1e-3:
55
+ logger.debug("fuse_quickgelu: failed to match approximation value")
56
+ return
57
+
58
+ if first_mul_node.input[0] != root_input:
59
+ logger.debug("fuse_quickgelu: failed to match root input with first Mul node's input")
60
+ return
61
+
62
+ new_node = helper.make_node(
63
+ "QuickGelu",
64
+ inputs=[root_input],
65
+ outputs=[second_mul_node.output[0]],
66
+ name=self.model.create_node_name("QuickGelu"),
67
+ )
68
+ new_node.domain = "com.microsoft"
69
+ new_node.attribute.extend([helper.make_attribute("alpha", approximation_value)])
70
+
71
+ self.nodes_to_remove.extend([first_mul_node, sigmoid_node, second_mul_node])
72
+ self.nodes_to_add.append(new_node)
73
+ self.node_name_to_graph_name[new_node.name] = self.this_graph_name
74
+ self.increase_counter("QuickGelu")
@@ -0,0 +1,173 @@
1
+ # -------------------------------------------------------------------------
2
+ # Copyright (c) Microsoft Corporation. All rights reserved.
3
+ # Licensed under the MIT License.
4
+ # --------------------------------------------------------------------------
5
+
6
+ from logging import getLogger
7
+
8
+ import numpy as np
9
+ from fusion_base import Fusion
10
+ from onnx import TensorProto, helper
11
+ from onnx_model import OnnxModel
12
+
13
+ logger = getLogger(__name__)
14
+
15
+
16
+ class FusionReshape(Fusion):
17
+ def __init__(self, model: OnnxModel):
18
+ super().__init__(model, "Reshape", "Reshape")
19
+ self.prune_graph: bool = False
20
+
21
+ def replace_reshape_node(self, shape, reshape_node, concat_node):
22
+ shape_value = np.asarray(shape, dtype=np.int64)
23
+ constant_shape_name = self.model.create_node_name("Constant", "constant_shape")
24
+ new_node = helper.make_node(
25
+ "Constant",
26
+ inputs=[],
27
+ outputs=[constant_shape_name],
28
+ value=helper.make_tensor(
29
+ name="const_tensor",
30
+ data_type=TensorProto.INT64,
31
+ dims=shape_value.shape,
32
+ vals=bytes(shape_value),
33
+ raw=True,
34
+ ),
35
+ )
36
+ reshape_node.input[1] = constant_shape_name
37
+ reshape_node.name = self.model.create_node_name("Reshape", "Reshape_Fuse")
38
+ self.nodes_to_remove.extend([concat_node])
39
+ self.nodes_to_add.append(new_node)
40
+ self.node_name_to_graph_name[new_node.name] = self.this_graph_name
41
+
42
+ def fuse(self, reshape_node, input_name_to_nodes, output_name_to_node):
43
+ if reshape_node.input[1] not in output_name_to_node:
44
+ return
45
+
46
+ concat_node = output_name_to_node[reshape_node.input[1]]
47
+ if concat_node.op_type != "Concat" or len(concat_node.input) < 3 or len(concat_node.input) > 4:
48
+ return
49
+
50
+ path0 = self.model.match_parent_path(
51
+ concat_node,
52
+ ["Unsqueeze", "Gather", "Shape"],
53
+ [0, 0, 0],
54
+ output_name_to_node,
55
+ )
56
+ if path0 is None:
57
+ return
58
+
59
+ (unsqueeze_0, gather_0, shape_0) = path0
60
+
61
+ path1 = self.model.match_parent_path(
62
+ concat_node,
63
+ ["Unsqueeze", "Gather", "Shape"],
64
+ [1, 0, 0],
65
+ output_name_to_node,
66
+ )
67
+ if path1 is None:
68
+ return
69
+ (unsqueeze_1, gather_1, shape_1) = path1
70
+
71
+ shape = []
72
+ gather_value = self.model.get_constant_value(gather_0.input[1])
73
+ if gather_value == 0:
74
+ shape.append(0)
75
+
76
+ gather_value = self.model.get_constant_value(gather_1.input[1])
77
+ if gather_value == 1:
78
+ shape.append(0)
79
+
80
+ if len(shape) != 2:
81
+ return
82
+
83
+ path2 = []
84
+ path3 = []
85
+ shape_nodes = [shape_0, shape_1]
86
+ if len(concat_node.input) == 3 and self.model.get_constant_value(concat_node.input[2]) is None:
87
+ path2 = self.model.match_parent_path(
88
+ concat_node,
89
+ ["Unsqueeze", "Mul", "Gather", "Shape"],
90
+ [2, 0, 0, 0],
91
+ output_name_to_node,
92
+ )
93
+ if path2 is None:
94
+ path2 = self.model.match_parent_path(
95
+ concat_node,
96
+ ["Unsqueeze", "Mul", "Squeeze", "Slice", "Shape"],
97
+ [2, 0, 0, 0, 0],
98
+ output_name_to_node,
99
+ ) # GPT2 exported by PyTorch 1.4 with opset_version=11
100
+ if path2 is None:
101
+ return
102
+
103
+ path3 = self.model.match_parent_path(
104
+ concat_node,
105
+ ["Unsqueeze", "Mul", "Gather", "Shape"],
106
+ [2, 0, 1, 0],
107
+ output_name_to_node,
108
+ )
109
+ if path3 is None:
110
+ path3 = self.model.match_parent_path(
111
+ concat_node,
112
+ ["Unsqueeze", "Mul", "Squeeze", "Slice", "Shape"],
113
+ [2, 0, 1, 0, 0],
114
+ output_name_to_node,
115
+ ) # GPT2 exported by PyTorch 1.4 with opset_version=11
116
+ if path3 is None:
117
+ return
118
+
119
+ shape_nodes.extend([path2[-1], path3[-1]])
120
+ shape.append(-1)
121
+ elif len(concat_node.input) > 2:
122
+ concat_value = self.model.get_constant_value(concat_node.input[2])
123
+ if concat_value is None:
124
+ return
125
+ if isinstance(concat_value, np.ndarray):
126
+ shape.extend(concat_value.tolist())
127
+ else:
128
+ shape.append(concat_value)
129
+
130
+ if len(concat_node.input) == 4 and self.model.get_constant_value(concat_node.input[3]) is None:
131
+ if -1 in shape:
132
+ return
133
+
134
+ path2 = self.model.match_parent_path(
135
+ concat_node,
136
+ ["Unsqueeze", "Div", "Gather", "Shape"],
137
+ [3, 0, 0, 0],
138
+ output_name_to_node,
139
+ )
140
+ if path2 is None:
141
+ path2 = self.model.match_parent_path(
142
+ concat_node,
143
+ ["Unsqueeze", "Div", "Squeeze", "Slice", "Shape"],
144
+ [3, 0, 0, 0, 0],
145
+ output_name_to_node,
146
+ ) # GPT2 exported by PyTorch 1.4 with opset_version=11
147
+ if path2 is None:
148
+ return
149
+ shape_nodes.extend([path2[-1]])
150
+ shape.append(-1)
151
+ elif len(concat_node.input) > 3:
152
+ concat_value = self.model.get_constant_value(concat_node.input[3])
153
+ if concat_value is None:
154
+ return
155
+
156
+ if isinstance(concat_value, np.ndarray):
157
+ shape.extend(concat_value.tolist())
158
+ else:
159
+ shape.append(concat_value)
160
+
161
+ root_input = reshape_node.input[0]
162
+ same_shape_input = True
163
+ for shape_node in shape_nodes:
164
+ if shape_node.input[0] != root_input:
165
+ same_shape_input = False
166
+
167
+ if not same_shape_input:
168
+ return
169
+
170
+ self.replace_reshape_node(shape, reshape_node, concat_node)
171
+
172
+ # TODO(tlwu): Subgraph blocks pruning un-used nodes. Add code to remove un-used nodes safely.
173
+ self.prune_graph = True