onnxruntime-directml 1.24.1__cp314-cp314-win_amd64.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (322) hide show
  1. onnxruntime/LICENSE +21 -0
  2. onnxruntime/Privacy.md +21 -0
  3. onnxruntime/ThirdPartyNotices.txt +6121 -0
  4. onnxruntime/__init__.py +418 -0
  5. onnxruntime/backend/__init__.py +6 -0
  6. onnxruntime/backend/backend.py +175 -0
  7. onnxruntime/backend/backend_rep.py +52 -0
  8. onnxruntime/capi/DirectML.dll +0 -0
  9. onnxruntime/capi/__init__.py +4 -0
  10. onnxruntime/capi/_ld_preload.py +7 -0
  11. onnxruntime/capi/_pybind_state.py +33 -0
  12. onnxruntime/capi/build_and_package_info.py +2 -0
  13. onnxruntime/capi/convert_npz_to_onnx_adapter.py +48 -0
  14. onnxruntime/capi/onnxruntime.dll +0 -0
  15. onnxruntime/capi/onnxruntime_collect_build_info.py +47 -0
  16. onnxruntime/capi/onnxruntime_inference_collection.py +1440 -0
  17. onnxruntime/capi/onnxruntime_providers_shared.dll +0 -0
  18. onnxruntime/capi/onnxruntime_pybind11_state.pyd +0 -0
  19. onnxruntime/capi/onnxruntime_validation.py +154 -0
  20. onnxruntime/capi/version_info.py +2 -0
  21. onnxruntime/datasets/__init__.py +18 -0
  22. onnxruntime/datasets/logreg_iris.onnx +0 -0
  23. onnxruntime/datasets/mul_1.onnx +0 -0
  24. onnxruntime/datasets/sigmoid.onnx +13 -0
  25. onnxruntime/quantization/CalTableFlatBuffers/KeyValue.py +78 -0
  26. onnxruntime/quantization/CalTableFlatBuffers/TrtTable.py +90 -0
  27. onnxruntime/quantization/CalTableFlatBuffers/__init__.py +0 -0
  28. onnxruntime/quantization/__init__.py +19 -0
  29. onnxruntime/quantization/base_quantizer.py +529 -0
  30. onnxruntime/quantization/calibrate.py +1267 -0
  31. onnxruntime/quantization/execution_providers/qnn/__init__.py +2 -0
  32. onnxruntime/quantization/execution_providers/qnn/fusion_lpnorm.py +132 -0
  33. onnxruntime/quantization/execution_providers/qnn/fusion_spacetodepth.py +162 -0
  34. onnxruntime/quantization/execution_providers/qnn/mixed_precision_overrides_utils.py +413 -0
  35. onnxruntime/quantization/execution_providers/qnn/preprocess.py +353 -0
  36. onnxruntime/quantization/execution_providers/qnn/quant_config.py +389 -0
  37. onnxruntime/quantization/fusions/__init__.py +4 -0
  38. onnxruntime/quantization/fusions/fusion.py +311 -0
  39. onnxruntime/quantization/fusions/fusion_gelu.py +272 -0
  40. onnxruntime/quantization/fusions/fusion_layernorm.py +146 -0
  41. onnxruntime/quantization/fusions/replace_upsample_with_resize.py +96 -0
  42. onnxruntime/quantization/matmul_bnb4_quantizer.py +239 -0
  43. onnxruntime/quantization/matmul_nbits_quantizer.py +1638 -0
  44. onnxruntime/quantization/neural_compressor/__init__.py +1 -0
  45. onnxruntime/quantization/neural_compressor/onnx_model.py +1251 -0
  46. onnxruntime/quantization/neural_compressor/util.py +80 -0
  47. onnxruntime/quantization/neural_compressor/weight_only.py +932 -0
  48. onnxruntime/quantization/onnx_model.py +600 -0
  49. onnxruntime/quantization/onnx_quantizer.py +1163 -0
  50. onnxruntime/quantization/operators/__init__.py +2 -0
  51. onnxruntime/quantization/operators/activation.py +119 -0
  52. onnxruntime/quantization/operators/argmax.py +18 -0
  53. onnxruntime/quantization/operators/attention.py +73 -0
  54. onnxruntime/quantization/operators/base_operator.py +26 -0
  55. onnxruntime/quantization/operators/binary_op.py +72 -0
  56. onnxruntime/quantization/operators/concat.py +62 -0
  57. onnxruntime/quantization/operators/conv.py +260 -0
  58. onnxruntime/quantization/operators/direct_q8.py +78 -0
  59. onnxruntime/quantization/operators/embed_layernorm.py +121 -0
  60. onnxruntime/quantization/operators/gather.py +64 -0
  61. onnxruntime/quantization/operators/gavgpool.py +62 -0
  62. onnxruntime/quantization/operators/gemm.py +172 -0
  63. onnxruntime/quantization/operators/lstm.py +121 -0
  64. onnxruntime/quantization/operators/matmul.py +231 -0
  65. onnxruntime/quantization/operators/maxpool.py +34 -0
  66. onnxruntime/quantization/operators/norm.py +40 -0
  67. onnxruntime/quantization/operators/pad.py +172 -0
  68. onnxruntime/quantization/operators/pooling.py +67 -0
  69. onnxruntime/quantization/operators/qdq_base_operator.py +22 -0
  70. onnxruntime/quantization/operators/resize.py +34 -0
  71. onnxruntime/quantization/operators/softmax.py +74 -0
  72. onnxruntime/quantization/operators/split.py +63 -0
  73. onnxruntime/quantization/operators/where.py +87 -0
  74. onnxruntime/quantization/preprocess.py +141 -0
  75. onnxruntime/quantization/qdq_loss_debug.py +389 -0
  76. onnxruntime/quantization/qdq_quantizer.py +1477 -0
  77. onnxruntime/quantization/quant_utils.py +1051 -0
  78. onnxruntime/quantization/quantize.py +953 -0
  79. onnxruntime/quantization/registry.py +110 -0
  80. onnxruntime/quantization/shape_inference.py +204 -0
  81. onnxruntime/quantization/static_quantize_runner.py +256 -0
  82. onnxruntime/quantization/tensor_quant_overrides.py +520 -0
  83. onnxruntime/tools/__init__.py +10 -0
  84. onnxruntime/tools/check_onnx_model_mobile_usability.py +47 -0
  85. onnxruntime/tools/convert_onnx_models_to_ort.py +380 -0
  86. onnxruntime/tools/file_utils.py +47 -0
  87. onnxruntime/tools/logger.py +11 -0
  88. onnxruntime/tools/make_dynamic_shape_fixed.py +73 -0
  89. onnxruntime/tools/mobile_helpers/__init__.py +0 -0
  90. onnxruntime/tools/mobile_helpers/coreml_supported_mlprogram_ops.md +53 -0
  91. onnxruntime/tools/mobile_helpers/coreml_supported_neuralnetwork_ops.md +43 -0
  92. onnxruntime/tools/mobile_helpers/nnapi_supported_ops.md +58 -0
  93. onnxruntime/tools/mobile_helpers/usability_checker.py +738 -0
  94. onnxruntime/tools/offline_tuning.py +169 -0
  95. onnxruntime/tools/onnx_model_utils.py +416 -0
  96. onnxruntime/tools/onnx_randomizer.py +85 -0
  97. onnxruntime/tools/onnxruntime_test.py +164 -0
  98. onnxruntime/tools/optimize_onnx_model.py +56 -0
  99. onnxruntime/tools/ort_format_model/__init__.py +27 -0
  100. onnxruntime/tools/ort_format_model/operator_type_usage_processors.py +653 -0
  101. onnxruntime/tools/ort_format_model/ort_flatbuffers_py/__init__.py +0 -0
  102. onnxruntime/tools/ort_format_model/ort_flatbuffers_py/fbs/ArgType.py +7 -0
  103. onnxruntime/tools/ort_format_model/ort_flatbuffers_py/fbs/ArgTypeAndIndex.py +67 -0
  104. onnxruntime/tools/ort_format_model/ort_flatbuffers_py/fbs/Attribute.py +337 -0
  105. onnxruntime/tools/ort_format_model/ort_flatbuffers_py/fbs/AttributeType.py +18 -0
  106. onnxruntime/tools/ort_format_model/ort_flatbuffers_py/fbs/Checkpoint.py +125 -0
  107. onnxruntime/tools/ort_format_model/ort_flatbuffers_py/fbs/DeprecatedKernelCreateInfos.py +120 -0
  108. onnxruntime/tools/ort_format_model/ort_flatbuffers_py/fbs/DeprecatedNodeIndexAndKernelDefHash.py +68 -0
  109. onnxruntime/tools/ort_format_model/ort_flatbuffers_py/fbs/DeprecatedSessionState.py +96 -0
  110. onnxruntime/tools/ort_format_model/ort_flatbuffers_py/fbs/DeprecatedSubGraphSessionState.py +72 -0
  111. onnxruntime/tools/ort_format_model/ort_flatbuffers_py/fbs/Dimension.py +71 -0
  112. onnxruntime/tools/ort_format_model/ort_flatbuffers_py/fbs/DimensionValue.py +80 -0
  113. onnxruntime/tools/ort_format_model/ort_flatbuffers_py/fbs/DimensionValueType.py +8 -0
  114. onnxruntime/tools/ort_format_model/ort_flatbuffers_py/fbs/EdgeEnd.py +32 -0
  115. onnxruntime/tools/ort_format_model/ort_flatbuffers_py/fbs/FloatProperty.py +67 -0
  116. onnxruntime/tools/ort_format_model/ort_flatbuffers_py/fbs/Graph.py +320 -0
  117. onnxruntime/tools/ort_format_model/ort_flatbuffers_py/fbs/InferenceSession.py +88 -0
  118. onnxruntime/tools/ort_format_model/ort_flatbuffers_py/fbs/IntProperty.py +67 -0
  119. onnxruntime/tools/ort_format_model/ort_flatbuffers_py/fbs/KernelTypeStrArgsEntry.py +91 -0
  120. onnxruntime/tools/ort_format_model/ort_flatbuffers_py/fbs/KernelTypeStrResolver.py +78 -0
  121. onnxruntime/tools/ort_format_model/ort_flatbuffers_py/fbs/MapType.py +71 -0
  122. onnxruntime/tools/ort_format_model/ort_flatbuffers_py/fbs/Model.py +223 -0
  123. onnxruntime/tools/ort_format_model/ort_flatbuffers_py/fbs/ModuleState.py +141 -0
  124. onnxruntime/tools/ort_format_model/ort_flatbuffers_py/fbs/Node.py +317 -0
  125. onnxruntime/tools/ort_format_model/ort_flatbuffers_py/fbs/NodeEdge.py +126 -0
  126. onnxruntime/tools/ort_format_model/ort_flatbuffers_py/fbs/NodeType.py +7 -0
  127. onnxruntime/tools/ort_format_model/ort_flatbuffers_py/fbs/NodesToOptimizeIndices.py +160 -0
  128. onnxruntime/tools/ort_format_model/ort_flatbuffers_py/fbs/OpIdKernelTypeStrArgsEntry.py +91 -0
  129. onnxruntime/tools/ort_format_model/ort_flatbuffers_py/fbs/OperatorSetId.py +67 -0
  130. onnxruntime/tools/ort_format_model/ort_flatbuffers_py/fbs/OptimizerGroup.py +117 -0
  131. onnxruntime/tools/ort_format_model/ort_flatbuffers_py/fbs/ParameterOptimizerState.py +91 -0
  132. onnxruntime/tools/ort_format_model/ort_flatbuffers_py/fbs/PropertyBag.py +152 -0
  133. onnxruntime/tools/ort_format_model/ort_flatbuffers_py/fbs/RuntimeOptimizationRecord.py +105 -0
  134. onnxruntime/tools/ort_format_model/ort_flatbuffers_py/fbs/RuntimeOptimizationRecordContainerEntry.py +91 -0
  135. onnxruntime/tools/ort_format_model/ort_flatbuffers_py/fbs/RuntimeOptimizations.py +79 -0
  136. onnxruntime/tools/ort_format_model/ort_flatbuffers_py/fbs/SequenceType.py +58 -0
  137. onnxruntime/tools/ort_format_model/ort_flatbuffers_py/fbs/Shape.py +78 -0
  138. onnxruntime/tools/ort_format_model/ort_flatbuffers_py/fbs/SparseTensor.py +114 -0
  139. onnxruntime/tools/ort_format_model/ort_flatbuffers_py/fbs/StringProperty.py +67 -0
  140. onnxruntime/tools/ort_format_model/ort_flatbuffers_py/fbs/StringStringEntry.py +67 -0
  141. onnxruntime/tools/ort_format_model/ort_flatbuffers_py/fbs/Tensor.py +203 -0
  142. onnxruntime/tools/ort_format_model/ort_flatbuffers_py/fbs/TensorDataType.py +26 -0
  143. onnxruntime/tools/ort_format_model/ort_flatbuffers_py/fbs/TensorTypeAndShape.py +71 -0
  144. onnxruntime/tools/ort_format_model/ort_flatbuffers_py/fbs/TypeInfo.py +83 -0
  145. onnxruntime/tools/ort_format_model/ort_flatbuffers_py/fbs/TypeInfoValue.py +9 -0
  146. onnxruntime/tools/ort_format_model/ort_flatbuffers_py/fbs/ValueInfo.py +84 -0
  147. onnxruntime/tools/ort_format_model/ort_flatbuffers_py/fbs/__init__.py +6 -0
  148. onnxruntime/tools/ort_format_model/ort_model_processor.py +86 -0
  149. onnxruntime/tools/ort_format_model/types.py +85 -0
  150. onnxruntime/tools/ort_format_model/utils.py +61 -0
  151. onnxruntime/tools/pytorch_export_contrib_ops.py +129 -0
  152. onnxruntime/tools/pytorch_export_helpers.py +131 -0
  153. onnxruntime/tools/qdq_helpers/__init__.py +0 -0
  154. onnxruntime/tools/qdq_helpers/optimize_qdq_model.py +37 -0
  155. onnxruntime/tools/qnn/add_trans_cast.py +292 -0
  156. onnxruntime/tools/qnn/gen_qnn_ctx_onnx_model.py +364 -0
  157. onnxruntime/tools/qnn/preprocess.py +165 -0
  158. onnxruntime/tools/reduced_build_config_parser.py +203 -0
  159. onnxruntime/tools/remove_initializer_from_input.py +37 -0
  160. onnxruntime/tools/symbolic_shape_infer.py +3094 -0
  161. onnxruntime/tools/update_onnx_opset.py +31 -0
  162. onnxruntime/transformers/__init__.py +8 -0
  163. onnxruntime/transformers/affinity_helper.py +40 -0
  164. onnxruntime/transformers/benchmark.py +942 -0
  165. onnxruntime/transformers/benchmark_helper.py +643 -0
  166. onnxruntime/transformers/bert_perf_test.py +629 -0
  167. onnxruntime/transformers/bert_test_data.py +641 -0
  168. onnxruntime/transformers/compare_bert_results.py +256 -0
  169. onnxruntime/transformers/constants.py +47 -0
  170. onnxruntime/transformers/convert_generation.py +3605 -0
  171. onnxruntime/transformers/convert_tf_models_to_pytorch.py +205 -0
  172. onnxruntime/transformers/convert_to_packing_mode.py +385 -0
  173. onnxruntime/transformers/dynamo_onnx_helper.py +205 -0
  174. onnxruntime/transformers/float16.py +501 -0
  175. onnxruntime/transformers/fusion_attention.py +1189 -0
  176. onnxruntime/transformers/fusion_attention_clip.py +340 -0
  177. onnxruntime/transformers/fusion_attention_sam2.py +533 -0
  178. onnxruntime/transformers/fusion_attention_unet.py +1307 -0
  179. onnxruntime/transformers/fusion_attention_vae.py +300 -0
  180. onnxruntime/transformers/fusion_bart_attention.py +435 -0
  181. onnxruntime/transformers/fusion_base.py +141 -0
  182. onnxruntime/transformers/fusion_bias_add.py +57 -0
  183. onnxruntime/transformers/fusion_biasgelu.py +66 -0
  184. onnxruntime/transformers/fusion_biassplitgelu.py +110 -0
  185. onnxruntime/transformers/fusion_conformer_attention.py +222 -0
  186. onnxruntime/transformers/fusion_constant_fold.py +144 -0
  187. onnxruntime/transformers/fusion_embedlayer.py +810 -0
  188. onnxruntime/transformers/fusion_fastgelu.py +492 -0
  189. onnxruntime/transformers/fusion_gelu.py +258 -0
  190. onnxruntime/transformers/fusion_gelu_approximation.py +25 -0
  191. onnxruntime/transformers/fusion_gemmfastgelu.py +121 -0
  192. onnxruntime/transformers/fusion_gpt_attention.py +546 -0
  193. onnxruntime/transformers/fusion_gpt_attention_megatron.py +355 -0
  194. onnxruntime/transformers/fusion_gpt_attention_no_past.py +260 -0
  195. onnxruntime/transformers/fusion_group_norm.py +180 -0
  196. onnxruntime/transformers/fusion_layernorm.py +489 -0
  197. onnxruntime/transformers/fusion_mha_mmdit.py +667 -0
  198. onnxruntime/transformers/fusion_nhwc_conv.py +99 -0
  199. onnxruntime/transformers/fusion_options.py +340 -0
  200. onnxruntime/transformers/fusion_qordered_attention.py +420 -0
  201. onnxruntime/transformers/fusion_qordered_gelu.py +118 -0
  202. onnxruntime/transformers/fusion_qordered_layernorm.py +122 -0
  203. onnxruntime/transformers/fusion_qordered_matmul.py +216 -0
  204. onnxruntime/transformers/fusion_quickgelu.py +74 -0
  205. onnxruntime/transformers/fusion_reshape.py +173 -0
  206. onnxruntime/transformers/fusion_rotary_attention.py +1591 -0
  207. onnxruntime/transformers/fusion_shape.py +109 -0
  208. onnxruntime/transformers/fusion_simplified_layernorm.py +165 -0
  209. onnxruntime/transformers/fusion_skip_group_norm.py +254 -0
  210. onnxruntime/transformers/fusion_skiplayernorm.py +209 -0
  211. onnxruntime/transformers/fusion_transpose.py +167 -0
  212. onnxruntime/transformers/fusion_utils.py +321 -0
  213. onnxruntime/transformers/huggingface_models.py +74 -0
  214. onnxruntime/transformers/import_utils.py +20 -0
  215. onnxruntime/transformers/io_binding_helper.py +487 -0
  216. onnxruntime/transformers/large_model_exporter.py +395 -0
  217. onnxruntime/transformers/machine_info.py +230 -0
  218. onnxruntime/transformers/metrics.py +163 -0
  219. onnxruntime/transformers/models/bart/__init__.py +12 -0
  220. onnxruntime/transformers/models/bart/export.py +98 -0
  221. onnxruntime/transformers/models/bert/__init__.py +12 -0
  222. onnxruntime/transformers/models/bert/eval_squad.py +329 -0
  223. onnxruntime/transformers/models/gpt2/__init__.py +12 -0
  224. onnxruntime/transformers/models/gpt2/benchmark_gpt2.py +413 -0
  225. onnxruntime/transformers/models/gpt2/convert_to_onnx.py +566 -0
  226. onnxruntime/transformers/models/gpt2/gpt2_helper.py +1031 -0
  227. onnxruntime/transformers/models/gpt2/gpt2_parity.py +513 -0
  228. onnxruntime/transformers/models/gpt2/gpt2_tester.py +501 -0
  229. onnxruntime/transformers/models/gpt2/parity_check_helper.py +146 -0
  230. onnxruntime/transformers/models/llama/__init__.py +12 -0
  231. onnxruntime/transformers/models/llama/benchmark.py +700 -0
  232. onnxruntime/transformers/models/llama/benchmark_all.py +488 -0
  233. onnxruntime/transformers/models/llama/benchmark_e2e.py +608 -0
  234. onnxruntime/transformers/models/llama/convert_to_onnx.py +1064 -0
  235. onnxruntime/transformers/models/llama/dist_settings.py +57 -0
  236. onnxruntime/transformers/models/llama/llama_inputs.py +504 -0
  237. onnxruntime/transformers/models/llama/llama_parity.py +343 -0
  238. onnxruntime/transformers/models/llama/llama_torch.py +47 -0
  239. onnxruntime/transformers/models/llama/quant_kv_dataloader.py +108 -0
  240. onnxruntime/transformers/models/longformer/__init__.py +12 -0
  241. onnxruntime/transformers/models/longformer/benchmark_longformer.py +821 -0
  242. onnxruntime/transformers/models/longformer/convert_to_onnx.py +413 -0
  243. onnxruntime/transformers/models/longformer/generate_test_data.py +347 -0
  244. onnxruntime/transformers/models/longformer/longformer_helper.py +76 -0
  245. onnxruntime/transformers/models/phi2/__init__.py +12 -0
  246. onnxruntime/transformers/models/phi2/convert_to_onnx.py +590 -0
  247. onnxruntime/transformers/models/phi2/inference_example.py +414 -0
  248. onnxruntime/transformers/models/sam2/__init__.py +12 -0
  249. onnxruntime/transformers/models/sam2/benchmark_sam2.py +638 -0
  250. onnxruntime/transformers/models/sam2/convert_to_onnx.py +270 -0
  251. onnxruntime/transformers/models/sam2/image_decoder.py +272 -0
  252. onnxruntime/transformers/models/sam2/image_encoder.py +236 -0
  253. onnxruntime/transformers/models/sam2/mask_decoder.py +208 -0
  254. onnxruntime/transformers/models/sam2/nvtx_helper.py +33 -0
  255. onnxruntime/transformers/models/sam2/prompt_encoder.py +189 -0
  256. onnxruntime/transformers/models/sam2/sam2_demo.py +321 -0
  257. onnxruntime/transformers/models/sam2/sam2_image_onnx_predictor.py +279 -0
  258. onnxruntime/transformers/models/sam2/sam2_utils.py +147 -0
  259. onnxruntime/transformers/models/stable_diffusion/__init__.py +12 -0
  260. onnxruntime/transformers/models/stable_diffusion/benchmark.py +1519 -0
  261. onnxruntime/transformers/models/stable_diffusion/benchmark_controlnet.py +426 -0
  262. onnxruntime/transformers/models/stable_diffusion/demo_txt2img.py +103 -0
  263. onnxruntime/transformers/models/stable_diffusion/demo_txt2img_xl.py +269 -0
  264. onnxruntime/transformers/models/stable_diffusion/demo_utils.py +778 -0
  265. onnxruntime/transformers/models/stable_diffusion/diffusion_models.py +1318 -0
  266. onnxruntime/transformers/models/stable_diffusion/diffusion_schedulers.py +1179 -0
  267. onnxruntime/transformers/models/stable_diffusion/engine_builder.py +295 -0
  268. onnxruntime/transformers/models/stable_diffusion/engine_builder_ort_cuda.py +387 -0
  269. onnxruntime/transformers/models/stable_diffusion/engine_builder_ort_trt.py +288 -0
  270. onnxruntime/transformers/models/stable_diffusion/engine_builder_tensorrt.py +395 -0
  271. onnxruntime/transformers/models/stable_diffusion/engine_builder_torch.py +108 -0
  272. onnxruntime/transformers/models/stable_diffusion/optimize_pipeline.py +590 -0
  273. onnxruntime/transformers/models/stable_diffusion/ort_optimizer.py +136 -0
  274. onnxruntime/transformers/models/stable_diffusion/pipeline_stable_diffusion.py +831 -0
  275. onnxruntime/transformers/models/stable_diffusion/trt_utilities.py +12 -0
  276. onnxruntime/transformers/models/t5/__init__.py +12 -0
  277. onnxruntime/transformers/models/t5/convert_to_onnx.py +318 -0
  278. onnxruntime/transformers/models/t5/t5_decoder.py +437 -0
  279. onnxruntime/transformers/models/t5/t5_encoder.py +70 -0
  280. onnxruntime/transformers/models/t5/t5_encoder_decoder_init.py +361 -0
  281. onnxruntime/transformers/models/t5/t5_helper.py +302 -0
  282. onnxruntime/transformers/models/whisper/__init__.py +12 -0
  283. onnxruntime/transformers/models/whisper/benchmark.py +585 -0
  284. onnxruntime/transformers/models/whisper/benchmark_all.py +526 -0
  285. onnxruntime/transformers/models/whisper/convert_to_onnx.py +609 -0
  286. onnxruntime/transformers/models/whisper/whisper_chain.py +334 -0
  287. onnxruntime/transformers/models/whisper/whisper_decoder.py +464 -0
  288. onnxruntime/transformers/models/whisper/whisper_encoder.py +164 -0
  289. onnxruntime/transformers/models/whisper/whisper_encoder_decoder_init.py +371 -0
  290. onnxruntime/transformers/models/whisper/whisper_helper.py +1035 -0
  291. onnxruntime/transformers/models/whisper/whisper_inputs.py +380 -0
  292. onnxruntime/transformers/models/whisper/whisper_jump_times.py +477 -0
  293. onnxruntime/transformers/onnx_exporter.py +719 -0
  294. onnxruntime/transformers/onnx_model.py +1636 -0
  295. onnxruntime/transformers/onnx_model_bart.py +141 -0
  296. onnxruntime/transformers/onnx_model_bert.py +488 -0
  297. onnxruntime/transformers/onnx_model_bert_keras.py +474 -0
  298. onnxruntime/transformers/onnx_model_bert_tf.py +588 -0
  299. onnxruntime/transformers/onnx_model_clip.py +42 -0
  300. onnxruntime/transformers/onnx_model_conformer.py +32 -0
  301. onnxruntime/transformers/onnx_model_gpt2.py +101 -0
  302. onnxruntime/transformers/onnx_model_mmdit.py +112 -0
  303. onnxruntime/transformers/onnx_model_phi.py +929 -0
  304. onnxruntime/transformers/onnx_model_sam2.py +137 -0
  305. onnxruntime/transformers/onnx_model_t5.py +985 -0
  306. onnxruntime/transformers/onnx_model_tnlr.py +226 -0
  307. onnxruntime/transformers/onnx_model_unet.py +258 -0
  308. onnxruntime/transformers/onnx_model_vae.py +42 -0
  309. onnxruntime/transformers/onnx_utils.py +55 -0
  310. onnxruntime/transformers/optimizer.py +620 -0
  311. onnxruntime/transformers/past_helper.py +149 -0
  312. onnxruntime/transformers/profile_result_processor.py +358 -0
  313. onnxruntime/transformers/profiler.py +434 -0
  314. onnxruntime/transformers/quantize_helper.py +76 -0
  315. onnxruntime/transformers/shape_infer_helper.py +121 -0
  316. onnxruntime/transformers/shape_optimizer.py +400 -0
  317. onnxruntime/transformers/torch_onnx_export_helper.py +74 -0
  318. onnxruntime_directml-1.24.1.dist-info/METADATA +216 -0
  319. onnxruntime_directml-1.24.1.dist-info/RECORD +322 -0
  320. onnxruntime_directml-1.24.1.dist-info/WHEEL +5 -0
  321. onnxruntime_directml-1.24.1.dist-info/entry_points.txt +2 -0
  322. onnxruntime_directml-1.24.1.dist-info/top_level.txt +1 -0
