cache-dit 1.0.3__py3-none-any.whl → 1.0.5__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.

Potentially problematic release.


This version of cache-dit might be problematic. Click here for more details.

Files changed (32) hide show
  1. cache_dit/__init__.py +3 -0
  2. cache_dit/_version.py +2 -2
  3. cache_dit/cache_factory/__init__.py +8 -1
  4. cache_dit/cache_factory/block_adapters/__init__.py +4 -1
  5. cache_dit/cache_factory/cache_adapters/cache_adapter.py +126 -80
  6. cache_dit/cache_factory/cache_blocks/__init__.py +167 -17
  7. cache_dit/cache_factory/cache_blocks/pattern_0_1_2.py +10 -0
  8. cache_dit/cache_factory/cache_blocks/pattern_3_4_5.py +256 -24
  9. cache_dit/cache_factory/cache_blocks/pattern_base.py +273 -38
  10. cache_dit/cache_factory/cache_blocks/pattern_utils.py +55 -10
  11. cache_dit/cache_factory/cache_contexts/__init__.py +15 -2
  12. cache_dit/cache_factory/cache_contexts/cache_config.py +118 -0
  13. cache_dit/cache_factory/cache_contexts/cache_context.py +15 -93
  14. cache_dit/cache_factory/cache_contexts/cache_manager.py +7 -7
  15. cache_dit/cache_factory/cache_contexts/calibrators/__init__.py +22 -0
  16. cache_dit/cache_factory/cache_contexts/calibrators/taylorseer.py +78 -8
  17. cache_dit/cache_factory/cache_contexts/context_manager.py +29 -0
  18. cache_dit/cache_factory/cache_contexts/prune_config.py +63 -0
  19. cache_dit/cache_factory/cache_contexts/prune_context.py +155 -0
  20. cache_dit/cache_factory/cache_contexts/prune_manager.py +154 -0
  21. cache_dit/cache_factory/cache_interface.py +20 -14
  22. cache_dit/cache_factory/cache_types.py +19 -2
  23. cache_dit/cache_factory/params_modifier.py +7 -7
  24. cache_dit/cache_factory/utils.py +18 -7
  25. cache_dit/quantize/quantize_ao.py +58 -17
  26. cache_dit/utils.py +191 -54
  27. {cache_dit-1.0.3.dist-info → cache_dit-1.0.5.dist-info}/METADATA +11 -10
  28. {cache_dit-1.0.3.dist-info → cache_dit-1.0.5.dist-info}/RECORD +32 -27
  29. {cache_dit-1.0.3.dist-info → cache_dit-1.0.5.dist-info}/WHEEL +0 -0
  30. {cache_dit-1.0.3.dist-info → cache_dit-1.0.5.dist-info}/entry_points.txt +0 -0
  31. {cache_dit-1.0.3.dist-info → cache_dit-1.0.5.dist-info}/licenses/LICENSE +0 -0
  32. {cache_dit-1.0.3.dist-info → cache_dit-1.0.5.dist-info}/top_level.txt +0 -0
@@ -1,28 +1,33 @@
1
1
  import torch
2
2
 
3
3
  from cache_dit.cache_factory import ForwardPattern
4
+ from cache_dit.cache_factory.cache_types import CacheType
4
5
  from cache_dit.cache_factory.cache_contexts.cache_context import CachedContext
6
+ from cache_dit.cache_factory.cache_contexts.prune_context import PrunedContext
5
7
  from cache_dit.cache_factory.cache_contexts.cache_manager import (
6
8
  CachedContextManager,
7
9
  )
10
+ from cache_dit.cache_factory.cache_contexts.prune_manager import (
11
+ PrunedContextManager,
12
+ )
8
13
 
9
14
  from cache_dit.cache_factory.cache_blocks.pattern_0_1_2 import (
10
15
  CachedBlocks_Pattern_0_1_2,
16
+ PrunedBlocks_Pattern_0_1_2,
11
17
  )
