diffusers 0.17.1__py3-none-any.whl → 0.18.2__py3-none-any.whl

Sign up to get free protection for your applications and to get access to all the features.
Files changed (120) hide show
  1. diffusers/__init__.py +26 -1
  2. diffusers/configuration_utils.py +34 -29
  3. diffusers/dependency_versions_table.py +4 -0
  4. diffusers/image_processor.py +125 -12
  5. diffusers/loaders.py +169 -203
  6. diffusers/models/attention.py +24 -1
  7. diffusers/models/attention_flax.py +10 -5
  8. diffusers/models/attention_processor.py +3 -0
  9. diffusers/models/autoencoder_kl.py +114 -33
  10. diffusers/models/controlnet.py +131 -14
  11. diffusers/models/controlnet_flax.py +37 -26
  12. diffusers/models/cross_attention.py +17 -17
  13. diffusers/models/embeddings.py +67 -0
  14. diffusers/models/modeling_flax_utils.py +64 -56
  15. diffusers/models/modeling_utils.py +193 -104
  16. diffusers/models/prior_transformer.py +207 -37
  17. diffusers/models/resnet.py +26 -26
  18. diffusers/models/transformer_2d.py +36 -41
  19. diffusers/models/transformer_temporal.py +24 -21
  20. diffusers/models/unet_1d.py +31 -25
  21. diffusers/models/unet_2d.py +43 -30
  22. diffusers/models/unet_2d_blocks.py +210 -89
  23. diffusers/models/unet_2d_blocks_flax.py +12 -12
  24. diffusers/models/unet_2d_condition.py +172 -64
  25. diffusers/models/unet_2d_condition_flax.py +38 -24
  26. diffusers/models/unet_3d_blocks.py +34 -31
  27. diffusers/models/unet_3d_condition.py +101 -34
  28. diffusers/models/vae.py +5 -5
  29. diffusers/models/vae_flax.py +37 -34
  30. diffusers/models/vq_model.py +23 -14
  31. diffusers/pipelines/__init__.py +24 -1
  32. diffusers/pipelines/alt_diffusion/pipeline_alt_diffusion.py +1 -1
  33. diffusers/pipelines/alt_diffusion/pipeline_alt_diffusion_img2img.py +5 -3
  34. diffusers/pipelines/consistency_models/__init__.py +1 -0
  35. diffusers/pipelines/consistency_models/pipeline_consistency_models.py +337 -0
  36. diffusers/pipelines/controlnet/multicontrolnet.py +120 -1
  37. diffusers/pipelines/controlnet/pipeline_controlnet.py +59 -17
  38. diffusers/pipelines/controlnet/pipeline_controlnet_img2img.py +60 -15
  39. diffusers/pipelines/controlnet/pipeline_controlnet_inpaint.py +60 -17
  40. diffusers/pipelines/controlnet/pipeline_flax_controlnet.py +1 -1
  41. diffusers/pipelines/kandinsky/__init__.py +1 -1
  42. diffusers/pipelines/kandinsky/pipeline_kandinsky.py +4 -6
  43. diffusers/pipelines/kandinsky/pipeline_kandinsky_inpaint.py +1 -0
  44. diffusers/pipelines/kandinsky/pipeline_kandinsky_prior.py +1 -0
  45. diffusers/pipelines/kandinsky2_2/__init__.py +7 -0
  46. diffusers/pipelines/kandinsky2_2/pipeline_kandinsky2_2.py +317 -0
  47. diffusers/pipelines/kandinsky2_2/pipeline_kandinsky2_2_controlnet.py +372 -0
  48. diffusers/pipelines/kandinsky2_2/pipeline_kandinsky2_2_controlnet_img2img.py +434 -0
  49. diffusers/pipelines/kandinsky2_2/pipeline_kandinsky2_2_img2img.py +398 -0
  50. diffusers/pipelines/kandinsky2_2/pipeline_kandinsky2_2_inpainting.py +531 -0
  51. diffusers/pipelines/kandinsky2_2/pipeline_kandinsky2_2_prior.py +541 -0
  52. diffusers/pipelines/kandinsky2_2/pipeline_kandinsky2_2_prior_emb2emb.py +605 -0
  53. diffusers/pipelines/pipeline_flax_utils.py +2 -2
  54. diffusers/pipelines/pipeline_utils.py +124 -146
  55. diffusers/pipelines/shap_e/__init__.py +27 -0
  56. diffusers/pipelines/shap_e/camera.py +147 -0
  57. diffusers/pipelines/shap_e/pipeline_shap_e.py +390 -0
  58. diffusers/pipelines/shap_e/pipeline_shap_e_img2img.py +349 -0
  59. diffusers/pipelines/shap_e/renderer.py +709 -0
  60. diffusers/pipelines/stable_diffusion/__init__.py +2 -0
  61. diffusers/pipelines/stable_diffusion/convert_from_ckpt.py +261 -66
  62. diffusers/pipelines/stable_diffusion/pipeline_stable_diffusion.py +3 -3
  63. diffusers/pipelines/stable_diffusion/pipeline_stable_diffusion_img2img.py +5 -3
  64. diffusers/pipelines/stable_diffusion/pipeline_stable_diffusion_inpaint.py +4 -2
  65. diffusers/pipelines/stable_diffusion/pipeline_stable_diffusion_inpaint_legacy.py +6 -6
  66. diffusers/pipelines/stable_diffusion/pipeline_stable_diffusion_instruct_pix2pix.py +1 -1
  67. diffusers/pipelines/stable_diffusion/pipeline_stable_diffusion_k_diffusion.py +1 -1
  68. diffusers/pipelines/stable_diffusion/pipeline_stable_diffusion_ldm3d.py +719 -0
  69. diffusers/pipelines/stable_diffusion/pipeline_stable_diffusion_panorama.py +1 -1
  70. diffusers/pipelines/stable_diffusion/pipeline_stable_diffusion_paradigms.py +832 -0
  71. diffusers/pipelines/stable_diffusion/pipeline_stable_diffusion_upscale.py +17 -7
  72. diffusers/pipelines/stable_diffusion_xl/__init__.py +26 -0
  73. diffusers/pipelines/stable_diffusion_xl/pipeline_stable_diffusion_xl.py +823 -0
  74. diffusers/pipelines/stable_diffusion_xl/pipeline_stable_diffusion_xl_img2img.py +896 -0
  75. diffusers/pipelines/stable_diffusion_xl/watermark.py +31 -0
  76. diffusers/pipelines/text_to_video_synthesis/__init__.py +2 -1
  77. diffusers/pipelines/text_to_video_synthesis/pipeline_text_to_video_synth.py +5 -1
  78. diffusers/pipelines/text_to_video_synthesis/pipeline_text_to_video_synth_img2img.py +771 -0
  79. diffusers/pipelines/text_to_video_synthesis/pipeline_text_to_video_zero.py +92 -6
  80. diffusers/pipelines/unidiffuser/pipeline_unidiffuser.py +3 -3
  81. diffusers/pipelines/versatile_diffusion/modeling_text_unet.py +209 -91
  82. diffusers/schedulers/__init__.py +3 -0
  83. diffusers/schedulers/scheduling_consistency_models.py +380 -0
  84. diffusers/schedulers/scheduling_ddim.py +28 -6
  85. diffusers/schedulers/scheduling_ddim_inverse.py +19 -4
  86. diffusers/schedulers/scheduling_ddim_parallel.py +642 -0
  87. diffusers/schedulers/scheduling_ddpm.py +53 -7
  88. diffusers/schedulers/scheduling_ddpm_parallel.py +604 -0
  89. diffusers/schedulers/scheduling_deis_multistep.py +66 -11
  90. diffusers/schedulers/scheduling_dpmsolver_multistep.py +55 -13
  91. diffusers/schedulers/scheduling_dpmsolver_multistep_inverse.py +19 -4
  92. diffusers/schedulers/scheduling_dpmsolver_sde.py +73 -11
  93. diffusers/schedulers/scheduling_dpmsolver_singlestep.py +23 -7
  94. diffusers/schedulers/scheduling_euler_ancestral_discrete.py +58 -9
  95. diffusers/schedulers/scheduling_euler_discrete.py +58 -8
  96. diffusers/schedulers/scheduling_heun_discrete.py +89 -14
  97. diffusers/schedulers/scheduling_k_dpm_2_ancestral_discrete.py +73 -11
  98. diffusers/schedulers/scheduling_k_dpm_2_discrete.py +73 -11
  99. diffusers/schedulers/scheduling_lms_discrete.py +57 -8
  100. diffusers/schedulers/scheduling_pndm.py +46 -10
  101. diffusers/schedulers/scheduling_repaint.py +19 -4
  102. diffusers/schedulers/scheduling_sde_ve.py +5 -1
  103. diffusers/schedulers/scheduling_unclip.py +43 -4
  104. diffusers/schedulers/scheduling_unipc_multistep.py +48 -7
  105. diffusers/training_utils.py +1 -1
  106. diffusers/utils/__init__.py +2 -1
  107. diffusers/utils/dummy_pt_objects.py +60 -0
  108. diffusers/utils/dummy_torch_and_transformers_and_invisible_watermark_objects.py +32 -0
  109. diffusers/utils/dummy_torch_and_transformers_objects.py +180 -0
  110. diffusers/utils/hub_utils.py +1 -1
  111. diffusers/utils/import_utils.py +20 -3
  112. diffusers/utils/logging.py +15 -18
  113. diffusers/utils/outputs.py +3 -3
  114. diffusers/utils/testing_utils.py +15 -0
  115. {diffusers-0.17.1.dist-info → diffusers-0.18.2.dist-info}/METADATA +4 -2
  116. {diffusers-0.17.1.dist-info → diffusers-0.18.2.dist-info}/RECORD +120 -94
  117. {diffusers-0.17.1.dist-info → diffusers-0.18.2.dist-info}/WHEEL +1 -1
  118. {diffusers-0.17.1.dist-info → diffusers-0.18.2.dist-info}/LICENSE +0 -0
  119. {diffusers-0.17.1.dist-info → diffusers-0.18.2.dist-info}/entry_points.txt +0 -0
  120. {diffusers-0.17.1.dist-info → diffusers-0.18.2.dist-info}/top_level.txt +0 -0
