diffusers 0.23.0__py3-none-any.whl → 0.24.0__py3-none-any.whl

Sign up to get free protection for your applications and to get access to all the features.
Files changed (177) hide show
  1. diffusers/__init__.py +16 -2
  2. diffusers/configuration_utils.py +1 -0
  3. diffusers/dependency_versions_check.py +1 -14
  4. diffusers/dependency_versions_table.py +5 -4
  5. diffusers/image_processor.py +186 -14
  6. diffusers/loaders/__init__.py +82 -0
  7. diffusers/loaders/ip_adapter.py +157 -0
  8. diffusers/loaders/lora.py +1415 -0
  9. diffusers/loaders/lora_conversion_utils.py +284 -0
  10. diffusers/loaders/single_file.py +631 -0
  11. diffusers/loaders/textual_inversion.py +459 -0
  12. diffusers/loaders/unet.py +735 -0
  13. diffusers/loaders/utils.py +59 -0
  14. diffusers/models/__init__.py +12 -1
  15. diffusers/models/attention.py +165 -14
  16. diffusers/models/attention_flax.py +9 -1
  17. diffusers/models/attention_processor.py +286 -1
  18. diffusers/models/autoencoder_asym_kl.py +14 -9
  19. diffusers/models/autoencoder_kl.py +3 -18
  20. diffusers/models/autoencoder_kl_temporal_decoder.py +402 -0
  21. diffusers/models/autoencoder_tiny.py +20 -24
  22. diffusers/models/consistency_decoder_vae.py +37 -30
  23. diffusers/models/controlnet.py +59 -39
  24. diffusers/models/controlnet_flax.py +19 -18
  25. diffusers/models/embeddings_flax.py +2 -0
  26. diffusers/models/lora.py +131 -1
  27. diffusers/models/modeling_flax_utils.py +2 -1
  28. diffusers/models/modeling_outputs.py +17 -0
  29. diffusers/models/modeling_utils.py +27 -19
  30. diffusers/models/normalization.py +2 -2
  31. diffusers/models/resnet.py +390 -59
  32. diffusers/models/transformer_2d.py +20 -3
  33. diffusers/models/transformer_temporal.py +183 -1
  34. diffusers/models/unet_2d_blocks_flax.py +5 -0
  35. diffusers/models/unet_2d_condition.py +9 -0
  36. diffusers/models/unet_2d_condition_flax.py +13 -13
  37. diffusers/models/unet_3d_blocks.py +957 -173
  38. diffusers/models/unet_3d_condition.py +16 -8
  39. diffusers/models/unet_kandi3.py +589 -0
  40. diffusers/models/unet_motion_model.py +48 -33
  41. diffusers/models/unet_spatio_temporal_condition.py +489 -0
  42. diffusers/models/vae.py +63 -13
  43. diffusers/models/vae_flax.py +7 -0
  44. diffusers/models/vq_model.py +3 -1
  45. diffusers/optimization.py +16 -9
  46. diffusers/pipelines/__init__.py +65 -12
  47. diffusers/pipelines/alt_diffusion/pipeline_alt_diffusion.py +93 -23
  48. diffusers/pipelines/alt_diffusion/pipeline_alt_diffusion_img2img.py +97 -25
  49. diffusers/pipelines/animatediff/pipeline_animatediff.py +34 -4
  50. diffusers/pipelines/audioldm/pipeline_audioldm.py +1 -0
  51. diffusers/pipelines/auto_pipeline.py +6 -0
  52. diffusers/pipelines/consistency_models/pipeline_consistency_models.py +1 -0
  53. diffusers/pipelines/controlnet/pipeline_controlnet.py +217 -31
  54. diffusers/pipelines/controlnet/pipeline_controlnet_img2img.py +101 -32
  55. diffusers/pipelines/controlnet/pipeline_controlnet_inpaint.py +136 -39
  56. diffusers/pipelines/controlnet/pipeline_controlnet_inpaint_sd_xl.py +119 -37
  57. diffusers/pipelines/controlnet/pipeline_controlnet_sd_xl.py +196 -35
  58. diffusers/pipelines/controlnet/pipeline_controlnet_sd_xl_img2img.py +102 -31
  59. diffusers/pipelines/dance_diffusion/pipeline_dance_diffusion.py +1 -0
  60. diffusers/pipelines/ddim/pipeline_ddim.py +1 -0
  61. diffusers/pipelines/ddpm/pipeline_ddpm.py +1 -0
  62. diffusers/pipelines/deepfloyd_if/pipeline_if.py +13 -1
  63. diffusers/pipelines/deepfloyd_if/pipeline_if_img2img.py +13 -1
  64. diffusers/pipelines/deepfloyd_if/pipeline_if_img2img_superresolution.py +13 -1
  65. diffusers/pipelines/deepfloyd_if/pipeline_if_inpainting.py +13 -1
  66. diffusers/pipelines/deepfloyd_if/pipeline_if_inpainting_superresolution.py +13 -1
  67. diffusers/pipelines/deepfloyd_if/pipeline_if_superresolution.py +13 -1
  68. diffusers/pipelines/dit/pipeline_dit.py +1 -0
  69. diffusers/pipelines/kandinsky2_2/pipeline_kandinsky2_2.py +1 -1
  70. diffusers/pipelines/kandinsky2_2/pipeline_kandinsky2_2_combined.py +3 -3
  71. diffusers/pipelines/kandinsky2_2/pipeline_kandinsky2_2_img2img.py +1 -1
  72. diffusers/pipelines/kandinsky2_2/pipeline_kandinsky2_2_inpainting.py +1 -1
  73. diffusers/pipelines/kandinsky2_2/pipeline_kandinsky2_2_prior.py +1 -1
  74. diffusers/pipelines/kandinsky3/__init__.py +49 -0
  75. diffusers/pipelines/kandinsky3/kandinsky3_pipeline.py +452 -0
  76. diffusers/pipelines/kandinsky3/kandinsky3img2img_pipeline.py +460 -0
  77. diffusers/pipelines/latent_consistency_models/pipeline_latent_consistency_img2img.py +65 -6
  78. diffusers/pipelines/latent_consistency_models/pipeline_latent_consistency_text2img.py +55 -3
  79. diffusers/pipelines/latent_diffusion/pipeline_latent_diffusion.py +1 -0
  80. diffusers/pipelines/musicldm/pipeline_musicldm.py +1 -1
  81. diffusers/pipelines/paint_by_example/pipeline_paint_by_example.py +7 -2
  82. diffusers/pipelines/pipeline_flax_utils.py +4 -2
  83. diffusers/pipelines/pipeline_utils.py +33 -13
  84. diffusers/pipelines/pixart_alpha/pipeline_pixart_alpha.py +196 -36
  85. diffusers/pipelines/score_sde_ve/pipeline_score_sde_ve.py +1 -0
  86. diffusers/pipelines/spectrogram_diffusion/pipeline_spectrogram_diffusion.py +1 -0
  87. diffusers/pipelines/stable_diffusion/__init__.py +64 -21
  88. diffusers/pipelines/stable_diffusion/convert_from_ckpt.py +8 -3
  89. diffusers/pipelines/stable_diffusion/pipeline_cycle_diffusion.py +18 -2
  90. diffusers/pipelines/stable_diffusion/pipeline_flax_stable_diffusion.py +2 -2
  91. diffusers/pipelines/stable_diffusion/pipeline_onnx_stable_diffusion_img2img.py +2 -4
  92. diffusers/pipelines/stable_diffusion/pipeline_onnx_stable_diffusion_inpaint.py +1 -0
  93. diffusers/pipelines/stable_diffusion/pipeline_onnx_stable_diffusion_inpaint_legacy.py +1 -0
  94. diffusers/pipelines/stable_diffusion/pipeline_stable_diffusion.py +88 -9
  95. diffusers/pipelines/stable_diffusion/pipeline_stable_diffusion_attend_and_excite.py +1 -0
  96. diffusers/pipelines/stable_diffusion/pipeline_stable_diffusion_depth2img.py +8 -3
  97. diffusers/pipelines/stable_diffusion/pipeline_stable_diffusion_diffedit.py +1 -0
  98. diffusers/pipelines/stable_diffusion/pipeline_stable_diffusion_gligen.py +1 -0
  99. diffusers/pipelines/stable_diffusion/pipeline_stable_diffusion_gligen_text_image.py +1 -0
  100. diffusers/pipelines/stable_diffusion/pipeline_stable_diffusion_image_variation.py +1 -0
  101. diffusers/pipelines/stable_diffusion/pipeline_stable_diffusion_img2img.py +92 -9
  102. diffusers/pipelines/stable_diffusion/pipeline_stable_diffusion_inpaint.py +92 -9
  103. diffusers/pipelines/stable_diffusion/pipeline_stable_diffusion_inpaint_legacy.py +1 -0
  104. diffusers/pipelines/stable_diffusion/pipeline_stable_diffusion_instruct_pix2pix.py +17 -13
  105. diffusers/pipelines/stable_diffusion/pipeline_stable_diffusion_k_diffusion.py +1 -0
  106. diffusers/pipelines/stable_diffusion/pipeline_stable_diffusion_latent_upscale.py +1 -0
  107. diffusers/pipelines/stable_diffusion/pipeline_stable_diffusion_ldm3d.py +1 -0
  108. diffusers/pipelines/stable_diffusion/pipeline_stable_diffusion_model_editing.py +1 -0
  109. diffusers/pipelines/stable_diffusion/pipeline_stable_diffusion_panorama.py +1 -0
  110. diffusers/pipelines/stable_diffusion/pipeline_stable_diffusion_paradigms.py +1 -0
  111. diffusers/pipelines/stable_diffusion/pipeline_stable_diffusion_pix2pix_zero.py +1 -0
  112. diffusers/pipelines/stable_diffusion/pipeline_stable_diffusion_sag.py +1 -0
  113. diffusers/pipelines/stable_diffusion/pipeline_stable_diffusion_upscale.py +1 -0
  114. diffusers/pipelines/stable_diffusion_xl/pipeline_stable_diffusion_xl.py +103 -8
  115. diffusers/pipelines/stable_diffusion_xl/pipeline_stable_diffusion_xl_img2img.py +113 -8
  116. diffusers/pipelines/stable_diffusion_xl/pipeline_stable_diffusion_xl_inpaint.py +115 -9
  117. diffusers/pipelines/stable_diffusion_xl/pipeline_stable_diffusion_xl_instruct_pix2pix.py +16 -12
  118. diffusers/pipelines/stable_video_diffusion/__init__.py +58 -0
  119. diffusers/pipelines/stable_video_diffusion/pipeline_stable_video_diffusion.py +649 -0
  120. diffusers/pipelines/t2i_adapter/pipeline_stable_diffusion_adapter.py +108 -12
  121. diffusers/pipelines/t2i_adapter/pipeline_stable_diffusion_xl_adapter.py +109 -14
  122. diffusers/pipelines/text_to_video_synthesis/__init__.py +2 -0
  123. diffusers/pipelines/text_to_video_synthesis/pipeline_text_to_video_synth.py +1 -0
  124. diffusers/pipelines/text_to_video_synthesis/pipeline_text_to_video_synth_img2img.py +18 -3
  125. diffusers/pipelines/text_to_video_synthesis/pipeline_text_to_video_zero.py +4 -2
  126. diffusers/pipelines/text_to_video_synthesis/pipeline_text_to_video_zero_sdxl.py +872 -0
  127. diffusers/pipelines/versatile_diffusion/modeling_text_unet.py +29 -40
  128. diffusers/pipelines/versatile_diffusion/pipeline_versatile_diffusion_dual_guided.py +1 -0
  129. diffusers/pipelines/versatile_diffusion/pipeline_versatile_diffusion_image_variation.py +1 -0
  130. diffusers/pipelines/versatile_diffusion/pipeline_versatile_diffusion_text_to_image.py +1 -0
  131. diffusers/pipelines/wuerstchen/modeling_wuerstchen_common.py +14 -4
  132. diffusers/pipelines/wuerstchen/modeling_wuerstchen_prior.py +9 -5
  133. diffusers/pipelines/wuerstchen/pipeline_wuerstchen.py +1 -1
  134. diffusers/pipelines/wuerstchen/pipeline_wuerstchen_combined.py +2 -2
  135. diffusers/pipelines/wuerstchen/pipeline_wuerstchen_prior.py +1 -1
  136. diffusers/schedulers/__init__.py +2 -4
  137. diffusers/schedulers/deprecated/__init__.py +50 -0
  138. diffusers/schedulers/{scheduling_karras_ve.py → deprecated/scheduling_karras_ve.py} +4 -4
  139. diffusers/schedulers/{scheduling_sde_vp.py → deprecated/scheduling_sde_vp.py} +4 -6
  140. diffusers/schedulers/scheduling_ddim.py +1 -3
  141. diffusers/schedulers/scheduling_ddim_inverse.py +1 -3
  142. diffusers/schedulers/scheduling_ddim_parallel.py +1 -3
  143. diffusers/schedulers/scheduling_ddpm.py +1 -3
  144. diffusers/schedulers/scheduling_ddpm_parallel.py +1 -3
  145. diffusers/schedulers/scheduling_deis_multistep.py +15 -5
  146. diffusers/schedulers/scheduling_dpmsolver_multistep.py +15 -5
  147. diffusers/schedulers/scheduling_dpmsolver_multistep_inverse.py +15 -5
  148. diffusers/schedulers/scheduling_dpmsolver_sde.py +1 -3
  149. diffusers/schedulers/scheduling_dpmsolver_singlestep.py +15 -5
  150. diffusers/schedulers/scheduling_euler_ancestral_discrete.py +1 -3
  151. diffusers/schedulers/scheduling_euler_discrete.py +40 -13
  152. diffusers/schedulers/scheduling_heun_discrete.py +15 -5
  153. diffusers/schedulers/scheduling_k_dpm_2_ancestral_discrete.py +15 -5
  154. diffusers/schedulers/scheduling_k_dpm_2_discrete.py +15 -5
  155. diffusers/schedulers/scheduling_lcm.py +123 -29
  156. diffusers/schedulers/scheduling_lms_discrete.py +1 -3
  157. diffusers/schedulers/scheduling_pndm.py +1 -3
  158. diffusers/schedulers/scheduling_repaint.py +1 -3
  159. diffusers/schedulers/scheduling_unipc_multistep.py +15 -5
  160. diffusers/utils/__init__.py +1 -0
  161. diffusers/utils/constants.py +11 -6
  162. diffusers/utils/dummy_pt_objects.py +45 -0
  163. diffusers/utils/dummy_torch_and_transformers_objects.py +60 -0
  164. diffusers/utils/dynamic_modules_utils.py +4 -4
  165. diffusers/utils/export_utils.py +8 -3
  166. diffusers/utils/logging.py +10 -10
  167. diffusers/utils/outputs.py +5 -5
  168. diffusers/utils/peft_utils.py +88 -44
  169. diffusers/utils/torch_utils.py +2 -2
  170. diffusers/utils/versions.py +117 -0
  171. {diffusers-0.23.0.dist-info → diffusers-0.24.0.dist-info}/METADATA +83 -64
  172. {diffusers-0.23.0.dist-info → diffusers-0.24.0.dist-info}/RECORD +176 -157
  173. {diffusers-0.23.0.dist-info → diffusers-0.24.0.dist-info}/WHEEL +1 -1
  174. {diffusers-0.23.0.dist-info → diffusers-0.24.0.dist-info}/entry_points.txt +1 -0
  175. diffusers/loaders.py +0 -3336
  176. {diffusers-0.23.0.dist-info → diffusers-0.24.0.dist-info}/LICENSE +0 -0
  177. {diffusers-0.23.0.dist-info → diffusers-0.24.0.dist-info}/top_level.txt +0 -0
