diffusers 0.34.0__py3-none-any.whl → 0.35.0__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (191) hide show
  1. diffusers/__init__.py +98 -1
  2. diffusers/callbacks.py +35 -0
  3. diffusers/commands/custom_blocks.py +134 -0
  4. diffusers/commands/diffusers_cli.py +2 -0
  5. diffusers/commands/fp16_safetensors.py +1 -1
  6. diffusers/configuration_utils.py +11 -2
  7. diffusers/dependency_versions_table.py +3 -3
  8. diffusers/guiders/__init__.py +41 -0
  9. diffusers/guiders/adaptive_projected_guidance.py +188 -0
  10. diffusers/guiders/auto_guidance.py +190 -0
  11. diffusers/guiders/classifier_free_guidance.py +141 -0
  12. diffusers/guiders/classifier_free_zero_star_guidance.py +152 -0
  13. diffusers/guiders/frequency_decoupled_guidance.py +327 -0
  14. diffusers/guiders/guider_utils.py +309 -0
  15. diffusers/guiders/perturbed_attention_guidance.py +271 -0
  16. diffusers/guiders/skip_layer_guidance.py +262 -0
  17. diffusers/guiders/smoothed_energy_guidance.py +251 -0
  18. diffusers/guiders/tangential_classifier_free_guidance.py +143 -0
  19. diffusers/hooks/__init__.py +17 -0
  20. diffusers/hooks/_common.py +56 -0
  21. diffusers/hooks/_helpers.py +293 -0
  22. diffusers/hooks/faster_cache.py +7 -6
  23. diffusers/hooks/first_block_cache.py +259 -0
  24. diffusers/hooks/group_offloading.py +292 -286
  25. diffusers/hooks/hooks.py +56 -1
  26. diffusers/hooks/layer_skip.py +263 -0
  27. diffusers/hooks/layerwise_casting.py +2 -7
  28. diffusers/hooks/pyramid_attention_broadcast.py +14 -11
  29. diffusers/hooks/smoothed_energy_guidance_utils.py +167 -0
  30. diffusers/hooks/utils.py +43 -0
  31. diffusers/loaders/__init__.py +6 -0
  32. diffusers/loaders/ip_adapter.py +255 -4
  33. diffusers/loaders/lora_base.py +63 -30
  34. diffusers/loaders/lora_conversion_utils.py +434 -53
  35. diffusers/loaders/lora_pipeline.py +834 -37
  36. diffusers/loaders/peft.py +28 -5
  37. diffusers/loaders/single_file_model.py +44 -11
  38. diffusers/loaders/single_file_utils.py +170 -2
  39. diffusers/loaders/transformer_flux.py +9 -10
  40. diffusers/loaders/transformer_sd3.py +6 -1
  41. diffusers/loaders/unet.py +22 -5
  42. diffusers/loaders/unet_loader_utils.py +5 -2
  43. diffusers/models/__init__.py +8 -0
  44. diffusers/models/attention.py +484 -3
  45. diffusers/models/attention_dispatch.py +1218 -0
  46. diffusers/models/attention_processor.py +105 -663
  47. diffusers/models/auto_model.py +2 -2
  48. diffusers/models/autoencoders/__init__.py +1 -0
  49. diffusers/models/autoencoders/autoencoder_dc.py +14 -1
  50. diffusers/models/autoencoders/autoencoder_kl.py +1 -1
  51. diffusers/models/autoencoders/autoencoder_kl_cosmos.py +3 -1
  52. diffusers/models/autoencoders/autoencoder_kl_qwenimage.py +1070 -0
  53. diffusers/models/autoencoders/autoencoder_kl_wan.py +370 -40
  54. diffusers/models/cache_utils.py +31 -9
  55. diffusers/models/controlnets/controlnet_flux.py +5 -5
  56. diffusers/models/controlnets/controlnet_union.py +4 -4
  57. diffusers/models/embeddings.py +26 -34
  58. diffusers/models/model_loading_utils.py +233 -1
  59. diffusers/models/modeling_flax_utils.py +1 -2
  60. diffusers/models/modeling_utils.py +159 -94
  61. diffusers/models/transformers/__init__.py +2 -0
  62. diffusers/models/transformers/transformer_chroma.py +16 -117
  63. diffusers/models/transformers/transformer_cogview4.py +36 -2
  64. diffusers/models/transformers/transformer_cosmos.py +11 -4
  65. diffusers/models/transformers/transformer_flux.py +372 -132
  66. diffusers/models/transformers/transformer_hunyuan_video.py +6 -0
  67. diffusers/models/transformers/transformer_ltx.py +104 -23
  68. diffusers/models/transformers/transformer_qwenimage.py +645 -0
  69. diffusers/models/transformers/transformer_skyreels_v2.py +607 -0
  70. diffusers/models/transformers/transformer_wan.py +298 -85
  71. diffusers/models/transformers/transformer_wan_vace.py +15 -21
  72. diffusers/models/unets/unet_2d_condition.py +2 -1
  73. diffusers/modular_pipelines/__init__.py +83 -0
  74. diffusers/modular_pipelines/components_manager.py +1068 -0
  75. diffusers/modular_pipelines/flux/__init__.py +66 -0
  76. diffusers/modular_pipelines/flux/before_denoise.py +689 -0
  77. diffusers/modular_pipelines/flux/decoders.py +109 -0
  78. diffusers/modular_pipelines/flux/denoise.py +227 -0
  79. diffusers/modular_pipelines/flux/encoders.py +412 -0
  80. diffusers/modular_pipelines/flux/modular_blocks.py +181 -0
  81. diffusers/modular_pipelines/flux/modular_pipeline.py +59 -0
  82. diffusers/modular_pipelines/modular_pipeline.py +2446 -0
  83. diffusers/modular_pipelines/modular_pipeline_utils.py +672 -0
  84. diffusers/modular_pipelines/node_utils.py +665 -0
  85. diffusers/modular_pipelines/stable_diffusion_xl/__init__.py +77 -0
  86. diffusers/modular_pipelines/stable_diffusion_xl/before_denoise.py +1874 -0
  87. diffusers/modular_pipelines/stable_diffusion_xl/decoders.py +208 -0
  88. diffusers/modular_pipelines/stable_diffusion_xl/denoise.py +771 -0
  89. diffusers/modular_pipelines/stable_diffusion_xl/encoders.py +887 -0
  90. diffusers/modular_pipelines/stable_diffusion_xl/modular_blocks.py +380 -0
  91. diffusers/modular_pipelines/stable_diffusion_xl/modular_pipeline.py +365 -0
  92. diffusers/modular_pipelines/wan/__init__.py +66 -0
  93. diffusers/modular_pipelines/wan/before_denoise.py +365 -0
  94. diffusers/modular_pipelines/wan/decoders.py +105 -0
  95. diffusers/modular_pipelines/wan/denoise.py +261 -0
  96. diffusers/modular_pipelines/wan/encoders.py +242 -0
  97. diffusers/modular_pipelines/wan/modular_blocks.py +144 -0
  98. diffusers/modular_pipelines/wan/modular_pipeline.py +90 -0
  99. diffusers/pipelines/__init__.py +31 -0
  100. diffusers/pipelines/audioldm2/pipeline_audioldm2.py +2 -3
  101. diffusers/pipelines/auto_pipeline.py +17 -13
  102. diffusers/pipelines/chroma/pipeline_chroma.py +5 -5
  103. diffusers/pipelines/chroma/pipeline_chroma_img2img.py +5 -5
  104. diffusers/pipelines/cogvideo/pipeline_cogvideox.py +9 -8
  105. diffusers/pipelines/cogvideo/pipeline_cogvideox_fun_control.py +9 -8
  106. diffusers/pipelines/cogvideo/pipeline_cogvideox_image2video.py +10 -9
  107. diffusers/pipelines/cogvideo/pipeline_cogvideox_video2video.py +9 -8
  108. diffusers/pipelines/cogview4/pipeline_cogview4.py +16 -15
  109. diffusers/pipelines/controlnet/pipeline_controlnet_blip_diffusion.py +3 -2
  110. diffusers/pipelines/controlnet/pipeline_controlnet_union_inpaint_sd_xl.py +212 -93
  111. diffusers/pipelines/controlnet/pipeline_controlnet_union_sd_xl.py +7 -3
  112. diffusers/pipelines/controlnet/pipeline_controlnet_union_sd_xl_img2img.py +194 -92
  113. diffusers/pipelines/deprecated/stable_diffusion_variants/pipeline_cycle_diffusion.py +1 -1
  114. diffusers/pipelines/dit/pipeline_dit.py +3 -1
  115. diffusers/pipelines/flux/__init__.py +4 -0
  116. diffusers/pipelines/flux/pipeline_flux.py +34 -26
  117. diffusers/pipelines/flux/pipeline_flux_control.py +8 -8
  118. diffusers/pipelines/flux/pipeline_flux_control_img2img.py +1 -1
  119. diffusers/pipelines/flux/pipeline_flux_control_inpaint.py +1 -1
  120. diffusers/pipelines/flux/pipeline_flux_controlnet.py +1 -1
  121. diffusers/pipelines/flux/pipeline_flux_controlnet_image_to_image.py +1 -1
  122. diffusers/pipelines/flux/pipeline_flux_controlnet_inpainting.py +1 -1
  123. diffusers/pipelines/flux/pipeline_flux_fill.py +1 -1
  124. diffusers/pipelines/flux/pipeline_flux_img2img.py +1 -1
  125. diffusers/pipelines/flux/pipeline_flux_inpaint.py +1 -1
  126. diffusers/pipelines/flux/pipeline_flux_kontext.py +1134 -0
  127. diffusers/pipelines/flux/pipeline_flux_kontext_inpaint.py +1460 -0
  128. diffusers/pipelines/flux/pipeline_flux_prior_redux.py +1 -1
  129. diffusers/pipelines/flux/pipeline_output.py +6 -4
  130. diffusers/pipelines/hidream_image/pipeline_hidream_image.py +5 -5
  131. diffusers/pipelines/hunyuan_video/pipeline_hunyuan_video.py +25 -24
  132. diffusers/pipelines/ltx/pipeline_ltx.py +13 -12
  133. diffusers/pipelines/ltx/pipeline_ltx_condition.py +10 -9
  134. diffusers/pipelines/ltx/pipeline_ltx_image2video.py +13 -12
  135. diffusers/pipelines/mochi/pipeline_mochi.py +9 -8
  136. diffusers/pipelines/pipeline_flax_utils.py +2 -2
  137. diffusers/pipelines/pipeline_loading_utils.py +24 -2
  138. diffusers/pipelines/pipeline_utils.py +22 -15
  139. diffusers/pipelines/pixart_alpha/pipeline_pixart_alpha.py +3 -1
  140. diffusers/pipelines/pixart_alpha/pipeline_pixart_sigma.py +20 -0
  141. diffusers/pipelines/qwenimage/__init__.py +55 -0
  142. diffusers/pipelines/qwenimage/pipeline_output.py +21 -0
  143. diffusers/pipelines/qwenimage/pipeline_qwenimage.py +726 -0
  144. diffusers/pipelines/qwenimage/pipeline_qwenimage_edit.py +882 -0
  145. diffusers/pipelines/qwenimage/pipeline_qwenimage_img2img.py +829 -0
  146. diffusers/pipelines/qwenimage/pipeline_qwenimage_inpaint.py +1015 -0
  147. diffusers/pipelines/sana/pipeline_sana_sprint.py +5 -5
  148. diffusers/pipelines/skyreels_v2/__init__.py +59 -0
  149. diffusers/pipelines/skyreels_v2/pipeline_output.py +20 -0
  150. diffusers/pipelines/skyreels_v2/pipeline_skyreels_v2.py +610 -0
  151. diffusers/pipelines/skyreels_v2/pipeline_skyreels_v2_diffusion_forcing.py +978 -0
  152. diffusers/pipelines/skyreels_v2/pipeline_skyreels_v2_diffusion_forcing_i2v.py +1059 -0
  153. diffusers/pipelines/skyreels_v2/pipeline_skyreels_v2_diffusion_forcing_v2v.py +1063 -0
  154. diffusers/pipelines/skyreels_v2/pipeline_skyreels_v2_i2v.py +745 -0
  155. diffusers/pipelines/stable_diffusion/pipeline_onnx_stable_diffusion.py +2 -1
  156. diffusers/pipelines/stable_diffusion/pipeline_onnx_stable_diffusion_inpaint.py +1 -1
  157. diffusers/pipelines/stable_diffusion/pipeline_onnx_stable_diffusion_upscale.py +1 -1
  158. diffusers/pipelines/stable_diffusion/pipeline_stable_diffusion.py +2 -1
  159. diffusers/pipelines/stable_diffusion_3/pipeline_stable_diffusion_3.py +6 -5
  160. diffusers/pipelines/wan/pipeline_wan.py +78 -20
  161. diffusers/pipelines/wan/pipeline_wan_i2v.py +112 -32
  162. diffusers/pipelines/wan/pipeline_wan_vace.py +1 -2
  163. diffusers/quantizers/__init__.py +1 -177
  164. diffusers/quantizers/base.py +11 -0
  165. diffusers/quantizers/gguf/utils.py +92 -3
  166. diffusers/quantizers/pipe_quant_config.py +202 -0
  167. diffusers/quantizers/torchao/torchao_quantizer.py +26 -0
  168. diffusers/schedulers/scheduling_deis_multistep.py +8 -1
  169. diffusers/schedulers/scheduling_dpmsolver_multistep.py +6 -0
  170. diffusers/schedulers/scheduling_dpmsolver_singlestep.py +6 -0
  171. diffusers/schedulers/scheduling_scm.py +0 -1
  172. diffusers/schedulers/scheduling_unipc_multistep.py +10 -1
  173. diffusers/schedulers/scheduling_utils.py +2 -2
  174. diffusers/schedulers/scheduling_utils_flax.py +1 -1
  175. diffusers/training_utils.py +78 -0
  176. diffusers/utils/__init__.py +10 -0
  177. diffusers/utils/constants.py +4 -0
  178. diffusers/utils/dummy_pt_objects.py +312 -0
  179. diffusers/utils/dummy_torch_and_transformers_objects.py +255 -0
  180. diffusers/utils/dynamic_modules_utils.py +84 -25
  181. diffusers/utils/hub_utils.py +33 -17
  182. diffusers/utils/import_utils.py +70 -0
  183. diffusers/utils/peft_utils.py +11 -8
  184. diffusers/utils/testing_utils.py +136 -10
  185. diffusers/utils/torch_utils.py +18 -0
  186. {diffusers-0.34.0.dist-info → diffusers-0.35.0.dist-info}/METADATA +6 -6
  187. {diffusers-0.34.0.dist-info → diffusers-0.35.0.dist-info}/RECORD +191 -127
  188. {diffusers-0.34.0.dist-info → diffusers-0.35.0.dist-info}/LICENSE +0 -0
  189. {diffusers-0.34.0.dist-info → diffusers-0.35.0.dist-info}/WHEEL +0 -0
  190. {diffusers-0.34.0.dist-info → diffusers-0.35.0.dist-info}/entry_points.txt +0 -0
  191. {diffusers-0.34.0.dist-info → diffusers-0.35.0.dist-info}/top_level.txt +0 -0
