onnxruntime-directml 1.20.0__cp313-cp313-win_amd64.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (305) hide show
  1. onnxruntime/LICENSE +21 -0
  2. onnxruntime/Privacy.md +21 -0
  3. onnxruntime/ThirdPartyNotices.txt +6508 -0
  4. onnxruntime/__init__.py +78 -0
  5. onnxruntime/backend/__init__.py +6 -0
  6. onnxruntime/backend/backend.py +174 -0
  7. onnxruntime/backend/backend_rep.py +53 -0
  8. onnxruntime/capi/DirectML.dll +0 -0
  9. onnxruntime/capi/__init__.py +4 -0
  10. onnxruntime/capi/_ld_preload.py +7 -0
  11. onnxruntime/capi/_pybind_state.py +33 -0
  12. onnxruntime/capi/convert_npz_to_onnx_adapter.py +48 -0
  13. onnxruntime/capi/onnxruntime.dll +0 -0
  14. onnxruntime/capi/onnxruntime_collect_build_info.py +47 -0
  15. onnxruntime/capi/onnxruntime_inference_collection.py +1108 -0
  16. onnxruntime/capi/onnxruntime_providers_shared.dll +0 -0
  17. onnxruntime/capi/onnxruntime_pybind11_state.pyd +0 -0
  18. onnxruntime/capi/onnxruntime_validation.py +150 -0
  19. onnxruntime/capi/version_info.py +2 -0
  20. onnxruntime/datasets/__init__.py +17 -0
  21. onnxruntime/datasets/logreg_iris.onnx +0 -0
  22. onnxruntime/datasets/mul_1.onnx +0 -0
  23. onnxruntime/datasets/sigmoid.onnx +13 -0
  24. onnxruntime/quantization/CalTableFlatBuffers/KeyValue.py +78 -0
  25. onnxruntime/quantization/CalTableFlatBuffers/TrtTable.py +90 -0
  26. onnxruntime/quantization/CalTableFlatBuffers/__init__.py +0 -0
  27. onnxruntime/quantization/__init__.py +16 -0
  28. onnxruntime/quantization/base_quantizer.py +532 -0
  29. onnxruntime/quantization/calibrate.py +1245 -0
  30. onnxruntime/quantization/execution_providers/qnn/__init__.py +2 -0
  31. onnxruntime/quantization/execution_providers/qnn/fusion_lpnorm.py +132 -0
  32. onnxruntime/quantization/execution_providers/qnn/mixed_precision_overrides_utils.py +413 -0
  33. onnxruntime/quantization/execution_providers/qnn/preprocess.py +307 -0
  34. onnxruntime/quantization/execution_providers/qnn/quant_config.py +387 -0
  35. onnxruntime/quantization/fusions/__init__.py +3 -0
  36. onnxruntime/quantization/fusions/fusion.py +311 -0
  37. onnxruntime/quantization/fusions/fusion_gelu.py +272 -0
  38. onnxruntime/quantization/fusions/fusion_layernorm.py +135 -0
  39. onnxruntime/quantization/matmul_4bits_quantizer.py +1480 -0
  40. onnxruntime/quantization/matmul_bnb4_quantizer.py +240 -0
  41. onnxruntime/quantization/onnx_model.py +580 -0
  42. onnxruntime/quantization/onnx_quantizer.py +1008 -0
  43. onnxruntime/quantization/operators/__init__.py +2 -0
  44. onnxruntime/quantization/operators/activation.py +119 -0
  45. onnxruntime/quantization/operators/argmax.py +18 -0
  46. onnxruntime/quantization/operators/attention.py +73 -0
  47. onnxruntime/quantization/operators/base_operator.py +26 -0
  48. onnxruntime/quantization/operators/binary_op.py +72 -0
  49. onnxruntime/quantization/operators/concat.py +62 -0
  50. onnxruntime/quantization/operators/conv.py +258 -0
  51. onnxruntime/quantization/operators/direct_q8.py +78 -0
  52. onnxruntime/quantization/operators/embed_layernorm.py +121 -0
  53. onnxruntime/quantization/operators/gather.py +64 -0
  54. onnxruntime/quantization/operators/gavgpool.py +62 -0
  55. onnxruntime/quantization/operators/gemm.py +166 -0
  56. onnxruntime/quantization/operators/lstm.py +117 -0
  57. onnxruntime/quantization/operators/matmul.py +231 -0
  58. onnxruntime/quantization/operators/maxpool.py +34 -0
  59. onnxruntime/quantization/operators/norm.py +40 -0
  60. onnxruntime/quantization/operators/pad.py +100 -0
  61. onnxruntime/quantization/operators/pooling.py +67 -0
  62. onnxruntime/quantization/operators/qdq_base_operator.py +22 -0
  63. onnxruntime/quantization/operators/resize.py +34 -0
  64. onnxruntime/quantization/operators/softmax.py +74 -0
  65. onnxruntime/quantization/operators/split.py +63 -0
  66. onnxruntime/quantization/operators/where.py +87 -0
  67. onnxruntime/quantization/preprocess.py +141 -0
  68. onnxruntime/quantization/qdq_loss_debug.py +389 -0
  69. onnxruntime/quantization/qdq_quantizer.py +1187 -0
  70. onnxruntime/quantization/quant_utils.py +891 -0
  71. onnxruntime/quantization/quantize.py +748 -0
  72. onnxruntime/quantization/registry.py +106 -0
  73. onnxruntime/quantization/shape_inference.py +187 -0
  74. onnxruntime/quantization/tensor_quant_overrides.py +516 -0
  75. onnxruntime/tools/__init__.py +10 -0
  76. onnxruntime/tools/check_onnx_model_mobile_usability.py +47 -0
  77. onnxruntime/tools/convert_onnx_models_to_ort.py +377 -0
  78. onnxruntime/tools/file_utils.py +46 -0
  79. onnxruntime/tools/logger.py +11 -0
  80. onnxruntime/tools/make_dynamic_shape_fixed.py +72 -0
  81. onnxruntime/tools/mobile_helpers/__init__.py +0 -0
  82. onnxruntime/tools/mobile_helpers/coreml_supported_mlprogram_ops.md +33 -0
  83. onnxruntime/tools/mobile_helpers/coreml_supported_neuralnetwork_ops.md +43 -0
  84. onnxruntime/tools/mobile_helpers/nnapi_supported_ops.md +58 -0
  85. onnxruntime/tools/mobile_helpers/usability_checker.py +739 -0
  86. onnxruntime/tools/offline_tuning.py +169 -0
  87. onnxruntime/tools/onnx_model_utils.py +413 -0
  88. onnxruntime/tools/onnx_randomizer.py +85 -0
  89. onnxruntime/tools/onnxruntime_test.py +164 -0
  90. onnxruntime/tools/optimize_onnx_model.py +55 -0
  91. onnxruntime/tools/ort_format_model/__init__.py +25 -0
  92. onnxruntime/tools/ort_format_model/operator_type_usage_processors.py +663 -0
  93. onnxruntime/tools/ort_format_model/ort_flatbuffers_py/__init__.py +0 -0
  94. onnxruntime/tools/ort_format_model/ort_flatbuffers_py/fbs/ArgType.py +7 -0
  95. onnxruntime/tools/ort_format_model/ort_flatbuffers_py/fbs/ArgTypeAndIndex.py +67 -0
  96. onnxruntime/tools/ort_format_model/ort_flatbuffers_py/fbs/Attribute.py +337 -0
  97. onnxruntime/tools/ort_format_model/ort_flatbuffers_py/fbs/AttributeType.py +18 -0
  98. onnxruntime/tools/ort_format_model/ort_flatbuffers_py/fbs/Checkpoint.py +125 -0
  99. onnxruntime/tools/ort_format_model/ort_flatbuffers_py/fbs/DeprecatedKernelCreateInfos.py +120 -0
  100. onnxruntime/tools/ort_format_model/ort_flatbuffers_py/fbs/DeprecatedNodeIndexAndKernelDefHash.py +68 -0
  101. onnxruntime/tools/ort_format_model/ort_flatbuffers_py/fbs/DeprecatedSessionState.py +96 -0
  102. onnxruntime/tools/ort_format_model/ort_flatbuffers_py/fbs/DeprecatedSubGraphSessionState.py +72 -0
  103. onnxruntime/tools/ort_format_model/ort_flatbuffers_py/fbs/Dimension.py +71 -0
  104. onnxruntime/tools/ort_format_model/ort_flatbuffers_py/fbs/DimensionValue.py +80 -0
  105. onnxruntime/tools/ort_format_model/ort_flatbuffers_py/fbs/DimensionValueType.py +8 -0
  106. onnxruntime/tools/ort_format_model/ort_flatbuffers_py/fbs/EdgeEnd.py +32 -0
  107. onnxruntime/tools/ort_format_model/ort_flatbuffers_py/fbs/FloatProperty.py +67 -0
  108. onnxruntime/tools/ort_format_model/ort_flatbuffers_py/fbs/Graph.py +320 -0
  109. onnxruntime/tools/ort_format_model/ort_flatbuffers_py/fbs/InferenceSession.py +88 -0
  110. onnxruntime/tools/ort_format_model/ort_flatbuffers_py/fbs/IntProperty.py +67 -0
  111. onnxruntime/tools/ort_format_model/ort_flatbuffers_py/fbs/KernelTypeStrArgsEntry.py +91 -0
  112. onnxruntime/tools/ort_format_model/ort_flatbuffers_py/fbs/KernelTypeStrResolver.py +78 -0
  113. onnxruntime/tools/ort_format_model/ort_flatbuffers_py/fbs/MapType.py +71 -0
  114. onnxruntime/tools/ort_format_model/ort_flatbuffers_py/fbs/Model.py +223 -0
  115. onnxruntime/tools/ort_format_model/ort_flatbuffers_py/fbs/ModuleState.py +141 -0
  116. onnxruntime/tools/ort_format_model/ort_flatbuffers_py/fbs/Node.py +317 -0
  117. onnxruntime/tools/ort_format_model/ort_flatbuffers_py/fbs/NodeEdge.py +126 -0
  118. onnxruntime/tools/ort_format_model/ort_flatbuffers_py/fbs/NodeType.py +7 -0
  119. onnxruntime/tools/ort_format_model/ort_flatbuffers_py/fbs/NodesToOptimizeIndices.py +160 -0
  120. onnxruntime/tools/ort_format_model/ort_flatbuffers_py/fbs/OpIdKernelTypeStrArgsEntry.py +91 -0
  121. onnxruntime/tools/ort_format_model/ort_flatbuffers_py/fbs/OperatorSetId.py +67 -0
  122. onnxruntime/tools/ort_format_model/ort_flatbuffers_py/fbs/OptimizerGroup.py +117 -0
  123. onnxruntime/tools/ort_format_model/ort_flatbuffers_py/fbs/ParameterOptimizerState.py +91 -0
  124. onnxruntime/tools/ort_format_model/ort_flatbuffers_py/fbs/PropertyBag.py +152 -0
  125. onnxruntime/tools/ort_format_model/ort_flatbuffers_py/fbs/RuntimeOptimizationRecord.py +105 -0
  126. onnxruntime/tools/ort_format_model/ort_flatbuffers_py/fbs/RuntimeOptimizationRecordContainerEntry.py +91 -0
  127. onnxruntime/tools/ort_format_model/ort_flatbuffers_py/fbs/RuntimeOptimizations.py +79 -0
  128. onnxruntime/tools/ort_format_model/ort_flatbuffers_py/fbs/SequenceType.py +58 -0
  129. onnxruntime/tools/ort_format_model/ort_flatbuffers_py/fbs/Shape.py +78 -0
  130. onnxruntime/tools/ort_format_model/ort_flatbuffers_py/fbs/SparseTensor.py +114 -0
  131. onnxruntime/tools/ort_format_model/ort_flatbuffers_py/fbs/StringProperty.py +67 -0
  132. onnxruntime/tools/ort_format_model/ort_flatbuffers_py/fbs/StringStringEntry.py +67 -0
  133. onnxruntime/tools/ort_format_model/ort_flatbuffers_py/fbs/Tensor.py +203 -0
  134. onnxruntime/tools/ort_format_model/ort_flatbuffers_py/fbs/TensorDataType.py +26 -0
  135. onnxruntime/tools/ort_format_model/ort_flatbuffers_py/fbs/TensorTypeAndShape.py +71 -0
  136. onnxruntime/tools/ort_format_model/ort_flatbuffers_py/fbs/TypeInfo.py +83 -0
  137. onnxruntime/tools/ort_format_model/ort_flatbuffers_py/fbs/TypeInfoValue.py +9 -0
  138. onnxruntime/tools/ort_format_model/ort_flatbuffers_py/fbs/ValueInfo.py +84 -0
  139. onnxruntime/tools/ort_format_model/ort_flatbuffers_py/fbs/__init__.py +6 -0
  140. onnxruntime/tools/ort_format_model/ort_model_processor.py +86 -0
  141. onnxruntime/tools/ort_format_model/types.py +84 -0
  142. onnxruntime/tools/ort_format_model/utils.py +62 -0
  143. onnxruntime/tools/pytorch_export_contrib_ops.py +108 -0
  144. onnxruntime/tools/pytorch_export_helpers.py +131 -0
  145. onnxruntime/tools/qdq_helpers/__init__.py +0 -0
  146. onnxruntime/tools/qdq_helpers/optimize_qdq_model.py +37 -0
  147. onnxruntime/tools/reduced_build_config_parser.py +202 -0
  148. onnxruntime/tools/symbolic_shape_infer.py +3016 -0
  149. onnxruntime/tools/update_onnx_opset.py +31 -0
  150. onnxruntime/transformers/__init__.py +8 -0
  151. onnxruntime/transformers/affinity_helper.py +40 -0
  152. onnxruntime/transformers/benchmark.py +944 -0
  153. onnxruntime/transformers/benchmark_helper.py +646 -0
  154. onnxruntime/transformers/bert_perf_test.py +634 -0
  155. onnxruntime/transformers/bert_test_data.py +642 -0
  156. onnxruntime/transformers/compare_bert_results.py +246 -0
  157. onnxruntime/transformers/constants.py +47 -0
  158. onnxruntime/transformers/convert_generation.py +3124 -0
  159. onnxruntime/transformers/convert_tf_models_to_pytorch.py +205 -0
  160. onnxruntime/transformers/convert_to_packing_mode.py +387 -0
  161. onnxruntime/transformers/dynamo_onnx_helper.py +104 -0
  162. onnxruntime/transformers/float16.py +501 -0
  163. onnxruntime/transformers/fusion_attention.py +1235 -0
  164. onnxruntime/transformers/fusion_attention_clip.py +257 -0
  165. onnxruntime/transformers/fusion_attention_sam2.py +534 -0
  166. onnxruntime/transformers/fusion_attention_unet.py +1304 -0
  167. onnxruntime/transformers/fusion_attention_vae.py +301 -0
  168. onnxruntime/transformers/fusion_bart_attention.py +640 -0
  169. onnxruntime/transformers/fusion_base.py +137 -0
  170. onnxruntime/transformers/fusion_bias_add.py +58 -0
  171. onnxruntime/transformers/fusion_biasgelu.py +66 -0
  172. onnxruntime/transformers/fusion_biassplitgelu.py +111 -0
  173. onnxruntime/transformers/fusion_conformer_attention.py +143 -0
  174. onnxruntime/transformers/fusion_embedlayer.py +811 -0
  175. onnxruntime/transformers/fusion_fastgelu.py +360 -0
  176. onnxruntime/transformers/fusion_gelu.py +259 -0
  177. onnxruntime/transformers/fusion_gelu_approximation.py +25 -0
  178. onnxruntime/transformers/fusion_gemmfastgelu.py +122 -0
  179. onnxruntime/transformers/fusion_gpt_attention.py +546 -0
  180. onnxruntime/transformers/fusion_gpt_attention_megatron.py +355 -0
  181. onnxruntime/transformers/fusion_gpt_attention_no_past.py +260 -0
  182. onnxruntime/transformers/fusion_group_norm.py +179 -0
  183. onnxruntime/transformers/fusion_layernorm.py +465 -0
  184. onnxruntime/transformers/fusion_nhwc_conv.py +100 -0
  185. onnxruntime/transformers/fusion_options.py +340 -0
  186. onnxruntime/transformers/fusion_qordered_attention.py +421 -0
  187. onnxruntime/transformers/fusion_qordered_gelu.py +119 -0
  188. onnxruntime/transformers/fusion_qordered_layernorm.py +123 -0
  189. onnxruntime/transformers/fusion_qordered_matmul.py +217 -0
  190. onnxruntime/transformers/fusion_quickgelu.py +74 -0
  191. onnxruntime/transformers/fusion_reshape.py +173 -0
  192. onnxruntime/transformers/fusion_rotary_attention.py +1592 -0
  193. onnxruntime/transformers/fusion_shape.py +110 -0
  194. onnxruntime/transformers/fusion_simplified_layernorm.py +159 -0
  195. onnxruntime/transformers/fusion_skip_group_norm.py +255 -0
  196. onnxruntime/transformers/fusion_skiplayernorm.py +209 -0
  197. onnxruntime/transformers/fusion_transpose.py +168 -0
  198. onnxruntime/transformers/fusion_utils.py +307 -0
  199. onnxruntime/transformers/huggingface_models.py +167 -0
  200. onnxruntime/transformers/import_utils.py +20 -0
  201. onnxruntime/transformers/io_binding_helper.py +442 -0
  202. onnxruntime/transformers/large_model_exporter.py +395 -0
  203. onnxruntime/transformers/machine_info.py +221 -0
  204. onnxruntime/transformers/metrics.py +164 -0
  205. onnxruntime/transformers/models/bart/__init__.py +12 -0
  206. onnxruntime/transformers/models/bart/export.py +98 -0
  207. onnxruntime/transformers/models/bert/__init__.py +12 -0
  208. onnxruntime/transformers/models/bert/eval_squad.py +329 -0
  209. onnxruntime/transformers/models/gpt2/__init__.py +12 -0
  210. onnxruntime/transformers/models/gpt2/benchmark_gpt2.py +413 -0
  211. onnxruntime/transformers/models/gpt2/convert_to_onnx.py +561 -0
  212. onnxruntime/transformers/models/gpt2/gpt2_helper.py +1032 -0
  213. onnxruntime/transformers/models/gpt2/gpt2_parity.py +513 -0
  214. onnxruntime/transformers/models/gpt2/gpt2_tester.py +501 -0
  215. onnxruntime/transformers/models/gpt2/parity_check_helper.py +146 -0
  216. onnxruntime/transformers/models/llama/__init__.py +12 -0
  217. onnxruntime/transformers/models/llama/benchmark.py +703 -0
  218. onnxruntime/transformers/models/llama/benchmark_all.py +488 -0
  219. onnxruntime/transformers/models/llama/benchmark_e2e.py +606 -0
  220. onnxruntime/transformers/models/llama/convert_to_onnx.py +1027 -0
  221. onnxruntime/transformers/models/llama/dist_settings.py +57 -0
  222. onnxruntime/transformers/models/llama/llama_inputs.py +503 -0
  223. onnxruntime/transformers/models/llama/llama_parity.py +309 -0
  224. onnxruntime/transformers/models/llama/llama_torch.py +47 -0
  225. onnxruntime/transformers/models/llama/quant_kv_dataloader.py +108 -0
  226. onnxruntime/transformers/models/longformer/__init__.py +12 -0
  227. onnxruntime/transformers/models/longformer/benchmark_longformer.py +821 -0
  228. onnxruntime/transformers/models/longformer/convert_to_onnx.py +413 -0
  229. onnxruntime/transformers/models/longformer/generate_test_data.py +347 -0
  230. onnxruntime/transformers/models/longformer/longformer_helper.py +77 -0
  231. onnxruntime/transformers/models/phi2/__init__.py +12 -0
  232. onnxruntime/transformers/models/phi2/convert_to_onnx.py +576 -0
  233. onnxruntime/transformers/models/phi2/inference_example.py +414 -0
  234. onnxruntime/transformers/models/sam2/__init__.py +12 -0
  235. onnxruntime/transformers/models/sam2/benchmark_sam2.py +625 -0
  236. onnxruntime/transformers/models/sam2/convert_to_onnx.py +260 -0
  237. onnxruntime/transformers/models/sam2/image_decoder.py +273 -0
  238. onnxruntime/transformers/models/sam2/image_encoder.py +186 -0
  239. onnxruntime/transformers/models/sam2/mask_decoder.py +208 -0
  240. onnxruntime/transformers/models/sam2/nvtx_helper.py +33 -0
  241. onnxruntime/transformers/models/sam2/prompt_encoder.py +189 -0
  242. onnxruntime/transformers/models/sam2/sam2_demo.py +322 -0
  243. onnxruntime/transformers/models/sam2/sam2_image_onnx_predictor.py +280 -0
  244. onnxruntime/transformers/models/sam2/sam2_utils.py +147 -0
  245. onnxruntime/transformers/models/stable_diffusion/__init__.py +12 -0
  246. onnxruntime/transformers/models/stable_diffusion/benchmark.py +1429 -0
  247. onnxruntime/transformers/models/stable_diffusion/benchmark_controlnet.py +426 -0
  248. onnxruntime/transformers/models/stable_diffusion/demo_txt2img.py +102 -0
  249. onnxruntime/transformers/models/stable_diffusion/demo_txt2img_xl.py +268 -0
  250. onnxruntime/transformers/models/stable_diffusion/demo_utils.py +778 -0
  251. onnxruntime/transformers/models/stable_diffusion/diffusion_models.py +1319 -0
  252. onnxruntime/transformers/models/stable_diffusion/diffusion_schedulers.py +1181 -0
  253. onnxruntime/transformers/models/stable_diffusion/engine_builder.py +296 -0
  254. onnxruntime/transformers/models/stable_diffusion/engine_builder_ort_cuda.py +388 -0
  255. onnxruntime/transformers/models/stable_diffusion/engine_builder_ort_trt.py +288 -0
  256. onnxruntime/transformers/models/stable_diffusion/engine_builder_tensorrt.py +395 -0
  257. onnxruntime/transformers/models/stable_diffusion/engine_builder_torch.py +108 -0
  258. onnxruntime/transformers/models/stable_diffusion/optimize_pipeline.py +350 -0
  259. onnxruntime/transformers/models/stable_diffusion/ort_optimizer.py +136 -0
  260. onnxruntime/transformers/models/stable_diffusion/pipeline_stable_diffusion.py +831 -0
  261. onnxruntime/transformers/models/stable_diffusion/trt_utilities.py +12 -0
  262. onnxruntime/transformers/models/t5/__init__.py +12 -0
  263. onnxruntime/transformers/models/t5/convert_to_onnx.py +278 -0
  264. onnxruntime/transformers/models/t5/past_helper.py +150 -0
  265. onnxruntime/transformers/models/t5/t5_decoder.py +438 -0
  266. onnxruntime/transformers/models/t5/t5_encoder.py +171 -0
  267. onnxruntime/transformers/models/t5/t5_encoder_decoder_init.py +299 -0
  268. onnxruntime/transformers/models/t5/t5_helper.py +272 -0
  269. onnxruntime/transformers/models/whisper/__init__.py +12 -0
  270. onnxruntime/transformers/models/whisper/benchmark.py +610 -0
  271. onnxruntime/transformers/models/whisper/benchmark_all.py +528 -0
  272. onnxruntime/transformers/models/whisper/convert_to_onnx.py +536 -0
  273. onnxruntime/transformers/models/whisper/whisper_chain.py +329 -0
  274. onnxruntime/transformers/models/whisper/whisper_decoder.py +402 -0
  275. onnxruntime/transformers/models/whisper/whisper_encoder.py +164 -0
  276. onnxruntime/transformers/models/whisper/whisper_encoder_decoder_init.py +306 -0
  277. onnxruntime/transformers/models/whisper/whisper_helper.py +524 -0
  278. onnxruntime/transformers/models/whisper/whisper_openai_helper.py +84 -0
  279. onnxruntime/transformers/onnx_exporter.py +717 -0
  280. onnxruntime/transformers/onnx_model.py +1569 -0
  281. onnxruntime/transformers/onnx_model_bart.py +142 -0
  282. onnxruntime/transformers/onnx_model_bert.py +481 -0
  283. onnxruntime/transformers/onnx_model_bert_keras.py +475 -0
  284. onnxruntime/transformers/onnx_model_bert_tf.py +589 -0
  285. onnxruntime/transformers/onnx_model_clip.py +40 -0
  286. onnxruntime/transformers/onnx_model_conformer.py +33 -0
  287. onnxruntime/transformers/onnx_model_gpt2.py +101 -0
  288. onnxruntime/transformers/onnx_model_phi.py +930 -0
  289. onnxruntime/transformers/onnx_model_sam2.py +138 -0
  290. onnxruntime/transformers/onnx_model_t5.py +791 -0
  291. onnxruntime/transformers/onnx_model_tnlr.py +227 -0
  292. onnxruntime/transformers/onnx_model_unet.py +259 -0
  293. onnxruntime/transformers/onnx_model_vae.py +43 -0
  294. onnxruntime/transformers/onnx_utils.py +55 -0
  295. onnxruntime/transformers/optimizer.py +612 -0
  296. onnxruntime/transformers/profiler.py +725 -0
  297. onnxruntime/transformers/quantize_helper.py +76 -0
  298. onnxruntime/transformers/shape_infer_helper.py +122 -0
  299. onnxruntime/transformers/shape_optimizer.py +401 -0
  300. onnxruntime/transformers/torch_onnx_export_helper.py +74 -0
  301. onnxruntime_directml-1.20.0.dist-info/METADATA +187 -0
  302. onnxruntime_directml-1.20.0.dist-info/RECORD +305 -0
  303. onnxruntime_directml-1.20.0.dist-info/WHEEL +5 -0
  304. onnxruntime_directml-1.20.0.dist-info/entry_points.txt +2 -0
  305. onnxruntime_directml-1.20.0.dist-info/top_level.txt +1 -0
