diffusers 0.27.2__py3-none-any.whl → 0.28.0__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (270) hide show
  1. diffusers/__init__.py +18 -1
  2. diffusers/callbacks.py +156 -0
  3. diffusers/commands/env.py +110 -6
  4. diffusers/configuration_utils.py +16 -11
  5. diffusers/dependency_versions_table.py +2 -1
  6. diffusers/image_processor.py +158 -45
  7. diffusers/loaders/__init__.py +2 -5
  8. diffusers/loaders/autoencoder.py +4 -4
  9. diffusers/loaders/controlnet.py +4 -4
  10. diffusers/loaders/ip_adapter.py +80 -22
  11. diffusers/loaders/lora.py +134 -20
  12. diffusers/loaders/lora_conversion_utils.py +46 -43
  13. diffusers/loaders/peft.py +4 -3
  14. diffusers/loaders/single_file.py +401 -170
  15. diffusers/loaders/single_file_model.py +290 -0
  16. diffusers/loaders/single_file_utils.py +616 -672
  17. diffusers/loaders/textual_inversion.py +41 -20
  18. diffusers/loaders/unet.py +168 -115
  19. diffusers/loaders/unet_loader_utils.py +163 -0
  20. diffusers/models/__init__.py +2 -0
  21. diffusers/models/activations.py +11 -3
  22. diffusers/models/attention.py +10 -11
  23. diffusers/models/attention_processor.py +367 -148
  24. diffusers/models/autoencoders/autoencoder_asym_kl.py +14 -16
  25. diffusers/models/autoencoders/autoencoder_kl.py +18 -19
  26. diffusers/models/autoencoders/autoencoder_kl_temporal_decoder.py +11 -11
  27. diffusers/models/autoencoders/autoencoder_tiny.py +16 -16
  28. diffusers/models/autoencoders/consistency_decoder_vae.py +36 -11
  29. diffusers/models/autoencoders/vae.py +23 -24
  30. diffusers/models/controlnet.py +12 -9
  31. diffusers/models/controlnet_flax.py +4 -4
  32. diffusers/models/controlnet_xs.py +1915 -0
  33. diffusers/models/downsampling.py +17 -18
  34. diffusers/models/embeddings.py +147 -24
  35. diffusers/models/model_loading_utils.py +149 -0
  36. diffusers/models/modeling_flax_pytorch_utils.py +2 -1
  37. diffusers/models/modeling_flax_utils.py +4 -4
  38. diffusers/models/modeling_pytorch_flax_utils.py +1 -1
  39. diffusers/models/modeling_utils.py +118 -98
  40. diffusers/models/resnet.py +18 -23
  41. diffusers/models/transformer_temporal.py +3 -3
  42. diffusers/models/transformers/dual_transformer_2d.py +4 -4
  43. diffusers/models/transformers/prior_transformer.py +7 -7
  44. diffusers/models/transformers/t5_film_transformer.py +17 -19
  45. diffusers/models/transformers/transformer_2d.py +272 -156
  46. diffusers/models/transformers/transformer_temporal.py +10 -10
  47. diffusers/models/unets/unet_1d.py +5 -5
  48. diffusers/models/unets/unet_1d_blocks.py +29 -29
  49. diffusers/models/unets/unet_2d.py +6 -6
  50. diffusers/models/unets/unet_2d_blocks.py +137 -128
  51. diffusers/models/unets/unet_2d_condition.py +19 -15
  52. diffusers/models/unets/unet_2d_condition_flax.py +6 -5
  53. diffusers/models/unets/unet_3d_blocks.py +79 -77
  54. diffusers/models/unets/unet_3d_condition.py +13 -9
  55. diffusers/models/unets/unet_i2vgen_xl.py +14 -13
  56. diffusers/models/unets/unet_kandinsky3.py +1 -1
  57. diffusers/models/unets/unet_motion_model.py +114 -14
  58. diffusers/models/unets/unet_spatio_temporal_condition.py +15 -14
  59. diffusers/models/unets/unet_stable_cascade.py +16 -13
  60. diffusers/models/upsampling.py +17 -20
  61. diffusers/models/vq_model.py +16 -15
  62. diffusers/pipelines/__init__.py +25 -3
  63. diffusers/pipelines/amused/pipeline_amused.py +12 -12
  64. diffusers/pipelines/amused/pipeline_amused_img2img.py +14 -12
  65. diffusers/pipelines/amused/pipeline_amused_inpaint.py +13 -11
  66. diffusers/pipelines/animatediff/__init__.py +2 -0
  67. diffusers/pipelines/animatediff/pipeline_animatediff.py +24 -46
  68. diffusers/pipelines/animatediff/pipeline_animatediff_sdxl.py +1284 -0
  69. diffusers/pipelines/animatediff/pipeline_animatediff_video2video.py +82 -72
  70. diffusers/pipelines/animatediff/pipeline_output.py +3 -2
  71. diffusers/pipelines/audioldm/pipeline_audioldm.py +14 -14
  72. diffusers/pipelines/audioldm2/modeling_audioldm2.py +54 -35
  73. diffusers/pipelines/audioldm2/pipeline_audioldm2.py +120 -36
  74. diffusers/pipelines/auto_pipeline.py +21 -17
  75. diffusers/pipelines/blip_diffusion/blip_image_processing.py +1 -1
  76. diffusers/pipelines/blip_diffusion/modeling_blip2.py +5 -5
  77. diffusers/pipelines/blip_diffusion/modeling_ctx_clip.py +1 -1
  78. diffusers/pipelines/blip_diffusion/pipeline_blip_diffusion.py +2 -2
  79. diffusers/pipelines/consistency_models/pipeline_consistency_models.py +5 -5
  80. diffusers/pipelines/controlnet/multicontrolnet.py +4 -8
  81. diffusers/pipelines/controlnet/pipeline_controlnet.py +87 -52
  82. diffusers/pipelines/controlnet/pipeline_controlnet_blip_diffusion.py +2 -2
  83. diffusers/pipelines/controlnet/pipeline_controlnet_img2img.py +50 -43
  84. diffusers/pipelines/controlnet/pipeline_controlnet_inpaint.py +52 -40
  85. diffusers/pipelines/controlnet/pipeline_controlnet_inpaint_sd_xl.py +80 -47
  86. diffusers/pipelines/controlnet/pipeline_controlnet_sd_xl.py +147 -49
  87. diffusers/pipelines/controlnet/pipeline_controlnet_sd_xl_img2img.py +89 -55
  88. diffusers/pipelines/controlnet_xs/__init__.py +68 -0
  89. diffusers/pipelines/controlnet_xs/pipeline_controlnet_xs.py +911 -0
  90. diffusers/pipelines/controlnet_xs/pipeline_controlnet_xs_sd_xl.py +1115 -0
  91. diffusers/pipelines/deepfloyd_if/pipeline_if.py +14 -28
  92. diffusers/pipelines/deepfloyd_if/pipeline_if_img2img.py +18 -33
  93. diffusers/pipelines/deepfloyd_if/pipeline_if_img2img_superresolution.py +21 -39
  94. diffusers/pipelines/deepfloyd_if/pipeline_if_inpainting.py +20 -36
  95. diffusers/pipelines/deepfloyd_if/pipeline_if_inpainting_superresolution.py +23 -39
  96. diffusers/pipelines/deepfloyd_if/pipeline_if_superresolution.py +17 -32
  97. diffusers/pipelines/deprecated/alt_diffusion/modeling_roberta_series.py +11 -11
  98. diffusers/pipelines/deprecated/alt_diffusion/pipeline_alt_diffusion.py +43 -20
  99. diffusers/pipelines/deprecated/alt_diffusion/pipeline_alt_diffusion_img2img.py +36 -18
  100. diffusers/pipelines/deprecated/repaint/pipeline_repaint.py +2 -2
  101. diffusers/pipelines/deprecated/spectrogram_diffusion/pipeline_spectrogram_diffusion.py +7 -7
  102. diffusers/pipelines/deprecated/stable_diffusion_variants/pipeline_cycle_diffusion.py +12 -12
  103. diffusers/pipelines/deprecated/stable_diffusion_variants/pipeline_stable_diffusion_inpaint_legacy.py +18 -18
  104. diffusers/pipelines/deprecated/stable_diffusion_variants/pipeline_stable_diffusion_model_editing.py +20 -15
  105. diffusers/pipelines/deprecated/stable_diffusion_variants/pipeline_stable_diffusion_paradigms.py +20 -15
  106. diffusers/pipelines/deprecated/stable_diffusion_variants/pipeline_stable_diffusion_pix2pix_zero.py +30 -25
  107. diffusers/pipelines/deprecated/versatile_diffusion/modeling_text_unet.py +69 -59
  108. diffusers/pipelines/deprecated/versatile_diffusion/pipeline_versatile_diffusion.py +13 -13
  109. diffusers/pipelines/deprecated/versatile_diffusion/pipeline_versatile_diffusion_dual_guided.py +10 -5
  110. diffusers/pipelines/deprecated/versatile_diffusion/pipeline_versatile_diffusion_image_variation.py +11 -6
  111. diffusers/pipelines/deprecated/versatile_diffusion/pipeline_versatile_diffusion_text_to_image.py +10 -5
  112. diffusers/pipelines/deprecated/vq_diffusion/pipeline_vq_diffusion.py +5 -5
  113. diffusers/pipelines/dit/pipeline_dit.py +3 -0
  114. diffusers/pipelines/free_init_utils.py +39 -38
  115. diffusers/pipelines/i2vgen_xl/pipeline_i2vgen_xl.py +33 -48
  116. diffusers/pipelines/kandinsky/pipeline_kandinsky.py +8 -8
  117. diffusers/pipelines/kandinsky/pipeline_kandinsky_combined.py +23 -20
  118. diffusers/pipelines/kandinsky/pipeline_kandinsky_img2img.py +11 -11
  119. diffusers/pipelines/kandinsky/pipeline_kandinsky_inpaint.py +12 -12
  120. diffusers/pipelines/kandinsky/pipeline_kandinsky_prior.py +10 -10
  121. diffusers/pipelines/kandinsky2_2/pipeline_kandinsky2_2.py +6 -6
  122. diffusers/pipelines/kandinsky2_2/pipeline_kandinsky2_2_combined.py +32 -29
  123. diffusers/pipelines/kandinsky2_2/pipeline_kandinsky2_2_controlnet.py +10 -10
  124. diffusers/pipelines/kandinsky2_2/pipeline_kandinsky2_2_controlnet_img2img.py +10 -10
  125. diffusers/pipelines/kandinsky2_2/pipeline_kandinsky2_2_img2img.py +6 -6
  126. diffusers/pipelines/kandinsky2_2/pipeline_kandinsky2_2_inpainting.py +8 -8
  127. diffusers/pipelines/kandinsky2_2/pipeline_kandinsky2_2_prior.py +7 -7
  128. diffusers/pipelines/kandinsky2_2/pipeline_kandinsky2_2_prior_emb2emb.py +6 -6
  129. diffusers/pipelines/kandinsky3/convert_kandinsky3_unet.py +3 -3
  130. diffusers/pipelines/kandinsky3/pipeline_kandinsky3.py +20 -33
  131. diffusers/pipelines/kandinsky3/pipeline_kandinsky3_img2img.py +24 -35
  132. diffusers/pipelines/latent_consistency_models/pipeline_latent_consistency_img2img.py +48 -30
  133. diffusers/pipelines/latent_consistency_models/pipeline_latent_consistency_text2img.py +50 -28
  134. diffusers/pipelines/latent_diffusion/pipeline_latent_diffusion.py +11 -11
  135. diffusers/pipelines/ledits_pp/pipeline_leditspp_stable_diffusion.py +61 -67
  136. diffusers/pipelines/ledits_pp/pipeline_leditspp_stable_diffusion_xl.py +70 -69
  137. diffusers/pipelines/ledits_pp/pipeline_output.py +2 -2
  138. diffusers/pipelines/marigold/__init__.py +50 -0
  139. diffusers/pipelines/marigold/marigold_image_processing.py +561 -0
  140. diffusers/pipelines/marigold/pipeline_marigold_depth.py +813 -0
  141. diffusers/pipelines/marigold/pipeline_marigold_normals.py +690 -0
  142. diffusers/pipelines/musicldm/pipeline_musicldm.py +14 -14
  143. diffusers/pipelines/paint_by_example/pipeline_paint_by_example.py +17 -12
  144. diffusers/pipelines/pia/pipeline_pia.py +39 -125
  145. diffusers/pipelines/pipeline_flax_utils.py +4 -4
  146. diffusers/pipelines/pipeline_loading_utils.py +268 -23
  147. diffusers/pipelines/pipeline_utils.py +266 -37
  148. diffusers/pipelines/pixart_alpha/__init__.py +8 -1
  149. diffusers/pipelines/pixart_alpha/pipeline_pixart_alpha.py +65 -75
  150. diffusers/pipelines/pixart_alpha/pipeline_pixart_sigma.py +880 -0
  151. diffusers/pipelines/semantic_stable_diffusion/pipeline_semantic_stable_diffusion.py +10 -5
  152. diffusers/pipelines/shap_e/pipeline_shap_e.py +3 -3
  153. diffusers/pipelines/shap_e/pipeline_shap_e_img2img.py +14 -14
  154. diffusers/pipelines/shap_e/renderer.py +1 -1
  155. diffusers/pipelines/stable_cascade/pipeline_stable_cascade.py +18 -18
  156. diffusers/pipelines/stable_cascade/pipeline_stable_cascade_combined.py +23 -19
  157. diffusers/pipelines/stable_cascade/pipeline_stable_cascade_prior.py +33 -32
  158. diffusers/pipelines/stable_diffusion/__init__.py +0 -1
  159. diffusers/pipelines/stable_diffusion/convert_from_ckpt.py +18 -11
  160. diffusers/pipelines/stable_diffusion/pipeline_onnx_stable_diffusion.py +2 -2
  161. diffusers/pipelines/stable_diffusion/pipeline_onnx_stable_diffusion_upscale.py +6 -6
  162. diffusers/pipelines/stable_diffusion/pipeline_stable_diffusion.py +73 -39
  163. diffusers/pipelines/stable_diffusion/pipeline_stable_diffusion_depth2img.py +24 -17
  164. diffusers/pipelines/stable_diffusion/pipeline_stable_diffusion_image_variation.py +13 -8
  165. diffusers/pipelines/stable_diffusion/pipeline_stable_diffusion_img2img.py +66 -36
  166. diffusers/pipelines/stable_diffusion/pipeline_stable_diffusion_inpaint.py +82 -46
  167. diffusers/pipelines/stable_diffusion/pipeline_stable_diffusion_instruct_pix2pix.py +123 -28
  168. diffusers/pipelines/stable_diffusion/pipeline_stable_diffusion_latent_upscale.py +6 -6
  169. diffusers/pipelines/stable_diffusion/pipeline_stable_diffusion_upscale.py +16 -16
  170. diffusers/pipelines/stable_diffusion/pipeline_stable_unclip.py +24 -19
  171. diffusers/pipelines/stable_diffusion/pipeline_stable_unclip_img2img.py +37 -31
  172. diffusers/pipelines/stable_diffusion/safety_checker.py +2 -1
  173. diffusers/pipelines/stable_diffusion_attend_and_excite/pipeline_stable_diffusion_attend_and_excite.py +23 -15
  174. diffusers/pipelines/stable_diffusion_diffedit/pipeline_stable_diffusion_diffedit.py +44 -39
  175. diffusers/pipelines/stable_diffusion_gligen/pipeline_stable_diffusion_gligen.py +23 -18
  176. diffusers/pipelines/stable_diffusion_gligen/pipeline_stable_diffusion_gligen_text_image.py +19 -14
  177. diffusers/pipelines/stable_diffusion_k_diffusion/pipeline_stable_diffusion_k_diffusion.py +20 -15
  178. diffusers/pipelines/stable_diffusion_k_diffusion/pipeline_stable_diffusion_xl_k_diffusion.py +24 -19
  179. diffusers/pipelines/stable_diffusion_ldm3d/pipeline_stable_diffusion_ldm3d.py +65 -32
  180. diffusers/pipelines/stable_diffusion_panorama/pipeline_stable_diffusion_panorama.py +274 -38
  181. diffusers/pipelines/stable_diffusion_safe/pipeline_stable_diffusion_safe.py +10 -5
  182. diffusers/pipelines/stable_diffusion_safe/safety_checker.py +1 -1
  183. diffusers/pipelines/stable_diffusion_sag/pipeline_stable_diffusion_sag.py +92 -25
  184. diffusers/pipelines/stable_diffusion_xl/pipeline_stable_diffusion_xl.py +88 -44
  185. diffusers/pipelines/stable_diffusion_xl/pipeline_stable_diffusion_xl_img2img.py +108 -56
  186. diffusers/pipelines/stable_diffusion_xl/pipeline_stable_diffusion_xl_inpaint.py +96 -51
  187. diffusers/pipelines/stable_diffusion_xl/pipeline_stable_diffusion_xl_instruct_pix2pix.py +45 -25
  188. diffusers/pipelines/stable_diffusion_xl/watermark.py +9 -3
  189. diffusers/pipelines/stable_video_diffusion/pipeline_stable_video_diffusion.py +110 -57
  190. diffusers/pipelines/t2i_adapter/pipeline_stable_diffusion_adapter.py +59 -30
  191. diffusers/pipelines/t2i_adapter/pipeline_stable_diffusion_xl_adapter.py +71 -42
  192. diffusers/pipelines/text_to_video_synthesis/pipeline_output.py +3 -2
  193. diffusers/pipelines/text_to_video_synthesis/pipeline_text_to_video_synth.py +18 -41
  194. diffusers/pipelines/text_to_video_synthesis/pipeline_text_to_video_synth_img2img.py +21 -85
  195. diffusers/pipelines/text_to_video_synthesis/pipeline_text_to_video_zero.py +28 -19
  196. diffusers/pipelines/text_to_video_synthesis/pipeline_text_to_video_zero_sdxl.py +39 -33
  197. diffusers/pipelines/unclip/pipeline_unclip.py +6 -6
  198. diffusers/pipelines/unclip/pipeline_unclip_image_variation.py +6 -6
  199. diffusers/pipelines/unidiffuser/modeling_text_decoder.py +1 -1
  200. diffusers/pipelines/unidiffuser/modeling_uvit.py +9 -9
  201. diffusers/pipelines/unidiffuser/pipeline_unidiffuser.py +23 -23
  202. diffusers/pipelines/wuerstchen/modeling_paella_vq_model.py +5 -5
  203. diffusers/pipelines/wuerstchen/modeling_wuerstchen_common.py +5 -10
  204. diffusers/pipelines/wuerstchen/modeling_wuerstchen_prior.py +4 -6
  205. diffusers/pipelines/wuerstchen/pipeline_wuerstchen.py +4 -4
  206. diffusers/pipelines/wuerstchen/pipeline_wuerstchen_combined.py +12 -12
  207. diffusers/pipelines/wuerstchen/pipeline_wuerstchen_prior.py +10 -10
  208. diffusers/schedulers/__init__.py +2 -2
  209. diffusers/schedulers/deprecated/__init__.py +1 -1
  210. diffusers/schedulers/deprecated/scheduling_karras_ve.py +25 -25
  211. diffusers/schedulers/scheduling_amused.py +5 -5
  212. diffusers/schedulers/scheduling_consistency_decoder.py +11 -11
  213. diffusers/schedulers/scheduling_consistency_models.py +20 -26
  214. diffusers/schedulers/scheduling_ddim.py +22 -24
  215. diffusers/schedulers/scheduling_ddim_flax.py +2 -1
  216. diffusers/schedulers/scheduling_ddim_inverse.py +16 -16
  217. diffusers/schedulers/scheduling_ddim_parallel.py +28 -30
  218. diffusers/schedulers/scheduling_ddpm.py +20 -22
  219. diffusers/schedulers/scheduling_ddpm_flax.py +7 -3
  220. diffusers/schedulers/scheduling_ddpm_parallel.py +26 -28
  221. diffusers/schedulers/scheduling_ddpm_wuerstchen.py +14 -14
  222. diffusers/schedulers/scheduling_deis_multistep.py +42 -42
  223. diffusers/schedulers/scheduling_dpmsolver_multistep.py +103 -77
  224. diffusers/schedulers/scheduling_dpmsolver_multistep_flax.py +2 -2
  225. diffusers/schedulers/scheduling_dpmsolver_multistep_inverse.py +46 -46
  226. diffusers/schedulers/scheduling_dpmsolver_sde.py +23 -23
  227. diffusers/schedulers/scheduling_dpmsolver_singlestep.py +86 -65
  228. diffusers/schedulers/scheduling_edm_dpmsolver_multistep.py +75 -54
  229. diffusers/schedulers/scheduling_edm_euler.py +50 -31
  230. diffusers/schedulers/scheduling_euler_ancestral_discrete.py +23 -29
  231. diffusers/schedulers/scheduling_euler_discrete.py +160 -68
  232. diffusers/schedulers/scheduling_heun_discrete.py +57 -39
  233. diffusers/schedulers/scheduling_ipndm.py +8 -8
  234. diffusers/schedulers/scheduling_k_dpm_2_ancestral_discrete.py +19 -19
  235. diffusers/schedulers/scheduling_k_dpm_2_discrete.py +19 -19
  236. diffusers/schedulers/scheduling_karras_ve_flax.py +6 -6
  237. diffusers/schedulers/scheduling_lcm.py +21 -23
  238. diffusers/schedulers/scheduling_lms_discrete.py +24 -26
  239. diffusers/schedulers/scheduling_pndm.py +20 -20
  240. diffusers/schedulers/scheduling_repaint.py +20 -20
  241. diffusers/schedulers/scheduling_sasolver.py +55 -54
  242. diffusers/schedulers/scheduling_sde_ve.py +19 -19
  243. diffusers/schedulers/scheduling_tcd.py +39 -30
  244. diffusers/schedulers/scheduling_unclip.py +15 -15
  245. diffusers/schedulers/scheduling_unipc_multistep.py +111 -41
  246. diffusers/schedulers/scheduling_utils.py +14 -5
  247. diffusers/schedulers/scheduling_utils_flax.py +3 -3
  248. diffusers/schedulers/scheduling_vq_diffusion.py +10 -10
  249. diffusers/training_utils.py +56 -1
  250. diffusers/utils/__init__.py +7 -0
  251. diffusers/utils/doc_utils.py +1 -0
  252. diffusers/utils/dummy_pt_objects.py +30 -0
  253. diffusers/utils/dummy_torch_and_transformers_objects.py +90 -0
  254. diffusers/utils/dynamic_modules_utils.py +24 -11
  255. diffusers/utils/hub_utils.py +3 -2
  256. diffusers/utils/import_utils.py +91 -0
  257. diffusers/utils/loading_utils.py +2 -2
  258. diffusers/utils/logging.py +1 -1
  259. diffusers/utils/peft_utils.py +32 -5
  260. diffusers/utils/state_dict_utils.py +11 -2
  261. diffusers/utils/testing_utils.py +71 -6
  262. diffusers/utils/torch_utils.py +1 -0
  263. diffusers/video_processor.py +113 -0
  264. {diffusers-0.27.2.dist-info → diffusers-0.28.0.dist-info}/METADATA +47 -47
  265. diffusers-0.28.0.dist-info/RECORD +414 -0
  266. {diffusers-0.27.2.dist-info → diffusers-0.28.0.dist-info}/WHEEL +1 -1
  267. diffusers-0.27.2.dist-info/RECORD +0 -399
  268. {diffusers-0.27.2.dist-info → diffusers-0.28.0.dist-info}/LICENSE +0 -0
  269. {diffusers-0.27.2.dist-info → diffusers-0.28.0.dist-info}/entry_points.txt +0 -0
  270. {diffusers-0.27.2.dist-info → diffusers-0.28.0.dist-info}/top_level.txt +0 -0
