onnxruntime-directml 1.20.0__cp313-cp313-win_amd64.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (305) hide show
  1. onnxruntime/LICENSE +21 -0
  2. onnxruntime/Privacy.md +21 -0
  3. onnxruntime/ThirdPartyNotices.txt +6508 -0
  4. onnxruntime/__init__.py +78 -0
  5. onnxruntime/backend/__init__.py +6 -0
  6. onnxruntime/backend/backend.py +174 -0
  7. onnxruntime/backend/backend_rep.py +53 -0
  8. onnxruntime/capi/DirectML.dll +0 -0
  9. onnxruntime/capi/__init__.py +4 -0
  10. onnxruntime/capi/_ld_preload.py +7 -0
  11. onnxruntime/capi/_pybind_state.py +33 -0
  12. onnxruntime/capi/convert_npz_to_onnx_adapter.py +48 -0
  13. onnxruntime/capi/onnxruntime.dll +0 -0
  14. onnxruntime/capi/onnxruntime_collect_build_info.py +47 -0
  15. onnxruntime/capi/onnxruntime_inference_collection.py +1108 -0
  16. onnxruntime/capi/onnxruntime_providers_shared.dll +0 -0
  17. onnxruntime/capi/onnxruntime_pybind11_state.pyd +0 -0
  18. onnxruntime/capi/onnxruntime_validation.py +150 -0
  19. onnxruntime/capi/version_info.py +2 -0
  20. onnxruntime/datasets/__init__.py +17 -0
  21. onnxruntime/datasets/logreg_iris.onnx +0 -0
  22. onnxruntime/datasets/mul_1.onnx +0 -0
  23. onnxruntime/datasets/sigmoid.onnx +13 -0
  24. onnxruntime/quantization/CalTableFlatBuffers/KeyValue.py +78 -0
  25. onnxruntime/quantization/CalTableFlatBuffers/TrtTable.py +90 -0
  26. onnxruntime/quantization/CalTableFlatBuffers/__init__.py +0 -0
  27. onnxruntime/quantization/__init__.py +16 -0
  28. onnxruntime/quantization/base_quantizer.py +532 -0
  29. onnxruntime/quantization/calibrate.py +1245 -0
  30. onnxruntime/quantization/execution_providers/qnn/__init__.py +2 -0
  31. onnxruntime/quantization/execution_providers/qnn/fusion_lpnorm.py +132 -0
  32. onnxruntime/quantization/execution_providers/qnn/mixed_precision_overrides_utils.py +413 -0
  33. onnxruntime/quantization/execution_providers/qnn/preprocess.py +307 -0
  34. onnxruntime/quantization/execution_providers/qnn/quant_config.py +387 -0
  35. onnxruntime/quantization/fusions/__init__.py +3 -0
  36. onnxruntime/quantization/fusions/fusion.py +311 -0
  37. onnxruntime/quantization/fusions/fusion_gelu.py +272 -0
  38. onnxruntime/quantization/fusions/fusion_layernorm.py +135 -0
  39. onnxruntime/quantization/matmul_4bits_quantizer.py +1480 -0
  40. onnxruntime/quantization/matmul_bnb4_quantizer.py +240 -0
  41. onnxruntime/quantization/onnx_model.py +580 -0
  42. onnxruntime/quantization/onnx_quantizer.py +1008 -0
  43. onnxruntime/quantization/operators/__init__.py +2 -0
  44. onnxruntime/quantization/operators/activation.py +119 -0
  45. onnxruntime/quantization/operators/argmax.py +18 -0
  46. onnxruntime/quantization/operators/attention.py +73 -0
  47. onnxruntime/quantization/operators/base_operator.py +26 -0
  48. onnxruntime/quantization/operators/binary_op.py +72 -0
  49. onnxruntime/quantization/operators/concat.py +62 -0
  50. onnxruntime/quantization/operators/conv.py +258 -0
  51. onnxruntime/quantization/operators/direct_q8.py +78 -0
  52. onnxruntime/quantization/operators/embed_layernorm.py +121 -0
  53. onnxruntime/quantization/operators/gather.py +64 -0
  54. onnxruntime/quantization/operators/gavgpool.py +62 -0
  55. onnxruntime/quantization/operators/gemm.py +166 -0
  56. onnxruntime/quantization/operators/lstm.py +117 -0
  57. onnxruntime/quantization/operators/matmul.py +231 -0
  58. onnxruntime/quantization/operators/maxpool.py +34 -0
  59. onnxruntime/quantization/operators/norm.py +40 -0
  60. onnxruntime/quantization/operators/pad.py +100 -0
  61. onnxruntime/quantization/operators/pooling.py +67 -0
  62. onnxruntime/quantization/operators/qdq_base_operator.py +22 -0
  63. onnxruntime/quantization/operators/resize.py +34 -0
  64. onnxruntime/quantization/operators/softmax.py +74 -0
  65. onnxruntime/quantization/operators/split.py +63 -0
  66. onnxruntime/quantization/operators/where.py +87 -0
  67. onnxruntime/quantization/preprocess.py +141 -0
  68. onnxruntime/quantization/qdq_loss_debug.py +389 -0
  69. onnxruntime/quantization/qdq_quantizer.py +1187 -0
  70. onnxruntime/quantization/quant_utils.py +891 -0
  71. onnxruntime/quantization/quantize.py +748 -0
  72. onnxruntime/quantization/registry.py +106 -0
  73. onnxruntime/quantization/shape_inference.py +187 -0
  74. onnxruntime/quantization/tensor_quant_overrides.py +516 -0
  75. onnxruntime/tools/__init__.py +10 -0
  76. onnxruntime/tools/check_onnx_model_mobile_usability.py +47 -0
  77. onnxruntime/tools/convert_onnx_models_to_ort.py +377 -0
  78. onnxruntime/tools/file_utils.py +46 -0
  79. onnxruntime/tools/logger.py +11 -0
  80. onnxruntime/tools/make_dynamic_shape_fixed.py +72 -0
  81. onnxruntime/tools/mobile_helpers/__init__.py +0 -0
  82. onnxruntime/tools/mobile_helpers/coreml_supported_mlprogram_ops.md +33 -0
  83. onnxruntime/tools/mobile_helpers/coreml_supported_neuralnetwork_ops.md +43 -0
  84. onnxruntime/tools/mobile_helpers/nnapi_supported_ops.md +58 -0
  85. onnxruntime/tools/mobile_helpers/usability_checker.py +739 -0
  86. onnxruntime/tools/offline_tuning.py +169 -0
  87. onnxruntime/tools/onnx_model_utils.py +413 -0
  88. onnxruntime/tools/onnx_randomizer.py +85 -0
  89. onnxruntime/tools/onnxruntime_test.py +164 -0
  90. onnxruntime/tools/optimize_onnx_model.py +55 -0
  91. onnxruntime/tools/ort_format_model/__init__.py +25 -0
  92. onnxruntime/tools/ort_format_model/operator_type_usage_processors.py +663 -0
  93. onnxruntime/tools/ort_format_model/ort_flatbuffers_py/__init__.py +0 -0
  94. onnxruntime/tools/ort_format_model/ort_flatbuffers_py/fbs/ArgType.py +7 -0
  95. onnxruntime/tools/ort_format_model/ort_flatbuffers_py/fbs/ArgTypeAndIndex.py +67 -0
  96. onnxruntime/tools/ort_format_model/ort_flatbuffers_py/fbs/Attribute.py +337 -0
  97. onnxruntime/tools/ort_format_model/ort_flatbuffers_py/fbs/AttributeType.py +18 -0
  98. onnxruntime/tools/ort_format_model/ort_flatbuffers_py/fbs/Checkpoint.py +125 -0
  99. onnxruntime/tools/ort_format_model/ort_flatbuffers_py/fbs/DeprecatedKernelCreateInfos.py +120 -0
  100. onnxruntime/tools/ort_format_model/ort_flatbuffers_py/fbs/DeprecatedNodeIndexAndKernelDefHash.py +68 -0
  101. onnxruntime/tools/ort_format_model/ort_flatbuffers_py/fbs/DeprecatedSessionState.py +96 -0
  102. onnxruntime/tools/ort_format_model/ort_flatbuffers_py/fbs/DeprecatedSubGraphSessionState.py +72 -0
  103. onnxruntime/tools/ort_format_model/ort_flatbuffers_py/fbs/Dimension.py +71 -0
  104. onnxruntime/tools/ort_format_model/ort_flatbuffers_py/fbs/DimensionValue.py +80 -0
  105. onnxruntime/tools/ort_format_model/ort_flatbuffers_py/fbs/DimensionValueType.py +8 -0
  106. onnxruntime/tools/ort_format_model/ort_flatbuffers_py/fbs/EdgeEnd.py +32 -0
  107. onnxruntime/tools/ort_format_model/ort_flatbuffers_py/fbs/FloatProperty.py +67 -0
  108. onnxruntime/tools/ort_format_model/ort_flatbuffers_py/fbs/Graph.py +320 -0
  109. onnxruntime/tools/ort_format_model/ort_flatbuffers_py/fbs/InferenceSession.py +88 -0
  110. onnxruntime/tools/ort_format_model/ort_flatbuffers_py/fbs/IntProperty.py +67 -0
  111. onnxruntime/tools/ort_format_model/ort_flatbuffers_py/fbs/KernelTypeStrArgsEntry.py +91 -0
  112. onnxruntime/tools/ort_format_model/ort_flatbuffers_py/fbs/KernelTypeStrResolver.py +78 -0
  113. onnxruntime/tools/ort_format_model/ort_flatbuffers_py/fbs/MapType.py +71 -0
  114. onnxruntime/tools/ort_format_model/ort_flatbuffers_py/fbs/Model.py +223 -0
  115. onnxruntime/tools/ort_format_model/ort_flatbuffers_py/fbs/ModuleState.py +141 -0
  116. onnxruntime/tools/ort_format_model/ort_flatbuffers_py/fbs/Node.py +317 -0
  117. onnxruntime/tools/ort_format_model/ort_flatbuffers_py/fbs/NodeEdge.py +126 -0
  118. onnxruntime/tools/ort_format_model/ort_flatbuffers_py/fbs/NodeType.py +7 -0
  119. onnxruntime/tools/ort_format_model/ort_flatbuffers_py/fbs/NodesToOptimizeIndices.py +160 -0
  120. onnxruntime/tools/ort_format_model/ort_flatbuffers_py/fbs/OpIdKernelTypeStrArgsEntry.py +91 -0
  121. onnxruntime/tools/ort_format_model/ort_flatbuffers_py/fbs/OperatorSetId.py +67 -0
  122. onnxruntime/tools/ort_format_model/ort_flatbuffers_py/fbs/OptimizerGroup.py +117 -0
  123. onnxruntime/tools/ort_format_model/ort_flatbuffers_py/fbs/ParameterOptimizerState.py +91 -0
  124. onnxruntime/tools/ort_format_model/ort_flatbuffers_py/fbs/PropertyBag.py +152 -0
  125. onnxruntime/tools/ort_format_model/ort_flatbuffers_py/fbs/RuntimeOptimizationRecord.py +105 -0
  126. onnxruntime/tools/ort_format_model/ort_flatbuffers_py/fbs/RuntimeOptimizationRecordContainerEntry.py +91 -0
  127. onnxruntime/tools/ort_format_model/ort_flatbuffers_py/fbs/RuntimeOptimizations.py +79 -0
  128. onnxruntime/tools/ort_format_model/ort_flatbuffers_py/fbs/SequenceType.py +58 -0
  129. onnxruntime/tools/ort_format_model/ort_flatbuffers_py/fbs/Shape.py +78 -0
  130. onnxruntime/tools/ort_format_model/ort_flatbuffers_py/fbs/SparseTensor.py +114 -0
  131. onnxruntime/tools/ort_format_model/ort_flatbuffers_py/fbs/StringProperty.py +67 -0
  132. onnxruntime/tools/ort_format_model/ort_flatbuffers_py/fbs/StringStringEntry.py +67 -0
  133. onnxruntime/tools/ort_format_model/ort_flatbuffers_py/fbs/Tensor.py +203 -0
  134. onnxruntime/tools/ort_format_model/ort_flatbuffers_py/fbs/TensorDataType.py +26 -0
  135. onnxruntime/tools/ort_format_model/ort_flatbuffers_py/fbs/TensorTypeAndShape.py +71 -0
  136. onnxruntime/tools/ort_format_model/ort_flatbuffers_py/fbs/TypeInfo.py +83 -0
  137. onnxruntime/tools/ort_format_model/ort_flatbuffers_py/fbs/TypeInfoValue.py +9 -0
  138. onnxruntime/tools/ort_format_model/ort_flatbuffers_py/fbs/ValueInfo.py +84 -0
  139. onnxruntime/tools/ort_format_model/ort_flatbuffers_py/fbs/__init__.py +6 -0
  140. onnxruntime/tools/ort_format_model/ort_model_processor.py +86 -0
  141. onnxruntime/tools/ort_format_model/types.py +84 -0
  142. onnxruntime/tools/ort_format_model/utils.py +62 -0
  143. onnxruntime/tools/pytorch_export_contrib_ops.py +108 -0
  144. onnxruntime/tools/pytorch_export_helpers.py +131 -0
  145. onnxruntime/tools/qdq_helpers/__init__.py +0 -0
  146. onnxruntime/tools/qdq_helpers/optimize_qdq_model.py +37 -0
  147. onnxruntime/tools/reduced_build_config_parser.py +202 -0
  148. onnxruntime/tools/symbolic_shape_infer.py +3016 -0
  149. onnxruntime/tools/update_onnx_opset.py +31 -0
  150. onnxruntime/transformers/__init__.py +8 -0
  151. onnxruntime/transformers/affinity_helper.py +40 -0
  152. onnxruntime/transformers/benchmark.py +944 -0
  153. onnxruntime/transformers/benchmark_helper.py +646 -0
  154. onnxruntime/transformers/bert_perf_test.py +634 -0
  155. onnxruntime/transformers/bert_test_data.py +642 -0
  156. onnxruntime/transformers/compare_bert_results.py +246 -0
  157. onnxruntime/transformers/constants.py +47 -0
  158. onnxruntime/transformers/convert_generation.py +3124 -0
  159. onnxruntime/transformers/convert_tf_models_to_pytorch.py +205 -0
  160. onnxruntime/transformers/convert_to_packing_mode.py +387 -0
  161. onnxruntime/transformers/dynamo_onnx_helper.py +104 -0
  162. onnxruntime/transformers/float16.py +501 -0
  163. onnxruntime/transformers/fusion_attention.py +1235 -0
  164. onnxruntime/transformers/fusion_attention_clip.py +257 -0
  165. onnxruntime/transformers/fusion_attention_sam2.py +534 -0
  166. onnxruntime/transformers/fusion_attention_unet.py +1304 -0
  167. onnxruntime/transformers/fusion_attention_vae.py +301 -0
  168. onnxruntime/transformers/fusion_bart_attention.py +640 -0
  169. onnxruntime/transformers/fusion_base.py +137 -0
  170. onnxruntime/transformers/fusion_bias_add.py +58 -0
  171. onnxruntime/transformers/fusion_biasgelu.py +66 -0
  172. onnxruntime/transformers/fusion_biassplitgelu.py +111 -0
  173. onnxruntime/transformers/fusion_conformer_attention.py +143 -0
  174. onnxruntime/transformers/fusion_embedlayer.py +811 -0
  175. onnxruntime/transformers/fusion_fastgelu.py +360 -0
  176. onnxruntime/transformers/fusion_gelu.py +259 -0
  177. onnxruntime/transformers/fusion_gelu_approximation.py +25 -0
  178. onnxruntime/transformers/fusion_gemmfastgelu.py +122 -0
  179. onnxruntime/transformers/fusion_gpt_attention.py +546 -0
  180. onnxruntime/transformers/fusion_gpt_attention_megatron.py +355 -0
  181. onnxruntime/transformers/fusion_gpt_attention_no_past.py +260 -0
  182. onnxruntime/transformers/fusion_group_norm.py +179 -0
  183. onnxruntime/transformers/fusion_layernorm.py +465 -0
  184. onnxruntime/transformers/fusion_nhwc_conv.py +100 -0
  185. onnxruntime/transformers/fusion_options.py +340 -0
  186. onnxruntime/transformers/fusion_qordered_attention.py +421 -0
  187. onnxruntime/transformers/fusion_qordered_gelu.py +119 -0
  188. onnxruntime/transformers/fusion_qordered_layernorm.py +123 -0
  189. onnxruntime/transformers/fusion_qordered_matmul.py +217 -0
  190. onnxruntime/transformers/fusion_quickgelu.py +74 -0
  191. onnxruntime/transformers/fusion_reshape.py +173 -0
  192. onnxruntime/transformers/fusion_rotary_attention.py +1592 -0
  193. onnxruntime/transformers/fusion_shape.py +110 -0
  194. onnxruntime/transformers/fusion_simplified_layernorm.py +159 -0
  195. onnxruntime/transformers/fusion_skip_group_norm.py +255 -0
  196. onnxruntime/transformers/fusion_skiplayernorm.py +209 -0
  197. onnxruntime/transformers/fusion_transpose.py +168 -0
  198. onnxruntime/transformers/fusion_utils.py +307 -0
  199. onnxruntime/transformers/huggingface_models.py +167 -0
  200. onnxruntime/transformers/import_utils.py +20 -0
  201. onnxruntime/transformers/io_binding_helper.py +442 -0
  202. onnxruntime/transformers/large_model_exporter.py +395 -0
  203. onnxruntime/transformers/machine_info.py +221 -0
  204. onnxruntime/transformers/metrics.py +164 -0
  205. onnxruntime/transformers/models/bart/__init__.py +12 -0
  206. onnxruntime/transformers/models/bart/export.py +98 -0
  207. onnxruntime/transformers/models/bert/__init__.py +12 -0
  208. onnxruntime/transformers/models/bert/eval_squad.py +329 -0
  209. onnxruntime/transformers/models/gpt2/__init__.py +12 -0
  210. onnxruntime/transformers/models/gpt2/benchmark_gpt2.py +413 -0
  211. onnxruntime/transformers/models/gpt2/convert_to_onnx.py +561 -0
  212. onnxruntime/transformers/models/gpt2/gpt2_helper.py +1032 -0
  213. onnxruntime/transformers/models/gpt2/gpt2_parity.py +513 -0
  214. onnxruntime/transformers/models/gpt2/gpt2_tester.py +501 -0
  215. onnxruntime/transformers/models/gpt2/parity_check_helper.py +146 -0
  216. onnxruntime/transformers/models/llama/__init__.py +12 -0
  217. onnxruntime/transformers/models/llama/benchmark.py +703 -0
  218. onnxruntime/transformers/models/llama/benchmark_all.py +488 -0
  219. onnxruntime/transformers/models/llama/benchmark_e2e.py +606 -0
  220. onnxruntime/transformers/models/llama/convert_to_onnx.py +1027 -0
  221. onnxruntime/transformers/models/llama/dist_settings.py +57 -0
  222. onnxruntime/transformers/models/llama/llama_inputs.py +503 -0
  223. onnxruntime/transformers/models/llama/llama_parity.py +309 -0
  224. onnxruntime/transformers/models/llama/llama_torch.py +47 -0
  225. onnxruntime/transformers/models/llama/quant_kv_dataloader.py +108 -0
  226. onnxruntime/transformers/models/longformer/__init__.py +12 -0
  227. onnxruntime/transformers/models/longformer/benchmark_longformer.py +821 -0
  228. onnxruntime/transformers/models/longformer/convert_to_onnx.py +413 -0
  229. onnxruntime/transformers/models/longformer/generate_test_data.py +347 -0
  230. onnxruntime/transformers/models/longformer/longformer_helper.py +77 -0
  231. onnxruntime/transformers/models/phi2/__init__.py +12 -0
  232. onnxruntime/transformers/models/phi2/convert_to_onnx.py +576 -0
  233. onnxruntime/transformers/models/phi2/inference_example.py +414 -0
  234. onnxruntime/transformers/models/sam2/__init__.py +12 -0
  235. onnxruntime/transformers/models/sam2/benchmark_sam2.py +625 -0
  236. onnxruntime/transformers/models/sam2/convert_to_onnx.py +260 -0
  237. onnxruntime/transformers/models/sam2/image_decoder.py +273 -0
  238. onnxruntime/transformers/models/sam2/image_encoder.py +186 -0
  239. onnxruntime/transformers/models/sam2/mask_decoder.py +208 -0
  240. onnxruntime/transformers/models/sam2/nvtx_helper.py +33 -0
  241. onnxruntime/transformers/models/sam2/prompt_encoder.py +189 -0
  242. onnxruntime/transformers/models/sam2/sam2_demo.py +322 -0
  243. onnxruntime/transformers/models/sam2/sam2_image_onnx_predictor.py +280 -0
  244. onnxruntime/transformers/models/sam2/sam2_utils.py +147 -0
  245. onnxruntime/transformers/models/stable_diffusion/__init__.py +12 -0
  246. onnxruntime/transformers/models/stable_diffusion/benchmark.py +1429 -0
  247. onnxruntime/transformers/models/stable_diffusion/benchmark_controlnet.py +426 -0
  248. onnxruntime/transformers/models/stable_diffusion/demo_txt2img.py +102 -0
  249. onnxruntime/transformers/models/stable_diffusion/demo_txt2img_xl.py +268 -0
  250. onnxruntime/transformers/models/stable_diffusion/demo_utils.py +778 -0
  251. onnxruntime/transformers/models/stable_diffusion/diffusion_models.py +1319 -0
  252. onnxruntime/transformers/models/stable_diffusion/diffusion_schedulers.py +1181 -0
  253. onnxruntime/transformers/models/stable_diffusion/engine_builder.py +296 -0
  254. onnxruntime/transformers/models/stable_diffusion/engine_builder_ort_cuda.py +388 -0
  255. onnxruntime/transformers/models/stable_diffusion/engine_builder_ort_trt.py +288 -0
  256. onnxruntime/transformers/models/stable_diffusion/engine_builder_tensorrt.py +395 -0
  257. onnxruntime/transformers/models/stable_diffusion/engine_builder_torch.py +108 -0
  258. onnxruntime/transformers/models/stable_diffusion/optimize_pipeline.py +350 -0
  259. onnxruntime/transformers/models/stable_diffusion/ort_optimizer.py +136 -0
  260. onnxruntime/transformers/models/stable_diffusion/pipeline_stable_diffusion.py +831 -0
  261. onnxruntime/transformers/models/stable_diffusion/trt_utilities.py +12 -0
  262. onnxruntime/transformers/models/t5/__init__.py +12 -0
  263. onnxruntime/transformers/models/t5/convert_to_onnx.py +278 -0
  264. onnxruntime/transformers/models/t5/past_helper.py +150 -0
  265. onnxruntime/transformers/models/t5/t5_decoder.py +438 -0
  266. onnxruntime/transformers/models/t5/t5_encoder.py +171 -0
  267. onnxruntime/transformers/models/t5/t5_encoder_decoder_init.py +299 -0
  268. onnxruntime/transformers/models/t5/t5_helper.py +272 -0
  269. onnxruntime/transformers/models/whisper/__init__.py +12 -0
  270. onnxruntime/transformers/models/whisper/benchmark.py +610 -0
  271. onnxruntime/transformers/models/whisper/benchmark_all.py +528 -0
  272. onnxruntime/transformers/models/whisper/convert_to_onnx.py +536 -0
  273. onnxruntime/transformers/models/whisper/whisper_chain.py +329 -0
  274. onnxruntime/transformers/models/whisper/whisper_decoder.py +402 -0
  275. onnxruntime/transformers/models/whisper/whisper_encoder.py +164 -0
  276. onnxruntime/transformers/models/whisper/whisper_encoder_decoder_init.py +306 -0
  277. onnxruntime/transformers/models/whisper/whisper_helper.py +524 -0
  278. onnxruntime/transformers/models/whisper/whisper_openai_helper.py +84 -0
  279. onnxruntime/transformers/onnx_exporter.py +717 -0
  280. onnxruntime/transformers/onnx_model.py +1569 -0
  281. onnxruntime/transformers/onnx_model_bart.py +142 -0
  282. onnxruntime/transformers/onnx_model_bert.py +481 -0
  283. onnxruntime/transformers/onnx_model_bert_keras.py +475 -0
  284. onnxruntime/transformers/onnx_model_bert_tf.py +589 -0
  285. onnxruntime/transformers/onnx_model_clip.py +40 -0
  286. onnxruntime/transformers/onnx_model_conformer.py +33 -0
  287. onnxruntime/transformers/onnx_model_gpt2.py +101 -0
  288. onnxruntime/transformers/onnx_model_phi.py +930 -0
  289. onnxruntime/transformers/onnx_model_sam2.py +138 -0
  290. onnxruntime/transformers/onnx_model_t5.py +791 -0
  291. onnxruntime/transformers/onnx_model_tnlr.py +227 -0
  292. onnxruntime/transformers/onnx_model_unet.py +259 -0
  293. onnxruntime/transformers/onnx_model_vae.py +43 -0
  294. onnxruntime/transformers/onnx_utils.py +55 -0
  295. onnxruntime/transformers/optimizer.py +612 -0
  296. onnxruntime/transformers/profiler.py +725 -0
  297. onnxruntime/transformers/quantize_helper.py +76 -0
  298. onnxruntime/transformers/shape_infer_helper.py +122 -0
  299. onnxruntime/transformers/shape_optimizer.py +401 -0
  300. onnxruntime/transformers/torch_onnx_export_helper.py +74 -0
  301. onnxruntime_directml-1.20.0.dist-info/METADATA +187 -0
  302. onnxruntime_directml-1.20.0.dist-info/RECORD +305 -0
  303. onnxruntime_directml-1.20.0.dist-info/WHEEL +5 -0
  304. onnxruntime_directml-1.20.0.dist-info/entry_points.txt +2 -0
  305. onnxruntime_directml-1.20.0.dist-info/top_level.txt +1 -0