diffusers/__init__.py CHANGED
@@ -1,10 +1,11 @@
1
- __version__ = "0.17.1"
1
+ __version__ = "0.18.2"
2
2
 
3
3
  from .configuration_utils import ConfigMixin
4
4
  from .utils import (
5
5
  OptionalDependencyNotAvailable,
6
6
  is_flax_available,
7
7
  is_inflect_available,
8
+ is_invisible_watermark_available,
8
9
  is_k_diffusion_available,
9
10
  is_k_diffusion_version,
10
11
  is_librosa_available,
@@ -58,6 +59,7 @@ else:
58
59
  )
59
60
  from .pipelines import (
60
61
  AudioPipelineOutput,
62
+ ConsistencyModelPipeline,
61
63
  DanceDiffusionPipeline,
62
64
  DDIMPipeline,
63
65
  DDPMPipeline,
@@ -72,8 +74,11 @@ else:
72
74
  ScoreSdeVePipeline,
73
75
  )
74
76
  from .schedulers import (
77
+ CMStochasticIterativeScheduler,
75
78
  DDIMInverseScheduler,
79
+ DDIMParallelScheduler,
76
80
  DDIMScheduler,
81
+ DDPMParallelScheduler,
77
82
  DDPMScheduler,
78
83
  DEISMultistepScheduler,
79
84
  DPMSolverMultistepInverseScheduler,
@@ -134,9 +139,18 @@ else:
134
139
  KandinskyInpaintPipeline,
135
140
  KandinskyPipeline,
136
141
  KandinskyPriorPipeline,
142
+ KandinskyV22ControlnetImg2ImgPipeline,
143
+ KandinskyV22ControlnetPipeline,
144
+ KandinskyV22Img2ImgPipeline,
145
+ KandinskyV22InpaintPipeline,
146
+ KandinskyV22Pipeline,
147
+ KandinskyV22PriorEmb2EmbPipeline,
148
+ KandinskyV22PriorPipeline,
137
149
  LDMTextToImagePipeline,
138
150
  PaintByExamplePipeline,
139
151
  SemanticStableDiffusionPipeline,
152
+ ShapEImg2ImgPipeline,
153
+ ShapEPipeline,
140
154
  StableDiffusionAttendAndExcitePipeline,
141
155
  StableDiffusionControlNetImg2ImgPipeline,
142
156
  StableDiffusionControlNetInpaintPipeline,
@@ -149,8 +163,10 @@ else:
149
163
  StableDiffusionInpaintPipelineLegacy,
150
164
  StableDiffusionInstructPix2PixPipeline,
151
165
  StableDiffusionLatentUpscalePipeline,
166
+ StableDiffusionLDM3DPipeline,
152
167
  StableDiffusionModelEditingPipeline,
153
168
  StableDiffusionPanoramaPipeline,
169
+ StableDiffusionParadigmsPipeline,
154
170
  StableDiffusionPipeline,
155
171
  StableDiffusionPipelineSafe,
156
172
  StableDiffusionPix2PixZeroPipeline,
@@ -169,9 +185,18 @@ else:
169
185
  VersatileDiffusionImageVariationPipeline,
170
186
  VersatileDiffusionPipeline,
171
187
  VersatileDiffusionTextToImagePipeline,
188
+ VideoToVideoSDPipeline,
172
189
  VQDiffusionPipeline,
173
190
  )