@@ -12,14 +12,18 @@
12
12
  # See the License for the specific language governing permissions and
13
13
  # limitations under the License.
14
14
 
15
+ import hashlib
15
16
  import os
16
17
  from contextlib import contextmanager, nullcontext
18
+ from dataclasses import dataclass
19
+ from enum import Enum
17
20
  from typing import Dict, List, Optional, Set, Tuple, Union
18
21
 
19
22
  import safetensors.torch
20
23
  import torch
21
24
 
22
25
  from ..utils import get_logger, is_accelerate_available
26
+ from ._common import _GO_LC_SUPPORTED_PYTORCH_LAYERS
23
27
  from .hooks import HookRegistry, ModelHook
24
28
 
25
29
 
@@ -35,17 +39,28 @@ logger = get_logger(__name__) # pylint: disable=invalid-name
35
39
  _GROUP_OFFLOADING = "group_offloading"
36
40
  _LAYER_EXECUTION_TRACKER = "layer_execution_tracker"
37
41
  _LAZY_PREFETCH_GROUP_OFFLOADING = "lazy_prefetch_group_offloading"
38
-
39
- _SUPPORTED_PYTORCH_LAYERS = (
40
- torch.nn.Conv1d, torch.nn.Conv2d, torch.nn.Conv3d,
41
- torch.nn.ConvTranspose1d, torch.nn.ConvTranspose2d, torch.nn.ConvTranspose3d,
42
- torch.nn.Linear,
43
- # TODO(aryan): look into torch.nn.LayerNorm, torch.nn.GroupNorm later, seems to be causing some issues with CogVideoX
44
- # because of double invocation of the same norm layer in CogVideoXLayerNorm
45
- )
42
+ _GROUP_ID_LAZY_LEAF = "lazy_leafs"
46
43
  # fmt: on
