diffusers 0.27.2__py3-none-any.whl → 0.28.0__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (270) hide show
  1. diffusers/__init__.py +18 -1
  2. diffusers/callbacks.py +156 -0
  3. diffusers/commands/env.py +110 -6
  4. diffusers/configuration_utils.py +16 -11
  5. diffusers/dependency_versions_table.py +2 -1
  6. diffusers/image_processor.py +158 -45
  7. diffusers/loaders/__init__.py +2 -5
  8. diffusers/loaders/autoencoder.py +4 -4
  9. diffusers/loaders/controlnet.py +4 -4
  10. diffusers/loaders/ip_adapter.py +80 -22
  11. diffusers/loaders/lora.py +134 -20
  12. diffusers/loaders/lora_conversion_utils.py +46 -43
  13. diffusers/loaders/peft.py +4 -3
  14. diffusers/loaders/single_file.py +401 -170
  15. diffusers/loaders/single_file_model.py +290 -0
  16. diffusers/loaders/single_file_utils.py +616 -672
  17. diffusers/loaders/textual_inversion.py +41 -20
  18. diffusers/loaders/unet.py +168 -115
  19. diffusers/loaders/unet_loader_utils.py +163 -0
  20. diffusers/models/__init__.py +2 -0
  21. diffusers/models/activations.py +11 -3
  22. diffusers/models/attention.py +10 -11
  23. diffusers/models/attention_processor.py +367 -148
  24. diffusers/models/autoencoders/autoencoder_asym_kl.py +14 -16
  25. diffusers/models/autoencoders/autoencoder_kl.py +18 -19
  26. diffusers/models/autoencoders/autoencoder_kl_temporal_decoder.py +11 -11
  27. diffusers/models/autoencoders/autoencoder_tiny.py +16 -16
  28. diffusers/models/autoencoders/consistency_decoder_vae.py +36 -11
  29. diffusers/models/autoencoders/vae.py +23 -24
  30. diffusers/models/controlnet.py +12 -9
  31. diffusers/models/controlnet_flax.py +4 -4
  32. diffusers/models/controlnet_xs.py +1915 -0
  33. diffusers/models/downsampling.py +17 -18
  34. diffusers/models/embeddings.py +147 -24
  35. diffusers/models/model_loading_utils.py +149 -0
  36. diffusers/models/modeling_flax_pytorch_utils.py +2 -1
  37. diffusers/models/modeling_flax_utils.py +4 -4
  38. diffusers/models/modeling_pytorch_flax_utils.py +1 -1
  39. diffusers/models/modeling_utils.py +118 -98
  40. diffusers/models/resnet.py +18 -23
  41. diffusers/models/transformer_temporal.py +3 -3
  42. diffusers/models/transformers/dual_transformer_2d.py +4 -4
  43. diffusers/models/transformers/prior_transformer.py +7 -7
  44. diffusers/models/transformers/t5_film_transformer.py +17 -19
  45. diffusers/models/transformers/transformer_2d.py +272 -156
  46. diffusers/models/transformers/transformer_temporal.py +10 -10
  47. diffusers/models/unets/unet_1d.py +5 -5
  48. diffusers/models/unets/unet_1d_blocks.py +29 -29
  49. diffusers/models/unets/unet_2d.py +6 -6
  50. diffusers/models/unets/unet_2d_blocks.py +137 -128
  51. diffusers/models/unets/unet_2d_condition.py +19 -15
  52. diffusers/models/unets/unet_2d_condition_flax.py +6 -5
  53. diffusers/models/unets/unet_3d_blocks.py +79 -77
  54. diffusers/models/unets/unet_3d_condition.py +13 -9
  55. diffusers/models/unets/unet_i2vgen_xl.py +14 -13
  56. diffusers/models/unets/unet_kandinsky3.py +1 -1
  57. diffusers/models/unets/unet_motion_model.py +114 -14
  58. diffusers/models/unets/unet_spatio_temporal_condition.py +15 -14
  59. diffusers/models/unets/unet_stable_cascade.py +16 -13
  60. diffusers/models/upsampling.py +17 -20
  61. diffusers/models/vq_model.py +16 -15
  62. diffusers/pipelines/__init__.py +25 -3
  63. diffusers/pipelines/amused/pipeline_amused.py +12 -12
  64. diffusers/pipelines/amused/pipeline_amused_img2img.py +14 -12
  65. diffusers/pipelines/amused/pipeline_amused_inpaint.py +13 -11
  66. diffusers/pipelines/animatediff/__init__.py +2 -0
  67. diffusers/pipelines/animatediff/pipeline_animatediff.py +24 -46
  68. diffusers/pipelines/animatediff/pipeline_animatediff_sdxl.py +1284 -0
  69. diffusers/pipelines/animatediff/pipeline_animatediff_video2video.py +82 -72
  70. diffusers/pipelines/animatediff/pipeline_output.py +3 -2
  71. diffusers/pipelines/audioldm/pipeline_audioldm.py +14 -14
  72. diffusers/pipelines/audioldm2/modeling_audioldm2.py +54 -35
  73. diffusers/pipelines/audioldm2/pipeline_audioldm2.py +120 -36
  74. diffusers/pipelines/auto_pipeline.py +21 -17
  75. diffusers/pipelines/blip_diffusion/blip_image_processing.py +1 -1
  76. diffusers/pipelines/blip_diffusion/modeling_blip2.py +5 -5
  77. diffusers/pipelines/blip_diffusion/modeling_ctx_clip.py +1 -1
  78. diffusers/pipelines/blip_diffusion/pipeline_blip_diffusion.py +2 -2
  79. diffusers/pipelines/consistency_models/pipeline_consistency_models.py +5 -5
  80. diffusers/pipelines/controlnet/multicontrolnet.py +4 -8
  81. diffusers/pipelines/controlnet/pipeline_controlnet.py +87 -52
  82. diffusers/pipelines/controlnet/pipeline_controlnet_blip_diffusion.py +2 -2
  83. diffusers/pipelines/controlnet/pipeline_controlnet_img2img.py +50 -43
  84. diffusers/pipelines/controlnet/pipeline_controlnet_inpaint.py +52 -40
  85. diffusers/pipelines/controlnet/pipeline_controlnet_inpaint_sd_xl.py +80 -47
  86. diffusers/pipelines/controlnet/pipeline_controlnet_sd_xl.py +147 -49
  87. diffusers/pipelines/controlnet/pipeline_controlnet_sd_xl_img2img.py +89 -55
  88. diffusers/pipelines/controlnet_xs/__init__.py +68 -0
  89. diffusers/pipelines/controlnet_xs/pipeline_controlnet_xs.py +911 -0
  90. diffusers/pipelines/controlnet_xs/pipeline_controlnet_xs_sd_xl.py +1115 -0
  91. diffusers/pipelines/deepfloyd_if/pipeline_if.py +14 -28
  92. diffusers/pipelines/deepfloyd_if/pipeline_if_img2img.py +18 -33
  93. diffusers/pipelines/deepfloyd_if/pipeline_if_img2img_superresolution.py +21 -39
  94. diffusers/pipelines/deepfloyd_if/pipeline_if_inpainting.py +20 -36
  95. diffusers/pipelines/deepfloyd_if/pipeline_if_inpainting_superresolution.py +23 -39
  96. diffusers/pipelines/deepfloyd_if/pipeline_if_superresolution.py +17 -32
  97. diffusers/pipelines/deprecated/alt_diffusion/modeling_roberta_series.py +11 -11
  98. diffusers/pipelines/deprecated/alt_diffusion/pipeline_alt_diffusion.py +43 -20
  99. diffusers/pipelines/deprecated/alt_diffusion/pipeline_alt_diffusion_img2img.py +36 -18
  100. diffusers/pipelines/deprecated/repaint/pipeline_repaint.py +2 -2
  101. diffusers/pipelines/deprecated/spectrogram_diffusion/pipeline_spectrogram_diffusion.py +7 -7
  102. diffusers/pipelines/deprecated/stable_diffusion_variants/pipeline_cycle_diffusion.py +12 -12
  103. diffusers/pipelines/deprecated/stable_diffusion_variants/pipeline_stable_diffusion_inpaint_legacy.py +18 -18
  104. diffusers/pipelines/deprecated/stable_diffusion_variants/pipeline_stable_diffusion_model_editing.py +20 -15
  105. diffusers/pipelines/deprecated/stable_diffusion_variants/pipeline_stable_diffusion_paradigms.py +20 -15
  106. diffusers/pipelines/deprecated/stable_diffusion_variants/pipeline_stable_diffusion_pix2pix_zero.py +30 -25
  107. diffusers/pipelines/deprecated/versatile_diffusion/modeling_text_unet.py +69 -59
  108. diffusers/pipelines/deprecated/versatile_diffusion/pipeline_versatile_diffusion.py +13 -13
  109. diffusers/pipelines/deprecated/versatile_diffusion/pipeline_versatile_diffusion_dual_guided.py +10 -5
  110. diffusers/pipelines/deprecated/versatile_diffusion/pipeline_versatile_diffusion_image_variation.py +11 -6
  111. diffusers/pipelines/deprecated/versatile_diffusion/pipeline_versatile_diffusion_text_to_image.py +10 -5
  112. diffusers/pipelines/deprecated/vq_diffusion/pipeline_vq_diffusion.py +5 -5
  113. diffusers/pipelines/dit/pipeline_dit.py +3 -0
  114. diffusers/pipelines/free_init_utils.py +39 -38
  115. diffusers/pipelines/i2vgen_xl/pipeline_i2vgen_xl.py +33 -48
  116. diffusers/pipelines/kandinsky/pipeline_kandinsky.py +8 -8
  117. diffusers/pipelines/kandinsky/pipeline_kandinsky_combined.py +23 -20
  118. diffusers/pipelines/kandinsky/pipeline_kandinsky_img2img.py +11 -11
  119. diffusers/pipelines/kandinsky/pipeline_kandinsky_inpaint.py +12 -12
  120. diffusers/pipelines/kandinsky/pipeline_kandinsky_prior.py +10 -10
  121. diffusers/pipelines/kandinsky2_2/pipeline_kandinsky2_2.py +6 -6
  122. diffusers/pipelines/kandinsky2_2/pipeline_kandinsky2_2_combined.py +32 -29
  123. diffusers/pipelines/kandinsky2_2/pipeline_kandinsky2_2_controlnet.py +10 -10
  124. diffusers/pipelines/kandinsky2_2/pipeline_kandinsky2_2_controlnet_img2img.py +10 -10
  125. diffusers/pipelines/kandinsky2_2/pipeline_kandinsky2_2_img2img.py +6 -6
  126. diffusers/pipelines/kandinsky2_2/pipeline_kandinsky2_2_inpainting.py +8 -8
  127. diffusers/pipelines/kandinsky2_2/pipeline_kandinsky2_2_prior.py +7 -7
  128. diffusers/pipelines/kandinsky2_2/pipeline_kandinsky2_2_prior_emb2emb.py +6 -6
  129. diffusers/pipelines/kandinsky3/convert_kandinsky3_unet.py +3 -3
  130. diffusers/pipelines/kandinsky3/pipeline_kandinsky3.py +20 -33
  131. diffusers/pipelines/kandinsky3/pipeline_kandinsky3_img2img.py +24 -35
  132. diffusers/pipelines/latent_consistency_models/pipeline_latent_consistency_img2img.py +48 -30
  133. diffusers/pipelines/latent_consistency_models/pipeline_latent_consistency_text2img.py +50 -28
  134. diffusers/pipelines/latent_diffusion/pipeline_latent_diffusion.py +11 -11
  135. diffusers/pipelines/ledits_pp/pipeline_leditspp_stable_diffusion.py +61 -67
  136. diffusers/pipelines/ledits_pp/pipeline_leditspp_stable_diffusion_xl.py +70 -69
  137. diffusers/pipelines/ledits_pp/pipeline_output.py +2 -2
  138. diffusers/pipelines/marigold/__init__.py +50 -0
  139. diffusers/pipelines/marigold/marigold_image_processing.py +561 -0
  140. diffusers/pipelines/marigold/pipeline_marigold_depth.py +813 -0
  141. diffusers/pipelines/marigold/pipeline_marigold_normals.py +690 -0
  142. diffusers/pipelines/musicldm/pipeline_musicldm.py +14 -14
  143. diffusers/pipelines/paint_by_example/pipeline_paint_by_example.py +17 -12
  144. diffusers/pipelines/pia/pipeline_pia.py +39 -125
  145. diffusers/pipelines/pipeline_flax_utils.py +4 -4
  146. diffusers/pipelines/pipeline_loading_utils.py +268 -23
  147. diffusers/pipelines/pipeline_utils.py +266 -37
  148. diffusers/pipelines/pixart_alpha/__init__.py +8 -1
  149. diffusers/pipelines/pixart_alpha/pipeline_pixart_alpha.py +65 -75
  150. diffusers/pipelines/pixart_alpha/pipeline_pixart_sigma.py +880 -0
  151. diffusers/pipelines/semantic_stable_diffusion/pipeline_semantic_stable_diffusion.py +10 -5
  152. diffusers/pipelines/shap_e/pipeline_shap_e.py +3 -3
  153. diffusers/pipelines/shap_e/pipeline_shap_e_img2img.py +14 -14
  154. diffusers/pipelines/shap_e/renderer.py +1 -1
  155. diffusers/pipelines/stable_cascade/pipeline_stable_cascade.py +18 -18
  156. diffusers/pipelines/stable_cascade/pipeline_stable_cascade_combined.py +23 -19
  157. diffusers/pipelines/stable_cascade/pipeline_stable_cascade_prior.py +33 -32
  158. diffusers/pipelines/stable_diffusion/__init__.py +0 -1
  159. diffusers/pipelines/stable_diffusion/convert_from_ckpt.py +18 -11
  160. diffusers/pipelines/stable_diffusion/pipeline_onnx_stable_diffusion.py +2 -2
  161. diffusers/pipelines/stable_diffusion/pipeline_onnx_stable_diffusion_upscale.py +6 -6
  162. diffusers/pipelines/stable_diffusion/pipeline_stable_diffusion.py +73 -39
  163. diffusers/pipelines/stable_diffusion/pipeline_stable_diffusion_depth2img.py +24 -17
  164. diffusers/pipelines/stable_diffusion/pipeline_stable_diffusion_image_variation.py +13 -8
  165. diffusers/pipelines/stable_diffusion/pipeline_stable_diffusion_img2img.py +66 -36
  166. diffusers/pipelines/stable_diffusion/pipeline_stable_diffusion_inpaint.py +82 -46
  167. diffusers/pipelines/stable_diffusion/pipeline_stable_diffusion_instruct_pix2pix.py +123 -28
  168. diffusers/pipelines/stable_diffusion/pipeline_stable_diffusion_latent_upscale.py +6 -6
  169. diffusers/pipelines/stable_diffusion/pipeline_stable_diffusion_upscale.py +16 -16
  170. diffusers/pipelines/stable_diffusion/pipeline_stable_unclip.py +24 -19
  171. diffusers/pipelines/stable_diffusion/pipeline_stable_unclip_img2img.py +37 -31
  172. diffusers/pipelines/stable_diffusion/safety_checker.py +2 -1
  173. diffusers/pipelines/stable_diffusion_attend_and_excite/pipeline_stable_diffusion_attend_and_excite.py +23 -15
  174. diffusers/pipelines/stable_diffusion_diffedit/pipeline_stable_diffusion_diffedit.py +44 -39
  175. diffusers/pipelines/stable_diffusion_gligen/pipeline_stable_diffusion_gligen.py +23 -18
  176. diffusers/pipelines/stable_diffusion_gligen/pipeline_stable_diffusion_gligen_text_image.py +19 -14
  177. diffusers/pipelines/stable_diffusion_k_diffusion/pipeline_stable_diffusion_k_diffusion.py +20 -15
  178. diffusers/pipelines/stable_diffusion_k_diffusion/pipeline_stable_diffusion_xl_k_diffusion.py +24 -19
  179. diffusers/pipelines/stable_diffusion_ldm3d/pipeline_stable_diffusion_ldm3d.py +65 -32
  180. diffusers/pipelines/stable_diffusion_panorama/pipeline_stable_diffusion_panorama.py +274 -38
  181. diffusers/pipelines/stable_diffusion_safe/pipeline_stable_diffusion_safe.py +10 -5
  182. diffusers/pipelines/stable_diffusion_safe/safety_checker.py +1 -1
  183. diffusers/pipelines/stable_diffusion_sag/pipeline_stable_diffusion_sag.py +92 -25
  184. diffusers/pipelines/stable_diffusion_xl/pipeline_stable_diffusion_xl.py +88 -44
  185. diffusers/pipelines/stable_diffusion_xl/pipeline_stable_diffusion_xl_img2img.py +108 -56
  186. diffusers/pipelines/stable_diffusion_xl/pipeline_stable_diffusion_xl_inpaint.py +96 -51
  187. diffusers/pipelines/stable_diffusion_xl/pipeline_stable_diffusion_xl_instruct_pix2pix.py +45 -25
  188. diffusers/pipelines/stable_diffusion_xl/watermark.py +9 -3
  189. diffusers/pipelines/stable_video_diffusion/pipeline_stable_video_diffusion.py +110 -57
  190. diffusers/pipelines/t2i_adapter/pipeline_stable_diffusion_adapter.py +59 -30
  191. diffusers/pipelines/t2i_adapter/pipeline_stable_diffusion_xl_adapter.py +71 -42
  192. diffusers/pipelines/text_to_video_synthesis/pipeline_output.py +3 -2
  193. diffusers/pipelines/text_to_video_synthesis/pipeline_text_to_video_synth.py +18 -41
  194. diffusers/pipelines/text_to_video_synthesis/pipeline_text_to_video_synth_img2img.py +21 -85
  195. diffusers/pipelines/text_to_video_synthesis/pipeline_text_to_video_zero.py +28 -19
  196. diffusers/pipelines/text_to_video_synthesis/pipeline_text_to_video_zero_sdxl.py +39 -33
  197. diffusers/pipelines/unclip/pipeline_unclip.py +6 -6
  198. diffusers/pipelines/unclip/pipeline_unclip_image_variation.py +6 -6
  199. diffusers/pipelines/unidiffuser/modeling_text_decoder.py +1 -1
  200. diffusers/pipelines/unidiffuser/modeling_uvit.py +9 -9
  201. diffusers/pipelines/unidiffuser/pipeline_unidiffuser.py +23 -23
  202. diffusers/pipelines/wuerstchen/modeling_paella_vq_model.py +5 -5
  203. diffusers/pipelines/wuerstchen/modeling_wuerstchen_common.py +5 -10
  204. diffusers/pipelines/wuerstchen/modeling_wuerstchen_prior.py +4 -6
  205. diffusers/pipelines/wuerstchen/pipeline_wuerstchen.py +4 -4
  206. diffusers/pipelines/wuerstchen/pipeline_wuerstchen_combined.py +12 -12
  207. diffusers/pipelines/wuerstchen/pipeline_wuerstchen_prior.py +10 -10
  208. diffusers/schedulers/__init__.py +2 -2
  209. diffusers/schedulers/deprecated/__init__.py +1 -1
  210. diffusers/schedulers/deprecated/scheduling_karras_ve.py +25 -25
  211. diffusers/schedulers/scheduling_amused.py +5 -5
  212. diffusers/schedulers/scheduling_consistency_decoder.py +11 -11
  213. diffusers/schedulers/scheduling_consistency_models.py +20 -26
  214. diffusers/schedulers/scheduling_ddim.py +22 -24
  215. diffusers/schedulers/scheduling_ddim_flax.py +2 -1
  216. diffusers/schedulers/scheduling_ddim_inverse.py +16 -16
  217. diffusers/schedulers/scheduling_ddim_parallel.py +28 -30
  218. diffusers/schedulers/scheduling_ddpm.py +20 -22
  219. diffusers/schedulers/scheduling_ddpm_flax.py +7 -3
  220. diffusers/schedulers/scheduling_ddpm_parallel.py +26 -28
  221. diffusers/schedulers/scheduling_ddpm_wuerstchen.py +14 -14
  222. diffusers/schedulers/scheduling_deis_multistep.py +42 -42
  223. diffusers/schedulers/scheduling_dpmsolver_multistep.py +103 -77
  224. diffusers/schedulers/scheduling_dpmsolver_multistep_flax.py +2 -2
  225. diffusers/schedulers/scheduling_dpmsolver_multistep_inverse.py +46 -46
  226. diffusers/schedulers/scheduling_dpmsolver_sde.py +23 -23
  227. diffusers/schedulers/scheduling_dpmsolver_singlestep.py +86 -65
  228. diffusers/schedulers/scheduling_edm_dpmsolver_multistep.py +75 -54
  229. diffusers/schedulers/scheduling_edm_euler.py +50 -31
  230. diffusers/schedulers/scheduling_euler_ancestral_discrete.py +23 -29
  231. diffusers/schedulers/scheduling_euler_discrete.py +160 -68
  232. diffusers/schedulers/scheduling_heun_discrete.py +57 -39
  233. diffusers/schedulers/scheduling_ipndm.py +8 -8
  234. diffusers/schedulers/scheduling_k_dpm_2_ancestral_discrete.py +19 -19
  235. diffusers/schedulers/scheduling_k_dpm_2_discrete.py +19 -19
  236. diffusers/schedulers/scheduling_karras_ve_flax.py +6 -6
  237. diffusers/schedulers/scheduling_lcm.py +21 -23
  238. diffusers/schedulers/scheduling_lms_discrete.py +24 -26
  239. diffusers/schedulers/scheduling_pndm.py +20 -20
  240. diffusers/schedulers/scheduling_repaint.py +20 -20
  241. diffusers/schedulers/scheduling_sasolver.py +55 -54
  242. diffusers/schedulers/scheduling_sde_ve.py +19 -19
  243. diffusers/schedulers/scheduling_tcd.py +39 -30
  244. diffusers/schedulers/scheduling_unclip.py +15 -15
  245. diffusers/schedulers/scheduling_unipc_multistep.py +111 -41
  246. diffusers/schedulers/scheduling_utils.py +14 -5
  247. diffusers/schedulers/scheduling_utils_flax.py +3 -3
  248. diffusers/schedulers/scheduling_vq_diffusion.py +10 -10
  249. diffusers/training_utils.py +56 -1
  250. diffusers/utils/__init__.py +7 -0
  251. diffusers/utils/doc_utils.py +1 -0
  252. diffusers/utils/dummy_pt_objects.py +30 -0
  253. diffusers/utils/dummy_torch_and_transformers_objects.py +90 -0
  254. diffusers/utils/dynamic_modules_utils.py +24 -11
  255. diffusers/utils/hub_utils.py +3 -2
  256. diffusers/utils/import_utils.py +91 -0
  257. diffusers/utils/loading_utils.py +2 -2
  258. diffusers/utils/logging.py +1 -1
  259. diffusers/utils/peft_utils.py +32 -5
  260. diffusers/utils/state_dict_utils.py +11 -2
  261. diffusers/utils/testing_utils.py +71 -6
  262. diffusers/utils/torch_utils.py +1 -0
  263. diffusers/video_processor.py +113 -0
  264. {diffusers-0.27.2.dist-info → diffusers-0.28.0.dist-info}/METADATA +47 -47
  265. diffusers-0.28.0.dist-info/RECORD +414 -0
  266. {diffusers-0.27.2.dist-info → diffusers-0.28.0.dist-info}/WHEEL +1 -1
  267. diffusers-0.27.2.dist-info/RECORD +0 -399
  268. {diffusers-0.27.2.dist-info → diffusers-0.28.0.dist-info}/LICENSE +0 -0
  269. {diffusers-0.27.2.dist-info → diffusers-0.28.0.dist-info}/entry_points.txt +0 -0
  270. {diffusers-0.27.2.dist-info → diffusers-0.28.0.dist-info}/top_level.txt +0 -0