@@ -102,7 +102,6 @@ class Downsample2D(nn.Module):
102
102
  self.padding = padding
103
103
  stride = 2
104
104
  self.name = name
105
- conv_cls = nn.Conv2d
106
105
 
107
106
  if norm_type == "ln_norm":
108
107
  self.norm = nn.LayerNorm(channels, eps, elementwise_affine)
@@ -114,7 +113,7 @@ class Downsample2D(nn.Module):
114
113
  raise ValueError(f"unknown norm_type: {norm_type}")
115
114
 
116
115
  if use_conv:
117
- conv = conv_cls(
116
+ conv = nn.Conv2d(
118
117
  self.channels, self.out_channels, kernel_size=kernel_size, stride=stride, padding=padding, bias=bias
119
118
  )
120
119
  else:
@@ -130,7 +129,7 @@ class Downsample2D(nn.Module):
130
129
  else:
131
130
  self.conv = conv
132
131
 
133
- def forward(self, hidden_states: torch.FloatTensor, *args, **kwargs) -> torch.FloatTensor:
132
+ def forward(self, hidden_states: torch.Tensor, *args, **kwargs) -> torch.Tensor:
134
133
  if len(args) > 0 or kwargs.get("scale", None) is not None:
135
134
  deprecation_message = "The `scale` argument is deprecated and will be ignored. Please remove it, as passing it will raise an error in the future. `scale` should directly be passed while calling the underlying pipeline component i.e., via `cross_attention_kwargs`."
136
135
  deprecate("scale", "1.0.0", deprecation_message)
@@ -181,24 +180,24 @@ class FirDownsample2D(nn.Module):
181
180
 
182
181
  def _downsample_2d(
183
182
  self,
184
- hidden_states: torch.FloatTensor,
185
- weight: Optional[torch.FloatTensor] = None,
186
- kernel: Optional[torch.FloatTensor] = None,
183
+ hidden_states: torch.Tensor,
184
+ weight: Optional[torch.Tensor] = None,
185
+ kernel: Optional[torch.Tensor] = None,
187
186
  factor: int = 2,
188
187
  gain: float = 1,
189
- ) -> torch.FloatTensor:
188
+ ) -> torch.Tensor:
190
189
  """Fused `Conv2d()` followed by `downsample_2d()`.
191
190
  Padding is performed only once at the beginning, not between the operations. The fused op is considerably more
192
191
  efficient than performing the same calculation using standard TensorFlow ops. It supports gradients of
193
192
  arbitrary order.
194
193
 
195
194
  Args:
196
- hidden_states (`torch.FloatTensor`):
195
+ hidden_states (`torch.Tensor`):
197
196
  Input tensor of the shape `[N, C, H, W]` or `[N, H, W, C]`.
198
- weight (`torch.FloatTensor`, *optional*):
197
+ weight (`torch.Tensor`, *optional*):
199
198
  Weight tensor of the shape `[filterH, filterW, inChannels, outChannels]`. Grouped convolution can be
200
199
  performed by `inChannels = x.shape[0] // numGroups`.
201
- kernel (`torch.FloatTensor`, *optional*):
200
+ kernel (`torch.Tensor`, *optional*):
202
201
  FIR filter of the shape `[firH, firW]` or `[firN]` (separable). The default is `[1] * factor`, which
203
202
  corresponds to average pooling.
204
203
  factor (`int`, *optional*, default to `2`):
@@ -207,7 +206,7 @@ class FirDownsample2D(nn.Module):
207
206
  Scaling factor for signal magnitude.
208
207
 
209
208
  Returns:
210
- output (`torch.FloatTensor`):
209
+ output (`torch.Tensor`):
211
210
  Tensor of the shape `[N, C, H // factor, W // factor]` or `[N, H // factor, W // factor, C]`, and same
212
211
  datatype as `x`.
213
212
  """
