diffusers 0.30.3__py3-none-any.whl → 0.32.0__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (268) hide show
  1. diffusers/__init__.py +97 -4
  2. diffusers/callbacks.py +56 -3
  3. diffusers/configuration_utils.py +13 -1
  4. diffusers/image_processor.py +282 -71
  5. diffusers/loaders/__init__.py +24 -3
  6. diffusers/loaders/ip_adapter.py +543 -16
  7. diffusers/loaders/lora_base.py +138 -125
  8. diffusers/loaders/lora_conversion_utils.py +647 -0
  9. diffusers/loaders/lora_pipeline.py +2216 -230
  10. diffusers/loaders/peft.py +380 -0
  11. diffusers/loaders/single_file_model.py +71 -4
  12. diffusers/loaders/single_file_utils.py +597 -10
  13. diffusers/loaders/textual_inversion.py +5 -3
  14. diffusers/loaders/transformer_flux.py +181 -0
  15. diffusers/loaders/transformer_sd3.py +89 -0
  16. diffusers/loaders/unet.py +56 -12
  17. diffusers/models/__init__.py +49 -12
  18. diffusers/models/activations.py +22 -9
  19. diffusers/models/adapter.py +53 -53
  20. diffusers/models/attention.py +98 -13
  21. diffusers/models/attention_flax.py +1 -1
  22. diffusers/models/attention_processor.py +2160 -346
  23. diffusers/models/autoencoders/__init__.py +5 -0
  24. diffusers/models/autoencoders/autoencoder_dc.py +620 -0
  25. diffusers/models/autoencoders/autoencoder_kl.py +73 -12
  26. diffusers/models/autoencoders/autoencoder_kl_allegro.py +1149 -0
  27. diffusers/models/autoencoders/autoencoder_kl_cogvideox.py +213 -105
  28. diffusers/models/autoencoders/autoencoder_kl_hunyuan_video.py +1176 -0
  29. diffusers/models/autoencoders/autoencoder_kl_ltx.py +1338 -0
  30. diffusers/models/autoencoders/autoencoder_kl_mochi.py +1166 -0
  31. diffusers/models/autoencoders/autoencoder_kl_temporal_decoder.py +3 -10
  32. diffusers/models/autoencoders/autoencoder_tiny.py +4 -2
  33. diffusers/models/autoencoders/vae.py +18 -5
  34. diffusers/models/controlnet.py +47 -802
  35. diffusers/models/controlnet_flux.py +70 -0
  36. diffusers/models/controlnet_sd3.py +26 -376
  37. diffusers/models/controlnet_sparsectrl.py +46 -719
  38. diffusers/models/controlnets/__init__.py +23 -0
  39. diffusers/models/controlnets/controlnet.py +872 -0
  40. diffusers/models/{controlnet_flax.py → controlnets/controlnet_flax.py} +5 -5
  41. diffusers/models/controlnets/controlnet_flux.py +536 -0
  42. diffusers/models/{controlnet_hunyuan.py → controlnets/controlnet_hunyuan.py} +7 -7
  43. diffusers/models/controlnets/controlnet_sd3.py +489 -0
  44. diffusers/models/controlnets/controlnet_sparsectrl.py +788 -0
  45. diffusers/models/controlnets/controlnet_union.py +832 -0
  46. diffusers/models/{controlnet_xs.py → controlnets/controlnet_xs.py} +14 -13
  47. diffusers/models/controlnets/multicontrolnet.py +183 -0
  48. diffusers/models/embeddings.py +996 -92
  49. diffusers/models/embeddings_flax.py +23 -9
  50. diffusers/models/model_loading_utils.py +264 -14
  51. diffusers/models/modeling_flax_utils.py +1 -1
  52. diffusers/models/modeling_utils.py +334 -51
  53. diffusers/models/normalization.py +157 -13
  54. diffusers/models/transformers/__init__.py +6 -0
  55. diffusers/models/transformers/auraflow_transformer_2d.py +3 -2
  56. diffusers/models/transformers/cogvideox_transformer_3d.py +69 -13
  57. diffusers/models/transformers/dit_transformer_2d.py +1 -1
  58. diffusers/models/transformers/latte_transformer_3d.py +4 -4
  59. diffusers/models/transformers/pixart_transformer_2d.py +10 -2
  60. diffusers/models/transformers/sana_transformer.py +488 -0
  61. diffusers/models/transformers/stable_audio_transformer.py +1 -1
  62. diffusers/models/transformers/transformer_2d.py +1 -1
  63. diffusers/models/transformers/transformer_allegro.py +422 -0
  64. diffusers/models/transformers/transformer_cogview3plus.py +386 -0
  65. diffusers/models/transformers/transformer_flux.py +189 -51
  66. diffusers/models/transformers/transformer_hunyuan_video.py +789 -0
  67. diffusers/models/transformers/transformer_ltx.py +469 -0
  68. diffusers/models/transformers/transformer_mochi.py +499 -0
  69. diffusers/models/transformers/transformer_sd3.py +112 -18
  70. diffusers/models/transformers/transformer_temporal.py +1 -1
  71. diffusers/models/unets/unet_1d_blocks.py +1 -1
  72. diffusers/models/unets/unet_2d.py +8 -1
  73. diffusers/models/unets/unet_2d_blocks.py +88 -21
  74. diffusers/models/unets/unet_2d_condition.py +9 -9
  75. diffusers/models/unets/unet_3d_blocks.py +9 -7
  76. diffusers/models/unets/unet_motion_model.py +46 -68
  77. diffusers/models/unets/unet_spatio_temporal_condition.py +23 -0
  78. diffusers/models/unets/unet_stable_cascade.py +2 -2
  79. diffusers/models/unets/uvit_2d.py +1 -1
  80. diffusers/models/upsampling.py +14 -6
  81. diffusers/pipelines/__init__.py +69 -6
  82. diffusers/pipelines/allegro/__init__.py +48 -0
  83. diffusers/pipelines/allegro/pipeline_allegro.py +938 -0
  84. diffusers/pipelines/allegro/pipeline_output.py +23 -0
  85. diffusers/pipelines/animatediff/__init__.py +2 -0
  86. diffusers/pipelines/animatediff/pipeline_animatediff.py +45 -21
  87. diffusers/pipelines/animatediff/pipeline_animatediff_controlnet.py +52 -22
  88. diffusers/pipelines/animatediff/pipeline_animatediff_sdxl.py +18 -4
  89. diffusers/pipelines/animatediff/pipeline_animatediff_sparsectrl.py +3 -1
  90. diffusers/pipelines/animatediff/pipeline_animatediff_video2video.py +104 -72
  91. diffusers/pipelines/animatediff/pipeline_animatediff_video2video_controlnet.py +1341 -0
  92. diffusers/pipelines/audioldm2/modeling_audioldm2.py +3 -3
  93. diffusers/pipelines/aura_flow/pipeline_aura_flow.py +2 -9
  94. diffusers/pipelines/auto_pipeline.py +88 -10
  95. diffusers/pipelines/blip_diffusion/modeling_blip2.py +1 -1
  96. diffusers/pipelines/cogvideo/__init__.py +2 -0
  97. diffusers/pipelines/cogvideo/pipeline_cogvideox.py +80 -39
  98. diffusers/pipelines/cogvideo/pipeline_cogvideox_fun_control.py +825 -0
  99. diffusers/pipelines/cogvideo/pipeline_cogvideox_image2video.py +108 -50
  100. diffusers/pipelines/cogvideo/pipeline_cogvideox_video2video.py +89 -50
  101. diffusers/pipelines/cogview3/__init__.py +47 -0
  102. diffusers/pipelines/cogview3/pipeline_cogview3plus.py +674 -0
  103. diffusers/pipelines/cogview3/pipeline_output.py +21 -0
  104. diffusers/pipelines/controlnet/__init__.py +86 -80
  105. diffusers/pipelines/controlnet/multicontrolnet.py +7 -178
  106. diffusers/pipelines/controlnet/pipeline_controlnet.py +20 -3
  107. diffusers/pipelines/controlnet/pipeline_controlnet_img2img.py +9 -2
  108. diffusers/pipelines/controlnet/pipeline_controlnet_inpaint.py +9 -2
  109. diffusers/pipelines/controlnet/pipeline_controlnet_inpaint_sd_xl.py +37 -15
  110. diffusers/pipelines/controlnet/pipeline_controlnet_sd_xl.py +12 -4
  111. diffusers/pipelines/controlnet/pipeline_controlnet_sd_xl_img2img.py +9 -4
  112. diffusers/pipelines/controlnet/pipeline_controlnet_union_inpaint_sd_xl.py +1790 -0
  113. diffusers/pipelines/controlnet/pipeline_controlnet_union_sd_xl.py +1501 -0
  114. diffusers/pipelines/controlnet/pipeline_controlnet_union_sd_xl_img2img.py +1627 -0
  115. diffusers/pipelines/controlnet_hunyuandit/pipeline_hunyuandit_controlnet.py +22 -4
  116. diffusers/pipelines/controlnet_sd3/__init__.py +4 -0
  117. diffusers/pipelines/controlnet_sd3/pipeline_stable_diffusion_3_controlnet.py +56 -20
  118. diffusers/pipelines/controlnet_sd3/pipeline_stable_diffusion_3_controlnet_inpainting.py +1153 -0
  119. diffusers/pipelines/ddpm/pipeline_ddpm.py +2 -2
  120. diffusers/pipelines/deepfloyd_if/pipeline_output.py +6 -5
  121. diffusers/pipelines/deprecated/alt_diffusion/pipeline_alt_diffusion.py +16 -4
  122. diffusers/pipelines/deprecated/alt_diffusion/pipeline_alt_diffusion_img2img.py +1 -1
  123. diffusers/pipelines/deprecated/versatile_diffusion/modeling_text_unet.py +32 -9
  124. diffusers/pipelines/flux/__init__.py +23 -1
  125. diffusers/pipelines/flux/modeling_flux.py +47 -0
  126. diffusers/pipelines/flux/pipeline_flux.py +256 -48
  127. diffusers/pipelines/flux/pipeline_flux_control.py +889 -0
  128. diffusers/pipelines/flux/pipeline_flux_control_img2img.py +945 -0
  129. diffusers/pipelines/flux/pipeline_flux_control_inpaint.py +1141 -0
  130. diffusers/pipelines/flux/pipeline_flux_controlnet.py +1006 -0
  131. diffusers/pipelines/flux/pipeline_flux_controlnet_image_to_image.py +998 -0
  132. diffusers/pipelines/flux/pipeline_flux_controlnet_inpainting.py +1204 -0
  133. diffusers/pipelines/flux/pipeline_flux_fill.py +969 -0
  134. diffusers/pipelines/flux/pipeline_flux_img2img.py +856 -0
  135. diffusers/pipelines/flux/pipeline_flux_inpaint.py +1022 -0
  136. diffusers/pipelines/flux/pipeline_flux_prior_redux.py +492 -0
  137. diffusers/pipelines/flux/pipeline_output.py +16 -0
  138. diffusers/pipelines/free_noise_utils.py +365 -5
  139. diffusers/pipelines/hunyuan_video/__init__.py +48 -0
  140. diffusers/pipelines/hunyuan_video/pipeline_hunyuan_video.py +687 -0
  141. diffusers/pipelines/hunyuan_video/pipeline_output.py +20 -0
  142. diffusers/pipelines/hunyuandit/pipeline_hunyuandit.py +20 -4
  143. diffusers/pipelines/kandinsky/pipeline_kandinsky_combined.py +9 -9
  144. diffusers/pipelines/kandinsky2_2/pipeline_kandinsky2_2_combined.py +2 -2
  145. diffusers/pipelines/kolors/pipeline_kolors.py +1 -1
  146. diffusers/pipelines/kolors/pipeline_kolors_img2img.py +14 -11
  147. diffusers/pipelines/kolors/text_encoder.py +2 -2
  148. diffusers/pipelines/kolors/tokenizer.py +4 -0
  149. diffusers/pipelines/latent_consistency_models/pipeline_latent_consistency_img2img.py +1 -1
  150. diffusers/pipelines/latent_consistency_models/pipeline_latent_consistency_text2img.py +1 -1
  151. diffusers/pipelines/latent_diffusion/pipeline_latent_diffusion.py +1 -1
  152. diffusers/pipelines/latte/pipeline_latte.py +2 -2
  153. diffusers/pipelines/ledits_pp/pipeline_leditspp_stable_diffusion.py +15 -3
  154. diffusers/pipelines/ledits_pp/pipeline_leditspp_stable_diffusion_xl.py +15 -3
  155. diffusers/pipelines/ltx/__init__.py +50 -0
  156. diffusers/pipelines/ltx/pipeline_ltx.py +789 -0
  157. diffusers/pipelines/ltx/pipeline_ltx_image2video.py +885 -0
  158. diffusers/pipelines/ltx/pipeline_output.py +20 -0
  159. diffusers/pipelines/lumina/pipeline_lumina.py +3 -10
  160. diffusers/pipelines/mochi/__init__.py +48 -0
  161. diffusers/pipelines/mochi/pipeline_mochi.py +748 -0
  162. diffusers/pipelines/mochi/pipeline_output.py +20 -0
  163. diffusers/pipelines/pag/__init__.py +13 -0
  164. diffusers/pipelines/pag/pag_utils.py +8 -2
  165. diffusers/pipelines/pag/pipeline_pag_controlnet_sd.py +2 -3
  166. diffusers/pipelines/pag/pipeline_pag_controlnet_sd_inpaint.py +1543 -0
  167. diffusers/pipelines/pag/pipeline_pag_controlnet_sd_xl.py +3 -5
  168. diffusers/pipelines/pag/pipeline_pag_controlnet_sd_xl_img2img.py +1683 -0
  169. diffusers/pipelines/pag/pipeline_pag_hunyuandit.py +22 -6
  170. diffusers/pipelines/pag/pipeline_pag_kolors.py +1 -1
  171. diffusers/pipelines/pag/pipeline_pag_pixart_sigma.py +7 -14
  172. diffusers/pipelines/pag/pipeline_pag_sana.py +886 -0
  173. diffusers/pipelines/pag/pipeline_pag_sd.py +18 -6
  174. diffusers/pipelines/pag/pipeline_pag_sd_3.py +18 -9
  175. diffusers/pipelines/pag/pipeline_pag_sd_3_img2img.py +1058 -0
  176. diffusers/pipelines/pag/pipeline_pag_sd_animatediff.py +5 -1
  177. diffusers/pipelines/pag/pipeline_pag_sd_img2img.py +1094 -0
  178. diffusers/pipelines/pag/pipeline_pag_sd_inpaint.py +1356 -0
  179. diffusers/pipelines/pag/pipeline_pag_sd_xl.py +18 -6
  180. diffusers/pipelines/pag/pipeline_pag_sd_xl_img2img.py +31 -16
  181. diffusers/pipelines/pag/pipeline_pag_sd_xl_inpaint.py +42 -19
  182. diffusers/pipelines/pia/pipeline_pia.py +2 -0
  183. diffusers/pipelines/pipeline_flax_utils.py +1 -1
  184. diffusers/pipelines/pipeline_loading_utils.py +250 -31
  185. diffusers/pipelines/pipeline_utils.py +158 -186
  186. diffusers/pipelines/pixart_alpha/pipeline_pixart_alpha.py +7 -14
  187. diffusers/pipelines/pixart_alpha/pipeline_pixart_sigma.py +7 -14
  188. diffusers/pipelines/sana/__init__.py +47 -0
  189. diffusers/pipelines/sana/pipeline_output.py +21 -0
  190. diffusers/pipelines/sana/pipeline_sana.py +884 -0
  191. diffusers/pipelines/stable_audio/pipeline_stable_audio.py +12 -1
  192. diffusers/pipelines/stable_cascade/pipeline_stable_cascade.py +35 -3
  193. diffusers/pipelines/stable_cascade/pipeline_stable_cascade_prior.py +2 -2
  194. diffusers/pipelines/stable_diffusion/pipeline_stable_diffusion.py +46 -9
  195. diffusers/pipelines/stable_diffusion/pipeline_stable_diffusion_img2img.py +1 -1
  196. diffusers/pipelines/stable_diffusion/pipeline_stable_diffusion_inpaint.py +1 -1
  197. diffusers/pipelines/stable_diffusion/pipeline_stable_diffusion_latent_upscale.py +241 -81
  198. diffusers/pipelines/stable_diffusion_3/pipeline_stable_diffusion_3.py +228 -23
  199. diffusers/pipelines/stable_diffusion_3/pipeline_stable_diffusion_3_img2img.py +82 -13
  200. diffusers/pipelines/stable_diffusion_3/pipeline_stable_diffusion_3_inpaint.py +60 -11
  201. diffusers/pipelines/stable_diffusion_gligen/pipeline_stable_diffusion_gligen_text_image.py +11 -1
  202. diffusers/pipelines/stable_diffusion_k_diffusion/pipeline_stable_diffusion_k_diffusion.py +1 -1
  203. diffusers/pipelines/stable_diffusion_ldm3d/pipeline_stable_diffusion_ldm3d.py +16 -4
  204. diffusers/pipelines/stable_diffusion_panorama/pipeline_stable_diffusion_panorama.py +16 -4
  205. diffusers/pipelines/stable_diffusion_xl/pipeline_stable_diffusion_xl.py +16 -12
  206. diffusers/pipelines/stable_diffusion_xl/pipeline_stable_diffusion_xl_img2img.py +29 -22
  207. diffusers/pipelines/stable_diffusion_xl/pipeline_stable_diffusion_xl_inpaint.py +29 -22
  208. diffusers/pipelines/stable_video_diffusion/pipeline_stable_video_diffusion.py +1 -1
  209. diffusers/pipelines/t2i_adapter/pipeline_stable_diffusion_adapter.py +1 -1
  210. diffusers/pipelines/t2i_adapter/pipeline_stable_diffusion_xl_adapter.py +16 -4
  211. diffusers/pipelines/text_to_video_synthesis/pipeline_text_to_video_zero_sdxl.py +15 -3
  212. diffusers/pipelines/unidiffuser/modeling_uvit.py +2 -2
  213. diffusers/pipelines/wuerstchen/modeling_wuerstchen_prior.py +1 -1
  214. diffusers/quantizers/__init__.py +16 -0
  215. diffusers/quantizers/auto.py +139 -0
  216. diffusers/quantizers/base.py +233 -0
  217. diffusers/quantizers/bitsandbytes/__init__.py +2 -0
  218. diffusers/quantizers/bitsandbytes/bnb_quantizer.py +561 -0
  219. diffusers/quantizers/bitsandbytes/utils.py +306 -0
  220. diffusers/quantizers/gguf/__init__.py +1 -0
  221. diffusers/quantizers/gguf/gguf_quantizer.py +159 -0
  222. diffusers/quantizers/gguf/utils.py +456 -0
  223. diffusers/quantizers/quantization_config.py +669 -0
  224. diffusers/quantizers/torchao/__init__.py +15 -0
  225. diffusers/quantizers/torchao/torchao_quantizer.py +285 -0
  226. diffusers/schedulers/scheduling_ddim.py +4 -1
  227. diffusers/schedulers/scheduling_ddim_cogvideox.py +4 -1
  228. diffusers/schedulers/scheduling_ddim_parallel.py +4 -1
  229. diffusers/schedulers/scheduling_ddpm.py +6 -7
  230. diffusers/schedulers/scheduling_ddpm_parallel.py +6 -7
  231. diffusers/schedulers/scheduling_deis_multistep.py +102 -6
  232. diffusers/schedulers/scheduling_dpmsolver_multistep.py +113 -6
  233. diffusers/schedulers/scheduling_dpmsolver_multistep_inverse.py +111 -5
  234. diffusers/schedulers/scheduling_dpmsolver_sde.py +125 -10
  235. diffusers/schedulers/scheduling_dpmsolver_singlestep.py +126 -7
  236. diffusers/schedulers/scheduling_edm_euler.py +8 -6
  237. diffusers/schedulers/scheduling_euler_ancestral_discrete.py +4 -1
  238. diffusers/schedulers/scheduling_euler_discrete.py +92 -7
  239. diffusers/schedulers/scheduling_flow_match_euler_discrete.py +153 -6
  240. diffusers/schedulers/scheduling_flow_match_heun_discrete.py +4 -5
  241. diffusers/schedulers/scheduling_heun_discrete.py +114 -8
  242. diffusers/schedulers/scheduling_k_dpm_2_ancestral_discrete.py +116 -11
  243. diffusers/schedulers/scheduling_k_dpm_2_discrete.py +110 -8
  244. diffusers/schedulers/scheduling_lcm.py +2 -6
  245. diffusers/schedulers/scheduling_lms_discrete.py +76 -1
  246. diffusers/schedulers/scheduling_repaint.py +1 -1
  247. diffusers/schedulers/scheduling_sasolver.py +102 -6
  248. diffusers/schedulers/scheduling_tcd.py +2 -6
  249. diffusers/schedulers/scheduling_unclip.py +4 -1
  250. diffusers/schedulers/scheduling_unipc_multistep.py +127 -5
  251. diffusers/training_utils.py +63 -19
  252. diffusers/utils/__init__.py +7 -1
  253. diffusers/utils/constants.py +1 -0
  254. diffusers/utils/dummy_pt_objects.py +240 -0
  255. diffusers/utils/dummy_torch_and_transformers_objects.py +435 -0
  256. diffusers/utils/dynamic_modules_utils.py +3 -3
  257. diffusers/utils/hub_utils.py +44 -40
  258. diffusers/utils/import_utils.py +98 -8
  259. diffusers/utils/loading_utils.py +28 -4
  260. diffusers/utils/peft_utils.py +6 -3
  261. diffusers/utils/testing_utils.py +115 -1
  262. diffusers/utils/torch_utils.py +3 -0
  263. {diffusers-0.30.3.dist-info → diffusers-0.32.0.dist-info}/METADATA +73 -72
  264. {diffusers-0.30.3.dist-info → diffusers-0.32.0.dist-info}/RECORD +268 -193
  265. {diffusers-0.30.3.dist-info → diffusers-0.32.0.dist-info}/WHEEL +1 -1
  266. {diffusers-0.30.3.dist-info → diffusers-0.32.0.dist-info}/LICENSE +0 -0
  267. {diffusers-0.30.3.dist-info → diffusers-0.32.0.dist-info}/entry_points.txt +0 -0
  268. {diffusers-0.30.3.dist-info → diffusers-0.32.0.dist-info}/top_level.txt +0 -0