@@ -425,10 +425,7 @@ class UNetFlatConditionModel(ModelMixin, ConfigMixin):
425
425
 
426
426
  if num_attention_heads is not None:
427
427
  raise ValueError(
428
- "At the moment it is not possible to define the number of attention heads via `num_attention_heads`"
429
- " because of a naming issue as described in"
430
- " https://github.com/huggingface/diffusers/issues/2011#issuecomment-1547958131. Passing"
431
- " `num_attention_heads` will only be supported in diffusers v0.19."
428
+ "At the moment it is not possible to define the number of attention heads via `num_attention_heads` because of a naming issue as described in https://github.com/huggingface/diffusers/issues/2011#issuecomment-1547958131. Passing `num_attention_heads` will only be supported in diffusers v0.19."
432
429
  )
433
430
 
434
431
  # If `num_attention_heads` is not defined (which is the case for most models)
@@ -442,44 +439,37 @@ class UNetFlatConditionModel(ModelMixin, ConfigMixin):
442
439
  # Check inputs
443
440
  if len(down_block_types) != len(up_block_types):
444
441
  raise ValueError(
445
- "Must provide the same number of `down_block_types` as `up_block_types`. `down_block_types`:"
446
- f" {down_block_types}. `up_block_types`: {up_block_types}."
442
+ f"Must provide the same number of `down_block_types` as `up_block_types`. `down_block_types`: {down_block_types}. `up_block_types`: {up_block_types}."
447
443
  )
