diffusers 0.24.0__py3-none-any.whl → 0.25.0__py3-none-any.whl

Sign up to get free protection for your applications and to get access to all the features.
Files changed (174) hide show
  1. diffusers/__init__.py +11 -1
  2. diffusers/commands/fp16_safetensors.py +10 -11
  3. diffusers/configuration_utils.py +12 -8
  4. diffusers/dependency_versions_table.py +2 -1
  5. diffusers/experimental/rl/value_guided_sampling.py +1 -1
  6. diffusers/image_processor.py +286 -46
  7. diffusers/loaders/ip_adapter.py +11 -9
  8. diffusers/loaders/lora.py +198 -60
  9. diffusers/loaders/single_file.py +24 -18
  10. diffusers/loaders/textual_inversion.py +10 -14
  11. diffusers/loaders/unet.py +130 -37
  12. diffusers/models/__init__.py +18 -12
  13. diffusers/models/activations.py +9 -6
  14. diffusers/models/attention.py +137 -16
  15. diffusers/models/attention_processor.py +133 -46
  16. diffusers/models/autoencoders/__init__.py +5 -0
  17. diffusers/models/{autoencoder_asym_kl.py → autoencoders/autoencoder_asym_kl.py} +4 -4
  18. diffusers/models/{autoencoder_kl.py → autoencoders/autoencoder_kl.py} +45 -6
  19. diffusers/models/{autoencoder_kl_temporal_decoder.py → autoencoders/autoencoder_kl_temporal_decoder.py} +8 -8
  20. diffusers/models/{autoencoder_tiny.py → autoencoders/autoencoder_tiny.py} +4 -4
  21. diffusers/models/{consistency_decoder_vae.py → autoencoders/consistency_decoder_vae.py} +14 -14
  22. diffusers/models/{vae.py → autoencoders/vae.py} +9 -5
  23. diffusers/models/downsampling.py +338 -0
  24. diffusers/models/embeddings.py +112 -29
  25. diffusers/models/modeling_flax_utils.py +12 -7
  26. diffusers/models/modeling_utils.py +10 -10
  27. diffusers/models/normalization.py +108 -2
  28. diffusers/models/resnet.py +15 -699
  29. diffusers/models/transformer_2d.py +2 -2
  30. diffusers/models/unet_2d_condition.py +37 -0
  31. diffusers/models/{unet_kandi3.py → unet_kandinsky3.py} +105 -159
  32. diffusers/models/upsampling.py +454 -0
  33. diffusers/models/uvit_2d.py +471 -0
  34. diffusers/models/vq_model.py +9 -2
  35. diffusers/pipelines/__init__.py +81 -73
  36. diffusers/pipelines/amused/__init__.py +62 -0
  37. diffusers/pipelines/amused/pipeline_amused.py +328 -0
  38. diffusers/pipelines/amused/pipeline_amused_img2img.py +347 -0
  39. diffusers/pipelines/amused/pipeline_amused_inpaint.py +378 -0
  40. diffusers/pipelines/animatediff/pipeline_animatediff.py +38 -10
  41. diffusers/pipelines/auto_pipeline.py +17 -13
  42. diffusers/pipelines/controlnet/pipeline_controlnet.py +27 -10
  43. diffusers/pipelines/controlnet/pipeline_controlnet_img2img.py +47 -5
  44. diffusers/pipelines/controlnet/pipeline_controlnet_inpaint.py +25 -8
  45. diffusers/pipelines/controlnet/pipeline_controlnet_inpaint_sd_xl.py +4 -6
  46. diffusers/pipelines/controlnet/pipeline_controlnet_sd_xl.py +26 -10
  47. diffusers/pipelines/controlnet/pipeline_controlnet_sd_xl_img2img.py +4 -3
  48. diffusers/pipelines/deprecated/__init__.py +153 -0
  49. diffusers/pipelines/{alt_diffusion → deprecated/alt_diffusion}/__init__.py +3 -3
  50. diffusers/pipelines/{alt_diffusion → deprecated/alt_diffusion}/pipeline_alt_diffusion.py +91 -18
  51. diffusers/pipelines/{alt_diffusion → deprecated/alt_diffusion}/pipeline_alt_diffusion_img2img.py +91 -18
  52. diffusers/pipelines/{alt_diffusion → deprecated/alt_diffusion}/pipeline_output.py +1 -1
  53. diffusers/pipelines/{audio_diffusion → deprecated/audio_diffusion}/__init__.py +1 -1
  54. diffusers/pipelines/{audio_diffusion → deprecated/audio_diffusion}/mel.py +2 -2
  55. diffusers/pipelines/{audio_diffusion → deprecated/audio_diffusion}/pipeline_audio_diffusion.py +4 -4
  56. diffusers/pipelines/{latent_diffusion_uncond → deprecated/latent_diffusion_uncond}/__init__.py +1 -1
  57. diffusers/pipelines/{latent_diffusion_uncond → deprecated/latent_diffusion_uncond}/pipeline_latent_diffusion_uncond.py +4 -4
  58. diffusers/pipelines/{pndm → deprecated/pndm}/__init__.py +1 -1
  59. diffusers/pipelines/{pndm → deprecated/pndm}/pipeline_pndm.py +4 -4
  60. diffusers/pipelines/{repaint → deprecated/repaint}/__init__.py +1 -1
  61. diffusers/pipelines/{repaint → deprecated/repaint}/pipeline_repaint.py +5 -5
  62. diffusers/pipelines/{score_sde_ve → deprecated/score_sde_ve}/__init__.py +1 -1
  63. diffusers/pipelines/{score_sde_ve → deprecated/score_sde_ve}/pipeline_score_sde_ve.py +4 -4
  64. diffusers/pipelines/{spectrogram_diffusion → deprecated/spectrogram_diffusion}/__init__.py +6 -6
  65. diffusers/pipelines/{spectrogram_diffusion/continous_encoder.py → deprecated/spectrogram_diffusion/continuous_encoder.py} +2 -2
  66. diffusers/pipelines/{spectrogram_diffusion → deprecated/spectrogram_diffusion}/midi_utils.py +1 -1
  67. diffusers/pipelines/{spectrogram_diffusion → deprecated/spectrogram_diffusion}/notes_encoder.py +2 -2
  68. diffusers/pipelines/{spectrogram_diffusion → deprecated/spectrogram_diffusion}/pipeline_spectrogram_diffusion.py +7 -7
  69. diffusers/pipelines/deprecated/stable_diffusion_variants/__init__.py +55 -0
  70. diffusers/pipelines/{stable_diffusion → deprecated/stable_diffusion_variants}/pipeline_cycle_diffusion.py +16 -11
  71. diffusers/pipelines/{stable_diffusion → deprecated/stable_diffusion_variants}/pipeline_onnx_stable_diffusion_inpaint_legacy.py +6 -6
  72. diffusers/pipelines/{stable_diffusion → deprecated/stable_diffusion_variants}/pipeline_stable_diffusion_inpaint_legacy.py +11 -11
  73. diffusers/pipelines/{stable_diffusion → deprecated/stable_diffusion_variants}/pipeline_stable_diffusion_model_editing.py +16 -11
  74. diffusers/pipelines/{stable_diffusion → deprecated/stable_diffusion_variants}/pipeline_stable_diffusion_paradigms.py +10 -10
  75. diffusers/pipelines/{stable_diffusion → deprecated/stable_diffusion_variants}/pipeline_stable_diffusion_pix2pix_zero.py +13 -13
  76. diffusers/pipelines/{stochastic_karras_ve → deprecated/stochastic_karras_ve}/__init__.py +1 -1
  77. diffusers/pipelines/{stochastic_karras_ve → deprecated/stochastic_karras_ve}/pipeline_stochastic_karras_ve.py +4 -4
  78. diffusers/pipelines/{versatile_diffusion → deprecated/versatile_diffusion}/__init__.py +3 -3
  79. diffusers/pipelines/{versatile_diffusion → deprecated/versatile_diffusion}/modeling_text_unet.py +54 -11
  80. diffusers/pipelines/{versatile_diffusion → deprecated/versatile_diffusion}/pipeline_versatile_diffusion.py +4 -4
  81. diffusers/pipelines/{versatile_diffusion → deprecated/versatile_diffusion}/pipeline_versatile_diffusion_dual_guided.py +6 -6
  82. diffusers/pipelines/{versatile_diffusion → deprecated/versatile_diffusion}/pipeline_versatile_diffusion_image_variation.py +6 -6
  83. diffusers/pipelines/{versatile_diffusion → deprecated/versatile_diffusion}/pipeline_versatile_diffusion_text_to_image.py +6 -6
  84. diffusers/pipelines/{vq_diffusion → deprecated/vq_diffusion}/__init__.py +3 -3
  85. diffusers/pipelines/{vq_diffusion → deprecated/vq_diffusion}/pipeline_vq_diffusion.py +5 -5
  86. diffusers/pipelines/kandinsky3/__init__.py +4 -4
  87. diffusers/pipelines/kandinsky3/convert_kandinsky3_unet.py +98 -0
  88. diffusers/pipelines/kandinsky3/{kandinsky3_pipeline.py → pipeline_kandinsky3.py} +172 -35
  89. diffusers/pipelines/kandinsky3/{kandinsky3img2img_pipeline.py → pipeline_kandinsky3_img2img.py} +228 -34
  90. diffusers/pipelines/latent_consistency_models/pipeline_latent_consistency_img2img.py +46 -5
  91. diffusers/pipelines/latent_consistency_models/pipeline_latent_consistency_text2img.py +47 -6
  92. diffusers/pipelines/onnx_utils.py +8 -5
  93. diffusers/pipelines/pipeline_flax_utils.py +7 -6
  94. diffusers/pipelines/pipeline_utils.py +30 -29
  95. diffusers/pipelines/pixart_alpha/pipeline_pixart_alpha.py +51 -2
  96. diffusers/pipelines/shap_e/pipeline_shap_e_img2img.py +3 -3
  97. diffusers/pipelines/stable_diffusion/__init__.py +1 -72
  98. diffusers/pipelines/stable_diffusion/convert_from_ckpt.py +67 -75
  99. diffusers/pipelines/stable_diffusion/pipeline_stable_diffusion.py +92 -8
  100. diffusers/pipelines/stable_diffusion/pipeline_stable_diffusion_img2img.py +92 -8
  101. diffusers/pipelines/stable_diffusion/pipeline_stable_diffusion_inpaint.py +138 -10
  102. diffusers/pipelines/stable_diffusion/pipeline_stable_diffusion_instruct_pix2pix.py +57 -7
  103. diffusers/pipelines/stable_diffusion/pipeline_stable_diffusion_latent_upscale.py +3 -0
  104. diffusers/pipelines/stable_diffusion/pipeline_stable_diffusion_upscale.py +6 -0
  105. diffusers/pipelines/stable_diffusion/pipeline_stable_unclip.py +5 -0
  106. diffusers/pipelines/stable_diffusion/pipeline_stable_unclip_img2img.py +5 -0
  107. diffusers/pipelines/stable_diffusion_attend_and_excite/__init__.py +48 -0
  108. diffusers/pipelines/{stable_diffusion → stable_diffusion_attend_and_excite}/pipeline_stable_diffusion_attend_and_excite.py +5 -2
  109. diffusers/pipelines/stable_diffusion_diffedit/__init__.py +48 -0
  110. diffusers/pipelines/{stable_diffusion → stable_diffusion_diffedit}/pipeline_stable_diffusion_diffedit.py +2 -3
  111. diffusers/pipelines/stable_diffusion_gligen/__init__.py +50 -0
  112. diffusers/pipelines/{stable_diffusion → stable_diffusion_gligen}/pipeline_stable_diffusion_gligen.py +2 -2
  113. diffusers/pipelines/{stable_diffusion → stable_diffusion_gligen}/pipeline_stable_diffusion_gligen_text_image.py +3 -3
  114. diffusers/pipelines/stable_diffusion_k_diffusion/__init__.py +60 -0
  115. diffusers/pipelines/{stable_diffusion → stable_diffusion_k_diffusion}/pipeline_stable_diffusion_k_diffusion.py +6 -1
  116. diffusers/pipelines/stable_diffusion_ldm3d/__init__.py +48 -0
  117. diffusers/pipelines/{stable_diffusion → stable_diffusion_ldm3d}/pipeline_stable_diffusion_ldm3d.py +50 -7
  118. diffusers/pipelines/stable_diffusion_panorama/__init__.py +48 -0
  119. diffusers/pipelines/{stable_diffusion → stable_diffusion_panorama}/pipeline_stable_diffusion_panorama.py +56 -8
  120. diffusers/pipelines/stable_diffusion_safe/pipeline_stable_diffusion_safe.py +58 -6
  121. diffusers/pipelines/stable_diffusion_sag/__init__.py +48 -0
  122. diffusers/pipelines/{stable_diffusion → stable_diffusion_sag}/pipeline_stable_diffusion_sag.py +67 -10
  123. diffusers/pipelines/stable_diffusion_xl/pipeline_stable_diffusion_xl.py +97 -15
  124. diffusers/pipelines/stable_diffusion_xl/pipeline_stable_diffusion_xl_img2img.py +98 -14
  125. diffusers/pipelines/stable_diffusion_xl/pipeline_stable_diffusion_xl_inpaint.py +97 -14
  126. diffusers/pipelines/stable_diffusion_xl/pipeline_stable_diffusion_xl_instruct_pix2pix.py +7 -5
  127. diffusers/pipelines/stable_video_diffusion/pipeline_stable_video_diffusion.py +12 -9
  128. diffusers/pipelines/t2i_adapter/pipeline_stable_diffusion_xl_adapter.py +6 -0
  129. diffusers/pipelines/text_to_video_synthesis/pipeline_text_to_video_synth.py +5 -0
  130. diffusers/pipelines/text_to_video_synthesis/pipeline_text_to_video_synth_img2img.py +5 -0
  131. diffusers/pipelines/text_to_video_synthesis/pipeline_text_to_video_zero.py +331 -9
  132. diffusers/pipelines/text_to_video_synthesis/pipeline_text_to_video_zero_sdxl.py +468 -9
  133. diffusers/pipelines/unclip/pipeline_unclip.py +2 -1
  134. diffusers/pipelines/unclip/pipeline_unclip_image_variation.py +1 -0
  135. diffusers/pipelines/wuerstchen/modeling_paella_vq_model.py +1 -1
  136. diffusers/pipelines/wuerstchen/pipeline_wuerstchen_prior.py +4 -0
  137. diffusers/schedulers/__init__.py +2 -0
  138. diffusers/schedulers/scheduling_amused.py +162 -0
  139. diffusers/schedulers/scheduling_consistency_models.py +2 -0
  140. diffusers/schedulers/scheduling_ddim_inverse.py +1 -4
  141. diffusers/schedulers/scheduling_ddpm.py +46 -0
  142. diffusers/schedulers/scheduling_ddpm_parallel.py +46 -0
  143. diffusers/schedulers/scheduling_deis_multistep.py +13 -1
  144. diffusers/schedulers/scheduling_dpmsolver_multistep.py +13 -1
  145. diffusers/schedulers/scheduling_dpmsolver_multistep_inverse.py +13 -1
  146. diffusers/schedulers/scheduling_dpmsolver_sde.py +2 -0
  147. diffusers/schedulers/scheduling_dpmsolver_singlestep.py +13 -1
  148. diffusers/schedulers/scheduling_euler_ancestral_discrete.py +58 -0
  149. diffusers/schedulers/scheduling_euler_discrete.py +62 -3
  150. diffusers/schedulers/scheduling_heun_discrete.py +2 -0
  151. diffusers/schedulers/scheduling_k_dpm_2_ancestral_discrete.py +2 -0
  152. diffusers/schedulers/scheduling_k_dpm_2_discrete.py +2 -0
  153. diffusers/schedulers/scheduling_lms_discrete.py +2 -0
  154. diffusers/schedulers/scheduling_unipc_multistep.py +13 -1
  155. diffusers/schedulers/scheduling_utils.py +3 -1
  156. diffusers/schedulers/scheduling_utils_flax.py +3 -1
  157. diffusers/training_utils.py +1 -1
  158. diffusers/utils/__init__.py +0 -2
  159. diffusers/utils/constants.py +2 -5
  160. diffusers/utils/dummy_pt_objects.py +30 -0
  161. diffusers/utils/dummy_torch_and_transformers_objects.py +45 -0
  162. diffusers/utils/dynamic_modules_utils.py +14 -18
  163. diffusers/utils/hub_utils.py +24 -36
  164. diffusers/utils/logging.py +1 -1
  165. diffusers/utils/state_dict_utils.py +8 -0
  166. diffusers/utils/testing_utils.py +199 -1
  167. diffusers/utils/torch_utils.py +3 -3
  168. {diffusers-0.24.0.dist-info → diffusers-0.25.0.dist-info}/METADATA +54 -53
  169. {diffusers-0.24.0.dist-info → diffusers-0.25.0.dist-info}/RECORD +174 -155
  170. {diffusers-0.24.0.dist-info → diffusers-0.25.0.dist-info}/WHEEL +1 -1
  171. {diffusers-0.24.0.dist-info → diffusers-0.25.0.dist-info}/entry_points.txt +0 -1
  172. /diffusers/pipelines/{alt_diffusion → deprecated/alt_diffusion}/modeling_roberta_series.py +0 -0
  173. {diffusers-0.24.0.dist-info → diffusers-0.25.0.dist-info}/LICENSE +0 -0
  174. {diffusers-0.24.0.dist-info → diffusers-0.25.0.dist-info}/top_level.txt +0 -0
