onnxruntime-directml 1.24.1__cp314-cp314-win_amd64.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (322) hide show
  1. onnxruntime/LICENSE +21 -0
  2. onnxruntime/Privacy.md +21 -0
  3. onnxruntime/ThirdPartyNotices.txt +6121 -0
  4. onnxruntime/__init__.py +418 -0
  5. onnxruntime/backend/__init__.py +6 -0
  6. onnxruntime/backend/backend.py +175 -0
  7. onnxruntime/backend/backend_rep.py +52 -0
  8. onnxruntime/capi/DirectML.dll +0 -0
  9. onnxruntime/capi/__init__.py +4 -0
  10. onnxruntime/capi/_ld_preload.py +7 -0
  11. onnxruntime/capi/_pybind_state.py +33 -0
  12. onnxruntime/capi/build_and_package_info.py +2 -0
  13. onnxruntime/capi/convert_npz_to_onnx_adapter.py +48 -0
  14. onnxruntime/capi/onnxruntime.dll +0 -0
  15. onnxruntime/capi/onnxruntime_collect_build_info.py +47 -0
  16. onnxruntime/capi/onnxruntime_inference_collection.py +1440 -0
  17. onnxruntime/capi/onnxruntime_providers_shared.dll +0 -0
  18. onnxruntime/capi/onnxruntime_pybind11_state.pyd +0 -0
  19. onnxruntime/capi/onnxruntime_validation.py +154 -0
  20. onnxruntime/capi/version_info.py +2 -0
  21. onnxruntime/datasets/__init__.py +18 -0
  22. onnxruntime/datasets/logreg_iris.onnx +0 -0
  23. onnxruntime/datasets/mul_1.onnx +0 -0
  24. onnxruntime/datasets/sigmoid.onnx +13 -0
  25. onnxruntime/quantization/CalTableFlatBuffers/KeyValue.py +78 -0
  26. onnxruntime/quantization/CalTableFlatBuffers/TrtTable.py +90 -0
  27. onnxruntime/quantization/CalTableFlatBuffers/__init__.py +0 -0
  28. onnxruntime/quantization/__init__.py +19 -0
  29. onnxruntime/quantization/base_quantizer.py +529 -0
  30. onnxruntime/quantization/calibrate.py +1267 -0
  31. onnxruntime/quantization/execution_providers/qnn/__init__.py +2 -0
  32. onnxruntime/quantization/execution_providers/qnn/fusion_lpnorm.py +132 -0
  33. onnxruntime/quantization/execution_providers/qnn/fusion_spacetodepth.py +162 -0
  34. onnxruntime/quantization/execution_providers/qnn/mixed_precision_overrides_utils.py +413 -0
  35. onnxruntime/quantization/execution_providers/qnn/preprocess.py +353 -0
  36. onnxruntime/quantization/execution_providers/qnn/quant_config.py +389 -0
  37. onnxruntime/quantization/fusions/__init__.py +4 -0
  38. onnxruntime/quantization/fusions/fusion.py +311 -0
  39. onnxruntime/quantization/fusions/fusion_gelu.py +272 -0
  40. onnxruntime/quantization/fusions/fusion_layernorm.py +146 -0
  41. onnxruntime/quantization/fusions/replace_upsample_with_resize.py +96 -0
  42. onnxruntime/quantization/matmul_bnb4_quantizer.py +239 -0
  43. onnxruntime/quantization/matmul_nbits_quantizer.py +1638 -0
  44. onnxruntime/quantization/neural_compressor/__init__.py +1 -0
  45. onnxruntime/quantization/neural_compressor/onnx_model.py +1251 -0
  46. onnxruntime/quantization/neural_compressor/util.py +80 -0
  47. onnxruntime/quantization/neural_compressor/weight_only.py +932 -0
  48. onnxruntime/quantization/onnx_model.py +600 -0
  49. onnxruntime/quantization/onnx_quantizer.py +1163 -0
  50. onnxruntime/quantization/operators/__init__.py +2 -0
  51. onnxruntime/quantization/operators/activation.py +119 -0
  52. onnxruntime/quantization/operators/argmax.py +18 -0
  53. onnxruntime/quantization/operators/attention.py +73 -0
  54. onnxruntime/quantization/operators/base_operator.py +26 -0
  55. onnxruntime/quantization/operators/binary_op.py +72 -0
  56. onnxruntime/quantization/operators/concat.py +62 -0
  57. onnxruntime/quantization/operators/conv.py +260 -0
  58. onnxruntime/quantization/operators/direct_q8.py +78 -0
  59. onnxruntime/quantization/operators/embed_layernorm.py +121 -0
  60. onnxruntime/quantization/operators/gather.py +64 -0
  61. onnxruntime/quantization/operators/gavgpool.py +62 -0
  62. onnxruntime/quantization/operators/gemm.py +172 -0
  63. onnxruntime/quantization/operators/lstm.py +121 -0
  64. onnxruntime/quantization/operators/matmul.py +231 -0
  65. onnxruntime/quantization/operators/maxpool.py +34 -0
  66. onnxruntime/quantization/operators/norm.py +40 -0
  67. onnxruntime/quantization/operators/pad.py +172 -0
  68. onnxruntime/quantization/operators/pooling.py +67 -0
  69. onnxruntime/quantization/operators/qdq_base_operator.py +22 -0
  70. onnxruntime/quantization/operators/resize.py +34 -0
  71. onnxruntime/quantization/operators/softmax.py +74 -0
  72. onnxruntime/quantization/operators/split.py +63 -0
  73. onnxruntime/quantization/operators/where.py +87 -0
  74. onnxruntime/quantization/preprocess.py +141 -0
  75. onnxruntime/quantization/qdq_loss_debug.py +389 -0
  76. onnxruntime/quantization/qdq_quantizer.py +1477 -0
  77. onnxruntime/quantization/quant_utils.py +1051 -0
  78. onnxruntime/quantization/quantize.py +953 -0
  79. onnxruntime/quantization/registry.py +110 -0
  80. onnxruntime/quantization/shape_inference.py +204 -0
  81. onnxruntime/quantization/static_quantize_runner.py +256 -0
  82. onnxruntime/quantization/tensor_quant_overrides.py +520 -0
  83. onnxruntime/tools/__init__.py +10 -0
  84. onnxruntime/tools/check_onnx_model_mobile_usability.py +47 -0
  85. onnxruntime/tools/convert_onnx_models_to_ort.py +380 -0
  86. onnxruntime/tools/file_utils.py +47 -0
  87. onnxruntime/tools/logger.py +11 -0
  88. onnxruntime/tools/make_dynamic_shape_fixed.py +73 -0
  89. onnxruntime/tools/mobile_helpers/__init__.py +0 -0
  90. onnxruntime/tools/mobile_helpers/coreml_supported_mlprogram_ops.md +53 -0
  91. onnxruntime/tools/mobile_helpers/coreml_supported_neuralnetwork_ops.md +43 -0
  92. onnxruntime/tools/mobile_helpers/nnapi_supported_ops.md +58 -0
  93. onnxruntime/tools/mobile_helpers/usability_checker.py +738 -0
  94. onnxruntime/tools/offline_tuning.py +169 -0
  95. onnxruntime/tools/onnx_model_utils.py +416 -0
  96. onnxruntime/tools/onnx_randomizer.py +85 -0
  97. onnxruntime/tools/onnxruntime_test.py +164 -0
  98. onnxruntime/tools/optimize_onnx_model.py +56 -0
  99. onnxruntime/tools/ort_format_model/__init__.py +27 -0
  100. onnxruntime/tools/ort_format_model/operator_type_usage_processors.py +653 -0
  101. onnxruntime/tools/ort_format_model/ort_flatbuffers_py/__init__.py +0 -0
  102. onnxruntime/tools/ort_format_model/ort_flatbuffers_py/fbs/ArgType.py +7 -0
  103. onnxruntime/tools/ort_format_model/ort_flatbuffers_py/fbs/ArgTypeAndIndex.py +67 -0
  104. onnxruntime/tools/ort_format_model/ort_flatbuffers_py/fbs/Attribute.py +337 -0
  105. onnxruntime/tools/ort_format_model/ort_flatbuffers_py/fbs/AttributeType.py +18 -0
  106. onnxruntime/tools/ort_format_model/ort_flatbuffers_py/fbs/Checkpoint.py +125 -0
  107. onnxruntime/tools/ort_format_model/ort_flatbuffers_py/fbs/DeprecatedKernelCreateInfos.py +120 -0
  108. onnxruntime/tools/ort_format_model/ort_flatbuffers_py/fbs/DeprecatedNodeIndexAndKernelDefHash.py +68 -0
  109. onnxruntime/tools/ort_format_model/ort_flatbuffers_py/fbs/DeprecatedSessionState.py +96 -0
  110. onnxruntime/tools/ort_format_model/ort_flatbuffers_py/fbs/DeprecatedSubGraphSessionState.py +72 -0
  111. onnxruntime/tools/ort_format_model/ort_flatbuffers_py/fbs/Dimension.py +71 -0
  112. onnxruntime/tools/ort_format_model/ort_flatbuffers_py/fbs/DimensionValue.py +80 -0
  113. onnxruntime/tools/ort_format_model/ort_flatbuffers_py/fbs/DimensionValueType.py +8 -0
  114. onnxruntime/tools/ort_format_model/ort_flatbuffers_py/fbs/EdgeEnd.py +32 -0
  115. onnxruntime/tools/ort_format_model/ort_flatbuffers_py/fbs/FloatProperty.py +67 -0
  116. onnxruntime/tools/ort_format_model/ort_flatbuffers_py/fbs/Graph.py +320 -0
  117. onnxruntime/tools/ort_format_model/ort_flatbuffers_py/fbs/InferenceSession.py +88 -0
  118. onnxruntime/tools/ort_format_model/ort_flatbuffers_py/fbs/IntProperty.py +67 -0
  119. onnxruntime/tools/ort_format_model/ort_flatbuffers_py/fbs/KernelTypeStrArgsEntry.py +91 -0
  120. onnxruntime/tools/ort_format_model/ort_flatbuffers_py/fbs/KernelTypeStrResolver.py +78 -0
  121. onnxruntime/tools/ort_format_model/ort_flatbuffers_py/fbs/MapType.py +71 -0
  122. onnxruntime/tools/ort_format_model/ort_flatbuffers_py/fbs/Model.py +223 -0
  123. onnxruntime/tools/ort_format_model/ort_flatbuffers_py/fbs/ModuleState.py +141 -0
  124. onnxruntime/tools/ort_format_model/ort_flatbuffers_py/fbs/Node.py +317 -0
  125. onnxruntime/tools/ort_format_model/ort_flatbuffers_py/fbs/NodeEdge.py +126 -0
  126. onnxruntime/tools/ort_format_model/ort_flatbuffers_py/fbs/NodeType.py +7 -0
  127. onnxruntime/tools/ort_format_model/ort_flatbuffers_py/fbs/NodesToOptimizeIndices.py +160 -0
  128. onnxruntime/tools/ort_format_model/ort_flatbuffers_py/fbs/OpIdKernelTypeStrArgsEntry.py +91 -0
  129. onnxruntime/tools/ort_format_model/ort_flatbuffers_py/fbs/OperatorSetId.py +67 -0
  130. onnxruntime/tools/ort_format_model/ort_flatbuffers_py/fbs/OptimizerGroup.py +117 -0
  131. onnxruntime/tools/ort_format_model/ort_flatbuffers_py/fbs/ParameterOptimizerState.py +91 -0
  132. onnxruntime/tools/ort_format_model/ort_flatbuffers_py/fbs/PropertyBag.py +152 -0
  133. onnxruntime/tools/ort_format_model/ort_flatbuffers_py/fbs/RuntimeOptimizationRecord.py +105 -0
  134. onnxruntime/tools/ort_format_model/ort_flatbuffers_py/fbs/RuntimeOptimizationRecordContainerEntry.py +91 -0
  135. onnxruntime/tools/ort_format_model/ort_flatbuffers_py/fbs/RuntimeOptimizations.py +79 -0
  136. onnxruntime/tools/ort_format_model/ort_flatbuffers_py/fbs/SequenceType.py +58 -0
  137. onnxruntime/tools/ort_format_model/ort_flatbuffers_py/fbs/Shape.py +78 -0
  138. onnxruntime/tools/ort_format_model/ort_flatbuffers_py/fbs/SparseTensor.py +114 -0
  139. onnxruntime/tools/ort_format_model/ort_flatbuffers_py/fbs/StringProperty.py +67 -0
  140. onnxruntime/tools/ort_format_model/ort_flatbuffers_py/fbs/StringStringEntry.py +67 -0
  141. onnxruntime/tools/ort_format_model/ort_flatbuffers_py/fbs/Tensor.py +203 -0
  142. onnxruntime/tools/ort_format_model/ort_flatbuffers_py/fbs/TensorDataType.py +26 -0
  143. onnxruntime/tools/ort_format_model/ort_flatbuffers_py/fbs/TensorTypeAndShape.py +71 -0
  144. onnxruntime/tools/ort_format_model/ort_flatbuffers_py/fbs/TypeInfo.py +83 -0
  145. onnxruntime/tools/ort_format_model/ort_flatbuffers_py/fbs/TypeInfoValue.py +9 -0
  146. onnxruntime/tools/ort_format_model/ort_flatbuffers_py/fbs/ValueInfo.py +84 -0
  147. onnxruntime/tools/ort_format_model/ort_flatbuffers_py/fbs/__init__.py +6 -0
  148. onnxruntime/tools/ort_format_model/ort_model_processor.py +86 -0
  149. onnxruntime/tools/ort_format_model/types.py +85 -0
  150. onnxruntime/tools/ort_format_model/utils.py +61 -0
  151. onnxruntime/tools/pytorch_export_contrib_ops.py +129 -0
  152. onnxruntime/tools/pytorch_export_helpers.py +131 -0
  153. onnxruntime/tools/qdq_helpers/__init__.py +0 -0
  154. onnxruntime/tools/qdq_helpers/optimize_qdq_model.py +37 -0
  155. onnxruntime/tools/qnn/add_trans_cast.py +292 -0
  156. onnxruntime/tools/qnn/gen_qnn_ctx_onnx_model.py +364 -0
  157. onnxruntime/tools/qnn/preprocess.py +165 -0
  158. onnxruntime/tools/reduced_build_config_parser.py +203 -0
  159. onnxruntime/tools/remove_initializer_from_input.py +37 -0
  160. onnxruntime/tools/symbolic_shape_infer.py +3094 -0
  161. onnxruntime/tools/update_onnx_opset.py +31 -0
  162. onnxruntime/transformers/__init__.py +8 -0
  163. onnxruntime/transformers/affinity_helper.py +40 -0
  164. onnxruntime/transformers/benchmark.py +942 -0
  165. onnxruntime/transformers/benchmark_helper.py +643 -0
  166. onnxruntime/transformers/bert_perf_test.py +629 -0
  167. onnxruntime/transformers/bert_test_data.py +641 -0
  168. onnxruntime/transformers/compare_bert_results.py +256 -0
  169. onnxruntime/transformers/constants.py +47 -0
  170. onnxruntime/transformers/convert_generation.py +3605 -0
  171. onnxruntime/transformers/convert_tf_models_to_pytorch.py +205 -0
  172. onnxruntime/transformers/convert_to_packing_mode.py +385 -0
  173. onnxruntime/transformers/dynamo_onnx_helper.py +205 -0
  174. onnxruntime/transformers/float16.py +501 -0
  175. onnxruntime/transformers/fusion_attention.py +1189 -0
  176. onnxruntime/transformers/fusion_attention_clip.py +340 -0
  177. onnxruntime/transformers/fusion_attention_sam2.py +533 -0
  178. onnxruntime/transformers/fusion_attention_unet.py +1307 -0
  179. onnxruntime/transformers/fusion_attention_vae.py +300 -0
  180. onnxruntime/transformers/fusion_bart_attention.py +435 -0
  181. onnxruntime/transformers/fusion_base.py +141 -0
  182. onnxruntime/transformers/fusion_bias_add.py +57 -0
  183. onnxruntime/transformers/fusion_biasgelu.py +66 -0
  184. onnxruntime/transformers/fusion_biassplitgelu.py +110 -0
  185. onnxruntime/transformers/fusion_conformer_attention.py +222 -0
  186. onnxruntime/transformers/fusion_constant_fold.py +144 -0
  187. onnxruntime/transformers/fusion_embedlayer.py +810 -0
  188. onnxruntime/transformers/fusion_fastgelu.py +492 -0
  189. onnxruntime/transformers/fusion_gelu.py +258 -0
  190. onnxruntime/transformers/fusion_gelu_approximation.py +25 -0
  191. onnxruntime/transformers/fusion_gemmfastgelu.py +121 -0
  192. onnxruntime/transformers/fusion_gpt_attention.py +546 -0
  193. onnxruntime/transformers/fusion_gpt_attention_megatron.py +355 -0
  194. onnxruntime/transformers/fusion_gpt_attention_no_past.py +260 -0
  195. onnxruntime/transformers/fusion_group_norm.py +180 -0
  196. onnxruntime/transformers/fusion_layernorm.py +489 -0
  197. onnxruntime/transformers/fusion_mha_mmdit.py +667 -0
  198. onnxruntime/transformers/fusion_nhwc_conv.py +99 -0
  199. onnxruntime/transformers/fusion_options.py +340 -0
  200. onnxruntime/transformers/fusion_qordered_attention.py +420 -0
  201. onnxruntime/transformers/fusion_qordered_gelu.py +118 -0
  202. onnxruntime/transformers/fusion_qordered_layernorm.py +122 -0
  203. onnxruntime/transformers/fusion_qordered_matmul.py +216 -0
  204. onnxruntime/transformers/fusion_quickgelu.py +74 -0
  205. onnxruntime/transformers/fusion_reshape.py +173 -0
  206. onnxruntime/transformers/fusion_rotary_attention.py +1591 -0
  207. onnxruntime/transformers/fusion_shape.py +109 -0
  208. onnxruntime/transformers/fusion_simplified_layernorm.py +165 -0
  209. onnxruntime/transformers/fusion_skip_group_norm.py +254 -0
  210. onnxruntime/transformers/fusion_skiplayernorm.py +209 -0
  211. onnxruntime/transformers/fusion_transpose.py +167 -0
  212. onnxruntime/transformers/fusion_utils.py +321 -0
  213. onnxruntime/transformers/huggingface_models.py +74 -0
  214. onnxruntime/transformers/import_utils.py +20 -0
  215. onnxruntime/transformers/io_binding_helper.py +487 -0
  216. onnxruntime/transformers/large_model_exporter.py +395 -0
  217. onnxruntime/transformers/machine_info.py +230 -0
  218. onnxruntime/transformers/metrics.py +163 -0
  219. onnxruntime/transformers/models/bart/__init__.py +12 -0
  220. onnxruntime/transformers/models/bart/export.py +98 -0
  221. onnxruntime/transformers/models/bert/__init__.py +12 -0
  222. onnxruntime/transformers/models/bert/eval_squad.py +329 -0
  223. onnxruntime/transformers/models/gpt2/__init__.py +12 -0
  224. onnxruntime/transformers/models/gpt2/benchmark_gpt2.py +413 -0
  225. onnxruntime/transformers/models/gpt2/convert_to_onnx.py +566 -0
  226. onnxruntime/transformers/models/gpt2/gpt2_helper.py +1031 -0
  227. onnxruntime/transformers/models/gpt2/gpt2_parity.py +513 -0
  228. onnxruntime/transformers/models/gpt2/gpt2_tester.py +501 -0
  229. onnxruntime/transformers/models/gpt2/parity_check_helper.py +146 -0
  230. onnxruntime/transformers/models/llama/__init__.py +12 -0
  231. onnxruntime/transformers/models/llama/benchmark.py +700 -0
  232. onnxruntime/transformers/models/llama/benchmark_all.py +488 -0
  233. onnxruntime/transformers/models/llama/benchmark_e2e.py +608 -0
  234. onnxruntime/transformers/models/llama/convert_to_onnx.py +1064 -0
  235. onnxruntime/transformers/models/llama/dist_settings.py +57 -0
  236. onnxruntime/transformers/models/llama/llama_inputs.py +504 -0
  237. onnxruntime/transformers/models/llama/llama_parity.py +343 -0
  238. onnxruntime/transformers/models/llama/llama_torch.py +47 -0
  239. onnxruntime/transformers/models/llama/quant_kv_dataloader.py +108 -0
  240. onnxruntime/transformers/models/longformer/__init__.py +12 -0
  241. onnxruntime/transformers/models/longformer/benchmark_longformer.py +821 -0
  242. onnxruntime/transformers/models/longformer/convert_to_onnx.py +413 -0
  243. onnxruntime/transformers/models/longformer/generate_test_data.py +347 -0
  244. onnxruntime/transformers/models/longformer/longformer_helper.py +76 -0
  245. onnxruntime/transformers/models/phi2/__init__.py +12 -0
  246. onnxruntime/transformers/models/phi2/convert_to_onnx.py +590 -0
  247. onnxruntime/transformers/models/phi2/inference_example.py +414 -0
  248. onnxruntime/transformers/models/sam2/__init__.py +12 -0
  249. onnxruntime/transformers/models/sam2/benchmark_sam2.py +638 -0
  250. onnxruntime/transformers/models/sam2/convert_to_onnx.py +270 -0
  251. onnxruntime/transformers/models/sam2/image_decoder.py +272 -0
  252. onnxruntime/transformers/models/sam2/image_encoder.py +236 -0
  253. onnxruntime/transformers/models/sam2/mask_decoder.py +208 -0
  254. onnxruntime/transformers/models/sam2/nvtx_helper.py +33 -0
  255. onnxruntime/transformers/models/sam2/prompt_encoder.py +189 -0
  256. onnxruntime/transformers/models/sam2/sam2_demo.py +321 -0
  257. onnxruntime/transformers/models/sam2/sam2_image_onnx_predictor.py +279 -0
  258. onnxruntime/transformers/models/sam2/sam2_utils.py +147 -0
  259. onnxruntime/transformers/models/stable_diffusion/__init__.py +12 -0
  260. onnxruntime/transformers/models/stable_diffusion/benchmark.py +1519 -0
  261. onnxruntime/transformers/models/stable_diffusion/benchmark_controlnet.py +426 -0
  262. onnxruntime/transformers/models/stable_diffusion/demo_txt2img.py +103 -0
  263. onnxruntime/transformers/models/stable_diffusion/demo_txt2img_xl.py +269 -0
  264. onnxruntime/transformers/models/stable_diffusion/demo_utils.py +778 -0
  265. onnxruntime/transformers/models/stable_diffusion/diffusion_models.py +1318 -0
  266. onnxruntime/transformers/models/stable_diffusion/diffusion_schedulers.py +1179 -0
  267. onnxruntime/transformers/models/stable_diffusion/engine_builder.py +295 -0
  268. onnxruntime/transformers/models/stable_diffusion/engine_builder_ort_cuda.py +387 -0
  269. onnxruntime/transformers/models/stable_diffusion/engine_builder_ort_trt.py +288 -0
  270. onnxruntime/transformers/models/stable_diffusion/engine_builder_tensorrt.py +395 -0
  271. onnxruntime/transformers/models/stable_diffusion/engine_builder_torch.py +108 -0
  272. onnxruntime/transformers/models/stable_diffusion/optimize_pipeline.py +590 -0
  273. onnxruntime/transformers/models/stable_diffusion/ort_optimizer.py +136 -0
  274. onnxruntime/transformers/models/stable_diffusion/pipeline_stable_diffusion.py +831 -0
  275. onnxruntime/transformers/models/stable_diffusion/trt_utilities.py +12 -0
  276. onnxruntime/transformers/models/t5/__init__.py +12 -0
  277. onnxruntime/transformers/models/t5/convert_to_onnx.py +318 -0
  278. onnxruntime/transformers/models/t5/t5_decoder.py +437 -0
  279. onnxruntime/transformers/models/t5/t5_encoder.py +70 -0
  280. onnxruntime/transformers/models/t5/t5_encoder_decoder_init.py +361 -0
  281. onnxruntime/transformers/models/t5/t5_helper.py +302 -0
  282. onnxruntime/transformers/models/whisper/__init__.py +12 -0
  283. onnxruntime/transformers/models/whisper/benchmark.py +585 -0
  284. onnxruntime/transformers/models/whisper/benchmark_all.py +526 -0
  285. onnxruntime/transformers/models/whisper/convert_to_onnx.py +609 -0
  286. onnxruntime/transformers/models/whisper/whisper_chain.py +334 -0
  287. onnxruntime/transformers/models/whisper/whisper_decoder.py +464 -0
  288. onnxruntime/transformers/models/whisper/whisper_encoder.py +164 -0
  289. onnxruntime/transformers/models/whisper/whisper_encoder_decoder_init.py +371 -0
  290. onnxruntime/transformers/models/whisper/whisper_helper.py +1035 -0
  291. onnxruntime/transformers/models/whisper/whisper_inputs.py +380 -0
  292. onnxruntime/transformers/models/whisper/whisper_jump_times.py +477 -0
  293. onnxruntime/transformers/onnx_exporter.py +719 -0
  294. onnxruntime/transformers/onnx_model.py +1636 -0
  295. onnxruntime/transformers/onnx_model_bart.py +141 -0
  296. onnxruntime/transformers/onnx_model_bert.py +488 -0
  297. onnxruntime/transformers/onnx_model_bert_keras.py +474 -0
  298. onnxruntime/transformers/onnx_model_bert_tf.py +588 -0
  299. onnxruntime/transformers/onnx_model_clip.py +42 -0
  300. onnxruntime/transformers/onnx_model_conformer.py +32 -0
  301. onnxruntime/transformers/onnx_model_gpt2.py +101 -0
  302. onnxruntime/transformers/onnx_model_mmdit.py +112 -0
  303. onnxruntime/transformers/onnx_model_phi.py +929 -0
  304. onnxruntime/transformers/onnx_model_sam2.py +137 -0
  305. onnxruntime/transformers/onnx_model_t5.py +985 -0
  306. onnxruntime/transformers/onnx_model_tnlr.py +226 -0
  307. onnxruntime/transformers/onnx_model_unet.py +258 -0
  308. onnxruntime/transformers/onnx_model_vae.py +42 -0
  309. onnxruntime/transformers/onnx_utils.py +55 -0
  310. onnxruntime/transformers/optimizer.py +620 -0
  311. onnxruntime/transformers/past_helper.py +149 -0
  312. onnxruntime/transformers/profile_result_processor.py +358 -0
  313. onnxruntime/transformers/profiler.py +434 -0
  314. onnxruntime/transformers/quantize_helper.py +76 -0
  315. onnxruntime/transformers/shape_infer_helper.py +121 -0
  316. onnxruntime/transformers/shape_optimizer.py +400 -0
  317. onnxruntime/transformers/torch_onnx_export_helper.py +74 -0
  318. onnxruntime_directml-1.24.1.dist-info/METADATA +216 -0
  319. onnxruntime_directml-1.24.1.dist-info/RECORD +322 -0
  320. onnxruntime_directml-1.24.1.dist-info/WHEEL +5 -0
  321. onnxruntime_directml-1.24.1.dist-info/entry_points.txt +2 -0
  322. onnxruntime_directml-1.24.1.dist-info/top_level.txt +1 -0
