diffusers 0.27.1__py3-none-any.whl → 0.28.0__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (270) hide show
  1. diffusers/__init__.py +18 -1
  2. diffusers/callbacks.py +156 -0
  3. diffusers/commands/env.py +110 -6
  4. diffusers/configuration_utils.py +16 -11
  5. diffusers/dependency_versions_table.py +2 -1
  6. diffusers/image_processor.py +158 -45
  7. diffusers/loaders/__init__.py +2 -5
  8. diffusers/loaders/autoencoder.py +4 -4
  9. diffusers/loaders/controlnet.py +4 -4
  10. diffusers/loaders/ip_adapter.py +80 -22
  11. diffusers/loaders/lora.py +134 -20
  12. diffusers/loaders/lora_conversion_utils.py +46 -43
  13. diffusers/loaders/peft.py +4 -3
  14. diffusers/loaders/single_file.py +401 -170
  15. diffusers/loaders/single_file_model.py +290 -0
  16. diffusers/loaders/single_file_utils.py +616 -672
  17. diffusers/loaders/textual_inversion.py +41 -20
  18. diffusers/loaders/unet.py +168 -115
  19. diffusers/loaders/unet_loader_utils.py +163 -0
  20. diffusers/models/__init__.py +2 -0
  21. diffusers/models/activations.py +11 -3
  22. diffusers/models/attention.py +10 -11
  23. diffusers/models/attention_processor.py +367 -148
  24. diffusers/models/autoencoders/autoencoder_asym_kl.py +14 -16
  25. diffusers/models/autoencoders/autoencoder_kl.py +18 -19
  26. diffusers/models/autoencoders/autoencoder_kl_temporal_decoder.py +11 -11
  27. diffusers/models/autoencoders/autoencoder_tiny.py +16 -16
  28. diffusers/models/autoencoders/consistency_decoder_vae.py +36 -11
  29. diffusers/models/autoencoders/vae.py +23 -24
  30. diffusers/models/controlnet.py +12 -9
  31. diffusers/models/controlnet_flax.py +4 -4
  32. diffusers/models/controlnet_xs.py +1915 -0
  33. diffusers/models/downsampling.py +17 -18
  34. diffusers/models/embeddings.py +147 -24
  35. diffusers/models/model_loading_utils.py +149 -0
  36. diffusers/models/modeling_flax_pytorch_utils.py +2 -1
  37. diffusers/models/modeling_flax_utils.py +4 -4
  38. diffusers/models/modeling_pytorch_flax_utils.py +1 -1
  39. diffusers/models/modeling_utils.py +118 -98
  40. diffusers/models/resnet.py +18 -23
  41. diffusers/models/transformer_temporal.py +3 -3
  42. diffusers/models/transformers/dual_transformer_2d.py +4 -4
  43. diffusers/models/transformers/prior_transformer.py +7 -7
  44. diffusers/models/transformers/t5_film_transformer.py +17 -19
  45. diffusers/models/transformers/transformer_2d.py +272 -156
  46. diffusers/models/transformers/transformer_temporal.py +10 -10
  47. diffusers/models/unets/unet_1d.py +5 -5
  48. diffusers/models/unets/unet_1d_blocks.py +29 -29
  49. diffusers/models/unets/unet_2d.py +6 -6
  50. diffusers/models/unets/unet_2d_blocks.py +137 -128
  51. diffusers/models/unets/unet_2d_condition.py +20 -15
  52. diffusers/models/unets/unet_2d_condition_flax.py +6 -5
  53. diffusers/models/unets/unet_3d_blocks.py +79 -77
  54. diffusers/models/unets/unet_3d_condition.py +13 -9
  55. diffusers/models/unets/unet_i2vgen_xl.py +14 -13
  56. diffusers/models/unets/unet_kandinsky3.py +1 -1
  57. diffusers/models/unets/unet_motion_model.py +114 -14
  58. diffusers/models/unets/unet_spatio_temporal_condition.py +15 -14
  59. diffusers/models/unets/unet_stable_cascade.py +16 -13
  60. diffusers/models/upsampling.py +17 -20
  61. diffusers/models/vq_model.py +16 -15
  62. diffusers/pipelines/__init__.py +25 -3
  63. diffusers/pipelines/amused/pipeline_amused.py +12 -12
  64. diffusers/pipelines/amused/pipeline_amused_img2img.py +14 -12
  65. diffusers/pipelines/amused/pipeline_amused_inpaint.py +13 -11
  66. diffusers/pipelines/animatediff/__init__.py +2 -0
  67. diffusers/pipelines/animatediff/pipeline_animatediff.py +24 -46
  68. diffusers/pipelines/animatediff/pipeline_animatediff_sdxl.py +1284 -0
  69. diffusers/pipelines/animatediff/pipeline_animatediff_video2video.py +82 -72
  70. diffusers/pipelines/animatediff/pipeline_output.py +3 -2
  71. diffusers/pipelines/audioldm/pipeline_audioldm.py +14 -14
  72. diffusers/pipelines/audioldm2/modeling_audioldm2.py +54 -35
  73. diffusers/pipelines/audioldm2/pipeline_audioldm2.py +120 -36
  74. diffusers/pipelines/auto_pipeline.py +21 -17
  75. diffusers/pipelines/blip_diffusion/blip_image_processing.py +1 -1
  76. diffusers/pipelines/blip_diffusion/modeling_blip2.py +5 -5
  77. diffusers/pipelines/blip_diffusion/modeling_ctx_clip.py +1 -1
  78. diffusers/pipelines/blip_diffusion/pipeline_blip_diffusion.py +2 -2
  79. diffusers/pipelines/consistency_models/pipeline_consistency_models.py +5 -5
  80. diffusers/pipelines/controlnet/multicontrolnet.py +4 -8
  81. diffusers/pipelines/controlnet/pipeline_controlnet.py +87 -52
  82. diffusers/pipelines/controlnet/pipeline_controlnet_blip_diffusion.py +2 -2
  83. diffusers/pipelines/controlnet/pipeline_controlnet_img2img.py +50 -43
  84. diffusers/pipelines/controlnet/pipeline_controlnet_inpaint.py +52 -40
  85. diffusers/pipelines/controlnet/pipeline_controlnet_inpaint_sd_xl.py +80 -47
  86. diffusers/pipelines/controlnet/pipeline_controlnet_sd_xl.py +147 -49
  87. diffusers/pipelines/controlnet/pipeline_controlnet_sd_xl_img2img.py +89 -55
  88. diffusers/pipelines/controlnet_xs/__init__.py +68 -0
  89. diffusers/pipelines/controlnet_xs/pipeline_controlnet_xs.py +911 -0
  90. diffusers/pipelines/controlnet_xs/pipeline_controlnet_xs_sd_xl.py +1115 -0
  91. diffusers/pipelines/deepfloyd_if/pipeline_if.py +14 -28
  92. diffusers/pipelines/deepfloyd_if/pipeline_if_img2img.py +18 -33
  93. diffusers/pipelines/deepfloyd_if/pipeline_if_img2img_superresolution.py +21 -39
  94. diffusers/pipelines/deepfloyd_if/pipeline_if_inpainting.py +20 -36
  95. diffusers/pipelines/deepfloyd_if/pipeline_if_inpainting_superresolution.py +23 -39
  96. diffusers/pipelines/deepfloyd_if/pipeline_if_superresolution.py +17 -32
  97. diffusers/pipelines/deprecated/alt_diffusion/modeling_roberta_series.py +11 -11
  98. diffusers/pipelines/deprecated/alt_diffusion/pipeline_alt_diffusion.py +43 -20
  99. diffusers/pipelines/deprecated/alt_diffusion/pipeline_alt_diffusion_img2img.py +36 -18
  100. diffusers/pipelines/deprecated/repaint/pipeline_repaint.py +2 -2
  101. diffusers/pipelines/deprecated/spectrogram_diffusion/pipeline_spectrogram_diffusion.py +7 -7
  102. diffusers/pipelines/deprecated/stable_diffusion_variants/pipeline_cycle_diffusion.py +12 -12
  103. diffusers/pipelines/deprecated/stable_diffusion_variants/pipeline_stable_diffusion_inpaint_legacy.py +18 -21
  104. diffusers/pipelines/deprecated/stable_diffusion_variants/pipeline_stable_diffusion_model_editing.py +20 -15
  105. diffusers/pipelines/deprecated/stable_diffusion_variants/pipeline_stable_diffusion_paradigms.py +20 -15
  106. diffusers/pipelines/deprecated/stable_diffusion_variants/pipeline_stable_diffusion_pix2pix_zero.py +30 -25
  107. diffusers/pipelines/deprecated/versatile_diffusion/modeling_text_unet.py +69 -59
  108. diffusers/pipelines/deprecated/versatile_diffusion/pipeline_versatile_diffusion.py +13 -13
  109. diffusers/pipelines/deprecated/versatile_diffusion/pipeline_versatile_diffusion_dual_guided.py +10 -5
  110. diffusers/pipelines/deprecated/versatile_diffusion/pipeline_versatile_diffusion_image_variation.py +11 -6
  111. diffusers/pipelines/deprecated/versatile_diffusion/pipeline_versatile_diffusion_text_to_image.py +10 -5
  112. diffusers/pipelines/deprecated/vq_diffusion/pipeline_vq_diffusion.py +5 -5
  113. diffusers/pipelines/dit/pipeline_dit.py +3 -0
  114. diffusers/pipelines/free_init_utils.py +39 -38
  115. diffusers/pipelines/i2vgen_xl/pipeline_i2vgen_xl.py +33 -48
  116. diffusers/pipelines/kandinsky/pipeline_kandinsky.py +8 -8
  117. diffusers/pipelines/kandinsky/pipeline_kandinsky_combined.py +23 -20
  118. diffusers/pipelines/kandinsky/pipeline_kandinsky_img2img.py +11 -11
  119. diffusers/pipelines/kandinsky/pipeline_kandinsky_inpaint.py +12 -12
  120. diffusers/pipelines/kandinsky/pipeline_kandinsky_prior.py +10 -10
  121. diffusers/pipelines/kandinsky2_2/pipeline_kandinsky2_2.py +6 -6
  122. diffusers/pipelines/kandinsky2_2/pipeline_kandinsky2_2_combined.py +32 -29
  123. diffusers/pipelines/kandinsky2_2/pipeline_kandinsky2_2_controlnet.py +10 -10
  124. diffusers/pipelines/kandinsky2_2/pipeline_kandinsky2_2_controlnet_img2img.py +10 -10
  125. diffusers/pipelines/kandinsky2_2/pipeline_kandinsky2_2_img2img.py +6 -6
  126. diffusers/pipelines/kandinsky2_2/pipeline_kandinsky2_2_inpainting.py +8 -8
  127. diffusers/pipelines/kandinsky2_2/pipeline_kandinsky2_2_prior.py +7 -7
  128. diffusers/pipelines/kandinsky2_2/pipeline_kandinsky2_2_prior_emb2emb.py +6 -6
  129. diffusers/pipelines/kandinsky3/convert_kandinsky3_unet.py +3 -3
  130. diffusers/pipelines/kandinsky3/pipeline_kandinsky3.py +20 -33
  131. diffusers/pipelines/kandinsky3/pipeline_kandinsky3_img2img.py +24 -35
  132. diffusers/pipelines/latent_consistency_models/pipeline_latent_consistency_img2img.py +48 -30
  133. diffusers/pipelines/latent_consistency_models/pipeline_latent_consistency_text2img.py +50 -28
  134. diffusers/pipelines/latent_diffusion/pipeline_latent_diffusion.py +11 -11
  135. diffusers/pipelines/ledits_pp/pipeline_leditspp_stable_diffusion.py +61 -67
  136. diffusers/pipelines/ledits_pp/pipeline_leditspp_stable_diffusion_xl.py +70 -69
  137. diffusers/pipelines/ledits_pp/pipeline_output.py +2 -2
  138. diffusers/pipelines/marigold/__init__.py +50 -0
  139. diffusers/pipelines/marigold/marigold_image_processing.py +561 -0
  140. diffusers/pipelines/marigold/pipeline_marigold_depth.py +813 -0
  141. diffusers/pipelines/marigold/pipeline_marigold_normals.py +690 -0
  142. diffusers/pipelines/musicldm/pipeline_musicldm.py +14 -14
  143. diffusers/pipelines/paint_by_example/pipeline_paint_by_example.py +17 -12
  144. diffusers/pipelines/pia/pipeline_pia.py +39 -125
  145. diffusers/pipelines/pipeline_flax_utils.py +4 -4
  146. diffusers/pipelines/pipeline_loading_utils.py +268 -23
  147. diffusers/pipelines/pipeline_utils.py +266 -37
  148. diffusers/pipelines/pixart_alpha/__init__.py +8 -1
  149. diffusers/pipelines/pixart_alpha/pipeline_pixart_alpha.py +65 -75
  150. diffusers/pipelines/pixart_alpha/pipeline_pixart_sigma.py +880 -0
  151. diffusers/pipelines/semantic_stable_diffusion/pipeline_semantic_stable_diffusion.py +10 -5
  152. diffusers/pipelines/shap_e/pipeline_shap_e.py +3 -3
  153. diffusers/pipelines/shap_e/pipeline_shap_e_img2img.py +14 -14
  154. diffusers/pipelines/shap_e/renderer.py +1 -1
  155. diffusers/pipelines/stable_cascade/pipeline_stable_cascade.py +36 -22
  156. diffusers/pipelines/stable_cascade/pipeline_stable_cascade_combined.py +23 -19
  157. diffusers/pipelines/stable_cascade/pipeline_stable_cascade_prior.py +33 -32
  158. diffusers/pipelines/stable_diffusion/__init__.py +0 -1
  159. diffusers/pipelines/stable_diffusion/convert_from_ckpt.py +18 -11
  160. diffusers/pipelines/stable_diffusion/pipeline_onnx_stable_diffusion.py +2 -2
  161. diffusers/pipelines/stable_diffusion/pipeline_onnx_stable_diffusion_upscale.py +6 -6
  162. diffusers/pipelines/stable_diffusion/pipeline_stable_diffusion.py +73 -39
  163. diffusers/pipelines/stable_diffusion/pipeline_stable_diffusion_depth2img.py +24 -17
  164. diffusers/pipelines/stable_diffusion/pipeline_stable_diffusion_image_variation.py +13 -8
  165. diffusers/pipelines/stable_diffusion/pipeline_stable_diffusion_img2img.py +66 -36
  166. diffusers/pipelines/stable_diffusion/pipeline_stable_diffusion_inpaint.py +82 -46
  167. diffusers/pipelines/stable_diffusion/pipeline_stable_diffusion_instruct_pix2pix.py +123 -28
  168. diffusers/pipelines/stable_diffusion/pipeline_stable_diffusion_latent_upscale.py +6 -6
  169. diffusers/pipelines/stable_diffusion/pipeline_stable_diffusion_upscale.py +16 -16
  170. diffusers/pipelines/stable_diffusion/pipeline_stable_unclip.py +24 -19
  171. diffusers/pipelines/stable_diffusion/pipeline_stable_unclip_img2img.py +37 -31
  172. diffusers/pipelines/stable_diffusion/safety_checker.py +2 -1
  173. diffusers/pipelines/stable_diffusion_attend_and_excite/pipeline_stable_diffusion_attend_and_excite.py +23 -15
  174. diffusers/pipelines/stable_diffusion_diffedit/pipeline_stable_diffusion_diffedit.py +44 -42
  175. diffusers/pipelines/stable_diffusion_gligen/pipeline_stable_diffusion_gligen.py +23 -18
  176. diffusers/pipelines/stable_diffusion_gligen/pipeline_stable_diffusion_gligen_text_image.py +19 -14
  177. diffusers/pipelines/stable_diffusion_k_diffusion/pipeline_stable_diffusion_k_diffusion.py +20 -15
  178. diffusers/pipelines/stable_diffusion_k_diffusion/pipeline_stable_diffusion_xl_k_diffusion.py +24 -19
  179. diffusers/pipelines/stable_diffusion_ldm3d/pipeline_stable_diffusion_ldm3d.py +65 -32
  180. diffusers/pipelines/stable_diffusion_panorama/pipeline_stable_diffusion_panorama.py +274 -38
  181. diffusers/pipelines/stable_diffusion_safe/pipeline_stable_diffusion_safe.py +10 -5
  182. diffusers/pipelines/stable_diffusion_safe/safety_checker.py +1 -1
  183. diffusers/pipelines/stable_diffusion_sag/pipeline_stable_diffusion_sag.py +92 -25
  184. diffusers/pipelines/stable_diffusion_xl/pipeline_stable_diffusion_xl.py +88 -44
  185. diffusers/pipelines/stable_diffusion_xl/pipeline_stable_diffusion_xl_img2img.py +108 -56
  186. diffusers/pipelines/stable_diffusion_xl/pipeline_stable_diffusion_xl_inpaint.py +96 -51
  187. diffusers/pipelines/stable_diffusion_xl/pipeline_stable_diffusion_xl_instruct_pix2pix.py +45 -25
  188. diffusers/pipelines/stable_diffusion_xl/watermark.py +9 -3
  189. diffusers/pipelines/stable_video_diffusion/pipeline_stable_video_diffusion.py +110 -57
  190. diffusers/pipelines/t2i_adapter/pipeline_stable_diffusion_adapter.py +59 -30
  191. diffusers/pipelines/t2i_adapter/pipeline_stable_diffusion_xl_adapter.py +71 -42
  192. diffusers/pipelines/text_to_video_synthesis/pipeline_output.py +3 -2
  193. diffusers/pipelines/text_to_video_synthesis/pipeline_text_to_video_synth.py +18 -41
  194. diffusers/pipelines/text_to_video_synthesis/pipeline_text_to_video_synth_img2img.py +21 -85
  195. diffusers/pipelines/text_to_video_synthesis/pipeline_text_to_video_zero.py +28 -19
  196. diffusers/pipelines/text_to_video_synthesis/pipeline_text_to_video_zero_sdxl.py +39 -33
  197. diffusers/pipelines/unclip/pipeline_unclip.py +6 -6
  198. diffusers/pipelines/unclip/pipeline_unclip_image_variation.py +6 -6
  199. diffusers/pipelines/unidiffuser/modeling_text_decoder.py +1 -1
  200. diffusers/pipelines/unidiffuser/modeling_uvit.py +9 -9
  201. diffusers/pipelines/unidiffuser/pipeline_unidiffuser.py +23 -23
  202. diffusers/pipelines/wuerstchen/modeling_paella_vq_model.py +5 -5
  203. diffusers/pipelines/wuerstchen/modeling_wuerstchen_common.py +5 -10
  204. diffusers/pipelines/wuerstchen/modeling_wuerstchen_prior.py +4 -6
  205. diffusers/pipelines/wuerstchen/pipeline_wuerstchen.py +4 -4
  206. diffusers/pipelines/wuerstchen/pipeline_wuerstchen_combined.py +12 -12
  207. diffusers/pipelines/wuerstchen/pipeline_wuerstchen_prior.py +10 -10
  208. diffusers/schedulers/__init__.py +2 -2
  209. diffusers/schedulers/deprecated/__init__.py +1 -1
  210. diffusers/schedulers/deprecated/scheduling_karras_ve.py +25 -25
  211. diffusers/schedulers/scheduling_amused.py +5 -5
  212. diffusers/schedulers/scheduling_consistency_decoder.py +11 -11
  213. diffusers/schedulers/scheduling_consistency_models.py +23 -25
  214. diffusers/schedulers/scheduling_ddim.py +22 -24
  215. diffusers/schedulers/scheduling_ddim_flax.py +2 -1
  216. diffusers/schedulers/scheduling_ddim_inverse.py +16 -16
  217. diffusers/schedulers/scheduling_ddim_parallel.py +28 -30
  218. diffusers/schedulers/scheduling_ddpm.py +20 -22
  219. diffusers/schedulers/scheduling_ddpm_flax.py +7 -3
  220. diffusers/schedulers/scheduling_ddpm_parallel.py +26 -28
  221. diffusers/schedulers/scheduling_ddpm_wuerstchen.py +14 -14
  222. diffusers/schedulers/scheduling_deis_multistep.py +46 -42
  223. diffusers/schedulers/scheduling_dpmsolver_multistep.py +107 -77
  224. diffusers/schedulers/scheduling_dpmsolver_multistep_flax.py +2 -2
  225. diffusers/schedulers/scheduling_dpmsolver_multistep_inverse.py +46 -46
  226. diffusers/schedulers/scheduling_dpmsolver_sde.py +26 -22
  227. diffusers/schedulers/scheduling_dpmsolver_singlestep.py +90 -65
  228. diffusers/schedulers/scheduling_edm_dpmsolver_multistep.py +78 -53
  229. diffusers/schedulers/scheduling_edm_euler.py +53 -30
  230. diffusers/schedulers/scheduling_euler_ancestral_discrete.py +26 -28
  231. diffusers/schedulers/scheduling_euler_discrete.py +163 -67
  232. diffusers/schedulers/scheduling_heun_discrete.py +60 -38
  233. diffusers/schedulers/scheduling_ipndm.py +8 -8
  234. diffusers/schedulers/scheduling_k_dpm_2_ancestral_discrete.py +22 -18
  235. diffusers/schedulers/scheduling_k_dpm_2_discrete.py +22 -18
  236. diffusers/schedulers/scheduling_karras_ve_flax.py +6 -6
  237. diffusers/schedulers/scheduling_lcm.py +21 -23
  238. diffusers/schedulers/scheduling_lms_discrete.py +27 -25
  239. diffusers/schedulers/scheduling_pndm.py +20 -20
  240. diffusers/schedulers/scheduling_repaint.py +20 -20
  241. diffusers/schedulers/scheduling_sasolver.py +55 -54
  242. diffusers/schedulers/scheduling_sde_ve.py +19 -19
  243. diffusers/schedulers/scheduling_tcd.py +39 -30
  244. diffusers/schedulers/scheduling_unclip.py +15 -15
  245. diffusers/schedulers/scheduling_unipc_multistep.py +115 -41
  246. diffusers/schedulers/scheduling_utils.py +14 -5
  247. diffusers/schedulers/scheduling_utils_flax.py +3 -3
  248. diffusers/schedulers/scheduling_vq_diffusion.py +10 -10
  249. diffusers/training_utils.py +56 -1
  250. diffusers/utils/__init__.py +7 -0
  251. diffusers/utils/doc_utils.py +1 -0
  252. diffusers/utils/dummy_pt_objects.py +30 -0
  253. diffusers/utils/dummy_torch_and_transformers_objects.py +90 -0
  254. diffusers/utils/dynamic_modules_utils.py +24 -11
  255. diffusers/utils/hub_utils.py +3 -2
  256. diffusers/utils/import_utils.py +91 -0
  257. diffusers/utils/loading_utils.py +2 -2
  258. diffusers/utils/logging.py +1 -1
  259. diffusers/utils/peft_utils.py +32 -5
  260. diffusers/utils/state_dict_utils.py +11 -2
  261. diffusers/utils/testing_utils.py +71 -6
  262. diffusers/utils/torch_utils.py +1 -0
  263. diffusers/video_processor.py +113 -0
  264. {diffusers-0.27.1.dist-info → diffusers-0.28.0.dist-info}/METADATA +7 -7
  265. diffusers-0.28.0.dist-info/RECORD +414 -0
  266. diffusers-0.27.1.dist-info/RECORD +0 -399
  267. {diffusers-0.27.1.dist-info → diffusers-0.28.0.dist-info}/LICENSE +0 -0
  268. {diffusers-0.27.1.dist-info → diffusers-0.28.0.dist-info}/WHEEL +0 -0
  269. {diffusers-0.27.1.dist-info → diffusers-0.28.0.dist-info}/entry_points.txt +0 -0
  270. {diffusers-0.27.1.dist-info → diffusers-0.28.0.dist-info}/top_level.txt +0 -0