@@ -0,0 +1,162 @@
1
+ import math
2
+ from dataclasses import dataclass
3
+ from typing import List, Optional, Tuple, Union
4
+
5
+ import torch
6
+
7
+ from ..configuration_utils import ConfigMixin, register_to_config
8
+ from ..utils import BaseOutput
9
+ from .scheduling_utils import SchedulerMixin
10
+
11
+
12
+ def gumbel_noise(t, generator=None):
13
+ device = generator.device if generator is not None else t.device
14
+ noise = torch.zeros_like(t, device=device).uniform_(0, 1, generator=generator).to(t.device)
15
+ return -torch.log((-torch.log(noise.clamp(1e-20))).clamp(1e-20))
16
+
17
+
18
+ def mask_by_random_topk(mask_len, probs, temperature=1.0, generator=None):
19
+ confidence = torch.log(probs.clamp(1e-20)) + temperature * gumbel_noise(probs, generator=generator)
20
+ sorted_confidence = torch.sort(confidence, dim=-1).values
21
+ cut_off = torch.gather(sorted_confidence, 1, mask_len.long())
22
+ masking = confidence < cut_off
23
+ return masking
24
+
25
+
26
+ @dataclass
27
+ class AmusedSchedulerOutput(BaseOutput):
28
+ """
29
+ Output class for the scheduler's `step` function output.
30
+
31
+ Args:
32
+ prev_sample (`torch.FloatTensor` of shape `(batch_size, num_channels, height, width)` for images):
33
+ Computed sample `(x_{t-1})` of previous timestep. `prev_sample` should be used as next model input in the
34
+ denoising loop.
35
+ pred_original_sample (`torch.FloatTensor` of shape `(batch_size, num_channels, height, width)` for images):
36
+ The predicted denoised sample `(x_{0})` based on the model output from the current timestep.
37
+ `pred_original_sample` can be used to preview progress or for guidance.
38
+ """
39
+
40
+ prev_sample: torch.FloatTensor
41
+ pred_original_sample: torch.FloatTensor = None
42
+
43
+
44
+ class AmusedScheduler(SchedulerMixin, ConfigMixin):
45
+ order = 1
46
+
47
+ temperatures: torch.Tensor
48
+
49
+ @register_to_config
50
+ def __init__(
51
+ self,
52
+ mask_token_id: int,
53
+ masking_schedule: str = "cosine",
54
+ ):
55
+ self.temperatures = None
56
+ self.timesteps = None
57
+
58
+ def set_timesteps(
59
+ self,
60
+ num_inference_steps: int,
61
+ temperature: Union[int, Tuple[int, int], List[int]] = (2, 0),
62
+ device: Union[str, torch.device] = None,
63
+ ):
64
+ self.timesteps = torch.arange(num_inference_steps, device=device).flip(0)
65
+
66
+ if isinstance(temperature, (tuple, list)):
67
+ self.temperatures = torch.linspace(temperature[0], temperature[1], num_inference_steps, device=device)
68
+ else:
69
+ self.temperatures = torch.linspace(temperature, 0.01, num_inference_steps, device=device)
70
+
71
+ def step(
72
+ self,
73
+ model_output: torch.FloatTensor,
74
+ timestep: torch.long,
75
+ sample: torch.LongTensor,
76
+ starting_mask_ratio: int = 1,
77
+ generator: Optional[torch.Generator] = None,
78
+ return_dict: bool = True,
79
+ ) -> Union[AmusedSchedulerOutput, Tuple]:
80
+ two_dim_input = sample.ndim == 3 and model_output.ndim == 4
81
+
82
+ if two_dim_input:
83
+ batch_size, codebook_size, height, width = model_output.shape
84
+ sample = sample.reshape(batch_size, height * width)
85
+ model_output = model_output.reshape(batch_size, codebook_size, height * width).permute(0, 2, 1)
86
+
87
+ unknown_map = sample == self.config.mask_token_id
88
+
89
+ probs = model_output.softmax(dim=-1)
90
+
91
+ device = probs.device
92
+ probs_ = probs.to(generator.device) if generator is not None else probs # handles when generator is on CPU
93
+ if probs_.device.type == "cpu" and probs_.dtype != torch.float32:
94
+ probs_ = probs_.float() # multinomial is not implemented for cpu half precision
95
+ probs_ = probs_.reshape(-1, probs.size(-1))
96
+ pred_original_sample = torch.multinomial(probs_, 1, generator=generator).to(device=device)
97
+ pred_original_sample = pred_original_sample[:, 0].view(*probs.shape[:-1])
98
+ pred_original_sample = torch.where(unknown_map, pred_original_sample, sample)
99
+
100
+ if timestep == 0:
101
+ prev_sample = pred_original_sample
102
+ else:
103
+ seq_len = sample.shape[1]
104
+ step_idx = (self.timesteps == timestep).nonzero()
105
+ ratio = (step_idx + 1) / len(self.timesteps)
106
+
107
+ if self.config.masking_schedule == "cosine":
108
+ mask_ratio = torch.cos(ratio * math.pi / 2)
109
+ elif self.config.masking_schedule == "linear":
110
+ mask_ratio = 1 - ratio
111
+ else:
112
+ raise ValueError(f"unknown masking schedule {self.config.masking_schedule}")
113
+
114
+ mask_ratio = starting_mask_ratio * mask_ratio
115
+
116
+ mask_len = (seq_len * mask_ratio).floor()
117
+ # do not mask more than amount previously masked
118
+ mask_len = torch.min(unknown_map.sum(dim=-1, keepdim=True) - 1, mask_len)
119
+ # mask at least one
120
+ mask_len = torch.max(torch.tensor([1], device=model_output.device), mask_len)
121
+
122
+ selected_probs = torch.gather(probs, -1, pred_original_sample[:, :, None])[:, :, 0]
123
+ # Ignores the tokens given in the input by overwriting their confidence.
124
+ selected_probs = torch.where(unknown_map, selected_probs, torch.finfo(selected_probs.dtype).max)
125
+
126
+ masking = mask_by_random_topk(mask_len, selected_probs, self.temperatures[step_idx], generator)
127
+
128
+ # Masks tokens with lower confidence.
129
+ prev_sample = torch.where(masking, self.config.mask_token_id, pred_original_sample)
130
+
131
+ if two_dim_input:
132
+ prev_sample = prev_sample.reshape(batch_size, height, width)
133
+ pred_original_sample = pred_original_sample.reshape(batch_size, height, width)
134
+
135
+ if not return_dict:
136
+ return (prev_sample, pred_original_sample)
137
+
138
+ return AmusedSchedulerOutput(prev_sample, pred_original_sample)
139
+
140
+ def add_noise(self, sample, timesteps, generator=None):
141
+ step_idx = (self.timesteps == timesteps).nonzero()
142
+ ratio = (step_idx + 1) / len(self.timesteps)
143
+
144
+ if self.config.masking_schedule == "cosine":
145
+ mask_ratio = torch.cos(ratio * math.pi / 2)
146
+ elif self.config.masking_schedule == "linear":
147
+ mask_ratio = 1 - ratio
148
+ else:
149
+ raise ValueError(f"unknown masking schedule {self.config.masking_schedule}")
150
+
151
+ mask_indices = (
152
+ torch.rand(
153
+ sample.shape, device=generator.device if generator is not None else sample.device, generator=generator
154
+ ).to(sample.device)
155
+ < mask_ratio
156
+ )
157
+
158
+ masked_sample = sample.clone()
159
+
160
+ masked_sample[mask_indices] = self.config.mask_token_id
161
+
162
+ return masked_sample
@@ -98,6 +98,7 @@ class CMStochasticIterativeScheduler(SchedulerMixin, ConfigMixin):
98
98
  self.custom_timesteps = False