47
44
 
48
45
 
46
+ class GroupOffloadingType(str, Enum):
47
+ BLOCK_LEVEL = "block_level"
48
+ LEAF_LEVEL = "leaf_level"
49
+
50
+
51
+ @dataclass
52
+ class GroupOffloadingConfig:
53
+ onload_device: torch.device
54
+ offload_device: torch.device
55
+ offload_type: GroupOffloadingType
56
+ non_blocking: bool
57
+ record_stream: bool
58
+ low_cpu_mem_usage: bool
59
+ num_blocks_per_group: Optional[int] = None
60
+ offload_to_disk_path: Optional[str] = None
61
+ stream: Optional[Union[torch.cuda.Stream, torch.Stream]] = None
62
+
63
+
49
64
  class ModuleGroup:
50
65
  def __init__(
51
66
  self,
@@ -62,6 +77,7 @@ class ModuleGroup:
62
77
  low_cpu_mem_usage: bool = False,
63
78
  onload_self: bool = True,
64
79
  offload_to_disk_path: Optional[str] = None,
80
+ group_id: Optional[int] = None,
65
81
  ) -> None:
66
82
  self.modules = modules
67
83
  self.offload_device = offload_device
@@ -79,8 +95,11 @@ class ModuleGroup:
79
95
  self.offload_to_disk_path = offload_to_disk_path
80
96
  self._is_offloaded_to_disk = False
81
97
 
82
- if self.offload_to_disk_path:
83
- self.safetensors_file_path = os.path.join(self.offload_to_disk_path, f"group_{id(self)}.safetensors")
98
+ if self.offload_to_disk_path is not None:
99
+ # Instead of `group_id or str(id(self))` we do this because `group_id` can be "" as well.
100
+ self.group_id = group_id if group_id is not None else str(id(self))
101
+ short_hash = _compute_group_hash(self.group_id)
102
+ self.safetensors_file_path = os.path.join(self.offload_to_disk_path, f"group_{short_hash}.safetensors")
84
103
 
85
104
  all_tensors = []
86
105
  for module in self.modules:
@@ -96,6 +115,12 @@ class ModuleGroup:
96
115
  else:
97
116
  self.cpu_param_dict = self._init_cpu_param_dict()
98
117
 
118
+ self._torch_accelerator_module = (
119
+ getattr(torch, torch.accelerator.current_accelerator().type)
120
+ if hasattr(torch, "accelerator")
121
+ else torch.cuda
122
+ )
123
+
99
124
  def _init_cpu_param_dict(self):
100
125
  cpu_param_dict = {}
101
126
  if self.stream is None:
@@ -119,128 +144,100 @@ class ModuleGroup:
119
144
 
120
145
  @contextmanager
121
146
  def _pinned_memory_tensors(self):
122
- pinned_dict = {}
123
147
  try:
124
- for param, tensor in self.cpu_param_dict.items():
125
- if not tensor.is_pinned():
126
- pinned_dict[param] = tensor.pin_memory()
127
- else:
128
- pinned_dict[param] = tensor
129
-
148
+ pinned_dict = {
149
+ param: tensor.pin_memory() if not tensor.is_pinned() else tensor
150
+ for param, tensor in self.cpu_param_dict.items()
151
+ }
130
152
  yield pinned_dict
131
-
132
153
  finally:
133
154
  pinned_dict = None
134
155
 
135
- @torch.compiler.disable()
136
- def onload_(self):
137
- r"""Onloads the group of modules to the onload_device."""
138
- torch_accelerator_module = (
139
- getattr(torch, torch.accelerator.current_accelerator().type)
140
- if hasattr(torch, "accelerator")
141
- else torch.cuda
142
- )
143
- context = nullcontext() if self.stream is None else torch_accelerator_module.stream(self.stream)
144
- current_stream = torch_accelerator_module.current_stream() if self.record_stream else None
156
+ def _transfer_tensor_to_device(self, tensor, source_tensor):
157
+ tensor.data = source_tensor.to(self.onload_device, non_blocking=self.non_blocking)
158
+ if self.record_stream:
159
+ tensor.data.record_stream(self._torch_accelerator_module.current_stream())
145
160
 
146
- if self.offload_to_disk_path:
147
- if self.stream is not None:
148
- # Wait for previous Host->Device transfer to complete
149
- self.stream.synchronize()
150
-
151
- with context:
152
- if self.stream is not None:
153
- # Load to CPU, pin, and async copy to device for overlapping transfer and compute
154
- loaded_cpu_tensors = safetensors.torch.load_file(self.safetensors_file_path, device="cpu")
155
- for key, tensor_obj in self.key_to_tensor.items():
156
- pinned_tensor = loaded_cpu_tensors[key].pin_memory()
157
- tensor_obj.data = pinned_tensor.to(self.onload_device, non_blocking=self.non_blocking)
158
- if self.record_stream:
159
- tensor_obj.data.record_stream(current_stream)
160
- else:
161
- # Load directly to the target device (synchronous)
162
- onload_device = (
163
- self.onload_device.type if isinstance(self.onload_device, torch.device) else self.onload_device
164
- )
165
- loaded_tensors = safetensors.torch.load_file(self.safetensors_file_path, device=onload_device)
166
- for key, tensor_obj in self.key_to_tensor.items():
167
- tensor_obj.data = loaded_tensors[key]
168
- return
161
+ def _process_tensors_from_modules(self, pinned_memory=None):
162
+ for group_module in self.modules:
163
+ for param in group_module.parameters():
164
+ source = pinned_memory[param] if pinned_memory else param.data
165
+ self._transfer_tensor_to_device(param, source)
166
+ for buffer in group_module.buffers():
167
+ source = pinned_memory[buffer] if pinned_memory else buffer.data
168
+ self._transfer_tensor_to_device(buffer, source)
169
+
170
+ for param in self.parameters:
171
+ source = pinned_memory[param] if pinned_memory else param.data
172
+ self._transfer_tensor_to_device(param, source)
169
173
 
