onnxruntime-directml 1.24.1__cp314-cp314-win_amd64.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (322) hide show
  1. onnxruntime/LICENSE +21 -0
  2. onnxruntime/Privacy.md +21 -0
  3. onnxruntime/ThirdPartyNotices.txt +6121 -0
  4. onnxruntime/__init__.py +418 -0
  5. onnxruntime/backend/__init__.py +6 -0
  6. onnxruntime/backend/backend.py +175 -0
  7. onnxruntime/backend/backend_rep.py +52 -0
  8. onnxruntime/capi/DirectML.dll +0 -0
  9. onnxruntime/capi/__init__.py +4 -0
  10. onnxruntime/capi/_ld_preload.py +7 -0
  11. onnxruntime/capi/_pybind_state.py +33 -0
  12. onnxruntime/capi/build_and_package_info.py +2 -0
  13. onnxruntime/capi/convert_npz_to_onnx_adapter.py +48 -0
  14. onnxruntime/capi/onnxruntime.dll +0 -0
  15. onnxruntime/capi/onnxruntime_collect_build_info.py +47 -0
  16. onnxruntime/capi/onnxruntime_inference_collection.py +1440 -0
  17. onnxruntime/capi/onnxruntime_providers_shared.dll +0 -0
  18. onnxruntime/capi/onnxruntime_pybind11_state.pyd +0 -0
  19. onnxruntime/capi/onnxruntime_validation.py +154 -0
  20. onnxruntime/capi/version_info.py +2 -0
  21. onnxruntime/datasets/__init__.py +18 -0
  22. onnxruntime/datasets/logreg_iris.onnx +0 -0
  23. onnxruntime/datasets/mul_1.onnx +0 -0
  24. onnxruntime/datasets/sigmoid.onnx +13 -0
  25. onnxruntime/quantization/CalTableFlatBuffers/KeyValue.py +78 -0
  26. onnxruntime/quantization/CalTableFlatBuffers/TrtTable.py +90 -0
  27. onnxruntime/quantization/CalTableFlatBuffers/__init__.py +0 -0
  28. onnxruntime/quantization/__init__.py +19 -0
  29. onnxruntime/quantization/base_quantizer.py +529 -0
  30. onnxruntime/quantization/calibrate.py +1267 -0
  31. onnxruntime/quantization/execution_providers/qnn/__init__.py +2 -0
  32. onnxruntime/quantization/execution_providers/qnn/fusion_lpnorm.py +132 -0
  33. onnxruntime/quantization/execution_providers/qnn/fusion_spacetodepth.py +162 -0
  34. onnxruntime/quantization/execution_providers/qnn/mixed_precision_overrides_utils.py +413 -0
  35. onnxruntime/quantization/execution_providers/qnn/preprocess.py +353 -0
  36. onnxruntime/quantization/execution_providers/qnn/quant_config.py +389 -0
  37. onnxruntime/quantization/fusions/__init__.py +4 -0
  38. onnxruntime/quantization/fusions/fusion.py +311 -0
  39. onnxruntime/quantization/fusions/fusion_gelu.py +272 -0
  40. onnxruntime/quantization/fusions/fusion_layernorm.py +146 -0
  41. onnxruntime/quantization/fusions/replace_upsample_with_resize.py +96 -0
  42. onnxruntime/quantization/matmul_bnb4_quantizer.py +239 -0
  43. onnxruntime/quantization/matmul_nbits_quantizer.py +1638 -0
  44. onnxruntime/quantization/neural_compressor/__init__.py +1 -0
  45. onnxruntime/quantization/neural_compressor/onnx_model.py +1251 -0
  46. onnxruntime/quantization/neural_compressor/util.py +80 -0
  47. onnxruntime/quantization/neural_compressor/weight_only.py +932 -0
  48. onnxruntime/quantization/onnx_model.py +600 -0
  49. onnxruntime/quantization/onnx_quantizer.py +1163 -0
  50. onnxruntime/quantization/operators/__init__.py +2 -0
  51. onnxruntime/quantization/operators/activation.py +119 -0
  52. onnxruntime/quantization/operators/argmax.py +18 -0
  53. onnxruntime/quantization/operators/attention.py +73 -0
  54. onnxruntime/quantization/operators/base_operator.py +26 -0
  55. onnxruntime/quantization/operators/binary_op.py +72 -0
  56. onnxruntime/quantization/operators/concat.py +62 -0
  57. onnxruntime/quantization/operators/conv.py +260 -0
  58. onnxruntime/quantization/operators/direct_q8.py +78 -0
  59. onnxruntime/quantization/operators/embed_layernorm.py +121 -0
  60. onnxruntime/quantization/operators/gather.py +64 -0
  61. onnxruntime/quantization/operators/gavgpool.py +62 -0
  62. onnxruntime/quantization/operators/gemm.py +172 -0
  63. onnxruntime/quantization/operators/lstm.py +121 -0
  64. onnxruntime/quantization/operators/matmul.py +231 -0
  65. onnxruntime/quantization/operators/maxpool.py +34 -0
  66. onnxruntime/quantization/operators/norm.py +40 -0
  67. onnxruntime/quantization/operators/pad.py +172 -0
  68. onnxruntime/quantization/operators/pooling.py +67 -0
  69. onnxruntime/quantization/operators/qdq_base_operator.py +22 -0
  70. onnxruntime/quantization/operators/resize.py +34 -0
  71. onnxruntime/quantization/operators/softmax.py +74 -0
  72. onnxruntime/quantization/operators/split.py +63 -0
  73. onnxruntime/quantization/operators/where.py +87 -0
  74. onnxruntime/quantization/preprocess.py +141 -0
  75. onnxruntime/quantization/qdq_loss_debug.py +389 -0
  76. onnxruntime/quantization/qdq_quantizer.py +1477 -0
  77. onnxruntime/quantization/quant_utils.py +1051 -0
  78. onnxruntime/quantization/quantize.py +953 -0
  79. onnxruntime/quantization/registry.py +110 -0
  80. onnxruntime/quantization/shape_inference.py +204 -0
  81. onnxruntime/quantization/static_quantize_runner.py +256 -0
  82. onnxruntime/quantization/tensor_quant_overrides.py +520 -0
  83. onnxruntime/tools/__init__.py +10 -0
  84. onnxruntime/tools/check_onnx_model_mobile_usability.py +47 -0
  85. onnxruntime/tools/convert_onnx_models_to_ort.py +380 -0
  86. onnxruntime/tools/file_utils.py +47 -0
  87. onnxruntime/tools/logger.py +11 -0
  88. onnxruntime/tools/make_dynamic_shape_fixed.py +73 -0
  89. onnxruntime/tools/mobile_helpers/__init__.py +0 -0
  90. onnxruntime/tools/mobile_helpers/coreml_supported_mlprogram_ops.md +53 -0
  91. onnxruntime/tools/mobile_helpers/coreml_supported_neuralnetwork_ops.md +43 -0
  92. onnxruntime/tools/mobile_helpers/nnapi_supported_ops.md +58 -0
  93. onnxruntime/tools/mobile_helpers/usability_checker.py +738 -0
  94. onnxruntime/tools/offline_tuning.py +169 -0
  95. onnxruntime/tools/onnx_model_utils.py +416 -0
  96. onnxruntime/tools/onnx_randomizer.py +85 -0
  97. onnxruntime/tools/onnxruntime_test.py +164 -0
  98. onnxruntime/tools/optimize_onnx_model.py +56 -0
  99. onnxruntime/tools/ort_format_model/__init__.py +27 -0
  100. onnxruntime/tools/ort_format_model/operator_type_usage_processors.py +653 -0
  101. onnxruntime/tools/ort_format_model/ort_flatbuffers_py/__init__.py +0 -0
  102. onnxruntime/tools/ort_format_model/ort_flatbuffers_py/fbs/ArgType.py +7 -0
  103. onnxruntime/tools/ort_format_model/ort_flatbuffers_py/fbs/ArgTypeAndIndex.py +67 -0
  104. onnxruntime/tools/ort_format_model/ort_flatbuffers_py/fbs/Attribute.py +337 -0
  105. onnxruntime/tools/ort_format_model/ort_flatbuffers_py/fbs/AttributeType.py +18 -0
  106. onnxruntime/tools/ort_format_model/ort_flatbuffers_py/fbs/Checkpoint.py +125 -0
  107. onnxruntime/tools/ort_format_model/ort_flatbuffers_py/fbs/DeprecatedKernelCreateInfos.py +120 -0
  108. onnxruntime/tools/ort_format_model/ort_flatbuffers_py/fbs/DeprecatedNodeIndexAndKernelDefHash.py +68 -0
  109. onnxruntime/tools/ort_format_model/ort_flatbuffers_py/fbs/DeprecatedSessionState.py +96 -0
  110. onnxruntime/tools/ort_format_model/ort_flatbuffers_py/fbs/DeprecatedSubGraphSessionState.py +72 -0
  111. onnxruntime/tools/ort_format_model/ort_flatbuffers_py/fbs/Dimension.py +71 -0
  112. onnxruntime/tools/ort_format_model/ort_flatbuffers_py/fbs/DimensionValue.py +80 -0
  113. onnxruntime/tools/ort_format_model/ort_flatbuffers_py/fbs/DimensionValueType.py +8 -0
  114. onnxruntime/tools/ort_format_model/ort_flatbuffers_py/fbs/EdgeEnd.py +32 -0
  115. onnxruntime/tools/ort_format_model/ort_flatbuffers_py/fbs/FloatProperty.py +67 -0
  116. onnxruntime/tools/ort_format_model/ort_flatbuffers_py/fbs/Graph.py +320 -0
  117. onnxruntime/tools/ort_format_model/ort_flatbuffers_py/fbs/InferenceSession.py +88 -0
  118. onnxruntime/tools/ort_format_model/ort_flatbuffers_py/fbs/IntProperty.py +67 -0
  119. onnxruntime/tools/ort_format_model/ort_flatbuffers_py/fbs/KernelTypeStrArgsEntry.py +91 -0
  120. onnxruntime/tools/ort_format_model/ort_flatbuffers_py/fbs/KernelTypeStrResolver.py +78 -0
  121. onnxruntime/tools/ort_format_model/ort_flatbuffers_py/fbs/MapType.py +71 -0
  122. onnxruntime/tools/ort_format_model/ort_flatbuffers_py/fbs/Model.py +223 -0
  123. onnxruntime/tools/ort_format_model/ort_flatbuffers_py/fbs/ModuleState.py +141 -0
  124. onnxruntime/tools/ort_format_model/ort_flatbuffers_py/fbs/Node.py +317 -0
  125. onnxruntime/tools/ort_format_model/ort_flatbuffers_py/fbs/NodeEdge.py +126 -0
  126. onnxruntime/tools/ort_format_model/ort_flatbuffers_py/fbs/NodeType.py +7 -0
  127. onnxruntime/tools/ort_format_model/ort_flatbuffers_py/fbs/NodesToOptimizeIndices.py +160 -0
  128. onnxruntime/tools/ort_format_model/ort_flatbuffers_py/fbs/OpIdKernelTypeStrArgsEntry.py +91 -0
  129. onnxruntime/tools/ort_format_model/ort_flatbuffers_py/fbs/OperatorSetId.py +67 -0
  130. onnxruntime/tools/ort_format_model/ort_flatbuffers_py/fbs/OptimizerGroup.py +117 -0
  131. onnxruntime/tools/ort_format_model/ort_flatbuffers_py/fbs/ParameterOptimizerState.py +91 -0
  132. onnxruntime/tools/ort_format_model/ort_flatbuffers_py/fbs/PropertyBag.py +152 -0
  133. onnxruntime/tools/ort_format_model/ort_flatbuffers_py/fbs/RuntimeOptimizationRecord.py +105 -0
  134. onnxruntime/tools/ort_format_model/ort_flatbuffers_py/fbs/RuntimeOptimizationRecordContainerEntry.py +91 -0
  135. onnxruntime/tools/ort_format_model/ort_flatbuffers_py/fbs/RuntimeOptimizations.py +79 -0
  136. onnxruntime/tools/ort_format_model/ort_flatbuffers_py/fbs/SequenceType.py +58 -0
  137. onnxruntime/tools/ort_format_model/ort_flatbuffers_py/fbs/Shape.py +78 -0
  138. onnxruntime/tools/ort_format_model/ort_flatbuffers_py/fbs/SparseTensor.py +114 -0
  139. onnxruntime/tools/ort_format_model/ort_flatbuffers_py/fbs/StringProperty.py +67 -0
  140. onnxruntime/tools/ort_format_model/ort_flatbuffers_py/fbs/StringStringEntry.py +67 -0
  141. onnxruntime/tools/ort_format_model/ort_flatbuffers_py/fbs/Tensor.py +203 -0
  142. onnxruntime/tools/ort_format_model/ort_flatbuffers_py/fbs/TensorDataType.py +26 -0
  143. onnxruntime/tools/ort_format_model/ort_flatbuffers_py/fbs/TensorTypeAndShape.py +71 -0
  144. onnxruntime/tools/ort_format_model/ort_flatbuffers_py/fbs/TypeInfo.py +83 -0
  145. onnxruntime/tools/ort_format_model/ort_flatbuffers_py/fbs/TypeInfoValue.py +9 -0
  146. onnxruntime/tools/ort_format_model/ort_flatbuffers_py/fbs/ValueInfo.py +84 -0
  147. onnxruntime/tools/ort_format_model/ort_flatbuffers_py/fbs/__init__.py +6 -0
  148. onnxruntime/tools/ort_format_model/ort_model_processor.py +86 -0
  149. onnxruntime/tools/ort_format_model/types.py +85 -0
  150. onnxruntime/tools/ort_format_model/utils.py +61 -0
  151. onnxruntime/tools/pytorch_export_contrib_ops.py +129 -0
  152. onnxruntime/tools/pytorch_export_helpers.py +131 -0
  153. onnxruntime/tools/qdq_helpers/__init__.py +0 -0
  154. onnxruntime/tools/qdq_helpers/optimize_qdq_model.py +37 -0
  155. onnxruntime/tools/qnn/add_trans_cast.py +292 -0
  156. onnxruntime/tools/qnn/gen_qnn_ctx_onnx_model.py +364 -0
  157. onnxruntime/tools/qnn/preprocess.py +165 -0
  158. onnxruntime/tools/reduced_build_config_parser.py +203 -0
  159. onnxruntime/tools/remove_initializer_from_input.py +37 -0
  160. onnxruntime/tools/symbolic_shape_infer.py +3094 -0
  161. onnxruntime/tools/update_onnx_opset.py +31 -0
  162. onnxruntime/transformers/__init__.py +8 -0
  163. onnxruntime/transformers/affinity_helper.py +40 -0
  164. onnxruntime/transformers/benchmark.py +942 -0
  165. onnxruntime/transformers/benchmark_helper.py +643 -0
  166. onnxruntime/transformers/bert_perf_test.py +629 -0
  167. onnxruntime/transformers/bert_test_data.py +641 -0
  168. onnxruntime/transformers/compare_bert_results.py +256 -0
  169. onnxruntime/transformers/constants.py +47 -0
  170. onnxruntime/transformers/convert_generation.py +3605 -0
  171. onnxruntime/transformers/convert_tf_models_to_pytorch.py +205 -0
  172. onnxruntime/transformers/convert_to_packing_mode.py +385 -0
  173. onnxruntime/transformers/dynamo_onnx_helper.py +205 -0
  174. onnxruntime/transformers/float16.py +501 -0
  175. onnxruntime/transformers/fusion_attention.py +1189 -0
  176. onnxruntime/transformers/fusion_attention_clip.py +340 -0
  177. onnxruntime/transformers/fusion_attention_sam2.py +533 -0
  178. onnxruntime/transformers/fusion_attention_unet.py +1307 -0
  179. onnxruntime/transformers/fusion_attention_vae.py +300 -0
  180. onnxruntime/transformers/fusion_bart_attention.py +435 -0
  181. onnxruntime/transformers/fusion_base.py +141 -0
  182. onnxruntime/transformers/fusion_bias_add.py +57 -0
  183. onnxruntime/transformers/fusion_biasgelu.py +66 -0
  184. onnxruntime/transformers/fusion_biassplitgelu.py +110 -0
  185. onnxruntime/transformers/fusion_conformer_attention.py +222 -0
  186. onnxruntime/transformers/fusion_constant_fold.py +144 -0
  187. onnxruntime/transformers/fusion_embedlayer.py +810 -0
  188. onnxruntime/transformers/fusion_fastgelu.py +492 -0
  189. onnxruntime/transformers/fusion_gelu.py +258 -0
  190. onnxruntime/transformers/fusion_gelu_approximation.py +25 -0
  191. onnxruntime/transformers/fusion_gemmfastgelu.py +121 -0
  192. onnxruntime/transformers/fusion_gpt_attention.py +546 -0
  193. onnxruntime/transformers/fusion_gpt_attention_megatron.py +355 -0
  194. onnxruntime/transformers/fusion_gpt_attention_no_past.py +260 -0
  195. onnxruntime/transformers/fusion_group_norm.py +180 -0
  196. onnxruntime/transformers/fusion_layernorm.py +489 -0
  197. onnxruntime/transformers/fusion_mha_mmdit.py +667 -0
  198. onnxruntime/transformers/fusion_nhwc_conv.py +99 -0
  199. onnxruntime/transformers/fusion_options.py +340 -0
  200. onnxruntime/transformers/fusion_qordered_attention.py +420 -0
  201. onnxruntime/transformers/fusion_qordered_gelu.py +118 -0
  202. onnxruntime/transformers/fusion_qordered_layernorm.py +122 -0
  203. onnxruntime/transformers/fusion_qordered_matmul.py +216 -0
  204. onnxruntime/transformers/fusion_quickgelu.py +74 -0
  205. onnxruntime/transformers/fusion_reshape.py +173 -0
  206. onnxruntime/transformers/fusion_rotary_attention.py +1591 -0
  207. onnxruntime/transformers/fusion_shape.py +109 -0
  208. onnxruntime/transformers/fusion_simplified_layernorm.py +165 -0
  209. onnxruntime/transformers/fusion_skip_group_norm.py +254 -0
  210. onnxruntime/transformers/fusion_skiplayernorm.py +209 -0
  211. onnxruntime/transformers/fusion_transpose.py +167 -0
  212. onnxruntime/transformers/fusion_utils.py +321 -0
  213. onnxruntime/transformers/huggingface_models.py +74 -0
  214. onnxruntime/transformers/import_utils.py +20 -0
  215. onnxruntime/transformers/io_binding_helper.py +487 -0
  216. onnxruntime/transformers/large_model_exporter.py +395 -0
  217. onnxruntime/transformers/machine_info.py +230 -0
  218. onnxruntime/transformers/metrics.py +163 -0
  219. onnxruntime/transformers/models/bart/__init__.py +12 -0
  220. onnxruntime/transformers/models/bart/export.py +98 -0
  221. onnxruntime/transformers/models/bert/__init__.py +12 -0
  222. onnxruntime/transformers/models/bert/eval_squad.py +329 -0
  223. onnxruntime/transformers/models/gpt2/__init__.py +12 -0
  224. onnxruntime/transformers/models/gpt2/benchmark_gpt2.py +413 -0
  225. onnxruntime/transformers/models/gpt2/convert_to_onnx.py +566 -0
  226. onnxruntime/transformers/models/gpt2/gpt2_helper.py +1031 -0
  227. onnxruntime/transformers/models/gpt2/gpt2_parity.py +513 -0
  228. onnxruntime/transformers/models/gpt2/gpt2_tester.py +501 -0
  229. onnxruntime/transformers/models/gpt2/parity_check_helper.py +146 -0
  230. onnxruntime/transformers/models/llama/__init__.py +12 -0
  231. onnxruntime/transformers/models/llama/benchmark.py +700 -0
  232. onnxruntime/transformers/models/llama/benchmark_all.py +488 -0
  233. onnxruntime/transformers/models/llama/benchmark_e2e.py +608 -0
  234. onnxruntime/transformers/models/llama/convert_to_onnx.py +1064 -0
  235. onnxruntime/transformers/models/llama/dist_settings.py +57 -0
  236. onnxruntime/transformers/models/llama/llama_inputs.py +504 -0
  237. onnxruntime/transformers/models/llama/llama_parity.py +343 -0
  238. onnxruntime/transformers/models/llama/llama_torch.py +47 -0
  239. onnxruntime/transformers/models/llama/quant_kv_dataloader.py +108 -0
  240. onnxruntime/transformers/models/longformer/__init__.py +12 -0
  241. onnxruntime/transformers/models/longformer/benchmark_longformer.py +821 -0
  242. onnxruntime/transformers/models/longformer/convert_to_onnx.py +413 -0
  243. onnxruntime/transformers/models/longformer/generate_test_data.py +347 -0
  244. onnxruntime/transformers/models/longformer/longformer_helper.py +76 -0
  245. onnxruntime/transformers/models/phi2/__init__.py +12 -0
  246. onnxruntime/transformers/models/phi2/convert_to_onnx.py +590 -0
  247. onnxruntime/transformers/models/phi2/inference_example.py +414 -0
  248. onnxruntime/transformers/models/sam2/__init__.py +12 -0
  249. onnxruntime/transformers/models/sam2/benchmark_sam2.py +638 -0
  250. onnxruntime/transformers/models/sam2/convert_to_onnx.py +270 -0
  251. onnxruntime/transformers/models/sam2/image_decoder.py +272 -0
  252. onnxruntime/transformers/models/sam2/image_encoder.py +236 -0
  253. onnxruntime/transformers/models/sam2/mask_decoder.py +208 -0
  254. onnxruntime/transformers/models/sam2/nvtx_helper.py +33 -0
  255. onnxruntime/transformers/models/sam2/prompt_encoder.py +189 -0
  256. onnxruntime/transformers/models/sam2/sam2_demo.py +321 -0
  257. onnxruntime/transformers/models/sam2/sam2_image_onnx_predictor.py +279 -0
  258. onnxruntime/transformers/models/sam2/sam2_utils.py +147 -0
  259. onnxruntime/transformers/models/stable_diffusion/__init__.py +12 -0
  260. onnxruntime/transformers/models/stable_diffusion/benchmark.py +1519 -0
  261. onnxruntime/transformers/models/stable_diffusion/benchmark_controlnet.py +426 -0
  262. onnxruntime/transformers/models/stable_diffusion/demo_txt2img.py +103 -0
  263. onnxruntime/transformers/models/stable_diffusion/demo_txt2img_xl.py +269 -0
  264. onnxruntime/transformers/models/stable_diffusion/demo_utils.py +778 -0
  265. onnxruntime/transformers/models/stable_diffusion/diffusion_models.py +1318 -0
  266. onnxruntime/transformers/models/stable_diffusion/diffusion_schedulers.py +1179 -0
  267. onnxruntime/transformers/models/stable_diffusion/engine_builder.py +295 -0
  268. onnxruntime/transformers/models/stable_diffusion/engine_builder_ort_cuda.py +387 -0
  269. onnxruntime/transformers/models/stable_diffusion/engine_builder_ort_trt.py +288 -0
  270. onnxruntime/transformers/models/stable_diffusion/engine_builder_tensorrt.py +395 -0
  271. onnxruntime/transformers/models/stable_diffusion/engine_builder_torch.py +108 -0
  272. onnxruntime/transformers/models/stable_diffusion/optimize_pipeline.py +590 -0
  273. onnxruntime/transformers/models/stable_diffusion/ort_optimizer.py +136 -0
  274. onnxruntime/transformers/models/stable_diffusion/pipeline_stable_diffusion.py +831 -0
  275. onnxruntime/transformers/models/stable_diffusion/trt_utilities.py +12 -0
  276. onnxruntime/transformers/models/t5/__init__.py +12 -0
  277. onnxruntime/transformers/models/t5/convert_to_onnx.py +318 -0
  278. onnxruntime/transformers/models/t5/t5_decoder.py +437 -0
  279. onnxruntime/transformers/models/t5/t5_encoder.py +70 -0
  280. onnxruntime/transformers/models/t5/t5_encoder_decoder_init.py +361 -0
  281. onnxruntime/transformers/models/t5/t5_helper.py +302 -0
  282. onnxruntime/transformers/models/whisper/__init__.py +12 -0
  283. onnxruntime/transformers/models/whisper/benchmark.py +585 -0
  284. onnxruntime/transformers/models/whisper/benchmark_all.py +526 -0
  285. onnxruntime/transformers/models/whisper/convert_to_onnx.py +609 -0
  286. onnxruntime/transformers/models/whisper/whisper_chain.py +334 -0
  287. onnxruntime/transformers/models/whisper/whisper_decoder.py +464 -0
  288. onnxruntime/transformers/models/whisper/whisper_encoder.py +164 -0
  289. onnxruntime/transformers/models/whisper/whisper_encoder_decoder_init.py +371 -0
  290. onnxruntime/transformers/models/whisper/whisper_helper.py +1035 -0
  291. onnxruntime/transformers/models/whisper/whisper_inputs.py +380 -0
  292. onnxruntime/transformers/models/whisper/whisper_jump_times.py +477 -0
  293. onnxruntime/transformers/onnx_exporter.py +719 -0
  294. onnxruntime/transformers/onnx_model.py +1636 -0
  295. onnxruntime/transformers/onnx_model_bart.py +141 -0
  296. onnxruntime/transformers/onnx_model_bert.py +488 -0
  297. onnxruntime/transformers/onnx_model_bert_keras.py +474 -0
  298. onnxruntime/transformers/onnx_model_bert_tf.py +588 -0
  299. onnxruntime/transformers/onnx_model_clip.py +42 -0
  300. onnxruntime/transformers/onnx_model_conformer.py +32 -0
  301. onnxruntime/transformers/onnx_model_gpt2.py +101 -0
  302. onnxruntime/transformers/onnx_model_mmdit.py +112 -0
  303. onnxruntime/transformers/onnx_model_phi.py +929 -0
  304. onnxruntime/transformers/onnx_model_sam2.py +137 -0
  305. onnxruntime/transformers/onnx_model_t5.py +985 -0
  306. onnxruntime/transformers/onnx_model_tnlr.py +226 -0
  307. onnxruntime/transformers/onnx_model_unet.py +258 -0
  308. onnxruntime/transformers/onnx_model_vae.py +42 -0
  309. onnxruntime/transformers/onnx_utils.py +55 -0
  310. onnxruntime/transformers/optimizer.py +620 -0
  311. onnxruntime/transformers/past_helper.py +149 -0
  312. onnxruntime/transformers/profile_result_processor.py +358 -0
  313. onnxruntime/transformers/profiler.py +434 -0
  314. onnxruntime/transformers/quantize_helper.py +76 -0
  315. onnxruntime/transformers/shape_infer_helper.py +121 -0
  316. onnxruntime/transformers/shape_optimizer.py +400 -0
  317. onnxruntime/transformers/torch_onnx_export_helper.py +74 -0
  318. onnxruntime_directml-1.24.1.dist-info/METADATA +216 -0
  319. onnxruntime_directml-1.24.1.dist-info/RECORD +322 -0
  320. onnxruntime_directml-1.24.1.dist-info/WHEEL +5 -0
  321. onnxruntime_directml-1.24.1.dist-info/entry_points.txt +2 -0
  322. onnxruntime_directml-1.24.1.dist-info/top_level.txt +1 -0