99
99
  self.is_scale_input_called = False
100
100
  self._step_index = None
101
+ self.sigmas.to("cpu") # to avoid too much CPU/GPU communication
101
102
 
102
103
  def index_for_timestep(self, timestep, schedule_timesteps=None):
103
104
  if schedule_timesteps is None:
@@ -230,6 +231,7 @@ class CMStochasticIterativeScheduler(SchedulerMixin, ConfigMixin):
230
231
  self.timesteps = torch.from_numpy(timesteps).to(device=device)
231
232
 
232
233
  self._step_index = None
234
+ self.sigmas.to("cpu") # to avoid too much CPU/GPU communication
233
235
 
234
236
  # Modified _convert_to_karras implementation that takes in ramp as argument
235
237
  def _convert_to_karras(self, ramp):
@@ -293,9 +293,6 @@ class DDIMInverseScheduler(SchedulerMixin, ConfigMixin):
293
293
  model_output: torch.FloatTensor,
294
294
  timestep: int,
295
295
  sample: torch.FloatTensor,
296
- eta: float = 0.0,
297
- use_clipped_model_output: bool = False,
298
- variance_noise: Optional[torch.FloatTensor] = None,
299
296
  return_dict: bool = True,
300
297
  ) -> Union[DDIMSchedulerOutput, Tuple]:
301
298
  """
@@ -332,7 +329,7 @@ class DDIMInverseScheduler(SchedulerMixin, ConfigMixin):
332
329
  # 1. get previous step value (=t+1)
333
330
  prev_timestep = timestep
334
331
  timestep = min(
335
- timestep - self.config.num_train_timesteps // self.num_inference_steps, self.num_train_timesteps - 1
332
+ timestep - self.config.num_train_timesteps // self.num_inference_steps, self.config.num_train_timesteps - 1
336
333
  )
337
334
 
338
335
  # 2. compute alphas, betas
@@ -89,6 +89,43 @@ def betas_for_alpha_bar(
89
89
  return torch.tensor(betas, dtype=torch.float32)
90
90
 
91
91
 
92
+ # Copied from diffusers.schedulers.scheduling_ddim.rescale_zero_terminal_snr
93
+ def rescale_zero_terminal_snr(betas):
94
+ """
95
+ Rescales betas to have zero terminal SNR Based on https://arxiv.org/pdf/2305.08891.pdf (Algorithm 1)
96
+
97
+
98
+ Args:
99
+ betas (`torch.FloatTensor`):
100
+ the betas that the scheduler is being initialized with.
101
+
102
+ Returns:
103
+ `torch.FloatTensor`: rescaled betas with zero terminal SNR
104
+ """
105
+ # Convert betas to alphas_bar_sqrt
106
+ alphas = 1.0 - betas
107
+ alphas_cumprod = torch.cumprod(alphas, dim=0)
108
+ alphas_bar_sqrt = alphas_cumprod.sqrt()
109
+
110
+ # Store old values.
111
+ alphas_bar_sqrt_0 = alphas_bar_sqrt[0].clone()
112
+ alphas_bar_sqrt_T = alphas_bar_sqrt[-1].clone()
113
+
114
+ # Shift so the last timestep is zero.
115
+ alphas_bar_sqrt -= alphas_bar_sqrt_T
116
+
117
+ # Scale so the first timestep is back to the old value.
118
+ alphas_bar_sqrt *= alphas_bar_sqrt_0 / (alphas_bar_sqrt_0 - alphas_bar_sqrt_T)
119
+
120
+ # Convert alphas_bar_sqrt to betas
121
+ alphas_bar = alphas_bar_sqrt**2 # Revert sqrt
122
+ alphas = alphas_bar[1:] / alphas_bar[:-1] # Revert cumprod
123
+ alphas = torch.cat([alphas_bar[0:1], alphas])
124
+ betas = 1 - alphas
125
+
126
+ return betas
127
+
128
+
92
129
  class DDPMScheduler(SchedulerMixin, ConfigMixin):
93
130
  """
