diffusers 0.27.1__py3-none-any.whl → 0.28.0__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (270) hide show
  1. diffusers/__init__.py +18 -1
  2. diffusers/callbacks.py +156 -0
  3. diffusers/commands/env.py +110 -6
  4. diffusers/configuration_utils.py +16 -11
  5. diffusers/dependency_versions_table.py +2 -1
  6. diffusers/image_processor.py +158 -45
  7. diffusers/loaders/__init__.py +2 -5
  8. diffusers/loaders/autoencoder.py +4 -4
  9. diffusers/loaders/controlnet.py +4 -4
  10. diffusers/loaders/ip_adapter.py +80 -22
  11. diffusers/loaders/lora.py +134 -20
  12. diffusers/loaders/lora_conversion_utils.py +46 -43
  13. diffusers/loaders/peft.py +4 -3
  14. diffusers/loaders/single_file.py +401 -170
  15. diffusers/loaders/single_file_model.py +290 -0
  16. diffusers/loaders/single_file_utils.py +616 -672
  17. diffusers/loaders/textual_inversion.py +41 -20
  18. diffusers/loaders/unet.py +168 -115
  19. diffusers/loaders/unet_loader_utils.py +163 -0
  20. diffusers/models/__init__.py +2 -0
  21. diffusers/models/activations.py +11 -3
  22. diffusers/models/attention.py +10 -11
  23. diffusers/models/attention_processor.py +367 -148
  24. diffusers/models/autoencoders/autoencoder_asym_kl.py +14 -16
  25. diffusers/models/autoencoders/autoencoder_kl.py +18 -19
  26. diffusers/models/autoencoders/autoencoder_kl_temporal_decoder.py +11 -11
  27. diffusers/models/autoencoders/autoencoder_tiny.py +16 -16
  28. diffusers/models/autoencoders/consistency_decoder_vae.py +36 -11
  29. diffusers/models/autoencoders/vae.py +23 -24
  30. diffusers/models/controlnet.py +12 -9
  31. diffusers/models/controlnet_flax.py +4 -4
  32. diffusers/models/controlnet_xs.py +1915 -0
  33. diffusers/models/downsampling.py +17 -18
  34. diffusers/models/embeddings.py +147 -24
  35. diffusers/models/model_loading_utils.py +149 -0
  36. diffusers/models/modeling_flax_pytorch_utils.py +2 -1
  37. diffusers/models/modeling_flax_utils.py +4 -4
  38. diffusers/models/modeling_pytorch_flax_utils.py +1 -1
  39. diffusers/models/modeling_utils.py +118 -98
  40. diffusers/models/resnet.py +18 -23
  41. diffusers/models/transformer_temporal.py +3 -3
  42. diffusers/models/transformers/dual_transformer_2d.py +4 -4
  43. diffusers/models/transformers/prior_transformer.py +7 -7
  44. diffusers/models/transformers/t5_film_transformer.py +17 -19
  45. diffusers/models/transformers/transformer_2d.py +272 -156
  46. diffusers/models/transformers/transformer_temporal.py +10 -10
  47. diffusers/models/unets/unet_1d.py +5 -5
  48. diffusers/models/unets/unet_1d_blocks.py +29 -29
  49. diffusers/models/unets/unet_2d.py +6 -6
  50. diffusers/models/unets/unet_2d_blocks.py +137 -128
  51. diffusers/models/unets/unet_2d_condition.py +20 -15
  52. diffusers/models/unets/unet_2d_condition_flax.py +6 -5
  53. diffusers/models/unets/unet_3d_blocks.py +79 -77
  54. diffusers/models/unets/unet_3d_condition.py +13 -9
  55. diffusers/models/unets/unet_i2vgen_xl.py +14 -13
  56. diffusers/models/unets/unet_kandinsky3.py +1 -1
  57. diffusers/models/unets/unet_motion_model.py +114 -14
  58. diffusers/models/unets/unet_spatio_temporal_condition.py +15 -14
  59. diffusers/models/unets/unet_stable_cascade.py +16 -13
  60. diffusers/models/upsampling.py +17 -20
  61. diffusers/models/vq_model.py +16 -15
  62. diffusers/pipelines/__init__.py +25 -3
  63. diffusers/pipelines/amused/pipeline_amused.py +12 -12
  64. diffusers/pipelines/amused/pipeline_amused_img2img.py +14 -12
  65. diffusers/pipelines/amused/pipeline_amused_inpaint.py +13 -11
  66. diffusers/pipelines/animatediff/__init__.py +2 -0
  67. diffusers/pipelines/animatediff/pipeline_animatediff.py +24 -46
  68. diffusers/pipelines/animatediff/pipeline_animatediff_sdxl.py +1284 -0
  69. diffusers/pipelines/animatediff/pipeline_animatediff_video2video.py +82 -72
  70. diffusers/pipelines/animatediff/pipeline_output.py +3 -2
  71. diffusers/pipelines/audioldm/pipeline_audioldm.py +14 -14
  72. diffusers/pipelines/audioldm2/modeling_audioldm2.py +54 -35
  73. diffusers/pipelines/audioldm2/pipeline_audioldm2.py +120 -36
  74. diffusers/pipelines/auto_pipeline.py +21 -17
  75. diffusers/pipelines/blip_diffusion/blip_image_processing.py +1 -1
  76. diffusers/pipelines/blip_diffusion/modeling_blip2.py +5 -5
  77. diffusers/pipelines/blip_diffusion/modeling_ctx_clip.py +1 -1
  78. diffusers/pipelines/blip_diffusion/pipeline_blip_diffusion.py +2 -2
  79. diffusers/pipelines/consistency_models/pipeline_consistency_models.py +5 -5
  80. diffusers/pipelines/controlnet/multicontrolnet.py +4 -8
  81. diffusers/pipelines/controlnet/pipeline_controlnet.py +87 -52
  82. diffusers/pipelines/controlnet/pipeline_controlnet_blip_diffusion.py +2 -2
  83. diffusers/pipelines/controlnet/pipeline_controlnet_img2img.py +50 -43
  84. diffusers/pipelines/controlnet/pipeline_controlnet_inpaint.py +52 -40
  85. diffusers/pipelines/controlnet/pipeline_controlnet_inpaint_sd_xl.py +80 -47
  86. diffusers/pipelines/controlnet/pipeline_controlnet_sd_xl.py +147 -49
  87. diffusers/pipelines/controlnet/pipeline_controlnet_sd_xl_img2img.py +89 -55
  88. diffusers/pipelines/controlnet_xs/__init__.py +68 -0
  89. diffusers/pipelines/controlnet_xs/pipeline_controlnet_xs.py +911 -0
  90. diffusers/pipelines/controlnet_xs/pipeline_controlnet_xs_sd_xl.py +1115 -0
  91. diffusers/pipelines/deepfloyd_if/pipeline_if.py +14 -28
  92. diffusers/pipelines/deepfloyd_if/pipeline_if_img2img.py +18 -33
  93. diffusers/pipelines/deepfloyd_if/pipeline_if_img2img_superresolution.py +21 -39
  94. diffusers/pipelines/deepfloyd_if/pipeline_if_inpainting.py +20 -36
  95. diffusers/pipelines/deepfloyd_if/pipeline_if_inpainting_superresolution.py +23 -39
  96. diffusers/pipelines/deepfloyd_if/pipeline_if_superresolution.py +17 -32
  97. diffusers/pipelines/deprecated/alt_diffusion/modeling_roberta_series.py +11 -11
  98. diffusers/pipelines/deprecated/alt_diffusion/pipeline_alt_diffusion.py +43 -20
  99. diffusers/pipelines/deprecated/alt_diffusion/pipeline_alt_diffusion_img2img.py +36 -18
  100. diffusers/pipelines/deprecated/repaint/pipeline_repaint.py +2 -2
  101. diffusers/pipelines/deprecated/spectrogram_diffusion/pipeline_spectrogram_diffusion.py +7 -7
  102. diffusers/pipelines/deprecated/stable_diffusion_variants/pipeline_cycle_diffusion.py +12 -12
  103. diffusers/pipelines/deprecated/stable_diffusion_variants/pipeline_stable_diffusion_inpaint_legacy.py +18 -21
  104. diffusers/pipelines/deprecated/stable_diffusion_variants/pipeline_stable_diffusion_model_editing.py +20 -15
  105. diffusers/pipelines/deprecated/stable_diffusion_variants/pipeline_stable_diffusion_paradigms.py +20 -15
  106. diffusers/pipelines/deprecated/stable_diffusion_variants/pipeline_stable_diffusion_pix2pix_zero.py +30 -25
  107. diffusers/pipelines/deprecated/versatile_diffusion/modeling_text_unet.py +69 -59
  108. diffusers/pipelines/deprecated/versatile_diffusion/pipeline_versatile_diffusion.py +13 -13
  109. diffusers/pipelines/deprecated/versatile_diffusion/pipeline_versatile_diffusion_dual_guided.py +10 -5
  110. diffusers/pipelines/deprecated/versatile_diffusion/pipeline_versatile_diffusion_image_variation.py +11 -6
  111. diffusers/pipelines/deprecated/versatile_diffusion/pipeline_versatile_diffusion_text_to_image.py +10 -5
  112. diffusers/pipelines/deprecated/vq_diffusion/pipeline_vq_diffusion.py +5 -5
  113. diffusers/pipelines/dit/pipeline_dit.py +3 -0
  114. diffusers/pipelines/free_init_utils.py +39 -38
  115. diffusers/pipelines/i2vgen_xl/pipeline_i2vgen_xl.py +33 -48
  116. diffusers/pipelines/kandinsky/pipeline_kandinsky.py +8 -8
  117. diffusers/pipelines/kandinsky/pipeline_kandinsky_combined.py +23 -20
  118. diffusers/pipelines/kandinsky/pipeline_kandinsky_img2img.py +11 -11
  119. diffusers/pipelines/kandinsky/pipeline_kandinsky_inpaint.py +12 -12
  120. diffusers/pipelines/kandinsky/pipeline_kandinsky_prior.py +10 -10
  121. diffusers/pipelines/kandinsky2_2/pipeline_kandinsky2_2.py +6 -6
  122. diffusers/pipelines/kandinsky2_2/pipeline_kandinsky2_2_combined.py +32 -29
  123. diffusers/pipelines/kandinsky2_2/pipeline_kandinsky2_2_controlnet.py +10 -10
  124. diffusers/pipelines/kandinsky2_2/pipeline_kandinsky2_2_controlnet_img2img.py +10 -10
  125. diffusers/pipelines/kandinsky2_2/pipeline_kandinsky2_2_img2img.py +6 -6
  126. diffusers/pipelines/kandinsky2_2/pipeline_kandinsky2_2_inpainting.py +8 -8
  127. diffusers/pipelines/kandinsky2_2/pipeline_kandinsky2_2_prior.py +7 -7
  128. diffusers/pipelines/kandinsky2_2/pipeline_kandinsky2_2_prior_emb2emb.py +6 -6
  129. diffusers/pipelines/kandinsky3/convert_kandinsky3_unet.py +3 -3
  130. diffusers/pipelines/kandinsky3/pipeline_kandinsky3.py +20 -33
  131. diffusers/pipelines/kandinsky3/pipeline_kandinsky3_img2img.py +24 -35
  132. diffusers/pipelines/latent_consistency_models/pipeline_latent_consistency_img2img.py +48 -30
  133. diffusers/pipelines/latent_consistency_models/pipeline_latent_consistency_text2img.py +50 -28
  134. diffusers/pipelines/latent_diffusion/pipeline_latent_diffusion.py +11 -11
  135. diffusers/pipelines/ledits_pp/pipeline_leditspp_stable_diffusion.py +61 -67
  136. diffusers/pipelines/ledits_pp/pipeline_leditspp_stable_diffusion_xl.py +70 -69
  137. diffusers/pipelines/ledits_pp/pipeline_output.py +2 -2
  138. diffusers/pipelines/marigold/__init__.py +50 -0
  139. diffusers/pipelines/marigold/marigold_image_processing.py +561 -0
  140. diffusers/pipelines/marigold/pipeline_marigold_depth.py +813 -0
  141. diffusers/pipelines/marigold/pipeline_marigold_normals.py +690 -0
  142. diffusers/pipelines/musicldm/pipeline_musicldm.py +14 -14
  143. diffusers/pipelines/paint_by_example/pipeline_paint_by_example.py +17 -12
  144. diffusers/pipelines/pia/pipeline_pia.py +39 -125
  145. diffusers/pipelines/pipeline_flax_utils.py +4 -4
  146. diffusers/pipelines/pipeline_loading_utils.py +268 -23
  147. diffusers/pipelines/pipeline_utils.py +266 -37
  148. diffusers/pipelines/pixart_alpha/__init__.py +8 -1
  149. diffusers/pipelines/pixart_alpha/pipeline_pixart_alpha.py +65 -75
  150. diffusers/pipelines/pixart_alpha/pipeline_pixart_sigma.py +880 -0
  151. diffusers/pipelines/semantic_stable_diffusion/pipeline_semantic_stable_diffusion.py +10 -5
  152. diffusers/pipelines/shap_e/pipeline_shap_e.py +3 -3
  153. diffusers/pipelines/shap_e/pipeline_shap_e_img2img.py +14 -14
  154. diffusers/pipelines/shap_e/renderer.py +1 -1
  155. diffusers/pipelines/stable_cascade/pipeline_stable_cascade.py +36 -22
  156. diffusers/pipelines/stable_cascade/pipeline_stable_cascade_combined.py +23 -19
  157. diffusers/pipelines/stable_cascade/pipeline_stable_cascade_prior.py +33 -32
  158. diffusers/pipelines/stable_diffusion/__init__.py +0 -1
  159. diffusers/pipelines/stable_diffusion/convert_from_ckpt.py +18 -11
  160. diffusers/pipelines/stable_diffusion/pipeline_onnx_stable_diffusion.py +2 -2
  161. diffusers/pipelines/stable_diffusion/pipeline_onnx_stable_diffusion_upscale.py +6 -6
  162. diffusers/pipelines/stable_diffusion/pipeline_stable_diffusion.py +73 -39
  163. diffusers/pipelines/stable_diffusion/pipeline_stable_diffusion_depth2img.py +24 -17
  164. diffusers/pipelines/stable_diffusion/pipeline_stable_diffusion_image_variation.py +13 -8
  165. diffusers/pipelines/stable_diffusion/pipeline_stable_diffusion_img2img.py +66 -36
  166. diffusers/pipelines/stable_diffusion/pipeline_stable_diffusion_inpaint.py +82 -46
  167. diffusers/pipelines/stable_diffusion/pipeline_stable_diffusion_instruct_pix2pix.py +123 -28
  168. diffusers/pipelines/stable_diffusion/pipeline_stable_diffusion_latent_upscale.py +6 -6
  169. diffusers/pipelines/stable_diffusion/pipeline_stable_diffusion_upscale.py +16 -16
  170. diffusers/pipelines/stable_diffusion/pipeline_stable_unclip.py +24 -19
  171. diffusers/pipelines/stable_diffusion/pipeline_stable_unclip_img2img.py +37 -31
  172. diffusers/pipelines/stable_diffusion/safety_checker.py +2 -1
  173. diffusers/pipelines/stable_diffusion_attend_and_excite/pipeline_stable_diffusion_attend_and_excite.py +23 -15
  174. diffusers/pipelines/stable_diffusion_diffedit/pipeline_stable_diffusion_diffedit.py +44 -42
  175. diffusers/pipelines/stable_diffusion_gligen/pipeline_stable_diffusion_gligen.py +23 -18
  176. diffusers/pipelines/stable_diffusion_gligen/pipeline_stable_diffusion_gligen_text_image.py +19 -14
  177. diffusers/pipelines/stable_diffusion_k_diffusion/pipeline_stable_diffusion_k_diffusion.py +20 -15
  178. diffusers/pipelines/stable_diffusion_k_diffusion/pipeline_stable_diffusion_xl_k_diffusion.py +24 -19
  179. diffusers/pipelines/stable_diffusion_ldm3d/pipeline_stable_diffusion_ldm3d.py +65 -32
  180. diffusers/pipelines/stable_diffusion_panorama/pipeline_stable_diffusion_panorama.py +274 -38
  181. diffusers/pipelines/stable_diffusion_safe/pipeline_stable_diffusion_safe.py +10 -5
  182. diffusers/pipelines/stable_diffusion_safe/safety_checker.py +1 -1
  183. diffusers/pipelines/stable_diffusion_sag/pipeline_stable_diffusion_sag.py +92 -25
  184. diffusers/pipelines/stable_diffusion_xl/pipeline_stable_diffusion_xl.py +88 -44
  185. diffusers/pipelines/stable_diffusion_xl/pipeline_stable_diffusion_xl_img2img.py +108 -56
  186. diffusers/pipelines/stable_diffusion_xl/pipeline_stable_diffusion_xl_inpaint.py +96 -51
  187. diffusers/pipelines/stable_diffusion_xl/pipeline_stable_diffusion_xl_instruct_pix2pix.py +45 -25
  188. diffusers/pipelines/stable_diffusion_xl/watermark.py +9 -3
  189. diffusers/pipelines/stable_video_diffusion/pipeline_stable_video_diffusion.py +110 -57
  190. diffusers/pipelines/t2i_adapter/pipeline_stable_diffusion_adapter.py +59 -30
  191. diffusers/pipelines/t2i_adapter/pipeline_stable_diffusion_xl_adapter.py +71 -42
  192. diffusers/pipelines/text_to_video_synthesis/pipeline_output.py +3 -2
  193. diffusers/pipelines/text_to_video_synthesis/pipeline_text_to_video_synth.py +18 -41
  194. diffusers/pipelines/text_to_video_synthesis/pipeline_text_to_video_synth_img2img.py +21 -85
  195. diffusers/pipelines/text_to_video_synthesis/pipeline_text_to_video_zero.py +28 -19
  196. diffusers/pipelines/text_to_video_synthesis/pipeline_text_to_video_zero_sdxl.py +39 -33
  197. diffusers/pipelines/unclip/pipeline_unclip.py +6 -6
  198. diffusers/pipelines/unclip/pipeline_unclip_image_variation.py +6 -6
  199. diffusers/pipelines/unidiffuser/modeling_text_decoder.py +1 -1
  200. diffusers/pipelines/unidiffuser/modeling_uvit.py +9 -9
  201. diffusers/pipelines/unidiffuser/pipeline_unidiffuser.py +23 -23
  202. diffusers/pipelines/wuerstchen/modeling_paella_vq_model.py +5 -5
  203. diffusers/pipelines/wuerstchen/modeling_wuerstchen_common.py +5 -10
  204. diffusers/pipelines/wuerstchen/modeling_wuerstchen_prior.py +4 -6
  205. diffusers/pipelines/wuerstchen/pipeline_wuerstchen.py +4 -4
  206. diffusers/pipelines/wuerstchen/pipeline_wuerstchen_combined.py +12 -12
  207. diffusers/pipelines/wuerstchen/pipeline_wuerstchen_prior.py +10 -10
  208. diffusers/schedulers/__init__.py +2 -2
  209. diffusers/schedulers/deprecated/__init__.py +1 -1
  210. diffusers/schedulers/deprecated/scheduling_karras_ve.py +25 -25
  211. diffusers/schedulers/scheduling_amused.py +5 -5
  212. diffusers/schedulers/scheduling_consistency_decoder.py +11 -11
  213. diffusers/schedulers/scheduling_consistency_models.py +23 -25
  214. diffusers/schedulers/scheduling_ddim.py +22 -24
  215. diffusers/schedulers/scheduling_ddim_flax.py +2 -1
  216. diffusers/schedulers/scheduling_ddim_inverse.py +16 -16
  217. diffusers/schedulers/scheduling_ddim_parallel.py +28 -30
  218. diffusers/schedulers/scheduling_ddpm.py +20 -22
  219. diffusers/schedulers/scheduling_ddpm_flax.py +7 -3
  220. diffusers/schedulers/scheduling_ddpm_parallel.py +26 -28
  221. diffusers/schedulers/scheduling_ddpm_wuerstchen.py +14 -14
  222. diffusers/schedulers/scheduling_deis_multistep.py +46 -42
  223. diffusers/schedulers/scheduling_dpmsolver_multistep.py +107 -77
  224. diffusers/schedulers/scheduling_dpmsolver_multistep_flax.py +2 -2
  225. diffusers/schedulers/scheduling_dpmsolver_multistep_inverse.py +46 -46
  226. diffusers/schedulers/scheduling_dpmsolver_sde.py +26 -22
  227. diffusers/schedulers/scheduling_dpmsolver_singlestep.py +90 -65
  228. diffusers/schedulers/scheduling_edm_dpmsolver_multistep.py +78 -53
  229. diffusers/schedulers/scheduling_edm_euler.py +53 -30
  230. diffusers/schedulers/scheduling_euler_ancestral_discrete.py +26 -28
  231. diffusers/schedulers/scheduling_euler_discrete.py +163 -67
  232. diffusers/schedulers/scheduling_heun_discrete.py +60 -38
  233. diffusers/schedulers/scheduling_ipndm.py +8 -8
  234. diffusers/schedulers/scheduling_k_dpm_2_ancestral_discrete.py +22 -18
  235. diffusers/schedulers/scheduling_k_dpm_2_discrete.py +22 -18
  236. diffusers/schedulers/scheduling_karras_ve_flax.py +6 -6
  237. diffusers/schedulers/scheduling_lcm.py +21 -23
  238. diffusers/schedulers/scheduling_lms_discrete.py +27 -25
  239. diffusers/schedulers/scheduling_pndm.py +20 -20
  240. diffusers/schedulers/scheduling_repaint.py +20 -20
  241. diffusers/schedulers/scheduling_sasolver.py +55 -54
  242. diffusers/schedulers/scheduling_sde_ve.py +19 -19
  243. diffusers/schedulers/scheduling_tcd.py +39 -30
  244. diffusers/schedulers/scheduling_unclip.py +15 -15
  245. diffusers/schedulers/scheduling_unipc_multistep.py +115 -41
  246. diffusers/schedulers/scheduling_utils.py +14 -5
  247. diffusers/schedulers/scheduling_utils_flax.py +3 -3
  248. diffusers/schedulers/scheduling_vq_diffusion.py +10 -10
  249. diffusers/training_utils.py +56 -1
  250. diffusers/utils/__init__.py +7 -0
  251. diffusers/utils/doc_utils.py +1 -0
  252. diffusers/utils/dummy_pt_objects.py +30 -0
  253. diffusers/utils/dummy_torch_and_transformers_objects.py +90 -0
  254. diffusers/utils/dynamic_modules_utils.py +24 -11
  255. diffusers/utils/hub_utils.py +3 -2
  256. diffusers/utils/import_utils.py +91 -0
  257. diffusers/utils/loading_utils.py +2 -2
  258. diffusers/utils/logging.py +1 -1
  259. diffusers/utils/peft_utils.py +32 -5
  260. diffusers/utils/state_dict_utils.py +11 -2
  261. diffusers/utils/testing_utils.py +71 -6
  262. diffusers/utils/torch_utils.py +1 -0
  263. diffusers/video_processor.py +113 -0
  264. {diffusers-0.27.1.dist-info → diffusers-0.28.0.dist-info}/METADATA +7 -7
  265. diffusers-0.28.0.dist-info/RECORD +414 -0
  266. diffusers-0.27.1.dist-info/RECORD +0 -399
  267. {diffusers-0.27.1.dist-info → diffusers-0.28.0.dist-info}/LICENSE +0 -0
  268. {diffusers-0.27.1.dist-info → diffusers-0.28.0.dist-info}/WHEEL +0 -0
  269. {diffusers-0.27.1.dist-info → diffusers-0.28.0.dist-info}/entry_points.txt +0 -0
  270. {diffusers-0.27.1.dist-info → diffusers-0.28.0.dist-info}/top_level.txt +0 -0
