diffusers 0.29.2__py3-none-any.whl → 0.30.0__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (220) hide show
  1. diffusers/__init__.py +94 -3
  2. diffusers/commands/env.py +1 -5
  3. diffusers/configuration_utils.py +4 -9
  4. diffusers/dependency_versions_table.py +2 -2
  5. diffusers/image_processor.py +1 -2
  6. diffusers/loaders/__init__.py +17 -2
  7. diffusers/loaders/ip_adapter.py +10 -7
  8. diffusers/loaders/lora_base.py +752 -0
  9. diffusers/loaders/lora_pipeline.py +2222 -0
  10. diffusers/loaders/peft.py +213 -5
  11. diffusers/loaders/single_file.py +1 -12
  12. diffusers/loaders/single_file_model.py +31 -10
  13. diffusers/loaders/single_file_utils.py +262 -2
  14. diffusers/loaders/textual_inversion.py +1 -6
  15. diffusers/loaders/unet.py +23 -208
  16. diffusers/models/__init__.py +20 -0
  17. diffusers/models/activations.py +22 -0
  18. diffusers/models/attention.py +386 -7
  19. diffusers/models/attention_processor.py +1795 -629
  20. diffusers/models/autoencoders/__init__.py +2 -0
  21. diffusers/models/autoencoders/autoencoder_kl.py +14 -3
  22. diffusers/models/autoencoders/autoencoder_kl_cogvideox.py +1035 -0
  23. diffusers/models/autoencoders/autoencoder_kl_temporal_decoder.py +1 -1
  24. diffusers/models/autoencoders/autoencoder_oobleck.py +464 -0
  25. diffusers/models/autoencoders/autoencoder_tiny.py +1 -0
  26. diffusers/models/autoencoders/consistency_decoder_vae.py +1 -1
  27. diffusers/models/autoencoders/vq_model.py +4 -4
  28. diffusers/models/controlnet.py +2 -3
  29. diffusers/models/controlnet_hunyuan.py +401 -0
  30. diffusers/models/controlnet_sd3.py +11 -11
  31. diffusers/models/controlnet_sparsectrl.py +789 -0
  32. diffusers/models/controlnet_xs.py +40 -10
  33. diffusers/models/downsampling.py +68 -0
  34. diffusers/models/embeddings.py +319 -36
  35. diffusers/models/model_loading_utils.py +1 -3
  36. diffusers/models/modeling_flax_utils.py +1 -6
  37. diffusers/models/modeling_utils.py +4 -16
  38. diffusers/models/normalization.py +203 -12
  39. diffusers/models/transformers/__init__.py +6 -0
  40. diffusers/models/transformers/auraflow_transformer_2d.py +527 -0
  41. diffusers/models/transformers/cogvideox_transformer_3d.py +345 -0
  42. diffusers/models/transformers/hunyuan_transformer_2d.py +19 -15
  43. diffusers/models/transformers/latte_transformer_3d.py +327 -0
  44. diffusers/models/transformers/lumina_nextdit2d.py +340 -0
  45. diffusers/models/transformers/pixart_transformer_2d.py +102 -1
  46. diffusers/models/transformers/prior_transformer.py +1 -1
  47. diffusers/models/transformers/stable_audio_transformer.py +458 -0
  48. diffusers/models/transformers/transformer_flux.py +455 -0
  49. diffusers/models/transformers/transformer_sd3.py +18 -4
  50. diffusers/models/unets/unet_1d_blocks.py +1 -1
  51. diffusers/models/unets/unet_2d_condition.py +8 -1
  52. diffusers/models/unets/unet_3d_blocks.py +51 -920
  53. diffusers/models/unets/unet_3d_condition.py +4 -1
  54. diffusers/models/unets/unet_i2vgen_xl.py +4 -1
  55. diffusers/models/unets/unet_kandinsky3.py +1 -1
  56. diffusers/models/unets/unet_motion_model.py +1330 -84
  57. diffusers/models/unets/unet_spatio_temporal_condition.py +1 -1
  58. diffusers/models/unets/unet_stable_cascade.py +1 -3
  59. diffusers/models/unets/uvit_2d.py +1 -1
  60. diffusers/models/upsampling.py +64 -0
  61. diffusers/models/vq_model.py +8 -4
  62. diffusers/optimization.py +1 -1
  63. diffusers/pipelines/__init__.py +100 -3
  64. diffusers/pipelines/animatediff/__init__.py +4 -0
  65. diffusers/pipelines/animatediff/pipeline_animatediff.py +50 -40
  66. diffusers/pipelines/animatediff/pipeline_animatediff_controlnet.py +1076 -0
  67. diffusers/pipelines/animatediff/pipeline_animatediff_sdxl.py +17 -27
  68. diffusers/pipelines/animatediff/pipeline_animatediff_sparsectrl.py +1008 -0
  69. diffusers/pipelines/animatediff/pipeline_animatediff_video2video.py +51 -38
  70. diffusers/pipelines/audioldm2/modeling_audioldm2.py +1 -1
  71. diffusers/pipelines/audioldm2/pipeline_audioldm2.py +1 -0
  72. diffusers/pipelines/aura_flow/__init__.py +48 -0
  73. diffusers/pipelines/aura_flow/pipeline_aura_flow.py +591 -0
  74. diffusers/pipelines/auto_pipeline.py +97 -19
  75. diffusers/pipelines/cogvideo/__init__.py +48 -0
  76. diffusers/pipelines/cogvideo/pipeline_cogvideox.py +687 -0
  77. diffusers/pipelines/consistency_models/pipeline_consistency_models.py +1 -1
  78. diffusers/pipelines/controlnet/pipeline_controlnet.py +24 -30
  79. diffusers/pipelines/controlnet/pipeline_controlnet_img2img.py +31 -30
  80. diffusers/pipelines/controlnet/pipeline_controlnet_inpaint.py +24 -153
  81. diffusers/pipelines/controlnet/pipeline_controlnet_inpaint_sd_xl.py +19 -28
  82. diffusers/pipelines/controlnet/pipeline_controlnet_sd_xl.py +18 -28
  83. diffusers/pipelines/controlnet/pipeline_controlnet_sd_xl_img2img.py +29 -32
  84. diffusers/pipelines/controlnet/pipeline_flax_controlnet.py +2 -2
  85. diffusers/pipelines/controlnet_hunyuandit/__init__.py +48 -0
  86. diffusers/pipelines/controlnet_hunyuandit/pipeline_hunyuandit_controlnet.py +1042 -0
  87. diffusers/pipelines/controlnet_sd3/pipeline_stable_diffusion_3_controlnet.py +35 -0
  88. diffusers/pipelines/controlnet_xs/pipeline_controlnet_xs.py +10 -6
  89. diffusers/pipelines/controlnet_xs/pipeline_controlnet_xs_sd_xl.py +0 -4
  90. diffusers/pipelines/deepfloyd_if/pipeline_if.py +2 -2
  91. diffusers/pipelines/deepfloyd_if/pipeline_if_img2img.py +2 -2
  92. diffusers/pipelines/deepfloyd_if/pipeline_if_img2img_superresolution.py +2 -2
  93. diffusers/pipelines/deepfloyd_if/pipeline_if_inpainting.py +2 -2
  94. diffusers/pipelines/deepfloyd_if/pipeline_if_inpainting_superresolution.py +2 -2
  95. diffusers/pipelines/deepfloyd_if/pipeline_if_superresolution.py +2 -2
  96. diffusers/pipelines/deprecated/alt_diffusion/pipeline_alt_diffusion.py +11 -6
  97. diffusers/pipelines/deprecated/alt_diffusion/pipeline_alt_diffusion_img2img.py +11 -6
  98. diffusers/pipelines/deprecated/stable_diffusion_variants/pipeline_cycle_diffusion.py +6 -6
  99. diffusers/pipelines/deprecated/stable_diffusion_variants/pipeline_stable_diffusion_inpaint_legacy.py +6 -6
  100. diffusers/pipelines/deprecated/stable_diffusion_variants/pipeline_stable_diffusion_model_editing.py +10 -10
  101. diffusers/pipelines/deprecated/stable_diffusion_variants/pipeline_stable_diffusion_paradigms.py +10 -6
  102. diffusers/pipelines/deprecated/stable_diffusion_variants/pipeline_stable_diffusion_pix2pix_zero.py +3 -3
  103. diffusers/pipelines/deprecated/versatile_diffusion/modeling_text_unet.py +1 -1
  104. diffusers/pipelines/flux/__init__.py +47 -0
  105. diffusers/pipelines/flux/pipeline_flux.py +749 -0
  106. diffusers/pipelines/flux/pipeline_output.py +21 -0
  107. diffusers/pipelines/free_init_utils.py +2 -0
  108. diffusers/pipelines/free_noise_utils.py +236 -0
  109. diffusers/pipelines/kandinsky3/pipeline_kandinsky3.py +2 -2
  110. diffusers/pipelines/kandinsky3/pipeline_kandinsky3_img2img.py +2 -2
  111. diffusers/pipelines/kolors/__init__.py +54 -0
  112. diffusers/pipelines/kolors/pipeline_kolors.py +1070 -0
  113. diffusers/pipelines/kolors/pipeline_kolors_img2img.py +1247 -0
  114. diffusers/pipelines/kolors/pipeline_output.py +21 -0
  115. diffusers/pipelines/kolors/text_encoder.py +889 -0
  116. diffusers/pipelines/kolors/tokenizer.py +334 -0
  117. diffusers/pipelines/latent_consistency_models/pipeline_latent_consistency_img2img.py +30 -29
  118. diffusers/pipelines/latent_consistency_models/pipeline_latent_consistency_text2img.py +23 -29
  119. diffusers/pipelines/latte/__init__.py +48 -0
  120. diffusers/pipelines/latte/pipeline_latte.py +881 -0
  121. diffusers/pipelines/ledits_pp/pipeline_leditspp_stable_diffusion.py +4 -4
  122. diffusers/pipelines/ledits_pp/pipeline_leditspp_stable_diffusion_xl.py +0 -4
  123. diffusers/pipelines/lumina/__init__.py +48 -0
  124. diffusers/pipelines/lumina/pipeline_lumina.py +897 -0
  125. diffusers/pipelines/pag/__init__.py +67 -0
  126. diffusers/pipelines/pag/pag_utils.py +237 -0
  127. diffusers/pipelines/pag/pipeline_pag_controlnet_sd.py +1329 -0
  128. diffusers/pipelines/pag/pipeline_pag_controlnet_sd_xl.py +1612 -0
  129. diffusers/pipelines/pag/pipeline_pag_hunyuandit.py +953 -0
  130. diffusers/pipelines/pag/pipeline_pag_kolors.py +1136 -0
  131. diffusers/pipelines/pag/pipeline_pag_pixart_sigma.py +872 -0
  132. diffusers/pipelines/pag/pipeline_pag_sd.py +1050 -0
  133. diffusers/pipelines/pag/pipeline_pag_sd_3.py +985 -0
  134. diffusers/pipelines/pag/pipeline_pag_sd_animatediff.py +862 -0
  135. diffusers/pipelines/pag/pipeline_pag_sd_xl.py +1333 -0
  136. diffusers/pipelines/pag/pipeline_pag_sd_xl_img2img.py +1529 -0
  137. diffusers/pipelines/pag/pipeline_pag_sd_xl_inpaint.py +1753 -0
  138. diffusers/pipelines/pia/pipeline_pia.py +30 -37
  139. diffusers/pipelines/pipeline_flax_utils.py +4 -9
  140. diffusers/pipelines/pipeline_loading_utils.py +0 -3
  141. diffusers/pipelines/pipeline_utils.py +2 -14
  142. diffusers/pipelines/semantic_stable_diffusion/pipeline_semantic_stable_diffusion.py +0 -1
  143. diffusers/pipelines/stable_audio/__init__.py +50 -0
  144. diffusers/pipelines/stable_audio/modeling_stable_audio.py +158 -0
  145. diffusers/pipelines/stable_audio/pipeline_stable_audio.py +745 -0
  146. diffusers/pipelines/stable_diffusion/convert_from_ckpt.py +2 -0
  147. diffusers/pipelines/stable_diffusion/pipeline_flax_stable_diffusion.py +1 -1
  148. diffusers/pipelines/stable_diffusion/pipeline_stable_diffusion.py +23 -29
  149. diffusers/pipelines/stable_diffusion/pipeline_stable_diffusion_depth2img.py +15 -8
  150. diffusers/pipelines/stable_diffusion/pipeline_stable_diffusion_img2img.py +30 -29
  151. diffusers/pipelines/stable_diffusion/pipeline_stable_diffusion_inpaint.py +23 -152
  152. diffusers/pipelines/stable_diffusion/pipeline_stable_diffusion_instruct_pix2pix.py +8 -4
  153. diffusers/pipelines/stable_diffusion/pipeline_stable_diffusion_upscale.py +11 -11
  154. diffusers/pipelines/stable_diffusion/pipeline_stable_unclip.py +8 -6
  155. diffusers/pipelines/stable_diffusion/pipeline_stable_unclip_img2img.py +6 -6
  156. diffusers/pipelines/stable_diffusion_3/__init__.py +2 -0
  157. diffusers/pipelines/stable_diffusion_3/pipeline_stable_diffusion_3.py +34 -3
  158. diffusers/pipelines/stable_diffusion_3/pipeline_stable_diffusion_3_img2img.py +33 -7
  159. diffusers/pipelines/stable_diffusion_3/pipeline_stable_diffusion_3_inpaint.py +1201 -0
  160. diffusers/pipelines/stable_diffusion_attend_and_excite/pipeline_stable_diffusion_attend_and_excite.py +3 -3
  161. diffusers/pipelines/stable_diffusion_diffedit/pipeline_stable_diffusion_diffedit.py +6 -6
  162. diffusers/pipelines/stable_diffusion_gligen/pipeline_stable_diffusion_gligen.py +5 -5
  163. diffusers/pipelines/stable_diffusion_gligen/pipeline_stable_diffusion_gligen_text_image.py +5 -5
  164. diffusers/pipelines/stable_diffusion_k_diffusion/pipeline_stable_diffusion_k_diffusion.py +6 -6
  165. diffusers/pipelines/stable_diffusion_k_diffusion/pipeline_stable_diffusion_xl_k_diffusion.py +0 -4
  166. diffusers/pipelines/stable_diffusion_ldm3d/pipeline_stable_diffusion_ldm3d.py +23 -29
  167. diffusers/pipelines/stable_diffusion_panorama/pipeline_stable_diffusion_panorama.py +27 -29
  168. diffusers/pipelines/stable_diffusion_sag/pipeline_stable_diffusion_sag.py +3 -3
  169. diffusers/pipelines/stable_diffusion_xl/pipeline_stable_diffusion_xl.py +17 -27
  170. diffusers/pipelines/stable_diffusion_xl/pipeline_stable_diffusion_xl_img2img.py +26 -29
  171. diffusers/pipelines/stable_diffusion_xl/pipeline_stable_diffusion_xl_inpaint.py +17 -145
  172. diffusers/pipelines/stable_diffusion_xl/pipeline_stable_diffusion_xl_instruct_pix2pix.py +0 -4
  173. diffusers/pipelines/t2i_adapter/pipeline_stable_diffusion_adapter.py +6 -6
  174. diffusers/pipelines/t2i_adapter/pipeline_stable_diffusion_xl_adapter.py +18 -28
  175. diffusers/pipelines/text_to_video_synthesis/pipeline_text_to_video_synth.py +8 -6
  176. diffusers/pipelines/text_to_video_synthesis/pipeline_text_to_video_synth_img2img.py +8 -6
  177. diffusers/pipelines/text_to_video_synthesis/pipeline_text_to_video_zero.py +6 -4
  178. diffusers/pipelines/text_to_video_synthesis/pipeline_text_to_video_zero_sdxl.py +0 -4
  179. diffusers/pipelines/unidiffuser/pipeline_unidiffuser.py +3 -3
  180. diffusers/pipelines/wuerstchen/modeling_wuerstchen_prior.py +1 -1
  181. diffusers/pipelines/wuerstchen/pipeline_wuerstchen_prior.py +5 -4
  182. diffusers/schedulers/__init__.py +8 -0
  183. diffusers/schedulers/scheduling_cosine_dpmsolver_multistep.py +572 -0
  184. diffusers/schedulers/scheduling_ddim.py +1 -1
  185. diffusers/schedulers/scheduling_ddim_cogvideox.py +449 -0
  186. diffusers/schedulers/scheduling_ddpm.py +1 -1
  187. diffusers/schedulers/scheduling_ddpm_parallel.py +1 -1
  188. diffusers/schedulers/scheduling_deis_multistep.py +2 -2
  189. diffusers/schedulers/scheduling_dpm_cogvideox.py +489 -0
  190. diffusers/schedulers/scheduling_dpmsolver_multistep.py +1 -1
  191. diffusers/schedulers/scheduling_dpmsolver_multistep_inverse.py +1 -1
  192. diffusers/schedulers/scheduling_dpmsolver_singlestep.py +64 -19
  193. diffusers/schedulers/scheduling_edm_dpmsolver_multistep.py +2 -2
  194. diffusers/schedulers/scheduling_flow_match_euler_discrete.py +63 -39
  195. diffusers/schedulers/scheduling_flow_match_heun_discrete.py +321 -0
  196. diffusers/schedulers/scheduling_ipndm.py +1 -1
  197. diffusers/schedulers/scheduling_unipc_multistep.py +1 -1
  198. diffusers/schedulers/scheduling_utils.py +1 -3
  199. diffusers/schedulers/scheduling_utils_flax.py +1 -3
  200. diffusers/training_utils.py +99 -14
  201. diffusers/utils/__init__.py +2 -2
  202. diffusers/utils/dummy_pt_objects.py +210 -0
  203. diffusers/utils/dummy_torch_and_torchsde_objects.py +15 -0
  204. diffusers/utils/dummy_torch_and_transformers_and_sentencepiece_objects.py +47 -0
  205. diffusers/utils/dummy_torch_and_transformers_objects.py +315 -0
  206. diffusers/utils/dynamic_modules_utils.py +1 -11
  207. diffusers/utils/export_utils.py +1 -4
  208. diffusers/utils/hub_utils.py +45 -42
  209. diffusers/utils/import_utils.py +19 -16
  210. diffusers/utils/loading_utils.py +76 -3
  211. diffusers/utils/testing_utils.py +11 -8
  212. {diffusers-0.29.2.dist-info → diffusers-0.30.0.dist-info}/METADATA +73 -83
  213. {diffusers-0.29.2.dist-info → diffusers-0.30.0.dist-info}/RECORD +217 -164
  214. {diffusers-0.29.2.dist-info → diffusers-0.30.0.dist-info}/WHEEL +1 -1
  215. diffusers/loaders/autoencoder.py +0 -146
  216. diffusers/loaders/controlnet.py +0 -136
  217. diffusers/loaders/lora.py +0 -1728
  218. {diffusers-0.29.2.dist-info → diffusers-0.30.0.dist-info}/LICENSE +0 -0
  219. {diffusers-0.29.2.dist-info → diffusers-0.30.0.dist-info}/entry_points.txt +0 -0
  220. {diffusers-0.29.2.dist-info → diffusers-0.30.0.dist-info}/top_level.txt +0 -0
