diffusers 0.34.0__py3-none-any.whl → 0.35.0__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (191) hide show
  1. diffusers/__init__.py +98 -1
  2. diffusers/callbacks.py +35 -0
  3. diffusers/commands/custom_blocks.py +134 -0
  4. diffusers/commands/diffusers_cli.py +2 -0
  5. diffusers/commands/fp16_safetensors.py +1 -1
  6. diffusers/configuration_utils.py +11 -2
  7. diffusers/dependency_versions_table.py +3 -3
  8. diffusers/guiders/__init__.py +41 -0
  9. diffusers/guiders/adaptive_projected_guidance.py +188 -0
  10. diffusers/guiders/auto_guidance.py +190 -0
  11. diffusers/guiders/classifier_free_guidance.py +141 -0
  12. diffusers/guiders/classifier_free_zero_star_guidance.py +152 -0
  13. diffusers/guiders/frequency_decoupled_guidance.py +327 -0
  14. diffusers/guiders/guider_utils.py +309 -0
  15. diffusers/guiders/perturbed_attention_guidance.py +271 -0
  16. diffusers/guiders/skip_layer_guidance.py +262 -0
  17. diffusers/guiders/smoothed_energy_guidance.py +251 -0
  18. diffusers/guiders/tangential_classifier_free_guidance.py +143 -0
  19. diffusers/hooks/__init__.py +17 -0
  20. diffusers/hooks/_common.py +56 -0
  21. diffusers/hooks/_helpers.py +293 -0
  22. diffusers/hooks/faster_cache.py +7 -6
  23. diffusers/hooks/first_block_cache.py +259 -0
  24. diffusers/hooks/group_offloading.py +292 -286
  25. diffusers/hooks/hooks.py +56 -1
  26. diffusers/hooks/layer_skip.py +263 -0
  27. diffusers/hooks/layerwise_casting.py +2 -7
  28. diffusers/hooks/pyramid_attention_broadcast.py +14 -11
  29. diffusers/hooks/smoothed_energy_guidance_utils.py +167 -0
  30. diffusers/hooks/utils.py +43 -0
  31. diffusers/loaders/__init__.py +6 -0
  32. diffusers/loaders/ip_adapter.py +255 -4
  33. diffusers/loaders/lora_base.py +63 -30
  34. diffusers/loaders/lora_conversion_utils.py +434 -53
  35. diffusers/loaders/lora_pipeline.py +834 -37
  36. diffusers/loaders/peft.py +28 -5
  37. diffusers/loaders/single_file_model.py +44 -11
  38. diffusers/loaders/single_file_utils.py +170 -2
  39. diffusers/loaders/transformer_flux.py +9 -10
  40. diffusers/loaders/transformer_sd3.py +6 -1
  41. diffusers/loaders/unet.py +22 -5
  42. diffusers/loaders/unet_loader_utils.py +5 -2
  43. diffusers/models/__init__.py +8 -0
  44. diffusers/models/attention.py +484 -3
  45. diffusers/models/attention_dispatch.py +1218 -0
  46. diffusers/models/attention_processor.py +105 -663
  47. diffusers/models/auto_model.py +2 -2
  48. diffusers/models/autoencoders/__init__.py +1 -0
  49. diffusers/models/autoencoders/autoencoder_dc.py +14 -1
  50. diffusers/models/autoencoders/autoencoder_kl.py +1 -1
  51. diffusers/models/autoencoders/autoencoder_kl_cosmos.py +3 -1
  52. diffusers/models/autoencoders/autoencoder_kl_qwenimage.py +1070 -0
  53. diffusers/models/autoencoders/autoencoder_kl_wan.py +370 -40
  54. diffusers/models/cache_utils.py +31 -9
  55. diffusers/models/controlnets/controlnet_flux.py +5 -5
  56. diffusers/models/controlnets/controlnet_union.py +4 -4
  57. diffusers/models/embeddings.py +26 -34
  58. diffusers/models/model_loading_utils.py +233 -1
  59. diffusers/models/modeling_flax_utils.py +1 -2
  60. diffusers/models/modeling_utils.py +159 -94
  61. diffusers/models/transformers/__init__.py +2 -0
  62. diffusers/models/transformers/transformer_chroma.py +16 -117
  63. diffusers/models/transformers/transformer_cogview4.py +36 -2
  64. diffusers/models/transformers/transformer_cosmos.py +11 -4
  65. diffusers/models/transformers/transformer_flux.py +372 -132
  66. diffusers/models/transformers/transformer_hunyuan_video.py +6 -0
  67. diffusers/models/transformers/transformer_ltx.py +104 -23
  68. diffusers/models/transformers/transformer_qwenimage.py +645 -0
  69. diffusers/models/transformers/transformer_skyreels_v2.py +607 -0
  70. diffusers/models/transformers/transformer_wan.py +298 -85
  71. diffusers/models/transformers/transformer_wan_vace.py +15 -21
  72. diffusers/models/unets/unet_2d_condition.py +2 -1
  73. diffusers/modular_pipelines/__init__.py +83 -0
  74. diffusers/modular_pipelines/components_manager.py +1068 -0
  75. diffusers/modular_pipelines/flux/__init__.py +66 -0
  76. diffusers/modular_pipelines/flux/before_denoise.py +689 -0
  77. diffusers/modular_pipelines/flux/decoders.py +109 -0
  78. diffusers/modular_pipelines/flux/denoise.py +227 -0
  79. diffusers/modular_pipelines/flux/encoders.py +412 -0
  80. diffusers/modular_pipelines/flux/modular_blocks.py +181 -0
  81. diffusers/modular_pipelines/flux/modular_pipeline.py +59 -0
  82. diffusers/modular_pipelines/modular_pipeline.py +2446 -0
  83. diffusers/modular_pipelines/modular_pipeline_utils.py +672 -0
  84. diffusers/modular_pipelines/node_utils.py +665 -0
  85. diffusers/modular_pipelines/stable_diffusion_xl/__init__.py +77 -0
  86. diffusers/modular_pipelines/stable_diffusion_xl/before_denoise.py +1874 -0
  87. diffusers/modular_pipelines/stable_diffusion_xl/decoders.py +208 -0
  88. diffusers/modular_pipelines/stable_diffusion_xl/denoise.py +771 -0
  89. diffusers/modular_pipelines/stable_diffusion_xl/encoders.py +887 -0
  90. diffusers/modular_pipelines/stable_diffusion_xl/modular_blocks.py +380 -0
  91. diffusers/modular_pipelines/stable_diffusion_xl/modular_pipeline.py +365 -0
  92. diffusers/modular_pipelines/wan/__init__.py +66 -0
  93. diffusers/modular_pipelines/wan/before_denoise.py +365 -0
  94. diffusers/modular_pipelines/wan/decoders.py +105 -0
  95. diffusers/modular_pipelines/wan/denoise.py +261 -0
  96. diffusers/modular_pipelines/wan/encoders.py +242 -0
  97. diffusers/modular_pipelines/wan/modular_blocks.py +144 -0
  98. diffusers/modular_pipelines/wan/modular_pipeline.py +90 -0
  99. diffusers/pipelines/__init__.py +31 -0
  100. diffusers/pipelines/audioldm2/pipeline_audioldm2.py +2 -3
  101. diffusers/pipelines/auto_pipeline.py +17 -13
  102. diffusers/pipelines/chroma/pipeline_chroma.py +5 -5
  103. diffusers/pipelines/chroma/pipeline_chroma_img2img.py +5 -5
  104. diffusers/pipelines/cogvideo/pipeline_cogvideox.py +9 -8
  105. diffusers/pipelines/cogvideo/pipeline_cogvideox_fun_control.py +9 -8
  106. diffusers/pipelines/cogvideo/pipeline_cogvideox_image2video.py +10 -9
  107. diffusers/pipelines/cogvideo/pipeline_cogvideox_video2video.py +9 -8
  108. diffusers/pipelines/cogview4/pipeline_cogview4.py +16 -15
  109. diffusers/pipelines/controlnet/pipeline_controlnet_blip_diffusion.py +3 -2
  110. diffusers/pipelines/controlnet/pipeline_controlnet_union_inpaint_sd_xl.py +212 -93
  111. diffusers/pipelines/controlnet/pipeline_controlnet_union_sd_xl.py +7 -3
  112. diffusers/pipelines/controlnet/pipeline_controlnet_union_sd_xl_img2img.py +194 -92
  113. diffusers/pipelines/deprecated/stable_diffusion_variants/pipeline_cycle_diffusion.py +1 -1
  114. diffusers/pipelines/dit/pipeline_dit.py +3 -1
  115. diffusers/pipelines/flux/__init__.py +4 -0
  116. diffusers/pipelines/flux/pipeline_flux.py +34 -26
  117. diffusers/pipelines/flux/pipeline_flux_control.py +8 -8
  118. diffusers/pipelines/flux/pipeline_flux_control_img2img.py +1 -1
  119. diffusers/pipelines/flux/pipeline_flux_control_inpaint.py +1 -1
  120. diffusers/pipelines/flux/pipeline_flux_controlnet.py +1 -1
  121. diffusers/pipelines/flux/pipeline_flux_controlnet_image_to_image.py +1 -1
  122. diffusers/pipelines/flux/pipeline_flux_controlnet_inpainting.py +1 -1
  123. diffusers/pipelines/flux/pipeline_flux_fill.py +1 -1
  124. diffusers/pipelines/flux/pipeline_flux_img2img.py +1 -1
  125. diffusers/pipelines/flux/pipeline_flux_inpaint.py +1 -1
  126. diffusers/pipelines/flux/pipeline_flux_kontext.py +1134 -0
  127. diffusers/pipelines/flux/pipeline_flux_kontext_inpaint.py +1460 -0
  128. diffusers/pipelines/flux/pipeline_flux_prior_redux.py +1 -1
  129. diffusers/pipelines/flux/pipeline_output.py +6 -4
  130. diffusers/pipelines/hidream_image/pipeline_hidream_image.py +5 -5
  131. diffusers/pipelines/hunyuan_video/pipeline_hunyuan_video.py +25 -24
  132. diffusers/pipelines/ltx/pipeline_ltx.py +13 -12
  133. diffusers/pipelines/ltx/pipeline_ltx_condition.py +10 -9
  134. diffusers/pipelines/ltx/pipeline_ltx_image2video.py +13 -12
  135. diffusers/pipelines/mochi/pipeline_mochi.py +9 -8
  136. diffusers/pipelines/pipeline_flax_utils.py +2 -2
  137. diffusers/pipelines/pipeline_loading_utils.py +24 -2
  138. diffusers/pipelines/pipeline_utils.py +22 -15
  139. diffusers/pipelines/pixart_alpha/pipeline_pixart_alpha.py +3 -1
  140. diffusers/pipelines/pixart_alpha/pipeline_pixart_sigma.py +20 -0
  141. diffusers/pipelines/qwenimage/__init__.py +55 -0
  142. diffusers/pipelines/qwenimage/pipeline_output.py +21 -0
  143. diffusers/pipelines/qwenimage/pipeline_qwenimage.py +726 -0
  144. diffusers/pipelines/qwenimage/pipeline_qwenimage_edit.py +882 -0
  145. diffusers/pipelines/qwenimage/pipeline_qwenimage_img2img.py +829 -0
  146. diffusers/pipelines/qwenimage/pipeline_qwenimage_inpaint.py +1015 -0
  147. diffusers/pipelines/sana/pipeline_sana_sprint.py +5 -5
  148. diffusers/pipelines/skyreels_v2/__init__.py +59 -0
  149. diffusers/pipelines/skyreels_v2/pipeline_output.py +20 -0
  150. diffusers/pipelines/skyreels_v2/pipeline_skyreels_v2.py +610 -0
  151. diffusers/pipelines/skyreels_v2/pipeline_skyreels_v2_diffusion_forcing.py +978 -0
  152. diffusers/pipelines/skyreels_v2/pipeline_skyreels_v2_diffusion_forcing_i2v.py +1059 -0
  153. diffusers/pipelines/skyreels_v2/pipeline_skyreels_v2_diffusion_forcing_v2v.py +1063 -0
  154. diffusers/pipelines/skyreels_v2/pipeline_skyreels_v2_i2v.py +745 -0
  155. diffusers/pipelines/stable_diffusion/pipeline_onnx_stable_diffusion.py +2 -1
  156. diffusers/pipelines/stable_diffusion/pipeline_onnx_stable_diffusion_inpaint.py +1 -1
  157. diffusers/pipelines/stable_diffusion/pipeline_onnx_stable_diffusion_upscale.py +1 -1
  158. diffusers/pipelines/stable_diffusion/pipeline_stable_diffusion.py +2 -1
  159. diffusers/pipelines/stable_diffusion_3/pipeline_stable_diffusion_3.py +6 -5
  160. diffusers/pipelines/wan/pipeline_wan.py +78 -20
  161. diffusers/pipelines/wan/pipeline_wan_i2v.py +112 -32
  162. diffusers/pipelines/wan/pipeline_wan_vace.py +1 -2
  163. diffusers/quantizers/__init__.py +1 -177
  164. diffusers/quantizers/base.py +11 -0
  165. diffusers/quantizers/gguf/utils.py +92 -3
  166. diffusers/quantizers/pipe_quant_config.py +202 -0
  167. diffusers/quantizers/torchao/torchao_quantizer.py +26 -0
  168. diffusers/schedulers/scheduling_deis_multistep.py +8 -1
  169. diffusers/schedulers/scheduling_dpmsolver_multistep.py +6 -0
  170. diffusers/schedulers/scheduling_dpmsolver_singlestep.py +6 -0
  171. diffusers/schedulers/scheduling_scm.py +0 -1
  172. diffusers/schedulers/scheduling_unipc_multistep.py +10 -1
  173. diffusers/schedulers/scheduling_utils.py +2 -2
  174. diffusers/schedulers/scheduling_utils_flax.py +1 -1
  175. diffusers/training_utils.py +78 -0
  176. diffusers/utils/__init__.py +10 -0
  177. diffusers/utils/constants.py +4 -0
  178. diffusers/utils/dummy_pt_objects.py +312 -0
  179. diffusers/utils/dummy_torch_and_transformers_objects.py +255 -0
  180. diffusers/utils/dynamic_modules_utils.py +84 -25
  181. diffusers/utils/hub_utils.py +33 -17
  182. diffusers/utils/import_utils.py +70 -0
  183. diffusers/utils/peft_utils.py +11 -8
  184. diffusers/utils/testing_utils.py +136 -10
  185. diffusers/utils/torch_utils.py +18 -0
  186. {diffusers-0.34.0.dist-info → diffusers-0.35.0.dist-info}/METADATA +6 -6
  187. {diffusers-0.34.0.dist-info → diffusers-0.35.0.dist-info}/RECORD +191 -127
  188. {diffusers-0.34.0.dist-info → diffusers-0.35.0.dist-info}/LICENSE +0 -0
  189. {diffusers-0.34.0.dist-info → diffusers-0.35.0.dist-info}/WHEEL +0 -0
  190. {diffusers-0.34.0.dist-info → diffusers-0.35.0.dist-info}/entry_points.txt +0 -0
  191. {diffusers-0.34.0.dist-info → diffusers-0.35.0.dist-info}/top_level.txt +0 -0
