diffusers 0.27.1__py3-none-any.whl → 0.28.0__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (270) hide show
  1. diffusers/__init__.py +18 -1
  2. diffusers/callbacks.py +156 -0
  3. diffusers/commands/env.py +110 -6
  4. diffusers/configuration_utils.py +16 -11
  5. diffusers/dependency_versions_table.py +2 -1
  6. diffusers/image_processor.py +158 -45
  7. diffusers/loaders/__init__.py +2 -5
  8. diffusers/loaders/autoencoder.py +4 -4
  9. diffusers/loaders/controlnet.py +4 -4
  10. diffusers/loaders/ip_adapter.py +80 -22
  11. diffusers/loaders/lora.py +134 -20
  12. diffusers/loaders/lora_conversion_utils.py +46 -43
  13. diffusers/loaders/peft.py +4 -3
  14. diffusers/loaders/single_file.py +401 -170
  15. diffusers/loaders/single_file_model.py +290 -0
  16. diffusers/loaders/single_file_utils.py +616 -672
  17. diffusers/loaders/textual_inversion.py +41 -20
  18. diffusers/loaders/unet.py +168 -115
  19. diffusers/loaders/unet_loader_utils.py +163 -0
  20. diffusers/models/__init__.py +2 -0
  21. diffusers/models/activations.py +11 -3
  22. diffusers/models/attention.py +10 -11
  23. diffusers/models/attention_processor.py +367 -148
  24. diffusers/models/autoencoders/autoencoder_asym_kl.py +14 -16
  25. diffusers/models/autoencoders/autoencoder_kl.py +18 -19
  26. diffusers/models/autoencoders/autoencoder_kl_temporal_decoder.py +11 -11
  27. diffusers/models/autoencoders/autoencoder_tiny.py +16 -16
  28. diffusers/models/autoencoders/consistency_decoder_vae.py +36 -11
  29. diffusers/models/autoencoders/vae.py +23 -24
  30. diffusers/models/controlnet.py +12 -9
  31. diffusers/models/controlnet_flax.py +4 -4
  32. diffusers/models/controlnet_xs.py +1915 -0
  33. diffusers/models/downsampling.py +17 -18
  34. diffusers/models/embeddings.py +147 -24
  35. diffusers/models/model_loading_utils.py +149 -0
  36. diffusers/models/modeling_flax_pytorch_utils.py +2 -1
  37. diffusers/models/modeling_flax_utils.py +4 -4
  38. diffusers/models/modeling_pytorch_flax_utils.py +1 -1
  39. diffusers/models/modeling_utils.py +118 -98
  40. diffusers/models/resnet.py +18 -23
  41. diffusers/models/transformer_temporal.py +3 -3
  42. diffusers/models/transformers/dual_transformer_2d.py +4 -4
  43. diffusers/models/transformers/prior_transformer.py +7 -7
  44. diffusers/models/transformers/t5_film_transformer.py +17 -19
  45. diffusers/models/transformers/transformer_2d.py +272 -156
  46. diffusers/models/transformers/transformer_temporal.py +10 -10
  47. diffusers/models/unets/unet_1d.py +5 -5
  48. diffusers/models/unets/unet_1d_blocks.py +29 -29
  49. diffusers/models/unets/unet_2d.py +6 -6
  50. diffusers/models/unets/unet_2d_blocks.py +137 -128
  51. diffusers/models/unets/unet_2d_condition.py +20 -15
  52. diffusers/models/unets/unet_2d_condition_flax.py +6 -5
  53. diffusers/models/unets/unet_3d_blocks.py +79 -77
  54. diffusers/models/unets/unet_3d_condition.py +13 -9
  55. diffusers/models/unets/unet_i2vgen_xl.py +14 -13
  56. diffusers/models/unets/unet_kandinsky3.py +1 -1
  57. diffusers/models/unets/unet_motion_model.py +114 -14
  58. diffusers/models/unets/unet_spatio_temporal_condition.py +15 -14
  59. diffusers/models/unets/unet_stable_cascade.py +16 -13
  60. diffusers/models/upsampling.py +17 -20
  61. diffusers/models/vq_model.py +16 -15
  62. diffusers/pipelines/__init__.py +25 -3
  63. diffusers/pipelines/amused/pipeline_amused.py +12 -12
  64. diffusers/pipelines/amused/pipeline_amused_img2img.py +14 -12
  65. diffusers/pipelines/amused/pipeline_amused_inpaint.py +13 -11
  66. diffusers/pipelines/animatediff/__init__.py +2 -0
  67. diffusers/pipelines/animatediff/pipeline_animatediff.py +24 -46
  68. diffusers/pipelines/animatediff/pipeline_animatediff_sdxl.py +1284 -0
  69. diffusers/pipelines/animatediff/pipeline_animatediff_video2video.py +82 -72
  70. diffusers/pipelines/animatediff/pipeline_output.py +3 -2
  71. diffusers/pipelines/audioldm/pipeline_audioldm.py +14 -14
  72. diffusers/pipelines/audioldm2/modeling_audioldm2.py +54 -35
  73. diffusers/pipelines/audioldm2/pipeline_audioldm2.py +120 -36
  74. diffusers/pipelines/auto_pipeline.py +21 -17
  75. diffusers/pipelines/blip_diffusion/blip_image_processing.py +1 -1
  76. diffusers/pipelines/blip_diffusion/modeling_blip2.py +5 -5
  77. diffusers/pipelines/blip_diffusion/modeling_ctx_clip.py +1 -1
  78. diffusers/pipelines/blip_diffusion/pipeline_blip_diffusion.py +2 -2
  79. diffusers/pipelines/consistency_models/pipeline_consistency_models.py +5 -5
  80. diffusers/pipelines/controlnet/multicontrolnet.py +4 -8
  81. diffusers/pipelines/controlnet/pipeline_controlnet.py +87 -52
  82. diffusers/pipelines/controlnet/pipeline_controlnet_blip_diffusion.py +2 -2
  83. diffusers/pipelines/controlnet/pipeline_controlnet_img2img.py +50 -43
  84. diffusers/pipelines/controlnet/pipeline_controlnet_inpaint.py +52 -40
  85. diffusers/pipelines/controlnet/pipeline_controlnet_inpaint_sd_xl.py +80 -47
  86. diffusers/pipelines/controlnet/pipeline_controlnet_sd_xl.py +147 -49
  87. diffusers/pipelines/controlnet/pipeline_controlnet_sd_xl_img2img.py +89 -55
  88. diffusers/pipelines/controlnet_xs/__init__.py +68 -0
  89. diffusers/pipelines/controlnet_xs/pipeline_controlnet_xs.py +911 -0
  90. diffusers/pipelines/controlnet_xs/pipeline_controlnet_xs_sd_xl.py +1115 -0
  91. diffusers/pipelines/deepfloyd_if/pipeline_if.py +14 -28
  92. diffusers/pipelines/deepfloyd_if/pipeline_if_img2img.py +18 -33
  93. diffusers/pipelines/deepfloyd_if/pipeline_if_img2img_superresolution.py +21 -39
  94. diffusers/pipelines/deepfloyd_if/pipeline_if_inpainting.py +20 -36
  95. diffusers/pipelines/deepfloyd_if/pipeline_if_inpainting_superresolution.py +23 -39
  96. diffusers/pipelines/deepfloyd_if/pipeline_if_superresolution.py +17 -32
  97. diffusers/pipelines/deprecated/alt_diffusion/modeling_roberta_series.py +11 -11
  98. diffusers/pipelines/deprecated/alt_diffusion/pipeline_alt_diffusion.py +43 -20
  99. diffusers/pipelines/deprecated/alt_diffusion/pipeline_alt_diffusion_img2img.py +36 -18
  100. diffusers/pipelines/deprecated/repaint/pipeline_repaint.py +2 -2
  101. diffusers/pipelines/deprecated/spectrogram_diffusion/pipeline_spectrogram_diffusion.py +7 -7
  102. diffusers/pipelines/deprecated/stable_diffusion_variants/pipeline_cycle_diffusion.py +12 -12
  103. diffusers/pipelines/deprecated/stable_diffusion_variants/pipeline_stable_diffusion_inpaint_legacy.py +18 -21
  104. diffusers/pipelines/deprecated/stable_diffusion_variants/pipeline_stable_diffusion_model_editing.py +20 -15
  105. diffusers/pipelines/deprecated/stable_diffusion_variants/pipeline_stable_diffusion_paradigms.py +20 -15
  106. diffusers/pipelines/deprecated/stable_diffusion_variants/pipeline_stable_diffusion_pix2pix_zero.py +30 -25
  107. diffusers/pipelines/deprecated/versatile_diffusion/modeling_text_unet.py +69 -59
  108. diffusers/pipelines/deprecated/versatile_diffusion/pipeline_versatile_diffusion.py +13 -13
  109. diffusers/pipelines/deprecated/versatile_diffusion/pipeline_versatile_diffusion_dual_guided.py +10 -5
  110. diffusers/pipelines/deprecated/versatile_diffusion/pipeline_versatile_diffusion_image_variation.py +11 -6
  111. diffusers/pipelines/deprecated/versatile_diffusion/pipeline_versatile_diffusion_text_to_image.py +10 -5
  112. diffusers/pipelines/deprecated/vq_diffusion/pipeline_vq_diffusion.py +5 -5
  113. diffusers/pipelines/dit/pipeline_dit.py +3 -0
  114. diffusers/pipelines/free_init_utils.py +39 -38
  115. diffusers/pipelines/i2vgen_xl/pipeline_i2vgen_xl.py +33 -48
  116. diffusers/pipelines/kandinsky/pipeline_kandinsky.py +8 -8
  117. diffusers/pipelines/kandinsky/pipeline_kandinsky_combined.py +23 -20
  118. diffusers/pipelines/kandinsky/pipeline_kandinsky_img2img.py +11 -11
  119. diffusers/pipelines/kandinsky/pipeline_kandinsky_inpaint.py +12 -12
  120. diffusers/pipelines/kandinsky/pipeline_kandinsky_prior.py +10 -10
  121. diffusers/pipelines/kandinsky2_2/pipeline_kandinsky2_2.py +6 -6
  122. diffusers/pipelines/kandinsky2_2/pipeline_kandinsky2_2_combined.py +32 -29
  123. diffusers/pipelines/kandinsky2_2/pipeline_kandinsky2_2_controlnet.py +10 -10
  124. diffusers/pipelines/kandinsky2_2/pipeline_kandinsky2_2_controlnet_img2img.py +10 -10
  125. diffusers/pipelines/kandinsky2_2/pipeline_kandinsky2_2_img2img.py +6 -6
  126. diffusers/pipelines/kandinsky2_2/pipeline_kandinsky2_2_inpainting.py +8 -8
  127. diffusers/pipelines/kandinsky2_2/pipeline_kandinsky2_2_prior.py +7 -7
  128. diffusers/pipelines/kandinsky2_2/pipeline_kandinsky2_2_prior_emb2emb.py +6 -6
  129. diffusers/pipelines/kandinsky3/convert_kandinsky3_unet.py +3 -3
  130. diffusers/pipelines/kandinsky3/pipeline_kandinsky3.py +20 -33
  131. diffusers/pipelines/kandinsky3/pipeline_kandinsky3_img2img.py +24 -35
  132. diffusers/pipelines/latent_consistency_models/pipeline_latent_consistency_img2img.py +48 -30
  133. diffusers/pipelines/latent_consistency_models/pipeline_latent_consistency_text2img.py +50 -28
  134. diffusers/pipelines/latent_diffusion/pipeline_latent_diffusion.py +11 -11
  135. diffusers/pipelines/ledits_pp/pipeline_leditspp_stable_diffusion.py +61 -67
  136. diffusers/pipelines/ledits_pp/pipeline_leditspp_stable_diffusion_xl.py +70 -69
  137. diffusers/pipelines/ledits_pp/pipeline_output.py +2 -2
  138. diffusers/pipelines/marigold/__init__.py +50 -0
  139. diffusers/pipelines/marigold/marigold_image_processing.py +561 -0
  140. diffusers/pipelines/marigold/pipeline_marigold_depth.py +813 -0
  141. diffusers/pipelines/marigold/pipeline_marigold_normals.py +690 -0
  142. diffusers/pipelines/musicldm/pipeline_musicldm.py +14 -14
  143. diffusers/pipelines/paint_by_example/pipeline_paint_by_example.py +17 -12
  144. diffusers/pipelines/pia/pipeline_pia.py +39 -125
  145. diffusers/pipelines/pipeline_flax_utils.py +4 -4
  146. diffusers/pipelines/pipeline_loading_utils.py +268 -23
  147. diffusers/pipelines/pipeline_utils.py +266 -37
  148. diffusers/pipelines/pixart_alpha/__init__.py +8 -1
  149. diffusers/pipelines/pixart_alpha/pipeline_pixart_alpha.py +65 -75
  150. diffusers/pipelines/pixart_alpha/pipeline_pixart_sigma.py +880 -0
  151. diffusers/pipelines/semantic_stable_diffusion/pipeline_semantic_stable_diffusion.py +10 -5
  152. diffusers/pipelines/shap_e/pipeline_shap_e.py +3 -3
  153. diffusers/pipelines/shap_e/pipeline_shap_e_img2img.py +14 -14
  154. diffusers/pipelines/shap_e/renderer.py +1 -1
  155. diffusers/pipelines/stable_cascade/pipeline_stable_cascade.py +36 -22
  156. diffusers/pipelines/stable_cascade/pipeline_stable_cascade_combined.py +23 -19
  157. diffusers/pipelines/stable_cascade/pipeline_stable_cascade_prior.py +33 -32
  158. diffusers/pipelines/stable_diffusion/__init__.py +0 -1
  159. diffusers/pipelines/stable_diffusion/convert_from_ckpt.py +18 -11
  160. diffusers/pipelines/stable_diffusion/pipeline_onnx_stable_diffusion.py +2 -2
  161. diffusers/pipelines/stable_diffusion/pipeline_onnx_stable_diffusion_upscale.py +6 -6
  162. diffusers/pipelines/stable_diffusion/pipeline_stable_diffusion.py +73 -39
  163. diffusers/pipelines/stable_diffusion/pipeline_stable_diffusion_depth2img.py +24 -17
  164. diffusers/pipelines/stable_diffusion/pipeline_stable_diffusion_image_variation.py +13 -8
  165. diffusers/pipelines/stable_diffusion/pipeline_stable_diffusion_img2img.py +66 -36
  166. diffusers/pipelines/stable_diffusion/pipeline_stable_diffusion_inpaint.py +82 -46
  167. diffusers/pipelines/stable_diffusion/pipeline_stable_diffusion_instruct_pix2pix.py +123 -28
  168. diffusers/pipelines/stable_diffusion/pipeline_stable_diffusion_latent_upscale.py +6 -6
  169. diffusers/pipelines/stable_diffusion/pipeline_stable_diffusion_upscale.py +16 -16
  170. diffusers/pipelines/stable_diffusion/pipeline_stable_unclip.py +24 -19
  171. diffusers/pipelines/stable_diffusion/pipeline_stable_unclip_img2img.py +37 -31
  172. diffusers/pipelines/stable_diffusion/safety_checker.py +2 -1
  173. diffusers/pipelines/stable_diffusion_attend_and_excite/pipeline_stable_diffusion_attend_and_excite.py +23 -15
  174. diffusers/pipelines/stable_diffusion_diffedit/pipeline_stable_diffusion_diffedit.py +44 -42
  175. diffusers/pipelines/stable_diffusion_gligen/pipeline_stable_diffusion_gligen.py +23 -18
  176. diffusers/pipelines/stable_diffusion_gligen/pipeline_stable_diffusion_gligen_text_image.py +19 -14
  177. diffusers/pipelines/stable_diffusion_k_diffusion/pipeline_stable_diffusion_k_diffusion.py +20 -15
  178. diffusers/pipelines/stable_diffusion_k_diffusion/pipeline_stable_diffusion_xl_k_diffusion.py +24 -19
  179. diffusers/pipelines/stable_diffusion_ldm3d/pipeline_stable_diffusion_ldm3d.py +65 -32
  180. diffusers/pipelines/stable_diffusion_panorama/pipeline_stable_diffusion_panorama.py +274 -38
  181. diffusers/pipelines/stable_diffusion_safe/pipeline_stable_diffusion_safe.py +10 -5
  182. diffusers/pipelines/stable_diffusion_safe/safety_checker.py +1 -1
  183. diffusers/pipelines/stable_diffusion_sag/pipeline_stable_diffusion_sag.py +92 -25
  184. diffusers/pipelines/stable_diffusion_xl/pipeline_stable_diffusion_xl.py +88 -44
  185. diffusers/pipelines/stable_diffusion_xl/pipeline_stable_diffusion_xl_img2img.py +108 -56
  186. diffusers/pipelines/stable_diffusion_xl/pipeline_stable_diffusion_xl_inpaint.py +96 -51
  187. diffusers/pipelines/stable_diffusion_xl/pipeline_stable_diffusion_xl_instruct_pix2pix.py +45 -25
  188. diffusers/pipelines/stable_diffusion_xl/watermark.py +9 -3
  189. diffusers/pipelines/stable_video_diffusion/pipeline_stable_video_diffusion.py +110 -57
  190. diffusers/pipelines/t2i_adapter/pipeline_stable_diffusion_adapter.py +59 -30
  191. diffusers/pipelines/t2i_adapter/pipeline_stable_diffusion_xl_adapter.py +71 -42
  192. diffusers/pipelines/text_to_video_synthesis/pipeline_output.py +3 -2
  193. diffusers/pipelines/text_to_video_synthesis/pipeline_text_to_video_synth.py +18 -41
  194. diffusers/pipelines/text_to_video_synthesis/pipeline_text_to_video_synth_img2img.py +21 -85
  195. diffusers/pipelines/text_to_video_synthesis/pipeline_text_to_video_zero.py +28 -19
  196. diffusers/pipelines/text_to_video_synthesis/pipeline_text_to_video_zero_sdxl.py +39 -33
  197. diffusers/pipelines/unclip/pipeline_unclip.py +6 -6
  198. diffusers/pipelines/unclip/pipeline_unclip_image_variation.py +6 -6
  199. diffusers/pipelines/unidiffuser/modeling_text_decoder.py +1 -1
  200. diffusers/pipelines/unidiffuser/modeling_uvit.py +9 -9
  201. diffusers/pipelines/unidiffuser/pipeline_unidiffuser.py +23 -23
  202. diffusers/pipelines/wuerstchen/modeling_paella_vq_model.py +5 -5
  203. diffusers/pipelines/wuerstchen/modeling_wuerstchen_common.py +5 -10
  204. diffusers/pipelines/wuerstchen/modeling_wuerstchen_prior.py +4 -6
  205. diffusers/pipelines/wuerstchen/pipeline_wuerstchen.py +4 -4
  206. diffusers/pipelines/wuerstchen/pipeline_wuerstchen_combined.py +12 -12
  207. diffusers/pipelines/wuerstchen/pipeline_wuerstchen_prior.py +10 -10
  208. diffusers/schedulers/__init__.py +2 -2
  209. diffusers/schedulers/deprecated/__init__.py +1 -1
  210. diffusers/schedulers/deprecated/scheduling_karras_ve.py +25 -25
  211. diffusers/schedulers/scheduling_amused.py +5 -5
  212. diffusers/schedulers/scheduling_consistency_decoder.py +11 -11
  213. diffusers/schedulers/scheduling_consistency_models.py +23 -25
  214. diffusers/schedulers/scheduling_ddim.py +22 -24
  215. diffusers/schedulers/scheduling_ddim_flax.py +2 -1
  216. diffusers/schedulers/scheduling_ddim_inverse.py +16 -16
  217. diffusers/schedulers/scheduling_ddim_parallel.py +28 -30
  218. diffusers/schedulers/scheduling_ddpm.py +20 -22
  219. diffusers/schedulers/scheduling_ddpm_flax.py +7 -3
  220. diffusers/schedulers/scheduling_ddpm_parallel.py +26 -28
  221. diffusers/schedulers/scheduling_ddpm_wuerstchen.py +14 -14
  222. diffusers/schedulers/scheduling_deis_multistep.py +46 -42
  223. diffusers/schedulers/scheduling_dpmsolver_multistep.py +107 -77
  224. diffusers/schedulers/scheduling_dpmsolver_multistep_flax.py +2 -2
  225. diffusers/schedulers/scheduling_dpmsolver_multistep_inverse.py +46 -46
  226. diffusers/schedulers/scheduling_dpmsolver_sde.py +26 -22
  227. diffusers/schedulers/scheduling_dpmsolver_singlestep.py +90 -65
  228. diffusers/schedulers/scheduling_edm_dpmsolver_multistep.py +78 -53
  229. diffusers/schedulers/scheduling_edm_euler.py +53 -30
  230. diffusers/schedulers/scheduling_euler_ancestral_discrete.py +26 -28
  231. diffusers/schedulers/scheduling_euler_discrete.py +163 -67
  232. diffusers/schedulers/scheduling_heun_discrete.py +60 -38
  233. diffusers/schedulers/scheduling_ipndm.py +8 -8
  234. diffusers/schedulers/scheduling_k_dpm_2_ancestral_discrete.py +22 -18
  235. diffusers/schedulers/scheduling_k_dpm_2_discrete.py +22 -18
  236. diffusers/schedulers/scheduling_karras_ve_flax.py +6 -6
  237. diffusers/schedulers/scheduling_lcm.py +21 -23
  238. diffusers/schedulers/scheduling_lms_discrete.py +27 -25
  239. diffusers/schedulers/scheduling_pndm.py +20 -20
  240. diffusers/schedulers/scheduling_repaint.py +20 -20
  241. diffusers/schedulers/scheduling_sasolver.py +55 -54
  242. diffusers/schedulers/scheduling_sde_ve.py +19 -19
  243. diffusers/schedulers/scheduling_tcd.py +39 -30
  244. diffusers/schedulers/scheduling_unclip.py +15 -15
  245. diffusers/schedulers/scheduling_unipc_multistep.py +115 -41
  246. diffusers/schedulers/scheduling_utils.py +14 -5
  247. diffusers/schedulers/scheduling_utils_flax.py +3 -3
  248. diffusers/schedulers/scheduling_vq_diffusion.py +10 -10
  249. diffusers/training_utils.py +56 -1
  250. diffusers/utils/__init__.py +7 -0
  251. diffusers/utils/doc_utils.py +1 -0
  252. diffusers/utils/dummy_pt_objects.py +30 -0
  253. diffusers/utils/dummy_torch_and_transformers_objects.py +90 -0
  254. diffusers/utils/dynamic_modules_utils.py +24 -11
  255. diffusers/utils/hub_utils.py +3 -2
  256. diffusers/utils/import_utils.py +91 -0
  257. diffusers/utils/loading_utils.py +2 -2
  258. diffusers/utils/logging.py +1 -1
  259. diffusers/utils/peft_utils.py +32 -5
  260. diffusers/utils/state_dict_utils.py +11 -2
  261. diffusers/utils/testing_utils.py +71 -6
  262. diffusers/utils/torch_utils.py +1 -0
  263. diffusers/video_processor.py +113 -0
  264. {diffusers-0.27.1.dist-info → diffusers-0.28.0.dist-info}/METADATA +7 -7
  265. diffusers-0.28.0.dist-info/RECORD +414 -0
  266. diffusers-0.27.1.dist-info/RECORD +0 -399
  267. {diffusers-0.27.1.dist-info → diffusers-0.28.0.dist-info}/LICENSE +0 -0
  268. {diffusers-0.27.1.dist-info → diffusers-0.28.0.dist-info}/WHEEL +0 -0
  269. {diffusers-0.27.1.dist-info → diffusers-0.28.0.dist-info}/entry_points.txt +0 -0
  270. {diffusers-0.27.1.dist-info → diffusers-0.28.0.dist-info}/top_level.txt +0 -0