@@ -253,7 +253,7 @@ class AutoencoderKLTemporalDecoder(ModelMixin, ConfigMixin):
253
253
 
254
254
  def fn_recursive_add_processors(name: str, module: torch.nn.Module, processors: Dict[str, AttentionProcessor]):
255
255
  if hasattr(module, "get_processor"):
256
- processors[f"{name}.processor"] = module.get_processor(return_deprecated_lora=True)
256
+ processors[f"{name}.processor"] = module.get_processor()
257
257
 
258
258
  for sub_name, child in module.named_children():
259
259
  fn_recursive_add_processors(f"{name}.{sub_name}", child, processors)
@@ -0,0 +1,464 @@
1
+ # Copyright 2024 The HuggingFace Team. All rights reserved.
2
+ #
3
+ # Licensed under the Apache License, Version 2.0 (the "License");
4
+ # you may not use this file except in compliance with the License.
5
+ # You may obtain a copy of the License at
6
+ #
7
+ # http://www.apache.org/licenses/LICENSE-2.0
8
+ #
9
+ # Unless required by applicable law or agreed to in writing, software
10
+ # distributed under the License is distributed on an "AS IS" BASIS,
11
+ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12
+ # See the License for the specific language governing permissions and
13
+ # limitations under the License.
14
+ import math
15
+ from dataclasses import dataclass
16
+ from typing import Optional, Tuple, Union
17
+
18
+ import numpy as np
19
+ import torch
20
+ import torch.nn as nn
21
+ from torch.nn.utils import weight_norm
22
+
23
+ from ...configuration_utils import ConfigMixin, register_to_config
24
+ from ...utils import BaseOutput
25
+ from ...utils.accelerate_utils import apply_forward_hook
26
+ from ...utils.torch_utils import randn_tensor
27
+ from ..modeling_utils import ModelMixin
28
+
29
+
30
+ class Snake1d(nn.Module):
31
+ """
32
+ A 1-dimensional Snake activation function module.
33
+ """
34
+
35
+ def __init__(self, hidden_dim, logscale=True):
36
+ super().__init__()
37
+ self.alpha = nn.Parameter(torch.zeros(1, hidden_dim, 1))
38
+ self.beta = nn.Parameter(torch.zeros(1, hidden_dim, 1))
39
+
40
+ self.alpha.requires_grad = True
41
+ self.beta.requires_grad = True
42
+ self.logscale = logscale
43
+
44
+ def forward(self, hidden_states):
45
+ shape = hidden_states.shape
46
+
47
+ alpha = self.alpha if not self.logscale else torch.exp(self.alpha)
48
+ beta = self.beta if not self.logscale else torch.exp(self.beta)
49
+
50
+ hidden_states = hidden_states.reshape(shape[0], shape[1], -1)
51
+ hidden_states = hidden_states + (beta + 1e-9).reciprocal() * torch.sin(alpha * hidden_states).pow(2)
52
+ hidden_states = hidden_states.reshape(shape)
53
+ return hidden_states
54
+
55
+
56
+ class OobleckResidualUnit(nn.Module):
57
+ """
58
+ A residual unit composed of Snake1d and weight-normalized Conv1d layers with dilations.
59
+ """
60
+
61
+ def __init__(self, dimension: int = 16, dilation: int = 1):
62
+ super().__init__()
63
+ pad = ((7 - 1) * dilation) // 2
64
+
65
+ self.snake1 = Snake1d(dimension)
66
+ self.conv1 = weight_norm(nn.Conv1d(dimension, dimension, kernel_size=7, dilation=dilation, padding=pad))
67
+ self.snake2 = Snake1d(dimension)
68
+ self.conv2 = weight_norm(nn.Conv1d(dimension, dimension, kernel_size=1))
69
+
70
+ def forward(self, hidden_state):
71
+ """
72
+ Forward pass through the residual unit.
73
+
74
+ Args:
75
+ hidden_state (`torch.Tensor` of shape `(batch_size, channels, time_steps)`):
76
+ Input tensor .
77
+
78
+ Returns:
79
+ output_tensor (`torch.Tensor` of shape `(batch_size, channels, time_steps)`)
80
+ Input tensor after passing through the residual unit.
81
+ """
82
+ output_tensor = hidden_state
83
+ output_tensor = self.conv1(self.snake1(output_tensor))
84
+ output_tensor = self.conv2(self.snake2(output_tensor))
85
+
86
+ padding = (hidden_state.shape[-1] - output_tensor.shape[-1]) // 2
87
+ if padding > 0:
88
+ hidden_state = hidden_state[..., padding:-padding]
89
+ output_tensor = hidden_state + output_tensor
90
+ return output_tensor
91
+
92
+
93
+ class OobleckEncoderBlock(nn.Module):
94
+ """Encoder block used in Oobleck encoder."""
95
+
96
+ def __init__(self, input_dim, output_dim, stride: int = 1):
97
+ super().__init__()
98
+
99
+ self.res_unit1 = OobleckResidualUnit(input_dim, dilation=1)
100
+ self.res_unit2 = OobleckResidualUnit(input_dim, dilation=3)
101
+ self.res_unit3 = OobleckResidualUnit(input_dim, dilation=9)
102
+ self.snake1 = Snake1d(input_dim)
103
+ self.conv1 = weight_norm(
104
+ nn.Conv1d(input_dim, output_dim, kernel_size=2 * stride, stride=stride, padding=math.ceil(stride / 2))
105
+ )
106
+
107
+ def forward(self, hidden_state):
108
+ hidden_state = self.res_unit1(hidden_state)
109
+ hidden_state = self.res_unit2(hidden_state)
110
+ hidden_state = self.snake1(self.res_unit3(hidden_state))
111
+ hidden_state = self.conv1(hidden_state)
112
+
113
+ return hidden_state
114
+
115
+
116
+ class OobleckDecoderBlock(nn.Module):
117
+ """Decoder block used in Oobleck decoder."""
118
+
119
+ def __init__(self, input_dim, output_dim, stride: int = 1):
120
+ super().__init__()
121
+
122
+ self.snake1 = Snake1d(input_dim)
123
+ self.conv_t1 = weight_norm(
124
+ nn.ConvTranspose1d(
125
+ input_dim,
126
+ output_dim,
127
+ kernel_size=2 * stride,
128
+ stride=stride,
129
+ padding=math.ceil(stride / 2),
130
+ )
131
+ )
132
+ self.res_unit1 = OobleckResidualUnit(output_dim, dilation=1)
133
+ self.res_unit2 = OobleckResidualUnit(output_dim, dilation=3)
134
+ self.res_unit3 = OobleckResidualUnit(output_dim, dilation=9)
135
+
136
+ def forward(self, hidden_state):
137
+ hidden_state = self.snake1(hidden_state)
138
+ hidden_state = self.conv_t1(hidden_state)
139
+ hidden_state = self.res_unit1(hidden_state)
140
+ hidden_state = self.res_unit2(hidden_state)
141
+ hidden_state = self.res_unit3(hidden_state)
142
+
143
+ return hidden_state
144
+
145
+
146
+ class OobleckDiagonalGaussianDistribution(object):
147
+ def __init__(self, parameters: torch.Tensor, deterministic: bool = False):
148
+ self.parameters = parameters
149
+ self.mean, self.scale = parameters.chunk(2, dim=1)
150
+ self.std = nn.functional.softplus(self.scale) + 1e-4
151
+ self.var = self.std * self.std
152
+ self.logvar = torch.log(self.var)
153
+ self.deterministic = deterministic
154
+
155
+ def sample(self, generator: Optional[torch.Generator] = None) -> torch.Tensor:
156
+ # make sure sample is on the same device as the parameters and has same dtype
157
+ sample = randn_tensor(
158
+ self.mean.shape,
159
+ generator=generator,
160
+ device=self.parameters.device,
161
+ dtype=self.parameters.dtype,
162
+ )
163
+ x = self.mean + self.std * sample
164
+ return x
165
+
166
+ def kl(self, other: "OobleckDiagonalGaussianDistribution" = None) -> torch.Tensor:
167
+ if self.deterministic:
168
+ return torch.Tensor([0.0])
169
+ else:
170
+ if other is None:
171
+ return (self.mean * self.mean + self.var - self.logvar - 1.0).sum(1).mean()
172
+ else:
173
+ normalized_diff = torch.pow(self.mean - other.mean, 2) / other.var
174
+ var_ratio = self.var / other.var
175
+ logvar_diff = self.logvar - other.logvar
176
+
177
+ kl = normalized_diff + var_ratio + logvar_diff - 1
178
+
179
+ kl = kl.sum(1).mean()
180
+ return kl
181
+
182
+ def mode(self) -> torch.Tensor:
183
+ return self.mean
184
+
185
+
186
+ @dataclass
187
+ class AutoencoderOobleckOutput(BaseOutput):
188
+ """
189
+ Output of AutoencoderOobleck encoding method.
190
+
191
+ Args:
192
+ latent_dist (`OobleckDiagonalGaussianDistribution`):
193
+ Encoded outputs of `Encoder` represented as the mean and standard deviation of
194
+ `OobleckDiagonalGaussianDistribution`. `OobleckDiagonalGaussianDistribution` allows for sampling latents
195
+ from the distribution.
196
+ """
197
+
198
+ latent_dist: "OobleckDiagonalGaussianDistribution" # noqa: F821
199
+
200
+
201
+ @dataclass
202
+ class OobleckDecoderOutput(BaseOutput):
203
+ r"""
204
+ Output of decoding method.
205
+
206
+ Args:
207
+ sample (`torch.Tensor` of shape `(batch_size, audio_channels, sequence_length)`):
208
+ The decoded output sample from the last layer of the model.
209
+ """
210
+
211
+ sample: torch.Tensor
212
+
213
+
214
+ class OobleckEncoder(nn.Module):
215
+ """Oobleck Encoder"""
216
+
217
+ def __init__(self, encoder_hidden_size, audio_channels, downsampling_ratios, channel_multiples):
218
+ super().__init__()
219
+
220
+ strides = downsampling_ratios
221
+ channel_multiples = [1] + channel_multiples
222
+
223
+ # Create first convolution
224
+ self.conv1 = weight_norm(nn.Conv1d(audio_channels, encoder_hidden_size, kernel_size=7, padding=3))
225
+
226
+ self.block = []
227
+ # Create EncoderBlocks that double channels as they downsample by `stride`
228
+ for stride_index, stride in enumerate(strides):
229
+ self.block += [
230
+ OobleckEncoderBlock(
231
+ input_dim=encoder_hidden_size * channel_multiples[stride_index],
232
+ output_dim=encoder_hidden_size * channel_multiples[stride_index + 1],
233
+ stride=stride,
234
+ )
235
+ ]
236
+
237
+ self.block = nn.ModuleList(self.block)
238
+ d_model = encoder_hidden_size * channel_multiples[-1]
239
+ self.snake1 = Snake1d(d_model)
240
+ self.conv2 = weight_norm(nn.Conv1d(d_model, encoder_hidden_size, kernel_size=3, padding=1))
241
+
242
+ def forward(self, hidden_state):
243
+ hidden_state = self.conv1(hidden_state)
244
+
245
+ for module in self.block:
246
+ hidden_state = module(hidden_state)
247
+
248
+ hidden_state = self.snake1(hidden_state)
249
+ hidden_state = self.conv2(hidden_state)
250
+
251
+ return hidden_state
252
+
253
+
254
+ class OobleckDecoder(nn.Module):
255
+ """Oobleck Decoder"""
256
+
257
+ def __init__(self, channels, input_channels, audio_channels, upsampling_ratios, channel_multiples):
258
+ super().__init__()
259
+
260
+ strides = upsampling_ratios
261
+ channel_multiples = [1] + channel_multiples
262
+
263
+ # Add first conv layer
264
+ self.conv1 = weight_norm(nn.Conv1d(input_channels, channels * channel_multiples[-1], kernel_size=7, padding=3))
265
+
266
+ # Add upsampling + MRF blocks
267
+ block = []
268
+ for stride_index, stride in enumerate(strides):
269
+ block += [
270
+ OobleckDecoderBlock(
271
+ input_dim=channels * channel_multiples[len(strides) - stride_index],
272
+ output_dim=channels * channel_multiples[len(strides) - stride_index - 1],
273
+ stride=stride,
274
+ )
275
+ ]
276
+
277
+ self.block = nn.ModuleList(block)
278
+ output_dim = channels
279
+ self.snake1 = Snake1d(output_dim)
280
+ self.conv2 = weight_norm(nn.Conv1d(channels, audio_channels, kernel_size=7, padding=3, bias=False))
281
+
282
+ def forward(self, hidden_state):
283
+ hidden_state = self.conv1(hidden_state)
284
+
285
+ for layer in self.block:
286
+ hidden_state = layer(hidden_state)
287
+
288
+ hidden_state = self.snake1(hidden_state)
289
+ hidden_state = self.conv2(hidden_state)
290
+
291
+ return hidden_state
292
+
293
+
294
+ class AutoencoderOobleck(ModelMixin, ConfigMixin):
295
+ r"""
296
+ An autoencoder for encoding waveforms into latents and decoding latent representations into waveforms. First
297
+ introduced in Stable Audio.
298
+
299
+ This model inherits from [`ModelMixin`]. Check the superclass documentation for it's generic methods implemented
300
+ for all models (such as downloading or saving).
301
+
302
+ Parameters:
303
+ encoder_hidden_size (`int`, *optional*, defaults to 128):
304
+ Intermediate representation dimension for the encoder.
305
+ downsampling_ratios (`List[int]`, *optional*, defaults to `[2, 4, 4, 8, 8]`):
306
+ Ratios for downsampling in the encoder. These are used in reverse order for upsampling in the decoder.
307
+ channel_multiples (`List[int]`, *optional*, defaults to `[1, 2, 4, 8, 16]`):
308
+ Multiples used to determine the hidden sizes of the hidden layers.
309
+ decoder_channels (`int`, *optional*, defaults to 128):
310
+ Intermediate representation dimension for the decoder.
311
+ decoder_input_channels (`int`, *optional*, defaults to 64):
312
+ Input dimension for the decoder. Corresponds to the latent dimension.
313
+ audio_channels (`int`, *optional*, defaults to 2):
314
+ Number of channels in the audio data. Either 1 for mono or 2 for stereo.
315
+ sampling_rate (`int`, *optional*, defaults to 44100):
316
+ The sampling rate at which the audio waveform should be digitalized expressed in hertz (Hz).
317
+ """
318
+
319
+ _supports_gradient_checkpointing = False
320
+
321
+ @register_to_config
322
+ def __init__(
323
+ self,
324
+ encoder_hidden_size=128,
325
+ downsampling_ratios=[2, 4, 4, 8, 8],
326
+ channel_multiples=[1, 2, 4, 8, 16],
327
+ decoder_channels=128,
328
+ decoder_input_channels=64,
329
+ audio_channels=2,
330
+ sampling_rate=44100,
331
+ ):
332
+ super().__init__()
333
+
334
+ self.encoder_hidden_size = encoder_hidden_size
335
+ self.downsampling_ratios = downsampling_ratios
336
+ self.decoder_channels = decoder_channels
337
+ self.upsampling_ratios = downsampling_ratios[::-1]
338
+ self.hop_length = int(np.prod(downsampling_ratios))
339
+ self.sampling_rate = sampling_rate
340
+
341
+ self.encoder = OobleckEncoder(
342
+ encoder_hidden_size=encoder_hidden_size,
343
+ audio_channels=audio_channels,
344
+ downsampling_ratios=downsampling_ratios,
345
+ channel_multiples=channel_multiples,
346
+ )
347
+
348
+ self.decoder = OobleckDecoder(
349
+ channels=decoder_channels,
350
+ input_channels=decoder_input_channels,
351
+ audio_channels=audio_channels,
352
+ upsampling_ratios=self.upsampling_ratios,
353
+ channel_multiples=channel_multiples,
354
+ )
355
+
356
+ self.use_slicing = False
357
+
358
+ def enable_slicing(self):
359
+ r"""
360
+ Enable sliced VAE decoding. When this option is enabled, the VAE will split the input tensor in slices to
361
+ compute decoding in several steps. This is useful to save some memory and allow larger batch sizes.
362
+ """
363
+ self.use_slicing = True
364
+
365
+ def disable_slicing(self):
366
+ r"""
367
+ Disable sliced VAE decoding. If `enable_slicing` was previously enabled, this method will go back to computing
368
+ decoding in one step.
369
+ """
370
+ self.use_slicing = False
371
+
372
+ @apply_forward_hook
373
+ def encode(
374
+ self, x: torch.Tensor, return_dict: bool = True
375
+ ) -> Union[AutoencoderOobleckOutput, Tuple[OobleckDiagonalGaussianDistribution]]:
376
+ """
377
+ Encode a batch of images into latents.
378
+
379
+ Args:
380
+ x (`torch.Tensor`): Input batch of images.
381
+ return_dict (`bool`, *optional*, defaults to `True`):
382
+ Whether to return a [`~models.autoencoder_kl.AutoencoderKLOutput`] instead of a plain tuple.
383
+
384
+ Returns:
385
+ The latent representations of the encoded images. If `return_dict` is True, a
386
+ [`~models.autoencoder_kl.AutoencoderKLOutput`] is returned, otherwise a plain `tuple` is returned.
387
+ """
388
+ if self.use_slicing and x.shape[0] > 1:
389
+ encoded_slices = [self.encoder(x_slice) for x_slice in x.split(1)]
390
+ h = torch.cat(encoded_slices)
391
+ else:
392
+ h = self.encoder(x)
393
+
394
+ posterior = OobleckDiagonalGaussianDistribution(h)
395
+
396
+ if not return_dict:
397
+ return (posterior,)
398
+
399
+ return AutoencoderOobleckOutput(latent_dist=posterior)
400
+
401
+ def _decode(self, z: torch.Tensor, return_dict: bool = True) -> Union[OobleckDecoderOutput, torch.Tensor]:
402
+ dec = self.decoder(z)
403
+
404
+ if not return_dict:
405
+ return (dec,)
406
+
407
+ return OobleckDecoderOutput(sample=dec)
408
+
409
+ @apply_forward_hook
410
+ def decode(
411
+ self, z: torch.FloatTensor, return_dict: bool = True, generator=None
412
+ ) -> Union[OobleckDecoderOutput, torch.FloatTensor]:
413
+ """
414
+ Decode a batch of images.
415
+
416
+ Args:
417
+ z (`torch.Tensor`): Input batch of latent vectors.
418
+ return_dict (`bool`, *optional*, defaults to `True`):
419
+ Whether to return a [`~models.vae.OobleckDecoderOutput`] instead of a plain tuple.
420
+
421
+ Returns:
422
+ [`~models.vae.OobleckDecoderOutput`] or `tuple`:
423
+ If return_dict is True, a [`~models.vae.OobleckDecoderOutput`] is returned, otherwise a plain `tuple`
424
+ is returned.
425
+
426
+ """
427
+ if self.use_slicing and z.shape[0] > 1:
428
+ decoded_slices = [self._decode(z_slice).sample for z_slice in z.split(1)]
429
+ decoded = torch.cat(decoded_slices)
430
+ else:
431
+ decoded = self._decode(z).sample
432
+
433
+ if not return_dict:
434
+ return (decoded,)
435
+
436
+ return OobleckDecoderOutput(sample=decoded)
437
+
438
+ def forward(
439
+ self,
440
+ sample: torch.Tensor,
441
+ sample_posterior: bool = False,
442
+ return_dict: bool = True,
443
+ generator: Optional[torch.Generator] = None,
444
+ ) -> Union[OobleckDecoderOutput, torch.Tensor]:
445
+ r"""
446
+ Args:
447
+ sample (`torch.Tensor`): Input sample.
448
+ sample_posterior (`bool`, *optional*, defaults to `False`):
449
+ Whether to sample from the posterior.
450
+ return_dict (`bool`, *optional*, defaults to `True`):
451
+ Whether or not to return a [`OobleckDecoderOutput`] instead of a plain tuple.
452
+ """
453
+ x = sample
454
+ posterior = self.encode(x).latent_dist
455
+ if sample_posterior:
456
+ z = posterior.sample(generator=generator)
457
+ else:
458
+ z = posterior.mode()
459
+ dec = self.decode(z).sample
460
+
461
+ if not return_dict:
462
+ return (dec,)
463
+
464
+ return OobleckDecoderOutput(sample=dec)
@@ -111,6 +111,7 @@ class AutoencoderTiny(ModelMixin, ConfigMixin):
111
111
  latent_shift: float = 0.5,
