cache-dit 0.2.27__py3-none-any.whl → 0.2.29__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.

Potentially problematic release.


This version of cache-dit might be problematic. Click here for more details.

@@ -1,6 +1,5 @@
1
1
  import torch
2
2
 
3
- from cache_dit.cache_factory import CachedContext
4
3
  from cache_dit.cache_factory import ForwardPattern
5
4
  from cache_dit.cache_factory.cache_blocks.pattern_base import (
6
5
  CachedBlocks_Pattern_Base,
@@ -24,7 +23,7 @@ class CachedBlocks_Pattern_3_4_5(CachedBlocks_Pattern_Base):
24
23
  **kwargs,
25
24
  ):
26
25
  # Use it's own cache context.
27
- CachedContext.set_cache_context(
26
+ self.cache_manager.set_context(
28
27
  self.cache_context,
29
28
  )
30
29
 
@@ -41,40 +40,40 @@ class CachedBlocks_Pattern_3_4_5(CachedBlocks_Pattern_Base):
41
40
  Fn_hidden_states_residual = hidden_states - original_hidden_states
42
41
  del original_hidden_states
43
42
 
44
- CachedContext.mark_step_begin()
43
+ self.cache_manager.mark_step_begin()
45
44
  # Residual L1 diff or Hidden States L1 diff
46
- can_use_cache = CachedContext.get_can_use_cache(
45
+ can_use_cache = self.cache_manager.can_cache(
47
46
  (
48
47
  Fn_hidden_states_residual
49
- if not CachedContext.is_l1_diff_enabled()
48
+ if not self.cache_manager.is_l1_diff_enabled()
50
49
  else hidden_states
51
50
  ),
52
51
  parallelized=self._is_parallelized(),
53
52
  prefix=(
54
- f"{self.blocks_name}_Fn_residual"
55
- if not CachedContext.is_l1_diff_enabled()
56
- else f"{self.blocks_name}_Fn_hidden_states"
53
+ f"{self.cache_prefix}_Fn_residual"
54
+ if not self.cache_manager.is_l1_diff_enabled()
55
+ else f"{self.cache_prefix}_Fn_hidden_states"
57
56
  ),
58
57
  )
59
58
 
60
59
  torch._dynamo.graph_break()
61
60
  if can_use_cache:
62
- CachedContext.add_cached_step()
61
+ self.cache_manager.add_cached_step()
63
62
  del Fn_hidden_states_residual
64
63
  hidden_states, encoder_hidden_states = (
65
- CachedContext.apply_hidden_states_residual(
64
+ self.cache_manager.apply_cache(
66
65
  hidden_states,
67
66
  # None Pattern 3, else 4, 5
68
67
  encoder_hidden_states,
69
68
  prefix=(
70
- f"{self.blocks_name}_Bn_residual"
71
- if CachedContext.is_cache_residual()
72
- else f"{self.blocks_name}_Bn_hidden_states"
69
+ f"{self.cache_prefix}_Bn_residual"
70
+ if self.cache_manager.is_cache_residual()
71
+ else f"{self.cache_prefix}_Bn_hidden_states"
73
72
  ),
74
73
  encoder_prefix=(
75
- f"{self.blocks_name}_Bn_residual"
76
- if CachedContext.is_encoder_cache_residual()
77
- else f"{self.blocks_name}_Bn_hidden_states"
74
+ f"{self.cache_prefix}_Bn_residual"
75
+ if self.cache_manager.is_encoder_cache_residual()
76
+ else f"{self.cache_prefix}_Bn_hidden_states"
78
77
  ),
79
78
  )
80
79
  )
@@ -88,15 +87,15 @@ class CachedBlocks_Pattern_3_4_5(CachedBlocks_Pattern_Base):
88
87
  **kwargs,
89
88
  )
90
89
  else:
91
- CachedContext.set_Fn_buffer(
90
+ self.cache_manager.set_Fn_buffer(
92
91
  Fn_hidden_states_residual,
93
- prefix=f"{self.blocks_name}_Fn_residual",
92
+ prefix=f"{self.cache_prefix}_Fn_residual",
94
93
  )
95
- if CachedContext.is_l1_diff_enabled():
94
+ if self.cache_manager.is_l1_diff_enabled():
96
95
  # for hidden states L1 diff
97
- CachedContext.set_Fn_buffer(
96
+ self.cache_manager.set_Fn_buffer(
98
97
  hidden_states,
99
- f"{self.blocks_name}_Fn_hidden_states",
98
+ f"{self.cache_prefix}_Fn_hidden_states",
100
99
  )
101
100
  del Fn_hidden_states_residual
102
101
  torch._dynamo.graph_break()
@@ -114,29 +113,29 @@ class CachedBlocks_Pattern_3_4_5(CachedBlocks_Pattern_Base):
114
113
  **kwargs,
115
114
  )
116
115
  torch._dynamo.graph_break()
117
- if CachedContext.is_cache_residual():
118
- CachedContext.set_Bn_buffer(
116
+ if self.cache_manager.is_cache_residual():
117
+ self.cache_manager.set_Bn_buffer(
119
118
  hidden_states_residual,
120
- prefix=f"{self.blocks_name}_Bn_residual",
119
+ prefix=f"{self.cache_prefix}_Bn_residual",
121
120
  )
122
121
  else:
123
122
  # TaylorSeer
124
- CachedContext.set_Bn_buffer(
123
+ self.cache_manager.set_Bn_buffer(
125
124
  hidden_states,
126
- prefix=f"{self.blocks_name}_Bn_hidden_states",
125
+ prefix=f"{self.cache_prefix}_Bn_hidden_states",
127
126
  )
128
- if CachedContext.is_encoder_cache_residual():
129
- CachedContext.set_Bn_encoder_buffer(
127
+ if self.cache_manager.is_encoder_cache_residual():
128
+ self.cache_manager.set_Bn_encoder_buffer(
130
129
  # None Pattern 3, else 4, 5
131
130
  encoder_hidden_states_residual,
132
- prefix=f"{self.blocks_name}_Bn_residual",
131
+ prefix=f"{self.cache_prefix}_Bn_residual",
133
132
  )
134
133
  else:
135
134
  # TaylorSeer
136
- CachedContext.set_Bn_encoder_buffer(
135
+ self.cache_manager.set_Bn_encoder_buffer(
137
136
  # None Pattern 3, else 4, 5
138
137
  encoder_hidden_states,
139
- prefix=f"{self.blocks_name}_Bn_hidden_states",
138
+ prefix=f"{self.cache_prefix}_Bn_hidden_states",
140
139
  )
141
140
  torch._dynamo.graph_break()
142
141
  # Call last `n` blocks to further process the hidden states
@@ -167,10 +166,10 @@ class CachedBlocks_Pattern_3_4_5(CachedBlocks_Pattern_Base):
167
166
  *args,
168
167
  **kwargs,
169
168
  ):
170
- assert CachedContext.Fn_compute_blocks() <= len(
169
+ assert self.cache_manager.Fn_compute_blocks() <= len(
171
170
  self.transformer_blocks
172
171
  ), (
173
- f"Fn_compute_blocks {CachedContext.Fn_compute_blocks()} must be less than "
172
+ f"Fn_compute_blocks {self.cache_manager.Fn_compute_blocks()} must be less than "
174
173
  f"the number of transformer blocks {len(self.transformer_blocks)}"
175
174
  )
176
175
  encoder_hidden_states = None # Pattern 3
@@ -242,16 +241,16 @@ class CachedBlocks_Pattern_3_4_5(CachedBlocks_Pattern_Base):
242
241
  *args,
243
242
  **kwargs,
244
243
  ):
245
- if CachedContext.Bn_compute_blocks() == 0:
244
+ if self.cache_manager.Bn_compute_blocks() == 0:
246
245
  return hidden_states, encoder_hidden_states
247
246
 
248
- assert CachedContext.Bn_compute_blocks() <= len(
247
+ assert self.cache_manager.Bn_compute_blocks() <= len(
249
248
  self.transformer_blocks
250
249
  ), (
251
- f"Bn_compute_blocks {CachedContext.Bn_compute_blocks()} must be less than "
250
+ f"Bn_compute_blocks {self.cache_manager.Bn_compute_blocks()} must be less than "
252
251
  f"the number of transformer blocks {len(self.transformer_blocks)}"
253
252
  )
254
- if len(CachedContext.Bn_compute_blocks_ids()) > 0:
253
+ if len(self.cache_manager.Bn_compute_blocks_ids()) > 0:
255
254
  raise ValueError(
256
255
  f"Bn_compute_blocks_ids is not support for "
257
256
  f"patterns: {self._supported_patterns}."
@@ -2,7 +2,10 @@ import inspect
2
2
  import torch
3
3
  import torch.distributed as dist
4
4
 
5
- from cache_dit.cache_factory import CachedContext
5
+ from cache_dit.cache_factory.cache_contexts.cache_context import CachedContext
6
+ from cache_dit.cache_factory.cache_contexts.cache_manager import (
7
+ CachedContextManager,
8
+ )
6
9
  from cache_dit.cache_factory import ForwardPattern
7
10
  from cache_dit.logger import init_logger
8
11
 
@@ -18,29 +21,34 @@ class CachedBlocks_Pattern_Base(torch.nn.Module):
18
21
 
19
22
  def __init__(
20
23
  self,
24
+ # 0. Transformer blocks configuration
21
25
  transformer_blocks: torch.nn.ModuleList,
22
- # 'transformer_blocks', 'blocks', 'single_transformer_blocks',
23
- # 'layers', 'single_stream_blocks', 'double_stream_blocks'
24
- blocks_name: str,
25
- # Usually, blocks_name, etc.
26
- cache_context: str,
27
- *,
28
26
  transformer: torch.nn.Module = None,
29
27
  forward_pattern: ForwardPattern = ForwardPattern.Pattern_0,
30
28
  check_num_outputs: bool = True,
29
+ # 1. Cache context configuration
30
+ cache_prefix: str = None, # maybe un-need.
31
+ cache_context: CachedContext | str = None,
32
+ cache_manager: CachedContextManager = None,
33
+ **kwargs,
31
34
  ):
32
35
  super().__init__()
33
36
 
37
+ # 0. Transformer blocks configuration
34
38
  self.transformer = transformer
35
39
  self.transformer_blocks = transformer_blocks
36
- self.blocks_name = blocks_name
37
- self.cache_context = cache_context
38
40
  self.forward_pattern = forward_pattern
39
41
  self.check_num_outputs = check_num_outputs
42
+ # 1. Cache context configuration
43
+ self.cache_prefix = cache_prefix
44
+ self.cache_context = cache_context
45
+ self.cache_manager = cache_manager
46
+
40
47
  self._check_forward_pattern()
41
48
  logger.info(
42
49
  f"Match Cached Blocks: {self.__class__.__name__}, for "
43
- f"{self.blocks_name}, context: {self.cache_context}"
50
+ f"{self.cache_prefix}, cache_context: {self.cache_context}, "
51
+ f"cache_manager: {self.cache_manager.name}."
44
52
  )
45
53
 
46
54
  def _check_forward_pattern(self):
@@ -79,9 +87,7 @@ class CachedBlocks_Pattern_Base(torch.nn.Module):
79
87
  *args,
80
88
  **kwargs,
81
89
  ):
82
- CachedContext.set_cache_context(
83
- self.cache_context,
84
- )
90
+ self.cache_manager.set_context(self.cache_context)
85
91
 
86
92
  original_hidden_states = hidden_states
87
93
  # Call first `n` blocks to process the hidden states for
@@ -96,39 +102,39 @@ class CachedBlocks_Pattern_Base(torch.nn.Module):
96
102
  Fn_hidden_states_residual = hidden_states - original_hidden_states
97
103
  del original_hidden_states
98
104
 
99
- CachedContext.mark_step_begin()
105
+ self.cache_manager.mark_step_begin()
100
106
  # Residual L1 diff or Hidden States L1 diff
101
- can_use_cache = CachedContext.get_can_use_cache(
107
+ can_use_cache = self.cache_manager.can_cache(
102
108
  (
103
109
  Fn_hidden_states_residual
104
- if not CachedContext.is_l1_diff_enabled()
110
+ if not self.cache_manager.is_l1_diff_enabled()
105
111
  else hidden_states
106
112
  ),
107
113
  parallelized=self._is_parallelized(),
108
114
  prefix=(
109
- f"{self.blocks_name}_Fn_residual"
110
- if not CachedContext.is_l1_diff_enabled()
111
- else f"{self.blocks_name}_Fn_hidden_states"
115
+ f"{self.cache_prefix}_Fn_residual"
116
+ if not self.cache_manager.is_l1_diff_enabled()
117
+ else f"{self.cache_prefix}_Fn_hidden_states"
112
118
  ),
113
119
  )
114
120
 
115
121
  torch._dynamo.graph_break()
116
122
  if can_use_cache:
117
- CachedContext.add_cached_step()
123
+ self.cache_manager.add_cached_step()
118
124
  del Fn_hidden_states_residual
119
125
  hidden_states, encoder_hidden_states = (
120
- CachedContext.apply_hidden_states_residual(
126
+ self.cache_manager.apply_cache(
121
127
  hidden_states,
122
128
  encoder_hidden_states,
123
129
  prefix=(
124
- f"{self.blocks_name}_Bn_residual"
125
- if CachedContext.is_cache_residual()
126
- else f"{self.blocks_name}_Bn_hidden_states"
130
+ f"{self.cache_prefix}_Bn_residual"
131
+ if self.cache_manager.is_cache_residual()
132
+ else f"{self.cache_prefix}_Bn_hidden_states"
127
133
  ),
128
134
  encoder_prefix=(
129
- f"{self.blocks_name}_Bn_residual"
130
- if CachedContext.is_encoder_cache_residual()
131
- else f"{self.blocks_name}_Bn_hidden_states"
135
+ f"{self.cache_prefix}_Bn_residual"
136
+ if self.cache_manager.is_encoder_cache_residual()
137
+ else f"{self.cache_prefix}_Bn_hidden_states"
132
138
  ),
133
139
  )
134
140
  )
@@ -142,15 +148,15 @@ class CachedBlocks_Pattern_Base(torch.nn.Module):
142
148
  **kwargs,
143
149
  )
144
150
  else:
145
- CachedContext.set_Fn_buffer(
151
+ self.cache_manager.set_Fn_buffer(
146
152
  Fn_hidden_states_residual,
147
- prefix=f"{self.blocks_name}_Fn_residual",
153
+ prefix=f"{self.cache_prefix}_Fn_residual",
148
154
  )
149
- if CachedContext.is_l1_diff_enabled():
155
+ if self.cache_manager.is_l1_diff_enabled():
150
156
  # for hidden states L1 diff
151
- CachedContext.set_Fn_buffer(
157
+ self.cache_manager.set_Fn_buffer(
152
158
  hidden_states,
153
- f"{self.blocks_name}_Fn_hidden_states",
159
+ f"{self.cache_prefix}_Fn_hidden_states",
154
160
  )
155
161
  del Fn_hidden_states_residual
156
162
  torch._dynamo.graph_break()
@@ -166,27 +172,27 @@ class CachedBlocks_Pattern_Base(torch.nn.Module):
166
172
  **kwargs,
167
173
  )
168
174
  torch._dynamo.graph_break()
169
- if CachedContext.is_cache_residual():
170
- CachedContext.set_Bn_buffer(
175
+ if self.cache_manager.is_cache_residual():
176
+ self.cache_manager.set_Bn_buffer(
171
177
  hidden_states_residual,
172
- prefix=f"{self.blocks_name}_Bn_residual",
178
+ prefix=f"{self.cache_prefix}_Bn_residual",
173
179
  )
174
180
  else:
175
181
  # TaylorSeer
176
- CachedContext.set_Bn_buffer(
182
+ self.cache_manager.set_Bn_buffer(
177
183
  hidden_states,
178
- prefix=f"{self.blocks_name}_Bn_hidden_states",
184
+ prefix=f"{self.cache_prefix}_Bn_hidden_states",
179
185
  )
180
- if CachedContext.is_encoder_cache_residual():
181
- CachedContext.set_Bn_encoder_buffer(
186
+ if self.cache_manager.is_encoder_cache_residual():
187
+ self.cache_manager.set_Bn_encoder_buffer(
182
188
  encoder_hidden_states_residual,
183
- prefix=f"{self.blocks_name}_Bn_residual",
189
+ prefix=f"{self.cache_prefix}_Bn_residual",
184
190
  )
185
191
  else:
186
192
  # TaylorSeer
187
- CachedContext.set_Bn_encoder_buffer(
193
+ self.cache_manager.set_Bn_encoder_buffer(
188
194
  encoder_hidden_states,
189
- prefix=f"{self.blocks_name}_Bn_hidden_states",
195
+ prefix=f"{self.cache_prefix}_Bn_hidden_states",
190
196
  )
191
197
  torch._dynamo.graph_break()
192
198
  # Call last `n` blocks to further process the hidden states
@@ -232,10 +238,11 @@ class CachedBlocks_Pattern_Base(torch.nn.Module):
232
238
  # If so, we can skip some Bn blocks and directly
233
239
  # use the cached values.
234
240
  return (
235
- CachedContext.get_current_step() in CachedContext.get_cached_steps()
241
+ self.cache_manager.get_current_step()
242
+ in self.cache_manager.get_cached_steps()
236
243
  ) or (
237
- CachedContext.get_current_step()
238
- in CachedContext.get_cfg_cached_steps()
244
+ self.cache_manager.get_current_step()
245
+ in self.cache_manager.get_cfg_cached_steps()
239
246
  )
240
247
 
241
248
  @torch.compiler.disable
@@ -244,20 +251,20 @@ class CachedBlocks_Pattern_Base(torch.nn.Module):
244
251
  # more stable diff calculation.
245
252
  # Fn: [0,...,n-1]
246
253
  selected_Fn_blocks = self.transformer_blocks[
247
- : CachedContext.Fn_compute_blocks()
254
+ : self.cache_manager.Fn_compute_blocks()
248
255
  ]
249
256
  return selected_Fn_blocks
250
257
 
251
258
  @torch.compiler.disable
252
259
  def _Mn_blocks(self): # middle blocks
253
260
  # M(N-2n): only transformer_blocks [n,...,N-n], middle
254
- if CachedContext.Bn_compute_blocks() == 0: # WARN: x[:-0] = []
261
+ if self.cache_manager.Bn_compute_blocks() == 0: # WARN: x[:-0] = []
255
262
  selected_Mn_blocks = self.transformer_blocks[
256
- CachedContext.Fn_compute_blocks() :
263
+ self.cache_manager.Fn_compute_blocks() :
257
264
  ]
258
265
  else:
259
266
  selected_Mn_blocks = self.transformer_blocks[
260
- CachedContext.Fn_compute_blocks() : -CachedContext.Bn_compute_blocks()
267
+ self.cache_manager.Fn_compute_blocks() : -self.cache_manager.Bn_compute_blocks()
261
268
  ]
262
269
  return selected_Mn_blocks
263
270
 
@@ -265,7 +272,7 @@ class CachedBlocks_Pattern_Base(torch.nn.Module):
265
272
  def _Bn_blocks(self):
266
273
  # Bn: transformer_blocks [N-n+1,...,N-1]
267
274
  selected_Bn_blocks = self.transformer_blocks[
268
- -CachedContext.Bn_compute_blocks() :
275
+ -self.cache_manager.Bn_compute_blocks() :
269
276
  ]
270
277
  return selected_Bn_blocks
271
278
 
@@ -276,10 +283,10 @@ class CachedBlocks_Pattern_Base(torch.nn.Module):
276
283
  *args,
277
284
  **kwargs,
278
285
  ):
279
- assert CachedContext.Fn_compute_blocks() <= len(
286
+ assert self.cache_manager.Fn_compute_blocks() <= len(
280
287
  self.transformer_blocks
281
288
  ), (
282
- f"Fn_compute_blocks {CachedContext.Fn_compute_blocks()} must be less than "
289
+ f"Fn_compute_blocks {self.cache_manager.Fn_compute_blocks()} must be less than "
283
290
  f"the number of transformer blocks {len(self.transformer_blocks)}"
284
291
  )
285
292
  for block in self._Fn_blocks():
@@ -376,7 +383,7 @@ class CachedBlocks_Pattern_Base(torch.nn.Module):
376
383
  )
377
384
  # Cache residuals for the non-compute Bn blocks for
378
385
  # subsequent cache steps.
379
- if block_id not in CachedContext.Bn_compute_blocks_ids():
386
+ if block_id not in self.cache_manager.Bn_compute_blocks_ids():
380
387
  Bn_i_hidden_states_residual = (
381
388
  hidden_states - Bn_i_original_hidden_states
382
389
  )
@@ -385,22 +392,22 @@ class CachedBlocks_Pattern_Base(torch.nn.Module):
385
392
  )
386
393
 
387
394
  # Save original_hidden_states for diff calculation.
388
- CachedContext.set_Bn_buffer(
395
+ self.cache_manager.set_Bn_buffer(
389
396
  Bn_i_original_hidden_states,
390
- prefix=f"{self.blocks_name}_Bn_{block_id}_original",
397
+ prefix=f"{self.cache_prefix}_Bn_{block_id}_original",
391
398
  )
392
- CachedContext.set_Bn_encoder_buffer(
399
+ self.cache_manager.set_Bn_encoder_buffer(
393
400
  Bn_i_original_encoder_hidden_states,
394
- prefix=f"{self.blocks_name}_Bn_{block_id}_original",
401
+ prefix=f"{self.cache_prefix}_Bn_{block_id}_original",
395
402
  )
396
403
 
397
- CachedContext.set_Bn_buffer(
404
+ self.cache_manager.set_Bn_buffer(
398
405
  Bn_i_hidden_states_residual,
399
- prefix=f"{self.blocks_name}_Bn_{block_id}_residual",
406
+ prefix=f"{self.cache_prefix}_Bn_{block_id}_residual",
400
407
  )
401
- CachedContext.set_Bn_encoder_buffer(
408
+ self.cache_manager.set_Bn_encoder_buffer(
402
409
  Bn_i_encoder_hidden_states_residual,
403
- prefix=f"{self.blocks_name}_Bn_{block_id}_residual",
410
+ prefix=f"{self.cache_prefix}_Bn_{block_id}_residual",
404
411
  )
405
412
  del Bn_i_hidden_states_residual
406
413
  del Bn_i_encoder_hidden_states_residual
@@ -411,7 +418,7 @@ class CachedBlocks_Pattern_Base(torch.nn.Module):
411
418
  else:
412
419
  # Cache steps: Reuse the cached residuals.
413
420
  # Check if the block is in the Bn_compute_blocks_ids.
414
- if block_id in CachedContext.Bn_compute_blocks_ids():
421
+ if block_id in self.cache_manager.Bn_compute_blocks_ids():
415
422
  hidden_states = block(
416
423
  hidden_states,
417
424
  encoder_hidden_states,
@@ -429,25 +436,25 @@ class CachedBlocks_Pattern_Base(torch.nn.Module):
429
436
  # Skip the block if it is not in the Bn_compute_blocks_ids.
430
437
  # Use the cached residuals instead.
431
438
  # Check if can use the cached residuals.
432
- if CachedContext.get_can_use_cache(
439
+ if self.cache_manager.can_cache(
433
440
  hidden_states, # curr step
434
441
  parallelized=self._is_parallelized(),
435
- threshold=CachedContext.non_compute_blocks_diff_threshold(),
436
- prefix=f"{self.blocks_name}_Bn_{block_id}_original", # prev step
442
+ threshold=self.cache_manager.non_compute_blocks_diff_threshold(),
443
+ prefix=f"{self.cache_prefix}_Bn_{block_id}_original", # prev step
437
444
  ):
438
445
  hidden_states, encoder_hidden_states = (
439
- CachedContext.apply_hidden_states_residual(
446
+ self.cache_manager.apply_cache(
440
447
  hidden_states,
441
448
  encoder_hidden_states,
442
449
  prefix=(
443
- f"{self.blocks_name}_Bn_{block_id}_residual"
444
- if CachedContext.is_cache_residual()
445
- else f"{self.blocks_name}_Bn_{block_id}_original"
450
+ f"{self.cache_prefix}_Bn_{block_id}_residual"
451
+ if self.cache_manager.is_cache_residual()
452
+ else f"{self.cache_prefix}_Bn_{block_id}_original"
446
453
  ),
447
454
  encoder_prefix=(
448
- f"{self.blocks_name}_Bn_{block_id}_residual"
449
- if CachedContext.is_encoder_cache_residual()
450
- else f"{self.blocks_name}_Bn_{block_id}_original"
455
+ f"{self.cache_prefix}_Bn_{block_id}_residual"
456
+ if self.cache_manager.is_encoder_cache_residual()
457
+ else f"{self.cache_prefix}_Bn_{block_id}_original"
451
458
  ),
452
459
  )
453
460
  )
@@ -474,16 +481,16 @@ class CachedBlocks_Pattern_Base(torch.nn.Module):
474
481
  *args,
475
482
  **kwargs,
476
483
  ):
477
- if CachedContext.Bn_compute_blocks() == 0:
484
+ if self.cache_manager.Bn_compute_blocks() == 0:
478
485
  return hidden_states, encoder_hidden_states
479
486
 
480
- assert CachedContext.Bn_compute_blocks() <= len(
487
+ assert self.cache_manager.Bn_compute_blocks() <= len(
481
488
  self.transformer_blocks
482
489
  ), (
483
- f"Bn_compute_blocks {CachedContext.Bn_compute_blocks()} must be less than "
490
+ f"Bn_compute_blocks {self.cache_manager.Bn_compute_blocks()} must be less than "
484
491
  f"the number of transformer blocks {len(self.transformer_blocks)}"
485
492
  )
486
- if len(CachedContext.Bn_compute_blocks_ids()) > 0:
493
+ if len(self.cache_manager.Bn_compute_blocks_ids()) > 0:
487
494
  for i, block in enumerate(self._Bn_blocks()):
488
495
  hidden_states, encoder_hidden_states = (
489
496
  self._compute_or_cache_block(
@@ -2,22 +2,40 @@ import torch
2
2
 
3
3
  from typing import Any
4
4
  from cache_dit.cache_factory import CachedContext
5
+ from cache_dit.cache_factory import CachedContextManager
5
6
 
6
7
 
7
- @torch.compiler.disable
8
8
  def patch_cached_stats(
9
- module: torch.nn.Module | Any, cache_context: str = None
9
+ module: torch.nn.Module | Any,
10
+ cache_context: CachedContext | str = None,
11
+ cache_manager: CachedContextManager = None,
10
12
  ):
11
13
  # Patch the cached stats to the module, the cached stats
12
14
  # will be reset for each calling of pipe.__call__(**kwargs).
13
- if module is None:
15
+ if module is None or cache_manager is None:
14
16
  return
15
17
 
16
18
  if cache_context is not None:
17
- CachedContext.set_cache_context(cache_context)
19
+ cache_manager.set_context(cache_context)
18
20
 
19
21
  # TODO: Patch more cached stats to the module
20
- module._cached_steps = CachedContext.get_cached_steps()
21
- module._residual_diffs = CachedContext.get_residual_diffs()
22
- module._cfg_cached_steps = CachedContext.get_cfg_cached_steps()
23
- module._cfg_residual_diffs = CachedContext.get_cfg_residual_diffs()
22
+ module._cached_steps = cache_manager.get_cached_steps()
23
+ module._residual_diffs = cache_manager.get_residual_diffs()
24
+ module._cfg_cached_steps = cache_manager.get_cfg_cached_steps()
25
+ module._cfg_residual_diffs = cache_manager.get_cfg_residual_diffs()
26
+
27
+
28
+ def remove_cached_stats(
29
+ module: torch.nn.Module | Any,
30
+ ):
31
+ if module is None:
32
+ return
33
+
34
+ if hasattr(module, "_cached_steps"):
35
+ del module._cached_steps
36
+ if hasattr(module, "_residual_diffs"):
37
+ del module._residual_diffs
38
+ if hasattr(module, "_cfg_cached_steps"):
39
+ del module._cfg_cached_steps
40
+ if hasattr(module, "_cfg_residual_diffs"):
41
+ del module._cfg_residual_diffs
@@ -1,2 +1,5 @@
1
1
  # namespace alias: for _CachedContext and many others' cache context funcs.
2
- import cache_dit.cache_factory.cache_contexts.cache_context as CachedContext
2
+ from cache_dit.cache_factory.cache_contexts.cache_context import CachedContext
3
+ from cache_dit.cache_factory.cache_contexts.cache_manager import (
4
+ CachedContextManager,
5
+ )