onnxruntime-directml 1.24.1__cp314-cp314-win_amd64.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (322) hide show
  1. onnxruntime/LICENSE +21 -0
  2. onnxruntime/Privacy.md +21 -0
  3. onnxruntime/ThirdPartyNotices.txt +6121 -0
  4. onnxruntime/__init__.py +418 -0
  5. onnxruntime/backend/__init__.py +6 -0
  6. onnxruntime/backend/backend.py +175 -0
  7. onnxruntime/backend/backend_rep.py +52 -0
  8. onnxruntime/capi/DirectML.dll +0 -0
  9. onnxruntime/capi/__init__.py +4 -0
  10. onnxruntime/capi/_ld_preload.py +7 -0
  11. onnxruntime/capi/_pybind_state.py +33 -0
  12. onnxruntime/capi/build_and_package_info.py +2 -0
  13. onnxruntime/capi/convert_npz_to_onnx_adapter.py +48 -0
  14. onnxruntime/capi/onnxruntime.dll +0 -0
  15. onnxruntime/capi/onnxruntime_collect_build_info.py +47 -0
  16. onnxruntime/capi/onnxruntime_inference_collection.py +1440 -0
  17. onnxruntime/capi/onnxruntime_providers_shared.dll +0 -0
  18. onnxruntime/capi/onnxruntime_pybind11_state.pyd +0 -0
  19. onnxruntime/capi/onnxruntime_validation.py +154 -0
  20. onnxruntime/capi/version_info.py +2 -0
  21. onnxruntime/datasets/__init__.py +18 -0
  22. onnxruntime/datasets/logreg_iris.onnx +0 -0
  23. onnxruntime/datasets/mul_1.onnx +0 -0
  24. onnxruntime/datasets/sigmoid.onnx +13 -0
  25. onnxruntime/quantization/CalTableFlatBuffers/KeyValue.py +78 -0
  26. onnxruntime/quantization/CalTableFlatBuffers/TrtTable.py +90 -0
  27. onnxruntime/quantization/CalTableFlatBuffers/__init__.py +0 -0
  28. onnxruntime/quantization/__init__.py +19 -0
  29. onnxruntime/quantization/base_quantizer.py +529 -0
  30. onnxruntime/quantization/calibrate.py +1267 -0
  31. onnxruntime/quantization/execution_providers/qnn/__init__.py +2 -0
  32. onnxruntime/quantization/execution_providers/qnn/fusion_lpnorm.py +132 -0
  33. onnxruntime/quantization/execution_providers/qnn/fusion_spacetodepth.py +162 -0
  34. onnxruntime/quantization/execution_providers/qnn/mixed_precision_overrides_utils.py +413 -0
  35. onnxruntime/quantization/execution_providers/qnn/preprocess.py +353 -0
  36. onnxruntime/quantization/execution_providers/qnn/quant_config.py +389 -0
  37. onnxruntime/quantization/fusions/__init__.py +4 -0
  38. onnxruntime/quantization/fusions/fusion.py +311 -0
  39. onnxruntime/quantization/fusions/fusion_gelu.py +272 -0
  40. onnxruntime/quantization/fusions/fusion_layernorm.py +146 -0
  41. onnxruntime/quantization/fusions/replace_upsample_with_resize.py +96 -0
  42. onnxruntime/quantization/matmul_bnb4_quantizer.py +239 -0
  43. onnxruntime/quantization/matmul_nbits_quantizer.py +1638 -0
  44. onnxruntime/quantization/neural_compressor/__init__.py +1 -0
  45. onnxruntime/quantization/neural_compressor/onnx_model.py +1251 -0
  46. onnxruntime/quantization/neural_compressor/util.py +80 -0
  47. onnxruntime/quantization/neural_compressor/weight_only.py +932 -0
  48. onnxruntime/quantization/onnx_model.py +600 -0
  49. onnxruntime/quantization/onnx_quantizer.py +1163 -0
  50. onnxruntime/quantization/operators/__init__.py +2 -0
  51. onnxruntime/quantization/operators/activation.py +119 -0
  52. onnxruntime/quantization/operators/argmax.py +18 -0
  53. onnxruntime/quantization/operators/attention.py +73 -0
  54. onnxruntime/quantization/operators/base_operator.py +26 -0
  55. onnxruntime/quantization/operators/binary_op.py +72 -0
  56. onnxruntime/quantization/operators/concat.py +62 -0
  57. onnxruntime/quantization/operators/conv.py +260 -0
  58. onnxruntime/quantization/operators/direct_q8.py +78 -0
  59. onnxruntime/quantization/operators/embed_layernorm.py +121 -0
  60. onnxruntime/quantization/operators/gather.py +64 -0
  61. onnxruntime/quantization/operators/gavgpool.py +62 -0
  62. onnxruntime/quantization/operators/gemm.py +172 -0
  63. onnxruntime/quantization/operators/lstm.py +121 -0
  64. onnxruntime/quantization/operators/matmul.py +231 -0
  65. onnxruntime/quantization/operators/maxpool.py +34 -0
  66. onnxruntime/quantization/operators/norm.py +40 -0
  67. onnxruntime/quantization/operators/pad.py +172 -0
  68. onnxruntime/quantization/operators/pooling.py +67 -0
  69. onnxruntime/quantization/operators/qdq_base_operator.py +22 -0
  70. onnxruntime/quantization/operators/resize.py +34 -0
  71. onnxruntime/quantization/operators/softmax.py +74 -0
  72. onnxruntime/quantization/operators/split.py +63 -0
  73. onnxruntime/quantization/operators/where.py +87 -0
  74. onnxruntime/quantization/preprocess.py +141 -0
  75. onnxruntime/quantization/qdq_loss_debug.py +389 -0
  76. onnxruntime/quantization/qdq_quantizer.py +1477 -0
  77. onnxruntime/quantization/quant_utils.py +1051 -0
  78. onnxruntime/quantization/quantize.py +953 -0
  79. onnxruntime/quantization/registry.py +110 -0
  80. onnxruntime/quantization/shape_inference.py +204 -0
  81. onnxruntime/quantization/static_quantize_runner.py +256 -0
  82. onnxruntime/quantization/tensor_quant_overrides.py +520 -0
  83. onnxruntime/tools/__init__.py +10 -0
  84. onnxruntime/tools/check_onnx_model_mobile_usability.py +47 -0
  85. onnxruntime/tools/convert_onnx_models_to_ort.py +380 -0
  86. onnxruntime/tools/file_utils.py +47 -0
  87. onnxruntime/tools/logger.py +11 -0
  88. onnxruntime/tools/make_dynamic_shape_fixed.py +73 -0
  89. onnxruntime/tools/mobile_helpers/__init__.py +0 -0
  90. onnxruntime/tools/mobile_helpers/coreml_supported_mlprogram_ops.md +53 -0
  91. onnxruntime/tools/mobile_helpers/coreml_supported_neuralnetwork_ops.md +43 -0
  92. onnxruntime/tools/mobile_helpers/nnapi_supported_ops.md +58 -0
  93. onnxruntime/tools/mobile_helpers/usability_checker.py +738 -0
  94. onnxruntime/tools/offline_tuning.py +169 -0
  95. onnxruntime/tools/onnx_model_utils.py +416 -0
  96. onnxruntime/tools/onnx_randomizer.py +85 -0
  97. onnxruntime/tools/onnxruntime_test.py +164 -0
  98. onnxruntime/tools/optimize_onnx_model.py +56 -0
  99. onnxruntime/tools/ort_format_model/__init__.py +27 -0
  100. onnxruntime/tools/ort_format_model/operator_type_usage_processors.py +653 -0
  101. onnxruntime/tools/ort_format_model/ort_flatbuffers_py/__init__.py +0 -0
  102. onnxruntime/tools/ort_format_model/ort_flatbuffers_py/fbs/ArgType.py +7 -0
  103. onnxruntime/tools/ort_format_model/ort_flatbuffers_py/fbs/ArgTypeAndIndex.py +67 -0
  104. onnxruntime/tools/ort_format_model/ort_flatbuffers_py/fbs/Attribute.py +337 -0
  105. onnxruntime/tools/ort_format_model/ort_flatbuffers_py/fbs/AttributeType.py +18 -0
  106. onnxruntime/tools/ort_format_model/ort_flatbuffers_py/fbs/Checkpoint.py +125 -0
  107. onnxruntime/tools/ort_format_model/ort_flatbuffers_py/fbs/DeprecatedKernelCreateInfos.py +120 -0
  108. onnxruntime/tools/ort_format_model/ort_flatbuffers_py/fbs/DeprecatedNodeIndexAndKernelDefHash.py +68 -0
  109. onnxruntime/tools/ort_format_model/ort_flatbuffers_py/fbs/DeprecatedSessionState.py +96 -0
  110. onnxruntime/tools/ort_format_model/ort_flatbuffers_py/fbs/DeprecatedSubGraphSessionState.py +72 -0
  111. onnxruntime/tools/ort_format_model/ort_flatbuffers_py/fbs/Dimension.py +71 -0
  112. onnxruntime/tools/ort_format_model/ort_flatbuffers_py/fbs/DimensionValue.py +80 -0
  113. onnxruntime/tools/ort_format_model/ort_flatbuffers_py/fbs/DimensionValueType.py +8 -0
  114. onnxruntime/tools/ort_format_model/ort_flatbuffers_py/fbs/EdgeEnd.py +32 -0
  115. onnxruntime/tools/ort_format_model/ort_flatbuffers_py/fbs/FloatProperty.py +67 -0
  116. onnxruntime/tools/ort_format_model/ort_flatbuffers_py/fbs/Graph.py +320 -0
  117. onnxruntime/tools/ort_format_model/ort_flatbuffers_py/fbs/InferenceSession.py +88 -0
  118. onnxruntime/tools/ort_format_model/ort_flatbuffers_py/fbs/IntProperty.py +67 -0
  119. onnxruntime/tools/ort_format_model/ort_flatbuffers_py/fbs/KernelTypeStrArgsEntry.py +91 -0
  120. onnxruntime/tools/ort_format_model/ort_flatbuffers_py/fbs/KernelTypeStrResolver.py +78 -0
  121. onnxruntime/tools/ort_format_model/ort_flatbuffers_py/fbs/MapType.py +71 -0
  122. onnxruntime/tools/ort_format_model/ort_flatbuffers_py/fbs/Model.py +223 -0
  123. onnxruntime/tools/ort_format_model/ort_flatbuffers_py/fbs/ModuleState.py +141 -0
  124. onnxruntime/tools/ort_format_model/ort_flatbuffers_py/fbs/Node.py +317 -0
  125. onnxruntime/tools/ort_format_model/ort_flatbuffers_py/fbs/NodeEdge.py +126 -0
  126. onnxruntime/tools/ort_format_model/ort_flatbuffers_py/fbs/NodeType.py +7 -0
  127. onnxruntime/tools/ort_format_model/ort_flatbuffers_py/fbs/NodesToOptimizeIndices.py +160 -0
  128. onnxruntime/tools/ort_format_model/ort_flatbuffers_py/fbs/OpIdKernelTypeStrArgsEntry.py +91 -0
  129. onnxruntime/tools/ort_format_model/ort_flatbuffers_py/fbs/OperatorSetId.py +67 -0
  130. onnxruntime/tools/ort_format_model/ort_flatbuffers_py/fbs/OptimizerGroup.py +117 -0
  131. onnxruntime/tools/ort_format_model/ort_flatbuffers_py/fbs/ParameterOptimizerState.py +91 -0
  132. onnxruntime/tools/ort_format_model/ort_flatbuffers_py/fbs/PropertyBag.py +152 -0
  133. onnxruntime/tools/ort_format_model/ort_flatbuffers_py/fbs/RuntimeOptimizationRecord.py +105 -0
  134. onnxruntime/tools/ort_format_model/ort_flatbuffers_py/fbs/RuntimeOptimizationRecordContainerEntry.py +91 -0
  135. onnxruntime/tools/ort_format_model/ort_flatbuffers_py/fbs/RuntimeOptimizations.py +79 -0
  136. onnxruntime/tools/ort_format_model/ort_flatbuffers_py/fbs/SequenceType.py +58 -0
  137. onnxruntime/tools/ort_format_model/ort_flatbuffers_py/fbs/Shape.py +78 -0
  138. onnxruntime/tools/ort_format_model/ort_flatbuffers_py/fbs/SparseTensor.py +114 -0
  139. onnxruntime/tools/ort_format_model/ort_flatbuffers_py/fbs/StringProperty.py +67 -0
  140. onnxruntime/tools/ort_format_model/ort_flatbuffers_py/fbs/StringStringEntry.py +67 -0
  141. onnxruntime/tools/ort_format_model/ort_flatbuffers_py/fbs/Tensor.py +203 -0
  142. onnxruntime/tools/ort_format_model/ort_flatbuffers_py/fbs/TensorDataType.py +26 -0
  143. onnxruntime/tools/ort_format_model/ort_flatbuffers_py/fbs/TensorTypeAndShape.py +71 -0
  144. onnxruntime/tools/ort_format_model/ort_flatbuffers_py/fbs/TypeInfo.py +83 -0
  145. onnxruntime/tools/ort_format_model/ort_flatbuffers_py/fbs/TypeInfoValue.py +9 -0
  146. onnxruntime/tools/ort_format_model/ort_flatbuffers_py/fbs/ValueInfo.py +84 -0
  147. onnxruntime/tools/ort_format_model/ort_flatbuffers_py/fbs/__init__.py +6 -0
  148. onnxruntime/tools/ort_format_model/ort_model_processor.py +86 -0
  149. onnxruntime/tools/ort_format_model/types.py +85 -0
  150. onnxruntime/tools/ort_format_model/utils.py +61 -0
  151. onnxruntime/tools/pytorch_export_contrib_ops.py +129 -0
  152. onnxruntime/tools/pytorch_export_helpers.py +131 -0
  153. onnxruntime/tools/qdq_helpers/__init__.py +0 -0
  154. onnxruntime/tools/qdq_helpers/optimize_qdq_model.py +37 -0
  155. onnxruntime/tools/qnn/add_trans_cast.py +292 -0
  156. onnxruntime/tools/qnn/gen_qnn_ctx_onnx_model.py +364 -0
  157. onnxruntime/tools/qnn/preprocess.py +165 -0
  158. onnxruntime/tools/reduced_build_config_parser.py +203 -0
  159. onnxruntime/tools/remove_initializer_from_input.py +37 -0
  160. onnxruntime/tools/symbolic_shape_infer.py +3094 -0
  161. onnxruntime/tools/update_onnx_opset.py +31 -0
  162. onnxruntime/transformers/__init__.py +8 -0
  163. onnxruntime/transformers/affinity_helper.py +40 -0
  164. onnxruntime/transformers/benchmark.py +942 -0
  165. onnxruntime/transformers/benchmark_helper.py +643 -0
  166. onnxruntime/transformers/bert_perf_test.py +629 -0
  167. onnxruntime/transformers/bert_test_data.py +641 -0
  168. onnxruntime/transformers/compare_bert_results.py +256 -0
  169. onnxruntime/transformers/constants.py +47 -0
  170. onnxruntime/transformers/convert_generation.py +3605 -0
  171. onnxruntime/transformers/convert_tf_models_to_pytorch.py +205 -0
  172. onnxruntime/transformers/convert_to_packing_mode.py +385 -0
  173. onnxruntime/transformers/dynamo_onnx_helper.py +205 -0
  174. onnxruntime/transformers/float16.py +501 -0
  175. onnxruntime/transformers/fusion_attention.py +1189 -0
  176. onnxruntime/transformers/fusion_attention_clip.py +340 -0
  177. onnxruntime/transformers/fusion_attention_sam2.py +533 -0
  178. onnxruntime/transformers/fusion_attention_unet.py +1307 -0
  179. onnxruntime/transformers/fusion_attention_vae.py +300 -0
  180. onnxruntime/transformers/fusion_bart_attention.py +435 -0
  181. onnxruntime/transformers/fusion_base.py +141 -0
  182. onnxruntime/transformers/fusion_bias_add.py +57 -0
  183. onnxruntime/transformers/fusion_biasgelu.py +66 -0
  184. onnxruntime/transformers/fusion_biassplitgelu.py +110 -0
  185. onnxruntime/transformers/fusion_conformer_attention.py +222 -0
  186. onnxruntime/transformers/fusion_constant_fold.py +144 -0
  187. onnxruntime/transformers/fusion_embedlayer.py +810 -0
  188. onnxruntime/transformers/fusion_fastgelu.py +492 -0
  189. onnxruntime/transformers/fusion_gelu.py +258 -0
  190. onnxruntime/transformers/fusion_gelu_approximation.py +25 -0
  191. onnxruntime/transformers/fusion_gemmfastgelu.py +121 -0
  192. onnxruntime/transformers/fusion_gpt_attention.py +546 -0
  193. onnxruntime/transformers/fusion_gpt_attention_megatron.py +355 -0
  194. onnxruntime/transformers/fusion_gpt_attention_no_past.py +260 -0
  195. onnxruntime/transformers/fusion_group_norm.py +180 -0
  196. onnxruntime/transformers/fusion_layernorm.py +489 -0
  197. onnxruntime/transformers/fusion_mha_mmdit.py +667 -0
  198. onnxruntime/transformers/fusion_nhwc_conv.py +99 -0
  199. onnxruntime/transformers/fusion_options.py +340 -0
  200. onnxruntime/transformers/fusion_qordered_attention.py +420 -0
  201. onnxruntime/transformers/fusion_qordered_gelu.py +118 -0
  202. onnxruntime/transformers/fusion_qordered_layernorm.py +122 -0
  203. onnxruntime/transformers/fusion_qordered_matmul.py +216 -0
  204. onnxruntime/transformers/fusion_quickgelu.py +74 -0
  205. onnxruntime/transformers/fusion_reshape.py +173 -0
  206. onnxruntime/transformers/fusion_rotary_attention.py +1591 -0
  207. onnxruntime/transformers/fusion_shape.py +109 -0
  208. onnxruntime/transformers/fusion_simplified_layernorm.py +165 -0
  209. onnxruntime/transformers/fusion_skip_group_norm.py +254 -0
  210. onnxruntime/transformers/fusion_skiplayernorm.py +209 -0
  211. onnxruntime/transformers/fusion_transpose.py +167 -0
  212. onnxruntime/transformers/fusion_utils.py +321 -0
  213. onnxruntime/transformers/huggingface_models.py +74 -0
  214. onnxruntime/transformers/import_utils.py +20 -0
  215. onnxruntime/transformers/io_binding_helper.py +487 -0
  216. onnxruntime/transformers/large_model_exporter.py +395 -0
  217. onnxruntime/transformers/machine_info.py +230 -0
  218. onnxruntime/transformers/metrics.py +163 -0
  219. onnxruntime/transformers/models/bart/__init__.py +12 -0
  220. onnxruntime/transformers/models/bart/export.py +98 -0
  221. onnxruntime/transformers/models/bert/__init__.py +12 -0
  222. onnxruntime/transformers/models/bert/eval_squad.py +329 -0
  223. onnxruntime/transformers/models/gpt2/__init__.py +12 -0
  224. onnxruntime/transformers/models/gpt2/benchmark_gpt2.py +413 -0
  225. onnxruntime/transformers/models/gpt2/convert_to_onnx.py +566 -0
  226. onnxruntime/transformers/models/gpt2/gpt2_helper.py +1031 -0
  227. onnxruntime/transformers/models/gpt2/gpt2_parity.py +513 -0
  228. onnxruntime/transformers/models/gpt2/gpt2_tester.py +501 -0
  229. onnxruntime/transformers/models/gpt2/parity_check_helper.py +146 -0
  230. onnxruntime/transformers/models/llama/__init__.py +12 -0
  231. onnxruntime/transformers/models/llama/benchmark.py +700 -0
  232. onnxruntime/transformers/models/llama/benchmark_all.py +488 -0
  233. onnxruntime/transformers/models/llama/benchmark_e2e.py +608 -0
  234. onnxruntime/transformers/models/llama/convert_to_onnx.py +1064 -0
  235. onnxruntime/transformers/models/llama/dist_settings.py +57 -0
  236. onnxruntime/transformers/models/llama/llama_inputs.py +504 -0
  237. onnxruntime/transformers/models/llama/llama_parity.py +343 -0
  238. onnxruntime/transformers/models/llama/llama_torch.py +47 -0
  239. onnxruntime/transformers/models/llama/quant_kv_dataloader.py +108 -0
  240. onnxruntime/transformers/models/longformer/__init__.py +12 -0
  241. onnxruntime/transformers/models/longformer/benchmark_longformer.py +821 -0
  242. onnxruntime/transformers/models/longformer/convert_to_onnx.py +413 -0
  243. onnxruntime/transformers/models/longformer/generate_test_data.py +347 -0
  244. onnxruntime/transformers/models/longformer/longformer_helper.py +76 -0
  245. onnxruntime/transformers/models/phi2/__init__.py +12 -0
  246. onnxruntime/transformers/models/phi2/convert_to_onnx.py +590 -0
  247. onnxruntime/transformers/models/phi2/inference_example.py +414 -0
  248. onnxruntime/transformers/models/sam2/__init__.py +12 -0
  249. onnxruntime/transformers/models/sam2/benchmark_sam2.py +638 -0
  250. onnxruntime/transformers/models/sam2/convert_to_onnx.py +270 -0
  251. onnxruntime/transformers/models/sam2/image_decoder.py +272 -0
  252. onnxruntime/transformers/models/sam2/image_encoder.py +236 -0
  253. onnxruntime/transformers/models/sam2/mask_decoder.py +208 -0
  254. onnxruntime/transformers/models/sam2/nvtx_helper.py +33 -0
  255. onnxruntime/transformers/models/sam2/prompt_encoder.py +189 -0
  256. onnxruntime/transformers/models/sam2/sam2_demo.py +321 -0
  257. onnxruntime/transformers/models/sam2/sam2_image_onnx_predictor.py +279 -0
  258. onnxruntime/transformers/models/sam2/sam2_utils.py +147 -0
  259. onnxruntime/transformers/models/stable_diffusion/__init__.py +12 -0
  260. onnxruntime/transformers/models/stable_diffusion/benchmark.py +1519 -0
  261. onnxruntime/transformers/models/stable_diffusion/benchmark_controlnet.py +426 -0
  262. onnxruntime/transformers/models/stable_diffusion/demo_txt2img.py +103 -0
  263. onnxruntime/transformers/models/stable_diffusion/demo_txt2img_xl.py +269 -0
  264. onnxruntime/transformers/models/stable_diffusion/demo_utils.py +778 -0
  265. onnxruntime/transformers/models/stable_diffusion/diffusion_models.py +1318 -0
  266. onnxruntime/transformers/models/stable_diffusion/diffusion_schedulers.py +1179 -0
  267. onnxruntime/transformers/models/stable_diffusion/engine_builder.py +295 -0
  268. onnxruntime/transformers/models/stable_diffusion/engine_builder_ort_cuda.py +387 -0
  269. onnxruntime/transformers/models/stable_diffusion/engine_builder_ort_trt.py +288 -0
  270. onnxruntime/transformers/models/stable_diffusion/engine_builder_tensorrt.py +395 -0
  271. onnxruntime/transformers/models/stable_diffusion/engine_builder_torch.py +108 -0
  272. onnxruntime/transformers/models/stable_diffusion/optimize_pipeline.py +590 -0
  273. onnxruntime/transformers/models/stable_diffusion/ort_optimizer.py +136 -0
  274. onnxruntime/transformers/models/stable_diffusion/pipeline_stable_diffusion.py +831 -0
  275. onnxruntime/transformers/models/stable_diffusion/trt_utilities.py +12 -0
  276. onnxruntime/transformers/models/t5/__init__.py +12 -0
  277. onnxruntime/transformers/models/t5/convert_to_onnx.py +318 -0
  278. onnxruntime/transformers/models/t5/t5_decoder.py +437 -0
  279. onnxruntime/transformers/models/t5/t5_encoder.py +70 -0
  280. onnxruntime/transformers/models/t5/t5_encoder_decoder_init.py +361 -0
  281. onnxruntime/transformers/models/t5/t5_helper.py +302 -0
  282. onnxruntime/transformers/models/whisper/__init__.py +12 -0
  283. onnxruntime/transformers/models/whisper/benchmark.py +585 -0
  284. onnxruntime/transformers/models/whisper/benchmark_all.py +526 -0
  285. onnxruntime/transformers/models/whisper/convert_to_onnx.py +609 -0
  286. onnxruntime/transformers/models/whisper/whisper_chain.py +334 -0
  287. onnxruntime/transformers/models/whisper/whisper_decoder.py +464 -0
  288. onnxruntime/transformers/models/whisper/whisper_encoder.py +164 -0
  289. onnxruntime/transformers/models/whisper/whisper_encoder_decoder_init.py +371 -0
  290. onnxruntime/transformers/models/whisper/whisper_helper.py +1035 -0
  291. onnxruntime/transformers/models/whisper/whisper_inputs.py +380 -0
  292. onnxruntime/transformers/models/whisper/whisper_jump_times.py +477 -0
  293. onnxruntime/transformers/onnx_exporter.py +719 -0
  294. onnxruntime/transformers/onnx_model.py +1636 -0
  295. onnxruntime/transformers/onnx_model_bart.py +141 -0
  296. onnxruntime/transformers/onnx_model_bert.py +488 -0
  297. onnxruntime/transformers/onnx_model_bert_keras.py +474 -0
  298. onnxruntime/transformers/onnx_model_bert_tf.py +588 -0
  299. onnxruntime/transformers/onnx_model_clip.py +42 -0
  300. onnxruntime/transformers/onnx_model_conformer.py +32 -0
  301. onnxruntime/transformers/onnx_model_gpt2.py +101 -0
  302. onnxruntime/transformers/onnx_model_mmdit.py +112 -0
  303. onnxruntime/transformers/onnx_model_phi.py +929 -0
  304. onnxruntime/transformers/onnx_model_sam2.py +137 -0
  305. onnxruntime/transformers/onnx_model_t5.py +985 -0
  306. onnxruntime/transformers/onnx_model_tnlr.py +226 -0
  307. onnxruntime/transformers/onnx_model_unet.py +258 -0
  308. onnxruntime/transformers/onnx_model_vae.py +42 -0
  309. onnxruntime/transformers/onnx_utils.py +55 -0
  310. onnxruntime/transformers/optimizer.py +620 -0
  311. onnxruntime/transformers/past_helper.py +149 -0
  312. onnxruntime/transformers/profile_result_processor.py +358 -0
  313. onnxruntime/transformers/profiler.py +434 -0
  314. onnxruntime/transformers/quantize_helper.py +76 -0
  315. onnxruntime/transformers/shape_infer_helper.py +121 -0
  316. onnxruntime/transformers/shape_optimizer.py +400 -0
  317. onnxruntime/transformers/torch_onnx_export_helper.py +74 -0
  318. onnxruntime_directml-1.24.1.dist-info/METADATA +216 -0
  319. onnxruntime_directml-1.24.1.dist-info/RECORD +322 -0
  320. onnxruntime_directml-1.24.1.dist-info/WHEEL +5 -0
  321. onnxruntime_directml-1.24.1.dist-info/entry_points.txt +2 -0
  322. onnxruntime_directml-1.24.1.dist-info/top_level.txt +1 -0
