cache-dit 0.2.26__py3-none-any.whl → 0.2.27__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.

Potentially problematic release.


This version of cache-dit might be problematic. Click here for more details.

Files changed (28) hide show
  1. cache_dit/__init__.py +7 -6
  2. cache_dit/_version.py +2 -2
  3. cache_dit/cache_factory/__init__.py +15 -4
  4. cache_dit/cache_factory/block_adapters/__init__.py +538 -0
  5. cache_dit/cache_factory/block_adapters/block_adapters.py +333 -0
  6. cache_dit/cache_factory/block_adapters/block_registers.py +77 -0
  7. cache_dit/cache_factory/cache_adapters.py +120 -911
  8. cache_dit/cache_factory/cache_blocks/__init__.py +7 -9
  9. cache_dit/cache_factory/cache_blocks/pattern_0_1_2.py +2 -2
  10. cache_dit/cache_factory/cache_blocks/pattern_3_4_5.py +46 -41
  11. cache_dit/cache_factory/cache_blocks/pattern_base.py +98 -79
  12. cache_dit/cache_factory/cache_blocks/utils.py +13 -9
  13. cache_dit/cache_factory/cache_contexts/__init__.py +2 -0
  14. cache_dit/cache_factory/{cache_context.py → cache_contexts/cache_context.py} +89 -55
  15. cache_dit/cache_factory/cache_contexts/cache_manager.py +0 -0
  16. cache_dit/cache_factory/cache_interface.py +21 -18
  17. cache_dit/cache_factory/patch_functors/functor_chroma.py +3 -0
  18. cache_dit/cache_factory/patch_functors/functor_flux.py +4 -0
  19. cache_dit/quantize/quantize_ao.py +1 -0
  20. cache_dit/utils.py +19 -16
  21. {cache_dit-0.2.26.dist-info → cache_dit-0.2.27.dist-info}/METADATA +42 -12
  22. cache_dit-0.2.27.dist-info/RECORD +47 -0
  23. cache_dit-0.2.26.dist-info/RECORD +0 -42
  24. /cache_dit/cache_factory/{taylorseer.py → cache_contexts/taylorseer.py} +0 -0
  25. {cache_dit-0.2.26.dist-info → cache_dit-0.2.27.dist-info}/WHEEL +0 -0
  26. {cache_dit-0.2.26.dist-info → cache_dit-0.2.27.dist-info}/entry_points.txt +0 -0
  27. {cache_dit-0.2.26.dist-info → cache_dit-0.2.27.dist-info}/licenses/LICENSE +0 -0
  28. {cache_dit-0.2.26.dist-info → cache_dit-0.2.27.dist-info}/top_level.txt +0 -0
@@ -1,20 +1,18 @@
1
1
  from cache_dit.cache_factory.cache_blocks.pattern_0_1_2 import (
2
- DBCachedBlocks_Pattern_0_1_2,
2
+ CachedBlocks_Pattern_0_1_2,
3
3
  )
4
4
  from cache_dit.cache_factory.cache_blocks.pattern_3_4_5 import (
5
- DBCachedBlocks_Pattern_3_4_5,
5
+ CachedBlocks_Pattern_3_4_5,
6
6
  )
7
7
 
8
8
 
9
- class DBCachedBlocks:
9
+ class CachedBlocks:
10
10
  def __new__(cls, *args, **kwargs):
11
11
  forward_pattern = kwargs.get("forward_pattern", None)
12
12
  assert forward_pattern is not None, "forward_pattern can't be None."
13
- if forward_pattern in DBCachedBlocks_Pattern_0_1_2._supported_patterns:
14
- return DBCachedBlocks_Pattern_0_1_2(*args, **kwargs)
15
- elif (
16
- forward_pattern in DBCachedBlocks_Pattern_3_4_5._supported_patterns
17
- ):
18
- return DBCachedBlocks_Pattern_3_4_5(*args, **kwargs)
13
+ if forward_pattern in CachedBlocks_Pattern_0_1_2._supported_patterns:
14
+ return CachedBlocks_Pattern_0_1_2(*args, **kwargs)
15
+ elif forward_pattern in CachedBlocks_Pattern_3_4_5._supported_patterns:
16
+ return CachedBlocks_Pattern_3_4_5(*args, **kwargs)
19
17
  else:
20
18
  raise ValueError(f"Pattern {forward_pattern} is not supported now!")
