cache-dit 0.3.2__py3-none-any.whl → 1.0.14__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- cache_dit/__init__.py +37 -19
- cache_dit/_version.py +2 -2
- cache_dit/caching/__init__.py +36 -0
- cache_dit/{cache_factory → caching}/block_adapters/__init__.py +149 -18
- cache_dit/{cache_factory → caching}/block_adapters/block_adapters.py +91 -7
- cache_dit/caching/block_adapters/block_registers.py +118 -0
- cache_dit/caching/cache_adapters/__init__.py +1 -0
- cache_dit/{cache_factory → caching}/cache_adapters/cache_adapter.py +262 -123
- cache_dit/caching/cache_blocks/__init__.py +226 -0
- cache_dit/caching/cache_blocks/offload_utils.py +115 -0
- cache_dit/caching/cache_blocks/pattern_0_1_2.py +26 -0
- cache_dit/caching/cache_blocks/pattern_3_4_5.py +543 -0
- cache_dit/caching/cache_blocks/pattern_base.py +748 -0
- cache_dit/caching/cache_blocks/pattern_utils.py +86 -0
- cache_dit/caching/cache_contexts/__init__.py +28 -0
- cache_dit/caching/cache_contexts/cache_config.py +120 -0
- cache_dit/{cache_factory → caching}/cache_contexts/cache_context.py +29 -90
- cache_dit/{cache_factory → caching}/cache_contexts/cache_manager.py +138 -10
- cache_dit/{cache_factory → caching}/cache_contexts/calibrators/__init__.py +25 -3
- cache_dit/{cache_factory → caching}/cache_contexts/calibrators/foca.py +1 -1
- cache_dit/{cache_factory → caching}/cache_contexts/calibrators/taylorseer.py +81 -9
- cache_dit/caching/cache_contexts/context_manager.py +36 -0
- cache_dit/caching/cache_contexts/prune_config.py +63 -0
- cache_dit/caching/cache_contexts/prune_context.py +155 -0
- cache_dit/caching/cache_contexts/prune_manager.py +167 -0
- cache_dit/caching/cache_interface.py +358 -0
- cache_dit/{cache_factory → caching}/cache_types.py +19 -2
- cache_dit/{cache_factory → caching}/forward_pattern.py +14 -14
- cache_dit/{cache_factory → caching}/params_modifier.py +10 -10
- cache_dit/caching/patch_functors/__init__.py +15 -0
- cache_dit/{cache_factory → caching}/patch_functors/functor_chroma.py +1 -1
- cache_dit/{cache_factory → caching}/patch_functors/functor_dit.py +1 -1
- cache_dit/{cache_factory → caching}/patch_functors/functor_flux.py +1 -1
- cache_dit/{cache_factory → caching}/patch_functors/functor_hidream.py +2 -4
- cache_dit/{cache_factory → caching}/patch_functors/functor_hunyuan_dit.py +1 -1
- cache_dit/caching/patch_functors/functor_qwen_image_controlnet.py +263 -0
- cache_dit/caching/utils.py +68 -0
- cache_dit/metrics/__init__.py +11 -0
- cache_dit/metrics/metrics.py +3 -0
- cache_dit/parallelism/__init__.py +3 -0
- cache_dit/parallelism/backends/native_diffusers/__init__.py +6 -0
- cache_dit/parallelism/backends/native_diffusers/context_parallelism/__init__.py +164 -0
- cache_dit/parallelism/backends/native_diffusers/context_parallelism/attention/__init__.py +4 -0
- cache_dit/parallelism/backends/native_diffusers/context_parallelism/attention/_attention_dispatch.py +304 -0
- cache_dit/parallelism/backends/native_diffusers/context_parallelism/cp_plan_chroma.py +95 -0
- cache_dit/parallelism/backends/native_diffusers/context_parallelism/cp_plan_cogvideox.py +202 -0
- cache_dit/parallelism/backends/native_diffusers/context_parallelism/cp_plan_cogview.py +299 -0
- cache_dit/parallelism/backends/native_diffusers/context_parallelism/cp_plan_cosisid.py +123 -0
- cache_dit/parallelism/backends/native_diffusers/context_parallelism/cp_plan_dit.py +94 -0
- cache_dit/parallelism/backends/native_diffusers/context_parallelism/cp_plan_flux.py +88 -0
- cache_dit/parallelism/backends/native_diffusers/context_parallelism/cp_plan_hunyuan.py +729 -0
- cache_dit/parallelism/backends/native_diffusers/context_parallelism/cp_plan_ltxvideo.py +264 -0
- cache_dit/parallelism/backends/native_diffusers/context_parallelism/cp_plan_nunchaku.py +407 -0
- cache_dit/parallelism/backends/native_diffusers/context_parallelism/cp_plan_pixart.py +285 -0
- cache_dit/parallelism/backends/native_diffusers/context_parallelism/cp_plan_qwen_image.py +104 -0
- cache_dit/parallelism/backends/native_diffusers/context_parallelism/cp_plan_registers.py +84 -0
- cache_dit/parallelism/backends/native_diffusers/context_parallelism/cp_plan_wan.py +101 -0
- cache_dit/parallelism/backends/native_diffusers/context_parallelism/cp_planners.py +117 -0
- cache_dit/parallelism/backends/native_diffusers/parallel_difffusers.py +49 -0
- cache_dit/parallelism/backends/native_diffusers/utils.py +11 -0
- cache_dit/parallelism/backends/native_pytorch/__init__.py +6 -0
- cache_dit/parallelism/backends/native_pytorch/parallel_torch.py +62 -0
- cache_dit/parallelism/backends/native_pytorch/tensor_parallelism/__init__.py +48 -0
- cache_dit/parallelism/backends/native_pytorch/tensor_parallelism/tp_plan_flux.py +171 -0
- cache_dit/parallelism/backends/native_pytorch/tensor_parallelism/tp_plan_kandinsky5.py +79 -0
- cache_dit/parallelism/backends/native_pytorch/tensor_parallelism/tp_plan_qwen_image.py +78 -0
- cache_dit/parallelism/backends/native_pytorch/tensor_parallelism/tp_plan_registers.py +65 -0
- cache_dit/parallelism/backends/native_pytorch/tensor_parallelism/tp_plan_wan.py +153 -0
- cache_dit/parallelism/backends/native_pytorch/tensor_parallelism/tp_planners.py +14 -0
- cache_dit/parallelism/parallel_backend.py +26 -0
- cache_dit/parallelism/parallel_config.py +88 -0
- cache_dit/parallelism/parallel_interface.py +77 -0
- cache_dit/quantize/__init__.py +7 -0
- cache_dit/quantize/backends/__init__.py +1 -0
- cache_dit/quantize/backends/bitsandbytes/__init__.py +0 -0
- cache_dit/quantize/backends/torchao/__init__.py +1 -0
- cache_dit/quantize/{quantize_ao.py → backends/torchao/quantize_ao.py} +44 -30
- cache_dit/quantize/quantize_backend.py +0 -0
- cache_dit/quantize/quantize_config.py +0 -0
- cache_dit/quantize/quantize_interface.py +3 -16
- cache_dit/summary.py +593 -0
- cache_dit/utils.py +46 -290
- cache_dit-1.0.14.dist-info/METADATA +301 -0
- cache_dit-1.0.14.dist-info/RECORD +102 -0
- cache_dit-1.0.14.dist-info/licenses/LICENSE +203 -0
- cache_dit/cache_factory/__init__.py +0 -28
- cache_dit/cache_factory/block_adapters/block_registers.py +0 -90
- cache_dit/cache_factory/cache_adapters/__init__.py +0 -1
- cache_dit/cache_factory/cache_blocks/__init__.py +0 -72
- cache_dit/cache_factory/cache_blocks/pattern_0_1_2.py +0 -16
- cache_dit/cache_factory/cache_blocks/pattern_3_4_5.py +0 -238
- cache_dit/cache_factory/cache_blocks/pattern_base.py +0 -404
- cache_dit/cache_factory/cache_blocks/utils.py +0 -41
- cache_dit/cache_factory/cache_contexts/__init__.py +0 -14
- cache_dit/cache_factory/cache_interface.py +0 -217
- cache_dit/cache_factory/patch_functors/__init__.py +0 -12
- cache_dit/cache_factory/utils.py +0 -57
- cache_dit-0.3.2.dist-info/METADATA +0 -753
- cache_dit-0.3.2.dist-info/RECORD +0 -56
- cache_dit-0.3.2.dist-info/licenses/LICENSE +0 -53
- /cache_dit/{cache_factory → caching}/.gitignore +0 -0
- /cache_dit/{cache_factory → caching}/cache_contexts/calibrators/base.py +0 -0
- /cache_dit/{cache_factory → caching}/patch_functors/functor_base.py +0 -0
- /cache_dit/{custom_ops → kernels}/__init__.py +0 -0
- /cache_dit/{custom_ops → kernels}/triton_taylorseer.py +0 -0
- {cache_dit-0.3.2.dist-info → cache_dit-1.0.14.dist-info}/WHEEL +0 -0
- {cache_dit-0.3.2.dist-info → cache_dit-1.0.14.dist-info}/entry_points.txt +0 -0
- {cache_dit-0.3.2.dist-info → cache_dit-1.0.14.dist-info}/top_level.txt +0 -0
|
@@ -0,0 +1,543 @@
|
|
|
1
|
+
import torch
|
|
2
|
+
|
|
3
|
+
from cache_dit.caching import ForwardPattern
|
|
4
|
+
from cache_dit.caching.cache_contexts.cache_manager import (
|
|
5
|
+
ContextNotExistError,
|
|
6
|
+
)
|
|
7
|
+
from cache_dit.caching.cache_blocks.pattern_base import (
|
|
8
|
+
CachedBlocks_Pattern_Base,
|
|
9
|
+
)
|
|
10
|
+
from cache_dit.caching.cache_contexts.prune_context import PrunedContext
|
|
11
|
+
from cache_dit.caching.cache_contexts.prune_manager import (
|
|
12
|
+
PrunedContextManager,
|
|
13
|
+
)
|
|
14
|
+
from cache_dit.caching.cache_types import CacheType
|
|
15
|
+
|
|
16
|
+
from cache_dit.logger import init_logger
|
|
17
|
+
|
|
18
|
+
logger = init_logger(__name__)
|
|
19
|
+
|
|
20
|
+
|
|
21
|
+
class CachedBlocks_Pattern_3_4_5(CachedBlocks_Pattern_Base):
|
|
22
|
+
_supported_patterns = [
|
|
23
|
+
ForwardPattern.Pattern_3,
|
|
24
|
+
ForwardPattern.Pattern_4,
|
|
25
|
+
ForwardPattern.Pattern_5,
|
|
26
|
+
]
|
|
27
|
+
|
|
28
|
+
def call_blocks(
|
|
29
|
+
self,
|
|
30
|
+
hidden_states: torch.Tensor,
|
|
31
|
+
*args,
|
|
32
|
+
**kwargs,
|
|
33
|
+
):
|
|
34
|
+
# Call all blocks to process the hidden states without cache.
|
|
35
|
+
new_encoder_hidden_states = None
|
|
36
|
+
for block in self.transformer_blocks:
|
|
37
|
+
hidden_states = block(
|
|
38
|
+
hidden_states,
|
|
39
|
+
*args,
|
|
40
|
+
**kwargs,
|
|
41
|
+
)
|
|
42
|
+
hidden_states, new_encoder_hidden_states = (
|
|
43
|
+
self._process_block_outputs(hidden_states)
|
|
44
|
+
)
|
|
45
|
+
|
|
46
|
+
return hidden_states, new_encoder_hidden_states
|
|
47
|
+
|
|
48
|
+
@torch.compiler.disable
|
|
49
|
+
def _process_block_outputs(
|
|
50
|
+
self, hidden_states: torch.Tensor | tuple
|
|
51
|
+
) -> tuple[torch.Tensor, torch.Tensor | None]:
|
|
52
|
+
# Process the outputs for the block.
|
|
53
|
+
new_encoder_hidden_states = None
|
|
54
|
+
if not isinstance(hidden_states, torch.Tensor): # Pattern 4, 5
|
|
55
|
+
if len(hidden_states) == 2:
|
|
56
|
+
if isinstance(hidden_states[1], torch.Tensor):
|
|
57
|
+
hidden_states, new_encoder_hidden_states = hidden_states
|
|
58
|
+
if not self.forward_pattern.Return_H_First:
|
|
59
|
+
hidden_states, new_encoder_hidden_states = (
|
|
60
|
+
new_encoder_hidden_states,
|
|
61
|
+
hidden_states,
|
|
62
|
+
)
|
|
63
|
+
elif isinstance(hidden_states[0], torch.Tensor):
|
|
64
|
+
hidden_states = hidden_states[0]
|
|
65
|
+
else:
|
|
66
|
+
raise ValueError("Unexpected hidden_states format.")
|
|
67
|
+
else:
|
|
68
|
+
assert (
|
|
69
|
+
len(hidden_states) == 1
|
|
70
|
+
), f"Unexpected output length: {len(hidden_states)}"
|
|
71
|
+
hidden_states = hidden_states[0]
|
|
72
|
+
return hidden_states, new_encoder_hidden_states
|
|
73
|
+
|
|
74
|
+
@torch.compiler.disable
|
|
75
|
+
def _process_forward_outputs(
|
|
76
|
+
self,
|
|
77
|
+
hidden_states: torch.Tensor,
|
|
78
|
+
new_encoder_hidden_states: torch.Tensor | None,
|
|
79
|
+
) -> (
|
|
80
|
+
torch.Tensor
|
|
81
|
+
| tuple[torch.Tensor, torch.Tensor]
|
|
82
|
+
| tuple[torch.Tensor, None]
|
|
83
|
+
):
|
|
84
|
+
if self.forward_pattern.Return_H_Only:
|
|
85
|
+
return hidden_states
|
|
86
|
+
else:
|
|
87
|
+
if self.forward_pattern.Return_H_First:
|
|
88
|
+
return (hidden_states, new_encoder_hidden_states)
|
|
89
|
+
else:
|
|
90
|
+
return (new_encoder_hidden_states, hidden_states)
|
|
91
|
+
|
|
92
|
+
def forward(
|
|
93
|
+
self,
|
|
94
|
+
hidden_states: torch.Tensor,
|
|
95
|
+
*args,
|
|
96
|
+
**kwargs,
|
|
97
|
+
):
|
|
98
|
+
# Use it's own cache context.
|
|
99
|
+
try:
|
|
100
|
+
self.context_manager.set_context(self.cache_context)
|
|
101
|
+
self._check_cache_params()
|
|
102
|
+
except ContextNotExistError as e:
|
|
103
|
+
logger.warning(f"context not exist: {e}, skip cache.")
|
|
104
|
+
hidden_states, new_encoder_hidden_states = self.call_blocks(
|
|
105
|
+
hidden_states,
|
|
106
|
+
*args,
|
|
107
|
+
**kwargs,
|
|
108
|
+
)
|
|
109
|
+
return self._process_forward_outputs(
|
|
110
|
+
hidden_states, new_encoder_hidden_states
|
|
111
|
+
)
|
|
112
|
+
|
|
113
|
+
original_hidden_states = hidden_states
|
|
114
|
+
# Call first `n` blocks to process the hidden states for
|
|
115
|
+
# more stable diff calculation.
|
|
116
|
+
hidden_states, new_encoder_hidden_states = self.call_Fn_blocks(
|
|
117
|
+
hidden_states,
|
|
118
|
+
*args,
|
|
119
|
+
**kwargs,
|
|
120
|
+
)
|
|
121
|
+
|
|
122
|
+
Fn_hidden_states_residual = self._get_Fn_residual(
|
|
123
|
+
original_hidden_states, hidden_states
|
|
124
|
+
)
|
|
125
|
+
del original_hidden_states
|
|
126
|
+
|
|
127
|
+
self.context_manager.mark_step_begin()
|
|
128
|
+
# Residual L1 diff or Hidden States L1 diff
|
|
129
|
+
can_use_cache = self.context_manager.can_cache(
|
|
130
|
+
(
|
|
131
|
+
Fn_hidden_states_residual
|
|
132
|
+
if not self.context_manager.is_l1_diff_enabled()
|
|
133
|
+
else hidden_states
|
|
134
|
+
),
|
|
135
|
+
parallelized=self._is_parallelized(),
|
|
136
|
+
prefix=(
|
|
137
|
+
f"{self.cache_prefix}_Fn_residual"
|
|
138
|
+
if not self.context_manager.is_l1_diff_enabled()
|
|
139
|
+
else f"{self.cache_prefix}_Fn_hidden_states"
|
|
140
|
+
),
|
|
141
|
+
)
|
|
142
|
+
|
|
143
|
+
torch._dynamo.graph_break()
|
|
144
|
+
if can_use_cache:
|
|
145
|
+
self.context_manager.add_cached_step()
|
|
146
|
+
del Fn_hidden_states_residual
|
|
147
|
+
hidden_states, new_encoder_hidden_states = (
|
|
148
|
+
self.context_manager.apply_cache(
|
|
149
|
+
hidden_states,
|
|
150
|
+
new_encoder_hidden_states, # encoder_hidden_states not use cache
|
|
151
|
+
prefix=(
|
|
152
|
+
f"{self.cache_prefix}_Bn_residual"
|
|
153
|
+
if self.context_manager.is_cache_residual()
|
|
154
|
+
else f"{self.cache_prefix}_Bn_hidden_states"
|
|
155
|
+
),
|
|
156
|
+
encoder_prefix=(
|
|
157
|
+
f"{self.cache_prefix}_Bn_residual"
|
|
158
|
+
if self.context_manager.is_encoder_cache_residual()
|
|
159
|
+
else f"{self.cache_prefix}_Bn_hidden_states"
|
|
160
|
+
),
|
|
161
|
+
)
|
|
162
|
+
)
|
|
163
|
+
torch._dynamo.graph_break()
|
|
164
|
+
# Call last `n` blocks to further process the hidden states
|
|
165
|
+
# for higher precision.
|
|
166
|
+
if self.context_manager.Bn_compute_blocks() > 0:
|
|
167
|
+
hidden_states, new_encoder_hidden_states = self.call_Bn_blocks(
|
|
168
|
+
hidden_states,
|
|
169
|
+
*args,
|
|
170
|
+
**kwargs,
|
|
171
|
+
)
|
|
172
|
+
else:
|
|
173
|
+
self.context_manager.set_Fn_buffer(
|
|
174
|
+
Fn_hidden_states_residual,
|
|
175
|
+
prefix=f"{self.cache_prefix}_Fn_residual",
|
|
176
|
+
)
|
|
177
|
+
if self.context_manager.is_l1_diff_enabled():
|
|
178
|
+
# for hidden states L1 diff
|
|
179
|
+
self.context_manager.set_Fn_buffer(
|
|
180
|
+
hidden_states,
|
|
181
|
+
f"{self.cache_prefix}_Fn_hidden_states",
|
|
182
|
+
)
|
|
183
|
+
del Fn_hidden_states_residual
|
|
184
|
+
torch._dynamo.graph_break()
|
|
185
|
+
old_encoder_hidden_states = new_encoder_hidden_states
|
|
186
|
+
(
|
|
187
|
+
hidden_states,
|
|
188
|
+
new_encoder_hidden_states,
|
|
189
|
+
hidden_states_residual,
|
|
190
|
+
) = self.call_Mn_blocks( # middle
|
|
191
|
+
hidden_states,
|
|
192
|
+
*args,
|
|
193
|
+
**kwargs,
|
|
194
|
+
)
|
|
195
|
+
|
|
196
|
+
torch._dynamo.graph_break()
|
|
197
|
+
if self.context_manager.is_cache_residual():
|
|
198
|
+
self.context_manager.set_Bn_buffer(
|
|
199
|
+
hidden_states_residual,
|
|
200
|
+
prefix=f"{self.cache_prefix}_Bn_residual",
|
|
201
|
+
)
|
|
202
|
+
else:
|
|
203
|
+
self.context_manager.set_Bn_buffer(
|
|
204
|
+
hidden_states,
|
|
205
|
+
prefix=f"{self.cache_prefix}_Bn_hidden_states",
|
|
206
|
+
)
|
|
207
|
+
|
|
208
|
+
if new_encoder_hidden_states is not None:
|
|
209
|
+
new_encoder_hidden_states_residual = (
|
|
210
|
+
new_encoder_hidden_states - old_encoder_hidden_states
|
|
211
|
+
)
|
|
212
|
+
if self.context_manager.is_encoder_cache_residual():
|
|
213
|
+
if new_encoder_hidden_states is not None:
|
|
214
|
+
self.context_manager.set_Bn_encoder_buffer(
|
|
215
|
+
new_encoder_hidden_states_residual,
|
|
216
|
+
prefix=f"{self.cache_prefix}_Bn_residual",
|
|
217
|
+
)
|
|
218
|
+
else:
|
|
219
|
+
if new_encoder_hidden_states is not None:
|
|
220
|
+
self.context_manager.set_Bn_encoder_buffer(
|
|
221
|
+
new_encoder_hidden_states_residual,
|
|
222
|
+
prefix=f"{self.cache_prefix}_Bn_hidden_states",
|
|
223
|
+
)
|
|
224
|
+
torch._dynamo.graph_break()
|
|
225
|
+
# Call last `n` blocks to further process the hidden states
|
|
226
|
+
# for higher precision.
|
|
227
|
+
if self.context_manager.Bn_compute_blocks() > 0:
|
|
228
|
+
hidden_states, new_encoder_hidden_states = self.call_Bn_blocks(
|
|
229
|
+
hidden_states,
|
|
230
|
+
*args,
|
|
231
|
+
**kwargs,
|
|
232
|
+
)
|
|
233
|
+
|
|
234
|
+
torch._dynamo.graph_break()
|
|
235
|
+
|
|
236
|
+
return self._process_forward_outputs(
|
|
237
|
+
hidden_states,
|
|
238
|
+
new_encoder_hidden_states,
|
|
239
|
+
)
|
|
240
|
+
|
|
241
|
+
def call_Fn_blocks(
|
|
242
|
+
self,
|
|
243
|
+
hidden_states: torch.Tensor,
|
|
244
|
+
*args,
|
|
245
|
+
**kwargs,
|
|
246
|
+
):
|
|
247
|
+
new_encoder_hidden_states = None
|
|
248
|
+
for block in self._Fn_blocks():
|
|
249
|
+
hidden_states = block(
|
|
250
|
+
hidden_states,
|
|
251
|
+
*args,
|
|
252
|
+
**kwargs,
|
|
253
|
+
)
|
|
254
|
+
hidden_states, new_encoder_hidden_states = (
|
|
255
|
+
self._process_block_outputs(hidden_states)
|
|
256
|
+
)
|
|
257
|
+
|
|
258
|
+
return hidden_states, new_encoder_hidden_states
|
|
259
|
+
|
|
260
|
+
def call_Mn_blocks(
|
|
261
|
+
self,
|
|
262
|
+
hidden_states: torch.Tensor,
|
|
263
|
+
*args,
|
|
264
|
+
**kwargs,
|
|
265
|
+
):
|
|
266
|
+
original_hidden_states = hidden_states
|
|
267
|
+
new_encoder_hidden_states = None
|
|
268
|
+
for block in self._Mn_blocks():
|
|
269
|
+
hidden_states = block(
|
|
270
|
+
hidden_states,
|
|
271
|
+
*args,
|
|
272
|
+
**kwargs,
|
|
273
|
+
)
|
|
274
|
+
|
|
275
|
+
hidden_states, new_encoder_hidden_states = (
|
|
276
|
+
self._process_block_outputs(hidden_states)
|
|
277
|
+
)
|
|
278
|
+
|
|
279
|
+
# compute hidden_states residual
|
|
280
|
+
hidden_states = hidden_states.contiguous()
|
|
281
|
+
hidden_states_residual = hidden_states - original_hidden_states.to(
|
|
282
|
+
hidden_states.device
|
|
283
|
+
)
|
|
284
|
+
|
|
285
|
+
return (
|
|
286
|
+
hidden_states,
|
|
287
|
+
new_encoder_hidden_states,
|
|
288
|
+
hidden_states_residual,
|
|
289
|
+
)
|
|
290
|
+
|
|
291
|
+
def call_Bn_blocks(
|
|
292
|
+
self,
|
|
293
|
+
hidden_states: torch.Tensor,
|
|
294
|
+
*args,
|
|
295
|
+
**kwargs,
|
|
296
|
+
):
|
|
297
|
+
new_encoder_hidden_states = None
|
|
298
|
+
if self.context_manager.Bn_compute_blocks() == 0:
|
|
299
|
+
return hidden_states, new_encoder_hidden_states
|
|
300
|
+
|
|
301
|
+
for block in self._Bn_blocks():
|
|
302
|
+
hidden_states = block(
|
|
303
|
+
hidden_states,
|
|
304
|
+
*args,
|
|
305
|
+
**kwargs,
|
|
306
|
+
)
|
|
307
|
+
|
|
308
|
+
hidden_states, new_encoder_hidden_states = (
|
|
309
|
+
self._process_block_outputs(hidden_states)
|
|
310
|
+
)
|
|
311
|
+
|
|
312
|
+
return hidden_states, new_encoder_hidden_states
|
|
313
|
+
|
|
314
|
+
|
|
315
|
+
class PrunedBlocks_Pattern_3_4_5(CachedBlocks_Pattern_3_4_5):
|
|
316
|
+
_supported_patterns = [
|
|
317
|
+
ForwardPattern.Pattern_3,
|
|
318
|
+
ForwardPattern.Pattern_4,
|
|
319
|
+
ForwardPattern.Pattern_5,
|
|
320
|
+
]
|
|
321
|
+
pruned_blocks_step: int = 0 # number of pruned blocks in current step
|
|
322
|
+
|
|
323
|
+
def __init__(
|
|
324
|
+
self,
|
|
325
|
+
# 0. Transformer blocks configuration
|
|
326
|
+
transformer_blocks: torch.nn.ModuleList,
|
|
327
|
+
transformer: torch.nn.Module = None,
|
|
328
|
+
forward_pattern: ForwardPattern = ForwardPattern.Pattern_0,
|
|
329
|
+
check_forward_pattern: bool = True,
|
|
330
|
+
check_num_outputs: bool = True,
|
|
331
|
+
# 1. Prune context configuration
|
|
332
|
+
cache_prefix: str = None, # maybe un-need.
|
|
333
|
+
cache_context: PrunedContext | str = None,
|
|
334
|
+
context_manager: PrunedContextManager = None,
|
|
335
|
+
cache_type: CacheType = CacheType.DBPrune,
|
|
336
|
+
**kwargs,
|
|
337
|
+
):
|
|
338
|
+
super().__init__(
|
|
339
|
+
# 0. Transformer blocks configuration
|
|
340
|
+
transformer_blocks,
|
|
341
|
+
transformer=transformer,
|
|
342
|
+
forward_pattern=forward_pattern,
|
|
343
|
+
check_forward_pattern=check_forward_pattern,
|
|
344
|
+
check_num_outputs=check_num_outputs,
|
|
345
|
+
# 1. Cache context configuration
|
|
346
|
+
cache_prefix=cache_prefix,
|
|
347
|
+
cache_context=cache_context,
|
|
348
|
+
context_manager=context_manager,
|
|
349
|
+
cache_type=cache_type,
|
|
350
|
+
**kwargs,
|
|
351
|
+
)
|
|
352
|
+
assert isinstance(
|
|
353
|
+
self.context_manager, PrunedContextManager
|
|
354
|
+
), "context_manager must be PrunedContextManager for PrunedBlocks."
|
|
355
|
+
self.context_manager: PrunedContextManager = (
|
|
356
|
+
self.context_manager
|
|
357
|
+
) # For type hint
|
|
358
|
+
|
|
359
|
+
@torch.compiler.disable
|
|
360
|
+
def _check_cache_type(self):
|
|
361
|
+
assert (
|
|
362
|
+
self.cache_type == CacheType.DBPrune
|
|
363
|
+
), f"Cache type {self.cache_type} is not supported for PrunedBlocks."
|
|
364
|
+
|
|
365
|
+
def forward(
|
|
366
|
+
self,
|
|
367
|
+
hidden_states: torch.Tensor,
|
|
368
|
+
*args,
|
|
369
|
+
**kwargs,
|
|
370
|
+
):
|
|
371
|
+
self.pruned_blocks_step: int = 0 # reset for each step
|
|
372
|
+
|
|
373
|
+
# Use it's own cache context.
|
|
374
|
+
try:
|
|
375
|
+
self.context_manager.set_context(self.cache_context)
|
|
376
|
+
self._check_cache_params()
|
|
377
|
+
except ContextNotExistError as e:
|
|
378
|
+
logger.warning(f"context not exist: {e}, skip prune.")
|
|
379
|
+
hidden_states, new_encoder_hidden_states = self.call_blocks(
|
|
380
|
+
hidden_states,
|
|
381
|
+
*args,
|
|
382
|
+
**kwargs,
|
|
383
|
+
)
|
|
384
|
+
return self._process_forward_outputs(
|
|
385
|
+
hidden_states, new_encoder_hidden_states
|
|
386
|
+
)
|
|
387
|
+
|
|
388
|
+
self.context_manager.mark_step_begin()
|
|
389
|
+
|
|
390
|
+
if self._check_if_context_parallel_enabled(self.transformer_blocks[0]):
|
|
391
|
+
raise RuntimeError(
|
|
392
|
+
"Block level Context parallelism is not supported in PrunedBlocks."
|
|
393
|
+
)
|
|
394
|
+
|
|
395
|
+
# Call all blocks with prune strategy to process the hidden states.
|
|
396
|
+
new_encoder_hidden_states = None
|
|
397
|
+
for i, block in enumerate(self.transformer_blocks):
|
|
398
|
+
hidden_states, new_encoder_hidden_states = self.compute_or_prune(
|
|
399
|
+
i,
|
|
400
|
+
block,
|
|
401
|
+
hidden_states,
|
|
402
|
+
new_encoder_hidden_states,
|
|
403
|
+
*args,
|
|
404
|
+
**kwargs,
|
|
405
|
+
)
|
|
406
|
+
|
|
407
|
+
self.context_manager.add_pruned_block(self.pruned_blocks_step)
|
|
408
|
+
self.context_manager.add_actual_block(self.num_blocks)
|
|
409
|
+
|
|
410
|
+
return self._process_forward_outputs(
|
|
411
|
+
hidden_states,
|
|
412
|
+
new_encoder_hidden_states,
|
|
413
|
+
)
|
|
414
|
+
|
|
415
|
+
@property
|
|
416
|
+
@torch.compiler.disable
|
|
417
|
+
def num_blocks(self):
|
|
418
|
+
return len(self.transformer_blocks)
|
|
419
|
+
|
|
420
|
+
@torch.compiler.disable
|
|
421
|
+
def _skip_prune(self, block_id: int) -> bool:
|
|
422
|
+
# Wrap for non compiled mode.
|
|
423
|
+
return block_id in self.context_manager.get_non_prune_blocks_ids(
|
|
424
|
+
self.num_blocks
|
|
425
|
+
)
|
|
426
|
+
|
|
427
|
+
@torch.compiler.disable
|
|
428
|
+
def _maybe_prune(
|
|
429
|
+
self,
|
|
430
|
+
block_id: int, # Block index in the transformer blocks
|
|
431
|
+
hidden_states: torch.Tensor, # hidden_states or residual
|
|
432
|
+
prefix: str = "Bn_original", # prev step name for single blocks
|
|
433
|
+
):
|
|
434
|
+
# Wrap for non compiled mode.
|
|
435
|
+
can_use_prune = False
|
|
436
|
+
if not self._skip_prune(block_id):
|
|
437
|
+
can_use_prune = self.context_manager.can_prune(
|
|
438
|
+
hidden_states, # curr step
|
|
439
|
+
parallelized=self._is_parallelized(),
|
|
440
|
+
prefix=prefix, # prev step
|
|
441
|
+
)
|
|
442
|
+
self.pruned_blocks_step += int(can_use_prune)
|
|
443
|
+
return can_use_prune
|
|
444
|
+
|
|
445
|
+
def compute_or_prune(
|
|
446
|
+
self,
|
|
447
|
+
block_id: int, # Block index in the transformer blocks
|
|
448
|
+
# Below are the inputs to the block
|
|
449
|
+
block, # The transformer block to be executed
|
|
450
|
+
hidden_states: torch.Tensor,
|
|
451
|
+
new_encoder_hidden_states: torch.Tensor | None,
|
|
452
|
+
*args,
|
|
453
|
+
**kwargs,
|
|
454
|
+
):
|
|
455
|
+
original_hidden_states = hidden_states
|
|
456
|
+
original_encoder_hidden_states = new_encoder_hidden_states
|
|
457
|
+
|
|
458
|
+
can_use_prune = self._maybe_prune(
|
|
459
|
+
block_id,
|
|
460
|
+
hidden_states,
|
|
461
|
+
prefix=f"{self.cache_prefix}_{block_id}_Fn_original",
|
|
462
|
+
)
|
|
463
|
+
|
|
464
|
+
# Prune steps: Prune current block and reuse the cached
|
|
465
|
+
# residuals for hidden states approximate.
|
|
466
|
+
torch._dynamo.graph_break()
|
|
467
|
+
if can_use_prune:
|
|
468
|
+
self.context_manager.add_pruned_step()
|
|
469
|
+
hidden_states, new_encoder_hidden_states = (
|
|
470
|
+
self.context_manager.apply_prune(
|
|
471
|
+
hidden_states,
|
|
472
|
+
new_encoder_hidden_states,
|
|
473
|
+
prefix=(
|
|
474
|
+
f"{self.cache_prefix}_{block_id}_Bn_residual"
|
|
475
|
+
if self.context_manager.is_cache_residual()
|
|
476
|
+
else f"{self.cache_prefix}_{block_id}_Bn_hidden_states"
|
|
477
|
+
),
|
|
478
|
+
encoder_prefix=(
|
|
479
|
+
f"{self.cache_prefix}_{block_id}_Bn_encoder_residual"
|
|
480
|
+
if self.context_manager.is_encoder_cache_residual()
|
|
481
|
+
else f"{self.cache_prefix}_{block_id}_Bn_encoder_hidden_states"
|
|
482
|
+
),
|
|
483
|
+
)
|
|
484
|
+
)
|
|
485
|
+
torch._dynamo.graph_break()
|
|
486
|
+
else:
|
|
487
|
+
# Normal steps: Compute the block and cache the residuals.
|
|
488
|
+
hidden_states = block(
|
|
489
|
+
hidden_states,
|
|
490
|
+
*args,
|
|
491
|
+
**kwargs,
|
|
492
|
+
)
|
|
493
|
+
hidden_states, new_encoder_hidden_states = (
|
|
494
|
+
self._process_block_outputs(
|
|
495
|
+
hidden_states, new_encoder_hidden_states
|
|
496
|
+
)
|
|
497
|
+
)
|
|
498
|
+
if not self._skip_prune(block_id):
|
|
499
|
+
hidden_states = hidden_states.contiguous()
|
|
500
|
+
hidden_states_residual = hidden_states - original_hidden_states
|
|
501
|
+
|
|
502
|
+
if (
|
|
503
|
+
new_encoder_hidden_states is not None
|
|
504
|
+
and original_encoder_hidden_states is not None
|
|
505
|
+
):
|
|
506
|
+
new_encoder_hidden_states = (
|
|
507
|
+
new_encoder_hidden_states.contiguous()
|
|
508
|
+
)
|
|
509
|
+
new_encoder_hidden_states_residual = (
|
|
510
|
+
new_encoder_hidden_states
|
|
511
|
+
- original_encoder_hidden_states
|
|
512
|
+
)
|
|
513
|
+
else:
|
|
514
|
+
new_encoder_hidden_states_residual = None
|
|
515
|
+
|
|
516
|
+
self.context_manager.set_Fn_buffer(
|
|
517
|
+
original_hidden_states,
|
|
518
|
+
prefix=f"{self.cache_prefix}_{block_id}_Fn_original",
|
|
519
|
+
)
|
|
520
|
+
if self.context_manager.is_cache_residual():
|
|
521
|
+
self.context_manager.set_Bn_buffer(
|
|
522
|
+
hidden_states_residual,
|
|
523
|
+
prefix=f"{self.cache_prefix}_{block_id}_Bn_residual",
|
|
524
|
+
)
|
|
525
|
+
else:
|
|
526
|
+
self.context_manager.set_Bn_buffer(
|
|
527
|
+
hidden_states,
|
|
528
|
+
prefix=f"{self.cache_prefix}_{block_id}_Bn_hidden_states",
|
|
529
|
+
)
|
|
530
|
+
if new_encoder_hidden_states_residual is not None:
|
|
531
|
+
if self.context_manager.is_encoder_cache_residual():
|
|
532
|
+
self.context_manager.set_Bn_encoder_buffer(
|
|
533
|
+
new_encoder_hidden_states_residual,
|
|
534
|
+
prefix=f"{self.cache_prefix}_{block_id}_Bn_encoder_residual",
|
|
535
|
+
)
|
|
536
|
+
else:
|
|
537
|
+
self.context_manager.set_Bn_encoder_buffer(
|
|
538
|
+
new_encoder_hidden_states_residual,
|
|
539
|
+
prefix=f"{self.cache_prefix}_{block_id}_Bn_encoder_hidden_states",
|
|
540
|
+
)
|
|
541
|
+
torch._dynamo.graph_break()
|
|
542
|
+
|
|
543
|
+
return hidden_states, new_encoder_hidden_states
|