@@ -245,7 +244,7 @@ class FirDownsample2D(nn.Module):
245
244
 
246
245
  return output
247
246
 
248
- def forward(self, hidden_states: torch.FloatTensor) -> torch.FloatTensor:
247
+ def forward(self, hidden_states: torch.Tensor) -> torch.Tensor:
249
248
  if self.use_conv:
250
249
  downsample_input = self._downsample_2d(hidden_states, weight=self.Conv2d_0.weight, kernel=self.fir_kernel)
251
250
  hidden_states = downsample_input + self.Conv2d_0.bias.reshape(1, -1, 1, 1)
@@ -287,11 +286,11 @@ class KDownsample2D(nn.Module):
287
286
 
288
287
 
289
288
  def downsample_2d(
290
- hidden_states: torch.FloatTensor,
291
- kernel: Optional[torch.FloatTensor] = None,
289
+ hidden_states: torch.Tensor,
290
+ kernel: Optional[torch.Tensor] = None,
292
291
  factor: int = 2,
293
292
  gain: float = 1,
294
- ) -> torch.FloatTensor:
293
+ ) -> torch.Tensor:
295
294
  r"""Downsample2D a batch of 2D images with the given filter.
296
295
  Accepts a batch of 2D images of the shape `[N, C, H, W]` or `[N, H, W, C]` and downsamples each image with the
297
296
  given filter. The filter is normalized so that if the input pixels are constant, they will be scaled by the
@@ -299,9 +298,9 @@ def downsample_2d(
299
298
  shape is a multiple of the downsampling factor.
300
299
 
301
300
  Args:
302
- hidden_states (`torch.FloatTensor`)
301
+ hidden_states (`torch.Tensor`)
303
302
  Input tensor of the shape `[N, C, H, W]` or `[N, H, W, C]`.
304
- kernel (`torch.FloatTensor`, *optional*):
303
+ kernel (`torch.Tensor`, *optional*):
305
304
  FIR filter of the shape `[firH, firW]` or `[firN]` (separable). The default is `[1] * factor`, which
306
305
  corresponds to average pooling.
307
306
  factor (`int`, *optional*, default to `2`):
@@ -310,7 +309,7 @@ def downsample_2d(
310
309
  Scaling factor for signal magnitude.
311
310
 
312
311
  Returns:
313
- output (`torch.FloatTensor`):
312
+ output (`torch.Tensor`):
314
313
  Tensor of the shape `[N, C, H // factor, W // factor]`
315
314
  """