@@ -34,16 +34,16 @@ class DDPMParallelSchedulerOutput(BaseOutput):
34
34
  Output class for the scheduler's `step` function output.
35
35
 
36
36
  Args:
37
- prev_sample (`torch.FloatTensor` of shape `(batch_size, num_channels, height, width)` for images):
37
+ prev_sample (`torch.Tensor` of shape `(batch_size, num_channels, height, width)` for images):
38
38
  Computed sample `(x_{t-1})` of previous timestep. `prev_sample` should be used as next model input in the
39
39
  denoising loop.
40
- pred_original_sample (`torch.FloatTensor` of shape `(batch_size, num_channels, height, width)` for images):
40
+ pred_original_sample (`torch.Tensor` of shape `(batch_size, num_channels, height, width)` for images):
41
41
  The predicted denoised sample `(x_{0})` based on the model output from the current timestep.
42
42
  `pred_original_sample` can be used to preview progress or for guidance.
43
43
  """
44
44
 
45
- prev_sample: torch.FloatTensor
46
- pred_original_sample: Optional[torch.FloatTensor] = None
45
+ prev_sample: torch.Tensor
46
+ pred_original_sample: Optional[torch.Tensor] = None
47
47
 
48
48
 
49
49
  # Copied from diffusers.schedulers.scheduling_ddpm.betas_for_alpha_bar
@@ -81,7 +81,7 @@ def betas_for_alpha_bar(
81
81
  return math.exp(t * -12.0)
82
82
 
83
83
  else:
84
- raise ValueError(f"Unsupported alpha_tranform_type: {alpha_transform_type}")
84
+ raise ValueError(f"Unsupported alpha_transform_type: {alpha_transform_type}")
85
85
 
86
86
  betas = []
87
87
  for i in range(num_diffusion_timesteps):
@@ -98,11 +98,11 @@ def rescale_zero_terminal_snr(betas):
98
98
 
99
99
 
100
100
  Args:
101
- betas (`torch.FloatTensor`):
101
+ betas (`torch.Tensor`):
102
102
  the betas that the scheduler is being initialized with.
103
103
 
104
104
  Returns:
105
- `torch.FloatTensor`: rescaled betas with zero terminal SNR
105
+ `torch.Tensor`: rescaled betas with zero terminal SNR
106
106
  """