@@ -57,7 +57,7 @@ def betas_for_alpha_bar(
57
57
  return math.exp(t * -12.0)
58
58
 
59
59
  else:
60
- raise ValueError(f"Unsupported alpha_tranform_type: {alpha_transform_type}")
60
+ raise ValueError(f"Unsupported alpha_transform_type: {alpha_transform_type}")
61
61
 
62
62
  betas = []
63
63
  for i in range(num_diffusion_timesteps):
@@ -135,7 +135,7 @@ class HeunDiscreteScheduler(SchedulerMixin, ConfigMixin):
135
135
  elif beta_schedule == "exp":
136
136
  self.betas = betas_for_alpha_bar(num_train_timesteps, alpha_transform_type="exp")
137
137
  else:
138
- raise NotImplementedError(f"{beta_schedule} does is not implemented for {self.__class__}")
138
+ raise NotImplementedError(f"{beta_schedule} is not implemented for {self.__class__}")
139
139
 
140
140
  self.alphas = 1.0 - self.betas
141
141
  self.alphas_cumprod = torch.cumprod(self.alphas, dim=0)
@@ -174,7 +174,7 @@ class HeunDiscreteScheduler(SchedulerMixin, ConfigMixin):
174
174
  @property
175
175
  def step_index(self):
176
176
  """
177
- The index counter for current timestep. It will increae 1 after each scheduler step.
177
+ The index counter for current timestep. It will increase 1 after each scheduler step.
178
178
  """
179
179
  return self._step_index
180
180
 
@@ -198,21 +198,21 @@ class HeunDiscreteScheduler(SchedulerMixin, ConfigMixin):
198
198
 
199
199
  def scale_model_input(
200
200
  self,
201
- sample: torch.FloatTensor,
202
- timestep: Union[float, torch.FloatTensor],
203
- ) -> torch.FloatTensor:
201
+ sample: torch.Tensor,
202
+ timestep: Union[float, torch.Tensor],
203
+ ) -> torch.Tensor:
204
204
  """
205
205
  Ensures interchangeability with schedulers that need to scale the denoising model input depending on the
206
206
  current timestep.
207
207
 
208
208
  Args:
209
- sample (`torch.FloatTensor`):
209
+ sample (`torch.Tensor`):
210
210
  The input sample.
211
211
  timestep (`int`, *optional*):
212
212
  The current timestep in the diffusion chain.
213
213
 
214
214
  Returns:
215
- `torch.FloatTensor`:
215
+ `torch.Tensor`:
216
216
  A scaled input sample.
217
217
  """
