onnxruntime-directml 1.24.1__cp314-cp314-win_amd64.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (322) hide show
  1. onnxruntime/LICENSE +21 -0
  2. onnxruntime/Privacy.md +21 -0
  3. onnxruntime/ThirdPartyNotices.txt +6121 -0
  4. onnxruntime/__init__.py +418 -0
  5. onnxruntime/backend/__init__.py +6 -0
  6. onnxruntime/backend/backend.py +175 -0
  7. onnxruntime/backend/backend_rep.py +52 -0
  8. onnxruntime/capi/DirectML.dll +0 -0
  9. onnxruntime/capi/__init__.py +4 -0
  10. onnxruntime/capi/_ld_preload.py +7 -0
  11. onnxruntime/capi/_pybind_state.py +33 -0
  12. onnxruntime/capi/build_and_package_info.py +2 -0
  13. onnxruntime/capi/convert_npz_to_onnx_adapter.py +48 -0
  14. onnxruntime/capi/onnxruntime.dll +0 -0
  15. onnxruntime/capi/onnxruntime_collect_build_info.py +47 -0
  16. onnxruntime/capi/onnxruntime_inference_collection.py +1440 -0
  17. onnxruntime/capi/onnxruntime_providers_shared.dll +0 -0
  18. onnxruntime/capi/onnxruntime_pybind11_state.pyd +0 -0
  19. onnxruntime/capi/onnxruntime_validation.py +154 -0
  20. onnxruntime/capi/version_info.py +2 -0
  21. onnxruntime/datasets/__init__.py +18 -0
  22. onnxruntime/datasets/logreg_iris.onnx +0 -0
  23. onnxruntime/datasets/mul_1.onnx +0 -0
  24. onnxruntime/datasets/sigmoid.onnx +13 -0
  25. onnxruntime/quantization/CalTableFlatBuffers/KeyValue.py +78 -0
  26. onnxruntime/quantization/CalTableFlatBuffers/TrtTable.py +90 -0
  27. onnxruntime/quantization/CalTableFlatBuffers/__init__.py +0 -0
  28. onnxruntime/quantization/__init__.py +19 -0
  29. onnxruntime/quantization/base_quantizer.py +529 -0
  30. onnxruntime/quantization/calibrate.py +1267 -0
  31. onnxruntime/quantization/execution_providers/qnn/__init__.py +2 -0
  32. onnxruntime/quantization/execution_providers/qnn/fusion_lpnorm.py +132 -0
  33. onnxruntime/quantization/execution_providers/qnn/fusion_spacetodepth.py +162 -0
  34. onnxruntime/quantization/execution_providers/qnn/mixed_precision_overrides_utils.py +413 -0
  35. onnxruntime/quantization/execution_providers/qnn/preprocess.py +353 -0
  36. onnxruntime/quantization/execution_providers/qnn/quant_config.py +389 -0
  37. onnxruntime/quantization/fusions/__init__.py +4 -0
  38. onnxruntime/quantization/fusions/fusion.py +311 -0
  39. onnxruntime/quantization/fusions/fusion_gelu.py +272 -0
  40. onnxruntime/quantization/fusions/fusion_layernorm.py +146 -0
  41. onnxruntime/quantization/fusions/replace_upsample_with_resize.py +96 -0
  42. onnxruntime/quantization/matmul_bnb4_quantizer.py +239 -0
  43. onnxruntime/quantization/matmul_nbits_quantizer.py +1638 -0
  44. onnxruntime/quantization/neural_compressor/__init__.py +1 -0
  45. onnxruntime/quantization/neural_compressor/onnx_model.py +1251 -0
  46. onnxruntime/quantization/neural_compressor/util.py +80 -0
  47. onnxruntime/quantization/neural_compressor/weight_only.py +932 -0
  48. onnxruntime/quantization/onnx_model.py +600 -0
  49. onnxruntime/quantization/onnx_quantizer.py +1163 -0
  50. onnxruntime/quantization/operators/__init__.py +2 -0
  51. onnxruntime/quantization/operators/activation.py +119 -0
  52. onnxruntime/quantization/operators/argmax.py +18 -0
  53. onnxruntime/quantization/operators/attention.py +73 -0
  54. onnxruntime/quantization/operators/base_operator.py +26 -0
  55. onnxruntime/quantization/operators/binary_op.py +72 -0
  56. onnxruntime/quantization/operators/concat.py +62 -0
  57. onnxruntime/quantization/operators/conv.py +260 -0
  58. onnxruntime/quantization/operators/direct_q8.py +78 -0
  59. onnxruntime/quantization/operators/embed_layernorm.py +121 -0
  60. onnxruntime/quantization/operators/gather.py +64 -0
  61. onnxruntime/quantization/operators/gavgpool.py +62 -0
  62. onnxruntime/quantization/operators/gemm.py +172 -0
  63. onnxruntime/quantization/operators/lstm.py +121 -0
  64. onnxruntime/quantization/operators/matmul.py +231 -0
  65. onnxruntime/quantization/operators/maxpool.py +34 -0
  66. onnxruntime/quantization/operators/norm.py +40 -0
  67. onnxruntime/quantization/operators/pad.py +172 -0
  68. onnxruntime/quantization/operators/pooling.py +67 -0
  69. onnxruntime/quantization/operators/qdq_base_operator.py +22 -0
  70. onnxruntime/quantization/operators/resize.py +34 -0
  71. onnxruntime/quantization/operators/softmax.py +74 -0
  72. onnxruntime/quantization/operators/split.py +63 -0
  73. onnxruntime/quantization/operators/where.py +87 -0
  74. onnxruntime/quantization/preprocess.py +141 -0
  75. onnxruntime/quantization/qdq_loss_debug.py +389 -0
  76. onnxruntime/quantization/qdq_quantizer.py +1477 -0
  77. onnxruntime/quantization/quant_utils.py +1051 -0
  78. onnxruntime/quantization/quantize.py +953 -0
  79. onnxruntime/quantization/registry.py +110 -0
  80. onnxruntime/quantization/shape_inference.py +204 -0
  81. onnxruntime/quantization/static_quantize_runner.py +256 -0
  82. onnxruntime/quantization/tensor_quant_overrides.py +520 -0
  83. onnxruntime/tools/__init__.py +10 -0
  84. onnxruntime/tools/check_onnx_model_mobile_usability.py +47 -0
  85. onnxruntime/tools/convert_onnx_models_to_ort.py +380 -0
  86. onnxruntime/tools/file_utils.py +47 -0
  87. onnxruntime/tools/logger.py +11 -0
  88. onnxruntime/tools/make_dynamic_shape_fixed.py +73 -0
  89. onnxruntime/tools/mobile_helpers/__init__.py +0 -0
  90. onnxruntime/tools/mobile_helpers/coreml_supported_mlprogram_ops.md +53 -0
  91. onnxruntime/tools/mobile_helpers/coreml_supported_neuralnetwork_ops.md +43 -0
  92. onnxruntime/tools/mobile_helpers/nnapi_supported_ops.md +58 -0
  93. onnxruntime/tools/mobile_helpers/usability_checker.py +738 -0
  94. onnxruntime/tools/offline_tuning.py +169 -0
  95. onnxruntime/tools/onnx_model_utils.py +416 -0
  96. onnxruntime/tools/onnx_randomizer.py +85 -0
  97. onnxruntime/tools/onnxruntime_test.py +164 -0
  98. onnxruntime/tools/optimize_onnx_model.py +56 -0
  99. onnxruntime/tools/ort_format_model/__init__.py +27 -0
  100. onnxruntime/tools/ort_format_model/operator_type_usage_processors.py +653 -0
  101. onnxruntime/tools/ort_format_model/ort_flatbuffers_py/__init__.py +0 -0
  102. onnxruntime/tools/ort_format_model/ort_flatbuffers_py/fbs/ArgType.py +7 -0
  103. onnxruntime/tools/ort_format_model/ort_flatbuffers_py/fbs/ArgTypeAndIndex.py +67 -0
  104. onnxruntime/tools/ort_format_model/ort_flatbuffers_py/fbs/Attribute.py +337 -0
  105. onnxruntime/tools/ort_format_model/ort_flatbuffers_py/fbs/AttributeType.py +18 -0
  106. onnxruntime/tools/ort_format_model/ort_flatbuffers_py/fbs/Checkpoint.py +125 -0
  107. onnxruntime/tools/ort_format_model/ort_flatbuffers_py/fbs/DeprecatedKernelCreateInfos.py +120 -0
  108. onnxruntime/tools/ort_format_model/ort_flatbuffers_py/fbs/DeprecatedNodeIndexAndKernelDefHash.py +68 -0
  109. onnxruntime/tools/ort_format_model/ort_flatbuffers_py/fbs/DeprecatedSessionState.py +96 -0
  110. onnxruntime/tools/ort_format_model/ort_flatbuffers_py/fbs/DeprecatedSubGraphSessionState.py +72 -0
  111. onnxruntime/tools/ort_format_model/ort_flatbuffers_py/fbs/Dimension.py +71 -0
  112. onnxruntime/tools/ort_format_model/ort_flatbuffers_py/fbs/DimensionValue.py +80 -0
  113. onnxruntime/tools/ort_format_model/ort_flatbuffers_py/fbs/DimensionValueType.py +8 -0
  114. onnxruntime/tools/ort_format_model/ort_flatbuffers_py/fbs/EdgeEnd.py +32 -0
  115. onnxruntime/tools/ort_format_model/ort_flatbuffers_py/fbs/FloatProperty.py +67 -0
  116. onnxruntime/tools/ort_format_model/ort_flatbuffers_py/fbs/Graph.py +320 -0
  117. onnxruntime/tools/ort_format_model/ort_flatbuffers_py/fbs/InferenceSession.py +88 -0
  118. onnxruntime/tools/ort_format_model/ort_flatbuffers_py/fbs/IntProperty.py +67 -0
  119. onnxruntime/tools/ort_format_model/ort_flatbuffers_py/fbs/KernelTypeStrArgsEntry.py +91 -0
  120. onnxruntime/tools/ort_format_model/ort_flatbuffers_py/fbs/KernelTypeStrResolver.py +78 -0
  121. onnxruntime/tools/ort_format_model/ort_flatbuffers_py/fbs/MapType.py +71 -0
  122. onnxruntime/tools/ort_format_model/ort_flatbuffers_py/fbs/Model.py +223 -0
  123. onnxruntime/tools/ort_format_model/ort_flatbuffers_py/fbs/ModuleState.py +141 -0
  124. onnxruntime/tools/ort_format_model/ort_flatbuffers_py/fbs/Node.py +317 -0
  125. onnxruntime/tools/ort_format_model/ort_flatbuffers_py/fbs/NodeEdge.py +126 -0
  126. onnxruntime/tools/ort_format_model/ort_flatbuffers_py/fbs/NodeType.py +7 -0
  127. onnxruntime/tools/ort_format_model/ort_flatbuffers_py/fbs/NodesToOptimizeIndices.py +160 -0
  128. onnxruntime/tools/ort_format_model/ort_flatbuffers_py/fbs/OpIdKernelTypeStrArgsEntry.py +91 -0
  129. onnxruntime/tools/ort_format_model/ort_flatbuffers_py/fbs/OperatorSetId.py +67 -0
  130. onnxruntime/tools/ort_format_model/ort_flatbuffers_py/fbs/OptimizerGroup.py +117 -0
  131. onnxruntime/tools/ort_format_model/ort_flatbuffers_py/fbs/ParameterOptimizerState.py +91 -0
  132. onnxruntime/tools/ort_format_model/ort_flatbuffers_py/fbs/PropertyBag.py +152 -0
  133. onnxruntime/tools/ort_format_model/ort_flatbuffers_py/fbs/RuntimeOptimizationRecord.py +105 -0
  134. onnxruntime/tools/ort_format_model/ort_flatbuffers_py/fbs/RuntimeOptimizationRecordContainerEntry.py +91 -0
  135. onnxruntime/tools/ort_format_model/ort_flatbuffers_py/fbs/RuntimeOptimizations.py +79 -0
  136. onnxruntime/tools/ort_format_model/ort_flatbuffers_py/fbs/SequenceType.py +58 -0
  137. onnxruntime/tools/ort_format_model/ort_flatbuffers_py/fbs/Shape.py +78 -0
  138. onnxruntime/tools/ort_format_model/ort_flatbuffers_py/fbs/SparseTensor.py +114 -0
  139. onnxruntime/tools/ort_format_model/ort_flatbuffers_py/fbs/StringProperty.py +67 -0
  140. onnxruntime/tools/ort_format_model/ort_flatbuffers_py/fbs/StringStringEntry.py +67 -0
  141. onnxruntime/tools/ort_format_model/ort_flatbuffers_py/fbs/Tensor.py +203 -0
  142. onnxruntime/tools/ort_format_model/ort_flatbuffers_py/fbs/TensorDataType.py +26 -0
  143. onnxruntime/tools/ort_format_model/ort_flatbuffers_py/fbs/TensorTypeAndShape.py +71 -0
  144. onnxruntime/tools/ort_format_model/ort_flatbuffers_py/fbs/TypeInfo.py +83 -0
  145. onnxruntime/tools/ort_format_model/ort_flatbuffers_py/fbs/TypeInfoValue.py +9 -0
  146. onnxruntime/tools/ort_format_model/ort_flatbuffers_py/fbs/ValueInfo.py +84 -0
  147. onnxruntime/tools/ort_format_model/ort_flatbuffers_py/fbs/__init__.py +6 -0
  148. onnxruntime/tools/ort_format_model/ort_model_processor.py +86 -0
  149. onnxruntime/tools/ort_format_model/types.py +85 -0
  150. onnxruntime/tools/ort_format_model/utils.py +61 -0
  151. onnxruntime/tools/pytorch_export_contrib_ops.py +129 -0
  152. onnxruntime/tools/pytorch_export_helpers.py +131 -0
  153. onnxruntime/tools/qdq_helpers/__init__.py +0 -0
  154. onnxruntime/tools/qdq_helpers/optimize_qdq_model.py +37 -0
  155. onnxruntime/tools/qnn/add_trans_cast.py +292 -0
  156. onnxruntime/tools/qnn/gen_qnn_ctx_onnx_model.py +364 -0
  157. onnxruntime/tools/qnn/preprocess.py +165 -0
  158. onnxruntime/tools/reduced_build_config_parser.py +203 -0
  159. onnxruntime/tools/remove_initializer_from_input.py +37 -0
  160. onnxruntime/tools/symbolic_shape_infer.py +3094 -0
  161. onnxruntime/tools/update_onnx_opset.py +31 -0
  162. onnxruntime/transformers/__init__.py +8 -0
  163. onnxruntime/transformers/affinity_helper.py +40 -0
  164. onnxruntime/transformers/benchmark.py +942 -0
  165. onnxruntime/transformers/benchmark_helper.py +643 -0
  166. onnxruntime/transformers/bert_perf_test.py +629 -0
  167. onnxruntime/transformers/bert_test_data.py +641 -0
  168. onnxruntime/transformers/compare_bert_results.py +256 -0
  169. onnxruntime/transformers/constants.py +47 -0
  170. onnxruntime/transformers/convert_generation.py +3605 -0
  171. onnxruntime/transformers/convert_tf_models_to_pytorch.py +205 -0
  172. onnxruntime/transformers/convert_to_packing_mode.py +385 -0
  173. onnxruntime/transformers/dynamo_onnx_helper.py +205 -0
  174. onnxruntime/transformers/float16.py +501 -0
  175. onnxruntime/transformers/fusion_attention.py +1189 -0
  176. onnxruntime/transformers/fusion_attention_clip.py +340 -0
  177. onnxruntime/transformers/fusion_attention_sam2.py +533 -0
  178. onnxruntime/transformers/fusion_attention_unet.py +1307 -0
  179. onnxruntime/transformers/fusion_attention_vae.py +300 -0
  180. onnxruntime/transformers/fusion_bart_attention.py +435 -0
  181. onnxruntime/transformers/fusion_base.py +141 -0
  182. onnxruntime/transformers/fusion_bias_add.py +57 -0
  183. onnxruntime/transformers/fusion_biasgelu.py +66 -0
  184. onnxruntime/transformers/fusion_biassplitgelu.py +110 -0
  185. onnxruntime/transformers/fusion_conformer_attention.py +222 -0
  186. onnxruntime/transformers/fusion_constant_fold.py +144 -0
  187. onnxruntime/transformers/fusion_embedlayer.py +810 -0
  188. onnxruntime/transformers/fusion_fastgelu.py +492 -0
  189. onnxruntime/transformers/fusion_gelu.py +258 -0
  190. onnxruntime/transformers/fusion_gelu_approximation.py +25 -0
  191. onnxruntime/transformers/fusion_gemmfastgelu.py +121 -0
  192. onnxruntime/transformers/fusion_gpt_attention.py +546 -0
  193. onnxruntime/transformers/fusion_gpt_attention_megatron.py +355 -0
  194. onnxruntime/transformers/fusion_gpt_attention_no_past.py +260 -0
  195. onnxruntime/transformers/fusion_group_norm.py +180 -0
  196. onnxruntime/transformers/fusion_layernorm.py +489 -0
  197. onnxruntime/transformers/fusion_mha_mmdit.py +667 -0
  198. onnxruntime/transformers/fusion_nhwc_conv.py +99 -0
  199. onnxruntime/transformers/fusion_options.py +340 -0
  200. onnxruntime/transformers/fusion_qordered_attention.py +420 -0
  201. onnxruntime/transformers/fusion_qordered_gelu.py +118 -0
  202. onnxruntime/transformers/fusion_qordered_layernorm.py +122 -0
  203. onnxruntime/transformers/fusion_qordered_matmul.py +216 -0
  204. onnxruntime/transformers/fusion_quickgelu.py +74 -0
  205. onnxruntime/transformers/fusion_reshape.py +173 -0
  206. onnxruntime/transformers/fusion_rotary_attention.py +1591 -0
  207. onnxruntime/transformers/fusion_shape.py +109 -0
  208. onnxruntime/transformers/fusion_simplified_layernorm.py +165 -0
  209. onnxruntime/transformers/fusion_skip_group_norm.py +254 -0
  210. onnxruntime/transformers/fusion_skiplayernorm.py +209 -0
  211. onnxruntime/transformers/fusion_transpose.py +167 -0
  212. onnxruntime/transformers/fusion_utils.py +321 -0
  213. onnxruntime/transformers/huggingface_models.py +74 -0
  214. onnxruntime/transformers/import_utils.py +20 -0
  215. onnxruntime/transformers/io_binding_helper.py +487 -0
  216. onnxruntime/transformers/large_model_exporter.py +395 -0
  217. onnxruntime/transformers/machine_info.py +230 -0
  218. onnxruntime/transformers/metrics.py +163 -0
  219. onnxruntime/transformers/models/bart/__init__.py +12 -0
  220. onnxruntime/transformers/models/bart/export.py +98 -0
  221. onnxruntime/transformers/models/bert/__init__.py +12 -0
  222. onnxruntime/transformers/models/bert/eval_squad.py +329 -0
  223. onnxruntime/transformers/models/gpt2/__init__.py +12 -0
  224. onnxruntime/transformers/models/gpt2/benchmark_gpt2.py +413 -0
  225. onnxruntime/transformers/models/gpt2/convert_to_onnx.py +566 -0
  226. onnxruntime/transformers/models/gpt2/gpt2_helper.py +1031 -0
  227. onnxruntime/transformers/models/gpt2/gpt2_parity.py +513 -0
  228. onnxruntime/transformers/models/gpt2/gpt2_tester.py +501 -0
  229. onnxruntime/transformers/models/gpt2/parity_check_helper.py +146 -0
  230. onnxruntime/transformers/models/llama/__init__.py +12 -0
  231. onnxruntime/transformers/models/llama/benchmark.py +700 -0
  232. onnxruntime/transformers/models/llama/benchmark_all.py +488 -0
  233. onnxruntime/transformers/models/llama/benchmark_e2e.py +608 -0
  234. onnxruntime/transformers/models/llama/convert_to_onnx.py +1064 -0
  235. onnxruntime/transformers/models/llama/dist_settings.py +57 -0
  236. onnxruntime/transformers/models/llama/llama_inputs.py +504 -0
  237. onnxruntime/transformers/models/llama/llama_parity.py +343 -0
  238. onnxruntime/transformers/models/llama/llama_torch.py +47 -0
  239. onnxruntime/transformers/models/llama/quant_kv_dataloader.py +108 -0
  240. onnxruntime/transformers/models/longformer/__init__.py +12 -0
  241. onnxruntime/transformers/models/longformer/benchmark_longformer.py +821 -0
  242. onnxruntime/transformers/models/longformer/convert_to_onnx.py +413 -0
  243. onnxruntime/transformers/models/longformer/generate_test_data.py +347 -0
  244. onnxruntime/transformers/models/longformer/longformer_helper.py +76 -0
  245. onnxruntime/transformers/models/phi2/__init__.py +12 -0
  246. onnxruntime/transformers/models/phi2/convert_to_onnx.py +590 -0
  247. onnxruntime/transformers/models/phi2/inference_example.py +414 -0
  248. onnxruntime/transformers/models/sam2/__init__.py +12 -0
  249. onnxruntime/transformers/models/sam2/benchmark_sam2.py +638 -0
  250. onnxruntime/transformers/models/sam2/convert_to_onnx.py +270 -0
  251. onnxruntime/transformers/models/sam2/image_decoder.py +272 -0
  252. onnxruntime/transformers/models/sam2/image_encoder.py +236 -0
  253. onnxruntime/transformers/models/sam2/mask_decoder.py +208 -0
  254. onnxruntime/transformers/models/sam2/nvtx_helper.py +33 -0
  255. onnxruntime/transformers/models/sam2/prompt_encoder.py +189 -0
  256. onnxruntime/transformers/models/sam2/sam2_demo.py +321 -0
  257. onnxruntime/transformers/models/sam2/sam2_image_onnx_predictor.py +279 -0
  258. onnxruntime/transformers/models/sam2/sam2_utils.py +147 -0
  259. onnxruntime/transformers/models/stable_diffusion/__init__.py +12 -0
  260. onnxruntime/transformers/models/stable_diffusion/benchmark.py +1519 -0
  261. onnxruntime/transformers/models/stable_diffusion/benchmark_controlnet.py +426 -0
  262. onnxruntime/transformers/models/stable_diffusion/demo_txt2img.py +103 -0
  263. onnxruntime/transformers/models/stable_diffusion/demo_txt2img_xl.py +269 -0
  264. onnxruntime/transformers/models/stable_diffusion/demo_utils.py +778 -0
  265. onnxruntime/transformers/models/stable_diffusion/diffusion_models.py +1318 -0
  266. onnxruntime/transformers/models/stable_diffusion/diffusion_schedulers.py +1179 -0
  267. onnxruntime/transformers/models/stable_diffusion/engine_builder.py +295 -0
  268. onnxruntime/transformers/models/stable_diffusion/engine_builder_ort_cuda.py +387 -0
  269. onnxruntime/transformers/models/stable_diffusion/engine_builder_ort_trt.py +288 -0
  270. onnxruntime/transformers/models/stable_diffusion/engine_builder_tensorrt.py +395 -0
  271. onnxruntime/transformers/models/stable_diffusion/engine_builder_torch.py +108 -0
  272. onnxruntime/transformers/models/stable_diffusion/optimize_pipeline.py +590 -0
  273. onnxruntime/transformers/models/stable_diffusion/ort_optimizer.py +136 -0
  274. onnxruntime/transformers/models/stable_diffusion/pipeline_stable_diffusion.py +831 -0
  275. onnxruntime/transformers/models/stable_diffusion/trt_utilities.py +12 -0
  276. onnxruntime/transformers/models/t5/__init__.py +12 -0
  277. onnxruntime/transformers/models/t5/convert_to_onnx.py +318 -0
  278. onnxruntime/transformers/models/t5/t5_decoder.py +437 -0
  279. onnxruntime/transformers/models/t5/t5_encoder.py +70 -0
  280. onnxruntime/transformers/models/t5/t5_encoder_decoder_init.py +361 -0
  281. onnxruntime/transformers/models/t5/t5_helper.py +302 -0
  282. onnxruntime/transformers/models/whisper/__init__.py +12 -0
  283. onnxruntime/transformers/models/whisper/benchmark.py +585 -0
  284. onnxruntime/transformers/models/whisper/benchmark_all.py +526 -0
  285. onnxruntime/transformers/models/whisper/convert_to_onnx.py +609 -0
  286. onnxruntime/transformers/models/whisper/whisper_chain.py +334 -0
  287. onnxruntime/transformers/models/whisper/whisper_decoder.py +464 -0
  288. onnxruntime/transformers/models/whisper/whisper_encoder.py +164 -0
  289. onnxruntime/transformers/models/whisper/whisper_encoder_decoder_init.py +371 -0
  290. onnxruntime/transformers/models/whisper/whisper_helper.py +1035 -0
  291. onnxruntime/transformers/models/whisper/whisper_inputs.py +380 -0
  292. onnxruntime/transformers/models/whisper/whisper_jump_times.py +477 -0
  293. onnxruntime/transformers/onnx_exporter.py +719 -0
  294. onnxruntime/transformers/onnx_model.py +1636 -0
  295. onnxruntime/transformers/onnx_model_bart.py +141 -0
  296. onnxruntime/transformers/onnx_model_bert.py +488 -0
  297. onnxruntime/transformers/onnx_model_bert_keras.py +474 -0
  298. onnxruntime/transformers/onnx_model_bert_tf.py +588 -0
  299. onnxruntime/transformers/onnx_model_clip.py +42 -0
  300. onnxruntime/transformers/onnx_model_conformer.py +32 -0
  301. onnxruntime/transformers/onnx_model_gpt2.py +101 -0
  302. onnxruntime/transformers/onnx_model_mmdit.py +112 -0
  303. onnxruntime/transformers/onnx_model_phi.py +929 -0
  304. onnxruntime/transformers/onnx_model_sam2.py +137 -0
  305. onnxruntime/transformers/onnx_model_t5.py +985 -0
  306. onnxruntime/transformers/onnx_model_tnlr.py +226 -0
  307. onnxruntime/transformers/onnx_model_unet.py +258 -0
  308. onnxruntime/transformers/onnx_model_vae.py +42 -0
  309. onnxruntime/transformers/onnx_utils.py +55 -0
  310. onnxruntime/transformers/optimizer.py +620 -0
  311. onnxruntime/transformers/past_helper.py +149 -0
  312. onnxruntime/transformers/profile_result_processor.py +358 -0
  313. onnxruntime/transformers/profiler.py +434 -0
  314. onnxruntime/transformers/quantize_helper.py +76 -0
  315. onnxruntime/transformers/shape_infer_helper.py +121 -0
  316. onnxruntime/transformers/shape_optimizer.py +400 -0
  317. onnxruntime/transformers/torch_onnx_export_helper.py +74 -0
  318. onnxruntime_directml-1.24.1.dist-info/METADATA +216 -0
  319. onnxruntime_directml-1.24.1.dist-info/RECORD +322 -0
  320. onnxruntime_directml-1.24.1.dist-info/WHEEL +5 -0
  321. onnxruntime_directml-1.24.1.dist-info/entry_points.txt +2 -0
  322. onnxruntime_directml-1.24.1.dist-info/top_level.txt +1 -0
