onnxruntime-directml 1.24.1__cp314-cp314-win_amd64.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (322) hide show
  1. onnxruntime/LICENSE +21 -0
  2. onnxruntime/Privacy.md +21 -0
  3. onnxruntime/ThirdPartyNotices.txt +6121 -0
  4. onnxruntime/__init__.py +418 -0
  5. onnxruntime/backend/__init__.py +6 -0
  6. onnxruntime/backend/backend.py +175 -0
  7. onnxruntime/backend/backend_rep.py +52 -0
  8. onnxruntime/capi/DirectML.dll +0 -0
  9. onnxruntime/capi/__init__.py +4 -0
  10. onnxruntime/capi/_ld_preload.py +7 -0
  11. onnxruntime/capi/_pybind_state.py +33 -0
  12. onnxruntime/capi/build_and_package_info.py +2 -0
  13. onnxruntime/capi/convert_npz_to_onnx_adapter.py +48 -0
  14. onnxruntime/capi/onnxruntime.dll +0 -0
  15. onnxruntime/capi/onnxruntime_collect_build_info.py +47 -0
  16. onnxruntime/capi/onnxruntime_inference_collection.py +1440 -0
  17. onnxruntime/capi/onnxruntime_providers_shared.dll +0 -0
  18. onnxruntime/capi/onnxruntime_pybind11_state.pyd +0 -0
  19. onnxruntime/capi/onnxruntime_validation.py +154 -0
  20. onnxruntime/capi/version_info.py +2 -0
  21. onnxruntime/datasets/__init__.py +18 -0
  22. onnxruntime/datasets/logreg_iris.onnx +0 -0
  23. onnxruntime/datasets/mul_1.onnx +0 -0
  24. onnxruntime/datasets/sigmoid.onnx +13 -0
  25. onnxruntime/quantization/CalTableFlatBuffers/KeyValue.py +78 -0
  26. onnxruntime/quantization/CalTableFlatBuffers/TrtTable.py +90 -0
  27. onnxruntime/quantization/CalTableFlatBuffers/__init__.py +0 -0
  28. onnxruntime/quantization/__init__.py +19 -0
  29. onnxruntime/quantization/base_quantizer.py +529 -0
  30. onnxruntime/quantization/calibrate.py +1267 -0
  31. onnxruntime/quantization/execution_providers/qnn/__init__.py +2 -0
  32. onnxruntime/quantization/execution_providers/qnn/fusion_lpnorm.py +132 -0
  33. onnxruntime/quantization/execution_providers/qnn/fusion_spacetodepth.py +162 -0
  34. onnxruntime/quantization/execution_providers/qnn/mixed_precision_overrides_utils.py +413 -0
  35. onnxruntime/quantization/execution_providers/qnn/preprocess.py +353 -0
  36. onnxruntime/quantization/execution_providers/qnn/quant_config.py +389 -0
  37. onnxruntime/quantization/fusions/__init__.py +4 -0
  38. onnxruntime/quantization/fusions/fusion.py +311 -0
  39. onnxruntime/quantization/fusions/fusion_gelu.py +272 -0
  40. onnxruntime/quantization/fusions/fusion_layernorm.py +146 -0
  41. onnxruntime/quantization/fusions/replace_upsample_with_resize.py +96 -0
  42. onnxruntime/quantization/matmul_bnb4_quantizer.py +239 -0
  43. onnxruntime/quantization/matmul_nbits_quantizer.py +1638 -0
  44. onnxruntime/quantization/neural_compressor/__init__.py +1 -0
  45. onnxruntime/quantization/neural_compressor/onnx_model.py +1251 -0
  46. onnxruntime/quantization/neural_compressor/util.py +80 -0
  47. onnxruntime/quantization/neural_compressor/weight_only.py +932 -0
  48. onnxruntime/quantization/onnx_model.py +600 -0
  49. onnxruntime/quantization/onnx_quantizer.py +1163 -0
  50. onnxruntime/quantization/operators/__init__.py +2 -0
  51. onnxruntime/quantization/operators/activation.py +119 -0
  52. onnxruntime/quantization/operators/argmax.py +18 -0
  53. onnxruntime/quantization/operators/attention.py +73 -0
  54. onnxruntime/quantization/operators/base_operator.py +26 -0
  55. onnxruntime/quantization/operators/binary_op.py +72 -0
  56. onnxruntime/quantization/operators/concat.py +62 -0
  57. onnxruntime/quantization/operators/conv.py +260 -0
  58. onnxruntime/quantization/operators/direct_q8.py +78 -0
  59. onnxruntime/quantization/operators/embed_layernorm.py +121 -0
  60. onnxruntime/quantization/operators/gather.py +64 -0
  61. onnxruntime/quantization/operators/gavgpool.py +62 -0
  62. onnxruntime/quantization/operators/gemm.py +172 -0
  63. onnxruntime/quantization/operators/lstm.py +121 -0
  64. onnxruntime/quantization/operators/matmul.py +231 -0
  65. onnxruntime/quantization/operators/maxpool.py +34 -0
  66. onnxruntime/quantization/operators/norm.py +40 -0
  67. onnxruntime/quantization/operators/pad.py +172 -0
  68. onnxruntime/quantization/operators/pooling.py +67 -0
  69. onnxruntime/quantization/operators/qdq_base_operator.py +22 -0
  70. onnxruntime/quantization/operators/resize.py +34 -0
  71. onnxruntime/quantization/operators/softmax.py +74 -0
  72. onnxruntime/quantization/operators/split.py +63 -0
  73. onnxruntime/quantization/operators/where.py +87 -0
  74. onnxruntime/quantization/preprocess.py +141 -0
  75. onnxruntime/quantization/qdq_loss_debug.py +389 -0
  76. onnxruntime/quantization/qdq_quantizer.py +1477 -0
  77. onnxruntime/quantization/quant_utils.py +1051 -0
  78. onnxruntime/quantization/quantize.py +953 -0
  79. onnxruntime/quantization/registry.py +110 -0
  80. onnxruntime/quantization/shape_inference.py +204 -0
  81. onnxruntime/quantization/static_quantize_runner.py +256 -0
  82. onnxruntime/quantization/tensor_quant_overrides.py +520 -0
  83. onnxruntime/tools/__init__.py +10 -0
  84. onnxruntime/tools/check_onnx_model_mobile_usability.py +47 -0
  85. onnxruntime/tools/convert_onnx_models_to_ort.py +380 -0
  86. onnxruntime/tools/file_utils.py +47 -0
  87. onnxruntime/tools/logger.py +11 -0
  88. onnxruntime/tools/make_dynamic_shape_fixed.py +73 -0
  89. onnxruntime/tools/mobile_helpers/__init__.py +0 -0
  90. onnxruntime/tools/mobile_helpers/coreml_supported_mlprogram_ops.md +53 -0
  91. onnxruntime/tools/mobile_helpers/coreml_supported_neuralnetwork_ops.md +43 -0
  92. onnxruntime/tools/mobile_helpers/nnapi_supported_ops.md +58 -0
  93. onnxruntime/tools/mobile_helpers/usability_checker.py +738 -0
  94. onnxruntime/tools/offline_tuning.py +169 -0
  95. onnxruntime/tools/onnx_model_utils.py +416 -0
  96. onnxruntime/tools/onnx_randomizer.py +85 -0
  97. onnxruntime/tools/onnxruntime_test.py +164 -0
  98. onnxruntime/tools/optimize_onnx_model.py +56 -0
  99. onnxruntime/tools/ort_format_model/__init__.py +27 -0
  100. onnxruntime/tools/ort_format_model/operator_type_usage_processors.py +653 -0
  101. onnxruntime/tools/ort_format_model/ort_flatbuffers_py/__init__.py +0 -0
  102. onnxruntime/tools/ort_format_model/ort_flatbuffers_py/fbs/ArgType.py +7 -0
  103. onnxruntime/tools/ort_format_model/ort_flatbuffers_py/fbs/ArgTypeAndIndex.py +67 -0
  104. onnxruntime/tools/ort_format_model/ort_flatbuffers_py/fbs/Attribute.py +337 -0
  105. onnxruntime/tools/ort_format_model/ort_flatbuffers_py/fbs/AttributeType.py +18 -0
  106. onnxruntime/tools/ort_format_model/ort_flatbuffers_py/fbs/Checkpoint.py +125 -0
  107. onnxruntime/tools/ort_format_model/ort_flatbuffers_py/fbs/DeprecatedKernelCreateInfos.py +120 -0
  108. onnxruntime/tools/ort_format_model/ort_flatbuffers_py/fbs/DeprecatedNodeIndexAndKernelDefHash.py +68 -0
  109. onnxruntime/tools/ort_format_model/ort_flatbuffers_py/fbs/DeprecatedSessionState.py +96 -0
  110. onnxruntime/tools/ort_format_model/ort_flatbuffers_py/fbs/DeprecatedSubGraphSessionState.py +72 -0
  111. onnxruntime/tools/ort_format_model/ort_flatbuffers_py/fbs/Dimension.py +71 -0
  112. onnxruntime/tools/ort_format_model/ort_flatbuffers_py/fbs/DimensionValue.py +80 -0
  113. onnxruntime/tools/ort_format_model/ort_flatbuffers_py/fbs/DimensionValueType.py +8 -0
  114. onnxruntime/tools/ort_format_model/ort_flatbuffers_py/fbs/EdgeEnd.py +32 -0
  115. onnxruntime/tools/ort_format_model/ort_flatbuffers_py/fbs/FloatProperty.py +67 -0
  116. onnxruntime/tools/ort_format_model/ort_flatbuffers_py/fbs/Graph.py +320 -0
  117. onnxruntime/tools/ort_format_model/ort_flatbuffers_py/fbs/InferenceSession.py +88 -0
  118. onnxruntime/tools/ort_format_model/ort_flatbuffers_py/fbs/IntProperty.py +67 -0
  119. onnxruntime/tools/ort_format_model/ort_flatbuffers_py/fbs/KernelTypeStrArgsEntry.py +91 -0
  120. onnxruntime/tools/ort_format_model/ort_flatbuffers_py/fbs/KernelTypeStrResolver.py +78 -0
  121. onnxruntime/tools/ort_format_model/ort_flatbuffers_py/fbs/MapType.py +71 -0
  122. onnxruntime/tools/ort_format_model/ort_flatbuffers_py/fbs/Model.py +223 -0
  123. onnxruntime/tools/ort_format_model/ort_flatbuffers_py/fbs/ModuleState.py +141 -0
  124. onnxruntime/tools/ort_format_model/ort_flatbuffers_py/fbs/Node.py +317 -0
  125. onnxruntime/tools/ort_format_model/ort_flatbuffers_py/fbs/NodeEdge.py +126 -0
  126. onnxruntime/tools/ort_format_model/ort_flatbuffers_py/fbs/NodeType.py +7 -0
  127. onnxruntime/tools/ort_format_model/ort_flatbuffers_py/fbs/NodesToOptimizeIndices.py +160 -0
  128. onnxruntime/tools/ort_format_model/ort_flatbuffers_py/fbs/OpIdKernelTypeStrArgsEntry.py +91 -0
  129. onnxruntime/tools/ort_format_model/ort_flatbuffers_py/fbs/OperatorSetId.py +67 -0
  130. onnxruntime/tools/ort_format_model/ort_flatbuffers_py/fbs/OptimizerGroup.py +117 -0
  131. onnxruntime/tools/ort_format_model/ort_flatbuffers_py/fbs/ParameterOptimizerState.py +91 -0
  132. onnxruntime/tools/ort_format_model/ort_flatbuffers_py/fbs/PropertyBag.py +152 -0
  133. onnxruntime/tools/ort_format_model/ort_flatbuffers_py/fbs/RuntimeOptimizationRecord.py +105 -0
  134. onnxruntime/tools/ort_format_model/ort_flatbuffers_py/fbs/RuntimeOptimizationRecordContainerEntry.py +91 -0
  135. onnxruntime/tools/ort_format_model/ort_flatbuffers_py/fbs/RuntimeOptimizations.py +79 -0
  136. onnxruntime/tools/ort_format_model/ort_flatbuffers_py/fbs/SequenceType.py +58 -0
  137. onnxruntime/tools/ort_format_model/ort_flatbuffers_py/fbs/Shape.py +78 -0
  138. onnxruntime/tools/ort_format_model/ort_flatbuffers_py/fbs/SparseTensor.py +114 -0
  139. onnxruntime/tools/ort_format_model/ort_flatbuffers_py/fbs/StringProperty.py +67 -0
  140. onnxruntime/tools/ort_format_model/ort_flatbuffers_py/fbs/StringStringEntry.py +67 -0
  141. onnxruntime/tools/ort_format_model/ort_flatbuffers_py/fbs/Tensor.py +203 -0
  142. onnxruntime/tools/ort_format_model/ort_flatbuffers_py/fbs/TensorDataType.py +26 -0
  143. onnxruntime/tools/ort_format_model/ort_flatbuffers_py/fbs/TensorTypeAndShape.py +71 -0
  144. onnxruntime/tools/ort_format_model/ort_flatbuffers_py/fbs/TypeInfo.py +83 -0
  145. onnxruntime/tools/ort_format_model/ort_flatbuffers_py/fbs/TypeInfoValue.py +9 -0
  146. onnxruntime/tools/ort_format_model/ort_flatbuffers_py/fbs/ValueInfo.py +84 -0
  147. onnxruntime/tools/ort_format_model/ort_flatbuffers_py/fbs/__init__.py +6 -0
  148. onnxruntime/tools/ort_format_model/ort_model_processor.py +86 -0
  149. onnxruntime/tools/ort_format_model/types.py +85 -0
  150. onnxruntime/tools/ort_format_model/utils.py +61 -0
  151. onnxruntime/tools/pytorch_export_contrib_ops.py +129 -0
  152. onnxruntime/tools/pytorch_export_helpers.py +131 -0
  153. onnxruntime/tools/qdq_helpers/__init__.py +0 -0
  154. onnxruntime/tools/qdq_helpers/optimize_qdq_model.py +37 -0
  155. onnxruntime/tools/qnn/add_trans_cast.py +292 -0
  156. onnxruntime/tools/qnn/gen_qnn_ctx_onnx_model.py +364 -0
  157. onnxruntime/tools/qnn/preprocess.py +165 -0
  158. onnxruntime/tools/reduced_build_config_parser.py +203 -0
  159. onnxruntime/tools/remove_initializer_from_input.py +37 -0
  160. onnxruntime/tools/symbolic_shape_infer.py +3094 -0
  161. onnxruntime/tools/update_onnx_opset.py +31 -0
  162. onnxruntime/transformers/__init__.py +8 -0
  163. onnxruntime/transformers/affinity_helper.py +40 -0
  164. onnxruntime/transformers/benchmark.py +942 -0
  165. onnxruntime/transformers/benchmark_helper.py +643 -0
  166. onnxruntime/transformers/bert_perf_test.py +629 -0
  167. onnxruntime/transformers/bert_test_data.py +641 -0
  168. onnxruntime/transformers/compare_bert_results.py +256 -0
  169. onnxruntime/transformers/constants.py +47 -0
  170. onnxruntime/transformers/convert_generation.py +3605 -0
  171. onnxruntime/transformers/convert_tf_models_to_pytorch.py +205 -0
  172. onnxruntime/transformers/convert_to_packing_mode.py +385 -0
  173. onnxruntime/transformers/dynamo_onnx_helper.py +205 -0
  174. onnxruntime/transformers/float16.py +501 -0
  175. onnxruntime/transformers/fusion_attention.py +1189 -0
  176. onnxruntime/transformers/fusion_attention_clip.py +340 -0
  177. onnxruntime/transformers/fusion_attention_sam2.py +533 -0
  178. onnxruntime/transformers/fusion_attention_unet.py +1307 -0
  179. onnxruntime/transformers/fusion_attention_vae.py +300 -0
  180. onnxruntime/transformers/fusion_bart_attention.py +435 -0
  181. onnxruntime/transformers/fusion_base.py +141 -0
  182. onnxruntime/transformers/fusion_bias_add.py +57 -0
  183. onnxruntime/transformers/fusion_biasgelu.py +66 -0
  184. onnxruntime/transformers/fusion_biassplitgelu.py +110 -0
  185. onnxruntime/transformers/fusion_conformer_attention.py +222 -0
  186. onnxruntime/transformers/fusion_constant_fold.py +144 -0
  187. onnxruntime/transformers/fusion_embedlayer.py +810 -0
  188. onnxruntime/transformers/fusion_fastgelu.py +492 -0
  189. onnxruntime/transformers/fusion_gelu.py +258 -0
  190. onnxruntime/transformers/fusion_gelu_approximation.py +25 -0
  191. onnxruntime/transformers/fusion_gemmfastgelu.py +121 -0
  192. onnxruntime/transformers/fusion_gpt_attention.py +546 -0
  193. onnxruntime/transformers/fusion_gpt_attention_megatron.py +355 -0
  194. onnxruntime/transformers/fusion_gpt_attention_no_past.py +260 -0
  195. onnxruntime/transformers/fusion_group_norm.py +180 -0
  196. onnxruntime/transformers/fusion_layernorm.py +489 -0
  197. onnxruntime/transformers/fusion_mha_mmdit.py +667 -0
  198. onnxruntime/transformers/fusion_nhwc_conv.py +99 -0
  199. onnxruntime/transformers/fusion_options.py +340 -0
  200. onnxruntime/transformers/fusion_qordered_attention.py +420 -0
  201. onnxruntime/transformers/fusion_qordered_gelu.py +118 -0
  202. onnxruntime/transformers/fusion_qordered_layernorm.py +122 -0
  203. onnxruntime/transformers/fusion_qordered_matmul.py +216 -0
  204. onnxruntime/transformers/fusion_quickgelu.py +74 -0
  205. onnxruntime/transformers/fusion_reshape.py +173 -0
  206. onnxruntime/transformers/fusion_rotary_attention.py +1591 -0
  207. onnxruntime/transformers/fusion_shape.py +109 -0
  208. onnxruntime/transformers/fusion_simplified_layernorm.py +165 -0
  209. onnxruntime/transformers/fusion_skip_group_norm.py +254 -0
  210. onnxruntime/transformers/fusion_skiplayernorm.py +209 -0
  211. onnxruntime/transformers/fusion_transpose.py +167 -0
  212. onnxruntime/transformers/fusion_utils.py +321 -0
  213. onnxruntime/transformers/huggingface_models.py +74 -0
  214. onnxruntime/transformers/import_utils.py +20 -0
  215. onnxruntime/transformers/io_binding_helper.py +487 -0
  216. onnxruntime/transformers/large_model_exporter.py +395 -0
  217. onnxruntime/transformers/machine_info.py +230 -0
  218. onnxruntime/transformers/metrics.py +163 -0
  219. onnxruntime/transformers/models/bart/__init__.py +12 -0
  220. onnxruntime/transformers/models/bart/export.py +98 -0
  221. onnxruntime/transformers/models/bert/__init__.py +12 -0
  222. onnxruntime/transformers/models/bert/eval_squad.py +329 -0
  223. onnxruntime/transformers/models/gpt2/__init__.py +12 -0
  224. onnxruntime/transformers/models/gpt2/benchmark_gpt2.py +413 -0
  225. onnxruntime/transformers/models/gpt2/convert_to_onnx.py +566 -0
  226. onnxruntime/transformers/models/gpt2/gpt2_helper.py +1031 -0
  227. onnxruntime/transformers/models/gpt2/gpt2_parity.py +513 -0
  228. onnxruntime/transformers/models/gpt2/gpt2_tester.py +501 -0
  229. onnxruntime/transformers/models/gpt2/parity_check_helper.py +146 -0
  230. onnxruntime/transformers/models/llama/__init__.py +12 -0
  231. onnxruntime/transformers/models/llama/benchmark.py +700 -0
  232. onnxruntime/transformers/models/llama/benchmark_all.py +488 -0
  233. onnxruntime/transformers/models/llama/benchmark_e2e.py +608 -0
  234. onnxruntime/transformers/models/llama/convert_to_onnx.py +1064 -0
  235. onnxruntime/transformers/models/llama/dist_settings.py +57 -0
  236. onnxruntime/transformers/models/llama/llama_inputs.py +504 -0
  237. onnxruntime/transformers/models/llama/llama_parity.py +343 -0
  238. onnxruntime/transformers/models/llama/llama_torch.py +47 -0
  239. onnxruntime/transformers/models/llama/quant_kv_dataloader.py +108 -0
  240. onnxruntime/transformers/models/longformer/__init__.py +12 -0
  241. onnxruntime/transformers/models/longformer/benchmark_longformer.py +821 -0
  242. onnxruntime/transformers/models/longformer/convert_to_onnx.py +413 -0
  243. onnxruntime/transformers/models/longformer/generate_test_data.py +347 -0
  244. onnxruntime/transformers/models/longformer/longformer_helper.py +76 -0
  245. onnxruntime/transformers/models/phi2/__init__.py +12 -0
  246. onnxruntime/transformers/models/phi2/convert_to_onnx.py +590 -0
  247. onnxruntime/transformers/models/phi2/inference_example.py +414 -0
  248. onnxruntime/transformers/models/sam2/__init__.py +12 -0
  249. onnxruntime/transformers/models/sam2/benchmark_sam2.py +638 -0
  250. onnxruntime/transformers/models/sam2/convert_to_onnx.py +270 -0
  251. onnxruntime/transformers/models/sam2/image_decoder.py +272 -0
  252. onnxruntime/transformers/models/sam2/image_encoder.py +236 -0
  253. onnxruntime/transformers/models/sam2/mask_decoder.py +208 -0
  254. onnxruntime/transformers/models/sam2/nvtx_helper.py +33 -0
  255. onnxruntime/transformers/models/sam2/prompt_encoder.py +189 -0
  256. onnxruntime/transformers/models/sam2/sam2_demo.py +321 -0
  257. onnxruntime/transformers/models/sam2/sam2_image_onnx_predictor.py +279 -0
  258. onnxruntime/transformers/models/sam2/sam2_utils.py +147 -0
  259. onnxruntime/transformers/models/stable_diffusion/__init__.py +12 -0
  260. onnxruntime/transformers/models/stable_diffusion/benchmark.py +1519 -0
  261. onnxruntime/transformers/models/stable_diffusion/benchmark_controlnet.py +426 -0
  262. onnxruntime/transformers/models/stable_diffusion/demo_txt2img.py +103 -0
  263. onnxruntime/transformers/models/stable_diffusion/demo_txt2img_xl.py +269 -0
  264. onnxruntime/transformers/models/stable_diffusion/demo_utils.py +778 -0
  265. onnxruntime/transformers/models/stable_diffusion/diffusion_models.py +1318 -0
  266. onnxruntime/transformers/models/stable_diffusion/diffusion_schedulers.py +1179 -0
  267. onnxruntime/transformers/models/stable_diffusion/engine_builder.py +295 -0
  268. onnxruntime/transformers/models/stable_diffusion/engine_builder_ort_cuda.py +387 -0
  269. onnxruntime/transformers/models/stable_diffusion/engine_builder_ort_trt.py +288 -0
  270. onnxruntime/transformers/models/stable_diffusion/engine_builder_tensorrt.py +395 -0
  271. onnxruntime/transformers/models/stable_diffusion/engine_builder_torch.py +108 -0
  272. onnxruntime/transformers/models/stable_diffusion/optimize_pipeline.py +590 -0
  273. onnxruntime/transformers/models/stable_diffusion/ort_optimizer.py +136 -0
  274. onnxruntime/transformers/models/stable_diffusion/pipeline_stable_diffusion.py +831 -0
  275. onnxruntime/transformers/models/stable_diffusion/trt_utilities.py +12 -0
  276. onnxruntime/transformers/models/t5/__init__.py +12 -0
  277. onnxruntime/transformers/models/t5/convert_to_onnx.py +318 -0
  278. onnxruntime/transformers/models/t5/t5_decoder.py +437 -0
  279. onnxruntime/transformers/models/t5/t5_encoder.py +70 -0
  280. onnxruntime/transformers/models/t5/t5_encoder_decoder_init.py +361 -0
  281. onnxruntime/transformers/models/t5/t5_helper.py +302 -0
  282. onnxruntime/transformers/models/whisper/__init__.py +12 -0
  283. onnxruntime/transformers/models/whisper/benchmark.py +585 -0
  284. onnxruntime/transformers/models/whisper/benchmark_all.py +526 -0
  285. onnxruntime/transformers/models/whisper/convert_to_onnx.py +609 -0
  286. onnxruntime/transformers/models/whisper/whisper_chain.py +334 -0
  287. onnxruntime/transformers/models/whisper/whisper_decoder.py +464 -0
  288. onnxruntime/transformers/models/whisper/whisper_encoder.py +164 -0
  289. onnxruntime/transformers/models/whisper/whisper_encoder_decoder_init.py +371 -0
  290. onnxruntime/transformers/models/whisper/whisper_helper.py +1035 -0
  291. onnxruntime/transformers/models/whisper/whisper_inputs.py +380 -0
  292. onnxruntime/transformers/models/whisper/whisper_jump_times.py +477 -0
  293. onnxruntime/transformers/onnx_exporter.py +719 -0
  294. onnxruntime/transformers/onnx_model.py +1636 -0
  295. onnxruntime/transformers/onnx_model_bart.py +141 -0
  296. onnxruntime/transformers/onnx_model_bert.py +488 -0
  297. onnxruntime/transformers/onnx_model_bert_keras.py +474 -0
  298. onnxruntime/transformers/onnx_model_bert_tf.py +588 -0
  299. onnxruntime/transformers/onnx_model_clip.py +42 -0
  300. onnxruntime/transformers/onnx_model_conformer.py +32 -0
  301. onnxruntime/transformers/onnx_model_gpt2.py +101 -0
  302. onnxruntime/transformers/onnx_model_mmdit.py +112 -0
  303. onnxruntime/transformers/onnx_model_phi.py +929 -0
  304. onnxruntime/transformers/onnx_model_sam2.py +137 -0
  305. onnxruntime/transformers/onnx_model_t5.py +985 -0
  306. onnxruntime/transformers/onnx_model_tnlr.py +226 -0
  307. onnxruntime/transformers/onnx_model_unet.py +258 -0
  308. onnxruntime/transformers/onnx_model_vae.py +42 -0
  309. onnxruntime/transformers/onnx_utils.py +55 -0
  310. onnxruntime/transformers/optimizer.py +620 -0
  311. onnxruntime/transformers/past_helper.py +149 -0
  312. onnxruntime/transformers/profile_result_processor.py +358 -0
  313. onnxruntime/transformers/profiler.py +434 -0
  314. onnxruntime/transformers/quantize_helper.py +76 -0
  315. onnxruntime/transformers/shape_infer_helper.py +121 -0
  316. onnxruntime/transformers/shape_optimizer.py +400 -0
  317. onnxruntime/transformers/torch_onnx_export_helper.py +74 -0
  318. onnxruntime_directml-1.24.1.dist-info/METADATA +216 -0
  319. onnxruntime_directml-1.24.1.dist-info/RECORD +322 -0
  320. onnxruntime_directml-1.24.1.dist-info/WHEEL +5 -0
  321. onnxruntime_directml-1.24.1.dist-info/entry_points.txt +2 -0
  322. onnxruntime_directml-1.24.1.dist-info/top_level.txt +1 -0
