diffusers 0.27.2__py3-none-any.whl → 0.28.1__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (278) hide show
  1. diffusers/__init__.py +26 -1
  2. diffusers/callbacks.py +156 -0
  3. diffusers/commands/env.py +110 -6
  4. diffusers/configuration_utils.py +33 -11
  5. diffusers/dependency_versions_table.py +2 -1
  6. diffusers/image_processor.py +158 -45
  7. diffusers/loaders/__init__.py +2 -5
  8. diffusers/loaders/autoencoder.py +4 -4
  9. diffusers/loaders/controlnet.py +4 -4
  10. diffusers/loaders/ip_adapter.py +80 -22
  11. diffusers/loaders/lora.py +134 -20
  12. diffusers/loaders/lora_conversion_utils.py +46 -43
  13. diffusers/loaders/peft.py +4 -3
  14. diffusers/loaders/single_file.py +401 -170
  15. diffusers/loaders/single_file_model.py +290 -0
  16. diffusers/loaders/single_file_utils.py +616 -672
  17. diffusers/loaders/textual_inversion.py +41 -20
  18. diffusers/loaders/unet.py +168 -115
  19. diffusers/loaders/unet_loader_utils.py +163 -0
  20. diffusers/models/__init__.py +8 -0
  21. diffusers/models/activations.py +23 -3
  22. diffusers/models/attention.py +10 -11
  23. diffusers/models/attention_processor.py +475 -148
  24. diffusers/models/autoencoders/autoencoder_asym_kl.py +14 -16
  25. diffusers/models/autoencoders/autoencoder_kl.py +18 -19
  26. diffusers/models/autoencoders/autoencoder_kl_temporal_decoder.py +11 -11
  27. diffusers/models/autoencoders/autoencoder_tiny.py +16 -16
  28. diffusers/models/autoencoders/consistency_decoder_vae.py +36 -11
  29. diffusers/models/autoencoders/vae.py +23 -24
  30. diffusers/models/controlnet.py +12 -9
  31. diffusers/models/controlnet_flax.py +4 -4
  32. diffusers/models/controlnet_xs.py +1915 -0
  33. diffusers/models/downsampling.py +17 -18
  34. diffusers/models/embeddings.py +363 -32
  35. diffusers/models/model_loading_utils.py +177 -0
  36. diffusers/models/modeling_flax_pytorch_utils.py +2 -1
  37. diffusers/models/modeling_flax_utils.py +4 -4
  38. diffusers/models/modeling_outputs.py +14 -0
  39. diffusers/models/modeling_pytorch_flax_utils.py +1 -1
  40. diffusers/models/modeling_utils.py +175 -99
  41. diffusers/models/normalization.py +2 -1
  42. diffusers/models/resnet.py +18 -23
  43. diffusers/models/transformer_temporal.py +3 -3
  44. diffusers/models/transformers/__init__.py +3 -0
  45. diffusers/models/transformers/dit_transformer_2d.py +240 -0
  46. diffusers/models/transformers/dual_transformer_2d.py +4 -4
  47. diffusers/models/transformers/hunyuan_transformer_2d.py +427 -0
  48. diffusers/models/transformers/pixart_transformer_2d.py +336 -0
  49. diffusers/models/transformers/prior_transformer.py +7 -7
  50. diffusers/models/transformers/t5_film_transformer.py +17 -19
  51. diffusers/models/transformers/transformer_2d.py +292 -184
  52. diffusers/models/transformers/transformer_temporal.py +10 -10
  53. diffusers/models/unets/unet_1d.py +5 -5
  54. diffusers/models/unets/unet_1d_blocks.py +29 -29
  55. diffusers/models/unets/unet_2d.py +6 -6
  56. diffusers/models/unets/unet_2d_blocks.py +137 -128
  57. diffusers/models/unets/unet_2d_condition.py +19 -15
  58. diffusers/models/unets/unet_2d_condition_flax.py +6 -5
  59. diffusers/models/unets/unet_3d_blocks.py +79 -77
  60. diffusers/models/unets/unet_3d_condition.py +13 -9
  61. diffusers/models/unets/unet_i2vgen_xl.py +14 -13
  62. diffusers/models/unets/unet_kandinsky3.py +1 -1
  63. diffusers/models/unets/unet_motion_model.py +114 -14
  64. diffusers/models/unets/unet_spatio_temporal_condition.py +15 -14
  65. diffusers/models/unets/unet_stable_cascade.py +16 -13
  66. diffusers/models/upsampling.py +17 -20
  67. diffusers/models/vq_model.py +16 -15
  68. diffusers/pipelines/__init__.py +27 -3
  69. diffusers/pipelines/amused/pipeline_amused.py +12 -12
  70. diffusers/pipelines/amused/pipeline_amused_img2img.py +14 -12
  71. diffusers/pipelines/amused/pipeline_amused_inpaint.py +13 -11
  72. diffusers/pipelines/animatediff/__init__.py +2 -0
  73. diffusers/pipelines/animatediff/pipeline_animatediff.py +24 -46
  74. diffusers/pipelines/animatediff/pipeline_animatediff_sdxl.py +1284 -0
  75. diffusers/pipelines/animatediff/pipeline_animatediff_video2video.py +82 -72
  76. diffusers/pipelines/animatediff/pipeline_output.py +3 -2
  77. diffusers/pipelines/audioldm/pipeline_audioldm.py +14 -14
  78. diffusers/pipelines/audioldm2/modeling_audioldm2.py +54 -35
  79. diffusers/pipelines/audioldm2/pipeline_audioldm2.py +120 -36
  80. diffusers/pipelines/auto_pipeline.py +21 -17
  81. diffusers/pipelines/blip_diffusion/blip_image_processing.py +1 -1
  82. diffusers/pipelines/blip_diffusion/modeling_blip2.py +5 -5
  83. diffusers/pipelines/blip_diffusion/modeling_ctx_clip.py +1 -1
  84. diffusers/pipelines/blip_diffusion/pipeline_blip_diffusion.py +2 -2
  85. diffusers/pipelines/consistency_models/pipeline_consistency_models.py +5 -5
  86. diffusers/pipelines/controlnet/multicontrolnet.py +4 -8
  87. diffusers/pipelines/controlnet/pipeline_controlnet.py +87 -52
  88. diffusers/pipelines/controlnet/pipeline_controlnet_blip_diffusion.py +2 -2
  89. diffusers/pipelines/controlnet/pipeline_controlnet_img2img.py +50 -43
  90. diffusers/pipelines/controlnet/pipeline_controlnet_inpaint.py +52 -40
  91. diffusers/pipelines/controlnet/pipeline_controlnet_inpaint_sd_xl.py +80 -47
  92. diffusers/pipelines/controlnet/pipeline_controlnet_sd_xl.py +147 -49
  93. diffusers/pipelines/controlnet/pipeline_controlnet_sd_xl_img2img.py +89 -55
  94. diffusers/pipelines/controlnet_xs/__init__.py +68 -0
  95. diffusers/pipelines/controlnet_xs/pipeline_controlnet_xs.py +911 -0
  96. diffusers/pipelines/controlnet_xs/pipeline_controlnet_xs_sd_xl.py +1115 -0
  97. diffusers/pipelines/deepfloyd_if/pipeline_if.py +14 -28
  98. diffusers/pipelines/deepfloyd_if/pipeline_if_img2img.py +18 -33
  99. diffusers/pipelines/deepfloyd_if/pipeline_if_img2img_superresolution.py +21 -39
  100. diffusers/pipelines/deepfloyd_if/pipeline_if_inpainting.py +20 -36
  101. diffusers/pipelines/deepfloyd_if/pipeline_if_inpainting_superresolution.py +23 -39
  102. diffusers/pipelines/deepfloyd_if/pipeline_if_superresolution.py +17 -32
  103. diffusers/pipelines/deprecated/alt_diffusion/modeling_roberta_series.py +11 -11
  104. diffusers/pipelines/deprecated/alt_diffusion/pipeline_alt_diffusion.py +43 -20
  105. diffusers/pipelines/deprecated/alt_diffusion/pipeline_alt_diffusion_img2img.py +36 -18
  106. diffusers/pipelines/deprecated/repaint/pipeline_repaint.py +2 -2
  107. diffusers/pipelines/deprecated/spectrogram_diffusion/pipeline_spectrogram_diffusion.py +7 -7
  108. diffusers/pipelines/deprecated/stable_diffusion_variants/pipeline_cycle_diffusion.py +12 -12
  109. diffusers/pipelines/deprecated/stable_diffusion_variants/pipeline_stable_diffusion_inpaint_legacy.py +18 -18
  110. diffusers/pipelines/deprecated/stable_diffusion_variants/pipeline_stable_diffusion_model_editing.py +20 -15
  111. diffusers/pipelines/deprecated/stable_diffusion_variants/pipeline_stable_diffusion_paradigms.py +20 -15
  112. diffusers/pipelines/deprecated/stable_diffusion_variants/pipeline_stable_diffusion_pix2pix_zero.py +30 -25
  113. diffusers/pipelines/deprecated/versatile_diffusion/modeling_text_unet.py +69 -59
  114. diffusers/pipelines/deprecated/versatile_diffusion/pipeline_versatile_diffusion.py +13 -13
  115. diffusers/pipelines/deprecated/versatile_diffusion/pipeline_versatile_diffusion_dual_guided.py +10 -5
  116. diffusers/pipelines/deprecated/versatile_diffusion/pipeline_versatile_diffusion_image_variation.py +11 -6
  117. diffusers/pipelines/deprecated/versatile_diffusion/pipeline_versatile_diffusion_text_to_image.py +10 -5
  118. diffusers/pipelines/deprecated/vq_diffusion/pipeline_vq_diffusion.py +5 -5
  119. diffusers/pipelines/dit/pipeline_dit.py +7 -4
  120. diffusers/pipelines/free_init_utils.py +39 -38
  121. diffusers/pipelines/hunyuandit/__init__.py +48 -0
  122. diffusers/pipelines/hunyuandit/pipeline_hunyuandit.py +881 -0
  123. diffusers/pipelines/i2vgen_xl/pipeline_i2vgen_xl.py +33 -48
  124. diffusers/pipelines/kandinsky/pipeline_kandinsky.py +8 -8
  125. diffusers/pipelines/kandinsky/pipeline_kandinsky_combined.py +23 -20
  126. diffusers/pipelines/kandinsky/pipeline_kandinsky_img2img.py +11 -11
  127. diffusers/pipelines/kandinsky/pipeline_kandinsky_inpaint.py +12 -12
  128. diffusers/pipelines/kandinsky/pipeline_kandinsky_prior.py +10 -10
  129. diffusers/pipelines/kandinsky2_2/pipeline_kandinsky2_2.py +6 -6
  130. diffusers/pipelines/kandinsky2_2/pipeline_kandinsky2_2_combined.py +32 -29
  131. diffusers/pipelines/kandinsky2_2/pipeline_kandinsky2_2_controlnet.py +10 -10
  132. diffusers/pipelines/kandinsky2_2/pipeline_kandinsky2_2_controlnet_img2img.py +10 -10
  133. diffusers/pipelines/kandinsky2_2/pipeline_kandinsky2_2_img2img.py +6 -6
  134. diffusers/pipelines/kandinsky2_2/pipeline_kandinsky2_2_inpainting.py +8 -8
  135. diffusers/pipelines/kandinsky2_2/pipeline_kandinsky2_2_prior.py +7 -7
  136. diffusers/pipelines/kandinsky2_2/pipeline_kandinsky2_2_prior_emb2emb.py +6 -6
  137. diffusers/pipelines/kandinsky3/convert_kandinsky3_unet.py +3 -3
  138. diffusers/pipelines/kandinsky3/pipeline_kandinsky3.py +20 -33
  139. diffusers/pipelines/kandinsky3/pipeline_kandinsky3_img2img.py +24 -35
  140. diffusers/pipelines/latent_consistency_models/pipeline_latent_consistency_img2img.py +48 -30
  141. diffusers/pipelines/latent_consistency_models/pipeline_latent_consistency_text2img.py +50 -28
  142. diffusers/pipelines/latent_diffusion/pipeline_latent_diffusion.py +11 -11
  143. diffusers/pipelines/ledits_pp/pipeline_leditspp_stable_diffusion.py +61 -67
  144. diffusers/pipelines/ledits_pp/pipeline_leditspp_stable_diffusion_xl.py +70 -69
  145. diffusers/pipelines/ledits_pp/pipeline_output.py +2 -2
  146. diffusers/pipelines/marigold/__init__.py +50 -0
  147. diffusers/pipelines/marigold/marigold_image_processing.py +561 -0
  148. diffusers/pipelines/marigold/pipeline_marigold_depth.py +813 -0
  149. diffusers/pipelines/marigold/pipeline_marigold_normals.py +690 -0
  150. diffusers/pipelines/musicldm/pipeline_musicldm.py +14 -14
  151. diffusers/pipelines/paint_by_example/pipeline_paint_by_example.py +17 -12
  152. diffusers/pipelines/pia/pipeline_pia.py +39 -125
  153. diffusers/pipelines/pipeline_flax_utils.py +4 -4
  154. diffusers/pipelines/pipeline_loading_utils.py +269 -23
  155. diffusers/pipelines/pipeline_utils.py +266 -37
  156. diffusers/pipelines/pixart_alpha/__init__.py +8 -1
  157. diffusers/pipelines/pixart_alpha/pipeline_pixart_alpha.py +69 -79
  158. diffusers/pipelines/pixart_alpha/pipeline_pixart_sigma.py +880 -0
  159. diffusers/pipelines/semantic_stable_diffusion/pipeline_semantic_stable_diffusion.py +10 -5
  160. diffusers/pipelines/shap_e/pipeline_shap_e.py +3 -3
  161. diffusers/pipelines/shap_e/pipeline_shap_e_img2img.py +14 -14
  162. diffusers/pipelines/shap_e/renderer.py +1 -1
  163. diffusers/pipelines/stable_cascade/pipeline_stable_cascade.py +18 -18
  164. diffusers/pipelines/stable_cascade/pipeline_stable_cascade_combined.py +23 -19
  165. diffusers/pipelines/stable_cascade/pipeline_stable_cascade_prior.py +33 -32
  166. diffusers/pipelines/stable_diffusion/__init__.py +0 -1
  167. diffusers/pipelines/stable_diffusion/convert_from_ckpt.py +18 -11
  168. diffusers/pipelines/stable_diffusion/pipeline_onnx_stable_diffusion.py +2 -2
  169. diffusers/pipelines/stable_diffusion/pipeline_onnx_stable_diffusion_upscale.py +6 -6
  170. diffusers/pipelines/stable_diffusion/pipeline_stable_diffusion.py +73 -39
  171. diffusers/pipelines/stable_diffusion/pipeline_stable_diffusion_depth2img.py +24 -17
  172. diffusers/pipelines/stable_diffusion/pipeline_stable_diffusion_image_variation.py +13 -8
  173. diffusers/pipelines/stable_diffusion/pipeline_stable_diffusion_img2img.py +66 -36
  174. diffusers/pipelines/stable_diffusion/pipeline_stable_diffusion_inpaint.py +82 -46
  175. diffusers/pipelines/stable_diffusion/pipeline_stable_diffusion_instruct_pix2pix.py +123 -28
  176. diffusers/pipelines/stable_diffusion/pipeline_stable_diffusion_latent_upscale.py +6 -6
  177. diffusers/pipelines/stable_diffusion/pipeline_stable_diffusion_upscale.py +16 -16
  178. diffusers/pipelines/stable_diffusion/pipeline_stable_unclip.py +24 -19
  179. diffusers/pipelines/stable_diffusion/pipeline_stable_unclip_img2img.py +37 -31
  180. diffusers/pipelines/stable_diffusion/safety_checker.py +2 -1
  181. diffusers/pipelines/stable_diffusion_attend_and_excite/pipeline_stable_diffusion_attend_and_excite.py +23 -15
  182. diffusers/pipelines/stable_diffusion_diffedit/pipeline_stable_diffusion_diffedit.py +44 -39
  183. diffusers/pipelines/stable_diffusion_gligen/pipeline_stable_diffusion_gligen.py +23 -18
  184. diffusers/pipelines/stable_diffusion_gligen/pipeline_stable_diffusion_gligen_text_image.py +19 -14
  185. diffusers/pipelines/stable_diffusion_k_diffusion/pipeline_stable_diffusion_k_diffusion.py +20 -15
  186. diffusers/pipelines/stable_diffusion_k_diffusion/pipeline_stable_diffusion_xl_k_diffusion.py +24 -19
  187. diffusers/pipelines/stable_diffusion_ldm3d/pipeline_stable_diffusion_ldm3d.py +65 -32
  188. diffusers/pipelines/stable_diffusion_panorama/pipeline_stable_diffusion_panorama.py +274 -38
  189. diffusers/pipelines/stable_diffusion_safe/pipeline_stable_diffusion_safe.py +10 -5
  190. diffusers/pipelines/stable_diffusion_safe/safety_checker.py +1 -1
  191. diffusers/pipelines/stable_diffusion_sag/pipeline_stable_diffusion_sag.py +92 -25
  192. diffusers/pipelines/stable_diffusion_xl/pipeline_stable_diffusion_xl.py +88 -44
  193. diffusers/pipelines/stable_diffusion_xl/pipeline_stable_diffusion_xl_img2img.py +108 -56
  194. diffusers/pipelines/stable_diffusion_xl/pipeline_stable_diffusion_xl_inpaint.py +96 -51
  195. diffusers/pipelines/stable_diffusion_xl/pipeline_stable_diffusion_xl_instruct_pix2pix.py +45 -25
  196. diffusers/pipelines/stable_diffusion_xl/watermark.py +9 -3
  197. diffusers/pipelines/stable_video_diffusion/pipeline_stable_video_diffusion.py +110 -57
  198. diffusers/pipelines/t2i_adapter/pipeline_stable_diffusion_adapter.py +59 -30
  199. diffusers/pipelines/t2i_adapter/pipeline_stable_diffusion_xl_adapter.py +71 -42
  200. diffusers/pipelines/text_to_video_synthesis/pipeline_output.py +3 -2
  201. diffusers/pipelines/text_to_video_synthesis/pipeline_text_to_video_synth.py +18 -41
  202. diffusers/pipelines/text_to_video_synthesis/pipeline_text_to_video_synth_img2img.py +21 -85
  203. diffusers/pipelines/text_to_video_synthesis/pipeline_text_to_video_zero.py +28 -19
  204. diffusers/pipelines/text_to_video_synthesis/pipeline_text_to_video_zero_sdxl.py +39 -33
  205. diffusers/pipelines/unclip/pipeline_unclip.py +6 -6
  206. diffusers/pipelines/unclip/pipeline_unclip_image_variation.py +6 -6
  207. diffusers/pipelines/unidiffuser/modeling_text_decoder.py +1 -1
  208. diffusers/pipelines/unidiffuser/modeling_uvit.py +9 -9
  209. diffusers/pipelines/unidiffuser/pipeline_unidiffuser.py +23 -23
  210. diffusers/pipelines/wuerstchen/modeling_paella_vq_model.py +5 -5
  211. diffusers/pipelines/wuerstchen/modeling_wuerstchen_common.py +5 -10
  212. diffusers/pipelines/wuerstchen/modeling_wuerstchen_prior.py +4 -6
  213. diffusers/pipelines/wuerstchen/pipeline_wuerstchen.py +4 -4
  214. diffusers/pipelines/wuerstchen/pipeline_wuerstchen_combined.py +12 -12
  215. diffusers/pipelines/wuerstchen/pipeline_wuerstchen_prior.py +10 -10
  216. diffusers/schedulers/__init__.py +2 -2
  217. diffusers/schedulers/deprecated/__init__.py +1 -1
  218. diffusers/schedulers/deprecated/scheduling_karras_ve.py +25 -25
  219. diffusers/schedulers/scheduling_amused.py +5 -5
  220. diffusers/schedulers/scheduling_consistency_decoder.py +11 -11
  221. diffusers/schedulers/scheduling_consistency_models.py +20 -26
  222. diffusers/schedulers/scheduling_ddim.py +22 -24
  223. diffusers/schedulers/scheduling_ddim_flax.py +2 -1
  224. diffusers/schedulers/scheduling_ddim_inverse.py +16 -16
  225. diffusers/schedulers/scheduling_ddim_parallel.py +28 -30
  226. diffusers/schedulers/scheduling_ddpm.py +20 -22
  227. diffusers/schedulers/scheduling_ddpm_flax.py +7 -3
  228. diffusers/schedulers/scheduling_ddpm_parallel.py +26 -28
  229. diffusers/schedulers/scheduling_ddpm_wuerstchen.py +14 -14
  230. diffusers/schedulers/scheduling_deis_multistep.py +42 -42
  231. diffusers/schedulers/scheduling_dpmsolver_multistep.py +103 -77
  232. diffusers/schedulers/scheduling_dpmsolver_multistep_flax.py +2 -2
  233. diffusers/schedulers/scheduling_dpmsolver_multistep_inverse.py +46 -46
  234. diffusers/schedulers/scheduling_dpmsolver_sde.py +23 -23
  235. diffusers/schedulers/scheduling_dpmsolver_singlestep.py +86 -65
  236. diffusers/schedulers/scheduling_edm_dpmsolver_multistep.py +75 -54
  237. diffusers/schedulers/scheduling_edm_euler.py +50 -31
  238. diffusers/schedulers/scheduling_euler_ancestral_discrete.py +23 -29
  239. diffusers/schedulers/scheduling_euler_discrete.py +160 -68
  240. diffusers/schedulers/scheduling_heun_discrete.py +57 -39
  241. diffusers/schedulers/scheduling_ipndm.py +8 -8
  242. diffusers/schedulers/scheduling_k_dpm_2_ancestral_discrete.py +19 -19
  243. diffusers/schedulers/scheduling_k_dpm_2_discrete.py +19 -19
  244. diffusers/schedulers/scheduling_karras_ve_flax.py +6 -6
  245. diffusers/schedulers/scheduling_lcm.py +21 -23
  246. diffusers/schedulers/scheduling_lms_discrete.py +24 -26
  247. diffusers/schedulers/scheduling_pndm.py +20 -20
  248. diffusers/schedulers/scheduling_repaint.py +20 -20
  249. diffusers/schedulers/scheduling_sasolver.py +55 -54
  250. diffusers/schedulers/scheduling_sde_ve.py +19 -19
  251. diffusers/schedulers/scheduling_tcd.py +39 -30
  252. diffusers/schedulers/scheduling_unclip.py +15 -15
  253. diffusers/schedulers/scheduling_unipc_multistep.py +111 -41
  254. diffusers/schedulers/scheduling_utils.py +14 -5
  255. diffusers/schedulers/scheduling_utils_flax.py +3 -3
  256. diffusers/schedulers/scheduling_vq_diffusion.py +10 -10
  257. diffusers/training_utils.py +56 -1
  258. diffusers/utils/__init__.py +7 -0
  259. diffusers/utils/doc_utils.py +1 -0
  260. diffusers/utils/dummy_pt_objects.py +75 -0
  261. diffusers/utils/dummy_torch_and_transformers_objects.py +105 -0
  262. diffusers/utils/dynamic_modules_utils.py +24 -11
  263. diffusers/utils/hub_utils.py +3 -2
  264. diffusers/utils/import_utils.py +91 -0
  265. diffusers/utils/loading_utils.py +2 -2
  266. diffusers/utils/logging.py +1 -1
  267. diffusers/utils/peft_utils.py +32 -5
  268. diffusers/utils/state_dict_utils.py +11 -2
  269. diffusers/utils/testing_utils.py +71 -6
  270. diffusers/utils/torch_utils.py +1 -0
  271. diffusers/video_processor.py +113 -0
  272. {diffusers-0.27.2.dist-info → diffusers-0.28.1.dist-info}/METADATA +7 -7
  273. diffusers-0.28.1.dist-info/RECORD +419 -0
  274. diffusers-0.27.2.dist-info/RECORD +0 -399
  275. {diffusers-0.27.2.dist-info → diffusers-0.28.1.dist-info}/LICENSE +0 -0
  276. {diffusers-0.27.2.dist-info → diffusers-0.28.1.dist-info}/WHEEL +0 -0
  277. {diffusers-0.27.2.dist-info → diffusers-0.28.1.dist-info}/entry_points.txt +0 -0
  278. {diffusers-0.27.2.dist-info → diffusers-0.28.1.dist-info}/top_level.txt +0 -0