@@ -32,15 +32,15 @@ class SdeVeOutput(BaseOutput):
32
32
  Output class for the scheduler's `step` function output.
33
33
 
34
34
  Args:
35
- prev_sample (`torch.FloatTensor` of shape `(batch_size, num_channels, height, width)` for images):
35
+ prev_sample (`torch.Tensor` of shape `(batch_size, num_channels, height, width)` for images):
36
36
  Computed sample `(x_{t-1})` of previous timestep. `prev_sample` should be used as next model input in the
37
37
  denoising loop.
38
- prev_sample_mean (`torch.FloatTensor` of shape `(batch_size, num_channels, height, width)` for images):
38
+ prev_sample_mean (`torch.Tensor` of shape `(batch_size, num_channels, height, width)` for images):
39
39
  Mean averaged `prev_sample` over previous timesteps.
40
40
  """
41
41
 
42
- prev_sample: torch.FloatTensor
43
- prev_sample_mean: torch.FloatTensor
42
+ prev_sample: torch.Tensor
43
+ prev_sample_mean: torch.Tensor
44
44
 
45
45
 
46
46
  class ScoreSdeVeScheduler(SchedulerMixin, ConfigMixin):
@@ -86,19 +86,19 @@ class ScoreSdeVeScheduler(SchedulerMixin, ConfigMixin):
86
86
 
87
87
  self.set_sigmas(num_train_timesteps, sigma_min, sigma_max, sampling_eps)
88
88
 
89
- def scale_model_input(self, sample: torch.FloatTensor, timestep: Optional[int] = None) -> torch.FloatTensor:
89
+ def scale_model_input(self, sample: torch.Tensor, timestep: Optional[int] = None) -> torch.Tensor:
90
90
  """