@@ -0,0 +1,778 @@
1
+ # -------------------------------------------------------------------------
2
+ # Copyright (c) Microsoft Corporation. All rights reserved.
3
+ # Licensed under the MIT License.
4
+ # --------------------------------------------------------------------------
5
+ # Modified from TensorRT demo diffusion, which has the following license:
6
+ #
7
+ # SPDX-FileCopyrightText: Copyright (c) 1993-2023 NVIDIA CORPORATION & AFFILIATES. All rights reserved.
8
+ # SPDX-License-Identifier: Apache-2.0
9
+ #
10
+ # Licensed under the Apache License, Version 2.0 (the "License");
11
+ # you may not use this file except in compliance with the License.
12
+ # You may obtain a copy of the License at
13
+ #
14
+ # http://www.apache.org/licenses/LICENSE-2.0
15
+ #
16
+ # Unless required by applicable law or agreed to in writing, software
17
+ # distributed under the License is distributed on an "AS IS" BASIS,
18
+ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
19
+ # See the License for the specific language governing permissions and
20
+ # limitations under the License.
21
+ # --------------------------------------------------------------------------
22
+ import argparse
23
+ import os
24
+ import sys
25
+ from importlib.metadata import PackageNotFoundError, version
26
+ from typing import Any
27
+
28
+ import controlnet_aux
29
+ import cv2
30
+ import numpy as np
31
+ import torch
32
+ from cuda import cudart
33
+ from diffusion_models import PipelineInfo
34
+ from engine_builder import EngineType, get_engine_paths, get_engine_type
35
+ from PIL import Image
36
+ from pipeline_stable_diffusion import StableDiffusionPipeline
37
+
38
+
39
+ class RawTextArgumentDefaultsHelpFormatter(argparse.ArgumentDefaultsHelpFormatter, argparse.RawTextHelpFormatter):
40
+ pass
41
+
42
+
43
+ def arg_parser(description: str):
44
+ return argparse.ArgumentParser(
45
+ description=description,
46
+ formatter_class=RawTextArgumentDefaultsHelpFormatter,
47
+ )
48
+
49
+
50
+ def set_default_arguments(args):
51
+ # set default value for some arguments if not provided
52
+ if args.height is None:
53
+ args.height = PipelineInfo.default_resolution(args.version)
54
+
55
+ if args.width is None:
56
+ args.width = PipelineInfo.default_resolution(args.version)
57
+
58
+ is_lcm = (args.version == "xl-1.0" and args.lcm) or "lcm" in args.lora_weights
59
+ is_turbo = args.version in ["sd-turbo", "xl-turbo"]
60
+ if args.denoising_steps is None:
61
+ args.denoising_steps = 4 if is_turbo else 8 if is_lcm else (30 if args.version == "xl-1.0" else 50)
62
+
63
+ if args.scheduler is None:
64
+ args.scheduler = "LCM" if (is_lcm or is_turbo) else ("EulerA" if args.version == "xl-1.0" else "DDIM")
65
+
66
+ if args.guidance is None:
67
+ args.guidance = 0.0 if (is_lcm or is_turbo) else (5.0 if args.version == "xl-1.0" else 7.5)
68
+
69
+
70
+ def parse_arguments(is_xl: bool, parser):
71
+ engines = ["ORT_CUDA", "ORT_TRT", "TRT", "TORCH"]
72
+
73
+ parser.add_argument(
74
+ "-e",
75
+ "--engine",
76
+ type=str,
77
+ default=engines[0],
78
+ choices=engines,
79
+ help="Backend engine in {engines}. "
80
+ "ORT_CUDA is CUDA execution provider; ORT_TRT is Tensorrt execution provider; TRT is TensorRT",
81
+ )
82
+
83
+ supported_versions = PipelineInfo.supported_versions(is_xl)
84
+ parser.add_argument(
85
+ "-v",
86
+ "--version",
87
+ type=str,
88
+ default="xl-1.0" if is_xl else "1.5",
89
+ choices=supported_versions,
90
+ help="Version of Stable Diffusion" + (" XL." if is_xl else "."),
91
+ )
92
+
93
+ parser.add_argument(
94
+ "-y",
95
+ "--height",
96
+ type=int,
97
+ default=None,
98
+ help="Height of image to generate (must be multiple of 8).",
99
+ )
100
+ parser.add_argument(
101
+ "-x", "--width", type=int, default=None, help="Height of image to generate (must be multiple of 8)."
102
+ )
103
+
104
+ parser.add_argument(
105
+ "-s",
106
+ "--scheduler",
107
+ type=str,
108
+ default=None,
109
+ choices=["DDIM", "EulerA", "UniPC", "LCM"],
110
+ help="Scheduler for diffusion process" + " of base" if is_xl else "",
111
+ )
112
+
113
+ parser.add_argument(
114
+ "-wd",
115
+ "--work-dir",
116
+ default=".",
117
+ help="Root Directory to store torch or ONNX models, built engines and output images etc.",
118
+ )
119
+
120
+ parser.add_argument(
121
+ "-i",
122
+ "--engine-dir",
123
+ default=None,
124
+ help="Root Directory to store built engines or optimized ONNX models etc.",
125
+ )
126
+
127
+ parser.add_argument("prompt", nargs="*", default=[""], help="Text prompt(s) to guide image generation.")
128
+
129
+ parser.add_argument(
130
+ "-n",
131
+ "--negative-prompt",
132
+ nargs="*",
133
+ default=[""],
134
+ help="Optional negative prompt(s) to guide the image generation.",
135
+ )
136
+ parser.add_argument(
137
+ "-b",
138
+ "--batch-size",
139
+ type=int,
140
+ default=1,
141
+ choices=[1, 2, 4, 8, 16],
142
+ help="Number of times to repeat the prompt (batch size multiplier).",
143
+ )
144
+
145
+ parser.add_argument(
146
+ "-d",
147
+ "--denoising-steps",
148
+ type=int,
149
+ default=None,
150
+ help="Number of denoising steps" + (" in base." if is_xl else "."),
151
+ )
152
+
153
+ parser.add_argument(
154
+ "-g",
155
+ "--guidance",
156
+ type=float,
157
+ default=None,
158
+ help="Higher guidance scale encourages to generate images that are closely linked to the text prompt.",
159
+ )
160
+
161
+ parser.add_argument(
162
+ "-ls", "--lora-scale", type=float, default=1, help="Scale of LoRA weights, default 1 (must between 0 and 1)"
163
+ )
164
+ parser.add_argument("-lw", "--lora-weights", type=str, default="", help="LoRA weights to apply in the base model")
165
+
166
+ if is_xl:
167
+ parser.add_argument(
168
+ "--lcm",
169
+ action="store_true",
170
+ help="Use fine-tuned latent consistency model to replace the UNet in base.",
171
+ )
172
+
173
+ parser.add_argument(
174
+ "-rs",
175
+ "--refiner-scheduler",
176
+ type=str,
177
+ default="EulerA",
178
+ choices=["DDIM", "EulerA", "UniPC"],
179
+ help="Scheduler for diffusion process of refiner.",
180
+ )
181
+
182
+ parser.add_argument(
183
+ "-rg",
184
+ "--refiner-guidance",
185
+ type=float,
186
+ default=5.0,
187
+ help="Guidance scale used in refiner.",
188
+ )
189
+
190
+ parser.add_argument(
191
+ "-rd",
192
+ "--refiner-denoising-steps",
193
+ type=int,
194
+ default=30,
195
+ help="Number of denoising steps in refiner. Note that actual steps is refiner_denoising_steps * strength.",
196
+ )
197
+
198
+ parser.add_argument(
199
+ "--strength",
200
+ type=float,
201
+ default=0.3,
202
+ help="A value between 0 and 1. The higher the value less the final image similar to the seed image.",
203
+ )
204
+
205
+ parser.add_argument(
206
+ "-r",
207
+ "--enable-refiner",
208
+ action="store_true",
209
+ help="Enable SDXL refiner to refine image from base pipeline.",
210
+ )
211
+
212
+ # ONNX export
213
+ parser.add_argument(
214
+ "--onnx-opset",
215
+ type=int,
216
+ default=None,
217
+ choices=range(14, 18),
218
+ help="Select ONNX opset version to target for exported models.",
219
+ )
220
+
221
+ # Engine build options.
222
+ parser.add_argument(
223
+ "-db",
224
+ "--build-dynamic-batch",
225
+ action="store_true",
226
+ help="Build TensorRT engines to support dynamic batch size.",
227
+ )
228
+ parser.add_argument(
229
+ "-ds",
230
+ "--build-dynamic-shape",
231
+ action="store_true",
232
+ help="Build TensorRT engines to support dynamic image sizes.",
233
+ )
234
+ parser.add_argument("--max-batch-size", type=int, default=None, choices=[1, 2, 4, 8, 16, 32], help="Max batch size")
235
+
236
+ # Inference related options
237
+ parser.add_argument(
238
+ "-nw", "--num-warmup-runs", type=int, default=5, help="Number of warmup runs before benchmarking performance."
239
+ )
240
+ parser.add_argument("--nvtx-profile", action="store_true", help="Enable NVTX markers for performance profiling.")
241
+ parser.add_argument("--seed", type=int, default=None, help="Seed for random generator to get consistent results.")
242
+ parser.add_argument("--deterministic", action="store_true", help="use deterministic algorithms.")
243
+ parser.add_argument("-dc", "--disable-cuda-graph", action="store_true", help="Disable cuda graph.")
244
+
245
+ parser.add_argument("--framework-model-dir", default=None, help="framework model directory")
246
+
247
+ group = parser.add_argument_group("Options for ORT_CUDA engine only")
248
+ group.add_argument("--enable-vae-slicing", action="store_true", help="True will feed only one image to VAE once.")
249
+ group.add_argument("--max-cuda-graphs", type=int, default=1, help="Max number of cuda graphs to use. Default 1.")
250
+ group.add_argument("--user-compute-stream", action="store_true", help="Use user compute stream.")
251
+
252
+ # TensorRT only options
253
+ group = parser.add_argument_group("Options for TensorRT (--engine=TRT) only")
254
+ group.add_argument(
255
+ "--build-all-tactics", action="store_true", help="Build TensorRT engines using all tactic sources."
256
+ )
257
+
258
+ args = parser.parse_args()
259
+
260
+ set_default_arguments(args)
261
+
262
+ # Validate image dimensions
263
+ if args.height % 64 != 0 or args.width % 64 != 0:
264
+ raise ValueError(
265
+ f"Image height and width have to be divisible by 64 but specified as: {args.height} and {args.width}."
266
+ )
267
+
268
+ if (args.build_dynamic_batch or args.build_dynamic_shape) and not args.disable_cuda_graph:
269
+ print("[I] CUDA Graph is disabled since dynamic input shape is configured.")
270
+ args.disable_cuda_graph = True
271
+
272
+ if args.onnx_opset is None:
273
+ args.onnx_opset = 14 if args.engine == "ORT_CUDA" else 17
274
+
275
+ if is_xl:
276
+ if args.version == "xl-turbo":
277
+ if args.lcm:
278
+ print("[I] sdxl-turbo cannot use with LCM.")
279
+ args.lcm = False
280
+
281
+ assert args.strength > 0.0 and args.strength < 1.0
282
+
283
+ assert not (args.lcm and args.lora_weights), "it is not supported to use both lcm unet and Lora together"
284
+
285
+ if args.scheduler == "LCM":
286
+ if args.guidance > 2.0:
287
+ print("[I] Use --guidance=0.0 (no more than 2.0) when LCM scheduler is used.")
288
+ args.guidance = 0.0
289
+ if args.denoising_steps > 16:
290
+ print("[I] Use --denoising_steps=8 (no more than 16) when LCM scheduler is used.")
291
+ args.denoising_steps = 8
292
+
293
+ print(args)
294
+
295
+ return args
296
+
297
+
298
+ def max_batch(args):
299
+ if args.max_batch_size:
300
+ max_batch_size = args.max_batch_size
301
+ else:
302
+ do_classifier_free_guidance = args.guidance > 1.0
303
+ batch_multiplier = 2 if do_classifier_free_guidance else 1
304
+ max_batch_size = 32 // batch_multiplier
305
+ if args.engine != "ORT_CUDA" and (args.build_dynamic_shape or args.height > 512 or args.width > 512):
306
+ max_batch_size = 8 // batch_multiplier
307
+ return max_batch_size
308
+
309
+
310
+ def get_metadata(args, is_xl: bool = False) -> dict[str, Any]:
311
+ metadata = {
312
+ "command": " ".join(['"' + x + '"' if " " in x else x for x in sys.argv]),
313
+ "args.prompt": args.prompt,
314
+ "args.negative_prompt": args.negative_prompt,
315
+ "args.batch_size": args.batch_size,
316
+ "height": args.height,
317
+ "width": args.width,
318
+ "cuda_graph": not args.disable_cuda_graph,
319
+ "vae_slicing": args.enable_vae_slicing,
320
+ "engine": args.engine,
321
+ }
322
+
323
+ if args.lora_weights:
324
+ metadata["lora_weights"] = args.lora_weights
325
+ metadata["lora_scale"] = args.lora_scale
326
+
327
+ if args.controlnet_type:
328
+ metadata["controlnet_type"] = args.controlnet_type
329
+ metadata["controlnet_scale"] = args.controlnet_scale
330
+
331
+ if is_xl and args.enable_refiner:
332
+ metadata["base.scheduler"] = args.scheduler
333
+ metadata["base.denoising_steps"] = args.denoising_steps
334
+ metadata["base.guidance"] = args.guidance
335
+ metadata["refiner.strength"] = args.strength
336
+ metadata["refiner.scheduler"] = args.refiner_scheduler
337
+ metadata["refiner.denoising_steps"] = args.refiner_denoising_steps
338
+ metadata["refiner.guidance"] = args.refiner_guidance
339
+ else:
340
+ metadata["scheduler"] = args.scheduler
341
+ metadata["denoising_steps"] = args.denoising_steps
342
+ metadata["guidance"] = args.guidance
343
+
344
+ # Version of installed python packages
345
+ packages = ""
346
+ for name in [
347
+ "onnxruntime-gpu",
348
+ "torch",
349
+ "tensorrt",
350
+ "transformers",
351
+ "diffusers",
352
+ "onnx",
353
+ "onnx-graphsurgeon",
354
+ "polygraphy",
355
+ "controlnet_aux",
356
+ ]:
357
+ try:
358
+ packages += (" " if packages else "") + f"{name}=={version(name)}"
359
+ except PackageNotFoundError:
360
+ continue
361
+ metadata["packages"] = packages
362
+ metadata["device"] = torch.cuda.get_device_name()
363
+ metadata["torch.version.cuda"] = torch.version.cuda
364
+
365
+ return metadata
366
+
367
+
368
+ def repeat_prompt(args):
369
+ if not isinstance(args.prompt, list):
370
+ raise ValueError(f"`prompt` must be of type `str` or `str` list, but is {type(args.prompt)}")
371
+ prompt = args.prompt * args.batch_size
372
+
373
+ if not isinstance(args.negative_prompt, list):
374
+ raise ValueError(
375
+ f"`--negative-prompt` must be of type `str` or `str` list, but is {type(args.negative_prompt)}"
376
+ )
377
+
378
+ if len(args.negative_prompt) == 1:
379
+ negative_prompt = args.negative_prompt * len(prompt)
380
+ else:
381
+ negative_prompt = args.negative_prompt
382
+
383
+ return prompt, negative_prompt
384
+
385
+
386
+ def initialize_pipeline(
387
+ version="xl-turbo",
388
+ is_refiner: bool = False,
389
+ is_inpaint: bool = False,
390
+ engine_type=EngineType.ORT_CUDA,
391
+ work_dir: str = ".",
392
+ engine_dir=None,
393
+ onnx_opset: int = 17,
394
+ scheduler="EulerA",
395
+ height=512,
396
+ width=512,
397
+ nvtx_profile=False,
398
+ use_cuda_graph=True,
399
+ build_dynamic_batch=False,
400
+ build_dynamic_shape=False,
401
+ min_image_size: int = 512,
402
+ max_image_size: int = 1024,
403
+ max_batch_size: int = 16,
404
+ opt_batch_size: int = 1,
405
+ build_all_tactics: bool = False,
406
+ do_classifier_free_guidance: bool = False,
407
+ lcm: bool = False,
408
+ controlnet=None,
409
+ lora_weights=None,
410
+ lora_scale: float = 1.0,
411
+ use_fp16_vae: bool = True,
412
+ use_vae: bool = True,
413
+ framework_model_dir: str | None = None,
414
+ max_cuda_graphs: int = 1,
415
+ ):
416
+ pipeline_info = PipelineInfo(
417
+ version,
418
+ is_refiner=is_refiner,
419
+ is_inpaint=is_inpaint,
420
+ use_vae=use_vae,
421
+ min_image_size=min_image_size,
422
+ max_image_size=max_image_size,
423
+ use_fp16_vae=use_fp16_vae,
424
+ use_lcm=lcm,
425
+ do_classifier_free_guidance=do_classifier_free_guidance,
426
+ controlnet=controlnet,
427
+ lora_weights=lora_weights,
428
+ lora_scale=lora_scale,
429
+ )
430
+
431
+ input_engine_dir = engine_dir
432
+
433
+ onnx_dir, engine_dir, output_dir, framework_model_dir, timing_cache = get_engine_paths(
434
+ work_dir=work_dir, pipeline_info=pipeline_info, engine_type=engine_type, framework_model_dir=framework_model_dir
435
+ )
436
+
437
+ pipeline = StableDiffusionPipeline(
438
+ pipeline_info,
439
+ scheduler=scheduler,
440
+ output_dir=output_dir,
441
+ verbose=False,
442
+ nvtx_profile=nvtx_profile,
443
+ max_batch_size=max_batch_size,
444
+ use_cuda_graph=use_cuda_graph,
445
+ framework_model_dir=framework_model_dir,
446
+ engine_type=engine_type,
447
+ )
448
+
449
+ import_engine_dir = None
450
+ if input_engine_dir:
451
+ if not os.path.exists(input_engine_dir):
452
+ raise RuntimeError(f"--engine_dir directory does not exist: {input_engine_dir}")
453
+
454
+ # Support importing from optimized diffusers onnx pipeline
455
+ if engine_type == EngineType.ORT_CUDA and os.path.exists(os.path.join(input_engine_dir, "model_index.json")):
456
+ import_engine_dir = input_engine_dir
457
+ else:
458
+ engine_dir = input_engine_dir
459
+
460
+ opt_image_height = pipeline_info.default_image_size() if build_dynamic_shape else height
461
+ opt_image_width = pipeline_info.default_image_size() if build_dynamic_shape else width
462
+
463
+ if engine_type == EngineType.ORT_CUDA:
464
+ pipeline.backend.build_engines(
465
+ engine_dir=engine_dir,
466
+ framework_model_dir=framework_model_dir,
467
+ onnx_dir=onnx_dir,
468
+ tmp_dir=os.path.join(work_dir or ".", engine_type.name, pipeline_info.short_name(), "tmp"),
469
+ device_id=torch.cuda.current_device(),
470
+ import_engine_dir=import_engine_dir,
471
+ max_cuda_graphs=max_cuda_graphs,
472
+ )
473
+ elif engine_type == EngineType.ORT_TRT:
474
+ pipeline.backend.build_engines(
475
+ engine_dir,
476
+ framework_model_dir,
477
+ onnx_dir,
478
+ onnx_opset,
479
+ opt_image_height=opt_image_height,
480
+ opt_image_width=opt_image_width,
481
+ opt_batch_size=opt_batch_size,
482
+ static_batch=not build_dynamic_batch,
483
+ static_image_shape=not build_dynamic_shape,
484
+ max_workspace_size=0,
485
+ device_id=torch.cuda.current_device(),
486
+ timing_cache=timing_cache,
487
+ )
488
+ elif engine_type == EngineType.TRT:
489
+ pipeline.backend.load_engines(
490
+ engine_dir,
491
+ framework_model_dir,
492
+ onnx_dir,
493
+ onnx_opset,
494
+ opt_batch_size=opt_batch_size,
495
+ opt_image_height=opt_image_height,
496
+ opt_image_width=opt_image_width,
497
+ static_batch=not build_dynamic_batch,
498
+ static_shape=not build_dynamic_shape,
499
+ enable_all_tactics=build_all_tactics,
500
+ timing_cache=timing_cache,
501
+ )
502
+ elif engine_type == EngineType.TORCH:
503
+ pipeline.backend.build_engines(framework_model_dir)
504
+ else:
505
+ raise RuntimeError("invalid engine type")
506
+
507
+ return pipeline
508
+
509
+
510
+ def load_pipelines(args, batch_size=None):
511
+ engine_type = get_engine_type(args.engine)
512
+
513
+ # Register TensorRT plugins
514
+ if engine_type == EngineType.TRT:
515
+ from trt_utilities import init_trt_plugins # noqa: PLC0415
516
+
517
+ init_trt_plugins()
518
+
519
+ max_batch_size = max_batch(args)
520
+
521
+ if batch_size is None:
522
+ assert isinstance(args.prompt, list)
523
+ batch_size = len(args.prompt) * args.batch_size
524
+
525
+ if batch_size > max_batch_size:
526
+ raise ValueError(f"Batch size {batch_size} is larger than allowed {max_batch_size}.")
527
+
528
+ # For TensorRT, performance of engine built with dynamic shape is very sensitive to the range of image size.
529
+ # Here, we reduce the range of image size for TensorRT to trade-off flexibility and performance.
530
+ # This range can cover most frequent shape of landscape (832x1216), portrait (1216x832) or square (1024x1024).
531
+ if args.version == "xl-turbo":
532
+ min_image_size = 512
533
+ max_image_size = 768 if args.engine != "ORT_CUDA" else 1024
534
+ elif args.version == "xl-1.0":
535
+ min_image_size = 832 if args.engine != "ORT_CUDA" else 512
536
+ max_image_size = 1216 if args.engine != "ORT_CUDA" else 2048
537
+ else:
538
+ # This range can cover common used shape of landscape 512x768, portrait 768x512, or square 512x512 and 768x768.
539
+ min_image_size = 512 if args.engine != "ORT_CUDA" else 256
540
+ max_image_size = 768 if args.engine != "ORT_CUDA" else 1024
541
+
542
+ params = {
543
+ "version": args.version,
544
+ "is_refiner": False,
545
+ "is_inpaint": False,
546
+ "engine_type": engine_type,
547
+ "work_dir": args.work_dir,
548
+ "engine_dir": args.engine_dir,
549
+ "onnx_opset": args.onnx_opset,
550
+ "scheduler": args.scheduler,
551
+ "height": args.height,
552
+ "width": args.width,
553
+ "nvtx_profile": args.nvtx_profile,
554
+ "use_cuda_graph": not args.disable_cuda_graph,
555
+ "build_dynamic_batch": args.build_dynamic_batch,
556
+ "build_dynamic_shape": args.build_dynamic_shape,
557
+ "min_image_size": min_image_size,
558
+ "max_image_size": max_image_size,
559
+ "max_batch_size": max_batch_size,
560
+ "opt_batch_size": 1 if args.build_dynamic_batch else batch_size,
561
+ "build_all_tactics": args.build_all_tactics,
562
+ "do_classifier_free_guidance": args.guidance > 1.0,
563
+ "controlnet": args.controlnet_type,
564
+ "lora_weights": args.lora_weights,
565
+ "lora_scale": args.lora_scale,
566
+ "use_fp16_vae": "xl" in args.version,
567
+ "use_vae": True,
568
+ "framework_model_dir": args.framework_model_dir,
569
+ "max_cuda_graphs": args.max_cuda_graphs,
570
+ }
571
+
572
+ if "xl" in args.version:
573
+ params["lcm"] = args.lcm
574
+ params["use_vae"] = not args.enable_refiner
575
+ base = initialize_pipeline(**params)
576
+
577
+ refiner = None
578
+ if "xl" in args.version and args.enable_refiner:
579
+ params["version"] = "xl-1.0" # Allow SDXL Turbo to use refiner.
580
+ params["is_refiner"] = True
581
+ params["scheduler"] = args.refiner_scheduler
582
+ params["do_classifier_free_guidance"] = args.refiner_guidance > 1.0
583
+ params["lcm"] = False
584
+ params["controlnet"] = None
585
+ params["lora_weights"] = None
586
+ params["use_vae"] = True
587
+ params["use_fp16_vae"] = True
588
+ refiner = initialize_pipeline(**params)
589
+
590
+ if engine_type == EngineType.TRT:
591
+ max_device_memory = max(base.backend.max_device_memory(), (refiner or base).backend.max_device_memory())
592
+ _, shared_device_memory = cudart.cudaMalloc(max_device_memory)
593
+ base.backend.activate_engines(shared_device_memory)
594
+ if refiner:
595
+ refiner.backend.activate_engines(shared_device_memory)
596
+
597
+ if engine_type == EngineType.ORT_CUDA:
598
+ enable_vae_slicing = args.enable_vae_slicing
599
+ if batch_size > 4 and not enable_vae_slicing and (args.height >= 1024 and args.width >= 1024):
600
+ print(
601
+ "Updating enable_vae_slicing to be True to avoid cuDNN error for batch size > 4 and resolution >= 1024."
602
+ )
603
+ enable_vae_slicing = True
604
+ if enable_vae_slicing:
605
+ (refiner or base).backend.enable_vae_slicing()
606
+ return base, refiner
607
+
608
+
609
+ def get_depth_image(image):
610
+ """
611
+ Create depth map for SDXL depth control net.
612
+ """
613
+ from transformers import DPTFeatureExtractor, DPTForDepthEstimation # noqa: PLC0415
614
+
615
+ depth_estimator = DPTForDepthEstimation.from_pretrained("Intel/dpt-hybrid-midas").to("cuda")
616
+ feature_extractor = DPTFeatureExtractor.from_pretrained("Intel/dpt-hybrid-midas")
617
+
618
+ image = feature_extractor(images=image, return_tensors="pt").pixel_values.to("cuda")
619
+ with torch.no_grad(), torch.autocast("cuda"):
620
+ depth_map = depth_estimator(image).predicted_depth
621
+
622
+ # The depth map is 384x384 by default, here we interpolate to the default output size.
623
+ # Note that it will be resized to output image size later. May change the size here to avoid interpolate twice.
624
+ depth_map = torch.nn.functional.interpolate(
625
+ depth_map.unsqueeze(1),
626
+ size=(1024, 1024),
627
+ mode="bicubic",
628
+ align_corners=False,
629
+ )
630
+ depth_min = torch.amin(depth_map, dim=[1, 2, 3], keepdim=True)
631
+ depth_max = torch.amax(depth_map, dim=[1, 2, 3], keepdim=True)
632
+ depth_map = (depth_map - depth_min) / (depth_max - depth_min)
633
+ image = torch.cat([depth_map] * 3, dim=1)
634
+
635
+ image = image.permute(0, 2, 3, 1).cpu().numpy()[0]
636
+ image = Image.fromarray((image * 255.0).clip(0, 255).astype(np.uint8))
637
+ return image
638
+
639
+
640
+ def get_canny_image(image) -> Image.Image:
641
+ """
642
+ Create canny image for SDXL control net.
643
+ """
644
+ image = np.array(image)
645
+ image = cv2.Canny(image, 100, 200)
646
+ image = image[:, :, None]
647
+ image = np.concatenate([image, image, image], axis=2)
648
+ image = Image.fromarray(image)
649
+ return image
650
+
651
+
652
+ def process_controlnet_images_xl(args) -> list[Image.Image]:
653
+ """
654
+ Process control image for SDXL control net.
655
+ """
656
+ assert len(args.controlnet_image) == 1
657
+ image = Image.open(args.controlnet_image[0]).convert("RGB")
658
+
659
+ controlnet_images = []
660
+ if args.controlnet_type[0] == "canny":
661
+ controlnet_images.append(get_canny_image(image))
662
+ elif args.controlnet_type[0] == "depth":
663
+ controlnet_images.append(get_depth_image(image))
664
+ else:
665
+ raise ValueError(f"This controlnet type is not supported for SDXL or Turbo: {args.controlnet_type}.")
666
+
667
+ return controlnet_images
668
+
669
+
670
+ def add_controlnet_arguments(parser, is_xl: bool = False):
671
+ """
672
+ Add control net related arguments.
673
+ """
674
+ group = parser.add_argument_group("Options for ControlNet (supports 1.5, sd-turbo, xl-turbo, xl-1.0).")
675
+
676
+ group.add_argument(
677
+ "-ci",
678
+ "--controlnet-image",
679
+ nargs="*",
680
+ type=str,
681
+ default=[],
682
+ help="Path to the input regular RGB image/images for controlnet",
683
+ )
684
+ group.add_argument(
685
+ "-ct",
686
+ "--controlnet-type",
687
+ nargs="*",
688
+ type=str,
689
+ default=[],
690
+ choices=list(PipelineInfo.supported_controlnet("xl-1.0" if is_xl else "1.5").keys()),
691
+ help="A list of controlnet type",
692
+ )
693
+ group.add_argument(
694
+ "-cs",
695
+ "--controlnet-scale",
696
+ nargs="*",
697
+ type=float,
698
+ default=[],
699
+ help="The outputs of the controlnet are multiplied by `controlnet_scale` before they are added to the residual in the original unet. Default is 0.5 for SDXL, or 1.0 for SD 1.5",
700
+ )
701
+
702
+
703
+ def process_controlnet_image(controlnet_type: str, image: Image.Image, height, width):
704
+ """
705
+ Process control images of control net v1.1 for Stable Diffusion 1.5.
706
+ """
707
+ control_image = None
708
+ shape = (height, width)
709
+ image = image.convert("RGB")
710
+ if controlnet_type == "canny":
711
+ canny_image = controlnet_aux.CannyDetector()(image)
712
+ control_image = canny_image.resize(shape)
713
+ elif controlnet_type == "normalbae":
714
+ normal_image = controlnet_aux.NormalBaeDetector.from_pretrained("lllyasviel/Annotators")(image)
715
+ control_image = normal_image.resize(shape)
716
+ elif controlnet_type == "depth":
717
+ depth_image = controlnet_aux.LeresDetector.from_pretrained("lllyasviel/Annotators")(image)
718
+ control_image = depth_image.resize(shape)
719
+ elif controlnet_type == "mlsd":
720
+ mlsd_image = controlnet_aux.MLSDdetector.from_pretrained("lllyasviel/Annotators")(image)
721
+ control_image = mlsd_image.resize(shape)
722
+ elif controlnet_type == "openpose":
723
+ openpose_image = controlnet_aux.OpenposeDetector.from_pretrained("lllyasviel/Annotators")(image)
724
+ control_image = openpose_image.resize(shape)
725
+ elif controlnet_type == "scribble":
726
+ scribble_image = controlnet_aux.HEDdetector.from_pretrained("lllyasviel/Annotators")(image, scribble=True)
727
+ control_image = scribble_image.resize(shape)
728
+ elif controlnet_type == "seg":
729
+ seg_image = controlnet_aux.SamDetector.from_pretrained("ybelkada/segment-anything", subfolder="checkpoints")(
730
+ image
731
+ )
732
+ control_image = seg_image.resize(shape)
733
+ else:
734
+ raise ValueError(f"There is no demo image of this controlnet_type: {controlnet_type}")
735
+ return control_image
736
+
737
+
738
+ def process_controlnet_arguments(args):
739
+ """
740
+ Process control net arguments, and returns a list of control images and a tensor of control net scales.
741
+ """
742
+ assert isinstance(args.controlnet_type, list)
743
+ assert isinstance(args.controlnet_scale, list)
744
+ assert isinstance(args.controlnet_image, list)
745
+
746
+ if len(args.controlnet_image) != len(args.controlnet_type):
747
+ raise ValueError(
748
+ f"Numbers of controlnet_image {len(args.controlnet_image)} should be equal to number of controlnet_type {len(args.controlnet_type)}."
749
+ )
750
+
751
+ if len(args.controlnet_type) == 0:
752
+ return None, None
753
+
754
+ if args.version not in ["1.5", "xl-1.0", "xl-turbo", "sd-turbo"]:
755
+ raise ValueError("This demo only supports ControlNet in Stable Diffusion 1.5, XL or Turbo.")
756
+
757
+ is_xl = "xl" in args.version
758
+ if is_xl and len(args.controlnet_type) > 1:
759
+ raise ValueError("This demo only support one ControlNet for Stable Diffusion XL or Turbo.")
760
+
761
+ if len(args.controlnet_scale) == 0:
762
+ args.controlnet_scale = [0.5 if is_xl else 1.0] * len(args.controlnet_type)
763
+ elif len(args.controlnet_type) != len(args.controlnet_scale):
764
+ raise ValueError(
765
+ f"Numbers of controlnet_type {len(args.controlnet_type)} should be equal to number of controlnet_scale {len(args.controlnet_scale)}."
766
+ )
767
+
768
+ # Convert controlnet scales to tensor
769
+ controlnet_scale = torch.FloatTensor(args.controlnet_scale)
770
+
771
+ if is_xl:
772
+ images = process_controlnet_images_xl(args)
773
+ else:
774
+ images = []
775
+ for i, image in enumerate(args.controlnet_image):
776
+ images.append(process_controlnet_image(args.controlnet_type[i], Image.open(image), args.height, args.width))
777
+
778
+ return images, controlnet_scale