@@ -102,7 +102,6 @@ class Downsample2D(nn.Module):
102
102
  self.padding = padding
103
103
  stride = 2
104
104
  self.name = name
105
- conv_cls = nn.Conv2d
106
105
 
107
106
  if norm_type == "ln_norm":
108
107
  self.norm = nn.LayerNorm(channels, eps, elementwise_affine)
@@ -114,7 +113,7 @@ class Downsample2D(nn.Module):
114
113
  raise ValueError(f"unknown norm_type: {norm_type}")
115
114
 
116
115
  if use_conv:
117
- conv = conv_cls(
116
+ conv = nn.Conv2d(
118
117
  self.channels, self.out_channels, kernel_size=kernel_size, stride=stride, padding=padding, bias=bias
119
118
  )
120
119
  else:
@@ -130,7 +129,7 @@ class Downsample2D(nn.Module):
130
129
  else:
131
130
  self.conv = conv
132
131
 
133
- def forward(self, hidden_states: torch.FloatTensor, *args, **kwargs) -> torch.FloatTensor:
132
+ def forward(self, hidden_states: torch.Tensor, *args, **kwargs) -> torch.Tensor:
134
133
  if len(args) > 0 or kwargs.get("scale", None) is not None:
135
134
  deprecation_message = "The `scale` argument is deprecated and will be ignored. Please remove it, as passing it will raise an error in the future. `scale` should directly be passed while calling the underlying pipeline component i.e., via `cross_attention_kwargs`."
136
135
  deprecate("scale", "1.0.0", deprecation_message)
@@ -181,24 +180,24 @@ class FirDownsample2D(nn.Module):
181
180
 
182
181
  def _downsample_2d(
183
182
  self,
184
- hidden_states: torch.FloatTensor,
185
- weight: Optional[torch.FloatTensor] = None,
186
- kernel: Optional[torch.FloatTensor] = None,
183
+ hidden_states: torch.Tensor,
184
+ weight: Optional[torch.Tensor] = None,
185
+ kernel: Optional[torch.Tensor] = None,
187
186
  factor: int = 2,
188
187
  gain: float = 1,
189
- ) -> torch.FloatTensor:
188
+ ) -> torch.Tensor:
190
189
  """Fused `Conv2d()` followed by `downsample_2d()`.
191
190
  Padding is performed only once at the beginning, not between the operations. The fused op is considerably more
192
191
  efficient than performing the same calculation using standard TensorFlow ops. It supports gradients of
193
192
  arbitrary order.
194
193
 
195
194
  Args:
196
- hidden_states (`torch.FloatTensor`):
195
+ hidden_states (`torch.Tensor`):
197
196
  Input tensor of the shape `[N, C, H, W]` or `[N, H, W, C]`.
198
- weight (`torch.FloatTensor`, *optional*):
197
+ weight (`torch.Tensor`, *optional*):
199
198
  Weight tensor of the shape `[filterH, filterW, inChannels, outChannels]`. Grouped convolution can be
200
199
  performed by `inChannels = x.shape[0] // numGroups`.
201
- kernel (`torch.FloatTensor`, *optional*):
200
+ kernel (`torch.Tensor`, *optional*):
202
201
  FIR filter of the shape `[firH, firW]` or `[firN]` (separable). The default is `[1] * factor`, which
203
202
  corresponds to average pooling.
204
203
  factor (`int`, *optional*, default to `2`):
@@ -207,7 +206,7 @@ class FirDownsample2D(nn.Module):
207
206
  Scaling factor for signal magnitude.
208
207
 
209
208
  Returns:
210
- output (`torch.FloatTensor`):
209
+ output (`torch.Tensor`):
211
210
  Tensor of the shape `[N, C, H // factor, W // factor]` or `[N, H // factor, W // factor, C]`, and same
212
211
  datatype as `x`.
213
212
  """
