cache-dit 0.2.26__py3-none-any.whl → 0.2.28__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (29) hide show
  1. cache_dit/__init__.py +8 -6
  2. cache_dit/_version.py +2 -2
  3. cache_dit/cache_factory/__init__.py +17 -4
  4. cache_dit/cache_factory/block_adapters/__init__.py +555 -0
  5. cache_dit/cache_factory/block_adapters/block_adapters.py +538 -0
  6. cache_dit/cache_factory/block_adapters/block_registers.py +77 -0
  7. cache_dit/cache_factory/cache_adapters.py +262 -938
  8. cache_dit/cache_factory/cache_blocks/__init__.py +60 -11
  9. cache_dit/cache_factory/cache_blocks/pattern_0_1_2.py +2 -2
  10. cache_dit/cache_factory/cache_blocks/pattern_3_4_5.py +45 -41
  11. cache_dit/cache_factory/cache_blocks/pattern_base.py +106 -80
  12. cache_dit/cache_factory/cache_blocks/utils.py +16 -10
  13. cache_dit/cache_factory/cache_contexts/__init__.py +5 -0
  14. cache_dit/cache_factory/cache_contexts/cache_context.py +327 -0
  15. cache_dit/cache_factory/cache_contexts/cache_manager.py +833 -0
  16. cache_dit/cache_factory/cache_interface.py +31 -31
  17. cache_dit/cache_factory/patch_functors/functor_chroma.py +3 -0
  18. cache_dit/cache_factory/patch_functors/functor_flux.py +4 -0
  19. cache_dit/quantize/quantize_ao.py +1 -0
  20. cache_dit/utils.py +26 -26
  21. {cache_dit-0.2.26.dist-info → cache_dit-0.2.28.dist-info}/METADATA +59 -23
  22. cache_dit-0.2.28.dist-info/RECORD +47 -0
  23. cache_dit/cache_factory/cache_context.py +0 -1155
  24. cache_dit-0.2.26.dist-info/RECORD +0 -42
  25. /cache_dit/cache_factory/{taylorseer.py → cache_contexts/taylorseer.py} +0 -0
  26. {cache_dit-0.2.26.dist-info → cache_dit-0.2.28.dist-info}/WHEEL +0 -0
  27. {cache_dit-0.2.26.dist-info → cache_dit-0.2.28.dist-info}/entry_points.txt +0 -0
  28. {cache_dit-0.2.26.dist-info → cache_dit-0.2.28.dist-info}/licenses/LICENSE +0 -0
  29. {cache_dit-0.2.26.dist-info → cache_dit-0.2.28.dist-info}/top_level.txt +0 -0
@@ -1,20 +1,69 @@
1
+ import torch
2
+
3
+ from cache_dit.cache_factory import ForwardPattern
4
+ from cache_dit.cache_factory.cache_contexts.cache_context import CachedContext
5
+ from cache_dit.cache_factory.cache_contexts.cache_manager import (
6
+ CachedContextManager,
7
+ )
8
+
1
9
  from cache_dit.cache_factory.cache_blocks.pattern_0_1_2 import (
2
- DBCachedBlocks_Pattern_0_1_2,
10
+ CachedBlocks_Pattern_0_1_2,
3
11
  )
4
12
  from cache_dit.cache_factory.cache_blocks.pattern_3_4_5 import (
5
- DBCachedBlocks_Pattern_3_4_5,
13
+ CachedBlocks_Pattern_3_4_5,
6
14
  )
7
15
 
16
+ from cache_dit.logger import init_logger
17
+
18
+ logger = init_logger(__name__)
19
+
8
20
 
9
- class DBCachedBlocks:
10
- def __new__(cls, *args, **kwargs):
11
- forward_pattern = kwargs.get("forward_pattern", None)
21
+ class CachedBlocks:
22
+ def __new__(
23
+ cls,
24
+ # 0. Transformer blocks configuration
25
+ transformer_blocks: torch.nn.ModuleList,
26
+ transformer: torch.nn.Module = None,
27
+ forward_pattern: ForwardPattern = None,
28
+ check_num_outputs: bool = True,
29
+ # 1. Cache context configuration
30
+ # 'transformer_blocks', 'blocks', 'single_transformer_blocks',
31
+ # 'layers', 'single_stream_blocks', 'double_stream_blocks'
32
+ cache_prefix: str = None, # cache_prefix maybe un-need.
33
+ # Usually, blocks_name, etc.
34
+ cache_context: CachedContext | str = None,
35
+ cache_manager: CachedContextManager = None,
36
+ **kwargs,
37
+ ):
38
+ assert transformer is not None, "transformer can't be None."
12
39
  assert forward_pattern is not None, "forward_pattern can't be None."
