diffusers 0.23.1__py3-none-any.whl → 0.24.0__py3-none-any.whl

Sign up to get free protection for your applications and to get access to all the features.
Files changed (176) hide show
  1. diffusers/__init__.py +16 -2
  2. diffusers/configuration_utils.py +1 -0
  3. diffusers/dependency_versions_check.py +0 -1
  4. diffusers/dependency_versions_table.py +4 -5
  5. diffusers/image_processor.py +186 -14
  6. diffusers/loaders/__init__.py +82 -0
  7. diffusers/loaders/ip_adapter.py +157 -0
  8. diffusers/loaders/lora.py +1415 -0
  9. diffusers/loaders/lora_conversion_utils.py +284 -0
  10. diffusers/loaders/single_file.py +631 -0
  11. diffusers/loaders/textual_inversion.py +459 -0
  12. diffusers/loaders/unet.py +735 -0
  13. diffusers/loaders/utils.py +59 -0
  14. diffusers/models/__init__.py +12 -1
  15. diffusers/models/attention.py +165 -14
  16. diffusers/models/attention_flax.py +9 -1
  17. diffusers/models/attention_processor.py +286 -1
  18. diffusers/models/autoencoder_asym_kl.py +14 -9
  19. diffusers/models/autoencoder_kl.py +3 -18
  20. diffusers/models/autoencoder_kl_temporal_decoder.py +402 -0
  21. diffusers/models/autoencoder_tiny.py +20 -24
  22. diffusers/models/consistency_decoder_vae.py +37 -30
  23. diffusers/models/controlnet.py +59 -39
  24. diffusers/models/controlnet_flax.py +19 -18
  25. diffusers/models/embeddings_flax.py +2 -0
  26. diffusers/models/lora.py +131 -1
  27. diffusers/models/modeling_flax_utils.py +2 -1
  28. diffusers/models/modeling_outputs.py +17 -0
  29. diffusers/models/modeling_utils.py +27 -19
  30. diffusers/models/normalization.py +2 -2
  31. diffusers/models/resnet.py +390 -59
  32. diffusers/models/transformer_2d.py +20 -3
  33. diffusers/models/transformer_temporal.py +183 -1
  34. diffusers/models/unet_2d_blocks_flax.py +5 -0
  35. diffusers/models/unet_2d_condition.py +9 -0
  36. diffusers/models/unet_2d_condition_flax.py +13 -13
  37. diffusers/models/unet_3d_blocks.py +957 -173
  38. diffusers/models/unet_3d_condition.py +16 -8
  39. diffusers/models/unet_kandi3.py +589 -0
  40. diffusers/models/unet_motion_model.py +48 -33
  41. diffusers/models/unet_spatio_temporal_condition.py +489 -0
  42. diffusers/models/vae.py +63 -13
  43. diffusers/models/vae_flax.py +7 -0
  44. diffusers/models/vq_model.py +3 -1
  45. diffusers/optimization.py +16 -9
  46. diffusers/pipelines/__init__.py +65 -12
  47. diffusers/pipelines/alt_diffusion/pipeline_alt_diffusion.py +93 -23
  48. diffusers/pipelines/alt_diffusion/pipeline_alt_diffusion_img2img.py +97 -25
  49. diffusers/pipelines/animatediff/pipeline_animatediff.py +34 -4
  50. diffusers/pipelines/audioldm/pipeline_audioldm.py +1 -0
  51. diffusers/pipelines/auto_pipeline.py +6 -0
  52. diffusers/pipelines/consistency_models/pipeline_consistency_models.py +1 -0
  53. diffusers/pipelines/controlnet/pipeline_controlnet.py +217 -31
  54. diffusers/pipelines/controlnet/pipeline_controlnet_img2img.py +101 -32
  55. diffusers/pipelines/controlnet/pipeline_controlnet_inpaint.py +136 -39
  56. diffusers/pipelines/controlnet/pipeline_controlnet_inpaint_sd_xl.py +119 -37
  57. diffusers/pipelines/controlnet/pipeline_controlnet_sd_xl.py +196 -35
  58. diffusers/pipelines/controlnet/pipeline_controlnet_sd_xl_img2img.py +102 -31
  59. diffusers/pipelines/dance_diffusion/pipeline_dance_diffusion.py +1 -0
  60. diffusers/pipelines/ddim/pipeline_ddim.py +1 -0
  61. diffusers/pipelines/ddpm/pipeline_ddpm.py +1 -0
  62. diffusers/pipelines/deepfloyd_if/pipeline_if.py +13 -1
  63. diffusers/pipelines/deepfloyd_if/pipeline_if_img2img.py +13 -1
  64. diffusers/pipelines/deepfloyd_if/pipeline_if_img2img_superresolution.py +13 -1
  65. diffusers/pipelines/deepfloyd_if/pipeline_if_inpainting.py +13 -1
  66. diffusers/pipelines/deepfloyd_if/pipeline_if_inpainting_superresolution.py +13 -1
  67. diffusers/pipelines/deepfloyd_if/pipeline_if_superresolution.py +13 -1
  68. diffusers/pipelines/dit/pipeline_dit.py +1 -0
  69. diffusers/pipelines/kandinsky2_2/pipeline_kandinsky2_2.py +1 -1
  70. diffusers/pipelines/kandinsky2_2/pipeline_kandinsky2_2_combined.py +3 -3
  71. diffusers/pipelines/kandinsky2_2/pipeline_kandinsky2_2_img2img.py +1 -1
  72. diffusers/pipelines/kandinsky2_2/pipeline_kandinsky2_2_inpainting.py +1 -1
  73. diffusers/pipelines/kandinsky2_2/pipeline_kandinsky2_2_prior.py +1 -1
  74. diffusers/pipelines/kandinsky3/__init__.py +49 -0
  75. diffusers/pipelines/kandinsky3/kandinsky3_pipeline.py +452 -0
  76. diffusers/pipelines/kandinsky3/kandinsky3img2img_pipeline.py +460 -0
  77. diffusers/pipelines/latent_consistency_models/pipeline_latent_consistency_img2img.py +65 -6
  78. diffusers/pipelines/latent_consistency_models/pipeline_latent_consistency_text2img.py +55 -3
  79. diffusers/pipelines/latent_diffusion/pipeline_latent_diffusion.py +1 -0
  80. diffusers/pipelines/musicldm/pipeline_musicldm.py +1 -1
  81. diffusers/pipelines/paint_by_example/pipeline_paint_by_example.py +7 -2
  82. diffusers/pipelines/pipeline_flax_utils.py +4 -2
  83. diffusers/pipelines/pipeline_utils.py +33 -13
  84. diffusers/pipelines/pixart_alpha/pipeline_pixart_alpha.py +196 -36
  85. diffusers/pipelines/score_sde_ve/pipeline_score_sde_ve.py +1 -0
  86. diffusers/pipelines/spectrogram_diffusion/pipeline_spectrogram_diffusion.py +1 -0
  87. diffusers/pipelines/stable_diffusion/__init__.py +64 -21
  88. diffusers/pipelines/stable_diffusion/convert_from_ckpt.py +8 -3
  89. diffusers/pipelines/stable_diffusion/pipeline_cycle_diffusion.py +18 -2
  90. diffusers/pipelines/stable_diffusion/pipeline_flax_stable_diffusion.py +2 -2
  91. diffusers/pipelines/stable_diffusion/pipeline_onnx_stable_diffusion_img2img.py +2 -4
  92. diffusers/pipelines/stable_diffusion/pipeline_onnx_stable_diffusion_inpaint.py +1 -0
  93. diffusers/pipelines/stable_diffusion/pipeline_onnx_stable_diffusion_inpaint_legacy.py +1 -0
  94. diffusers/pipelines/stable_diffusion/pipeline_stable_diffusion.py +88 -9
  95. diffusers/pipelines/stable_diffusion/pipeline_stable_diffusion_attend_and_excite.py +1 -0
  96. diffusers/pipelines/stable_diffusion/pipeline_stable_diffusion_depth2img.py +8 -3
  97. diffusers/pipelines/stable_diffusion/pipeline_stable_diffusion_diffedit.py +1 -0
  98. diffusers/pipelines/stable_diffusion/pipeline_stable_diffusion_gligen.py +1 -0
  99. diffusers/pipelines/stable_diffusion/pipeline_stable_diffusion_gligen_text_image.py +1 -0
  100. diffusers/pipelines/stable_diffusion/pipeline_stable_diffusion_image_variation.py +1 -0
  101. diffusers/pipelines/stable_diffusion/pipeline_stable_diffusion_img2img.py +92 -9
  102. diffusers/pipelines/stable_diffusion/pipeline_stable_diffusion_inpaint.py +92 -9
  103. diffusers/pipelines/stable_diffusion/pipeline_stable_diffusion_inpaint_legacy.py +1 -0
  104. diffusers/pipelines/stable_diffusion/pipeline_stable_diffusion_instruct_pix2pix.py +17 -13
  105. diffusers/pipelines/stable_diffusion/pipeline_stable_diffusion_k_diffusion.py +1 -0
  106. diffusers/pipelines/stable_diffusion/pipeline_stable_diffusion_latent_upscale.py +1 -0
  107. diffusers/pipelines/stable_diffusion/pipeline_stable_diffusion_ldm3d.py +1 -0
  108. diffusers/pipelines/stable_diffusion/pipeline_stable_diffusion_model_editing.py +1 -0
  109. diffusers/pipelines/stable_diffusion/pipeline_stable_diffusion_panorama.py +1 -0
  110. diffusers/pipelines/stable_diffusion/pipeline_stable_diffusion_paradigms.py +1 -0
  111. diffusers/pipelines/stable_diffusion/pipeline_stable_diffusion_pix2pix_zero.py +1 -0
  112. diffusers/pipelines/stable_diffusion/pipeline_stable_diffusion_sag.py +1 -0
  113. diffusers/pipelines/stable_diffusion/pipeline_stable_diffusion_upscale.py +1 -0
  114. diffusers/pipelines/stable_diffusion_xl/pipeline_stable_diffusion_xl.py +103 -8
  115. diffusers/pipelines/stable_diffusion_xl/pipeline_stable_diffusion_xl_img2img.py +113 -8
  116. diffusers/pipelines/stable_diffusion_xl/pipeline_stable_diffusion_xl_inpaint.py +115 -9
  117. diffusers/pipelines/stable_diffusion_xl/pipeline_stable_diffusion_xl_instruct_pix2pix.py +16 -12
  118. diffusers/pipelines/stable_video_diffusion/__init__.py +58 -0
  119. diffusers/pipelines/stable_video_diffusion/pipeline_stable_video_diffusion.py +649 -0
  120. diffusers/pipelines/t2i_adapter/pipeline_stable_diffusion_adapter.py +108 -12
  121. diffusers/pipelines/t2i_adapter/pipeline_stable_diffusion_xl_adapter.py +109 -14
  122. diffusers/pipelines/text_to_video_synthesis/__init__.py +2 -0
  123. diffusers/pipelines/text_to_video_synthesis/pipeline_text_to_video_synth.py +1 -0
  124. diffusers/pipelines/text_to_video_synthesis/pipeline_text_to_video_synth_img2img.py +18 -3
  125. diffusers/pipelines/text_to_video_synthesis/pipeline_text_to_video_zero.py +4 -2
  126. diffusers/pipelines/text_to_video_synthesis/pipeline_text_to_video_zero_sdxl.py +872 -0
  127. diffusers/pipelines/versatile_diffusion/modeling_text_unet.py +29 -40
  128. diffusers/pipelines/versatile_diffusion/pipeline_versatile_diffusion_dual_guided.py +1 -0
  129. diffusers/pipelines/versatile_diffusion/pipeline_versatile_diffusion_image_variation.py +1 -0
  130. diffusers/pipelines/versatile_diffusion/pipeline_versatile_diffusion_text_to_image.py +1 -0
  131. diffusers/pipelines/wuerstchen/modeling_wuerstchen_common.py +14 -4
  132. diffusers/pipelines/wuerstchen/modeling_wuerstchen_prior.py +9 -5
  133. diffusers/pipelines/wuerstchen/pipeline_wuerstchen.py +1 -1
  134. diffusers/pipelines/wuerstchen/pipeline_wuerstchen_combined.py +2 -2
  135. diffusers/pipelines/wuerstchen/pipeline_wuerstchen_prior.py +1 -1
  136. diffusers/schedulers/__init__.py +2 -4
  137. diffusers/schedulers/deprecated/__init__.py +50 -0
  138. diffusers/schedulers/{scheduling_karras_ve.py → deprecated/scheduling_karras_ve.py} +4 -4
  139. diffusers/schedulers/{scheduling_sde_vp.py → deprecated/scheduling_sde_vp.py} +4 -6
  140. diffusers/schedulers/scheduling_ddim.py +1 -3
  141. diffusers/schedulers/scheduling_ddim_inverse.py +1 -3
  142. diffusers/schedulers/scheduling_ddim_parallel.py +1 -3
  143. diffusers/schedulers/scheduling_ddpm.py +1 -3
  144. diffusers/schedulers/scheduling_ddpm_parallel.py +1 -3
  145. diffusers/schedulers/scheduling_deis_multistep.py +15 -5
  146. diffusers/schedulers/scheduling_dpmsolver_multistep.py +15 -5
  147. diffusers/schedulers/scheduling_dpmsolver_multistep_inverse.py +15 -5
  148. diffusers/schedulers/scheduling_dpmsolver_sde.py +1 -3
  149. diffusers/schedulers/scheduling_dpmsolver_singlestep.py +15 -5
  150. diffusers/schedulers/scheduling_euler_ancestral_discrete.py +1 -3
  151. diffusers/schedulers/scheduling_euler_discrete.py +40 -13
  152. diffusers/schedulers/scheduling_heun_discrete.py +15 -5
  153. diffusers/schedulers/scheduling_k_dpm_2_ancestral_discrete.py +15 -5
  154. diffusers/schedulers/scheduling_k_dpm_2_discrete.py +15 -5
  155. diffusers/schedulers/scheduling_lcm.py +123 -29
  156. diffusers/schedulers/scheduling_lms_discrete.py +1 -3
  157. diffusers/schedulers/scheduling_pndm.py +1 -3
  158. diffusers/schedulers/scheduling_repaint.py +1 -3
  159. diffusers/schedulers/scheduling_unipc_multistep.py +15 -5
  160. diffusers/utils/__init__.py +1 -0
  161. diffusers/utils/constants.py +8 -7
  162. diffusers/utils/dummy_pt_objects.py +45 -0
  163. diffusers/utils/dummy_torch_and_transformers_objects.py +60 -0
  164. diffusers/utils/dynamic_modules_utils.py +4 -4
  165. diffusers/utils/export_utils.py +8 -3
  166. diffusers/utils/logging.py +10 -10
  167. diffusers/utils/outputs.py +5 -5
  168. diffusers/utils/peft_utils.py +88 -44
  169. diffusers/utils/torch_utils.py +2 -2
  170. {diffusers-0.23.1.dist-info → diffusers-0.24.0.dist-info}/METADATA +38 -22
  171. {diffusers-0.23.1.dist-info → diffusers-0.24.0.dist-info}/RECORD +175 -157
  172. diffusers/loaders.py +0 -3336
  173. {diffusers-0.23.1.dist-info → diffusers-0.24.0.dist-info}/LICENSE +0 -0
  174. {diffusers-0.23.1.dist-info → diffusers-0.24.0.dist-info}/WHEEL +0 -0
  175. {diffusers-0.23.1.dist-info → diffusers-0.24.0.dist-info}/entry_points.txt +0 -0
  176. {diffusers-0.23.1.dist-info → diffusers-0.24.0.dist-info}/top_level.txt +0 -0