316
315
 
@@ -199,9 +199,8 @@ class TimestepEmbedding(nn.Module):
199
199
  sample_proj_bias=True,
200
200
  ):
201
201
  super().__init__()
202
- linear_cls = nn.Linear
203
202
 
204
- self.linear_1 = linear_cls(in_channels, time_embed_dim, sample_proj_bias)
203
+ self.linear_1 = nn.Linear(in_channels, time_embed_dim, sample_proj_bias)
205
204
 
206
205
  if cond_proj_dim is not None:
207
206
  self.cond_proj = nn.Linear(cond_proj_dim, in_channels, bias=False)
@@ -214,7 +213,7 @@ class TimestepEmbedding(nn.Module):
214
213
  time_embed_dim_out = out_dim
215
214
  else:
216
215
  time_embed_dim_out = time_embed_dim
217
- self.linear_2 = linear_cls(time_embed_dim, time_embed_dim_out, sample_proj_bias)
216
+ self.linear_2 = nn.Linear(time_embed_dim, time_embed_dim_out, sample_proj_bias)
218
217
 
219
218
  if post_act_fn is None:
220
219
  self.post_act = None
@@ -425,7 +424,7 @@ class TextImageProjection(nn.Module):
425
424
  self.image_embeds = nn.Linear(image_embed_dim, self.num_image_text_embeds * cross_attention_dim)
