diffusers 0.31.0__py3-none-any.whl → 0.32.0__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (214) hide show
  1. diffusers/__init__.py +66 -5
  2. diffusers/callbacks.py +56 -3
  3. diffusers/configuration_utils.py +1 -1
  4. diffusers/dependency_versions_table.py +1 -1
  5. diffusers/image_processor.py +25 -17
  6. diffusers/loaders/__init__.py +22 -3
  7. diffusers/loaders/ip_adapter.py +538 -15
  8. diffusers/loaders/lora_base.py +124 -118
  9. diffusers/loaders/lora_conversion_utils.py +318 -3
  10. diffusers/loaders/lora_pipeline.py +1688 -368
  11. diffusers/loaders/peft.py +379 -0
  12. diffusers/loaders/single_file_model.py +71 -4
  13. diffusers/loaders/single_file_utils.py +519 -9
  14. diffusers/loaders/textual_inversion.py +3 -3
  15. diffusers/loaders/transformer_flux.py +181 -0
  16. diffusers/loaders/transformer_sd3.py +89 -0
  17. diffusers/loaders/unet.py +17 -4
  18. diffusers/models/__init__.py +47 -14
  19. diffusers/models/activations.py +22 -9
  20. diffusers/models/attention.py +13 -4
  21. diffusers/models/attention_flax.py +1 -1
  22. diffusers/models/attention_processor.py +2059 -281
  23. diffusers/models/autoencoders/__init__.py +5 -0
  24. diffusers/models/autoencoders/autoencoder_dc.py +620 -0
  25. diffusers/models/autoencoders/autoencoder_kl.py +2 -1
  26. diffusers/models/autoencoders/autoencoder_kl_allegro.py +1149 -0
  27. diffusers/models/autoencoders/autoencoder_kl_cogvideox.py +36 -27
  28. diffusers/models/autoencoders/autoencoder_kl_hunyuan_video.py +1176 -0
  29. diffusers/models/autoencoders/autoencoder_kl_ltx.py +1338 -0
  30. diffusers/models/autoencoders/autoencoder_kl_mochi.py +1166 -0
  31. diffusers/models/autoencoders/autoencoder_kl_temporal_decoder.py +3 -10
  32. diffusers/models/autoencoders/autoencoder_tiny.py +4 -2
  33. diffusers/models/autoencoders/vae.py +18 -5
  34. diffusers/models/controlnet.py +47 -802
  35. diffusers/models/controlnet_flux.py +29 -495
  36. diffusers/models/controlnet_sd3.py +25 -379
  37. diffusers/models/controlnet_sparsectrl.py +46 -718
  38. diffusers/models/controlnets/__init__.py +23 -0
  39. diffusers/models/controlnets/controlnet.py +872 -0
  40. diffusers/models/{controlnet_flax.py → controlnets/controlnet_flax.py} +5 -5
  41. diffusers/models/controlnets/controlnet_flux.py +536 -0
  42. diffusers/models/{controlnet_hunyuan.py → controlnets/controlnet_hunyuan.py} +7 -7
  43. diffusers/models/controlnets/controlnet_sd3.py +489 -0
  44. diffusers/models/controlnets/controlnet_sparsectrl.py +788 -0
  45. diffusers/models/controlnets/controlnet_union.py +832 -0
  46. diffusers/models/{controlnet_xs.py → controlnets/controlnet_xs.py} +14 -13
  47. diffusers/models/controlnets/multicontrolnet.py +183 -0
  48. diffusers/models/embeddings.py +838 -43
  49. diffusers/models/model_loading_utils.py +88 -6
  50. diffusers/models/modeling_flax_utils.py +1 -1
  51. diffusers/models/modeling_utils.py +74 -28
  52. diffusers/models/normalization.py +78 -13
  53. diffusers/models/transformers/__init__.py +5 -0
  54. diffusers/models/transformers/auraflow_transformer_2d.py +2 -2
  55. diffusers/models/transformers/cogvideox_transformer_3d.py +46 -11
  56. diffusers/models/transformers/dit_transformer_2d.py +1 -1
  57. diffusers/models/transformers/latte_transformer_3d.py +4 -4
  58. diffusers/models/transformers/pixart_transformer_2d.py +1 -1
  59. diffusers/models/transformers/sana_transformer.py +488 -0
  60. diffusers/models/transformers/stable_audio_transformer.py +1 -1
  61. diffusers/models/transformers/transformer_2d.py +1 -1
  62. diffusers/models/transformers/transformer_allegro.py +422 -0
  63. diffusers/models/transformers/transformer_cogview3plus.py +1 -1
  64. diffusers/models/transformers/transformer_flux.py +30 -9
  65. diffusers/models/transformers/transformer_hunyuan_video.py +789 -0
  66. diffusers/models/transformers/transformer_ltx.py +469 -0
  67. diffusers/models/transformers/transformer_mochi.py +499 -0
  68. diffusers/models/transformers/transformer_sd3.py +105 -17
  69. diffusers/models/transformers/transformer_temporal.py +1 -1
  70. diffusers/models/unets/unet_1d_blocks.py +1 -1
  71. diffusers/models/unets/unet_2d.py +8 -1
  72. diffusers/models/unets/unet_2d_blocks.py +88 -21
  73. diffusers/models/unets/unet_2d_condition.py +1 -1
  74. diffusers/models/unets/unet_3d_blocks.py +9 -7
  75. diffusers/models/unets/unet_motion_model.py +5 -5
  76. diffusers/models/unets/unet_spatio_temporal_condition.py +23 -0
  77. diffusers/models/unets/unet_stable_cascade.py +2 -2
  78. diffusers/models/unets/uvit_2d.py +1 -1
  79. diffusers/models/upsampling.py +8 -0
  80. diffusers/pipelines/__init__.py +34 -0
  81. diffusers/pipelines/allegro/__init__.py +48 -0
  82. diffusers/pipelines/allegro/pipeline_allegro.py +938 -0
  83. diffusers/pipelines/allegro/pipeline_output.py +23 -0
  84. diffusers/pipelines/animatediff/pipeline_animatediff_controlnet.py +8 -2
  85. diffusers/pipelines/animatediff/pipeline_animatediff_sparsectrl.py +1 -1
  86. diffusers/pipelines/animatediff/pipeline_animatediff_video2video.py +0 -6
  87. diffusers/pipelines/animatediff/pipeline_animatediff_video2video_controlnet.py +8 -8
  88. diffusers/pipelines/audioldm2/modeling_audioldm2.py +3 -3
  89. diffusers/pipelines/aura_flow/pipeline_aura_flow.py +1 -8
  90. diffusers/pipelines/auto_pipeline.py +53 -6
  91. diffusers/pipelines/blip_diffusion/modeling_blip2.py +1 -1
  92. diffusers/pipelines/cogvideo/pipeline_cogvideox.py +50 -22
  93. diffusers/pipelines/cogvideo/pipeline_cogvideox_fun_control.py +51 -20
  94. diffusers/pipelines/cogvideo/pipeline_cogvideox_image2video.py +69 -21
  95. diffusers/pipelines/cogvideo/pipeline_cogvideox_video2video.py +47 -21
  96. diffusers/pipelines/cogview3/pipeline_cogview3plus.py +1 -1
  97. diffusers/pipelines/controlnet/__init__.py +86 -80
  98. diffusers/pipelines/controlnet/multicontrolnet.py +7 -178
  99. diffusers/pipelines/controlnet/pipeline_controlnet.py +11 -2
  100. diffusers/pipelines/controlnet/pipeline_controlnet_img2img.py +1 -2
  101. diffusers/pipelines/controlnet/pipeline_controlnet_inpaint.py +1 -2
  102. diffusers/pipelines/controlnet/pipeline_controlnet_inpaint_sd_xl.py +1 -2
  103. diffusers/pipelines/controlnet/pipeline_controlnet_sd_xl.py +3 -3
  104. diffusers/pipelines/controlnet/pipeline_controlnet_sd_xl_img2img.py +1 -3
  105. diffusers/pipelines/controlnet/pipeline_controlnet_union_inpaint_sd_xl.py +1790 -0
  106. diffusers/pipelines/controlnet/pipeline_controlnet_union_sd_xl.py +1501 -0
  107. diffusers/pipelines/controlnet/pipeline_controlnet_union_sd_xl_img2img.py +1627 -0
  108. diffusers/pipelines/controlnet_hunyuandit/pipeline_hunyuandit_controlnet.py +5 -1
  109. diffusers/pipelines/controlnet_sd3/pipeline_stable_diffusion_3_controlnet.py +53 -19
  110. diffusers/pipelines/controlnet_sd3/pipeline_stable_diffusion_3_controlnet_inpainting.py +7 -7
  111. diffusers/pipelines/deprecated/versatile_diffusion/modeling_text_unet.py +31 -8
  112. diffusers/pipelines/flux/__init__.py +13 -1
  113. diffusers/pipelines/flux/modeling_flux.py +47 -0
  114. diffusers/pipelines/flux/pipeline_flux.py +204 -29
  115. diffusers/pipelines/flux/pipeline_flux_control.py +889 -0
  116. diffusers/pipelines/flux/pipeline_flux_control_img2img.py +945 -0
  117. diffusers/pipelines/flux/pipeline_flux_control_inpaint.py +1141 -0
  118. diffusers/pipelines/flux/pipeline_flux_controlnet.py +49 -27
  119. diffusers/pipelines/flux/pipeline_flux_controlnet_image_to_image.py +40 -30
  120. diffusers/pipelines/flux/pipeline_flux_controlnet_inpainting.py +78 -56
  121. diffusers/pipelines/flux/pipeline_flux_fill.py +969 -0
  122. diffusers/pipelines/flux/pipeline_flux_img2img.py +33 -27
  123. diffusers/pipelines/flux/pipeline_flux_inpaint.py +36 -29
  124. diffusers/pipelines/flux/pipeline_flux_prior_redux.py +492 -0
  125. diffusers/pipelines/flux/pipeline_output.py +16 -0
  126. diffusers/pipelines/hunyuan_video/__init__.py +48 -0
  127. diffusers/pipelines/hunyuan_video/pipeline_hunyuan_video.py +687 -0
  128. diffusers/pipelines/hunyuan_video/pipeline_output.py +20 -0
  129. diffusers/pipelines/hunyuandit/pipeline_hunyuandit.py +5 -1
  130. diffusers/pipelines/kandinsky/pipeline_kandinsky_combined.py +9 -9
  131. diffusers/pipelines/kolors/text_encoder.py +2 -2
  132. diffusers/pipelines/latent_diffusion/pipeline_latent_diffusion.py +1 -1
  133. diffusers/pipelines/ltx/__init__.py +50 -0
  134. diffusers/pipelines/ltx/pipeline_ltx.py +789 -0
  135. diffusers/pipelines/ltx/pipeline_ltx_image2video.py +885 -0
  136. diffusers/pipelines/ltx/pipeline_output.py +20 -0
  137. diffusers/pipelines/lumina/pipeline_lumina.py +1 -8
  138. diffusers/pipelines/mochi/__init__.py +48 -0
  139. diffusers/pipelines/mochi/pipeline_mochi.py +748 -0
  140. diffusers/pipelines/mochi/pipeline_output.py +20 -0
  141. diffusers/pipelines/pag/__init__.py +7 -0
  142. diffusers/pipelines/pag/pipeline_pag_controlnet_sd.py +1 -2
  143. diffusers/pipelines/pag/pipeline_pag_controlnet_sd_inpaint.py +1 -2
  144. diffusers/pipelines/pag/pipeline_pag_controlnet_sd_xl.py +1 -3
  145. diffusers/pipelines/pag/pipeline_pag_controlnet_sd_xl_img2img.py +1 -3
  146. diffusers/pipelines/pag/pipeline_pag_hunyuandit.py +5 -1
  147. diffusers/pipelines/pag/pipeline_pag_pixart_sigma.py +6 -13
  148. diffusers/pipelines/pag/pipeline_pag_sana.py +886 -0
  149. diffusers/pipelines/pag/pipeline_pag_sd_3.py +6 -6
  150. diffusers/pipelines/pag/pipeline_pag_sd_3_img2img.py +1058 -0
  151. diffusers/pipelines/pag/pipeline_pag_sd_img2img.py +3 -0
  152. diffusers/pipelines/pag/pipeline_pag_sd_inpaint.py +1356 -0
  153. diffusers/pipelines/pipeline_flax_utils.py +1 -1
  154. diffusers/pipelines/pipeline_loading_utils.py +25 -4
  155. diffusers/pipelines/pipeline_utils.py +35 -6
  156. diffusers/pipelines/pixart_alpha/pipeline_pixart_alpha.py +6 -13
  157. diffusers/pipelines/pixart_alpha/pipeline_pixart_sigma.py +6 -13
  158. diffusers/pipelines/sana/__init__.py +47 -0
  159. diffusers/pipelines/sana/pipeline_output.py +21 -0
  160. diffusers/pipelines/sana/pipeline_sana.py +884 -0
  161. diffusers/pipelines/stable_audio/pipeline_stable_audio.py +12 -1
  162. diffusers/pipelines/stable_diffusion/pipeline_stable_diffusion.py +18 -3
  163. diffusers/pipelines/stable_diffusion_3/pipeline_stable_diffusion_3.py +216 -20
  164. diffusers/pipelines/stable_diffusion_3/pipeline_stable_diffusion_3_img2img.py +62 -9
  165. diffusers/pipelines/stable_diffusion_3/pipeline_stable_diffusion_3_inpaint.py +57 -8
  166. diffusers/pipelines/stable_diffusion_gligen/pipeline_stable_diffusion_gligen_text_image.py +11 -1
  167. diffusers/pipelines/stable_diffusion_xl/pipeline_stable_diffusion_xl.py +0 -8
  168. diffusers/pipelines/stable_diffusion_xl/pipeline_stable_diffusion_xl_img2img.py +0 -8
  169. diffusers/pipelines/stable_diffusion_xl/pipeline_stable_diffusion_xl_inpaint.py +0 -8
  170. diffusers/pipelines/unidiffuser/modeling_uvit.py +2 -2
  171. diffusers/pipelines/wuerstchen/modeling_wuerstchen_prior.py +1 -1
  172. diffusers/quantizers/auto.py +14 -1
  173. diffusers/quantizers/bitsandbytes/bnb_quantizer.py +4 -1
  174. diffusers/quantizers/gguf/__init__.py +1 -0
  175. diffusers/quantizers/gguf/gguf_quantizer.py +159 -0
  176. diffusers/quantizers/gguf/utils.py +456 -0
  177. diffusers/quantizers/quantization_config.py +280 -2
  178. diffusers/quantizers/torchao/__init__.py +15 -0
  179. diffusers/quantizers/torchao/torchao_quantizer.py +285 -0
  180. diffusers/schedulers/scheduling_ddpm.py +2 -6
  181. diffusers/schedulers/scheduling_ddpm_parallel.py +2 -6
  182. diffusers/schedulers/scheduling_deis_multistep.py +28 -9
  183. diffusers/schedulers/scheduling_dpmsolver_multistep.py +35 -9
  184. diffusers/schedulers/scheduling_dpmsolver_multistep_inverse.py +35 -8
  185. diffusers/schedulers/scheduling_dpmsolver_sde.py +4 -4
  186. diffusers/schedulers/scheduling_dpmsolver_singlestep.py +48 -10
  187. diffusers/schedulers/scheduling_euler_discrete.py +4 -4
  188. diffusers/schedulers/scheduling_flow_match_euler_discrete.py +153 -6
  189. diffusers/schedulers/scheduling_heun_discrete.py +4 -4
  190. diffusers/schedulers/scheduling_k_dpm_2_ancestral_discrete.py +4 -4
  191. diffusers/schedulers/scheduling_k_dpm_2_discrete.py +4 -4
  192. diffusers/schedulers/scheduling_lcm.py +2 -6
  193. diffusers/schedulers/scheduling_lms_discrete.py +4 -4
  194. diffusers/schedulers/scheduling_repaint.py +1 -1
  195. diffusers/schedulers/scheduling_sasolver.py +28 -9
  196. diffusers/schedulers/scheduling_tcd.py +2 -6
  197. diffusers/schedulers/scheduling_unipc_multistep.py +53 -8
  198. diffusers/training_utils.py +16 -2
  199. diffusers/utils/__init__.py +5 -0
  200. diffusers/utils/constants.py +1 -0
  201. diffusers/utils/dummy_pt_objects.py +180 -0
  202. diffusers/utils/dummy_torch_and_transformers_objects.py +270 -0
  203. diffusers/utils/dynamic_modules_utils.py +3 -3
  204. diffusers/utils/hub_utils.py +31 -39
  205. diffusers/utils/import_utils.py +67 -0
  206. diffusers/utils/peft_utils.py +3 -0
  207. diffusers/utils/testing_utils.py +56 -1
  208. diffusers/utils/torch_utils.py +3 -0
  209. {diffusers-0.31.0.dist-info → diffusers-0.32.0.dist-info}/METADATA +69 -69
  210. {diffusers-0.31.0.dist-info → diffusers-0.32.0.dist-info}/RECORD +214 -162
  211. {diffusers-0.31.0.dist-info → diffusers-0.32.0.dist-info}/WHEEL +1 -1
  212. {diffusers-0.31.0.dist-info → diffusers-0.32.0.dist-info}/LICENSE +0 -0
  213. {diffusers-0.31.0.dist-info → diffusers-0.32.0.dist-info}/entry_points.txt +0 -0
  214. {diffusers-0.31.0.dist-info → diffusers-0.32.0.dist-info}/top_level.txt +0 -0