107
107
  # Convert betas to alphas_bar_sqrt
108
108
  alphas = 1.0 - betas
@@ -219,7 +219,7 @@ class DDPMParallelScheduler(SchedulerMixin, ConfigMixin):
219
219
  betas = torch.linspace(-6, 6, num_train_timesteps)
220
220
  self.betas = torch.sigmoid(betas) * (beta_end - beta_start) + beta_start
221
221
  else:
222
- raise NotImplementedError(f"{beta_schedule} does is not implemented for {self.__class__}")
222
+ raise NotImplementedError(f"{beta_schedule} is not implemented for {self.__class__}")
223
223
 
224
224
  # Rescale for zero SNR
225
225
  if rescale_betas_zero_snr:
@@ -240,19 +240,19 @@ class DDPMParallelScheduler(SchedulerMixin, ConfigMixin):
240
240
  self.variance_type = variance_type
241
241
 
242
242
  # Copied from diffusers.schedulers.scheduling_ddpm.DDPMScheduler.scale_model_input
243
- def scale_model_input(self, sample: torch.FloatTensor, timestep: Optional[int] = None) -> torch.FloatTensor:
243
+ def scale_model_input(self, sample: torch.Tensor, timestep: Optional[int] = None) -> torch.Tensor:
244
244
  """
245
245
  Ensures interchangeability with schedulers that need to scale the denoising model input depending on the
246
246
  current timestep.
247
247
 
248
248
  Args:
249
- sample (`torch.FloatTensor`):
249
+ sample (`torch.Tensor`):
250
250
  The input sample.
251
251
  timestep (`int`, *optional*):
252
252
  The current timestep in the diffusion chain.
253
253
 
254
254
  Returns:
255
- `torch.FloatTensor`:
255
+ `torch.Tensor`:
256
256
  A scaled input sample.
257
257
  """