@@ -35,16 +35,16 @@ class DDIMSchedulerOutput(BaseOutput):
35
35
  Output class for the scheduler's `step` function output.
36
36
 
37
37
  Args:
38
- prev_sample (`torch.FloatTensor` of shape `(batch_size, num_channels, height, width)` for images):
38
+ prev_sample (`torch.Tensor` of shape `(batch_size, num_channels, height, width)` for images):
39
39
  Computed sample `(x_{t-1})` of previous timestep. `prev_sample` should be used as next model input in the
40
40
  denoising loop.
41
- pred_original_sample (`torch.FloatTensor` of shape `(batch_size, num_channels, height, width)` for images):
41
+ pred_original_sample (`torch.Tensor` of shape `(batch_size, num_channels, height, width)` for images):
42
42
  The predicted denoised sample `(x_{0})` based on the model output from the current timestep.
43
43
  `pred_original_sample` can be used to preview progress or for guidance.
44
44
  """
45
45
 
46
- prev_sample: torch.FloatTensor
47
- pred_original_sample: Optional[torch.FloatTensor] = None
46
+ prev_sample: torch.Tensor
47
+ pred_original_sample: Optional[torch.Tensor] = None
48
48
 
49
49
 
50
50
  # Copied from diffusers.schedulers.scheduling_ddpm.betas_for_alpha_bar
@@ -82,7 +82,7 @@ def betas_for_alpha_bar(
82
82
  return math.exp(t * -12.0)
83
83
 
84
84
  else:
85
- raise ValueError(f"Unsupported alpha_tranform_type: {alpha_transform_type}")
85
+ raise ValueError(f"Unsupported alpha_transform_type: {alpha_transform_type}")
86
86
 
87
87
  betas = []
88
88
  for i in range(num_diffusion_timesteps):
@@ -98,11 +98,11 @@ def rescale_zero_terminal_snr(betas):
98
98
 
99
99
 
100
100
  Args:
101
- betas (`torch.FloatTensor`):
101
+ betas (`torch.Tensor`):
102
102
  the betas that the scheduler is being initialized with.
103
103
 
104
104
  Returns:
105
- `torch.FloatTensor`: rescaled betas with zero terminal SNR
105
+ `torch.Tensor`: rescaled betas with zero terminal SNR
106
106
  """