91
91
  Ensures interchangeability with schedulers that need to scale the denoising model input depending on the
92
92
  current timestep.
93
93
 
94
94
  Args:
95
- sample (`torch.FloatTensor`):
95
+ sample (`torch.Tensor`):
96
96
  The input sample.
97
97
  timestep (`int`, *optional*):
98
98
  The current timestep in the diffusion chain.
99
99
 
100
100
  Returns:
101
- `torch.FloatTensor`:
101
+ `torch.Tensor`:
102
102
  A scaled input sample.
103
103
  """
104
104
  return sample
@@ -159,9 +159,9 @@ class ScoreSdeVeScheduler(SchedulerMixin, ConfigMixin):
159
159
 
160
160
  def step_pred(
161
161
  self,
162
- model_output: torch.FloatTensor,
162
+ model_output: torch.Tensor,
163
163
  timestep: int,
164
- sample: torch.FloatTensor,
164
+ sample: torch.Tensor,
165
165
  generator: Optional[torch.Generator] = None,
166
166
  return_dict: bool = True,
167
167
  ) -> Union[SdeVeOutput, Tuple]:
@@ -170,11 +170,11 @@ class ScoreSdeVeScheduler(SchedulerMixin, ConfigMixin):
170
170
  process from the learned model outputs (most often the predicted noise).
171
171
 
172
172
  Args:
173
- model_output (`torch.FloatTensor`):
173
+ model_output (`torch.Tensor`):
174
174
  The direct output from learned diffusion model.
175
175
  timestep (`int`):
176
176
  The current discrete timestep in the diffusion chain.
177
- sample (`torch.FloatTensor`):
177
+ sample (`torch.Tensor`):
178
178
  A current instance of a sample created by the diffusion process.
179
179
  generator (`torch.Generator`, *optional*):
180
180
  A random number generator.
@@ -227,8 +227,8 @@ class ScoreSdeVeScheduler(SchedulerMixin, ConfigMixin):
227
227
 
228
228
  def step_correct(
229
229
  self,
230
- model_output: torch.FloatTensor,
231
- sample: torch.FloatTensor,
230
+ model_output: torch.Tensor,
231
+ sample: torch.Tensor,
232
232
  generator: Optional[torch.Generator] = None,
233
233
  return_dict: bool = True,
234
234
  ) -> Union[SchedulerOutput, Tuple]:
@@ -237,9 +237,9 @@ class ScoreSdeVeScheduler(SchedulerMixin, ConfigMixin):
237
237
  making the prediction for the previous timestep.
238
238
 
239
239
  Args:
240
- model_output (`torch.FloatTensor`):
240
+ model_output (`torch.Tensor`):
241
241
  The direct output from learned diffusion model.
242
- sample (`torch.FloatTensor`):
242
+ sample (`torch.Tensor`):
243
243
  A current instance of a sample created by the diffusion process.
244
244
  generator (`torch.Generator`, *optional*):
245
245
  A random number generator.
@@ -282,10 +282,10 @@ class ScoreSdeVeScheduler(SchedulerMixin, ConfigMixin):
282
282
 
283
283
  def add_noise(
284
284
  self,
285
- original_samples: torch.FloatTensor,
286
- noise: torch.FloatTensor,
287
- timesteps: torch.FloatTensor,
288
- ) -> torch.FloatTensor:
285
+ original_samples: torch.Tensor,
286
+ noise: torch.Tensor,
287
+ timesteps: torch.Tensor,
288
+ ) -> torch.Tensor:
289
289
  # Make sure sigmas and timesteps have the same device and dtype as original_samples
290
290
  timesteps = timesteps.to(original_samples.device)
291
291
  sigmas = self.discrete_sigmas.to(original_samples.device)[timesteps]
@@ -37,15 +37,15 @@ class TCDSchedulerOutput(BaseOutput):
37
37
  Output class for the scheduler's `step` function output.
38
38
 
39
39
  Args:
40
- prev_sample (`torch.FloatTensor` of shape `(batch_size, num_channels, height, width)` for images):
40
+ prev_sample (`torch.Tensor` of shape `(batch_size, num_channels, height, width)` for images):
41
41
  Computed sample `(x_{t-1})` of previous timestep. `prev_sample` should be used as next model input in the
42
42
  denoising loop.
43
- pred_noised_sample (`torch.FloatTensor` of shape `(batch_size, num_channels, height, width)` for images):
43
+ pred_noised_sample (`torch.Tensor` of shape `(batch_size, num_channels, height, width)` for images):
44
44
  The predicted noised sample `(x_{s})` based on the model output from the current timestep.
45
45
  """
