onnxruntime-directml 1.20.0__cp313-cp313-win_amd64.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (305) hide show
  1. onnxruntime/LICENSE +21 -0
  2. onnxruntime/Privacy.md +21 -0
  3. onnxruntime/ThirdPartyNotices.txt +6508 -0
  4. onnxruntime/__init__.py +78 -0
  5. onnxruntime/backend/__init__.py +6 -0
  6. onnxruntime/backend/backend.py +174 -0
  7. onnxruntime/backend/backend_rep.py +53 -0
  8. onnxruntime/capi/DirectML.dll +0 -0
  9. onnxruntime/capi/__init__.py +4 -0
  10. onnxruntime/capi/_ld_preload.py +7 -0
  11. onnxruntime/capi/_pybind_state.py +33 -0
  12. onnxruntime/capi/convert_npz_to_onnx_adapter.py +48 -0
  13. onnxruntime/capi/onnxruntime.dll +0 -0
  14. onnxruntime/capi/onnxruntime_collect_build_info.py +47 -0
  15. onnxruntime/capi/onnxruntime_inference_collection.py +1108 -0
  16. onnxruntime/capi/onnxruntime_providers_shared.dll +0 -0
  17. onnxruntime/capi/onnxruntime_pybind11_state.pyd +0 -0
  18. onnxruntime/capi/onnxruntime_validation.py +150 -0
  19. onnxruntime/capi/version_info.py +2 -0
  20. onnxruntime/datasets/__init__.py +17 -0
  21. onnxruntime/datasets/logreg_iris.onnx +0 -0
  22. onnxruntime/datasets/mul_1.onnx +0 -0
  23. onnxruntime/datasets/sigmoid.onnx +13 -0
  24. onnxruntime/quantization/CalTableFlatBuffers/KeyValue.py +78 -0
  25. onnxruntime/quantization/CalTableFlatBuffers/TrtTable.py +90 -0
  26. onnxruntime/quantization/CalTableFlatBuffers/__init__.py +0 -0
  27. onnxruntime/quantization/__init__.py +16 -0
  28. onnxruntime/quantization/base_quantizer.py +532 -0
  29. onnxruntime/quantization/calibrate.py +1245 -0
  30. onnxruntime/quantization/execution_providers/qnn/__init__.py +2 -0
  31. onnxruntime/quantization/execution_providers/qnn/fusion_lpnorm.py +132 -0
  32. onnxruntime/quantization/execution_providers/qnn/mixed_precision_overrides_utils.py +413 -0
  33. onnxruntime/quantization/execution_providers/qnn/preprocess.py +307 -0
  34. onnxruntime/quantization/execution_providers/qnn/quant_config.py +387 -0
  35. onnxruntime/quantization/fusions/__init__.py +3 -0
  36. onnxruntime/quantization/fusions/fusion.py +311 -0
  37. onnxruntime/quantization/fusions/fusion_gelu.py +272 -0
  38. onnxruntime/quantization/fusions/fusion_layernorm.py +135 -0
  39. onnxruntime/quantization/matmul_4bits_quantizer.py +1480 -0
  40. onnxruntime/quantization/matmul_bnb4_quantizer.py +240 -0
  41. onnxruntime/quantization/onnx_model.py +580 -0
  42. onnxruntime/quantization/onnx_quantizer.py +1008 -0
  43. onnxruntime/quantization/operators/__init__.py +2 -0
  44. onnxruntime/quantization/operators/activation.py +119 -0
  45. onnxruntime/quantization/operators/argmax.py +18 -0
  46. onnxruntime/quantization/operators/attention.py +73 -0
  47. onnxruntime/quantization/operators/base_operator.py +26 -0
  48. onnxruntime/quantization/operators/binary_op.py +72 -0
  49. onnxruntime/quantization/operators/concat.py +62 -0
  50. onnxruntime/quantization/operators/conv.py +258 -0
  51. onnxruntime/quantization/operators/direct_q8.py +78 -0
  52. onnxruntime/quantization/operators/embed_layernorm.py +121 -0
  53. onnxruntime/quantization/operators/gather.py +64 -0
  54. onnxruntime/quantization/operators/gavgpool.py +62 -0
  55. onnxruntime/quantization/operators/gemm.py +166 -0
  56. onnxruntime/quantization/operators/lstm.py +117 -0
  57. onnxruntime/quantization/operators/matmul.py +231 -0
  58. onnxruntime/quantization/operators/maxpool.py +34 -0
  59. onnxruntime/quantization/operators/norm.py +40 -0
  60. onnxruntime/quantization/operators/pad.py +100 -0
  61. onnxruntime/quantization/operators/pooling.py +67 -0
  62. onnxruntime/quantization/operators/qdq_base_operator.py +22 -0
  63. onnxruntime/quantization/operators/resize.py +34 -0
  64. onnxruntime/quantization/operators/softmax.py +74 -0
  65. onnxruntime/quantization/operators/split.py +63 -0
  66. onnxruntime/quantization/operators/where.py +87 -0
  67. onnxruntime/quantization/preprocess.py +141 -0
  68. onnxruntime/quantization/qdq_loss_debug.py +389 -0
  69. onnxruntime/quantization/qdq_quantizer.py +1187 -0
  70. onnxruntime/quantization/quant_utils.py +891 -0
  71. onnxruntime/quantization/quantize.py +748 -0
  72. onnxruntime/quantization/registry.py +106 -0
  73. onnxruntime/quantization/shape_inference.py +187 -0
  74. onnxruntime/quantization/tensor_quant_overrides.py +516 -0
  75. onnxruntime/tools/__init__.py +10 -0
  76. onnxruntime/tools/check_onnx_model_mobile_usability.py +47 -0
  77. onnxruntime/tools/convert_onnx_models_to_ort.py +377 -0
  78. onnxruntime/tools/file_utils.py +46 -0
  79. onnxruntime/tools/logger.py +11 -0
  80. onnxruntime/tools/make_dynamic_shape_fixed.py +72 -0
  81. onnxruntime/tools/mobile_helpers/__init__.py +0 -0
  82. onnxruntime/tools/mobile_helpers/coreml_supported_mlprogram_ops.md +33 -0
  83. onnxruntime/tools/mobile_helpers/coreml_supported_neuralnetwork_ops.md +43 -0
  84. onnxruntime/tools/mobile_helpers/nnapi_supported_ops.md +58 -0
  85. onnxruntime/tools/mobile_helpers/usability_checker.py +739 -0
  86. onnxruntime/tools/offline_tuning.py +169 -0
  87. onnxruntime/tools/onnx_model_utils.py +413 -0
  88. onnxruntime/tools/onnx_randomizer.py +85 -0
  89. onnxruntime/tools/onnxruntime_test.py +164 -0
  90. onnxruntime/tools/optimize_onnx_model.py +55 -0
  91. onnxruntime/tools/ort_format_model/__init__.py +25 -0
  92. onnxruntime/tools/ort_format_model/operator_type_usage_processors.py +663 -0
  93. onnxruntime/tools/ort_format_model/ort_flatbuffers_py/__init__.py +0 -0
  94. onnxruntime/tools/ort_format_model/ort_flatbuffers_py/fbs/ArgType.py +7 -0
  95. onnxruntime/tools/ort_format_model/ort_flatbuffers_py/fbs/ArgTypeAndIndex.py +67 -0
  96. onnxruntime/tools/ort_format_model/ort_flatbuffers_py/fbs/Attribute.py +337 -0
  97. onnxruntime/tools/ort_format_model/ort_flatbuffers_py/fbs/AttributeType.py +18 -0
  98. onnxruntime/tools/ort_format_model/ort_flatbuffers_py/fbs/Checkpoint.py +125 -0
  99. onnxruntime/tools/ort_format_model/ort_flatbuffers_py/fbs/DeprecatedKernelCreateInfos.py +120 -0
  100. onnxruntime/tools/ort_format_model/ort_flatbuffers_py/fbs/DeprecatedNodeIndexAndKernelDefHash.py +68 -0
  101. onnxruntime/tools/ort_format_model/ort_flatbuffers_py/fbs/DeprecatedSessionState.py +96 -0
  102. onnxruntime/tools/ort_format_model/ort_flatbuffers_py/fbs/DeprecatedSubGraphSessionState.py +72 -0
  103. onnxruntime/tools/ort_format_model/ort_flatbuffers_py/fbs/Dimension.py +71 -0
  104. onnxruntime/tools/ort_format_model/ort_flatbuffers_py/fbs/DimensionValue.py +80 -0
  105. onnxruntime/tools/ort_format_model/ort_flatbuffers_py/fbs/DimensionValueType.py +8 -0
  106. onnxruntime/tools/ort_format_model/ort_flatbuffers_py/fbs/EdgeEnd.py +32 -0
  107. onnxruntime/tools/ort_format_model/ort_flatbuffers_py/fbs/FloatProperty.py +67 -0
  108. onnxruntime/tools/ort_format_model/ort_flatbuffers_py/fbs/Graph.py +320 -0
  109. onnxruntime/tools/ort_format_model/ort_flatbuffers_py/fbs/InferenceSession.py +88 -0
  110. onnxruntime/tools/ort_format_model/ort_flatbuffers_py/fbs/IntProperty.py +67 -0
  111. onnxruntime/tools/ort_format_model/ort_flatbuffers_py/fbs/KernelTypeStrArgsEntry.py +91 -0
  112. onnxruntime/tools/ort_format_model/ort_flatbuffers_py/fbs/KernelTypeStrResolver.py +78 -0
  113. onnxruntime/tools/ort_format_model/ort_flatbuffers_py/fbs/MapType.py +71 -0
  114. onnxruntime/tools/ort_format_model/ort_flatbuffers_py/fbs/Model.py +223 -0
  115. onnxruntime/tools/ort_format_model/ort_flatbuffers_py/fbs/ModuleState.py +141 -0
  116. onnxruntime/tools/ort_format_model/ort_flatbuffers_py/fbs/Node.py +317 -0
  117. onnxruntime/tools/ort_format_model/ort_flatbuffers_py/fbs/NodeEdge.py +126 -0
  118. onnxruntime/tools/ort_format_model/ort_flatbuffers_py/fbs/NodeType.py +7 -0
  119. onnxruntime/tools/ort_format_model/ort_flatbuffers_py/fbs/NodesToOptimizeIndices.py +160 -0
  120. onnxruntime/tools/ort_format_model/ort_flatbuffers_py/fbs/OpIdKernelTypeStrArgsEntry.py +91 -0
  121. onnxruntime/tools/ort_format_model/ort_flatbuffers_py/fbs/OperatorSetId.py +67 -0
  122. onnxruntime/tools/ort_format_model/ort_flatbuffers_py/fbs/OptimizerGroup.py +117 -0
  123. onnxruntime/tools/ort_format_model/ort_flatbuffers_py/fbs/ParameterOptimizerState.py +91 -0
  124. onnxruntime/tools/ort_format_model/ort_flatbuffers_py/fbs/PropertyBag.py +152 -0
  125. onnxruntime/tools/ort_format_model/ort_flatbuffers_py/fbs/RuntimeOptimizationRecord.py +105 -0
  126. onnxruntime/tools/ort_format_model/ort_flatbuffers_py/fbs/RuntimeOptimizationRecordContainerEntry.py +91 -0
  127. onnxruntime/tools/ort_format_model/ort_flatbuffers_py/fbs/RuntimeOptimizations.py +79 -0
  128. onnxruntime/tools/ort_format_model/ort_flatbuffers_py/fbs/SequenceType.py +58 -0
  129. onnxruntime/tools/ort_format_model/ort_flatbuffers_py/fbs/Shape.py +78 -0
  130. onnxruntime/tools/ort_format_model/ort_flatbuffers_py/fbs/SparseTensor.py +114 -0
  131. onnxruntime/tools/ort_format_model/ort_flatbuffers_py/fbs/StringProperty.py +67 -0
  132. onnxruntime/tools/ort_format_model/ort_flatbuffers_py/fbs/StringStringEntry.py +67 -0
  133. onnxruntime/tools/ort_format_model/ort_flatbuffers_py/fbs/Tensor.py +203 -0
  134. onnxruntime/tools/ort_format_model/ort_flatbuffers_py/fbs/TensorDataType.py +26 -0
  135. onnxruntime/tools/ort_format_model/ort_flatbuffers_py/fbs/TensorTypeAndShape.py +71 -0
  136. onnxruntime/tools/ort_format_model/ort_flatbuffers_py/fbs/TypeInfo.py +83 -0
  137. onnxruntime/tools/ort_format_model/ort_flatbuffers_py/fbs/TypeInfoValue.py +9 -0
  138. onnxruntime/tools/ort_format_model/ort_flatbuffers_py/fbs/ValueInfo.py +84 -0
  139. onnxruntime/tools/ort_format_model/ort_flatbuffers_py/fbs/__init__.py +6 -0
  140. onnxruntime/tools/ort_format_model/ort_model_processor.py +86 -0
  141. onnxruntime/tools/ort_format_model/types.py +84 -0
  142. onnxruntime/tools/ort_format_model/utils.py +62 -0
  143. onnxruntime/tools/pytorch_export_contrib_ops.py +108 -0
  144. onnxruntime/tools/pytorch_export_helpers.py +131 -0
  145. onnxruntime/tools/qdq_helpers/__init__.py +0 -0
  146. onnxruntime/tools/qdq_helpers/optimize_qdq_model.py +37 -0
  147. onnxruntime/tools/reduced_build_config_parser.py +202 -0
  148. onnxruntime/tools/symbolic_shape_infer.py +3016 -0
  149. onnxruntime/tools/update_onnx_opset.py +31 -0
  150. onnxruntime/transformers/__init__.py +8 -0
  151. onnxruntime/transformers/affinity_helper.py +40 -0
  152. onnxruntime/transformers/benchmark.py +944 -0
  153. onnxruntime/transformers/benchmark_helper.py +646 -0
  154. onnxruntime/transformers/bert_perf_test.py +634 -0
  155. onnxruntime/transformers/bert_test_data.py +642 -0
  156. onnxruntime/transformers/compare_bert_results.py +246 -0
  157. onnxruntime/transformers/constants.py +47 -0
  158. onnxruntime/transformers/convert_generation.py +3124 -0
  159. onnxruntime/transformers/convert_tf_models_to_pytorch.py +205 -0
  160. onnxruntime/transformers/convert_to_packing_mode.py +387 -0
  161. onnxruntime/transformers/dynamo_onnx_helper.py +104 -0
  162. onnxruntime/transformers/float16.py +501 -0
  163. onnxruntime/transformers/fusion_attention.py +1235 -0
  164. onnxruntime/transformers/fusion_attention_clip.py +257 -0
  165. onnxruntime/transformers/fusion_attention_sam2.py +534 -0
  166. onnxruntime/transformers/fusion_attention_unet.py +1304 -0
  167. onnxruntime/transformers/fusion_attention_vae.py +301 -0
  168. onnxruntime/transformers/fusion_bart_attention.py +640 -0
  169. onnxruntime/transformers/fusion_base.py +137 -0
  170. onnxruntime/transformers/fusion_bias_add.py +58 -0
  171. onnxruntime/transformers/fusion_biasgelu.py +66 -0
  172. onnxruntime/transformers/fusion_biassplitgelu.py +111 -0
  173. onnxruntime/transformers/fusion_conformer_attention.py +143 -0
  174. onnxruntime/transformers/fusion_embedlayer.py +811 -0
  175. onnxruntime/transformers/fusion_fastgelu.py +360 -0
  176. onnxruntime/transformers/fusion_gelu.py +259 -0
  177. onnxruntime/transformers/fusion_gelu_approximation.py +25 -0
  178. onnxruntime/transformers/fusion_gemmfastgelu.py +122 -0
  179. onnxruntime/transformers/fusion_gpt_attention.py +546 -0
  180. onnxruntime/transformers/fusion_gpt_attention_megatron.py +355 -0
  181. onnxruntime/transformers/fusion_gpt_attention_no_past.py +260 -0
  182. onnxruntime/transformers/fusion_group_norm.py +179 -0
  183. onnxruntime/transformers/fusion_layernorm.py +465 -0
  184. onnxruntime/transformers/fusion_nhwc_conv.py +100 -0
  185. onnxruntime/transformers/fusion_options.py +340 -0
  186. onnxruntime/transformers/fusion_qordered_attention.py +421 -0
  187. onnxruntime/transformers/fusion_qordered_gelu.py +119 -0
  188. onnxruntime/transformers/fusion_qordered_layernorm.py +123 -0
  189. onnxruntime/transformers/fusion_qordered_matmul.py +217 -0
  190. onnxruntime/transformers/fusion_quickgelu.py +74 -0
  191. onnxruntime/transformers/fusion_reshape.py +173 -0
  192. onnxruntime/transformers/fusion_rotary_attention.py +1592 -0
  193. onnxruntime/transformers/fusion_shape.py +110 -0
  194. onnxruntime/transformers/fusion_simplified_layernorm.py +159 -0
  195. onnxruntime/transformers/fusion_skip_group_norm.py +255 -0
  196. onnxruntime/transformers/fusion_skiplayernorm.py +209 -0
  197. onnxruntime/transformers/fusion_transpose.py +168 -0
  198. onnxruntime/transformers/fusion_utils.py +307 -0
  199. onnxruntime/transformers/huggingface_models.py +167 -0
  200. onnxruntime/transformers/import_utils.py +20 -0
  201. onnxruntime/transformers/io_binding_helper.py +442 -0
  202. onnxruntime/transformers/large_model_exporter.py +395 -0
  203. onnxruntime/transformers/machine_info.py +221 -0
  204. onnxruntime/transformers/metrics.py +164 -0
  205. onnxruntime/transformers/models/bart/__init__.py +12 -0
  206. onnxruntime/transformers/models/bart/export.py +98 -0
  207. onnxruntime/transformers/models/bert/__init__.py +12 -0
  208. onnxruntime/transformers/models/bert/eval_squad.py +329 -0
  209. onnxruntime/transformers/models/gpt2/__init__.py +12 -0
  210. onnxruntime/transformers/models/gpt2/benchmark_gpt2.py +413 -0
  211. onnxruntime/transformers/models/gpt2/convert_to_onnx.py +561 -0
  212. onnxruntime/transformers/models/gpt2/gpt2_helper.py +1032 -0
  213. onnxruntime/transformers/models/gpt2/gpt2_parity.py +513 -0
  214. onnxruntime/transformers/models/gpt2/gpt2_tester.py +501 -0
  215. onnxruntime/transformers/models/gpt2/parity_check_helper.py +146 -0
  216. onnxruntime/transformers/models/llama/__init__.py +12 -0
  217. onnxruntime/transformers/models/llama/benchmark.py +703 -0
  218. onnxruntime/transformers/models/llama/benchmark_all.py +488 -0
  219. onnxruntime/transformers/models/llama/benchmark_e2e.py +606 -0
  220. onnxruntime/transformers/models/llama/convert_to_onnx.py +1027 -0
  221. onnxruntime/transformers/models/llama/dist_settings.py +57 -0
  222. onnxruntime/transformers/models/llama/llama_inputs.py +503 -0
  223. onnxruntime/transformers/models/llama/llama_parity.py +309 -0
  224. onnxruntime/transformers/models/llama/llama_torch.py +47 -0
  225. onnxruntime/transformers/models/llama/quant_kv_dataloader.py +108 -0
  226. onnxruntime/transformers/models/longformer/__init__.py +12 -0
  227. onnxruntime/transformers/models/longformer/benchmark_longformer.py +821 -0
  228. onnxruntime/transformers/models/longformer/convert_to_onnx.py +413 -0
  229. onnxruntime/transformers/models/longformer/generate_test_data.py +347 -0
  230. onnxruntime/transformers/models/longformer/longformer_helper.py +77 -0
  231. onnxruntime/transformers/models/phi2/__init__.py +12 -0
  232. onnxruntime/transformers/models/phi2/convert_to_onnx.py +576 -0
  233. onnxruntime/transformers/models/phi2/inference_example.py +414 -0
  234. onnxruntime/transformers/models/sam2/__init__.py +12 -0
  235. onnxruntime/transformers/models/sam2/benchmark_sam2.py +625 -0
  236. onnxruntime/transformers/models/sam2/convert_to_onnx.py +260 -0
  237. onnxruntime/transformers/models/sam2/image_decoder.py +273 -0
  238. onnxruntime/transformers/models/sam2/image_encoder.py +186 -0
  239. onnxruntime/transformers/models/sam2/mask_decoder.py +208 -0
  240. onnxruntime/transformers/models/sam2/nvtx_helper.py +33 -0
  241. onnxruntime/transformers/models/sam2/prompt_encoder.py +189 -0
  242. onnxruntime/transformers/models/sam2/sam2_demo.py +322 -0
  243. onnxruntime/transformers/models/sam2/sam2_image_onnx_predictor.py +280 -0
  244. onnxruntime/transformers/models/sam2/sam2_utils.py +147 -0
  245. onnxruntime/transformers/models/stable_diffusion/__init__.py +12 -0
  246. onnxruntime/transformers/models/stable_diffusion/benchmark.py +1429 -0
  247. onnxruntime/transformers/models/stable_diffusion/benchmark_controlnet.py +426 -0
  248. onnxruntime/transformers/models/stable_diffusion/demo_txt2img.py +102 -0
  249. onnxruntime/transformers/models/stable_diffusion/demo_txt2img_xl.py +268 -0
  250. onnxruntime/transformers/models/stable_diffusion/demo_utils.py +778 -0
  251. onnxruntime/transformers/models/stable_diffusion/diffusion_models.py +1319 -0
  252. onnxruntime/transformers/models/stable_diffusion/diffusion_schedulers.py +1181 -0
  253. onnxruntime/transformers/models/stable_diffusion/engine_builder.py +296 -0
  254. onnxruntime/transformers/models/stable_diffusion/engine_builder_ort_cuda.py +388 -0
  255. onnxruntime/transformers/models/stable_diffusion/engine_builder_ort_trt.py +288 -0
  256. onnxruntime/transformers/models/stable_diffusion/engine_builder_tensorrt.py +395 -0
  257. onnxruntime/transformers/models/stable_diffusion/engine_builder_torch.py +108 -0
  258. onnxruntime/transformers/models/stable_diffusion/optimize_pipeline.py +350 -0
  259. onnxruntime/transformers/models/stable_diffusion/ort_optimizer.py +136 -0
  260. onnxruntime/transformers/models/stable_diffusion/pipeline_stable_diffusion.py +831 -0
  261. onnxruntime/transformers/models/stable_diffusion/trt_utilities.py +12 -0
  262. onnxruntime/transformers/models/t5/__init__.py +12 -0
  263. onnxruntime/transformers/models/t5/convert_to_onnx.py +278 -0
  264. onnxruntime/transformers/models/t5/past_helper.py +150 -0
  265. onnxruntime/transformers/models/t5/t5_decoder.py +438 -0
  266. onnxruntime/transformers/models/t5/t5_encoder.py +171 -0
  267. onnxruntime/transformers/models/t5/t5_encoder_decoder_init.py +299 -0
  268. onnxruntime/transformers/models/t5/t5_helper.py +272 -0
  269. onnxruntime/transformers/models/whisper/__init__.py +12 -0
  270. onnxruntime/transformers/models/whisper/benchmark.py +610 -0
  271. onnxruntime/transformers/models/whisper/benchmark_all.py +528 -0
  272. onnxruntime/transformers/models/whisper/convert_to_onnx.py +536 -0
  273. onnxruntime/transformers/models/whisper/whisper_chain.py +329 -0
  274. onnxruntime/transformers/models/whisper/whisper_decoder.py +402 -0
  275. onnxruntime/transformers/models/whisper/whisper_encoder.py +164 -0
  276. onnxruntime/transformers/models/whisper/whisper_encoder_decoder_init.py +306 -0
  277. onnxruntime/transformers/models/whisper/whisper_helper.py +524 -0
  278. onnxruntime/transformers/models/whisper/whisper_openai_helper.py +84 -0
  279. onnxruntime/transformers/onnx_exporter.py +717 -0
  280. onnxruntime/transformers/onnx_model.py +1569 -0
  281. onnxruntime/transformers/onnx_model_bart.py +142 -0
  282. onnxruntime/transformers/onnx_model_bert.py +481 -0
  283. onnxruntime/transformers/onnx_model_bert_keras.py +475 -0
  284. onnxruntime/transformers/onnx_model_bert_tf.py +589 -0
  285. onnxruntime/transformers/onnx_model_clip.py +40 -0
  286. onnxruntime/transformers/onnx_model_conformer.py +33 -0
  287. onnxruntime/transformers/onnx_model_gpt2.py +101 -0
  288. onnxruntime/transformers/onnx_model_phi.py +930 -0
  289. onnxruntime/transformers/onnx_model_sam2.py +138 -0
  290. onnxruntime/transformers/onnx_model_t5.py +791 -0
  291. onnxruntime/transformers/onnx_model_tnlr.py +227 -0
  292. onnxruntime/transformers/onnx_model_unet.py +259 -0
  293. onnxruntime/transformers/onnx_model_vae.py +43 -0
  294. onnxruntime/transformers/onnx_utils.py +55 -0
  295. onnxruntime/transformers/optimizer.py +612 -0
  296. onnxruntime/transformers/profiler.py +725 -0
  297. onnxruntime/transformers/quantize_helper.py +76 -0
  298. onnxruntime/transformers/shape_infer_helper.py +122 -0
  299. onnxruntime/transformers/shape_optimizer.py +401 -0
  300. onnxruntime/transformers/torch_onnx_export_helper.py +74 -0
  301. onnxruntime_directml-1.20.0.dist-info/METADATA +187 -0
  302. onnxruntime_directml-1.20.0.dist-info/RECORD +305 -0
  303. onnxruntime_directml-1.20.0.dist-info/WHEEL +5 -0
  304. onnxruntime_directml-1.20.0.dist-info/entry_points.txt +2 -0
  305. onnxruntime_directml-1.20.0.dist-info/top_level.txt +1 -0