107
107
  # Convert betas to alphas_bar_sqrt
108
108
  alphas = 1.0 - betas
@@ -211,7 +211,7 @@ class DDIMScheduler(SchedulerMixin, ConfigMixin):
211
211
  # Glide cosine schedule
212
212
  self.betas = betas_for_alpha_bar(num_train_timesteps)
213
213
  else:
214
- raise NotImplementedError(f"{beta_schedule} does is not implemented for {self.__class__}")
214
+ raise NotImplementedError(f"{beta_schedule} is not implemented for {self.__class__}")
215
215
 
216
216
  # Rescale for zero SNR
217
217
  if rescale_betas_zero_snr:
@@ -233,19 +233,19 @@ class DDIMScheduler(SchedulerMixin, ConfigMixin):
233
233
  self.num_inference_steps = None
234
234
  self.timesteps = torch.from_numpy(np.arange(0, num_train_timesteps)[::-1].copy().astype(np.int64))
235
235
 
236
- def scale_model_input(self, sample: torch.FloatTensor, timestep: Optional[int] = None) -> torch.FloatTensor:
236
+ def scale_model_input(self, sample: torch.Tensor, timestep: Optional[int] = None) -> torch.Tensor:
237
237
  """
238
238
  Ensures interchangeability with schedulers that need to scale the denoising model input depending on the
239
239
  current timestep.
240
240
 
241
241
  Args:
242
- sample (`torch.FloatTensor`):
242
+ sample (`torch.Tensor`):
243
243
  The input sample.
244
244
  timestep (`int`, *optional*):
245
245
  The current timestep in the diffusion chain.
246
246
 
247
247
  Returns:
248
- `torch.FloatTensor`:
248
+ `torch.Tensor`:
249
249
  A scaled input sample.
250
250
  """