@@ -84,15 +84,106 @@ def get_3d_sincos_pos_embed(
84
84
  temporal_size: int,
85
85
  spatial_interpolation_scale: float = 1.0,
86
86
  temporal_interpolation_scale: float = 1.0,
87
+ device: Optional[torch.device] = None,
88
+ output_type: str = "np",
89
+ ) -> torch.Tensor:
90
+ r"""
91
+ Creates 3D sinusoidal positional embeddings.
92
+
93
+ Args:
94
+ embed_dim (`int`):
95
+ The embedding dimension of inputs. It must be divisible by 16.
96
+ spatial_size (`int` or `Tuple[int, int]`):
97
+ The spatial dimension of positional embeddings. If an integer is provided, the same size is applied to both
98
+ spatial dimensions (height and width).
99
+ temporal_size (`int`):
100
+ The temporal dimension of postional embeddings (number of frames).
101
+ spatial_interpolation_scale (`float`, defaults to 1.0):
102
+ Scale factor for spatial grid interpolation.
103
+ temporal_interpolation_scale (`float`, defaults to 1.0):
104
+ Scale factor for temporal grid interpolation.
105
+
106
+ Returns:
107
+ `torch.Tensor`:
108
+ The 3D sinusoidal positional embeddings of shape `[temporal_size, spatial_size[0] * spatial_size[1],
109
+ embed_dim]`.
110
+ """
111
+ if output_type == "np":
112
+ return _get_3d_sincos_pos_embed_np(
113
+ embed_dim=embed_dim,
114
+ spatial_size=spatial_size,
115
+ temporal_size=temporal_size,
116
+ spatial_interpolation_scale=spatial_interpolation_scale,
117
+ temporal_interpolation_scale=temporal_interpolation_scale,
118
+ )
119
+ if embed_dim % 4 != 0:
120
+ raise ValueError("`embed_dim` must be divisible by 4")
121
+ if isinstance(spatial_size, int):
122
+ spatial_size = (spatial_size, spatial_size)
123
+
124
+ embed_dim_spatial = 3 * embed_dim // 4
125
+ embed_dim_temporal = embed_dim // 4
126
+
127
+ # 1. Spatial
128
+ grid_h = torch.arange(spatial_size[1], device=device, dtype=torch.float32) / spatial_interpolation_scale
129
+ grid_w = torch.arange(spatial_size[0], device=device, dtype=torch.float32) / spatial_interpolation_scale
130
+ grid = torch.meshgrid(grid_w, grid_h, indexing="xy") # here w goes first
131
+ grid = torch.stack(grid, dim=0)
132
+
133
+ grid = grid.reshape([2, 1, spatial_size[1], spatial_size[0]])
134
+ pos_embed_spatial = get_2d_sincos_pos_embed_from_grid(embed_dim_spatial, grid, output_type="pt")
135
+
136
+ # 2. Temporal
137
+ grid_t = torch.arange(temporal_size, device=device, dtype=torch.float32) / temporal_interpolation_scale
138
+ pos_embed_temporal = get_1d_sincos_pos_embed_from_grid(embed_dim_temporal, grid_t, output_type="pt")
139
+
140
+ # 3. Concat
141
+ pos_embed_spatial = pos_embed_spatial[None, :, :]
142
+ pos_embed_spatial = pos_embed_spatial.repeat_interleave(temporal_size, dim=0) # [T, H*W, D // 4 * 3]
143
+
144
+ pos_embed_temporal = pos_embed_temporal[:, None, :]
145
+ pos_embed_temporal = pos_embed_temporal.repeat_interleave(
146
+ spatial_size[0] * spatial_size[1], dim=1
147
+ ) # [T, H*W, D // 4]
148
+
149
+ pos_embed = torch.concat([pos_embed_temporal, pos_embed_spatial], dim=-1) # [T, H*W, D]
150
+ return pos_embed
151
+
152
+
153
+ def _get_3d_sincos_pos_embed_np(
154
+ embed_dim: int,
155
+ spatial_size: Union[int, Tuple[int, int]],
156
+ temporal_size: int,
157
+ spatial_interpolation_scale: float = 1.0,
158
+ temporal_interpolation_scale: float = 1.0,
87
159
  ) -> np.ndarray:
