onnxruntime-directml 1.20.0__cp313-cp313-win_amd64.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (305) hide show
  1. onnxruntime/LICENSE +21 -0
  2. onnxruntime/Privacy.md +21 -0
  3. onnxruntime/ThirdPartyNotices.txt +6508 -0
  4. onnxruntime/__init__.py +78 -0
  5. onnxruntime/backend/__init__.py +6 -0
  6. onnxruntime/backend/backend.py +174 -0
  7. onnxruntime/backend/backend_rep.py +53 -0
  8. onnxruntime/capi/DirectML.dll +0 -0
  9. onnxruntime/capi/__init__.py +4 -0
  10. onnxruntime/capi/_ld_preload.py +7 -0
  11. onnxruntime/capi/_pybind_state.py +33 -0
  12. onnxruntime/capi/convert_npz_to_onnx_adapter.py +48 -0
  13. onnxruntime/capi/onnxruntime.dll +0 -0
  14. onnxruntime/capi/onnxruntime_collect_build_info.py +47 -0
  15. onnxruntime/capi/onnxruntime_inference_collection.py +1108 -0
  16. onnxruntime/capi/onnxruntime_providers_shared.dll +0 -0
  17. onnxruntime/capi/onnxruntime_pybind11_state.pyd +0 -0
  18. onnxruntime/capi/onnxruntime_validation.py +150 -0
  19. onnxruntime/capi/version_info.py +2 -0
  20. onnxruntime/datasets/__init__.py +17 -0
  21. onnxruntime/datasets/logreg_iris.onnx +0 -0
  22. onnxruntime/datasets/mul_1.onnx +0 -0
  23. onnxruntime/datasets/sigmoid.onnx +13 -0
  24. onnxruntime/quantization/CalTableFlatBuffers/KeyValue.py +78 -0
  25. onnxruntime/quantization/CalTableFlatBuffers/TrtTable.py +90 -0
  26. onnxruntime/quantization/CalTableFlatBuffers/__init__.py +0 -0
  27. onnxruntime/quantization/__init__.py +16 -0
  28. onnxruntime/quantization/base_quantizer.py +532 -0
  29. onnxruntime/quantization/calibrate.py +1245 -0
  30. onnxruntime/quantization/execution_providers/qnn/__init__.py +2 -0
  31. onnxruntime/quantization/execution_providers/qnn/fusion_lpnorm.py +132 -0
  32. onnxruntime/quantization/execution_providers/qnn/mixed_precision_overrides_utils.py +413 -0
  33. onnxruntime/quantization/execution_providers/qnn/preprocess.py +307 -0
  34. onnxruntime/quantization/execution_providers/qnn/quant_config.py +387 -0
  35. onnxruntime/quantization/fusions/__init__.py +3 -0
  36. onnxruntime/quantization/fusions/fusion.py +311 -0
  37. onnxruntime/quantization/fusions/fusion_gelu.py +272 -0
  38. onnxruntime/quantization/fusions/fusion_layernorm.py +135 -0
  39. onnxruntime/quantization/matmul_4bits_quantizer.py +1480 -0
  40. onnxruntime/quantization/matmul_bnb4_quantizer.py +240 -0
  41. onnxruntime/quantization/onnx_model.py +580 -0
  42. onnxruntime/quantization/onnx_quantizer.py +1008 -0
  43. onnxruntime/quantization/operators/__init__.py +2 -0
  44. onnxruntime/quantization/operators/activation.py +119 -0
  45. onnxruntime/quantization/operators/argmax.py +18 -0
  46. onnxruntime/quantization/operators/attention.py +73 -0
  47. onnxruntime/quantization/operators/base_operator.py +26 -0
  48. onnxruntime/quantization/operators/binary_op.py +72 -0
  49. onnxruntime/quantization/operators/concat.py +62 -0
  50. onnxruntime/quantization/operators/conv.py +258 -0
  51. onnxruntime/quantization/operators/direct_q8.py +78 -0
  52. onnxruntime/quantization/operators/embed_layernorm.py +121 -0
  53. onnxruntime/quantization/operators/gather.py +64 -0
  54. onnxruntime/quantization/operators/gavgpool.py +62 -0
  55. onnxruntime/quantization/operators/gemm.py +166 -0
  56. onnxruntime/quantization/operators/lstm.py +117 -0
  57. onnxruntime/quantization/operators/matmul.py +231 -0
  58. onnxruntime/quantization/operators/maxpool.py +34 -0
  59. onnxruntime/quantization/operators/norm.py +40 -0
  60. onnxruntime/quantization/operators/pad.py +100 -0
  61. onnxruntime/quantization/operators/pooling.py +67 -0
  62. onnxruntime/quantization/operators/qdq_base_operator.py +22 -0
  63. onnxruntime/quantization/operators/resize.py +34 -0
  64. onnxruntime/quantization/operators/softmax.py +74 -0
  65. onnxruntime/quantization/operators/split.py +63 -0
  66. onnxruntime/quantization/operators/where.py +87 -0
  67. onnxruntime/quantization/preprocess.py +141 -0
  68. onnxruntime/quantization/qdq_loss_debug.py +389 -0
  69. onnxruntime/quantization/qdq_quantizer.py +1187 -0
  70. onnxruntime/quantization/quant_utils.py +891 -0
  71. onnxruntime/quantization/quantize.py +748 -0
  72. onnxruntime/quantization/registry.py +106 -0
  73. onnxruntime/quantization/shape_inference.py +187 -0
  74. onnxruntime/quantization/tensor_quant_overrides.py +516 -0
  75. onnxruntime/tools/__init__.py +10 -0
  76. onnxruntime/tools/check_onnx_model_mobile_usability.py +47 -0
  77. onnxruntime/tools/convert_onnx_models_to_ort.py +377 -0
  78. onnxruntime/tools/file_utils.py +46 -0
  79. onnxruntime/tools/logger.py +11 -0
  80. onnxruntime/tools/make_dynamic_shape_fixed.py +72 -0
  81. onnxruntime/tools/mobile_helpers/__init__.py +0 -0
  82. onnxruntime/tools/mobile_helpers/coreml_supported_mlprogram_ops.md +33 -0
  83. onnxruntime/tools/mobile_helpers/coreml_supported_neuralnetwork_ops.md +43 -0
  84. onnxruntime/tools/mobile_helpers/nnapi_supported_ops.md +58 -0
  85. onnxruntime/tools/mobile_helpers/usability_checker.py +739 -0
  86. onnxruntime/tools/offline_tuning.py +169 -0
  87. onnxruntime/tools/onnx_model_utils.py +413 -0
  88. onnxruntime/tools/onnx_randomizer.py +85 -0
  89. onnxruntime/tools/onnxruntime_test.py +164 -0
  90. onnxruntime/tools/optimize_onnx_model.py +55 -0
  91. onnxruntime/tools/ort_format_model/__init__.py +25 -0
  92. onnxruntime/tools/ort_format_model/operator_type_usage_processors.py +663 -0
  93. onnxruntime/tools/ort_format_model/ort_flatbuffers_py/__init__.py +0 -0
  94. onnxruntime/tools/ort_format_model/ort_flatbuffers_py/fbs/ArgType.py +7 -0
  95. onnxruntime/tools/ort_format_model/ort_flatbuffers_py/fbs/ArgTypeAndIndex.py +67 -0
  96. onnxruntime/tools/ort_format_model/ort_flatbuffers_py/fbs/Attribute.py +337 -0
  97. onnxruntime/tools/ort_format_model/ort_flatbuffers_py/fbs/AttributeType.py +18 -0
  98. onnxruntime/tools/ort_format_model/ort_flatbuffers_py/fbs/Checkpoint.py +125 -0
  99. onnxruntime/tools/ort_format_model/ort_flatbuffers_py/fbs/DeprecatedKernelCreateInfos.py +120 -0
  100. onnxruntime/tools/ort_format_model/ort_flatbuffers_py/fbs/DeprecatedNodeIndexAndKernelDefHash.py +68 -0
  101. onnxruntime/tools/ort_format_model/ort_flatbuffers_py/fbs/DeprecatedSessionState.py +96 -0
  102. onnxruntime/tools/ort_format_model/ort_flatbuffers_py/fbs/DeprecatedSubGraphSessionState.py +72 -0
  103. onnxruntime/tools/ort_format_model/ort_flatbuffers_py/fbs/Dimension.py +71 -0
  104. onnxruntime/tools/ort_format_model/ort_flatbuffers_py/fbs/DimensionValue.py +80 -0
  105. onnxruntime/tools/ort_format_model/ort_flatbuffers_py/fbs/DimensionValueType.py +8 -0
  106. onnxruntime/tools/ort_format_model/ort_flatbuffers_py/fbs/EdgeEnd.py +32 -0
  107. onnxruntime/tools/ort_format_model/ort_flatbuffers_py/fbs/FloatProperty.py +67 -0
  108. onnxruntime/tools/ort_format_model/ort_flatbuffers_py/fbs/Graph.py +320 -0
  109. onnxruntime/tools/ort_format_model/ort_flatbuffers_py/fbs/InferenceSession.py +88 -0
  110. onnxruntime/tools/ort_format_model/ort_flatbuffers_py/fbs/IntProperty.py +67 -0
  111. onnxruntime/tools/ort_format_model/ort_flatbuffers_py/fbs/KernelTypeStrArgsEntry.py +91 -0
  112. onnxruntime/tools/ort_format_model/ort_flatbuffers_py/fbs/KernelTypeStrResolver.py +78 -0
  113. onnxruntime/tools/ort_format_model/ort_flatbuffers_py/fbs/MapType.py +71 -0
  114. onnxruntime/tools/ort_format_model/ort_flatbuffers_py/fbs/Model.py +223 -0
  115. onnxruntime/tools/ort_format_model/ort_flatbuffers_py/fbs/ModuleState.py +141 -0
  116. onnxruntime/tools/ort_format_model/ort_flatbuffers_py/fbs/Node.py +317 -0
  117. onnxruntime/tools/ort_format_model/ort_flatbuffers_py/fbs/NodeEdge.py +126 -0
  118. onnxruntime/tools/ort_format_model/ort_flatbuffers_py/fbs/NodeType.py +7 -0
  119. onnxruntime/tools/ort_format_model/ort_flatbuffers_py/fbs/NodesToOptimizeIndices.py +160 -0
  120. onnxruntime/tools/ort_format_model/ort_flatbuffers_py/fbs/OpIdKernelTypeStrArgsEntry.py +91 -0
  121. onnxruntime/tools/ort_format_model/ort_flatbuffers_py/fbs/OperatorSetId.py +67 -0
  122. onnxruntime/tools/ort_format_model/ort_flatbuffers_py/fbs/OptimizerGroup.py +117 -0
  123. onnxruntime/tools/ort_format_model/ort_flatbuffers_py/fbs/ParameterOptimizerState.py +91 -0
  124. onnxruntime/tools/ort_format_model/ort_flatbuffers_py/fbs/PropertyBag.py +152 -0
  125. onnxruntime/tools/ort_format_model/ort_flatbuffers_py/fbs/RuntimeOptimizationRecord.py +105 -0
  126. onnxruntime/tools/ort_format_model/ort_flatbuffers_py/fbs/RuntimeOptimizationRecordContainerEntry.py +91 -0
  127. onnxruntime/tools/ort_format_model/ort_flatbuffers_py/fbs/RuntimeOptimizations.py +79 -0
  128. onnxruntime/tools/ort_format_model/ort_flatbuffers_py/fbs/SequenceType.py +58 -0
  129. onnxruntime/tools/ort_format_model/ort_flatbuffers_py/fbs/Shape.py +78 -0
  130. onnxruntime/tools/ort_format_model/ort_flatbuffers_py/fbs/SparseTensor.py +114 -0
  131. onnxruntime/tools/ort_format_model/ort_flatbuffers_py/fbs/StringProperty.py +67 -0
  132. onnxruntime/tools/ort_format_model/ort_flatbuffers_py/fbs/StringStringEntry.py +67 -0
  133. onnxruntime/tools/ort_format_model/ort_flatbuffers_py/fbs/Tensor.py +203 -0
  134. onnxruntime/tools/ort_format_model/ort_flatbuffers_py/fbs/TensorDataType.py +26 -0
  135. onnxruntime/tools/ort_format_model/ort_flatbuffers_py/fbs/TensorTypeAndShape.py +71 -0
  136. onnxruntime/tools/ort_format_model/ort_flatbuffers_py/fbs/TypeInfo.py +83 -0
  137. onnxruntime/tools/ort_format_model/ort_flatbuffers_py/fbs/TypeInfoValue.py +9 -0
  138. onnxruntime/tools/ort_format_model/ort_flatbuffers_py/fbs/ValueInfo.py +84 -0
  139. onnxruntime/tools/ort_format_model/ort_flatbuffers_py/fbs/__init__.py +6 -0
  140. onnxruntime/tools/ort_format_model/ort_model_processor.py +86 -0
  141. onnxruntime/tools/ort_format_model/types.py +84 -0
  142. onnxruntime/tools/ort_format_model/utils.py +62 -0
  143. onnxruntime/tools/pytorch_export_contrib_ops.py +108 -0
  144. onnxruntime/tools/pytorch_export_helpers.py +131 -0
  145. onnxruntime/tools/qdq_helpers/__init__.py +0 -0
  146. onnxruntime/tools/qdq_helpers/optimize_qdq_model.py +37 -0
  147. onnxruntime/tools/reduced_build_config_parser.py +202 -0
  148. onnxruntime/tools/symbolic_shape_infer.py +3016 -0
  149. onnxruntime/tools/update_onnx_opset.py +31 -0
  150. onnxruntime/transformers/__init__.py +8 -0
  151. onnxruntime/transformers/affinity_helper.py +40 -0
  152. onnxruntime/transformers/benchmark.py +944 -0
  153. onnxruntime/transformers/benchmark_helper.py +646 -0
  154. onnxruntime/transformers/bert_perf_test.py +634 -0
  155. onnxruntime/transformers/bert_test_data.py +642 -0
  156. onnxruntime/transformers/compare_bert_results.py +246 -0
  157. onnxruntime/transformers/constants.py +47 -0
  158. onnxruntime/transformers/convert_generation.py +3124 -0
  159. onnxruntime/transformers/convert_tf_models_to_pytorch.py +205 -0
  160. onnxruntime/transformers/convert_to_packing_mode.py +387 -0
  161. onnxruntime/transformers/dynamo_onnx_helper.py +104 -0
  162. onnxruntime/transformers/float16.py +501 -0
  163. onnxruntime/transformers/fusion_attention.py +1235 -0
  164. onnxruntime/transformers/fusion_attention_clip.py +257 -0
  165. onnxruntime/transformers/fusion_attention_sam2.py +534 -0
  166. onnxruntime/transformers/fusion_attention_unet.py +1304 -0
  167. onnxruntime/transformers/fusion_attention_vae.py +301 -0
  168. onnxruntime/transformers/fusion_bart_attention.py +640 -0
  169. onnxruntime/transformers/fusion_base.py +137 -0
  170. onnxruntime/transformers/fusion_bias_add.py +58 -0
  171. onnxruntime/transformers/fusion_biasgelu.py +66 -0
  172. onnxruntime/transformers/fusion_biassplitgelu.py +111 -0
  173. onnxruntime/transformers/fusion_conformer_attention.py +143 -0
  174. onnxruntime/transformers/fusion_embedlayer.py +811 -0
  175. onnxruntime/transformers/fusion_fastgelu.py +360 -0
  176. onnxruntime/transformers/fusion_gelu.py +259 -0
  177. onnxruntime/transformers/fusion_gelu_approximation.py +25 -0
  178. onnxruntime/transformers/fusion_gemmfastgelu.py +122 -0
  179. onnxruntime/transformers/fusion_gpt_attention.py +546 -0
  180. onnxruntime/transformers/fusion_gpt_attention_megatron.py +355 -0
  181. onnxruntime/transformers/fusion_gpt_attention_no_past.py +260 -0
  182. onnxruntime/transformers/fusion_group_norm.py +179 -0
  183. onnxruntime/transformers/fusion_layernorm.py +465 -0
  184. onnxruntime/transformers/fusion_nhwc_conv.py +100 -0
  185. onnxruntime/transformers/fusion_options.py +340 -0
  186. onnxruntime/transformers/fusion_qordered_attention.py +421 -0
  187. onnxruntime/transformers/fusion_qordered_gelu.py +119 -0
  188. onnxruntime/transformers/fusion_qordered_layernorm.py +123 -0
  189. onnxruntime/transformers/fusion_qordered_matmul.py +217 -0
  190. onnxruntime/transformers/fusion_quickgelu.py +74 -0
  191. onnxruntime/transformers/fusion_reshape.py +173 -0
  192. onnxruntime/transformers/fusion_rotary_attention.py +1592 -0
  193. onnxruntime/transformers/fusion_shape.py +110 -0
  194. onnxruntime/transformers/fusion_simplified_layernorm.py +159 -0
  195. onnxruntime/transformers/fusion_skip_group_norm.py +255 -0
  196. onnxruntime/transformers/fusion_skiplayernorm.py +209 -0
  197. onnxruntime/transformers/fusion_transpose.py +168 -0
  198. onnxruntime/transformers/fusion_utils.py +307 -0
  199. onnxruntime/transformers/huggingface_models.py +167 -0
  200. onnxruntime/transformers/import_utils.py +20 -0
  201. onnxruntime/transformers/io_binding_helper.py +442 -0
  202. onnxruntime/transformers/large_model_exporter.py +395 -0
  203. onnxruntime/transformers/machine_info.py +221 -0
  204. onnxruntime/transformers/metrics.py +164 -0
  205. onnxruntime/transformers/models/bart/__init__.py +12 -0
  206. onnxruntime/transformers/models/bart/export.py +98 -0
  207. onnxruntime/transformers/models/bert/__init__.py +12 -0
  208. onnxruntime/transformers/models/bert/eval_squad.py +329 -0
  209. onnxruntime/transformers/models/gpt2/__init__.py +12 -0
  210. onnxruntime/transformers/models/gpt2/benchmark_gpt2.py +413 -0
  211. onnxruntime/transformers/models/gpt2/convert_to_onnx.py +561 -0
  212. onnxruntime/transformers/models/gpt2/gpt2_helper.py +1032 -0
  213. onnxruntime/transformers/models/gpt2/gpt2_parity.py +513 -0
  214. onnxruntime/transformers/models/gpt2/gpt2_tester.py +501 -0
  215. onnxruntime/transformers/models/gpt2/parity_check_helper.py +146 -0
  216. onnxruntime/transformers/models/llama/__init__.py +12 -0
  217. onnxruntime/transformers/models/llama/benchmark.py +703 -0
  218. onnxruntime/transformers/models/llama/benchmark_all.py +488 -0
  219. onnxruntime/transformers/models/llama/benchmark_e2e.py +606 -0
  220. onnxruntime/transformers/models/llama/convert_to_onnx.py +1027 -0
  221. onnxruntime/transformers/models/llama/dist_settings.py +57 -0
  222. onnxruntime/transformers/models/llama/llama_inputs.py +503 -0
  223. onnxruntime/transformers/models/llama/llama_parity.py +309 -0
  224. onnxruntime/transformers/models/llama/llama_torch.py +47 -0
  225. onnxruntime/transformers/models/llama/quant_kv_dataloader.py +108 -0
  226. onnxruntime/transformers/models/longformer/__init__.py +12 -0
  227. onnxruntime/transformers/models/longformer/benchmark_longformer.py +821 -0
  228. onnxruntime/transformers/models/longformer/convert_to_onnx.py +413 -0
  229. onnxruntime/transformers/models/longformer/generate_test_data.py +347 -0
  230. onnxruntime/transformers/models/longformer/longformer_helper.py +77 -0
  231. onnxruntime/transformers/models/phi2/__init__.py +12 -0
  232. onnxruntime/transformers/models/phi2/convert_to_onnx.py +576 -0
  233. onnxruntime/transformers/models/phi2/inference_example.py +414 -0
  234. onnxruntime/transformers/models/sam2/__init__.py +12 -0
  235. onnxruntime/transformers/models/sam2/benchmark_sam2.py +625 -0
  236. onnxruntime/transformers/models/sam2/convert_to_onnx.py +260 -0
  237. onnxruntime/transformers/models/sam2/image_decoder.py +273 -0
  238. onnxruntime/transformers/models/sam2/image_encoder.py +186 -0
  239. onnxruntime/transformers/models/sam2/mask_decoder.py +208 -0
  240. onnxruntime/transformers/models/sam2/nvtx_helper.py +33 -0
  241. onnxruntime/transformers/models/sam2/prompt_encoder.py +189 -0
  242. onnxruntime/transformers/models/sam2/sam2_demo.py +322 -0
  243. onnxruntime/transformers/models/sam2/sam2_image_onnx_predictor.py +280 -0
  244. onnxruntime/transformers/models/sam2/sam2_utils.py +147 -0
  245. onnxruntime/transformers/models/stable_diffusion/__init__.py +12 -0
  246. onnxruntime/transformers/models/stable_diffusion/benchmark.py +1429 -0
  247. onnxruntime/transformers/models/stable_diffusion/benchmark_controlnet.py +426 -0
  248. onnxruntime/transformers/models/stable_diffusion/demo_txt2img.py +102 -0
  249. onnxruntime/transformers/models/stable_diffusion/demo_txt2img_xl.py +268 -0
  250. onnxruntime/transformers/models/stable_diffusion/demo_utils.py +778 -0
  251. onnxruntime/transformers/models/stable_diffusion/diffusion_models.py +1319 -0
  252. onnxruntime/transformers/models/stable_diffusion/diffusion_schedulers.py +1181 -0
  253. onnxruntime/transformers/models/stable_diffusion/engine_builder.py +296 -0
  254. onnxruntime/transformers/models/stable_diffusion/engine_builder_ort_cuda.py +388 -0
  255. onnxruntime/transformers/models/stable_diffusion/engine_builder_ort_trt.py +288 -0
  256. onnxruntime/transformers/models/stable_diffusion/engine_builder_tensorrt.py +395 -0
  257. onnxruntime/transformers/models/stable_diffusion/engine_builder_torch.py +108 -0
  258. onnxruntime/transformers/models/stable_diffusion/optimize_pipeline.py +350 -0
  259. onnxruntime/transformers/models/stable_diffusion/ort_optimizer.py +136 -0
  260. onnxruntime/transformers/models/stable_diffusion/pipeline_stable_diffusion.py +831 -0
  261. onnxruntime/transformers/models/stable_diffusion/trt_utilities.py +12 -0
  262. onnxruntime/transformers/models/t5/__init__.py +12 -0
  263. onnxruntime/transformers/models/t5/convert_to_onnx.py +278 -0
  264. onnxruntime/transformers/models/t5/past_helper.py +150 -0
  265. onnxruntime/transformers/models/t5/t5_decoder.py +438 -0
  266. onnxruntime/transformers/models/t5/t5_encoder.py +171 -0
  267. onnxruntime/transformers/models/t5/t5_encoder_decoder_init.py +299 -0
  268. onnxruntime/transformers/models/t5/t5_helper.py +272 -0
  269. onnxruntime/transformers/models/whisper/__init__.py +12 -0
  270. onnxruntime/transformers/models/whisper/benchmark.py +610 -0
  271. onnxruntime/transformers/models/whisper/benchmark_all.py +528 -0
  272. onnxruntime/transformers/models/whisper/convert_to_onnx.py +536 -0
  273. onnxruntime/transformers/models/whisper/whisper_chain.py +329 -0
  274. onnxruntime/transformers/models/whisper/whisper_decoder.py +402 -0
  275. onnxruntime/transformers/models/whisper/whisper_encoder.py +164 -0
  276. onnxruntime/transformers/models/whisper/whisper_encoder_decoder_init.py +306 -0
  277. onnxruntime/transformers/models/whisper/whisper_helper.py +524 -0
  278. onnxruntime/transformers/models/whisper/whisper_openai_helper.py +84 -0
  279. onnxruntime/transformers/onnx_exporter.py +717 -0
  280. onnxruntime/transformers/onnx_model.py +1569 -0
  281. onnxruntime/transformers/onnx_model_bart.py +142 -0
  282. onnxruntime/transformers/onnx_model_bert.py +481 -0
  283. onnxruntime/transformers/onnx_model_bert_keras.py +475 -0
  284. onnxruntime/transformers/onnx_model_bert_tf.py +589 -0
  285. onnxruntime/transformers/onnx_model_clip.py +40 -0
  286. onnxruntime/transformers/onnx_model_conformer.py +33 -0
  287. onnxruntime/transformers/onnx_model_gpt2.py +101 -0
  288. onnxruntime/transformers/onnx_model_phi.py +930 -0
  289. onnxruntime/transformers/onnx_model_sam2.py +138 -0
  290. onnxruntime/transformers/onnx_model_t5.py +791 -0
  291. onnxruntime/transformers/onnx_model_tnlr.py +227 -0
  292. onnxruntime/transformers/onnx_model_unet.py +259 -0
  293. onnxruntime/transformers/onnx_model_vae.py +43 -0
  294. onnxruntime/transformers/onnx_utils.py +55 -0
  295. onnxruntime/transformers/optimizer.py +612 -0
  296. onnxruntime/transformers/profiler.py +725 -0
  297. onnxruntime/transformers/quantize_helper.py +76 -0
  298. onnxruntime/transformers/shape_infer_helper.py +122 -0
  299. onnxruntime/transformers/shape_optimizer.py +401 -0
  300. onnxruntime/transformers/torch_onnx_export_helper.py +74 -0
  301. onnxruntime_directml-1.20.0.dist-info/METADATA +187 -0
  302. onnxruntime_directml-1.20.0.dist-info/RECORD +305 -0
  303. onnxruntime_directml-1.20.0.dist-info/WHEEL +5 -0
  304. onnxruntime_directml-1.20.0.dist-info/entry_points.txt +2 -0
  305. onnxruntime_directml-1.20.0.dist-info/top_level.txt +1 -0