@@ -84,15 +84,106 @@ def get_3d_sincos_pos_embed(
84
84
  temporal_size: int,
85
85
  spatial_interpolation_scale: float = 1.0,
86
86
  temporal_interpolation_scale: float = 1.0,
87
+ device: Optional[torch.device] = None,
88
+ output_type: str = "np",
89
+ ) -> torch.Tensor:
90
+ r"""
91
+ Creates 3D sinusoidal positional embeddings.
92
+
93
+ Args:
94
+ embed_dim (`int`):
95
+ The embedding dimension of inputs. It must be divisible by 16.
96
+ spatial_size (`int` or `Tuple[int, int]`):
97
+ The spatial dimension of positional embeddings. If an integer is provided, the same size is applied to both
98
+ spatial dimensions (height and width).
99
+ temporal_size (`int`):
100
+ The temporal dimension of postional embeddings (number of frames).
101
+ spatial_interpolation_scale (`float`, defaults to 1.0):
102
+ Scale factor for spatial grid interpolation.
103
+ temporal_interpolation_scale (`float`, defaults to 1.0):
104
+ Scale factor for temporal grid interpolation.
105
+
106
+ Returns:
107
+ `torch.Tensor`:
108
+ The 3D sinusoidal positional embeddings of shape `[temporal_size, spatial_size[0] * spatial_size[1],
109
+ embed_dim]`.
110
+ """
111
+ if output_type == "np":
112
+ return _get_3d_sincos_pos_embed_np(
113
+ embed_dim=embed_dim,
114
+ spatial_size=spatial_size,
115
+ temporal_size=temporal_size,
116
+ spatial_interpolation_scale=spatial_interpolation_scale,
117
+ temporal_interpolation_scale=temporal_interpolation_scale,
118
+ )
119
+ if embed_dim % 4 != 0:
120
+ raise ValueError("`embed_dim` must be divisible by 4")
121
+ if isinstance(spatial_size, int):
122
+ spatial_size = (spatial_size, spatial_size)
123
+
124
+ embed_dim_spatial = 3 * embed_dim // 4
125
+ embed_dim_temporal = embed_dim // 4
126
+
127
+ # 1. Spatial
128
+ grid_h = torch.arange(spatial_size[1], device=device, dtype=torch.float32) / spatial_interpolation_scale
129
+ grid_w = torch.arange(spatial_size[0], device=device, dtype=torch.float32) / spatial_interpolation_scale
130
+ grid = torch.meshgrid(grid_w, grid_h, indexing="xy") # here w goes first
131
+ grid = torch.stack(grid, dim=0)
132
+
133
+ grid = grid.reshape([2, 1, spatial_size[1], spatial_size[0]])
134
+ pos_embed_spatial = get_2d_sincos_pos_embed_from_grid(embed_dim_spatial, grid, output_type="pt")
135
+
136
+ # 2. Temporal
137
+ grid_t = torch.arange(temporal_size, device=device, dtype=torch.float32) / temporal_interpolation_scale
138
+ pos_embed_temporal = get_1d_sincos_pos_embed_from_grid(embed_dim_temporal, grid_t, output_type="pt")
139
+
140
+ # 3. Concat
141
+ pos_embed_spatial = pos_embed_spatial[None, :, :]
142
+ pos_embed_spatial = pos_embed_spatial.repeat_interleave(temporal_size, dim=0) # [T, H*W, D // 4 * 3]
143
+
144
+ pos_embed_temporal = pos_embed_temporal[:, None, :]
145
+ pos_embed_temporal = pos_embed_temporal.repeat_interleave(
146
+ spatial_size[0] * spatial_size[1], dim=1
147
+ ) # [T, H*W, D // 4]
148
+
149
+ pos_embed = torch.concat([pos_embed_temporal, pos_embed_spatial], dim=-1) # [T, H*W, D]
150
+ return pos_embed
151
+
152
+
153
+ def _get_3d_sincos_pos_embed_np(
154
+ embed_dim: int,
155
+ spatial_size: Union[int, Tuple[int, int]],
156
+ temporal_size: int,
157
+ spatial_interpolation_scale: float = 1.0,
158
+ temporal_interpolation_scale: float = 1.0,
87
159
  ) -> np.ndarray:
88
160
  r"""
161
+ Creates 3D sinusoidal positional embeddings.
162
+
89
163
  Args:
90
164
  embed_dim (`int`):
165
+ The embedding dimension of inputs. It must be divisible by 16.
91
166
  spatial_size (`int` or `Tuple[int, int]`):
167
+ The spatial dimension of positional embeddings. If an integer is provided, the same size is applied to both
168
+ spatial dimensions (height and width).
92
169
  temporal_size (`int`):
170
+ The temporal dimension of postional embeddings (number of frames).
93
171
  spatial_interpolation_scale (`float`, defaults to 1.0):
172
+ Scale factor for spatial grid interpolation.
94
173
  temporal_interpolation_scale (`float`, defaults to 1.0):
174
+ Scale factor for temporal grid interpolation.
175
+
176
+ Returns:
177
+ `np.ndarray`:
178
+ The 3D sinusoidal positional embeddings of shape `[temporal_size, spatial_size[0] * spatial_size[1],
179
+ embed_dim]`.
95
180
  """
181
+ deprecation_message = (
182
+ "`get_3d_sincos_pos_embed` uses `torch` and supports `device`."
183
+ " `from_numpy` is no longer required."
184
+ " Pass `output_type='pt' to use the new version now."
185
+ )
186
+ deprecate("output_type=='np'", "0.33.0", deprecation_message, standard_warn=False)
96
187
  if embed_dim % 4 != 0:
97
188
  raise ValueError("`embed_dim` must be divisible by 4")
98
189
  if isinstance(spatial_size, int):
