onnxruntime-directml 1.20.0__cp313-cp313-win_amd64.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (305) hide show
  1. onnxruntime/LICENSE +21 -0
  2. onnxruntime/Privacy.md +21 -0
  3. onnxruntime/ThirdPartyNotices.txt +6508 -0
  4. onnxruntime/__init__.py +78 -0
  5. onnxruntime/backend/__init__.py +6 -0
  6. onnxruntime/backend/backend.py +174 -0
  7. onnxruntime/backend/backend_rep.py +53 -0
  8. onnxruntime/capi/DirectML.dll +0 -0
  9. onnxruntime/capi/__init__.py +4 -0
  10. onnxruntime/capi/_ld_preload.py +7 -0
  11. onnxruntime/capi/_pybind_state.py +33 -0
  12. onnxruntime/capi/convert_npz_to_onnx_adapter.py +48 -0
  13. onnxruntime/capi/onnxruntime.dll +0 -0
  14. onnxruntime/capi/onnxruntime_collect_build_info.py +47 -0
  15. onnxruntime/capi/onnxruntime_inference_collection.py +1108 -0
  16. onnxruntime/capi/onnxruntime_providers_shared.dll +0 -0
  17. onnxruntime/capi/onnxruntime_pybind11_state.pyd +0 -0
  18. onnxruntime/capi/onnxruntime_validation.py +150 -0
  19. onnxruntime/capi/version_info.py +2 -0
  20. onnxruntime/datasets/__init__.py +17 -0
  21. onnxruntime/datasets/logreg_iris.onnx +0 -0
  22. onnxruntime/datasets/mul_1.onnx +0 -0
  23. onnxruntime/datasets/sigmoid.onnx +13 -0
  24. onnxruntime/quantization/CalTableFlatBuffers/KeyValue.py +78 -0
  25. onnxruntime/quantization/CalTableFlatBuffers/TrtTable.py +90 -0
  26. onnxruntime/quantization/CalTableFlatBuffers/__init__.py +0 -0
  27. onnxruntime/quantization/__init__.py +16 -0
  28. onnxruntime/quantization/base_quantizer.py +532 -0
  29. onnxruntime/quantization/calibrate.py +1245 -0
  30. onnxruntime/quantization/execution_providers/qnn/__init__.py +2 -0
  31. onnxruntime/quantization/execution_providers/qnn/fusion_lpnorm.py +132 -0
  32. onnxruntime/quantization/execution_providers/qnn/mixed_precision_overrides_utils.py +413 -0
  33. onnxruntime/quantization/execution_providers/qnn/preprocess.py +307 -0
  34. onnxruntime/quantization/execution_providers/qnn/quant_config.py +387 -0
  35. onnxruntime/quantization/fusions/__init__.py +3 -0
  36. onnxruntime/quantization/fusions/fusion.py +311 -0
  37. onnxruntime/quantization/fusions/fusion_gelu.py +272 -0
  38. onnxruntime/quantization/fusions/fusion_layernorm.py +135 -0
  39. onnxruntime/quantization/matmul_4bits_quantizer.py +1480 -0
  40. onnxruntime/quantization/matmul_bnb4_quantizer.py +240 -0
  41. onnxruntime/quantization/onnx_model.py +580 -0
  42. onnxruntime/quantization/onnx_quantizer.py +1008 -0
  43. onnxruntime/quantization/operators/__init__.py +2 -0
  44. onnxruntime/quantization/operators/activation.py +119 -0
  45. onnxruntime/quantization/operators/argmax.py +18 -0
  46. onnxruntime/quantization/operators/attention.py +73 -0
  47. onnxruntime/quantization/operators/base_operator.py +26 -0
  48. onnxruntime/quantization/operators/binary_op.py +72 -0
  49. onnxruntime/quantization/operators/concat.py +62 -0
  50. onnxruntime/quantization/operators/conv.py +258 -0
  51. onnxruntime/quantization/operators/direct_q8.py +78 -0
  52. onnxruntime/quantization/operators/embed_layernorm.py +121 -0
  53. onnxruntime/quantization/operators/gather.py +64 -0
  54. onnxruntime/quantization/operators/gavgpool.py +62 -0
  55. onnxruntime/quantization/operators/gemm.py +166 -0
  56. onnxruntime/quantization/operators/lstm.py +117 -0
  57. onnxruntime/quantization/operators/matmul.py +231 -0
  58. onnxruntime/quantization/operators/maxpool.py +34 -0
  59. onnxruntime/quantization/operators/norm.py +40 -0
  60. onnxruntime/quantization/operators/pad.py +100 -0
  61. onnxruntime/quantization/operators/pooling.py +67 -0
  62. onnxruntime/quantization/operators/qdq_base_operator.py +22 -0
  63. onnxruntime/quantization/operators/resize.py +34 -0
  64. onnxruntime/quantization/operators/softmax.py +74 -0
  65. onnxruntime/quantization/operators/split.py +63 -0
  66. onnxruntime/quantization/operators/where.py +87 -0
  67. onnxruntime/quantization/preprocess.py +141 -0
  68. onnxruntime/quantization/qdq_loss_debug.py +389 -0
  69. onnxruntime/quantization/qdq_quantizer.py +1187 -0
  70. onnxruntime/quantization/quant_utils.py +891 -0
  71. onnxruntime/quantization/quantize.py +748 -0
  72. onnxruntime/quantization/registry.py +106 -0
  73. onnxruntime/quantization/shape_inference.py +187 -0
  74. onnxruntime/quantization/tensor_quant_overrides.py +516 -0
  75. onnxruntime/tools/__init__.py +10 -0
  76. onnxruntime/tools/check_onnx_model_mobile_usability.py +47 -0
  77. onnxruntime/tools/convert_onnx_models_to_ort.py +377 -0
  78. onnxruntime/tools/file_utils.py +46 -0
  79. onnxruntime/tools/logger.py +11 -0
  80. onnxruntime/tools/make_dynamic_shape_fixed.py +72 -0
  81. onnxruntime/tools/mobile_helpers/__init__.py +0 -0
  82. onnxruntime/tools/mobile_helpers/coreml_supported_mlprogram_ops.md +33 -0
  83. onnxruntime/tools/mobile_helpers/coreml_supported_neuralnetwork_ops.md +43 -0
  84. onnxruntime/tools/mobile_helpers/nnapi_supported_ops.md +58 -0
  85. onnxruntime/tools/mobile_helpers/usability_checker.py +739 -0
  86. onnxruntime/tools/offline_tuning.py +169 -0
  87. onnxruntime/tools/onnx_model_utils.py +413 -0
  88. onnxruntime/tools/onnx_randomizer.py +85 -0
  89. onnxruntime/tools/onnxruntime_test.py +164 -0
  90. onnxruntime/tools/optimize_onnx_model.py +55 -0
  91. onnxruntime/tools/ort_format_model/__init__.py +25 -0
  92. onnxruntime/tools/ort_format_model/operator_type_usage_processors.py +663 -0
  93. onnxruntime/tools/ort_format_model/ort_flatbuffers_py/__init__.py +0 -0
  94. onnxruntime/tools/ort_format_model/ort_flatbuffers_py/fbs/ArgType.py +7 -0
  95. onnxruntime/tools/ort_format_model/ort_flatbuffers_py/fbs/ArgTypeAndIndex.py +67 -0
  96. onnxruntime/tools/ort_format_model/ort_flatbuffers_py/fbs/Attribute.py +337 -0
  97. onnxruntime/tools/ort_format_model/ort_flatbuffers_py/fbs/AttributeType.py +18 -0
  98. onnxruntime/tools/ort_format_model/ort_flatbuffers_py/fbs/Checkpoint.py +125 -0
  99. onnxruntime/tools/ort_format_model/ort_flatbuffers_py/fbs/DeprecatedKernelCreateInfos.py +120 -0
  100. onnxruntime/tools/ort_format_model/ort_flatbuffers_py/fbs/DeprecatedNodeIndexAndKernelDefHash.py +68 -0
  101. onnxruntime/tools/ort_format_model/ort_flatbuffers_py/fbs/DeprecatedSessionState.py +96 -0
  102. onnxruntime/tools/ort_format_model/ort_flatbuffers_py/fbs/DeprecatedSubGraphSessionState.py +72 -0
  103. onnxruntime/tools/ort_format_model/ort_flatbuffers_py/fbs/Dimension.py +71 -0
  104. onnxruntime/tools/ort_format_model/ort_flatbuffers_py/fbs/DimensionValue.py +80 -0
  105. onnxruntime/tools/ort_format_model/ort_flatbuffers_py/fbs/DimensionValueType.py +8 -0
  106. onnxruntime/tools/ort_format_model/ort_flatbuffers_py/fbs/EdgeEnd.py +32 -0
  107. onnxruntime/tools/ort_format_model/ort_flatbuffers_py/fbs/FloatProperty.py +67 -0
  108. onnxruntime/tools/ort_format_model/ort_flatbuffers_py/fbs/Graph.py +320 -0
  109. onnxruntime/tools/ort_format_model/ort_flatbuffers_py/fbs/InferenceSession.py +88 -0
  110. onnxruntime/tools/ort_format_model/ort_flatbuffers_py/fbs/IntProperty.py +67 -0
  111. onnxruntime/tools/ort_format_model/ort_flatbuffers_py/fbs/KernelTypeStrArgsEntry.py +91 -0
  112. onnxruntime/tools/ort_format_model/ort_flatbuffers_py/fbs/KernelTypeStrResolver.py +78 -0
  113. onnxruntime/tools/ort_format_model/ort_flatbuffers_py/fbs/MapType.py +71 -0
  114. onnxruntime/tools/ort_format_model/ort_flatbuffers_py/fbs/Model.py +223 -0
  115. onnxruntime/tools/ort_format_model/ort_flatbuffers_py/fbs/ModuleState.py +141 -0
  116. onnxruntime/tools/ort_format_model/ort_flatbuffers_py/fbs/Node.py +317 -0
  117. onnxruntime/tools/ort_format_model/ort_flatbuffers_py/fbs/NodeEdge.py +126 -0
  118. onnxruntime/tools/ort_format_model/ort_flatbuffers_py/fbs/NodeType.py +7 -0
  119. onnxruntime/tools/ort_format_model/ort_flatbuffers_py/fbs/NodesToOptimizeIndices.py +160 -0
  120. onnxruntime/tools/ort_format_model/ort_flatbuffers_py/fbs/OpIdKernelTypeStrArgsEntry.py +91 -0
  121. onnxruntime/tools/ort_format_model/ort_flatbuffers_py/fbs/OperatorSetId.py +67 -0
  122. onnxruntime/tools/ort_format_model/ort_flatbuffers_py/fbs/OptimizerGroup.py +117 -0
  123. onnxruntime/tools/ort_format_model/ort_flatbuffers_py/fbs/ParameterOptimizerState.py +91 -0
  124. onnxruntime/tools/ort_format_model/ort_flatbuffers_py/fbs/PropertyBag.py +152 -0
  125. onnxruntime/tools/ort_format_model/ort_flatbuffers_py/fbs/RuntimeOptimizationRecord.py +105 -0
  126. onnxruntime/tools/ort_format_model/ort_flatbuffers_py/fbs/RuntimeOptimizationRecordContainerEntry.py +91 -0
  127. onnxruntime/tools/ort_format_model/ort_flatbuffers_py/fbs/RuntimeOptimizations.py +79 -0
  128. onnxruntime/tools/ort_format_model/ort_flatbuffers_py/fbs/SequenceType.py +58 -0
  129. onnxruntime/tools/ort_format_model/ort_flatbuffers_py/fbs/Shape.py +78 -0
  130. onnxruntime/tools/ort_format_model/ort_flatbuffers_py/fbs/SparseTensor.py +114 -0
  131. onnxruntime/tools/ort_format_model/ort_flatbuffers_py/fbs/StringProperty.py +67 -0
  132. onnxruntime/tools/ort_format_model/ort_flatbuffers_py/fbs/StringStringEntry.py +67 -0
  133. onnxruntime/tools/ort_format_model/ort_flatbuffers_py/fbs/Tensor.py +203 -0
  134. onnxruntime/tools/ort_format_model/ort_flatbuffers_py/fbs/TensorDataType.py +26 -0
  135. onnxruntime/tools/ort_format_model/ort_flatbuffers_py/fbs/TensorTypeAndShape.py +71 -0
  136. onnxruntime/tools/ort_format_model/ort_flatbuffers_py/fbs/TypeInfo.py +83 -0
  137. onnxruntime/tools/ort_format_model/ort_flatbuffers_py/fbs/TypeInfoValue.py +9 -0
  138. onnxruntime/tools/ort_format_model/ort_flatbuffers_py/fbs/ValueInfo.py +84 -0
  139. onnxruntime/tools/ort_format_model/ort_flatbuffers_py/fbs/__init__.py +6 -0
  140. onnxruntime/tools/ort_format_model/ort_model_processor.py +86 -0
  141. onnxruntime/tools/ort_format_model/types.py +84 -0
  142. onnxruntime/tools/ort_format_model/utils.py +62 -0
  143. onnxruntime/tools/pytorch_export_contrib_ops.py +108 -0
  144. onnxruntime/tools/pytorch_export_helpers.py +131 -0
  145. onnxruntime/tools/qdq_helpers/__init__.py +0 -0
  146. onnxruntime/tools/qdq_helpers/optimize_qdq_model.py +37 -0
  147. onnxruntime/tools/reduced_build_config_parser.py +202 -0
  148. onnxruntime/tools/symbolic_shape_infer.py +3016 -0
  149. onnxruntime/tools/update_onnx_opset.py +31 -0
  150. onnxruntime/transformers/__init__.py +8 -0
  151. onnxruntime/transformers/affinity_helper.py +40 -0
  152. onnxruntime/transformers/benchmark.py +944 -0
  153. onnxruntime/transformers/benchmark_helper.py +646 -0
  154. onnxruntime/transformers/bert_perf_test.py +634 -0
  155. onnxruntime/transformers/bert_test_data.py +642 -0
  156. onnxruntime/transformers/compare_bert_results.py +246 -0
  157. onnxruntime/transformers/constants.py +47 -0
  158. onnxruntime/transformers/convert_generation.py +3124 -0
  159. onnxruntime/transformers/convert_tf_models_to_pytorch.py +205 -0
  160. onnxruntime/transformers/convert_to_packing_mode.py +387 -0
  161. onnxruntime/transformers/dynamo_onnx_helper.py +104 -0
  162. onnxruntime/transformers/float16.py +501 -0
  163. onnxruntime/transformers/fusion_attention.py +1235 -0
  164. onnxruntime/transformers/fusion_attention_clip.py +257 -0
  165. onnxruntime/transformers/fusion_attention_sam2.py +534 -0
  166. onnxruntime/transformers/fusion_attention_unet.py +1304 -0
  167. onnxruntime/transformers/fusion_attention_vae.py +301 -0
  168. onnxruntime/transformers/fusion_bart_attention.py +640 -0
  169. onnxruntime/transformers/fusion_base.py +137 -0
  170. onnxruntime/transformers/fusion_bias_add.py +58 -0
  171. onnxruntime/transformers/fusion_biasgelu.py +66 -0
  172. onnxruntime/transformers/fusion_biassplitgelu.py +111 -0
  173. onnxruntime/transformers/fusion_conformer_attention.py +143 -0
  174. onnxruntime/transformers/fusion_embedlayer.py +811 -0
  175. onnxruntime/transformers/fusion_fastgelu.py +360 -0
  176. onnxruntime/transformers/fusion_gelu.py +259 -0
  177. onnxruntime/transformers/fusion_gelu_approximation.py +25 -0
  178. onnxruntime/transformers/fusion_gemmfastgelu.py +122 -0
  179. onnxruntime/transformers/fusion_gpt_attention.py +546 -0
  180. onnxruntime/transformers/fusion_gpt_attention_megatron.py +355 -0
  181. onnxruntime/transformers/fusion_gpt_attention_no_past.py +260 -0
  182. onnxruntime/transformers/fusion_group_norm.py +179 -0
  183. onnxruntime/transformers/fusion_layernorm.py +465 -0
  184. onnxruntime/transformers/fusion_nhwc_conv.py +100 -0
  185. onnxruntime/transformers/fusion_options.py +340 -0
  186. onnxruntime/transformers/fusion_qordered_attention.py +421 -0
  187. onnxruntime/transformers/fusion_qordered_gelu.py +119 -0
  188. onnxruntime/transformers/fusion_qordered_layernorm.py +123 -0
  189. onnxruntime/transformers/fusion_qordered_matmul.py +217 -0
  190. onnxruntime/transformers/fusion_quickgelu.py +74 -0
  191. onnxruntime/transformers/fusion_reshape.py +173 -0
  192. onnxruntime/transformers/fusion_rotary_attention.py +1592 -0
  193. onnxruntime/transformers/fusion_shape.py +110 -0
  194. onnxruntime/transformers/fusion_simplified_layernorm.py +159 -0
  195. onnxruntime/transformers/fusion_skip_group_norm.py +255 -0
  196. onnxruntime/transformers/fusion_skiplayernorm.py +209 -0
  197. onnxruntime/transformers/fusion_transpose.py +168 -0
  198. onnxruntime/transformers/fusion_utils.py +307 -0
  199. onnxruntime/transformers/huggingface_models.py +167 -0
  200. onnxruntime/transformers/import_utils.py +20 -0
  201. onnxruntime/transformers/io_binding_helper.py +442 -0
  202. onnxruntime/transformers/large_model_exporter.py +395 -0
  203. onnxruntime/transformers/machine_info.py +221 -0
  204. onnxruntime/transformers/metrics.py +164 -0
  205. onnxruntime/transformers/models/bart/__init__.py +12 -0
  206. onnxruntime/transformers/models/bart/export.py +98 -0
  207. onnxruntime/transformers/models/bert/__init__.py +12 -0
  208. onnxruntime/transformers/models/bert/eval_squad.py +329 -0
  209. onnxruntime/transformers/models/gpt2/__init__.py +12 -0
  210. onnxruntime/transformers/models/gpt2/benchmark_gpt2.py +413 -0
  211. onnxruntime/transformers/models/gpt2/convert_to_onnx.py +561 -0
  212. onnxruntime/transformers/models/gpt2/gpt2_helper.py +1032 -0
  213. onnxruntime/transformers/models/gpt2/gpt2_parity.py +513 -0
  214. onnxruntime/transformers/models/gpt2/gpt2_tester.py +501 -0
  215. onnxruntime/transformers/models/gpt2/parity_check_helper.py +146 -0
  216. onnxruntime/transformers/models/llama/__init__.py +12 -0
  217. onnxruntime/transformers/models/llama/benchmark.py +703 -0
  218. onnxruntime/transformers/models/llama/benchmark_all.py +488 -0
  219. onnxruntime/transformers/models/llama/benchmark_e2e.py +606 -0
  220. onnxruntime/transformers/models/llama/convert_to_onnx.py +1027 -0
  221. onnxruntime/transformers/models/llama/dist_settings.py +57 -0
  222. onnxruntime/transformers/models/llama/llama_inputs.py +503 -0
  223. onnxruntime/transformers/models/llama/llama_parity.py +309 -0
  224. onnxruntime/transformers/models/llama/llama_torch.py +47 -0
  225. onnxruntime/transformers/models/llama/quant_kv_dataloader.py +108 -0
  226. onnxruntime/transformers/models/longformer/__init__.py +12 -0
  227. onnxruntime/transformers/models/longformer/benchmark_longformer.py +821 -0
  228. onnxruntime/transformers/models/longformer/convert_to_onnx.py +413 -0
  229. onnxruntime/transformers/models/longformer/generate_test_data.py +347 -0
  230. onnxruntime/transformers/models/longformer/longformer_helper.py +77 -0
  231. onnxruntime/transformers/models/phi2/__init__.py +12 -0
  232. onnxruntime/transformers/models/phi2/convert_to_onnx.py +576 -0
  233. onnxruntime/transformers/models/phi2/inference_example.py +414 -0
  234. onnxruntime/transformers/models/sam2/__init__.py +12 -0
  235. onnxruntime/transformers/models/sam2/benchmark_sam2.py +625 -0
  236. onnxruntime/transformers/models/sam2/convert_to_onnx.py +260 -0
  237. onnxruntime/transformers/models/sam2/image_decoder.py +273 -0
  238. onnxruntime/transformers/models/sam2/image_encoder.py +186 -0
  239. onnxruntime/transformers/models/sam2/mask_decoder.py +208 -0
  240. onnxruntime/transformers/models/sam2/nvtx_helper.py +33 -0
  241. onnxruntime/transformers/models/sam2/prompt_encoder.py +189 -0
  242. onnxruntime/transformers/models/sam2/sam2_demo.py +322 -0
  243. onnxruntime/transformers/models/sam2/sam2_image_onnx_predictor.py +280 -0
  244. onnxruntime/transformers/models/sam2/sam2_utils.py +147 -0
  245. onnxruntime/transformers/models/stable_diffusion/__init__.py +12 -0
  246. onnxruntime/transformers/models/stable_diffusion/benchmark.py +1429 -0
  247. onnxruntime/transformers/models/stable_diffusion/benchmark_controlnet.py +426 -0
  248. onnxruntime/transformers/models/stable_diffusion/demo_txt2img.py +102 -0
  249. onnxruntime/transformers/models/stable_diffusion/demo_txt2img_xl.py +268 -0
  250. onnxruntime/transformers/models/stable_diffusion/demo_utils.py +778 -0
  251. onnxruntime/transformers/models/stable_diffusion/diffusion_models.py +1319 -0
  252. onnxruntime/transformers/models/stable_diffusion/diffusion_schedulers.py +1181 -0
  253. onnxruntime/transformers/models/stable_diffusion/engine_builder.py +296 -0
  254. onnxruntime/transformers/models/stable_diffusion/engine_builder_ort_cuda.py +388 -0
  255. onnxruntime/transformers/models/stable_diffusion/engine_builder_ort_trt.py +288 -0
  256. onnxruntime/transformers/models/stable_diffusion/engine_builder_tensorrt.py +395 -0
  257. onnxruntime/transformers/models/stable_diffusion/engine_builder_torch.py +108 -0
  258. onnxruntime/transformers/models/stable_diffusion/optimize_pipeline.py +350 -0
  259. onnxruntime/transformers/models/stable_diffusion/ort_optimizer.py +136 -0
  260. onnxruntime/transformers/models/stable_diffusion/pipeline_stable_diffusion.py +831 -0
  261. onnxruntime/transformers/models/stable_diffusion/trt_utilities.py +12 -0
  262. onnxruntime/transformers/models/t5/__init__.py +12 -0
  263. onnxruntime/transformers/models/t5/convert_to_onnx.py +278 -0
  264. onnxruntime/transformers/models/t5/past_helper.py +150 -0
  265. onnxruntime/transformers/models/t5/t5_decoder.py +438 -0
  266. onnxruntime/transformers/models/t5/t5_encoder.py +171 -0
  267. onnxruntime/transformers/models/t5/t5_encoder_decoder_init.py +299 -0
  268. onnxruntime/transformers/models/t5/t5_helper.py +272 -0
  269. onnxruntime/transformers/models/whisper/__init__.py +12 -0
  270. onnxruntime/transformers/models/whisper/benchmark.py +610 -0
  271. onnxruntime/transformers/models/whisper/benchmark_all.py +528 -0
  272. onnxruntime/transformers/models/whisper/convert_to_onnx.py +536 -0
  273. onnxruntime/transformers/models/whisper/whisper_chain.py +329 -0
  274. onnxruntime/transformers/models/whisper/whisper_decoder.py +402 -0
  275. onnxruntime/transformers/models/whisper/whisper_encoder.py +164 -0
  276. onnxruntime/transformers/models/whisper/whisper_encoder_decoder_init.py +306 -0
  277. onnxruntime/transformers/models/whisper/whisper_helper.py +524 -0
  278. onnxruntime/transformers/models/whisper/whisper_openai_helper.py +84 -0
  279. onnxruntime/transformers/onnx_exporter.py +717 -0
  280. onnxruntime/transformers/onnx_model.py +1569 -0
  281. onnxruntime/transformers/onnx_model_bart.py +142 -0
  282. onnxruntime/transformers/onnx_model_bert.py +481 -0
  283. onnxruntime/transformers/onnx_model_bert_keras.py +475 -0
  284. onnxruntime/transformers/onnx_model_bert_tf.py +589 -0
  285. onnxruntime/transformers/onnx_model_clip.py +40 -0
  286. onnxruntime/transformers/onnx_model_conformer.py +33 -0
  287. onnxruntime/transformers/onnx_model_gpt2.py +101 -0
  288. onnxruntime/transformers/onnx_model_phi.py +930 -0
  289. onnxruntime/transformers/onnx_model_sam2.py +138 -0
  290. onnxruntime/transformers/onnx_model_t5.py +791 -0
  291. onnxruntime/transformers/onnx_model_tnlr.py +227 -0
  292. onnxruntime/transformers/onnx_model_unet.py +259 -0
  293. onnxruntime/transformers/onnx_model_vae.py +43 -0
  294. onnxruntime/transformers/onnx_utils.py +55 -0
  295. onnxruntime/transformers/optimizer.py +612 -0
  296. onnxruntime/transformers/profiler.py +725 -0
  297. onnxruntime/transformers/quantize_helper.py +76 -0
  298. onnxruntime/transformers/shape_infer_helper.py +122 -0
  299. onnxruntime/transformers/shape_optimizer.py +401 -0
  300. onnxruntime/transformers/torch_onnx_export_helper.py +74 -0
  301. onnxruntime_directml-1.20.0.dist-info/METADATA +187 -0
  302. onnxruntime_directml-1.20.0.dist-info/RECORD +305 -0
  303. onnxruntime_directml-1.20.0.dist-info/WHEEL +5 -0
  304. onnxruntime_directml-1.20.0.dist-info/entry_points.txt +2 -0
  305. onnxruntime_directml-1.20.0.dist-info/top_level.txt +1 -0
