diffusers 0.23.1__py3-none-any.whl → 0.25.0__py3-none-any.whl

Sign up to get free protection for your applications and to get access to all the features.
Files changed (238) hide show
  1. diffusers/__init__.py +26 -2
  2. diffusers/commands/fp16_safetensors.py +10 -11
  3. diffusers/configuration_utils.py +13 -8
  4. diffusers/dependency_versions_check.py +0 -1
  5. diffusers/dependency_versions_table.py +5 -5
  6. diffusers/experimental/rl/value_guided_sampling.py +1 -1
  7. diffusers/image_processor.py +463 -51
  8. diffusers/loaders/__init__.py +82 -0
  9. diffusers/loaders/ip_adapter.py +159 -0
  10. diffusers/loaders/lora.py +1553 -0
  11. diffusers/loaders/lora_conversion_utils.py +284 -0
  12. diffusers/loaders/single_file.py +637 -0
  13. diffusers/loaders/textual_inversion.py +455 -0
  14. diffusers/loaders/unet.py +828 -0
  15. diffusers/loaders/utils.py +59 -0
  16. diffusers/models/__init__.py +26 -9
  17. diffusers/models/activations.py +9 -6
  18. diffusers/models/attention.py +301 -29
  19. diffusers/models/attention_flax.py +9 -1
  20. diffusers/models/attention_processor.py +378 -6
  21. diffusers/models/autoencoders/__init__.py +5 -0
  22. diffusers/models/{autoencoder_asym_kl.py → autoencoders/autoencoder_asym_kl.py} +17 -12
  23. diffusers/models/{autoencoder_kl.py → autoencoders/autoencoder_kl.py} +47 -23
  24. diffusers/models/autoencoders/autoencoder_kl_temporal_decoder.py +402 -0
  25. diffusers/models/{autoencoder_tiny.py → autoencoders/autoencoder_tiny.py} +24 -28
  26. diffusers/models/{consistency_decoder_vae.py → autoencoders/consistency_decoder_vae.py} +51 -44
  27. diffusers/models/{vae.py → autoencoders/vae.py} +71 -17
  28. diffusers/models/controlnet.py +59 -39
  29. diffusers/models/controlnet_flax.py +19 -18
  30. diffusers/models/downsampling.py +338 -0
  31. diffusers/models/embeddings.py +112 -29
  32. diffusers/models/embeddings_flax.py +2 -0
  33. diffusers/models/lora.py +131 -1
  34. diffusers/models/modeling_flax_utils.py +14 -8
  35. diffusers/models/modeling_outputs.py +17 -0
  36. diffusers/models/modeling_utils.py +37 -29
  37. diffusers/models/normalization.py +110 -4
  38. diffusers/models/resnet.py +299 -652
  39. diffusers/models/transformer_2d.py +22 -5
  40. diffusers/models/transformer_temporal.py +183 -1
  41. diffusers/models/unet_2d_blocks_flax.py +5 -0
  42. diffusers/models/unet_2d_condition.py +46 -0
  43. diffusers/models/unet_2d_condition_flax.py +13 -13
  44. diffusers/models/unet_3d_blocks.py +957 -173
  45. diffusers/models/unet_3d_condition.py +16 -8
  46. diffusers/models/unet_kandinsky3.py +535 -0
  47. diffusers/models/unet_motion_model.py +48 -33
  48. diffusers/models/unet_spatio_temporal_condition.py +489 -0
  49. diffusers/models/upsampling.py +454 -0
  50. diffusers/models/uvit_2d.py +471 -0
  51. diffusers/models/vae_flax.py +7 -0
  52. diffusers/models/vq_model.py +12 -3
  53. diffusers/optimization.py +16 -9
  54. diffusers/pipelines/__init__.py +137 -76
  55. diffusers/pipelines/amused/__init__.py +62 -0
  56. diffusers/pipelines/amused/pipeline_amused.py +328 -0
  57. diffusers/pipelines/amused/pipeline_amused_img2img.py +347 -0
  58. diffusers/pipelines/amused/pipeline_amused_inpaint.py +378 -0
  59. diffusers/pipelines/animatediff/pipeline_animatediff.py +66 -8
  60. diffusers/pipelines/audioldm/pipeline_audioldm.py +1 -0
  61. diffusers/pipelines/auto_pipeline.py +23 -13
  62. diffusers/pipelines/consistency_models/pipeline_consistency_models.py +1 -0
  63. diffusers/pipelines/controlnet/pipeline_controlnet.py +238 -35
  64. diffusers/pipelines/controlnet/pipeline_controlnet_img2img.py +148 -37
  65. diffusers/pipelines/controlnet/pipeline_controlnet_inpaint.py +155 -41
  66. diffusers/pipelines/controlnet/pipeline_controlnet_inpaint_sd_xl.py +123 -43
  67. diffusers/pipelines/controlnet/pipeline_controlnet_sd_xl.py +216 -39
  68. diffusers/pipelines/controlnet/pipeline_controlnet_sd_xl_img2img.py +106 -34
  69. diffusers/pipelines/dance_diffusion/pipeline_dance_diffusion.py +1 -0
  70. diffusers/pipelines/ddim/pipeline_ddim.py +1 -0
  71. diffusers/pipelines/ddpm/pipeline_ddpm.py +1 -0
  72. diffusers/pipelines/deepfloyd_if/pipeline_if.py +13 -1
  73. diffusers/pipelines/deepfloyd_if/pipeline_if_img2img.py +13 -1
  74. diffusers/pipelines/deepfloyd_if/pipeline_if_img2img_superresolution.py +13 -1
  75. diffusers/pipelines/deepfloyd_if/pipeline_if_inpainting.py +13 -1
  76. diffusers/pipelines/deepfloyd_if/pipeline_if_inpainting_superresolution.py +13 -1
  77. diffusers/pipelines/deepfloyd_if/pipeline_if_superresolution.py +13 -1
  78. diffusers/pipelines/deprecated/__init__.py +153 -0
  79. diffusers/pipelines/{alt_diffusion → deprecated/alt_diffusion}/__init__.py +3 -3
  80. diffusers/pipelines/{alt_diffusion → deprecated/alt_diffusion}/pipeline_alt_diffusion.py +177 -34
  81. diffusers/pipelines/{alt_diffusion → deprecated/alt_diffusion}/pipeline_alt_diffusion_img2img.py +182 -37
  82. diffusers/pipelines/{alt_diffusion → deprecated/alt_diffusion}/pipeline_output.py +1 -1
  83. diffusers/pipelines/{audio_diffusion → deprecated/audio_diffusion}/__init__.py +1 -1
  84. diffusers/pipelines/{audio_diffusion → deprecated/audio_diffusion}/mel.py +2 -2
  85. diffusers/pipelines/{audio_diffusion → deprecated/audio_diffusion}/pipeline_audio_diffusion.py +4 -4
  86. diffusers/pipelines/{latent_diffusion_uncond → deprecated/latent_diffusion_uncond}/__init__.py +1 -1
  87. diffusers/pipelines/{latent_diffusion_uncond → deprecated/latent_diffusion_uncond}/pipeline_latent_diffusion_uncond.py +4 -4
  88. diffusers/pipelines/{pndm → deprecated/pndm}/__init__.py +1 -1
  89. diffusers/pipelines/{pndm → deprecated/pndm}/pipeline_pndm.py +4 -4
  90. diffusers/pipelines/{repaint → deprecated/repaint}/__init__.py +1 -1
  91. diffusers/pipelines/{repaint → deprecated/repaint}/pipeline_repaint.py +5 -5
  92. diffusers/pipelines/{score_sde_ve → deprecated/score_sde_ve}/__init__.py +1 -1
  93. diffusers/pipelines/{score_sde_ve → deprecated/score_sde_ve}/pipeline_score_sde_ve.py +5 -4
  94. diffusers/pipelines/{spectrogram_diffusion → deprecated/spectrogram_diffusion}/__init__.py +6 -6
  95. diffusers/pipelines/{spectrogram_diffusion/continous_encoder.py → deprecated/spectrogram_diffusion/continuous_encoder.py} +2 -2
  96. diffusers/pipelines/{spectrogram_diffusion → deprecated/spectrogram_diffusion}/midi_utils.py +1 -1
  97. diffusers/pipelines/{spectrogram_diffusion → deprecated/spectrogram_diffusion}/notes_encoder.py +2 -2
  98. diffusers/pipelines/{spectrogram_diffusion → deprecated/spectrogram_diffusion}/pipeline_spectrogram_diffusion.py +8 -7
  99. diffusers/pipelines/deprecated/stable_diffusion_variants/__init__.py +55 -0
  100. diffusers/pipelines/{stable_diffusion → deprecated/stable_diffusion_variants}/pipeline_cycle_diffusion.py +34 -13
  101. diffusers/pipelines/{stable_diffusion → deprecated/stable_diffusion_variants}/pipeline_onnx_stable_diffusion_inpaint_legacy.py +7 -6
  102. diffusers/pipelines/{stable_diffusion → deprecated/stable_diffusion_variants}/pipeline_stable_diffusion_inpaint_legacy.py +12 -11
  103. diffusers/pipelines/{stable_diffusion → deprecated/stable_diffusion_variants}/pipeline_stable_diffusion_model_editing.py +17 -11
  104. diffusers/pipelines/{stable_diffusion → deprecated/stable_diffusion_variants}/pipeline_stable_diffusion_paradigms.py +11 -10
  105. diffusers/pipelines/{stable_diffusion → deprecated/stable_diffusion_variants}/pipeline_stable_diffusion_pix2pix_zero.py +14 -13
  106. diffusers/pipelines/{stochastic_karras_ve → deprecated/stochastic_karras_ve}/__init__.py +1 -1
  107. diffusers/pipelines/{stochastic_karras_ve → deprecated/stochastic_karras_ve}/pipeline_stochastic_karras_ve.py +4 -4
  108. diffusers/pipelines/{versatile_diffusion → deprecated/versatile_diffusion}/__init__.py +3 -3
  109. diffusers/pipelines/{versatile_diffusion → deprecated/versatile_diffusion}/modeling_text_unet.py +83 -51
  110. diffusers/pipelines/{versatile_diffusion → deprecated/versatile_diffusion}/pipeline_versatile_diffusion.py +4 -4
  111. diffusers/pipelines/{versatile_diffusion → deprecated/versatile_diffusion}/pipeline_versatile_diffusion_dual_guided.py +7 -6
  112. diffusers/pipelines/{versatile_diffusion → deprecated/versatile_diffusion}/pipeline_versatile_diffusion_image_variation.py +7 -6
  113. diffusers/pipelines/{versatile_diffusion → deprecated/versatile_diffusion}/pipeline_versatile_diffusion_text_to_image.py +7 -6
  114. diffusers/pipelines/{vq_diffusion → deprecated/vq_diffusion}/__init__.py +3 -3
  115. diffusers/pipelines/{vq_diffusion → deprecated/vq_diffusion}/pipeline_vq_diffusion.py +5 -5
  116. diffusers/pipelines/dit/pipeline_dit.py +1 -0
  117. diffusers/pipelines/kandinsky2_2/pipeline_kandinsky2_2.py +1 -1
  118. diffusers/pipelines/kandinsky2_2/pipeline_kandinsky2_2_combined.py +3 -3
  119. diffusers/pipelines/kandinsky2_2/pipeline_kandinsky2_2_img2img.py +1 -1
  120. diffusers/pipelines/kandinsky2_2/pipeline_kandinsky2_2_inpainting.py +1 -1
  121. diffusers/pipelines/kandinsky2_2/pipeline_kandinsky2_2_prior.py +1 -1
  122. diffusers/pipelines/kandinsky3/__init__.py +49 -0
  123. diffusers/pipelines/kandinsky3/convert_kandinsky3_unet.py +98 -0
  124. diffusers/pipelines/kandinsky3/pipeline_kandinsky3.py +589 -0
  125. diffusers/pipelines/kandinsky3/pipeline_kandinsky3_img2img.py +654 -0
  126. diffusers/pipelines/latent_consistency_models/pipeline_latent_consistency_img2img.py +111 -11
  127. diffusers/pipelines/latent_consistency_models/pipeline_latent_consistency_text2img.py +102 -9
  128. diffusers/pipelines/latent_diffusion/pipeline_latent_diffusion.py +1 -0
  129. diffusers/pipelines/musicldm/pipeline_musicldm.py +1 -1
  130. diffusers/pipelines/onnx_utils.py +8 -5
  131. diffusers/pipelines/paint_by_example/pipeline_paint_by_example.py +7 -2
  132. diffusers/pipelines/pipeline_flax_utils.py +11 -8
  133. diffusers/pipelines/pipeline_utils.py +63 -42
  134. diffusers/pipelines/pixart_alpha/pipeline_pixart_alpha.py +247 -38
  135. diffusers/pipelines/shap_e/pipeline_shap_e_img2img.py +3 -3
  136. diffusers/pipelines/stable_diffusion/__init__.py +37 -65
  137. diffusers/pipelines/stable_diffusion/convert_from_ckpt.py +75 -78
  138. diffusers/pipelines/stable_diffusion/pipeline_flax_stable_diffusion.py +2 -2
  139. diffusers/pipelines/stable_diffusion/pipeline_onnx_stable_diffusion_img2img.py +2 -4
  140. diffusers/pipelines/stable_diffusion/pipeline_onnx_stable_diffusion_inpaint.py +1 -0
  141. diffusers/pipelines/stable_diffusion/pipeline_stable_diffusion.py +174 -11
  142. diffusers/pipelines/stable_diffusion/pipeline_stable_diffusion_depth2img.py +8 -3
  143. diffusers/pipelines/stable_diffusion/pipeline_stable_diffusion_image_variation.py +1 -0
  144. diffusers/pipelines/stable_diffusion/pipeline_stable_diffusion_img2img.py +178 -11
  145. diffusers/pipelines/stable_diffusion/pipeline_stable_diffusion_inpaint.py +224 -13
  146. diffusers/pipelines/stable_diffusion/pipeline_stable_diffusion_instruct_pix2pix.py +74 -20
  147. diffusers/pipelines/stable_diffusion/pipeline_stable_diffusion_latent_upscale.py +4 -0
  148. diffusers/pipelines/stable_diffusion/pipeline_stable_diffusion_upscale.py +7 -0
  149. diffusers/pipelines/stable_diffusion/pipeline_stable_unclip.py +5 -0
  150. diffusers/pipelines/stable_diffusion/pipeline_stable_unclip_img2img.py +5 -0
  151. diffusers/pipelines/stable_diffusion_attend_and_excite/__init__.py +48 -0
  152. diffusers/pipelines/{stable_diffusion → stable_diffusion_attend_and_excite}/pipeline_stable_diffusion_attend_and_excite.py +6 -2
  153. diffusers/pipelines/stable_diffusion_diffedit/__init__.py +48 -0
  154. diffusers/pipelines/{stable_diffusion → stable_diffusion_diffedit}/pipeline_stable_diffusion_diffedit.py +3 -3
  155. diffusers/pipelines/stable_diffusion_gligen/__init__.py +50 -0
  156. diffusers/pipelines/{stable_diffusion → stable_diffusion_gligen}/pipeline_stable_diffusion_gligen.py +3 -2
  157. diffusers/pipelines/{stable_diffusion → stable_diffusion_gligen}/pipeline_stable_diffusion_gligen_text_image.py +4 -3
  158. diffusers/pipelines/stable_diffusion_k_diffusion/__init__.py +60 -0
  159. diffusers/pipelines/{stable_diffusion → stable_diffusion_k_diffusion}/pipeline_stable_diffusion_k_diffusion.py +7 -1
  160. diffusers/pipelines/stable_diffusion_ldm3d/__init__.py +48 -0
  161. diffusers/pipelines/{stable_diffusion → stable_diffusion_ldm3d}/pipeline_stable_diffusion_ldm3d.py +51 -7
  162. diffusers/pipelines/stable_diffusion_panorama/__init__.py +48 -0
  163. diffusers/pipelines/{stable_diffusion → stable_diffusion_panorama}/pipeline_stable_diffusion_panorama.py +57 -8
  164. diffusers/pipelines/stable_diffusion_safe/pipeline_stable_diffusion_safe.py +58 -6
  165. diffusers/pipelines/stable_diffusion_sag/__init__.py +48 -0
  166. diffusers/pipelines/{stable_diffusion → stable_diffusion_sag}/pipeline_stable_diffusion_sag.py +68 -10
  167. diffusers/pipelines/stable_diffusion_xl/pipeline_stable_diffusion_xl.py +194 -17
  168. diffusers/pipelines/stable_diffusion_xl/pipeline_stable_diffusion_xl_img2img.py +205 -16
  169. diffusers/pipelines/stable_diffusion_xl/pipeline_stable_diffusion_xl_inpaint.py +206 -17
  170. diffusers/pipelines/stable_diffusion_xl/pipeline_stable_diffusion_xl_instruct_pix2pix.py +23 -17
  171. diffusers/pipelines/stable_video_diffusion/__init__.py +58 -0
  172. diffusers/pipelines/stable_video_diffusion/pipeline_stable_video_diffusion.py +652 -0
  173. diffusers/pipelines/t2i_adapter/pipeline_stable_diffusion_adapter.py +108 -12
  174. diffusers/pipelines/t2i_adapter/pipeline_stable_diffusion_xl_adapter.py +115 -14
  175. diffusers/pipelines/text_to_video_synthesis/__init__.py +2 -0
  176. diffusers/pipelines/text_to_video_synthesis/pipeline_text_to_video_synth.py +6 -0
  177. diffusers/pipelines/text_to_video_synthesis/pipeline_text_to_video_synth_img2img.py +23 -3
  178. diffusers/pipelines/text_to_video_synthesis/pipeline_text_to_video_zero.py +334 -10
  179. diffusers/pipelines/text_to_video_synthesis/pipeline_text_to_video_zero_sdxl.py +1331 -0
  180. diffusers/pipelines/unclip/pipeline_unclip.py +2 -1
  181. diffusers/pipelines/unclip/pipeline_unclip_image_variation.py +1 -0
  182. diffusers/pipelines/wuerstchen/modeling_paella_vq_model.py +1 -1
  183. diffusers/pipelines/wuerstchen/modeling_wuerstchen_common.py +14 -4
  184. diffusers/pipelines/wuerstchen/modeling_wuerstchen_prior.py +9 -5
  185. diffusers/pipelines/wuerstchen/pipeline_wuerstchen.py +1 -1
  186. diffusers/pipelines/wuerstchen/pipeline_wuerstchen_combined.py +2 -2
  187. diffusers/pipelines/wuerstchen/pipeline_wuerstchen_prior.py +5 -1
  188. diffusers/schedulers/__init__.py +4 -4
  189. diffusers/schedulers/deprecated/__init__.py +50 -0
  190. diffusers/schedulers/{scheduling_karras_ve.py → deprecated/scheduling_karras_ve.py} +4 -4
  191. diffusers/schedulers/{scheduling_sde_vp.py → deprecated/scheduling_sde_vp.py} +4 -6
  192. diffusers/schedulers/scheduling_amused.py +162 -0
  193. diffusers/schedulers/scheduling_consistency_models.py +2 -0
  194. diffusers/schedulers/scheduling_ddim.py +1 -3
  195. diffusers/schedulers/scheduling_ddim_inverse.py +2 -7
  196. diffusers/schedulers/scheduling_ddim_parallel.py +1 -3
  197. diffusers/schedulers/scheduling_ddpm.py +47 -3
  198. diffusers/schedulers/scheduling_ddpm_parallel.py +47 -3
  199. diffusers/schedulers/scheduling_deis_multistep.py +28 -6
  200. diffusers/schedulers/scheduling_dpmsolver_multistep.py +28 -6
  201. diffusers/schedulers/scheduling_dpmsolver_multistep_inverse.py +28 -6
  202. diffusers/schedulers/scheduling_dpmsolver_sde.py +3 -3
  203. diffusers/schedulers/scheduling_dpmsolver_singlestep.py +28 -6
  204. diffusers/schedulers/scheduling_euler_ancestral_discrete.py +59 -3
  205. diffusers/schedulers/scheduling_euler_discrete.py +102 -16
  206. diffusers/schedulers/scheduling_heun_discrete.py +17 -5
  207. diffusers/schedulers/scheduling_k_dpm_2_ancestral_discrete.py +17 -5
  208. diffusers/schedulers/scheduling_k_dpm_2_discrete.py +17 -5
  209. diffusers/schedulers/scheduling_lcm.py +123 -29
  210. diffusers/schedulers/scheduling_lms_discrete.py +3 -3
  211. diffusers/schedulers/scheduling_pndm.py +1 -3
  212. diffusers/schedulers/scheduling_repaint.py +1 -3
  213. diffusers/schedulers/scheduling_unipc_multistep.py +28 -6
  214. diffusers/schedulers/scheduling_utils.py +3 -1
  215. diffusers/schedulers/scheduling_utils_flax.py +3 -1
  216. diffusers/training_utils.py +1 -1
  217. diffusers/utils/__init__.py +1 -2
  218. diffusers/utils/constants.py +10 -12
  219. diffusers/utils/dummy_pt_objects.py +75 -0
  220. diffusers/utils/dummy_torch_and_transformers_objects.py +105 -0
  221. diffusers/utils/dynamic_modules_utils.py +18 -22
  222. diffusers/utils/export_utils.py +8 -3
  223. diffusers/utils/hub_utils.py +24 -36
  224. diffusers/utils/logging.py +11 -11
  225. diffusers/utils/outputs.py +5 -5
  226. diffusers/utils/peft_utils.py +88 -44
  227. diffusers/utils/state_dict_utils.py +8 -0
  228. diffusers/utils/testing_utils.py +199 -1
  229. diffusers/utils/torch_utils.py +4 -4
  230. {diffusers-0.23.1.dist-info → diffusers-0.25.0.dist-info}/METADATA +86 -69
  231. diffusers-0.25.0.dist-info/RECORD +360 -0
  232. {diffusers-0.23.1.dist-info → diffusers-0.25.0.dist-info}/WHEEL +1 -1
  233. {diffusers-0.23.1.dist-info → diffusers-0.25.0.dist-info}/entry_points.txt +0 -1
  234. diffusers/loaders.py +0 -3336
  235. diffusers-0.23.1.dist-info/RECORD +0 -323
  236. /diffusers/pipelines/{alt_diffusion → deprecated/alt_diffusion}/modeling_roberta_series.py +0 -0
  237. {diffusers-0.23.1.dist-info → diffusers-0.25.0.dist-info}/LICENSE +0 -0
  238. {diffusers-0.23.1.dist-info → diffusers-0.25.0.dist-info}/top_level.txt +0 -0
