onnxruntime-directml 1.24.1__cp314-cp314-win_amd64.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (322) hide show
  1. onnxruntime/LICENSE +21 -0
  2. onnxruntime/Privacy.md +21 -0
  3. onnxruntime/ThirdPartyNotices.txt +6121 -0
  4. onnxruntime/__init__.py +418 -0
  5. onnxruntime/backend/__init__.py +6 -0
  6. onnxruntime/backend/backend.py +175 -0
  7. onnxruntime/backend/backend_rep.py +52 -0
  8. onnxruntime/capi/DirectML.dll +0 -0
  9. onnxruntime/capi/__init__.py +4 -0
  10. onnxruntime/capi/_ld_preload.py +7 -0
  11. onnxruntime/capi/_pybind_state.py +33 -0
  12. onnxruntime/capi/build_and_package_info.py +2 -0
  13. onnxruntime/capi/convert_npz_to_onnx_adapter.py +48 -0
  14. onnxruntime/capi/onnxruntime.dll +0 -0
  15. onnxruntime/capi/onnxruntime_collect_build_info.py +47 -0
  16. onnxruntime/capi/onnxruntime_inference_collection.py +1440 -0
  17. onnxruntime/capi/onnxruntime_providers_shared.dll +0 -0
  18. onnxruntime/capi/onnxruntime_pybind11_state.pyd +0 -0
  19. onnxruntime/capi/onnxruntime_validation.py +154 -0
  20. onnxruntime/capi/version_info.py +2 -0
  21. onnxruntime/datasets/__init__.py +18 -0
  22. onnxruntime/datasets/logreg_iris.onnx +0 -0
  23. onnxruntime/datasets/mul_1.onnx +0 -0
  24. onnxruntime/datasets/sigmoid.onnx +13 -0
  25. onnxruntime/quantization/CalTableFlatBuffers/KeyValue.py +78 -0
  26. onnxruntime/quantization/CalTableFlatBuffers/TrtTable.py +90 -0
  27. onnxruntime/quantization/CalTableFlatBuffers/__init__.py +0 -0
  28. onnxruntime/quantization/__init__.py +19 -0
  29. onnxruntime/quantization/base_quantizer.py +529 -0
  30. onnxruntime/quantization/calibrate.py +1267 -0
  31. onnxruntime/quantization/execution_providers/qnn/__init__.py +2 -0
  32. onnxruntime/quantization/execution_providers/qnn/fusion_lpnorm.py +132 -0
  33. onnxruntime/quantization/execution_providers/qnn/fusion_spacetodepth.py +162 -0
  34. onnxruntime/quantization/execution_providers/qnn/mixed_precision_overrides_utils.py +413 -0
  35. onnxruntime/quantization/execution_providers/qnn/preprocess.py +353 -0
  36. onnxruntime/quantization/execution_providers/qnn/quant_config.py +389 -0
  37. onnxruntime/quantization/fusions/__init__.py +4 -0
  38. onnxruntime/quantization/fusions/fusion.py +311 -0
  39. onnxruntime/quantization/fusions/fusion_gelu.py +272 -0
  40. onnxruntime/quantization/fusions/fusion_layernorm.py +146 -0
  41. onnxruntime/quantization/fusions/replace_upsample_with_resize.py +96 -0
  42. onnxruntime/quantization/matmul_bnb4_quantizer.py +239 -0
  43. onnxruntime/quantization/matmul_nbits_quantizer.py +1638 -0
  44. onnxruntime/quantization/neural_compressor/__init__.py +1 -0
  45. onnxruntime/quantization/neural_compressor/onnx_model.py +1251 -0
  46. onnxruntime/quantization/neural_compressor/util.py +80 -0
  47. onnxruntime/quantization/neural_compressor/weight_only.py +932 -0
  48. onnxruntime/quantization/onnx_model.py +600 -0
  49. onnxruntime/quantization/onnx_quantizer.py +1163 -0
  50. onnxruntime/quantization/operators/__init__.py +2 -0
  51. onnxruntime/quantization/operators/activation.py +119 -0
  52. onnxruntime/quantization/operators/argmax.py +18 -0
  53. onnxruntime/quantization/operators/attention.py +73 -0
  54. onnxruntime/quantization/operators/base_operator.py +26 -0
  55. onnxruntime/quantization/operators/binary_op.py +72 -0
  56. onnxruntime/quantization/operators/concat.py +62 -0
  57. onnxruntime/quantization/operators/conv.py +260 -0
  58. onnxruntime/quantization/operators/direct_q8.py +78 -0
  59. onnxruntime/quantization/operators/embed_layernorm.py +121 -0
  60. onnxruntime/quantization/operators/gather.py +64 -0
  61. onnxruntime/quantization/operators/gavgpool.py +62 -0
  62. onnxruntime/quantization/operators/gemm.py +172 -0
  63. onnxruntime/quantization/operators/lstm.py +121 -0
  64. onnxruntime/quantization/operators/matmul.py +231 -0
  65. onnxruntime/quantization/operators/maxpool.py +34 -0
  66. onnxruntime/quantization/operators/norm.py +40 -0
  67. onnxruntime/quantization/operators/pad.py +172 -0
  68. onnxruntime/quantization/operators/pooling.py +67 -0
  69. onnxruntime/quantization/operators/qdq_base_operator.py +22 -0
  70. onnxruntime/quantization/operators/resize.py +34 -0
  71. onnxruntime/quantization/operators/softmax.py +74 -0
  72. onnxruntime/quantization/operators/split.py +63 -0
  73. onnxruntime/quantization/operators/where.py +87 -0
  74. onnxruntime/quantization/preprocess.py +141 -0
  75. onnxruntime/quantization/qdq_loss_debug.py +389 -0
  76. onnxruntime/quantization/qdq_quantizer.py +1477 -0
  77. onnxruntime/quantization/quant_utils.py +1051 -0
  78. onnxruntime/quantization/quantize.py +953 -0
  79. onnxruntime/quantization/registry.py +110 -0
  80. onnxruntime/quantization/shape_inference.py +204 -0
  81. onnxruntime/quantization/static_quantize_runner.py +256 -0
  82. onnxruntime/quantization/tensor_quant_overrides.py +520 -0
  83. onnxruntime/tools/__init__.py +10 -0
  84. onnxruntime/tools/check_onnx_model_mobile_usability.py +47 -0
  85. onnxruntime/tools/convert_onnx_models_to_ort.py +380 -0
  86. onnxruntime/tools/file_utils.py +47 -0
  87. onnxruntime/tools/logger.py +11 -0
  88. onnxruntime/tools/make_dynamic_shape_fixed.py +73 -0
  89. onnxruntime/tools/mobile_helpers/__init__.py +0 -0
  90. onnxruntime/tools/mobile_helpers/coreml_supported_mlprogram_ops.md +53 -0
  91. onnxruntime/tools/mobile_helpers/coreml_supported_neuralnetwork_ops.md +43 -0
  92. onnxruntime/tools/mobile_helpers/nnapi_supported_ops.md +58 -0
  93. onnxruntime/tools/mobile_helpers/usability_checker.py +738 -0
  94. onnxruntime/tools/offline_tuning.py +169 -0
  95. onnxruntime/tools/onnx_model_utils.py +416 -0
  96. onnxruntime/tools/onnx_randomizer.py +85 -0
  97. onnxruntime/tools/onnxruntime_test.py +164 -0
  98. onnxruntime/tools/optimize_onnx_model.py +56 -0
  99. onnxruntime/tools/ort_format_model/__init__.py +27 -0
  100. onnxruntime/tools/ort_format_model/operator_type_usage_processors.py +653 -0
  101. onnxruntime/tools/ort_format_model/ort_flatbuffers_py/__init__.py +0 -0
  102. onnxruntime/tools/ort_format_model/ort_flatbuffers_py/fbs/ArgType.py +7 -0
  103. onnxruntime/tools/ort_format_model/ort_flatbuffers_py/fbs/ArgTypeAndIndex.py +67 -0
  104. onnxruntime/tools/ort_format_model/ort_flatbuffers_py/fbs/Attribute.py +337 -0
  105. onnxruntime/tools/ort_format_model/ort_flatbuffers_py/fbs/AttributeType.py +18 -0
  106. onnxruntime/tools/ort_format_model/ort_flatbuffers_py/fbs/Checkpoint.py +125 -0
  107. onnxruntime/tools/ort_format_model/ort_flatbuffers_py/fbs/DeprecatedKernelCreateInfos.py +120 -0
  108. onnxruntime/tools/ort_format_model/ort_flatbuffers_py/fbs/DeprecatedNodeIndexAndKernelDefHash.py +68 -0
  109. onnxruntime/tools/ort_format_model/ort_flatbuffers_py/fbs/DeprecatedSessionState.py +96 -0
  110. onnxruntime/tools/ort_format_model/ort_flatbuffers_py/fbs/DeprecatedSubGraphSessionState.py +72 -0
  111. onnxruntime/tools/ort_format_model/ort_flatbuffers_py/fbs/Dimension.py +71 -0
  112. onnxruntime/tools/ort_format_model/ort_flatbuffers_py/fbs/DimensionValue.py +80 -0
  113. onnxruntime/tools/ort_format_model/ort_flatbuffers_py/fbs/DimensionValueType.py +8 -0
  114. onnxruntime/tools/ort_format_model/ort_flatbuffers_py/fbs/EdgeEnd.py +32 -0
  115. onnxruntime/tools/ort_format_model/ort_flatbuffers_py/fbs/FloatProperty.py +67 -0
  116. onnxruntime/tools/ort_format_model/ort_flatbuffers_py/fbs/Graph.py +320 -0
  117. onnxruntime/tools/ort_format_model/ort_flatbuffers_py/fbs/InferenceSession.py +88 -0
  118. onnxruntime/tools/ort_format_model/ort_flatbuffers_py/fbs/IntProperty.py +67 -0
  119. onnxruntime/tools/ort_format_model/ort_flatbuffers_py/fbs/KernelTypeStrArgsEntry.py +91 -0
  120. onnxruntime/tools/ort_format_model/ort_flatbuffers_py/fbs/KernelTypeStrResolver.py +78 -0
  121. onnxruntime/tools/ort_format_model/ort_flatbuffers_py/fbs/MapType.py +71 -0
  122. onnxruntime/tools/ort_format_model/ort_flatbuffers_py/fbs/Model.py +223 -0
  123. onnxruntime/tools/ort_format_model/ort_flatbuffers_py/fbs/ModuleState.py +141 -0
  124. onnxruntime/tools/ort_format_model/ort_flatbuffers_py/fbs/Node.py +317 -0
  125. onnxruntime/tools/ort_format_model/ort_flatbuffers_py/fbs/NodeEdge.py +126 -0
  126. onnxruntime/tools/ort_format_model/ort_flatbuffers_py/fbs/NodeType.py +7 -0
  127. onnxruntime/tools/ort_format_model/ort_flatbuffers_py/fbs/NodesToOptimizeIndices.py +160 -0
  128. onnxruntime/tools/ort_format_model/ort_flatbuffers_py/fbs/OpIdKernelTypeStrArgsEntry.py +91 -0
  129. onnxruntime/tools/ort_format_model/ort_flatbuffers_py/fbs/OperatorSetId.py +67 -0
  130. onnxruntime/tools/ort_format_model/ort_flatbuffers_py/fbs/OptimizerGroup.py +117 -0
  131. onnxruntime/tools/ort_format_model/ort_flatbuffers_py/fbs/ParameterOptimizerState.py +91 -0
  132. onnxruntime/tools/ort_format_model/ort_flatbuffers_py/fbs/PropertyBag.py +152 -0
  133. onnxruntime/tools/ort_format_model/ort_flatbuffers_py/fbs/RuntimeOptimizationRecord.py +105 -0
  134. onnxruntime/tools/ort_format_model/ort_flatbuffers_py/fbs/RuntimeOptimizationRecordContainerEntry.py +91 -0
  135. onnxruntime/tools/ort_format_model/ort_flatbuffers_py/fbs/RuntimeOptimizations.py +79 -0
  136. onnxruntime/tools/ort_format_model/ort_flatbuffers_py/fbs/SequenceType.py +58 -0
  137. onnxruntime/tools/ort_format_model/ort_flatbuffers_py/fbs/Shape.py +78 -0
  138. onnxruntime/tools/ort_format_model/ort_flatbuffers_py/fbs/SparseTensor.py +114 -0
  139. onnxruntime/tools/ort_format_model/ort_flatbuffers_py/fbs/StringProperty.py +67 -0
  140. onnxruntime/tools/ort_format_model/ort_flatbuffers_py/fbs/StringStringEntry.py +67 -0
  141. onnxruntime/tools/ort_format_model/ort_flatbuffers_py/fbs/Tensor.py +203 -0
  142. onnxruntime/tools/ort_format_model/ort_flatbuffers_py/fbs/TensorDataType.py +26 -0
  143. onnxruntime/tools/ort_format_model/ort_flatbuffers_py/fbs/TensorTypeAndShape.py +71 -0
  144. onnxruntime/tools/ort_format_model/ort_flatbuffers_py/fbs/TypeInfo.py +83 -0
  145. onnxruntime/tools/ort_format_model/ort_flatbuffers_py/fbs/TypeInfoValue.py +9 -0
  146. onnxruntime/tools/ort_format_model/ort_flatbuffers_py/fbs/ValueInfo.py +84 -0
  147. onnxruntime/tools/ort_format_model/ort_flatbuffers_py/fbs/__init__.py +6 -0
  148. onnxruntime/tools/ort_format_model/ort_model_processor.py +86 -0
  149. onnxruntime/tools/ort_format_model/types.py +85 -0
  150. onnxruntime/tools/ort_format_model/utils.py +61 -0
  151. onnxruntime/tools/pytorch_export_contrib_ops.py +129 -0
  152. onnxruntime/tools/pytorch_export_helpers.py +131 -0
  153. onnxruntime/tools/qdq_helpers/__init__.py +0 -0
  154. onnxruntime/tools/qdq_helpers/optimize_qdq_model.py +37 -0
  155. onnxruntime/tools/qnn/add_trans_cast.py +292 -0
  156. onnxruntime/tools/qnn/gen_qnn_ctx_onnx_model.py +364 -0
  157. onnxruntime/tools/qnn/preprocess.py +165 -0
  158. onnxruntime/tools/reduced_build_config_parser.py +203 -0
  159. onnxruntime/tools/remove_initializer_from_input.py +37 -0
  160. onnxruntime/tools/symbolic_shape_infer.py +3094 -0
  161. onnxruntime/tools/update_onnx_opset.py +31 -0
  162. onnxruntime/transformers/__init__.py +8 -0
  163. onnxruntime/transformers/affinity_helper.py +40 -0
  164. onnxruntime/transformers/benchmark.py +942 -0
  165. onnxruntime/transformers/benchmark_helper.py +643 -0
  166. onnxruntime/transformers/bert_perf_test.py +629 -0
  167. onnxruntime/transformers/bert_test_data.py +641 -0
  168. onnxruntime/transformers/compare_bert_results.py +256 -0
  169. onnxruntime/transformers/constants.py +47 -0
  170. onnxruntime/transformers/convert_generation.py +3605 -0
  171. onnxruntime/transformers/convert_tf_models_to_pytorch.py +205 -0
  172. onnxruntime/transformers/convert_to_packing_mode.py +385 -0
  173. onnxruntime/transformers/dynamo_onnx_helper.py +205 -0
  174. onnxruntime/transformers/float16.py +501 -0
  175. onnxruntime/transformers/fusion_attention.py +1189 -0
  176. onnxruntime/transformers/fusion_attention_clip.py +340 -0
  177. onnxruntime/transformers/fusion_attention_sam2.py +533 -0
  178. onnxruntime/transformers/fusion_attention_unet.py +1307 -0
  179. onnxruntime/transformers/fusion_attention_vae.py +300 -0
  180. onnxruntime/transformers/fusion_bart_attention.py +435 -0
  181. onnxruntime/transformers/fusion_base.py +141 -0
  182. onnxruntime/transformers/fusion_bias_add.py +57 -0
  183. onnxruntime/transformers/fusion_biasgelu.py +66 -0
  184. onnxruntime/transformers/fusion_biassplitgelu.py +110 -0
  185. onnxruntime/transformers/fusion_conformer_attention.py +222 -0
  186. onnxruntime/transformers/fusion_constant_fold.py +144 -0
  187. onnxruntime/transformers/fusion_embedlayer.py +810 -0
  188. onnxruntime/transformers/fusion_fastgelu.py +492 -0
  189. onnxruntime/transformers/fusion_gelu.py +258 -0
  190. onnxruntime/transformers/fusion_gelu_approximation.py +25 -0
  191. onnxruntime/transformers/fusion_gemmfastgelu.py +121 -0
  192. onnxruntime/transformers/fusion_gpt_attention.py +546 -0
  193. onnxruntime/transformers/fusion_gpt_attention_megatron.py +355 -0
  194. onnxruntime/transformers/fusion_gpt_attention_no_past.py +260 -0
  195. onnxruntime/transformers/fusion_group_norm.py +180 -0
  196. onnxruntime/transformers/fusion_layernorm.py +489 -0
  197. onnxruntime/transformers/fusion_mha_mmdit.py +667 -0
  198. onnxruntime/transformers/fusion_nhwc_conv.py +99 -0
  199. onnxruntime/transformers/fusion_options.py +340 -0
  200. onnxruntime/transformers/fusion_qordered_attention.py +420 -0
  201. onnxruntime/transformers/fusion_qordered_gelu.py +118 -0
  202. onnxruntime/transformers/fusion_qordered_layernorm.py +122 -0
  203. onnxruntime/transformers/fusion_qordered_matmul.py +216 -0
  204. onnxruntime/transformers/fusion_quickgelu.py +74 -0
  205. onnxruntime/transformers/fusion_reshape.py +173 -0
  206. onnxruntime/transformers/fusion_rotary_attention.py +1591 -0
  207. onnxruntime/transformers/fusion_shape.py +109 -0
  208. onnxruntime/transformers/fusion_simplified_layernorm.py +165 -0
  209. onnxruntime/transformers/fusion_skip_group_norm.py +254 -0
  210. onnxruntime/transformers/fusion_skiplayernorm.py +209 -0
  211. onnxruntime/transformers/fusion_transpose.py +167 -0
  212. onnxruntime/transformers/fusion_utils.py +321 -0
  213. onnxruntime/transformers/huggingface_models.py +74 -0
  214. onnxruntime/transformers/import_utils.py +20 -0
  215. onnxruntime/transformers/io_binding_helper.py +487 -0
  216. onnxruntime/transformers/large_model_exporter.py +395 -0
  217. onnxruntime/transformers/machine_info.py +230 -0
  218. onnxruntime/transformers/metrics.py +163 -0
  219. onnxruntime/transformers/models/bart/__init__.py +12 -0
  220. onnxruntime/transformers/models/bart/export.py +98 -0
  221. onnxruntime/transformers/models/bert/__init__.py +12 -0
  222. onnxruntime/transformers/models/bert/eval_squad.py +329 -0
  223. onnxruntime/transformers/models/gpt2/__init__.py +12 -0
  224. onnxruntime/transformers/models/gpt2/benchmark_gpt2.py +413 -0
  225. onnxruntime/transformers/models/gpt2/convert_to_onnx.py +566 -0
  226. onnxruntime/transformers/models/gpt2/gpt2_helper.py +1031 -0
  227. onnxruntime/transformers/models/gpt2/gpt2_parity.py +513 -0
  228. onnxruntime/transformers/models/gpt2/gpt2_tester.py +501 -0
  229. onnxruntime/transformers/models/gpt2/parity_check_helper.py +146 -0
  230. onnxruntime/transformers/models/llama/__init__.py +12 -0
  231. onnxruntime/transformers/models/llama/benchmark.py +700 -0
  232. onnxruntime/transformers/models/llama/benchmark_all.py +488 -0
  233. onnxruntime/transformers/models/llama/benchmark_e2e.py +608 -0
  234. onnxruntime/transformers/models/llama/convert_to_onnx.py +1064 -0
  235. onnxruntime/transformers/models/llama/dist_settings.py +57 -0
  236. onnxruntime/transformers/models/llama/llama_inputs.py +504 -0
  237. onnxruntime/transformers/models/llama/llama_parity.py +343 -0
  238. onnxruntime/transformers/models/llama/llama_torch.py +47 -0
  239. onnxruntime/transformers/models/llama/quant_kv_dataloader.py +108 -0
  240. onnxruntime/transformers/models/longformer/__init__.py +12 -0
  241. onnxruntime/transformers/models/longformer/benchmark_longformer.py +821 -0
  242. onnxruntime/transformers/models/longformer/convert_to_onnx.py +413 -0
  243. onnxruntime/transformers/models/longformer/generate_test_data.py +347 -0
  244. onnxruntime/transformers/models/longformer/longformer_helper.py +76 -0
  245. onnxruntime/transformers/models/phi2/__init__.py +12 -0
  246. onnxruntime/transformers/models/phi2/convert_to_onnx.py +590 -0
  247. onnxruntime/transformers/models/phi2/inference_example.py +414 -0
  248. onnxruntime/transformers/models/sam2/__init__.py +12 -0
  249. onnxruntime/transformers/models/sam2/benchmark_sam2.py +638 -0
  250. onnxruntime/transformers/models/sam2/convert_to_onnx.py +270 -0
  251. onnxruntime/transformers/models/sam2/image_decoder.py +272 -0
  252. onnxruntime/transformers/models/sam2/image_encoder.py +236 -0
  253. onnxruntime/transformers/models/sam2/mask_decoder.py +208 -0
  254. onnxruntime/transformers/models/sam2/nvtx_helper.py +33 -0
  255. onnxruntime/transformers/models/sam2/prompt_encoder.py +189 -0
  256. onnxruntime/transformers/models/sam2/sam2_demo.py +321 -0
  257. onnxruntime/transformers/models/sam2/sam2_image_onnx_predictor.py +279 -0
  258. onnxruntime/transformers/models/sam2/sam2_utils.py +147 -0
  259. onnxruntime/transformers/models/stable_diffusion/__init__.py +12 -0
  260. onnxruntime/transformers/models/stable_diffusion/benchmark.py +1519 -0
  261. onnxruntime/transformers/models/stable_diffusion/benchmark_controlnet.py +426 -0
  262. onnxruntime/transformers/models/stable_diffusion/demo_txt2img.py +103 -0
  263. onnxruntime/transformers/models/stable_diffusion/demo_txt2img_xl.py +269 -0
  264. onnxruntime/transformers/models/stable_diffusion/demo_utils.py +778 -0
  265. onnxruntime/transformers/models/stable_diffusion/diffusion_models.py +1318 -0
  266. onnxruntime/transformers/models/stable_diffusion/diffusion_schedulers.py +1179 -0
  267. onnxruntime/transformers/models/stable_diffusion/engine_builder.py +295 -0
  268. onnxruntime/transformers/models/stable_diffusion/engine_builder_ort_cuda.py +387 -0
  269. onnxruntime/transformers/models/stable_diffusion/engine_builder_ort_trt.py +288 -0
  270. onnxruntime/transformers/models/stable_diffusion/engine_builder_tensorrt.py +395 -0
  271. onnxruntime/transformers/models/stable_diffusion/engine_builder_torch.py +108 -0
  272. onnxruntime/transformers/models/stable_diffusion/optimize_pipeline.py +590 -0
  273. onnxruntime/transformers/models/stable_diffusion/ort_optimizer.py +136 -0
  274. onnxruntime/transformers/models/stable_diffusion/pipeline_stable_diffusion.py +831 -0
  275. onnxruntime/transformers/models/stable_diffusion/trt_utilities.py +12 -0
  276. onnxruntime/transformers/models/t5/__init__.py +12 -0
  277. onnxruntime/transformers/models/t5/convert_to_onnx.py +318 -0
  278. onnxruntime/transformers/models/t5/t5_decoder.py +437 -0
  279. onnxruntime/transformers/models/t5/t5_encoder.py +70 -0
  280. onnxruntime/transformers/models/t5/t5_encoder_decoder_init.py +361 -0
  281. onnxruntime/transformers/models/t5/t5_helper.py +302 -0
  282. onnxruntime/transformers/models/whisper/__init__.py +12 -0
  283. onnxruntime/transformers/models/whisper/benchmark.py +585 -0
  284. onnxruntime/transformers/models/whisper/benchmark_all.py +526 -0
  285. onnxruntime/transformers/models/whisper/convert_to_onnx.py +609 -0
  286. onnxruntime/transformers/models/whisper/whisper_chain.py +334 -0
  287. onnxruntime/transformers/models/whisper/whisper_decoder.py +464 -0
  288. onnxruntime/transformers/models/whisper/whisper_encoder.py +164 -0
  289. onnxruntime/transformers/models/whisper/whisper_encoder_decoder_init.py +371 -0
  290. onnxruntime/transformers/models/whisper/whisper_helper.py +1035 -0
  291. onnxruntime/transformers/models/whisper/whisper_inputs.py +380 -0
  292. onnxruntime/transformers/models/whisper/whisper_jump_times.py +477 -0
  293. onnxruntime/transformers/onnx_exporter.py +719 -0
  294. onnxruntime/transformers/onnx_model.py +1636 -0
  295. onnxruntime/transformers/onnx_model_bart.py +141 -0
  296. onnxruntime/transformers/onnx_model_bert.py +488 -0
  297. onnxruntime/transformers/onnx_model_bert_keras.py +474 -0
  298. onnxruntime/transformers/onnx_model_bert_tf.py +588 -0
  299. onnxruntime/transformers/onnx_model_clip.py +42 -0
  300. onnxruntime/transformers/onnx_model_conformer.py +32 -0
  301. onnxruntime/transformers/onnx_model_gpt2.py +101 -0
  302. onnxruntime/transformers/onnx_model_mmdit.py +112 -0
  303. onnxruntime/transformers/onnx_model_phi.py +929 -0
  304. onnxruntime/transformers/onnx_model_sam2.py +137 -0
  305. onnxruntime/transformers/onnx_model_t5.py +985 -0
  306. onnxruntime/transformers/onnx_model_tnlr.py +226 -0
  307. onnxruntime/transformers/onnx_model_unet.py +258 -0
  308. onnxruntime/transformers/onnx_model_vae.py +42 -0
  309. onnxruntime/transformers/onnx_utils.py +55 -0
  310. onnxruntime/transformers/optimizer.py +620 -0
  311. onnxruntime/transformers/past_helper.py +149 -0
  312. onnxruntime/transformers/profile_result_processor.py +358 -0
  313. onnxruntime/transformers/profiler.py +434 -0
  314. onnxruntime/transformers/quantize_helper.py +76 -0
  315. onnxruntime/transformers/shape_infer_helper.py +121 -0
  316. onnxruntime/transformers/shape_optimizer.py +400 -0
  317. onnxruntime/transformers/torch_onnx_export_helper.py +74 -0
  318. onnxruntime_directml-1.24.1.dist-info/METADATA +216 -0
  319. onnxruntime_directml-1.24.1.dist-info/RECORD +322 -0
  320. onnxruntime_directml-1.24.1.dist-info/WHEEL +5 -0
  321. onnxruntime_directml-1.24.1.dist-info/entry_points.txt +2 -0
  322. onnxruntime_directml-1.24.1.dist-info/top_level.txt +1 -0