174
191
 
192
+ try:
193
+ if not (is_torch_available() and is_transformers_available() and is_invisible_watermark_available()):
194
+ raise OptionalDependencyNotAvailable()
195
+ except OptionalDependencyNotAvailable:
196
+ from .utils.dummy_torch_and_transformers_and_invisible_watermark_objects import * # noqa F403
197
+ else:
198
+ from .pipelines import StableDiffusionXLImg2ImgPipeline, StableDiffusionXLPipeline
199
+
175
200
  try:
176
201
  if not (is_torch_available() and is_transformers_available() and is_k_diffusion_available()):
177
202
  raise OptionalDependencyNotAvailable()
@@ -81,10 +81,9 @@ class FrozenDict(OrderedDict):
81
81
 
82
82
  class ConfigMixin:
83
83
  r"""
84
- Base class for all configuration classes. Stores all configuration parameters under `self.config` Also handles all
85
- methods for loading/downloading/saving classes inheriting from [`ConfigMixin`] with
86
- - [`~ConfigMixin.from_config`]
87
- - [`~ConfigMixin.save_config`]
84
+ Base class for all configuration classes. All configuration parameters are stored under `self.config`. Also
85
+ provides the [`~ConfigMixin.from_config`] and [`~ConfigMixin.save_config`] methods for loading, downloading, and
86
+ saving classes that inherit from [`ConfigMixin`].
88
87
 
89
88
  Class attributes:
90
89
  - **config_name** (`str`) -- A filename under which the config should stored when calling
@@ -92,7 +91,7 @@ class ConfigMixin:
92
91
  - **ignore_for_config** (`List[str]`) -- A list of attributes that should not be saved in the config (should be
93
92
  overridden by subclass).
94
93
  - **has_compatibles** (`bool`) -- Whether the class has compatible classes (should be overridden by subclass).
95
- - **_deprecated_kwargs** (`List[str]`) -- Keyword arguments that are deprecated. Note that the init function
94
+ - **_deprecated_kwargs** (`List[str]`) -- Keyword arguments that are deprecated. Note that the `init` function
96
95
  should only have a `kwargs` argument if at least one argument is deprecated (should be overridden by
97
96
  subclass).
98
97
  """