@@ -0,0 +1,146 @@
1
+ # -------------------------------------------------------------------------
2
+ # Copyright (c) Microsoft Corporation. All rights reserved.
3
+ # Licensed under the MIT License. See License.txt in the project root for
4
+ # license information.
5
+ # --------------------------------------------------------------------------
6
+ from __future__ import annotations
7
+
8
+ import onnx
9
+
10
+ from ..onnx_model import ONNXModel
11
+ from .fusion import Fusion
12
+
13
+
14
+ class FusionLayerNormalization(Fusion):
15
+ def __init__(self, model: ONNXModel):
16
+ super().__init__(model, "LayerNormalization", "ReduceMean")
17
+
18
+ def fuse(
19
+ self,
20
+ reduce_mean_node: onnx.NodeProto,
21
+ input_name_to_nodes: dict[str, list[onnx.NodeProto]],
22
+ output_name_to_node: dict[str, onnx.NodeProto],
23
+ ):
24
+ """
25
+ Interface function that tries to fuse a node sequence containing a ReduceMean node into a single
26
+ LayerNormalization node.
27
+
28
+ +----------------------+
29
+ | |
30
+ | v
31
+ [Root] --> ReduceMean --> Sub --> Pow --> ReduceMean --> Add --> Sqrt --> Div --> Mul --> Add
32
+ (axis=2 or -1) | (Y=2) (axis=2 or -1) (E-6 or E-12 or 0) ^
33
+ | |
34
+ +-------------------------------------------------+
35
+
36
+ Or, using Mul instead of Pow:
37
+
38
+ +----------------------+
39
+ | |
40
+ | v
41
+ [Root] --> ReduceMean --> Sub --> Mul --> ReduceMean --> Add --> Sqrt --> Div --> Mul --> Add
42
+ (axis=2 or -1) | (in0=in1) (axis=2 or -1) (E-6 or E-12 or 0) ^
43
+ | |
44
+ +-------------------------------------------------+
45
+
46
+ It also handles cases of duplicated sub nodes exported from older version of PyTorch:
47
+
48
+ +----------------------+
49
+ | v
50
+ | +-------> Sub-----------------------------------------------+
51
+ | | |
52
+ | | v
53
+ [Root] --> ReduceMean --> Sub --> (Pow or Mul) --> ReduceMean --> Add --> Sqrt --> Div --> Mul --> Add
54
+ | ^
55
+ | |
56
+ +----------------------+
57
+ """
58
+ children = self.model.get_children(reduce_mean_node, input_name_to_nodes)
59
+ if len(children) == 0 or len(children) > 2:
60
+ return
61
+
62
+ root_input = reduce_mean_node.input[0]
63
+
64
+ if children[0].op_type != "Sub" or children[0].input[0] != root_input:
65
+ return
66
+
67
+ if len(children) == 2:
68
+ if children[1].op_type != "Sub" or children[1].input[0] != root_input:
69
+ return
70
+
71
+ div_node = None
72
+ for child in children:
73
+ div_node = self.find_first_child_by_type(child, "Div", input_name_to_nodes, recursive=False)
74
+ if div_node is not None:
75
+ break
76
+ if div_node is None:
77
+ return
78
+
79
+ path_id, parent_nodes, _ = self.match_parent_paths(
80
+ div_node,
81
+ [
82
+ (["Sqrt", "Add", "ReduceMean", "Pow", "Sub"], [1, 0, 0, 0, 0]),
83
+ (["Sqrt", "Add", "ReduceMean", "Pow", "Cast", "Sub"], [1, 0, 0, 0, 0, 0]),
84
+ (["Sqrt", "Add", "ReduceMean", "Mul", "Sub"], [1, 0, 0, 0, 0]),
85
+ (["Sqrt", "Add", "ReduceMean", "Mul", "Cast", "Sub"], [1, 0, 0, 0, 0, 0]),
86
+ ],
87
+ output_name_to_node,
88
+ )
89
+ if path_id < 0:
90
+ return
91
+
92
+ sub_node = parent_nodes[-1]
93
+ if sub_node not in children:
94
+ return
95
+
96
+ second_add_node = parent_nodes[1]
97
+ i, add_weight = self.get_constant_input(second_add_node)
98
+ if add_weight is None or add_weight <= 0 or add_weight > 1.0e-4:
99
+ # Skip fusion since epsilon value is not expected.
100
+ return
101
+
102
+ pow_or_mul_node = parent_nodes[3]
103
+ if pow_or_mul_node.op_type == "Pow" and self.find_constant_input(pow_or_mul_node, 2.0) != 1:
104
+ return
105
+ elif pow_or_mul_node.op_type == "Mul" and pow_or_mul_node.input[0] != pow_or_mul_node.input[1]:
106
+ return
107
+
108
+ mul_node = input_name_to_nodes[div_node.output[0]][0]
109
+ if mul_node.op_type != "Mul":
110
+ return
111
+
112
+ last_add_node = input_name_to_nodes[mul_node.output[0]][0]
113
+ if last_add_node.op_type != "Add":
114
+ return
115
+
116
+ subgraph_nodes = [reduce_mean_node]
117
+ subgraph_nodes.extend(children)
118
+ subgraph_nodes.extend(parent_nodes[:-1])
119
+
120
+ subgraph_nodes.extend([last_add_node, mul_node, div_node])
121
+ if not self.is_safe_to_fuse_nodes(
122
+ subgraph_nodes,
123
+ last_add_node.output,
124
+ input_name_to_nodes,
125
+ output_name_to_node,
126
+ ):
127
+ return
128
+
129
+ weight_input = mul_node.input[1 - self.input_index(div_node.output[0], mul_node)]
130
+ if not self.is_constant_with_specified_rank(weight_input, 1):
131
+ return
132
+
133
+ bias_input = last_add_node.input[1 - self.input_index(mul_node.output[0], last_add_node)]
134
+ if not self.is_constant_with_specified_rank(bias_input, 1):
135
+ return
136
+
137
+ self.nodes_to_remove.extend(subgraph_nodes)
138
+
139
+ normalize_node = onnx.helper.make_node(
140
+ "LayerNormalization",
141
+ name=self.create_unique_node_name(),
142
+ inputs=[reduce_mean_node.input[0], weight_input, bias_input],
143
+ outputs=[last_add_node.output[0]],
144
+ )
145
+ normalize_node.attribute.extend([onnx.helper.make_attribute("epsilon", float(add_weight))])
146
+ self.nodes_to_add.append(normalize_node)
@@ -0,0 +1,96 @@
1
+ # -------------------------------------------------------------------------
2
+ # Copyright (c) Microsoft Corporation. All rights reserved.
3
+ # Licensed under the MIT License. See License.txt in the project root for
4
+ # license information.
5
+ # --------------------------------------------------------------------------
6
+ from __future__ import annotations
7
+
8
+ import numpy as np
9
+ import onnx
10
+
11
+ from ..onnx_model import ONNXModel
12
+ from .fusion import Fusion
13
+
14
+
15
+ class ReplaceUpsampleWithResize(Fusion):
16
+ """Replace Upsample with Resize."""
17
+
18
+ def __init__(self, model: ONNXModel, opset):
19
+ """Initialize."""
20
+ super().__init__(model, "Resize", "Upsample")
21
+ self.opset = opset
22
+
23
+ def fuse(
24
+ self,
25
+ node: onnx.NodeProto,
26
+ input_name_to_nodes: dict[str, list[onnx.NodeProto]],
27
+ output_name_to_node: dict[str, onnx.NodeProto],
28
+ ):
29
+ """Replace Upsample with Resize."""
30
+ mode = None
31
+ for attr in node.attribute:
32
+ if attr.name == "mode":
33
+ mode = attr.s.decode("utf-8")
34
+ break
35
+
36
+ scales_input = None
37
+ if self.opset > 7:
38
+ scales_input = node.input[1] if len(node.input) > 1 else ""
39
+ resize_inputs = [node.input[0], node.name + "_roi", scales_input]
40
+ else:
41
+ if self.opset == 7:
42
+ for attr in node.attribute:
43
+ if attr.name == "scales":
44
+ scales_input = attr.floats
45
+ break
46
+
47
+ scales_input = np.array(list(scales_input), np.float32)
48
+ else:
49
+ h_scale = 1
50
+ w_scale = 1
51
+ for attr in node.attribute:
52
+ if attr.name == "height_scale":
53
+ h_scale = attr.float
54
+ elif attr.name == "width_scale":
55
+ w_scale = attr.float
56
+
57
+ scales_input = np.array([1, 1, h_scale, w_scale], np.float32)
58
+
59
+ scales_tensor = onnx.helper.make_tensor(
60
+ name=node.name + "_scales",
61
+ data_type=onnx.TensorProto.FLOAT,
62
+ dims=scales_input.shape,
63
+ vals=scales_input.flatten().tolist(),
64
+ )
65
+
66
+ scales_node = onnx.helper.make_node(
67
+ "Constant", inputs=[], outputs=[node.name + "_scales"], value=scales_tensor
68
+ )
69
+
70
+ self.nodes_to_add.append(scales_node)
71
+
72
+ resize_inputs = [node.input[0], node.name + "_roi", node.name + "_scales"]
73
+
74
+ roi_tensor = onnx.helper.make_tensor(
75
+ name=node.name + "_roi",
76
+ data_type=onnx.TensorProto.FLOAT,
77
+ dims=(len(scales_input) * 2,),
78
+ vals=[0] * len(scales_input) + [1] * len(scales_input),
79
+ )
80
+
81
+ roi_node = onnx.helper.make_node("Constant", inputs=[], outputs=[node.name + "_roi"], value=roi_tensor)
82
+
83
+ resize_node = onnx.helper.make_node(
84
+ op_type="Resize", inputs=resize_inputs, outputs=node.output, mode=mode, nearest_mode="floor"
85
+ )
86
+
87
+ self.nodes_to_remove.append(node)
88
+ self.nodes_to_add.append(roi_node)
89
+ self.nodes_to_add.append(resize_node)
90
+
91
+ def apply(self) -> bool:
92
+ """Apply."""
93
+ if super().apply():
94
+ self.model.topological_sort()
95
+ return True
96
+ return False
@@ -0,0 +1,239 @@
1
+ # -------------------------------------------------------------------------
2
+ # Copyright (c) Microsoft Corporation. All rights reserved.
3
+ # Licensed under the MIT License. See License.txt in the project root for
4
+ # license information.
5
+ # --------------------------------------------------------------------------
6
+
7
+ import argparse
8
+ import logging
9
+ import os
10
+
11
+ import numpy as np
12
+ import numpy.typing as npt
13
+ import onnx
14
+ from onnx.onnx_pb import GraphProto, ModelProto, NodeProto, TensorProto
15
+
16
+ from onnxruntime.capi._pybind_state import quantize_matmul_bnb4
17
+
18
+ from .onnx_model import ONNXModel
19
+ from .quant_utils import attribute_to_kwarg
20
+
21
+ logger = logging.getLogger(__name__)
22
+
23
+
24
+ class MatMulBnb4Quantizer:
25
+ """Perform 4b quantization of constant MatMul weights using FP4 or NF4 data type"""
26
+
27
+ ##################
28
+ # quantization types, must be consistent with native code type
29
+ # Bnb_DataType_t defined in blockwise_quant_block_bnb4.h
30
+
31
+ # 4b floating point with bias of 3
32
+ FP4 = 0
33
+
34
+ # 4b NormalFloat
35
+ NF4 = 1
36
+
37
+ def __init__(self, model: ModelProto, quant_type: int, block_size: int, nodes_to_exclude=None):
38
+ nodes_to_exclude = nodes_to_exclude or []
39
+ assert quant_type in [MatMulBnb4Quantizer.FP4, MatMulBnb4Quantizer.NF4]
40
+ self.model = ONNXModel(model)
41
+ self.quant_type = quant_type
42
+ self.block_size = block_size
43
+ self.nodes_to_exclude = set(nodes_to_exclude)
44
+
45
+ @staticmethod
46
+ def __get_initializer(name, graph_path: list[GraphProto]) -> tuple[TensorProto, GraphProto]:
47
+ for gid in range(len(graph_path) - 1, -1, -1):
48
+ graph = graph_path[gid]
49
+ for tensor in graph.initializer:
50
+ if tensor.name == name:
51
+ return tensor, graph
52
+ return None, None
53
+
54
+ def bnb4_block_quant(self, fpweight: npt.ArrayLike) -> np.ndarray:
55
+ """4b quantize fp32/fp16 weight"""
56
+
57
+ if len(fpweight.shape) != 2:
58
+ raise ValueError("Current bnb4 block quantization only supports 2D tensors!")
59
+ # need to copy since the transposed weight still has the original memory layout
60
+ # Linear4bit quantizes its weight data which is the transposed weight
61
+ fpweight_t = fpweight.transpose().copy()
62
+
63
+ rows, cols = fpweight.shape
64
+ numel = rows * cols
65
+ block_size = self.block_size
66
+ num_blocks = (numel + block_size - 1) // block_size
67
+ quantized_numel = (numel + 1) // 2
68
+
69
+ packed = np.zeros(quantized_numel, dtype="uint8")
70
+ absmax = np.zeros(num_blocks, dtype=fpweight.dtype)
71
+ # block wise quantization, fpweight_t is flattened and divided into blocks
72
+ quantize_matmul_bnb4(packed, fpweight_t, absmax, block_size, self.quant_type, cols, rows)
73
+
74
+ return (packed, absmax)
75
+
76
+ def _bnb4_matmul_node_weight(self, node: NodeProto, graph_stack: list[GraphProto]) -> NodeProto:
77
+ """If the node is MatMul with fp32 const weight, quantize the weight with int4, and return the new node"""
78
+
79
+ if node.op_type != "MatMul":
80
+ return node # only care about MatMul for now
81
+
82
+ logger.debug(f"start to quantize {node.name} ...")
83
+ if node.name in self.nodes_to_exclude:
84
+ logger.debug(f"exclude to quantize {node.name} as specified by nodes_to_exclude...")
85
+ return node
86
+
87
+ inputB = node.input[1] # noqa: N806
88
+ B, Bs_graph = MatMulBnb4Quantizer.__get_initializer(inputB, graph_stack) # noqa: N806
89
+ if B is None:
90
+ logger.debug("MatMul doesn't have const weight. Skip to quantize")
91
+ return node # only care about constant weight
92
+
93
+ B_array = onnx.numpy_helper.to_array(B) # noqa: N806
94
+ if len(B_array.shape) != 2:
95
+ logger.debug("MatMul weight is not 2D. Skip to quantize")
96
+ return node # can only process 2-D matrix
97
+
98
+ packed, absmax = self.bnb4_block_quant(B_array)
99
+ B_quant = onnx.numpy_helper.from_array(packed) # noqa: N806
100
+ B_quant.name = B.name + "_Bnb4"
101
+ for input in Bs_graph.input:
102
+ if input.name == inputB:
103
+ Bs_graph.input.remove(input)
104
+ break
105
+
106
+ absmax_tensor = onnx.numpy_helper.from_array(absmax)
107
+ absmax_tensor.name = B.name + "_absmax"
108
+
109
+ Bs_graph.initializer.extend([B_quant, absmax_tensor])
110
+
111
+ kwargs = {}
112
+ rows, cols = B_array.shape
113
+ kwargs["K"] = rows
114
+ kwargs["N"] = cols
115
+ kwargs["block_size"] = self.block_size
116
+ kwargs["quant_type"] = self.quant_type
117
+
118
+ matmul_bnb4_node = onnx.helper.make_node(
119
+ "MatMulBnb4",
120
+ inputs=[node.input[0], B_quant.name, absmax_tensor.name],
121
+ outputs=[node.output[0]],
122
+ name=node.name + "_Bnb4" if node.name else "",
123
+ domain="com.microsoft",
124
+ **kwargs,
125
+ )
126
+
127
+ logger.debug(f"complete quantization of {node.name} ...")
128
+
129
+ return matmul_bnb4_node
130
+
131
+ def _process_subgraph(self, graph_stack: list[GraphProto]):
132
+ new_nodes = []
133
+ graph = graph_stack[-1]
134
+
135
+ for node in graph.node:
136
+ graph_attrs = [
137
+ attr
138
+ for attr in node.attribute
139
+ if attr.type == onnx.AttributeProto.GRAPH or attr.type == onnx.AttributeProto.GRAPHS
140
+ ]
141
+ if graph_attrs:
142
+ kwargs = {}
143
+ for attr in node.attribute:
144
+ if attr.type == onnx.AttributeProto.GRAPH:
145
+ # recursive call to take care of sub-graph
146
+ graph_stack.append(attr.g)
147
+ kv = {attr.name: self._process_subgraph(graph_stack)}
148
+ elif attr.type == onnx.AttributeProto.GRAPHS:
149
+ value = []
150
+ for subgraph in attr.graphs:
151
+ # recursive call to take care of sub-graph
152
+ graph_stack.append(subgraph)
153
+ value.extend([self._process_subgraph(graph_stack)])
154
+ kv = {attr.name: value}
155
+ else:
156
+ kv = attribute_to_kwarg(attr)
157
+ kwargs.update(kv)
158
+ node = onnx.helper.make_node( # noqa: PLW2901
159
+ node.op_type, node.input, node.output, name=node.name, **kwargs
160
+ )
161
+
162
+ new_nodes.append(self._bnb4_matmul_node_weight(node, graph_stack))
163
+
164
+ graph.ClearField("node")
165
+ graph.node.extend(new_nodes)
166
+ graph_stack.pop()
167
+ return graph
168
+
169
+ def process(self):
170
+ # use a stack to keep track of sub-graphs
171
+ graph_stack = [self.model.graph()]
172
+ opset_import = self.model.opset_import()
173
+
174
+ has_ms_domain = False
175
+ for opset in opset_import:
176
+ if opset.domain == "com.microsoft":
177
+ has_ms_domain = True
178
+ if not has_ms_domain:
179
+ opset_import.extend([onnx.helper.make_opsetid("com.microsoft", 1)])
180
+
181
+ self._process_subgraph(graph_stack)
182
+ self.model.clean_initializers()
183
+
184
+
185
+ def parse_args():
186
+ parser = argparse.ArgumentParser(
187
+ description="""Blockwise FP4/NF4 quantization for MatMul 2D weight matrices.
188
+
189
+ A weight matrix is partitioned into blocks, where each block is a contiguous
190
+ subset inside the flattened transposed weight matrix. Each block is quantized
191
+ into a set of 4b integers with an absolute value scaling factor.
192
+ """
193
+ )
194
+
195
+ parser.add_argument("--input_model", required=True, help="Path to the input model file")
196
+ parser.add_argument("--output_model", required=True, help="Path to the output model file")
197
+ parser.add_argument(
198
+ "--quant_type",
199
+ required=False,
200
+ default=1,
201
+ choices=[MatMulBnb4Quantizer.FP4, MatMulBnb4Quantizer.NF4],
202
+ help="Quantization data type. 0: FP4, 1: NF4",
203
+ )
204
+ parser.add_argument(
205
+ "--block_size",
206
+ required=False,
207
+ default=64,
208
+ help="Block size for blockwise quantization. Note: bnb.nn.Linear4bit only uses block_size=64",
209
+ )
210
+ parser.add_argument("-v", "--verbose", required=False, action="store_true")
211
+ parser.set_defaults(verbose=False)
212
+ parser.add_argument(
213
+ "--nodes_to_exclude",
214
+ nargs="+",
215
+ type=str,
216
+ required=False,
217
+ default=[],
218
+ help="Specify the nodes to be excluded from quantization with node names",
219
+ )
220
+
221
+ return parser.parse_args()
222
+
223
+
224
+ if __name__ == "__main__":
225
+ args = parse_args()
226
+ if args.verbose:
227
+ logger.setLevel(logging.DEBUG)
228
+
229
+ input_model_path = args.input_model
230
+ output_model_path = args.output_model
231
+
232
+ if os.path.exists(output_model_path):
233
+ logger.error(f"file {output_model_path} already exists")
234
+ raise Exception(f"file {output_model_path} already exists")
235
+
236
+ model = onnx.load(input_model_path)
237
+ quant = MatMulBnb4Quantizer(model, args.quant_type, args.block_size, nodes_to_exclude=args.nodes_to_exclude)
238
+ quant.process()
239
+ quant.model.save_model_to_file(output_model_path, True)