diffusers 0.23.1__py3-none-any.whl → 0.24.0__py3-none-any.whl

Sign up to get free protection for your applications and to get access to all the features.
Files changed (176) hide show
  1. diffusers/__init__.py +16 -2
  2. diffusers/configuration_utils.py +1 -0
  3. diffusers/dependency_versions_check.py +0 -1
  4. diffusers/dependency_versions_table.py +4 -5
  5. diffusers/image_processor.py +186 -14
  6. diffusers/loaders/__init__.py +82 -0
  7. diffusers/loaders/ip_adapter.py +157 -0
  8. diffusers/loaders/lora.py +1415 -0
  9. diffusers/loaders/lora_conversion_utils.py +284 -0
  10. diffusers/loaders/single_file.py +631 -0
  11. diffusers/loaders/textual_inversion.py +459 -0
  12. diffusers/loaders/unet.py +735 -0
  13. diffusers/loaders/utils.py +59 -0
  14. diffusers/models/__init__.py +12 -1
  15. diffusers/models/attention.py +165 -14
  16. diffusers/models/attention_flax.py +9 -1
  17. diffusers/models/attention_processor.py +286 -1
  18. diffusers/models/autoencoder_asym_kl.py +14 -9
  19. diffusers/models/autoencoder_kl.py +3 -18
  20. diffusers/models/autoencoder_kl_temporal_decoder.py +402 -0
  21. diffusers/models/autoencoder_tiny.py +20 -24
  22. diffusers/models/consistency_decoder_vae.py +37 -30
  23. diffusers/models/controlnet.py +59 -39
  24. diffusers/models/controlnet_flax.py +19 -18
  25. diffusers/models/embeddings_flax.py +2 -0
  26. diffusers/models/lora.py +131 -1
  27. diffusers/models/modeling_flax_utils.py +2 -1
  28. diffusers/models/modeling_outputs.py +17 -0
  29. diffusers/models/modeling_utils.py +27 -19
  30. diffusers/models/normalization.py +2 -2
  31. diffusers/models/resnet.py +390 -59
  32. diffusers/models/transformer_2d.py +20 -3
  33. diffusers/models/transformer_temporal.py +183 -1
  34. diffusers/models/unet_2d_blocks_flax.py +5 -0
  35. diffusers/models/unet_2d_condition.py +9 -0
  36. diffusers/models/unet_2d_condition_flax.py +13 -13
  37. diffusers/models/unet_3d_blocks.py +957 -173
  38. diffusers/models/unet_3d_condition.py +16 -8
  39. diffusers/models/unet_kandi3.py +589 -0
  40. diffusers/models/unet_motion_model.py +48 -33
  41. diffusers/models/unet_spatio_temporal_condition.py +489 -0
  42. diffusers/models/vae.py +63 -13
  43. diffusers/models/vae_flax.py +7 -0
  44. diffusers/models/vq_model.py +3 -1
  45. diffusers/optimization.py +16 -9
  46. diffusers/pipelines/__init__.py +65 -12
  47. diffusers/pipelines/alt_diffusion/pipeline_alt_diffusion.py +93 -23
  48. diffusers/pipelines/alt_diffusion/pipeline_alt_diffusion_img2img.py +97 -25
  49. diffusers/pipelines/animatediff/pipeline_animatediff.py +34 -4
  50. diffusers/pipelines/audioldm/pipeline_audioldm.py +1 -0
  51. diffusers/pipelines/auto_pipeline.py +6 -0
  52. diffusers/pipelines/consistency_models/pipeline_consistency_models.py +1 -0
  53. diffusers/pipelines/controlnet/pipeline_controlnet.py +217 -31
  54. diffusers/pipelines/controlnet/pipeline_controlnet_img2img.py +101 -32
  55. diffusers/pipelines/controlnet/pipeline_controlnet_inpaint.py +136 -39
  56. diffusers/pipelines/controlnet/pipeline_controlnet_inpaint_sd_xl.py +119 -37
  57. diffusers/pipelines/controlnet/pipeline_controlnet_sd_xl.py +196 -35
  58. diffusers/pipelines/controlnet/pipeline_controlnet_sd_xl_img2img.py +102 -31
  59. diffusers/pipelines/dance_diffusion/pipeline_dance_diffusion.py +1 -0
  60. diffusers/pipelines/ddim/pipeline_ddim.py +1 -0
  61. diffusers/pipelines/ddpm/pipeline_ddpm.py +1 -0
  62. diffusers/pipelines/deepfloyd_if/pipeline_if.py +13 -1
  63. diffusers/pipelines/deepfloyd_if/pipeline_if_img2img.py +13 -1
  64. diffusers/pipelines/deepfloyd_if/pipeline_if_img2img_superresolution.py +13 -1
  65. diffusers/pipelines/deepfloyd_if/pipeline_if_inpainting.py +13 -1
  66. diffusers/pipelines/deepfloyd_if/pipeline_if_inpainting_superresolution.py +13 -1
  67. diffusers/pipelines/deepfloyd_if/pipeline_if_superresolution.py +13 -1
  68. diffusers/pipelines/dit/pipeline_dit.py +1 -0
  69. diffusers/pipelines/kandinsky2_2/pipeline_kandinsky2_2.py +1 -1
  70. diffusers/pipelines/kandinsky2_2/pipeline_kandinsky2_2_combined.py +3 -3
  71. diffusers/pipelines/kandinsky2_2/pipeline_kandinsky2_2_img2img.py +1 -1
  72. diffusers/pipelines/kandinsky2_2/pipeline_kandinsky2_2_inpainting.py +1 -1
  73. diffusers/pipelines/kandinsky2_2/pipeline_kandinsky2_2_prior.py +1 -1
  74. diffusers/pipelines/kandinsky3/__init__.py +49 -0
  75. diffusers/pipelines/kandinsky3/kandinsky3_pipeline.py +452 -0
  76. diffusers/pipelines/kandinsky3/kandinsky3img2img_pipeline.py +460 -0
  77. diffusers/pipelines/latent_consistency_models/pipeline_latent_consistency_img2img.py +65 -6
  78. diffusers/pipelines/latent_consistency_models/pipeline_latent_consistency_text2img.py +55 -3
  79. diffusers/pipelines/latent_diffusion/pipeline_latent_diffusion.py +1 -0
  80. diffusers/pipelines/musicldm/pipeline_musicldm.py +1 -1
  81. diffusers/pipelines/paint_by_example/pipeline_paint_by_example.py +7 -2
  82. diffusers/pipelines/pipeline_flax_utils.py +4 -2
  83. diffusers/pipelines/pipeline_utils.py +33 -13
  84. diffusers/pipelines/pixart_alpha/pipeline_pixart_alpha.py +196 -36
  85. diffusers/pipelines/score_sde_ve/pipeline_score_sde_ve.py +1 -0
  86. diffusers/pipelines/spectrogram_diffusion/pipeline_spectrogram_diffusion.py +1 -0
  87. diffusers/pipelines/stable_diffusion/__init__.py +64 -21
  88. diffusers/pipelines/stable_diffusion/convert_from_ckpt.py +8 -3
  89. diffusers/pipelines/stable_diffusion/pipeline_cycle_diffusion.py +18 -2
  90. diffusers/pipelines/stable_diffusion/pipeline_flax_stable_diffusion.py +2 -2
  91. diffusers/pipelines/stable_diffusion/pipeline_onnx_stable_diffusion_img2img.py +2 -4
  92. diffusers/pipelines/stable_diffusion/pipeline_onnx_stable_diffusion_inpaint.py +1 -0
  93. diffusers/pipelines/stable_diffusion/pipeline_onnx_stable_diffusion_inpaint_legacy.py +1 -0
  94. diffusers/pipelines/stable_diffusion/pipeline_stable_diffusion.py +88 -9
  95. diffusers/pipelines/stable_diffusion/pipeline_stable_diffusion_attend_and_excite.py +1 -0
  96. diffusers/pipelines/stable_diffusion/pipeline_stable_diffusion_depth2img.py +8 -3
  97. diffusers/pipelines/stable_diffusion/pipeline_stable_diffusion_diffedit.py +1 -0
  98. diffusers/pipelines/stable_diffusion/pipeline_stable_diffusion_gligen.py +1 -0
  99. diffusers/pipelines/stable_diffusion/pipeline_stable_diffusion_gligen_text_image.py +1 -0
  100. diffusers/pipelines/stable_diffusion/pipeline_stable_diffusion_image_variation.py +1 -0
  101. diffusers/pipelines/stable_diffusion/pipeline_stable_diffusion_img2img.py +92 -9
  102. diffusers/pipelines/stable_diffusion/pipeline_stable_diffusion_inpaint.py +92 -9
  103. diffusers/pipelines/stable_diffusion/pipeline_stable_diffusion_inpaint_legacy.py +1 -0
  104. diffusers/pipelines/stable_diffusion/pipeline_stable_diffusion_instruct_pix2pix.py +17 -13
  105. diffusers/pipelines/stable_diffusion/pipeline_stable_diffusion_k_diffusion.py +1 -0
  106. diffusers/pipelines/stable_diffusion/pipeline_stable_diffusion_latent_upscale.py +1 -0
  107. diffusers/pipelines/stable_diffusion/pipeline_stable_diffusion_ldm3d.py +1 -0
  108. diffusers/pipelines/stable_diffusion/pipeline_stable_diffusion_model_editing.py +1 -0
  109. diffusers/pipelines/stable_diffusion/pipeline_stable_diffusion_panorama.py +1 -0
  110. diffusers/pipelines/stable_diffusion/pipeline_stable_diffusion_paradigms.py +1 -0
  111. diffusers/pipelines/stable_diffusion/pipeline_stable_diffusion_pix2pix_zero.py +1 -0
  112. diffusers/pipelines/stable_diffusion/pipeline_stable_diffusion_sag.py +1 -0
  113. diffusers/pipelines/stable_diffusion/pipeline_stable_diffusion_upscale.py +1 -0
  114. diffusers/pipelines/stable_diffusion_xl/pipeline_stable_diffusion_xl.py +103 -8
  115. diffusers/pipelines/stable_diffusion_xl/pipeline_stable_diffusion_xl_img2img.py +113 -8
  116. diffusers/pipelines/stable_diffusion_xl/pipeline_stable_diffusion_xl_inpaint.py +115 -9
  117. diffusers/pipelines/stable_diffusion_xl/pipeline_stable_diffusion_xl_instruct_pix2pix.py +16 -12
  118. diffusers/pipelines/stable_video_diffusion/__init__.py +58 -0
  119. diffusers/pipelines/stable_video_diffusion/pipeline_stable_video_diffusion.py +649 -0
  120. diffusers/pipelines/t2i_adapter/pipeline_stable_diffusion_adapter.py +108 -12
  121. diffusers/pipelines/t2i_adapter/pipeline_stable_diffusion_xl_adapter.py +109 -14
  122. diffusers/pipelines/text_to_video_synthesis/__init__.py +2 -0
  123. diffusers/pipelines/text_to_video_synthesis/pipeline_text_to_video_synth.py +1 -0
  124. diffusers/pipelines/text_to_video_synthesis/pipeline_text_to_video_synth_img2img.py +18 -3
  125. diffusers/pipelines/text_to_video_synthesis/pipeline_text_to_video_zero.py +4 -2
  126. diffusers/pipelines/text_to_video_synthesis/pipeline_text_to_video_zero_sdxl.py +872 -0
  127. diffusers/pipelines/versatile_diffusion/modeling_text_unet.py +29 -40
  128. diffusers/pipelines/versatile_diffusion/pipeline_versatile_diffusion_dual_guided.py +1 -0
  129. diffusers/pipelines/versatile_diffusion/pipeline_versatile_diffusion_image_variation.py +1 -0
  130. diffusers/pipelines/versatile_diffusion/pipeline_versatile_diffusion_text_to_image.py +1 -0
  131. diffusers/pipelines/wuerstchen/modeling_wuerstchen_common.py +14 -4
  132. diffusers/pipelines/wuerstchen/modeling_wuerstchen_prior.py +9 -5
  133. diffusers/pipelines/wuerstchen/pipeline_wuerstchen.py +1 -1
  134. diffusers/pipelines/wuerstchen/pipeline_wuerstchen_combined.py +2 -2
  135. diffusers/pipelines/wuerstchen/pipeline_wuerstchen_prior.py +1 -1
  136. diffusers/schedulers/__init__.py +2 -4
  137. diffusers/schedulers/deprecated/__init__.py +50 -0
  138. diffusers/schedulers/{scheduling_karras_ve.py → deprecated/scheduling_karras_ve.py} +4 -4
  139. diffusers/schedulers/{scheduling_sde_vp.py → deprecated/scheduling_sde_vp.py} +4 -6
  140. diffusers/schedulers/scheduling_ddim.py +1 -3
  141. diffusers/schedulers/scheduling_ddim_inverse.py +1 -3
  142. diffusers/schedulers/scheduling_ddim_parallel.py +1 -3
  143. diffusers/schedulers/scheduling_ddpm.py +1 -3
  144. diffusers/schedulers/scheduling_ddpm_parallel.py +1 -3
  145. diffusers/schedulers/scheduling_deis_multistep.py +15 -5
  146. diffusers/schedulers/scheduling_dpmsolver_multistep.py +15 -5
  147. diffusers/schedulers/scheduling_dpmsolver_multistep_inverse.py +15 -5
  148. diffusers/schedulers/scheduling_dpmsolver_sde.py +1 -3
  149. diffusers/schedulers/scheduling_dpmsolver_singlestep.py +15 -5
  150. diffusers/schedulers/scheduling_euler_ancestral_discrete.py +1 -3
  151. diffusers/schedulers/scheduling_euler_discrete.py +40 -13
  152. diffusers/schedulers/scheduling_heun_discrete.py +15 -5
  153. diffusers/schedulers/scheduling_k_dpm_2_ancestral_discrete.py +15 -5
  154. diffusers/schedulers/scheduling_k_dpm_2_discrete.py +15 -5
  155. diffusers/schedulers/scheduling_lcm.py +123 -29
  156. diffusers/schedulers/scheduling_lms_discrete.py +1 -3
  157. diffusers/schedulers/scheduling_pndm.py +1 -3
  158. diffusers/schedulers/scheduling_repaint.py +1 -3
  159. diffusers/schedulers/scheduling_unipc_multistep.py +15 -5
  160. diffusers/utils/__init__.py +1 -0
  161. diffusers/utils/constants.py +8 -7
  162. diffusers/utils/dummy_pt_objects.py +45 -0
  163. diffusers/utils/dummy_torch_and_transformers_objects.py +60 -0
  164. diffusers/utils/dynamic_modules_utils.py +4 -4
  165. diffusers/utils/export_utils.py +8 -3
  166. diffusers/utils/logging.py +10 -10
  167. diffusers/utils/outputs.py +5 -5
  168. diffusers/utils/peft_utils.py +88 -44
  169. diffusers/utils/torch_utils.py +2 -2
  170. {diffusers-0.23.1.dist-info → diffusers-0.24.0.dist-info}/METADATA +38 -22
  171. {diffusers-0.23.1.dist-info → diffusers-0.24.0.dist-info}/RECORD +175 -157
  172. diffusers/loaders.py +0 -3336
  173. {diffusers-0.23.1.dist-info → diffusers-0.24.0.dist-info}/LICENSE +0 -0
  174. {diffusers-0.23.1.dist-info → diffusers-0.24.0.dist-info}/WHEEL +0 -0
  175. {diffusers-0.23.1.dist-info → diffusers-0.24.0.dist-info}/entry_points.txt +0 -0
  176. {diffusers-0.23.1.dist-info → diffusers-0.24.0.dist-info}/top_level.txt +0 -0