@@ -126,11 +217,164 @@ def get_3d_sincos_pos_embed(
126
217
 
127
218
 
128
219
  def get_2d_sincos_pos_embed(
220
+ embed_dim,
221
+ grid_size,
222
+ cls_token=False,
223
+ extra_tokens=0,
224
+ interpolation_scale=1.0,
225
+ base_size=16,
226
+ device: Optional[torch.device] = None,
227
+ output_type: str = "np",
228
+ ):
229
+ """
230
+ Creates 2D sinusoidal positional embeddings.
231
+
232
+ Args:
233
+ embed_dim (`int`):
234
+ The embedding dimension.
235
+ grid_size (`int`):
236
+ The size of the grid height and width.
237
+ cls_token (`bool`, defaults to `False`):
238
+ Whether or not to add a classification token.
239
+ extra_tokens (`int`, defaults to `0`):
240
+ The number of extra tokens to add.
241
+ interpolation_scale (`float`, defaults to `1.0`):
242
+ The scale of the interpolation.
243
+
244
+ Returns:
245
+ pos_embed (`torch.Tensor`):
246
+ Shape is either `[grid_size * grid_size, embed_dim]` if not using cls_token, or `[1 + grid_size*grid_size,
247
+ embed_dim]` if using cls_token
248
+ """
249
+ if output_type == "np":
250
+ deprecation_message = (
251
+ "`get_2d_sincos_pos_embed` uses `torch` and supports `device`."
252
+ " `from_numpy` is no longer required."
253
+ " Pass `output_type='pt' to use the new version now."
254
+ )
255
+ deprecate("output_type=='np'", "0.33.0", deprecation_message, standard_warn=False)
256
+ return get_2d_sincos_pos_embed_np(
257
+ embed_dim=embed_dim,
258
+ grid_size=grid_size,
259
+ cls_token=cls_token,
260
+ extra_tokens=extra_tokens,
261
+ interpolation_scale=interpolation_scale,
262
+ base_size=base_size,
263
+ )
264
+ if isinstance(grid_size, int):
265
+ grid_size = (grid_size, grid_size)
266
+
267
+ grid_h = (
268
+ torch.arange(grid_size[0], device=device, dtype=torch.float32)
269
+ / (grid_size[0] / base_size)
270
+ / interpolation_scale
271
+ )
272
+ grid_w = (
273
+ torch.arange(grid_size[1], device=device, dtype=torch.float32)
274
+ / (grid_size[1] / base_size)
275
+ / interpolation_scale
276
+ )
277
+ grid = torch.meshgrid(grid_w, grid_h, indexing="xy") # here w goes first
278
+ grid = torch.stack(grid, dim=0)
279
+
280
+ grid = grid.reshape([2, 1, grid_size[1], grid_size[0]])
281
+ pos_embed = get_2d_sincos_pos_embed_from_grid(embed_dim, grid, output_type=output_type)
282
+ if cls_token and extra_tokens > 0:
283
+ pos_embed = torch.concat([torch.zeros([extra_tokens, embed_dim]), pos_embed], dim=0)
284
+ return pos_embed
285
+
286
+
287
+ def get_2d_sincos_pos_embed_from_grid(embed_dim, grid, output_type="np"):
288
+ r"""
289
+ This function generates 2D sinusoidal positional embeddings from a grid.
290
+
291
+ Args:
292
+ embed_dim (`int`): The embedding dimension.
293
+ grid (`torch.Tensor`): Grid of positions with shape `(H * W,)`.
294
+
295
+ Returns:
296
+ `torch.Tensor`: The 2D sinusoidal positional embeddings with shape `(H * W, embed_dim)`
297
+ """
298
+ if output_type == "np":
299
+ deprecation_message = (
300
+ "`get_2d_sincos_pos_embed_from_grid` uses `torch` and supports `device`."
301
+ " `from_numpy` is no longer required."
302
+ " Pass `output_type='pt' to use the new version now."
303
+ )
304
+ deprecate("output_type=='np'", "0.33.0", deprecation_message, standard_warn=False)
305
+ return get_2d_sincos_pos_embed_from_grid_np(
306
+ embed_dim=embed_dim,
307
+ grid=grid,
308
+ )
309
+ if embed_dim % 2 != 0:
310
+ raise ValueError("embed_dim must be divisible by 2")
311
+
312
+ # use half of dimensions to encode grid_h
313
+ emb_h = get_1d_sincos_pos_embed_from_grid(embed_dim // 2, grid[0], output_type=output_type) # (H*W, D/2)
314
+ emb_w = get_1d_sincos_pos_embed_from_grid(embed_dim // 2, grid[1], output_type=output_type) # (H*W, D/2)
315
+
316
+ emb = torch.concat([emb_h, emb_w], dim=1) # (H*W, D)
317
+ return emb
318
+
319
+
320
+ def get_1d_sincos_pos_embed_from_grid(embed_dim, pos, output_type="np"):
321
+ """
322
+ This function generates 1D positional embeddings from a grid.
323
+
324
+ Args:
325
+ embed_dim (`int`): The embedding dimension `D`
326
+ pos (`torch.Tensor`): 1D tensor of positions with shape `(M,)`
327
+
328
+ Returns:
329
+ `torch.Tensor`: Sinusoidal positional embeddings of shape `(M, D)`.
330
+ """
331
+ if output_type == "np":
332
+ deprecation_message = (
333
+ "`get_1d_sincos_pos_embed_from_grid` uses `torch` and supports `device`."
334
+ " `from_numpy` is no longer required."
335
+ " Pass `output_type='pt' to use the new version now."
336
+ )
337
+ deprecate("output_type=='np'", "0.33.0", deprecation_message, standard_warn=False)
338
+ return get_1d_sincos_pos_embed_from_grid_np(embed_dim=embed_dim, pos=pos)
339
+ if embed_dim % 2 != 0:
340
+ raise ValueError("embed_dim must be divisible by 2")
341
+
342
+ omega = torch.arange(embed_dim // 2, device=pos.device, dtype=torch.float64)
343
+ omega /= embed_dim / 2.0
344
+ omega = 1.0 / 10000**omega # (D/2,)
345
+
346
+ pos = pos.reshape(-1) # (M,)
347
+ out = torch.outer(pos, omega) # (M, D/2), outer product
348
+
349
+ emb_sin = torch.sin(out) # (M, D/2)
350
+ emb_cos = torch.cos(out) # (M, D/2)
351
+
352
+ emb = torch.concat([emb_sin, emb_cos], dim=1) # (M, D)
353
+ return emb
354
+
355
+
356
+ def get_2d_sincos_pos_embed_np(
129
357
  embed_dim, grid_size, cls_token=False, extra_tokens=0, interpolation_scale=1.0, base_size=16
130
358
  ):
131
359
  """
132
- grid_size: int of the grid height and width return: pos_embed: [grid_size*grid_size, embed_dim] or
133
- [1+grid_size*grid_size, embed_dim] (w/ or w/o cls_token)
360
+ Creates 2D sinusoidal positional embeddings.
361
+
362
+ Args:
363
+ embed_dim (`int`):
364
+ The embedding dimension.
365
+ grid_size (`int`):
366
+ The size of the grid height and width.
367
+ cls_token (`bool`, defaults to `False`):
368
+ Whether or not to add a classification token.
369
+ extra_tokens (`int`, defaults to `0`):
370
+ The number of extra tokens to add.
371
+ interpolation_scale (`float`, defaults to `1.0`):
372
+ The scale of the interpolation.
373
+
374
+ Returns:
375
+ pos_embed (`np.ndarray`):
376
+ Shape is either `[grid_size * grid_size, embed_dim]` if not using cls_token, or `[1 + grid_size*grid_size,
377
+ embed_dim]` if using cls_token
134
378
  """
135
379
  if isinstance(grid_size, int):
136
380
  grid_size = (grid_size, grid_size)
@@ -141,27 +385,44 @@ def get_2d_sincos_pos_embed(
141
385
  grid = np.stack(grid, axis=0)
142
386
 
143
387
  grid = grid.reshape([2, 1, grid_size[1], grid_size[0]])
144
- pos_embed = get_2d_sincos_pos_embed_from_grid(embed_dim, grid)
388
+ pos_embed = get_2d_sincos_pos_embed_from_grid_np(embed_dim, grid)
145
389
  if cls_token and extra_tokens > 0:
146
390
  pos_embed = np.concatenate([np.zeros([extra_tokens, embed_dim]), pos_embed], axis=0)
147
391
  return pos_embed
148
392
 
149
393
 
150
- def get_2d_sincos_pos_embed_from_grid(embed_dim, grid):
394
+ def get_2d_sincos_pos_embed_from_grid_np(embed_dim, grid):
395
+ r"""
396
+ This function generates 2D sinusoidal positional embeddings from a grid.
397
+
398
+ Args:
399
+ embed_dim (`int`): The embedding dimension.
400
+ grid (`np.ndarray`): Grid of positions with shape `(H * W,)`.
401
+
402
+ Returns:
403
+ `np.ndarray`: The 2D sinusoidal positional embeddings with shape `(H * W, embed_dim)`
404
+ """
151
405
  if embed_dim % 2 != 0:
152
406
  raise ValueError("embed_dim must be divisible by 2")
153
407
 
154
408
  # use half of dimensions to encode grid_h
155
- emb_h = get_1d_sincos_pos_embed_from_grid(embed_dim // 2, grid[0]) # (H*W, D/2)
156
- emb_w = get_1d_sincos_pos_embed_from_grid(embed_dim // 2, grid[1]) # (H*W, D/2)
409
+ emb_h = get_1d_sincos_pos_embed_from_grid_np(embed_dim // 2, grid[0]) # (H*W, D/2)
410
+ emb_w = get_1d_sincos_pos_embed_from_grid_np(embed_dim // 2, grid[1]) # (H*W, D/2)
157
411
 
158
412
  emb = np.concatenate([emb_h, emb_w], axis=1) # (H*W, D)
159
413
  return emb
160
414
 
161
415
 
162
- def get_1d_sincos_pos_embed_from_grid(embed_dim, pos):
416
+ def get_1d_sincos_pos_embed_from_grid_np(embed_dim, pos):
163
417
  """
164
- embed_dim: output dimension for each position pos: a list of positions to be encoded: size (M,) out: (M, D)
418
+ This function generates 1D positional embeddings from a grid.
419
+
420
+ Args:
421
+ embed_dim (`int`): The embedding dimension `D`
422
+ pos (`numpy.ndarray`): 1D tensor of positions with shape `(M,)`
423
+
424
+ Returns:
425
+ `numpy.ndarray`: Sinusoidal positional embeddings of shape `(M, D)`.
165
426
  """
166
427
  if embed_dim % 2 != 0:
167
428
  raise ValueError("embed_dim must be divisible by 2")
@@ -181,7 +442,22 @@ def get_1d_sincos_pos_embed_from_grid(embed_dim, pos):
181
442
 
182
443
 
183
444
  class PatchEmbed(nn.Module):
184
- """2D Image to Patch Embedding with support for SD3 cropping."""
445
+ """
446
+ 2D Image to Patch Embedding with support for SD3 cropping.
447
+
448
+ Args:
449
+ height (`int`, defaults to `224`): The height of the image.
450
+ width (`int`, defaults to `224`): The width of the image.
451
+ patch_size (`int`, defaults to `16`): The size of the patches.
452
+ in_channels (`int`, defaults to `3`): The number of input channels.
453
+ embed_dim (`int`, defaults to `768`): The output dimension of the embedding.
454
+ layer_norm (`bool`, defaults to `False`): Whether or not to use layer normalization.
455
+ flatten (`bool`, defaults to `True`): Whether or not to flatten the output.
456
+ bias (`bool`, defaults to `True`): Whether or not to use bias.
457
+ interpolation_scale (`float`, defaults to `1`): The scale of the interpolation.
458
+ pos_embed_type (`str`, defaults to `"sincos"`): The type of positional embedding.
459
+ pos_embed_max_size (`int`, defaults to `None`): The maximum size of the positional embedding.
460
+ """
185
461
 
186
462
  def __init__(
187
463
  self,
@@ -227,10 +503,14 @@ class PatchEmbed(nn.Module):
227
503
  self.pos_embed = None
228
504
  elif pos_embed_type == "sincos":
229
505
  pos_embed = get_2d_sincos_pos_embed(
230
- embed_dim, grid_size, base_size=self.base_size, interpolation_scale=self.interpolation_scale
506
+ embed_dim,
507
+ grid_size,
508
+ base_size=self.base_size,
509
+ interpolation_scale=self.interpolation_scale,
510
+ output_type="pt",
231
511
  )
232
512
  persistent = True if pos_embed_max_size else False
233
- self.register_buffer("pos_embed", torch.from_numpy(pos_embed).float().unsqueeze(0), persistent=persistent)
513
+ self.register_buffer("pos_embed", pos_embed.float().unsqueeze(0), persistent=persistent)
234
514
  else:
235
515
  raise ValueError(f"Unsupported pos_embed_type: {pos_embed_type}")
236
516
 
@@ -262,7 +542,6 @@ class PatchEmbed(nn.Module):
262
542
  height, width = latent.shape[-2:]
263
543
  else:
264
544
  height, width = latent.shape[-2] // self.patch_size, latent.shape[-1] // self.patch_size
265
-
266
545
  latent = self.proj(latent)
267
546
  if self.flatten:
268
547
  latent = latent.flatten(2).transpose(1, 2) # BCHW -> BNC
@@ -280,8 +559,10 @@ class PatchEmbed(nn.Module):
280
559
  grid_size=(height, width),
281
560
  base_size=self.base_size,
282
561
  interpolation_scale=self.interpolation_scale,
562
+ device=latent.device,
563
+ output_type="pt",
283
564
  )
284
- pos_embed = torch.from_numpy(pos_embed).float().unsqueeze(0).to(latent.device)
565
+ pos_embed = pos_embed.float().unsqueeze(0)
285
566
  else:
286
567
  pos_embed = self.pos_embed
287
568
 
@@ -289,7 +570,15 @@ class PatchEmbed(nn.Module):
289
570
 
290
571
 
291
572
  class LuminaPatchEmbed(nn.Module):
292
- """2D Image to Patch Embedding with support for Lumina-T2X"""
573
+ """
574
+ 2D Image to Patch Embedding with support for Lumina-T2X
575
+
576
+ Args:
577
+ patch_size (`int`, defaults to `2`): The size of the patches.
578
+ in_channels (`int`, defaults to `4`): The number of input channels.
579
+ embed_dim (`int`, defaults to `768`): The output dimension of the embedding.
580
+ bias (`bool`, defaults to `True`): Whether or not to use bias.
581
+ """
293
582
 
294
583
  def __init__(self, patch_size=2, in_channels=4, embed_dim=768, bias=True):
295
584
  super().__init__()
@@ -338,6 +627,7 @@ class CogVideoXPatchEmbed(nn.Module):
338
627
  def __init__(
339
628
  self,
340
629
  patch_size: int = 2,
630
+ patch_size_t: Optional[int] = None,
341
631
  in_channels: int = 16,
342
632
  embed_dim: int = 1920,
343
633
  text_embed_dim: int = 4096,
@@ -355,6 +645,7 @@ class CogVideoXPatchEmbed(nn.Module):
355
645
  super().__init__()
356
646
 
357
647
  self.patch_size = patch_size
648
+ self.patch_size_t = patch_size_t
358
649
  self.embed_dim = embed_dim
359
650
  self.sample_height = sample_height
360
651
  self.sample_width = sample_width
@@ -366,9 +657,15 @@ class CogVideoXPatchEmbed(nn.Module):
366
657
  self.use_positional_embeddings = use_positional_embeddings
367
658
  self.use_learned_positional_embeddings = use_learned_positional_embeddings
368
659
 
369
- self.proj = nn.Conv2d(
370
- in_channels, embed_dim, kernel_size=(patch_size, patch_size), stride=patch_size, bias=bias
371
- )
660
+ if patch_size_t is None:
661
+ # CogVideoX 1.0 checkpoints
662
+ self.proj = nn.Conv2d(
663
+ in_channels, embed_dim, kernel_size=(patch_size, patch_size), stride=patch_size, bias=bias
664
+ )
665
+ else:
666
+ # CogVideoX 1.5 checkpoints
667
+ self.proj = nn.Linear(in_channels * patch_size * patch_size * patch_size_t, embed_dim)
668
+
372
669
  self.text_proj = nn.Linear(text_embed_dim, embed_dim)
373
670
 
374
671
  if use_positional_embeddings or use_learned_positional_embeddings:
@@ -376,7 +673,9 @@ class CogVideoXPatchEmbed(nn.Module):
376
673
  pos_embedding = self._get_positional_embeddings(sample_height, sample_width, sample_frames)
377
674
  self.register_buffer("pos_embedding", pos_embedding, persistent=persistent)
378
675
 
379
- def _get_positional_embeddings(self, sample_height: int, sample_width: int, sample_frames: int) -> torch.Tensor:
676
+ def _get_positional_embeddings(
677
+ self, sample_height: int, sample_width: int, sample_frames: int, device: Optional[torch.device] = None
678
+ ) -> torch.Tensor:
380
679
  post_patch_height = sample_height // self.patch_size
381
680
  post_patch_width = sample_width // self.patch_size
382
681
  post_time_compression_frames = (sample_frames - 1) // self.temporal_compression_ratio + 1
@@ -388,9 +687,11 @@ class CogVideoXPatchEmbed(nn.Module):
388
687
  post_time_compression_frames,
389
688
  self.spatial_interpolation_scale,
390
689
  self.temporal_interpolation_scale,
690
+ device=device,
691
+ output_type="pt",
391
692
  )
392
- pos_embedding = torch.from_numpy(pos_embedding).flatten(0, 1)
393
- joint_pos_embedding = torch.zeros(
693
+ pos_embedding = pos_embedding.flatten(0, 1)
694
+ joint_pos_embedding = pos_embedding.new_zeros(
394
695
  1, self.max_text_seq_length + num_patches, self.embed_dim, requires_grad=False
395
696
  )
396
697
  joint_pos_embedding.data[:, self.max_text_seq_length :].copy_(pos_embedding)
@@ -407,12 +708,24 @@ class CogVideoXPatchEmbed(nn.Module):
407
708
  """
408
709
  text_embeds = self.text_proj(text_embeds)
409
710
 
410
- batch, num_frames, channels, height, width = image_embeds.shape
411
- image_embeds = image_embeds.reshape(-1, channels, height, width)
412
- image_embeds = self.proj(image_embeds)
413
- image_embeds = image_embeds.view(batch, num_frames, *image_embeds.shape[1:])
414
- image_embeds = image_embeds.flatten(3).transpose(2, 3) # [batch, num_frames, height x width, channels]
415
- image_embeds = image_embeds.flatten(1, 2) # [batch, num_frames x height x width, channels]
711
+ batch_size, num_frames, channels, height, width = image_embeds.shape
712
+
713
+ if self.patch_size_t is None:
714
+ image_embeds = image_embeds.reshape(-1, channels, height, width)
715
+ image_embeds = self.proj(image_embeds)
716
+ image_embeds = image_embeds.view(batch_size, num_frames, *image_embeds.shape[1:])
717
+ image_embeds = image_embeds.flatten(3).transpose(2, 3) # [batch, num_frames, height x width, channels]
718
+ image_embeds = image_embeds.flatten(1, 2) # [batch, num_frames x height x width, channels]
719
+ else:
720
+ p = self.patch_size
721
+ p_t = self.patch_size_t
722
+
723
+ image_embeds = image_embeds.permute(0, 1, 3, 4, 2)
724
+ image_embeds = image_embeds.reshape(
725
+ batch_size, num_frames // p_t, p_t, height // p, p, width // p, p, channels
726
+ )
727
+ image_embeds = image_embeds.permute(0, 1, 3, 5, 7, 2, 4, 6).flatten(4, 7).flatten(1, 3)
728
+ image_embeds = self.proj(image_embeds)
416
729
 
417
730
  embeds = torch.cat(
418
731
  [text_embeds, image_embeds], dim=1
@@ -432,11 +745,13 @@ class CogVideoXPatchEmbed(nn.Module):
432
745
  or self.sample_width != width
433
746
  or self.sample_frames != pre_time_compression_frames
434
747
  ):
435
- pos_embedding = self._get_positional_embeddings(height, width, pre_time_compression_frames)
436
- pos_embedding = pos_embedding.to(embeds.device, dtype=embeds.dtype)
748
+ pos_embedding = self._get_positional_embeddings(
749
+ height, width, pre_time_compression_frames, device=embeds.device
750
+ )
437
751
  else:
438
752
  pos_embedding = self.pos_embedding
439
753
 
754
+ pos_embedding = pos_embedding.to(dtype=embeds.dtype)
440
755
  embeds = embeds + pos_embedding
441
756
 
442
757
  return embeds
@@ -463,9 +778,11 @@ class CogView3PlusPatchEmbed(nn.Module):
463
778
  # Linear projection for text embeddings
464
779
  self.text_proj = nn.Linear(text_hidden_size, hidden_size)
465
780
 
466
- pos_embed = get_2d_sincos_pos_embed(hidden_size, pos_embed_max_size, base_size=pos_embed_max_size)
781
+ pos_embed = get_2d_sincos_pos_embed(
782
+ hidden_size, pos_embed_max_size, base_size=pos_embed_max_size, output_type="pt"
783
+ )
467
784
  pos_embed = pos_embed.reshape(pos_embed_max_size, pos_embed_max_size, hidden_size)
468
- self.register_buffer("pos_embed", torch.from_numpy(pos_embed).float(), persistent=False)
785
+ self.register_buffer("pos_embed", pos_embed.float(), persistent=False)
469
786
 
470
787
  def forward(self, hidden_states: torch.Tensor, encoder_hidden_states: torch.Tensor) -> torch.Tensor:
471
788
  batch_size, channel, height, width = hidden_states.shape
@@ -497,7 +814,15 @@ class CogView3PlusPatchEmbed(nn.Module):
497
814
 
498
815
 
499
816
  def get_3d_rotary_pos_embed(
500
- embed_dim, crops_coords, grid_size, temporal_size, theta: int = 10000, use_real: bool = True
817
+ embed_dim,
818
+ crops_coords,
819
+ grid_size,
820
+ temporal_size,
821
+ theta: int = 10000,
822
+ use_real: bool = True,
823
+ grid_type: str = "linspace",
824
+ max_size: Optional[Tuple[int, int]] = None,
825
+ device: Optional[torch.device] = None,
501
826
  ) -> Union[torch.Tensor, Tuple[torch.Tensor, torch.Tensor]]:
502
827
  """
503
828
  RoPE for video tokens with 3D structure.
@@ -513,17 +838,36 @@ def get_3d_rotary_pos_embed(
513
838
  The size of the temporal dimension.
514
839
  theta (`float`):
515
840
  Scaling factor for frequency computation.
841
+ grid_type (`str`):
842
+ Whether to use "linspace" or "slice" to compute grids.
516
843
 
517
844
  Returns:
518
845
  `torch.Tensor`: positional embedding with shape `(temporal_size * grid_size[0] * grid_size[1], embed_dim/2)`.
519
846
  """
520
847
  if use_real is not True:
521
848
  raise ValueError(" `use_real = False` is not currently supported for get_3d_rotary_pos_embed")
522
- start, stop = crops_coords
523
- grid_size_h, grid_size_w = grid_size
524
- grid_h = np.linspace(start[0], stop[0], grid_size_h, endpoint=False, dtype=np.float32)
525
- grid_w = np.linspace(start[1], stop[1], grid_size_w, endpoint=False, dtype=np.float32)
526
- grid_t = np.linspace(0, temporal_size, temporal_size, endpoint=False, dtype=np.float32)
849
+
850
+ if grid_type == "linspace":
851
+ start, stop = crops_coords
852
+ grid_size_h, grid_size_w = grid_size
853
+ grid_h = torch.linspace(
854
+ start[0], stop[0] * (grid_size_h - 1) / grid_size_h, grid_size_h, device=device, dtype=torch.float32
855
+ )
856
+ grid_w = torch.linspace(
857
+ start[1], stop[1] * (grid_size_w - 1) / grid_size_w, grid_size_w, device=device, dtype=torch.float32
858
+ )
859
+ grid_t = torch.arange(temporal_size, device=device, dtype=torch.float32)
860
+ grid_t = torch.linspace(
861
+ 0, temporal_size * (temporal_size - 1) / temporal_size, temporal_size, device=device, dtype=torch.float32
862
+ )
863
+ elif grid_type == "slice":
864
+ max_h, max_w = max_size
865
+ grid_size_h, grid_size_w = grid_size
866
+ grid_h = torch.arange(max_h, device=device, dtype=torch.float32)
867
+ grid_w = torch.arange(max_w, device=device, dtype=torch.float32)
868
+ grid_t = torch.arange(temporal_size, device=device, dtype=torch.float32)
869
+ else:
870
+ raise ValueError("Invalid value passed for `grid_type`.")
527
871
 
528
872
  # Compute dimensions for each axis
529
873
  dim_t = embed_dim // 4
@@ -531,10 +875,10 @@ def get_3d_rotary_pos_embed(
531
875
  dim_w = embed_dim // 8 * 3
532
876
 
533
877
  # Temporal frequencies
534
- freqs_t = get_1d_rotary_pos_embed(dim_t, grid_t, use_real=True)
878
+ freqs_t = get_1d_rotary_pos_embed(dim_t, grid_t, theta=theta, use_real=True)
535
879
  # Spatial frequencies for height and width
536
- freqs_h = get_1d_rotary_pos_embed(dim_h, grid_h, use_real=True)
537
- freqs_w = get_1d_rotary_pos_embed(dim_w, grid_w, use_real=True)
880
+ freqs_h = get_1d_rotary_pos_embed(dim_h, grid_h, theta=theta, use_real=True)
881
+ freqs_w = get_1d_rotary_pos_embed(dim_w, grid_w, theta=theta, use_real=True)
538
882
 
539
883
  # BroadCast and concatenate temporal and spaial frequencie (height and width) into a 3d tensor
540
884
  def combine_time_height_width(freqs_t, freqs_h, freqs_w):
@@ -559,12 +903,111 @@ def get_3d_rotary_pos_embed(
559
903
  t_cos, t_sin = freqs_t # both t_cos and t_sin has shape: temporal_size, dim_t
560
904
  h_cos, h_sin = freqs_h # both h_cos and h_sin has shape: grid_size_h, dim_h
561
905
  w_cos, w_sin = freqs_w # both w_cos and w_sin has shape: grid_size_w, dim_w
906
+
907
+ if grid_type == "slice":
908
+ t_cos, t_sin = t_cos[:temporal_size], t_sin[:temporal_size]
909
+ h_cos, h_sin = h_cos[:grid_size_h], h_sin[:grid_size_h]
910
+ w_cos, w_sin = w_cos[:grid_size_w], w_sin[:grid_size_w]
911
+
562
912
  cos = combine_time_height_width(t_cos, h_cos, w_cos)
563
913
  sin = combine_time_height_width(t_sin, h_sin, w_sin)
564
914
  return cos, sin
565
915
 
566
916
 
567
- def get_2d_rotary_pos_embed(embed_dim, crops_coords, grid_size, use_real=True):
917
+ def get_3d_rotary_pos_embed_allegro(
918
+ embed_dim,
919
+ crops_coords,
920
+ grid_size,
921
+ temporal_size,
922
+ interpolation_scale: Tuple[float, float, float] = (1.0, 1.0, 1.0),
923
+ theta: int = 10000,
924
+ device: Optional[torch.device] = None,
925
+ ) -> Union[torch.Tensor, Tuple[torch.Tensor, torch.Tensor]]:
926
+ # TODO(aryan): docs
927
+ start, stop = crops_coords
928
+ grid_size_h, grid_size_w = grid_size
929
+ interpolation_scale_t, interpolation_scale_h, interpolation_scale_w = interpolation_scale
930
+ grid_t = torch.linspace(
931
+ 0, temporal_size * (temporal_size - 1) / temporal_size, temporal_size, device=device, dtype=torch.float32
932
+ )
933
+ grid_h = torch.linspace(
934
+ start[0], stop[0] * (grid_size_h - 1) / grid_size_h, grid_size_h, device=device, dtype=torch.float32
935
+ )
936
+ grid_w = torch.linspace(
937
+ start[1], stop[1] * (grid_size_w - 1) / grid_size_w, grid_size_w, device=device, dtype=torch.float32
938
+ )
939
+
940
+ # Compute dimensions for each axis
941
+ dim_t = embed_dim // 3
942
+ dim_h = embed_dim // 3
943
+ dim_w = embed_dim // 3
944
+
945
+ # Temporal frequencies
946
+ freqs_t = get_1d_rotary_pos_embed(
947
+ dim_t, grid_t / interpolation_scale_t, theta=theta, use_real=True, repeat_interleave_real=False
948
+ )
949
+ # Spatial frequencies for height and width
950
+ freqs_h = get_1d_rotary_pos_embed(
951
+ dim_h, grid_h / interpolation_scale_h, theta=theta, use_real=True, repeat_interleave_real=False
952
+ )
953
+ freqs_w = get_1d_rotary_pos_embed(
954
+ dim_w, grid_w / interpolation_scale_w, theta=theta, use_real=True, repeat_interleave_real=False
955
+ )
956
+
957
+ return freqs_t, freqs_h, freqs_w, grid_t, grid_h, grid_w
958
+
959
+
960
+ def get_2d_rotary_pos_embed(
961
+ embed_dim, crops_coords, grid_size, use_real=True, device: Optional[torch.device] = None, output_type: str = "np"
962
+ ):
963
+ """
964
+ RoPE for image tokens with 2d structure.
965
+
966
+ Args:
967
+ embed_dim: (`int`):
968
+ The embedding dimension size
969
+ crops_coords (`Tuple[int]`)
970
+ The top-left and bottom-right coordinates of the crop.
971
+ grid_size (`Tuple[int]`):
972
+ The grid size of the positional embedding.
973
+ use_real (`bool`):
974
+ If True, return real part and imaginary part separately. Otherwise, return complex numbers.
975
+ device: (`torch.device`, **optional**):
976
+ The device used to create tensors.
977
+
978
+ Returns:
979
+ `torch.Tensor`: positional embedding with shape `( grid_size * grid_size, embed_dim/2)`.
980
+ """
981
+ if output_type == "np":
982
+ deprecation_message = (
983
+ "`get_2d_sincos_pos_embed` uses `torch` and supports `device`."
984
+ " `from_numpy` is no longer required."
985
+ " Pass `output_type='pt' to use the new version now."
986
+ )
987
+ deprecate("output_type=='np'", "0.33.0", deprecation_message, standard_warn=False)
988
+ return _get_2d_rotary_pos_embed_np(
989
+ embed_dim=embed_dim,
990
+ crops_coords=crops_coords,
991
+ grid_size=grid_size,
992
+ use_real=use_real,
993
+ )
994
+ start, stop = crops_coords
995
+ # scale end by (steps−1)/steps matches np.linspace(..., endpoint=False)
996
+ grid_h = torch.linspace(
997
+ start[0], stop[0] * (grid_size[0] - 1) / grid_size[0], grid_size[0], device=device, dtype=torch.float32
998
+ )
999
+ grid_w = torch.linspace(
1000
+ start[1], stop[1] * (grid_size[1] - 1) / grid_size[1], grid_size[1], device=device, dtype=torch.float32
1001
+ )
1002
+ grid = torch.meshgrid(grid_w, grid_h, indexing="xy")
1003
+ grid = torch.stack(grid, dim=0) # [2, W, H]
1004
+
1005
+ grid = grid.reshape([2, 1, *grid.shape[1:]])
1006
+ pos_embed = get_2d_rotary_pos_embed_from_grid(embed_dim, grid, use_real=use_real)
1007
+ return pos_embed
1008
+
1009
+
1010
+ def _get_2d_rotary_pos_embed_np(embed_dim, crops_coords, grid_size, use_real=True):
568
1011
  """
569
1012
  RoPE for image tokens with 2d structure.
570
1013
 
@@ -593,6 +1036,20 @@ def get_2d_rotary_pos_embed(embed_dim, crops_coords, grid_size, use_real=True):
593
1036
 
594
1037
 
595
1038
  def get_2d_rotary_pos_embed_from_grid(embed_dim, grid, use_real=False):
1039
+ """
1040
+ Get 2D RoPE from grid.
1041
+
1042
+ Args:
1043
+ embed_dim: (`int`):
1044
+ The embedding dimension size, corresponding to hidden_size_head.
1045
+ grid (`np.ndarray`):
1046
+ The grid of the positional embedding.
1047
+ use_real (`bool`):
1048
+ If True, return real part and imaginary part separately. Otherwise, return complex numbers.
1049
+
1050
+ Returns:
1051
+ `torch.Tensor`: positional embedding with shape `( grid_size * grid_size, embed_dim/2)`.
1052
+ """
596
1053
  assert embed_dim % 4 == 0
597
1054
 
598
1055
  # use half of dimensions to encode grid_h
@@ -613,6 +1070,23 @@ def get_2d_rotary_pos_embed_from_grid(embed_dim, grid, use_real=False):
613
1070
 
614
1071
 
615
1072
  def get_2d_rotary_pos_embed_lumina(embed_dim, len_h, len_w, linear_factor=1.0, ntk_factor=1.0):
1073
+ """
1074
+ Get 2D RoPE from grid.
1075
+
1076
+ Args:
1077
+ embed_dim: (`int`):
1078
+ The embedding dimension size, corresponding to hidden_size_head.
1079
+ grid (`np.ndarray`):
1080
+ The grid of the positional embedding.
1081
+ linear_factor (`float`):
1082
+ The linear factor of the positional embedding, which is used to scale the positional embedding in the linear
1083
+ layer.
1084
+ ntk_factor (`float`):
1085
+ The ntk factor of the positional embedding, which is used to scale the positional embedding in the ntk layer.
1086
+
1087
+ Returns:
1088
+ `torch.Tensor`: positional embedding with shape `( grid_size * grid_size, embed_dim/2)`.
1089
+ """
616
1090
  assert embed_dim % 4 == 0
617
1091
 
618
1092
  emb_h = get_1d_rotary_pos_embed(
@@ -684,7 +1158,7 @@ def get_1d_rotary_pos_embed(
684
1158
  freqs_sin = freqs.sin().repeat_interleave(2, dim=1).float() # [S, D]
685
1159
  return freqs_cos, freqs_sin
686
1160
  elif use_real:
687
- # stable audio
1161
+ # stable audio, allegro
688
1162
  freqs_cos = torch.cat([freqs.cos(), freqs.cos()], dim=-1).float() # [S, D]
689
1163
  freqs_sin = torch.cat([freqs.sin(), freqs.sin()], dim=-1).float() # [S, D]
690
1164
  return freqs_cos, freqs_sin
@@ -743,6 +1217,24 @@ def apply_rotary_emb(
743
1217
  return x_out.type_as(x)
744
1218
 
745
1219
 
1220
+ def apply_rotary_emb_allegro(x: torch.Tensor, freqs_cis, positions):
1221
+ # TODO(aryan): rewrite
1222
+ def apply_1d_rope(tokens, pos, cos, sin):
1223
+ cos = F.embedding(pos, cos)[:, None, :, :]
1224
+ sin = F.embedding(pos, sin)[:, None, :, :]
1225
+ x1, x2 = tokens[..., : tokens.shape[-1] // 2], tokens[..., tokens.shape[-1] // 2 :]
1226
+ tokens_rotated = torch.cat((-x2, x1), dim=-1)
1227
+ return (tokens.float() * cos + tokens_rotated.float() * sin).to(tokens.dtype)
1228
+
1229
+ (t_cos, t_sin), (h_cos, h_sin), (w_cos, w_sin) = freqs_cis
1230
+ t, h, w = x.chunk(3, dim=-1)
1231
+ t = apply_1d_rope(t, positions[0], t_cos, t_sin)
1232
+ h = apply_1d_rope(h, positions[1], h_cos, h_sin)
1233
+ w = apply_1d_rope(w, positions[2], w_cos, w_sin)
1234
+ x = torch.cat([t, h, w], dim=-1)
1235
+ return x
1236
+
1237
+
746
1238
  class FluxPosEmbed(nn.Module):
747
1239
  # modified from https://github.com/black-forest-labs/flux/blob/c00d7c60b085fce8058b9df845e036090873f2ce/src/flux/modules/layers.py#L11
748
1240
  def __init__(self, theta: int, axes_dim: List[int]):
@@ -759,7 +1251,12 @@ class FluxPosEmbed(nn.Module):
759
1251
  freqs_dtype = torch.float32 if is_mps else torch.float64
760
1252
  for i in range(n_axes):
761
1253
  cos, sin = get_1d_rotary_pos_embed(
762
- self.axes_dim[i], pos[:, i], repeat_interleave_real=True, use_real=True, freqs_dtype=freqs_dtype
1254
+ self.axes_dim[i],
1255
+ pos[:, i],
1256
+ theta=self.theta,
1257
+ repeat_interleave_real=True,
1258
+ use_real=True,
1259
+ freqs_dtype=freqs_dtype,
763
1260
  )
764
1261
  cos_out.append(cos)
765
1262
  sin_out.append(sin)
@@ -1038,7 +1535,7 @@ class ImageProjection(nn.Module):
1038
1535
  batch_size = image_embeds.shape[0]
1039
1536
 
1040
1537
  # image
1041
- image_embeds = self.image_embeds(image_embeds)
1538
+ image_embeds = self.image_embeds(image_embeds.to(self.image_embeds.weight.dtype))
1042
1539
  image_embeds = image_embeds.reshape(batch_size, self.num_image_text_embeds, -1)
1043
1540
  image_embeds = self.norm(image_embeds)
1044
1541
  return image_embeds
@@ -1302,6 +1799,41 @@ class LuminaCombinedTimestepCaptionEmbedding(nn.Module):
1302
1799
  return conditioning
1303
1800
 
1304
1801
 
1802
+ class MochiCombinedTimestepCaptionEmbedding(nn.Module):
1803
+ def __init__(
1804
+ self,
1805
+ embedding_dim: int,
1806
+ pooled_projection_dim: int,
1807
+ text_embed_dim: int,
1808
+ time_embed_dim: int = 256,
1809
+ num_attention_heads: int = 8,
1810
+ ) -> None:
1811
+ super().__init__()
1812
+
1813
+ self.time_proj = Timesteps(num_channels=time_embed_dim, flip_sin_to_cos=True, downscale_freq_shift=0.0)
1814
+ self.timestep_embedder = TimestepEmbedding(in_channels=time_embed_dim, time_embed_dim=embedding_dim)
1815
+ self.pooler = MochiAttentionPool(
1816
+ num_attention_heads=num_attention_heads, embed_dim=text_embed_dim, output_dim=embedding_dim
1817
+ )
1818
+ self.caption_proj = nn.Linear(text_embed_dim, pooled_projection_dim)
1819
+
1820
+ def forward(
1821
+ self,
1822
+ timestep: torch.LongTensor,
1823
+ encoder_hidden_states: torch.Tensor,
1824
+ encoder_attention_mask: torch.Tensor,
1825
+ hidden_dtype: Optional[torch.dtype] = None,
1826
+ ):
1827
+ time_proj = self.time_proj(timestep)
1828
+ time_emb = self.timestep_embedder(time_proj.to(dtype=hidden_dtype))
1829
+
1830
+ pooled_projections = self.pooler(encoder_hidden_states, encoder_attention_mask)
1831
+ caption_proj = self.caption_proj(encoder_hidden_states)
1832
+
1833
+ conditioning = time_emb + pooled_projections
1834
+ return conditioning, caption_proj
1835
+
1836
+
1305
1837
  class TextTimeEmbedding(nn.Module):
1306
1838
  def __init__(self, encoder_dim: int, time_embed_dim: int, num_heads: int = 64):
1307
1839
  super().__init__()
@@ -1430,6 +1962,88 @@ class AttentionPooling(nn.Module):
1430
1962
  return a[:, 0, :] # cls_token
1431
1963
 
1432
1964
 
1965
+ class MochiAttentionPool(nn.Module):
1966
+ def __init__(
1967
+ self,
1968
+ num_attention_heads: int,
1969
+ embed_dim: int,
1970
+ output_dim: Optional[int] = None,
1971
+ ) -> None:
1972
+ super().__init__()
1973
+
1974
+ self.output_dim = output_dim or embed_dim
1975
+ self.num_attention_heads = num_attention_heads
1976
+
1977
+ self.to_kv = nn.Linear(embed_dim, 2 * embed_dim)
1978
+ self.to_q = nn.Linear(embed_dim, embed_dim)
1979
+ self.to_out = nn.Linear(embed_dim, self.output_dim)
1980
+
1981
+ @staticmethod
1982
+ def pool_tokens(x: torch.Tensor, mask: torch.Tensor, *, keepdim=False) -> torch.Tensor:
1983
+ """
1984
+ Pool tokens in x using mask.
1985
+
1986
+ NOTE: We assume x does not require gradients.
1987
+
1988
+ Args:
1989
+ x: (B, L, D) tensor of tokens.
1990
+ mask: (B, L) boolean tensor indicating which tokens are not padding.
1991
+
1992
+ Returns:
1993
+ pooled: (B, D) tensor of pooled tokens.
1994
+ """
1995
+ assert x.size(1) == mask.size(1) # Expected mask to have same length as tokens.
1996
+ assert x.size(0) == mask.size(0) # Expected mask to have same batch size as tokens.
1997
+ mask = mask[:, :, None].to(dtype=x.dtype)
1998
+ mask = mask / mask.sum(dim=1, keepdim=True).clamp(min=1)
1999
+ pooled = (x * mask).sum(dim=1, keepdim=keepdim)
2000
+ return pooled
2001
+
2002
+ def forward(self, x: torch.Tensor, mask: torch.BoolTensor) -> torch.Tensor:
2003
+ r"""
2004
+ Args:
2005
+ x (`torch.Tensor`):
2006
+ Tensor of shape `(B, S, D)` of input tokens.
2007
+ mask (`torch.Tensor`):
2008
+ Boolean ensor of shape `(B, S)` indicating which tokens are not padding.
2009
+
2010
+ Returns:
2011
+ `torch.Tensor`:
2012
+ `(B, D)` tensor of pooled tokens.
2013
+ """
2014
+ D = x.size(2)
2015
+
2016
+ # Construct attention mask, shape: (B, 1, num_queries=1, num_keys=1+L).
2017
+ attn_mask = mask[:, None, None, :].bool() # (B, 1, 1, L).
2018
+ attn_mask = F.pad(attn_mask, (1, 0), value=True) # (B, 1, 1, 1+L).
2019
+
2020
+ # Average non-padding token features. These will be used as the query.
2021
+ x_pool = self.pool_tokens(x, mask, keepdim=True) # (B, 1, D)
2022
+
2023
+ # Concat pooled features to input sequence.
2024
+ x = torch.cat([x_pool, x], dim=1) # (B, L+1, D)
2025
+
2026
+ # Compute queries, keys, values. Only the mean token is used to create a query.
2027
+ kv = self.to_kv(x) # (B, L+1, 2 * D)
2028
+ q = self.to_q(x[:, 0]) # (B, D)
2029
+
2030
+ # Extract heads.
2031
+ head_dim = D // self.num_attention_heads
2032
+ kv = kv.unflatten(2, (2, self.num_attention_heads, head_dim)) # (B, 1+L, 2, H, head_dim)
2033
+ kv = kv.transpose(1, 3) # (B, H, 2, 1+L, head_dim)
2034
+ k, v = kv.unbind(2) # (B, H, 1+L, head_dim)
2035
+ q = q.unflatten(1, (self.num_attention_heads, head_dim)) # (B, H, head_dim)
2036
+ q = q.unsqueeze(2) # (B, H, 1, head_dim)
2037
+
2038
+ # Compute attention.
2039
+ x = F.scaled_dot_product_attention(q, k, v, attn_mask=attn_mask, dropout_p=0.0) # (B, H, 1, head_dim)
2040
+
2041
+ # Concatenate heads and run output.
2042
+ x = x.squeeze(2).flatten(1, 2) # (B, D = H * head_dim)
2043
+ x = self.to_out(x)
2044
+ return x
2045
+
2046
+
1433
2047
  def get_fourier_embeds_from_boundingbox(embed_dim, box):
1434
2048
  """
1435
2049
  Args:
@@ -1782,6 +2396,187 @@ class IPAdapterFaceIDPlusImageProjection(nn.Module):
1782
2396
  return out
1783
2397
 
1784
2398
 
2399
+ class IPAdapterTimeImageProjectionBlock(nn.Module):
2400
+ """Block for IPAdapterTimeImageProjection.
2401
+
2402
+ Args:
2403
+ hidden_dim (`int`, defaults to 1280):
2404
+ The number of hidden channels.
2405
+ dim_head (`int`, defaults to 64):
2406
+ The number of head channels.
2407
+ heads (`int`, defaults to 20):
2408
+ Parallel attention heads.
2409
+ ffn_ratio (`int`, defaults to 4):
2410
+ The expansion ratio of feedforward network hidden layer channels.
2411
+ """
2412
+
2413
+ def __init__(
2414
+ self,
2415
+ hidden_dim: int = 1280,
2416
+ dim_head: int = 64,
2417
+ heads: int = 20,
2418
+ ffn_ratio: int = 4,
2419
+ ) -> None:
2420
+ super().__init__()
2421
+ from .attention import FeedForward
2422
+
2423
+ self.ln0 = nn.LayerNorm(hidden_dim)
2424
+ self.ln1 = nn.LayerNorm(hidden_dim)
2425
+ self.attn = Attention(
2426
+ query_dim=hidden_dim,
2427
+ cross_attention_dim=hidden_dim,
2428
+ dim_head=dim_head,
2429
+ heads=heads,
2430
+ bias=False,
2431
+ out_bias=False,
2432
+ )
2433
+ self.ff = FeedForward(hidden_dim, hidden_dim, activation_fn="gelu", mult=ffn_ratio, bias=False)
2434
+
2435
+ # AdaLayerNorm
2436
+ self.adaln_silu = nn.SiLU()
2437
+ self.adaln_proj = nn.Linear(hidden_dim, 4 * hidden_dim)
2438
+ self.adaln_norm = nn.LayerNorm(hidden_dim)
2439
+
2440
+ # Set attention scale and fuse KV
2441
+ self.attn.scale = 1 / math.sqrt(math.sqrt(dim_head))
2442
+ self.attn.fuse_projections()
2443
+ self.attn.to_k = None
2444
+ self.attn.to_v = None
2445
+
2446
+ def forward(self, x: torch.Tensor, latents: torch.Tensor, timestep_emb: torch.Tensor) -> torch.Tensor:
2447
+ """Forward pass.
2448
+
2449
+ Args:
2450
+ x (`torch.Tensor`):
2451
+ Image features.
2452
+ latents (`torch.Tensor`):
2453
+ Latent features.
2454
+ timestep_emb (`torch.Tensor`):
2455
+ Timestep embedding.
2456
+
2457
+ Returns:
2458
+ `torch.Tensor`: Output latent features.
2459
+ """
2460
+
2461
+ # Shift and scale for AdaLayerNorm
2462
+ emb = self.adaln_proj(self.adaln_silu(timestep_emb))
2463
+ shift_msa, scale_msa, shift_mlp, scale_mlp = emb.chunk(4, dim=1)
2464
+
2465
+ # Fused Attention
2466
+ residual = latents
2467
+ x = self.ln0(x)
2468
+ latents = self.ln1(latents) * (1 + scale_msa[:, None]) + shift_msa[:, None]
2469
+
2470
+ batch_size = latents.shape[0]
2471
+
2472
+ query = self.attn.to_q(latents)
2473
+ kv_input = torch.cat((x, latents), dim=-2)
2474
+ key, value = self.attn.to_kv(kv_input).chunk(2, dim=-1)
2475
+
2476
+ inner_dim = key.shape[-1]
2477
+ head_dim = inner_dim // self.attn.heads
2478
+
2479
+ query = query.view(batch_size, -1, self.attn.heads, head_dim).transpose(1, 2)
2480
+ key = key.view(batch_size, -1, self.attn.heads, head_dim).transpose(1, 2)
2481
+ value = value.view(batch_size, -1, self.attn.heads, head_dim).transpose(1, 2)
2482
+
2483
+ weight = (query * self.attn.scale) @ (key * self.attn.scale).transpose(-2, -1)
2484
+ weight = torch.softmax(weight.float(), dim=-1).type(weight.dtype)
2485
+ latents = weight @ value
2486
+
2487
+ latents = latents.transpose(1, 2).reshape(batch_size, -1, self.attn.heads * head_dim)
2488
+ latents = self.attn.to_out[0](latents)
2489
+ latents = self.attn.to_out[1](latents)
2490
+ latents = latents + residual
2491
+
2492
+ ## FeedForward
2493
+ residual = latents
2494
+ latents = self.adaln_norm(latents) * (1 + scale_mlp[:, None]) + shift_mlp[:, None]
2495
+ return self.ff(latents) + residual
2496
+
2497
+
2498
+ # Modified from https://github.com/mlfoundations/open_flamingo/blob/main/open_flamingo/src/helpers.py
2499
+ class IPAdapterTimeImageProjection(nn.Module):
2500
+ """Resampler of SD3 IP-Adapter with timestep embedding.
2501
+
2502
+ Args:
2503
+ embed_dim (`int`, defaults to 1152):
2504
+ The feature dimension.
2505
+ output_dim (`int`, defaults to 2432):
2506
+ The number of output channels.
2507
+ hidden_dim (`int`, defaults to 1280):
2508
+ The number of hidden channels.
2509
+ depth (`int`, defaults to 4):
2510
+ The number of blocks.
2511
+ dim_head (`int`, defaults to 64):
2512
+ The number of head channels.
2513
+ heads (`int`, defaults to 20):
2514
+ Parallel attention heads.
2515
+ num_queries (`int`, defaults to 64):
2516
+ The number of queries.
2517
+ ffn_ratio (`int`, defaults to 4):
2518
+ The expansion ratio of feedforward network hidden layer channels.
2519
+ timestep_in_dim (`int`, defaults to 320):
2520
+ The number of input channels for timestep embedding.
2521
+ timestep_flip_sin_to_cos (`bool`, defaults to True):
2522
+ Flip the timestep embedding order to `cos, sin` (if True) or `sin, cos` (if False).
2523
+ timestep_freq_shift (`int`, defaults to 0):
2524
+ Controls the timestep delta between frequencies between dimensions.
2525
+ """
2526
+
2527
+ def __init__(
2528
+ self,
2529
+ embed_dim: int = 1152,
2530
+ output_dim: int = 2432,
2531
+ hidden_dim: int = 1280,
2532
+ depth: int = 4,
2533
+ dim_head: int = 64,
2534
+ heads: int = 20,
2535
+ num_queries: int = 64,
2536
+ ffn_ratio: int = 4,
2537
+ timestep_in_dim: int = 320,
2538
+ timestep_flip_sin_to_cos: bool = True,
2539
+ timestep_freq_shift: int = 0,
2540
+ ) -> None:
2541
+ super().__init__()
2542
+ self.latents = nn.Parameter(torch.randn(1, num_queries, hidden_dim) / hidden_dim**0.5)
2543
+ self.proj_in = nn.Linear(embed_dim, hidden_dim)
2544
+ self.proj_out = nn.Linear(hidden_dim, output_dim)
2545
+ self.norm_out = nn.LayerNorm(output_dim)
2546
+ self.layers = nn.ModuleList(
2547
+ [IPAdapterTimeImageProjectionBlock(hidden_dim, dim_head, heads, ffn_ratio) for _ in range(depth)]
2548
+ )
2549
+ self.time_proj = Timesteps(timestep_in_dim, timestep_flip_sin_to_cos, timestep_freq_shift)
2550
+ self.time_embedding = TimestepEmbedding(timestep_in_dim, hidden_dim, act_fn="silu")
2551
+
2552
+ def forward(self, x: torch.Tensor, timestep: torch.Tensor) -> Tuple[torch.Tensor, torch.Tensor]:
2553
+ """Forward pass.
2554
+
2555
+ Args:
2556
+ x (`torch.Tensor`):
2557
+ Image features.
2558
+ timestep (`torch.Tensor`):
2559
+ Timestep in denoising process.
2560
+ Returns:
2561
+ `Tuple`[`torch.Tensor`, `torch.Tensor`]: The pair (latents, timestep_emb).
2562
+ """
2563
+ timestep_emb = self.time_proj(timestep).to(dtype=x.dtype)
2564
+ timestep_emb = self.time_embedding(timestep_emb)
2565
+
2566
+ latents = self.latents.repeat(x.size(0), 1, 1)
2567
+
2568
+ x = self.proj_in(x)
2569
+ x = x + timestep_emb[:, None]
2570
+
2571
+ for block in self.layers:
2572
+ latents = block(x, latents, timestep_emb)
2573
+
2574
+ latents = self.proj_out(latents)
2575
+ latents = self.norm_out(latents)
2576
+
2577
+ return latents, timestep_emb
2578
+
2579
+
1785
2580
  class MultiIPAdapterImageProjection(nn.Module):
1786
2581
  def __init__(self, IPAdapterImageProjectionLayers: Union[List[nn.Module], Tuple[nn.Module]]):
1787
2582
  super().__init__()