diffusers 0.17.1__py3-none-any.whl → 0.18.2__py3-none-any.whl
Sign up to get free protection for your applications and to get access to all the features.
- diffusers/__init__.py +26 -1
- diffusers/configuration_utils.py +34 -29
- diffusers/dependency_versions_table.py +4 -0
- diffusers/image_processor.py +125 -12
- diffusers/loaders.py +169 -203
- diffusers/models/attention.py +24 -1
- diffusers/models/attention_flax.py +10 -5
- diffusers/models/attention_processor.py +3 -0
- diffusers/models/autoencoder_kl.py +114 -33
- diffusers/models/controlnet.py +131 -14
- diffusers/models/controlnet_flax.py +37 -26
- diffusers/models/cross_attention.py +17 -17
- diffusers/models/embeddings.py +67 -0
- diffusers/models/modeling_flax_utils.py +64 -56
- diffusers/models/modeling_utils.py +193 -104
- diffusers/models/prior_transformer.py +207 -37
- diffusers/models/resnet.py +26 -26
- diffusers/models/transformer_2d.py +36 -41
- diffusers/models/transformer_temporal.py +24 -21
- diffusers/models/unet_1d.py +31 -25
- diffusers/models/unet_2d.py +43 -30
- diffusers/models/unet_2d_blocks.py +210 -89
- diffusers/models/unet_2d_blocks_flax.py +12 -12
- diffusers/models/unet_2d_condition.py +172 -64
- diffusers/models/unet_2d_condition_flax.py +38 -24
- diffusers/models/unet_3d_blocks.py +34 -31
- diffusers/models/unet_3d_condition.py +101 -34
- diffusers/models/vae.py +5 -5
- diffusers/models/vae_flax.py +37 -34
- diffusers/models/vq_model.py +23 -14
- diffusers/pipelines/__init__.py +24 -1
- diffusers/pipelines/alt_diffusion/pipeline_alt_diffusion.py +1 -1
- diffusers/pipelines/alt_diffusion/pipeline_alt_diffusion_img2img.py +5 -3
- diffusers/pipelines/consistency_models/__init__.py +1 -0
- diffusers/pipelines/consistency_models/pipeline_consistency_models.py +337 -0
- diffusers/pipelines/controlnet/multicontrolnet.py +120 -1
- diffusers/pipelines/controlnet/pipeline_controlnet.py +59 -17
- diffusers/pipelines/controlnet/pipeline_controlnet_img2img.py +60 -15
- diffusers/pipelines/controlnet/pipeline_controlnet_inpaint.py +60 -17
- diffusers/pipelines/controlnet/pipeline_flax_controlnet.py +1 -1
- diffusers/pipelines/kandinsky/__init__.py +1 -1
- diffusers/pipelines/kandinsky/pipeline_kandinsky.py +4 -6
- diffusers/pipelines/kandinsky/pipeline_kandinsky_inpaint.py +1 -0
- diffusers/pipelines/kandinsky/pipeline_kandinsky_prior.py +1 -0
- diffusers/pipelines/kandinsky2_2/__init__.py +7 -0
- diffusers/pipelines/kandinsky2_2/pipeline_kandinsky2_2.py +317 -0
- diffusers/pipelines/kandinsky2_2/pipeline_kandinsky2_2_controlnet.py +372 -0
- diffusers/pipelines/kandinsky2_2/pipeline_kandinsky2_2_controlnet_img2img.py +434 -0
- diffusers/pipelines/kandinsky2_2/pipeline_kandinsky2_2_img2img.py +398 -0
- diffusers/pipelines/kandinsky2_2/pipeline_kandinsky2_2_inpainting.py +531 -0
- diffusers/pipelines/kandinsky2_2/pipeline_kandinsky2_2_prior.py +541 -0
- diffusers/pipelines/kandinsky2_2/pipeline_kandinsky2_2_prior_emb2emb.py +605 -0
- diffusers/pipelines/pipeline_flax_utils.py +2 -2
- diffusers/pipelines/pipeline_utils.py +124 -146
- diffusers/pipelines/shap_e/__init__.py +27 -0
- diffusers/pipelines/shap_e/camera.py +147 -0
- diffusers/pipelines/shap_e/pipeline_shap_e.py +390 -0
- diffusers/pipelines/shap_e/pipeline_shap_e_img2img.py +349 -0
- diffusers/pipelines/shap_e/renderer.py +709 -0
- diffusers/pipelines/stable_diffusion/__init__.py +2 -0
- diffusers/pipelines/stable_diffusion/convert_from_ckpt.py +261 -66
- diffusers/pipelines/stable_diffusion/pipeline_stable_diffusion.py +3 -3
- diffusers/pipelines/stable_diffusion/pipeline_stable_diffusion_img2img.py +5 -3
- diffusers/pipelines/stable_diffusion/pipeline_stable_diffusion_inpaint.py +4 -2
- diffusers/pipelines/stable_diffusion/pipeline_stable_diffusion_inpaint_legacy.py +6 -6
- diffusers/pipelines/stable_diffusion/pipeline_stable_diffusion_instruct_pix2pix.py +1 -1
- diffusers/pipelines/stable_diffusion/pipeline_stable_diffusion_k_diffusion.py +1 -1
- diffusers/pipelines/stable_diffusion/pipeline_stable_diffusion_ldm3d.py +719 -0
- diffusers/pipelines/stable_diffusion/pipeline_stable_diffusion_panorama.py +1 -1
- diffusers/pipelines/stable_diffusion/pipeline_stable_diffusion_paradigms.py +832 -0
- diffusers/pipelines/stable_diffusion/pipeline_stable_diffusion_upscale.py +17 -7
- diffusers/pipelines/stable_diffusion_xl/__init__.py +26 -0
- diffusers/pipelines/stable_diffusion_xl/pipeline_stable_diffusion_xl.py +823 -0
- diffusers/pipelines/stable_diffusion_xl/pipeline_stable_diffusion_xl_img2img.py +896 -0
- diffusers/pipelines/stable_diffusion_xl/watermark.py +31 -0
- diffusers/pipelines/text_to_video_synthesis/__init__.py +2 -1
- diffusers/pipelines/text_to_video_synthesis/pipeline_text_to_video_synth.py +5 -1
- diffusers/pipelines/text_to_video_synthesis/pipeline_text_to_video_synth_img2img.py +771 -0
- diffusers/pipelines/text_to_video_synthesis/pipeline_text_to_video_zero.py +92 -6
- diffusers/pipelines/unidiffuser/pipeline_unidiffuser.py +3 -3
- diffusers/pipelines/versatile_diffusion/modeling_text_unet.py +209 -91
- diffusers/schedulers/__init__.py +3 -0
- diffusers/schedulers/scheduling_consistency_models.py +380 -0
- diffusers/schedulers/scheduling_ddim.py +28 -6
- diffusers/schedulers/scheduling_ddim_inverse.py +19 -4
- diffusers/schedulers/scheduling_ddim_parallel.py +642 -0
- diffusers/schedulers/scheduling_ddpm.py +53 -7
- diffusers/schedulers/scheduling_ddpm_parallel.py +604 -0
- diffusers/schedulers/scheduling_deis_multistep.py +66 -11
- diffusers/schedulers/scheduling_dpmsolver_multistep.py +55 -13
- diffusers/schedulers/scheduling_dpmsolver_multistep_inverse.py +19 -4
- diffusers/schedulers/scheduling_dpmsolver_sde.py +73 -11
- diffusers/schedulers/scheduling_dpmsolver_singlestep.py +23 -7
- diffusers/schedulers/scheduling_euler_ancestral_discrete.py +58 -9
- diffusers/schedulers/scheduling_euler_discrete.py +58 -8
- diffusers/schedulers/scheduling_heun_discrete.py +89 -14
- diffusers/schedulers/scheduling_k_dpm_2_ancestral_discrete.py +73 -11
- diffusers/schedulers/scheduling_k_dpm_2_discrete.py +73 -11
- diffusers/schedulers/scheduling_lms_discrete.py +57 -8
- diffusers/schedulers/scheduling_pndm.py +46 -10
- diffusers/schedulers/scheduling_repaint.py +19 -4
- diffusers/schedulers/scheduling_sde_ve.py +5 -1
- diffusers/schedulers/scheduling_unclip.py +43 -4
- diffusers/schedulers/scheduling_unipc_multistep.py +48 -7
- diffusers/training_utils.py +1 -1
- diffusers/utils/__init__.py +2 -1
- diffusers/utils/dummy_pt_objects.py +60 -0
- diffusers/utils/dummy_torch_and_transformers_and_invisible_watermark_objects.py +32 -0
- diffusers/utils/dummy_torch_and_transformers_objects.py +180 -0
- diffusers/utils/hub_utils.py +1 -1
- diffusers/utils/import_utils.py +20 -3
- diffusers/utils/logging.py +15 -18
- diffusers/utils/outputs.py +3 -3
- diffusers/utils/testing_utils.py +15 -0
- {diffusers-0.17.1.dist-info → diffusers-0.18.2.dist-info}/METADATA +4 -2
- {diffusers-0.17.1.dist-info → diffusers-0.18.2.dist-info}/RECORD +120 -94
- {diffusers-0.17.1.dist-info → diffusers-0.18.2.dist-info}/WHEEL +1 -1
- {diffusers-0.17.1.dist-info → diffusers-0.18.2.dist-info}/LICENSE +0 -0
- {diffusers-0.17.1.dist-info → diffusers-0.18.2.dist-info}/entry_points.txt +0 -0
- {diffusers-0.17.1.dist-info → diffusers-0.18.2.dist-info}/top_level.txt +0 -0
diffusers/__init__.py
CHANGED
@@ -1,10 +1,11 @@
|
|
1
|
-
__version__ = "0.
|
1
|
+
__version__ = "0.18.2"
|
2
2
|
|
3
3
|
from .configuration_utils import ConfigMixin
|
4
4
|
from .utils import (
|
5
5
|
OptionalDependencyNotAvailable,
|
6
6
|
is_flax_available,
|
7
7
|
is_inflect_available,
|
8
|
+
is_invisible_watermark_available,
|
8
9
|
is_k_diffusion_available,
|
9
10
|
is_k_diffusion_version,
|
10
11
|
is_librosa_available,
|
@@ -58,6 +59,7 @@ else:
|
|
58
59
|
)
|
59
60
|
from .pipelines import (
|
60
61
|
AudioPipelineOutput,
|
62
|
+
ConsistencyModelPipeline,
|
61
63
|
DanceDiffusionPipeline,
|
62
64
|
DDIMPipeline,
|
63
65
|
DDPMPipeline,
|
@@ -72,8 +74,11 @@ else:
|
|
72
74
|
ScoreSdeVePipeline,
|
73
75
|
)
|
74
76
|
from .schedulers import (
|
77
|
+
CMStochasticIterativeScheduler,
|
75
78
|
DDIMInverseScheduler,
|
79
|
+
DDIMParallelScheduler,
|
76
80
|
DDIMScheduler,
|
81
|
+
DDPMParallelScheduler,
|
77
82
|
DDPMScheduler,
|
78
83
|
DEISMultistepScheduler,
|
79
84
|
DPMSolverMultistepInverseScheduler,
|
@@ -134,9 +139,18 @@ else:
|
|
134
139
|
KandinskyInpaintPipeline,
|
135
140
|
KandinskyPipeline,
|
136
141
|
KandinskyPriorPipeline,
|
142
|
+
KandinskyV22ControlnetImg2ImgPipeline,
|
143
|
+
KandinskyV22ControlnetPipeline,
|
144
|
+
KandinskyV22Img2ImgPipeline,
|
145
|
+
KandinskyV22InpaintPipeline,
|
146
|
+
KandinskyV22Pipeline,
|
147
|
+
KandinskyV22PriorEmb2EmbPipeline,
|
148
|
+
KandinskyV22PriorPipeline,
|
137
149
|
LDMTextToImagePipeline,
|
138
150
|
PaintByExamplePipeline,
|
139
151
|
SemanticStableDiffusionPipeline,
|
152
|
+
ShapEImg2ImgPipeline,
|
153
|
+
ShapEPipeline,
|
140
154
|
StableDiffusionAttendAndExcitePipeline,
|
141
155
|
StableDiffusionControlNetImg2ImgPipeline,
|
142
156
|
StableDiffusionControlNetInpaintPipeline,
|
@@ -149,8 +163,10 @@ else:
|
|
149
163
|
StableDiffusionInpaintPipelineLegacy,
|
150
164
|
StableDiffusionInstructPix2PixPipeline,
|
151
165
|
StableDiffusionLatentUpscalePipeline,
|
166
|
+
StableDiffusionLDM3DPipeline,
|
152
167
|
StableDiffusionModelEditingPipeline,
|
153
168
|
StableDiffusionPanoramaPipeline,
|
169
|
+
StableDiffusionParadigmsPipeline,
|
154
170
|
StableDiffusionPipeline,
|
155
171
|
StableDiffusionPipelineSafe,
|
156
172
|
StableDiffusionPix2PixZeroPipeline,
|
@@ -169,9 +185,18 @@ else:
|
|
169
185
|
VersatileDiffusionImageVariationPipeline,
|
170
186
|
VersatileDiffusionPipeline,
|
171
187
|
VersatileDiffusionTextToImagePipeline,
|
188
|
+
VideoToVideoSDPipeline,
|
172
189
|
VQDiffusionPipeline,
|
173
190
|
)
|
174
191
|
|
192
|
+
try:
|
193
|
+
if not (is_torch_available() and is_transformers_available() and is_invisible_watermark_available()):
|
194
|
+
raise OptionalDependencyNotAvailable()
|
195
|
+
except OptionalDependencyNotAvailable:
|
196
|
+
from .utils.dummy_torch_and_transformers_and_invisible_watermark_objects import * # noqa F403
|
197
|
+
else:
|
198
|
+
from .pipelines import StableDiffusionXLImg2ImgPipeline, StableDiffusionXLPipeline
|
199
|
+
|
175
200
|
try:
|
176
201
|
if not (is_torch_available() and is_transformers_available() and is_k_diffusion_available()):
|
177
202
|
raise OptionalDependencyNotAvailable()
|
diffusers/configuration_utils.py
CHANGED
@@ -81,10 +81,9 @@ class FrozenDict(OrderedDict):
|
|
81
81
|
|
82
82
|
class ConfigMixin:
|
83
83
|
r"""
|
84
|
-
Base class for all configuration classes.
|
85
|
-
|
86
|
-
|
87
|
-
- [`~ConfigMixin.save_config`]
|
84
|
+
Base class for all configuration classes. All configuration parameters are stored under `self.config`. Also
|
85
|
+
provides the [`~ConfigMixin.from_config`] and [`~ConfigMixin.save_config`] methods for loading, downloading, and
|
86
|
+
saving classes that inherit from [`ConfigMixin`].
|
88
87
|
|
89
88
|
Class attributes:
|
90
89
|
- **config_name** (`str`) -- A filename under which the config should stored when calling
|
@@ -92,7 +91,7 @@ class ConfigMixin:
|
|
92
91
|
- **ignore_for_config** (`List[str]`) -- A list of attributes that should not be saved in the config (should be
|
93
92
|
overridden by subclass).
|
94
93
|
- **has_compatibles** (`bool`) -- Whether the class has compatible classes (should be overridden by subclass).
|
95
|
-
- **_deprecated_kwargs** (`List[str]`) -- Keyword arguments that are deprecated. Note that the init function
|
94
|
+
- **_deprecated_kwargs** (`List[str]`) -- Keyword arguments that are deprecated. Note that the `init` function
|
96
95
|
should only have a `kwargs` argument if at least one argument is deprecated (should be overridden by
|
97
96
|
subclass).
|
98
97
|
"""
|
@@ -139,12 +138,12 @@ class ConfigMixin:
|
|
139
138
|
|
140
139
|
def save_config(self, save_directory: Union[str, os.PathLike], push_to_hub: bool = False, **kwargs):
|
141
140
|
"""
|
142
|
-
Save a configuration object to the directory `save_directory
|
141
|
+
Save a configuration object to the directory specified in `save_directory` so that it can be reloaded using the
|
143
142
|
[`~ConfigMixin.from_config`] class method.
|
144
143
|
|
145
144
|
Args:
|
146
145
|
save_directory (`str` or `os.PathLike`):
|
147
|
-
Directory where the configuration JSON file
|
146
|
+
Directory where the configuration JSON file is saved (will be created if it does not exist).
|
148
147
|
"""
|
149
148
|
if os.path.isfile(save_directory):
|
150
149
|
raise AssertionError(f"Provided path ({save_directory}) should be a directory, not a file")
|
@@ -164,15 +163,14 @@ class ConfigMixin:
|
|
164
163
|
|
165
164
|
Parameters:
|
166
165
|
config (`Dict[str, Any]`):
|
167
|
-
A config dictionary from which the Python class
|
168
|
-
|
166
|
+
A config dictionary from which the Python class is instantiated. Make sure to only load configuration
|
167
|
+
files of compatible classes.
|
169
168
|
return_unused_kwargs (`bool`, *optional*, defaults to `False`):
|
170
169
|
Whether kwargs that are not consumed by the Python class should be returned or not.
|
171
|
-
|
172
170
|
kwargs (remaining dictionary of keyword arguments, *optional*):
|
173
171
|
Can be used to update the configuration object (after it is loaded) and initiate the Python class.
|
174
|
-
`**kwargs` are directly
|
175
|
-
overwrite same named arguments in `config`.
|
172
|
+
`**kwargs` are passed directly to the underlying scheduler/model's `__init__` method and eventually
|
173
|
+
overwrite the same named arguments in `config`.
|
176
174
|
|
177
175
|
Returns:
|
178
176
|
[`ModelMixin`] or [`SchedulerMixin`]:
|
@@ -280,16 +278,16 @@ class ConfigMixin:
|
|
280
278
|
Whether or not to force the (re-)download of the model weights and configuration files, overriding the
|
281
279
|
cached versions if they exist.
|
282
280
|
resume_download (`bool`, *optional*, defaults to `False`):
|
283
|
-
Whether or not to resume downloading the model weights and configuration files. If set to False
|
281
|
+
Whether or not to resume downloading the model weights and configuration files. If set to `False`, any
|
284
282
|
incompletely downloaded files are deleted.
|
285
283
|
proxies (`Dict[str, str]`, *optional*):
|
286
284
|
A dictionary of proxy servers to use by protocol or endpoint, for example, `{'http': 'foo.bar:3128',
|
287
285
|
'http://hostname': 'foo.bar:4012'}`. The proxies are used on each request.
|
288
286
|
output_loading_info(`bool`, *optional*, defaults to `False`):
|
289
287
|
Whether or not to also return a dictionary containing missing keys, unexpected keys and error messages.
|
290
|
-
local_files_only(`bool`, *optional*, defaults to `False`):
|
291
|
-
Whether to only load local model weights and configuration files or not. If set to True
|
292
|
-
won
|
288
|
+
local_files_only (`bool`, *optional*, defaults to `False`):
|
289
|
+
Whether to only load local model weights and configuration files or not. If set to `True`, the model
|
290
|
+
won't be downloaded from the Hub.
|
293
291
|
use_auth_token (`str` or *bool*, *optional*):
|
294
292
|
The token to use as HTTP bearer authorization for remote files. If `True`, the token generated from
|
295
293
|
`diffusers-cli login` (stored in `~/.huggingface`) is used.
|
@@ -307,14 +305,6 @@ class ConfigMixin:
|
|
307
305
|
`dict`:
|
308
306
|
A dictionary of all the parameters stored in a JSON configuration file.
|
309
307
|
|
310
|
-
<Tip>
|
311
|
-
|
312
|
-
To use private or [gated models](https://huggingface.co/docs/hub/models-gated#gated-models), log-in with
|
313
|
-
`huggingface-cli login`. You can also activate the special
|
314
|
-
["offline-mode"](https://huggingface.co/transformers/installation.html#offline-mode) to use this method in a
|
315
|
-
firewalled environment.
|
316
|
-
|
317
|
-
</Tip>
|
318
308
|
"""
|
319
309
|
cache_dir = kwargs.pop("cache_dir", DIFFUSERS_CACHE)
|
320
310
|
force_download = kwargs.pop("force_download", False)
|
@@ -433,6 +423,10 @@ class ConfigMixin:
|
|
433
423
|
|
434
424
|
@classmethod
|
435
425
|
def extract_init_dict(cls, config_dict, **kwargs):
|
426
|
+
# Skip keys that were not present in the original config, so default __init__ values were used
|
427
|
+
used_defaults = config_dict.get("_use_default_values", [])
|
428
|
+
config_dict = {k: v for k, v in config_dict.items() if k not in used_defaults and k != "_use_default_values"}
|
429
|
+
|
436
430
|
# 0. Copy origin config dict
|
437
431
|
original_dict = dict(config_dict.items())
|
438
432
|
|
@@ -536,10 +530,11 @@ class ConfigMixin:
|
|
536
530
|
|
537
531
|
def to_json_string(self) -> str:
|
538
532
|
"""
|
539
|
-
Serializes
|
533
|
+
Serializes the configuration instance to a JSON string.
|
540
534
|
|
541
535
|
Returns:
|
542
|
-
`str`:
|
536
|
+
`str`:
|
537
|
+
String containing all the attributes that make up the configuration instance in JSON format.
|
543
538
|
"""
|
544
539
|
config_dict = self._internal_dict if hasattr(self, "_internal_dict") else {}
|
545
540
|
config_dict["_class_name"] = self.__class__.__name__
|
@@ -553,18 +548,19 @@ class ConfigMixin:
|
|
553
548
|
return value
|
554
549
|
|
555
550
|
config_dict = {k: to_json_saveable(v) for k, v in config_dict.items()}
|
556
|
-
# Don't save "_ignore_files"
|
551
|
+
# Don't save "_ignore_files" or "_use_default_values"
|
557
552
|
config_dict.pop("_ignore_files", None)
|
553
|
+
config_dict.pop("_use_default_values", None)
|
558
554
|
|
559
555
|
return json.dumps(config_dict, indent=2, sort_keys=True) + "\n"
|
560
556
|
|
561
557
|
def to_json_file(self, json_file_path: Union[str, os.PathLike]):
|
562
558
|
"""
|
563
|
-
Save
|
559
|
+
Save the configuration instance's parameters to a JSON file.
|
564
560
|
|
565
561
|
Args:
|
566
562
|
json_file_path (`str` or `os.PathLike`):
|
567
|
-
Path to the JSON file
|
563
|
+
Path to the JSON file to save a configuration instance's parameters.
|
568
564
|
"""
|
569
565
|
with open(json_file_path, "w", encoding="utf-8") as writer:
|
570
566
|
writer.write(self.to_json_string())
|
@@ -608,6 +604,11 @@ def register_to_config(init):
|
|
608
604
|
if k not in ignore and k not in new_kwargs
|
609
605
|
}
|
610
606
|
)
|
607
|
+
|
608
|
+
# Take note of the parameters that were not present in the loaded config
|
609
|
+
if len(set(new_kwargs.keys()) - set(init_kwargs)) > 0:
|
610
|
+
new_kwargs["_use_default_values"] = list(set(new_kwargs.keys()) - set(init_kwargs))
|
611
|
+
|
611
612
|
new_kwargs = {**config_init_kwargs, **new_kwargs}
|
612
613
|
getattr(self, "register_to_config")(**new_kwargs)
|
613
614
|
init(self, *args, **init_kwargs)
|
@@ -652,6 +653,10 @@ def flax_register_to_config(cls):
|
|
652
653
|
name = fields[i].name
|
653
654
|
new_kwargs[name] = arg
|
654
655
|
|
656
|
+
# Take note of the parameters that were not present in the loaded config
|
657
|
+
if len(set(new_kwargs.keys()) - set(init_kwargs)) > 0:
|
658
|
+
new_kwargs["_use_default_values"] = list(set(new_kwargs.keys()) - set(init_kwargs))
|
659
|
+
|
655
660
|
getattr(self, "register_to_config")(**new_kwargs)
|
656
661
|
original_init(self, *args, **kwargs)
|
657
662
|
|
@@ -13,11 +13,14 @@ deps = {
|
|
13
13
|
"huggingface-hub": "huggingface-hub>=0.13.2",
|
14
14
|
"requests-mock": "requests-mock==1.10.0",
|
15
15
|
"importlib_metadata": "importlib_metadata",
|
16
|
+
"invisible-watermark": "invisible-watermark",
|
16
17
|
"isort": "isort>=5.5.4",
|
17
18
|
"jax": "jax>=0.2.8,!=0.3.2",
|
18
19
|
"jaxlib": "jaxlib>=0.1.65",
|
19
20
|
"Jinja2": "Jinja2",
|
20
21
|
"k-diffusion": "k-diffusion>=0.0.12",
|
22
|
+
"torchsde": "torchsde",
|
23
|
+
"note_seq": "note_seq",
|
21
24
|
"librosa": "librosa",
|
22
25
|
"numpy": "numpy",
|
23
26
|
"omegaconf": "omegaconf",
|
@@ -30,6 +33,7 @@ deps = {
|
|
30
33
|
"safetensors": "safetensors",
|
31
34
|
"sentencepiece": "sentencepiece>=0.1.91,!=0.1.92",
|
32
35
|
"scipy": "scipy",
|
36
|
+
"onnx": "onnx",
|
33
37
|
"regex": "regex!=2019.12.17",
|
34
38
|
"requests": "requests",
|
35
39
|
"tensorboard": "tensorboard",
|
diffusers/image_processor.py
CHANGED
@@ -26,19 +26,18 @@ from .utils import CONFIG_NAME, PIL_INTERPOLATION, deprecate
|
|
26
26
|
|
27
27
|
class VaeImageProcessor(ConfigMixin):
|
28
28
|
"""
|
29
|
-
Image
|
29
|
+
Image processor for VAE.
|
30
30
|
|
31
31
|
Args:
|
32
32
|
do_resize (`bool`, *optional*, defaults to `True`):
|
33
33
|
Whether to downscale the image's (height, width) dimensions to multiples of `vae_scale_factor`. Can accept
|
34
|
-
`height` and `width` arguments from `preprocess` method
|
34
|
+
`height` and `width` arguments from [`image_processor.VaeImageProcessor.preprocess`] method.
|
35
35
|
vae_scale_factor (`int`, *optional*, defaults to `8`):
|
36
|
-
VAE scale factor. If `do_resize` is True
|
37
|
-
factor.
|
36
|
+
VAE scale factor. If `do_resize` is `True`, the image is automatically resized to multiples of this factor.
|
38
37
|
resample (`str`, *optional*, defaults to `lanczos`):
|
39
38
|
Resampling filter to use when resizing the image.
|
40
39
|
do_normalize (`bool`, *optional*, defaults to `True`):
|
41
|
-
Whether to normalize the image to [-1,1]
|
40
|
+
Whether to normalize the image to [-1,1].
|
42
41
|
do_convert_rgb (`bool`, *optional*, defaults to be `False`):
|
43
42
|
Whether to convert the images to RGB format.
|
44
43
|
"""
|
@@ -75,7 +74,7 @@ class VaeImageProcessor(ConfigMixin):
|
|
75
74
|
@staticmethod
|
76
75
|
def pil_to_numpy(images: Union[List[PIL.Image.Image], PIL.Image.Image]) -> np.ndarray:
|
77
76
|
"""
|
78
|
-
Convert a PIL image or a list of PIL images to
|
77
|
+
Convert a PIL image or a list of PIL images to NumPy arrays.
|
79
78
|
"""
|
80
79
|
if not isinstance(images, list):
|
81
80
|
images = [images]
|
@@ -87,7 +86,7 @@ class VaeImageProcessor(ConfigMixin):
|
|
87
86
|
@staticmethod
|
88
87
|
def numpy_to_pt(images: np.ndarray) -> torch.FloatTensor:
|
89
88
|
"""
|
90
|
-
Convert a
|
89
|
+
Convert a NumPy image to a PyTorch tensor.
|
91
90
|
"""
|
92
91
|
if images.ndim == 3:
|
93
92
|
images = images[..., None]
|
@@ -98,7 +97,7 @@ class VaeImageProcessor(ConfigMixin):
|
|
98
97
|
@staticmethod
|
99
98
|
def pt_to_numpy(images: torch.FloatTensor) -> np.ndarray:
|
100
99
|
"""
|
101
|
-
Convert a
|
100
|
+
Convert a PyTorch tensor to a NumPy image.
|
102
101
|
"""
|
103
102
|
images = images.cpu().permute(0, 2, 3, 1).float().numpy()
|
104
103
|
return images
|
@@ -106,14 +105,14 @@ class VaeImageProcessor(ConfigMixin):
|
|
106
105
|
@staticmethod
|
107
106
|
def normalize(images):
|
108
107
|
"""
|
109
|
-
Normalize an image array to [-1,1]
|
108
|
+
Normalize an image array to [-1,1].
|
110
109
|
"""
|
111
110
|
return 2.0 * images - 1.0
|
112
111
|
|
113
112
|
@staticmethod
|
114
113
|
def denormalize(images):
|
115
114
|
"""
|
116
|
-
Denormalize an image array to [0,1]
|
115
|
+
Denormalize an image array to [0,1].
|
117
116
|
"""
|
118
117
|
return (images / 2 + 0.5).clamp(0, 1)
|
119
118
|
|
@@ -132,7 +131,7 @@ class VaeImageProcessor(ConfigMixin):
|
|
132
131
|
width: Optional[int] = None,
|
133
132
|
) -> PIL.Image.Image:
|
134
133
|
"""
|
135
|
-
Resize a PIL image. Both height and width
|
134
|
+
Resize a PIL image. Both height and width are downscaled to the next integer multiple of `vae_scale_factor`.
|
136
135
|
"""
|
137
136
|
if height is None:
|
138
137
|
height = image.height
|
@@ -152,7 +151,7 @@ class VaeImageProcessor(ConfigMixin):
|
|
152
151
|
width: Optional[int] = None,
|
153
152
|
) -> torch.Tensor:
|
154
153
|
"""
|
155
|
-
Preprocess the image input
|
154
|
+
Preprocess the image input. Accepted formats are PIL images, NumPy arrays or PyTorch tensors.
|
156
155
|
"""
|
157
156
|
supported_formats = (PIL.Image.Image, np.ndarray, torch.Tensor)
|
158
157
|
if isinstance(image, supported_formats):
|
@@ -251,3 +250,117 @@ class VaeImageProcessor(ConfigMixin):
|
|
251
250
|
|
252
251
|
if output_type == "pil":
|
253
252
|
return self.numpy_to_pil(image)
|
253
|
+
|
254
|
+
|
255
|
+
class VaeImageProcessorLDM3D(VaeImageProcessor):
|
256
|
+
"""
|
257
|
+
Image processor for VAE LDM3D.
|
258
|
+
|
259
|
+
Args:
|
260
|
+
do_resize (`bool`, *optional*, defaults to `True`):
|
261
|
+
Whether to downscale the image's (height, width) dimensions to multiples of `vae_scale_factor`.
|
262
|
+
vae_scale_factor (`int`, *optional*, defaults to `8`):
|
263
|
+
VAE scale factor. If `do_resize` is `True`, the image is automatically resized to multiples of this factor.
|
264
|
+
resample (`str`, *optional*, defaults to `lanczos`):
|
265
|
+
Resampling filter to use when resizing the image.
|
266
|
+
do_normalize (`bool`, *optional*, defaults to `True`):
|
267
|
+
Whether to normalize the image to [-1,1].
|
268
|
+
"""
|
269
|
+
|
270
|
+
config_name = CONFIG_NAME
|
271
|
+
|
272
|
+
@register_to_config
|
273
|
+
def __init__(
|
274
|
+
self,
|
275
|
+
do_resize: bool = True,
|
276
|
+
vae_scale_factor: int = 8,
|
277
|
+
resample: str = "lanczos",
|
278
|
+
do_normalize: bool = True,
|
279
|
+
):
|
280
|
+
super().__init__()
|
281
|
+
|
282
|
+
@staticmethod
|
283
|
+
def numpy_to_pil(images):
|
284
|
+
"""
|
285
|
+
Convert a NumPy image or a batch of images to a PIL image.
|
286
|
+
"""
|
287
|
+
if images.ndim == 3:
|
288
|
+
images = images[None, ...]
|
289
|
+
images = (images * 255).round().astype("uint8")
|
290
|
+
if images.shape[-1] == 1:
|
291
|
+
# special case for grayscale (single channel) images
|
292
|
+
pil_images = [Image.fromarray(image.squeeze(), mode="L") for image in images]
|
293
|
+
else:
|
294
|
+
pil_images = [Image.fromarray(image[:, :, :3]) for image in images]
|
295
|
+
|
296
|
+
return pil_images
|
297
|
+
|
298
|
+
@staticmethod
|
299
|
+
def rgblike_to_depthmap(image):
|
300
|
+
"""
|
301
|
+
Args:
|
302
|
+
image: RGB-like depth image
|
303
|
+
|
304
|
+
Returns: depth map
|
305
|
+
|
306
|
+
"""
|
307
|
+
return image[:, :, 1] * 2**8 + image[:, :, 2]
|
308
|
+
|
309
|
+
def numpy_to_depth(self, images):
|
310
|
+
"""
|
311
|
+
Convert a NumPy depth image or a batch of images to a PIL image.
|
312
|
+
"""
|
313
|
+
if images.ndim == 3:
|
314
|
+
images = images[None, ...]
|
315
|
+
images_depth = images[:, :, :, 3:]
|
316
|
+
if images.shape[-1] == 6:
|
317
|
+
images_depth = (images_depth * 255).round().astype("uint8")
|
318
|
+
pil_images = [
|
319
|
+
Image.fromarray(self.rgblike_to_depthmap(image_depth), mode="I;16") for image_depth in images_depth
|
320
|
+
]
|
321
|
+
elif images.shape[-1] == 4:
|
322
|
+
images_depth = (images_depth * 65535.0).astype(np.uint16)
|
323
|
+
pil_images = [Image.fromarray(image_depth, mode="I;16") for image_depth in images_depth]
|
324
|
+
else:
|
325
|
+
raise Exception("Not supported")
|
326
|
+
|
327
|
+
return pil_images
|
328
|
+
|
329
|
+
def postprocess(
|
330
|
+
self,
|
331
|
+
image: torch.FloatTensor,
|
332
|
+
output_type: str = "pil",
|
333
|
+
do_denormalize: Optional[List[bool]] = None,
|
334
|
+
):
|
335
|
+
if not isinstance(image, torch.Tensor):
|
336
|
+
raise ValueError(
|
337
|
+
f"Input for postprocessing is in incorrect format: {type(image)}. We only support pytorch tensor"
|
338
|
+
)
|
339
|
+
if output_type not in ["latent", "pt", "np", "pil"]:
|
340
|
+
deprecation_message = (
|
341
|
+
f"the output_type {output_type} is outdated and has been set to `np`. Please make sure to set it to one of these instead: "
|
342
|
+
"`pil`, `np`, `pt`, `latent`"
|
343
|
+
)
|
344
|
+
deprecate("Unsupported output_type", "1.0.0", deprecation_message, standard_warn=False)
|
345
|
+
output_type = "np"
|
346
|
+
|
347
|
+
if do_denormalize is None:
|
348
|
+
do_denormalize = [self.config.do_normalize] * image.shape[0]
|
349
|
+
|
350
|
+
image = torch.stack(
|
351
|
+
[self.denormalize(image[i]) if do_denormalize[i] else image[i] for i in range(image.shape[0])]
|
352
|
+
)
|
353
|
+
|
354
|
+
image = self.pt_to_numpy(image)
|
355
|
+
|
356
|
+
if output_type == "np":
|
357
|
+
if image.shape[-1] == 6:
|
358
|
+
image_depth = np.stack([self.rgblike_to_depthmap(im[:, :, 3:]) for im in image], axis=0)
|
359
|
+
else:
|
360
|
+
image_depth = image[:, :, :, 3:]
|
361
|
+
return image[:, :, :, :3], image_depth
|
362
|
+
|
363
|
+
if output_type == "pil":
|
364
|
+
return self.numpy_to_pil(image), self.numpy_to_depth(image)
|
365
|
+
else:
|
366
|
+
raise Exception(f"This type {output_type} is not supported")
|