12
18
  from cache_dit.cache_factory.cache_blocks.pattern_3_4_5 import (
13
19
  CachedBlocks_Pattern_3_4_5,
20
+ PrunedBlocks_Pattern_3_4_5,
14
21
  )
15
- from cache_dit.cache_factory.cache_blocks.pattern_utils import (
16
- patch_cached_stats,
17
- remove_cached_stats,
18
- )
22
+ from cache_dit.cache_factory.cache_blocks.pattern_utils import apply_stats
23
+ from cache_dit.cache_factory.cache_blocks.pattern_utils import remove_stats
19
24
 
20
25
  from cache_dit.logger import init_logger
21
26
 
22
27
  logger = init_logger(__name__)
23
28
 
24
29
 
25
- class CachedBlocks:
30
+ class UnifiedBlocks:
26
31
  def __new__(
27
32
  cls,
28
33
  # 0. Transformer blocks configuration
@@ -36,16 +41,13 @@ class CachedBlocks:
36
41
  # 'layers', 'single_stream_blocks', 'double_stream_blocks'
37
42
  cache_prefix: str = None, # cache_prefix maybe un-need.
38
43
  # Usually, blocks_name, etc.
39
- cache_context: CachedContext | str = None,
40
- cache_manager: CachedContextManager = None,
44
+ cache_context: CachedContext | PrunedContext | str = None,
45
+ context_manager: CachedContextManager | PrunedContextManager = None,
46
+ cache_type: CacheType = CacheType.DBCache,
41
47
  **kwargs,
42
48
  ):
