cache-dit 0.2.27__py3-none-any.whl → 0.2.28__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
@@ -3,14 +3,16 @@ import torch
3
3
  import unittest
4
4
  import functools
5
5
 
6
- from typing import Dict
7
6
  from contextlib import ExitStack
7
+ from typing import Dict, List, Tuple, Any
8
+
8
9
  from diffusers import DiffusionPipeline
10
+
9
11
  from cache_dit.cache_factory import CacheType
10
- from cache_dit.cache_factory import CachedContext
11
- from cache_dit.cache_factory import ForwardPattern
12
12
  from cache_dit.cache_factory import BlockAdapter
13
+ from cache_dit.cache_factory import ParamsModifier
13
14
  from cache_dit.cache_factory import BlockAdapterRegistry
15
+ from cache_dit.cache_factory import CachedContextManager
14
16
  from cache_dit.cache_factory import CachedBlocks
15
17
 
16
18
  from cache_dit.logger import init_logger
@@ -29,7 +31,6 @@ class CachedAdapter:
29
31
  cls,
30
32
  pipe: DiffusionPipeline = None,
31
33
  block_adapter: BlockAdapter = None,
32
- # forward_pattern: ForwardPattern = ForwardPattern.Pattern_0,
33
34
  **cache_context_kwargs,
34
35
  ) -> DiffusionPipeline:
35
36
  assert (
@@ -74,54 +75,67 @@ class CachedAdapter:
74
75
  )
75
76
 
76
77
  if BlockAdapter.check_block_adapter(block_adapter):
78
+
79
+ # 0. Must normalize block_adapter before apply cache
77
80
  block_adapter = BlockAdapter.normalize(block_adapter)
78
- # 0. Apply cache on pipeline: wrap cache context
81
+ if BlockAdapter.is_cached(block_adapter):
82
+ return block_adapter.pipe
83
+
84
+ # 1. Apply cache on pipeline: wrap cache context, must
85
+ # call create_context before mock_blocks.
79
86
  cls.create_context(
80
87
  block_adapter,
81
88
  **cache_context_kwargs,
82
89
  )
83
- # 1. Apply cache on transformer: mock cached transformer blocks
90
+
91
+ # 2. Apply cache on transformer: mock cached blocks
84
92
  cls.mock_blocks(
85
93
  block_adapter,
86
94
  )
87
- cls.patch_params(
88
- block_adapter,
89
- **cache_context_kwargs,
90
- )
95
+
91
96
  return block_adapter.pipe
92
97
 
93
98
  @classmethod
94
99
  def patch_params(
95
100
  cls,
96
101
  block_adapter: BlockAdapter,
97
- **cache_context_kwargs,
102
+ contexts_kwargs: List[Dict],
98
103
  ):
99
- block_adapter.transformer._forward_pattern = (
100
- block_adapter.forward_pattern
101
- )
102
- block_adapter.transformer._has_separate_cfg = (
103
- block_adapter.has_separate_cfg
104
- )
105
- block_adapter.transformer._cache_context_kwargs = cache_context_kwargs
106
- block_adapter.pipe.__class__._cache_context_kwargs = (
107
- cache_context_kwargs
108
- )
109
- for blocks, forward_pattern in zip(
110
- block_adapter.blocks, block_adapter.forward_pattern
111
- ):
112
- blocks._forward_pattern = forward_pattern
113
- blocks._cache_context_kwargs = cache_context_kwargs
104
+ block_adapter.pipe._cache_context_kwargs = contexts_kwargs[0]
105
+
106
+ params_shift = 0
107
+ for i in range(len(block_adapter.transformer)):
108
+
109
+ block_adapter.transformer[i]._forward_pattern = (
110
+ block_adapter.forward_pattern
111
+ )
112
+ block_adapter.transformer[i]._has_separate_cfg = (
113
+ block_adapter.has_separate_cfg
114
+ )
115
+ block_adapter.transformer[i]._cache_context_kwargs = (
116
+ contexts_kwargs[params_shift]
117
+ )
118
+
119
+ blocks = block_adapter.blocks[i]
120
+ for j in range(len(blocks)):
121
+ blocks[j]._forward_pattern = block_adapter.forward_pattern[i][j]
122
+ blocks[j]._cache_context_kwargs = contexts_kwargs[
123
+ params_shift + j
124
+ ]
125
+
126
+ params_shift += len(blocks)
114
127
 
