onnxruntime-directml 1.24.1__cp314-cp314-win_amd64.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (322) hide show
  1. onnxruntime/LICENSE +21 -0
  2. onnxruntime/Privacy.md +21 -0
  3. onnxruntime/ThirdPartyNotices.txt +6121 -0
  4. onnxruntime/__init__.py +418 -0
  5. onnxruntime/backend/__init__.py +6 -0
  6. onnxruntime/backend/backend.py +175 -0
  7. onnxruntime/backend/backend_rep.py +52 -0
  8. onnxruntime/capi/DirectML.dll +0 -0
  9. onnxruntime/capi/__init__.py +4 -0
  10. onnxruntime/capi/_ld_preload.py +7 -0
  11. onnxruntime/capi/_pybind_state.py +33 -0
  12. onnxruntime/capi/build_and_package_info.py +2 -0
  13. onnxruntime/capi/convert_npz_to_onnx_adapter.py +48 -0
  14. onnxruntime/capi/onnxruntime.dll +0 -0
  15. onnxruntime/capi/onnxruntime_collect_build_info.py +47 -0
  16. onnxruntime/capi/onnxruntime_inference_collection.py +1440 -0
  17. onnxruntime/capi/onnxruntime_providers_shared.dll +0 -0
  18. onnxruntime/capi/onnxruntime_pybind11_state.pyd +0 -0
  19. onnxruntime/capi/onnxruntime_validation.py +154 -0
  20. onnxruntime/capi/version_info.py +2 -0
  21. onnxruntime/datasets/__init__.py +18 -0
  22. onnxruntime/datasets/logreg_iris.onnx +0 -0
  23. onnxruntime/datasets/mul_1.onnx +0 -0
  24. onnxruntime/datasets/sigmoid.onnx +13 -0
  25. onnxruntime/quantization/CalTableFlatBuffers/KeyValue.py +78 -0
  26. onnxruntime/quantization/CalTableFlatBuffers/TrtTable.py +90 -0
  27. onnxruntime/quantization/CalTableFlatBuffers/__init__.py +0 -0
  28. onnxruntime/quantization/__init__.py +19 -0
  29. onnxruntime/quantization/base_quantizer.py +529 -0
  30. onnxruntime/quantization/calibrate.py +1267 -0
  31. onnxruntime/quantization/execution_providers/qnn/__init__.py +2 -0
  32. onnxruntime/quantization/execution_providers/qnn/fusion_lpnorm.py +132 -0
  33. onnxruntime/quantization/execution_providers/qnn/fusion_spacetodepth.py +162 -0
  34. onnxruntime/quantization/execution_providers/qnn/mixed_precision_overrides_utils.py +413 -0
  35. onnxruntime/quantization/execution_providers/qnn/preprocess.py +353 -0
  36. onnxruntime/quantization/execution_providers/qnn/quant_config.py +389 -0
  37. onnxruntime/quantization/fusions/__init__.py +4 -0
  38. onnxruntime/quantization/fusions/fusion.py +311 -0
  39. onnxruntime/quantization/fusions/fusion_gelu.py +272 -0
  40. onnxruntime/quantization/fusions/fusion_layernorm.py +146 -0
  41. onnxruntime/quantization/fusions/replace_upsample_with_resize.py +96 -0
  42. onnxruntime/quantization/matmul_bnb4_quantizer.py +239 -0
  43. onnxruntime/quantization/matmul_nbits_quantizer.py +1638 -0
  44. onnxruntime/quantization/neural_compressor/__init__.py +1 -0
  45. onnxruntime/quantization/neural_compressor/onnx_model.py +1251 -0
  46. onnxruntime/quantization/neural_compressor/util.py +80 -0
  47. onnxruntime/quantization/neural_compressor/weight_only.py +932 -0
  48. onnxruntime/quantization/onnx_model.py +600 -0
  49. onnxruntime/quantization/onnx_quantizer.py +1163 -0
  50. onnxruntime/quantization/operators/__init__.py +2 -0
  51. onnxruntime/quantization/operators/activation.py +119 -0
  52. onnxruntime/quantization/operators/argmax.py +18 -0
  53. onnxruntime/quantization/operators/attention.py +73 -0
  54. onnxruntime/quantization/operators/base_operator.py +26 -0
  55. onnxruntime/quantization/operators/binary_op.py +72 -0
  56. onnxruntime/quantization/operators/concat.py +62 -0
  57. onnxruntime/quantization/operators/conv.py +260 -0
  58. onnxruntime/quantization/operators/direct_q8.py +78 -0
  59. onnxruntime/quantization/operators/embed_layernorm.py +121 -0
  60. onnxruntime/quantization/operators/gather.py +64 -0
  61. onnxruntime/quantization/operators/gavgpool.py +62 -0
  62. onnxruntime/quantization/operators/gemm.py +172 -0
  63. onnxruntime/quantization/operators/lstm.py +121 -0
  64. onnxruntime/quantization/operators/matmul.py +231 -0
  65. onnxruntime/quantization/operators/maxpool.py +34 -0
  66. onnxruntime/quantization/operators/norm.py +40 -0
  67. onnxruntime/quantization/operators/pad.py +172 -0
  68. onnxruntime/quantization/operators/pooling.py +67 -0
  69. onnxruntime/quantization/operators/qdq_base_operator.py +22 -0
  70. onnxruntime/quantization/operators/resize.py +34 -0
  71. onnxruntime/quantization/operators/softmax.py +74 -0
  72. onnxruntime/quantization/operators/split.py +63 -0
  73. onnxruntime/quantization/operators/where.py +87 -0
  74. onnxruntime/quantization/preprocess.py +141 -0
  75. onnxruntime/quantization/qdq_loss_debug.py +389 -0
  76. onnxruntime/quantization/qdq_quantizer.py +1477 -0
  77. onnxruntime/quantization/quant_utils.py +1051 -0
  78. onnxruntime/quantization/quantize.py +953 -0
  79. onnxruntime/quantization/registry.py +110 -0
  80. onnxruntime/quantization/shape_inference.py +204 -0
  81. onnxruntime/quantization/static_quantize_runner.py +256 -0
  82. onnxruntime/quantization/tensor_quant_overrides.py +520 -0
  83. onnxruntime/tools/__init__.py +10 -0
  84. onnxruntime/tools/check_onnx_model_mobile_usability.py +47 -0
  85. onnxruntime/tools/convert_onnx_models_to_ort.py +380 -0
  86. onnxruntime/tools/file_utils.py +47 -0
  87. onnxruntime/tools/logger.py +11 -0
  88. onnxruntime/tools/make_dynamic_shape_fixed.py +73 -0
  89. onnxruntime/tools/mobile_helpers/__init__.py +0 -0
  90. onnxruntime/tools/mobile_helpers/coreml_supported_mlprogram_ops.md +53 -0
  91. onnxruntime/tools/mobile_helpers/coreml_supported_neuralnetwork_ops.md +43 -0
  92. onnxruntime/tools/mobile_helpers/nnapi_supported_ops.md +58 -0
  93. onnxruntime/tools/mobile_helpers/usability_checker.py +738 -0
  94. onnxruntime/tools/offline_tuning.py +169 -0
  95. onnxruntime/tools/onnx_model_utils.py +416 -0
  96. onnxruntime/tools/onnx_randomizer.py +85 -0
  97. onnxruntime/tools/onnxruntime_test.py +164 -0
  98. onnxruntime/tools/optimize_onnx_model.py +56 -0
  99. onnxruntime/tools/ort_format_model/__init__.py +27 -0
  100. onnxruntime/tools/ort_format_model/operator_type_usage_processors.py +653 -0
  101. onnxruntime/tools/ort_format_model/ort_flatbuffers_py/__init__.py +0 -0
  102. onnxruntime/tools/ort_format_model/ort_flatbuffers_py/fbs/ArgType.py +7 -0
  103. onnxruntime/tools/ort_format_model/ort_flatbuffers_py/fbs/ArgTypeAndIndex.py +67 -0
  104. onnxruntime/tools/ort_format_model/ort_flatbuffers_py/fbs/Attribute.py +337 -0
  105. onnxruntime/tools/ort_format_model/ort_flatbuffers_py/fbs/AttributeType.py +18 -0
  106. onnxruntime/tools/ort_format_model/ort_flatbuffers_py/fbs/Checkpoint.py +125 -0
  107. onnxruntime/tools/ort_format_model/ort_flatbuffers_py/fbs/DeprecatedKernelCreateInfos.py +120 -0
  108. onnxruntime/tools/ort_format_model/ort_flatbuffers_py/fbs/DeprecatedNodeIndexAndKernelDefHash.py +68 -0
  109. onnxruntime/tools/ort_format_model/ort_flatbuffers_py/fbs/DeprecatedSessionState.py +96 -0
  110. onnxruntime/tools/ort_format_model/ort_flatbuffers_py/fbs/DeprecatedSubGraphSessionState.py +72 -0
  111. onnxruntime/tools/ort_format_model/ort_flatbuffers_py/fbs/Dimension.py +71 -0
  112. onnxruntime/tools/ort_format_model/ort_flatbuffers_py/fbs/DimensionValue.py +80 -0
  113. onnxruntime/tools/ort_format_model/ort_flatbuffers_py/fbs/DimensionValueType.py +8 -0
  114. onnxruntime/tools/ort_format_model/ort_flatbuffers_py/fbs/EdgeEnd.py +32 -0
  115. onnxruntime/tools/ort_format_model/ort_flatbuffers_py/fbs/FloatProperty.py +67 -0
  116. onnxruntime/tools/ort_format_model/ort_flatbuffers_py/fbs/Graph.py +320 -0
  117. onnxruntime/tools/ort_format_model/ort_flatbuffers_py/fbs/InferenceSession.py +88 -0
  118. onnxruntime/tools/ort_format_model/ort_flatbuffers_py/fbs/IntProperty.py +67 -0
  119. onnxruntime/tools/ort_format_model/ort_flatbuffers_py/fbs/KernelTypeStrArgsEntry.py +91 -0
  120. onnxruntime/tools/ort_format_model/ort_flatbuffers_py/fbs/KernelTypeStrResolver.py +78 -0
  121. onnxruntime/tools/ort_format_model/ort_flatbuffers_py/fbs/MapType.py +71 -0
  122. onnxruntime/tools/ort_format_model/ort_flatbuffers_py/fbs/Model.py +223 -0
  123. onnxruntime/tools/ort_format_model/ort_flatbuffers_py/fbs/ModuleState.py +141 -0
  124. onnxruntime/tools/ort_format_model/ort_flatbuffers_py/fbs/Node.py +317 -0
  125. onnxruntime/tools/ort_format_model/ort_flatbuffers_py/fbs/NodeEdge.py +126 -0
  126. onnxruntime/tools/ort_format_model/ort_flatbuffers_py/fbs/NodeType.py +7 -0
  127. onnxruntime/tools/ort_format_model/ort_flatbuffers_py/fbs/NodesToOptimizeIndices.py +160 -0
  128. onnxruntime/tools/ort_format_model/ort_flatbuffers_py/fbs/OpIdKernelTypeStrArgsEntry.py +91 -0
  129. onnxruntime/tools/ort_format_model/ort_flatbuffers_py/fbs/OperatorSetId.py +67 -0
  130. onnxruntime/tools/ort_format_model/ort_flatbuffers_py/fbs/OptimizerGroup.py +117 -0
  131. onnxruntime/tools/ort_format_model/ort_flatbuffers_py/fbs/ParameterOptimizerState.py +91 -0
  132. onnxruntime/tools/ort_format_model/ort_flatbuffers_py/fbs/PropertyBag.py +152 -0
  133. onnxruntime/tools/ort_format_model/ort_flatbuffers_py/fbs/RuntimeOptimizationRecord.py +105 -0
  134. onnxruntime/tools/ort_format_model/ort_flatbuffers_py/fbs/RuntimeOptimizationRecordContainerEntry.py +91 -0
  135. onnxruntime/tools/ort_format_model/ort_flatbuffers_py/fbs/RuntimeOptimizations.py +79 -0
  136. onnxruntime/tools/ort_format_model/ort_flatbuffers_py/fbs/SequenceType.py +58 -0
  137. onnxruntime/tools/ort_format_model/ort_flatbuffers_py/fbs/Shape.py +78 -0
  138. onnxruntime/tools/ort_format_model/ort_flatbuffers_py/fbs/SparseTensor.py +114 -0
  139. onnxruntime/tools/ort_format_model/ort_flatbuffers_py/fbs/StringProperty.py +67 -0
  140. onnxruntime/tools/ort_format_model/ort_flatbuffers_py/fbs/StringStringEntry.py +67 -0
  141. onnxruntime/tools/ort_format_model/ort_flatbuffers_py/fbs/Tensor.py +203 -0
  142. onnxruntime/tools/ort_format_model/ort_flatbuffers_py/fbs/TensorDataType.py +26 -0
  143. onnxruntime/tools/ort_format_model/ort_flatbuffers_py/fbs/TensorTypeAndShape.py +71 -0
  144. onnxruntime/tools/ort_format_model/ort_flatbuffers_py/fbs/TypeInfo.py +83 -0
  145. onnxruntime/tools/ort_format_model/ort_flatbuffers_py/fbs/TypeInfoValue.py +9 -0
  146. onnxruntime/tools/ort_format_model/ort_flatbuffers_py/fbs/ValueInfo.py +84 -0
  147. onnxruntime/tools/ort_format_model/ort_flatbuffers_py/fbs/__init__.py +6 -0
  148. onnxruntime/tools/ort_format_model/ort_model_processor.py +86 -0
  149. onnxruntime/tools/ort_format_model/types.py +85 -0
  150. onnxruntime/tools/ort_format_model/utils.py +61 -0
  151. onnxruntime/tools/pytorch_export_contrib_ops.py +129 -0
  152. onnxruntime/tools/pytorch_export_helpers.py +131 -0
  153. onnxruntime/tools/qdq_helpers/__init__.py +0 -0
  154. onnxruntime/tools/qdq_helpers/optimize_qdq_model.py +37 -0
  155. onnxruntime/tools/qnn/add_trans_cast.py +292 -0
  156. onnxruntime/tools/qnn/gen_qnn_ctx_onnx_model.py +364 -0
  157. onnxruntime/tools/qnn/preprocess.py +165 -0
  158. onnxruntime/tools/reduced_build_config_parser.py +203 -0
  159. onnxruntime/tools/remove_initializer_from_input.py +37 -0
  160. onnxruntime/tools/symbolic_shape_infer.py +3094 -0
  161. onnxruntime/tools/update_onnx_opset.py +31 -0
  162. onnxruntime/transformers/__init__.py +8 -0
  163. onnxruntime/transformers/affinity_helper.py +40 -0
  164. onnxruntime/transformers/benchmark.py +942 -0
  165. onnxruntime/transformers/benchmark_helper.py +643 -0
  166. onnxruntime/transformers/bert_perf_test.py +629 -0
  167. onnxruntime/transformers/bert_test_data.py +641 -0
  168. onnxruntime/transformers/compare_bert_results.py +256 -0
  169. onnxruntime/transformers/constants.py +47 -0
  170. onnxruntime/transformers/convert_generation.py +3605 -0
  171. onnxruntime/transformers/convert_tf_models_to_pytorch.py +205 -0
  172. onnxruntime/transformers/convert_to_packing_mode.py +385 -0
  173. onnxruntime/transformers/dynamo_onnx_helper.py +205 -0
  174. onnxruntime/transformers/float16.py +501 -0
  175. onnxruntime/transformers/fusion_attention.py +1189 -0
  176. onnxruntime/transformers/fusion_attention_clip.py +340 -0
  177. onnxruntime/transformers/fusion_attention_sam2.py +533 -0
  178. onnxruntime/transformers/fusion_attention_unet.py +1307 -0
  179. onnxruntime/transformers/fusion_attention_vae.py +300 -0
  180. onnxruntime/transformers/fusion_bart_attention.py +435 -0
  181. onnxruntime/transformers/fusion_base.py +141 -0
  182. onnxruntime/transformers/fusion_bias_add.py +57 -0
  183. onnxruntime/transformers/fusion_biasgelu.py +66 -0
  184. onnxruntime/transformers/fusion_biassplitgelu.py +110 -0
  185. onnxruntime/transformers/fusion_conformer_attention.py +222 -0
  186. onnxruntime/transformers/fusion_constant_fold.py +144 -0
  187. onnxruntime/transformers/fusion_embedlayer.py +810 -0
  188. onnxruntime/transformers/fusion_fastgelu.py +492 -0
  189. onnxruntime/transformers/fusion_gelu.py +258 -0
  190. onnxruntime/transformers/fusion_gelu_approximation.py +25 -0
  191. onnxruntime/transformers/fusion_gemmfastgelu.py +121 -0
  192. onnxruntime/transformers/fusion_gpt_attention.py +546 -0
  193. onnxruntime/transformers/fusion_gpt_attention_megatron.py +355 -0
  194. onnxruntime/transformers/fusion_gpt_attention_no_past.py +260 -0
  195. onnxruntime/transformers/fusion_group_norm.py +180 -0
  196. onnxruntime/transformers/fusion_layernorm.py +489 -0
  197. onnxruntime/transformers/fusion_mha_mmdit.py +667 -0
  198. onnxruntime/transformers/fusion_nhwc_conv.py +99 -0
  199. onnxruntime/transformers/fusion_options.py +340 -0
  200. onnxruntime/transformers/fusion_qordered_attention.py +420 -0
  201. onnxruntime/transformers/fusion_qordered_gelu.py +118 -0
  202. onnxruntime/transformers/fusion_qordered_layernorm.py +122 -0
  203. onnxruntime/transformers/fusion_qordered_matmul.py +216 -0
  204. onnxruntime/transformers/fusion_quickgelu.py +74 -0
  205. onnxruntime/transformers/fusion_reshape.py +173 -0
  206. onnxruntime/transformers/fusion_rotary_attention.py +1591 -0
  207. onnxruntime/transformers/fusion_shape.py +109 -0
  208. onnxruntime/transformers/fusion_simplified_layernorm.py +165 -0
  209. onnxruntime/transformers/fusion_skip_group_norm.py +254 -0
  210. onnxruntime/transformers/fusion_skiplayernorm.py +209 -0
  211. onnxruntime/transformers/fusion_transpose.py +167 -0
  212. onnxruntime/transformers/fusion_utils.py +321 -0
  213. onnxruntime/transformers/huggingface_models.py +74 -0
  214. onnxruntime/transformers/import_utils.py +20 -0
  215. onnxruntime/transformers/io_binding_helper.py +487 -0
  216. onnxruntime/transformers/large_model_exporter.py +395 -0
  217. onnxruntime/transformers/machine_info.py +230 -0
  218. onnxruntime/transformers/metrics.py +163 -0
  219. onnxruntime/transformers/models/bart/__init__.py +12 -0
  220. onnxruntime/transformers/models/bart/export.py +98 -0
  221. onnxruntime/transformers/models/bert/__init__.py +12 -0
  222. onnxruntime/transformers/models/bert/eval_squad.py +329 -0
  223. onnxruntime/transformers/models/gpt2/__init__.py +12 -0
  224. onnxruntime/transformers/models/gpt2/benchmark_gpt2.py +413 -0
  225. onnxruntime/transformers/models/gpt2/convert_to_onnx.py +566 -0
  226. onnxruntime/transformers/models/gpt2/gpt2_helper.py +1031 -0
  227. onnxruntime/transformers/models/gpt2/gpt2_parity.py +513 -0
  228. onnxruntime/transformers/models/gpt2/gpt2_tester.py +501 -0
  229. onnxruntime/transformers/models/gpt2/parity_check_helper.py +146 -0
  230. onnxruntime/transformers/models/llama/__init__.py +12 -0
  231. onnxruntime/transformers/models/llama/benchmark.py +700 -0
  232. onnxruntime/transformers/models/llama/benchmark_all.py +488 -0
  233. onnxruntime/transformers/models/llama/benchmark_e2e.py +608 -0
  234. onnxruntime/transformers/models/llama/convert_to_onnx.py +1064 -0
  235. onnxruntime/transformers/models/llama/dist_settings.py +57 -0
  236. onnxruntime/transformers/models/llama/llama_inputs.py +504 -0
  237. onnxruntime/transformers/models/llama/llama_parity.py +343 -0
  238. onnxruntime/transformers/models/llama/llama_torch.py +47 -0
  239. onnxruntime/transformers/models/llama/quant_kv_dataloader.py +108 -0
  240. onnxruntime/transformers/models/longformer/__init__.py +12 -0
  241. onnxruntime/transformers/models/longformer/benchmark_longformer.py +821 -0
  242. onnxruntime/transformers/models/longformer/convert_to_onnx.py +413 -0
  243. onnxruntime/transformers/models/longformer/generate_test_data.py +347 -0
  244. onnxruntime/transformers/models/longformer/longformer_helper.py +76 -0
  245. onnxruntime/transformers/models/phi2/__init__.py +12 -0
  246. onnxruntime/transformers/models/phi2/convert_to_onnx.py +590 -0
  247. onnxruntime/transformers/models/phi2/inference_example.py +414 -0
  248. onnxruntime/transformers/models/sam2/__init__.py +12 -0
  249. onnxruntime/transformers/models/sam2/benchmark_sam2.py +638 -0
  250. onnxruntime/transformers/models/sam2/convert_to_onnx.py +270 -0
  251. onnxruntime/transformers/models/sam2/image_decoder.py +272 -0
  252. onnxruntime/transformers/models/sam2/image_encoder.py +236 -0
  253. onnxruntime/transformers/models/sam2/mask_decoder.py +208 -0
  254. onnxruntime/transformers/models/sam2/nvtx_helper.py +33 -0
  255. onnxruntime/transformers/models/sam2/prompt_encoder.py +189 -0
  256. onnxruntime/transformers/models/sam2/sam2_demo.py +321 -0
  257. onnxruntime/transformers/models/sam2/sam2_image_onnx_predictor.py +279 -0
  258. onnxruntime/transformers/models/sam2/sam2_utils.py +147 -0
  259. onnxruntime/transformers/models/stable_diffusion/__init__.py +12 -0
  260. onnxruntime/transformers/models/stable_diffusion/benchmark.py +1519 -0
  261. onnxruntime/transformers/models/stable_diffusion/benchmark_controlnet.py +426 -0
  262. onnxruntime/transformers/models/stable_diffusion/demo_txt2img.py +103 -0
  263. onnxruntime/transformers/models/stable_diffusion/demo_txt2img_xl.py +269 -0
  264. onnxruntime/transformers/models/stable_diffusion/demo_utils.py +778 -0
  265. onnxruntime/transformers/models/stable_diffusion/diffusion_models.py +1318 -0
  266. onnxruntime/transformers/models/stable_diffusion/diffusion_schedulers.py +1179 -0
  267. onnxruntime/transformers/models/stable_diffusion/engine_builder.py +295 -0
  268. onnxruntime/transformers/models/stable_diffusion/engine_builder_ort_cuda.py +387 -0
  269. onnxruntime/transformers/models/stable_diffusion/engine_builder_ort_trt.py +288 -0
  270. onnxruntime/transformers/models/stable_diffusion/engine_builder_tensorrt.py +395 -0
  271. onnxruntime/transformers/models/stable_diffusion/engine_builder_torch.py +108 -0
  272. onnxruntime/transformers/models/stable_diffusion/optimize_pipeline.py +590 -0
  273. onnxruntime/transformers/models/stable_diffusion/ort_optimizer.py +136 -0
  274. onnxruntime/transformers/models/stable_diffusion/pipeline_stable_diffusion.py +831 -0
  275. onnxruntime/transformers/models/stable_diffusion/trt_utilities.py +12 -0
  276. onnxruntime/transformers/models/t5/__init__.py +12 -0
  277. onnxruntime/transformers/models/t5/convert_to_onnx.py +318 -0
  278. onnxruntime/transformers/models/t5/t5_decoder.py +437 -0
  279. onnxruntime/transformers/models/t5/t5_encoder.py +70 -0
  280. onnxruntime/transformers/models/t5/t5_encoder_decoder_init.py +361 -0
  281. onnxruntime/transformers/models/t5/t5_helper.py +302 -0
  282. onnxruntime/transformers/models/whisper/__init__.py +12 -0
  283. onnxruntime/transformers/models/whisper/benchmark.py +585 -0
  284. onnxruntime/transformers/models/whisper/benchmark_all.py +526 -0
  285. onnxruntime/transformers/models/whisper/convert_to_onnx.py +609 -0
  286. onnxruntime/transformers/models/whisper/whisper_chain.py +334 -0
  287. onnxruntime/transformers/models/whisper/whisper_decoder.py +464 -0
  288. onnxruntime/transformers/models/whisper/whisper_encoder.py +164 -0
  289. onnxruntime/transformers/models/whisper/whisper_encoder_decoder_init.py +371 -0
  290. onnxruntime/transformers/models/whisper/whisper_helper.py +1035 -0
  291. onnxruntime/transformers/models/whisper/whisper_inputs.py +380 -0
  292. onnxruntime/transformers/models/whisper/whisper_jump_times.py +477 -0
  293. onnxruntime/transformers/onnx_exporter.py +719 -0
  294. onnxruntime/transformers/onnx_model.py +1636 -0
  295. onnxruntime/transformers/onnx_model_bart.py +141 -0
  296. onnxruntime/transformers/onnx_model_bert.py +488 -0
  297. onnxruntime/transformers/onnx_model_bert_keras.py +474 -0
  298. onnxruntime/transformers/onnx_model_bert_tf.py +588 -0
  299. onnxruntime/transformers/onnx_model_clip.py +42 -0
  300. onnxruntime/transformers/onnx_model_conformer.py +32 -0
  301. onnxruntime/transformers/onnx_model_gpt2.py +101 -0
  302. onnxruntime/transformers/onnx_model_mmdit.py +112 -0
  303. onnxruntime/transformers/onnx_model_phi.py +929 -0
  304. onnxruntime/transformers/onnx_model_sam2.py +137 -0
  305. onnxruntime/transformers/onnx_model_t5.py +985 -0
  306. onnxruntime/transformers/onnx_model_tnlr.py +226 -0
  307. onnxruntime/transformers/onnx_model_unet.py +258 -0
  308. onnxruntime/transformers/onnx_model_vae.py +42 -0
  309. onnxruntime/transformers/onnx_utils.py +55 -0
  310. onnxruntime/transformers/optimizer.py +620 -0
  311. onnxruntime/transformers/past_helper.py +149 -0
  312. onnxruntime/transformers/profile_result_processor.py +358 -0
  313. onnxruntime/transformers/profiler.py +434 -0
  314. onnxruntime/transformers/quantize_helper.py +76 -0
  315. onnxruntime/transformers/shape_infer_helper.py +121 -0
  316. onnxruntime/transformers/shape_optimizer.py +400 -0
  317. onnxruntime/transformers/torch_onnx_export_helper.py +74 -0
  318. onnxruntime_directml-1.24.1.dist-info/METADATA +216 -0
  319. onnxruntime_directml-1.24.1.dist-info/RECORD +322 -0
  320. onnxruntime_directml-1.24.1.dist-info/WHEEL +5 -0
  321. onnxruntime_directml-1.24.1.dist-info/entry_points.txt +2 -0
  322. onnxruntime_directml-1.24.1.dist-info/top_level.txt +1 -0