@@ -139,12 +138,12 @@ class ConfigMixin:
139
138
 
140
139
  def save_config(self, save_directory: Union[str, os.PathLike], push_to_hub: bool = False, **kwargs):
141
140
  """
142
- Save a configuration object to the directory `save_directory`, so that it can be re-loaded using the
141
+ Save a configuration object to the directory specified in `save_directory` so that it can be reloaded using the
143
142
  [`~ConfigMixin.from_config`] class method.
144
143
 
145
144
  Args:
146
145
  save_directory (`str` or `os.PathLike`):
147
- Directory where the configuration JSON file will be saved (will be created if it does not exist).
146
+ Directory where the configuration JSON file is saved (will be created if it does not exist).
148
147
  """
149
148
  if os.path.isfile(save_directory):
150
149
  raise AssertionError(f"Provided path ({save_directory}) should be a directory, not a file")
@@ -164,15 +163,14 @@ class ConfigMixin:
164
163
 
165
164
  Parameters:
166
165
  config (`Dict[str, Any]`):
167
- A config dictionary from which the Python class will be instantiated. Make sure to only load
168
- configuration files of compatible classes.
166
+ A config dictionary from which the Python class is instantiated. Make sure to only load configuration
167
+ files of compatible classes.
169
168
  return_unused_kwargs (`bool`, *optional*, defaults to `False`):