@@ -245,7 +244,7 @@ class FirDownsample2D(nn.Module):
245
244
 
246
245
  return output
247
246
 
248
- def forward(self, hidden_states: torch.FloatTensor) -> torch.FloatTensor:
247
+ def forward(self, hidden_states: torch.Tensor) -> torch.Tensor:
249
248
  if self.use_conv:
250
249
  downsample_input = self._downsample_2d(hidden_states, weight=self.Conv2d_0.weight, kernel=self.fir_kernel)
251
250
  hidden_states = downsample_input + self.Conv2d_0.bias.reshape(1, -1, 1, 1)
@@ -287,11 +286,11 @@ class KDownsample2D(nn.Module):
287
286
 
288
287
 
289
288
  def downsample_2d(
290
- hidden_states: torch.FloatTensor,
291
- kernel: Optional[torch.FloatTensor] = None,
289
+ hidden_states: torch.Tensor,
290
+ kernel: Optional[torch.Tensor] = None,
292
291
  factor: int = 2,
293
292
  gain: float = 1,
294
- ) -> torch.FloatTensor:
293
+ ) -> torch.Tensor:
295
294
  r"""Downsample2D a batch of 2D images with the given filter.
296
295
  Accepts a batch of 2D images of the shape `[N, C, H, W]` or `[N, H, W, C]` and downsamples each image with the
297
296
  given filter. The filter is normalized so that if the input pixels are constant, they will be scaled by the
@@ -299,9 +298,9 @@ def downsample_2d(
299
298
  shape is a multiple of the downsampling factor.
300
299
 
301
300
  Args:
302
- hidden_states (`torch.FloatTensor`)
301
+ hidden_states (`torch.Tensor`)
303
302
  Input tensor of the shape `[N, C, H, W]` or `[N, H, W, C]`.
304
- kernel (`torch.FloatTensor`, *optional*):
303
+ kernel (`torch.Tensor`, *optional*):
305
304
  FIR filter of the shape `[firH, firW]` or `[firN]` (separable). The default is `[1] * factor`, which
306
305
  corresponds to average pooling.
307
306
  factor (`int`, *optional*, default to `2`):
@@ -310,7 +309,7 @@ def downsample_2d(
310
309
  Scaling factor for signal magnitude.
311
310
 
312
311
  Returns:
313
- output (`torch.FloatTensor`):
312
+ output (`torch.Tensor`):
314
313
  Tensor of the shape `[N, C, H // factor, W // factor]`
315
314
  """