@@ -0,0 +1,1319 @@
1
+ # -------------------------------------------------------------------------
2
+ # Copyright (c) Microsoft Corporation. All rights reserved.
3
+ # Licensed under the MIT License.
4
+ # --------------------------------------------------------------------------
5
+ # Modified from stable_diffusion_tensorrt_txt2img.py in diffusers and TensorRT demo diffusion,
6
+ # which has the following license:
7
+ #
8
+ # Copyright 2023 The HuggingFace Inc. team.
9
+ # SPDX-FileCopyrightText: Copyright (c) 1993-2023 NVIDIA CORPORATION & AFFILIATES. All rights reserved.
10
+ # SPDX-License-Identifier: Apache-2.0
11
+ #
12
+ # Licensed under the Apache License, Version 2.0 (the "License");
13
+ # you may not use this file except in compliance with the License.
14
+ # You may obtain a copy of the License at
15
+ #
16
+ # http://www.apache.org/licenses/LICENSE-2.0
17
+ #
18
+ # Unless required by applicable law or agreed to in writing, software
19
+ # distributed under the License is distributed on an "AS IS" BASIS,
20
+ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
21
+ # See the License for the specific language governing permissions and
22
+ # limitations under the License.
23
+
24
+ import logging
25
+ import os
26
+ import tempfile
27
+ from typing import Dict, List, Optional
28
+
29
+ import onnx
30
+ import onnx_graphsurgeon as gs
31
+ import torch
32
+ from diffusers.models import AutoencoderKL, ControlNetModel, UNet2DConditionModel
33
+ from onnx import GraphProto, ModelProto, shape_inference
34
+ from ort_optimizer import OrtStableDiffusionOptimizer
35
+ from polygraphy.backend.onnx.loader import fold_constants
36
+ from transformers import CLIPTextModel, CLIPTextModelWithProjection, CLIPTokenizer
37
+
38
+ from onnxruntime.transformers.onnx_model import OnnxModel
39
+
40
+ logger = logging.getLogger(__name__)
41
+
42
+
43
+ class TrtOptimizer:
44
+ def __init__(self, onnx_graph):
45
+ self.graph = gs.import_onnx(onnx_graph)
46
+
47
+ def cleanup(self):
48
+ self.graph.cleanup().toposort()
49
+
50
+ def get_optimized_onnx_graph(self):
51
+ return gs.export_onnx(self.graph)
52
+
53
+ def select_outputs(self, keep, names=None):
54
+ self.graph.outputs = [self.graph.outputs[o] for o in keep]
55
+ if names:
56
+ for i, name in enumerate(names):
57
+ self.graph.outputs[i].name = name
58
+
59
+ def fold_constants(self):
60
+ onnx_graph = fold_constants(gs.export_onnx(self.graph), allow_onnxruntime_shape_inference=True)
61
+ self.graph = gs.import_onnx(onnx_graph)
62
+
63
+ def infer_shapes(self):
64
+ onnx_graph = gs.export_onnx(self.graph)
65
+ if onnx_graph.ByteSize() >= onnx.checker.MAXIMUM_PROTOBUF:
66
+ with tempfile.TemporaryDirectory() as temp_dir:
67
+ input_onnx_path = os.path.join(temp_dir, "model.onnx")
68
+ onnx.save_model(
69
+ onnx_graph,
70
+ input_onnx_path,
71
+ save_as_external_data=True,
72
+ all_tensors_to_one_file=True,
73
+ convert_attribute=False,
74
+ )
75
+ output_onnx_path = os.path.join(temp_dir, "model_with_shape.onnx")
76
+ onnx.shape_inference.infer_shapes_path(input_onnx_path, output_onnx_path)
77
+ onnx_graph = onnx.load(output_onnx_path)
78
+ else:
79
+ onnx_graph = shape_inference.infer_shapes(onnx_graph)
80
+
81
+ self.graph = gs.import_onnx(onnx_graph)
82
+
83
+
84
+ class PipelineInfo:
85
+ def __init__(
86
+ self,
87
+ version: str,
88
+ is_inpaint: bool = False,
89
+ is_refiner: bool = False,
90
+ use_vae=True, # TODO: this has couple with output type of pipeline
91
+ min_image_size=256,
92
+ max_image_size=1024,
93
+ use_fp16_vae=True,
94
+ use_lcm=False,
95
+ do_classifier_free_guidance=True,
96
+ controlnet=None,
97
+ lora_weights=None,
98
+ lora_scale=1.0,
99
+ ):
100
+ self.version = version
101
+ self._is_inpaint = is_inpaint
102
+ self._is_refiner = is_refiner
103
+ self._use_vae = use_vae
104
+ self._min_image_size = min_image_size
105
+ self._max_image_size = max_image_size
106
+ self._use_fp16_vae = use_fp16_vae
107
+ self._use_lcm = use_lcm
108
+ self.do_classifier_free_guidance = do_classifier_free_guidance and not use_lcm
109
+ self.controlnet = controlnet # A list of control net type
110
+ self.lora_weights = lora_weights
111
+ self.lora_scale = lora_scale
112
+
113
+ if is_refiner:
114
+ assert not use_lcm
115
+ assert self.is_xl()
116
+
117
+ def is_inpaint(self) -> bool:
118
+ return self._is_inpaint
119
+
120
+ def is_xl(self) -> bool:
121
+ return "xl" in self.version
122
+
123
+ def is_xl_turbo(self) -> bool:
124
+ return self.version == "xl-turbo"
125
+
126
+ def is_xl_base(self) -> bool:
127
+ return self.version == "xl-1.0" and not self._is_refiner
128
+
129
+ def is_xl_base_or_turbo(self) -> bool:
130
+ return self.is_xl_base() or self.is_xl_turbo()
131
+
132
+ def is_xl_refiner(self) -> bool:
133
+ return self.version == "xl-1.0" and self._is_refiner
134
+
135
+ def use_safetensors(self) -> bool:
136
+ return self.is_xl() or self.version in ["sd-turbo"]
137
+
138
+ def stages(self) -> List[str]:
139
+ if self.is_xl_base_or_turbo():
140
+ return ["clip", "clip2", "unetxl"] + (["vae"] if self._use_vae else [])
141
+
142
+ if self.is_xl_refiner():
143
+ return ["clip2", "unetxl", "vae"]
144
+
145
+ return ["clip", "unet", "vae"]
146
+
147
+ def vae_scaling_factor(self) -> float:
148
+ return 0.13025 if self.is_xl() else 0.18215
149
+
150
+ def vae_torch_fallback(self) -> bool:
151
+ return self.is_xl() and not self._use_fp16_vae
152
+
153
+ def custom_fp16_vae(self) -> Optional[str]:
154
+ # For SD XL, use a VAE that fine-tuned to run in fp16 precision without generating NaNs
155
+ return "madebyollin/sdxl-vae-fp16-fix" if self._use_fp16_vae and self.is_xl() else None
156
+
157
+ def custom_unet(self) -> Optional[str]:
158
+ return "latent-consistency/lcm-sdxl" if self._use_lcm and self.is_xl_base() else None
159
+
160
+ @staticmethod
161
+ def supported_versions(is_xl: bool):
162
+ return ["xl-1.0", "xl-turbo"] if is_xl else ["1.4", "1.5", "2.0-base", "2.0", "2.1", "2.1-base", "sd-turbo"]
163
+
164
+ @staticmethod
165
+ def supported_models():
166
+ return {
167
+ "CompVis/stable-diffusion-v1-4": "1.4",
168
+ "runwayml/stable-diffusion-v1-5": "1.5",
169
+ "stabilityai/stable-diffusion-2-base": "2.0-base",
170
+ "stabilityai/stable-diffusion-2": "2.0",
171
+ "stabilityai/stable-diffusion-2-1": "2.1",
172
+ "stabilityai/stable-diffusion-2-1-base": "2.1",
173
+ "stabilityai/stable-diffusion-xl-base-1.0": "xl-1.0",
174
+ "stabilityai/stable-diffusion-xl-refiner-1.0": "xl-1.0",
175
+ "stabilityai/sdxl-turbo": "xl-turbo",
176
+ "stabilityai/sd-turbo": "sd-turbo",
177
+ # "runwayml/stable-diffusion-inpainting": "1.5",
178
+ # "stabilityai/stable-diffusion-2-inpainting": "2.0",
179
+ }
180
+
181
+ def name(self) -> str:
182
+ if self.version == "1.4":
183
+ if self.is_inpaint():
184
+ return "runwayml/stable-diffusion-inpainting"
185
+ else:
186
+ return "CompVis/stable-diffusion-v1-4"
187
+ elif self.version == "1.5":
188
+ if self.is_inpaint():
189
+ return "runwayml/stable-diffusion-inpainting"
190
+ else:
191
+ return "runwayml/stable-diffusion-v1-5"
192
+ elif self.version == "2.0-base":
193
+ if self.is_inpaint():
194
+ return "stabilityai/stable-diffusion-2-inpainting"
195
+ else:
196
+ return "stabilityai/stable-diffusion-2-base"
197
+ elif self.version == "2.0":
198
+ if self.is_inpaint():
199
+ return "stabilityai/stable-diffusion-2-inpainting"
200
+ else:
201
+ return "stabilityai/stable-diffusion-2"
202
+ elif self.version == "2.1":
203
+ return "stabilityai/stable-diffusion-2-1"
204
+ elif self.version == "2.1-base":
205
+ return "stabilityai/stable-diffusion-2-1-base"
206
+ elif self.version == "xl-1.0":
207
+ if self.is_xl_refiner():
208
+ return "stabilityai/stable-diffusion-xl-refiner-1.0"
209
+ else:
210
+ return "stabilityai/stable-diffusion-xl-base-1.0"
211
+ elif self.version == "xl-turbo":
212
+ return "stabilityai/sdxl-turbo"
213
+ elif self.version == "sd-turbo":
214
+ return "stabilityai/sd-turbo"
215
+
216
+ raise ValueError(f"Incorrect version {self.version}")
217
+
218
+ def short_name(self) -> str:
219
+ return self.name().split("/")[-1].replace("stable-diffusion", "sd")
220
+
221
+ def clip_embedding_dim(self):
222
+ # TODO: can we read from config instead
223
+ if self.version in ("1.4", "1.5"):
224
+ return 768
225
+ elif self.version in ("2.0", "2.0-base", "2.1", "2.1-base", "sd-turbo"):
226
+ return 1024
227
+ elif self.is_xl_base_or_turbo():
228
+ return 768
229
+ else:
230
+ raise ValueError(f"Invalid version {self.version}")
231
+
232
+ def clipwithproj_embedding_dim(self):
233
+ if self.is_xl():
234
+ return 1280
235
+ else:
236
+ raise ValueError(f"Invalid version {self.version}")
237
+
238
+ def unet_embedding_dim(self):
239
+ if self.version in ("1.4", "1.5"):
240
+ return 768
241
+ elif self.version in ("2.0", "2.0-base", "2.1", "2.1-base", "sd-turbo"):
242
+ return 1024
243
+ elif self.is_xl_base_or_turbo():
244
+ return 2048
245
+ elif self.is_xl_refiner():
246
+ return 1280
247
+ else:
248
+ raise ValueError(f"Invalid version {self.version}")
249
+
250
+ def min_image_size(self):
251
+ return self._min_image_size
252
+
253
+ def max_image_size(self):
254
+ return self._max_image_size
255
+
256
+ @staticmethod
257
+ def default_resolution(version: str) -> int:
258
+ if version == "xl-1.0":
259
+ return 1024
260
+ if version in ("2.0", "2.1"):
261
+ return 768
262
+ return 512
263
+
264
+ def default_image_size(self) -> int:
265
+ return PipelineInfo.default_resolution(self.version)
266
+
267
+ @staticmethod
268
+ def supported_controlnet(version="1.5"):
269
+ if version in ("xl-1.0", "xl-turbo"):
270
+ return {
271
+ "canny": "diffusers/controlnet-canny-sdxl-1.0",
272
+ "depth": "diffusers/controlnet-depth-sdxl-1.0",
273
+ }
274
+ elif version == "1.5":
275
+ return {
276
+ "canny": "lllyasviel/control_v11p_sd15_canny",
277
+ "depth": "lllyasviel/control_v11f1p_sd15_depth",
278
+ "openpose": "lllyasviel/control_v11p_sd15_openpose",
279
+ # "tile": "lllyasviel/control_v11f1e_sd15_tile",
280
+ # "lineart": "lllyasviel/control_v11p_sd15_lineart",
281
+ # "inpaint": "lllyasviel/control_v11p_sd15_inpaint",
282
+ # "softedge": "lllyasviel/control_v11p_sd15_softedge",
283
+ "mlsd": "lllyasviel/control_v11p_sd15_mlsd",
284
+ "scribble": "lllyasviel/control_v11p_sd15_scribble",
285
+ # "ip2p": "lllyasviel/control_v11e_sd15_ip2p",
286
+ "normalbae": "lllyasviel/control_v11p_sd15_normalbae",
287
+ "seg": "lllyasviel/control_v11p_sd15_seg",
288
+ # "shuffle": "lllyasviel/control_v11e_sd15_shuffle",
289
+ # "lineart_anime": "lllyasviel/control_v11p_sd15s2_lineart_anime",
290
+ }
291
+ return None
292
+
293
+ def controlnet_name(self):
294
+ """Return a list of controlnet name"""
295
+ if not self.controlnet:
296
+ return None
297
+ controlnet_map = PipelineInfo.supported_controlnet(self.version)
298
+ if controlnet_map is None:
299
+ return None
300
+ return [controlnet_map[controlnet] for controlnet in self.controlnet]
301
+
302
+
303
+ class BaseModel:
304
+ def __init__(
305
+ self,
306
+ pipeline_info: PipelineInfo,
307
+ model,
308
+ device,
309
+ fp16: bool = False,
310
+ max_batch_size: int = 16,
311
+ embedding_dim: int = 768,
312
+ text_maxlen: int = 77,
313
+ ):
314
+ self.name = self.__class__.__name__
315
+
316
+ self.pipeline_info = pipeline_info
317
+
318
+ self.model = model
319
+ self.fp16 = fp16
320
+ self.device = device
321
+
322
+ self.min_batch = 1
323
+ self.max_batch = max_batch_size
324
+ self.min_image_shape = pipeline_info.min_image_size()
325
+ self.max_image_shape = pipeline_info.max_image_size()
326
+ self.min_latent_shape = self.min_image_shape // 8
327
+ self.max_latent_shape = self.max_image_shape // 8
328
+
329
+ self.embedding_dim = embedding_dim
330
+ self.text_maxlen = text_maxlen
331
+
332
+ def get_batch_multiplier(self):
333
+ return 2 if self.pipeline_info.do_classifier_free_guidance else 1
334
+
335
+ def get_ort_optimizer(self):
336
+ model_name_to_model_type = {
337
+ "CLIP": "clip",
338
+ "UNet": "unet",
339
+ "VAE": "vae",
340
+ "UNetXL": "unet",
341
+ "CLIPWithProj": "clip",
342
+ }
343
+ model_type = model_name_to_model_type[self.name]
344
+ return OrtStableDiffusionOptimizer(model_type)
345
+
346
+ def get_model(self):
347
+ return self.model
348
+
349
+ def from_pretrained(self, model_class, framework_model_dir, subfolder=None, model_name=None, **kwargs):
350
+ if model_name is None:
351
+ model_name = self.pipeline_info.name()
352
+
353
+ if subfolder:
354
+ model_dir = os.path.join(framework_model_dir, model_name, subfolder)
355
+ else:
356
+ model_dir = os.path.join(framework_model_dir, model_name)
357
+
358
+ if not os.path.exists(model_dir):
359
+ model = model_class.from_pretrained(
360
+ model_name,
361
+ subfolder=subfolder,
362
+ use_safetensors=self.pipeline_info.use_safetensors(),
363
+ **kwargs,
364
+ ).to(self.device)
365
+ model.save_pretrained(model_dir)
366
+ else:
367
+ print(f"Load {self.name} pytorch model from: {model_dir}")
368
+
369
+ model = model_class.from_pretrained(model_dir).to(self.device)
370
+ return model
371
+
372
+ def load_model(self, framework_model_dir: str, subfolder: str):
373
+ pass
374
+
375
+ def get_input_names(self) -> List[str]:
376
+ pass
377
+
378
+ def get_output_names(self) -> List[str]:
379
+ pass
380
+
381
+ def get_dynamic_axes(self) -> Dict[str, Dict[int, str]]:
382
+ pass
383
+
384
+ def get_sample_input(self, batch_size, image_height, image_width) -> tuple:
385
+ pass
386
+
387
+ def get_profile_id(self, batch_size, image_height, image_width, static_batch, static_image_shape):
388
+ """For TensorRT EP"""
389
+ (
390
+ min_batch,
391
+ max_batch,
392
+ min_image_height,
393
+ max_image_height,
394
+ min_image_width,
395
+ max_image_width,
396
+ _,
397
+ _,
398
+ _,
399
+ _,
400
+ ) = self.get_minmax_dims(batch_size, image_height, image_width, static_batch, static_image_shape)
401
+
402
+ if (self.name in ["UNet", "UNetXL"]) and (self.get_batch_multiplier() == 1):
403
+ profile_id = f"_b1_{batch_size}" if static_batch else f"_b1_{min_batch}_{max_batch}"
404
+ else:
405
+ profile_id = f"_b_{batch_size}" if static_batch else f"_b_{min_batch}_{max_batch}"
406
+
407
+ if self.name != "CLIP":
408
+ if static_image_shape:
409
+ profile_id += f"_h_{image_height}_w_{image_width}"
410
+ else:
411
+ profile_id += f"_h_{min_image_height}_{max_image_height}_w_{min_image_width}_{max_image_width}"
412
+
413
+ return profile_id
414
+
415
+ def get_input_profile(self, batch_size, image_height, image_width, static_batch, static_image_shape):
416
+ """For TensorRT"""
417
+
418
+ def get_shape_dict(self, batch_size, image_height, image_width):
419
+ pass
420
+
421
+ def fp32_input_output_names(self) -> List[str]:
422
+ """For CUDA EP, we export ONNX model with FP32 first, then convert it to mixed precision model.
423
+ This is a list of input or output names that are kept as float32 in optimized model.
424
+ """
425
+ return []
426
+
427
+ def optimize_ort(
428
+ self,
429
+ input_onnx_path,
430
+ optimized_onnx_path,
431
+ to_fp16=True,
432
+ fp32_op_list=None,
433
+ optimize_by_ort=True,
434
+ optimize_by_fusion=True,
435
+ tmp_dir=None,
436
+ ):
437
+ optimizer = self.get_ort_optimizer()
438
+ optimizer.optimize(
439
+ input_onnx_path,
440
+ optimized_onnx_path,
441
+ float16=to_fp16,
442
+ keep_io_types=self.fp32_input_output_names(),
443
+ fp32_op_list=fp32_op_list,
444
+ optimize_by_ort=optimize_by_ort,
445
+ optimize_by_fusion=optimize_by_fusion,
446
+ tmp_dir=tmp_dir,
447
+ )
448
+
449
+ def optimize_trt(self, input_onnx_path, optimized_onnx_path):
450
+ onnx_graph = onnx.load(input_onnx_path)
451
+ opt = TrtOptimizer(onnx_graph)
452
+ opt.cleanup()
453
+ opt.fold_constants()
454
+ opt.infer_shapes()
455
+ opt.cleanup()
456
+ onnx_opt_graph = opt.get_optimized_onnx_graph()
457
+
458
+ if onnx_opt_graph.ByteSize() > onnx.checker.MAXIMUM_PROTOBUF:
459
+ onnx.save_model(
460
+ onnx_opt_graph,
461
+ optimized_onnx_path,
462
+ save_as_external_data=True,
463
+ all_tensors_to_one_file=True,
464
+ convert_attribute=False,
465
+ )
466
+ else:
467
+ onnx.save(onnx_opt_graph, optimized_onnx_path)
468
+
469
+ def check_dims(self, batch_size, image_height, image_width):
470
+ assert batch_size >= self.min_batch and batch_size <= self.max_batch
471
+ assert image_height % 8 == 0 or image_width % 8 == 0
472
+ latent_height = image_height // 8
473
+ latent_width = image_width // 8
474
+ assert latent_height >= self.min_latent_shape and latent_height <= self.max_latent_shape
475
+ assert latent_width >= self.min_latent_shape and latent_width <= self.max_latent_shape
476
+ return (latent_height, latent_width)
477
+
478
+ def get_minmax_dims(self, batch_size, image_height, image_width, static_batch, static_image_shape):
479
+ min_batch = batch_size if static_batch else self.min_batch
480
+ max_batch = batch_size if static_batch else self.max_batch
481
+ latent_height = image_height // 8
482
+ latent_width = image_width // 8
483
+ min_image_height = image_height if static_image_shape else self.min_image_shape
484
+ max_image_height = image_height if static_image_shape else self.max_image_shape
485
+ min_image_width = image_width if static_image_shape else self.min_image_shape
486
+ max_image_width = image_width if static_image_shape else self.max_image_shape
487
+ min_latent_height = latent_height if static_image_shape else self.min_latent_shape
488
+ max_latent_height = latent_height if static_image_shape else self.max_latent_shape
489
+ min_latent_width = latent_width if static_image_shape else self.min_latent_shape
490
+ max_latent_width = latent_width if static_image_shape else self.max_latent_shape
491
+ return (
492
+ min_batch,
493
+ max_batch,
494
+ min_image_height,
495
+ max_image_height,
496
+ min_image_width,
497
+ max_image_width,
498
+ min_latent_height,
499
+ max_latent_height,
500
+ min_latent_width,
501
+ max_latent_width,
502
+ )
503
+
504
+
505
+ class CLIP(BaseModel):
506
+ def __init__(
507
+ self,
508
+ pipeline_info: PipelineInfo,
509
+ model,
510
+ device,
511
+ max_batch_size,
512
+ embedding_dim: int = 0,
513
+ clip_skip=0,
514
+ ):
515
+ super().__init__(
516
+ pipeline_info,
517
+ model=model,
518
+ device=device,
519
+ max_batch_size=max_batch_size,
520
+ embedding_dim=embedding_dim if embedding_dim > 0 else pipeline_info.clip_embedding_dim(),
521
+ )
522
+ self.output_hidden_state = pipeline_info.is_xl()
523
+
524
+ # see https://github.com/huggingface/diffusers/pull/5057 for more information of clip_skip.
525
+ # Clip_skip=1 means that the output of the pre-final layer will be used for computing the prompt embeddings.
526
+ self.clip_skip = clip_skip
527
+
528
+ def get_input_names(self):
529
+ return ["input_ids"]
530
+
531
+ def get_output_names(self):
532
+ # The exported onnx model has no hidden_state. For SD-XL, We will add hidden_state to optimized onnx model.
533
+ return ["text_embeddings"]
534
+
535
+ def get_dynamic_axes(self):
536
+ return {"input_ids": {0: "B", 1: "S"}, "text_embeddings": {0: "B", 1: "S"}}
537
+
538
+ def get_input_profile(self, batch_size, image_height, image_width, static_batch, static_image_shape):
539
+ self.check_dims(batch_size, image_height, image_width)
540
+ min_batch, max_batch, _, _, _, _, _, _, _, _ = self.get_minmax_dims(
541
+ batch_size, image_height, image_width, static_batch, static_image_shape
542
+ )
543
+ return {
544
+ "input_ids": [(min_batch, self.text_maxlen), (batch_size, self.text_maxlen), (max_batch, self.text_maxlen)]
545
+ }
546
+
547
+ def get_shape_dict(self, batch_size, image_height, image_width):
548
+ self.check_dims(batch_size, image_height, image_width)
549
+ output = {
550
+ "input_ids": (batch_size, self.text_maxlen),
551
+ "text_embeddings": (batch_size, self.text_maxlen, self.embedding_dim),
552
+ }
553
+
554
+ if self.output_hidden_state:
555
+ output["hidden_states"] = (batch_size, self.text_maxlen, self.embedding_dim)
556
+
557
+ return output
558
+
559
+ def get_sample_input(self, batch_size, image_height, image_width):
560
+ self.check_dims(batch_size, image_height, image_width)
561
+ return (torch.zeros(batch_size, self.text_maxlen, dtype=torch.int32, device=self.device),)
562
+
563
+ def add_hidden_states_graph_output(self, model: ModelProto, optimized_onnx_path, use_external_data_format=False):
564
+ graph: GraphProto = model.graph
565
+ hidden_layers = -1
566
+ for i in range(len(graph.node)):
567
+ for j in range(len(graph.node[i].output)):
568
+ name = graph.node[i].output[j]
569
+ if "layers" in name:
570
+ hidden_layers = max(int(name.split(".")[1].split("/")[0]), hidden_layers)
571
+
572
+ assert self.clip_skip >= 0 and self.clip_skip < hidden_layers
573
+
574
+ node_output_name = f"/text_model/encoder/layers.{hidden_layers - 1 - self.clip_skip}/Add_1_output_0"
575
+
576
+ # search the name in outputs of all node
577
+ found = False
578
+ for i in range(len(graph.node)):
579
+ for j in range(len(graph.node[i].output)):
580
+ if graph.node[i].output[j] == node_output_name:
581
+ found = True
582
+ break
583
+ if found:
584
+ break
585
+ if not found:
586
+ raise RuntimeError("Failed to find hidden_states graph output in clip")
587
+
588
+ # Insert a Cast (fp32 -> fp16) node so that hidden_states has same data type as the first graph output.
589
+ graph_output_name = "hidden_states"
590
+ cast_node = onnx.helper.make_node("Cast", inputs=[node_output_name], outputs=[graph_output_name])
591
+ cast_node.attribute.extend([onnx.helper.make_attribute("to", graph.output[0].type.tensor_type.elem_type)])
592
+
593
+ hidden_state = graph.output.add()
594
+ hidden_state.CopyFrom(
595
+ onnx.helper.make_tensor_value_info(
596
+ graph_output_name,
597
+ graph.output[0].type.tensor_type.elem_type,
598
+ ["B", "S", self.embedding_dim],
599
+ )
600
+ )
601
+
602
+ onnx_model = OnnxModel(model)
603
+ onnx_model.add_node(cast_node)
604
+ onnx_model.save_model_to_file(optimized_onnx_path, use_external_data_format=use_external_data_format)
605
+
606
+ def optimize_ort(
607
+ self,
608
+ input_onnx_path,
609
+ optimized_onnx_path,
610
+ to_fp16=True,
611
+ fp32_op_list=None,
612
+ optimize_by_ort=True,
613
+ optimize_by_fusion=True,
614
+ tmp_dir=None,
615
+ ):
616
+ optimizer = self.get_ort_optimizer()
617
+
618
+ if not self.output_hidden_state:
619
+ optimizer.optimize(
620
+ input_onnx_path,
621
+ optimized_onnx_path,
622
+ float16=to_fp16,
623
+ keep_io_types=[],
624
+ fp32_op_list=fp32_op_list,
625
+ keep_outputs=["text_embeddings"],
626
+ optimize_by_ort=optimize_by_ort,
627
+ optimize_by_fusion=optimize_by_fusion,
628
+ tmp_dir=tmp_dir,
629
+ )
630
+ elif optimize_by_fusion:
631
+ with tempfile.TemporaryDirectory() as tmp_dir:
632
+ # Save to a temporary file so that we can load it with Onnx Runtime.
633
+ logger.info("Saving a temporary model to add hidden_states to graph output ...")
634
+ tmp_model_path = os.path.join(tmp_dir, "model.onnx")
635
+
636
+ model = onnx.load(input_onnx_path)
637
+ self.add_hidden_states_graph_output(model, tmp_model_path, use_external_data_format=True)
638
+ optimizer.optimize(
639
+ tmp_model_path,
640
+ optimized_onnx_path,
641
+ float16=to_fp16,
642
+ keep_io_types=[],
643
+ fp32_op_list=fp32_op_list,
644
+ keep_outputs=["text_embeddings", "hidden_states"],
645
+ optimize_by_ort=optimize_by_ort,
646
+ optimize_by_fusion=optimize_by_fusion,
647
+ tmp_dir=tmp_dir,
648
+ )
649
+ else: # input is optimized model, there is no need to add hidden states.
650
+ optimizer.optimize(
651
+ input_onnx_path,
652
+ optimized_onnx_path,
653
+ float16=to_fp16,
654
+ keep_io_types=[],
655
+ fp32_op_list=fp32_op_list,
656
+ keep_outputs=["text_embeddings", "hidden_states"],
657
+ optimize_by_ort=optimize_by_ort,
658
+ optimize_by_fusion=optimize_by_fusion,
659
+ tmp_dir=tmp_dir,
660
+ )
661
+
662
+ def optimize_trt(self, input_onnx_path, optimized_onnx_path):
663
+ onnx_graph = onnx.load(input_onnx_path)
664
+ opt = TrtOptimizer(onnx_graph)
665
+ opt.select_outputs([0]) # delete graph output#1
666
+ opt.cleanup()
667
+ opt.fold_constants()
668
+ opt.infer_shapes()
669
+ opt.select_outputs([0], names=["text_embeddings"]) # rename network output
670
+ opt.cleanup()
671
+ onnx_opt_graph = opt.get_optimized_onnx_graph()
672
+ if self.output_hidden_state:
673
+ self.add_hidden_states_graph_output(onnx_opt_graph, optimized_onnx_path)
674
+ else:
675
+ onnx.save(onnx_opt_graph, optimized_onnx_path)
676
+
677
+ def load_model(self, framework_model_dir, subfolder="text_encoder"):
678
+ return self.from_pretrained(CLIPTextModel, framework_model_dir, subfolder)
679
+
680
+
681
+ class CLIPWithProj(CLIP):
682
+ def __init__(
683
+ self,
684
+ pipeline_info: PipelineInfo,
685
+ model,
686
+ device,
687
+ max_batch_size=16,
688
+ clip_skip=0,
689
+ ):
690
+ super().__init__(
691
+ pipeline_info,
692
+ model,
693
+ device=device,
694
+ max_batch_size=max_batch_size,
695
+ embedding_dim=pipeline_info.clipwithproj_embedding_dim(),
696
+ clip_skip=clip_skip,
697
+ )
698
+
699
+ def load_model(self, framework_model_dir, subfolder="text_encoder_2"):
700
+ return self.from_pretrained(CLIPTextModelWithProjection, framework_model_dir, subfolder)
701
+
702
+ def get_shape_dict(self, batch_size, image_height, image_width):
703
+ self.check_dims(batch_size, image_height, image_width)
704
+ output = {
705
+ "input_ids": (batch_size, self.text_maxlen),
706
+ "text_embeddings": (batch_size, self.embedding_dim),
707
+ }
708
+
709
+ if self.output_hidden_state:
710
+ output["hidden_states"] = (batch_size, self.text_maxlen, self.embedding_dim)
711
+
712
+ return output
713
+
714
+
715
+ class UNet2DConditionControlNetModel(torch.nn.Module):
716
+ def __init__(self, unet, controlnets: ControlNetModel):
717
+ super().__init__()
718
+ self.unet = unet
719
+ self.controlnets = controlnets
720
+
721
+ def forward(self, sample, timestep, encoder_hidden_states, controlnet_images, controlnet_scales):
722
+ for i, (controlnet_image, conditioning_scale, controlnet) in enumerate(
723
+ zip(controlnet_images, controlnet_scales, self.controlnets)
724
+ ):
725
+ down_samples, mid_sample = controlnet(
726
+ sample,
727
+ timestep,
728
+ encoder_hidden_states=encoder_hidden_states,
729
+ controlnet_cond=controlnet_image,
730
+ return_dict=False,
731
+ )
732
+
733
+ down_samples = [down_sample * conditioning_scale for down_sample in down_samples]
734
+ mid_sample *= conditioning_scale
735
+
736
+ # merge samples
737
+ if i == 0:
738
+ down_block_res_samples, mid_block_res_sample = down_samples, mid_sample
739
+ else:
740
+ down_block_res_samples = [
741
+ samples_prev + samples_curr
742
+ for samples_prev, samples_curr in zip(down_block_res_samples, down_samples)
743
+ ]
744
+ mid_block_res_sample += mid_sample
745
+
746
+ noise_pred = self.unet(
747
+ sample,
748
+ timestep,
749
+ encoder_hidden_states=encoder_hidden_states,
750
+ down_block_additional_residuals=down_block_res_samples,
751
+ mid_block_additional_residual=mid_block_res_sample,
752
+ )
753
+ return noise_pred[0]
754
+
755
+
756
+ # Modified from convert_stable_diffusion_controlnet_to_onnx.py in diffusers
757
+ class UNet2DConditionXLControlNetModel(torch.nn.Module):
758
+ def __init__(self, unet, controlnets: ControlNetModel):
759
+ super().__init__()
760
+ self.unet = unet
761
+ self.controlnets = controlnets
762
+
763
+ def forward(
764
+ self,
765
+ sample,
766
+ timestep,
767
+ encoder_hidden_states,
768
+ text_embeds,
769
+ time_ids,
770
+ controlnet_images,
771
+ controlnet_scales,
772
+ ):
773
+ added_cond_kwargs = {"text_embeds": text_embeds, "time_ids": time_ids}
774
+ for i, (controlnet_image, conditioning_scale, controlnet) in enumerate(
775
+ zip(controlnet_images, controlnet_scales, self.controlnets)
776
+ ):
777
+ down_samples, mid_sample = controlnet(
778
+ sample,
779
+ timestep,
780
+ encoder_hidden_states=encoder_hidden_states,
781
+ controlnet_cond=controlnet_image,
782
+ conditioning_scale=conditioning_scale,
783
+ added_cond_kwargs=added_cond_kwargs,
784
+ return_dict=False,
785
+ )
786
+
787
+ # merge samples
788
+ if i == 0:
789
+ down_block_res_samples, mid_block_res_sample = down_samples, mid_sample
790
+ else:
791
+ down_block_res_samples = [
792
+ samples_prev + samples_curr
793
+ for samples_prev, samples_curr in zip(down_block_res_samples, down_samples)
794
+ ]
795
+ mid_block_res_sample += mid_sample
796
+
797
+ noise_pred = self.unet(
798
+ sample,
799
+ timestep,
800
+ encoder_hidden_states=encoder_hidden_states,
801
+ down_block_additional_residuals=down_block_res_samples,
802
+ mid_block_additional_residual=mid_block_res_sample,
803
+ added_cond_kwargs=added_cond_kwargs,
804
+ return_dict=False,
805
+ )
806
+ return noise_pred[0]
807
+
808
+
809
+ class UNet(BaseModel):
810
+ def __init__(
811
+ self,
812
+ pipeline_info: PipelineInfo,
813
+ model,
814
+ device,
815
+ fp16=False, # used by TRT
816
+ max_batch_size=16,
817
+ text_maxlen=77,
818
+ unet_dim=4,
819
+ ):
820
+ super().__init__(
821
+ pipeline_info,
822
+ model=model,
823
+ device=device,
824
+ fp16=fp16,
825
+ max_batch_size=max_batch_size,
826
+ embedding_dim=pipeline_info.unet_embedding_dim(),
827
+ text_maxlen=text_maxlen,
828
+ )
829
+
830
+ self.unet_dim = unet_dim
831
+ self.controlnet = pipeline_info.controlnet_name()
832
+
833
+ def load_model(self, framework_model_dir, subfolder="unet"):
834
+ options = {"variant": "fp16", "torch_dtype": torch.float16}
835
+
836
+ model = self.from_pretrained(UNet2DConditionModel, framework_model_dir, subfolder, **options)
837
+
838
+ if self.controlnet:
839
+ controlnet_list = []
840
+ for name in self.controlnet:
841
+ controlnet = self.from_pretrained(
842
+ ControlNetModel,
843
+ framework_model_dir,
844
+ subfolder=None,
845
+ model_name=name,
846
+ torch_dtype=torch.float16,
847
+ )
848
+ controlnet_list.append(controlnet)
849
+
850
+ model = UNet2DConditionControlNetModel(model, torch.nn.ModuleList(controlnet_list))
851
+
852
+ if not self.fp16:
853
+ model = model.to(torch.float32)
854
+
855
+ return model
856
+
857
+ def get_input_names(self):
858
+ if not self.controlnet:
859
+ return ["sample", "timestep", "encoder_hidden_states"]
860
+ else:
861
+ return ["sample", "timestep", "encoder_hidden_states", "controlnet_images", "controlnet_scales"]
862
+
863
+ def get_output_names(self):
864
+ return ["latent"]
865
+
866
+ def get_dynamic_axes(self):
867
+ b = "2B" if self.get_batch_multiplier() == 2 else "B"
868
+ output = {
869
+ "sample": {0: b, 2: "H", 3: "W"},
870
+ "encoder_hidden_states": {0: b},
871
+ "latent": {0: b, 2: "H", 3: "W"},
872
+ }
873
+ if self.controlnet:
874
+ output.update(
875
+ {
876
+ "controlnet_images": {1: b, 3: "8H", 4: "8W"},
877
+ }
878
+ )
879
+ return output
880
+
881
+ def get_input_profile(self, batch_size, image_height, image_width, static_batch, static_image_shape):
882
+ latent_height, latent_width = self.check_dims(batch_size, image_height, image_width)
883
+ (
884
+ min_batch,
885
+ max_batch,
886
+ min_image_height,
887
+ max_image_height,
888
+ min_image_width,
889
+ max_image_width,
890
+ min_latent_height,
891
+ max_latent_height,
892
+ min_latent_width,
893
+ max_latent_width,
894
+ ) = self.get_minmax_dims(batch_size, image_height, image_width, static_batch, static_image_shape)
895
+ m = self.get_batch_multiplier()
896
+ output = {
897
+ "sample": [
898
+ (m * min_batch, self.unet_dim, min_latent_height, min_latent_width),
899
+ (m * batch_size, self.unet_dim, latent_height, latent_width),
900
+ (m * max_batch, self.unet_dim, max_latent_height, max_latent_width),
901
+ ],
902
+ "encoder_hidden_states": [
903
+ (m * min_batch, self.text_maxlen, self.embedding_dim),
904
+ (m * batch_size, self.text_maxlen, self.embedding_dim),
905
+ (m * max_batch, self.text_maxlen, self.embedding_dim),
906
+ ],
907
+ }
908
+
909
+ if self.controlnet:
910
+ output.update(
911
+ {
912
+ "controlnet_images": [
913
+ (len(self.controlnet), m * min_batch, 3, min_image_height, min_image_width),
914
+ (len(self.controlnet), m * batch_size, 3, image_height, image_width),
915
+ (len(self.controlnet), m * max_batch, 3, max_image_height, max_image_width),
916
+ ]
917
+ }
918
+ )
919
+ return output
920
+
921
+ def get_shape_dict(self, batch_size, image_height, image_width):
922
+ latent_height, latent_width = self.check_dims(batch_size, image_height, image_width)
923
+ m = self.get_batch_multiplier()
924
+ output = {
925
+ "sample": (m * batch_size, self.unet_dim, latent_height, latent_width),
926
+ "timestep": [1],
927
+ "encoder_hidden_states": (m * batch_size, self.text_maxlen, self.embedding_dim),
928
+ "latent": (m * batch_size, 4, latent_height, latent_width),
929
+ }
930
+
931
+ if self.controlnet:
932
+ output.update(
933
+ {
934
+ "controlnet_images": (len(self.controlnet), m * batch_size, 3, image_height, image_width),
935
+ "controlnet_scales": [len(self.controlnet)],
936
+ }
937
+ )
938
+ return output
939
+
940
+ def get_sample_input(self, batch_size, image_height, image_width):
941
+ latent_height, latent_width = self.check_dims(batch_size, image_height, image_width)
942
+ dtype = torch.float16 if self.fp16 else torch.float32
943
+ m = self.get_batch_multiplier()
944
+ output = (
945
+ torch.randn(m * batch_size, self.unet_dim, latent_height, latent_width, dtype=dtype, device=self.device),
946
+ torch.tensor([1.0], dtype=dtype, device=self.device),
947
+ torch.randn(m * batch_size, self.text_maxlen, self.embedding_dim, dtype=dtype, device=self.device),
948
+ )
949
+
950
+ if self.controlnet:
951
+ output = (
952
+ *output,
953
+ torch.randn(
954
+ len(self.controlnet), m * batch_size, 3, image_height, image_width, dtype=dtype, device=self.device
955
+ ),
956
+ torch.randn(len(self.controlnet), dtype=dtype, device=self.device),
957
+ )
958
+ return output
959
+
960
+
961
+ class UNetXL(BaseModel):
962
+ def __init__(
963
+ self,
964
+ pipeline_info: PipelineInfo,
965
+ model,
966
+ device,
967
+ fp16=False, # used by TRT
968
+ max_batch_size=16,
969
+ text_maxlen=77,
970
+ unet_dim=4,
971
+ time_dim=6,
972
+ ):
973
+ super().__init__(
974
+ pipeline_info,
975
+ model,
976
+ device=device,
977
+ fp16=fp16,
978
+ max_batch_size=max_batch_size,
979
+ embedding_dim=pipeline_info.unet_embedding_dim(),
980
+ text_maxlen=text_maxlen,
981
+ )
982
+ self.unet_dim = unet_dim
983
+ self.time_dim = time_dim
984
+
985
+ self.custom_unet = pipeline_info.custom_unet()
986
+ self.controlnet = pipeline_info.controlnet_name()
987
+
988
+ def load_model(self, framework_model_dir, subfolder="unet", always_download_fp16=True):
989
+ options = {"variant": "fp16", "torch_dtype": torch.float16} if self.fp16 or always_download_fp16 else {}
990
+
991
+ if self.custom_unet:
992
+ model_dir = os.path.join(framework_model_dir, self.custom_unet, subfolder)
993
+ if not os.path.exists(model_dir):
994
+ unet = UNet2DConditionModel.from_pretrained(self.custom_unet, **options)
995
+ unet.save_pretrained(model_dir)
996
+ else:
997
+ unet = UNet2DConditionModel.from_pretrained(model_dir, **options)
998
+ model = unet.to(self.device)
999
+ else:
1000
+ model = self.from_pretrained(UNet2DConditionModel, framework_model_dir, subfolder, **options)
1001
+
1002
+ if always_download_fp16 and not self.fp16:
1003
+ model = model.to(torch.float32)
1004
+
1005
+ if self.controlnet:
1006
+ cnet_model_opts = {"torch_dtype": torch.float16} if self.fp16 or always_download_fp16 else {}
1007
+ controlnets = torch.nn.ModuleList(
1008
+ [ControlNetModel.from_pretrained(path, **cnet_model_opts).to(self.device) for path in self.controlnet]
1009
+ )
1010
+ model = UNet2DConditionXLControlNetModel(model, controlnets)
1011
+
1012
+ if always_download_fp16 and not self.fp16:
1013
+ model = model.to(torch.float32)
1014
+
1015
+ return model
1016
+
1017
+ def get_input_names(self):
1018
+ input_names = ["sample", "timestep", "encoder_hidden_states", "text_embeds", "time_ids"]
1019
+ if self.controlnet:
1020
+ return [*input_names, "controlnet_images", "controlnet_scales"]
1021
+ return input_names
1022
+
1023
+ def get_output_names(self):
1024
+ return ["latent"]
1025
+
1026
+ def get_dynamic_axes(self):
1027
+ b = "2B" if self.get_batch_multiplier() == 2 else "B"
1028
+ output = {
1029
+ "sample": {0: b, 2: "H", 3: "W"},
1030
+ "encoder_hidden_states": {0: b},
1031
+ "text_embeds": {0: b},
1032
+ "time_ids": {0: b},
1033
+ "latent": {0: b, 2: "H", 3: "W"},
1034
+ }
1035
+
1036
+ if self.controlnet:
1037
+ output.update(
1038
+ {
1039
+ "controlnet_images": {1: b, 3: "8H", 4: "8W"},
1040
+ }
1041
+ )
1042
+ return output
1043
+
1044
+ def get_input_profile(self, batch_size, image_height, image_width, static_batch, static_image_shape):
1045
+ latent_height, latent_width = self.check_dims(batch_size, image_height, image_width)
1046
+ (
1047
+ min_batch,
1048
+ max_batch,
1049
+ min_image_height,
1050
+ max_image_height,
1051
+ min_image_width,
1052
+ max_image_width,
1053
+ min_latent_height,
1054
+ max_latent_height,
1055
+ min_latent_width,
1056
+ max_latent_width,
1057
+ ) = self.get_minmax_dims(batch_size, image_height, image_width, static_batch, static_image_shape)
1058
+ m = self.get_batch_multiplier()
1059
+ output = {
1060
+ "sample": [
1061
+ (m * min_batch, self.unet_dim, min_latent_height, min_latent_width),
1062
+ (m * batch_size, self.unet_dim, latent_height, latent_width),
1063
+ (m * max_batch, self.unet_dim, max_latent_height, max_latent_width),
1064
+ ],
1065
+ "encoder_hidden_states": [
1066
+ (m * min_batch, self.text_maxlen, self.embedding_dim),
1067
+ (m * batch_size, self.text_maxlen, self.embedding_dim),
1068
+ (m * max_batch, self.text_maxlen, self.embedding_dim),
1069
+ ],
1070
+ "text_embeds": [(m * min_batch, 1280), (m * batch_size, 1280), (m * max_batch, 1280)],
1071
+ "time_ids": [
1072
+ (m * min_batch, self.time_dim),
1073
+ (m * batch_size, self.time_dim),
1074
+ (m * max_batch, self.time_dim),
1075
+ ],
1076
+ }
1077
+
1078
+ if self.controlnet:
1079
+ output.update(
1080
+ {
1081
+ "controlnet_images": [
1082
+ (len(self.controlnet), m * min_batch, 3, min_image_height, min_image_width),
1083
+ (len(self.controlnet), m * batch_size, 3, image_height, image_width),
1084
+ (len(self.controlnet), m * max_batch, 3, max_image_height, max_image_width),
1085
+ ],
1086
+ }
1087
+ )
1088
+ return output
1089
+
1090
+ def get_shape_dict(self, batch_size, image_height, image_width):
1091
+ latent_height, latent_width = self.check_dims(batch_size, image_height, image_width)
1092
+ m = self.get_batch_multiplier()
1093
+ output = {
1094
+ "sample": (m * batch_size, self.unet_dim, latent_height, latent_width),
1095
+ "timestep": (1,),
1096
+ "encoder_hidden_states": (m * batch_size, self.text_maxlen, self.embedding_dim),
1097
+ "text_embeds": (m * batch_size, 1280),
1098
+ "time_ids": (m * batch_size, self.time_dim),
1099
+ "latent": (m * batch_size, 4, latent_height, latent_width),
1100
+ }
1101
+
1102
+ if self.controlnet:
1103
+ output.update(
1104
+ {
1105
+ "controlnet_images": (len(self.controlnet), m * batch_size, 3, image_height, image_width),
1106
+ "controlnet_scales": [len(self.controlnet)],
1107
+ }
1108
+ )
1109
+ return output
1110
+
1111
+ def get_sample_input(self, batch_size, image_height, image_width):
1112
+ latent_height, latent_width = self.check_dims(batch_size, image_height, image_width)
1113
+ dtype = torch.float16 if self.fp16 else torch.float32
1114
+ m = self.get_batch_multiplier()
1115
+ if not self.controlnet:
1116
+ return (
1117
+ torch.randn(
1118
+ m * batch_size, self.unet_dim, latent_height, latent_width, dtype=dtype, device=self.device
1119
+ ),
1120
+ torch.tensor([1.0], dtype=dtype, device=self.device),
1121
+ torch.randn(m * batch_size, self.text_maxlen, self.embedding_dim, dtype=dtype, device=self.device),
1122
+ {
1123
+ "added_cond_kwargs": {
1124
+ "text_embeds": torch.randn(m * batch_size, 1280, dtype=dtype, device=self.device),
1125
+ "time_ids": torch.randn(m * batch_size, self.time_dim, dtype=dtype, device=self.device),
1126
+ }
1127
+ },
1128
+ )
1129
+ else:
1130
+ # sample, timestep, encoder_hidden_states, text_embeds, time_ids, controlnet_images, controlnet_scales,
1131
+ return (
1132
+ torch.randn(
1133
+ m * batch_size, self.unet_dim, latent_height, latent_width, dtype=dtype, device=self.device
1134
+ ),
1135
+ torch.tensor([1.0], dtype=dtype, device=self.device),
1136
+ torch.randn(m * batch_size, self.text_maxlen, self.embedding_dim, dtype=dtype, device=self.device),
1137
+ torch.randn(m * batch_size, 1280, dtype=dtype, device=self.device),
1138
+ torch.randn(m * batch_size, self.time_dim, dtype=dtype, device=self.device),
1139
+ torch.randn(
1140
+ len(self.controlnet), m * batch_size, 3, image_height, image_width, dtype=dtype, device=self.device
1141
+ ),
1142
+ torch.randn(len(self.controlnet), dtype=dtype, device=self.device),
1143
+ )
1144
+
1145
+
1146
+ # VAE Decoder
1147
+ class VAE(BaseModel):
1148
+ def __init__(
1149
+ self,
1150
+ pipeline_info: PipelineInfo,
1151
+ model,
1152
+ device,
1153
+ max_batch_size,
1154
+ fp16: bool = False,
1155
+ custom_fp16_vae: Optional[str] = None,
1156
+ ):
1157
+ super().__init__(
1158
+ pipeline_info,
1159
+ model=model,
1160
+ device=device,
1161
+ fp16=fp16,
1162
+ max_batch_size=max_batch_size,
1163
+ )
1164
+
1165
+ # For SD XL, need custom trained fp16 model to speed up, and avoid overflow at the same time.
1166
+ self.custom_fp16_vae = custom_fp16_vae
1167
+
1168
+ def load_model(self, framework_model_dir, subfolder: str = "vae_decoder"):
1169
+ model_name = self.custom_fp16_vae or self.pipeline_info.name()
1170
+
1171
+ model_dir = os.path.join(framework_model_dir, model_name, subfolder)
1172
+ if not os.path.exists(model_dir):
1173
+ if self.custom_fp16_vae:
1174
+ vae = AutoencoderKL.from_pretrained(self.custom_fp16_vae, torch_dtype=torch.float16).to(self.device)
1175
+ else:
1176
+ vae = AutoencoderKL.from_pretrained(
1177
+ self.pipeline_info.name(),
1178
+ subfolder="vae",
1179
+ use_safetensors=self.pipeline_info.use_safetensors(),
1180
+ ).to(self.device)
1181
+ vae.save_pretrained(model_dir)
1182
+ else:
1183
+ print(f"Load {self.name} pytorch model from: {model_dir}")
1184
+ if self.custom_fp16_vae:
1185
+ vae = AutoencoderKL.from_pretrained(model_dir, torch_dtype=torch.float16).to(self.device)
1186
+ else:
1187
+ vae = AutoencoderKL.from_pretrained(model_dir).to(self.device)
1188
+
1189
+ vae.forward = vae.decode
1190
+ return vae
1191
+
1192
+ def get_input_names(self):
1193
+ return ["latent"]
1194
+
1195
+ def get_output_names(self):
1196
+ return ["images"]
1197
+
1198
+ def get_dynamic_axes(self):
1199
+ return {"latent": {0: "B", 2: "H", 3: "W"}, "images": {0: "B", 2: "8H", 3: "8W"}}
1200
+
1201
+ def get_input_profile(self, batch_size, image_height, image_width, static_batch, static_image_shape):
1202
+ latent_height, latent_width = self.check_dims(batch_size, image_height, image_width)
1203
+ (
1204
+ min_batch,
1205
+ max_batch,
1206
+ _,
1207
+ _,
1208
+ _,
1209
+ _,
1210
+ min_latent_height,
1211
+ max_latent_height,
1212
+ min_latent_width,
1213
+ max_latent_width,
1214
+ ) = self.get_minmax_dims(batch_size, image_height, image_width, static_batch, static_image_shape)
1215
+ return {
1216
+ "latent": [
1217
+ (min_batch, 4, min_latent_height, min_latent_width),
1218
+ (batch_size, 4, latent_height, latent_width),
1219
+ (max_batch, 4, max_latent_height, max_latent_width),
1220
+ ]
1221
+ }
1222
+
1223
+ def get_shape_dict(self, batch_size, image_height, image_width):
1224
+ latent_height, latent_width = self.check_dims(batch_size, image_height, image_width)
1225
+ return {
1226
+ "latent": (batch_size, 4, latent_height, latent_width),
1227
+ "images": (batch_size, 3, image_height, image_width),
1228
+ }
1229
+
1230
+ def get_sample_input(self, batch_size, image_height, image_width):
1231
+ latent_height, latent_width = self.check_dims(batch_size, image_height, image_width)
1232
+ dtype = torch.float16 if self.fp16 else torch.float32
1233
+ return (torch.randn(batch_size, 4, latent_height, latent_width, dtype=dtype, device=self.device),)
1234
+
1235
+ def fp32_input_output_names(self) -> List[str]:
1236
+ return []
1237
+
1238
+
1239
+ def get_tokenizer(pipeline_info: PipelineInfo, framework_model_dir, subfolder="tokenizer"):
1240
+ tokenizer_dir = os.path.join(framework_model_dir, pipeline_info.name(), subfolder)
1241
+
1242
+ if not os.path.exists(tokenizer_dir):
1243
+ model = CLIPTokenizer.from_pretrained(
1244
+ pipeline_info.name(),
1245
+ subfolder=subfolder,
1246
+ use_safetensors=pipeline_info.is_xl(),
1247
+ )
1248
+ model.save_pretrained(tokenizer_dir)
1249
+ else:
1250
+ print(f"[I] Load tokenizer pytorch model from: {tokenizer_dir}")
1251
+ model = CLIPTokenizer.from_pretrained(tokenizer_dir)
1252
+ return model
1253
+
1254
+
1255
+ class TorchVAEEncoder(torch.nn.Module):
1256
+ def __init__(self, vae_encoder):
1257
+ super().__init__()
1258
+ self.vae_encoder = vae_encoder
1259
+
1260
+ def forward(self, x):
1261
+ return self.vae_encoder.encode(x).latent_dist.sample()
1262
+
1263
+
1264
+ class VAEEncoder(BaseModel):
1265
+ def __init__(self, pipeline_info: PipelineInfo, model, device, max_batch_size):
1266
+ super().__init__(
1267
+ pipeline_info,
1268
+ model=model,
1269
+ device=device,
1270
+ max_batch_size=max_batch_size,
1271
+ )
1272
+
1273
+ def load_model(self, framework_model_dir, subfolder="vae_encoder"):
1274
+ vae = self.from_pretrained(AutoencoderKL, framework_model_dir, subfolder)
1275
+ return TorchVAEEncoder(vae)
1276
+
1277
+ def get_input_names(self):
1278
+ return ["images"]
1279
+
1280
+ def get_output_names(self):
1281
+ return ["latent"]
1282
+
1283
+ def get_dynamic_axes(self):
1284
+ return {"images": {0: "B", 2: "8H", 3: "8W"}, "latent": {0: "B", 2: "H", 3: "W"}}
1285
+
1286
+ def get_input_profile(self, batch_size, image_height, image_width, static_batch, static_image_shape):
1287
+ self.check_dims(batch_size, image_height, image_width)
1288
+
1289
+ (
1290
+ min_batch,
1291
+ max_batch,
1292
+ min_image_height,
1293
+ max_image_height,
1294
+ min_image_width,
1295
+ max_image_width,
1296
+ _,
1297
+ _,
1298
+ _,
1299
+ _,
1300
+ ) = self.get_minmax_dims(batch_size, image_height, image_width, static_batch, static_image_shape)
1301
+
1302
+ return {
1303
+ "images": [
1304
+ (min_batch, 3, min_image_height, min_image_width),
1305
+ (batch_size, 3, image_height, image_width),
1306
+ (max_batch, 3, max_image_height, max_image_width),
1307
+ ],
1308
+ }
1309
+
1310
+ def get_shape_dict(self, batch_size, image_height, image_width):
1311
+ latent_height, latent_width = self.check_dims(batch_size, image_height, image_width)
1312
+ return {
1313
+ "images": (batch_size, 3, image_height, image_width),
1314
+ "latent": (batch_size, 4, latent_height, latent_width),
1315
+ }
1316
+
1317
+ def get_sample_input(self, batch_size, image_height, image_width):
1318
+ self.check_dims(batch_size, image_height, image_width)
1319
+ return torch.randn(batch_size, 3, image_height, image_width, dtype=torch.float32, device=self.device)