@@ -0,0 +1,643 @@
1
+ # -------------------------------------------------------------------------
2
+ # Copyright (c) Microsoft Corporation. All rights reserved.
3
+ # Licensed under the MIT License. See License.txt in the project root for
4
+ # license information.
5
+ # --------------------------------------------------------------------------
6
+
7
+ import csv
8
+ import logging
9
+ import os
10
+ import random
11
+ import sys
12
+ import time
13
+ import timeit
14
+ from abc import ABC, abstractmethod
15
+ from concurrent.futures import ThreadPoolExecutor
16
+ from datetime import datetime
17
+ from enum import Enum
18
+ from time import sleep
19
+ from typing import Any
20
+
21
+ import numpy
22
+ import torch
23
+ import transformers
24
+ from packaging import version
25
+
26
+ import onnxruntime
27
+
28
+ logger = logging.getLogger(__name__)
29
+
30
+
31
+ class Precision(Enum):
32
+ FLOAT32 = "fp32"
33
+ FLOAT16 = "fp16"
34
+ INT8 = "int8"
35
+ INT4 = "int4"
36
+
37
+ def __str__(self):
38
+ return self.value
39
+
40
+
41
+ class OptimizerInfo(Enum):
42
+ # no_opt means using the raw ONNX model, but OnnxRuntime might still apply optimization as long as
43
+ # graph optimization level is not 0 (disable all).
44
+ NOOPT = "no_opt"
45
+ BYORT = "by_ort"
46
+ BYSCRIPT = "by_script"
47
+
48
+ def __str__(self):
49
+ return self.value
50
+
51
+
52
+ class ConfigModifier:
53
+ def __init__(self, num_layers):
54
+ self.num_layers = num_layers
55
+
56
+ def modify(self, config):
57
+ if self.num_layers is None:
58
+ return
59
+ if hasattr(config, "num_hidden_layers"):
60
+ config.num_hidden_layers = self.num_layers
61
+ logger.info(f"Modifying pytorch model's number of hidden layers to: {self.num_layers}")
62
+ if hasattr(config, "encoder_layers"):
63
+ config.encoder_layers = self.num_layers
64
+ logger.info(f"Modifying pytorch model's number of encoder layers to: {self.num_layers}")
65
+ if hasattr(config, "decoder_layers "):
66
+ config.decoder_layers = self.num_layers
67
+ logger.info(f"Modifying pytorch model's number of decoder layers to: {self.num_layers}")
68
+
69
+ def get_layer_num(self):
70
+ return self.num_layers
71
+
72
+
73
+ IO_BINDING_DATA_TYPE_MAP = {
74
+ "float32": numpy.float32,
75
+ # TODO: Add more.
76
+ }
77
+
78
+
79
+ def create_onnxruntime_session(
80
+ onnx_model_path,
81
+ use_gpu,
82
+ provider=None,
83
+ enable_all_optimization=True,
84
+ num_threads=-1,
85
+ enable_profiling=False,
86
+ verbose=False,
87
+ enable_mlas_gemm_fastmath_arm64_bfloat16=False,
88
+ provider_options={}, # map execution provider name to its option # noqa: B006
89
+ ):
90
+ sess_options = onnxruntime.SessionOptions()
91
+
92
+ if enable_all_optimization:
93
+ sess_options.graph_optimization_level = onnxruntime.GraphOptimizationLevel.ORT_ENABLE_ALL
94
+ else:
95
+ sess_options.graph_optimization_level = onnxruntime.GraphOptimizationLevel.ORT_ENABLE_BASIC
96
+
97
+ if enable_profiling:
98
+ sess_options.enable_profiling = True
99
+
100
+ if num_threads > 0:
101
+ sess_options.intra_op_num_threads = num_threads
102
+ logger.debug(f"Session option: intra_op_num_threads={sess_options.intra_op_num_threads}")
103
+
104
+ if verbose:
105
+ sess_options.log_severity_level = 0
106
+ else:
107
+ sess_options.log_severity_level = 4
108
+
109
+ if provider in onnxruntime.get_available_providers():
110
+ providers = [provider]
111
+ elif use_gpu:
112
+ if provider == "dml":
113
+ providers = ["DmlExecutionProvider", "CPUExecutionProvider"]
114
+ elif provider == "migraphx":
115
+ providers = [
116
+ "MIGraphXExecutionProvider",
117
+ "CPUExecutionProvider",
118
+ ]
119
+ elif provider == "cuda" or provider is None:
120
+ providers = ["CUDAExecutionProvider", "CPUExecutionProvider"]
121
+ elif provider == "tensorrt":
122
+ providers = [
123
+ "TensorrtExecutionProvider",
124
+ "CUDAExecutionProvider",
125
+ "CPUExecutionProvider",
126
+ ]
127
+ else:
128
+ raise RuntimeError(f"The execution provider is not supported: {provider}")
129
+ else:
130
+ providers = ["CPUExecutionProvider"]
131
+
132
+ if provider_options:
133
+ providers = [(name, provider_options[name]) if name in provider_options else name for name in providers]
134
+
135
+ if enable_mlas_gemm_fastmath_arm64_bfloat16:
136
+ sess_options.add_session_config_entry("mlas.enable_gemm_fastmath_arm64_bfloat16", "1")
137
+
138
+ session = None
139
+ try:
140
+ session = onnxruntime.InferenceSession(onnx_model_path, sess_options, providers=providers)
141
+ except Exception:
142
+ logger.exception(f"Failed to create session for {onnx_model_path} with providers={providers}")
143
+
144
+ return session
145
+
146
+
147
+ def setup_logger(verbose=True):
148
+ if verbose:
149
+ logging.basicConfig(
150
+ format="[%(filename)s:%(lineno)s - %(funcName)20s()] %(message)s",
151
+ level=logging.DEBUG,
152
+ )
153
+ else:
154
+ logging.basicConfig(format="%(message)s", level=logging.INFO)
155
+ logging.getLogger("transformers").setLevel(logging.WARNING)
156
+
157
+
158
+ def prepare_environment(cache_dir, output_dir, use_gpu, provider=None):
159
+ if cache_dir and not os.path.exists(cache_dir):
160
+ os.makedirs(cache_dir)
161
+
162
+ if output_dir and not os.path.exists(output_dir):
163
+ os.makedirs(output_dir)
164
+
165
+ if use_gpu:
166
+ if provider == "dml":
167
+ assert "DmlExecutionProvider" in onnxruntime.get_available_providers(), (
168
+ "Please install onnxruntime-directml package to test GPU inference."
169
+ )
170
+
171
+ else:
172
+ assert not set(onnxruntime.get_available_providers()).isdisjoint(
173
+ ["CUDAExecutionProvider", "MIGraphXExecutionProvider"]
174
+ ), "Please install onnxruntime-gpu package, or install migraphx, to test GPU inference."
175
+
176
+ logger.info(f"PyTorch Version:{torch.__version__}")
177
+ logger.info(f"Transformers Version:{transformers.__version__}")
178
+ logger.info(f"OnnxRuntime Version:{onnxruntime.__version__}")
179
+
180
+ # Support three major versions of PyTorch and OnnxRuntime, and up to 9 months of transformers.
181
+ assert version.parse(torch.__version__) >= version.parse("1.10.0")
182
+ assert version.parse(transformers.__version__) >= version.parse("4.12.0")
183
+ assert version.parse(onnxruntime.__version__) >= version.parse("1.10.0")
184
+
185
+
186
+ def get_latency_result(latency_list, batch_size):
187
+ latency_ms = sum(latency_list) / float(len(latency_list)) * 1000.0
188
+ latency_variance = numpy.var(latency_list, dtype=numpy.float64) * 1000.0
189
+ throughput = batch_size * (1000.0 / latency_ms)
190
+
191
+ return {
192
+ "test_times": len(latency_list),
193
+ "latency_variance": f"{latency_variance:.2f}",
194
+ "latency_90_percentile": f"{numpy.percentile(latency_list, 90) * 1000.0:.2f}",
195
+ "latency_95_percentile": f"{numpy.percentile(latency_list, 95) * 1000.0:.2f}",
196
+ "latency_99_percentile": f"{numpy.percentile(latency_list, 99) * 1000.0:.2f}",
197
+ "average_latency_ms": f"{latency_ms:.2f}",
198
+ "QPS": f"{throughput:.2f}",
199
+ }
200
+
201
+
202
+ def output_details(results, csv_filename):
203
+ with open(csv_filename, mode="a", newline="", encoding="ascii") as csv_file:
204
+ column_names = [
205
+ "engine",
206
+ "version",
207
+ "providers",
208
+ "device",
209
+ "precision",
210
+ "optimizer",
211
+ "io_binding",
212
+ "model_name",
213
+ "inputs",
214
+ "threads",
215
+ "batch_size",
216
+ "sequence_length",
217
+ "custom_layer_num",
218
+ "datetime",
219
+ "test_times",
220
+ "QPS",
221
+ "average_latency_ms",
222
+ "latency_variance",
223
+ "latency_90_percentile",
224
+ "latency_95_percentile",
225
+ "latency_99_percentile",
226
+ ]
227
+
228
+ csv_writer = csv.DictWriter(csv_file, fieldnames=column_names)
229
+ csv_writer.writeheader()
230
+ for result in results:
231
+ csv_writer.writerow(result)
232
+
233
+ logger.info(f"Detail results are saved to csv file: {csv_filename}")
234
+
235
+
236
+ def output_summary(results, csv_filename, args):
237
+ with open(csv_filename, mode="a", newline="", encoding="ascii") as csv_file:
238
+ header_names = [
239
+ "model_name",
240
+ "inputs",
241
+ "custom_layer_num",
242
+ "engine",
243
+ "version",
244
+ "providers",
245
+ "device",
246
+ "precision",
247
+ "optimizer",
248
+ "io_binding",
249
+ "threads",
250
+ ]
251
+ data_names = []
252
+ for batch_size in args.batch_sizes:
253
+ if args.sequence_lengths == [""]:
254
+ data_names.append(f"b{batch_size}")
255
+ else:
256
+ for sequence_length in args.sequence_lengths:
257
+ data_names.append(f"b{batch_size}_s{sequence_length}")
258
+
259
+ csv_writer = csv.DictWriter(csv_file, fieldnames=header_names + data_names)
260
+ csv_writer.writeheader()
261
+ for model_name in args.models:
262
+ for input_count in [1, 2, 3]:
263
+ for engine_name in args.engines:
264
+ for io_binding in [True, False, ""]:
265
+ for threads in args.num_threads:
266
+ row = {}
267
+ for result in results:
268
+ if (
269
+ result["model_name"] == model_name
270
+ and result["inputs"] == input_count
271
+ and result["engine"] == engine_name
272
+ and result["io_binding"] == io_binding
273
+ and result["threads"] == threads
274
+ ):
275
+ headers = {k: v for k, v in result.items() if k in header_names}
276
+ if not row:
277
+ row.update(headers)
278
+ row.update(dict.fromkeys(data_names, ""))
279
+ else:
280
+ for k in header_names:
281
+ assert row[k] == headers[k]
282
+ b = result["batch_size"]
283
+ s = result["sequence_length"]
284
+ if s:
285
+ row[f"b{b}_s{s}"] = result["average_latency_ms"]
286
+ else:
287
+ row[f"b{b}"] = result["average_latency_ms"]
288
+ if row:
289
+ csv_writer.writerow(row)
290
+
291
+ logger.info(f"Summary results are saved to csv file: {csv_filename}")
292
+
293
+
294
+ def output_fusion_statistics(model_fusion_statistics, csv_filename):
295
+ with open(csv_filename, mode="a", newline="", encoding="ascii") as csv_file:
296
+ column_names = [
297
+ "model_filename",
298
+ "datetime",
299
+ "transformers",
300
+ "torch",
301
+ *list(next(iter(model_fusion_statistics.values())).keys()),
302
+ ]
303
+ csv_writer = csv.DictWriter(csv_file, fieldnames=column_names)
304
+ csv_writer.writeheader()
305
+ for key in model_fusion_statistics:
306
+ model_fusion_statistics[key]["datetime"] = str(datetime.now())
307
+ model_fusion_statistics[key]["transformers"] = transformers.__version__
308
+ model_fusion_statistics[key]["torch"] = torch.__version__
309
+ model_fusion_statistics[key]["model_filename"] = key
310
+ csv_writer.writerow(model_fusion_statistics[key])
311
+ logger.info(f"Fusion statistics is saved to csv file: {csv_filename}")
312
+
313
+
314
+ def inference_ort(ort_session, ort_inputs, result_template, repeat_times, batch_size, warm_up_repeat=0):
315
+ result = {}
316
+ timeit.repeat(lambda: ort_session.run(None, ort_inputs), number=1, repeat=warm_up_repeat) # Dry run
317
+ latency_list = timeit.repeat(lambda: ort_session.run(None, ort_inputs), number=1, repeat=repeat_times)
318
+ result.update(result_template)
319
+ result.update({"io_binding": False})
320
+ result.update(get_latency_result(latency_list, batch_size))
321
+ return result
322
+
323
+
324
+ def inference_ort_with_io_binding(
325
+ ort_session,
326
+ ort_inputs,
327
+ result_template,
328
+ repeat_times,
329
+ ort_output_names,
330
+ ort_outputs,
331
+ output_buffers,
332
+ output_buffer_max_sizes,
333
+ batch_size,
334
+ device,
335
+ data_type=numpy.longlong,
336
+ warm_up_repeat=0,
337
+ ):
338
+ result = {}
339
+
340
+ # Bind inputs and outputs to onnxruntime session
341
+ io_binding = ort_session.io_binding()
342
+ # Bind inputs to device
343
+ for name in ort_inputs:
344
+ np_input = torch.from_numpy(ort_inputs[name]).to(device)
345
+ input_type = IO_BINDING_DATA_TYPE_MAP.get(str(ort_inputs[name].dtype), data_type)
346
+ io_binding.bind_input(
347
+ name,
348
+ np_input.device.type,
349
+ 0,
350
+ input_type,
351
+ np_input.shape,
352
+ np_input.data_ptr(),
353
+ )
354
+ # Bind outputs buffers with the sizes needed if not allocated already
355
+ if len(output_buffers) == 0:
356
+ allocateOutputBuffers(output_buffers, output_buffer_max_sizes, device)
357
+
358
+ for i, ort_output_name in enumerate(ort_output_names):
359
+ io_binding.bind_output(
360
+ ort_output_name,
361
+ output_buffers[i].device.type,
362
+ 0,
363
+ numpy.float32,
364
+ ort_outputs[i].shape,
365
+ output_buffers[i].data_ptr(),
366
+ )
367
+
368
+ timeit.repeat(
369
+ lambda: ort_session.run_with_iobinding(io_binding),
370
+ number=1,
371
+ repeat=warm_up_repeat,
372
+ ) # Dry run
373
+
374
+ latency_list = timeit.repeat(
375
+ lambda: ort_session.run_with_iobinding(io_binding),
376
+ number=1,
377
+ repeat=repeat_times,
378
+ )
379
+ result.update(result_template)
380
+ result.update({"io_binding": True})
381
+ result.update(get_latency_result(latency_list, batch_size))
382
+ return result
383
+
384
+
385
+ def allocateOutputBuffers(output_buffers, output_buffer_max_sizes, device): # noqa: N802
386
+ # Allocate output tensors with the largest test size needed. So the allocated memory can be reused
387
+ # for each test run.
388
+
389
+ for i in output_buffer_max_sizes:
390
+ output_buffers.append(torch.empty(i, dtype=torch.float32, device=device))
391
+
392
+
393
+ def set_random_seed(seed=123):
394
+ """Set random seed manually to get deterministic results"""
395
+ random.seed(seed)
396
+ numpy.random.seed(seed)
397
+ torch.manual_seed(seed)
398
+ torch.cuda.manual_seed(seed)
399
+ torch.cuda.manual_seed_all(seed)
400
+ # torch.backends.cudnn.enabled = False
401
+ # torch.backends.cudnn.benchmark = False
402
+ # torch.backends.cudnn.deterministic = True
403
+
404
+
405
+ def get_gpu_info() -> list[dict[str, Any]] | None:
406
+ from py3nvml.py3nvml import ( # noqa: PLC0415
407
+ NVMLError,
408
+ nvmlDeviceGetCount,
409
+ nvmlDeviceGetHandleByIndex,
410
+ nvmlDeviceGetMemoryInfo,
411
+ nvmlDeviceGetName,
412
+ nvmlInit,
413
+ nvmlShutdown,
414
+ )
415
+
416
+ try:
417
+ nvmlInit()
418
+ result = []
419
+ device_count = nvmlDeviceGetCount()
420
+ if not isinstance(device_count, int):
421
+ return None
422
+
423
+ for i in range(device_count):
424
+ info = nvmlDeviceGetMemoryInfo(nvmlDeviceGetHandleByIndex(i))
425
+ if isinstance(info, str):
426
+ return None
427
+ result.append(
428
+ {
429
+ "id": i,
430
+ "name": nvmlDeviceGetName(nvmlDeviceGetHandleByIndex(i)),
431
+ "total": info.total,
432
+ "free": info.free,
433
+ "used": info.used,
434
+ }
435
+ )
436
+ nvmlShutdown()
437
+ return result
438
+ except NVMLError as error:
439
+ print("Error fetching GPU information using nvml: %s", error)
440
+ return None
441
+
442
+
443
+ class MemoryMonitor(ABC):
444
+ def __init__(self, keep_measuring=True):
445
+ self.keep_measuring = keep_measuring
446
+
447
+ def measure_cpu_usage(self):
448
+ import psutil # noqa: PLC0415
449
+
450
+ max_usage = 0
451
+ while True:
452
+ max_usage = max(max_usage, psutil.Process(os.getpid()).memory_info().rss / 1024**2)
453
+ sleep(0.005) # 5ms
454
+ if not self.keep_measuring:
455
+ break
456
+ return max_usage
457
+
458
+ @abstractmethod
459
+ def measure_gpu_usage(self) -> list[dict[str, Any]] | None:
460
+ raise NotImplementedError()
461
+
462
+
463
+ class CudaMemoryMonitor(MemoryMonitor):
464
+ def __init__(self, keep_measuring=True):
465
+ super().__init__(keep_measuring)
466
+
467
+ def measure_gpu_usage(self) -> list[dict[str, Any]] | None:
468
+ from py3nvml.py3nvml import ( # noqa: PLC0415
469
+ NVMLError,
470
+ nvmlDeviceGetCount,
471
+ nvmlDeviceGetHandleByIndex,
472
+ nvmlDeviceGetMemoryInfo,
473
+ nvmlDeviceGetName,
474
+ nvmlInit,
475
+ nvmlShutdown,
476
+ )
477
+
478
+ max_gpu_usage = []
479
+ gpu_name = []
480
+ try:
481
+ nvmlInit()
482
+ device_count = nvmlDeviceGetCount()
483
+ if not isinstance(device_count, int):
484
+ logger.error(f"nvmlDeviceGetCount result is not integer: {device_count}")
485
+ return None
486
+
487
+ max_gpu_usage = [0 for i in range(device_count)]
488
+ gpu_name = [nvmlDeviceGetName(nvmlDeviceGetHandleByIndex(i)) for i in range(device_count)]
489
+ while True:
490
+ for i in range(device_count):
491
+ info = nvmlDeviceGetMemoryInfo(nvmlDeviceGetHandleByIndex(i))
492
+ if isinstance(info, str):
493
+ logger.error(f"nvmlDeviceGetMemoryInfo returns str: {info}")
494
+ return None
495
+ max_gpu_usage[i] = max(max_gpu_usage[i], info.used / 1024**2)
496
+ sleep(0.005) # 5ms
497
+ if not self.keep_measuring:
498
+ break
499
+ nvmlShutdown()
500
+ return [
501
+ {
502
+ "device_id": i,
503
+ "name": gpu_name[i],
504
+ "max_used_MB": max_gpu_usage[i],
505
+ }
506
+ for i in range(device_count)
507
+ ]
508
+ except NVMLError as error:
509
+ logger.error("Error fetching GPU information using nvml: %s", error)
510
+ return None
511
+
512
+
513
+ class RocmMemoryMonitor(MemoryMonitor):
514
+ def __init__(self, keep_measuring=True):
515
+ super().__init__(keep_measuring)
516
+ rocm_smi_path = "/opt/rocm/libexec/rocm_smi"
517
+ if os.path.exists(rocm_smi_path):
518
+ if rocm_smi_path not in sys.path:
519
+ sys.path.append(rocm_smi_path)
520
+ try:
521
+ import rocm_smi # noqa: PLC0415
522
+
523
+ self.rocm_smi = rocm_smi
524
+ self.rocm_smi.initializeRsmi()
525
+ except ImportError:
526
+ self.rocm_smi = None
527
+
528
+ def get_used_memory(self, dev):
529
+ if self.rocm_smi is None:
530
+ return -1
531
+ return self.rocm_smi.getMemInfo(dev, "VRAM")[0] / 1024 / 1024
532
+
533
+ def measure_gpu_usage(self):
534
+ if self.rocm_smi is None:
535
+ return None
536
+
537
+ device_count = len(self.rocm_smi.listDevices()) if self.rocm_smi is not None else 0
538
+ max_gpu_usage = [0 for i in range(device_count)]
539
+ gpu_name = [f"GPU{i}" for i in range(device_count)]
540
+ while True:
541
+ for i in range(device_count):
542
+ max_gpu_usage[i] = max(max_gpu_usage[i], self.get_used_memory(i))
543
+ time.sleep(0.005) # 5ms
544
+ if not self.keep_measuring:
545
+ break
546
+ return [
547
+ {
548
+ "device_id": i,
549
+ "name": gpu_name[i],
550
+ "max_used_MB": max_gpu_usage[i],
551
+ }
552
+ for i in range(device_count)
553
+ ]
554
+
555
+
556
+ def measure_memory(is_gpu, func, monitor_type="cuda", start_memory=None):
557
+ memory_monitor_type = None
558
+ if monitor_type == "rocm":
559
+ memory_monitor_type = RocmMemoryMonitor
560
+ else:
561
+ memory_monitor_type = CudaMemoryMonitor
562
+
563
+ monitor = memory_monitor_type(False)
564
+
565
+ if is_gpu:
566
+ if start_memory is not None:
567
+ memory_before_test = start_memory
568
+ else:
569
+ memory_before_test = monitor.measure_gpu_usage()
570
+ if memory_before_test is None:
571
+ return None
572
+
573
+ if func is None:
574
+ return memory_before_test
575
+
576
+ with ThreadPoolExecutor() as executor:
577
+ monitor = memory_monitor_type()
578
+ mem_thread = executor.submit(monitor.measure_gpu_usage)
579
+ try:
580
+ fn_thread = executor.submit(func)
581
+ _ = fn_thread.result()
582
+ finally:
583
+ monitor.keep_measuring = False
584
+ max_usage = mem_thread.result()
585
+
586
+ if max_usage is None:
587
+ return None
588
+
589
+ logger.info(f"GPU memory usage: before={memory_before_test} peak={max_usage}")
590
+ if len(memory_before_test) >= 1 and len(max_usage) >= 1 and len(memory_before_test) == len(max_usage):
591
+ # When there are multiple GPUs, we will check the one with maximum usage.
592
+ max_used = 0
593
+ for i, memory_before in enumerate(memory_before_test):
594
+ before = memory_before["max_used_MB"]
595
+ after = max_usage[i]["max_used_MB"]
596
+ used = after - before
597
+ max_used = max(max_used, used)
598
+ return max_used
599
+ return None
600
+
601
+ # CPU memory
602
+ if start_memory is not None:
603
+ memory_before_test = start_memory
604
+ else:
605
+ memory_before_test = monitor.measure_cpu_usage()
606
+
607
+ if func is None:
608
+ return memory_before_test
609
+
610
+ with ThreadPoolExecutor() as executor:
611
+ monitor = memory_monitor_type()
612
+ mem_thread = executor.submit(monitor.measure_cpu_usage)
613
+ try:
614
+ fn_thread = executor.submit(func)
615
+ _ = fn_thread.result()
616
+ finally:
617
+ monitor.keep_measuring = False
618
+ max_usage = mem_thread.result()
619
+
620
+ logger.info(f"CPU memory usage: before={memory_before_test:.1f} MB, peak={max_usage:.1f} MB")
621
+ return max_usage - memory_before_test
622
+
623
+
624
+ def get_ort_environment_variables():
625
+ # Environment variables might impact ORT performance on transformer models. Note that they are for testing only.
626
+ env_names = [
627
+ "ORT_DISABLE_FUSED_ATTENTION",
628
+ "ORT_ENABLE_FUSED_CAUSAL_ATTENTION",
629
+ "ORT_DISABLE_FUSED_CROSS_ATTENTION",
630
+ "ORT_DISABLE_TRT_FLASH_ATTENTION",
631
+ "ORT_DISABLE_MEMORY_EFFICIENT_ATTENTION",
632
+ "ORT_TRANSFORMER_OPTIONS",
633
+ "ORT_CUDA_GEMM_OPTIONS",
634
+ ]
635
+ env = ""
636
+ for name in env_names:
637
+ value = os.getenv(name)
638
+ if value is None:
639
+ continue
640
+ if env:
641
+ env += ","
642
+ env += f"{name}={value}"
643
+ return env