diffusers 0.27.2__py3-none-any.whl → 0.28.1__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (278) hide show
  1. diffusers/__init__.py +26 -1
  2. diffusers/callbacks.py +156 -0
  3. diffusers/commands/env.py +110 -6
  4. diffusers/configuration_utils.py +33 -11
  5. diffusers/dependency_versions_table.py +2 -1
  6. diffusers/image_processor.py +158 -45
  7. diffusers/loaders/__init__.py +2 -5
  8. diffusers/loaders/autoencoder.py +4 -4
  9. diffusers/loaders/controlnet.py +4 -4
  10. diffusers/loaders/ip_adapter.py +80 -22
  11. diffusers/loaders/lora.py +134 -20
  12. diffusers/loaders/lora_conversion_utils.py +46 -43
  13. diffusers/loaders/peft.py +4 -3
  14. diffusers/loaders/single_file.py +401 -170
  15. diffusers/loaders/single_file_model.py +290 -0
  16. diffusers/loaders/single_file_utils.py +616 -672
  17. diffusers/loaders/textual_inversion.py +41 -20
  18. diffusers/loaders/unet.py +168 -115
  19. diffusers/loaders/unet_loader_utils.py +163 -0
  20. diffusers/models/__init__.py +8 -0
  21. diffusers/models/activations.py +23 -3
  22. diffusers/models/attention.py +10 -11
  23. diffusers/models/attention_processor.py +475 -148
  24. diffusers/models/autoencoders/autoencoder_asym_kl.py +14 -16
  25. diffusers/models/autoencoders/autoencoder_kl.py +18 -19
  26. diffusers/models/autoencoders/autoencoder_kl_temporal_decoder.py +11 -11
  27. diffusers/models/autoencoders/autoencoder_tiny.py +16 -16
  28. diffusers/models/autoencoders/consistency_decoder_vae.py +36 -11
  29. diffusers/models/autoencoders/vae.py +23 -24
  30. diffusers/models/controlnet.py +12 -9
  31. diffusers/models/controlnet_flax.py +4 -4
  32. diffusers/models/controlnet_xs.py +1915 -0
  33. diffusers/models/downsampling.py +17 -18
  34. diffusers/models/embeddings.py +363 -32
  35. diffusers/models/model_loading_utils.py +177 -0
  36. diffusers/models/modeling_flax_pytorch_utils.py +2 -1
  37. diffusers/models/modeling_flax_utils.py +4 -4
  38. diffusers/models/modeling_outputs.py +14 -0
  39. diffusers/models/modeling_pytorch_flax_utils.py +1 -1
  40. diffusers/models/modeling_utils.py +175 -99
  41. diffusers/models/normalization.py +2 -1
  42. diffusers/models/resnet.py +18 -23
  43. diffusers/models/transformer_temporal.py +3 -3
  44. diffusers/models/transformers/__init__.py +3 -0
  45. diffusers/models/transformers/dit_transformer_2d.py +240 -0
  46. diffusers/models/transformers/dual_transformer_2d.py +4 -4
  47. diffusers/models/transformers/hunyuan_transformer_2d.py +427 -0
  48. diffusers/models/transformers/pixart_transformer_2d.py +336 -0
  49. diffusers/models/transformers/prior_transformer.py +7 -7
  50. diffusers/models/transformers/t5_film_transformer.py +17 -19
  51. diffusers/models/transformers/transformer_2d.py +292 -184
  52. diffusers/models/transformers/transformer_temporal.py +10 -10
  53. diffusers/models/unets/unet_1d.py +5 -5
  54. diffusers/models/unets/unet_1d_blocks.py +29 -29
  55. diffusers/models/unets/unet_2d.py +6 -6
  56. diffusers/models/unets/unet_2d_blocks.py +137 -128
  57. diffusers/models/unets/unet_2d_condition.py +19 -15
  58. diffusers/models/unets/unet_2d_condition_flax.py +6 -5
  59. diffusers/models/unets/unet_3d_blocks.py +79 -77
  60. diffusers/models/unets/unet_3d_condition.py +13 -9
  61. diffusers/models/unets/unet_i2vgen_xl.py +14 -13
  62. diffusers/models/unets/unet_kandinsky3.py +1 -1
  63. diffusers/models/unets/unet_motion_model.py +114 -14
  64. diffusers/models/unets/unet_spatio_temporal_condition.py +15 -14
  65. diffusers/models/unets/unet_stable_cascade.py +16 -13
  66. diffusers/models/upsampling.py +17 -20
  67. diffusers/models/vq_model.py +16 -15
  68. diffusers/pipelines/__init__.py +27 -3
  69. diffusers/pipelines/amused/pipeline_amused.py +12 -12
  70. diffusers/pipelines/amused/pipeline_amused_img2img.py +14 -12
  71. diffusers/pipelines/amused/pipeline_amused_inpaint.py +13 -11
  72. diffusers/pipelines/animatediff/__init__.py +2 -0
  73. diffusers/pipelines/animatediff/pipeline_animatediff.py +24 -46
  74. diffusers/pipelines/animatediff/pipeline_animatediff_sdxl.py +1284 -0
  75. diffusers/pipelines/animatediff/pipeline_animatediff_video2video.py +82 -72
  76. diffusers/pipelines/animatediff/pipeline_output.py +3 -2
  77. diffusers/pipelines/audioldm/pipeline_audioldm.py +14 -14
  78. diffusers/pipelines/audioldm2/modeling_audioldm2.py +54 -35
  79. diffusers/pipelines/audioldm2/pipeline_audioldm2.py +120 -36
  80. diffusers/pipelines/auto_pipeline.py +21 -17
  81. diffusers/pipelines/blip_diffusion/blip_image_processing.py +1 -1
  82. diffusers/pipelines/blip_diffusion/modeling_blip2.py +5 -5
  83. diffusers/pipelines/blip_diffusion/modeling_ctx_clip.py +1 -1
  84. diffusers/pipelines/blip_diffusion/pipeline_blip_diffusion.py +2 -2
  85. diffusers/pipelines/consistency_models/pipeline_consistency_models.py +5 -5
  86. diffusers/pipelines/controlnet/multicontrolnet.py +4 -8
  87. diffusers/pipelines/controlnet/pipeline_controlnet.py +87 -52
  88. diffusers/pipelines/controlnet/pipeline_controlnet_blip_diffusion.py +2 -2
  89. diffusers/pipelines/controlnet/pipeline_controlnet_img2img.py +50 -43
  90. diffusers/pipelines/controlnet/pipeline_controlnet_inpaint.py +52 -40
  91. diffusers/pipelines/controlnet/pipeline_controlnet_inpaint_sd_xl.py +80 -47
  92. diffusers/pipelines/controlnet/pipeline_controlnet_sd_xl.py +147 -49
  93. diffusers/pipelines/controlnet/pipeline_controlnet_sd_xl_img2img.py +89 -55
  94. diffusers/pipelines/controlnet_xs/__init__.py +68 -0
  95. diffusers/pipelines/controlnet_xs/pipeline_controlnet_xs.py +911 -0
  96. diffusers/pipelines/controlnet_xs/pipeline_controlnet_xs_sd_xl.py +1115 -0
  97. diffusers/pipelines/deepfloyd_if/pipeline_if.py +14 -28
  98. diffusers/pipelines/deepfloyd_if/pipeline_if_img2img.py +18 -33
  99. diffusers/pipelines/deepfloyd_if/pipeline_if_img2img_superresolution.py +21 -39
  100. diffusers/pipelines/deepfloyd_if/pipeline_if_inpainting.py +20 -36
  101. diffusers/pipelines/deepfloyd_if/pipeline_if_inpainting_superresolution.py +23 -39
  102. diffusers/pipelines/deepfloyd_if/pipeline_if_superresolution.py +17 -32
  103. diffusers/pipelines/deprecated/alt_diffusion/modeling_roberta_series.py +11 -11
  104. diffusers/pipelines/deprecated/alt_diffusion/pipeline_alt_diffusion.py +43 -20
  105. diffusers/pipelines/deprecated/alt_diffusion/pipeline_alt_diffusion_img2img.py +36 -18
  106. diffusers/pipelines/deprecated/repaint/pipeline_repaint.py +2 -2
  107. diffusers/pipelines/deprecated/spectrogram_diffusion/pipeline_spectrogram_diffusion.py +7 -7
  108. diffusers/pipelines/deprecated/stable_diffusion_variants/pipeline_cycle_diffusion.py +12 -12
  109. diffusers/pipelines/deprecated/stable_diffusion_variants/pipeline_stable_diffusion_inpaint_legacy.py +18 -18
  110. diffusers/pipelines/deprecated/stable_diffusion_variants/pipeline_stable_diffusion_model_editing.py +20 -15
  111. diffusers/pipelines/deprecated/stable_diffusion_variants/pipeline_stable_diffusion_paradigms.py +20 -15
  112. diffusers/pipelines/deprecated/stable_diffusion_variants/pipeline_stable_diffusion_pix2pix_zero.py +30 -25
  113. diffusers/pipelines/deprecated/versatile_diffusion/modeling_text_unet.py +69 -59
  114. diffusers/pipelines/deprecated/versatile_diffusion/pipeline_versatile_diffusion.py +13 -13
  115. diffusers/pipelines/deprecated/versatile_diffusion/pipeline_versatile_diffusion_dual_guided.py +10 -5
  116. diffusers/pipelines/deprecated/versatile_diffusion/pipeline_versatile_diffusion_image_variation.py +11 -6
  117. diffusers/pipelines/deprecated/versatile_diffusion/pipeline_versatile_diffusion_text_to_image.py +10 -5
  118. diffusers/pipelines/deprecated/vq_diffusion/pipeline_vq_diffusion.py +5 -5
  119. diffusers/pipelines/dit/pipeline_dit.py +7 -4
  120. diffusers/pipelines/free_init_utils.py +39 -38
  121. diffusers/pipelines/hunyuandit/__init__.py +48 -0
  122. diffusers/pipelines/hunyuandit/pipeline_hunyuandit.py +881 -0
  123. diffusers/pipelines/i2vgen_xl/pipeline_i2vgen_xl.py +33 -48
  124. diffusers/pipelines/kandinsky/pipeline_kandinsky.py +8 -8
  125. diffusers/pipelines/kandinsky/pipeline_kandinsky_combined.py +23 -20
  126. diffusers/pipelines/kandinsky/pipeline_kandinsky_img2img.py +11 -11
  127. diffusers/pipelines/kandinsky/pipeline_kandinsky_inpaint.py +12 -12
  128. diffusers/pipelines/kandinsky/pipeline_kandinsky_prior.py +10 -10
  129. diffusers/pipelines/kandinsky2_2/pipeline_kandinsky2_2.py +6 -6
  130. diffusers/pipelines/kandinsky2_2/pipeline_kandinsky2_2_combined.py +32 -29
  131. diffusers/pipelines/kandinsky2_2/pipeline_kandinsky2_2_controlnet.py +10 -10
  132. diffusers/pipelines/kandinsky2_2/pipeline_kandinsky2_2_controlnet_img2img.py +10 -10
  133. diffusers/pipelines/kandinsky2_2/pipeline_kandinsky2_2_img2img.py +6 -6
  134. diffusers/pipelines/kandinsky2_2/pipeline_kandinsky2_2_inpainting.py +8 -8
  135. diffusers/pipelines/kandinsky2_2/pipeline_kandinsky2_2_prior.py +7 -7
  136. diffusers/pipelines/kandinsky2_2/pipeline_kandinsky2_2_prior_emb2emb.py +6 -6
  137. diffusers/pipelines/kandinsky3/convert_kandinsky3_unet.py +3 -3
  138. diffusers/pipelines/kandinsky3/pipeline_kandinsky3.py +20 -33
  139. diffusers/pipelines/kandinsky3/pipeline_kandinsky3_img2img.py +24 -35
  140. diffusers/pipelines/latent_consistency_models/pipeline_latent_consistency_img2img.py +48 -30
  141. diffusers/pipelines/latent_consistency_models/pipeline_latent_consistency_text2img.py +50 -28
  142. diffusers/pipelines/latent_diffusion/pipeline_latent_diffusion.py +11 -11
  143. diffusers/pipelines/ledits_pp/pipeline_leditspp_stable_diffusion.py +61 -67
  144. diffusers/pipelines/ledits_pp/pipeline_leditspp_stable_diffusion_xl.py +70 -69
  145. diffusers/pipelines/ledits_pp/pipeline_output.py +2 -2
  146. diffusers/pipelines/marigold/__init__.py +50 -0
  147. diffusers/pipelines/marigold/marigold_image_processing.py +561 -0
  148. diffusers/pipelines/marigold/pipeline_marigold_depth.py +813 -0
  149. diffusers/pipelines/marigold/pipeline_marigold_normals.py +690 -0
  150. diffusers/pipelines/musicldm/pipeline_musicldm.py +14 -14
  151. diffusers/pipelines/paint_by_example/pipeline_paint_by_example.py +17 -12
  152. diffusers/pipelines/pia/pipeline_pia.py +39 -125
  153. diffusers/pipelines/pipeline_flax_utils.py +4 -4
  154. diffusers/pipelines/pipeline_loading_utils.py +269 -23
  155. diffusers/pipelines/pipeline_utils.py +266 -37
  156. diffusers/pipelines/pixart_alpha/__init__.py +8 -1
  157. diffusers/pipelines/pixart_alpha/pipeline_pixart_alpha.py +69 -79
  158. diffusers/pipelines/pixart_alpha/pipeline_pixart_sigma.py +880 -0
  159. diffusers/pipelines/semantic_stable_diffusion/pipeline_semantic_stable_diffusion.py +10 -5
  160. diffusers/pipelines/shap_e/pipeline_shap_e.py +3 -3
  161. diffusers/pipelines/shap_e/pipeline_shap_e_img2img.py +14 -14
  162. diffusers/pipelines/shap_e/renderer.py +1 -1
  163. diffusers/pipelines/stable_cascade/pipeline_stable_cascade.py +18 -18
  164. diffusers/pipelines/stable_cascade/pipeline_stable_cascade_combined.py +23 -19
  165. diffusers/pipelines/stable_cascade/pipeline_stable_cascade_prior.py +33 -32
  166. diffusers/pipelines/stable_diffusion/__init__.py +0 -1
  167. diffusers/pipelines/stable_diffusion/convert_from_ckpt.py +18 -11
  168. diffusers/pipelines/stable_diffusion/pipeline_onnx_stable_diffusion.py +2 -2
  169. diffusers/pipelines/stable_diffusion/pipeline_onnx_stable_diffusion_upscale.py +6 -6
  170. diffusers/pipelines/stable_diffusion/pipeline_stable_diffusion.py +73 -39
  171. diffusers/pipelines/stable_diffusion/pipeline_stable_diffusion_depth2img.py +24 -17
  172. diffusers/pipelines/stable_diffusion/pipeline_stable_diffusion_image_variation.py +13 -8
  173. diffusers/pipelines/stable_diffusion/pipeline_stable_diffusion_img2img.py +66 -36
  174. diffusers/pipelines/stable_diffusion/pipeline_stable_diffusion_inpaint.py +82 -46
  175. diffusers/pipelines/stable_diffusion/pipeline_stable_diffusion_instruct_pix2pix.py +123 -28
  176. diffusers/pipelines/stable_diffusion/pipeline_stable_diffusion_latent_upscale.py +6 -6
  177. diffusers/pipelines/stable_diffusion/pipeline_stable_diffusion_upscale.py +16 -16
  178. diffusers/pipelines/stable_diffusion/pipeline_stable_unclip.py +24 -19
  179. diffusers/pipelines/stable_diffusion/pipeline_stable_unclip_img2img.py +37 -31
  180. diffusers/pipelines/stable_diffusion/safety_checker.py +2 -1
  181. diffusers/pipelines/stable_diffusion_attend_and_excite/pipeline_stable_diffusion_attend_and_excite.py +23 -15
  182. diffusers/pipelines/stable_diffusion_diffedit/pipeline_stable_diffusion_diffedit.py +44 -39
  183. diffusers/pipelines/stable_diffusion_gligen/pipeline_stable_diffusion_gligen.py +23 -18
  184. diffusers/pipelines/stable_diffusion_gligen/pipeline_stable_diffusion_gligen_text_image.py +19 -14
  185. diffusers/pipelines/stable_diffusion_k_diffusion/pipeline_stable_diffusion_k_diffusion.py +20 -15
  186. diffusers/pipelines/stable_diffusion_k_diffusion/pipeline_stable_diffusion_xl_k_diffusion.py +24 -19
  187. diffusers/pipelines/stable_diffusion_ldm3d/pipeline_stable_diffusion_ldm3d.py +65 -32
  188. diffusers/pipelines/stable_diffusion_panorama/pipeline_stable_diffusion_panorama.py +274 -38
  189. diffusers/pipelines/stable_diffusion_safe/pipeline_stable_diffusion_safe.py +10 -5
  190. diffusers/pipelines/stable_diffusion_safe/safety_checker.py +1 -1
  191. diffusers/pipelines/stable_diffusion_sag/pipeline_stable_diffusion_sag.py +92 -25
  192. diffusers/pipelines/stable_diffusion_xl/pipeline_stable_diffusion_xl.py +88 -44
  193. diffusers/pipelines/stable_diffusion_xl/pipeline_stable_diffusion_xl_img2img.py +108 -56
  194. diffusers/pipelines/stable_diffusion_xl/pipeline_stable_diffusion_xl_inpaint.py +96 -51
  195. diffusers/pipelines/stable_diffusion_xl/pipeline_stable_diffusion_xl_instruct_pix2pix.py +45 -25
  196. diffusers/pipelines/stable_diffusion_xl/watermark.py +9 -3
  197. diffusers/pipelines/stable_video_diffusion/pipeline_stable_video_diffusion.py +110 -57
  198. diffusers/pipelines/t2i_adapter/pipeline_stable_diffusion_adapter.py +59 -30
  199. diffusers/pipelines/t2i_adapter/pipeline_stable_diffusion_xl_adapter.py +71 -42
  200. diffusers/pipelines/text_to_video_synthesis/pipeline_output.py +3 -2
  201. diffusers/pipelines/text_to_video_synthesis/pipeline_text_to_video_synth.py +18 -41
  202. diffusers/pipelines/text_to_video_synthesis/pipeline_text_to_video_synth_img2img.py +21 -85
  203. diffusers/pipelines/text_to_video_synthesis/pipeline_text_to_video_zero.py +28 -19
  204. diffusers/pipelines/text_to_video_synthesis/pipeline_text_to_video_zero_sdxl.py +39 -33
  205. diffusers/pipelines/unclip/pipeline_unclip.py +6 -6
  206. diffusers/pipelines/unclip/pipeline_unclip_image_variation.py +6 -6
  207. diffusers/pipelines/unidiffuser/modeling_text_decoder.py +1 -1
  208. diffusers/pipelines/unidiffuser/modeling_uvit.py +9 -9
  209. diffusers/pipelines/unidiffuser/pipeline_unidiffuser.py +23 -23
  210. diffusers/pipelines/wuerstchen/modeling_paella_vq_model.py +5 -5
  211. diffusers/pipelines/wuerstchen/modeling_wuerstchen_common.py +5 -10
  212. diffusers/pipelines/wuerstchen/modeling_wuerstchen_prior.py +4 -6
  213. diffusers/pipelines/wuerstchen/pipeline_wuerstchen.py +4 -4
  214. diffusers/pipelines/wuerstchen/pipeline_wuerstchen_combined.py +12 -12
  215. diffusers/pipelines/wuerstchen/pipeline_wuerstchen_prior.py +10 -10
  216. diffusers/schedulers/__init__.py +2 -2
  217. diffusers/schedulers/deprecated/__init__.py +1 -1
  218. diffusers/schedulers/deprecated/scheduling_karras_ve.py +25 -25
  219. diffusers/schedulers/scheduling_amused.py +5 -5
  220. diffusers/schedulers/scheduling_consistency_decoder.py +11 -11
  221. diffusers/schedulers/scheduling_consistency_models.py +20 -26
  222. diffusers/schedulers/scheduling_ddim.py +22 -24
  223. diffusers/schedulers/scheduling_ddim_flax.py +2 -1
  224. diffusers/schedulers/scheduling_ddim_inverse.py +16 -16
  225. diffusers/schedulers/scheduling_ddim_parallel.py +28 -30
  226. diffusers/schedulers/scheduling_ddpm.py +20 -22
  227. diffusers/schedulers/scheduling_ddpm_flax.py +7 -3
  228. diffusers/schedulers/scheduling_ddpm_parallel.py +26 -28
  229. diffusers/schedulers/scheduling_ddpm_wuerstchen.py +14 -14
  230. diffusers/schedulers/scheduling_deis_multistep.py +42 -42
  231. diffusers/schedulers/scheduling_dpmsolver_multistep.py +103 -77
  232. diffusers/schedulers/scheduling_dpmsolver_multistep_flax.py +2 -2
  233. diffusers/schedulers/scheduling_dpmsolver_multistep_inverse.py +46 -46
  234. diffusers/schedulers/scheduling_dpmsolver_sde.py +23 -23
  235. diffusers/schedulers/scheduling_dpmsolver_singlestep.py +86 -65
  236. diffusers/schedulers/scheduling_edm_dpmsolver_multistep.py +75 -54
  237. diffusers/schedulers/scheduling_edm_euler.py +50 -31
  238. diffusers/schedulers/scheduling_euler_ancestral_discrete.py +23 -29
  239. diffusers/schedulers/scheduling_euler_discrete.py +160 -68
  240. diffusers/schedulers/scheduling_heun_discrete.py +57 -39
  241. diffusers/schedulers/scheduling_ipndm.py +8 -8
  242. diffusers/schedulers/scheduling_k_dpm_2_ancestral_discrete.py +19 -19
  243. diffusers/schedulers/scheduling_k_dpm_2_discrete.py +19 -19
  244. diffusers/schedulers/scheduling_karras_ve_flax.py +6 -6
  245. diffusers/schedulers/scheduling_lcm.py +21 -23
  246. diffusers/schedulers/scheduling_lms_discrete.py +24 -26
  247. diffusers/schedulers/scheduling_pndm.py +20 -20
  248. diffusers/schedulers/scheduling_repaint.py +20 -20
  249. diffusers/schedulers/scheduling_sasolver.py +55 -54
  250. diffusers/schedulers/scheduling_sde_ve.py +19 -19
  251. diffusers/schedulers/scheduling_tcd.py +39 -30
  252. diffusers/schedulers/scheduling_unclip.py +15 -15
  253. diffusers/schedulers/scheduling_unipc_multistep.py +111 -41
  254. diffusers/schedulers/scheduling_utils.py +14 -5
  255. diffusers/schedulers/scheduling_utils_flax.py +3 -3
  256. diffusers/schedulers/scheduling_vq_diffusion.py +10 -10
  257. diffusers/training_utils.py +56 -1
  258. diffusers/utils/__init__.py +7 -0
  259. diffusers/utils/doc_utils.py +1 -0
  260. diffusers/utils/dummy_pt_objects.py +75 -0
  261. diffusers/utils/dummy_torch_and_transformers_objects.py +105 -0
  262. diffusers/utils/dynamic_modules_utils.py +24 -11
  263. diffusers/utils/hub_utils.py +3 -2
  264. diffusers/utils/import_utils.py +91 -0
  265. diffusers/utils/loading_utils.py +2 -2
  266. diffusers/utils/logging.py +1 -1
  267. diffusers/utils/peft_utils.py +32 -5
  268. diffusers/utils/state_dict_utils.py +11 -2
  269. diffusers/utils/testing_utils.py +71 -6
  270. diffusers/utils/torch_utils.py +1 -0
  271. diffusers/video_processor.py +113 -0
  272. {diffusers-0.27.2.dist-info → diffusers-0.28.1.dist-info}/METADATA +7 -7
  273. diffusers-0.28.1.dist-info/RECORD +419 -0
  274. diffusers-0.27.2.dist-info/RECORD +0 -399
  275. {diffusers-0.27.2.dist-info → diffusers-0.28.1.dist-info}/LICENSE +0 -0
  276. {diffusers-0.27.2.dist-info → diffusers-0.28.1.dist-info}/WHEEL +0 -0
  277. {diffusers-0.27.2.dist-info → diffusers-0.28.1.dist-info}/entry_points.txt +0 -0
  278. {diffusers-0.27.2.dist-info → diffusers-0.28.1.dist-info}/top_level.txt +0 -0