13
- if forward_pattern in DBCachedBlocks_Pattern_0_1_2._supported_patterns:
14
- return DBCachedBlocks_Pattern_0_1_2(*args, **kwargs)
15
- elif (
16
- forward_pattern in DBCachedBlocks_Pattern_3_4_5._supported_patterns
17
- ):
18
- return DBCachedBlocks_Pattern_3_4_5(*args, **kwargs)
40
+ assert cache_context is not None, "cache_context can't be None."
41
+ assert cache_manager is not None, "cache_manager can't be None."
42
+ if forward_pattern in CachedBlocks_Pattern_0_1_2._supported_patterns:
43
+ return CachedBlocks_Pattern_0_1_2(
44
+ # 0. Transformer blocks configuration
45
+ transformer_blocks,
46
+ transformer=transformer,
47
+ forward_pattern=forward_pattern,
48
+ check_num_outputs=check_num_outputs,
49
+ # 1. Cache context configuration
50
+ cache_prefix=cache_prefix,
51
+ cache_context=cache_context,
52
+ cache_manager=cache_manager,
53
+ **kwargs,
54
+ )
55
+ elif forward_pattern in CachedBlocks_Pattern_3_4_5._supported_patterns:
56
+ return CachedBlocks_Pattern_3_4_5(
57
+ # 0. Transformer blocks configuration
58
+ transformer_blocks,
59
+ transformer=transformer,
60
+ forward_pattern=forward_pattern,
61
+ check_num_outputs=check_num_outputs,
62
+ # 1. Cache context configuration
63
+ cache_prefix=cache_prefix,
64
+ cache_context=cache_context,
65
+ cache_manager=cache_manager,
66
+ **kwargs,
67
+ )
19
68
  else:
20
69
  raise ValueError(f"Pattern {forward_pattern} is not supported now!")
@@ -1,13 +1,13 @@
1
1
  from cache_dit.cache_factory import ForwardPattern
2
2
  from cache_dit.cache_factory.cache_blocks.pattern_base import (
3
- DBCachedBlocks_Pattern_Base,
3
+ CachedBlocks_Pattern_Base,
4
4
  )
5
5
  from cache_dit.logger import init_logger
6
6
 
7
7
  logger = init_logger(__name__)
8
8
 
9
9
 