426
425
  self.text_proj = nn.Linear(text_embed_dim, cross_attention_dim)
427
426
 
428
- def forward(self, text_embeds: torch.FloatTensor, image_embeds: torch.FloatTensor):
427
+ def forward(self, text_embeds: torch.Tensor, image_embeds: torch.Tensor):
429
428
  batch_size = text_embeds.shape[0]
430
429
 
431
430
  # image
@@ -451,7 +450,7 @@ class ImageProjection(nn.Module):
451
450
  self.image_embeds = nn.Linear(image_embed_dim, self.num_image_text_embeds * cross_attention_dim)
452
451
  self.norm = nn.LayerNorm(cross_attention_dim)
453
452
 
454
- def forward(self, image_embeds: torch.FloatTensor):
453
+ def forward(self, image_embeds: torch.Tensor):
455
454
  batch_size = image_embeds.shape[0]
456
455
 
457
456
  # image
@@ -469,10 +468,26 @@ class IPAdapterFullImageProjection(nn.Module):
469
468
  self.ff = FeedForward(image_embed_dim, cross_attention_dim, mult=1, activation_fn="gelu")
470
469
  self.norm = nn.LayerNorm(cross_attention_dim)
471
470
 
472
- def forward(self, image_embeds: torch.FloatTensor):
471
+ def forward(self, image_embeds: torch.Tensor):
473
472
  return self.norm(self.ff(image_embeds))