@@ -112,9 +112,7 @@ class AsymmetricAutoencoderKL(ModelMixin, ConfigMixin):
112
112
  self.register_to_config(force_upcast=False)
113
113
 
114
114
  @apply_forward_hook
115
- def encode(
116
- self, x: torch.FloatTensor, return_dict: bool = True
117
- ) -> Union[AutoencoderKLOutput, Tuple[torch.FloatTensor]]:
115
+ def encode(self, x: torch.Tensor, return_dict: bool = True) -> Union[AutoencoderKLOutput, Tuple[torch.Tensor]]:
118
116
  h = self.encoder(x)
119
117
  moments = self.quant_conv(h)
120
118
  posterior = DiagonalGaussianDistribution(moments)
@@ -126,11 +124,11 @@ class AsymmetricAutoencoderKL(ModelMixin, ConfigMixin):
126
124
 
127
125
  def _decode(
128
126
  self,
129
- z: torch.FloatTensor,
130
- image: Optional[torch.FloatTensor] = None,
131
- mask: Optional[torch.FloatTensor] = None,
127
+ z: torch.Tensor,
128
+ image: Optional[torch.Tensor] = None,
129
+ mask: Optional[torch.Tensor] = None,
132
130
  return_dict: bool = True,
133
- ) -> Union[DecoderOutput, Tuple[torch.FloatTensor]]:
131
+ ) -> Union[DecoderOutput, Tuple[torch.Tensor]]:
134
132
  z = self.post_quant_conv(z)