258
258
  return sample
@@ -375,7 +375,7 @@ class DDPMParallelScheduler(SchedulerMixin, ConfigMixin):
375
375
  return variance
376
376
 
377
377
  # Copied from diffusers.schedulers.scheduling_ddpm.DDPMScheduler._threshold_sample
378
- def _threshold_sample(self, sample: torch.FloatTensor) -> torch.FloatTensor:
378
+ def _threshold_sample(self, sample: torch.Tensor) -> torch.Tensor:
379
379
  """
380
380
  "Dynamic thresholding: At each sampling step we set s to a certain percentile absolute pixel value in xt0 (the
381
381
  prediction of x_0 at timestep t), and if s > 1, then we threshold xt0 to the range [-s, s] and then divide by
@@ -410,9 +410,9 @@ class DDPMParallelScheduler(SchedulerMixin, ConfigMixin):
410
410
 
411
411
  def step(
412
412
  self,
413
- model_output: torch.FloatTensor,
413
+ model_output: torch.Tensor,
414
414
  timestep: int,
415
- sample: torch.FloatTensor,
415
+ sample: torch.Tensor,
416
416
  generator=None,
417
417
  return_dict: bool = True,
418
418
  ) -> Union[DDPMParallelSchedulerOutput, Tuple]:
@@ -421,9 +421,9 @@ class DDPMParallelScheduler(SchedulerMixin, ConfigMixin):
421
421
  process from the learned model outputs (most often the predicted noise).
422
422
 
423
423
  Args:
424
- model_output (`torch.FloatTensor`): direct output from learned diffusion model.
424
+ model_output (`torch.Tensor`): direct output from learned diffusion model.
425
425
  timestep (`int`): current discrete timestep in the diffusion chain.
426
- sample (`torch.FloatTensor`):
426
+ sample (`torch.Tensor`):
427
427
  current instance of sample being created by diffusion process.
428
428
  generator: random number generator.
429
429
  return_dict (`bool`): option for returning tuple rather than DDPMParallelSchedulerOutput class
@@ -506,10 +506,10 @@ class DDPMParallelScheduler(SchedulerMixin, ConfigMixin):
506
506
 
507
507
  def batch_step_no_noise(
508
508
  self,
509
- model_output: torch.FloatTensor,
509
+ model_output: torch.Tensor,
510
510
  timesteps: List[int],
511
- sample: torch.FloatTensor,
512
- ) -> torch.FloatTensor:
511
+ sample: torch.Tensor,
512
+ ) -> torch.Tensor:
513
513
  """