@@ -13,12 +13,12 @@
13
13
  # limitations under the License.
14
14
 
15
15
  import warnings
16
- from typing import List, Optional, Union
16
+ from typing import List, Optional, Tuple, Union
17
17
 
18
18
  import numpy as np
19
19
  import PIL.Image
20
20
  import torch
21
- from PIL import Image
21
+ from PIL import Image, ImageFilter, ImageOps
22
22
 
23
23
  from .configuration_utils import ConfigMixin, register_to_config
24
24
  from .utils import CONFIG_NAME, PIL_INTERPOLATION, deprecate
@@ -33,6 +33,15 @@ PipelineImageInput = Union[
33
33
  List[torch.FloatTensor],
34
34
  ]
35
35
 
36
+ PipelineDepthInput = Union[
37
+ PIL.Image.Image,
38
+ np.ndarray,
39
+ torch.FloatTensor,
40
+ List[PIL.Image.Image],
41
+ List[np.ndarray],
42
+ List[torch.FloatTensor],
43
+ ]
44
+
36
45
 
37
46
  class VaeImageProcessor(ConfigMixin):
38
47
  """
@@ -79,7 +88,7 @@ class VaeImageProcessor(ConfigMixin):
79
88
  self.config.do_convert_rgb = False
80
89
 
81
90
  @staticmethod
82
- def numpy_to_pil(images: np.ndarray) -> PIL.Image.Image:
91
+ def numpy_to_pil(images: np.ndarray) -> List[PIL.Image.Image]:
83
92
  """