135
133
  dec = self.decoder(z, image, mask)
136
134
 
@@ -142,12 +140,12 @@ class AsymmetricAutoencoderKL(ModelMixin, ConfigMixin):
142
140
  @apply_forward_hook
143
141
  def decode(
144
142
  self,
145
- z: torch.FloatTensor,
143
+ z: torch.Tensor,
146
144
  generator: Optional[torch.Generator] = None,
147
- image: Optional[torch.FloatTensor] = None,
148
- mask: Optional[torch.FloatTensor] = None,
145
+ image: Optional[torch.Tensor] = None,
146
+ mask: Optional[torch.Tensor] = None,
149
147
  return_dict: bool = True,
150
- ) -> Union[DecoderOutput, Tuple[torch.FloatTensor]]:
148
+ ) -> Union[DecoderOutput, Tuple[torch.Tensor]]:
151
149
  decoded = self._decode(z, image, mask).sample
152
150
 
153
151
  if not return_dict:
@@ -157,16 +155,16 @@ class AsymmetricAutoencoderKL(ModelMixin, ConfigMixin):
157
155
 
158
156
  def forward(
159
157
  self,
160
- sample: torch.FloatTensor,
161
- mask: Optional[torch.FloatTensor] = None,
158
+ sample: torch.Tensor,
159
+ mask: Optional[torch.Tensor] = None,
162
160
  sample_posterior: bool = False,
163
161
  return_dict: bool = True,
164
162
  generator: Optional[torch.Generator] = None,
165
- ) -> Union[DecoderOutput, Tuple[torch.FloatTensor]]:
163
+ ) -> Union[DecoderOutput, Tuple[torch.Tensor]]:
166
164
  r"""
167
165
  Args:
168
- sample (`torch.FloatTensor`): Input sample.
169
- mask (`torch.FloatTensor`, *optional*, defaults to `None`): Optional inpainting mask.
166
+ sample (`torch.Tensor`): Input sample.
167
+ mask (`torch.Tensor`, *optional*, defaults to `None`): Optional inpainting mask.
170
168
  sample_posterior (`bool`, *optional*, defaults to `False`):
171
169
  Whether to sample from the posterior.
172
170
  return_dict (`bool`, *optional*, defaults to `True`):
@@ -17,7 +17,7 @@ import torch
17
17
  import torch.nn as nn
18
18
 
19
19
  from ...configuration_utils import ConfigMixin, register_to_config
20
- from ...loaders import FromOriginalVAEMixin
20
+ from ...loaders.single_file_model import FromOriginalModelMixin
21
21
  from ...utils.accelerate_utils import apply_forward_hook
22
22
  from ..attention_processor import (
23
23
  ADDED_KV_ATTENTION_PROCESSORS,
@@ -32,7 +32,7 @@ from ..modeling_utils import ModelMixin
32
32
  from .vae import Decoder, DecoderOutput, DiagonalGaussianDistribution, Encoder
33
33
 
34
34
 
35
- class AutoencoderKL(ModelMixin, ConfigMixin, FromOriginalVAEMixin):
35
+ class AutoencoderKL(ModelMixin, ConfigMixin, FromOriginalModelMixin):
36
36
  r"""