170
169
  Whether kwargs that are not consumed by the Python class should be returned or not.
171
-
172
170
  kwargs (remaining dictionary of keyword arguments, *optional*):
173
171
  Can be used to update the configuration object (after it is loaded) and initiate the Python class.
174
- `**kwargs` are directly passed to the underlying scheduler/model's `__init__` method and eventually
175
- overwrite same named arguments in `config`.
172
+ `**kwargs` are passed directly to the underlying scheduler/model's `__init__` method and eventually
173
+ overwrite the same named arguments in `config`.
176
174
 
177
175
  Returns:
178
176
  [`ModelMixin`] or [`SchedulerMixin`]:
@@ -280,16 +278,16 @@ class ConfigMixin:
280
278
  Whether or not to force the (re-)download of the model weights and configuration files, overriding the
281
279
  cached versions if they exist.
282
280
  resume_download (`bool`, *optional*, defaults to `False`):
283
- Whether or not to resume downloading the model weights and configuration files. If set to False, any
281
+ Whether or not to resume downloading the model weights and configuration files. If set to `False`, any
284
282
  incompletely downloaded files are deleted.
285
283
  proxies (`Dict[str, str]`, *optional*):
286
284
  A dictionary of proxy servers to use by protocol or endpoint, for example, `{'http': 'foo.bar:3128',
287
285
  'http://hostname': 'foo.bar:4012'}`. The proxies are used on each request.
288
286
  output_loading_info(`bool`, *optional*, defaults to `False`):
289
287
  Whether or not to also return a dictionary containing missing keys, unexpected keys and error messages.
290
- local_files_only(`bool`, *optional*, defaults to `False`):
291
- Whether to only load local model weights and configuration files or not. If set to True, the model
292
- wont be downloaded from the Hub.
288
+ local_files_only (`bool`, *optional*, defaults to `False`):
289
+ Whether to only load local model weights and configuration files or not. If set to `True`, the model
290
+ won't be downloaded from the Hub.
293
291
  use_auth_token (`str` or *bool*, *optional*):
294
292
  The token to use as HTTP bearer authorization for remote files. If `True`, the token generated from
295
293
  `diffusers-cli login` (stored in `~/.huggingface`) is used.
@@ -307,14 +305,6 @@ class ConfigMixin:
307
305
  `dict`:
308
306
  A dictionary of all the parameters stored in a JSON configuration file.
309
307
 
