cache-dit 1.0.2__py3-none-any.whl → 1.0.4__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.

Potentially problematic release.


This version of cache-dit might be problematic. Click here for more details.

Files changed (29) hide show
  1. cache_dit/__init__.py +3 -0
  2. cache_dit/_version.py +2 -2
  3. cache_dit/cache_factory/__init__.py +8 -1
  4. cache_dit/cache_factory/cache_adapters/cache_adapter.py +90 -76
  5. cache_dit/cache_factory/cache_blocks/__init__.py +167 -17
  6. cache_dit/cache_factory/cache_blocks/pattern_0_1_2.py +10 -0
  7. cache_dit/cache_factory/cache_blocks/pattern_3_4_5.py +271 -36
  8. cache_dit/cache_factory/cache_blocks/pattern_base.py +286 -45
  9. cache_dit/cache_factory/cache_blocks/pattern_utils.py +55 -10
  10. cache_dit/cache_factory/cache_contexts/__init__.py +15 -2
  11. cache_dit/cache_factory/cache_contexts/cache_config.py +102 -0
  12. cache_dit/cache_factory/cache_contexts/cache_context.py +26 -89
  13. cache_dit/cache_factory/cache_contexts/cache_manager.py +7 -7
  14. cache_dit/cache_factory/cache_contexts/calibrators/taylorseer.py +78 -8
  15. cache_dit/cache_factory/cache_contexts/context_manager.py +29 -0
  16. cache_dit/cache_factory/cache_contexts/prune_config.py +69 -0
  17. cache_dit/cache_factory/cache_contexts/prune_context.py +155 -0
  18. cache_dit/cache_factory/cache_contexts/prune_manager.py +154 -0
  19. cache_dit/cache_factory/cache_interface.py +23 -14
  20. cache_dit/cache_factory/cache_types.py +19 -2
  21. cache_dit/cache_factory/params_modifier.py +7 -7
  22. cache_dit/cache_factory/utils.py +38 -27
  23. cache_dit/utils.py +191 -54
  24. {cache_dit-1.0.2.dist-info → cache_dit-1.0.4.dist-info}/METADATA +14 -7
  25. {cache_dit-1.0.2.dist-info → cache_dit-1.0.4.dist-info}/RECORD +29 -24
  26. {cache_dit-1.0.2.dist-info → cache_dit-1.0.4.dist-info}/WHEEL +0 -0
  27. {cache_dit-1.0.2.dist-info → cache_dit-1.0.4.dist-info}/entry_points.txt +0 -0
  28. {cache_dit-1.0.2.dist-info → cache_dit-1.0.4.dist-info}/licenses/LICENSE +0 -0
  29. {cache_dit-1.0.2.dist-info → cache_dit-1.0.4.dist-info}/top_level.txt +0 -0
@@ -2,11 +2,17 @@ import torch
2
2
 
3
3
  from cache_dit.cache_factory import ForwardPattern
4
4
  from cache_dit.cache_factory.cache_contexts.cache_manager import (
5
- CacheNotExistError,
5
+ ContextNotExistError,
6
6
  )
7
7
  from cache_dit.cache_factory.cache_blocks.pattern_base import (
8
8
  CachedBlocks_Pattern_Base,
9
9
  )
10
+ from cache_dit.cache_factory.cache_contexts.prune_context import PrunedContext
11
+ from cache_dit.cache_factory.cache_contexts.prune_manager import (
12
+ PrunedContextManager,
13
+ )
14
+ from cache_dit.cache_factory.cache_types import CacheType
15
+
10
16
  from cache_dit.logger import init_logger
11
17
 
12
18
  logger = init_logger(__name__)
@@ -33,14 +39,14 @@ class CachedBlocks_Pattern_3_4_5(CachedBlocks_Pattern_Base):
33
39
  *args,
34
40
  **kwargs,
35
41
  )