112
112
  force_upcast: bool = False,
113
113
  scaling_factor: float = 1.0,
114
+ shift_factor: float = 0.0,
114
115
  ):
115
116
  super().__init__()
116
117
 
@@ -211,7 +211,7 @@ class ConsistencyDecoderVAE(ModelMixin, ConfigMixin):
211
211
 
212
212
  def fn_recursive_add_processors(name: str, module: torch.nn.Module, processors: Dict[str, AttentionProcessor]):
213
213
  if hasattr(module, "get_processor"):
214
- processors[f"{name}.processor"] = module.get_processor(return_deprecated_lora=True)
214
+ processors[f"{name}.processor"] = module.get_processor()
215
215
 
216
216
  for sub_name, child in module.named_children():
217
217
  fn_recursive_add_processors(f"{name}.{sub_name}", child, processors)
@@ -166,12 +166,12 @@ class VQModel(ModelMixin, ConfigMixin):
166
166
  Args:
167
167
  sample (`torch.Tensor`): Input sample.
168
168
  return_dict (`bool`, *optional*, defaults to `True`):
169
- Whether or not to return a [`models.vq_model.VQEncoderOutput`] instead of a plain tuple.
169
+ Whether or not to return a [`models.autoencoders.vq_model.VQEncoderOutput`] instead of a plain tuple.
170
170
 