84
93
  Convert a numpy image or a batch of images to a PIL image.
85
94
  """
@@ -126,14 +135,14 @@ class VaeImageProcessor(ConfigMixin):
126
135
  return images
127
136
 
128
137
  @staticmethod
129
- def normalize(images):
138
+ def normalize(images: Union[np.ndarray, torch.Tensor]) -> Union[np.ndarray, torch.Tensor]:
130
139
  """
131
140
  Normalize an image array to [-1,1].
132
141
  """
133
142
  return 2.0 * images - 1.0
134
143
 
135
144
  @staticmethod
136
- def denormalize(images):
145
+ def denormalize(images: Union[np.ndarray, torch.Tensor]) -> Union[np.ndarray, torch.Tensor]:
137
146
  """
138
147
  Denormalize an image array to [0,1].
139
148
  """
@@ -157,12 +166,250 @@ class VaeImageProcessor(ConfigMixin):
157
166
 
158
167
  return image
159
168
 
169
+ @staticmethod
170
+ def blur(image: PIL.Image.Image, blur_factor: int = 4) -> PIL.Image.Image:
171
+ """
172
+ Blurs an image.
173
+ """
174
+ image = image.filter(ImageFilter.GaussianBlur(blur_factor))
175
+
176
+ return image
177
+
178
+ @staticmethod
179
+ def get_crop_region(mask_image: PIL.Image.Image, width: int, height: int, pad=0):
180
+ """
181
+ Finds a rectangular region that contains all masked ares in an image, and expands region to match the aspect ratio of the original image;
182
+ for example, if user drew mask in a 128x32 region, and the dimensions for processing are 512x512, the region will be expanded to 128x128.
183
+
184
+ Args:
185
+ mask_image (PIL.Image.Image): Mask image.
186
+ width (int): Width of the image to be processed.
187
+ height (int): Height of the image to be processed.
188
+ pad (int, optional): Padding to be added to the crop region. Defaults to 0.
189
+
190
+ Returns:
191
+ tuple: (x1, y1, x2, y2) represent a rectangular region that contains all masked ares in an image and matches the original aspect ratio.
192
+ """
193
+
194
+ mask_image = mask_image.convert("L")
195
+ mask = np.array(mask_image)
196
+
197
+ # 1. find a rectangular region that contains all masked ares in an image
198
+ h, w = mask.shape
199
+ crop_left = 0
200
+ for i in range(w):
201
+ if not (mask[:, i] == 0).all():
202
+ break
203
+ crop_left += 1
204
+
205
+ crop_right = 0
206
+ for i in reversed(range(w)):
207
+ if not (mask[:, i] == 0).all():
208
+ break
209
+ crop_right += 1
210
+
211
+ crop_top = 0
212
+ for i in range(h):
213
+ if not (mask[i] == 0).all():
214
+ break
215
+ crop_top += 1
216
+
217
+ crop_bottom = 0
218
+ for i in reversed(range(h)):
219
+ if not (mask[i] == 0).all():
220
+ break
221
+ crop_bottom += 1
222
+
223
+ # 2. add padding to the crop region
224
+ x1, y1, x2, y2 = (
225
+ int(max(crop_left - pad, 0)),
226
+ int(max(crop_top - pad, 0)),
227
+ int(min(w - crop_right + pad, w)),
228
+ int(min(h - crop_bottom + pad, h)),
229
+ )
230
+
231
+ # 3. expands crop region to match the aspect ratio of the image to be processed
232
+ ratio_crop_region = (x2 - x1) / (y2 - y1)
233
+ ratio_processing = width / height
234
+
235
+ if ratio_crop_region > ratio_processing:
236
+ desired_height = (x2 - x1) / ratio_processing
237
+ desired_height_diff = int(desired_height - (y2 - y1))
238
+ y1 -= desired_height_diff // 2
239
+ y2 += desired_height_diff - desired_height_diff // 2
240
+ if y2 >= mask_image.height:
241
+ diff = y2 - mask_image.height
242
+ y2 -= diff
243
+ y1 -= diff
244
+ if y1 < 0:
245
+ y2 -= y1
246
+ y1 -= y1
247
+ if y2 >= mask_image.height:
248
+ y2 = mask_image.height
249
+ else:
250
+ desired_width = (y2 - y1) * ratio_processing
251
+ desired_width_diff = int(desired_width - (x2 - x1))
252
+ x1 -= desired_width_diff // 2
253
+ x2 += desired_width_diff - desired_width_diff // 2
254
+ if x2 >= mask_image.width:
255
+ diff = x2 - mask_image.width
256
+ x2 -= diff
257
+ x1 -= diff
258
+ if x1 < 0:
259
+ x2 -= x1
260
+ x1 -= x1
261
+ if x2 >= mask_image.width:
262
+ x2 = mask_image.width
263
+
264
+ return x1, y1, x2, y2
265
+
266
+ def _resize_and_fill(
267
+ self,
268
+ image: PIL.Image.Image,
269
+ width: int,
270
+ height: int,
271
+ ) -> PIL.Image.Image:
272
+ """
273
+ Resize the image to fit within the specified width and height, maintaining the aspect ratio, and then center the image within the dimensions, filling empty with data from image.
274
+
275
+ Args:
276
+ image: The image to resize.
277
+ width: The width to resize the image to.
278
+ height: The height to resize the image to.
279
+ """
280
+
281
+ ratio = width / height
282
+ src_ratio = image.width / image.height
283
+
284
+ src_w = width if ratio < src_ratio else image.width * height // image.height
285
+ src_h = height if ratio >= src_ratio else image.height * width // image.width
286
+
287
+ resized = image.resize((src_w, src_h), resample=PIL_INTERPOLATION["lanczos"])
288
+ res = Image.new("RGB", (width, height))
289
+ res.paste(resized, box=(width // 2 - src_w // 2, height // 2 - src_h // 2))
290
+
291
+ if ratio < src_ratio:
292
+ fill_height = height // 2 - src_h // 2
293
+ if fill_height > 0:
294
+ res.paste(resized.resize((width, fill_height), box=(0, 0, width, 0)), box=(0, 0))
295
+ res.paste(
296
+ resized.resize((width, fill_height), box=(0, resized.height, width, resized.height)),
297
+ box=(0, fill_height + src_h),
298
+ )
299
+ elif ratio > src_ratio:
300
+ fill_width = width // 2 - src_w // 2
301
+ if fill_width > 0:
302
+ res.paste(resized.resize((fill_width, height), box=(0, 0, 0, height)), box=(0, 0))
303
+ res.paste(
304
+ resized.resize((fill_width, height), box=(resized.width, 0, resized.width, height)),
305
+ box=(fill_width + src_w, 0),
306
+ )
307
+
308
+ return res
309
+
310
+ def _resize_and_crop(
311
+ self,
312
+ image: PIL.Image.Image,
313
+ width: int,
314
+ height: int,
315
+ ) -> PIL.Image.Image:
316
+ """
317
+ Resize the image to fit within the specified width and height, maintaining the aspect ratio, and then center the image within the dimensions, cropping the excess.
318
+
319
+ Args:
320
+ image: The image to resize.
321
+ width: The width to resize the image to.
322
+ height: The height to resize the image to.
323
+ """
324
+ ratio = width / height
325
+ src_ratio = image.width / image.height
326
+
327
+ src_w = width if ratio > src_ratio else image.width * height // image.height
328
+ src_h = height if ratio <= src_ratio else image.height * width // image.width
329
+
330
+ resized = image.resize((src_w, src_h), resample=PIL_INTERPOLATION["lanczos"])
331
+ res = Image.new("RGB", (width, height))
332
+ res.paste(resized, box=(width // 2 - src_w // 2, height // 2 - src_h // 2))
333
+ return res
334
+
335
+ def resize(
336
+ self,
337
+ image: Union[PIL.Image.Image, np.ndarray, torch.Tensor],
338
+ height: int,
339
+ width: int,
340
+ resize_mode: str = "default", # "defalt", "fill", "crop"
341
+ ) -> Union[PIL.Image.Image, np.ndarray, torch.Tensor]:
342
+ """
343
+ Resize image.
344
+
345
+ Args:
346
+ image (`PIL.Image.Image`, `np.ndarray` or `torch.Tensor`):
347
+ The image input, can be a PIL image, numpy array or pytorch tensor.
348
+ height (`int`):
349
+ The height to resize to.
350
+ width (`int`):
351
+ The width to resize to.
352
+ resize_mode (`str`, *optional*, defaults to `default`):
353
+ The resize mode to use, can be one of `default` or `fill`. If `default`, will resize the image to fit
354
+ within the specified width and height, and it may not maintaining the original aspect ratio.
355
+ If `fill`, will resize the image to fit within the specified width and height, maintaining the aspect ratio, and then center the image
356
+ within the dimensions, filling empty with data from image.
357
+ If `crop`, will resize the image to fit within the specified width and height, maintaining the aspect ratio, and then center the image
358
+ within the dimensions, cropping the excess.
359
+ Note that resize_mode `fill` and `crop` are only supported for PIL image input.
360
+
361
+ Returns:
362
+ `PIL.Image.Image`, `np.ndarray` or `torch.Tensor`:
363
+ The resized image.
364
+ """
365
+ if resize_mode != "default" and not isinstance(image, PIL.Image.Image):
366
+ raise ValueError(f"Only PIL image input is supported for resize_mode {resize_mode}")
367
+ if isinstance(image, PIL.Image.Image):
368
+ if resize_mode == "default":
369
+ image = image.resize((width, height), resample=PIL_INTERPOLATION[self.config.resample])
370
+ elif resize_mode == "fill":
371
+ image = self._resize_and_fill(image, width, height)
372
+ elif resize_mode == "crop":
373
+ image = self._resize_and_crop(image, width, height)
374
+ else:
375
+ raise ValueError(f"resize_mode {resize_mode} is not supported")
376
+
377
+ elif isinstance(image, torch.Tensor):
378
+ image = torch.nn.functional.interpolate(
379
+ image,
380
+ size=(height, width),
381
+ )
382
+ elif isinstance(image, np.ndarray):
383
+ image = self.numpy_to_pt(image)
384
+ image = torch.nn.functional.interpolate(
385
+ image,
386
+ size=(height, width),
387
+ )
388
+ image = self.pt_to_numpy(image)
389
+ return image
390
+
391
+ def binarize(self, image: PIL.Image.Image) -> PIL.Image.Image:
392
+ """
393
+ Create a mask.
394
+
395
+ Args:
396
+ image (`PIL.Image.Image`):
397
+ The image input, should be a PIL image.
398
+
399
+ Returns:
400
+ `PIL.Image.Image`:
401
+ The binarized image. Values less than 0.5 are set to 0, values greater than 0.5 are set to 1.
402
+ """
403
+ image[image < 0.5] = 0
404
+ image[image >= 0.5] = 1
405
+ return image
406
+
160
407
  def get_default_height_width(
161
408
  self,
162
- image: [PIL.Image.Image, np.ndarray, torch.Tensor],
409
+ image: Union[PIL.Image.Image, np.ndarray, torch.Tensor],
163
410
  height: Optional[int] = None,
164
411
  width: Optional[int] = None,
165
- ):
412
+ ) -> Tuple[int, int]:
166
413
  """