474
473
 
475
474
 
475
+ class IPAdapterFaceIDImageProjection(nn.Module):
476
+ def __init__(self, image_embed_dim=1024, cross_attention_dim=1024, mult=1, num_tokens=1):
477
+ super().__init__()
478
+ from .attention import FeedForward
479
+
480
+ self.num_tokens = num_tokens
481
+ self.cross_attention_dim = cross_attention_dim
482
+ self.ff = FeedForward(image_embed_dim, cross_attention_dim * num_tokens, mult=mult, activation_fn="gelu")
483
+ self.norm = nn.LayerNorm(cross_attention_dim)
484
+
485
+ def forward(self, image_embeds: torch.Tensor):
486
+ x = self.ff(image_embeds)
487
+ x = x.reshape(-1, self.num_tokens, self.cross_attention_dim)
488
+ return self.norm(x)
489
+
490
+
476
491
  class CombinedTimestepLabelEmbeddings(nn.Module):
477
492
  def __init__(self, num_classes, embedding_dim, class_dropout_prob=0.1):
478
493
  super().__init__()
@@ -515,7 +530,7 @@ class TextImageTimeEmbedding(nn.Module):
515
530
  self.text_norm = nn.LayerNorm(time_embed_dim)
516
531
  self.image_proj = nn.Linear(image_embed_dim, time_embed_dim)
517
532
 
518
- def forward(self, text_embeds: torch.FloatTensor, image_embeds: torch.FloatTensor):
533
+ def forward(self, text_embeds: torch.Tensor, image_embeds: torch.Tensor):
519
534
  # text
520
535
  time_text_embeds = self.text_proj(text_embeds)
521
536
  time_text_embeds = self.text_norm(time_text_embeds)
@@ -532,7 +547,7 @@ class ImageTimeEmbedding(nn.Module):
532
547
  self.image_proj = nn.Linear(image_embed_dim, time_embed_dim)
533
548
  self.image_norm = nn.LayerNorm(time_embed_dim)
534
549
 
535
- def forward(self, image_embeds: torch.FloatTensor):
550
+ def forward(self, image_embeds: torch.Tensor):
536
551
  # image
537
552
  time_image_embeds = self.image_proj(image_embeds)
538
553
  time_image_embeds = self.image_norm(time_image_embeds)
@@ -562,7 +577,7 @@ class ImageHintTimeEmbedding(nn.Module):
562
577
  nn.Conv2d(256, 4, 3, padding=1),
563
578
  )
564
579
 
565
- def forward(self, image_embeds: torch.FloatTensor, hint: torch.FloatTensor):
580
+ def forward(self, image_embeds: torch.Tensor, hint: torch.Tensor):
566
581
  # image
567
582
  time_image_embeds = self.image_proj(image_embeds)
568
583
  time_image_embeds = self.image_norm(time_image_embeds)
@@ -795,17 +810,15 @@ class IPAdapterPlusImageProjection(nn.Module):
795
810
  """Resampler of IP-Adapter Plus.
796
811
 
797
812
  Args:
798
- ----
799
- embed_dims (int): The feature dimension. Defaults to 768.
800
- output_dims (int): The number of output channels, that is the same
801
- number of the channels in the
802
- `unet.config.cross_attention_dim`. Defaults to 1024.
803
- hidden_dims (int): The number of hidden channels. Defaults to 1280.
804
- depth (int): The number of blocks. Defaults to 8.
805
- dim_head (int): The number of head channels. Defaults to 64.
806
- heads (int): Parallel attention heads. Defaults to 16.
807
- num_queries (int): The number of queries. Defaults to 8.
808
- ffn_ratio (float): The expansion ratio of feedforward network hidden
813
+ embed_dims (int): The feature dimension. Defaults to 768. output_dims (int): The number of output channels,
814
+ that is the same
815
+ number of the channels in the `unet.config.cross_attention_dim`. Defaults to 1024.
816
+ hidden_dims (int):
817
+ The number of hidden channels. Defaults to 1280. depth (int): The number of blocks. Defaults
818
+ to 8. dim_head (int): The number of head channels. Defaults to 64. heads (int): Parallel attention heads.
819
+ Defaults to 16. num_queries (int):
820
+ The number of queries. Defaults to 8. ffn_ratio (float): The expansion ratio
821
+ of feedforward network hidden
809
822
  layer channels. Defaults to 4.
810
823
  """
811
824
 
@@ -855,11 +868,8 @@ class IPAdapterPlusImageProjection(nn.Module):
855
868
  """Forward pass.
856
869
 
857
870
  Args:
858
- ----
859
871
  x (torch.Tensor): Input Tensor.
860
-
861
872
  Returns:
862
- -------
863
873
  torch.Tensor: Output Tensor.
864
874
  """
865
875
  latents = self.latents.repeat(x.size(0), 1, 1)
@@ -879,12 +889,125 @@ class IPAdapterPlusImageProjection(nn.Module):
879
889
  return self.norm_out(latents)
880
890
 
881
891
 
