onnxruntime-directml 1.24.1__cp314-cp314-win_amd64.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (322) hide show
  1. onnxruntime/LICENSE +21 -0
  2. onnxruntime/Privacy.md +21 -0
  3. onnxruntime/ThirdPartyNotices.txt +6121 -0
  4. onnxruntime/__init__.py +418 -0
  5. onnxruntime/backend/__init__.py +6 -0
  6. onnxruntime/backend/backend.py +175 -0
  7. onnxruntime/backend/backend_rep.py +52 -0
  8. onnxruntime/capi/DirectML.dll +0 -0
  9. onnxruntime/capi/__init__.py +4 -0
  10. onnxruntime/capi/_ld_preload.py +7 -0
  11. onnxruntime/capi/_pybind_state.py +33 -0
  12. onnxruntime/capi/build_and_package_info.py +2 -0
  13. onnxruntime/capi/convert_npz_to_onnx_adapter.py +48 -0
  14. onnxruntime/capi/onnxruntime.dll +0 -0
  15. onnxruntime/capi/onnxruntime_collect_build_info.py +47 -0
  16. onnxruntime/capi/onnxruntime_inference_collection.py +1440 -0
  17. onnxruntime/capi/onnxruntime_providers_shared.dll +0 -0
  18. onnxruntime/capi/onnxruntime_pybind11_state.pyd +0 -0
  19. onnxruntime/capi/onnxruntime_validation.py +154 -0
  20. onnxruntime/capi/version_info.py +2 -0
  21. onnxruntime/datasets/__init__.py +18 -0
  22. onnxruntime/datasets/logreg_iris.onnx +0 -0
  23. onnxruntime/datasets/mul_1.onnx +0 -0
  24. onnxruntime/datasets/sigmoid.onnx +13 -0
  25. onnxruntime/quantization/CalTableFlatBuffers/KeyValue.py +78 -0
  26. onnxruntime/quantization/CalTableFlatBuffers/TrtTable.py +90 -0
  27. onnxruntime/quantization/CalTableFlatBuffers/__init__.py +0 -0
  28. onnxruntime/quantization/__init__.py +19 -0
  29. onnxruntime/quantization/base_quantizer.py +529 -0
  30. onnxruntime/quantization/calibrate.py +1267 -0
  31. onnxruntime/quantization/execution_providers/qnn/__init__.py +2 -0
  32. onnxruntime/quantization/execution_providers/qnn/fusion_lpnorm.py +132 -0
  33. onnxruntime/quantization/execution_providers/qnn/fusion_spacetodepth.py +162 -0
  34. onnxruntime/quantization/execution_providers/qnn/mixed_precision_overrides_utils.py +413 -0
  35. onnxruntime/quantization/execution_providers/qnn/preprocess.py +353 -0
  36. onnxruntime/quantization/execution_providers/qnn/quant_config.py +389 -0
  37. onnxruntime/quantization/fusions/__init__.py +4 -0
  38. onnxruntime/quantization/fusions/fusion.py +311 -0
  39. onnxruntime/quantization/fusions/fusion_gelu.py +272 -0
  40. onnxruntime/quantization/fusions/fusion_layernorm.py +146 -0
  41. onnxruntime/quantization/fusions/replace_upsample_with_resize.py +96 -0
  42. onnxruntime/quantization/matmul_bnb4_quantizer.py +239 -0
  43. onnxruntime/quantization/matmul_nbits_quantizer.py +1638 -0
  44. onnxruntime/quantization/neural_compressor/__init__.py +1 -0
  45. onnxruntime/quantization/neural_compressor/onnx_model.py +1251 -0
  46. onnxruntime/quantization/neural_compressor/util.py +80 -0
  47. onnxruntime/quantization/neural_compressor/weight_only.py +932 -0
  48. onnxruntime/quantization/onnx_model.py +600 -0
  49. onnxruntime/quantization/onnx_quantizer.py +1163 -0
  50. onnxruntime/quantization/operators/__init__.py +2 -0
  51. onnxruntime/quantization/operators/activation.py +119 -0
  52. onnxruntime/quantization/operators/argmax.py +18 -0
  53. onnxruntime/quantization/operators/attention.py +73 -0
  54. onnxruntime/quantization/operators/base_operator.py +26 -0
  55. onnxruntime/quantization/operators/binary_op.py +72 -0
  56. onnxruntime/quantization/operators/concat.py +62 -0
  57. onnxruntime/quantization/operators/conv.py +260 -0
  58. onnxruntime/quantization/operators/direct_q8.py +78 -0
  59. onnxruntime/quantization/operators/embed_layernorm.py +121 -0
  60. onnxruntime/quantization/operators/gather.py +64 -0
  61. onnxruntime/quantization/operators/gavgpool.py +62 -0
  62. onnxruntime/quantization/operators/gemm.py +172 -0
  63. onnxruntime/quantization/operators/lstm.py +121 -0
  64. onnxruntime/quantization/operators/matmul.py +231 -0
  65. onnxruntime/quantization/operators/maxpool.py +34 -0
  66. onnxruntime/quantization/operators/norm.py +40 -0
  67. onnxruntime/quantization/operators/pad.py +172 -0
  68. onnxruntime/quantization/operators/pooling.py +67 -0
  69. onnxruntime/quantization/operators/qdq_base_operator.py +22 -0
  70. onnxruntime/quantization/operators/resize.py +34 -0
  71. onnxruntime/quantization/operators/softmax.py +74 -0
  72. onnxruntime/quantization/operators/split.py +63 -0
  73. onnxruntime/quantization/operators/where.py +87 -0
  74. onnxruntime/quantization/preprocess.py +141 -0
  75. onnxruntime/quantization/qdq_loss_debug.py +389 -0
  76. onnxruntime/quantization/qdq_quantizer.py +1477 -0
  77. onnxruntime/quantization/quant_utils.py +1051 -0
  78. onnxruntime/quantization/quantize.py +953 -0
  79. onnxruntime/quantization/registry.py +110 -0
  80. onnxruntime/quantization/shape_inference.py +204 -0
  81. onnxruntime/quantization/static_quantize_runner.py +256 -0
  82. onnxruntime/quantization/tensor_quant_overrides.py +520 -0
  83. onnxruntime/tools/__init__.py +10 -0
  84. onnxruntime/tools/check_onnx_model_mobile_usability.py +47 -0
  85. onnxruntime/tools/convert_onnx_models_to_ort.py +380 -0
  86. onnxruntime/tools/file_utils.py +47 -0
  87. onnxruntime/tools/logger.py +11 -0
  88. onnxruntime/tools/make_dynamic_shape_fixed.py +73 -0
  89. onnxruntime/tools/mobile_helpers/__init__.py +0 -0
  90. onnxruntime/tools/mobile_helpers/coreml_supported_mlprogram_ops.md +53 -0
  91. onnxruntime/tools/mobile_helpers/coreml_supported_neuralnetwork_ops.md +43 -0
  92. onnxruntime/tools/mobile_helpers/nnapi_supported_ops.md +58 -0
  93. onnxruntime/tools/mobile_helpers/usability_checker.py +738 -0
  94. onnxruntime/tools/offline_tuning.py +169 -0
  95. onnxruntime/tools/onnx_model_utils.py +416 -0
  96. onnxruntime/tools/onnx_randomizer.py +85 -0
  97. onnxruntime/tools/onnxruntime_test.py +164 -0
  98. onnxruntime/tools/optimize_onnx_model.py +56 -0
  99. onnxruntime/tools/ort_format_model/__init__.py +27 -0
  100. onnxruntime/tools/ort_format_model/operator_type_usage_processors.py +653 -0
  101. onnxruntime/tools/ort_format_model/ort_flatbuffers_py/__init__.py +0 -0
  102. onnxruntime/tools/ort_format_model/ort_flatbuffers_py/fbs/ArgType.py +7 -0
  103. onnxruntime/tools/ort_format_model/ort_flatbuffers_py/fbs/ArgTypeAndIndex.py +67 -0
  104. onnxruntime/tools/ort_format_model/ort_flatbuffers_py/fbs/Attribute.py +337 -0
  105. onnxruntime/tools/ort_format_model/ort_flatbuffers_py/fbs/AttributeType.py +18 -0
  106. onnxruntime/tools/ort_format_model/ort_flatbuffers_py/fbs/Checkpoint.py +125 -0
  107. onnxruntime/tools/ort_format_model/ort_flatbuffers_py/fbs/DeprecatedKernelCreateInfos.py +120 -0
  108. onnxruntime/tools/ort_format_model/ort_flatbuffers_py/fbs/DeprecatedNodeIndexAndKernelDefHash.py +68 -0
  109. onnxruntime/tools/ort_format_model/ort_flatbuffers_py/fbs/DeprecatedSessionState.py +96 -0
  110. onnxruntime/tools/ort_format_model/ort_flatbuffers_py/fbs/DeprecatedSubGraphSessionState.py +72 -0
  111. onnxruntime/tools/ort_format_model/ort_flatbuffers_py/fbs/Dimension.py +71 -0
  112. onnxruntime/tools/ort_format_model/ort_flatbuffers_py/fbs/DimensionValue.py +80 -0
  113. onnxruntime/tools/ort_format_model/ort_flatbuffers_py/fbs/DimensionValueType.py +8 -0
  114. onnxruntime/tools/ort_format_model/ort_flatbuffers_py/fbs/EdgeEnd.py +32 -0
  115. onnxruntime/tools/ort_format_model/ort_flatbuffers_py/fbs/FloatProperty.py +67 -0
  116. onnxruntime/tools/ort_format_model/ort_flatbuffers_py/fbs/Graph.py +320 -0
  117. onnxruntime/tools/ort_format_model/ort_flatbuffers_py/fbs/InferenceSession.py +88 -0
  118. onnxruntime/tools/ort_format_model/ort_flatbuffers_py/fbs/IntProperty.py +67 -0
  119. onnxruntime/tools/ort_format_model/ort_flatbuffers_py/fbs/KernelTypeStrArgsEntry.py +91 -0
  120. onnxruntime/tools/ort_format_model/ort_flatbuffers_py/fbs/KernelTypeStrResolver.py +78 -0
  121. onnxruntime/tools/ort_format_model/ort_flatbuffers_py/fbs/MapType.py +71 -0
  122. onnxruntime/tools/ort_format_model/ort_flatbuffers_py/fbs/Model.py +223 -0
  123. onnxruntime/tools/ort_format_model/ort_flatbuffers_py/fbs/ModuleState.py +141 -0
  124. onnxruntime/tools/ort_format_model/ort_flatbuffers_py/fbs/Node.py +317 -0
  125. onnxruntime/tools/ort_format_model/ort_flatbuffers_py/fbs/NodeEdge.py +126 -0
  126. onnxruntime/tools/ort_format_model/ort_flatbuffers_py/fbs/NodeType.py +7 -0
  127. onnxruntime/tools/ort_format_model/ort_flatbuffers_py/fbs/NodesToOptimizeIndices.py +160 -0
  128. onnxruntime/tools/ort_format_model/ort_flatbuffers_py/fbs/OpIdKernelTypeStrArgsEntry.py +91 -0
  129. onnxruntime/tools/ort_format_model/ort_flatbuffers_py/fbs/OperatorSetId.py +67 -0
  130. onnxruntime/tools/ort_format_model/ort_flatbuffers_py/fbs/OptimizerGroup.py +117 -0
  131. onnxruntime/tools/ort_format_model/ort_flatbuffers_py/fbs/ParameterOptimizerState.py +91 -0
  132. onnxruntime/tools/ort_format_model/ort_flatbuffers_py/fbs/PropertyBag.py +152 -0
  133. onnxruntime/tools/ort_format_model/ort_flatbuffers_py/fbs/RuntimeOptimizationRecord.py +105 -0
  134. onnxruntime/tools/ort_format_model/ort_flatbuffers_py/fbs/RuntimeOptimizationRecordContainerEntry.py +91 -0
  135. onnxruntime/tools/ort_format_model/ort_flatbuffers_py/fbs/RuntimeOptimizations.py +79 -0
  136. onnxruntime/tools/ort_format_model/ort_flatbuffers_py/fbs/SequenceType.py +58 -0
  137. onnxruntime/tools/ort_format_model/ort_flatbuffers_py/fbs/Shape.py +78 -0
  138. onnxruntime/tools/ort_format_model/ort_flatbuffers_py/fbs/SparseTensor.py +114 -0
  139. onnxruntime/tools/ort_format_model/ort_flatbuffers_py/fbs/StringProperty.py +67 -0
  140. onnxruntime/tools/ort_format_model/ort_flatbuffers_py/fbs/StringStringEntry.py +67 -0
  141. onnxruntime/tools/ort_format_model/ort_flatbuffers_py/fbs/Tensor.py +203 -0
  142. onnxruntime/tools/ort_format_model/ort_flatbuffers_py/fbs/TensorDataType.py +26 -0
  143. onnxruntime/tools/ort_format_model/ort_flatbuffers_py/fbs/TensorTypeAndShape.py +71 -0
  144. onnxruntime/tools/ort_format_model/ort_flatbuffers_py/fbs/TypeInfo.py +83 -0
  145. onnxruntime/tools/ort_format_model/ort_flatbuffers_py/fbs/TypeInfoValue.py +9 -0
  146. onnxruntime/tools/ort_format_model/ort_flatbuffers_py/fbs/ValueInfo.py +84 -0
  147. onnxruntime/tools/ort_format_model/ort_flatbuffers_py/fbs/__init__.py +6 -0
  148. onnxruntime/tools/ort_format_model/ort_model_processor.py +86 -0
  149. onnxruntime/tools/ort_format_model/types.py +85 -0
  150. onnxruntime/tools/ort_format_model/utils.py +61 -0
  151. onnxruntime/tools/pytorch_export_contrib_ops.py +129 -0
  152. onnxruntime/tools/pytorch_export_helpers.py +131 -0
  153. onnxruntime/tools/qdq_helpers/__init__.py +0 -0
  154. onnxruntime/tools/qdq_helpers/optimize_qdq_model.py +37 -0
  155. onnxruntime/tools/qnn/add_trans_cast.py +292 -0
  156. onnxruntime/tools/qnn/gen_qnn_ctx_onnx_model.py +364 -0
  157. onnxruntime/tools/qnn/preprocess.py +165 -0
  158. onnxruntime/tools/reduced_build_config_parser.py +203 -0
  159. onnxruntime/tools/remove_initializer_from_input.py +37 -0
  160. onnxruntime/tools/symbolic_shape_infer.py +3094 -0
  161. onnxruntime/tools/update_onnx_opset.py +31 -0
  162. onnxruntime/transformers/__init__.py +8 -0
  163. onnxruntime/transformers/affinity_helper.py +40 -0
  164. onnxruntime/transformers/benchmark.py +942 -0
  165. onnxruntime/transformers/benchmark_helper.py +643 -0
  166. onnxruntime/transformers/bert_perf_test.py +629 -0
  167. onnxruntime/transformers/bert_test_data.py +641 -0
  168. onnxruntime/transformers/compare_bert_results.py +256 -0
  169. onnxruntime/transformers/constants.py +47 -0
  170. onnxruntime/transformers/convert_generation.py +3605 -0
  171. onnxruntime/transformers/convert_tf_models_to_pytorch.py +205 -0
  172. onnxruntime/transformers/convert_to_packing_mode.py +385 -0
  173. onnxruntime/transformers/dynamo_onnx_helper.py +205 -0
  174. onnxruntime/transformers/float16.py +501 -0
  175. onnxruntime/transformers/fusion_attention.py +1189 -0
  176. onnxruntime/transformers/fusion_attention_clip.py +340 -0
  177. onnxruntime/transformers/fusion_attention_sam2.py +533 -0
  178. onnxruntime/transformers/fusion_attention_unet.py +1307 -0
  179. onnxruntime/transformers/fusion_attention_vae.py +300 -0
  180. onnxruntime/transformers/fusion_bart_attention.py +435 -0
  181. onnxruntime/transformers/fusion_base.py +141 -0
  182. onnxruntime/transformers/fusion_bias_add.py +57 -0
  183. onnxruntime/transformers/fusion_biasgelu.py +66 -0
  184. onnxruntime/transformers/fusion_biassplitgelu.py +110 -0
  185. onnxruntime/transformers/fusion_conformer_attention.py +222 -0
  186. onnxruntime/transformers/fusion_constant_fold.py +144 -0
  187. onnxruntime/transformers/fusion_embedlayer.py +810 -0
  188. onnxruntime/transformers/fusion_fastgelu.py +492 -0
  189. onnxruntime/transformers/fusion_gelu.py +258 -0
  190. onnxruntime/transformers/fusion_gelu_approximation.py +25 -0
  191. onnxruntime/transformers/fusion_gemmfastgelu.py +121 -0
  192. onnxruntime/transformers/fusion_gpt_attention.py +546 -0
  193. onnxruntime/transformers/fusion_gpt_attention_megatron.py +355 -0
  194. onnxruntime/transformers/fusion_gpt_attention_no_past.py +260 -0
  195. onnxruntime/transformers/fusion_group_norm.py +180 -0
  196. onnxruntime/transformers/fusion_layernorm.py +489 -0
  197. onnxruntime/transformers/fusion_mha_mmdit.py +667 -0
  198. onnxruntime/transformers/fusion_nhwc_conv.py +99 -0
  199. onnxruntime/transformers/fusion_options.py +340 -0
  200. onnxruntime/transformers/fusion_qordered_attention.py +420 -0
  201. onnxruntime/transformers/fusion_qordered_gelu.py +118 -0
  202. onnxruntime/transformers/fusion_qordered_layernorm.py +122 -0
  203. onnxruntime/transformers/fusion_qordered_matmul.py +216 -0
  204. onnxruntime/transformers/fusion_quickgelu.py +74 -0
  205. onnxruntime/transformers/fusion_reshape.py +173 -0
  206. onnxruntime/transformers/fusion_rotary_attention.py +1591 -0
  207. onnxruntime/transformers/fusion_shape.py +109 -0
  208. onnxruntime/transformers/fusion_simplified_layernorm.py +165 -0
  209. onnxruntime/transformers/fusion_skip_group_norm.py +254 -0
  210. onnxruntime/transformers/fusion_skiplayernorm.py +209 -0
  211. onnxruntime/transformers/fusion_transpose.py +167 -0
  212. onnxruntime/transformers/fusion_utils.py +321 -0
  213. onnxruntime/transformers/huggingface_models.py +74 -0
  214. onnxruntime/transformers/import_utils.py +20 -0
  215. onnxruntime/transformers/io_binding_helper.py +487 -0
  216. onnxruntime/transformers/large_model_exporter.py +395 -0
  217. onnxruntime/transformers/machine_info.py +230 -0
  218. onnxruntime/transformers/metrics.py +163 -0
  219. onnxruntime/transformers/models/bart/__init__.py +12 -0
  220. onnxruntime/transformers/models/bart/export.py +98 -0
  221. onnxruntime/transformers/models/bert/__init__.py +12 -0
  222. onnxruntime/transformers/models/bert/eval_squad.py +329 -0
  223. onnxruntime/transformers/models/gpt2/__init__.py +12 -0
  224. onnxruntime/transformers/models/gpt2/benchmark_gpt2.py +413 -0
  225. onnxruntime/transformers/models/gpt2/convert_to_onnx.py +566 -0
  226. onnxruntime/transformers/models/gpt2/gpt2_helper.py +1031 -0
  227. onnxruntime/transformers/models/gpt2/gpt2_parity.py +513 -0
  228. onnxruntime/transformers/models/gpt2/gpt2_tester.py +501 -0
  229. onnxruntime/transformers/models/gpt2/parity_check_helper.py +146 -0
  230. onnxruntime/transformers/models/llama/__init__.py +12 -0
  231. onnxruntime/transformers/models/llama/benchmark.py +700 -0
  232. onnxruntime/transformers/models/llama/benchmark_all.py +488 -0
  233. onnxruntime/transformers/models/llama/benchmark_e2e.py +608 -0
  234. onnxruntime/transformers/models/llama/convert_to_onnx.py +1064 -0
  235. onnxruntime/transformers/models/llama/dist_settings.py +57 -0
  236. onnxruntime/transformers/models/llama/llama_inputs.py +504 -0
  237. onnxruntime/transformers/models/llama/llama_parity.py +343 -0
  238. onnxruntime/transformers/models/llama/llama_torch.py +47 -0
  239. onnxruntime/transformers/models/llama/quant_kv_dataloader.py +108 -0
  240. onnxruntime/transformers/models/longformer/__init__.py +12 -0
  241. onnxruntime/transformers/models/longformer/benchmark_longformer.py +821 -0
  242. onnxruntime/transformers/models/longformer/convert_to_onnx.py +413 -0
  243. onnxruntime/transformers/models/longformer/generate_test_data.py +347 -0
  244. onnxruntime/transformers/models/longformer/longformer_helper.py +76 -0
  245. onnxruntime/transformers/models/phi2/__init__.py +12 -0
  246. onnxruntime/transformers/models/phi2/convert_to_onnx.py +590 -0
  247. onnxruntime/transformers/models/phi2/inference_example.py +414 -0
  248. onnxruntime/transformers/models/sam2/__init__.py +12 -0
  249. onnxruntime/transformers/models/sam2/benchmark_sam2.py +638 -0
  250. onnxruntime/transformers/models/sam2/convert_to_onnx.py +270 -0
  251. onnxruntime/transformers/models/sam2/image_decoder.py +272 -0
  252. onnxruntime/transformers/models/sam2/image_encoder.py +236 -0
  253. onnxruntime/transformers/models/sam2/mask_decoder.py +208 -0
  254. onnxruntime/transformers/models/sam2/nvtx_helper.py +33 -0
  255. onnxruntime/transformers/models/sam2/prompt_encoder.py +189 -0
  256. onnxruntime/transformers/models/sam2/sam2_demo.py +321 -0
  257. onnxruntime/transformers/models/sam2/sam2_image_onnx_predictor.py +279 -0
  258. onnxruntime/transformers/models/sam2/sam2_utils.py +147 -0
  259. onnxruntime/transformers/models/stable_diffusion/__init__.py +12 -0
  260. onnxruntime/transformers/models/stable_diffusion/benchmark.py +1519 -0
  261. onnxruntime/transformers/models/stable_diffusion/benchmark_controlnet.py +426 -0
  262. onnxruntime/transformers/models/stable_diffusion/demo_txt2img.py +103 -0
  263. onnxruntime/transformers/models/stable_diffusion/demo_txt2img_xl.py +269 -0
  264. onnxruntime/transformers/models/stable_diffusion/demo_utils.py +778 -0
  265. onnxruntime/transformers/models/stable_diffusion/diffusion_models.py +1318 -0
  266. onnxruntime/transformers/models/stable_diffusion/diffusion_schedulers.py +1179 -0
  267. onnxruntime/transformers/models/stable_diffusion/engine_builder.py +295 -0
  268. onnxruntime/transformers/models/stable_diffusion/engine_builder_ort_cuda.py +387 -0
  269. onnxruntime/transformers/models/stable_diffusion/engine_builder_ort_trt.py +288 -0
  270. onnxruntime/transformers/models/stable_diffusion/engine_builder_tensorrt.py +395 -0
  271. onnxruntime/transformers/models/stable_diffusion/engine_builder_torch.py +108 -0
  272. onnxruntime/transformers/models/stable_diffusion/optimize_pipeline.py +590 -0
  273. onnxruntime/transformers/models/stable_diffusion/ort_optimizer.py +136 -0
  274. onnxruntime/transformers/models/stable_diffusion/pipeline_stable_diffusion.py +831 -0
  275. onnxruntime/transformers/models/stable_diffusion/trt_utilities.py +12 -0
  276. onnxruntime/transformers/models/t5/__init__.py +12 -0
  277. onnxruntime/transformers/models/t5/convert_to_onnx.py +318 -0
  278. onnxruntime/transformers/models/t5/t5_decoder.py +437 -0
  279. onnxruntime/transformers/models/t5/t5_encoder.py +70 -0
  280. onnxruntime/transformers/models/t5/t5_encoder_decoder_init.py +361 -0
  281. onnxruntime/transformers/models/t5/t5_helper.py +302 -0
  282. onnxruntime/transformers/models/whisper/__init__.py +12 -0
  283. onnxruntime/transformers/models/whisper/benchmark.py +585 -0
  284. onnxruntime/transformers/models/whisper/benchmark_all.py +526 -0
  285. onnxruntime/transformers/models/whisper/convert_to_onnx.py +609 -0
  286. onnxruntime/transformers/models/whisper/whisper_chain.py +334 -0
  287. onnxruntime/transformers/models/whisper/whisper_decoder.py +464 -0
  288. onnxruntime/transformers/models/whisper/whisper_encoder.py +164 -0
  289. onnxruntime/transformers/models/whisper/whisper_encoder_decoder_init.py +371 -0
  290. onnxruntime/transformers/models/whisper/whisper_helper.py +1035 -0
  291. onnxruntime/transformers/models/whisper/whisper_inputs.py +380 -0
  292. onnxruntime/transformers/models/whisper/whisper_jump_times.py +477 -0
  293. onnxruntime/transformers/onnx_exporter.py +719 -0
  294. onnxruntime/transformers/onnx_model.py +1636 -0
  295. onnxruntime/transformers/onnx_model_bart.py +141 -0
  296. onnxruntime/transformers/onnx_model_bert.py +488 -0
  297. onnxruntime/transformers/onnx_model_bert_keras.py +474 -0
  298. onnxruntime/transformers/onnx_model_bert_tf.py +588 -0
  299. onnxruntime/transformers/onnx_model_clip.py +42 -0
  300. onnxruntime/transformers/onnx_model_conformer.py +32 -0
  301. onnxruntime/transformers/onnx_model_gpt2.py +101 -0
  302. onnxruntime/transformers/onnx_model_mmdit.py +112 -0
  303. onnxruntime/transformers/onnx_model_phi.py +929 -0
  304. onnxruntime/transformers/onnx_model_sam2.py +137 -0
  305. onnxruntime/transformers/onnx_model_t5.py +985 -0
  306. onnxruntime/transformers/onnx_model_tnlr.py +226 -0
  307. onnxruntime/transformers/onnx_model_unet.py +258 -0
  308. onnxruntime/transformers/onnx_model_vae.py +42 -0
  309. onnxruntime/transformers/onnx_utils.py +55 -0
  310. onnxruntime/transformers/optimizer.py +620 -0
  311. onnxruntime/transformers/past_helper.py +149 -0
  312. onnxruntime/transformers/profile_result_processor.py +358 -0
  313. onnxruntime/transformers/profiler.py +434 -0
  314. onnxruntime/transformers/quantize_helper.py +76 -0
  315. onnxruntime/transformers/shape_infer_helper.py +121 -0
  316. onnxruntime/transformers/shape_optimizer.py +400 -0
  317. onnxruntime/transformers/torch_onnx_export_helper.py +74 -0
  318. onnxruntime_directml-1.24.1.dist-info/METADATA +216 -0
  319. onnxruntime_directml-1.24.1.dist-info/RECORD +322 -0
  320. onnxruntime_directml-1.24.1.dist-info/WHEEL +5 -0
  321. onnxruntime_directml-1.24.1.dist-info/entry_points.txt +2 -0
  322. onnxruntime_directml-1.24.1.dist-info/top_level.txt +1 -0