115
128
  @classmethod
116
129
  def check_context_kwargs(cls, pipe, **cache_context_kwargs):
117
130
  # Check cache_context_kwargs
118
- if not cache_context_kwargs["do_separate_cfg"]:
131
+ if not cache_context_kwargs["enable_spearate_cfg"]:
119
132
  # Check cfg for some specific case if users don't set it as True
120
- cache_context_kwargs["do_separate_cfg"] = (
133
+ cache_context_kwargs["enable_spearate_cfg"] = (
121
134
  BlockAdapterRegistry.has_separate_cfg(pipe)
122
135
  )
123
136
  logger.info(
124
- f"Use default 'do_separate_cfg': {cache_context_kwargs['do_separate_cfg']}, "
137
+ f"Use default 'enable_spearate_cfg': "
138
+ f"{cache_context_kwargs['enable_spearate_cfg']}, "
125
139
  f"Pipeline: {pipe.__class__.__name__}."
126
140
  )
127
141
 
@@ -138,7 +152,10 @@ class CachedAdapter:
138
152
  block_adapter: BlockAdapter,
139
153
  **cache_context_kwargs,
140
154
  ) -> DiffusionPipeline:
141
- if getattr(block_adapter.pipe, "_is_cached", False):
155
+
156
+ BlockAdapter.assert_normalized(block_adapter)
157
+
158
+ if BlockAdapter.is_cached(block_adapter.pipe):
142
159
  return block_adapter.pipe
143
160
 
144
161
  # Check cache_context_kwargs
@@ -147,22 +164,35 @@ class CachedAdapter:
147
164
  **cache_context_kwargs,
148
165
  )
149
166
  # Apply cache on pipeline: wrap cache context