174
+ for buffer in self.buffers:
175
+ source = pinned_memory[buffer] if pinned_memory else buffer.data
176
+ self._transfer_tensor_to_device(buffer, source)
177
+
178
+ def _onload_from_disk(self):
170
179
  if self.stream is not None:
171
180
  # Wait for previous Host->Device transfer to complete
172
181
  self.stream.synchronize()
173
182
 
183
+ context = nullcontext() if self.stream is None else self._torch_accelerator_module.stream(self.stream)
184
+ current_stream = self._torch_accelerator_module.current_stream() if self.record_stream else None
185
+
174
186
  with context:
175
- if self.stream is not None:
176
- with self._pinned_memory_tensors() as pinned_memory:
177
- for group_module in self.modules:
178
- for param in group_module.parameters():
179
- param.data = pinned_memory[param].to(self.onload_device, non_blocking=self.non_blocking)
180
- if self.record_stream:
181
- param.data.record_stream(current_stream)
182
- for buffer in group_module.buffers():
183
- buffer.data = pinned_memory[buffer].to(self.onload_device, non_blocking=self.non_blocking)
184
- if self.record_stream:
185
- buffer.data.record_stream(current_stream)
186
-
187
- for param in self.parameters:
188
- param.data = pinned_memory[param].to(self.onload_device, non_blocking=self.non_blocking)
189
- if self.record_stream:
190
- param.data.record_stream(current_stream)
191
-
192
- for buffer in self.buffers:
193
- buffer.data = pinned_memory[buffer].to(self.onload_device, non_blocking=self.non_blocking)
194
- if self.record_stream:
195
- buffer.data.record_stream(current_stream)
187
+ # Load to CPU (if using streams) or directly to target device, pin, and async copy to device
188
+ device = str(self.onload_device) if self.stream is None else "cpu"
189
+ loaded_tensors = safetensors.torch.load_file(self.safetensors_file_path, device=device)
196
190
 
191
+ if self.stream is not None:
192
+ for key, tensor_obj in self.key_to_tensor.items():
193
+ pinned_tensor = loaded_tensors[key].pin_memory()
194
+ tensor_obj.data = pinned_tensor.to(self.onload_device, non_blocking=self.non_blocking)
195
+ if self.record_stream:
196
+ tensor_obj.data.record_stream(current_stream)
197
197
  else:
198
- for group_module in self.modules:
199
- for param in group_module.parameters():
200
- param.data = param.data.to(self.onload_device, non_blocking=self.non_blocking)
201
- for buffer in group_module.buffers():
202
- buffer.data = buffer.data.to(self.onload_device, non_blocking=self.non_blocking)
203
-
204
- for param in self.parameters:
205
- param.data = param.data.to(self.onload_device, non_blocking=self.non_blocking)
198
+ onload_device = (
199
+ self.onload_device.type if isinstance(self.onload_device, torch.device) else self.onload_device
200
+ )
201
+ loaded_tensors = safetensors.torch.load_file(self.safetensors_file_path, device=onload_device)
202
+ for key, tensor_obj in self.key_to_tensor.items():
203
+ tensor_obj.data = loaded_tensors[key]
206
204
 
207
- for buffer in self.buffers:
208
- buffer.data = buffer.data.to(self.onload_device, non_blocking=self.non_blocking)
209
- if self.record_stream:
210
- buffer.data.record_stream(current_stream)
205
+ def _onload_from_memory(self):
206
+ if self.stream is not None:
207
+ # Wait for previous Host->Device transfer to complete
208
+ self.stream.synchronize()
211
209
 
212
- @torch.compiler.disable()
213
- def offload_(self):
214
- r"""Offloads the group of modules to the offload_device."""
215
- if self.offload_to_disk_path:
216
- # TODO: we can potentially optimize this code path by checking if the _all_ the desired
217
- # safetensor files exist on the disk and if so, skip this step entirely, reducing IO
218
- # overhead. Currently, we just check if the given `safetensors_file_path` exists and if not
219
- # we perform a write.
220
- # Check if the file has been saved in this session or if it already exists on disk.
221
- if not self._is_offloaded_to_disk and not os.path.exists(self.safetensors_file_path):
222
- os.makedirs(os.path.dirname(self.safetensors_file_path), exist_ok=True)
223
- tensors_to_save = {
224
- key: tensor.data.to(self.offload_device) for tensor, key in self.tensor_to_key.items()
225
- }
226
- safetensors.torch.save_file(tensors_to_save, self.safetensors_file_path)
227
-
228
- # The group is now considered offloaded to disk for the rest of the session.
229
- self._is_offloaded_to_disk = True
230
-
231
- # We do this to free up the RAM which is still holding the up tensor data.
232
- for tensor_obj in self.tensor_to_key.keys():
233
- tensor_obj.data = torch.empty_like(tensor_obj.data, device=self.offload_device)
234
- return
235
-
236
- torch_accelerator_module = (
237
- getattr(torch, torch.accelerator.current_accelerator().type)
238
- if hasattr(torch, "accelerator")
239
- else torch.cuda
240
- )
210
+ context = nullcontext() if self.stream is None else self._torch_accelerator_module.stream(self.stream)
211
+ with context:
212
+ if self.stream is not None:
213
+ with self._pinned_memory_tensors() as pinned_memory:
214
+ self._process_tensors_from_modules(pinned_memory)
215
+ else:
216
+ self._process_tensors_from_modules(None)
217
+
218
+ def _offload_to_disk(self):
219
+ # TODO: we can potentially optimize this code path by checking if the _all_ the desired
220
+ # safetensor files exist on the disk and if so, skip this step entirely, reducing IO
221
+ # overhead. Currently, we just check if the given `safetensors_file_path` exists and if not
222
+ # we perform a write.
223
+ # Check if the file has been saved in this session or if it already exists on disk.
224
+ if not self._is_offloaded_to_disk and not os.path.exists(self.safetensors_file_path):
225
+ os.makedirs(os.path.dirname(self.safetensors_file_path), exist_ok=True)
226
+ tensors_to_save = {key: tensor.data.to(self.offload_device) for tensor, key in self.tensor_to_key.items()}
227
+ safetensors.torch.save_file(tensors_to_save, self.safetensors_file_path)
228
+
229
+ # The group is now considered offloaded to disk for the rest of the session.
230
+ self._is_offloaded_to_disk = True
231
+
232
+ # We do this to free up the RAM which is still holding the up tensor data.
233
+ for tensor_obj in self.tensor_to_key.keys():
234
+ tensor_obj.data = torch.empty_like(tensor_obj.data, device=self.offload_device)
235
+
236
+ def _offload_to_memory(self):
241
237
  if self.stream is not None:
242
238
  if not self.record_stream:
243
- torch_accelerator_module.current_stream().synchronize()
239
+ self._torch_accelerator_module.current_stream().synchronize()
240
+
244
241
  for group_module in self.modules:
245
242
  for param in group_module.parameters():
246
243
  param.data = self.cpu_param_dict[param]
@@ -248,14 +245,29 @@ class ModuleGroup:
248
245
  param.data = self.cpu_param_dict[param]
249
246
  for buffer in self.buffers:
250
247
  buffer.data = self.cpu_param_dict[buffer]
251
-
252
248
  else:
253
249
  for group_module in self.modules:
254
- group_module.to(self.offload_device, non_blocking=self.non_blocking)
250
+ group_module.to(self.offload_device, non_blocking=False)
255
251
  for param in self.parameters:
256
- param.data = param.data.to(self.offload_device, non_blocking=self.non_blocking)
252
+ param.data = param.data.to(self.offload_device, non_blocking=False)
257
253
  for buffer in self.buffers:
258
- buffer.data = buffer.data.to(self.offload_device, non_blocking=self.non_blocking)
254
+ buffer.data = buffer.data.to(self.offload_device, non_blocking=False)
255
+
256
+ @torch.compiler.disable()
257
+ def onload_(self):
258
+ r"""Onloads the group of parameters to the onload_device."""
259
+ if self.offload_to_disk_path is not None:
260
+ self._onload_from_disk()
261
+ else:
262
+ self._onload_from_memory()
263
+
264
+ @torch.compiler.disable()
265
+ def offload_(self):
266
+ r"""Offloads the group of parameters to the offload_device."""
267
+ if self.offload_to_disk_path:
268
+ self._offload_to_disk()
269
+ else:
270
+ self._offload_to_memory()
259
271
 
260
272
 
261
273
  class GroupOffloadingHook(ModelHook):
@@ -268,9 +280,10 @@ class GroupOffloadingHook(ModelHook):
268
280
 
269
281
  _is_stateful = False
270
282
 
271
- def __init__(self, group: ModuleGroup, next_group: Optional[ModuleGroup] = None) -> None:
283
+ def __init__(self, group: ModuleGroup, *, config: GroupOffloadingConfig) -> None:
272
284
  self.group = group
273
- self.next_group = next_group
285
+ self.next_group: Optional[ModuleGroup] = None
286
+ self.config = config
274
287
 
275
288
  def initialize_hook(self, module: torch.nn.Module) -> torch.nn.Module:
276
289
  if self.group.offload_leader == module:
@@ -289,9 +302,23 @@ class GroupOffloadingHook(ModelHook):
289
302
  if self.group.onload_leader == module:
290
303
  if self.group.onload_self:
291
304
  self.group.onload_()
292
- if self.next_group is not None and not self.next_group.onload_self:
305
+
306
+ should_onload_next_group = self.next_group is not None and not self.next_group.onload_self
307
+ if should_onload_next_group:
293
308
  self.next_group.onload_()
294
309
 
310
+ should_synchronize = (
311
+ not self.group.onload_self and self.group.stream is not None and not should_onload_next_group
312
+ )
313
+ if should_synchronize:
314
+ # If this group didn't onload itself, it means it was asynchronously onloaded by the
315
+ # previous group. We need to synchronize the side stream to ensure parameters
316
+ # are completely loaded to proceed with forward pass. Without this, uninitialized
317
+ # weights will be used in the computation, leading to incorrect results
318
+ # Also, we should only do this synchronization if we don't already do it from the sync call in
319
+ # self.next_group.onload_, hence the `not should_onload_next_group` check.
320
+ self.group.stream.synchronize()
321
+
295
322
  args = send_to_device(args, self.group.onload_device, non_blocking=self.group.non_blocking)
296
323
  kwargs = send_to_device(kwargs, self.group.onload_device, non_blocking=self.group.non_blocking)
297
324
  return args, kwargs
@@ -319,7 +346,8 @@ class LazyPrefetchGroupOffloadingHook(ModelHook):
319
346
  def initialize_hook(self, module):
320
347
  def make_execution_order_update_callback(current_name, current_submodule):
321
348
  def callback():
322
- logger.debug(f"Adding {current_name} to the execution order")
349
+ if not torch.compiler.is_compiling():
350
+ logger.debug(f"Adding {current_name} to the execution order")
323
351
  self.execution_order.append((current_name, current_submodule))
324
352
 
325
353
  return callback
@@ -356,12 +384,13 @@ class LazyPrefetchGroupOffloadingHook(ModelHook):
356
384
  # if the missing layers end up being executed in the future.
357
385
  if execution_order_module_names != self._layer_execution_tracker_module_names:
358
386
  unexecuted_layers = list(self._layer_execution_tracker_module_names - execution_order_module_names)
359
- logger.warning(
360
- "It seems like some layers were not executed during the forward pass. This may lead to problems when "
361
- "applying lazy prefetching with automatic tracing and lead to device-mismatch related errors. Please "
362
- "make sure that all layers are executed during the forward pass. The following layers were not executed:\n"
363
- f"{unexecuted_layers=}"
364
- )
387
+ if not torch.compiler.is_compiling():
388
+ logger.warning(
389
+ "It seems like some layers were not executed during the forward pass. This may lead to problems when "
390
+ "applying lazy prefetching with automatic tracing and lead to device-mismatch related errors. Please "
391
+ "make sure that all layers are executed during the forward pass. The following layers were not executed:\n"
392
+ f"{unexecuted_layers=}"
393
+ )
365
394
 
366
395
  # Remove the layer execution tracker hooks from the submodules
367
396
  base_module_registry = module._diffusers_hook
@@ -389,7 +418,8 @@ class LazyPrefetchGroupOffloadingHook(ModelHook):
389
418
  for i in range(num_executed - 1):
390
419
  name1, _ = self.execution_order[i]
391
420
  name2, _ = self.execution_order[i + 1]
392
- logger.debug(f"Applying lazy prefetch group offloading from {name1} to {name2}")
421
+ if not torch.compiler.is_compiling():
422
+ logger.debug(f"Applying lazy prefetch group offloading from {name1} to {name2}")
393
423
  group_offloading_hooks[i].next_group = group_offloading_hooks[i + 1].group
394
424
  group_offloading_hooks[i].next_group.onload_self = False
395
425
 
@@ -414,9 +444,9 @@ class LayerExecutionTrackerHook(ModelHook):
414
444
 
