cache-dit 1.0.2__py3-none-any.whl → 1.0.4__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Potentially problematic release.
This version of cache-dit might be problematic. Click here for more details.
- cache_dit/__init__.py +3 -0
- cache_dit/_version.py +2 -2
- cache_dit/cache_factory/__init__.py +8 -1
- cache_dit/cache_factory/cache_adapters/cache_adapter.py +90 -76
- cache_dit/cache_factory/cache_blocks/__init__.py +167 -17
- cache_dit/cache_factory/cache_blocks/pattern_0_1_2.py +10 -0
- cache_dit/cache_factory/cache_blocks/pattern_3_4_5.py +271 -36
- cache_dit/cache_factory/cache_blocks/pattern_base.py +286 -45
- cache_dit/cache_factory/cache_blocks/pattern_utils.py +55 -10
- cache_dit/cache_factory/cache_contexts/__init__.py +15 -2
- cache_dit/cache_factory/cache_contexts/cache_config.py +102 -0
- cache_dit/cache_factory/cache_contexts/cache_context.py +26 -89
- cache_dit/cache_factory/cache_contexts/cache_manager.py +7 -7
- cache_dit/cache_factory/cache_contexts/calibrators/taylorseer.py +78 -8
- cache_dit/cache_factory/cache_contexts/context_manager.py +29 -0
- cache_dit/cache_factory/cache_contexts/prune_config.py +69 -0
- cache_dit/cache_factory/cache_contexts/prune_context.py +155 -0
- cache_dit/cache_factory/cache_contexts/prune_manager.py +154 -0
- cache_dit/cache_factory/cache_interface.py +23 -14
- cache_dit/cache_factory/cache_types.py +19 -2
- cache_dit/cache_factory/params_modifier.py +7 -7
- cache_dit/cache_factory/utils.py +38 -27
- cache_dit/utils.py +191 -54
- {cache_dit-1.0.2.dist-info → cache_dit-1.0.4.dist-info}/METADATA +14 -7
- {cache_dit-1.0.2.dist-info → cache_dit-1.0.4.dist-info}/RECORD +29 -24
- {cache_dit-1.0.2.dist-info → cache_dit-1.0.4.dist-info}/WHEEL +0 -0
- {cache_dit-1.0.2.dist-info → cache_dit-1.0.4.dist-info}/entry_points.txt +0 -0
- {cache_dit-1.0.2.dist-info → cache_dit-1.0.4.dist-info}/licenses/LICENSE +0 -0
- {cache_dit-1.0.2.dist-info → cache_dit-1.0.4.dist-info}/top_level.txt +0 -0
|
@@ -2,11 +2,17 @@ import torch
|
|
|
2
2
|
|
|
3
3
|
from cache_dit.cache_factory import ForwardPattern
|
|
4
4
|
from cache_dit.cache_factory.cache_contexts.cache_manager import (
|
|
5
|
-
|
|
5
|
+
ContextNotExistError,
|
|
6
6
|
)
|
|
7
7
|
from cache_dit.cache_factory.cache_blocks.pattern_base import (
|
|
8
8
|
CachedBlocks_Pattern_Base,
|
|
9
9
|
)
|
|
10
|
+
from cache_dit.cache_factory.cache_contexts.prune_context import PrunedContext
|
|
11
|
+
from cache_dit.cache_factory.cache_contexts.prune_manager import (
|
|
12
|
+
PrunedContextManager,
|
|
13
|
+
)
|
|
14
|
+
from cache_dit.cache_factory.cache_types import CacheType
|
|
15
|
+
|
|
10
16
|
from cache_dit.logger import init_logger
|
|
11
17
|
|
|
12
18
|
logger = init_logger(__name__)
|
|
@@ -33,14 +39,14 @@ class CachedBlocks_Pattern_3_4_5(CachedBlocks_Pattern_Base):
|
|
|
33
39
|
*args,
|
|
34
40
|
**kwargs,
|
|
35
41
|
)
|
|
36
|
-
hidden_states, new_encoder_hidden_states =
|
|
37
|
-
hidden_states
|
|
42
|
+
hidden_states, new_encoder_hidden_states = (
|
|
43
|
+
self._process_block_outputs(hidden_states)
|
|
38
44
|
)
|
|
39
45
|
|
|
40
46
|
return hidden_states, new_encoder_hidden_states
|
|
41
47
|
|
|
42
48
|
@torch.compiler.disable
|
|
43
|
-
def
|
|
49
|
+
def _process_block_outputs(
|
|
44
50
|
self, hidden_states: torch.Tensor | tuple
|
|
45
51
|
) -> tuple[torch.Tensor, torch.Tensor | None]:
|
|
46
52
|
# Process the outputs for the block.
|
|
@@ -66,7 +72,7 @@ class CachedBlocks_Pattern_3_4_5(CachedBlocks_Pattern_Base):
|
|
|
66
72
|
return hidden_states, new_encoder_hidden_states
|
|
67
73
|
|
|
68
74
|
@torch.compiler.disable
|
|
69
|
-
def
|
|
75
|
+
def _process_forward_outputs(
|
|
70
76
|
self,
|
|
71
77
|
hidden_states: torch.Tensor,
|
|
72
78
|
new_encoder_hidden_states: torch.Tensor | None,
|
|
@@ -91,16 +97,16 @@ class CachedBlocks_Pattern_3_4_5(CachedBlocks_Pattern_Base):
|
|
|
91
97
|
):
|
|
92
98
|
# Use it's own cache context.
|
|
93
99
|
try:
|
|
94
|
-
self.
|
|
100
|
+
self.context_manager.set_context(self.cache_context)
|
|
95
101
|
self._check_cache_params()
|
|
96
|
-
except
|
|
97
|
-
logger.warning(f"
|
|
102
|
+
except ContextNotExistError as e:
|
|
103
|
+
logger.warning(f"context not exist: {e}, skip cache.")
|
|
98
104
|
hidden_states, new_encoder_hidden_states = self.call_blocks(
|
|
99
105
|
hidden_states,
|
|
100
106
|
*args,
|
|
101
107
|
**kwargs,
|
|
102
108
|
)
|
|
103
|
-
return self.
|
|
109
|
+
return self._process_forward_outputs(
|
|
104
110
|
hidden_states, new_encoder_hidden_states
|
|
105
111
|
)
|
|
106
112
|
|
|
@@ -118,38 +124,38 @@ class CachedBlocks_Pattern_3_4_5(CachedBlocks_Pattern_Base):
|
|
|
118
124
|
)
|
|
119
125
|
del original_hidden_states
|
|
120
126
|
|
|
121
|
-
self.
|
|
127
|
+
self.context_manager.mark_step_begin()
|
|
122
128
|
# Residual L1 diff or Hidden States L1 diff
|
|
123
|
-
can_use_cache = self.
|
|
129
|
+
can_use_cache = self.context_manager.can_cache(
|
|
124
130
|
(
|
|
125
131
|
Fn_hidden_states_residual
|
|
126
|
-
if not self.
|
|
132
|
+
if not self.context_manager.is_l1_diff_enabled()
|
|
127
133
|
else hidden_states
|
|
128
134
|
),
|
|
129
135
|
parallelized=self._is_parallelized(),
|
|
130
136
|
prefix=(
|
|
131
137
|
f"{self.cache_prefix}_Fn_residual"
|
|
132
|
-
if not self.
|
|
138
|
+
if not self.context_manager.is_l1_diff_enabled()
|
|
133
139
|
else f"{self.cache_prefix}_Fn_hidden_states"
|
|
134
140
|
),
|
|
135
141
|
)
|
|
136
142
|
|
|
137
143
|
torch._dynamo.graph_break()
|
|
138
144
|
if can_use_cache:
|
|
139
|
-
self.
|
|
145
|
+
self.context_manager.add_cached_step()
|
|
140
146
|
del Fn_hidden_states_residual
|
|
141
147
|
hidden_states, new_encoder_hidden_states = (
|
|
142
|
-
self.
|
|
148
|
+
self.context_manager.apply_cache(
|
|
143
149
|
hidden_states,
|
|
144
150
|
new_encoder_hidden_states, # encoder_hidden_states not use cache
|
|
145
151
|
prefix=(
|
|
146
152
|
f"{self.cache_prefix}_Bn_residual"
|
|
147
|
-
if self.
|
|
153
|
+
if self.context_manager.is_cache_residual()
|
|
148
154
|
else f"{self.cache_prefix}_Bn_hidden_states"
|
|
149
155
|
),
|
|
150
156
|
encoder_prefix=(
|
|
151
157
|
f"{self.cache_prefix}_Bn_residual"
|
|
152
|
-
if self.
|
|
158
|
+
if self.context_manager.is_encoder_cache_residual()
|
|
153
159
|
else f"{self.cache_prefix}_Bn_hidden_states"
|
|
154
160
|
),
|
|
155
161
|
)
|
|
@@ -157,20 +163,20 @@ class CachedBlocks_Pattern_3_4_5(CachedBlocks_Pattern_Base):
|
|
|
157
163
|
torch._dynamo.graph_break()
|
|
158
164
|
# Call last `n` blocks to further process the hidden states
|
|
159
165
|
# for higher precision.
|
|
160
|
-
if self.
|
|
166
|
+
if self.context_manager.Bn_compute_blocks() > 0:
|
|
161
167
|
hidden_states, new_encoder_hidden_states = self.call_Bn_blocks(
|
|
162
168
|
hidden_states,
|
|
163
169
|
*args,
|
|
164
170
|
**kwargs,
|
|
165
171
|
)
|
|
166
172
|
else:
|
|
167
|
-
self.
|
|
173
|
+
self.context_manager.set_Fn_buffer(
|
|
168
174
|
Fn_hidden_states_residual,
|
|
169
175
|
prefix=f"{self.cache_prefix}_Fn_residual",
|
|
170
176
|
)
|
|
171
|
-
if self.
|
|
177
|
+
if self.context_manager.is_l1_diff_enabled():
|
|
172
178
|
# for hidden states L1 diff
|
|
173
|
-
self.
|
|
179
|
+
self.context_manager.set_Fn_buffer(
|
|
174
180
|
hidden_states,
|
|
175
181
|
f"{self.cache_prefix}_Fn_hidden_states",
|
|
176
182
|
)
|
|
@@ -188,13 +194,13 @@ class CachedBlocks_Pattern_3_4_5(CachedBlocks_Pattern_Base):
|
|
|
188
194
|
)
|
|
189
195
|
|
|
190
196
|
torch._dynamo.graph_break()
|
|
191
|
-
if self.
|
|
192
|
-
self.
|
|
197
|
+
if self.context_manager.is_cache_residual():
|
|
198
|
+
self.context_manager.set_Bn_buffer(
|
|
193
199
|
hidden_states_residual,
|
|
194
200
|
prefix=f"{self.cache_prefix}_Bn_residual",
|
|
195
201
|
)
|
|
196
202
|
else:
|
|
197
|
-
self.
|
|
203
|
+
self.context_manager.set_Bn_buffer(
|
|
198
204
|
hidden_states,
|
|
199
205
|
prefix=f"{self.cache_prefix}_Bn_hidden_states",
|
|
200
206
|
)
|
|
@@ -203,22 +209,22 @@ class CachedBlocks_Pattern_3_4_5(CachedBlocks_Pattern_Base):
|
|
|
203
209
|
new_encoder_hidden_states_residual = (
|
|
204
210
|
new_encoder_hidden_states - old_encoder_hidden_states
|
|
205
211
|
)
|
|
206
|
-
if self.
|
|
212
|
+
if self.context_manager.is_encoder_cache_residual():
|
|
207
213
|
if new_encoder_hidden_states is not None:
|
|
208
|
-
self.
|
|
214
|
+
self.context_manager.set_Bn_encoder_buffer(
|
|
209
215
|
new_encoder_hidden_states_residual,
|
|
210
216
|
prefix=f"{self.cache_prefix}_Bn_residual",
|
|
211
217
|
)
|
|
212
218
|
else:
|
|
213
219
|
if new_encoder_hidden_states is not None:
|
|
214
|
-
self.
|
|
220
|
+
self.context_manager.set_Bn_encoder_buffer(
|
|
215
221
|
new_encoder_hidden_states_residual,
|
|
216
222
|
prefix=f"{self.cache_prefix}_Bn_hidden_states",
|
|
217
223
|
)
|
|
218
224
|
torch._dynamo.graph_break()
|
|
219
225
|
# Call last `n` blocks to further process the hidden states
|
|
220
226
|
# for higher precision.
|
|
221
|
-
if self.
|
|
227
|
+
if self.context_manager.Bn_compute_blocks() > 0:
|
|
222
228
|
hidden_states, new_encoder_hidden_states = self.call_Bn_blocks(
|
|
223
229
|
hidden_states,
|
|
224
230
|
*args,
|
|
@@ -227,7 +233,10 @@ class CachedBlocks_Pattern_3_4_5(CachedBlocks_Pattern_Base):
|
|
|
227
233
|
|
|
228
234
|
torch._dynamo.graph_break()
|
|
229
235
|
|
|
230
|
-
return self.
|
|
236
|
+
return self._process_forward_outputs(
|
|
237
|
+
hidden_states,
|
|
238
|
+
new_encoder_hidden_states,
|
|
239
|
+
)
|
|
231
240
|
|
|
232
241
|
def call_Fn_blocks(
|
|
233
242
|
self,
|
|
@@ -242,8 +251,8 @@ class CachedBlocks_Pattern_3_4_5(CachedBlocks_Pattern_Base):
|
|
|
242
251
|
*args,
|
|
243
252
|
**kwargs,
|
|
244
253
|
)
|
|
245
|
-
hidden_states, new_encoder_hidden_states =
|
|
246
|
-
hidden_states
|
|
254
|
+
hidden_states, new_encoder_hidden_states = (
|
|
255
|
+
self._process_block_outputs(hidden_states)
|
|
247
256
|
)
|
|
248
257
|
|
|
249
258
|
return hidden_states, new_encoder_hidden_states
|
|
@@ -263,8 +272,8 @@ class CachedBlocks_Pattern_3_4_5(CachedBlocks_Pattern_Base):
|
|
|
263
272
|
**kwargs,
|
|
264
273
|
)
|
|
265
274
|
|
|
266
|
-
hidden_states, new_encoder_hidden_states =
|
|
267
|
-
hidden_states
|
|
275
|
+
hidden_states, new_encoder_hidden_states = (
|
|
276
|
+
self._process_block_outputs(hidden_states)
|
|
268
277
|
)
|
|
269
278
|
|
|
270
279
|
# compute hidden_states residual
|
|
@@ -286,7 +295,7 @@ class CachedBlocks_Pattern_3_4_5(CachedBlocks_Pattern_Base):
|
|
|
286
295
|
**kwargs,
|
|
287
296
|
):
|
|
288
297
|
new_encoder_hidden_states = None
|
|
289
|
-
if self.
|
|
298
|
+
if self.context_manager.Bn_compute_blocks() == 0:
|
|
290
299
|
return hidden_states, new_encoder_hidden_states
|
|
291
300
|
|
|
292
301
|
for block in self._Bn_blocks():
|
|
@@ -296,8 +305,234 @@ class CachedBlocks_Pattern_3_4_5(CachedBlocks_Pattern_Base):
|
|
|
296
305
|
**kwargs,
|
|
297
306
|
)
|
|
298
307
|
|
|
299
|
-
hidden_states, new_encoder_hidden_states =
|
|
300
|
-
hidden_states
|
|
308
|
+
hidden_states, new_encoder_hidden_states = (
|
|
309
|
+
self._process_block_outputs(hidden_states)
|
|
310
|
+
)
|
|
311
|
+
|
|
312
|
+
return hidden_states, new_encoder_hidden_states
|
|
313
|
+
|
|
314
|
+
|
|
315
|
+
class PrunedBlocks_Pattern_3_4_5(CachedBlocks_Pattern_3_4_5):
|
|
316
|
+
_supported_patterns = [
|
|
317
|
+
ForwardPattern.Pattern_3,
|
|
318
|
+
ForwardPattern.Pattern_4,
|
|
319
|
+
ForwardPattern.Pattern_5,
|
|
320
|
+
]
|
|
321
|
+
pruned_blocks_step: int = 0 # number of pruned blocks in current step
|
|
322
|
+
|
|
323
|
+
def __init__(
|
|
324
|
+
self,
|
|
325
|
+
# 0. Transformer blocks configuration
|
|
326
|
+
transformer_blocks: torch.nn.ModuleList,
|
|
327
|
+
transformer: torch.nn.Module = None,
|
|
328
|
+
forward_pattern: ForwardPattern = ForwardPattern.Pattern_0,
|
|
329
|
+
check_forward_pattern: bool = True,
|
|
330
|
+
check_num_outputs: bool = True,
|
|
331
|
+
# 1. Prune context configuration
|
|
332
|
+
cache_prefix: str = None, # maybe un-need.
|
|
333
|
+
cache_context: PrunedContext | str = None,
|
|
334
|
+
context_manager: PrunedContextManager = None,
|
|
335
|
+
cache_type: CacheType = CacheType.DBPrune,
|
|
336
|
+
**kwargs,
|
|
337
|
+
):
|
|
338
|
+
super().__init__(
|
|
339
|
+
# 0. Transformer blocks configuration
|
|
340
|
+
transformer_blocks,
|
|
341
|
+
transformer=transformer,
|
|
342
|
+
forward_pattern=forward_pattern,
|
|
343
|
+
check_forward_pattern=check_forward_pattern,
|
|
344
|
+
check_num_outputs=check_num_outputs,
|
|
345
|
+
# 1. Cache context configuration
|
|
346
|
+
cache_prefix=cache_prefix,
|
|
347
|
+
cache_context=cache_context,
|
|
348
|
+
context_manager=context_manager,
|
|
349
|
+
cache_type=cache_type,
|
|
350
|
+
**kwargs,
|
|
351
|
+
)
|
|
352
|
+
assert isinstance(
|
|
353
|
+
self.context_manager, PrunedContextManager
|
|
354
|
+
), "context_manager must be PrunedContextManager for PrunedBlocks."
|
|
355
|
+
self.context_manager: PrunedContextManager = (
|
|
356
|
+
self.context_manager
|
|
357
|
+
) # For type hint
|
|
358
|
+
|
|
359
|
+
@torch.compiler.disable
|
|
360
|
+
def _check_cache_type(self):
|
|
361
|
+
assert (
|
|
362
|
+
self.cache_type == CacheType.DBPrune
|
|
363
|
+
), f"Cache type {self.cache_type} is not supported for PrunedBlocks."
|
|
364
|
+
|
|
365
|
+
def forward(
|
|
366
|
+
self,
|
|
367
|
+
hidden_states: torch.Tensor,
|
|
368
|
+
*args,
|
|
369
|
+
**kwargs,
|
|
370
|
+
):
|
|
371
|
+
self.pruned_blocks_step: int = 0 # reset for each step
|
|
372
|
+
|
|
373
|
+
# Use it's own cache context.
|
|
374
|
+
try:
|
|
375
|
+
self.context_manager.set_context(self.cache_context)
|
|
376
|
+
self._check_cache_params()
|
|
377
|
+
except ContextNotExistError as e:
|
|
378
|
+
logger.warning(f"context not exist: {e}, skip prune.")
|
|
379
|
+
hidden_states, new_encoder_hidden_states = self.call_blocks(
|
|
380
|
+
hidden_states,
|
|
381
|
+
*args,
|
|
382
|
+
**kwargs,
|
|
383
|
+
)
|
|
384
|
+
return self._process_forward_outputs(
|
|
385
|
+
hidden_states, new_encoder_hidden_states
|
|
386
|
+
)
|
|
387
|
+
|
|
388
|
+
self.context_manager.mark_step_begin()
|
|
389
|
+
|
|
390
|
+
# Call all blocks with prune strategy to process the hidden states.
|
|
391
|
+
new_encoder_hidden_states = None
|
|
392
|
+
for i, block in enumerate(self.transformer_blocks):
|
|
393
|
+
hidden_states, new_encoder_hidden_states = self.compute_or_prune(
|
|
394
|
+
i,
|
|
395
|
+
block,
|
|
396
|
+
hidden_states,
|
|
397
|
+
new_encoder_hidden_states,
|
|
398
|
+
*args,
|
|
399
|
+
**kwargs,
|
|
400
|
+
)
|
|
401
|
+
|
|
402
|
+
self.context_manager.add_pruned_block(self.pruned_blocks_step)
|
|
403
|
+
self.context_manager.add_actual_block(self.num_blocks)
|
|
404
|
+
|
|
405
|
+
return self._process_forward_outputs(
|
|
406
|
+
hidden_states,
|
|
407
|
+
new_encoder_hidden_states,
|
|
408
|
+
)
|
|
409
|
+
|
|
410
|
+
@property
|
|
411
|
+
@torch.compiler.disable
|
|
412
|
+
def num_blocks(self):
|
|
413
|
+
return len(self.transformer_blocks)
|
|
414
|
+
|
|
415
|
+
@torch.compiler.disable
|
|
416
|
+
def _skip_prune(self, block_id: int) -> bool:
|
|
417
|
+
# Wrap for non compiled mode.
|
|
418
|
+
return block_id in self.context_manager.get_non_prune_blocks_ids(
|
|
419
|
+
self.num_blocks
|
|
420
|
+
)
|
|
421
|
+
|
|
422
|
+
@torch.compiler.disable
|
|
423
|
+
def _maybe_prune(
|
|
424
|
+
self,
|
|
425
|
+
block_id: int, # Block index in the transformer blocks
|
|
426
|
+
hidden_states: torch.Tensor, # hidden_states or residual
|
|
427
|
+
prefix: str = "Bn_original", # prev step name for single blocks
|
|
428
|
+
):
|
|
429
|
+
# Wrap for non compiled mode.
|
|
430
|
+
can_use_prune = False
|
|
431
|
+
if not self._skip_prune(block_id):
|
|
432
|
+
can_use_prune = self.context_manager.can_prune(
|
|
433
|
+
hidden_states, # curr step
|
|
434
|
+
parallelized=self._is_parallelized(),
|
|
435
|
+
prefix=prefix, # prev step
|
|
436
|
+
)
|
|
437
|
+
self.pruned_blocks_step += int(can_use_prune)
|
|
438
|
+
return can_use_prune
|
|
439
|
+
|
|
440
|
+
def compute_or_prune(
|
|
441
|
+
self,
|
|
442
|
+
block_id: int, # Block index in the transformer blocks
|
|
443
|
+
# Below are the inputs to the block
|
|
444
|
+
block, # The transformer block to be executed
|
|
445
|
+
hidden_states: torch.Tensor,
|
|
446
|
+
new_encoder_hidden_states: torch.Tensor | None,
|
|
447
|
+
*args,
|
|
448
|
+
**kwargs,
|
|
449
|
+
):
|
|
450
|
+
original_hidden_states = hidden_states
|
|
451
|
+
original_encoder_hidden_states = new_encoder_hidden_states
|
|
452
|
+
|
|
453
|
+
can_use_prune = self._maybe_prune(
|
|
454
|
+
block_id,
|
|
455
|
+
hidden_states,
|
|
456
|
+
prefix=f"{self.cache_prefix}_{block_id}_Fn_original",
|
|
457
|
+
)
|
|
458
|
+
|
|
459
|
+
# Prune steps: Prune current block and reuse the cached
|
|
460
|
+
# residuals for hidden states approximate.
|
|
461
|
+
torch._dynamo.graph_break()
|
|
462
|
+
if can_use_prune:
|
|
463
|
+
self.context_manager.add_pruned_step()
|
|
464
|
+
hidden_states, new_encoder_hidden_states = (
|
|
465
|
+
self.context_manager.apply_prune(
|
|
466
|
+
hidden_states,
|
|
467
|
+
new_encoder_hidden_states,
|
|
468
|
+
prefix=(
|
|
469
|
+
f"{self.cache_prefix}_{block_id}_Bn_residual"
|
|
470
|
+
if self.context_manager.is_cache_residual()
|
|
471
|
+
else f"{self.cache_prefix}_{block_id}_Bn_hidden_states"
|
|
472
|
+
),
|
|
473
|
+
encoder_prefix=(
|
|
474
|
+
f"{self.cache_prefix}_{block_id}_Bn_encoder_residual"
|
|
475
|
+
if self.context_manager.is_encoder_cache_residual()
|
|
476
|
+
else f"{self.cache_prefix}_{block_id}_Bn_encoder_hidden_states"
|
|
477
|
+
),
|
|
478
|
+
)
|
|
479
|
+
)
|
|
480
|
+
torch._dynamo.graph_break()
|
|
481
|
+
else:
|
|
482
|
+
# Normal steps: Compute the block and cache the residuals.
|
|
483
|
+
hidden_states = block(
|
|
484
|
+
hidden_states,
|
|
485
|
+
*args,
|
|
486
|
+
**kwargs,
|
|
301
487
|
)
|
|
488
|
+
hidden_states, new_encoder_hidden_states = (
|
|
489
|
+
self._process_block_outputs(
|
|
490
|
+
hidden_states, new_encoder_hidden_states
|
|
491
|
+
)
|
|
492
|
+
)
|
|
493
|
+
if not self._skip_prune(block_id):
|
|
494
|
+
hidden_states = hidden_states.contiguous()
|
|
495
|
+
hidden_states_residual = hidden_states - original_hidden_states
|
|
496
|
+
|
|
497
|
+
if (
|
|
498
|
+
new_encoder_hidden_states is not None
|
|
499
|
+
and original_encoder_hidden_states is not None
|
|
500
|
+
):
|
|
501
|
+
new_encoder_hidden_states = (
|
|
502
|
+
new_encoder_hidden_states.contiguous()
|
|
503
|
+
)
|
|
504
|
+
new_encoder_hidden_states_residual = (
|
|
505
|
+
new_encoder_hidden_states
|
|
506
|
+
- original_encoder_hidden_states
|
|
507
|
+
)
|
|
508
|
+
else:
|
|
509
|
+
new_encoder_hidden_states_residual = None
|
|
510
|
+
|
|
511
|
+
self.context_manager.set_Fn_buffer(
|
|
512
|
+
original_hidden_states,
|
|
513
|
+
prefix=f"{self.cache_prefix}_{block_id}_Fn_original",
|
|
514
|
+
)
|
|
515
|
+
if self.context_manager.is_cache_residual():
|
|
516
|
+
self.context_manager.set_Bn_buffer(
|
|
517
|
+
hidden_states_residual,
|
|
518
|
+
prefix=f"{self.cache_prefix}_{block_id}_Bn_residual",
|
|
519
|
+
)
|
|
520
|
+
else:
|
|
521
|
+
self.context_manager.set_Bn_buffer(
|
|
522
|
+
hidden_states,
|
|
523
|
+
prefix=f"{self.cache_prefix}_{block_id}_Bn_hidden_states",
|
|
524
|
+
)
|
|
525
|
+
if new_encoder_hidden_states_residual is not None:
|
|
526
|
+
if self.context_manager.is_encoder_cache_residual():
|
|
527
|
+
self.context_manager.set_Bn_encoder_buffer(
|
|
528
|
+
new_encoder_hidden_states_residual,
|
|
529
|
+
prefix=f"{self.cache_prefix}_{block_id}_Bn_encoder_residual",
|
|
530
|
+
)
|
|
531
|
+
else:
|
|
532
|
+
self.context_manager.set_Bn_encoder_buffer(
|
|
533
|
+
new_encoder_hidden_states_residual,
|
|
534
|
+
prefix=f"{self.cache_prefix}_{block_id}_Bn_encoder_hidden_states",
|
|
535
|
+
)
|
|
536
|
+
torch._dynamo.graph_break()
|
|
302
537
|
|
|
303
538
|
return hidden_states, new_encoder_hidden_states
|