onnxruntime-directml 1.24.1__cp314-cp314-win_amd64.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (322) hide show
  1. onnxruntime/LICENSE +21 -0
  2. onnxruntime/Privacy.md +21 -0
  3. onnxruntime/ThirdPartyNotices.txt +6121 -0
  4. onnxruntime/__init__.py +418 -0
  5. onnxruntime/backend/__init__.py +6 -0
  6. onnxruntime/backend/backend.py +175 -0
  7. onnxruntime/backend/backend_rep.py +52 -0
  8. onnxruntime/capi/DirectML.dll +0 -0
  9. onnxruntime/capi/__init__.py +4 -0
  10. onnxruntime/capi/_ld_preload.py +7 -0
  11. onnxruntime/capi/_pybind_state.py +33 -0
  12. onnxruntime/capi/build_and_package_info.py +2 -0
  13. onnxruntime/capi/convert_npz_to_onnx_adapter.py +48 -0
  14. onnxruntime/capi/onnxruntime.dll +0 -0
  15. onnxruntime/capi/onnxruntime_collect_build_info.py +47 -0
  16. onnxruntime/capi/onnxruntime_inference_collection.py +1440 -0
  17. onnxruntime/capi/onnxruntime_providers_shared.dll +0 -0
  18. onnxruntime/capi/onnxruntime_pybind11_state.pyd +0 -0
  19. onnxruntime/capi/onnxruntime_validation.py +154 -0
  20. onnxruntime/capi/version_info.py +2 -0
  21. onnxruntime/datasets/__init__.py +18 -0
  22. onnxruntime/datasets/logreg_iris.onnx +0 -0
  23. onnxruntime/datasets/mul_1.onnx +0 -0
  24. onnxruntime/datasets/sigmoid.onnx +13 -0
  25. onnxruntime/quantization/CalTableFlatBuffers/KeyValue.py +78 -0
  26. onnxruntime/quantization/CalTableFlatBuffers/TrtTable.py +90 -0
  27. onnxruntime/quantization/CalTableFlatBuffers/__init__.py +0 -0
  28. onnxruntime/quantization/__init__.py +19 -0
  29. onnxruntime/quantization/base_quantizer.py +529 -0
  30. onnxruntime/quantization/calibrate.py +1267 -0
  31. onnxruntime/quantization/execution_providers/qnn/__init__.py +2 -0
  32. onnxruntime/quantization/execution_providers/qnn/fusion_lpnorm.py +132 -0
  33. onnxruntime/quantization/execution_providers/qnn/fusion_spacetodepth.py +162 -0
  34. onnxruntime/quantization/execution_providers/qnn/mixed_precision_overrides_utils.py +413 -0
  35. onnxruntime/quantization/execution_providers/qnn/preprocess.py +353 -0
  36. onnxruntime/quantization/execution_providers/qnn/quant_config.py +389 -0
  37. onnxruntime/quantization/fusions/__init__.py +4 -0
  38. onnxruntime/quantization/fusions/fusion.py +311 -0
  39. onnxruntime/quantization/fusions/fusion_gelu.py +272 -0
  40. onnxruntime/quantization/fusions/fusion_layernorm.py +146 -0
  41. onnxruntime/quantization/fusions/replace_upsample_with_resize.py +96 -0
  42. onnxruntime/quantization/matmul_bnb4_quantizer.py +239 -0
  43. onnxruntime/quantization/matmul_nbits_quantizer.py +1638 -0
  44. onnxruntime/quantization/neural_compressor/__init__.py +1 -0
  45. onnxruntime/quantization/neural_compressor/onnx_model.py +1251 -0
  46. onnxruntime/quantization/neural_compressor/util.py +80 -0
  47. onnxruntime/quantization/neural_compressor/weight_only.py +932 -0
  48. onnxruntime/quantization/onnx_model.py +600 -0
  49. onnxruntime/quantization/onnx_quantizer.py +1163 -0
  50. onnxruntime/quantization/operators/__init__.py +2 -0
  51. onnxruntime/quantization/operators/activation.py +119 -0
  52. onnxruntime/quantization/operators/argmax.py +18 -0
  53. onnxruntime/quantization/operators/attention.py +73 -0
  54. onnxruntime/quantization/operators/base_operator.py +26 -0
  55. onnxruntime/quantization/operators/binary_op.py +72 -0
  56. onnxruntime/quantization/operators/concat.py +62 -0
  57. onnxruntime/quantization/operators/conv.py +260 -0
  58. onnxruntime/quantization/operators/direct_q8.py +78 -0
  59. onnxruntime/quantization/operators/embed_layernorm.py +121 -0
  60. onnxruntime/quantization/operators/gather.py +64 -0
  61. onnxruntime/quantization/operators/gavgpool.py +62 -0
  62. onnxruntime/quantization/operators/gemm.py +172 -0
  63. onnxruntime/quantization/operators/lstm.py +121 -0
  64. onnxruntime/quantization/operators/matmul.py +231 -0
  65. onnxruntime/quantization/operators/maxpool.py +34 -0
  66. onnxruntime/quantization/operators/norm.py +40 -0
  67. onnxruntime/quantization/operators/pad.py +172 -0
  68. onnxruntime/quantization/operators/pooling.py +67 -0
  69. onnxruntime/quantization/operators/qdq_base_operator.py +22 -0
  70. onnxruntime/quantization/operators/resize.py +34 -0
  71. onnxruntime/quantization/operators/softmax.py +74 -0
  72. onnxruntime/quantization/operators/split.py +63 -0
  73. onnxruntime/quantization/operators/where.py +87 -0
  74. onnxruntime/quantization/preprocess.py +141 -0
  75. onnxruntime/quantization/qdq_loss_debug.py +389 -0
  76. onnxruntime/quantization/qdq_quantizer.py +1477 -0
  77. onnxruntime/quantization/quant_utils.py +1051 -0
  78. onnxruntime/quantization/quantize.py +953 -0
  79. onnxruntime/quantization/registry.py +110 -0
  80. onnxruntime/quantization/shape_inference.py +204 -0
  81. onnxruntime/quantization/static_quantize_runner.py +256 -0
  82. onnxruntime/quantization/tensor_quant_overrides.py +520 -0
  83. onnxruntime/tools/__init__.py +10 -0
  84. onnxruntime/tools/check_onnx_model_mobile_usability.py +47 -0
  85. onnxruntime/tools/convert_onnx_models_to_ort.py +380 -0
  86. onnxruntime/tools/file_utils.py +47 -0
  87. onnxruntime/tools/logger.py +11 -0
  88. onnxruntime/tools/make_dynamic_shape_fixed.py +73 -0
  89. onnxruntime/tools/mobile_helpers/__init__.py +0 -0
  90. onnxruntime/tools/mobile_helpers/coreml_supported_mlprogram_ops.md +53 -0
  91. onnxruntime/tools/mobile_helpers/coreml_supported_neuralnetwork_ops.md +43 -0
  92. onnxruntime/tools/mobile_helpers/nnapi_supported_ops.md +58 -0
  93. onnxruntime/tools/mobile_helpers/usability_checker.py +738 -0
  94. onnxruntime/tools/offline_tuning.py +169 -0
  95. onnxruntime/tools/onnx_model_utils.py +416 -0
  96. onnxruntime/tools/onnx_randomizer.py +85 -0
  97. onnxruntime/tools/onnxruntime_test.py +164 -0
  98. onnxruntime/tools/optimize_onnx_model.py +56 -0
  99. onnxruntime/tools/ort_format_model/__init__.py +27 -0
  100. onnxruntime/tools/ort_format_model/operator_type_usage_processors.py +653 -0
  101. onnxruntime/tools/ort_format_model/ort_flatbuffers_py/__init__.py +0 -0
  102. onnxruntime/tools/ort_format_model/ort_flatbuffers_py/fbs/ArgType.py +7 -0
  103. onnxruntime/tools/ort_format_model/ort_flatbuffers_py/fbs/ArgTypeAndIndex.py +67 -0
  104. onnxruntime/tools/ort_format_model/ort_flatbuffers_py/fbs/Attribute.py +337 -0
  105. onnxruntime/tools/ort_format_model/ort_flatbuffers_py/fbs/AttributeType.py +18 -0
  106. onnxruntime/tools/ort_format_model/ort_flatbuffers_py/fbs/Checkpoint.py +125 -0
  107. onnxruntime/tools/ort_format_model/ort_flatbuffers_py/fbs/DeprecatedKernelCreateInfos.py +120 -0
  108. onnxruntime/tools/ort_format_model/ort_flatbuffers_py/fbs/DeprecatedNodeIndexAndKernelDefHash.py +68 -0
  109. onnxruntime/tools/ort_format_model/ort_flatbuffers_py/fbs/DeprecatedSessionState.py +96 -0
  110. onnxruntime/tools/ort_format_model/ort_flatbuffers_py/fbs/DeprecatedSubGraphSessionState.py +72 -0
  111. onnxruntime/tools/ort_format_model/ort_flatbuffers_py/fbs/Dimension.py +71 -0
  112. onnxruntime/tools/ort_format_model/ort_flatbuffers_py/fbs/DimensionValue.py +80 -0
  113. onnxruntime/tools/ort_format_model/ort_flatbuffers_py/fbs/DimensionValueType.py +8 -0
  114. onnxruntime/tools/ort_format_model/ort_flatbuffers_py/fbs/EdgeEnd.py +32 -0
  115. onnxruntime/tools/ort_format_model/ort_flatbuffers_py/fbs/FloatProperty.py +67 -0
  116. onnxruntime/tools/ort_format_model/ort_flatbuffers_py/fbs/Graph.py +320 -0
  117. onnxruntime/tools/ort_format_model/ort_flatbuffers_py/fbs/InferenceSession.py +88 -0
  118. onnxruntime/tools/ort_format_model/ort_flatbuffers_py/fbs/IntProperty.py +67 -0
  119. onnxruntime/tools/ort_format_model/ort_flatbuffers_py/fbs/KernelTypeStrArgsEntry.py +91 -0
  120. onnxruntime/tools/ort_format_model/ort_flatbuffers_py/fbs/KernelTypeStrResolver.py +78 -0
  121. onnxruntime/tools/ort_format_model/ort_flatbuffers_py/fbs/MapType.py +71 -0
  122. onnxruntime/tools/ort_format_model/ort_flatbuffers_py/fbs/Model.py +223 -0
  123. onnxruntime/tools/ort_format_model/ort_flatbuffers_py/fbs/ModuleState.py +141 -0
  124. onnxruntime/tools/ort_format_model/ort_flatbuffers_py/fbs/Node.py +317 -0
  125. onnxruntime/tools/ort_format_model/ort_flatbuffers_py/fbs/NodeEdge.py +126 -0
  126. onnxruntime/tools/ort_format_model/ort_flatbuffers_py/fbs/NodeType.py +7 -0
  127. onnxruntime/tools/ort_format_model/ort_flatbuffers_py/fbs/NodesToOptimizeIndices.py +160 -0
  128. onnxruntime/tools/ort_format_model/ort_flatbuffers_py/fbs/OpIdKernelTypeStrArgsEntry.py +91 -0
  129. onnxruntime/tools/ort_format_model/ort_flatbuffers_py/fbs/OperatorSetId.py +67 -0
  130. onnxruntime/tools/ort_format_model/ort_flatbuffers_py/fbs/OptimizerGroup.py +117 -0
  131. onnxruntime/tools/ort_format_model/ort_flatbuffers_py/fbs/ParameterOptimizerState.py +91 -0
  132. onnxruntime/tools/ort_format_model/ort_flatbuffers_py/fbs/PropertyBag.py +152 -0
  133. onnxruntime/tools/ort_format_model/ort_flatbuffers_py/fbs/RuntimeOptimizationRecord.py +105 -0
  134. onnxruntime/tools/ort_format_model/ort_flatbuffers_py/fbs/RuntimeOptimizationRecordContainerEntry.py +91 -0
  135. onnxruntime/tools/ort_format_model/ort_flatbuffers_py/fbs/RuntimeOptimizations.py +79 -0
  136. onnxruntime/tools/ort_format_model/ort_flatbuffers_py/fbs/SequenceType.py +58 -0
  137. onnxruntime/tools/ort_format_model/ort_flatbuffers_py/fbs/Shape.py +78 -0
  138. onnxruntime/tools/ort_format_model/ort_flatbuffers_py/fbs/SparseTensor.py +114 -0
  139. onnxruntime/tools/ort_format_model/ort_flatbuffers_py/fbs/StringProperty.py +67 -0
  140. onnxruntime/tools/ort_format_model/ort_flatbuffers_py/fbs/StringStringEntry.py +67 -0
  141. onnxruntime/tools/ort_format_model/ort_flatbuffers_py/fbs/Tensor.py +203 -0
  142. onnxruntime/tools/ort_format_model/ort_flatbuffers_py/fbs/TensorDataType.py +26 -0
  143. onnxruntime/tools/ort_format_model/ort_flatbuffers_py/fbs/TensorTypeAndShape.py +71 -0
  144. onnxruntime/tools/ort_format_model/ort_flatbuffers_py/fbs/TypeInfo.py +83 -0
  145. onnxruntime/tools/ort_format_model/ort_flatbuffers_py/fbs/TypeInfoValue.py +9 -0
  146. onnxruntime/tools/ort_format_model/ort_flatbuffers_py/fbs/ValueInfo.py +84 -0
  147. onnxruntime/tools/ort_format_model/ort_flatbuffers_py/fbs/__init__.py +6 -0
  148. onnxruntime/tools/ort_format_model/ort_model_processor.py +86 -0
  149. onnxruntime/tools/ort_format_model/types.py +85 -0
  150. onnxruntime/tools/ort_format_model/utils.py +61 -0
  151. onnxruntime/tools/pytorch_export_contrib_ops.py +129 -0
  152. onnxruntime/tools/pytorch_export_helpers.py +131 -0
  153. onnxruntime/tools/qdq_helpers/__init__.py +0 -0
  154. onnxruntime/tools/qdq_helpers/optimize_qdq_model.py +37 -0
  155. onnxruntime/tools/qnn/add_trans_cast.py +292 -0
  156. onnxruntime/tools/qnn/gen_qnn_ctx_onnx_model.py +364 -0
  157. onnxruntime/tools/qnn/preprocess.py +165 -0
  158. onnxruntime/tools/reduced_build_config_parser.py +203 -0
  159. onnxruntime/tools/remove_initializer_from_input.py +37 -0
  160. onnxruntime/tools/symbolic_shape_infer.py +3094 -0
  161. onnxruntime/tools/update_onnx_opset.py +31 -0
  162. onnxruntime/transformers/__init__.py +8 -0
  163. onnxruntime/transformers/affinity_helper.py +40 -0
  164. onnxruntime/transformers/benchmark.py +942 -0
  165. onnxruntime/transformers/benchmark_helper.py +643 -0
  166. onnxruntime/transformers/bert_perf_test.py +629 -0
  167. onnxruntime/transformers/bert_test_data.py +641 -0
  168. onnxruntime/transformers/compare_bert_results.py +256 -0
  169. onnxruntime/transformers/constants.py +47 -0
  170. onnxruntime/transformers/convert_generation.py +3605 -0
  171. onnxruntime/transformers/convert_tf_models_to_pytorch.py +205 -0
  172. onnxruntime/transformers/convert_to_packing_mode.py +385 -0
  173. onnxruntime/transformers/dynamo_onnx_helper.py +205 -0
  174. onnxruntime/transformers/float16.py +501 -0
  175. onnxruntime/transformers/fusion_attention.py +1189 -0
  176. onnxruntime/transformers/fusion_attention_clip.py +340 -0
  177. onnxruntime/transformers/fusion_attention_sam2.py +533 -0
  178. onnxruntime/transformers/fusion_attention_unet.py +1307 -0
  179. onnxruntime/transformers/fusion_attention_vae.py +300 -0
  180. onnxruntime/transformers/fusion_bart_attention.py +435 -0
  181. onnxruntime/transformers/fusion_base.py +141 -0
  182. onnxruntime/transformers/fusion_bias_add.py +57 -0
  183. onnxruntime/transformers/fusion_biasgelu.py +66 -0
  184. onnxruntime/transformers/fusion_biassplitgelu.py +110 -0
  185. onnxruntime/transformers/fusion_conformer_attention.py +222 -0
  186. onnxruntime/transformers/fusion_constant_fold.py +144 -0
  187. onnxruntime/transformers/fusion_embedlayer.py +810 -0
  188. onnxruntime/transformers/fusion_fastgelu.py +492 -0
  189. onnxruntime/transformers/fusion_gelu.py +258 -0
  190. onnxruntime/transformers/fusion_gelu_approximation.py +25 -0
  191. onnxruntime/transformers/fusion_gemmfastgelu.py +121 -0
  192. onnxruntime/transformers/fusion_gpt_attention.py +546 -0
  193. onnxruntime/transformers/fusion_gpt_attention_megatron.py +355 -0
  194. onnxruntime/transformers/fusion_gpt_attention_no_past.py +260 -0
  195. onnxruntime/transformers/fusion_group_norm.py +180 -0
  196. onnxruntime/transformers/fusion_layernorm.py +489 -0
  197. onnxruntime/transformers/fusion_mha_mmdit.py +667 -0
  198. onnxruntime/transformers/fusion_nhwc_conv.py +99 -0
  199. onnxruntime/transformers/fusion_options.py +340 -0
  200. onnxruntime/transformers/fusion_qordered_attention.py +420 -0
  201. onnxruntime/transformers/fusion_qordered_gelu.py +118 -0
  202. onnxruntime/transformers/fusion_qordered_layernorm.py +122 -0
  203. onnxruntime/transformers/fusion_qordered_matmul.py +216 -0
  204. onnxruntime/transformers/fusion_quickgelu.py +74 -0
  205. onnxruntime/transformers/fusion_reshape.py +173 -0
  206. onnxruntime/transformers/fusion_rotary_attention.py +1591 -0
  207. onnxruntime/transformers/fusion_shape.py +109 -0
  208. onnxruntime/transformers/fusion_simplified_layernorm.py +165 -0
  209. onnxruntime/transformers/fusion_skip_group_norm.py +254 -0
  210. onnxruntime/transformers/fusion_skiplayernorm.py +209 -0
  211. onnxruntime/transformers/fusion_transpose.py +167 -0
  212. onnxruntime/transformers/fusion_utils.py +321 -0
  213. onnxruntime/transformers/huggingface_models.py +74 -0
  214. onnxruntime/transformers/import_utils.py +20 -0
  215. onnxruntime/transformers/io_binding_helper.py +487 -0
  216. onnxruntime/transformers/large_model_exporter.py +395 -0
  217. onnxruntime/transformers/machine_info.py +230 -0
  218. onnxruntime/transformers/metrics.py +163 -0
  219. onnxruntime/transformers/models/bart/__init__.py +12 -0
  220. onnxruntime/transformers/models/bart/export.py +98 -0
  221. onnxruntime/transformers/models/bert/__init__.py +12 -0
  222. onnxruntime/transformers/models/bert/eval_squad.py +329 -0
  223. onnxruntime/transformers/models/gpt2/__init__.py +12 -0
  224. onnxruntime/transformers/models/gpt2/benchmark_gpt2.py +413 -0
  225. onnxruntime/transformers/models/gpt2/convert_to_onnx.py +566 -0
  226. onnxruntime/transformers/models/gpt2/gpt2_helper.py +1031 -0
  227. onnxruntime/transformers/models/gpt2/gpt2_parity.py +513 -0
  228. onnxruntime/transformers/models/gpt2/gpt2_tester.py +501 -0
  229. onnxruntime/transformers/models/gpt2/parity_check_helper.py +146 -0
  230. onnxruntime/transformers/models/llama/__init__.py +12 -0
  231. onnxruntime/transformers/models/llama/benchmark.py +700 -0
  232. onnxruntime/transformers/models/llama/benchmark_all.py +488 -0
  233. onnxruntime/transformers/models/llama/benchmark_e2e.py +608 -0
  234. onnxruntime/transformers/models/llama/convert_to_onnx.py +1064 -0
  235. onnxruntime/transformers/models/llama/dist_settings.py +57 -0
  236. onnxruntime/transformers/models/llama/llama_inputs.py +504 -0
  237. onnxruntime/transformers/models/llama/llama_parity.py +343 -0
  238. onnxruntime/transformers/models/llama/llama_torch.py +47 -0
  239. onnxruntime/transformers/models/llama/quant_kv_dataloader.py +108 -0
  240. onnxruntime/transformers/models/longformer/__init__.py +12 -0
  241. onnxruntime/transformers/models/longformer/benchmark_longformer.py +821 -0
  242. onnxruntime/transformers/models/longformer/convert_to_onnx.py +413 -0
  243. onnxruntime/transformers/models/longformer/generate_test_data.py +347 -0
  244. onnxruntime/transformers/models/longformer/longformer_helper.py +76 -0
  245. onnxruntime/transformers/models/phi2/__init__.py +12 -0
  246. onnxruntime/transformers/models/phi2/convert_to_onnx.py +590 -0
  247. onnxruntime/transformers/models/phi2/inference_example.py +414 -0
  248. onnxruntime/transformers/models/sam2/__init__.py +12 -0
  249. onnxruntime/transformers/models/sam2/benchmark_sam2.py +638 -0
  250. onnxruntime/transformers/models/sam2/convert_to_onnx.py +270 -0
  251. onnxruntime/transformers/models/sam2/image_decoder.py +272 -0
  252. onnxruntime/transformers/models/sam2/image_encoder.py +236 -0
  253. onnxruntime/transformers/models/sam2/mask_decoder.py +208 -0
  254. onnxruntime/transformers/models/sam2/nvtx_helper.py +33 -0
  255. onnxruntime/transformers/models/sam2/prompt_encoder.py +189 -0
  256. onnxruntime/transformers/models/sam2/sam2_demo.py +321 -0
  257. onnxruntime/transformers/models/sam2/sam2_image_onnx_predictor.py +279 -0
  258. onnxruntime/transformers/models/sam2/sam2_utils.py +147 -0
  259. onnxruntime/transformers/models/stable_diffusion/__init__.py +12 -0
  260. onnxruntime/transformers/models/stable_diffusion/benchmark.py +1519 -0
  261. onnxruntime/transformers/models/stable_diffusion/benchmark_controlnet.py +426 -0
  262. onnxruntime/transformers/models/stable_diffusion/demo_txt2img.py +103 -0
  263. onnxruntime/transformers/models/stable_diffusion/demo_txt2img_xl.py +269 -0
  264. onnxruntime/transformers/models/stable_diffusion/demo_utils.py +778 -0
  265. onnxruntime/transformers/models/stable_diffusion/diffusion_models.py +1318 -0
  266. onnxruntime/transformers/models/stable_diffusion/diffusion_schedulers.py +1179 -0
  267. onnxruntime/transformers/models/stable_diffusion/engine_builder.py +295 -0
  268. onnxruntime/transformers/models/stable_diffusion/engine_builder_ort_cuda.py +387 -0
  269. onnxruntime/transformers/models/stable_diffusion/engine_builder_ort_trt.py +288 -0
  270. onnxruntime/transformers/models/stable_diffusion/engine_builder_tensorrt.py +395 -0
  271. onnxruntime/transformers/models/stable_diffusion/engine_builder_torch.py +108 -0
  272. onnxruntime/transformers/models/stable_diffusion/optimize_pipeline.py +590 -0
  273. onnxruntime/transformers/models/stable_diffusion/ort_optimizer.py +136 -0
  274. onnxruntime/transformers/models/stable_diffusion/pipeline_stable_diffusion.py +831 -0
  275. onnxruntime/transformers/models/stable_diffusion/trt_utilities.py +12 -0
  276. onnxruntime/transformers/models/t5/__init__.py +12 -0
  277. onnxruntime/transformers/models/t5/convert_to_onnx.py +318 -0
  278. onnxruntime/transformers/models/t5/t5_decoder.py +437 -0
  279. onnxruntime/transformers/models/t5/t5_encoder.py +70 -0
  280. onnxruntime/transformers/models/t5/t5_encoder_decoder_init.py +361 -0
  281. onnxruntime/transformers/models/t5/t5_helper.py +302 -0
  282. onnxruntime/transformers/models/whisper/__init__.py +12 -0
  283. onnxruntime/transformers/models/whisper/benchmark.py +585 -0
  284. onnxruntime/transformers/models/whisper/benchmark_all.py +526 -0
  285. onnxruntime/transformers/models/whisper/convert_to_onnx.py +609 -0
  286. onnxruntime/transformers/models/whisper/whisper_chain.py +334 -0
  287. onnxruntime/transformers/models/whisper/whisper_decoder.py +464 -0
  288. onnxruntime/transformers/models/whisper/whisper_encoder.py +164 -0
  289. onnxruntime/transformers/models/whisper/whisper_encoder_decoder_init.py +371 -0
  290. onnxruntime/transformers/models/whisper/whisper_helper.py +1035 -0
  291. onnxruntime/transformers/models/whisper/whisper_inputs.py +380 -0
  292. onnxruntime/transformers/models/whisper/whisper_jump_times.py +477 -0
  293. onnxruntime/transformers/onnx_exporter.py +719 -0
  294. onnxruntime/transformers/onnx_model.py +1636 -0
  295. onnxruntime/transformers/onnx_model_bart.py +141 -0
  296. onnxruntime/transformers/onnx_model_bert.py +488 -0
  297. onnxruntime/transformers/onnx_model_bert_keras.py +474 -0
  298. onnxruntime/transformers/onnx_model_bert_tf.py +588 -0
  299. onnxruntime/transformers/onnx_model_clip.py +42 -0
  300. onnxruntime/transformers/onnx_model_conformer.py +32 -0
  301. onnxruntime/transformers/onnx_model_gpt2.py +101 -0
  302. onnxruntime/transformers/onnx_model_mmdit.py +112 -0
  303. onnxruntime/transformers/onnx_model_phi.py +929 -0
  304. onnxruntime/transformers/onnx_model_sam2.py +137 -0
  305. onnxruntime/transformers/onnx_model_t5.py +985 -0
  306. onnxruntime/transformers/onnx_model_tnlr.py +226 -0
  307. onnxruntime/transformers/onnx_model_unet.py +258 -0
  308. onnxruntime/transformers/onnx_model_vae.py +42 -0
  309. onnxruntime/transformers/onnx_utils.py +55 -0
  310. onnxruntime/transformers/optimizer.py +620 -0
  311. onnxruntime/transformers/past_helper.py +149 -0
  312. onnxruntime/transformers/profile_result_processor.py +358 -0
  313. onnxruntime/transformers/profiler.py +434 -0
  314. onnxruntime/transformers/quantize_helper.py +76 -0
  315. onnxruntime/transformers/shape_infer_helper.py +121 -0
  316. onnxruntime/transformers/shape_optimizer.py +400 -0
  317. onnxruntime/transformers/torch_onnx_export_helper.py +74 -0
  318. onnxruntime_directml-1.24.1.dist-info/METADATA +216 -0
  319. onnxruntime_directml-1.24.1.dist-info/RECORD +322 -0
  320. onnxruntime_directml-1.24.1.dist-info/WHEEL +5 -0
  321. onnxruntime_directml-1.24.1.dist-info/entry_points.txt +2 -0
  322. onnxruntime_directml-1.24.1.dist-info/top_level.txt +1 -0
