onnxruntime-directml 1.24.1__cp314-cp314-win_amd64.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (322) hide show
  1. onnxruntime/LICENSE +21 -0
  2. onnxruntime/Privacy.md +21 -0
  3. onnxruntime/ThirdPartyNotices.txt +6121 -0
  4. onnxruntime/__init__.py +418 -0
  5. onnxruntime/backend/__init__.py +6 -0
  6. onnxruntime/backend/backend.py +175 -0
  7. onnxruntime/backend/backend_rep.py +52 -0
  8. onnxruntime/capi/DirectML.dll +0 -0
  9. onnxruntime/capi/__init__.py +4 -0
  10. onnxruntime/capi/_ld_preload.py +7 -0
  11. onnxruntime/capi/_pybind_state.py +33 -0
  12. onnxruntime/capi/build_and_package_info.py +2 -0
  13. onnxruntime/capi/convert_npz_to_onnx_adapter.py +48 -0
  14. onnxruntime/capi/onnxruntime.dll +0 -0
  15. onnxruntime/capi/onnxruntime_collect_build_info.py +47 -0
  16. onnxruntime/capi/onnxruntime_inference_collection.py +1440 -0
  17. onnxruntime/capi/onnxruntime_providers_shared.dll +0 -0
  18. onnxruntime/capi/onnxruntime_pybind11_state.pyd +0 -0
  19. onnxruntime/capi/onnxruntime_validation.py +154 -0
  20. onnxruntime/capi/version_info.py +2 -0
  21. onnxruntime/datasets/__init__.py +18 -0
  22. onnxruntime/datasets/logreg_iris.onnx +0 -0
  23. onnxruntime/datasets/mul_1.onnx +0 -0
  24. onnxruntime/datasets/sigmoid.onnx +13 -0
  25. onnxruntime/quantization/CalTableFlatBuffers/KeyValue.py +78 -0
  26. onnxruntime/quantization/CalTableFlatBuffers/TrtTable.py +90 -0
  27. onnxruntime/quantization/CalTableFlatBuffers/__init__.py +0 -0
  28. onnxruntime/quantization/__init__.py +19 -0
  29. onnxruntime/quantization/base_quantizer.py +529 -0
  30. onnxruntime/quantization/calibrate.py +1267 -0
  31. onnxruntime/quantization/execution_providers/qnn/__init__.py +2 -0
  32. onnxruntime/quantization/execution_providers/qnn/fusion_lpnorm.py +132 -0
  33. onnxruntime/quantization/execution_providers/qnn/fusion_spacetodepth.py +162 -0
  34. onnxruntime/quantization/execution_providers/qnn/mixed_precision_overrides_utils.py +413 -0
  35. onnxruntime/quantization/execution_providers/qnn/preprocess.py +353 -0
  36. onnxruntime/quantization/execution_providers/qnn/quant_config.py +389 -0
  37. onnxruntime/quantization/fusions/__init__.py +4 -0
  38. onnxruntime/quantization/fusions/fusion.py +311 -0
  39. onnxruntime/quantization/fusions/fusion_gelu.py +272 -0
  40. onnxruntime/quantization/fusions/fusion_layernorm.py +146 -0
  41. onnxruntime/quantization/fusions/replace_upsample_with_resize.py +96 -0
  42. onnxruntime/quantization/matmul_bnb4_quantizer.py +239 -0
  43. onnxruntime/quantization/matmul_nbits_quantizer.py +1638 -0
  44. onnxruntime/quantization/neural_compressor/__init__.py +1 -0
  45. onnxruntime/quantization/neural_compressor/onnx_model.py +1251 -0
  46. onnxruntime/quantization/neural_compressor/util.py +80 -0
  47. onnxruntime/quantization/neural_compressor/weight_only.py +932 -0
  48. onnxruntime/quantization/onnx_model.py +600 -0
  49. onnxruntime/quantization/onnx_quantizer.py +1163 -0
  50. onnxruntime/quantization/operators/__init__.py +2 -0
  51. onnxruntime/quantization/operators/activation.py +119 -0
  52. onnxruntime/quantization/operators/argmax.py +18 -0
  53. onnxruntime/quantization/operators/attention.py +73 -0
  54. onnxruntime/quantization/operators/base_operator.py +26 -0
  55. onnxruntime/quantization/operators/binary_op.py +72 -0
  56. onnxruntime/quantization/operators/concat.py +62 -0
  57. onnxruntime/quantization/operators/conv.py +260 -0
  58. onnxruntime/quantization/operators/direct_q8.py +78 -0
  59. onnxruntime/quantization/operators/embed_layernorm.py +121 -0
  60. onnxruntime/quantization/operators/gather.py +64 -0
  61. onnxruntime/quantization/operators/gavgpool.py +62 -0
  62. onnxruntime/quantization/operators/gemm.py +172 -0
  63. onnxruntime/quantization/operators/lstm.py +121 -0
  64. onnxruntime/quantization/operators/matmul.py +231 -0
  65. onnxruntime/quantization/operators/maxpool.py +34 -0
  66. onnxruntime/quantization/operators/norm.py +40 -0
  67. onnxruntime/quantization/operators/pad.py +172 -0
  68. onnxruntime/quantization/operators/pooling.py +67 -0
  69. onnxruntime/quantization/operators/qdq_base_operator.py +22 -0
  70. onnxruntime/quantization/operators/resize.py +34 -0
  71. onnxruntime/quantization/operators/softmax.py +74 -0
  72. onnxruntime/quantization/operators/split.py +63 -0
  73. onnxruntime/quantization/operators/where.py +87 -0
  74. onnxruntime/quantization/preprocess.py +141 -0
  75. onnxruntime/quantization/qdq_loss_debug.py +389 -0
  76. onnxruntime/quantization/qdq_quantizer.py +1477 -0
  77. onnxruntime/quantization/quant_utils.py +1051 -0
  78. onnxruntime/quantization/quantize.py +953 -0
  79. onnxruntime/quantization/registry.py +110 -0
  80. onnxruntime/quantization/shape_inference.py +204 -0
  81. onnxruntime/quantization/static_quantize_runner.py +256 -0
  82. onnxruntime/quantization/tensor_quant_overrides.py +520 -0
  83. onnxruntime/tools/__init__.py +10 -0
  84. onnxruntime/tools/check_onnx_model_mobile_usability.py +47 -0
  85. onnxruntime/tools/convert_onnx_models_to_ort.py +380 -0
  86. onnxruntime/tools/file_utils.py +47 -0
  87. onnxruntime/tools/logger.py +11 -0
  88. onnxruntime/tools/make_dynamic_shape_fixed.py +73 -0
  89. onnxruntime/tools/mobile_helpers/__init__.py +0 -0
  90. onnxruntime/tools/mobile_helpers/coreml_supported_mlprogram_ops.md +53 -0
  91. onnxruntime/tools/mobile_helpers/coreml_supported_neuralnetwork_ops.md +43 -0
  92. onnxruntime/tools/mobile_helpers/nnapi_supported_ops.md +58 -0
  93. onnxruntime/tools/mobile_helpers/usability_checker.py +738 -0
  94. onnxruntime/tools/offline_tuning.py +169 -0
  95. onnxruntime/tools/onnx_model_utils.py +416 -0
  96. onnxruntime/tools/onnx_randomizer.py +85 -0
  97. onnxruntime/tools/onnxruntime_test.py +164 -0
  98. onnxruntime/tools/optimize_onnx_model.py +56 -0
  99. onnxruntime/tools/ort_format_model/__init__.py +27 -0
  100. onnxruntime/tools/ort_format_model/operator_type_usage_processors.py +653 -0
  101. onnxruntime/tools/ort_format_model/ort_flatbuffers_py/__init__.py +0 -0
  102. onnxruntime/tools/ort_format_model/ort_flatbuffers_py/fbs/ArgType.py +7 -0
  103. onnxruntime/tools/ort_format_model/ort_flatbuffers_py/fbs/ArgTypeAndIndex.py +67 -0
  104. onnxruntime/tools/ort_format_model/ort_flatbuffers_py/fbs/Attribute.py +337 -0
  105. onnxruntime/tools/ort_format_model/ort_flatbuffers_py/fbs/AttributeType.py +18 -0
  106. onnxruntime/tools/ort_format_model/ort_flatbuffers_py/fbs/Checkpoint.py +125 -0
  107. onnxruntime/tools/ort_format_model/ort_flatbuffers_py/fbs/DeprecatedKernelCreateInfos.py +120 -0
  108. onnxruntime/tools/ort_format_model/ort_flatbuffers_py/fbs/DeprecatedNodeIndexAndKernelDefHash.py +68 -0
  109. onnxruntime/tools/ort_format_model/ort_flatbuffers_py/fbs/DeprecatedSessionState.py +96 -0
  110. onnxruntime/tools/ort_format_model/ort_flatbuffers_py/fbs/DeprecatedSubGraphSessionState.py +72 -0
  111. onnxruntime/tools/ort_format_model/ort_flatbuffers_py/fbs/Dimension.py +71 -0
  112. onnxruntime/tools/ort_format_model/ort_flatbuffers_py/fbs/DimensionValue.py +80 -0
  113. onnxruntime/tools/ort_format_model/ort_flatbuffers_py/fbs/DimensionValueType.py +8 -0
  114. onnxruntime/tools/ort_format_model/ort_flatbuffers_py/fbs/EdgeEnd.py +32 -0
  115. onnxruntime/tools/ort_format_model/ort_flatbuffers_py/fbs/FloatProperty.py +67 -0
  116. onnxruntime/tools/ort_format_model/ort_flatbuffers_py/fbs/Graph.py +320 -0
  117. onnxruntime/tools/ort_format_model/ort_flatbuffers_py/fbs/InferenceSession.py +88 -0
  118. onnxruntime/tools/ort_format_model/ort_flatbuffers_py/fbs/IntProperty.py +67 -0
  119. onnxruntime/tools/ort_format_model/ort_flatbuffers_py/fbs/KernelTypeStrArgsEntry.py +91 -0
  120. onnxruntime/tools/ort_format_model/ort_flatbuffers_py/fbs/KernelTypeStrResolver.py +78 -0
  121. onnxruntime/tools/ort_format_model/ort_flatbuffers_py/fbs/MapType.py +71 -0
  122. onnxruntime/tools/ort_format_model/ort_flatbuffers_py/fbs/Model.py +223 -0
  123. onnxruntime/tools/ort_format_model/ort_flatbuffers_py/fbs/ModuleState.py +141 -0
  124. onnxruntime/tools/ort_format_model/ort_flatbuffers_py/fbs/Node.py +317 -0
  125. onnxruntime/tools/ort_format_model/ort_flatbuffers_py/fbs/NodeEdge.py +126 -0
  126. onnxruntime/tools/ort_format_model/ort_flatbuffers_py/fbs/NodeType.py +7 -0
  127. onnxruntime/tools/ort_format_model/ort_flatbuffers_py/fbs/NodesToOptimizeIndices.py +160 -0
  128. onnxruntime/tools/ort_format_model/ort_flatbuffers_py/fbs/OpIdKernelTypeStrArgsEntry.py +91 -0
  129. onnxruntime/tools/ort_format_model/ort_flatbuffers_py/fbs/OperatorSetId.py +67 -0
  130. onnxruntime/tools/ort_format_model/ort_flatbuffers_py/fbs/OptimizerGroup.py +117 -0
  131. onnxruntime/tools/ort_format_model/ort_flatbuffers_py/fbs/ParameterOptimizerState.py +91 -0
  132. onnxruntime/tools/ort_format_model/ort_flatbuffers_py/fbs/PropertyBag.py +152 -0
  133. onnxruntime/tools/ort_format_model/ort_flatbuffers_py/fbs/RuntimeOptimizationRecord.py +105 -0
  134. onnxruntime/tools/ort_format_model/ort_flatbuffers_py/fbs/RuntimeOptimizationRecordContainerEntry.py +91 -0
  135. onnxruntime/tools/ort_format_model/ort_flatbuffers_py/fbs/RuntimeOptimizations.py +79 -0
  136. onnxruntime/tools/ort_format_model/ort_flatbuffers_py/fbs/SequenceType.py +58 -0
  137. onnxruntime/tools/ort_format_model/ort_flatbuffers_py/fbs/Shape.py +78 -0
  138. onnxruntime/tools/ort_format_model/ort_flatbuffers_py/fbs/SparseTensor.py +114 -0
  139. onnxruntime/tools/ort_format_model/ort_flatbuffers_py/fbs/StringProperty.py +67 -0
  140. onnxruntime/tools/ort_format_model/ort_flatbuffers_py/fbs/StringStringEntry.py +67 -0
  141. onnxruntime/tools/ort_format_model/ort_flatbuffers_py/fbs/Tensor.py +203 -0
  142. onnxruntime/tools/ort_format_model/ort_flatbuffers_py/fbs/TensorDataType.py +26 -0
  143. onnxruntime/tools/ort_format_model/ort_flatbuffers_py/fbs/TensorTypeAndShape.py +71 -0
  144. onnxruntime/tools/ort_format_model/ort_flatbuffers_py/fbs/TypeInfo.py +83 -0
  145. onnxruntime/tools/ort_format_model/ort_flatbuffers_py/fbs/TypeInfoValue.py +9 -0
  146. onnxruntime/tools/ort_format_model/ort_flatbuffers_py/fbs/ValueInfo.py +84 -0
  147. onnxruntime/tools/ort_format_model/ort_flatbuffers_py/fbs/__init__.py +6 -0
  148. onnxruntime/tools/ort_format_model/ort_model_processor.py +86 -0
  149. onnxruntime/tools/ort_format_model/types.py +85 -0
  150. onnxruntime/tools/ort_format_model/utils.py +61 -0
  151. onnxruntime/tools/pytorch_export_contrib_ops.py +129 -0
  152. onnxruntime/tools/pytorch_export_helpers.py +131 -0
  153. onnxruntime/tools/qdq_helpers/__init__.py +0 -0
  154. onnxruntime/tools/qdq_helpers/optimize_qdq_model.py +37 -0
  155. onnxruntime/tools/qnn/add_trans_cast.py +292 -0
  156. onnxruntime/tools/qnn/gen_qnn_ctx_onnx_model.py +364 -0
  157. onnxruntime/tools/qnn/preprocess.py +165 -0
  158. onnxruntime/tools/reduced_build_config_parser.py +203 -0
  159. onnxruntime/tools/remove_initializer_from_input.py +37 -0
  160. onnxruntime/tools/symbolic_shape_infer.py +3094 -0
  161. onnxruntime/tools/update_onnx_opset.py +31 -0
  162. onnxruntime/transformers/__init__.py +8 -0
  163. onnxruntime/transformers/affinity_helper.py +40 -0
  164. onnxruntime/transformers/benchmark.py +942 -0
  165. onnxruntime/transformers/benchmark_helper.py +643 -0
  166. onnxruntime/transformers/bert_perf_test.py +629 -0
  167. onnxruntime/transformers/bert_test_data.py +641 -0
  168. onnxruntime/transformers/compare_bert_results.py +256 -0
  169. onnxruntime/transformers/constants.py +47 -0
  170. onnxruntime/transformers/convert_generation.py +3605 -0
  171. onnxruntime/transformers/convert_tf_models_to_pytorch.py +205 -0
  172. onnxruntime/transformers/convert_to_packing_mode.py +385 -0
  173. onnxruntime/transformers/dynamo_onnx_helper.py +205 -0
  174. onnxruntime/transformers/float16.py +501 -0
  175. onnxruntime/transformers/fusion_attention.py +1189 -0
  176. onnxruntime/transformers/fusion_attention_clip.py +340 -0
  177. onnxruntime/transformers/fusion_attention_sam2.py +533 -0
  178. onnxruntime/transformers/fusion_attention_unet.py +1307 -0
  179. onnxruntime/transformers/fusion_attention_vae.py +300 -0
  180. onnxruntime/transformers/fusion_bart_attention.py +435 -0
  181. onnxruntime/transformers/fusion_base.py +141 -0
  182. onnxruntime/transformers/fusion_bias_add.py +57 -0
  183. onnxruntime/transformers/fusion_biasgelu.py +66 -0
  184. onnxruntime/transformers/fusion_biassplitgelu.py +110 -0
  185. onnxruntime/transformers/fusion_conformer_attention.py +222 -0
  186. onnxruntime/transformers/fusion_constant_fold.py +144 -0
  187. onnxruntime/transformers/fusion_embedlayer.py +810 -0
  188. onnxruntime/transformers/fusion_fastgelu.py +492 -0
  189. onnxruntime/transformers/fusion_gelu.py +258 -0
  190. onnxruntime/transformers/fusion_gelu_approximation.py +25 -0
  191. onnxruntime/transformers/fusion_gemmfastgelu.py +121 -0
  192. onnxruntime/transformers/fusion_gpt_attention.py +546 -0
  193. onnxruntime/transformers/fusion_gpt_attention_megatron.py +355 -0
  194. onnxruntime/transformers/fusion_gpt_attention_no_past.py +260 -0
  195. onnxruntime/transformers/fusion_group_norm.py +180 -0
  196. onnxruntime/transformers/fusion_layernorm.py +489 -0
  197. onnxruntime/transformers/fusion_mha_mmdit.py +667 -0
  198. onnxruntime/transformers/fusion_nhwc_conv.py +99 -0
  199. onnxruntime/transformers/fusion_options.py +340 -0
  200. onnxruntime/transformers/fusion_qordered_attention.py +420 -0
  201. onnxruntime/transformers/fusion_qordered_gelu.py +118 -0
  202. onnxruntime/transformers/fusion_qordered_layernorm.py +122 -0
  203. onnxruntime/transformers/fusion_qordered_matmul.py +216 -0
  204. onnxruntime/transformers/fusion_quickgelu.py +74 -0
  205. onnxruntime/transformers/fusion_reshape.py +173 -0
  206. onnxruntime/transformers/fusion_rotary_attention.py +1591 -0
  207. onnxruntime/transformers/fusion_shape.py +109 -0
  208. onnxruntime/transformers/fusion_simplified_layernorm.py +165 -0
  209. onnxruntime/transformers/fusion_skip_group_norm.py +254 -0
  210. onnxruntime/transformers/fusion_skiplayernorm.py +209 -0
  211. onnxruntime/transformers/fusion_transpose.py +167 -0
  212. onnxruntime/transformers/fusion_utils.py +321 -0
  213. onnxruntime/transformers/huggingface_models.py +74 -0
  214. onnxruntime/transformers/import_utils.py +20 -0
  215. onnxruntime/transformers/io_binding_helper.py +487 -0
  216. onnxruntime/transformers/large_model_exporter.py +395 -0
  217. onnxruntime/transformers/machine_info.py +230 -0
  218. onnxruntime/transformers/metrics.py +163 -0
  219. onnxruntime/transformers/models/bart/__init__.py +12 -0
  220. onnxruntime/transformers/models/bart/export.py +98 -0
  221. onnxruntime/transformers/models/bert/__init__.py +12 -0
  222. onnxruntime/transformers/models/bert/eval_squad.py +329 -0
  223. onnxruntime/transformers/models/gpt2/__init__.py +12 -0
  224. onnxruntime/transformers/models/gpt2/benchmark_gpt2.py +413 -0
  225. onnxruntime/transformers/models/gpt2/convert_to_onnx.py +566 -0
  226. onnxruntime/transformers/models/gpt2/gpt2_helper.py +1031 -0
  227. onnxruntime/transformers/models/gpt2/gpt2_parity.py +513 -0
  228. onnxruntime/transformers/models/gpt2/gpt2_tester.py +501 -0
  229. onnxruntime/transformers/models/gpt2/parity_check_helper.py +146 -0
  230. onnxruntime/transformers/models/llama/__init__.py +12 -0
  231. onnxruntime/transformers/models/llama/benchmark.py +700 -0
  232. onnxruntime/transformers/models/llama/benchmark_all.py +488 -0
  233. onnxruntime/transformers/models/llama/benchmark_e2e.py +608 -0
  234. onnxruntime/transformers/models/llama/convert_to_onnx.py +1064 -0
  235. onnxruntime/transformers/models/llama/dist_settings.py +57 -0
  236. onnxruntime/transformers/models/llama/llama_inputs.py +504 -0
  237. onnxruntime/transformers/models/llama/llama_parity.py +343 -0
  238. onnxruntime/transformers/models/llama/llama_torch.py +47 -0
  239. onnxruntime/transformers/models/llama/quant_kv_dataloader.py +108 -0
  240. onnxruntime/transformers/models/longformer/__init__.py +12 -0
  241. onnxruntime/transformers/models/longformer/benchmark_longformer.py +821 -0
  242. onnxruntime/transformers/models/longformer/convert_to_onnx.py +413 -0
  243. onnxruntime/transformers/models/longformer/generate_test_data.py +347 -0
  244. onnxruntime/transformers/models/longformer/longformer_helper.py +76 -0
  245. onnxruntime/transformers/models/phi2/__init__.py +12 -0
  246. onnxruntime/transformers/models/phi2/convert_to_onnx.py +590 -0
  247. onnxruntime/transformers/models/phi2/inference_example.py +414 -0
  248. onnxruntime/transformers/models/sam2/__init__.py +12 -0
  249. onnxruntime/transformers/models/sam2/benchmark_sam2.py +638 -0
  250. onnxruntime/transformers/models/sam2/convert_to_onnx.py +270 -0
  251. onnxruntime/transformers/models/sam2/image_decoder.py +272 -0
  252. onnxruntime/transformers/models/sam2/image_encoder.py +236 -0
  253. onnxruntime/transformers/models/sam2/mask_decoder.py +208 -0
  254. onnxruntime/transformers/models/sam2/nvtx_helper.py +33 -0
  255. onnxruntime/transformers/models/sam2/prompt_encoder.py +189 -0
  256. onnxruntime/transformers/models/sam2/sam2_demo.py +321 -0
  257. onnxruntime/transformers/models/sam2/sam2_image_onnx_predictor.py +279 -0
  258. onnxruntime/transformers/models/sam2/sam2_utils.py +147 -0
  259. onnxruntime/transformers/models/stable_diffusion/__init__.py +12 -0
  260. onnxruntime/transformers/models/stable_diffusion/benchmark.py +1519 -0
  261. onnxruntime/transformers/models/stable_diffusion/benchmark_controlnet.py +426 -0
  262. onnxruntime/transformers/models/stable_diffusion/demo_txt2img.py +103 -0
  263. onnxruntime/transformers/models/stable_diffusion/demo_txt2img_xl.py +269 -0
  264. onnxruntime/transformers/models/stable_diffusion/demo_utils.py +778 -0
  265. onnxruntime/transformers/models/stable_diffusion/diffusion_models.py +1318 -0
  266. onnxruntime/transformers/models/stable_diffusion/diffusion_schedulers.py +1179 -0
  267. onnxruntime/transformers/models/stable_diffusion/engine_builder.py +295 -0
  268. onnxruntime/transformers/models/stable_diffusion/engine_builder_ort_cuda.py +387 -0
  269. onnxruntime/transformers/models/stable_diffusion/engine_builder_ort_trt.py +288 -0
  270. onnxruntime/transformers/models/stable_diffusion/engine_builder_tensorrt.py +395 -0
  271. onnxruntime/transformers/models/stable_diffusion/engine_builder_torch.py +108 -0
  272. onnxruntime/transformers/models/stable_diffusion/optimize_pipeline.py +590 -0
  273. onnxruntime/transformers/models/stable_diffusion/ort_optimizer.py +136 -0
  274. onnxruntime/transformers/models/stable_diffusion/pipeline_stable_diffusion.py +831 -0
  275. onnxruntime/transformers/models/stable_diffusion/trt_utilities.py +12 -0
  276. onnxruntime/transformers/models/t5/__init__.py +12 -0
  277. onnxruntime/transformers/models/t5/convert_to_onnx.py +318 -0
  278. onnxruntime/transformers/models/t5/t5_decoder.py +437 -0
  279. onnxruntime/transformers/models/t5/t5_encoder.py +70 -0
  280. onnxruntime/transformers/models/t5/t5_encoder_decoder_init.py +361 -0
  281. onnxruntime/transformers/models/t5/t5_helper.py +302 -0
  282. onnxruntime/transformers/models/whisper/__init__.py +12 -0
  283. onnxruntime/transformers/models/whisper/benchmark.py +585 -0
  284. onnxruntime/transformers/models/whisper/benchmark_all.py +526 -0
  285. onnxruntime/transformers/models/whisper/convert_to_onnx.py +609 -0
  286. onnxruntime/transformers/models/whisper/whisper_chain.py +334 -0
  287. onnxruntime/transformers/models/whisper/whisper_decoder.py +464 -0
  288. onnxruntime/transformers/models/whisper/whisper_encoder.py +164 -0
  289. onnxruntime/transformers/models/whisper/whisper_encoder_decoder_init.py +371 -0
  290. onnxruntime/transformers/models/whisper/whisper_helper.py +1035 -0
  291. onnxruntime/transformers/models/whisper/whisper_inputs.py +380 -0
  292. onnxruntime/transformers/models/whisper/whisper_jump_times.py +477 -0
  293. onnxruntime/transformers/onnx_exporter.py +719 -0
  294. onnxruntime/transformers/onnx_model.py +1636 -0
  295. onnxruntime/transformers/onnx_model_bart.py +141 -0
  296. onnxruntime/transformers/onnx_model_bert.py +488 -0
  297. onnxruntime/transformers/onnx_model_bert_keras.py +474 -0
  298. onnxruntime/transformers/onnx_model_bert_tf.py +588 -0
  299. onnxruntime/transformers/onnx_model_clip.py +42 -0
  300. onnxruntime/transformers/onnx_model_conformer.py +32 -0
  301. onnxruntime/transformers/onnx_model_gpt2.py +101 -0
  302. onnxruntime/transformers/onnx_model_mmdit.py +112 -0
  303. onnxruntime/transformers/onnx_model_phi.py +929 -0
  304. onnxruntime/transformers/onnx_model_sam2.py +137 -0
  305. onnxruntime/transformers/onnx_model_t5.py +985 -0
  306. onnxruntime/transformers/onnx_model_tnlr.py +226 -0
  307. onnxruntime/transformers/onnx_model_unet.py +258 -0
  308. onnxruntime/transformers/onnx_model_vae.py +42 -0
  309. onnxruntime/transformers/onnx_utils.py +55 -0
  310. onnxruntime/transformers/optimizer.py +620 -0
  311. onnxruntime/transformers/past_helper.py +149 -0
  312. onnxruntime/transformers/profile_result_processor.py +358 -0
  313. onnxruntime/transformers/profiler.py +434 -0
  314. onnxruntime/transformers/quantize_helper.py +76 -0
  315. onnxruntime/transformers/shape_infer_helper.py +121 -0
  316. onnxruntime/transformers/shape_optimizer.py +400 -0
  317. onnxruntime/transformers/torch_onnx_export_helper.py +74 -0
  318. onnxruntime_directml-1.24.1.dist-info/METADATA +216 -0
  319. onnxruntime_directml-1.24.1.dist-info/RECORD +322 -0
  320. onnxruntime_directml-1.24.1.dist-info/WHEEL +5 -0
  321. onnxruntime_directml-1.24.1.dist-info/entry_points.txt +2 -0
  322. onnxruntime_directml-1.24.1.dist-info/top_level.txt +1 -0