37
37
  A VAE model with KL loss for encoding images into latents and decoding latent representations into images.
38
38
 
@@ -65,6 +65,7 @@ class AutoencoderKL(ModelMixin, ConfigMixin, FromOriginalVAEMixin):
65
65
  """
66
66
 
67
67
  _supports_gradient_checkpointing = True
68
+ _no_split_modules = ["BasicTransformerBlock", "ResnetBlock2D"]
68
69
 
69
70
  @register_to_config
70
71
  def __init__(
@@ -236,13 +237,13 @@ class AutoencoderKL(ModelMixin, ConfigMixin, FromOriginalVAEMixin):
236
237
 
237
238
  @apply_forward_hook
238
239
  def encode(
239
- self, x: torch.FloatTensor, return_dict: bool = True
240
+ self, x: torch.Tensor, return_dict: bool = True
240
241
  ) -> Union[AutoencoderKLOutput, Tuple[DiagonalGaussianDistribution]]:
241
242
  """
242
243
  Encode a batch of images into latents.
243
244
 
244
245
  Args:
245
- x (`torch.FloatTensor`): Input batch of images.
246
+ x (`torch.Tensor`): Input batch of images.
246
247
  return_dict (`bool`, *optional*, defaults to `True`):
247
248
  Whether to return a [`~models.autoencoder_kl.AutoencoderKLOutput`] instead of a plain tuple.
248
249
 
@@ -267,7 +268,7 @@ class AutoencoderKL(ModelMixin, ConfigMixin, FromOriginalVAEMixin):
267
268
 
268
269
  return AutoencoderKLOutput(latent_dist=posterior)
269
270
 
270
- def _decode(self, z: torch.FloatTensor, return_dict: bool = True) -> Union[DecoderOutput, torch.FloatTensor]:
271
+ def _decode(self, z: torch.Tensor, return_dict: bool = True) -> Union[DecoderOutput, torch.Tensor]:
271
272
  if self.use_tiling and (z.shape[-1] > self.tile_latent_min_size or z.shape[-2] > self.tile_latent_min_size):
