diffusers 0.34.0__py3-none-any.whl → 0.35.0__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- diffusers/__init__.py +98 -1
- diffusers/callbacks.py +35 -0
- diffusers/commands/custom_blocks.py +134 -0
- diffusers/commands/diffusers_cli.py +2 -0
- diffusers/commands/fp16_safetensors.py +1 -1
- diffusers/configuration_utils.py +11 -2
- diffusers/dependency_versions_table.py +3 -3
- diffusers/guiders/__init__.py +41 -0
- diffusers/guiders/adaptive_projected_guidance.py +188 -0
- diffusers/guiders/auto_guidance.py +190 -0
- diffusers/guiders/classifier_free_guidance.py +141 -0
- diffusers/guiders/classifier_free_zero_star_guidance.py +152 -0
- diffusers/guiders/frequency_decoupled_guidance.py +327 -0
- diffusers/guiders/guider_utils.py +309 -0
- diffusers/guiders/perturbed_attention_guidance.py +271 -0
- diffusers/guiders/skip_layer_guidance.py +262 -0
- diffusers/guiders/smoothed_energy_guidance.py +251 -0
- diffusers/guiders/tangential_classifier_free_guidance.py +143 -0
- diffusers/hooks/__init__.py +17 -0
- diffusers/hooks/_common.py +56 -0
- diffusers/hooks/_helpers.py +293 -0
- diffusers/hooks/faster_cache.py +7 -6
- diffusers/hooks/first_block_cache.py +259 -0
- diffusers/hooks/group_offloading.py +292 -286
- diffusers/hooks/hooks.py +56 -1
- diffusers/hooks/layer_skip.py +263 -0
- diffusers/hooks/layerwise_casting.py +2 -7
- diffusers/hooks/pyramid_attention_broadcast.py +14 -11
- diffusers/hooks/smoothed_energy_guidance_utils.py +167 -0
- diffusers/hooks/utils.py +43 -0
- diffusers/loaders/__init__.py +6 -0
- diffusers/loaders/ip_adapter.py +255 -4
- diffusers/loaders/lora_base.py +63 -30
- diffusers/loaders/lora_conversion_utils.py +434 -53
- diffusers/loaders/lora_pipeline.py +834 -37
- diffusers/loaders/peft.py +28 -5
- diffusers/loaders/single_file_model.py +44 -11
- diffusers/loaders/single_file_utils.py +170 -2
- diffusers/loaders/transformer_flux.py +9 -10
- diffusers/loaders/transformer_sd3.py +6 -1
- diffusers/loaders/unet.py +22 -5
- diffusers/loaders/unet_loader_utils.py +5 -2
- diffusers/models/__init__.py +8 -0
- diffusers/models/attention.py +484 -3
- diffusers/models/attention_dispatch.py +1218 -0
- diffusers/models/attention_processor.py +105 -663
- diffusers/models/auto_model.py +2 -2
- diffusers/models/autoencoders/__init__.py +1 -0
- diffusers/models/autoencoders/autoencoder_dc.py +14 -1
- diffusers/models/autoencoders/autoencoder_kl.py +1 -1
- diffusers/models/autoencoders/autoencoder_kl_cosmos.py +3 -1
- diffusers/models/autoencoders/autoencoder_kl_qwenimage.py +1070 -0
- diffusers/models/autoencoders/autoencoder_kl_wan.py +370 -40
- diffusers/models/cache_utils.py +31 -9
- diffusers/models/controlnets/controlnet_flux.py +5 -5
- diffusers/models/controlnets/controlnet_union.py +4 -4
- diffusers/models/embeddings.py +26 -34
- diffusers/models/model_loading_utils.py +233 -1
- diffusers/models/modeling_flax_utils.py +1 -2
- diffusers/models/modeling_utils.py +159 -94
- diffusers/models/transformers/__init__.py +2 -0
- diffusers/models/transformers/transformer_chroma.py +16 -117
- diffusers/models/transformers/transformer_cogview4.py +36 -2
- diffusers/models/transformers/transformer_cosmos.py +11 -4
- diffusers/models/transformers/transformer_flux.py +372 -132
- diffusers/models/transformers/transformer_hunyuan_video.py +6 -0
- diffusers/models/transformers/transformer_ltx.py +104 -23
- diffusers/models/transformers/transformer_qwenimage.py +645 -0
- diffusers/models/transformers/transformer_skyreels_v2.py +607 -0
- diffusers/models/transformers/transformer_wan.py +298 -85
- diffusers/models/transformers/transformer_wan_vace.py +15 -21
- diffusers/models/unets/unet_2d_condition.py +2 -1
- diffusers/modular_pipelines/__init__.py +83 -0
- diffusers/modular_pipelines/components_manager.py +1068 -0
- diffusers/modular_pipelines/flux/__init__.py +66 -0
- diffusers/modular_pipelines/flux/before_denoise.py +689 -0
- diffusers/modular_pipelines/flux/decoders.py +109 -0
- diffusers/modular_pipelines/flux/denoise.py +227 -0
- diffusers/modular_pipelines/flux/encoders.py +412 -0
- diffusers/modular_pipelines/flux/modular_blocks.py +181 -0
- diffusers/modular_pipelines/flux/modular_pipeline.py +59 -0
- diffusers/modular_pipelines/modular_pipeline.py +2446 -0
- diffusers/modular_pipelines/modular_pipeline_utils.py +672 -0
- diffusers/modular_pipelines/node_utils.py +665 -0
- diffusers/modular_pipelines/stable_diffusion_xl/__init__.py +77 -0
- diffusers/modular_pipelines/stable_diffusion_xl/before_denoise.py +1874 -0
- diffusers/modular_pipelines/stable_diffusion_xl/decoders.py +208 -0
- diffusers/modular_pipelines/stable_diffusion_xl/denoise.py +771 -0
- diffusers/modular_pipelines/stable_diffusion_xl/encoders.py +887 -0
- diffusers/modular_pipelines/stable_diffusion_xl/modular_blocks.py +380 -0
- diffusers/modular_pipelines/stable_diffusion_xl/modular_pipeline.py +365 -0
- diffusers/modular_pipelines/wan/__init__.py +66 -0
- diffusers/modular_pipelines/wan/before_denoise.py +365 -0
- diffusers/modular_pipelines/wan/decoders.py +105 -0
- diffusers/modular_pipelines/wan/denoise.py +261 -0
- diffusers/modular_pipelines/wan/encoders.py +242 -0
- diffusers/modular_pipelines/wan/modular_blocks.py +144 -0
- diffusers/modular_pipelines/wan/modular_pipeline.py +90 -0
- diffusers/pipelines/__init__.py +31 -0
- diffusers/pipelines/audioldm2/pipeline_audioldm2.py +2 -3
- diffusers/pipelines/auto_pipeline.py +17 -13
- diffusers/pipelines/chroma/pipeline_chroma.py +5 -5
- diffusers/pipelines/chroma/pipeline_chroma_img2img.py +5 -5
- diffusers/pipelines/cogvideo/pipeline_cogvideox.py +9 -8
- diffusers/pipelines/cogvideo/pipeline_cogvideox_fun_control.py +9 -8
- diffusers/pipelines/cogvideo/pipeline_cogvideox_image2video.py +10 -9
- diffusers/pipelines/cogvideo/pipeline_cogvideox_video2video.py +9 -8
- diffusers/pipelines/cogview4/pipeline_cogview4.py +16 -15
- diffusers/pipelines/controlnet/pipeline_controlnet_blip_diffusion.py +3 -2
- diffusers/pipelines/controlnet/pipeline_controlnet_union_inpaint_sd_xl.py +212 -93
- diffusers/pipelines/controlnet/pipeline_controlnet_union_sd_xl.py +7 -3
- diffusers/pipelines/controlnet/pipeline_controlnet_union_sd_xl_img2img.py +194 -92
- diffusers/pipelines/deprecated/stable_diffusion_variants/pipeline_cycle_diffusion.py +1 -1
- diffusers/pipelines/dit/pipeline_dit.py +3 -1
- diffusers/pipelines/flux/__init__.py +4 -0
- diffusers/pipelines/flux/pipeline_flux.py +34 -26
- diffusers/pipelines/flux/pipeline_flux_control.py +8 -8
- diffusers/pipelines/flux/pipeline_flux_control_img2img.py +1 -1
- diffusers/pipelines/flux/pipeline_flux_control_inpaint.py +1 -1
- diffusers/pipelines/flux/pipeline_flux_controlnet.py +1 -1
- diffusers/pipelines/flux/pipeline_flux_controlnet_image_to_image.py +1 -1
- diffusers/pipelines/flux/pipeline_flux_controlnet_inpainting.py +1 -1
- diffusers/pipelines/flux/pipeline_flux_fill.py +1 -1
- diffusers/pipelines/flux/pipeline_flux_img2img.py +1 -1
- diffusers/pipelines/flux/pipeline_flux_inpaint.py +1 -1
- diffusers/pipelines/flux/pipeline_flux_kontext.py +1134 -0
- diffusers/pipelines/flux/pipeline_flux_kontext_inpaint.py +1460 -0
- diffusers/pipelines/flux/pipeline_flux_prior_redux.py +1 -1
- diffusers/pipelines/flux/pipeline_output.py +6 -4
- diffusers/pipelines/hidream_image/pipeline_hidream_image.py +5 -5
- diffusers/pipelines/hunyuan_video/pipeline_hunyuan_video.py +25 -24
- diffusers/pipelines/ltx/pipeline_ltx.py +13 -12
- diffusers/pipelines/ltx/pipeline_ltx_condition.py +10 -9
- diffusers/pipelines/ltx/pipeline_ltx_image2video.py +13 -12
- diffusers/pipelines/mochi/pipeline_mochi.py +9 -8
- diffusers/pipelines/pipeline_flax_utils.py +2 -2
- diffusers/pipelines/pipeline_loading_utils.py +24 -2
- diffusers/pipelines/pipeline_utils.py +22 -15
- diffusers/pipelines/pixart_alpha/pipeline_pixart_alpha.py +3 -1
- diffusers/pipelines/pixart_alpha/pipeline_pixart_sigma.py +20 -0
- diffusers/pipelines/qwenimage/__init__.py +55 -0
- diffusers/pipelines/qwenimage/pipeline_output.py +21 -0
- diffusers/pipelines/qwenimage/pipeline_qwenimage.py +726 -0
- diffusers/pipelines/qwenimage/pipeline_qwenimage_edit.py +882 -0
- diffusers/pipelines/qwenimage/pipeline_qwenimage_img2img.py +829 -0
- diffusers/pipelines/qwenimage/pipeline_qwenimage_inpaint.py +1015 -0
- diffusers/pipelines/sana/pipeline_sana_sprint.py +5 -5
- diffusers/pipelines/skyreels_v2/__init__.py +59 -0
- diffusers/pipelines/skyreels_v2/pipeline_output.py +20 -0
- diffusers/pipelines/skyreels_v2/pipeline_skyreels_v2.py +610 -0
- diffusers/pipelines/skyreels_v2/pipeline_skyreels_v2_diffusion_forcing.py +978 -0
- diffusers/pipelines/skyreels_v2/pipeline_skyreels_v2_diffusion_forcing_i2v.py +1059 -0
- diffusers/pipelines/skyreels_v2/pipeline_skyreels_v2_diffusion_forcing_v2v.py +1063 -0
- diffusers/pipelines/skyreels_v2/pipeline_skyreels_v2_i2v.py +745 -0
- diffusers/pipelines/stable_diffusion/pipeline_onnx_stable_diffusion.py +2 -1
- diffusers/pipelines/stable_diffusion/pipeline_onnx_stable_diffusion_inpaint.py +1 -1
- diffusers/pipelines/stable_diffusion/pipeline_onnx_stable_diffusion_upscale.py +1 -1
- diffusers/pipelines/stable_diffusion/pipeline_stable_diffusion.py +2 -1
- diffusers/pipelines/stable_diffusion_3/pipeline_stable_diffusion_3.py +6 -5
- diffusers/pipelines/wan/pipeline_wan.py +78 -20
- diffusers/pipelines/wan/pipeline_wan_i2v.py +112 -32
- diffusers/pipelines/wan/pipeline_wan_vace.py +1 -2
- diffusers/quantizers/__init__.py +1 -177
- diffusers/quantizers/base.py +11 -0
- diffusers/quantizers/gguf/utils.py +92 -3
- diffusers/quantizers/pipe_quant_config.py +202 -0
- diffusers/quantizers/torchao/torchao_quantizer.py +26 -0
- diffusers/schedulers/scheduling_deis_multistep.py +8 -1
- diffusers/schedulers/scheduling_dpmsolver_multistep.py +6 -0
- diffusers/schedulers/scheduling_dpmsolver_singlestep.py +6 -0
- diffusers/schedulers/scheduling_scm.py +0 -1
- diffusers/schedulers/scheduling_unipc_multistep.py +10 -1
- diffusers/schedulers/scheduling_utils.py +2 -2
- diffusers/schedulers/scheduling_utils_flax.py +1 -1
- diffusers/training_utils.py +78 -0
- diffusers/utils/__init__.py +10 -0
- diffusers/utils/constants.py +4 -0
- diffusers/utils/dummy_pt_objects.py +312 -0
- diffusers/utils/dummy_torch_and_transformers_objects.py +255 -0
- diffusers/utils/dynamic_modules_utils.py +84 -25
- diffusers/utils/hub_utils.py +33 -17
- diffusers/utils/import_utils.py +70 -0
- diffusers/utils/peft_utils.py +11 -8
- diffusers/utils/testing_utils.py +136 -10
- diffusers/utils/torch_utils.py +18 -0
- {diffusers-0.34.0.dist-info → diffusers-0.35.0.dist-info}/METADATA +6 -6
- {diffusers-0.34.0.dist-info → diffusers-0.35.0.dist-info}/RECORD +191 -127
- {diffusers-0.34.0.dist-info → diffusers-0.35.0.dist-info}/LICENSE +0 -0
- {diffusers-0.34.0.dist-info → diffusers-0.35.0.dist-info}/WHEEL +0 -0
- {diffusers-0.34.0.dist-info → diffusers-0.35.0.dist-info}/entry_points.txt +0 -0
- {diffusers-0.34.0.dist-info → diffusers-0.35.0.dist-info}/top_level.txt +0 -0
@@ -12,14 +12,18 @@
|
|
12
12
|
# See the License for the specific language governing permissions and
|
13
13
|
# limitations under the License.
|
14
14
|
|
15
|
+
import hashlib
|
15
16
|
import os
|
16
17
|
from contextlib import contextmanager, nullcontext
|
18
|
+
from dataclasses import dataclass
|
19
|
+
from enum import Enum
|
17
20
|
from typing import Dict, List, Optional, Set, Tuple, Union
|
18
21
|
|
19
22
|
import safetensors.torch
|
20
23
|
import torch
|
21
24
|
|
22
25
|
from ..utils import get_logger, is_accelerate_available
|
26
|
+
from ._common import _GO_LC_SUPPORTED_PYTORCH_LAYERS
|
23
27
|
from .hooks import HookRegistry, ModelHook
|
24
28
|
|
25
29
|
|
@@ -35,17 +39,28 @@ logger = get_logger(__name__) # pylint: disable=invalid-name
|
|
35
39
|
_GROUP_OFFLOADING = "group_offloading"
|
36
40
|
_LAYER_EXECUTION_TRACKER = "layer_execution_tracker"
|
37
41
|
_LAZY_PREFETCH_GROUP_OFFLOADING = "lazy_prefetch_group_offloading"
|
38
|
-
|
39
|
-
_SUPPORTED_PYTORCH_LAYERS = (
|
40
|
-
torch.nn.Conv1d, torch.nn.Conv2d, torch.nn.Conv3d,
|
41
|
-
torch.nn.ConvTranspose1d, torch.nn.ConvTranspose2d, torch.nn.ConvTranspose3d,
|
42
|
-
torch.nn.Linear,
|
43
|
-
# TODO(aryan): look into torch.nn.LayerNorm, torch.nn.GroupNorm later, seems to be causing some issues with CogVideoX
|
44
|
-
# because of double invocation of the same norm layer in CogVideoXLayerNorm
|
45
|
-
)
|
42
|
+
_GROUP_ID_LAZY_LEAF = "lazy_leafs"
|
46
43
|
# fmt: on
|
47
44
|
|
48
45
|
|
46
|
+
class GroupOffloadingType(str, Enum):
|
47
|
+
BLOCK_LEVEL = "block_level"
|
48
|
+
LEAF_LEVEL = "leaf_level"
|
49
|
+
|
50
|
+
|
51
|
+
@dataclass
|
52
|
+
class GroupOffloadingConfig:
|
53
|
+
onload_device: torch.device
|
54
|
+
offload_device: torch.device
|
55
|
+
offload_type: GroupOffloadingType
|
56
|
+
non_blocking: bool
|
57
|
+
record_stream: bool
|
58
|
+
low_cpu_mem_usage: bool
|
59
|
+
num_blocks_per_group: Optional[int] = None
|
60
|
+
offload_to_disk_path: Optional[str] = None
|
61
|
+
stream: Optional[Union[torch.cuda.Stream, torch.Stream]] = None
|
62
|
+
|
63
|
+
|
49
64
|
class ModuleGroup:
|
50
65
|
def __init__(
|
51
66
|
self,
|
@@ -62,6 +77,7 @@ class ModuleGroup:
|
|
62
77
|
low_cpu_mem_usage: bool = False,
|
63
78
|
onload_self: bool = True,
|
64
79
|
offload_to_disk_path: Optional[str] = None,
|
80
|
+
group_id: Optional[int] = None,
|
65
81
|
) -> None:
|
66
82
|
self.modules = modules
|
67
83
|
self.offload_device = offload_device
|
@@ -79,8 +95,11 @@ class ModuleGroup:
|
|
79
95
|
self.offload_to_disk_path = offload_to_disk_path
|
80
96
|
self._is_offloaded_to_disk = False
|
81
97
|
|
82
|
-
if self.offload_to_disk_path:
|
83
|
-
|
98
|
+
if self.offload_to_disk_path is not None:
|
99
|
+
# Instead of `group_id or str(id(self))` we do this because `group_id` can be "" as well.
|
100
|
+
self.group_id = group_id if group_id is not None else str(id(self))
|
101
|
+
short_hash = _compute_group_hash(self.group_id)
|
102
|
+
self.safetensors_file_path = os.path.join(self.offload_to_disk_path, f"group_{short_hash}.safetensors")
|
84
103
|
|
85
104
|
all_tensors = []
|
86
105
|
for module in self.modules:
|
@@ -96,6 +115,12 @@ class ModuleGroup:
|
|
96
115
|
else:
|
97
116
|
self.cpu_param_dict = self._init_cpu_param_dict()
|
98
117
|
|
118
|
+
self._torch_accelerator_module = (
|
119
|
+
getattr(torch, torch.accelerator.current_accelerator().type)
|
120
|
+
if hasattr(torch, "accelerator")
|
121
|
+
else torch.cuda
|
122
|
+
)
|
123
|
+
|
99
124
|
def _init_cpu_param_dict(self):
|
100
125
|
cpu_param_dict = {}
|
101
126
|
if self.stream is None:
|
@@ -119,128 +144,100 @@ class ModuleGroup:
|
|
119
144
|
|
120
145
|
@contextmanager
|
121
146
|
def _pinned_memory_tensors(self):
|
122
|
-
pinned_dict = {}
|
123
147
|
try:
|
124
|
-
|
125
|
-
if not tensor.is_pinned()
|
126
|
-
|
127
|
-
|
128
|
-
pinned_dict[param] = tensor
|
129
|
-
|
148
|
+
pinned_dict = {
|
149
|
+
param: tensor.pin_memory() if not tensor.is_pinned() else tensor
|
150
|
+
for param, tensor in self.cpu_param_dict.items()
|
151
|
+
}
|
130
152
|
yield pinned_dict
|
131
|
-
|
132
153
|
finally:
|
133
154
|
pinned_dict = None
|
134
155
|
|
135
|
-
|
136
|
-
|
137
|
-
|
138
|
-
|
139
|
-
getattr(torch, torch.accelerator.current_accelerator().type)
|
140
|
-
if hasattr(torch, "accelerator")
|
141
|
-
else torch.cuda
|
142
|
-
)
|
143
|
-
context = nullcontext() if self.stream is None else torch_accelerator_module.stream(self.stream)
|
144
|
-
current_stream = torch_accelerator_module.current_stream() if self.record_stream else None
|
156
|
+
def _transfer_tensor_to_device(self, tensor, source_tensor):
|
157
|
+
tensor.data = source_tensor.to(self.onload_device, non_blocking=self.non_blocking)
|
158
|
+
if self.record_stream:
|
159
|
+
tensor.data.record_stream(self._torch_accelerator_module.current_stream())
|
145
160
|
|
146
|
-
|
147
|
-
|
148
|
-
|
149
|
-
|
150
|
-
|
151
|
-
|
152
|
-
if
|
153
|
-
|
154
|
-
|
155
|
-
|
156
|
-
|
157
|
-
|
158
|
-
if self.record_stream:
|
159
|
-
tensor_obj.data.record_stream(current_stream)
|
160
|
-
else:
|
161
|
-
# Load directly to the target device (synchronous)
|
162
|
-
onload_device = (
|
163
|
-
self.onload_device.type if isinstance(self.onload_device, torch.device) else self.onload_device
|
164
|
-
)
|
165
|
-
loaded_tensors = safetensors.torch.load_file(self.safetensors_file_path, device=onload_device)
|
166
|
-
for key, tensor_obj in self.key_to_tensor.items():
|
167
|
-
tensor_obj.data = loaded_tensors[key]
|
168
|
-
return
|
161
|
+
def _process_tensors_from_modules(self, pinned_memory=None):
|
162
|
+
for group_module in self.modules:
|
163
|
+
for param in group_module.parameters():
|
164
|
+
source = pinned_memory[param] if pinned_memory else param.data
|
165
|
+
self._transfer_tensor_to_device(param, source)
|
166
|
+
for buffer in group_module.buffers():
|
167
|
+
source = pinned_memory[buffer] if pinned_memory else buffer.data
|
168
|
+
self._transfer_tensor_to_device(buffer, source)
|
169
|
+
|
170
|
+
for param in self.parameters:
|
171
|
+
source = pinned_memory[param] if pinned_memory else param.data
|
172
|
+
self._transfer_tensor_to_device(param, source)
|
169
173
|
|
174
|
+
for buffer in self.buffers:
|
175
|
+
source = pinned_memory[buffer] if pinned_memory else buffer.data
|
176
|
+
self._transfer_tensor_to_device(buffer, source)
|
177
|
+
|
178
|
+
def _onload_from_disk(self):
|
170
179
|
if self.stream is not None:
|
171
180
|
# Wait for previous Host->Device transfer to complete
|
172
181
|
self.stream.synchronize()
|
173
182
|
|
183
|
+
context = nullcontext() if self.stream is None else self._torch_accelerator_module.stream(self.stream)
|
184
|
+
current_stream = self._torch_accelerator_module.current_stream() if self.record_stream else None
|
185
|
+
|
174
186
|
with context:
|
175
|
-
if
|
176
|
-
|
177
|
-
|
178
|
-
for param in group_module.parameters():
|
179
|
-
param.data = pinned_memory[param].to(self.onload_device, non_blocking=self.non_blocking)
|
180
|
-
if self.record_stream:
|
181
|
-
param.data.record_stream(current_stream)
|
182
|
-
for buffer in group_module.buffers():
|
183
|
-
buffer.data = pinned_memory[buffer].to(self.onload_device, non_blocking=self.non_blocking)
|
184
|
-
if self.record_stream:
|
185
|
-
buffer.data.record_stream(current_stream)
|
186
|
-
|
187
|
-
for param in self.parameters:
|
188
|
-
param.data = pinned_memory[param].to(self.onload_device, non_blocking=self.non_blocking)
|
189
|
-
if self.record_stream:
|
190
|
-
param.data.record_stream(current_stream)
|
191
|
-
|
192
|
-
for buffer in self.buffers:
|
193
|
-
buffer.data = pinned_memory[buffer].to(self.onload_device, non_blocking=self.non_blocking)
|
194
|
-
if self.record_stream:
|
195
|
-
buffer.data.record_stream(current_stream)
|
187
|
+
# Load to CPU (if using streams) or directly to target device, pin, and async copy to device
|
188
|
+
device = str(self.onload_device) if self.stream is None else "cpu"
|
189
|
+
loaded_tensors = safetensors.torch.load_file(self.safetensors_file_path, device=device)
|
196
190
|
|
191
|
+
if self.stream is not None:
|
192
|
+
for key, tensor_obj in self.key_to_tensor.items():
|
193
|
+
pinned_tensor = loaded_tensors[key].pin_memory()
|
194
|
+
tensor_obj.data = pinned_tensor.to(self.onload_device, non_blocking=self.non_blocking)
|
195
|
+
if self.record_stream:
|
196
|
+
tensor_obj.data.record_stream(current_stream)
|
197
197
|
else:
|
198
|
-
|
199
|
-
|
200
|
-
|
201
|
-
|
202
|
-
|
203
|
-
|
204
|
-
for param in self.parameters:
|
205
|
-
param.data = param.data.to(self.onload_device, non_blocking=self.non_blocking)
|
198
|
+
onload_device = (
|
199
|
+
self.onload_device.type if isinstance(self.onload_device, torch.device) else self.onload_device
|
200
|
+
)
|
201
|
+
loaded_tensors = safetensors.torch.load_file(self.safetensors_file_path, device=onload_device)
|
202
|
+
for key, tensor_obj in self.key_to_tensor.items():
|
203
|
+
tensor_obj.data = loaded_tensors[key]
|
206
204
|
|
207
|
-
|
208
|
-
|
209
|
-
|
210
|
-
|
205
|
+
def _onload_from_memory(self):
|
206
|
+
if self.stream is not None:
|
207
|
+
# Wait for previous Host->Device transfer to complete
|
208
|
+
self.stream.synchronize()
|
211
209
|
|
212
|
-
|
213
|
-
|
214
|
-
|
215
|
-
|
216
|
-
|
217
|
-
|
218
|
-
|
219
|
-
|
220
|
-
|
221
|
-
|
222
|
-
|
223
|
-
|
224
|
-
|
225
|
-
|
226
|
-
|
227
|
-
|
228
|
-
|
229
|
-
self.
|
230
|
-
|
231
|
-
|
232
|
-
|
233
|
-
|
234
|
-
|
235
|
-
|
236
|
-
|
237
|
-
|
238
|
-
|
239
|
-
else torch.cuda
|
240
|
-
)
|
210
|
+
context = nullcontext() if self.stream is None else self._torch_accelerator_module.stream(self.stream)
|
211
|
+
with context:
|
212
|
+
if self.stream is not None:
|
213
|
+
with self._pinned_memory_tensors() as pinned_memory:
|
214
|
+
self._process_tensors_from_modules(pinned_memory)
|
215
|
+
else:
|
216
|
+
self._process_tensors_from_modules(None)
|
217
|
+
|
218
|
+
def _offload_to_disk(self):
|
219
|
+
# TODO: we can potentially optimize this code path by checking if the _all_ the desired
|
220
|
+
# safetensor files exist on the disk and if so, skip this step entirely, reducing IO
|
221
|
+
# overhead. Currently, we just check if the given `safetensors_file_path` exists and if not
|
222
|
+
# we perform a write.
|
223
|
+
# Check if the file has been saved in this session or if it already exists on disk.
|
224
|
+
if not self._is_offloaded_to_disk and not os.path.exists(self.safetensors_file_path):
|
225
|
+
os.makedirs(os.path.dirname(self.safetensors_file_path), exist_ok=True)
|
226
|
+
tensors_to_save = {key: tensor.data.to(self.offload_device) for tensor, key in self.tensor_to_key.items()}
|
227
|
+
safetensors.torch.save_file(tensors_to_save, self.safetensors_file_path)
|
228
|
+
|
229
|
+
# The group is now considered offloaded to disk for the rest of the session.
|
230
|
+
self._is_offloaded_to_disk = True
|
231
|
+
|
232
|
+
# We do this to free up the RAM which is still holding the up tensor data.
|
233
|
+
for tensor_obj in self.tensor_to_key.keys():
|
234
|
+
tensor_obj.data = torch.empty_like(tensor_obj.data, device=self.offload_device)
|
235
|
+
|
236
|
+
def _offload_to_memory(self):
|
241
237
|
if self.stream is not None:
|
242
238
|
if not self.record_stream:
|
243
|
-
|
239
|
+
self._torch_accelerator_module.current_stream().synchronize()
|
240
|
+
|
244
241
|
for group_module in self.modules:
|
245
242
|
for param in group_module.parameters():
|
246
243
|
param.data = self.cpu_param_dict[param]
|
@@ -248,14 +245,29 @@ class ModuleGroup:
|
|
248
245
|
param.data = self.cpu_param_dict[param]
|
249
246
|
for buffer in self.buffers:
|
250
247
|
buffer.data = self.cpu_param_dict[buffer]
|
251
|
-
|
252
248
|
else:
|
253
249
|
for group_module in self.modules:
|
254
|
-
group_module.to(self.offload_device, non_blocking=
|
250
|
+
group_module.to(self.offload_device, non_blocking=False)
|
255
251
|
for param in self.parameters:
|
256
|
-
param.data = param.data.to(self.offload_device, non_blocking=
|
252
|
+
param.data = param.data.to(self.offload_device, non_blocking=False)
|
257
253
|
for buffer in self.buffers:
|
258
|
-
buffer.data = buffer.data.to(self.offload_device, non_blocking=
|
254
|
+
buffer.data = buffer.data.to(self.offload_device, non_blocking=False)
|
255
|
+
|
256
|
+
@torch.compiler.disable()
|
257
|
+
def onload_(self):
|
258
|
+
r"""Onloads the group of parameters to the onload_device."""
|
259
|
+
if self.offload_to_disk_path is not None:
|
260
|
+
self._onload_from_disk()
|
261
|
+
else:
|
262
|
+
self._onload_from_memory()
|
263
|
+
|
264
|
+
@torch.compiler.disable()
|
265
|
+
def offload_(self):
|
266
|
+
r"""Offloads the group of parameters to the offload_device."""
|
267
|
+
if self.offload_to_disk_path:
|
268
|
+
self._offload_to_disk()
|
269
|
+
else:
|
270
|
+
self._offload_to_memory()
|
259
271
|
|
260
272
|
|
261
273
|
class GroupOffloadingHook(ModelHook):
|
@@ -268,9 +280,10 @@ class GroupOffloadingHook(ModelHook):
|
|
268
280
|
|
269
281
|
_is_stateful = False
|
270
282
|
|
271
|
-
def __init__(self, group: ModuleGroup,
|
283
|
+
def __init__(self, group: ModuleGroup, *, config: GroupOffloadingConfig) -> None:
|
272
284
|
self.group = group
|
273
|
-
self.next_group =
|
285
|
+
self.next_group: Optional[ModuleGroup] = None
|
286
|
+
self.config = config
|
274
287
|
|
275
288
|
def initialize_hook(self, module: torch.nn.Module) -> torch.nn.Module:
|
276
289
|
if self.group.offload_leader == module:
|
@@ -289,9 +302,23 @@ class GroupOffloadingHook(ModelHook):
|
|
289
302
|
if self.group.onload_leader == module:
|
290
303
|
if self.group.onload_self:
|
291
304
|
self.group.onload_()
|
292
|
-
|
305
|
+
|
306
|
+
should_onload_next_group = self.next_group is not None and not self.next_group.onload_self
|
307
|
+
if should_onload_next_group:
|
293
308
|
self.next_group.onload_()
|
294
309
|
|
310
|
+
should_synchronize = (
|
311
|
+
not self.group.onload_self and self.group.stream is not None and not should_onload_next_group
|
312
|
+
)
|
313
|
+
if should_synchronize:
|
314
|
+
# If this group didn't onload itself, it means it was asynchronously onloaded by the
|
315
|
+
# previous group. We need to synchronize the side stream to ensure parameters
|
316
|
+
# are completely loaded to proceed with forward pass. Without this, uninitialized
|
317
|
+
# weights will be used in the computation, leading to incorrect results
|
318
|
+
# Also, we should only do this synchronization if we don't already do it from the sync call in
|
319
|
+
# self.next_group.onload_, hence the `not should_onload_next_group` check.
|
320
|
+
self.group.stream.synchronize()
|
321
|
+
|
295
322
|
args = send_to_device(args, self.group.onload_device, non_blocking=self.group.non_blocking)
|
296
323
|
kwargs = send_to_device(kwargs, self.group.onload_device, non_blocking=self.group.non_blocking)
|
297
324
|
return args, kwargs
|
@@ -319,7 +346,8 @@ class LazyPrefetchGroupOffloadingHook(ModelHook):
|
|
319
346
|
def initialize_hook(self, module):
|
320
347
|
def make_execution_order_update_callback(current_name, current_submodule):
|
321
348
|
def callback():
|
322
|
-
|
349
|
+
if not torch.compiler.is_compiling():
|
350
|
+
logger.debug(f"Adding {current_name} to the execution order")
|
323
351
|
self.execution_order.append((current_name, current_submodule))
|
324
352
|
|
325
353
|
return callback
|
@@ -356,12 +384,13 @@ class LazyPrefetchGroupOffloadingHook(ModelHook):
|
|
356
384
|
# if the missing layers end up being executed in the future.
|
357
385
|
if execution_order_module_names != self._layer_execution_tracker_module_names:
|
358
386
|
unexecuted_layers = list(self._layer_execution_tracker_module_names - execution_order_module_names)
|
359
|
-
|
360
|
-
|
361
|
-
|
362
|
-
|
363
|
-
|
364
|
-
|
387
|
+
if not torch.compiler.is_compiling():
|
388
|
+
logger.warning(
|
389
|
+
"It seems like some layers were not executed during the forward pass. This may lead to problems when "
|
390
|
+
"applying lazy prefetching with automatic tracing and lead to device-mismatch related errors. Please "
|
391
|
+
"make sure that all layers are executed during the forward pass. The following layers were not executed:\n"
|
392
|
+
f"{unexecuted_layers=}"
|
393
|
+
)
|
365
394
|
|
366
395
|
# Remove the layer execution tracker hooks from the submodules
|
367
396
|
base_module_registry = module._diffusers_hook
|
@@ -389,7 +418,8 @@ class LazyPrefetchGroupOffloadingHook(ModelHook):
|
|
389
418
|
for i in range(num_executed - 1):
|
390
419
|
name1, _ = self.execution_order[i]
|
391
420
|
name2, _ = self.execution_order[i + 1]
|
392
|
-
|
421
|
+
if not torch.compiler.is_compiling():
|
422
|
+
logger.debug(f"Applying lazy prefetch group offloading from {name1} to {name2}")
|
393
423
|
group_offloading_hooks[i].next_group = group_offloading_hooks[i + 1].group
|
394
424
|
group_offloading_hooks[i].next_group.onload_self = False
|
395
425
|
|
@@ -414,9 +444,9 @@ class LayerExecutionTrackerHook(ModelHook):
|
|
414
444
|
|
415
445
|
def apply_group_offloading(
|
416
446
|
module: torch.nn.Module,
|
417
|
-
onload_device: torch.device,
|
418
|
-
offload_device: torch.device = torch.device("cpu"),
|
419
|
-
offload_type: str = "block_level",
|
447
|
+
onload_device: Union[str, torch.device],
|
448
|
+
offload_device: Union[str, torch.device] = torch.device("cpu"),
|
449
|
+
offload_type: Union[str, GroupOffloadingType] = "block_level",
|
420
450
|
num_blocks_per_group: Optional[int] = None,
|
421
451
|
non_blocking: bool = False,
|
422
452
|
use_stream: bool = False,
|
@@ -458,7 +488,7 @@ def apply_group_offloading(
|
|
458
488
|
The device to which the group of modules are onloaded.
|
459
489
|
offload_device (`torch.device`, defaults to `torch.device("cpu")`):
|
460
490
|
The device to which the group of modules are offloaded. This should typically be the CPU. Default is CPU.
|
461
|
-
offload_type (`str`, defaults to "block_level"):
|
491
|
+
offload_type (`str` or `GroupOffloadingType`, defaults to "block_level"):
|
462
492
|
The type of offloading to be applied. Can be one of "block_level" or "leaf_level". Default is
|
463
493
|
"block_level".
|
464
494
|
offload_to_disk_path (`str`, *optional*, defaults to `None`):
|
@@ -501,6 +531,10 @@ def apply_group_offloading(
|
|
501
531
|
```
|
502
532
|
"""
|
503
533
|
|
534
|
+
onload_device = torch.device(onload_device) if isinstance(onload_device, str) else onload_device
|
535
|
+
offload_device = torch.device(offload_device) if isinstance(offload_device, str) else offload_device
|
536
|
+
offload_type = GroupOffloadingType(offload_type)
|
537
|
+
|
504
538
|
stream = None
|
505
539
|
if use_stream:
|
506
540
|
if torch.cuda.is_available():
|
@@ -512,84 +546,45 @@ def apply_group_offloading(
|
|
512
546
|
|
513
547
|
if not use_stream and record_stream:
|
514
548
|
raise ValueError("`record_stream` cannot be True when `use_stream=False`.")
|
549
|
+
if offload_type == GroupOffloadingType.BLOCK_LEVEL and num_blocks_per_group is None:
|
550
|
+
raise ValueError("`num_blocks_per_group` must be provided when using `offload_type='block_level'.")
|
515
551
|
|
516
552
|
_raise_error_if_accelerate_model_or_sequential_hook_present(module)
|
517
553
|
|
518
|
-
|
519
|
-
|
520
|
-
|
521
|
-
|
522
|
-
|
523
|
-
|
524
|
-
|
525
|
-
|
526
|
-
|
527
|
-
|
528
|
-
|
529
|
-
|
530
|
-
|
531
|
-
|
532
|
-
|
533
|
-
|
534
|
-
|
535
|
-
|
536
|
-
|
537
|
-
onload_device=onload_device,
|
538
|
-
offload_to_disk_path=offload_to_disk_path,
|
539
|
-
non_blocking=non_blocking,
|
540
|
-
stream=stream,
|
541
|
-
record_stream=record_stream,
|
542
|
-
low_cpu_mem_usage=low_cpu_mem_usage,
|
543
|
-
)
|
554
|
+
config = GroupOffloadingConfig(
|
555
|
+
onload_device=onload_device,
|
556
|
+
offload_device=offload_device,
|
557
|
+
offload_type=offload_type,
|
558
|
+
num_blocks_per_group=num_blocks_per_group,
|
559
|
+
non_blocking=non_blocking,
|
560
|
+
stream=stream,
|
561
|
+
record_stream=record_stream,
|
562
|
+
low_cpu_mem_usage=low_cpu_mem_usage,
|
563
|
+
offload_to_disk_path=offload_to_disk_path,
|
564
|
+
)
|
565
|
+
_apply_group_offloading(module, config)
|
566
|
+
|
567
|
+
|
568
|
+
def _apply_group_offloading(module: torch.nn.Module, config: GroupOffloadingConfig) -> None:
|
569
|
+
if config.offload_type == GroupOffloadingType.BLOCK_LEVEL:
|
570
|
+
_apply_group_offloading_block_level(module, config)
|
571
|
+
elif config.offload_type == GroupOffloadingType.LEAF_LEVEL:
|
572
|
+
_apply_group_offloading_leaf_level(module, config)
|
544
573
|
else:
|
545
|
-
|
574
|
+
assert False
|
546
575
|
|
547
576
|
|
548
|
-
def _apply_group_offloading_block_level(
|
549
|
-
module: torch.nn.Module,
|
550
|
-
num_blocks_per_group: int,
|
551
|
-
offload_device: torch.device,
|
552
|
-
onload_device: torch.device,
|
553
|
-
non_blocking: bool,
|
554
|
-
stream: Union[torch.cuda.Stream, torch.Stream, None] = None,
|
555
|
-
record_stream: Optional[bool] = False,
|
556
|
-
low_cpu_mem_usage: bool = False,
|
557
|
-
offload_to_disk_path: Optional[str] = None,
|
558
|
-
) -> None:
|
577
|
+
def _apply_group_offloading_block_level(module: torch.nn.Module, config: GroupOffloadingConfig) -> None:
|
559
578
|
r"""
|
560
579
|
This function applies offloading to groups of torch.nn.ModuleList or torch.nn.Sequential blocks. In comparison to
|
561
580
|
the "leaf_level" offloading, which is more fine-grained, this offloading is done at the top-level blocks.
|
562
|
-
|
563
|
-
Args:
|
564
|
-
module (`torch.nn.Module`):
|
565
|
-
The module to which group offloading is applied.
|
566
|
-
offload_device (`torch.device`):
|
567
|
-
The device to which the group of modules are offloaded. This should typically be the CPU.
|
568
|
-
offload_to_disk_path (`str`, *optional*, defaults to `None`):
|
569
|
-
The path to the directory where parameters will be offloaded. Setting this option can be useful in limited
|
570
|
-
RAM environment settings where a reasonable speed-memory trade-off is desired.
|
571
|
-
onload_device (`torch.device`):
|
572
|
-
The device to which the group of modules are onloaded.
|
573
|
-
non_blocking (`bool`):
|
574
|
-
If True, offloading and onloading is done asynchronously. This can be useful for overlapping computation
|
575
|
-
and data transfer.
|
576
|
-
stream (`torch.cuda.Stream`or `torch.Stream`, *optional*):
|
577
|
-
If provided, offloading and onloading is done asynchronously using the provided stream. This can be useful
|
578
|
-
for overlapping computation and data transfer.
|
579
|
-
record_stream (`bool`, defaults to `False`): When enabled with `use_stream`, it marks the current tensor
|
580
|
-
as having been used by this stream. It is faster at the expense of slightly more memory usage. Refer to the
|
581
|
-
[PyTorch official docs](https://pytorch.org/docs/stable/generated/torch.Tensor.record_stream.html) more
|
582
|
-
details.
|
583
|
-
low_cpu_mem_usage (`bool`, defaults to `False`):
|
584
|
-
If True, the CPU memory usage is minimized by pinning tensors on-the-fly instead of pre-pinning them. This
|
585
|
-
option only matters when using streamed CPU offloading (i.e. `use_stream=True`). This can be useful when
|
586
|
-
the CPU memory is a bottleneck but may counteract the benefits of using streams.
|
587
581
|
"""
|
588
|
-
|
582
|
+
|
583
|
+
if config.stream is not None and config.num_blocks_per_group != 1:
|
589
584
|
logger.warning(
|
590
|
-
f"Using streams is only supported for num_blocks_per_group=1. Got {num_blocks_per_group=}. Setting it to 1."
|
585
|
+
f"Using streams is only supported for num_blocks_per_group=1. Got {config.num_blocks_per_group=}. Setting it to 1."
|
591
586
|
)
|
592
|
-
num_blocks_per_group = 1
|
587
|
+
config.num_blocks_per_group = 1
|
593
588
|
|
594
589
|
# Create module groups for ModuleList and Sequential blocks
|
595
590
|
modules_with_group_offloading = set()
|
@@ -601,20 +596,22 @@ def _apply_group_offloading_block_level(
|
|
601
596
|
modules_with_group_offloading.add(name)
|
602
597
|
continue
|
603
598
|
|
604
|
-
for i in range(0, len(submodule), num_blocks_per_group):
|
605
|
-
current_modules = submodule[i : i + num_blocks_per_group]
|
599
|
+
for i in range(0, len(submodule), config.num_blocks_per_group):
|
600
|
+
current_modules = submodule[i : i + config.num_blocks_per_group]
|
601
|
+
group_id = f"{name}_{i}_{i + len(current_modules) - 1}"
|
606
602
|
group = ModuleGroup(
|
607
603
|
modules=current_modules,
|
608
|
-
offload_device=offload_device,
|
609
|
-
onload_device=onload_device,
|
610
|
-
offload_to_disk_path=offload_to_disk_path,
|
604
|
+
offload_device=config.offload_device,
|
605
|
+
onload_device=config.onload_device,
|
606
|
+
offload_to_disk_path=config.offload_to_disk_path,
|
611
607
|
offload_leader=current_modules[-1],
|
612
608
|
onload_leader=current_modules[0],
|
613
|
-
non_blocking=non_blocking,
|
614
|
-
stream=stream,
|
615
|
-
record_stream=record_stream,
|
616
|
-
low_cpu_mem_usage=low_cpu_mem_usage,
|
609
|
+
non_blocking=config.non_blocking,
|
610
|
+
stream=config.stream,
|
611
|
+
record_stream=config.record_stream,
|
612
|
+
low_cpu_mem_usage=config.low_cpu_mem_usage,
|
617
613
|
onload_self=True,
|
614
|
+
group_id=group_id,
|
618
615
|
)
|
619
616
|
matched_module_groups.append(group)
|
620
617
|
for j in range(i, i + len(current_modules)):
|
@@ -623,7 +620,7 @@ def _apply_group_offloading_block_level(
|
|
623
620
|
# Apply group offloading hooks to the module groups
|
624
621
|
for i, group in enumerate(matched_module_groups):
|
625
622
|
for group_module in group.modules:
|
626
|
-
_apply_group_offloading_hook(group_module, group,
|
623
|
+
_apply_group_offloading_hook(group_module, group, config=config)
|
627
624
|
|
628
625
|
# Parameters and Buffers of the top-level module need to be offloaded/onloaded separately
|
629
626
|
# when the forward pass of this module is called. This is because the top-level module is not
|
@@ -638,9 +635,9 @@ def _apply_group_offloading_block_level(
|
|
638
635
|
unmatched_modules = [unmatched_module for _, unmatched_module in unmatched_modules]
|
639
636
|
unmatched_group = ModuleGroup(
|
640
637
|
modules=unmatched_modules,
|
641
|
-
offload_device=offload_device,
|
642
|
-
onload_device=onload_device,
|
643
|
-
offload_to_disk_path=offload_to_disk_path,
|
638
|
+
offload_device=config.offload_device,
|
639
|
+
onload_device=config.onload_device,
|
640
|
+
offload_to_disk_path=config.offload_to_disk_path,
|
644
641
|
offload_leader=module,
|
645
642
|
onload_leader=module,
|
646
643
|
parameters=parameters,
|
@@ -649,74 +646,41 @@ def _apply_group_offloading_block_level(
|
|
649
646
|
stream=None,
|
650
647
|
record_stream=False,
|
651
648
|
onload_self=True,
|
649
|
+
group_id=f"{module.__class__.__name__}_unmatched_group",
|
652
650
|
)
|
653
|
-
if stream is None:
|
654
|
-
_apply_group_offloading_hook(module, unmatched_group,
|
651
|
+
if config.stream is None:
|
652
|
+
_apply_group_offloading_hook(module, unmatched_group, config=config)
|
655
653
|
else:
|
656
|
-
_apply_lazy_group_offloading_hook(module, unmatched_group,
|
654
|
+
_apply_lazy_group_offloading_hook(module, unmatched_group, config=config)
|
657
655
|
|
658
656
|
|
659
|
-
def _apply_group_offloading_leaf_level(
|
660
|
-
module: torch.nn.Module,
|
661
|
-
offload_device: torch.device,
|
662
|
-
onload_device: torch.device,
|
663
|
-
non_blocking: bool,
|
664
|
-
stream: Union[torch.cuda.Stream, torch.Stream, None] = None,
|
665
|
-
record_stream: Optional[bool] = False,
|
666
|
-
low_cpu_mem_usage: bool = False,
|
667
|
-
offload_to_disk_path: Optional[str] = None,
|
668
|
-
) -> None:
|
657
|
+
def _apply_group_offloading_leaf_level(module: torch.nn.Module, config: GroupOffloadingConfig) -> None:
|
669
658
|
r"""
|
670
659
|
This function applies offloading to groups of leaf modules in a torch.nn.Module. This method has minimal memory
|
671
660
|
requirements. However, it can be slower compared to other offloading methods due to the excessive number of device
|
672
661
|
synchronizations. When using devices that support streams to overlap data transfer and computation, this method can
|
673
662
|
reduce memory usage without any performance degradation.
|
674
|
-
|
675
|
-
Args:
|
676
|
-
module (`torch.nn.Module`):
|
677
|
-
The module to which group offloading is applied.
|
678
|
-
offload_device (`torch.device`):
|
679
|
-
The device to which the group of modules are offloaded. This should typically be the CPU.
|
680
|
-
onload_device (`torch.device`):
|
681
|
-
The device to which the group of modules are onloaded.
|
682
|
-
offload_to_disk_path (`str`, *optional*, defaults to `None`):
|
683
|
-
The path to the directory where parameters will be offloaded. Setting this option can be useful in limited
|
684
|
-
RAM environment settings where a reasonable speed-memory trade-off is desired.
|
685
|
-
non_blocking (`bool`):
|
686
|
-
If True, offloading and onloading is done asynchronously. This can be useful for overlapping computation
|
687
|
-
and data transfer.
|
688
|
-
stream (`torch.cuda.Stream` or `torch.Stream`, *optional*):
|
689
|
-
If provided, offloading and onloading is done asynchronously using the provided stream. This can be useful
|
690
|
-
for overlapping computation and data transfer.
|
691
|
-
record_stream (`bool`, defaults to `False`): When enabled with `use_stream`, it marks the current tensor
|
692
|
-
as having been used by this stream. It is faster at the expense of slightly more memory usage. Refer to the
|
693
|
-
[PyTorch official docs](https://pytorch.org/docs/stable/generated/torch.Tensor.record_stream.html) more
|
694
|
-
details.
|
695
|
-
low_cpu_mem_usage (`bool`, defaults to `False`):
|
696
|
-
If True, the CPU memory usage is minimized by pinning tensors on-the-fly instead of pre-pinning them. This
|
697
|
-
option only matters when using streamed CPU offloading (i.e. `use_stream=True`). This can be useful when
|
698
|
-
the CPU memory is a bottleneck but may counteract the benefits of using streams.
|
699
663
|
"""
|
700
|
-
|
701
664
|
# Create module groups for leaf modules and apply group offloading hooks
|
702
665
|
modules_with_group_offloading = set()
|
703
666
|
for name, submodule in module.named_modules():
|
704
|
-
if not isinstance(submodule,
|
667
|
+
if not isinstance(submodule, _GO_LC_SUPPORTED_PYTORCH_LAYERS):
|
705
668
|
continue
|
706
669
|
group = ModuleGroup(
|
707
670
|
modules=[submodule],
|
708
|
-
offload_device=offload_device,
|
709
|
-
onload_device=onload_device,
|
710
|
-
offload_to_disk_path=offload_to_disk_path,
|
671
|
+
offload_device=config.offload_device,
|
672
|
+
onload_device=config.onload_device,
|
673
|
+
offload_to_disk_path=config.offload_to_disk_path,
|
711
674
|
offload_leader=submodule,
|
712
675
|
onload_leader=submodule,
|
713
|
-
non_blocking=non_blocking,
|
714
|
-
stream=stream,
|
715
|
-
record_stream=record_stream,
|
716
|
-
low_cpu_mem_usage=low_cpu_mem_usage,
|
676
|
+
non_blocking=config.non_blocking,
|
677
|
+
stream=config.stream,
|
678
|
+
record_stream=config.record_stream,
|
679
|
+
low_cpu_mem_usage=config.low_cpu_mem_usage,
|
717
680
|
onload_self=True,
|
681
|
+
group_id=name,
|
718
682
|
)
|
719
|
-
_apply_group_offloading_hook(submodule, group,
|
683
|
+
_apply_group_offloading_hook(submodule, group, config=config)
|
720
684
|
modules_with_group_offloading.add(name)
|
721
685
|
|
722
686
|
# Parameters and Buffers at all non-leaf levels need to be offloaded/onloaded separately when the forward pass
|
@@ -747,33 +711,33 @@ def _apply_group_offloading_leaf_level(
|
|
747
711
|
parameters = parent_to_parameters.get(name, [])
|
748
712
|
buffers = parent_to_buffers.get(name, [])
|
749
713
|
parent_module = module_dict[name]
|
750
|
-
assert getattr(parent_module, "_diffusers_hook", None) is None
|
751
714
|
group = ModuleGroup(
|
752
715
|
modules=[],
|
753
|
-
offload_device=offload_device,
|
754
|
-
onload_device=onload_device,
|
716
|
+
offload_device=config.offload_device,
|
717
|
+
onload_device=config.onload_device,
|
755
718
|
offload_leader=parent_module,
|
756
719
|
onload_leader=parent_module,
|
757
|
-
offload_to_disk_path=offload_to_disk_path,
|
720
|
+
offload_to_disk_path=config.offload_to_disk_path,
|
758
721
|
parameters=parameters,
|
759
722
|
buffers=buffers,
|
760
|
-
non_blocking=non_blocking,
|
761
|
-
stream=stream,
|
762
|
-
record_stream=record_stream,
|
763
|
-
low_cpu_mem_usage=low_cpu_mem_usage,
|
723
|
+
non_blocking=config.non_blocking,
|
724
|
+
stream=config.stream,
|
725
|
+
record_stream=config.record_stream,
|
726
|
+
low_cpu_mem_usage=config.low_cpu_mem_usage,
|
764
727
|
onload_self=True,
|
728
|
+
group_id=name,
|
765
729
|
)
|
766
|
-
_apply_group_offloading_hook(parent_module, group,
|
730
|
+
_apply_group_offloading_hook(parent_module, group, config=config)
|
767
731
|
|
768
|
-
if stream is not None:
|
732
|
+
if config.stream is not None:
|
769
733
|
# When using streams, we need to know the layer execution order for applying prefetching (to overlap data transfer
|
770
734
|
# and computation). Since we don't know the order beforehand, we apply a lazy prefetching hook that will find the
|
771
735
|
# execution order and apply prefetching in the correct order.
|
772
736
|
unmatched_group = ModuleGroup(
|
773
737
|
modules=[],
|
774
|
-
offload_device=offload_device,
|
775
|
-
onload_device=onload_device,
|
776
|
-
offload_to_disk_path=offload_to_disk_path,
|
738
|
+
offload_device=config.offload_device,
|
739
|
+
onload_device=config.onload_device,
|
740
|
+
offload_to_disk_path=config.offload_to_disk_path,
|
777
741
|
offload_leader=module,
|
778
742
|
onload_leader=module,
|
779
743
|
parameters=None,
|
@@ -781,37 +745,40 @@ def _apply_group_offloading_leaf_level(
|
|
781
745
|
non_blocking=False,
|
782
746
|
stream=None,
|
783
747
|
record_stream=False,
|
784
|
-
low_cpu_mem_usage=low_cpu_mem_usage,
|
748
|
+
low_cpu_mem_usage=config.low_cpu_mem_usage,
|
785
749
|
onload_self=True,
|
750
|
+
group_id=_GROUP_ID_LAZY_LEAF,
|
786
751
|
)
|
787
|
-
_apply_lazy_group_offloading_hook(module, unmatched_group,
|
752
|
+
_apply_lazy_group_offloading_hook(module, unmatched_group, config=config)
|
788
753
|
|
789
754
|
|
790
755
|
def _apply_group_offloading_hook(
|
791
756
|
module: torch.nn.Module,
|
792
757
|
group: ModuleGroup,
|
793
|
-
|
758
|
+
*,
|
759
|
+
config: GroupOffloadingConfig,
|
794
760
|
) -> None:
|
795
761
|
registry = HookRegistry.check_if_exists_or_initialize(module)
|
796
762
|
|
797
763
|
# We may have already registered a group offloading hook if the module had a torch.nn.Parameter whose parent
|
798
764
|
# is the current module. In such cases, we don't want to overwrite the existing group offloading hook.
|
799
765
|
if registry.get_hook(_GROUP_OFFLOADING) is None:
|
800
|
-
hook = GroupOffloadingHook(group,
|
766
|
+
hook = GroupOffloadingHook(group, config=config)
|
801
767
|
registry.register_hook(hook, _GROUP_OFFLOADING)
|
802
768
|
|
803
769
|
|
804
770
|
def _apply_lazy_group_offloading_hook(
|
805
771
|
module: torch.nn.Module,
|
806
772
|
group: ModuleGroup,
|
807
|
-
|
773
|
+
*,
|
774
|
+
config: GroupOffloadingConfig,
|
808
775
|
) -> None:
|
809
776
|
registry = HookRegistry.check_if_exists_or_initialize(module)
|
810
777
|
|
811
778
|
# We may have already registered a group offloading hook if the module had a torch.nn.Parameter whose parent
|
812
779
|
# is the current module. In such cases, we don't want to overwrite the existing group offloading hook.
|
813
780
|
if registry.get_hook(_GROUP_OFFLOADING) is None:
|
814
|
-
hook = GroupOffloadingHook(group,
|
781
|
+
hook = GroupOffloadingHook(group, config=config)
|
815
782
|
registry.register_hook(hook, _GROUP_OFFLOADING)
|
816
783
|
|
817
784
|
lazy_prefetch_hook = LazyPrefetchGroupOffloadingHook()
|
@@ -878,15 +845,54 @@ def _raise_error_if_accelerate_model_or_sequential_hook_present(module: torch.nn
|
|
878
845
|
)
|
879
846
|
|
880
847
|
|
881
|
-
def
|
848
|
+
def _get_top_level_group_offload_hook(module: torch.nn.Module) -> Optional[GroupOffloadingHook]:
|
882
849
|
for submodule in module.modules():
|
883
|
-
if hasattr(submodule, "_diffusers_hook")
|
884
|
-
|
885
|
-
|
850
|
+
if hasattr(submodule, "_diffusers_hook"):
|
851
|
+
group_offloading_hook = submodule._diffusers_hook.get_hook(_GROUP_OFFLOADING)
|
852
|
+
if group_offloading_hook is not None:
|
853
|
+
return group_offloading_hook
|
854
|
+
return None
|
855
|
+
|
856
|
+
|
857
|
+
def _is_group_offload_enabled(module: torch.nn.Module) -> bool:
|
858
|
+
top_level_group_offload_hook = _get_top_level_group_offload_hook(module)
|
859
|
+
return top_level_group_offload_hook is not None
|
886
860
|
|
887
861
|
|
888
862
|
def _get_group_onload_device(module: torch.nn.Module) -> torch.device:
|
889
|
-
|
890
|
-
|
891
|
-
|
863
|
+
top_level_group_offload_hook = _get_top_level_group_offload_hook(module)
|
864
|
+
if top_level_group_offload_hook is not None:
|
865
|
+
return top_level_group_offload_hook.config.onload_device
|
892
866
|
raise ValueError("Group offloading is not enabled for the provided module.")
|
867
|
+
|
868
|
+
|
869
|
+
def _compute_group_hash(group_id):
|
870
|
+
hashed_id = hashlib.sha256(group_id.encode("utf-8")).hexdigest()
|
871
|
+
# first 16 characters for a reasonably short but unique name
|
872
|
+
return hashed_id[:16]
|
873
|
+
|
874
|
+
|
875
|
+
def _maybe_remove_and_reapply_group_offloading(module: torch.nn.Module) -> None:
|
876
|
+
r"""
|
877
|
+
Removes the group offloading hook from the module and re-applies it. This is useful when the module has been
|
878
|
+
modified in-place and the group offloading hook references-to-tensors needs to be updated. The in-place
|
879
|
+
modification can happen in a number of ways, for example, fusing QKV or unloading/loading LoRAs on-the-fly.
|
880
|
+
|
881
|
+
In this implementation, we make an assumption that group offloading has only been applied at the top-level module,
|
882
|
+
and therefore all submodules have the same onload and offload devices. If this assumption is not true, say in the
|
883
|
+
case where user has applied group offloading at multiple levels, this function will not work as expected.
|
884
|
+
|
885
|
+
There is some performance penalty associated with doing this when non-default streams are used, because we need to
|
886
|
+
retrace the execution order of the layers with `LazyPrefetchGroupOffloadingHook`.
|
887
|
+
"""
|
888
|
+
top_level_group_offload_hook = _get_top_level_group_offload_hook(module)
|
889
|
+
|
890
|
+
if top_level_group_offload_hook is None:
|
891
|
+
return
|
892
|
+
|
893
|
+
registry = HookRegistry.check_if_exists_or_initialize(module)
|
894
|
+
registry.remove_hook(_GROUP_OFFLOADING, recurse=True)
|
895
|
+
registry.remove_hook(_LAYER_EXECUTION_TRACKER, recurse=True)
|
896
|
+
registry.remove_hook(_LAZY_PREFETCH_GROUP_OFFLOADING, recurse=True)
|
897
|
+
|
898
|
+
_apply_group_offloading(module, top_level_group_offload_hook.config)
|