167
414
  This function return the height and width that are downscaled to the next integer multiple of
168
415
  `vae_scale_factor`.
@@ -200,47 +447,34 @@ class VaeImageProcessor(ConfigMixin):
200
447
 
201
448
  return height, width
202
449
 
203
- def resize(
204
- self,
205
- image: [PIL.Image.Image, np.ndarray, torch.Tensor],
206
- height: Optional[int] = None,
207
- width: Optional[int] = None,
208
- ) -> [PIL.Image.Image, np.ndarray, torch.Tensor]:
209
- """
210
- Resize image.
211
- """
212
- if isinstance(image, PIL.Image.Image):
213
- image = image.resize((width, height), resample=PIL_INTERPOLATION[self.config.resample])
214
- elif isinstance(image, torch.Tensor):
215
- image = torch.nn.functional.interpolate(
216
- image,
217
- size=(height, width),
218
- )
219
- elif isinstance(image, np.ndarray):
220
- image = self.numpy_to_pt(image)
221
- image = torch.nn.functional.interpolate(
222
- image,
223
- size=(height, width),
224
- )
225
- image = self.pt_to_numpy(image)
226
- return image
227
-
228
- def binarize(self, image: PIL.Image.Image) -> PIL.Image.Image:
229
- """
230
- create a mask
231
- """
232
- image[image < 0.5] = 0
233
- image[image >= 0.5] = 1
234
- return image
235
-
236
450
  def preprocess(
237
451
  self,
238
- image: Union[torch.FloatTensor, PIL.Image.Image, np.ndarray],
452
+ image: PipelineImageInput,
239
453
  height: Optional[int] = None,
240
454
  width: Optional[int] = None,
455
+ resize_mode: str = "default", # "defalt", "fill", "crop"
456
+ crops_coords: Optional[Tuple[int, int, int, int]] = None,
241
457
  ) -> torch.Tensor:
242
458
  """
243
- Preprocess the image input. Accepted formats are PIL images, NumPy arrays or PyTorch tensors.
459
+ Preprocess the image input.
460
+
461
+ Args:
462
+ image (`pipeline_image_input`):
463
+ The image input, accepted formats are PIL images, NumPy arrays, PyTorch tensors; Also accept list of supported formats.
464
+ height (`int`, *optional*, defaults to `None`):
465
+ The height in preprocessed image. If `None`, will use the `get_default_height_width()` to get default height.
466
+ width (`int`, *optional*`, defaults to `None`):
467
+ The width in preprocessed. If `None`, will use get_default_height_width()` to get the default width.
468
+ resize_mode (`str`, *optional*, defaults to `default`):
469
+ The resize mode, can be one of `default` or `fill`. If `default`, will resize the image to fit
470
+ within the specified width and height, and it may not maintaining the original aspect ratio.
471
+ If `fill`, will resize the image to fit within the specified width and height, maintaining the aspect ratio, and then center the image
472
+ within the dimensions, filling empty with data from image.
473
+ If `crop`, will resize the image to fit within the specified width and height, maintaining the aspect ratio, and then center the image
474
+ within the dimensions, cropping the excess.
475
+ Note that resize_mode `fill` and `crop` are only supported for PIL image input.
476
+ crops_coords (`List[Tuple[int, int, int, int]]`, *optional*, defaults to `None`):
477
+ The crop coordinates for each image in the batch. If `None`, will not crop the image.
244
478
  """