251
251
  return sample
@@ -261,7 +261,7 @@ class DDIMScheduler(SchedulerMixin, ConfigMixin):
261
261
  return variance
262
262
 
263
263
  # Copied from diffusers.schedulers.scheduling_ddpm.DDPMScheduler._threshold_sample
264
- def _threshold_sample(self, sample: torch.FloatTensor) -> torch.FloatTensor:
264
+ def _threshold_sample(self, sample: torch.Tensor) -> torch.Tensor:
265
265
  """
266
266
  "Dynamic thresholding: At each sampling step we set s to a certain percentile absolute pixel value in xt0 (the
267
267
  prediction of x_0 at timestep t), and if s > 1, then we threshold xt0 to the range [-s, s] and then divide by
@@ -341,13 +341,13 @@ class DDIMScheduler(SchedulerMixin, ConfigMixin):
341
341
 
342
342
  def step(
343
343
  self,
344
- model_output: torch.FloatTensor,
344
+ model_output: torch.Tensor,
345
345
  timestep: int,
346
- sample: torch.FloatTensor,
346
+ sample: torch.Tensor,
347
347
  eta: float = 0.0,
348
348
  use_clipped_model_output: bool = False,
349
349
  generator=None,
350
- variance_noise: Optional[torch.FloatTensor] = None,
350
+ variance_noise: Optional[torch.Tensor] = None,
351
351
  return_dict: bool = True,
352
352
  ) -> Union[DDIMSchedulerOutput, Tuple]:
353
353
  """