218
218
  if self.step_index is None:
@@ -224,9 +224,10 @@ class HeunDiscreteScheduler(SchedulerMixin, ConfigMixin):
224
224
 
225
225
  def set_timesteps(
226
226
  self,
227
- num_inference_steps: int,
227
+ num_inference_steps: Optional[int] = None,
228
228
  device: Union[str, torch.device] = None,
229
229
  num_train_timesteps: Optional[int] = None,
230
+ timesteps: Optional[List[int]] = None,
230
231
  ):
231
232
  """
232
233
  Sets the discrete timesteps used for the diffusion chain (to be run before inference).
@@ -236,30 +237,47 @@ class HeunDiscreteScheduler(SchedulerMixin, ConfigMixin):
236
237
  The number of diffusion steps used when generating samples with a pre-trained model.
237
238
  device (`str` or `torch.device`, *optional*):
238
239
  The device to which the timesteps should be moved to. If `None`, the timesteps are not moved.
240
+ num_train_timesteps (`int`, *optional*):
241
+ The number of diffusion steps used when training the model. If `None`, the default
242
+ `num_train_timesteps` attribute is used.
243
+ timesteps (`List[int]`, *optional*):
244
+ Custom timesteps used to support arbitrary spacing between timesteps. If `None`, timesteps will be
245
+ generated based on the `timestep_spacing` attribute. If `timesteps` is passed, `num_inference_steps`
246
+ must be `None`, and `timestep_spacing` attribute will be ignored.
239
247
  """