272
273
  return self.tiled_decode(z, return_dict=return_dict)
273
274
 
@@ -280,14 +281,12 @@ class AutoencoderKL(ModelMixin, ConfigMixin, FromOriginalVAEMixin):
280
281
  return DecoderOutput(sample=dec)
281
282
 
282
283
  @apply_forward_hook
283
- def decode(
284
- self, z: torch.FloatTensor, return_dict: bool = True, generator=None
285
- ) -> Union[DecoderOutput, torch.FloatTensor]:
284
+ def decode(self, z: torch.Tensor, return_dict: bool = True, generator=None) -> Union[DecoderOutput, torch.Tensor]:
286
285
  """
287
286
  Decode a batch of images.
288
287
 
289
288
  Args:
290
- z (`torch.FloatTensor`): Input batch of latent vectors.
289
+ z (`torch.Tensor`): Input batch of latent vectors.
291
290
  return_dict (`bool`, *optional*, defaults to `True`):
292
291
  Whether to return a [`~models.vae.DecoderOutput`] instead of a plain tuple.
293
292
 
@@ -301,7 +300,7 @@ class AutoencoderKL(ModelMixin, ConfigMixin, FromOriginalVAEMixin):
301
300
  decoded_slices = [self._decode(z_slice).sample for z_slice in z.split(1)]
302
301
  decoded = torch.cat(decoded_slices)
303
302
  else:
304
- decoded = self._decode(z).sample
303
+ decoded = self._decode(z, return_dict=False)[0]
305
304
 
306
305
  if not return_dict:
307
306
  return (decoded,)
@@ -320,7 +319,7 @@ class AutoencoderKL(ModelMixin, ConfigMixin, FromOriginalVAEMixin):
320
319
  b[:, :, :, x] = a[:, :, :, -blend_extent + x] * (1 - x / blend_extent) + b[:, :, :, x] * (x / blend_extent)
321
320
  return b
322
321
 
323
- def tiled_encode(self, x: torch.FloatTensor, return_dict: bool = True) -> AutoencoderKLOutput:
322
+ def tiled_encode(self, x: torch.Tensor, return_dict: bool = True) -> AutoencoderKLOutput:
324
323
  r"""Encode a batch of images using a tiled encoder.
325
324
 
326
325
  When this option is enabled, the VAE will split the input tensor into tiles to compute encoding in several
@@ -330,7 +329,7 @@ class AutoencoderKL(ModelMixin, ConfigMixin, FromOriginalVAEMixin):
330
329
  output, but they should be much less noticeable.
331
330
 
332
331
  Args:
333
- x (`torch.FloatTensor`): Input batch of images.
332
+ x (`torch.Tensor`): Input batch of images.
334
333
  return_dict (`bool`, *optional*, defaults to `True`):
335
334
  Whether or not to return a [`~models.autoencoder_kl.AutoencoderKLOutput`] instead of a plain tuple.
336
335
 
@@ -374,12 +373,12 @@ class AutoencoderKL(ModelMixin, ConfigMixin, FromOriginalVAEMixin):
374
373
 
375
374
  return AutoencoderKLOutput(latent_dist=posterior)
376
375
 
377
- def tiled_decode(self, z: torch.FloatTensor, return_dict: bool = True) -> Union[DecoderOutput, torch.FloatTensor]:
376
+ def tiled_decode(self, z: torch.Tensor, return_dict: bool = True) -> Union[DecoderOutput, torch.Tensor]:
378
377
  r"""
379
378
  Decode a batch of images using a tiled decoder.
380
379
 
381
380
  Args:
382
- z (`torch.FloatTensor`): Input batch of latent vectors.
381
+ z (`torch.Tensor`): Input batch of latent vectors.
383
382
  return_dict (`bool`, *optional*, defaults to `True`):
384
383
  Whether or not to return a [`~models.vae.DecoderOutput`] instead of a plain tuple.
385
384
 
@@ -424,14 +423,14 @@ class AutoencoderKL(ModelMixin, ConfigMixin, FromOriginalVAEMixin):
424
423
 
425
424
  def forward(
426
425
  self,
427
- sample: torch.FloatTensor,
426
+ sample: torch.Tensor,
428
427
  sample_posterior: bool = False,
429
428
  return_dict: bool = True,
430
429
  generator: Optional[torch.Generator] = None,
431
- ) -> Union[DecoderOutput, torch.FloatTensor]:
430
+ ) -> Union[DecoderOutput, torch.Tensor]:
432
431
  r"""
433
432
  Args:
434
- sample (`torch.FloatTensor`): Input sample.
433
+ sample (`torch.Tensor`): Input sample.
435
434
  sample_posterior (`bool`, *optional*, defaults to `False`):
436
435
  Whether to sample from the posterior.
437
436
  return_dict (`bool`, *optional*, defaults to `True`):
@@ -453,8 +452,8 @@ class AutoencoderKL(ModelMixin, ConfigMixin, FromOriginalVAEMixin):
453
452
  # Copied from diffusers.models.unets.unet_2d_condition.UNet2DConditionModel.fuse_qkv_projections
454
453
  def fuse_qkv_projections(self):
455
454
  """
456
- Enables fused QKV projections. For self-attention modules, all projection matrices (i.e., query,
457
- key, value) are fused. For cross-attention modules, key and value projection matrices are fused.
455
+ Enables fused QKV projections. For self-attention modules, all projection matrices (i.e., query, key, value)
456
+ are fused. For cross-attention modules, key and value projection matrices are fused.
458
457
 
459
458
  <Tip warning={true}>
460
459
 
@@ -86,10 +86,10 @@ class TemporalDecoder(nn.Module):
86
86
 
87
87
  def forward(
88
88
  self,
89
- sample: torch.FloatTensor,
90
- image_only_indicator: torch.FloatTensor,
89
+ sample: torch.Tensor,
90
+ image_only_indicator: torch.Tensor,
91
91
  num_frames: int = 1,
92
- ) -> torch.FloatTensor:
92
+ ) -> torch.Tensor:
93
93
  r"""The forward method of the `Decoder` class."""
94
94
 
95
95
  sample = self.conv_in(sample)
@@ -315,13 +315,13 @@ class AutoencoderKLTemporalDecoder(ModelMixin, ConfigMixin):
315
315
 
316
316
  @apply_forward_hook
317
317
  def encode(
318
- self, x: torch.FloatTensor, return_dict: bool = True
318
+ self, x: torch.Tensor, return_dict: bool = True
319
319
  ) -> Union[AutoencoderKLOutput, Tuple[DiagonalGaussianDistribution]]:
320
320
  """
321
321
  Encode a batch of images into latents.
322
322
 
323
323
  Args:
324
- x (`torch.FloatTensor`): Input batch of images.
324
+ x (`torch.Tensor`): Input batch of images.
325
325
  return_dict (`bool`, *optional*, defaults to `True`):
326
326
  Whether to return a [`~models.autoencoder_kl.AutoencoderKLOutput`] instead of a plain tuple.
327
327
 
@@ -341,15 +341,15 @@ class AutoencoderKLTemporalDecoder(ModelMixin, ConfigMixin):
341
341
  @apply_forward_hook