514
514
  Batched version of the `step` function, to be able to reverse the SDE for multiple samples/timesteps at once.
515
515
  Also, does not add any noise to the predicted sample, which is necessary for parallel sampling where the noise
@@ -519,14 +519,14 @@ class DDPMParallelScheduler(SchedulerMixin, ConfigMixin):
519
519
  process from the learned model outputs (most often the predicted noise).
520
520
 
521
521
  Args:
522
- model_output (`torch.FloatTensor`): direct output from learned diffusion model.
522
+ model_output (`torch.Tensor`): direct output from learned diffusion model.
523
523
  timesteps (`List[int]`):
524
524
  current discrete timesteps in the diffusion chain. This is now a list of integers.
525
- sample (`torch.FloatTensor`):
525
+ sample (`torch.Tensor`):
526
526
  current instance of sample being created by diffusion process.
527
527
 
528
528
  Returns:
529
- `torch.FloatTensor`: sample tensor at previous timestep.
529
+ `torch.Tensor`: sample tensor at previous timestep.
530
530
  """
531
531
  t = timesteps
532
532
  num_inference_steps = self.num_inference_steps if self.num_inference_steps else self.config.num_train_timesteps
@@ -587,10 +587,10 @@ class DDPMParallelScheduler(SchedulerMixin, ConfigMixin):
587
587
  # Copied from diffusers.schedulers.scheduling_ddpm.DDPMScheduler.add_noise
588
588
  def add_noise(
589
589
  self,
590
- original_samples: torch.FloatTensor,
591
- noise: torch.FloatTensor,
590
+ original_samples: torch.Tensor,
591
+ noise: torch.Tensor,
592
592
  timesteps: torch.IntTensor,
593
- ) -> torch.FloatTensor:
593
+ ) -> torch.Tensor:
594
594
  # Make sure alphas_cumprod and timestep have same device and dtype as original_samples
595
595
  # Move the self.alphas_cumprod to device to avoid redundant CPU to GPU data movement
596
596
  # for the subsequent add_noise calls
@@ -612,9 +612,7 @@ class DDPMParallelScheduler(SchedulerMixin, ConfigMixin):
612
612
  return noisy_samples
613
613
 
614
614
  # Copied from diffusers.schedulers.scheduling_ddpm.DDPMScheduler.get_velocity
615
- def get_velocity(
616
- self, sample: torch.FloatTensor, noise: torch.FloatTensor, timesteps: torch.IntTensor
617
- ) -> torch.FloatTensor:
615
+ def get_velocity(self, sample: torch.Tensor, noise: torch.Tensor, timesteps: torch.IntTensor) -> torch.Tensor:
618
616
  # Make sure alphas_cumprod and timestep have same device and dtype as sample
619
617
  self.alphas_cumprod = self.alphas_cumprod.to(device=sample.device)
620
618
  alphas_cumprod = self.alphas_cumprod.to(dtype=sample.dtype)
@@ -33,12 +33,12 @@ class DDPMWuerstchenSchedulerOutput(BaseOutput):
33
33
  Output class for the scheduler's step function output.
34
34
 
35
35
  Args:
36
- prev_sample (`torch.FloatTensor` of shape `(batch_size, num_channels, height, width)` for images):
36
+ prev_sample (`torch.Tensor` of shape `(batch_size, num_channels, height, width)` for images):
37
37
  Computed sample (x_{t-1}) of previous timestep. `prev_sample` should be used as next model input in the
38
38
  denoising loop.
39
39
  """
