diffusers 0.31.0__py3-none-any.whl → 0.32.1__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (214) hide show
  1. diffusers/__init__.py +66 -5
  2. diffusers/callbacks.py +56 -3
  3. diffusers/configuration_utils.py +1 -1
  4. diffusers/dependency_versions_table.py +1 -1
  5. diffusers/image_processor.py +25 -17
  6. diffusers/loaders/__init__.py +22 -3
  7. diffusers/loaders/ip_adapter.py +538 -15
  8. diffusers/loaders/lora_base.py +124 -118
  9. diffusers/loaders/lora_conversion_utils.py +318 -3
  10. diffusers/loaders/lora_pipeline.py +1688 -368
  11. diffusers/loaders/peft.py +379 -0
  12. diffusers/loaders/single_file_model.py +71 -4
  13. diffusers/loaders/single_file_utils.py +519 -9
  14. diffusers/loaders/textual_inversion.py +3 -3
  15. diffusers/loaders/transformer_flux.py +181 -0
  16. diffusers/loaders/transformer_sd3.py +89 -0
  17. diffusers/loaders/unet.py +17 -4
  18. diffusers/models/__init__.py +47 -14
  19. diffusers/models/activations.py +22 -9
  20. diffusers/models/attention.py +13 -4
  21. diffusers/models/attention_flax.py +1 -1
  22. diffusers/models/attention_processor.py +2059 -281
  23. diffusers/models/autoencoders/__init__.py +5 -0
  24. diffusers/models/autoencoders/autoencoder_dc.py +620 -0
  25. diffusers/models/autoencoders/autoencoder_kl.py +2 -1
  26. diffusers/models/autoencoders/autoencoder_kl_allegro.py +1149 -0
  27. diffusers/models/autoencoders/autoencoder_kl_cogvideox.py +36 -27
  28. diffusers/models/autoencoders/autoencoder_kl_hunyuan_video.py +1176 -0
  29. diffusers/models/autoencoders/autoencoder_kl_ltx.py +1338 -0
  30. diffusers/models/autoencoders/autoencoder_kl_mochi.py +1166 -0
  31. diffusers/models/autoencoders/autoencoder_kl_temporal_decoder.py +3 -10
  32. diffusers/models/autoencoders/autoencoder_tiny.py +4 -2
  33. diffusers/models/autoencoders/vae.py +18 -5
  34. diffusers/models/controlnet.py +47 -802
  35. diffusers/models/controlnet_flux.py +29 -495
  36. diffusers/models/controlnet_sd3.py +25 -379
  37. diffusers/models/controlnet_sparsectrl.py +46 -718
  38. diffusers/models/controlnets/__init__.py +23 -0
  39. diffusers/models/controlnets/controlnet.py +872 -0
  40. diffusers/models/{controlnet_flax.py → controlnets/controlnet_flax.py} +5 -5
  41. diffusers/models/controlnets/controlnet_flux.py +536 -0
  42. diffusers/models/{controlnet_hunyuan.py → controlnets/controlnet_hunyuan.py} +7 -7
  43. diffusers/models/controlnets/controlnet_sd3.py +489 -0
  44. diffusers/models/controlnets/controlnet_sparsectrl.py +788 -0
  45. diffusers/models/controlnets/controlnet_union.py +832 -0
  46. diffusers/models/{controlnet_xs.py → controlnets/controlnet_xs.py} +14 -13
  47. diffusers/models/controlnets/multicontrolnet.py +183 -0
  48. diffusers/models/embeddings.py +838 -43
  49. diffusers/models/model_loading_utils.py +88 -6
  50. diffusers/models/modeling_flax_utils.py +1 -1
  51. diffusers/models/modeling_utils.py +72 -26
  52. diffusers/models/normalization.py +78 -13
  53. diffusers/models/transformers/__init__.py +5 -0
  54. diffusers/models/transformers/auraflow_transformer_2d.py +2 -2
  55. diffusers/models/transformers/cogvideox_transformer_3d.py +46 -11
  56. diffusers/models/transformers/dit_transformer_2d.py +1 -1
  57. diffusers/models/transformers/latte_transformer_3d.py +4 -4
  58. diffusers/models/transformers/pixart_transformer_2d.py +1 -1
  59. diffusers/models/transformers/sana_transformer.py +488 -0
  60. diffusers/models/transformers/stable_audio_transformer.py +1 -1
  61. diffusers/models/transformers/transformer_2d.py +1 -1
  62. diffusers/models/transformers/transformer_allegro.py +422 -0
  63. diffusers/models/transformers/transformer_cogview3plus.py +1 -1
  64. diffusers/models/transformers/transformer_flux.py +30 -9
  65. diffusers/models/transformers/transformer_hunyuan_video.py +789 -0
  66. diffusers/models/transformers/transformer_ltx.py +469 -0
  67. diffusers/models/transformers/transformer_mochi.py +499 -0
  68. diffusers/models/transformers/transformer_sd3.py +105 -17
  69. diffusers/models/transformers/transformer_temporal.py +1 -1
  70. diffusers/models/unets/unet_1d_blocks.py +1 -1
  71. diffusers/models/unets/unet_2d.py +8 -1
  72. diffusers/models/unets/unet_2d_blocks.py +88 -21
  73. diffusers/models/unets/unet_2d_condition.py +1 -1
  74. diffusers/models/unets/unet_3d_blocks.py +9 -7
  75. diffusers/models/unets/unet_motion_model.py +5 -5
  76. diffusers/models/unets/unet_spatio_temporal_condition.py +23 -0
  77. diffusers/models/unets/unet_stable_cascade.py +2 -2
  78. diffusers/models/unets/uvit_2d.py +1 -1
  79. diffusers/models/upsampling.py +8 -0
  80. diffusers/pipelines/__init__.py +34 -0
  81. diffusers/pipelines/allegro/__init__.py +48 -0
  82. diffusers/pipelines/allegro/pipeline_allegro.py +938 -0
  83. diffusers/pipelines/allegro/pipeline_output.py +23 -0
  84. diffusers/pipelines/animatediff/pipeline_animatediff_controlnet.py +8 -2
  85. diffusers/pipelines/animatediff/pipeline_animatediff_sparsectrl.py +1 -1
  86. diffusers/pipelines/animatediff/pipeline_animatediff_video2video.py +0 -6
  87. diffusers/pipelines/animatediff/pipeline_animatediff_video2video_controlnet.py +8 -8
  88. diffusers/pipelines/audioldm2/modeling_audioldm2.py +3 -3
  89. diffusers/pipelines/aura_flow/pipeline_aura_flow.py +1 -8
  90. diffusers/pipelines/auto_pipeline.py +53 -6
  91. diffusers/pipelines/blip_diffusion/modeling_blip2.py +1 -1
  92. diffusers/pipelines/cogvideo/pipeline_cogvideox.py +50 -22
  93. diffusers/pipelines/cogvideo/pipeline_cogvideox_fun_control.py +51 -20
  94. diffusers/pipelines/cogvideo/pipeline_cogvideox_image2video.py +69 -21
  95. diffusers/pipelines/cogvideo/pipeline_cogvideox_video2video.py +47 -21
  96. diffusers/pipelines/cogview3/pipeline_cogview3plus.py +1 -1
  97. diffusers/pipelines/controlnet/__init__.py +86 -80
  98. diffusers/pipelines/controlnet/multicontrolnet.py +7 -178
  99. diffusers/pipelines/controlnet/pipeline_controlnet.py +11 -2
  100. diffusers/pipelines/controlnet/pipeline_controlnet_img2img.py +1 -2
  101. diffusers/pipelines/controlnet/pipeline_controlnet_inpaint.py +1 -2
  102. diffusers/pipelines/controlnet/pipeline_controlnet_inpaint_sd_xl.py +1 -2
  103. diffusers/pipelines/controlnet/pipeline_controlnet_sd_xl.py +3 -3
  104. diffusers/pipelines/controlnet/pipeline_controlnet_sd_xl_img2img.py +1 -3
  105. diffusers/pipelines/controlnet/pipeline_controlnet_union_inpaint_sd_xl.py +1790 -0
  106. diffusers/pipelines/controlnet/pipeline_controlnet_union_sd_xl.py +1501 -0
  107. diffusers/pipelines/controlnet/pipeline_controlnet_union_sd_xl_img2img.py +1627 -0
  108. diffusers/pipelines/controlnet_hunyuandit/pipeline_hunyuandit_controlnet.py +5 -1
  109. diffusers/pipelines/controlnet_sd3/pipeline_stable_diffusion_3_controlnet.py +53 -19
  110. diffusers/pipelines/controlnet_sd3/pipeline_stable_diffusion_3_controlnet_inpainting.py +7 -7
  111. diffusers/pipelines/deprecated/versatile_diffusion/modeling_text_unet.py +31 -8
  112. diffusers/pipelines/flux/__init__.py +13 -1
  113. diffusers/pipelines/flux/modeling_flux.py +47 -0
  114. diffusers/pipelines/flux/pipeline_flux.py +204 -29
  115. diffusers/pipelines/flux/pipeline_flux_control.py +889 -0
  116. diffusers/pipelines/flux/pipeline_flux_control_img2img.py +945 -0
  117. diffusers/pipelines/flux/pipeline_flux_control_inpaint.py +1141 -0
  118. diffusers/pipelines/flux/pipeline_flux_controlnet.py +49 -27
  119. diffusers/pipelines/flux/pipeline_flux_controlnet_image_to_image.py +40 -30
  120. diffusers/pipelines/flux/pipeline_flux_controlnet_inpainting.py +78 -56
  121. diffusers/pipelines/flux/pipeline_flux_fill.py +969 -0
  122. diffusers/pipelines/flux/pipeline_flux_img2img.py +33 -27
  123. diffusers/pipelines/flux/pipeline_flux_inpaint.py +36 -29
  124. diffusers/pipelines/flux/pipeline_flux_prior_redux.py +492 -0
  125. diffusers/pipelines/flux/pipeline_output.py +16 -0
  126. diffusers/pipelines/hunyuan_video/__init__.py +48 -0
  127. diffusers/pipelines/hunyuan_video/pipeline_hunyuan_video.py +687 -0
  128. diffusers/pipelines/hunyuan_video/pipeline_output.py +20 -0
  129. diffusers/pipelines/hunyuandit/pipeline_hunyuandit.py +5 -1
  130. diffusers/pipelines/kandinsky/pipeline_kandinsky_combined.py +9 -9
  131. diffusers/pipelines/kolors/text_encoder.py +2 -2
  132. diffusers/pipelines/latent_diffusion/pipeline_latent_diffusion.py +1 -1
  133. diffusers/pipelines/ltx/__init__.py +50 -0
  134. diffusers/pipelines/ltx/pipeline_ltx.py +789 -0
  135. diffusers/pipelines/ltx/pipeline_ltx_image2video.py +885 -0
  136. diffusers/pipelines/ltx/pipeline_output.py +20 -0
  137. diffusers/pipelines/lumina/pipeline_lumina.py +1 -8
  138. diffusers/pipelines/mochi/__init__.py +48 -0
  139. diffusers/pipelines/mochi/pipeline_mochi.py +748 -0
  140. diffusers/pipelines/mochi/pipeline_output.py +20 -0
  141. diffusers/pipelines/pag/__init__.py +7 -0
  142. diffusers/pipelines/pag/pipeline_pag_controlnet_sd.py +1 -2
  143. diffusers/pipelines/pag/pipeline_pag_controlnet_sd_inpaint.py +1 -2
  144. diffusers/pipelines/pag/pipeline_pag_controlnet_sd_xl.py +1 -3
  145. diffusers/pipelines/pag/pipeline_pag_controlnet_sd_xl_img2img.py +1 -3
  146. diffusers/pipelines/pag/pipeline_pag_hunyuandit.py +5 -1
  147. diffusers/pipelines/pag/pipeline_pag_pixart_sigma.py +6 -13
  148. diffusers/pipelines/pag/pipeline_pag_sana.py +886 -0
  149. diffusers/pipelines/pag/pipeline_pag_sd_3.py +6 -6
  150. diffusers/pipelines/pag/pipeline_pag_sd_3_img2img.py +1058 -0
  151. diffusers/pipelines/pag/pipeline_pag_sd_img2img.py +3 -0
  152. diffusers/pipelines/pag/pipeline_pag_sd_inpaint.py +1356 -0
  153. diffusers/pipelines/pipeline_flax_utils.py +1 -1
  154. diffusers/pipelines/pipeline_loading_utils.py +25 -4
  155. diffusers/pipelines/pipeline_utils.py +35 -6
  156. diffusers/pipelines/pixart_alpha/pipeline_pixart_alpha.py +6 -13
  157. diffusers/pipelines/pixart_alpha/pipeline_pixart_sigma.py +6 -13
  158. diffusers/pipelines/sana/__init__.py +47 -0
  159. diffusers/pipelines/sana/pipeline_output.py +21 -0
  160. diffusers/pipelines/sana/pipeline_sana.py +884 -0
  161. diffusers/pipelines/stable_audio/pipeline_stable_audio.py +12 -1
  162. diffusers/pipelines/stable_diffusion/pipeline_stable_diffusion.py +18 -3
  163. diffusers/pipelines/stable_diffusion_3/pipeline_stable_diffusion_3.py +216 -20
  164. diffusers/pipelines/stable_diffusion_3/pipeline_stable_diffusion_3_img2img.py +62 -9
  165. diffusers/pipelines/stable_diffusion_3/pipeline_stable_diffusion_3_inpaint.py +57 -8
  166. diffusers/pipelines/stable_diffusion_gligen/pipeline_stable_diffusion_gligen_text_image.py +11 -1
  167. diffusers/pipelines/stable_diffusion_xl/pipeline_stable_diffusion_xl.py +0 -8
  168. diffusers/pipelines/stable_diffusion_xl/pipeline_stable_diffusion_xl_img2img.py +0 -8
  169. diffusers/pipelines/stable_diffusion_xl/pipeline_stable_diffusion_xl_inpaint.py +0 -8
  170. diffusers/pipelines/unidiffuser/modeling_uvit.py +2 -2
  171. diffusers/pipelines/wuerstchen/modeling_wuerstchen_prior.py +1 -1
  172. diffusers/quantizers/auto.py +14 -1
  173. diffusers/quantizers/bitsandbytes/bnb_quantizer.py +4 -1
  174. diffusers/quantizers/gguf/__init__.py +1 -0
  175. diffusers/quantizers/gguf/gguf_quantizer.py +159 -0
  176. diffusers/quantizers/gguf/utils.py +456 -0
  177. diffusers/quantizers/quantization_config.py +280 -2
  178. diffusers/quantizers/torchao/__init__.py +15 -0
  179. diffusers/quantizers/torchao/torchao_quantizer.py +292 -0
  180. diffusers/schedulers/scheduling_ddpm.py +2 -6
  181. diffusers/schedulers/scheduling_ddpm_parallel.py +2 -6
  182. diffusers/schedulers/scheduling_deis_multistep.py +28 -9
  183. diffusers/schedulers/scheduling_dpmsolver_multistep.py +35 -9
  184. diffusers/schedulers/scheduling_dpmsolver_multistep_inverse.py +35 -8
  185. diffusers/schedulers/scheduling_dpmsolver_sde.py +4 -4
  186. diffusers/schedulers/scheduling_dpmsolver_singlestep.py +48 -10
  187. diffusers/schedulers/scheduling_euler_discrete.py +4 -4
  188. diffusers/schedulers/scheduling_flow_match_euler_discrete.py +153 -6
  189. diffusers/schedulers/scheduling_heun_discrete.py +4 -4
  190. diffusers/schedulers/scheduling_k_dpm_2_ancestral_discrete.py +4 -4
  191. diffusers/schedulers/scheduling_k_dpm_2_discrete.py +4 -4
  192. diffusers/schedulers/scheduling_lcm.py +2 -6
  193. diffusers/schedulers/scheduling_lms_discrete.py +4 -4
  194. diffusers/schedulers/scheduling_repaint.py +1 -1
  195. diffusers/schedulers/scheduling_sasolver.py +28 -9
  196. diffusers/schedulers/scheduling_tcd.py +2 -6
  197. diffusers/schedulers/scheduling_unipc_multistep.py +53 -8
  198. diffusers/training_utils.py +16 -2
  199. diffusers/utils/__init__.py +5 -0
  200. diffusers/utils/constants.py +1 -0
  201. diffusers/utils/dummy_pt_objects.py +180 -0
  202. diffusers/utils/dummy_torch_and_transformers_objects.py +270 -0
  203. diffusers/utils/dynamic_modules_utils.py +3 -3
  204. diffusers/utils/hub_utils.py +31 -39
  205. diffusers/utils/import_utils.py +67 -0
  206. diffusers/utils/peft_utils.py +3 -0
  207. diffusers/utils/testing_utils.py +56 -1
  208. diffusers/utils/torch_utils.py +3 -0
  209. {diffusers-0.31.0.dist-info → diffusers-0.32.1.dist-info}/METADATA +6 -6
  210. {diffusers-0.31.0.dist-info → diffusers-0.32.1.dist-info}/RECORD +214 -162
  211. {diffusers-0.31.0.dist-info → diffusers-0.32.1.dist-info}/WHEEL +1 -1
  212. {diffusers-0.31.0.dist-info → diffusers-0.32.1.dist-info}/LICENSE +0 -0
  213. {diffusers-0.31.0.dist-info → diffusers-0.32.1.dist-info}/entry_points.txt +0 -0
  214. {diffusers-0.31.0.dist-info → diffusers-0.32.1.dist-info}/top_level.txt +0 -0