248
+ if num_inference_steps is None and timesteps is None:
249
+ raise ValueError("Must pass exactly one of `num_inference_steps` or `custom_timesteps`.")
250
+ if num_inference_steps is not None and timesteps is not None:
251
+ raise ValueError("Can only pass one of `num_inference_steps` or `custom_timesteps`.")
252
+ if timesteps is not None and self.config.use_karras_sigmas:
253
+ raise ValueError("Cannot use `timesteps` with `config.use_karras_sigmas = True`")
254
+
255
+ num_inference_steps = num_inference_steps or len(timesteps)
240
256
  self.num_inference_steps = num_inference_steps
241
-
242
257
  num_train_timesteps = num_train_timesteps or self.config.num_train_timesteps
243
258
 
244
- # "linspace", "leading", "trailing" corresponds to annotation of Table 2. of https://arxiv.org/abs/2305.08891
245
- if self.config.timestep_spacing == "linspace":
246
- timesteps = np.linspace(0, num_train_timesteps - 1, num_inference_steps, dtype=np.float32)[::-1].copy()
247
- elif self.config.timestep_spacing == "leading":
248
- step_ratio = num_train_timesteps // self.num_inference_steps
249
- # creates integer timesteps by multiplying by ratio
250
- # casting to int to avoid issues when num_inference_step is power of 3
251
- timesteps = (np.arange(0, num_inference_steps) * step_ratio).round()[::-1].copy().astype(np.float32)
252
- timesteps += self.config.steps_offset
253
- elif self.config.timestep_spacing == "trailing":
254
- step_ratio = num_train_timesteps / self.num_inference_steps
255
- # creates integer timesteps by multiplying by ratio
256
- # casting to int to avoid issues when num_inference_step is power of 3
257
- timesteps = (np.arange(num_train_timesteps, 0, -step_ratio)).round().copy().astype(np.float32)
258
- timesteps -= 1
259
+ if timesteps is not None:
260
+ timesteps = np.array(timesteps, dtype=np.float32)
259
261
  else:
