diffusers 0.30.2__py3-none-any.whl → 0.31.0__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (173) hide show
  1. diffusers/__init__.py +38 -2
  2. diffusers/configuration_utils.py +12 -0
  3. diffusers/dependency_versions_table.py +1 -1
  4. diffusers/image_processor.py +257 -54
  5. diffusers/loaders/__init__.py +2 -0
  6. diffusers/loaders/ip_adapter.py +5 -1
  7. diffusers/loaders/lora_base.py +14 -7
  8. diffusers/loaders/lora_conversion_utils.py +332 -0
  9. diffusers/loaders/lora_pipeline.py +707 -41
  10. diffusers/loaders/peft.py +1 -0
  11. diffusers/loaders/single_file_utils.py +81 -4
  12. diffusers/loaders/textual_inversion.py +2 -0
  13. diffusers/loaders/unet.py +39 -8
  14. diffusers/models/__init__.py +4 -0
  15. diffusers/models/adapter.py +53 -53
  16. diffusers/models/attention.py +86 -10
  17. diffusers/models/attention_processor.py +169 -133
  18. diffusers/models/autoencoders/autoencoder_kl.py +71 -11
  19. diffusers/models/autoencoders/autoencoder_kl_cogvideox.py +287 -85
  20. diffusers/models/controlnet_flux.py +536 -0
  21. diffusers/models/controlnet_sd3.py +7 -3
  22. diffusers/models/controlnet_sparsectrl.py +0 -1
  23. diffusers/models/embeddings.py +238 -61
  24. diffusers/models/embeddings_flax.py +23 -9
  25. diffusers/models/model_loading_utils.py +182 -14
  26. diffusers/models/modeling_utils.py +283 -46
  27. diffusers/models/normalization.py +79 -0
  28. diffusers/models/transformers/__init__.py +1 -0
  29. diffusers/models/transformers/auraflow_transformer_2d.py +1 -0
  30. diffusers/models/transformers/cogvideox_transformer_3d.py +58 -36
  31. diffusers/models/transformers/pixart_transformer_2d.py +9 -1
  32. diffusers/models/transformers/transformer_cogview3plus.py +386 -0
  33. diffusers/models/transformers/transformer_flux.py +161 -44
  34. diffusers/models/transformers/transformer_sd3.py +7 -1
  35. diffusers/models/unets/unet_2d_condition.py +8 -8
  36. diffusers/models/unets/unet_motion_model.py +41 -63
  37. diffusers/models/upsampling.py +6 -6
  38. diffusers/pipelines/__init__.py +40 -7
  39. diffusers/pipelines/animatediff/__init__.py +2 -0
  40. diffusers/pipelines/animatediff/pipeline_animatediff.py +45 -21
  41. diffusers/pipelines/animatediff/pipeline_animatediff_controlnet.py +44 -20
  42. diffusers/pipelines/animatediff/pipeline_animatediff_sdxl.py +18 -4
  43. diffusers/pipelines/animatediff/pipeline_animatediff_sparsectrl.py +2 -0
  44. diffusers/pipelines/animatediff/pipeline_animatediff_video2video.py +104 -66
  45. diffusers/pipelines/animatediff/pipeline_animatediff_video2video_controlnet.py +1341 -0
  46. diffusers/pipelines/aura_flow/pipeline_aura_flow.py +1 -1
  47. diffusers/pipelines/auto_pipeline.py +39 -8
  48. diffusers/pipelines/cogvideo/__init__.py +6 -0
  49. diffusers/pipelines/cogvideo/pipeline_cogvideox.py +32 -34
  50. diffusers/pipelines/cogvideo/pipeline_cogvideox_fun_control.py +794 -0
  51. diffusers/pipelines/cogvideo/pipeline_cogvideox_image2video.py +837 -0
  52. diffusers/pipelines/cogvideo/pipeline_cogvideox_video2video.py +825 -0
  53. diffusers/pipelines/cogvideo/pipeline_output.py +20 -0
  54. diffusers/pipelines/cogview3/__init__.py +47 -0
  55. diffusers/pipelines/cogview3/pipeline_cogview3plus.py +674 -0
  56. diffusers/pipelines/cogview3/pipeline_output.py +21 -0
  57. diffusers/pipelines/controlnet/pipeline_controlnet.py +9 -1
  58. diffusers/pipelines/controlnet/pipeline_controlnet_img2img.py +8 -0
  59. diffusers/pipelines/controlnet/pipeline_controlnet_inpaint.py +8 -0
  60. diffusers/pipelines/controlnet/pipeline_controlnet_inpaint_sd_xl.py +36 -13
  61. diffusers/pipelines/controlnet/pipeline_controlnet_sd_xl.py +9 -1
  62. diffusers/pipelines/controlnet/pipeline_controlnet_sd_xl_img2img.py +8 -1
  63. diffusers/pipelines/controlnet_hunyuandit/pipeline_hunyuandit_controlnet.py +17 -3
  64. diffusers/pipelines/controlnet_sd3/__init__.py +4 -0
  65. diffusers/pipelines/controlnet_sd3/pipeline_stable_diffusion_3_controlnet.py +3 -1
  66. diffusers/pipelines/controlnet_sd3/pipeline_stable_diffusion_3_controlnet_inpainting.py +1153 -0
  67. diffusers/pipelines/ddpm/pipeline_ddpm.py +2 -2
  68. diffusers/pipelines/deepfloyd_if/pipeline_output.py +6 -5
  69. diffusers/pipelines/deprecated/alt_diffusion/pipeline_alt_diffusion.py +16 -4
  70. diffusers/pipelines/deprecated/alt_diffusion/pipeline_alt_diffusion_img2img.py +1 -1
  71. diffusers/pipelines/deprecated/versatile_diffusion/modeling_text_unet.py +1 -1
  72. diffusers/pipelines/flux/__init__.py +10 -0
  73. diffusers/pipelines/flux/pipeline_flux.py +53 -20
  74. diffusers/pipelines/flux/pipeline_flux_controlnet.py +984 -0
  75. diffusers/pipelines/flux/pipeline_flux_controlnet_image_to_image.py +988 -0
  76. diffusers/pipelines/flux/pipeline_flux_controlnet_inpainting.py +1182 -0
  77. diffusers/pipelines/flux/pipeline_flux_img2img.py +850 -0
  78. diffusers/pipelines/flux/pipeline_flux_inpaint.py +1015 -0
  79. diffusers/pipelines/free_noise_utils.py +365 -5
  80. diffusers/pipelines/hunyuandit/pipeline_hunyuandit.py +15 -3
  81. diffusers/pipelines/kandinsky2_2/pipeline_kandinsky2_2_combined.py +2 -2
  82. diffusers/pipelines/kolors/pipeline_kolors.py +1 -1
  83. diffusers/pipelines/kolors/pipeline_kolors_img2img.py +14 -11
  84. diffusers/pipelines/kolors/tokenizer.py +4 -0
  85. diffusers/pipelines/latent_consistency_models/pipeline_latent_consistency_img2img.py +1 -1
  86. diffusers/pipelines/latent_consistency_models/pipeline_latent_consistency_text2img.py +1 -1
  87. diffusers/pipelines/latte/pipeline_latte.py +2 -2
  88. diffusers/pipelines/ledits_pp/pipeline_leditspp_stable_diffusion.py +15 -3
  89. diffusers/pipelines/ledits_pp/pipeline_leditspp_stable_diffusion_xl.py +15 -3
  90. diffusers/pipelines/lumina/pipeline_lumina.py +2 -2
  91. diffusers/pipelines/pag/__init__.py +6 -0
  92. diffusers/pipelines/pag/pag_utils.py +8 -2
  93. diffusers/pipelines/pag/pipeline_pag_controlnet_sd.py +1 -1
  94. diffusers/pipelines/pag/pipeline_pag_controlnet_sd_inpaint.py +1544 -0
  95. diffusers/pipelines/pag/pipeline_pag_controlnet_sd_xl.py +2 -2
  96. diffusers/pipelines/pag/pipeline_pag_controlnet_sd_xl_img2img.py +1685 -0
  97. diffusers/pipelines/pag/pipeline_pag_hunyuandit.py +17 -5
  98. diffusers/pipelines/pag/pipeline_pag_kolors.py +1 -1
  99. diffusers/pipelines/pag/pipeline_pag_pixart_sigma.py +1 -1
  100. diffusers/pipelines/pag/pipeline_pag_sd.py +18 -6
  101. diffusers/pipelines/pag/pipeline_pag_sd_3.py +12 -3
  102. diffusers/pipelines/pag/pipeline_pag_sd_animatediff.py +5 -1
  103. diffusers/pipelines/pag/pipeline_pag_sd_img2img.py +1091 -0
  104. diffusers/pipelines/pag/pipeline_pag_sd_xl.py +18 -6
  105. diffusers/pipelines/pag/pipeline_pag_sd_xl_img2img.py +31 -16
  106. diffusers/pipelines/pag/pipeline_pag_sd_xl_inpaint.py +42 -19
  107. diffusers/pipelines/pia/pipeline_pia.py +2 -0
  108. diffusers/pipelines/pipeline_loading_utils.py +225 -27
  109. diffusers/pipelines/pipeline_utils.py +123 -180
  110. diffusers/pipelines/pixart_alpha/pipeline_pixart_alpha.py +1 -1
  111. diffusers/pipelines/pixart_alpha/pipeline_pixart_sigma.py +1 -1
  112. diffusers/pipelines/stable_cascade/pipeline_stable_cascade.py +35 -3
  113. diffusers/pipelines/stable_cascade/pipeline_stable_cascade_prior.py +2 -2
  114. diffusers/pipelines/stable_diffusion/pipeline_stable_diffusion.py +28 -6
  115. diffusers/pipelines/stable_diffusion/pipeline_stable_diffusion_img2img.py +1 -1
  116. diffusers/pipelines/stable_diffusion/pipeline_stable_diffusion_inpaint.py +1 -1
  117. diffusers/pipelines/stable_diffusion/pipeline_stable_diffusion_latent_upscale.py +241 -81
  118. diffusers/pipelines/stable_diffusion_3/pipeline_stable_diffusion_3.py +12 -3
  119. diffusers/pipelines/stable_diffusion_3/pipeline_stable_diffusion_3_img2img.py +20 -4
  120. diffusers/pipelines/stable_diffusion_3/pipeline_stable_diffusion_3_inpaint.py +3 -3
  121. diffusers/pipelines/stable_diffusion_k_diffusion/pipeline_stable_diffusion_k_diffusion.py +1 -1
  122. diffusers/pipelines/stable_diffusion_ldm3d/pipeline_stable_diffusion_ldm3d.py +16 -4
  123. diffusers/pipelines/stable_diffusion_panorama/pipeline_stable_diffusion_panorama.py +16 -4
  124. diffusers/pipelines/stable_diffusion_xl/pipeline_stable_diffusion_xl.py +16 -4
  125. diffusers/pipelines/stable_diffusion_xl/pipeline_stable_diffusion_xl_img2img.py +29 -14
  126. diffusers/pipelines/stable_diffusion_xl/pipeline_stable_diffusion_xl_inpaint.py +29 -14
  127. diffusers/pipelines/stable_video_diffusion/pipeline_stable_video_diffusion.py +1 -1
  128. diffusers/pipelines/t2i_adapter/pipeline_stable_diffusion_adapter.py +1 -1
  129. diffusers/pipelines/t2i_adapter/pipeline_stable_diffusion_xl_adapter.py +16 -4
  130. diffusers/pipelines/text_to_video_synthesis/pipeline_text_to_video_zero_sdxl.py +15 -3
  131. diffusers/quantizers/__init__.py +16 -0
  132. diffusers/quantizers/auto.py +126 -0
  133. diffusers/quantizers/base.py +233 -0
  134. diffusers/quantizers/bitsandbytes/__init__.py +2 -0
  135. diffusers/quantizers/bitsandbytes/bnb_quantizer.py +558 -0
  136. diffusers/quantizers/bitsandbytes/utils.py +306 -0
  137. diffusers/quantizers/quantization_config.py +391 -0
  138. diffusers/schedulers/scheduling_ddim.py +4 -1
  139. diffusers/schedulers/scheduling_ddim_cogvideox.py +4 -1
  140. diffusers/schedulers/scheduling_ddim_parallel.py +4 -1
  141. diffusers/schedulers/scheduling_ddpm.py +4 -1
  142. diffusers/schedulers/scheduling_ddpm_parallel.py +4 -1
  143. diffusers/schedulers/scheduling_deis_multistep.py +78 -1
  144. diffusers/schedulers/scheduling_dpmsolver_multistep.py +82 -1
  145. diffusers/schedulers/scheduling_dpmsolver_multistep_inverse.py +80 -1
  146. diffusers/schedulers/scheduling_dpmsolver_sde.py +125 -10
  147. diffusers/schedulers/scheduling_dpmsolver_singlestep.py +82 -1
  148. diffusers/schedulers/scheduling_edm_euler.py +8 -6
  149. diffusers/schedulers/scheduling_euler_ancestral_discrete.py +4 -1
  150. diffusers/schedulers/scheduling_euler_discrete.py +92 -7
  151. diffusers/schedulers/scheduling_flow_match_heun_discrete.py +4 -5
  152. diffusers/schedulers/scheduling_heun_discrete.py +114 -8
  153. diffusers/schedulers/scheduling_k_dpm_2_ancestral_discrete.py +116 -11
  154. diffusers/schedulers/scheduling_k_dpm_2_discrete.py +110 -8
  155. diffusers/schedulers/scheduling_lms_discrete.py +76 -1
  156. diffusers/schedulers/scheduling_sasolver.py +78 -1
  157. diffusers/schedulers/scheduling_unclip.py +4 -1
  158. diffusers/schedulers/scheduling_unipc_multistep.py +78 -1
  159. diffusers/training_utils.py +48 -18
  160. diffusers/utils/__init__.py +2 -1
  161. diffusers/utils/dummy_pt_objects.py +60 -0
  162. diffusers/utils/dummy_torch_and_transformers_objects.py +195 -0
  163. diffusers/utils/hub_utils.py +16 -4
  164. diffusers/utils/import_utils.py +31 -8
  165. diffusers/utils/loading_utils.py +28 -4
  166. diffusers/utils/peft_utils.py +3 -3
  167. diffusers/utils/testing_utils.py +59 -0
  168. {diffusers-0.30.2.dist-info → diffusers-0.31.0.dist-info}/METADATA +7 -6
  169. {diffusers-0.30.2.dist-info → diffusers-0.31.0.dist-info}/RECORD +173 -147
  170. {diffusers-0.30.2.dist-info → diffusers-0.31.0.dist-info}/WHEEL +1 -1
  171. {diffusers-0.30.2.dist-info → diffusers-0.31.0.dist-info}/LICENSE +0 -0
  172. {diffusers-0.30.2.dist-info → diffusers-0.31.0.dist-info}/entry_points.txt +0 -0
  173. {diffusers-0.30.2.dist-info → diffusers-0.31.0.dist-info}/top_level.txt +0 -0