@@ -0,0 +1,821 @@
1
+ # -------------------------------------------------------------------------
2
+ # Copyright (c) Microsoft Corporation. All rights reserved.
3
+ # Licensed under the MIT License. See License.txt in the project root for
4
+ # license information.
5
+ # --------------------------------------------------------------------------
6
+ #
7
+ # This script run benchmark of latency or peak memory usage of Longformer model inference.
8
+ # Please run convert_to_onnx.py to get onnx model before running benchmark.
9
+ #
10
+ # It is tested with python 3.8, onnxruntime-gpu 1.11.0, PyTorch 1.11.0, transformers 4.18.0, CUDA 11.3 like:
11
+ # conda create -n gpu_env python=3.8
12
+ # conda activate gpu_env
13
+ # pip3 install torch torchvision torchaudio --extra-index-url https://download.pytorch.org/whl/cu113
14
+ # pip3 install onnx transformers onnxruntime-gpu numpy sympy psutil py3nvml
15
+ # python benchmark_longformer.py
16
+ #
17
+ # When there is no parameter, pre-defined tests will run on the longformer-base-4096 model.
18
+
19
+ # Benchmark the latency:
20
+ # python benchmark_longformer.py --model longformer-base-4096 --batch_sizes 1 --sequence_lengths 512 1024 2048 4096 \
21
+ # --global_lengths 8 --onnx ./longformer-base-4096_fp16.onnx -t 100
22
+ #
23
+ # Benchmark GPU peak memory:
24
+ # export ORT_LONGFORMER_COMPACT_MEMORY=0
25
+ # python benchmark_longformer.py --model longformer-base-4096 --batch_sizes 1 --sequence_lengths 4096 \
26
+ # --global_lengths 8 --onnx ./longformer-base-4096_fp32.onnx --memory -t 10 --engine onnxruntime
27
+ # export ORT_LONGFORMER_COMPACT_MEMORY=1
28
+ # python benchmark_longformer.py --model longformer-base-4096 --batch_sizes 1 --sequence_lengths 4096 \
29
+ # --global_lengths 8 --onnx ./longformer-base-4096_fp32.onnx --memory -t 10 --engine onnxruntime
30
+ #
31
+ # By default, compact memory kernel is enabled. To disable it, set environment variable ORT_LONGFORMER_COMPACT_MEMORY=0.
32
+
33
+ import argparse
34
+ import csv
35
+ import logging
36
+ import math
37
+ import os
38
+ import re
39
+ import sys
40
+ import timeit
41
+ import traceback
42
+ from concurrent.futures import ProcessPoolExecutor
43
+ from datetime import datetime
44
+ from typing import Any
45
+
46
+ import benchmark_helper
47
+ import numpy as np
48
+ import torch
49
+ from longformer_helper import PRETRAINED_LONGFORMER_MODELS, LongformerHelper, LongformerInputs
50
+ from transformers import LongformerModel
51
+
52
+ import onnxruntime
53
+
54
+ logger = logging.getLogger("")
55
+
56
+
57
+ def test_torch_latency(
58
+ device,
59
+ model,
60
+ model_name,
61
+ batch_sizes,
62
+ sequence_lengths,
63
+ global_lengths,
64
+ test_times,
65
+ num_threads,
66
+ ) -> list[dict[str, Any]]:
67
+ if num_threads > 0:
68
+ torch.set_num_threads(num_threads)
69
+
70
+ results = []
71
+ for batch_size in batch_sizes:
72
+ for sequence_length in sequence_lengths:
73
+ for global_length in global_lengths:
74
+ logger.info(f"batch_size={batch_size} sequence_length={sequence_length} global_length={global_length}")
75
+ inputs: LongformerInputs = LongformerHelper.get_dummy_inputs(
76
+ batch_size, sequence_length, global_length, device
77
+ )
78
+ input_list = inputs.to_list()
79
+
80
+ _ = model(*input_list)
81
+ runtimes = timeit.repeat(lambda: model(*input_list), repeat=test_times, number=1) # noqa: B023
82
+ result = {
83
+ "engine": "torch", # TODO: test torchscript
84
+ "version": torch.__version__,
85
+ "device": "cuda",
86
+ "optimizer": "",
87
+ "precision": "fp32",
88
+ "io_binding": "",
89
+ "model_name": model_name,
90
+ "description": model_name + " [torch]",
91
+ "inputs": 3,
92
+ "threads": num_threads,
93
+ "batch_size": batch_size,
94
+ "sequence_length": sequence_length,
95
+ "global_length": global_length,
96
+ "datetime": str(datetime.now()),
97
+ "memory": "NA",
98
+ "diff_max": 0,
99
+ "diff_90_percentile": 0,
100
+ "diff_95_percentile": 0,
101
+ "diff_99_percentile": 0,
102
+ "use_compact_memory": "NA",
103
+ }
104
+ result.update(benchmark_helper.get_latency_result(runtimes, batch_size))
105
+ logger.info("%s", result)
106
+ results.append(result)
107
+ return results
108
+
109
+
110
+ def test_parity(device, model, ort_session, batch_size, sequence_length, global_length, verbose=True):
111
+ parameters = f"batch_size={batch_size} sequence_length={sequence_length} global_length={global_length}"
112
+ logger.info(f"Comparing Torch and ORT outputs for {parameters}...")
113
+ dummy_inputs: LongformerInputs = LongformerHelper.get_dummy_inputs(
114
+ batch_size, sequence_length, global_length, device
115
+ )
116
+ ort_inputs = dummy_inputs.get_ort_inputs()
117
+ ort_outputs = ort_session.run(None, ort_inputs)
118
+ input_list = dummy_inputs.to_list()
119
+ torch_outputs = model(*input_list)
120
+ max_diff = np.amax(torch_outputs[0].cpu().numpy() - ort_outputs[0])
121
+ logger.info(f"last_state max diff = {max_diff}")
122
+ if verbose and (math.isnan(max_diff) or max_diff > 0.001):
123
+ print("torch last_state:", torch_outputs[0])
124
+ print("ort last_state:", ort_outputs[0])
125
+ return float(max_diff)
126
+
127
+
128
+ def test_ort_latency(
129
+ device,
130
+ model,
131
+ model_name,
132
+ description,
133
+ ort_session,
134
+ batch_sizes,
135
+ sequence_lengths,
136
+ global_lengths,
137
+ test_times,
138
+ num_threads,
139
+ optimizer=False,
140
+ precision="fp32",
141
+ disable_io_binding=False,
142
+ verbose=True,
143
+ use_compact_memory=False,
144
+ use_half4=False,
145
+ disable_parity=False,
146
+ ) -> list[dict[str, Any]]:
147
+ results = []
148
+ for batch_size in batch_sizes:
149
+ for sequence_length in sequence_lengths:
150
+ for global_length in global_lengths:
151
+ assert global_length <= model.config.attention_window[0], (
152
+ "Limitation of current implementation: number of global token <= attention_window"
153
+ )
154
+
155
+ logger.info(
156
+ f"Testing batch_size={batch_size} sequence_length={sequence_length} global_length={global_length} "
157
+ f"optimizer={optimizer}, precision={precision} io_binding={not disable_io_binding}..."
158
+ )
159
+ dummy_inputs: LongformerInputs = LongformerHelper.get_dummy_inputs(
160
+ batch_size, sequence_length, global_length, device
161
+ )
162
+
163
+ # Run OnnxRuntime
164
+ ort_inputs = dummy_inputs.get_ort_inputs()
165
+
166
+ if verbose:
167
+ print(ort_inputs)
168
+
169
+ # run one query for warm up
170
+ ort_outputs = ort_session.run(None, ort_inputs)
171
+
172
+ result_template = {
173
+ "model_name": model_name,
174
+ "description": description,
175
+ "inputs": 3,
176
+ "engine": "OnnxRuntime",
177
+ "version": str(onnxruntime.__version__),
178
+ "device": "cuda",
179
+ "precision": str(precision),
180
+ "optimizer": int(optimizer),
181
+ "threads": int(num_threads),
182
+ "batch_size": int(batch_size),
183
+ "sequence_length": int(sequence_length),
184
+ "global_length": int(global_length),
185
+ "test_times": int(test_times),
186
+ "datetime": str(datetime.now()),
187
+ "memory": "",
188
+ "diff_max": None,
189
+ "diff_90_percentile": None,
190
+ "diff_95_percentile": None,
191
+ "diff_99_percentile": None,
192
+ "use_compact_memory": use_compact_memory,
193
+ "use_half4": use_half4,
194
+ }
195
+
196
+ if not disable_io_binding:
197
+ max_last_state_size = max(batch_sizes) * max(sequence_lengths) * model.config.hidden_size
198
+ max_pooler_size = max(batch_sizes) * max(sequence_lengths)
199
+ result = benchmark_helper.inference_ort_with_io_binding(
200
+ ort_session,
201
+ ort_inputs,
202
+ result_template=result_template,
203
+ repeat_times=test_times,
204
+ ort_output_names=["last_state", "pooler"],
205
+ ort_outputs=ort_outputs,
206
+ output_buffers=[],
207
+ output_buffer_max_sizes=[max_last_state_size, max_pooler_size],
208
+ batch_size=batch_size,
209
+ device=device,
210
+ data_type=np.longlong, # input data type
211
+ )
212
+ else:
213
+ result = benchmark_helper.inference_ort(
214
+ ort_session,
215
+ ort_inputs,
216
+ result_template=result_template,
217
+ repeat_times=test_times,
218
+ batch_size=batch_size,
219
+ )
220
+
221
+ # measure result difference between PyTorch and OnnxRuntime
222
+ if not disable_parity:
223
+ diff_results = [
224
+ test_parity(
225
+ device,
226
+ model,
227
+ ort_session,
228
+ batch_size,
229
+ sequence_length,
230
+ global_length,
231
+ verbose,
232
+ )
233
+ for _ in range(test_times)
234
+ ]
235
+
236
+ result["diff_max"] = max(diff_results)
237
+ result["diff_90_percentile"] = np.percentile(diff_results, 90)
238
+ result["diff_95_percentile"] = np.percentile(diff_results, 95)
239
+ result["diff_99_percentile"] = np.percentile(diff_results, 99)
240
+
241
+ results.append(result)
242
+ return results
243
+
244
+
245
+ def test_ort_memory(
246
+ device,
247
+ onnx_model_path,
248
+ batch_size,
249
+ sequence_length,
250
+ global_length,
251
+ test_times,
252
+ num_threads,
253
+ ) -> dict[str, Any]:
254
+ logger.info(
255
+ f"Testing memory for model={onnx_model_path}, batch_size={batch_size}, sequence_length={sequence_length}, "
256
+ f"global_length={global_length}, test_times={test_times}, num_threads={num_threads}"
257
+ )
258
+
259
+ def inference():
260
+ # Update Arena strategy so that we can measure the minimum memory required
261
+ cuda_provider_options = {"arena_extend_strategy": "kSameAsRequested"}
262
+ provider_options = {"CUDAExecutionProvider": cuda_provider_options}
263
+ session = benchmark_helper.create_onnxruntime_session(
264
+ onnx_model_path,
265
+ use_gpu=True,
266
+ enable_all_optimization=True,
267
+ num_threads=num_threads,
268
+ provider_options=provider_options,
269
+ )
270
+
271
+ dummy_inputs: LongformerInputs = LongformerHelper.get_dummy_inputs(
272
+ batch_size, sequence_length, global_length, device
273
+ )
274
+ ort_inputs = dummy_inputs.get_ort_inputs()
275
+ for _ in range(test_times):
276
+ _ = session.run(None, ort_inputs)
277
+
278
+ memory_used = benchmark_helper.measure_memory(is_gpu=True, func=inference)
279
+
280
+ return {
281
+ "onnx_model": onnx_model_path,
282
+ "batch_size": batch_size,
283
+ "sequence_length": sequence_length,
284
+ "global_length": global_length,
285
+ "test_times": test_times,
286
+ "num_threads": num_threads,
287
+ "memory": memory_used,
288
+ }
289
+
290
+
291
+ def load_torch_model(model_name, device):
292
+ torch_model_name_or_dir = PRETRAINED_LONGFORMER_MODELS.get(model_name, model_name)
293
+ model = LongformerModel.from_pretrained(torch_model_name_or_dir)
294
+ model.to(device)
295
+ return model
296
+
297
+
298
+ def find_onnx_model(model_name, onnx_dir="."):
299
+ # Search onnx model in the following order: optimized fp16 model, optimized fp32 model, raw model
300
+ onnx_model_path = os.path.join(onnx_dir, model_name + ".onnx")
301
+ optimized_fp32_model = os.path.join(onnx_dir, model_name + "_fp32.onnx")
302
+ optimized_fp16_model = os.path.join(onnx_dir, model_name + "_fp16.onnx")
303
+ if os.path.isfile(optimized_fp16_model):
304
+ onnx_model_path = optimized_fp16_model
305
+ elif os.path.isfile(optimized_fp32_model):
306
+ onnx_model_path = optimized_fp32_model
307
+ return onnx_model_path
308
+
309
+
310
+ def test_memory(args, device) -> dict[str, Any]:
311
+ if len(args.batch_sizes) > 1:
312
+ raise RuntimeError("For memory test, only one batch_size (-b) is allowed.")
313
+ if len(args.sequence_lengths) > 1:
314
+ raise RuntimeError("For memory test, only one sequence_length (-s) is allowed.")
315
+ if len(args.global_lengths) > 1:
316
+ raise RuntimeError("For memory test, only one global_length (-g) is allowed.")
317
+
318
+ model_name = args.model
319
+ onnx_model_path = find_onnx_model(model_name) if not args.onnx else args.onnx
320
+
321
+ torch.cuda.empty_cache()
322
+ return test_ort_memory(
323
+ device,
324
+ onnx_model_path,
325
+ args.batch_sizes[0],
326
+ args.sequence_lengths[0],
327
+ args.global_lengths[0],
328
+ args.test_times,
329
+ args.num_threads,
330
+ )
331
+
332
+
333
+ def test_ort(args, device) -> list[dict[str, Any]]:
334
+ model_name = args.model
335
+
336
+ onnx_model_path = find_onnx_model(model_name) if not args.onnx else args.onnx
337
+
338
+ optimized = onnx_model_path.endswith("_fp16.onnx") or onnx_model_path.endswith("_fp32.onnx") # noqa: PIE810
339
+ precision = "fp32" if not onnx_model_path.endswith("_fp16.onnx") else "fp16"
340
+
341
+ model = load_torch_model(model_name, device)
342
+
343
+ num_threads = args.num_threads
344
+
345
+ cuda_provider_options = {"arena_extend_strategy": "kSameAsRequested"}
346
+ provider_options = {"CUDAExecutionProvider": cuda_provider_options}
347
+ session = benchmark_helper.create_onnxruntime_session(
348
+ onnx_model_path,
349
+ use_gpu=True,
350
+ enable_all_optimization=True,
351
+ num_threads=num_threads,
352
+ provider_options=provider_options,
353
+ )
354
+ if session is None:
355
+ raise RuntimeError(f"Failed to create ORT session from ONNX file {onnx_model_path}")
356
+
357
+ use_compact_memory = os.environ.get("ORT_LONGFORMER_COMPACT_MEMORY", "1") == "1"
358
+ description = onnx_model_path
359
+ if not use_compact_memory:
360
+ description += "[non_compact_memory]"
361
+
362
+ if args.use_half4:
363
+ description += "[half4]" if precision == "fp16" else "[float4]"
364
+ else:
365
+ description += "[half2]" if precision == "fp16" else "[float4]"
366
+
367
+ return test_ort_latency(
368
+ device,
369
+ model,
370
+ model_name,
371
+ description,
372
+ session,
373
+ args.batch_sizes,
374
+ args.sequence_lengths,
375
+ args.global_lengths,
376
+ args.test_times,
377
+ num_threads,
378
+ optimized,
379
+ precision,
380
+ args.disable_io_binding,
381
+ args.verbose,
382
+ use_compact_memory,
383
+ args.use_half4,
384
+ args.disable_parity,
385
+ )
386
+
387
+
388
+ def test_torch(args, device) -> list[dict[str, Any]]:
389
+ model = load_torch_model(args.model, device)
390
+ return test_torch_latency(
391
+ device,
392
+ model,
393
+ args.model,
394
+ args.batch_sizes,
395
+ args.sequence_lengths,
396
+ args.global_lengths,
397
+ args.test_times,
398
+ args.num_threads,
399
+ )
400
+
401
+
402
+ def test_latency(args, device) -> list[dict[str, Any]]:
403
+ if args.engine == "onnxruntime":
404
+ return test_ort(args, device)
405
+
406
+ return test_torch(args, device)
407
+
408
+
409
+ def parse_arguments(argv=None):
410
+ parser = argparse.ArgumentParser()
411
+
412
+ parser.add_argument(
413
+ "-m",
414
+ "--model",
415
+ required=False,
416
+ type=str,
417
+ default="longformer-base-4096",
418
+ help="Checkpoint directory or pre-trained model names in the list: "
419
+ + ", ".join(PRETRAINED_LONGFORMER_MODELS.keys()),
420
+ )
421
+
422
+ parser.add_argument(
423
+ "-e",
424
+ "--engine",
425
+ required=False,
426
+ type=str,
427
+ default="onnxruntime",
428
+ choices=["onnxruntime", "torch"],
429
+ help="Engine to benchmark.",
430
+ )
431
+
432
+ parser.add_argument(
433
+ "-t",
434
+ "--test_times",
435
+ required=False,
436
+ default=1000,
437
+ type=int,
438
+ help="Number of repeat times to get average inference latency.",
439
+ )
440
+
441
+ parser.add_argument("-b", "--batch_sizes", nargs="+", type=int, default=[1])
442
+
443
+ # If --export_padding is not used in exporting onnx model, there is no padding in ONNX model,
444
+ # and you will need padding inputs by yourself before running onnx model.
445
+ # Here, we only test sequence length that is multiple of attention window size.
446
+ parser.add_argument(
447
+ "-s",
448
+ "--sequence_lengths",
449
+ nargs="+",
450
+ type=int,
451
+ default=[512, 1024, 2048, 4096],
452
+ help="Sequence lengths. It could have multiple values in latency test."
453
+ "If --export_padding is not used, sequence length shall be multiple of window size.",
454
+ )
455
+
456
+ parser.add_argument("--onnx", required=False, type=str, default=None, help="Onnx model path")
457
+
458
+ parser.add_argument(
459
+ "-g",
460
+ "--global_lengths",
461
+ nargs="+",
462
+ type=int,
463
+ default=[0],
464
+ help="Number of global tokens. It could have multiple values in latency test.",
465
+ )
466
+
467
+ parser.add_argument(
468
+ "-n",
469
+ "--num_threads",
470
+ required=False,
471
+ type=int,
472
+ default=0,
473
+ help="Threads to use.",
474
+ )
475
+
476
+ parser.add_argument(
477
+ "--disable_io_binding",
478
+ required=False,
479
+ action="store_true",
480
+ help="Do not use IO Binding.",
481
+ )
482
+
483
+ parser.add_argument(
484
+ "--memory",
485
+ required=False,
486
+ action="store_true",
487
+ help="Test memory usage instead of latency.",
488
+ )
489
+
490
+ parser.add_argument("--verbose", required=False, action="store_true", help="Print more information.")
491
+ parser.set_defaults(verbose=False)
492
+
493
+ parser.add_argument("--use_half4", required=False, action="store_true", help="Use half4 kernel.")
494
+ parser.set_defaults(use_half4=False)
495
+
496
+ parser.add_argument("--disable_parity", required=False, action="store_true", help="Do not run parity test.")
497
+ parser.set_defaults(disable_parity=False)
498
+
499
+ args = parser.parse_args(argv)
500
+
501
+ return args
502
+
503
+
504
+ def output_details(results, csv_filename):
505
+ latency_results = [result for result in results if "average_latency_ms" in result]
506
+ if len(latency_results) == 0:
507
+ print("No latency results for output.")
508
+ return
509
+
510
+ with open(csv_filename, mode="a", newline="", encoding="ascii") as csv_file:
511
+ column_names = [
512
+ "engine",
513
+ "version",
514
+ "device",
515
+ "precision",
516
+ "optimizer",
517
+ "io_binding",
518
+ "model_name",
519
+ "inputs",
520
+ "threads",
521
+ "datetime",
522
+ "test_times",
523
+ "description",
524
+ "batch_size",
525
+ "sequence_length",
526
+ "global_length",
527
+ "use_compact_memory",
528
+ "use_half4",
529
+ "diff_max",
530
+ "diff_90_percentile",
531
+ "diff_95_percentile",
532
+ "diff_99_percentile",
533
+ "memory",
534
+ "QPS",
535
+ "average_latency_ms",
536
+ "latency_variance",
537
+ "latency_90_percentile",
538
+ "latency_95_percentile",
539
+ "latency_99_percentile",
540
+ ]
541
+
542
+ csv_writer = csv.DictWriter(csv_file, fieldnames=column_names)
543
+ csv_writer.writeheader()
544
+ for result in latency_results:
545
+ print(result)
546
+ csv_writer.writerow(result)
547
+
548
+ csv_file.flush()
549
+
550
+ print(f"Detail results are saved to csv file: {csv_filename}")
551
+
552
+
553
+ def run(args) -> list[dict[str, Any]]:
554
+ torch.set_grad_enabled(False)
555
+
556
+ # set random seed manually to get deterministic results
557
+ benchmark_helper.set_random_seed(123)
558
+
559
+ # Currently, the longformer attention operator could only run in GPU (no CPU implementation yet).
560
+ device = torch.device("cuda:0")
561
+
562
+ if args.memory:
563
+ return [test_memory(args, device)] # Convert to List so that return type is same as test_latency
564
+
565
+ return test_latency(args, device)
566
+
567
+
568
+ def launch_test(arguments) -> list[dict[str, Any]]:
569
+ if not torch.cuda.is_available():
570
+ raise RuntimeError("Please install PyTorch with Cuda, and use a machine with GPU for testing gpu performance.")
571
+
572
+ with ProcessPoolExecutor() as executor:
573
+ results = list(executor.map(run, [arguments]))
574
+ assert len(results) == 1
575
+ return results[0]
576
+
577
+
578
+ def run_tests(
579
+ use_compact_memory=True,
580
+ run_torch=False,
581
+ run_memory=True,
582
+ use_io_binding=True,
583
+ use_fp16=True,
584
+ use_merged_qkv_weights=True,
585
+ use_half4=True,
586
+ batch_size=1,
587
+ ):
588
+ compact_memory = "1" if use_compact_memory else "0"
589
+ os.environ["ORT_LONGFORMER_COMPACT_MEMORY"] = compact_memory
590
+ logger.info(f"ORT_LONGFORMER_COMPACT_MEMORY={compact_memory}")
591
+
592
+ os.environ["ORT_LONGFORMER_USE_HALF4"] = "1" if use_half4 else "0"
593
+ logger.info("ORT_LONGFORMER_USE_HALF4={}".format("1" if use_half4 else "0")) # noqa: G001
594
+
595
+ results = []
596
+ test_times = 1000
597
+ sequence_lengths = [4096, 2048, 1024, 512]
598
+ batch_sizes = [batch_size]
599
+ for model_name in ["longformer-base-4096"]:
600
+ for batch_size in batch_sizes:
601
+ for sequence_length in sequence_lengths:
602
+ for global_length in [16]:
603
+ if run_torch:
604
+ engine_name = "torch"
605
+ args = parse_arguments(
606
+ f"-e {engine_name} -t {test_times} -b {batch_size} -s {sequence_length} -g {global_length} "
607
+ f"-t {test_times} -m {model_name}".split(" ")
608
+ )
609
+ results += run(args)
610
+
611
+ engine_name = "onnxruntime"
612
+ file_format = 1 if use_merged_qkv_weights else 0
613
+ onnx_path = (
614
+ f"{model_name}_f{file_format}_fp16.onnx"
615
+ if use_fp16
616
+ else f"{model_name}_f{file_format}_fp32.onnx"
617
+ )
618
+ if not os.path.exists(onnx_path):
619
+ raise RuntimeError(f"onnx file not exists:{onnx_path}")
620
+
621
+ arguments = (
622
+ f"-e {engine_name} --onnx {onnx_path} "
623
+ f"-b {batch_size} -s {sequence_length} -g {global_length} -m {model_name}"
624
+ )
625
+
626
+ if not use_io_binding:
627
+ arguments += " --disable_io_binding"
628
+
629
+ if use_half4:
630
+ arguments += " --use_half4"
631
+
632
+ # Disable parity test to avoid out of memory for large batch size
633
+ if batch_size >= 4:
634
+ arguments += " --disable_parity"
635
+
636
+ memory_results = None
637
+ try:
638
+ if run_memory:
639
+ args = parse_arguments(f"{arguments} -t 10 --memory".split(" "))
640
+ memory_results = launch_test(args)
641
+
642
+ args = parse_arguments(f"{arguments} -t {test_times}".split(" "))
643
+ latency_results = launch_test(args)
644
+ except KeyboardInterrupt as exc:
645
+ raise RuntimeError("Keyboard Interrupted") from exc
646
+ except Exception:
647
+ traceback.print_exc()
648
+ continue
649
+
650
+ if len(latency_results) == 1:
651
+ latency_results[0]["memory"] = memory_results[0]["memory"] if memory_results else "N/A"
652
+ else:
653
+ raise RuntimeError("length of latency_results should be 1")
654
+
655
+ logger.info("%s", latency_results)
656
+
657
+ results += latency_results
658
+ return results
659
+
660
+
661
+ def output_summary(results, csv_filename, data_field="average_latency_ms"):
662
+ with open(csv_filename, mode="a", newline="", encoding="ascii") as csv_file:
663
+ header_names = [
664
+ "model_name",
665
+ "precision",
666
+ "engine",
667
+ "version",
668
+ "global_length",
669
+ "use_compact_memory",
670
+ "use_half4",
671
+ "description",
672
+ ]
673
+
674
+ description_list = list({result["description"] for result in results})
675
+ description_list.sort()
676
+
677
+ batch_sizes = list({result["batch_size"] for result in results})
678
+ batch_sizes.sort()
679
+
680
+ sequence_lengths = list({result["sequence_length"] for result in results})
681
+ sequence_lengths.sort()
682
+
683
+ data_names = []
684
+ for sequence_length in sequence_lengths:
685
+ for batch_size in batch_sizes:
686
+ data_names.append(f"b{batch_size}_s{sequence_length}")
687
+
688
+ csv_writer = csv.DictWriter(csv_file, fieldnames=header_names + data_names)
689
+ csv_writer.writeheader()
690
+
691
+ for description in description_list:
692
+ row = {}
693
+
694
+ sum_latency = {}
695
+ sum_latency.update(dict.fromkeys(data_names, 0))
696
+
697
+ count_latency = {}
698
+ count_latency.update(dict.fromkeys(data_names, 0))
699
+
700
+ for result in results:
701
+ if result["description"] == description and result[data_field]:
702
+ headers = {k: v for k, v in result.items() if k in header_names}
703
+ if not row:
704
+ row.update(headers)
705
+ else:
706
+ for k in header_names:
707
+ if row[k] != headers[k]:
708
+ raise RuntimeError("Description shall be unique")
709
+
710
+ batch_size = result["batch_size"]
711
+ sequence_length = result["sequence_length"]
712
+ key = f"b{batch_size}_s{sequence_length}"
713
+
714
+ try:
715
+ latency = float(result[data_field])
716
+ except ValueError:
717
+ continue
718
+
719
+ sum_latency[key] += latency
720
+ count_latency[key] += 1
721
+
722
+ if row:
723
+ for key in data_names:
724
+ if key in count_latency and count_latency[key] > 0:
725
+ row[key] = sum_latency[key] / count_latency[key]
726
+ else:
727
+ row[key] = ""
728
+
729
+ csv_writer.writerow(row)
730
+
731
+ csv_file.flush()
732
+
733
+
734
+ def run_experiments(use_fp16, batch_size, is_baseline=False):
735
+ """Run experiments to compare different algorithms on one batch size"""
736
+ test_results = run_tests(
737
+ use_fp16=use_fp16,
738
+ use_merged_qkv_weights=True,
739
+ use_half4=False,
740
+ batch_size=batch_size,
741
+ )
742
+
743
+ if is_baseline:
744
+ return test_results
745
+
746
+ if use_fp16:
747
+ test_results += run_tests(
748
+ use_fp16=use_fp16,
749
+ use_merged_qkv_weights=True,
750
+ use_half4=True,
751
+ batch_size=batch_size,
752
+ )
753
+
754
+ test_results += run_tests(
755
+ use_fp16=use_fp16,
756
+ use_merged_qkv_weights=False,
757
+ use_half4=True,
758
+ batch_size=batch_size,
759
+ )
760
+
761
+ test_results += run_tests(
762
+ use_fp16=use_fp16,
763
+ use_merged_qkv_weights=False,
764
+ use_half4=False,
765
+ batch_size=batch_size,
766
+ )
767
+
768
+ return test_results
769
+
770
+
771
+ def main():
772
+ torch.multiprocessing.set_start_method("spawn")
773
+
774
+ args = parse_arguments()
775
+
776
+ benchmark_helper.setup_logger(args.verbose)
777
+
778
+ if len(sys.argv) > 1:
779
+ test_results = launch_test(args)
780
+ time_stamp = datetime.now().strftime("%Y%m%d-%H%M%S")
781
+ csv_filename = f"benchmark_detail_{time_stamp}.csv"
782
+ output_details(test_results, csv_filename)
783
+ return
784
+
785
+ gpu_list = benchmark_helper.get_gpu_info()
786
+ logger.info("GPU info: %s", gpu_list)
787
+ fp16_batch_sizes = [16, 8, 4, 2, 1]
788
+ fp32_batch_sizes = [4, 2, 1]
789
+ if gpu_list and gpu_list[0]["total"] >= 32 * 1024 * 1024 * 1024: # 32 GB
790
+ fp16_batch_sizes = [64, 32, 16, 8, 4, 2, 1]
791
+ fp32_batch_sizes = [16, 8, 4, 2, 1]
792
+
793
+ gpu_name = re.sub(r"(?u)[^-\w.]", "_", gpu_list[0]["name"]) if gpu_list else "gpu"
794
+ is_baseline = os.environ.get("ORT_LONGFORMER_BASELINE", "0") == "1"
795
+ experiment_name = f"longformer_base_{gpu_name}" + ("_baseline" if is_baseline else "")
796
+ logger.info(
797
+ f"experiment_name={experiment_name}, fp16_batch_sizes={fp16_batch_sizes}, fp32_batch_sizes={fp32_batch_sizes}"
798
+ )
799
+
800
+ total_runs = 1
801
+ all_results = []
802
+ for _ in range(total_runs):
803
+ for batch_size in fp16_batch_sizes:
804
+ fp16_results = run_experiments(use_fp16=True, batch_size=batch_size, is_baseline=is_baseline)
805
+ output_details(fp16_results, "longformer_base_fp16.csv")
806
+ all_results += fp16_results
807
+ for metric_name in ["average_latency_ms", "QPS", "memory", "diff_90_percentile"]:
808
+ output_summary(all_results, f"{experiment_name}_{metric_name}.csv", metric_name)
809
+
810
+ all_results = []
811
+ for _ in range(total_runs):
812
+ for batch_size in fp32_batch_sizes:
813
+ fp32_results = run_experiments(use_fp16=False, batch_size=batch_size, is_baseline=is_baseline)
814
+ output_details(fp32_results, "longformer_base_fp32.csv")
815
+ all_results += fp32_results
816
+ for metric_name in ["average_latency_ms", "QPS", "memory", "diff_90_percentile"]:
817
+ output_summary(all_results, f"{experiment_name}_{metric_name}.csv", metric_name)
818
+
819
+
820
+ if __name__ == "__main__":
821
+ main()