cache-dit 1.0.2__py3-none-any.whl → 1.0.4__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.

Potentially problematic release.


This version of cache-dit might be problematic. Click here for more details.

Files changed (29) hide show
  1. cache_dit/__init__.py +3 -0
  2. cache_dit/_version.py +2 -2
  3. cache_dit/cache_factory/__init__.py +8 -1
  4. cache_dit/cache_factory/cache_adapters/cache_adapter.py +90 -76
  5. cache_dit/cache_factory/cache_blocks/__init__.py +167 -17
  6. cache_dit/cache_factory/cache_blocks/pattern_0_1_2.py +10 -0
  7. cache_dit/cache_factory/cache_blocks/pattern_3_4_5.py +271 -36
  8. cache_dit/cache_factory/cache_blocks/pattern_base.py +286 -45
  9. cache_dit/cache_factory/cache_blocks/pattern_utils.py +55 -10
  10. cache_dit/cache_factory/cache_contexts/__init__.py +15 -2
  11. cache_dit/cache_factory/cache_contexts/cache_config.py +102 -0
  12. cache_dit/cache_factory/cache_contexts/cache_context.py +26 -89
  13. cache_dit/cache_factory/cache_contexts/cache_manager.py +7 -7
  14. cache_dit/cache_factory/cache_contexts/calibrators/taylorseer.py +78 -8
  15. cache_dit/cache_factory/cache_contexts/context_manager.py +29 -0
  16. cache_dit/cache_factory/cache_contexts/prune_config.py +69 -0
  17. cache_dit/cache_factory/cache_contexts/prune_context.py +155 -0
  18. cache_dit/cache_factory/cache_contexts/prune_manager.py +154 -0
  19. cache_dit/cache_factory/cache_interface.py +23 -14
  20. cache_dit/cache_factory/cache_types.py +19 -2
  21. cache_dit/cache_factory/params_modifier.py +7 -7
  22. cache_dit/cache_factory/utils.py +38 -27
  23. cache_dit/utils.py +191 -54
  24. {cache_dit-1.0.2.dist-info → cache_dit-1.0.4.dist-info}/METADATA +14 -7
  25. {cache_dit-1.0.2.dist-info → cache_dit-1.0.4.dist-info}/RECORD +29 -24
  26. {cache_dit-1.0.2.dist-info → cache_dit-1.0.4.dist-info}/WHEEL +0 -0
  27. {cache_dit-1.0.2.dist-info → cache_dit-1.0.4.dist-info}/entry_points.txt +0 -0
  28. {cache_dit-1.0.2.dist-info → cache_dit-1.0.4.dist-info}/licenses/LICENSE +0 -0
  29. {cache_dit-1.0.2.dist-info → cache_dit-1.0.4.dist-info}/top_level.txt +0 -0
@@ -1,28 +1,33 @@
1
1
  import torch
2
2
 
3
3
  from cache_dit.cache_factory import ForwardPattern
4
+ from cache_dit.cache_factory.cache_types import CacheType
4
5
  from cache_dit.cache_factory.cache_contexts.cache_context import CachedContext
6
+ from cache_dit.cache_factory.cache_contexts.prune_context import PrunedContext
5
7
  from cache_dit.cache_factory.cache_contexts.cache_manager import (
6
8
  CachedContextManager,
7
9
  )
10
+ from cache_dit.cache_factory.cache_contexts.prune_manager import (
11
+ PrunedContextManager,
12
+ )
8
13
 
9
14
  from cache_dit.cache_factory.cache_blocks.pattern_0_1_2 import (
10
15
  CachedBlocks_Pattern_0_1_2,
16
+ PrunedBlocks_Pattern_0_1_2,
11
17
  )
12
18
  from cache_dit.cache_factory.cache_blocks.pattern_3_4_5 import (
13
19
  CachedBlocks_Pattern_3_4_5,
20
+ PrunedBlocks_Pattern_3_4_5,
14
21
  )
15
- from cache_dit.cache_factory.cache_blocks.pattern_utils import (
16
- patch_cached_stats,
17
- remove_cached_stats,
18
- )
22
+ from cache_dit.cache_factory.cache_blocks.pattern_utils import apply_stats
23
+ from cache_dit.cache_factory.cache_blocks.pattern_utils import remove_stats
19
24
 
20
25
  from cache_dit.logger import init_logger
21
26
 
22
27
  logger = init_logger(__name__)
23
28
 
24
29
 
25
- class CachedBlocks:
30
+ class UnifiedBlocks:
26
31
  def __new__(
27
32
  cls,
28
33
  # 0. Transformer blocks configuration
@@ -36,16 +41,13 @@ class CachedBlocks:
36
41
  # 'layers', 'single_stream_blocks', 'double_stream_blocks'
37
42
  cache_prefix: str = None, # cache_prefix maybe un-need.
38
43
  # Usually, blocks_name, etc.
39
- cache_context: CachedContext | str = None,
40
- cache_manager: CachedContextManager = None,
44
+ cache_context: CachedContext | PrunedContext | str = None,
45
+ context_manager: CachedContextManager | PrunedContextManager = None,
46
+ cache_type: CacheType = CacheType.DBCache,
41
47
  **kwargs,
42
48
  ):
