onnxruntime-directml 1.20.0__cp313-cp313-win_amd64.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (305) hide show
  1. onnxruntime/LICENSE +21 -0
  2. onnxruntime/Privacy.md +21 -0
  3. onnxruntime/ThirdPartyNotices.txt +6508 -0
  4. onnxruntime/__init__.py +78 -0
  5. onnxruntime/backend/__init__.py +6 -0
  6. onnxruntime/backend/backend.py +174 -0
  7. onnxruntime/backend/backend_rep.py +53 -0
  8. onnxruntime/capi/DirectML.dll +0 -0
  9. onnxruntime/capi/__init__.py +4 -0
  10. onnxruntime/capi/_ld_preload.py +7 -0
  11. onnxruntime/capi/_pybind_state.py +33 -0
  12. onnxruntime/capi/convert_npz_to_onnx_adapter.py +48 -0
  13. onnxruntime/capi/onnxruntime.dll +0 -0
  14. onnxruntime/capi/onnxruntime_collect_build_info.py +47 -0
  15. onnxruntime/capi/onnxruntime_inference_collection.py +1108 -0
  16. onnxruntime/capi/onnxruntime_providers_shared.dll +0 -0
  17. onnxruntime/capi/onnxruntime_pybind11_state.pyd +0 -0
  18. onnxruntime/capi/onnxruntime_validation.py +150 -0
  19. onnxruntime/capi/version_info.py +2 -0
  20. onnxruntime/datasets/__init__.py +17 -0
  21. onnxruntime/datasets/logreg_iris.onnx +0 -0
  22. onnxruntime/datasets/mul_1.onnx +0 -0
  23. onnxruntime/datasets/sigmoid.onnx +13 -0
  24. onnxruntime/quantization/CalTableFlatBuffers/KeyValue.py +78 -0
  25. onnxruntime/quantization/CalTableFlatBuffers/TrtTable.py +90 -0
  26. onnxruntime/quantization/CalTableFlatBuffers/__init__.py +0 -0
  27. onnxruntime/quantization/__init__.py +16 -0
  28. onnxruntime/quantization/base_quantizer.py +532 -0
  29. onnxruntime/quantization/calibrate.py +1245 -0
  30. onnxruntime/quantization/execution_providers/qnn/__init__.py +2 -0
  31. onnxruntime/quantization/execution_providers/qnn/fusion_lpnorm.py +132 -0
  32. onnxruntime/quantization/execution_providers/qnn/mixed_precision_overrides_utils.py +413 -0
  33. onnxruntime/quantization/execution_providers/qnn/preprocess.py +307 -0
  34. onnxruntime/quantization/execution_providers/qnn/quant_config.py +387 -0
  35. onnxruntime/quantization/fusions/__init__.py +3 -0
  36. onnxruntime/quantization/fusions/fusion.py +311 -0
  37. onnxruntime/quantization/fusions/fusion_gelu.py +272 -0
  38. onnxruntime/quantization/fusions/fusion_layernorm.py +135 -0
  39. onnxruntime/quantization/matmul_4bits_quantizer.py +1480 -0
  40. onnxruntime/quantization/matmul_bnb4_quantizer.py +240 -0
  41. onnxruntime/quantization/onnx_model.py +580 -0
  42. onnxruntime/quantization/onnx_quantizer.py +1008 -0
  43. onnxruntime/quantization/operators/__init__.py +2 -0
  44. onnxruntime/quantization/operators/activation.py +119 -0
  45. onnxruntime/quantization/operators/argmax.py +18 -0
  46. onnxruntime/quantization/operators/attention.py +73 -0
  47. onnxruntime/quantization/operators/base_operator.py +26 -0
  48. onnxruntime/quantization/operators/binary_op.py +72 -0
  49. onnxruntime/quantization/operators/concat.py +62 -0
  50. onnxruntime/quantization/operators/conv.py +258 -0
  51. onnxruntime/quantization/operators/direct_q8.py +78 -0
  52. onnxruntime/quantization/operators/embed_layernorm.py +121 -0
  53. onnxruntime/quantization/operators/gather.py +64 -0
  54. onnxruntime/quantization/operators/gavgpool.py +62 -0
  55. onnxruntime/quantization/operators/gemm.py +166 -0
  56. onnxruntime/quantization/operators/lstm.py +117 -0
  57. onnxruntime/quantization/operators/matmul.py +231 -0
  58. onnxruntime/quantization/operators/maxpool.py +34 -0
  59. onnxruntime/quantization/operators/norm.py +40 -0
  60. onnxruntime/quantization/operators/pad.py +100 -0
  61. onnxruntime/quantization/operators/pooling.py +67 -0
  62. onnxruntime/quantization/operators/qdq_base_operator.py +22 -0
  63. onnxruntime/quantization/operators/resize.py +34 -0
  64. onnxruntime/quantization/operators/softmax.py +74 -0
  65. onnxruntime/quantization/operators/split.py +63 -0
  66. onnxruntime/quantization/operators/where.py +87 -0
  67. onnxruntime/quantization/preprocess.py +141 -0
  68. onnxruntime/quantization/qdq_loss_debug.py +389 -0
  69. onnxruntime/quantization/qdq_quantizer.py +1187 -0
  70. onnxruntime/quantization/quant_utils.py +891 -0
  71. onnxruntime/quantization/quantize.py +748 -0
  72. onnxruntime/quantization/registry.py +106 -0
  73. onnxruntime/quantization/shape_inference.py +187 -0
  74. onnxruntime/quantization/tensor_quant_overrides.py +516 -0
  75. onnxruntime/tools/__init__.py +10 -0
  76. onnxruntime/tools/check_onnx_model_mobile_usability.py +47 -0
  77. onnxruntime/tools/convert_onnx_models_to_ort.py +377 -0
  78. onnxruntime/tools/file_utils.py +46 -0
  79. onnxruntime/tools/logger.py +11 -0
  80. onnxruntime/tools/make_dynamic_shape_fixed.py +72 -0
  81. onnxruntime/tools/mobile_helpers/__init__.py +0 -0
  82. onnxruntime/tools/mobile_helpers/coreml_supported_mlprogram_ops.md +33 -0
  83. onnxruntime/tools/mobile_helpers/coreml_supported_neuralnetwork_ops.md +43 -0
  84. onnxruntime/tools/mobile_helpers/nnapi_supported_ops.md +58 -0
  85. onnxruntime/tools/mobile_helpers/usability_checker.py +739 -0
  86. onnxruntime/tools/offline_tuning.py +169 -0
  87. onnxruntime/tools/onnx_model_utils.py +413 -0
  88. onnxruntime/tools/onnx_randomizer.py +85 -0
  89. onnxruntime/tools/onnxruntime_test.py +164 -0
  90. onnxruntime/tools/optimize_onnx_model.py +55 -0
  91. onnxruntime/tools/ort_format_model/__init__.py +25 -0
  92. onnxruntime/tools/ort_format_model/operator_type_usage_processors.py +663 -0
  93. onnxruntime/tools/ort_format_model/ort_flatbuffers_py/__init__.py +0 -0
  94. onnxruntime/tools/ort_format_model/ort_flatbuffers_py/fbs/ArgType.py +7 -0
  95. onnxruntime/tools/ort_format_model/ort_flatbuffers_py/fbs/ArgTypeAndIndex.py +67 -0
  96. onnxruntime/tools/ort_format_model/ort_flatbuffers_py/fbs/Attribute.py +337 -0
  97. onnxruntime/tools/ort_format_model/ort_flatbuffers_py/fbs/AttributeType.py +18 -0
  98. onnxruntime/tools/ort_format_model/ort_flatbuffers_py/fbs/Checkpoint.py +125 -0
  99. onnxruntime/tools/ort_format_model/ort_flatbuffers_py/fbs/DeprecatedKernelCreateInfos.py +120 -0
  100. onnxruntime/tools/ort_format_model/ort_flatbuffers_py/fbs/DeprecatedNodeIndexAndKernelDefHash.py +68 -0
  101. onnxruntime/tools/ort_format_model/ort_flatbuffers_py/fbs/DeprecatedSessionState.py +96 -0
  102. onnxruntime/tools/ort_format_model/ort_flatbuffers_py/fbs/DeprecatedSubGraphSessionState.py +72 -0
  103. onnxruntime/tools/ort_format_model/ort_flatbuffers_py/fbs/Dimension.py +71 -0
  104. onnxruntime/tools/ort_format_model/ort_flatbuffers_py/fbs/DimensionValue.py +80 -0
  105. onnxruntime/tools/ort_format_model/ort_flatbuffers_py/fbs/DimensionValueType.py +8 -0
  106. onnxruntime/tools/ort_format_model/ort_flatbuffers_py/fbs/EdgeEnd.py +32 -0
  107. onnxruntime/tools/ort_format_model/ort_flatbuffers_py/fbs/FloatProperty.py +67 -0
  108. onnxruntime/tools/ort_format_model/ort_flatbuffers_py/fbs/Graph.py +320 -0
  109. onnxruntime/tools/ort_format_model/ort_flatbuffers_py/fbs/InferenceSession.py +88 -0
  110. onnxruntime/tools/ort_format_model/ort_flatbuffers_py/fbs/IntProperty.py +67 -0
  111. onnxruntime/tools/ort_format_model/ort_flatbuffers_py/fbs/KernelTypeStrArgsEntry.py +91 -0
  112. onnxruntime/tools/ort_format_model/ort_flatbuffers_py/fbs/KernelTypeStrResolver.py +78 -0
  113. onnxruntime/tools/ort_format_model/ort_flatbuffers_py/fbs/MapType.py +71 -0
  114. onnxruntime/tools/ort_format_model/ort_flatbuffers_py/fbs/Model.py +223 -0
  115. onnxruntime/tools/ort_format_model/ort_flatbuffers_py/fbs/ModuleState.py +141 -0
  116. onnxruntime/tools/ort_format_model/ort_flatbuffers_py/fbs/Node.py +317 -0
  117. onnxruntime/tools/ort_format_model/ort_flatbuffers_py/fbs/NodeEdge.py +126 -0
  118. onnxruntime/tools/ort_format_model/ort_flatbuffers_py/fbs/NodeType.py +7 -0
  119. onnxruntime/tools/ort_format_model/ort_flatbuffers_py/fbs/NodesToOptimizeIndices.py +160 -0
  120. onnxruntime/tools/ort_format_model/ort_flatbuffers_py/fbs/OpIdKernelTypeStrArgsEntry.py +91 -0
  121. onnxruntime/tools/ort_format_model/ort_flatbuffers_py/fbs/OperatorSetId.py +67 -0
  122. onnxruntime/tools/ort_format_model/ort_flatbuffers_py/fbs/OptimizerGroup.py +117 -0
  123. onnxruntime/tools/ort_format_model/ort_flatbuffers_py/fbs/ParameterOptimizerState.py +91 -0
  124. onnxruntime/tools/ort_format_model/ort_flatbuffers_py/fbs/PropertyBag.py +152 -0
  125. onnxruntime/tools/ort_format_model/ort_flatbuffers_py/fbs/RuntimeOptimizationRecord.py +105 -0
  126. onnxruntime/tools/ort_format_model/ort_flatbuffers_py/fbs/RuntimeOptimizationRecordContainerEntry.py +91 -0
  127. onnxruntime/tools/ort_format_model/ort_flatbuffers_py/fbs/RuntimeOptimizations.py +79 -0
  128. onnxruntime/tools/ort_format_model/ort_flatbuffers_py/fbs/SequenceType.py +58 -0
  129. onnxruntime/tools/ort_format_model/ort_flatbuffers_py/fbs/Shape.py +78 -0
  130. onnxruntime/tools/ort_format_model/ort_flatbuffers_py/fbs/SparseTensor.py +114 -0
  131. onnxruntime/tools/ort_format_model/ort_flatbuffers_py/fbs/StringProperty.py +67 -0
  132. onnxruntime/tools/ort_format_model/ort_flatbuffers_py/fbs/StringStringEntry.py +67 -0
  133. onnxruntime/tools/ort_format_model/ort_flatbuffers_py/fbs/Tensor.py +203 -0
  134. onnxruntime/tools/ort_format_model/ort_flatbuffers_py/fbs/TensorDataType.py +26 -0
  135. onnxruntime/tools/ort_format_model/ort_flatbuffers_py/fbs/TensorTypeAndShape.py +71 -0
  136. onnxruntime/tools/ort_format_model/ort_flatbuffers_py/fbs/TypeInfo.py +83 -0
  137. onnxruntime/tools/ort_format_model/ort_flatbuffers_py/fbs/TypeInfoValue.py +9 -0
  138. onnxruntime/tools/ort_format_model/ort_flatbuffers_py/fbs/ValueInfo.py +84 -0
  139. onnxruntime/tools/ort_format_model/ort_flatbuffers_py/fbs/__init__.py +6 -0
  140. onnxruntime/tools/ort_format_model/ort_model_processor.py +86 -0
  141. onnxruntime/tools/ort_format_model/types.py +84 -0
  142. onnxruntime/tools/ort_format_model/utils.py +62 -0
  143. onnxruntime/tools/pytorch_export_contrib_ops.py +108 -0
  144. onnxruntime/tools/pytorch_export_helpers.py +131 -0
  145. onnxruntime/tools/qdq_helpers/__init__.py +0 -0
  146. onnxruntime/tools/qdq_helpers/optimize_qdq_model.py +37 -0
  147. onnxruntime/tools/reduced_build_config_parser.py +202 -0
  148. onnxruntime/tools/symbolic_shape_infer.py +3016 -0
  149. onnxruntime/tools/update_onnx_opset.py +31 -0
  150. onnxruntime/transformers/__init__.py +8 -0
  151. onnxruntime/transformers/affinity_helper.py +40 -0
  152. onnxruntime/transformers/benchmark.py +944 -0
  153. onnxruntime/transformers/benchmark_helper.py +646 -0
  154. onnxruntime/transformers/bert_perf_test.py +634 -0
  155. onnxruntime/transformers/bert_test_data.py +642 -0
  156. onnxruntime/transformers/compare_bert_results.py +246 -0
  157. onnxruntime/transformers/constants.py +47 -0
  158. onnxruntime/transformers/convert_generation.py +3124 -0
  159. onnxruntime/transformers/convert_tf_models_to_pytorch.py +205 -0
  160. onnxruntime/transformers/convert_to_packing_mode.py +387 -0
  161. onnxruntime/transformers/dynamo_onnx_helper.py +104 -0
  162. onnxruntime/transformers/float16.py +501 -0
  163. onnxruntime/transformers/fusion_attention.py +1235 -0
  164. onnxruntime/transformers/fusion_attention_clip.py +257 -0
  165. onnxruntime/transformers/fusion_attention_sam2.py +534 -0
  166. onnxruntime/transformers/fusion_attention_unet.py +1304 -0
  167. onnxruntime/transformers/fusion_attention_vae.py +301 -0
  168. onnxruntime/transformers/fusion_bart_attention.py +640 -0
  169. onnxruntime/transformers/fusion_base.py +137 -0
  170. onnxruntime/transformers/fusion_bias_add.py +58 -0
  171. onnxruntime/transformers/fusion_biasgelu.py +66 -0
  172. onnxruntime/transformers/fusion_biassplitgelu.py +111 -0
  173. onnxruntime/transformers/fusion_conformer_attention.py +143 -0
  174. onnxruntime/transformers/fusion_embedlayer.py +811 -0
  175. onnxruntime/transformers/fusion_fastgelu.py +360 -0
  176. onnxruntime/transformers/fusion_gelu.py +259 -0
  177. onnxruntime/transformers/fusion_gelu_approximation.py +25 -0
  178. onnxruntime/transformers/fusion_gemmfastgelu.py +122 -0
  179. onnxruntime/transformers/fusion_gpt_attention.py +546 -0
  180. onnxruntime/transformers/fusion_gpt_attention_megatron.py +355 -0
  181. onnxruntime/transformers/fusion_gpt_attention_no_past.py +260 -0
  182. onnxruntime/transformers/fusion_group_norm.py +179 -0
  183. onnxruntime/transformers/fusion_layernorm.py +465 -0
  184. onnxruntime/transformers/fusion_nhwc_conv.py +100 -0
  185. onnxruntime/transformers/fusion_options.py +340 -0
  186. onnxruntime/transformers/fusion_qordered_attention.py +421 -0
  187. onnxruntime/transformers/fusion_qordered_gelu.py +119 -0
  188. onnxruntime/transformers/fusion_qordered_layernorm.py +123 -0
  189. onnxruntime/transformers/fusion_qordered_matmul.py +217 -0
  190. onnxruntime/transformers/fusion_quickgelu.py +74 -0
  191. onnxruntime/transformers/fusion_reshape.py +173 -0
  192. onnxruntime/transformers/fusion_rotary_attention.py +1592 -0
  193. onnxruntime/transformers/fusion_shape.py +110 -0
  194. onnxruntime/transformers/fusion_simplified_layernorm.py +159 -0
  195. onnxruntime/transformers/fusion_skip_group_norm.py +255 -0
  196. onnxruntime/transformers/fusion_skiplayernorm.py +209 -0
  197. onnxruntime/transformers/fusion_transpose.py +168 -0
  198. onnxruntime/transformers/fusion_utils.py +307 -0
  199. onnxruntime/transformers/huggingface_models.py +167 -0
  200. onnxruntime/transformers/import_utils.py +20 -0
  201. onnxruntime/transformers/io_binding_helper.py +442 -0
  202. onnxruntime/transformers/large_model_exporter.py +395 -0
  203. onnxruntime/transformers/machine_info.py +221 -0
  204. onnxruntime/transformers/metrics.py +164 -0
  205. onnxruntime/transformers/models/bart/__init__.py +12 -0
  206. onnxruntime/transformers/models/bart/export.py +98 -0
  207. onnxruntime/transformers/models/bert/__init__.py +12 -0
  208. onnxruntime/transformers/models/bert/eval_squad.py +329 -0
  209. onnxruntime/transformers/models/gpt2/__init__.py +12 -0
  210. onnxruntime/transformers/models/gpt2/benchmark_gpt2.py +413 -0
  211. onnxruntime/transformers/models/gpt2/convert_to_onnx.py +561 -0
  212. onnxruntime/transformers/models/gpt2/gpt2_helper.py +1032 -0
  213. onnxruntime/transformers/models/gpt2/gpt2_parity.py +513 -0
  214. onnxruntime/transformers/models/gpt2/gpt2_tester.py +501 -0
  215. onnxruntime/transformers/models/gpt2/parity_check_helper.py +146 -0
  216. onnxruntime/transformers/models/llama/__init__.py +12 -0
  217. onnxruntime/transformers/models/llama/benchmark.py +703 -0
  218. onnxruntime/transformers/models/llama/benchmark_all.py +488 -0
  219. onnxruntime/transformers/models/llama/benchmark_e2e.py +606 -0
  220. onnxruntime/transformers/models/llama/convert_to_onnx.py +1027 -0
  221. onnxruntime/transformers/models/llama/dist_settings.py +57 -0
  222. onnxruntime/transformers/models/llama/llama_inputs.py +503 -0
  223. onnxruntime/transformers/models/llama/llama_parity.py +309 -0
  224. onnxruntime/transformers/models/llama/llama_torch.py +47 -0
  225. onnxruntime/transformers/models/llama/quant_kv_dataloader.py +108 -0
  226. onnxruntime/transformers/models/longformer/__init__.py +12 -0
  227. onnxruntime/transformers/models/longformer/benchmark_longformer.py +821 -0
  228. onnxruntime/transformers/models/longformer/convert_to_onnx.py +413 -0
  229. onnxruntime/transformers/models/longformer/generate_test_data.py +347 -0
  230. onnxruntime/transformers/models/longformer/longformer_helper.py +77 -0
  231. onnxruntime/transformers/models/phi2/__init__.py +12 -0
  232. onnxruntime/transformers/models/phi2/convert_to_onnx.py +576 -0
  233. onnxruntime/transformers/models/phi2/inference_example.py +414 -0
  234. onnxruntime/transformers/models/sam2/__init__.py +12 -0
  235. onnxruntime/transformers/models/sam2/benchmark_sam2.py +625 -0
  236. onnxruntime/transformers/models/sam2/convert_to_onnx.py +260 -0
  237. onnxruntime/transformers/models/sam2/image_decoder.py +273 -0
  238. onnxruntime/transformers/models/sam2/image_encoder.py +186 -0
  239. onnxruntime/transformers/models/sam2/mask_decoder.py +208 -0
  240. onnxruntime/transformers/models/sam2/nvtx_helper.py +33 -0
  241. onnxruntime/transformers/models/sam2/prompt_encoder.py +189 -0
  242. onnxruntime/transformers/models/sam2/sam2_demo.py +322 -0
  243. onnxruntime/transformers/models/sam2/sam2_image_onnx_predictor.py +280 -0
  244. onnxruntime/transformers/models/sam2/sam2_utils.py +147 -0
  245. onnxruntime/transformers/models/stable_diffusion/__init__.py +12 -0
  246. onnxruntime/transformers/models/stable_diffusion/benchmark.py +1429 -0
  247. onnxruntime/transformers/models/stable_diffusion/benchmark_controlnet.py +426 -0
  248. onnxruntime/transformers/models/stable_diffusion/demo_txt2img.py +102 -0
  249. onnxruntime/transformers/models/stable_diffusion/demo_txt2img_xl.py +268 -0
  250. onnxruntime/transformers/models/stable_diffusion/demo_utils.py +778 -0
  251. onnxruntime/transformers/models/stable_diffusion/diffusion_models.py +1319 -0
  252. onnxruntime/transformers/models/stable_diffusion/diffusion_schedulers.py +1181 -0
  253. onnxruntime/transformers/models/stable_diffusion/engine_builder.py +296 -0
  254. onnxruntime/transformers/models/stable_diffusion/engine_builder_ort_cuda.py +388 -0
  255. onnxruntime/transformers/models/stable_diffusion/engine_builder_ort_trt.py +288 -0
  256. onnxruntime/transformers/models/stable_diffusion/engine_builder_tensorrt.py +395 -0
  257. onnxruntime/transformers/models/stable_diffusion/engine_builder_torch.py +108 -0
  258. onnxruntime/transformers/models/stable_diffusion/optimize_pipeline.py +350 -0
  259. onnxruntime/transformers/models/stable_diffusion/ort_optimizer.py +136 -0
  260. onnxruntime/transformers/models/stable_diffusion/pipeline_stable_diffusion.py +831 -0
  261. onnxruntime/transformers/models/stable_diffusion/trt_utilities.py +12 -0
  262. onnxruntime/transformers/models/t5/__init__.py +12 -0
  263. onnxruntime/transformers/models/t5/convert_to_onnx.py +278 -0
  264. onnxruntime/transformers/models/t5/past_helper.py +150 -0
  265. onnxruntime/transformers/models/t5/t5_decoder.py +438 -0
  266. onnxruntime/transformers/models/t5/t5_encoder.py +171 -0
  267. onnxruntime/transformers/models/t5/t5_encoder_decoder_init.py +299 -0
  268. onnxruntime/transformers/models/t5/t5_helper.py +272 -0
  269. onnxruntime/transformers/models/whisper/__init__.py +12 -0
  270. onnxruntime/transformers/models/whisper/benchmark.py +610 -0
  271. onnxruntime/transformers/models/whisper/benchmark_all.py +528 -0
  272. onnxruntime/transformers/models/whisper/convert_to_onnx.py +536 -0
  273. onnxruntime/transformers/models/whisper/whisper_chain.py +329 -0
  274. onnxruntime/transformers/models/whisper/whisper_decoder.py +402 -0
  275. onnxruntime/transformers/models/whisper/whisper_encoder.py +164 -0
  276. onnxruntime/transformers/models/whisper/whisper_encoder_decoder_init.py +306 -0
  277. onnxruntime/transformers/models/whisper/whisper_helper.py +524 -0
  278. onnxruntime/transformers/models/whisper/whisper_openai_helper.py +84 -0
  279. onnxruntime/transformers/onnx_exporter.py +717 -0
  280. onnxruntime/transformers/onnx_model.py +1569 -0
  281. onnxruntime/transformers/onnx_model_bart.py +142 -0
  282. onnxruntime/transformers/onnx_model_bert.py +481 -0
  283. onnxruntime/transformers/onnx_model_bert_keras.py +475 -0
  284. onnxruntime/transformers/onnx_model_bert_tf.py +589 -0
  285. onnxruntime/transformers/onnx_model_clip.py +40 -0
  286. onnxruntime/transformers/onnx_model_conformer.py +33 -0
  287. onnxruntime/transformers/onnx_model_gpt2.py +101 -0
  288. onnxruntime/transformers/onnx_model_phi.py +930 -0
  289. onnxruntime/transformers/onnx_model_sam2.py +138 -0
  290. onnxruntime/transformers/onnx_model_t5.py +791 -0
  291. onnxruntime/transformers/onnx_model_tnlr.py +227 -0
  292. onnxruntime/transformers/onnx_model_unet.py +259 -0
  293. onnxruntime/transformers/onnx_model_vae.py +43 -0
  294. onnxruntime/transformers/onnx_utils.py +55 -0
  295. onnxruntime/transformers/optimizer.py +612 -0
  296. onnxruntime/transformers/profiler.py +725 -0
  297. onnxruntime/transformers/quantize_helper.py +76 -0
  298. onnxruntime/transformers/shape_infer_helper.py +122 -0
  299. onnxruntime/transformers/shape_optimizer.py +401 -0
  300. onnxruntime/transformers/torch_onnx_export_helper.py +74 -0
  301. onnxruntime_directml-1.20.0.dist-info/METADATA +187 -0
  302. onnxruntime_directml-1.20.0.dist-info/RECORD +305 -0
  303. onnxruntime_directml-1.20.0.dist-info/WHEEL +5 -0
  304. onnxruntime_directml-1.20.0.dist-info/entry_points.txt +2 -0
  305. onnxruntime_directml-1.20.0.dist-info/top_level.txt +1 -0
