diffusers 0.30.2__py3-none-any.whl → 0.31.0__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (173) hide show
  1. diffusers/__init__.py +38 -2
  2. diffusers/configuration_utils.py +12 -0
  3. diffusers/dependency_versions_table.py +1 -1
  4. diffusers/image_processor.py +257 -54
  5. diffusers/loaders/__init__.py +2 -0
  6. diffusers/loaders/ip_adapter.py +5 -1
  7. diffusers/loaders/lora_base.py +14 -7
  8. diffusers/loaders/lora_conversion_utils.py +332 -0
  9. diffusers/loaders/lora_pipeline.py +707 -41
  10. diffusers/loaders/peft.py +1 -0
  11. diffusers/loaders/single_file_utils.py +81 -4
  12. diffusers/loaders/textual_inversion.py +2 -0
  13. diffusers/loaders/unet.py +39 -8
  14. diffusers/models/__init__.py +4 -0
  15. diffusers/models/adapter.py +53 -53
  16. diffusers/models/attention.py +86 -10
  17. diffusers/models/attention_processor.py +169 -133
  18. diffusers/models/autoencoders/autoencoder_kl.py +71 -11
  19. diffusers/models/autoencoders/autoencoder_kl_cogvideox.py +287 -85
  20. diffusers/models/controlnet_flux.py +536 -0
  21. diffusers/models/controlnet_sd3.py +7 -3
  22. diffusers/models/controlnet_sparsectrl.py +0 -1
  23. diffusers/models/embeddings.py +238 -61
  24. diffusers/models/embeddings_flax.py +23 -9
  25. diffusers/models/model_loading_utils.py +182 -14
  26. diffusers/models/modeling_utils.py +283 -46
  27. diffusers/models/normalization.py +79 -0
  28. diffusers/models/transformers/__init__.py +1 -0
  29. diffusers/models/transformers/auraflow_transformer_2d.py +1 -0
  30. diffusers/models/transformers/cogvideox_transformer_3d.py +58 -36
  31. diffusers/models/transformers/pixart_transformer_2d.py +9 -1
  32. diffusers/models/transformers/transformer_cogview3plus.py +386 -0
  33. diffusers/models/transformers/transformer_flux.py +161 -44
  34. diffusers/models/transformers/transformer_sd3.py +7 -1
  35. diffusers/models/unets/unet_2d_condition.py +8 -8
  36. diffusers/models/unets/unet_motion_model.py +41 -63
  37. diffusers/models/upsampling.py +6 -6
  38. diffusers/pipelines/__init__.py +40 -7
  39. diffusers/pipelines/animatediff/__init__.py +2 -0
  40. diffusers/pipelines/animatediff/pipeline_animatediff.py +45 -21
  41. diffusers/pipelines/animatediff/pipeline_animatediff_controlnet.py +44 -20
  42. diffusers/pipelines/animatediff/pipeline_animatediff_sdxl.py +18 -4
  43. diffusers/pipelines/animatediff/pipeline_animatediff_sparsectrl.py +2 -0
  44. diffusers/pipelines/animatediff/pipeline_animatediff_video2video.py +104 -66
  45. diffusers/pipelines/animatediff/pipeline_animatediff_video2video_controlnet.py +1341 -0
  46. diffusers/pipelines/aura_flow/pipeline_aura_flow.py +1 -1
  47. diffusers/pipelines/auto_pipeline.py +39 -8
  48. diffusers/pipelines/cogvideo/__init__.py +6 -0
  49. diffusers/pipelines/cogvideo/pipeline_cogvideox.py +32 -34
  50. diffusers/pipelines/cogvideo/pipeline_cogvideox_fun_control.py +794 -0
  51. diffusers/pipelines/cogvideo/pipeline_cogvideox_image2video.py +837 -0
  52. diffusers/pipelines/cogvideo/pipeline_cogvideox_video2video.py +825 -0
  53. diffusers/pipelines/cogvideo/pipeline_output.py +20 -0
  54. diffusers/pipelines/cogview3/__init__.py +47 -0
  55. diffusers/pipelines/cogview3/pipeline_cogview3plus.py +674 -0
  56. diffusers/pipelines/cogview3/pipeline_output.py +21 -0
  57. diffusers/pipelines/controlnet/pipeline_controlnet.py +9 -1
  58. diffusers/pipelines/controlnet/pipeline_controlnet_img2img.py +8 -0
  59. diffusers/pipelines/controlnet/pipeline_controlnet_inpaint.py +8 -0
  60. diffusers/pipelines/controlnet/pipeline_controlnet_inpaint_sd_xl.py +36 -13
  61. diffusers/pipelines/controlnet/pipeline_controlnet_sd_xl.py +9 -1
  62. diffusers/pipelines/controlnet/pipeline_controlnet_sd_xl_img2img.py +8 -1
  63. diffusers/pipelines/controlnet_hunyuandit/pipeline_hunyuandit_controlnet.py +17 -3
  64. diffusers/pipelines/controlnet_sd3/__init__.py +4 -0
  65. diffusers/pipelines/controlnet_sd3/pipeline_stable_diffusion_3_controlnet.py +3 -1
  66. diffusers/pipelines/controlnet_sd3/pipeline_stable_diffusion_3_controlnet_inpainting.py +1153 -0
  67. diffusers/pipelines/ddpm/pipeline_ddpm.py +2 -2
  68. diffusers/pipelines/deepfloyd_if/pipeline_output.py +6 -5
  69. diffusers/pipelines/deprecated/alt_diffusion/pipeline_alt_diffusion.py +16 -4
  70. diffusers/pipelines/deprecated/alt_diffusion/pipeline_alt_diffusion_img2img.py +1 -1
  71. diffusers/pipelines/deprecated/versatile_diffusion/modeling_text_unet.py +1 -1
  72. diffusers/pipelines/flux/__init__.py +10 -0
  73. diffusers/pipelines/flux/pipeline_flux.py +53 -20
  74. diffusers/pipelines/flux/pipeline_flux_controlnet.py +984 -0
  75. diffusers/pipelines/flux/pipeline_flux_controlnet_image_to_image.py +988 -0
  76. diffusers/pipelines/flux/pipeline_flux_controlnet_inpainting.py +1182 -0
  77. diffusers/pipelines/flux/pipeline_flux_img2img.py +850 -0
  78. diffusers/pipelines/flux/pipeline_flux_inpaint.py +1015 -0
  79. diffusers/pipelines/free_noise_utils.py +365 -5
  80. diffusers/pipelines/hunyuandit/pipeline_hunyuandit.py +15 -3
  81. diffusers/pipelines/kandinsky2_2/pipeline_kandinsky2_2_combined.py +2 -2
  82. diffusers/pipelines/kolors/pipeline_kolors.py +1 -1
  83. diffusers/pipelines/kolors/pipeline_kolors_img2img.py +14 -11
  84. diffusers/pipelines/kolors/tokenizer.py +4 -0
  85. diffusers/pipelines/latent_consistency_models/pipeline_latent_consistency_img2img.py +1 -1
  86. diffusers/pipelines/latent_consistency_models/pipeline_latent_consistency_text2img.py +1 -1
  87. diffusers/pipelines/latte/pipeline_latte.py +2 -2
  88. diffusers/pipelines/ledits_pp/pipeline_leditspp_stable_diffusion.py +15 -3
  89. diffusers/pipelines/ledits_pp/pipeline_leditspp_stable_diffusion_xl.py +15 -3
  90. diffusers/pipelines/lumina/pipeline_lumina.py +2 -2
  91. diffusers/pipelines/pag/__init__.py +6 -0
  92. diffusers/pipelines/pag/pag_utils.py +8 -2
  93. diffusers/pipelines/pag/pipeline_pag_controlnet_sd.py +1 -1
  94. diffusers/pipelines/pag/pipeline_pag_controlnet_sd_inpaint.py +1544 -0
  95. diffusers/pipelines/pag/pipeline_pag_controlnet_sd_xl.py +2 -2
  96. diffusers/pipelines/pag/pipeline_pag_controlnet_sd_xl_img2img.py +1685 -0
  97. diffusers/pipelines/pag/pipeline_pag_hunyuandit.py +17 -5
  98. diffusers/pipelines/pag/pipeline_pag_kolors.py +1 -1
  99. diffusers/pipelines/pag/pipeline_pag_pixart_sigma.py +1 -1
  100. diffusers/pipelines/pag/pipeline_pag_sd.py +18 -6
  101. diffusers/pipelines/pag/pipeline_pag_sd_3.py +12 -3
  102. diffusers/pipelines/pag/pipeline_pag_sd_animatediff.py +5 -1
  103. diffusers/pipelines/pag/pipeline_pag_sd_img2img.py +1091 -0
  104. diffusers/pipelines/pag/pipeline_pag_sd_xl.py +18 -6
  105. diffusers/pipelines/pag/pipeline_pag_sd_xl_img2img.py +31 -16
  106. diffusers/pipelines/pag/pipeline_pag_sd_xl_inpaint.py +42 -19
  107. diffusers/pipelines/pia/pipeline_pia.py +2 -0
  108. diffusers/pipelines/pipeline_loading_utils.py +225 -27
  109. diffusers/pipelines/pipeline_utils.py +123 -180
  110. diffusers/pipelines/pixart_alpha/pipeline_pixart_alpha.py +1 -1
  111. diffusers/pipelines/pixart_alpha/pipeline_pixart_sigma.py +1 -1
  112. diffusers/pipelines/stable_cascade/pipeline_stable_cascade.py +35 -3
  113. diffusers/pipelines/stable_cascade/pipeline_stable_cascade_prior.py +2 -2
  114. diffusers/pipelines/stable_diffusion/pipeline_stable_diffusion.py +28 -6
  115. diffusers/pipelines/stable_diffusion/pipeline_stable_diffusion_img2img.py +1 -1
  116. diffusers/pipelines/stable_diffusion/pipeline_stable_diffusion_inpaint.py +1 -1
  117. diffusers/pipelines/stable_diffusion/pipeline_stable_diffusion_latent_upscale.py +241 -81
  118. diffusers/pipelines/stable_diffusion_3/pipeline_stable_diffusion_3.py +12 -3
  119. diffusers/pipelines/stable_diffusion_3/pipeline_stable_diffusion_3_img2img.py +20 -4
  120. diffusers/pipelines/stable_diffusion_3/pipeline_stable_diffusion_3_inpaint.py +3 -3
  121. diffusers/pipelines/stable_diffusion_k_diffusion/pipeline_stable_diffusion_k_diffusion.py +1 -1
  122. diffusers/pipelines/stable_diffusion_ldm3d/pipeline_stable_diffusion_ldm3d.py +16 -4
  123. diffusers/pipelines/stable_diffusion_panorama/pipeline_stable_diffusion_panorama.py +16 -4
  124. diffusers/pipelines/stable_diffusion_xl/pipeline_stable_diffusion_xl.py +16 -4
  125. diffusers/pipelines/stable_diffusion_xl/pipeline_stable_diffusion_xl_img2img.py +29 -14
  126. diffusers/pipelines/stable_diffusion_xl/pipeline_stable_diffusion_xl_inpaint.py +29 -14
  127. diffusers/pipelines/stable_video_diffusion/pipeline_stable_video_diffusion.py +1 -1
  128. diffusers/pipelines/t2i_adapter/pipeline_stable_diffusion_adapter.py +1 -1
  129. diffusers/pipelines/t2i_adapter/pipeline_stable_diffusion_xl_adapter.py +16 -4
  130. diffusers/pipelines/text_to_video_synthesis/pipeline_text_to_video_zero_sdxl.py +15 -3
  131. diffusers/quantizers/__init__.py +16 -0
  132. diffusers/quantizers/auto.py +126 -0
  133. diffusers/quantizers/base.py +233 -0
  134. diffusers/quantizers/bitsandbytes/__init__.py +2 -0
  135. diffusers/quantizers/bitsandbytes/bnb_quantizer.py +558 -0
  136. diffusers/quantizers/bitsandbytes/utils.py +306 -0
  137. diffusers/quantizers/quantization_config.py +391 -0
  138. diffusers/schedulers/scheduling_ddim.py +4 -1
  139. diffusers/schedulers/scheduling_ddim_cogvideox.py +4 -1
  140. diffusers/schedulers/scheduling_ddim_parallel.py +4 -1
  141. diffusers/schedulers/scheduling_ddpm.py +4 -1
  142. diffusers/schedulers/scheduling_ddpm_parallel.py +4 -1
  143. diffusers/schedulers/scheduling_deis_multistep.py +78 -1
  144. diffusers/schedulers/scheduling_dpmsolver_multistep.py +82 -1
  145. diffusers/schedulers/scheduling_dpmsolver_multistep_inverse.py +80 -1
  146. diffusers/schedulers/scheduling_dpmsolver_sde.py +125 -10
  147. diffusers/schedulers/scheduling_dpmsolver_singlestep.py +82 -1
  148. diffusers/schedulers/scheduling_edm_euler.py +8 -6
  149. diffusers/schedulers/scheduling_euler_ancestral_discrete.py +4 -1
  150. diffusers/schedulers/scheduling_euler_discrete.py +92 -7
  151. diffusers/schedulers/scheduling_flow_match_heun_discrete.py +4 -5
  152. diffusers/schedulers/scheduling_heun_discrete.py +114 -8
  153. diffusers/schedulers/scheduling_k_dpm_2_ancestral_discrete.py +116 -11
  154. diffusers/schedulers/scheduling_k_dpm_2_discrete.py +110 -8
  155. diffusers/schedulers/scheduling_lms_discrete.py +76 -1
  156. diffusers/schedulers/scheduling_sasolver.py +78 -1
  157. diffusers/schedulers/scheduling_unclip.py +4 -1
  158. diffusers/schedulers/scheduling_unipc_multistep.py +78 -1
  159. diffusers/training_utils.py +48 -18
  160. diffusers/utils/__init__.py +2 -1
  161. diffusers/utils/dummy_pt_objects.py +60 -0
  162. diffusers/utils/dummy_torch_and_transformers_objects.py +195 -0
  163. diffusers/utils/hub_utils.py +16 -4
  164. diffusers/utils/import_utils.py +31 -8
  165. diffusers/utils/loading_utils.py +28 -4
  166. diffusers/utils/peft_utils.py +3 -3
  167. diffusers/utils/testing_utils.py +59 -0
  168. {diffusers-0.30.2.dist-info → diffusers-0.31.0.dist-info}/METADATA +7 -6
  169. {diffusers-0.30.2.dist-info → diffusers-0.31.0.dist-info}/RECORD +173 -147
  170. {diffusers-0.30.2.dist-info → diffusers-0.31.0.dist-info}/WHEEL +1 -1
  171. {diffusers-0.30.2.dist-info → diffusers-0.31.0.dist-info}/LICENSE +0 -0
  172. {diffusers-0.30.2.dist-info → diffusers-0.31.0.dist-info}/entry_points.txt +0 -0
  173. {diffusers-0.30.2.dist-info → diffusers-0.31.0.dist-info}/top_level.txt +0 -0