@@ -355,11 +355,11 @@ class DDIMScheduler(SchedulerMixin, ConfigMixin):
355
355
  process from the learned model outputs (most often the predicted noise).
356
356
 
357
357
  Args:
358
- model_output (`torch.FloatTensor`):
358
+ model_output (`torch.Tensor`):
359
359
  The direct output from learned diffusion model.
360
360
  timestep (`float`):
361
361
  The current discrete timestep in the diffusion chain.
362
- sample (`torch.FloatTensor`):
362
+ sample (`torch.Tensor`):
363
363
  A current instance of a sample created by the diffusion process.
364
364
  eta (`float`):
365
365
  The weight of noise for added noise in diffusion step.
@@ -370,7 +370,7 @@ class DDIMScheduler(SchedulerMixin, ConfigMixin):
370
370
  `use_clipped_model_output` has no effect.
371
371
  generator (`torch.Generator`, *optional*):
372
372
  A random number generator.
373
- variance_noise (`torch.FloatTensor`):
373
+ variance_noise (`torch.Tensor`):
374
374
  Alternative to generating noise with `generator` by directly providing the noise for the variance
375
375
  itself. Useful for methods such as [`CycleDiffusion`].
376
376
  return_dict (`bool`, *optional*, defaults to `True`):
@@ -470,10 +470,10 @@ class DDIMScheduler(SchedulerMixin, ConfigMixin):
470
470
  # Copied from diffusers.schedulers.scheduling_ddpm.DDPMScheduler.add_noise
471
471
  def add_noise(
472
472
  self,
473
- original_samples: torch.FloatTensor,
474
- noise: torch.FloatTensor,
473
+ original_samples: torch.Tensor,
474
+ noise: torch.Tensor,
475
475
  timesteps: torch.IntTensor,
476
- ) -> torch.FloatTensor:
476
+ ) -> torch.Tensor:
477
477
  # Make sure alphas_cumprod and timestep have same device and dtype as original_samples
478
478
  # Move the self.alphas_cumprod to device to avoid redundant CPU to GPU data movement
479
479
  # for the subsequent add_noise calls
@@ -495,9 +495,7 @@ class DDIMScheduler(SchedulerMixin, ConfigMixin):
495
495
  return noisy_samples
496
496
 
497
497
  # Copied from diffusers.schedulers.scheduling_ddpm.DDPMScheduler.get_velocity
498
- def get_velocity(
499
- self, sample: torch.FloatTensor, noise: torch.FloatTensor, timesteps: torch.IntTensor
500
- ) -> torch.FloatTensor:
498
+ def get_velocity(self, sample: torch.Tensor, noise: torch.Tensor, timesteps: torch.IntTensor) -> torch.Tensor:
501
499
  # Make sure alphas_cumprod and timestep have same device and dtype as sample
502
500
  self.alphas_cumprod = self.alphas_cumprod.to(device=sample.device)
503
501
  alphas_cumprod = self.alphas_cumprod.to(dtype=sample.dtype)
@@ -85,7 +85,8 @@ class FlaxDDIMScheduler(FlaxSchedulerMixin, ConfigMixin):
85
85
  trained_betas (`jnp.ndarray`, optional):
86
86
  option to pass an array of betas directly to the constructor to bypass `beta_start`, `beta_end` etc.
87
87
  clip_sample (`bool`, default `True`):
88
- option to clip predicted sample between for numerical stability. The clip range is determined by `clip_sample_range`.
88
+ option to clip predicted sample between for numerical stability. The clip range is determined by
89
+ `clip_sample_range`.
89
90
  clip_sample_range (`float`, default `1.0`):
90
91
  the maximum magnitude for sample clipping. Valid only when `clip_sample=True`.
91
92
  set_alpha_to_one (`bool`, default `True`):
@@ -33,16 +33,16 @@ class DDIMSchedulerOutput(BaseOutput):
33
33
  Output class for the scheduler's `step` function output.
34
34
 
35
35
  Args:
36
- prev_sample (`torch.FloatTensor` of shape `(batch_size, num_channels, height, width)` for images):
36
+ prev_sample (`torch.Tensor` of shape `(batch_size, num_channels, height, width)` for images):
37
37
  Computed sample `(x_{t-1})` of previous timestep. `prev_sample` should be used as next model input in the
38
38
  denoising loop.
39
- pred_original_sample (`torch.FloatTensor` of shape `(batch_size, num_channels, height, width)` for images):
39
+ pred_original_sample (`torch.Tensor` of shape `(batch_size, num_channels, height, width)` for images):
40
40
  The predicted denoised sample `(x_{0})` based on the model output from the current timestep.
41
41
  `pred_original_sample` can be used to preview progress or for guidance.
42
42
  """
43
43
 
44
- prev_sample: torch.FloatTensor
45
- pred_original_sample: Optional[torch.FloatTensor] = None
44
+ prev_sample: torch.Tensor
45
+ pred_original_sample: Optional[torch.Tensor] = None
46
46
 
47
47
 
48
48
  # Copied from diffusers.schedulers.scheduling_ddpm.betas_for_alpha_bar
@@ -80,7 +80,7 @@ def betas_for_alpha_bar(
80
80
  return math.exp(t * -12.0)
81
81
 
82
82
  else:
83
- raise ValueError(f"Unsupported alpha_tranform_type: {alpha_transform_type}")
83
+ raise ValueError(f"Unsupported alpha_transform_type: {alpha_transform_type}")
84
84
 
85
85
  betas = []
86
86
  for i in range(num_diffusion_timesteps):
@@ -97,11 +97,11 @@ def rescale_zero_terminal_snr(betas):
97
97
 
98
98
 
99
99
  Args:
100
- betas (`torch.FloatTensor`):
100
+ betas (`torch.Tensor`):
101
101
  the betas that the scheduler is being initialized with.
102
102
 
103
103
  Returns:
104
- `torch.FloatTensor`: rescaled betas with zero terminal SNR
104
+ `torch.Tensor`: rescaled betas with zero terminal SNR
105
105
  """
106
106
  # Convert betas to alphas_bar_sqrt
107
107
  alphas = 1.0 - betas
@@ -207,7 +207,7 @@ class DDIMInverseScheduler(SchedulerMixin, ConfigMixin):
207
207
  # Glide cosine schedule
208
208
  self.betas = betas_for_alpha_bar(num_train_timesteps)
209
209
  else:
210
- raise NotImplementedError(f"{beta_schedule} does is not implemented for {self.__class__}")
210
+ raise NotImplementedError(f"{beta_schedule} is not implemented for {self.__class__}")
211
211
 
212
212
  # Rescale for zero SNR
213
213
  if rescale_betas_zero_snr:
@@ -231,19 +231,19 @@ class DDIMInverseScheduler(SchedulerMixin, ConfigMixin):
231
231
  self.timesteps = torch.from_numpy(np.arange(0, num_train_timesteps).copy().astype(np.int64))
232
232
 
233
233
  # Copied from diffusers.schedulers.scheduling_ddim.DDIMScheduler.scale_model_input