171
171
  Returns:
172
- [`~models.vq_model.VQEncoderOutput`] or `tuple`:
173
- If return_dict is True, a [`~models.vq_model.VQEncoderOutput`] is returned, otherwise a plain `tuple`
174
- is returned.
172
+ [`~models.autoencoders.vq_model.VQEncoderOutput`] or `tuple`:
173
+ If return_dict is True, a [`~models.autoencoders.vq_model.VQEncoderOutput`] is returned, otherwise a
174
+ plain `tuple` is returned.
175
175
  """
176
176
 
177
177
  h = self.encode(sample).latents
@@ -54,7 +54,7 @@ class ControlNetOutput(BaseOutput):
54
54
  be of shape `(batch_size, channel * resolution, height //resolution, width // resolution)`. Output can be
55
55
  used to condition the original UNet's downsampling activations.
56
56
  mid_down_block_re_sample (`torch.Tensor`):
57
- The activation of the midde block (the lowest sample resolution). Each tensor should be of shape
57
+ The activation of the middle block (the lowest sample resolution). Each tensor should be of shape
58
58
  `(batch_size, channel * lowest_resolution, height // lowest_resolution, width // lowest_resolution)`.
59
59
  Output can be used to condition the original UNet's middle block activation.
60
60
  """
@@ -530,7 +530,7 @@ class ControlNetModel(ModelMixin, ConfigMixin, FromOriginalModelMixin):
530
530
 
531
531
  def fn_recursive_add_processors(name: str, module: torch.nn.Module, processors: Dict[str, AttentionProcessor]):
532
532
  if hasattr(module, "get_processor"):
533
- processors[f"{name}.processor"] = module.get_processor(return_deprecated_lora=True)
533
+ processors[f"{name}.processor"] = module.get_processor()
534
534
 
535
535
  for sub_name, child in module.named_children():
536
536
  fn_recursive_add_processors(f"{name}.{sub_name}", child, processors)
@@ -830,7 +830,6 @@ class ControlNetModel(ModelMixin, ConfigMixin, FromOriginalModelMixin):
830
830
  sample = self.mid_block(sample, emb)
831
831
 
832
832
  # 5. Control net blocks
833
-
834
833
  controlnet_down_block_res_samples = ()
835
834
 
836
835
  for down_block_res_sample, controlnet_block in zip(down_block_res_samples, self.controlnet_down_blocks):