316
315
 
@@ -16,10 +16,11 @@ from typing import List, Optional, Tuple, Union
16
16
 
17
17
  import numpy as np
18
18
  import torch
19
+ import torch.nn.functional as F
19
20
  from torch import nn
20
21
 
21
22
  from ..utils import deprecate
22
- from .activations import get_activation
23
+ from .activations import FP32SiLU, get_activation
23
24
  from .attention_processor import Attention
24
25
 
25
26
 
@@ -135,6 +136,7 @@ class PatchEmbed(nn.Module):
135
136
  flatten=True,
136
137
  bias=True,
137
138
  interpolation_scale=1,
139
+ pos_embed_type="sincos",
138
140
  ):
139
141
  super().__init__()
140
142
 
@@ -156,10 +158,18 @@ class PatchEmbed(nn.Module):
156
158
  self.height, self.width = height // patch_size, width // patch_size
157
159
  self.base_size = height // patch_size
158
160
  self.interpolation_scale = interpolation_scale
159
- pos_embed = get_2d_sincos_pos_embed(
160
- embed_dim, int(num_patches**0.5), base_size=self.base_size, interpolation_scale=self.interpolation_scale
161
- )
162
- self.register_buffer("pos_embed", torch.from_numpy(pos_embed).float().unsqueeze(0), persistent=False)
161
+ if pos_embed_type is None:
162
+ self.pos_embed = None
163
+ elif pos_embed_type == "sincos":
164
+ pos_embed = get_2d_sincos_pos_embed(
165
+ embed_dim,
166
+ int(num_patches**0.5),
167
+ base_size=self.base_size,
168
+ interpolation_scale=self.interpolation_scale,
169
+ )
170
+ self.register_buffer("pos_embed", torch.from_numpy(pos_embed).float().unsqueeze(0), persistent=False)
171
+ else:
172
+ raise ValueError(f"Unsupported pos_embed_type: {pos_embed_type}")
163
173
 