245
479
  supported_formats = (PIL.Image.Image, np.ndarray, torch.Tensor)
246
480
 
@@ -270,13 +504,15 @@ class VaeImageProcessor(ConfigMixin):
270
504
  )
271
505
 
272
506
  if isinstance(image[0], PIL.Image.Image):
507
+ if crops_coords is not None:
508
+ image = [i.crop(crops_coords) for i in image]
509
+ if self.config.do_resize:
510
+ height, width = self.get_default_height_width(image[0], height, width)
511
+ image = [self.resize(i, height, width, resize_mode=resize_mode) for i in image]
273
512
  if self.config.do_convert_rgb:
274
513
  image = [self.convert_to_rgb(i) for i in image]
275
514
  elif self.config.do_convert_grayscale:
276
515
  image = [self.convert_to_grayscale(i) for i in image]
277
- if self.config.do_resize:
278
- height, width = self.get_default_height_width(image[0], height, width)
279
- image = [self.resize(i, height, width) for i in image]
280
516
  image = self.pil_to_numpy(image) # to np
281
517
  image = self.numpy_to_pt(image) # to pt
282
518
 
@@ -306,7 +542,7 @@ class VaeImageProcessor(ConfigMixin):
306
542
 
307
543
  # expected range [0,1], normalize to [-1,1]