150
- cache_kwargs, _ = CachedContext.collect_cache_kwargs(
151
- default_attrs={},
152
- **cache_context_kwargs,
167
+ pipe_cls_name = block_adapter.pipe.__class__.__name__
168
+
169
+ # Each Pipeline should have it's own context manager instance.
170
+ # Different transformers (Wan2.2, etc) should shared the same
171
+ # cache manager but with different cache context (according
172
+ # to their unique instance id).
173
+ cache_manager = CachedContextManager(
174
+ name=f"{pipe_cls_name}_{hash(id(block_adapter.pipe))}",
175
+ )
176
+ block_adapter.pipe._cache_manager = cache_manager # instance level
177
+
178
+ flatten_contexts, contexts_kwargs = cls.modify_context_params(
179
+ block_adapter, cache_manager, **cache_context_kwargs
153
180
  )
181
+
154
182
  original_call = block_adapter.pipe.__class__.__call__
155
183
 
156
184
  @functools.wraps(original_call)
157
185
  def new_call(self, *args, **kwargs):
158
186
  with ExitStack() as stack:
159
- # cache context will reset for each pipe inference
160
- for blocks_name in block_adapter.blocks_name:
187
+ # cache context will be reset for each pipe inference
188
+ for context_name, context_kwargs in zip(
189
+ flatten_contexts, contexts_kwargs
190
+ ):
161
191
  stack.enter_context(
162
- CachedContext.cache_context(
163
- CachedContext.reset_cache_context(
164
- blocks_name,
165
- **cache_kwargs,
192
+ cache_manager.enter_context(
193
+ cache_manager.reset_context(
194
+ context_name,
195
+ **context_kwargs,
166
196
  ),
167
197
  )
168
198
  )
@@ -171,109 +201,194 @@ class CachedAdapter:
171
201
  return outputs
172
202
 
173
203
  block_adapter.pipe.__class__.__call__ = new_call
204
+ block_adapter.pipe.__class__._original_call = original_call
174
205
  block_adapter.pipe.__class__._is_cached = True
206
+
207
+ cls.patch_params(block_adapter, contexts_kwargs)
208
+
175
209
  return block_adapter.pipe
176
210
 
177
211
  @classmethod
178
- def patch_stats(cls, block_adapter: BlockAdapter):
212
+ def modify_context_params(
213
+ cls,
214
+ block_adapter: BlockAdapter,
215
+ cache_manager: CachedContextManager,
216
+ **cache_context_kwargs,
217
+ ) -> Tuple[List[str], List[Dict[str, Any]]]:
218
+
219
+ flatten_contexts = BlockAdapter.flatten(
220
+ block_adapter.unique_blocks_name
221
+ )
222
+ contexts_kwargs = [
223
+ cache_context_kwargs.copy()
224
+ for _ in range(
225
+ len(flatten_contexts),
226
+ )
227
+ ]
228
+
229
+ for i in range(len(contexts_kwargs)):
230
+ contexts_kwargs[i]["name"] = flatten_contexts[i]
231
+
232
+ if block_adapter.params_modifiers is None:
233
+ return flatten_contexts, contexts_kwargs
234
+
235
+ flatten_modifiers: List[ParamsModifier] = BlockAdapter.flatten(
236
+ block_adapter.params_modifiers,
237
+ )
238
+
239
+ for i in range(
240
+ min(len(contexts_kwargs), len(flatten_modifiers)),
241
+ ):
242
+ contexts_kwargs[i].update(
243
+ flatten_modifiers[i]._context_kwargs,
244
+ )
245
+ contexts_kwargs[i], _ = cache_manager.collect_cache_kwargs(
246
+ default_attrs={}, **contexts_kwargs[i]
247
+ )
248
+
249
+ return flatten_contexts, contexts_kwargs
250
+
251
+ @classmethod
252
+ def patch_stats(
253
+ cls,
254
+ block_adapter: BlockAdapter,
255
+ ):
179
256
  from cache_dit.cache_factory.cache_blocks.utils import (
180
257
  patch_cached_stats,
181
258
  )
182
259
 
183
- patch_cached_stats(block_adapter.transformer)
184
- for blocks, blocks_name in zip(
185
- block_adapter.blocks, block_adapter.blocks_name
186
- ):
187
- patch_cached_stats(blocks, blocks_name)
260
+ cache_manager = block_adapter.pipe._cache_manager
261
+
262
+ for i in range(len(block_adapter.transformer)):
263
+ patch_cached_stats(
264
+ block_adapter.transformer[i],
265
+ cache_context=block_adapter.unique_blocks_name[i][-1],
266
+ cache_manager=cache_manager,
267
+ )
268
+ for blocks, unique_name in zip(
269
+ block_adapter.blocks[i],
270
+ block_adapter.unique_blocks_name[i],
271
+ ):
272
+ patch_cached_stats(
273
+ blocks,
274
+ cache_context=unique_name,
275
+ cache_manager=cache_manager,
276
+ )
188
277
 
189
278
  @classmethod
190
279
  def mock_blocks(
191
280
  cls,
192
281
  block_adapter: BlockAdapter,
193
- ) -> torch.nn.Module:
282
+ ) -> List[torch.nn.Module]:
283
+
284
+ BlockAdapter.assert_normalized(block_adapter)
194
285
 
195
- if getattr(block_adapter.transformer, "_is_cached", False):
286
+ if BlockAdapter.is_cached(block_adapter.transformer):
196
287
  return block_adapter.transformer
197
288
 
198
- # Check block forward pattern matching
199
- block_adapter = BlockAdapter.normalize(block_adapter)
200
- for forward_pattern, blocks in zip(
201
- block_adapter.forward_pattern, block_adapter.blocks
289
+ # Apply cache on transformer: mock cached transformer blocks
290
+ for (
291
+ cached_blocks,
292
+ transformer,
293
+ blocks_name,
294
+ unique_blocks_name,
295
+ dummy_blocks_names,
296
+ ) in zip(
297
+ cls.collect_cached_blocks(block_adapter),
298
+ block_adapter.transformer,
299
+ block_adapter.blocks_name,
300
+ block_adapter.unique_blocks_name,
301
+ block_adapter.dummy_blocks_names,
202
302
  ):
203
- assert BlockAdapter.match_blocks_pattern(
204
- blocks,
205
- forward_pattern=forward_pattern,
206
- check_num_outputs=block_adapter.check_num_outputs,
207
- ), (
208
- "No block forward pattern matched, "
209
- f"supported lists: {ForwardPattern.supported_patterns()}"
303
+ cls.mock_transformer(
304
+ cached_blocks,
305
+ transformer,
306
+ blocks_name,
307
+ unique_blocks_name,
308
+ dummy_blocks_names,
210
309
  )
211
310
 
212
- # Apply cache on transformer: mock cached transformer blocks
213
- # TODO: Use blocks_name to spearate cached context for different
214
- # blocks list. For example, single_transformer_blocks and
215
- # transformer_blocks should have different cached context and
216
- # forward pattern.
217
- cached_blocks = cls.collect_cached_blocks(
218
- block_adapter=block_adapter,
219
- )
311
+ return block_adapter.transformer
312
+
313
+ @classmethod
314
+ def mock_transformer(
315
+ cls,
316
+ cached_blocks: Dict[str, torch.nn.ModuleList],
317
+ transformer: torch.nn.Module,
318
+ blocks_name: List[str],
319
+ unique_blocks_name: List[str],
320
+ dummy_blocks_names: List[str],
321
+ ) -> torch.nn.Module:
220
322
  dummy_blocks = torch.nn.ModuleList()
221
323
 
222
- original_forward = block_adapter.transformer.forward
324
+ original_forward = transformer.forward
223
325
 
224
- assert isinstance(block_adapter.dummy_blocks_names, list)
326
+ assert isinstance(dummy_blocks_names, list)
225
327
 
226
328
  @functools.wraps(original_forward)
227
329
  def new_forward(self, *args, **kwargs):
228
330
  with ExitStack() as stack:
229
- for blocks_name in block_adapter.blocks_name:
331
+ for name, context_name in zip(
332
+ blocks_name,
333
+ unique_blocks_name,
334
+ ):
230
335
  stack.enter_context(
231
336
  unittest.mock.patch.object(
232
- self,
233
- blocks_name,
234
- cached_blocks[blocks_name],
337
+ self, name, cached_blocks[context_name]
235
338
  )
236
339
  )
237
- for dummy_name in block_adapter.dummy_blocks_names:
340
+ for dummy_name in dummy_blocks_names:
238
341
  stack.enter_context(
239
342
  unittest.mock.patch.object(
240
- self,
241
- dummy_name,
242
- dummy_blocks,
343
+ self, dummy_name, dummy_blocks
243
344
  )
244
345
  )
245
346
  return original_forward(*args, **kwargs)
246
347
 
247
- block_adapter.transformer.forward = new_forward.__get__(
248
- block_adapter.transformer
249
- )
250
- block_adapter.transformer._is_cached = True
348
+ transformer.forward = new_forward.__get__(transformer)
349
+ transformer._original_forward = original_forward
350
+ transformer._is_cached = True
251
351
 
252
- return block_adapter.transformer
352
+ return transformer
253
353
 
254
354
  @classmethod
255
355
  def collect_cached_blocks(
256
356
  cls,
257
357
  block_adapter: BlockAdapter,
258
- ) -> Dict[str, torch.nn.ModuleList]:
259
- block_adapter = BlockAdapter.normalize(block_adapter)
358
+ ) -> List[Dict[str, torch.nn.ModuleList]]:
359
+
360
+ BlockAdapter.assert_normalized(block_adapter)
260
361
 
261
- cached_blocks_bind_context = {}
362
+ total_cached_blocks: List[Dict[str, torch.nn.ModuleList]] = []
363
+ assert hasattr(block_adapter.pipe, "_cache_manager")
364
+ assert isinstance(
365
+ block_adapter.pipe._cache_manager, CachedContextManager
366
+ )
367
+
368
+ for i in range(len(block_adapter.transformer)):
262
369
 
263
- for i in range(len(block_adapter.blocks)):
264
- cached_blocks_bind_context[block_adapter.blocks_name[i]] = (
265
- torch.nn.ModuleList(
370
+ cached_blocks_bind_context = {}
371
+ for j in range(len(block_adapter.blocks[i])):
372
+ cached_blocks_bind_context[
373
+ block_adapter.unique_blocks_name[i][j]
374
+ ] = torch.nn.ModuleList(
266
375
  [
267
376
  CachedBlocks(
268
- block_adapter.blocks[i],
269
- block_adapter.blocks_name[i],
270
- block_adapter.blocks_name[i], # context name
271
- transformer=block_adapter.transformer,
272
- forward_pattern=block_adapter.forward_pattern[i],
377
+ # 0. Transformer blocks configuration
378
+ block_adapter.blocks[i][j],
379
+ transformer=block_adapter.transformer[i],
380
+ forward_pattern=block_adapter.forward_pattern[i][j],
273
381
  check_num_outputs=block_adapter.check_num_outputs,
382
+ # 1. Cache context configuration
383
+ cache_prefix=block_adapter.blocks_name[i][j],
384
+ cache_context=block_adapter.unique_blocks_name[i][
385
+ j
386
+ ],
387
+ cache_manager=block_adapter.pipe._cache_manager,
274
388
  )
275
389
  ]
276
390
  )
277
- )
278
391
 
279
- return cached_blocks_bind_context
392
+ total_cached_blocks.append(cached_blocks_bind_context)
393
+
394
+ return total_cached_blocks
@@ -1,3 +1,11 @@
1
+ import torch
2
+
3
+ from cache_dit.cache_factory import ForwardPattern
4
+ from cache_dit.cache_factory.cache_contexts.cache_context import CachedContext
5
+ from cache_dit.cache_factory.cache_contexts.cache_manager import (
6
+ CachedContextManager,
7
+ )
8
+
1
9
  from cache_dit.cache_factory.cache_blocks.pattern_0_1_2 import (
2
10
  CachedBlocks_Pattern_0_1_2,
3
11
  )
@@ -5,14 +13,57 @@ from cache_dit.cache_factory.cache_blocks.pattern_3_4_5 import (
5
13
  CachedBlocks_Pattern_3_4_5,
6
14
  )
7
15
 
16
+ from cache_dit.logger import init_logger
17
+
18
+ logger = init_logger(__name__)
19
+
8
20
 
9
21
  class CachedBlocks:
10
- def __new__(cls, *args, **kwargs):
11
- forward_pattern = kwargs.get("forward_pattern", None)
22
+ def __new__(
23
+ cls,
24
+ # 0. Transformer blocks configuration
25
+ transformer_blocks: torch.nn.ModuleList,
26
+ transformer: torch.nn.Module = None,
27
+ forward_pattern: ForwardPattern = None,
28
+ check_num_outputs: bool = True,
29
+ # 1. Cache context configuration
30
+ # 'transformer_blocks', 'blocks', 'single_transformer_blocks',
31
+ # 'layers', 'single_stream_blocks', 'double_stream_blocks'
32
+ cache_prefix: str = None, # cache_prefix maybe un-need.
33
+ # Usually, blocks_name, etc.
34
+ cache_context: CachedContext | str = None,
35
+ cache_manager: CachedContextManager = None,
36
+ **kwargs,
37
+ ):
38
+ assert transformer is not None, "transformer can't be None."
12
39
  assert forward_pattern is not None, "forward_pattern can't be None."
40
+ assert cache_context is not None, "cache_context can't be None."
41
+ assert cache_manager is not None, "cache_manager can't be None."
13
42
  if forward_pattern in CachedBlocks_Pattern_0_1_2._supported_patterns:
14
- return CachedBlocks_Pattern_0_1_2(*args, **kwargs)
43
+ return CachedBlocks_Pattern_0_1_2(
44
+ # 0. Transformer blocks configuration
45
+ transformer_blocks,
46
+ transformer=transformer,
47
+ forward_pattern=forward_pattern,
48
+ check_num_outputs=check_num_outputs,
49
+ # 1. Cache context configuration
50
+ cache_prefix=cache_prefix,
51
+ cache_context=cache_context,
52
+ cache_manager=cache_manager,
53
+ **kwargs,
54
+ )
15
55
  elif forward_pattern in CachedBlocks_Pattern_3_4_5._supported_patterns:
16
- return CachedBlocks_Pattern_3_4_5(*args, **kwargs)
56
+ return CachedBlocks_Pattern_3_4_5(
57
+ # 0. Transformer blocks configuration
58
+ transformer_blocks,
59
+ transformer=transformer,
60
+ forward_pattern=forward_pattern,
61
+ check_num_outputs=check_num_outputs,
62
+ # 1. Cache context configuration
63
+ cache_prefix=cache_prefix,
64
+ cache_context=cache_context,
65
+ cache_manager=cache_manager,
66
+ **kwargs,
67
+ )
17
68
  else:
18
69
  raise ValueError(f"Pattern {forward_pattern} is not supported now!")
@@ -1,6 +1,5 @@
1
1
  import torch
2
2
 
3
- from cache_dit.cache_factory import CachedContext
4
3
  from cache_dit.cache_factory import ForwardPattern
5
4
  from cache_dit.cache_factory.cache_blocks.pattern_base import (
6
5
  CachedBlocks_Pattern_Base,
@@ -24,7 +23,7 @@ class CachedBlocks_Pattern_3_4_5(CachedBlocks_Pattern_Base):
24
23
  **kwargs,
25
24
  ):
26
25
  # Use it's own cache context.
27
- CachedContext.set_cache_context(
26
+ self.cache_manager.set_context(
28
27
  self.cache_context,
29
28
  )
30
29
 
@@ -41,40 +40,40 @@ class CachedBlocks_Pattern_3_4_5(CachedBlocks_Pattern_Base):
41
40
  Fn_hidden_states_residual = hidden_states - original_hidden_states
42
41
  del original_hidden_states
43
42
 
44
- CachedContext.mark_step_begin()
43
+ self.cache_manager.mark_step_begin()
45
44
  # Residual L1 diff or Hidden States L1 diff
46
- can_use_cache = CachedContext.get_can_use_cache(
45
+ can_use_cache = self.cache_manager.can_cache(
47
46
  (
48
47
  Fn_hidden_states_residual
49
- if not CachedContext.is_l1_diff_enabled()
48
+ if not self.cache_manager.is_l1_diff_enabled()
50
49
  else hidden_states
51
50
  ),
52
51
  parallelized=self._is_parallelized(),
53
52
  prefix=(
54
- f"{self.blocks_name}_Fn_residual"
55
- if not CachedContext.is_l1_diff_enabled()
56
- else f"{self.blocks_name}_Fn_hidden_states"
53
+ f"{self.cache_prefix}_Fn_residual"
54
+ if not self.cache_manager.is_l1_diff_enabled()
55
+ else f"{self.cache_prefix}_Fn_hidden_states"
57
56
  ),
58
57
  )
59
58
 
60
59
  torch._dynamo.graph_break()
61
60
  if can_use_cache:
62
- CachedContext.add_cached_step()
61
+ self.cache_manager.add_cached_step()
63
62
  del Fn_hidden_states_residual
64
63
  hidden_states, encoder_hidden_states = (
65
- CachedContext.apply_hidden_states_residual(
64
+ self.cache_manager.apply_cache(
66
65
  hidden_states,
67
66
  # None Pattern 3, else 4, 5
68
67
  encoder_hidden_states,
69
68
  prefix=(
70
- f"{self.blocks_name}_Bn_residual"
71
- if CachedContext.is_cache_residual()
72
- else f"{self.blocks_name}_Bn_hidden_states"
69
+ f"{self.cache_prefix}_Bn_residual"
70
+ if self.cache_manager.is_cache_residual()
71
+ else f"{self.cache_prefix}_Bn_hidden_states"
73
72
  ),
74
73
  encoder_prefix=(
75
- f"{self.blocks_name}_Bn_residual"
76
- if CachedContext.is_encoder_cache_residual()
77
- else f"{self.blocks_name}_Bn_hidden_states"
74
+ f"{self.cache_prefix}_Bn_residual"
75
+ if self.cache_manager.is_encoder_cache_residual()
76
+ else f"{self.cache_prefix}_Bn_hidden_states"
78
77
  ),
79
78
  )
80
79
  )
@@ -88,15 +87,15 @@ class CachedBlocks_Pattern_3_4_5(CachedBlocks_Pattern_Base):
88
87
  **kwargs,
89
88
  )
90
89
  else:
91
- CachedContext.set_Fn_buffer(
90
+ self.cache_manager.set_Fn_buffer(
92
91
  Fn_hidden_states_residual,
93
- prefix=f"{self.blocks_name}_Fn_residual",
92
+ prefix=f"{self.cache_prefix}_Fn_residual",
94
93
  )
95
- if CachedContext.is_l1_diff_enabled():
94
+ if self.cache_manager.is_l1_diff_enabled():
96
95
  # for hidden states L1 diff
97
- CachedContext.set_Fn_buffer(
96
+ self.cache_manager.set_Fn_buffer(
98
97
  hidden_states,
99
- f"{self.blocks_name}_Fn_hidden_states",
98
+ f"{self.cache_prefix}_Fn_hidden_states",
100
99
  )
101
100
  del Fn_hidden_states_residual
102
101
  torch._dynamo.graph_break()
@@ -114,29 +113,29 @@ class CachedBlocks_Pattern_3_4_5(CachedBlocks_Pattern_Base):
114
113
  **kwargs,
115
114
  )
116
115
  torch._dynamo.graph_break()
117
- if CachedContext.is_cache_residual():
118
- CachedContext.set_Bn_buffer(
116
+ if self.cache_manager.is_cache_residual():
117
+ self.cache_manager.set_Bn_buffer(
119
118
  hidden_states_residual,
120
- prefix=f"{self.blocks_name}_Bn_residual",
119
+ prefix=f"{self.cache_prefix}_Bn_residual",
121
120
  )
122
121
  else:
123
122
  # TaylorSeer
124
- CachedContext.set_Bn_buffer(
123
+ self.cache_manager.set_Bn_buffer(
125
124
  hidden_states,
126
- prefix=f"{self.blocks_name}_Bn_hidden_states",
125
+ prefix=f"{self.cache_prefix}_Bn_hidden_states",
127
126
  )
128
- if CachedContext.is_encoder_cache_residual():
129
- CachedContext.set_Bn_encoder_buffer(
127
+ if self.cache_manager.is_encoder_cache_residual():
128
+ self.cache_manager.set_Bn_encoder_buffer(
130
129
  # None Pattern 3, else 4, 5
131
130
  encoder_hidden_states_residual,
132
- prefix=f"{self.blocks_name}_Bn_residual",
131
+ prefix=f"{self.cache_prefix}_Bn_residual",
133
132
  )
134
133
  else:
135
134
  # TaylorSeer
136
- CachedContext.set_Bn_encoder_buffer(
135
+ self.cache_manager.set_Bn_encoder_buffer(
137
136
  # None Pattern 3, else 4, 5
138
137
  encoder_hidden_states,
139
- prefix=f"{self.blocks_name}_Bn_hidden_states",
138
+ prefix=f"{self.cache_prefix}_Bn_hidden_states",
140
139
  )
141
140
  torch._dynamo.graph_break()
142
141
  # Call last `n` blocks to further process the hidden states
@@ -167,10 +166,10 @@ class CachedBlocks_Pattern_3_4_5(CachedBlocks_Pattern_Base):
167
166
  *args,
168
167
  **kwargs,
169
168
  ):
170
- assert CachedContext.Fn_compute_blocks() <= len(
169
+ assert self.cache_manager.Fn_compute_blocks() <= len(
171
170
  self.transformer_blocks
172
171
  ), (
173
- f"Fn_compute_blocks {CachedContext.Fn_compute_blocks()} must be less than "
172
+ f"Fn_compute_blocks {self.cache_manager.Fn_compute_blocks()} must be less than "
174
173
  f"the number of transformer blocks {len(self.transformer_blocks)}"
175
174
  )
176
175
  encoder_hidden_states = None # Pattern 3
@@ -242,16 +241,16 @@ class CachedBlocks_Pattern_3_4_5(CachedBlocks_Pattern_Base):
242
241
  *args,
243
242
  **kwargs,
244
243
  ):
245
- if CachedContext.Bn_compute_blocks() == 0:
244
+ if self.cache_manager.Bn_compute_blocks() == 0:
246
245
  return hidden_states, encoder_hidden_states
247
246
 
248
- assert CachedContext.Bn_compute_blocks() <= len(
247
+ assert self.cache_manager.Bn_compute_blocks() <= len(
249
248
  self.transformer_blocks
250
249
  ), (
251
- f"Bn_compute_blocks {CachedContext.Bn_compute_blocks()} must be less than "
250
+ f"Bn_compute_blocks {self.cache_manager.Bn_compute_blocks()} must be less than "
252
251
  f"the number of transformer blocks {len(self.transformer_blocks)}"
253
252
  )
254
- if len(CachedContext.Bn_compute_blocks_ids()) > 0:
253
+ if len(self.cache_manager.Bn_compute_blocks_ids()) > 0:
255
254
  raise ValueError(
256
255
  f"Bn_compute_blocks_ids is not support for "
257
256
  f"patterns: {self._supported_patterns}."