onnxruntime-directml 1.20.0__cp313-cp313-win_amd64.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (305) hide show
  1. onnxruntime/LICENSE +21 -0
  2. onnxruntime/Privacy.md +21 -0
  3. onnxruntime/ThirdPartyNotices.txt +6508 -0
  4. onnxruntime/__init__.py +78 -0
  5. onnxruntime/backend/__init__.py +6 -0
  6. onnxruntime/backend/backend.py +174 -0
  7. onnxruntime/backend/backend_rep.py +53 -0
  8. onnxruntime/capi/DirectML.dll +0 -0
  9. onnxruntime/capi/__init__.py +4 -0
  10. onnxruntime/capi/_ld_preload.py +7 -0
  11. onnxruntime/capi/_pybind_state.py +33 -0
  12. onnxruntime/capi/convert_npz_to_onnx_adapter.py +48 -0
  13. onnxruntime/capi/onnxruntime.dll +0 -0
  14. onnxruntime/capi/onnxruntime_collect_build_info.py +47 -0
  15. onnxruntime/capi/onnxruntime_inference_collection.py +1108 -0
  16. onnxruntime/capi/onnxruntime_providers_shared.dll +0 -0
  17. onnxruntime/capi/onnxruntime_pybind11_state.pyd +0 -0
  18. onnxruntime/capi/onnxruntime_validation.py +150 -0
  19. onnxruntime/capi/version_info.py +2 -0
  20. onnxruntime/datasets/__init__.py +17 -0
  21. onnxruntime/datasets/logreg_iris.onnx +0 -0
  22. onnxruntime/datasets/mul_1.onnx +0 -0
  23. onnxruntime/datasets/sigmoid.onnx +13 -0
  24. onnxruntime/quantization/CalTableFlatBuffers/KeyValue.py +78 -0
  25. onnxruntime/quantization/CalTableFlatBuffers/TrtTable.py +90 -0
  26. onnxruntime/quantization/CalTableFlatBuffers/__init__.py +0 -0
  27. onnxruntime/quantization/__init__.py +16 -0
  28. onnxruntime/quantization/base_quantizer.py +532 -0
  29. onnxruntime/quantization/calibrate.py +1245 -0
  30. onnxruntime/quantization/execution_providers/qnn/__init__.py +2 -0
  31. onnxruntime/quantization/execution_providers/qnn/fusion_lpnorm.py +132 -0
  32. onnxruntime/quantization/execution_providers/qnn/mixed_precision_overrides_utils.py +413 -0
  33. onnxruntime/quantization/execution_providers/qnn/preprocess.py +307 -0
  34. onnxruntime/quantization/execution_providers/qnn/quant_config.py +387 -0
  35. onnxruntime/quantization/fusions/__init__.py +3 -0
  36. onnxruntime/quantization/fusions/fusion.py +311 -0
  37. onnxruntime/quantization/fusions/fusion_gelu.py +272 -0
  38. onnxruntime/quantization/fusions/fusion_layernorm.py +135 -0
  39. onnxruntime/quantization/matmul_4bits_quantizer.py +1480 -0
  40. onnxruntime/quantization/matmul_bnb4_quantizer.py +240 -0
  41. onnxruntime/quantization/onnx_model.py +580 -0
  42. onnxruntime/quantization/onnx_quantizer.py +1008 -0
  43. onnxruntime/quantization/operators/__init__.py +2 -0
  44. onnxruntime/quantization/operators/activation.py +119 -0
  45. onnxruntime/quantization/operators/argmax.py +18 -0
  46. onnxruntime/quantization/operators/attention.py +73 -0
  47. onnxruntime/quantization/operators/base_operator.py +26 -0
  48. onnxruntime/quantization/operators/binary_op.py +72 -0
  49. onnxruntime/quantization/operators/concat.py +62 -0
  50. onnxruntime/quantization/operators/conv.py +258 -0
  51. onnxruntime/quantization/operators/direct_q8.py +78 -0
  52. onnxruntime/quantization/operators/embed_layernorm.py +121 -0
  53. onnxruntime/quantization/operators/gather.py +64 -0
  54. onnxruntime/quantization/operators/gavgpool.py +62 -0
  55. onnxruntime/quantization/operators/gemm.py +166 -0
  56. onnxruntime/quantization/operators/lstm.py +117 -0
  57. onnxruntime/quantization/operators/matmul.py +231 -0
  58. onnxruntime/quantization/operators/maxpool.py +34 -0
  59. onnxruntime/quantization/operators/norm.py +40 -0
  60. onnxruntime/quantization/operators/pad.py +100 -0
  61. onnxruntime/quantization/operators/pooling.py +67 -0
  62. onnxruntime/quantization/operators/qdq_base_operator.py +22 -0
  63. onnxruntime/quantization/operators/resize.py +34 -0
  64. onnxruntime/quantization/operators/softmax.py +74 -0
  65. onnxruntime/quantization/operators/split.py +63 -0
  66. onnxruntime/quantization/operators/where.py +87 -0
  67. onnxruntime/quantization/preprocess.py +141 -0
  68. onnxruntime/quantization/qdq_loss_debug.py +389 -0
  69. onnxruntime/quantization/qdq_quantizer.py +1187 -0
  70. onnxruntime/quantization/quant_utils.py +891 -0
  71. onnxruntime/quantization/quantize.py +748 -0
  72. onnxruntime/quantization/registry.py +106 -0
  73. onnxruntime/quantization/shape_inference.py +187 -0
  74. onnxruntime/quantization/tensor_quant_overrides.py +516 -0
  75. onnxruntime/tools/__init__.py +10 -0
  76. onnxruntime/tools/check_onnx_model_mobile_usability.py +47 -0
  77. onnxruntime/tools/convert_onnx_models_to_ort.py +377 -0
  78. onnxruntime/tools/file_utils.py +46 -0
  79. onnxruntime/tools/logger.py +11 -0
  80. onnxruntime/tools/make_dynamic_shape_fixed.py +72 -0
  81. onnxruntime/tools/mobile_helpers/__init__.py +0 -0
  82. onnxruntime/tools/mobile_helpers/coreml_supported_mlprogram_ops.md +33 -0
  83. onnxruntime/tools/mobile_helpers/coreml_supported_neuralnetwork_ops.md +43 -0
  84. onnxruntime/tools/mobile_helpers/nnapi_supported_ops.md +58 -0
  85. onnxruntime/tools/mobile_helpers/usability_checker.py +739 -0
  86. onnxruntime/tools/offline_tuning.py +169 -0
  87. onnxruntime/tools/onnx_model_utils.py +413 -0
  88. onnxruntime/tools/onnx_randomizer.py +85 -0
  89. onnxruntime/tools/onnxruntime_test.py +164 -0
  90. onnxruntime/tools/optimize_onnx_model.py +55 -0
  91. onnxruntime/tools/ort_format_model/__init__.py +25 -0
  92. onnxruntime/tools/ort_format_model/operator_type_usage_processors.py +663 -0
  93. onnxruntime/tools/ort_format_model/ort_flatbuffers_py/__init__.py +0 -0
  94. onnxruntime/tools/ort_format_model/ort_flatbuffers_py/fbs/ArgType.py +7 -0
  95. onnxruntime/tools/ort_format_model/ort_flatbuffers_py/fbs/ArgTypeAndIndex.py +67 -0
  96. onnxruntime/tools/ort_format_model/ort_flatbuffers_py/fbs/Attribute.py +337 -0
  97. onnxruntime/tools/ort_format_model/ort_flatbuffers_py/fbs/AttributeType.py +18 -0
  98. onnxruntime/tools/ort_format_model/ort_flatbuffers_py/fbs/Checkpoint.py +125 -0
  99. onnxruntime/tools/ort_format_model/ort_flatbuffers_py/fbs/DeprecatedKernelCreateInfos.py +120 -0
  100. onnxruntime/tools/ort_format_model/ort_flatbuffers_py/fbs/DeprecatedNodeIndexAndKernelDefHash.py +68 -0
  101. onnxruntime/tools/ort_format_model/ort_flatbuffers_py/fbs/DeprecatedSessionState.py +96 -0
  102. onnxruntime/tools/ort_format_model/ort_flatbuffers_py/fbs/DeprecatedSubGraphSessionState.py +72 -0
  103. onnxruntime/tools/ort_format_model/ort_flatbuffers_py/fbs/Dimension.py +71 -0
  104. onnxruntime/tools/ort_format_model/ort_flatbuffers_py/fbs/DimensionValue.py +80 -0
  105. onnxruntime/tools/ort_format_model/ort_flatbuffers_py/fbs/DimensionValueType.py +8 -0
  106. onnxruntime/tools/ort_format_model/ort_flatbuffers_py/fbs/EdgeEnd.py +32 -0
  107. onnxruntime/tools/ort_format_model/ort_flatbuffers_py/fbs/FloatProperty.py +67 -0
  108. onnxruntime/tools/ort_format_model/ort_flatbuffers_py/fbs/Graph.py +320 -0
  109. onnxruntime/tools/ort_format_model/ort_flatbuffers_py/fbs/InferenceSession.py +88 -0
  110. onnxruntime/tools/ort_format_model/ort_flatbuffers_py/fbs/IntProperty.py +67 -0
  111. onnxruntime/tools/ort_format_model/ort_flatbuffers_py/fbs/KernelTypeStrArgsEntry.py +91 -0
  112. onnxruntime/tools/ort_format_model/ort_flatbuffers_py/fbs/KernelTypeStrResolver.py +78 -0
  113. onnxruntime/tools/ort_format_model/ort_flatbuffers_py/fbs/MapType.py +71 -0
  114. onnxruntime/tools/ort_format_model/ort_flatbuffers_py/fbs/Model.py +223 -0
  115. onnxruntime/tools/ort_format_model/ort_flatbuffers_py/fbs/ModuleState.py +141 -0
  116. onnxruntime/tools/ort_format_model/ort_flatbuffers_py/fbs/Node.py +317 -0
  117. onnxruntime/tools/ort_format_model/ort_flatbuffers_py/fbs/NodeEdge.py +126 -0
  118. onnxruntime/tools/ort_format_model/ort_flatbuffers_py/fbs/NodeType.py +7 -0
  119. onnxruntime/tools/ort_format_model/ort_flatbuffers_py/fbs/NodesToOptimizeIndices.py +160 -0
  120. onnxruntime/tools/ort_format_model/ort_flatbuffers_py/fbs/OpIdKernelTypeStrArgsEntry.py +91 -0
  121. onnxruntime/tools/ort_format_model/ort_flatbuffers_py/fbs/OperatorSetId.py +67 -0
  122. onnxruntime/tools/ort_format_model/ort_flatbuffers_py/fbs/OptimizerGroup.py +117 -0
  123. onnxruntime/tools/ort_format_model/ort_flatbuffers_py/fbs/ParameterOptimizerState.py +91 -0
  124. onnxruntime/tools/ort_format_model/ort_flatbuffers_py/fbs/PropertyBag.py +152 -0
  125. onnxruntime/tools/ort_format_model/ort_flatbuffers_py/fbs/RuntimeOptimizationRecord.py +105 -0
  126. onnxruntime/tools/ort_format_model/ort_flatbuffers_py/fbs/RuntimeOptimizationRecordContainerEntry.py +91 -0
  127. onnxruntime/tools/ort_format_model/ort_flatbuffers_py/fbs/RuntimeOptimizations.py +79 -0
  128. onnxruntime/tools/ort_format_model/ort_flatbuffers_py/fbs/SequenceType.py +58 -0
  129. onnxruntime/tools/ort_format_model/ort_flatbuffers_py/fbs/Shape.py +78 -0
  130. onnxruntime/tools/ort_format_model/ort_flatbuffers_py/fbs/SparseTensor.py +114 -0
  131. onnxruntime/tools/ort_format_model/ort_flatbuffers_py/fbs/StringProperty.py +67 -0
  132. onnxruntime/tools/ort_format_model/ort_flatbuffers_py/fbs/StringStringEntry.py +67 -0
  133. onnxruntime/tools/ort_format_model/ort_flatbuffers_py/fbs/Tensor.py +203 -0
  134. onnxruntime/tools/ort_format_model/ort_flatbuffers_py/fbs/TensorDataType.py +26 -0
  135. onnxruntime/tools/ort_format_model/ort_flatbuffers_py/fbs/TensorTypeAndShape.py +71 -0
  136. onnxruntime/tools/ort_format_model/ort_flatbuffers_py/fbs/TypeInfo.py +83 -0
  137. onnxruntime/tools/ort_format_model/ort_flatbuffers_py/fbs/TypeInfoValue.py +9 -0
  138. onnxruntime/tools/ort_format_model/ort_flatbuffers_py/fbs/ValueInfo.py +84 -0
  139. onnxruntime/tools/ort_format_model/ort_flatbuffers_py/fbs/__init__.py +6 -0
  140. onnxruntime/tools/ort_format_model/ort_model_processor.py +86 -0
  141. onnxruntime/tools/ort_format_model/types.py +84 -0
  142. onnxruntime/tools/ort_format_model/utils.py +62 -0
  143. onnxruntime/tools/pytorch_export_contrib_ops.py +108 -0
  144. onnxruntime/tools/pytorch_export_helpers.py +131 -0
  145. onnxruntime/tools/qdq_helpers/__init__.py +0 -0
  146. onnxruntime/tools/qdq_helpers/optimize_qdq_model.py +37 -0
  147. onnxruntime/tools/reduced_build_config_parser.py +202 -0
  148. onnxruntime/tools/symbolic_shape_infer.py +3016 -0
  149. onnxruntime/tools/update_onnx_opset.py +31 -0
  150. onnxruntime/transformers/__init__.py +8 -0
  151. onnxruntime/transformers/affinity_helper.py +40 -0
  152. onnxruntime/transformers/benchmark.py +944 -0
  153. onnxruntime/transformers/benchmark_helper.py +646 -0
  154. onnxruntime/transformers/bert_perf_test.py +634 -0
  155. onnxruntime/transformers/bert_test_data.py +642 -0
  156. onnxruntime/transformers/compare_bert_results.py +246 -0
  157. onnxruntime/transformers/constants.py +47 -0
  158. onnxruntime/transformers/convert_generation.py +3124 -0
  159. onnxruntime/transformers/convert_tf_models_to_pytorch.py +205 -0
  160. onnxruntime/transformers/convert_to_packing_mode.py +387 -0
  161. onnxruntime/transformers/dynamo_onnx_helper.py +104 -0
  162. onnxruntime/transformers/float16.py +501 -0
  163. onnxruntime/transformers/fusion_attention.py +1235 -0
  164. onnxruntime/transformers/fusion_attention_clip.py +257 -0
  165. onnxruntime/transformers/fusion_attention_sam2.py +534 -0
  166. onnxruntime/transformers/fusion_attention_unet.py +1304 -0
  167. onnxruntime/transformers/fusion_attention_vae.py +301 -0
  168. onnxruntime/transformers/fusion_bart_attention.py +640 -0
  169. onnxruntime/transformers/fusion_base.py +137 -0
  170. onnxruntime/transformers/fusion_bias_add.py +58 -0
  171. onnxruntime/transformers/fusion_biasgelu.py +66 -0
  172. onnxruntime/transformers/fusion_biassplitgelu.py +111 -0
  173. onnxruntime/transformers/fusion_conformer_attention.py +143 -0
  174. onnxruntime/transformers/fusion_embedlayer.py +811 -0
  175. onnxruntime/transformers/fusion_fastgelu.py +360 -0
  176. onnxruntime/transformers/fusion_gelu.py +259 -0
  177. onnxruntime/transformers/fusion_gelu_approximation.py +25 -0
  178. onnxruntime/transformers/fusion_gemmfastgelu.py +122 -0
  179. onnxruntime/transformers/fusion_gpt_attention.py +546 -0
  180. onnxruntime/transformers/fusion_gpt_attention_megatron.py +355 -0
  181. onnxruntime/transformers/fusion_gpt_attention_no_past.py +260 -0
  182. onnxruntime/transformers/fusion_group_norm.py +179 -0
  183. onnxruntime/transformers/fusion_layernorm.py +465 -0
  184. onnxruntime/transformers/fusion_nhwc_conv.py +100 -0
  185. onnxruntime/transformers/fusion_options.py +340 -0
  186. onnxruntime/transformers/fusion_qordered_attention.py +421 -0
  187. onnxruntime/transformers/fusion_qordered_gelu.py +119 -0
  188. onnxruntime/transformers/fusion_qordered_layernorm.py +123 -0
  189. onnxruntime/transformers/fusion_qordered_matmul.py +217 -0
  190. onnxruntime/transformers/fusion_quickgelu.py +74 -0
  191. onnxruntime/transformers/fusion_reshape.py +173 -0
  192. onnxruntime/transformers/fusion_rotary_attention.py +1592 -0
  193. onnxruntime/transformers/fusion_shape.py +110 -0
  194. onnxruntime/transformers/fusion_simplified_layernorm.py +159 -0
  195. onnxruntime/transformers/fusion_skip_group_norm.py +255 -0
  196. onnxruntime/transformers/fusion_skiplayernorm.py +209 -0
  197. onnxruntime/transformers/fusion_transpose.py +168 -0
  198. onnxruntime/transformers/fusion_utils.py +307 -0
  199. onnxruntime/transformers/huggingface_models.py +167 -0
  200. onnxruntime/transformers/import_utils.py +20 -0
  201. onnxruntime/transformers/io_binding_helper.py +442 -0
  202. onnxruntime/transformers/large_model_exporter.py +395 -0
  203. onnxruntime/transformers/machine_info.py +221 -0
  204. onnxruntime/transformers/metrics.py +164 -0
  205. onnxruntime/transformers/models/bart/__init__.py +12 -0
  206. onnxruntime/transformers/models/bart/export.py +98 -0
  207. onnxruntime/transformers/models/bert/__init__.py +12 -0
  208. onnxruntime/transformers/models/bert/eval_squad.py +329 -0
  209. onnxruntime/transformers/models/gpt2/__init__.py +12 -0
  210. onnxruntime/transformers/models/gpt2/benchmark_gpt2.py +413 -0
  211. onnxruntime/transformers/models/gpt2/convert_to_onnx.py +561 -0
  212. onnxruntime/transformers/models/gpt2/gpt2_helper.py +1032 -0
  213. onnxruntime/transformers/models/gpt2/gpt2_parity.py +513 -0
  214. onnxruntime/transformers/models/gpt2/gpt2_tester.py +501 -0
  215. onnxruntime/transformers/models/gpt2/parity_check_helper.py +146 -0
  216. onnxruntime/transformers/models/llama/__init__.py +12 -0
  217. onnxruntime/transformers/models/llama/benchmark.py +703 -0
  218. onnxruntime/transformers/models/llama/benchmark_all.py +488 -0
  219. onnxruntime/transformers/models/llama/benchmark_e2e.py +606 -0
  220. onnxruntime/transformers/models/llama/convert_to_onnx.py +1027 -0
  221. onnxruntime/transformers/models/llama/dist_settings.py +57 -0
  222. onnxruntime/transformers/models/llama/llama_inputs.py +503 -0
  223. onnxruntime/transformers/models/llama/llama_parity.py +309 -0
  224. onnxruntime/transformers/models/llama/llama_torch.py +47 -0
  225. onnxruntime/transformers/models/llama/quant_kv_dataloader.py +108 -0
  226. onnxruntime/transformers/models/longformer/__init__.py +12 -0
  227. onnxruntime/transformers/models/longformer/benchmark_longformer.py +821 -0
  228. onnxruntime/transformers/models/longformer/convert_to_onnx.py +413 -0
  229. onnxruntime/transformers/models/longformer/generate_test_data.py +347 -0
  230. onnxruntime/transformers/models/longformer/longformer_helper.py +77 -0
  231. onnxruntime/transformers/models/phi2/__init__.py +12 -0
  232. onnxruntime/transformers/models/phi2/convert_to_onnx.py +576 -0
  233. onnxruntime/transformers/models/phi2/inference_example.py +414 -0
  234. onnxruntime/transformers/models/sam2/__init__.py +12 -0
  235. onnxruntime/transformers/models/sam2/benchmark_sam2.py +625 -0
  236. onnxruntime/transformers/models/sam2/convert_to_onnx.py +260 -0
  237. onnxruntime/transformers/models/sam2/image_decoder.py +273 -0
  238. onnxruntime/transformers/models/sam2/image_encoder.py +186 -0
  239. onnxruntime/transformers/models/sam2/mask_decoder.py +208 -0
  240. onnxruntime/transformers/models/sam2/nvtx_helper.py +33 -0
  241. onnxruntime/transformers/models/sam2/prompt_encoder.py +189 -0
  242. onnxruntime/transformers/models/sam2/sam2_demo.py +322 -0
  243. onnxruntime/transformers/models/sam2/sam2_image_onnx_predictor.py +280 -0
  244. onnxruntime/transformers/models/sam2/sam2_utils.py +147 -0
  245. onnxruntime/transformers/models/stable_diffusion/__init__.py +12 -0
  246. onnxruntime/transformers/models/stable_diffusion/benchmark.py +1429 -0
  247. onnxruntime/transformers/models/stable_diffusion/benchmark_controlnet.py +426 -0
  248. onnxruntime/transformers/models/stable_diffusion/demo_txt2img.py +102 -0
  249. onnxruntime/transformers/models/stable_diffusion/demo_txt2img_xl.py +268 -0
  250. onnxruntime/transformers/models/stable_diffusion/demo_utils.py +778 -0
  251. onnxruntime/transformers/models/stable_diffusion/diffusion_models.py +1319 -0
  252. onnxruntime/transformers/models/stable_diffusion/diffusion_schedulers.py +1181 -0
  253. onnxruntime/transformers/models/stable_diffusion/engine_builder.py +296 -0
  254. onnxruntime/transformers/models/stable_diffusion/engine_builder_ort_cuda.py +388 -0
  255. onnxruntime/transformers/models/stable_diffusion/engine_builder_ort_trt.py +288 -0
  256. onnxruntime/transformers/models/stable_diffusion/engine_builder_tensorrt.py +395 -0
  257. onnxruntime/transformers/models/stable_diffusion/engine_builder_torch.py +108 -0
  258. onnxruntime/transformers/models/stable_diffusion/optimize_pipeline.py +350 -0
  259. onnxruntime/transformers/models/stable_diffusion/ort_optimizer.py +136 -0
  260. onnxruntime/transformers/models/stable_diffusion/pipeline_stable_diffusion.py +831 -0
  261. onnxruntime/transformers/models/stable_diffusion/trt_utilities.py +12 -0
  262. onnxruntime/transformers/models/t5/__init__.py +12 -0
  263. onnxruntime/transformers/models/t5/convert_to_onnx.py +278 -0
  264. onnxruntime/transformers/models/t5/past_helper.py +150 -0
  265. onnxruntime/transformers/models/t5/t5_decoder.py +438 -0
  266. onnxruntime/transformers/models/t5/t5_encoder.py +171 -0
  267. onnxruntime/transformers/models/t5/t5_encoder_decoder_init.py +299 -0
  268. onnxruntime/transformers/models/t5/t5_helper.py +272 -0
  269. onnxruntime/transformers/models/whisper/__init__.py +12 -0
  270. onnxruntime/transformers/models/whisper/benchmark.py +610 -0
  271. onnxruntime/transformers/models/whisper/benchmark_all.py +528 -0
  272. onnxruntime/transformers/models/whisper/convert_to_onnx.py +536 -0
  273. onnxruntime/transformers/models/whisper/whisper_chain.py +329 -0
  274. onnxruntime/transformers/models/whisper/whisper_decoder.py +402 -0
  275. onnxruntime/transformers/models/whisper/whisper_encoder.py +164 -0
  276. onnxruntime/transformers/models/whisper/whisper_encoder_decoder_init.py +306 -0
  277. onnxruntime/transformers/models/whisper/whisper_helper.py +524 -0
  278. onnxruntime/transformers/models/whisper/whisper_openai_helper.py +84 -0
  279. onnxruntime/transformers/onnx_exporter.py +717 -0
  280. onnxruntime/transformers/onnx_model.py +1569 -0
  281. onnxruntime/transformers/onnx_model_bart.py +142 -0
  282. onnxruntime/transformers/onnx_model_bert.py +481 -0
  283. onnxruntime/transformers/onnx_model_bert_keras.py +475 -0
  284. onnxruntime/transformers/onnx_model_bert_tf.py +589 -0
  285. onnxruntime/transformers/onnx_model_clip.py +40 -0
  286. onnxruntime/transformers/onnx_model_conformer.py +33 -0
  287. onnxruntime/transformers/onnx_model_gpt2.py +101 -0
  288. onnxruntime/transformers/onnx_model_phi.py +930 -0
  289. onnxruntime/transformers/onnx_model_sam2.py +138 -0
  290. onnxruntime/transformers/onnx_model_t5.py +791 -0
  291. onnxruntime/transformers/onnx_model_tnlr.py +227 -0
  292. onnxruntime/transformers/onnx_model_unet.py +259 -0
  293. onnxruntime/transformers/onnx_model_vae.py +43 -0
  294. onnxruntime/transformers/onnx_utils.py +55 -0
  295. onnxruntime/transformers/optimizer.py +612 -0
  296. onnxruntime/transformers/profiler.py +725 -0
  297. onnxruntime/transformers/quantize_helper.py +76 -0
  298. onnxruntime/transformers/shape_infer_helper.py +122 -0
  299. onnxruntime/transformers/shape_optimizer.py +401 -0
  300. onnxruntime/transformers/torch_onnx_export_helper.py +74 -0
  301. onnxruntime_directml-1.20.0.dist-info/METADATA +187 -0
  302. onnxruntime_directml-1.20.0.dist-info/RECORD +305 -0
  303. onnxruntime_directml-1.20.0.dist-info/WHEEL +5 -0
  304. onnxruntime_directml-1.20.0.dist-info/entry_points.txt +2 -0
  305. onnxruntime_directml-1.20.0.dist-info/top_level.txt +1 -0