234
- def scale_model_input(self, sample: torch.FloatTensor, timestep: Optional[int] = None) -> torch.FloatTensor:
234
+ def scale_model_input(self, sample: torch.Tensor, timestep: Optional[int] = None) -> torch.Tensor:
235
235
  """
236
236
  Ensures interchangeability with schedulers that need to scale the denoising model input depending on the
237
237
  current timestep.
238
238
 
239
239
  Args:
240
- sample (`torch.FloatTensor`):
240
+ sample (`torch.Tensor`):
241
241
  The input sample.
242
242
  timestep (`int`, *optional*):
243
243
  The current timestep in the diffusion chain.
244
244
 
245
245
  Returns:
246
- `torch.FloatTensor`:
246
+ `torch.Tensor`:
247
247
  A scaled input sample.
248
248
  """
249
249
  return sample
@@ -288,9 +288,9 @@ class DDIMInverseScheduler(SchedulerMixin, ConfigMixin):
288
288
 
289
289
  def step(
290
290
  self,
291
- model_output: torch.FloatTensor,
291
+ model_output: torch.Tensor,
292
292
  timestep: int,
293
- sample: torch.FloatTensor,
293
+ sample: torch.Tensor,
294
294
  return_dict: bool = True,
295
295
  ) -> Union[DDIMSchedulerOutput, Tuple]:
296
296
  """
@@ -298,11 +298,11 @@ class DDIMInverseScheduler(SchedulerMixin, ConfigMixin):
298
298
  process from the learned model outputs (most often the predicted noise).
299
299
 
300
300
  Args:
301
- model_output (`torch.FloatTensor`):
301
+ model_output (`torch.Tensor`):
302
302
  The direct output from learned diffusion model.
303
303
  timestep (`float`):
304
304
  The current discrete timestep in the diffusion chain.
305
- sample (`torch.FloatTensor`):
305
+ sample (`torch.Tensor`):
306
306
  A current instance of a sample created by the diffusion process.
307
307
  eta (`float`):
308
308
  The weight of noise for added noise in diffusion step.
@@ -311,7 +311,7 @@ class DDIMInverseScheduler(SchedulerMixin, ConfigMixin):
311
311
  because predicted original sample is clipped to [-1, 1] when `self.config.clip_sample` is `True`. If no
312
312
  clipping has happened, "corrected" `model_output` would coincide with the one provided as input and
313
313
  `use_clipped_model_output` has no effect.
314
- variance_noise (`torch.FloatTensor`):
314
+ variance_noise (`torch.Tensor`):
315
315
  Alternative to generating noise with `generator` by directly providing the noise for the variance
316
316
  itself. Useful for methods such as [`CycleDiffusion`].
317
317
  return_dict (`bool`, *optional*, defaults to `True`):
@@ -35,16 +35,16 @@ class DDIMParallelSchedulerOutput(BaseOutput):
35
35
  Output class for the scheduler's `step` function output.
36
36
 
37
37
  Args:
38
- prev_sample (`torch.FloatTensor` of shape `(batch_size, num_channels, height, width)` for images):
38
+ prev_sample (`torch.Tensor` of shape `(batch_size, num_channels, height, width)` for images):
39
39
  Computed sample `(x_{t-1})` of previous timestep. `prev_sample` should be used as next model input in the
40
40
  denoising loop.
41
- pred_original_sample (`torch.FloatTensor` of shape `(batch_size, num_channels, height, width)` for images):
41
+ pred_original_sample (`torch.Tensor` of shape `(batch_size, num_channels, height, width)` for images):
42
42
  The predicted denoised sample `(x_{0})` based on the model output from the current timestep.
43
43
  `pred_original_sample` can be used to preview progress or for guidance.
44
44
  """
45
45
 
46
- prev_sample: torch.FloatTensor
47
- pred_original_sample: Optional[torch.FloatTensor] = None
46
+ prev_sample: torch.Tensor
47
+ pred_original_sample: Optional[torch.Tensor] = None
48
48
 
49
49
 
50
50
  # Copied from diffusers.schedulers.scheduling_ddpm.betas_for_alpha_bar
@@ -82,7 +82,7 @@ def betas_for_alpha_bar(
82
82
  return math.exp(t * -12.0)
83
83
 
84
84
  else:
85
- raise ValueError(f"Unsupported alpha_tranform_type: {alpha_transform_type}")
85
+ raise ValueError(f"Unsupported alpha_transform_type: {alpha_transform_type}")
86
86
 
87
87
  betas = []
88
88
  for i in range(num_diffusion_timesteps):
@@ -99,11 +99,11 @@ def rescale_zero_terminal_snr(betas):
99
99
 
100
100
 
101
101
  Args:
102
- betas (`torch.FloatTensor`):
102
+ betas (`torch.Tensor`):
103
103
  the betas that the scheduler is being initialized with.
104
104
 
105
105
  Returns:
106
- `torch.FloatTensor`: rescaled betas with zero terminal SNR
106
+ `torch.Tensor`: rescaled betas with zero terminal SNR
107
107
  """
108
108
  # Convert betas to alphas_bar_sqrt
109
109
  alphas = 1.0 - betas
@@ -218,7 +218,7 @@ class DDIMParallelScheduler(SchedulerMixin, ConfigMixin):
218
218
  # Glide cosine schedule
219
219
  self.betas = betas_for_alpha_bar(num_train_timesteps)
220
220
  else:
221
- raise NotImplementedError(f"{beta_schedule} does is not implemented for {self.__class__}")
221
+ raise NotImplementedError(f"{beta_schedule} is not implemented for {self.__class__}")
222
222
 
223
223
  # Rescale for zero SNR
224
224
  if rescale_betas_zero_snr:
@@ -241,19 +241,19 @@ class DDIMParallelScheduler(SchedulerMixin, ConfigMixin):
241
241
  self.timesteps = torch.from_numpy(np.arange(0, num_train_timesteps)[::-1].copy().astype(np.int64))
242
242
 
243
243
  # Copied from diffusers.schedulers.scheduling_ddim.DDIMScheduler.scale_model_input
244
- def scale_model_input(self, sample: torch.FloatTensor, timestep: Optional[int] = None) -> torch.FloatTensor:
244
+ def scale_model_input(self, sample: torch.Tensor, timestep: Optional[int] = None) -> torch.Tensor:
245
245
  """
246
246
  Ensures interchangeability with schedulers that need to scale the denoising model input depending on the
247
247
  current timestep.
248
248
 
249
249
  Args:
250
- sample (`torch.FloatTensor`):
250
+ sample (`torch.Tensor`):
251
251
  The input sample.
252
252
  timestep (`int`, *optional*):
253
253
  The current timestep in the diffusion chain.
254
254
 
255
255
  Returns:
256
- `torch.FloatTensor`:
256
+ `torch.Tensor`:
257
257
  A scaled input sample.
258
258
  """
259
259
  return sample
@@ -283,7 +283,7 @@ class DDIMParallelScheduler(SchedulerMixin, ConfigMixin):
283
283
  return variance
284
284
 
285
285
  # Copied from diffusers.schedulers.scheduling_ddpm.DDPMScheduler._threshold_sample
286
- def _threshold_sample(self, sample: torch.FloatTensor) -> torch.FloatTensor:
286
+ def _threshold_sample(self, sample: torch.Tensor) -> torch.Tensor:
287
287
  """
288
288
  "Dynamic thresholding: At each sampling step we set s to a certain percentile absolute pixel value in xt0 (the
289
289
  prediction of x_0 at timestep t), and if s > 1, then we threshold xt0 to the range [-s, s] and then divide by