@@ -0,0 +1,536 @@
1
+ # -------------------------------------------------------------------------
2
+ # Copyright (c) Microsoft Corporation. All rights reserved.
3
+ # Licensed under the MIT License. See License.txt in the project root for
4
+ # license information.
5
+ # --------------------------------------------------------------------------
6
+
7
+ import argparse
8
+ import copy
9
+ import logging
10
+ import os
11
+
12
+ import torch
13
+ from benchmark_helper import Precision, create_onnxruntime_session, prepare_environment, setup_logger
14
+ from whisper_chain import chain_model
15
+ from whisper_helper import PRETRAINED_WHISPER_MODELS, WhisperHelper
16
+
17
+ from onnxruntime import quantization
18
+
19
+ logger = logging.getLogger("")
20
+
21
+ PROVIDERS = {
22
+ "cpu": "CPUExecutionProvider",
23
+ "cuda": "CUDAExecutionProvider",
24
+ "rocm": "ROCMExecutionProvider",
25
+ }
26
+
27
+
28
+ def parse_arguments(argv=None):
29
+ parser = argparse.ArgumentParser()
30
+
31
+ conversion_args = parser.add_argument_group("Conversion Process Args")
32
+ optional_inputs = parser.add_argument_group("Optional Inputs (for WhisperBeamSearch op)")
33
+ optional_outputs = parser.add_argument_group("Optional Outputs (for WhisperBeamSearch op)")
34
+ quant_args = parser.add_argument_group("INT8 Quantization Args")
35
+
36
+ #################################
37
+ # Conversion options for Whisper
38
+ #################################
39
+
40
+ conversion_args.add_argument(
41
+ "-m",
42
+ "--model_name_or_path",
43
+ required=False,
44
+ default=PRETRAINED_WHISPER_MODELS[0],
45
+ type=str,
46
+ help="Model path, or pretrained model name in the list: " + ", ".join(PRETRAINED_WHISPER_MODELS),
47
+ )
48
+
49
+ conversion_args.add_argument(
50
+ "--model_impl",
51
+ required=False,
52
+ default="hf",
53
+ choices=["hf", "openai"],
54
+ type=str,
55
+ help="Select implementation for export of encoder and decoder subgraphs",
56
+ )
57
+
58
+ conversion_args.add_argument(
59
+ "--cache_dir",
60
+ required=False,
61
+ type=str,
62
+ default=os.path.join(".", "cache_models"),
63
+ help="Directory to cache pre-trained models",
64
+ )
65
+
66
+ conversion_args.add_argument(
67
+ "--output",
68
+ required=False,
69
+ type=str,
70
+ default=os.path.join(".", "onnx_models"),
71
+ help="Output directory",
72
+ )
73
+
74
+ conversion_args.add_argument(
75
+ "-o",
76
+ "--optimize_onnx",
77
+ required=False,
78
+ action="store_true",
79
+ help="Use optimizer.py to optimize onnx model",
80
+ )
81
+ conversion_args.set_defaults(optimize_onnx=False)
82
+
83
+ conversion_args.add_argument(
84
+ "--use_gpu",
85
+ required=False,
86
+ action="store_true",
87
+ help="Use GPU for model inference",
88
+ )
89
+ conversion_args.set_defaults(use_gpu=False)
90
+
91
+ conversion_args.add_argument(
92
+ "-p",
93
+ "--precision",
94
+ required=False,
95
+ type=Precision,
96
+ default=Precision.FLOAT32,
97
+ choices=[Precision.FLOAT32, Precision.FLOAT16, Precision.INT8],
98
+ help="Precision of model to run. fp32 for full precision, fp16 for half precision, int8 for quantization",
99
+ )
100
+
101
+ conversion_args.add_argument(
102
+ "--use_int64_inputs",
103
+ required=False,
104
+ action="store_true",
105
+ help="Use int64 instead of int32 for input_ids and attention_mask.",
106
+ )
107
+ conversion_args.set_defaults(use_int64_inputs=False)
108
+
109
+ conversion_args.add_argument(
110
+ "--disable_auto_mixed_precision",
111
+ required=False,
112
+ action="store_true",
113
+ help="Use pure fp16 instead of mixed precision",
114
+ )
115
+ conversion_args.set_defaults(disable_auto_mixed_precision=False)
116
+
117
+ conversion_args.add_argument(
118
+ "-r",
119
+ "--provider",
120
+ required=False,
121
+ type=str,
122
+ default="cpu",
123
+ choices=list(PROVIDERS.keys()),
124
+ help="Provider to benchmark. Default is CPUExecutionProvider.",
125
+ )
126
+
127
+ conversion_args.add_argument(
128
+ "--verbose",
129
+ required=False,
130
+ action="store_true",
131
+ help="Enable verbose logging",
132
+ )
133
+ conversion_args.set_defaults(verbose=False)
134
+
135
+ conversion_args.add_argument(
136
+ "-e",
137
+ "--use_external_data_format",
138
+ required=False,
139
+ action="store_true",
140
+ help="Save weights in external file. Necessary for 'small', 'medium', and 'large' models. Optional for 'tiny' and 'base' models.",
141
+ )
142
+ conversion_args.set_defaults(use_external_data_format=False)
143
+
144
+ conversion_args.add_argument(
145
+ "-w",
146
+ "--overwrite",
147
+ required=False,
148
+ action="store_true",
149
+ help="Overwrite existing ONNX model",
150
+ )
151
+ conversion_args.set_defaults(overwrite=False)
152
+
153
+ conversion_args.add_argument(
154
+ "--separate_encoder_and_decoder_init",
155
+ required=False,
156
+ action="store_true",
157
+ help="Do not merge encoder and decoder init to initialize past KV caches. Output 3 instead of 2 ONNX models.",
158
+ )
159
+ conversion_args.set_defaults(separate_encoder_and_decoder_init=False)
160
+
161
+ conversion_args.add_argument(
162
+ "--no_beam_search_op",
163
+ required=False,
164
+ action="store_true",
165
+ help="Do not produce model with WhisperBeamSearch op, which chains encdecinit and decoder models into one op.",
166
+ )
167
+ conversion_args.set_defaults(no_beam_search_op=False)
168
+
169
+ conversion_args.add_argument(
170
+ "--state_dict_path",
171
+ type=str,
172
+ default="",
173
+ help="Filepath to load pre-trained model with custom state dictionary (e.g. pytorch_model.bin)",
174
+ )
175
+
176
+ #############################################################
177
+ # Optional inputs for Whisper
178
+ # (listed below in the order that WhisperBeamSearch expects)
179
+ #############################################################
180
+
181
+ optional_inputs.add_argument(
182
+ "-v",
183
+ "--use_vocab_mask",
184
+ required=False,
185
+ action="store_true",
186
+ help="Use vocab_mask as an extra graph input to enable specific logits processing",
187
+ )
188
+ optional_inputs.set_defaults(use_vocab_mask=False)
189
+
190
+ optional_inputs.add_argument(
191
+ "-u",
192
+ "--use_prefix_vocab_mask",
193
+ required=False,
194
+ action="store_true",
195
+ help="Use prefix_vocab_mask as an extra graph input to enable specific logits processing",
196
+ )
197
+ optional_inputs.set_defaults(use_prefix_vocab_mask=False)
198
+
199
+ optional_inputs.add_argument(
200
+ "-f",
201
+ "--use_forced_decoder_ids",
202
+ required=False,
203
+ action="store_true",
204
+ help="Use decoder_input_ids as an extra graph input to the beam search op",
205
+ )
206
+ optional_inputs.set_defaults(use_forced_decoder_ids=False)
207
+
208
+ optional_inputs.add_argument(
209
+ "-l",
210
+ "--use_logits_processor",
211
+ required=False,
212
+ action="store_true",
213
+ help="Use logits_processor as an extra graph input to enable specific logits processing",
214
+ )
215
+ optional_inputs.set_defaults(use_specific_logits_processor=False)
216
+
217
+ optional_inputs.add_argument(
218
+ "--collect_cross_qk",
219
+ required=False,
220
+ action="store_true",
221
+ help="Beam search model collect stacked cross QK.",
222
+ )
223
+ optional_inputs.set_defaults(collect_cross_qk=False)
224
+
225
+ optional_inputs.add_argument(
226
+ "--extra_decoding_ids",
227
+ required=False,
228
+ action="store_true",
229
+ help="Need extra starting decoding ids for some feature like cross qk. Default if false.",
230
+ )
231
+ optional_inputs.set_defaults(extra_decoding_ids=False)
232
+
233
+ optional_inputs.add_argument(
234
+ "-t",
235
+ "--use_temperature",
236
+ required=False,
237
+ action="store_true",
238
+ help="Use temperature as an extra graph input for the WhisperBeamSearch op",
239
+ )
240
+ optional_inputs.set_defaults(use_temperature=False)
241
+
242
+ optional_inputs.add_argument(
243
+ "--no_repeat_ngram_size",
244
+ type=int,
245
+ default=0,
246
+ help="default to 0",
247
+ )
248
+
249
+ #############################################################
250
+ # Optional outputs for Whisper
251
+ # (listed below in the order that WhisperBeamSearch expects)
252
+ #############################################################
253
+
254
+ optional_outputs.add_argument(
255
+ "--output_sequence_scores",
256
+ required=False,
257
+ action="store_true",
258
+ help="Beam search model output scores for each generated sequence.",
259
+ )
260
+ optional_outputs.set_defaults(output_sequence_scores=False)
261
+
262
+ optional_outputs.add_argument(
263
+ "--output_scores",
264
+ required=False,
265
+ action="store_true",
266
+ help="Beam search model output scores over vocab per generated token.",
267
+ )
268
+ optional_outputs.set_defaults(output_scores=False)
269
+
270
+ optional_outputs.add_argument(
271
+ "--output_cross_qk",
272
+ required=False,
273
+ action="store_true",
274
+ help="Beam search model output collected qk as output. Also hint collect_cross_qk",
275
+ )
276
+ optional_outputs.set_defaults(output_cross_qk=False)
277
+
278
+ optional_outputs.add_argument(
279
+ "--cross_qk_onnx_model",
280
+ required=False,
281
+ type=str,
282
+ default=None,
283
+ help="The model which consumes cross_qk outputs.",
284
+ )
285
+
286
+ optional_outputs.add_argument(
287
+ "--output_no_speech_probs",
288
+ required=False,
289
+ action="store_true",
290
+ help="Beam search model output no speech probs which is computed from the encoder/context-decoder graph.",
291
+ )
292
+ optional_outputs.set_defaults(output_no_speech_probs=False)
293
+
294
+ ###################################
295
+ # Quantization options for Whisper
296
+ ###################################
297
+
298
+ quant_args.add_argument(
299
+ "--quantize_embedding_layer",
300
+ required=False,
301
+ action="store_true",
302
+ help="Quantize MatMul, GEMM, and Gather.",
303
+ )
304
+ quant_args.set_defaults(quantize_embedding_layer=False)
305
+
306
+ quant_args.add_argument(
307
+ "--quantize_per_channel",
308
+ required=False,
309
+ action="store_true",
310
+ help="Quantize weights per each channel.",
311
+ )
312
+ quant_args.set_defaults(quantize_per_channel=False)
313
+
314
+ quant_args.add_argument(
315
+ "--quantize_reduce_range",
316
+ required=False,
317
+ action="store_true",
318
+ help="Quantize weights with 7 bits.",
319
+ )
320
+ quant_args.set_defaults(quantize_reduce_range=False)
321
+
322
+ args = parser.parse_args(argv)
323
+ args.collect_cross_qk = args.collect_cross_qk or args.output_cross_qk
324
+
325
+ return args
326
+
327
+
328
+ def export_onnx_models(
329
+ model_name_or_path,
330
+ model_impl,
331
+ cache_dir,
332
+ output_dir,
333
+ use_gpu,
334
+ use_external_data_format,
335
+ optimize_onnx,
336
+ precision,
337
+ verbose,
338
+ use_forced_decoder_ids: bool = False,
339
+ merge_encoder_and_decoder_init: bool = True,
340
+ overwrite: bool = False,
341
+ disable_auto_mixed_precision: bool = False,
342
+ use_int32_inputs: bool = True,
343
+ quantize_embedding_layer: bool = False,
344
+ quantize_per_channel: bool = False,
345
+ quantize_reduce_range: bool = False,
346
+ state_dict_path: str = "",
347
+ provider: str = "cpu",
348
+ ):
349
+ device = torch.device("cuda:0" if use_gpu else "cpu")
350
+
351
+ models = WhisperHelper.load_model(
352
+ model_name_or_path, model_impl, cache_dir, device, merge_encoder_and_decoder_init, state_dict_path
353
+ )
354
+ config = models["decoder"].config
355
+
356
+ if (not use_external_data_format) and (config.num_hidden_layers > 24):
357
+ logger.info("Try use_external_data_format when model size > 2GB")
358
+
359
+ output_paths = []
360
+ for name, model in models.items():
361
+ print(f"========> Handling {name} model......")
362
+ model.to(device)
363
+ filename_suffix = "_" + name
364
+
365
+ onnx_path = WhisperHelper.get_onnx_path(
366
+ output_dir,
367
+ model_name_or_path,
368
+ suffix=filename_suffix,
369
+ new_folder=False,
370
+ )
371
+
372
+ if overwrite or not os.path.exists(onnx_path):
373
+ logger.info(f"Exporting ONNX model to {onnx_path}")
374
+ # We have to clone model before exporting onnx, otherwise verify_onnx will report large difference.
375
+ device_to_export = torch.device("cpu")
376
+ cloned_model = copy.deepcopy(model).to(device_to_export)
377
+ WhisperHelper.export_onnx(
378
+ cloned_model,
379
+ device_to_export,
380
+ onnx_path,
381
+ verbose,
382
+ use_external_data_format,
383
+ use_int32_inputs=use_int32_inputs,
384
+ )
385
+ else:
386
+ logger.info(f"Skip exporting: existed ONNX model {onnx_path}")
387
+
388
+ # Optimize ONNX graph. Note that we have not implemented graph optimization for Whisper yet.
389
+ if optimize_onnx or precision != Precision.FLOAT32:
390
+ output_path = WhisperHelper.get_onnx_path(
391
+ output_dir,
392
+ model_name_or_path,
393
+ suffix=filename_suffix + "_" + str(precision),
394
+ new_folder=False,
395
+ )
396
+
397
+ if overwrite or not os.path.exists(output_path):
398
+ if optimize_onnx:
399
+ logger.info(f"Optimizing model to {output_path}")
400
+ WhisperHelper.optimize_onnx(
401
+ onnx_path,
402
+ output_path,
403
+ precision == Precision.FLOAT16,
404
+ config.encoder_attention_heads,
405
+ config.d_model,
406
+ use_external_data_format,
407
+ auto_mixed_precision=not disable_auto_mixed_precision,
408
+ use_gpu=use_gpu,
409
+ provider=provider,
410
+ )
411
+ onnx_path = output_path
412
+
413
+ if precision == Precision.INT8:
414
+ quantization.quantize_dynamic(
415
+ onnx_path,
416
+ output_path,
417
+ op_types_to_quantize=(
418
+ ["MatMul", "Gemm", "Gather"] if quantize_embedding_layer else ["MatMul", "Gemm"]
419
+ ),
420
+ use_external_data_format=use_external_data_format,
421
+ per_channel=quantize_per_channel,
422
+ reduce_range=quantize_reduce_range,
423
+ extra_options={"MatMulConstBOnly": True},
424
+ )
425
+ else:
426
+ logger.info(f"Skip optimizing: existing ONNX model {onnx_path}")
427
+ else:
428
+ output_path = onnx_path
429
+
430
+ ort_session = create_onnxruntime_session(
431
+ output_path,
432
+ use_gpu=use_gpu,
433
+ provider=provider,
434
+ )
435
+ assert ort_session is not None
436
+
437
+ output_paths.append(output_path)
438
+
439
+ return output_paths
440
+
441
+
442
+ def main(argv=None):
443
+ args = parse_arguments(argv)
444
+
445
+ setup_logger(args.verbose)
446
+
447
+ logger.info(f"Arguments:{args}")
448
+
449
+ cache_dir = args.cache_dir
450
+ output_dir = args.output if not args.output.endswith(".onnx") else os.path.dirname(args.output)
451
+ prepare_environment(cache_dir, output_dir, args.use_gpu)
452
+
453
+ if args.precision == Precision.FLOAT16:
454
+ assert args.use_gpu, "fp16 requires --use_gpu"
455
+
456
+ if args.optimize_onnx:
457
+ logger.warning("Applying graph optimization for Whisper...")
458
+
459
+ output_paths = export_onnx_models(
460
+ args.model_name_or_path,
461
+ args.model_impl,
462
+ cache_dir,
463
+ output_dir,
464
+ args.use_gpu,
465
+ args.use_external_data_format,
466
+ args.optimize_onnx,
467
+ args.precision,
468
+ args.verbose,
469
+ args.use_forced_decoder_ids,
470
+ not args.separate_encoder_and_decoder_init,
471
+ args.overwrite,
472
+ args.disable_auto_mixed_precision,
473
+ not args.use_int64_inputs,
474
+ args.quantize_embedding_layer,
475
+ args.quantize_per_channel,
476
+ args.quantize_reduce_range,
477
+ args.state_dict_path,
478
+ args.provider,
479
+ )
480
+
481
+ max_diff = 0
482
+ if not args.no_beam_search_op:
483
+ logger.info("Chaining model ... :")
484
+ args.beam_model_output_dir = WhisperHelper.get_onnx_path(
485
+ output_dir,
486
+ args.model_name_or_path,
487
+ suffix="_beamsearch",
488
+ new_folder=False,
489
+ )
490
+ for path in output_paths:
491
+ if "encoder_decoder" in path:
492
+ args.encoder_path = path
493
+ elif "decoder" in path:
494
+ args.decoder_path = path
495
+ chain_model(args)
496
+ output_paths.append(args.beam_model_output_dir)
497
+
498
+ # Check chained model
499
+ ort_session = create_onnxruntime_session(
500
+ args.beam_model_output_dir,
501
+ use_gpu=args.use_gpu,
502
+ provider=args.provider,
503
+ )
504
+ device = torch.device("cuda:0" if args.use_gpu else "cpu")
505
+
506
+ # Wrap parity check in try-except to allow export to continue in case this produces an error
507
+ try:
508
+ with torch.no_grad():
509
+ # Verify batched decoding with prompts for whisper openai implementation
510
+ if args.model_impl == "openai" and args.use_forced_decoder_ids:
511
+ max_diff = WhisperHelper.verify_onnx(
512
+ args.model_name_or_path, cache_dir, ort_session, device, batch_size=2, prompt_mode=True
513
+ )
514
+ else:
515
+ max_diff = WhisperHelper.verify_onnx(args.model_name_or_path, cache_dir, ort_session, device)
516
+ if max_diff > 1e-4:
517
+ logger.warning("PyTorch and ONNX Runtime results are NOT close")
518
+ else:
519
+ logger.info("PyTorch and ONNX Runtime results are close")
520
+ except Exception as e:
521
+ logger.warning(
522
+ f"An error occurred while trying to verify parity between PyTorch and ONNX Runtime: {e}", exc_info=True
523
+ )
524
+
525
+ # Remove extra ONNX models saved in output directory
526
+ for fle in os.listdir(output_dir):
527
+ if "_beamsearch" not in fle:
528
+ os.remove(os.path.join(output_dir, fle))
529
+ output_paths = [args.beam_model_output_dir]
530
+
531
+ logger.info(f"Done! Outputs: {output_paths}")
532
+ return max_diff
533
+
534
+
535
+ if __name__ == "__main__":
536
+ main()