260
- raise ValueError(
261
- f"{self.config.timestep_spacing} is not supported. Please make sure to choose one of 'linspace', 'leading' or 'trailing'."
262
- )
262
+ # "linspace", "leading", "trailing" corresponds to annotation of Table 2. of https://arxiv.org/abs/2305.08891
263
+ if self.config.timestep_spacing == "linspace":
264
+ timesteps = np.linspace(0, num_train_timesteps - 1, num_inference_steps, dtype=np.float32)[::-1].copy()
265
+ elif self.config.timestep_spacing == "leading":
266
+ step_ratio = num_train_timesteps // self.num_inference_steps
267
+ # creates integer timesteps by multiplying by ratio
268
+ # casting to int to avoid issues when num_inference_step is power of 3
269
+ timesteps = (np.arange(0, num_inference_steps) * step_ratio).round()[::-1].copy().astype(np.float32)
270
+ timesteps += self.config.steps_offset
271
+ elif self.config.timestep_spacing == "trailing":
272
+ step_ratio = num_train_timesteps / self.num_inference_steps
273
+ # creates integer timesteps by multiplying by ratio
274
+ # casting to int to avoid issues when num_inference_step is power of 3
275
+ timesteps = (np.arange(num_train_timesteps, 0, -step_ratio)).round().copy().astype(np.float32)
276
+ timesteps -= 1
277
+ else:
278
+ raise ValueError(
279
+ f"{self.config.timestep_spacing} is not supported. Please make sure to choose one of 'linspace', 'leading' or 'trailing'."
280
+ )
263
281
 
264
282
  sigmas = np.array(((1 - self.alphas_cumprod) / self.alphas_cumprod) ** 0.5)
265
283
  log_sigmas = np.log(sigmas)
@@ -311,7 +329,7 @@ class HeunDiscreteScheduler(SchedulerMixin, ConfigMixin):
311
329
  return t
312
330
 
313
331
  # Copied from diffusers.schedulers.scheduling_euler_discrete.EulerDiscreteScheduler._convert_to_karras
314
- def _convert_to_karras(self, in_sigmas: torch.FloatTensor, num_inference_steps) -> torch.FloatTensor:
332
+ def _convert_to_karras(self, in_sigmas: torch.Tensor, num_inference_steps) -> torch.Tensor:
315
333
  """Constructs the noise schedule of Karras et al. (2022)."""
316
334
 
317
335
  # Hack to make sure that other schedulers which copy this function don't break
@@ -351,9 +369,9 @@ class HeunDiscreteScheduler(SchedulerMixin, ConfigMixin):
351
369
 
352
370
  def step(
353
371
  self,
354
- model_output: Union[torch.FloatTensor, np.ndarray],
355
- timestep: Union[float, torch.FloatTensor],
356
- sample: Union[torch.FloatTensor, np.ndarray],
372
+ model_output: Union[torch.Tensor, np.ndarray],
373
+ timestep: Union[float, torch.Tensor],
374
+ sample: Union[torch.Tensor, np.ndarray],
357
375
  return_dict: bool = True,
358
376
  ) -> Union[SchedulerOutput, Tuple]:
359
377
  """
@@ -361,11 +379,11 @@ class HeunDiscreteScheduler(SchedulerMixin, ConfigMixin):
361
379
  process from the learned model outputs (most often the predicted noise).
362
380
 
363
381
  Args:
364
- model_output (`torch.FloatTensor`):
382
+ model_output (`torch.Tensor`):
365
383
  The direct output from learned diffusion model.
366
384
  timestep (`float`):
367
385
  The current discrete timestep in the diffusion chain.
368
- sample (`torch.FloatTensor`):
386
+ sample (`torch.Tensor`):
369
387
  A current instance of a sample created by the diffusion process.
370
388
  return_dict (`bool`):
371
389
  Whether or not to return a [`~schedulers.scheduling_utils.SchedulerOutput`] or tuple.
@@ -451,10 +469,10 @@ class HeunDiscreteScheduler(SchedulerMixin, ConfigMixin):
451
469
  # Copied from diffusers.schedulers.scheduling_euler_discrete.EulerDiscreteScheduler.add_noise
452
470
  def add_noise(
453
471
  self,
454
- original_samples: torch.FloatTensor,
455
- noise: torch.FloatTensor,
456
- timesteps: torch.FloatTensor,
457
- ) -> torch.FloatTensor:
472
+ original_samples: torch.Tensor,
473
+ noise: torch.Tensor,
474
+ timesteps: torch.Tensor,
475
+ ) -> torch.Tensor:
458
476
  # Make sure sigmas and timesteps have the same device and dtype as original_samples
459
477
  sigmas = self.sigmas.to(device=original_samples.device, dtype=original_samples.dtype)
460
478
  if original_samples.device.type == "mps" and torch.is_floating_point(timesteps):
@@ -468,7 +486,11 @@ class HeunDiscreteScheduler(SchedulerMixin, ConfigMixin):
468
486
  # self.begin_index is None when scheduler is used for training, or pipeline does not implement set_begin_index
469
487
  if self.begin_index is None:
470
488
  step_indices = [self.index_for_timestep(t, schedule_timesteps) for t in timesteps]
489
+ elif self.step_index is not None:
490
+ # add_noise is called after first denoising step (for inpainting)
491
+ step_indices = [self.step_index] * timesteps.shape[0]
471
492
  else:
493
+ # add noise is called before first denoising step to create initial latent(img2img)
472
494
  step_indices = [self.begin_index] * timesteps.shape[0]
473
495
 
474
496
  sigma = sigmas[step_indices].flatten()
@@ -61,7 +61,7 @@ class IPNDMScheduler(SchedulerMixin, ConfigMixin):
61
61
  @property
62
62
  def step_index(self):
63
63
  """
64
- The index counter for current timestep. It will increae 1 after each scheduler step.
64
+ The index counter for current timestep. It will increase 1 after each scheduler step.
65
65
  """
66
66
  return self._step_index
67
67
 
@@ -137,9 +137,9 @@ class IPNDMScheduler(SchedulerMixin, ConfigMixin):
137
137
 
138
138
  def step(
139
139
  self,
140
- model_output: torch.FloatTensor,
140
+ model_output: torch.Tensor,
141
141
  timestep: int,
142
- sample: torch.FloatTensor,
142
+ sample: torch.Tensor,
143
143
  return_dict: bool = True,
144
144
  ) -> Union[SchedulerOutput, Tuple]:
145
145
  """
@@ -147,11 +147,11 @@ class IPNDMScheduler(SchedulerMixin, ConfigMixin):
147
147
  the linear multistep method. It performs one forward pass multiple times to approximate the solution.
148
148
 
149
149
  Args:
150
- model_output (`torch.FloatTensor`):
150
+ model_output (`torch.Tensor`):
151
151
  The direct output from learned diffusion model.
152
152
  timestep (`int`):
153
153
  The current discrete timestep in the diffusion chain.
154
- sample (`torch.FloatTensor`):
154
+ sample (`torch.Tensor`):
155
155
  A current instance of a sample created by the diffusion process.
156
156
  return_dict (`bool`):
157
157
  Whether or not to return a [`~schedulers.scheduling_utils.SchedulerOutput`] or tuple.
@@ -193,17 +193,17 @@ class IPNDMScheduler(SchedulerMixin, ConfigMixin):
193
193
 
194
194
  return SchedulerOutput(prev_sample=prev_sample)
195
195
 
196
- def scale_model_input(self, sample: torch.FloatTensor, *args, **kwargs) -> torch.FloatTensor:
196
+ def scale_model_input(self, sample: torch.Tensor, *args, **kwargs) -> torch.Tensor:
197
197
  """
198
198
  Ensures interchangeability with schedulers that need to scale the denoising model input depending on the
199
199
  current timestep.
200
200
 
201
201
  Args:
202
- sample (`torch.FloatTensor`):
202
+ sample (`torch.Tensor`):
203
203
  The input sample.
204
204
 
205
205
  Returns:
206
- `torch.FloatTensor`:
206
+ `torch.Tensor`:
207
207
  A scaled input sample.
208
208
  """