@@ -101,10 +101,10 @@ class DDPMPipeline(DiffusionPipeline):
101
101
 
102
102
  if self.device.type == "mps":
103
103
  # randn does not work reproducibly on mps
104
- image = randn_tensor(image_shape, generator=generator)
104
+ image = randn_tensor(image_shape, generator=generator, dtype=self.unet.dtype)
105
105
  image = image.to(self.device)
106
106
  else:
107
- image = randn_tensor(image_shape, generator=generator, device=self.device)
107
+ image = randn_tensor(image_shape, generator=generator, device=self.device, dtype=self.unet.dtype)
108
108
 
109
109
  # set step values
110
110
  self.scheduler.set_timesteps(num_inference_steps)
@@ -9,16 +9,17 @@ from ...utils import BaseOutput
9
9
 
10
10
  @dataclass
11
11
  class IFPipelineOutput(BaseOutput):
12
- """
13
- Args:
12
+ r"""
14
13
  Output class for Stable Diffusion pipelines.
15
- images (`List[PIL.Image.Image]` or `np.ndarray`)
14
+
15
+ Args:
16
+ images (`List[PIL.Image.Image]` or `np.ndarray`):
16
17
  List of denoised PIL images of length `batch_size` or numpy array of shape `(batch_size, height, width,
17
18
  num_channels)`. PIL images or numpy array present the denoised images of the diffusion pipeline.