342
342
  def decode(
343
343
  self,
344
- z: torch.FloatTensor,
344
+ z: torch.Tensor,
345
345
  num_frames: int,
346
346
  return_dict: bool = True,
347
- ) -> Union[DecoderOutput, torch.FloatTensor]:
347
+ ) -> Union[DecoderOutput, torch.Tensor]:
348
348
  """
349
349
  Decode a batch of images.
350
350
 
351
351
  Args:
352
- z (`torch.FloatTensor`): Input batch of latent vectors.
352
+ z (`torch.Tensor`): Input batch of latent vectors.
353
353
  return_dict (`bool`, *optional*, defaults to `True`):
354
354
  Whether to return a [`~models.vae.DecoderOutput`] instead of a plain tuple.
355
355
 
@@ -370,15 +370,15 @@ class AutoencoderKLTemporalDecoder(ModelMixin, ConfigMixin):
370
370
 
371
371
  def forward(
372
372
  self,
373
- sample: torch.FloatTensor,
373
+ sample: torch.Tensor,
374
374
  sample_posterior: bool = False,
375
375
  return_dict: bool = True,
376
376
  generator: Optional[torch.Generator] = None,
377
377
  num_frames: int = 1,
378
- ) -> Union[DecoderOutput, torch.FloatTensor]:
378
+ ) -> Union[DecoderOutput, torch.Tensor]:
379
379
  r"""
380
380
  Args:
381
- sample (`torch.FloatTensor`): Input sample.
381
+ sample (`torch.Tensor`): Input sample.
382
382
  sample_posterior (`bool`, *optional*, defaults to `False`):
383
383
  Whether to sample from the posterior.
384
384
  return_dict (`bool`, *optional*, defaults to `True`):
@@ -102,6 +102,7 @@ class AutoencoderTiny(ModelMixin, ConfigMixin):
102
102
  encoder_block_out_channels: Tuple[int, ...] = (64, 64, 64, 64),
103
103
  decoder_block_out_channels: Tuple[int, ...] = (64, 64, 64, 64),
104
104
  act_fn: str = "relu",
105
+ upsample_fn: str = "nearest",
105
106
  latent_channels: int = 4,
106
107
  upsampling_scaling_factor: int = 2,
107
108
  num_encoder_blocks: Tuple[int, ...] = (1, 3, 3, 3),
@@ -133,6 +134,7 @@ class AutoencoderTiny(ModelMixin, ConfigMixin):
133
134
  block_out_channels=decoder_block_out_channels,
134
135
  upsampling_scaling_factor=upsampling_scaling_factor,
135
136
  act_fn=act_fn,
137
+ upsample_fn=upsample_fn,
136
138
  )
137
139
 
138
140
  self.latent_magnitude = latent_magnitude
@@ -155,11 +157,11 @@ class AutoencoderTiny(ModelMixin, ConfigMixin):
155
157
  if isinstance(module, (EncoderTiny, DecoderTiny)):
156
158
  module.gradient_checkpointing = value
157
159
 
158
- def scale_latents(self, x: torch.FloatTensor) -> torch.FloatTensor:
160
+ def scale_latents(self, x: torch.Tensor) -> torch.Tensor:
159
161
  """raw latents -> [0, 1]"""
160
162
  return x.div(2 * self.latent_magnitude).add(self.latent_shift).clamp(0, 1)
161
163
 
162
- def unscale_latents(self, x: torch.FloatTensor) -> torch.FloatTensor:
164
+ def unscale_latents(self, x: torch.Tensor) -> torch.Tensor:
163
165
  """[0, 1] -> raw latents"""
164
166
  return x.sub(self.latent_shift).mul(2 * self.latent_magnitude)
165
167
 
@@ -192,7 +194,7 @@ class AutoencoderTiny(ModelMixin, ConfigMixin):
192
194
  """
193
195
  self.enable_tiling(False)
194
196
 
195
- def _tiled_encode(self, x: torch.FloatTensor) -> torch.FloatTensor:
197
+ def _tiled_encode(self, x: torch.Tensor) -> torch.Tensor:
196
198
  r"""Encode a batch of images using a tiled encoder.
197
199
 
198
200
  When this option is enabled, the VAE will split the input tensor into tiles to compute encoding in several
@@ -200,10 +202,10 @@ class AutoencoderTiny(ModelMixin, ConfigMixin):
200
202
  tiles overlap and are blended together to form a smooth output.
201
203
 
202
204
  Args:
203
- x (`torch.FloatTensor`): Input batch of images.
205
+ x (`torch.Tensor`): Input batch of images.
204
206
 
205
207
  Returns:
206
- `torch.FloatTensor`: Encoded batch of images.
208
+ `torch.Tensor`: Encoded batch of images.
207
209
  """
208
210
  # scale of encoder output relative to input
209
211
  sf = self.spatial_scale_factor
@@ -240,7 +242,7 @@ class AutoencoderTiny(ModelMixin, ConfigMixin):
240
242
  tile_out.copy_(blend_mask * tile + (1 - blend_mask) * tile_out)
241
243
  return out
242
244
 
243
- def _tiled_decode(self, x: torch.FloatTensor) -> torch.FloatTensor:
245
+ def _tiled_decode(self, x: torch.Tensor) -> torch.Tensor:
244
246
  r"""Encode a batch of images using a tiled encoder.
245
247
 
246
248
  When this option is enabled, the VAE will split the input tensor into tiles to compute encoding in several
@@ -248,10 +250,10 @@ class AutoencoderTiny(ModelMixin, ConfigMixin):
248
250
  tiles overlap and are blended together to form a smooth output.
249
251
 
250
252
  Args:
251
- x (`torch.FloatTensor`): Input batch of images.
253
+ x (`torch.Tensor`): Input batch of images.
252
254
 
253
255
  Returns:
254
- `torch.FloatTensor`: Encoded batch of images.
256
+ `torch.Tensor`: Encoded batch of images.
255
257
  """
256
258
  # scale of decoder output relative to input
257
259
  sf = self.spatial_scale_factor
@@ -288,9 +290,7 @@ class AutoencoderTiny(ModelMixin, ConfigMixin):
288
290
  return out
289
291
 
290
292
  @apply_forward_hook
291
- def encode(
292
- self, x: torch.FloatTensor, return_dict: bool = True
293
- ) -> Union[AutoencoderTinyOutput, Tuple[torch.FloatTensor]]:
293
+ def encode(self, x: torch.Tensor, return_dict: bool = True) -> Union[AutoencoderTinyOutput, Tuple[torch.Tensor]]:
294
294
  if self.use_slicing and x.shape[0] > 1:
295
295
  output = [
296
296
  self._tiled_encode(x_slice) if self.use_tiling else self.encoder(x_slice) for x_slice in x.split(1)
@@ -306,8 +306,8 @@ class AutoencoderTiny(ModelMixin, ConfigMixin):
306
306
 
307
307
  @apply_forward_hook
308
308
  def decode(
309
- self, x: torch.FloatTensor, generator: Optional[torch.Generator] = None, return_dict: bool = True
310
- ) -> Union[DecoderOutput, Tuple[torch.FloatTensor]]:
309
+ self, x: torch.Tensor, generator: Optional[torch.Generator] = None, return_dict: bool = True
310
+ ) -> Union[DecoderOutput, Tuple[torch.Tensor]]:
311
311
  if self.use_slicing and x.shape[0] > 1:
312
312
  output = [self._tiled_decode(x_slice) if self.use_tiling else self.decoder(x) for x_slice in x.split(1)]
313
313
  output = torch.cat(output)
@@ -321,12 +321,12 @@ class AutoencoderTiny(ModelMixin, ConfigMixin):
321
321
 
322
322
  def forward(
323
323
  self,
324
- sample: torch.FloatTensor,
324
+ sample: torch.Tensor,
325
325
  return_dict: bool = True,
326
- ) -> Union[DecoderOutput, Tuple[torch.FloatTensor]]:
326
+ ) -> Union[DecoderOutput, Tuple[torch.Tensor]]:
327
327
  r"""
328
328
  Args:
329
- sample (`torch.FloatTensor`): Input sample.
329
+ sample (`torch.Tensor`): Input sample.
330
330
  return_dict (`bool`, *optional*, defaults to `True`):
331
331
  Whether or not to return a [`DecoderOutput`] instead of a plain tuple.
