onnxruntime-directml 1.24.1__cp314-cp314-win_amd64.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (322) hide show
  1. onnxruntime/LICENSE +21 -0
  2. onnxruntime/Privacy.md +21 -0
  3. onnxruntime/ThirdPartyNotices.txt +6121 -0
  4. onnxruntime/__init__.py +418 -0
  5. onnxruntime/backend/__init__.py +6 -0
  6. onnxruntime/backend/backend.py +175 -0
  7. onnxruntime/backend/backend_rep.py +52 -0
  8. onnxruntime/capi/DirectML.dll +0 -0
  9. onnxruntime/capi/__init__.py +4 -0
  10. onnxruntime/capi/_ld_preload.py +7 -0
  11. onnxruntime/capi/_pybind_state.py +33 -0
  12. onnxruntime/capi/build_and_package_info.py +2 -0
  13. onnxruntime/capi/convert_npz_to_onnx_adapter.py +48 -0
  14. onnxruntime/capi/onnxruntime.dll +0 -0
  15. onnxruntime/capi/onnxruntime_collect_build_info.py +47 -0
  16. onnxruntime/capi/onnxruntime_inference_collection.py +1440 -0
  17. onnxruntime/capi/onnxruntime_providers_shared.dll +0 -0
  18. onnxruntime/capi/onnxruntime_pybind11_state.pyd +0 -0
  19. onnxruntime/capi/onnxruntime_validation.py +154 -0
  20. onnxruntime/capi/version_info.py +2 -0
  21. onnxruntime/datasets/__init__.py +18 -0
  22. onnxruntime/datasets/logreg_iris.onnx +0 -0
  23. onnxruntime/datasets/mul_1.onnx +0 -0
  24. onnxruntime/datasets/sigmoid.onnx +13 -0
  25. onnxruntime/quantization/CalTableFlatBuffers/KeyValue.py +78 -0
  26. onnxruntime/quantization/CalTableFlatBuffers/TrtTable.py +90 -0
  27. onnxruntime/quantization/CalTableFlatBuffers/__init__.py +0 -0
  28. onnxruntime/quantization/__init__.py +19 -0
  29. onnxruntime/quantization/base_quantizer.py +529 -0
  30. onnxruntime/quantization/calibrate.py +1267 -0
  31. onnxruntime/quantization/execution_providers/qnn/__init__.py +2 -0
  32. onnxruntime/quantization/execution_providers/qnn/fusion_lpnorm.py +132 -0
  33. onnxruntime/quantization/execution_providers/qnn/fusion_spacetodepth.py +162 -0
  34. onnxruntime/quantization/execution_providers/qnn/mixed_precision_overrides_utils.py +413 -0
  35. onnxruntime/quantization/execution_providers/qnn/preprocess.py +353 -0
  36. onnxruntime/quantization/execution_providers/qnn/quant_config.py +389 -0
  37. onnxruntime/quantization/fusions/__init__.py +4 -0
  38. onnxruntime/quantization/fusions/fusion.py +311 -0
  39. onnxruntime/quantization/fusions/fusion_gelu.py +272 -0
  40. onnxruntime/quantization/fusions/fusion_layernorm.py +146 -0
  41. onnxruntime/quantization/fusions/replace_upsample_with_resize.py +96 -0
  42. onnxruntime/quantization/matmul_bnb4_quantizer.py +239 -0
  43. onnxruntime/quantization/matmul_nbits_quantizer.py +1638 -0
  44. onnxruntime/quantization/neural_compressor/__init__.py +1 -0
  45. onnxruntime/quantization/neural_compressor/onnx_model.py +1251 -0
  46. onnxruntime/quantization/neural_compressor/util.py +80 -0
  47. onnxruntime/quantization/neural_compressor/weight_only.py +932 -0
  48. onnxruntime/quantization/onnx_model.py +600 -0
  49. onnxruntime/quantization/onnx_quantizer.py +1163 -0
  50. onnxruntime/quantization/operators/__init__.py +2 -0
  51. onnxruntime/quantization/operators/activation.py +119 -0
  52. onnxruntime/quantization/operators/argmax.py +18 -0
  53. onnxruntime/quantization/operators/attention.py +73 -0
  54. onnxruntime/quantization/operators/base_operator.py +26 -0
  55. onnxruntime/quantization/operators/binary_op.py +72 -0
  56. onnxruntime/quantization/operators/concat.py +62 -0
  57. onnxruntime/quantization/operators/conv.py +260 -0
  58. onnxruntime/quantization/operators/direct_q8.py +78 -0
  59. onnxruntime/quantization/operators/embed_layernorm.py +121 -0
  60. onnxruntime/quantization/operators/gather.py +64 -0
  61. onnxruntime/quantization/operators/gavgpool.py +62 -0
  62. onnxruntime/quantization/operators/gemm.py +172 -0
  63. onnxruntime/quantization/operators/lstm.py +121 -0
  64. onnxruntime/quantization/operators/matmul.py +231 -0
  65. onnxruntime/quantization/operators/maxpool.py +34 -0
  66. onnxruntime/quantization/operators/norm.py +40 -0
  67. onnxruntime/quantization/operators/pad.py +172 -0
  68. onnxruntime/quantization/operators/pooling.py +67 -0
  69. onnxruntime/quantization/operators/qdq_base_operator.py +22 -0
  70. onnxruntime/quantization/operators/resize.py +34 -0
  71. onnxruntime/quantization/operators/softmax.py +74 -0
  72. onnxruntime/quantization/operators/split.py +63 -0
  73. onnxruntime/quantization/operators/where.py +87 -0
  74. onnxruntime/quantization/preprocess.py +141 -0
  75. onnxruntime/quantization/qdq_loss_debug.py +389 -0
  76. onnxruntime/quantization/qdq_quantizer.py +1477 -0
  77. onnxruntime/quantization/quant_utils.py +1051 -0
  78. onnxruntime/quantization/quantize.py +953 -0
  79. onnxruntime/quantization/registry.py +110 -0
  80. onnxruntime/quantization/shape_inference.py +204 -0
  81. onnxruntime/quantization/static_quantize_runner.py +256 -0
  82. onnxruntime/quantization/tensor_quant_overrides.py +520 -0
  83. onnxruntime/tools/__init__.py +10 -0
  84. onnxruntime/tools/check_onnx_model_mobile_usability.py +47 -0
  85. onnxruntime/tools/convert_onnx_models_to_ort.py +380 -0
  86. onnxruntime/tools/file_utils.py +47 -0
  87. onnxruntime/tools/logger.py +11 -0
  88. onnxruntime/tools/make_dynamic_shape_fixed.py +73 -0
  89. onnxruntime/tools/mobile_helpers/__init__.py +0 -0
  90. onnxruntime/tools/mobile_helpers/coreml_supported_mlprogram_ops.md +53 -0
  91. onnxruntime/tools/mobile_helpers/coreml_supported_neuralnetwork_ops.md +43 -0
  92. onnxruntime/tools/mobile_helpers/nnapi_supported_ops.md +58 -0
  93. onnxruntime/tools/mobile_helpers/usability_checker.py +738 -0
  94. onnxruntime/tools/offline_tuning.py +169 -0
  95. onnxruntime/tools/onnx_model_utils.py +416 -0
  96. onnxruntime/tools/onnx_randomizer.py +85 -0
  97. onnxruntime/tools/onnxruntime_test.py +164 -0
  98. onnxruntime/tools/optimize_onnx_model.py +56 -0
  99. onnxruntime/tools/ort_format_model/__init__.py +27 -0
  100. onnxruntime/tools/ort_format_model/operator_type_usage_processors.py +653 -0
  101. onnxruntime/tools/ort_format_model/ort_flatbuffers_py/__init__.py +0 -0
  102. onnxruntime/tools/ort_format_model/ort_flatbuffers_py/fbs/ArgType.py +7 -0
  103. onnxruntime/tools/ort_format_model/ort_flatbuffers_py/fbs/ArgTypeAndIndex.py +67 -0
  104. onnxruntime/tools/ort_format_model/ort_flatbuffers_py/fbs/Attribute.py +337 -0
  105. onnxruntime/tools/ort_format_model/ort_flatbuffers_py/fbs/AttributeType.py +18 -0
  106. onnxruntime/tools/ort_format_model/ort_flatbuffers_py/fbs/Checkpoint.py +125 -0
  107. onnxruntime/tools/ort_format_model/ort_flatbuffers_py/fbs/DeprecatedKernelCreateInfos.py +120 -0
  108. onnxruntime/tools/ort_format_model/ort_flatbuffers_py/fbs/DeprecatedNodeIndexAndKernelDefHash.py +68 -0
  109. onnxruntime/tools/ort_format_model/ort_flatbuffers_py/fbs/DeprecatedSessionState.py +96 -0
  110. onnxruntime/tools/ort_format_model/ort_flatbuffers_py/fbs/DeprecatedSubGraphSessionState.py +72 -0
  111. onnxruntime/tools/ort_format_model/ort_flatbuffers_py/fbs/Dimension.py +71 -0
  112. onnxruntime/tools/ort_format_model/ort_flatbuffers_py/fbs/DimensionValue.py +80 -0
  113. onnxruntime/tools/ort_format_model/ort_flatbuffers_py/fbs/DimensionValueType.py +8 -0
  114. onnxruntime/tools/ort_format_model/ort_flatbuffers_py/fbs/EdgeEnd.py +32 -0
  115. onnxruntime/tools/ort_format_model/ort_flatbuffers_py/fbs/FloatProperty.py +67 -0
  116. onnxruntime/tools/ort_format_model/ort_flatbuffers_py/fbs/Graph.py +320 -0
  117. onnxruntime/tools/ort_format_model/ort_flatbuffers_py/fbs/InferenceSession.py +88 -0
  118. onnxruntime/tools/ort_format_model/ort_flatbuffers_py/fbs/IntProperty.py +67 -0
  119. onnxruntime/tools/ort_format_model/ort_flatbuffers_py/fbs/KernelTypeStrArgsEntry.py +91 -0
  120. onnxruntime/tools/ort_format_model/ort_flatbuffers_py/fbs/KernelTypeStrResolver.py +78 -0
  121. onnxruntime/tools/ort_format_model/ort_flatbuffers_py/fbs/MapType.py +71 -0
  122. onnxruntime/tools/ort_format_model/ort_flatbuffers_py/fbs/Model.py +223 -0
  123. onnxruntime/tools/ort_format_model/ort_flatbuffers_py/fbs/ModuleState.py +141 -0
  124. onnxruntime/tools/ort_format_model/ort_flatbuffers_py/fbs/Node.py +317 -0
  125. onnxruntime/tools/ort_format_model/ort_flatbuffers_py/fbs/NodeEdge.py +126 -0
  126. onnxruntime/tools/ort_format_model/ort_flatbuffers_py/fbs/NodeType.py +7 -0
  127. onnxruntime/tools/ort_format_model/ort_flatbuffers_py/fbs/NodesToOptimizeIndices.py +160 -0
  128. onnxruntime/tools/ort_format_model/ort_flatbuffers_py/fbs/OpIdKernelTypeStrArgsEntry.py +91 -0
  129. onnxruntime/tools/ort_format_model/ort_flatbuffers_py/fbs/OperatorSetId.py +67 -0
  130. onnxruntime/tools/ort_format_model/ort_flatbuffers_py/fbs/OptimizerGroup.py +117 -0
  131. onnxruntime/tools/ort_format_model/ort_flatbuffers_py/fbs/ParameterOptimizerState.py +91 -0
  132. onnxruntime/tools/ort_format_model/ort_flatbuffers_py/fbs/PropertyBag.py +152 -0
  133. onnxruntime/tools/ort_format_model/ort_flatbuffers_py/fbs/RuntimeOptimizationRecord.py +105 -0
  134. onnxruntime/tools/ort_format_model/ort_flatbuffers_py/fbs/RuntimeOptimizationRecordContainerEntry.py +91 -0
  135. onnxruntime/tools/ort_format_model/ort_flatbuffers_py/fbs/RuntimeOptimizations.py +79 -0
  136. onnxruntime/tools/ort_format_model/ort_flatbuffers_py/fbs/SequenceType.py +58 -0
  137. onnxruntime/tools/ort_format_model/ort_flatbuffers_py/fbs/Shape.py +78 -0
  138. onnxruntime/tools/ort_format_model/ort_flatbuffers_py/fbs/SparseTensor.py +114 -0
  139. onnxruntime/tools/ort_format_model/ort_flatbuffers_py/fbs/StringProperty.py +67 -0
  140. onnxruntime/tools/ort_format_model/ort_flatbuffers_py/fbs/StringStringEntry.py +67 -0
  141. onnxruntime/tools/ort_format_model/ort_flatbuffers_py/fbs/Tensor.py +203 -0
  142. onnxruntime/tools/ort_format_model/ort_flatbuffers_py/fbs/TensorDataType.py +26 -0
  143. onnxruntime/tools/ort_format_model/ort_flatbuffers_py/fbs/TensorTypeAndShape.py +71 -0
  144. onnxruntime/tools/ort_format_model/ort_flatbuffers_py/fbs/TypeInfo.py +83 -0
  145. onnxruntime/tools/ort_format_model/ort_flatbuffers_py/fbs/TypeInfoValue.py +9 -0
  146. onnxruntime/tools/ort_format_model/ort_flatbuffers_py/fbs/ValueInfo.py +84 -0
  147. onnxruntime/tools/ort_format_model/ort_flatbuffers_py/fbs/__init__.py +6 -0
  148. onnxruntime/tools/ort_format_model/ort_model_processor.py +86 -0
  149. onnxruntime/tools/ort_format_model/types.py +85 -0
  150. onnxruntime/tools/ort_format_model/utils.py +61 -0
  151. onnxruntime/tools/pytorch_export_contrib_ops.py +129 -0
  152. onnxruntime/tools/pytorch_export_helpers.py +131 -0
  153. onnxruntime/tools/qdq_helpers/__init__.py +0 -0
  154. onnxruntime/tools/qdq_helpers/optimize_qdq_model.py +37 -0
  155. onnxruntime/tools/qnn/add_trans_cast.py +292 -0
  156. onnxruntime/tools/qnn/gen_qnn_ctx_onnx_model.py +364 -0
  157. onnxruntime/tools/qnn/preprocess.py +165 -0
  158. onnxruntime/tools/reduced_build_config_parser.py +203 -0
  159. onnxruntime/tools/remove_initializer_from_input.py +37 -0
  160. onnxruntime/tools/symbolic_shape_infer.py +3094 -0
  161. onnxruntime/tools/update_onnx_opset.py +31 -0
  162. onnxruntime/transformers/__init__.py +8 -0
  163. onnxruntime/transformers/affinity_helper.py +40 -0
  164. onnxruntime/transformers/benchmark.py +942 -0
  165. onnxruntime/transformers/benchmark_helper.py +643 -0
  166. onnxruntime/transformers/bert_perf_test.py +629 -0
  167. onnxruntime/transformers/bert_test_data.py +641 -0
  168. onnxruntime/transformers/compare_bert_results.py +256 -0
  169. onnxruntime/transformers/constants.py +47 -0
  170. onnxruntime/transformers/convert_generation.py +3605 -0
  171. onnxruntime/transformers/convert_tf_models_to_pytorch.py +205 -0
  172. onnxruntime/transformers/convert_to_packing_mode.py +385 -0
  173. onnxruntime/transformers/dynamo_onnx_helper.py +205 -0
  174. onnxruntime/transformers/float16.py +501 -0
  175. onnxruntime/transformers/fusion_attention.py +1189 -0
  176. onnxruntime/transformers/fusion_attention_clip.py +340 -0
  177. onnxruntime/transformers/fusion_attention_sam2.py +533 -0
  178. onnxruntime/transformers/fusion_attention_unet.py +1307 -0
  179. onnxruntime/transformers/fusion_attention_vae.py +300 -0
  180. onnxruntime/transformers/fusion_bart_attention.py +435 -0
  181. onnxruntime/transformers/fusion_base.py +141 -0
  182. onnxruntime/transformers/fusion_bias_add.py +57 -0
  183. onnxruntime/transformers/fusion_biasgelu.py +66 -0
  184. onnxruntime/transformers/fusion_biassplitgelu.py +110 -0
  185. onnxruntime/transformers/fusion_conformer_attention.py +222 -0
  186. onnxruntime/transformers/fusion_constant_fold.py +144 -0
  187. onnxruntime/transformers/fusion_embedlayer.py +810 -0
  188. onnxruntime/transformers/fusion_fastgelu.py +492 -0
  189. onnxruntime/transformers/fusion_gelu.py +258 -0
  190. onnxruntime/transformers/fusion_gelu_approximation.py +25 -0
  191. onnxruntime/transformers/fusion_gemmfastgelu.py +121 -0
  192. onnxruntime/transformers/fusion_gpt_attention.py +546 -0
  193. onnxruntime/transformers/fusion_gpt_attention_megatron.py +355 -0
  194. onnxruntime/transformers/fusion_gpt_attention_no_past.py +260 -0
  195. onnxruntime/transformers/fusion_group_norm.py +180 -0
  196. onnxruntime/transformers/fusion_layernorm.py +489 -0
  197. onnxruntime/transformers/fusion_mha_mmdit.py +667 -0
  198. onnxruntime/transformers/fusion_nhwc_conv.py +99 -0
  199. onnxruntime/transformers/fusion_options.py +340 -0
  200. onnxruntime/transformers/fusion_qordered_attention.py +420 -0
  201. onnxruntime/transformers/fusion_qordered_gelu.py +118 -0
  202. onnxruntime/transformers/fusion_qordered_layernorm.py +122 -0
  203. onnxruntime/transformers/fusion_qordered_matmul.py +216 -0
  204. onnxruntime/transformers/fusion_quickgelu.py +74 -0
  205. onnxruntime/transformers/fusion_reshape.py +173 -0
  206. onnxruntime/transformers/fusion_rotary_attention.py +1591 -0
  207. onnxruntime/transformers/fusion_shape.py +109 -0
  208. onnxruntime/transformers/fusion_simplified_layernorm.py +165 -0
  209. onnxruntime/transformers/fusion_skip_group_norm.py +254 -0
  210. onnxruntime/transformers/fusion_skiplayernorm.py +209 -0
  211. onnxruntime/transformers/fusion_transpose.py +167 -0
  212. onnxruntime/transformers/fusion_utils.py +321 -0
  213. onnxruntime/transformers/huggingface_models.py +74 -0
  214. onnxruntime/transformers/import_utils.py +20 -0
  215. onnxruntime/transformers/io_binding_helper.py +487 -0
  216. onnxruntime/transformers/large_model_exporter.py +395 -0
  217. onnxruntime/transformers/machine_info.py +230 -0
  218. onnxruntime/transformers/metrics.py +163 -0
  219. onnxruntime/transformers/models/bart/__init__.py +12 -0
  220. onnxruntime/transformers/models/bart/export.py +98 -0
  221. onnxruntime/transformers/models/bert/__init__.py +12 -0
  222. onnxruntime/transformers/models/bert/eval_squad.py +329 -0
  223. onnxruntime/transformers/models/gpt2/__init__.py +12 -0
  224. onnxruntime/transformers/models/gpt2/benchmark_gpt2.py +413 -0
  225. onnxruntime/transformers/models/gpt2/convert_to_onnx.py +566 -0
  226. onnxruntime/transformers/models/gpt2/gpt2_helper.py +1031 -0
  227. onnxruntime/transformers/models/gpt2/gpt2_parity.py +513 -0
  228. onnxruntime/transformers/models/gpt2/gpt2_tester.py +501 -0
  229. onnxruntime/transformers/models/gpt2/parity_check_helper.py +146 -0
  230. onnxruntime/transformers/models/llama/__init__.py +12 -0
  231. onnxruntime/transformers/models/llama/benchmark.py +700 -0
  232. onnxruntime/transformers/models/llama/benchmark_all.py +488 -0
  233. onnxruntime/transformers/models/llama/benchmark_e2e.py +608 -0
  234. onnxruntime/transformers/models/llama/convert_to_onnx.py +1064 -0
  235. onnxruntime/transformers/models/llama/dist_settings.py +57 -0
  236. onnxruntime/transformers/models/llama/llama_inputs.py +504 -0
  237. onnxruntime/transformers/models/llama/llama_parity.py +343 -0
  238. onnxruntime/transformers/models/llama/llama_torch.py +47 -0
  239. onnxruntime/transformers/models/llama/quant_kv_dataloader.py +108 -0
  240. onnxruntime/transformers/models/longformer/__init__.py +12 -0
  241. onnxruntime/transformers/models/longformer/benchmark_longformer.py +821 -0
  242. onnxruntime/transformers/models/longformer/convert_to_onnx.py +413 -0
  243. onnxruntime/transformers/models/longformer/generate_test_data.py +347 -0
  244. onnxruntime/transformers/models/longformer/longformer_helper.py +76 -0
  245. onnxruntime/transformers/models/phi2/__init__.py +12 -0
  246. onnxruntime/transformers/models/phi2/convert_to_onnx.py +590 -0
  247. onnxruntime/transformers/models/phi2/inference_example.py +414 -0
  248. onnxruntime/transformers/models/sam2/__init__.py +12 -0
  249. onnxruntime/transformers/models/sam2/benchmark_sam2.py +638 -0
  250. onnxruntime/transformers/models/sam2/convert_to_onnx.py +270 -0
  251. onnxruntime/transformers/models/sam2/image_decoder.py +272 -0
  252. onnxruntime/transformers/models/sam2/image_encoder.py +236 -0
  253. onnxruntime/transformers/models/sam2/mask_decoder.py +208 -0
  254. onnxruntime/transformers/models/sam2/nvtx_helper.py +33 -0
  255. onnxruntime/transformers/models/sam2/prompt_encoder.py +189 -0
  256. onnxruntime/transformers/models/sam2/sam2_demo.py +321 -0
  257. onnxruntime/transformers/models/sam2/sam2_image_onnx_predictor.py +279 -0
  258. onnxruntime/transformers/models/sam2/sam2_utils.py +147 -0
  259. onnxruntime/transformers/models/stable_diffusion/__init__.py +12 -0
  260. onnxruntime/transformers/models/stable_diffusion/benchmark.py +1519 -0
  261. onnxruntime/transformers/models/stable_diffusion/benchmark_controlnet.py +426 -0
  262. onnxruntime/transformers/models/stable_diffusion/demo_txt2img.py +103 -0
  263. onnxruntime/transformers/models/stable_diffusion/demo_txt2img_xl.py +269 -0
  264. onnxruntime/transformers/models/stable_diffusion/demo_utils.py +778 -0
  265. onnxruntime/transformers/models/stable_diffusion/diffusion_models.py +1318 -0
  266. onnxruntime/transformers/models/stable_diffusion/diffusion_schedulers.py +1179 -0
  267. onnxruntime/transformers/models/stable_diffusion/engine_builder.py +295 -0
  268. onnxruntime/transformers/models/stable_diffusion/engine_builder_ort_cuda.py +387 -0
  269. onnxruntime/transformers/models/stable_diffusion/engine_builder_ort_trt.py +288 -0
  270. onnxruntime/transformers/models/stable_diffusion/engine_builder_tensorrt.py +395 -0
  271. onnxruntime/transformers/models/stable_diffusion/engine_builder_torch.py +108 -0
  272. onnxruntime/transformers/models/stable_diffusion/optimize_pipeline.py +590 -0
  273. onnxruntime/transformers/models/stable_diffusion/ort_optimizer.py +136 -0
  274. onnxruntime/transformers/models/stable_diffusion/pipeline_stable_diffusion.py +831 -0
  275. onnxruntime/transformers/models/stable_diffusion/trt_utilities.py +12 -0
  276. onnxruntime/transformers/models/t5/__init__.py +12 -0
  277. onnxruntime/transformers/models/t5/convert_to_onnx.py +318 -0
  278. onnxruntime/transformers/models/t5/t5_decoder.py +437 -0
  279. onnxruntime/transformers/models/t5/t5_encoder.py +70 -0
  280. onnxruntime/transformers/models/t5/t5_encoder_decoder_init.py +361 -0
  281. onnxruntime/transformers/models/t5/t5_helper.py +302 -0
  282. onnxruntime/transformers/models/whisper/__init__.py +12 -0
  283. onnxruntime/transformers/models/whisper/benchmark.py +585 -0
  284. onnxruntime/transformers/models/whisper/benchmark_all.py +526 -0
  285. onnxruntime/transformers/models/whisper/convert_to_onnx.py +609 -0
  286. onnxruntime/transformers/models/whisper/whisper_chain.py +334 -0
  287. onnxruntime/transformers/models/whisper/whisper_decoder.py +464 -0
  288. onnxruntime/transformers/models/whisper/whisper_encoder.py +164 -0
  289. onnxruntime/transformers/models/whisper/whisper_encoder_decoder_init.py +371 -0
  290. onnxruntime/transformers/models/whisper/whisper_helper.py +1035 -0
  291. onnxruntime/transformers/models/whisper/whisper_inputs.py +380 -0
  292. onnxruntime/transformers/models/whisper/whisper_jump_times.py +477 -0
  293. onnxruntime/transformers/onnx_exporter.py +719 -0
  294. onnxruntime/transformers/onnx_model.py +1636 -0
  295. onnxruntime/transformers/onnx_model_bart.py +141 -0
  296. onnxruntime/transformers/onnx_model_bert.py +488 -0
  297. onnxruntime/transformers/onnx_model_bert_keras.py +474 -0
  298. onnxruntime/transformers/onnx_model_bert_tf.py +588 -0
  299. onnxruntime/transformers/onnx_model_clip.py +42 -0
  300. onnxruntime/transformers/onnx_model_conformer.py +32 -0
  301. onnxruntime/transformers/onnx_model_gpt2.py +101 -0
  302. onnxruntime/transformers/onnx_model_mmdit.py +112 -0
  303. onnxruntime/transformers/onnx_model_phi.py +929 -0
  304. onnxruntime/transformers/onnx_model_sam2.py +137 -0
  305. onnxruntime/transformers/onnx_model_t5.py +985 -0
  306. onnxruntime/transformers/onnx_model_tnlr.py +226 -0
  307. onnxruntime/transformers/onnx_model_unet.py +258 -0
  308. onnxruntime/transformers/onnx_model_vae.py +42 -0
  309. onnxruntime/transformers/onnx_utils.py +55 -0
  310. onnxruntime/transformers/optimizer.py +620 -0
  311. onnxruntime/transformers/past_helper.py +149 -0
  312. onnxruntime/transformers/profile_result_processor.py +358 -0
  313. onnxruntime/transformers/profiler.py +434 -0
  314. onnxruntime/transformers/quantize_helper.py +76 -0
  315. onnxruntime/transformers/shape_infer_helper.py +121 -0
  316. onnxruntime/transformers/shape_optimizer.py +400 -0
  317. onnxruntime/transformers/torch_onnx_export_helper.py +74 -0
  318. onnxruntime_directml-1.24.1.dist-info/METADATA +216 -0
  319. onnxruntime_directml-1.24.1.dist-info/RECORD +322 -0
  320. onnxruntime_directml-1.24.1.dist-info/WHEEL +5 -0
  321. onnxruntime_directml-1.24.1.dist-info/entry_points.txt +2 -0
  322. onnxruntime_directml-1.24.1.dist-info/top_level.txt +1 -0