164
174
  def forward(self, latent):
165
175
  height, width = latent.shape[-2] // self.patch_size, latent.shape[-1] // self.patch_size
@@ -169,6 +179,8 @@ class PatchEmbed(nn.Module):
169
179
  latent = latent.flatten(2).transpose(1, 2) # BCHW -> BNC
170
180
  if self.layer_norm:
171
181
  latent = self.norm(latent)
182
+ if self.pos_embed is None:
183
+ return latent.to(latent.dtype)
172
184
 
173
185
  # Interpolate positional embeddings if needed.
174
186
  # (For PixArt-Alpha: https://github.com/PixArt-alpha/PixArt-alpha/blob/0f55e922376d8b797edd44d25d0e7464b260dcab/diffusion/model/nets/PixArtMS.py#L162C151-L162C160)
@@ -187,6 +199,113 @@ class PatchEmbed(nn.Module):
187
199
  return (latent + pos_embed).to(latent.dtype)
188
200
 
189
201
 
202
+ def get_2d_rotary_pos_embed(embed_dim, crops_coords, grid_size, use_real=True):
203
+ """
204
+ RoPE for image tokens with 2d structure.
205
+
206
+ Args:
207
+ embed_dim: (`int`):
208
+ The embedding dimension size
209
+ crops_coords (`Tuple[int]`)
210
+ The top-left and bottom-right coordinates of the crop.
211
+ grid_size (`Tuple[int]`):
212
+ The grid size of the positional embedding.
213
+ use_real (`bool`):
214
+ If True, return real part and imaginary part separately. Otherwise, return complex numbers.
215
+
216
+ Returns:
217
+ `torch.Tensor`: positional embdding with shape `( grid_size * grid_size, embed_dim/2)`.
218
+ """
219
+ start, stop = crops_coords
220
+ grid_h = np.linspace(start[0], stop[0], grid_size[0], endpoint=False, dtype=np.float32)
221
+ grid_w = np.linspace(start[1], stop[1], grid_size[1], endpoint=False, dtype=np.float32)
222
+ grid = np.meshgrid(grid_w, grid_h) # here w goes first
223
+ grid = np.stack(grid, axis=0) # [2, W, H]
224
+
225
+ grid = grid.reshape([2, 1, *grid.shape[1:]])
226
+ pos_embed = get_2d_rotary_pos_embed_from_grid(embed_dim, grid, use_real=use_real)
227
+ return pos_embed
228
+
229
+
230
+ def get_2d_rotary_pos_embed_from_grid(embed_dim, grid, use_real=False):
231
+ assert embed_dim % 4 == 0
232
+
233
+ # use half of dimensions to encode grid_h
234
+ emb_h = get_1d_rotary_pos_embed(embed_dim // 2, grid[0].reshape(-1), use_real=use_real) # (H*W, D/4)
235
+ emb_w = get_1d_rotary_pos_embed(embed_dim // 2, grid[1].reshape(-1), use_real=use_real) # (H*W, D/4)
236
+
237
+ if use_real:
238
+ cos = torch.cat([emb_h[0], emb_w[0]], dim=1) # (H*W, D/2)
239
+ sin = torch.cat([emb_h[1], emb_w[1]], dim=1) # (H*W, D/2)
240
+ return cos, sin
241
+ else:
242
+ emb = torch.cat([emb_h, emb_w], dim=1) # (H*W, D/2)
243
+ return emb
244
+
245
+
246
+ def get_1d_rotary_pos_embed(dim: int, pos: Union[np.ndarray, int], theta: float = 10000.0, use_real=False):
247
+ """
248
+ Precompute the frequency tensor for complex exponentials (cis) with given dimensions.
249
+
250
+ This function calculates a frequency tensor with complex exponentials using the given dimension 'dim' and the end
251
+ index 'end'. The 'theta' parameter scales the frequencies. The returned tensor contains complex values in complex64
252
+ data type.
253
+
254
+ Args:
255
+ dim (`int`): Dimension of the frequency tensor.
256
+ pos (`np.ndarray` or `int`): Position indices for the frequency tensor. [S] or scalar
257
+ theta (`float`, *optional*, defaults to 10000.0):
258
+ Scaling factor for frequency computation. Defaults to 10000.0.
259
+ use_real (`bool`, *optional*):
260
+ If True, return real part and imaginary part separately. Otherwise, return complex numbers.
261
+
262
+ Returns:
263
+ `torch.Tensor`: Precomputed frequency tensor with complex exponentials. [S, D/2]
264
+ """
265
+ if isinstance(pos, int):
266
+ pos = np.arange(pos)
267
+ freqs = 1.0 / (theta ** (torch.arange(0, dim, 2)[: (dim // 2)].float() / dim)) # [D/2]
268
+ t = torch.from_numpy(pos).to(freqs.device) # type: ignore # [S]
269
+ freqs = torch.outer(t, freqs).float() # type: ignore # [S, D/2]
270
+ if use_real:
271
+ freqs_cos = freqs.cos().repeat_interleave(2, dim=1) # [S, D]
272
+ freqs_sin = freqs.sin().repeat_interleave(2, dim=1) # [S, D]
273
+ return freqs_cos, freqs_sin
274
+ else:
275
+ freqs_cis = torch.polar(torch.ones_like(freqs), freqs) # complex64 # [S, D/2]
276
+ return freqs_cis
277
+
278
+
279
+ def apply_rotary_emb(
280
+ x: torch.Tensor,
281
+ freqs_cis: Union[torch.Tensor, Tuple[torch.Tensor]],
282
+ ) -> Tuple[torch.Tensor, torch.Tensor]:
283
+ """
284
+ Apply rotary embeddings to input tensors using the given frequency tensor. This function applies rotary embeddings
285
+ to the given query or key 'x' tensors using the provided frequency tensor 'freqs_cis'. The input tensors are
286
+ reshaped as complex numbers, and the frequency tensor is reshaped for broadcasting compatibility. The resulting
287
+ tensors contain rotary embeddings and are returned as real tensors.
288
+
289
+ Args:
290
+ x (`torch.Tensor`):
291
+ Query or key tensor to apply rotary embeddings. [B, H, S, D] xk (torch.Tensor): Key tensor to apply
292
+ freqs_cis (`Tuple[torch.Tensor]`): Precomputed frequency tensor for complex exponentials. ([S, D], [S, D],)
293
+
294
+ Returns:
295
+ Tuple[torch.Tensor, torch.Tensor]: Tuple of modified query tensor and key tensor with rotary embeddings.
296
+ """
297
+ cos, sin = freqs_cis # [S, D]
298
+ cos = cos[None, None]
299
+ sin = sin[None, None]
300
+ cos, sin = cos.to(x.device), sin.to(x.device)
301
+
302
+ x_real, x_imag = x.reshape(*x.shape[:-1], -1, 2).unbind(-1) # [B, S, H, D//2]
303
+ x_rotated = torch.stack([-x_imag, x_real], dim=-1).flatten(3)
304
+ out = (x.float() * cos + x_rotated.float() * sin).to(x.dtype)
305
+
306
+ return out
307
+
308
+
190
309
  class TimestepEmbedding(nn.Module):
191
310
  def __init__(
192
311
  self,
@@ -199,9 +318,8 @@ class TimestepEmbedding(nn.Module):
199
318
  sample_proj_bias=True,
200
319
  ):
201
320
  super().__init__()
202
- linear_cls = nn.Linear
203
321
 
204
- self.linear_1 = linear_cls(in_channels, time_embed_dim, sample_proj_bias)
322
+ self.linear_1 = nn.Linear(in_channels, time_embed_dim, sample_proj_bias)
205
323
 
206
324
  if cond_proj_dim is not None:
207
325
  self.cond_proj = nn.Linear(cond_proj_dim, in_channels, bias=False)
@@ -214,7 +332,7 @@ class TimestepEmbedding(nn.Module):
214
332
  time_embed_dim_out = out_dim
215
333
  else:
216
334
  time_embed_dim_out = time_embed_dim
217
- self.linear_2 = linear_cls(time_embed_dim, time_embed_dim_out, sample_proj_bias)
335
+ self.linear_2 = nn.Linear(time_embed_dim, time_embed_dim_out, sample_proj_bias)
218
336
 
219
337
  if post_act_fn is None:
220
338
  self.post_act = None
@@ -425,7 +543,7 @@ class TextImageProjection(nn.Module):
425
543
  self.image_embeds = nn.Linear(image_embed_dim, self.num_image_text_embeds * cross_attention_dim)
426
544
  self.text_proj = nn.Linear(text_embed_dim, cross_attention_dim)
427
545
 
428
- def forward(self, text_embeds: torch.FloatTensor, image_embeds: torch.FloatTensor):
546
+ def forward(self, text_embeds: torch.Tensor, image_embeds: torch.Tensor):
429
547
  batch_size = text_embeds.shape[0]
430
548
 
431
549
  # image
@@ -451,7 +569,7 @@ class ImageProjection(nn.Module):
451
569
  self.image_embeds = nn.Linear(image_embed_dim, self.num_image_text_embeds * cross_attention_dim)
452
570
  self.norm = nn.LayerNorm(cross_attention_dim)
453
571
 
454
- def forward(self, image_embeds: torch.FloatTensor):
572
+ def forward(self, image_embeds: torch.Tensor):
455
573
  batch_size = image_embeds.shape[0]
456
574
 
457
575
  # image
@@ -469,10 +587,26 @@ class IPAdapterFullImageProjection(nn.Module):
469
587
  self.ff = FeedForward(image_embed_dim, cross_attention_dim, mult=1, activation_fn="gelu")
470
588
  self.norm = nn.LayerNorm(cross_attention_dim)
471
589
 
472
- def forward(self, image_embeds: torch.FloatTensor):
590
+ def forward(self, image_embeds: torch.Tensor):
473
591
  return self.norm(self.ff(image_embeds))
474
592
 
475
593
 
594
+ class IPAdapterFaceIDImageProjection(nn.Module):
595
+ def __init__(self, image_embed_dim=1024, cross_attention_dim=1024, mult=1, num_tokens=1):
596
+ super().__init__()
597
+ from .attention import FeedForward
598
+
599
+ self.num_tokens = num_tokens
600
+ self.cross_attention_dim = cross_attention_dim
601
+ self.ff = FeedForward(image_embed_dim, cross_attention_dim * num_tokens, mult=mult, activation_fn="gelu")
602
+ self.norm = nn.LayerNorm(cross_attention_dim)
603
+
604
+ def forward(self, image_embeds: torch.Tensor):
605
+ x = self.ff(image_embeds)
606
+ x = x.reshape(-1, self.num_tokens, self.cross_attention_dim)
607
+ return self.norm(x)
608
+
609
+
476
610
  class CombinedTimestepLabelEmbeddings(nn.Module):
477
611
  def __init__(self, num_classes, embedding_dim, class_dropout_prob=0.1):
478
612
  super().__init__()
@@ -492,6 +626,88 @@ class CombinedTimestepLabelEmbeddings(nn.Module):
492
626
  return conditioning
493
627
 
494
628
 
629
+ class HunyuanDiTAttentionPool(nn.Module):
630
+ # Copied from https://github.com/Tencent/HunyuanDiT/blob/cb709308d92e6c7e8d59d0dff41b74d35088db6a/hydit/modules/poolers.py#L6
631
+
632
+ def __init__(self, spacial_dim: int, embed_dim: int, num_heads: int, output_dim: int = None):
633
+ super().__init__()
634
+ self.positional_embedding = nn.Parameter(torch.randn(spacial_dim + 1, embed_dim) / embed_dim**0.5)
635
+ self.k_proj = nn.Linear(embed_dim, embed_dim)
636
+ self.q_proj = nn.Linear(embed_dim, embed_dim)
637
+ self.v_proj = nn.Linear(embed_dim, embed_dim)
638
+ self.c_proj = nn.Linear(embed_dim, output_dim or embed_dim)
639
+ self.num_heads = num_heads
640
+
641
+ def forward(self, x):
642
+ x = x.permute(1, 0, 2) # NLC -> LNC
643
+ x = torch.cat([x.mean(dim=0, keepdim=True), x], dim=0) # (L+1)NC
644
+ x = x + self.positional_embedding[:, None, :].to(x.dtype) # (L+1)NC
645
+ x, _ = F.multi_head_attention_forward(
646
+ query=x[:1],
647
+ key=x,
648
+ value=x,
649
+ embed_dim_to_check=x.shape[-1],
650
+ num_heads=self.num_heads,
651
+ q_proj_weight=self.q_proj.weight,
652
+ k_proj_weight=self.k_proj.weight,
653
+ v_proj_weight=self.v_proj.weight,
654
+ in_proj_weight=None,
655
+ in_proj_bias=torch.cat([self.q_proj.bias, self.k_proj.bias, self.v_proj.bias]),
656
+ bias_k=None,
657
+ bias_v=None,
658
+ add_zero_attn=False,
659
+ dropout_p=0,
660
+ out_proj_weight=self.c_proj.weight,
661
+ out_proj_bias=self.c_proj.bias,
662
+ use_separate_proj_weight=True,
663
+ training=self.training,
664
+ need_weights=False,
665
+ )
666
+ return x.squeeze(0)
667
+
668
+
669
+ class HunyuanCombinedTimestepTextSizeStyleEmbedding(nn.Module):
670
+ def __init__(self, embedding_dim, pooled_projection_dim=1024, seq_len=256, cross_attention_dim=2048):
671
+ super().__init__()
672
+
673
+ self.time_proj = Timesteps(num_channels=256, flip_sin_to_cos=True, downscale_freq_shift=0)
674
+ self.timestep_embedder = TimestepEmbedding(in_channels=256, time_embed_dim=embedding_dim)
675
+
676
+ self.pooler = HunyuanDiTAttentionPool(
677
+ seq_len, cross_attention_dim, num_heads=8, output_dim=pooled_projection_dim
678
+ )
679
+ # Here we use a default learned embedder layer for future extension.
680
+ self.style_embedder = nn.Embedding(1, embedding_dim)
681
+ extra_in_dim = 256 * 6 + embedding_dim + pooled_projection_dim
682
+ self.extra_embedder = PixArtAlphaTextProjection(
683
+ in_features=extra_in_dim,
684
+ hidden_size=embedding_dim * 4,
685
+ out_features=embedding_dim,
686
+ act_fn="silu_fp32",
687
+ )
688
+
689
+ def forward(self, timestep, encoder_hidden_states, image_meta_size, style, hidden_dtype=None):
690
+ timesteps_proj = self.time_proj(timestep)
691
+ timesteps_emb = self.timestep_embedder(timesteps_proj.to(dtype=hidden_dtype)) # (N, 256)
692
+
693
+ # extra condition1: text
694
+ pooled_projections = self.pooler(encoder_hidden_states) # (N, 1024)
695
+
696
+ # extra condition2: image meta size embdding
697
+ image_meta_size = get_timestep_embedding(image_meta_size.view(-1), 256, True, 0)
698
+ image_meta_size = image_meta_size.to(dtype=hidden_dtype)
699
+ image_meta_size = image_meta_size.view(-1, 6 * 256) # (N, 1536)
700
+
701
+ # extra condition3: style embedding
702
+ style_embedding = self.style_embedder(style) # (N, embedding_dim)
703
+
704
+ # Concatenate all extra vectors
705
+ extra_cond = torch.cat([pooled_projections, image_meta_size, style_embedding], dim=1)
706
+ conditioning = timesteps_emb + self.extra_embedder(extra_cond) # [B, D]
707
+
708
+ return conditioning
709
+
710
+
495
711
  class TextTimeEmbedding(nn.Module):
496
712
  def __init__(self, encoder_dim: int, time_embed_dim: int, num_heads: int = 64):
497
713
  super().__init__()
@@ -515,7 +731,7 @@ class TextImageTimeEmbedding(nn.Module):
515
731
  self.text_norm = nn.LayerNorm(time_embed_dim)
516
732
  self.image_proj = nn.Linear(image_embed_dim, time_embed_dim)
517
733
 
518
- def forward(self, text_embeds: torch.FloatTensor, image_embeds: torch.FloatTensor):
734
+ def forward(self, text_embeds: torch.Tensor, image_embeds: torch.Tensor):
519
735
  # text
520
736
  time_text_embeds = self.text_proj(text_embeds)
521
737
  time_text_embeds = self.text_norm(time_text_embeds)
@@ -532,7 +748,7 @@ class ImageTimeEmbedding(nn.Module):
532
748
  self.image_proj = nn.Linear(image_embed_dim, time_embed_dim)
533
749
  self.image_norm = nn.LayerNorm(time_embed_dim)
534
750
 
535
- def forward(self, image_embeds: torch.FloatTensor):
751
+ def forward(self, image_embeds: torch.Tensor):
536
752
  # image
537
753
  time_image_embeds = self.image_proj(image_embeds)
538
754
  time_image_embeds = self.image_norm(time_image_embeds)
@@ -562,7 +778,7 @@ class ImageHintTimeEmbedding(nn.Module):
562
778
  nn.Conv2d(256, 4, 3, padding=1),
563
779
  )
564
780
 
565
- def forward(self, image_embeds: torch.FloatTensor, hint: torch.FloatTensor):
781
+ def forward(self, image_embeds: torch.Tensor, hint: torch.Tensor):
566
782
  # image
567
783
  time_image_embeds = self.image_proj(image_embeds)
568
784
  time_image_embeds = self.image_norm(time_image_embeds)
@@ -778,11 +994,18 @@ class PixArtAlphaTextProjection(nn.Module):
778
994
  Adapted from https://github.com/PixArt-alpha/PixArt-alpha/blob/master/diffusion/model/nets/PixArt_blocks.py
779
995
  """
780
996
 
781
- def __init__(self, in_features, hidden_size, num_tokens=120):
997
+ def __init__(self, in_features, hidden_size, out_features=None, act_fn="gelu_tanh"):
782
998
  super().__init__()
999
+ if out_features is None:
1000
+ out_features = hidden_size
783
1001
  self.linear_1 = nn.Linear(in_features=in_features, out_features=hidden_size, bias=True)
784
- self.act_1 = nn.GELU(approximate="tanh")
785
- self.linear_2 = nn.Linear(in_features=hidden_size, out_features=hidden_size, bias=True)
1002
+ if act_fn == "gelu_tanh":
1003
+ self.act_1 = nn.GELU(approximate="tanh")
1004
+ elif act_fn == "silu_fp32":
1005
+ self.act_1 = FP32SiLU()
1006
+ else:
1007
+ raise ValueError(f"Unknown activation function: {act_fn}")
1008
+ self.linear_2 = nn.Linear(in_features=hidden_size, out_features=out_features, bias=True)
786
1009
 
787
1010
  def forward(self, caption):
788
1011
  hidden_states = self.linear_1(caption)
@@ -795,17 +1018,15 @@ class IPAdapterPlusImageProjection(nn.Module):
795
1018
  """Resampler of IP-Adapter Plus.
796
1019
 
797
1020
  Args:
798
- ----
799
- embed_dims (int): The feature dimension. Defaults to 768.
800
- output_dims (int): The number of output channels, that is the same
801
- number of the channels in the
802
- `unet.config.cross_attention_dim`. Defaults to 1024.
803
- hidden_dims (int): The number of hidden channels. Defaults to 1280.
804
- depth (int): The number of blocks. Defaults to 8.
805
- dim_head (int): The number of head channels. Defaults to 64.
806
- heads (int): Parallel attention heads. Defaults to 16.
807
- num_queries (int): The number of queries. Defaults to 8.
808
- ffn_ratio (float): The expansion ratio of feedforward network hidden
1021
+ embed_dims (int): The feature dimension. Defaults to 768. output_dims (int): The number of output channels,
1022
+ that is the same
1023
+ number of the channels in the `unet.config.cross_attention_dim`. Defaults to 1024.
1024
+ hidden_dims (int):
1025
+ The number of hidden channels. Defaults to 1280. depth (int): The number of blocks. Defaults
1026
+ to 8. dim_head (int): The number of head channels. Defaults to 64. heads (int): Parallel attention heads.
1027
+ Defaults to 16. num_queries (int):
1028
+ The number of queries. Defaults to 8. ffn_ratio (float): The expansion ratio
1029
+ of feedforward network hidden
809
1030
  layer channels. Defaults to 4.
810
1031
  """
811
1032
 
@@ -855,11 +1076,8 @@ class IPAdapterPlusImageProjection(nn.Module):
855
1076
  """Forward pass.
856
1077
 
857
1078
  Args:
858
- ----
859
1079
  x (torch.Tensor): Input Tensor.
860
-
861
1080
  Returns:
862
- -------
863
1081
  torch.Tensor: Output Tensor.
864
1082
  """
865
1083
  latents = self.latents.repeat(x.size(0), 1, 1)
@@ -879,12 +1097,125 @@ class IPAdapterPlusImageProjection(nn.Module):
879
1097
  return self.norm_out(latents)
880
1098
 
881
1099
 
1100
+ class IPAdapterPlusImageProjectionBlock(nn.Module):
1101
+ def __init__(
1102
+ self,
1103
+ embed_dims: int = 768,
1104
+ dim_head: int = 64,
1105
+ heads: int = 16,
1106
+ ffn_ratio: float = 4,
1107
+ ) -> None:
1108
+ super().__init__()
1109
+ from .attention import FeedForward
1110
+
1111
+ self.ln0 = nn.LayerNorm(embed_dims)
1112
+ self.ln1 = nn.LayerNorm(embed_dims)
1113
+ self.attn = Attention(
1114
+ query_dim=embed_dims,
1115
+ dim_head=dim_head,
1116
+ heads=heads,
1117
+ out_bias=False,
1118
+ )
1119
+ self.ff = nn.Sequential(
1120
+ nn.LayerNorm(embed_dims),
1121
+ FeedForward(embed_dims, embed_dims, activation_fn="gelu", mult=ffn_ratio, bias=False),
1122
+ )
1123
+
1124
+ def forward(self, x, latents, residual):
1125
+ encoder_hidden_states = self.ln0(x)
1126
+ latents = self.ln1(latents)
1127
+ encoder_hidden_states = torch.cat([encoder_hidden_states, latents], dim=-2)
1128
+ latents = self.attn(latents, encoder_hidden_states) + residual
1129
+ latents = self.ff(latents) + latents
1130
+ return latents
1131
+
1132
+
1133
+ class IPAdapterFaceIDPlusImageProjection(nn.Module):
1134
+ """FacePerceiverResampler of IP-Adapter Plus.
1135
+
1136
+ Args:
1137
+ embed_dims (int): The feature dimension. Defaults to 768. output_dims (int): The number of output channels,
1138
+ that is the same
1139
+ number of the channels in the `unet.config.cross_attention_dim`. Defaults to 1024.
1140
+ hidden_dims (int):
1141
+ The number of hidden channels. Defaults to 1280. depth (int): The number of blocks. Defaults
1142
+ to 8. dim_head (int): The number of head channels. Defaults to 64. heads (int): Parallel attention heads.
1143
+ Defaults to 16. num_tokens (int): Number of tokens num_queries (int): The number of queries. Defaults to 8.
1144
+ ffn_ratio (float): The expansion ratio of feedforward network hidden
1145
+ layer channels. Defaults to 4.
1146
+ ffproj_ratio (float): The expansion ratio of feedforward network hidden
1147
+ layer channels (for ID embeddings). Defaults to 4.
1148
+ """
1149
+
1150
+ def __init__(
1151
+ self,
1152
+ embed_dims: int = 768,
1153
+ output_dims: int = 768,
1154
+ hidden_dims: int = 1280,
1155
+ id_embeddings_dim: int = 512,
1156
+ depth: int = 4,
1157
+ dim_head: int = 64,
1158
+ heads: int = 16,
1159
+ num_tokens: int = 4,
1160
+ num_queries: int = 8,
1161
+ ffn_ratio: float = 4,
1162
+ ffproj_ratio: int = 2,
1163
+ ) -> None:
1164
+ super().__init__()
1165
+ from .attention import FeedForward
1166
+
1167
+ self.num_tokens = num_tokens
1168
+ self.embed_dim = embed_dims
1169
+ self.clip_embeds = None
1170
+ self.shortcut = False
1171
+ self.shortcut_scale = 1.0
1172
+
1173
+ self.proj = FeedForward(id_embeddings_dim, embed_dims * num_tokens, activation_fn="gelu", mult=ffproj_ratio)
1174
+ self.norm = nn.LayerNorm(embed_dims)
1175
+
1176
+ self.proj_in = nn.Linear(hidden_dims, embed_dims)
1177
+
1178
+ self.proj_out = nn.Linear(embed_dims, output_dims)
1179
+ self.norm_out = nn.LayerNorm(output_dims)
1180
+
1181
+ self.layers = nn.ModuleList(
1182
+ [IPAdapterPlusImageProjectionBlock(embed_dims, dim_head, heads, ffn_ratio) for _ in range(depth)]
1183
+ )
1184
+
1185
+ def forward(self, id_embeds: torch.Tensor) -> torch.Tensor:
1186
+ """Forward pass.
1187
+
1188
+ Args:
1189
+ id_embeds (torch.Tensor): Input Tensor (ID embeds).
1190
+ Returns:
1191
+ torch.Tensor: Output Tensor.
1192
+ """
1193
+ id_embeds = id_embeds.to(self.clip_embeds.dtype)
1194
+ id_embeds = self.proj(id_embeds)
1195
+ id_embeds = id_embeds.reshape(-1, self.num_tokens, self.embed_dim)
1196
+ id_embeds = self.norm(id_embeds)
1197
+ latents = id_embeds
1198
+
1199
+ clip_embeds = self.proj_in(self.clip_embeds)
1200
+ x = clip_embeds.reshape(-1, clip_embeds.shape[2], clip_embeds.shape[3])
1201
+
1202
+ for block in self.layers:
1203
+ residual = latents
1204
+ latents = block(x, latents, residual)
1205
+
1206
+ latents = self.proj_out(latents)
1207
+ out = self.norm_out(latents)
1208
+ if self.shortcut:
1209
+ out = id_embeds + self.shortcut_scale * out
1210
+ return out
1211
+
1212
+
882
1213
  class MultiIPAdapterImageProjection(nn.Module):
883
1214
  def __init__(self, IPAdapterImageProjectionLayers: Union[List[nn.Module], Tuple[nn.Module]]):
884
1215
  super().__init__()
885
1216
  self.image_projection_layers = nn.ModuleList(IPAdapterImageProjectionLayers)
886
1217
 
887
- def forward(self, image_embeds: List[torch.FloatTensor]):
1218
+ def forward(self, image_embeds: List[torch.Tensor]):
888
1219
  projected_image_embeds = []
889
1220
 
890
1221
  # currently, we accept `image_embeds` as