@@ -0,0 +1,99 @@
1
+ # -------------------------------------------------------------------------
2
+ # Copyright (c) Microsoft Corporation. All rights reserved.
3
+ # Licensed under the MIT License.
4
+ # --------------------------------------------------------------------------
5
+
6
+ from logging import getLogger
7
+
8
+ from fusion_base import Fusion
9
+ from fusion_utils import FusionUtils
10
+ from onnx import helper, numpy_helper
11
+ from onnx_model import OnnxModel
12
+
13
+ logger = getLogger(__name__)
14
+
15
+
16
+ class FusionNhwcConv(Fusion):
17
+ """Convert Conv to NhwcConv"""
18
+
19
+ def __init__(self, model: OnnxModel, update_weight=False):
20
+ super().__init__(model, "NhwcConv", ["Conv"], "NhwcConv")
21
+ self.update_weight = update_weight
22
+ self.fusion_utils = FusionUtils(model)
23
+
24
+ def create_transpose_node(self, input_name: str, perm: list[int], output_name=None):
25
+ """Append a Transpose node after an input"""
26
+ node_name = self.model.create_node_name("Transpose")
27
+
28
+ if output_name is None:
29
+ output_name = node_name + "_out" + "-" + input_name
30
+
31
+ transpose_node = helper.make_node("Transpose", inputs=[input_name], outputs=[output_name], name=node_name)
32
+ transpose_node.attribute.extend([helper.make_attribute("perm", perm)])
33
+
34
+ return transpose_node
35
+
36
+ def fuse(self, conv, input_name_to_nodes, output_name_to_node):
37
+ # Add Transpose node to convert input from NCHW to NHWC
38
+ input_transpose_node = self.create_transpose_node(conv.input[0], [0, 2, 3, 1])
39
+
40
+ nhwc_conv_input = input_transpose_node.output[0]
41
+
42
+ # Create a tensor for transposed weights (already in NHWC format).
43
+ node_name = self.model.create_node_name("NhwcConv")
44
+
45
+ # Make sure the weights is 4D
46
+ weight_tensor = self.model.get_initializer(conv.input[1])
47
+ if weight_tensor is None:
48
+ return
49
+ weight = numpy_helper.to_array(weight_tensor)
50
+ if len(weight.shape) != 4:
51
+ return
52
+
53
+ dtype = self.model.get_dtype(nhwc_conv_input)
54
+ if not (dtype is not None and weight_tensor.data_type == dtype):
55
+ cast_node = self.fusion_utils.add_cast_node(
56
+ input_name=nhwc_conv_input,
57
+ to_type=weight_tensor.data_type,
58
+ output_name_to_node=output_name_to_node,
59
+ )
60
+ nhwc_conv_input = cast_node.output[0]
61
+
62
+ if self.update_weight:
63
+ # Transpose weights from NCHW to NHWC
64
+ weight = weight.transpose(0, 2, 3, 1)
65
+
66
+ weight_name = node_name + "_weight_NHWC"
67
+ self.add_initializer(
68
+ name=weight_name,
69
+ data_type=weight_tensor.data_type,
70
+ dims=list(weight.shape),
71
+ vals=weight,
72
+ )
73
+ weight_transpose_node = None
74
+ else:
75
+ weight_transpose_node = self.create_transpose_node(conv.input[1], [0, 2, 3, 1])
76
+ weight_name = weight_transpose_node.output[0]
77
+
78
+ nhwc_output_name = node_name + "_out" + "-" + conv.output[0]
79
+ nhwc_conv = helper.make_node(
80
+ "NhwcConv",
81
+ inputs=[nhwc_conv_input, weight_name, *conv.input[2:]],
82
+ outputs=[nhwc_output_name],
83
+ name=node_name + "-" + conv.name,
84
+ )
85
+ nhwc_conv.attribute.extend(conv.attribute)
86
+ nhwc_conv.domain = "com.microsoft"
87
+
88
+ output_transpose_node = self.create_transpose_node(nhwc_conv.output[0], [0, 3, 1, 2], conv.output[0])
89
+
90
+ self.nodes_to_remove.append(conv)
91
+
92
+ nodes_to_add = [input_transpose_node, nhwc_conv, output_transpose_node]
93
+ if weight_transpose_node:
94
+ nodes_to_add.append(weight_transpose_node)
95
+ for node in nodes_to_add:
96
+ self.node_name_to_graph_name[node.name] = self.this_graph_name
97
+ self.nodes_to_add.extend(nodes_to_add)
98
+
99
+ self.increase_counter("NhwcConv")
@@ -0,0 +1,340 @@
1
+ # -------------------------------------------------------------------------
2
+ # Copyright (c) Microsoft Corporation. All rights reserved.
3
+ # Licensed under the MIT License.
4
+ # --------------------------------------------------------------------------
5
+ from argparse import ArgumentParser
6
+ from enum import Enum
7
+
8
+
9
+ class AttentionMaskFormat:
10
+ # Build 1D mask indice (sequence length). It requires right side padding! Recommended for BERT model to get best performance.
11
+ MaskIndexEnd = 0
12
+
13
+ # For experiment only. Do not use it in production.
14
+ MaskIndexEndAndStart = 1
15
+
16
+ # Raw attention mask with 0 means padding (or no attention) and 1 otherwise.
17
+ AttentionMask = 2
18
+
19
+ # No attention mask
20
+ NoMask = 3
21
+
22
+
23
+ class AttentionOpType(Enum):
24
+ Attention = "Attention"
25
+ MultiHeadAttention = "MultiHeadAttention"
26
+ GroupQueryAttention = "GroupQueryAttention"
27
+ PagedAttention = "PagedAttention"
28
+
29
+ def __str__(self):
30
+ return self.value
31
+
32
+ # Override __eq__ to return string comparison
33
+ def __hash__(self):
34
+ return hash(self.value)
35
+
36
+ def __eq__(self, other):
37
+ return other.value == self.value
38
+
39
+
40
+ class FusionOptions:
41
+ """Options of fusion in graph optimization"""
42
+
43
+ def __init__(self, model_type):
44
+ self.enable_gelu = True
45
+ self.enable_layer_norm = True
46
+ self.enable_attention = True
47
+ self.enable_rotary_embeddings = True
48
+
49
+ # Use MultiHeadAttention instead of Attention operator. The difference:
50
+ # (1) Attention has merged weights for Q/K/V projection, which might be faster in some cases since 3 MatMul is
51
+ # merged into one.
52
+ # (2) Attention could only handle self attention; MultiHeadAttention could handle both self and cross attention.
53
+ self.use_multi_head_attention = False
54
+ self.disable_multi_head_attention_bias = False
55
+
56
+ self.enable_skip_layer_norm = True
57
+ self.enable_embed_layer_norm = True
58
+ self.enable_bias_skip_layer_norm = True
59
+ self.enable_bias_gelu = True
60
+ self.enable_gelu_approximation = False
61
+ self.enable_qordered_matmul = True
62
+
63
+ self.enable_shape_inference = True
64
+ self.enable_gemm_fast_gelu = False
65
+ self.group_norm_channels_last = True
66
+
67
+ if model_type == "clip":
68
+ self.enable_embed_layer_norm = False
69
+
70
+ # Set default to sequence length for BERT model to use fused attention to speed up.
71
+ # Note that embed layer normalization will convert 2D mask to 1D when mask type is MaskIndexEnd.
72
+ self.attention_mask_format = AttentionMaskFormat.AttentionMask
73
+ if model_type == "bert":
74
+ self.attention_mask_format = AttentionMaskFormat.MaskIndexEnd
75
+ elif model_type == "vit":
76
+ self.attention_mask_format = AttentionMaskFormat.NoMask
77
+
78
+ self.attention_op_type = None
79
+
80
+ # options for stable diffusion
81
+ if model_type in ["unet", "vae", "clip"]:
82
+ self.enable_nhwc_conv = True
83
+ self.enable_group_norm = True
84
+ self.enable_skip_group_norm = True
85
+ self.enable_bias_splitgelu = True
86
+ self.enable_packed_qkv = True
87
+ self.enable_packed_kv = True
88
+ self.enable_bias_add = True
89
+
90
+ def use_raw_attention_mask(self, use_raw_mask=True):
91
+ if use_raw_mask:
92
+ self.attention_mask_format = AttentionMaskFormat.AttentionMask
93
+ else:
94
+ self.attention_mask_format = AttentionMaskFormat.MaskIndexEnd
95
+
96
+ def disable_attention_mask(self):
97
+ self.attention_mask_format = AttentionMaskFormat.NoMask
98
+
99
+ def set_attention_op_type(self, attn_op_type: AttentionOpType):
100
+ self.attention_op_type = attn_op_type
101
+
102
+ @staticmethod
103
+ def parse(args):
104
+ options = FusionOptions(args.model_type)
105
+ if args.disable_gelu:
106
+ options.enable_gelu = False
107
+ if args.disable_layer_norm:
108
+ options.enable_layer_norm = False
109
+ if args.disable_rotary_embeddings:
110
+ options.enable_rotary_embeddings = False
111
+ if args.disable_attention:
112
+ options.enable_attention = False
113
+ if args.use_multi_head_attention:
114
+ options.use_multi_head_attention = True
115
+ if args.disable_skip_layer_norm:
116
+ options.enable_skip_layer_norm = False
117
+ if args.disable_embed_layer_norm:
118
+ options.enable_embed_layer_norm = False
119
+ if args.disable_bias_skip_layer_norm:
120
+ options.enable_bias_skip_layer_norm = False
121
+ if args.disable_bias_gelu:
122
+ options.enable_bias_gelu = False
123
+ if args.enable_gelu_approximation:
124
+ options.enable_gelu_approximation = True
125
+ if args.disable_shape_inference:
126
+ options.enable_shape_inference = False
127
+ if args.enable_gemm_fast_gelu:
128
+ options.enable_gemm_fast_gelu = True
129
+ if args.use_mask_index:
130
+ options.use_raw_attention_mask(False)
131
+ if args.use_raw_attention_mask:
132
+ options.use_raw_attention_mask(True)
133
+ if args.no_attention_mask:
134
+ options.disable_attention_mask()
135
+
136
+ if args.model_type in ["unet", "vae", "clip"]:
137
+ if args.use_group_norm_channels_first:
138
+ options.group_norm_channels_last = False
139
+ if args.disable_nhwc_conv:
140
+ options.enable_nhwc_conv = False
141
+ if args.disable_group_norm:
142
+ options.enable_group_norm = False
143
+ if args.disable_skip_group_norm:
144
+ options.enable_skip_group_norm = False
145
+ if args.disable_bias_splitgelu:
146
+ options.enable_bias_splitgelu = False
147
+ if args.disable_packed_qkv:
148
+ options.enable_packed_qkv = False
149
+ if args.disable_packed_kv:
150
+ options.enable_packed_kv = False
151
+ if args.disable_bias_add:
152
+ options.enable_bias_add = False
153
+
154
+ return options
155
+
156
+ @staticmethod
157
+ def add_arguments(parser: ArgumentParser):
158
+ parser.add_argument(
159
+ "--disable_attention",
160
+ required=False,
161
+ action="store_true",
162
+ help="disable Attention fusion",
163
+ )
164
+ parser.set_defaults(disable_attention=False)
165
+
166
+ parser.add_argument(
167
+ "--disable_skip_layer_norm",
168
+ required=False,
169
+ action="store_true",
170
+ help="disable SkipLayerNormalization fusion",
171
+ )
172
+ parser.set_defaults(disable_skip_layer_norm=False)
173
+
174
+ parser.add_argument(
175
+ "--disable_embed_layer_norm",
176
+ required=False,
177
+ action="store_true",
178
+ help="disable EmbedLayerNormalization fusion",
179
+ )
180
+ parser.set_defaults(disable_embed_layer_norm=False)
181
+
182
+ parser.add_argument(
183
+ "--disable_bias_skip_layer_norm",
184
+ required=False,
185
+ action="store_true",
186
+ help="disable Add Bias and SkipLayerNormalization fusion",
187
+ )
188
+ parser.set_defaults(disable_bias_skip_layer_norm=False)
189
+
190
+ parser.add_argument(
191
+ "--disable_bias_gelu",
192
+ required=False,
193
+ action="store_true",
194
+ help="disable Add Bias and Gelu/FastGelu fusion",
195
+ )
196
+ parser.set_defaults(disable_bias_gelu=False)
197
+
198
+ parser.add_argument(
199
+ "--disable_layer_norm",
200
+ required=False,
201
+ action="store_true",
202
+ help="disable LayerNormalization fusion",
203
+ )
204
+ parser.set_defaults(disable_layer_norm=False)
205
+
206
+ parser.add_argument(
207
+ "--disable_gelu",
208
+ required=False,
209
+ action="store_true",
210
+ help="disable Gelu fusion",
211
+ )
212
+ parser.set_defaults(disable_gelu=False)
213
+
214
+ parser.add_argument(
215
+ "--enable_gelu_approximation",
216
+ required=False,
217
+ action="store_true",
218
+ help="enable Gelu/BiasGelu to FastGelu conversion",
219
+ )
220
+ parser.set_defaults(enable_gelu_approximation=False)
221
+
222
+ parser.add_argument(
223
+ "--disable_shape_inference",
224
+ required=False,
225
+ action="store_true",
226
+ help="disable symbolic shape inference",
227
+ )
228
+ parser.set_defaults(disable_shape_inference=False)
229
+
230
+ parser.add_argument(
231
+ "--enable_gemm_fast_gelu",
232
+ required=False,
233
+ action="store_true",
234
+ help="enable GemmfastGelu fusion",
235
+ )
236
+ parser.set_defaults(enable_gemm_fast_gelu=False)
237
+
238
+ parser.add_argument(
239
+ "--use_mask_index",
240
+ required=False,
241
+ action="store_true",
242
+ help="use mask index to activate fused attention to speed up. It requires right-side padding!",
243
+ )
244
+ parser.set_defaults(use_mask_index=False)
245
+
246
+ parser.add_argument(
247
+ "--use_raw_attention_mask",
248
+ required=False,
249
+ action="store_true",
250
+ help="use raw attention mask. Use this option if your input is not right-side padding. This might deactivate fused attention and get worse performance.",
251
+ )
252
+ parser.set_defaults(use_raw_attention_mask=False)
253
+
254
+ parser.add_argument(
255
+ "--no_attention_mask",
256
+ required=False,
257
+ action="store_true",
258
+ help="no attention mask. Only works for model_type=bert",
259
+ )
260
+ parser.set_defaults(no_attention_mask=False)
261
+
262
+ parser.add_argument(
263
+ "--use_multi_head_attention",
264
+ required=False,
265
+ action="store_true",
266
+ help="Use MultiHeadAttention instead of Attention operator for testing purpose. "
267
+ "Note that MultiHeadAttention might be slower than Attention when qkv are not packed. ",
268
+ )
269
+ parser.set_defaults(use_multi_head_attention=False)
270
+
271
+ parser.add_argument(
272
+ "--disable_group_norm",
273
+ required=False,
274
+ action="store_true",
275
+ help="not fuse GroupNorm. Only works for model_type=unet or vae",
276
+ )
277
+ parser.set_defaults(disable_group_norm=False)
278
+
279
+ parser.add_argument(
280
+ "--disable_skip_group_norm",
281
+ required=False,
282
+ action="store_true",
283
+ help="not fuse Add + GroupNorm to SkipGroupNorm. Only works for model_type=unet or vae",
284
+ )
285
+ parser.set_defaults(disable_skip_group_norm=False)
286
+
287
+ parser.add_argument(
288
+ "--disable_packed_kv",
289
+ required=False,
290
+ action="store_true",
291
+ help="not use packed kv for cross attention in MultiHeadAttention. Only works for model_type=unet",
292
+ )
293
+ parser.set_defaults(disable_packed_kv=False)
294
+
295
+ parser.add_argument(
296
+ "--disable_packed_qkv",
297
+ required=False,
298
+ action="store_true",
299
+ help="not use packed qkv for self attention in MultiHeadAttention. Only works for model_type=unet",
300
+ )
301
+ parser.set_defaults(disable_packed_qkv=False)
302
+
303
+ parser.add_argument(
304
+ "--disable_bias_add",
305
+ required=False,
306
+ action="store_true",
307
+ help="not fuse BiasAdd. Only works for model_type=unet",
308
+ )
309
+ parser.set_defaults(disable_bias_add=False)
310
+
311
+ parser.add_argument(
312
+ "--disable_bias_splitgelu",
313
+ required=False,
314
+ action="store_true",
315
+ help="not fuse BiasSplitGelu. Only works for model_type=unet",
316
+ )
317
+ parser.set_defaults(disable_bias_splitgelu=False)
318
+
319
+ parser.add_argument(
320
+ "--disable_nhwc_conv",
321
+ required=False,
322
+ action="store_true",
323
+ help="Do not use NhwcConv. Only works for model_type=unet or vae",
324
+ )
325
+ parser.set_defaults(disable_nhwc_conv=False)
326
+
327
+ parser.add_argument(
328
+ "--use_group_norm_channels_first",
329
+ required=False,
330
+ action="store_true",
331
+ help="Use channels_first (NCHW) instead of channels_last (NHWC) for GroupNorm. Only works for model_type=unet or vae",
332
+ )
333
+ parser.set_defaults(use_group_norm_channels_first=False)
334
+
335
+ parser.add_argument(
336
+ "--disable_rotary_embeddings",
337
+ required=False,
338
+ action="store_true",
339
+ help="Do not fuse rotary embeddings into RotaryEmbedding op",
340
+ )