209
209
  return sample
@@ -58,7 +58,7 @@ def betas_for_alpha_bar(
58
58
  return math.exp(t * -12.0)
59
59
 
60
60
  else:
61
- raise ValueError(f"Unsupported alpha_tranform_type: {alpha_transform_type}")
61
+ raise ValueError(f"Unsupported alpha_transform_type: {alpha_transform_type}")
62
62
 
63
63
  betas = []
64
64
  for i in range(num_diffusion_timesteps):
@@ -129,7 +129,7 @@ class KDPM2AncestralDiscreteScheduler(SchedulerMixin, ConfigMixin):
129
129
  # Glide cosine schedule
130
130
  self.betas = betas_for_alpha_bar(num_train_timesteps)
131
131
  else:
132
- raise NotImplementedError(f"{beta_schedule} does is not implemented for {self.__class__}")
132
+ raise NotImplementedError(f"{beta_schedule} is not implemented for {self.__class__}")
133
133
 
134
134
  self.alphas = 1.0 - self.betas
135
135
  self.alphas_cumprod = torch.cumprod(self.alphas, dim=0)
@@ -151,7 +151,7 @@ class KDPM2AncestralDiscreteScheduler(SchedulerMixin, ConfigMixin):
151
151
  @property
152
152
  def step_index(self):
153
153
  """
154
- The index counter for current timestep. It will increae 1 after each scheduler step.
154
+ The index counter for current timestep. It will increase 1 after each scheduler step.
155
155
  """
156
156
  return self._step_index
157
157
 
@@ -175,21 +175,21 @@ class KDPM2AncestralDiscreteScheduler(SchedulerMixin, ConfigMixin):
175
175
 
176
176
  def scale_model_input(
177
177
  self,
178
- sample: torch.FloatTensor,
179
- timestep: Union[float, torch.FloatTensor],
180
- ) -> torch.FloatTensor:
178
+ sample: torch.Tensor,
179
+ timestep: Union[float, torch.Tensor],
180
+ ) -> torch.Tensor:
181
181
  """
182
182
  Ensures interchangeability with schedulers that need to scale the denoising model input depending on the
183
183
  current timestep.
184
184
 
185
185
  Args:
186
- sample (`torch.FloatTensor`):
186
+ sample (`torch.Tensor`):
187
187
  The input sample.
188
188
  timestep (`int`, *optional*):
189
189
  The current timestep in the diffusion chain.
190
190
 
191
191
  Returns:
192
- `torch.FloatTensor`:
192
+ `torch.Tensor`:
193
193
  A scaled input sample.
194
194
  """
195
195
  if self.step_index is None:
@@ -321,7 +321,7 @@ class KDPM2AncestralDiscreteScheduler(SchedulerMixin, ConfigMixin):
321
321
  return t
322
322
 
323
323
  # Copied from diffusers.schedulers.scheduling_euler_discrete.EulerDiscreteScheduler._convert_to_karras
324
- def _convert_to_karras(self, in_sigmas: torch.FloatTensor, num_inference_steps) -> torch.FloatTensor:
324
+ def _convert_to_karras(self, in_sigmas: torch.Tensor, num_inference_steps) -> torch.Tensor:
325
325
  """Constructs the noise schedule of Karras et al. (2022)."""
326
326
 
327
327
  # Hack to make sure that other schedulers which copy this function don't break
@@ -376,9 +376,9 @@ class KDPM2AncestralDiscreteScheduler(SchedulerMixin, ConfigMixin):
376
376
 
377
377
  def step(
378
378
  self,
379
- model_output: Union[torch.FloatTensor, np.ndarray],
380
- timestep: Union[float, torch.FloatTensor],
381
- sample: Union[torch.FloatTensor, np.ndarray],
379
+ model_output: Union[torch.Tensor, np.ndarray],
380
+ timestep: Union[float, torch.Tensor],
381
+ sample: Union[torch.Tensor, np.ndarray],
382
382
  generator: Optional[torch.Generator] = None,
383
383
  return_dict: bool = True,
384
384
  ) -> Union[SchedulerOutput, Tuple]:
@@ -387,11 +387,11 @@ class KDPM2AncestralDiscreteScheduler(SchedulerMixin, ConfigMixin):
387
387
  process from the learned model outputs (most often the predicted noise).
388
388
 
389
389
  Args:
390
- model_output (`torch.FloatTensor`):
390
+ model_output (`torch.Tensor`):
391
391
  The direct output from learned diffusion model.
392
392
  timestep (`float`):
393
393
  The current discrete timestep in the diffusion chain.
394
- sample (`torch.FloatTensor`):
394
+ sample (`torch.Tensor`):
395
395
  A current instance of a sample created by the diffusion process.
396
396
  generator (`torch.Generator`, *optional*):
397
397
  A random number generator.
@@ -477,10 +477,10 @@ class KDPM2AncestralDiscreteScheduler(SchedulerMixin, ConfigMixin):
477
477
  # Copied from diffusers.schedulers.scheduling_euler_discrete.EulerDiscreteScheduler.add_noise
478
478
  def add_noise(
479
479
  self,
480
- original_samples: torch.FloatTensor,
481
- noise: torch.FloatTensor,
482
- timesteps: torch.FloatTensor,
483
- ) -> torch.FloatTensor:
480
+ original_samples: torch.Tensor,
481
+ noise: torch.Tensor,
482
+ timesteps: torch.Tensor,
483
+ ) -> torch.Tensor:
484
484
  # Make sure sigmas and timesteps have the same device and dtype as original_samples
485
485
  sigmas = self.sigmas.to(device=original_samples.device, dtype=original_samples.dtype)
486
486
  if original_samples.device.type == "mps" and torch.is_floating_point(timesteps):
@@ -494,7 +494,11 @@ class KDPM2AncestralDiscreteScheduler(SchedulerMixin, ConfigMixin):
494
494
  # self.begin_index is None when scheduler is used for training, or pipeline does not implement set_begin_index
495
495
  if self.begin_index is None:
496
496
  step_indices = [self.index_for_timestep(t, schedule_timesteps) for t in timesteps]
497
+ elif self.step_index is not None:
498
+ # add_noise is called after first denoising step (for inpainting)
499
+ step_indices = [self.step_index] * timesteps.shape[0]
497
500
  else:
501
+ # add noise is called before first denoising step to create initial latent(img2img)
498
502
  step_indices = [self.begin_index] * timesteps.shape[0]
499
503
 
500
504
  sigma = sigmas[step_indices].flatten()