332
332
  """
@@ -63,7 +63,8 @@ class ConsistencyDecoderVAE(ModelMixin, ConfigMixin):
63
63
  ... "runwayml/stable-diffusion-v1-5", vae=vae, torch_dtype=torch.float16
64
64
  ... ).to("cuda")
65
65
 
66
- >>> pipe("horse", generator=torch.manual_seed(0)).images
66
+ >>> image = pipe("horse", generator=torch.manual_seed(0)).images[0]
67
+ >>> image
67
68
  ```
68
69
  """
69
70
 
@@ -72,6 +73,7 @@ class ConsistencyDecoderVAE(ModelMixin, ConfigMixin):
72
73
  self,
73
74
  scaling_factor: float = 0.18215,
74
75
  latent_channels: int = 4,
76
+ sample_size: int = 32,
75
77
  encoder_act_fn: str = "silu",
76
78
  encoder_block_out_channels: Tuple[int, ...] = (128, 256, 512, 512),
77
79
  encoder_double_z: bool = True,
@@ -153,6 +155,16 @@ class ConsistencyDecoderVAE(ModelMixin, ConfigMixin):
153
155
  self.use_slicing = False
154
156
  self.use_tiling = False
155
157
 
158
+ # only relevant if vae tiling is enabled
159
+ self.tile_sample_min_size = self.config.sample_size
160
+ sample_size = (
161
+ self.config.sample_size[0]
162
+ if isinstance(self.config.sample_size, (list, tuple))
163
+ else self.config.sample_size
164
+ )
165
+ self.tile_latent_min_size = int(sample_size / (2 ** (len(self.config.block_out_channels) - 1)))
166
+ self.tile_overlap_factor = 0.25
167
+
156
168
  # Copied from diffusers.models.autoencoders.autoencoder_kl.AutoencoderKL.enable_tiling
157
169
  def enable_tiling(self, use_tiling: bool = True):
158
170
  r"""
@@ -264,15 +276,15 @@ class ConsistencyDecoderVAE(ModelMixin, ConfigMixin):
264
276
 
265
277
  @apply_forward_hook
266
278
  def encode(
267
- self, x: torch.FloatTensor, return_dict: bool = True
279
+ self, x: torch.Tensor, return_dict: bool = True
268
280
  ) -> Union[ConsistencyDecoderVAEOutput, Tuple[DiagonalGaussianDistribution]]:
269
281
  """
270
282
  Encode a batch of images into latents.
271
283
 
272
284
  Args:
273
- x (`torch.FloatTensor`): Input batch of images.
285
+ x (`torch.Tensor`): Input batch of images.
274
286
  return_dict (`bool`, *optional*, defaults to `True`):
275
- Whether to return a [`~models.consistecy_decoder_vae.ConsistencyDecoderOoutput`] instead of a plain
287
+ Whether to return a [`~models.consistency_decoder_vae.ConsistencyDecoderVAEOutput`] instead of a plain
276
288
  tuple.
277
289
 
278
290
  Returns:
@@ -300,11 +312,24 @@ class ConsistencyDecoderVAE(ModelMixin, ConfigMixin):
300
312
  @apply_forward_hook
301
313
  def decode(
302
314
  self,
303
- z: torch.FloatTensor,
315
+ z: torch.Tensor,
304
316
  generator: Optional[torch.Generator] = None,
305
317
  return_dict: bool = True,
306
318
  num_inference_steps: int = 2,
307
- ) -> Union[DecoderOutput, Tuple[torch.FloatTensor]]:
319
+ ) -> Union[DecoderOutput, Tuple[torch.Tensor]]:
320
+ """
321
+ Decodes the input latent vector `z` using the consistency decoder VAE model.
322
+
323
+ Args:
324
+ z (torch.Tensor): The input latent vector.
325
+ generator (Optional[torch.Generator]): The random number generator. Default is None.
326
+ return_dict (bool): Whether to return the output as a dictionary. Default is True.
327
+ num_inference_steps (int): The number of inference steps. Default is 2.
328
+
329
+ Returns:
330
+ Union[DecoderOutput, Tuple[torch.Tensor]]: The decoded output.
331
+
332
+ """
308
333
  z = (z * self.config.scaling_factor - self.means) / self.stds
309
334
 
310
335
  scale_factor = 2 ** (len(self.config.block_out_channels) - 1)
@@ -345,7 +370,7 @@ class ConsistencyDecoderVAE(ModelMixin, ConfigMixin):
345
370
  b[:, :, :, x] = a[:, :, :, -blend_extent + x] * (1 - x / blend_extent) + b[:, :, :, x] * (x / blend_extent)
346
371
  return b
347
372
 
348
- def tiled_encode(self, x: torch.FloatTensor, return_dict: bool = True) -> ConsistencyDecoderVAEOutput:
373
+ def tiled_encode(self, x: torch.Tensor, return_dict: bool = True) -> Union[ConsistencyDecoderVAEOutput, Tuple]:
349
374
  r"""Encode a batch of images using a tiled encoder.
350
375
 
351
376
  When this option is enabled, the VAE will split the input tensor into tiles to compute encoding in several
@@ -355,7 +380,7 @@ class ConsistencyDecoderVAE(ModelMixin, ConfigMixin):
355
380
  output, but they should be much less noticeable.
356
381
 
357
382
  Args:
358
- x (`torch.FloatTensor`): Input batch of images.
383
+ x (`torch.Tensor`): Input batch of images.
359
384
  return_dict (`bool`, *optional*, defaults to `True`):
360
385
  Whether or not to return a [`~models.consistency_decoder_vae.ConsistencyDecoderVAEOutput`] instead of a
361
386
  plain tuple.
@@ -402,14 +427,14 @@ class ConsistencyDecoderVAE(ModelMixin, ConfigMixin):
402
427
 
403
428
  def forward(
404
429
  self,
405
- sample: torch.FloatTensor,
430
+ sample: torch.Tensor,
406
431
  sample_posterior: bool = False,
407
432
  return_dict: bool = True,
408
433
  generator: Optional[torch.Generator] = None,
409
- ) -> Union[DecoderOutput, Tuple[torch.FloatTensor]]:
434
+ ) -> Union[DecoderOutput, Tuple[torch.Tensor]]:
410
435
  r"""
411
436
  Args:
412
- sample (`torch.FloatTensor`): Input sample.
437
+ sample (`torch.Tensor`): Input sample.
413
438
  sample_posterior (`bool`, *optional*, defaults to `False`):
414
439
  Whether to sample from the posterior.
415
440
  return_dict (`bool`, *optional*, defaults to `True`):
@@ -36,11 +36,12 @@ class DecoderOutput(BaseOutput):
36
36
  Output of decoding method.
37
37
 
38
38
  Args:
39
- sample (`torch.FloatTensor` of shape `(batch_size, num_channels, height, width)`):
39
+ sample (`torch.Tensor` of shape `(batch_size, num_channels, height, width)`):
40
40
  The decoded output sample from the last layer of the model.
