diffusers 0.23.0__py3-none-any.whl → 0.24.0__py3-none-any.whl
Sign up to get free protection for your applications and to get access to all the features.
- diffusers/__init__.py +16 -2
- diffusers/configuration_utils.py +1 -0
- diffusers/dependency_versions_check.py +1 -14
- diffusers/dependency_versions_table.py +5 -4
- diffusers/image_processor.py +186 -14
- diffusers/loaders/__init__.py +82 -0
- diffusers/loaders/ip_adapter.py +157 -0
- diffusers/loaders/lora.py +1415 -0
- diffusers/loaders/lora_conversion_utils.py +284 -0
- diffusers/loaders/single_file.py +631 -0
- diffusers/loaders/textual_inversion.py +459 -0
- diffusers/loaders/unet.py +735 -0
- diffusers/loaders/utils.py +59 -0
- diffusers/models/__init__.py +12 -1
- diffusers/models/attention.py +165 -14
- diffusers/models/attention_flax.py +9 -1
- diffusers/models/attention_processor.py +286 -1
- diffusers/models/autoencoder_asym_kl.py +14 -9
- diffusers/models/autoencoder_kl.py +3 -18
- diffusers/models/autoencoder_kl_temporal_decoder.py +402 -0
- diffusers/models/autoencoder_tiny.py +20 -24
- diffusers/models/consistency_decoder_vae.py +37 -30
- diffusers/models/controlnet.py +59 -39
- diffusers/models/controlnet_flax.py +19 -18
- diffusers/models/embeddings_flax.py +2 -0
- diffusers/models/lora.py +131 -1
- diffusers/models/modeling_flax_utils.py +2 -1
- diffusers/models/modeling_outputs.py +17 -0
- diffusers/models/modeling_utils.py +27 -19
- diffusers/models/normalization.py +2 -2
- diffusers/models/resnet.py +390 -59
- diffusers/models/transformer_2d.py +20 -3
- diffusers/models/transformer_temporal.py +183 -1
- diffusers/models/unet_2d_blocks_flax.py +5 -0
- diffusers/models/unet_2d_condition.py +9 -0
- diffusers/models/unet_2d_condition_flax.py +13 -13
- diffusers/models/unet_3d_blocks.py +957 -173
- diffusers/models/unet_3d_condition.py +16 -8
- diffusers/models/unet_kandi3.py +589 -0
- diffusers/models/unet_motion_model.py +48 -33
- diffusers/models/unet_spatio_temporal_condition.py +489 -0
- diffusers/models/vae.py +63 -13
- diffusers/models/vae_flax.py +7 -0
- diffusers/models/vq_model.py +3 -1
- diffusers/optimization.py +16 -9
- diffusers/pipelines/__init__.py +65 -12
- diffusers/pipelines/alt_diffusion/pipeline_alt_diffusion.py +93 -23
- diffusers/pipelines/alt_diffusion/pipeline_alt_diffusion_img2img.py +97 -25
- diffusers/pipelines/animatediff/pipeline_animatediff.py +34 -4
- diffusers/pipelines/audioldm/pipeline_audioldm.py +1 -0
- diffusers/pipelines/auto_pipeline.py +6 -0
- diffusers/pipelines/consistency_models/pipeline_consistency_models.py +1 -0
- diffusers/pipelines/controlnet/pipeline_controlnet.py +217 -31
- diffusers/pipelines/controlnet/pipeline_controlnet_img2img.py +101 -32
- diffusers/pipelines/controlnet/pipeline_controlnet_inpaint.py +136 -39
- diffusers/pipelines/controlnet/pipeline_controlnet_inpaint_sd_xl.py +119 -37
- diffusers/pipelines/controlnet/pipeline_controlnet_sd_xl.py +196 -35
- diffusers/pipelines/controlnet/pipeline_controlnet_sd_xl_img2img.py +102 -31
- diffusers/pipelines/dance_diffusion/pipeline_dance_diffusion.py +1 -0
- diffusers/pipelines/ddim/pipeline_ddim.py +1 -0
- diffusers/pipelines/ddpm/pipeline_ddpm.py +1 -0
- diffusers/pipelines/deepfloyd_if/pipeline_if.py +13 -1
- diffusers/pipelines/deepfloyd_if/pipeline_if_img2img.py +13 -1
- diffusers/pipelines/deepfloyd_if/pipeline_if_img2img_superresolution.py +13 -1
- diffusers/pipelines/deepfloyd_if/pipeline_if_inpainting.py +13 -1
- diffusers/pipelines/deepfloyd_if/pipeline_if_inpainting_superresolution.py +13 -1
- diffusers/pipelines/deepfloyd_if/pipeline_if_superresolution.py +13 -1
- diffusers/pipelines/dit/pipeline_dit.py +1 -0
- diffusers/pipelines/kandinsky2_2/pipeline_kandinsky2_2.py +1 -1
- diffusers/pipelines/kandinsky2_2/pipeline_kandinsky2_2_combined.py +3 -3
- diffusers/pipelines/kandinsky2_2/pipeline_kandinsky2_2_img2img.py +1 -1
- diffusers/pipelines/kandinsky2_2/pipeline_kandinsky2_2_inpainting.py +1 -1
- diffusers/pipelines/kandinsky2_2/pipeline_kandinsky2_2_prior.py +1 -1
- diffusers/pipelines/kandinsky3/__init__.py +49 -0
- diffusers/pipelines/kandinsky3/kandinsky3_pipeline.py +452 -0
- diffusers/pipelines/kandinsky3/kandinsky3img2img_pipeline.py +460 -0
- diffusers/pipelines/latent_consistency_models/pipeline_latent_consistency_img2img.py +65 -6
- diffusers/pipelines/latent_consistency_models/pipeline_latent_consistency_text2img.py +55 -3
- diffusers/pipelines/latent_diffusion/pipeline_latent_diffusion.py +1 -0
- diffusers/pipelines/musicldm/pipeline_musicldm.py +1 -1
- diffusers/pipelines/paint_by_example/pipeline_paint_by_example.py +7 -2
- diffusers/pipelines/pipeline_flax_utils.py +4 -2
- diffusers/pipelines/pipeline_utils.py +33 -13
- diffusers/pipelines/pixart_alpha/pipeline_pixart_alpha.py +196 -36
- diffusers/pipelines/score_sde_ve/pipeline_score_sde_ve.py +1 -0
- diffusers/pipelines/spectrogram_diffusion/pipeline_spectrogram_diffusion.py +1 -0
- diffusers/pipelines/stable_diffusion/__init__.py +64 -21
- diffusers/pipelines/stable_diffusion/convert_from_ckpt.py +8 -3
- diffusers/pipelines/stable_diffusion/pipeline_cycle_diffusion.py +18 -2
- diffusers/pipelines/stable_diffusion/pipeline_flax_stable_diffusion.py +2 -2
- diffusers/pipelines/stable_diffusion/pipeline_onnx_stable_diffusion_img2img.py +2 -4
- diffusers/pipelines/stable_diffusion/pipeline_onnx_stable_diffusion_inpaint.py +1 -0
- diffusers/pipelines/stable_diffusion/pipeline_onnx_stable_diffusion_inpaint_legacy.py +1 -0
- diffusers/pipelines/stable_diffusion/pipeline_stable_diffusion.py +88 -9
- diffusers/pipelines/stable_diffusion/pipeline_stable_diffusion_attend_and_excite.py +1 -0
- diffusers/pipelines/stable_diffusion/pipeline_stable_diffusion_depth2img.py +8 -3
- diffusers/pipelines/stable_diffusion/pipeline_stable_diffusion_diffedit.py +1 -0
- diffusers/pipelines/stable_diffusion/pipeline_stable_diffusion_gligen.py +1 -0
- diffusers/pipelines/stable_diffusion/pipeline_stable_diffusion_gligen_text_image.py +1 -0
- diffusers/pipelines/stable_diffusion/pipeline_stable_diffusion_image_variation.py +1 -0
- diffusers/pipelines/stable_diffusion/pipeline_stable_diffusion_img2img.py +92 -9
- diffusers/pipelines/stable_diffusion/pipeline_stable_diffusion_inpaint.py +92 -9
- diffusers/pipelines/stable_diffusion/pipeline_stable_diffusion_inpaint_legacy.py +1 -0
- diffusers/pipelines/stable_diffusion/pipeline_stable_diffusion_instruct_pix2pix.py +17 -13
- diffusers/pipelines/stable_diffusion/pipeline_stable_diffusion_k_diffusion.py +1 -0
- diffusers/pipelines/stable_diffusion/pipeline_stable_diffusion_latent_upscale.py +1 -0
- diffusers/pipelines/stable_diffusion/pipeline_stable_diffusion_ldm3d.py +1 -0
- diffusers/pipelines/stable_diffusion/pipeline_stable_diffusion_model_editing.py +1 -0
- diffusers/pipelines/stable_diffusion/pipeline_stable_diffusion_panorama.py +1 -0
- diffusers/pipelines/stable_diffusion/pipeline_stable_diffusion_paradigms.py +1 -0
- diffusers/pipelines/stable_diffusion/pipeline_stable_diffusion_pix2pix_zero.py +1 -0
- diffusers/pipelines/stable_diffusion/pipeline_stable_diffusion_sag.py +1 -0
- diffusers/pipelines/stable_diffusion/pipeline_stable_diffusion_upscale.py +1 -0
- diffusers/pipelines/stable_diffusion_xl/pipeline_stable_diffusion_xl.py +103 -8
- diffusers/pipelines/stable_diffusion_xl/pipeline_stable_diffusion_xl_img2img.py +113 -8
- diffusers/pipelines/stable_diffusion_xl/pipeline_stable_diffusion_xl_inpaint.py +115 -9
- diffusers/pipelines/stable_diffusion_xl/pipeline_stable_diffusion_xl_instruct_pix2pix.py +16 -12
- diffusers/pipelines/stable_video_diffusion/__init__.py +58 -0
- diffusers/pipelines/stable_video_diffusion/pipeline_stable_video_diffusion.py +649 -0
- diffusers/pipelines/t2i_adapter/pipeline_stable_diffusion_adapter.py +108 -12
- diffusers/pipelines/t2i_adapter/pipeline_stable_diffusion_xl_adapter.py +109 -14
- diffusers/pipelines/text_to_video_synthesis/__init__.py +2 -0
- diffusers/pipelines/text_to_video_synthesis/pipeline_text_to_video_synth.py +1 -0
- diffusers/pipelines/text_to_video_synthesis/pipeline_text_to_video_synth_img2img.py +18 -3
- diffusers/pipelines/text_to_video_synthesis/pipeline_text_to_video_zero.py +4 -2
- diffusers/pipelines/text_to_video_synthesis/pipeline_text_to_video_zero_sdxl.py +872 -0
- diffusers/pipelines/versatile_diffusion/modeling_text_unet.py +29 -40
- diffusers/pipelines/versatile_diffusion/pipeline_versatile_diffusion_dual_guided.py +1 -0
- diffusers/pipelines/versatile_diffusion/pipeline_versatile_diffusion_image_variation.py +1 -0
- diffusers/pipelines/versatile_diffusion/pipeline_versatile_diffusion_text_to_image.py +1 -0
- diffusers/pipelines/wuerstchen/modeling_wuerstchen_common.py +14 -4
- diffusers/pipelines/wuerstchen/modeling_wuerstchen_prior.py +9 -5
- diffusers/pipelines/wuerstchen/pipeline_wuerstchen.py +1 -1
- diffusers/pipelines/wuerstchen/pipeline_wuerstchen_combined.py +2 -2
- diffusers/pipelines/wuerstchen/pipeline_wuerstchen_prior.py +1 -1
- diffusers/schedulers/__init__.py +2 -4
- diffusers/schedulers/deprecated/__init__.py +50 -0
- diffusers/schedulers/{scheduling_karras_ve.py → deprecated/scheduling_karras_ve.py} +4 -4
- diffusers/schedulers/{scheduling_sde_vp.py → deprecated/scheduling_sde_vp.py} +4 -6
- diffusers/schedulers/scheduling_ddim.py +1 -3
- diffusers/schedulers/scheduling_ddim_inverse.py +1 -3
- diffusers/schedulers/scheduling_ddim_parallel.py +1 -3
- diffusers/schedulers/scheduling_ddpm.py +1 -3
- diffusers/schedulers/scheduling_ddpm_parallel.py +1 -3
- diffusers/schedulers/scheduling_deis_multistep.py +15 -5
- diffusers/schedulers/scheduling_dpmsolver_multistep.py +15 -5
- diffusers/schedulers/scheduling_dpmsolver_multistep_inverse.py +15 -5
- diffusers/schedulers/scheduling_dpmsolver_sde.py +1 -3
- diffusers/schedulers/scheduling_dpmsolver_singlestep.py +15 -5
- diffusers/schedulers/scheduling_euler_ancestral_discrete.py +1 -3
- diffusers/schedulers/scheduling_euler_discrete.py +40 -13
- diffusers/schedulers/scheduling_heun_discrete.py +15 -5
- diffusers/schedulers/scheduling_k_dpm_2_ancestral_discrete.py +15 -5
- diffusers/schedulers/scheduling_k_dpm_2_discrete.py +15 -5
- diffusers/schedulers/scheduling_lcm.py +123 -29
- diffusers/schedulers/scheduling_lms_discrete.py +1 -3
- diffusers/schedulers/scheduling_pndm.py +1 -3
- diffusers/schedulers/scheduling_repaint.py +1 -3
- diffusers/schedulers/scheduling_unipc_multistep.py +15 -5
- diffusers/utils/__init__.py +1 -0
- diffusers/utils/constants.py +11 -6
- diffusers/utils/dummy_pt_objects.py +45 -0
- diffusers/utils/dummy_torch_and_transformers_objects.py +60 -0
- diffusers/utils/dynamic_modules_utils.py +4 -4
- diffusers/utils/export_utils.py +8 -3
- diffusers/utils/logging.py +10 -10
- diffusers/utils/outputs.py +5 -5
- diffusers/utils/peft_utils.py +88 -44
- diffusers/utils/torch_utils.py +2 -2
- diffusers/utils/versions.py +117 -0
- {diffusers-0.23.0.dist-info → diffusers-0.24.0.dist-info}/METADATA +83 -64
- {diffusers-0.23.0.dist-info → diffusers-0.24.0.dist-info}/RECORD +176 -157
- {diffusers-0.23.0.dist-info → diffusers-0.24.0.dist-info}/WHEEL +1 -1
- {diffusers-0.23.0.dist-info → diffusers-0.24.0.dist-info}/entry_points.txt +1 -0
- diffusers/loaders.py +0 -3336
- {diffusers-0.23.0.dist-info → diffusers-0.24.0.dist-info}/LICENSE +0 -0
- {diffusers-0.23.0.dist-info → diffusers-0.24.0.dist-info}/top_level.txt +0 -0
@@ -242,6 +242,36 @@ class ImageTextPipelineOutput(metaclass=DummyObject):
|
|
242
242
|
requires_backends(cls, ["torch", "transformers"])
|
243
243
|
|
244
244
|
|
245
|
+
class Kandinsky3Img2ImgPipeline(metaclass=DummyObject):
|
246
|
+
_backends = ["torch", "transformers"]
|
247
|
+
|
248
|
+
def __init__(self, *args, **kwargs):
|
249
|
+
requires_backends(self, ["torch", "transformers"])
|
250
|
+
|
251
|
+
@classmethod
|
252
|
+
def from_config(cls, *args, **kwargs):
|
253
|
+
requires_backends(cls, ["torch", "transformers"])
|
254
|
+
|
255
|
+
@classmethod
|
256
|
+
def from_pretrained(cls, *args, **kwargs):
|
257
|
+
requires_backends(cls, ["torch", "transformers"])
|
258
|
+
|
259
|
+
|
260
|
+
class Kandinsky3Pipeline(metaclass=DummyObject):
|
261
|
+
_backends = ["torch", "transformers"]
|
262
|
+
|
263
|
+
def __init__(self, *args, **kwargs):
|
264
|
+
requires_backends(self, ["torch", "transformers"])
|
265
|
+
|
266
|
+
@classmethod
|
267
|
+
def from_config(cls, *args, **kwargs):
|
268
|
+
requires_backends(cls, ["torch", "transformers"])
|
269
|
+
|
270
|
+
@classmethod
|
271
|
+
def from_pretrained(cls, *args, **kwargs):
|
272
|
+
requires_backends(cls, ["torch", "transformers"])
|
273
|
+
|
274
|
+
|
245
275
|
class KandinskyCombinedPipeline(metaclass=DummyObject):
|
246
276
|
_backends = ["torch", "transformers"]
|
247
277
|
|
@@ -1142,6 +1172,21 @@ class StableUnCLIPPipeline(metaclass=DummyObject):
|
|
1142
1172
|
requires_backends(cls, ["torch", "transformers"])
|
1143
1173
|
|
1144
1174
|
|
1175
|
+
class StableVideoDiffusionPipeline(metaclass=DummyObject):
|
1176
|
+
_backends = ["torch", "transformers"]
|
1177
|
+
|
1178
|
+
def __init__(self, *args, **kwargs):
|
1179
|
+
requires_backends(self, ["torch", "transformers"])
|
1180
|
+
|
1181
|
+
@classmethod
|
1182
|
+
def from_config(cls, *args, **kwargs):
|
1183
|
+
requires_backends(cls, ["torch", "transformers"])
|
1184
|
+
|
1185
|
+
@classmethod
|
1186
|
+
def from_pretrained(cls, *args, **kwargs):
|
1187
|
+
requires_backends(cls, ["torch", "transformers"])
|
1188
|
+
|
1189
|
+
|
1145
1190
|
class TextToVideoSDPipeline(metaclass=DummyObject):
|
1146
1191
|
_backends = ["torch", "transformers"]
|
1147
1192
|
|
@@ -1172,6 +1217,21 @@ class TextToVideoZeroPipeline(metaclass=DummyObject):
|
|
1172
1217
|
requires_backends(cls, ["torch", "transformers"])
|
1173
1218
|
|
1174
1219
|
|
1220
|
+
class TextToVideoZeroSDXLPipeline(metaclass=DummyObject):
|
1221
|
+
_backends = ["torch", "transformers"]
|
1222
|
+
|
1223
|
+
def __init__(self, *args, **kwargs):
|
1224
|
+
requires_backends(self, ["torch", "transformers"])
|
1225
|
+
|
1226
|
+
@classmethod
|
1227
|
+
def from_config(cls, *args, **kwargs):
|
1228
|
+
requires_backends(cls, ["torch", "transformers"])
|
1229
|
+
|
1230
|
+
@classmethod
|
1231
|
+
def from_pretrained(cls, *args, **kwargs):
|
1232
|
+
requires_backends(cls, ["torch", "transformers"])
|
1233
|
+
|
1234
|
+
|
1175
1235
|
class UnCLIPImageVariationPipeline(metaclass=DummyObject):
|
1176
1236
|
_backends = ["torch", "transformers"]
|
1177
1237
|
|
@@ -87,9 +87,9 @@ def get_relative_imports(module_file):
|
|
87
87
|
content = f.read()
|
88
88
|
|
89
89
|
# Imports of the form `import .xxx`
|
90
|
-
relative_imports = re.findall("^\s*import\s+\.(\S+)\s*$", content, flags=re.MULTILINE)
|
90
|
+
relative_imports = re.findall(r"^\s*import\s+\.(\S+)\s*$", content, flags=re.MULTILINE)
|
91
91
|
# Imports of the form `from .xxx import yyy`
|
92
|
-
relative_imports += re.findall("^\s*from\s+\.(\S+)\s+import", content, flags=re.MULTILINE)
|
92
|
+
relative_imports += re.findall(r"^\s*from\s+\.(\S+)\s+import", content, flags=re.MULTILINE)
|
93
93
|
# Unique-ify
|
94
94
|
return list(set(relative_imports))
|
95
95
|
|
@@ -131,9 +131,9 @@ def check_imports(filename):
|
|
131
131
|
content = f.read()
|
132
132
|
|
133
133
|
# Imports of the form `import xxx`
|
134
|
-
imports = re.findall("^\s*import\s+(\S+)\s*$", content, flags=re.MULTILINE)
|
134
|
+
imports = re.findall(r"^\s*import\s+(\S+)\s*$", content, flags=re.MULTILINE)
|
135
135
|
# Imports of the form `from xxx import yyy`
|
136
|
-
imports += re.findall("^\s*from\s+(\S+)\s+import", content, flags=re.MULTILINE)
|
136
|
+
imports += re.findall(r"^\s*from\s+(\S+)\s+import", content, flags=re.MULTILINE)
|
137
137
|
# Only keep the top-level module
|
138
138
|
imports = [imp.split(".")[0] for imp in imports if not imp.startswith(".")]
|
139
139
|
|
diffusers/utils/export_utils.py
CHANGED
@@ -3,7 +3,7 @@ import random
|
|
3
3
|
import struct
|
4
4
|
import tempfile
|
5
5
|
from contextlib import contextmanager
|
6
|
-
from typing import List
|
6
|
+
from typing import List, Union
|
7
7
|
|
8
8
|
import numpy as np
|
9
9
|
import PIL.Image
|
@@ -115,7 +115,9 @@ def export_to_obj(mesh, output_obj_path: str = None):
|
|
115
115
|
f.writelines("\n".join(combined_data))
|
116
116
|
|
117
117
|
|
118
|
-
def export_to_video(
|
118
|
+
def export_to_video(
|
119
|
+
video_frames: Union[List[np.ndarray], List[PIL.Image.Image]], output_video_path: str = None, fps: int = 8
|
120
|
+
) -> str:
|
119
121
|
if is_opencv_available():
|
120
122
|
import cv2
|
121
123
|
else:
|
@@ -123,9 +125,12 @@ def export_to_video(video_frames: List[np.ndarray], output_video_path: str = Non
|
|
123
125
|
if output_video_path is None:
|
124
126
|
output_video_path = tempfile.NamedTemporaryFile(suffix=".mp4").name
|
125
127
|
|
128
|
+
if isinstance(video_frames[0], PIL.Image.Image):
|
129
|
+
video_frames = [np.array(frame) for frame in video_frames]
|
130
|
+
|
126
131
|
fourcc = cv2.VideoWriter_fourcc(*"mp4v")
|
127
132
|
h, w, c = video_frames[0].shape
|
128
|
-
video_writer = cv2.VideoWriter(output_video_path, fourcc, fps=
|
133
|
+
video_writer = cv2.VideoWriter(output_video_path, fourcc, fps=fps, frameSize=(w, h))
|
129
134
|
for i in range(len(video_frames)):
|
130
135
|
img = cv2.cvtColor(video_frames[i], cv2.COLOR_RGB2BGR)
|
131
136
|
video_writer.write(img)
|
diffusers/utils/logging.py
CHANGED
@@ -28,7 +28,7 @@ from logging import (
|
|
28
28
|
WARN, # NOQA
|
29
29
|
WARNING, # NOQA
|
30
30
|
)
|
31
|
-
from typing import Optional
|
31
|
+
from typing import Dict, Optional
|
32
32
|
|
33
33
|
from tqdm import auto as tqdm_lib
|
34
34
|
|
@@ -49,7 +49,7 @@ _default_log_level = logging.WARNING
|
|
49
49
|
_tqdm_active = True
|
50
50
|
|
51
51
|
|
52
|
-
def _get_default_logging_level():
|
52
|
+
def _get_default_logging_level() -> int:
|
53
53
|
"""
|
54
54
|
If DIFFUSERS_VERBOSITY env var is set to one of the valid choices return that as the new default level. If it is
|
55
55
|
not - fall back to `_default_log_level`
|
@@ -104,7 +104,7 @@ def _reset_library_root_logger() -> None:
|
|
104
104
|
_default_handler = None
|
105
105
|
|
106
106
|
|
107
|
-
def get_log_levels_dict():
|
107
|
+
def get_log_levels_dict() -> Dict[str, int]:
|
108
108
|
return log_levels
|
109
109
|
|
110
110
|
|
@@ -161,22 +161,22 @@ def set_verbosity(verbosity: int) -> None:
|
|
161
161
|
_get_library_root_logger().setLevel(verbosity)
|
162
162
|
|
163
163
|
|
164
|
-
def set_verbosity_info():
|
164
|
+
def set_verbosity_info() -> None:
|
165
165
|
"""Set the verbosity to the `INFO` level."""
|
166
166
|
return set_verbosity(INFO)
|
167
167
|
|
168
168
|
|
169
|
-
def set_verbosity_warning():
|
169
|
+
def set_verbosity_warning() -> None:
|
170
170
|
"""Set the verbosity to the `WARNING` level."""
|
171
171
|
return set_verbosity(WARNING)
|
172
172
|
|
173
173
|
|
174
|
-
def set_verbosity_debug():
|
174
|
+
def set_verbosity_debug() -> None:
|
175
175
|
"""Set the verbosity to the `DEBUG` level."""
|
176
176
|
return set_verbosity(DEBUG)
|
177
177
|
|
178
178
|
|
179
|
-
def set_verbosity_error():
|
179
|
+
def set_verbosity_error() -> None:
|
180
180
|
"""Set the verbosity to the `ERROR` level."""
|
181
181
|
return set_verbosity(ERROR)
|
182
182
|
|
@@ -263,7 +263,7 @@ def reset_format() -> None:
|
|
263
263
|
handler.setFormatter(None)
|
264
264
|
|
265
265
|
|
266
|
-
def warning_advice(self, *args, **kwargs):
|
266
|
+
def warning_advice(self, *args, **kwargs) -> None:
|
267
267
|
"""
|
268
268
|
This method is identical to `logger.warning()`, but if env var DIFFUSERS_NO_ADVISORY_WARNINGS=1 is set, this
|
269
269
|
warning will not be printed
|
@@ -327,13 +327,13 @@ def is_progress_bar_enabled() -> bool:
|
|
327
327
|
return bool(_tqdm_active)
|
328
328
|
|
329
329
|
|
330
|
-
def enable_progress_bar():
|
330
|
+
def enable_progress_bar() -> None:
|
331
331
|
"""Enable tqdm progress bar."""
|
332
332
|
global _tqdm_active
|
333
333
|
_tqdm_active = True
|
334
334
|
|
335
335
|
|
336
|
-
def disable_progress_bar():
|
336
|
+
def disable_progress_bar() -> None:
|
337
337
|
"""Disable tqdm progress bar."""
|
338
338
|
global _tqdm_active
|
339
339
|
_tqdm_active = False
|
diffusers/utils/outputs.py
CHANGED
@@ -24,7 +24,7 @@ import numpy as np
|
|
24
24
|
from .import_utils import is_torch_available
|
25
25
|
|
26
26
|
|
27
|
-
def is_tensor(x):
|
27
|
+
def is_tensor(x) -> bool:
|
28
28
|
"""
|
29
29
|
Tests if `x` is a `torch.Tensor` or `np.ndarray`.
|
30
30
|
"""
|
@@ -66,7 +66,7 @@ class BaseOutput(OrderedDict):
|
|
66
66
|
lambda values, context: cls(**torch.utils._pytree._dict_unflatten(values, context)),
|
67
67
|
)
|
68
68
|
|
69
|
-
def __post_init__(self):
|
69
|
+
def __post_init__(self) -> None:
|
70
70
|
class_fields = fields(self)
|
71
71
|
|
72
72
|
# Safety and consistency checks
|
@@ -97,14 +97,14 @@ class BaseOutput(OrderedDict):
|
|
97
97
|
def update(self, *args, **kwargs):
|
98
98
|
raise Exception(f"You cannot use ``update`` on a {self.__class__.__name__} instance.")
|
99
99
|
|
100
|
-
def __getitem__(self, k):
|
100
|
+
def __getitem__(self, k: Any) -> Any:
|
101
101
|
if isinstance(k, str):
|
102
102
|
inner_dict = dict(self.items())
|
103
103
|
return inner_dict[k]
|
104
104
|
else:
|
105
105
|
return self.to_tuple()[k]
|
106
106
|
|
107
|
-
def __setattr__(self, name, value):
|
107
|
+
def __setattr__(self, name: Any, value: Any) -> None:
|
108
108
|
if name in self.keys() and value is not None:
|
109
109
|
# Don't call self.__setitem__ to avoid recursion errors
|
110
110
|
super().__setitem__(name, value)
|
@@ -123,7 +123,7 @@ class BaseOutput(OrderedDict):
|
|
123
123
|
args = tuple(getattr(self, field.name) for field in fields(self))
|
124
124
|
return callable, args, *remaining
|
125
125
|
|
126
|
-
def to_tuple(self) -> Tuple[Any]:
|
126
|
+
def to_tuple(self) -> Tuple[Any, ...]:
|
127
127
|
"""
|
128
128
|
Convert self to a tuple containing all the attributes/keys that are not `None`.
|
129
129
|
"""
|
diffusers/utils/peft_utils.py
CHANGED
@@ -23,55 +23,77 @@ from packaging import version
|
|
23
23
|
from .import_utils import is_peft_available, is_torch_available
|
24
24
|
|
25
25
|
|
26
|
-
|
27
|
-
|
28
|
-
|
26
|
+
if is_torch_available():
|
27
|
+
import torch
|
28
|
+
|
29
29
|
|
30
|
+
def recurse_remove_peft_layers(model):
|
30
31
|
r"""
|
31
32
|
Recursively replace all instances of `LoraLayer` with corresponding new layers in `model`.
|
32
33
|
"""
|
33
|
-
from peft.tuners.
|
34
|
-
|
35
|
-
for name, module in model.named_children():
|
36
|
-
if len(list(module.children())) > 0:
|
37
|
-
## compound module, go inside it
|
38
|
-
recurse_remove_peft_layers(module)
|
39
|
-
|
40
|
-
module_replaced = False
|
41
|
-
|
42
|
-
if isinstance(module, LoraLayer) and isinstance(module, torch.nn.Linear):
|
43
|
-
new_module = torch.nn.Linear(module.in_features, module.out_features, bias=module.bias is not None).to(
|
44
|
-
module.weight.device
|
45
|
-
)
|
46
|
-
new_module.weight = module.weight
|
47
|
-
if module.bias is not None:
|
48
|
-
new_module.bias = module.bias
|
49
|
-
|
50
|
-
module_replaced = True
|
51
|
-
elif isinstance(module, LoraLayer) and isinstance(module, torch.nn.Conv2d):
|
52
|
-
new_module = torch.nn.Conv2d(
|
53
|
-
module.in_channels,
|
54
|
-
module.out_channels,
|
55
|
-
module.kernel_size,
|
56
|
-
module.stride,
|
57
|
-
module.padding,
|
58
|
-
module.dilation,
|
59
|
-
module.groups,
|
60
|
-
).to(module.weight.device)
|
61
|
-
|
62
|
-
new_module.weight = module.weight
|
63
|
-
if module.bias is not None:
|
64
|
-
new_module.bias = module.bias
|
65
|
-
|
66
|
-
module_replaced = True
|
67
|
-
|
68
|
-
if module_replaced:
|
69
|
-
setattr(model, name, new_module)
|
70
|
-
del module
|
71
|
-
|
72
|
-
if torch.cuda.is_available():
|
73
|
-
torch.cuda.empty_cache()
|
34
|
+
from peft.tuners.tuners_utils import BaseTunerLayer
|
74
35
|
|
36
|
+
has_base_layer_pattern = False
|
37
|
+
for module in model.modules():
|
38
|
+
if isinstance(module, BaseTunerLayer):
|
39
|
+
has_base_layer_pattern = hasattr(module, "base_layer")
|
40
|
+
break
|
41
|
+
|
42
|
+
if has_base_layer_pattern:
|
43
|
+
from peft.utils import _get_submodules
|
44
|
+
|
45
|
+
key_list = [key for key, _ in model.named_modules() if "lora" not in key]
|
46
|
+
for key in key_list:
|
47
|
+
try:
|
48
|
+
parent, target, target_name = _get_submodules(model, key)
|
49
|
+
except AttributeError:
|
50
|
+
continue
|
51
|
+
if hasattr(target, "base_layer"):
|
52
|
+
setattr(parent, target_name, target.get_base_layer())
|
53
|
+
else:
|
54
|
+
# This is for backwards compatibility with PEFT <= 0.6.2.
|
55
|
+
# TODO can be removed once that PEFT version is no longer supported.
|
56
|
+
from peft.tuners.lora import LoraLayer
|
57
|
+
|
58
|
+
for name, module in model.named_children():
|
59
|
+
if len(list(module.children())) > 0:
|
60
|
+
## compound module, go inside it
|
61
|
+
recurse_remove_peft_layers(module)
|
62
|
+
|
63
|
+
module_replaced = False
|
64
|
+
|
65
|
+
if isinstance(module, LoraLayer) and isinstance(module, torch.nn.Linear):
|
66
|
+
new_module = torch.nn.Linear(module.in_features, module.out_features, bias=module.bias is not None).to(
|
67
|
+
module.weight.device
|
68
|
+
)
|
69
|
+
new_module.weight = module.weight
|
70
|
+
if module.bias is not None:
|
71
|
+
new_module.bias = module.bias
|
72
|
+
|
73
|
+
module_replaced = True
|
74
|
+
elif isinstance(module, LoraLayer) and isinstance(module, torch.nn.Conv2d):
|
75
|
+
new_module = torch.nn.Conv2d(
|
76
|
+
module.in_channels,
|
77
|
+
module.out_channels,
|
78
|
+
module.kernel_size,
|
79
|
+
module.stride,
|
80
|
+
module.padding,
|
81
|
+
module.dilation,
|
82
|
+
module.groups,
|
83
|
+
).to(module.weight.device)
|
84
|
+
|
85
|
+
new_module.weight = module.weight
|
86
|
+
if module.bias is not None:
|
87
|
+
new_module.bias = module.bias
|
88
|
+
|
89
|
+
module_replaced = True
|
90
|
+
|
91
|
+
if module_replaced:
|
92
|
+
setattr(model, name, new_module)
|
93
|
+
del module
|
94
|
+
|
95
|
+
if torch.cuda.is_available():
|
96
|
+
torch.cuda.empty_cache()
|
75
97
|
return model
|
76
98
|
|
77
99
|
|
@@ -180,6 +202,28 @@ def set_adapter_layers(model, enabled=True):
|
|
180
202
|
module.disable_adapters = not enabled
|
181
203
|
|
182
204
|
|
205
|
+
def delete_adapter_layers(model, adapter_name):
|
206
|
+
from peft.tuners.tuners_utils import BaseTunerLayer
|
207
|
+
|
208
|
+
for module in model.modules():
|
209
|
+
if isinstance(module, BaseTunerLayer):
|
210
|
+
if hasattr(module, "delete_adapter"):
|
211
|
+
module.delete_adapter(adapter_name)
|
212
|
+
else:
|
213
|
+
raise ValueError(
|
214
|
+
"The version of PEFT you are using is not compatible, please use a version that is greater than 0.6.1"
|
215
|
+
)
|
216
|
+
|
217
|
+
# For transformers integration - we need to pop the adapter from the config
|
218
|
+
if getattr(model, "_hf_peft_config_loaded", False) and hasattr(model, "peft_config"):
|
219
|
+
model.peft_config.pop(adapter_name, None)
|
220
|
+
# In case all adapters are deleted, we need to delete the config
|
221
|
+
# and make sure to set the flag to False
|
222
|
+
if len(model.peft_config) == 0:
|
223
|
+
del model.peft_config
|
224
|
+
model._hf_peft_config_loaded = None
|
225
|
+
|
226
|
+
|
183
227
|
def set_weights_and_activate_adapters(model, adapter_names, weights):
|
184
228
|
from peft.tuners.tuners_utils import BaseTunerLayer
|
185
229
|
|
diffusers/utils/torch_utils.py
CHANGED
@@ -82,14 +82,14 @@ def randn_tensor(
|
|
82
82
|
return latents
|
83
83
|
|
84
84
|
|
85
|
-
def is_compiled_module(module):
|
85
|
+
def is_compiled_module(module) -> bool:
|
86
86
|
"""Check whether the module was compiled with torch.compile()"""
|
87
87
|
if is_torch_version("<", "2.0.0") or not hasattr(torch, "_dynamo"):
|
88
88
|
return False
|
89
89
|
return isinstance(module, torch._dynamo.eval_frame.OptimizedModule)
|
90
90
|
|
91
91
|
|
92
|
-
def fourier_filter(x_in, threshold, scale):
|
92
|
+
def fourier_filter(x_in: torch.Tensor, threshold: int, scale: int) -> torch.Tensor:
|
93
93
|
"""Fourier filter as introduced in FreeU (https://arxiv.org/abs/2309.11497).
|
94
94
|
|
95
95
|
This version of the method comes from here:
|
@@ -0,0 +1,117 @@
|
|
1
|
+
# Copyright 2020 The HuggingFace Team. All rights reserved.
|
2
|
+
#
|
3
|
+
# Licensed under the Apache License, Version 2.0 (the "License");
|
4
|
+
# you may not use this file except in compliance with the License.
|
5
|
+
# You may obtain a copy of the License at
|
6
|
+
#
|
7
|
+
# http://www.apache.org/licenses/LICENSE-2.0
|
8
|
+
#
|
9
|
+
# Unless required by applicable law or agreed to in writing, software
|
10
|
+
# distributed under the License is distributed on an "AS IS" BASIS,
|
11
|
+
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
12
|
+
# See the License for the specific language governing permissions and
|
13
|
+
# limitations under the License.
|
14
|
+
"""
|
15
|
+
Utilities for working with package versions
|
16
|
+
"""
|
17
|
+
|
18
|
+
import importlib.metadata
|
19
|
+
import operator
|
20
|
+
import re
|
21
|
+
import sys
|
22
|
+
from typing import Optional
|
23
|
+
|
24
|
+
from packaging import version
|
25
|
+
|
26
|
+
|
27
|
+
ops = {
|
28
|
+
"<": operator.lt,
|
29
|
+
"<=": operator.le,
|
30
|
+
"==": operator.eq,
|
31
|
+
"!=": operator.ne,
|
32
|
+
">=": operator.ge,
|
33
|
+
">": operator.gt,
|
34
|
+
}
|
35
|
+
|
36
|
+
|
37
|
+
def _compare_versions(op, got_ver, want_ver, requirement, pkg, hint):
|
38
|
+
if got_ver is None or want_ver is None:
|
39
|
+
raise ValueError(
|
40
|
+
f"Unable to compare versions for {requirement}: need={want_ver} found={got_ver}. This is unusual. Consider"
|
41
|
+
f" reinstalling {pkg}."
|
42
|
+
)
|
43
|
+
if not ops[op](version.parse(got_ver), version.parse(want_ver)):
|
44
|
+
raise ImportError(
|
45
|
+
f"{requirement} is required for a normal functioning of this module, but found {pkg}=={got_ver}.{hint}"
|
46
|
+
)
|
47
|
+
|
48
|
+
|
49
|
+
def require_version(requirement: str, hint: Optional[str] = None) -> None:
|
50
|
+
"""
|
51
|
+
Perform a runtime check of the dependency versions, using the exact same syntax used by pip.
|
52
|
+
|
53
|
+
The installed module version comes from the *site-packages* dir via *importlib.metadata*.
|
54
|
+
|
55
|
+
Args:
|
56
|
+
requirement (`str`): pip style definition, e.g., "tokenizers==0.9.4", "tqdm>=4.27", "numpy"
|
57
|
+
hint (`str`, *optional*): what suggestion to print in case of requirements not being met
|
58
|
+
|
59
|
+
Example:
|
60
|
+
|
61
|
+
```python
|
62
|
+
require_version("pandas>1.1.2")
|
63
|
+
require_version("numpy>1.18.5", "this is important to have for whatever reason")
|
64
|
+
```"""
|
65
|
+
|
66
|
+
hint = f"\n{hint}" if hint is not None else ""
|
67
|
+
|
68
|
+
# non-versioned check
|
69
|
+
if re.match(r"^[\w_\-\d]+$", requirement):
|
70
|
+
pkg, op, want_ver = requirement, None, None
|
71
|
+
else:
|
72
|
+
match = re.findall(r"^([^!=<>\s]+)([\s!=<>]{1,2}.+)", requirement)
|
73
|
+
if not match:
|
74
|
+
raise ValueError(
|
75
|
+
"requirement needs to be in the pip package format, .e.g., package_a==1.23, or package_b>=1.23, but"
|
76
|
+
f" got {requirement}"
|
77
|
+
)
|
78
|
+
pkg, want_full = match[0]
|
79
|
+
want_range = want_full.split(",") # there could be multiple requirements
|
80
|
+
wanted = {}
|
81
|
+
for w in want_range:
|
82
|
+
match = re.findall(r"^([\s!=<>]{1,2})(.+)", w)
|
83
|
+
if not match:
|
84
|
+
raise ValueError(
|
85
|
+
"requirement needs to be in the pip package format, .e.g., package_a==1.23, or package_b>=1.23,"
|
86
|
+
f" but got {requirement}"
|
87
|
+
)
|
88
|
+
op, want_ver = match[0]
|
89
|
+
wanted[op] = want_ver
|
90
|
+
if op not in ops:
|
91
|
+
raise ValueError(f"{requirement}: need one of {list(ops.keys())}, but got {op}")
|
92
|
+
|
93
|
+
# special case
|
94
|
+
if pkg == "python":
|
95
|
+
got_ver = ".".join([str(x) for x in sys.version_info[:3]])
|
96
|
+
for op, want_ver in wanted.items():
|
97
|
+
_compare_versions(op, got_ver, want_ver, requirement, pkg, hint)
|
98
|
+
return
|
99
|
+
|
100
|
+
# check if any version is installed
|
101
|
+
try:
|
102
|
+
got_ver = importlib.metadata.version(pkg)
|
103
|
+
except importlib.metadata.PackageNotFoundError:
|
104
|
+
raise importlib.metadata.PackageNotFoundError(
|
105
|
+
f"The '{requirement}' distribution was not found and is required by this application. {hint}"
|
106
|
+
)
|
107
|
+
|
108
|
+
# check that the right version is installed if version number or a range was provided
|
109
|
+
if want_ver is not None:
|
110
|
+
for op, want_ver in wanted.items():
|
111
|
+
_compare_versions(op, got_ver, want_ver, requirement, pkg, hint)
|
112
|
+
|
113
|
+
|
114
|
+
def require_version_core(requirement):
|
115
|
+
"""require_version wrapper which emits a core-specific hint on failure"""
|
116
|
+
hint = "Try: pip install transformers -U or pip install -e '.[dev]' if you're working with git main"
|
117
|
+
return require_version(requirement, hint)
|