@@ -0,0 +1,703 @@
1
+ # -------------------------------------------------------------------------
2
+ # Copyright (c) Microsoft Corporation. All rights reserved.
3
+ # Licensed under the MIT License. See License.txt in the project root for
4
+ # license information.
5
+ # --------------------------------------------------------------------------
6
+ import argparse
7
+ import datetime
8
+ import gc
9
+ import itertools
10
+ import logging
11
+ import os
12
+ import sys
13
+ import time
14
+
15
+ import numpy as np
16
+ import onnx
17
+ import psutil
18
+ import torch
19
+ from benchmark_helper import measure_memory, setup_logger
20
+ from dist_settings import get_rank, get_size
21
+ from llama_inputs import (
22
+ add_io_bindings_as_ortvalues,
23
+ get_merged_sample_with_past_kv_inputs,
24
+ get_msft_sample_inputs,
25
+ get_sample_inputs,
26
+ get_sample_with_past_kv_inputs,
27
+ verify_ort_inputs,
28
+ )
29
+ from optimum.onnxruntime import ORTModelForCausalLM
30
+ from torch.profiler import ProfilerActivity, profile, record_function
31
+ from tqdm import trange
32
+ from transformers import AutoConfig, AutoModelForCausalLM, AutoTokenizer
33
+
34
+ import onnxruntime as ort
35
+
36
+ logger = logging.getLogger(__name__)
37
+
38
+
39
+ # For determining whether the ONNX model can do both prompt generation and token generation or only one of the two
40
+ def get_ort_model_inputs_len(args, model):
41
+ if args.benchmark_type in {"hf-pt-eager", "hf-pt-compile"}:
42
+ return 0
43
+ if args.benchmark_type == "hf-ort":
44
+ try:
45
+ # New Optimum export (https://github.com/huggingface/optimum/blob/888332364c2e0091da1fc974737c7e277af168bf/optimum/onnxruntime/modeling_ort.py#L268)
46
+ return len(model.inputs_names)
47
+ except Exception:
48
+ # Old Optimum export (https://github.com/huggingface/optimum/blob/c5ad7f971cb0a494eac03dc0909f146725f999c5/optimum/onnxruntime/base.py#L54)
49
+ return len(model.decoder.input_names)
50
+ return len(model.get_inputs())
51
+
52
+
53
+ def get_inputs(args: argparse.Namespace, ort_model_inputs_len: int):
54
+ init_inputs, iter_inputs = None, None
55
+
56
+ # For past_present_share_buffer:
57
+ # Set max_seq_len to 2048 for Microsoft LLaMA-2 model since that is the max value currently supported
58
+ # Set max_seq_len to config value for other models
59
+ max_seq_len = 2048 if args.benchmark_type == "ort-msft" else args.config.max_position_embeddings
60
+
61
+ if args.benchmark_type in {"hf-pt-eager", "hf-pt-compile"}:
62
+ init_inputs = get_sample_inputs(
63
+ args.config,
64
+ args.target_device,
65
+ args.batch_size,
66
+ args.sequence_length,
67
+ return_dict=True,
68
+ )
69
+ iter_inputs = get_sample_with_past_kv_inputs(
70
+ args.config,
71
+ args.target_device,
72
+ args.batch_size,
73
+ args.sequence_length,
74
+ use_fp16=args.use_fp16,
75
+ return_dict=True,
76
+ )
77
+
78
+ elif args.benchmark_type in {"hf-ort"}:
79
+ if ort_model_inputs_len == 3: # [input_ids, attention_mask, position_ids]
80
+ # Using split models in Optimum (e.g. created by Optimum export)
81
+ init_inputs = get_sample_inputs(
82
+ args.config,
83
+ args.target_device,
84
+ args.batch_size,
85
+ args.sequence_length,
86
+ return_dict=True,
87
+ )
88
+ iter_inputs = get_sample_with_past_kv_inputs(
89
+ args.config,
90
+ args.target_device,
91
+ args.batch_size,
92
+ args.sequence_length,
93
+ use_fp16=args.use_fp16,
94
+ return_dict=True,
95
+ )
96
+ else:
97
+ # Using merged model in Optimum (e.g. created by convert_to_onnx export)
98
+ init_inputs = get_merged_sample_with_past_kv_inputs(
99
+ args.config,
100
+ args.target_device,
101
+ args.batch_size,
102
+ seq_len=args.sequence_length,
103
+ past_seq_len=0,
104
+ max_seq_len=max_seq_len,
105
+ use_fp16=args.use_fp16,
106
+ use_buffer_share=args.use_buffer_share,
107
+ engine="pt",
108
+ return_dict=True,
109
+ )
110
+ iter_inputs = get_merged_sample_with_past_kv_inputs(
111
+ args.config,
112
+ args.target_device,
113
+ args.batch_size,
114
+ seq_len=1,
115
+ past_seq_len=args.sequence_length,
116
+ max_seq_len=max_seq_len,
117
+ use_fp16=args.use_fp16,
118
+ use_buffer_share=args.use_buffer_share,
119
+ engine="pt",
120
+ return_dict=True,
121
+ )
122
+
123
+ elif args.benchmark_type == "ort-convert-to-onnx":
124
+ # Microsoft export from convert_to_onnx
125
+ init_inputs = get_merged_sample_with_past_kv_inputs(
126
+ args.config,
127
+ args.target_device,
128
+ args.batch_size,
129
+ seq_len=args.sequence_length,
130
+ past_seq_len=0,
131
+ max_seq_len=max_seq_len,
132
+ use_fp16=args.use_fp16,
133
+ use_buffer_share=args.use_buffer_share,
134
+ engine="ort",
135
+ return_dict=True,
136
+ world_size=args.world_size,
137
+ )
138
+ iter_inputs = get_merged_sample_with_past_kv_inputs(
139
+ args.config,
140
+ args.target_device,
141
+ args.batch_size,
142
+ seq_len=1,
143
+ past_seq_len=args.sequence_length,
144
+ max_seq_len=max_seq_len,
145
+ use_fp16=args.use_fp16,
146
+ use_buffer_share=args.use_buffer_share,
147
+ engine="ort",
148
+ return_dict=True,
149
+ world_size=args.world_size,
150
+ )
151
+
152
+ elif args.benchmark_type == "ort-msft":
153
+ # Microsoft export from https://github.com/microsoft/Llama-2-Onnx
154
+ split_kv = ort_model_inputs_len > 5 # original inputs: [x, attn_mask, k_cache, v_cache, pos]
155
+
156
+ init_inputs = get_msft_sample_inputs(
157
+ args.config,
158
+ args.batch_size,
159
+ past_seq_len=0,
160
+ seq_len=args.sequence_length,
161
+ max_seq_len=max_seq_len,
162
+ use_fp16=args.use_fp16,
163
+ use_buffer_share=args.use_buffer_share,
164
+ split_kv=split_kv,
165
+ )
166
+ iter_inputs = get_msft_sample_inputs(
167
+ args.config,
168
+ args.batch_size,
169
+ past_seq_len=args.sequence_length,
170
+ seq_len=1,
171
+ max_seq_len=max_seq_len,
172
+ use_fp16=args.use_fp16,
173
+ use_buffer_share=args.use_buffer_share,
174
+ split_kv=split_kv,
175
+ )
176
+
177
+ else:
178
+ raise Exception("Unable to auto-detect inputs for provided model")
179
+
180
+ return init_inputs, iter_inputs
181
+
182
+
183
+ def get_model(args: argparse.Namespace):
184
+ model, sess_options = None, None
185
+ start_time, end_time = None, None
186
+
187
+ # There are multiple sources that the model could come from:
188
+ # 1) Benchmark LLaMA-2 from unofficial source on Hugging Face
189
+ # 2) Benchmark LLaMA-2 from official source on Hugging Face, which requires an authentication token
190
+ # 3) Benchmark LLaMA-2 from local download of model
191
+ # 4) Benchmark LLaMA-2 from Microsoft (already optimized, available at https://github.com/microsoft/Llama-2-Onnx)
192
+ # 5) Benchmark LLaMA-2 from convert_to_onnx
193
+
194
+ if args.benchmark_type in {"hf-pt-eager", "hf-pt-compile"}:
195
+ source = args.hf_pt_dir_path if args.hf_pt_dir_path else args.model_name
196
+ start_time = time.time()
197
+ model = AutoModelForCausalLM.from_pretrained(
198
+ source,
199
+ torch_dtype=torch.float16 if args.use_fp16 else torch.float32,
200
+ use_auth_token=args.auth,
201
+ trust_remote_code=args.auth,
202
+ use_cache=True,
203
+ cache_dir=args.cache_dir,
204
+ ).to(args.target_device)
205
+ end_time = time.time()
206
+
207
+ if args.benchmark_type == "hf-pt-compile":
208
+ model = torch.compile(model)
209
+
210
+ elif args.benchmark_type in {"hf-ort", "ort-msft", "ort-convert-to-onnx"}:
211
+ sess_options = ort.SessionOptions()
212
+ sess_options.enable_profiling = args.profile
213
+ if args.verbose:
214
+ sess_options.log_verbosity_level = 1
215
+ sess_options.log_severity_level = 1
216
+
217
+ else:
218
+ raise Exception(f"Cannot recognize {args.benchmark_type}")
219
+
220
+ if args.benchmark_type == "hf-ort":
221
+ # Optimum export or convert_to_onnx.py export
222
+ provider = args.execution_provider[0] if type(args.execution_provider) is tuple else args.execution_provider
223
+ provider_options = args.execution_provider[1] if type(args.execution_provider) is tuple else None
224
+
225
+ decoder_file_name = None
226
+ decoder_with_past_file_name = None
227
+ for filename in os.listdir(args.hf_ort_dir_path):
228
+ if ".onnx" not in filename or ".onnx_data" in filename or ".onnx.data" in filename:
229
+ continue
230
+ if "decoder_model" in filename or filename == "model.onnx":
231
+ decoder_file_name = filename
232
+ if "decoder_with_past_model" in filename:
233
+ decoder_with_past_file_name = filename
234
+ if "decoder_merged_model" in filename:
235
+ decoder_file_name = filename
236
+ decoder_with_past_file_name = filename
237
+
238
+ start_time = time.time()
239
+ model = ORTModelForCausalLM.from_pretrained(
240
+ args.hf_ort_dir_path,
241
+ decoder_file_name=decoder_file_name,
242
+ decoder_with_past_file_name=decoder_with_past_file_name,
243
+ use_auth_token=args.auth,
244
+ trust_remote_code=args.auth,
245
+ use_io_binding=True, # Large perf gain even for cpu due to avoiding output copy.
246
+ use_merged=(True if decoder_file_name == "model.onnx" else None),
247
+ provider=provider,
248
+ provider_options=provider_options,
249
+ session_options=sess_options,
250
+ )
251
+ end_time = time.time()
252
+
253
+ if args.benchmark_type in {"ort-msft", "ort-convert-to-onnx"}:
254
+ # Ex: Microsoft export from https://github.com/microsoft/Llama-2-Onnx
255
+ logger.info(f"Loading model from {args.ort_model_path.format(args.rank)}")
256
+ start_time = time.time()
257
+ model = ort.InferenceSession(
258
+ args.ort_model_path.format(args.rank),
259
+ sess_options,
260
+ providers=[args.execution_provider],
261
+ )
262
+ end_time = time.time()
263
+
264
+ logger.info(f"Loaded model in {end_time - start_time} s")
265
+ return model
266
+
267
+
268
+ def time_fn(args, fn, inputs):
269
+ # Warm up
270
+ warmup_range = (
271
+ range(args.warmup_runs)
272
+ if args.benchmark_type in {"ort-msft", "ort-convert-to-onnx"}
273
+ else trange(args.warmup_runs, file=sys.stdout, desc="Warm up")
274
+ )
275
+
276
+ if args.verbose:
277
+ outputs = fn(inputs)
278
+ logger.info(outputs)
279
+
280
+ input_sync = lambda *kwargs: ( # noqa: E731
281
+ args.io_binding.synchronize_inputs()
282
+ if args.device != "cpu" and args.benchmark_type in {"ort-msft", "ort-convert-to-onnx"} # ORT synchronize
283
+ else lambda *kwargs: (
284
+ torch.cuda.synchronize()
285
+ if args.device != "cpu" and torch.cuda.is_available() # PyTorch synchronize
286
+ else lambda *kwargs: None
287
+ )
288
+ ) # no-op function
289
+
290
+ output_sync = lambda *kwargs: ( # noqa: E731
291
+ args.io_binding.synchronize_outputs()
292
+ if args.device != "cpu" and args.benchmark_type in {"ort-msft", "ort-convert-to-onnx"} # ORT synchronize
293
+ else lambda *kwargs: (
294
+ torch.cuda.synchronize()
295
+ if args.device != "cpu" and torch.cuda.is_available() # PyTorch synchronize
296
+ else lambda *kwargs: None
297
+ )
298
+ ) # no-op function
299
+
300
+ for _ in warmup_range:
301
+ input_sync()
302
+ fn(inputs)
303
+ output_sync()
304
+
305
+ # Benchmark
306
+ total_time = 0
307
+ bench_range = (
308
+ range(args.num_runs)
309
+ if args.benchmark_type in {"ort-msft", "ort-convert-to-onnx"}
310
+ else trange(args.num_runs, file=sys.stdout, desc="Benchmark")
311
+ )
312
+ for _ in bench_range:
313
+ input_sync()
314
+ start_time = time.time()
315
+
316
+ fn(inputs)
317
+
318
+ output_sync()
319
+ end_time = time.time()
320
+
321
+ total_time += end_time - start_time
322
+
323
+ # Newline print after trange in order to print metrics on new lines without progress bar on same line
324
+ if args.benchmark_type not in {"ort-msft", "ort-convert-to-onnx"}:
325
+ logger.info("")
326
+
327
+ latency = total_time / args.num_runs
328
+ throughput = args.batch_size / latency
329
+
330
+ if args.rank == 0:
331
+ logger.info(f"Batch Size: {args.batch_size}")
332
+ logger.info(f"Sequence Length: {args.sequence_length}")
333
+ logger.info(f"Latency: {latency} s")
334
+ logger.info(f"Throughput: {throughput} tps")
335
+ return
336
+
337
+
338
+ def profile_fn(args, fn, inputs, inputs_type):
339
+ # Filename prefix format:
340
+ # "b<batch-size>_s<sequence-length>_<benchmark-type>-<precision>-<device>_<inference-step>_<inputs-type>_<current-time>"
341
+ prefix = f"b{args.batch_size}_s{args.sequence_length}_{args.benchmark_type.lower()}-{args.precision}-{args.device}_{fn.__name__.replace('_', '-')}_{inputs_type}_{datetime.datetime.now():%Y-%m-%d_%H:%M:%S}"
342
+ filename = None
343
+
344
+ if args.benchmark_type in {"hf-pt-eager", "hf-pt-compile"}:
345
+ # Profile PyTorch kernels
346
+ with profile( # noqa: SIM117
347
+ activities=[ProfilerActivity.CPU, ProfilerActivity.CUDA], record_shapes=True, profile_memory=True
348
+ ) as prof:
349
+ with record_function("model_inference"):
350
+ fn(inputs)
351
+ prof_data = prof.key_averages(group_by_stack_n=5).table(sort_by=args.pt_filter_by, row_limit=args.pt_num_rows)
352
+
353
+ filename = os.path.join(args.log_folder, f"{prefix}.log")
354
+ with open(filename, "w") as f:
355
+ f.write(prof_data)
356
+
357
+ else:
358
+ # Profile ORT kernels
359
+ fn(inputs)
360
+
361
+ # Set new log name for ORT profile log generated
362
+ filename = f"{prefix}.json"
363
+
364
+ return filename
365
+
366
+
367
+ def measure_fn(args, fn, inputs):
368
+ # Measure CPU usage
369
+ pid = os.getpid()
370
+ process = psutil.Process(pid)
371
+ process.cpu_percent(interval=0.1)
372
+
373
+ fn(inputs)
374
+ if args.rank == 0:
375
+ logger.info(f"CPU usage: {process.cpu_percent(interval=None) / psutil.cpu_count(logical=False)}%")
376
+
377
+ # Measure memory usage
378
+ gc.collect()
379
+ torch.cuda.empty_cache()
380
+ measure_memory(is_gpu=(args.device != "cpu"), func=lambda: fn(inputs))
381
+
382
+ # Flush output so memory usage is printed
383
+ sys.stdout.flush()
384
+
385
+
386
+ def run_hf_inference(args, init_inputs, iter_inputs, model):
387
+ # Inference steps to measure
388
+ def get_logits(inputs):
389
+ # Inference pass without decoding
390
+ outputs = model(**inputs)
391
+ return outputs
392
+
393
+ # Examples of other inference steps that can be measured:
394
+ # To use, uncomment the function and assign it to `generate_fn`
395
+
396
+ # def get_pred_ids(inputs):
397
+ # # Inference pass with predicted token ids generation
398
+ # predicted_ids = model.generate(**inputs)
399
+ # return predicted_ids
400
+
401
+ # def gen_and_dec(inputs):
402
+ # # Inference pass with generation and decoding
403
+ # predicted_ids = get_pred_ids(inputs)
404
+ # transcription = []
405
+ # for bs in range(args.batch_size):
406
+ # for rs in range(args.num_return_sequences):
407
+ # transcription.append(
408
+ # args.tokenizer.batch_decode(
409
+ # predicted_ids[bs * args.num_return_sequences + rs], skip_special_tokens=True
410
+ # )[0]
411
+ # )
412
+ # return transcription
413
+
414
+ generate_fn = get_logits
415
+
416
+ if args.benchmark_type == "hf-pt-compile":
417
+ # Run forward pass once with each set of inputs to process through Dynamo
418
+ generate_fn(init_inputs)
419
+ generate_fn(iter_inputs)
420
+
421
+ if args.profile:
422
+ new_logname = profile_fn(args, generate_fn, init_inputs, "prompt")
423
+ if args.benchmark_type == "hf-ort":
424
+ # Turn profiling off to stop appending to log
425
+ old_logname = model.decoder.session.end_profiling()
426
+ logger.warning(f"Renaming {old_logname} to {new_logname}")
427
+ os.rename(old_logname, os.path.join(args.log_folder, new_logname))
428
+
429
+ new_logname = profile_fn(args, generate_fn, iter_inputs, "token")
430
+ if args.benchmark_type == "hf-ort":
431
+ # Turn profiling off to stop appending to log
432
+ old_logname = model.decoder_with_past.session.end_profiling()
433
+ logger.warning(f"Renaming {old_logname} to {new_logname}")
434
+ os.rename(old_logname, os.path.join(args.log_folder, new_logname))
435
+
436
+ return
437
+
438
+ # PyTorch evaluations
439
+ logger.info("\nEvaluating `model(inputs)` step to get past_key_values")
440
+ time_fn(args, generate_fn, init_inputs)
441
+ measure_fn(args, generate_fn, init_inputs)
442
+
443
+ logger.info("\nEvaluating `model(inputs)` step with past_key_values")
444
+ time_fn(args, generate_fn, iter_inputs)
445
+ measure_fn(args, generate_fn, iter_inputs)
446
+
447
+
448
+ def run_ort_inference(args, init_inputs, iter_inputs, model):
449
+ def prepare_ort_inputs(inputs, kv_cache_ortvalues):
450
+ # Verify model inputs
451
+ inputs = verify_ort_inputs(model, inputs)
452
+
453
+ # Add IO bindings for non-CPU execution providers
454
+ if args.device != "cpu":
455
+ io_binding, kv_cache_ortvalues = add_io_bindings_as_ortvalues(
456
+ model, inputs, args.device, int(args.rank), args.use_buffer_share, kv_cache_ortvalues
457
+ )
458
+ setattr(args, "io_binding", io_binding) # noqa: B010
459
+ return io_binding, kv_cache_ortvalues
460
+
461
+ return inputs, kv_cache_ortvalues
462
+
463
+ def with_io_binding(io_binding):
464
+ # Inference pass with IO binding
465
+ model.run_with_iobinding(io_binding)
466
+
467
+ def without_io_binding(inputs):
468
+ # Inference pass without IO binding
469
+ outputs = model.run(None, inputs)
470
+ return outputs
471
+
472
+ generate_fn = with_io_binding if args.device != "cpu" else without_io_binding
473
+ kv_cache_ortvalues = {}
474
+
475
+ if args.profile:
476
+ ort_init_inputs, kv_cache_ortvalues = prepare_ort_inputs(init_inputs, kv_cache_ortvalues)
477
+ new_logname = profile_fn(args, generate_fn, ort_init_inputs, "prompt")
478
+
479
+ # Turn profiling off to stop appending to log file
480
+ old_logname = model.end_profiling()
481
+ logger.warning(f"Renaming {old_logname} to {new_logname}")
482
+ os.rename(old_logname, os.path.join(args.log_folder, new_logname))
483
+
484
+ # Re-initialize model for new log file instead of appending to old log file
485
+ model = get_model(args)
486
+ ort_iter_inputs, kv_cache_ortvalues = prepare_ort_inputs(iter_inputs, kv_cache_ortvalues)
487
+ new_logname = profile_fn(args, generate_fn, ort_iter_inputs, "token")
488
+
489
+ # Turn profiling off to stop appending to log
490
+ old_logname = model.end_profiling()
491
+ logger.warning(f"Renaming {old_logname} to {new_logname}")
492
+ os.rename(old_logname, os.path.join(args.log_folder, new_logname))
493
+ return
494
+
495
+ # ORT evaluations
496
+ logger.info("\nEvaluating `model(inputs)` step to get past_key_values")
497
+ ort_init_inputs, kv_cache_ortvalues = prepare_ort_inputs(init_inputs, kv_cache_ortvalues)
498
+ time_fn(args, generate_fn, ort_init_inputs)
499
+ measure_fn(args, generate_fn, ort_init_inputs)
500
+
501
+ logger.info("\nEvaluating `model(inputs)` step with past_key_values")
502
+ ort_iter_inputs, kv_cache_ortvalues = prepare_ort_inputs(iter_inputs, kv_cache_ortvalues)
503
+ time_fn(args, generate_fn, ort_iter_inputs)
504
+ measure_fn(args, generate_fn, ort_iter_inputs)
505
+
506
+
507
+ def run_inference(args, init_inputs, iter_inputs, model):
508
+ if args.benchmark_type in {"hf-pt-eager", "hf-pt-compile", "hf-ort"}:
509
+ run_hf_inference(args, init_inputs, iter_inputs, model)
510
+ elif args.benchmark_type in {"ort-msft", "ort-convert-to-onnx"}:
511
+ run_ort_inference(args, init_inputs, iter_inputs, model)
512
+ else:
513
+ raise Exception(f"Cannot recognize {args.benchmark_type}")
514
+
515
+
516
+ def get_args(rank=0):
517
+ parser = argparse.ArgumentParser()
518
+ parser.add_argument(
519
+ "-bt",
520
+ "--benchmark-type",
521
+ type=str,
522
+ required=True,
523
+ choices=[
524
+ "hf-pt-eager",
525
+ "hf-pt-compile",
526
+ "hf-ort",
527
+ "ort-msft",
528
+ "ort-convert-to-onnx",
529
+ ],
530
+ )
531
+ parser.add_argument(
532
+ "-m",
533
+ "--model-name",
534
+ type=str,
535
+ required=True,
536
+ help="Hugging Face name of model (e.g. 'meta-llama/Llama-2-7b-hf')",
537
+ )
538
+ parser.add_argument(
539
+ "-a", "--auth", default=False, action="store_true", help="Use Hugging Face authentication token to access model"
540
+ )
541
+
542
+ # Args for choosing the model
543
+ parser.add_argument(
544
+ "-p",
545
+ "--precision",
546
+ required=True,
547
+ type=str,
548
+ default="fp32",
549
+ choices=["int4", "int8", "fp16", "fp32"],
550
+ help="Precision for model. For ONNX models, the model's precision should be set before running this script.",
551
+ )
552
+ parser.add_argument(
553
+ "--hf-pt-dir-path",
554
+ type=str,
555
+ default="",
556
+ help="Path to directory containing all PyTorch files (e.g. tokenizer, PyTorch model)",
557
+ )
558
+ parser.add_argument(
559
+ "--hf-ort-dir-path",
560
+ type=str,
561
+ default="",
562
+ help="Path to directory containing all ONNX files (e.g. tokenizer, decoder_merged, decoder, decoder_with_past)",
563
+ )
564
+ parser.add_argument(
565
+ "--ort-model-path",
566
+ type=str,
567
+ default="",
568
+ help="Path to ONNX model",
569
+ )
570
+
571
+ # Args for running and evaluating the model
572
+ parser.add_argument(
573
+ "-b",
574
+ "--batch-sizes",
575
+ default="1 2",
576
+ )
577
+ parser.add_argument(
578
+ "-s",
579
+ "--sequence-lengths",
580
+ default="32 64 128 256 512",
581
+ )
582
+ parser.add_argument(
583
+ "-d",
584
+ "--device",
585
+ type=str,
586
+ default="cuda" if torch.cuda.is_available() else "cpu",
587
+ choices=["cpu", "cuda", "rocm"],
588
+ )
589
+ parser.add_argument("-id", "--device-id", type=int, default=0)
590
+ parser.add_argument("-w", "--warmup-runs", type=int, default=5)
591
+ parser.add_argument("-n", "--num-runs", type=int, default=10)
592
+ parser.add_argument("--seed", type=int, default=2)
593
+
594
+ # Args for decoding logic
595
+ parser.add_argument("--max-length", type=int, default=32)
596
+ parser.add_argument("--num-return-sequences", type=int, default=1)
597
+
598
+ # Args for accessing detailed info
599
+ parser.add_argument("--profile", default=False, action="store_true")
600
+ parser.add_argument(
601
+ "--pt-filter-by", type=str, default="self_cpu_time_total", help="What to filter PyTorch profiler by"
602
+ )
603
+ parser.add_argument("--pt-num-rows", type=int, default=1000, help="Number of rows for PyTorch profiler to display")
604
+ parser.add_argument("--verbose", default=False, action="store_true")
605
+ parser.add_argument("--log-folder", type=str, default=os.path.join("."), help="Folder to cache log files")
606
+ parser.add_argument(
607
+ "--cache-dir",
608
+ type=str,
609
+ required=True,
610
+ default="./model_cache",
611
+ help="Cache dir where Hugging Face files are stored",
612
+ )
613
+
614
+ args = parser.parse_args()
615
+
616
+ # Set seed properties
617
+ np.random.seed(args.seed)
618
+ torch.manual_seed(args.seed)
619
+
620
+ # Set runtime properties
621
+ if "ort" in args.benchmark_type:
622
+ setattr(args, "execution_provider", f"{args.device.upper()}ExecutionProvider") # noqa: B010
623
+ if args.execution_provider == "CUDAExecutionProvider":
624
+ args.execution_provider = (args.execution_provider, {"device_id": rank})
625
+ elif args.execution_provider == "ROCMExecutionProvider":
626
+ args.execution_provider = (args.execution_provider, {"device_id": rank})
627
+ args.device = "cuda"
628
+
629
+ # Check that paths have been specified for any benchmarking with ORT
630
+ if args.benchmark_type == "hf-ort":
631
+ assert args.hf_ort_dir_path, "Please specify a path to `--hf-ort-dir-path`"
632
+ if args.benchmark_type in {"ort-msft", "ort-convert-to-onnx"}:
633
+ assert args.ort_model_path, "Please specify a path to `--ort-model-path`"
634
+
635
+ args.batch_sizes = args.batch_sizes.split(" ")
636
+ args.sequence_lengths = args.sequence_lengths.split(" ")
637
+
638
+ # Use FP32 precision for FP32, INT8, INT4 CPU models, use FP16 precision for FP16 and INT4 GPU models
639
+ args.precision = (
640
+ "fp32" if args.precision in {"int8", "fp32"} or (args.precision == "int4" and args.device == "cpu") else "fp16"
641
+ )
642
+
643
+ # Check that only one (batch_size, sequence_length) combination is set for profiling
644
+ if args.profile:
645
+ assert (
646
+ len(args.batch_sizes) == 1 and len(args.sequence_lengths) == 1
647
+ ), "Please provide only one (batch_size, sequence_length) combination for profiling"
648
+
649
+ return args
650
+
651
+
652
+ def main():
653
+ rank = get_rank()
654
+ world_size = get_size()
655
+
656
+ args = get_args(rank)
657
+ setup_logger(args.verbose)
658
+ logger.info(args.__dict__)
659
+ torch.backends.cudnn.benchmark = True
660
+
661
+ args.rank = rank
662
+ args.world_size = world_size
663
+ tokenizer = AutoTokenizer.from_pretrained(
664
+ args.model_name, cache_dir=args.cache_dir, use_auth_token=args.auth, trust_remote_code=args.auth
665
+ )
666
+ config = AutoConfig.from_pretrained(
667
+ args.model_name, cache_dir=args.cache_dir, use_auth_token=args.auth, trust_remote_code=args.auth
668
+ )
669
+ target_device = f"cuda:{args.rank}" if args.device != "cpu" else args.device
670
+ use_fp16 = args.precision == "fp16"
671
+
672
+ setattr(args, "tokenizer", tokenizer) # noqa: B010
673
+ setattr(args, "config", config) # noqa: B010
674
+ setattr(args, "target_device", target_device) # noqa: B010
675
+ setattr(args, "use_fp16", use_fp16) # noqa: B010
676
+
677
+ # Get model and model info
678
+ model = get_model(args)
679
+ ort_model_inputs_len = get_ort_model_inputs_len(args, model)
680
+
681
+ # Check if past_present_share_buffer can be enabled (only for FP16 models with GQA)
682
+ if args.benchmark_type in {"ort-convert-to-onnx", "ort-msft"}:
683
+ onnx_model = onnx.load_model(args.ort_model_path.format(args.rank), load_external_data=False)
684
+ gqa_nodes = list(filter(lambda node: node.op_type == "GroupQueryAttention", onnx_model.graph.node))
685
+
686
+ use_buffer_share = use_fp16 and len(gqa_nodes) > 0 and args.device != "cpu"
687
+ setattr(args, "use_buffer_share", use_buffer_share) # noqa: B010
688
+ else:
689
+ setattr(args, "use_buffer_share", False) # noqa: B010
690
+
691
+ # Measure prompt cost (init_inputs) and generated token cost (iter_inputs)
692
+ for batch_size, sequence_length in itertools.product(args.batch_sizes, args.sequence_lengths):
693
+ if args.rank == 0:
694
+ logger.info(f"\nBatch size = {batch_size} and sequence length = {sequence_length}...")
695
+ setattr(args, "batch_size", int(batch_size)) # noqa: B010
696
+ setattr(args, "sequence_length", int(sequence_length)) # noqa: B010
697
+
698
+ init_inputs, iter_inputs = get_inputs(args, ort_model_inputs_len)
699
+ run_inference(args, init_inputs, iter_inputs, model)
700
+
701
+
702
+ if __name__ == "__main__":
703
+ main()