448
444
 
449
445
  if len(block_out_channels) != len(down_block_types):
450
446
  raise ValueError(
451
- "Must provide the same number of `block_out_channels` as `down_block_types`. `block_out_channels`:"
452
- f" {block_out_channels}. `down_block_types`: {down_block_types}."
447
+ f"Must provide the same number of `block_out_channels` as `down_block_types`. `block_out_channels`: {block_out_channels}. `down_block_types`: {down_block_types}."
453
448
  )
454
449
 
455
450
  if not isinstance(only_cross_attention, bool) and len(only_cross_attention) != len(down_block_types):
456
451
  raise ValueError(
457
- "Must provide the same number of `only_cross_attention` as `down_block_types`."
458
- f" `only_cross_attention`: {only_cross_attention}. `down_block_types`: {down_block_types}."
452
+ f"Must provide the same number of `only_cross_attention` as `down_block_types`. `only_cross_attention`: {only_cross_attention}. `down_block_types`: {down_block_types}."
459
453
  )
460
454
 
461
455
  if not isinstance(num_attention_heads, int) and len(num_attention_heads) != len(down_block_types):
462
456
  raise ValueError(
463
- "Must provide the same number of `num_attention_heads` as `down_block_types`. `num_attention_heads`:"
464
- f" {num_attention_heads}. `down_block_types`: {down_block_types}."
457
+ f"Must provide the same number of `num_attention_heads` as `down_block_types`. `num_attention_heads`: {num_attention_heads}. `down_block_types`: {down_block_types}."
465
458
  )
466
459
 
467
460
  if not isinstance(attention_head_dim, int) and len(attention_head_dim) != len(down_block_types):
468
461
  raise ValueError(
469
- "Must provide the same number of `attention_head_dim` as `down_block_types`. `attention_head_dim`:"
470
- f" {attention_head_dim}. `down_block_types`: {down_block_types}."
462
+ f"Must provide the same number of `attention_head_dim` as `down_block_types`. `attention_head_dim`: {attention_head_dim}. `down_block_types`: {down_block_types}."
471
463
  )
472
464
 
473
465
  if isinstance(cross_attention_dim, list) and len(cross_attention_dim) != len(down_block_types):
474
466
  raise ValueError(
475
- "Must provide the same number of `cross_attention_dim` as `down_block_types`. `cross_attention_dim`:"
476
- f" {cross_attention_dim}. `down_block_types`: {down_block_types}."
467
+ f"Must provide the same number of `cross_attention_dim` as `down_block_types`. `cross_attention_dim`: {cross_attention_dim}. `down_block_types`: {down_block_types}."
477
468
  )
478
469
 
479
470
  if not isinstance(layers_per_block, int) and len(layers_per_block) != len(down_block_types):
480
471
  raise ValueError(
481
- "Must provide the same number of `layers_per_block` as `down_block_types`. `layers_per_block`:"
482
- f" {layers_per_block}. `down_block_types`: {down_block_types}."
472
+ f"Must provide the same number of `layers_per_block` as `down_block_types`. `layers_per_block`: {layers_per_block}. `down_block_types`: {down_block_types}."
483
473
  )
484
474
  if isinstance(transformer_layers_per_block, list) and reverse_transformer_layers_per_block is None:
485
475
  for layer_number_per_block in transformer_layers_per_block:
@@ -897,8 +887,7 @@ class UNetFlatConditionModel(ModelMixin, ConfigMixin):
897
887
  processor = AttnProcessor()
898
888
  else:
899
889
  raise ValueError(
900
- "Cannot call `set_default_attn_processor` when attention processors are of type"
901
- f" {next(iter(self.attn_processors.values()))}"
890
+ f"Cannot call `set_default_attn_processor` when attention processors are of type {next(iter(self.attn_processors.values()))}"
902
891
  )
903
892
 
904
893
  self.set_attn_processor(processor, _remove_lora=True)