@@ -180,7 +180,7 @@ class FlaxDiffusionPipeline(ConfigMixin, PushToHubMixin):
180
180
 
181
181
  if push_to_hub:
182
182
  commit_message = kwargs.pop("commit_message", None)
183
- private = kwargs.pop("private", False)
183
+ private = kwargs.pop("private", None)
184
184
  create_pr = kwargs.pop("create_pr", False)
185
185
  token = kwargs.pop("token", None)
186
186
  repo_id = kwargs.pop("repo_id", save_directory.split(os.path.sep)[-1])
@@ -198,10 +198,31 @@ def variant_compatible_siblings(filenames, variant=None) -> Union[List[os.PathLi
198
198
  variant_filename = f"{filename.split('.')[0]}.{variant}.{filename.split('.')[1]}"
199
199
  return variant_filename
200
200
 
201
- for f in non_variant_filenames:
202
- variant_filename = convert_to_variant(f)
203
- if variant_filename not in usable_filenames:
204
- usable_filenames.add(f)
201
+ def find_component(filename):
202
+ if not len(filename.split("/")) == 2:
203
+ return
204
+ component = filename.split("/")[0]
205
+ return component
206
+
207
+ def has_sharded_variant(component, variant, variant_filenames):
208
+ # If component exists check for sharded variant index filename
209
+ # If component doesn't exist check main dir for sharded variant index filename
210
+ component = component + "/" if component else ""
211
+ variant_index_re = re.compile(
212
+ rf"{component}({'|'.join(weight_prefixes)})\.({'|'.join(weight_suffixs)})\.index\.{variant}\.json$"
213
+ )
214
+ return any(f for f in variant_filenames if variant_index_re.match(f) is not None)
215
+
216
+ for filename in non_variant_filenames:
217
+ if convert_to_variant(filename) in variant_filenames:
218
+ continue
219
+
220
+ component = find_component(filename)
221
+ # If a sharded variant exists skip adding to allowed patterns
222
+ if has_sharded_variant(component, variant, variant_filenames):
223
+ continue
224
+
225
+ usable_filenames.add(filename)
205
226
 
206
227
  return usable_filenames, variant_filenames
207
228
 
@@ -13,6 +13,7 @@
13
13
  # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
14
14
  # See the License for the specific language governing permissions and
15
15
  # limitations under the License.
16
+ import enum
16
17
  import fnmatch
17
18
  import importlib
18
19
  import inspect
@@ -66,7 +67,6 @@ from ..utils.torch_utils import is_compiled_module
66
67
  if is_torch_npu_available():
67
68
  import torch_npu # noqa: F401
68
69
 
69
-
70
70
  from .pipeline_loading_utils import (
71
71
  ALL_IMPORTABLE_CLASSES,
72
72
  CONNECTED_PIPES_KEYS,
@@ -229,7 +229,7 @@ class DiffusionPipeline(ConfigMixin, PushToHubMixin):
229
229
 
230
230
  if push_to_hub:
231
231
  commit_message = kwargs.pop("commit_message", None)
232
- private = kwargs.pop("private", False)
232
+ private = kwargs.pop("private", None)
233
233
  create_pr = kwargs.pop("create_pr", False)
234
234
  token = kwargs.pop("token", None)
235
235
  repo_id = kwargs.pop("repo_id", save_directory.split(os.path.sep)[-1])
@@ -388,6 +388,7 @@ class DiffusionPipeline(ConfigMixin, PushToHubMixin):
388
388
  )
389
389
 
390
390
  device = device or device_arg
391
+ pipeline_has_bnb = any(any((_check_bnb_status(module))) for _, module in self.components.items())
391
392
 
392
393
  # throw warning if pipeline is in "offloaded"-mode but user tries to manually set to GPU.
393
394
  def module_is_sequentially_offloaded(module):
@@ -410,10 +411,16 @@ class DiffusionPipeline(ConfigMixin, PushToHubMixin):
410
411
  pipeline_is_sequentially_offloaded = any(
411
412
  module_is_sequentially_offloaded(module) for _, module in self.components.items()
412
413
  )
413
- if pipeline_is_sequentially_offloaded and device and torch.device(device).type == "cuda":
414
- raise ValueError(
415
- "It seems like you have activated sequential model offloading by calling `enable_sequential_cpu_offload`, but are now attempting to move the pipeline to GPU. This is not compatible with offloading. Please, move your pipeline `.to('cpu')` or consider removing the move altogether if you use sequential offloading."
416
- )
414
+ if device and torch.device(device).type == "cuda":
415
+ if pipeline_is_sequentially_offloaded and not pipeline_has_bnb:
416
+ raise ValueError(
417
+ "It seems like you have activated sequential model offloading by calling `enable_sequential_cpu_offload`, but are now attempting to move the pipeline to GPU. This is not compatible with offloading. Please, move your pipeline `.to('cpu')` or consider removing the move altogether if you use sequential offloading."
418
+ )
419
+ # PR: https://github.com/huggingface/accelerate/pull/3223/
420
+ elif pipeline_has_bnb and is_accelerate_version("<", "1.1.0.dev0"):
421
+ raise ValueError(
422
+ "You are trying to call `.to('cuda')` on a pipeline that has models quantized with `bitsandbytes`. Your current `accelerate` installation does not support it. Please upgrade the installation."
423
+ )
417
424
 
418
425
  is_pipeline_device_mapped = self.hf_device_map is not None and len(self.hf_device_map) > 1
419
426
  if is_pipeline_device_mapped:
@@ -805,6 +812,7 @@ class DiffusionPipeline(ConfigMixin, PushToHubMixin):
805
812
  # in this case they are already instantiated in `kwargs`
806
813
  # extract them here
807
814
  expected_modules, optional_kwargs = cls._get_signature_keys(pipeline_class)
815
+ expected_types = pipeline_class._get_signature_types()
808
816
  passed_class_obj = {k: kwargs.pop(k) for k in expected_modules if k in kwargs}
809
817
  passed_pipe_kwargs = {k: kwargs.pop(k) for k in optional_kwargs if k in kwargs}
810
818
  init_dict, unused_kwargs, _ = pipeline_class.extract_init_dict(config_dict, **kwargs)
@@ -827,6 +835,26 @@ class DiffusionPipeline(ConfigMixin, PushToHubMixin):
827
835
 
828
836
  init_dict = {k: v for k, v in init_dict.items() if load_module(k, v)}
829
837
 
838
+ for key in init_dict.keys():
839
+ if key not in passed_class_obj:
840
+ continue
841
+ if "scheduler" in key:
842
+ continue
843
+
844
+ class_obj = passed_class_obj[key]
845
+ _expected_class_types = []
846
+ for expected_type in expected_types[key]:
847
+ if isinstance(expected_type, enum.EnumMeta):
848
+ _expected_class_types.extend(expected_type.__members__.keys())
849
+ else:
850
+ _expected_class_types.append(expected_type.__name__)
851
+
852
+ _is_valid_type = class_obj.__class__.__name__ in _expected_class_types
853
+ if not _is_valid_type:
854
+ logger.warning(
855
+ f"Expected types for {key}: {_expected_class_types}, got {class_obj.__class__.__name__}."
856
+ )
857
+
830
858
  # Special case: safety_checker must be loaded separately when using `from_flax`
831
859
  if from_flax and "safety_checker" in init_dict and "safety_checker" not in passed_class_obj:
832
860
  raise NotImplementedError(
@@ -1552,6 +1580,7 @@ class DiffusionPipeline(ConfigMixin, PushToHubMixin):
1552
1580
  """
1553
1581
  return numpy_to_pil(images)
1554
1582
 
1583
+ @torch.compiler.disable
1555
1584
  def progress_bar(self, iterable=None, total=None):
1556
1585
  if not hasattr(self, "_progress_bar_config"):
1557
1586
  self._progress_bar_config = {}
@@ -338,13 +338,6 @@ class PixArtAlphaPipeline(DiffusionPipeline):
338
338
  if device is None:
339
339
  device = self._execution_device
340
340
 
341
- if prompt is not None and isinstance(prompt, str):
342
- batch_size = 1
343
- elif prompt is not None and isinstance(prompt, list):
344
- batch_size = len(prompt)
345
- else:
346
- batch_size = prompt_embeds.shape[0]
347
-
348
341
  # See Section 3.1. of the paper.
349
342
  max_length = max_sequence_length
350
343
 
@@ -389,12 +382,12 @@ class PixArtAlphaPipeline(DiffusionPipeline):
389
382
  # duplicate text embeddings and attention mask for each generation per prompt, using mps friendly method
390
383
  prompt_embeds = prompt_embeds.repeat(1, num_images_per_prompt, 1)
391
384
  prompt_embeds = prompt_embeds.view(bs_embed * num_images_per_prompt, seq_len, -1)
392
- prompt_attention_mask = prompt_attention_mask.view(bs_embed, -1)
393
- prompt_attention_mask = prompt_attention_mask.repeat(num_images_per_prompt, 1)
385
+ prompt_attention_mask = prompt_attention_mask.repeat(1, num_images_per_prompt)
386
+ prompt_attention_mask = prompt_attention_mask.view(bs_embed * num_images_per_prompt, -1)
394
387
 
395
388
  # get unconditional embeddings for classifier free guidance
396
389
  if do_classifier_free_guidance and negative_prompt_embeds is None:
397
- uncond_tokens = [negative_prompt] * batch_size if isinstance(negative_prompt, str) else negative_prompt
390
+ uncond_tokens = [negative_prompt] * bs_embed if isinstance(negative_prompt, str) else negative_prompt
398
391
  uncond_tokens = self._text_preprocessing(uncond_tokens, clean_caption=clean_caption)
399
392
  max_length = prompt_embeds.shape[1]
400
393
  uncond_input = self.tokenizer(
@@ -421,10 +414,10 @@ class PixArtAlphaPipeline(DiffusionPipeline):
421
414
  negative_prompt_embeds = negative_prompt_embeds.to(dtype=dtype, device=device)
422
415
 
423
416
  negative_prompt_embeds = negative_prompt_embeds.repeat(1, num_images_per_prompt, 1)
424
- negative_prompt_embeds = negative_prompt_embeds.view(batch_size * num_images_per_prompt, seq_len, -1)
417
+ negative_prompt_embeds = negative_prompt_embeds.view(bs_embed * num_images_per_prompt, seq_len, -1)
425
418
 
426
- negative_prompt_attention_mask = negative_prompt_attention_mask.view(bs_embed, -1)
427
- negative_prompt_attention_mask = negative_prompt_attention_mask.repeat(num_images_per_prompt, 1)
419
+ negative_prompt_attention_mask = negative_prompt_attention_mask.repeat(1, num_images_per_prompt)
420
+ negative_prompt_attention_mask = negative_prompt_attention_mask.view(bs_embed * num_images_per_prompt, -1)
428
421
  else:
429
422
  negative_prompt_embeds = None
430
423
  negative_prompt_attention_mask = None
@@ -264,13 +264,6 @@ class PixArtSigmaPipeline(DiffusionPipeline):
264
264
  if device is None:
265
265
  device = self._execution_device
266
266
 
267
- if prompt is not None and isinstance(prompt, str):
268
- batch_size = 1
269
- elif prompt is not None and isinstance(prompt, list):
270
- batch_size = len(prompt)
271
- else:
272
- batch_size = prompt_embeds.shape[0]
273
-
274
267
  # See Section 3.1. of the paper.
275
268
  max_length = max_sequence_length
276
269
 
@@ -315,12 +308,12 @@ class PixArtSigmaPipeline(DiffusionPipeline):
315
308
  # duplicate text embeddings and attention mask for each generation per prompt, using mps friendly method
316
309
  prompt_embeds = prompt_embeds.repeat(1, num_images_per_prompt, 1)
317
310
  prompt_embeds = prompt_embeds.view(bs_embed * num_images_per_prompt, seq_len, -1)
318
- prompt_attention_mask = prompt_attention_mask.view(bs_embed, -1)
319
- prompt_attention_mask = prompt_attention_mask.repeat(num_images_per_prompt, 1)
311
+ prompt_attention_mask = prompt_attention_mask.repeat(1, num_images_per_prompt)
312
+ prompt_attention_mask = prompt_attention_mask.view(bs_embed * num_images_per_prompt, -1)
320
313
 
321
314
  # get unconditional embeddings for classifier free guidance
322
315
  if do_classifier_free_guidance and negative_prompt_embeds is None:
323
- uncond_tokens = [negative_prompt] * batch_size if isinstance(negative_prompt, str) else negative_prompt
316
+ uncond_tokens = [negative_prompt] * bs_embed if isinstance(negative_prompt, str) else negative_prompt
324
317
  uncond_tokens = self._text_preprocessing(uncond_tokens, clean_caption=clean_caption)
325
318
  max_length = prompt_embeds.shape[1]
326
319
  uncond_input = self.tokenizer(
@@ -347,10 +340,10 @@ class PixArtSigmaPipeline(DiffusionPipeline):
347
340
  negative_prompt_embeds = negative_prompt_embeds.to(dtype=dtype, device=device)
348
341
 
349
342
  negative_prompt_embeds = negative_prompt_embeds.repeat(1, num_images_per_prompt, 1)
350
- negative_prompt_embeds = negative_prompt_embeds.view(batch_size * num_images_per_prompt, seq_len, -1)
343
+ negative_prompt_embeds = negative_prompt_embeds.view(bs_embed * num_images_per_prompt, seq_len, -1)
351
344
 
352
- negative_prompt_attention_mask = negative_prompt_attention_mask.view(bs_embed, -1)
353
- negative_prompt_attention_mask = negative_prompt_attention_mask.repeat(num_images_per_prompt, 1)
345
+ negative_prompt_attention_mask = negative_prompt_attention_mask.repeat(1, num_images_per_prompt)
346
+ negative_prompt_attention_mask = negative_prompt_attention_mask.view(bs_embed * num_images_per_prompt, -1)
354
347
  else:
355
348
  negative_prompt_embeds = None
356
349
  negative_prompt_attention_mask = None
@@ -0,0 +1,47 @@
1
+ from typing import TYPE_CHECKING
2
+
3
+ from ...utils import (
4
+ DIFFUSERS_SLOW_IMPORT,
5
+ OptionalDependencyNotAvailable,
6
+ _LazyModule,
7
+ get_objects_from_module,
8
+ is_torch_available,
9
+ is_transformers_available,
10
+ )
11
+
12
+
13
+ _dummy_objects = {}
14
+ _import_structure = {}
15
+
16
+
17
+ try:
18
+ if not (is_transformers_available() and is_torch_available()):
19
+ raise OptionalDependencyNotAvailable()
20
+ except OptionalDependencyNotAvailable:
21
+ from ...utils import dummy_torch_and_transformers_objects # noqa F403
22
+
23
+ _dummy_objects.update(get_objects_from_module(dummy_torch_and_transformers_objects))
24
+ else:
25
+ _import_structure["pipeline_sana"] = ["SanaPipeline"]
26
+
27
+ if TYPE_CHECKING or DIFFUSERS_SLOW_IMPORT:
28
+ try:
29
+ if not (is_transformers_available() and is_torch_available()):
30
+ raise OptionalDependencyNotAvailable()
31
+
32
+ except OptionalDependencyNotAvailable:
33
+ from ...utils.dummy_torch_and_transformers_objects import *
34
+ else:
35
+ from .pipeline_sana import SanaPipeline
36
+ else:
37
+ import sys
38
+
39
+ sys.modules[__name__] = _LazyModule(
40
+ __name__,
41
+ globals()["__file__"],
42
+ _import_structure,
43
+ module_spec=__spec__,
44
+ )
45
+
46
+ for name, value in _dummy_objects.items():
47
+ setattr(sys.modules[__name__], name, value)
@@ -0,0 +1,21 @@
1
+ from dataclasses import dataclass
2
+ from typing import List, Union
3
+
4
+ import numpy as np
5
+ import PIL.Image
6
+
7
+ from ...utils import BaseOutput
8
+
9
+
10
+ @dataclass
11
+ class SanaPipelineOutput(BaseOutput):
12
+ """
13
+ Output class for Sana pipelines.
14
+
15
+ Args:
16
+ images (`List[PIL.Image.Image]` or `np.ndarray`)
17
+ List of denoised PIL images of length `batch_size` or numpy array of shape `(batch_size, height, width,
18
+ num_channels)`. PIL images or numpy array present the denoised images of the diffusion pipeline.
19
+ """
20
+
21
+ images: Union[List[PIL.Image.Image], np.ndarray]