308
544
  do_normalize = self.config.do_normalize
309
- if image.min() < 0 and do_normalize:
545
+ if do_normalize and image.min() < 0:
310
546
  warnings.warn(
311
547
  "Passing `image` as torch tensor with value range in [-1,1] is deprecated. The expected value range for image tensor is [0,1] "
312
548
  f"when passing as pytorch tensor or numpy Array. You passed `image` with value range [{image.min()},{image.max()}]",
@@ -327,7 +563,23 @@ class VaeImageProcessor(ConfigMixin):
327
563
  image: torch.FloatTensor,
328
564
  output_type: str = "pil",
329
565
  do_denormalize: Optional[List[bool]] = None,
330
- ):
566
+ ) -> Union[PIL.Image.Image, np.ndarray, torch.FloatTensor]:
567
+ """
568
+ Postprocess the image output from tensor to `output_type`.
569
+
570
+ Args:
571
+ image (`torch.FloatTensor`):
572
+ The image input, should be a pytorch tensor with shape `B x C x H x W`.
573
+ output_type (`str`, *optional*, defaults to `pil`):
574
+ The output type of the image, can be one of `pil`, `np`, `pt`, `latent`.
575
+ do_denormalize (`List[bool]`, *optional*, defaults to `None`):
576
+ Whether to denormalize the image to [0,1]. If `None`, will use the value of `do_normalize` in the
577
+ `VaeImageProcessor` config.
578
+
579
+ Returns:
580
+ `PIL.Image.Image`, `np.ndarray` or `torch.FloatTensor`:
581
+ The postprocessed image.
582
+ """
331
583
  if not isinstance(image, torch.Tensor):