@@ -364,13 +364,13 @@ class DDIMParallelScheduler(SchedulerMixin, ConfigMixin):
364
364
 
365
365
  def step(
366
366
  self,
367
- model_output: torch.FloatTensor,
367
+ model_output: torch.Tensor,
368
368
  timestep: int,
369
- sample: torch.FloatTensor,
369
+ sample: torch.Tensor,
370
370
  eta: float = 0.0,
371
371
  use_clipped_model_output: bool = False,
372
372
  generator=None,
373
- variance_noise: Optional[torch.FloatTensor] = None,
373
+ variance_noise: Optional[torch.Tensor] = None,
374
374
  return_dict: bool = True,
375
375
  ) -> Union[DDIMParallelSchedulerOutput, Tuple]:
376
376
  """
@@ -378,9 +378,9 @@ class DDIMParallelScheduler(SchedulerMixin, ConfigMixin):
378
378
  process from the learned model outputs (most often the predicted noise).
379
379
 
380
380
  Args:
381
- model_output (`torch.FloatTensor`): direct output from learned diffusion model.
381
+ model_output (`torch.Tensor`): direct output from learned diffusion model.
382
382
  timestep (`int`): current discrete timestep in the diffusion chain.
383
- sample (`torch.FloatTensor`):
383
+ sample (`torch.Tensor`):
384
384
  current instance of sample being created by diffusion process.
385
385
  eta (`float`): weight of noise for added noise in diffusion step.
386
386
  use_clipped_model_output (`bool`): if `True`, compute "corrected" `model_output` from the clipped
@@ -388,7 +388,7 @@ class DDIMParallelScheduler(SchedulerMixin, ConfigMixin):
388
388
  `self.config.clip_sample` is `True`. If no clipping has happened, "corrected" `model_output` would
389
389
  coincide with the one provided as input and `use_clipped_model_output` will have not effect.
390
390
  generator: random number generator.
391
- variance_noise (`torch.FloatTensor`): instead of generating noise for the variance using `generator`, we
391
+ variance_noise (`torch.Tensor`): instead of generating noise for the variance using `generator`, we
392
392
  can directly provide the noise for the variance itself. This is useful for methods such as
393
393
  CycleDiffusion. (https://arxiv.org/abs/2210.05559)
394
394
  return_dict (`bool`): option for returning tuple rather than DDIMParallelSchedulerOutput class
@@ -486,12 +486,12 @@ class DDIMParallelScheduler(SchedulerMixin, ConfigMixin):
486
486
 
487
487
  def batch_step_no_noise(
488
488
  self,
489
- model_output: torch.FloatTensor,
489
+ model_output: torch.Tensor,
490
490
  timesteps: List[int],
491
- sample: torch.FloatTensor,
491
+ sample: torch.Tensor,
492
492
  eta: float = 0.0,
493
493
  use_clipped_model_output: bool = False,
494
- ) -> torch.FloatTensor:
494
+ ) -> torch.Tensor:
495
495
  """
496
496
  Batched version of the `step` function, to be able to reverse the SDE for multiple samples/timesteps at once.
497
497
  Also, does not add any noise to the predicted sample, which is necessary for parallel sampling where the noise
@@ -501,10 +501,10 @@ class DDIMParallelScheduler(SchedulerMixin, ConfigMixin):
501
501
  process from the learned model outputs (most often the predicted noise).
502
502
 
503
503
  Args:
504
- model_output (`torch.FloatTensor`): direct output from learned diffusion model.
504
+ model_output (`torch.Tensor`): direct output from learned diffusion model.
505
505
  timesteps (`List[int]`):
506
506
  current discrete timesteps in the diffusion chain. This is now a list of integers.
507
- sample (`torch.FloatTensor`):
507
+ sample (`torch.Tensor`):
508
508
  current instance of sample being created by diffusion process.
509
509
  eta (`float`): weight of noise for added noise in diffusion step.
510
510
  use_clipped_model_output (`bool`): if `True`, compute "corrected" `model_output` from the clipped
@@ -513,7 +513,7 @@ class DDIMParallelScheduler(SchedulerMixin, ConfigMixin):
513
513
  coincide with the one provided as input and `use_clipped_model_output` will have not effect.
514
514
 
515
515
  Returns:
516
- `torch.FloatTensor`: sample tensor at previous timestep.
516
+ `torch.Tensor`: sample tensor at previous timestep.
517
517
 
518
518
  """
519
519
  if self.num_inference_steps is None:
@@ -595,10 +595,10 @@ class DDIMParallelScheduler(SchedulerMixin, ConfigMixin):
595
595
  # Copied from diffusers.schedulers.scheduling_ddpm.DDPMScheduler.add_noise
596
596
  def add_noise(
597
597
  self,
598
- original_samples: torch.FloatTensor,
599
- noise: torch.FloatTensor,
598
+ original_samples: torch.Tensor,
599
+ noise: torch.Tensor,
600
600
  timesteps: torch.IntTensor,
601
- ) -> torch.FloatTensor:
601
+ ) -> torch.Tensor:
602
602
  # Make sure alphas_cumprod and timestep have same device and dtype as original_samples
603
603
  # Move the self.alphas_cumprod to device to avoid redundant CPU to GPU data movement
604
604
  # for the subsequent add_noise calls
@@ -620,9 +620,7 @@ class DDIMParallelScheduler(SchedulerMixin, ConfigMixin):
620
620
  return noisy_samples
621
621
 
622
622
  # Copied from diffusers.schedulers.scheduling_ddpm.DDPMScheduler.get_velocity
623
- def get_velocity(
624
- self, sample: torch.FloatTensor, noise: torch.FloatTensor, timesteps: torch.IntTensor
625
- ) -> torch.FloatTensor:
623
+ def get_velocity(self, sample: torch.Tensor, noise: torch.Tensor, timesteps: torch.IntTensor) -> torch.Tensor:
626
624
  # Make sure alphas_cumprod and timestep have same device and dtype as sample
627
625
  self.alphas_cumprod = self.alphas_cumprod.to(device=sample.device)
628
626
  alphas_cumprod = self.alphas_cumprod.to(dtype=sample.dtype)
@@ -33,16 +33,16 @@ class DDPMSchedulerOutput(BaseOutput):
33
33
  Output class for the scheduler's `step` function output.
34
34
 
35
35
  Args:
36
- prev_sample (`torch.FloatTensor` of shape `(batch_size, num_channels, height, width)` for images):
36
+ prev_sample (`torch.Tensor` of shape `(batch_size, num_channels, height, width)` for images):
37
37
  Computed sample `(x_{t-1})` of previous timestep. `prev_sample` should be used as next model input in the
38
38
  denoising loop.
39
- pred_original_sample (`torch.FloatTensor` of shape `(batch_size, num_channels, height, width)` for images):
39
+ pred_original_sample (`torch.Tensor` of shape `(batch_size, num_channels, height, width)` for images):
40
40
  The predicted denoised sample `(x_{0})` based on the model output from the current timestep.
41
41
  `pred_original_sample` can be used to preview progress or for guidance.
42
42
  """
43
43
 
44
- prev_sample: torch.FloatTensor
45
- pred_original_sample: Optional[torch.FloatTensor] = None
44
+ prev_sample: torch.Tensor
45
+ pred_original_sample: Optional[torch.Tensor] = None
46
46
 
47
47
 