40
40
 
41
- prev_sample: torch.FloatTensor
41
+ prev_sample: torch.Tensor
42
42
 
43
43
 
44
44
  def betas_for_alpha_bar(
@@ -75,7 +75,7 @@ def betas_for_alpha_bar(
75
75
  return math.exp(t * -12.0)
76
76
 
77
77
  else:
78
- raise ValueError(f"Unsupported alpha_tranform_type: {alpha_transform_type}")
78
+ raise ValueError(f"Unsupported alpha_transform_type: {alpha_transform_type}")
79
79
 
80
80
  betas = []
81
81
  for i in range(num_diffusion_timesteps):
@@ -125,17 +125,17 @@ class DDPMWuerstchenScheduler(SchedulerMixin, ConfigMixin):
125
125
  ) ** 2 / self._init_alpha_cumprod.to(device)
126
126
  return alpha_cumprod.clamp(0.0001, 0.9999)
127
127
 
128
- def scale_model_input(self, sample: torch.FloatTensor, timestep: Optional[int] = None) -> torch.FloatTensor:
128
+ def scale_model_input(self, sample: torch.Tensor, timestep: Optional[int] = None) -> torch.Tensor:
129
129
  """
130
130
  Ensures interchangeability with schedulers that need to scale the denoising model input depending on the
131
131
  current timestep.
132
132
 
133
133
  Args:
134
- sample (`torch.FloatTensor`): input sample
134
+ sample (`torch.Tensor`): input sample
135
135
  timestep (`int`, optional): current timestep
136
136
 
137
137
  Returns:
138
- `torch.FloatTensor`: scaled input sample
138
+ `torch.Tensor`: scaled input sample
139
139
  """
140
140
  return sample
141
141
 
@@ -163,9 +163,9 @@ class DDPMWuerstchenScheduler(SchedulerMixin, ConfigMixin):
163
163
 
164
164
  def step(
165
165
  self,
166
- model_output: torch.FloatTensor,
166
+ model_output: torch.Tensor,
167
167
  timestep: int,
168
- sample: torch.FloatTensor,
168
+ sample: torch.Tensor,
169
169
  generator=None,
170
170
  return_dict: bool = True,
171
171
  ) -> Union[DDPMWuerstchenSchedulerOutput, Tuple]:
@@ -174,9 +174,9 @@ class DDPMWuerstchenScheduler(SchedulerMixin, ConfigMixin):
174
174
  process from the learned model outputs (most often the predicted noise).
175
175
 
176
176
  Args:
177
- model_output (`torch.FloatTensor`): direct output from learned diffusion model.
177
+ model_output (`torch.Tensor`): direct output from learned diffusion model.
178
178
  timestep (`int`): current discrete timestep in the diffusion chain.
179
- sample (`torch.FloatTensor`):
179
+ sample (`torch.Tensor`):
180
180
  current instance of sample being created by diffusion process.
181
181
  generator: random number generator.
182
182
  return_dict (`bool`): option for returning tuple rather than DDPMWuerstchenSchedulerOutput class
@@ -209,10 +209,10 @@ class DDPMWuerstchenScheduler(SchedulerMixin, ConfigMixin):
209
209
 
210
210
  def add_noise(
211
211
  self,
212
- original_samples: torch.FloatTensor,
213
- noise: torch.FloatTensor,
214
- timesteps: torch.FloatTensor,
215
- ) -> torch.FloatTensor:
212
+ original_samples: torch.Tensor,
213
+ noise: torch.Tensor,
214
+ timesteps: torch.Tensor,
215
+ ) -> torch.Tensor:
216
216
  device = original_samples.device
217
217
  dtype = original_samples.dtype