@@ -57,7 +57,7 @@ def betas_for_alpha_bar(
57
57
  return math.exp(t * -12.0)
58
58
 
59
59
  else:
60
- raise ValueError(f"Unsupported alpha_tranform_type: {alpha_transform_type}")
60
+ raise ValueError(f"Unsupported alpha_transform_type: {alpha_transform_type}")
61
61
 
62
62
  betas = []
63
63
  for i in range(num_diffusion_timesteps):
@@ -128,7 +128,7 @@ class KDPM2DiscreteScheduler(SchedulerMixin, ConfigMixin):
128
128
  # Glide cosine schedule
129
129
  self.betas = betas_for_alpha_bar(num_train_timesteps)
130
130
  else:
131
- raise NotImplementedError(f"{beta_schedule} does is not implemented for {self.__class__}")
131
+ raise NotImplementedError(f"{beta_schedule} is not implemented for {self.__class__}")
132
132
 
133
133
  self.alphas = 1.0 - self.betas
134
134
  self.alphas_cumprod = torch.cumprod(self.alphas, dim=0)
@@ -151,7 +151,7 @@ class KDPM2DiscreteScheduler(SchedulerMixin, ConfigMixin):
151
151
  @property
152
152
  def step_index(self):
153
153
  """
154
- The index counter for current timestep. It will increae 1 after each scheduler step.
154
+ The index counter for current timestep. It will increase 1 after each scheduler step.
155
155
  """
156
156
  return self._step_index
157
157
 
@@ -175,21 +175,21 @@ class KDPM2DiscreteScheduler(SchedulerMixin, ConfigMixin):
175
175
 
176
176
  def scale_model_input(
177
177
  self,
178
- sample: torch.FloatTensor,
179
- timestep: Union[float, torch.FloatTensor],
180
- ) -> torch.FloatTensor:
178
+ sample: torch.Tensor,
179
+ timestep: Union[float, torch.Tensor],
180
+ ) -> torch.Tensor:
181
181
  """
182
182
  Ensures interchangeability with schedulers that need to scale the denoising model input depending on the
183
183
  current timestep.
184
184
 
185
185
  Args:
186
- sample (`torch.FloatTensor`):
186
+ sample (`torch.Tensor`):
187
187
  The input sample.
188
188
  timestep (`int`, *optional*):
189
189
  The current timestep in the diffusion chain.
190
190
 
191
191
  Returns:
192
- `torch.FloatTensor`:
192
+ `torch.Tensor`:
193
193
  A scaled input sample.
194
194
  """
195
195
  if self.step_index is None:
@@ -334,7 +334,7 @@ class KDPM2DiscreteScheduler(SchedulerMixin, ConfigMixin):
334
334
  return t
335
335
 
336
336
  # Copied from diffusers.schedulers.scheduling_euler_discrete.EulerDiscreteScheduler._convert_to_karras
337
- def _convert_to_karras(self, in_sigmas: torch.FloatTensor, num_inference_steps) -> torch.FloatTensor:
337
+ def _convert_to_karras(self, in_sigmas: torch.Tensor, num_inference_steps) -> torch.Tensor:
338
338
  """Constructs the noise schedule of Karras et al. (2022)."""
339
339
 
340
340
  # Hack to make sure that other schedulers which copy this function don't break
@@ -361,9 +361,9 @@ class KDPM2DiscreteScheduler(SchedulerMixin, ConfigMixin):
361
361
 
362
362
  def step(
363
363
  self,
364
- model_output: Union[torch.FloatTensor, np.ndarray],
365
- timestep: Union[float, torch.FloatTensor],
366
- sample: Union[torch.FloatTensor, np.ndarray],
364
+ model_output: Union[torch.Tensor, np.ndarray],
365
+ timestep: Union[float, torch.Tensor],
366
+ sample: Union[torch.Tensor, np.ndarray],
367
367
  return_dict: bool = True,
368
368
  ) -> Union[SchedulerOutput, Tuple]:
369
369
  """
@@ -371,11 +371,11 @@ class KDPM2DiscreteScheduler(SchedulerMixin, ConfigMixin):
371
371
  process from the learned model outputs (most often the predicted noise).
372
372
 
373
373
  Args:
374
- model_output (`torch.FloatTensor`):
374
+ model_output (`torch.Tensor`):
375
375
  The direct output from learned diffusion model.
376
376
  timestep (`float`):
377
377
  The current discrete timestep in the diffusion chain.
378
- sample (`torch.FloatTensor`):
378
+ sample (`torch.Tensor`):
379
379
  A current instance of a sample created by the diffusion process.
380
380
  return_dict (`bool`):
381
381
  Whether or not to return a [`~schedulers.scheduling_utils.SchedulerOutput`] or tuple.
@@ -452,10 +452,10 @@ class KDPM2DiscreteScheduler(SchedulerMixin, ConfigMixin):
452
452
  # Copied from diffusers.schedulers.scheduling_euler_discrete.EulerDiscreteScheduler.add_noise
453
453
  def add_noise(
454
454
  self,
455
- original_samples: torch.FloatTensor,
456
- noise: torch.FloatTensor,
457
- timesteps: torch.FloatTensor,
458
- ) -> torch.FloatTensor:
455
+ original_samples: torch.Tensor,
456
+ noise: torch.Tensor,
457
+ timesteps: torch.Tensor,
458
+ ) -> torch.Tensor:
459
459
  # Make sure sigmas and timesteps have the same device and dtype as original_samples
460
460
  sigmas = self.sigmas.to(device=original_samples.device, dtype=original_samples.dtype)
461
461
  if original_samples.device.type == "mps" and torch.is_floating_point(timesteps):
@@ -469,7 +469,11 @@ class KDPM2DiscreteScheduler(SchedulerMixin, ConfigMixin):
469
469
  # self.begin_index is None when scheduler is used for training, or pipeline does not implement set_begin_index
470
470
  if self.begin_index is None:
471
471
  step_indices = [self.index_for_timestep(t, schedule_timesteps) for t in timesteps]
472
+ elif self.step_index is not None:
473
+ # add_noise is called after first denoising step (for inpainting)
474
+ step_indices = [self.step_index] * timesteps.shape[0]
472
475
  else:
476
+ # add noise is called before first denoising step to create initial latent(img2img)
473
477
  step_indices = [self.begin_index] * timesteps.shape[0]
474
478
 
475
479
  sigma = sigmas[step_indices].flatten()
@@ -176,10 +176,10 @@ class FlaxKarrasVeScheduler(FlaxSchedulerMixin, ConfigMixin):
176
176
 
177
177
  Args:
178
178
  state (`KarrasVeSchedulerState`): the `FlaxKarrasVeScheduler` state data class.
179
- model_output (`torch.FloatTensor` or `np.ndarray`): direct output from learned diffusion model.
179
+ model_output (`torch.Tensor` or `np.ndarray`): direct output from learned diffusion model.
180
180
  sigma_hat (`float`): TODO
181
181
  sigma_prev (`float`): TODO
182
- sample_hat (`torch.FloatTensor` or `np.ndarray`): TODO
182
+ sample_hat (`torch.Tensor` or `np.ndarray`): TODO
183
183
  return_dict (`bool`): option for returning tuple rather than FlaxKarrasVeOutput class
184
184
 
185
185
  Returns:
@@ -213,12 +213,12 @@ class FlaxKarrasVeScheduler(FlaxSchedulerMixin, ConfigMixin):
213
213
 
214
214
  Args:
215
215
  state (`KarrasVeSchedulerState`): the `FlaxKarrasVeScheduler` state data class.
216
- model_output (`torch.FloatTensor` or `np.ndarray`): direct output from learned diffusion model.
216
+ model_output (`torch.Tensor` or `np.ndarray`): direct output from learned diffusion model.
217
217
  sigma_hat (`float`): TODO
218
218
  sigma_prev (`float`): TODO
219
- sample_hat (`torch.FloatTensor` or `np.ndarray`): TODO
220
- sample_prev (`torch.FloatTensor` or `np.ndarray`): TODO
221
- derivative (`torch.FloatTensor` or `np.ndarray`): TODO
219
+ sample_hat (`torch.Tensor` or `np.ndarray`): TODO
220
+ sample_prev (`torch.Tensor` or `np.ndarray`): TODO
221
+ derivative (`torch.Tensor` or `np.ndarray`): TODO
222
222
  return_dict (`bool`): option for returning tuple rather than FlaxKarrasVeOutput class
223
223
 
224
224
  Returns:
@@ -37,16 +37,16 @@ class LCMSchedulerOutput(BaseOutput):
37
37
  Output class for the scheduler's `step` function output.
38
38
 
39
39
  Args:
40
- prev_sample (`torch.FloatTensor` of shape `(batch_size, num_channels, height, width)` for images):
40
+ prev_sample (`torch.Tensor` of shape `(batch_size, num_channels, height, width)` for images):
41
41
  Computed sample `(x_{t-1})` of previous timestep. `prev_sample` should be used as next model input in the
42
42
  denoising loop.
43
- pred_original_sample (`torch.FloatTensor` of shape `(batch_size, num_channels, height, width)` for images):
43
+ pred_original_sample (`torch.Tensor` of shape `(batch_size, num_channels, height, width)` for images):
44
44
  The predicted denoised sample `(x_{0})` based on the model output from the current timestep.
45
45
  `pred_original_sample` can be used to preview progress or for guidance.
46
46
  """