@@ -0,0 +1,831 @@
1
+ # -------------------------------------------------------------------------
2
+ # Copyright (c) Microsoft Corporation. All rights reserved.
3
+ # Licensed under the MIT License.
4
+ # --------------------------------------------------------------------------
5
+ # Modified from TensorRT demo diffusion, which has the following license:
6
+ #
7
+ # SPDX-FileCopyrightText: Copyright (c) 1993-2023 NVIDIA CORPORATION & AFFILIATES. All rights reserved.
8
+ # SPDX-License-Identifier: Apache-2.0
9
+ #
10
+ # Licensed under the Apache License, Version 2.0 (the "License");
11
+ # you may not use this file except in compliance with the License.
12
+ # You may obtain a copy of the License at
13
+ #
14
+ # http://www.apache.org/licenses/LICENSE-2.0
15
+ #
16
+ # Unless required by applicable law or agreed to in writing, software
17
+ # distributed under the License is distributed on an "AS IS" BASIS,
18
+ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
19
+ # See the License for the specific language governing permissions and
20
+ # limitations under the License.
21
+ # --------------------------------------------------------------------------
22
+
23
+ import os
24
+ import pathlib
25
+ import random
26
+ import time
27
+ from typing import Any
28
+
29
+ import numpy as np
30
+ import nvtx
31
+ import torch
32
+ from cuda import cudart
33
+ from diffusion_models import PipelineInfo, get_tokenizer
34
+ from diffusion_schedulers import DDIMScheduler, EulerAncestralDiscreteScheduler, LCMScheduler, UniPCMultistepScheduler
35
+ from engine_builder import EngineType
36
+ from engine_builder_ort_cuda import OrtCudaEngineBuilder
37
+ from engine_builder_ort_trt import OrtTensorrtEngineBuilder
38
+ from engine_builder_tensorrt import TensorrtEngineBuilder
39
+ from engine_builder_torch import TorchEngineBuilder
40
+ from PIL import Image
41
+
42
+
43
+ class StableDiffusionPipeline:
44
+ """
45
+ Stable Diffusion pipeline using TensorRT.
46
+ """
47
+
48
+ def __init__(
49
+ self,
50
+ pipeline_info: PipelineInfo,
51
+ max_batch_size=16,
52
+ scheduler="DDIM",
53
+ device="cuda",
54
+ output_dir=".",
55
+ verbose=False,
56
+ nvtx_profile=False,
57
+ use_cuda_graph=False,
58
+ framework_model_dir="pytorch_model",
59
+ engine_type: EngineType = EngineType.ORT_CUDA,
60
+ ):
61
+ """
62
+ Initializes the Diffusion pipeline.
63
+
64
+ Args:
65
+ pipeline_info (PipelineInfo):
66
+ Version and Type of pipeline.
67
+ max_batch_size (int):
68
+ Maximum batch size for dynamic batch engine.
69
+ scheduler (str):
70
+ The scheduler to guide the denoising process. Must be one of [DDIM, EulerA, UniPC, LCM].
71
+ device (str):
72
+ PyTorch device to run inference. Default: 'cuda'
73
+ output_dir (str):
74
+ Output directory for log files and image artifacts
75
+ verbose (bool):
76
+ Enable verbose logging.
77
+ nvtx_profile (bool):
78
+ Insert NVTX profiling markers.
79
+ use_cuda_graph (bool):
80
+ Use CUDA graph to capture engine execution and then launch inference
81
+ framework_model_dir (str):
82
+ cache directory for framework checkpoints
83
+ engine_type (EngineType)
84
+ backend engine type like ORT_TRT or TRT
85
+ """
86
+
87
+ self.pipeline_info = pipeline_info
88
+ self.version = pipeline_info.version
89
+
90
+ self.vae_scaling_factor = pipeline_info.vae_scaling_factor()
91
+
92
+ self.max_batch_size = max_batch_size
93
+
94
+ self.framework_model_dir = framework_model_dir
95
+ self.output_dir = output_dir
96
+ for directory in [self.framework_model_dir, self.output_dir]:
97
+ if not os.path.exists(directory):
98
+ print(f"[I] Create directory: {directory}")
99
+ pathlib.Path(directory).mkdir(parents=True)
100
+
101
+ self.device = device
102
+ self.torch_device = torch.device(device, torch.cuda.current_device())
103
+ self.verbose = verbose
104
+ self.nvtx_profile = nvtx_profile
105
+
106
+ self.use_cuda_graph = use_cuda_graph
107
+
108
+ self.tokenizer = None
109
+ self.tokenizer2 = None
110
+
111
+ self.generator = torch.Generator(device="cuda")
112
+ self.actual_steps = None
113
+
114
+ self.current_scheduler = None
115
+ self.set_scheduler(scheduler)
116
+
117
+ # backend engine
118
+ self.engine_type = engine_type
119
+ if engine_type == EngineType.TRT:
120
+ self.backend = TensorrtEngineBuilder(pipeline_info, max_batch_size, device, use_cuda_graph)
121
+ elif engine_type == EngineType.ORT_TRT:
122
+ self.backend = OrtTensorrtEngineBuilder(pipeline_info, max_batch_size, device, use_cuda_graph)
123
+ elif engine_type == EngineType.ORT_CUDA:
124
+ self.backend = OrtCudaEngineBuilder(pipeline_info, max_batch_size, device, use_cuda_graph)
125
+ elif engine_type == EngineType.TORCH:
126
+ self.backend = TorchEngineBuilder(pipeline_info, max_batch_size, device, use_cuda_graph)
127
+ else:
128
+ raise RuntimeError(f"Backend engine type {engine_type.name} is not supported")
129
+
130
+ # Load text tokenizer
131
+ if not self.pipeline_info.is_xl_refiner():
132
+ self.tokenizer = get_tokenizer(self.pipeline_info, self.framework_model_dir, subfolder="tokenizer")
133
+
134
+ if self.pipeline_info.is_xl():
135
+ self.tokenizer2 = get_tokenizer(self.pipeline_info, self.framework_model_dir, subfolder="tokenizer_2")
136
+
137
+ self.control_image_processor = None
138
+ if self.pipeline_info.is_xl() and self.pipeline_info.controlnet:
139
+ from diffusers.image_processor import VaeImageProcessor # noqa: PLC0415
140
+
141
+ self.control_image_processor = VaeImageProcessor(
142
+ vae_scale_factor=8, do_convert_rgb=True, do_normalize=False
143
+ )
144
+
145
+ # Create CUDA events
146
+ self.events = {}
147
+ for stage in ["clip", "denoise", "vae", "vae_encoder", "pil"]:
148
+ for marker in ["start", "stop"]:
149
+ self.events[stage + "-" + marker] = cudart.cudaEventCreate()[1]
150
+ self.markers = {}
151
+
152
+ def is_backend_tensorrt(self):
153
+ return self.engine_type == EngineType.TRT
154
+
155
+ def set_scheduler(self, scheduler: str):
156
+ if scheduler == self.current_scheduler:
157
+ return
158
+
159
+ # Scheduler options
160
+ sched_opts = {"num_train_timesteps": 1000, "beta_start": 0.00085, "beta_end": 0.012}
161
+ if self.version in ("2.0", "2.1"):
162
+ sched_opts["prediction_type"] = "v_prediction"
163
+ else:
164
+ sched_opts["prediction_type"] = "epsilon"
165
+
166
+ if scheduler == "DDIM":
167
+ self.scheduler = DDIMScheduler(device=self.device, **sched_opts)
168
+ elif scheduler == "EulerA":
169
+ self.scheduler = EulerAncestralDiscreteScheduler(device=self.device, **sched_opts)
170
+ elif scheduler == "UniPC":
171
+ self.scheduler = UniPCMultistepScheduler(device=self.device, **sched_opts)
172
+ elif scheduler == "LCM":
173
+ self.scheduler = LCMScheduler(device=self.device, **sched_opts)
174
+ else:
175
+ raise ValueError("Scheduler should be either DDIM, EulerA, UniPC or LCM")
176
+
177
+ self.current_scheduler = scheduler
178
+ self.denoising_steps = None
179
+
180
+ def set_denoising_steps(self, denoising_steps: int):
181
+ if not (self.denoising_steps == denoising_steps and isinstance(self.scheduler, DDIMScheduler)):
182
+ self.scheduler.set_timesteps(denoising_steps)
183
+ self.scheduler.configure()
184
+ self.denoising_steps = denoising_steps
185
+
186
+ def load_resources(self, image_height, image_width, batch_size):
187
+ # If engine is built with static input shape, call this only once after engine build.
188
+ # Otherwise, it need be called before every inference run.
189
+ self.backend.load_resources(image_height, image_width, batch_size)
190
+
191
+ def set_random_seed(self, seed):
192
+ if isinstance(seed, int):
193
+ self.generator.manual_seed(seed)
194
+ else:
195
+ self.generator.seed()
196
+
197
+ def get_current_seed(self):
198
+ return self.generator.initial_seed()
199
+
200
+ def teardown(self):
201
+ for e in self.events.values():
202
+ cudart.cudaEventDestroy(e)
203
+
204
+ if self.backend:
205
+ self.backend.teardown()
206
+
207
+ def run_engine(self, model_name, feed_dict):
208
+ return self.backend.run_engine(model_name, feed_dict)
209
+
210
+ def initialize_latents(self, batch_size, unet_channels, latent_height, latent_width):
211
+ latents_dtype = torch.float16
212
+ latents_shape = (batch_size, unet_channels, latent_height, latent_width)
213
+ latents = torch.randn(latents_shape, device=self.device, dtype=latents_dtype, generator=self.generator)
214
+ # Scale the initial noise by the standard deviation required by the scheduler
215
+ latents = latents * self.scheduler.init_noise_sigma
216
+ return latents
217
+
218
+ def initialize_timesteps(self, timesteps, strength):
219
+ """Initialize timesteps for refiner."""
220
+ self.scheduler.set_timesteps(timesteps)
221
+ offset = self.scheduler.steps_offset if hasattr(self.scheduler, "steps_offset") else 0
222
+ init_timestep = int(timesteps * strength) + offset
223
+ init_timestep = min(init_timestep, timesteps)
224
+ t_start = max(timesteps - init_timestep + offset, 0)
225
+ timesteps = self.scheduler.timesteps[t_start:].to(self.device)
226
+ return timesteps, t_start
227
+
228
+ def initialize_refiner(self, batch_size, image, strength):
229
+ """Add noise to a reference image."""
230
+ # Initialize timesteps
231
+ timesteps, t_start = self.initialize_timesteps(self.denoising_steps, strength)
232
+
233
+ latent_timestep = timesteps[:1].repeat(batch_size)
234
+
235
+ # Pre-process input image
236
+ image = self.preprocess_images(batch_size, (image,))[0]
237
+
238
+ # VAE encode init image
239
+ if image.shape[1] == 4:
240
+ init_latents = image
241
+ else:
242
+ init_latents = self.encode_image(image)
243
+
244
+ # Add noise to latents using timesteps
245
+ noise = torch.randn(init_latents.shape, device=self.device, dtype=torch.float16, generator=self.generator)
246
+
247
+ latents = self.scheduler.add_noise(init_latents, noise, t_start, latent_timestep)
248
+
249
+ return timesteps, t_start, latents
250
+
251
+ def _get_add_time_ids(
252
+ self,
253
+ original_size,
254
+ crops_coords_top_left,
255
+ target_size,
256
+ aesthetic_score,
257
+ negative_aesthetic_score,
258
+ dtype,
259
+ requires_aesthetics_score,
260
+ ):
261
+ if requires_aesthetics_score:
262
+ add_time_ids = list(original_size + crops_coords_top_left + (aesthetic_score,))
263
+ add_neg_time_ids = list(original_size + crops_coords_top_left + (negative_aesthetic_score,))
264
+ else:
265
+ add_time_ids = list(original_size + crops_coords_top_left + target_size)
266
+ add_neg_time_ids = list(original_size + crops_coords_top_left + target_size)
267
+
268
+ add_time_ids = torch.tensor([add_time_ids], dtype=dtype)
269
+ add_neg_time_ids = torch.tensor([add_neg_time_ids], dtype=dtype)
270
+
271
+ return add_time_ids, add_neg_time_ids
272
+
273
+ def start_profile(self, name, color="blue"):
274
+ if self.nvtx_profile:
275
+ self.markers[name] = nvtx.start_range(message=name, color=color)
276
+ event_name = name + "-start"
277
+ if event_name in self.events:
278
+ cudart.cudaEventRecord(self.events[event_name], 0)
279
+
280
+ def stop_profile(self, name):
281
+ event_name = name + "-stop"
282
+ if event_name in self.events:
283
+ cudart.cudaEventRecord(self.events[event_name], 0)
284
+ if self.nvtx_profile:
285
+ nvtx.end_range(self.markers[name])
286
+
287
+ def preprocess_images(self, batch_size, images=()):
288
+ self.start_profile("preprocess", color="pink")
289
+ init_images = []
290
+ for i in images:
291
+ image = i.to(self.device)
292
+ if image.shape[0] != batch_size:
293
+ image = image.repeat(batch_size, 1, 1, 1)
294
+ init_images.append(image)
295
+ self.stop_profile("preprocess")
296
+ return tuple(init_images)
297
+
298
+ def preprocess_controlnet_images(
299
+ self, batch_size, images=None, do_classifier_free_guidance=True, height=1024, width=1024
300
+ ):
301
+ """
302
+ Process a list of PIL.Image.Image as control images, and return a torch tensor.
303
+ """
304
+ if images is None:
305
+ return None
306
+ self.start_profile("preprocess", color="pink")
307
+
308
+ if not self.pipeline_info.is_xl():
309
+ images = [
310
+ torch.from_numpy(
311
+ (np.array(image.convert("RGB")).astype(np.float32) / 255.0)[..., None].transpose(3, 2, 0, 1)
312
+ )
313
+ .to(device=self.device, dtype=torch.float16)
314
+ .repeat_interleave(batch_size, dim=0)
315
+ for image in images
316
+ ]
317
+ else:
318
+ images = [
319
+ self.control_image_processor.preprocess(image, height=height, width=width)
320
+ .to(device=self.device, dtype=torch.float16)
321
+ .repeat_interleave(batch_size, dim=0)
322
+ for image in images
323
+ ]
324
+
325
+ if do_classifier_free_guidance:
326
+ images = [torch.cat([i] * 2) for i in images]
327
+ images = torch.cat([image[None, ...] for image in images], dim=0)
328
+
329
+ self.stop_profile("preprocess")
330
+ return images
331
+
332
+ def encode_prompt(
333
+ self,
334
+ prompt,
335
+ negative_prompt,
336
+ encoder="clip",
337
+ tokenizer=None,
338
+ pooled_outputs=False,
339
+ output_hidden_states=False,
340
+ force_zeros_for_empty_prompt=False,
341
+ do_classifier_free_guidance=True,
342
+ dtype=torch.float16,
343
+ ):
344
+ if tokenizer is None:
345
+ tokenizer = self.tokenizer
346
+
347
+ self.start_profile("clip", color="green")
348
+
349
+ def tokenize(prompt, output_hidden_states):
350
+ text_input_ids = (
351
+ tokenizer(
352
+ prompt,
353
+ padding="max_length",
354
+ max_length=tokenizer.model_max_length,
355
+ truncation=True,
356
+ return_tensors="pt",
357
+ )
358
+ .input_ids.type(torch.int32)
359
+ .to(self.device)
360
+ )
361
+
362
+ hidden_states = None
363
+ if self.engine_type == EngineType.TORCH:
364
+ outputs = self.backend.engines[encoder](text_input_ids)
365
+ text_embeddings = outputs[0]
366
+ if output_hidden_states:
367
+ hidden_states = outputs["last_hidden_state"]
368
+ else:
369
+ outputs = self.run_engine(encoder, {"input_ids": text_input_ids})
370
+ text_embeddings = outputs["text_embeddings"]
371
+ if output_hidden_states:
372
+ hidden_states = outputs["hidden_states"]
373
+ return text_embeddings, hidden_states
374
+
375
+ # Tokenize prompt
376
+ text_embeddings, hidden_states = tokenize(prompt, output_hidden_states)
377
+
378
+ # NOTE: output tensor for CLIP must be cloned because it will be overwritten when called again for negative prompt
379
+ text_embeddings = text_embeddings.clone()
380
+ if hidden_states is not None:
381
+ hidden_states = hidden_states.clone()
382
+
383
+ # Note: negative prompt embedding is not needed for SD XL when guidance <= 1
384
+ if do_classifier_free_guidance:
385
+ # For SD XL base, handle force_zeros_for_empty_prompt
386
+ is_empty_negative_prompt = all(not i for i in negative_prompt)
387
+ if force_zeros_for_empty_prompt and is_empty_negative_prompt:
388
+ uncond_embeddings = torch.zeros_like(text_embeddings)
389
+ if output_hidden_states:
390
+ uncond_hidden_states = torch.zeros_like(hidden_states)
391
+ else:
392
+ # Tokenize negative prompt
393
+ uncond_embeddings, uncond_hidden_states = tokenize(negative_prompt, output_hidden_states)
394
+
395
+ # Concatenate the unconditional and text embeddings into a single batch to avoid doing two forward passes for classifier free guidance
396
+ text_embeddings = torch.cat([uncond_embeddings, text_embeddings])
397
+
398
+ if output_hidden_states:
399
+ hidden_states = torch.cat([uncond_hidden_states, hidden_states])
400
+
401
+ self.stop_profile("clip")
402
+
403
+ if pooled_outputs:
404
+ # For text encoder in sdxl base
405
+ return hidden_states.to(dtype=dtype), text_embeddings.to(dtype=dtype)
406
+
407
+ if output_hidden_states:
408
+ # For text encoder 2 in sdxl base or refiner
409
+ return hidden_states.to(dtype=dtype)
410
+
411
+ # For text encoder in sd 1.5
412
+ return text_embeddings.to(dtype=dtype)
413
+
414
+ def denoise_latent(
415
+ self,
416
+ latents,
417
+ text_embeddings,
418
+ denoiser="unet",
419
+ timesteps=None,
420
+ step_offset=0,
421
+ guidance=7.5,
422
+ add_kwargs=None,
423
+ ):
424
+ do_classifier_free_guidance = guidance > 1.0
425
+
426
+ self.start_profile("denoise", color="blue")
427
+
428
+ if not isinstance(timesteps, torch.Tensor):
429
+ timesteps = self.scheduler.timesteps
430
+
431
+ for step_index, timestep in enumerate(timesteps):
432
+ # Expand the latents if we are doing classifier free guidance
433
+ latent_model_input = torch.cat([latents] * 2) if do_classifier_free_guidance else latents
434
+
435
+ latent_model_input = self.scheduler.scale_model_input(
436
+ latent_model_input, step_offset + step_index, timestep
437
+ )
438
+
439
+ # Predict the noise residual
440
+ if self.nvtx_profile:
441
+ nvtx_unet = nvtx.start_range(message="unet", color="blue")
442
+
443
+ params = {
444
+ "sample": latent_model_input,
445
+ "timestep": timestep.to(latents.dtype),
446
+ "encoder_hidden_states": text_embeddings,
447
+ }
448
+
449
+ if add_kwargs:
450
+ params.update(add_kwargs)
451
+
452
+ noise_pred = self.run_engine(denoiser, params)["latent"]
453
+
454
+ if self.nvtx_profile:
455
+ nvtx.end_range(nvtx_unet)
456
+
457
+ # perform guidance
458
+ if do_classifier_free_guidance:
459
+ noise_pred_uncond, noise_pred_text = noise_pred.chunk(2)
460
+ noise_pred = noise_pred_uncond + guidance * (noise_pred_text - noise_pred_uncond)
461
+
462
+ if type(self.scheduler) is UniPCMultistepScheduler:
463
+ latents = self.scheduler.step(noise_pred, timestep, latents, return_dict=False)[0]
464
+ elif type(self.scheduler) is LCMScheduler:
465
+ latents = self.scheduler.step(noise_pred, timestep, latents, generator=self.generator)[0]
466
+ else:
467
+ latents = self.scheduler.step(noise_pred, latents, step_offset + step_index, timestep)
468
+
469
+ # The actual number of steps. It might be different from denoising_steps.
470
+ self.actual_steps = len(timesteps)
471
+
472
+ self.stop_profile("denoise")
473
+ return latents
474
+
475
+ def encode_image(self, image):
476
+ self.start_profile("vae_encoder", color="red")
477
+ init_latents = self.run_engine("vae_encoder", {"images": image})["latent"]
478
+ init_latents = self.vae_scaling_factor * init_latents
479
+ self.stop_profile("vae_encoder")
480
+ return init_latents
481
+
482
+ def decode_latent(self, latents):
483
+ self.start_profile("vae", color="red")
484
+ images = self.backend.vae_decode(latents)
485
+ self.stop_profile("vae")
486
+ return images
487
+
488
+ def print_summary(self, tic, toc, batch_size, vae_enc=False, pil=False) -> dict[str, Any]:
489
+ throughput = batch_size / (toc - tic)
490
+ latency_clip = cudart.cudaEventElapsedTime(self.events["clip-start"], self.events["clip-stop"])[1]
491
+ latency_unet = cudart.cudaEventElapsedTime(self.events["denoise-start"], self.events["denoise-stop"])[1]
492
+ latency_vae = cudart.cudaEventElapsedTime(self.events["vae-start"], self.events["vae-stop"])[1]
493
+ latency_vae_encoder = (
494
+ cudart.cudaEventElapsedTime(self.events["vae_encoder-start"], self.events["vae_encoder-stop"])[1]
495
+ if vae_enc
496
+ else None
497
+ )
498
+ latency_pil = cudart.cudaEventElapsedTime(self.events["pil-start"], self.events["pil-stop"])[1] if pil else None
499
+
500
+ latency = (toc - tic) * 1000.0
501
+
502
+ print("|----------------|--------------|")
503
+ print("| {:^14} | {:^12} |".format("Module", "Latency"))
504
+ print("|----------------|--------------|")
505
+ if vae_enc:
506
+ print("| {:^14} | {:>9.2f} ms |".format("VAE-Enc", latency_vae_encoder))
507
+ print("| {:^14} | {:>9.2f} ms |".format("CLIP", latency_clip))
508
+ print(
509
+ "| {:^14} | {:>9.2f} ms |".format(
510
+ "UNet" + ("+CNet" if self.pipeline_info.controlnet else "") + " x " + str(self.actual_steps),
511
+ latency_unet,
512
+ )
513
+ )
514
+ print("| {:^14} | {:>9.2f} ms |".format("VAE-Dec", latency_vae))
515
+ pipeline = "Refiner" if self.pipeline_info.is_xl_refiner() else "Pipeline"
516
+ if pil:
517
+ print("| {:^14} | {:>9.2f} ms |".format("PIL", latency_pil))
518
+ print("|----------------|--------------|")
519
+ print(f"| {pipeline:^14} | {latency:>9.2f} ms |")
520
+ print("|----------------|--------------|")
521
+ print(f"Throughput: {throughput:.2f} image/s")
522
+
523
+ perf_data = {
524
+ "latency_clip": latency_clip,
525
+ "latency_unet": latency_unet,
526
+ "latency_vae": latency_vae,
527
+ "latency_pil": latency_pil,
528
+ "latency": latency,
529
+ "throughput": throughput,
530
+ }
531
+ if vae_enc:
532
+ perf_data["latency_vae_encoder"] = latency_vae_encoder
533
+ return perf_data
534
+
535
+ @staticmethod
536
+ def pt_to_pil(images):
537
+ images = (
538
+ ((images + 1) * 255 / 2).clamp(0, 255).detach().permute(0, 2, 3, 1).round().type(torch.uint8).cpu().numpy()
539
+ )
540
+ return [Image.fromarray(images[i]) for i in range(images.shape[0])]
541
+
542
+ @staticmethod
543
+ def pt_to_numpy(images: torch.FloatTensor):
544
+ """
545
+ Convert a PyTorch tensor to a NumPy image.
546
+ """
547
+ return ((images + 1) / 2).clamp(0, 1).detach().permute(0, 2, 3, 1).float().cpu().numpy()
548
+
549
+ def metadata(self) -> dict[str, Any]:
550
+ data = {
551
+ "actual_steps": self.actual_steps,
552
+ "seed": self.get_current_seed(),
553
+ "name": self.pipeline_info.name(),
554
+ "custom_vae": self.pipeline_info.custom_fp16_vae(),
555
+ "custom_unet": self.pipeline_info.custom_unet(),
556
+ }
557
+
558
+ if self.engine_type == EngineType.ORT_CUDA:
559
+ for engine_name, engine in self.backend.engines.items():
560
+ data.update(engine.metadata(engine_name))
561
+
562
+ return data
563
+
564
+ def save_images(self, images: list, prompt: list[str], negative_prompt: list[str], metadata: dict[str, Any]):
565
+ session_id = str(random.randint(1000, 9999))
566
+ for i, image in enumerate(images):
567
+ seed = str(self.get_current_seed())
568
+ prefix = "".join(x for x in prompt[i] if x.isalnum() or x in ", -").replace(" ", "_")[:20]
569
+ parts = [prefix, session_id, str(i + 1), str(seed), self.current_scheduler, str(self.actual_steps)]
570
+ image_path = os.path.join(self.output_dir, "-".join(parts) + ".png")
571
+ print(f"Saving image {i + 1} / {len(images)} to: {image_path}")
572
+
573
+ from PIL import PngImagePlugin # noqa: PLC0415
574
+
575
+ info = PngImagePlugin.PngInfo()
576
+ for k, v in metadata.items():
577
+ info.add_text(k, str(v))
578
+ info.add_text("prompt", prompt[i])
579
+ info.add_text("negative_prompt", negative_prompt[i])
580
+
581
+ image.save(image_path, "PNG", pnginfo=info)
582
+
583
+ def _infer(
584
+ self,
585
+ prompt,
586
+ negative_prompt,
587
+ image_height,
588
+ image_width,
589
+ denoising_steps=30,
590
+ guidance=5.0,
591
+ seed=None,
592
+ image=None,
593
+ strength=0.3,
594
+ controlnet_images=None,
595
+ controlnet_scales=None,
596
+ show_latency=False,
597
+ output_type="pil",
598
+ ):
599
+ if show_latency:
600
+ torch.cuda.synchronize()
601
+ start_time = time.perf_counter()
602
+
603
+ assert len(prompt) == len(negative_prompt)
604
+ batch_size = len(prompt)
605
+
606
+ self.set_denoising_steps(denoising_steps)
607
+ self.set_random_seed(seed)
608
+
609
+ timesteps = None
610
+ step_offset = 0
611
+ with torch.inference_mode(), torch.autocast("cuda"):
612
+ if image is not None:
613
+ timesteps, step_offset, latents = self.initialize_refiner(
614
+ batch_size=batch_size,
615
+ image=image,
616
+ strength=strength,
617
+ )
618
+ else:
619
+ # Pre-initialize latents
620
+ latents = self.initialize_latents(
621
+ batch_size=batch_size,
622
+ unet_channels=4,
623
+ latent_height=(image_height // 8),
624
+ latent_width=(image_width // 8),
625
+ )
626
+
627
+ do_classifier_free_guidance = guidance > 1.0
628
+ if not self.pipeline_info.is_xl():
629
+ denoiser = "unet"
630
+ text_embeddings = self.encode_prompt(
631
+ prompt,
632
+ negative_prompt,
633
+ do_classifier_free_guidance=do_classifier_free_guidance,
634
+ dtype=latents.dtype,
635
+ )
636
+ add_kwargs = {}
637
+ else:
638
+ denoiser = "unetxl"
639
+
640
+ # Time embeddings
641
+ original_size = (image_height, image_width)
642
+ crops_coords_top_left = (0, 0)
643
+ target_size = (image_height, image_width)
644
+ aesthetic_score = 6.0
645
+ negative_aesthetic_score = 2.5
646
+ add_time_ids, add_negative_time_ids = self._get_add_time_ids(
647
+ original_size,
648
+ crops_coords_top_left,
649
+ target_size,
650
+ aesthetic_score,
651
+ negative_aesthetic_score,
652
+ dtype=latents.dtype,
653
+ requires_aesthetics_score=self.pipeline_info.is_xl_refiner(),
654
+ )
655
+ if do_classifier_free_guidance:
656
+ add_time_ids = torch.cat([add_negative_time_ids, add_time_ids], dim=0)
657
+ add_time_ids = add_time_ids.to(device=self.device).repeat(batch_size, 1)
658
+
659
+ if self.pipeline_info.is_xl_refiner():
660
+ # CLIP text encoder 2
661
+ text_embeddings, pooled_embeddings2 = self.encode_prompt(
662
+ prompt,
663
+ negative_prompt,
664
+ encoder="clip2",
665
+ tokenizer=self.tokenizer2,
666
+ pooled_outputs=True,
667
+ output_hidden_states=True,
668
+ dtype=latents.dtype,
669
+ )
670
+ add_kwargs = {"text_embeds": pooled_embeddings2, "time_ids": add_time_ids}
671
+ else: # XL Base
672
+ # CLIP text encoder
673
+ text_embeddings = self.encode_prompt(
674
+ prompt,
675
+ negative_prompt,
676
+ encoder="clip",
677
+ tokenizer=self.tokenizer,
678
+ output_hidden_states=True,
679
+ force_zeros_for_empty_prompt=True,
680
+ do_classifier_free_guidance=do_classifier_free_guidance,
681
+ dtype=latents.dtype,
682
+ )
683
+ # CLIP text encoder 2
684
+ text_embeddings2, pooled_embeddings2 = self.encode_prompt(
685
+ prompt,
686
+ negative_prompt,
687
+ encoder="clip2",
688
+ tokenizer=self.tokenizer2,
689
+ pooled_outputs=True,
690
+ output_hidden_states=True,
691
+ force_zeros_for_empty_prompt=True,
692
+ do_classifier_free_guidance=do_classifier_free_guidance,
693
+ dtype=latents.dtype,
694
+ )
695
+
696
+ # Merged text embeddings
697
+ text_embeddings = torch.cat([text_embeddings, text_embeddings2], dim=-1)
698
+
699
+ add_kwargs = {"text_embeds": pooled_embeddings2, "time_ids": add_time_ids}
700
+
701
+ if self.pipeline_info.controlnet:
702
+ controlnet_images = self.preprocess_controlnet_images(
703
+ latents.shape[0],
704
+ controlnet_images,
705
+ do_classifier_free_guidance=do_classifier_free_guidance,
706
+ height=image_height,
707
+ width=image_width,
708
+ )
709
+ add_kwargs.update(
710
+ {
711
+ "controlnet_images": controlnet_images,
712
+ "controlnet_scales": controlnet_scales.to(controlnet_images.dtype).to(controlnet_images.device),
713
+ }
714
+ )
715
+
716
+ # UNet denoiser
717
+ latents = self.denoise_latent(
718
+ latents,
719
+ text_embeddings,
720
+ timesteps=timesteps,
721
+ step_offset=step_offset,
722
+ denoiser=denoiser,
723
+ guidance=guidance,
724
+ add_kwargs=add_kwargs,
725
+ )
726
+
727
+ with torch.inference_mode():
728
+ # VAE decode latent
729
+ if output_type == "latent":
730
+ images = latents
731
+ else:
732
+ images = self.decode_latent(latents / self.vae_scaling_factor)
733
+ if output_type == "pil":
734
+ self.start_profile("pil", color="green")
735
+ images = self.pt_to_pil(images)
736
+ self.stop_profile("pil")
737
+
738
+ perf_data = None
739
+ if show_latency:
740
+ torch.cuda.synchronize()
741
+ end_time = time.perf_counter()
742
+ perf_data = self.print_summary(
743
+ start_time, end_time, batch_size, vae_enc=self.pipeline_info.is_xl_refiner(), pil=(output_type == "pil")
744
+ )
745
+
746
+ return images, perf_data
747
+
748
+ def run(
749
+ self,
750
+ prompt: list[str],
751
+ negative_prompt: list[str],
752
+ image_height: int,
753
+ image_width: int,
754
+ denoising_steps: int = 30,
755
+ guidance: float = 5.0,
756
+ seed: int | None = None,
757
+ image: torch.Tensor | None = None,
758
+ strength: float = 0.3,
759
+ controlnet_images: torch.Tensor | None = None,
760
+ controlnet_scales: torch.Tensor | None = None,
761
+ show_latency: bool = False,
762
+ output_type: str = "pil",
763
+ deterministic: bool = False,
764
+ ):
765
+ """
766
+ Run the diffusion pipeline.
767
+
768
+ Args:
769
+ prompt (List[str]):
770
+ The text prompt to guide image generation.
771
+ negative_prompt (List[str]):
772
+ The prompt not to guide the image generation.
773
+ image_height (int):
774
+ Height (in pixels) of the image to be generated. Must be a multiple of 8.
775
+ image_width (int):
776
+ Width (in pixels) of the image to be generated. Must be a multiple of 8.
777
+ denoising_steps (int):
778
+ Number of denoising steps. More steps usually lead to higher quality image at the expense of slower inference.
779
+ guidance (float):
780
+ Higher guidance scale encourages to generate images that are closely linked to the text prompt.
781
+ seed (int):
782
+ Seed for the random generator
783
+ image (tuple[torch.Tensor]):
784
+ Reference image.
785
+ strength (float):
786
+ Indicates extent to transform the reference image, which is used as a starting point,
787
+ and more noise is added the higher the strength.
788
+ show_latency (bool):
789
+ Whether return latency data.
790
+ output_type (str):
791
+ It can be "latent", "pt" or "pil".
792
+ """
793
+ if deterministic:
794
+ torch.use_deterministic_algorithms(True)
795
+
796
+ if self.is_backend_tensorrt():
797
+ import tensorrt as trt # noqa: PLC0415
798
+ from trt_utilities import TRT_LOGGER # noqa: PLC0415
799
+
800
+ with trt.Runtime(TRT_LOGGER):
801
+ return self._infer(
802
+ prompt,
803
+ negative_prompt,
804
+ image_height,
805
+ image_width,
806
+ denoising_steps=denoising_steps,
807
+ guidance=guidance,
808
+ seed=seed,
809
+ image=image,
810
+ strength=strength,
811
+ controlnet_images=controlnet_images,
812
+ controlnet_scales=controlnet_scales,
813
+ show_latency=show_latency,
814
+ output_type=output_type,
815
+ )
816
+ else:
817
+ return self._infer(
818
+ prompt,
819
+ negative_prompt,
820
+ image_height,
821
+ image_width,
822
+ denoising_steps=denoising_steps,
823
+ guidance=guidance,
824
+ seed=seed,
825
+ image=image,
826
+ strength=strength,
827
+ controlnet_images=controlnet_images,
828
+ controlnet_scales=controlnet_scales,
829
+ show_latency=show_latency,
830
+ output_type=output_type,
831
+ )