415
445
  def apply_group_offloading(
416
446
  module: torch.nn.Module,
417
- onload_device: torch.device,
418
- offload_device: torch.device = torch.device("cpu"),
419
- offload_type: str = "block_level",
447
+ onload_device: Union[str, torch.device],
448
+ offload_device: Union[str, torch.device] = torch.device("cpu"),
449
+ offload_type: Union[str, GroupOffloadingType] = "block_level",
420
450
  num_blocks_per_group: Optional[int] = None,
421
451
  non_blocking: bool = False,
422
452
  use_stream: bool = False,
@@ -458,7 +488,7 @@ def apply_group_offloading(
458
488
  The device to which the group of modules are onloaded.
459
489
  offload_device (`torch.device`, defaults to `torch.device("cpu")`):
460
490
  The device to which the group of modules are offloaded. This should typically be the CPU. Default is CPU.
461
- offload_type (`str`, defaults to "block_level"):
491
+ offload_type (`str` or `GroupOffloadingType`, defaults to "block_level"):
462
492
  The type of offloading to be applied. Can be one of "block_level" or "leaf_level". Default is
463
493
  "block_level".
464
494
  offload_to_disk_path (`str`, *optional*, defaults to `None`):
@@ -501,6 +531,10 @@ def apply_group_offloading(
501
531
  ```
502
532
  """
503
533
 
534
+ onload_device = torch.device(onload_device) if isinstance(onload_device, str) else onload_device
535
+ offload_device = torch.device(offload_device) if isinstance(offload_device, str) else offload_device
536
+ offload_type = GroupOffloadingType(offload_type)
537
+
504
538
  stream = None
505
539
  if use_stream:
506
540
  if torch.cuda.is_available():
@@ -512,84 +546,45 @@ def apply_group_offloading(
512
546
 
513
547
  if not use_stream and record_stream:
514
548
  raise ValueError("`record_stream` cannot be True when `use_stream=False`.")
549
+ if offload_type == GroupOffloadingType.BLOCK_LEVEL and num_blocks_per_group is None:
550
+ raise ValueError("`num_blocks_per_group` must be provided when using `offload_type='block_level'.")
515
551
 
516
552
  _raise_error_if_accelerate_model_or_sequential_hook_present(module)
517
553
 
518
- if offload_type == "block_level":
519
- if num_blocks_per_group is None:
520
- raise ValueError("num_blocks_per_group must be provided when using offload_type='block_level'.")
521
-
522
- _apply_group_offloading_block_level(
523
- module=module,
524
- num_blocks_per_group=num_blocks_per_group,
525
- offload_device=offload_device,
526
- onload_device=onload_device,
527
- offload_to_disk_path=offload_to_disk_path,
528
- non_blocking=non_blocking,
529
- stream=stream,
530
- record_stream=record_stream,
531
- low_cpu_mem_usage=low_cpu_mem_usage,
532
- )
533
- elif offload_type == "leaf_level":
534
- _apply_group_offloading_leaf_level(
535
- module=module,
536
- offload_device=offload_device,
537
- onload_device=onload_device,
538
- offload_to_disk_path=offload_to_disk_path,
539
- non_blocking=non_blocking,
540
- stream=stream,
541
- record_stream=record_stream,
542
- low_cpu_mem_usage=low_cpu_mem_usage,
543
- )
554
+ config = GroupOffloadingConfig(
555
+ onload_device=onload_device,
556
+ offload_device=offload_device,
557
+ offload_type=offload_type,
558
+ num_blocks_per_group=num_blocks_per_group,
559
+ non_blocking=non_blocking,
560
+ stream=stream,
561
+ record_stream=record_stream,
562
+ low_cpu_mem_usage=low_cpu_mem_usage,
563
+ offload_to_disk_path=offload_to_disk_path,
564
+ )
565
+ _apply_group_offloading(module, config)
566
+
567
+
568
+ def _apply_group_offloading(module: torch.nn.Module, config: GroupOffloadingConfig) -> None:
569
+ if config.offload_type == GroupOffloadingType.BLOCK_LEVEL:
570
+ _apply_group_offloading_block_level(module, config)
571
+ elif config.offload_type == GroupOffloadingType.LEAF_LEVEL:
572
+ _apply_group_offloading_leaf_level(module, config)
544
573
  else:
545
- raise ValueError(f"Unsupported offload_type: {offload_type}")
574
+ assert False
546
575
 
547
576
 
548
- def _apply_group_offloading_block_level(
549
- module: torch.nn.Module,
550
- num_blocks_per_group: int,
551
- offload_device: torch.device,
552
- onload_device: torch.device,
553
- non_blocking: bool,
554
- stream: Union[torch.cuda.Stream, torch.Stream, None] = None,
555
- record_stream: Optional[bool] = False,
556
- low_cpu_mem_usage: bool = False,
557
- offload_to_disk_path: Optional[str] = None,
558
- ) -> None:
577
+ def _apply_group_offloading_block_level(module: torch.nn.Module, config: GroupOffloadingConfig) -> None:
559
578
  r"""
560
579
  This function applies offloading to groups of torch.nn.ModuleList or torch.nn.Sequential blocks. In comparison to
561
580
  the "leaf_level" offloading, which is more fine-grained, this offloading is done at the top-level blocks.
562
-
563
- Args:
564
- module (`torch.nn.Module`):
565
- The module to which group offloading is applied.
566
- offload_device (`torch.device`):
567
- The device to which the group of modules are offloaded. This should typically be the CPU.
568
- offload_to_disk_path (`str`, *optional*, defaults to `None`):
569
- The path to the directory where parameters will be offloaded. Setting this option can be useful in limited
570
- RAM environment settings where a reasonable speed-memory trade-off is desired.
571
- onload_device (`torch.device`):
572
- The device to which the group of modules are onloaded.
573
- non_blocking (`bool`):
574
- If True, offloading and onloading is done asynchronously. This can be useful for overlapping computation
575
- and data transfer.
576
- stream (`torch.cuda.Stream`or `torch.Stream`, *optional*):
577
- If provided, offloading and onloading is done asynchronously using the provided stream. This can be useful
578
- for overlapping computation and data transfer.
579
- record_stream (`bool`, defaults to `False`): When enabled with `use_stream`, it marks the current tensor
580
- as having been used by this stream. It is faster at the expense of slightly more memory usage. Refer to the
581
- [PyTorch official docs](https://pytorch.org/docs/stable/generated/torch.Tensor.record_stream.html) more
582
- details.
583
- low_cpu_mem_usage (`bool`, defaults to `False`):
584
- If True, the CPU memory usage is minimized by pinning tensors on-the-fly instead of pre-pinning them. This
585
- option only matters when using streamed CPU offloading (i.e. `use_stream=True`). This can be useful when
586
- the CPU memory is a bottleneck but may counteract the benefits of using streams.
587
581
  """
588
- if stream is not None and num_blocks_per_group != 1:
582
+
583
+ if config.stream is not None and config.num_blocks_per_group != 1:
589
584
  logger.warning(
590
- f"Using streams is only supported for num_blocks_per_group=1. Got {num_blocks_per_group=}. Setting it to 1."
585
+ f"Using streams is only supported for num_blocks_per_group=1. Got {config.num_blocks_per_group=}. Setting it to 1."
591
586
  )
592
- num_blocks_per_group = 1
587
+ config.num_blocks_per_group = 1
593
588
 
594
589
  # Create module groups for ModuleList and Sequential blocks
595
590
  modules_with_group_offloading = set()
@@ -601,20 +596,22 @@ def _apply_group_offloading_block_level(
601
596
  modules_with_group_offloading.add(name)
602
597
  continue
603
598
 
604
- for i in range(0, len(submodule), num_blocks_per_group):
605
- current_modules = submodule[i : i + num_blocks_per_group]
599
+ for i in range(0, len(submodule), config.num_blocks_per_group):
600
+ current_modules = submodule[i : i + config.num_blocks_per_group]
601
+ group_id = f"{name}_{i}_{i + len(current_modules) - 1}"
606
602
  group = ModuleGroup(
607
603
  modules=current_modules,
608
- offload_device=offload_device,
609
- onload_device=onload_device,
610
- offload_to_disk_path=offload_to_disk_path,
604
+ offload_device=config.offload_device,
605
+ onload_device=config.onload_device,
606
+ offload_to_disk_path=config.offload_to_disk_path,
611
607
  offload_leader=current_modules[-1],
612
608
  onload_leader=current_modules[0],
613
- non_blocking=non_blocking,
614
- stream=stream,
615
- record_stream=record_stream,
616
- low_cpu_mem_usage=low_cpu_mem_usage,
609
+ non_blocking=config.non_blocking,
610
+ stream=config.stream,
611
+ record_stream=config.record_stream,
612
+ low_cpu_mem_usage=config.low_cpu_mem_usage,
617
613
  onload_self=True,
614
+ group_id=group_id,
618
615
  )
619
616
  matched_module_groups.append(group)
620
617
  for j in range(i, i + len(current_modules)):
@@ -623,7 +620,7 @@ def _apply_group_offloading_block_level(
623
620
  # Apply group offloading hooks to the module groups
624
621
  for i, group in enumerate(matched_module_groups):
625
622
  for group_module in group.modules:
626
- _apply_group_offloading_hook(group_module, group, None)
623
+ _apply_group_offloading_hook(group_module, group, config=config)
627
624
 
628
625
  # Parameters and Buffers of the top-level module need to be offloaded/onloaded separately
629
626
  # when the forward pass of this module is called. This is because the top-level module is not
@@ -638,9 +635,9 @@ def _apply_group_offloading_block_level(
638
635
  unmatched_modules = [unmatched_module for _, unmatched_module in unmatched_modules]
639
636
  unmatched_group = ModuleGroup(
640
637
  modules=unmatched_modules,
641
- offload_device=offload_device,
642
- onload_device=onload_device,
643
- offload_to_disk_path=offload_to_disk_path,
638
+ offload_device=config.offload_device,
639
+ onload_device=config.onload_device,
640
+ offload_to_disk_path=config.offload_to_disk_path,
644
641
  offload_leader=module,
645
642
  onload_leader=module,
646
643
  parameters=parameters,
@@ -649,74 +646,41 @@ def _apply_group_offloading_block_level(
649
646
  stream=None,
650
647
  record_stream=False,
651
648
  onload_self=True,
649
+ group_id=f"{module.__class__.__name__}_unmatched_group",
652
650
  )
653
- if stream is None:
654
- _apply_group_offloading_hook(module, unmatched_group, None)
651
+ if config.stream is None:
652
+ _apply_group_offloading_hook(module, unmatched_group, config=config)
655
653
  else:
656
- _apply_lazy_group_offloading_hook(module, unmatched_group, None)
654
+ _apply_lazy_group_offloading_hook(module, unmatched_group, config=config)
657
655
 
658
656
 
659
- def _apply_group_offloading_leaf_level(
660
- module: torch.nn.Module,
661
- offload_device: torch.device,
662
- onload_device: torch.device,
663
- non_blocking: bool,
664
- stream: Union[torch.cuda.Stream, torch.Stream, None] = None,
665
- record_stream: Optional[bool] = False,
666
- low_cpu_mem_usage: bool = False,
667
- offload_to_disk_path: Optional[str] = None,
668
- ) -> None:
657
+ def _apply_group_offloading_leaf_level(module: torch.nn.Module, config: GroupOffloadingConfig) -> None:
669
658
  r"""
670
659
  This function applies offloading to groups of leaf modules in a torch.nn.Module. This method has minimal memory
671
660
  requirements. However, it can be slower compared to other offloading methods due to the excessive number of device
672
661
  synchronizations. When using devices that support streams to overlap data transfer and computation, this method can
673
662
  reduce memory usage without any performance degradation.
674
-
675
- Args:
676
- module (`torch.nn.Module`):
677
- The module to which group offloading is applied.
678
- offload_device (`torch.device`):
679
- The device to which the group of modules are offloaded. This should typically be the CPU.
680
- onload_device (`torch.device`):
681
- The device to which the group of modules are onloaded.
682
- offload_to_disk_path (`str`, *optional*, defaults to `None`):
683
- The path to the directory where parameters will be offloaded. Setting this option can be useful in limited
684
- RAM environment settings where a reasonable speed-memory trade-off is desired.
685
- non_blocking (`bool`):
686
- If True, offloading and onloading is done asynchronously. This can be useful for overlapping computation
687
- and data transfer.
688
- stream (`torch.cuda.Stream` or `torch.Stream`, *optional*):
689
- If provided, offloading and onloading is done asynchronously using the provided stream. This can be useful
690
- for overlapping computation and data transfer.
691
- record_stream (`bool`, defaults to `False`): When enabled with `use_stream`, it marks the current tensor
692
- as having been used by this stream. It is faster at the expense of slightly more memory usage. Refer to the
693
- [PyTorch official docs](https://pytorch.org/docs/stable/generated/torch.Tensor.record_stream.html) more
694
- details.
695
- low_cpu_mem_usage (`bool`, defaults to `False`):
696
- If True, the CPU memory usage is minimized by pinning tensors on-the-fly instead of pre-pinning them. This
697
- option only matters when using streamed CPU offloading (i.e. `use_stream=True`). This can be useful when
698
- the CPU memory is a bottleneck but may counteract the benefits of using streams.
699
663
  """
700
-
701
664
  # Create module groups for leaf modules and apply group offloading hooks
702
665
  modules_with_group_offloading = set()
703
666
  for name, submodule in module.named_modules():
704
- if not isinstance(submodule, _SUPPORTED_PYTORCH_LAYERS):
667
+ if not isinstance(submodule, _GO_LC_SUPPORTED_PYTORCH_LAYERS):
705
668
  continue
706
669
  group = ModuleGroup(
707
670
  modules=[submodule],
708
- offload_device=offload_device,
709
- onload_device=onload_device,
710
- offload_to_disk_path=offload_to_disk_path,
671
+ offload_device=config.offload_device,
672
+ onload_device=config.onload_device,
673
+ offload_to_disk_path=config.offload_to_disk_path,
711
674
  offload_leader=submodule,
712
675
  onload_leader=submodule,
713
- non_blocking=non_blocking,
714
- stream=stream,
715
- record_stream=record_stream,
716
- low_cpu_mem_usage=low_cpu_mem_usage,
676
+ non_blocking=config.non_blocking,
677
+ stream=config.stream,
678
+ record_stream=config.record_stream,
679
+ low_cpu_mem_usage=config.low_cpu_mem_usage,
717
680
  onload_self=True,
681
+ group_id=name,
718
682
  )
719
- _apply_group_offloading_hook(submodule, group, None)
683
+ _apply_group_offloading_hook(submodule, group, config=config)
720
684
  modules_with_group_offloading.add(name)
721
685
 
722
686
  # Parameters and Buffers at all non-leaf levels need to be offloaded/onloaded separately when the forward pass
@@ -747,33 +711,33 @@ def _apply_group_offloading_leaf_level(
747
711
  parameters = parent_to_parameters.get(name, [])
748
712
  buffers = parent_to_buffers.get(name, [])
749
713
  parent_module = module_dict[name]
750
- assert getattr(parent_module, "_diffusers_hook", None) is None
751
714
  group = ModuleGroup(
752
715
  modules=[],
753
- offload_device=offload_device,
754
- onload_device=onload_device,
716
+ offload_device=config.offload_device,
717
+ onload_device=config.onload_device,
755
718
  offload_leader=parent_module,
756
719
  onload_leader=parent_module,
757
- offload_to_disk_path=offload_to_disk_path,
720
+ offload_to_disk_path=config.offload_to_disk_path,
758
721
  parameters=parameters,
759
722
  buffers=buffers,
760
- non_blocking=non_blocking,
761
- stream=stream,
762
- record_stream=record_stream,
763
- low_cpu_mem_usage=low_cpu_mem_usage,
723
+ non_blocking=config.non_blocking,
724
+ stream=config.stream,
725
+ record_stream=config.record_stream,
726
+ low_cpu_mem_usage=config.low_cpu_mem_usage,
764
727
  onload_self=True,
728
+ group_id=name,
765
729
  )
766
- _apply_group_offloading_hook(parent_module, group, None)
730
+ _apply_group_offloading_hook(parent_module, group, config=config)
767
731
 
768
- if stream is not None:
732
+ if config.stream is not None:
769
733
  # When using streams, we need to know the layer execution order for applying prefetching (to overlap data transfer
770
734
  # and computation). Since we don't know the order beforehand, we apply a lazy prefetching hook that will find the
771
735
  # execution order and apply prefetching in the correct order.
772
736
  unmatched_group = ModuleGroup(
773
737
  modules=[],
774
- offload_device=offload_device,
775
- onload_device=onload_device,
776
- offload_to_disk_path=offload_to_disk_path,
738
+ offload_device=config.offload_device,
739
+ onload_device=config.onload_device,
740
+ offload_to_disk_path=config.offload_to_disk_path,
777
741
  offload_leader=module,
778
742
  onload_leader=module,
779
743
  parameters=None,
@@ -781,37 +745,40 @@ def _apply_group_offloading_leaf_level(
781
745
  non_blocking=False,
782
746
  stream=None,
783
747
  record_stream=False,
784
- low_cpu_mem_usage=low_cpu_mem_usage,
748
+ low_cpu_mem_usage=config.low_cpu_mem_usage,
785
749
  onload_self=True,
750
+ group_id=_GROUP_ID_LAZY_LEAF,
786
751
  )
787
- _apply_lazy_group_offloading_hook(module, unmatched_group, None)
752
+ _apply_lazy_group_offloading_hook(module, unmatched_group, config=config)
788
753
 
789
754
 
790
755
  def _apply_group_offloading_hook(
791
756
  module: torch.nn.Module,
792
757
  group: ModuleGroup,
793
- next_group: Optional[ModuleGroup] = None,
758
+ *,
759
+ config: GroupOffloadingConfig,
794
760
  ) -> None:
795
761
  registry = HookRegistry.check_if_exists_or_initialize(module)
796
762
 
797
763
  # We may have already registered a group offloading hook if the module had a torch.nn.Parameter whose parent
798
764
  # is the current module. In such cases, we don't want to overwrite the existing group offloading hook.
799
765
  if registry.get_hook(_GROUP_OFFLOADING) is None:
800
- hook = GroupOffloadingHook(group, next_group)
766
+ hook = GroupOffloadingHook(group, config=config)
801
767
  registry.register_hook(hook, _GROUP_OFFLOADING)
802
768
 
803
769
 
804
770
  def _apply_lazy_group_offloading_hook(
805
771
  module: torch.nn.Module,
806
772
  group: ModuleGroup,
807
- next_group: Optional[ModuleGroup] = None,
773
+ *,
774
+ config: GroupOffloadingConfig,
808
775
  ) -> None:
809
776
  registry = HookRegistry.check_if_exists_or_initialize(module)
810
777
 
811
778
  # We may have already registered a group offloading hook if the module had a torch.nn.Parameter whose parent
812
779
  # is the current module. In such cases, we don't want to overwrite the existing group offloading hook.
813
780
  if registry.get_hook(_GROUP_OFFLOADING) is None:
814
- hook = GroupOffloadingHook(group, next_group)
781
+ hook = GroupOffloadingHook(group, config=config)
815
782
  registry.register_hook(hook, _GROUP_OFFLOADING)
816
783
 
817
784
  lazy_prefetch_hook = LazyPrefetchGroupOffloadingHook()
@@ -878,15 +845,54 @@ def _raise_error_if_accelerate_model_or_sequential_hook_present(module: torch.nn
878
845
  )
879
846
 
880
847
 
881
- def _is_group_offload_enabled(module: torch.nn.Module) -> bool:
848
+ def _get_top_level_group_offload_hook(module: torch.nn.Module) -> Optional[GroupOffloadingHook]:
882
849
  for submodule in module.modules():
883
- if hasattr(submodule, "_diffusers_hook") and submodule._diffusers_hook.get_hook(_GROUP_OFFLOADING) is not None:
884
- return True
885
- return False
850
+ if hasattr(submodule, "_diffusers_hook"):
851
+ group_offloading_hook = submodule._diffusers_hook.get_hook(_GROUP_OFFLOADING)
852
+ if group_offloading_hook is not None:
853
+ return group_offloading_hook
854
+ return None
855
+
856
+
857
+ def _is_group_offload_enabled(module: torch.nn.Module) -> bool:
858
+ top_level_group_offload_hook = _get_top_level_group_offload_hook(module)
859
+ return top_level_group_offload_hook is not None
886
860
 
887
861
 
888
862
  def _get_group_onload_device(module: torch.nn.Module) -> torch.device:
889
- for submodule in module.modules():
890
- if hasattr(submodule, "_diffusers_hook") and submodule._diffusers_hook.get_hook(_GROUP_OFFLOADING) is not None:
891
- return submodule._diffusers_hook.get_hook(_GROUP_OFFLOADING).group.onload_device
863
+ top_level_group_offload_hook = _get_top_level_group_offload_hook(module)
864
+ if top_level_group_offload_hook is not None:
865
+ return top_level_group_offload_hook.config.onload_device
892
866
  raise ValueError("Group offloading is not enabled for the provided module.")
867
+
868
+
869
+ def _compute_group_hash(group_id):
870
+ hashed_id = hashlib.sha256(group_id.encode("utf-8")).hexdigest()
871
+ # first 16 characters for a reasonably short but unique name
872
+ return hashed_id[:16]
873
+
874
+
875
+ def _maybe_remove_and_reapply_group_offloading(module: torch.nn.Module) -> None:
876
+ r"""
877
+ Removes the group offloading hook from the module and re-applies it. This is useful when the module has been
878
+ modified in-place and the group offloading hook references-to-tensors needs to be updated. The in-place
879
+ modification can happen in a number of ways, for example, fusing QKV or unloading/loading LoRAs on-the-fly.
880
+
881
+ In this implementation, we make an assumption that group offloading has only been applied at the top-level module,
882
+ and therefore all submodules have the same onload and offload devices. If this assumption is not true, say in the
883
+ case where user has applied group offloading at multiple levels, this function will not work as expected.
884
+
885
+ There is some performance penalty associated with doing this when non-default streams are used, because we need to
886
+ retrace the execution order of the layers with `LazyPrefetchGroupOffloadingHook`.
887
+ """
888
+ top_level_group_offload_hook = _get_top_level_group_offload_hook(module)
889
+
890
+ if top_level_group_offload_hook is None:
891
+ return
892
+
893
+ registry = HookRegistry.check_if_exists_or_initialize(module)
894
+ registry.remove_hook(_GROUP_OFFLOADING, recurse=True)
895
+ registry.remove_hook(_LAYER_EXECUTION_TRACKER, recurse=True)
896
+ registry.remove_hook(_LAZY_PREFETCH_GROUP_OFFLOADING, recurse=True)
897
+
898
+ _apply_group_offloading(module, top_level_group_offload_hook.config)