48
48
  def betas_for_alpha_bar(
@@ -79,7 +79,7 @@ def betas_for_alpha_bar(
79
79
  return math.exp(t * -12.0)
80
80
 
81
81
  else:
82
- raise ValueError(f"Unsupported alpha_tranform_type: {alpha_transform_type}")
82
+ raise ValueError(f"Unsupported alpha_transform_type: {alpha_transform_type}")
83
83
 
84
84
  betas = []
85
85
  for i in range(num_diffusion_timesteps):
@@ -96,11 +96,11 @@ def rescale_zero_terminal_snr(betas):
96
96
 
97
97
 
98
98
  Args:
99
- betas (`torch.FloatTensor`):
99
+ betas (`torch.Tensor`):
100
100
  the betas that the scheduler is being initialized with.
101
101
 
102
102
  Returns:
103
- `torch.FloatTensor`: rescaled betas with zero terminal SNR
103
+ `torch.Tensor`: rescaled betas with zero terminal SNR
104
104
  """
105
105
  # Convert betas to alphas_bar_sqrt
106
106
  alphas = 1.0 - betas
@@ -211,7 +211,7 @@ class DDPMScheduler(SchedulerMixin, ConfigMixin):
211
211
  betas = torch.linspace(-6, 6, num_train_timesteps)
212
212
  self.betas = torch.sigmoid(betas) * (beta_end - beta_start) + beta_start
213
213
  else:
214
- raise NotImplementedError(f"{beta_schedule} does is not implemented for {self.__class__}")
214
+ raise NotImplementedError(f"{beta_schedule} is not implemented for {self.__class__}")
215
215
 
216
216
  # Rescale for zero SNR
217
217
  if rescale_betas_zero_snr:
@@ -231,19 +231,19 @@ class DDPMScheduler(SchedulerMixin, ConfigMixin):
231
231
 
232
232
  self.variance_type = variance_type
233
233
 
234
- def scale_model_input(self, sample: torch.FloatTensor, timestep: Optional[int] = None) -> torch.FloatTensor:
234
+ def scale_model_input(self, sample: torch.Tensor, timestep: Optional[int] = None) -> torch.Tensor:
235
235
  """
236
236
  Ensures interchangeability with schedulers that need to scale the denoising model input depending on the
237
237
  current timestep.
238
238
 
239
239
  Args:
240
- sample (`torch.FloatTensor`):
240
+ sample (`torch.Tensor`):
241
241
  The input sample.
242
242
  timestep (`int`, *optional*):
243
243
  The current timestep in the diffusion chain.
244
244
 
245
245
  Returns:
246
- `torch.FloatTensor`:
246
+ `torch.Tensor`:
247
247
  A scaled input sample.
248
248
  """
249
249
  return sample
@@ -363,7 +363,7 @@ class DDPMScheduler(SchedulerMixin, ConfigMixin):
363
363
 
364
364
  return variance
365
365
 
366
- def _threshold_sample(self, sample: torch.FloatTensor) -> torch.FloatTensor:
366
+ def _threshold_sample(self, sample: torch.Tensor) -> torch.Tensor:
367
367
  """
368
368
  "Dynamic thresholding: At each sampling step we set s to a certain percentile absolute pixel value in xt0 (the
369
369
  prediction of x_0 at timestep t), and if s > 1, then we threshold xt0 to the range [-s, s] and then divide by
@@ -398,9 +398,9 @@ class DDPMScheduler(SchedulerMixin, ConfigMixin):
398
398
 
399
399
  def step(
400
400
  self,
401
- model_output: torch.FloatTensor,
401
+ model_output: torch.Tensor,
402
402
  timestep: int,
403
- sample: torch.FloatTensor,
403
+ sample: torch.Tensor,
404
404
  generator=None,
405
405
  return_dict: bool = True,
406
406
  ) -> Union[DDPMSchedulerOutput, Tuple]:
@@ -409,11 +409,11 @@ class DDPMScheduler(SchedulerMixin, ConfigMixin):
409
409
  process from the learned model outputs (most often the predicted noise).
410
410
 
411
411
  Args:
412
- model_output (`torch.FloatTensor`):
412
+ model_output (`torch.Tensor`):
413
413
  The direct output from learned diffusion model.
414
414
  timestep (`float`):
415
415
  The current discrete timestep in the diffusion chain.
416
- sample (`torch.FloatTensor`):
416
+ sample (`torch.Tensor`):
417
417
  A current instance of a sample created by the diffusion process.
418
418
  generator (`torch.Generator`, *optional*):
419
419
  A random number generator.
@@ -498,10 +498,10 @@ class DDPMScheduler(SchedulerMixin, ConfigMixin):
498
498
 
499
499
  def add_noise(
500
500
  self,
501
- original_samples: torch.FloatTensor,
502
- noise: torch.FloatTensor,
501
+ original_samples: torch.Tensor,
502
+ noise: torch.Tensor,
503
503
  timesteps: torch.IntTensor,
504
- ) -> torch.FloatTensor:
504
+ ) -> torch.Tensor:
505
505
  # Make sure alphas_cumprod and timestep have same device and dtype as original_samples
506
506
  # Move the self.alphas_cumprod to device to avoid redundant CPU to GPU data movement
507
507
  # for the subsequent add_noise calls
@@ -522,9 +522,7 @@ class DDPMScheduler(SchedulerMixin, ConfigMixin):
522
522
  noisy_samples = sqrt_alpha_prod * original_samples + sqrt_one_minus_alpha_prod * noise
523
523
  return noisy_samples
524
524
 
525
- def get_velocity(
526
- self, sample: torch.FloatTensor, noise: torch.FloatTensor, timesteps: torch.IntTensor
527
- ) -> torch.FloatTensor:
525
+ def get_velocity(self, sample: torch.Tensor, noise: torch.Tensor, timesteps: torch.IntTensor) -> torch.Tensor:
528
526
  # Make sure alphas_cumprod and timestep have same device and dtype as sample
529
527
  self.alphas_cumprod = self.alphas_cumprod.to(device=sample.device)
530
528
  alphas_cumprod = self.alphas_cumprod.to(dtype=sample.dtype)
@@ -222,9 +222,13 @@ class FlaxDDPMScheduler(FlaxSchedulerMixin, ConfigMixin):
222
222
  t = timestep
223
223
 
224
224
  if key is None:
225
- key = jax.random.PRNGKey(0)
225
+ key = jax.random.key(0)
226
226
 
227
- if model_output.shape[1] == sample.shape[1] * 2 and self.config.variance_type in ["learned", "learned_range"]:
227
+ if (
228
+ len(model_output.shape) > 1
229
+ and model_output.shape[1] == sample.shape[1] * 2
230
+ and self.config.variance_type in ["learned", "learned_range"]
231
+ ):
228
232
  model_output, predicted_variance = jnp.split(model_output, sample.shape[1], axis=1)
229
233
  else:
230
234
  predicted_variance = None
@@ -264,7 +268,7 @@ class FlaxDDPMScheduler(FlaxSchedulerMixin, ConfigMixin):
264
268
 
265
269
  # 6. Add noise
266
270
  def random_variance():
267
- split_key = jax.random.split(key, num=1)
271
+ split_key = jax.random.split(key, num=1)[0]
268
272
  noise = jax.random.normal(split_key, shape=model_output.shape, dtype=self.dtype)
269
273
  return (self._get_variance(state, t, predicted_variance=predicted_variance) ** 0.5) * noise
270
274