@@ -1,13 +1,13 @@
1
1
  from cache_dit.cache_factory import ForwardPattern
2
2
  from cache_dit.cache_factory.cache_blocks.pattern_base import (
3
- DBCachedBlocks_Pattern_Base,
3
+ CachedBlocks_Pattern_Base,
4
4
  )
5
5
  from cache_dit.logger import init_logger
6
6
 
7
7
  logger = init_logger(__name__)
8
8
 
9
9
 
10
- class DBCachedBlocks_Pattern_0_1_2(DBCachedBlocks_Pattern_Base):
10
+ class CachedBlocks_Pattern_0_1_2(CachedBlocks_Pattern_Base):
11
11
  _supported_patterns = [
12
12
  ForwardPattern.Pattern_0,
13
13
  ForwardPattern.Pattern_1,
@@ -1,19 +1,16 @@
1
1
  import torch
2
2
 
3
- from cache_dit.cache_factory import cache_context
3
+ from cache_dit.cache_factory import CachedContext
4
4
  from cache_dit.cache_factory import ForwardPattern
5
- from cache_dit.cache_factory.cache_blocks.utils import (
6
- patch_cached_stats,
7
- )
8
5
  from cache_dit.cache_factory.cache_blocks.pattern_base import (
9
- DBCachedBlocks_Pattern_Base,
6
+ CachedBlocks_Pattern_Base,
10
7
  )
11
8
  from cache_dit.logger import init_logger
12
9
 
13
10
  logger = init_logger(__name__)
14
11
 
15
12
 
16
- class DBCachedBlocks_Pattern_3_4_5(DBCachedBlocks_Pattern_Base):
13
+ class CachedBlocks_Pattern_3_4_5(CachedBlocks_Pattern_Base):
17
14
  _supported_patterns = [
18
15
  ForwardPattern.Pattern_3,
19
16
  ForwardPattern.Pattern_4,
@@ -26,6 +23,11 @@ class DBCachedBlocks_Pattern_3_4_5(DBCachedBlocks_Pattern_Base):
26
23
  *args,
27
24
  **kwargs,
28
25
  ):
26
+ # Use it's own cache context.
27
+ CachedContext.set_cache_context(
28
+ self.cache_context,
29
+ )
30
+
29
31
  original_hidden_states = hidden_states
30
32
  # Call first `n` blocks to process the hidden states for
31
33
  # more stable diff calculation.
@@ -39,40 +41,40 @@ class DBCachedBlocks_Pattern_3_4_5(DBCachedBlocks_Pattern_Base):
39
41
  Fn_hidden_states_residual = hidden_states - original_hidden_states
40
42
  del original_hidden_states
41
43
 
42
- cache_context.mark_step_begin()
44
+ CachedContext.mark_step_begin()
43
45
  # Residual L1 diff or Hidden States L1 diff
44
- can_use_cache = cache_context.get_can_use_cache(
46
+ can_use_cache = CachedContext.get_can_use_cache(
45
47
  (
46
48
  Fn_hidden_states_residual
47
- if not cache_context.is_l1_diff_enabled()
49
+ if not CachedContext.is_l1_diff_enabled()
48
50
  else hidden_states
49
51
  ),
50
52
  parallelized=self._is_parallelized(),
51
53
  prefix=(
52
- "Fn_residual"
53
- if not cache_context.is_l1_diff_enabled()
54
- else "Fn_hidden_states"
54
+ f"{self.blocks_name}_Fn_residual"
55
+ if not CachedContext.is_l1_diff_enabled()
56
+ else f"{self.blocks_name}_Fn_hidden_states"
55
57
  ),
56
58
  )
57
59
 
58
60
  torch._dynamo.graph_break()
59
61
  if can_use_cache:
60
- cache_context.add_cached_step()
62
+ CachedContext.add_cached_step()
61
63
  del Fn_hidden_states_residual
62
64
  hidden_states, encoder_hidden_states = (
63
- cache_context.apply_hidden_states_residual(
65
+ CachedContext.apply_hidden_states_residual(
64
66
  hidden_states,
65
67
  # None Pattern 3, else 4, 5
66
68
  encoder_hidden_states,
67
69
  prefix=(
68
- "Bn_residual"
69
- if cache_context.is_cache_residual()
70
- else "Bn_hidden_states"
70
+ f"{self.blocks_name}_Bn_residual"
71
+ if CachedContext.is_cache_residual()
72
+ else f"{self.blocks_name}_Bn_hidden_states"
71
73
  ),
72
74
  encoder_prefix=(
73
- "Bn_residual"
74
- if cache_context.is_encoder_cache_residual()
75
- else "Bn_hidden_states"
75
+ f"{self.blocks_name}_Bn_residual"
76
+ if CachedContext.is_encoder_cache_residual()
77
+ else f"{self.blocks_name}_Bn_hidden_states"
76
78
  ),
77
79
  )
78
80
  )
@@ -86,12 +88,16 @@ class DBCachedBlocks_Pattern_3_4_5(DBCachedBlocks_Pattern_Base):
86
88
  **kwargs,
87
89
  )
88
90
  else:
89
- cache_context.set_Fn_buffer(
90
- Fn_hidden_states_residual, prefix="Fn_residual"
91
+ CachedContext.set_Fn_buffer(
92
+ Fn_hidden_states_residual,
93
+ prefix=f"{self.blocks_name}_Fn_residual",
91
94
  )
92
- if cache_context.is_l1_diff_enabled():
95
+ if CachedContext.is_l1_diff_enabled():
93
96
  # for hidden states L1 diff
94
- cache_context.set_Fn_buffer(hidden_states, "Fn_hidden_states")
97
+ CachedContext.set_Fn_buffer(
98
+ hidden_states,
99
+ f"{self.blocks_name}_Fn_hidden_states",
100
+ )
95
101
  del Fn_hidden_states_residual
96
102
  torch._dynamo.graph_break()
97
103
  (
@@ -108,29 +114,29 @@ class DBCachedBlocks_Pattern_3_4_5(DBCachedBlocks_Pattern_Base):
108
114
  **kwargs,
109
115
  )
110
116
  torch._dynamo.graph_break()
111
- if cache_context.is_cache_residual():
112
- cache_context.set_Bn_buffer(
117
+ if CachedContext.is_cache_residual():
118
+ CachedContext.set_Bn_buffer(
113
119
  hidden_states_residual,
114
- prefix="Bn_residual",
120
+ prefix=f"{self.blocks_name}_Bn_residual",
115
121
  )
116
122
  else:
117
123
  # TaylorSeer
118
- cache_context.set_Bn_buffer(
124
+ CachedContext.set_Bn_buffer(
119
125
  hidden_states,
120
- prefix="Bn_hidden_states",
126
+ prefix=f"{self.blocks_name}_Bn_hidden_states",
121
127
  )
122
- if cache_context.is_encoder_cache_residual():
123
- cache_context.set_Bn_encoder_buffer(
128
+ if CachedContext.is_encoder_cache_residual():
129
+ CachedContext.set_Bn_encoder_buffer(
124
130
  # None Pattern 3, else 4, 5
125
131
  encoder_hidden_states_residual,
126
- prefix="Bn_residual",
132
+ prefix=f"{self.blocks_name}_Bn_residual",
127
133
  )
128
134
  else:
129
135
  # TaylorSeer
130
- cache_context.set_Bn_encoder_buffer(
136
+ CachedContext.set_Bn_encoder_buffer(
131
137
  # None Pattern 3, else 4, 5
132
138
  encoder_hidden_states,
133
- prefix="Bn_hidden_states",
139
+ prefix=f"{self.blocks_name}_Bn_hidden_states",
134
140
  )
135
141
  torch._dynamo.graph_break()
136
142
  # Call last `n` blocks to further process the hidden states
@@ -143,7 +149,6 @@ class DBCachedBlocks_Pattern_3_4_5(DBCachedBlocks_Pattern_Base):
143
149
  **kwargs,
144
150
  )
145
151
 
146
- patch_cached_stats(self.transformer)
147
152
  torch._dynamo.graph_break()
148
153
 
149
154
  return (
@@ -162,10 +167,10 @@ class DBCachedBlocks_Pattern_3_4_5(DBCachedBlocks_Pattern_Base):
162
167
  *args,
163
168
  **kwargs,
164
169
  ):
165
- assert cache_context.Fn_compute_blocks() <= len(
170
+ assert CachedContext.Fn_compute_blocks() <= len(
166
171
  self.transformer_blocks
167
172
  ), (
168
- f"Fn_compute_blocks {cache_context.Fn_compute_blocks()} must be less than "
173
+ f"Fn_compute_blocks {CachedContext.Fn_compute_blocks()} must be less than "
169
174
  f"the number of transformer blocks {len(self.transformer_blocks)}"
170
175
  )
171
176
  encoder_hidden_states = None # Pattern 3
@@ -237,16 +242,16 @@ class DBCachedBlocks_Pattern_3_4_5(DBCachedBlocks_Pattern_Base):
237
242
  *args,
238
243
  **kwargs,
239
244
  ):
240
- if cache_context.Bn_compute_blocks() == 0:
245
+ if CachedContext.Bn_compute_blocks() == 0:
241
246
  return hidden_states, encoder_hidden_states
242
247
 
243
- assert cache_context.Bn_compute_blocks() <= len(
248
+ assert CachedContext.Bn_compute_blocks() <= len(
244
249
  self.transformer_blocks
245
250
  ), (
246
- f"Bn_compute_blocks {cache_context.Bn_compute_blocks()} must be less than "
251
+ f"Bn_compute_blocks {CachedContext.Bn_compute_blocks()} must be less than "
247
252
  f"the number of transformer blocks {len(self.transformer_blocks)}"
248
253
  )
249
- if len(cache_context.Bn_compute_blocks_ids()) > 0:
254
+ if len(CachedContext.Bn_compute_blocks_ids()) > 0:
250
255
  raise ValueError(
251
256
  f"Bn_compute_blocks_ids is not support for "
252
257
  f"patterns: {self._supported_patterns}."
@@ -2,17 +2,14 @@ import inspect
2
2
  import torch
3
3
  import torch.distributed as dist
4
4
 
5
- from cache_dit.cache_factory import cache_context
5
+ from cache_dit.cache_factory import CachedContext
6
6
  from cache_dit.cache_factory import ForwardPattern
7
- from cache_dit.cache_factory.cache_blocks.utils import (
8
- patch_cached_stats,
9
- )
10
7
  from cache_dit.logger import init_logger
11
8
 
12
9
  logger = init_logger(__name__)
13
10
 
14
11
 
15
- class DBCachedBlocks_Pattern_Base(torch.nn.Module):
12
+ class CachedBlocks_Pattern_Base(torch.nn.Module):
16
13
  _supported_patterns = [
17
14
  ForwardPattern.Pattern_0,
18
15
  ForwardPattern.Pattern_1,
@@ -22,17 +19,29 @@ class DBCachedBlocks_Pattern_Base(torch.nn.Module):
22
19
  def __init__(
23
20
  self,
24
21
  transformer_blocks: torch.nn.ModuleList,
22
+ # 'transformer_blocks', 'blocks', 'single_transformer_blocks',
23
+ # 'layers', 'single_stream_blocks', 'double_stream_blocks'
24
+ blocks_name: str,
25
+ # Usually, blocks_name, etc.
26
+ cache_context: str,
25
27
  *,
26
28
  transformer: torch.nn.Module = None,
27
29
  forward_pattern: ForwardPattern = ForwardPattern.Pattern_0,
30
+ check_num_outputs: bool = True,
28
31
  ):
29
32
  super().__init__()
30
33
 
31
34
  self.transformer = transformer
32
35
  self.transformer_blocks = transformer_blocks
36
+ self.blocks_name = blocks_name
37
+ self.cache_context = cache_context
33
38
  self.forward_pattern = forward_pattern
39
+ self.check_num_outputs = check_num_outputs
34
40
  self._check_forward_pattern()
35
- logger.info(f"Match Cached Blocks: {self.__class__.__name__}")
41
+ logger.info(
42
+ f"Match Cached Blocks: {self.__class__.__name__}, for "
43
+ f"{self.blocks_name}, context: {self.cache_context}"
44
+ )
36
45
 
37
46
  def _check_forward_pattern(self):
38
47
  assert (
@@ -45,16 +54,18 @@ class DBCachedBlocks_Pattern_Base(torch.nn.Module):
45
54
  forward_parameters = set(
46
55
  inspect.signature(block.forward).parameters.keys()
47
56
  )
48
- num_outputs = str(
49
- inspect.signature(block.forward).return_annotation
50
- ).count("torch.Tensor")
51
-
52
- if num_outputs > 0:
53
- assert len(self.forward_pattern.Out) == num_outputs, (
54
- f"The number of block's outputs is {num_outputs} don't not "
55
- f"match the number of the pattern: {self.forward_pattern}, "
56
- f"Out: {len(self.forward_pattern.Out)}."
57
- )
57
+
58
+ if self.check_num_outputs:
59
+ num_outputs = str(
60
+ inspect.signature(block.forward).return_annotation
61
+ ).count("torch.Tensor")
62
+
63
+ if num_outputs > 0:
64
+ assert len(self.forward_pattern.Out) == num_outputs, (
65
+ f"The number of block's outputs is {num_outputs} don't not "
66
+ f"match the number of the pattern: {self.forward_pattern}, "
67
+ f"Out: {len(self.forward_pattern.Out)}."
68
+ )
58
69
 
59
70
  for required_param in self.forward_pattern.In:
60
71
  assert (
@@ -68,6 +79,10 @@ class DBCachedBlocks_Pattern_Base(torch.nn.Module):
68
79
  *args,
69
80
  **kwargs,
70
81
  ):
82
+ CachedContext.set_cache_context(
83
+ self.cache_context,
84
+ )
85
+
71
86
  original_hidden_states = hidden_states
72
87
  # Call first `n` blocks to process the hidden states for
73
88
  # more stable diff calculation.
@@ -81,39 +96,39 @@ class DBCachedBlocks_Pattern_Base(torch.nn.Module):
81
96
  Fn_hidden_states_residual = hidden_states - original_hidden_states
82
97
  del original_hidden_states
83
98
 
84
- cache_context.mark_step_begin()
99
+ CachedContext.mark_step_begin()
85
100
  # Residual L1 diff or Hidden States L1 diff
86
- can_use_cache = cache_context.get_can_use_cache(
101
+ can_use_cache = CachedContext.get_can_use_cache(
87
102
  (
88
103
  Fn_hidden_states_residual
89
- if not cache_context.is_l1_diff_enabled()
104
+ if not CachedContext.is_l1_diff_enabled()
90
105
  else hidden_states
91
106
  ),
92
107
  parallelized=self._is_parallelized(),
93
108
  prefix=(
94
- "Fn_residual"
95
- if not cache_context.is_l1_diff_enabled()
96
- else "Fn_hidden_states"
109
+ f"{self.blocks_name}_Fn_residual"
110
+ if not CachedContext.is_l1_diff_enabled()
111
+ else f"{self.blocks_name}_Fn_hidden_states"
97
112
  ),
98
113
  )
99
114
 
100
115
  torch._dynamo.graph_break()
101
116
  if can_use_cache:
102
- cache_context.add_cached_step()
117
+ CachedContext.add_cached_step()
103
118
  del Fn_hidden_states_residual
104
119
  hidden_states, encoder_hidden_states = (
105
- cache_context.apply_hidden_states_residual(
120
+ CachedContext.apply_hidden_states_residual(
106
121
  hidden_states,
107
122
  encoder_hidden_states,
108
123
  prefix=(
109
- "Bn_residual"
110
- if cache_context.is_cache_residual()
111
- else "Bn_hidden_states"
124
+ f"{self.blocks_name}_Bn_residual"
125
+ if CachedContext.is_cache_residual()
126
+ else f"{self.blocks_name}_Bn_hidden_states"
112
127
  ),
113
128
  encoder_prefix=(
114
- "Bn_residual"
115
- if cache_context.is_encoder_cache_residual()
116
- else "Bn_hidden_states"
129
+ f"{self.blocks_name}_Bn_residual"
130
+ if CachedContext.is_encoder_cache_residual()
131
+ else f"{self.blocks_name}_Bn_hidden_states"
117
132
  ),
118
133
  )
119
134
  )
@@ -127,12 +142,16 @@ class DBCachedBlocks_Pattern_Base(torch.nn.Module):
127
142
  **kwargs,
128
143
  )
129
144
  else:
130
- cache_context.set_Fn_buffer(
131
- Fn_hidden_states_residual, prefix="Fn_residual"
145
+ CachedContext.set_Fn_buffer(
146
+ Fn_hidden_states_residual,
147
+ prefix=f"{self.blocks_name}_Fn_residual",
132
148
  )
133
- if cache_context.is_l1_diff_enabled():
149
+ if CachedContext.is_l1_diff_enabled():
134
150
  # for hidden states L1 diff
135
- cache_context.set_Fn_buffer(hidden_states, "Fn_hidden_states")
151
+ CachedContext.set_Fn_buffer(
152
+ hidden_states,
153
+ f"{self.blocks_name}_Fn_hidden_states",
154
+ )
136
155
  del Fn_hidden_states_residual
137
156
  torch._dynamo.graph_break()
138
157
  (
@@ -147,27 +166,27 @@ class DBCachedBlocks_Pattern_Base(torch.nn.Module):
147
166
  **kwargs,
148
167
  )
149
168
  torch._dynamo.graph_break()
150
- if cache_context.is_cache_residual():
151
- cache_context.set_Bn_buffer(
169
+ if CachedContext.is_cache_residual():
170
+ CachedContext.set_Bn_buffer(
152
171
  hidden_states_residual,
153
- prefix="Bn_residual",
172
+ prefix=f"{self.blocks_name}_Bn_residual",
154
173
  )
155
174
  else:
156
175
  # TaylorSeer
157
- cache_context.set_Bn_buffer(
176
+ CachedContext.set_Bn_buffer(
158
177
  hidden_states,
159
- prefix="Bn_hidden_states",
178
+ prefix=f"{self.blocks_name}_Bn_hidden_states",
160
179
  )
161
- if cache_context.is_encoder_cache_residual():
162
- cache_context.set_Bn_encoder_buffer(
180
+ if CachedContext.is_encoder_cache_residual():
181
+ CachedContext.set_Bn_encoder_buffer(
163
182
  encoder_hidden_states_residual,
164
- prefix="Bn_residual",
183
+ prefix=f"{self.blocks_name}_Bn_residual",
165
184
  )
166
185
  else:
167
186
  # TaylorSeer
168
- cache_context.set_Bn_encoder_buffer(
187
+ CachedContext.set_Bn_encoder_buffer(
169
188
  encoder_hidden_states,
170
- prefix="Bn_hidden_states",
189
+ prefix=f"{self.blocks_name}_Bn_hidden_states",
171
190
  )
172
191
  torch._dynamo.graph_break()
173
192
  # Call last `n` blocks to further process the hidden states
@@ -179,7 +198,7 @@ class DBCachedBlocks_Pattern_Base(torch.nn.Module):
179
198
  **kwargs,
180
199
  )
181
200
 
182
- patch_cached_stats(self.transformer)
201
+ # patch cached stats for blocks or remove it.
183
202
  torch._dynamo.graph_break()
184
203
 
185
204
  return (
@@ -213,10 +232,10 @@ class DBCachedBlocks_Pattern_Base(torch.nn.Module):
213
232
  # If so, we can skip some Bn blocks and directly
214
233
  # use the cached values.
215
234
  return (
216
- cache_context.get_current_step() in cache_context.get_cached_steps()
235
+ CachedContext.get_current_step() in CachedContext.get_cached_steps()
217
236
  ) or (
218
- cache_context.get_current_step()
219
- in cache_context.get_cfg_cached_steps()
237
+ CachedContext.get_current_step()
238
+ in CachedContext.get_cfg_cached_steps()
220
239
  )
221
240
 
222
241
  @torch.compiler.disable
@@ -225,20 +244,20 @@ class DBCachedBlocks_Pattern_Base(torch.nn.Module):
225
244
  # more stable diff calculation.
226
245
  # Fn: [0,...,n-1]
227
246
  selected_Fn_blocks = self.transformer_blocks[
228
- : cache_context.Fn_compute_blocks()
247
+ : CachedContext.Fn_compute_blocks()
229
248
  ]
230
249
  return selected_Fn_blocks
231
250
 
232
251
  @torch.compiler.disable
233
252
  def _Mn_blocks(self): # middle blocks
234
253
  # M(N-2n): only transformer_blocks [n,...,N-n], middle
235
- if cache_context.Bn_compute_blocks() == 0: # WARN: x[:-0] = []
254
+ if CachedContext.Bn_compute_blocks() == 0: # WARN: x[:-0] = []
236
255
  selected_Mn_blocks = self.transformer_blocks[
237
- cache_context.Fn_compute_blocks() :
256
+ CachedContext.Fn_compute_blocks() :
238
257
  ]
239
258
  else:
240
259
  selected_Mn_blocks = self.transformer_blocks[
241
- cache_context.Fn_compute_blocks() : -cache_context.Bn_compute_blocks()
260
+ CachedContext.Fn_compute_blocks() : -CachedContext.Bn_compute_blocks()
242
261
  ]
243
262
  return selected_Mn_blocks
244
263
 
@@ -246,7 +265,7 @@ class DBCachedBlocks_Pattern_Base(torch.nn.Module):
246
265
  def _Bn_blocks(self):
247
266
  # Bn: transformer_blocks [N-n+1,...,N-1]
248
267
  selected_Bn_blocks = self.transformer_blocks[
249
- -cache_context.Bn_compute_blocks() :
268
+ -CachedContext.Bn_compute_blocks() :
250
269
  ]
251
270
  return selected_Bn_blocks
252
271
 
@@ -257,10 +276,10 @@ class DBCachedBlocks_Pattern_Base(torch.nn.Module):
257
276
  *args,
258
277
  **kwargs,
259
278
  ):
260
- assert cache_context.Fn_compute_blocks() <= len(
279
+ assert CachedContext.Fn_compute_blocks() <= len(
261
280
  self.transformer_blocks
262
281
  ), (
263
- f"Fn_compute_blocks {cache_context.Fn_compute_blocks()} must be less than "
282
+ f"Fn_compute_blocks {CachedContext.Fn_compute_blocks()} must be less than "
264
283
  f"the number of transformer blocks {len(self.transformer_blocks)}"
265
284
  )
266
285
  for block in self._Fn_blocks():
@@ -357,7 +376,7 @@ class DBCachedBlocks_Pattern_Base(torch.nn.Module):
357
376
  )
358
377
  # Cache residuals for the non-compute Bn blocks for
359
378
  # subsequent cache steps.
360
- if block_id not in cache_context.Bn_compute_blocks_ids():
379
+ if block_id not in CachedContext.Bn_compute_blocks_ids():
361
380
  Bn_i_hidden_states_residual = (
362
381
  hidden_states - Bn_i_original_hidden_states
363
382
  )
@@ -366,22 +385,22 @@ class DBCachedBlocks_Pattern_Base(torch.nn.Module):
366
385
  )
367
386
 
368
387
  # Save original_hidden_states for diff calculation.
369
- cache_context.set_Bn_buffer(
388
+ CachedContext.set_Bn_buffer(
370
389
  Bn_i_original_hidden_states,
371
- prefix=f"Bn_{block_id}_original",
390
+ prefix=f"{self.blocks_name}_Bn_{block_id}_original",
372
391
  )
373
- cache_context.set_Bn_encoder_buffer(
392
+ CachedContext.set_Bn_encoder_buffer(
374
393
  Bn_i_original_encoder_hidden_states,
375
- prefix=f"Bn_{block_id}_original",
394
+ prefix=f"{self.blocks_name}_Bn_{block_id}_original",
376
395
  )
377
396
 
378
- cache_context.set_Bn_buffer(
397
+ CachedContext.set_Bn_buffer(
379
398
  Bn_i_hidden_states_residual,
380
- prefix=f"Bn_{block_id}_residual",
399
+ prefix=f"{self.blocks_name}_Bn_{block_id}_residual",
381
400
  )
382
- cache_context.set_Bn_encoder_buffer(
401
+ CachedContext.set_Bn_encoder_buffer(
383
402
  Bn_i_encoder_hidden_states_residual,
384
- prefix=f"Bn_{block_id}_residual",
403
+ prefix=f"{self.blocks_name}_Bn_{block_id}_residual",
385
404
  )
386
405
  del Bn_i_hidden_states_residual
387
406
  del Bn_i_encoder_hidden_states_residual
@@ -392,7 +411,7 @@ class DBCachedBlocks_Pattern_Base(torch.nn.Module):
392
411
  else:
393
412
  # Cache steps: Reuse the cached residuals.
394
413
  # Check if the block is in the Bn_compute_blocks_ids.
395
- if block_id in cache_context.Bn_compute_blocks_ids():
414
+ if block_id in CachedContext.Bn_compute_blocks_ids():
396
415
  hidden_states = block(
397
416
  hidden_states,
398
417
  encoder_hidden_states,
@@ -410,25 +429,25 @@ class DBCachedBlocks_Pattern_Base(torch.nn.Module):
410
429
  # Skip the block if it is not in the Bn_compute_blocks_ids.
411
430
  # Use the cached residuals instead.
412
431
  # Check if can use the cached residuals.
413
- if cache_context.get_can_use_cache(
432
+ if CachedContext.get_can_use_cache(
414
433
  hidden_states, # curr step
415
434
  parallelized=self._is_parallelized(),
416
- threshold=cache_context.non_compute_blocks_diff_threshold(),
417
- prefix=f"Bn_{block_id}_original", # prev step
435
+ threshold=CachedContext.non_compute_blocks_diff_threshold(),
436
+ prefix=f"{self.blocks_name}_Bn_{block_id}_original", # prev step
418
437
  ):
419
438
  hidden_states, encoder_hidden_states = (
420
- cache_context.apply_hidden_states_residual(
439
+ CachedContext.apply_hidden_states_residual(
421
440
  hidden_states,
422
441
  encoder_hidden_states,
423
442
  prefix=(
424
- f"Bn_{block_id}_residual"
425
- if cache_context.is_cache_residual()
426
- else f"Bn_{block_id}_original"
443
+ f"{self.blocks_name}_Bn_{block_id}_residual"
444
+ if CachedContext.is_cache_residual()
445
+ else f"{self.blocks_name}_Bn_{block_id}_original"
427
446
  ),
428
447
  encoder_prefix=(
429
- f"Bn_{block_id}_residual"
430
- if cache_context.is_encoder_cache_residual()
431
- else f"Bn_{block_id}_original"
448
+ f"{self.blocks_name}_Bn_{block_id}_residual"
449
+ if CachedContext.is_encoder_cache_residual()
450
+ else f"{self.blocks_name}_Bn_{block_id}_original"
432
451
  ),
433
452
  )
434
453
  )
@@ -455,16 +474,16 @@ class DBCachedBlocks_Pattern_Base(torch.nn.Module):
455
474
  *args,
456
475
  **kwargs,
457
476
  ):
458
- if cache_context.Bn_compute_blocks() == 0:
477
+ if CachedContext.Bn_compute_blocks() == 0:
459
478
  return hidden_states, encoder_hidden_states
460
479
 
461
- assert cache_context.Bn_compute_blocks() <= len(
480
+ assert CachedContext.Bn_compute_blocks() <= len(
462
481
  self.transformer_blocks
463
482
  ), (
464
- f"Bn_compute_blocks {cache_context.Bn_compute_blocks()} must be less than "
483
+ f"Bn_compute_blocks {CachedContext.Bn_compute_blocks()} must be less than "
465
484
  f"the number of transformer blocks {len(self.transformer_blocks)}"
466
485
  )
467
- if len(cache_context.Bn_compute_blocks_ids()) > 0:
486
+ if len(CachedContext.Bn_compute_blocks_ids()) > 0:
468
487
  for i, block in enumerate(self._Bn_blocks()):
469
488
  hidden_states, encoder_hidden_states = (
470
489
  self._compute_or_cache_block(
@@ -1,19 +1,23 @@
1
1
  import torch
2
2
 
3
- from cache_dit.cache_factory import cache_context
3
+ from typing import Any
4
+ from cache_dit.cache_factory import CachedContext
4
5
 
5
6
 
6
7
  @torch.compiler.disable
7
8
  def patch_cached_stats(
8
- transformer,
9
+ module: torch.nn.Module | Any, cache_context: str = None
9
10
  ):
10
- # Patch the cached stats to the transformer, the cached stats
11
+ # Patch the cached stats to the module, the cached stats
11
12
  # will be reset for each calling of pipe.__call__(**kwargs).
12
- if transformer is None:
13
+ if module is None:
13
14
  return
14
15
 
15
- # TODO: Patch more cached stats to the transformer
16
- transformer._cached_steps = cache_context.get_cached_steps()
17
- transformer._residual_diffs = cache_context.get_residual_diffs()
18
- transformer._cfg_cached_steps = cache_context.get_cfg_cached_steps()
19
- transformer._cfg_residual_diffs = cache_context.get_cfg_residual_diffs()
16
+ if cache_context is not None:
17
+ CachedContext.set_cache_context(cache_context)
18
+
19
+ # TODO: Patch more cached stats to the module
20
+ module._cached_steps = CachedContext.get_cached_steps()
21
+ module._residual_diffs = CachedContext.get_residual_diffs()
22
+ module._cfg_cached_steps = CachedContext.get_cfg_cached_steps()
23
+ module._cfg_residual_diffs = CachedContext.get_cfg_residual_diffs()
@@ -0,0 +1,2 @@
1
+ # namespace alias: for _CachedContext and many others' cache context funcs.
2
+ import cache_dit.cache_factory.cache_contexts.cache_context as CachedContext