diffusers 0.31.0__py3-none-any.whl → 0.32.0__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- diffusers/__init__.py +66 -5
- diffusers/callbacks.py +56 -3
- diffusers/configuration_utils.py +1 -1
- diffusers/dependency_versions_table.py +1 -1
- diffusers/image_processor.py +25 -17
- diffusers/loaders/__init__.py +22 -3
- diffusers/loaders/ip_adapter.py +538 -15
- diffusers/loaders/lora_base.py +124 -118
- diffusers/loaders/lora_conversion_utils.py +318 -3
- diffusers/loaders/lora_pipeline.py +1688 -368
- diffusers/loaders/peft.py +379 -0
- diffusers/loaders/single_file_model.py +71 -4
- diffusers/loaders/single_file_utils.py +519 -9
- diffusers/loaders/textual_inversion.py +3 -3
- diffusers/loaders/transformer_flux.py +181 -0
- diffusers/loaders/transformer_sd3.py +89 -0
- diffusers/loaders/unet.py +17 -4
- diffusers/models/__init__.py +47 -14
- diffusers/models/activations.py +22 -9
- diffusers/models/attention.py +13 -4
- diffusers/models/attention_flax.py +1 -1
- diffusers/models/attention_processor.py +2059 -281
- diffusers/models/autoencoders/__init__.py +5 -0
- diffusers/models/autoencoders/autoencoder_dc.py +620 -0
- diffusers/models/autoencoders/autoencoder_kl.py +2 -1
- diffusers/models/autoencoders/autoencoder_kl_allegro.py +1149 -0
- diffusers/models/autoencoders/autoencoder_kl_cogvideox.py +36 -27
- diffusers/models/autoencoders/autoencoder_kl_hunyuan_video.py +1176 -0
- diffusers/models/autoencoders/autoencoder_kl_ltx.py +1338 -0
- diffusers/models/autoencoders/autoencoder_kl_mochi.py +1166 -0
- diffusers/models/autoencoders/autoencoder_kl_temporal_decoder.py +3 -10
- diffusers/models/autoencoders/autoencoder_tiny.py +4 -2
- diffusers/models/autoencoders/vae.py +18 -5
- diffusers/models/controlnet.py +47 -802
- diffusers/models/controlnet_flux.py +29 -495
- diffusers/models/controlnet_sd3.py +25 -379
- diffusers/models/controlnet_sparsectrl.py +46 -718
- diffusers/models/controlnets/__init__.py +23 -0
- diffusers/models/controlnets/controlnet.py +872 -0
- diffusers/models/{controlnet_flax.py → controlnets/controlnet_flax.py} +5 -5
- diffusers/models/controlnets/controlnet_flux.py +536 -0
- diffusers/models/{controlnet_hunyuan.py → controlnets/controlnet_hunyuan.py} +7 -7
- diffusers/models/controlnets/controlnet_sd3.py +489 -0
- diffusers/models/controlnets/controlnet_sparsectrl.py +788 -0
- diffusers/models/controlnets/controlnet_union.py +832 -0
- diffusers/models/{controlnet_xs.py → controlnets/controlnet_xs.py} +14 -13
- diffusers/models/controlnets/multicontrolnet.py +183 -0
- diffusers/models/embeddings.py +838 -43
- diffusers/models/model_loading_utils.py +88 -6
- diffusers/models/modeling_flax_utils.py +1 -1
- diffusers/models/modeling_utils.py +74 -28
- diffusers/models/normalization.py +78 -13
- diffusers/models/transformers/__init__.py +5 -0
- diffusers/models/transformers/auraflow_transformer_2d.py +2 -2
- diffusers/models/transformers/cogvideox_transformer_3d.py +46 -11
- diffusers/models/transformers/dit_transformer_2d.py +1 -1
- diffusers/models/transformers/latte_transformer_3d.py +4 -4
- diffusers/models/transformers/pixart_transformer_2d.py +1 -1
- diffusers/models/transformers/sana_transformer.py +488 -0
- diffusers/models/transformers/stable_audio_transformer.py +1 -1
- diffusers/models/transformers/transformer_2d.py +1 -1
- diffusers/models/transformers/transformer_allegro.py +422 -0
- diffusers/models/transformers/transformer_cogview3plus.py +1 -1
- diffusers/models/transformers/transformer_flux.py +30 -9
- diffusers/models/transformers/transformer_hunyuan_video.py +789 -0
- diffusers/models/transformers/transformer_ltx.py +469 -0
- diffusers/models/transformers/transformer_mochi.py +499 -0
- diffusers/models/transformers/transformer_sd3.py +105 -17
- diffusers/models/transformers/transformer_temporal.py +1 -1
- diffusers/models/unets/unet_1d_blocks.py +1 -1
- diffusers/models/unets/unet_2d.py +8 -1
- diffusers/models/unets/unet_2d_blocks.py +88 -21
- diffusers/models/unets/unet_2d_condition.py +1 -1
- diffusers/models/unets/unet_3d_blocks.py +9 -7
- diffusers/models/unets/unet_motion_model.py +5 -5
- diffusers/models/unets/unet_spatio_temporal_condition.py +23 -0
- diffusers/models/unets/unet_stable_cascade.py +2 -2
- diffusers/models/unets/uvit_2d.py +1 -1
- diffusers/models/upsampling.py +8 -0
- diffusers/pipelines/__init__.py +34 -0
- diffusers/pipelines/allegro/__init__.py +48 -0
- diffusers/pipelines/allegro/pipeline_allegro.py +938 -0
- diffusers/pipelines/allegro/pipeline_output.py +23 -0
- diffusers/pipelines/animatediff/pipeline_animatediff_controlnet.py +8 -2
- diffusers/pipelines/animatediff/pipeline_animatediff_sparsectrl.py +1 -1
- diffusers/pipelines/animatediff/pipeline_animatediff_video2video.py +0 -6
- diffusers/pipelines/animatediff/pipeline_animatediff_video2video_controlnet.py +8 -8
- diffusers/pipelines/audioldm2/modeling_audioldm2.py +3 -3
- diffusers/pipelines/aura_flow/pipeline_aura_flow.py +1 -8
- diffusers/pipelines/auto_pipeline.py +53 -6
- diffusers/pipelines/blip_diffusion/modeling_blip2.py +1 -1
- diffusers/pipelines/cogvideo/pipeline_cogvideox.py +50 -22
- diffusers/pipelines/cogvideo/pipeline_cogvideox_fun_control.py +51 -20
- diffusers/pipelines/cogvideo/pipeline_cogvideox_image2video.py +69 -21
- diffusers/pipelines/cogvideo/pipeline_cogvideox_video2video.py +47 -21
- diffusers/pipelines/cogview3/pipeline_cogview3plus.py +1 -1
- diffusers/pipelines/controlnet/__init__.py +86 -80
- diffusers/pipelines/controlnet/multicontrolnet.py +7 -178
- diffusers/pipelines/controlnet/pipeline_controlnet.py +11 -2
- diffusers/pipelines/controlnet/pipeline_controlnet_img2img.py +1 -2
- diffusers/pipelines/controlnet/pipeline_controlnet_inpaint.py +1 -2
- diffusers/pipelines/controlnet/pipeline_controlnet_inpaint_sd_xl.py +1 -2
- diffusers/pipelines/controlnet/pipeline_controlnet_sd_xl.py +3 -3
- diffusers/pipelines/controlnet/pipeline_controlnet_sd_xl_img2img.py +1 -3
- diffusers/pipelines/controlnet/pipeline_controlnet_union_inpaint_sd_xl.py +1790 -0
- diffusers/pipelines/controlnet/pipeline_controlnet_union_sd_xl.py +1501 -0
- diffusers/pipelines/controlnet/pipeline_controlnet_union_sd_xl_img2img.py +1627 -0
- diffusers/pipelines/controlnet_hunyuandit/pipeline_hunyuandit_controlnet.py +5 -1
- diffusers/pipelines/controlnet_sd3/pipeline_stable_diffusion_3_controlnet.py +53 -19
- diffusers/pipelines/controlnet_sd3/pipeline_stable_diffusion_3_controlnet_inpainting.py +7 -7
- diffusers/pipelines/deprecated/versatile_diffusion/modeling_text_unet.py +31 -8
- diffusers/pipelines/flux/__init__.py +13 -1
- diffusers/pipelines/flux/modeling_flux.py +47 -0
- diffusers/pipelines/flux/pipeline_flux.py +204 -29
- diffusers/pipelines/flux/pipeline_flux_control.py +889 -0
- diffusers/pipelines/flux/pipeline_flux_control_img2img.py +945 -0
- diffusers/pipelines/flux/pipeline_flux_control_inpaint.py +1141 -0
- diffusers/pipelines/flux/pipeline_flux_controlnet.py +49 -27
- diffusers/pipelines/flux/pipeline_flux_controlnet_image_to_image.py +40 -30
- diffusers/pipelines/flux/pipeline_flux_controlnet_inpainting.py +78 -56
- diffusers/pipelines/flux/pipeline_flux_fill.py +969 -0
- diffusers/pipelines/flux/pipeline_flux_img2img.py +33 -27
- diffusers/pipelines/flux/pipeline_flux_inpaint.py +36 -29
- diffusers/pipelines/flux/pipeline_flux_prior_redux.py +492 -0
- diffusers/pipelines/flux/pipeline_output.py +16 -0
- diffusers/pipelines/hunyuan_video/__init__.py +48 -0
- diffusers/pipelines/hunyuan_video/pipeline_hunyuan_video.py +687 -0
- diffusers/pipelines/hunyuan_video/pipeline_output.py +20 -0
- diffusers/pipelines/hunyuandit/pipeline_hunyuandit.py +5 -1
- diffusers/pipelines/kandinsky/pipeline_kandinsky_combined.py +9 -9
- diffusers/pipelines/kolors/text_encoder.py +2 -2
- diffusers/pipelines/latent_diffusion/pipeline_latent_diffusion.py +1 -1
- diffusers/pipelines/ltx/__init__.py +50 -0
- diffusers/pipelines/ltx/pipeline_ltx.py +789 -0
- diffusers/pipelines/ltx/pipeline_ltx_image2video.py +885 -0
- diffusers/pipelines/ltx/pipeline_output.py +20 -0
- diffusers/pipelines/lumina/pipeline_lumina.py +1 -8
- diffusers/pipelines/mochi/__init__.py +48 -0
- diffusers/pipelines/mochi/pipeline_mochi.py +748 -0
- diffusers/pipelines/mochi/pipeline_output.py +20 -0
- diffusers/pipelines/pag/__init__.py +7 -0
- diffusers/pipelines/pag/pipeline_pag_controlnet_sd.py +1 -2
- diffusers/pipelines/pag/pipeline_pag_controlnet_sd_inpaint.py +1 -2
- diffusers/pipelines/pag/pipeline_pag_controlnet_sd_xl.py +1 -3
- diffusers/pipelines/pag/pipeline_pag_controlnet_sd_xl_img2img.py +1 -3
- diffusers/pipelines/pag/pipeline_pag_hunyuandit.py +5 -1
- diffusers/pipelines/pag/pipeline_pag_pixart_sigma.py +6 -13
- diffusers/pipelines/pag/pipeline_pag_sana.py +886 -0
- diffusers/pipelines/pag/pipeline_pag_sd_3.py +6 -6
- diffusers/pipelines/pag/pipeline_pag_sd_3_img2img.py +1058 -0
- diffusers/pipelines/pag/pipeline_pag_sd_img2img.py +3 -0
- diffusers/pipelines/pag/pipeline_pag_sd_inpaint.py +1356 -0
- diffusers/pipelines/pipeline_flax_utils.py +1 -1
- diffusers/pipelines/pipeline_loading_utils.py +25 -4
- diffusers/pipelines/pipeline_utils.py +35 -6
- diffusers/pipelines/pixart_alpha/pipeline_pixart_alpha.py +6 -13
- diffusers/pipelines/pixart_alpha/pipeline_pixart_sigma.py +6 -13
- diffusers/pipelines/sana/__init__.py +47 -0
- diffusers/pipelines/sana/pipeline_output.py +21 -0
- diffusers/pipelines/sana/pipeline_sana.py +884 -0
- diffusers/pipelines/stable_audio/pipeline_stable_audio.py +12 -1
- diffusers/pipelines/stable_diffusion/pipeline_stable_diffusion.py +18 -3
- diffusers/pipelines/stable_diffusion_3/pipeline_stable_diffusion_3.py +216 -20
- diffusers/pipelines/stable_diffusion_3/pipeline_stable_diffusion_3_img2img.py +62 -9
- diffusers/pipelines/stable_diffusion_3/pipeline_stable_diffusion_3_inpaint.py +57 -8
- diffusers/pipelines/stable_diffusion_gligen/pipeline_stable_diffusion_gligen_text_image.py +11 -1
- diffusers/pipelines/stable_diffusion_xl/pipeline_stable_diffusion_xl.py +0 -8
- diffusers/pipelines/stable_diffusion_xl/pipeline_stable_diffusion_xl_img2img.py +0 -8
- diffusers/pipelines/stable_diffusion_xl/pipeline_stable_diffusion_xl_inpaint.py +0 -8
- diffusers/pipelines/unidiffuser/modeling_uvit.py +2 -2
- diffusers/pipelines/wuerstchen/modeling_wuerstchen_prior.py +1 -1
- diffusers/quantizers/auto.py +14 -1
- diffusers/quantizers/bitsandbytes/bnb_quantizer.py +4 -1
- diffusers/quantizers/gguf/__init__.py +1 -0
- diffusers/quantizers/gguf/gguf_quantizer.py +159 -0
- diffusers/quantizers/gguf/utils.py +456 -0
- diffusers/quantizers/quantization_config.py +280 -2
- diffusers/quantizers/torchao/__init__.py +15 -0
- diffusers/quantizers/torchao/torchao_quantizer.py +285 -0
- diffusers/schedulers/scheduling_ddpm.py +2 -6
- diffusers/schedulers/scheduling_ddpm_parallel.py +2 -6
- diffusers/schedulers/scheduling_deis_multistep.py +28 -9
- diffusers/schedulers/scheduling_dpmsolver_multistep.py +35 -9
- diffusers/schedulers/scheduling_dpmsolver_multistep_inverse.py +35 -8
- diffusers/schedulers/scheduling_dpmsolver_sde.py +4 -4
- diffusers/schedulers/scheduling_dpmsolver_singlestep.py +48 -10
- diffusers/schedulers/scheduling_euler_discrete.py +4 -4
- diffusers/schedulers/scheduling_flow_match_euler_discrete.py +153 -6
- diffusers/schedulers/scheduling_heun_discrete.py +4 -4
- diffusers/schedulers/scheduling_k_dpm_2_ancestral_discrete.py +4 -4
- diffusers/schedulers/scheduling_k_dpm_2_discrete.py +4 -4
- diffusers/schedulers/scheduling_lcm.py +2 -6
- diffusers/schedulers/scheduling_lms_discrete.py +4 -4
- diffusers/schedulers/scheduling_repaint.py +1 -1
- diffusers/schedulers/scheduling_sasolver.py +28 -9
- diffusers/schedulers/scheduling_tcd.py +2 -6
- diffusers/schedulers/scheduling_unipc_multistep.py +53 -8
- diffusers/training_utils.py +16 -2
- diffusers/utils/__init__.py +5 -0
- diffusers/utils/constants.py +1 -0
- diffusers/utils/dummy_pt_objects.py +180 -0
- diffusers/utils/dummy_torch_and_transformers_objects.py +270 -0
- diffusers/utils/dynamic_modules_utils.py +3 -3
- diffusers/utils/hub_utils.py +31 -39
- diffusers/utils/import_utils.py +67 -0
- diffusers/utils/peft_utils.py +3 -0
- diffusers/utils/testing_utils.py +56 -1
- diffusers/utils/torch_utils.py +3 -0
- {diffusers-0.31.0.dist-info → diffusers-0.32.0.dist-info}/METADATA +69 -69
- {diffusers-0.31.0.dist-info → diffusers-0.32.0.dist-info}/RECORD +214 -162
- {diffusers-0.31.0.dist-info → diffusers-0.32.0.dist-info}/WHEEL +1 -1
- {diffusers-0.31.0.dist-info → diffusers-0.32.0.dist-info}/LICENSE +0 -0
- {diffusers-0.31.0.dist-info → diffusers-0.32.0.dist-info}/entry_points.txt +0 -0
- {diffusers-0.31.0.dist-info → diffusers-0.32.0.dist-info}/top_level.txt +0 -0
diffusers/models/embeddings.py
CHANGED
@@ -84,15 +84,106 @@ def get_3d_sincos_pos_embed(
|
|
84
84
|
temporal_size: int,
|
85
85
|
spatial_interpolation_scale: float = 1.0,
|
86
86
|
temporal_interpolation_scale: float = 1.0,
|
87
|
+
device: Optional[torch.device] = None,
|
88
|
+
output_type: str = "np",
|
89
|
+
) -> torch.Tensor:
|
90
|
+
r"""
|
91
|
+
Creates 3D sinusoidal positional embeddings.
|
92
|
+
|
93
|
+
Args:
|
94
|
+
embed_dim (`int`):
|
95
|
+
The embedding dimension of inputs. It must be divisible by 16.
|
96
|
+
spatial_size (`int` or `Tuple[int, int]`):
|
97
|
+
The spatial dimension of positional embeddings. If an integer is provided, the same size is applied to both
|
98
|
+
spatial dimensions (height and width).
|
99
|
+
temporal_size (`int`):
|
100
|
+
The temporal dimension of postional embeddings (number of frames).
|
101
|
+
spatial_interpolation_scale (`float`, defaults to 1.0):
|
102
|
+
Scale factor for spatial grid interpolation.
|
103
|
+
temporal_interpolation_scale (`float`, defaults to 1.0):
|
104
|
+
Scale factor for temporal grid interpolation.
|
105
|
+
|
106
|
+
Returns:
|
107
|
+
`torch.Tensor`:
|
108
|
+
The 3D sinusoidal positional embeddings of shape `[temporal_size, spatial_size[0] * spatial_size[1],
|
109
|
+
embed_dim]`.
|
110
|
+
"""
|
111
|
+
if output_type == "np":
|
112
|
+
return _get_3d_sincos_pos_embed_np(
|
113
|
+
embed_dim=embed_dim,
|
114
|
+
spatial_size=spatial_size,
|
115
|
+
temporal_size=temporal_size,
|
116
|
+
spatial_interpolation_scale=spatial_interpolation_scale,
|
117
|
+
temporal_interpolation_scale=temporal_interpolation_scale,
|
118
|
+
)
|
119
|
+
if embed_dim % 4 != 0:
|
120
|
+
raise ValueError("`embed_dim` must be divisible by 4")
|
121
|
+
if isinstance(spatial_size, int):
|
122
|
+
spatial_size = (spatial_size, spatial_size)
|
123
|
+
|
124
|
+
embed_dim_spatial = 3 * embed_dim // 4
|
125
|
+
embed_dim_temporal = embed_dim // 4
|
126
|
+
|
127
|
+
# 1. Spatial
|
128
|
+
grid_h = torch.arange(spatial_size[1], device=device, dtype=torch.float32) / spatial_interpolation_scale
|
129
|
+
grid_w = torch.arange(spatial_size[0], device=device, dtype=torch.float32) / spatial_interpolation_scale
|
130
|
+
grid = torch.meshgrid(grid_w, grid_h, indexing="xy") # here w goes first
|
131
|
+
grid = torch.stack(grid, dim=0)
|
132
|
+
|
133
|
+
grid = grid.reshape([2, 1, spatial_size[1], spatial_size[0]])
|
134
|
+
pos_embed_spatial = get_2d_sincos_pos_embed_from_grid(embed_dim_spatial, grid, output_type="pt")
|
135
|
+
|
136
|
+
# 2. Temporal
|
137
|
+
grid_t = torch.arange(temporal_size, device=device, dtype=torch.float32) / temporal_interpolation_scale
|
138
|
+
pos_embed_temporal = get_1d_sincos_pos_embed_from_grid(embed_dim_temporal, grid_t, output_type="pt")
|
139
|
+
|
140
|
+
# 3. Concat
|
141
|
+
pos_embed_spatial = pos_embed_spatial[None, :, :]
|
142
|
+
pos_embed_spatial = pos_embed_spatial.repeat_interleave(temporal_size, dim=0) # [T, H*W, D // 4 * 3]
|
143
|
+
|
144
|
+
pos_embed_temporal = pos_embed_temporal[:, None, :]
|
145
|
+
pos_embed_temporal = pos_embed_temporal.repeat_interleave(
|
146
|
+
spatial_size[0] * spatial_size[1], dim=1
|
147
|
+
) # [T, H*W, D // 4]
|
148
|
+
|
149
|
+
pos_embed = torch.concat([pos_embed_temporal, pos_embed_spatial], dim=-1) # [T, H*W, D]
|
150
|
+
return pos_embed
|
151
|
+
|
152
|
+
|
153
|
+
def _get_3d_sincos_pos_embed_np(
|
154
|
+
embed_dim: int,
|
155
|
+
spatial_size: Union[int, Tuple[int, int]],
|
156
|
+
temporal_size: int,
|
157
|
+
spatial_interpolation_scale: float = 1.0,
|
158
|
+
temporal_interpolation_scale: float = 1.0,
|
87
159
|
) -> np.ndarray:
|
88
160
|
r"""
|
161
|
+
Creates 3D sinusoidal positional embeddings.
|
162
|
+
|
89
163
|
Args:
|
90
164
|
embed_dim (`int`):
|
165
|
+
The embedding dimension of inputs. It must be divisible by 16.
|
91
166
|
spatial_size (`int` or `Tuple[int, int]`):
|
167
|
+
The spatial dimension of positional embeddings. If an integer is provided, the same size is applied to both
|
168
|
+
spatial dimensions (height and width).
|
92
169
|
temporal_size (`int`):
|
170
|
+
The temporal dimension of postional embeddings (number of frames).
|
93
171
|
spatial_interpolation_scale (`float`, defaults to 1.0):
|
172
|
+
Scale factor for spatial grid interpolation.
|
94
173
|
temporal_interpolation_scale (`float`, defaults to 1.0):
|
174
|
+
Scale factor for temporal grid interpolation.
|
175
|
+
|
176
|
+
Returns:
|
177
|
+
`np.ndarray`:
|
178
|
+
The 3D sinusoidal positional embeddings of shape `[temporal_size, spatial_size[0] * spatial_size[1],
|
179
|
+
embed_dim]`.
|
95
180
|
"""
|
181
|
+
deprecation_message = (
|
182
|
+
"`get_3d_sincos_pos_embed` uses `torch` and supports `device`."
|
183
|
+
" `from_numpy` is no longer required."
|
184
|
+
" Pass `output_type='pt' to use the new version now."
|
185
|
+
)
|
186
|
+
deprecate("output_type=='np'", "0.33.0", deprecation_message, standard_warn=False)
|
96
187
|
if embed_dim % 4 != 0:
|
97
188
|
raise ValueError("`embed_dim` must be divisible by 4")
|
98
189
|
if isinstance(spatial_size, int):
|
@@ -126,11 +217,164 @@ def get_3d_sincos_pos_embed(
|
|
126
217
|
|
127
218
|
|
128
219
|
def get_2d_sincos_pos_embed(
|
220
|
+
embed_dim,
|
221
|
+
grid_size,
|
222
|
+
cls_token=False,
|
223
|
+
extra_tokens=0,
|
224
|
+
interpolation_scale=1.0,
|
225
|
+
base_size=16,
|
226
|
+
device: Optional[torch.device] = None,
|
227
|
+
output_type: str = "np",
|
228
|
+
):
|
229
|
+
"""
|
230
|
+
Creates 2D sinusoidal positional embeddings.
|
231
|
+
|
232
|
+
Args:
|
233
|
+
embed_dim (`int`):
|
234
|
+
The embedding dimension.
|
235
|
+
grid_size (`int`):
|
236
|
+
The size of the grid height and width.
|
237
|
+
cls_token (`bool`, defaults to `False`):
|
238
|
+
Whether or not to add a classification token.
|
239
|
+
extra_tokens (`int`, defaults to `0`):
|
240
|
+
The number of extra tokens to add.
|
241
|
+
interpolation_scale (`float`, defaults to `1.0`):
|
242
|
+
The scale of the interpolation.
|
243
|
+
|
244
|
+
Returns:
|
245
|
+
pos_embed (`torch.Tensor`):
|
246
|
+
Shape is either `[grid_size * grid_size, embed_dim]` if not using cls_token, or `[1 + grid_size*grid_size,
|
247
|
+
embed_dim]` if using cls_token
|
248
|
+
"""
|
249
|
+
if output_type == "np":
|
250
|
+
deprecation_message = (
|
251
|
+
"`get_2d_sincos_pos_embed` uses `torch` and supports `device`."
|
252
|
+
" `from_numpy` is no longer required."
|
253
|
+
" Pass `output_type='pt' to use the new version now."
|
254
|
+
)
|
255
|
+
deprecate("output_type=='np'", "0.33.0", deprecation_message, standard_warn=False)
|
256
|
+
return get_2d_sincos_pos_embed_np(
|
257
|
+
embed_dim=embed_dim,
|
258
|
+
grid_size=grid_size,
|
259
|
+
cls_token=cls_token,
|
260
|
+
extra_tokens=extra_tokens,
|
261
|
+
interpolation_scale=interpolation_scale,
|
262
|
+
base_size=base_size,
|
263
|
+
)
|
264
|
+
if isinstance(grid_size, int):
|
265
|
+
grid_size = (grid_size, grid_size)
|
266
|
+
|
267
|
+
grid_h = (
|
268
|
+
torch.arange(grid_size[0], device=device, dtype=torch.float32)
|
269
|
+
/ (grid_size[0] / base_size)
|
270
|
+
/ interpolation_scale
|
271
|
+
)
|
272
|
+
grid_w = (
|
273
|
+
torch.arange(grid_size[1], device=device, dtype=torch.float32)
|
274
|
+
/ (grid_size[1] / base_size)
|
275
|
+
/ interpolation_scale
|
276
|
+
)
|
277
|
+
grid = torch.meshgrid(grid_w, grid_h, indexing="xy") # here w goes first
|
278
|
+
grid = torch.stack(grid, dim=0)
|
279
|
+
|
280
|
+
grid = grid.reshape([2, 1, grid_size[1], grid_size[0]])
|
281
|
+
pos_embed = get_2d_sincos_pos_embed_from_grid(embed_dim, grid, output_type=output_type)
|
282
|
+
if cls_token and extra_tokens > 0:
|
283
|
+
pos_embed = torch.concat([torch.zeros([extra_tokens, embed_dim]), pos_embed], dim=0)
|
284
|
+
return pos_embed
|
285
|
+
|
286
|
+
|
287
|
+
def get_2d_sincos_pos_embed_from_grid(embed_dim, grid, output_type="np"):
|
288
|
+
r"""
|
289
|
+
This function generates 2D sinusoidal positional embeddings from a grid.
|
290
|
+
|
291
|
+
Args:
|
292
|
+
embed_dim (`int`): The embedding dimension.
|
293
|
+
grid (`torch.Tensor`): Grid of positions with shape `(H * W,)`.
|
294
|
+
|
295
|
+
Returns:
|
296
|
+
`torch.Tensor`: The 2D sinusoidal positional embeddings with shape `(H * W, embed_dim)`
|
297
|
+
"""
|
298
|
+
if output_type == "np":
|
299
|
+
deprecation_message = (
|
300
|
+
"`get_2d_sincos_pos_embed_from_grid` uses `torch` and supports `device`."
|
301
|
+
" `from_numpy` is no longer required."
|
302
|
+
" Pass `output_type='pt' to use the new version now."
|
303
|
+
)
|
304
|
+
deprecate("output_type=='np'", "0.33.0", deprecation_message, standard_warn=False)
|
305
|
+
return get_2d_sincos_pos_embed_from_grid_np(
|
306
|
+
embed_dim=embed_dim,
|
307
|
+
grid=grid,
|
308
|
+
)
|
309
|
+
if embed_dim % 2 != 0:
|
310
|
+
raise ValueError("embed_dim must be divisible by 2")
|
311
|
+
|
312
|
+
# use half of dimensions to encode grid_h
|
313
|
+
emb_h = get_1d_sincos_pos_embed_from_grid(embed_dim // 2, grid[0], output_type=output_type) # (H*W, D/2)
|
314
|
+
emb_w = get_1d_sincos_pos_embed_from_grid(embed_dim // 2, grid[1], output_type=output_type) # (H*W, D/2)
|
315
|
+
|
316
|
+
emb = torch.concat([emb_h, emb_w], dim=1) # (H*W, D)
|
317
|
+
return emb
|
318
|
+
|
319
|
+
|
320
|
+
def get_1d_sincos_pos_embed_from_grid(embed_dim, pos, output_type="np"):
|
321
|
+
"""
|
322
|
+
This function generates 1D positional embeddings from a grid.
|
323
|
+
|
324
|
+
Args:
|
325
|
+
embed_dim (`int`): The embedding dimension `D`
|
326
|
+
pos (`torch.Tensor`): 1D tensor of positions with shape `(M,)`
|
327
|
+
|
328
|
+
Returns:
|
329
|
+
`torch.Tensor`: Sinusoidal positional embeddings of shape `(M, D)`.
|
330
|
+
"""
|
331
|
+
if output_type == "np":
|
332
|
+
deprecation_message = (
|
333
|
+
"`get_1d_sincos_pos_embed_from_grid` uses `torch` and supports `device`."
|
334
|
+
" `from_numpy` is no longer required."
|
335
|
+
" Pass `output_type='pt' to use the new version now."
|
336
|
+
)
|
337
|
+
deprecate("output_type=='np'", "0.33.0", deprecation_message, standard_warn=False)
|
338
|
+
return get_1d_sincos_pos_embed_from_grid_np(embed_dim=embed_dim, pos=pos)
|
339
|
+
if embed_dim % 2 != 0:
|
340
|
+
raise ValueError("embed_dim must be divisible by 2")
|
341
|
+
|
342
|
+
omega = torch.arange(embed_dim // 2, device=pos.device, dtype=torch.float64)
|
343
|
+
omega /= embed_dim / 2.0
|
344
|
+
omega = 1.0 / 10000**omega # (D/2,)
|
345
|
+
|
346
|
+
pos = pos.reshape(-1) # (M,)
|
347
|
+
out = torch.outer(pos, omega) # (M, D/2), outer product
|
348
|
+
|
349
|
+
emb_sin = torch.sin(out) # (M, D/2)
|
350
|
+
emb_cos = torch.cos(out) # (M, D/2)
|
351
|
+
|
352
|
+
emb = torch.concat([emb_sin, emb_cos], dim=1) # (M, D)
|
353
|
+
return emb
|
354
|
+
|
355
|
+
|
356
|
+
def get_2d_sincos_pos_embed_np(
|
129
357
|
embed_dim, grid_size, cls_token=False, extra_tokens=0, interpolation_scale=1.0, base_size=16
|
130
358
|
):
|
131
359
|
"""
|
132
|
-
|
133
|
-
|
360
|
+
Creates 2D sinusoidal positional embeddings.
|
361
|
+
|
362
|
+
Args:
|
363
|
+
embed_dim (`int`):
|
364
|
+
The embedding dimension.
|
365
|
+
grid_size (`int`):
|
366
|
+
The size of the grid height and width.
|
367
|
+
cls_token (`bool`, defaults to `False`):
|
368
|
+
Whether or not to add a classification token.
|
369
|
+
extra_tokens (`int`, defaults to `0`):
|
370
|
+
The number of extra tokens to add.
|
371
|
+
interpolation_scale (`float`, defaults to `1.0`):
|
372
|
+
The scale of the interpolation.
|
373
|
+
|
374
|
+
Returns:
|
375
|
+
pos_embed (`np.ndarray`):
|
376
|
+
Shape is either `[grid_size * grid_size, embed_dim]` if not using cls_token, or `[1 + grid_size*grid_size,
|
377
|
+
embed_dim]` if using cls_token
|
134
378
|
"""
|
135
379
|
if isinstance(grid_size, int):
|
136
380
|
grid_size = (grid_size, grid_size)
|
@@ -141,27 +385,44 @@ def get_2d_sincos_pos_embed(
|
|
141
385
|
grid = np.stack(grid, axis=0)
|
142
386
|
|
143
387
|
grid = grid.reshape([2, 1, grid_size[1], grid_size[0]])
|
144
|
-
pos_embed =
|
388
|
+
pos_embed = get_2d_sincos_pos_embed_from_grid_np(embed_dim, grid)
|
145
389
|
if cls_token and extra_tokens > 0:
|
146
390
|
pos_embed = np.concatenate([np.zeros([extra_tokens, embed_dim]), pos_embed], axis=0)
|
147
391
|
return pos_embed
|
148
392
|
|
149
393
|
|
150
|
-
def
|
394
|
+
def get_2d_sincos_pos_embed_from_grid_np(embed_dim, grid):
|
395
|
+
r"""
|
396
|
+
This function generates 2D sinusoidal positional embeddings from a grid.
|
397
|
+
|
398
|
+
Args:
|
399
|
+
embed_dim (`int`): The embedding dimension.
|
400
|
+
grid (`np.ndarray`): Grid of positions with shape `(H * W,)`.
|
401
|
+
|
402
|
+
Returns:
|
403
|
+
`np.ndarray`: The 2D sinusoidal positional embeddings with shape `(H * W, embed_dim)`
|
404
|
+
"""
|
151
405
|
if embed_dim % 2 != 0:
|
152
406
|
raise ValueError("embed_dim must be divisible by 2")
|
153
407
|
|
154
408
|
# use half of dimensions to encode grid_h
|
155
|
-
emb_h =
|
156
|
-
emb_w =
|
409
|
+
emb_h = get_1d_sincos_pos_embed_from_grid_np(embed_dim // 2, grid[0]) # (H*W, D/2)
|
410
|
+
emb_w = get_1d_sincos_pos_embed_from_grid_np(embed_dim // 2, grid[1]) # (H*W, D/2)
|
157
411
|
|
158
412
|
emb = np.concatenate([emb_h, emb_w], axis=1) # (H*W, D)
|
159
413
|
return emb
|
160
414
|
|
161
415
|
|
162
|
-
def
|
416
|
+
def get_1d_sincos_pos_embed_from_grid_np(embed_dim, pos):
|
163
417
|
"""
|
164
|
-
|
418
|
+
This function generates 1D positional embeddings from a grid.
|
419
|
+
|
420
|
+
Args:
|
421
|
+
embed_dim (`int`): The embedding dimension `D`
|
422
|
+
pos (`numpy.ndarray`): 1D tensor of positions with shape `(M,)`
|
423
|
+
|
424
|
+
Returns:
|
425
|
+
`numpy.ndarray`: Sinusoidal positional embeddings of shape `(M, D)`.
|
165
426
|
"""
|
166
427
|
if embed_dim % 2 != 0:
|
167
428
|
raise ValueError("embed_dim must be divisible by 2")
|
@@ -181,7 +442,22 @@ def get_1d_sincos_pos_embed_from_grid(embed_dim, pos):
|
|
181
442
|
|
182
443
|
|
183
444
|
class PatchEmbed(nn.Module):
|
184
|
-
"""
|
445
|
+
"""
|
446
|
+
2D Image to Patch Embedding with support for SD3 cropping.
|
447
|
+
|
448
|
+
Args:
|
449
|
+
height (`int`, defaults to `224`): The height of the image.
|
450
|
+
width (`int`, defaults to `224`): The width of the image.
|
451
|
+
patch_size (`int`, defaults to `16`): The size of the patches.
|
452
|
+
in_channels (`int`, defaults to `3`): The number of input channels.
|
453
|
+
embed_dim (`int`, defaults to `768`): The output dimension of the embedding.
|
454
|
+
layer_norm (`bool`, defaults to `False`): Whether or not to use layer normalization.
|
455
|
+
flatten (`bool`, defaults to `True`): Whether or not to flatten the output.
|
456
|
+
bias (`bool`, defaults to `True`): Whether or not to use bias.
|
457
|
+
interpolation_scale (`float`, defaults to `1`): The scale of the interpolation.
|
458
|
+
pos_embed_type (`str`, defaults to `"sincos"`): The type of positional embedding.
|
459
|
+
pos_embed_max_size (`int`, defaults to `None`): The maximum size of the positional embedding.
|
460
|
+
"""
|
185
461
|
|
186
462
|
def __init__(
|
187
463
|
self,
|
@@ -227,10 +503,14 @@ class PatchEmbed(nn.Module):
|
|
227
503
|
self.pos_embed = None
|
228
504
|
elif pos_embed_type == "sincos":
|
229
505
|
pos_embed = get_2d_sincos_pos_embed(
|
230
|
-
embed_dim,
|
506
|
+
embed_dim,
|
507
|
+
grid_size,
|
508
|
+
base_size=self.base_size,
|
509
|
+
interpolation_scale=self.interpolation_scale,
|
510
|
+
output_type="pt",
|
231
511
|
)
|
232
512
|
persistent = True if pos_embed_max_size else False
|
233
|
-
self.register_buffer("pos_embed",
|
513
|
+
self.register_buffer("pos_embed", pos_embed.float().unsqueeze(0), persistent=persistent)
|
234
514
|
else:
|
235
515
|
raise ValueError(f"Unsupported pos_embed_type: {pos_embed_type}")
|
236
516
|
|
@@ -262,7 +542,6 @@ class PatchEmbed(nn.Module):
|
|
262
542
|
height, width = latent.shape[-2:]
|
263
543
|
else:
|
264
544
|
height, width = latent.shape[-2] // self.patch_size, latent.shape[-1] // self.patch_size
|
265
|
-
|
266
545
|
latent = self.proj(latent)
|
267
546
|
if self.flatten:
|
268
547
|
latent = latent.flatten(2).transpose(1, 2) # BCHW -> BNC
|
@@ -280,8 +559,10 @@ class PatchEmbed(nn.Module):
|
|
280
559
|
grid_size=(height, width),
|
281
560
|
base_size=self.base_size,
|
282
561
|
interpolation_scale=self.interpolation_scale,
|
562
|
+
device=latent.device,
|
563
|
+
output_type="pt",
|
283
564
|
)
|
284
|
-
pos_embed =
|
565
|
+
pos_embed = pos_embed.float().unsqueeze(0)
|
285
566
|
else:
|
286
567
|
pos_embed = self.pos_embed
|
287
568
|
|
@@ -289,7 +570,15 @@ class PatchEmbed(nn.Module):
|
|
289
570
|
|
290
571
|
|
291
572
|
class LuminaPatchEmbed(nn.Module):
|
292
|
-
"""
|
573
|
+
"""
|
574
|
+
2D Image to Patch Embedding with support for Lumina-T2X
|
575
|
+
|
576
|
+
Args:
|
577
|
+
patch_size (`int`, defaults to `2`): The size of the patches.
|
578
|
+
in_channels (`int`, defaults to `4`): The number of input channels.
|
579
|
+
embed_dim (`int`, defaults to `768`): The output dimension of the embedding.
|
580
|
+
bias (`bool`, defaults to `True`): Whether or not to use bias.
|
581
|
+
"""
|
293
582
|
|
294
583
|
def __init__(self, patch_size=2, in_channels=4, embed_dim=768, bias=True):
|
295
584
|
super().__init__()
|
@@ -338,6 +627,7 @@ class CogVideoXPatchEmbed(nn.Module):
|
|
338
627
|
def __init__(
|
339
628
|
self,
|
340
629
|
patch_size: int = 2,
|
630
|
+
patch_size_t: Optional[int] = None,
|
341
631
|
in_channels: int = 16,
|
342
632
|
embed_dim: int = 1920,
|
343
633
|
text_embed_dim: int = 4096,
|
@@ -355,6 +645,7 @@ class CogVideoXPatchEmbed(nn.Module):
|
|
355
645
|
super().__init__()
|
356
646
|
|
357
647
|
self.patch_size = patch_size
|
648
|
+
self.patch_size_t = patch_size_t
|
358
649
|
self.embed_dim = embed_dim
|
359
650
|
self.sample_height = sample_height
|
360
651
|
self.sample_width = sample_width
|
@@ -366,9 +657,15 @@ class CogVideoXPatchEmbed(nn.Module):
|
|
366
657
|
self.use_positional_embeddings = use_positional_embeddings
|
367
658
|
self.use_learned_positional_embeddings = use_learned_positional_embeddings
|
368
659
|
|
369
|
-
|
370
|
-
|
371
|
-
|
660
|
+
if patch_size_t is None:
|
661
|
+
# CogVideoX 1.0 checkpoints
|
662
|
+
self.proj = nn.Conv2d(
|
663
|
+
in_channels, embed_dim, kernel_size=(patch_size, patch_size), stride=patch_size, bias=bias
|
664
|
+
)
|
665
|
+
else:
|
666
|
+
# CogVideoX 1.5 checkpoints
|
667
|
+
self.proj = nn.Linear(in_channels * patch_size * patch_size * patch_size_t, embed_dim)
|
668
|
+
|
372
669
|
self.text_proj = nn.Linear(text_embed_dim, embed_dim)
|
373
670
|
|
374
671
|
if use_positional_embeddings or use_learned_positional_embeddings:
|
@@ -376,7 +673,9 @@ class CogVideoXPatchEmbed(nn.Module):
|
|
376
673
|
pos_embedding = self._get_positional_embeddings(sample_height, sample_width, sample_frames)
|
377
674
|
self.register_buffer("pos_embedding", pos_embedding, persistent=persistent)
|
378
675
|
|
379
|
-
def _get_positional_embeddings(
|
676
|
+
def _get_positional_embeddings(
|
677
|
+
self, sample_height: int, sample_width: int, sample_frames: int, device: Optional[torch.device] = None
|
678
|
+
) -> torch.Tensor:
|
380
679
|
post_patch_height = sample_height // self.patch_size
|
381
680
|
post_patch_width = sample_width // self.patch_size
|
382
681
|
post_time_compression_frames = (sample_frames - 1) // self.temporal_compression_ratio + 1
|
@@ -388,9 +687,11 @@ class CogVideoXPatchEmbed(nn.Module):
|
|
388
687
|
post_time_compression_frames,
|
389
688
|
self.spatial_interpolation_scale,
|
390
689
|
self.temporal_interpolation_scale,
|
690
|
+
device=device,
|
691
|
+
output_type="pt",
|
391
692
|
)
|
392
|
-
pos_embedding =
|
393
|
-
joint_pos_embedding =
|
693
|
+
pos_embedding = pos_embedding.flatten(0, 1)
|
694
|
+
joint_pos_embedding = pos_embedding.new_zeros(
|
394
695
|
1, self.max_text_seq_length + num_patches, self.embed_dim, requires_grad=False
|
395
696
|
)
|
396
697
|
joint_pos_embedding.data[:, self.max_text_seq_length :].copy_(pos_embedding)
|
@@ -407,12 +708,24 @@ class CogVideoXPatchEmbed(nn.Module):
|
|
407
708
|
"""
|
408
709
|
text_embeds = self.text_proj(text_embeds)
|
409
710
|
|
410
|
-
|
411
|
-
|
412
|
-
|
413
|
-
|
414
|
-
|
415
|
-
|
711
|
+
batch_size, num_frames, channels, height, width = image_embeds.shape
|
712
|
+
|
713
|
+
if self.patch_size_t is None:
|
714
|
+
image_embeds = image_embeds.reshape(-1, channels, height, width)
|
715
|
+
image_embeds = self.proj(image_embeds)
|
716
|
+
image_embeds = image_embeds.view(batch_size, num_frames, *image_embeds.shape[1:])
|
717
|
+
image_embeds = image_embeds.flatten(3).transpose(2, 3) # [batch, num_frames, height x width, channels]
|
718
|
+
image_embeds = image_embeds.flatten(1, 2) # [batch, num_frames x height x width, channels]
|
719
|
+
else:
|
720
|
+
p = self.patch_size
|
721
|
+
p_t = self.patch_size_t
|
722
|
+
|
723
|
+
image_embeds = image_embeds.permute(0, 1, 3, 4, 2)
|
724
|
+
image_embeds = image_embeds.reshape(
|
725
|
+
batch_size, num_frames // p_t, p_t, height // p, p, width // p, p, channels
|
726
|
+
)
|
727
|
+
image_embeds = image_embeds.permute(0, 1, 3, 5, 7, 2, 4, 6).flatten(4, 7).flatten(1, 3)
|
728
|
+
image_embeds = self.proj(image_embeds)
|
416
729
|
|
417
730
|
embeds = torch.cat(
|
418
731
|
[text_embeds, image_embeds], dim=1
|
@@ -432,11 +745,13 @@ class CogVideoXPatchEmbed(nn.Module):
|
|
432
745
|
or self.sample_width != width
|
433
746
|
or self.sample_frames != pre_time_compression_frames
|
434
747
|
):
|
435
|
-
pos_embedding = self._get_positional_embeddings(
|
436
|
-
|
748
|
+
pos_embedding = self._get_positional_embeddings(
|
749
|
+
height, width, pre_time_compression_frames, device=embeds.device
|
750
|
+
)
|
437
751
|
else:
|
438
752
|
pos_embedding = self.pos_embedding
|
439
753
|
|
754
|
+
pos_embedding = pos_embedding.to(dtype=embeds.dtype)
|
440
755
|
embeds = embeds + pos_embedding
|
441
756
|
|
442
757
|
return embeds
|
@@ -463,9 +778,11 @@ class CogView3PlusPatchEmbed(nn.Module):
|
|
463
778
|
# Linear projection for text embeddings
|
464
779
|
self.text_proj = nn.Linear(text_hidden_size, hidden_size)
|
465
780
|
|
466
|
-
pos_embed = get_2d_sincos_pos_embed(
|
781
|
+
pos_embed = get_2d_sincos_pos_embed(
|
782
|
+
hidden_size, pos_embed_max_size, base_size=pos_embed_max_size, output_type="pt"
|
783
|
+
)
|
467
784
|
pos_embed = pos_embed.reshape(pos_embed_max_size, pos_embed_max_size, hidden_size)
|
468
|
-
self.register_buffer("pos_embed",
|
785
|
+
self.register_buffer("pos_embed", pos_embed.float(), persistent=False)
|
469
786
|
|
470
787
|
def forward(self, hidden_states: torch.Tensor, encoder_hidden_states: torch.Tensor) -> torch.Tensor:
|
471
788
|
batch_size, channel, height, width = hidden_states.shape
|
@@ -497,7 +814,15 @@ class CogView3PlusPatchEmbed(nn.Module):
|
|
497
814
|
|
498
815
|
|
499
816
|
def get_3d_rotary_pos_embed(
|
500
|
-
embed_dim,
|
817
|
+
embed_dim,
|
818
|
+
crops_coords,
|
819
|
+
grid_size,
|
820
|
+
temporal_size,
|
821
|
+
theta: int = 10000,
|
822
|
+
use_real: bool = True,
|
823
|
+
grid_type: str = "linspace",
|
824
|
+
max_size: Optional[Tuple[int, int]] = None,
|
825
|
+
device: Optional[torch.device] = None,
|
501
826
|
) -> Union[torch.Tensor, Tuple[torch.Tensor, torch.Tensor]]:
|
502
827
|
"""
|
503
828
|
RoPE for video tokens with 3D structure.
|
@@ -513,17 +838,36 @@ def get_3d_rotary_pos_embed(
|
|
513
838
|
The size of the temporal dimension.
|
514
839
|
theta (`float`):
|
515
840
|
Scaling factor for frequency computation.
|
841
|
+
grid_type (`str`):
|
842
|
+
Whether to use "linspace" or "slice" to compute grids.
|
516
843
|
|
517
844
|
Returns:
|
518
845
|
`torch.Tensor`: positional embedding with shape `(temporal_size * grid_size[0] * grid_size[1], embed_dim/2)`.
|
519
846
|
"""
|
520
847
|
if use_real is not True:
|
521
848
|
raise ValueError(" `use_real = False` is not currently supported for get_3d_rotary_pos_embed")
|
522
|
-
|
523
|
-
|
524
|
-
|
525
|
-
|
526
|
-
|
849
|
+
|
850
|
+
if grid_type == "linspace":
|
851
|
+
start, stop = crops_coords
|
852
|
+
grid_size_h, grid_size_w = grid_size
|
853
|
+
grid_h = torch.linspace(
|
854
|
+
start[0], stop[0] * (grid_size_h - 1) / grid_size_h, grid_size_h, device=device, dtype=torch.float32
|
855
|
+
)
|
856
|
+
grid_w = torch.linspace(
|
857
|
+
start[1], stop[1] * (grid_size_w - 1) / grid_size_w, grid_size_w, device=device, dtype=torch.float32
|
858
|
+
)
|
859
|
+
grid_t = torch.arange(temporal_size, device=device, dtype=torch.float32)
|
860
|
+
grid_t = torch.linspace(
|
861
|
+
0, temporal_size * (temporal_size - 1) / temporal_size, temporal_size, device=device, dtype=torch.float32
|
862
|
+
)
|
863
|
+
elif grid_type == "slice":
|
864
|
+
max_h, max_w = max_size
|
865
|
+
grid_size_h, grid_size_w = grid_size
|
866
|
+
grid_h = torch.arange(max_h, device=device, dtype=torch.float32)
|
867
|
+
grid_w = torch.arange(max_w, device=device, dtype=torch.float32)
|
868
|
+
grid_t = torch.arange(temporal_size, device=device, dtype=torch.float32)
|
869
|
+
else:
|
870
|
+
raise ValueError("Invalid value passed for `grid_type`.")
|
527
871
|
|
528
872
|
# Compute dimensions for each axis
|
529
873
|
dim_t = embed_dim // 4
|
@@ -531,10 +875,10 @@ def get_3d_rotary_pos_embed(
|
|
531
875
|
dim_w = embed_dim // 8 * 3
|
532
876
|
|
533
877
|
# Temporal frequencies
|
534
|
-
freqs_t = get_1d_rotary_pos_embed(dim_t, grid_t, use_real=True)
|
878
|
+
freqs_t = get_1d_rotary_pos_embed(dim_t, grid_t, theta=theta, use_real=True)
|
535
879
|
# Spatial frequencies for height and width
|
536
|
-
freqs_h = get_1d_rotary_pos_embed(dim_h, grid_h, use_real=True)
|
537
|
-
freqs_w = get_1d_rotary_pos_embed(dim_w, grid_w, use_real=True)
|
880
|
+
freqs_h = get_1d_rotary_pos_embed(dim_h, grid_h, theta=theta, use_real=True)
|
881
|
+
freqs_w = get_1d_rotary_pos_embed(dim_w, grid_w, theta=theta, use_real=True)
|
538
882
|
|
539
883
|
# BroadCast and concatenate temporal and spaial frequencie (height and width) into a 3d tensor
|
540
884
|
def combine_time_height_width(freqs_t, freqs_h, freqs_w):
|
@@ -559,12 +903,111 @@ def get_3d_rotary_pos_embed(
|
|
559
903
|
t_cos, t_sin = freqs_t # both t_cos and t_sin has shape: temporal_size, dim_t
|
560
904
|
h_cos, h_sin = freqs_h # both h_cos and h_sin has shape: grid_size_h, dim_h
|
561
905
|
w_cos, w_sin = freqs_w # both w_cos and w_sin has shape: grid_size_w, dim_w
|
906
|
+
|
907
|
+
if grid_type == "slice":
|
908
|
+
t_cos, t_sin = t_cos[:temporal_size], t_sin[:temporal_size]
|
909
|
+
h_cos, h_sin = h_cos[:grid_size_h], h_sin[:grid_size_h]
|
910
|
+
w_cos, w_sin = w_cos[:grid_size_w], w_sin[:grid_size_w]
|
911
|
+
|
562
912
|
cos = combine_time_height_width(t_cos, h_cos, w_cos)
|
563
913
|
sin = combine_time_height_width(t_sin, h_sin, w_sin)
|
564
914
|
return cos, sin
|
565
915
|
|
566
916
|
|
567
|
-
def
|
917
|
+
def get_3d_rotary_pos_embed_allegro(
|
918
|
+
embed_dim,
|
919
|
+
crops_coords,
|
920
|
+
grid_size,
|
921
|
+
temporal_size,
|
922
|
+
interpolation_scale: Tuple[float, float, float] = (1.0, 1.0, 1.0),
|
923
|
+
theta: int = 10000,
|
924
|
+
device: Optional[torch.device] = None,
|
925
|
+
) -> Union[torch.Tensor, Tuple[torch.Tensor, torch.Tensor]]:
|
926
|
+
# TODO(aryan): docs
|
927
|
+
start, stop = crops_coords
|
928
|
+
grid_size_h, grid_size_w = grid_size
|
929
|
+
interpolation_scale_t, interpolation_scale_h, interpolation_scale_w = interpolation_scale
|
930
|
+
grid_t = torch.linspace(
|
931
|
+
0, temporal_size * (temporal_size - 1) / temporal_size, temporal_size, device=device, dtype=torch.float32
|
932
|
+
)
|
933
|
+
grid_h = torch.linspace(
|
934
|
+
start[0], stop[0] * (grid_size_h - 1) / grid_size_h, grid_size_h, device=device, dtype=torch.float32
|
935
|
+
)
|
936
|
+
grid_w = torch.linspace(
|
937
|
+
start[1], stop[1] * (grid_size_w - 1) / grid_size_w, grid_size_w, device=device, dtype=torch.float32
|
938
|
+
)
|
939
|
+
|
940
|
+
# Compute dimensions for each axis
|
941
|
+
dim_t = embed_dim // 3
|
942
|
+
dim_h = embed_dim // 3
|
943
|
+
dim_w = embed_dim // 3
|
944
|
+
|
945
|
+
# Temporal frequencies
|
946
|
+
freqs_t = get_1d_rotary_pos_embed(
|
947
|
+
dim_t, grid_t / interpolation_scale_t, theta=theta, use_real=True, repeat_interleave_real=False
|
948
|
+
)
|
949
|
+
# Spatial frequencies for height and width
|
950
|
+
freqs_h = get_1d_rotary_pos_embed(
|
951
|
+
dim_h, grid_h / interpolation_scale_h, theta=theta, use_real=True, repeat_interleave_real=False
|
952
|
+
)
|
953
|
+
freqs_w = get_1d_rotary_pos_embed(
|
954
|
+
dim_w, grid_w / interpolation_scale_w, theta=theta, use_real=True, repeat_interleave_real=False
|
955
|
+
)
|
956
|
+
|
957
|
+
return freqs_t, freqs_h, freqs_w, grid_t, grid_h, grid_w
|
958
|
+
|
959
|
+
|
960
|
+
def get_2d_rotary_pos_embed(
|
961
|
+
embed_dim, crops_coords, grid_size, use_real=True, device: Optional[torch.device] = None, output_type: str = "np"
|
962
|
+
):
|
963
|
+
"""
|
964
|
+
RoPE for image tokens with 2d structure.
|
965
|
+
|
966
|
+
Args:
|
967
|
+
embed_dim: (`int`):
|
968
|
+
The embedding dimension size
|
969
|
+
crops_coords (`Tuple[int]`)
|
970
|
+
The top-left and bottom-right coordinates of the crop.
|
971
|
+
grid_size (`Tuple[int]`):
|
972
|
+
The grid size of the positional embedding.
|
973
|
+
use_real (`bool`):
|
974
|
+
If True, return real part and imaginary part separately. Otherwise, return complex numbers.
|
975
|
+
device: (`torch.device`, **optional**):
|
976
|
+
The device used to create tensors.
|
977
|
+
|
978
|
+
Returns:
|
979
|
+
`torch.Tensor`: positional embedding with shape `( grid_size * grid_size, embed_dim/2)`.
|
980
|
+
"""
|
981
|
+
if output_type == "np":
|
982
|
+
deprecation_message = (
|
983
|
+
"`get_2d_sincos_pos_embed` uses `torch` and supports `device`."
|
984
|
+
" `from_numpy` is no longer required."
|
985
|
+
" Pass `output_type='pt' to use the new version now."
|
986
|
+
)
|
987
|
+
deprecate("output_type=='np'", "0.33.0", deprecation_message, standard_warn=False)
|
988
|
+
return _get_2d_rotary_pos_embed_np(
|
989
|
+
embed_dim=embed_dim,
|
990
|
+
crops_coords=crops_coords,
|
991
|
+
grid_size=grid_size,
|
992
|
+
use_real=use_real,
|
993
|
+
)
|
994
|
+
start, stop = crops_coords
|
995
|
+
# scale end by (steps−1)/steps matches np.linspace(..., endpoint=False)
|
996
|
+
grid_h = torch.linspace(
|
997
|
+
start[0], stop[0] * (grid_size[0] - 1) / grid_size[0], grid_size[0], device=device, dtype=torch.float32
|
998
|
+
)
|
999
|
+
grid_w = torch.linspace(
|
1000
|
+
start[1], stop[1] * (grid_size[1] - 1) / grid_size[1], grid_size[1], device=device, dtype=torch.float32
|
1001
|
+
)
|
1002
|
+
grid = torch.meshgrid(grid_w, grid_h, indexing="xy")
|
1003
|
+
grid = torch.stack(grid, dim=0) # [2, W, H]
|
1004
|
+
|
1005
|
+
grid = grid.reshape([2, 1, *grid.shape[1:]])
|
1006
|
+
pos_embed = get_2d_rotary_pos_embed_from_grid(embed_dim, grid, use_real=use_real)
|
1007
|
+
return pos_embed
|
1008
|
+
|
1009
|
+
|
1010
|
+
def _get_2d_rotary_pos_embed_np(embed_dim, crops_coords, grid_size, use_real=True):
|
568
1011
|
"""
|
569
1012
|
RoPE for image tokens with 2d structure.
|
570
1013
|
|
@@ -593,6 +1036,20 @@ def get_2d_rotary_pos_embed(embed_dim, crops_coords, grid_size, use_real=True):
|
|
593
1036
|
|
594
1037
|
|
595
1038
|
def get_2d_rotary_pos_embed_from_grid(embed_dim, grid, use_real=False):
|
1039
|
+
"""
|
1040
|
+
Get 2D RoPE from grid.
|
1041
|
+
|
1042
|
+
Args:
|
1043
|
+
embed_dim: (`int`):
|
1044
|
+
The embedding dimension size, corresponding to hidden_size_head.
|
1045
|
+
grid (`np.ndarray`):
|
1046
|
+
The grid of the positional embedding.
|
1047
|
+
use_real (`bool`):
|
1048
|
+
If True, return real part and imaginary part separately. Otherwise, return complex numbers.
|
1049
|
+
|
1050
|
+
Returns:
|
1051
|
+
`torch.Tensor`: positional embedding with shape `( grid_size * grid_size, embed_dim/2)`.
|
1052
|
+
"""
|
596
1053
|
assert embed_dim % 4 == 0
|
597
1054
|
|
598
1055
|
# use half of dimensions to encode grid_h
|
@@ -613,6 +1070,23 @@ def get_2d_rotary_pos_embed_from_grid(embed_dim, grid, use_real=False):
|
|
613
1070
|
|
614
1071
|
|
615
1072
|
def get_2d_rotary_pos_embed_lumina(embed_dim, len_h, len_w, linear_factor=1.0, ntk_factor=1.0):
|
1073
|
+
"""
|
1074
|
+
Get 2D RoPE from grid.
|
1075
|
+
|
1076
|
+
Args:
|
1077
|
+
embed_dim: (`int`):
|
1078
|
+
The embedding dimension size, corresponding to hidden_size_head.
|
1079
|
+
grid (`np.ndarray`):
|
1080
|
+
The grid of the positional embedding.
|
1081
|
+
linear_factor (`float`):
|
1082
|
+
The linear factor of the positional embedding, which is used to scale the positional embedding in the linear
|
1083
|
+
layer.
|
1084
|
+
ntk_factor (`float`):
|
1085
|
+
The ntk factor of the positional embedding, which is used to scale the positional embedding in the ntk layer.
|
1086
|
+
|
1087
|
+
Returns:
|
1088
|
+
`torch.Tensor`: positional embedding with shape `( grid_size * grid_size, embed_dim/2)`.
|
1089
|
+
"""
|
616
1090
|
assert embed_dim % 4 == 0
|
617
1091
|
|
618
1092
|
emb_h = get_1d_rotary_pos_embed(
|
@@ -684,7 +1158,7 @@ def get_1d_rotary_pos_embed(
|
|
684
1158
|
freqs_sin = freqs.sin().repeat_interleave(2, dim=1).float() # [S, D]
|
685
1159
|
return freqs_cos, freqs_sin
|
686
1160
|
elif use_real:
|
687
|
-
# stable audio
|
1161
|
+
# stable audio, allegro
|
688
1162
|
freqs_cos = torch.cat([freqs.cos(), freqs.cos()], dim=-1).float() # [S, D]
|
689
1163
|
freqs_sin = torch.cat([freqs.sin(), freqs.sin()], dim=-1).float() # [S, D]
|
690
1164
|
return freqs_cos, freqs_sin
|
@@ -743,6 +1217,24 @@ def apply_rotary_emb(
|
|
743
1217
|
return x_out.type_as(x)
|
744
1218
|
|
745
1219
|
|
1220
|
+
def apply_rotary_emb_allegro(x: torch.Tensor, freqs_cis, positions):
|
1221
|
+
# TODO(aryan): rewrite
|
1222
|
+
def apply_1d_rope(tokens, pos, cos, sin):
|
1223
|
+
cos = F.embedding(pos, cos)[:, None, :, :]
|
1224
|
+
sin = F.embedding(pos, sin)[:, None, :, :]
|
1225
|
+
x1, x2 = tokens[..., : tokens.shape[-1] // 2], tokens[..., tokens.shape[-1] // 2 :]
|
1226
|
+
tokens_rotated = torch.cat((-x2, x1), dim=-1)
|
1227
|
+
return (tokens.float() * cos + tokens_rotated.float() * sin).to(tokens.dtype)
|
1228
|
+
|
1229
|
+
(t_cos, t_sin), (h_cos, h_sin), (w_cos, w_sin) = freqs_cis
|
1230
|
+
t, h, w = x.chunk(3, dim=-1)
|
1231
|
+
t = apply_1d_rope(t, positions[0], t_cos, t_sin)
|
1232
|
+
h = apply_1d_rope(h, positions[1], h_cos, h_sin)
|
1233
|
+
w = apply_1d_rope(w, positions[2], w_cos, w_sin)
|
1234
|
+
x = torch.cat([t, h, w], dim=-1)
|
1235
|
+
return x
|
1236
|
+
|
1237
|
+
|
746
1238
|
class FluxPosEmbed(nn.Module):
|
747
1239
|
# modified from https://github.com/black-forest-labs/flux/blob/c00d7c60b085fce8058b9df845e036090873f2ce/src/flux/modules/layers.py#L11
|
748
1240
|
def __init__(self, theta: int, axes_dim: List[int]):
|
@@ -759,7 +1251,12 @@ class FluxPosEmbed(nn.Module):
|
|
759
1251
|
freqs_dtype = torch.float32 if is_mps else torch.float64
|
760
1252
|
for i in range(n_axes):
|
761
1253
|
cos, sin = get_1d_rotary_pos_embed(
|
762
|
-
self.axes_dim[i],
|
1254
|
+
self.axes_dim[i],
|
1255
|
+
pos[:, i],
|
1256
|
+
theta=self.theta,
|
1257
|
+
repeat_interleave_real=True,
|
1258
|
+
use_real=True,
|
1259
|
+
freqs_dtype=freqs_dtype,
|
763
1260
|
)
|
764
1261
|
cos_out.append(cos)
|
765
1262
|
sin_out.append(sin)
|
@@ -1038,7 +1535,7 @@ class ImageProjection(nn.Module):
|
|
1038
1535
|
batch_size = image_embeds.shape[0]
|
1039
1536
|
|
1040
1537
|
# image
|
1041
|
-
image_embeds = self.image_embeds(image_embeds)
|
1538
|
+
image_embeds = self.image_embeds(image_embeds.to(self.image_embeds.weight.dtype))
|
1042
1539
|
image_embeds = image_embeds.reshape(batch_size, self.num_image_text_embeds, -1)
|
1043
1540
|
image_embeds = self.norm(image_embeds)
|
1044
1541
|
return image_embeds
|
@@ -1302,6 +1799,41 @@ class LuminaCombinedTimestepCaptionEmbedding(nn.Module):
|
|
1302
1799
|
return conditioning
|
1303
1800
|
|
1304
1801
|
|
1802
|
+
class MochiCombinedTimestepCaptionEmbedding(nn.Module):
|
1803
|
+
def __init__(
|
1804
|
+
self,
|
1805
|
+
embedding_dim: int,
|
1806
|
+
pooled_projection_dim: int,
|
1807
|
+
text_embed_dim: int,
|
1808
|
+
time_embed_dim: int = 256,
|
1809
|
+
num_attention_heads: int = 8,
|
1810
|
+
) -> None:
|
1811
|
+
super().__init__()
|
1812
|
+
|
1813
|
+
self.time_proj = Timesteps(num_channels=time_embed_dim, flip_sin_to_cos=True, downscale_freq_shift=0.0)
|
1814
|
+
self.timestep_embedder = TimestepEmbedding(in_channels=time_embed_dim, time_embed_dim=embedding_dim)
|
1815
|
+
self.pooler = MochiAttentionPool(
|
1816
|
+
num_attention_heads=num_attention_heads, embed_dim=text_embed_dim, output_dim=embedding_dim
|
1817
|
+
)
|
1818
|
+
self.caption_proj = nn.Linear(text_embed_dim, pooled_projection_dim)
|
1819
|
+
|
1820
|
+
def forward(
|
1821
|
+
self,
|
1822
|
+
timestep: torch.LongTensor,
|
1823
|
+
encoder_hidden_states: torch.Tensor,
|
1824
|
+
encoder_attention_mask: torch.Tensor,
|
1825
|
+
hidden_dtype: Optional[torch.dtype] = None,
|
1826
|
+
):
|
1827
|
+
time_proj = self.time_proj(timestep)
|
1828
|
+
time_emb = self.timestep_embedder(time_proj.to(dtype=hidden_dtype))
|
1829
|
+
|
1830
|
+
pooled_projections = self.pooler(encoder_hidden_states, encoder_attention_mask)
|
1831
|
+
caption_proj = self.caption_proj(encoder_hidden_states)
|
1832
|
+
|
1833
|
+
conditioning = time_emb + pooled_projections
|
1834
|
+
return conditioning, caption_proj
|
1835
|
+
|
1836
|
+
|
1305
1837
|
class TextTimeEmbedding(nn.Module):
|
1306
1838
|
def __init__(self, encoder_dim: int, time_embed_dim: int, num_heads: int = 64):
|
1307
1839
|
super().__init__()
|
@@ -1430,6 +1962,88 @@ class AttentionPooling(nn.Module):
|
|
1430
1962
|
return a[:, 0, :] # cls_token
|
1431
1963
|
|
1432
1964
|
|
1965
|
+
class MochiAttentionPool(nn.Module):
|
1966
|
+
def __init__(
|
1967
|
+
self,
|
1968
|
+
num_attention_heads: int,
|
1969
|
+
embed_dim: int,
|
1970
|
+
output_dim: Optional[int] = None,
|
1971
|
+
) -> None:
|
1972
|
+
super().__init__()
|
1973
|
+
|
1974
|
+
self.output_dim = output_dim or embed_dim
|
1975
|
+
self.num_attention_heads = num_attention_heads
|
1976
|
+
|
1977
|
+
self.to_kv = nn.Linear(embed_dim, 2 * embed_dim)
|
1978
|
+
self.to_q = nn.Linear(embed_dim, embed_dim)
|
1979
|
+
self.to_out = nn.Linear(embed_dim, self.output_dim)
|
1980
|
+
|
1981
|
+
@staticmethod
|
1982
|
+
def pool_tokens(x: torch.Tensor, mask: torch.Tensor, *, keepdim=False) -> torch.Tensor:
|
1983
|
+
"""
|
1984
|
+
Pool tokens in x using mask.
|
1985
|
+
|
1986
|
+
NOTE: We assume x does not require gradients.
|
1987
|
+
|
1988
|
+
Args:
|
1989
|
+
x: (B, L, D) tensor of tokens.
|
1990
|
+
mask: (B, L) boolean tensor indicating which tokens are not padding.
|
1991
|
+
|
1992
|
+
Returns:
|
1993
|
+
pooled: (B, D) tensor of pooled tokens.
|
1994
|
+
"""
|
1995
|
+
assert x.size(1) == mask.size(1) # Expected mask to have same length as tokens.
|
1996
|
+
assert x.size(0) == mask.size(0) # Expected mask to have same batch size as tokens.
|
1997
|
+
mask = mask[:, :, None].to(dtype=x.dtype)
|
1998
|
+
mask = mask / mask.sum(dim=1, keepdim=True).clamp(min=1)
|
1999
|
+
pooled = (x * mask).sum(dim=1, keepdim=keepdim)
|
2000
|
+
return pooled
|
2001
|
+
|
2002
|
+
def forward(self, x: torch.Tensor, mask: torch.BoolTensor) -> torch.Tensor:
|
2003
|
+
r"""
|
2004
|
+
Args:
|
2005
|
+
x (`torch.Tensor`):
|
2006
|
+
Tensor of shape `(B, S, D)` of input tokens.
|
2007
|
+
mask (`torch.Tensor`):
|
2008
|
+
Boolean ensor of shape `(B, S)` indicating which tokens are not padding.
|
2009
|
+
|
2010
|
+
Returns:
|
2011
|
+
`torch.Tensor`:
|
2012
|
+
`(B, D)` tensor of pooled tokens.
|
2013
|
+
"""
|
2014
|
+
D = x.size(2)
|
2015
|
+
|
2016
|
+
# Construct attention mask, shape: (B, 1, num_queries=1, num_keys=1+L).
|
2017
|
+
attn_mask = mask[:, None, None, :].bool() # (B, 1, 1, L).
|
2018
|
+
attn_mask = F.pad(attn_mask, (1, 0), value=True) # (B, 1, 1, 1+L).
|
2019
|
+
|
2020
|
+
# Average non-padding token features. These will be used as the query.
|
2021
|
+
x_pool = self.pool_tokens(x, mask, keepdim=True) # (B, 1, D)
|
2022
|
+
|
2023
|
+
# Concat pooled features to input sequence.
|
2024
|
+
x = torch.cat([x_pool, x], dim=1) # (B, L+1, D)
|
2025
|
+
|
2026
|
+
# Compute queries, keys, values. Only the mean token is used to create a query.
|
2027
|
+
kv = self.to_kv(x) # (B, L+1, 2 * D)
|
2028
|
+
q = self.to_q(x[:, 0]) # (B, D)
|
2029
|
+
|
2030
|
+
# Extract heads.
|
2031
|
+
head_dim = D // self.num_attention_heads
|
2032
|
+
kv = kv.unflatten(2, (2, self.num_attention_heads, head_dim)) # (B, 1+L, 2, H, head_dim)
|
2033
|
+
kv = kv.transpose(1, 3) # (B, H, 2, 1+L, head_dim)
|
2034
|
+
k, v = kv.unbind(2) # (B, H, 1+L, head_dim)
|
2035
|
+
q = q.unflatten(1, (self.num_attention_heads, head_dim)) # (B, H, head_dim)
|
2036
|
+
q = q.unsqueeze(2) # (B, H, 1, head_dim)
|
2037
|
+
|
2038
|
+
# Compute attention.
|
2039
|
+
x = F.scaled_dot_product_attention(q, k, v, attn_mask=attn_mask, dropout_p=0.0) # (B, H, 1, head_dim)
|
2040
|
+
|
2041
|
+
# Concatenate heads and run output.
|
2042
|
+
x = x.squeeze(2).flatten(1, 2) # (B, D = H * head_dim)
|
2043
|
+
x = self.to_out(x)
|
2044
|
+
return x
|
2045
|
+
|
2046
|
+
|
1433
2047
|
def get_fourier_embeds_from_boundingbox(embed_dim, box):
|
1434
2048
|
"""
|
1435
2049
|
Args:
|
@@ -1782,6 +2396,187 @@ class IPAdapterFaceIDPlusImageProjection(nn.Module):
|
|
1782
2396
|
return out
|
1783
2397
|
|
1784
2398
|
|
2399
|
+
class IPAdapterTimeImageProjectionBlock(nn.Module):
|
2400
|
+
"""Block for IPAdapterTimeImageProjection.
|
2401
|
+
|
2402
|
+
Args:
|
2403
|
+
hidden_dim (`int`, defaults to 1280):
|
2404
|
+
The number of hidden channels.
|
2405
|
+
dim_head (`int`, defaults to 64):
|
2406
|
+
The number of head channels.
|
2407
|
+
heads (`int`, defaults to 20):
|
2408
|
+
Parallel attention heads.
|
2409
|
+
ffn_ratio (`int`, defaults to 4):
|
2410
|
+
The expansion ratio of feedforward network hidden layer channels.
|
2411
|
+
"""
|
2412
|
+
|
2413
|
+
def __init__(
|
2414
|
+
self,
|
2415
|
+
hidden_dim: int = 1280,
|
2416
|
+
dim_head: int = 64,
|
2417
|
+
heads: int = 20,
|
2418
|
+
ffn_ratio: int = 4,
|
2419
|
+
) -> None:
|
2420
|
+
super().__init__()
|
2421
|
+
from .attention import FeedForward
|
2422
|
+
|
2423
|
+
self.ln0 = nn.LayerNorm(hidden_dim)
|
2424
|
+
self.ln1 = nn.LayerNorm(hidden_dim)
|
2425
|
+
self.attn = Attention(
|
2426
|
+
query_dim=hidden_dim,
|
2427
|
+
cross_attention_dim=hidden_dim,
|
2428
|
+
dim_head=dim_head,
|
2429
|
+
heads=heads,
|
2430
|
+
bias=False,
|
2431
|
+
out_bias=False,
|
2432
|
+
)
|
2433
|
+
self.ff = FeedForward(hidden_dim, hidden_dim, activation_fn="gelu", mult=ffn_ratio, bias=False)
|
2434
|
+
|
2435
|
+
# AdaLayerNorm
|
2436
|
+
self.adaln_silu = nn.SiLU()
|
2437
|
+
self.adaln_proj = nn.Linear(hidden_dim, 4 * hidden_dim)
|
2438
|
+
self.adaln_norm = nn.LayerNorm(hidden_dim)
|
2439
|
+
|
2440
|
+
# Set attention scale and fuse KV
|
2441
|
+
self.attn.scale = 1 / math.sqrt(math.sqrt(dim_head))
|
2442
|
+
self.attn.fuse_projections()
|
2443
|
+
self.attn.to_k = None
|
2444
|
+
self.attn.to_v = None
|
2445
|
+
|
2446
|
+
def forward(self, x: torch.Tensor, latents: torch.Tensor, timestep_emb: torch.Tensor) -> torch.Tensor:
|
2447
|
+
"""Forward pass.
|
2448
|
+
|
2449
|
+
Args:
|
2450
|
+
x (`torch.Tensor`):
|
2451
|
+
Image features.
|
2452
|
+
latents (`torch.Tensor`):
|
2453
|
+
Latent features.
|
2454
|
+
timestep_emb (`torch.Tensor`):
|
2455
|
+
Timestep embedding.
|
2456
|
+
|
2457
|
+
Returns:
|
2458
|
+
`torch.Tensor`: Output latent features.
|
2459
|
+
"""
|
2460
|
+
|
2461
|
+
# Shift and scale for AdaLayerNorm
|
2462
|
+
emb = self.adaln_proj(self.adaln_silu(timestep_emb))
|
2463
|
+
shift_msa, scale_msa, shift_mlp, scale_mlp = emb.chunk(4, dim=1)
|
2464
|
+
|
2465
|
+
# Fused Attention
|
2466
|
+
residual = latents
|
2467
|
+
x = self.ln0(x)
|
2468
|
+
latents = self.ln1(latents) * (1 + scale_msa[:, None]) + shift_msa[:, None]
|
2469
|
+
|
2470
|
+
batch_size = latents.shape[0]
|
2471
|
+
|
2472
|
+
query = self.attn.to_q(latents)
|
2473
|
+
kv_input = torch.cat((x, latents), dim=-2)
|
2474
|
+
key, value = self.attn.to_kv(kv_input).chunk(2, dim=-1)
|
2475
|
+
|
2476
|
+
inner_dim = key.shape[-1]
|
2477
|
+
head_dim = inner_dim // self.attn.heads
|
2478
|
+
|
2479
|
+
query = query.view(batch_size, -1, self.attn.heads, head_dim).transpose(1, 2)
|
2480
|
+
key = key.view(batch_size, -1, self.attn.heads, head_dim).transpose(1, 2)
|
2481
|
+
value = value.view(batch_size, -1, self.attn.heads, head_dim).transpose(1, 2)
|
2482
|
+
|
2483
|
+
weight = (query * self.attn.scale) @ (key * self.attn.scale).transpose(-2, -1)
|
2484
|
+
weight = torch.softmax(weight.float(), dim=-1).type(weight.dtype)
|
2485
|
+
latents = weight @ value
|
2486
|
+
|
2487
|
+
latents = latents.transpose(1, 2).reshape(batch_size, -1, self.attn.heads * head_dim)
|
2488
|
+
latents = self.attn.to_out[0](latents)
|
2489
|
+
latents = self.attn.to_out[1](latents)
|
2490
|
+
latents = latents + residual
|
2491
|
+
|
2492
|
+
## FeedForward
|
2493
|
+
residual = latents
|
2494
|
+
latents = self.adaln_norm(latents) * (1 + scale_mlp[:, None]) + shift_mlp[:, None]
|
2495
|
+
return self.ff(latents) + residual
|
2496
|
+
|
2497
|
+
|
2498
|
+
# Modified from https://github.com/mlfoundations/open_flamingo/blob/main/open_flamingo/src/helpers.py
|
2499
|
+
class IPAdapterTimeImageProjection(nn.Module):
|
2500
|
+
"""Resampler of SD3 IP-Adapter with timestep embedding.
|
2501
|
+
|
2502
|
+
Args:
|
2503
|
+
embed_dim (`int`, defaults to 1152):
|
2504
|
+
The feature dimension.
|
2505
|
+
output_dim (`int`, defaults to 2432):
|
2506
|
+
The number of output channels.
|
2507
|
+
hidden_dim (`int`, defaults to 1280):
|
2508
|
+
The number of hidden channels.
|
2509
|
+
depth (`int`, defaults to 4):
|
2510
|
+
The number of blocks.
|
2511
|
+
dim_head (`int`, defaults to 64):
|
2512
|
+
The number of head channels.
|
2513
|
+
heads (`int`, defaults to 20):
|
2514
|
+
Parallel attention heads.
|
2515
|
+
num_queries (`int`, defaults to 64):
|
2516
|
+
The number of queries.
|
2517
|
+
ffn_ratio (`int`, defaults to 4):
|
2518
|
+
The expansion ratio of feedforward network hidden layer channels.
|
2519
|
+
timestep_in_dim (`int`, defaults to 320):
|
2520
|
+
The number of input channels for timestep embedding.
|
2521
|
+
timestep_flip_sin_to_cos (`bool`, defaults to True):
|
2522
|
+
Flip the timestep embedding order to `cos, sin` (if True) or `sin, cos` (if False).
|
2523
|
+
timestep_freq_shift (`int`, defaults to 0):
|
2524
|
+
Controls the timestep delta between frequencies between dimensions.
|
2525
|
+
"""
|
2526
|
+
|
2527
|
+
def __init__(
|
2528
|
+
self,
|
2529
|
+
embed_dim: int = 1152,
|
2530
|
+
output_dim: int = 2432,
|
2531
|
+
hidden_dim: int = 1280,
|
2532
|
+
depth: int = 4,
|
2533
|
+
dim_head: int = 64,
|
2534
|
+
heads: int = 20,
|
2535
|
+
num_queries: int = 64,
|
2536
|
+
ffn_ratio: int = 4,
|
2537
|
+
timestep_in_dim: int = 320,
|
2538
|
+
timestep_flip_sin_to_cos: bool = True,
|
2539
|
+
timestep_freq_shift: int = 0,
|
2540
|
+
) -> None:
|
2541
|
+
super().__init__()
|
2542
|
+
self.latents = nn.Parameter(torch.randn(1, num_queries, hidden_dim) / hidden_dim**0.5)
|
2543
|
+
self.proj_in = nn.Linear(embed_dim, hidden_dim)
|
2544
|
+
self.proj_out = nn.Linear(hidden_dim, output_dim)
|
2545
|
+
self.norm_out = nn.LayerNorm(output_dim)
|
2546
|
+
self.layers = nn.ModuleList(
|
2547
|
+
[IPAdapterTimeImageProjectionBlock(hidden_dim, dim_head, heads, ffn_ratio) for _ in range(depth)]
|
2548
|
+
)
|
2549
|
+
self.time_proj = Timesteps(timestep_in_dim, timestep_flip_sin_to_cos, timestep_freq_shift)
|
2550
|
+
self.time_embedding = TimestepEmbedding(timestep_in_dim, hidden_dim, act_fn="silu")
|
2551
|
+
|
2552
|
+
def forward(self, x: torch.Tensor, timestep: torch.Tensor) -> Tuple[torch.Tensor, torch.Tensor]:
|
2553
|
+
"""Forward pass.
|
2554
|
+
|
2555
|
+
Args:
|
2556
|
+
x (`torch.Tensor`):
|
2557
|
+
Image features.
|
2558
|
+
timestep (`torch.Tensor`):
|
2559
|
+
Timestep in denoising process.
|
2560
|
+
Returns:
|
2561
|
+
`Tuple`[`torch.Tensor`, `torch.Tensor`]: The pair (latents, timestep_emb).
|
2562
|
+
"""
|
2563
|
+
timestep_emb = self.time_proj(timestep).to(dtype=x.dtype)
|
2564
|
+
timestep_emb = self.time_embedding(timestep_emb)
|
2565
|
+
|
2566
|
+
latents = self.latents.repeat(x.size(0), 1, 1)
|
2567
|
+
|
2568
|
+
x = self.proj_in(x)
|
2569
|
+
x = x + timestep_emb[:, None]
|
2570
|
+
|
2571
|
+
for block in self.layers:
|
2572
|
+
latents = block(x, latents, timestep_emb)
|
2573
|
+
|
2574
|
+
latents = self.proj_out(latents)
|
2575
|
+
latents = self.norm_out(latents)
|
2576
|
+
|
2577
|
+
return latents, timestep_emb
|
2578
|
+
|
2579
|
+
|
1785
2580
|
class MultiIPAdapterImageProjection(nn.Module):
|
1786
2581
|
def __init__(self, IPAdapterImageProjectionLayers: Union[List[nn.Module], Tuple[nn.Module]]):
|
1787
2582
|
super().__init__()
|