218
218
  alpha_cumprod = self._alpha_cumprod(timesteps, device=device).view(
@@ -61,7 +61,7 @@ def betas_for_alpha_bar(
61
61
  return math.exp(t * -12.0)
62
62
 
63
63
  else:
64
- raise ValueError(f"Unsupported alpha_tranform_type: {alpha_transform_type}")
64
+ raise ValueError(f"Unsupported alpha_transform_type: {alpha_transform_type}")
65
65
 
66
66
  betas = []
67
67
  for i in range(num_diffusion_timesteps):
@@ -152,7 +152,7 @@ class DEISMultistepScheduler(SchedulerMixin, ConfigMixin):
152
152
  # Glide cosine schedule
153
153
  self.betas = betas_for_alpha_bar(num_train_timesteps)
154
154
  else:
155
- raise NotImplementedError(f"{beta_schedule} does is not implemented for {self.__class__}")
155
+ raise NotImplementedError(f"{beta_schedule} is not implemented for {self.__class__}")
156
156
 
157
157
  self.alphas = 1.0 - self.betas
158
158
  self.alphas_cumprod = torch.cumprod(self.alphas, dim=0)
@@ -170,13 +170,13 @@ class DEISMultistepScheduler(SchedulerMixin, ConfigMixin):
170
170
  if algorithm_type in ["dpmsolver", "dpmsolver++"]:
171
171
  self.register_to_config(algorithm_type="deis")
172
172
  else:
173
- raise NotImplementedError(f"{algorithm_type} does is not implemented for {self.__class__}")
173
+ raise NotImplementedError(f"{algorithm_type} is not implemented for {self.__class__}")
174
174
 
175
175
  if solver_type not in ["logrho"]:
176
176
  if solver_type in ["midpoint", "heun", "bh1", "bh2"]:
177
177
  self.register_to_config(solver_type="logrho")
178
178
  else:
179
- raise NotImplementedError(f"solver type {solver_type} does is not implemented for {self.__class__}")
179
+ raise NotImplementedError(f"solver type {solver_type} is not implemented for {self.__class__}")
180
180
 
181
181
  # setable values
182
182
  self.num_inference_steps = None
@@ -191,7 +191,7 @@ class DEISMultistepScheduler(SchedulerMixin, ConfigMixin):
191
191
  @property
192
192
  def step_index(self):
193
193
  """
194
- The index counter for current timestep. It will increae 1 after each scheduler step.
194
+ The index counter for current timestep. It will increase 1 after each scheduler step.
195
195
  """
196
196
  return self._step_index
197
197
 
@@ -276,7 +276,7 @@ class DEISMultistepScheduler(SchedulerMixin, ConfigMixin):
276
276
  self.sigmas = self.sigmas.to("cpu") # to avoid too much CPU/GPU communication
277
277
 
278
278
  # Copied from diffusers.schedulers.scheduling_ddpm.DDPMScheduler._threshold_sample
279
- def _threshold_sample(self, sample: torch.FloatTensor) -> torch.FloatTensor:
279
+ def _threshold_sample(self, sample: torch.Tensor) -> torch.Tensor:
280
280
  """
281
281
  "Dynamic thresholding: At each sampling step we set s to a certain percentile absolute pixel value in xt0 (the
282
282
  prediction of x_0 at timestep t), and if s > 1, then we threshold xt0 to the range [-s, s] and then divide by
@@ -341,7 +341,7 @@ class DEISMultistepScheduler(SchedulerMixin, ConfigMixin):
341
341
  return alpha_t, sigma_t
342
342
 
343
343
  # Copied from diffusers.schedulers.scheduling_euler_discrete.EulerDiscreteScheduler._convert_to_karras
344
- def _convert_to_karras(self, in_sigmas: torch.FloatTensor, num_inference_steps) -> torch.FloatTensor:
344
+ def _convert_to_karras(self, in_sigmas: torch.Tensor, num_inference_steps) -> torch.Tensor:
345
345
  """Constructs the noise schedule of Karras et al. (2022)."""
346
346
 
347
347
  # Hack to make sure that other schedulers which copy this function don't break
@@ -368,24 +368,24 @@ class DEISMultistepScheduler(SchedulerMixin, ConfigMixin):
368
368
 
369
369
  def convert_model_output(
370
370
  self,
371
- model_output: torch.FloatTensor,
371
+ model_output: torch.Tensor,
372
372
  *args,
373
- sample: torch.FloatTensor = None,
373
+ sample: torch.Tensor = None,
374
374
  **kwargs,
375
- ) -> torch.FloatTensor:
375
+ ) -> torch.Tensor:
376
376
  """
377
377
  Convert the model output to the corresponding type the DEIS algorithm needs.
378
378
 
379
379
  Args:
380
- model_output (`torch.FloatTensor`):
380
+ model_output (`torch.Tensor`):
381
381
  The direct output from the learned diffusion model.
382
382
  timestep (`int`):
383
383
  The current discrete timestep in the diffusion chain.
384
- sample (`torch.FloatTensor`):
384
+ sample (`torch.Tensor`):
385
385
  A current instance of a sample created by the diffusion process.
386
386
 
387
387
  Returns:
388
- `torch.FloatTensor`:
388
+ `torch.Tensor`:
389
389
  The converted model output.
390
390
  """
391
391
  timestep = args[0] if len(args) > 0 else kwargs.pop("timestep", None)
@@ -425,26 +425,26 @@ class DEISMultistepScheduler(SchedulerMixin, ConfigMixin):
425
425
 
426
426
  def deis_first_order_update(
427
427
  self,
428
- model_output: torch.FloatTensor,
428
+ model_output: torch.Tensor,
429
429
  *args,
430
- sample: torch.FloatTensor = None,
430
+ sample: torch.Tensor = None,
431
431
  **kwargs,
432
- ) -> torch.FloatTensor:
432
+ ) -> torch.Tensor:
433
433
  """
434
434
  One step for the first-order DEIS (equivalent to DDIM).
435
435
 
436
436
  Args:
437
- model_output (`torch.FloatTensor`):
437
+ model_output (`torch.Tensor`):
438
438
  The direct output from the learned diffusion model.
439
439
  timestep (`int`):
440
440
  The current discrete timestep in the diffusion chain.
441
441
  prev_timestep (`int`):
442
442
  The previous discrete timestep in the diffusion chain.
443
- sample (`torch.FloatTensor`):
443
+ sample (`torch.Tensor`):
444
444
  A current instance of a sample created by the diffusion process.
445
445
 
446
446
  Returns:
447
- `torch.FloatTensor`:
447
+ `torch.Tensor`:
448
448
  The sample tensor at the previous timestep.
449
449
  """
450
450
  timestep = args[0] if len(args) > 0 else kwargs.pop("timestep", None)
@@ -483,22 +483,22 @@ class DEISMultistepScheduler(SchedulerMixin, ConfigMixin):
483
483
 
484
484
  def multistep_deis_second_order_update(
485
485
  self,
486
- model_output_list: List[torch.FloatTensor],
486
+ model_output_list: List[torch.Tensor],
487
487
  *args,
488
- sample: torch.FloatTensor = None,
488
+ sample: torch.Tensor = None,
489
489
  **kwargs,
490
- ) -> torch.FloatTensor:
490
+ ) -> torch.Tensor:
491
491
  """
492
492
  One step for the second-order multistep DEIS.
493
493
 
494
494
  Args:
495
- model_output_list (`List[torch.FloatTensor]`):
495
+ model_output_list (`List[torch.Tensor]`):
496
496
  The direct outputs from learned diffusion model at current and latter timesteps.
497
- sample (`torch.FloatTensor`):
497
+ sample (`torch.Tensor`):
498
498
  A current instance of a sample created by the diffusion process.
499
499
 
500
500
  Returns:
501
- `torch.FloatTensor`:
501
+ `torch.Tensor`:
502
502
  The sample tensor at the previous timestep.
503
503
  """
504
504
  timestep_list = args[0] if len(args) > 0 else kwargs.pop("timestep_list", None)
@@ -552,22 +552,22 @@ class DEISMultistepScheduler(SchedulerMixin, ConfigMixin):
552
552
 
553
553
  def multistep_deis_third_order_update(
554
554
  self,
555
- model_output_list: List[torch.FloatTensor],
555
+ model_output_list: List[torch.Tensor],
556
556
  *args,
557
- sample: torch.FloatTensor = None,
557
+ sample: torch.Tensor = None,
558
558
  **kwargs,
559
- ) -> torch.FloatTensor:
559
+ ) -> torch.Tensor:
560
560
  """
561
561
  One step for the third-order multistep DEIS.
562
562
 
563
563
  Args:
564
- model_output_list (`List[torch.FloatTensor]`):
564
+ model_output_list (`List[torch.Tensor]`):
565
565
  The direct outputs from learned diffusion model at current and latter timesteps.
566
- sample (`torch.FloatTensor`):
566
+ sample (`torch.Tensor`):
567
567
  A current instance of a sample created by diffusion process.
568
568
 
569
569
  Returns:
570
- `torch.FloatTensor`:
570
+ `torch.Tensor`:
571
571
  The sample tensor at the previous timestep.
572
572
  """
573
573
 
@@ -673,9 +673,9 @@ class DEISMultistepScheduler(SchedulerMixin, ConfigMixin):
673
673
 
674
674
  def step(
675
675
  self,
676
- model_output: torch.FloatTensor,
676
+ model_output: torch.Tensor,
677
677
  timestep: int,
678
- sample: torch.FloatTensor,
678
+ sample: torch.Tensor,
679
679
  return_dict: bool = True,
680
680
  ) -> Union[SchedulerOutput, Tuple]:
681
681
  """
@@ -683,11 +683,11 @@ class DEISMultistepScheduler(SchedulerMixin, ConfigMixin):
683
683
  the multistep DEIS.
684
684
 
685
685
  Args:
686
- model_output (`torch.FloatTensor`):
686
+ model_output (`torch.Tensor`):
687
687
  The direct output from learned diffusion model.
688
688
  timestep (`float`):
689
689
  The current discrete timestep in the diffusion chain.
690
- sample (`torch.FloatTensor`):
690
+ sample (`torch.Tensor`):
691
691
  A current instance of a sample created by the diffusion process.
692
692
  return_dict (`bool`):
693
693
  Whether or not to return a [`~schedulers.scheduling_utils.SchedulerOutput`] or `tuple`.
@@ -736,17 +736,17 @@ class DEISMultistepScheduler(SchedulerMixin, ConfigMixin):
736
736
 
737
737
  return SchedulerOutput(prev_sample=prev_sample)
738
738
 
739
- def scale_model_input(self, sample: torch.FloatTensor, *args, **kwargs) -> torch.FloatTensor:
739
+ def scale_model_input(self, sample: torch.Tensor, *args, **kwargs) -> torch.Tensor:
740
740
  """
741
741
  Ensures interchangeability with schedulers that need to scale the denoising model input depending on the
742
742
  current timestep.
743
743
 
744
744
  Args:
745
- sample (`torch.FloatTensor`):
745
+ sample (`torch.Tensor`):
746
746
  The input sample.
747
747
 
748
748
  Returns:
749
- `torch.FloatTensor`:
749
+ `torch.Tensor`:
750
750
  A scaled input sample.
751
751
  """
752
752
  return sample
@@ -754,10 +754,10 @@ class DEISMultistepScheduler(SchedulerMixin, ConfigMixin):
754
754
  # Copied from diffusers.schedulers.scheduling_dpmsolver_multistep.DPMSolverMultistepScheduler.add_noise
755
755
  def add_noise(
756
756
  self,
757
- original_samples: torch.FloatTensor,
758
- noise: torch.FloatTensor,
757
+ original_samples: torch.Tensor,
758
+ noise: torch.Tensor,
759
759
  timesteps: torch.IntTensor,
760
- ) -> torch.FloatTensor:
760
+ ) -> torch.Tensor:
761
761
  # Make sure sigmas and timesteps have the same device and dtype as original_samples
762
762
  sigmas = self.sigmas.to(device=original_samples.device, dtype=original_samples.dtype)
763
763
  if original_samples.device.type == "mps" and torch.is_floating_point(timesteps):
@@ -775,7 +775,7 @@ class DEISMultistepScheduler(SchedulerMixin, ConfigMixin):
775
775
  # add_noise is called after first denoising step (for inpainting)
776
776
  step_indices = [self.step_index] * timesteps.shape[0]
777
777
  else:
778
- # add noise is called bevore first denoising step to create inital latent(img2img)
778
+ # add noise is called before first denoising step to create initial latent(img2img)
779
779
  step_indices = [self.begin_index] * timesteps.shape[0]
780
780
 
781
781
  sigma = sigmas[step_indices].flatten()