@@ -1166,8 +1155,7 @@ class UNetFlatConditionModel(ModelMixin, ConfigMixin):
1166
1155
  # Kandinsky 2.1 - style
1167
1156
  if "image_embeds" not in added_cond_kwargs:
1168
1157
  raise ValueError(
1169
- f"{self.__class__} has the config param `addition_embed_type` set to 'text_image' which requires"
1170
- " the keyword argument `image_embeds` to be passed in `added_cond_kwargs`"
1158
+ f"{self.__class__} has the config param `addition_embed_type` set to 'text_image' which requires the keyword argument `image_embeds` to be passed in `added_cond_kwargs`"
1171
1159
  )
1172
1160
 
1173
1161
  image_embs = added_cond_kwargs.get("image_embeds")
@@ -1177,14 +1165,12 @@ class UNetFlatConditionModel(ModelMixin, ConfigMixin):
1177
1165
  # SDXL - style
1178
1166
  if "text_embeds" not in added_cond_kwargs:
1179
1167
  raise ValueError(
1180
- f"{self.__class__} has the config param `addition_embed_type` set to 'text_time' which requires"
1181
- " the keyword argument `text_embeds` to be passed in `added_cond_kwargs`"
1168
+ f"{self.__class__} has the config param `addition_embed_type` set to 'text_time' which requires the keyword argument `text_embeds` to be passed in `added_cond_kwargs`"
1182
1169
  )
1183
1170
  text_embeds = added_cond_kwargs.get("text_embeds")
1184
1171
  if "time_ids" not in added_cond_kwargs:
1185
1172
  raise ValueError(
1186
- f"{self.__class__} has the config param `addition_embed_type` set to 'text_time' which requires"
1187
- " the keyword argument `time_ids` to be passed in `added_cond_kwargs`"
1173
+ f"{self.__class__} has the config param `addition_embed_type` set to 'text_time' which requires the keyword argument `time_ids` to be passed in `added_cond_kwargs`"
1188
1174
  )
1189
1175
  time_ids = added_cond_kwargs.get("time_ids")
1190
1176
  time_embeds = self.add_time_proj(time_ids.flatten())
@@ -1196,8 +1182,7 @@ class UNetFlatConditionModel(ModelMixin, ConfigMixin):
1196
1182
  # Kandinsky 2.2 - style
1197
1183
  if "image_embeds" not in added_cond_kwargs:
1198
1184
  raise ValueError(
1199
- f"{self.__class__} has the config param `addition_embed_type` set to 'image' which requires the"
1200
- " keyword argument `image_embeds` to be passed in `added_cond_kwargs`"
1185
+ f"{self.__class__} has the config param `addition_embed_type` set to 'image' which requires the keyword argument `image_embeds` to be passed in `added_cond_kwargs`"
1201
1186
  )
1202
1187
  image_embs = added_cond_kwargs.get("image_embeds")
1203
1188
  aug_emb = self.add_embedding(image_embs)
@@ -1205,8 +1190,7 @@ class UNetFlatConditionModel(ModelMixin, ConfigMixin):
1205
1190
  # Kandinsky 2.2 - style
1206
1191
  if "image_embeds" not in added_cond_kwargs or "hint" not in added_cond_kwargs:
1207
1192
  raise ValueError(
1208
- f"{self.__class__} has the config param `addition_embed_type` set to 'image_hint' which requires"
1209
- " the keyword arguments `image_embeds` and `hint` to be passed in `added_cond_kwargs`"
1193
+ f"{self.__class__} has the config param `addition_embed_type` set to 'image_hint' which requires the keyword arguments `image_embeds` and `hint` to be passed in `added_cond_kwargs`"
1210
1194
  )
1211
1195
  image_embs = added_cond_kwargs.get("image_embeds")
1212
1196
  hint = added_cond_kwargs.get("hint")
@@ -1224,8 +1208,7 @@ class UNetFlatConditionModel(ModelMixin, ConfigMixin):
1224
1208
  # Kadinsky 2.1 - style
1225
1209
  if "image_embeds" not in added_cond_kwargs:
1226
1210
  raise ValueError(
1227
- f"{self.__class__} has the config param `encoder_hid_dim_type` set to 'text_image_proj' which"
1228
- " requires the keyword argument `image_embeds` to be passed in `added_conditions`"
1211
+ f"{self.__class__} has the config param `encoder_hid_dim_type` set to 'text_image_proj' which requires the keyword argument `image_embeds` to be passed in `added_conditions`"
1229
1212
  )
1230
1213
 
1231
1214
  image_embeds = added_cond_kwargs.get("image_embeds")
@@ -1234,11 +1217,19 @@ class UNetFlatConditionModel(ModelMixin, ConfigMixin):
1234
1217
  # Kandinsky 2.2 - style
1235
1218
  if "image_embeds" not in added_cond_kwargs:
1236
1219
  raise ValueError(
1237
- f"{self.__class__} has the config param `encoder_hid_dim_type` set to 'image_proj' which requires"
1238
- " the keyword argument `image_embeds` to be passed in `added_conditions`"
1220
+ f"{self.__class__} has the config param `encoder_hid_dim_type` set to 'image_proj' which requires the keyword argument `image_embeds` to be passed in `added_conditions`"
1239
1221
  )
1240
1222
  image_embeds = added_cond_kwargs.get("image_embeds")
1241
1223
  encoder_hidden_states = self.encoder_hid_proj(image_embeds)
1224
+ elif self.encoder_hid_proj is not None and self.config.encoder_hid_dim_type == "ip_image_proj":
1225
+ if "image_embeds" not in added_cond_kwargs:
1226
+ raise ValueError(
1227
+ f"{self.__class__} has the config param `encoder_hid_dim_type` set to 'ip_image_proj' which requires the keyword argument `image_embeds` to be passed in `added_conditions`"
1228
+ )
1229
+ image_embeds = added_cond_kwargs.get("image_embeds")
1230
+ image_embeds = self.encoder_hid_proj(image_embeds).to(encoder_hidden_states.dtype)
1231
+ encoder_hidden_states = torch.cat([encoder_hidden_states, image_embeds], dim=1)
1232
+
1242
1233
  # 2. pre-process
1243
1234
  sample = self.conv_in(sample)
1244
1235
 
@@ -1264,10 +1255,9 @@ class UNetFlatConditionModel(ModelMixin, ConfigMixin):
1264
1255
  deprecate(
1265
1256
  "T2I should not use down_block_additional_residuals",
1266
1257
  "1.3.0",
1267
- "Passing intrablock residual connections with `down_block_additional_residuals` is deprecated "
1268
- " and will be removed in diffusers 1.3.0. `down_block_additional_residuals` should only"
1269
- " be used for ControlNet. Please make sure use"
1270
- " `down_intrablock_additional_residuals` instead. ",
1258
+ "Passing intrablock residual connections with `down_block_additional_residuals` is deprecated \
1259
+ and will be removed in diffusers 1.3.0. `down_block_additional_residuals` should only be used \
1260
+ for ControlNet. Please make sure use `down_intrablock_additional_residuals` instead. ",
1271
1261
  standard_warn=False,
1272
1262
  )
1273
1263
  down_intrablock_additional_residuals = down_block_additional_residuals