@@ -0,0 +1,646 @@
1
+ # -------------------------------------------------------------------------
2
+ # Copyright (c) Microsoft Corporation. All rights reserved.
3
+ # Licensed under the MIT License. See License.txt in the project root for
4
+ # license information.
5
+ # --------------------------------------------------------------------------
6
+
7
+ import csv
8
+ import logging
9
+ import os
10
+ import random
11
+ import sys
12
+ import time
13
+ import timeit
14
+ from abc import ABC, abstractmethod
15
+ from concurrent.futures import ThreadPoolExecutor
16
+ from datetime import datetime
17
+ from enum import Enum
18
+ from time import sleep
19
+ from typing import Any, Dict, List, Optional
20
+
21
+ import coloredlogs
22
+ import numpy
23
+ import torch
24
+ import transformers
25
+ from packaging import version
26
+
27
+ import onnxruntime
28
+
29
+ logger = logging.getLogger(__name__)
30
+
31
+
32
+ class Precision(Enum):
33
+ FLOAT32 = "fp32"
34
+ FLOAT16 = "fp16"
35
+ INT8 = "int8"
36
+ INT4 = "int4"
37
+
38
+ def __str__(self):
39
+ return self.value
40
+
41
+
42
+ class OptimizerInfo(Enum):
43
+ # no_opt means using the raw ONNX model, but OnnxRuntime might still apply optimization as long as
44
+ # graph optimization level is not 0 (disable all).
45
+ NOOPT = "no_opt"
46
+ BYORT = "by_ort"
47
+ BYSCRIPT = "by_script"
48
+
49
+ def __str__(self):
50
+ return self.value
51
+
52
+
53
+ class ConfigModifier:
54
+ def __init__(self, num_layers):
55
+ self.num_layers = num_layers
56
+
57
+ def modify(self, config):
58
+ if self.num_layers is None:
59
+ return
60
+ if hasattr(config, "num_hidden_layers"):
61
+ config.num_hidden_layers = self.num_layers
62
+ logger.info(f"Modifying pytorch model's number of hidden layers to: {self.num_layers}")
63
+ if hasattr(config, "encoder_layers"):
64
+ config.encoder_layers = self.num_layers
65
+ logger.info(f"Modifying pytorch model's number of encoder layers to: {self.num_layers}")
66
+ if hasattr(config, "decoder_layers "):
67
+ config.decoder_layers = self.num_layers
68
+ logger.info(f"Modifying pytorch model's number of decoder layers to: {self.num_layers}")
69
+
70
+ def get_layer_num(self):
71
+ return self.num_layers
72
+
73
+
74
+ IO_BINDING_DATA_TYPE_MAP = {
75
+ "float32": numpy.float32,
76
+ # TODO: Add more.
77
+ }
78
+
79
+
80
+ def create_onnxruntime_session(
81
+ onnx_model_path,
82
+ use_gpu,
83
+ provider=None,
84
+ enable_all_optimization=True,
85
+ num_threads=-1,
86
+ enable_profiling=False,
87
+ verbose=False,
88
+ enable_mlas_gemm_fastmath_arm64_bfloat16=False,
89
+ provider_options={}, # map execution provider name to its option # noqa: B006
90
+ ):
91
+ session = None
92
+ try:
93
+ sess_options = onnxruntime.SessionOptions()
94
+
95
+ if enable_all_optimization:
96
+ sess_options.graph_optimization_level = onnxruntime.GraphOptimizationLevel.ORT_ENABLE_ALL
97
+ else:
98
+ sess_options.graph_optimization_level = onnxruntime.GraphOptimizationLevel.ORT_ENABLE_BASIC
99
+
100
+ if enable_profiling:
101
+ sess_options.enable_profiling = True
102
+
103
+ if num_threads > 0:
104
+ sess_options.intra_op_num_threads = num_threads
105
+ logger.debug(f"Session option: intra_op_num_threads={sess_options.intra_op_num_threads}")
106
+
107
+ if verbose:
108
+ sess_options.log_severity_level = 0
109
+ else:
110
+ sess_options.log_severity_level = 4
111
+
112
+ logger.debug(f"Create session for onnx model: {onnx_model_path}")
113
+ if use_gpu:
114
+ if provider == "dml":
115
+ providers = ["DmlExecutionProvider", "CPUExecutionProvider"]
116
+ elif provider == "rocm":
117
+ providers = ["ROCMExecutionProvider", "CPUExecutionProvider"]
118
+ elif provider == "migraphx":
119
+ providers = [
120
+ "MIGraphXExecutionProvider",
121
+ "ROCMExecutionProvider",
122
+ "CPUExecutionProvider",
123
+ ]
124
+ elif provider == "cuda":
125
+ providers = ["CUDAExecutionProvider", "CPUExecutionProvider"]
126
+ elif provider == "tensorrt":
127
+ providers = [
128
+ "TensorrtExecutionProvider",
129
+ "CUDAExecutionProvider",
130
+ "CPUExecutionProvider",
131
+ ]
132
+ else:
133
+ providers = ["CUDAExecutionProvider", "CPUExecutionProvider"]
134
+ else:
135
+ providers = ["CPUExecutionProvider"]
136
+
137
+ if provider_options:
138
+ providers = [(name, provider_options[name]) if name in provider_options else name for name in providers]
139
+
140
+ if enable_mlas_gemm_fastmath_arm64_bfloat16:
141
+ sess_options.add_session_config_entry("mlas.enable_gemm_fastmath_arm64_bfloat16", "1")
142
+
143
+ session = onnxruntime.InferenceSession(onnx_model_path, sess_options, providers=providers)
144
+ except Exception:
145
+ logger.error("Exception", exc_info=True) # noqa: G201
146
+
147
+ return session
148
+
149
+
150
+ def setup_logger(verbose=True):
151
+ if verbose:
152
+ coloredlogs.install(
153
+ level="DEBUG",
154
+ fmt="[%(filename)s:%(lineno)s - %(funcName)20s()] %(message)s",
155
+ )
156
+ else:
157
+ coloredlogs.install(fmt="%(message)s")
158
+ logging.getLogger("transformers").setLevel(logging.WARNING)
159
+
160
+
161
+ def prepare_environment(cache_dir, output_dir, use_gpu, provider=None):
162
+ if cache_dir and not os.path.exists(cache_dir):
163
+ os.makedirs(cache_dir)
164
+
165
+ if output_dir and not os.path.exists(output_dir):
166
+ os.makedirs(output_dir)
167
+
168
+ if use_gpu:
169
+ if provider == "dml":
170
+ assert (
171
+ "DmlExecutionProvider" in onnxruntime.get_available_providers()
172
+ ), "Please install onnxruntime-directml package to test GPU inference."
173
+
174
+ else:
175
+ assert not set(onnxruntime.get_available_providers()).isdisjoint(
176
+ ["CUDAExecutionProvider", "ROCMExecutionProvider", "MIGraphXExecutionProvider"]
177
+ ), "Please install onnxruntime-gpu package, or install ROCm support, to test GPU inference."
178
+
179
+ logger.info(f"PyTorch Version:{torch.__version__}")
180
+ logger.info(f"Transformers Version:{transformers.__version__}")
181
+ logger.info(f"OnnxRuntime Version:{onnxruntime.__version__}")
182
+
183
+ # Support three major versions of PyTorch and OnnxRuntime, and up to 9 months of transformers.
184
+ assert version.parse(torch.__version__) >= version.parse("1.10.0")
185
+ assert version.parse(transformers.__version__) >= version.parse("4.12.0")
186
+ assert version.parse(onnxruntime.__version__) >= version.parse("1.10.0")
187
+
188
+
189
+ def get_latency_result(latency_list, batch_size):
190
+ latency_ms = sum(latency_list) / float(len(latency_list)) * 1000.0
191
+ latency_variance = numpy.var(latency_list, dtype=numpy.float64) * 1000.0
192
+ throughput = batch_size * (1000.0 / latency_ms)
193
+
194
+ return {
195
+ "test_times": len(latency_list),
196
+ "latency_variance": f"{latency_variance:.2f}",
197
+ "latency_90_percentile": f"{numpy.percentile(latency_list, 90) * 1000.0:.2f}",
198
+ "latency_95_percentile": f"{numpy.percentile(latency_list, 95) * 1000.0:.2f}",
199
+ "latency_99_percentile": f"{numpy.percentile(latency_list, 99) * 1000.0:.2f}",
200
+ "average_latency_ms": f"{latency_ms:.2f}",
201
+ "QPS": f"{throughput:.2f}",
202
+ }
203
+
204
+
205
+ def output_details(results, csv_filename):
206
+ with open(csv_filename, mode="a", newline="", encoding="ascii") as csv_file:
207
+ column_names = [
208
+ "engine",
209
+ "version",
210
+ "providers",
211
+ "device",
212
+ "precision",
213
+ "optimizer",
214
+ "io_binding",
215
+ "model_name",
216
+ "inputs",
217
+ "threads",
218
+ "batch_size",
219
+ "sequence_length",
220
+ "custom_layer_num",
221
+ "datetime",
222
+ "test_times",
223
+ "QPS",
224
+ "average_latency_ms",
225
+ "latency_variance",
226
+ "latency_90_percentile",
227
+ "latency_95_percentile",
228
+ "latency_99_percentile",
229
+ ]
230
+
231
+ csv_writer = csv.DictWriter(csv_file, fieldnames=column_names)
232
+ csv_writer.writeheader()
233
+ for result in results:
234
+ csv_writer.writerow(result)
235
+
236
+ logger.info(f"Detail results are saved to csv file: {csv_filename}")
237
+
238
+
239
+ def output_summary(results, csv_filename, args):
240
+ with open(csv_filename, mode="a", newline="", encoding="ascii") as csv_file:
241
+ header_names = [
242
+ "model_name",
243
+ "inputs",
244
+ "custom_layer_num",
245
+ "engine",
246
+ "version",
247
+ "providers",
248
+ "device",
249
+ "precision",
250
+ "optimizer",
251
+ "io_binding",
252
+ "threads",
253
+ ]
254
+ data_names = []
255
+ for batch_size in args.batch_sizes:
256
+ if args.sequence_lengths == [""]:
257
+ data_names.append(f"b{batch_size}")
258
+ else:
259
+ for sequence_length in args.sequence_lengths:
260
+ data_names.append(f"b{batch_size}_s{sequence_length}")
261
+
262
+ csv_writer = csv.DictWriter(csv_file, fieldnames=header_names + data_names)
263
+ csv_writer.writeheader()
264
+ for model_name in args.models:
265
+ for input_count in [1, 2, 3]:
266
+ for engine_name in args.engines:
267
+ for io_binding in [True, False, ""]:
268
+ for threads in args.num_threads:
269
+ row = {}
270
+ for result in results:
271
+ if (
272
+ result["model_name"] == model_name
273
+ and result["inputs"] == input_count
274
+ and result["engine"] == engine_name
275
+ and result["io_binding"] == io_binding
276
+ and result["threads"] == threads
277
+ ):
278
+ headers = {k: v for k, v in result.items() if k in header_names}
279
+ if not row:
280
+ row.update(headers)
281
+ row.update({k: "" for k in data_names})
282
+ else:
283
+ for k in header_names:
284
+ assert row[k] == headers[k]
285
+ b = result["batch_size"]
286
+ s = result["sequence_length"]
287
+ if s:
288
+ row[f"b{b}_s{s}"] = result["average_latency_ms"]
289
+ else:
290
+ row[f"b{b}"] = result["average_latency_ms"]
291
+ if row:
292
+ csv_writer.writerow(row)
293
+
294
+ logger.info(f"Summary results are saved to csv file: {csv_filename}")
295
+
296
+
297
+ def output_fusion_statistics(model_fusion_statistics, csv_filename):
298
+ with open(csv_filename, mode="a", newline="", encoding="ascii") as csv_file:
299
+ column_names = [
300
+ "model_filename",
301
+ "datetime",
302
+ "transformers",
303
+ "torch",
304
+ *list(next(iter(model_fusion_statistics.values())).keys()),
305
+ ]
306
+ csv_writer = csv.DictWriter(csv_file, fieldnames=column_names)
307
+ csv_writer.writeheader()
308
+ for key in model_fusion_statistics:
309
+ model_fusion_statistics[key]["datetime"] = str(datetime.now())
310
+ model_fusion_statistics[key]["transformers"] = transformers.__version__
311
+ model_fusion_statistics[key]["torch"] = torch.__version__
312
+ model_fusion_statistics[key]["model_filename"] = key
313
+ csv_writer.writerow(model_fusion_statistics[key])
314
+ logger.info(f"Fusion statistics is saved to csv file: {csv_filename}")
315
+
316
+
317
+ def inference_ort(ort_session, ort_inputs, result_template, repeat_times, batch_size, warm_up_repeat=0):
318
+ result = {}
319
+ timeit.repeat(lambda: ort_session.run(None, ort_inputs), number=1, repeat=warm_up_repeat) # Dry run
320
+ latency_list = timeit.repeat(lambda: ort_session.run(None, ort_inputs), number=1, repeat=repeat_times)
321
+ result.update(result_template)
322
+ result.update({"io_binding": False})
323
+ result.update(get_latency_result(latency_list, batch_size))
324
+ return result
325
+
326
+
327
+ def inference_ort_with_io_binding(
328
+ ort_session,
329
+ ort_inputs,
330
+ result_template,
331
+ repeat_times,
332
+ ort_output_names,
333
+ ort_outputs,
334
+ output_buffers,
335
+ output_buffer_max_sizes,
336
+ batch_size,
337
+ device,
338
+ data_type=numpy.longlong,
339
+ warm_up_repeat=0,
340
+ ):
341
+ result = {}
342
+
343
+ # Bind inputs and outputs to onnxruntime session
344
+ io_binding = ort_session.io_binding()
345
+ # Bind inputs to device
346
+ for name in ort_inputs:
347
+ np_input = torch.from_numpy(ort_inputs[name]).to(device)
348
+ input_type = IO_BINDING_DATA_TYPE_MAP.get(str(ort_inputs[name].dtype), data_type)
349
+ io_binding.bind_input(
350
+ name,
351
+ np_input.device.type,
352
+ 0,
353
+ input_type,
354
+ np_input.shape,
355
+ np_input.data_ptr(),
356
+ )
357
+ # Bind outputs buffers with the sizes needed if not allocated already
358
+ if len(output_buffers) == 0:
359
+ allocateOutputBuffers(output_buffers, output_buffer_max_sizes, device)
360
+
361
+ for i, ort_output_name in enumerate(ort_output_names):
362
+ io_binding.bind_output(
363
+ ort_output_name,
364
+ output_buffers[i].device.type,
365
+ 0,
366
+ numpy.float32,
367
+ ort_outputs[i].shape,
368
+ output_buffers[i].data_ptr(),
369
+ )
370
+
371
+ timeit.repeat(
372
+ lambda: ort_session.run_with_iobinding(io_binding),
373
+ number=1,
374
+ repeat=warm_up_repeat,
375
+ ) # Dry run
376
+
377
+ latency_list = timeit.repeat(
378
+ lambda: ort_session.run_with_iobinding(io_binding),
379
+ number=1,
380
+ repeat=repeat_times,
381
+ )
382
+ result.update(result_template)
383
+ result.update({"io_binding": True})
384
+ result.update(get_latency_result(latency_list, batch_size))
385
+ return result
386
+
387
+
388
+ def allocateOutputBuffers(output_buffers, output_buffer_max_sizes, device): # noqa: N802
389
+ # Allocate output tensors with the largest test size needed. So the allocated memory can be reused
390
+ # for each test run.
391
+
392
+ for i in output_buffer_max_sizes:
393
+ output_buffers.append(torch.empty(i, dtype=torch.float32, device=device))
394
+
395
+
396
+ def set_random_seed(seed=123):
397
+ """Set random seed manually to get deterministic results"""
398
+ random.seed(seed)
399
+ numpy.random.seed(seed)
400
+ torch.manual_seed(seed)
401
+ torch.cuda.manual_seed(seed)
402
+ torch.cuda.manual_seed_all(seed)
403
+ # torch.backends.cudnn.enabled = False
404
+ # torch.backends.cudnn.benchmark = False
405
+ # torch.backends.cudnn.deterministic = True
406
+
407
+
408
+ def get_gpu_info() -> Optional[List[Dict[str, Any]]]:
409
+ from py3nvml.py3nvml import (
410
+ NVMLError,
411
+ nvmlDeviceGetCount,
412
+ nvmlDeviceGetHandleByIndex,
413
+ nvmlDeviceGetMemoryInfo,
414
+ nvmlDeviceGetName,
415
+ nvmlInit,
416
+ nvmlShutdown,
417
+ )
418
+
419
+ try:
420
+ nvmlInit()
421
+ result = []
422
+ device_count = nvmlDeviceGetCount()
423
+ if not isinstance(device_count, int):
424
+ return None
425
+
426
+ for i in range(device_count):
427
+ info = nvmlDeviceGetMemoryInfo(nvmlDeviceGetHandleByIndex(i))
428
+ if isinstance(info, str):
429
+ return None
430
+ result.append(
431
+ {
432
+ "id": i,
433
+ "name": nvmlDeviceGetName(nvmlDeviceGetHandleByIndex(i)),
434
+ "total": info.total,
435
+ "free": info.free,
436
+ "used": info.used,
437
+ }
438
+ )
439
+ nvmlShutdown()
440
+ return result
441
+ except NVMLError as error:
442
+ print("Error fetching GPU information using nvml: %s", error)
443
+ return None
444
+
445
+
446
+ class MemoryMonitor(ABC):
447
+ def __init__(self, keep_measuring=True):
448
+ self.keep_measuring = keep_measuring
449
+
450
+ def measure_cpu_usage(self):
451
+ import psutil
452
+
453
+ max_usage = 0
454
+ while True:
455
+ max_usage = max(max_usage, psutil.Process(os.getpid()).memory_info().rss / 1024**2)
456
+ sleep(0.005) # 5ms
457
+ if not self.keep_measuring:
458
+ break
459
+ return max_usage
460
+
461
+ @abstractmethod
462
+ def measure_gpu_usage(self) -> Optional[List[Dict[str, Any]]]:
463
+ raise NotImplementedError()
464
+
465
+
466
+ class CudaMemoryMonitor(MemoryMonitor):
467
+ def __init__(self, keep_measuring=True):
468
+ super().__init__(keep_measuring)
469
+
470
+ def measure_gpu_usage(self) -> Optional[List[Dict[str, Any]]]:
471
+ from py3nvml.py3nvml import (
472
+ NVMLError,
473
+ nvmlDeviceGetCount,
474
+ nvmlDeviceGetHandleByIndex,
475
+ nvmlDeviceGetMemoryInfo,
476
+ nvmlDeviceGetName,
477
+ nvmlInit,
478
+ nvmlShutdown,
479
+ )
480
+
481
+ max_gpu_usage = []
482
+ gpu_name = []
483
+ try:
484
+ nvmlInit()
485
+ device_count = nvmlDeviceGetCount()
486
+ if not isinstance(device_count, int):
487
+ logger.error(f"nvmlDeviceGetCount result is not integer: {device_count}")
488
+ return None
489
+
490
+ max_gpu_usage = [0 for i in range(device_count)]
491
+ gpu_name = [nvmlDeviceGetName(nvmlDeviceGetHandleByIndex(i)) for i in range(device_count)]
492
+ while True:
493
+ for i in range(device_count):
494
+ info = nvmlDeviceGetMemoryInfo(nvmlDeviceGetHandleByIndex(i))
495
+ if isinstance(info, str):
496
+ logger.error(f"nvmlDeviceGetMemoryInfo returns str: {info}")
497
+ return None
498
+ max_gpu_usage[i] = max(max_gpu_usage[i], info.used / 1024**2)
499
+ sleep(0.005) # 5ms
500
+ if not self.keep_measuring:
501
+ break
502
+ nvmlShutdown()
503
+ return [
504
+ {
505
+ "device_id": i,
506
+ "name": gpu_name[i],
507
+ "max_used_MB": max_gpu_usage[i],
508
+ }
509
+ for i in range(device_count)
510
+ ]
511
+ except NVMLError as error:
512
+ logger.error("Error fetching GPU information using nvml: %s", error)
513
+ return None
514
+
515
+
516
+ class RocmMemoryMonitor(MemoryMonitor):
517
+ def __init__(self, keep_measuring=True):
518
+ super().__init__(keep_measuring)
519
+ rocm_smi_path = "/opt/rocm/libexec/rocm_smi"
520
+ if os.path.exists(rocm_smi_path):
521
+ if rocm_smi_path not in sys.path:
522
+ sys.path.append(rocm_smi_path)
523
+ try:
524
+ import rocm_smi
525
+
526
+ self.rocm_smi = rocm_smi
527
+ self.rocm_smi.initializeRsmi()
528
+ except ImportError:
529
+ self.rocm_smi = None
530
+
531
+ def get_used_memory(self, dev):
532
+ if self.rocm_smi is None:
533
+ return -1
534
+ return self.rocm_smi.getMemInfo(dev, "VRAM")[0] / 1024 / 1024
535
+
536
+ def measure_gpu_usage(self):
537
+ if self.rocm_smi is None:
538
+ return None
539
+
540
+ device_count = len(self.rocm_smi.listDevices()) if self.rocm_smi is not None else 0
541
+ max_gpu_usage = [0 for i in range(device_count)]
542
+ gpu_name = [f"GPU{i}" for i in range(device_count)]
543
+ while True:
544
+ for i in range(device_count):
545
+ max_gpu_usage[i] = max(max_gpu_usage[i], self.get_used_memory(i))
546
+ time.sleep(0.005) # 5ms
547
+ if not self.keep_measuring:
548
+ break
549
+ return [
550
+ {
551
+ "device_id": i,
552
+ "name": gpu_name[i],
553
+ "max_used_MB": max_gpu_usage[i],
554
+ }
555
+ for i in range(device_count)
556
+ ]
557
+
558
+
559
+ def measure_memory(is_gpu, func, monitor_type="cuda", start_memory=None):
560
+ memory_monitor_type = None
561
+ if monitor_type == "rocm":
562
+ memory_monitor_type = RocmMemoryMonitor
563
+ else:
564
+ memory_monitor_type = CudaMemoryMonitor
565
+
566
+ monitor = memory_monitor_type(False)
567
+
568
+ if is_gpu:
569
+ if start_memory is not None:
570
+ memory_before_test = start_memory
571
+ else:
572
+ memory_before_test = monitor.measure_gpu_usage()
573
+ if memory_before_test is None:
574
+ return None
575
+
576
+ if func is None:
577
+ return memory_before_test
578
+
579
+ with ThreadPoolExecutor() as executor:
580
+ monitor = memory_monitor_type()
581
+ mem_thread = executor.submit(monitor.measure_gpu_usage)
582
+ try:
583
+ fn_thread = executor.submit(func)
584
+ _ = fn_thread.result()
585
+ finally:
586
+ monitor.keep_measuring = False
587
+ max_usage = mem_thread.result()
588
+
589
+ if max_usage is None:
590
+ return None
591
+
592
+ logger.info(f"GPU memory usage: before={memory_before_test} peak={max_usage}")
593
+ if len(memory_before_test) >= 1 and len(max_usage) >= 1 and len(memory_before_test) == len(max_usage):
594
+ # When there are multiple GPUs, we will check the one with maximum usage.
595
+ max_used = 0
596
+ for i, memory_before in enumerate(memory_before_test):
597
+ before = memory_before["max_used_MB"]
598
+ after = max_usage[i]["max_used_MB"]
599
+ used = after - before
600
+ max_used = max(max_used, used)
601
+ return max_used
602
+ return None
603
+
604
+ # CPU memory
605
+ if start_memory is not None:
606
+ memory_before_test = start_memory
607
+ else:
608
+ memory_before_test = monitor.measure_cpu_usage()
609
+
610
+ if func is None:
611
+ return memory_before_test
612
+
613
+ with ThreadPoolExecutor() as executor:
614
+ monitor = memory_monitor_type()
615
+ mem_thread = executor.submit(monitor.measure_cpu_usage)
616
+ try:
617
+ fn_thread = executor.submit(func)
618
+ _ = fn_thread.result()
619
+ finally:
620
+ monitor.keep_measuring = False
621
+ max_usage = mem_thread.result()
622
+
623
+ logger.info(f"CPU memory usage: before={memory_before_test:.1f} MB, peak={max_usage:.1f} MB")
624
+ return max_usage - memory_before_test
625
+
626
+
627
+ def get_ort_environment_variables():
628
+ # Environment variables might impact ORT performance on transformer models. Note that they are for testing only.
629
+ env_names = [
630
+ "ORT_DISABLE_FUSED_ATTENTION",
631
+ "ORT_ENABLE_FUSED_CAUSAL_ATTENTION",
632
+ "ORT_DISABLE_FUSED_CROSS_ATTENTION",
633
+ "ORT_DISABLE_TRT_FLASH_ATTENTION",
634
+ "ORT_DISABLE_MEMORY_EFFICIENT_ATTENTION",
635
+ "ORT_TRANSFORMER_OPTIONS",
636
+ "ORT_CUDA_GEMM_OPTIONS",
637
+ ]
638
+ env = ""
639
+ for name in env_names:
640
+ value = os.getenv(name)
641
+ if value is None:
642
+ continue
643
+ if env:
644
+ env += ","
645
+ env += f"{name}={value}"
646
+ return env