diffusers 0.23.1__py3-none-any.whl → 0.24.0__py3-none-any.whl
Sign up to get free protection for your applications and to get access to all the features.
- diffusers/__init__.py +16 -2
- diffusers/configuration_utils.py +1 -0
- diffusers/dependency_versions_check.py +0 -1
- diffusers/dependency_versions_table.py +4 -5
- diffusers/image_processor.py +186 -14
- diffusers/loaders/__init__.py +82 -0
- diffusers/loaders/ip_adapter.py +157 -0
- diffusers/loaders/lora.py +1415 -0
- diffusers/loaders/lora_conversion_utils.py +284 -0
- diffusers/loaders/single_file.py +631 -0
- diffusers/loaders/textual_inversion.py +459 -0
- diffusers/loaders/unet.py +735 -0
- diffusers/loaders/utils.py +59 -0
- diffusers/models/__init__.py +12 -1
- diffusers/models/attention.py +165 -14
- diffusers/models/attention_flax.py +9 -1
- diffusers/models/attention_processor.py +286 -1
- diffusers/models/autoencoder_asym_kl.py +14 -9
- diffusers/models/autoencoder_kl.py +3 -18
- diffusers/models/autoencoder_kl_temporal_decoder.py +402 -0
- diffusers/models/autoencoder_tiny.py +20 -24
- diffusers/models/consistency_decoder_vae.py +37 -30
- diffusers/models/controlnet.py +59 -39
- diffusers/models/controlnet_flax.py +19 -18
- diffusers/models/embeddings_flax.py +2 -0
- diffusers/models/lora.py +131 -1
- diffusers/models/modeling_flax_utils.py +2 -1
- diffusers/models/modeling_outputs.py +17 -0
- diffusers/models/modeling_utils.py +27 -19
- diffusers/models/normalization.py +2 -2
- diffusers/models/resnet.py +390 -59
- diffusers/models/transformer_2d.py +20 -3
- diffusers/models/transformer_temporal.py +183 -1
- diffusers/models/unet_2d_blocks_flax.py +5 -0
- diffusers/models/unet_2d_condition.py +9 -0
- diffusers/models/unet_2d_condition_flax.py +13 -13
- diffusers/models/unet_3d_blocks.py +957 -173
- diffusers/models/unet_3d_condition.py +16 -8
- diffusers/models/unet_kandi3.py +589 -0
- diffusers/models/unet_motion_model.py +48 -33
- diffusers/models/unet_spatio_temporal_condition.py +489 -0
- diffusers/models/vae.py +63 -13
- diffusers/models/vae_flax.py +7 -0
- diffusers/models/vq_model.py +3 -1
- diffusers/optimization.py +16 -9
- diffusers/pipelines/__init__.py +65 -12
- diffusers/pipelines/alt_diffusion/pipeline_alt_diffusion.py +93 -23
- diffusers/pipelines/alt_diffusion/pipeline_alt_diffusion_img2img.py +97 -25
- diffusers/pipelines/animatediff/pipeline_animatediff.py +34 -4
- diffusers/pipelines/audioldm/pipeline_audioldm.py +1 -0
- diffusers/pipelines/auto_pipeline.py +6 -0
- diffusers/pipelines/consistency_models/pipeline_consistency_models.py +1 -0
- diffusers/pipelines/controlnet/pipeline_controlnet.py +217 -31
- diffusers/pipelines/controlnet/pipeline_controlnet_img2img.py +101 -32
- diffusers/pipelines/controlnet/pipeline_controlnet_inpaint.py +136 -39
- diffusers/pipelines/controlnet/pipeline_controlnet_inpaint_sd_xl.py +119 -37
- diffusers/pipelines/controlnet/pipeline_controlnet_sd_xl.py +196 -35
- diffusers/pipelines/controlnet/pipeline_controlnet_sd_xl_img2img.py +102 -31
- diffusers/pipelines/dance_diffusion/pipeline_dance_diffusion.py +1 -0
- diffusers/pipelines/ddim/pipeline_ddim.py +1 -0
- diffusers/pipelines/ddpm/pipeline_ddpm.py +1 -0
- diffusers/pipelines/deepfloyd_if/pipeline_if.py +13 -1
- diffusers/pipelines/deepfloyd_if/pipeline_if_img2img.py +13 -1
- diffusers/pipelines/deepfloyd_if/pipeline_if_img2img_superresolution.py +13 -1
- diffusers/pipelines/deepfloyd_if/pipeline_if_inpainting.py +13 -1
- diffusers/pipelines/deepfloyd_if/pipeline_if_inpainting_superresolution.py +13 -1
- diffusers/pipelines/deepfloyd_if/pipeline_if_superresolution.py +13 -1
- diffusers/pipelines/dit/pipeline_dit.py +1 -0
- diffusers/pipelines/kandinsky2_2/pipeline_kandinsky2_2.py +1 -1
- diffusers/pipelines/kandinsky2_2/pipeline_kandinsky2_2_combined.py +3 -3
- diffusers/pipelines/kandinsky2_2/pipeline_kandinsky2_2_img2img.py +1 -1
- diffusers/pipelines/kandinsky2_2/pipeline_kandinsky2_2_inpainting.py +1 -1
- diffusers/pipelines/kandinsky2_2/pipeline_kandinsky2_2_prior.py +1 -1
- diffusers/pipelines/kandinsky3/__init__.py +49 -0
- diffusers/pipelines/kandinsky3/kandinsky3_pipeline.py +452 -0
- diffusers/pipelines/kandinsky3/kandinsky3img2img_pipeline.py +460 -0
- diffusers/pipelines/latent_consistency_models/pipeline_latent_consistency_img2img.py +65 -6
- diffusers/pipelines/latent_consistency_models/pipeline_latent_consistency_text2img.py +55 -3
- diffusers/pipelines/latent_diffusion/pipeline_latent_diffusion.py +1 -0
- diffusers/pipelines/musicldm/pipeline_musicldm.py +1 -1
- diffusers/pipelines/paint_by_example/pipeline_paint_by_example.py +7 -2
- diffusers/pipelines/pipeline_flax_utils.py +4 -2
- diffusers/pipelines/pipeline_utils.py +33 -13
- diffusers/pipelines/pixart_alpha/pipeline_pixart_alpha.py +196 -36
- diffusers/pipelines/score_sde_ve/pipeline_score_sde_ve.py +1 -0
- diffusers/pipelines/spectrogram_diffusion/pipeline_spectrogram_diffusion.py +1 -0
- diffusers/pipelines/stable_diffusion/__init__.py +64 -21
- diffusers/pipelines/stable_diffusion/convert_from_ckpt.py +8 -3
- diffusers/pipelines/stable_diffusion/pipeline_cycle_diffusion.py +18 -2
- diffusers/pipelines/stable_diffusion/pipeline_flax_stable_diffusion.py +2 -2
- diffusers/pipelines/stable_diffusion/pipeline_onnx_stable_diffusion_img2img.py +2 -4
- diffusers/pipelines/stable_diffusion/pipeline_onnx_stable_diffusion_inpaint.py +1 -0
- diffusers/pipelines/stable_diffusion/pipeline_onnx_stable_diffusion_inpaint_legacy.py +1 -0
- diffusers/pipelines/stable_diffusion/pipeline_stable_diffusion.py +88 -9
- diffusers/pipelines/stable_diffusion/pipeline_stable_diffusion_attend_and_excite.py +1 -0
- diffusers/pipelines/stable_diffusion/pipeline_stable_diffusion_depth2img.py +8 -3
- diffusers/pipelines/stable_diffusion/pipeline_stable_diffusion_diffedit.py +1 -0
- diffusers/pipelines/stable_diffusion/pipeline_stable_diffusion_gligen.py +1 -0
- diffusers/pipelines/stable_diffusion/pipeline_stable_diffusion_gligen_text_image.py +1 -0
- diffusers/pipelines/stable_diffusion/pipeline_stable_diffusion_image_variation.py +1 -0
- diffusers/pipelines/stable_diffusion/pipeline_stable_diffusion_img2img.py +92 -9
- diffusers/pipelines/stable_diffusion/pipeline_stable_diffusion_inpaint.py +92 -9
- diffusers/pipelines/stable_diffusion/pipeline_stable_diffusion_inpaint_legacy.py +1 -0
- diffusers/pipelines/stable_diffusion/pipeline_stable_diffusion_instruct_pix2pix.py +17 -13
- diffusers/pipelines/stable_diffusion/pipeline_stable_diffusion_k_diffusion.py +1 -0
- diffusers/pipelines/stable_diffusion/pipeline_stable_diffusion_latent_upscale.py +1 -0
- diffusers/pipelines/stable_diffusion/pipeline_stable_diffusion_ldm3d.py +1 -0
- diffusers/pipelines/stable_diffusion/pipeline_stable_diffusion_model_editing.py +1 -0
- diffusers/pipelines/stable_diffusion/pipeline_stable_diffusion_panorama.py +1 -0
- diffusers/pipelines/stable_diffusion/pipeline_stable_diffusion_paradigms.py +1 -0
- diffusers/pipelines/stable_diffusion/pipeline_stable_diffusion_pix2pix_zero.py +1 -0
- diffusers/pipelines/stable_diffusion/pipeline_stable_diffusion_sag.py +1 -0
- diffusers/pipelines/stable_diffusion/pipeline_stable_diffusion_upscale.py +1 -0
- diffusers/pipelines/stable_diffusion_xl/pipeline_stable_diffusion_xl.py +103 -8
- diffusers/pipelines/stable_diffusion_xl/pipeline_stable_diffusion_xl_img2img.py +113 -8
- diffusers/pipelines/stable_diffusion_xl/pipeline_stable_diffusion_xl_inpaint.py +115 -9
- diffusers/pipelines/stable_diffusion_xl/pipeline_stable_diffusion_xl_instruct_pix2pix.py +16 -12
- diffusers/pipelines/stable_video_diffusion/__init__.py +58 -0
- diffusers/pipelines/stable_video_diffusion/pipeline_stable_video_diffusion.py +649 -0
- diffusers/pipelines/t2i_adapter/pipeline_stable_diffusion_adapter.py +108 -12
- diffusers/pipelines/t2i_adapter/pipeline_stable_diffusion_xl_adapter.py +109 -14
- diffusers/pipelines/text_to_video_synthesis/__init__.py +2 -0
- diffusers/pipelines/text_to_video_synthesis/pipeline_text_to_video_synth.py +1 -0
- diffusers/pipelines/text_to_video_synthesis/pipeline_text_to_video_synth_img2img.py +18 -3
- diffusers/pipelines/text_to_video_synthesis/pipeline_text_to_video_zero.py +4 -2
- diffusers/pipelines/text_to_video_synthesis/pipeline_text_to_video_zero_sdxl.py +872 -0
- diffusers/pipelines/versatile_diffusion/modeling_text_unet.py +29 -40
- diffusers/pipelines/versatile_diffusion/pipeline_versatile_diffusion_dual_guided.py +1 -0
- diffusers/pipelines/versatile_diffusion/pipeline_versatile_diffusion_image_variation.py +1 -0
- diffusers/pipelines/versatile_diffusion/pipeline_versatile_diffusion_text_to_image.py +1 -0
- diffusers/pipelines/wuerstchen/modeling_wuerstchen_common.py +14 -4
- diffusers/pipelines/wuerstchen/modeling_wuerstchen_prior.py +9 -5
- diffusers/pipelines/wuerstchen/pipeline_wuerstchen.py +1 -1
- diffusers/pipelines/wuerstchen/pipeline_wuerstchen_combined.py +2 -2
- diffusers/pipelines/wuerstchen/pipeline_wuerstchen_prior.py +1 -1
- diffusers/schedulers/__init__.py +2 -4
- diffusers/schedulers/deprecated/__init__.py +50 -0
- diffusers/schedulers/{scheduling_karras_ve.py → deprecated/scheduling_karras_ve.py} +4 -4
- diffusers/schedulers/{scheduling_sde_vp.py → deprecated/scheduling_sde_vp.py} +4 -6
- diffusers/schedulers/scheduling_ddim.py +1 -3
- diffusers/schedulers/scheduling_ddim_inverse.py +1 -3
- diffusers/schedulers/scheduling_ddim_parallel.py +1 -3
- diffusers/schedulers/scheduling_ddpm.py +1 -3
- diffusers/schedulers/scheduling_ddpm_parallel.py +1 -3
- diffusers/schedulers/scheduling_deis_multistep.py +15 -5
- diffusers/schedulers/scheduling_dpmsolver_multistep.py +15 -5
- diffusers/schedulers/scheduling_dpmsolver_multistep_inverse.py +15 -5
- diffusers/schedulers/scheduling_dpmsolver_sde.py +1 -3
- diffusers/schedulers/scheduling_dpmsolver_singlestep.py +15 -5
- diffusers/schedulers/scheduling_euler_ancestral_discrete.py +1 -3
- diffusers/schedulers/scheduling_euler_discrete.py +40 -13
- diffusers/schedulers/scheduling_heun_discrete.py +15 -5
- diffusers/schedulers/scheduling_k_dpm_2_ancestral_discrete.py +15 -5
- diffusers/schedulers/scheduling_k_dpm_2_discrete.py +15 -5
- diffusers/schedulers/scheduling_lcm.py +123 -29
- diffusers/schedulers/scheduling_lms_discrete.py +1 -3
- diffusers/schedulers/scheduling_pndm.py +1 -3
- diffusers/schedulers/scheduling_repaint.py +1 -3
- diffusers/schedulers/scheduling_unipc_multistep.py +15 -5
- diffusers/utils/__init__.py +1 -0
- diffusers/utils/constants.py +8 -7
- diffusers/utils/dummy_pt_objects.py +45 -0
- diffusers/utils/dummy_torch_and_transformers_objects.py +60 -0
- diffusers/utils/dynamic_modules_utils.py +4 -4
- diffusers/utils/export_utils.py +8 -3
- diffusers/utils/logging.py +10 -10
- diffusers/utils/outputs.py +5 -5
- diffusers/utils/peft_utils.py +88 -44
- diffusers/utils/torch_utils.py +2 -2
- {diffusers-0.23.1.dist-info → diffusers-0.24.0.dist-info}/METADATA +38 -22
- {diffusers-0.23.1.dist-info → diffusers-0.24.0.dist-info}/RECORD +175 -157
- diffusers/loaders.py +0 -3336
- {diffusers-0.23.1.dist-info → diffusers-0.24.0.dist-info}/LICENSE +0 -0
- {diffusers-0.23.1.dist-info → diffusers-0.24.0.dist-info}/WHEEL +0 -0
- {diffusers-0.23.1.dist-info → diffusers-0.24.0.dist-info}/entry_points.txt +0 -0
- {diffusers-0.23.1.dist-info → diffusers-0.24.0.dist-info}/top_level.txt +0 -0
@@ -18,13 +18,14 @@ import inspect
|
|
18
18
|
import itertools
|
19
19
|
import os
|
20
20
|
import re
|
21
|
+
from collections import OrderedDict
|
21
22
|
from functools import partial
|
22
23
|
from typing import Any, Callable, List, Optional, Tuple, Union
|
23
24
|
|
24
25
|
import safetensors
|
25
26
|
import torch
|
26
27
|
from huggingface_hub import create_repo
|
27
|
-
from torch import Tensor,
|
28
|
+
from torch import Tensor, nn
|
28
29
|
|
29
30
|
from .. import __version__
|
30
31
|
from ..utils import (
|
@@ -61,7 +62,7 @@ if is_accelerate_available():
|
|
61
62
|
from accelerate.utils.versions import is_torch_version
|
62
63
|
|
63
64
|
|
64
|
-
def get_parameter_device(parameter: torch.nn.Module):
|
65
|
+
def get_parameter_device(parameter: torch.nn.Module) -> torch.device:
|
65
66
|
try:
|
66
67
|
parameters_and_buffers = itertools.chain(parameter.parameters(), parameter.buffers())
|
67
68
|
return next(parameters_and_buffers).device
|
@@ -77,7 +78,7 @@ def get_parameter_device(parameter: torch.nn.Module):
|
|
77
78
|
return first_tuple[1].device
|
78
79
|
|
79
80
|
|
80
|
-
def get_parameter_dtype(parameter: torch.nn.Module):
|
81
|
+
def get_parameter_dtype(parameter: torch.nn.Module) -> torch.dtype:
|
81
82
|
try:
|
82
83
|
params = tuple(parameter.parameters())
|
83
84
|
if len(params) > 0:
|
@@ -130,7 +131,13 @@ def load_state_dict(checkpoint_file: Union[str, os.PathLike], variant: Optional[
|
|
130
131
|
)
|
131
132
|
|
132
133
|
|
133
|
-
def load_model_dict_into_meta(
|
134
|
+
def load_model_dict_into_meta(
|
135
|
+
model,
|
136
|
+
state_dict: OrderedDict,
|
137
|
+
device: Optional[Union[str, torch.device]] = None,
|
138
|
+
dtype: Optional[Union[str, torch.dtype]] = None,
|
139
|
+
model_name_or_path: Optional[str] = None,
|
140
|
+
) -> List[str]:
|
134
141
|
device = device or torch.device("cpu")
|
135
142
|
dtype = dtype or torch.float32
|
136
143
|
|
@@ -156,7 +163,7 @@ def load_model_dict_into_meta(model, state_dict, device=None, dtype=None, model_
|
|
156
163
|
return unexpected_keys
|
157
164
|
|
158
165
|
|
159
|
-
def _load_state_dict_into_model(model_to_load, state_dict):
|
166
|
+
def _load_state_dict_into_model(model_to_load, state_dict: OrderedDict) -> List[str]:
|
160
167
|
# Convert old format to new format if needed from a PyTorch state_dict
|
161
168
|
# copy state_dict so _load_from_state_dict can modify it
|
162
169
|
state_dict = state_dict.copy()
|
@@ -164,7 +171,7 @@ def _load_state_dict_into_model(model_to_load, state_dict):
|
|
164
171
|
|
165
172
|
# PyTorch's `_load_from_state_dict` does not copy parameters in a module's descendants
|
166
173
|
# so we need to apply the function recursively.
|
167
|
-
def load(module: torch.nn.Module, prefix=""):
|
174
|
+
def load(module: torch.nn.Module, prefix: str = ""):
|
168
175
|
args = (state_dict, prefix, {}, True, [], [], error_msgs)
|
169
176
|
module._load_from_state_dict(*args)
|
170
177
|
|
@@ -186,6 +193,7 @@ class ModelMixin(torch.nn.Module, PushToHubMixin):
|
|
186
193
|
|
187
194
|
- **config_name** ([`str`]) -- Filename to save a model to when calling [`~models.ModelMixin.save_pretrained`].
|
188
195
|
"""
|
196
|
+
|
189
197
|
config_name = CONFIG_NAME
|
190
198
|
_automatically_saved_args = ["_diffusers_version", "_class_name", "_name_or_path"]
|
191
199
|
_supports_gradient_checkpointing = False
|
@@ -220,7 +228,7 @@ class ModelMixin(torch.nn.Module, PushToHubMixin):
|
|
220
228
|
"""
|
221
229
|
return any(hasattr(m, "gradient_checkpointing") and m.gradient_checkpointing for m in self.modules())
|
222
230
|
|
223
|
-
def enable_gradient_checkpointing(self):
|
231
|
+
def enable_gradient_checkpointing(self) -> None:
|
224
232
|
"""
|
225
233
|
Activates gradient checkpointing for the current model (may be referred to as *activation checkpointing* or
|
226
234
|
*checkpoint activations* in other frameworks).
|
@@ -229,7 +237,7 @@ class ModelMixin(torch.nn.Module, PushToHubMixin):
|
|
229
237
|
raise ValueError(f"{self.__class__.__name__} does not support gradient checkpointing.")
|
230
238
|
self.apply(partial(self._set_gradient_checkpointing, value=True))
|
231
239
|
|
232
|
-
def disable_gradient_checkpointing(self):
|
240
|
+
def disable_gradient_checkpointing(self) -> None:
|
233
241
|
"""
|
234
242
|
Deactivates gradient checkpointing for the current model (may be referred to as *activation checkpointing* or
|
235
243
|
*checkpoint activations* in other frameworks).
|
@@ -254,7 +262,7 @@ class ModelMixin(torch.nn.Module, PushToHubMixin):
|
|
254
262
|
if isinstance(module, torch.nn.Module):
|
255
263
|
fn_recursive_set_mem_eff(module)
|
256
264
|
|
257
|
-
def enable_xformers_memory_efficient_attention(self, attention_op: Optional[Callable] = None):
|
265
|
+
def enable_xformers_memory_efficient_attention(self, attention_op: Optional[Callable] = None) -> None:
|
258
266
|
r"""
|
259
267
|
Enable memory efficient attention from [xFormers](https://facebookresearch.github.io/xformers/).
|
260
268
|
|
@@ -290,7 +298,7 @@ class ModelMixin(torch.nn.Module, PushToHubMixin):
|
|
290
298
|
"""
|
291
299
|
self.set_use_memory_efficient_attention_xformers(True, attention_op)
|
292
300
|
|
293
|
-
def disable_xformers_memory_efficient_attention(self):
|
301
|
+
def disable_xformers_memory_efficient_attention(self) -> None:
|
294
302
|
r"""
|
295
303
|
Disable memory efficient attention from [xFormers](https://facebookresearch.github.io/xformers/).
|
296
304
|
"""
|
@@ -447,7 +455,7 @@ class ModelMixin(torch.nn.Module, PushToHubMixin):
|
|
447
455
|
self,
|
448
456
|
save_directory: Union[str, os.PathLike],
|
449
457
|
is_main_process: bool = True,
|
450
|
-
save_function: Callable = None,
|
458
|
+
save_function: Optional[Callable] = None,
|
451
459
|
safe_serialization: bool = True,
|
452
460
|
variant: Optional[str] = None,
|
453
461
|
push_to_hub: bool = False,
|
@@ -910,10 +918,10 @@ class ModelMixin(torch.nn.Module, PushToHubMixin):
|
|
910
918
|
def _load_pretrained_model(
|
911
919
|
cls,
|
912
920
|
model,
|
913
|
-
state_dict,
|
921
|
+
state_dict: OrderedDict,
|
914
922
|
resolved_archive_file,
|
915
|
-
pretrained_model_name_or_path,
|
916
|
-
ignore_mismatched_sizes=False,
|
923
|
+
pretrained_model_name_or_path: Union[str, os.PathLike],
|
924
|
+
ignore_mismatched_sizes: bool = False,
|
917
925
|
):
|
918
926
|
# Retrieve missing & unexpected_keys
|
919
927
|
model_state_dict = model.state_dict()
|
@@ -1011,7 +1019,7 @@ class ModelMixin(torch.nn.Module, PushToHubMixin):
|
|
1011
1019
|
return model, missing_keys, unexpected_keys, mismatched_keys, error_msgs
|
1012
1020
|
|
1013
1021
|
@property
|
1014
|
-
def device(self) -> device:
|
1022
|
+
def device(self) -> torch.device:
|
1015
1023
|
"""
|
1016
1024
|
`torch.device`: The device on which the module is (assuming that all the module parameters are on the same
|
1017
1025
|
device).
|
@@ -1063,7 +1071,7 @@ class ModelMixin(torch.nn.Module, PushToHubMixin):
|
|
1063
1071
|
else:
|
1064
1072
|
return sum(p.numel() for p in self.parameters() if p.requires_grad or not only_trainable)
|
1065
1073
|
|
1066
|
-
def _convert_deprecated_attention_blocks(self, state_dict):
|
1074
|
+
def _convert_deprecated_attention_blocks(self, state_dict: OrderedDict) -> None:
|
1067
1075
|
deprecated_attention_block_paths = []
|
1068
1076
|
|
1069
1077
|
def recursive_find_attn_block(name, module):
|
@@ -1107,7 +1115,7 @@ class ModelMixin(torch.nn.Module, PushToHubMixin):
|
|
1107
1115
|
if f"{path}.proj_attn.bias" in state_dict:
|
1108
1116
|
state_dict[f"{path}.to_out.0.bias"] = state_dict.pop(f"{path}.proj_attn.bias")
|
1109
1117
|
|
1110
|
-
def _temp_convert_self_to_deprecated_attention_blocks(self):
|
1118
|
+
def _temp_convert_self_to_deprecated_attention_blocks(self) -> None:
|
1111
1119
|
deprecated_attention_block_modules = []
|
1112
1120
|
|
1113
1121
|
def recursive_find_attn_block(module):
|
@@ -1134,10 +1142,10 @@ class ModelMixin(torch.nn.Module, PushToHubMixin):
|
|
1134
1142
|
del module.to_v
|
1135
1143
|
del module.to_out
|
1136
1144
|
|
1137
|
-
def _undo_temp_convert_self_to_deprecated_attention_blocks(self):
|
1145
|
+
def _undo_temp_convert_self_to_deprecated_attention_blocks(self) -> None:
|
1138
1146
|
deprecated_attention_block_modules = []
|
1139
1147
|
|
1140
|
-
def recursive_find_attn_block(module):
|
1148
|
+
def recursive_find_attn_block(module) -> None:
|
1141
1149
|
if hasattr(module, "_from_deprecated_attn_block") and module._from_deprecated_attn_block:
|
1142
1150
|
deprecated_attention_block_modules.append(module)
|
1143
1151
|
|
@@ -101,8 +101,8 @@ class AdaLayerNormSingle(nn.Module):
|
|
101
101
|
def forward(
|
102
102
|
self,
|
103
103
|
timestep: torch.Tensor,
|
104
|
-
added_cond_kwargs: Dict[str, torch.Tensor] = None,
|
105
|
-
batch_size: int = None,
|
104
|
+
added_cond_kwargs: Optional[Dict[str, torch.Tensor]] = None,
|
105
|
+
batch_size: Optional[int] = None,
|
106
106
|
hidden_dtype: Optional[torch.dtype] = None,
|
107
107
|
) -> Tuple[torch.Tensor, torch.Tensor, torch.Tensor, torch.Tensor, torch.Tensor]:
|
108
108
|
# No modulation happening here.
|