@@ -56,9 +56,9 @@ class ConsistencyDecoderVAE(ModelMixin, ConfigMixin):
56
56
  Examples:
57
57
  ```py
58
58
  >>> import torch
59
- >>> from diffusers import DiffusionPipeline, ConsistencyDecoderVAE
59
+ >>> from diffusers import StableDiffusionPipeline, ConsistencyDecoderVAE
60
60
 
61
- >>> vae = ConsistencyDecoderVAE.from_pretrained("openai/consistency-decoder", torch_dtype=pipe.torch_dtype)
61
+ >>> vae = ConsistencyDecoderVAE.from_pretrained("openai/consistency-decoder", torch_dtype=torch.float16)
62
62
  >>> pipe = StableDiffusionPipeline.from_pretrained(
63
63
  ... "runwayml/stable-diffusion-v1-5", vae=vae, torch_dtype=torch.float16
64
64
  ... ).to("cuda")
@@ -70,39 +70,39 @@ class ConsistencyDecoderVAE(ModelMixin, ConfigMixin):
70
70
  @register_to_config
71
71
  def __init__(
72
72
  self,
73
- scaling_factor=0.18215,
74
- latent_channels=4,
75
- encoder_act_fn="silu",
76
- encoder_block_out_channels=(128, 256, 512, 512),
77
- encoder_double_z=True,
78
- encoder_down_block_types=(
73
+ scaling_factor: float = 0.18215,
74
+ latent_channels: int = 4,
75
+ encoder_act_fn: str = "silu",
76
+ encoder_block_out_channels: Tuple[int, ...] = (128, 256, 512, 512),
77
+ encoder_double_z: bool = True,
78
+ encoder_down_block_types: Tuple[str, ...] = (
79
79
  "DownEncoderBlock2D",
80
80
  "DownEncoderBlock2D",
81
81
  "DownEncoderBlock2D",
82
82
  "DownEncoderBlock2D",
83
83
  ),
84
- encoder_in_channels=3,
85
- encoder_layers_per_block=2,
86
- encoder_norm_num_groups=32,
87
- encoder_out_channels=4,
88
- decoder_add_attention=False,
89
- decoder_block_out_channels=(320, 640, 1024, 1024),
90
- decoder_down_block_types=(
84
+ encoder_in_channels: int = 3,
85
+ encoder_layers_per_block: int = 2,
86
+ encoder_norm_num_groups: int = 32,
87
+ encoder_out_channels: int = 4,
88
+ decoder_add_attention: bool = False,
89
+ decoder_block_out_channels: Tuple[int, ...] = (320, 640, 1024, 1024),
90
+ decoder_down_block_types: Tuple[str, ...] = (
91
91
  "ResnetDownsampleBlock2D",
92
92
  "ResnetDownsampleBlock2D",
93
93
  "ResnetDownsampleBlock2D",
94
94
  "ResnetDownsampleBlock2D",
95
95
  ),
96
- decoder_downsample_padding=1,
97
- decoder_in_channels=7,
98
- decoder_layers_per_block=3,
99
- decoder_norm_eps=1e-05,
100
- decoder_norm_num_groups=32,
101
- decoder_num_train_timesteps=1024,
102
- decoder_out_channels=6,
103
- decoder_resnet_time_scale_shift="scale_shift",
104
- decoder_time_embedding_type="learned",
105
- decoder_up_block_types=(
96
+ decoder_downsample_padding: int = 1,
97
+ decoder_in_channels: int = 7,
98
+ decoder_layers_per_block: int = 3,
99
+ decoder_norm_eps: float = 1e-05,
100
+ decoder_norm_num_groups: int = 32,
101
+ decoder_num_train_timesteps: int = 1024,
102
+ decoder_out_channels: int = 6,
103
+ decoder_resnet_time_scale_shift: str = "scale_shift",
104
+ decoder_time_embedding_type: str = "learned",
105
+ decoder_up_block_types: Tuple[str, ...] = (
106
106
  "ResnetUpsampleBlock2D",
107
107
  "ResnetUpsampleBlock2D",
108
108
  "ResnetUpsampleBlock2D",
@@ -138,6 +138,7 @@ class ConsistencyDecoderVAE(ModelMixin, ConfigMixin):
138
138
  )
139
139
  self.decoder_scheduler = ConsistencyDecoderScheduler()
140
140
  self.register_to_config(block_out_channels=encoder_block_out_channels)
141
+ self.register_to_config(force_upcast=False)
141
142
  self.register_buffer(
142
143
  "means",
143
144
  torch.tensor([0.38862467, 0.02253063, 0.07381133, -0.0171294])[None, :, None, None],
@@ -304,8 +305,8 @@ class ConsistencyDecoderVAE(ModelMixin, ConfigMixin):
304
305
  z: torch.FloatTensor,
305
306
  generator: Optional[torch.Generator] = None,
306
307
  return_dict: bool = True,
307
- num_inference_steps=2,
308
- ) -> Union[DecoderOutput, torch.FloatTensor]:
308
+ num_inference_steps: int = 2,
309
+ ) -> Union[DecoderOutput, Tuple[torch.FloatTensor]]:
309
310
  z = (z * self.config.scaling_factor - self.means) / self.stds
310
311
 
311
312
  scale_factor = 2 ** (len(self.config.block_out_channels) - 1)
@@ -333,14 +334,14 @@ class ConsistencyDecoderVAE(ModelMixin, ConfigMixin):
333
334
  return DecoderOutput(sample=x_0)
334
335
 
335
336
  # Copied from diffusers.models.autoencoder_kl.AutoencoderKL.blend_v
336
- def blend_v(self, a, b, blend_extent):
337
+ def blend_v(self, a: torch.Tensor, b: torch.Tensor, blend_extent: int) -> torch.Tensor:
337
338
  blend_extent = min(a.shape[2], b.shape[2], blend_extent)
338
339
  for y in range(blend_extent):
339
340
  b[:, :, y, :] = a[:, :, -blend_extent + y, :] * (1 - y / blend_extent) + b[:, :, y, :] * (y / blend_extent)
340
341
  return b
341
342
 
342
343
  # Copied from diffusers.models.autoencoder_kl.AutoencoderKL.blend_h
343
- def blend_h(self, a, b, blend_extent):
344
+ def blend_h(self, a: torch.Tensor, b: torch.Tensor, blend_extent: int) -> torch.Tensor:
344
345
  blend_extent = min(a.shape[3], b.shape[3], blend_extent)
345
346
  for x in range(blend_extent):
346
347
  b[:, :, :, x] = a[:, :, :, -blend_extent + x] * (1 - x / blend_extent) + b[:, :, :, x] * (x / blend_extent)
@@ -407,7 +408,7 @@ class ConsistencyDecoderVAE(ModelMixin, ConfigMixin):
407
408
  sample_posterior: bool = False,
408
409
  return_dict: bool = True,
409
410
  generator: Optional[torch.Generator] = None,
410
- ) -> Union[DecoderOutput, torch.FloatTensor]:
411
+ ) -> Union[DecoderOutput, Tuple[torch.FloatTensor]]:
411
412
  r"""
412
413
  Args:
413
414
  sample (`torch.FloatTensor`): Input sample.
@@ -415,6 +416,12 @@ class ConsistencyDecoderVAE(ModelMixin, ConfigMixin):
415
416
  Whether to sample from the posterior.
416
417
  return_dict (`bool`, *optional*, defaults to `True`):
417
418
  Whether or not to return a [`DecoderOutput`] instead of a plain tuple.
419
+ generator (`torch.Generator`, *optional*, defaults to `None`):
420
+ Generator to use for sampling.
421
+
422
+ Returns:
423
+ [`DecoderOutput`] or `tuple`:
424
+ If return_dict is True, a [`DecoderOutput`] is returned, otherwise a plain `tuple` is returned.
418
425
  """
419
426
  x = sample
420
427
  posterior = self.encode(x).latent_dist
@@ -30,12 +30,7 @@ from .attention_processor import (
30
30
  )
31
31
  from .embeddings import TextImageProjection, TextImageTimeEmbedding, TextTimeEmbedding, TimestepEmbedding, Timesteps
32
32
  from .modeling_utils import ModelMixin
33
- from .unet_2d_blocks import (
34
- CrossAttnDownBlock2D,
35
- DownBlock2D,
36
- UNetMidBlock2DCrossAttn,
37
- get_down_block,
38
- )
33
+ from .unet_2d_blocks import CrossAttnDownBlock2D, DownBlock2D, UNetMidBlock2D, UNetMidBlock2DCrossAttn, get_down_block
39
34
  from .unet_2d_condition import UNet2DConditionModel
40
35
 
41
36
 
@@ -76,7 +71,7 @@ class ControlNetConditioningEmbedding(nn.Module):
76
71
  self,
77
72
  conditioning_embedding_channels: int,
78
73
  conditioning_channels: int = 3,
79
- block_out_channels: Tuple[int] = (16, 32, 96, 256),
74
+ block_out_channels: Tuple[int, ...] = (16, 32, 96, 256),
80
75
  ):
81
76
  super().__init__()
82
77
 
@@ -171,6 +166,9 @@ class ControlNetModel(ModelMixin, ConfigMixin, FromOriginalControlnetMixin):
171
166
  conditioning_embedding_out_channels (`tuple[int]`, *optional*, defaults to `(16, 32, 96, 256)`):
172
167
  The tuple of output channel for each block in the `conditioning_embedding` layer.
173
168
  global_pool_conditions (`bool`, defaults to `False`):
169
+ TODO(Patrick) - unused parameter.
170
+ addition_embed_type_num_heads (`int`, defaults to 64):
171
+ The number of heads to use for the `TextTimeEmbedding` layer.
174
172
  """
175
173
 
176
174
  _supports_gradient_checkpointing = True
@@ -182,14 +180,15 @@ class ControlNetModel(ModelMixin, ConfigMixin, FromOriginalControlnetMixin):
182
180
  conditioning_channels: int = 3,
183
181
  flip_sin_to_cos: bool = True,
184
182
  freq_shift: int = 0,
185
- down_block_types: Tuple[str] = (
183
+ down_block_types: Tuple[str, ...] = (
186
184
  "CrossAttnDownBlock2D",
187
185
  "CrossAttnDownBlock2D",
188
186
  "CrossAttnDownBlock2D",
189
187
  "DownBlock2D",
190
188
  ),
189
+ mid_block_type: Optional[str] = "UNetMidBlock2DCrossAttn",
191
190
  only_cross_attention: Union[bool, Tuple[bool]] = False,
192
- block_out_channels: Tuple[int] = (320, 640, 1280, 1280),
191
+ block_out_channels: Tuple[int, ...] = (320, 640, 1280, 1280),
193
192
  layers_per_block: int = 2,
194
193
  downsample_padding: int = 1,
195
194
  mid_block_scale_factor: float = 1,
@@ -197,11 +196,11 @@ class ControlNetModel(ModelMixin, ConfigMixin, FromOriginalControlnetMixin):
197
196
  norm_num_groups: Optional[int] = 32,
198
197
  norm_eps: float = 1e-5,
199
198
  cross_attention_dim: int = 1280,
200
- transformer_layers_per_block: Union[int, Tuple[int]] = 1,
199
+ transformer_layers_per_block: Union[int, Tuple[int, ...]] = 1,
201
200
  encoder_hid_dim: Optional[int] = None,
202
201
  encoder_hid_dim_type: Optional[str] = None,
203
- attention_head_dim: Union[int, Tuple[int]] = 8,
204
- num_attention_heads: Optional[Union[int, Tuple[int]]] = None,
202
+ attention_head_dim: Union[int, Tuple[int, ...]] = 8,
203
+ num_attention_heads: Optional[Union[int, Tuple[int, ...]]] = None,
205
204
  use_linear_projection: bool = False,
206
205
  class_embed_type: Optional[str] = None,
207
206
  addition_embed_type: Optional[str] = None,
@@ -211,9 +210,9 @@ class ControlNetModel(ModelMixin, ConfigMixin, FromOriginalControlnetMixin):
211
210
  resnet_time_scale_shift: str = "default",
212
211
  projection_class_embeddings_input_dim: Optional[int] = None,
213
212
  controlnet_conditioning_channel_order: str = "rgb",
214
- conditioning_embedding_out_channels: Optional[Tuple[int]] = (16, 32, 96, 256),
213
+ conditioning_embedding_out_channels: Optional[Tuple[int, ...]] = (16, 32, 96, 256),
215
214
  global_pool_conditions: bool = False,
216
- addition_embed_type_num_heads=64,
215
+ addition_embed_type_num_heads: int = 64,
217
216
  ):
218
217
  super().__init__()
219
218
 
@@ -406,28 +405,44 @@ class ControlNetModel(ModelMixin, ConfigMixin, FromOriginalControlnetMixin):
406
405
  controlnet_block = zero_module(controlnet_block)
407
406
  self.controlnet_mid_block = controlnet_block
408
407
 
409
- self.mid_block = UNetMidBlock2DCrossAttn(
410
- transformer_layers_per_block=transformer_layers_per_block[-1],
411
- in_channels=mid_block_channel,
412
- temb_channels=time_embed_dim,
413
- resnet_eps=norm_eps,
414
- resnet_act_fn=act_fn,
415
- output_scale_factor=mid_block_scale_factor,
416
- resnet_time_scale_shift=resnet_time_scale_shift,
417
- cross_attention_dim=cross_attention_dim,
418
- num_attention_heads=num_attention_heads[-1],
419
- resnet_groups=norm_num_groups,
420
- use_linear_projection=use_linear_projection,
421
- upcast_attention=upcast_attention,
422
- )
408
+ if mid_block_type == "UNetMidBlock2DCrossAttn":
409
+ self.mid_block = UNetMidBlock2DCrossAttn(
410
+ transformer_layers_per_block=transformer_layers_per_block[-1],
411
+ in_channels=mid_block_channel,
412
+ temb_channels=time_embed_dim,
413
+ resnet_eps=norm_eps,
414
+ resnet_act_fn=act_fn,
415
+ output_scale_factor=mid_block_scale_factor,
416
+ resnet_time_scale_shift=resnet_time_scale_shift,
417
+ cross_attention_dim=cross_attention_dim,
418
+ num_attention_heads=num_attention_heads[-1],
419
+ resnet_groups=norm_num_groups,
420
+ use_linear_projection=use_linear_projection,
421
+ upcast_attention=upcast_attention,
422
+ )
423
+ elif mid_block_type == "UNetMidBlock2D":
424
+ self.mid_block = UNetMidBlock2D(
425
+ in_channels=block_out_channels[-1],
426
+ temb_channels=time_embed_dim,
427
+ num_layers=0,
428
+ resnet_eps=norm_eps,
429
+ resnet_act_fn=act_fn,
430
+ output_scale_factor=mid_block_scale_factor,
431
+ resnet_groups=norm_num_groups,
432
+ resnet_time_scale_shift=resnet_time_scale_shift,
433
+ add_attention=False,
434
+ )
435
+ else:
436
+ raise ValueError(f"unknown mid_block_type : {mid_block_type}")
423
437
 
424
438
  @classmethod
425
439
  def from_unet(
426
440
  cls,
427
441
  unet: UNet2DConditionModel,
428
442
  controlnet_conditioning_channel_order: str = "rgb",
429
- conditioning_embedding_out_channels: Optional[Tuple[int]] = (16, 32, 96, 256),
443
+ conditioning_embedding_out_channels: Optional[Tuple[int, ...]] = (16, 32, 96, 256),
430
444
  load_weights_from_unet: bool = True,
445
+ conditioning_channels: int = 3,
431
446
  ):
432
447
  r"""
433
448
  Instantiate a [`ControlNetModel`] from [`UNet2DConditionModel`].
@@ -474,8 +489,10 @@ class ControlNetModel(ModelMixin, ConfigMixin, FromOriginalControlnetMixin):
474
489
  upcast_attention=unet.config.upcast_attention,
475
490
  resnet_time_scale_shift=unet.config.resnet_time_scale_shift,
476
491
  projection_class_embeddings_input_dim=unet.config.projection_class_embeddings_input_dim,
492
+ mid_block_type=unet.config.mid_block_type,
477
493
  controlnet_conditioning_channel_order=controlnet_conditioning_channel_order,
478
494
  conditioning_embedding_out_channels=conditioning_embedding_out_channels,
495
+ conditioning_channels=conditioning_channels,
479
496
  )
480
497
 
481
498
  if load_weights_from_unet:
@@ -570,7 +587,7 @@ class ControlNetModel(ModelMixin, ConfigMixin, FromOriginalControlnetMixin):
570
587
  self.set_attn_processor(processor, _remove_lora=True)
571
588
 
572
589
  # Copied from diffusers.models.unet_2d_condition.UNet2DConditionModel.set_attention_slice
573
- def set_attention_slice(self, slice_size):
590
+ def set_attention_slice(self, slice_size: Union[str, int, List[int]]) -> None:
574
591
  r"""
575
592
  Enable sliced attention computation.
576
593
 
@@ -635,7 +652,7 @@ class ControlNetModel(ModelMixin, ConfigMixin, FromOriginalControlnetMixin):
635
652
  for module in self.children():
636
653
  fn_recursive_set_attention_slice(module, reversed_slice_size)
637
654
 
638
- def _set_gradient_checkpointing(self, module, value=False):
655
+ def _set_gradient_checkpointing(self, module, value: bool = False) -> None:
639
656
  if isinstance(module, (CrossAttnDownBlock2D, DownBlock2D)):
640
657
  module.gradient_checkpointing = value
641
658
 
@@ -653,7 +670,7 @@ class ControlNetModel(ModelMixin, ConfigMixin, FromOriginalControlnetMixin):
653
670
  cross_attention_kwargs: Optional[Dict[str, Any]] = None,
654
671
  guess_mode: bool = False,
655
672
  return_dict: bool = True,
656
- ) -> Union[ControlNetOutput, Tuple]:
673
+ ) -> Union[ControlNetOutput, Tuple[Tuple[torch.FloatTensor, ...], torch.FloatTensor]]:
657
674
  """
658
675
  The [`ControlNetModel`] forward method.
659
676
 
@@ -794,13 +811,16 @@ class ControlNetModel(ModelMixin, ConfigMixin, FromOriginalControlnetMixin):
794
811
 
795
812
  # 4. mid
796
813
  if self.mid_block is not None:
797
- sample = self.mid_block(
798
- sample,
799
- emb,
800
- encoder_hidden_states=encoder_hidden_states,
801
- attention_mask=attention_mask,
802
- cross_attention_kwargs=cross_attention_kwargs,
803
- )
814
+ if hasattr(self.mid_block, "has_cross_attention") and self.mid_block.has_cross_attention:
815
+ sample = self.mid_block(
816
+ sample,
817
+ emb,
818
+ encoder_hidden_states=encoder_hidden_states,
819
+ attention_mask=attention_mask,
820
+ cross_attention_kwargs=cross_attention_kwargs,
821
+ )
822
+ else:
823
+ sample = self.mid_block(sample, emb)
804
824
 
805
825
  # 5. Control net blocks
806
826
 
@@ -46,10 +46,10 @@ class FlaxControlNetOutput(BaseOutput):
46
46
 
47
47
  class FlaxControlNetConditioningEmbedding(nn.Module):
48
48
  conditioning_embedding_channels: int
49
- block_out_channels: Tuple[int] = (16, 32, 96, 256)
49
+ block_out_channels: Tuple[int, ...] = (16, 32, 96, 256)
50
50
  dtype: jnp.dtype = jnp.float32
51
51
 
52
- def setup(self):
52
+ def setup(self) -> None:
53
53
  self.conv_in = nn.Conv(
54
54
  self.block_out_channels[0],
55
55
  kernel_size=(3, 3),
@@ -87,7 +87,7 @@ class FlaxControlNetConditioningEmbedding(nn.Module):
87
87
  dtype=self.dtype,
88
88
  )
89
89
 
90
- def __call__(self, conditioning):
90
+ def __call__(self, conditioning: jnp.ndarray) -> jnp.ndarray:
91
91
  embedding = self.conv_in(conditioning)
92
92
  embedding = nn.silu(embedding)
93
93
 
@@ -146,19 +146,20 @@ class FlaxControlNetModel(nn.Module, FlaxModelMixin, ConfigMixin):
146
146
  conditioning_embedding_out_channels (`tuple`, *optional*, defaults to `(16, 32, 96, 256)`):
147
147
  The tuple of output channel for each block in the `conditioning_embedding` layer.
148
148
  """
149
+
149
150
  sample_size: int = 32
150
151
  in_channels: int = 4
151
- down_block_types: Tuple[str] = (
152
+ down_block_types: Tuple[str, ...] = (
152
153
  "CrossAttnDownBlock2D",
153
154
  "CrossAttnDownBlock2D",
154
155
  "CrossAttnDownBlock2D",
155
156
  "DownBlock2D",
156
157
  )
157
- only_cross_attention: Union[bool, Tuple[bool]] = False
158
- block_out_channels: Tuple[int] = (320, 640, 1280, 1280)
158
+ only_cross_attention: Union[bool, Tuple[bool, ...]] = False
159
+ block_out_channels: Tuple[int, ...] = (320, 640, 1280, 1280)
159
160
  layers_per_block: int = 2
160
- attention_head_dim: Union[int, Tuple[int]] = 8
161
- num_attention_heads: Optional[Union[int, Tuple[int]]] = None
161
+ attention_head_dim: Union[int, Tuple[int, ...]] = 8
162
+ num_attention_heads: Optional[Union[int, Tuple[int, ...]]] = None
162
163
  cross_attention_dim: int = 1280
163
164
  dropout: float = 0.0
164
165
  use_linear_projection: bool = False
@@ -166,7 +167,7 @@ class FlaxControlNetModel(nn.Module, FlaxModelMixin, ConfigMixin):
166
167
  flip_sin_to_cos: bool = True
167
168
  freq_shift: int = 0
168
169
  controlnet_conditioning_channel_order: str = "rgb"
169
- conditioning_embedding_out_channels: Tuple[int] = (16, 32, 96, 256)
170
+ conditioning_embedding_out_channels: Tuple[int, ...] = (16, 32, 96, 256)
170
171
 
171
172
  def init_weights(self, rng: jax.Array) -> FrozenDict:
172
173
  # init input tensors
@@ -182,7 +183,7 @@ class FlaxControlNetModel(nn.Module, FlaxModelMixin, ConfigMixin):
182
183
 
183
184
  return self.init(rngs, sample, timesteps, encoder_hidden_states, controlnet_cond)["params"]
184
185
 
185
- def setup(self):
186
+ def setup(self) -> None:
186
187
  block_out_channels = self.block_out_channels
187
188
  time_embed_dim = block_out_channels[0] * 4
188
189
 
@@ -312,21 +313,21 @@ class FlaxControlNetModel(nn.Module, FlaxModelMixin, ConfigMixin):
312
313
 
313
314
  def __call__(
314
315
  self,
315
- sample,
316
- timesteps,
317
- encoder_hidden_states,
318
- controlnet_cond,
316
+ sample: jnp.ndarray,
317
+ timesteps: Union[jnp.ndarray, float, int],
318
+ encoder_hidden_states: jnp.ndarray,
319
+ controlnet_cond: jnp.ndarray,
319
320
  conditioning_scale: float = 1.0,
320
321
  return_dict: bool = True,
321
322
  train: bool = False,
322
- ) -> Union[FlaxControlNetOutput, Tuple]:
323
+ ) -> Union[FlaxControlNetOutput, Tuple[Tuple[jnp.ndarray, ...], jnp.ndarray]]:
323
324
  r"""
324
325
  Args:
325
326
  sample (`jnp.ndarray`): (batch, channel, height, width) noisy inputs tensor
326
327
  timestep (`jnp.ndarray` or `float` or `int`): timesteps
327
328
  encoder_hidden_states (`jnp.ndarray`): (batch_size, sequence_length, hidden_size) encoder hidden states
328
329
  controlnet_cond (`jnp.ndarray`): (batch, channel, height, width) the conditional input tensor
329
- conditioning_scale: (`float`) the scale factor for controlnet outputs
330
+ conditioning_scale (`float`, *optional*, defaults to `1.0`): the scale factor for controlnet outputs
330
331
  return_dict (`bool`, *optional*, defaults to `True`):
331
332
  Whether or not to return a [`models.unet_2d_condition_flax.FlaxUNet2DConditionOutput`] instead of a
332
333
  plain tuple.
@@ -335,8 +336,8 @@ class FlaxControlNetModel(nn.Module, FlaxModelMixin, ConfigMixin):
335
336
 
336
337
  Returns:
337
338
  [`~models.unet_2d_condition_flax.FlaxUNet2DConditionOutput`] or `tuple`:
338
- [`~models.unet_2d_condition_flax.FlaxUNet2DConditionOutput`] if `return_dict` is True, otherwise a `tuple`.
339
- When returning a tuple, the first element is the sample tensor.
339
+ [`~models.unet_2d_condition_flax.FlaxUNet2DConditionOutput`] if `return_dict` is True, otherwise a
340
+ `tuple`. When returning a tuple, the first element is the sample tensor.
340
341
  """
341
342
  channel_order = self.controlnet_conditioning_channel_order
342
343
  if channel_order == "bgr":
@@ -65,6 +65,7 @@ class FlaxTimestepEmbedding(nn.Module):
65
65
  dtype (:obj:`jnp.dtype`, *optional*, defaults to jnp.float32):
66
66
  Parameters `dtype`
67
67
  """
68
+
68
69
  time_embed_dim: int = 32
69
70
  dtype: jnp.dtype = jnp.float32
70
71
 
@@ -84,6 +85,7 @@ class FlaxTimesteps(nn.Module):
84
85
  dim (`int`, *optional*, defaults to `32`):
85
86
  Time step embedding dimension
86
87
  """
88
+
87
89
  dim: int = 32
88
90
  flip_sin_to_cos: bool = False
89
91
  freq_shift: float = 1
diffusers/models/lora.py CHANGED
@@ -12,19 +12,60 @@
12
12
  # See the License for the specific language governing permissions and
13
13
  # limitations under the License.
14
14
 
15
+
16
+ # IMPORTANT: #
17
+ ###################################################################
18
+ # ----------------------------------------------------------------#
19
+ # This file is deprecated and will be removed soon #
20
+ # (as soon as PEFT will become a required dependency for LoRA) #
21
+ # ----------------------------------------------------------------#
22
+ ###################################################################
23
+
15
24
  from typing import Optional, Tuple, Union
16
25
 
17
26
  import torch
18
27
  import torch.nn.functional as F
19
28
  from torch import nn
20
29
 
21
- from ..loaders import PatchedLoraProjection, text_encoder_attn_modules, text_encoder_mlp_modules
22
30
  from ..utils import logging
31
+ from ..utils.import_utils import is_transformers_available
32
+
33
+
34
+ if is_transformers_available():
35
+ from transformers import CLIPTextModel, CLIPTextModelWithProjection
23
36
 
24
37
 
25
38
  logger = logging.get_logger(__name__) # pylint: disable=invalid-name
26
39
 
27
40
 
41
+ def text_encoder_attn_modules(text_encoder):
42
+ attn_modules = []
43
+
44
+ if isinstance(text_encoder, (CLIPTextModel, CLIPTextModelWithProjection)):
45
+ for i, layer in enumerate(text_encoder.text_model.encoder.layers):
46
+ name = f"text_model.encoder.layers.{i}.self_attn"
47
+ mod = layer.self_attn
48
+ attn_modules.append((name, mod))
49
+ else:
50
+ raise ValueError(f"do not know how to get attention modules for: {text_encoder.__class__.__name__}")
51
+
52
+ return attn_modules
53
+
54
+
55
+ def text_encoder_mlp_modules(text_encoder):
56
+ mlp_modules = []
57
+
58
+ if isinstance(text_encoder, (CLIPTextModel, CLIPTextModelWithProjection)):
59
+ for i, layer in enumerate(text_encoder.text_model.encoder.layers):
60
+ mlp_mod = layer.mlp
61
+ name = f"text_model.encoder.layers.{i}.mlp"
62
+ mlp_modules.append((name, mlp_mod))
63
+ else:
64
+ raise ValueError(f"do not know how to get mlp modules for: {text_encoder.__class__.__name__}")
65
+
66
+ return mlp_modules
67
+
68
+
28
69
  def adjust_lora_scale_text_encoder(text_encoder, lora_scale: float = 1.0):
29
70
  for _, attn_module in text_encoder_attn_modules(text_encoder):
30
71
  if isinstance(attn_module.q_proj, PatchedLoraProjection):
@@ -39,6 +80,95 @@ def adjust_lora_scale_text_encoder(text_encoder, lora_scale: float = 1.0):
39
80
  mlp_module.fc2.lora_scale = lora_scale
40
81
 
41
82
 
83
+ class PatchedLoraProjection(torch.nn.Module):
84
+ def __init__(self, regular_linear_layer, lora_scale=1, network_alpha=None, rank=4, dtype=None):
85
+ super().__init__()
86
+ from ..models.lora import LoRALinearLayer
87
+
88
+ self.regular_linear_layer = regular_linear_layer
89
+
90
+ device = self.regular_linear_layer.weight.device
91
+
92
+ if dtype is None:
93
+ dtype = self.regular_linear_layer.weight.dtype
94
+
95
+ self.lora_linear_layer = LoRALinearLayer(
96
+ self.regular_linear_layer.in_features,
97
+ self.regular_linear_layer.out_features,
98
+ network_alpha=network_alpha,
99
+ device=device,
100
+ dtype=dtype,
101
+ rank=rank,
102
+ )
103
+
104
+ self.lora_scale = lora_scale
105
+
106
+ # overwrite PyTorch's `state_dict` to be sure that only the 'regular_linear_layer' weights are saved
107
+ # when saving the whole text encoder model and when LoRA is unloaded or fused
108
+ def state_dict(self, *args, destination=None, prefix="", keep_vars=False):
109
+ if self.lora_linear_layer is None:
110
+ return self.regular_linear_layer.state_dict(
111
+ *args, destination=destination, prefix=prefix, keep_vars=keep_vars
112
+ )
113
+
114
+ return super().state_dict(*args, destination=destination, prefix=prefix, keep_vars=keep_vars)
115
+
116
+ def _fuse_lora(self, lora_scale=1.0, safe_fusing=False):
117
+ if self.lora_linear_layer is None:
118
+ return
119
+
120
+ dtype, device = self.regular_linear_layer.weight.data.dtype, self.regular_linear_layer.weight.data.device
121
+
122
+ w_orig = self.regular_linear_layer.weight.data.float()
123
+ w_up = self.lora_linear_layer.up.weight.data.float()
124
+ w_down = self.lora_linear_layer.down.weight.data.float()
125
+
126
+ if self.lora_linear_layer.network_alpha is not None:
127
+ w_up = w_up * self.lora_linear_layer.network_alpha / self.lora_linear_layer.rank
128
+
129
+ fused_weight = w_orig + (lora_scale * torch.bmm(w_up[None, :], w_down[None, :])[0])
130
+
131
+ if safe_fusing and torch.isnan(fused_weight).any().item():
132
+ raise ValueError(
133
+ "This LoRA weight seems to be broken. "
134
+ f"Encountered NaN values when trying to fuse LoRA weights for {self}."
135
+ "LoRA weights will not be fused."
136
+ )
137
+
138
+ self.regular_linear_layer.weight.data = fused_weight.to(device=device, dtype=dtype)
139
+
140
+ # we can drop the lora layer now
141
+ self.lora_linear_layer = None
142
+
143
+ # offload the up and down matrices to CPU to not blow the memory
144
+ self.w_up = w_up.cpu()
145
+ self.w_down = w_down.cpu()
146
+ self.lora_scale = lora_scale
147
+
148
+ def _unfuse_lora(self):
149
+ if not (getattr(self, "w_up", None) is not None and getattr(self, "w_down", None) is not None):
150
+ return
151
+
152
+ fused_weight = self.regular_linear_layer.weight.data
153
+ dtype, device = fused_weight.dtype, fused_weight.device
154
+
155
+ w_up = self.w_up.to(device=device).float()
156
+ w_down = self.w_down.to(device).float()
157
+
158
+ unfused_weight = fused_weight.float() - (self.lora_scale * torch.bmm(w_up[None, :], w_down[None, :])[0])
159
+ self.regular_linear_layer.weight.data = unfused_weight.to(device=device, dtype=dtype)
160
+
161
+ self.w_up = None
162
+ self.w_down = None
163
+
164
+ def forward(self, input):
165
+ if self.lora_scale is None:
166
+ self.lora_scale = 1.0
167
+ if self.lora_linear_layer is None:
168
+ return self.regular_linear_layer(input)
169
+ return self.regular_linear_layer(input) + (self.lora_scale * self.lora_linear_layer(input))
170
+
171
+
42
172
  class LoRALinearLayer(nn.Module):
43
173
  r"""
44
174
  A linear layer that is used with LoRA.
@@ -52,6 +52,7 @@ class FlaxModelMixin(PushToHubMixin):
52
52
 
53
53
  - **config_name** ([`str`]) -- Filename to save a model to when calling [`~FlaxModelMixin.save_pretrained`].
54
54
  """
55
+
55
56
  config_name = CONFIG_NAME
56
57
  _automatically_saved_args = ["_diffusers_version", "_class_name", "_name_or_path"]
57
58
  _flax_internal_args = ["name", "parent", "dtype"]
@@ -436,7 +437,7 @@ class FlaxModelMixin(PushToHubMixin):
436
437
  # make sure all arrays are stored as jnp.ndarray
437
438
  # NOTE: This is to prevent a bug this will be fixed in Flax >= v0.3.4:
438
439
  # https://github.com/google/flax/issues/1261
439
- state = jax.tree_util.tree_map(lambda x: jax.device_put(x, jax.devices("cpu")[0]), state)
440
+ state = jax.tree_util.tree_map(lambda x: jax.device_put(x, jax.local_devices(backend="cpu")[0]), state)
440
441
 
441
442
  # flatten dicts
442
443
  state = flatten_dict(state)
@@ -0,0 +1,17 @@
1
+ from dataclasses import dataclass
2
+
3
+ from ..utils import BaseOutput
4
+
5
+
6
+ @dataclass
7
+ class AutoencoderKLOutput(BaseOutput):
8
+ """
9
+ Output of AutoencoderKL encoding method.
10
+
11
+ Args:
12
+ latent_dist (`DiagonalGaussianDistribution`):
13
+ Encoded outputs of `Encoder` represented as the mean and logvar of `DiagonalGaussianDistribution`.
14
+ `DiagonalGaussianDistribution` allows for sampling latents from the distribution.
15
+ """
16
+
17
+ latent_dist: "DiagonalGaussianDistribution" # noqa: F821