diffusers 0.30.3__py3-none-any.whl → 0.31.0__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- diffusers/__init__.py +34 -2
- diffusers/configuration_utils.py +12 -0
- diffusers/dependency_versions_table.py +1 -1
- diffusers/image_processor.py +257 -54
- diffusers/loaders/__init__.py +2 -0
- diffusers/loaders/ip_adapter.py +5 -1
- diffusers/loaders/lora_base.py +14 -7
- diffusers/loaders/lora_conversion_utils.py +332 -0
- diffusers/loaders/lora_pipeline.py +707 -41
- diffusers/loaders/peft.py +1 -0
- diffusers/loaders/single_file_utils.py +81 -4
- diffusers/loaders/textual_inversion.py +2 -0
- diffusers/loaders/unet.py +39 -8
- diffusers/models/__init__.py +4 -0
- diffusers/models/adapter.py +53 -53
- diffusers/models/attention.py +86 -10
- diffusers/models/attention_processor.py +169 -133
- diffusers/models/autoencoders/autoencoder_kl.py +71 -11
- diffusers/models/autoencoders/autoencoder_kl_cogvideox.py +187 -88
- diffusers/models/controlnet_flux.py +536 -0
- diffusers/models/controlnet_sd3.py +7 -3
- diffusers/models/controlnet_sparsectrl.py +0 -1
- diffusers/models/embeddings.py +170 -61
- diffusers/models/embeddings_flax.py +23 -9
- diffusers/models/model_loading_utils.py +182 -14
- diffusers/models/modeling_utils.py +283 -46
- diffusers/models/normalization.py +79 -0
- diffusers/models/transformers/__init__.py +1 -0
- diffusers/models/transformers/auraflow_transformer_2d.py +1 -0
- diffusers/models/transformers/cogvideox_transformer_3d.py +23 -2
- diffusers/models/transformers/pixart_transformer_2d.py +9 -1
- diffusers/models/transformers/transformer_cogview3plus.py +386 -0
- diffusers/models/transformers/transformer_flux.py +161 -44
- diffusers/models/transformers/transformer_sd3.py +7 -1
- diffusers/models/unets/unet_2d_condition.py +8 -8
- diffusers/models/unets/unet_motion_model.py +41 -63
- diffusers/models/upsampling.py +6 -6
- diffusers/pipelines/__init__.py +35 -6
- diffusers/pipelines/animatediff/__init__.py +2 -0
- diffusers/pipelines/animatediff/pipeline_animatediff.py +45 -21
- diffusers/pipelines/animatediff/pipeline_animatediff_controlnet.py +44 -20
- diffusers/pipelines/animatediff/pipeline_animatediff_sdxl.py +18 -4
- diffusers/pipelines/animatediff/pipeline_animatediff_sparsectrl.py +2 -0
- diffusers/pipelines/animatediff/pipeline_animatediff_video2video.py +104 -66
- diffusers/pipelines/animatediff/pipeline_animatediff_video2video_controlnet.py +1341 -0
- diffusers/pipelines/aura_flow/pipeline_aura_flow.py +1 -1
- diffusers/pipelines/auto_pipeline.py +39 -8
- diffusers/pipelines/cogvideo/__init__.py +2 -0
- diffusers/pipelines/cogvideo/pipeline_cogvideox.py +30 -17
- diffusers/pipelines/cogvideo/pipeline_cogvideox_fun_control.py +794 -0
- diffusers/pipelines/cogvideo/pipeline_cogvideox_image2video.py +41 -31
- diffusers/pipelines/cogvideo/pipeline_cogvideox_video2video.py +42 -29
- diffusers/pipelines/cogview3/__init__.py +47 -0
- diffusers/pipelines/cogview3/pipeline_cogview3plus.py +674 -0
- diffusers/pipelines/cogview3/pipeline_output.py +21 -0
- diffusers/pipelines/controlnet/pipeline_controlnet.py +9 -1
- diffusers/pipelines/controlnet/pipeline_controlnet_img2img.py +8 -0
- diffusers/pipelines/controlnet/pipeline_controlnet_inpaint.py +8 -0
- diffusers/pipelines/controlnet/pipeline_controlnet_inpaint_sd_xl.py +36 -13
- diffusers/pipelines/controlnet/pipeline_controlnet_sd_xl.py +9 -1
- diffusers/pipelines/controlnet/pipeline_controlnet_sd_xl_img2img.py +8 -1
- diffusers/pipelines/controlnet_hunyuandit/pipeline_hunyuandit_controlnet.py +17 -3
- diffusers/pipelines/controlnet_sd3/__init__.py +4 -0
- diffusers/pipelines/controlnet_sd3/pipeline_stable_diffusion_3_controlnet.py +3 -1
- diffusers/pipelines/controlnet_sd3/pipeline_stable_diffusion_3_controlnet_inpainting.py +1153 -0
- diffusers/pipelines/ddpm/pipeline_ddpm.py +2 -2
- diffusers/pipelines/deepfloyd_if/pipeline_output.py +6 -5
- diffusers/pipelines/deprecated/alt_diffusion/pipeline_alt_diffusion.py +16 -4
- diffusers/pipelines/deprecated/alt_diffusion/pipeline_alt_diffusion_img2img.py +1 -1
- diffusers/pipelines/deprecated/versatile_diffusion/modeling_text_unet.py +1 -1
- diffusers/pipelines/flux/__init__.py +10 -0
- diffusers/pipelines/flux/pipeline_flux.py +53 -20
- diffusers/pipelines/flux/pipeline_flux_controlnet.py +984 -0
- diffusers/pipelines/flux/pipeline_flux_controlnet_image_to_image.py +988 -0
- diffusers/pipelines/flux/pipeline_flux_controlnet_inpainting.py +1182 -0
- diffusers/pipelines/flux/pipeline_flux_img2img.py +850 -0
- diffusers/pipelines/flux/pipeline_flux_inpaint.py +1015 -0
- diffusers/pipelines/free_noise_utils.py +365 -5
- diffusers/pipelines/hunyuandit/pipeline_hunyuandit.py +15 -3
- diffusers/pipelines/kandinsky2_2/pipeline_kandinsky2_2_combined.py +2 -2
- diffusers/pipelines/kolors/pipeline_kolors.py +1 -1
- diffusers/pipelines/kolors/pipeline_kolors_img2img.py +14 -11
- diffusers/pipelines/kolors/tokenizer.py +4 -0
- diffusers/pipelines/latent_consistency_models/pipeline_latent_consistency_img2img.py +1 -1
- diffusers/pipelines/latent_consistency_models/pipeline_latent_consistency_text2img.py +1 -1
- diffusers/pipelines/latte/pipeline_latte.py +2 -2
- diffusers/pipelines/ledits_pp/pipeline_leditspp_stable_diffusion.py +15 -3
- diffusers/pipelines/ledits_pp/pipeline_leditspp_stable_diffusion_xl.py +15 -3
- diffusers/pipelines/lumina/pipeline_lumina.py +2 -2
- diffusers/pipelines/pag/__init__.py +6 -0
- diffusers/pipelines/pag/pag_utils.py +8 -2
- diffusers/pipelines/pag/pipeline_pag_controlnet_sd.py +1 -1
- diffusers/pipelines/pag/pipeline_pag_controlnet_sd_inpaint.py +1544 -0
- diffusers/pipelines/pag/pipeline_pag_controlnet_sd_xl.py +2 -2
- diffusers/pipelines/pag/pipeline_pag_controlnet_sd_xl_img2img.py +1685 -0
- diffusers/pipelines/pag/pipeline_pag_hunyuandit.py +17 -5
- diffusers/pipelines/pag/pipeline_pag_kolors.py +1 -1
- diffusers/pipelines/pag/pipeline_pag_pixart_sigma.py +1 -1
- diffusers/pipelines/pag/pipeline_pag_sd.py +18 -6
- diffusers/pipelines/pag/pipeline_pag_sd_3.py +12 -3
- diffusers/pipelines/pag/pipeline_pag_sd_animatediff.py +5 -1
- diffusers/pipelines/pag/pipeline_pag_sd_img2img.py +1091 -0
- diffusers/pipelines/pag/pipeline_pag_sd_xl.py +18 -6
- diffusers/pipelines/pag/pipeline_pag_sd_xl_img2img.py +31 -16
- diffusers/pipelines/pag/pipeline_pag_sd_xl_inpaint.py +42 -19
- diffusers/pipelines/pia/pipeline_pia.py +2 -0
- diffusers/pipelines/pipeline_loading_utils.py +225 -27
- diffusers/pipelines/pipeline_utils.py +123 -180
- diffusers/pipelines/pixart_alpha/pipeline_pixart_alpha.py +1 -1
- diffusers/pipelines/pixart_alpha/pipeline_pixart_sigma.py +1 -1
- diffusers/pipelines/stable_cascade/pipeline_stable_cascade.py +35 -3
- diffusers/pipelines/stable_cascade/pipeline_stable_cascade_prior.py +2 -2
- diffusers/pipelines/stable_diffusion/pipeline_stable_diffusion.py +28 -6
- diffusers/pipelines/stable_diffusion/pipeline_stable_diffusion_img2img.py +1 -1
- diffusers/pipelines/stable_diffusion/pipeline_stable_diffusion_inpaint.py +1 -1
- diffusers/pipelines/stable_diffusion/pipeline_stable_diffusion_latent_upscale.py +241 -81
- diffusers/pipelines/stable_diffusion_3/pipeline_stable_diffusion_3.py +12 -3
- diffusers/pipelines/stable_diffusion_3/pipeline_stable_diffusion_3_img2img.py +20 -4
- diffusers/pipelines/stable_diffusion_3/pipeline_stable_diffusion_3_inpaint.py +3 -3
- diffusers/pipelines/stable_diffusion_k_diffusion/pipeline_stable_diffusion_k_diffusion.py +1 -1
- diffusers/pipelines/stable_diffusion_ldm3d/pipeline_stable_diffusion_ldm3d.py +16 -4
- diffusers/pipelines/stable_diffusion_panorama/pipeline_stable_diffusion_panorama.py +16 -4
- diffusers/pipelines/stable_diffusion_xl/pipeline_stable_diffusion_xl.py +16 -4
- diffusers/pipelines/stable_diffusion_xl/pipeline_stable_diffusion_xl_img2img.py +29 -14
- diffusers/pipelines/stable_diffusion_xl/pipeline_stable_diffusion_xl_inpaint.py +29 -14
- diffusers/pipelines/stable_video_diffusion/pipeline_stable_video_diffusion.py +1 -1
- diffusers/pipelines/t2i_adapter/pipeline_stable_diffusion_adapter.py +1 -1
- diffusers/pipelines/t2i_adapter/pipeline_stable_diffusion_xl_adapter.py +16 -4
- diffusers/pipelines/text_to_video_synthesis/pipeline_text_to_video_zero_sdxl.py +15 -3
- diffusers/quantizers/__init__.py +16 -0
- diffusers/quantizers/auto.py +126 -0
- diffusers/quantizers/base.py +233 -0
- diffusers/quantizers/bitsandbytes/__init__.py +2 -0
- diffusers/quantizers/bitsandbytes/bnb_quantizer.py +558 -0
- diffusers/quantizers/bitsandbytes/utils.py +306 -0
- diffusers/quantizers/quantization_config.py +391 -0
- diffusers/schedulers/scheduling_ddim.py +4 -1
- diffusers/schedulers/scheduling_ddim_cogvideox.py +4 -1
- diffusers/schedulers/scheduling_ddim_parallel.py +4 -1
- diffusers/schedulers/scheduling_ddpm.py +4 -1
- diffusers/schedulers/scheduling_ddpm_parallel.py +4 -1
- diffusers/schedulers/scheduling_deis_multistep.py +78 -1
- diffusers/schedulers/scheduling_dpmsolver_multistep.py +82 -1
- diffusers/schedulers/scheduling_dpmsolver_multistep_inverse.py +80 -1
- diffusers/schedulers/scheduling_dpmsolver_sde.py +125 -10
- diffusers/schedulers/scheduling_dpmsolver_singlestep.py +82 -1
- diffusers/schedulers/scheduling_edm_euler.py +8 -6
- diffusers/schedulers/scheduling_euler_ancestral_discrete.py +4 -1
- diffusers/schedulers/scheduling_euler_discrete.py +92 -7
- diffusers/schedulers/scheduling_flow_match_heun_discrete.py +4 -5
- diffusers/schedulers/scheduling_heun_discrete.py +114 -8
- diffusers/schedulers/scheduling_k_dpm_2_ancestral_discrete.py +116 -11
- diffusers/schedulers/scheduling_k_dpm_2_discrete.py +110 -8
- diffusers/schedulers/scheduling_lms_discrete.py +76 -1
- diffusers/schedulers/scheduling_sasolver.py +78 -1
- diffusers/schedulers/scheduling_unclip.py +4 -1
- diffusers/schedulers/scheduling_unipc_multistep.py +78 -1
- diffusers/training_utils.py +48 -18
- diffusers/utils/__init__.py +2 -1
- diffusers/utils/dummy_pt_objects.py +60 -0
- diffusers/utils/dummy_torch_and_transformers_objects.py +165 -0
- diffusers/utils/hub_utils.py +16 -4
- diffusers/utils/import_utils.py +31 -8
- diffusers/utils/loading_utils.py +28 -4
- diffusers/utils/peft_utils.py +3 -3
- diffusers/utils/testing_utils.py +59 -0
- {diffusers-0.30.3.dist-info → diffusers-0.31.0.dist-info}/METADATA +7 -6
- {diffusers-0.30.3.dist-info → diffusers-0.31.0.dist-info}/RECORD +172 -149
- {diffusers-0.30.3.dist-info → diffusers-0.31.0.dist-info}/LICENSE +0 -0
- {diffusers-0.30.3.dist-info → diffusers-0.31.0.dist-info}/WHEEL +0 -0
- {diffusers-0.30.3.dist-info → diffusers-0.31.0.dist-info}/entry_points.txt +0 -0
- {diffusers-0.30.3.dist-info → diffusers-0.31.0.dist-info}/top_level.txt +0 -0
@@ -101,10 +101,10 @@ class DDPMPipeline(DiffusionPipeline):
|
|
101
101
|
|
102
102
|
if self.device.type == "mps":
|
103
103
|
# randn does not work reproducibly on mps
|
104
|
-
image = randn_tensor(image_shape, generator=generator)
|
104
|
+
image = randn_tensor(image_shape, generator=generator, dtype=self.unet.dtype)
|
105
105
|
image = image.to(self.device)
|
106
106
|
else:
|
107
|
-
image = randn_tensor(image_shape, generator=generator, device=self.device)
|
107
|
+
image = randn_tensor(image_shape, generator=generator, device=self.device, dtype=self.unet.dtype)
|
108
108
|
|
109
109
|
# set step values
|
110
110
|
self.scheduler.set_timesteps(num_inference_steps)
|
@@ -9,16 +9,17 @@ from ...utils import BaseOutput
|
|
9
9
|
|
10
10
|
@dataclass
|
11
11
|
class IFPipelineOutput(BaseOutput):
|
12
|
-
"""
|
13
|
-
Args:
|
12
|
+
r"""
|
14
13
|
Output class for Stable Diffusion pipelines.
|
15
|
-
|
14
|
+
|
15
|
+
Args:
|
16
|
+
images (`List[PIL.Image.Image]` or `np.ndarray`):
|
16
17
|
List of denoised PIL images of length `batch_size` or numpy array of shape `(batch_size, height, width,
|
17
18
|
num_channels)`. PIL images or numpy array present the denoised images of the diffusion pipeline.
|
18
|
-
nsfw_detected (`List[bool]`)
|
19
|
+
nsfw_detected (`List[bool]`):
|
19
20
|
List of flags denoting whether the corresponding generated image likely represents "not-safe-for-work"
|
20
21
|
(nsfw) content or a watermark. `None` if safety checking could not be performed.
|
21
|
-
watermark_detected (`List[bool]`)
|
22
|
+
watermark_detected (`List[bool]`):
|
22
23
|
List of flags denoting whether the corresponding generated image likely has a watermark. `None` if safety
|
23
24
|
checking could not be performed.
|
24
25
|
"""
|
@@ -65,9 +65,21 @@ EXAMPLE_DOC_STRING = """
|
|
65
65
|
|
66
66
|
# Copied from diffusers.pipelines.stable_diffusion.pipeline_stable_diffusion.rescale_noise_cfg
|
67
67
|
def rescale_noise_cfg(noise_cfg, noise_pred_text, guidance_rescale=0.0):
|
68
|
-
"""
|
69
|
-
|
70
|
-
|
68
|
+
r"""
|
69
|
+
Rescales `noise_cfg` tensor based on `guidance_rescale` to improve image quality and fix overexposure. Based on
|
70
|
+
Section 3.4 from [Common Diffusion Noise Schedules and Sample Steps are
|
71
|
+
Flawed](https://arxiv.org/pdf/2305.08891.pdf).
|
72
|
+
|
73
|
+
Args:
|
74
|
+
noise_cfg (`torch.Tensor`):
|
75
|
+
The predicted noise tensor for the guided diffusion process.
|
76
|
+
noise_pred_text (`torch.Tensor`):
|
77
|
+
The predicted noise tensor for the text-guided diffusion process.
|
78
|
+
guidance_rescale (`float`, *optional*, defaults to 0.0):
|
79
|
+
A rescale factor applied to the noise predictions.
|
80
|
+
|
81
|
+
Returns:
|
82
|
+
noise_cfg (`torch.Tensor`): The rescaled noise prediction tensor.
|
71
83
|
"""
|
72
84
|
std_text = noise_pred_text.std(dim=list(range(1, noise_pred_text.ndim)), keepdim=True)
|
73
85
|
std_cfg = noise_cfg.std(dim=list(range(1, noise_cfg.ndim)), keepdim=True)
|
@@ -87,7 +99,7 @@ def retrieve_timesteps(
|
|
87
99
|
sigmas: Optional[List[float]] = None,
|
88
100
|
**kwargs,
|
89
101
|
):
|
90
|
-
"""
|
102
|
+
r"""
|
91
103
|
Calls the scheduler's `set_timesteps` method and retrieves timesteps from the scheduler after the call. Handles
|
92
104
|
custom timesteps. Any kwargs will be supplied to `scheduler.set_timesteps`.
|
93
105
|
|
@@ -127,7 +127,7 @@ def retrieve_timesteps(
|
|
127
127
|
sigmas: Optional[List[float]] = None,
|
128
128
|
**kwargs,
|
129
129
|
):
|
130
|
-
"""
|
130
|
+
r"""
|
131
131
|
Calls the scheduler's `set_timesteps` method and retrieves timesteps from the scheduler after the call. Handles
|
132
132
|
custom timesteps. Any kwargs will be supplied to `scheduler.set_timesteps`.
|
133
133
|
|
@@ -546,7 +546,7 @@ class UNetFlatConditionModel(ModelMixin, ConfigMixin):
|
|
546
546
|
)
|
547
547
|
elif encoder_hid_dim_type is not None:
|
548
548
|
raise ValueError(
|
549
|
-
f"encoder_hid_dim_type
|
549
|
+
f"`encoder_hid_dim_type`: {encoder_hid_dim_type} must be None, 'text_proj', 'text_image_proj' or 'image_proj'."
|
550
550
|
)
|
551
551
|
else:
|
552
552
|
self.encoder_hid_proj = None
|
@@ -23,6 +23,11 @@ except OptionalDependencyNotAvailable:
|
|
23
23
|
_dummy_objects.update(get_objects_from_module(dummy_torch_and_transformers_objects))
|
24
24
|
else:
|
25
25
|
_import_structure["pipeline_flux"] = ["FluxPipeline"]
|
26
|
+
_import_structure["pipeline_flux_controlnet"] = ["FluxControlNetPipeline"]
|
27
|
+
_import_structure["pipeline_flux_controlnet_image_to_image"] = ["FluxControlNetImg2ImgPipeline"]
|
28
|
+
_import_structure["pipeline_flux_controlnet_inpainting"] = ["FluxControlNetInpaintPipeline"]
|
29
|
+
_import_structure["pipeline_flux_img2img"] = ["FluxImg2ImgPipeline"]
|
30
|
+
_import_structure["pipeline_flux_inpaint"] = ["FluxInpaintPipeline"]
|
26
31
|
if TYPE_CHECKING or DIFFUSERS_SLOW_IMPORT:
|
27
32
|
try:
|
28
33
|
if not (is_transformers_available() and is_torch_available()):
|
@@ -31,6 +36,11 @@ if TYPE_CHECKING or DIFFUSERS_SLOW_IMPORT:
|
|
31
36
|
from ...utils.dummy_torch_and_transformers_objects import * # noqa F403
|
32
37
|
else:
|
33
38
|
from .pipeline_flux import FluxPipeline
|
39
|
+
from .pipeline_flux_controlnet import FluxControlNetPipeline
|
40
|
+
from .pipeline_flux_controlnet_image_to_image import FluxControlNetImg2ImgPipeline
|
41
|
+
from .pipeline_flux_controlnet_inpainting import FluxControlNetInpaintPipeline
|
42
|
+
from .pipeline_flux_img2img import FluxImg2ImgPipeline
|
43
|
+
from .pipeline_flux_inpaint import FluxInpaintPipeline
|
34
44
|
else:
|
35
45
|
import sys
|
36
46
|
|
@@ -20,7 +20,7 @@ import torch
|
|
20
20
|
from transformers import CLIPTextModel, CLIPTokenizer, T5EncoderModel, T5TokenizerFast
|
21
21
|
|
22
22
|
from ...image_processor import VaeImageProcessor
|
23
|
-
from ...loaders import FluxLoraLoaderMixin
|
23
|
+
from ...loaders import FluxLoraLoaderMixin, FromSingleFileMixin, TextualInversionLoaderMixin
|
24
24
|
from ...models.autoencoders import AutoencoderKL
|
25
25
|
from ...models.transformers import FluxTransformer2DModel
|
26
26
|
from ...schedulers import FlowMatchEulerDiscreteScheduler
|
@@ -86,7 +86,7 @@ def retrieve_timesteps(
|
|
86
86
|
sigmas: Optional[List[float]] = None,
|
87
87
|
**kwargs,
|
88
88
|
):
|
89
|
-
"""
|
89
|
+
r"""
|
90
90
|
Calls the scheduler's `set_timesteps` method and retrieves timesteps from the scheduler after the call. Handles
|
91
91
|
custom timesteps. Any kwargs will be supplied to `scheduler.set_timesteps`.
|
92
92
|
|
@@ -137,7 +137,12 @@ def retrieve_timesteps(
|
|
137
137
|
return timesteps, num_inference_steps
|
138
138
|
|
139
139
|
|
140
|
-
class FluxPipeline(
|
140
|
+
class FluxPipeline(
|
141
|
+
DiffusionPipeline,
|
142
|
+
FluxLoraLoaderMixin,
|
143
|
+
FromSingleFileMixin,
|
144
|
+
TextualInversionLoaderMixin,
|
145
|
+
):
|
141
146
|
r"""
|
142
147
|
The Flux pipeline for text-to-image generation.
|
143
148
|
|
@@ -212,6 +217,9 @@ class FluxPipeline(DiffusionPipeline, FluxLoraLoaderMixin):
|
|
212
217
|
prompt = [prompt] if isinstance(prompt, str) else prompt
|
213
218
|
batch_size = len(prompt)
|
214
219
|
|
220
|
+
if isinstance(self, TextualInversionLoaderMixin):
|
221
|
+
prompt = self.maybe_convert_prompt(prompt, self.tokenizer_2)
|
222
|
+
|
215
223
|
text_inputs = self.tokenizer_2(
|
216
224
|
prompt,
|
217
225
|
padding="max_length",
|
@@ -255,6 +263,9 @@ class FluxPipeline(DiffusionPipeline, FluxLoraLoaderMixin):
|
|
255
263
|
prompt = [prompt] if isinstance(prompt, str) else prompt
|
256
264
|
batch_size = len(prompt)
|
257
265
|
|
266
|
+
if isinstance(self, TextualInversionLoaderMixin):
|
267
|
+
prompt = self.maybe_convert_prompt(prompt, self.tokenizer)
|
268
|
+
|
258
269
|
text_inputs = self.tokenizer(
|
259
270
|
prompt,
|
260
271
|
padding="max_length",
|
@@ -331,10 +342,6 @@ class FluxPipeline(DiffusionPipeline, FluxLoraLoaderMixin):
|
|
331
342
|
scale_lora_layers(self.text_encoder_2, lora_scale)
|
332
343
|
|
333
344
|
prompt = [prompt] if isinstance(prompt, str) else prompt
|
334
|
-
if prompt is not None:
|
335
|
-
batch_size = len(prompt)
|
336
|
-
else:
|
337
|
-
batch_size = prompt_embeds.shape[0]
|
338
345
|
|
339
346
|
if prompt_embeds is None:
|
340
347
|
prompt_2 = prompt_2 or prompt
|
@@ -364,8 +371,7 @@ class FluxPipeline(DiffusionPipeline, FluxLoraLoaderMixin):
|
|
364
371
|
unscale_lora_layers(self.text_encoder_2, lora_scale)
|
365
372
|
|
366
373
|
dtype = self.text_encoder.dtype if self.text_encoder is not None else self.transformer.dtype
|
367
|
-
text_ids = torch.zeros(
|
368
|
-
text_ids = text_ids.repeat(num_images_per_prompt, 1, 1)
|
374
|
+
text_ids = torch.zeros(prompt_embeds.shape[1], 3).to(device=device, dtype=dtype)
|
369
375
|
|
370
376
|
return prompt_embeds, pooled_prompt_embeds, text_ids
|
371
377
|
|
@@ -425,9 +431,8 @@ class FluxPipeline(DiffusionPipeline, FluxLoraLoaderMixin):
|
|
425
431
|
|
426
432
|
latent_image_id_height, latent_image_id_width, latent_image_id_channels = latent_image_ids.shape
|
427
433
|
|
428
|
-
latent_image_ids = latent_image_ids[None, :].repeat(batch_size, 1, 1, 1)
|
429
434
|
latent_image_ids = latent_image_ids.reshape(
|
430
|
-
|
435
|
+
latent_image_id_height * latent_image_id_width, latent_image_id_channels
|
431
436
|
)
|
432
437
|
|
433
438
|
return latent_image_ids.to(device=device, dtype=dtype)
|
@@ -454,6 +459,35 @@ class FluxPipeline(DiffusionPipeline, FluxLoraLoaderMixin):
|
|
454
459
|
|
455
460
|
return latents
|
456
461
|
|
462
|
+
def enable_vae_slicing(self):
|
463
|
+
r"""
|
464
|
+
Enable sliced VAE decoding. When this option is enabled, the VAE will split the input tensor in slices to
|
465
|
+
compute decoding in several steps. This is useful to save some memory and allow larger batch sizes.
|
466
|
+
"""
|
467
|
+
self.vae.enable_slicing()
|
468
|
+
|
469
|
+
def disable_vae_slicing(self):
|
470
|
+
r"""
|
471
|
+
Disable sliced VAE decoding. If `enable_vae_slicing` was previously enabled, this method will go back to
|
472
|
+
computing decoding in one step.
|
473
|
+
"""
|
474
|
+
self.vae.disable_slicing()
|
475
|
+
|
476
|
+
def enable_vae_tiling(self):
|
477
|
+
r"""
|
478
|
+
Enable tiled VAE decoding. When this option is enabled, the VAE will split the input tensor into tiles to
|
479
|
+
compute decoding and encoding in several steps. This is useful for saving a large amount of memory and to allow
|
480
|
+
processing larger images.
|
481
|
+
"""
|
482
|
+
self.vae.enable_tiling()
|
483
|
+
|
484
|
+
def disable_vae_tiling(self):
|
485
|
+
r"""
|
486
|
+
Disable tiled VAE decoding. If `enable_vae_tiling` was previously enabled, this method will go back to
|
487
|
+
computing decoding in one step.
|
488
|
+
"""
|
489
|
+
self.vae.disable_tiling()
|
490
|
+
|
457
491
|
def prepare_latents(
|
458
492
|
self,
|
459
493
|
batch_size,
|
@@ -513,7 +547,7 @@ class FluxPipeline(DiffusionPipeline, FluxLoraLoaderMixin):
|
|
513
547
|
width: Optional[int] = None,
|
514
548
|
num_inference_steps: int = 28,
|
515
549
|
timesteps: List[int] = None,
|
516
|
-
guidance_scale: float =
|
550
|
+
guidance_scale: float = 3.5,
|
517
551
|
num_images_per_prompt: Optional[int] = 1,
|
518
552
|
generator: Optional[Union[torch.Generator, List[torch.Generator]]] = None,
|
519
553
|
latents: Optional[torch.FloatTensor] = None,
|
@@ -677,6 +711,13 @@ class FluxPipeline(DiffusionPipeline, FluxLoraLoaderMixin):
|
|
677
711
|
num_warmup_steps = max(len(timesteps) - num_inference_steps * self.scheduler.order, 0)
|
678
712
|
self._num_timesteps = len(timesteps)
|
679
713
|
|
714
|
+
# handle guidance
|
715
|
+
if self.transformer.config.guidance_embeds:
|
716
|
+
guidance = torch.full([1], guidance_scale, device=device, dtype=torch.float32)
|
717
|
+
guidance = guidance.expand(latents.shape[0])
|
718
|
+
else:
|
719
|
+
guidance = None
|
720
|
+
|
680
721
|
# 6. Denoising loop
|
681
722
|
with self.progress_bar(total=num_inference_steps) as progress_bar:
|
682
723
|
for i, t in enumerate(timesteps):
|
@@ -686,16 +727,8 @@ class FluxPipeline(DiffusionPipeline, FluxLoraLoaderMixin):
|
|
686
727
|
# broadcast to batch dimension in a way that's compatible with ONNX/Core ML
|
687
728
|
timestep = t.expand(latents.shape[0]).to(latents.dtype)
|
688
729
|
|
689
|
-
# handle guidance
|
690
|
-
if self.transformer.config.guidance_embeds:
|
691
|
-
guidance = torch.tensor([guidance_scale], device=device)
|
692
|
-
guidance = guidance.expand(latents.shape[0])
|
693
|
-
else:
|
694
|
-
guidance = None
|
695
|
-
|
696
730
|
noise_pred = self.transformer(
|
697
731
|
hidden_states=latents,
|
698
|
-
# YiYi notes: divide it by 1000 for now because we scale it by 1000 in the transforme rmodel (we should not keep it but I want to keep the inputs same for the model for testing)
|
699
732
|
timestep=timestep / 1000,
|
700
733
|
guidance=guidance,
|
701
734
|
pooled_projections=pooled_prompt_embeds,
|