46
46
 
47
- prev_sample: torch.FloatTensor
48
- pred_noised_sample: Optional[torch.FloatTensor] = None
47
+ prev_sample: torch.Tensor
48
+ pred_noised_sample: Optional[torch.Tensor] = None
49
49
 
50
50
 
51
51
  # Copied from diffusers.schedulers.scheduling_ddpm.betas_for_alpha_bar
@@ -83,7 +83,7 @@ def betas_for_alpha_bar(
83
83
  return math.exp(t * -12.0)
84
84
 
85
85
  else:
86
- raise ValueError(f"Unsupported alpha_tranform_type: {alpha_transform_type}")
86
+ raise ValueError(f"Unsupported alpha_transform_type: {alpha_transform_type}")
87
87
 
88
88
  betas = []
89
89
  for i in range(num_diffusion_timesteps):
@@ -94,17 +94,17 @@ def betas_for_alpha_bar(
94
94
 
95
95
 
96
96
  # Copied from diffusers.schedulers.scheduling_ddim.rescale_zero_terminal_snr
97
- def rescale_zero_terminal_snr(betas: torch.FloatTensor) -> torch.FloatTensor:
97
+ def rescale_zero_terminal_snr(betas: torch.Tensor) -> torch.Tensor:
98
98
  """
99
99
  Rescales betas to have zero terminal SNR Based on https://arxiv.org/pdf/2305.08891.pdf (Algorithm 1)
100
100
 
101
101
 
102
102
  Args:
103
- betas (`torch.FloatTensor`):
103
+ betas (`torch.Tensor`):
104
104
  the betas that the scheduler is being initialized with.
105
105
 
106
106
  Returns:
107
- `torch.FloatTensor`: rescaled betas with zero terminal SNR
107
+ `torch.Tensor`: rescaled betas with zero terminal SNR
108
108
  """
109
109
  # Convert betas to alphas_bar_sqrt
110
110
  alphas = 1.0 - betas
@@ -132,8 +132,8 @@ def rescale_zero_terminal_snr(betas: torch.FloatTensor) -> torch.FloatTensor:
132
132
 
133
133
  class TCDScheduler(SchedulerMixin, ConfigMixin):
134
134
  """
135
- `TCDScheduler` incorporates the `Strategic Stochastic Sampling` introduced by the paper `Trajectory Consistency Distillation`,
136
- extending the original Multistep Consistency Sampling to enable unrestricted trajectory traversal.
135
+ `TCDScheduler` incorporates the `Strategic Stochastic Sampling` introduced by the paper `Trajectory Consistency
136
+ Distillation`, extending the original Multistep Consistency Sampling to enable unrestricted trajectory traversal.
137
137
 
138
138
  This code is based on the official repo of TCD(https://github.com/jabir-zheng/TCD).
139
139
 
@@ -225,7 +225,7 @@ class TCDScheduler(SchedulerMixin, ConfigMixin):
225
225
  # Glide cosine schedule
226
226
  self.betas = betas_for_alpha_bar(num_train_timesteps)
227
227
  else:
228
- raise NotImplementedError(f"{beta_schedule} does is not implemented for {self.__class__}")
228
+ raise NotImplementedError(f"{beta_schedule} is not implemented for {self.__class__}")
229
229
 
230
230
  # Rescale for zero SNR
231
231
  if rescale_betas_zero_snr:
@@ -297,18 +297,19 @@ class TCDScheduler(SchedulerMixin, ConfigMixin):
297
297
  """
298
298
  self._begin_index = begin_index
299
299
 
300
- def scale_model_input(self, sample: torch.FloatTensor, timestep: Optional[int] = None) -> torch.FloatTensor:
300
+ def scale_model_input(self, sample: torch.Tensor, timestep: Optional[int] = None) -> torch.Tensor:
301
301
  """
302
302
  Ensures interchangeability with schedulers that need to scale the denoising model input depending on the
303
303
  current timestep.
304
304
 
305
305
  Args:
306
- sample (`torch.FloatTensor`):
306
+ sample (`torch.Tensor`):
307
307
  The input sample.
308
308
  timestep (`int`, *optional*):
309
309
  The current timestep in the diffusion chain.
310
+
310
311
  Returns:
311
- `torch.FloatTensor`:
312
+ `torch.Tensor`:
312
313
  A scaled input sample.
313
314
  """
314
315
  return sample
@@ -325,7 +326,7 @@ class TCDScheduler(SchedulerMixin, ConfigMixin):
325
326
  return variance
326
327
 
327
328
  # Copied from diffusers.schedulers.scheduling_ddpm.DDPMScheduler._threshold_sample
328
- def _threshold_sample(self, sample: torch.FloatTensor) -> torch.FloatTensor:
329
+ def _threshold_sample(self, sample: torch.Tensor) -> torch.Tensor:
329
330
  """
330
331
  "Dynamic thresholding: At each sampling step we set s to a certain percentile absolute pixel value in xt0 (the
331
332
  prediction of x_0 at timestep t), and if s > 1, then we threshold xt0 to the range [-s, s] and then divide by
@@ -364,7 +365,7 @@ class TCDScheduler(SchedulerMixin, ConfigMixin):
364
365
  device: Union[str, torch.device] = None,
365
366
  original_inference_steps: Optional[int] = None,
366
367
  timesteps: Optional[List[int]] = None,
367
- strength: int = 1.0,
368
+ strength: float = 1.0,
368
369
  ):
369
370
  """
370
371
  Sets the discrete timesteps used for the diffusion chain (to be run before inference).
@@ -384,6 +385,8 @@ class TCDScheduler(SchedulerMixin, ConfigMixin):
384
385
  Custom timesteps used to support arbitrary spacing between timesteps. If `None`, then the default
385
386
  timestep spacing strategy of equal spacing between timesteps on the training/distillation timestep
386
387
  schedule is used. If `timesteps` is passed, `num_inference_steps` must be `None`.
388
+ strength (`float`, *optional*, defaults to 1.0):
389
+ Used to determine the number of timesteps used for inference when using img2img, inpaint, etc.
387
390
  """
388
391
  # 0. Check inputs
389
392
  if num_inference_steps is None and timesteps is None:
@@ -521,9 +524,9 @@ class TCDScheduler(SchedulerMixin, ConfigMixin):
521
524
 
522
525
  def step(
523
526
  self,
524
- model_output: torch.FloatTensor,
527
+ model_output: torch.Tensor,
525
528
  timestep: int,
526
- sample: torch.FloatTensor,
529
+ sample: torch.Tensor,
527
530
  eta: float = 0.3,
528
531
  generator: Optional[torch.Generator] = None,
529
532
  return_dict: bool = True,
@@ -533,15 +536,16 @@ class TCDScheduler(SchedulerMixin, ConfigMixin):
533
536
  process from the learned model outputs (most often the predicted noise).
534
537
 
535
538
  Args:
536
- model_output (`torch.FloatTensor`):
539
+ model_output (`torch.Tensor`):
537
540
  The direct output from learned diffusion model.
538
541
  timestep (`int`):
539
542
  The current discrete timestep in the diffusion chain.
540
- sample (`torch.FloatTensor`):
543
+ sample (`torch.Tensor`):
541
544
  A current instance of a sample created by the diffusion process.
542
545
  eta (`float`):
543
- A stochastic parameter (referred to as `gamma` in the paper) used to control the stochasticity in every step.
544
- When eta = 0, it represents deterministic sampling, whereas eta = 1 indicates full stochastic sampling.
546
+ A stochastic parameter (referred to as `gamma` in the paper) used to control the stochasticity in every
547
+ step. When eta = 0, it represents deterministic sampling, whereas eta = 1 indicates full stochastic
548
+ sampling.
545
549
  generator (`torch.Generator`, *optional*):
546
550
  A random number generator.
547
551
  return_dict (`bool`, *optional*, defaults to `True`):
@@ -624,14 +628,18 @@ class TCDScheduler(SchedulerMixin, ConfigMixin):
624
628
 
625
629
  return TCDSchedulerOutput(prev_sample=prev_sample, pred_noised_sample=pred_noised_sample)
626
630
 
631
+ # Copied from diffusers.schedulers.scheduling_ddpm.DDPMScheduler.add_noise
627
632
  def add_noise(
628
633
  self,
629
- original_samples: torch.FloatTensor,
630
- noise: torch.FloatTensor,
634
+ original_samples: torch.Tensor,
635
+ noise: torch.Tensor,
631
636
  timesteps: torch.IntTensor,
632
- ) -> torch.FloatTensor:
637
+ ) -> torch.Tensor:
633
638
  # Make sure alphas_cumprod and timestep have same device and dtype as original_samples
634
- alphas_cumprod = self.alphas_cumprod.to(device=original_samples.device, dtype=original_samples.dtype)
639
+ # Move the self.alphas_cumprod to device to avoid redundant CPU to GPU data movement
640
+ # for the subsequent add_noise calls
641
+ self.alphas_cumprod = self.alphas_cumprod.to(device=original_samples.device)
642
+ alphas_cumprod = self.alphas_cumprod.to(dtype=original_samples.dtype)
635
643
  timesteps = timesteps.to(original_samples.device)
636
644
 
637
645
  sqrt_alpha_prod = alphas_cumprod[timesteps] ** 0.5
@@ -647,11 +655,11 @@ class TCDScheduler(SchedulerMixin, ConfigMixin):
647
655
  noisy_samples = sqrt_alpha_prod * original_samples + sqrt_one_minus_alpha_prod * noise
648
656
  return noisy_samples
649
657
 
650
- def get_velocity(
651
- self, sample: torch.FloatTensor, noise: torch.FloatTensor, timesteps: torch.IntTensor
652
- ) -> torch.FloatTensor:
658
+ # Copied from diffusers.schedulers.scheduling_ddpm.DDPMScheduler.get_velocity
659
+ def get_velocity(self, sample: torch.Tensor, noise: torch.Tensor, timesteps: torch.IntTensor) -> torch.Tensor:
653
660
  # Make sure alphas_cumprod and timestep have same device and dtype as sample
654
- alphas_cumprod = self.alphas_cumprod.to(device=sample.device, dtype=sample.dtype)
661
+ self.alphas_cumprod = self.alphas_cumprod.to(device=sample.device)
662
+ alphas_cumprod = self.alphas_cumprod.to(dtype=sample.dtype)
655
663
  timesteps = timesteps.to(sample.device)
656
664
 
657
665
  sqrt_alpha_prod = alphas_cumprod[timesteps] ** 0.5
@@ -670,6 +678,7 @@ class TCDScheduler(SchedulerMixin, ConfigMixin):
670
678
  def __len__(self):
671
679
  return self.config.num_train_timesteps
672
680
 
681
+ # Copied from diffusers.schedulers.scheduling_ddpm.DDPMScheduler.previous_timestep
673
682
  def previous_timestep(self, timestep):
674
683
  if self.custom_timesteps:
675
684
  index = (self.timesteps == timestep).nonzero(as_tuple=True)[0][0]
@@ -32,16 +32,16 @@ class UnCLIPSchedulerOutput(BaseOutput):
32
32
  Output class for the scheduler's `step` function output.
33
33
 
34
34
  Args:
35
- prev_sample (`torch.FloatTensor` of shape `(batch_size, num_channels, height, width)` for images):
35
+ prev_sample (`torch.Tensor` of shape `(batch_size, num_channels, height, width)` for images):
36
36
  Computed sample `(x_{t-1})` of previous timestep. `prev_sample` should be used as next model input in the
37
37
  denoising loop.
38
- pred_original_sample (`torch.FloatTensor` of shape `(batch_size, num_channels, height, width)` for images):
38
+ pred_original_sample (`torch.Tensor` of shape `(batch_size, num_channels, height, width)` for images):
39
39
  The predicted denoised sample `(x_{0})` based on the model output from the current timestep.
40
40
  `pred_original_sample` can be used to preview progress or for guidance.
41
41
  """
42
42
 
43
- prev_sample: torch.FloatTensor
44
- pred_original_sample: Optional[torch.FloatTensor] = None
43
+ prev_sample: torch.Tensor
44
+ pred_original_sample: Optional[torch.Tensor] = None
45
45
 
46
46
 
47
47
  # Copied from diffusers.schedulers.scheduling_ddpm.betas_for_alpha_bar
@@ -79,7 +79,7 @@ def betas_for_alpha_bar(
79
79
  return math.exp(t * -12.0)
80
80
 
81
81
  else:
82
- raise ValueError(f"Unsupported alpha_tranform_type: {alpha_transform_type}")
82
+ raise ValueError(f"Unsupported alpha_transform_type: {alpha_transform_type}")
83
83
 
84
84
  betas = []
85
85
  for i in range(num_diffusion_timesteps):
@@ -146,17 +146,17 @@ class UnCLIPScheduler(SchedulerMixin, ConfigMixin):
146
146
 
147
147
  self.variance_type = variance_type
148
148
 
149
- def scale_model_input(self, sample: torch.FloatTensor, timestep: Optional[int] = None) -> torch.FloatTensor:
149
+ def scale_model_input(self, sample: torch.Tensor, timestep: Optional[int] = None) -> torch.Tensor:
150
150
  """
151
151
  Ensures interchangeability with schedulers that need to scale the denoising model input depending on the
152
152
  current timestep.
153
153
 
154
154
  Args:
155
- sample (`torch.FloatTensor`): input sample
155
+ sample (`torch.Tensor`): input sample
156
156
  timestep (`int`, optional): current timestep
157
157
 
158
158
  Returns:
159
- `torch.FloatTensor`: scaled input sample
159
+ `torch.Tensor`: scaled input sample
160
160
  """
161
161
  return sample
162
162
 
@@ -215,9 +215,9 @@ class UnCLIPScheduler(SchedulerMixin, ConfigMixin):
215
215
 
216
216
  def step(
217
217
  self,
218
- model_output: torch.FloatTensor,
218
+ model_output: torch.Tensor,
219
219
  timestep: int,
220
- sample: torch.FloatTensor,
220
+ sample: torch.Tensor,
221
221
  prev_timestep: Optional[int] = None,
222
222
  generator=None,
223
223
  return_dict: bool = True,
@@ -227,9 +227,9 @@ class UnCLIPScheduler(SchedulerMixin, ConfigMixin):
227
227
  process from the learned model outputs (most often the predicted noise).
228
228
 
229
229
  Args:
230
- model_output (`torch.FloatTensor`): direct output from learned diffusion model.
230
+ model_output (`torch.Tensor`): direct output from learned diffusion model.
231
231
  timestep (`int`): current discrete timestep in the diffusion chain.
232
- sample (`torch.FloatTensor`):
232
+ sample (`torch.Tensor`):
233
233
  current instance of sample being created by diffusion process.
234
234
  prev_timestep (`int`, *optional*): The previous timestep to predict the previous sample at.
235
235
  Used to dynamically compute beta. If not given, `t-1` is used and the pre-computed beta is used.
@@ -327,10 +327,10 @@ class UnCLIPScheduler(SchedulerMixin, ConfigMixin):
327
327
  # Copied from diffusers.schedulers.scheduling_ddpm.DDPMScheduler.add_noise
328
328
  def add_noise(
329
329
  self,
330
- original_samples: torch.FloatTensor,
331
- noise: torch.FloatTensor,
330
+ original_samples: torch.Tensor,
331
+ noise: torch.Tensor,
332
332
  timesteps: torch.IntTensor,
333
- ) -> torch.FloatTensor:
333
+ ) -> torch.Tensor:
334
334
  # Make sure alphas_cumprod and timestep have same device and dtype as original_samples
335
335
  # Move the self.alphas_cumprod to device to avoid redundant CPU to GPU data movement
336
336
  # for the subsequent add_noise calls