18
- nsfw_detected (`List[bool]`)
19
+ nsfw_detected (`List[bool]`):
19
20
  List of flags denoting whether the corresponding generated image likely represents "not-safe-for-work"
20
21
  (nsfw) content or a watermark. `None` if safety checking could not be performed.
21
- watermark_detected (`List[bool]`)
22
+ watermark_detected (`List[bool]`):
22
23
  List of flags denoting whether the corresponding generated image likely has a watermark. `None` if safety
23
24
  checking could not be performed.
24
25
  """
@@ -65,9 +65,21 @@ EXAMPLE_DOC_STRING = """
65
65
 
66
66
  # Copied from diffusers.pipelines.stable_diffusion.pipeline_stable_diffusion.rescale_noise_cfg
67
67
  def rescale_noise_cfg(noise_cfg, noise_pred_text, guidance_rescale=0.0):
68
- """
69
- Rescale `noise_cfg` according to `guidance_rescale`. Based on findings of [Common Diffusion Noise Schedules and
70
- Sample Steps are Flawed](https://arxiv.org/pdf/2305.08891.pdf). See Section 3.4
68
+ r"""
69
+ Rescales `noise_cfg` tensor based on `guidance_rescale` to improve image quality and fix overexposure. Based on
70
+ Section 3.4 from [Common Diffusion Noise Schedules and Sample Steps are
71
+ Flawed](https://arxiv.org/pdf/2305.08891.pdf).
72
+
73
+ Args:
74
+ noise_cfg (`torch.Tensor`):
75
+ The predicted noise tensor for the guided diffusion process.
76
+ noise_pred_text (`torch.Tensor`):
77
+ The predicted noise tensor for the text-guided diffusion process.
78
+ guidance_rescale (`float`, *optional*, defaults to 0.0):
79
+ A rescale factor applied to the noise predictions.
80
+
81
+ Returns:
82
+ noise_cfg (`torch.Tensor`): The rescaled noise prediction tensor.
71
83
  """
72
84
  std_text = noise_pred_text.std(dim=list(range(1, noise_pred_text.ndim)), keepdim=True)
73
85
  std_cfg = noise_cfg.std(dim=list(range(1, noise_cfg.ndim)), keepdim=True)
@@ -87,7 +99,7 @@ def retrieve_timesteps(
87
99
  sigmas: Optional[List[float]] = None,
88
100
  **kwargs,
89
101
  ):
90
- """
102
+ r"""
91
103
  Calls the scheduler's `set_timesteps` method and retrieves timesteps from the scheduler after the call. Handles
92
104
  custom timesteps. Any kwargs will be supplied to `scheduler.set_timesteps`.
93
105
 
@@ -127,7 +127,7 @@ def retrieve_timesteps(
127
127
  sigmas: Optional[List[float]] = None,
128
128
  **kwargs,
129
129
  ):
130
- """
130
+ r"""
131
131
  Calls the scheduler's `set_timesteps` method and retrieves timesteps from the scheduler after the call. Handles
132
132
  custom timesteps. Any kwargs will be supplied to `scheduler.set_timesteps`.
133
133
 
@@ -546,7 +546,7 @@ class UNetFlatConditionModel(ModelMixin, ConfigMixin):
546
546
  )
547
547
  elif encoder_hid_dim_type is not None:
548
548
  raise ValueError(
549
- f"encoder_hid_dim_type: {encoder_hid_dim_type} must be None, 'text_proj' or 'text_image_proj'."
549
+ f"`encoder_hid_dim_type`: {encoder_hid_dim_type} must be None, 'text_proj', 'text_image_proj' or 'image_proj'."
550
550
  )
551
551
  else:
552
552
  self.encoder_hid_proj = None
@@ -23,6 +23,11 @@ except OptionalDependencyNotAvailable:
23
23
  _dummy_objects.update(get_objects_from_module(dummy_torch_and_transformers_objects))
24
24
  else:
25
25
  _import_structure["pipeline_flux"] = ["FluxPipeline"]
26
+ _import_structure["pipeline_flux_controlnet"] = ["FluxControlNetPipeline"]
27
+ _import_structure["pipeline_flux_controlnet_image_to_image"] = ["FluxControlNetImg2ImgPipeline"]
28
+ _import_structure["pipeline_flux_controlnet_inpainting"] = ["FluxControlNetInpaintPipeline"]
29
+ _import_structure["pipeline_flux_img2img"] = ["FluxImg2ImgPipeline"]
30
+ _import_structure["pipeline_flux_inpaint"] = ["FluxInpaintPipeline"]
26
31
  if TYPE_CHECKING or DIFFUSERS_SLOW_IMPORT:
27
32
  try:
28
33
  if not (is_transformers_available() and is_torch_available()):
@@ -31,6 +36,11 @@ if TYPE_CHECKING or DIFFUSERS_SLOW_IMPORT:
31
36
  from ...utils.dummy_torch_and_transformers_objects import * # noqa F403
32
37
  else:
33
38
  from .pipeline_flux import FluxPipeline
39
+ from .pipeline_flux_controlnet import FluxControlNetPipeline
40
+ from .pipeline_flux_controlnet_image_to_image import FluxControlNetImg2ImgPipeline
41
+ from .pipeline_flux_controlnet_inpainting import FluxControlNetInpaintPipeline
42
+ from .pipeline_flux_img2img import FluxImg2ImgPipeline
43
+ from .pipeline_flux_inpaint import FluxInpaintPipeline
34
44
  else:
35
45
  import sys
36
46
 
@@ -20,7 +20,7 @@ import torch
20
20
  from transformers import CLIPTextModel, CLIPTokenizer, T5EncoderModel, T5TokenizerFast
21
21
 
22
22
  from ...image_processor import VaeImageProcessor
23
- from ...loaders import FluxLoraLoaderMixin
23
+ from ...loaders import FluxLoraLoaderMixin, FromSingleFileMixin, TextualInversionLoaderMixin
24
24
  from ...models.autoencoders import AutoencoderKL
25
25
  from ...models.transformers import FluxTransformer2DModel
26
26
  from ...schedulers import FlowMatchEulerDiscreteScheduler
@@ -86,7 +86,7 @@ def retrieve_timesteps(
86
86
  sigmas: Optional[List[float]] = None,
87
87
  **kwargs,
88
88
  ):
89
- """
89
+ r"""
90
90
  Calls the scheduler's `set_timesteps` method and retrieves timesteps from the scheduler after the call. Handles
91
91
  custom timesteps. Any kwargs will be supplied to `scheduler.set_timesteps`.
92
92
 
@@ -137,7 +137,12 @@ def retrieve_timesteps(
137
137
  return timesteps, num_inference_steps
138
138
 
139
139
 
140
- class FluxPipeline(DiffusionPipeline, FluxLoraLoaderMixin):
140
+ class FluxPipeline(
141
+ DiffusionPipeline,
142
+ FluxLoraLoaderMixin,
143
+ FromSingleFileMixin,
144
+ TextualInversionLoaderMixin,
145
+ ):
141
146
  r"""
142
147
  The Flux pipeline for text-to-image generation.
143
148
 
@@ -212,6 +217,9 @@ class FluxPipeline(DiffusionPipeline, FluxLoraLoaderMixin):
212
217
  prompt = [prompt] if isinstance(prompt, str) else prompt
213
218
  batch_size = len(prompt)
214
219
 
220
+ if isinstance(self, TextualInversionLoaderMixin):
221
+ prompt = self.maybe_convert_prompt(prompt, self.tokenizer_2)
222
+
215
223
  text_inputs = self.tokenizer_2(
216
224
  prompt,
217
225
  padding="max_length",
@@ -255,6 +263,9 @@ class FluxPipeline(DiffusionPipeline, FluxLoraLoaderMixin):
255
263
  prompt = [prompt] if isinstance(prompt, str) else prompt
256
264
  batch_size = len(prompt)
257
265
 
266
+ if isinstance(self, TextualInversionLoaderMixin):
267
+ prompt = self.maybe_convert_prompt(prompt, self.tokenizer)
268
+
258
269
  text_inputs = self.tokenizer(
259
270
  prompt,
260
271
  padding="max_length",
@@ -331,10 +342,6 @@ class FluxPipeline(DiffusionPipeline, FluxLoraLoaderMixin):
331
342
  scale_lora_layers(self.text_encoder_2, lora_scale)
332
343
 
333
344
  prompt = [prompt] if isinstance(prompt, str) else prompt
334
- if prompt is not None:
335
- batch_size = len(prompt)
336
- else:
337
- batch_size = prompt_embeds.shape[0]
338
345
 
339
346
  if prompt_embeds is None:
340
347
  prompt_2 = prompt_2 or prompt
@@ -364,8 +371,7 @@ class FluxPipeline(DiffusionPipeline, FluxLoraLoaderMixin):
364
371
  unscale_lora_layers(self.text_encoder_2, lora_scale)
365
372
 
366
373
  dtype = self.text_encoder.dtype if self.text_encoder is not None else self.transformer.dtype
367
- text_ids = torch.zeros(batch_size, prompt_embeds.shape[1], 3).to(device=device, dtype=dtype)
368
- text_ids = text_ids.repeat(num_images_per_prompt, 1, 1)
374
+ text_ids = torch.zeros(prompt_embeds.shape[1], 3).to(device=device, dtype=dtype)
369
375
 
370
376
  return prompt_embeds, pooled_prompt_embeds, text_ids
371
377
 
@@ -425,9 +431,8 @@ class FluxPipeline(DiffusionPipeline, FluxLoraLoaderMixin):
425
431
 
426
432
  latent_image_id_height, latent_image_id_width, latent_image_id_channels = latent_image_ids.shape
427
433
 
428
- latent_image_ids = latent_image_ids[None, :].repeat(batch_size, 1, 1, 1)
429
434
  latent_image_ids = latent_image_ids.reshape(
430
- batch_size, latent_image_id_height * latent_image_id_width, latent_image_id_channels
435
+ latent_image_id_height * latent_image_id_width, latent_image_id_channels
431
436
  )
432
437
 
433
438
  return latent_image_ids.to(device=device, dtype=dtype)
@@ -454,6 +459,35 @@ class FluxPipeline(DiffusionPipeline, FluxLoraLoaderMixin):
454
459
 
455
460
  return latents
456
461
 
462
+ def enable_vae_slicing(self):
463
+ r"""
464
+ Enable sliced VAE decoding. When this option is enabled, the VAE will split the input tensor in slices to
465
+ compute decoding in several steps. This is useful to save some memory and allow larger batch sizes.
466
+ """
467
+ self.vae.enable_slicing()
468
+
469
+ def disable_vae_slicing(self):
470
+ r"""
471
+ Disable sliced VAE decoding. If `enable_vae_slicing` was previously enabled, this method will go back to
472
+ computing decoding in one step.
473
+ """
474
+ self.vae.disable_slicing()
475
+
476
+ def enable_vae_tiling(self):
477
+ r"""
478
+ Enable tiled VAE decoding. When this option is enabled, the VAE will split the input tensor into tiles to
479
+ compute decoding and encoding in several steps. This is useful for saving a large amount of memory and to allow
480
+ processing larger images.
481
+ """
482
+ self.vae.enable_tiling()
483
+
484
+ def disable_vae_tiling(self):
485
+ r"""
486
+ Disable tiled VAE decoding. If `enable_vae_tiling` was previously enabled, this method will go back to
487
+ computing decoding in one step.
488
+ """
489
+ self.vae.disable_tiling()
490
+
457
491
  def prepare_latents(
458
492
  self,
459
493
  batch_size,
@@ -513,7 +547,7 @@ class FluxPipeline(DiffusionPipeline, FluxLoraLoaderMixin):
513
547
  width: Optional[int] = None,
514
548
  num_inference_steps: int = 28,
515
549
  timesteps: List[int] = None,
516
- guidance_scale: float = 7.0,
550
+ guidance_scale: float = 3.5,
517
551
  num_images_per_prompt: Optional[int] = 1,
518
552
  generator: Optional[Union[torch.Generator, List[torch.Generator]]] = None,
519
553
  latents: Optional[torch.FloatTensor] = None,
@@ -677,6 +711,13 @@ class FluxPipeline(DiffusionPipeline, FluxLoraLoaderMixin):
677
711
  num_warmup_steps = max(len(timesteps) - num_inference_steps * self.scheduler.order, 0)
678
712
  self._num_timesteps = len(timesteps)
679
713
 
714
+ # handle guidance
715
+ if self.transformer.config.guidance_embeds:
716
+ guidance = torch.full([1], guidance_scale, device=device, dtype=torch.float32)
717
+ guidance = guidance.expand(latents.shape[0])
718
+ else:
719
+ guidance = None
720
+
680
721
  # 6. Denoising loop
681
722
  with self.progress_bar(total=num_inference_steps) as progress_bar:
682
723
  for i, t in enumerate(timesteps):
@@ -686,16 +727,8 @@ class FluxPipeline(DiffusionPipeline, FluxLoraLoaderMixin):
686
727
  # broadcast to batch dimension in a way that's compatible with ONNX/Core ML
687
728
  timestep = t.expand(latents.shape[0]).to(latents.dtype)
688
729
 
689
- # handle guidance
690
- if self.transformer.config.guidance_embeds:
691
- guidance = torch.tensor([guidance_scale], device=device)
692
- guidance = guidance.expand(latents.shape[0])
693
- else:
694
- guidance = None
695
-
696
730
  noise_pred = self.transformer(
697
731
  hidden_states=latents,
698
- # YiYi notes: divide it by 1000 for now because we scale it by 1000 in the transforme rmodel (we should not keep it but I want to keep the inputs same for the model for testing)
699
732
  timestep=timestep / 1000,
700
733
  guidance=guidance,
701
734
  pooled_projections=pooled_prompt_embeds,