@@ -0,0 +1,1182 @@
1
+ import inspect
2
+ from typing import Any, Callable, Dict, List, Optional, Tuple, Union
3
+
4
+ import numpy as np
5
+ import PIL
6
+ import torch
7
+ from transformers import (
8
+ CLIPTextModel,
9
+ CLIPTokenizer,
10
+ T5EncoderModel,
11
+ T5TokenizerFast,
12
+ )
13
+
14
+ from ...image_processor import PipelineImageInput, VaeImageProcessor
15
+ from ...loaders import FluxLoraLoaderMixin, FromSingleFileMixin, TextualInversionLoaderMixin
16
+ from ...models.autoencoders import AutoencoderKL
17
+ from ...models.controlnet_flux import FluxControlNetModel, FluxMultiControlNetModel
18
+ from ...models.transformers import FluxTransformer2DModel
19
+ from ...schedulers import FlowMatchEulerDiscreteScheduler
20
+ from ...utils import (
21
+ USE_PEFT_BACKEND,
22
+ is_torch_xla_available,
23
+ logging,
24
+ replace_example_docstring,
25
+ scale_lora_layers,
26
+ unscale_lora_layers,
27
+ )
28
+ from ...utils.torch_utils import randn_tensor
29
+ from ..pipeline_utils import DiffusionPipeline
30
+ from .pipeline_output import FluxPipelineOutput
31
+
32
+
33
+ if is_torch_xla_available():
34
+ import torch_xla.core.xla_model as xm
35
+
36
+ XLA_AVAILABLE = True
37
+ else:
38
+ XLA_AVAILABLE = False
39
+
40
+ logger = logging.get_logger(__name__)
41
+
42
+ EXAMPLE_DOC_STRING = """
43
+ Examples:
44
+ ```py
45
+ >>> import torch
46
+ >>> from diffusers import FluxControlNetInpaintPipeline
47
+ >>> from diffusers.models import FluxControlNetModel
48
+ >>> from diffusers.utils import load_image
49
+
50
+ >>> controlnet = FluxControlNetModel.from_pretrained(
51
+ ... "InstantX/FLUX.1-dev-controlnet-canny", torch_dtype=torch.float16
52
+ ... )
53
+ >>> pipe = FluxControlNetInpaintPipeline.from_pretrained(
54
+ ... "black-forest-labs/FLUX.1-schnell", controlnet=controlnet, torch_dtype=torch.float16
55
+ ... )
56
+ >>> pipe.to("cuda")
57
+
58
+ >>> control_image = load_image(
59
+ ... "https://huggingface.co/InstantX/FLUX.1-dev-Controlnet-Canny-alpha/resolve/main/canny.jpg"
60
+ ... )
61
+ >>> init_image = load_image(
62
+ ... "https://raw.githubusercontent.com/CompVis/latent-diffusion/main/data/inpainting_examples/overture-creations-5sI6fQgYIuo.png"
63
+ ... )
64
+ >>> mask_image = load_image(
65
+ ... "https://raw.githubusercontent.com/CompVis/latent-diffusion/main/data/inpainting_examples/overture-creations-5sI6fQgYIuo_mask.png"
66
+ ... )
67
+
68
+ >>> prompt = "A girl holding a sign that says InstantX"
69
+ >>> image = pipe(
70
+ ... prompt,
71
+ ... image=init_image,
72
+ ... mask_image=mask_image,
73
+ ... control_image=control_image,
74
+ ... control_guidance_start=0.2,
75
+ ... control_guidance_end=0.8,
76
+ ... controlnet_conditioning_scale=0.7,
77
+ ... strength=0.7,
78
+ ... num_inference_steps=28,
79
+ ... guidance_scale=3.5,
80
+ ... ).images[0]
81
+ >>> image.save("flux_controlnet_inpaint.png")
82
+ ```
83
+ """
84
+
85
+
86
+ # Copied from diffusers.pipelines.flux.pipeline_flux.calculate_shift
87
+ def calculate_shift(
88
+ image_seq_len,
89
+ base_seq_len: int = 256,
90
+ max_seq_len: int = 4096,
91
+ base_shift: float = 0.5,
92
+ max_shift: float = 1.16,
93
+ ):
94
+ m = (max_shift - base_shift) / (max_seq_len - base_seq_len)
95
+ b = base_shift - m * base_seq_len
96
+ mu = image_seq_len * m + b
97
+ return mu
98
+
99
+
100
+ # Copied from diffusers.pipelines.stable_diffusion.pipeline_stable_diffusion_img2img.retrieve_latents
101
+ def retrieve_latents(
102
+ encoder_output: torch.Tensor, generator: Optional[torch.Generator] = None, sample_mode: str = "sample"
103
+ ):
104
+ if hasattr(encoder_output, "latent_dist") and sample_mode == "sample":
105
+ return encoder_output.latent_dist.sample(generator)
106
+ elif hasattr(encoder_output, "latent_dist") and sample_mode == "argmax":
107
+ return encoder_output.latent_dist.mode()
108
+ elif hasattr(encoder_output, "latents"):
109
+ return encoder_output.latents
110
+ else:
111
+ raise AttributeError("Could not access latents of provided encoder_output")
112
+
113
+
114
+ # Copied from diffusers.pipelines.stable_diffusion.pipeline_stable_diffusion.retrieve_timesteps
115
+ def retrieve_timesteps(
116
+ scheduler,
117
+ num_inference_steps: Optional[int] = None,
118
+ device: Optional[Union[str, torch.device]] = None,
119
+ timesteps: Optional[List[int]] = None,
120
+ sigmas: Optional[List[float]] = None,
121
+ **kwargs,
122
+ ):
123
+ r"""
124
+ Calls the scheduler's `set_timesteps` method and retrieves timesteps from the scheduler after the call. Handles
125
+ custom timesteps. Any kwargs will be supplied to `scheduler.set_timesteps`.
126
+
127
+ Args:
128
+ scheduler (`SchedulerMixin`):
129
+ The scheduler to get timesteps from.
130
+ num_inference_steps (`int`):
131
+ The number of diffusion steps used when generating samples with a pre-trained model. If used, `timesteps`
132
+ must be `None`.
133
+ device (`str` or `torch.device`, *optional*):
134
+ The device to which the timesteps should be moved to. If `None`, the timesteps are not moved.
135
+ timesteps (`List[int]`, *optional*):
136
+ Custom timesteps used to override the timestep spacing strategy of the scheduler. If `timesteps` is passed,
137
+ `num_inference_steps` and `sigmas` must be `None`.
138
+ sigmas (`List[float]`, *optional*):
139
+ Custom sigmas used to override the timestep spacing strategy of the scheduler. If `sigmas` is passed,
140
+ `num_inference_steps` and `timesteps` must be `None`.
141
+
142
+ Returns:
143
+ `Tuple[torch.Tensor, int]`: A tuple where the first element is the timestep schedule from the scheduler and the
144
+ second element is the number of inference steps.
145
+ """
146
+ if timesteps is not None and sigmas is not None:
147
+ raise ValueError("Only one of `timesteps` or `sigmas` can be passed. Please choose one to set custom values")
148
+ if timesteps is not None:
149
+ accepts_timesteps = "timesteps" in set(inspect.signature(scheduler.set_timesteps).parameters.keys())
150
+ if not accepts_timesteps:
151
+ raise ValueError(
152
+ f"The current scheduler class {scheduler.__class__}'s `set_timesteps` does not support custom"
153
+ f" timestep schedules. Please check whether you are using the correct scheduler."
154
+ )
155
+ scheduler.set_timesteps(timesteps=timesteps, device=device, **kwargs)
156
+ timesteps = scheduler.timesteps
157
+ num_inference_steps = len(timesteps)
158
+ elif sigmas is not None:
159
+ accept_sigmas = "sigmas" in set(inspect.signature(scheduler.set_timesteps).parameters.keys())
160
+ if not accept_sigmas:
161
+ raise ValueError(
162
+ f"The current scheduler class {scheduler.__class__}'s `set_timesteps` does not support custom"
163
+ f" sigmas schedules. Please check whether you are using the correct scheduler."
164
+ )
165
+ scheduler.set_timesteps(sigmas=sigmas, device=device, **kwargs)
166
+ timesteps = scheduler.timesteps
167
+ num_inference_steps = len(timesteps)
168
+ else:
169
+ scheduler.set_timesteps(num_inference_steps, device=device, **kwargs)
170
+ timesteps = scheduler.timesteps
171
+ return timesteps, num_inference_steps
172
+
173
+
174
+ class FluxControlNetInpaintPipeline(DiffusionPipeline, FluxLoraLoaderMixin, FromSingleFileMixin):
175
+ r"""
176
+ The Flux controlnet pipeline for inpainting.
177
+
178
+ Reference: https://blackforestlabs.ai/announcing-black-forest-labs/
179
+
180
+ Args:
181
+ transformer ([`FluxTransformer2DModel`]):
182
+ Conditional Transformer (MMDiT) architecture to denoise the encoded image latents.
183
+ scheduler ([`FlowMatchEulerDiscreteScheduler`]):
184
+ A scheduler to be used in combination with `transformer` to denoise the encoded image latents.
185
+ vae ([`AutoencoderKL`]):
186
+ Variational Auto-Encoder (VAE) Model to encode and decode images to and from latent representations.
187
+ text_encoder ([`CLIPTextModel`]):
188
+ [CLIP](https://huggingface.co/docs/transformers/model_doc/clip#transformers.CLIPTextModel), specifically
189
+ the [clip-vit-large-patch14](https://huggingface.co/openai/clip-vit-large-patch14) variant.
190
+ text_encoder_2 ([`T5EncoderModel`]):
191
+ [T5](https://huggingface.co/docs/transformers/en/model_doc/t5#transformers.T5EncoderModel), specifically
192
+ the [google/t5-v1_1-xxl](https://huggingface.co/google/t5-v1_1-xxl) variant.
193
+ tokenizer (`CLIPTokenizer`):
194
+ Tokenizer of class
195
+ [CLIPTokenizer](https://huggingface.co/docs/transformers/en/model_doc/clip#transformers.CLIPTokenizer).
196
+ tokenizer_2 (`T5TokenizerFast`):
197
+ Second Tokenizer of class
198
+ [T5TokenizerFast](https://huggingface.co/docs/transformers/en/model_doc/t5#transformers.T5TokenizerFast).
199
+ """
200
+
201
+ model_cpu_offload_seq = "text_encoder->text_encoder_2->transformer->vae"
202
+ _optional_components = []
203
+ _callback_tensor_inputs = ["latents", "prompt_embeds"]
204
+
205
+ def __init__(
206
+ self,
207
+ scheduler: FlowMatchEulerDiscreteScheduler,
208
+ vae: AutoencoderKL,
209
+ text_encoder: CLIPTextModel,
210
+ tokenizer: CLIPTokenizer,
211
+ text_encoder_2: T5EncoderModel,
212
+ tokenizer_2: T5TokenizerFast,
213
+ transformer: FluxTransformer2DModel,
214
+ controlnet: Union[
215
+ FluxControlNetModel, List[FluxControlNetModel], Tuple[FluxControlNetModel], FluxMultiControlNetModel
216
+ ],
217
+ ):
218
+ super().__init__()
219
+ if isinstance(controlnet, (list, tuple)):
220
+ controlnet = FluxMultiControlNetModel(controlnet)
221
+
222
+ self.register_modules(
223
+ scheduler=scheduler,
224
+ vae=vae,
225
+ text_encoder=text_encoder,
226
+ tokenizer=tokenizer,
227
+ text_encoder_2=text_encoder_2,
228
+ tokenizer_2=tokenizer_2,
229
+ transformer=transformer,
230
+ controlnet=controlnet,
231
+ )
232
+
233
+ self.vae_scale_factor = (
234
+ 2 ** (len(self.vae.config.block_out_channels)) if hasattr(self, "vae") and self.vae is not None else 16
235
+ )
236
+ self.image_processor = VaeImageProcessor(vae_scale_factor=self.vae_scale_factor)
237
+ self.mask_processor = VaeImageProcessor(
238
+ vae_scale_factor=self.vae_scale_factor,
239
+ vae_latent_channels=self.vae.config.latent_channels,
240
+ do_normalize=False,
241
+ do_binarize=True,
242
+ do_convert_grayscale=True,
243
+ )
244
+ self.tokenizer_max_length = (
245
+ self.tokenizer.model_max_length if hasattr(self, "tokenizer") and self.tokenizer is not None else 77
246
+ )
247
+ self.default_sample_size = 64
248
+
249
+ # Copied from diffusers.pipelines.flux.pipeline_flux.FluxPipeline._get_t5_prompt_embeds
250
+ def _get_t5_prompt_embeds(
251
+ self,
252
+ prompt: Union[str, List[str]] = None,
253
+ num_images_per_prompt: int = 1,
254
+ max_sequence_length: int = 512,
255
+ device: Optional[torch.device] = None,
256
+ dtype: Optional[torch.dtype] = None,
257
+ ):
258
+ device = device or self._execution_device
259
+ dtype = dtype or self.text_encoder.dtype
260
+
261
+ prompt = [prompt] if isinstance(prompt, str) else prompt
262
+ batch_size = len(prompt)
263
+
264
+ if isinstance(self, TextualInversionLoaderMixin):
265
+ prompt = self.maybe_convert_prompt(prompt, self.tokenizer_2)
266
+
267
+ text_inputs = self.tokenizer_2(
268
+ prompt,
269
+ padding="max_length",
270
+ max_length=max_sequence_length,
271
+ truncation=True,
272
+ return_length=False,
273
+ return_overflowing_tokens=False,
274
+ return_tensors="pt",
275
+ )
276
+ text_input_ids = text_inputs.input_ids
277
+ untruncated_ids = self.tokenizer_2(prompt, padding="longest", return_tensors="pt").input_ids
278
+
279
+ if untruncated_ids.shape[-1] >= text_input_ids.shape[-1] and not torch.equal(text_input_ids, untruncated_ids):
280
+ removed_text = self.tokenizer_2.batch_decode(untruncated_ids[:, self.tokenizer_max_length - 1 : -1])
281
+ logger.warning(
282
+ "The following part of your input was truncated because `max_sequence_length` is set to "
283
+ f" {max_sequence_length} tokens: {removed_text}"
284
+ )
285
+
286
+ prompt_embeds = self.text_encoder_2(text_input_ids.to(device), output_hidden_states=False)[0]
287
+
288
+ dtype = self.text_encoder_2.dtype
289
+ prompt_embeds = prompt_embeds.to(dtype=dtype, device=device)
290
+
291
+ _, seq_len, _ = prompt_embeds.shape
292
+
293
+ # duplicate text embeddings and attention mask for each generation per prompt, using mps friendly method
294
+ prompt_embeds = prompt_embeds.repeat(1, num_images_per_prompt, 1)
295
+ prompt_embeds = prompt_embeds.view(batch_size * num_images_per_prompt, seq_len, -1)
296
+
297
+ return prompt_embeds
298
+
299
+ # Copied from diffusers.pipelines.flux.pipeline_flux.FluxPipeline._get_clip_prompt_embeds
300
+ def _get_clip_prompt_embeds(
301
+ self,
302
+ prompt: Union[str, List[str]],
303
+ num_images_per_prompt: int = 1,
304
+ device: Optional[torch.device] = None,
305
+ ):
306
+ device = device or self._execution_device
307
+
308
+ prompt = [prompt] if isinstance(prompt, str) else prompt
309
+ batch_size = len(prompt)
310
+
311
+ if isinstance(self, TextualInversionLoaderMixin):
312
+ prompt = self.maybe_convert_prompt(prompt, self.tokenizer)
313
+
314
+ text_inputs = self.tokenizer(
315
+ prompt,
316
+ padding="max_length",
317
+ max_length=self.tokenizer_max_length,
318
+ truncation=True,
319
+ return_overflowing_tokens=False,
320
+ return_length=False,
321
+ return_tensors="pt",
322
+ )
323
+
324
+ text_input_ids = text_inputs.input_ids
325
+ untruncated_ids = self.tokenizer(prompt, padding="longest", return_tensors="pt").input_ids
326
+ if untruncated_ids.shape[-1] >= text_input_ids.shape[-1] and not torch.equal(text_input_ids, untruncated_ids):
327
+ removed_text = self.tokenizer.batch_decode(untruncated_ids[:, self.tokenizer_max_length - 1 : -1])
328
+ logger.warning(
329
+ "The following part of your input was truncated because CLIP can only handle sequences up to"
330
+ f" {self.tokenizer_max_length} tokens: {removed_text}"
331
+ )
332
+ prompt_embeds = self.text_encoder(text_input_ids.to(device), output_hidden_states=False)
333
+
334
+ # Use pooled output of CLIPTextModel
335
+ prompt_embeds = prompt_embeds.pooler_output
336
+ prompt_embeds = prompt_embeds.to(dtype=self.text_encoder.dtype, device=device)
337
+
338
+ # duplicate text embeddings for each generation per prompt, using mps friendly method
339
+ prompt_embeds = prompt_embeds.repeat(1, num_images_per_prompt)
340
+ prompt_embeds = prompt_embeds.view(batch_size * num_images_per_prompt, -1)
341
+
342
+ return prompt_embeds
343
+
344
+ # Copied from diffusers.pipelines.flux.pipeline_flux.FluxPipeline.encode_prompt
345
+ def encode_prompt(
346
+ self,
347
+ prompt: Union[str, List[str]],
348
+ prompt_2: Union[str, List[str]],
349
+ device: Optional[torch.device] = None,
350
+ num_images_per_prompt: int = 1,
351
+ prompt_embeds: Optional[torch.FloatTensor] = None,
352
+ pooled_prompt_embeds: Optional[torch.FloatTensor] = None,
353
+ max_sequence_length: int = 512,
354
+ lora_scale: Optional[float] = None,
355
+ ):
356
+ r"""
357
+
358
+ Args:
359
+ prompt (`str` or `List[str]`, *optional*):
360
+ prompt to be encoded
361
+ prompt_2 (`str` or `List[str]`, *optional*):
362
+ The prompt or prompts to be sent to the `tokenizer_2` and `text_encoder_2`. If not defined, `prompt` is
363
+ used in all text-encoders
364
+ device: (`torch.device`):
365
+ torch device
366
+ num_images_per_prompt (`int`):
367
+ number of images that should be generated per prompt
368
+ prompt_embeds (`torch.FloatTensor`, *optional*):
369
+ Pre-generated text embeddings. Can be used to easily tweak text inputs, *e.g.* prompt weighting. If not
370
+ provided, text embeddings will be generated from `prompt` input argument.
371
+ pooled_prompt_embeds (`torch.FloatTensor`, *optional*):
372
+ Pre-generated pooled text embeddings. Can be used to easily tweak text inputs, *e.g.* prompt weighting.
373
+ If not provided, pooled text embeddings will be generated from `prompt` input argument.
374
+ lora_scale (`float`, *optional*):
375
+ A lora scale that will be applied to all LoRA layers of the text encoder if LoRA layers are loaded.
376
+ """
377
+ device = device or self._execution_device
378
+
379
+ # set lora scale so that monkey patched LoRA
380
+ # function of text encoder can correctly access it
381
+ if lora_scale is not None and isinstance(self, FluxLoraLoaderMixin):
382
+ self._lora_scale = lora_scale
383
+
384
+ # dynamically adjust the LoRA scale
385
+ if self.text_encoder is not None and USE_PEFT_BACKEND:
386
+ scale_lora_layers(self.text_encoder, lora_scale)
387
+ if self.text_encoder_2 is not None and USE_PEFT_BACKEND:
388
+ scale_lora_layers(self.text_encoder_2, lora_scale)
389
+
390
+ prompt = [prompt] if isinstance(prompt, str) else prompt
391
+
392
+ if prompt_embeds is None:
393
+ prompt_2 = prompt_2 or prompt
394
+ prompt_2 = [prompt_2] if isinstance(prompt_2, str) else prompt_2
395
+
396
+ # We only use the pooled prompt output from the CLIPTextModel
397
+ pooled_prompt_embeds = self._get_clip_prompt_embeds(
398
+ prompt=prompt,
399
+ device=device,
400
+ num_images_per_prompt=num_images_per_prompt,
401
+ )
402
+ prompt_embeds = self._get_t5_prompt_embeds(
403
+ prompt=prompt_2,
404
+ num_images_per_prompt=num_images_per_prompt,
405
+ max_sequence_length=max_sequence_length,
406
+ device=device,
407
+ )
408
+
409
+ if self.text_encoder is not None:
410
+ if isinstance(self, FluxLoraLoaderMixin) and USE_PEFT_BACKEND:
411
+ # Retrieve the original scale by scaling back the LoRA layers
412
+ unscale_lora_layers(self.text_encoder, lora_scale)
413
+
414
+ if self.text_encoder_2 is not None:
415
+ if isinstance(self, FluxLoraLoaderMixin) and USE_PEFT_BACKEND:
416
+ # Retrieve the original scale by scaling back the LoRA layers
417
+ unscale_lora_layers(self.text_encoder_2, lora_scale)
418
+
419
+ dtype = self.text_encoder.dtype if self.text_encoder is not None else self.transformer.dtype
420
+ text_ids = torch.zeros(prompt_embeds.shape[1], 3).to(device=device, dtype=dtype)
421
+
422
+ return prompt_embeds, pooled_prompt_embeds, text_ids
423
+
424
+ # Copied from diffusers.pipelines.stable_diffusion_3.pipeline_stable_diffusion_3_inpaint.StableDiffusion3InpaintPipeline._encode_vae_image
425
+ def _encode_vae_image(self, image: torch.Tensor, generator: torch.Generator):
426
+ if isinstance(generator, list):
427
+ image_latents = [
428
+ retrieve_latents(self.vae.encode(image[i : i + 1]), generator=generator[i])
429
+ for i in range(image.shape[0])
430
+ ]
431
+ image_latents = torch.cat(image_latents, dim=0)
432
+ else:
433
+ image_latents = retrieve_latents(self.vae.encode(image), generator=generator)
434
+
435
+ image_latents = (image_latents - self.vae.config.shift_factor) * self.vae.config.scaling_factor
436
+
437
+ return image_latents
438
+
439
+ # Copied from diffusers.pipelines.stable_diffusion_3.pipeline_stable_diffusion_3_img2img.StableDiffusion3Img2ImgPipeline.get_timesteps
440
+ def get_timesteps(self, num_inference_steps, strength, device):
441
+ # get the original timestep using init_timestep
442
+ init_timestep = min(num_inference_steps * strength, num_inference_steps)
443
+
444
+ t_start = int(max(num_inference_steps - init_timestep, 0))
445
+ timesteps = self.scheduler.timesteps[t_start * self.scheduler.order :]
446
+ if hasattr(self.scheduler, "set_begin_index"):
447
+ self.scheduler.set_begin_index(t_start * self.scheduler.order)
448
+
449
+ return timesteps, num_inference_steps - t_start
450
+
451
+ def check_inputs(
452
+ self,
453
+ prompt,
454
+ prompt_2,
455
+ image,
456
+ mask_image,
457
+ strength,
458
+ height,
459
+ width,
460
+ output_type,
461
+ prompt_embeds=None,
462
+ pooled_prompt_embeds=None,
463
+ callback_on_step_end_tensor_inputs=None,
464
+ padding_mask_crop=None,
465
+ max_sequence_length=None,
466
+ ):
467
+ if strength < 0 or strength > 1:
468
+ raise ValueError(f"The value of strength should in [0.0, 1.0] but is {strength}")
469
+
470
+ if height % 8 != 0 or width % 8 != 0:
471
+ raise ValueError(f"`height` and `width` have to be divisible by 8 but are {height} and {width}.")
472
+
473
+ if callback_on_step_end_tensor_inputs is not None and not all(
474
+ k in self._callback_tensor_inputs for k in callback_on_step_end_tensor_inputs
475
+ ):
476
+ raise ValueError(
477
+ f"`callback_on_step_end_tensor_inputs` has to be in {self._callback_tensor_inputs}, but found {[k for k in callback_on_step_end_tensor_inputs if k not in self._callback_tensor_inputs]}"
478
+ )
479
+
480
+ if prompt is not None and prompt_embeds is not None:
481
+ raise ValueError(
482
+ f"Cannot forward both `prompt`: {prompt} and `prompt_embeds`: {prompt_embeds}. Please make sure to"
483
+ " only forward one of the two."
484
+ )
485
+ elif prompt_2 is not None and prompt_embeds is not None:
486
+ raise ValueError(
487
+ f"Cannot forward both `prompt_2`: {prompt_2} and `prompt_embeds`: {prompt_embeds}. Please make sure to"
488
+ " only forward one of the two."
489
+ )
490
+ elif prompt is None and prompt_embeds is None:
491
+ raise ValueError(
492
+ "Provide either `prompt` or `prompt_embeds`. Cannot leave both `prompt` and `prompt_embeds` undefined."
493
+ )
494
+ elif prompt is not None and (not isinstance(prompt, str) and not isinstance(prompt, list)):
495
+ raise ValueError(f"`prompt` has to be of type `str` or `list` but is {type(prompt)}")
496
+ elif prompt_2 is not None and (not isinstance(prompt_2, str) and not isinstance(prompt_2, list)):
497
+ raise ValueError(f"`prompt_2` has to be of type `str` or `list` but is {type(prompt_2)}")
498
+
499
+ if prompt_embeds is not None and pooled_prompt_embeds is None:
500
+ raise ValueError(
501
+ "If `prompt_embeds` are provided, `pooled_prompt_embeds` also have to be passed. Make sure to generate `pooled_prompt_embeds` from the same text encoder that was used to generate `prompt_embeds`."
502
+ )
503
+
504
+ if padding_mask_crop is not None:
505
+ if not isinstance(image, PIL.Image.Image):
506
+ raise ValueError(
507
+ f"The image should be a PIL image when inpainting mask crop, but is of type" f" {type(image)}."
508
+ )
509
+ if not isinstance(mask_image, PIL.Image.Image):
510
+ raise ValueError(
511
+ f"The mask image should be a PIL image when inpainting mask crop, but is of type"
512
+ f" {type(mask_image)}."
513
+ )
514
+ if output_type != "pil":
515
+ raise ValueError(f"The output type should be PIL when inpainting mask crop, but is" f" {output_type}.")
516
+
517
+ if max_sequence_length is not None and max_sequence_length > 512:
518
+ raise ValueError(f"`max_sequence_length` cannot be greater than 512 but is {max_sequence_length}")
519
+
520
+ @staticmethod
521
+ # Copied from diffusers.pipelines.flux.pipeline_flux.FluxPipeline._prepare_latent_image_ids
522
+ def _prepare_latent_image_ids(batch_size, height, width, device, dtype):
523
+ latent_image_ids = torch.zeros(height // 2, width // 2, 3)
524
+ latent_image_ids[..., 1] = latent_image_ids[..., 1] + torch.arange(height // 2)[:, None]
525
+ latent_image_ids[..., 2] = latent_image_ids[..., 2] + torch.arange(width // 2)[None, :]
526
+
527
+ latent_image_id_height, latent_image_id_width, latent_image_id_channels = latent_image_ids.shape
528
+
529
+ latent_image_ids = latent_image_ids.reshape(
530
+ latent_image_id_height * latent_image_id_width, latent_image_id_channels
531
+ )
532
+
533
+ return latent_image_ids.to(device=device, dtype=dtype)
534
+
535
+ @staticmethod
536
+ # Copied from diffusers.pipelines.flux.pipeline_flux.FluxPipeline._pack_latents
537
+ def _pack_latents(latents, batch_size, num_channels_latents, height, width):
538
+ latents = latents.view(batch_size, num_channels_latents, height // 2, 2, width // 2, 2)
539
+ latents = latents.permute(0, 2, 4, 1, 3, 5)
540
+ latents = latents.reshape(batch_size, (height // 2) * (width // 2), num_channels_latents * 4)
541
+
542
+ return latents
543
+
544
+ @staticmethod
545
+ # Copied from diffusers.pipelines.flux.pipeline_flux.FluxPipeline._unpack_latents
546
+ def _unpack_latents(latents, height, width, vae_scale_factor):
547
+ batch_size, num_patches, channels = latents.shape
548
+
549
+ height = height // vae_scale_factor
550
+ width = width // vae_scale_factor
551
+
552
+ latents = latents.view(batch_size, height, width, channels // 4, 2, 2)
553
+ latents = latents.permute(0, 3, 1, 4, 2, 5)
554
+
555
+ latents = latents.reshape(batch_size, channels // (2 * 2), height * 2, width * 2)
556
+
557
+ return latents
558
+
559
+ # Copied from diffusers.pipelines.flux.pipeline_flux_inpaint.FluxInpaintPipeline.prepare_latents
560
+ def prepare_latents(
561
+ self,
562
+ image,
563
+ timestep,
564
+ batch_size,
565
+ num_channels_latents,
566
+ height,
567
+ width,
568
+ dtype,
569
+ device,
570
+ generator,
571
+ latents=None,
572
+ ):
573
+ if isinstance(generator, list) and len(generator) != batch_size:
574
+ raise ValueError(
575
+ f"You have passed a list of generators of length {len(generator)}, but requested an effective batch"
576
+ f" size of {batch_size}. Make sure the batch size matches the length of the generators."
577
+ )
578
+
579
+ height = 2 * (int(height) // self.vae_scale_factor)
580
+ width = 2 * (int(width) // self.vae_scale_factor)
581
+
582
+ shape = (batch_size, num_channels_latents, height, width)
583
+ latent_image_ids = self._prepare_latent_image_ids(batch_size, height, width, device, dtype)
584
+
585
+ image = image.to(device=device, dtype=dtype)
586
+ image_latents = self._encode_vae_image(image=image, generator=generator)
587
+
588
+ if batch_size > image_latents.shape[0] and batch_size % image_latents.shape[0] == 0:
589
+ # expand init_latents for batch_size
590
+ additional_image_per_prompt = batch_size // image_latents.shape[0]
591
+ image_latents = torch.cat([image_latents] * additional_image_per_prompt, dim=0)
592
+ elif batch_size > image_latents.shape[0] and batch_size % image_latents.shape[0] != 0:
593
+ raise ValueError(
594
+ f"Cannot duplicate `image` of batch size {image_latents.shape[0]} to {batch_size} text prompts."
595
+ )
596
+ else:
597
+ image_latents = torch.cat([image_latents], dim=0)
598
+
599
+ if latents is None:
600
+ noise = randn_tensor(shape, generator=generator, device=device, dtype=dtype)
601
+ latents = self.scheduler.scale_noise(image_latents, timestep, noise)
602
+ else:
603
+ noise = latents.to(device)
604
+ latents = noise
605
+
606
+ noise = self._pack_latents(noise, batch_size, num_channels_latents, height, width)
607
+ image_latents = self._pack_latents(image_latents, batch_size, num_channels_latents, height, width)
608
+ latents = self._pack_latents(latents, batch_size, num_channels_latents, height, width)
609
+ return latents, noise, image_latents, latent_image_ids
610
+
611
+ # Copied from diffusers.pipelines.flux.pipeline_flux_inpaint.FluxInpaintPipeline.prepare_mask_latents
612
+ def prepare_mask_latents(
613
+ self,
614
+ mask,
615
+ masked_image,
616
+ batch_size,
617
+ num_channels_latents,
618
+ num_images_per_prompt,
619
+ height,
620
+ width,
621
+ dtype,
622
+ device,
623
+ generator,
624
+ ):
625
+ height = 2 * (int(height) // self.vae_scale_factor)
626
+ width = 2 * (int(width) // self.vae_scale_factor)
627
+ # resize the mask to latents shape as we concatenate the mask to the latents
628
+ # we do that before converting to dtype to avoid breaking in case we're using cpu_offload
629
+ # and half precision
630
+ mask = torch.nn.functional.interpolate(mask, size=(height, width))
631
+ mask = mask.to(device=device, dtype=dtype)
632
+
633
+ batch_size = batch_size * num_images_per_prompt
634
+
635
+ masked_image = masked_image.to(device=device, dtype=dtype)
636
+
637
+ if masked_image.shape[1] == 16:
638
+ masked_image_latents = masked_image
639
+ else:
640
+ masked_image_latents = retrieve_latents(self.vae.encode(masked_image), generator=generator)
641
+
642
+ masked_image_latents = (masked_image_latents - self.vae.config.shift_factor) * self.vae.config.scaling_factor
643
+
644
+ # duplicate mask and masked_image_latents for each generation per prompt, using mps friendly method
645
+ if mask.shape[0] < batch_size:
646
+ if not batch_size % mask.shape[0] == 0:
647
+ raise ValueError(
648
+ "The passed mask and the required batch size don't match. Masks are supposed to be duplicated to"
649
+ f" a total batch size of {batch_size}, but {mask.shape[0]} masks were passed. Make sure the number"
650
+ " of masks that you pass is divisible by the total requested batch size."
651
+ )
652
+ mask = mask.repeat(batch_size // mask.shape[0], 1, 1, 1)
653
+ if masked_image_latents.shape[0] < batch_size:
654
+ if not batch_size % masked_image_latents.shape[0] == 0:
655
+ raise ValueError(
656
+ "The passed images and the required batch size don't match. Images are supposed to be duplicated"
657
+ f" to a total batch size of {batch_size}, but {masked_image_latents.shape[0]} images were passed."
658
+ " Make sure the number of images that you pass is divisible by the total requested batch size."
659
+ )
660
+ masked_image_latents = masked_image_latents.repeat(batch_size // masked_image_latents.shape[0], 1, 1, 1)
661
+
662
+ # aligning device to prevent device errors when concating it with the latent model input
663
+ masked_image_latents = masked_image_latents.to(device=device, dtype=dtype)
664
+
665
+ masked_image_latents = self._pack_latents(
666
+ masked_image_latents,
667
+ batch_size,
668
+ num_channels_latents,
669
+ height,
670
+ width,
671
+ )
672
+ mask = self._pack_latents(
673
+ mask.repeat(1, num_channels_latents, 1, 1),
674
+ batch_size,
675
+ num_channels_latents,
676
+ height,
677
+ width,
678
+ )
679
+
680
+ return mask, masked_image_latents
681
+
682
+ # Copied from diffusers.pipelines.controlnet_sd3.pipeline_stable_diffusion_3_controlnet.StableDiffusion3ControlNetPipeline.prepare_image
683
+ def prepare_image(
684
+ self,
685
+ image,
686
+ width,
687
+ height,
688
+ batch_size,
689
+ num_images_per_prompt,
690
+ device,
691
+ dtype,
692
+ do_classifier_free_guidance=False,
693
+ guess_mode=False,
694
+ ):
695
+ if isinstance(image, torch.Tensor):
696
+ pass
697
+ else:
698
+ image = self.image_processor.preprocess(image, height=height, width=width)
699
+
700
+ image_batch_size = image.shape[0]
701
+
702
+ if image_batch_size == 1:
703
+ repeat_by = batch_size
704
+ else:
705
+ # image batch size is the same as prompt batch size
706
+ repeat_by = num_images_per_prompt
707
+
708
+ image = image.repeat_interleave(repeat_by, dim=0)
709
+
710
+ image = image.to(device=device, dtype=dtype)
711
+
712
+ if do_classifier_free_guidance and not guess_mode:
713
+ image = torch.cat([image] * 2)
714
+
715
+ return image
716
+
717
+ @property
718
+ def guidance_scale(self):
719
+ return self._guidance_scale
720
+
721
+ @property
722
+ def joint_attention_kwargs(self):
723
+ return self._joint_attention_kwargs
724
+
725
+ @property
726
+ def num_timesteps(self):
727
+ return self._num_timesteps
728
+
729
+ @property
730
+ def interrupt(self):
731
+ return self._interrupt
732
+
733
+ @torch.no_grad()
734
+ @replace_example_docstring(EXAMPLE_DOC_STRING)
735
+ def __call__(
736
+ self,
737
+ prompt: Union[str, List[str]] = None,
738
+ prompt_2: Optional[Union[str, List[str]]] = None,
739
+ image: PipelineImageInput = None,
740
+ mask_image: PipelineImageInput = None,
741
+ masked_image_latents: PipelineImageInput = None,
742
+ control_image: PipelineImageInput = None,
743
+ height: Optional[int] = None,
744
+ width: Optional[int] = None,
745
+ strength: float = 0.6,
746
+ padding_mask_crop: Optional[int] = None,
747
+ timesteps: List[int] = None,
748
+ num_inference_steps: int = 28,
749
+ guidance_scale: float = 7.0,
750
+ control_guidance_start: Union[float, List[float]] = 0.0,
751
+ control_guidance_end: Union[float, List[float]] = 1.0,
752
+ control_mode: Optional[Union[int, List[int]]] = None,
753
+ controlnet_conditioning_scale: Union[float, List[float]] = 1.0,
754
+ num_images_per_prompt: Optional[int] = 1,
755
+ generator: Optional[Union[torch.Generator, List[torch.Generator]]] = None,
756
+ latents: Optional[torch.FloatTensor] = None,
757
+ prompt_embeds: Optional[torch.FloatTensor] = None,
758
+ pooled_prompt_embeds: Optional[torch.FloatTensor] = None,
759
+ output_type: Optional[str] = "pil",
760
+ return_dict: bool = True,
761
+ joint_attention_kwargs: Optional[Dict[str, Any]] = None,
762
+ callback_on_step_end: Optional[Callable[[int, int, Dict], None]] = None,
763
+ callback_on_step_end_tensor_inputs: List[str] = ["latents"],
764
+ max_sequence_length: int = 512,
765
+ ):
766
+ """
767
+ Function invoked when calling the pipeline for generation.
768
+
769
+ Args:
770
+ prompt (`str` or `List[str]`, *optional*):
771
+ The prompt or prompts to guide the image generation.
772
+ prompt_2 (`str` or `List[str]`, *optional*):
773
+ The prompt or prompts to be sent to the `tokenizer_2` and `text_encoder_2`.
774
+ image (`PIL.Image.Image` or `List[PIL.Image.Image]` or `torch.FloatTensor`):
775
+ The image(s) to inpaint.
776
+ mask_image (`PIL.Image.Image` or `List[PIL.Image.Image]` or `torch.FloatTensor`):
777
+ The mask image(s) to use for inpainting. White pixels in the mask will be repainted, while black pixels
778
+ will be preserved.
779
+ masked_image_latents (`torch.FloatTensor`, *optional*):
780
+ Pre-generated masked image latents.
781
+ control_image (`PIL.Image.Image` or `List[PIL.Image.Image]` or `torch.FloatTensor`):
782
+ The ControlNet input condition. Image to control the generation.
783
+ height (`int`, *optional*, defaults to self.default_sample_size * self.vae_scale_factor):
784
+ The height in pixels of the generated image.
785
+ width (`int`, *optional*, defaults to self.default_sample_size * self.vae_scale_factor):
786
+ The width in pixels of the generated image.
787
+ strength (`float`, *optional*, defaults to 0.6):
788
+ Conceptually, indicates how much to inpaint the masked area. Must be between 0 and 1.
789
+ padding_mask_crop (`int`, *optional*):
790
+ The size of the padding to use when cropping the mask.
791
+ num_inference_steps (`int`, *optional*, defaults to 28):
792
+ The number of denoising steps. More denoising steps usually lead to a higher quality image at the
793
+ expense of slower inference.
794
+ timesteps (`List[int]`, *optional*):
795
+ Custom timesteps to use for the denoising process.
796
+ guidance_scale (`float`, *optional*, defaults to 7.0):
797
+ Guidance scale as defined in [Classifier-Free Diffusion Guidance](https://arxiv.org/abs/2207.12598).
798
+ control_guidance_start (`float` or `List[float]`, *optional*, defaults to 0.0):
799
+ The percentage of total steps at which the ControlNet starts applying.
800
+ control_guidance_end (`float` or `List[float]`, *optional*, defaults to 1.0):
801
+ The percentage of total steps at which the ControlNet stops applying.
802
+ control_mode (`int` or `List[int]`, *optional*):
803
+ The mode for the ControlNet. If multiple ControlNets are used, this should be a list.
804
+ controlnet_conditioning_scale (`float` or `List[float]`, *optional*, defaults to 1.0):
805
+ The outputs of the ControlNet are multiplied by `controlnet_conditioning_scale` before they are added
806
+ to the residual in the original transformer.
807
+ num_images_per_prompt (`int`, *optional*, defaults to 1):
808
+ The number of images to generate per prompt.
809
+ generator (`torch.Generator` or `List[torch.Generator]`, *optional*):
810
+ One or more [torch generator(s)](https://pytorch.org/docs/stable/generated/torch.Generator.html) to
811
+ make generation deterministic.
812
+ latents (`torch.FloatTensor`, *optional*):
813
+ Pre-generated noisy latents, sampled from a Gaussian distribution, to be used as inputs for image
814
+ generation. Can be used to tweak the same generation with different prompts.
815
+ prompt_embeds (`torch.FloatTensor`, *optional*):
816
+ Pre-generated text embeddings. Can be used to easily tweak text inputs, *e.g.* prompt weighting.
817
+ pooled_prompt_embeds (`torch.FloatTensor`, *optional*):
818
+ Pre-generated pooled text embeddings.
819
+ output_type (`str`, *optional*, defaults to `"pil"`):
820
+ The output format of the generate image. Choose between `PIL.Image` or `np.array`.
821
+ return_dict (`bool`, *optional*, defaults to `True`):
822
+ Whether or not to return a [`~pipelines.flux.FluxPipelineOutput`] instead of a plain tuple.
823
+ joint_attention_kwargs (`dict`, *optional*):
824
+ Additional keyword arguments to be passed to the joint attention mechanism.
825
+ callback_on_step_end (`Callable`, *optional*):
826
+ A function that calls at the end of each denoising step during the inference.
827
+ callback_on_step_end_tensor_inputs (`List[str]`, *optional*):
828
+ The list of tensor inputs for the `callback_on_step_end` function.
829
+ max_sequence_length (`int`, *optional*, defaults to 512):
830
+ The maximum length of the sequence to be generated.
831
+
832
+ Examples:
833
+
834
+ Returns:
835
+ [`~pipelines.flux.FluxPipelineOutput`] or `tuple`: [`~pipelines.flux.FluxPipelineOutput`] if `return_dict`
836
+ is True, otherwise a `tuple`. When returning a tuple, the first element is a list with the generated
837
+ images.
838
+ """
839
+ height = height or self.default_sample_size * self.vae_scale_factor
840
+ width = width or self.default_sample_size * self.vae_scale_factor
841
+
842
+ global_height = height
843
+ global_width = width
844
+
845
+ if not isinstance(control_guidance_start, list) and isinstance(control_guidance_end, list):
846
+ control_guidance_start = len(control_guidance_end) * [control_guidance_start]
847
+ elif not isinstance(control_guidance_end, list) and isinstance(control_guidance_start, list):
848
+ control_guidance_end = len(control_guidance_start) * [control_guidance_end]
849
+ elif not isinstance(control_guidance_start, list) and not isinstance(control_guidance_end, list):
850
+ mult = len(self.controlnet.nets) if isinstance(self.controlnet, FluxMultiControlNetModel) else 1
851
+ control_guidance_start, control_guidance_end = (
852
+ mult * [control_guidance_start],
853
+ mult * [control_guidance_end],
854
+ )
855
+
856
+ # 1. Check inputs
857
+ self.check_inputs(
858
+ prompt,
859
+ prompt_2,
860
+ image,
861
+ mask_image,
862
+ strength,
863
+ height,
864
+ width,
865
+ output_type=output_type,
866
+ prompt_embeds=prompt_embeds,
867
+ pooled_prompt_embeds=pooled_prompt_embeds,
868
+ callback_on_step_end_tensor_inputs=callback_on_step_end_tensor_inputs,
869
+ padding_mask_crop=padding_mask_crop,
870
+ max_sequence_length=max_sequence_length,
871
+ )
872
+
873
+ self._guidance_scale = guidance_scale
874
+ self._joint_attention_kwargs = joint_attention_kwargs
875
+ self._interrupt = False
876
+
877
+ # 2. Define call parameters
878
+ if prompt is not None and isinstance(prompt, str):
879
+ batch_size = 1
880
+ elif prompt is not None and isinstance(prompt, list):
881
+ batch_size = len(prompt)
882
+ else:
883
+ batch_size = prompt_embeds.shape[0]
884
+
885
+ device = self._execution_device
886
+ dtype = self.transformer.dtype
887
+
888
+ # 3. Encode input prompt
889
+ lora_scale = (
890
+ self.joint_attention_kwargs.get("scale", None) if self.joint_attention_kwargs is not None else None
891
+ )
892
+ prompt_embeds, pooled_prompt_embeds, text_ids = self.encode_prompt(
893
+ prompt=prompt,
894
+ prompt_2=prompt_2,
895
+ prompt_embeds=prompt_embeds,
896
+ pooled_prompt_embeds=pooled_prompt_embeds,
897
+ device=device,
898
+ num_images_per_prompt=num_images_per_prompt,
899
+ max_sequence_length=max_sequence_length,
900
+ lora_scale=lora_scale,
901
+ )
902
+
903
+ # 4. Preprocess mask and image
904
+ if padding_mask_crop is not None:
905
+ crops_coords = self.mask_processor.get_crop_region(
906
+ mask_image, global_width, global_height, pad=padding_mask_crop
907
+ )
908
+ resize_mode = "fill"
909
+ else:
910
+ crops_coords = None
911
+ resize_mode = "default"
912
+
913
+ original_image = image
914
+ init_image = self.image_processor.preprocess(
915
+ image, height=global_height, width=global_width, crops_coords=crops_coords, resize_mode=resize_mode
916
+ )
917
+ init_image = init_image.to(dtype=torch.float32)
918
+
919
+ # 5. Prepare control image
920
+ num_channels_latents = self.transformer.config.in_channels // 4
921
+ if isinstance(self.controlnet, FluxControlNetModel):
922
+ control_image = self.prepare_image(
923
+ image=control_image,
924
+ width=height,
925
+ height=width,
926
+ batch_size=batch_size * num_images_per_prompt,
927
+ num_images_per_prompt=num_images_per_prompt,
928
+ device=device,
929
+ dtype=self.vae.dtype,
930
+ )
931
+ height, width = control_image.shape[-2:]
932
+
933
+ # vae encode
934
+ control_image = self.vae.encode(control_image).latent_dist.sample()
935
+ control_image = (control_image - self.vae.config.shift_factor) * self.vae.config.scaling_factor
936
+
937
+ # pack
938
+ height_control_image, width_control_image = control_image.shape[2:]
939
+ control_image = self._pack_latents(
940
+ control_image,
941
+ batch_size * num_images_per_prompt,
942
+ num_channels_latents,
943
+ height_control_image,
944
+ width_control_image,
945
+ )
946
+
947
+ # set control mode
948
+ if control_mode is not None:
949
+ control_mode = torch.tensor(control_mode).to(device, dtype=torch.long)
950
+ control_mode = control_mode.reshape([-1, 1])
951
+
952
+ elif isinstance(self.controlnet, FluxMultiControlNetModel):
953
+ control_images = []
954
+
955
+ for control_image_ in control_image:
956
+ control_image_ = self.prepare_image(
957
+ image=control_image_,
958
+ width=width,
959
+ height=height,
960
+ batch_size=batch_size * num_images_per_prompt,
961
+ num_images_per_prompt=num_images_per_prompt,
962
+ device=device,
963
+ dtype=self.vae.dtype,
964
+ )
965
+ height, width = control_image_.shape[-2:]
966
+
967
+ # vae encode
968
+ control_image_ = self.vae.encode(control_image_).latent_dist.sample()
969
+ control_image_ = (control_image_ - self.vae.config.shift_factor) * self.vae.config.scaling_factor
970
+
971
+ # pack
972
+ height_control_image, width_control_image = control_image_.shape[2:]
973
+ control_image_ = self._pack_latents(
974
+ control_image_,
975
+ batch_size * num_images_per_prompt,
976
+ num_channels_latents,
977
+ height_control_image,
978
+ width_control_image,
979
+ )
980
+
981
+ control_images.append(control_image_)
982
+
983
+ control_image = control_images
984
+
985
+ # set control mode
986
+ control_mode_ = []
987
+ if isinstance(control_mode, list):
988
+ for cmode in control_mode:
989
+ if cmode is None:
990
+ control_mode_.append(-1)
991
+ else:
992
+ control_mode_.append(cmode)
993
+ control_mode = torch.tensor(control_mode_).to(device, dtype=torch.long)
994
+ control_mode = control_mode.reshape([-1, 1])
995
+
996
+ # 6. Prepare timesteps
997
+
998
+ sigmas = np.linspace(1.0, 1 / num_inference_steps, num_inference_steps)
999
+ image_seq_len = (int(global_height) // self.vae_scale_factor) * (int(global_width) // self.vae_scale_factor)
1000
+ mu = calculate_shift(
1001
+ image_seq_len,
1002
+ self.scheduler.config.base_image_seq_len,
1003
+ self.scheduler.config.max_image_seq_len,
1004
+ self.scheduler.config.base_shift,
1005
+ self.scheduler.config.max_shift,
1006
+ )
1007
+ timesteps, num_inference_steps = retrieve_timesteps(
1008
+ self.scheduler,
1009
+ num_inference_steps,
1010
+ device,
1011
+ timesteps,
1012
+ sigmas,
1013
+ mu=mu,
1014
+ )
1015
+ timesteps, num_inference_steps = self.get_timesteps(num_inference_steps, strength, device)
1016
+
1017
+ if num_inference_steps < 1:
1018
+ raise ValueError(
1019
+ f"After adjusting the num_inference_steps by strength parameter: {strength}, the number of pipeline"
1020
+ f"steps is {num_inference_steps} which is < 1 and not appropriate for this pipeline."
1021
+ )
1022
+ latent_timestep = timesteps[:1].repeat(batch_size * num_images_per_prompt)
1023
+
1024
+ # 7. Prepare latent variables
1025
+
1026
+ latents, noise, image_latents, latent_image_ids = self.prepare_latents(
1027
+ init_image,
1028
+ latent_timestep,
1029
+ batch_size * num_images_per_prompt,
1030
+ num_channels_latents,
1031
+ global_height,
1032
+ global_width,
1033
+ prompt_embeds.dtype,
1034
+ device,
1035
+ generator,
1036
+ latents,
1037
+ )
1038
+
1039
+ # 8. Prepare mask latents
1040
+ mask_condition = self.mask_processor.preprocess(
1041
+ mask_image, height=global_height, width=global_width, resize_mode=resize_mode, crops_coords=crops_coords
1042
+ )
1043
+ if masked_image_latents is None:
1044
+ masked_image = init_image * (mask_condition < 0.5)
1045
+ else:
1046
+ masked_image = masked_image_latents
1047
+
1048
+ mask, masked_image_latents = self.prepare_mask_latents(
1049
+ mask_condition,
1050
+ masked_image,
1051
+ batch_size,
1052
+ num_channels_latents,
1053
+ num_images_per_prompt,
1054
+ global_height,
1055
+ global_width,
1056
+ prompt_embeds.dtype,
1057
+ device,
1058
+ generator,
1059
+ )
1060
+
1061
+ controlnet_keep = []
1062
+ for i in range(len(timesteps)):
1063
+ keeps = [
1064
+ 1.0 - float(i / len(timesteps) < s or (i + 1) / len(timesteps) > e)
1065
+ for s, e in zip(control_guidance_start, control_guidance_end)
1066
+ ]
1067
+ controlnet_keep.append(keeps[0] if isinstance(self.controlnet, FluxControlNetModel) else keeps)
1068
+
1069
+ # 9. Denoising loop
1070
+ num_warmup_steps = max(len(timesteps) - num_inference_steps * self.scheduler.order, 0)
1071
+ self._num_timesteps = len(timesteps)
1072
+
1073
+ with self.progress_bar(total=num_inference_steps) as progress_bar:
1074
+ for i, t in enumerate(timesteps):
1075
+ if self.interrupt:
1076
+ continue
1077
+
1078
+ timestep = t.expand(latents.shape[0]).to(latents.dtype)
1079
+
1080
+ # predict the noise residual
1081
+ if self.controlnet.config.guidance_embeds:
1082
+ guidance = torch.full([1], guidance_scale, device=device, dtype=torch.float32)
1083
+ guidance = guidance.expand(latents.shape[0])
1084
+ else:
1085
+ guidance = None
1086
+
1087
+ if isinstance(controlnet_keep[i], list):
1088
+ cond_scale = [c * s for c, s in zip(controlnet_conditioning_scale, controlnet_keep[i])]
1089
+ else:
1090
+ controlnet_cond_scale = controlnet_conditioning_scale
1091
+ if isinstance(controlnet_cond_scale, list):
1092
+ controlnet_cond_scale = controlnet_cond_scale[0]
1093
+ cond_scale = controlnet_cond_scale * controlnet_keep[i]
1094
+
1095
+ controlnet_block_samples, controlnet_single_block_samples = self.controlnet(
1096
+ hidden_states=latents,
1097
+ controlnet_cond=control_image,
1098
+ controlnet_mode=control_mode,
1099
+ conditioning_scale=cond_scale,
1100
+ timestep=timestep / 1000,
1101
+ guidance=guidance,
1102
+ pooled_projections=pooled_prompt_embeds,
1103
+ encoder_hidden_states=prompt_embeds,
1104
+ txt_ids=text_ids,
1105
+ img_ids=latent_image_ids,
1106
+ joint_attention_kwargs=self.joint_attention_kwargs,
1107
+ return_dict=False,
1108
+ )
1109
+
1110
+ if self.transformer.config.guidance_embeds:
1111
+ guidance = torch.full([1], guidance_scale, device=device, dtype=torch.float32)
1112
+ guidance = guidance.expand(latents.shape[0])
1113
+ else:
1114
+ guidance = None
1115
+
1116
+ noise_pred = self.transformer(
1117
+ hidden_states=latents,
1118
+ timestep=timestep / 1000,
1119
+ guidance=guidance,
1120
+ pooled_projections=pooled_prompt_embeds,
1121
+ encoder_hidden_states=prompt_embeds,
1122
+ controlnet_block_samples=controlnet_block_samples,
1123
+ controlnet_single_block_samples=controlnet_single_block_samples,
1124
+ txt_ids=text_ids,
1125
+ img_ids=latent_image_ids,
1126
+ joint_attention_kwargs=self.joint_attention_kwargs,
1127
+ return_dict=False,
1128
+ )[0]
1129
+
1130
+ # compute the previous noisy sample x_t -> x_t-1
1131
+ latents_dtype = latents.dtype
1132
+ latents = self.scheduler.step(noise_pred, t, latents, return_dict=False)[0]
1133
+
1134
+ # For inpainting, we need to apply the mask and add the masked image latents
1135
+ init_latents_proper = image_latents
1136
+ init_mask = mask
1137
+
1138
+ if i < len(timesteps) - 1:
1139
+ noise_timestep = timesteps[i + 1]
1140
+ init_latents_proper = self.scheduler.scale_noise(
1141
+ init_latents_proper, torch.tensor([noise_timestep]), noise
1142
+ )
1143
+
1144
+ latents = (1 - init_mask) * init_latents_proper + init_mask * latents
1145
+
1146
+ if latents.dtype != latents_dtype:
1147
+ if torch.backends.mps.is_available():
1148
+ # some platforms (eg. apple mps) misbehave due to a pytorch bug: https://github.com/pytorch/pytorch/pull/99272
1149
+ latents = latents.to(latents_dtype)
1150
+
1151
+ # call the callback, if provided
1152
+ if callback_on_step_end is not None:
1153
+ callback_kwargs = {}
1154
+ for k in callback_on_step_end_tensor_inputs:
1155
+ callback_kwargs[k] = locals()[k]
1156
+ callback_outputs = callback_on_step_end(self, i, t, callback_kwargs)
1157
+
1158
+ latents = callback_outputs.pop("latents", latents)
1159
+ prompt_embeds = callback_outputs.pop("prompt_embeds", prompt_embeds)
1160
+
1161
+ if i == len(timesteps) - 1 or ((i + 1) > num_warmup_steps and (i + 1) % self.scheduler.order == 0):
1162
+ progress_bar.update()
1163
+
1164
+ if XLA_AVAILABLE:
1165
+ xm.mark_step()
1166
+
1167
+ # Post-processing
1168
+ if output_type == "latent":
1169
+ image = latents
1170
+ else:
1171
+ latents = self._unpack_latents(latents, global_height, global_width, self.vae_scale_factor)
1172
+ latents = (latents / self.vae.config.scaling_factor) + self.vae.config.shift_factor
1173
+ image = self.vae.decode(latents, return_dict=False)[0]
1174
+ image = self.image_processor.postprocess(image, output_type=output_type)
1175
+
1176
+ # Offload all models
1177
+ self.maybe_free_model_hooks()
1178
+
1179
+ if not return_dict:
1180
+ return (image,)
1181
+
1182
+ return FluxPipelineOutput(images=image)