43
- assert transformer is not None, "transformer can't be None."
44
- assert forward_pattern is not None, "forward_pattern can't be None."
45
- assert cache_context is not None, "cache_context can't be None."
46
- assert cache_manager is not None, "cache_manager can't be None."
47
- if forward_pattern in CachedBlocks_Pattern_0_1_2._supported_patterns:
48
- return CachedBlocks_Pattern_0_1_2(
49
+ if cache_type == CacheType.DBCache:
50
+ return CachedBlocks(
49
51
  # 0. Transformer blocks configuration
50
52
  transformer_blocks,
51
53
  transformer=transformer,
@@ -55,11 +57,12 @@ class CachedBlocks:
55
57
  # 1. Cache context configuration
56
58
  cache_prefix=cache_prefix,
57
59
  cache_context=cache_context,
58
- cache_manager=cache_manager,
60
+ context_manager=context_manager,
61
+ cache_type=cache_type,
59
62
  **kwargs,
60
63
  )
61
- elif forward_pattern in CachedBlocks_Pattern_3_4_5._supported_patterns:
62
- return CachedBlocks_Pattern_3_4_5(
64
+ elif cache_type == CacheType.DBPrune:
65
+ return PrunedBlocks(
63
66
  # 0. Transformer blocks configuration
64
67
  transformer_blocks,
65
68
  transformer=transformer,
@@ -69,8 +72,155 @@ class CachedBlocks:
69
72
  # 1. Cache context configuration
70
73
  cache_prefix=cache_prefix,
71
74
  cache_context=cache_context,
72
- cache_manager=cache_manager,
75
+ context_manager=context_manager,
76
+ cache_type=cache_type,
73
77
  **kwargs,
74
78
  )
79
+ else:
80
+ raise ValueError(f"Cache type {cache_type} is not supported now!")
81
+
82
+
83
+ class CachedBlocks:
84
+ def __new__(
85
+ cls,
86
+ # 0. Transformer blocks configuration
87
+ transformer_blocks: torch.nn.ModuleList,
88
+ transformer: torch.nn.Module = None,
89
+ forward_pattern: ForwardPattern = None,
90
+ check_forward_pattern: bool = True,
91
+ check_num_outputs: bool = True,
92
+ # 1. Cache context configuration
93
+ # 'transformer_blocks', 'blocks', 'single_transformer_blocks',
94
+ # 'layers', 'single_stream_blocks', 'double_stream_blocks'
95
+ cache_prefix: str = None, # cache_prefix maybe un-need.
96
+ # Usually, blocks_name, etc.
97
+ cache_context: CachedContext | PrunedContext | str = None,
98
+ context_manager: CachedContextManager | PrunedContextManager = None,
99
+ cache_type: CacheType = CacheType.DBCache,
100
+ **kwargs,
101
+ ):
102
+ assert transformer is not None, "transformer can't be None."
103
+ assert forward_pattern is not None, "forward_pattern can't be None."
104
+ assert cache_context is not None, "cache_context can't be None."
105
+ assert context_manager is not None, "context_manager can't be None."
106
+ if forward_pattern in CachedBlocks_Pattern_0_1_2._supported_patterns:
107
+ if cache_type == CacheType.DBCache:
108
+ assert isinstance(
109
+ context_manager, CachedContextManager
110
+ ), "context_manager must be CachedContextManager for DBCache."
111
+ return CachedBlocks_Pattern_0_1_2(
112
+ # 0. Transformer blocks configuration
113
+ transformer_blocks,
114
+ transformer=transformer,
115
+ forward_pattern=forward_pattern,
116
+ check_forward_pattern=check_forward_pattern,
117
+ check_num_outputs=check_num_outputs,
118
+ # 1. Cache context configuration
119
+ cache_prefix=cache_prefix,
120
+ cache_context=cache_context,
121
+ context_manager=context_manager,
122
+ cache_type=cache_type,
123
+ **kwargs,
124
+ )
125
+ else:
126
+ raise ValueError(
127
+ f"Cache type {cache_type} is not supported now!"
128
+ )
129
+ elif forward_pattern in CachedBlocks_Pattern_3_4_5._supported_patterns:
130
+ if cache_type == CacheType.DBCache:
131
+ assert isinstance(
132
+ context_manager, CachedContextManager
133
+ ), "context_manager must be CachedContextManager for DBCache."
134
+ return CachedBlocks_Pattern_3_4_5(
135
+ # 0. Transformer blocks configuration
136
+ transformer_blocks,
137
+ transformer=transformer,
138
+ forward_pattern=forward_pattern,
139
+ check_forward_pattern=check_forward_pattern,
140
+ check_num_outputs=check_num_outputs,
141
+ # 1. Cache context configuration
142
+ cache_prefix=cache_prefix,
143
+ cache_context=cache_context,
144
+ context_manager=context_manager,
145
+ cache_type=cache_type,
146
+ **kwargs,
147
+ )
148
+ else:
149
+ raise ValueError(
150
+ f"Cache type {cache_type} is not supported now!"
151
+ )
152
+ else:
153
+ raise ValueError(f"Pattern {forward_pattern} is not supported now!")
154
+
155
+
156
+ class PrunedBlocks:
157
+ def __new__(
158
+ cls,
159
+ # 0. Transformer blocks configuration
160
+ transformer_blocks: torch.nn.ModuleList,
161
+ transformer: torch.nn.Module = None,
162
+ forward_pattern: ForwardPattern = None,
163
+ check_forward_pattern: bool = True,
164
+ check_num_outputs: bool = True,
165
+ # 1. Cache context configuration
166
+ # 'transformer_blocks', 'blocks', 'single_transformer_blocks',
167
+ # 'layers', 'single_stream_blocks', 'double_stream_blocks'
168
+ cache_prefix: str = None, # cache_prefix maybe un-need.
169
+ # Usually, blocks_name, etc.
170
+ cache_context: CachedContext | PrunedContext | str = None,
171
+ context_manager: CachedContextManager | PrunedContextManager = None,
172
+ cache_type: CacheType = CacheType.DBCache,
173
+ **kwargs,
174
+ ):
175
+ assert transformer is not None, "transformer can't be None."
176
+ assert forward_pattern is not None, "forward_pattern can't be None."
177
+ assert cache_context is not None, "cache_context can't be None."
178
+ assert context_manager is not None, "context_manager can't be None."
179
+ if forward_pattern in PrunedBlocks_Pattern_0_1_2._supported_patterns:
180
+ if cache_type == CacheType.DBPrune:
181
+ assert isinstance(
182
+ context_manager, PrunedContextManager
183
+ ), "context_manager must be PrunedContextManager for DBPrune."
184
+ return PrunedBlocks_Pattern_0_1_2(
185
+ # 0. Transformer blocks configuration
186
+ transformer_blocks,
187
+ transformer=transformer,
188
+ forward_pattern=forward_pattern,
189
+ check_forward_pattern=check_forward_pattern,
190
+ check_num_outputs=check_num_outputs,
191
+ # 1. Cache context configuration
192
+ cache_prefix=cache_prefix,
193
+ cache_context=cache_context,
194
+ context_manager=context_manager,
195
+ cache_type=cache_type,
196
+ **kwargs,
197
+ )
198
+ else:
199
+ raise ValueError(
200
+ f"Cache type {cache_type} is not supported now!"
201
+ )
202
+ elif forward_pattern in PrunedBlocks_Pattern_3_4_5._supported_patterns:
203
+ if cache_type == CacheType.DBPrune:
204
+ assert isinstance(
205
+ context_manager, PrunedContextManager
206
+ ), "context_manager must be PrunedContextManager for DBPrune."
207
+ return PrunedBlocks_Pattern_3_4_5(
208
+ # 0. Transformer blocks configuration
209
+ transformer_blocks,
210
+ transformer=transformer,
211
+ forward_pattern=forward_pattern,
212
+ check_forward_pattern=check_forward_pattern,
213
+ check_num_outputs=check_num_outputs,
214
+ # 1. Cache context configuration
215
+ cache_prefix=cache_prefix,
216
+ cache_context=cache_context,
217
+ context_manager=context_manager,
218
+ cache_type=cache_type,
219
+ **kwargs,
220
+ )
221
+ else:
222
+ raise ValueError(
223
+ f"Cache type {cache_type} is not supported now!"
224
+ )
75
225
  else:
76
226
  raise ValueError(f"Pattern {forward_pattern} is not supported now!")
@@ -1,6 +1,7 @@
1
1
  from cache_dit.cache_factory import ForwardPattern
2
2
  from cache_dit.cache_factory.cache_blocks.pattern_base import (
3
3
  CachedBlocks_Pattern_Base,
4
+ PrunedBlocks_Pattern_Base,
4
5
  )
5
6
  from cache_dit.logger import init_logger
6
7
 
@@ -14,3 +15,12 @@ class CachedBlocks_Pattern_0_1_2(CachedBlocks_Pattern_Base):
14
15
  ForwardPattern.Pattern_2,
15
16
  ]
16
17
  ...
18
+
19
+
20
+ class PrunedBlocks_Pattern_0_1_2(PrunedBlocks_Pattern_Base):
21
+ _supported_patterns = [
22
+ ForwardPattern.Pattern_0,
23
+ ForwardPattern.Pattern_1,
24
+ ForwardPattern.Pattern_2,
25
+ ]
26
+ ...
@@ -2,11 +2,17 @@ import torch
2
2
 
3
3
  from cache_dit.cache_factory import ForwardPattern
4
4
  from cache_dit.cache_factory.cache_contexts.cache_manager import (
5
- CacheNotExistError,
5
+ ContextNotExistError,
6
6
  )
7
7
  from cache_dit.cache_factory.cache_blocks.pattern_base import (
8
8
  CachedBlocks_Pattern_Base,
9
9
  )
10
+ from cache_dit.cache_factory.cache_contexts.prune_context import PrunedContext
11
+ from cache_dit.cache_factory.cache_contexts.prune_manager import (
12
+ PrunedContextManager,
13
+ )
14
+ from cache_dit.cache_factory.cache_types import CacheType
15
+
10
16
  from cache_dit.logger import init_logger
11
17
 
12
18
  logger = init_logger(__name__)
@@ -91,10 +97,10 @@ class CachedBlocks_Pattern_3_4_5(CachedBlocks_Pattern_Base):
91
97
  ):