@@ -2102,8 +2092,7 @@ class UNetMidBlockFlat(nn.Module):
2102
2092
 
2103
2093
  if attention_head_dim is None:
2104
2094
  logger.warn(
2105
- "It is not recommend to pass `attention_head_dim=None`. Defaulting `attention_head_dim` to"
2106
- f" `in_channels`: {in_channels}."
2095
+ f"It is not recommend to pass `attention_head_dim=None`. Defaulting `attention_head_dim` to `in_channels`: {in_channels}."
2107
2096
  )
2108
2097
  attention_head_dim = in_channels
2109
2098
 
@@ -58,6 +58,7 @@ class VersatileDiffusionDualGuidedPipeline(DiffusionPipeline):
58
58
  A scheduler to be used in combination with `unet` to denoise the encoded image latents. Can be one of
59
59
  [`DDIMScheduler`], [`LMSDiscreteScheduler`], or [`PNDMScheduler`].
60
60
  """
61
+
61
62
  model_cpu_offload_seq = "bert->unet->vqvae"
62
63
 
63
64
  tokenizer: CLIPTokenizer
@@ -52,6 +52,7 @@ class VersatileDiffusionImageVariationPipeline(DiffusionPipeline):
52
52
  A scheduler to be used in combination with `unet` to denoise the encoded image latents. Can be one of
53
53
  [`DDIMScheduler`], [`LMSDiscreteScheduler`], or [`PNDMScheduler`].
54
54
  """
55
+
55
56
  model_cpu_offload_seq = "bert->unet->vqvae"
56
57
 
57
58
  image_feature_extractor: CLIPImageProcessor
@@ -51,6 +51,7 @@ class VersatileDiffusionTextToImagePipeline(DiffusionPipeline):
51
51
  A scheduler to be used in combination with `unet` to denoise the encoded image latents. Can be one of
52
52
  [`DDIMScheduler`], [`LMSDiscreteScheduler`], or [`PNDMScheduler`].