88
160
  r"""
161
+ Creates 3D sinusoidal positional embeddings.
162
+
89
163
  Args:
90
164
  embed_dim (`int`):
165
+ The embedding dimension of inputs. It must be divisible by 16.
91
166
  spatial_size (`int` or `Tuple[int, int]`):
167
+ The spatial dimension of positional embeddings. If an integer is provided, the same size is applied to both
168
+ spatial dimensions (height and width).
92
169
  temporal_size (`int`):
170
+ The temporal dimension of postional embeddings (number of frames).
93
171
  spatial_interpolation_scale (`float`, defaults to 1.0):
172
+ Scale factor for spatial grid interpolation.
94
173
  temporal_interpolation_scale (`float`, defaults to 1.0):
174
+ Scale factor for temporal grid interpolation.
175
+
176
+ Returns:
177
+ `np.ndarray`:
178
+ The 3D sinusoidal positional embeddings of shape `[temporal_size, spatial_size[0] * spatial_size[1],
179
+ embed_dim]`.
95
180
  """
181
+ deprecation_message = (
182
+ "`get_3d_sincos_pos_embed` uses `torch` and supports `device`."
183
+ " `from_numpy` is no longer required."
184
+ " Pass `output_type='pt' to use the new version now."
185
+ )
186
+ deprecate("output_type=='np'", "0.33.0", deprecation_message, standard_warn=False)
96
187
  if embed_dim % 4 != 0:
97
188
  raise ValueError("`embed_dim` must be divisible by 4")
98
189
  if isinstance(spatial_size, int):
