diffusers 0.34.0__py3-none-any.whl → 0.35.1__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- diffusers/__init__.py +98 -1
- diffusers/callbacks.py +35 -0
- diffusers/commands/custom_blocks.py +134 -0
- diffusers/commands/diffusers_cli.py +2 -0
- diffusers/commands/fp16_safetensors.py +1 -1
- diffusers/configuration_utils.py +11 -2
- diffusers/dependency_versions_table.py +3 -3
- diffusers/guiders/__init__.py +41 -0
- diffusers/guiders/adaptive_projected_guidance.py +188 -0
- diffusers/guiders/auto_guidance.py +190 -0
- diffusers/guiders/classifier_free_guidance.py +141 -0
- diffusers/guiders/classifier_free_zero_star_guidance.py +152 -0
- diffusers/guiders/frequency_decoupled_guidance.py +327 -0
- diffusers/guiders/guider_utils.py +309 -0
- diffusers/guiders/perturbed_attention_guidance.py +271 -0
- diffusers/guiders/skip_layer_guidance.py +262 -0
- diffusers/guiders/smoothed_energy_guidance.py +251 -0
- diffusers/guiders/tangential_classifier_free_guidance.py +143 -0
- diffusers/hooks/__init__.py +17 -0
- diffusers/hooks/_common.py +56 -0
- diffusers/hooks/_helpers.py +293 -0
- diffusers/hooks/faster_cache.py +7 -6
- diffusers/hooks/first_block_cache.py +259 -0
- diffusers/hooks/group_offloading.py +292 -286
- diffusers/hooks/hooks.py +56 -1
- diffusers/hooks/layer_skip.py +263 -0
- diffusers/hooks/layerwise_casting.py +2 -7
- diffusers/hooks/pyramid_attention_broadcast.py +14 -11
- diffusers/hooks/smoothed_energy_guidance_utils.py +167 -0
- diffusers/hooks/utils.py +43 -0
- diffusers/loaders/__init__.py +6 -0
- diffusers/loaders/ip_adapter.py +255 -4
- diffusers/loaders/lora_base.py +63 -30
- diffusers/loaders/lora_conversion_utils.py +434 -53
- diffusers/loaders/lora_pipeline.py +834 -37
- diffusers/loaders/peft.py +28 -5
- diffusers/loaders/single_file_model.py +44 -11
- diffusers/loaders/single_file_utils.py +170 -2
- diffusers/loaders/transformer_flux.py +9 -10
- diffusers/loaders/transformer_sd3.py +6 -1
- diffusers/loaders/unet.py +22 -5
- diffusers/loaders/unet_loader_utils.py +5 -2
- diffusers/models/__init__.py +8 -0
- diffusers/models/attention.py +484 -3
- diffusers/models/attention_dispatch.py +1218 -0
- diffusers/models/attention_processor.py +105 -663
- diffusers/models/auto_model.py +2 -2
- diffusers/models/autoencoders/__init__.py +1 -0
- diffusers/models/autoencoders/autoencoder_dc.py +14 -1
- diffusers/models/autoencoders/autoencoder_kl.py +1 -1
- diffusers/models/autoencoders/autoencoder_kl_cosmos.py +3 -1
- diffusers/models/autoencoders/autoencoder_kl_qwenimage.py +1070 -0
- diffusers/models/autoencoders/autoencoder_kl_wan.py +370 -40
- diffusers/models/cache_utils.py +31 -9
- diffusers/models/controlnets/controlnet_flux.py +5 -5
- diffusers/models/controlnets/controlnet_union.py +4 -4
- diffusers/models/embeddings.py +26 -34
- diffusers/models/model_loading_utils.py +233 -1
- diffusers/models/modeling_flax_utils.py +1 -2
- diffusers/models/modeling_utils.py +159 -94
- diffusers/models/transformers/__init__.py +2 -0
- diffusers/models/transformers/transformer_chroma.py +16 -117
- diffusers/models/transformers/transformer_cogview4.py +36 -2
- diffusers/models/transformers/transformer_cosmos.py +11 -4
- diffusers/models/transformers/transformer_flux.py +372 -132
- diffusers/models/transformers/transformer_hunyuan_video.py +6 -0
- diffusers/models/transformers/transformer_ltx.py +104 -23
- diffusers/models/transformers/transformer_qwenimage.py +645 -0
- diffusers/models/transformers/transformer_skyreels_v2.py +607 -0
- diffusers/models/transformers/transformer_wan.py +298 -85
- diffusers/models/transformers/transformer_wan_vace.py +15 -21
- diffusers/models/unets/unet_2d_condition.py +2 -1
- diffusers/modular_pipelines/__init__.py +83 -0
- diffusers/modular_pipelines/components_manager.py +1068 -0
- diffusers/modular_pipelines/flux/__init__.py +66 -0
- diffusers/modular_pipelines/flux/before_denoise.py +689 -0
- diffusers/modular_pipelines/flux/decoders.py +109 -0
- diffusers/modular_pipelines/flux/denoise.py +227 -0
- diffusers/modular_pipelines/flux/encoders.py +412 -0
- diffusers/modular_pipelines/flux/modular_blocks.py +181 -0
- diffusers/modular_pipelines/flux/modular_pipeline.py +59 -0
- diffusers/modular_pipelines/modular_pipeline.py +2446 -0
- diffusers/modular_pipelines/modular_pipeline_utils.py +672 -0
- diffusers/modular_pipelines/node_utils.py +665 -0
- diffusers/modular_pipelines/stable_diffusion_xl/__init__.py +77 -0
- diffusers/modular_pipelines/stable_diffusion_xl/before_denoise.py +1874 -0
- diffusers/modular_pipelines/stable_diffusion_xl/decoders.py +208 -0
- diffusers/modular_pipelines/stable_diffusion_xl/denoise.py +771 -0
- diffusers/modular_pipelines/stable_diffusion_xl/encoders.py +887 -0
- diffusers/modular_pipelines/stable_diffusion_xl/modular_blocks.py +380 -0
- diffusers/modular_pipelines/stable_diffusion_xl/modular_pipeline.py +365 -0
- diffusers/modular_pipelines/wan/__init__.py +66 -0
- diffusers/modular_pipelines/wan/before_denoise.py +365 -0
- diffusers/modular_pipelines/wan/decoders.py +105 -0
- diffusers/modular_pipelines/wan/denoise.py +261 -0
- diffusers/modular_pipelines/wan/encoders.py +242 -0
- diffusers/modular_pipelines/wan/modular_blocks.py +144 -0
- diffusers/modular_pipelines/wan/modular_pipeline.py +90 -0
- diffusers/pipelines/__init__.py +31 -0
- diffusers/pipelines/audioldm2/pipeline_audioldm2.py +2 -3
- diffusers/pipelines/auto_pipeline.py +17 -13
- diffusers/pipelines/chroma/pipeline_chroma.py +5 -5
- diffusers/pipelines/chroma/pipeline_chroma_img2img.py +5 -5
- diffusers/pipelines/cogvideo/pipeline_cogvideox.py +9 -8
- diffusers/pipelines/cogvideo/pipeline_cogvideox_fun_control.py +9 -8
- diffusers/pipelines/cogvideo/pipeline_cogvideox_image2video.py +10 -9
- diffusers/pipelines/cogvideo/pipeline_cogvideox_video2video.py +9 -8
- diffusers/pipelines/cogview4/pipeline_cogview4.py +16 -15
- diffusers/pipelines/controlnet/pipeline_controlnet_blip_diffusion.py +3 -2
- diffusers/pipelines/controlnet/pipeline_controlnet_union_inpaint_sd_xl.py +212 -93
- diffusers/pipelines/controlnet/pipeline_controlnet_union_sd_xl.py +7 -3
- diffusers/pipelines/controlnet/pipeline_controlnet_union_sd_xl_img2img.py +194 -92
- diffusers/pipelines/deprecated/stable_diffusion_variants/pipeline_cycle_diffusion.py +1 -1
- diffusers/pipelines/dit/pipeline_dit.py +3 -1
- diffusers/pipelines/flux/__init__.py +4 -0
- diffusers/pipelines/flux/pipeline_flux.py +34 -26
- diffusers/pipelines/flux/pipeline_flux_control.py +8 -8
- diffusers/pipelines/flux/pipeline_flux_control_img2img.py +1 -1
- diffusers/pipelines/flux/pipeline_flux_control_inpaint.py +1 -1
- diffusers/pipelines/flux/pipeline_flux_controlnet.py +1 -1
- diffusers/pipelines/flux/pipeline_flux_controlnet_image_to_image.py +1 -1
- diffusers/pipelines/flux/pipeline_flux_controlnet_inpainting.py +1 -1
- diffusers/pipelines/flux/pipeline_flux_fill.py +1 -1
- diffusers/pipelines/flux/pipeline_flux_img2img.py +1 -1
- diffusers/pipelines/flux/pipeline_flux_inpaint.py +1 -1
- diffusers/pipelines/flux/pipeline_flux_kontext.py +1134 -0
- diffusers/pipelines/flux/pipeline_flux_kontext_inpaint.py +1460 -0
- diffusers/pipelines/flux/pipeline_flux_prior_redux.py +1 -1
- diffusers/pipelines/flux/pipeline_output.py +6 -4
- diffusers/pipelines/hidream_image/pipeline_hidream_image.py +5 -5
- diffusers/pipelines/hunyuan_video/pipeline_hunyuan_video.py +25 -24
- diffusers/pipelines/ltx/pipeline_ltx.py +13 -12
- diffusers/pipelines/ltx/pipeline_ltx_condition.py +10 -9
- diffusers/pipelines/ltx/pipeline_ltx_image2video.py +13 -12
- diffusers/pipelines/mochi/pipeline_mochi.py +9 -8
- diffusers/pipelines/pipeline_flax_utils.py +2 -2
- diffusers/pipelines/pipeline_loading_utils.py +24 -2
- diffusers/pipelines/pipeline_utils.py +22 -15
- diffusers/pipelines/pixart_alpha/pipeline_pixart_alpha.py +3 -1
- diffusers/pipelines/pixart_alpha/pipeline_pixart_sigma.py +20 -0
- diffusers/pipelines/qwenimage/__init__.py +55 -0
- diffusers/pipelines/qwenimage/pipeline_output.py +21 -0
- diffusers/pipelines/qwenimage/pipeline_qwenimage.py +726 -0
- diffusers/pipelines/qwenimage/pipeline_qwenimage_edit.py +849 -0
- diffusers/pipelines/qwenimage/pipeline_qwenimage_img2img.py +829 -0
- diffusers/pipelines/qwenimage/pipeline_qwenimage_inpaint.py +1015 -0
- diffusers/pipelines/sana/pipeline_sana_sprint.py +5 -5
- diffusers/pipelines/skyreels_v2/__init__.py +59 -0
- diffusers/pipelines/skyreels_v2/pipeline_output.py +20 -0
- diffusers/pipelines/skyreels_v2/pipeline_skyreels_v2.py +610 -0
- diffusers/pipelines/skyreels_v2/pipeline_skyreels_v2_diffusion_forcing.py +978 -0
- diffusers/pipelines/skyreels_v2/pipeline_skyreels_v2_diffusion_forcing_i2v.py +1059 -0
- diffusers/pipelines/skyreels_v2/pipeline_skyreels_v2_diffusion_forcing_v2v.py +1063 -0
- diffusers/pipelines/skyreels_v2/pipeline_skyreels_v2_i2v.py +745 -0
- diffusers/pipelines/stable_diffusion/pipeline_onnx_stable_diffusion.py +2 -1
- diffusers/pipelines/stable_diffusion/pipeline_onnx_stable_diffusion_inpaint.py +1 -1
- diffusers/pipelines/stable_diffusion/pipeline_onnx_stable_diffusion_upscale.py +1 -1
- diffusers/pipelines/stable_diffusion/pipeline_stable_diffusion.py +2 -1
- diffusers/pipelines/stable_diffusion_3/pipeline_stable_diffusion_3.py +6 -5
- diffusers/pipelines/wan/pipeline_wan.py +78 -20
- diffusers/pipelines/wan/pipeline_wan_i2v.py +112 -32
- diffusers/pipelines/wan/pipeline_wan_vace.py +1 -2
- diffusers/quantizers/__init__.py +1 -177
- diffusers/quantizers/base.py +11 -0
- diffusers/quantizers/gguf/utils.py +92 -3
- diffusers/quantizers/pipe_quant_config.py +202 -0
- diffusers/quantizers/torchao/torchao_quantizer.py +26 -0
- diffusers/schedulers/scheduling_deis_multistep.py +8 -1
- diffusers/schedulers/scheduling_dpmsolver_multistep.py +6 -0
- diffusers/schedulers/scheduling_dpmsolver_singlestep.py +6 -0
- diffusers/schedulers/scheduling_scm.py +0 -1
- diffusers/schedulers/scheduling_unipc_multistep.py +10 -1
- diffusers/schedulers/scheduling_utils.py +2 -2
- diffusers/schedulers/scheduling_utils_flax.py +1 -1
- diffusers/training_utils.py +78 -0
- diffusers/utils/__init__.py +10 -0
- diffusers/utils/constants.py +4 -0
- diffusers/utils/dummy_pt_objects.py +312 -0
- diffusers/utils/dummy_torch_and_transformers_objects.py +255 -0
- diffusers/utils/dynamic_modules_utils.py +84 -25
- diffusers/utils/hub_utils.py +33 -17
- diffusers/utils/import_utils.py +70 -0
- diffusers/utils/peft_utils.py +11 -8
- diffusers/utils/testing_utils.py +136 -10
- diffusers/utils/torch_utils.py +18 -0
- {diffusers-0.34.0.dist-info → diffusers-0.35.1.dist-info}/METADATA +6 -6
- {diffusers-0.34.0.dist-info → diffusers-0.35.1.dist-info}/RECORD +191 -127
- {diffusers-0.34.0.dist-info → diffusers-0.35.1.dist-info}/LICENSE +0 -0
- {diffusers-0.34.0.dist-info → diffusers-0.35.1.dist-info}/WHEEL +0 -0
- {diffusers-0.34.0.dist-info → diffusers-0.35.1.dist-info}/entry_points.txt +0 -0
- {diffusers-0.34.0.dist-info → diffusers-0.35.1.dist-info}/top_level.txt +0 -0
@@ -383,7 +383,8 @@ class OnnxStableDiffusionPipeline(DiffusionPipeline):
|
|
383
383
|
# set timesteps
|
384
384
|
self.scheduler.set_timesteps(num_inference_steps)
|
385
385
|
|
386
|
-
|
386
|
+
# scale the initial noise by the standard deviation required by the scheduler
|
387
|
+
latents = latents * self.scheduler.init_noise_sigma
|
387
388
|
|
388
389
|
# prepare extra kwargs for the scheduler step, since not all schedulers have the same signature
|
389
390
|
# eta (η) is only used with the DDIMScheduler, it will be ignored for other schedulers.
|
@@ -483,7 +483,7 @@ class OnnxStableDiffusionInpaintPipeline(DiffusionPipeline):
|
|
483
483
|
self.scheduler.set_timesteps(num_inference_steps)
|
484
484
|
|
485
485
|
# scale the initial noise by the standard deviation required by the scheduler
|
486
|
-
latents = latents *
|
486
|
+
latents = latents * self.scheduler.init_noise_sigma
|
487
487
|
|
488
488
|
# prepare extra kwargs for the scheduler step, since not all schedulers have the same signature
|
489
489
|
# eta (η) is only used with the DDIMScheduler, it will be ignored for other schedulers.
|
@@ -481,7 +481,7 @@ class OnnxStableDiffusionUpscalePipeline(DiffusionPipeline):
|
|
481
481
|
timesteps = self.scheduler.timesteps
|
482
482
|
|
483
483
|
# Scale the initial noise by the standard deviation required by the scheduler
|
484
|
-
latents = latents *
|
484
|
+
latents = latents * self.scheduler.init_noise_sigma
|
485
485
|
|
486
486
|
# 5. Add noise to image
|
487
487
|
noise_level = np.array([noise_level]).astype(np.int64)
|
@@ -1034,7 +1034,8 @@ class StableDiffusionPipeline(
|
|
1034
1034
|
|
1035
1035
|
# expand the latents if we are doing classifier free guidance
|
1036
1036
|
latent_model_input = torch.cat([latents] * 2) if self.do_classifier_free_guidance else latents
|
1037
|
-
|
1037
|
+
if hasattr(self.scheduler, "scale_model_input"):
|
1038
|
+
latent_model_input = self.scheduler.scale_model_input(latent_model_input, t)
|
1038
1039
|
|
1039
1040
|
# predict the noise residual
|
1040
1041
|
noise_pred = self.unet(
|
@@ -25,6 +25,7 @@ from transformers import (
|
|
25
25
|
T5TokenizerFast,
|
26
26
|
)
|
27
27
|
|
28
|
+
from ...callbacks import MultiPipelineCallbacks, PipelineCallback
|
28
29
|
from ...image_processor import PipelineImageInput, VaeImageProcessor
|
29
30
|
from ...loaders import FromSingleFileMixin, SD3IPAdapterMixin, SD3LoraLoaderMixin
|
30
31
|
from ...models.autoencoders import AutoencoderKL
|
@@ -184,7 +185,7 @@ class StableDiffusion3Pipeline(DiffusionPipeline, SD3LoraLoaderMixin, FromSingle
|
|
184
185
|
|
185
186
|
model_cpu_offload_seq = "text_encoder->text_encoder_2->text_encoder_3->image_encoder->transformer->vae"
|
186
187
|
_optional_components = ["image_encoder", "feature_extractor"]
|
187
|
-
_callback_tensor_inputs = ["latents", "prompt_embeds", "
|
188
|
+
_callback_tensor_inputs = ["latents", "prompt_embeds", "pooled_prompt_embeds"]
|
188
189
|
|
189
190
|
def __init__(
|
190
191
|
self,
|
@@ -923,6 +924,9 @@ class StableDiffusion3Pipeline(DiffusionPipeline, SD3LoraLoaderMixin, FromSingle
|
|
923
924
|
height = height or self.default_sample_size * self.vae_scale_factor
|
924
925
|
width = width or self.default_sample_size * self.vae_scale_factor
|
925
926
|
|
927
|
+
if isinstance(callback_on_step_end, (PipelineCallback, MultiPipelineCallbacks)):
|
928
|
+
callback_on_step_end_tensor_inputs = callback_on_step_end.tensor_inputs
|
929
|
+
|
926
930
|
# 1. Check inputs. Raise error if not correct
|
927
931
|
self.check_inputs(
|
928
932
|
prompt,
|
@@ -1109,10 +1113,7 @@ class StableDiffusion3Pipeline(DiffusionPipeline, SD3LoraLoaderMixin, FromSingle
|
|
1109
1113
|
|
1110
1114
|
latents = callback_outputs.pop("latents", latents)
|
1111
1115
|
prompt_embeds = callback_outputs.pop("prompt_embeds", prompt_embeds)
|
1112
|
-
|
1113
|
-
negative_pooled_prompt_embeds = callback_outputs.pop(
|
1114
|
-
"negative_pooled_prompt_embeds", negative_pooled_prompt_embeds
|
1115
|
-
)
|
1116
|
+
pooled_prompt_embeds = callback_outputs.pop("pooled_prompt_embeds", pooled_prompt_embeds)
|
1116
1117
|
|
1117
1118
|
# call the callback, if provided
|
1118
1119
|
if i == len(timesteps) - 1 or ((i + 1) > num_warmup_steps and (i + 1) % self.scheduler.order == 0):
|
@@ -112,18 +112,31 @@ class WanPipeline(DiffusionPipeline, WanLoraLoaderMixin):
|
|
112
112
|
A scheduler to be used in combination with `transformer` to denoise the encoded image latents.
|
113
113
|
vae ([`AutoencoderKLWan`]):
|
114
114
|
Variational Auto-Encoder (VAE) Model to encode and decode videos to and from latent representations.
|
115
|
+
transformer_2 ([`WanTransformer3DModel`], *optional*):
|
116
|
+
Conditional Transformer to denoise the input latents during the low-noise stage. If provided, enables
|
117
|
+
two-stage denoising where `transformer` handles high-noise stages and `transformer_2` handles low-noise
|
118
|
+
stages. If not provided, only `transformer` is used.
|
119
|
+
boundary_ratio (`float`, *optional*, defaults to `None`):
|
120
|
+
Ratio of total timesteps to use as the boundary for switching between transformers in two-stage denoising.
|
121
|
+
The actual boundary timestep is calculated as `boundary_ratio * num_train_timesteps`. When provided,
|
122
|
+
`transformer` handles timesteps >= boundary_timestep and `transformer_2` handles timesteps <
|
123
|
+
boundary_timestep. If `None`, only `transformer` is used for the entire denoising process.
|
115
124
|
"""
|
116
125
|
|
117
|
-
model_cpu_offload_seq = "text_encoder->transformer->vae"
|
126
|
+
model_cpu_offload_seq = "text_encoder->transformer->transformer_2->vae"
|
118
127
|
_callback_tensor_inputs = ["latents", "prompt_embeds", "negative_prompt_embeds"]
|
128
|
+
_optional_components = ["transformer", "transformer_2"]
|
119
129
|
|
120
130
|
def __init__(
|
121
131
|
self,
|
122
132
|
tokenizer: AutoTokenizer,
|
123
133
|
text_encoder: UMT5EncoderModel,
|
124
|
-
transformer: WanTransformer3DModel,
|
125
134
|
vae: AutoencoderKLWan,
|
126
135
|
scheduler: FlowMatchEulerDiscreteScheduler,
|
136
|
+
transformer: Optional[WanTransformer3DModel] = None,
|
137
|
+
transformer_2: Optional[WanTransformer3DModel] = None,
|
138
|
+
boundary_ratio: Optional[float] = None,
|
139
|
+
expand_timesteps: bool = False, # Wan2.2 ti2v
|
127
140
|
):
|
128
141
|
super().__init__()
|
129
142
|
|
@@ -133,10 +146,12 @@ class WanPipeline(DiffusionPipeline, WanLoraLoaderMixin):
|
|
133
146
|
tokenizer=tokenizer,
|
134
147
|
transformer=transformer,
|
135
148
|
scheduler=scheduler,
|
149
|
+
transformer_2=transformer_2,
|
136
150
|
)
|
137
|
-
|
138
|
-
self.
|
139
|
-
self.
|
151
|
+
self.register_to_config(boundary_ratio=boundary_ratio)
|
152
|
+
self.register_to_config(expand_timesteps=expand_timesteps)
|
153
|
+
self.vae_scale_factor_temporal = self.vae.config.scale_factor_temporal if getattr(self, "vae", None) else 4
|
154
|
+
self.vae_scale_factor_spatial = self.vae.config.scale_factor_spatial if getattr(self, "vae", None) else 8
|
140
155
|
self.video_processor = VideoProcessor(vae_scale_factor=self.vae_scale_factor_spatial)
|
141
156
|
|
142
157
|
def _get_t5_prompt_embeds(
|
@@ -270,6 +285,7 @@ class WanPipeline(DiffusionPipeline, WanLoraLoaderMixin):
|
|
270
285
|
prompt_embeds=None,
|
271
286
|
negative_prompt_embeds=None,
|
272
287
|
callback_on_step_end_tensor_inputs=None,
|
288
|
+
guidance_scale_2=None,
|
273
289
|
):
|
274
290
|
if height % 16 != 0 or width % 16 != 0:
|
275
291
|
raise ValueError(f"`height` and `width` have to be divisible by 16 but are {height} and {width}.")
|
@@ -302,6 +318,9 @@ class WanPipeline(DiffusionPipeline, WanLoraLoaderMixin):
|
|
302
318
|
):
|
303
319
|
raise ValueError(f"`negative_prompt` has to be of type `str` or `list` but is {type(negative_prompt)}")
|
304
320
|
|
321
|
+
if self.config.boundary_ratio is None and guidance_scale_2 is not None:
|
322
|
+
raise ValueError("`guidance_scale_2` is only supported when the pipeline's `boundary_ratio` is not None.")
|
323
|
+
|
305
324
|
def prepare_latents(
|
306
325
|
self,
|
307
326
|
batch_size: int,
|
@@ -369,6 +388,7 @@ class WanPipeline(DiffusionPipeline, WanLoraLoaderMixin):
|
|
369
388
|
num_frames: int = 81,
|
370
389
|
num_inference_steps: int = 50,
|
371
390
|
guidance_scale: float = 5.0,
|
391
|
+
guidance_scale_2: Optional[float] = None,
|
372
392
|
num_videos_per_prompt: Optional[int] = 1,
|
373
393
|
generator: Optional[Union[torch.Generator, List[torch.Generator]]] = None,
|
374
394
|
latents: Optional[torch.Tensor] = None,
|
@@ -407,6 +427,10 @@ class WanPipeline(DiffusionPipeline, WanLoraLoaderMixin):
|
|
407
427
|
of [Imagen Paper](https://huggingface.co/papers/2205.11487). Guidance scale is enabled by setting
|
408
428
|
`guidance_scale > 1`. Higher guidance scale encourages to generate images that are closely linked to
|
409
429
|
the text `prompt`, usually at the expense of lower image quality.
|
430
|
+
guidance_scale_2 (`float`, *optional*, defaults to `None`):
|
431
|
+
Guidance scale for the low-noise stage transformer (`transformer_2`). If `None` and the pipeline's
|
432
|
+
`boundary_ratio` is not None, uses the same value as `guidance_scale`. Only used when `transformer_2`
|
433
|
+
and the pipeline's `boundary_ratio` are not None.
|
410
434
|
num_videos_per_prompt (`int`, *optional*, defaults to 1):
|
411
435
|
The number of images to generate per prompt.
|
412
436
|
generator (`torch.Generator` or `List[torch.Generator]`, *optional*):
|
@@ -461,6 +485,7 @@ class WanPipeline(DiffusionPipeline, WanLoraLoaderMixin):
|
|
461
485
|
prompt_embeds,
|
462
486
|
negative_prompt_embeds,
|
463
487
|
callback_on_step_end_tensor_inputs,
|
488
|
+
guidance_scale_2,
|
464
489
|
)
|
465
490
|
|
466
491
|
if num_frames % self.vae_scale_factor_temporal != 1:
|
@@ -470,7 +495,11 @@ class WanPipeline(DiffusionPipeline, WanLoraLoaderMixin):
|
|
470
495
|
num_frames = num_frames // self.vae_scale_factor_temporal * self.vae_scale_factor_temporal + 1
|
471
496
|
num_frames = max(num_frames, 1)
|
472
497
|
|
498
|
+
if self.config.boundary_ratio is not None and guidance_scale_2 is None:
|
499
|
+
guidance_scale_2 = guidance_scale
|
500
|
+
|
473
501
|
self._guidance_scale = guidance_scale
|
502
|
+
self._guidance_scale_2 = guidance_scale_2
|
474
503
|
self._attention_kwargs = attention_kwargs
|
475
504
|
self._current_timestep = None
|
476
505
|
self._interrupt = False
|
@@ -497,7 +526,7 @@ class WanPipeline(DiffusionPipeline, WanLoraLoaderMixin):
|
|
497
526
|
device=device,
|
498
527
|
)
|
499
528
|
|
500
|
-
transformer_dtype = self.transformer.dtype
|
529
|
+
transformer_dtype = self.transformer.dtype if self.transformer is not None else self.transformer_2.dtype
|
501
530
|
prompt_embeds = prompt_embeds.to(transformer_dtype)
|
502
531
|
if negative_prompt_embeds is not None:
|
503
532
|
negative_prompt_embeds = negative_prompt_embeds.to(transformer_dtype)
|
@@ -507,7 +536,11 @@ class WanPipeline(DiffusionPipeline, WanLoraLoaderMixin):
|
|
507
536
|
timesteps = self.scheduler.timesteps
|
508
537
|
|
509
538
|
# 5. Prepare latent variables
|
510
|
-
num_channels_latents =
|
539
|
+
num_channels_latents = (
|
540
|
+
self.transformer.config.in_channels
|
541
|
+
if self.transformer is not None
|
542
|
+
else self.transformer_2.config.in_channels
|
543
|
+
)
|
511
544
|
latents = self.prepare_latents(
|
512
545
|
batch_size * num_videos_per_prompt,
|
513
546
|
num_channels_latents,
|
@@ -520,36 +553,61 @@ class WanPipeline(DiffusionPipeline, WanLoraLoaderMixin):
|
|
520
553
|
latents,
|
521
554
|
)
|
522
555
|
|
556
|
+
mask = torch.ones(latents.shape, dtype=torch.float32, device=device)
|
557
|
+
|
523
558
|
# 6. Denoising loop
|
524
559
|
num_warmup_steps = len(timesteps) - num_inference_steps * self.scheduler.order
|
525
560
|
self._num_timesteps = len(timesteps)
|
526
561
|
|
562
|
+
if self.config.boundary_ratio is not None:
|
563
|
+
boundary_timestep = self.config.boundary_ratio * self.scheduler.config.num_train_timesteps
|
564
|
+
else:
|
565
|
+
boundary_timestep = None
|
566
|
+
|
527
567
|
with self.progress_bar(total=num_inference_steps) as progress_bar:
|
528
568
|
for i, t in enumerate(timesteps):
|
529
569
|
if self.interrupt:
|
530
570
|
continue
|
531
571
|
|
532
572
|
self._current_timestep = t
|
533
|
-
latent_model_input = latents.to(transformer_dtype)
|
534
|
-
timestep = t.expand(latents.shape[0])
|
535
573
|
|
536
|
-
|
537
|
-
|
538
|
-
|
539
|
-
|
540
|
-
|
541
|
-
|
542
|
-
|
574
|
+
if boundary_timestep is None or t >= boundary_timestep:
|
575
|
+
# wan2.1 or high-noise stage in wan2.2
|
576
|
+
current_model = self.transformer
|
577
|
+
current_guidance_scale = guidance_scale
|
578
|
+
else:
|
579
|
+
# low-noise stage in wan2.2
|
580
|
+
current_model = self.transformer_2
|
581
|
+
current_guidance_scale = guidance_scale_2
|
543
582
|
|
544
|
-
|
545
|
-
|
583
|
+
latent_model_input = latents.to(transformer_dtype)
|
584
|
+
if self.config.expand_timesteps:
|
585
|
+
# seq_len: num_latent_frames * latent_height//2 * latent_width//2
|
586
|
+
temp_ts = (mask[0][0][:, ::2, ::2] * t).flatten()
|
587
|
+
# batch_size, seq_len
|
588
|
+
timestep = temp_ts.unsqueeze(0).expand(latents.shape[0], -1)
|
589
|
+
else:
|
590
|
+
timestep = t.expand(latents.shape[0])
|
591
|
+
|
592
|
+
with current_model.cache_context("cond"):
|
593
|
+
noise_pred = current_model(
|
546
594
|
hidden_states=latent_model_input,
|
547
595
|
timestep=timestep,
|
548
|
-
encoder_hidden_states=
|
596
|
+
encoder_hidden_states=prompt_embeds,
|
549
597
|
attention_kwargs=attention_kwargs,
|
550
598
|
return_dict=False,
|
551
599
|
)[0]
|
552
|
-
|
600
|
+
|
601
|
+
if self.do_classifier_free_guidance:
|
602
|
+
with current_model.cache_context("uncond"):
|
603
|
+
noise_uncond = current_model(
|
604
|
+
hidden_states=latent_model_input,
|
605
|
+
timestep=timestep,
|
606
|
+
encoder_hidden_states=negative_prompt_embeds,
|
607
|
+
attention_kwargs=attention_kwargs,
|
608
|
+
return_dict=False,
|
609
|
+
)[0]
|
610
|
+
noise_pred = noise_uncond + current_guidance_scale * (noise_pred - noise_uncond)
|
553
611
|
|
554
612
|
# compute the previous noisy sample x_t -> x_t-1
|
555
613
|
latents = self.scheduler.step(noise_pred, t, latents, return_dict=False)[0]
|
@@ -149,20 +149,33 @@ class WanImageToVideoPipeline(DiffusionPipeline, WanLoraLoaderMixin):
|
|
149
149
|
A scheduler to be used in combination with `transformer` to denoise the encoded image latents.
|
150
150
|
vae ([`AutoencoderKLWan`]):
|
151
151
|
Variational Auto-Encoder (VAE) Model to encode and decode videos to and from latent representations.
|
152
|
+
transformer_2 ([`WanTransformer3DModel`], *optional*):
|
153
|
+
Conditional Transformer to denoise the input latents during the low-noise stage. In two-stage denoising,
|
154
|
+
`transformer` handles high-noise stages and `transformer_2` handles low-noise stages. If not provided, only
|
155
|
+
`transformer` is used.
|
156
|
+
boundary_ratio (`float`, *optional*, defaults to `None`):
|
157
|
+
Ratio of total timesteps to use as the boundary for switching between transformers in two-stage denoising.
|
158
|
+
The actual boundary timestep is calculated as `boundary_ratio * num_train_timesteps`. When provided,
|
159
|
+
`transformer` handles timesteps >= boundary_timestep and `transformer_2` handles timesteps <
|
160
|
+
boundary_timestep. If `None`, only `transformer` is used for the entire denoising process.
|
152
161
|
"""
|
153
162
|
|
154
|
-
model_cpu_offload_seq = "text_encoder->image_encoder->transformer->vae"
|
163
|
+
model_cpu_offload_seq = "text_encoder->image_encoder->transformer->transformer_2->vae"
|
155
164
|
_callback_tensor_inputs = ["latents", "prompt_embeds", "negative_prompt_embeds"]
|
165
|
+
_optional_components = ["transformer", "transformer_2", "image_encoder", "image_processor"]
|
156
166
|
|
157
167
|
def __init__(
|
158
168
|
self,
|
159
169
|
tokenizer: AutoTokenizer,
|
160
170
|
text_encoder: UMT5EncoderModel,
|
161
|
-
image_encoder: CLIPVisionModel,
|
162
|
-
image_processor: CLIPImageProcessor,
|
163
|
-
transformer: WanTransformer3DModel,
|
164
171
|
vae: AutoencoderKLWan,
|
165
172
|
scheduler: FlowMatchEulerDiscreteScheduler,
|
173
|
+
image_processor: CLIPImageProcessor = None,
|
174
|
+
image_encoder: CLIPVisionModel = None,
|
175
|
+
transformer: WanTransformer3DModel = None,
|
176
|
+
transformer_2: WanTransformer3DModel = None,
|
177
|
+
boundary_ratio: Optional[float] = None,
|
178
|
+
expand_timesteps: bool = False,
|
166
179
|
):
|
167
180
|
super().__init__()
|
168
181
|
|
@@ -174,10 +187,12 @@ class WanImageToVideoPipeline(DiffusionPipeline, WanLoraLoaderMixin):
|
|
174
187
|
transformer=transformer,
|
175
188
|
scheduler=scheduler,
|
176
189
|
image_processor=image_processor,
|
190
|
+
transformer_2=transformer_2,
|
177
191
|
)
|
192
|
+
self.register_to_config(boundary_ratio=boundary_ratio, expand_timesteps=expand_timesteps)
|
178
193
|
|
179
|
-
self.vae_scale_factor_temporal =
|
180
|
-
self.vae_scale_factor_spatial =
|
194
|
+
self.vae_scale_factor_temporal = self.vae.config.scale_factor_temporal if getattr(self, "vae", None) else 4
|
195
|
+
self.vae_scale_factor_spatial = self.vae.config.scale_factor_spatial if getattr(self, "vae", None) else 8
|
181
196
|
self.video_processor = VideoProcessor(vae_scale_factor=self.vae_scale_factor_spatial)
|
182
197
|
self.image_processor = image_processor
|
183
198
|
|
@@ -325,6 +340,7 @@ class WanImageToVideoPipeline(DiffusionPipeline, WanLoraLoaderMixin):
|
|
325
340
|
negative_prompt_embeds=None,
|
326
341
|
image_embeds=None,
|
327
342
|
callback_on_step_end_tensor_inputs=None,
|
343
|
+
guidance_scale_2=None,
|
328
344
|
):
|
329
345
|
if image is not None and image_embeds is not None:
|
330
346
|
raise ValueError(
|
@@ -368,6 +384,12 @@ class WanImageToVideoPipeline(DiffusionPipeline, WanLoraLoaderMixin):
|
|
368
384
|
):
|
369
385
|
raise ValueError(f"`negative_prompt` has to be of type `str` or `list` but is {type(negative_prompt)}")
|
370
386
|
|
387
|
+
if self.config.boundary_ratio is None and guidance_scale_2 is not None:
|
388
|
+
raise ValueError("`guidance_scale_2` is only supported when the pipeline's `boundary_ratio` is not None.")
|
389
|
+
|
390
|
+
if self.config.boundary_ratio is not None and image_embeds is not None:
|
391
|
+
raise ValueError("Cannot forward `image_embeds` when the pipeline's `boundary_ratio` is not configured.")
|
392
|
+
|
371
393
|
def prepare_latents(
|
372
394
|
self,
|
373
395
|
image: PipelineImageInput,
|
@@ -398,8 +420,12 @@ class WanImageToVideoPipeline(DiffusionPipeline, WanLoraLoaderMixin):
|
|
398
420
|
else:
|
399
421
|
latents = latents.to(device=device, dtype=dtype)
|
400
422
|
|
401
|
-
image = image.unsqueeze(2)
|
402
|
-
|
423
|
+
image = image.unsqueeze(2) # [batch_size, channels, 1, height, width]
|
424
|
+
|
425
|
+
if self.config.expand_timesteps:
|
426
|
+
video_condition = image
|
427
|
+
|
428
|
+
elif last_image is None:
|
403
429
|
video_condition = torch.cat(
|
404
430
|
[image, image.new_zeros(image.shape[0], image.shape[1], num_frames - 1, height, width)], dim=2
|
405
431
|
)
|
@@ -432,6 +458,13 @@ class WanImageToVideoPipeline(DiffusionPipeline, WanLoraLoaderMixin):
|
|
432
458
|
latent_condition = latent_condition.to(dtype)
|
433
459
|
latent_condition = (latent_condition - latents_mean) * latents_std
|
434
460
|
|
461
|
+
if self.config.expand_timesteps:
|
462
|
+
first_frame_mask = torch.ones(
|
463
|
+
1, 1, num_latent_frames, latent_height, latent_width, dtype=dtype, device=device
|
464
|
+
)
|
465
|
+
first_frame_mask[:, :, 0] = 0
|
466
|
+
return latents, latent_condition, first_frame_mask
|
467
|
+
|
435
468
|
mask_lat_size = torch.ones(batch_size, 1, num_frames, latent_height, latent_width)
|
436
469
|
|
437
470
|
if last_image is None:
|
@@ -483,6 +516,7 @@ class WanImageToVideoPipeline(DiffusionPipeline, WanLoraLoaderMixin):
|
|
483
516
|
num_frames: int = 81,
|
484
517
|
num_inference_steps: int = 50,
|
485
518
|
guidance_scale: float = 5.0,
|
519
|
+
guidance_scale_2: Optional[float] = None,
|
486
520
|
num_videos_per_prompt: Optional[int] = 1,
|
487
521
|
generator: Optional[Union[torch.Generator, List[torch.Generator]]] = None,
|
488
522
|
latents: Optional[torch.Tensor] = None,
|
@@ -527,6 +561,10 @@ class WanImageToVideoPipeline(DiffusionPipeline, WanLoraLoaderMixin):
|
|
527
561
|
of [Imagen Paper](https://huggingface.co/papers/2205.11487). Guidance scale is enabled by setting
|
528
562
|
`guidance_scale > 1`. Higher guidance scale encourages to generate images that are closely linked to
|
529
563
|
the text `prompt`, usually at the expense of lower image quality.
|
564
|
+
guidance_scale_2 (`float`, *optional*, defaults to `None`):
|
565
|
+
Guidance scale for the low-noise stage transformer (`transformer_2`). If `None` and the pipeline's
|
566
|
+
`boundary_ratio` is not None, uses the same value as `guidance_scale`. Only used when `transformer_2`
|
567
|
+
and the pipeline's `boundary_ratio` are not None.
|
530
568
|
num_videos_per_prompt (`int`, *optional*, defaults to 1):
|
531
569
|
The number of images to generate per prompt.
|
532
570
|
generator (`torch.Generator` or `List[torch.Generator]`, *optional*):
|
@@ -589,6 +627,7 @@ class WanImageToVideoPipeline(DiffusionPipeline, WanLoraLoaderMixin):
|
|
589
627
|
negative_prompt_embeds,
|
590
628
|
image_embeds,
|
591
629
|
callback_on_step_end_tensor_inputs,
|
630
|
+
guidance_scale_2,
|
592
631
|
)
|
593
632
|
|
594
633
|
if num_frames % self.vae_scale_factor_temporal != 1:
|
@@ -598,7 +637,11 @@ class WanImageToVideoPipeline(DiffusionPipeline, WanLoraLoaderMixin):
|
|
598
637
|
num_frames = num_frames // self.vae_scale_factor_temporal * self.vae_scale_factor_temporal + 1
|
599
638
|
num_frames = max(num_frames, 1)
|
600
639
|
|
640
|
+
if self.config.boundary_ratio is not None and guidance_scale_2 is None:
|
641
|
+
guidance_scale_2 = guidance_scale
|
642
|
+
|
601
643
|
self._guidance_scale = guidance_scale
|
644
|
+
self._guidance_scale_2 = guidance_scale_2
|
602
645
|
self._attention_kwargs = attention_kwargs
|
603
646
|
self._current_timestep = None
|
604
647
|
self._interrupt = False
|
@@ -626,18 +669,20 @@ class WanImageToVideoPipeline(DiffusionPipeline, WanLoraLoaderMixin):
|
|
626
669
|
)
|
627
670
|
|
628
671
|
# Encode image embedding
|
629
|
-
transformer_dtype = self.transformer.dtype
|
672
|
+
transformer_dtype = self.transformer.dtype if self.transformer is not None else self.transformer_2.dtype
|
630
673
|
prompt_embeds = prompt_embeds.to(transformer_dtype)
|
631
674
|
if negative_prompt_embeds is not None:
|
632
675
|
negative_prompt_embeds = negative_prompt_embeds.to(transformer_dtype)
|
633
676
|
|
634
|
-
|
635
|
-
|
636
|
-
|
637
|
-
|
638
|
-
|
639
|
-
|
640
|
-
|
677
|
+
# only wan 2.1 i2v transformer accepts image_embeds
|
678
|
+
if self.transformer is not None and self.transformer.config.image_dim is not None:
|
679
|
+
if image_embeds is None:
|
680
|
+
if last_image is None:
|
681
|
+
image_embeds = self.encode_image(image, device)
|
682
|
+
else:
|
683
|
+
image_embeds = self.encode_image([image, last_image], device)
|
684
|
+
image_embeds = image_embeds.repeat(batch_size, 1, 1)
|
685
|
+
image_embeds = image_embeds.to(transformer_dtype)
|
641
686
|
|
642
687
|
# 4. Prepare timesteps
|
643
688
|
self.scheduler.set_timesteps(num_inference_steps, device=device)
|
@@ -650,7 +695,8 @@ class WanImageToVideoPipeline(DiffusionPipeline, WanLoraLoaderMixin):
|
|
650
695
|
last_image = self.video_processor.preprocess(last_image, height=height, width=width).to(
|
651
696
|
device, dtype=torch.float32
|
652
697
|
)
|
653
|
-
|
698
|
+
|
699
|
+
latents_outputs = self.prepare_latents(
|
654
700
|
image,
|
655
701
|
batch_size * num_videos_per_prompt,
|
656
702
|
num_channels_latents,
|
@@ -663,39 +709,70 @@ class WanImageToVideoPipeline(DiffusionPipeline, WanLoraLoaderMixin):
|
|
663
709
|
latents,
|
664
710
|
last_image,
|
665
711
|
)
|
712
|
+
if self.config.expand_timesteps:
|
713
|
+
# wan 2.2 5b i2v use firt_frame_mask to mask timesteps
|
714
|
+
latents, condition, first_frame_mask = latents_outputs
|
715
|
+
else:
|
716
|
+
latents, condition = latents_outputs
|
666
717
|
|
667
718
|
# 6. Denoising loop
|
668
719
|
num_warmup_steps = len(timesteps) - num_inference_steps * self.scheduler.order
|
669
720
|
self._num_timesteps = len(timesteps)
|
670
721
|
|
722
|
+
if self.config.boundary_ratio is not None:
|
723
|
+
boundary_timestep = self.config.boundary_ratio * self.scheduler.config.num_train_timesteps
|
724
|
+
else:
|
725
|
+
boundary_timestep = None
|
726
|
+
|
671
727
|
with self.progress_bar(total=num_inference_steps) as progress_bar:
|
672
728
|
for i, t in enumerate(timesteps):
|
673
729
|
if self.interrupt:
|
674
730
|
continue
|
675
731
|
|
676
732
|
self._current_timestep = t
|
677
|
-
latent_model_input = torch.cat([latents, condition], dim=1).to(transformer_dtype)
|
678
|
-
timestep = t.expand(latents.shape[0])
|
679
|
-
|
680
|
-
noise_pred = self.transformer(
|
681
|
-
hidden_states=latent_model_input,
|
682
|
-
timestep=timestep,
|
683
|
-
encoder_hidden_states=prompt_embeds,
|
684
|
-
encoder_hidden_states_image=image_embeds,
|
685
|
-
attention_kwargs=attention_kwargs,
|
686
|
-
return_dict=False,
|
687
|
-
)[0]
|
688
733
|
|
689
|
-
if
|
690
|
-
|
734
|
+
if boundary_timestep is None or t >= boundary_timestep:
|
735
|
+
# wan2.1 or high-noise stage in wan2.2
|
736
|
+
current_model = self.transformer
|
737
|
+
current_guidance_scale = guidance_scale
|
738
|
+
else:
|
739
|
+
# low-noise stage in wan2.2
|
740
|
+
current_model = self.transformer_2
|
741
|
+
current_guidance_scale = guidance_scale_2
|
742
|
+
|
743
|
+
if self.config.expand_timesteps:
|
744
|
+
latent_model_input = (1 - first_frame_mask) * condition + first_frame_mask * latents
|
745
|
+
latent_model_input = latent_model_input.to(transformer_dtype)
|
746
|
+
|
747
|
+
# seq_len: num_latent_frames * (latent_height // patch_size) * (latent_width // patch_size)
|
748
|
+
temp_ts = (first_frame_mask[0][0][:, ::2, ::2] * t).flatten()
|
749
|
+
# batch_size, seq_len
|
750
|
+
timestep = temp_ts.unsqueeze(0).expand(latents.shape[0], -1)
|
751
|
+
else:
|
752
|
+
latent_model_input = torch.cat([latents, condition], dim=1).to(transformer_dtype)
|
753
|
+
timestep = t.expand(latents.shape[0])
|
754
|
+
|
755
|
+
with current_model.cache_context("cond"):
|
756
|
+
noise_pred = current_model(
|
691
757
|
hidden_states=latent_model_input,
|
692
758
|
timestep=timestep,
|
693
|
-
encoder_hidden_states=
|
759
|
+
encoder_hidden_states=prompt_embeds,
|
694
760
|
encoder_hidden_states_image=image_embeds,
|
695
761
|
attention_kwargs=attention_kwargs,
|
696
762
|
return_dict=False,
|
697
763
|
)[0]
|
698
|
-
|
764
|
+
|
765
|
+
if self.do_classifier_free_guidance:
|
766
|
+
with current_model.cache_context("uncond"):
|
767
|
+
noise_uncond = current_model(
|
768
|
+
hidden_states=latent_model_input,
|
769
|
+
timestep=timestep,
|
770
|
+
encoder_hidden_states=negative_prompt_embeds,
|
771
|
+
encoder_hidden_states_image=image_embeds,
|
772
|
+
attention_kwargs=attention_kwargs,
|
773
|
+
return_dict=False,
|
774
|
+
)[0]
|
775
|
+
noise_pred = noise_uncond + current_guidance_scale * (noise_pred - noise_uncond)
|
699
776
|
|
700
777
|
# compute the previous noisy sample x_t -> x_t-1
|
701
778
|
latents = self.scheduler.step(noise_pred, t, latents, return_dict=False)[0]
|
@@ -719,6 +796,9 @@ class WanImageToVideoPipeline(DiffusionPipeline, WanLoraLoaderMixin):
|
|
719
796
|
|
720
797
|
self._current_timestep = None
|
721
798
|
|
799
|
+
if self.config.expand_timesteps:
|
800
|
+
latents = (1 - first_frame_mask) * condition + first_frame_mask * latents
|
801
|
+
|
722
802
|
if not output_type == "latent":
|
723
803
|
latents = latents.to(self.vae.dtype)
|
724
804
|
latents_mean = (
|
@@ -525,8 +525,7 @@ class WanVACEPipeline(DiffusionPipeline, WanLoraLoaderMixin):
|
|
525
525
|
latents = retrieve_latents(self.vae.encode(video), generator, sample_mode="argmax").unbind(0)
|
526
526
|
latents = ((latents.float() - latents_mean) * latents_std).to(vae_dtype)
|
527
527
|
else:
|
528
|
-
mask = mask.to(dtype=vae_dtype)
|
529
|
-
mask = torch.where(mask > 0.5, 1.0, 0.0)
|
528
|
+
mask = torch.where(mask > 0.5, 1.0, 0.0).to(dtype=vae_dtype)
|
530
529
|
inactive = video * (1 - mask)
|
531
530
|
reactive = video * mask
|
532
531
|
inactive = retrieve_latents(self.vae.encode(inactive), generator, sample_mode="argmax")
|