cache-dit 1.0.3__py3-none-any.whl → 1.0.14__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- cache_dit/__init__.py +37 -19
- cache_dit/_version.py +2 -2
- cache_dit/caching/__init__.py +36 -0
- cache_dit/{cache_factory → caching}/block_adapters/__init__.py +126 -11
- cache_dit/{cache_factory → caching}/block_adapters/block_adapters.py +78 -7
- cache_dit/caching/block_adapters/block_registers.py +118 -0
- cache_dit/caching/cache_adapters/__init__.py +1 -0
- cache_dit/{cache_factory → caching}/cache_adapters/cache_adapter.py +214 -114
- cache_dit/caching/cache_blocks/__init__.py +226 -0
- cache_dit/caching/cache_blocks/pattern_0_1_2.py +26 -0
- cache_dit/caching/cache_blocks/pattern_3_4_5.py +543 -0
- cache_dit/caching/cache_blocks/pattern_base.py +748 -0
- cache_dit/caching/cache_blocks/pattern_utils.py +86 -0
- cache_dit/caching/cache_contexts/__init__.py +28 -0
- cache_dit/caching/cache_contexts/cache_config.py +120 -0
- cache_dit/{cache_factory → caching}/cache_contexts/cache_context.py +18 -94
- cache_dit/{cache_factory → caching}/cache_contexts/cache_manager.py +133 -12
- cache_dit/{cache_factory → caching}/cache_contexts/calibrators/__init__.py +25 -3
- cache_dit/{cache_factory → caching}/cache_contexts/calibrators/foca.py +1 -1
- cache_dit/{cache_factory → caching}/cache_contexts/calibrators/taylorseer.py +81 -9
- cache_dit/caching/cache_contexts/context_manager.py +36 -0
- cache_dit/caching/cache_contexts/prune_config.py +63 -0
- cache_dit/caching/cache_contexts/prune_context.py +155 -0
- cache_dit/caching/cache_contexts/prune_manager.py +167 -0
- cache_dit/{cache_factory → caching}/cache_interface.py +150 -37
- cache_dit/{cache_factory → caching}/cache_types.py +19 -2
- cache_dit/{cache_factory → caching}/forward_pattern.py +14 -14
- cache_dit/{cache_factory → caching}/params_modifier.py +10 -10
- cache_dit/caching/patch_functors/__init__.py +15 -0
- cache_dit/{cache_factory → caching}/patch_functors/functor_chroma.py +1 -1
- cache_dit/{cache_factory → caching}/patch_functors/functor_dit.py +1 -1
- cache_dit/{cache_factory → caching}/patch_functors/functor_flux.py +1 -1
- cache_dit/{cache_factory → caching}/patch_functors/functor_hidream.py +1 -1
- cache_dit/{cache_factory → caching}/patch_functors/functor_hunyuan_dit.py +1 -1
- cache_dit/{cache_factory → caching}/patch_functors/functor_qwen_image_controlnet.py +1 -1
- cache_dit/{cache_factory → caching}/utils.py +19 -8
- cache_dit/metrics/__init__.py +11 -0
- cache_dit/parallelism/__init__.py +3 -0
- cache_dit/parallelism/backends/native_diffusers/__init__.py +6 -0
- cache_dit/parallelism/backends/native_diffusers/context_parallelism/__init__.py +164 -0
- cache_dit/parallelism/backends/native_diffusers/context_parallelism/attention/__init__.py +4 -0
- cache_dit/parallelism/backends/native_diffusers/context_parallelism/attention/_attention_dispatch.py +304 -0
- cache_dit/parallelism/backends/native_diffusers/context_parallelism/cp_plan_chroma.py +95 -0
- cache_dit/parallelism/backends/native_diffusers/context_parallelism/cp_plan_cogvideox.py +202 -0
- cache_dit/parallelism/backends/native_diffusers/context_parallelism/cp_plan_cogview.py +299 -0
- cache_dit/parallelism/backends/native_diffusers/context_parallelism/cp_plan_cosisid.py +123 -0
- cache_dit/parallelism/backends/native_diffusers/context_parallelism/cp_plan_dit.py +94 -0
- cache_dit/parallelism/backends/native_diffusers/context_parallelism/cp_plan_flux.py +88 -0
- cache_dit/parallelism/backends/native_diffusers/context_parallelism/cp_plan_hunyuan.py +729 -0
- cache_dit/parallelism/backends/native_diffusers/context_parallelism/cp_plan_ltxvideo.py +264 -0
- cache_dit/parallelism/backends/native_diffusers/context_parallelism/cp_plan_nunchaku.py +407 -0
- cache_dit/parallelism/backends/native_diffusers/context_parallelism/cp_plan_pixart.py +285 -0
- cache_dit/parallelism/backends/native_diffusers/context_parallelism/cp_plan_qwen_image.py +104 -0
- cache_dit/parallelism/backends/native_diffusers/context_parallelism/cp_plan_registers.py +84 -0
- cache_dit/parallelism/backends/native_diffusers/context_parallelism/cp_plan_wan.py +101 -0
- cache_dit/parallelism/backends/native_diffusers/context_parallelism/cp_planners.py +117 -0
- cache_dit/parallelism/backends/native_diffusers/parallel_difffusers.py +49 -0
- cache_dit/parallelism/backends/native_diffusers/utils.py +11 -0
- cache_dit/parallelism/backends/native_pytorch/__init__.py +6 -0
- cache_dit/parallelism/backends/native_pytorch/parallel_torch.py +62 -0
- cache_dit/parallelism/backends/native_pytorch/tensor_parallelism/__init__.py +48 -0
- cache_dit/parallelism/backends/native_pytorch/tensor_parallelism/tp_plan_flux.py +171 -0
- cache_dit/parallelism/backends/native_pytorch/tensor_parallelism/tp_plan_kandinsky5.py +79 -0
- cache_dit/parallelism/backends/native_pytorch/tensor_parallelism/tp_plan_qwen_image.py +78 -0
- cache_dit/parallelism/backends/native_pytorch/tensor_parallelism/tp_plan_registers.py +65 -0
- cache_dit/parallelism/backends/native_pytorch/tensor_parallelism/tp_plan_wan.py +153 -0
- cache_dit/parallelism/backends/native_pytorch/tensor_parallelism/tp_planners.py +14 -0
- cache_dit/parallelism/parallel_backend.py +26 -0
- cache_dit/parallelism/parallel_config.py +88 -0
- cache_dit/parallelism/parallel_interface.py +77 -0
- cache_dit/quantize/__init__.py +7 -0
- cache_dit/quantize/backends/__init__.py +1 -0
- cache_dit/quantize/backends/bitsandbytes/__init__.py +0 -0
- cache_dit/quantize/backends/torchao/__init__.py +1 -0
- cache_dit/quantize/{quantize_ao.py → backends/torchao/quantize_ao.py} +40 -30
- cache_dit/quantize/quantize_backend.py +0 -0
- cache_dit/quantize/quantize_config.py +0 -0
- cache_dit/quantize/quantize_interface.py +3 -16
- cache_dit/summary.py +593 -0
- cache_dit/utils.py +46 -290
- {cache_dit-1.0.3.dist-info → cache_dit-1.0.14.dist-info}/METADATA +123 -116
- cache_dit-1.0.14.dist-info/RECORD +102 -0
- cache_dit-1.0.14.dist-info/licenses/LICENSE +203 -0
- cache_dit/cache_factory/__init__.py +0 -28
- cache_dit/cache_factory/block_adapters/block_registers.py +0 -90
- cache_dit/cache_factory/cache_adapters/__init__.py +0 -1
- cache_dit/cache_factory/cache_blocks/__init__.py +0 -76
- cache_dit/cache_factory/cache_blocks/pattern_0_1_2.py +0 -16
- cache_dit/cache_factory/cache_blocks/pattern_3_4_5.py +0 -306
- cache_dit/cache_factory/cache_blocks/pattern_base.py +0 -458
- cache_dit/cache_factory/cache_blocks/pattern_utils.py +0 -41
- cache_dit/cache_factory/cache_contexts/__init__.py +0 -15
- cache_dit/cache_factory/patch_functors/__init__.py +0 -15
- cache_dit-1.0.3.dist-info/RECORD +0 -58
- cache_dit-1.0.3.dist-info/licenses/LICENSE +0 -53
- /cache_dit/{cache_factory → caching}/.gitignore +0 -0
- /cache_dit/{cache_factory → caching}/cache_blocks/offload_utils.py +0 -0
- /cache_dit/{cache_factory → caching}/cache_contexts/calibrators/base.py +0 -0
- /cache_dit/{cache_factory → caching}/patch_functors/functor_base.py +0 -0
- /cache_dit/{custom_ops → kernels}/__init__.py +0 -0
- /cache_dit/{custom_ops → kernels}/triton_taylorseer.py +0 -0
- {cache_dit-1.0.3.dist-info → cache_dit-1.0.14.dist-info}/WHEEL +0 -0
- {cache_dit-1.0.3.dist-info → cache_dit-1.0.14.dist-info}/entry_points.txt +0 -0
- {cache_dit-1.0.3.dist-info → cache_dit-1.0.14.dist-info}/top_level.txt +0 -0
|
@@ -0,0 +1,748 @@
|
|
|
1
|
+
import inspect
|
|
2
|
+
import logging
|
|
3
|
+
import torch
|
|
4
|
+
import torch.distributed as dist
|
|
5
|
+
from diffusers.hooks import HookRegistry
|
|
6
|
+
|
|
7
|
+
try:
|
|
8
|
+
from diffusers.hooks.context_parallel import ContextParallelSplitHook
|
|
9
|
+
except ImportError:
|
|
10
|
+
ContextParallelSplitHook = None
|
|
11
|
+
raise UserWarning(
|
|
12
|
+
"Context parallelism requires the 'diffusers>=0.36.dev0'."
|
|
13
|
+
"Please install latest version of diffusers from source: \n"
|
|
14
|
+
"pip3 install git+https://github.com/huggingface/diffusers.git"
|
|
15
|
+
)
|
|
16
|
+
from cache_dit.caching.cache_contexts.cache_context import CachedContext
|
|
17
|
+
from cache_dit.caching.cache_contexts.prune_context import PrunedContext
|
|
18
|
+
from cache_dit.caching.cache_contexts.cache_manager import (
|
|
19
|
+
CachedContextManager,
|
|
20
|
+
ContextNotExistError,
|
|
21
|
+
)
|
|
22
|
+
from cache_dit.caching.cache_contexts.prune_manager import (
|
|
23
|
+
PrunedContextManager,
|
|
24
|
+
)
|
|
25
|
+
from cache_dit.caching import ForwardPattern
|
|
26
|
+
from cache_dit.caching.cache_types import CacheType
|
|
27
|
+
from cache_dit.logger import init_logger
|
|
28
|
+
|
|
29
|
+
logger = init_logger(__name__)
|
|
30
|
+
|
|
31
|
+
|
|
32
|
+
class CachedBlocks_Pattern_Base(torch.nn.Module):
|
|
33
|
+
_supported_patterns = [
|
|
34
|
+
ForwardPattern.Pattern_0,
|
|
35
|
+
ForwardPattern.Pattern_1,
|
|
36
|
+
ForwardPattern.Pattern_2,
|
|
37
|
+
]
|
|
38
|
+
|
|
39
|
+
def __init__(
|
|
40
|
+
self,
|
|
41
|
+
# 0. Transformer blocks configuration
|
|
42
|
+
transformer_blocks: torch.nn.ModuleList,
|
|
43
|
+
transformer: torch.nn.Module = None,
|
|
44
|
+
forward_pattern: ForwardPattern = ForwardPattern.Pattern_0,
|
|
45
|
+
check_forward_pattern: bool = True,
|
|
46
|
+
check_num_outputs: bool = True,
|
|
47
|
+
# 1. Cache context configuration
|
|
48
|
+
cache_prefix: str = None, # maybe un-need.
|
|
49
|
+
cache_context: CachedContext | str = None,
|
|
50
|
+
context_manager: CachedContextManager = None,
|
|
51
|
+
cache_type: CacheType = CacheType.DBCache,
|
|
52
|
+
**kwargs,
|
|
53
|
+
):
|
|
54
|
+
super().__init__()
|
|
55
|
+
|
|
56
|
+
# 0. Transformer blocks configuration
|
|
57
|
+
self.transformer = transformer
|
|
58
|
+
self.transformer_blocks = transformer_blocks
|
|
59
|
+
self.forward_pattern = forward_pattern
|
|
60
|
+
self.check_forward_pattern = check_forward_pattern
|
|
61
|
+
self.check_num_outputs = check_num_outputs
|
|
62
|
+
# 1. Cache context configuration
|
|
63
|
+
self.cache_prefix = cache_prefix
|
|
64
|
+
self.cache_context = cache_context
|
|
65
|
+
self.context_manager = context_manager
|
|
66
|
+
self.cache_type = cache_type
|
|
67
|
+
|
|
68
|
+
self._check_forward_pattern()
|
|
69
|
+
self._check_cache_type()
|
|
70
|
+
logger.info(
|
|
71
|
+
f"Match Blocks: {self.__class__.__name__}, for "
|
|
72
|
+
f"{self.cache_prefix}, cache_context: {self.cache_context}, "
|
|
73
|
+
f"context_manager: {self.context_manager.name}."
|
|
74
|
+
)
|
|
75
|
+
|
|
76
|
+
def _check_forward_pattern(self):
|
|
77
|
+
if not self.check_forward_pattern:
|
|
78
|
+
logger.warning(
|
|
79
|
+
f"Skipped Forward Pattern Check: {self.forward_pattern}"
|
|
80
|
+
)
|
|
81
|
+
return
|
|
82
|
+
|
|
83
|
+
assert (
|
|
84
|
+
self.forward_pattern.Supported
|
|
85
|
+
and self.forward_pattern in self._supported_patterns
|
|
86
|
+
), f"Pattern {self.forward_pattern} is not supported now!"
|
|
87
|
+
|
|
88
|
+
if self.transformer_blocks is not None:
|
|
89
|
+
for block in self.transformer_blocks:
|
|
90
|
+
# Special case for HiDreamBlock
|
|
91
|
+
if hasattr(block, "block"):
|
|
92
|
+
if isinstance(block.block, torch.nn.Module):
|
|
93
|
+
block = block.block
|
|
94
|
+
|
|
95
|
+
forward_parameters = set(
|
|
96
|
+
inspect.signature(block.forward).parameters.keys()
|
|
97
|
+
)
|
|
98
|
+
|
|
99
|
+
if self.check_num_outputs:
|
|
100
|
+
num_outputs = str(
|
|
101
|
+
inspect.signature(block.forward).return_annotation
|
|
102
|
+
).count("torch.Tensor")
|
|
103
|
+
|
|
104
|
+
if num_outputs > 0:
|
|
105
|
+
assert len(self.forward_pattern.Out) == num_outputs, (
|
|
106
|
+
f"The number of block's outputs is {num_outputs} don't not "
|
|
107
|
+
f"match the number of the pattern: {self.forward_pattern}, "
|
|
108
|
+
f"Out: {len(self.forward_pattern.Out)}."
|
|
109
|
+
)
|
|
110
|
+
|
|
111
|
+
for required_param in self.forward_pattern.In:
|
|
112
|
+
assert (
|
|
113
|
+
required_param in forward_parameters
|
|
114
|
+
), f"The input parameters must contains: {required_param}."
|
|
115
|
+
|
|
116
|
+
@torch.compiler.disable
|
|
117
|
+
def _check_cache_type(self):
|
|
118
|
+
assert (
|
|
119
|
+
self.cache_type == CacheType.DBCache
|
|
120
|
+
), f"Cache type {self.cache_type} is not supported for CachedBlocks."
|
|
121
|
+
|
|
122
|
+
@torch.compiler.disable
|
|
123
|
+
def _check_cache_params(self):
|
|
124
|
+
self._check_cache_type()
|
|
125
|
+
assert self.context_manager.Fn_compute_blocks() <= len(
|
|
126
|
+
self.transformer_blocks
|
|
127
|
+
), (
|
|
128
|
+
f"Fn_compute_blocks {self.context_manager.Fn_compute_blocks()} must be less than "
|
|
129
|
+
f"the number of transformer blocks {len(self.transformer_blocks)}"
|
|
130
|
+
)
|
|
131
|
+
assert self.context_manager.Bn_compute_blocks() <= len(
|
|
132
|
+
self.transformer_blocks
|
|
133
|
+
), (
|
|
134
|
+
f"Bn_compute_blocks {self.context_manager.Bn_compute_blocks()} must be less than "
|
|
135
|
+
f"the number of transformer blocks {len(self.transformer_blocks)}"
|
|
136
|
+
)
|
|
137
|
+
|
|
138
|
+
def call_blocks(
|
|
139
|
+
self,
|
|
140
|
+
hidden_states: torch.Tensor,
|
|
141
|
+
encoder_hidden_states: torch.Tensor,
|
|
142
|
+
*args,
|
|
143
|
+
**kwargs,
|
|
144
|
+
):
|
|
145
|
+
# Call all blocks to process the hidden states without cache.
|
|
146
|
+
for block in self.transformer_blocks:
|
|
147
|
+
hidden_states = block(
|
|
148
|
+
hidden_states,
|
|
149
|
+
encoder_hidden_states,
|
|
150
|
+
*args,
|
|
151
|
+
**kwargs,
|
|
152
|
+
)
|
|
153
|
+
hidden_states, encoder_hidden_states = self._process_block_outputs(
|
|
154
|
+
hidden_states, encoder_hidden_states
|
|
155
|
+
)
|
|
156
|
+
return hidden_states, encoder_hidden_states
|
|
157
|
+
|
|
158
|
+
@torch.compiler.disable
|
|
159
|
+
def _process_block_outputs(
|
|
160
|
+
self,
|
|
161
|
+
hidden_states: torch.Tensor | tuple,
|
|
162
|
+
encoder_hidden_states: torch.Tensor | None,
|
|
163
|
+
) -> tuple[torch.Tensor, torch.Tensor | None]:
|
|
164
|
+
if not isinstance(hidden_states, torch.Tensor):
|
|
165
|
+
hidden_states, encoder_hidden_states = hidden_states
|
|
166
|
+
if not self.forward_pattern.Return_H_First:
|
|
167
|
+
hidden_states, encoder_hidden_states = (
|
|
168
|
+
encoder_hidden_states,
|
|
169
|
+
hidden_states,
|
|
170
|
+
)
|
|
171
|
+
return hidden_states, encoder_hidden_states
|
|
172
|
+
|
|
173
|
+
@torch.compiler.disable
|
|
174
|
+
def _process_forward_outputs(
|
|
175
|
+
self,
|
|
176
|
+
hidden_states: torch.Tensor,
|
|
177
|
+
encoder_hidden_states: torch.Tensor | None,
|
|
178
|
+
) -> tuple[torch.Tensor, torch.Tensor | None] | torch.Tensor:
|
|
179
|
+
return (
|
|
180
|
+
hidden_states
|
|
181
|
+
if self.forward_pattern.Return_H_Only
|
|
182
|
+
else (
|
|
183
|
+
(hidden_states, encoder_hidden_states)
|
|
184
|
+
if self.forward_pattern.Return_H_First
|
|
185
|
+
else (encoder_hidden_states, hidden_states)
|
|
186
|
+
)
|
|
187
|
+
)
|
|
188
|
+
|
|
189
|
+
@torch.compiler.disable
|
|
190
|
+
def _check_if_context_parallel_enabled(
|
|
191
|
+
self,
|
|
192
|
+
module: torch.nn.Module,
|
|
193
|
+
) -> bool:
|
|
194
|
+
if ContextParallelSplitHook is None:
|
|
195
|
+
return False
|
|
196
|
+
if hasattr(module, "_diffusers_hook"):
|
|
197
|
+
_diffusers_hook: HookRegistry = module._diffusers_hook
|
|
198
|
+
for hook in _diffusers_hook.hooks.values():
|
|
199
|
+
if isinstance(hook, ContextParallelSplitHook):
|
|
200
|
+
return True
|
|
201
|
+
return False
|
|
202
|
+
|
|
203
|
+
def _get_Fn_residual(
|
|
204
|
+
self,
|
|
205
|
+
original_hidden_states: torch.Tensor,
|
|
206
|
+
hidden_states: torch.Tensor,
|
|
207
|
+
) -> torch.Tensor:
|
|
208
|
+
# NOTE: Make cases compatible with context parallelism while using
|
|
209
|
+
# block level cp plan, e.g., WanTransformer3DModel. The shape of
|
|
210
|
+
# `original_hidden_states` and `hidden_states` after Fn maybe
|
|
211
|
+
# different due to seqlen split in context parallelism.
|
|
212
|
+
if self._check_if_context_parallel_enabled(
|
|
213
|
+
self.transformer_blocks[0]
|
|
214
|
+
) and (original_hidden_states.shape != hidden_states.shape):
|
|
215
|
+
# Force use `hidden_states` as the Fn states residual for subsequent
|
|
216
|
+
# dynamic cache processing if the shape is different.
|
|
217
|
+
Fn_hidden_states_residual = hidden_states
|
|
218
|
+
if logger.isEnabledFor(logging.DEBUG):
|
|
219
|
+
logger.debug(
|
|
220
|
+
f"Context parallelism is enabled in Fn blocks, and the shape of "
|
|
221
|
+
f"original_hidden_states {original_hidden_states.shape} and "
|
|
222
|
+
f"hidden_states {hidden_states.shape} are different after Fn blocks. "
|
|
223
|
+
f"Use hidden_states as Fn_hidden_states_residual directly."
|
|
224
|
+
)
|
|
225
|
+
else:
|
|
226
|
+
Fn_hidden_states_residual = (
|
|
227
|
+
hidden_states - original_hidden_states.to(hidden_states.device)
|
|
228
|
+
)
|
|
229
|
+
return Fn_hidden_states_residual
|
|
230
|
+
|
|
231
|
+
def forward(
|
|
232
|
+
self,
|
|
233
|
+
hidden_states: torch.Tensor,
|
|
234
|
+
encoder_hidden_states: torch.Tensor,
|
|
235
|
+
*args,
|
|
236
|
+
**kwargs,
|
|
237
|
+
):
|
|
238
|
+
# Use it's own cache context.
|
|
239
|
+
try:
|
|
240
|
+
self.context_manager.set_context(self.cache_context)
|
|
241
|
+
self._check_cache_params()
|
|
242
|
+
except ContextNotExistError as e:
|
|
243
|
+
logger.warning(f"Cache context not exist: {e}, skip cache.")
|
|
244
|
+
# Call all blocks to process the hidden states.
|
|
245
|
+
hidden_states, encoder_hidden_states = self.call_blocks(
|
|
246
|
+
hidden_states,
|
|
247
|
+
encoder_hidden_states,
|
|
248
|
+
*args,
|
|
249
|
+
**kwargs,
|
|
250
|
+
)
|
|
251
|
+
return self._process_forward_outputs(
|
|
252
|
+
hidden_states,
|
|
253
|
+
encoder_hidden_states,
|
|
254
|
+
)
|
|
255
|
+
|
|
256
|
+
original_hidden_states = hidden_states
|
|
257
|
+
# Call first `n` blocks to process the hidden states for
|
|
258
|
+
# more stable diff calculation.
|
|
259
|
+
hidden_states, encoder_hidden_states = self.call_Fn_blocks(
|
|
260
|
+
hidden_states,
|
|
261
|
+
encoder_hidden_states,
|
|
262
|
+
*args,
|
|
263
|
+
**kwargs,
|
|
264
|
+
)
|
|
265
|
+
|
|
266
|
+
Fn_hidden_states_residual = self._get_Fn_residual(
|
|
267
|
+
original_hidden_states, hidden_states
|
|
268
|
+
)
|
|
269
|
+
del original_hidden_states
|
|
270
|
+
|
|
271
|
+
self.context_manager.mark_step_begin()
|
|
272
|
+
# Residual L1 diff or Hidden States L1 diff
|
|
273
|
+
can_use_cache = self.context_manager.can_cache(
|
|
274
|
+
(
|
|
275
|
+
Fn_hidden_states_residual
|
|
276
|
+
if not self.context_manager.is_l1_diff_enabled()
|
|
277
|
+
else hidden_states
|
|
278
|
+
),
|
|
279
|
+
parallelized=self._is_parallelized(),
|
|
280
|
+
prefix=(
|
|
281
|
+
f"{self.cache_prefix}_Fn_residual"
|
|
282
|
+
if not self.context_manager.is_l1_diff_enabled()
|
|
283
|
+
else f"{self.cache_prefix}_Fn_hidden_states"
|
|
284
|
+
),
|
|
285
|
+
)
|
|
286
|
+
|
|
287
|
+
torch._dynamo.graph_break()
|
|
288
|
+
if can_use_cache:
|
|
289
|
+
self.context_manager.add_cached_step()
|
|
290
|
+
del Fn_hidden_states_residual
|
|
291
|
+
hidden_states, encoder_hidden_states = (
|
|
292
|
+
self.context_manager.apply_cache(
|
|
293
|
+
hidden_states,
|
|
294
|
+
encoder_hidden_states,
|
|
295
|
+
prefix=(
|
|
296
|
+
f"{self.cache_prefix}_Bn_residual"
|
|
297
|
+
if self.context_manager.is_cache_residual()
|
|
298
|
+
else f"{self.cache_prefix}_Bn_hidden_states"
|
|
299
|
+
),
|
|
300
|
+
encoder_prefix=(
|
|
301
|
+
f"{self.cache_prefix}_Bn_residual"
|
|
302
|
+
if self.context_manager.is_encoder_cache_residual()
|
|
303
|
+
else f"{self.cache_prefix}_Bn_hidden_states"
|
|
304
|
+
),
|
|
305
|
+
)
|
|
306
|
+
)
|
|
307
|
+
torch._dynamo.graph_break()
|
|
308
|
+
# Call last `n` blocks to further process the hidden states
|
|
309
|
+
# for higher precision.
|
|
310
|
+
hidden_states, encoder_hidden_states = self.call_Bn_blocks(
|
|
311
|
+
hidden_states,
|
|
312
|
+
encoder_hidden_states,
|
|
313
|
+
*args,
|
|
314
|
+
**kwargs,
|
|
315
|
+
)
|
|
316
|
+
else:
|
|
317
|
+
self.context_manager.set_Fn_buffer(
|
|
318
|
+
Fn_hidden_states_residual,
|
|
319
|
+
prefix=f"{self.cache_prefix}_Fn_residual",
|
|
320
|
+
)
|
|
321
|
+
if self.context_manager.is_l1_diff_enabled():
|
|
322
|
+
# for hidden states L1 diff
|
|
323
|
+
self.context_manager.set_Fn_buffer(
|
|
324
|
+
hidden_states,
|
|
325
|
+
f"{self.cache_prefix}_Fn_hidden_states",
|
|
326
|
+
)
|
|
327
|
+
del Fn_hidden_states_residual
|
|
328
|
+
torch._dynamo.graph_break()
|
|
329
|
+
(
|
|
330
|
+
hidden_states,
|
|
331
|
+
encoder_hidden_states,
|
|
332
|
+
hidden_states_residual,
|
|
333
|
+
encoder_hidden_states_residual,
|
|
334
|
+
) = self.call_Mn_blocks( # middle
|
|
335
|
+
hidden_states,
|
|
336
|
+
encoder_hidden_states,
|
|
337
|
+
*args,
|
|
338
|
+
**kwargs,
|
|
339
|
+
)
|
|
340
|
+
torch._dynamo.graph_break()
|
|
341
|
+
if self.context_manager.is_cache_residual():
|
|
342
|
+
self.context_manager.set_Bn_buffer(
|
|
343
|
+
hidden_states_residual,
|
|
344
|
+
prefix=f"{self.cache_prefix}_Bn_residual",
|
|
345
|
+
)
|
|
346
|
+
else:
|
|
347
|
+
self.context_manager.set_Bn_buffer(
|
|
348
|
+
hidden_states,
|
|
349
|
+
prefix=f"{self.cache_prefix}_Bn_hidden_states",
|
|
350
|
+
)
|
|
351
|
+
|
|
352
|
+
if self.context_manager.is_encoder_cache_residual():
|
|
353
|
+
self.context_manager.set_Bn_encoder_buffer(
|
|
354
|
+
encoder_hidden_states_residual,
|
|
355
|
+
prefix=f"{self.cache_prefix}_Bn_residual",
|
|
356
|
+
)
|
|
357
|
+
else:
|
|
358
|
+
self.context_manager.set_Bn_encoder_buffer(
|
|
359
|
+
encoder_hidden_states,
|
|
360
|
+
prefix=f"{self.cache_prefix}_Bn_hidden_states",
|
|
361
|
+
)
|
|
362
|
+
torch._dynamo.graph_break()
|
|
363
|
+
# Call last `n` blocks to further process the hidden states
|
|
364
|
+
# for higher precision.
|
|
365
|
+
hidden_states, encoder_hidden_states = self.call_Bn_blocks(
|
|
366
|
+
hidden_states,
|
|
367
|
+
encoder_hidden_states,
|
|
368
|
+
*args,
|
|
369
|
+
**kwargs,
|
|
370
|
+
)
|
|
371
|
+
|
|
372
|
+
# patch cached stats for blocks or remove it.
|
|
373
|
+
torch._dynamo.graph_break()
|
|
374
|
+
|
|
375
|
+
return self._process_forward_outputs(
|
|
376
|
+
hidden_states,
|
|
377
|
+
encoder_hidden_states,
|
|
378
|
+
)
|
|
379
|
+
|
|
380
|
+
@torch.compiler.disable
|
|
381
|
+
def _is_parallelized(self):
|
|
382
|
+
# Compatible with distributed inference.
|
|
383
|
+
return any(
|
|
384
|
+
(
|
|
385
|
+
all(
|
|
386
|
+
(
|
|
387
|
+
self.transformer is not None,
|
|
388
|
+
getattr(self.transformer, "_is_parallelized", False),
|
|
389
|
+
)
|
|
390
|
+
),
|
|
391
|
+
(dist.is_initialized() and dist.get_world_size() > 1),
|
|
392
|
+
)
|
|
393
|
+
)
|
|
394
|
+
|
|
395
|
+
@torch.compiler.disable
|
|
396
|
+
def _is_in_cache_step(self):
|
|
397
|
+
# Check if the current step is in cache steps.
|
|
398
|
+
# If so, we can skip some Bn blocks and directly
|
|
399
|
+
# use the cached values.
|
|
400
|
+
return (
|
|
401
|
+
self.context_manager.get_current_step()
|
|
402
|
+
in self.context_manager.get_cached_steps()
|
|
403
|
+
) or (
|
|
404
|
+
self.context_manager.get_current_step()
|
|
405
|
+
in self.context_manager.get_cfg_cached_steps()
|
|
406
|
+
)
|
|
407
|
+
|
|
408
|
+
@torch.compiler.disable
|
|
409
|
+
def _Fn_blocks(self):
|
|
410
|
+
# Select first `n` blocks to process the hidden states for
|
|
411
|
+
# more stable diff calculation.
|
|
412
|
+
# Fn: [0,...,n-1]
|
|
413
|
+
selected_Fn_blocks = self.transformer_blocks[
|
|
414
|
+
: self.context_manager.Fn_compute_blocks()
|
|
415
|
+
]
|
|
416
|
+
return selected_Fn_blocks
|
|
417
|
+
|
|
418
|
+
@torch.compiler.disable
|
|
419
|
+
def _Mn_blocks(self): # middle blocks
|
|
420
|
+
# M(N-2n): only transformer_blocks [n,...,N-n], middle
|
|
421
|
+
if self.context_manager.Bn_compute_blocks() == 0: # WARN: x[:-0] = []
|
|
422
|
+
selected_Mn_blocks = self.transformer_blocks[
|
|
423
|
+
self.context_manager.Fn_compute_blocks() :
|
|
424
|
+
]
|
|
425
|
+
else:
|
|
426
|
+
selected_Mn_blocks = self.transformer_blocks[
|
|
427
|
+
self.context_manager.Fn_compute_blocks() : -self.context_manager.Bn_compute_blocks()
|
|
428
|
+
]
|
|
429
|
+
return selected_Mn_blocks
|
|
430
|
+
|
|
431
|
+
@torch.compiler.disable
|
|
432
|
+
def _Bn_blocks(self):
|
|
433
|
+
# Bn: transformer_blocks [N-n+1,...,N-1]
|
|
434
|
+
selected_Bn_blocks = self.transformer_blocks[
|
|
435
|
+
-self.context_manager.Bn_compute_blocks() :
|
|
436
|
+
]
|
|
437
|
+
return selected_Bn_blocks
|
|
438
|
+
|
|
439
|
+
def call_Fn_blocks(
|
|
440
|
+
self,
|
|
441
|
+
hidden_states: torch.Tensor,
|
|
442
|
+
encoder_hidden_states: torch.Tensor,
|
|
443
|
+
*args,
|
|
444
|
+
**kwargs,
|
|
445
|
+
):
|
|
446
|
+
for block in self._Fn_blocks():
|
|
447
|
+
hidden_states = block(
|
|
448
|
+
hidden_states,
|
|
449
|
+
encoder_hidden_states,
|
|
450
|
+
*args,
|
|
451
|
+
**kwargs,
|
|
452
|
+
)
|
|
453
|
+
hidden_states, encoder_hidden_states = self._process_block_outputs(
|
|
454
|
+
hidden_states, encoder_hidden_states
|
|
455
|
+
)
|
|
456
|
+
|
|
457
|
+
return hidden_states, encoder_hidden_states
|
|
458
|
+
|
|
459
|
+
def call_Mn_blocks(
|
|
460
|
+
self,
|
|
461
|
+
hidden_states: torch.Tensor,
|
|
462
|
+
encoder_hidden_states: torch.Tensor,
|
|
463
|
+
*args,
|
|
464
|
+
**kwargs,
|
|
465
|
+
):
|
|
466
|
+
original_hidden_states = hidden_states
|
|
467
|
+
original_encoder_hidden_states = encoder_hidden_states
|
|
468
|
+
for block in self._Mn_blocks():
|
|
469
|
+
hidden_states = block(
|
|
470
|
+
hidden_states,
|
|
471
|
+
encoder_hidden_states,
|
|
472
|
+
*args,
|
|
473
|
+
**kwargs,
|
|
474
|
+
)
|
|
475
|
+
hidden_states, encoder_hidden_states = self._process_block_outputs(
|
|
476
|
+
hidden_states, encoder_hidden_states
|
|
477
|
+
)
|
|
478
|
+
|
|
479
|
+
# compute hidden_states residual
|
|
480
|
+
hidden_states = hidden_states.contiguous()
|
|
481
|
+
|
|
482
|
+
hidden_states_residual = hidden_states - original_hidden_states
|
|
483
|
+
|
|
484
|
+
if (
|
|
485
|
+
encoder_hidden_states is not None
|
|
486
|
+
and original_encoder_hidden_states is not None
|
|
487
|
+
):
|
|
488
|
+
encoder_hidden_states = encoder_hidden_states.contiguous()
|
|
489
|
+
encoder_hidden_states_residual = (
|
|
490
|
+
encoder_hidden_states - original_encoder_hidden_states
|
|
491
|
+
)
|
|
492
|
+
else:
|
|
493
|
+
encoder_hidden_states_residual = None
|
|
494
|
+
|
|
495
|
+
return (
|
|
496
|
+
hidden_states,
|
|
497
|
+
encoder_hidden_states,
|
|
498
|
+
hidden_states_residual,
|
|
499
|
+
encoder_hidden_states_residual,
|
|
500
|
+
)
|
|
501
|
+
|
|
502
|
+
def call_Bn_blocks(
|
|
503
|
+
self,
|
|
504
|
+
hidden_states: torch.Tensor,
|
|
505
|
+
encoder_hidden_states: torch.Tensor,
|
|
506
|
+
*args,
|
|
507
|
+
**kwargs,
|
|
508
|
+
):
|
|
509
|
+
if self.context_manager.Bn_compute_blocks() == 0:
|
|
510
|
+
return hidden_states, encoder_hidden_states
|
|
511
|
+
|
|
512
|
+
for block in self._Bn_blocks():
|
|
513
|
+
hidden_states = block(
|
|
514
|
+
hidden_states,
|
|
515
|
+
encoder_hidden_states,
|
|
516
|
+
*args,
|
|
517
|
+
**kwargs,
|
|
518
|
+
)
|
|
519
|
+
hidden_states, encoder_hidden_states = self._process_block_outputs(
|
|
520
|
+
hidden_states, encoder_hidden_states
|
|
521
|
+
)
|
|
522
|
+
|
|
523
|
+
return hidden_states, encoder_hidden_states
|
|
524
|
+
|
|
525
|
+
|
|
526
|
+
class PrunedBlocks_Pattern_Base(CachedBlocks_Pattern_Base):
|
|
527
|
+
pruned_blocks_step: int = 0 # number of pruned blocks in current step
|
|
528
|
+
|
|
529
|
+
def __init__(
|
|
530
|
+
self,
|
|
531
|
+
# 0. Transformer blocks configuration
|
|
532
|
+
transformer_blocks: torch.nn.ModuleList,
|
|
533
|
+
transformer: torch.nn.Module = None,
|
|
534
|
+
forward_pattern: ForwardPattern = ForwardPattern.Pattern_0,
|
|
535
|
+
check_forward_pattern: bool = True,
|
|
536
|
+
check_num_outputs: bool = True,
|
|
537
|
+
# 1. Prune context configuration
|
|
538
|
+
cache_prefix: str = None, # maybe un-need.
|
|
539
|
+
cache_context: PrunedContext | str = None,
|
|
540
|
+
context_manager: PrunedContextManager = None,
|
|
541
|
+
cache_type: CacheType = CacheType.DBPrune,
|
|
542
|
+
**kwargs,
|
|
543
|
+
):
|
|
544
|
+
super().__init__(
|
|
545
|
+
# 0. Transformer blocks configuration
|
|
546
|
+
transformer_blocks,
|
|
547
|
+
transformer=transformer,
|
|
548
|
+
forward_pattern=forward_pattern,
|
|
549
|
+
check_forward_pattern=check_forward_pattern,
|
|
550
|
+
check_num_outputs=check_num_outputs,
|
|
551
|
+
# 1. Cache context configuration
|
|
552
|
+
cache_prefix=cache_prefix,
|
|
553
|
+
cache_context=cache_context,
|
|
554
|
+
context_manager=context_manager,
|
|
555
|
+
cache_type=cache_type,
|
|
556
|
+
**kwargs,
|
|
557
|
+
)
|
|
558
|
+
assert isinstance(
|
|
559
|
+
self.context_manager, PrunedContextManager
|
|
560
|
+
), "context_manager must be PrunedContextManager for PrunedBlocks."
|
|
561
|
+
self.context_manager: PrunedContextManager = (
|
|
562
|
+
self.context_manager
|
|
563
|
+
) # For type hint
|
|
564
|
+
|
|
565
|
+
@torch.compiler.disable
|
|
566
|
+
def _check_cache_type(self):
|
|
567
|
+
assert (
|
|
568
|
+
self.cache_type == CacheType.DBPrune
|
|
569
|
+
), f"Cache type {self.cache_type} is not supported for PrunedBlocks."
|
|
570
|
+
|
|
571
|
+
def forward(
|
|
572
|
+
self,
|
|
573
|
+
hidden_states: torch.Tensor,
|
|
574
|
+
encoder_hidden_states: torch.Tensor,
|
|
575
|
+
*args,
|
|
576
|
+
**kwargs,
|
|
577
|
+
):
|
|
578
|
+
self.pruned_blocks_step: int = 0 # reset for each step
|
|
579
|
+
|
|
580
|
+
# Use it's own cache context.
|
|
581
|
+
try:
|
|
582
|
+
self.context_manager.set_context(self.cache_context)
|
|
583
|
+
self._check_cache_params()
|
|
584
|
+
except ContextNotExistError as e:
|
|
585
|
+
logger.warning(f"Cache context not exist: {e}, skip prune.")
|
|
586
|
+
# Fallback to call all blocks to process the hidden states w/o prune.
|
|
587
|
+
hidden_states, encoder_hidden_states = self.call_blocks(
|
|
588
|
+
hidden_states,
|
|
589
|
+
encoder_hidden_states,
|
|
590
|
+
*args,
|
|
591
|
+
**kwargs,
|
|
592
|
+
)
|
|
593
|
+
return self._process_forward_outputs(
|
|
594
|
+
hidden_states,
|
|
595
|
+
encoder_hidden_states,
|
|
596
|
+
)
|
|
597
|
+
|
|
598
|
+
self.context_manager.mark_step_begin()
|
|
599
|
+
|
|
600
|
+
if self._check_if_context_parallel_enabled(self.transformer_blocks[0]):
|
|
601
|
+
raise RuntimeError(
|
|
602
|
+
"Block level Context parallelism is not supported in PrunedBlocks."
|
|
603
|
+
)
|
|
604
|
+
|
|
605
|
+
# Call all blocks with prune strategy to process the hidden states.
|
|
606
|
+
for i, block in enumerate(self.transformer_blocks):
|
|
607
|
+
hidden_states, encoder_hidden_states = self.compute_or_prune(
|
|
608
|
+
i,
|
|
609
|
+
block,
|
|
610
|
+
hidden_states,
|
|
611
|
+
encoder_hidden_states,
|
|
612
|
+
*args,
|
|
613
|
+
**kwargs,
|
|
614
|
+
)
|
|
615
|
+
|
|
616
|
+
self.context_manager.add_pruned_block(self.pruned_blocks_step)
|
|
617
|
+
self.context_manager.add_actual_block(self.num_blocks)
|
|
618
|
+
|
|
619
|
+
return self._process_forward_outputs(
|
|
620
|
+
hidden_states,
|
|
621
|
+
encoder_hidden_states,
|
|
622
|
+
)
|
|
623
|
+
|
|
624
|
+
@property
|
|
625
|
+
@torch.compiler.disable
|
|
626
|
+
def num_blocks(self):
|
|
627
|
+
return len(self.transformer_blocks)
|
|
628
|
+
|
|
629
|
+
@torch.compiler.disable
|
|
630
|
+
def _skip_prune(self, block_id: int) -> bool:
|
|
631
|
+
# Wrap for non compiled mode.
|
|
632
|
+
return block_id in self.context_manager.get_non_prune_blocks_ids(
|
|
633
|
+
self.num_blocks
|
|
634
|
+
)
|
|
635
|
+
|
|
636
|
+
@torch.compiler.disable
|
|
637
|
+
def _maybe_prune(
|
|
638
|
+
self,
|
|
639
|
+
block_id: int, # Block index in the transformer blocks
|
|
640
|
+
hidden_states: torch.Tensor, # hidden_states or residual
|
|
641
|
+
prefix: str = "Bn_original", # prev step name for single blocks
|
|
642
|
+
):
|
|
643
|
+
# Wrap for non compiled mode.
|
|
644
|
+
can_use_prune = False
|
|
645
|
+
if not self._skip_prune(block_id):
|
|
646
|
+
can_use_prune = self.context_manager.can_prune(
|
|
647
|
+
hidden_states, # curr step
|
|
648
|
+
parallelized=self._is_parallelized(),
|
|
649
|
+
prefix=prefix, # prev step
|
|
650
|
+
)
|
|
651
|
+
self.pruned_blocks_step += int(can_use_prune)
|
|
652
|
+
return can_use_prune
|
|
653
|
+
|
|
654
|
+
def compute_or_prune(
|
|
655
|
+
self,
|
|
656
|
+
block_id: int, # Block index in the transformer blocks
|
|
657
|
+
# Below are the inputs to the block
|
|
658
|
+
block, # The transformer block to be executed
|
|
659
|
+
hidden_states: torch.Tensor,
|
|
660
|
+
encoder_hidden_states: torch.Tensor,
|
|
661
|
+
*args,
|
|
662
|
+
**kwargs,
|
|
663
|
+
):
|
|
664
|
+
original_hidden_states = hidden_states
|
|
665
|
+
original_encoder_hidden_states = encoder_hidden_states
|
|
666
|
+
|
|
667
|
+
can_use_prune = self._maybe_prune(
|
|
668
|
+
block_id,
|
|
669
|
+
hidden_states,
|
|
670
|
+
prefix=f"{self.cache_prefix}_{block_id}_Fn_original",
|
|
671
|
+
)
|
|
672
|
+
|
|
673
|
+
# Prune steps: Prune current block and reuse the cached
|
|
674
|
+
# residuals for hidden states approximate.
|
|
675
|
+
torch._dynamo.graph_break()
|
|
676
|
+
if can_use_prune:
|
|
677
|
+
self.context_manager.add_pruned_step()
|
|
678
|
+
hidden_states, encoder_hidden_states = (
|
|
679
|
+
self.context_manager.apply_prune(
|
|
680
|
+
hidden_states,
|
|
681
|
+
encoder_hidden_states,
|
|
682
|
+
prefix=(
|
|
683
|
+
f"{self.cache_prefix}_{block_id}_Bn_residual"
|
|
684
|
+
if self.context_manager.is_cache_residual()
|
|
685
|
+
else f"{self.cache_prefix}_{block_id}_Bn_hidden_states"
|
|
686
|
+
),
|
|
687
|
+
encoder_prefix=(
|
|
688
|
+
f"{self.cache_prefix}_{block_id}_Bn_encoder_residual"
|
|
689
|
+
if self.context_manager.is_encoder_cache_residual()
|
|
690
|
+
else f"{self.cache_prefix}_{block_id}_Bn_encoder_hidden_states"
|
|
691
|
+
),
|
|
692
|
+
)
|
|
693
|
+
)
|
|
694
|
+
torch._dynamo.graph_break()
|
|
695
|
+
else:
|
|
696
|
+
# Normal steps: Compute the block and cache the residuals.
|
|
697
|
+
hidden_states = block(
|
|
698
|
+
hidden_states,
|
|
699
|
+
encoder_hidden_states,
|
|
700
|
+
*args,
|
|
701
|
+
**kwargs,
|
|
702
|
+
)
|
|
703
|
+
hidden_states, encoder_hidden_states = self._process_block_outputs(
|
|
704
|
+
hidden_states, encoder_hidden_states
|
|
705
|
+
)
|
|
706
|
+
if not self._skip_prune(block_id):
|
|
707
|
+
hidden_states = hidden_states.contiguous()
|
|
708
|
+
hidden_states_residual = hidden_states - original_hidden_states
|
|
709
|
+
|
|
710
|
+
if (
|
|
711
|
+
encoder_hidden_states is not None
|
|
712
|
+
and original_encoder_hidden_states is not None
|
|
713
|
+
):
|
|
714
|
+
encoder_hidden_states = encoder_hidden_states.contiguous()
|
|
715
|
+
encoder_hidden_states_residual = (
|
|
716
|
+
encoder_hidden_states - original_encoder_hidden_states
|
|
717
|
+
)
|
|
718
|
+
else:
|
|
719
|
+
encoder_hidden_states_residual = None
|
|
720
|
+
|
|
721
|
+
self.context_manager.set_Fn_buffer(
|
|
722
|
+
original_hidden_states,
|
|
723
|
+
prefix=f"{self.cache_prefix}_{block_id}_Fn_original",
|
|
724
|
+
)
|
|
725
|
+
if self.context_manager.is_cache_residual():
|
|
726
|
+
self.context_manager.set_Bn_buffer(
|
|
727
|
+
hidden_states_residual,
|
|
728
|
+
prefix=f"{self.cache_prefix}_{block_id}_Bn_residual",
|
|
729
|
+
)
|
|
730
|
+
else:
|
|
731
|
+
self.context_manager.set_Bn_buffer(
|
|
732
|
+
hidden_states,
|
|
733
|
+
prefix=f"{self.cache_prefix}_{block_id}_Bn_hidden_states",
|
|
734
|
+
)
|
|
735
|
+
if encoder_hidden_states_residual is not None:
|
|
736
|
+
if self.context_manager.is_encoder_cache_residual():
|
|
737
|
+
self.context_manager.set_Bn_encoder_buffer(
|
|
738
|
+
encoder_hidden_states_residual,
|
|
739
|
+
prefix=f"{self.cache_prefix}_{block_id}_Bn_encoder_residual",
|
|
740
|
+
)
|
|
741
|
+
else:
|
|
742
|
+
self.context_manager.set_Bn_encoder_buffer(
|
|
743
|
+
encoder_hidden_states_residual,
|
|
744
|
+
prefix=f"{self.cache_prefix}_{block_id}_Bn_encoder_hidden_states",
|
|
745
|
+
)
|
|
746
|
+
torch._dynamo.graph_break()
|
|
747
|
+
|
|
748
|
+
return hidden_states, encoder_hidden_states
|