@@ -0,0 +1,347 @@
1
+ # -------------------------------------------------------------------------
2
+ # Copyright (c) Microsoft Corporation. All rights reserved.
3
+ # Licensed under the MIT License.
4
+ # --------------------------------------------------------------------------
5
+
6
+ # Generate test data for a longformer model, so that we can use onnxruntime_perf_test.exe to evaluate the inference latency.
7
+
8
+ import argparse
9
+ import os
10
+ import random
11
+ from pathlib import Path
12
+
13
+ import numpy as np
14
+ from bert_test_data import fake_input_ids_data, fake_input_mask_data, output_test_data
15
+ from onnx import ModelProto, TensorProto
16
+ from onnx_model import OnnxModel
17
+
18
+
19
+ def parse_arguments():
20
+ parser = argparse.ArgumentParser()
21
+
22
+ parser.add_argument("--model", required=True, type=str, help="bert onnx model path.")
23
+
24
+ parser.add_argument(
25
+ "--output_dir",
26
+ required=False,
27
+ type=str,
28
+ default=None,
29
+ help="output test data path. If not specified, .",
30
+ )
31
+
32
+ parser.add_argument("--batch_size", required=False, type=int, default=1, help="batch size of input")
33
+
34
+ parser.add_argument(
35
+ "--sequence_length",
36
+ required=False,
37
+ type=int,
38
+ default=128,
39
+ help="maximum sequence length of input",
40
+ )
41
+
42
+ parser.add_argument(
43
+ "-a",
44
+ "--average_sequence_length",
45
+ default=-1,
46
+ type=int,
47
+ help="average sequence length excluding padding",
48
+ )
49
+
50
+ parser.add_argument(
51
+ "-r",
52
+ "--random_sequence_length",
53
+ required=False,
54
+ action="store_true",
55
+ help="use uniform random instead of fixed sequence length",
56
+ )
57
+ parser.set_defaults(random_sequence_length=False)
58
+
59
+ parser.add_argument(
60
+ "--global_tokens",
61
+ required=False,
62
+ type=int,
63
+ default=10,
64
+ help="number of global tokens",
65
+ )
66
+
67
+ parser.add_argument(
68
+ "--input_ids_name",
69
+ required=False,
70
+ type=str,
71
+ default=None,
72
+ help="input name for input ids",
73
+ )
74
+
75
+ parser.add_argument(
76
+ "--input_mask_name",
77
+ required=False,
78
+ type=str,
79
+ default=None,
80
+ help="input name for attention mask",
81
+ )
82
+
83
+ parser.add_argument(
84
+ "--global_mask_name",
85
+ required=False,
86
+ type=str,
87
+ default=None,
88
+ help="input name for global attention mask",
89
+ )
90
+
91
+ parser.add_argument(
92
+ "--samples",
93
+ required=False,
94
+ type=int,
95
+ default=1,
96
+ help="number of test cases to be generated",
97
+ )
98
+
99
+ parser.add_argument("--seed", required=False, type=int, default=3, help="random seed")
100
+
101
+ parser.add_argument(
102
+ "--verbose",
103
+ required=False,
104
+ action="store_true",
105
+ help="print verbose information",
106
+ )
107
+ parser.set_defaults(verbose=False)
108
+
109
+ args = parser.parse_args()
110
+ return args
111
+
112
+
113
+ def get_longformer_inputs(onnx_file, input_ids_name=None, input_mask_name=None, global_mask_name=None):
114
+ """
115
+ Get graph inputs for longformer model.
116
+ """
117
+ model = ModelProto()
118
+ with open(onnx_file, "rb") as f:
119
+ model.ParseFromString(f.read())
120
+
121
+ onnx_model = OnnxModel(model)
122
+ graph_inputs = onnx_model.get_graph_inputs_excluding_initializers()
123
+
124
+ if input_ids_name is not None:
125
+ input_ids = onnx_model.find_graph_input(input_ids_name)
126
+ if input_ids is None:
127
+ raise ValueError(f"Graph does not have input named {input_ids_name}")
128
+
129
+ input_mask = None
130
+ if input_mask_name:
131
+ input_mask = onnx_model.find_graph_input(input_mask_name)
132
+ if input_mask is None:
133
+ raise ValueError(f"Graph does not have input named {input_mask_name}")
134
+
135
+ global_mask = None
136
+ if global_mask_name:
137
+ global_mask = onnx_model.find_graph_input(global_mask_name)
138
+ if global_mask is None:
139
+ raise ValueError(f"Graph does not have input named {global_mask_name}")
140
+
141
+ expected_inputs = 1 + (1 if input_mask else 0) + (1 if global_mask else 0)
142
+ if len(graph_inputs) != expected_inputs:
143
+ raise ValueError(f"Expect the graph to have {expected_inputs} inputs. Got {len(graph_inputs)}")
144
+
145
+ return input_ids, input_mask, global_mask
146
+
147
+ if len(graph_inputs) != 3:
148
+ raise ValueError(f"Expect the graph to have 3 inputs. Got {len(graph_inputs)}")
149
+
150
+ # Try guess the inputs based on naming.
151
+ input_ids = None
152
+ input_mask = None
153
+ global_mask = None
154
+ for input in graph_inputs:
155
+ input_name_lower = input.name.lower()
156
+ if "global" in input_name_lower:
157
+ global_mask = input
158
+ elif "mask" in input_name_lower:
159
+ input_mask = input
160
+ else:
161
+ input_ids = input
162
+
163
+ if input_ids and input_mask and global_mask:
164
+ return input_ids, input_mask, global_mask
165
+
166
+ raise ValueError("Fail to assign 3 inputs. You might try rename the graph inputs.")
167
+
168
+
169
+ def fake_global_mask_data(global_mask, batch_size, sequence_length, num_global_tokens):
170
+ """
171
+ Fake data based on the graph input of segment_ids.
172
+ Args:
173
+ segment_ids (TensorProto): graph input of input tensor.
174
+ Returns:
175
+ data (np.array): the data for input tensor
176
+ """
177
+ data_type = global_mask.type.tensor_type.elem_type
178
+ assert data_type in [TensorProto.FLOAT, TensorProto.INT32, TensorProto.INT64]
179
+
180
+ if num_global_tokens > 0:
181
+ assert num_global_tokens <= sequence_length
182
+ data = np.zeros((batch_size, sequence_length), dtype=np.int32)
183
+ temp = np.ones((batch_size, num_global_tokens), dtype=np.int32)
184
+ data[: temp.shape[0], : temp.shape[1]] = temp
185
+ else:
186
+ data = np.zeros((batch_size, sequence_length), dtype=np.int32)
187
+
188
+ if data_type == TensorProto.FLOAT:
189
+ data = np.float32(data)
190
+ elif data_type == TensorProto.INT64:
191
+ data = np.int64(data)
192
+
193
+ return data
194
+
195
+
196
+ def fake_test_data(
197
+ batch_size,
198
+ sequence_length,
199
+ test_cases,
200
+ dictionary_size,
201
+ verbose,
202
+ random_seed,
203
+ input_ids,
204
+ input_mask,
205
+ global_mask,
206
+ num_global_tokens,
207
+ average_sequence_length,
208
+ random_sequence_length,
209
+ ):
210
+ """
211
+ Generate fake input data for test.
212
+ """
213
+ assert input_ids is not None
214
+
215
+ np.random.seed(random_seed)
216
+ random.seed(random_seed)
217
+
218
+ all_inputs = []
219
+ for _ in range(test_cases):
220
+ input_1 = fake_input_ids_data(input_ids, batch_size, sequence_length, dictionary_size)
221
+ inputs = {input_ids.name: input_1}
222
+
223
+ if input_mask:
224
+ inputs[input_mask.name] = fake_input_mask_data(
225
+ input_mask, batch_size, sequence_length, average_sequence_length, random_sequence_length
226
+ )
227
+
228
+ if global_mask:
229
+ inputs[global_mask.name] = fake_global_mask_data(
230
+ global_mask, batch_size, sequence_length, num_global_tokens
231
+ )
232
+
233
+ if verbose and len(all_inputs) == 0:
234
+ print("Example inputs", inputs)
235
+ all_inputs.append(inputs)
236
+
237
+ return all_inputs
238
+
239
+
240
+ def generate_test_data(
241
+ batch_size,
242
+ sequence_length,
243
+ test_cases,
244
+ seed,
245
+ verbose,
246
+ input_ids,
247
+ input_mask,
248
+ global_mask,
249
+ num_global_tokens,
250
+ average_sequence_length,
251
+ random_sequence_length,
252
+ ):
253
+ dictionary_size = 10000
254
+ all_inputs = fake_test_data(
255
+ batch_size,
256
+ sequence_length,
257
+ test_cases,
258
+ dictionary_size,
259
+ verbose,
260
+ seed,
261
+ input_ids,
262
+ input_mask,
263
+ global_mask,
264
+ num_global_tokens,
265
+ average_sequence_length,
266
+ random_sequence_length,
267
+ )
268
+ if len(all_inputs) != test_cases:
269
+ print("Failed to create test data for test.")
270
+ return all_inputs
271
+
272
+
273
+ def create_longformer_test_data(
274
+ model,
275
+ output_dir,
276
+ batch_size,
277
+ sequence_length,
278
+ test_cases,
279
+ seed,
280
+ verbose,
281
+ input_ids_name,
282
+ input_mask_name,
283
+ global_mask_name,
284
+ num_global_tokens,
285
+ average_sequence_length,
286
+ random_sequence_length,
287
+ ):
288
+ input_ids, input_mask, global_mask = get_longformer_inputs(model, input_ids_name, input_mask_name, global_mask_name)
289
+ all_inputs = generate_test_data(
290
+ batch_size,
291
+ sequence_length,
292
+ test_cases,
293
+ seed,
294
+ verbose,
295
+ input_ids,
296
+ input_mask,
297
+ global_mask,
298
+ num_global_tokens,
299
+ average_sequence_length,
300
+ random_sequence_length,
301
+ )
302
+
303
+ for i, inputs in enumerate(all_inputs):
304
+ output_test_data(output_dir, i, inputs)
305
+
306
+
307
+ def main():
308
+ args = parse_arguments()
309
+
310
+ output_dir = args.output_dir
311
+ if output_dir is None:
312
+ # Default output directory is a sub-directory under the directory of model.
313
+ output_dir = os.path.join(
314
+ Path(args.model).parent,
315
+ f"b{args.batch_size}_s{args.sequence_length}_g{args.global_tokens}",
316
+ )
317
+
318
+ if output_dir is not None:
319
+ # create the output directory if not existed
320
+ path = Path(output_dir)
321
+ path.mkdir(parents=True, exist_ok=True)
322
+ else:
323
+ print("Directory existed. test data files will be overwritten.")
324
+
325
+ if args.average_sequence_length <= 0:
326
+ args.average_sequence_length = args.sequence_length
327
+
328
+ create_longformer_test_data(
329
+ args.model,
330
+ output_dir,
331
+ args.batch_size,
332
+ args.sequence_length,
333
+ args.samples,
334
+ args.seed,
335
+ args.verbose,
336
+ args.input_ids_name,
337
+ args.input_mask_name,
338
+ args.global_mask_name,
339
+ args.global_tokens,
340
+ args.average_sequence_length,
341
+ )
342
+
343
+ print("Test data is saved to directory:", output_dir)
344
+
345
+
346
+ if __name__ == "__main__":
347
+ main()
@@ -0,0 +1,76 @@
1
+ # -------------------------------------------------------------------------
2
+ # Copyright (c) Microsoft Corporation. All rights reserved.
3
+ # Licensed under the MIT License. See License.txt in the project root for
4
+ # license information.
5
+ # --------------------------------------------------------------------------
6
+ # This script helps creating dummy inputs for Longformer model.
7
+
8
+ import logging
9
+
10
+ import numpy
11
+ import torch
12
+
13
+ logger = logging.getLogger(__name__)
14
+
15
+ PRETRAINED_LONGFORMER_MODELS = {
16
+ "longformer-base-4096": "allenai/longformer-base-4096",
17
+ "longformer-large-4096": "allenai/longformer-large-4096",
18
+ "longformer-random-tiny": "patrickvonplaten/longformer-random-tiny", # A tiny model for debugging
19
+ }
20
+
21
+
22
+ class LongformerInputs:
23
+ def __init__(self, input_ids, attention_mask, global_attention_mask):
24
+ self.input_ids: torch.LongTensor = input_ids
25
+ self.attention_mask: torch.FloatTensor | torch.HalfTensor = attention_mask
26
+ self.global_attention_mask: torch.FloatTensor | torch.HalfTensor = global_attention_mask
27
+
28
+ def to_list(self) -> list:
29
+ return [v for v in [self.input_ids, self.attention_mask, self.global_attention_mask] if v is not None]
30
+
31
+ def to_tuple(self) -> tuple:
32
+ return tuple(v for v in self.to_list())
33
+
34
+ def get_ort_inputs(self) -> dict:
35
+ return {
36
+ "input_ids": numpy.ascontiguousarray(self.input_ids.cpu().numpy()),
37
+ "attention_mask": numpy.ascontiguousarray(self.attention_mask.cpu().numpy()),
38
+ "global_attention_mask": numpy.ascontiguousarray(self.global_attention_mask.cpu().numpy()),
39
+ }
40
+
41
+
42
+ class LongformerHelper:
43
+ """A helper class for Longformer model conversion, inference and verification."""
44
+
45
+ @staticmethod
46
+ def get_dummy_inputs(
47
+ batch_size: int,
48
+ sequence_length: int,
49
+ num_global_tokens: int,
50
+ device: torch.device,
51
+ vocab_size: int = 100,
52
+ ) -> LongformerInputs:
53
+ """Create random inputs for Longformer model.
54
+ Returns torch tensors of input_ids, attention_mask and global_attention_mask tensors.
55
+ """
56
+
57
+ input_ids = torch.randint(
58
+ low=0,
59
+ high=vocab_size - 1,
60
+ size=(batch_size, sequence_length),
61
+ dtype=torch.long,
62
+ device=device,
63
+ )
64
+ attention_mask = torch.ones(input_ids.shape, dtype=torch.long, device=device)
65
+ global_attention_mask = torch.zeros(input_ids.shape, dtype=torch.long, device=device)
66
+ global_token_index = list(range(num_global_tokens))
67
+ global_attention_mask[:, global_token_index] = 1
68
+ return LongformerInputs(input_ids, attention_mask, global_attention_mask)
69
+
70
+ @staticmethod
71
+ def get_output_shapes(batch_size: int, sequence_length: int, hidden_size: int) -> dict[str, list[int]]:
72
+ """Returns a dictionary with output name as key, and shape as value."""
73
+ return {
74
+ "last_state": [batch_size, sequence_length, hidden_size],
75
+ "pooler": [batch_size, sequence_length],
76
+ }
@@ -0,0 +1,12 @@
1
+ # -------------------------------------------------------------------------
2
+ # Copyright (c) Microsoft Corporation. All rights reserved.
3
+ # Licensed under the MIT License.
4
+ # --------------------------------------------------------------------------
5
+ import os
6
+ import sys
7
+
8
+ sys.path.append(os.path.dirname(__file__))
9
+
10
+ transformers_dir = os.path.normpath(os.path.join(os.path.dirname(__file__), "..", ".."))
11
+ if transformers_dir not in sys.path:
12
+ sys.path.append(transformers_dir)