@@ -0,0 +1,625 @@
1
+ # -------------------------------------------------------------------------
2
+ # Copyright (c) Microsoft Corporation. All rights reserved.
3
+ # Licensed under the MIT License.
4
+ # --------------------------------------------------------------------------
5
+
6
+ """
7
+ Benchmark performance of SAM2 encoder with ORT or PyTorch. See benchmark_sam2.sh for usage.
8
+ """
9
+
10
+ import argparse
11
+ import csv
12
+ import statistics
13
+ import time
14
+ from datetime import datetime
15
+ from typing import List, Mapping, Optional
16
+
17
+ import torch
18
+ from image_decoder import SAM2ImageDecoder
19
+ from image_encoder import SAM2ImageEncoder
20
+ from sam2_utils import decoder_shape_dict, encoder_shape_dict, load_sam2_model
21
+
22
+ from onnxruntime import InferenceSession, SessionOptions, get_available_providers
23
+ from onnxruntime.transformers.io_binding_helper import CudaSession
24
+
25
+
26
+ class TestConfig:
27
+ def __init__(
28
+ self,
29
+ model_type: str,
30
+ onnx_path: str,
31
+ sam2_dir: str,
32
+ device: torch.device,
33
+ component: str = "image_encoder",
34
+ provider="CPUExecutionProvider",
35
+ torch_compile_mode="max-autotune",
36
+ batch_size: int = 1,
37
+ height: int = 1024,
38
+ width: int = 1024,
39
+ num_labels: int = 1,
40
+ num_points: int = 1,
41
+ num_masks: int = 1,
42
+ multi_mask_output: bool = False,
43
+ use_tf32: bool = True,
44
+ enable_cuda_graph: bool = False,
45
+ dtype=torch.float32,
46
+ prefer_nhwc: bool = False,
47
+ warm_up: int = 5,
48
+ enable_nvtx_profile: bool = False,
49
+ enable_torch_profile: bool = False,
50
+ repeats: int = 1000,
51
+ verbose: bool = False,
52
+ ):
53
+ assert model_type in ["sam2_hiera_tiny", "sam2_hiera_small", "sam2_hiera_large", "sam2_hiera_base_plus"]
54
+ assert height >= 160 and height <= 4096
55
+ assert width >= 160 and width <= 4096
56
+
57
+ self.model_type = model_type
58
+ self.onnx_path = onnx_path
59
+ self.sam2_dir = sam2_dir
60
+ self.component = component
61
+ self.provider = provider
62
+ self.torch_compile_mode = torch_compile_mode
63
+ self.batch_size = batch_size
64
+ self.height = height
65
+ self.width = width
66
+ self.num_labels = num_labels
67
+ self.num_points = num_points
68
+ self.num_masks = num_masks
69
+ self.multi_mask_output = multi_mask_output
70
+ self.device = device
71
+ self.use_tf32 = use_tf32
72
+ self.enable_cuda_graph = enable_cuda_graph
73
+ self.dtype = dtype
74
+ self.prefer_nhwc = prefer_nhwc
75
+ self.warm_up = warm_up
76
+ self.enable_nvtx_profile = enable_nvtx_profile
77
+ self.enable_torch_profile = enable_torch_profile
78
+ self.repeats = repeats
79
+ self.verbose = verbose
80
+
81
+ if self.component == "image_encoder":
82
+ assert self.height == 1024 and self.width == 1024, "Only image size 1024x1024 is allowed for image encoder."
83
+
84
+ def __repr__(self):
85
+ return f"{vars(self)}"
86
+
87
+ def shape_dict(self) -> Mapping[str, List[int]]:
88
+ if self.component == "image_encoder":
89
+ return encoder_shape_dict(self.batch_size, self.height, self.width)
90
+ else:
91
+ return decoder_shape_dict(self.height, self.width, self.num_labels, self.num_points, self.num_masks)
92
+
93
+ def random_inputs(self) -> Mapping[str, torch.Tensor]:
94
+ dtype = self.dtype
95
+ if self.component == "image_encoder":
96
+ return {"image": torch.randn(self.batch_size, 3, self.height, self.width, dtype=dtype, device=self.device)}
97
+ else:
98
+ return {
99
+ "image_features_0": torch.rand(1, 32, 256, 256, dtype=dtype, device=self.device),
100
+ "image_features_1": torch.rand(1, 64, 128, 128, dtype=dtype, device=self.device),
101
+ "image_embeddings": torch.rand(1, 256, 64, 64, dtype=dtype, device=self.device),
102
+ "point_coords": torch.randint(
103
+ 0, 1024, (self.num_labels, self.num_points, 2), dtype=dtype, device=self.device
104
+ ),
105
+ "point_labels": torch.randint(
106
+ 0, 1, (self.num_labels, self.num_points), dtype=torch.int32, device=self.device
107
+ ),
108
+ "input_masks": torch.zeros(self.num_labels, 1, 256, 256, dtype=dtype, device=self.device),
109
+ "has_input_masks": torch.ones(self.num_labels, dtype=dtype, device=self.device),
110
+ "original_image_size": torch.tensor([self.height, self.width], dtype=torch.int32, device=self.device),
111
+ }
112
+
113
+
114
+ def create_ort_session(config: TestConfig, session_options=None) -> InferenceSession:
115
+ if config.verbose:
116
+ print(f"create session for {vars(config)}")
117
+
118
+ if config.provider == "CUDAExecutionProvider":
119
+ device_id = torch.cuda.current_device() if isinstance(config.device, str) else config.device.index
120
+ provider_options = CudaSession.get_cuda_provider_options(device_id, config.enable_cuda_graph)
121
+ provider_options["use_tf32"] = int(config.use_tf32)
122
+ if config.prefer_nhwc:
123
+ provider_options["prefer_nhwc"] = 1
124
+ providers = [(config.provider, provider_options), "CPUExecutionProvider"]
125
+ else:
126
+ providers = ["CPUExecutionProvider"]
127
+
128
+ ort_session = InferenceSession(config.onnx_path, session_options, providers=providers)
129
+ return ort_session
130
+
131
+
132
+ def create_session(config: TestConfig, session_options=None) -> CudaSession:
133
+ ort_session = create_ort_session(config, session_options)
134
+ cuda_session = CudaSession(ort_session, config.device, config.enable_cuda_graph)
135
+ cuda_session.allocate_buffers(config.shape_dict())
136
+ return cuda_session
137
+
138
+
139
+ class OrtTestSession:
140
+ """A wrapper of ORT session to test relevance and performance."""
141
+
142
+ def __init__(self, config: TestConfig, session_options=None):
143
+ self.ort_session = create_session(config, session_options)
144
+ self.feed_dict = config.random_inputs()
145
+
146
+ def infer(self):
147
+ return self.ort_session.infer(self.feed_dict)
148
+
149
+
150
+ def measure_latency(cuda_session: CudaSession, input_dict):
151
+ start = time.time()
152
+ _ = cuda_session.infer(input_dict)
153
+ end = time.time()
154
+ return end - start
155
+
156
+
157
+ def run_torch(config: TestConfig):
158
+ device_type = config.device.type
159
+ is_cuda = device_type == "cuda"
160
+
161
+ # Turn on TF32 for Ampere GPUs which could help when data type is float32.
162
+ if is_cuda and torch.cuda.get_device_properties(0).major >= 8 and config.use_tf32:
163
+ torch.backends.cuda.matmul.allow_tf32 = True
164
+ torch.backends.cudnn.allow_tf32 = True
165
+
166
+ enabled_auto_cast = is_cuda and config.dtype != torch.float32
167
+ ort_inputs = config.random_inputs()
168
+
169
+ with torch.inference_mode(), torch.autocast(device_type=device_type, dtype=config.dtype, enabled=enabled_auto_cast):
170
+ sam2_model = load_sam2_model(config.sam2_dir, config.model_type, device=config.device)
171
+ if config.component == "image_encoder":
172
+ if is_cuda and config.torch_compile_mode != "none":
173
+ sam2_model.image_encoder.forward = torch.compile(
174
+ sam2_model.image_encoder.forward,
175
+ mode=config.torch_compile_mode, # "reduce-overhead" if you want to reduce latency of first run.
176
+ fullgraph=True,
177
+ dynamic=False,
178
+ )
179
+
180
+ image_shape = config.shape_dict()["image"]
181
+ img = torch.randn(image_shape).to(device=config.device, dtype=config.dtype)
182
+ sam2_encoder = SAM2ImageEncoder(sam2_model)
183
+
184
+ if is_cuda and config.torch_compile_mode != "none":
185
+ print(f"Running warm up. It will take a while since torch compile mode is {config.torch_compile_mode}.")
186
+
187
+ for _ in range(config.warm_up):
188
+ _image_features_0, _image_features_1, _image_embeddings = sam2_encoder(img)
189
+
190
+ if is_cuda and config.enable_nvtx_profile:
191
+ import nvtx
192
+ from cuda import cudart
193
+
194
+ cudart.cudaProfilerStart()
195
+ print("Start nvtx profiling on encoder ...")
196
+ with nvtx.annotate("one_run"):
197
+ sam2_encoder(img, enable_nvtx_profile=True)
198
+ cudart.cudaProfilerStop()
199
+
200
+ if is_cuda and config.enable_torch_profile:
201
+ with torch.profiler.profile(
202
+ activities=[torch.profiler.ProfilerActivity.CPU, torch.profiler.ProfilerActivity.CUDA],
203
+ record_shapes=True,
204
+ ) as prof:
205
+ print("Start torch profiling on encoder ...")
206
+ with torch.profiler.record_function("encoder"):
207
+ sam2_encoder(img)
208
+ print(prof.key_averages().table(sort_by="cuda_time_total", row_limit=10))
209
+ prof.export_chrome_trace("torch_image_encoder.json")
210
+
211
+ if config.repeats == 0:
212
+ return
213
+
214
+ print(f"Start {config.repeats} runs of performance tests...")
215
+ start = time.time()
216
+ for _ in range(config.repeats):
217
+ _image_features_0, _image_features_1, _image_embeddings = sam2_encoder(img)
218
+ if is_cuda:
219
+ torch.cuda.synchronize()
220
+ else:
221
+ torch_inputs = (
222
+ ort_inputs["image_features_0"],
223
+ ort_inputs["image_features_1"],
224
+ ort_inputs["image_embeddings"],
225
+ ort_inputs["point_coords"],
226
+ ort_inputs["point_labels"],
227
+ ort_inputs["input_masks"],
228
+ ort_inputs["has_input_masks"],
229
+ ort_inputs["original_image_size"],
230
+ )
231
+
232
+ sam2_decoder = SAM2ImageDecoder(
233
+ sam2_model,
234
+ multimask_output=config.multi_mask_output,
235
+ )
236
+
237
+ if is_cuda and config.torch_compile_mode != "none":
238
+ sam2_decoder.forward = torch.compile(
239
+ sam2_decoder.forward,
240
+ mode=config.torch_compile_mode,
241
+ fullgraph=True,
242
+ dynamic=False,
243
+ )
244
+
245
+ # warm up
246
+ for _ in range(config.warm_up):
247
+ _masks, _iou_predictions, _low_res_masks = sam2_decoder(*torch_inputs)
248
+
249
+ if is_cuda and config.enable_nvtx_profile:
250
+ import nvtx
251
+ from cuda import cudart
252
+
253
+ cudart.cudaProfilerStart()
254
+ print("Start nvtx profiling on decoder...")
255
+ with nvtx.annotate("one_run"):
256
+ sam2_decoder(*torch_inputs, enable_nvtx_profile=True)
257
+ cudart.cudaProfilerStop()
258
+
259
+ if is_cuda and config.enable_torch_profile:
260
+ with torch.profiler.profile(
261
+ activities=[torch.profiler.ProfilerActivity.CPU, torch.profiler.ProfilerActivity.CUDA],
262
+ record_shapes=True,
263
+ ) as prof:
264
+ print("Start torch profiling on decoder ...")
265
+ with torch.profiler.record_function("decoder"):
266
+ sam2_decoder(*torch_inputs)
267
+ print(prof.key_averages().table(sort_by="cuda_time_total", row_limit=10))
268
+ prof.export_chrome_trace("torch_image_decoder.json")
269
+
270
+ if config.repeats == 0:
271
+ return
272
+
273
+ print(f"Start {config.repeats} runs of performance tests...")
274
+ start = time.time()
275
+ for _ in range(config.repeats):
276
+ _masks, _iou_predictions, _low_res_masks = sam2_decoder(*torch_inputs)
277
+ if is_cuda:
278
+ torch.cuda.synchronize()
279
+
280
+ end = time.time()
281
+ return (end - start) / config.repeats
282
+
283
+
284
+ def run_test(
285
+ args: argparse.Namespace,
286
+ csv_writer: Optional[csv.DictWriter] = None,
287
+ ):
288
+ use_gpu: bool = args.use_gpu
289
+ enable_cuda_graph: bool = args.use_cuda_graph
290
+ repeats: int = args.repeats
291
+
292
+ if use_gpu:
293
+ device_id = torch.cuda.current_device()
294
+ device = torch.device("cuda", device_id)
295
+ provider = "CUDAExecutionProvider"
296
+ else:
297
+ device_id = 0
298
+ device = torch.device("cpu")
299
+ enable_cuda_graph = False
300
+ provider = "CPUExecutionProvider"
301
+
302
+ dtypes = {"fp32": torch.float32, "fp16": torch.float16, "bf16": torch.bfloat16}
303
+ config = TestConfig(
304
+ model_type=args.model_type,
305
+ onnx_path=args.onnx_path,
306
+ sam2_dir=args.sam2_dir,
307
+ component=args.component,
308
+ provider=provider,
309
+ batch_size=args.batch_size,
310
+ height=args.height,
311
+ width=args.width,
312
+ device=device,
313
+ use_tf32=True,
314
+ enable_cuda_graph=enable_cuda_graph,
315
+ dtype=dtypes[args.dtype],
316
+ prefer_nhwc=args.prefer_nhwc,
317
+ repeats=args.repeats,
318
+ warm_up=args.warm_up,
319
+ enable_nvtx_profile=args.enable_nvtx_profile,
320
+ enable_torch_profile=args.enable_torch_profile,
321
+ torch_compile_mode=args.torch_compile_mode,
322
+ verbose=False,
323
+ )
324
+
325
+ if args.engine == "ort":
326
+ sess_options = SessionOptions()
327
+ sess_options.intra_op_num_threads = args.intra_op_num_threads
328
+ if config.enable_nvtx_profile:
329
+ sess_options.enable_profiling = True
330
+ sess_options.log_severity_level = 4
331
+ sess_options.log_verbosity_level = 0
332
+
333
+ session = create_session(config, sess_options)
334
+ input_dict = config.random_inputs()
335
+
336
+ # warm up session
337
+ try:
338
+ for _ in range(config.warm_up):
339
+ _ = measure_latency(session, input_dict)
340
+ except Exception as e:
341
+ print(f"Failed to run {config=}. Exception: {e}")
342
+ return
343
+
344
+ if config.enable_nvtx_profile:
345
+ import nvtx
346
+ from cuda import cudart
347
+
348
+ cudart.cudaProfilerStart()
349
+ with nvtx.annotate("one_run"):
350
+ _ = session.infer(input_dict)
351
+ cudart.cudaProfilerStop()
352
+ session.ort_session.end_profiling()
353
+
354
+ if repeats == 0:
355
+ return
356
+
357
+ latency_list = []
358
+ for _ in range(repeats):
359
+ latency = measure_latency(session, input_dict)
360
+ latency_list.append(latency)
361
+ average_latency = statistics.mean(latency_list)
362
+
363
+ del session
364
+ else: # torch
365
+ with torch.no_grad():
366
+ try:
367
+ average_latency = run_torch(config)
368
+ except Exception as e:
369
+ print(f"Failed to run {config=}. Exception: {e}")
370
+ return
371
+
372
+ if repeats == 0:
373
+ return
374
+
375
+ engine = args.engine + ":" + ("cuda" if use_gpu else "cpu")
376
+ row = {
377
+ "model_type": args.model_type,
378
+ "component": args.component,
379
+ "dtype": args.dtype,
380
+ "use_gpu": use_gpu,
381
+ "enable_cuda_graph": enable_cuda_graph,
382
+ "prefer_nhwc": config.prefer_nhwc,
383
+ "use_tf32": config.use_tf32,
384
+ "batch_size": args.batch_size,
385
+ "height": args.height,
386
+ "width": args.width,
387
+ "multi_mask_output": args.multimask_output,
388
+ "num_labels": config.num_labels,
389
+ "num_points": config.num_points,
390
+ "num_masks": config.num_masks,
391
+ "intra_op_num_threads": args.intra_op_num_threads,
392
+ "warm_up": config.warm_up,
393
+ "repeats": repeats,
394
+ "enable_nvtx_profile": args.enable_nvtx_profile,
395
+ "torch_compile_mode": args.torch_compile_mode,
396
+ "engine": engine,
397
+ "average_latency": average_latency,
398
+ }
399
+
400
+ if csv_writer is not None:
401
+ csv_writer.writerow(row)
402
+
403
+ print(f"{vars(config)}")
404
+ print(f"{row}")
405
+
406
+
407
+ def run_perf_test(args):
408
+ features = "gpu" if args.use_gpu else "cpu"
409
+ csv_filename = "benchmark_sam_{}_{}_{}.csv".format(
410
+ features,
411
+ args.engine,
412
+ datetime.now().strftime("%Y%m%d-%H%M%S"),
413
+ )
414
+ with open(csv_filename, mode="a", newline="") as csv_file:
415
+ column_names = [
416
+ "model_type",
417
+ "component",
418
+ "dtype",
419
+ "use_gpu",
420
+ "enable_cuda_graph",
421
+ "prefer_nhwc",
422
+ "use_tf32",
423
+ "batch_size",
424
+ "height",
425
+ "width",
426
+ "multi_mask_output",
427
+ "num_labels",
428
+ "num_points",
429
+ "num_masks",
430
+ "intra_op_num_threads",
431
+ "warm_up",
432
+ "repeats",
433
+ "enable_nvtx_profile",
434
+ "torch_compile_mode",
435
+ "engine",
436
+ "average_latency",
437
+ ]
438
+ csv_writer = csv.DictWriter(csv_file, fieldnames=column_names)
439
+ csv_writer.writeheader()
440
+
441
+ run_test(args, csv_writer)
442
+
443
+
444
+ def _parse_arguments():
445
+ parser = argparse.ArgumentParser(description="Benchmark SMA2 for ONNX Runtime and PyTorch.")
446
+
447
+ parser.add_argument(
448
+ "--component",
449
+ required=False,
450
+ choices=["image_encoder", "image_decoder"],
451
+ default="image_encoder",
452
+ help="component to benchmark. Choices are image_encoder and image_decoder.",
453
+ )
454
+
455
+ parser.add_argument(
456
+ "--dtype", required=False, choices=["fp32", "fp16", "bf16"], default="fp32", help="Data type for inference."
457
+ )
458
+
459
+ parser.add_argument(
460
+ "--use_gpu",
461
+ required=False,
462
+ action="store_true",
463
+ help="Use GPU for inference.",
464
+ )
465
+ parser.set_defaults(use_gpu=False)
466
+
467
+ parser.add_argument(
468
+ "--use_cuda_graph",
469
+ required=False,
470
+ action="store_true",
471
+ help="Use cuda graph in onnxruntime.",
472
+ )
473
+ parser.set_defaults(use_cuda_graph=False)
474
+
475
+ parser.add_argument(
476
+ "--intra_op_num_threads",
477
+ required=False,
478
+ type=int,
479
+ choices=[0, 1, 2, 4, 8, 16],
480
+ default=0,
481
+ help="intra_op_num_threads for onnxruntime. ",
482
+ )
483
+
484
+ parser.add_argument(
485
+ "--batch_size",
486
+ required=False,
487
+ type=int,
488
+ default=1,
489
+ help="batch size",
490
+ )
491
+
492
+ parser.add_argument(
493
+ "--height",
494
+ required=False,
495
+ type=int,
496
+ default=1024,
497
+ help="image height",
498
+ )
499
+
500
+ parser.add_argument(
501
+ "--width",
502
+ required=False,
503
+ type=int,
504
+ default=1024,
505
+ help="image width",
506
+ )
507
+
508
+ parser.add_argument(
509
+ "--repeats",
510
+ required=False,
511
+ type=int,
512
+ default=1000,
513
+ help="number of repeats for performance test. Default is 1000.",
514
+ )
515
+
516
+ parser.add_argument(
517
+ "--warm_up",
518
+ required=False,
519
+ type=int,
520
+ default=5,
521
+ help="number of runs for warm up. Default is 5.",
522
+ )
523
+
524
+ parser.add_argument(
525
+ "--engine",
526
+ required=False,
527
+ type=str,
528
+ default="ort",
529
+ choices=["ort", "torch"],
530
+ help="engine for inference",
531
+ )
532
+
533
+ parser.add_argument(
534
+ "--multimask_output",
535
+ required=False,
536
+ default=False,
537
+ action="store_true",
538
+ help="Export mask_decoder or image_decoder with multimask_output",
539
+ )
540
+
541
+ parser.add_argument(
542
+ "--prefer_nhwc",
543
+ required=False,
544
+ default=False,
545
+ action="store_true",
546
+ help="Use prefer_nhwc=1 provider option for CUDAExecutionProvider",
547
+ )
548
+
549
+ parser.add_argument(
550
+ "--enable_nvtx_profile",
551
+ required=False,
552
+ default=False,
553
+ action="store_true",
554
+ help="Enable nvtx profiling. It will add an extra run for profiling before performance test.",
555
+ )
556
+
557
+ parser.add_argument(
558
+ "--enable_torch_profile",
559
+ required=False,
560
+ default=False,
561
+ action="store_true",
562
+ help="Enable PyTorch profiling. It will add an extra run for profiling before performance test.",
563
+ )
564
+
565
+ parser.add_argument(
566
+ "--model_type",
567
+ required=False,
568
+ type=str,
569
+ default="sam2_hiera_large",
570
+ choices=["sam2_hiera_tiny", "sam2_hiera_small", "sam2_hiera_large", "sam2_hiera_base_plus"],
571
+ help="sam2 model name",
572
+ )
573
+
574
+ parser.add_argument(
575
+ "--sam2_dir",
576
+ required=False,
577
+ type=str,
578
+ default="./segment-anything-2",
579
+ help="The directory of segment-anything-2 git root directory",
580
+ )
581
+
582
+ parser.add_argument(
583
+ "--onnx_path",
584
+ required=False,
585
+ type=str,
586
+ default="./sam2_onnx_models/sam2_hiera_large_image_encoder.onnx",
587
+ help="path of onnx model",
588
+ )
589
+
590
+ parser.add_argument(
591
+ "--torch_compile_mode",
592
+ required=False,
593
+ type=str,
594
+ default=None,
595
+ choices=["reduce-overhead", "max-autotune", "max-autotune-no-cudagraphs", "none"],
596
+ help="torch compile mode. none will disable torch compile.",
597
+ )
598
+
599
+ args = parser.parse_args()
600
+
601
+ return args
602
+
603
+
604
+ if __name__ == "__main__":
605
+ args = _parse_arguments()
606
+ print(f"arguments:{args}")
607
+
608
+ if args.torch_compile_mode is None:
609
+ # image decoder will fail with compile modes other than "none".
610
+ args.torch_compile_mode = "max-autotune" if args.component == "image_encoder" else "none"
611
+
612
+ if args.use_gpu:
613
+ assert torch.cuda.is_available()
614
+ if args.engine == "ort":
615
+ assert "CUDAExecutionProvider" in get_available_providers()
616
+ args.enable_torch_profile = False
617
+ else:
618
+ # Only support cuda profiling for now.
619
+ assert not args.enable_nvtx_profile
620
+ assert not args.enable_torch_profile
621
+
622
+ if args.enable_nvtx_profile or args.enable_torch_profile:
623
+ run_test(args)
624
+ else:
625
+ run_perf_test(args)