@@ -0,0 +1,402 @@
1
+ # -------------------------------------------------------------------------
2
+ # Copyright (c) Microsoft Corporation. All rights reserved.
3
+ # Licensed under the MIT License. See License.txt in the project root for
4
+ # license information.
5
+ # --------------------------------------------------------------------------
6
+
7
+ import logging
8
+ import os
9
+ import tempfile
10
+ from pathlib import Path
11
+ from typing import List, Optional, Union
12
+
13
+ import numpy
14
+ import onnx
15
+ import torch
16
+ from io_binding_helper import TypeHelper
17
+ from models.t5.past_helper import PastKeyValuesHelper
18
+ from onnx_model import OnnxModel
19
+ from torch_onnx_export_helper import torch_onnx_export
20
+ from transformers import WhisperConfig, file_utils
21
+ from whisper_openai_helper import WhisperDecoderInitOpenai
22
+
23
+ from onnxruntime import InferenceSession
24
+
25
+ logger = logging.getLogger(__name__)
26
+
27
+
28
+ class WhisperDecoderInit(torch.nn.Module):
29
+ """A Whisper decoder to create initial past key values.
30
+ This model is only called once during starting decoding.
31
+ """
32
+
33
+ def __init__(
34
+ self,
35
+ decoder: torch.nn.Module,
36
+ config: WhisperConfig,
37
+ decoder_start_token_id: Optional[int] = None,
38
+ ):
39
+ super().__init__()
40
+ self.decoder = decoder
41
+ self.config = config
42
+ self.decoder_start_token_id = (
43
+ decoder_start_token_id if decoder_start_token_id is not None else self.config.decoder_start_token_id
44
+ )
45
+
46
+ def forward(
47
+ self,
48
+ decoder_input_ids: torch.Tensor,
49
+ encoder_hidden_states: torch.FloatTensor,
50
+ ):
51
+ encoder_outputs = file_utils.ModelOutput()
52
+ encoder_outputs["last_hidden_state"] = encoder_hidden_states
53
+ encoder_outputs["hidden_states"] = None
54
+ encoder_outputs["attentions"] = None
55
+
56
+ out = self.decoder.model(
57
+ None,
58
+ encoder_outputs=encoder_outputs,
59
+ decoder_input_ids=decoder_input_ids,
60
+ past_key_values=None,
61
+ use_cache=True,
62
+ return_dict=True,
63
+ )
64
+ logits = self.decoder.proj_out(out[0])
65
+ return logits, out.past_key_values, out.encoder_last_hidden_state
66
+
67
+
68
+ class WhisperDecoder(torch.nn.Module):
69
+ """A Whisper decoder with past key values"""
70
+
71
+ def __init__(self, decoder, config, model_impl: str = "hf", model: torch.nn.Module = None):
72
+ super().__init__()
73
+ self.decoder = decoder
74
+ self.config = config
75
+ self.model_impl = model_impl
76
+ if model is not None:
77
+ self.whisper_decoder_openai_init = WhisperDecoderInitOpenai(model, decoder)
78
+
79
+ def forward(self, decoder_input_ids, *past):
80
+ encoder_outputs = file_utils.ModelOutput()
81
+ dummy_encoder_hidden_states = torch.randn((decoder_input_ids.shape[0], 3000, int(self.config.d_model)))
82
+ encoder_outputs["last_hidden_state"] = dummy_encoder_hidden_states
83
+ encoder_outputs["hidden_states"] = dummy_encoder_hidden_states
84
+ encoder_outputs["attentions"] = None
85
+
86
+ if self.model_impl == "openai":
87
+ dummy_encoder_hidden_states.unsqueeze(0)
88
+ dec_out, present = self.whisper_decoder_openai_init(
89
+ decoder_input_ids, dummy_encoder_hidden_states, past=past
90
+ )
91
+ return dec_out, present
92
+
93
+ if len(past) == 0:
94
+ past_key_values = None
95
+ else:
96
+ past_key_values = PastKeyValuesHelper.back_group_by_layer(past)
97
+
98
+ decoder_out = self.decoder(
99
+ None,
100
+ encoder_outputs=encoder_outputs,
101
+ decoder_input_ids=decoder_input_ids,
102
+ past_key_values=past_key_values,
103
+ use_cache=True,
104
+ return_dict=True,
105
+ )
106
+ logits = decoder_out[0]
107
+ present_self, _ = PastKeyValuesHelper.group_by_self_and_cross(decoder_out.past_key_values)
108
+ return logits, present_self
109
+
110
+
111
+ class WhisperDecoderInputs:
112
+ def __init__(
113
+ self,
114
+ decoder_input_ids,
115
+ past_key_values=None,
116
+ ):
117
+ self.decoder_input_ids: torch.LongTensor = decoder_input_ids
118
+ self.past_key_values: Union[List[torch.FloatTensor], List[torch.HalfTensor], None] = past_key_values
119
+
120
+ @staticmethod
121
+ def create_dummy(
122
+ config: WhisperConfig,
123
+ batch_size: int,
124
+ encode_sequence_length: int,
125
+ past_decode_sequence_length: int,
126
+ device: torch.device,
127
+ float16: bool = False,
128
+ use_int32_inputs: bool = False,
129
+ model_impl: str = "hf",
130
+ ): # -> WhisperDecoderInputs:
131
+ """Create dummy inputs for WhisperDecoder.
132
+
133
+ Args:
134
+ decoder: decoder
135
+ batch_size (int): batch size
136
+ encode_sequence_length (int): sequence length of input_ids for encoder
137
+ past_decode_sequence_length (int): past sequence length of input_ids for decoder
138
+ device (torch.device): device of output tensors
139
+ float16 (bool): whether the model uses float32 or float16 in input
140
+ use_int32_inputs(bool): whether use int32 instead of int64 for some inputs
141
+
142
+ Returns:
143
+ WhisperDecoderInputs: dummy inputs for decoder
144
+ """
145
+ num_attention_heads: int = config.encoder_attention_heads
146
+ num_layers: int = config.decoder_layers # + config.encoder_layers
147
+ vocab_size: int = config.vocab_size
148
+
149
+ # Use head_size, use hidden_size / num_attention_heads here.
150
+ # For example, whisper-large, d_model=1280 and num_heads=20
151
+ head_size: int = config.d_model // config.encoder_attention_heads
152
+
153
+ sequence_length: int = 1 # fixed for decoding
154
+ decoder_input_ids = torch.randint(
155
+ low=0,
156
+ high=vocab_size - 1,
157
+ size=(batch_size, sequence_length),
158
+ dtype=(torch.int32 if use_int32_inputs else torch.int64),
159
+ device=device,
160
+ )
161
+
162
+ float_type = torch.float16 if float16 else torch.float32
163
+
164
+ if past_decode_sequence_length > 0:
165
+ self_attention_past_shape = [
166
+ batch_size,
167
+ num_attention_heads,
168
+ past_decode_sequence_length,
169
+ head_size,
170
+ ]
171
+ cross_attention_past_shape = [
172
+ batch_size,
173
+ num_attention_heads,
174
+ encode_sequence_length if model_impl == "hf" else past_decode_sequence_length,
175
+ head_size,
176
+ ]
177
+
178
+ past = []
179
+ for _ in range(2 * num_layers):
180
+ past.append(torch.rand(self_attention_past_shape, dtype=float_type, device=device))
181
+
182
+ for _ in range(2 * num_layers):
183
+ past.append(torch.rand(cross_attention_past_shape, dtype=float_type, device=device))
184
+ else:
185
+ past = None
186
+
187
+ return WhisperDecoderInputs(decoder_input_ids, past)
188
+
189
+ def to_list(self) -> List:
190
+ input_list = [self.decoder_input_ids]
191
+ if self.past_key_values:
192
+ input_list.extend(self.past_key_values)
193
+ return input_list
194
+
195
+ def to_fp32(self):
196
+ past = [p.to(dtype=torch.float32) for p in self.past_key_values] if self.past_key_values else None
197
+ return WhisperDecoderInputs(
198
+ self.decoder_input_ids.clone(),
199
+ past,
200
+ )
201
+
202
+
203
+ class WhisperDecoderHelper:
204
+ @staticmethod
205
+ def export_onnx(
206
+ decoder: WhisperDecoder,
207
+ device: torch.device,
208
+ onnx_model_path: str,
209
+ verbose: bool = True,
210
+ use_external_data_format: bool = False,
211
+ use_int32_inputs: bool = False,
212
+ ):
213
+ """Export decoder to ONNX
214
+
215
+ Args:
216
+ decoder (Union[WhisperDecoder, WhisperDecoderNoPastState]): decoder object
217
+ device (torch.device): device of decoder object
218
+ onnx_model_path (str): onnx path
219
+ verbose (bool, optional): print verbose information. Defaults to True.
220
+ use_external_data_format (bool, optional): use external data format or not. Defaults to False.
221
+ use_int32_inputs (bool, optional): use int32 inputs
222
+ """
223
+ assert isinstance(decoder, (WhisperDecoder, WhisperDecoderInit))
224
+
225
+ inputs = WhisperDecoderInputs.create_dummy(
226
+ decoder.config,
227
+ batch_size=2,
228
+ encode_sequence_length=3000,
229
+ past_decode_sequence_length=6 if isinstance(decoder, WhisperDecoder) else 0,
230
+ device=device,
231
+ use_int32_inputs=use_int32_inputs,
232
+ model_impl=decoder.model_impl,
233
+ )
234
+ input_list = inputs.to_list()
235
+
236
+ # Fix past disappearing bug - duplicate first past entry
237
+ # input_list.insert(2, input_list[2])
238
+
239
+ past_names = PastKeyValuesHelper.get_past_names(decoder.config.decoder_layers, present=False)
240
+ present_names = PastKeyValuesHelper.get_past_names(decoder.config.decoder_layers, present=True)
241
+ present_self_names = present_names[: 2 * decoder.config.decoder_layers]
242
+
243
+ input_past_names = past_names if isinstance(decoder, WhisperDecoder) else []
244
+ output_present_names = present_self_names if isinstance(decoder, WhisperDecoder) else present_names
245
+ output_names = ["logits", *output_present_names]
246
+
247
+ # Shape of input tensors (sequence_length==1):
248
+ # input_ids: (batch_size, sequence_length)
249
+ # past_self_*: (batch_size, num_heads, past_decode_sequence_length, head_size)
250
+ # past_cross_*: (batch_size, num_heads, encode_sequence_length, head_size)
251
+
252
+ # Shape of output tensors:
253
+ # logits: (batch_size, sequence_length, vocab_size)
254
+ # past_self_*: (batch_size, num_heads, past_decode_sequence_length + sequence_length, head_size)
255
+ # past_cross_*: (batch_size, num_heads, encode_sequence_length, head_size)
256
+
257
+ input_names = ["input_ids"]
258
+ input_names.extend(input_past_names)
259
+
260
+ dynamic_axes = {
261
+ "input_ids": {0: "batch_size"},
262
+ "encoder_hidden_states": {0: "batch_size", 1: "encode_sequence_length / 2"},
263
+ "logits": {0: "batch_size", 1: "sequence_length"},
264
+ }
265
+
266
+ for name in input_past_names:
267
+ dynamic_axes[name] = {
268
+ 0: "batch_size",
269
+ 2: "past_decode_sequence_length" if "self" in name else "encode_sequence_length",
270
+ }
271
+
272
+ for name in output_present_names:
273
+ if "cross" in name:
274
+ dynamic_axes[name] = {0: "batch_size", 2: "encode_sequence_length"}
275
+ else: # self attention past state
276
+ if isinstance(decoder, WhisperDecoder):
277
+ dynamic_axes[name] = {
278
+ 0: "batch_size",
279
+ 2: "past_decode_sequence_length + 1",
280
+ }
281
+ else:
282
+ dynamic_axes[name] = {
283
+ 0: "batch_size",
284
+ # 2: 'sequence_length'
285
+ }
286
+
287
+ Path(onnx_model_path).parent.mkdir(parents=True, exist_ok=True)
288
+
289
+ with tempfile.TemporaryDirectory() as tmp_dir_name:
290
+ temp_onnx_model_path = os.path.join(tmp_dir_name, "decoder.onnx")
291
+ Path(temp_onnx_model_path).parent.mkdir(parents=True, exist_ok=True)
292
+ torch_onnx_export(
293
+ decoder,
294
+ args=tuple(input_list),
295
+ f=temp_onnx_model_path if use_external_data_format else onnx_model_path,
296
+ export_params=True,
297
+ input_names=input_names,
298
+ output_names=output_names,
299
+ dynamic_axes=dynamic_axes,
300
+ opset_version=17,
301
+ do_constant_folding=True,
302
+ use_external_data_format=use_external_data_format,
303
+ verbose=verbose,
304
+ )
305
+
306
+ if use_external_data_format:
307
+ model = onnx.load_model(temp_onnx_model_path, load_external_data=True)
308
+ OnnxModel.save(
309
+ model,
310
+ onnx_model_path,
311
+ save_as_external_data=True,
312
+ all_tensors_to_one_file=True,
313
+ )
314
+
315
+ @staticmethod
316
+ def onnxruntime_inference(ort_session, inputs: WhisperDecoderInputs):
317
+ """Run inference of ONNX model."""
318
+ logger.debug("start onnxruntime_inference")
319
+
320
+ ort_inputs = {
321
+ "input_ids": numpy.ascontiguousarray(inputs.decoder_input_ids.cpu().numpy()),
322
+ }
323
+
324
+ if inputs.past_key_values:
325
+ assert len(inputs.past_key_values) % 4 == 0
326
+ num_layers = int(len(inputs.past_key_values) / 4)
327
+ past_names = PastKeyValuesHelper.get_past_names(num_layers)
328
+ for i, past_tensor in enumerate(inputs.past_key_values):
329
+ ort_inputs[past_names[i]] = numpy.ascontiguousarray(past_tensor.cpu().numpy())
330
+
331
+ ort_outputs = ort_session.run(None, ort_inputs)
332
+ return ort_outputs
333
+
334
+ @staticmethod
335
+ def verify_onnx(
336
+ model: Union[WhisperDecoder, WhisperDecoderInit],
337
+ ort_session: InferenceSession,
338
+ device: torch.device,
339
+ use_int32_inputs: bool,
340
+ max_cases: int = 4,
341
+ ):
342
+ """Compare the result from PyTorch and OnnxRuntime to verify the ONNX model is good."""
343
+ float16: bool = TypeHelper.get_input_type(ort_session, "past_key_self_0") == "tensor(float16)"
344
+
345
+ test_cases = [(4, 11, 3), (1, 2, 5), (3, 1, 1), (8, 5, 2)]
346
+ test_cases_max_diff = []
347
+ for (
348
+ batch_size,
349
+ encode_sequence_length,
350
+ past_decode_sequence_length,
351
+ ) in test_cases[:max_cases]:
352
+ if isinstance(model, WhisperDecoderInit):
353
+ dec_seq_len = 0
354
+ else:
355
+ dec_seq_len = past_decode_sequence_length
356
+
357
+ inputs = WhisperDecoderInputs.create_dummy(
358
+ model.config,
359
+ batch_size,
360
+ encode_sequence_length,
361
+ dec_seq_len,
362
+ device=device,
363
+ float16=float16,
364
+ use_int32_inputs=use_int32_inputs,
365
+ )
366
+
367
+ # We use fp32 PyTroch model as baseline even when ONNX model is fp16
368
+ input_list = inputs.to_fp32().to_list()
369
+
370
+ # Run inference of PyTorch model
371
+ with torch.no_grad():
372
+ torch_outputs = model(*input_list)
373
+
374
+ ort_outputs = WhisperDecoderHelper.onnxruntime_inference(ort_session, inputs)
375
+
376
+ max_diff = numpy.amax(numpy.abs(torch_outputs[0].cpu().numpy() - ort_outputs[0]))
377
+ max_diff_all = max_diff
378
+ logger.debug(f"logits max_diff={max_diff}")
379
+
380
+ for i in range(2 * model.config.num_layers):
381
+ max_diff = numpy.amax(numpy.abs(torch_outputs[1][i].cpu().numpy() - ort_outputs[1 + i]))
382
+ logger.debug(f"self attention past state {i} max_diff={max_diff}")
383
+ max_diff_all = max(max_diff_all, max_diff)
384
+
385
+ if isinstance(model, WhisperDecoderInit):
386
+ for i in range(2 * model.config.num_layers):
387
+ max_diff = numpy.amax(
388
+ numpy.abs(torch_outputs[2][i].cpu().numpy() - ort_outputs[1 + 2 * model.config.num_layers + i])
389
+ )
390
+ logger.debug(f"cross attention past state {i} max_diff={max_diff}")
391
+ max_diff_all = max(max_diff_all, max_diff)
392
+
393
+ test_cases_max_diff.append(max_diff_all)
394
+ logger.info(
395
+ "batch_size=%s, encode_sequence_length=%s, past_decode_sequence_length=%s, max_diff=%s",
396
+ batch_size,
397
+ encode_sequence_length,
398
+ past_decode_sequence_length,
399
+ max_diff_all,
400
+ )
401
+
402
+ return max_diff_all
@@ -0,0 +1,164 @@
1
+ # -------------------------------------------------------------------------
2
+ # Copyright (c) Microsoft Corporation. All rights reserved.
3
+ # Licensed under the MIT License. See License.txt in the project root for
4
+ # license information.
5
+ # --------------------------------------------------------------------------
6
+
7
+ import logging
8
+ import os
9
+ import tempfile
10
+ from pathlib import Path
11
+ from typing import List
12
+
13
+ import numpy
14
+ import onnx
15
+ import torch
16
+ from onnx_model import OnnxModel
17
+ from torch_onnx_export_helper import torch_onnx_export
18
+ from transformers import WhisperConfig
19
+
20
+ from onnxruntime import InferenceSession
21
+
22
+ logger = logging.getLogger(__name__)
23
+
24
+
25
+ class WhisperEncoder(torch.nn.Module):
26
+ """Whisper encoder outputs only the last hidden state"""
27
+
28
+ def __init__(self, encoder, config: WhisperConfig, model_impl: str = "hf"):
29
+ super().__init__()
30
+ self.encoder = encoder
31
+ self.config = config
32
+ self.model_impl = model_impl
33
+
34
+ def forward(self, input_features):
35
+ if self.model_impl == "openai":
36
+ return self.encoder(input_features)
37
+ return self.encoder.model.encoder(input_features)[0]
38
+
39
+
40
+ class WhisperEncoderInputs:
41
+ def __init__(self, input_features):
42
+ self.input_ids: torch.LongTensor = input_features
43
+
44
+ @staticmethod
45
+ def create_dummy(
46
+ batch_size: int,
47
+ sequence_length: int,
48
+ feature_size: int,
49
+ device: torch.device,
50
+ use_int32_inputs: bool = False,
51
+ ):
52
+ """Create dummy inputs for Whisper encoder.
53
+
54
+ Args:
55
+ batch_size (int): batch size
56
+ sequence_length (int): sequence length
57
+ feature_size (int): feature size for spectrogram input
58
+ device (torch.device): device of output tensors
59
+
60
+ Returns:
61
+ WhisperEncoderInputs: dummy inputs for encoder
62
+ """
63
+
64
+ input_features = torch.randn(
65
+ size=(batch_size, feature_size, sequence_length),
66
+ device=device,
67
+ )
68
+ return WhisperEncoderInputs(input_features)
69
+
70
+ def to_list(self) -> List:
71
+ if self.input_ids is None:
72
+ return []
73
+ return [self.input_ids]
74
+
75
+
76
+ class WhisperEncoderHelper:
77
+ @staticmethod
78
+ def export_onnx(
79
+ encoder,
80
+ device: torch.device,
81
+ onnx_model_path: str,
82
+ verbose: bool = True,
83
+ use_external_data_format: bool = False,
84
+ use_int32_inputs: bool = False,
85
+ ):
86
+ """Export encoder to ONNX
87
+
88
+ Args:
89
+ encoder (WhisperEncoder): encoder object
90
+ device (torch.device): device of encoder object
91
+ onnx_model_path (str): onnx path
92
+ verbose (bool, optional): print verbose information. Defaults to True.
93
+ use_external_data_format (bool, optional): use external data format or not. Defaults to False.
94
+ """
95
+ config = encoder.config
96
+ encoder_inputs = WhisperEncoderInputs.create_dummy(
97
+ batch_size=2,
98
+ sequence_length=3000,
99
+ feature_size=config.num_mel_bins,
100
+ device=device,
101
+ use_int32_inputs=use_int32_inputs,
102
+ )
103
+
104
+ Path(onnx_model_path).parent.mkdir(parents=True, exist_ok=True)
105
+
106
+ with tempfile.TemporaryDirectory() as tmp_dir_name:
107
+ temp_onnx_model_path = os.path.join(tmp_dir_name, "encoder.onnx")
108
+ Path(temp_onnx_model_path).parent.mkdir(parents=True, exist_ok=True)
109
+ torch_onnx_export(
110
+ encoder,
111
+ args=tuple(encoder_inputs.to_list()),
112
+ f=temp_onnx_model_path if use_external_data_format else onnx_model_path,
113
+ export_params=True,
114
+ input_names=["input_features"],
115
+ output_names=["hidden_states"],
116
+ dynamic_axes={
117
+ "input_ids": {0: "batch_size", 1: "feature_size", 2: "sequence_length"},
118
+ "hidden_states": {0: "batch_size", 1: "sequence_length"},
119
+ },
120
+ opset_version=17,
121
+ do_constant_folding=True,
122
+ use_external_data_format=use_external_data_format,
123
+ verbose=verbose,
124
+ )
125
+
126
+ if use_external_data_format:
127
+ model = onnx.load_model(temp_onnx_model_path, load_external_data=True)
128
+ OnnxModel.save(
129
+ model,
130
+ onnx_model_path,
131
+ save_as_external_data=True,
132
+ all_tensors_to_one_file=True,
133
+ )
134
+
135
+ @staticmethod
136
+ def onnxruntime_inference(ort_session, inputs: WhisperEncoderInputs):
137
+ """Run inference of ONNX model."""
138
+ ort_inputs = {
139
+ "input_ids": numpy.ascontiguousarray(inputs.input_ids.cpu().numpy()),
140
+ }
141
+
142
+ return ort_session.run(None, ort_inputs)
143
+
144
+ @staticmethod
145
+ def verify_onnx(
146
+ model: WhisperEncoder, ort_session: InferenceSession, device: torch.device, use_int32_inputs: bool = False
147
+ ):
148
+ """Compare the result from PyTorch and OnnxRuntime to verify the ONNX model is good."""
149
+ inputs = WhisperEncoderInputs.create_dummy(
150
+ batch_size=4,
151
+ sequence_length=11,
152
+ device=device,
153
+ use_int32_inputs=use_int32_inputs,
154
+ )
155
+ input_list = inputs.to_list()
156
+ torch_outputs = model(*input_list)
157
+
158
+ ort_outputs = WhisperEncoderHelper.onnxruntime_inference(ort_session, inputs)
159
+
160
+ max_diff = numpy.amax(numpy.abs(torch_outputs.cpu().numpy() - ort_outputs[0]))
161
+
162
+ logger.info(f"max_diff={max_diff}")
163
+
164
+ return max_diff