41
41
  """
42
42
 
43
- sample: torch.FloatTensor
43
+ sample: torch.Tensor
44
+ commit_loss: Optional[torch.FloatTensor] = None
44
45
 
45
46
 
46
47
  class Encoder(nn.Module):
@@ -90,7 +91,6 @@ class Encoder(nn.Module):
90
91
  padding=1,
91
92
  )
92
93
 
93
- self.mid_block = None
94
94
  self.down_blocks = nn.ModuleList([])
95
95
 
96
96
  # down
@@ -137,7 +137,7 @@ class Encoder(nn.Module):
137
137
 
138
138
  self.gradient_checkpointing = False
139
139
 
140
- def forward(self, sample: torch.FloatTensor) -> torch.FloatTensor:
140
+ def forward(self, sample: torch.Tensor) -> torch.Tensor:
141
141
  r"""The forward method of the `Encoder` class."""
142
142
 
143
143
  sample = self.conv_in(sample)
@@ -228,7 +228,6 @@ class Decoder(nn.Module):
228
228
  padding=1,
229
229
  )
230
230
 
231
- self.mid_block = None
232
231
  self.up_blocks = nn.ModuleList([])
233
232
 
234
233
  temb_channels = in_channels if norm_type == "spatial" else None
@@ -284,9 +283,9 @@ class Decoder(nn.Module):
284
283
 
285
284
  def forward(
286
285
  self,
287
- sample: torch.FloatTensor,
288
- latent_embeds: Optional[torch.FloatTensor] = None,
289
- ) -> torch.FloatTensor:
286
+ sample: torch.Tensor,
287
+ latent_embeds: Optional[torch.Tensor] = None,
288
+ ) -> torch.Tensor:
290
289
  r"""The forward method of the `Decoder` class."""
291
290
 
292
291
  sample = self.conv_in(sample)
@@ -369,7 +368,7 @@ class UpSample(nn.Module):
369
368
  self.out_channels = out_channels
370
369
  self.deconv = nn.ConvTranspose2d(in_channels, out_channels, kernel_size=4, stride=2, padding=1)
371
370
 
372
- def forward(self, x: torch.FloatTensor) -> torch.FloatTensor:
371
+ def forward(self, x: torch.Tensor) -> torch.Tensor:
373
372
  r"""The forward method of the `UpSample` class."""
374
373
  x = torch.relu(x)
375
374
  x = self.deconv(x)
@@ -418,7 +417,7 @@ class MaskConditionEncoder(nn.Module):
418
417
 
419
418
  self.layers = nn.Sequential(*layers)
420
419
 
421
- def forward(self, x: torch.FloatTensor, mask=None) -> torch.FloatTensor:
420
+ def forward(self, x: torch.Tensor, mask=None) -> torch.Tensor:
422
421
  r"""The forward method of the `MaskConditionEncoder` class."""
423
422
  out = {}
424
423
  for l in range(len(self.layers)):
@@ -474,7 +473,6 @@ class MaskConditionDecoder(nn.Module):
474
473
  padding=1,
475
474
  )
476
475
 
477
- self.mid_block = None
478
476
  self.up_blocks = nn.ModuleList([])
479
477
 
480
478
  temb_channels = in_channels if norm_type == "spatial" else None
@@ -536,11 +534,11 @@ class MaskConditionDecoder(nn.Module):
536
534
 
537
535
  def forward(
538
536
  self,
539
- z: torch.FloatTensor,
540
- image: Optional[torch.FloatTensor] = None,
541
- mask: Optional[torch.FloatTensor] = None,
542
- latent_embeds: Optional[torch.FloatTensor] = None,
543
- ) -> torch.FloatTensor:
537
+ z: torch.Tensor,
538
+ image: Optional[torch.Tensor] = None,
539
+ mask: Optional[torch.Tensor] = None,
540
+ latent_embeds: Optional[torch.Tensor] = None,
541
+ ) -> torch.Tensor:
544
542
  r"""The forward method of the `MaskConditionDecoder` class."""
545
543
  sample = z
546
544
  sample = self.conv_in(sample)
@@ -714,7 +712,7 @@ class VectorQuantizer(nn.Module):
714
712
  back = torch.gather(used[None, :][inds.shape[0] * [0], :], 1, inds)
715
713
  return back.reshape(ishape)
716
714
 
717
- def forward(self, z: torch.FloatTensor) -> Tuple[torch.FloatTensor, torch.FloatTensor, Tuple]:
715
+ def forward(self, z: torch.Tensor) -> Tuple[torch.Tensor, torch.Tensor, Tuple]:
718
716
  # reshape z -> (batch, height, width, channel) and flatten
719
717
  z = z.permute(0, 2, 3, 1).contiguous()
720
718
  z_flattened = z.view(-1, self.vq_embed_dim)
@@ -733,7 +731,7 @@ class VectorQuantizer(nn.Module):
733
731
  loss = torch.mean((z_q.detach() - z) ** 2) + self.beta * torch.mean((z_q - z.detach()) ** 2)
734
732
 
735
733
  # preserve gradients
736
- z_q: torch.FloatTensor = z + (z_q - z).detach()
734
+ z_q: torch.Tensor = z + (z_q - z).detach()
737
735
 
738
736
  # reshape back to match original input shape
739
737
  z_q = z_q.permute(0, 3, 1, 2).contiguous()
@@ -748,7 +746,7 @@ class VectorQuantizer(nn.Module):
748
746
 
749
747
  return z_q, loss, (perplexity, min_encodings, min_encoding_indices)
750
748
 
751
- def get_codebook_entry(self, indices: torch.LongTensor, shape: Tuple[int, ...]) -> torch.FloatTensor:
749
+ def get_codebook_entry(self, indices: torch.LongTensor, shape: Tuple[int, ...]) -> torch.Tensor:
752
750
  # shape specifying (batch, height, width, channel)
753
751
  if self.remap is not None:
754
752
  indices = indices.reshape(shape[0], -1) # add batch axis
@@ -756,7 +754,7 @@ class VectorQuantizer(nn.Module):
756
754
  indices = indices.reshape(-1) # flatten again
757
755
 
758
756
  # get quantized latent vectors
759
- z_q: torch.FloatTensor = self.embedding(indices)
757
+ z_q: torch.Tensor = self.embedding(indices)
760
758
 
761
759
  if shape is not None:
762
760
  z_q = z_q.view(shape)
@@ -779,7 +777,7 @@ class DiagonalGaussianDistribution(object):
779
777
  self.mean, device=self.parameters.device, dtype=self.parameters.dtype
780
778
  )
781
779
 
782
- def sample(self, generator: Optional[torch.Generator] = None) -> torch.FloatTensor:
780
+ def sample(self, generator: Optional[torch.Generator] = None) -> torch.Tensor:
783
781
  # make sure sample is on the same device as the parameters and has same dtype
784
782
  sample = randn_tensor(
785
783
  self.mean.shape,
@@ -876,7 +874,7 @@ class EncoderTiny(nn.Module):
876
874
  self.layers = nn.Sequential(*layers)
877
875
  self.gradient_checkpointing = False
878
876
 
879
- def forward(self, x: torch.FloatTensor) -> torch.FloatTensor:
877
+ def forward(self, x: torch.Tensor) -> torch.Tensor:
880
878
  r"""The forward method of the `EncoderTiny` class."""
881
879
  if self.training and self.gradient_checkpointing:
882
880
 
@@ -926,6 +924,7 @@ class DecoderTiny(nn.Module):
926
924
  block_out_channels: Tuple[int, ...],
927
925
  upsampling_scaling_factor: int,
928
926
  act_fn: str,
927
+ upsample_fn: str,
929
928
  ):
930
929
  super().__init__()
931
930
 
@@ -942,7 +941,7 @@ class DecoderTiny(nn.Module):
942
941
  layers.append(AutoencoderTinyBlock(num_channels, num_channels, act_fn))
943
942
 
944
943
  if not is_final_block:
945
- layers.append(nn.Upsample(scale_factor=upsampling_scaling_factor))
944
+ layers.append(nn.Upsample(scale_factor=upsampling_scaling_factor, mode=upsample_fn))
946
945
 
947
946
  conv_out_channel = num_channels if not is_final_block else out_channels
948
947
  layers.append(
@@ -958,7 +957,7 @@ class DecoderTiny(nn.Module):
958
957
  self.layers = nn.Sequential(*layers)
959
958
  self.gradient_checkpointing = False
960
959
 
961
- def forward(self, x: torch.FloatTensor) -> torch.FloatTensor:
960
+ def forward(self, x: torch.Tensor) -> torch.Tensor:
962
961
  r"""The forward method of the `DecoderTiny` class."""
963
962
  # Clamp.
964
963
  x = torch.tanh(x / 3) * 3