43
- assert transformer is not None, "transformer can't be None."
44
- assert forward_pattern is not None, "forward_pattern can't be None."
45
- assert cache_context is not None, "cache_context can't be None."
46
- assert cache_manager is not None, "cache_manager can't be None."
47
- if forward_pattern in CachedBlocks_Pattern_0_1_2._supported_patterns:
48
- return CachedBlocks_Pattern_0_1_2(
49
+ if cache_type == CacheType.DBCache:
50
+ return CachedBlocks(
49
51
  # 0. Transformer blocks configuration
50
52
  transformer_blocks,
51
53
  transformer=transformer,
@@ -55,11 +57,12 @@ class CachedBlocks:
55
57
  # 1. Cache context configuration
56
58
  cache_prefix=cache_prefix,
57
59
  cache_context=cache_context,
58
- cache_manager=cache_manager,
60
+ context_manager=context_manager,
61
+ cache_type=cache_type,
59
62
  **kwargs,
60
63
  )
61
- elif forward_pattern in CachedBlocks_Pattern_3_4_5._supported_patterns:
62
- return CachedBlocks_Pattern_3_4_5(
64
+ elif cache_type == CacheType.DBPrune:
65
+ return PrunedBlocks(
63
66
  # 0. Transformer blocks configuration
64
67
  transformer_blocks,
65
68
  transformer=transformer,
@@ -69,8 +72,155 @@ class CachedBlocks:
69
72
  # 1. Cache context configuration
70
73
  cache_prefix=cache_prefix,
71
74
  cache_context=cache_context,
72
- cache_manager=cache_manager,
75
+ context_manager=context_manager,
76
+ cache_type=cache_type,
73
77
  **kwargs,
74
78
  )
79
+ else:
80
+ raise ValueError(f"Cache type {cache_type} is not supported now!")
81
+
82
+
83
+ class CachedBlocks:
84
+ def __new__(
85
+ cls,
86
+ # 0. Transformer blocks configuration
87
+ transformer_blocks: torch.nn.ModuleList,
88
+ transformer: torch.nn.Module = None,
89
+ forward_pattern: ForwardPattern = None,
90
+ check_forward_pattern: bool = True,
91
+ check_num_outputs: bool = True,
92
+ # 1. Cache context configuration
93
+ # 'transformer_blocks', 'blocks', 'single_transformer_blocks',
94
+ # 'layers', 'single_stream_blocks', 'double_stream_blocks'
95
+ cache_prefix: str = None, # cache_prefix maybe un-need.
96
+ # Usually, blocks_name, etc.
97
+ cache_context: CachedContext | PrunedContext | str = None,
98
+ context_manager: CachedContextManager | PrunedContextManager = None,
99
+ cache_type: CacheType = CacheType.DBCache,
100
+ **kwargs,
101
+ ):
102
+ assert transformer is not None, "transformer can't be None."
103
+ assert forward_pattern is not None, "forward_pattern can't be None."
104
+ assert cache_context is not None, "cache_context can't be None."
105
+ assert context_manager is not None, "context_manager can't be None."
106
+ if forward_pattern in CachedBlocks_Pattern_0_1_2._supported_patterns:
107
+ if cache_type == CacheType.DBCache:
108
+ assert isinstance(
109
+ context_manager, CachedContextManager
110
+ ), "context_manager must be CachedContextManager for DBCache."
111
+ return CachedBlocks_Pattern_0_1_2(
112
+ # 0. Transformer blocks configuration
113
+ transformer_blocks,
114
+ transformer=transformer,
115
+ forward_pattern=forward_pattern,
116
+ check_forward_pattern=check_forward_pattern,
117
+ check_num_outputs=check_num_outputs,
118
+ # 1. Cache context configuration
119
+ cache_prefix=cache_prefix,
120
+ cache_context=cache_context,
121
+ context_manager=context_manager,
122
+ cache_type=cache_type,
123
+ **kwargs,
124
+ )
125
+ else:
126
+ raise ValueError(
127
+ f"Cache type {cache_type} is not supported now!"
128
+ )
129
+ elif forward_pattern in CachedBlocks_Pattern_3_4_5._supported_patterns:
130
+ if cache_type == CacheType.DBCache:
131
+ assert isinstance(
132
+ context_manager, CachedContextManager
133
+ ), "context_manager must be CachedContextManager for DBCache."
134
+ return CachedBlocks_Pattern_3_4_5(
135
+ # 0. Transformer blocks configuration
136
+ transformer_blocks,
137
+ transformer=transformer,
138
+ forward_pattern=forward_pattern,
139
+ check_forward_pattern=check_forward_pattern,
140
+ check_num_outputs=check_num_outputs,
141
+ # 1. Cache context configuration
142
+ cache_prefix=cache_prefix,
143
+ cache_context=cache_context,
144
+ context_manager=context_manager,
145
+ cache_type=cache_type,
146
+ **kwargs,
147
+ )
148
+ else:
149
+ raise ValueError(
150
+ f"Cache type {cache_type} is not supported now!"
151
+ )
152
+ else:
153
+ raise ValueError(f"Pattern {forward_pattern} is not supported now!")
154
+
155
+
156
+ class PrunedBlocks:
157
+ def __new__(
158
+ cls,
159
+ # 0. Transformer blocks configuration
160
+ transformer_blocks: torch.nn.ModuleList,
161
+ transformer: torch.nn.Module = None,
162
+ forward_pattern: ForwardPattern = None,
163
+ check_forward_pattern: bool = True,
164
+ check_num_outputs: bool = True,
165
+ # 1. Cache context configuration
166
+ # 'transformer_blocks', 'blocks', 'single_transformer_blocks',
167
+ # 'layers', 'single_stream_blocks', 'double_stream_blocks'
168
+ cache_prefix: str = None, # cache_prefix maybe un-need.
169
+ # Usually, blocks_name, etc.
170
+ cache_context: CachedContext | PrunedContext | str = None,
171
+ context_manager: CachedContextManager | PrunedContextManager = None,
172
+ cache_type: CacheType = CacheType.DBCache,
173
+ **kwargs,
174
+ ):
175
+ assert transformer is not None, "transformer can't be None."
176
+ assert forward_pattern is not None, "forward_pattern can't be None."
177
+ assert cache_context is not None, "cache_context can't be None."
178
+ assert context_manager is not None, "context_manager can't be None."
179
+ if forward_pattern in PrunedBlocks_Pattern_0_1_2._supported_patterns:
180
+ if cache_type == CacheType.DBPrune:
181
+ assert isinstance(
182
+ context_manager, PrunedContextManager
183
+ ), "context_manager must be PrunedContextManager for DBPrune."
184
+ return PrunedBlocks_Pattern_0_1_2(
185
+ # 0. Transformer blocks configuration
186
+ transformer_blocks,
187
+ transformer=transformer,
188
+ forward_pattern=forward_pattern,
189
+ check_forward_pattern=check_forward_pattern,
190
+ check_num_outputs=check_num_outputs,
191
+ # 1. Cache context configuration
192
+ cache_prefix=cache_prefix,
193
+ cache_context=cache_context,
194
+ context_manager=context_manager,
195
+ cache_type=cache_type,
196
+ **kwargs,
197
+ )
198
+ else:
199
+ raise ValueError(
200
+ f"Cache type {cache_type} is not supported now!"
201
+ )
202
+ elif forward_pattern in PrunedBlocks_Pattern_3_4_5._supported_patterns:
203
+ if cache_type == CacheType.DBPrune:
204
+ assert isinstance(
205
+ context_manager, PrunedContextManager
206
+ ), "context_manager must be PrunedContextManager for DBPrune."
207
+ return PrunedBlocks_Pattern_3_4_5(
208
+ # 0. Transformer blocks configuration
209
+ transformer_blocks,
210
+ transformer=transformer,
211
+ forward_pattern=forward_pattern,
212
+ check_forward_pattern=check_forward_pattern,
213
+ check_num_outputs=check_num_outputs,
214
+ # 1. Cache context configuration
215
+ cache_prefix=cache_prefix,
216
+ cache_context=cache_context,
217
+ context_manager=context_manager,
218
+ cache_type=cache_type,
219
+ **kwargs,
220
+ )
221
+ else:
222
+ raise ValueError(
223
+ f"Cache type {cache_type} is not supported now!"
224
+ )
75
225
  else:
76
226
  raise ValueError(f"Pattern {forward_pattern} is not supported now!")
@@ -1,6 +1,7 @@
1
1
  from cache_dit.cache_factory import ForwardPattern
2
2
  from cache_dit.cache_factory.cache_blocks.pattern_base import (
3
3
  CachedBlocks_Pattern_Base,
4
+ PrunedBlocks_Pattern_Base,
4
5
  )
5
6
  from cache_dit.logger import init_logger
6
7
 
@@ -14,3 +15,12 @@ class CachedBlocks_Pattern_0_1_2(CachedBlocks_Pattern_Base):
14
15
  ForwardPattern.Pattern_2,
15
16
  ]
16
17
  ...
18
+
19
+
20
+ class PrunedBlocks_Pattern_0_1_2(PrunedBlocks_Pattern_Base):
21
+ _supported_patterns = [
22
+ ForwardPattern.Pattern_0,
23
+ ForwardPattern.Pattern_1,
24
+ ForwardPattern.Pattern_2,
25
+ ]
26
+ ...