@@ -0,0 +1,109 @@
1
+ # Copyright 2025 The HuggingFace Team. All rights reserved.
2
+ #
3
+ # Licensed under the Apache License, Version 2.0 (the "License");
4
+ # you may not use this file except in compliance with the License.
5
+ # You may obtain a copy of the License at
6
+ #
7
+ # http://www.apache.org/licenses/LICENSE-2.0
8
+ #
9
+ # Unless required by applicable law or agreed to in writing, software
10
+ # distributed under the License is distributed on an "AS IS" BASIS,
11
+ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12
+ # See the License for the specific language governing permissions and
13
+ # limitations under the License.
14
+
15
+ from typing import Any, List, Tuple, Union
16
+
17
+ import numpy as np
18
+ import PIL
19
+ import torch
20
+
21
+ from ...configuration_utils import FrozenDict
22
+ from ...models import AutoencoderKL
23
+ from ...utils import logging
24
+ from ...video_processor import VaeImageProcessor
25
+ from ..modular_pipeline import ModularPipelineBlocks, PipelineState
26
+ from ..modular_pipeline_utils import ComponentSpec, InputParam, OutputParam
27
+
28
+
29
+ logger = logging.get_logger(__name__) # pylint: disable=invalid-name
30
+
31
+
32
+ def _unpack_latents(latents, height, width, vae_scale_factor):
33
+ batch_size, num_patches, channels = latents.shape
34
+
35
+ # VAE applies 8x compression on images but we must also account for packing which requires
36
+ # latent height and width to be divisible by 2.
37
+ height = 2 * (int(height) // (vae_scale_factor * 2))
38
+ width = 2 * (int(width) // (vae_scale_factor * 2))
39
+
40
+ latents = latents.view(batch_size, height // 2, width // 2, channels // 4, 2, 2)
41
+ latents = latents.permute(0, 3, 1, 4, 2, 5)
42
+
43
+ latents = latents.reshape(batch_size, channels // (2 * 2), height, width)
44
+
45
+ return latents
46
+
47
+
48
+ class FluxDecodeStep(ModularPipelineBlocks):
49
+ model_name = "flux"
50
+
51
+ @property
52
+ def expected_components(self) -> List[ComponentSpec]:
53
+ return [
54
+ ComponentSpec("vae", AutoencoderKL),
55
+ ComponentSpec(
56
+ "image_processor",
57
+ VaeImageProcessor,
58
+ config=FrozenDict({"vae_scale_factor": 16}),
59
+ default_creation_method="from_config",
60
+ ),
61
+ ]
62
+
63
+ @property
64
+ def description(self) -> str:
65
+ return "Step that decodes the denoised latents into images"
66
+
67
+ @property
68
+ def inputs(self) -> List[Tuple[str, Any]]:
69
+ return [
70
+ InputParam("output_type", default="pil"),
71
+ InputParam("height", default=1024),
72
+ InputParam("width", default=1024),
73
+ InputParam(
74
+ "latents",
75
+ required=True,
76
+ type_hint=torch.Tensor,
77
+ description="The denoised latents from the denoising step",
78
+ ),
79
+ ]
80
+
81
+ @property
82
+ def intermediate_outputs(self) -> List[str]:
83
+ return [
84
+ OutputParam(
85
+ "images",
86
+ type_hint=Union[List[PIL.Image.Image], torch.Tensor, np.ndarray],
87
+ description="The generated images, can be a list of PIL.Image.Image, torch.Tensor or a numpy array",
88
+ )
89
+ ]
90
+
91
+ @torch.no_grad()
92
+ def __call__(self, components, state: PipelineState) -> PipelineState:
93
+ block_state = self.get_block_state(state)
94
+ vae = components.vae
95
+
96
+ if not block_state.output_type == "latent":
97
+ latents = block_state.latents
98
+ latents = _unpack_latents(latents, block_state.height, block_state.width, components.vae_scale_factor)
99
+ latents = (latents / vae.config.scaling_factor) + vae.config.shift_factor
100
+ block_state.images = vae.decode(latents, return_dict=False)[0]
101
+ block_state.images = components.image_processor.postprocess(
102
+ block_state.images, output_type=block_state.output_type
103
+ )
104
+ else:
105
+ block_state.images = block_state.latents
106
+
107
+ self.set_block_state(state, block_state)
108
+
109
+ return components, state
@@ -0,0 +1,227 @@
1
+ # Copyright 2025 The HuggingFace Team. All rights reserved.
2
+ #
3
+ # Licensed under the Apache License, Version 2.0 (the "License");
4
+ # you may not use this file except in compliance with the License.
5
+ # You may obtain a copy of the License at
6
+ #
7
+ # http://www.apache.org/licenses/LICENSE-2.0
8
+ #
9
+ # Unless required by applicable law or agreed to in writing, software
10
+ # distributed under the License is distributed on an "AS IS" BASIS,
11
+ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12
+ # See the License for the specific language governing permissions and
13
+ # limitations under the License.
14
+
15
+ from typing import Any, List, Tuple
16
+
17
+ import torch
18
+
19
+ from ...models import FluxTransformer2DModel
20
+ from ...schedulers import FlowMatchEulerDiscreteScheduler
21
+ from ...utils import logging
22
+ from ..modular_pipeline import (
23
+ BlockState,
24
+ LoopSequentialPipelineBlocks,
25
+ ModularPipelineBlocks,
26
+ PipelineState,
27
+ )
28
+ from ..modular_pipeline_utils import ComponentSpec, InputParam, OutputParam
29
+ from .modular_pipeline import FluxModularPipeline
30
+
31
+
32
+ logger = logging.get_logger(__name__) # pylint: disable=invalid-name
33
+
34
+
35
+ class FluxLoopDenoiser(ModularPipelineBlocks):
36
+ model_name = "flux"
37
+
38
+ @property
39
+ def expected_components(self) -> List[ComponentSpec]:
40
+ return [ComponentSpec("transformer", FluxTransformer2DModel)]
41
+
42
+ @property
43
+ def description(self) -> str:
44
+ return (
45
+ "Step within the denoising loop that denoise the latents. "
46
+ "This block should be used to compose the `sub_blocks` attribute of a `LoopSequentialPipelineBlocks` "
47
+ "object (e.g. `FluxDenoiseLoopWrapper`)"
48
+ )
49
+
50
+ @property
51
+ def inputs(self) -> List[Tuple[str, Any]]:
52
+ return [
53
+ InputParam("joint_attention_kwargs"),
54
+ InputParam(
55
+ "latents",
56
+ required=True,
57
+ type_hint=torch.Tensor,
58
+ description="The initial latents to use for the denoising process. Can be generated in prepare_latent step.",
59
+ ),
60
+ InputParam(
61
+ "guidance",
62
+ required=True,
63
+ type_hint=torch.Tensor,
64
+ description="Guidance scale as a tensor",
65
+ ),
66
+ InputParam(
67
+ "prompt_embeds",
68
+ required=True,
69
+ type_hint=torch.Tensor,
70
+ description="Prompt embeddings",
71
+ ),
72
+ InputParam(
73
+ "pooled_prompt_embeds",
74
+ required=True,
75
+ type_hint=torch.Tensor,
76
+ description="Pooled prompt embeddings",
77
+ ),
78
+ InputParam(
79
+ "text_ids",
80
+ required=True,
81
+ type_hint=torch.Tensor,
82
+ description="IDs computed from text sequence needed for RoPE",
83
+ ),
84
+ InputParam(
85
+ "latent_image_ids",
86
+ required=True,
87
+ type_hint=torch.Tensor,
88
+ description="IDs computed from image sequence needed for RoPE",
89
+ ),
90
+ # TODO: guidance
91
+ ]
92
+
93
+ @torch.no_grad()
94
+ def __call__(
95
+ self, components: FluxModularPipeline, block_state: BlockState, i: int, t: torch.Tensor
96
+ ) -> PipelineState:
97
+ noise_pred = components.transformer(
98
+ hidden_states=block_state.latents,
99
+ timestep=t.flatten() / 1000,
100
+ guidance=block_state.guidance,
101
+ encoder_hidden_states=block_state.prompt_embeds,
102
+ pooled_projections=block_state.pooled_prompt_embeds,
103
+ joint_attention_kwargs=block_state.joint_attention_kwargs,
104
+ txt_ids=block_state.text_ids,
105
+ img_ids=block_state.latent_image_ids,
106
+ return_dict=False,
107
+ )[0]
108
+ block_state.noise_pred = noise_pred
109
+
110
+ return components, block_state
111
+
112
+
113
+ class FluxLoopAfterDenoiser(ModularPipelineBlocks):
114
+ model_name = "flux"
115
+
116
+ @property
117
+ def expected_components(self) -> List[ComponentSpec]:
118
+ return [ComponentSpec("scheduler", FlowMatchEulerDiscreteScheduler)]
119
+
120
+ @property
121
+ def description(self) -> str:
122
+ return (
123
+ "step within the denoising loop that update the latents. "
124
+ "This block should be used to compose the `sub_blocks` attribute of a `LoopSequentialPipelineBlocks` "
125
+ "object (e.g. `FluxDenoiseLoopWrapper`)"
126
+ )
127
+
128
+ @property
129
+ def inputs(self) -> List[Tuple[str, Any]]:
130
+ return []
131
+
132
+ @property
133
+ def intermediate_inputs(self) -> List[str]:
134
+ return [InputParam("generator")]
135
+
136
+ @property
137
+ def intermediate_outputs(self) -> List[OutputParam]:
138
+ return [OutputParam("latents", type_hint=torch.Tensor, description="The denoised latents")]
139
+
140
+ @torch.no_grad()
141
+ def __call__(self, components: FluxModularPipeline, block_state: BlockState, i: int, t: torch.Tensor):
142
+ # Perform scheduler step using the predicted output
143
+ latents_dtype = block_state.latents.dtype
144
+ block_state.latents = components.scheduler.step(
145
+ block_state.noise_pred,
146
+ t,
147
+ block_state.latents,
148
+ return_dict=False,
149
+ )[0]
150
+
151
+ if block_state.latents.dtype != latents_dtype:
152
+ block_state.latents = block_state.latents.to(latents_dtype)
153
+
154
+ return components, block_state
155
+
156
+
157
+ class FluxDenoiseLoopWrapper(LoopSequentialPipelineBlocks):
158
+ model_name = "flux"
159
+
160
+ @property
161
+ def description(self) -> str:
162
+ return (
163
+ "Pipeline block that iteratively denoise the latents over `timesteps`. "
164
+ "The specific steps with each iteration can be customized with `sub_blocks` attributes"
165
+ )
166
+
167
+ @property
168
+ def loop_expected_components(self) -> List[ComponentSpec]:
169
+ return [
170
+ ComponentSpec("scheduler", FlowMatchEulerDiscreteScheduler),
171
+ ComponentSpec("transformer", FluxTransformer2DModel),
172
+ ]
173
+
174
+ @property
175
+ def loop_inputs(self) -> List[InputParam]:
176
+ return [
177
+ InputParam(
178
+ "timesteps",
179
+ required=True,
180
+ type_hint=torch.Tensor,
181
+ description="The timesteps to use for the denoising process. Can be generated in set_timesteps step.",
182
+ ),
183
+ InputParam(
184
+ "num_inference_steps",
185
+ required=True,
186
+ type_hint=int,
187
+ description="The number of inference steps to use for the denoising process. Can be generated in set_timesteps step.",
188
+ ),
189
+ ]
190
+
191
+ @torch.no_grad()
192
+ def __call__(self, components: FluxModularPipeline, state: PipelineState) -> PipelineState:
193
+ block_state = self.get_block_state(state)
194
+
195
+ block_state.num_warmup_steps = max(
196
+ len(block_state.timesteps) - block_state.num_inference_steps * components.scheduler.order, 0
197
+ )
198
+ # We set the index here to remove DtoH sync, helpful especially during compilation.
199
+ # Check out more details here: https://github.com/huggingface/diffusers/pull/11696
200
+ components.scheduler.set_begin_index(0)
201
+ with self.progress_bar(total=block_state.num_inference_steps) as progress_bar:
202
+ for i, t in enumerate(block_state.timesteps):
203
+ components, block_state = self.loop_step(components, block_state, i=i, t=t)
204
+ if i == len(block_state.timesteps) - 1 or (
205
+ (i + 1) > block_state.num_warmup_steps and (i + 1) % components.scheduler.order == 0
206
+ ):
207
+ progress_bar.update()
208
+
209
+ self.set_block_state(state, block_state)
210
+
211
+ return components, state
212
+
213
+
214
+ class FluxDenoiseStep(FluxDenoiseLoopWrapper):
215
+ block_classes = [FluxLoopDenoiser, FluxLoopAfterDenoiser]
216
+ block_names = ["denoiser", "after_denoiser"]
217
+
218
+ @property
219
+ def description(self) -> str:
220
+ return (
221
+ "Denoise step that iteratively denoise the latents. \n"
222
+ "Its loop logic is defined in `FluxDenoiseLoopWrapper.__call__` method \n"
223
+ "At each iteration, it runs blocks defined in `sub_blocks` sequencially:\n"
224
+ " - `FluxLoopDenoiser`\n"
225
+ " - `FluxLoopAfterDenoiser`\n"
226
+ "This block supports both text2image and img2img tasks."
227
+ )