36
- hidden_states, new_encoder_hidden_states = self._process_outputs(
37
- hidden_states
42
+ hidden_states, new_encoder_hidden_states = (
43
+ self._process_block_outputs(hidden_states)
38
44
  )
39
45
 
40
46
  return hidden_states, new_encoder_hidden_states
41
47
 
42
48
  @torch.compiler.disable
43
- def _process_outputs(
49
+ def _process_block_outputs(
44
50
  self, hidden_states: torch.Tensor | tuple
45
51
  ) -> tuple[torch.Tensor, torch.Tensor | None]:
46
52
  # Process the outputs for the block.
@@ -66,7 +72,7 @@ class CachedBlocks_Pattern_3_4_5(CachedBlocks_Pattern_Base):
66
72
  return hidden_states, new_encoder_hidden_states
67
73
 
68
74
  @torch.compiler.disable
69
- def _forward_outputs(
75
+ def _process_forward_outputs(
70
76
  self,
71
77
  hidden_states: torch.Tensor,
72
78
  new_encoder_hidden_states: torch.Tensor | None,
@@ -91,16 +97,16 @@ class CachedBlocks_Pattern_3_4_5(CachedBlocks_Pattern_Base):
91
97
  ):
92
98
  # Use it's own cache context.
93
99
  try:
94
- self.cache_manager.set_context(self.cache_context)
100
+ self.context_manager.set_context(self.cache_context)
95
101
  self._check_cache_params()
96
- except CacheNotExistError as e:
97
- logger.warning(f"Cache context not exist: {e}, skip cache.")
102
+ except ContextNotExistError as e:
103
+ logger.warning(f"context not exist: {e}, skip cache.")
98
104
  hidden_states, new_encoder_hidden_states = self.call_blocks(
99
105
  hidden_states,
100
106
  *args,
101
107
  **kwargs,
102
108
  )
103
- return self._forward_outputs(
109
+ return self._process_forward_outputs(
104
110
  hidden_states, new_encoder_hidden_states
105
111
  )
106
112
 
@@ -118,38 +124,38 @@ class CachedBlocks_Pattern_3_4_5(CachedBlocks_Pattern_Base):
118
124
  )
119
125
  del original_hidden_states
120
126
 
121
- self.cache_manager.mark_step_begin()
127
+ self.context_manager.mark_step_begin()
122
128
  # Residual L1 diff or Hidden States L1 diff
123
- can_use_cache = self.cache_manager.can_cache(
129
+ can_use_cache = self.context_manager.can_cache(
124
130
  (
125
131
  Fn_hidden_states_residual
126
- if not self.cache_manager.is_l1_diff_enabled()
132
+ if not self.context_manager.is_l1_diff_enabled()
127
133
  else hidden_states
128
134
  ),
129
135
  parallelized=self._is_parallelized(),
130
136
  prefix=(
131
137
  f"{self.cache_prefix}_Fn_residual"
132
- if not self.cache_manager.is_l1_diff_enabled()
138
+ if not self.context_manager.is_l1_diff_enabled()
133
139
  else f"{self.cache_prefix}_Fn_hidden_states"
134
140
  ),
135
141
  )
136
142
 
137
143
  torch._dynamo.graph_break()
138
144
  if can_use_cache:
139
- self.cache_manager.add_cached_step()
145
+ self.context_manager.add_cached_step()
140
146
  del Fn_hidden_states_residual
141
147
  hidden_states, new_encoder_hidden_states = (
142
- self.cache_manager.apply_cache(
148
+ self.context_manager.apply_cache(
143
149
  hidden_states,
144
150
  new_encoder_hidden_states, # encoder_hidden_states not use cache
145
151
  prefix=(
146
152
  f"{self.cache_prefix}_Bn_residual"
147
- if self.cache_manager.is_cache_residual()
153
+ if self.context_manager.is_cache_residual()
148
154
  else f"{self.cache_prefix}_Bn_hidden_states"
149
155
  ),
150
156
  encoder_prefix=(
151
157
  f"{self.cache_prefix}_Bn_residual"
152
- if self.cache_manager.is_encoder_cache_residual()
158
+ if self.context_manager.is_encoder_cache_residual()
153
159
  else f"{self.cache_prefix}_Bn_hidden_states"
154
160
  ),
155
161
  )
@@ -157,20 +163,20 @@ class CachedBlocks_Pattern_3_4_5(CachedBlocks_Pattern_Base):
157
163
  torch._dynamo.graph_break()
158
164
  # Call last `n` blocks to further process the hidden states
159
165
  # for higher precision.
160
- if self.cache_manager.Bn_compute_blocks() > 0:
166
+ if self.context_manager.Bn_compute_blocks() > 0:
161
167
  hidden_states, new_encoder_hidden_states = self.call_Bn_blocks(
162
168
  hidden_states,
163
169
  *args,
164
170
  **kwargs,
165
171
  )
166
172
  else:
167
- self.cache_manager.set_Fn_buffer(
173
+ self.context_manager.set_Fn_buffer(
168
174
  Fn_hidden_states_residual,
169
175
  prefix=f"{self.cache_prefix}_Fn_residual",
170
176
  )
171
- if self.cache_manager.is_l1_diff_enabled():
177
+ if self.context_manager.is_l1_diff_enabled():
172
178
  # for hidden states L1 diff
173
- self.cache_manager.set_Fn_buffer(
179
+ self.context_manager.set_Fn_buffer(
174
180
  hidden_states,
175
181
  f"{self.cache_prefix}_Fn_hidden_states",
176
182
  )
@@ -188,13 +194,13 @@ class CachedBlocks_Pattern_3_4_5(CachedBlocks_Pattern_Base):
188
194
  )
189
195
 
190
196
  torch._dynamo.graph_break()
191
- if self.cache_manager.is_cache_residual():
192
- self.cache_manager.set_Bn_buffer(
197
+ if self.context_manager.is_cache_residual():
198
+ self.context_manager.set_Bn_buffer(
193
199
  hidden_states_residual,
194
200
  prefix=f"{self.cache_prefix}_Bn_residual",
195
201
  )
196
202
  else:
197
- self.cache_manager.set_Bn_buffer(
203
+ self.context_manager.set_Bn_buffer(
198
204
  hidden_states,
199
205
  prefix=f"{self.cache_prefix}_Bn_hidden_states",
200
206
  )
@@ -203,22 +209,22 @@ class CachedBlocks_Pattern_3_4_5(CachedBlocks_Pattern_Base):
203
209
  new_encoder_hidden_states_residual = (
204
210
  new_encoder_hidden_states - old_encoder_hidden_states
205
211
  )
206
- if self.cache_manager.is_encoder_cache_residual():
212
+ if self.context_manager.is_encoder_cache_residual():
207
213
  if new_encoder_hidden_states is not None:
208
- self.cache_manager.set_Bn_encoder_buffer(
214
+ self.context_manager.set_Bn_encoder_buffer(
209
215
  new_encoder_hidden_states_residual,
210
216
  prefix=f"{self.cache_prefix}_Bn_residual",
211
217
  )
212
218
  else:
213
219
  if new_encoder_hidden_states is not None:
214
- self.cache_manager.set_Bn_encoder_buffer(
220
+ self.context_manager.set_Bn_encoder_buffer(
215
221
  new_encoder_hidden_states_residual,
216
222
  prefix=f"{self.cache_prefix}_Bn_hidden_states",
217
223
  )
218
224
  torch._dynamo.graph_break()
219
225
  # Call last `n` blocks to further process the hidden states
220
226
  # for higher precision.
221
- if self.cache_manager.Bn_compute_blocks() > 0:
227
+ if self.context_manager.Bn_compute_blocks() > 0:
222
228
  hidden_states, new_encoder_hidden_states = self.call_Bn_blocks(
223
229
  hidden_states,
224
230
  *args,
@@ -227,7 +233,10 @@ class CachedBlocks_Pattern_3_4_5(CachedBlocks_Pattern_Base):
227
233
 
228
234
  torch._dynamo.graph_break()
229
235
 
230
- return self._forward_outputs(hidden_states, new_encoder_hidden_states)
236
+ return self._process_forward_outputs(
237
+ hidden_states,
238
+ new_encoder_hidden_states,
239
+ )
231
240
 
232
241
  def call_Fn_blocks(
233
242
  self,
@@ -242,8 +251,8 @@ class CachedBlocks_Pattern_3_4_5(CachedBlocks_Pattern_Base):
242
251
  *args,
243
252
  **kwargs,
244
253
  )
245
- hidden_states, new_encoder_hidden_states = self._process_outputs(
246
- hidden_states
254
+ hidden_states, new_encoder_hidden_states = (
255
+ self._process_block_outputs(hidden_states)
247
256
  )
248
257
 
249
258
  return hidden_states, new_encoder_hidden_states
@@ -263,8 +272,8 @@ class CachedBlocks_Pattern_3_4_5(CachedBlocks_Pattern_Base):
263
272
  **kwargs,
264
273
  )
265
274
 
266
- hidden_states, new_encoder_hidden_states = self._process_outputs(
267
- hidden_states
275
+ hidden_states, new_encoder_hidden_states = (
276
+ self._process_block_outputs(hidden_states)
268
277
  )
269
278
 
270
279
  # compute hidden_states residual
@@ -286,7 +295,7 @@ class CachedBlocks_Pattern_3_4_5(CachedBlocks_Pattern_Base):
286
295
  **kwargs,
287
296
  ):
288
297
  new_encoder_hidden_states = None
289
- if self.cache_manager.Bn_compute_blocks() == 0:
298
+ if self.context_manager.Bn_compute_blocks() == 0:
290
299
  return hidden_states, new_encoder_hidden_states
291
300
 
292
301
  for block in self._Bn_blocks():
@@ -296,8 +305,234 @@ class CachedBlocks_Pattern_3_4_5(CachedBlocks_Pattern_Base):
296
305
  **kwargs,
297
306
  )
298
307
 
299
- hidden_states, new_encoder_hidden_states = self._process_outputs(
300
- hidden_states
308
+ hidden_states, new_encoder_hidden_states = (
309
+ self._process_block_outputs(hidden_states)
310
+ )
311
+
312
+ return hidden_states, new_encoder_hidden_states
313
+
314
+
315
+ class PrunedBlocks_Pattern_3_4_5(CachedBlocks_Pattern_3_4_5):
316
+ _supported_patterns = [
317
+ ForwardPattern.Pattern_3,
318
+ ForwardPattern.Pattern_4,
319
+ ForwardPattern.Pattern_5,
320
+ ]
321
+ pruned_blocks_step: int = 0 # number of pruned blocks in current step
322
+
323
+ def __init__(
324
+ self,
325
+ # 0. Transformer blocks configuration
326
+ transformer_blocks: torch.nn.ModuleList,
327
+ transformer: torch.nn.Module = None,
328
+ forward_pattern: ForwardPattern = ForwardPattern.Pattern_0,
329
+ check_forward_pattern: bool = True,
330
+ check_num_outputs: bool = True,
331
+ # 1. Prune context configuration
332
+ cache_prefix: str = None, # maybe un-need.
333
+ cache_context: PrunedContext | str = None,
334
+ context_manager: PrunedContextManager = None,
335
+ cache_type: CacheType = CacheType.DBPrune,
336
+ **kwargs,
337
+ ):
338
+ super().__init__(
339
+ # 0. Transformer blocks configuration
340
+ transformer_blocks,
341
+ transformer=transformer,
342
+ forward_pattern=forward_pattern,
343
+ check_forward_pattern=check_forward_pattern,
344
+ check_num_outputs=check_num_outputs,
345
+ # 1. Cache context configuration
346
+ cache_prefix=cache_prefix,
347
+ cache_context=cache_context,
348
+ context_manager=context_manager,
349
+ cache_type=cache_type,
350
+ **kwargs,
351
+ )
352
+ assert isinstance(
353
+ self.context_manager, PrunedContextManager
354
+ ), "context_manager must be PrunedContextManager for PrunedBlocks."
355
+ self.context_manager: PrunedContextManager = (
356
+ self.context_manager
357
+ ) # For type hint
358
+
359
+ @torch.compiler.disable
360
+ def _check_cache_type(self):
361
+ assert (
362
+ self.cache_type == CacheType.DBPrune
363
+ ), f"Cache type {self.cache_type} is not supported for PrunedBlocks."
364
+
365
+ def forward(
366
+ self,
367
+ hidden_states: torch.Tensor,
368
+ *args,
369
+ **kwargs,
370
+ ):
371
+ self.pruned_blocks_step: int = 0 # reset for each step
372
+
373
+ # Use it's own cache context.
374
+ try:
375
+ self.context_manager.set_context(self.cache_context)
376
+ self._check_cache_params()
377
+ except ContextNotExistError as e:
378
+ logger.warning(f"context not exist: {e}, skip prune.")
379
+ hidden_states, new_encoder_hidden_states = self.call_blocks(
380
+ hidden_states,
381
+ *args,
382
+ **kwargs,
383
+ )
384
+ return self._process_forward_outputs(
385
+ hidden_states, new_encoder_hidden_states
386
+ )
387
+
388
+ self.context_manager.mark_step_begin()
389
+
390
+ # Call all blocks with prune strategy to process the hidden states.
391
+ new_encoder_hidden_states = None
392
+ for i, block in enumerate(self.transformer_blocks):
393
+ hidden_states, new_encoder_hidden_states = self.compute_or_prune(
394
+ i,
395
+ block,
396
+ hidden_states,
397
+ new_encoder_hidden_states,
398
+ *args,
399
+ **kwargs,
400
+ )
401
+
402
+ self.context_manager.add_pruned_block(self.pruned_blocks_step)
403
+ self.context_manager.add_actual_block(self.num_blocks)
404
+
405
+ return self._process_forward_outputs(
406
+ hidden_states,
407
+ new_encoder_hidden_states,
408
+ )
409
+
410
+ @property
411
+ @torch.compiler.disable
412
+ def num_blocks(self):
413
+ return len(self.transformer_blocks)
414
+
415
+ @torch.compiler.disable
416
+ def _skip_prune(self, block_id: int) -> bool:
417
+ # Wrap for non compiled mode.
418
+ return block_id in self.context_manager.get_non_prune_blocks_ids(
419
+ self.num_blocks
420
+ )
421
+
422
+ @torch.compiler.disable
423
+ def _maybe_prune(
424
+ self,
425
+ block_id: int, # Block index in the transformer blocks
426
+ hidden_states: torch.Tensor, # hidden_states or residual
427
+ prefix: str = "Bn_original", # prev step name for single blocks
428
+ ):
429
+ # Wrap for non compiled mode.
430
+ can_use_prune = False
431
+ if not self._skip_prune(block_id):
432
+ can_use_prune = self.context_manager.can_prune(
433
+ hidden_states, # curr step
434
+ parallelized=self._is_parallelized(),
435
+ prefix=prefix, # prev step
436
+ )
437
+ self.pruned_blocks_step += int(can_use_prune)
438
+ return can_use_prune
439
+
440
+ def compute_or_prune(
441
+ self,
442
+ block_id: int, # Block index in the transformer blocks
443
+ # Below are the inputs to the block
444
+ block, # The transformer block to be executed
445
+ hidden_states: torch.Tensor,
446
+ new_encoder_hidden_states: torch.Tensor | None,
447
+ *args,
448
+ **kwargs,
449
+ ):
450
+ original_hidden_states = hidden_states
451
+ original_encoder_hidden_states = new_encoder_hidden_states
452
+
453
+ can_use_prune = self._maybe_prune(
454
+ block_id,
455
+ hidden_states,
456
+ prefix=f"{self.cache_prefix}_{block_id}_Fn_original",
457
+ )
458
+
459
+ # Prune steps: Prune current block and reuse the cached
460
+ # residuals for hidden states approximate.
461
+ torch._dynamo.graph_break()
462
+ if can_use_prune:
463
+ self.context_manager.add_pruned_step()
464
+ hidden_states, new_encoder_hidden_states = (
465
+ self.context_manager.apply_prune(
466
+ hidden_states,
467
+ new_encoder_hidden_states,
468
+ prefix=(
469
+ f"{self.cache_prefix}_{block_id}_Bn_residual"
470
+ if self.context_manager.is_cache_residual()
471
+ else f"{self.cache_prefix}_{block_id}_Bn_hidden_states"
472
+ ),
473
+ encoder_prefix=(
474
+ f"{self.cache_prefix}_{block_id}_Bn_encoder_residual"
475
+ if self.context_manager.is_encoder_cache_residual()
476
+ else f"{self.cache_prefix}_{block_id}_Bn_encoder_hidden_states"
477
+ ),
478
+ )
479
+ )
480
+ torch._dynamo.graph_break()
481
+ else:
482
+ # Normal steps: Compute the block and cache the residuals.
483
+ hidden_states = block(
484
+ hidden_states,
485
+ *args,
486
+ **kwargs,
301
487
  )
488
+ hidden_states, new_encoder_hidden_states = (
489
+ self._process_block_outputs(
490
+ hidden_states, new_encoder_hidden_states
491
+ )
492
+ )
493
+ if not self._skip_prune(block_id):
494
+ hidden_states = hidden_states.contiguous()
495
+ hidden_states_residual = hidden_states - original_hidden_states
496
+
497
+ if (
498
+ new_encoder_hidden_states is not None
499
+ and original_encoder_hidden_states is not None
500
+ ):
501
+ new_encoder_hidden_states = (
502
+ new_encoder_hidden_states.contiguous()
503
+ )
504
+ new_encoder_hidden_states_residual = (
505
+ new_encoder_hidden_states
506
+ - original_encoder_hidden_states
507
+ )
508
+ else:
509
+ new_encoder_hidden_states_residual = None
510
+
511
+ self.context_manager.set_Fn_buffer(
512
+ original_hidden_states,
513
+ prefix=f"{self.cache_prefix}_{block_id}_Fn_original",
514
+ )
515
+ if self.context_manager.is_cache_residual():
516
+ self.context_manager.set_Bn_buffer(
517
+ hidden_states_residual,
518
+ prefix=f"{self.cache_prefix}_{block_id}_Bn_residual",
519
+ )
520
+ else:
521
+ self.context_manager.set_Bn_buffer(
522
+ hidden_states,
523
+ prefix=f"{self.cache_prefix}_{block_id}_Bn_hidden_states",
524
+ )
525
+ if new_encoder_hidden_states_residual is not None:
526
+ if self.context_manager.is_encoder_cache_residual():
527
+ self.context_manager.set_Bn_encoder_buffer(
528
+ new_encoder_hidden_states_residual,
529
+ prefix=f"{self.cache_prefix}_{block_id}_Bn_encoder_residual",
530
+ )
531
+ else:
532
+ self.context_manager.set_Bn_encoder_buffer(
533
+ new_encoder_hidden_states_residual,
534
+ prefix=f"{self.cache_prefix}_{block_id}_Bn_encoder_hidden_states",
535
+ )
536
+ torch._dynamo.graph_break()
302
537
 
303
538
  return hidden_states, new_encoder_hidden_states