@@ -0,0 +1,413 @@
1
+ # -------------------------------------------------------------------------
2
+ # Copyright (c) Microsoft Corporation. All rights reserved.
3
+ # Licensed under the MIT License. See License.txt in the project root for
4
+ # license information.
5
+ # --------------------------------------------------------------------------
6
+
7
+ # This script converts Longformer model from huggingface transformers 4.0 or later to ONNX.
8
+ # It translates LongformerSelfAttention to the LongformerAttention operator in ONNX Runtime.
9
+ #
10
+ # Before running this script, prepare a python environment in Linux with PyTorch 1.9.0 and other packages installed.
11
+ # Then run "python setup.py install" in ./torch_extensions directory. If your python version is not 3.8, you will need
12
+ # update this script with correct name of longformer_attention.cpython-*.so (search TODO below).
13
+ #
14
+ # It is tested in Ubuntu 18.04 with python 3.8, onnxruntime-gpu 1.11.0, PyTorch 1.9.0, transformers 4.18.0.
15
+ # Warning: Using PyTorch 1.10 or newer version might encounter issue in exporting, but they are fine for benchmarking.
16
+ #
17
+ # Example commands to export longformer base model in Linux:
18
+ # conda create -n longformer python=3.8
19
+ # conda activate longformer
20
+ # python3 -m pip install torch==1.9.0+cu111 torchvision==0.10.0+cu111 torchaudio==0.9.0 -f https://download.pytorch.org/whl/torch_stable.html
21
+ # python3 -m pip install flatbuffers numpy packaging sympy protobuf==3.20.1 onnx==1.12.0 transformers==4.18.0
22
+ # python3 -m pip install -i https://test.pypi.org/simple/ ort-nightly-gpu
23
+ # cd ./torch_extensions
24
+ # rm -rf build
25
+ # python setup.py install
26
+ # cd ..
27
+ # python convert_to_onnx.py --model longformer-base-4096 --precision fp16 --optimize_onnx
28
+ # python convert_to_onnx.py --model longformer-base-4096 --precision fp16 --optimize_onnx --no_merge_qkv
29
+ #
30
+ # GPU is not needed for this script. You can run it in CPU. For --optimize_onnx, you can use either onnxruntime or onnxruntime-gpu package.
31
+ #
32
+ # For inference of the onnx model, you will need onnxruntime-gpu 1.7.0 or newer version.
33
+
34
+ import argparse
35
+ import inspect
36
+ from pathlib import Path
37
+
38
+ import torch
39
+ import transformers
40
+ from longformer_helper import PRETRAINED_LONGFORMER_MODELS
41
+ from onnx import load_model
42
+ from onnx_model_bert import BertOnnxModel
43
+ from packaging import version
44
+ from torch.onnx import register_custom_op_symbolic
45
+ from torch.onnx.symbolic_helper import parse_args
46
+ from torch_onnx_export_helper import torch_onnx_export
47
+ from transformers import LongformerModel, LongformerSelfAttention
48
+
49
+ # Supports format 0 or 1
50
+ weight_bias_format = 0
51
+
52
+
53
+ @parse_args("v", "v", "v", "v", "v", "v", "v", "i", "i")
54
+ def my_longformer_attention(
55
+ g,
56
+ input,
57
+ weight,
58
+ bias,
59
+ mask,
60
+ global_weight,
61
+ global_bias,
62
+ global_mask,
63
+ num_heads,
64
+ window,
65
+ ):
66
+ return g.op(
67
+ "com.microsoft::LongformerAttention",
68
+ input,
69
+ weight,
70
+ bias,
71
+ mask,
72
+ global_weight,
73
+ global_bias,
74
+ global_mask,
75
+ num_heads_i=num_heads,
76
+ window_i=window,
77
+ )
78
+
79
+
80
+ # namespace is onnxruntime which is registered in longformer_attention.cpp
81
+ register_custom_op_symbolic("onnxruntime::LongformerAttention", my_longformer_attention, 9)
82
+
83
+ # TODO: search the directory to find correct output filename of "python setup.py install" when python version is not 3.8
84
+ torch.ops.load_library(
85
+ r"./torch_extensions/build/lib.linux-x86_64-3.8/longformer_attention.cpython-38-x86_64-linux-gnu.so"
86
+ )
87
+
88
+
89
+ def parse_arguments():
90
+ """Parse arguments
91
+
92
+ Returns:
93
+ args: Namespace
94
+ """
95
+ parser = argparse.ArgumentParser()
96
+
97
+ parser.add_argument(
98
+ "-m",
99
+ "--model",
100
+ required=False,
101
+ type=str,
102
+ default="longformer-base-4096",
103
+ help="Checkpoint directory or pre-trained model names in the list: "
104
+ + ", ".join(PRETRAINED_LONGFORMER_MODELS.keys()),
105
+ )
106
+
107
+ parser.add_argument(
108
+ "--export_padding",
109
+ required=False,
110
+ action="store_true",
111
+ help="Export padding logic to ONNX graph. If not enabled, user need pad input so that sequence length is multiple of window size.",
112
+ )
113
+ parser.set_defaults(export_padding=False)
114
+
115
+ parser.add_argument(
116
+ "--no_merge_qkv",
117
+ required=False,
118
+ action="store_true",
119
+ help="Stack the weights of q, k and v on dimension 0 instead of dimension 1.",
120
+ )
121
+ parser.set_defaults(no_merge_qkv=False)
122
+
123
+ parser.add_argument(
124
+ "-o",
125
+ "--optimize_onnx",
126
+ required=False,
127
+ action="store_true",
128
+ help="Use optimizer.py to optimize onnx model.",
129
+ )
130
+ parser.set_defaults(optimize_onnx=False)
131
+
132
+ parser.add_argument(
133
+ "-p",
134
+ "--precision",
135
+ required=False,
136
+ type=str,
137
+ default="fp32",
138
+ choices=["fp32", "fp16"],
139
+ help="Precision of model to run: fp32 for full precision, fp16 for mixed precision",
140
+ )
141
+
142
+ args = parser.parse_args()
143
+ return args
144
+
145
+
146
+ # Create a dummy input for ONNX export.
147
+ def get_dummy_inputs(config, export_padding, device):
148
+ # When sequence length is multiple of windows size, there is no padding logic in ONNX graph
149
+ sequence_length = config.attention_window[0] + 1 if export_padding else config.attention_window[0]
150
+
151
+ # Create dummy inputs
152
+ input_ids = torch.arange(sequence_length).unsqueeze(0).to(device)
153
+
154
+ attention_mask = torch.ones(input_ids.shape, dtype=torch.long, device=device)
155
+ attention_mask[:, sequence_length - 1] = 0 # last token is masked
156
+
157
+ global_attention_mask = torch.zeros(input_ids.shape, dtype=torch.long, device=device)
158
+ global_attention_mask[:, 0] = 1 # first token is global token
159
+
160
+ return input_ids, attention_mask, global_attention_mask
161
+
162
+
163
+ # A new function to replace LongformerSelfAttention.forward
164
+ # For transformers 4.0.0
165
+ def my_longformer_self_attention_forward_4(
166
+ self,
167
+ hidden_states,
168
+ attention_mask=None,
169
+ is_index_masked=None,
170
+ is_index_global_attn=None,
171
+ is_global_attn=None,
172
+ ):
173
+ global_mask = is_index_global_attn.int()
174
+ # The following check is based on the dummy inputs (only the first token is global).
175
+ assert (
176
+ len(global_mask.shape) == 2
177
+ and global_mask.shape[0] == 1
178
+ and global_mask.count_nonzero().item() == 1
179
+ and global_mask.tolist()[0][0] == 1
180
+ )
181
+
182
+ input_mask = is_index_masked.float()
183
+ # TODO: The filtering value may be -10000.0 or -inf. Check the huggingface implementation.
184
+ input_mask = input_mask.masked_fill(is_index_masked, -10000.0)
185
+ # Yet another way to generate input_mask = torch.masked_fill(attention_mask, is_index_global_attn, 0.0)
186
+
187
+ # TODO: add postprocessing of ONNX model to calculate based on graph input: input_mask = (attention_mask - 1) * 10000.0
188
+ # TODO: add postprocessing of ONNX model to use graph input directly: global_mask = global_attention_mask
189
+
190
+ # The following check is based on the dummy inputs (only the last token is masked).
191
+ assert (
192
+ len(input_mask.shape) == 2
193
+ and input_mask.shape[0] == 1
194
+ and input_mask.count_nonzero().item() == 1
195
+ and input_mask.tolist()[0][-1] == -10000.0
196
+ )
197
+
198
+ weight = torch.stack(
199
+ (
200
+ self.query.weight.transpose(0, 1),
201
+ self.key.weight.transpose(0, 1),
202
+ self.value.weight.transpose(0, 1),
203
+ ),
204
+ dim=weight_bias_format,
205
+ )
206
+
207
+ if weight_bias_format == 1:
208
+ # shape is (hidden_size, 3*hidden_size) for format 1, otherwise (3, hidden_size, hidden_size) by default
209
+ weight = weight.reshape(self.embed_dim, 3 * self.embed_dim)
210
+
211
+ global_weight = torch.stack(
212
+ (
213
+ self.query_global.weight.transpose(0, 1),
214
+ self.key_global.weight.transpose(0, 1),
215
+ self.value_global.weight.transpose(0, 1),
216
+ ),
217
+ dim=weight_bias_format,
218
+ )
219
+
220
+ if weight_bias_format == 1:
221
+ global_weight = global_weight.reshape(self.embed_dim, 3 * self.embed_dim)
222
+
223
+ if weight_bias_format == 1:
224
+ bias = torch.stack((self.query.bias, self.key.bias, self.value.bias), dim=0)
225
+ bias = bias.reshape(3 * self.embed_dim)
226
+ global_bias = torch.stack((self.query_global.bias, self.key_global.bias, self.value_global.bias), dim=0)
227
+ global_bias = global_bias.reshape(3 * self.embed_dim)
228
+ else:
229
+ bias = torch.stack(
230
+ (self.query.bias, self.key.bias, self.value.bias, self.key_global.bias, self.value_global.bias), dim=0
231
+ )
232
+ bias = bias.reshape(5 * self.embed_dim)
233
+ global_bias = self.query_global.bias
234
+ global_bias = global_bias.reshape(1 * self.embed_dim)
235
+
236
+ attn_output = torch.ops.onnxruntime.LongformerAttention(
237
+ hidden_states,
238
+ weight,
239
+ bias,
240
+ input_mask,
241
+ global_weight,
242
+ global_bias,
243
+ global_mask,
244
+ self.num_heads,
245
+ self.one_sided_attn_window_size,
246
+ )
247
+
248
+ assert attn_output.size() == hidden_states.size(), "Unexpected size"
249
+
250
+ outputs = (attn_output,)
251
+ return outputs
252
+
253
+
254
+ # For transformers 4.3.0
255
+ def my_longformer_self_attention_forward_4_3(
256
+ self,
257
+ hidden_states,
258
+ attention_mask=None,
259
+ is_index_masked=None,
260
+ is_index_global_attn=None,
261
+ is_global_attn=None,
262
+ output_attentions=False,
263
+ ):
264
+ assert output_attentions is False
265
+ return my_longformer_self_attention_forward_4(
266
+ self,
267
+ hidden_states,
268
+ attention_mask,
269
+ is_index_masked,
270
+ is_index_global_attn,
271
+ is_global_attn,
272
+ )
273
+
274
+
275
+ # For transformers 4.3.2 or later versions
276
+ def my_longformer_self_attention_forward_4_3_2(
277
+ self,
278
+ hidden_states,
279
+ attention_mask=None,
280
+ layer_head_mask=None,
281
+ is_index_masked=None,
282
+ is_index_global_attn=None,
283
+ is_global_attn=None,
284
+ output_attentions=False,
285
+ ):
286
+ assert output_attentions is False
287
+ assert layer_head_mask is None
288
+ return my_longformer_self_attention_forward_4(
289
+ self,
290
+ hidden_states,
291
+ attention_mask,
292
+ is_index_masked,
293
+ is_index_global_attn,
294
+ is_global_attn,
295
+ )
296
+
297
+
298
+ def export_longformer(model: LongformerModel, onnx_model_path: str, export_padding: bool):
299
+ """Export longformer model to ONNX
300
+
301
+ Args:
302
+ model (LongformerModel): longformer model
303
+ onnx_model_path (str): output onnx path
304
+ export_padding (bool): whether export padding logic to ONNX so that input string can be any length.
305
+
306
+ Raises:
307
+ RuntimeError: This tool requires transformers 4.0.0 or later.
308
+ RuntimeError: LongformerSelfAttention.forward arguments are different.
309
+ """
310
+ input_ids, attention_mask, global_attention_mask = get_dummy_inputs(
311
+ model.config, export_padding, device=torch.device("cpu")
312
+ )
313
+
314
+ _ = model(
315
+ input_ids,
316
+ attention_mask=attention_mask,
317
+ global_attention_mask=global_attention_mask,
318
+ )
319
+
320
+ if version.parse(transformers.__version__) < version.parse("4.0.0"):
321
+ raise RuntimeError("This tool requires transformers 4.0.0 or later.")
322
+
323
+ # Here we replace LongformerSelfAttention.forward using our implementation for exporting ONNX model
324
+ key = " ".join(inspect.getfullargspec(LongformerSelfAttention.forward).args)
325
+ args_to_func = {
326
+ "self hidden_states attention_mask layer_head_mask is_index_masked is_index_global_attn is_global_attn output_attentions": my_longformer_self_attention_forward_4_3_2,
327
+ "self hidden_states attention_mask is_index_masked is_index_global_attn is_global_attn output_attentions": my_longformer_self_attention_forward_4_3,
328
+ "self hidden_states attention_mask is_index_masked is_index_global_attn is_global_attn": my_longformer_self_attention_forward_4,
329
+ }
330
+
331
+ if key not in args_to_func:
332
+ print(
333
+ "Current arguments",
334
+ inspect.getfullargspec(LongformerSelfAttention.forward).args,
335
+ )
336
+ raise RuntimeError(
337
+ "LongformerSelfAttention.forward arguments are different. Please install supported version (like transformers 4.3.0)."
338
+ )
339
+
340
+ # Store for restoring later
341
+ original_forward = LongformerSelfAttention.forward
342
+
343
+ LongformerSelfAttention.forward = args_to_func[key]
344
+
345
+ example_inputs = (input_ids, attention_mask, global_attention_mask)
346
+
347
+ Path(onnx_model_path).parent.mkdir(parents=True, exist_ok=True)
348
+
349
+ torch_onnx_export(
350
+ model,
351
+ example_inputs,
352
+ onnx_model_path,
353
+ opset_version=12,
354
+ input_names=["input_ids", "attention_mask", "global_attention_mask"],
355
+ output_names=["last_state", "pooler"],
356
+ dynamic_axes={
357
+ "input_ids": {0: "batch_size", 1: "sequence_length"},
358
+ "attention_mask": {0: "batch_size", 1: "sequence_length"},
359
+ "global_attention_mask": {0: "batch_size", 1: "sequence_length"},
360
+ "last_state": {0: "batch_size", 1: "sequence_length"},
361
+ "pooler": {0: "batch_size", 1: "sequence_length"},
362
+ },
363
+ custom_opsets={"com.microsoft": 1},
364
+ )
365
+ print(f"ONNX model exported to {onnx_model_path}")
366
+
367
+ # Restore original implementation:
368
+ LongformerSelfAttention.forward = original_forward
369
+
370
+
371
+ def optimize_longformer(onnx_model_path: str, fp32_model_path: str, fp16_model_path=None):
372
+ """Optimize longformer onnx model
373
+
374
+ Args:
375
+ onnx_model_path (str): path of original ONNX model.
376
+ fp32_model_path (str): path of optimized fp32 model.
377
+ fp16_model_path (str, optional): path of optimized fp16 model. Defaults to None.
378
+ """
379
+ model = load_model(onnx_model_path, format=None, load_external_data=True)
380
+ optimizer = BertOnnxModel(model)
381
+ optimizer.optimize()
382
+
383
+ use_external_data_format = False
384
+ if fp32_model_path:
385
+ optimizer.save_model_to_file(fp32_model_path, use_external_data_format)
386
+ print(f"optimized fp32 model saved to {fp32_model_path}")
387
+
388
+ if fp16_model_path:
389
+ optimizer.convert_float_to_float16(keep_io_types=True)
390
+ optimizer.save_model_to_file(fp16_model_path, use_external_data_format)
391
+ print(f"optimized fp16 model saved to {fp16_model_path}")
392
+
393
+
394
+ def main(args):
395
+ model_name = args.model
396
+ onnx_model_path = model_name + ".onnx"
397
+
398
+ global weight_bias_format # noqa: PLW0603
399
+ weight_bias_format = 0 if args.no_merge_qkv else 1
400
+
401
+ model = LongformerModel.from_pretrained(PRETRAINED_LONGFORMER_MODELS[model_name])
402
+
403
+ export_longformer(model, onnx_model_path, args.export_padding)
404
+
405
+ if args.optimize_onnx or args.precision != "fp32":
406
+ fp32_model_path = model_name + f"_f{weight_bias_format}" + "_fp32.onnx"
407
+ fp16_model_path = model_name + f"_f{weight_bias_format}" + "_fp16.onnx" if args.precision == "fp16" else None
408
+ optimize_longformer(onnx_model_path, fp32_model_path, fp16_model_path)
409
+
410
+
411
+ if __name__ == "__main__":
412
+ args = parse_arguments()
413
+ main(args)