332
584
  raise ValueError(
333
585
  f"Input for postprocessing is in incorrect format: {type(image)}. We only support pytorch tensor"
@@ -361,6 +613,39 @@ class VaeImageProcessor(ConfigMixin):
361
613
  if output_type == "pil":
362
614
  return self.numpy_to_pil(image)
363
615
 
616
+ def apply_overlay(
617
+ self,
618
+ mask: PIL.Image.Image,
619
+ init_image: PIL.Image.Image,
620
+ image: PIL.Image.Image,
621
+ crop_coords: Optional[Tuple[int, int, int, int]] = None,
622
+ ) -> PIL.Image.Image:
623
+ """
624
+ overlay the inpaint output to the original image
625
+ """
626
+
627
+ width, height = image.width, image.height
628
+
629
+ init_image = self.resize(init_image, width=width, height=height)
630
+ mask = self.resize(mask, width=width, height=height)
631
+
632
+ init_image_masked = PIL.Image.new("RGBa", (width, height))
633
+ init_image_masked.paste(init_image.convert("RGBA").convert("RGBa"), mask=ImageOps.invert(mask.convert("L")))
634
+ init_image_masked = init_image_masked.convert("RGBA")
635
+
636
+ if crop_coords is not None:
637
+ x, y, w, h = crop_coords
638
+ base_image = PIL.Image.new("RGBA", (width, height))
639
+ image = self.resize(image, height=h, width=w, resize_mode="crop")
640
+ base_image.paste(image, (x, y))
641
+ image = base_image.convert("RGB")
642
+
643
+ image = image.convert("RGBA")
644
+ image.alpha_composite(init_image_masked)
645
+ image = image.convert("RGB")
646
+
647
+ return image
648
+
364
649
 
365
650
  class VaeImageProcessorLDM3D(VaeImageProcessor):
366
651
  """
@@ -390,7 +675,7 @@ class VaeImageProcessorLDM3D(VaeImageProcessor):
390
675
  super().__init__()
391
676
 
392
677
  @staticmethod
393
- def numpy_to_pil(images):
678
+ def numpy_to_pil(images: np.ndarray) -> List[PIL.Image.Image]:
394
679
  """
395
680
  Convert a NumPy image or a batch of images to a PIL image.
396
681
  """
@@ -406,7 +691,19 @@ class VaeImageProcessorLDM3D(VaeImageProcessor):
406
691
  return pil_images
407
692
 
408
693
  @staticmethod
409
- def rgblike_to_depthmap(image):
694
+ def depth_pil_to_numpy(images: Union[List[PIL.Image.Image], PIL.Image.Image]) -> np.ndarray:
695
+ """
696
+ Convert a PIL image or a list of PIL images to NumPy arrays.
697
+ """
698
+ if not isinstance(images, list):
699
+ images = [images]
700
+
701
+ images = [np.array(image).astype(np.float32) / (2**16 - 1) for image in images]
702
+ images = np.stack(images, axis=0)
703
+ return images
704
+
705
+ @staticmethod
706
+ def rgblike_to_depthmap(image: Union[np.ndarray, torch.Tensor]) -> Union[np.ndarray, torch.Tensor]:
410
707
  """
411
708
  Args:
412
709
  image: RGB-like depth image
@@ -416,7 +713,7 @@ class VaeImageProcessorLDM3D(VaeImageProcessor):
416
713
  """
417
714
  return image[:, :, 1] * 2**8 + image[:, :, 2]
418
715
 
419
- def numpy_to_depth(self, images):
716
+ def numpy_to_depth(self, images: np.ndarray) -> List[PIL.Image.Image]:
420
717
  """
421
718
  Convert a NumPy depth image or a batch of images to a PIL image.
422
719
  """
@@ -441,7 +738,23 @@ class VaeImageProcessorLDM3D(VaeImageProcessor):
441
738
  image: torch.FloatTensor,
442
739
  output_type: str = "pil",
443
740
  do_denormalize: Optional[List[bool]] = None,
444
- ):
741
+ ) -> Union[PIL.Image.Image, np.ndarray, torch.FloatTensor]:
742
+ """
743
+ Postprocess the image output from tensor to `output_type`.
744
+
745
+ Args:
746
+ image (`torch.FloatTensor`):
747
+ The image input, should be a pytorch tensor with shape `B x C x H x W`.
748
+ output_type (`str`, *optional*, defaults to `pil`):
749
+ The output type of the image, can be one of `pil`, `np`, `pt`, `latent`.
750
+ do_denormalize (`List[bool]`, *optional*, defaults to `None`):
751
+ Whether to denormalize the image to [0,1]. If `None`, will use the value of `do_normalize` in the
752
+ `VaeImageProcessor` config.
753
+
754
+ Returns:
755
+ `PIL.Image.Image`, `np.ndarray` or `torch.FloatTensor`:
756
+ The postprocessed image.
757
+ """
445
758
  if not isinstance(image, torch.Tensor):
446
759
  raise ValueError(
447
760
  f"Input for postprocessing is in incorrect format: {type(image)}. We only support pytorch tensor"
@@ -474,3 +787,102 @@ class VaeImageProcessorLDM3D(VaeImageProcessor):
474
787
  return self.numpy_to_pil(image), self.numpy_to_depth(image)
475
788
  else:
476
789
  raise Exception(f"This type {output_type} is not supported")
790
+
791
+ def preprocess(
792
+ self,
793
+ rgb: Union[torch.FloatTensor, PIL.Image.Image, np.ndarray],
794
+ depth: Union[torch.FloatTensor, PIL.Image.Image, np.ndarray],
795
+ height: Optional[int] = None,
796
+ width: Optional[int] = None,
797
+ target_res: Optional[int] = None,
798
+ ) -> torch.Tensor:
799
+ """
800
+ Preprocess the image input. Accepted formats are PIL images, NumPy arrays or PyTorch tensors.
801
+ """
802
+ supported_formats = (PIL.Image.Image, np.ndarray, torch.Tensor)
803
+
804
+ # Expand the missing dimension for 3-dimensional pytorch tensor or numpy array that represents grayscale image
805
+ if self.config.do_convert_grayscale and isinstance(rgb, (torch.Tensor, np.ndarray)) and rgb.ndim == 3:
806
+ raise Exception("This is not yet supported")
807
+
808
+ if isinstance(rgb, supported_formats):
809
+ rgb = [rgb]
810
+ depth = [depth]
811
+ elif not (isinstance(rgb, list) and all(isinstance(i, supported_formats) for i in rgb)):
812
+ raise ValueError(
813
+ f"Input is in incorrect format: {[type(i) for i in rgb]}. Currently, we only support {', '.join(supported_formats)}"
814
+ )
815
+
816
+ if isinstance(rgb[0], PIL.Image.Image):
817
+ if self.config.do_convert_rgb:
818
+ raise Exception("This is not yet supported")
819
+ # rgb = [self.convert_to_rgb(i) for i in rgb]
820
+ # depth = [self.convert_to_depth(i) for i in depth] #TODO define convert_to_depth
821
+ if self.config.do_resize or target_res:
822
+ height, width = self.get_default_height_width(rgb[0], height, width) if not target_res else target_res
823
+ rgb = [self.resize(i, height, width) for i in rgb]
824
+ depth = [self.resize(i, height, width) for i in depth]
825
+ rgb = self.pil_to_numpy(rgb) # to np
826
+ rgb = self.numpy_to_pt(rgb) # to pt
827
+
828
+ depth = self.depth_pil_to_numpy(depth) # to np
829
+ depth = self.numpy_to_pt(depth) # to pt
830
+
831
+ elif isinstance(rgb[0], np.ndarray):
832
+ rgb = np.concatenate(rgb, axis=0) if rgb[0].ndim == 4 else np.stack(rgb, axis=0)
833
+ rgb = self.numpy_to_pt(rgb)
834
+ height, width = self.get_default_height_width(rgb, height, width)
835
+ if self.config.do_resize:
836
+ rgb = self.resize(rgb, height, width)
837
+
838
+ depth = np.concatenate(depth, axis=0) if rgb[0].ndim == 4 else np.stack(depth, axis=0)
839
+ depth = self.numpy_to_pt(depth)
840
+ height, width = self.get_default_height_width(depth, height, width)
841
+ if self.config.do_resize:
842
+ depth = self.resize(depth, height, width)
843
+
844
+ elif isinstance(rgb[0], torch.Tensor):
845
+ raise Exception("This is not yet supported")
846
+ # rgb = torch.cat(rgb, axis=0) if rgb[0].ndim == 4 else torch.stack(rgb, axis=0)
847
+
848
+ # if self.config.do_convert_grayscale and rgb.ndim == 3:
849
+ # rgb = rgb.unsqueeze(1)
850
+
851
+ # channel = rgb.shape[1]
852
+
853
+ # height, width = self.get_default_height_width(rgb, height, width)
854
+ # if self.config.do_resize:
855
+ # rgb = self.resize(rgb, height, width)
856
+
857
+ # depth = torch.cat(depth, axis=0) if depth[0].ndim == 4 else torch.stack(depth, axis=0)
858
+
859
+ # if self.config.do_convert_grayscale and depth.ndim == 3:
860
+ # depth = depth.unsqueeze(1)
861
+
862
+ # channel = depth.shape[1]
863
+ # # don't need any preprocess if the image is latents
864
+ # if depth == 4:
865
+ # return rgb, depth
866
+
867
+ # height, width = self.get_default_height_width(depth, height, width)
868
+ # if self.config.do_resize:
869
+ # depth = self.resize(depth, height, width)
870
+ # expected range [0,1], normalize to [-1,1]
871
+ do_normalize = self.config.do_normalize
872
+ if rgb.min() < 0 and do_normalize:
873
+ warnings.warn(
874
+ "Passing `image` as torch tensor with value range in [-1,1] is deprecated. The expected value range for image tensor is [0,1] "
875
+ f"when passing as pytorch tensor or numpy Array. You passed `image` with value range [{rgb.min()},{rgb.max()}]",
876
+ FutureWarning,
877
+ )
878
+ do_normalize = False
879
+
880
+ if do_normalize:
881
+ rgb = self.normalize(rgb)
882
+ depth = self.normalize(depth)
883
+
884
+ if self.config.do_binarize:
885
+ rgb = self.binarize(rgb)
886
+ depth = self.binarize(depth)
887
+
888
+ return rgb, depth