diffusers 0.23.1__py3-none-any.whl → 0.24.0__py3-none-any.whl
Sign up to get free protection for your applications and to get access to all the features.
- diffusers/__init__.py +16 -2
- diffusers/configuration_utils.py +1 -0
- diffusers/dependency_versions_check.py +0 -1
- diffusers/dependency_versions_table.py +4 -5
- diffusers/image_processor.py +186 -14
- diffusers/loaders/__init__.py +82 -0
- diffusers/loaders/ip_adapter.py +157 -0
- diffusers/loaders/lora.py +1415 -0
- diffusers/loaders/lora_conversion_utils.py +284 -0
- diffusers/loaders/single_file.py +631 -0
- diffusers/loaders/textual_inversion.py +459 -0
- diffusers/loaders/unet.py +735 -0
- diffusers/loaders/utils.py +59 -0
- diffusers/models/__init__.py +12 -1
- diffusers/models/attention.py +165 -14
- diffusers/models/attention_flax.py +9 -1
- diffusers/models/attention_processor.py +286 -1
- diffusers/models/autoencoder_asym_kl.py +14 -9
- diffusers/models/autoencoder_kl.py +3 -18
- diffusers/models/autoencoder_kl_temporal_decoder.py +402 -0
- diffusers/models/autoencoder_tiny.py +20 -24
- diffusers/models/consistency_decoder_vae.py +37 -30
- diffusers/models/controlnet.py +59 -39
- diffusers/models/controlnet_flax.py +19 -18
- diffusers/models/embeddings_flax.py +2 -0
- diffusers/models/lora.py +131 -1
- diffusers/models/modeling_flax_utils.py +2 -1
- diffusers/models/modeling_outputs.py +17 -0
- diffusers/models/modeling_utils.py +27 -19
- diffusers/models/normalization.py +2 -2
- diffusers/models/resnet.py +390 -59
- diffusers/models/transformer_2d.py +20 -3
- diffusers/models/transformer_temporal.py +183 -1
- diffusers/models/unet_2d_blocks_flax.py +5 -0
- diffusers/models/unet_2d_condition.py +9 -0
- diffusers/models/unet_2d_condition_flax.py +13 -13
- diffusers/models/unet_3d_blocks.py +957 -173
- diffusers/models/unet_3d_condition.py +16 -8
- diffusers/models/unet_kandi3.py +589 -0
- diffusers/models/unet_motion_model.py +48 -33
- diffusers/models/unet_spatio_temporal_condition.py +489 -0
- diffusers/models/vae.py +63 -13
- diffusers/models/vae_flax.py +7 -0
- diffusers/models/vq_model.py +3 -1
- diffusers/optimization.py +16 -9
- diffusers/pipelines/__init__.py +65 -12
- diffusers/pipelines/alt_diffusion/pipeline_alt_diffusion.py +93 -23
- diffusers/pipelines/alt_diffusion/pipeline_alt_diffusion_img2img.py +97 -25
- diffusers/pipelines/animatediff/pipeline_animatediff.py +34 -4
- diffusers/pipelines/audioldm/pipeline_audioldm.py +1 -0
- diffusers/pipelines/auto_pipeline.py +6 -0
- diffusers/pipelines/consistency_models/pipeline_consistency_models.py +1 -0
- diffusers/pipelines/controlnet/pipeline_controlnet.py +217 -31
- diffusers/pipelines/controlnet/pipeline_controlnet_img2img.py +101 -32
- diffusers/pipelines/controlnet/pipeline_controlnet_inpaint.py +136 -39
- diffusers/pipelines/controlnet/pipeline_controlnet_inpaint_sd_xl.py +119 -37
- diffusers/pipelines/controlnet/pipeline_controlnet_sd_xl.py +196 -35
- diffusers/pipelines/controlnet/pipeline_controlnet_sd_xl_img2img.py +102 -31
- diffusers/pipelines/dance_diffusion/pipeline_dance_diffusion.py +1 -0
- diffusers/pipelines/ddim/pipeline_ddim.py +1 -0
- diffusers/pipelines/ddpm/pipeline_ddpm.py +1 -0
- diffusers/pipelines/deepfloyd_if/pipeline_if.py +13 -1
- diffusers/pipelines/deepfloyd_if/pipeline_if_img2img.py +13 -1
- diffusers/pipelines/deepfloyd_if/pipeline_if_img2img_superresolution.py +13 -1
- diffusers/pipelines/deepfloyd_if/pipeline_if_inpainting.py +13 -1
- diffusers/pipelines/deepfloyd_if/pipeline_if_inpainting_superresolution.py +13 -1
- diffusers/pipelines/deepfloyd_if/pipeline_if_superresolution.py +13 -1
- diffusers/pipelines/dit/pipeline_dit.py +1 -0
- diffusers/pipelines/kandinsky2_2/pipeline_kandinsky2_2.py +1 -1
- diffusers/pipelines/kandinsky2_2/pipeline_kandinsky2_2_combined.py +3 -3
- diffusers/pipelines/kandinsky2_2/pipeline_kandinsky2_2_img2img.py +1 -1
- diffusers/pipelines/kandinsky2_2/pipeline_kandinsky2_2_inpainting.py +1 -1
- diffusers/pipelines/kandinsky2_2/pipeline_kandinsky2_2_prior.py +1 -1
- diffusers/pipelines/kandinsky3/__init__.py +49 -0
- diffusers/pipelines/kandinsky3/kandinsky3_pipeline.py +452 -0
- diffusers/pipelines/kandinsky3/kandinsky3img2img_pipeline.py +460 -0
- diffusers/pipelines/latent_consistency_models/pipeline_latent_consistency_img2img.py +65 -6
- diffusers/pipelines/latent_consistency_models/pipeline_latent_consistency_text2img.py +55 -3
- diffusers/pipelines/latent_diffusion/pipeline_latent_diffusion.py +1 -0
- diffusers/pipelines/musicldm/pipeline_musicldm.py +1 -1
- diffusers/pipelines/paint_by_example/pipeline_paint_by_example.py +7 -2
- diffusers/pipelines/pipeline_flax_utils.py +4 -2
- diffusers/pipelines/pipeline_utils.py +33 -13
- diffusers/pipelines/pixart_alpha/pipeline_pixart_alpha.py +196 -36
- diffusers/pipelines/score_sde_ve/pipeline_score_sde_ve.py +1 -0
- diffusers/pipelines/spectrogram_diffusion/pipeline_spectrogram_diffusion.py +1 -0
- diffusers/pipelines/stable_diffusion/__init__.py +64 -21
- diffusers/pipelines/stable_diffusion/convert_from_ckpt.py +8 -3
- diffusers/pipelines/stable_diffusion/pipeline_cycle_diffusion.py +18 -2
- diffusers/pipelines/stable_diffusion/pipeline_flax_stable_diffusion.py +2 -2
- diffusers/pipelines/stable_diffusion/pipeline_onnx_stable_diffusion_img2img.py +2 -4
- diffusers/pipelines/stable_diffusion/pipeline_onnx_stable_diffusion_inpaint.py +1 -0
- diffusers/pipelines/stable_diffusion/pipeline_onnx_stable_diffusion_inpaint_legacy.py +1 -0
- diffusers/pipelines/stable_diffusion/pipeline_stable_diffusion.py +88 -9
- diffusers/pipelines/stable_diffusion/pipeline_stable_diffusion_attend_and_excite.py +1 -0
- diffusers/pipelines/stable_diffusion/pipeline_stable_diffusion_depth2img.py +8 -3
- diffusers/pipelines/stable_diffusion/pipeline_stable_diffusion_diffedit.py +1 -0
- diffusers/pipelines/stable_diffusion/pipeline_stable_diffusion_gligen.py +1 -0
- diffusers/pipelines/stable_diffusion/pipeline_stable_diffusion_gligen_text_image.py +1 -0
- diffusers/pipelines/stable_diffusion/pipeline_stable_diffusion_image_variation.py +1 -0
- diffusers/pipelines/stable_diffusion/pipeline_stable_diffusion_img2img.py +92 -9
- diffusers/pipelines/stable_diffusion/pipeline_stable_diffusion_inpaint.py +92 -9
- diffusers/pipelines/stable_diffusion/pipeline_stable_diffusion_inpaint_legacy.py +1 -0
- diffusers/pipelines/stable_diffusion/pipeline_stable_diffusion_instruct_pix2pix.py +17 -13
- diffusers/pipelines/stable_diffusion/pipeline_stable_diffusion_k_diffusion.py +1 -0
- diffusers/pipelines/stable_diffusion/pipeline_stable_diffusion_latent_upscale.py +1 -0
- diffusers/pipelines/stable_diffusion/pipeline_stable_diffusion_ldm3d.py +1 -0
- diffusers/pipelines/stable_diffusion/pipeline_stable_diffusion_model_editing.py +1 -0
- diffusers/pipelines/stable_diffusion/pipeline_stable_diffusion_panorama.py +1 -0
- diffusers/pipelines/stable_diffusion/pipeline_stable_diffusion_paradigms.py +1 -0
- diffusers/pipelines/stable_diffusion/pipeline_stable_diffusion_pix2pix_zero.py +1 -0
- diffusers/pipelines/stable_diffusion/pipeline_stable_diffusion_sag.py +1 -0
- diffusers/pipelines/stable_diffusion/pipeline_stable_diffusion_upscale.py +1 -0
- diffusers/pipelines/stable_diffusion_xl/pipeline_stable_diffusion_xl.py +103 -8
- diffusers/pipelines/stable_diffusion_xl/pipeline_stable_diffusion_xl_img2img.py +113 -8
- diffusers/pipelines/stable_diffusion_xl/pipeline_stable_diffusion_xl_inpaint.py +115 -9
- diffusers/pipelines/stable_diffusion_xl/pipeline_stable_diffusion_xl_instruct_pix2pix.py +16 -12
- diffusers/pipelines/stable_video_diffusion/__init__.py +58 -0
- diffusers/pipelines/stable_video_diffusion/pipeline_stable_video_diffusion.py +649 -0
- diffusers/pipelines/t2i_adapter/pipeline_stable_diffusion_adapter.py +108 -12
- diffusers/pipelines/t2i_adapter/pipeline_stable_diffusion_xl_adapter.py +109 -14
- diffusers/pipelines/text_to_video_synthesis/__init__.py +2 -0
- diffusers/pipelines/text_to_video_synthesis/pipeline_text_to_video_synth.py +1 -0
- diffusers/pipelines/text_to_video_synthesis/pipeline_text_to_video_synth_img2img.py +18 -3
- diffusers/pipelines/text_to_video_synthesis/pipeline_text_to_video_zero.py +4 -2
- diffusers/pipelines/text_to_video_synthesis/pipeline_text_to_video_zero_sdxl.py +872 -0
- diffusers/pipelines/versatile_diffusion/modeling_text_unet.py +29 -40
- diffusers/pipelines/versatile_diffusion/pipeline_versatile_diffusion_dual_guided.py +1 -0
- diffusers/pipelines/versatile_diffusion/pipeline_versatile_diffusion_image_variation.py +1 -0
- diffusers/pipelines/versatile_diffusion/pipeline_versatile_diffusion_text_to_image.py +1 -0
- diffusers/pipelines/wuerstchen/modeling_wuerstchen_common.py +14 -4
- diffusers/pipelines/wuerstchen/modeling_wuerstchen_prior.py +9 -5
- diffusers/pipelines/wuerstchen/pipeline_wuerstchen.py +1 -1
- diffusers/pipelines/wuerstchen/pipeline_wuerstchen_combined.py +2 -2
- diffusers/pipelines/wuerstchen/pipeline_wuerstchen_prior.py +1 -1
- diffusers/schedulers/__init__.py +2 -4
- diffusers/schedulers/deprecated/__init__.py +50 -0
- diffusers/schedulers/{scheduling_karras_ve.py → deprecated/scheduling_karras_ve.py} +4 -4
- diffusers/schedulers/{scheduling_sde_vp.py → deprecated/scheduling_sde_vp.py} +4 -6
- diffusers/schedulers/scheduling_ddim.py +1 -3
- diffusers/schedulers/scheduling_ddim_inverse.py +1 -3
- diffusers/schedulers/scheduling_ddim_parallel.py +1 -3
- diffusers/schedulers/scheduling_ddpm.py +1 -3
- diffusers/schedulers/scheduling_ddpm_parallel.py +1 -3
- diffusers/schedulers/scheduling_deis_multistep.py +15 -5
- diffusers/schedulers/scheduling_dpmsolver_multistep.py +15 -5
- diffusers/schedulers/scheduling_dpmsolver_multistep_inverse.py +15 -5
- diffusers/schedulers/scheduling_dpmsolver_sde.py +1 -3
- diffusers/schedulers/scheduling_dpmsolver_singlestep.py +15 -5
- diffusers/schedulers/scheduling_euler_ancestral_discrete.py +1 -3
- diffusers/schedulers/scheduling_euler_discrete.py +40 -13
- diffusers/schedulers/scheduling_heun_discrete.py +15 -5
- diffusers/schedulers/scheduling_k_dpm_2_ancestral_discrete.py +15 -5
- diffusers/schedulers/scheduling_k_dpm_2_discrete.py +15 -5
- diffusers/schedulers/scheduling_lcm.py +123 -29
- diffusers/schedulers/scheduling_lms_discrete.py +1 -3
- diffusers/schedulers/scheduling_pndm.py +1 -3
- diffusers/schedulers/scheduling_repaint.py +1 -3
- diffusers/schedulers/scheduling_unipc_multistep.py +15 -5
- diffusers/utils/__init__.py +1 -0
- diffusers/utils/constants.py +8 -7
- diffusers/utils/dummy_pt_objects.py +45 -0
- diffusers/utils/dummy_torch_and_transformers_objects.py +60 -0
- diffusers/utils/dynamic_modules_utils.py +4 -4
- diffusers/utils/export_utils.py +8 -3
- diffusers/utils/logging.py +10 -10
- diffusers/utils/outputs.py +5 -5
- diffusers/utils/peft_utils.py +88 -44
- diffusers/utils/torch_utils.py +2 -2
- {diffusers-0.23.1.dist-info → diffusers-0.24.0.dist-info}/METADATA +38 -22
- {diffusers-0.23.1.dist-info → diffusers-0.24.0.dist-info}/RECORD +175 -157
- diffusers/loaders.py +0 -3336
- {diffusers-0.23.1.dist-info → diffusers-0.24.0.dist-info}/LICENSE +0 -0
- {diffusers-0.23.1.dist-info → diffusers-0.24.0.dist-info}/WHEEL +0 -0
- {diffusers-0.23.1.dist-info → diffusers-0.24.0.dist-info}/entry_points.txt +0 -0
- {diffusers-0.23.1.dist-info → diffusers-0.24.0.dist-info}/top_level.txt +0 -0
diffusers/__init__.py
CHANGED
@@ -1,4 +1,4 @@
|
|
1
|
-
__version__ = "0.
|
1
|
+
__version__ = "0.24.0"
|
2
2
|
|
3
3
|
from typing import TYPE_CHECKING
|
4
4
|
|
@@ -76,9 +76,11 @@ else:
|
|
76
76
|
[
|
77
77
|
"AsymmetricAutoencoderKL",
|
78
78
|
"AutoencoderKL",
|
79
|
+
"AutoencoderKLTemporalDecoder",
|
79
80
|
"AutoencoderTiny",
|
80
81
|
"ConsistencyDecoderVAE",
|
81
82
|
"ControlNetModel",
|
83
|
+
"Kandinsky3UNet",
|
82
84
|
"ModelMixin",
|
83
85
|
"MotionAdapter",
|
84
86
|
"MultiAdapter",
|
@@ -91,9 +93,11 @@ else:
|
|
91
93
|
"UNet2DModel",
|
92
94
|
"UNet3DConditionModel",
|
93
95
|
"UNetMotionModel",
|
96
|
+
"UNetSpatioTemporalConditionModel",
|
94
97
|
"VQModel",
|
95
98
|
]
|
96
99
|
)
|
100
|
+
|
97
101
|
_import_structure["optimization"] = [
|
98
102
|
"get_constant_schedule",
|
99
103
|
"get_constant_schedule_with_warmup",
|
@@ -103,7 +107,6 @@ else:
|
|
103
107
|
"get_polynomial_decay_schedule_with_warmup",
|
104
108
|
"get_scheduler",
|
105
109
|
]
|
106
|
-
|
107
110
|
_import_structure["pipelines"].extend(
|
108
111
|
[
|
109
112
|
"AudioPipelineOutput",
|
@@ -214,6 +217,8 @@ else:
|
|
214
217
|
"IFPipeline",
|
215
218
|
"IFSuperResolutionPipeline",
|
216
219
|
"ImageTextPipelineOutput",
|
220
|
+
"Kandinsky3Img2ImgPipeline",
|
221
|
+
"Kandinsky3Pipeline",
|
217
222
|
"KandinskyCombinedPipeline",
|
218
223
|
"KandinskyImg2ImgCombinedPipeline",
|
219
224
|
"KandinskyImg2ImgPipeline",
|
@@ -274,8 +279,10 @@ else:
|
|
274
279
|
"StableDiffusionXLPipeline",
|
275
280
|
"StableUnCLIPImg2ImgPipeline",
|
276
281
|
"StableUnCLIPPipeline",
|
282
|
+
"StableVideoDiffusionPipeline",
|
277
283
|
"TextToVideoSDPipeline",
|
278
284
|
"TextToVideoZeroPipeline",
|
285
|
+
"TextToVideoZeroSDXLPipeline",
|
279
286
|
"UnCLIPImageVariationPipeline",
|
280
287
|
"UnCLIPPipeline",
|
281
288
|
"UniDiffuserModel",
|
@@ -443,9 +450,11 @@ if TYPE_CHECKING or DIFFUSERS_SLOW_IMPORT:
|
|
443
450
|
from .models import (
|
444
451
|
AsymmetricAutoencoderKL,
|
445
452
|
AutoencoderKL,
|
453
|
+
AutoencoderKLTemporalDecoder,
|
446
454
|
AutoencoderTiny,
|
447
455
|
ConsistencyDecoderVAE,
|
448
456
|
ControlNetModel,
|
457
|
+
Kandinsky3UNet,
|
449
458
|
ModelMixin,
|
450
459
|
MotionAdapter,
|
451
460
|
MultiAdapter,
|
@@ -458,6 +467,7 @@ if TYPE_CHECKING or DIFFUSERS_SLOW_IMPORT:
|
|
458
467
|
UNet2DModel,
|
459
468
|
UNet3DConditionModel,
|
460
469
|
UNetMotionModel,
|
470
|
+
UNetSpatioTemporalConditionModel,
|
461
471
|
VQModel,
|
462
472
|
)
|
463
473
|
from .optimization import (
|
@@ -560,6 +570,8 @@ if TYPE_CHECKING or DIFFUSERS_SLOW_IMPORT:
|
|
560
570
|
IFPipeline,
|
561
571
|
IFSuperResolutionPipeline,
|
562
572
|
ImageTextPipelineOutput,
|
573
|
+
Kandinsky3Img2ImgPipeline,
|
574
|
+
Kandinsky3Pipeline,
|
563
575
|
KandinskyCombinedPipeline,
|
564
576
|
KandinskyImg2ImgCombinedPipeline,
|
565
577
|
KandinskyImg2ImgPipeline,
|
@@ -620,8 +632,10 @@ if TYPE_CHECKING or DIFFUSERS_SLOW_IMPORT:
|
|
620
632
|
StableDiffusionXLPipeline,
|
621
633
|
StableUnCLIPImg2ImgPipeline,
|
622
634
|
StableUnCLIPPipeline,
|
635
|
+
StableVideoDiffusionPipeline,
|
623
636
|
TextToVideoSDPipeline,
|
624
637
|
TextToVideoZeroPipeline,
|
638
|
+
TextToVideoZeroSDXLPipeline,
|
625
639
|
UnCLIPImageVariationPipeline,
|
626
640
|
UnCLIPPipeline,
|
627
641
|
UniDiffuserModel,
|
diffusers/configuration_utils.py
CHANGED
@@ -11,7 +11,6 @@
|
|
11
11
|
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
12
12
|
# See the License for the specific language governing permissions and
|
13
13
|
# limitations under the License.
|
14
|
-
import sys
|
15
14
|
|
16
15
|
from .dependency_versions_table import deps
|
17
16
|
from .utils.versions import require_version, require_version_core
|
@@ -1,16 +1,15 @@
|
|
1
1
|
# THIS FILE HAS BEEN AUTOGENERATED. To update:
|
2
2
|
# 1. modify the `_deps` dict in setup.py
|
3
|
-
# 2. run `make deps_table_update
|
3
|
+
# 2. run `make deps_table_update`
|
4
4
|
deps = {
|
5
5
|
"Pillow": "Pillow",
|
6
6
|
"accelerate": "accelerate>=0.11.0",
|
7
7
|
"compel": "compel==0.1.8",
|
8
|
-
"black": "black~=23.1",
|
9
8
|
"datasets": "datasets",
|
10
9
|
"filelock": "filelock",
|
11
10
|
"flax": "flax>=0.4.1",
|
12
11
|
"hf-doc-builder": "hf-doc-builder>=0.3.0",
|
13
|
-
"huggingface-hub": "huggingface-hub>=0.
|
12
|
+
"huggingface-hub": "huggingface-hub>=0.19.4",
|
14
13
|
"requests-mock": "requests-mock==1.10.0",
|
15
14
|
"importlib_metadata": "importlib_metadata",
|
16
15
|
"invisible-watermark": "invisible-watermark>=0.2.0",
|
@@ -25,13 +24,13 @@ deps = {
|
|
25
24
|
"numpy": "numpy",
|
26
25
|
"omegaconf": "omegaconf",
|
27
26
|
"parameterized": "parameterized",
|
28
|
-
"peft": "peft
|
27
|
+
"peft": "peft>=0.6.0",
|
29
28
|
"protobuf": "protobuf>=3.20.3,<4",
|
30
29
|
"pytest": "pytest",
|
31
30
|
"pytest-timeout": "pytest-timeout",
|
32
31
|
"pytest-xdist": "pytest-xdist",
|
33
32
|
"python": "python>=3.8.0",
|
34
|
-
"ruff": "ruff
|
33
|
+
"ruff": "ruff>=0.1.5,<=0.2",
|
35
34
|
"safetensors": "safetensors>=0.3.1",
|
36
35
|
"sentencepiece": "sentencepiece>=0.1.91,!=0.1.92",
|
37
36
|
"scipy": "scipy",
|
diffusers/image_processor.py
CHANGED
@@ -13,7 +13,7 @@
|
|
13
13
|
# limitations under the License.
|
14
14
|
|
15
15
|
import warnings
|
16
|
-
from typing import List, Optional, Union
|
16
|
+
from typing import List, Optional, Tuple, Union
|
17
17
|
|
18
18
|
import numpy as np
|
19
19
|
import PIL.Image
|
@@ -33,6 +33,15 @@ PipelineImageInput = Union[
|
|
33
33
|
List[torch.FloatTensor],
|
34
34
|
]
|
35
35
|
|
36
|
+
PipelineDepthInput = Union[
|
37
|
+
PIL.Image.Image,
|
38
|
+
np.ndarray,
|
39
|
+
torch.FloatTensor,
|
40
|
+
List[PIL.Image.Image],
|
41
|
+
List[np.ndarray],
|
42
|
+
List[torch.FloatTensor],
|
43
|
+
]
|
44
|
+
|
36
45
|
|
37
46
|
class VaeImageProcessor(ConfigMixin):
|
38
47
|
"""
|
@@ -126,14 +135,14 @@ class VaeImageProcessor(ConfigMixin):
|
|
126
135
|
return images
|
127
136
|
|
128
137
|
@staticmethod
|
129
|
-
def normalize(images):
|
138
|
+
def normalize(images: Union[np.ndarray, torch.Tensor]) -> Union[np.ndarray, torch.Tensor]:
|
130
139
|
"""
|
131
140
|
Normalize an image array to [-1,1].
|
132
141
|
"""
|
133
142
|
return 2.0 * images - 1.0
|
134
143
|
|
135
144
|
@staticmethod
|
136
|
-
def denormalize(images):
|
145
|
+
def denormalize(images: Union[np.ndarray, torch.Tensor]) -> Union[np.ndarray, torch.Tensor]:
|
137
146
|
"""
|
138
147
|
Denormalize an image array to [0,1].
|
139
148
|
"""
|
@@ -159,10 +168,10 @@ class VaeImageProcessor(ConfigMixin):
|
|
159
168
|
|
160
169
|
def get_default_height_width(
|
161
170
|
self,
|
162
|
-
image: [PIL.Image.Image, np.ndarray, torch.Tensor],
|
171
|
+
image: Union[PIL.Image.Image, np.ndarray, torch.Tensor],
|
163
172
|
height: Optional[int] = None,
|
164
173
|
width: Optional[int] = None,
|
165
|
-
):
|
174
|
+
) -> Tuple[int, int]:
|
166
175
|
"""
|
167
176
|
This function return the height and width that are downscaled to the next integer multiple of
|
168
177
|
`vae_scale_factor`.
|
@@ -202,12 +211,24 @@ class VaeImageProcessor(ConfigMixin):
|
|
202
211
|
|
203
212
|
def resize(
|
204
213
|
self,
|
205
|
-
image: [PIL.Image.Image, np.ndarray, torch.Tensor],
|
214
|
+
image: Union[PIL.Image.Image, np.ndarray, torch.Tensor],
|
206
215
|
height: Optional[int] = None,
|
207
216
|
width: Optional[int] = None,
|
208
|
-
) -> [PIL.Image.Image, np.ndarray, torch.Tensor]:
|
217
|
+
) -> Union[PIL.Image.Image, np.ndarray, torch.Tensor]:
|
209
218
|
"""
|
210
219
|
Resize image.
|
220
|
+
|
221
|
+
Args:
|
222
|
+
image (`PIL.Image.Image`, `np.ndarray` or `torch.Tensor`):
|
223
|
+
The image input, can be a PIL image, numpy array or pytorch tensor.
|
224
|
+
height (`int`, *optional*, defaults to `None`):
|
225
|
+
The height to resize to.
|
226
|
+
width (`int`, *optional*`, defaults to `None`):
|
227
|
+
The width to resize to.
|
228
|
+
|
229
|
+
Returns:
|
230
|
+
`PIL.Image.Image`, `np.ndarray` or `torch.Tensor`:
|
231
|
+
The resized image.
|
211
232
|
"""
|
212
233
|
if isinstance(image, PIL.Image.Image):
|
213
234
|
image = image.resize((width, height), resample=PIL_INTERPOLATION[self.config.resample])
|
@@ -227,7 +248,15 @@ class VaeImageProcessor(ConfigMixin):
|
|
227
248
|
|
228
249
|
def binarize(self, image: PIL.Image.Image) -> PIL.Image.Image:
|
229
250
|
"""
|
230
|
-
|
251
|
+
Create a mask.
|
252
|
+
|
253
|
+
Args:
|
254
|
+
image (`PIL.Image.Image`):
|
255
|
+
The image input, should be a PIL image.
|
256
|
+
|
257
|
+
Returns:
|
258
|
+
`PIL.Image.Image`:
|
259
|
+
The binarized image. Values less than 0.5 are set to 0, values greater than 0.5 are set to 1.
|
231
260
|
"""
|
232
261
|
image[image < 0.5] = 0
|
233
262
|
image[image >= 0.5] = 1
|
@@ -306,7 +335,7 @@ class VaeImageProcessor(ConfigMixin):
|
|
306
335
|
|
307
336
|
# expected range [0,1], normalize to [-1,1]
|
308
337
|
do_normalize = self.config.do_normalize
|
309
|
-
if image.min() < 0
|
338
|
+
if do_normalize and image.min() < 0:
|
310
339
|
warnings.warn(
|
311
340
|
"Passing `image` as torch tensor with value range in [-1,1] is deprecated. The expected value range for image tensor is [0,1] "
|
312
341
|
f"when passing as pytorch tensor or numpy Array. You passed `image` with value range [{image.min()},{image.max()}]",
|
@@ -327,7 +356,23 @@ class VaeImageProcessor(ConfigMixin):
|
|
327
356
|
image: torch.FloatTensor,
|
328
357
|
output_type: str = "pil",
|
329
358
|
do_denormalize: Optional[List[bool]] = None,
|
330
|
-
):
|
359
|
+
) -> Union[PIL.Image.Image, np.ndarray, torch.FloatTensor]:
|
360
|
+
"""
|
361
|
+
Postprocess the image output from tensor to `output_type`.
|
362
|
+
|
363
|
+
Args:
|
364
|
+
image (`torch.FloatTensor`):
|
365
|
+
The image input, should be a pytorch tensor with shape `B x C x H x W`.
|
366
|
+
output_type (`str`, *optional*, defaults to `pil`):
|
367
|
+
The output type of the image, can be one of `pil`, `np`, `pt`, `latent`.
|
368
|
+
do_denormalize (`List[bool]`, *optional*, defaults to `None`):
|
369
|
+
Whether to denormalize the image to [0,1]. If `None`, will use the value of `do_normalize` in the
|
370
|
+
`VaeImageProcessor` config.
|
371
|
+
|
372
|
+
Returns:
|
373
|
+
`PIL.Image.Image`, `np.ndarray` or `torch.FloatTensor`:
|
374
|
+
The postprocessed image.
|
375
|
+
"""
|
331
376
|
if not isinstance(image, torch.Tensor):
|
332
377
|
raise ValueError(
|
333
378
|
f"Input for postprocessing is in incorrect format: {type(image)}. We only support pytorch tensor"
|
@@ -390,7 +435,7 @@ class VaeImageProcessorLDM3D(VaeImageProcessor):
|
|
390
435
|
super().__init__()
|
391
436
|
|
392
437
|
@staticmethod
|
393
|
-
def numpy_to_pil(images):
|
438
|
+
def numpy_to_pil(images: np.ndarray) -> List[PIL.Image.Image]:
|
394
439
|
"""
|
395
440
|
Convert a NumPy image or a batch of images to a PIL image.
|
396
441
|
"""
|
@@ -406,7 +451,19 @@ class VaeImageProcessorLDM3D(VaeImageProcessor):
|
|
406
451
|
return pil_images
|
407
452
|
|
408
453
|
@staticmethod
|
409
|
-
def
|
454
|
+
def depth_pil_to_numpy(images: Union[List[PIL.Image.Image], PIL.Image.Image]) -> np.ndarray:
|
455
|
+
"""
|
456
|
+
Convert a PIL image or a list of PIL images to NumPy arrays.
|
457
|
+
"""
|
458
|
+
if not isinstance(images, list):
|
459
|
+
images = [images]
|
460
|
+
|
461
|
+
images = [np.array(image).astype(np.float32) / (2**16 - 1) for image in images]
|
462
|
+
images = np.stack(images, axis=0)
|
463
|
+
return images
|
464
|
+
|
465
|
+
@staticmethod
|
466
|
+
def rgblike_to_depthmap(image: Union[np.ndarray, torch.Tensor]) -> Union[np.ndarray, torch.Tensor]:
|
410
467
|
"""
|
411
468
|
Args:
|
412
469
|
image: RGB-like depth image
|
@@ -416,7 +473,7 @@ class VaeImageProcessorLDM3D(VaeImageProcessor):
|
|
416
473
|
"""
|
417
474
|
return image[:, :, 1] * 2**8 + image[:, :, 2]
|
418
475
|
|
419
|
-
def numpy_to_depth(self, images):
|
476
|
+
def numpy_to_depth(self, images: np.ndarray) -> List[PIL.Image.Image]:
|
420
477
|
"""
|
421
478
|
Convert a NumPy depth image or a batch of images to a PIL image.
|
422
479
|
"""
|
@@ -441,7 +498,23 @@ class VaeImageProcessorLDM3D(VaeImageProcessor):
|
|
441
498
|
image: torch.FloatTensor,
|
442
499
|
output_type: str = "pil",
|
443
500
|
do_denormalize: Optional[List[bool]] = None,
|
444
|
-
):
|
501
|
+
) -> Union[PIL.Image.Image, np.ndarray, torch.FloatTensor]:
|
502
|
+
"""
|
503
|
+
Postprocess the image output from tensor to `output_type`.
|
504
|
+
|
505
|
+
Args:
|
506
|
+
image (`torch.FloatTensor`):
|
507
|
+
The image input, should be a pytorch tensor with shape `B x C x H x W`.
|
508
|
+
output_type (`str`, *optional*, defaults to `pil`):
|
509
|
+
The output type of the image, can be one of `pil`, `np`, `pt`, `latent`.
|
510
|
+
do_denormalize (`List[bool]`, *optional*, defaults to `None`):
|
511
|
+
Whether to denormalize the image to [0,1]. If `None`, will use the value of `do_normalize` in the
|
512
|
+
`VaeImageProcessor` config.
|
513
|
+
|
514
|
+
Returns:
|
515
|
+
`PIL.Image.Image`, `np.ndarray` or `torch.FloatTensor`:
|
516
|
+
The postprocessed image.
|
517
|
+
"""
|
445
518
|
if not isinstance(image, torch.Tensor):
|
446
519
|
raise ValueError(
|
447
520
|
f"Input for postprocessing is in incorrect format: {type(image)}. We only support pytorch tensor"
|
@@ -474,3 +547,102 @@ class VaeImageProcessorLDM3D(VaeImageProcessor):
|
|
474
547
|
return self.numpy_to_pil(image), self.numpy_to_depth(image)
|
475
548
|
else:
|
476
549
|
raise Exception(f"This type {output_type} is not supported")
|
550
|
+
|
551
|
+
def preprocess(
|
552
|
+
self,
|
553
|
+
rgb: Union[torch.FloatTensor, PIL.Image.Image, np.ndarray],
|
554
|
+
depth: Union[torch.FloatTensor, PIL.Image.Image, np.ndarray],
|
555
|
+
height: Optional[int] = None,
|
556
|
+
width: Optional[int] = None,
|
557
|
+
target_res: Optional[int] = None,
|
558
|
+
) -> torch.Tensor:
|
559
|
+
"""
|
560
|
+
Preprocess the image input. Accepted formats are PIL images, NumPy arrays or PyTorch tensors.
|
561
|
+
"""
|
562
|
+
supported_formats = (PIL.Image.Image, np.ndarray, torch.Tensor)
|
563
|
+
|
564
|
+
# Expand the missing dimension for 3-dimensional pytorch tensor or numpy array that represents grayscale image
|
565
|
+
if self.config.do_convert_grayscale and isinstance(rgb, (torch.Tensor, np.ndarray)) and rgb.ndim == 3:
|
566
|
+
raise Exception("This is not yet supported")
|
567
|
+
|
568
|
+
if isinstance(rgb, supported_formats):
|
569
|
+
rgb = [rgb]
|
570
|
+
depth = [depth]
|
571
|
+
elif not (isinstance(rgb, list) and all(isinstance(i, supported_formats) for i in rgb)):
|
572
|
+
raise ValueError(
|
573
|
+
f"Input is in incorrect format: {[type(i) for i in rgb]}. Currently, we only support {', '.join(supported_formats)}"
|
574
|
+
)
|
575
|
+
|
576
|
+
if isinstance(rgb[0], PIL.Image.Image):
|
577
|
+
if self.config.do_convert_rgb:
|
578
|
+
raise Exception("This is not yet supported")
|
579
|
+
# rgb = [self.convert_to_rgb(i) for i in rgb]
|
580
|
+
# depth = [self.convert_to_depth(i) for i in depth] #TODO define convert_to_depth
|
581
|
+
if self.config.do_resize or target_res:
|
582
|
+
height, width = self.get_default_height_width(rgb[0], height, width) if not target_res else target_res
|
583
|
+
rgb = [self.resize(i, height, width) for i in rgb]
|
584
|
+
depth = [self.resize(i, height, width) for i in depth]
|
585
|
+
rgb = self.pil_to_numpy(rgb) # to np
|
586
|
+
rgb = self.numpy_to_pt(rgb) # to pt
|
587
|
+
|
588
|
+
depth = self.depth_pil_to_numpy(depth) # to np
|
589
|
+
depth = self.numpy_to_pt(depth) # to pt
|
590
|
+
|
591
|
+
elif isinstance(rgb[0], np.ndarray):
|
592
|
+
rgb = np.concatenate(rgb, axis=0) if rgb[0].ndim == 4 else np.stack(rgb, axis=0)
|
593
|
+
rgb = self.numpy_to_pt(rgb)
|
594
|
+
height, width = self.get_default_height_width(rgb, height, width)
|
595
|
+
if self.config.do_resize:
|
596
|
+
rgb = self.resize(rgb, height, width)
|
597
|
+
|
598
|
+
depth = np.concatenate(depth, axis=0) if rgb[0].ndim == 4 else np.stack(depth, axis=0)
|
599
|
+
depth = self.numpy_to_pt(depth)
|
600
|
+
height, width = self.get_default_height_width(depth, height, width)
|
601
|
+
if self.config.do_resize:
|
602
|
+
depth = self.resize(depth, height, width)
|
603
|
+
|
604
|
+
elif isinstance(rgb[0], torch.Tensor):
|
605
|
+
raise Exception("This is not yet supported")
|
606
|
+
# rgb = torch.cat(rgb, axis=0) if rgb[0].ndim == 4 else torch.stack(rgb, axis=0)
|
607
|
+
|
608
|
+
# if self.config.do_convert_grayscale and rgb.ndim == 3:
|
609
|
+
# rgb = rgb.unsqueeze(1)
|
610
|
+
|
611
|
+
# channel = rgb.shape[1]
|
612
|
+
|
613
|
+
# height, width = self.get_default_height_width(rgb, height, width)
|
614
|
+
# if self.config.do_resize:
|
615
|
+
# rgb = self.resize(rgb, height, width)
|
616
|
+
|
617
|
+
# depth = torch.cat(depth, axis=0) if depth[0].ndim == 4 else torch.stack(depth, axis=0)
|
618
|
+
|
619
|
+
# if self.config.do_convert_grayscale and depth.ndim == 3:
|
620
|
+
# depth = depth.unsqueeze(1)
|
621
|
+
|
622
|
+
# channel = depth.shape[1]
|
623
|
+
# # don't need any preprocess if the image is latents
|
624
|
+
# if depth == 4:
|
625
|
+
# return rgb, depth
|
626
|
+
|
627
|
+
# height, width = self.get_default_height_width(depth, height, width)
|
628
|
+
# if self.config.do_resize:
|
629
|
+
# depth = self.resize(depth, height, width)
|
630
|
+
# expected range [0,1], normalize to [-1,1]
|
631
|
+
do_normalize = self.config.do_normalize
|
632
|
+
if rgb.min() < 0 and do_normalize:
|
633
|
+
warnings.warn(
|
634
|
+
"Passing `image` as torch tensor with value range in [-1,1] is deprecated. The expected value range for image tensor is [0,1] "
|
635
|
+
f"when passing as pytorch tensor or numpy Array. You passed `image` with value range [{rgb.min()},{rgb.max()}]",
|
636
|
+
FutureWarning,
|
637
|
+
)
|
638
|
+
do_normalize = False
|
639
|
+
|
640
|
+
if do_normalize:
|
641
|
+
rgb = self.normalize(rgb)
|
642
|
+
depth = self.normalize(depth)
|
643
|
+
|
644
|
+
if self.config.do_binarize:
|
645
|
+
rgb = self.binarize(rgb)
|
646
|
+
depth = self.binarize(depth)
|
647
|
+
|
648
|
+
return rgb, depth
|
@@ -0,0 +1,82 @@
|
|
1
|
+
from typing import TYPE_CHECKING
|
2
|
+
|
3
|
+
from ..utils import DIFFUSERS_SLOW_IMPORT, _LazyModule, deprecate
|
4
|
+
from ..utils.import_utils import is_torch_available, is_transformers_available
|
5
|
+
|
6
|
+
|
7
|
+
def text_encoder_lora_state_dict(text_encoder):
|
8
|
+
deprecate(
|
9
|
+
"text_encoder_load_state_dict in `models`",
|
10
|
+
"0.27.0",
|
11
|
+
"`text_encoder_lora_state_dict` is deprecated and will be removed in 0.27.0. Make sure to retrieve the weights using `get_peft_model`. See https://huggingface.co/docs/peft/v0.6.2/en/quicktour#peftmodel for more information.",
|
12
|
+
)
|
13
|
+
state_dict = {}
|
14
|
+
|
15
|
+
for name, module in text_encoder_attn_modules(text_encoder):
|
16
|
+
for k, v in module.q_proj.lora_linear_layer.state_dict().items():
|
17
|
+
state_dict[f"{name}.q_proj.lora_linear_layer.{k}"] = v
|
18
|
+
|
19
|
+
for k, v in module.k_proj.lora_linear_layer.state_dict().items():
|
20
|
+
state_dict[f"{name}.k_proj.lora_linear_layer.{k}"] = v
|
21
|
+
|
22
|
+
for k, v in module.v_proj.lora_linear_layer.state_dict().items():
|
23
|
+
state_dict[f"{name}.v_proj.lora_linear_layer.{k}"] = v
|
24
|
+
|
25
|
+
for k, v in module.out_proj.lora_linear_layer.state_dict().items():
|
26
|
+
state_dict[f"{name}.out_proj.lora_linear_layer.{k}"] = v
|
27
|
+
|
28
|
+
return state_dict
|
29
|
+
|
30
|
+
|
31
|
+
if is_transformers_available():
|
32
|
+
|
33
|
+
def text_encoder_attn_modules(text_encoder):
|
34
|
+
deprecate(
|
35
|
+
"text_encoder_attn_modules in `models`",
|
36
|
+
"0.27.0",
|
37
|
+
"`text_encoder_lora_state_dict` is deprecated and will be removed in 0.27.0. Make sure to retrieve the weights using `get_peft_model`. See https://huggingface.co/docs/peft/v0.6.2/en/quicktour#peftmodel for more information.",
|
38
|
+
)
|
39
|
+
from transformers import CLIPTextModel, CLIPTextModelWithProjection
|
40
|
+
|
41
|
+
attn_modules = []
|
42
|
+
|
43
|
+
if isinstance(text_encoder, (CLIPTextModel, CLIPTextModelWithProjection)):
|
44
|
+
for i, layer in enumerate(text_encoder.text_model.encoder.layers):
|
45
|
+
name = f"text_model.encoder.layers.{i}.self_attn"
|
46
|
+
mod = layer.self_attn
|
47
|
+
attn_modules.append((name, mod))
|
48
|
+
else:
|
49
|
+
raise ValueError(f"do not know how to get attention modules for: {text_encoder.__class__.__name__}")
|
50
|
+
|
51
|
+
return attn_modules
|
52
|
+
|
53
|
+
|
54
|
+
_import_structure = {}
|
55
|
+
|
56
|
+
if is_torch_available():
|
57
|
+
_import_structure["single_file"] = ["FromOriginalControlnetMixin", "FromOriginalVAEMixin"]
|
58
|
+
_import_structure["unet"] = ["UNet2DConditionLoadersMixin"]
|
59
|
+
_import_structure["utils"] = ["AttnProcsLayers"]
|
60
|
+
|
61
|
+
if is_transformers_available():
|
62
|
+
_import_structure["single_file"].extend(["FromSingleFileMixin"])
|
63
|
+
_import_structure["lora"] = ["LoraLoaderMixin", "StableDiffusionXLLoraLoaderMixin"]
|
64
|
+
_import_structure["textual_inversion"] = ["TextualInversionLoaderMixin"]
|
65
|
+
_import_structure["ip_adapter"] = ["IPAdapterMixin"]
|
66
|
+
|
67
|
+
|
68
|
+
if TYPE_CHECKING or DIFFUSERS_SLOW_IMPORT:
|
69
|
+
if is_torch_available():
|
70
|
+
from .single_file import FromOriginalControlnetMixin, FromOriginalVAEMixin
|
71
|
+
from .unet import UNet2DConditionLoadersMixin
|
72
|
+
from .utils import AttnProcsLayers
|
73
|
+
|
74
|
+
if is_transformers_available():
|
75
|
+
from .ip_adapter import IPAdapterMixin
|
76
|
+
from .lora import LoraLoaderMixin, StableDiffusionXLLoraLoaderMixin
|
77
|
+
from .single_file import FromSingleFileMixin
|
78
|
+
from .textual_inversion import TextualInversionLoaderMixin
|
79
|
+
else:
|
80
|
+
import sys
|
81
|
+
|
82
|
+
sys.modules[__name__] = _LazyModule(__name__, globals()["__file__"], _import_structure, module_spec=__spec__)
|
@@ -0,0 +1,157 @@
|
|
1
|
+
# Copyright 2023 The HuggingFace Team. All rights reserved.
|
2
|
+
#
|
3
|
+
# Licensed under the Apache License, Version 2.0 (the "License");
|
4
|
+
# you may not use this file except in compliance with the License.
|
5
|
+
# You may obtain a copy of the License at
|
6
|
+
#
|
7
|
+
# http://www.apache.org/licenses/LICENSE-2.0
|
8
|
+
#
|
9
|
+
# Unless required by applicable law or agreed to in writing, software
|
10
|
+
# distributed under the License is distributed on an "AS IS" BASIS,
|
11
|
+
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
12
|
+
# See the License for the specific language governing permissions and
|
13
|
+
# limitations under the License.
|
14
|
+
import os
|
15
|
+
from typing import Dict, Union
|
16
|
+
|
17
|
+
import torch
|
18
|
+
from safetensors import safe_open
|
19
|
+
|
20
|
+
from ..utils import (
|
21
|
+
DIFFUSERS_CACHE,
|
22
|
+
HF_HUB_OFFLINE,
|
23
|
+
_get_model_file,
|
24
|
+
is_transformers_available,
|
25
|
+
logging,
|
26
|
+
)
|
27
|
+
|
28
|
+
|
29
|
+
if is_transformers_available():
|
30
|
+
from transformers import (
|
31
|
+
CLIPImageProcessor,
|
32
|
+
CLIPVisionModelWithProjection,
|
33
|
+
)
|
34
|
+
|
35
|
+
from ..models.attention_processor import (
|
36
|
+
IPAdapterAttnProcessor,
|
37
|
+
IPAdapterAttnProcessor2_0,
|
38
|
+
)
|
39
|
+
|
40
|
+
logger = logging.get_logger(__name__)
|
41
|
+
|
42
|
+
|
43
|
+
class IPAdapterMixin:
|
44
|
+
"""Mixin for handling IP Adapters."""
|
45
|
+
|
46
|
+
def load_ip_adapter(
|
47
|
+
self,
|
48
|
+
pretrained_model_name_or_path_or_dict: Union[str, Dict[str, torch.Tensor]],
|
49
|
+
subfolder: str,
|
50
|
+
weight_name: str,
|
51
|
+
**kwargs,
|
52
|
+
):
|
53
|
+
"""
|
54
|
+
Parameters:
|
55
|
+
pretrained_model_name_or_path_or_dict (`str` or `os.PathLike` or `dict`):
|
56
|
+
Can be either:
|
57
|
+
|
58
|
+
- A string, the *model id* (for example `google/ddpm-celebahq-256`) of a pretrained model hosted on
|
59
|
+
the Hub.
|
60
|
+
- A path to a *directory* (for example `./my_model_directory`) containing the model weights saved
|
61
|
+
with [`ModelMixin.save_pretrained`].
|
62
|
+
- A [torch state
|
63
|
+
dict](https://pytorch.org/tutorials/beginner/saving_loading_models.html#what-is-a-state-dict).
|
64
|
+
|
65
|
+
cache_dir (`Union[str, os.PathLike]`, *optional*):
|
66
|
+
Path to a directory where a downloaded pretrained model configuration is cached if the standard cache
|
67
|
+
is not used.
|
68
|
+
force_download (`bool`, *optional*, defaults to `False`):
|
69
|
+
Whether or not to force the (re-)download of the model weights and configuration files, overriding the
|
70
|
+
cached versions if they exist.
|
71
|
+
resume_download (`bool`, *optional*, defaults to `False`):
|
72
|
+
Whether or not to resume downloading the model weights and configuration files. If set to `False`, any
|
73
|
+
incompletely downloaded files are deleted.
|
74
|
+
proxies (`Dict[str, str]`, *optional*):
|
75
|
+
A dictionary of proxy servers to use by protocol or endpoint, for example, `{'http': 'foo.bar:3128',
|
76
|
+
'http://hostname': 'foo.bar:4012'}`. The proxies are used on each request.
|
77
|
+
local_files_only (`bool`, *optional*, defaults to `False`):
|
78
|
+
Whether to only load local model weights and configuration files or not. If set to `True`, the model
|
79
|
+
won't be downloaded from the Hub.
|
80
|
+
use_auth_token (`str` or *bool*, *optional*):
|
81
|
+
The token to use as HTTP bearer authorization for remote files. If `True`, the token generated from
|
82
|
+
`diffusers-cli login` (stored in `~/.huggingface`) is used.
|
83
|
+
revision (`str`, *optional*, defaults to `"main"`):
|
84
|
+
The specific model version to use. It can be a branch name, a tag name, a commit id, or any identifier
|
85
|
+
allowed by Git.
|
86
|
+
subfolder (`str`, *optional*, defaults to `""`):
|
87
|
+
The subfolder location of a model file within a larger model repository on the Hub or locally.
|
88
|
+
"""
|
89
|
+
|
90
|
+
# Load the main state dict first.
|
91
|
+
cache_dir = kwargs.pop("cache_dir", DIFFUSERS_CACHE)
|
92
|
+
force_download = kwargs.pop("force_download", False)
|
93
|
+
resume_download = kwargs.pop("resume_download", False)
|
94
|
+
proxies = kwargs.pop("proxies", None)
|
95
|
+
local_files_only = kwargs.pop("local_files_only", HF_HUB_OFFLINE)
|
96
|
+
use_auth_token = kwargs.pop("use_auth_token", None)
|
97
|
+
revision = kwargs.pop("revision", None)
|
98
|
+
|
99
|
+
user_agent = {
|
100
|
+
"file_type": "attn_procs_weights",
|
101
|
+
"framework": "pytorch",
|
102
|
+
}
|
103
|
+
|
104
|
+
if not isinstance(pretrained_model_name_or_path_or_dict, dict):
|
105
|
+
model_file = _get_model_file(
|
106
|
+
pretrained_model_name_or_path_or_dict,
|
107
|
+
weights_name=weight_name,
|
108
|
+
cache_dir=cache_dir,
|
109
|
+
force_download=force_download,
|
110
|
+
resume_download=resume_download,
|
111
|
+
proxies=proxies,
|
112
|
+
local_files_only=local_files_only,
|
113
|
+
use_auth_token=use_auth_token,
|
114
|
+
revision=revision,
|
115
|
+
subfolder=subfolder,
|
116
|
+
user_agent=user_agent,
|
117
|
+
)
|
118
|
+
if weight_name.endswith(".safetensors"):
|
119
|
+
state_dict = {"image_proj": {}, "ip_adapter": {}}
|
120
|
+
with safe_open(model_file, framework="pt", device="cpu") as f:
|
121
|
+
for key in f.keys():
|
122
|
+
if key.startswith("image_proj."):
|
123
|
+
state_dict["image_proj"][key.replace("image_proj.", "")] = f.get_tensor(key)
|
124
|
+
elif key.startswith("ip_adapter."):
|
125
|
+
state_dict["ip_adapter"][key.replace("ip_adapter.", "")] = f.get_tensor(key)
|
126
|
+
else:
|
127
|
+
state_dict = torch.load(model_file, map_location="cpu")
|
128
|
+
else:
|
129
|
+
state_dict = pretrained_model_name_or_path_or_dict
|
130
|
+
|
131
|
+
keys = list(state_dict.keys())
|
132
|
+
if keys != ["image_proj", "ip_adapter"]:
|
133
|
+
raise ValueError("Required keys are (`image_proj` and `ip_adapter`) missing from the state dict.")
|
134
|
+
|
135
|
+
# load CLIP image encoer here if it has not been registered to the pipeline yet
|
136
|
+
if hasattr(self, "image_encoder") and getattr(self, "image_encoder", None) is None:
|
137
|
+
if not isinstance(pretrained_model_name_or_path_or_dict, dict):
|
138
|
+
logger.info(f"loading image_encoder from {pretrained_model_name_or_path_or_dict}")
|
139
|
+
image_encoder = CLIPVisionModelWithProjection.from_pretrained(
|
140
|
+
pretrained_model_name_or_path_or_dict,
|
141
|
+
subfolder=os.path.join(subfolder, "image_encoder"),
|
142
|
+
).to(self.device, dtype=self.dtype)
|
143
|
+
self.image_encoder = image_encoder
|
144
|
+
else:
|
145
|
+
raise ValueError("`image_encoder` cannot be None when using IP Adapters.")
|
146
|
+
|
147
|
+
# create feature extractor if it has not been registered to the pipeline yet
|
148
|
+
if hasattr(self, "feature_extractor") and getattr(self, "feature_extractor", None) is None:
|
149
|
+
self.feature_extractor = CLIPImageProcessor()
|
150
|
+
|
151
|
+
# load ip-adapter into unet
|
152
|
+
self.unet._load_ip_adapter_weights(state_dict)
|
153
|
+
|
154
|
+
def set_ip_adapter_scale(self, scale):
|
155
|
+
for attn_processor in self.unet.attn_processors.values():
|
156
|
+
if isinstance(attn_processor, (IPAdapterAttnProcessor, IPAdapterAttnProcessor2_0)):
|
157
|
+
attn_processor.scale = scale
|