94
131
  `DDPMScheduler` explores the connections between denoising score matching and Langevin dynamics sampling.
@@ -131,6 +168,10 @@ class DDPMScheduler(SchedulerMixin, ConfigMixin):
131
168
  An offset added to the inference steps. You can use a combination of `offset=1` and
132
169
  `set_alpha_to_one=False` to make the last step use step 0 for the previous alpha product like in Stable
133
170
  Diffusion.
171
+ rescale_betas_zero_snr (`bool`, defaults to `False`):
172
+ Whether to rescale the betas to have zero terminal SNR. This enables the model to generate very bright and
173
+ dark samples instead of limiting it to samples with medium brightness. Loosely related to
174
+ [`--offset_noise`](https://github.com/huggingface/diffusers/blob/74fd735eb073eb1d774b1ab4154a0876eb82f055/examples/dreambooth/train_dreambooth.py#L506).
134
175
  """
135
176
 
136
177
  _compatibles = [e.name for e in KarrasDiffusionSchedulers]
@@ -153,6 +194,7 @@ class DDPMScheduler(SchedulerMixin, ConfigMixin):
153
194
  sample_max_value: float = 1.0,
154
195
  timestep_spacing: str = "leading",
155
196
  steps_offset: int = 0,
197
+ rescale_betas_zero_snr: int = False,
156
198
  ):
157
199
  if trained_betas is not None:
158
200
  self.betas = torch.tensor(trained_betas, dtype=torch.float32)
@@ -171,6 +213,10 @@ class DDPMScheduler(SchedulerMixin, ConfigMixin):
171
213
  else:
172
214
  raise NotImplementedError(f"{beta_schedule} does is not implemented for {self.__class__}")
173
215
 
216
+ # Rescale for zero SNR
217
+ if rescale_betas_zero_snr:
218
+ self.betas = rescale_zero_terminal_snr(self.betas)
219
+
174
220
  self.alphas = 1.0 - self.betas
175
221
  self.alphas_cumprod = torch.cumprod(self.alphas, dim=0)
176
222
  self.one = torch.tensor(1.0)
@@ -91,6 +91,43 @@ def betas_for_alpha_bar(
91
91
  return torch.tensor(betas, dtype=torch.float32)
92
92
 
93
93
 
94
+ # Copied from diffusers.schedulers.scheduling_ddim.rescale_zero_terminal_snr
95
+ def rescale_zero_terminal_snr(betas):
96
+ """
97
+ Rescales betas to have zero terminal SNR Based on https://arxiv.org/pdf/2305.08891.pdf (Algorithm 1)
98
+
99
+
100
+ Args:
101
+ betas (`torch.FloatTensor`):
102
+ the betas that the scheduler is being initialized with.
103
+
104
+ Returns:
105
+ `torch.FloatTensor`: rescaled betas with zero terminal SNR
106
+ """
107
+ # Convert betas to alphas_bar_sqrt
108
+ alphas = 1.0 - betas
109
+ alphas_cumprod = torch.cumprod(alphas, dim=0)
110
+ alphas_bar_sqrt = alphas_cumprod.sqrt()
111
+
112
+ # Store old values.
113
+ alphas_bar_sqrt_0 = alphas_bar_sqrt[0].clone()
114
+ alphas_bar_sqrt_T = alphas_bar_sqrt[-1].clone()
115
+
116
+ # Shift so the last timestep is zero.
117
+ alphas_bar_sqrt -= alphas_bar_sqrt_T
118
+
119
+ # Scale so the first timestep is back to the old value.
120
+ alphas_bar_sqrt *= alphas_bar_sqrt_0 / (alphas_bar_sqrt_0 - alphas_bar_sqrt_T)
121
+
122
+ # Convert alphas_bar_sqrt to betas
123
+ alphas_bar = alphas_bar_sqrt**2 # Revert sqrt
124
+ alphas = alphas_bar[1:] / alphas_bar[:-1] # Revert cumprod
125
+ alphas = torch.cat([alphas_bar[0:1], alphas])
126
+ betas = 1 - alphas
127
+
128
+ return betas
129
+
130
+
94
131
  class DDPMParallelScheduler(SchedulerMixin, ConfigMixin):
95
132
  """
96
133
  Denoising diffusion probabilistic models (DDPMs) explores the connections between denoising score matching and
@@ -139,6 +176,10 @@ class DDPMParallelScheduler(SchedulerMixin, ConfigMixin):
139
176
  an offset added to the inference steps. You can use a combination of `offset=1` and
140
177
  `set_alpha_to_one=False`, to make the last step use step 0 for the previous alpha product, as done in
141
178
  stable diffusion.
179
+ rescale_betas_zero_snr (`bool`, defaults to `False`):
180
+ Whether to rescale the betas to have zero terminal SNR. This enables the model to generate very bright and
181
+ dark samples instead of limiting it to samples with medium brightness. Loosely related to
182
+ [`--offset_noise`](https://github.com/huggingface/diffusers/blob/74fd735eb073eb1d774b1ab4154a0876eb82f055/examples/dreambooth/train_dreambooth.py#L506).
142
183
  """
143
184
 
144
185
  _compatibles = [e.name for e in KarrasDiffusionSchedulers]
@@ -163,6 +204,7 @@ class DDPMParallelScheduler(SchedulerMixin, ConfigMixin):
163
204
  sample_max_value: float = 1.0,
164
205
  timestep_spacing: str = "leading",
165
206
  steps_offset: int = 0,
207
+ rescale_betas_zero_snr: int = False,
166
208
  ):
167
209
  if trained_betas is not None:
168
210
  self.betas = torch.tensor(trained_betas, dtype=torch.float32)
@@ -181,6 +223,10 @@ class DDPMParallelScheduler(SchedulerMixin, ConfigMixin):
181
223
  else:
182
224
  raise NotImplementedError(f"{beta_schedule} does is not implemented for {self.__class__}")
183
225
 
226
+ # Rescale for zero SNR
227
+ if rescale_betas_zero_snr:
228
+ self.betas = rescale_zero_terminal_snr(self.betas)
229
+
184
230
  self.alphas = 1.0 - self.betas
185
231
  self.alphas_cumprod = torch.cumprod(self.alphas, dim=0)
186
232
  self.one = torch.tensor(1.0)
@@ -162,6 +162,7 @@ class DEISMultistepScheduler(SchedulerMixin, ConfigMixin):
162
162
  self.alpha_t = torch.sqrt(self.alphas_cumprod)
163
163
  self.sigma_t = torch.sqrt(1 - self.alphas_cumprod)
164
164
  self.lambda_t = torch.log(self.alpha_t) - torch.log(self.sigma_t)
165
+ self.sigmas = ((1 - self.alphas_cumprod) / self.alphas_cumprod) ** 0.5
165
166
 
166
167
  # standard deviation of the initial noise distribution
167
168
  self.init_noise_sigma = 1.0
@@ -186,6 +187,7 @@ class DEISMultistepScheduler(SchedulerMixin, ConfigMixin):
186
187
  self.model_outputs = [None] * solver_order
187
188
  self.lower_order_nums = 0
188
189
  self._step_index = None
190
+ self.sigmas.to("cpu") # to avoid too much CPU/GPU communication
189
191
 
190
192
  @property
191
193
  def step_index(self):
@@ -253,6 +255,7 @@ class DEISMultistepScheduler(SchedulerMixin, ConfigMixin):
253
255
 
254
256
  # add an index counter for schedulers that allow duplicated timesteps
255
257
  self._step_index = None
258
+ self.sigmas.to("cpu") # to avoid too much CPU/GPU communication
256
259
 
257
260
  # Copied from diffusers.schedulers.scheduling_ddpm.DDPMScheduler._threshold_sample
258
261
  def _threshold_sample(self, sample: torch.FloatTensor) -> torch.FloatTensor:
@@ -733,7 +736,16 @@ class DEISMultistepScheduler(SchedulerMixin, ConfigMixin):
733
736
  schedule_timesteps = self.timesteps.to(original_samples.device)
734
737
  timesteps = timesteps.to(original_samples.device)
735
738
 
736
- step_indices = [(schedule_timesteps == t).nonzero().item() for t in timesteps]
739
+ step_indices = []
740
+ for timestep in timesteps:
741
+ index_candidates = (schedule_timesteps == timestep).nonzero()
742
+ if len(index_candidates) == 0:
743
+ step_index = len(schedule_timesteps) - 1
744
+ elif len(index_candidates) > 1:
745
+ step_index = index_candidates[1].item()
746
+ else:
747
+ step_index = index_candidates[0].item()
748
+ step_indices.append(step_index)
737
749
 
738
750
  sigma = sigmas[step_indices].flatten()
739
751
  while len(sigma.shape) < len(original_samples.shape):
@@ -189,6 +189,7 @@ class DPMSolverMultistepScheduler(SchedulerMixin, ConfigMixin):
189
189
  self.alpha_t = torch.sqrt(self.alphas_cumprod)
190
190
  self.sigma_t = torch.sqrt(1 - self.alphas_cumprod)
191
191
  self.lambda_t = torch.log(self.alpha_t) - torch.log(self.sigma_t)
192
+ self.sigmas = ((1 - self.alphas_cumprod) / self.alphas_cumprod) ** 0.5
192
193
 
193
194
  # standard deviation of the initial noise distribution
194
195
  self.init_noise_sigma = 1.0
@@ -213,6 +214,7 @@ class DPMSolverMultistepScheduler(SchedulerMixin, ConfigMixin):
213
214
  self.model_outputs = [None] * solver_order
214
215
  self.lower_order_nums = 0
215
216
  self._step_index = None
217
+ self.sigmas.to("cpu") # to avoid too much CPU/GPU communication
216
218
 
217
219
  @property
218
220
  def step_index(self):
@@ -289,6 +291,7 @@ class DPMSolverMultistepScheduler(SchedulerMixin, ConfigMixin):
289
291
 
290
292
  # add an index counter for schedulers that allow duplicated timesteps
291
293
  self._step_index = None
294
+ self.sigmas.to("cpu") # to avoid too much CPU/GPU communication
292
295
 
293
296
  # Copied from diffusers.schedulers.scheduling_ddpm.DDPMScheduler._threshold_sample
294
297
  def _threshold_sample(self, sample: torch.FloatTensor) -> torch.FloatTensor:
@@ -895,7 +898,16 @@ class DPMSolverMultistepScheduler(SchedulerMixin, ConfigMixin):
895
898
  schedule_timesteps = self.timesteps.to(original_samples.device)
896
899
  timesteps = timesteps.to(original_samples.device)
897
900
 
898
- step_indices = [(schedule_timesteps == t).nonzero().item() for t in timesteps]
901
+ step_indices = []
902
+ for timestep in timesteps:
903
+ index_candidates = (schedule_timesteps == timestep).nonzero()
904
+ if len(index_candidates) == 0:
905
+ step_index = len(schedule_timesteps) - 1
906
+ elif len(index_candidates) > 1:
907
+ step_index = index_candidates[1].item()
908
+ else:
909
+ step_index = index_candidates[0].item()
910
+ step_indices.append(step_index)
899
911
 
900
912
  sigma = sigmas[step_indices].flatten()
901
913
  while len(sigma.shape) < len(original_samples.shape):
@@ -184,6 +184,7 @@ class DPMSolverMultistepInverseScheduler(SchedulerMixin, ConfigMixin):
184
184
  self.alpha_t = torch.sqrt(self.alphas_cumprod)
185
185
  self.sigma_t = torch.sqrt(1 - self.alphas_cumprod)
186
186
  self.lambda_t = torch.log(self.alpha_t) - torch.log(self.sigma_t)
187
+ self.sigmas = ((1 - self.alphas_cumprod) / self.alphas_cumprod) ** 0.5
187
188
 
188
189
  # standard deviation of the initial noise distribution
189
190
  self.init_noise_sigma = 1.0
@@ -208,6 +209,7 @@ class DPMSolverMultistepInverseScheduler(SchedulerMixin, ConfigMixin):
208
209
  self.model_outputs = [None] * solver_order
209
210
  self.lower_order_nums = 0
210
211
  self._step_index = None
212
+ self.sigmas.to("cpu") # to avoid too much CPU/GPU communication
211
213
  self.use_karras_sigmas = use_karras_sigmas
212
214
 
213
215
  @property
@@ -288,6 +290,7 @@ class DPMSolverMultistepInverseScheduler(SchedulerMixin, ConfigMixin):
288
290
 
289
291
  # add an index counter for schedulers that allow duplicated timesteps
290
292
  self._step_index = None
293
+ self.sigmas.to("cpu") # to avoid too much CPU/GPU communication
291
294
 
292
295
  # Copied from diffusers.schedulers.scheduling_ddpm.DDPMScheduler._threshold_sample
293
296
  def _threshold_sample(self, sample: torch.FloatTensor) -> torch.FloatTensor:
@@ -890,7 +893,16 @@ class DPMSolverMultistepInverseScheduler(SchedulerMixin, ConfigMixin):
890
893
  schedule_timesteps = self.timesteps.to(original_samples.device)
891
894
  timesteps = timesteps.to(original_samples.device)
892
895
 
893
- step_indices = [(schedule_timesteps == t).nonzero().item() for t in timesteps]
896
+ step_indices = []
897
+ for timestep in timesteps:
898
+ index_candidates = (schedule_timesteps == timestep).nonzero()
899
+ if len(index_candidates) == 0:
900
+ step_index = len(schedule_timesteps) - 1
901
+ elif len(index_candidates) > 1:
902
+ step_index = index_candidates[1].item()
903
+ else:
904
+ step_index = index_candidates[0].item()
905
+ step_indices.append(step_index)
894
906
 
895
907
  sigma = sigmas[step_indices].flatten()
896
908
  while len(sigma.shape) < len(original_samples.shape):
@@ -198,6 +198,7 @@ class DPMSolverSDEScheduler(SchedulerMixin, ConfigMixin):
198
198
  self.noise_sampler = None
199
199
  self.noise_sampler_seed = noise_sampler_seed
200
200
  self._step_index = None
201
+ self.sigmas.to("cpu") # to avoid too much CPU/GPU communication
201
202
 
202
203
  # Copied from diffusers.schedulers.scheduling_heun_discrete.HeunDiscreteScheduler.index_for_timestep
203
204
  def index_for_timestep(self, timestep, schedule_timesteps=None):
@@ -347,6 +348,7 @@ class DPMSolverSDEScheduler(SchedulerMixin, ConfigMixin):
347
348
  self.mid_point_sigma = None
348
349
 
349
350
  self._step_index = None
351
+ self.sigmas.to("cpu") # to avoid too much CPU/GPU communication
350
352
  self.noise_sampler = None
351
353
 
352
354
  # for exp beta schedules, such as the one for `pipeline_shap_e.py`
@@ -172,6 +172,7 @@ class DPMSolverSinglestepScheduler(SchedulerMixin, ConfigMixin):
172
172
  self.alpha_t = torch.sqrt(self.alphas_cumprod)
173
173
  self.sigma_t = torch.sqrt(1 - self.alphas_cumprod)
174
174
  self.lambda_t = torch.log(self.alpha_t) - torch.log(self.sigma_t)
175
+ self.sigmas = ((1 - self.alphas_cumprod) / self.alphas_cumprod) ** 0.5
175
176
 
176
177
  # standard deviation of the initial noise distribution
177
178
  self.init_noise_sigma = 1.0
@@ -196,6 +197,7 @@ class DPMSolverSinglestepScheduler(SchedulerMixin, ConfigMixin):
196
197
  self.sample = None
197
198
  self.order_list = self.get_order_list(num_train_timesteps)
198
199
  self._step_index = None
200
+ self.sigmas.to("cpu") # to avoid too much CPU/GPU communication
199
201
 
200
202
  def get_order_list(self, num_inference_steps: int) -> List[int]:
201
203
  """
@@ -287,6 +289,7 @@ class DPMSolverSinglestepScheduler(SchedulerMixin, ConfigMixin):
287
289
 
288
290
  # add an index counter for schedulers that allow duplicated timesteps
289
291
  self._step_index = None
292
+ self.sigmas.to("cpu") # to avoid too much CPU/GPU communication
290
293
 
291
294
  # Copied from diffusers.schedulers.scheduling_ddpm.DDPMScheduler._threshold_sample
292
295
  def _threshold_sample(self, sample: torch.FloatTensor) -> torch.FloatTensor:
@@ -896,7 +899,16 @@ class DPMSolverSinglestepScheduler(SchedulerMixin, ConfigMixin):
896
899
  schedule_timesteps = self.timesteps.to(original_samples.device)
897
900
  timesteps = timesteps.to(original_samples.device)
898
901
 
899
- step_indices = [(schedule_timesteps == t).nonzero().item() for t in timesteps]
902
+ step_indices = []
903
+ for timestep in timesteps:
904
+ index_candidates = (schedule_timesteps == timestep).nonzero()
905
+ if len(index_candidates) == 0:
906
+ step_index = len(schedule_timesteps) - 1
907
+ elif len(index_candidates) > 1:
908
+ step_index = index_candidates[1].item()
909
+ else:
910
+ step_index = index_candidates[0].item()
911
+ step_indices.append(step_index)
900
912
 
901
913
  sigma = sigmas[step_indices].flatten()
902
914
  while len(sigma.shape) < len(original_samples.shape):
@@ -92,6 +92,43 @@ def betas_for_alpha_bar(
92
92
  return torch.tensor(betas, dtype=torch.float32)
93
93
 
94
94
 
95
+ # Copied from diffusers.schedulers.scheduling_ddim.rescale_zero_terminal_snr
96
+ def rescale_zero_terminal_snr(betas):
97
+ """
98
+ Rescales betas to have zero terminal SNR Based on https://arxiv.org/pdf/2305.08891.pdf (Algorithm 1)
99
+
100
+
101
+ Args:
102
+ betas (`torch.FloatTensor`):
103
+ the betas that the scheduler is being initialized with.
104
+
105
+ Returns:
106
+ `torch.FloatTensor`: rescaled betas with zero terminal SNR
107
+ """
108
+ # Convert betas to alphas_bar_sqrt
109
+ alphas = 1.0 - betas
110
+ alphas_cumprod = torch.cumprod(alphas, dim=0)
111
+ alphas_bar_sqrt = alphas_cumprod.sqrt()
112
+
113
+ # Store old values.
114
+ alphas_bar_sqrt_0 = alphas_bar_sqrt[0].clone()
115
+ alphas_bar_sqrt_T = alphas_bar_sqrt[-1].clone()
116
+
117
+ # Shift so the last timestep is zero.
118
+ alphas_bar_sqrt -= alphas_bar_sqrt_T
119
+
120
+ # Scale so the first timestep is back to the old value.
121
+ alphas_bar_sqrt *= alphas_bar_sqrt_0 / (alphas_bar_sqrt_0 - alphas_bar_sqrt_T)
122
+
123
+ # Convert alphas_bar_sqrt to betas
124
+ alphas_bar = alphas_bar_sqrt**2 # Revert sqrt
125
+ alphas = alphas_bar[1:] / alphas_bar[:-1] # Revert cumprod
126
+ alphas = torch.cat([alphas_bar[0:1], alphas])
127
+ betas = 1 - alphas
128
+
129
+ return betas
130
+
131
+
95
132
  class EulerAncestralDiscreteScheduler(SchedulerMixin, ConfigMixin):
96
133
  """
97
134
  Ancestral sampling with Euler method steps.
@@ -122,6 +159,10 @@ class EulerAncestralDiscreteScheduler(SchedulerMixin, ConfigMixin):
122
159
  An offset added to the inference steps. You can use a combination of `offset=1` and
123
160
  `set_alpha_to_one=False` to make the last step use step 0 for the previous alpha product like in Stable
124
161
  Diffusion.
162
+ rescale_betas_zero_snr (`bool`, defaults to `False`):
163
+ Whether to rescale the betas to have zero terminal SNR. This enables the model to generate very bright and
164
+ dark samples instead of limiting it to samples with medium brightness. Loosely related to
165
+ [`--offset_noise`](https://github.com/huggingface/diffusers/blob/74fd735eb073eb1d774b1ab4154a0876eb82f055/examples/dreambooth/train_dreambooth.py#L506).
125
166
  """
126
167
 
127
168
  _compatibles = [e.name for e in KarrasDiffusionSchedulers]
@@ -138,6 +179,7 @@ class EulerAncestralDiscreteScheduler(SchedulerMixin, ConfigMixin):
138
179
  prediction_type: str = "epsilon",
139
180
  timestep_spacing: str = "linspace",
140
181
  steps_offset: int = 0,
182
+ rescale_betas_zero_snr: bool = False,
141
183
  ):
142
184
  if trained_betas is not None:
143
185
  self.betas = torch.tensor(trained_betas, dtype=torch.float32)
@@ -152,9 +194,17 @@ class EulerAncestralDiscreteScheduler(SchedulerMixin, ConfigMixin):
152
194
  else:
153
195
  raise NotImplementedError(f"{beta_schedule} does is not implemented for {self.__class__}")
154
196
 
197
+ if rescale_betas_zero_snr:
198
+ self.betas = rescale_zero_terminal_snr(self.betas)
199
+
155
200
  self.alphas = 1.0 - self.betas
156
201
  self.alphas_cumprod = torch.cumprod(self.alphas, dim=0)
157
202
 
203
+ if rescale_betas_zero_snr:
204
+ # Close to 0 without being 0 so first sigma is not inf
205
+ # FP16 smallest positive subnormal works well here
206
+ self.alphas_cumprod[-1] = 2**-24
207
+
158
208
  sigmas = np.array(((1 - self.alphas_cumprod) / self.alphas_cumprod) ** 0.5)
159
209
  sigmas = np.concatenate([sigmas[::-1], [0.0]]).astype(np.float32)
160
210
  self.sigmas = torch.from_numpy(sigmas)
@@ -166,6 +216,7 @@ class EulerAncestralDiscreteScheduler(SchedulerMixin, ConfigMixin):
166
216
  self.is_scale_input_called = False
167
217
 
168
218
  self._step_index = None
219
+ self.sigmas.to("cpu") # to avoid too much CPU/GPU communication
169
220
 
170
221
  @property
171
222
  def init_noise_sigma(self):
@@ -249,6 +300,7 @@ class EulerAncestralDiscreteScheduler(SchedulerMixin, ConfigMixin):
249
300
 
250
301
  self.timesteps = torch.from_numpy(timesteps).to(device=device)
251
302
  self._step_index = None
303
+ self.sigmas.to("cpu") # to avoid too much CPU/GPU communication
252
304
 
253
305
  # Copied from diffusers.schedulers.scheduling_euler_discrete.EulerDiscreteScheduler._init_step_index
254
306
  def _init_step_index(self, timestep):
@@ -325,6 +377,9 @@ class EulerAncestralDiscreteScheduler(SchedulerMixin, ConfigMixin):
325
377
 
326
378
  sigma = self.sigmas[self.step_index]
327
379
 
380
+ # Upcast to avoid precision issues when computing prev_sample
381
+ sample = sample.to(torch.float32)
382
+
328
383
  # 1. compute predicted original sample (x_0) from sigma-scaled predicted noise
329
384
  if self.config.prediction_type == "epsilon":
330
385
  pred_original_sample = sample - sigma * model_output
@@ -355,6 +410,9 @@ class EulerAncestralDiscreteScheduler(SchedulerMixin, ConfigMixin):
355
410
 
356
411
  prev_sample = prev_sample + noise * sigma_up
357
412
 
413
+ # Cast sample back to model compatible dtype
414
+ prev_sample = prev_sample.to(model_output.dtype)
415
+
358
416
  # upon completion increase step index by one
359
417
  self._step_index += 1
360
418