47
47
 
48
- prev_sample: torch.FloatTensor
49
- denoised: Optional[torch.FloatTensor] = None
48
+ prev_sample: torch.Tensor
49
+ denoised: Optional[torch.Tensor] = None
50
50
 
51
51
 
52
52
  # Copied from diffusers.schedulers.scheduling_ddpm.betas_for_alpha_bar
@@ -84,7 +84,7 @@ def betas_for_alpha_bar(
84
84
  return math.exp(t * -12.0)
85
85
 
86
86
  else:
87
- raise ValueError(f"Unsupported alpha_tranform_type: {alpha_transform_type}")
87
+ raise ValueError(f"Unsupported alpha_transform_type: {alpha_transform_type}")
88
88
 
89
89
  betas = []
90
90
  for i in range(num_diffusion_timesteps):
@@ -95,17 +95,17 @@ def betas_for_alpha_bar(
95
95
 
96
96
 
97
97
  # Copied from diffusers.schedulers.scheduling_ddim.rescale_zero_terminal_snr
98
- def rescale_zero_terminal_snr(betas: torch.FloatTensor) -> torch.FloatTensor:
98
+ def rescale_zero_terminal_snr(betas: torch.Tensor) -> torch.Tensor:
99
99
  """
100
100
  Rescales betas to have zero terminal SNR Based on https://arxiv.org/pdf/2305.08891.pdf (Algorithm 1)
101
101
 
102
102
 
103
103
  Args:
104
- betas (`torch.FloatTensor`):
104
+ betas (`torch.Tensor`):
105
105
  the betas that the scheduler is being initialized with.
106
106
 
107
107
  Returns:
108
- `torch.FloatTensor`: rescaled betas with zero terminal SNR
108
+ `torch.Tensor`: rescaled betas with zero terminal SNR
109
109
  """
110
110
  # Convert betas to alphas_bar_sqrt
111
111
  alphas = 1.0 - betas
@@ -224,7 +224,7 @@ class LCMScheduler(SchedulerMixin, ConfigMixin):
224
224
  # Glide cosine schedule
225
225
  self.betas = betas_for_alpha_bar(num_train_timesteps)
226
226
  else:
227
- raise NotImplementedError(f"{beta_schedule} does is not implemented for {self.__class__}")
227
+ raise NotImplementedError(f"{beta_schedule} is not implemented for {self.__class__}")
228
228
 
229
229
  # Rescale for zero SNR
230
230
  if rescale_betas_zero_snr:
@@ -296,24 +296,24 @@ class LCMScheduler(SchedulerMixin, ConfigMixin):
296
296
  """
297
297
  self._begin_index = begin_index
298
298
 
299
- def scale_model_input(self, sample: torch.FloatTensor, timestep: Optional[int] = None) -> torch.FloatTensor:
299
+ def scale_model_input(self, sample: torch.Tensor, timestep: Optional[int] = None) -> torch.Tensor:
300
300
  """
301
301
  Ensures interchangeability with schedulers that need to scale the denoising model input depending on the
302
302
  current timestep.
303
303
 
304
304
  Args:
305
- sample (`torch.FloatTensor`):
305
+ sample (`torch.Tensor`):
306
306
  The input sample.
307
307
  timestep (`int`, *optional*):
308
308
  The current timestep in the diffusion chain.
309
309
  Returns:
310
- `torch.FloatTensor`:
310
+ `torch.Tensor`:
311
311
  A scaled input sample.
312
312
  """
313
313
  return sample
314
314
 
315
315
  # Copied from diffusers.schedulers.scheduling_ddpm.DDPMScheduler._threshold_sample
316
- def _threshold_sample(self, sample: torch.FloatTensor) -> torch.FloatTensor:
316
+ def _threshold_sample(self, sample: torch.Tensor) -> torch.Tensor:
317
317
  """
318
318
  "Dynamic thresholding: At each sampling step we set s to a certain percentile absolute pixel value in xt0 (the
319
319
  prediction of x_0 at timestep t), and if s > 1, then we threshold xt0 to the range [-s, s] and then divide by
@@ -497,9 +497,9 @@ class LCMScheduler(SchedulerMixin, ConfigMixin):
497
497
 
498
498
  def step(
499
499
  self,
500
- model_output: torch.FloatTensor,
500
+ model_output: torch.Tensor,
501
501
  timestep: int,
502
- sample: torch.FloatTensor,
502
+ sample: torch.Tensor,
503
503
  generator: Optional[torch.Generator] = None,
504
504
  return_dict: bool = True,
505
505
  ) -> Union[LCMSchedulerOutput, Tuple]:
@@ -508,11 +508,11 @@ class LCMScheduler(SchedulerMixin, ConfigMixin):
508
508
  process from the learned model outputs (most often the predicted noise).
509
509
 
510
510
  Args:
511
- model_output (`torch.FloatTensor`):
511
+ model_output (`torch.Tensor`):
512
512
  The direct output from learned diffusion model.
513
513
  timestep (`float`):
514
514
  The current discrete timestep in the diffusion chain.
515
- sample (`torch.FloatTensor`):
515
+ sample (`torch.Tensor`):
516
516
  A current instance of a sample created by the diffusion process.
517
517
  generator (`torch.Generator`, *optional*):
518
518
  A random number generator.
@@ -594,10 +594,10 @@ class LCMScheduler(SchedulerMixin, ConfigMixin):
594
594
  # Copied from diffusers.schedulers.scheduling_ddpm.DDPMScheduler.add_noise
595
595
  def add_noise(
596
596
  self,
597
- original_samples: torch.FloatTensor,
598
- noise: torch.FloatTensor,
597
+ original_samples: torch.Tensor,
598
+ noise: torch.Tensor,
599
599
  timesteps: torch.IntTensor,
600
- ) -> torch.FloatTensor:
600
+ ) -> torch.Tensor:
601
601
  # Make sure alphas_cumprod and timestep have same device and dtype as original_samples
602
602
  # Move the self.alphas_cumprod to device to avoid redundant CPU to GPU data movement
603
603
  # for the subsequent add_noise calls
@@ -619,9 +619,7 @@ class LCMScheduler(SchedulerMixin, ConfigMixin):
619
619
  return noisy_samples
620
620
 
621
621
  # Copied from diffusers.schedulers.scheduling_ddpm.DDPMScheduler.get_velocity
622
- def get_velocity(
623
- self, sample: torch.FloatTensor, noise: torch.FloatTensor, timesteps: torch.IntTensor
624
- ) -> torch.FloatTensor:
622
+ def get_velocity(self, sample: torch.Tensor, noise: torch.Tensor, timesteps: torch.IntTensor) -> torch.Tensor:
625
623
  # Make sure alphas_cumprod and timestep have same device and dtype as sample
626
624
  self.alphas_cumprod = self.alphas_cumprod.to(device=sample.device)
627
625
  alphas_cumprod = self.alphas_cumprod.to(dtype=sample.dtype)