10
- class DBCachedBlocks_Pattern_0_1_2(DBCachedBlocks_Pattern_Base):
10
+ class CachedBlocks_Pattern_0_1_2(CachedBlocks_Pattern_Base):
11
11
  _supported_patterns = [
12
12
  ForwardPattern.Pattern_0,
13
13
  ForwardPattern.Pattern_1,
@@ -1,19 +1,15 @@
1
1
  import torch
2
2
 
3
- from cache_dit.cache_factory import cache_context
4
3
  from cache_dit.cache_factory import ForwardPattern
5
- from cache_dit.cache_factory.cache_blocks.utils import (
6
- patch_cached_stats,
7
- )
8
4
  from cache_dit.cache_factory.cache_blocks.pattern_base import (
9
- DBCachedBlocks_Pattern_Base,
5
+ CachedBlocks_Pattern_Base,
10
6
  )
11
7
  from cache_dit.logger import init_logger
12
8
 
13
9
  logger = init_logger(__name__)
14
10
 
15
11
 
16
- class DBCachedBlocks_Pattern_3_4_5(DBCachedBlocks_Pattern_Base):
12
+ class CachedBlocks_Pattern_3_4_5(CachedBlocks_Pattern_Base):
17
13
  _supported_patterns = [
18
14
  ForwardPattern.Pattern_3,
19
15
  ForwardPattern.Pattern_4,
@@ -26,6 +22,11 @@ class DBCachedBlocks_Pattern_3_4_5(DBCachedBlocks_Pattern_Base):
26
22
  *args,
27
23
  **kwargs,
28
24
  ):
25
+ # Use it's own cache context.
26
+ self.cache_manager.set_context(
27
+ self.cache_context,
28
+ )
29
+
29
30
  original_hidden_states = hidden_states
30
31
  # Call first `n` blocks to process the hidden states for
31
32
  # more stable diff calculation.
@@ -39,40 +40,40 @@ class DBCachedBlocks_Pattern_3_4_5(DBCachedBlocks_Pattern_Base):
39
40
  Fn_hidden_states_residual = hidden_states - original_hidden_states
40
41
  del original_hidden_states
41
42
 
42
- cache_context.mark_step_begin()
43
+ self.cache_manager.mark_step_begin()
43
44
  # Residual L1 diff or Hidden States L1 diff
44
- can_use_cache = cache_context.get_can_use_cache(
45
+ can_use_cache = self.cache_manager.can_cache(
45
46
  (
46
47
  Fn_hidden_states_residual
47
- if not cache_context.is_l1_diff_enabled()
48
+ if not self.cache_manager.is_l1_diff_enabled()
48
49
  else hidden_states
49
50
  ),
50
51
  parallelized=self._is_parallelized(),
51
52
  prefix=(
52
- "Fn_residual"
53
- if not cache_context.is_l1_diff_enabled()
54
- else "Fn_hidden_states"
53
+ f"{self.cache_prefix}_Fn_residual"
54
+ if not self.cache_manager.is_l1_diff_enabled()
55
+ else f"{self.cache_prefix}_Fn_hidden_states"
55
56
  ),
56
57
  )
57
58
 
58
59
  torch._dynamo.graph_break()
59
60
  if can_use_cache:
60
- cache_context.add_cached_step()
61
+ self.cache_manager.add_cached_step()
61
62
  del Fn_hidden_states_residual
62
63
  hidden_states, encoder_hidden_states = (
63
- cache_context.apply_hidden_states_residual(
64
+ self.cache_manager.apply_cache(
64
65
  hidden_states,
65
66
  # None Pattern 3, else 4, 5
66
67
  encoder_hidden_states,
67
68
  prefix=(
68
- "Bn_residual"
69
- if cache_context.is_cache_residual()
70
- else "Bn_hidden_states"
69
+ f"{self.cache_prefix}_Bn_residual"
70
+ if self.cache_manager.is_cache_residual()
71
+ else f"{self.cache_prefix}_Bn_hidden_states"
71
72
  ),
72
73
  encoder_prefix=(
73
- "Bn_residual"
74
- if cache_context.is_encoder_cache_residual()
75
- else "Bn_hidden_states"
74
+ f"{self.cache_prefix}_Bn_residual"
75
+ if self.cache_manager.is_encoder_cache_residual()
76
+ else f"{self.cache_prefix}_Bn_hidden_states"
76
77
  ),
77
78
  )
78
79
  )
@@ -86,12 +87,16 @@ class DBCachedBlocks_Pattern_3_4_5(DBCachedBlocks_Pattern_Base):
86
87
  **kwargs,
87
88
  )
88
89
  else:
89
- cache_context.set_Fn_buffer(
90
- Fn_hidden_states_residual, prefix="Fn_residual"
90
+ self.cache_manager.set_Fn_buffer(
91
+ Fn_hidden_states_residual,
92
+ prefix=f"{self.cache_prefix}_Fn_residual",
91
93
  )
92
- if cache_context.is_l1_diff_enabled():
94
+ if self.cache_manager.is_l1_diff_enabled():
93
95
  # for hidden states L1 diff
94
- cache_context.set_Fn_buffer(hidden_states, "Fn_hidden_states")
96
+ self.cache_manager.set_Fn_buffer(
97
+ hidden_states,
98
+ f"{self.cache_prefix}_Fn_hidden_states",
99
+ )
95
100
  del Fn_hidden_states_residual
96
101
  torch._dynamo.graph_break()
97
102
  (
@@ -108,29 +113,29 @@ class DBCachedBlocks_Pattern_3_4_5(DBCachedBlocks_Pattern_Base):
108
113
  **kwargs,
109
114
  )
110
115
  torch._dynamo.graph_break()
111
- if cache_context.is_cache_residual():
112
- cache_context.set_Bn_buffer(
116
+ if self.cache_manager.is_cache_residual():
117
+ self.cache_manager.set_Bn_buffer(
113
118
  hidden_states_residual,
114
- prefix="Bn_residual",
119
+ prefix=f"{self.cache_prefix}_Bn_residual",
115
120
  )
116
121
  else:
117
122
  # TaylorSeer
118
- cache_context.set_Bn_buffer(
123
+ self.cache_manager.set_Bn_buffer(
119
124
  hidden_states,
120
- prefix="Bn_hidden_states",
125
+ prefix=f"{self.cache_prefix}_Bn_hidden_states",
121
126
  )
122
- if cache_context.is_encoder_cache_residual():
123
- cache_context.set_Bn_encoder_buffer(
127
+ if self.cache_manager.is_encoder_cache_residual():
128
+ self.cache_manager.set_Bn_encoder_buffer(
124
129
  # None Pattern 3, else 4, 5
125
130
  encoder_hidden_states_residual,
126
- prefix="Bn_residual",
131
+ prefix=f"{self.cache_prefix}_Bn_residual",
127
132
  )
128
133
  else:
129
134
  # TaylorSeer
130
- cache_context.set_Bn_encoder_buffer(
135
+ self.cache_manager.set_Bn_encoder_buffer(
131
136
  # None Pattern 3, else 4, 5
132
137
  encoder_hidden_states,
133
- prefix="Bn_hidden_states",
138
+ prefix=f"{self.cache_prefix}_Bn_hidden_states",
134
139
  )
135
140
  torch._dynamo.graph_break()
136
141
  # Call last `n` blocks to further process the hidden states
@@ -143,7 +148,6 @@ class DBCachedBlocks_Pattern_3_4_5(DBCachedBlocks_Pattern_Base):
143
148
  **kwargs,
144
149
  )
145
150
 
146
- patch_cached_stats(self.transformer)
147
151
  torch._dynamo.graph_break()
148
152
 
149
153
  return (
@@ -162,10 +166,10 @@ class DBCachedBlocks_Pattern_3_4_5(DBCachedBlocks_Pattern_Base):
162
166
  *args,
163
167
  **kwargs,
164
168
  ):
165
- assert cache_context.Fn_compute_blocks() <= len(
169
+ assert self.cache_manager.Fn_compute_blocks() <= len(
166
170
  self.transformer_blocks
167
171
  ), (
168
- f"Fn_compute_blocks {cache_context.Fn_compute_blocks()} must be less than "
172
+ f"Fn_compute_blocks {self.cache_manager.Fn_compute_blocks()} must be less than "
169
173
  f"the number of transformer blocks {len(self.transformer_blocks)}"
170
174
  )
171
175
  encoder_hidden_states = None # Pattern 3
@@ -237,16 +241,16 @@ class DBCachedBlocks_Pattern_3_4_5(DBCachedBlocks_Pattern_Base):
237
241
  *args,
238
242
  **kwargs,
239
243
  ):
240
- if cache_context.Bn_compute_blocks() == 0:
244
+ if self.cache_manager.Bn_compute_blocks() == 0:
241
245
  return hidden_states, encoder_hidden_states
242
246
 
243
- assert cache_context.Bn_compute_blocks() <= len(
247
+ assert self.cache_manager.Bn_compute_blocks() <= len(
244
248
  self.transformer_blocks
245
249
  ), (
246
- f"Bn_compute_blocks {cache_context.Bn_compute_blocks()} must be less than "
250
+ f"Bn_compute_blocks {self.cache_manager.Bn_compute_blocks()} must be less than "
247
251
  f"the number of transformer blocks {len(self.transformer_blocks)}"
248
252
  )
249
- if len(cache_context.Bn_compute_blocks_ids()) > 0:
253
+ if len(self.cache_manager.Bn_compute_blocks_ids()) > 0:
250
254
  raise ValueError(
251
255
  f"Bn_compute_blocks_ids is not support for "
252
256
  f"patterns: {self._supported_patterns}."