@@ -18,13 +18,14 @@ import inspect
18
18
  import itertools
19
19
  import os
20
20
  import re
21
+ from collections import OrderedDict
21
22
  from functools import partial
22
23
  from typing import Any, Callable, List, Optional, Tuple, Union
23
24
 
24
25
  import safetensors
25
26
  import torch
26
27
  from huggingface_hub import create_repo
27
- from torch import Tensor, device, nn
28
+ from torch import Tensor, nn
28
29
 
29
30
  from .. import __version__
30
31
  from ..utils import (
@@ -61,7 +62,7 @@ if is_accelerate_available():
61
62
  from accelerate.utils.versions import is_torch_version
62
63
 
63
64
 
64
- def get_parameter_device(parameter: torch.nn.Module):
65
+ def get_parameter_device(parameter: torch.nn.Module) -> torch.device:
65
66
  try:
66
67
  parameters_and_buffers = itertools.chain(parameter.parameters(), parameter.buffers())
67
68
  return next(parameters_and_buffers).device
@@ -77,7 +78,7 @@ def get_parameter_device(parameter: torch.nn.Module):
77
78
  return first_tuple[1].device
78
79
 
79
80
 
80
- def get_parameter_dtype(parameter: torch.nn.Module):
81
+ def get_parameter_dtype(parameter: torch.nn.Module) -> torch.dtype:
81
82
  try:
82
83
  params = tuple(parameter.parameters())
83
84
  if len(params) > 0:
@@ -130,7 +131,13 @@ def load_state_dict(checkpoint_file: Union[str, os.PathLike], variant: Optional[
130
131
  )
131
132
 
132
133
 
133
- def load_model_dict_into_meta(model, state_dict, device=None, dtype=None, model_name_or_path=None):
134
+ def load_model_dict_into_meta(
135
+ model,
136
+ state_dict: OrderedDict,
137
+ device: Optional[Union[str, torch.device]] = None,
138
+ dtype: Optional[Union[str, torch.dtype]] = None,
139
+ model_name_or_path: Optional[str] = None,
140
+ ) -> List[str]:
134
141
  device = device or torch.device("cpu")
135
142
  dtype = dtype or torch.float32
136
143
 
@@ -156,7 +163,7 @@ def load_model_dict_into_meta(model, state_dict, device=None, dtype=None, model_
156
163
  return unexpected_keys
157
164
 
158
165
 
159
- def _load_state_dict_into_model(model_to_load, state_dict):
166
+ def _load_state_dict_into_model(model_to_load, state_dict: OrderedDict) -> List[str]:
160
167
  # Convert old format to new format if needed from a PyTorch state_dict
161
168
  # copy state_dict so _load_from_state_dict can modify it
162
169
  state_dict = state_dict.copy()
@@ -164,7 +171,7 @@ def _load_state_dict_into_model(model_to_load, state_dict):
164
171
 
165
172
  # PyTorch's `_load_from_state_dict` does not copy parameters in a module's descendants
166
173
  # so we need to apply the function recursively.
167
- def load(module: torch.nn.Module, prefix=""):
174
+ def load(module: torch.nn.Module, prefix: str = ""):
168
175
  args = (state_dict, prefix, {}, True, [], [], error_msgs)
169
176
  module._load_from_state_dict(*args)
170
177
 
@@ -186,6 +193,7 @@ class ModelMixin(torch.nn.Module, PushToHubMixin):
186
193
 
187
194
  - **config_name** ([`str`]) -- Filename to save a model to when calling [`~models.ModelMixin.save_pretrained`].
188
195
  """
196
+
189
197
  config_name = CONFIG_NAME
190
198
  _automatically_saved_args = ["_diffusers_version", "_class_name", "_name_or_path"]
191
199
  _supports_gradient_checkpointing = False
@@ -220,7 +228,7 @@ class ModelMixin(torch.nn.Module, PushToHubMixin):
220
228
  """
221
229
  return any(hasattr(m, "gradient_checkpointing") and m.gradient_checkpointing for m in self.modules())
222
230
 
223
- def enable_gradient_checkpointing(self):
231
+ def enable_gradient_checkpointing(self) -> None:
224
232
  """
225
233
  Activates gradient checkpointing for the current model (may be referred to as *activation checkpointing* or
226
234
  *checkpoint activations* in other frameworks).
@@ -229,7 +237,7 @@ class ModelMixin(torch.nn.Module, PushToHubMixin):
229
237
  raise ValueError(f"{self.__class__.__name__} does not support gradient checkpointing.")
230
238
  self.apply(partial(self._set_gradient_checkpointing, value=True))
231
239
 
232
- def disable_gradient_checkpointing(self):
240
+ def disable_gradient_checkpointing(self) -> None:
233
241
  """
234
242
  Deactivates gradient checkpointing for the current model (may be referred to as *activation checkpointing* or
235
243
  *checkpoint activations* in other frameworks).
@@ -254,7 +262,7 @@ class ModelMixin(torch.nn.Module, PushToHubMixin):
254
262
  if isinstance(module, torch.nn.Module):
255
263
  fn_recursive_set_mem_eff(module)
256
264
 
257
- def enable_xformers_memory_efficient_attention(self, attention_op: Optional[Callable] = None):
265
+ def enable_xformers_memory_efficient_attention(self, attention_op: Optional[Callable] = None) -> None:
258
266
  r"""
259
267
  Enable memory efficient attention from [xFormers](https://facebookresearch.github.io/xformers/).
260
268
 
@@ -290,7 +298,7 @@ class ModelMixin(torch.nn.Module, PushToHubMixin):
290
298
  """
291
299
  self.set_use_memory_efficient_attention_xformers(True, attention_op)
292
300
 
293
- def disable_xformers_memory_efficient_attention(self):
301
+ def disable_xformers_memory_efficient_attention(self) -> None:
294
302
  r"""
295
303
  Disable memory efficient attention from [xFormers](https://facebookresearch.github.io/xformers/).
296
304
  """
@@ -447,7 +455,7 @@ class ModelMixin(torch.nn.Module, PushToHubMixin):
447
455
  self,
448
456
  save_directory: Union[str, os.PathLike],
449
457
  is_main_process: bool = True,
450
- save_function: Callable = None,
458
+ save_function: Optional[Callable] = None,
451
459
  safe_serialization: bool = True,
452
460
  variant: Optional[str] = None,
453
461
  push_to_hub: bool = False,
@@ -910,10 +918,10 @@ class ModelMixin(torch.nn.Module, PushToHubMixin):
910
918
  def _load_pretrained_model(
911
919
  cls,
912
920
  model,
913
- state_dict,
921
+ state_dict: OrderedDict,
914
922
  resolved_archive_file,
915
- pretrained_model_name_or_path,
916
- ignore_mismatched_sizes=False,
923
+ pretrained_model_name_or_path: Union[str, os.PathLike],
924
+ ignore_mismatched_sizes: bool = False,
917
925
  ):
918
926
  # Retrieve missing & unexpected_keys
919
927
  model_state_dict = model.state_dict()
@@ -1011,7 +1019,7 @@ class ModelMixin(torch.nn.Module, PushToHubMixin):
1011
1019
  return model, missing_keys, unexpected_keys, mismatched_keys, error_msgs
1012
1020
 
1013
1021
  @property
1014
- def device(self) -> device:
1022
+ def device(self) -> torch.device:
1015
1023
  """
1016
1024
  `torch.device`: The device on which the module is (assuming that all the module parameters are on the same
1017
1025
  device).
@@ -1063,7 +1071,7 @@ class ModelMixin(torch.nn.Module, PushToHubMixin):
1063
1071
  else:
1064
1072
  return sum(p.numel() for p in self.parameters() if p.requires_grad or not only_trainable)
1065
1073
 
1066
- def _convert_deprecated_attention_blocks(self, state_dict):
1074
+ def _convert_deprecated_attention_blocks(self, state_dict: OrderedDict) -> None:
1067
1075
  deprecated_attention_block_paths = []
1068
1076
 
1069
1077
  def recursive_find_attn_block(name, module):
@@ -1107,7 +1115,7 @@ class ModelMixin(torch.nn.Module, PushToHubMixin):
1107
1115
  if f"{path}.proj_attn.bias" in state_dict:
1108
1116
  state_dict[f"{path}.to_out.0.bias"] = state_dict.pop(f"{path}.proj_attn.bias")
1109
1117
 
1110
- def _temp_convert_self_to_deprecated_attention_blocks(self):
1118
+ def _temp_convert_self_to_deprecated_attention_blocks(self) -> None:
1111
1119
  deprecated_attention_block_modules = []
1112
1120
 
1113
1121
  def recursive_find_attn_block(module):
@@ -1134,10 +1142,10 @@ class ModelMixin(torch.nn.Module, PushToHubMixin):
1134
1142
  del module.to_v
1135
1143
  del module.to_out
1136
1144
 
1137
- def _undo_temp_convert_self_to_deprecated_attention_blocks(self):
1145
+ def _undo_temp_convert_self_to_deprecated_attention_blocks(self) -> None:
1138
1146
  deprecated_attention_block_modules = []
1139
1147
 
1140
- def recursive_find_attn_block(module):
1148
+ def recursive_find_attn_block(module) -> None:
1141
1149
  if hasattr(module, "_from_deprecated_attn_block") and module._from_deprecated_attn_block:
1142
1150
  deprecated_attention_block_modules.append(module)
1143
1151
 
@@ -101,8 +101,8 @@ class AdaLayerNormSingle(nn.Module):
101
101
  def forward(
102
102
  self,
103
103
  timestep: torch.Tensor,
104
- added_cond_kwargs: Dict[str, torch.Tensor] = None,
105
- batch_size: int = None,
104
+ added_cond_kwargs: Optional[Dict[str, torch.Tensor]] = None,
105
+ batch_size: Optional[int] = None,
106
106
  hidden_dtype: Optional[torch.dtype] = None,
107
107
  ) -> Tuple[torch.Tensor, torch.Tensor, torch.Tensor, torch.Tensor, torch.Tensor]:
108
108
  # No modulation happening here.