53
53
  """
54
+
54
55
  model_cpu_offload_seq = "bert->unet->vqvae"
55
56
 
56
57
  tokenizer: CLIPTokenizer
@@ -17,6 +17,8 @@ import torch
17
17
  import torch.nn as nn
18
18
 
19
19
  from ...models.attention_processor import Attention
20
+ from ...models.lora import LoRACompatibleConv, LoRACompatibleLinear
21
+ from ...utils import USE_PEFT_BACKEND
20
22
 
21
23
 
22
24
  class WuerstchenLayerNorm(nn.LayerNorm):
@@ -32,7 +34,8 @@ class WuerstchenLayerNorm(nn.LayerNorm):
32
34
  class TimestepBlock(nn.Module):
33
35
  def __init__(self, c, c_timestep):
34
36
  super().__init__()
35
- self.mapper = nn.Linear(c_timestep, c * 2)
37
+ linear_cls = nn.Linear if USE_PEFT_BACKEND else LoRACompatibleLinear
38
+ self.mapper = linear_cls(c_timestep, c * 2)
36
39
 
37
40
  def forward(self, x, t):
38
41
  a, b = self.mapper(t)[:, :, None, None].chunk(2, dim=1)
@@ -42,10 +45,14 @@ class TimestepBlock(nn.Module):
42
45
  class ResBlock(nn.Module):
43
46
  def __init__(self, c, c_skip=0, kernel_size=3, dropout=0.0):
44
47
  super().__init__()
45
- self.depthwise = nn.Conv2d(c + c_skip, c, kernel_size=kernel_size, padding=kernel_size // 2, groups=c)
48
+
49
+ conv_cls = nn.Conv2d if USE_PEFT_BACKEND else LoRACompatibleConv
50
+ linear_cls = nn.Linear if USE_PEFT_BACKEND else LoRACompatibleLinear
51
+
52
+ self.depthwise = conv_cls(c + c_skip, c, kernel_size=kernel_size, padding=kernel_size // 2, groups=c)
46
53
  self.norm = WuerstchenLayerNorm(c, elementwise_affine=False, eps=1e-6)
47
54
  self.channelwise = nn.Sequential(
48
- nn.Linear(c, c * 4), nn.GELU(), GlobalResponseNorm(c * 4), nn.Dropout(dropout), nn.Linear(c * 4, c)
55
+ linear_cls(c, c * 4), nn.GELU(), GlobalResponseNorm(c * 4), nn.Dropout(dropout), linear_cls(c * 4, c)
49
56
  )
50
57
 
51
58
  def forward(self, x, x_skip=None):
@@ -73,10 +80,13 @@ class GlobalResponseNorm(nn.Module):
73
80
  class AttnBlock(nn.Module):
74
81
  def __init__(self, c, c_cond, nhead, self_attn=True, dropout=0.0):
75
82
  super().__init__()
83
+
84
+ linear_cls = nn.Linear if USE_PEFT_BACKEND else LoRACompatibleLinear
85
+
76
86
  self.self_attn = self_attn
77
87
  self.norm = WuerstchenLayerNorm(c, elementwise_affine=False, eps=1e-6)
78
88
  self.attention = Attention(query_dim=c, heads=nhead, dim_head=c // nhead, dropout=dropout, bias=True)
79
- self.kv_mapper = nn.Sequential(nn.SiLU(), nn.Linear(c_cond, c))
89
+ self.kv_mapper = nn.Sequential(nn.SiLU(), linear_cls(c_cond, c))
80
90
 
81
91
  def forward(self, x, kv):
82
92
  kv = self.kv_mapper(kv)
@@ -28,8 +28,9 @@ from ...models.attention_processor import (
28
28
  AttnAddedKVProcessor,
29
29
  AttnProcessor,
30
30
  )
31
+ from ...models.lora import LoRACompatibleConv, LoRACompatibleLinear
31
32
  from ...models.modeling_utils import ModelMixin
32
- from ...utils import is_torch_version
33
+ from ...utils import USE_PEFT_BACKEND, is_torch_version
33
34
  from .modeling_wuerstchen_common import AttnBlock, ResBlock, TimestepBlock, WuerstchenLayerNorm
34
35
 
35
36
 
@@ -40,12 +41,15 @@ class WuerstchenPrior(ModelMixin, ConfigMixin, UNet2DConditionLoadersMixin):
40
41
  @register_to_config
41
42
  def __init__(self, c_in=16, c=1280, c_cond=1024, c_r=64, depth=16, nhead=16, dropout=0.1):
42
43
  super().__init__()
44
+ conv_cls = nn.Conv2d if USE_PEFT_BACKEND else LoRACompatibleConv
45
+ linear_cls = nn.Linear if USE_PEFT_BACKEND else LoRACompatibleLinear
46
+
43
47
  self.c_r = c_r
44
- self.projection = nn.Conv2d(c_in, c, kernel_size=1)
48
+ self.projection = conv_cls(c_in, c, kernel_size=1)
45
49
  self.cond_mapper = nn.Sequential(
46
- nn.Linear(c_cond, c),
50
+ linear_cls(c_cond, c),
47
51
  nn.LeakyReLU(0.2),
48
- nn.Linear(c, c),
52
+ linear_cls(c, c),
49
53
  )
50
54
 
51
55
  self.blocks = nn.ModuleList()
@@ -55,7 +59,7 @@ class WuerstchenPrior(ModelMixin, ConfigMixin, UNet2DConditionLoadersMixin):
55
59
  self.blocks.append(AttnBlock(c, c, nhead, self_attn=True, dropout=dropout))
56
60
  self.out = nn.Sequential(
57
61
  WuerstchenLayerNorm(c, elementwise_affine=False, eps=1e-6),
58
- nn.Conv2d(c, c_in * 2, kernel_size=1),
62
+ conv_cls(c, c_in * 2, kernel_size=1),
59
63
  )
60
64
 
61
65
  self.gradient_checkpointing = False
@@ -269,7 +269,7 @@ class WuerstchenDecoderPipeline(DiffusionPipeline):
269
269
  callback_on_step_end_tensor_inputs (`List`, *optional*):
270
270
  The list of tensor inputs for the `callback_on_step_end` function. The tensors specified in the list
271
271
  will be passed as `callback_kwargs` argument. You will only be able to include variables listed in the
272
- `._callback_tensor_inputs` attribute of your pipeine class.
272
+ `._callback_tensor_inputs` attribute of your pipeline class.
273
273
 
274
274
  Examples:
275
275
 
@@ -234,7 +234,7 @@ class WuerstchenCombinedPipeline(DiffusionPipeline):
234
234
  prior_callback_on_step_end_tensor_inputs (`List`, *optional*):
235
235
  The list of tensor inputs for the `prior_callback_on_step_end` function. The tensors specified in the
236
236
  list will be passed as `callback_kwargs` argument. You will only be able to include variables listed in
237
- the `._callback_tensor_inputs` attribute of your pipeine class.
237
+ the `._callback_tensor_inputs` attribute of your pipeline class.
238
238
  callback_on_step_end (`Callable`, *optional*):
239
239
  A function that calls at the end of each denoising steps during the inference. The function is called
240
240
  with the following arguments: `callback_on_step_end(self: DiffusionPipeline, step: int, timestep: int,
@@ -243,7 +243,7 @@ class WuerstchenCombinedPipeline(DiffusionPipeline):
243
243
  callback_on_step_end_tensor_inputs (`List`, *optional*):
244
244
  The list of tensor inputs for the `callback_on_step_end` function. The tensors specified in the list
245
245
  will be passed as `callback_kwargs` argument. You will only be able to include variables listed in the
246
- `._callback_tensor_inputs` attribute of your pipeine class.
246
+ `._callback_tensor_inputs` attribute of your pipeline class.
247
247
 
248
248
  Examples:
249
249
 
@@ -349,7 +349,7 @@ class WuerstchenPriorPipeline(DiffusionPipeline, LoraLoaderMixin):
349
349
  callback_on_step_end_tensor_inputs (`List`, *optional*):
350
350
  The list of tensor inputs for the `callback_on_step_end` function. The tensors specified in the list
351
351
  will be passed as `callback_kwargs` argument. You will only be able to include variables listed in the
352
- `._callback_tensor_inputs` attribute of your pipeine class.
352
+ `._callback_tensor_inputs` attribute of your pipeline class.
353
353
 
354
354
  Examples:
355
355
 
@@ -38,6 +38,7 @@ except OptionalDependencyNotAvailable:
38
38
  _dummy_modules.update(get_objects_from_module(dummy_pt_objects))
39
39
 
40
40
  else:
41
+ _import_structure["deprecated"] = ["KarrasVeScheduler", "ScoreSdeVpScheduler"]
41
42
  _import_structure["scheduling_consistency_decoder"] = ["ConsistencyDecoderScheduler"]
42
43
  _import_structure["scheduling_consistency_models"] = ["CMStochasticIterativeScheduler"]
43
44
  _import_structure["scheduling_ddim"] = ["DDIMScheduler"]
@@ -56,12 +57,10 @@ else:
56
57
  _import_structure["scheduling_ipndm"] = ["IPNDMScheduler"]
57
58
  _import_structure["scheduling_k_dpm_2_ancestral_discrete"] = ["KDPM2AncestralDiscreteScheduler"]
58
59
  _import_structure["scheduling_k_dpm_2_discrete"] = ["KDPM2DiscreteScheduler"]
59
- _import_structure["scheduling_karras_ve"] = ["KarrasVeScheduler"]
60
60
  _import_structure["scheduling_lcm"] = ["LCMScheduler"]
61
61
  _import_structure["scheduling_pndm"] = ["PNDMScheduler"]
62
62
  _import_structure["scheduling_repaint"] = ["RePaintScheduler"]
63
63
  _import_structure["scheduling_sde_ve"] = ["ScoreSdeVeScheduler"]
64
- _import_structure["scheduling_sde_vp"] = ["ScoreSdeVpScheduler"]
65
64
  _import_structure["scheduling_unclip"] = ["UnCLIPScheduler"]
66
65
  _import_structure["scheduling_unipc_multistep"] = ["UniPCMultistepScheduler"]
67
66
  _import_structure["scheduling_utils"] = ["KarrasDiffusionSchedulers", "SchedulerMixin"]
@@ -129,6 +128,7 @@ if TYPE_CHECKING or DIFFUSERS_SLOW_IMPORT:
129
128
  except OptionalDependencyNotAvailable:
130
129
  from ..utils.dummy_pt_objects import * # noqa F403
131
130
  else:
131
+ from .deprecated import KarrasVeScheduler, ScoreSdeVpScheduler
132
132
  from .scheduling_consistency_decoder import ConsistencyDecoderScheduler
133
133
  from .scheduling_consistency_models import CMStochasticIterativeScheduler
134
134
  from .scheduling_ddim import DDIMScheduler
@@ -147,12 +147,10 @@ if TYPE_CHECKING or DIFFUSERS_SLOW_IMPORT:
147
147
  from .scheduling_ipndm import IPNDMScheduler
148
148
  from .scheduling_k_dpm_2_ancestral_discrete import KDPM2AncestralDiscreteScheduler
149
149
  from .scheduling_k_dpm_2_discrete import KDPM2DiscreteScheduler
150
- from .scheduling_karras_ve import KarrasVeScheduler
151
150
  from .scheduling_lcm import LCMScheduler
152
151
  from .scheduling_pndm import PNDMScheduler
153
152
  from .scheduling_repaint import RePaintScheduler
154
153
  from .scheduling_sde_ve import ScoreSdeVeScheduler
155
- from .scheduling_sde_vp import ScoreSdeVpScheduler
156
154
  from .scheduling_unclip import UnCLIPScheduler
157
155
  from .scheduling_unipc_multistep import UniPCMultistepScheduler
158
156
  from .scheduling_utils import KarrasDiffusionSchedulers, SchedulerMixin
@@ -0,0 +1,50 @@
1
+ from typing import TYPE_CHECKING
2
+
3
+ from ...utils import (
4
+ DIFFUSERS_SLOW_IMPORT,
5
+ OptionalDependencyNotAvailable,
6
+ _LazyModule,
7
+ get_objects_from_module,
8
+ is_torch_available,
9
+ is_transformers_available,
10
+ )
11
+
12
+
13
+ _dummy_objects = {}
14
+ _import_structure = {}
15
+
16
+ try:
17
+ if not (is_transformers_available() and is_torch_available()):
18
+ raise OptionalDependencyNotAvailable()
19
+ except OptionalDependencyNotAvailable:
20
+ from ...utils import dummy_pt_objects # noqa F403
21
+
22
+ _dummy_objects.update(get_objects_from_module(dummy_pt_objects))
23
+ else:
24
+ _import_structure["scheduling_karras_ve"] = ["KarrasVeScheduler"]
25
+ _import_structure["scheduling_sde_vp"] = ["ScoreSdeVpScheduler"]
26
+
27
+ if TYPE_CHECKING or DIFFUSERS_SLOW_IMPORT:
28
+ try:
29
+ if not is_torch_available():
30
+ raise OptionalDependencyNotAvailable()
31
+
32
+ except OptionalDependencyNotAvailable:
33
+ from ..utils.dummy_pt_objects import * # noqa F403
34
+ else:
35
+ from .scheduling_karras_ve import KarrasVeScheduler
36
+ from .scheduling_sde_vp import ScoreSdeVpScheduler
37
+
38
+
39
+ else:
40
+ import sys
41
+
42
+ sys.modules[__name__] = _LazyModule(
43
+ __name__,
44
+ globals()["__file__"],
45
+ _import_structure,
46
+ module_spec=__spec__,
47
+ )
48
+
49
+ for name, value in _dummy_objects.items():
50
+ setattr(sys.modules[__name__], name, value)
@@ -19,10 +19,10 @@ from typing import Optional, Tuple, Union
19
19
  import numpy as np
20
20
  import torch
21
21
 
22
- from ..configuration_utils import ConfigMixin, register_to_config
23
- from ..utils import BaseOutput
24
- from ..utils.torch_utils import randn_tensor
25
- from .scheduling_utils import SchedulerMixin
22
+ from ...configuration_utils import ConfigMixin, register_to_config
23
+ from ...utils import BaseOutput
24
+ from ...utils.torch_utils import randn_tensor
25
+ from ..scheduling_utils import SchedulerMixin
26
26
 
27
27
 
28
28
  @dataclass
@@ -19,9 +19,9 @@ from typing import Union
19
19
 
20
20
  import torch
21
21
 
22
- from ..configuration_utils import ConfigMixin, register_to_config
23
- from ..utils.torch_utils import randn_tensor
24
- from .scheduling_utils import SchedulerMixin
22
+ from ...configuration_utils import ConfigMixin, register_to_config
23
+ from ...utils.torch_utils import randn_tensor
24
+ from ..scheduling_utils import SchedulerMixin
25
25
 
26
26
 
27
27
  class ScoreSdeVpScheduler(SchedulerMixin, ConfigMixin):
@@ -79,9 +79,7 @@ class ScoreSdeVpScheduler(SchedulerMixin, ConfigMixin):
79
79
 
80
80
  # TODO(Patrick) better comments + non-PyTorch
81
81
  # postprocess model score
82
- log_mean_coeff = (
83
- -0.25 * t**2 * (self.config.beta_max - self.config.beta_min) - 0.5 * t * self.config.beta_min
84
- )
82
+ log_mean_coeff = -0.25 * t**2 * (self.config.beta_max - self.config.beta_min) - 0.5 * t * self.config.beta_min
85
83
  std = torch.sqrt(1.0 - torch.exp(2.0 * log_mean_coeff))
86
84
  std = std.flatten()
87
85
  while len(std.shape) < len(score.shape):
@@ -208,9 +208,7 @@ class DDIMScheduler(SchedulerMixin, ConfigMixin):
208
208
  self.betas = torch.linspace(beta_start, beta_end, num_train_timesteps, dtype=torch.float32)
209
209
  elif beta_schedule == "scaled_linear":
210
210
  # this schedule is very specific to the latent diffusion model.
211
- self.betas = (
212
- torch.linspace(beta_start**0.5, beta_end**0.5, num_train_timesteps, dtype=torch.float32) ** 2
213
- )
211
+ self.betas = torch.linspace(beta_start**0.5, beta_end**0.5, num_train_timesteps, dtype=torch.float32) ** 2
214
212
  elif beta_schedule == "squaredcos_cap_v2":
215
213
  # Glide cosine schedule
216
214
  self.betas = betas_for_alpha_bar(num_train_timesteps)
@@ -204,9 +204,7 @@ class DDIMInverseScheduler(SchedulerMixin, ConfigMixin):
204
204
  self.betas = torch.linspace(beta_start, beta_end, num_train_timesteps, dtype=torch.float32)
205
205
  elif beta_schedule == "scaled_linear":
206
206
  # this schedule is very specific to the latent diffusion model.
207
- self.betas = (
208
- torch.linspace(beta_start**0.5, beta_end**0.5, num_train_timesteps, dtype=torch.float32) ** 2
209
- )
207
+ self.betas = torch.linspace(beta_start**0.5, beta_end**0.5, num_train_timesteps, dtype=torch.float32) ** 2
210
208
  elif beta_schedule == "squaredcos_cap_v2":
211
209
  # Glide cosine schedule
212
210
  self.betas = betas_for_alpha_bar(num_train_timesteps)
@@ -215,9 +215,7 @@ class DDIMParallelScheduler(SchedulerMixin, ConfigMixin):
215
215
  self.betas = torch.linspace(beta_start, beta_end, num_train_timesteps, dtype=torch.float32)
216
216
  elif beta_schedule == "scaled_linear":
217
217
  # this schedule is very specific to the latent diffusion model.
218
- self.betas = (
219
- torch.linspace(beta_start**0.5, beta_end**0.5, num_train_timesteps, dtype=torch.float32) ** 2
220
- )
218
+ self.betas = torch.linspace(beta_start**0.5, beta_end**0.5, num_train_timesteps, dtype=torch.float32) ** 2
221
219
  elif beta_schedule == "squaredcos_cap_v2":
222
220
  # Glide cosine schedule
223
221
  self.betas = betas_for_alpha_bar(num_train_timesteps)
@@ -160,9 +160,7 @@ class DDPMScheduler(SchedulerMixin, ConfigMixin):
160
160
  self.betas = torch.linspace(beta_start, beta_end, num_train_timesteps, dtype=torch.float32)
161
161
  elif beta_schedule == "scaled_linear":
162
162
  # this schedule is very specific to the latent diffusion model.
163
- self.betas = (
164
- torch.linspace(beta_start**0.5, beta_end**0.5, num_train_timesteps, dtype=torch.float32) ** 2
165
- )
163
+ self.betas = torch.linspace(beta_start**0.5, beta_end**0.5, num_train_timesteps, dtype=torch.float32) ** 2
166
164
  elif beta_schedule == "squaredcos_cap_v2":
167
165
  # Glide cosine schedule
168
166
  self.betas = betas_for_alpha_bar(num_train_timesteps)
@@ -170,9 +170,7 @@ class DDPMParallelScheduler(SchedulerMixin, ConfigMixin):
170
170
  self.betas = torch.linspace(beta_start, beta_end, num_train_timesteps, dtype=torch.float32)
171
171
  elif beta_schedule == "scaled_linear":
172
172
  # this schedule is very specific to the latent diffusion model.
173
- self.betas = (
174
- torch.linspace(beta_start**0.5, beta_end**0.5, num_train_timesteps, dtype=torch.float32) ** 2
175
- )
173
+ self.betas = torch.linspace(beta_start**0.5, beta_end**0.5, num_train_timesteps, dtype=torch.float32) ** 2
176
174
  elif beta_schedule == "squaredcos_cap_v2":
177
175
  # Glide cosine schedule
178
176
  self.betas = betas_for_alpha_bar(num_train_timesteps)
@@ -149,9 +149,7 @@ class DEISMultistepScheduler(SchedulerMixin, ConfigMixin):
149
149
  self.betas = torch.linspace(beta_start, beta_end, num_train_timesteps, dtype=torch.float32)
150
150
  elif beta_schedule == "scaled_linear":
151
151
  # this schedule is very specific to the latent diffusion model.
152
- self.betas = (
153
- torch.linspace(beta_start**0.5, beta_end**0.5, num_train_timesteps, dtype=torch.float32) ** 2
154
- )
152
+ self.betas = torch.linspace(beta_start**0.5, beta_end**0.5, num_train_timesteps, dtype=torch.float32) ** 2
155
153
  elif beta_schedule == "squaredcos_cap_v2":
156
154
  # Glide cosine schedule
157
155
  self.betas = betas_for_alpha_bar(num_train_timesteps)
@@ -325,8 +323,20 @@ class DEISMultistepScheduler(SchedulerMixin, ConfigMixin):
325
323
  def _convert_to_karras(self, in_sigmas: torch.FloatTensor, num_inference_steps) -> torch.FloatTensor:
326
324
  """Constructs the noise schedule of Karras et al. (2022)."""
327
325
 
328
- sigma_min: float = in_sigmas[-1].item()
329
- sigma_max: float = in_sigmas[0].item()
326
+ # Hack to make sure that other schedulers which copy this function don't break
327
+ # TODO: Add this logic to the other schedulers
328
+ if hasattr(self.config, "sigma_min"):
329
+ sigma_min = self.config.sigma_min
330
+ else:
331
+ sigma_min = None
332
+
333
+ if hasattr(self.config, "sigma_max"):
334
+ sigma_max = self.config.sigma_max
335
+ else:
336
+ sigma_max = None
337
+
338
+ sigma_min = sigma_min if sigma_min is not None else in_sigmas[-1].item()
339
+ sigma_max = sigma_max if sigma_max is not None else in_sigmas[0].item()
330
340
 
331
341
  rho = 7.0 # 7.0 is the value used in the paper
332
342
  ramp = np.linspace(0, 1, num_inference_steps)
@@ -176,9 +176,7 @@ class DPMSolverMultistepScheduler(SchedulerMixin, ConfigMixin):
176
176
  self.betas = torch.linspace(beta_start, beta_end, num_train_timesteps, dtype=torch.float32)
177
177
  elif beta_schedule == "scaled_linear":
178
178
  # this schedule is very specific to the latent diffusion model.
179
- self.betas = (
180
- torch.linspace(beta_start**0.5, beta_end**0.5, num_train_timesteps, dtype=torch.float32) ** 2
181
- )
179
+ self.betas = torch.linspace(beta_start**0.5, beta_end**0.5, num_train_timesteps, dtype=torch.float32) ** 2
182
180
  elif beta_schedule == "squaredcos_cap_v2":
183
181
  # Glide cosine schedule
184
182
  self.betas = betas_for_alpha_bar(num_train_timesteps)
@@ -360,8 +358,20 @@ class DPMSolverMultistepScheduler(SchedulerMixin, ConfigMixin):
360
358
  def _convert_to_karras(self, in_sigmas: torch.FloatTensor, num_inference_steps) -> torch.FloatTensor:
361
359
  """Constructs the noise schedule of Karras et al. (2022)."""
362
360
 
363
- sigma_min: float = in_sigmas[-1].item()
364
- sigma_max: float = in_sigmas[0].item()
361
+ # Hack to make sure that other schedulers which copy this function don't break
362
+ # TODO: Add this logic to the other schedulers
363
+ if hasattr(self.config, "sigma_min"):
364
+ sigma_min = self.config.sigma_min
365
+ else:
366
+ sigma_min = None
367
+
368
+ if hasattr(self.config, "sigma_max"):
369
+ sigma_max = self.config.sigma_max
370
+ else:
371
+ sigma_max = None
372
+
373
+ sigma_min = sigma_min if sigma_min is not None else in_sigmas[-1].item()
374
+ sigma_max = sigma_max if sigma_max is not None else in_sigmas[0].item()
365
375
 
366
376
  rho = 7.0 # 7.0 is the value used in the paper
367
377
  ramp = np.linspace(0, 1, num_inference_steps)
@@ -171,9 +171,7 @@ class DPMSolverMultistepInverseScheduler(SchedulerMixin, ConfigMixin):
171
171
  self.betas = torch.linspace(beta_start, beta_end, num_train_timesteps, dtype=torch.float32)
172
172
  elif beta_schedule == "scaled_linear":
173
173
  # this schedule is very specific to the latent diffusion model.
174
- self.betas = (
175
- torch.linspace(beta_start**0.5, beta_end**0.5, num_train_timesteps, dtype=torch.float32) ** 2
176
- )
174
+ self.betas = torch.linspace(beta_start**0.5, beta_end**0.5, num_train_timesteps, dtype=torch.float32) ** 2
177
175
  elif beta_schedule == "squaredcos_cap_v2":
178
176
  # Glide cosine schedule
179
177
  self.betas = betas_for_alpha_bar(num_train_timesteps)
@@ -360,8 +358,20 @@ class DPMSolverMultistepInverseScheduler(SchedulerMixin, ConfigMixin):
360
358
  def _convert_to_karras(self, in_sigmas: torch.FloatTensor, num_inference_steps) -> torch.FloatTensor:
361
359
  """Constructs the noise schedule of Karras et al. (2022)."""
362
360
 
363
- sigma_min: float = in_sigmas[-1].item()
364
- sigma_max: float = in_sigmas[0].item()
361
+ # Hack to make sure that other schedulers which copy this function don't break
362
+ # TODO: Add this logic to the other schedulers
363
+ if hasattr(self.config, "sigma_min"):
364
+ sigma_min = self.config.sigma_min
365
+ else:
366
+ sigma_min = None
367
+
368
+ if hasattr(self.config, "sigma_max"):
369
+ sigma_max = self.config.sigma_max
370
+ else:
371
+ sigma_max = None
372
+
373
+ sigma_min = sigma_min if sigma_min is not None else in_sigmas[-1].item()
374
+ sigma_max = sigma_max if sigma_max is not None else in_sigmas[0].item()
365
375
 
366
376
  rho = 7.0 # 7.0 is the value used in the paper
367
377
  ramp = np.linspace(0, 1, num_inference_steps)