892
+ class IPAdapterPlusImageProjectionBlock(nn.Module):
893
+ def __init__(
894
+ self,
895
+ embed_dims: int = 768,
896
+ dim_head: int = 64,
897
+ heads: int = 16,
898
+ ffn_ratio: float = 4,
899
+ ) -> None:
900
+ super().__init__()
901
+ from .attention import FeedForward
902
+
903
+ self.ln0 = nn.LayerNorm(embed_dims)
904
+ self.ln1 = nn.LayerNorm(embed_dims)
905
+ self.attn = Attention(
906
+ query_dim=embed_dims,
907
+ dim_head=dim_head,
908
+ heads=heads,
909
+ out_bias=False,
910
+ )
911
+ self.ff = nn.Sequential(
912
+ nn.LayerNorm(embed_dims),
913
+ FeedForward(embed_dims, embed_dims, activation_fn="gelu", mult=ffn_ratio, bias=False),
914
+ )
915
+
916
+ def forward(self, x, latents, residual):
917
+ encoder_hidden_states = self.ln0(x)
918
+ latents = self.ln1(latents)
919
+ encoder_hidden_states = torch.cat([encoder_hidden_states, latents], dim=-2)
920
+ latents = self.attn(latents, encoder_hidden_states) + residual
921
+ latents = self.ff(latents) + latents
922
+ return latents
923
+
924
+
925
+ class IPAdapterFaceIDPlusImageProjection(nn.Module):
926
+ """FacePerceiverResampler of IP-Adapter Plus.
927
+
928
+ Args:
929
+ embed_dims (int): The feature dimension. Defaults to 768. output_dims (int): The number of output channels,
930
+ that is the same
931
+ number of the channels in the `unet.config.cross_attention_dim`. Defaults to 1024.
932
+ hidden_dims (int):
933
+ The number of hidden channels. Defaults to 1280. depth (int): The number of blocks. Defaults
934
+ to 8. dim_head (int): The number of head channels. Defaults to 64. heads (int): Parallel attention heads.
935
+ Defaults to 16. num_tokens (int): Number of tokens num_queries (int): The number of queries. Defaults to 8.
936
+ ffn_ratio (float): The expansion ratio of feedforward network hidden
937
+ layer channels. Defaults to 4.
938
+ ffproj_ratio (float): The expansion ratio of feedforward network hidden
939
+ layer channels (for ID embeddings). Defaults to 4.
940
+ """
941
+
942
+ def __init__(
943
+ self,
944
+ embed_dims: int = 768,
945
+ output_dims: int = 768,
946
+ hidden_dims: int = 1280,
947
+ id_embeddings_dim: int = 512,
948
+ depth: int = 4,
949
+ dim_head: int = 64,
950
+ heads: int = 16,
951
+ num_tokens: int = 4,
952
+ num_queries: int = 8,
953
+ ffn_ratio: float = 4,
954
+ ffproj_ratio: int = 2,
955
+ ) -> None:
956
+ super().__init__()
957
+ from .attention import FeedForward
958
+
959
+ self.num_tokens = num_tokens
960
+ self.embed_dim = embed_dims
961
+ self.clip_embeds = None
962
+ self.shortcut = False
963
+ self.shortcut_scale = 1.0
964
+
965
+ self.proj = FeedForward(id_embeddings_dim, embed_dims * num_tokens, activation_fn="gelu", mult=ffproj_ratio)
966
+ self.norm = nn.LayerNorm(embed_dims)
967
+
968
+ self.proj_in = nn.Linear(hidden_dims, embed_dims)
969
+
970
+ self.proj_out = nn.Linear(embed_dims, output_dims)
971
+ self.norm_out = nn.LayerNorm(output_dims)
972
+
973
+ self.layers = nn.ModuleList(
974
+ [IPAdapterPlusImageProjectionBlock(embed_dims, dim_head, heads, ffn_ratio) for _ in range(depth)]
975
+ )
976
+
977
+ def forward(self, id_embeds: torch.Tensor) -> torch.Tensor:
978
+ """Forward pass.
979
+
980
+ Args:
981
+ id_embeds (torch.Tensor): Input Tensor (ID embeds).
982
+ Returns:
983
+ torch.Tensor: Output Tensor.
984
+ """
985
+ id_embeds = id_embeds.to(self.clip_embeds.dtype)
986
+ id_embeds = self.proj(id_embeds)
987
+ id_embeds = id_embeds.reshape(-1, self.num_tokens, self.embed_dim)
988
+ id_embeds = self.norm(id_embeds)
989
+ latents = id_embeds
990
+
991
+ clip_embeds = self.proj_in(self.clip_embeds)
992
+ x = clip_embeds.reshape(-1, clip_embeds.shape[2], clip_embeds.shape[3])
993
+
994
+ for block in self.layers:
995
+ residual = latents
996
+ latents = block(x, latents, residual)
997
+
998
+ latents = self.proj_out(latents)
999
+ out = self.norm_out(latents)
1000
+ if self.shortcut:
1001
+ out = id_embeds + self.shortcut_scale * out
1002
+ return out
1003
+
1004
+
882
1005
  class MultiIPAdapterImageProjection(nn.Module):
883
1006
  def __init__(self, IPAdapterImageProjectionLayers: Union[List[nn.Module], Tuple[nn.Module]]):
884
1007
  super().__init__()
885
1008
  self.image_projection_layers = nn.ModuleList(IPAdapterImageProjectionLayers)
886
1009
 
887
- def forward(self, image_embeds: List[torch.FloatTensor]):
1010
+ def forward(self, image_embeds: List[torch.Tensor]):
888
1011
  projected_image_embeds = []
889
1012
 
890
1013
  # currently, we accept `image_embeds` as