@@ -0,0 +1,944 @@
1
+ # Copyright (c) Microsoft Corporation. All rights reserved.
2
+ # Copyright 2018 The HuggingFace Inc. team.
3
+ # Copyright (c) 2018, NVIDIA CORPORATION. All rights reserved.
4
+ #
5
+ # Licensed under the Apache License, Version 2.0 (the "License");
6
+ # you may not use this file except in compliance with the License.
7
+ # You may obtain a copy of the License at
8
+ #
9
+ # http://www.apache.org/licenses/LICENSE-2.0
10
+ #
11
+ # Unless required by applicable law or agreed to in writing, software
12
+ # distributed under the License is distributed on an "AS IS" BASIS,
13
+ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
14
+ # See the License for the specific language governing permissions and
15
+ # limitations under the License.
16
+ """ Benchmarking the inference of pretrained transformer models.
17
+ PyTorch/TorchScript benchmark is based on https://github.com/huggingface/transformers/blob/master/examples/benchmarks.py.
18
+ One difference is that random input_ids is generated in this benchmark.
19
+
20
+ For onnxruntime, this script will convert a pretrained model to ONNX, and optimize it when -o parameter is used.
21
+
22
+ Example commands:
23
+ Export all models to ONNX, optimize and validate them:
24
+ python benchmark.py -b 0 -o -v -i 1 2 3
25
+ Run OnnxRuntime on GPU for all models:
26
+ python benchmark.py -g
27
+ Run OnnxRuntime on GPU for all models with fp32 optimization:
28
+ python benchmark.py -g -o
29
+ Run OnnxRuntime on GPU with fp16 optimization:
30
+ python benchmark.py -g -o -p "fp16"
31
+ Run TorchScript on GPU for all models:
32
+ python benchmark.py -e torchscript -g
33
+ Run TorchScript on GPU for all models with fp16:
34
+ python benchmark.py -e torchscript -g -p "fp16"
35
+ Run ONNXRuntime and TorchScript on CPU for all models with quantization:
36
+ python benchmark.py -e torchscript onnxruntime -p "int8" -o
37
+ Run OnnxRuntime with the ROCM provider and graph optimization script:
38
+ python benchmark.py -g -m bert-base-cased --provider rocm --optimizer_info by_script --disable_embed_layer_norm
39
+ Run OnnxRuntime with bfloat16 fastmath mode kernels on aarch64 platforms with bfloat16 support:
40
+ python benchmark.py --enable_arm64_bfloat16_fastmath_mlas_gemm
41
+
42
+ It is recommended to use run_benchmark.sh to launch benchmark.
43
+ """
44
+
45
+ import argparse
46
+ import logging
47
+ import os
48
+ import timeit
49
+ from datetime import datetime
50
+
51
+ import numpy
52
+ import psutil
53
+ from benchmark_helper import (
54
+ ConfigModifier,
55
+ OptimizerInfo,
56
+ Precision,
57
+ create_onnxruntime_session,
58
+ get_latency_result,
59
+ inference_ort,
60
+ inference_ort_with_io_binding,
61
+ output_details,
62
+ output_fusion_statistics,
63
+ output_summary,
64
+ setup_logger,
65
+ )
66
+ from fusion_options import FusionOptions
67
+ from huggingface_models import MODEL_CLASSES, MODELS
68
+ from onnx_exporter import (
69
+ create_onnxruntime_input,
70
+ export_onnx_model_from_pt,
71
+ export_onnx_model_from_tf,
72
+ load_pretrained_model,
73
+ )
74
+ from packaging import version
75
+ from quantize_helper import QuantizeHelper
76
+
77
+ logger = logging.getLogger("")
78
+
79
+ cpu_count = psutil.cpu_count(logical=False)
80
+
81
+ # Set OMP environment variable before importing onnxruntime or torch.
82
+ if "OMP_NUM_THREADS" not in os.environ:
83
+ os.environ["OMP_NUM_THREADS"] = str(cpu_count)
84
+
85
+ import torch # noqa: E402
86
+ from transformers import AutoConfig, AutoTokenizer, LxmertConfig # noqa: E402
87
+
88
+
89
+ def run_onnxruntime(
90
+ use_gpu,
91
+ provider,
92
+ model_names,
93
+ model_class,
94
+ config_modifier,
95
+ precision,
96
+ num_threads,
97
+ batch_sizes,
98
+ sequence_lengths,
99
+ repeat_times,
100
+ input_counts,
101
+ optimizer_info,
102
+ validate_onnx,
103
+ cache_dir,
104
+ onnx_dir,
105
+ verbose,
106
+ overwrite,
107
+ disable_ort_io_binding,
108
+ use_raw_attention_mask,
109
+ model_fusion_statistics,
110
+ model_source,
111
+ enable_arm64_bfloat16_fastmath_mlas_gemm,
112
+ args,
113
+ ):
114
+ import onnxruntime
115
+
116
+ results = []
117
+ if (
118
+ use_gpu
119
+ and ("CUDAExecutionProvider" not in onnxruntime.get_available_providers())
120
+ and ("ROCMExecutionProvider" not in onnxruntime.get_available_providers())
121
+ and ("DmlExecutionProvider" not in onnxruntime.get_available_providers())
122
+ ):
123
+ logger.error(
124
+ "Please install onnxruntime-gpu or onnxruntime-directml package instead of onnxruntime, and use a machine with GPU for testing gpu performance."
125
+ )
126
+ return results
127
+
128
+ warm_up_repeat = 0
129
+ if provider == "tensorrt":
130
+ optimizer_info = OptimizerInfo.NOOPT
131
+ warm_up_repeat = 5
132
+ if "TensorrtExecutionProvider" not in onnxruntime.get_available_providers():
133
+ logger.error(
134
+ "Please install onnxruntime-gpu-tensorrt package, and use a machine with GPU for testing gpu performance."
135
+ )
136
+ return results
137
+
138
+ if optimizer_info == OptimizerInfo.NOOPT:
139
+ logger.warning(
140
+ f"OptimizerInfo is set to {optimizer_info}, graph optimizations specified in FusionOptions are not applied."
141
+ )
142
+
143
+ for model_name in model_names:
144
+ all_input_names = MODELS[model_name][0]
145
+ for num_inputs in input_counts:
146
+ if num_inputs > len(all_input_names):
147
+ break
148
+
149
+ input_names = all_input_names[:num_inputs]
150
+ args.model_type = MODELS[model_name][3]
151
+ fusion_options = FusionOptions.parse(args)
152
+
153
+ if "pt" in model_source:
154
+ with torch.no_grad():
155
+ (
156
+ onnx_model_file,
157
+ is_valid_onnx_model,
158
+ vocab_size,
159
+ max_sequence_length,
160
+ ) = export_onnx_model_from_pt(
161
+ model_name,
162
+ MODELS[model_name][1],
163
+ MODELS[model_name][2],
164
+ MODELS[model_name][3],
165
+ model_class,
166
+ config_modifier,
167
+ cache_dir,
168
+ onnx_dir,
169
+ input_names,
170
+ use_gpu,
171
+ precision,
172
+ optimizer_info,
173
+ validate_onnx,
174
+ use_raw_attention_mask,
175
+ overwrite,
176
+ model_fusion_statistics,
177
+ fusion_options,
178
+ )
179
+ if "tf" in model_source:
180
+ (
181
+ onnx_model_file,
182
+ is_valid_onnx_model,
183
+ vocab_size,
184
+ max_sequence_length,
185
+ ) = export_onnx_model_from_tf(
186
+ model_name,
187
+ MODELS[model_name][1],
188
+ MODELS[model_name][2],
189
+ MODELS[model_name][3],
190
+ model_class,
191
+ config_modifier,
192
+ cache_dir,
193
+ onnx_dir,
194
+ input_names,
195
+ use_gpu,
196
+ precision,
197
+ optimizer_info,
198
+ validate_onnx,
199
+ use_raw_attention_mask,
200
+ overwrite,
201
+ model_fusion_statistics,
202
+ fusion_options,
203
+ )
204
+
205
+ if not is_valid_onnx_model:
206
+ continue
207
+
208
+ ort_session = create_onnxruntime_session(
209
+ onnx_model_file,
210
+ use_gpu,
211
+ provider,
212
+ enable_all_optimization=True,
213
+ num_threads=num_threads,
214
+ verbose=verbose,
215
+ enable_mlas_gemm_fastmath_arm64_bfloat16=enable_arm64_bfloat16_fastmath_mlas_gemm,
216
+ )
217
+ if ort_session is None:
218
+ continue
219
+
220
+ ort_output_names = [node_arg.name for node_arg in ort_session.get_outputs()]
221
+ output_buffers = []
222
+ device = "cuda" if use_gpu else "cpu"
223
+ config = AutoConfig.from_pretrained(model_name, cache_dir=cache_dir)
224
+ max_last_state_size = numpy.prod(
225
+ [
226
+ max(batch_sizes),
227
+ max(sequence_lengths),
228
+ max(vocab_size, config.hidden_size),
229
+ ]
230
+ )
231
+ max_pooler_size = numpy.prod([max(batch_sizes), config.hidden_size])
232
+ for batch_size in batch_sizes:
233
+ if batch_size <= 0:
234
+ continue
235
+ for sequence_length in sequence_lengths:
236
+ if max_sequence_length is not None and sequence_length > max_sequence_length:
237
+ continue
238
+
239
+ input_value_type = numpy.int64 if "pt" in model_source else numpy.int32
240
+ ort_inputs = create_onnxruntime_input(
241
+ vocab_size,
242
+ batch_size,
243
+ sequence_length,
244
+ input_names,
245
+ config,
246
+ input_value_type,
247
+ )
248
+ result_template = {
249
+ "engine": "onnxruntime",
250
+ "version": onnxruntime.__version__,
251
+ "providers": provider,
252
+ "device": device,
253
+ "optimizer": optimizer_info,
254
+ "precision": precision,
255
+ "io_binding": not disable_ort_io_binding,
256
+ "model_name": model_name,
257
+ "inputs": num_inputs,
258
+ "threads": num_threads,
259
+ "batch_size": batch_size,
260
+ "sequence_length": sequence_length,
261
+ "custom_layer_num": config_modifier.get_layer_num(),
262
+ "datetime": str(datetime.now()),
263
+ }
264
+
265
+ if config.model_type in ["vit", "swin"]:
266
+ logger.info(
267
+ f"Run onnxruntime on {model_name} with input shape {[batch_size, 3, config.image_size, config.image_size]}"
268
+ )
269
+ else:
270
+ logger.info(f"Run onnxruntime on {model_name} with input shape {[batch_size, sequence_length]}")
271
+
272
+ if disable_ort_io_binding:
273
+ result = inference_ort(
274
+ ort_session,
275
+ ort_inputs,
276
+ result_template,
277
+ repeat_times,
278
+ batch_size,
279
+ warm_up_repeat,
280
+ )
281
+ else:
282
+ # Get output sizes from a dummy ort run
283
+ ort_outputs = ort_session.run(ort_output_names, ort_inputs)
284
+ output_buffer_max_sizes = [max_last_state_size]
285
+ for i in range(len(ort_outputs)):
286
+ if i == 2 and MODELS[model_name][3] == "gpt":
287
+ # past state output max size
288
+ output_buffer_max_sizes.append(max_pooler_size)
289
+ else:
290
+ output_buffer_max_sizes.append(max_last_state_size)
291
+
292
+ data_type = numpy.longlong if "pt" in model_source else numpy.intc
293
+ result = inference_ort_with_io_binding(
294
+ ort_session,
295
+ ort_inputs,
296
+ result_template,
297
+ repeat_times,
298
+ ort_output_names,
299
+ ort_outputs,
300
+ output_buffers,
301
+ output_buffer_max_sizes,
302
+ batch_size,
303
+ device,
304
+ data_type,
305
+ warm_up_repeat,
306
+ )
307
+ logger.info(result)
308
+ results.append(result)
309
+
310
+ return results
311
+
312
+
313
+ def run_pytorch(
314
+ use_gpu,
315
+ model_names,
316
+ model_class,
317
+ config_modifier,
318
+ precision,
319
+ num_threads,
320
+ batch_sizes,
321
+ sequence_lengths,
322
+ repeat_times,
323
+ torchscript,
324
+ torch2,
325
+ cache_dir,
326
+ verbose,
327
+ ):
328
+ results = []
329
+ if use_gpu and not torch.cuda.is_available():
330
+ logger.error("Please install PyTorch with Cuda, and use a machine with GPU for testing gpu performance.")
331
+ return results
332
+
333
+ torch.set_grad_enabled(False)
334
+
335
+ for model_name in model_names:
336
+ config = AutoConfig.from_pretrained(model_name, torchscript=torchscript, cache_dir=cache_dir)
337
+ config_modifier.modify(config)
338
+ model = load_pretrained_model(
339
+ model_name,
340
+ config=config,
341
+ cache_dir=cache_dir,
342
+ custom_model_class=model_class,
343
+ )
344
+
345
+ if config.model_type in ["vit", "swin"]:
346
+ # These models don't use sequence lengths, so just pick the first sequence length so that the summary still works
347
+ sequence_lengths = [sequence_lengths[0]]
348
+ else:
349
+ tokenizer = AutoTokenizer.from_pretrained(model_name, cache_dir=cache_dir)
350
+
351
+ max_input_size = tokenizer.max_model_input_sizes.get(model_name, 1024)
352
+
353
+ logger.debug(f"Model {model}")
354
+ logger.debug(f"Number of parameters {model.num_parameters()}")
355
+
356
+ if precision == Precision.FLOAT16:
357
+ model.half()
358
+
359
+ device = torch.device("cuda:0" if use_gpu else "cpu")
360
+ model.to(device)
361
+
362
+ if precision == Precision.INT8:
363
+ model = QuantizeHelper.quantize_torch_model(model)
364
+
365
+ for batch_size in batch_sizes:
366
+ if batch_size <= 0:
367
+ continue
368
+
369
+ for sequence_length in sequence_lengths:
370
+ if config.model_type in ["vit", "swin"]:
371
+ logger.info(
372
+ f"Run PyTorch on {model_name} with input shape {[batch_size, 3, config.image_size, config.image_size]}"
373
+ )
374
+ input_ids = torch.randn(
375
+ size=(batch_size, 3, config.image_size, config.image_size),
376
+ dtype=torch.float16 if precision == Precision.FLOAT16 else torch.float32,
377
+ device=device,
378
+ )
379
+ else:
380
+ if max_input_size is not None and sequence_length > max_input_size:
381
+ continue
382
+
383
+ logger.info(f"Run PyTorch on {model_name} with input shape {[batch_size, sequence_length]}")
384
+ input_ids = torch.randint(
385
+ low=0,
386
+ high=config.vocab_size - 1,
387
+ size=(batch_size, sequence_length),
388
+ dtype=torch.long,
389
+ device=device,
390
+ )
391
+ try:
392
+ inference = (
393
+ torch.jit.trace(model, input_ids) if torchscript else torch.compile(model) if torch2 else model
394
+ )
395
+ inference(input_ids)
396
+
397
+ runtimes = timeit.repeat(lambda: inference(input_ids), repeat=repeat_times, number=1) # noqa: B023
398
+
399
+ result = {
400
+ "engine": "torchscript" if torchscript else "torch2" if torch2 else "torch",
401
+ "version": torch.__version__,
402
+ "providers": "NA",
403
+ "device": "cuda" if use_gpu else "cpu",
404
+ "optimizer": "",
405
+ "precision": precision,
406
+ "io_binding": "",
407
+ "model_name": model_name,
408
+ "inputs": 1,
409
+ "threads": num_threads,
410
+ "batch_size": batch_size,
411
+ "sequence_length": sequence_length,
412
+ "custom_layer_num": config_modifier.get_layer_num(),
413
+ "datetime": str(datetime.now()),
414
+ }
415
+ result.update(get_latency_result(runtimes, batch_size))
416
+ logger.info(result)
417
+ results.append(result)
418
+ except RuntimeError as e:
419
+ logger.exception(e)
420
+ torch.cuda.empty_cache()
421
+
422
+ return results
423
+
424
+
425
+ def run_with_tf_optimizations(do_eager_mode: bool, use_xla: bool):
426
+ from functools import wraps
427
+
428
+ import tensorflow as tf
429
+
430
+ def run_func(func):
431
+ @wraps(func)
432
+ def run_in_eager_mode(*args, **kwargs):
433
+ return func(*args, **kwargs)
434
+
435
+ @wraps(func)
436
+ @tf.function(experimental_compile=use_xla)
437
+ def run_in_graph_mode(*args, **kwargs):
438
+ return func(*args, **kwargs)
439
+
440
+ if do_eager_mode is True:
441
+ assert (
442
+ use_xla is False
443
+ ), "Cannot run model in XLA, if `args.eager_mode` is set to `True`. Please set `args.eager_mode=False`."
444
+ return run_in_eager_mode
445
+ else:
446
+ return run_in_graph_mode
447
+
448
+ return run_func
449
+
450
+
451
+ def run_tensorflow(
452
+ use_gpu,
453
+ model_names,
454
+ model_class,
455
+ config_modifier,
456
+ precision,
457
+ num_threads,
458
+ batch_sizes,
459
+ sequence_lengths,
460
+ repeat_times,
461
+ cache_dir,
462
+ verbose,
463
+ ):
464
+ results = []
465
+
466
+ import tensorflow as tf
467
+
468
+ tf.config.threading.set_intra_op_parallelism_threads(num_threads)
469
+
470
+ if not use_gpu:
471
+ tf.config.set_visible_devices([], "GPU")
472
+
473
+ if use_gpu and not tf.test.is_built_with_cuda():
474
+ logger.error("Please install Tensorflow-gpu, and use a machine with GPU for testing gpu performance.")
475
+ return results
476
+
477
+ if use_gpu: # Restrict TensorFlow to only use the first GPU
478
+ physical_devices = tf.config.list_physical_devices("GPU")
479
+ try:
480
+ tf.config.set_visible_devices(physical_devices[0], "GPU")
481
+ tf.config.experimental.set_memory_growth(physical_devices[0], True)
482
+ tf.distribute.OneDeviceStrategy(device="/gpu:0")
483
+ except RuntimeError as e:
484
+ logger.exception(e)
485
+
486
+ if precision == Precision.FLOAT16 or precision == Precision.INT8:
487
+ raise NotImplementedError("Mixed precision is currently not supported.")
488
+
489
+ for model_name in model_names:
490
+ config = AutoConfig.from_pretrained(model_name, cache_dir=cache_dir)
491
+ config_modifier.modify(config)
492
+
493
+ model = load_pretrained_model(
494
+ model_name,
495
+ config=config,
496
+ cache_dir=cache_dir,
497
+ custom_model_class=model_class,
498
+ is_tf_model=True,
499
+ )
500
+
501
+ tokenizer = AutoTokenizer.from_pretrained(model_name, cache_dir=cache_dir)
502
+
503
+ max_input_size = tokenizer.max_model_input_sizes.get(model_name, 1024)
504
+
505
+ for batch_size in batch_sizes:
506
+ if batch_size <= 0:
507
+ continue
508
+
509
+ for sequence_length in sequence_lengths:
510
+ if max_input_size is not None and sequence_length > max_input_size:
511
+ continue
512
+
513
+ logger.info(f"Run Tensorflow on {model_name} with input shape {[batch_size, sequence_length]}")
514
+
515
+ import random
516
+
517
+ rng = random.Random()
518
+ values = [rng.randint(0, config.vocab_size - 1) for i in range(batch_size * sequence_length)]
519
+ input_ids = tf.constant(values, shape=(batch_size, sequence_length), dtype=tf.int32)
520
+
521
+ try:
522
+ # Disable both for better inference perf
523
+ @run_with_tf_optimizations(do_eager_mode=False, use_xla=False)
524
+ def encoder_forward():
525
+ return model(input_ids, training=False) # noqa: B023
526
+
527
+ @run_with_tf_optimizations(do_eager_mode=False, use_xla=False)
528
+ def encoder_decoder_forward():
529
+ return model(input_ids, decoder_input_ids=input_ids, training=False) # noqa: B023
530
+
531
+ @run_with_tf_optimizations(do_eager_mode=False, use_xla=False)
532
+ def lxmert_forward():
533
+ feats = tf.random.normal([1, 1, config.visual_feat_dim]) # noqa: B023
534
+ pos = tf.random.normal([1, 1, config.visual_pos_dim]) # noqa: B023
535
+ return model( # noqa: B023
536
+ input_ids, # noqa: B023
537
+ visual_feats=feats,
538
+ visual_pos=pos,
539
+ training=False,
540
+ )
541
+
542
+ inference = encoder_forward
543
+ if config.is_encoder_decoder:
544
+ inference = encoder_decoder_forward
545
+ elif isinstance(config, LxmertConfig):
546
+ inference = lxmert_forward
547
+
548
+ inference()
549
+
550
+ runtimes = timeit.repeat(lambda: inference(), repeat=repeat_times, number=1) # noqa: B023
551
+
552
+ result = {
553
+ "engine": "tensorflow",
554
+ "version": tf.__version__,
555
+ "providers": "NA",
556
+ "device": "cuda" if use_gpu else "cpu",
557
+ "optimizer": "",
558
+ "precision": precision,
559
+ "io_binding": "",
560
+ "model_name": model_name,
561
+ "inputs": 1,
562
+ "threads": num_threads,
563
+ "batch_size": batch_size,
564
+ "sequence_length": sequence_length,
565
+ "custom_layer_num": config_modifier.get_layer_num(),
566
+ "datetime": str(datetime.now()),
567
+ }
568
+ result.update(get_latency_result(runtimes, batch_size))
569
+ logger.info(result)
570
+ results.append(result)
571
+ except RuntimeError as e:
572
+ logger.exception(e)
573
+ from numba import cuda
574
+
575
+ device = cuda.get_current_device()
576
+ device.reset()
577
+
578
+ return results
579
+
580
+
581
+ def parse_arguments():
582
+ parser = argparse.ArgumentParser()
583
+
584
+ parser.add_argument(
585
+ "-m",
586
+ "--models",
587
+ required=False,
588
+ nargs="+",
589
+ type=str,
590
+ default=["bert-base-cased", "roberta-base", "gpt2"],
591
+ choices=list(MODELS.keys()),
592
+ help="Pre-trained models in the list: " + ", ".join(MODELS.keys()),
593
+ )
594
+
595
+ parser.add_argument(
596
+ "--model_source",
597
+ required=False,
598
+ nargs=1,
599
+ type=str,
600
+ default="pt",
601
+ choices=["pt", "tf"],
602
+ help="Export onnx from pt or tf",
603
+ )
604
+
605
+ parser.add_argument(
606
+ "--model_class",
607
+ required=False,
608
+ type=str,
609
+ default=None,
610
+ choices=list(MODEL_CLASSES),
611
+ help="Model type selected in the list: " + ", ".join(MODEL_CLASSES),
612
+ )
613
+
614
+ parser.add_argument(
615
+ "-e",
616
+ "--engines",
617
+ required=False,
618
+ nargs="+",
619
+ type=str,
620
+ default=["onnxruntime"],
621
+ choices=["onnxruntime", "torch", "torch2", "torchscript", "tensorflow"],
622
+ help="Engines to benchmark",
623
+ )
624
+
625
+ parser.add_argument(
626
+ "-c",
627
+ "--cache_dir",
628
+ required=False,
629
+ type=str,
630
+ default=os.path.join(".", "cache_models"),
631
+ help="Directory to cache pre-trained models",
632
+ )
633
+
634
+ parser.add_argument(
635
+ "--onnx_dir",
636
+ required=False,
637
+ type=str,
638
+ default=os.path.join(".", "onnx_models"),
639
+ help="Directory to store onnx models",
640
+ )
641
+
642
+ parser.add_argument("-g", "--use_gpu", required=False, action="store_true", help="Run on gpu device")
643
+
644
+ parser.add_argument(
645
+ "--provider",
646
+ required=False,
647
+ type=str,
648
+ default=None,
649
+ help="Execution provider to use",
650
+ )
651
+
652
+ parser.add_argument(
653
+ "-p",
654
+ "--precision",
655
+ type=Precision,
656
+ default=Precision.FLOAT32,
657
+ choices=list(Precision),
658
+ help="Precision of model to run. fp32 for full precision, fp16 for half precision, and int8 for quantization",
659
+ )
660
+
661
+ parser.add_argument("--verbose", required=False, action="store_true", help="Print more information")
662
+
663
+ parser.add_argument(
664
+ "--overwrite",
665
+ required=False,
666
+ action="store_true",
667
+ help="Overwrite existing models",
668
+ )
669
+
670
+ parser.add_argument(
671
+ "-o",
672
+ "--optimizer_info",
673
+ type=OptimizerInfo,
674
+ default=OptimizerInfo.BYSCRIPT,
675
+ choices=list(OptimizerInfo),
676
+ help="Optimizer info: Use optimizer.py to optimize onnx model as default. Can also choose from by_ort and no_opt",
677
+ )
678
+
679
+ parser.add_argument(
680
+ "-v",
681
+ "--validate_onnx",
682
+ required=False,
683
+ action="store_true",
684
+ help="Validate ONNX model",
685
+ )
686
+
687
+ parser.add_argument(
688
+ "-f",
689
+ "--fusion_csv",
690
+ required=False,
691
+ default=None,
692
+ help="CSV file for saving summary results of graph optimization.",
693
+ )
694
+
695
+ parser.add_argument(
696
+ "-d",
697
+ "--detail_csv",
698
+ required=False,
699
+ default=None,
700
+ help="CSV file for saving detail results.",
701
+ )
702
+
703
+ parser.add_argument(
704
+ "-r",
705
+ "--result_csv",
706
+ required=False,
707
+ default=None,
708
+ help="CSV file for saving summary results.",
709
+ )
710
+
711
+ parser.add_argument(
712
+ "-i",
713
+ "--input_counts",
714
+ required=False,
715
+ nargs="+",
716
+ default=[1],
717
+ type=int,
718
+ choices=[1, 2, 3],
719
+ help="Number of ONNX model inputs. Please use 1 for fair comparison with Torch or TorchScript.",
720
+ )
721
+
722
+ parser.add_argument(
723
+ "-t",
724
+ "--test_times",
725
+ required=False,
726
+ default=100,
727
+ type=int,
728
+ help="Number of repeat times to get average inference latency.",
729
+ )
730
+
731
+ parser.add_argument("-b", "--batch_sizes", nargs="+", type=int, default=[1])
732
+
733
+ parser.add_argument(
734
+ "-s",
735
+ "--sequence_lengths",
736
+ nargs="+",
737
+ type=int,
738
+ default=[4, 8, 16, 32, 64, 128, 256],
739
+ )
740
+
741
+ parser.add_argument(
742
+ "--disable_ort_io_binding",
743
+ required=False,
744
+ action="store_true",
745
+ help="Disable running ONNX Runtime with binded inputs and outputs. ",
746
+ )
747
+ parser.set_defaults(disable_ort_io_binding=False)
748
+
749
+ parser.add_argument(
750
+ "-n",
751
+ "--num_threads",
752
+ required=False,
753
+ nargs="+",
754
+ type=int,
755
+ default=[0],
756
+ help="Threads to use",
757
+ )
758
+
759
+ parser.add_argument(
760
+ "--force_num_layers",
761
+ required=False,
762
+ type=int,
763
+ default=None,
764
+ help="Manually set the model's layer number",
765
+ )
766
+
767
+ parser.add_argument(
768
+ "--enable_arm64_bfloat16_fastmath_mlas_gemm",
769
+ required=False,
770
+ action="store_true",
771
+ help="Enable bfloat16 mlas gemm kernels on aarch64. Supported only for CPU EP ",
772
+ )
773
+ parser.set_defaults(enable_arm64_bfloat16_fastmath_mlas_gemm=False)
774
+
775
+ FusionOptions.add_arguments(parser)
776
+
777
+ args = parser.parse_args()
778
+ return args
779
+
780
+
781
+ def main():
782
+ args = parse_arguments()
783
+
784
+ setup_logger(args.verbose)
785
+
786
+ if args.precision == Precision.FLOAT16 and not args.use_gpu:
787
+ logger.error("fp16 is for GPU only")
788
+ return
789
+
790
+ if args.precision == Precision.INT8 and args.use_gpu and args.provider not in ["migraphx", "rocm"]:
791
+ logger.error("int8 is for CPU only")
792
+ return
793
+
794
+ if len(args.models) == 1 and MODELS[args.models[0]][3] in ["vit", "swim"]:
795
+ args.sequence_lengths = [""]
796
+
797
+ args.num_threads = sorted({cpu_count if x <= 0 else x for x in args.num_threads})
798
+
799
+ logger.info(f"Arguments: {args}")
800
+
801
+ if not os.path.exists(args.cache_dir):
802
+ try:
803
+ os.mkdir(args.cache_dir)
804
+ except OSError:
805
+ logger.error("Creation of the directory %s failed", args.cache_dir)
806
+
807
+ enable_torch = "torch" in args.engines
808
+ enable_torch2 = "torch2" in args.engines
809
+ enable_torchscript = "torchscript" in args.engines
810
+ enable_onnxruntime = "onnxruntime" in args.engines
811
+ enable_tensorflow = "tensorflow" in args.engines
812
+
813
+ if enable_torch2 and version.parse(torch.__version__) < version.parse("2.0.0"):
814
+ logger.error(f"PyTorch version must be >=2.0.0 and you are using {torch.__version__}")
815
+ return
816
+
817
+ config_modifier = ConfigModifier(args.force_num_layers)
818
+
819
+ results = []
820
+
821
+ for num_threads in args.num_threads:
822
+ torch.set_num_threads(num_threads)
823
+ logger.debug(torch.__config__.parallel_info())
824
+ if enable_torch or enable_torch2 or enable_torchscript:
825
+ if args.input_counts != [1]:
826
+ logger.warning("--input_counts is not implemented for torch or torchscript engine.")
827
+
828
+ if enable_torchscript:
829
+ results += run_pytorch(
830
+ args.use_gpu,
831
+ args.models,
832
+ args.model_class,
833
+ config_modifier,
834
+ args.precision,
835
+ num_threads,
836
+ args.batch_sizes,
837
+ args.sequence_lengths,
838
+ args.test_times,
839
+ True,
840
+ False,
841
+ args.cache_dir,
842
+ args.verbose,
843
+ )
844
+
845
+ if enable_torch:
846
+ results += run_pytorch(
847
+ args.use_gpu,
848
+ args.models,
849
+ args.model_class,
850
+ config_modifier,
851
+ args.precision,
852
+ num_threads,
853
+ args.batch_sizes,
854
+ args.sequence_lengths,
855
+ args.test_times,
856
+ False,
857
+ False,
858
+ args.cache_dir,
859
+ args.verbose,
860
+ )
861
+
862
+ if enable_torch2:
863
+ results += run_pytorch(
864
+ args.use_gpu,
865
+ args.models,
866
+ args.model_class,
867
+ config_modifier,
868
+ args.precision,
869
+ num_threads,
870
+ args.batch_sizes,
871
+ args.sequence_lengths,
872
+ args.test_times,
873
+ False,
874
+ True,
875
+ args.cache_dir,
876
+ args.verbose,
877
+ )
878
+
879
+ if enable_tensorflow:
880
+ results += run_tensorflow(
881
+ args.use_gpu,
882
+ args.models,
883
+ args.model_class,
884
+ config_modifier,
885
+ args.precision,
886
+ num_threads,
887
+ args.batch_sizes,
888
+ args.sequence_lengths,
889
+ args.test_times,
890
+ args.cache_dir,
891
+ args.verbose,
892
+ )
893
+
894
+ model_fusion_statistics = {}
895
+ if enable_onnxruntime:
896
+ try:
897
+ use_raw_attention_mask = not args.use_mask_index
898
+ results += run_onnxruntime(
899
+ args.use_gpu,
900
+ args.provider,
901
+ args.models,
902
+ args.model_class,
903
+ config_modifier,
904
+ args.precision,
905
+ num_threads,
906
+ args.batch_sizes,
907
+ args.sequence_lengths,
908
+ args.test_times,
909
+ args.input_counts,
910
+ args.optimizer_info,
911
+ args.validate_onnx,
912
+ args.cache_dir,
913
+ args.onnx_dir,
914
+ args.verbose,
915
+ args.overwrite,
916
+ args.disable_ort_io_binding,
917
+ use_raw_attention_mask,
918
+ model_fusion_statistics,
919
+ args.model_source,
920
+ args.enable_arm64_bfloat16_fastmath_mlas_gemm,
921
+ args,
922
+ )
923
+ except Exception:
924
+ logger.exception("Exception")
925
+
926
+ time_stamp = datetime.now().strftime("%Y%m%d-%H%M%S")
927
+ if model_fusion_statistics:
928
+ csv_filename = args.fusion_csv or f"benchmark_fusion_{time_stamp}.csv"
929
+ output_fusion_statistics(model_fusion_statistics, csv_filename)
930
+
931
+ if len(results) == 0:
932
+ if args.batch_sizes != [0]:
933
+ logger.warning("No any result available.")
934
+ return
935
+
936
+ csv_filename = args.detail_csv or f"benchmark_detail_{time_stamp}.csv"
937
+ output_details(results, csv_filename)
938
+
939
+ csv_filename = args.result_csv or f"benchmark_summary_{time_stamp}.csv"
940
+ output_summary(results, csv_filename, args)
941
+
942
+
943
+ if __name__ == "__main__":
944
+ main()