@@ -0,0 +1,149 @@
1
+ # -------------------------------------------------------------------------
2
+ # Copyright (c) Microsoft Corporation. All rights reserved.
3
+ # Licensed under the MIT License. See License.txt in the project root for
4
+ # license information.
5
+ # --------------------------------------------------------------------------
6
+
7
+ import logging
8
+
9
+ import torch
10
+
11
+ logger = logging.getLogger(__name__)
12
+
13
+
14
+ class PastKeyValuesHelper:
15
+ """Helper functions to process past key values for encoder-decoder model"""
16
+
17
+ @staticmethod
18
+ def get_past_names(num_layers, present: bool = False):
19
+ past_self_names = []
20
+ past_cross_names = []
21
+ for i in range(num_layers):
22
+ past_self_names.extend(
23
+ [f"present_key_self_{i}", f"present_value_self_{i}"]
24
+ if present
25
+ else [f"past_key_self_{i}", f"past_value_self_{i}"]
26
+ )
27
+ past_cross_names.extend(
28
+ [f"present_key_cross_{i}", f"present_value_cross_{i}"]
29
+ if present
30
+ else [f"past_key_cross_{i}", f"past_value_cross_{i}"]
31
+ )
32
+ return past_self_names + past_cross_names
33
+
34
+ @staticmethod
35
+ def group_by_self_or_cross(present_key_values):
36
+ """Split present state from grouped by layer to grouped by self/cross attention.
37
+ Before: (past_key_self_0, past_value_self_0, past_key_cross_0, past_value_cross_0), (past_key_self_1, past_value_self_1, past_key_cross_1, past_value_cross_1), ...
38
+ After: (past_key_self_0, past_value_self_0, past_key_self_1, past_value_self_1, ...), (past_key_cross_0, past_value_cross_0, past_key_cross_1, past_value_cross_1, ...)
39
+
40
+ """
41
+ present_self = []
42
+ present_cross = []
43
+ for _i, present_layer_i in enumerate(present_key_values):
44
+ assert len(present_layer_i) == 4, f"Expected to have four items. Got {len(present_layer_i)}"
45
+ (
46
+ present_key_self,
47
+ present_value_self,
48
+ present_key_cross,
49
+ present_value_cross,
50
+ ) = present_layer_i
51
+ present_self.extend([present_key_self, present_value_self])
52
+ present_cross.extend([present_key_cross, present_value_cross])
53
+ return present_self, present_cross
54
+
55
+ @staticmethod
56
+ def group_by_layer(past, num_layers):
57
+ """Reorder past state from grouped by self/cross attention to grouped by layer.
58
+ Before: past_key_self_0, past_value_self_0, past_key_self_1, past_value_self_1, ..., past_key_cross_0, past_value_cross_0, past_key_cross_1, past_value_cross_1, ...
59
+ After: (past_key_self_0, past_value_self_0, past_key_cross_0, past_value_cross_0), (past_key_self_1, past_value_self_1, past_key_cross_1, past_value_cross_1),
60
+ """
61
+ assert len(past) == 4 * num_layers
62
+ return tuple(
63
+ [
64
+ past[2 * i],
65
+ past[2 * i + 1],
66
+ past[2 * num_layers + 2 * i],
67
+ past[2 * num_layers + 2 * i + 1],
68
+ ]
69
+ for i in range(num_layers)
70
+ )
71
+
72
+ @staticmethod
73
+ def back_group_by_layer(past_key_values: tuple[tuple[torch.Tensor]]):
74
+ """Categorize present_key_values from self and cross attention to layer by layer.
75
+
76
+ Reorder past state from grouped by self/cross attention to grouped by layer.
77
+ Before: past_key_self_0, past_value_self_0, past_key_self_1, past_value_self_1, ...,
78
+ past_key_cross_0, past_value_cross_0, past_key_cross_1, past_value_cross_1, ...
79
+ After: (past_key_self_0, past_value_self_0, past_key_cross_0, past_value_cross_0),
80
+ (past_key_self_1, past_value_self_1, past_key_cross_1, past_value_cross_1),
81
+
82
+ Args:
83
+ present_key_values: From past_key_values of a model (group by self and cross attention)
84
+
85
+ Returns:
86
+ past_tuples: present key and values grouped by layer.
87
+ """
88
+ past_tuples = ()
89
+ half_idx = len(past_key_values) // 2
90
+ for i in range(len(past_key_values) // 4):
91
+ idx = 2 * i
92
+ past_tuples += (
93
+ (
94
+ past_key_values[idx],
95
+ past_key_values[idx + 1],
96
+ past_key_values[half_idx + idx],
97
+ past_key_values[half_idx + idx + 1],
98
+ ),
99
+ )
100
+ return past_tuples
101
+
102
+ @staticmethod
103
+ def group_by_self_and_cross(present_key_values: tuple[torch.Tensor], concat: bool = False):
104
+ """Categorize present_key_values into self and cross attention.
105
+
106
+ Split present state from grouped by layer to grouped by self/cross attention.
107
+ Before: (past_key_self_0, past_value_self_0, past_key_cross_0, past_value_cross_0),
108
+ (past_key_self_1, past_value_self_1, past_key_cross_1, past_value_cross_1), ...
109
+ After: (past_key_self_0, past_value_self_0, past_key_self_1, past_value_self_1, ...),
110
+ (past_key_cross_0, past_value_cross_0, past_key_cross_1, past_value_cross_1, ...)
111
+
112
+ Args:
113
+ present_key_values: From past_key_values of a model (group by layer)
114
+ concat: If concat self attention with cross attention key/value to return
115
+
116
+ Returns:
117
+ present_self (Tuple[torch.Tensor]): present key and values from self attention
118
+ present_cross (Tuple[torch.Tensor]): present key and values from cross attention
119
+ """
120
+ present_self: list[torch.Tensor] = []
121
+ present_cross: list[torch.Tensor] = []
122
+ for _, present_layer_i in enumerate(present_key_values):
123
+ assert len(present_layer_i) == 4, f"Expected to have four items. Got {len(present_layer_i)}"
124
+ present_key_self, present_value_self, present_key_cross, present_value_cross = present_layer_i
125
+ present_self.extend([present_key_self, present_value_self])
126
+ present_cross.extend([present_key_cross, present_value_cross])
127
+ if concat:
128
+ return present_self + present_cross
129
+ else:
130
+ return present_self, present_cross
131
+
132
+ @staticmethod
133
+ def get_input_names(past_key_values: tuple[tuple[torch.Tensor]], encoder=True):
134
+ """Process input names of model wrapper.
135
+
136
+ Args:
137
+ past_key_values: Consider `self` and `cross` past_key_values
138
+
139
+ Returns:
140
+ names (List[string]): input names
141
+ """
142
+ names = []
143
+ num_layers = len(past_key_values) // 4 if encoder else len(past_key_values)
144
+ prefix = "past_" if not encoder else "present_"
145
+ for i in range(num_layers):
146
+ names.extend([prefix + s for s in [f"key_self_{i}", f"value_self_{i}"]])
147
+ for i in range(num_layers):
148
+ names.extend([prefix + s for s in [f"key_cross_{i}", f"value_cross_{i}"]])
149
+ return names
@@ -0,0 +1,358 @@
1
+ # -------------------------------------------------------------------------
2
+ # Copyright (c) Microsoft Corporation. All rights reserved.
3
+ # Licensed under the MIT License.
4
+ # --------------------------------------------------------------------------
5
+
6
+ """This profiler result processor print out the kernel time spent on each Node of the model.
7
+ Example of importing profile result file from onnxruntime_perf_test:
8
+ python profile_result_processor.py --input profile_2021-10-25_12-02-41.json
9
+ """
10
+
11
+ import argparse
12
+ import json
13
+
14
+ _NODES_TYPE_CONTAINING_SUBGRAPH = frozenset(("Scan", "Loop", "If"))
15
+
16
+
17
+ def parse_arguments(argv=None):
18
+ parser = argparse.ArgumentParser()
19
+
20
+ parser.add_argument(
21
+ "-i",
22
+ "--input",
23
+ required=False,
24
+ type=str,
25
+ help="Set the input file for reading the profile results",
26
+ )
27
+
28
+ parser.add_argument(
29
+ "--threshold",
30
+ required=False,
31
+ type=float,
32
+ default=0.01,
33
+ help="Threshold of run time ratio among all nodes. Nodes with larger ratio will show in top expensive nodes.",
34
+ )
35
+
36
+ parser.add_argument(
37
+ "--provider",
38
+ required=False,
39
+ type=str,
40
+ default="cuda",
41
+ help="Execution provider to use",
42
+ )
43
+
44
+ parser.add_argument(
45
+ "--kernel_time_only",
46
+ required=False,
47
+ action="store_true",
48
+ help="Only include the kernel time and no fence time",
49
+ )
50
+
51
+ parser.set_defaults(kernel_time_only=False)
52
+
53
+ parser.add_argument("-v", "--verbose", required=False, action="store_true")
54
+ parser.set_defaults(verbose=False)
55
+
56
+ return parser.parse_args(argv)
57
+
58
+
59
+ def load_profile_json(profile_file):
60
+ print(f"loading profile output {profile_file} ...")
61
+
62
+ with open(profile_file) as opened_file:
63
+ sess_time = json.load(opened_file)
64
+
65
+ assert isinstance(sess_time, list)
66
+ return sess_time
67
+
68
+
69
+ def parse_kernel_results(sess_time, threshold=0):
70
+ """Parse profile data and output nodes in two sections - nodes in the original order, and top expensive nodes.
71
+
72
+ Args:
73
+ sess_time (List[Dict]): profile data
74
+ threshold (int, optional): Minimum ratio of duration among all. Defaults to 0.
75
+
76
+ Returns:
77
+ List[str]: lines of string for output.
78
+ """
79
+ kernel_name_to_op_name = {}
80
+ kernel_time = {}
81
+ kernel_freq = {}
82
+ total = 0
83
+ session_init = False
84
+ for item in sess_time:
85
+ # Skip all MemcpyHostToDevice before session_initialization
86
+ if item["cat"] == "Session" and item["name"] == "session_initialization":
87
+ session_init = True
88
+ if not session_init:
89
+ continue
90
+
91
+ if item["cat"] == "Kernel" and "dur" in item and "args" in item and "op_name" in item["args"]:
92
+ kernel_name = item["name"]
93
+
94
+ op_name = item["args"]["op_name"]
95
+ if op_name in _NODES_TYPE_CONTAINING_SUBGRAPH:
96
+ continue
97
+
98
+ # Handle MemcpyHostToDevice and MemcpyDeviceToHost here
99
+ if not op_name:
100
+ op_name = f"({kernel_name})"
101
+
102
+ if kernel_name in kernel_time:
103
+ kernel_time[kernel_name] += item["dur"]
104
+ kernel_freq[kernel_name] += 1
105
+ else:
106
+ kernel_time[kernel_name] = item["dur"]
107
+ kernel_freq[kernel_name] = 1
108
+ kernel_name_to_op_name[kernel_name] = op_name
109
+
110
+ total += item["dur"]
111
+
112
+ if not kernel_time:
113
+ return ["No kernel record found!"]
114
+
115
+ # Output items with run time ratio > thresholds, and sorted by duration in the descending order.
116
+ lines = []
117
+ lines.append(f"\nTop expensive kernels with Time% >= {threshold * 100:.2f}:")
118
+ lines.append("-" * 64)
119
+ lines.append("Total(μs)\tTime%\tCalls\tAvg(μs)\tKernel")
120
+ for kernel_name, duration in sorted(kernel_time.items(), key=lambda x: x[1], reverse=True):
121
+ ratio = duration / total
122
+ if ratio < threshold:
123
+ continue
124
+
125
+ calls = kernel_freq[kernel_name]
126
+ avg_time = duration / float(calls)
127
+ lines.append(f"{duration:10d}\t{ratio * 100.0:5.2f}\t{calls:5d}\t{avg_time:8.1f}\t{kernel_name}")
128
+
129
+ # Group by operator
130
+ op_time = {}
131
+ for kernel_name, op_name in kernel_name_to_op_name.items():
132
+ duration = kernel_time[kernel_name]
133
+ if op_name in op_time:
134
+ op_time[op_name] += duration
135
+ else:
136
+ op_time[op_name] = duration
137
+
138
+ lines.append("\nGroup kernel time by operator:")
139
+ lines.append("-" * 64)
140
+ lines.append("Total(μs)\tTime%\tOperator")
141
+ for op_name, duration in sorted(op_time.items(), key=lambda x: x[1], reverse=True):
142
+ ratio = duration / total
143
+ lines.append(f"{duration:10d}\t{ratio * 100.0:5.2f}\t{op_name}")
144
+
145
+ return lines
146
+
147
+
148
+ def parse_node_results(sess_time, kernel_time_only=False, threshold=0):
149
+ """Parse profile data and output nodes in two sections - nodes in the original order, and top expensive nodes.
150
+
151
+ Args:
152
+ sess_time (List[Dict]): profile data
153
+ kernel_time_only (bool, optional): Only include items for kernel time. Defaults to False.
154
+ threshold (int, optional): Minimum ratio of duration among all. Defaults to 0.
155
+
156
+ Returns:
157
+ List[str]: lines of string for output.
158
+ """
159
+ node_name_list = []
160
+ node_time = {}
161
+ node_freq = {}
162
+ node_provider = {}
163
+ total = 0
164
+ for item in sess_time:
165
+ if item["cat"] == "Node" and "dur" in item and "args" in item and "op_name" in item["args"]:
166
+ node_name = (
167
+ item["name"].replace("_kernel_time", "").replace("_fence_before", "").replace("_fence_after", "")
168
+ )
169
+
170
+ if "provider" in item["args"]:
171
+ if item["args"]["provider"] == "CPUExecutionProvider":
172
+ device = "CPU"
173
+ elif item["args"]["provider"] == "CUDAExecutionProvider":
174
+ device = "CUDA"
175
+ elif item["args"]["provider"] == "DmlExecutionProvider":
176
+ device = "DML"
177
+
178
+ if node_name not in node_provider:
179
+ node_provider[node_name] = device
180
+ else:
181
+ assert node_provider[node_name] == device
182
+ elif kernel_time_only:
183
+ continue
184
+
185
+ op_name = item["args"]["op_name"]
186
+ if op_name in _NODES_TYPE_CONTAINING_SUBGRAPH:
187
+ continue
188
+
189
+ if node_name in node_time:
190
+ node_time[node_name] += item["dur"]
191
+ node_freq[node_name] += 1
192
+ else:
193
+ node_time[node_name] = item["dur"]
194
+ node_freq[node_name] = 1
195
+ node_name_list.append(node_name)
196
+
197
+ total += item["dur"]
198
+
199
+ # Output items in the original order.
200
+ lines = [
201
+ "\nNodes in the original order:",
202
+ "-" * 64,
203
+ "Total(μs)\tTime%\tAcc %\tAvg(μs)\tCalls\tProvider\tNode",
204
+ ]
205
+ before_percentage = 0.0
206
+ for node_name in node_name_list:
207
+ duration = node_time[node_name]
208
+ calls = node_freq[node_name]
209
+ avg_time = duration / float(calls)
210
+ percentage = (duration / total) * 100.0
211
+ provider = node_provider.get(node_name, "")
212
+ before_percentage += percentage
213
+ lines.append(
214
+ f"{duration:10d}\t{percentage:5.2f}\t{before_percentage:5.2f}\t{avg_time:8.1f}\t{calls:5d}\t{provider:8s}\t{node_name}"
215
+ )
216
+
217
+ # Output items with run time ratio > thresholds, and sorted by duration in the descending order.
218
+ lines.append(f"\nTop expensive nodes with Time% >= {threshold * 100:.2f}:")
219
+ lines.append("-" * 64)
220
+ lines.append("Total(μs)\tTime%\tAvg(μs)\tCalls\tProvider\tNode")
221
+ for node_name, duration in sorted(node_time.items(), key=lambda x: x[1], reverse=True):
222
+ ratio = duration / total
223
+ if ratio < threshold:
224
+ continue
225
+
226
+ calls = node_freq[node_name]
227
+ avg_time = duration / float(calls)
228
+ percentage = (duration / total) * 100.0
229
+ provider = node_provider.get(node_name, "")
230
+ lines.append(f"{duration:10d}\t{percentage:5.2f}\t{avg_time:8.1f}\t{calls:5d}\t{provider:8s}\t{node_name}")
231
+
232
+ return lines
233
+
234
+
235
+ def group_node_results(sess_time):
236
+ """Group results by operator name.
237
+
238
+ Args:
239
+ sess_time (List[Dict]): profile data
240
+
241
+ Returns:
242
+ List[str]: lines of string for output.
243
+ """
244
+ op_kernel_time = {}
245
+ op_kernel_records = {}
246
+ total_kernel_time = 0
247
+
248
+ provider_op_kernel_time = {}
249
+ provider_op_kernel_records = {}
250
+ provider_kernel_time = {}
251
+
252
+ op_fence_time = {}
253
+ total_fence_time = 0
254
+
255
+ provider_counter = {}
256
+ for item in sess_time:
257
+ if item["cat"] == "Node" and "dur" in item and "args" in item and "op_name" in item["args"]:
258
+ op_name = item["args"]["op_name"]
259
+
260
+ # TODO: shall we have a separated group for nodes with subgraph?
261
+ if op_name in _NODES_TYPE_CONTAINING_SUBGRAPH:
262
+ continue
263
+
264
+ if "provider" not in item["args"]:
265
+ if "fence" in item["name"]:
266
+ if op_name in op_fence_time:
267
+ op_fence_time[op_name] += item["dur"]
268
+ else:
269
+ op_fence_time[op_name] = item["dur"]
270
+ total_fence_time += item["dur"]
271
+ continue
272
+
273
+ provider = item["args"].get("provider", "")
274
+ if provider in provider_counter:
275
+ provider_counter[provider] += 1
276
+ else:
277
+ provider_counter[provider] = 1
278
+
279
+ key = f"{provider}:{op_name}"
280
+ if key in provider_op_kernel_time:
281
+ provider_op_kernel_time[key] += item["dur"]
282
+ provider_op_kernel_records[key] += 1
283
+ else:
284
+ provider_op_kernel_time[key] = item["dur"]
285
+ provider_op_kernel_records[key] = 1
286
+
287
+ if provider in provider_kernel_time:
288
+ provider_kernel_time[provider] += item["dur"]
289
+ else:
290
+ provider_kernel_time[provider] = item["dur"]
291
+
292
+ if op_name in op_kernel_time:
293
+ op_kernel_time[op_name] += item["dur"]
294
+ op_kernel_records[op_name] += 1
295
+ else:
296
+ op_kernel_time[op_name] = item["dur"]
297
+ op_kernel_records[op_name] = 1
298
+
299
+ total_kernel_time += item["dur"]
300
+
301
+ lines = ["", "Grouped by operator"]
302
+ lines.append("-" * 64)
303
+ lines.append("Total(μs)\tTime%\tKernel(μs)\tKernel%\tCalls\tAvgKernel(μs)\tFence(μs)\tOperator")
304
+ for op_name, kernel_time in sorted(op_kernel_time.items(), key=lambda x: x[1], reverse=True):
305
+ fence_time = op_fence_time.get(op_name, 0)
306
+ kernel_time_ratio = kernel_time / total_kernel_time
307
+ total_time = kernel_time + fence_time
308
+ time_ratio = total_time / (total_kernel_time + total_fence_time)
309
+ kernel_calls = op_kernel_records[op_name]
310
+ avg_kernel_time = kernel_time / kernel_calls
311
+ lines.append(
312
+ f"{total_time:10d}\t{time_ratio * 100.0:5.2f}\t{kernel_time:11d}\t{kernel_time_ratio * 100.0:5.2f}\t{kernel_calls:5d}\t{avg_kernel_time:14.1f}\t{fence_time:10d}\t{op_name}"
313
+ )
314
+
315
+ lines += ["", "Grouped by provider + operator"]
316
+ lines.append("-" * 64)
317
+ lines.append("Kernel(μs)\tProvider%\tCalls\tAvgKernel(μs)\tProvider\tOperator")
318
+ for key, kernel_time in sorted(provider_op_kernel_time.items(), key=lambda x: x[1], reverse=True):
319
+ parts = key.split(":")
320
+ provider = parts[0]
321
+ op_name = parts[1]
322
+ short_ep = provider.replace("ExecutionProvider", "")
323
+ calls = provider_op_kernel_records[key]
324
+ avg_kernel_time = kernel_time / calls
325
+ provider_time_ratio = kernel_time / provider_kernel_time[provider]
326
+ lines.append(
327
+ f"{kernel_time:10d}\t{provider_time_ratio * 100.0:9.2f}\t{calls:5d}\t{avg_kernel_time:14.1f}\t{short_ep:8s}\t{op_name}"
328
+ )
329
+
330
+ return lines
331
+
332
+
333
+ def process_results(profile_file, args):
334
+ profile_records = load_profile_json(profile_file)
335
+
336
+ lines = parse_kernel_results(profile_records, args.threshold)
337
+
338
+ lines += parse_node_results(profile_records, args.kernel_time_only, args.threshold)
339
+
340
+ lines += group_node_results(profile_records)
341
+
342
+ return lines
343
+
344
+
345
+ if __name__ == "__main__":
346
+ arguments = parse_arguments()
347
+ print("Arguments", arguments)
348
+
349
+ from benchmark_helper import setup_logger
350
+
351
+ setup_logger(arguments.verbose)
352
+
353
+ profile_file = arguments.input
354
+
355
+ results = process_results(profile_file, arguments)
356
+
357
+ for line in results:
358
+ print(line)