@@ -126,11 +217,164 @@ def get_3d_sincos_pos_embed(
126
217
 
127
218
 
128
219
  def get_2d_sincos_pos_embed(
220
+ embed_dim,
221
+ grid_size,
222
+ cls_token=False,
223
+ extra_tokens=0,
224
+ interpolation_scale=1.0,
225
+ base_size=16,
226
+ device: Optional[torch.device] = None,
227
+ output_type: str = "np",
228
+ ):
229
+ """
230
+ Creates 2D sinusoidal positional embeddings.
231
+
232
+ Args:
233
+ embed_dim (`int`):
234
+ The embedding dimension.
235
+ grid_size (`int`):
236
+ The size of the grid height and width.
237
+ cls_token (`bool`, defaults to `False`):
238
+ Whether or not to add a classification token.
239
+ extra_tokens (`int`, defaults to `0`):
240
+ The number of extra tokens to add.
241
+ interpolation_scale (`float`, defaults to `1.0`):
242
+ The scale of the interpolation.
243
+
244
+ Returns:
245
+ pos_embed (`torch.Tensor`):
246
+ Shape is either `[grid_size * grid_size, embed_dim]` if not using cls_token, or `[1 + grid_size*grid_size,
247
+ embed_dim]` if using cls_token
248
+ """
249
+ if output_type == "np":
250
+ deprecation_message = (
251
+ "`get_2d_sincos_pos_embed` uses `torch` and supports `device`."
252
+ " `from_numpy` is no longer required."
253
+ " Pass `output_type='pt' to use the new version now."
254
+ )
255
+ deprecate("output_type=='np'", "0.33.0", deprecation_message, standard_warn=False)
256
+ return get_2d_sincos_pos_embed_np(
257
+ embed_dim=embed_dim,
258
+ grid_size=grid_size,
259
+ cls_token=cls_token,
260
+ extra_tokens=extra_tokens,
261
+ interpolation_scale=interpolation_scale,
262
+ base_size=base_size,
263
+ )
264
+ if isinstance(grid_size, int):
265
+ grid_size = (grid_size, grid_size)
266
+
267
+ grid_h = (
268
+ torch.arange(grid_size[0], device=device, dtype=torch.float32)
269
+ / (grid_size[0] / base_size)
270
+ / interpolation_scale
271
+ )
272
+ grid_w = (
273
+ torch.arange(grid_size[1], device=device, dtype=torch.float32)
274
+ / (grid_size[1] / base_size)
275
+ / interpolation_scale
276
+ )
277
+ grid = torch.meshgrid(grid_w, grid_h, indexing="xy") # here w goes first
278
+ grid = torch.stack(grid, dim=0)
279
+
280
+ grid = grid.reshape([2, 1, grid_size[1], grid_size[0]])
281
+ pos_embed = get_2d_sincos_pos_embed_from_grid(embed_dim, grid, output_type=output_type)
282
+ if cls_token and extra_tokens > 0:
283
+ pos_embed = torch.concat([torch.zeros([extra_tokens, embed_dim]), pos_embed], dim=0)
284
+ return pos_embed
285
+
286
+
287
+ def get_2d_sincos_pos_embed_from_grid(embed_dim, grid, output_type="np"):
288
+ r"""
289
+ This function generates 2D sinusoidal positional embeddings from a grid.
290
+
291
+ Args:
292
+ embed_dim (`int`): The embedding dimension.
293
+ grid (`torch.Tensor`): Grid of positions with shape `(H * W,)`.
294
+
295
+ Returns:
296
+ `torch.Tensor`: The 2D sinusoidal positional embeddings with shape `(H * W, embed_dim)`
297
+ """
298
+ if output_type == "np":
299
+ deprecation_message = (
300
+ "`get_2d_sincos_pos_embed_from_grid` uses `torch` and supports `device`."
301
+ " `from_numpy` is no longer required."
302
+ " Pass `output_type='pt' to use the new version now."
303
+ )
304
+ deprecate("output_type=='np'", "0.33.0", deprecation_message, standard_warn=False)
305
+ return get_2d_sincos_pos_embed_from_grid_np(
306
+ embed_dim=embed_dim,
307
+ grid=grid,
308
+ )
309
+ if embed_dim % 2 != 0:
310
+ raise ValueError("embed_dim must be divisible by 2")
311
+
312
+ # use half of dimensions to encode grid_h
313
+ emb_h = get_1d_sincos_pos_embed_from_grid(embed_dim // 2, grid[0], output_type=output_type) # (H*W, D/2)
314
+ emb_w = get_1d_sincos_pos_embed_from_grid(embed_dim // 2, grid[1], output_type=output_type) # (H*W, D/2)
315
+
316
+ emb = torch.concat([emb_h, emb_w], dim=1) # (H*W, D)
317
+ return emb
318
+
319
+
320
+ def get_1d_sincos_pos_embed_from_grid(embed_dim, pos, output_type="np"):
321
+ """
322
+ This function generates 1D positional embeddings from a grid.
323
+
324
+ Args:
325
+ embed_dim (`int`): The embedding dimension `D`
326
+ pos (`torch.Tensor`): 1D tensor of positions with shape `(M,)`
327
+
328
+ Returns:
329
+ `torch.Tensor`: Sinusoidal positional embeddings of shape `(M, D)`.
330
+ """
331
+ if output_type == "np":
332
+ deprecation_message = (
333
+ "`get_1d_sincos_pos_embed_from_grid` uses `torch` and supports `device`."
334
+ " `from_numpy` is no longer required."
335
+ " Pass `output_type='pt' to use the new version now."
336
+ )
337
+ deprecate("output_type=='np'", "0.33.0", deprecation_message, standard_warn=False)
338
+ return get_1d_sincos_pos_embed_from_grid_np(embed_dim=embed_dim, pos=pos)
339
+ if embed_dim % 2 != 0:
340
+ raise ValueError("embed_dim must be divisible by 2")
341
+
342
+ omega = torch.arange(embed_dim // 2, device=pos.device, dtype=torch.float64)
343
+ omega /= embed_dim / 2.0
344
+ omega = 1.0 / 10000**omega # (D/2,)
345
+
346
+ pos = pos.reshape(-1) # (M,)
347
+ out = torch.outer(pos, omega) # (M, D/2), outer product
348
+
349
+ emb_sin = torch.sin(out) # (M, D/2)
350
+ emb_cos = torch.cos(out) # (M, D/2)
351
+
352
+ emb = torch.concat([emb_sin, emb_cos], dim=1) # (M, D)
353
+ return emb
354
+
355
+
356
+ def get_2d_sincos_pos_embed_np(
129
357
  embed_dim, grid_size, cls_token=False, extra_tokens=0, interpolation_scale=1.0, base_size=16
130
358
  ):
131
359
  """
132
- grid_size: int of the grid height and width return: pos_embed: [grid_size*grid_size, embed_dim] or
133
- [1+grid_size*grid_size, embed_dim] (w/ or w/o cls_token)
360
+ Creates 2D sinusoidal positional embeddings.
361
+
362
+ Args:
363
+ embed_dim (`int`):
364
+ The embedding dimension.
365
+ grid_size (`int`):
366
+ The size of the grid height and width.
367
+ cls_token (`bool`, defaults to `False`):
368
+ Whether or not to add a classification token.
369
+ extra_tokens (`int`, defaults to `0`):
370
+ The number of extra tokens to add.
371
+ interpolation_scale (`float`, defaults to `1.0`):
372
+ The scale of the interpolation.
373
+
374
+ Returns:
375
+ pos_embed (`np.ndarray`):
376
+ Shape is either `[grid_size * grid_size, embed_dim]` if not using cls_token, or `[1 + grid_size*grid_size,
377
+ embed_dim]` if using cls_token
134
378
  """
135
379
  if isinstance(grid_size, int):
136
380
  grid_size = (grid_size, grid_size)
@@ -141,27 +385,44 @@ def get_2d_sincos_pos_embed(
141
385
  grid = np.stack(grid, axis=0)
142
386
 
143
387
  grid = grid.reshape([2, 1, grid_size[1], grid_size[0]])
144
- pos_embed = get_2d_sincos_pos_embed_from_grid(embed_dim, grid)
388
+ pos_embed = get_2d_sincos_pos_embed_from_grid_np(embed_dim, grid)
145
389
  if cls_token and extra_tokens > 0:
146
390
  pos_embed = np.concatenate([np.zeros([extra_tokens, embed_dim]), pos_embed], axis=0)
147
391
  return pos_embed
148
392
 
149
393
 
150
- def get_2d_sincos_pos_embed_from_grid(embed_dim, grid):
394
+ def get_2d_sincos_pos_embed_from_grid_np(embed_dim, grid):
395
+ r"""
396
+ This function generates 2D sinusoidal positional embeddings from a grid.
397
+
398
+ Args:
399
+ embed_dim (`int`): The embedding dimension.
400
+ grid (`np.ndarray`): Grid of positions with shape `(H * W,)`.
401
+
402
+ Returns:
403
+ `np.ndarray`: The 2D sinusoidal positional embeddings with shape `(H * W, embed_dim)`
404
+ """
151
405
  if embed_dim % 2 != 0:
152
406
  raise ValueError("embed_dim must be divisible by 2")
153
407
 
154
408
  # use half of dimensions to encode grid_h
155
- emb_h = get_1d_sincos_pos_embed_from_grid(embed_dim // 2, grid[0]) # (H*W, D/2)
156
- emb_w = get_1d_sincos_pos_embed_from_grid(embed_dim // 2, grid[1]) # (H*W, D/2)
409
+ emb_h = get_1d_sincos_pos_embed_from_grid_np(embed_dim // 2, grid[0]) # (H*W, D/2)
410
+ emb_w = get_1d_sincos_pos_embed_from_grid_np(embed_dim // 2, grid[1]) # (H*W, D/2)
157
411
 
158
412
  emb = np.concatenate([emb_h, emb_w], axis=1) # (H*W, D)
159
413
  return emb
160
414
 
161
415
 
162
- def get_1d_sincos_pos_embed_from_grid(embed_dim, pos):
416
+ def get_1d_sincos_pos_embed_from_grid_np(embed_dim, pos):
163
417
  """
164
- embed_dim: output dimension for each position pos: a list of positions to be encoded: size (M,) out: (M, D)
418
+ This function generates 1D positional embeddings from a grid.
419
+
420
+ Args:
421
+ embed_dim (`int`): The embedding dimension `D`
422
+ pos (`numpy.ndarray`): 1D tensor of positions with shape `(M,)`
423
+
424
+ Returns:
425
+ `numpy.ndarray`: Sinusoidal positional embeddings of shape `(M, D)`.
165
426
  """
166
427
  if embed_dim % 2 != 0:
167
428
  raise ValueError("embed_dim must be divisible by 2")
@@ -181,7 +442,22 @@ def get_1d_sincos_pos_embed_from_grid(embed_dim, pos):
181
442
 
182
443
 
183
444
  class PatchEmbed(nn.Module):
184
- """2D Image to Patch Embedding with support for SD3 cropping."""
445
+ """
446
+ 2D Image to Patch Embedding with support for SD3 cropping.
447
+
448
+ Args:
449
+ height (`int`, defaults to `224`): The height of the image.
450
+ width (`int`, defaults to `224`): The width of the image.
451
+ patch_size (`int`, defaults to `16`): The size of the patches.
452
+ in_channels (`int`, defaults to `3`): The number of input channels.
453
+ embed_dim (`int`, defaults to `768`): The output dimension of the embedding.
454
+ layer_norm (`bool`, defaults to `False`): Whether or not to use layer normalization.
455
+ flatten (`bool`, defaults to `True`): Whether or not to flatten the output.
456
+ bias (`bool`, defaults to `True`): Whether or not to use bias.
457
+ interpolation_scale (`float`, defaults to `1`): The scale of the interpolation.
458
+ pos_embed_type (`str`, defaults to `"sincos"`): The type of positional embedding.
459
+ pos_embed_max_size (`int`, defaults to `None`): The maximum size of the positional embedding.
460
+ """
185
461
 
186
462
  def __init__(
187
463
  self,
@@ -227,10 +503,14 @@ class PatchEmbed(nn.Module):
227
503
  self.pos_embed = None
228
504
  elif pos_embed_type == "sincos":
229
505
  pos_embed = get_2d_sincos_pos_embed(
230
- embed_dim, grid_size, base_size=self.base_size, interpolation_scale=self.interpolation_scale
506
+ embed_dim,
507
+ grid_size,
508
+ base_size=self.base_size,
509
+ interpolation_scale=self.interpolation_scale,
510
+ output_type="pt",
231
511
  )
232
512
  persistent = True if pos_embed_max_size else False
233
- self.register_buffer("pos_embed", torch.from_numpy(pos_embed).float().unsqueeze(0), persistent=persistent)
513
+ self.register_buffer("pos_embed", pos_embed.float().unsqueeze(0), persistent=persistent)
234
514
  else:
235
515
  raise ValueError(f"Unsupported pos_embed_type: {pos_embed_type}")
236
516
 
@@ -262,7 +542,6 @@ class PatchEmbed(nn.Module):
262
542
  height, width = latent.shape[-2:]
263
543
  else:
264
544
  height, width = latent.shape[-2] // self.patch_size, latent.shape[-1] // self.patch_size
265
-
266
545
  latent = self.proj(latent)
267
546
  if self.flatten:
268
547
  latent = latent.flatten(2).transpose(1, 2) # BCHW -> BNC
@@ -280,8 +559,10 @@ class PatchEmbed(nn.Module):
280
559
  grid_size=(height, width),
281
560
  base_size=self.base_size,
282
561
  interpolation_scale=self.interpolation_scale,
562
+ device=latent.device,
563
+ output_type="pt",
283
564
  )
284
- pos_embed = torch.from_numpy(pos_embed).float().unsqueeze(0).to(latent.device)
565
+ pos_embed = pos_embed.float().unsqueeze(0)
285
566
  else:
286
567
  pos_embed = self.pos_embed
287
568
 
@@ -289,7 +570,15 @@ class PatchEmbed(nn.Module):
289
570
 
290
571
 
291
572
  class LuminaPatchEmbed(nn.Module):
292
- """2D Image to Patch Embedding with support for Lumina-T2X"""
573
+ """
574
+ 2D Image to Patch Embedding with support for Lumina-T2X
575
+
576
+ Args:
577
+ patch_size (`int`, defaults to `2`): The size of the patches.
578
+ in_channels (`int`, defaults to `4`): The number of input channels.
579
+ embed_dim (`int`, defaults to `768`): The output dimension of the embedding.
580
+ bias (`bool`, defaults to `True`): Whether or not to use bias.
581
+ """
293
582
 
294
583
  def __init__(self, patch_size=2, in_channels=4, embed_dim=768, bias=True):
295
584
  super().__init__()
@@ -338,6 +627,7 @@ class CogVideoXPatchEmbed(nn.Module):
338
627
  def __init__(
339
628
  self,
340
629
  patch_size: int = 2,
630
+ patch_size_t: Optional[int] = None,
341
631
  in_channels: int = 16,
342
632
  embed_dim: int = 1920,
343
633
  text_embed_dim: int = 4096,
@@ -355,6 +645,7 @@ class CogVideoXPatchEmbed(nn.Module):
355
645
  super().__init__()
356
646
 
357
647
  self.patch_size = patch_size
648
+ self.patch_size_t = patch_size_t
358
649
  self.embed_dim = embed_dim
359
650
  self.sample_height = sample_height
360
651
  self.sample_width = sample_width
@@ -366,9 +657,15 @@ class CogVideoXPatchEmbed(nn.Module):
366
657
  self.use_positional_embeddings = use_positional_embeddings
367
658
  self.use_learned_positional_embeddings = use_learned_positional_embeddings
368
659
 
369
- self.proj = nn.Conv2d(
370
- in_channels, embed_dim, kernel_size=(patch_size, patch_size), stride=patch_size, bias=bias
371
- )
660
+ if patch_size_t is None:
661
+ # CogVideoX 1.0 checkpoints
662
+ self.proj = nn.Conv2d(
663
+ in_channels, embed_dim, kernel_size=(patch_size, patch_size), stride=patch_size, bias=bias
664
+ )
665
+ else:
666
+ # CogVideoX 1.5 checkpoints
667
+ self.proj = nn.Linear(in_channels * patch_size * patch_size * patch_size_t, embed_dim)
668
+
372
669
  self.text_proj = nn.Linear(text_embed_dim, embed_dim)
373
670
 
374
671
  if use_positional_embeddings or use_learned_positional_embeddings:
@@ -376,7 +673,9 @@ class CogVideoXPatchEmbed(nn.Module):
376
673
  pos_embedding = self._get_positional_embeddings(sample_height, sample_width, sample_frames)
377
674
  self.register_buffer("pos_embedding", pos_embedding, persistent=persistent)
378
675
 
379
- def _get_positional_embeddings(self, sample_height: int, sample_width: int, sample_frames: int) -> torch.Tensor:
676
+ def _get_positional_embeddings(
677
+ self, sample_height: int, sample_width: int, sample_frames: int, device: Optional[torch.device] = None
678
+ ) -> torch.Tensor:
380
679
  post_patch_height = sample_height // self.patch_size
381
680
  post_patch_width = sample_width // self.patch_size
382
681
  post_time_compression_frames = (sample_frames - 1) // self.temporal_compression_ratio + 1
@@ -388,9 +687,11 @@ class CogVideoXPatchEmbed(nn.Module):
388
687
  post_time_compression_frames,
389
688
  self.spatial_interpolation_scale,
390
689
  self.temporal_interpolation_scale,
690
+ device=device,
691
+ output_type="pt",
391
692
  )
392
- pos_embedding = torch.from_numpy(pos_embedding).flatten(0, 1)
393
- joint_pos_embedding = torch.zeros(
693
+ pos_embedding = pos_embedding.flatten(0, 1)
694
+ joint_pos_embedding = pos_embedding.new_zeros(
394
695
  1, self.max_text_seq_length + num_patches, self.embed_dim, requires_grad=False
395
696
  )
396
697
  joint_pos_embedding.data[:, self.max_text_seq_length :].copy_(pos_embedding)
@@ -407,12 +708,24 @@ class CogVideoXPatchEmbed(nn.Module):
407
708
  """
408
709
  text_embeds = self.text_proj(text_embeds)
409
710
 
410
- batch, num_frames, channels, height, width = image_embeds.shape
411
- image_embeds = image_embeds.reshape(-1, channels, height, width)
412
- image_embeds = self.proj(image_embeds)
413
- image_embeds = image_embeds.view(batch, num_frames, *image_embeds.shape[1:])
414
- image_embeds = image_embeds.flatten(3).transpose(2, 3) # [batch, num_frames, height x width, channels]
415
- image_embeds = image_embeds.flatten(1, 2) # [batch, num_frames x height x width, channels]
711
+ batch_size, num_frames, channels, height, width = image_embeds.shape
712
+
713
+ if self.patch_size_t is None:
714
+ image_embeds = image_embeds.reshape(-1, channels, height, width)
715
+ image_embeds = self.proj(image_embeds)
716
+ image_embeds = image_embeds.view(batch_size, num_frames, *image_embeds.shape[1:])
717
+ image_embeds = image_embeds.flatten(3).transpose(2, 3) # [batch, num_frames, height x width, channels]
718
+ image_embeds = image_embeds.flatten(1, 2) # [batch, num_frames x height x width, channels]
719
+ else:
720
+ p = self.patch_size
721
+ p_t = self.patch_size_t
722
+
723
+ image_embeds = image_embeds.permute(0, 1, 3, 4, 2)
724
+ image_embeds = image_embeds.reshape(
725
+ batch_size, num_frames // p_t, p_t, height // p, p, width // p, p, channels
726
+ )
727
+ image_embeds = image_embeds.permute(0, 1, 3, 5, 7, 2, 4, 6).flatten(4, 7).flatten(1, 3)
728
+ image_embeds = self.proj(image_embeds)
416
729
 
417
730
  embeds = torch.cat(
418
731
  [text_embeds, image_embeds], dim=1
@@ -432,18 +745,84 @@ class CogVideoXPatchEmbed(nn.Module):
432
745
  or self.sample_width != width
433
746
  or self.sample_frames != pre_time_compression_frames
434
747
  ):
435
- pos_embedding = self._get_positional_embeddings(height, width, pre_time_compression_frames)
436
- pos_embedding = pos_embedding.to(embeds.device, dtype=embeds.dtype)
748
+ pos_embedding = self._get_positional_embeddings(
749
+ height, width, pre_time_compression_frames, device=embeds.device
750
+ )
437
751
  else:
438
752
  pos_embedding = self.pos_embedding
439
753
 
754
+ pos_embedding = pos_embedding.to(dtype=embeds.dtype)
440
755
  embeds = embeds + pos_embedding
441
756
 
442
757
  return embeds
443
758
 
444
759
 
760
+ class CogView3PlusPatchEmbed(nn.Module):
761
+ def __init__(
762
+ self,
763
+ in_channels: int = 16,
764
+ hidden_size: int = 2560,
765
+ patch_size: int = 2,
766
+ text_hidden_size: int = 4096,
767
+ pos_embed_max_size: int = 128,
768
+ ):
769
+ super().__init__()
770
+ self.in_channels = in_channels
771
+ self.hidden_size = hidden_size
772
+ self.patch_size = patch_size
773
+ self.text_hidden_size = text_hidden_size
774
+ self.pos_embed_max_size = pos_embed_max_size
775
+ # Linear projection for image patches
776
+ self.proj = nn.Linear(in_channels * patch_size**2, hidden_size)
777
+
778
+ # Linear projection for text embeddings
779
+ self.text_proj = nn.Linear(text_hidden_size, hidden_size)
780
+
781
+ pos_embed = get_2d_sincos_pos_embed(
782
+ hidden_size, pos_embed_max_size, base_size=pos_embed_max_size, output_type="pt"
783
+ )
784
+ pos_embed = pos_embed.reshape(pos_embed_max_size, pos_embed_max_size, hidden_size)
785
+ self.register_buffer("pos_embed", pos_embed.float(), persistent=False)
786
+
787
+ def forward(self, hidden_states: torch.Tensor, encoder_hidden_states: torch.Tensor) -> torch.Tensor:
788
+ batch_size, channel, height, width = hidden_states.shape
789
+
790
+ if height % self.patch_size != 0 or width % self.patch_size != 0:
791
+ raise ValueError("Height and width must be divisible by patch size")
792
+
793
+ height = height // self.patch_size
794
+ width = width // self.patch_size
795
+ hidden_states = hidden_states.view(batch_size, channel, height, self.patch_size, width, self.patch_size)
796
+ hidden_states = hidden_states.permute(0, 2, 4, 1, 3, 5).contiguous()
797
+ hidden_states = hidden_states.view(batch_size, height * width, channel * self.patch_size * self.patch_size)
798
+
799
+ # Project the patches
800
+ hidden_states = self.proj(hidden_states)
801
+ encoder_hidden_states = self.text_proj(encoder_hidden_states)
802
+ hidden_states = torch.cat([encoder_hidden_states, hidden_states], dim=1)
803
+
804
+ # Calculate text_length
805
+ text_length = encoder_hidden_states.shape[1]
806
+
807
+ image_pos_embed = self.pos_embed[:height, :width].reshape(height * width, -1)
808
+ text_pos_embed = torch.zeros(
809
+ (text_length, self.hidden_size), dtype=image_pos_embed.dtype, device=image_pos_embed.device
810
+ )
811
+ pos_embed = torch.cat([text_pos_embed, image_pos_embed], dim=0)[None, ...]
812
+
813
+ return (hidden_states + pos_embed).to(hidden_states.dtype)
814
+
815
+
445
816
  def get_3d_rotary_pos_embed(
446
- embed_dim, crops_coords, grid_size, temporal_size, theta: int = 10000, use_real: bool = True
817
+ embed_dim,
818
+ crops_coords,
819
+ grid_size,
820
+ temporal_size,
821
+ theta: int = 10000,
822
+ use_real: bool = True,
823
+ grid_type: str = "linspace",
824
+ max_size: Optional[Tuple[int, int]] = None,
825
+ device: Optional[torch.device] = None,
447
826
  ) -> Union[torch.Tensor, Tuple[torch.Tensor, torch.Tensor]]:
448
827
  """
449
828
  RoPE for video tokens with 3D structure.
@@ -459,16 +838,36 @@ def get_3d_rotary_pos_embed(
459
838
  The size of the temporal dimension.
460
839
  theta (`float`):
461
840
  Scaling factor for frequency computation.
462
- use_real (`bool`):
463
- If True, return real part and imaginary part separately. Otherwise, return complex numbers.
841
+ grid_type (`str`):
842
+ Whether to use "linspace" or "slice" to compute grids.
464
843
 
465
844
  Returns:
466
845
  `torch.Tensor`: positional embedding with shape `(temporal_size * grid_size[0] * grid_size[1], embed_dim/2)`.
467
846
  """
468
- start, stop = crops_coords
469
- grid_h = np.linspace(start[0], stop[0], grid_size[0], endpoint=False, dtype=np.float32)
470
- grid_w = np.linspace(start[1], stop[1], grid_size[1], endpoint=False, dtype=np.float32)
471
- grid_t = np.linspace(0, temporal_size, temporal_size, endpoint=False, dtype=np.float32)
847
+ if use_real is not True:
848
+ raise ValueError(" `use_real = False` is not currently supported for get_3d_rotary_pos_embed")
849
+
850
+ if grid_type == "linspace":
851
+ start, stop = crops_coords
852
+ grid_size_h, grid_size_w = grid_size
853
+ grid_h = torch.linspace(
854
+ start[0], stop[0] * (grid_size_h - 1) / grid_size_h, grid_size_h, device=device, dtype=torch.float32
855
+ )
856
+ grid_w = torch.linspace(
857
+ start[1], stop[1] * (grid_size_w - 1) / grid_size_w, grid_size_w, device=device, dtype=torch.float32
858
+ )
859
+ grid_t = torch.arange(temporal_size, device=device, dtype=torch.float32)
860
+ grid_t = torch.linspace(
861
+ 0, temporal_size * (temporal_size - 1) / temporal_size, temporal_size, device=device, dtype=torch.float32
862
+ )
863
+ elif grid_type == "slice":
864
+ max_h, max_w = max_size
865
+ grid_size_h, grid_size_w = grid_size
866
+ grid_h = torch.arange(max_h, device=device, dtype=torch.float32)
867
+ grid_w = torch.arange(max_w, device=device, dtype=torch.float32)
868
+ grid_t = torch.arange(temporal_size, device=device, dtype=torch.float32)
869
+ else:
870
+ raise ValueError("Invalid value passed for `grid_type`.")
472
871
 
473
872
  # Compute dimensions for each axis
474
873
  dim_t = embed_dim // 4
@@ -476,57 +875,139 @@ def get_3d_rotary_pos_embed(
476
875
  dim_w = embed_dim // 8 * 3
477
876
 
478
877
  # Temporal frequencies
479
- freqs_t = 1.0 / (theta ** (torch.arange(0, dim_t, 2).float() / dim_t))
480
- grid_t = torch.from_numpy(grid_t).float()
481
- freqs_t = torch.einsum("n , f -> n f", grid_t, freqs_t)
482
- freqs_t = freqs_t.repeat_interleave(2, dim=-1)
878
+ freqs_t = get_1d_rotary_pos_embed(dim_t, grid_t, theta=theta, use_real=True)
879
+ # Spatial frequencies for height and width
880
+ freqs_h = get_1d_rotary_pos_embed(dim_h, grid_h, theta=theta, use_real=True)
881
+ freqs_w = get_1d_rotary_pos_embed(dim_w, grid_w, theta=theta, use_real=True)
882
+
883
+ # BroadCast and concatenate temporal and spaial frequencie (height and width) into a 3d tensor
884
+ def combine_time_height_width(freqs_t, freqs_h, freqs_w):
885
+ freqs_t = freqs_t[:, None, None, :].expand(
886
+ -1, grid_size_h, grid_size_w, -1
887
+ ) # temporal_size, grid_size_h, grid_size_w, dim_t
888
+ freqs_h = freqs_h[None, :, None, :].expand(
889
+ temporal_size, -1, grid_size_w, -1
890
+ ) # temporal_size, grid_size_h, grid_size_2, dim_h
891
+ freqs_w = freqs_w[None, None, :, :].expand(
892
+ temporal_size, grid_size_h, -1, -1
893
+ ) # temporal_size, grid_size_h, grid_size_2, dim_w
894
+
895
+ freqs = torch.cat(
896
+ [freqs_t, freqs_h, freqs_w], dim=-1
897
+ ) # temporal_size, grid_size_h, grid_size_w, (dim_t + dim_h + dim_w)
898
+ freqs = freqs.view(
899
+ temporal_size * grid_size_h * grid_size_w, -1
900
+ ) # (temporal_size * grid_size_h * grid_size_w), (dim_t + dim_h + dim_w)
901
+ return freqs
902
+
903
+ t_cos, t_sin = freqs_t # both t_cos and t_sin has shape: temporal_size, dim_t
904
+ h_cos, h_sin = freqs_h # both h_cos and h_sin has shape: grid_size_h, dim_h
905
+ w_cos, w_sin = freqs_w # both w_cos and w_sin has shape: grid_size_w, dim_w
906
+
907
+ if grid_type == "slice":
908
+ t_cos, t_sin = t_cos[:temporal_size], t_sin[:temporal_size]
909
+ h_cos, h_sin = h_cos[:grid_size_h], h_sin[:grid_size_h]
910
+ w_cos, w_sin = w_cos[:grid_size_w], w_sin[:grid_size_w]
911
+
912
+ cos = combine_time_height_width(t_cos, h_cos, w_cos)
913
+ sin = combine_time_height_width(t_sin, h_sin, w_sin)
914
+ return cos, sin
915
+
916
+
917
+ def get_3d_rotary_pos_embed_allegro(
918
+ embed_dim,
919
+ crops_coords,
920
+ grid_size,
921
+ temporal_size,
922
+ interpolation_scale: Tuple[float, float, float] = (1.0, 1.0, 1.0),
923
+ theta: int = 10000,
924
+ device: Optional[torch.device] = None,
925
+ ) -> Union[torch.Tensor, Tuple[torch.Tensor, torch.Tensor]]:
926
+ # TODO(aryan): docs
927
+ start, stop = crops_coords
928
+ grid_size_h, grid_size_w = grid_size
929
+ interpolation_scale_t, interpolation_scale_h, interpolation_scale_w = interpolation_scale
930
+ grid_t = torch.linspace(
931
+ 0, temporal_size * (temporal_size - 1) / temporal_size, temporal_size, device=device, dtype=torch.float32
932
+ )
933
+ grid_h = torch.linspace(
934
+ start[0], stop[0] * (grid_size_h - 1) / grid_size_h, grid_size_h, device=device, dtype=torch.float32
935
+ )
936
+ grid_w = torch.linspace(
937
+ start[1], stop[1] * (grid_size_w - 1) / grid_size_w, grid_size_w, device=device, dtype=torch.float32
938
+ )
939
+
940
+ # Compute dimensions for each axis
941
+ dim_t = embed_dim // 3
942
+ dim_h = embed_dim // 3
943
+ dim_w = embed_dim // 3
483
944
 
945
+ # Temporal frequencies
946
+ freqs_t = get_1d_rotary_pos_embed(
947
+ dim_t, grid_t / interpolation_scale_t, theta=theta, use_real=True, repeat_interleave_real=False
948
+ )
484
949
  # Spatial frequencies for height and width
485
- freqs_h = 1.0 / (theta ** (torch.arange(0, dim_h, 2).float() / dim_h))
486
- freqs_w = 1.0 / (theta ** (torch.arange(0, dim_w, 2).float() / dim_w))
487
- grid_h = torch.from_numpy(grid_h).float()
488
- grid_w = torch.from_numpy(grid_w).float()
489
- freqs_h = torch.einsum("n , f -> n f", grid_h, freqs_h)
490
- freqs_w = torch.einsum("n , f -> n f", grid_w, freqs_w)
491
- freqs_h = freqs_h.repeat_interleave(2, dim=-1)
492
- freqs_w = freqs_w.repeat_interleave(2, dim=-1)
493
-
494
- # Broadcast and concatenate tensors along specified dimension
495
- def broadcast(tensors, dim=-1):
496
- num_tensors = len(tensors)
497
- shape_lens = {len(t.shape) for t in tensors}
498
- assert len(shape_lens) == 1, "tensors must all have the same number of dimensions"
499
- shape_len = list(shape_lens)[0]
500
- dim = (dim + shape_len) if dim < 0 else dim
501
- dims = list(zip(*(list(t.shape) for t in tensors)))
502
- expandable_dims = [(i, val) for i, val in enumerate(dims) if i != dim]
503
- assert all(
504
- [*(len(set(t[1])) <= 2 for t in expandable_dims)]
505
- ), "invalid dimensions for broadcastable concatenation"
506
- max_dims = [(t[0], max(t[1])) for t in expandable_dims]
507
- expanded_dims = [(t[0], (t[1],) * num_tensors) for t in max_dims]
508
- expanded_dims.insert(dim, (dim, dims[dim]))
509
- expandable_shapes = list(zip(*(t[1] for t in expanded_dims)))
510
- tensors = [t[0].expand(*t[1]) for t in zip(tensors, expandable_shapes)]
511
- return torch.cat(tensors, dim=dim)
512
-
513
- freqs = broadcast((freqs_t[:, None, None, :], freqs_h[None, :, None, :], freqs_w[None, None, :, :]), dim=-1)
514
-
515
- t, h, w, d = freqs.shape
516
- freqs = freqs.view(t * h * w, d)
517
-
518
- # Generate sine and cosine components
519
- sin = freqs.sin()
520
- cos = freqs.cos()
950
+ freqs_h = get_1d_rotary_pos_embed(
951
+ dim_h, grid_h / interpolation_scale_h, theta=theta, use_real=True, repeat_interleave_real=False
952
+ )
953
+ freqs_w = get_1d_rotary_pos_embed(
954
+ dim_w, grid_w / interpolation_scale_w, theta=theta, use_real=True, repeat_interleave_real=False
955
+ )
521
956
 
522
- if use_real:
523
- return cos, sin
524
- else:
525
- freqs_cis = torch.polar(torch.ones_like(freqs), freqs)
526
- return freqs_cis
957
+ return freqs_t, freqs_h, freqs_w, grid_t, grid_h, grid_w
958
+
959
+
960
+ def get_2d_rotary_pos_embed(
961
+ embed_dim, crops_coords, grid_size, use_real=True, device: Optional[torch.device] = None, output_type: str = "np"
962
+ ):
963
+ """
964
+ RoPE for image tokens with 2d structure.
965
+
966
+ Args:
967
+ embed_dim: (`int`):
968
+ The embedding dimension size
969
+ crops_coords (`Tuple[int]`)
970
+ The top-left and bottom-right coordinates of the crop.
971
+ grid_size (`Tuple[int]`):
972
+ The grid size of the positional embedding.
973
+ use_real (`bool`):
974
+ If True, return real part and imaginary part separately. Otherwise, return complex numbers.
975
+ device: (`torch.device`, **optional**):
976
+ The device used to create tensors.
977
+
978
+ Returns:
979
+ `torch.Tensor`: positional embedding with shape `( grid_size * grid_size, embed_dim/2)`.
980
+ """
981
+ if output_type == "np":
982
+ deprecation_message = (
983
+ "`get_2d_sincos_pos_embed` uses `torch` and supports `device`."
984
+ " `from_numpy` is no longer required."
985
+ " Pass `output_type='pt' to use the new version now."
986
+ )
987
+ deprecate("output_type=='np'", "0.33.0", deprecation_message, standard_warn=False)
988
+ return _get_2d_rotary_pos_embed_np(
989
+ embed_dim=embed_dim,
990
+ crops_coords=crops_coords,
991
+ grid_size=grid_size,
992
+ use_real=use_real,
993
+ )
994
+ start, stop = crops_coords
995
+ # scale end by (steps−1)/steps matches np.linspace(..., endpoint=False)
996
+ grid_h = torch.linspace(
997
+ start[0], stop[0] * (grid_size[0] - 1) / grid_size[0], grid_size[0], device=device, dtype=torch.float32
998
+ )
999
+ grid_w = torch.linspace(
1000
+ start[1], stop[1] * (grid_size[1] - 1) / grid_size[1], grid_size[1], device=device, dtype=torch.float32
1001
+ )
1002
+ grid = torch.meshgrid(grid_w, grid_h, indexing="xy")
1003
+ grid = torch.stack(grid, dim=0) # [2, W, H]
1004
+
1005
+ grid = grid.reshape([2, 1, *grid.shape[1:]])
1006
+ pos_embed = get_2d_rotary_pos_embed_from_grid(embed_dim, grid, use_real=use_real)
1007
+ return pos_embed
527
1008
 
528
1009
 
529
- def get_2d_rotary_pos_embed(embed_dim, crops_coords, grid_size, use_real=True):
1010
+ def _get_2d_rotary_pos_embed_np(embed_dim, crops_coords, grid_size, use_real=True):
530
1011
  """
531
1012
  RoPE for image tokens with 2d structure.
532
1013
 
@@ -555,6 +1036,20 @@ def get_2d_rotary_pos_embed(embed_dim, crops_coords, grid_size, use_real=True):
555
1036
 
556
1037
 
557
1038
  def get_2d_rotary_pos_embed_from_grid(embed_dim, grid, use_real=False):
1039
+ """
1040
+ Get 2D RoPE from grid.
1041
+
1042
+ Args:
1043
+ embed_dim: (`int`):
1044
+ The embedding dimension size, corresponding to hidden_size_head.
1045
+ grid (`np.ndarray`):
1046
+ The grid of the positional embedding.
1047
+ use_real (`bool`):
1048
+ If True, return real part and imaginary part separately. Otherwise, return complex numbers.
1049
+
1050
+ Returns:
1051
+ `torch.Tensor`: positional embedding with shape `( grid_size * grid_size, embed_dim/2)`.
1052
+ """
558
1053
  assert embed_dim % 4 == 0
559
1054
 
560
1055
  # use half of dimensions to encode grid_h
@@ -575,6 +1070,23 @@ def get_2d_rotary_pos_embed_from_grid(embed_dim, grid, use_real=False):
575
1070
 
576
1071
 
577
1072
  def get_2d_rotary_pos_embed_lumina(embed_dim, len_h, len_w, linear_factor=1.0, ntk_factor=1.0):
1073
+ """
1074
+ Get 2D RoPE from grid.
1075
+
1076
+ Args:
1077
+ embed_dim: (`int`):
1078
+ The embedding dimension size, corresponding to hidden_size_head.
1079
+ grid (`np.ndarray`):
1080
+ The grid of the positional embedding.
1081
+ linear_factor (`float`):
1082
+ The linear factor of the positional embedding, which is used to scale the positional embedding in the linear
1083
+ layer.
1084
+ ntk_factor (`float`):
1085
+ The ntk factor of the positional embedding, which is used to scale the positional embedding in the ntk layer.
1086
+
1087
+ Returns:
1088
+ `torch.Tensor`: positional embedding with shape `( grid_size * grid_size, embed_dim/2)`.
1089
+ """
578
1090
  assert embed_dim % 4 == 0
579
1091
 
580
1092
  emb_h = get_1d_rotary_pos_embed(
@@ -598,6 +1110,7 @@ def get_1d_rotary_pos_embed(
598
1110
  linear_factor=1.0,
599
1111
  ntk_factor=1.0,
600
1112
  repeat_interleave_real=True,
1113
+ freqs_dtype=torch.float32, # torch.float32, torch.float64 (flux)
601
1114
  ):
602
1115
  """
603
1116
  Precompute the frequency tensor for complex exponentials (cis) with given dimensions.
@@ -620,26 +1133,37 @@ def get_1d_rotary_pos_embed(
620
1133
  repeat_interleave_real (`bool`, *optional*, defaults to `True`):
621
1134
  If `True` and `use_real`, real part and imaginary part are each interleaved with themselves to reach `dim`.
622
1135
  Otherwise, they are concateanted with themselves.
1136
+ freqs_dtype (`torch.float32` or `torch.float64`, *optional*, defaults to `torch.float32`):
1137
+ the dtype of the frequency tensor.
623
1138
  Returns:
624
1139
  `torch.Tensor`: Precomputed frequency tensor with complex exponentials. [S, D/2]
625
1140
  """
626
1141
  assert dim % 2 == 0
627
1142
 
628
1143
  if isinstance(pos, int):
629
- pos = np.arange(pos)
1144
+ pos = torch.arange(pos)
1145
+ if isinstance(pos, np.ndarray):
1146
+ pos = torch.from_numpy(pos) # type: ignore # [S]
1147
+
630
1148
  theta = theta * ntk_factor
631
- freqs = 1.0 / (theta ** (torch.arange(0, dim, 2)[: (dim // 2)].float() / dim)) / linear_factor # [D/2]
632
- t = torch.from_numpy(pos).to(freqs.device) # type: ignore # [S]
633
- freqs = torch.outer(t, freqs).float() # type: ignore # [S, D/2]
1149
+ freqs = (
1150
+ 1.0
1151
+ / (theta ** (torch.arange(0, dim, 2, dtype=freqs_dtype, device=pos.device)[: (dim // 2)] / dim))
1152
+ / linear_factor
1153
+ ) # [D/2]
1154
+ freqs = torch.outer(pos, freqs) # type: ignore # [S, D/2]
634
1155
  if use_real and repeat_interleave_real:
635
- freqs_cos = freqs.cos().repeat_interleave(2, dim=1) # [S, D]
636
- freqs_sin = freqs.sin().repeat_interleave(2, dim=1) # [S, D]
1156
+ # flux, hunyuan-dit, cogvideox
1157
+ freqs_cos = freqs.cos().repeat_interleave(2, dim=1).float() # [S, D]
1158
+ freqs_sin = freqs.sin().repeat_interleave(2, dim=1).float() # [S, D]
637
1159
  return freqs_cos, freqs_sin
638
1160
  elif use_real:
639
- freqs_cos = torch.cat([freqs.cos(), freqs.cos()], dim=-1) # [S, D]
640
- freqs_sin = torch.cat([freqs.sin(), freqs.sin()], dim=-1) # [S, D]
1161
+ # stable audio, allegro
1162
+ freqs_cos = torch.cat([freqs.cos(), freqs.cos()], dim=-1).float() # [S, D]
1163
+ freqs_sin = torch.cat([freqs.sin(), freqs.sin()], dim=-1).float() # [S, D]
641
1164
  return freqs_cos, freqs_sin
642
1165
  else:
1166
+ # lumina
643
1167
  freqs_cis = torch.polar(torch.ones_like(freqs), freqs) # complex64 # [S, D/2]
644
1168
  return freqs_cis
645
1169
 
@@ -671,11 +1195,11 @@ def apply_rotary_emb(
671
1195
  cos, sin = cos.to(x.device), sin.to(x.device)
672
1196
 
673
1197
  if use_real_unbind_dim == -1:
674
- # Use for example in Lumina
1198
+ # Used for flux, cogvideox, hunyuan-dit
675
1199
  x_real, x_imag = x.reshape(*x.shape[:-1], -1, 2).unbind(-1) # [B, S, H, D//2]
676
1200
  x_rotated = torch.stack([-x_imag, x_real], dim=-1).flatten(3)
677
1201
  elif use_real_unbind_dim == -2:
678
- # Use for example in Stable Audio
1202
+ # Used for Stable Audio
679
1203
  x_real, x_imag = x.reshape(*x.shape[:-1], 2, -1).unbind(-2) # [B, S, H, D//2]
680
1204
  x_rotated = torch.cat([-x_imag, x_real], dim=-1)
681
1205
  else:
@@ -685,6 +1209,7 @@ def apply_rotary_emb(
685
1209
 
686
1210
  return out
687
1211
  else:
1212
+ # used for lumina
688
1213
  x_rotated = torch.view_as_complex(x.float().reshape(*x.shape[:-1], -1, 2))
689
1214
  freqs_cis = freqs_cis.unsqueeze(2)
690
1215
  x_out = torch.view_as_real(x_rotated * freqs_cis).flatten(3)
@@ -692,6 +1217,54 @@ def apply_rotary_emb(
692
1217
  return x_out.type_as(x)
693
1218
 
694
1219
 
1220
+ def apply_rotary_emb_allegro(x: torch.Tensor, freqs_cis, positions):
1221
+ # TODO(aryan): rewrite
1222
+ def apply_1d_rope(tokens, pos, cos, sin):
1223
+ cos = F.embedding(pos, cos)[:, None, :, :]
1224
+ sin = F.embedding(pos, sin)[:, None, :, :]
1225
+ x1, x2 = tokens[..., : tokens.shape[-1] // 2], tokens[..., tokens.shape[-1] // 2 :]
1226
+ tokens_rotated = torch.cat((-x2, x1), dim=-1)
1227
+ return (tokens.float() * cos + tokens_rotated.float() * sin).to(tokens.dtype)
1228
+
1229
+ (t_cos, t_sin), (h_cos, h_sin), (w_cos, w_sin) = freqs_cis
1230
+ t, h, w = x.chunk(3, dim=-1)
1231
+ t = apply_1d_rope(t, positions[0], t_cos, t_sin)
1232
+ h = apply_1d_rope(h, positions[1], h_cos, h_sin)
1233
+ w = apply_1d_rope(w, positions[2], w_cos, w_sin)
1234
+ x = torch.cat([t, h, w], dim=-1)
1235
+ return x
1236
+
1237
+
1238
+ class FluxPosEmbed(nn.Module):
1239
+ # modified from https://github.com/black-forest-labs/flux/blob/c00d7c60b085fce8058b9df845e036090873f2ce/src/flux/modules/layers.py#L11
1240
+ def __init__(self, theta: int, axes_dim: List[int]):
1241
+ super().__init__()
1242
+ self.theta = theta
1243
+ self.axes_dim = axes_dim
1244
+
1245
+ def forward(self, ids: torch.Tensor) -> torch.Tensor:
1246
+ n_axes = ids.shape[-1]
1247
+ cos_out = []
1248
+ sin_out = []
1249
+ pos = ids.float()
1250
+ is_mps = ids.device.type == "mps"
1251
+ freqs_dtype = torch.float32 if is_mps else torch.float64
1252
+ for i in range(n_axes):
1253
+ cos, sin = get_1d_rotary_pos_embed(
1254
+ self.axes_dim[i],
1255
+ pos[:, i],
1256
+ theta=self.theta,
1257
+ repeat_interleave_real=True,
1258
+ use_real=True,
1259
+ freqs_dtype=freqs_dtype,
1260
+ )
1261
+ cos_out.append(cos)
1262
+ sin_out.append(sin)
1263
+ freqs_cos = torch.cat(cos_out, dim=-1).to(ids.device)
1264
+ freqs_sin = torch.cat(sin_out, dim=-1).to(ids.device)
1265
+ return freqs_cos, freqs_sin
1266
+
1267
+
695
1268
  class TimestepEmbedding(nn.Module):
696
1269
  def __init__(
697
1270
  self,
@@ -962,7 +1535,7 @@ class ImageProjection(nn.Module):
962
1535
  batch_size = image_embeds.shape[0]
963
1536
 
964
1537
  # image
965
- image_embeds = self.image_embeds(image_embeds)
1538
+ image_embeds = self.image_embeds(image_embeds.to(self.image_embeds.weight.dtype))
966
1539
  image_embeds = image_embeds.reshape(batch_size, self.num_image_text_embeds, -1)
967
1540
  image_embeds = self.norm(image_embeds)
968
1541
  return image_embeds
@@ -1058,6 +1631,39 @@ class CombinedTimestepGuidanceTextProjEmbeddings(nn.Module):
1058
1631
  return conditioning
1059
1632
 
1060
1633
 
1634
+ class CogView3CombinedTimestepSizeEmbeddings(nn.Module):
1635
+ def __init__(self, embedding_dim: int, condition_dim: int, pooled_projection_dim: int, timesteps_dim: int = 256):
1636
+ super().__init__()
1637
+
1638
+ self.time_proj = Timesteps(num_channels=timesteps_dim, flip_sin_to_cos=True, downscale_freq_shift=0)
1639
+ self.condition_proj = Timesteps(num_channels=condition_dim, flip_sin_to_cos=True, downscale_freq_shift=0)
1640
+ self.timestep_embedder = TimestepEmbedding(in_channels=timesteps_dim, time_embed_dim=embedding_dim)
1641
+ self.condition_embedder = PixArtAlphaTextProjection(pooled_projection_dim, embedding_dim, act_fn="silu")
1642
+
1643
+ def forward(
1644
+ self,
1645
+ timestep: torch.Tensor,
1646
+ original_size: torch.Tensor,
1647
+ target_size: torch.Tensor,
1648
+ crop_coords: torch.Tensor,
1649
+ hidden_dtype: torch.dtype,
1650
+ ) -> torch.Tensor:
1651
+ timesteps_proj = self.time_proj(timestep)
1652
+
1653
+ original_size_proj = self.condition_proj(original_size.flatten()).view(original_size.size(0), -1)
1654
+ crop_coords_proj = self.condition_proj(crop_coords.flatten()).view(crop_coords.size(0), -1)
1655
+ target_size_proj = self.condition_proj(target_size.flatten()).view(target_size.size(0), -1)
1656
+
1657
+ # (B, 3 * condition_dim)
1658
+ condition_proj = torch.cat([original_size_proj, crop_coords_proj, target_size_proj], dim=1)
1659
+
1660
+ timesteps_emb = self.timestep_embedder(timesteps_proj.to(dtype=hidden_dtype)) # (B, embedding_dim)
1661
+ condition_emb = self.condition_embedder(condition_proj.to(dtype=hidden_dtype)) # (B, embedding_dim)
1662
+
1663
+ conditioning = timesteps_emb + condition_emb
1664
+ return conditioning
1665
+
1666
+
1061
1667
  class HunyuanDiTAttentionPool(nn.Module):
1062
1668
  # Copied from https://github.com/Tencent/HunyuanDiT/blob/cb709308d92e6c7e8d59d0dff41b74d35088db6a/hydit/modules/poolers.py#L6
1063
1669
 
@@ -1193,6 +1799,41 @@ class LuminaCombinedTimestepCaptionEmbedding(nn.Module):
1193
1799
  return conditioning
1194
1800
 
1195
1801
 
1802
+ class MochiCombinedTimestepCaptionEmbedding(nn.Module):
1803
+ def __init__(
1804
+ self,
1805
+ embedding_dim: int,
1806
+ pooled_projection_dim: int,
1807
+ text_embed_dim: int,
1808
+ time_embed_dim: int = 256,
1809
+ num_attention_heads: int = 8,
1810
+ ) -> None:
1811
+ super().__init__()
1812
+
1813
+ self.time_proj = Timesteps(num_channels=time_embed_dim, flip_sin_to_cos=True, downscale_freq_shift=0.0)
1814
+ self.timestep_embedder = TimestepEmbedding(in_channels=time_embed_dim, time_embed_dim=embedding_dim)
1815
+ self.pooler = MochiAttentionPool(
1816
+ num_attention_heads=num_attention_heads, embed_dim=text_embed_dim, output_dim=embedding_dim
1817
+ )
1818
+ self.caption_proj = nn.Linear(text_embed_dim, pooled_projection_dim)
1819
+
1820
+ def forward(
1821
+ self,
1822
+ timestep: torch.LongTensor,
1823
+ encoder_hidden_states: torch.Tensor,
1824
+ encoder_attention_mask: torch.Tensor,
1825
+ hidden_dtype: Optional[torch.dtype] = None,
1826
+ ):
1827
+ time_proj = self.time_proj(timestep)
1828
+ time_emb = self.timestep_embedder(time_proj.to(dtype=hidden_dtype))
1829
+
1830
+ pooled_projections = self.pooler(encoder_hidden_states, encoder_attention_mask)
1831
+ caption_proj = self.caption_proj(encoder_hidden_states)
1832
+
1833
+ conditioning = time_emb + pooled_projections
1834
+ return conditioning, caption_proj
1835
+
1836
+
1196
1837
  class TextTimeEmbedding(nn.Module):
1197
1838
  def __init__(self, encoder_dim: int, time_embed_dim: int, num_heads: int = 64):
1198
1839
  super().__init__()
@@ -1321,6 +1962,88 @@ class AttentionPooling(nn.Module):
1321
1962
  return a[:, 0, :] # cls_token
1322
1963
 
1323
1964
 
1965
+ class MochiAttentionPool(nn.Module):
1966
+ def __init__(
1967
+ self,
1968
+ num_attention_heads: int,
1969
+ embed_dim: int,
1970
+ output_dim: Optional[int] = None,
1971
+ ) -> None:
1972
+ super().__init__()
1973
+
1974
+ self.output_dim = output_dim or embed_dim
1975
+ self.num_attention_heads = num_attention_heads
1976
+
1977
+ self.to_kv = nn.Linear(embed_dim, 2 * embed_dim)
1978
+ self.to_q = nn.Linear(embed_dim, embed_dim)
1979
+ self.to_out = nn.Linear(embed_dim, self.output_dim)
1980
+
1981
+ @staticmethod
1982
+ def pool_tokens(x: torch.Tensor, mask: torch.Tensor, *, keepdim=False) -> torch.Tensor:
1983
+ """
1984
+ Pool tokens in x using mask.
1985
+
1986
+ NOTE: We assume x does not require gradients.
1987
+
1988
+ Args:
1989
+ x: (B, L, D) tensor of tokens.
1990
+ mask: (B, L) boolean tensor indicating which tokens are not padding.
1991
+
1992
+ Returns:
1993
+ pooled: (B, D) tensor of pooled tokens.
1994
+ """
1995
+ assert x.size(1) == mask.size(1) # Expected mask to have same length as tokens.
1996
+ assert x.size(0) == mask.size(0) # Expected mask to have same batch size as tokens.
1997
+ mask = mask[:, :, None].to(dtype=x.dtype)
1998
+ mask = mask / mask.sum(dim=1, keepdim=True).clamp(min=1)
1999
+ pooled = (x * mask).sum(dim=1, keepdim=keepdim)
2000
+ return pooled
2001
+
2002
+ def forward(self, x: torch.Tensor, mask: torch.BoolTensor) -> torch.Tensor:
2003
+ r"""
2004
+ Args:
2005
+ x (`torch.Tensor`):
2006
+ Tensor of shape `(B, S, D)` of input tokens.
2007
+ mask (`torch.Tensor`):
2008
+ Boolean ensor of shape `(B, S)` indicating which tokens are not padding.
2009
+
2010
+ Returns:
2011
+ `torch.Tensor`:
2012
+ `(B, D)` tensor of pooled tokens.
2013
+ """
2014
+ D = x.size(2)
2015
+
2016
+ # Construct attention mask, shape: (B, 1, num_queries=1, num_keys=1+L).
2017
+ attn_mask = mask[:, None, None, :].bool() # (B, 1, 1, L).
2018
+ attn_mask = F.pad(attn_mask, (1, 0), value=True) # (B, 1, 1, 1+L).
2019
+
2020
+ # Average non-padding token features. These will be used as the query.
2021
+ x_pool = self.pool_tokens(x, mask, keepdim=True) # (B, 1, D)
2022
+
2023
+ # Concat pooled features to input sequence.
2024
+ x = torch.cat([x_pool, x], dim=1) # (B, L+1, D)
2025
+
2026
+ # Compute queries, keys, values. Only the mean token is used to create a query.
2027
+ kv = self.to_kv(x) # (B, L+1, 2 * D)
2028
+ q = self.to_q(x[:, 0]) # (B, D)
2029
+
2030
+ # Extract heads.
2031
+ head_dim = D // self.num_attention_heads
2032
+ kv = kv.unflatten(2, (2, self.num_attention_heads, head_dim)) # (B, 1+L, 2, H, head_dim)
2033
+ kv = kv.transpose(1, 3) # (B, H, 2, 1+L, head_dim)
2034
+ k, v = kv.unbind(2) # (B, H, 1+L, head_dim)
2035
+ q = q.unflatten(1, (self.num_attention_heads, head_dim)) # (B, H, head_dim)
2036
+ q = q.unsqueeze(2) # (B, H, 1, head_dim)
2037
+
2038
+ # Compute attention.
2039
+ x = F.scaled_dot_product_attention(q, k, v, attn_mask=attn_mask, dropout_p=0.0) # (B, H, 1, head_dim)
2040
+
2041
+ # Concatenate heads and run output.
2042
+ x = x.squeeze(2).flatten(1, 2) # (B, D = H * head_dim)
2043
+ x = self.to_out(x)
2044
+ return x
2045
+
2046
+
1324
2047
  def get_fourier_embeds_from_boundingbox(embed_dim, box):
1325
2048
  """
1326
2049
  Args:
@@ -1673,6 +2396,187 @@ class IPAdapterFaceIDPlusImageProjection(nn.Module):
1673
2396
  return out
1674
2397
 
1675
2398
 
2399
+ class IPAdapterTimeImageProjectionBlock(nn.Module):
2400
+ """Block for IPAdapterTimeImageProjection.
2401
+
2402
+ Args:
2403
+ hidden_dim (`int`, defaults to 1280):
2404
+ The number of hidden channels.
2405
+ dim_head (`int`, defaults to 64):
2406
+ The number of head channels.
2407
+ heads (`int`, defaults to 20):
2408
+ Parallel attention heads.
2409
+ ffn_ratio (`int`, defaults to 4):
2410
+ The expansion ratio of feedforward network hidden layer channels.
2411
+ """
2412
+
2413
+ def __init__(
2414
+ self,
2415
+ hidden_dim: int = 1280,
2416
+ dim_head: int = 64,
2417
+ heads: int = 20,
2418
+ ffn_ratio: int = 4,
2419
+ ) -> None:
2420
+ super().__init__()
2421
+ from .attention import FeedForward
2422
+
2423
+ self.ln0 = nn.LayerNorm(hidden_dim)
2424
+ self.ln1 = nn.LayerNorm(hidden_dim)
2425
+ self.attn = Attention(
2426
+ query_dim=hidden_dim,
2427
+ cross_attention_dim=hidden_dim,
2428
+ dim_head=dim_head,
2429
+ heads=heads,
2430
+ bias=False,
2431
+ out_bias=False,
2432
+ )
2433
+ self.ff = FeedForward(hidden_dim, hidden_dim, activation_fn="gelu", mult=ffn_ratio, bias=False)
2434
+
2435
+ # AdaLayerNorm
2436
+ self.adaln_silu = nn.SiLU()
2437
+ self.adaln_proj = nn.Linear(hidden_dim, 4 * hidden_dim)
2438
+ self.adaln_norm = nn.LayerNorm(hidden_dim)
2439
+
2440
+ # Set attention scale and fuse KV
2441
+ self.attn.scale = 1 / math.sqrt(math.sqrt(dim_head))
2442
+ self.attn.fuse_projections()
2443
+ self.attn.to_k = None
2444
+ self.attn.to_v = None
2445
+
2446
+ def forward(self, x: torch.Tensor, latents: torch.Tensor, timestep_emb: torch.Tensor) -> torch.Tensor:
2447
+ """Forward pass.
2448
+
2449
+ Args:
2450
+ x (`torch.Tensor`):
2451
+ Image features.
2452
+ latents (`torch.Tensor`):
2453
+ Latent features.
2454
+ timestep_emb (`torch.Tensor`):
2455
+ Timestep embedding.
2456
+
2457
+ Returns:
2458
+ `torch.Tensor`: Output latent features.
2459
+ """
2460
+
2461
+ # Shift and scale for AdaLayerNorm
2462
+ emb = self.adaln_proj(self.adaln_silu(timestep_emb))
2463
+ shift_msa, scale_msa, shift_mlp, scale_mlp = emb.chunk(4, dim=1)
2464
+
2465
+ # Fused Attention
2466
+ residual = latents
2467
+ x = self.ln0(x)
2468
+ latents = self.ln1(latents) * (1 + scale_msa[:, None]) + shift_msa[:, None]
2469
+
2470
+ batch_size = latents.shape[0]
2471
+
2472
+ query = self.attn.to_q(latents)
2473
+ kv_input = torch.cat((x, latents), dim=-2)
2474
+ key, value = self.attn.to_kv(kv_input).chunk(2, dim=-1)
2475
+
2476
+ inner_dim = key.shape[-1]
2477
+ head_dim = inner_dim // self.attn.heads
2478
+
2479
+ query = query.view(batch_size, -1, self.attn.heads, head_dim).transpose(1, 2)
2480
+ key = key.view(batch_size, -1, self.attn.heads, head_dim).transpose(1, 2)
2481
+ value = value.view(batch_size, -1, self.attn.heads, head_dim).transpose(1, 2)
2482
+
2483
+ weight = (query * self.attn.scale) @ (key * self.attn.scale).transpose(-2, -1)
2484
+ weight = torch.softmax(weight.float(), dim=-1).type(weight.dtype)
2485
+ latents = weight @ value
2486
+
2487
+ latents = latents.transpose(1, 2).reshape(batch_size, -1, self.attn.heads * head_dim)
2488
+ latents = self.attn.to_out[0](latents)
2489
+ latents = self.attn.to_out[1](latents)
2490
+ latents = latents + residual
2491
+
2492
+ ## FeedForward
2493
+ residual = latents
2494
+ latents = self.adaln_norm(latents) * (1 + scale_mlp[:, None]) + shift_mlp[:, None]
2495
+ return self.ff(latents) + residual
2496
+
2497
+
2498
+ # Modified from https://github.com/mlfoundations/open_flamingo/blob/main/open_flamingo/src/helpers.py
2499
+ class IPAdapterTimeImageProjection(nn.Module):
2500
+ """Resampler of SD3 IP-Adapter with timestep embedding.
2501
+
2502
+ Args:
2503
+ embed_dim (`int`, defaults to 1152):
2504
+ The feature dimension.
2505
+ output_dim (`int`, defaults to 2432):
2506
+ The number of output channels.
2507
+ hidden_dim (`int`, defaults to 1280):
2508
+ The number of hidden channels.
2509
+ depth (`int`, defaults to 4):
2510
+ The number of blocks.
2511
+ dim_head (`int`, defaults to 64):
2512
+ The number of head channels.
2513
+ heads (`int`, defaults to 20):
2514
+ Parallel attention heads.
2515
+ num_queries (`int`, defaults to 64):
2516
+ The number of queries.
2517
+ ffn_ratio (`int`, defaults to 4):
2518
+ The expansion ratio of feedforward network hidden layer channels.
2519
+ timestep_in_dim (`int`, defaults to 320):
2520
+ The number of input channels for timestep embedding.
2521
+ timestep_flip_sin_to_cos (`bool`, defaults to True):
2522
+ Flip the timestep embedding order to `cos, sin` (if True) or `sin, cos` (if False).
2523
+ timestep_freq_shift (`int`, defaults to 0):
2524
+ Controls the timestep delta between frequencies between dimensions.
2525
+ """
2526
+
2527
+ def __init__(
2528
+ self,
2529
+ embed_dim: int = 1152,
2530
+ output_dim: int = 2432,
2531
+ hidden_dim: int = 1280,
2532
+ depth: int = 4,
2533
+ dim_head: int = 64,
2534
+ heads: int = 20,
2535
+ num_queries: int = 64,
2536
+ ffn_ratio: int = 4,
2537
+ timestep_in_dim: int = 320,
2538
+ timestep_flip_sin_to_cos: bool = True,
2539
+ timestep_freq_shift: int = 0,
2540
+ ) -> None:
2541
+ super().__init__()
2542
+ self.latents = nn.Parameter(torch.randn(1, num_queries, hidden_dim) / hidden_dim**0.5)
2543
+ self.proj_in = nn.Linear(embed_dim, hidden_dim)
2544
+ self.proj_out = nn.Linear(hidden_dim, output_dim)
2545
+ self.norm_out = nn.LayerNorm(output_dim)
2546
+ self.layers = nn.ModuleList(
2547
+ [IPAdapterTimeImageProjectionBlock(hidden_dim, dim_head, heads, ffn_ratio) for _ in range(depth)]
2548
+ )
2549
+ self.time_proj = Timesteps(timestep_in_dim, timestep_flip_sin_to_cos, timestep_freq_shift)
2550
+ self.time_embedding = TimestepEmbedding(timestep_in_dim, hidden_dim, act_fn="silu")
2551
+
2552
+ def forward(self, x: torch.Tensor, timestep: torch.Tensor) -> Tuple[torch.Tensor, torch.Tensor]:
2553
+ """Forward pass.
2554
+
2555
+ Args:
2556
+ x (`torch.Tensor`):
2557
+ Image features.
2558
+ timestep (`torch.Tensor`):
2559
+ Timestep in denoising process.
2560
+ Returns:
2561
+ `Tuple`[`torch.Tensor`, `torch.Tensor`]: The pair (latents, timestep_emb).
2562
+ """
2563
+ timestep_emb = self.time_proj(timestep).to(dtype=x.dtype)
2564
+ timestep_emb = self.time_embedding(timestep_emb)
2565
+
2566
+ latents = self.latents.repeat(x.size(0), 1, 1)
2567
+
2568
+ x = self.proj_in(x)
2569
+ x = x + timestep_emb[:, None]
2570
+
2571
+ for block in self.layers:
2572
+ latents = block(x, latents, timestep_emb)
2573
+
2574
+ latents = self.proj_out(latents)
2575
+ latents = self.norm_out(latents)
2576
+
2577
+ return latents, timestep_emb
2578
+
2579
+
1676
2580
  class MultiIPAdapterImageProjection(nn.Module):
1677
2581
  def __init__(self, IPAdapterImageProjectionLayers: Union[List[nn.Module], Tuple[nn.Module]]):
1678
2582
  super().__init__()