310
- <Tip>
311
-
312
- To use private or [gated models](https://huggingface.co/docs/hub/models-gated#gated-models), log-in with
313
- `huggingface-cli login`. You can also activate the special
314
- ["offline-mode"](https://huggingface.co/transformers/installation.html#offline-mode) to use this method in a
315
- firewalled environment.
316
-
317
- </Tip>
318
308
  """
319
309
  cache_dir = kwargs.pop("cache_dir", DIFFUSERS_CACHE)
320
310
  force_download = kwargs.pop("force_download", False)
@@ -433,6 +423,10 @@ class ConfigMixin:
433
423
 
434
424
  @classmethod
435
425
  def extract_init_dict(cls, config_dict, **kwargs):
426
+ # Skip keys that were not present in the original config, so default __init__ values were used
427
+ used_defaults = config_dict.get("_use_default_values", [])
428
+ config_dict = {k: v for k, v in config_dict.items() if k not in used_defaults and k != "_use_default_values"}
429
+
436
430
  # 0. Copy origin config dict
437
431
  original_dict = dict(config_dict.items())
438
432
 
@@ -536,10 +530,11 @@ class ConfigMixin:
536
530
 
537
531
  def to_json_string(self) -> str:
538
532
  """
539
- Serializes this instance to a JSON string.
533
+ Serializes the configuration instance to a JSON string.
540
534
 
541
535
  Returns:
542
- `str`: String containing all the attributes that make up this configuration instance in JSON format.
536
+ `str`:
537
+ String containing all the attributes that make up the configuration instance in JSON format.
543
538
  """
544
539
  config_dict = self._internal_dict if hasattr(self, "_internal_dict") else {}
545
540
  config_dict["_class_name"] = self.__class__.__name__
@@ -553,18 +548,19 @@ class ConfigMixin:
553
548
  return value
554
549
 
555
550
  config_dict = {k: to_json_saveable(v) for k, v in config_dict.items()}
556
- # Don't save "_ignore_files"
551
+ # Don't save "_ignore_files" or "_use_default_values"
557
552
  config_dict.pop("_ignore_files", None)
553
+ config_dict.pop("_use_default_values", None)
558
554
 
559
555
  return json.dumps(config_dict, indent=2, sort_keys=True) + "\n"
560
556
 
561
557
  def to_json_file(self, json_file_path: Union[str, os.PathLike]):
562
558
  """
563
- Save this instance to a JSON file.
559
+ Save the configuration instance's parameters to a JSON file.
564
560
 
565
561
  Args:
566
562
  json_file_path (`str` or `os.PathLike`):
567
- Path to the JSON file in which this configuration instance's parameters will be saved.
563
+ Path to the JSON file to save a configuration instance's parameters.
568
564
  """
569
565
  with open(json_file_path, "w", encoding="utf-8") as writer:
570
566
  writer.write(self.to_json_string())
@@ -608,6 +604,11 @@ def register_to_config(init):
608
604
  if k not in ignore and k not in new_kwargs
609
605
  }
610
606
  )
607
+
608
+ # Take note of the parameters that were not present in the loaded config
609
+ if len(set(new_kwargs.keys()) - set(init_kwargs)) > 0:
610
+ new_kwargs["_use_default_values"] = list(set(new_kwargs.keys()) - set(init_kwargs))
611
+
611
612
  new_kwargs = {**config_init_kwargs, **new_kwargs}
612
613
  getattr(self, "register_to_config")(**new_kwargs)
613
614
  init(self, *args, **init_kwargs)
@@ -652,6 +653,10 @@ def flax_register_to_config(cls):
652
653
  name = fields[i].name
653
654
  new_kwargs[name] = arg
654
655
 
656
+ # Take note of the parameters that were not present in the loaded config
657
+ if len(set(new_kwargs.keys()) - set(init_kwargs)) > 0:
658
+ new_kwargs["_use_default_values"] = list(set(new_kwargs.keys()) - set(init_kwargs))
659
+
655
660
  getattr(self, "register_to_config")(**new_kwargs)
656
661
  original_init(self, *args, **kwargs)
657
662
 
@@ -13,11 +13,14 @@ deps = {
13
13
  "huggingface-hub": "huggingface-hub>=0.13.2",
14
14
  "requests-mock": "requests-mock==1.10.0",
15
15
  "importlib_metadata": "importlib_metadata",
16
+ "invisible-watermark": "invisible-watermark",
16
17
  "isort": "isort>=5.5.4",
17
18
  "jax": "jax>=0.2.8,!=0.3.2",
18
19
  "jaxlib": "jaxlib>=0.1.65",
19
20
  "Jinja2": "Jinja2",
20
21
  "k-diffusion": "k-diffusion>=0.0.12",
22
+ "torchsde": "torchsde",
23
+ "note_seq": "note_seq",
21
24
  "librosa": "librosa",
22
25
  "numpy": "numpy",
23
26
  "omegaconf": "omegaconf",
@@ -30,6 +33,7 @@ deps = {
30
33
  "safetensors": "safetensors",
31
34
  "sentencepiece": "sentencepiece>=0.1.91,!=0.1.92",
32
35
  "scipy": "scipy",
36
+ "onnx": "onnx",
33
37
  "regex": "regex!=2019.12.17",
34
38
  "requests": "requests",
35
39
  "tensorboard": "tensorboard",
@@ -26,19 +26,18 @@ from .utils import CONFIG_NAME, PIL_INTERPOLATION, deprecate
26
26
 
27
27
  class VaeImageProcessor(ConfigMixin):
28
28
  """
29
- Image Processor for VAE
29
+ Image processor for VAE.
30
30
 
31
31
  Args:
32
32
  do_resize (`bool`, *optional*, defaults to `True`):
33
33
  Whether to downscale the image's (height, width) dimensions to multiples of `vae_scale_factor`. Can accept
34
- `height` and `width` arguments from `preprocess` method
34
+ `height` and `width` arguments from [`image_processor.VaeImageProcessor.preprocess`] method.
35
35
  vae_scale_factor (`int`, *optional*, defaults to `8`):
36
- VAE scale factor. If `do_resize` is True, the image will be automatically resized to multiples of this
37
- factor.
36
+ VAE scale factor. If `do_resize` is `True`, the image is automatically resized to multiples of this factor.
38
37
  resample (`str`, *optional*, defaults to `lanczos`):
39
38
  Resampling filter to use when resizing the image.
40
39
  do_normalize (`bool`, *optional*, defaults to `True`):
41
- Whether to normalize the image to [-1,1]
40
+ Whether to normalize the image to [-1,1].
42
41
  do_convert_rgb (`bool`, *optional*, defaults to be `False`):
43
42
  Whether to convert the images to RGB format.
44
43
  """
@@ -75,7 +74,7 @@ class VaeImageProcessor(ConfigMixin):
75
74
  @staticmethod
76
75
  def pil_to_numpy(images: Union[List[PIL.Image.Image], PIL.Image.Image]) -> np.ndarray:
77
76
  """
78
- Convert a PIL image or a list of PIL images to numpy arrays.
77
+ Convert a PIL image or a list of PIL images to NumPy arrays.
79
78
  """
80
79
  if not isinstance(images, list):
81
80
  images = [images]
@@ -87,7 +86,7 @@ class VaeImageProcessor(ConfigMixin):
87
86
  @staticmethod
88
87
  def numpy_to_pt(images: np.ndarray) -> torch.FloatTensor:
89
88
  """
90
- Convert a numpy image to a pytorch tensor
89
+ Convert a NumPy image to a PyTorch tensor.
91
90
  """
92
91
  if images.ndim == 3:
93
92
  images = images[..., None]
@@ -98,7 +97,7 @@ class VaeImageProcessor(ConfigMixin):
98
97
  @staticmethod
99
98
  def pt_to_numpy(images: torch.FloatTensor) -> np.ndarray:
100
99
  """
101
- Convert a pytorch tensor to a numpy image
100
+ Convert a PyTorch tensor to a NumPy image.
102
101
  """
103
102
  images = images.cpu().permute(0, 2, 3, 1).float().numpy()
104
103
  return images
@@ -106,14 +105,14 @@ class VaeImageProcessor(ConfigMixin):
106
105
  @staticmethod
107
106
  def normalize(images):
108
107
  """
109
- Normalize an image array to [-1,1]
108
+ Normalize an image array to [-1,1].
110
109
  """
111
110
  return 2.0 * images - 1.0
112
111
 
113
112
  @staticmethod
114
113
  def denormalize(images):
115
114
  """
116
- Denormalize an image array to [0,1]
115
+ Denormalize an image array to [0,1].
117
116
  """
118
117
  return (images / 2 + 0.5).clamp(0, 1)
119
118
 
@@ -132,7 +131,7 @@ class VaeImageProcessor(ConfigMixin):
132
131
  width: Optional[int] = None,
133
132
  ) -> PIL.Image.Image:
134
133
  """
135
- Resize a PIL image. Both height and width will be downscaled to the next integer multiple of `vae_scale_factor`
134
+ Resize a PIL image. Both height and width are downscaled to the next integer multiple of `vae_scale_factor`.
136
135
  """
137
136
  if height is None:
138
137
  height = image.height
@@ -152,7 +151,7 @@ class VaeImageProcessor(ConfigMixin):
152
151
  width: Optional[int] = None,
153
152
  ) -> torch.Tensor:
154
153
  """
155
- Preprocess the image input, accepted formats are PIL images, numpy arrays or pytorch tensors"
154
+ Preprocess the image input. Accepted formats are PIL images, NumPy arrays or PyTorch tensors.
156
155
  """
157
156
  supported_formats = (PIL.Image.Image, np.ndarray, torch.Tensor)
158
157
  if isinstance(image, supported_formats):
@@ -251,3 +250,117 @@ class VaeImageProcessor(ConfigMixin):
251
250
 
252
251
  if output_type == "pil":
253
252
  return self.numpy_to_pil(image)
253
+
254
+
255
+ class VaeImageProcessorLDM3D(VaeImageProcessor):
256
+ """
257
+ Image processor for VAE LDM3D.
258
+
259
+ Args:
260
+ do_resize (`bool`, *optional*, defaults to `True`):
261
+ Whether to downscale the image's (height, width) dimensions to multiples of `vae_scale_factor`.
262
+ vae_scale_factor (`int`, *optional*, defaults to `8`):
263
+ VAE scale factor. If `do_resize` is `True`, the image is automatically resized to multiples of this factor.
264
+ resample (`str`, *optional*, defaults to `lanczos`):
265
+ Resampling filter to use when resizing the image.
266
+ do_normalize (`bool`, *optional*, defaults to `True`):
267
+ Whether to normalize the image to [-1,1].
268
+ """
269
+
270
+ config_name = CONFIG_NAME
271
+
272
+ @register_to_config
273
+ def __init__(
274
+ self,
275
+ do_resize: bool = True,
276
+ vae_scale_factor: int = 8,
277
+ resample: str = "lanczos",
278
+ do_normalize: bool = True,
279
+ ):
280
+ super().__init__()
281
+
282
+ @staticmethod
283
+ def numpy_to_pil(images):
284
+ """
285
+ Convert a NumPy image or a batch of images to a PIL image.
286
+ """
287
+ if images.ndim == 3:
288
+ images = images[None, ...]
289
+ images = (images * 255).round().astype("uint8")
290
+ if images.shape[-1] == 1:
291
+ # special case for grayscale (single channel) images
292
+ pil_images = [Image.fromarray(image.squeeze(), mode="L") for image in images]
293
+ else:
294
+ pil_images = [Image.fromarray(image[:, :, :3]) for image in images]
295
+
296
+ return pil_images
297
+
298
+ @staticmethod
299
+ def rgblike_to_depthmap(image):
300
+ """
301
+ Args:
302
+ image: RGB-like depth image
303
+
304
+ Returns: depth map
305
+
306
+ """
307
+ return image[:, :, 1] * 2**8 + image[:, :, 2]
308
+
309
+ def numpy_to_depth(self, images):
310
+ """
311
+ Convert a NumPy depth image or a batch of images to a PIL image.
312
+ """
313
+ if images.ndim == 3:
314
+ images = images[None, ...]
315
+ images_depth = images[:, :, :, 3:]
316
+ if images.shape[-1] == 6:
317
+ images_depth = (images_depth * 255).round().astype("uint8")
318
+ pil_images = [
319
+ Image.fromarray(self.rgblike_to_depthmap(image_depth), mode="I;16") for image_depth in images_depth
320
+ ]
321
+ elif images.shape[-1] == 4:
322
+ images_depth = (images_depth * 65535.0).astype(np.uint16)
323
+ pil_images = [Image.fromarray(image_depth, mode="I;16") for image_depth in images_depth]
324
+ else:
325
+ raise Exception("Not supported")
326
+
327
+ return pil_images
328
+
329
+ def postprocess(
330
+ self,
331
+ image: torch.FloatTensor,
332
+ output_type: str = "pil",
333
+ do_denormalize: Optional[List[bool]] = None,
334
+ ):
335
+ if not isinstance(image, torch.Tensor):
336
+ raise ValueError(
337
+ f"Input for postprocessing is in incorrect format: {type(image)}. We only support pytorch tensor"
338
+ )
339
+ if output_type not in ["latent", "pt", "np", "pil"]:
340
+ deprecation_message = (
341
+ f"the output_type {output_type} is outdated and has been set to `np`. Please make sure to set it to one of these instead: "
342
+ "`pil`, `np`, `pt`, `latent`"
343
+ )
344
+ deprecate("Unsupported output_type", "1.0.0", deprecation_message, standard_warn=False)
345
+ output_type = "np"
346
+
347
+ if do_denormalize is None:
348
+ do_denormalize = [self.config.do_normalize] * image.shape[0]
349
+
350
+ image = torch.stack(
351
+ [self.denormalize(image[i]) if do_denormalize[i] else image[i] for i in range(image.shape[0])]
352
+ )
353
+
354
+ image = self.pt_to_numpy(image)
355
+
356
+ if output_type == "np":
357
+ if image.shape[-1] == 6:
358
+ image_depth = np.stack([self.rgblike_to_depthmap(im[:, :, 3:]) for im in image], axis=0)
359
+ else:
360
+ image_depth = image[:, :, :, 3:]
361
+ return image[:, :, :, :3], image_depth
362
+
363
+ if output_type == "pil":
364
+ return self.numpy_to_pil(image), self.numpy_to_depth(image)
365
+ else:
366
+ raise Exception(f"This type {output_type} is not supported")