@@ -0,0 +1,343 @@
1
+ # -------------------------------------------------------------------------
2
+ # Copyright (c) Microsoft Corporation. All rights reserved.
3
+ # Licensed under the MIT License. See License.txt in the project root for
4
+ # license information.
5
+ # --------------------------------------------------------------------------
6
+ from __future__ import annotations
7
+
8
+ import argparse
9
+ import logging
10
+ import os
11
+ import time
12
+
13
+ import numpy as np
14
+ import packaging.version as pv
15
+ import torch
16
+ from benchmark_helper import setup_logger
17
+ from dist_settings import get_rank, get_size
18
+ from llama_inputs import (
19
+ add_io_bindings_as_ortvalues,
20
+ convert_inputs_for_ort,
21
+ get_merged_sample_with_past_kv_inputs,
22
+ get_sample_inputs,
23
+ get_sample_with_past_kv_inputs,
24
+ verify_ort_inputs,
25
+ )
26
+ from llama_torch import setup_torch_model
27
+ from models.torch_export_patches.cache_helper import make_dynamic_cache
28
+ from transformers import AutoConfig
29
+ from transformers import __version__ as transformers_version
30
+ from transformers.cache_utils import DynamicCache
31
+
32
+ import onnxruntime as ort
33
+
34
+ logger = logging.getLogger("")
35
+
36
+
37
+ def get_sequence_lengths(args: argparse.Namespace, config: AutoConfig):
38
+ past_sequence_length, curr_sequence_length = (8, 1) if args.use_past_kv else (0, 8)
39
+ max_sequence_length = config.max_position_embeddings
40
+ return past_sequence_length, curr_sequence_length, max_sequence_length
41
+
42
+
43
+ def get_inputs(args: argparse.Namespace, config: AutoConfig):
44
+ # Dummy values for parity
45
+ world_size = get_size()
46
+ batch_size = 2
47
+ past_sequence_length, sequence_length, max_sequence_length = get_sequence_lengths(args, config)
48
+
49
+ if args.merged:
50
+ inputs = get_merged_sample_with_past_kv_inputs(
51
+ config,
52
+ args.device,
53
+ batch_size,
54
+ seq_len=sequence_length,
55
+ past_seq_len=past_sequence_length,
56
+ max_seq_len=max_sequence_length,
57
+ use_fp16=args.use_fp16,
58
+ use_buffer_share=args.use_buffer_share,
59
+ return_dict=True,
60
+ world_size=world_size,
61
+ )
62
+ elif args.use_past_kv:
63
+ inputs = get_sample_with_past_kv_inputs(
64
+ config,
65
+ args.device,
66
+ batch_size,
67
+ sequence_length,
68
+ use_fp16=args.use_fp16,
69
+ return_dict=True,
70
+ world_size=world_size,
71
+ )
72
+ else:
73
+ inputs = get_sample_inputs(config, args.device, batch_size, sequence_length, return_dict=True)
74
+
75
+ return inputs
76
+
77
+
78
+ def torch_deepcopy(value):
79
+ if isinstance(value, (int, float, str)):
80
+ return value
81
+ if isinstance(value, tuple):
82
+ return tuple(torch_deepcopy(v) for v in value)
83
+ if isinstance(value, list):
84
+ return [torch_deepcopy(v) for v in value]
85
+ if isinstance(value, set):
86
+ return {torch_deepcopy(v) for v in value}
87
+ if isinstance(value, dict):
88
+ return {k: torch_deepcopy(v) for k, v in value.items()}
89
+ if isinstance(value, np.ndarray):
90
+ return value.copy()
91
+ if hasattr(value, "clone"):
92
+ return value.clone()
93
+ if isinstance(value, DynamicCache):
94
+ return make_dynamic_cache(torch_deepcopy(list(zip(value.key_cache, value.value_cache, strict=False))))
95
+ # We should have a code using serialization, deserialization assuming a model
96
+ # cannot be exported without them.
97
+ raise NotImplementedError(f"torch_deepcopy not implemented for type {type(value)}")
98
+
99
+
100
+ def verify_parity(
101
+ args: argparse.Namespace,
102
+ location: str,
103
+ use_auth_token: bool,
104
+ kv_cache_ortvalues: dict,
105
+ pytorch_model: None | torch.nn.Module = None,
106
+ config: None | AutoConfig = None,
107
+ ):
108
+ # If it's running in a machine where GPU memory < 36GB, it should unload the model in GPU in time and free the GPU memory for ORT.
109
+ py_model = pytorch_model
110
+ if py_model is None:
111
+ config, py_model = setup_torch_model(
112
+ args,
113
+ location,
114
+ use_auth_token,
115
+ torch_dtype=(torch.float16 if args.use_fp16 else torch.float32),
116
+ device=args.device,
117
+ )
118
+
119
+ inputs = get_inputs(args, config)
120
+
121
+ if "past_key_values" in inputs and pv.Version(transformers_version) >= pv.Version("4.45"):
122
+ # Using DynamicCache
123
+ inputs["past_key_values"] = make_dynamic_cache(inputs["past_key_values"])
124
+
125
+ # Run inference with PyTorch
126
+ inputs_after_deepcopy = torch_deepcopy(inputs)
127
+ if args.execution_provider != "cpu":
128
+ torch.cuda.synchronize()
129
+ start_time = time.time()
130
+ # If there is a cache in the inputs, we need to make a copy as the model modifies them inplace.
131
+ # DynamicCache inherits from torch.nn.Module in some version of transformers.
132
+ # We need to make the copy manually.
133
+ pt_outputs = py_model(**inputs_after_deepcopy).logits.detach().cpu().numpy()
134
+ if args.execution_provider != "cpu":
135
+ torch.cuda.synchronize()
136
+ end_time = time.time()
137
+ logger.info(f"PyTorch took {end_time - start_time} s")
138
+
139
+ if args.small_gpu and py_model is not None:
140
+ del py_model
141
+ torch.cuda.empty_cache()
142
+
143
+ # Run inference with ORT
144
+ past_sequence_length, _, max_sequence_length = get_sequence_lengths(args, config)
145
+ inputs = convert_inputs_for_ort(
146
+ inputs,
147
+ use_buffer_share=args.use_buffer_share,
148
+ past_seq_len=past_sequence_length,
149
+ max_seq_len=max_sequence_length,
150
+ )
151
+
152
+ ep = f"{args.execution_provider.upper()}ExecutionProvider"
153
+ if ep == "CUDAExecutionProvider":
154
+ ep = (ep, {"device_id": args.rank})
155
+ ort_model = ort.InferenceSession(
156
+ args.onnx_model_path,
157
+ sess_options=ort.SessionOptions(),
158
+ providers=[ep],
159
+ )
160
+ inputs = verify_ort_inputs(ort_model, inputs)
161
+
162
+ # Add IO bindings for non-CPU execution providers
163
+ if args.execution_provider != "cpu":
164
+ io_binding, kv_cache_ortvalues = add_io_bindings_as_ortvalues(
165
+ ort_model,
166
+ ort_inputs=inputs,
167
+ device=args.execution_provider,
168
+ device_id=int(args.rank),
169
+ use_buffer_share=args.use_buffer_share,
170
+ kv_cache_ortvalues=kv_cache_ortvalues,
171
+ )
172
+
173
+ io_binding.synchronize_inputs()
174
+ start_time = time.time()
175
+ ort_model.run_with_iobinding(io_binding)
176
+ io_binding.synchronize_outputs()
177
+ end_time = time.time()
178
+
179
+ ort_outputs = io_binding.copy_outputs_to_cpu()[0] # Get logits
180
+ del ort_model
181
+
182
+ else:
183
+ start_time = time.time()
184
+ ort_outputs = ort_model.run(None, inputs)
185
+ end_time = time.time()
186
+
187
+ ort_outputs = ort_outputs[0] # Get logits
188
+
189
+ logger.info(f"ONNX Runtime took {end_time - start_time} s")
190
+
191
+ # Compare PyTorch and ONNX Runtime accuracy
192
+ tol = 2e1 if "int4" in args.onnx_model_path or "int8" in args.onnx_model_path else 5e-1
193
+ parity = np.allclose(pt_outputs, ort_outputs, rtol=tol, atol=tol)
194
+ logger.warning(f"Are PyTorch and ONNX Runtime results close? {parity}")
195
+ if not parity:
196
+ logger.warning(f"Max diff: {np.max(pt_outputs - ort_outputs)}")
197
+ return kv_cache_ortvalues
198
+
199
+
200
+ def get_args(argv: list[str]):
201
+ parser = argparse.ArgumentParser()
202
+
203
+ parser.add_argument(
204
+ "-m",
205
+ "--model_name",
206
+ required=False,
207
+ help="Model name in Hugging Face",
208
+ )
209
+
210
+ parser.add_argument(
211
+ "-t",
212
+ "--torch_model_directory",
213
+ required=False,
214
+ default=os.path.join("."),
215
+ help="Path to folder containing PyTorch model and associated files if saved on disk",
216
+ )
217
+
218
+ parser.add_argument(
219
+ "-o",
220
+ "--onnx_model_path",
221
+ required=True,
222
+ default=os.path.join("."),
223
+ help="Path to ONNX model (with external data files saved in the same folder as the model)",
224
+ )
225
+
226
+ parser.add_argument(
227
+ "-ep",
228
+ "--execution_provider",
229
+ required=False,
230
+ default="cpu",
231
+ choices=["cpu", "cuda"],
232
+ help="Execution provider to verify parity with",
233
+ )
234
+
235
+ parser.add_argument(
236
+ "-v",
237
+ "--verbose",
238
+ action="store_true",
239
+ help="Print verbose logs",
240
+ )
241
+ parser.set_defaults(verbose=False)
242
+
243
+ parser.add_argument(
244
+ "-p",
245
+ "--use_past_kv",
246
+ action="store_true",
247
+ help="Use past key and past value as inputs to the model. Necessary for decoder_with_past_model.onnx models.",
248
+ )
249
+ parser.set_defaults(use_past_kv=False)
250
+
251
+ parser.add_argument(
252
+ "-g",
253
+ "--use_buffer_share",
254
+ action="store_true",
255
+ help="Use if model has GroupQueryAttention and you want to enable past-present buffer sharing",
256
+ )
257
+ parser.set_defaults(use_buffer_share=False)
258
+
259
+ parser.add_argument(
260
+ "--merged",
261
+ action="store_true",
262
+ help="Use merged model (i.e. decoder_merged_model.onnx).",
263
+ )
264
+ parser.set_defaults(merged=False)
265
+
266
+ parser.add_argument(
267
+ "-fp",
268
+ "--precision",
269
+ required=True,
270
+ choices=["int4", "int8", "fp16", "fp32"],
271
+ help="Precision of model",
272
+ )
273
+
274
+ parser.add_argument(
275
+ "--cache_dir",
276
+ required=False,
277
+ type=str,
278
+ default="./model_cache",
279
+ help="model cache dir to override default HF cache dir to avoid overflood the /home dir",
280
+ )
281
+
282
+ # The argument is used for CI mainly, because the CI machine has 24G GPU memory at most.
283
+ parser.add_argument(
284
+ "--small_gpu",
285
+ action="store_true",
286
+ help="Load the llama in GPU every time for parity_check if it's running in a machine which GPU memory < 36GB. ",
287
+ )
288
+
289
+ args = parser.parse_args() if argv == [] else parser.parse_args(argv)
290
+
291
+ # Use FP32 precision for FP32, INT8, INT4 CPU models, use FP16 precision for FP16 and INT4 GPU models
292
+ args.precision = (
293
+ "fp32"
294
+ if args.precision in {"int8", "fp32"} or (args.precision == "int4" and args.execution_provider == "cpu")
295
+ else "fp16"
296
+ )
297
+ return args
298
+
299
+
300
+ def main(argv: list[str] = []): # noqa: B006
301
+ args = get_args(argv)
302
+ setup_logger(args.verbose)
303
+ logger.info(f"Arguments: {args}")
304
+ rank = get_rank()
305
+
306
+ # Load model and config
307
+ setattr(args, "use_fp16", args.precision == "fp16") # noqa: B010
308
+ args.rank = rank
309
+ setattr(args, "device_name", "cpu" if args.execution_provider == "cpu" else f"cuda:{rank}") # noqa: B010
310
+ setattr(args, "device", torch.device(args.device_name)) # noqa: B010
311
+ use_auth_token = args.torch_model_directory == os.path.join(".")
312
+ location = args.model_name if use_auth_token else args.torch_model_directory
313
+
314
+ kv_cache_ortvalues = {}
315
+ if not args.merged:
316
+ verify_parity(args, location, use_auth_token, kv_cache_ortvalues)
317
+ else:
318
+ config = llama = None
319
+ if not args.small_gpu:
320
+ config, llama = setup_torch_model(
321
+ args,
322
+ location,
323
+ use_auth_token,
324
+ torch_dtype=(torch.float16 if args.use_fp16 else torch.float32),
325
+ device=args.device,
326
+ )
327
+
328
+ # Verify prompt processing in merged model (decoder_model.onnx)
329
+ args.use_past_kv = False
330
+ kv_cache_ortvalues = verify_parity(
331
+ args, location, use_auth_token, kv_cache_ortvalues, pytorch_model=llama, config=config
332
+ )
333
+
334
+ # Verify token generation in merged model (decoder_with_past_model.onnx)
335
+ args.use_past_kv = True
336
+ verify_parity(args, location, use_auth_token, kv_cache_ortvalues, pytorch_model=llama, config=config)
337
+
338
+
339
+ if __name__ == "__main__":
340
+ seed = 2
341
+ np.random.seed(seed)
342
+ torch.manual_seed(seed)
343
+ main()
@@ -0,0 +1,47 @@
1
+ # -------------------------------------------------------------------------
2
+ # Copyright (c) Microsoft Corporation. All rights reserved.
3
+ # Licensed under the MIT License. See License.txt in the project root for
4
+ # license information.
5
+ # --------------------------------------------------------------------------
6
+ import logging
7
+ import os
8
+
9
+ import torch
10
+ from dist_settings import barrier, get_rank, get_size
11
+ from transformers import AutoConfig, AutoModelForCausalLM
12
+
13
+ logger = logging.getLogger("")
14
+
15
+
16
+ def setup_torch_model(args, location, auth, torch_dtype=torch.float32, device=None):
17
+ world_size = get_size()
18
+ logger.info(f"world_size: {world_size}")
19
+ rank = get_rank()
20
+ barrier()
21
+
22
+ if not os.path.exists(args.cache_dir):
23
+ os.makedirs(args.cache_dir, exist_ok=True)
24
+
25
+ for i in range(world_size):
26
+ if i == rank % (world_size):
27
+ l_config = AutoConfig.from_pretrained(
28
+ location, use_auth_token=auth, cache_dir=args.cache_dir, trust_remote_code=auth
29
+ )
30
+ l_config.use_cache = True
31
+ l_config._attn_implementation = "eager" # "eager" uses LlamaAttention for attention layer
32
+ llama = AutoModelForCausalLM.from_pretrained(
33
+ location,
34
+ use_auth_token=auth,
35
+ trust_remote_code=auth,
36
+ config=l_config,
37
+ torch_dtype=torch_dtype,
38
+ cache_dir=args.cache_dir,
39
+ )
40
+ if world_size > 1:
41
+ llama.parallel_model()
42
+ if device:
43
+ llama.to(device)
44
+ llama.eval()
45
+ llama.requires_grad_(False)
46
+ barrier()
47
+ return l_config, llama
@@ -0,0 +1,108 @@
1
+ # -------------------------------------------------------------------------
2
+ # Copyright (c) Microsoft Corporation. All rights reserved.
3
+ # Licensed under the MIT License. See License.txt in the project root for
4
+ # license information.
5
+ # --------------------------------------------------------------------------
6
+ import argparse
7
+
8
+ import numpy as np
9
+ import torch
10
+ from benchmark_helper import create_onnxruntime_session
11
+ from datasets import load_dataset
12
+ from llama_inputs import get_position_ids
13
+ from torch.nn.functional import pad
14
+ from torch.utils.data import DataLoader
15
+ from transformers import LlamaTokenizer
16
+
17
+
18
+ class QuantKVDataLoader:
19
+ def __init__(self, args: argparse.Namespace, onnx_model_path: str = ""):
20
+ self.batch_size = 1
21
+ self.pad_max = args.pad_max
22
+
23
+ tokenizer = LlamaTokenizer.from_pretrained(args.original_model_name, use_auth_token=args.use_auth_token)
24
+ dataset = load_dataset(args.smooth_quant_dataset, split="train")
25
+ dataset = dataset.map(lambda examples: tokenizer(examples["text"]), batched=True)
26
+ dataset.set_format(type="torch", columns=["input_ids", "attention_mask"])
27
+
28
+ self.dataloader = DataLoader(
29
+ dataset,
30
+ batch_size=self.batch_size,
31
+ shuffle=False,
32
+ collate_fn=self.collate_batch,
33
+ )
34
+ self.decoder_model = (
35
+ create_onnxruntime_session(
36
+ onnx_model_path,
37
+ args.execution_provider != "cpu", # use_gpu
38
+ provider=args.execution_provider,
39
+ verbose=args.verbose,
40
+ )
41
+ if onnx_model_path
42
+ else None
43
+ )
44
+
45
+ def collate_batch(self, batch):
46
+ input_ids_batched = []
47
+ attention_mask_batched = []
48
+ position_ids_batched = []
49
+ labels = []
50
+
51
+ for text in batch:
52
+ # Set inputs for model
53
+ input_ids = text["input_ids"]
54
+ attention_mask = torch.ones(len(input_ids))
55
+ position_ids = get_position_ids(attention_mask, use_past_kv=False)
56
+ label = len(input_ids) - 1
57
+
58
+ # Pad input data because all model inputs must have same shape
59
+ pad_len = self.pad_max - input_ids.shape[0]
60
+ input_ids = pad(input_ids, (0, pad_len), value=1)
61
+ attention_mask = pad(attention_mask, (0, pad_len), value=0)
62
+ position_ids = pad(position_ids, (0, pad_len), value=0)
63
+
64
+ input_ids_batched.append(input_ids)
65
+ attention_mask_batched.append(attention_mask)
66
+ position_ids_batched.append(position_ids)
67
+ labels.append(label)
68
+
69
+ input_ids_batched = torch.vstack(input_ids_batched)
70
+ attention_mask_batched = torch.vstack(attention_mask_batched)
71
+ position_ids_batched = torch.vstack(position_ids_batched)
72
+ labels = torch.tensor(labels)
73
+
74
+ return (input_ids_batched, attention_mask_batched, position_ids_batched), labels
75
+
76
+ def __iter__(self):
77
+ try:
78
+ for (input_ids, attention_mask, position_ids), labels in self.dataloader:
79
+ # Inputs for decoder_model.onnx
80
+ inputs = {
81
+ "input_ids": input_ids[:, :-1].detach().cpu().numpy().astype(np.int64),
82
+ "attention_mask": attention_mask[:, :-1].detach().cpu().numpy().astype(np.int64),
83
+ "position_ids": position_ids[:, :-1].detach().cpu().numpy().astype(np.int64),
84
+ }
85
+ label = labels.detach().cpu().numpy()
86
+
87
+ if self.decoder_model is not None:
88
+ # Run decoder_model.onnx to get inputs for decoder_with_past_model.onnx
89
+ outputs = self.decoder_model.run(None, inputs)
90
+
91
+ for i in range(int((len(outputs) - 1) / 2)):
92
+ inputs[f"past_key_values.{i}.key"] = outputs[i * 2 + 1]
93
+ inputs[f"past_key_values.{i}.value"] = outputs[i * 2 + 2]
94
+ past_sequence_length = inputs["past_key_values.0.key"].shape[2]
95
+
96
+ inputs["input_ids"] = input_ids[:, -1].unsqueeze(0).detach().cpu().numpy().astype(np.int64)
97
+ attn_mask_torch = torch.ones((self.batch_size, past_sequence_length + 1), dtype=torch.int64)
98
+ inputs["attention_mask"] = attn_mask_torch.detach().cpu().numpy().astype(np.int64)
99
+ inputs["position_ids"] = (
100
+ get_position_ids(attn_mask_torch, use_past_kv=True).detach().cpu().numpy().astype(np.int64)
101
+ )
102
+
103
+ # Yield (inputs, label) tuple for Intel's Neural Compressor:
104
+ # https://github.com/intel/neural-compressor/blob/d4baed9ea11614e1f0dc8a1f4f55b73ed3ed585c/neural_compressor/quantization.py#L55-L62
105
+ yield (inputs, label)
106
+
107
+ except StopIteration:
108
+ return
@@ -0,0 +1,12 @@
1
+ # -------------------------------------------------------------------------
2
+ # Copyright (c) Microsoft Corporation. All rights reserved.
3
+ # Licensed under the MIT License.
4
+ # --------------------------------------------------------------------------
5
+ import os.path
6
+ import sys
7
+
8
+ sys.path.append(os.path.dirname(__file__))
9
+
10
+ transformers_dir = os.path.normpath(os.path.join(os.path.dirname(__file__), "..", ".."))
11
+ if transformers_dir not in sys.path:
12
+ sys.path.append(transformers_dir)