@@ -0,0 +1,1519 @@
1
+ # -------------------------------------------------------------------------
2
+ # Copyright (c) Microsoft Corporation. All rights reserved.
3
+ # Licensed under the MIT License.
4
+ # --------------------------------------------------------------------------
5
+
6
+ import argparse
7
+ import csv
8
+ import logging
9
+ import os
10
+ import statistics
11
+ import sys
12
+ import time
13
+ from pathlib import Path
14
+
15
+ # import torch before onnxruntime so that onnxruntime uses the cuDNN in the torch package.
16
+ import torch
17
+ from benchmark_helper import measure_memory
18
+
19
+ SD_MODELS = {
20
+ "1.5": "runwayml/stable-diffusion-v1-5",
21
+ "2.0": "stabilityai/stable-diffusion-2",
22
+ "2.1": "stabilityai/stable-diffusion-2-1",
23
+ "xl-1.0": "stabilityai/stable-diffusion-xl-refiner-1.0",
24
+ "3.0M": "stabilityai/stable-diffusion-3-medium-diffusers",
25
+ "3.5M": "stabilityai/stable-diffusion-3.5-medium",
26
+ "3.5L": "stabilityai/stable-diffusion-3.5-large",
27
+ "Flux.1S": "black-forest-labs/FLUX.1-schnell",
28
+ "Flux.1D": "black-forest-labs/FLUX.1-dev",
29
+ }
30
+
31
+ PROVIDERS = {
32
+ "cuda": "CUDAExecutionProvider",
33
+ "migraphx": "MIGraphXExecutionProvider",
34
+ "tensorrt": "TensorrtExecutionProvider",
35
+ }
36
+
37
+
38
+ def example_prompts():
39
+ prompts = [
40
+ "a photo of an astronaut riding a horse on mars",
41
+ "cute grey cat with blue eyes, wearing a bowtie, acrylic painting",
42
+ "a cute magical flying dog, fantasy art drawn by disney concept artists, highly detailed, digital painting",
43
+ "an illustration of a house with large barn with many cute flower pots and beautiful blue sky scenery",
44
+ "one apple sitting on a table, still life, reflective, full color photograph, centered, close-up product",
45
+ "background texture of stones, masterpiece, artistic, stunning photo, award winner photo",
46
+ "new international organic style house, tropical surroundings, architecture, 8k, hdr",
47
+ "beautiful Renaissance Revival Estate, Hobbit-House, detailed painting, warm colors, 8k, trending on Artstation",
48
+ "blue owl, big green eyes, portrait, intricate metal design, unreal engine, octane render, realistic",
49
+ "delicate elvish moonstone necklace on a velvet background, symmetrical intricate motifs, leaves, flowers, 8k",
50
+ ]
51
+
52
+ negative_prompt = "bad composition, ugly, abnormal, malformed"
53
+
54
+ return prompts, negative_prompt
55
+
56
+
57
+ def warmup_prompts():
58
+ return "warm up", "bad"
59
+
60
+
61
+ def measure_gpu_memory(monitor_type, func, start_memory=None):
62
+ return measure_memory(is_gpu=True, func=func, monitor_type=monitor_type, start_memory=start_memory)
63
+
64
+
65
+ def get_ort_pipeline(model_name: str, directory: str, provider, disable_safety_checker: bool):
66
+ from diffusers import DDIMScheduler, OnnxStableDiffusionPipeline # noqa: PLC0415
67
+
68
+ import onnxruntime # noqa: PLC0415
69
+
70
+ if directory is not None:
71
+ assert os.path.exists(directory)
72
+ session_options = onnxruntime.SessionOptions()
73
+ pipe = OnnxStableDiffusionPipeline.from_pretrained(
74
+ directory,
75
+ provider=provider,
76
+ sess_options=session_options,
77
+ )
78
+ else:
79
+ pipe = OnnxStableDiffusionPipeline.from_pretrained(
80
+ model_name,
81
+ revision="onnx",
82
+ provider=provider,
83
+ use_auth_token=True,
84
+ )
85
+ pipe.scheduler = DDIMScheduler.from_config(pipe.scheduler.config)
86
+ pipe.set_progress_bar_config(disable=True)
87
+
88
+ if disable_safety_checker:
89
+ pipe.safety_checker = None
90
+ pipe.feature_extractor = None
91
+
92
+ return pipe
93
+
94
+
95
+ def get_torch_pipeline(model_name: str, disable_safety_checker: bool, enable_torch_compile: bool, use_xformers: bool):
96
+ if "FLUX" in model_name:
97
+ from diffusers import FluxPipeline # noqa: PLC0415
98
+
99
+ pipe = FluxPipeline.from_pretrained(model_name, torch_dtype=torch.bfloat16).to("cuda")
100
+ if enable_torch_compile:
101
+ pipe.transformer.to(memory_format=torch.channels_last)
102
+ pipe.transformer = torch.compile(pipe.transformer, mode="max-autotune", fullgraph=True)
103
+ return pipe
104
+
105
+ if "stable-diffusion-3" in model_name:
106
+ from diffusers import StableDiffusion3Pipeline # noqa: PLC0415
107
+
108
+ pipe = StableDiffusion3Pipeline.from_pretrained(model_name, torch_dtype=torch.bfloat16).to("cuda")
109
+ if enable_torch_compile:
110
+ pipe.transformer.to(memory_format=torch.channels_last)
111
+ pipe.transformer = torch.compile(pipe.transformer, mode="max-autotune", fullgraph=True)
112
+ return pipe
113
+
114
+ from diffusers import DDIMScheduler, StableDiffusionPipeline # noqa: PLC0415
115
+ from torch import channels_last, float16 # noqa: PLC0415
116
+
117
+ pipe = StableDiffusionPipeline.from_pretrained(model_name, torch_dtype=float16).to("cuda")
118
+
119
+ pipe.unet.to(memory_format=channels_last) # in-place operation
120
+
121
+ if use_xformers:
122
+ pipe.enable_xformers_memory_efficient_attention()
123
+
124
+ if enable_torch_compile:
125
+ pipe.unet = torch.compile(pipe.unet)
126
+ pipe.vae = torch.compile(pipe.vae)
127
+ pipe.text_encoder = torch.compile(pipe.text_encoder)
128
+ print("Torch compiled unet, vae and text_encoder")
129
+
130
+ pipe.scheduler = DDIMScheduler.from_config(pipe.scheduler.config)
131
+ pipe.set_progress_bar_config(disable=True)
132
+
133
+ if disable_safety_checker:
134
+ pipe.safety_checker = None
135
+ pipe.feature_extractor = None
136
+
137
+ return pipe
138
+
139
+
140
+ def get_image_filename_prefix(engine: str, model_name: str, batch_size: int, steps: int, disable_safety_checker: bool):
141
+ short_model_name = model_name.split("/")[-1].replace("stable-diffusion-", "sd")
142
+ return f"{engine}_{short_model_name}_b{batch_size}_s{steps}" + ("" if disable_safety_checker else "_safe")
143
+
144
+
145
+ def run_ort_pipeline(
146
+ pipe,
147
+ batch_size: int,
148
+ image_filename_prefix: str,
149
+ height,
150
+ width,
151
+ steps,
152
+ num_prompts,
153
+ batch_count,
154
+ start_memory,
155
+ memory_monitor_type,
156
+ skip_warmup: bool = False,
157
+ ):
158
+ from diffusers import OnnxStableDiffusionPipeline # noqa: PLC0415
159
+
160
+ assert isinstance(pipe, OnnxStableDiffusionPipeline)
161
+
162
+ prompts, negative_prompt = example_prompts()
163
+
164
+ def warmup():
165
+ if skip_warmup:
166
+ return
167
+ prompt, negative = warmup_prompts()
168
+ pipe(
169
+ prompt=[prompt] * batch_size,
170
+ height=height,
171
+ width=width,
172
+ num_inference_steps=steps,
173
+ negative_prompt=[negative] * batch_size,
174
+ )
175
+
176
+ # Run warm up, and measure GPU memory of two runs
177
+ # cuDNN/MIOpen The first run has algo search so it might need more memory)
178
+ first_run_memory = measure_gpu_memory(memory_monitor_type, warmup, start_memory)
179
+ second_run_memory = measure_gpu_memory(memory_monitor_type, warmup, start_memory)
180
+
181
+ warmup()
182
+
183
+ latency_list = []
184
+ for i, prompt in enumerate(prompts):
185
+ if i >= num_prompts:
186
+ break
187
+ inference_start = time.time()
188
+ images = pipe(
189
+ prompt=[prompt] * batch_size,
190
+ height=height,
191
+ width=width,
192
+ num_inference_steps=steps,
193
+ negative_prompt=[negative_prompt] * batch_size,
194
+ ).images
195
+ inference_end = time.time()
196
+ latency = inference_end - inference_start
197
+ latency_list.append(latency)
198
+ print(f"Inference took {latency:.3f} seconds")
199
+ for k, image in enumerate(images):
200
+ image.save(f"{image_filename_prefix}_{i}_{k}.jpg")
201
+
202
+ from onnxruntime import __version__ as ort_version # noqa: PLC0415
203
+
204
+ return {
205
+ "engine": "onnxruntime",
206
+ "version": ort_version,
207
+ "height": height,
208
+ "width": width,
209
+ "steps": steps,
210
+ "batch_size": batch_size,
211
+ "batch_count": batch_count,
212
+ "num_prompts": num_prompts,
213
+ "average_latency": sum(latency_list) / len(latency_list),
214
+ "median_latency": statistics.median(latency_list),
215
+ "first_run_memory_MB": first_run_memory,
216
+ "second_run_memory_MB": second_run_memory,
217
+ }
218
+
219
+
220
+ def get_negative_prompt_kwargs(negative_prompt, use_num_images_per_prompt, is_flux, batch_size) -> dict:
221
+ # Flux does not support negative prompt
222
+ kwargs = (
223
+ (
224
+ {"negative_prompt": negative_prompt}
225
+ if use_num_images_per_prompt
226
+ else {"negative_prompt": [negative_prompt] * batch_size}
227
+ )
228
+ if not is_flux
229
+ else {}
230
+ )
231
+
232
+ # Fix the random seed so that we can inspect the output quality easily.
233
+ if torch.cuda.is_available():
234
+ kwargs["generator"] = torch.Generator(device="cuda").manual_seed(123)
235
+
236
+ return kwargs
237
+
238
+
239
+ def run_torch_pipeline(
240
+ pipe,
241
+ batch_size: int,
242
+ image_filename_prefix: str,
243
+ height,
244
+ width,
245
+ steps,
246
+ num_prompts,
247
+ batch_count,
248
+ start_memory,
249
+ memory_monitor_type,
250
+ skip_warmup=False,
251
+ ):
252
+ prompts, negative_prompt = example_prompts()
253
+
254
+ import diffusers # noqa: PLC0415
255
+
256
+ is_flux = isinstance(pipe, diffusers.FluxPipeline)
257
+
258
+ def warmup():
259
+ if skip_warmup:
260
+ return
261
+ prompt, negative = warmup_prompts()
262
+ extra_kwargs = get_negative_prompt_kwargs(negative, False, is_flux, batch_size)
263
+ pipe(prompt=[prompt] * batch_size, height=height, width=width, num_inference_steps=steps, **extra_kwargs)
264
+
265
+ # Run warm up, and measure GPU memory of two runs (The first run has cuDNN algo search so it might need more memory)
266
+ first_run_memory = measure_gpu_memory(memory_monitor_type, warmup, start_memory)
267
+ second_run_memory = measure_gpu_memory(memory_monitor_type, warmup, start_memory)
268
+
269
+ warmup()
270
+
271
+ torch.set_grad_enabled(False)
272
+
273
+ latency_list = []
274
+ for i, prompt in enumerate(prompts):
275
+ if i >= num_prompts:
276
+ break
277
+ torch.cuda.synchronize()
278
+ inference_start = time.time()
279
+ extra_kwargs = get_negative_prompt_kwargs(negative_prompt, False, is_flux, batch_size)
280
+ images = pipe(
281
+ prompt=[prompt] * batch_size,
282
+ height=height,
283
+ width=width,
284
+ num_inference_steps=steps,
285
+ **extra_kwargs,
286
+ ).images
287
+
288
+ torch.cuda.synchronize()
289
+ inference_end = time.time()
290
+ latency = inference_end - inference_start
291
+ latency_list.append(latency)
292
+ print(f"Inference took {latency:.3f} seconds")
293
+ for k, image in enumerate(images):
294
+ image.save(f"{image_filename_prefix}_{i}_{k}.jpg")
295
+
296
+ return {
297
+ "engine": "torch",
298
+ "version": torch.__version__,
299
+ "height": height,
300
+ "width": width,
301
+ "steps": steps,
302
+ "batch_size": batch_size,
303
+ "batch_count": batch_count,
304
+ "num_prompts": num_prompts,
305
+ "average_latency": sum(latency_list) / len(latency_list),
306
+ "median_latency": statistics.median(latency_list),
307
+ "first_run_memory_MB": first_run_memory,
308
+ "second_run_memory_MB": second_run_memory,
309
+ }
310
+
311
+
312
+ def run_ort(
313
+ model_name: str,
314
+ directory: str,
315
+ provider: str,
316
+ batch_size: int,
317
+ disable_safety_checker: bool,
318
+ height: int,
319
+ width: int,
320
+ steps: int,
321
+ num_prompts: int,
322
+ batch_count: int,
323
+ start_memory,
324
+ memory_monitor_type,
325
+ tuning: bool,
326
+ skip_warmup: bool = False,
327
+ ):
328
+ provider_and_options = provider
329
+ if tuning and provider in ["CUDAExecutionProvider"]:
330
+ provider_and_options = (provider, {"tunable_op_enable": 1, "tunable_op_tuning_enable": 1})
331
+
332
+ load_start = time.time()
333
+ pipe = get_ort_pipeline(model_name, directory, provider_and_options, disable_safety_checker)
334
+ load_end = time.time()
335
+ print(f"Model loading took {load_end - load_start} seconds")
336
+
337
+ image_filename_prefix = get_image_filename_prefix("ort", model_name, batch_size, steps, disable_safety_checker)
338
+ result = run_ort_pipeline(
339
+ pipe,
340
+ batch_size,
341
+ image_filename_prefix,
342
+ height,
343
+ width,
344
+ steps,
345
+ num_prompts,
346
+ batch_count,
347
+ start_memory,
348
+ memory_monitor_type,
349
+ skip_warmup=skip_warmup,
350
+ )
351
+
352
+ result.update(
353
+ {
354
+ "model_name": model_name,
355
+ "directory": directory,
356
+ "provider": provider.replace("ExecutionProvider", ""),
357
+ "disable_safety_checker": disable_safety_checker,
358
+ "enable_cuda_graph": False,
359
+ }
360
+ )
361
+ return result
362
+
363
+
364
+ def get_optimum_ort_pipeline(
365
+ model_name: str,
366
+ directory: str,
367
+ provider="CUDAExecutionProvider",
368
+ disable_safety_checker: bool = True,
369
+ use_io_binding: bool = False,
370
+ ):
371
+ from optimum.onnxruntime import ORTPipelineForText2Image # noqa: PLC0415
372
+
373
+ if directory is not None and os.path.exists(directory):
374
+ pipeline = ORTPipelineForText2Image.from_pretrained(directory, provider=provider, use_io_binding=use_io_binding)
375
+ else:
376
+ pipeline = ORTPipelineForText2Image.from_pretrained(
377
+ model_name,
378
+ export=True,
379
+ provider=provider,
380
+ use_io_binding=use_io_binding,
381
+ )
382
+ pipeline.save_pretrained(directory)
383
+
384
+ if disable_safety_checker:
385
+ pipeline.safety_checker = None
386
+ pipeline.feature_extractor = None
387
+
388
+ return pipeline
389
+
390
+
391
+ def run_optimum_ort_pipeline(
392
+ pipe,
393
+ batch_size: int,
394
+ image_filename_prefix: str,
395
+ height,
396
+ width,
397
+ steps,
398
+ num_prompts,
399
+ batch_count,
400
+ start_memory,
401
+ memory_monitor_type,
402
+ use_num_images_per_prompt=False,
403
+ skip_warmup=False,
404
+ ):
405
+ print("Pipeline type", type(pipe))
406
+ from optimum.onnxruntime.modeling_diffusion import ORTFluxPipeline # noqa: PLC0415
407
+
408
+ is_flux = isinstance(pipe, ORTFluxPipeline)
409
+
410
+ prompts, negative_prompt = example_prompts()
411
+
412
+ def warmup():
413
+ if skip_warmup:
414
+ return
415
+ prompt, negative = warmup_prompts()
416
+ extra_kwargs = get_negative_prompt_kwargs(negative, use_num_images_per_prompt, is_flux, batch_size)
417
+ if use_num_images_per_prompt:
418
+ pipe(
419
+ prompt=prompt,
420
+ height=height,
421
+ width=width,
422
+ num_inference_steps=steps,
423
+ num_images_per_prompt=batch_count,
424
+ **extra_kwargs,
425
+ )
426
+ else:
427
+ pipe(prompt=[prompt] * batch_size, height=height, width=width, num_inference_steps=steps, **extra_kwargs)
428
+
429
+ # Run warm up, and measure GPU memory of two runs.
430
+ # The first run has algo search for cuDNN/MIOpen, so it might need more memory.
431
+ first_run_memory = measure_gpu_memory(memory_monitor_type, warmup, start_memory)
432
+ second_run_memory = measure_gpu_memory(memory_monitor_type, warmup, start_memory)
433
+
434
+ warmup()
435
+
436
+ extra_kwargs = get_negative_prompt_kwargs(negative_prompt, use_num_images_per_prompt, is_flux, batch_size)
437
+
438
+ latency_list = []
439
+ for i, prompt in enumerate(prompts):
440
+ if i >= num_prompts:
441
+ break
442
+ inference_start = time.time()
443
+ if use_num_images_per_prompt:
444
+ images = pipe(
445
+ prompt=prompt,
446
+ height=height,
447
+ width=width,
448
+ num_inference_steps=steps,
449
+ num_images_per_prompt=batch_size,
450
+ **extra_kwargs,
451
+ ).images
452
+ else:
453
+ images = pipe(
454
+ prompt=[prompt] * batch_size, height=height, width=width, num_inference_steps=steps, **extra_kwargs
455
+ ).images
456
+ inference_end = time.time()
457
+ latency = inference_end - inference_start
458
+ latency_list.append(latency)
459
+ print(f"Inference took {latency:.3f} seconds")
460
+ for k, image in enumerate(images):
461
+ image.save(f"{image_filename_prefix}_{i}_{k}.jpg")
462
+
463
+ from onnxruntime import __version__ as ort_version # noqa: PLC0415
464
+
465
+ return {
466
+ "engine": "optimum_ort",
467
+ "version": ort_version,
468
+ "height": height,
469
+ "width": width,
470
+ "steps": steps,
471
+ "batch_size": batch_size,
472
+ "batch_count": batch_count,
473
+ "num_prompts": num_prompts,
474
+ "average_latency": sum(latency_list) / len(latency_list),
475
+ "median_latency": statistics.median(latency_list),
476
+ "first_run_memory_MB": first_run_memory,
477
+ "second_run_memory_MB": second_run_memory,
478
+ }
479
+
480
+
481
+ def run_optimum_ort(
482
+ model_name: str,
483
+ directory: str,
484
+ provider: str,
485
+ batch_size: int,
486
+ disable_safety_checker: bool,
487
+ height: int,
488
+ width: int,
489
+ steps: int,
490
+ num_prompts: int,
491
+ batch_count: int,
492
+ start_memory,
493
+ memory_monitor_type,
494
+ use_io_binding: bool = False,
495
+ skip_warmup: bool = False,
496
+ ):
497
+ load_start = time.time()
498
+ pipe = get_optimum_ort_pipeline(
499
+ model_name, directory, provider, disable_safety_checker, use_io_binding=use_io_binding
500
+ )
501
+ load_end = time.time()
502
+ print(f"Model loading took {load_end - load_start} seconds")
503
+
504
+ full_model_name = model_name + "_" + Path(directory).name if directory else model_name
505
+ image_filename_prefix = get_image_filename_prefix(
506
+ "optimum", full_model_name, batch_size, steps, disable_safety_checker
507
+ )
508
+ result = run_optimum_ort_pipeline(
509
+ pipe,
510
+ batch_size,
511
+ image_filename_prefix,
512
+ height,
513
+ width,
514
+ steps,
515
+ num_prompts,
516
+ batch_count,
517
+ start_memory,
518
+ memory_monitor_type,
519
+ skip_warmup=skip_warmup,
520
+ )
521
+
522
+ result.update(
523
+ {
524
+ "model_name": model_name,
525
+ "directory": directory,
526
+ "provider": provider.replace("ExecutionProvider", ""),
527
+ "disable_safety_checker": disable_safety_checker,
528
+ "enable_cuda_graph": False,
529
+ }
530
+ )
531
+ return result
532
+
533
+
534
+ def run_ort_trt_static(
535
+ work_dir: str,
536
+ version: str,
537
+ batch_size: int,
538
+ disable_safety_checker: bool,
539
+ height: int,
540
+ width: int,
541
+ steps: int,
542
+ num_prompts: int,
543
+ batch_count: int,
544
+ start_memory,
545
+ memory_monitor_type,
546
+ max_batch_size: int,
547
+ nvtx_profile: bool = False,
548
+ use_cuda_graph: bool = True,
549
+ ):
550
+ print("[I] Initializing ORT TensorRT EP accelerated StableDiffusionXL txt2img pipeline (static input shape)")
551
+
552
+ # Register TensorRT plugins
553
+ from trt_utilities import init_trt_plugins # noqa: PLC0415
554
+
555
+ init_trt_plugins()
556
+
557
+ assert batch_size <= max_batch_size
558
+
559
+ from diffusion_models import PipelineInfo # noqa: PLC0415
560
+
561
+ pipeline_info = PipelineInfo(version)
562
+ short_name = pipeline_info.short_name()
563
+
564
+ from engine_builder import EngineType, get_engine_paths # noqa: PLC0415
565
+ from pipeline_stable_diffusion import StableDiffusionPipeline # noqa: PLC0415
566
+
567
+ engine_type = EngineType.ORT_TRT
568
+ onnx_dir, engine_dir, output_dir, framework_model_dir, _ = get_engine_paths(work_dir, pipeline_info, engine_type)
569
+
570
+ # Initialize pipeline
571
+ pipeline = StableDiffusionPipeline(
572
+ pipeline_info,
573
+ scheduler="DDIM",
574
+ output_dir=output_dir,
575
+ verbose=False,
576
+ nvtx_profile=nvtx_profile,
577
+ max_batch_size=max_batch_size,
578
+ use_cuda_graph=use_cuda_graph,
579
+ framework_model_dir=framework_model_dir,
580
+ engine_type=engine_type,
581
+ )
582
+
583
+ # Load TensorRT engines and pytorch modules
584
+ pipeline.backend.build_engines(
585
+ engine_dir,
586
+ framework_model_dir,
587
+ onnx_dir,
588
+ 17,
589
+ opt_image_height=height,
590
+ opt_image_width=width,
591
+ opt_batch_size=batch_size,
592
+ static_batch=True,
593
+ static_image_shape=True,
594
+ max_workspace_size=0,
595
+ device_id=torch.cuda.current_device(),
596
+ )
597
+
598
+ # Here we use static batch and image size, so the resource allocation only need done once.
599
+ # For dynamic batch and image size, some cost (like memory allocation) shall be included in latency.
600
+ pipeline.load_resources(height, width, batch_size)
601
+
602
+ def warmup():
603
+ prompt, negative = warmup_prompts()
604
+ pipeline.run([prompt] * batch_size, [negative] * batch_size, height, width, denoising_steps=steps)
605
+
606
+ # Run warm up, and measure GPU memory of two runs
607
+ # The first run has algo search so it might need more memory
608
+ first_run_memory = measure_gpu_memory(memory_monitor_type, warmup, start_memory)
609
+ second_run_memory = measure_gpu_memory(memory_monitor_type, warmup, start_memory)
610
+
611
+ warmup()
612
+
613
+ image_filename_prefix = get_image_filename_prefix("ort_trt", short_name, batch_size, steps, disable_safety_checker)
614
+
615
+ latency_list = []
616
+ prompts, negative_prompt = example_prompts()
617
+ for i, prompt in enumerate(prompts):
618
+ if i >= num_prompts:
619
+ break
620
+ inference_start = time.time()
621
+ # Use warmup mode here since non-warmup mode will save image to disk.
622
+ images, pipeline_time = pipeline.run(
623
+ [prompt] * batch_size,
624
+ [negative_prompt] * batch_size,
625
+ height,
626
+ width,
627
+ denoising_steps=steps,
628
+ guidance=7.5,
629
+ seed=123,
630
+ )
631
+ inference_end = time.time()
632
+ latency = inference_end - inference_start
633
+ latency_list.append(latency)
634
+ print(f"End2End took {latency:.3f} seconds. Inference latency: {pipeline_time}")
635
+ for k, image in enumerate(images):
636
+ image.save(f"{image_filename_prefix}_{i}_{k}.jpg")
637
+
638
+ pipeline.teardown()
639
+
640
+ from tensorrt import __version__ as trt_version # noqa: PLC0415
641
+
642
+ from onnxruntime import __version__ as ort_version # noqa: PLC0415
643
+
644
+ return {
645
+ "model_name": pipeline_info.name(),
646
+ "engine": "onnxruntime",
647
+ "version": ort_version,
648
+ "provider": f"tensorrt({trt_version})",
649
+ "directory": engine_dir,
650
+ "height": height,
651
+ "width": width,
652
+ "steps": steps,
653
+ "batch_size": batch_size,
654
+ "batch_count": batch_count,
655
+ "num_prompts": num_prompts,
656
+ "average_latency": sum(latency_list) / len(latency_list),
657
+ "median_latency": statistics.median(latency_list),
658
+ "first_run_memory_MB": first_run_memory,
659
+ "second_run_memory_MB": second_run_memory,
660
+ "disable_safety_checker": disable_safety_checker,
661
+ "enable_cuda_graph": use_cuda_graph,
662
+ }
663
+
664
+
665
+ def run_tensorrt_static(
666
+ work_dir: str,
667
+ version: str,
668
+ model_name: str,
669
+ batch_size: int,
670
+ disable_safety_checker: bool,
671
+ height: int,
672
+ width: int,
673
+ steps: int,
674
+ num_prompts: int,
675
+ batch_count: int,
676
+ start_memory,
677
+ memory_monitor_type,
678
+ max_batch_size: int,
679
+ nvtx_profile: bool = False,
680
+ use_cuda_graph: bool = True,
681
+ skip_warmup: bool = False,
682
+ ):
683
+ print("[I] Initializing TensorRT accelerated StableDiffusionXL txt2img pipeline (static input shape)")
684
+
685
+ from cuda import cudart # noqa: PLC0415
686
+
687
+ # Register TensorRT plugins
688
+ from trt_utilities import init_trt_plugins # noqa: PLC0415
689
+
690
+ init_trt_plugins()
691
+
692
+ assert batch_size <= max_batch_size
693
+
694
+ from diffusion_models import PipelineInfo # noqa: PLC0415
695
+
696
+ pipeline_info = PipelineInfo(version)
697
+
698
+ from engine_builder import EngineType, get_engine_paths # noqa: PLC0415
699
+ from pipeline_stable_diffusion import StableDiffusionPipeline # noqa: PLC0415
700
+
701
+ engine_type = EngineType.TRT
702
+ onnx_dir, engine_dir, output_dir, framework_model_dir, timing_cache = get_engine_paths(
703
+ work_dir, pipeline_info, engine_type
704
+ )
705
+
706
+ # Initialize pipeline
707
+ pipeline = StableDiffusionPipeline(
708
+ pipeline_info,
709
+ scheduler="DDIM",
710
+ output_dir=output_dir,
711
+ verbose=False,
712
+ nvtx_profile=nvtx_profile,
713
+ max_batch_size=max_batch_size,
714
+ use_cuda_graph=True,
715
+ engine_type=engine_type,
716
+ )
717
+
718
+ # Load TensorRT engines and pytorch modules
719
+ pipeline.backend.load_engines(
720
+ engine_dir=engine_dir,
721
+ framework_model_dir=framework_model_dir,
722
+ onnx_dir=onnx_dir,
723
+ onnx_opset=17,
724
+ opt_batch_size=batch_size,
725
+ opt_image_height=height,
726
+ opt_image_width=width,
727
+ static_batch=True,
728
+ static_shape=True,
729
+ enable_all_tactics=False,
730
+ timing_cache=timing_cache,
731
+ )
732
+
733
+ # activate engines
734
+ max_device_memory = max(pipeline.backend.max_device_memory(), pipeline.backend.max_device_memory())
735
+ _, shared_device_memory = cudart.cudaMalloc(max_device_memory)
736
+ pipeline.backend.activate_engines(shared_device_memory)
737
+
738
+ # Here we use static batch and image size, so the resource allocation only need done once.
739
+ # For dynamic batch and image size, some cost (like memory allocation) shall be included in latency.
740
+ pipeline.load_resources(height, width, batch_size)
741
+
742
+ def warmup():
743
+ if skip_warmup:
744
+ return
745
+ prompt, negative = warmup_prompts()
746
+ pipeline.run([prompt] * batch_size, [negative] * batch_size, height, width, denoising_steps=steps)
747
+
748
+ # Run warm up, and measure GPU memory of two runs
749
+ # The first run has algo search so it might need more memory
750
+ first_run_memory = measure_gpu_memory(memory_monitor_type, warmup, start_memory)
751
+ second_run_memory = measure_gpu_memory(memory_monitor_type, warmup, start_memory)
752
+
753
+ warmup()
754
+
755
+ image_filename_prefix = get_image_filename_prefix("trt", model_name, batch_size, steps, disable_safety_checker)
756
+
757
+ latency_list = []
758
+ prompts, negative_prompt = example_prompts()
759
+ for i, prompt in enumerate(prompts):
760
+ if i >= num_prompts:
761
+ break
762
+ inference_start = time.time()
763
+ # Use warmup mode here since non-warmup mode will save image to disk.
764
+ images, pipeline_time = pipeline.run(
765
+ [prompt] * batch_size,
766
+ [negative_prompt] * batch_size,
767
+ height,
768
+ width,
769
+ denoising_steps=steps,
770
+ seed=123,
771
+ )
772
+ inference_end = time.time()
773
+ latency = inference_end - inference_start
774
+ latency_list.append(latency)
775
+ print(f"End2End took {latency:.3f} seconds. Inference latency: {pipeline_time}")
776
+ for k, image in enumerate(images):
777
+ image.save(f"{image_filename_prefix}_{i}_{k}.jpg")
778
+
779
+ pipeline.teardown()
780
+
781
+ import tensorrt as trt # noqa: PLC0415
782
+
783
+ return {
784
+ "engine": "tensorrt",
785
+ "version": trt.__version__,
786
+ "provider": "default",
787
+ "height": height,
788
+ "width": width,
789
+ "steps": steps,
790
+ "batch_size": batch_size,
791
+ "batch_count": batch_count,
792
+ "num_prompts": num_prompts,
793
+ "average_latency": sum(latency_list) / len(latency_list),
794
+ "median_latency": statistics.median(latency_list),
795
+ "first_run_memory_MB": first_run_memory,
796
+ "second_run_memory_MB": second_run_memory,
797
+ "enable_cuda_graph": use_cuda_graph,
798
+ }
799
+
800
+
801
+ def run_tensorrt_static_xl(
802
+ work_dir: str,
803
+ version: str,
804
+ batch_size: int,
805
+ disable_safety_checker: bool,
806
+ height: int,
807
+ width: int,
808
+ steps: int,
809
+ num_prompts: int,
810
+ batch_count: int,
811
+ start_memory,
812
+ memory_monitor_type,
813
+ max_batch_size: int,
814
+ nvtx_profile: bool = False,
815
+ use_cuda_graph=True,
816
+ skip_warmup: bool = False,
817
+ ):
818
+ print("[I] Initializing TensorRT accelerated StableDiffusionXL txt2img pipeline (static input shape)")
819
+
820
+ import tensorrt as trt # noqa: PLC0415
821
+ from cuda import cudart # noqa: PLC0415
822
+ from trt_utilities import init_trt_plugins # noqa: PLC0415
823
+
824
+ # Validate image dimensions
825
+ image_height = height
826
+ image_width = width
827
+ if image_height % 8 != 0 or image_width % 8 != 0:
828
+ raise ValueError(
829
+ f"Image height and width have to be divisible by 8 but specified as: {image_height} and {image_width}."
830
+ )
831
+
832
+ # Register TensorRT plugins
833
+ init_trt_plugins()
834
+
835
+ assert batch_size <= max_batch_size
836
+
837
+ from diffusion_models import PipelineInfo # noqa: PLC0415
838
+ from engine_builder import EngineType, get_engine_paths # noqa: PLC0415
839
+
840
+ def init_pipeline(pipeline_class, pipeline_info):
841
+ engine_type = EngineType.TRT
842
+
843
+ onnx_dir, engine_dir, output_dir, framework_model_dir, timing_cache = get_engine_paths(
844
+ work_dir, pipeline_info, engine_type
845
+ )
846
+
847
+ # Initialize pipeline
848
+ pipeline = pipeline_class(
849
+ pipeline_info,
850
+ scheduler="DDIM",
851
+ output_dir=output_dir,
852
+ verbose=False,
853
+ nvtx_profile=nvtx_profile,
854
+ max_batch_size=max_batch_size,
855
+ use_cuda_graph=use_cuda_graph,
856
+ framework_model_dir=framework_model_dir,
857
+ engine_type=engine_type,
858
+ )
859
+
860
+ pipeline.backend.load_engines(
861
+ engine_dir=engine_dir,
862
+ framework_model_dir=framework_model_dir,
863
+ onnx_dir=onnx_dir,
864
+ onnx_opset=17,
865
+ opt_batch_size=batch_size,
866
+ opt_image_height=height,
867
+ opt_image_width=width,
868
+ static_batch=True,
869
+ static_shape=True,
870
+ enable_all_tactics=False,
871
+ timing_cache=timing_cache,
872
+ )
873
+ return pipeline
874
+
875
+ from pipeline_stable_diffusion import StableDiffusionPipeline # noqa: PLC0415
876
+
877
+ pipeline_info = PipelineInfo(version)
878
+ pipeline = init_pipeline(StableDiffusionPipeline, pipeline_info)
879
+
880
+ max_device_memory = max(pipeline.backend.max_device_memory(), pipeline.backend.max_device_memory())
881
+ _, shared_device_memory = cudart.cudaMalloc(max_device_memory)
882
+ pipeline.backend.activate_engines(shared_device_memory)
883
+
884
+ # Here we use static batch and image size, so the resource allocation only need done once.
885
+ # For dynamic batch and image size, some cost (like memory allocation) shall be included in latency.
886
+ pipeline.load_resources(image_height, image_width, batch_size)
887
+
888
+ def run_sd_xl_inference(prompt, negative_prompt, seed=None):
889
+ return pipeline.run(
890
+ prompt,
891
+ negative_prompt,
892
+ image_height,
893
+ image_width,
894
+ denoising_steps=steps,
895
+ guidance=5.0,
896
+ seed=seed,
897
+ )
898
+
899
+ def warmup():
900
+ if skip_warmup:
901
+ return
902
+ prompt, negative = warmup_prompts()
903
+ run_sd_xl_inference([prompt] * batch_size, [negative] * batch_size)
904
+
905
+ # Run warm up, and measure GPU memory of two runs
906
+ # The first run has algo search so it might need more memory
907
+ first_run_memory = measure_gpu_memory(memory_monitor_type, warmup, start_memory)
908
+ second_run_memory = measure_gpu_memory(memory_monitor_type, warmup, start_memory)
909
+
910
+ warmup()
911
+
912
+ model_name = pipeline_info.name()
913
+ image_filename_prefix = get_image_filename_prefix("trt", model_name, batch_size, steps, disable_safety_checker)
914
+
915
+ latency_list = []
916
+ prompts, negative_prompt = example_prompts()
917
+ for i, prompt in enumerate(prompts):
918
+ if i >= num_prompts:
919
+ break
920
+ inference_start = time.time()
921
+ # Use warmup mode here since non-warmup mode will save image to disk.
922
+ images, pipeline_time = run_sd_xl_inference([prompt] * batch_size, [negative_prompt] * batch_size, seed=123)
923
+ inference_end = time.time()
924
+ latency = inference_end - inference_start
925
+ latency_list.append(latency)
926
+ print(f"End2End took {latency:.3f} seconds. Inference latency: {pipeline_time}")
927
+ for k, image in enumerate(images):
928
+ image.save(f"{image_filename_prefix}_{i}_{k}.png")
929
+
930
+ pipeline.teardown()
931
+
932
+ return {
933
+ "model_name": model_name,
934
+ "engine": "tensorrt",
935
+ "version": trt.__version__,
936
+ "provider": "default",
937
+ "height": height,
938
+ "width": width,
939
+ "steps": steps,
940
+ "batch_size": batch_size,
941
+ "batch_count": batch_count,
942
+ "num_prompts": num_prompts,
943
+ "average_latency": sum(latency_list) / len(latency_list),
944
+ "median_latency": statistics.median(latency_list),
945
+ "first_run_memory_MB": first_run_memory,
946
+ "second_run_memory_MB": second_run_memory,
947
+ "enable_cuda_graph": use_cuda_graph,
948
+ }
949
+
950
+
951
+ def run_ort_trt_xl(
952
+ work_dir: str,
953
+ version: str,
954
+ batch_size: int,
955
+ disable_safety_checker: bool,
956
+ height: int,
957
+ width: int,
958
+ steps: int,
959
+ num_prompts: int,
960
+ batch_count: int,
961
+ start_memory,
962
+ memory_monitor_type,
963
+ max_batch_size: int,
964
+ nvtx_profile: bool = False,
965
+ use_cuda_graph=True,
966
+ skip_warmup: bool = False,
967
+ ):
968
+ from demo_utils import initialize_pipeline # noqa: PLC0415
969
+ from engine_builder import EngineType # noqa: PLC0415
970
+
971
+ pipeline = initialize_pipeline(
972
+ version=version,
973
+ engine_type=EngineType.ORT_TRT,
974
+ work_dir=work_dir,
975
+ height=height,
976
+ width=width,
977
+ use_cuda_graph=use_cuda_graph,
978
+ max_batch_size=max_batch_size,
979
+ opt_batch_size=batch_size,
980
+ )
981
+
982
+ assert batch_size <= max_batch_size
983
+
984
+ pipeline.load_resources(height, width, batch_size)
985
+
986
+ def run_sd_xl_inference(prompt, negative_prompt, seed=None):
987
+ return pipeline.run(
988
+ prompt,
989
+ negative_prompt,
990
+ height,
991
+ width,
992
+ denoising_steps=steps,
993
+ guidance=5.0,
994
+ seed=seed,
995
+ )
996
+
997
+ def warmup():
998
+ if skip_warmup:
999
+ return
1000
+ prompt, negative = warmup_prompts()
1001
+ run_sd_xl_inference([prompt] * batch_size, [negative] * batch_size)
1002
+
1003
+ # Run warm up, and measure GPU memory of two runs
1004
+ # The first run has algo search so it might need more memory
1005
+ first_run_memory = measure_gpu_memory(memory_monitor_type, warmup, start_memory)
1006
+ second_run_memory = measure_gpu_memory(memory_monitor_type, warmup, start_memory)
1007
+
1008
+ warmup()
1009
+
1010
+ model_name = pipeline.pipeline_info.name()
1011
+ image_filename_prefix = get_image_filename_prefix("ort_trt", model_name, batch_size, steps, disable_safety_checker)
1012
+
1013
+ latency_list = []
1014
+ prompts, negative_prompt = example_prompts()
1015
+ for i, prompt in enumerate(prompts):
1016
+ if i >= num_prompts:
1017
+ break
1018
+ inference_start = time.time()
1019
+ # Use warmup mode here since non-warmup mode will save image to disk.
1020
+ images, pipeline_time = run_sd_xl_inference([prompt] * batch_size, [negative_prompt] * batch_size, seed=123)
1021
+ inference_end = time.time()
1022
+ latency = inference_end - inference_start
1023
+ latency_list.append(latency)
1024
+ print(f"End2End took {latency:.3f} seconds. Inference latency: {pipeline_time}")
1025
+ for k, image in enumerate(images):
1026
+ filename = f"{image_filename_prefix}_{i}_{k}.png"
1027
+ image.save(filename)
1028
+ print("Image saved to", filename)
1029
+
1030
+ pipeline.teardown()
1031
+
1032
+ from tensorrt import __version__ as trt_version # noqa: PLC0415
1033
+
1034
+ from onnxruntime import __version__ as ort_version # noqa: PLC0415
1035
+
1036
+ return {
1037
+ "model_name": model_name,
1038
+ "engine": "onnxruntime",
1039
+ "version": ort_version,
1040
+ "provider": f"tensorrt{trt_version})",
1041
+ "height": height,
1042
+ "width": width,
1043
+ "steps": steps,
1044
+ "batch_size": batch_size,
1045
+ "batch_count": batch_count,
1046
+ "num_prompts": num_prompts,
1047
+ "average_latency": sum(latency_list) / len(latency_list),
1048
+ "median_latency": statistics.median(latency_list),
1049
+ "first_run_memory_MB": first_run_memory,
1050
+ "second_run_memory_MB": second_run_memory,
1051
+ "enable_cuda_graph": use_cuda_graph,
1052
+ }
1053
+
1054
+
1055
+ def run_torch(
1056
+ model_name: str,
1057
+ batch_size: int,
1058
+ disable_safety_checker: bool,
1059
+ enable_torch_compile: bool,
1060
+ use_xformers: bool,
1061
+ height: int,
1062
+ width: int,
1063
+ steps: int,
1064
+ num_prompts: int,
1065
+ batch_count: int,
1066
+ start_memory,
1067
+ memory_monitor_type,
1068
+ skip_warmup: bool = True,
1069
+ ):
1070
+ torch.backends.cudnn.enabled = True
1071
+ torch.backends.cudnn.benchmark = True
1072
+
1073
+ torch.set_grad_enabled(False)
1074
+
1075
+ load_start = time.time()
1076
+ pipe = get_torch_pipeline(model_name, disable_safety_checker, enable_torch_compile, use_xformers)
1077
+ load_end = time.time()
1078
+ print(f"Model loading took {load_end - load_start} seconds")
1079
+
1080
+ image_filename_prefix = get_image_filename_prefix("torch", model_name, batch_size, steps, disable_safety_checker)
1081
+
1082
+ if not enable_torch_compile:
1083
+ with torch.inference_mode():
1084
+ result = run_torch_pipeline(
1085
+ pipe,
1086
+ batch_size,
1087
+ image_filename_prefix,
1088
+ height,
1089
+ width,
1090
+ steps,
1091
+ num_prompts,
1092
+ batch_count,
1093
+ start_memory,
1094
+ memory_monitor_type,
1095
+ skip_warmup=skip_warmup,
1096
+ )
1097
+ else:
1098
+ result = run_torch_pipeline(
1099
+ pipe,
1100
+ batch_size,
1101
+ image_filename_prefix,
1102
+ height,
1103
+ width,
1104
+ steps,
1105
+ num_prompts,
1106
+ batch_count,
1107
+ start_memory,
1108
+ memory_monitor_type,
1109
+ skip_warmup=skip_warmup,
1110
+ )
1111
+
1112
+ result.update(
1113
+ {
1114
+ "model_name": model_name,
1115
+ "directory": None,
1116
+ "provider": "compile" if enable_torch_compile else "xformers" if use_xformers else "default",
1117
+ "disable_safety_checker": disable_safety_checker,
1118
+ "enable_cuda_graph": False,
1119
+ }
1120
+ )
1121
+ return result
1122
+
1123
+
1124
+ def parse_arguments():
1125
+ parser = argparse.ArgumentParser()
1126
+
1127
+ parser.add_argument(
1128
+ "-e",
1129
+ "--engine",
1130
+ required=False,
1131
+ type=str,
1132
+ default="onnxruntime",
1133
+ choices=["onnxruntime", "optimum", "torch", "tensorrt"],
1134
+ help="Engines to benchmark. Default is onnxruntime.",
1135
+ )
1136
+
1137
+ parser.add_argument(
1138
+ "-r",
1139
+ "--provider",
1140
+ required=False,
1141
+ type=str,
1142
+ default="cuda",
1143
+ choices=list(PROVIDERS.keys()),
1144
+ help="Provider to benchmark. Default is CUDAExecutionProvider.",
1145
+ )
1146
+
1147
+ parser.add_argument(
1148
+ "-t",
1149
+ "--tuning",
1150
+ action="store_true",
1151
+ help="Enable TunableOp and tuning. This will incur longer warmup latency.",
1152
+ )
1153
+
1154
+ parser.add_argument(
1155
+ "-v",
1156
+ "--version",
1157
+ required=False,
1158
+ type=str,
1159
+ choices=list(SD_MODELS.keys()),
1160
+ default="1.5",
1161
+ help="Stable diffusion version like 1.5, 2.0 or 2.1. Default is 1.5.",
1162
+ )
1163
+
1164
+ parser.add_argument(
1165
+ "-p",
1166
+ "--pipeline",
1167
+ required=False,
1168
+ type=str,
1169
+ default=None,
1170
+ help="Directory of saved onnx pipeline. It could be the output directory of optimize_pipeline.py.",
1171
+ )
1172
+
1173
+ parser.add_argument(
1174
+ "-w",
1175
+ "--work_dir",
1176
+ required=False,
1177
+ type=str,
1178
+ default=".",
1179
+ help="Root directory to save exported onnx models, built engines etc.",
1180
+ )
1181
+
1182
+ parser.add_argument(
1183
+ "--enable_safety_checker",
1184
+ required=False,
1185
+ action="store_true",
1186
+ help="Enable safety checker",
1187
+ )
1188
+ parser.set_defaults(enable_safety_checker=False)
1189
+
1190
+ parser.add_argument(
1191
+ "--enable_torch_compile",
1192
+ required=False,
1193
+ action="store_true",
1194
+ help="Enable compile unet for PyTorch 2.0",
1195
+ )
1196
+ parser.set_defaults(enable_torch_compile=False)
1197
+
1198
+ parser.add_argument(
1199
+ "--use_xformers",
1200
+ required=False,
1201
+ action="store_true",
1202
+ help="Use xformers for PyTorch",
1203
+ )
1204
+ parser.set_defaults(use_xformers=False)
1205
+
1206
+ parser.add_argument(
1207
+ "--use_io_binding",
1208
+ required=False,
1209
+ action="store_true",
1210
+ help="Use I/O Binding for Optimum.",
1211
+ )
1212
+ parser.set_defaults(use_io_binding=False)
1213
+
1214
+ parser.add_argument(
1215
+ "--skip_warmup",
1216
+ required=False,
1217
+ action="store_true",
1218
+ help="No warmup.",
1219
+ )
1220
+ parser.set_defaults(skip_warmup=False)
1221
+
1222
+ parser.add_argument(
1223
+ "-b",
1224
+ "--batch_size",
1225
+ type=int,
1226
+ default=1,
1227
+ choices=[1, 2, 3, 4, 8, 10, 16, 32],
1228
+ help="Number of images per batch. Default is 1.",
1229
+ )
1230
+
1231
+ parser.add_argument(
1232
+ "--height",
1233
+ required=False,
1234
+ type=int,
1235
+ default=512,
1236
+ help="Output image height. Default is 512.",
1237
+ )
1238
+
1239
+ parser.add_argument(
1240
+ "--width",
1241
+ required=False,
1242
+ type=int,
1243
+ default=512,
1244
+ help="Output image width. Default is 512.",
1245
+ )
1246
+
1247
+ parser.add_argument(
1248
+ "-s",
1249
+ "--steps",
1250
+ required=False,
1251
+ type=int,
1252
+ default=50,
1253
+ help="Number of steps. Default is 50.",
1254
+ )
1255
+
1256
+ parser.add_argument(
1257
+ "-n",
1258
+ "--num_prompts",
1259
+ required=False,
1260
+ type=int,
1261
+ default=10,
1262
+ help="Number of prompts. Default is 10.",
1263
+ )
1264
+
1265
+ parser.add_argument(
1266
+ "-c",
1267
+ "--batch_count",
1268
+ required=False,
1269
+ type=int,
1270
+ choices=range(1, 11),
1271
+ default=5,
1272
+ help="Number of batches to test. Default is 5.",
1273
+ )
1274
+
1275
+ parser.add_argument(
1276
+ "-m",
1277
+ "--max_trt_batch_size",
1278
+ required=False,
1279
+ type=int,
1280
+ choices=range(1, 16),
1281
+ default=4,
1282
+ help="Maximum batch size for TensorRT. Change the value may trigger TensorRT engine rebuild. Default is 4.",
1283
+ )
1284
+
1285
+ parser.add_argument(
1286
+ "-g",
1287
+ "--enable_cuda_graph",
1288
+ required=False,
1289
+ action="store_true",
1290
+ help="Enable Cuda Graph. Requires onnxruntime >= 1.16",
1291
+ )
1292
+ parser.set_defaults(enable_cuda_graph=False)
1293
+
1294
+ args = parser.parse_args()
1295
+
1296
+ return args
1297
+
1298
+
1299
+ def print_loaded_libraries(cuda_related_only=True):
1300
+ import psutil # noqa: PLC0415
1301
+
1302
+ p = psutil.Process(os.getpid())
1303
+ for lib in p.memory_maps():
1304
+ if (not cuda_related_only) or any(x in lib.path for x in ("libcu", "libnv", "tensorrt")):
1305
+ print(lib.path)
1306
+
1307
+
1308
+ def main():
1309
+ args = parse_arguments()
1310
+ print(args)
1311
+
1312
+ if args.engine == "onnxruntime":
1313
+ if args.version in ["2.1"]:
1314
+ # Set a flag to avoid overflow in attention, which causes black image output in SD 2.1 model.
1315
+ # The environment variables shall be set before the first run of Attention or MultiHeadAttention operator.
1316
+ os.environ["ORT_DISABLE_TRT_FLASH_ATTENTION"] = "1"
1317
+
1318
+ from packaging import version # noqa: PLC0415
1319
+
1320
+ from onnxruntime import __version__ as ort_version # noqa: PLC0415
1321
+
1322
+ if version.parse(ort_version) == version.parse("1.16.0"):
1323
+ # ORT 1.16 has a bug that might trigger Attention RuntimeError when latest fusion script is applied on clip model.
1324
+ # The walkaround is to enable fused causal attention, or disable Attention fusion for clip model.
1325
+ os.environ["ORT_ENABLE_FUSED_CAUSAL_ATTENTION"] = "1"
1326
+
1327
+ if args.enable_cuda_graph:
1328
+ if not (args.engine == "onnxruntime" and args.provider in ["cuda", "tensorrt"] and args.pipeline is None):
1329
+ raise ValueError("The stable diffusion pipeline does not support CUDA graph.")
1330
+
1331
+ if version.parse(ort_version) < version.parse("1.16"):
1332
+ raise ValueError("CUDA graph requires ONNX Runtime 1.16 or later")
1333
+
1334
+ logging.basicConfig(format="%(funcName)20s: %(message)s", level=logging.INFO, force=True)
1335
+
1336
+ memory_monitor_type = "cuda"
1337
+
1338
+ start_memory = measure_gpu_memory(memory_monitor_type, None)
1339
+ print("GPU memory used before loading models:", start_memory)
1340
+
1341
+ sd_model = SD_MODELS[args.version]
1342
+ provider = PROVIDERS[args.provider]
1343
+ if args.engine == "onnxruntime" and args.provider == "tensorrt":
1344
+ if "xl" in args.version:
1345
+ print("Testing Txt2ImgXLPipeline with static input shape. Backend is ORT TensorRT EP.")
1346
+ result = run_ort_trt_xl(
1347
+ work_dir=args.work_dir,
1348
+ version=args.version,
1349
+ batch_size=args.batch_size,
1350
+ disable_safety_checker=True,
1351
+ height=args.height,
1352
+ width=args.width,
1353
+ steps=args.steps,
1354
+ num_prompts=args.num_prompts,
1355
+ batch_count=args.batch_count,
1356
+ start_memory=start_memory,
1357
+ memory_monitor_type=memory_monitor_type,
1358
+ max_batch_size=args.max_trt_batch_size,
1359
+ nvtx_profile=False,
1360
+ use_cuda_graph=args.enable_cuda_graph,
1361
+ skip_warmup=args.skip_warmup,
1362
+ )
1363
+ else:
1364
+ print("Testing Txt2ImgPipeline with static input shape. Backend is ORT TensorRT EP.")
1365
+ result = run_ort_trt_static(
1366
+ work_dir=args.work_dir,
1367
+ version=args.version,
1368
+ batch_size=args.batch_size,
1369
+ disable_safety_checker=not args.enable_safety_checker,
1370
+ height=args.height,
1371
+ width=args.width,
1372
+ steps=args.steps,
1373
+ num_prompts=args.num_prompts,
1374
+ batch_count=args.batch_count,
1375
+ start_memory=start_memory,
1376
+ memory_monitor_type=memory_monitor_type,
1377
+ max_batch_size=args.max_trt_batch_size,
1378
+ nvtx_profile=False,
1379
+ use_cuda_graph=args.enable_cuda_graph,
1380
+ skip_warmup=args.skip_warmup,
1381
+ )
1382
+ elif args.engine == "optimum" and provider == "CUDAExecutionProvider":
1383
+ if "xl" in args.version:
1384
+ os.environ["ORT_ENABLE_FUSED_CAUSAL_ATTENTION"] = "1"
1385
+
1386
+ result = run_optimum_ort(
1387
+ model_name=sd_model,
1388
+ directory=args.pipeline,
1389
+ provider=provider,
1390
+ batch_size=args.batch_size,
1391
+ disable_safety_checker=not args.enable_safety_checker,
1392
+ height=args.height,
1393
+ width=args.width,
1394
+ steps=args.steps,
1395
+ num_prompts=args.num_prompts,
1396
+ batch_count=args.batch_count,
1397
+ start_memory=start_memory,
1398
+ memory_monitor_type=memory_monitor_type,
1399
+ use_io_binding=args.use_io_binding,
1400
+ skip_warmup=args.skip_warmup,
1401
+ )
1402
+ elif args.engine == "onnxruntime":
1403
+ assert args.pipeline and os.path.isdir(args.pipeline), (
1404
+ "--pipeline should be specified for the directory of ONNX models"
1405
+ )
1406
+ print(f"Testing diffusers StableDiffusionPipeline with {provider} provider and tuning={args.tuning}")
1407
+ result = run_ort(
1408
+ model_name=sd_model,
1409
+ directory=args.pipeline,
1410
+ provider=provider,
1411
+ batch_size=args.batch_size,
1412
+ disable_safety_checker=not args.enable_safety_checker,
1413
+ height=args.height,
1414
+ width=args.width,
1415
+ steps=args.steps,
1416
+ num_prompts=args.num_prompts,
1417
+ batch_count=args.batch_count,
1418
+ start_memory=start_memory,
1419
+ memory_monitor_type=memory_monitor_type,
1420
+ tuning=args.tuning,
1421
+ skip_warmup=args.skip_warmup,
1422
+ )
1423
+ elif args.engine == "tensorrt" and "xl" in args.version:
1424
+ print("Testing Txt2ImgXLPipeline with static input shape. Backend is TensorRT.")
1425
+ result = run_tensorrt_static_xl(
1426
+ work_dir=args.work_dir,
1427
+ version=args.version,
1428
+ batch_size=args.batch_size,
1429
+ disable_safety_checker=True,
1430
+ height=args.height,
1431
+ width=args.width,
1432
+ steps=args.steps,
1433
+ num_prompts=args.num_prompts,
1434
+ batch_count=args.batch_count,
1435
+ start_memory=start_memory,
1436
+ memory_monitor_type=memory_monitor_type,
1437
+ max_batch_size=args.max_trt_batch_size,
1438
+ nvtx_profile=False,
1439
+ use_cuda_graph=args.enable_cuda_graph,
1440
+ skip_warmup=args.skip_warmup,
1441
+ )
1442
+ elif args.engine == "tensorrt":
1443
+ print("Testing Txt2ImgPipeline with static input shape. Backend is TensorRT.")
1444
+ result = run_tensorrt_static(
1445
+ work_dir=args.work_dir,
1446
+ version=args.version,
1447
+ model_name=sd_model,
1448
+ batch_size=args.batch_size,
1449
+ disable_safety_checker=True,
1450
+ height=args.height,
1451
+ width=args.width,
1452
+ steps=args.steps,
1453
+ num_prompts=args.num_prompts,
1454
+ batch_count=args.batch_count,
1455
+ start_memory=start_memory,
1456
+ memory_monitor_type=memory_monitor_type,
1457
+ max_batch_size=args.max_trt_batch_size,
1458
+ nvtx_profile=False,
1459
+ use_cuda_graph=args.enable_cuda_graph,
1460
+ skip_warmup=args.skip_warmup,
1461
+ )
1462
+ else:
1463
+ print(
1464
+ f"Testing Txt2ImgPipeline with dynamic input shape. Backend is PyTorch: compile={args.enable_torch_compile}, xformers={args.use_xformers}."
1465
+ )
1466
+ result = run_torch(
1467
+ model_name=sd_model,
1468
+ batch_size=args.batch_size,
1469
+ disable_safety_checker=not args.enable_safety_checker,
1470
+ enable_torch_compile=args.enable_torch_compile,
1471
+ use_xformers=args.use_xformers,
1472
+ height=args.height,
1473
+ width=args.width,
1474
+ steps=args.steps,
1475
+ num_prompts=args.num_prompts,
1476
+ batch_count=args.batch_count,
1477
+ start_memory=start_memory,
1478
+ memory_monitor_type=memory_monitor_type,
1479
+ skip_warmup=args.skip_warmup,
1480
+ )
1481
+
1482
+ print(result)
1483
+
1484
+ with open("benchmark_result.csv", mode="a", newline="") as csv_file:
1485
+ column_names = [
1486
+ "model_name",
1487
+ "directory",
1488
+ "engine",
1489
+ "version",
1490
+ "provider",
1491
+ "disable_safety_checker",
1492
+ "height",
1493
+ "width",
1494
+ "steps",
1495
+ "batch_size",
1496
+ "batch_count",
1497
+ "num_prompts",
1498
+ "average_latency",
1499
+ "median_latency",
1500
+ "first_run_memory_MB",
1501
+ "second_run_memory_MB",
1502
+ "enable_cuda_graph",
1503
+ ]
1504
+ csv_writer = csv.DictWriter(csv_file, fieldnames=column_names)
1505
+ csv_writer.writeheader()
1506
+ csv_writer.writerow(result)
1507
+
1508
+ # Show loaded DLLs when steps == 1 for debugging purpose.
1509
+ if args.steps == 1:
1510
+ print_loaded_libraries(args.provider in ["cuda", "tensorrt"])
1511
+
1512
+
1513
+ if __name__ == "__main__":
1514
+ import traceback
1515
+
1516
+ try:
1517
+ main()
1518
+ except Exception:
1519
+ traceback.print_exception(*sys.exc_info())