@@ -0,0 +1,149 @@
1
+ # coding=utf-8
2
+ # Copyright 2024 The HuggingFace Inc. team.
3
+ # Copyright (c) 2022, NVIDIA CORPORATION. All rights reserved.
4
+ #
5
+ # Licensed under the Apache License, Version 2.0 (the "License");
6
+ # you may not use this file except in compliance with the License.
7
+ # You may obtain a copy of the License at
8
+ #
9
+ # http://www.apache.org/licenses/LICENSE-2.0
10
+ #
11
+ # Unless required by applicable law or agreed to in writing, software
12
+ # distributed under the License is distributed on an "AS IS" BASIS,
13
+ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
14
+ # See the License for the specific language governing permissions and
15
+ # limitations under the License.
16
+
17
+ import inspect
18
+ import os
19
+ from collections import OrderedDict
20
+ from typing import List, Optional, Union
21
+
22
+ import safetensors
23
+ import torch
24
+
25
+ from ..utils import (
26
+ SAFETENSORS_FILE_EXTENSION,
27
+ is_accelerate_available,
28
+ is_torch_version,
29
+ logging,
30
+ )
31
+
32
+
33
+ logger = logging.get_logger(__name__)
34
+
35
+
36
+ if is_accelerate_available():
37
+ from accelerate import infer_auto_device_map
38
+ from accelerate.utils import get_balanced_memory, get_max_memory, set_module_tensor_to_device
39
+
40
+
41
+ # Adapted from `transformers` (see modeling_utils.py)
42
+ def _determine_device_map(model: torch.nn.Module, device_map, max_memory, torch_dtype):
43
+ if isinstance(device_map, str):
44
+ no_split_modules = model._get_no_split_modules(device_map)
45
+ device_map_kwargs = {"no_split_module_classes": no_split_modules}
46
+
47
+ if device_map != "sequential":
48
+ max_memory = get_balanced_memory(
49
+ model,
50
+ dtype=torch_dtype,
51
+ low_zero=(device_map == "balanced_low_0"),
52
+ max_memory=max_memory,
53
+ **device_map_kwargs,
54
+ )
55
+ else:
56
+ max_memory = get_max_memory(max_memory)
57
+
58
+ device_map_kwargs["max_memory"] = max_memory
59
+ device_map = infer_auto_device_map(model, dtype=torch_dtype, **device_map_kwargs)
60
+
61
+ return device_map
62
+
63
+
64
+ def load_state_dict(checkpoint_file: Union[str, os.PathLike], variant: Optional[str] = None):
65
+ """
66
+ Reads a checkpoint file, returning properly formatted errors if they arise.
67
+ """
68
+ try:
69
+ file_extension = os.path.basename(checkpoint_file).split(".")[-1]
70
+ if file_extension == SAFETENSORS_FILE_EXTENSION:
71
+ return safetensors.torch.load_file(checkpoint_file, device="cpu")
72
+ else:
73
+ weights_only_kwarg = {"weights_only": True} if is_torch_version(">=", "1.13") else {}
74
+ return torch.load(
75
+ checkpoint_file,
76
+ map_location="cpu",
77
+ **weights_only_kwarg,
78
+ )
79
+ except Exception as e:
80
+ try:
81
+ with open(checkpoint_file) as f:
82
+ if f.read().startswith("version"):
83
+ raise OSError(
84
+ "You seem to have cloned a repository without having git-lfs installed. Please install "
85
+ "git-lfs and run `git lfs install` followed by `git lfs pull` in the folder "
86
+ "you cloned."
87
+ )
88
+ else:
89
+ raise ValueError(
90
+ f"Unable to locate the file {checkpoint_file} which is necessary to load this pretrained "
91
+ "model. Make sure you have saved the model properly."
92
+ ) from e
93
+ except (UnicodeDecodeError, ValueError):
94
+ raise OSError(
95
+ f"Unable to load weights from checkpoint file for '{checkpoint_file}' " f"at '{checkpoint_file}'. "
96
+ )
97
+
98
+
99
+ def load_model_dict_into_meta(
100
+ model,
101
+ state_dict: OrderedDict,
102
+ device: Optional[Union[str, torch.device]] = None,
103
+ dtype: Optional[Union[str, torch.dtype]] = None,
104
+ model_name_or_path: Optional[str] = None,
105
+ ) -> List[str]:
106
+ device = device or torch.device("cpu")
107
+ dtype = dtype or torch.float32
108
+
109
+ accepts_dtype = "dtype" in set(inspect.signature(set_module_tensor_to_device).parameters.keys())
110
+
111
+ unexpected_keys = []
112
+ empty_state_dict = model.state_dict()
113
+ for param_name, param in state_dict.items():
114
+ if param_name not in empty_state_dict:
115
+ unexpected_keys.append(param_name)
116
+ continue
117
+
118
+ if empty_state_dict[param_name].shape != param.shape:
119
+ model_name_or_path_str = f"{model_name_or_path} " if model_name_or_path is not None else ""
120
+ raise ValueError(
121
+ f"Cannot load {model_name_or_path_str}because {param_name} expected shape {empty_state_dict[param_name]}, but got {param.shape}. If you want to instead overwrite randomly initialized weights, please make sure to pass both `low_cpu_mem_usage=False` and `ignore_mismatched_sizes=True`. For more information, see also: https://github.com/huggingface/diffusers/issues/1619#issuecomment-1345604389 as an example."
122
+ )
123
+
124
+ if accepts_dtype:
125
+ set_module_tensor_to_device(model, param_name, device, value=param, dtype=dtype)
126
+ else:
127
+ set_module_tensor_to_device(model, param_name, device, value=param)
128
+ return unexpected_keys
129
+
130
+
131
+ def _load_state_dict_into_model(model_to_load, state_dict: OrderedDict) -> List[str]:
132
+ # Convert old format to new format if needed from a PyTorch state_dict
133
+ # copy state_dict so _load_from_state_dict can modify it
134
+ state_dict = state_dict.copy()
135
+ error_msgs = []
136
+
137
+ # PyTorch's `_load_from_state_dict` does not copy parameters in a module's descendants
138
+ # so we need to apply the function recursively.
139
+ def load(module: torch.nn.Module, prefix: str = ""):
140
+ args = (state_dict, prefix, {}, True, [], [], error_msgs)
141
+ module._load_from_state_dict(*args)
142
+
143
+ for name, child in module._modules.items():
144
+ if child is not None:
145
+ load(child, prefix + name + ".")
146
+
147
+ load(model_to_load)
148
+
149
+ return error_msgs
@@ -12,7 +12,8 @@
12
12
  # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13
13
  # See the License for the specific language governing permissions and
14
14
  # limitations under the License.
15
- """ PyTorch - Flax general utilities."""
15
+ """PyTorch - Flax general utilities."""
16
+
16
17
  import re
17
18
 
18
19
  import jax.numpy as jnp
@@ -245,9 +245,9 @@ class FlaxModelMixin(PushToHubMixin):
245
245
  force_download (`bool`, *optional*, defaults to `False`):
246
246
  Whether or not to force the (re-)download of the model weights and configuration files, overriding the
247
247
  cached versions if they exist.
248
- resume_download (`bool`, *optional*, defaults to `False`):
249
- Whether or not to resume downloading the model weights and configuration files. If set to `False`, any
250
- incompletely downloaded files are deleted.
248
+ resume_download:
249
+ Deprecated and ignored. All downloads are now resumed by default when possible. Will be removed in v1
250
+ of Diffusers.
251
251
  proxies (`Dict[str, str]`, *optional*):
252
252
  A dictionary of proxy servers to use by protocol or endpoint, for example, `{'http': 'foo.bar:3128',
253
253
  'http://hostname': 'foo.bar:4012'}`. The proxies are used on each request.
@@ -296,7 +296,7 @@ class FlaxModelMixin(PushToHubMixin):
296
296
  cache_dir = kwargs.pop("cache_dir", None)
297
297
  force_download = kwargs.pop("force_download", False)
298
298
  from_pt = kwargs.pop("from_pt", False)
299
- resume_download = kwargs.pop("resume_download", False)
299
+ resume_download = kwargs.pop("resume_download", None)
300
300
  proxies = kwargs.pop("proxies", None)
301
301
  local_files_only = kwargs.pop("local_files_only", False)
302
302
  token = kwargs.pop("token", None)
@@ -12,7 +12,7 @@
12
12
  # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13
13
  # See the License for the specific language governing permissions and
14
14
  # limitations under the License.
15
- """ PyTorch - Flax general utilities."""
15
+ """PyTorch - Flax general utilities."""
16
16
 
17
17
  from pickle import UnpicklingError
18
18