92
98
  # Use it's own cache context.
93
99
  try:
94
- self.cache_manager.set_context(self.cache_context)
100
+ self.context_manager.set_context(self.cache_context)
95
101
  self._check_cache_params()
96
- except CacheNotExistError as e:
97
- logger.warning(f"Cache context not exist: {e}, skip cache.")
102
+ except ContextNotExistError as e:
103
+ logger.warning(f"context not exist: {e}, skip cache.")
98
104
  hidden_states, new_encoder_hidden_states = self.call_blocks(
99
105
  hidden_states,
100
106
  *args,
@@ -118,38 +124,38 @@ class CachedBlocks_Pattern_3_4_5(CachedBlocks_Pattern_Base):
118
124
  )
119
125
  del original_hidden_states
120
126
 
121
- self.cache_manager.mark_step_begin()
127
+ self.context_manager.mark_step_begin()
122
128
  # Residual L1 diff or Hidden States L1 diff
123
- can_use_cache = self.cache_manager.can_cache(
129
+ can_use_cache = self.context_manager.can_cache(
124
130
  (
125
131
  Fn_hidden_states_residual
126
- if not self.cache_manager.is_l1_diff_enabled()
132
+ if not self.context_manager.is_l1_diff_enabled()
127
133
  else hidden_states
128
134
  ),
129
135
  parallelized=self._is_parallelized(),
130
136
  prefix=(
131
137
  f"{self.cache_prefix}_Fn_residual"
132
- if not self.cache_manager.is_l1_diff_enabled()
138
+ if not self.context_manager.is_l1_diff_enabled()
133
139
  else f"{self.cache_prefix}_Fn_hidden_states"
134
140
  ),
135
141
  )
136
142
 
137
143
  torch._dynamo.graph_break()
138
144
  if can_use_cache:
139
- self.cache_manager.add_cached_step()
145
+ self.context_manager.add_cached_step()
140
146
  del Fn_hidden_states_residual
141
147
  hidden_states, new_encoder_hidden_states = (
142
- self.cache_manager.apply_cache(
148
+ self.context_manager.apply_cache(
143
149
  hidden_states,
144
150
  new_encoder_hidden_states, # encoder_hidden_states not use cache
145
151
  prefix=(
146
152
  f"{self.cache_prefix}_Bn_residual"
147
- if self.cache_manager.is_cache_residual()
153
+ if self.context_manager.is_cache_residual()
148
154
  else f"{self.cache_prefix}_Bn_hidden_states"
149
155
  ),
150
156
  encoder_prefix=(
151
157
  f"{self.cache_prefix}_Bn_residual"
152
- if self.cache_manager.is_encoder_cache_residual()
158
+ if self.context_manager.is_encoder_cache_residual()
153
159
  else f"{self.cache_prefix}_Bn_hidden_states"
154
160
  ),
155
161
  )
@@ -157,20 +163,20 @@ class CachedBlocks_Pattern_3_4_5(CachedBlocks_Pattern_Base):
157
163
  torch._dynamo.graph_break()
158
164
  # Call last `n` blocks to further process the hidden states
159
165
  # for higher precision.
160
- if self.cache_manager.Bn_compute_blocks() > 0:
166
+ if self.context_manager.Bn_compute_blocks() > 0:
161
167
  hidden_states, new_encoder_hidden_states = self.call_Bn_blocks(
162
168
  hidden_states,
163
169
  *args,
164
170
  **kwargs,
165
171
  )
166
172
  else:
167
- self.cache_manager.set_Fn_buffer(
173
+ self.context_manager.set_Fn_buffer(
168
174
  Fn_hidden_states_residual,
169
175
  prefix=f"{self.cache_prefix}_Fn_residual",
170
176
  )
171
- if self.cache_manager.is_l1_diff_enabled():
177
+ if self.context_manager.is_l1_diff_enabled():
172
178
  # for hidden states L1 diff
173
- self.cache_manager.set_Fn_buffer(
179
+ self.context_manager.set_Fn_buffer(
174
180
  hidden_states,
175
181
  f"{self.cache_prefix}_Fn_hidden_states",
176
182
  )
@@ -188,13 +194,13 @@ class CachedBlocks_Pattern_3_4_5(CachedBlocks_Pattern_Base):
188
194
  )
189
195
 
190
196
  torch._dynamo.graph_break()
191
- if self.cache_manager.is_cache_residual():
192
- self.cache_manager.set_Bn_buffer(
197
+ if self.context_manager.is_cache_residual():
198
+ self.context_manager.set_Bn_buffer(
193
199
  hidden_states_residual,
194
200
  prefix=f"{self.cache_prefix}_Bn_residual",
195
201
  )
196
202
  else:
197
- self.cache_manager.set_Bn_buffer(
203
+ self.context_manager.set_Bn_buffer(
198
204
  hidden_states,
199
205
  prefix=f"{self.cache_prefix}_Bn_hidden_states",
200
206
  )
@@ -203,22 +209,22 @@ class CachedBlocks_Pattern_3_4_5(CachedBlocks_Pattern_Base):
203
209
  new_encoder_hidden_states_residual = (
204
210
  new_encoder_hidden_states - old_encoder_hidden_states
205
211
  )
206
- if self.cache_manager.is_encoder_cache_residual():
212
+ if self.context_manager.is_encoder_cache_residual():
207
213
  if new_encoder_hidden_states is not None:
208
- self.cache_manager.set_Bn_encoder_buffer(
214
+ self.context_manager.set_Bn_encoder_buffer(
209
215
  new_encoder_hidden_states_residual,
210
216
  prefix=f"{self.cache_prefix}_Bn_residual",
211
217
  )
212
218
  else:
213
219
  if new_encoder_hidden_states is not None:
214
- self.cache_manager.set_Bn_encoder_buffer(
220
+ self.context_manager.set_Bn_encoder_buffer(
215
221
  new_encoder_hidden_states_residual,
216
222
  prefix=f"{self.cache_prefix}_Bn_hidden_states",
217
223
  )
218
224
  torch._dynamo.graph_break()
219
225
  # Call last `n` blocks to further process the hidden states
220
226
  # for higher precision.
221
- if self.cache_manager.Bn_compute_blocks() > 0:
227
+ if self.context_manager.Bn_compute_blocks() > 0:
222
228
  hidden_states, new_encoder_hidden_states = self.call_Bn_blocks(
223
229
  hidden_states,
224
230
  *args,
@@ -289,7 +295,7 @@ class CachedBlocks_Pattern_3_4_5(CachedBlocks_Pattern_Base):
289
295
  **kwargs,
290
296
  ):
291
297
  new_encoder_hidden_states = None
292
- if self.cache_manager.Bn_compute_blocks() == 0:
298
+ if self.context_manager.Bn_compute_blocks() == 0:
293
299
  return hidden_states, new_encoder_hidden_states
294
300
 
295
301
  for block in self._Bn_blocks():
@@ -304,3 +310,229 @@ class CachedBlocks_Pattern_3_4_5(CachedBlocks_Pattern_Base):
304
310
  )
305
311
 
306
312
  return hidden_states, new_encoder_hidden_states
313
+
314
+
315
+ class PrunedBlocks_Pattern_3_4_5(CachedBlocks_Pattern_3_4_5):
316
+ _supported_patterns = [
317
+ ForwardPattern.Pattern_3,
318
+ ForwardPattern.Pattern_4,
319
+ ForwardPattern.Pattern_5,
320
+ ]
321
+ pruned_blocks_step: int = 0 # number of pruned blocks in current step
322
+
323
+ def __init__(
324
+ self,
325
+ # 0. Transformer blocks configuration
326
+ transformer_blocks: torch.nn.ModuleList,
327
+ transformer: torch.nn.Module = None,
328
+ forward_pattern: ForwardPattern = ForwardPattern.Pattern_0,
329
+ check_forward_pattern: bool = True,
330
+ check_num_outputs: bool = True,
331
+ # 1. Prune context configuration
332
+ cache_prefix: str = None, # maybe un-need.
333
+ cache_context: PrunedContext | str = None,
334
+ context_manager: PrunedContextManager = None,
335
+ cache_type: CacheType = CacheType.DBPrune,
336
+ **kwargs,
337
+ ):
338
+ super().__init__(
339
+ # 0. Transformer blocks configuration
340
+ transformer_blocks,
341
+ transformer=transformer,
342
+ forward_pattern=forward_pattern,
343
+ check_forward_pattern=check_forward_pattern,
344
+ check_num_outputs=check_num_outputs,
345
+ # 1. Cache context configuration
346
+ cache_prefix=cache_prefix,
347
+ cache_context=cache_context,
348
+ context_manager=context_manager,
349
+ cache_type=cache_type,
350
+ **kwargs,
351
+ )
352
+ assert isinstance(
353
+ self.context_manager, PrunedContextManager
354
+ ), "context_manager must be PrunedContextManager for PrunedBlocks."
355
+ self.context_manager: PrunedContextManager = (
356
+ self.context_manager
357
+ ) # For type hint
358
+
359
+ @torch.compiler.disable
360
+ def _check_cache_type(self):
361
+ assert (
362
+ self.cache_type == CacheType.DBPrune
363
+ ), f"Cache type {self.cache_type} is not supported for PrunedBlocks."
364
+
365
+ def forward(
366
+ self,
367
+ hidden_states: torch.Tensor,
368
+ *args,
369
+ **kwargs,
370
+ ):
371
+ self.pruned_blocks_step: int = 0 # reset for each step
372
+
373
+ # Use it's own cache context.
374
+ try:
375
+ self.context_manager.set_context(self.cache_context)
376
+ self._check_cache_params()
377
+ except ContextNotExistError as e:
378
+ logger.warning(f"context not exist: {e}, skip prune.")
379
+ hidden_states, new_encoder_hidden_states = self.call_blocks(
380
+ hidden_states,
381
+ *args,
382
+ **kwargs,
383
+ )
384
+ return self._process_forward_outputs(
385
+ hidden_states, new_encoder_hidden_states
386
+ )
387
+
388
+ self.context_manager.mark_step_begin()
389
+
390
+ # Call all blocks with prune strategy to process the hidden states.
391
+ new_encoder_hidden_states = None
392
+ for i, block in enumerate(self.transformer_blocks):
393
+ hidden_states, new_encoder_hidden_states = self.compute_or_prune(
394
+ i,
395
+ block,
396
+ hidden_states,
397
+ new_encoder_hidden_states,
398
+ *args,
399
+ **kwargs,
400
+ )
401
+
402
+ self.context_manager.add_pruned_block(self.pruned_blocks_step)
403
+ self.context_manager.add_actual_block(self.num_blocks)
404
+
405
+ return self._process_forward_outputs(
406
+ hidden_states,
407
+ new_encoder_hidden_states,
408
+ )
409
+
410
+ @property
411
+ @torch.compiler.disable
412
+ def num_blocks(self):
413
+ return len(self.transformer_blocks)
414
+
415
+ @torch.compiler.disable
416
+ def _skip_prune(self, block_id: int) -> bool:
417
+ # Wrap for non compiled mode.
418
+ return block_id in self.context_manager.get_non_prune_blocks_ids(
419
+ self.num_blocks
420
+ )
421
+
422
+ @torch.compiler.disable
423
+ def _maybe_prune(
424
+ self,
425
+ block_id: int, # Block index in the transformer blocks
426
+ hidden_states: torch.Tensor, # hidden_states or residual
427
+ prefix: str = "Bn_original", # prev step name for single blocks
428
+ ):
429
+ # Wrap for non compiled mode.
430
+ can_use_prune = False
431
+ if not self._skip_prune(block_id):
432
+ can_use_prune = self.context_manager.can_prune(
433
+ hidden_states, # curr step
434
+ parallelized=self._is_parallelized(),
435
+ prefix=prefix, # prev step
436
+ )
437
+ self.pruned_blocks_step += int(can_use_prune)
438
+ return can_use_prune
439
+
440
+ def compute_or_prune(
441
+ self,
442
+ block_id: int, # Block index in the transformer blocks
443
+ # Below are the inputs to the block
444
+ block, # The transformer block to be executed
445
+ hidden_states: torch.Tensor,
446
+ new_encoder_hidden_states: torch.Tensor | None,
447
+ *args,
448
+ **kwargs,
449
+ ):
450
+ original_hidden_states = hidden_states
451
+ original_encoder_hidden_states = new_encoder_hidden_states
452
+
453
+ can_use_prune = self._maybe_prune(
454
+ block_id,
455
+ hidden_states,
456
+ prefix=f"{self.cache_prefix}_{block_id}_Fn_original",
457
+ )
458
+
459
+ # Prune steps: Prune current block and reuse the cached
460
+ # residuals for hidden states approximate.
461
+ torch._dynamo.graph_break()
462
+ if can_use_prune:
463
+ self.context_manager.add_pruned_step()
464
+ hidden_states, new_encoder_hidden_states = (
465
+ self.context_manager.apply_prune(
466
+ hidden_states,
467
+ new_encoder_hidden_states,
468
+ prefix=(
469
+ f"{self.cache_prefix}_{block_id}_Bn_residual"
470
+ if self.context_manager.is_cache_residual()
471
+ else f"{self.cache_prefix}_{block_id}_Bn_hidden_states"
472
+ ),
473
+ encoder_prefix=(
474
+ f"{self.cache_prefix}_{block_id}_Bn_encoder_residual"
475
+ if self.context_manager.is_encoder_cache_residual()
476
+ else f"{self.cache_prefix}_{block_id}_Bn_encoder_hidden_states"
477
+ ),
478
+ )
479
+ )
480
+ torch._dynamo.graph_break()
481
+ else:
482
+ # Normal steps: Compute the block and cache the residuals.
483
+ hidden_states = block(
484
+ hidden_states,
485
+ *args,
486
+ **kwargs,
487
+ )
488
+ hidden_states, new_encoder_hidden_states = (
489
+ self._process_block_outputs(
490
+ hidden_states, new_encoder_hidden_states
491
+ )
492
+ )
493
+ if not self._skip_prune(block_id):
494
+ hidden_states = hidden_states.contiguous()
495
+ hidden_states_residual = hidden_states - original_hidden_states
496
+
497
+ if (
498
+ new_encoder_hidden_states is not None
499
+ and original_encoder_hidden_states is not None
500
+ ):
501
+ new_encoder_hidden_states = (
502
+ new_encoder_hidden_states.contiguous()
503
+ )
504
+ new_encoder_hidden_states_residual = (
505
+ new_encoder_hidden_states
506
+ - original_encoder_hidden_states
507
+ )
508
+ else:
509
+ new_encoder_hidden_states_residual = None
510
+
511
+ self.context_manager.set_Fn_buffer(
512
+ original_hidden_states,
513
+ prefix=f"{self.cache_prefix}_{block_id}_Fn_original",
514
+ )
515
+ if self.context_manager.is_cache_residual():
516
+ self.context_manager.set_Bn_buffer(
517
+ hidden_states_residual,
518
+ prefix=f"{self.cache_prefix}_{block_id}_Bn_residual",
519
+ )
520
+ else:
521
+ self.context_manager.set_Bn_buffer(
522
+ hidden_states,
523
+ prefix=f"{self.cache_prefix}_{block_id}_Bn_hidden_states",
524
+ )
525
+ if new_encoder_hidden_states_residual is not None:
526
+ if self.context_manager.is_encoder_cache_residual():
527
+ self.context_manager.set_Bn_encoder_buffer(
528
+ new_encoder_hidden_states_residual,
529
+ prefix=f"{self.cache_prefix}_{block_id}_Bn_encoder_residual",
530
+ )
531
+ else:
532
+ self.context_manager.set_Bn_encoder_buffer(
533
+ new_encoder_hidden_states_residual,
534
+ prefix=f"{self.cache_prefix}_{block_id}_Bn_encoder_hidden_states",
535
+ )
536
+ torch._dynamo.graph_break()
537
+
538
+ return hidden_states, new_encoder_hidden_states