@@ -0,0 +1,1318 @@
1
+ # -------------------------------------------------------------------------
2
+ # Copyright (c) Microsoft Corporation. All rights reserved.
3
+ # Licensed under the MIT License.
4
+ # --------------------------------------------------------------------------
5
+ # Modified from stable_diffusion_tensorrt_txt2img.py in diffusers and TensorRT demo diffusion,
6
+ # which has the following license:
7
+ #
8
+ # Copyright 2023 The HuggingFace Inc. team.
9
+ # SPDX-FileCopyrightText: Copyright (c) 1993-2023 NVIDIA CORPORATION & AFFILIATES. All rights reserved.
10
+ # SPDX-License-Identifier: Apache-2.0
11
+ #
12
+ # Licensed under the Apache License, Version 2.0 (the "License");
13
+ # you may not use this file except in compliance with the License.
14
+ # You may obtain a copy of the License at
15
+ #
16
+ # http://www.apache.org/licenses/LICENSE-2.0
17
+ #
18
+ # Unless required by applicable law or agreed to in writing, software
19
+ # distributed under the License is distributed on an "AS IS" BASIS,
20
+ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
21
+ # See the License for the specific language governing permissions and
22
+ # limitations under the License.
23
+
24
+ import logging
25
+ import os
26
+ import tempfile
27
+
28
+ import onnx
29
+ import onnx_graphsurgeon as gs
30
+ import torch
31
+ from diffusers.models import AutoencoderKL, ControlNetModel, UNet2DConditionModel
32
+ from onnx import GraphProto, ModelProto, shape_inference
33
+ from ort_optimizer import OrtStableDiffusionOptimizer
34
+ from polygraphy.backend.onnx.loader import fold_constants
35
+ from transformers import CLIPTextModel, CLIPTextModelWithProjection, CLIPTokenizer
36
+
37
+ from onnxruntime.transformers.onnx_model import OnnxModel
38
+
39
+ logger = logging.getLogger(__name__)
40
+
41
+
42
+ class TrtOptimizer:
43
+ def __init__(self, onnx_graph):
44
+ self.graph = gs.import_onnx(onnx_graph)
45
+
46
+ def cleanup(self):
47
+ self.graph.cleanup().toposort()
48
+
49
+ def get_optimized_onnx_graph(self):
50
+ return gs.export_onnx(self.graph)
51
+
52
+ def select_outputs(self, keep, names=None):
53
+ self.graph.outputs = [self.graph.outputs[o] for o in keep]
54
+ if names:
55
+ for i, name in enumerate(names):
56
+ self.graph.outputs[i].name = name
57
+
58
+ def fold_constants(self):
59
+ onnx_graph = fold_constants(gs.export_onnx(self.graph), allow_onnxruntime_shape_inference=True)
60
+ self.graph = gs.import_onnx(onnx_graph)
61
+
62
+ def infer_shapes(self):
63
+ onnx_graph = gs.export_onnx(self.graph)
64
+ if onnx_graph.ByteSize() >= onnx.checker.MAXIMUM_PROTOBUF:
65
+ with tempfile.TemporaryDirectory() as temp_dir:
66
+ input_onnx_path = os.path.join(temp_dir, "model.onnx")
67
+ onnx.save_model(
68
+ onnx_graph,
69
+ input_onnx_path,
70
+ save_as_external_data=True,
71
+ all_tensors_to_one_file=True,
72
+ convert_attribute=False,
73
+ )
74
+ output_onnx_path = os.path.join(temp_dir, "model_with_shape.onnx")
75
+ onnx.shape_inference.infer_shapes_path(input_onnx_path, output_onnx_path)
76
+ onnx_graph = onnx.load(output_onnx_path)
77
+ else:
78
+ onnx_graph = shape_inference.infer_shapes(onnx_graph)
79
+
80
+ self.graph = gs.import_onnx(onnx_graph)
81
+
82
+
83
+ class PipelineInfo:
84
+ def __init__(
85
+ self,
86
+ version: str,
87
+ is_inpaint: bool = False,
88
+ is_refiner: bool = False,
89
+ use_vae=True, # TODO: this has couple with output type of pipeline
90
+ min_image_size=256,
91
+ max_image_size=1024,
92
+ use_fp16_vae=True,
93
+ use_lcm=False,
94
+ do_classifier_free_guidance=True,
95
+ controlnet=None,
96
+ lora_weights=None,
97
+ lora_scale=1.0,
98
+ ):
99
+ self.version = version
100
+ self._is_inpaint = is_inpaint
101
+ self._is_refiner = is_refiner
102
+ self._use_vae = use_vae
103
+ self._min_image_size = min_image_size
104
+ self._max_image_size = max_image_size
105
+ self._use_fp16_vae = use_fp16_vae
106
+ self._use_lcm = use_lcm
107
+ self.do_classifier_free_guidance = do_classifier_free_guidance and not use_lcm
108
+ self.controlnet = controlnet # A list of control net type
109
+ self.lora_weights = lora_weights
110
+ self.lora_scale = lora_scale
111
+
112
+ if is_refiner:
113
+ assert not use_lcm
114
+ assert self.is_xl()
115
+
116
+ def is_inpaint(self) -> bool:
117
+ return self._is_inpaint
118
+
119
+ def is_xl(self) -> bool:
120
+ return "xl" in self.version
121
+
122
+ def is_xl_turbo(self) -> bool:
123
+ return self.version == "xl-turbo"
124
+
125
+ def is_xl_base(self) -> bool:
126
+ return self.version == "xl-1.0" and not self._is_refiner
127
+
128
+ def is_xl_base_or_turbo(self) -> bool:
129
+ return self.is_xl_base() or self.is_xl_turbo()
130
+
131
+ def is_xl_refiner(self) -> bool:
132
+ return self.version == "xl-1.0" and self._is_refiner
133
+
134
+ def use_safetensors(self) -> bool:
135
+ return self.is_xl() or self.version in ["sd-turbo"]
136
+
137
+ def stages(self) -> list[str]:
138
+ if self.is_xl_base_or_turbo():
139
+ return ["clip", "clip2", "unetxl"] + (["vae"] if self._use_vae else [])
140
+
141
+ if self.is_xl_refiner():
142
+ return ["clip2", "unetxl", "vae"]
143
+
144
+ return ["clip", "unet", "vae"]
145
+
146
+ def vae_scaling_factor(self) -> float:
147
+ return 0.13025 if self.is_xl() else 0.18215
148
+
149
+ def vae_torch_fallback(self) -> bool:
150
+ return self.is_xl() and not self._use_fp16_vae
151
+
152
+ def custom_fp16_vae(self) -> str | None:
153
+ # For SD XL, use a VAE that fine-tuned to run in fp16 precision without generating NaNs
154
+ return "madebyollin/sdxl-vae-fp16-fix" if self._use_fp16_vae and self.is_xl() else None
155
+
156
+ def custom_unet(self) -> str | None:
157
+ return "latent-consistency/lcm-sdxl" if self._use_lcm and self.is_xl_base() else None
158
+
159
+ @staticmethod
160
+ def supported_versions(is_xl: bool):
161
+ return ["xl-1.0", "xl-turbo"] if is_xl else ["1.4", "1.5", "2.0-base", "2.0", "2.1", "2.1-base", "sd-turbo"]
162
+
163
+ @staticmethod
164
+ def supported_models():
165
+ return {
166
+ "CompVis/stable-diffusion-v1-4": "1.4",
167
+ "runwayml/stable-diffusion-v1-5": "1.5",
168
+ "stabilityai/stable-diffusion-2-base": "2.0-base",
169
+ "stabilityai/stable-diffusion-2": "2.0",
170
+ "stabilityai/stable-diffusion-2-1": "2.1",
171
+ "stabilityai/stable-diffusion-2-1-base": "2.1",
172
+ "stabilityai/stable-diffusion-xl-base-1.0": "xl-1.0",
173
+ "stabilityai/stable-diffusion-xl-refiner-1.0": "xl-1.0",
174
+ "stabilityai/sdxl-turbo": "xl-turbo",
175
+ "stabilityai/sd-turbo": "sd-turbo",
176
+ # "runwayml/stable-diffusion-inpainting": "1.5",
177
+ # "stabilityai/stable-diffusion-2-inpainting": "2.0",
178
+ }
179
+
180
+ def name(self) -> str:
181
+ if self.version == "1.4":
182
+ if self.is_inpaint():
183
+ return "runwayml/stable-diffusion-inpainting"
184
+ else:
185
+ return "CompVis/stable-diffusion-v1-4"
186
+ elif self.version == "1.5":
187
+ if self.is_inpaint():
188
+ return "runwayml/stable-diffusion-inpainting"
189
+ else:
190
+ return "runwayml/stable-diffusion-v1-5"
191
+ elif self.version == "2.0-base":
192
+ if self.is_inpaint():
193
+ return "stabilityai/stable-diffusion-2-inpainting"
194
+ else:
195
+ return "stabilityai/stable-diffusion-2-base"
196
+ elif self.version == "2.0":
197
+ if self.is_inpaint():
198
+ return "stabilityai/stable-diffusion-2-inpainting"
199
+ else:
200
+ return "stabilityai/stable-diffusion-2"
201
+ elif self.version == "2.1":
202
+ return "stabilityai/stable-diffusion-2-1"
203
+ elif self.version == "2.1-base":
204
+ return "stabilityai/stable-diffusion-2-1-base"
205
+ elif self.version == "xl-1.0":
206
+ if self.is_xl_refiner():
207
+ return "stabilityai/stable-diffusion-xl-refiner-1.0"
208
+ else:
209
+ return "stabilityai/stable-diffusion-xl-base-1.0"
210
+ elif self.version == "xl-turbo":
211
+ return "stabilityai/sdxl-turbo"
212
+ elif self.version == "sd-turbo":
213
+ return "stabilityai/sd-turbo"
214
+
215
+ raise ValueError(f"Incorrect version {self.version}")
216
+
217
+ def short_name(self) -> str:
218
+ return self.name().split("/")[-1].replace("stable-diffusion", "sd")
219
+
220
+ def clip_embedding_dim(self):
221
+ # TODO: can we read from config instead
222
+ if self.version in ("1.4", "1.5"):
223
+ return 768
224
+ elif self.version in ("2.0", "2.0-base", "2.1", "2.1-base", "sd-turbo"):
225
+ return 1024
226
+ elif self.is_xl_base_or_turbo():
227
+ return 768
228
+ else:
229
+ raise ValueError(f"Invalid version {self.version}")
230
+
231
+ def clipwithproj_embedding_dim(self):
232
+ if self.is_xl():
233
+ return 1280
234
+ else:
235
+ raise ValueError(f"Invalid version {self.version}")
236
+
237
+ def unet_embedding_dim(self):
238
+ if self.version in ("1.4", "1.5"):
239
+ return 768
240
+ elif self.version in ("2.0", "2.0-base", "2.1", "2.1-base", "sd-turbo"):
241
+ return 1024
242
+ elif self.is_xl_base_or_turbo():
243
+ return 2048
244
+ elif self.is_xl_refiner():
245
+ return 1280
246
+ else:
247
+ raise ValueError(f"Invalid version {self.version}")
248
+
249
+ def min_image_size(self):
250
+ return self._min_image_size
251
+
252
+ def max_image_size(self):
253
+ return self._max_image_size
254
+
255
+ @staticmethod
256
+ def default_resolution(version: str) -> int:
257
+ if version == "xl-1.0":
258
+ return 1024
259
+ if version in ("2.0", "2.1"):
260
+ return 768
261
+ return 512
262
+
263
+ def default_image_size(self) -> int:
264
+ return PipelineInfo.default_resolution(self.version)
265
+
266
+ @staticmethod
267
+ def supported_controlnet(version="1.5"):
268
+ if version in ("xl-1.0", "xl-turbo"):
269
+ return {
270
+ "canny": "diffusers/controlnet-canny-sdxl-1.0",
271
+ "depth": "diffusers/controlnet-depth-sdxl-1.0",
272
+ }
273
+ elif version == "1.5":
274
+ return {
275
+ "canny": "lllyasviel/control_v11p_sd15_canny",
276
+ "depth": "lllyasviel/control_v11f1p_sd15_depth",
277
+ "openpose": "lllyasviel/control_v11p_sd15_openpose",
278
+ # "tile": "lllyasviel/control_v11f1e_sd15_tile",
279
+ # "lineart": "lllyasviel/control_v11p_sd15_lineart",
280
+ # "inpaint": "lllyasviel/control_v11p_sd15_inpaint",
281
+ # "softedge": "lllyasviel/control_v11p_sd15_softedge",
282
+ "mlsd": "lllyasviel/control_v11p_sd15_mlsd",
283
+ "scribble": "lllyasviel/control_v11p_sd15_scribble",
284
+ # "ip2p": "lllyasviel/control_v11e_sd15_ip2p",
285
+ "normalbae": "lllyasviel/control_v11p_sd15_normalbae",
286
+ "seg": "lllyasviel/control_v11p_sd15_seg",
287
+ # "shuffle": "lllyasviel/control_v11e_sd15_shuffle",
288
+ # "lineart_anime": "lllyasviel/control_v11p_sd15s2_lineart_anime",
289
+ }
290
+ return None
291
+
292
+ def controlnet_name(self):
293
+ """Return a list of controlnet name"""
294
+ if not self.controlnet:
295
+ return None
296
+ controlnet_map = PipelineInfo.supported_controlnet(self.version)
297
+ if controlnet_map is None:
298
+ return None
299
+ return [controlnet_map[controlnet] for controlnet in self.controlnet]
300
+
301
+
302
+ class BaseModel:
303
+ def __init__(
304
+ self,
305
+ pipeline_info: PipelineInfo,
306
+ model,
307
+ device,
308
+ fp16: bool = False,
309
+ max_batch_size: int = 16,
310
+ embedding_dim: int = 768,
311
+ text_maxlen: int = 77,
312
+ ):
313
+ self.name = self.__class__.__name__
314
+
315
+ self.pipeline_info = pipeline_info
316
+
317
+ self.model = model
318
+ self.fp16 = fp16
319
+ self.device = device
320
+
321
+ self.min_batch = 1
322
+ self.max_batch = max_batch_size
323
+ self.min_image_shape = pipeline_info.min_image_size()
324
+ self.max_image_shape = pipeline_info.max_image_size()
325
+ self.min_latent_shape = self.min_image_shape // 8
326
+ self.max_latent_shape = self.max_image_shape // 8
327
+
328
+ self.embedding_dim = embedding_dim
329
+ self.text_maxlen = text_maxlen
330
+
331
+ def get_batch_multiplier(self):
332
+ return 2 if self.pipeline_info.do_classifier_free_guidance else 1
333
+
334
+ def get_ort_optimizer(self):
335
+ model_name_to_model_type = {
336
+ "CLIP": "clip",
337
+ "UNet": "unet",
338
+ "VAE": "vae",
339
+ "UNetXL": "unet",
340
+ "CLIPWithProj": "clip",
341
+ }
342
+ model_type = model_name_to_model_type[self.name]
343
+ return OrtStableDiffusionOptimizer(model_type)
344
+
345
+ def get_model(self):
346
+ return self.model
347
+
348
+ def from_pretrained(self, model_class, framework_model_dir, subfolder=None, model_name=None, **kwargs):
349
+ if model_name is None:
350
+ model_name = self.pipeline_info.name()
351
+
352
+ if subfolder:
353
+ model_dir = os.path.join(framework_model_dir, model_name, subfolder)
354
+ else:
355
+ model_dir = os.path.join(framework_model_dir, model_name)
356
+
357
+ if not os.path.exists(model_dir):
358
+ model = model_class.from_pretrained(
359
+ model_name,
360
+ subfolder=subfolder,
361
+ use_safetensors=self.pipeline_info.use_safetensors(),
362
+ **kwargs,
363
+ ).to(self.device)
364
+ model.save_pretrained(model_dir)
365
+ else:
366
+ print(f"Load {self.name} pytorch model from: {model_dir}")
367
+
368
+ model = model_class.from_pretrained(model_dir).to(self.device)
369
+ return model
370
+
371
+ def load_model(self, framework_model_dir: str, subfolder: str):
372
+ pass
373
+
374
+ def get_input_names(self) -> list[str]:
375
+ pass
376
+
377
+ def get_output_names(self) -> list[str]:
378
+ pass
379
+
380
+ def get_dynamic_axes(self) -> dict[str, dict[int, str]]:
381
+ pass
382
+
383
+ def get_sample_input(self, batch_size, image_height, image_width) -> tuple:
384
+ pass
385
+
386
+ def get_profile_id(self, batch_size, image_height, image_width, static_batch, static_image_shape):
387
+ """For TensorRT EP"""
388
+ (
389
+ min_batch,
390
+ max_batch,
391
+ min_image_height,
392
+ max_image_height,
393
+ min_image_width,
394
+ max_image_width,
395
+ _,
396
+ _,
397
+ _,
398
+ _,
399
+ ) = self.get_minmax_dims(batch_size, image_height, image_width, static_batch, static_image_shape)
400
+
401
+ if (self.name in ["UNet", "UNetXL"]) and (self.get_batch_multiplier() == 1):
402
+ profile_id = f"_b1_{batch_size}" if static_batch else f"_b1_{min_batch}_{max_batch}"
403
+ else:
404
+ profile_id = f"_b_{batch_size}" if static_batch else f"_b_{min_batch}_{max_batch}"
405
+
406
+ if self.name != "CLIP":
407
+ if static_image_shape:
408
+ profile_id += f"_h_{image_height}_w_{image_width}"
409
+ else:
410
+ profile_id += f"_h_{min_image_height}_{max_image_height}_w_{min_image_width}_{max_image_width}"
411
+
412
+ return profile_id
413
+
414
+ def get_input_profile(self, batch_size, image_height, image_width, static_batch, static_image_shape):
415
+ """For TensorRT"""
416
+
417
+ def get_shape_dict(self, batch_size, image_height, image_width):
418
+ pass
419
+
420
+ def fp32_input_output_names(self) -> list[str]:
421
+ """For CUDA EP, we export ONNX model with FP32 first, then convert it to mixed precision model.
422
+ This is a list of input or output names that are kept as float32 in optimized model.
423
+ """
424
+ return []
425
+
426
+ def optimize_ort(
427
+ self,
428
+ input_onnx_path,
429
+ optimized_onnx_path,
430
+ to_fp16=True,
431
+ fp32_op_list=None,
432
+ optimize_by_ort=True,
433
+ optimize_by_fusion=True,
434
+ tmp_dir=None,
435
+ ):
436
+ optimizer = self.get_ort_optimizer()
437
+ optimizer.optimize(
438
+ input_onnx_path,
439
+ optimized_onnx_path,
440
+ float16=to_fp16,
441
+ keep_io_types=self.fp32_input_output_names(),
442
+ fp32_op_list=fp32_op_list,
443
+ optimize_by_ort=optimize_by_ort,
444
+ optimize_by_fusion=optimize_by_fusion,
445
+ tmp_dir=tmp_dir,
446
+ )
447
+
448
+ def optimize_trt(self, input_onnx_path, optimized_onnx_path):
449
+ onnx_graph = onnx.load(input_onnx_path)
450
+ opt = TrtOptimizer(onnx_graph)
451
+ opt.cleanup()
452
+ opt.fold_constants()
453
+ opt.infer_shapes()
454
+ opt.cleanup()
455
+ onnx_opt_graph = opt.get_optimized_onnx_graph()
456
+
457
+ if onnx_opt_graph.ByteSize() > onnx.checker.MAXIMUM_PROTOBUF:
458
+ onnx.save_model(
459
+ onnx_opt_graph,
460
+ optimized_onnx_path,
461
+ save_as_external_data=True,
462
+ all_tensors_to_one_file=True,
463
+ convert_attribute=False,
464
+ )
465
+ else:
466
+ onnx.save(onnx_opt_graph, optimized_onnx_path)
467
+
468
+ def check_dims(self, batch_size, image_height, image_width):
469
+ assert batch_size >= self.min_batch and batch_size <= self.max_batch
470
+ assert image_height % 8 == 0 or image_width % 8 == 0
471
+ latent_height = image_height // 8
472
+ latent_width = image_width // 8
473
+ assert latent_height >= self.min_latent_shape and latent_height <= self.max_latent_shape
474
+ assert latent_width >= self.min_latent_shape and latent_width <= self.max_latent_shape
475
+ return (latent_height, latent_width)
476
+
477
+ def get_minmax_dims(self, batch_size, image_height, image_width, static_batch, static_image_shape):
478
+ min_batch = batch_size if static_batch else self.min_batch
479
+ max_batch = batch_size if static_batch else self.max_batch
480
+ latent_height = image_height // 8
481
+ latent_width = image_width // 8
482
+ min_image_height = image_height if static_image_shape else self.min_image_shape
483
+ max_image_height = image_height if static_image_shape else self.max_image_shape
484
+ min_image_width = image_width if static_image_shape else self.min_image_shape
485
+ max_image_width = image_width if static_image_shape else self.max_image_shape
486
+ min_latent_height = latent_height if static_image_shape else self.min_latent_shape
487
+ max_latent_height = latent_height if static_image_shape else self.max_latent_shape
488
+ min_latent_width = latent_width if static_image_shape else self.min_latent_shape
489
+ max_latent_width = latent_width if static_image_shape else self.max_latent_shape
490
+ return (
491
+ min_batch,
492
+ max_batch,
493
+ min_image_height,
494
+ max_image_height,
495
+ min_image_width,
496
+ max_image_width,
497
+ min_latent_height,
498
+ max_latent_height,
499
+ min_latent_width,
500
+ max_latent_width,
501
+ )
502
+
503
+
504
+ class CLIP(BaseModel):
505
+ def __init__(
506
+ self,
507
+ pipeline_info: PipelineInfo,
508
+ model,
509
+ device,
510
+ max_batch_size,
511
+ embedding_dim: int = 0,
512
+ clip_skip=0,
513
+ ):
514
+ super().__init__(
515
+ pipeline_info,
516
+ model=model,
517
+ device=device,
518
+ max_batch_size=max_batch_size,
519
+ embedding_dim=embedding_dim if embedding_dim > 0 else pipeline_info.clip_embedding_dim(),
520
+ )
521
+ self.output_hidden_state = pipeline_info.is_xl()
522
+
523
+ # see https://github.com/huggingface/diffusers/pull/5057 for more information of clip_skip.
524
+ # Clip_skip=1 means that the output of the pre-final layer will be used for computing the prompt embeddings.
525
+ self.clip_skip = clip_skip
526
+
527
+ def get_input_names(self):
528
+ return ["input_ids"]
529
+
530
+ def get_output_names(self):
531
+ # The exported onnx model has no hidden_state. For SD-XL, We will add hidden_state to optimized onnx model.
532
+ return ["text_embeddings"]
533
+
534
+ def get_dynamic_axes(self):
535
+ return {"input_ids": {0: "B", 1: "S"}, "text_embeddings": {0: "B", 1: "S"}}
536
+
537
+ def get_input_profile(self, batch_size, image_height, image_width, static_batch, static_image_shape):
538
+ self.check_dims(batch_size, image_height, image_width)
539
+ min_batch, max_batch, _, _, _, _, _, _, _, _ = self.get_minmax_dims(
540
+ batch_size, image_height, image_width, static_batch, static_image_shape
541
+ )
542
+ return {
543
+ "input_ids": [(min_batch, self.text_maxlen), (batch_size, self.text_maxlen), (max_batch, self.text_maxlen)]
544
+ }
545
+
546
+ def get_shape_dict(self, batch_size, image_height, image_width):
547
+ self.check_dims(batch_size, image_height, image_width)
548
+ output = {
549
+ "input_ids": (batch_size, self.text_maxlen),
550
+ "text_embeddings": (batch_size, self.text_maxlen, self.embedding_dim),
551
+ }
552
+
553
+ if self.output_hidden_state:
554
+ output["hidden_states"] = (batch_size, self.text_maxlen, self.embedding_dim)
555
+
556
+ return output
557
+
558
+ def get_sample_input(self, batch_size, image_height, image_width):
559
+ self.check_dims(batch_size, image_height, image_width)
560
+ return (torch.zeros(batch_size, self.text_maxlen, dtype=torch.int32, device=self.device),)
561
+
562
+ def add_hidden_states_graph_output(self, model: ModelProto, optimized_onnx_path, use_external_data_format=False):
563
+ graph: GraphProto = model.graph
564
+ hidden_layers = -1
565
+ for i in range(len(graph.node)):
566
+ for j in range(len(graph.node[i].output)):
567
+ name = graph.node[i].output[j]
568
+ if "layers" in name:
569
+ hidden_layers = max(int(name.split(".")[1].split("/")[0]), hidden_layers)
570
+
571
+ assert self.clip_skip >= 0 and self.clip_skip < hidden_layers
572
+
573
+ node_output_name = f"/text_model/encoder/layers.{hidden_layers - 1 - self.clip_skip}/Add_1_output_0"
574
+
575
+ # search the name in outputs of all node
576
+ found = False
577
+ for i in range(len(graph.node)):
578
+ for j in range(len(graph.node[i].output)):
579
+ if graph.node[i].output[j] == node_output_name:
580
+ found = True
581
+ break
582
+ if found:
583
+ break
584
+ if not found:
585
+ raise RuntimeError("Failed to find hidden_states graph output in clip")
586
+
587
+ # Insert a Cast (fp32 -> fp16) node so that hidden_states has same data type as the first graph output.
588
+ graph_output_name = "hidden_states"
589
+ cast_node = onnx.helper.make_node("Cast", inputs=[node_output_name], outputs=[graph_output_name])
590
+ cast_node.attribute.extend([onnx.helper.make_attribute("to", graph.output[0].type.tensor_type.elem_type)])
591
+
592
+ hidden_state = graph.output.add()
593
+ hidden_state.CopyFrom(
594
+ onnx.helper.make_tensor_value_info(
595
+ graph_output_name,
596
+ graph.output[0].type.tensor_type.elem_type,
597
+ ["B", "S", self.embedding_dim],
598
+ )
599
+ )
600
+
601
+ onnx_model = OnnxModel(model)
602
+ onnx_model.add_node(cast_node)
603
+ onnx_model.save_model_to_file(optimized_onnx_path, use_external_data_format=use_external_data_format)
604
+
605
+ def optimize_ort(
606
+ self,
607
+ input_onnx_path,
608
+ optimized_onnx_path,
609
+ to_fp16=True,
610
+ fp32_op_list=None,
611
+ optimize_by_ort=True,
612
+ optimize_by_fusion=True,
613
+ tmp_dir=None,
614
+ ):
615
+ optimizer = self.get_ort_optimizer()
616
+
617
+ if not self.output_hidden_state:
618
+ optimizer.optimize(
619
+ input_onnx_path,
620
+ optimized_onnx_path,
621
+ float16=to_fp16,
622
+ keep_io_types=[],
623
+ fp32_op_list=fp32_op_list,
624
+ keep_outputs=["text_embeddings"],
625
+ optimize_by_ort=optimize_by_ort,
626
+ optimize_by_fusion=optimize_by_fusion,
627
+ tmp_dir=tmp_dir,
628
+ )
629
+ elif optimize_by_fusion:
630
+ with tempfile.TemporaryDirectory() as tmp_dir:
631
+ # Save to a temporary file so that we can load it with Onnx Runtime.
632
+ logger.info("Saving a temporary model to add hidden_states to graph output ...")
633
+ tmp_model_path = os.path.join(tmp_dir, "model.onnx")
634
+
635
+ model = onnx.load(input_onnx_path)
636
+ self.add_hidden_states_graph_output(model, tmp_model_path, use_external_data_format=True)
637
+ optimizer.optimize(
638
+ tmp_model_path,
639
+ optimized_onnx_path,
640
+ float16=to_fp16,
641
+ keep_io_types=[],
642
+ fp32_op_list=fp32_op_list,
643
+ keep_outputs=["text_embeddings", "hidden_states"],
644
+ optimize_by_ort=optimize_by_ort,
645
+ optimize_by_fusion=optimize_by_fusion,
646
+ tmp_dir=tmp_dir,
647
+ )
648
+ else: # input is optimized model, there is no need to add hidden states.
649
+ optimizer.optimize(
650
+ input_onnx_path,
651
+ optimized_onnx_path,
652
+ float16=to_fp16,
653
+ keep_io_types=[],
654
+ fp32_op_list=fp32_op_list,
655
+ keep_outputs=["text_embeddings", "hidden_states"],
656
+ optimize_by_ort=optimize_by_ort,
657
+ optimize_by_fusion=optimize_by_fusion,
658
+ tmp_dir=tmp_dir,
659
+ )
660
+
661
+ def optimize_trt(self, input_onnx_path, optimized_onnx_path):
662
+ onnx_graph = onnx.load(input_onnx_path)
663
+ opt = TrtOptimizer(onnx_graph)
664
+ opt.select_outputs([0]) # delete graph output#1
665
+ opt.cleanup()
666
+ opt.fold_constants()
667
+ opt.infer_shapes()
668
+ opt.select_outputs([0], names=["text_embeddings"]) # rename network output
669
+ opt.cleanup()
670
+ onnx_opt_graph = opt.get_optimized_onnx_graph()
671
+ if self.output_hidden_state:
672
+ self.add_hidden_states_graph_output(onnx_opt_graph, optimized_onnx_path)
673
+ else:
674
+ onnx.save(onnx_opt_graph, optimized_onnx_path)
675
+
676
+ def load_model(self, framework_model_dir, subfolder="text_encoder"):
677
+ return self.from_pretrained(CLIPTextModel, framework_model_dir, subfolder)
678
+
679
+
680
+ class CLIPWithProj(CLIP):
681
+ def __init__(
682
+ self,
683
+ pipeline_info: PipelineInfo,
684
+ model,
685
+ device,
686
+ max_batch_size=16,
687
+ clip_skip=0,
688
+ ):
689
+ super().__init__(
690
+ pipeline_info,
691
+ model,
692
+ device=device,
693
+ max_batch_size=max_batch_size,
694
+ embedding_dim=pipeline_info.clipwithproj_embedding_dim(),
695
+ clip_skip=clip_skip,
696
+ )
697
+
698
+ def load_model(self, framework_model_dir, subfolder="text_encoder_2"):
699
+ return self.from_pretrained(CLIPTextModelWithProjection, framework_model_dir, subfolder)
700
+
701
+ def get_shape_dict(self, batch_size, image_height, image_width):
702
+ self.check_dims(batch_size, image_height, image_width)
703
+ output = {
704
+ "input_ids": (batch_size, self.text_maxlen),
705
+ "text_embeddings": (batch_size, self.embedding_dim),
706
+ }
707
+
708
+ if self.output_hidden_state:
709
+ output["hidden_states"] = (batch_size, self.text_maxlen, self.embedding_dim)
710
+
711
+ return output
712
+
713
+
714
+ class UNet2DConditionControlNetModel(torch.nn.Module):
715
+ def __init__(self, unet, controlnets: ControlNetModel):
716
+ super().__init__()
717
+ self.unet = unet
718
+ self.controlnets = controlnets
719
+
720
+ def forward(self, sample, timestep, encoder_hidden_states, controlnet_images, controlnet_scales):
721
+ for i, (controlnet_image, conditioning_scale, controlnet) in enumerate(
722
+ zip(controlnet_images, controlnet_scales, self.controlnets, strict=False)
723
+ ):
724
+ down_samples, mid_sample = controlnet(
725
+ sample,
726
+ timestep,
727
+ encoder_hidden_states=encoder_hidden_states,
728
+ controlnet_cond=controlnet_image,
729
+ return_dict=False,
730
+ )
731
+
732
+ down_samples = [down_sample * conditioning_scale for down_sample in down_samples]
733
+ mid_sample *= conditioning_scale
734
+
735
+ # merge samples
736
+ if i == 0:
737
+ down_block_res_samples, mid_block_res_sample = down_samples, mid_sample
738
+ else:
739
+ down_block_res_samples = [
740
+ samples_prev + samples_curr
741
+ for samples_prev, samples_curr in zip(down_block_res_samples, down_samples, strict=False)
742
+ ]
743
+ mid_block_res_sample += mid_sample
744
+
745
+ noise_pred = self.unet(
746
+ sample,
747
+ timestep,
748
+ encoder_hidden_states=encoder_hidden_states,
749
+ down_block_additional_residuals=down_block_res_samples,
750
+ mid_block_additional_residual=mid_block_res_sample,
751
+ )
752
+ return noise_pred[0]
753
+
754
+
755
+ # Modified from convert_stable_diffusion_controlnet_to_onnx.py in diffusers
756
+ class UNet2DConditionXLControlNetModel(torch.nn.Module):
757
+ def __init__(self, unet, controlnets: ControlNetModel):
758
+ super().__init__()
759
+ self.unet = unet
760
+ self.controlnets = controlnets
761
+
762
+ def forward(
763
+ self,
764
+ sample,
765
+ timestep,
766
+ encoder_hidden_states,
767
+ text_embeds,
768
+ time_ids,
769
+ controlnet_images,
770
+ controlnet_scales,
771
+ ):
772
+ added_cond_kwargs = {"text_embeds": text_embeds, "time_ids": time_ids}
773
+ for i, (controlnet_image, conditioning_scale, controlnet) in enumerate(
774
+ zip(controlnet_images, controlnet_scales, self.controlnets, strict=False)
775
+ ):
776
+ down_samples, mid_sample = controlnet(
777
+ sample,
778
+ timestep,
779
+ encoder_hidden_states=encoder_hidden_states,
780
+ controlnet_cond=controlnet_image,
781
+ conditioning_scale=conditioning_scale,
782
+ added_cond_kwargs=added_cond_kwargs,
783
+ return_dict=False,
784
+ )
785
+
786
+ # merge samples
787
+ if i == 0:
788
+ down_block_res_samples, mid_block_res_sample = down_samples, mid_sample
789
+ else:
790
+ down_block_res_samples = [
791
+ samples_prev + samples_curr
792
+ for samples_prev, samples_curr in zip(down_block_res_samples, down_samples, strict=False)
793
+ ]
794
+ mid_block_res_sample += mid_sample
795
+
796
+ noise_pred = self.unet(
797
+ sample,
798
+ timestep,
799
+ encoder_hidden_states=encoder_hidden_states,
800
+ down_block_additional_residuals=down_block_res_samples,
801
+ mid_block_additional_residual=mid_block_res_sample,
802
+ added_cond_kwargs=added_cond_kwargs,
803
+ return_dict=False,
804
+ )
805
+ return noise_pred[0]
806
+
807
+
808
+ class UNet(BaseModel):
809
+ def __init__(
810
+ self,
811
+ pipeline_info: PipelineInfo,
812
+ model,
813
+ device,
814
+ fp16=False, # used by TRT
815
+ max_batch_size=16,
816
+ text_maxlen=77,
817
+ unet_dim=4,
818
+ ):
819
+ super().__init__(
820
+ pipeline_info,
821
+ model=model,
822
+ device=device,
823
+ fp16=fp16,
824
+ max_batch_size=max_batch_size,
825
+ embedding_dim=pipeline_info.unet_embedding_dim(),
826
+ text_maxlen=text_maxlen,
827
+ )
828
+
829
+ self.unet_dim = unet_dim
830
+ self.controlnet = pipeline_info.controlnet_name()
831
+
832
+ def load_model(self, framework_model_dir, subfolder="unet"):
833
+ options = {"variant": "fp16", "torch_dtype": torch.float16}
834
+
835
+ model = self.from_pretrained(UNet2DConditionModel, framework_model_dir, subfolder, **options)
836
+
837
+ if self.controlnet:
838
+ controlnet_list = []
839
+ for name in self.controlnet:
840
+ controlnet = self.from_pretrained(
841
+ ControlNetModel,
842
+ framework_model_dir,
843
+ subfolder=None,
844
+ model_name=name,
845
+ torch_dtype=torch.float16,
846
+ )
847
+ controlnet_list.append(controlnet)
848
+
849
+ model = UNet2DConditionControlNetModel(model, torch.nn.ModuleList(controlnet_list))
850
+
851
+ if not self.fp16:
852
+ model = model.to(torch.float32)
853
+
854
+ return model
855
+
856
+ def get_input_names(self):
857
+ if not self.controlnet:
858
+ return ["sample", "timestep", "encoder_hidden_states"]
859
+ else:
860
+ return ["sample", "timestep", "encoder_hidden_states", "controlnet_images", "controlnet_scales"]
861
+
862
+ def get_output_names(self):
863
+ return ["latent"]
864
+
865
+ def get_dynamic_axes(self):
866
+ b = "2B" if self.get_batch_multiplier() == 2 else "B"
867
+ output = {
868
+ "sample": {0: b, 2: "H", 3: "W"},
869
+ "encoder_hidden_states": {0: b},
870
+ "latent": {0: b, 2: "H", 3: "W"},
871
+ }
872
+ if self.controlnet:
873
+ output.update(
874
+ {
875
+ "controlnet_images": {1: b, 3: "8H", 4: "8W"},
876
+ }
877
+ )
878
+ return output
879
+
880
+ def get_input_profile(self, batch_size, image_height, image_width, static_batch, static_image_shape):
881
+ latent_height, latent_width = self.check_dims(batch_size, image_height, image_width)
882
+ (
883
+ min_batch,
884
+ max_batch,
885
+ min_image_height,
886
+ max_image_height,
887
+ min_image_width,
888
+ max_image_width,
889
+ min_latent_height,
890
+ max_latent_height,
891
+ min_latent_width,
892
+ max_latent_width,
893
+ ) = self.get_minmax_dims(batch_size, image_height, image_width, static_batch, static_image_shape)
894
+ m = self.get_batch_multiplier()
895
+ output = {
896
+ "sample": [
897
+ (m * min_batch, self.unet_dim, min_latent_height, min_latent_width),
898
+ (m * batch_size, self.unet_dim, latent_height, latent_width),
899
+ (m * max_batch, self.unet_dim, max_latent_height, max_latent_width),
900
+ ],
901
+ "encoder_hidden_states": [
902
+ (m * min_batch, self.text_maxlen, self.embedding_dim),
903
+ (m * batch_size, self.text_maxlen, self.embedding_dim),
904
+ (m * max_batch, self.text_maxlen, self.embedding_dim),
905
+ ],
906
+ }
907
+
908
+ if self.controlnet:
909
+ output.update(
910
+ {
911
+ "controlnet_images": [
912
+ (len(self.controlnet), m * min_batch, 3, min_image_height, min_image_width),
913
+ (len(self.controlnet), m * batch_size, 3, image_height, image_width),
914
+ (len(self.controlnet), m * max_batch, 3, max_image_height, max_image_width),
915
+ ]
916
+ }
917
+ )
918
+ return output
919
+
920
+ def get_shape_dict(self, batch_size, image_height, image_width):
921
+ latent_height, latent_width = self.check_dims(batch_size, image_height, image_width)
922
+ m = self.get_batch_multiplier()
923
+ output = {
924
+ "sample": (m * batch_size, self.unet_dim, latent_height, latent_width),
925
+ "timestep": [1],
926
+ "encoder_hidden_states": (m * batch_size, self.text_maxlen, self.embedding_dim),
927
+ "latent": (m * batch_size, 4, latent_height, latent_width),
928
+ }
929
+
930
+ if self.controlnet:
931
+ output.update(
932
+ {
933
+ "controlnet_images": (len(self.controlnet), m * batch_size, 3, image_height, image_width),
934
+ "controlnet_scales": [len(self.controlnet)],
935
+ }
936
+ )
937
+ return output
938
+
939
+ def get_sample_input(self, batch_size, image_height, image_width):
940
+ latent_height, latent_width = self.check_dims(batch_size, image_height, image_width)
941
+ dtype = torch.float16 if self.fp16 else torch.float32
942
+ m = self.get_batch_multiplier()
943
+ output = (
944
+ torch.randn(m * batch_size, self.unet_dim, latent_height, latent_width, dtype=dtype, device=self.device),
945
+ torch.tensor([1.0], dtype=dtype, device=self.device),
946
+ torch.randn(m * batch_size, self.text_maxlen, self.embedding_dim, dtype=dtype, device=self.device),
947
+ )
948
+
949
+ if self.controlnet:
950
+ output = (
951
+ *output,
952
+ torch.randn(
953
+ len(self.controlnet), m * batch_size, 3, image_height, image_width, dtype=dtype, device=self.device
954
+ ),
955
+ torch.randn(len(self.controlnet), dtype=dtype, device=self.device),
956
+ )
957
+ return output
958
+
959
+
960
+ class UNetXL(BaseModel):
961
+ def __init__(
962
+ self,
963
+ pipeline_info: PipelineInfo,
964
+ model,
965
+ device,
966
+ fp16=False, # used by TRT
967
+ max_batch_size=16,
968
+ text_maxlen=77,
969
+ unet_dim=4,
970
+ time_dim=6,
971
+ ):
972
+ super().__init__(
973
+ pipeline_info,
974
+ model,
975
+ device=device,
976
+ fp16=fp16,
977
+ max_batch_size=max_batch_size,
978
+ embedding_dim=pipeline_info.unet_embedding_dim(),
979
+ text_maxlen=text_maxlen,
980
+ )
981
+ self.unet_dim = unet_dim
982
+ self.time_dim = time_dim
983
+
984
+ self.custom_unet = pipeline_info.custom_unet()
985
+ self.controlnet = pipeline_info.controlnet_name()
986
+
987
+ def load_model(self, framework_model_dir, subfolder="unet", always_download_fp16=True):
988
+ options = {"variant": "fp16", "torch_dtype": torch.float16} if self.fp16 or always_download_fp16 else {}
989
+
990
+ if self.custom_unet:
991
+ model_dir = os.path.join(framework_model_dir, self.custom_unet, subfolder)
992
+ if not os.path.exists(model_dir):
993
+ unet = UNet2DConditionModel.from_pretrained(self.custom_unet, **options)
994
+ unet.save_pretrained(model_dir)
995
+ else:
996
+ unet = UNet2DConditionModel.from_pretrained(model_dir, **options)
997
+ model = unet.to(self.device)
998
+ else:
999
+ model = self.from_pretrained(UNet2DConditionModel, framework_model_dir, subfolder, **options)
1000
+
1001
+ if always_download_fp16 and not self.fp16:
1002
+ model = model.to(torch.float32)
1003
+
1004
+ if self.controlnet:
1005
+ cnet_model_opts = {"torch_dtype": torch.float16} if self.fp16 or always_download_fp16 else {}
1006
+ controlnets = torch.nn.ModuleList(
1007
+ [ControlNetModel.from_pretrained(path, **cnet_model_opts).to(self.device) for path in self.controlnet]
1008
+ )
1009
+ model = UNet2DConditionXLControlNetModel(model, controlnets)
1010
+
1011
+ if always_download_fp16 and not self.fp16:
1012
+ model = model.to(torch.float32)
1013
+
1014
+ return model
1015
+
1016
+ def get_input_names(self):
1017
+ input_names = ["sample", "timestep", "encoder_hidden_states", "text_embeds", "time_ids"]
1018
+ if self.controlnet:
1019
+ return [*input_names, "controlnet_images", "controlnet_scales"]
1020
+ return input_names
1021
+
1022
+ def get_output_names(self):
1023
+ return ["latent"]
1024
+
1025
+ def get_dynamic_axes(self):
1026
+ b = "2B" if self.get_batch_multiplier() == 2 else "B"
1027
+ output = {
1028
+ "sample": {0: b, 2: "H", 3: "W"},
1029
+ "encoder_hidden_states": {0: b},
1030
+ "text_embeds": {0: b},
1031
+ "time_ids": {0: b},
1032
+ "latent": {0: b, 2: "H", 3: "W"},
1033
+ }
1034
+
1035
+ if self.controlnet:
1036
+ output.update(
1037
+ {
1038
+ "controlnet_images": {1: b, 3: "8H", 4: "8W"},
1039
+ }
1040
+ )
1041
+ return output
1042
+
1043
+ def get_input_profile(self, batch_size, image_height, image_width, static_batch, static_image_shape):
1044
+ latent_height, latent_width = self.check_dims(batch_size, image_height, image_width)
1045
+ (
1046
+ min_batch,
1047
+ max_batch,
1048
+ min_image_height,
1049
+ max_image_height,
1050
+ min_image_width,
1051
+ max_image_width,
1052
+ min_latent_height,
1053
+ max_latent_height,
1054
+ min_latent_width,
1055
+ max_latent_width,
1056
+ ) = self.get_minmax_dims(batch_size, image_height, image_width, static_batch, static_image_shape)
1057
+ m = self.get_batch_multiplier()
1058
+ output = {
1059
+ "sample": [
1060
+ (m * min_batch, self.unet_dim, min_latent_height, min_latent_width),
1061
+ (m * batch_size, self.unet_dim, latent_height, latent_width),
1062
+ (m * max_batch, self.unet_dim, max_latent_height, max_latent_width),
1063
+ ],
1064
+ "encoder_hidden_states": [
1065
+ (m * min_batch, self.text_maxlen, self.embedding_dim),
1066
+ (m * batch_size, self.text_maxlen, self.embedding_dim),
1067
+ (m * max_batch, self.text_maxlen, self.embedding_dim),
1068
+ ],
1069
+ "text_embeds": [(m * min_batch, 1280), (m * batch_size, 1280), (m * max_batch, 1280)],
1070
+ "time_ids": [
1071
+ (m * min_batch, self.time_dim),
1072
+ (m * batch_size, self.time_dim),
1073
+ (m * max_batch, self.time_dim),
1074
+ ],
1075
+ }
1076
+
1077
+ if self.controlnet:
1078
+ output.update(
1079
+ {
1080
+ "controlnet_images": [
1081
+ (len(self.controlnet), m * min_batch, 3, min_image_height, min_image_width),
1082
+ (len(self.controlnet), m * batch_size, 3, image_height, image_width),
1083
+ (len(self.controlnet), m * max_batch, 3, max_image_height, max_image_width),
1084
+ ],
1085
+ }
1086
+ )
1087
+ return output
1088
+
1089
+ def get_shape_dict(self, batch_size, image_height, image_width):
1090
+ latent_height, latent_width = self.check_dims(batch_size, image_height, image_width)
1091
+ m = self.get_batch_multiplier()
1092
+ output = {
1093
+ "sample": (m * batch_size, self.unet_dim, latent_height, latent_width),
1094
+ "timestep": (1,),
1095
+ "encoder_hidden_states": (m * batch_size, self.text_maxlen, self.embedding_dim),
1096
+ "text_embeds": (m * batch_size, 1280),
1097
+ "time_ids": (m * batch_size, self.time_dim),
1098
+ "latent": (m * batch_size, 4, latent_height, latent_width),
1099
+ }
1100
+
1101
+ if self.controlnet:
1102
+ output.update(
1103
+ {
1104
+ "controlnet_images": (len(self.controlnet), m * batch_size, 3, image_height, image_width),
1105
+ "controlnet_scales": [len(self.controlnet)],
1106
+ }
1107
+ )
1108
+ return output
1109
+
1110
+ def get_sample_input(self, batch_size, image_height, image_width):
1111
+ latent_height, latent_width = self.check_dims(batch_size, image_height, image_width)
1112
+ dtype = torch.float16 if self.fp16 else torch.float32
1113
+ m = self.get_batch_multiplier()
1114
+ if not self.controlnet:
1115
+ return (
1116
+ torch.randn(
1117
+ m * batch_size, self.unet_dim, latent_height, latent_width, dtype=dtype, device=self.device
1118
+ ),
1119
+ torch.tensor([1.0], dtype=dtype, device=self.device),
1120
+ torch.randn(m * batch_size, self.text_maxlen, self.embedding_dim, dtype=dtype, device=self.device),
1121
+ {
1122
+ "added_cond_kwargs": {
1123
+ "text_embeds": torch.randn(m * batch_size, 1280, dtype=dtype, device=self.device),
1124
+ "time_ids": torch.randn(m * batch_size, self.time_dim, dtype=dtype, device=self.device),
1125
+ }
1126
+ },
1127
+ )
1128
+ else:
1129
+ # sample, timestep, encoder_hidden_states, text_embeds, time_ids, controlnet_images, controlnet_scales,
1130
+ return (
1131
+ torch.randn(
1132
+ m * batch_size, self.unet_dim, latent_height, latent_width, dtype=dtype, device=self.device
1133
+ ),
1134
+ torch.tensor([1.0], dtype=dtype, device=self.device),
1135
+ torch.randn(m * batch_size, self.text_maxlen, self.embedding_dim, dtype=dtype, device=self.device),
1136
+ torch.randn(m * batch_size, 1280, dtype=dtype, device=self.device),
1137
+ torch.randn(m * batch_size, self.time_dim, dtype=dtype, device=self.device),
1138
+ torch.randn(
1139
+ len(self.controlnet), m * batch_size, 3, image_height, image_width, dtype=dtype, device=self.device
1140
+ ),
1141
+ torch.randn(len(self.controlnet), dtype=dtype, device=self.device),
1142
+ )
1143
+
1144
+
1145
+ # VAE Decoder
1146
+ class VAE(BaseModel):
1147
+ def __init__(
1148
+ self,
1149
+ pipeline_info: PipelineInfo,
1150
+ model,
1151
+ device,
1152
+ max_batch_size,
1153
+ fp16: bool = False,
1154
+ custom_fp16_vae: str | None = None,
1155
+ ):
1156
+ super().__init__(
1157
+ pipeline_info,
1158
+ model=model,
1159
+ device=device,
1160
+ fp16=fp16,
1161
+ max_batch_size=max_batch_size,
1162
+ )
1163
+
1164
+ # For SD XL, need custom trained fp16 model to speed up, and avoid overflow at the same time.
1165
+ self.custom_fp16_vae = custom_fp16_vae
1166
+
1167
+ def load_model(self, framework_model_dir, subfolder: str = "vae_decoder"):
1168
+ model_name = self.custom_fp16_vae or self.pipeline_info.name()
1169
+
1170
+ model_dir = os.path.join(framework_model_dir, model_name, subfolder)
1171
+ if not os.path.exists(model_dir):
1172
+ if self.custom_fp16_vae:
1173
+ vae = AutoencoderKL.from_pretrained(self.custom_fp16_vae, torch_dtype=torch.float16).to(self.device)
1174
+ else:
1175
+ vae = AutoencoderKL.from_pretrained(
1176
+ self.pipeline_info.name(),
1177
+ subfolder="vae",
1178
+ use_safetensors=self.pipeline_info.use_safetensors(),
1179
+ ).to(self.device)
1180
+ vae.save_pretrained(model_dir)
1181
+ else:
1182
+ print(f"Load {self.name} pytorch model from: {model_dir}")
1183
+ if self.custom_fp16_vae:
1184
+ vae = AutoencoderKL.from_pretrained(model_dir, torch_dtype=torch.float16).to(self.device)
1185
+ else:
1186
+ vae = AutoencoderKL.from_pretrained(model_dir).to(self.device)
1187
+
1188
+ vae.forward = vae.decode
1189
+ return vae
1190
+
1191
+ def get_input_names(self):
1192
+ return ["latent"]
1193
+
1194
+ def get_output_names(self):
1195
+ return ["images"]
1196
+
1197
+ def get_dynamic_axes(self):
1198
+ return {"latent": {0: "B", 2: "H", 3: "W"}, "images": {0: "B", 2: "8H", 3: "8W"}}
1199
+
1200
+ def get_input_profile(self, batch_size, image_height, image_width, static_batch, static_image_shape):
1201
+ latent_height, latent_width = self.check_dims(batch_size, image_height, image_width)
1202
+ (
1203
+ min_batch,
1204
+ max_batch,
1205
+ _,
1206
+ _,
1207
+ _,
1208
+ _,
1209
+ min_latent_height,
1210
+ max_latent_height,
1211
+ min_latent_width,
1212
+ max_latent_width,
1213
+ ) = self.get_minmax_dims(batch_size, image_height, image_width, static_batch, static_image_shape)
1214
+ return {
1215
+ "latent": [
1216
+ (min_batch, 4, min_latent_height, min_latent_width),
1217
+ (batch_size, 4, latent_height, latent_width),
1218
+ (max_batch, 4, max_latent_height, max_latent_width),
1219
+ ]
1220
+ }
1221
+
1222
+ def get_shape_dict(self, batch_size, image_height, image_width):
1223
+ latent_height, latent_width = self.check_dims(batch_size, image_height, image_width)
1224
+ return {
1225
+ "latent": (batch_size, 4, latent_height, latent_width),
1226
+ "images": (batch_size, 3, image_height, image_width),
1227
+ }
1228
+
1229
+ def get_sample_input(self, batch_size, image_height, image_width):
1230
+ latent_height, latent_width = self.check_dims(batch_size, image_height, image_width)
1231
+ dtype = torch.float16 if self.fp16 else torch.float32
1232
+ return (torch.randn(batch_size, 4, latent_height, latent_width, dtype=dtype, device=self.device),)
1233
+
1234
+ def fp32_input_output_names(self) -> list[str]:
1235
+ return []
1236
+
1237
+
1238
+ def get_tokenizer(pipeline_info: PipelineInfo, framework_model_dir, subfolder="tokenizer"):
1239
+ tokenizer_dir = os.path.join(framework_model_dir, pipeline_info.name(), subfolder)
1240
+
1241
+ if not os.path.exists(tokenizer_dir):
1242
+ model = CLIPTokenizer.from_pretrained(
1243
+ pipeline_info.name(),
1244
+ subfolder=subfolder,
1245
+ use_safetensors=pipeline_info.is_xl(),
1246
+ )
1247
+ model.save_pretrained(tokenizer_dir)
1248
+ else:
1249
+ print(f"[I] Load tokenizer pytorch model from: {tokenizer_dir}")
1250
+ model = CLIPTokenizer.from_pretrained(tokenizer_dir)
1251
+ return model
1252
+
1253
+
1254
+ class TorchVAEEncoder(torch.nn.Module):
1255
+ def __init__(self, vae_encoder):
1256
+ super().__init__()
1257
+ self.vae_encoder = vae_encoder
1258
+
1259
+ def forward(self, x):
1260
+ return self.vae_encoder.encode(x).latent_dist.sample()
1261
+
1262
+
1263
+ class VAEEncoder(BaseModel):
1264
+ def __init__(self, pipeline_info: PipelineInfo, model, device, max_batch_size):
1265
+ super().__init__(
1266
+ pipeline_info,
1267
+ model=model,
1268
+ device=device,
1269
+ max_batch_size=max_batch_size,
1270
+ )
1271
+
1272
+ def load_model(self, framework_model_dir, subfolder="vae_encoder"):
1273
+ vae = self.from_pretrained(AutoencoderKL, framework_model_dir, subfolder)
1274
+ return TorchVAEEncoder(vae)
1275
+
1276
+ def get_input_names(self):
1277
+ return ["images"]
1278
+
1279
+ def get_output_names(self):
1280
+ return ["latent"]
1281
+
1282
+ def get_dynamic_axes(self):
1283
+ return {"images": {0: "B", 2: "8H", 3: "8W"}, "latent": {0: "B", 2: "H", 3: "W"}}
1284
+
1285
+ def get_input_profile(self, batch_size, image_height, image_width, static_batch, static_image_shape):
1286
+ self.check_dims(batch_size, image_height, image_width)
1287
+
1288
+ (
1289
+ min_batch,
1290
+ max_batch,
1291
+ min_image_height,
1292
+ max_image_height,
1293
+ min_image_width,
1294
+ max_image_width,
1295
+ _,
1296
+ _,
1297
+ _,
1298
+ _,
1299
+ ) = self.get_minmax_dims(batch_size, image_height, image_width, static_batch, static_image_shape)
1300
+
1301
+ return {
1302
+ "images": [
1303
+ (min_batch, 3, min_image_height, min_image_width),
1304
+ (batch_size, 3, image_height, image_width),
1305
+ (max_batch, 3, max_image_height, max_image_width),
1306
+ ],
1307
+ }
1308
+
1309
+ def get_shape_dict(self, batch_size, image_height, image_width):
1310
+ latent_height, latent_width = self.check_dims(batch_size, image_height, image_width)
1311
+ return {
1312
+ "images": (batch_size, 3, image_height, image_width),
1313
+ "latent": (batch_size, 4, latent_height, latent_width),
1314
+ }
1315
+
1316
+ def get_sample_input(self, batch_size, image_height, image_width):
1317
+ self.check_dims(batch_size, image_height, image_width)
1318
+ return torch.randn(batch_size, 3, image_height, image_width, dtype=torch.float32, device=self.device)