cache-dit 0.2.27__py3-none-any.whl → 0.2.28__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- cache_dit/__init__.py +1 -0
- cache_dit/_version.py +2 -2
- cache_dit/cache_factory/__init__.py +2 -0
- cache_dit/cache_factory/block_adapters/__init__.py +22 -5
- cache_dit/cache_factory/block_adapters/block_adapters.py +230 -25
- cache_dit/cache_factory/cache_adapters.py +209 -94
- cache_dit/cache_factory/cache_blocks/__init__.py +55 -4
- cache_dit/cache_factory/cache_blocks/pattern_3_4_5.py +36 -37
- cache_dit/cache_factory/cache_blocks/pattern_base.py +83 -76
- cache_dit/cache_factory/cache_blocks/utils.py +10 -8
- cache_dit/cache_factory/cache_contexts/__init__.py +4 -1
- cache_dit/cache_factory/cache_contexts/cache_context.py +14 -876
- cache_dit/cache_factory/cache_contexts/cache_manager.py +833 -0
- cache_dit/cache_factory/cache_interface.py +10 -13
- cache_dit/utils.py +7 -10
- {cache_dit-0.2.27.dist-info → cache_dit-0.2.28.dist-info}/METADATA +30 -24
- {cache_dit-0.2.27.dist-info → cache_dit-0.2.28.dist-info}/RECORD +21 -21
- {cache_dit-0.2.27.dist-info → cache_dit-0.2.28.dist-info}/WHEEL +0 -0
- {cache_dit-0.2.27.dist-info → cache_dit-0.2.28.dist-info}/entry_points.txt +0 -0
- {cache_dit-0.2.27.dist-info → cache_dit-0.2.28.dist-info}/licenses/LICENSE +0 -0
- {cache_dit-0.2.27.dist-info → cache_dit-0.2.28.dist-info}/top_level.txt +0 -0
|
@@ -3,14 +3,16 @@ import torch
|
|
|
3
3
|
import unittest
|
|
4
4
|
import functools
|
|
5
5
|
|
|
6
|
-
from typing import Dict
|
|
7
6
|
from contextlib import ExitStack
|
|
7
|
+
from typing import Dict, List, Tuple, Any
|
|
8
|
+
|
|
8
9
|
from diffusers import DiffusionPipeline
|
|
10
|
+
|
|
9
11
|
from cache_dit.cache_factory import CacheType
|
|
10
|
-
from cache_dit.cache_factory import CachedContext
|
|
11
|
-
from cache_dit.cache_factory import ForwardPattern
|
|
12
12
|
from cache_dit.cache_factory import BlockAdapter
|
|
13
|
+
from cache_dit.cache_factory import ParamsModifier
|
|
13
14
|
from cache_dit.cache_factory import BlockAdapterRegistry
|
|
15
|
+
from cache_dit.cache_factory import CachedContextManager
|
|
14
16
|
from cache_dit.cache_factory import CachedBlocks
|
|
15
17
|
|
|
16
18
|
from cache_dit.logger import init_logger
|
|
@@ -29,7 +31,6 @@ class CachedAdapter:
|
|
|
29
31
|
cls,
|
|
30
32
|
pipe: DiffusionPipeline = None,
|
|
31
33
|
block_adapter: BlockAdapter = None,
|
|
32
|
-
# forward_pattern: ForwardPattern = ForwardPattern.Pattern_0,
|
|
33
34
|
**cache_context_kwargs,
|
|
34
35
|
) -> DiffusionPipeline:
|
|
35
36
|
assert (
|
|
@@ -74,54 +75,67 @@ class CachedAdapter:
|
|
|
74
75
|
)
|
|
75
76
|
|
|
76
77
|
if BlockAdapter.check_block_adapter(block_adapter):
|
|
78
|
+
|
|
79
|
+
# 0. Must normalize block_adapter before apply cache
|
|
77
80
|
block_adapter = BlockAdapter.normalize(block_adapter)
|
|
78
|
-
|
|
81
|
+
if BlockAdapter.is_cached(block_adapter):
|
|
82
|
+
return block_adapter.pipe
|
|
83
|
+
|
|
84
|
+
# 1. Apply cache on pipeline: wrap cache context, must
|
|
85
|
+
# call create_context before mock_blocks.
|
|
79
86
|
cls.create_context(
|
|
80
87
|
block_adapter,
|
|
81
88
|
**cache_context_kwargs,
|
|
82
89
|
)
|
|
83
|
-
|
|
90
|
+
|
|
91
|
+
# 2. Apply cache on transformer: mock cached blocks
|
|
84
92
|
cls.mock_blocks(
|
|
85
93
|
block_adapter,
|
|
86
94
|
)
|
|
87
|
-
|
|
88
|
-
block_adapter,
|
|
89
|
-
**cache_context_kwargs,
|
|
90
|
-
)
|
|
95
|
+
|
|
91
96
|
return block_adapter.pipe
|
|
92
97
|
|
|
93
98
|
@classmethod
|
|
94
99
|
def patch_params(
|
|
95
100
|
cls,
|
|
96
101
|
block_adapter: BlockAdapter,
|
|
97
|
-
|
|
102
|
+
contexts_kwargs: List[Dict],
|
|
98
103
|
):
|
|
99
|
-
block_adapter.
|
|
100
|
-
|
|
101
|
-
|
|
102
|
-
block_adapter.transformer
|
|
103
|
-
|
|
104
|
-
|
|
105
|
-
|
|
106
|
-
|
|
107
|
-
|
|
108
|
-
|
|
109
|
-
|
|
110
|
-
block_adapter.
|
|
111
|
-
|
|
112
|
-
|
|
113
|
-
|
|
104
|
+
block_adapter.pipe._cache_context_kwargs = contexts_kwargs[0]
|
|
105
|
+
|
|
106
|
+
params_shift = 0
|
|
107
|
+
for i in range(len(block_adapter.transformer)):
|
|
108
|
+
|
|
109
|
+
block_adapter.transformer[i]._forward_pattern = (
|
|
110
|
+
block_adapter.forward_pattern
|
|
111
|
+
)
|
|
112
|
+
block_adapter.transformer[i]._has_separate_cfg = (
|
|
113
|
+
block_adapter.has_separate_cfg
|
|
114
|
+
)
|
|
115
|
+
block_adapter.transformer[i]._cache_context_kwargs = (
|
|
116
|
+
contexts_kwargs[params_shift]
|
|
117
|
+
)
|
|
118
|
+
|
|
119
|
+
blocks = block_adapter.blocks[i]
|
|
120
|
+
for j in range(len(blocks)):
|
|
121
|
+
blocks[j]._forward_pattern = block_adapter.forward_pattern[i][j]
|
|
122
|
+
blocks[j]._cache_context_kwargs = contexts_kwargs[
|
|
123
|
+
params_shift + j
|
|
124
|
+
]
|
|
125
|
+
|
|
126
|
+
params_shift += len(blocks)
|
|
114
127
|
|
|
115
128
|
@classmethod
|
|
116
129
|
def check_context_kwargs(cls, pipe, **cache_context_kwargs):
|
|
117
130
|
# Check cache_context_kwargs
|
|
118
|
-
if not cache_context_kwargs["
|
|
131
|
+
if not cache_context_kwargs["enable_spearate_cfg"]:
|
|
119
132
|
# Check cfg for some specific case if users don't set it as True
|
|
120
|
-
cache_context_kwargs["
|
|
133
|
+
cache_context_kwargs["enable_spearate_cfg"] = (
|
|
121
134
|
BlockAdapterRegistry.has_separate_cfg(pipe)
|
|
122
135
|
)
|
|
123
136
|
logger.info(
|
|
124
|
-
f"Use default '
|
|
137
|
+
f"Use default 'enable_spearate_cfg': "
|
|
138
|
+
f"{cache_context_kwargs['enable_spearate_cfg']}, "
|
|
125
139
|
f"Pipeline: {pipe.__class__.__name__}."
|
|
126
140
|
)
|
|
127
141
|
|
|
@@ -138,7 +152,10 @@ class CachedAdapter:
|
|
|
138
152
|
block_adapter: BlockAdapter,
|
|
139
153
|
**cache_context_kwargs,
|
|
140
154
|
) -> DiffusionPipeline:
|
|
141
|
-
|
|
155
|
+
|
|
156
|
+
BlockAdapter.assert_normalized(block_adapter)
|
|
157
|
+
|
|
158
|
+
if BlockAdapter.is_cached(block_adapter.pipe):
|
|
142
159
|
return block_adapter.pipe
|
|
143
160
|
|
|
144
161
|
# Check cache_context_kwargs
|
|
@@ -147,22 +164,35 @@ class CachedAdapter:
|
|
|
147
164
|
**cache_context_kwargs,
|
|
148
165
|
)
|
|
149
166
|
# Apply cache on pipeline: wrap cache context
|
|
150
|
-
|
|
151
|
-
|
|
152
|
-
|
|
167
|
+
pipe_cls_name = block_adapter.pipe.__class__.__name__
|
|
168
|
+
|
|
169
|
+
# Each Pipeline should have it's own context manager instance.
|
|
170
|
+
# Different transformers (Wan2.2, etc) should shared the same
|
|
171
|
+
# cache manager but with different cache context (according
|
|
172
|
+
# to their unique instance id).
|
|
173
|
+
cache_manager = CachedContextManager(
|
|
174
|
+
name=f"{pipe_cls_name}_{hash(id(block_adapter.pipe))}",
|
|
175
|
+
)
|
|
176
|
+
block_adapter.pipe._cache_manager = cache_manager # instance level
|
|
177
|
+
|
|
178
|
+
flatten_contexts, contexts_kwargs = cls.modify_context_params(
|
|
179
|
+
block_adapter, cache_manager, **cache_context_kwargs
|
|
153
180
|
)
|
|
181
|
+
|
|
154
182
|
original_call = block_adapter.pipe.__class__.__call__
|
|
155
183
|
|
|
156
184
|
@functools.wraps(original_call)
|
|
157
185
|
def new_call(self, *args, **kwargs):
|
|
158
186
|
with ExitStack() as stack:
|
|
159
|
-
# cache context will reset for each pipe inference
|
|
160
|
-
for
|
|
187
|
+
# cache context will be reset for each pipe inference
|
|
188
|
+
for context_name, context_kwargs in zip(
|
|
189
|
+
flatten_contexts, contexts_kwargs
|
|
190
|
+
):
|
|
161
191
|
stack.enter_context(
|
|
162
|
-
|
|
163
|
-
|
|
164
|
-
|
|
165
|
-
**
|
|
192
|
+
cache_manager.enter_context(
|
|
193
|
+
cache_manager.reset_context(
|
|
194
|
+
context_name,
|
|
195
|
+
**context_kwargs,
|
|
166
196
|
),
|
|
167
197
|
)
|
|
168
198
|
)
|
|
@@ -171,109 +201,194 @@ class CachedAdapter:
|
|
|
171
201
|
return outputs
|
|
172
202
|
|
|
173
203
|
block_adapter.pipe.__class__.__call__ = new_call
|
|
204
|
+
block_adapter.pipe.__class__._original_call = original_call
|
|
174
205
|
block_adapter.pipe.__class__._is_cached = True
|
|
206
|
+
|
|
207
|
+
cls.patch_params(block_adapter, contexts_kwargs)
|
|
208
|
+
|
|
175
209
|
return block_adapter.pipe
|
|
176
210
|
|
|
177
211
|
@classmethod
|
|
178
|
-
def
|
|
212
|
+
def modify_context_params(
|
|
213
|
+
cls,
|
|
214
|
+
block_adapter: BlockAdapter,
|
|
215
|
+
cache_manager: CachedContextManager,
|
|
216
|
+
**cache_context_kwargs,
|
|
217
|
+
) -> Tuple[List[str], List[Dict[str, Any]]]:
|
|
218
|
+
|
|
219
|
+
flatten_contexts = BlockAdapter.flatten(
|
|
220
|
+
block_adapter.unique_blocks_name
|
|
221
|
+
)
|
|
222
|
+
contexts_kwargs = [
|
|
223
|
+
cache_context_kwargs.copy()
|
|
224
|
+
for _ in range(
|
|
225
|
+
len(flatten_contexts),
|
|
226
|
+
)
|
|
227
|
+
]
|
|
228
|
+
|
|
229
|
+
for i in range(len(contexts_kwargs)):
|
|
230
|
+
contexts_kwargs[i]["name"] = flatten_contexts[i]
|
|
231
|
+
|
|
232
|
+
if block_adapter.params_modifiers is None:
|
|
233
|
+
return flatten_contexts, contexts_kwargs
|
|
234
|
+
|
|
235
|
+
flatten_modifiers: List[ParamsModifier] = BlockAdapter.flatten(
|
|
236
|
+
block_adapter.params_modifiers,
|
|
237
|
+
)
|
|
238
|
+
|
|
239
|
+
for i in range(
|
|
240
|
+
min(len(contexts_kwargs), len(flatten_modifiers)),
|
|
241
|
+
):
|
|
242
|
+
contexts_kwargs[i].update(
|
|
243
|
+
flatten_modifiers[i]._context_kwargs,
|
|
244
|
+
)
|
|
245
|
+
contexts_kwargs[i], _ = cache_manager.collect_cache_kwargs(
|
|
246
|
+
default_attrs={}, **contexts_kwargs[i]
|
|
247
|
+
)
|
|
248
|
+
|
|
249
|
+
return flatten_contexts, contexts_kwargs
|
|
250
|
+
|
|
251
|
+
@classmethod
|
|
252
|
+
def patch_stats(
|
|
253
|
+
cls,
|
|
254
|
+
block_adapter: BlockAdapter,
|
|
255
|
+
):
|
|
179
256
|
from cache_dit.cache_factory.cache_blocks.utils import (
|
|
180
257
|
patch_cached_stats,
|
|
181
258
|
)
|
|
182
259
|
|
|
183
|
-
|
|
184
|
-
|
|
185
|
-
|
|
186
|
-
|
|
187
|
-
|
|
260
|
+
cache_manager = block_adapter.pipe._cache_manager
|
|
261
|
+
|
|
262
|
+
for i in range(len(block_adapter.transformer)):
|
|
263
|
+
patch_cached_stats(
|
|
264
|
+
block_adapter.transformer[i],
|
|
265
|
+
cache_context=block_adapter.unique_blocks_name[i][-1],
|
|
266
|
+
cache_manager=cache_manager,
|
|
267
|
+
)
|
|
268
|
+
for blocks, unique_name in zip(
|
|
269
|
+
block_adapter.blocks[i],
|
|
270
|
+
block_adapter.unique_blocks_name[i],
|
|
271
|
+
):
|
|
272
|
+
patch_cached_stats(
|
|
273
|
+
blocks,
|
|
274
|
+
cache_context=unique_name,
|
|
275
|
+
cache_manager=cache_manager,
|
|
276
|
+
)
|
|
188
277
|
|
|
189
278
|
@classmethod
|
|
190
279
|
def mock_blocks(
|
|
191
280
|
cls,
|
|
192
281
|
block_adapter: BlockAdapter,
|
|
193
|
-
) -> torch.nn.Module:
|
|
282
|
+
) -> List[torch.nn.Module]:
|
|
283
|
+
|
|
284
|
+
BlockAdapter.assert_normalized(block_adapter)
|
|
194
285
|
|
|
195
|
-
if
|
|
286
|
+
if BlockAdapter.is_cached(block_adapter.transformer):
|
|
196
287
|
return block_adapter.transformer
|
|
197
288
|
|
|
198
|
-
#
|
|
199
|
-
|
|
200
|
-
|
|
201
|
-
|
|
289
|
+
# Apply cache on transformer: mock cached transformer blocks
|
|
290
|
+
for (
|
|
291
|
+
cached_blocks,
|
|
292
|
+
transformer,
|
|
293
|
+
blocks_name,
|
|
294
|
+
unique_blocks_name,
|
|
295
|
+
dummy_blocks_names,
|
|
296
|
+
) in zip(
|
|
297
|
+
cls.collect_cached_blocks(block_adapter),
|
|
298
|
+
block_adapter.transformer,
|
|
299
|
+
block_adapter.blocks_name,
|
|
300
|
+
block_adapter.unique_blocks_name,
|
|
301
|
+
block_adapter.dummy_blocks_names,
|
|
202
302
|
):
|
|
203
|
-
|
|
204
|
-
|
|
205
|
-
|
|
206
|
-
|
|
207
|
-
|
|
208
|
-
|
|
209
|
-
f"supported lists: {ForwardPattern.supported_patterns()}"
|
|
303
|
+
cls.mock_transformer(
|
|
304
|
+
cached_blocks,
|
|
305
|
+
transformer,
|
|
306
|
+
blocks_name,
|
|
307
|
+
unique_blocks_name,
|
|
308
|
+
dummy_blocks_names,
|
|
210
309
|
)
|
|
211
310
|
|
|
212
|
-
|
|
213
|
-
|
|
214
|
-
|
|
215
|
-
|
|
216
|
-
|
|
217
|
-
cached_blocks
|
|
218
|
-
|
|
219
|
-
|
|
311
|
+
return block_adapter.transformer
|
|
312
|
+
|
|
313
|
+
@classmethod
|
|
314
|
+
def mock_transformer(
|
|
315
|
+
cls,
|
|
316
|
+
cached_blocks: Dict[str, torch.nn.ModuleList],
|
|
317
|
+
transformer: torch.nn.Module,
|
|
318
|
+
blocks_name: List[str],
|
|
319
|
+
unique_blocks_name: List[str],
|
|
320
|
+
dummy_blocks_names: List[str],
|
|
321
|
+
) -> torch.nn.Module:
|
|
220
322
|
dummy_blocks = torch.nn.ModuleList()
|
|
221
323
|
|
|
222
|
-
original_forward =
|
|
324
|
+
original_forward = transformer.forward
|
|
223
325
|
|
|
224
|
-
assert isinstance(
|
|
326
|
+
assert isinstance(dummy_blocks_names, list)
|
|
225
327
|
|
|
226
328
|
@functools.wraps(original_forward)
|
|
227
329
|
def new_forward(self, *args, **kwargs):
|
|
228
330
|
with ExitStack() as stack:
|
|
229
|
-
for
|
|
331
|
+
for name, context_name in zip(
|
|
332
|
+
blocks_name,
|
|
333
|
+
unique_blocks_name,
|
|
334
|
+
):
|
|
230
335
|
stack.enter_context(
|
|
231
336
|
unittest.mock.patch.object(
|
|
232
|
-
self,
|
|
233
|
-
blocks_name,
|
|
234
|
-
cached_blocks[blocks_name],
|
|
337
|
+
self, name, cached_blocks[context_name]
|
|
235
338
|
)
|
|
236
339
|
)
|
|
237
|
-
for dummy_name in
|
|
340
|
+
for dummy_name in dummy_blocks_names:
|
|
238
341
|
stack.enter_context(
|
|
239
342
|
unittest.mock.patch.object(
|
|
240
|
-
self,
|
|
241
|
-
dummy_name,
|
|
242
|
-
dummy_blocks,
|
|
343
|
+
self, dummy_name, dummy_blocks
|
|
243
344
|
)
|
|
244
345
|
)
|
|
245
346
|
return original_forward(*args, **kwargs)
|
|
246
347
|
|
|
247
|
-
|
|
248
|
-
|
|
249
|
-
|
|
250
|
-
block_adapter.transformer._is_cached = True
|
|
348
|
+
transformer.forward = new_forward.__get__(transformer)
|
|
349
|
+
transformer._original_forward = original_forward
|
|
350
|
+
transformer._is_cached = True
|
|
251
351
|
|
|
252
|
-
return
|
|
352
|
+
return transformer
|
|
253
353
|
|
|
254
354
|
@classmethod
|
|
255
355
|
def collect_cached_blocks(
|
|
256
356
|
cls,
|
|
257
357
|
block_adapter: BlockAdapter,
|
|
258
|
-
) -> Dict[str, torch.nn.ModuleList]:
|
|
259
|
-
|
|
358
|
+
) -> List[Dict[str, torch.nn.ModuleList]]:
|
|
359
|
+
|
|
360
|
+
BlockAdapter.assert_normalized(block_adapter)
|
|
260
361
|
|
|
261
|
-
|
|
362
|
+
total_cached_blocks: List[Dict[str, torch.nn.ModuleList]] = []
|
|
363
|
+
assert hasattr(block_adapter.pipe, "_cache_manager")
|
|
364
|
+
assert isinstance(
|
|
365
|
+
block_adapter.pipe._cache_manager, CachedContextManager
|
|
366
|
+
)
|
|
367
|
+
|
|
368
|
+
for i in range(len(block_adapter.transformer)):
|
|
262
369
|
|
|
263
|
-
|
|
264
|
-
|
|
265
|
-
|
|
370
|
+
cached_blocks_bind_context = {}
|
|
371
|
+
for j in range(len(block_adapter.blocks[i])):
|
|
372
|
+
cached_blocks_bind_context[
|
|
373
|
+
block_adapter.unique_blocks_name[i][j]
|
|
374
|
+
] = torch.nn.ModuleList(
|
|
266
375
|
[
|
|
267
376
|
CachedBlocks(
|
|
268
|
-
|
|
269
|
-
block_adapter.
|
|
270
|
-
block_adapter.
|
|
271
|
-
|
|
272
|
-
forward_pattern=block_adapter.forward_pattern[i],
|
|
377
|
+
# 0. Transformer blocks configuration
|
|
378
|
+
block_adapter.blocks[i][j],
|
|
379
|
+
transformer=block_adapter.transformer[i],
|
|
380
|
+
forward_pattern=block_adapter.forward_pattern[i][j],
|
|
273
381
|
check_num_outputs=block_adapter.check_num_outputs,
|
|
382
|
+
# 1. Cache context configuration
|
|
383
|
+
cache_prefix=block_adapter.blocks_name[i][j],
|
|
384
|
+
cache_context=block_adapter.unique_blocks_name[i][
|
|
385
|
+
j
|
|
386
|
+
],
|
|
387
|
+
cache_manager=block_adapter.pipe._cache_manager,
|
|
274
388
|
)
|
|
275
389
|
]
|
|
276
390
|
)
|
|
277
|
-
)
|
|
278
391
|
|
|
279
|
-
|
|
392
|
+
total_cached_blocks.append(cached_blocks_bind_context)
|
|
393
|
+
|
|
394
|
+
return total_cached_blocks
|
|
@@ -1,3 +1,11 @@
|
|
|
1
|
+
import torch
|
|
2
|
+
|
|
3
|
+
from cache_dit.cache_factory import ForwardPattern
|
|
4
|
+
from cache_dit.cache_factory.cache_contexts.cache_context import CachedContext
|
|
5
|
+
from cache_dit.cache_factory.cache_contexts.cache_manager import (
|
|
6
|
+
CachedContextManager,
|
|
7
|
+
)
|
|
8
|
+
|
|
1
9
|
from cache_dit.cache_factory.cache_blocks.pattern_0_1_2 import (
|
|
2
10
|
CachedBlocks_Pattern_0_1_2,
|
|
3
11
|
)
|
|
@@ -5,14 +13,57 @@ from cache_dit.cache_factory.cache_blocks.pattern_3_4_5 import (
|
|
|
5
13
|
CachedBlocks_Pattern_3_4_5,
|
|
6
14
|
)
|
|
7
15
|
|
|
16
|
+
from cache_dit.logger import init_logger
|
|
17
|
+
|
|
18
|
+
logger = init_logger(__name__)
|
|
19
|
+
|
|
8
20
|
|
|
9
21
|
class CachedBlocks:
|
|
10
|
-
def __new__(
|
|
11
|
-
|
|
22
|
+
def __new__(
|
|
23
|
+
cls,
|
|
24
|
+
# 0. Transformer blocks configuration
|
|
25
|
+
transformer_blocks: torch.nn.ModuleList,
|
|
26
|
+
transformer: torch.nn.Module = None,
|
|
27
|
+
forward_pattern: ForwardPattern = None,
|
|
28
|
+
check_num_outputs: bool = True,
|
|
29
|
+
# 1. Cache context configuration
|
|
30
|
+
# 'transformer_blocks', 'blocks', 'single_transformer_blocks',
|
|
31
|
+
# 'layers', 'single_stream_blocks', 'double_stream_blocks'
|
|
32
|
+
cache_prefix: str = None, # cache_prefix maybe un-need.
|
|
33
|
+
# Usually, blocks_name, etc.
|
|
34
|
+
cache_context: CachedContext | str = None,
|
|
35
|
+
cache_manager: CachedContextManager = None,
|
|
36
|
+
**kwargs,
|
|
37
|
+
):
|
|
38
|
+
assert transformer is not None, "transformer can't be None."
|
|
12
39
|
assert forward_pattern is not None, "forward_pattern can't be None."
|
|
40
|
+
assert cache_context is not None, "cache_context can't be None."
|
|
41
|
+
assert cache_manager is not None, "cache_manager can't be None."
|
|
13
42
|
if forward_pattern in CachedBlocks_Pattern_0_1_2._supported_patterns:
|
|
14
|
-
return CachedBlocks_Pattern_0_1_2(
|
|
43
|
+
return CachedBlocks_Pattern_0_1_2(
|
|
44
|
+
# 0. Transformer blocks configuration
|
|
45
|
+
transformer_blocks,
|
|
46
|
+
transformer=transformer,
|
|
47
|
+
forward_pattern=forward_pattern,
|
|
48
|
+
check_num_outputs=check_num_outputs,
|
|
49
|
+
# 1. Cache context configuration
|
|
50
|
+
cache_prefix=cache_prefix,
|
|
51
|
+
cache_context=cache_context,
|
|
52
|
+
cache_manager=cache_manager,
|
|
53
|
+
**kwargs,
|
|
54
|
+
)
|
|
15
55
|
elif forward_pattern in CachedBlocks_Pattern_3_4_5._supported_patterns:
|
|
16
|
-
return CachedBlocks_Pattern_3_4_5(
|
|
56
|
+
return CachedBlocks_Pattern_3_4_5(
|
|
57
|
+
# 0. Transformer blocks configuration
|
|
58
|
+
transformer_blocks,
|
|
59
|
+
transformer=transformer,
|
|
60
|
+
forward_pattern=forward_pattern,
|
|
61
|
+
check_num_outputs=check_num_outputs,
|
|
62
|
+
# 1. Cache context configuration
|
|
63
|
+
cache_prefix=cache_prefix,
|
|
64
|
+
cache_context=cache_context,
|
|
65
|
+
cache_manager=cache_manager,
|
|
66
|
+
**kwargs,
|
|
67
|
+
)
|
|
17
68
|
else:
|
|
18
69
|
raise ValueError(f"Pattern {forward_pattern} is not supported now!")
|
|
@@ -1,6 +1,5 @@
|
|
|
1
1
|
import torch
|
|
2
2
|
|
|
3
|
-
from cache_dit.cache_factory import CachedContext
|
|
4
3
|
from cache_dit.cache_factory import ForwardPattern
|
|
5
4
|
from cache_dit.cache_factory.cache_blocks.pattern_base import (
|
|
6
5
|
CachedBlocks_Pattern_Base,
|
|
@@ -24,7 +23,7 @@ class CachedBlocks_Pattern_3_4_5(CachedBlocks_Pattern_Base):
|
|
|
24
23
|
**kwargs,
|
|
25
24
|
):
|
|
26
25
|
# Use it's own cache context.
|
|
27
|
-
|
|
26
|
+
self.cache_manager.set_context(
|
|
28
27
|
self.cache_context,
|
|
29
28
|
)
|
|
30
29
|
|
|
@@ -41,40 +40,40 @@ class CachedBlocks_Pattern_3_4_5(CachedBlocks_Pattern_Base):
|
|
|
41
40
|
Fn_hidden_states_residual = hidden_states - original_hidden_states
|
|
42
41
|
del original_hidden_states
|
|
43
42
|
|
|
44
|
-
|
|
43
|
+
self.cache_manager.mark_step_begin()
|
|
45
44
|
# Residual L1 diff or Hidden States L1 diff
|
|
46
|
-
can_use_cache =
|
|
45
|
+
can_use_cache = self.cache_manager.can_cache(
|
|
47
46
|
(
|
|
48
47
|
Fn_hidden_states_residual
|
|
49
|
-
if not
|
|
48
|
+
if not self.cache_manager.is_l1_diff_enabled()
|
|
50
49
|
else hidden_states
|
|
51
50
|
),
|
|
52
51
|
parallelized=self._is_parallelized(),
|
|
53
52
|
prefix=(
|
|
54
|
-
f"{self.
|
|
55
|
-
if not
|
|
56
|
-
else f"{self.
|
|
53
|
+
f"{self.cache_prefix}_Fn_residual"
|
|
54
|
+
if not self.cache_manager.is_l1_diff_enabled()
|
|
55
|
+
else f"{self.cache_prefix}_Fn_hidden_states"
|
|
57
56
|
),
|
|
58
57
|
)
|
|
59
58
|
|
|
60
59
|
torch._dynamo.graph_break()
|
|
61
60
|
if can_use_cache:
|
|
62
|
-
|
|
61
|
+
self.cache_manager.add_cached_step()
|
|
63
62
|
del Fn_hidden_states_residual
|
|
64
63
|
hidden_states, encoder_hidden_states = (
|
|
65
|
-
|
|
64
|
+
self.cache_manager.apply_cache(
|
|
66
65
|
hidden_states,
|
|
67
66
|
# None Pattern 3, else 4, 5
|
|
68
67
|
encoder_hidden_states,
|
|
69
68
|
prefix=(
|
|
70
|
-
f"{self.
|
|
71
|
-
if
|
|
72
|
-
else f"{self.
|
|
69
|
+
f"{self.cache_prefix}_Bn_residual"
|
|
70
|
+
if self.cache_manager.is_cache_residual()
|
|
71
|
+
else f"{self.cache_prefix}_Bn_hidden_states"
|
|
73
72
|
),
|
|
74
73
|
encoder_prefix=(
|
|
75
|
-
f"{self.
|
|
76
|
-
if
|
|
77
|
-
else f"{self.
|
|
74
|
+
f"{self.cache_prefix}_Bn_residual"
|
|
75
|
+
if self.cache_manager.is_encoder_cache_residual()
|
|
76
|
+
else f"{self.cache_prefix}_Bn_hidden_states"
|
|
78
77
|
),
|
|
79
78
|
)
|
|
80
79
|
)
|
|
@@ -88,15 +87,15 @@ class CachedBlocks_Pattern_3_4_5(CachedBlocks_Pattern_Base):
|
|
|
88
87
|
**kwargs,
|
|
89
88
|
)
|
|
90
89
|
else:
|
|
91
|
-
|
|
90
|
+
self.cache_manager.set_Fn_buffer(
|
|
92
91
|
Fn_hidden_states_residual,
|
|
93
|
-
prefix=f"{self.
|
|
92
|
+
prefix=f"{self.cache_prefix}_Fn_residual",
|
|
94
93
|
)
|
|
95
|
-
if
|
|
94
|
+
if self.cache_manager.is_l1_diff_enabled():
|
|
96
95
|
# for hidden states L1 diff
|
|
97
|
-
|
|
96
|
+
self.cache_manager.set_Fn_buffer(
|
|
98
97
|
hidden_states,
|
|
99
|
-
f"{self.
|
|
98
|
+
f"{self.cache_prefix}_Fn_hidden_states",
|
|
100
99
|
)
|
|
101
100
|
del Fn_hidden_states_residual
|
|
102
101
|
torch._dynamo.graph_break()
|
|
@@ -114,29 +113,29 @@ class CachedBlocks_Pattern_3_4_5(CachedBlocks_Pattern_Base):
|
|
|
114
113
|
**kwargs,
|
|
115
114
|
)
|
|
116
115
|
torch._dynamo.graph_break()
|
|
117
|
-
if
|
|
118
|
-
|
|
116
|
+
if self.cache_manager.is_cache_residual():
|
|
117
|
+
self.cache_manager.set_Bn_buffer(
|
|
119
118
|
hidden_states_residual,
|
|
120
|
-
prefix=f"{self.
|
|
119
|
+
prefix=f"{self.cache_prefix}_Bn_residual",
|
|
121
120
|
)
|
|
122
121
|
else:
|
|
123
122
|
# TaylorSeer
|
|
124
|
-
|
|
123
|
+
self.cache_manager.set_Bn_buffer(
|
|
125
124
|
hidden_states,
|
|
126
|
-
prefix=f"{self.
|
|
125
|
+
prefix=f"{self.cache_prefix}_Bn_hidden_states",
|
|
127
126
|
)
|
|
128
|
-
if
|
|
129
|
-
|
|
127
|
+
if self.cache_manager.is_encoder_cache_residual():
|
|
128
|
+
self.cache_manager.set_Bn_encoder_buffer(
|
|
130
129
|
# None Pattern 3, else 4, 5
|
|
131
130
|
encoder_hidden_states_residual,
|
|
132
|
-
prefix=f"{self.
|
|
131
|
+
prefix=f"{self.cache_prefix}_Bn_residual",
|
|
133
132
|
)
|
|
134
133
|
else:
|
|
135
134
|
# TaylorSeer
|
|
136
|
-
|
|
135
|
+
self.cache_manager.set_Bn_encoder_buffer(
|
|
137
136
|
# None Pattern 3, else 4, 5
|
|
138
137
|
encoder_hidden_states,
|
|
139
|
-
prefix=f"{self.
|
|
138
|
+
prefix=f"{self.cache_prefix}_Bn_hidden_states",
|
|
140
139
|
)
|
|
141
140
|
torch._dynamo.graph_break()
|
|
142
141
|
# Call last `n` blocks to further process the hidden states
|
|
@@ -167,10 +166,10 @@ class CachedBlocks_Pattern_3_4_5(CachedBlocks_Pattern_Base):
|
|
|
167
166
|
*args,
|
|
168
167
|
**kwargs,
|
|
169
168
|
):
|
|
170
|
-
assert
|
|
169
|
+
assert self.cache_manager.Fn_compute_blocks() <= len(
|
|
171
170
|
self.transformer_blocks
|
|
172
171
|
), (
|
|
173
|
-
f"Fn_compute_blocks {
|
|
172
|
+
f"Fn_compute_blocks {self.cache_manager.Fn_compute_blocks()} must be less than "
|
|
174
173
|
f"the number of transformer blocks {len(self.transformer_blocks)}"
|
|
175
174
|
)
|
|
176
175
|
encoder_hidden_states = None # Pattern 3
|
|
@@ -242,16 +241,16 @@ class CachedBlocks_Pattern_3_4_5(CachedBlocks_Pattern_Base):
|
|
|
242
241
|
*args,
|
|
243
242
|
**kwargs,
|
|
244
243
|
):
|
|
245
|
-
if
|
|
244
|
+
if self.cache_manager.Bn_compute_blocks() == 0:
|
|
246
245
|
return hidden_states, encoder_hidden_states
|
|
247
246
|
|
|
248
|
-
assert
|
|
247
|
+
assert self.cache_manager.Bn_compute_blocks() <= len(
|
|
249
248
|
self.transformer_blocks
|
|
250
249
|
), (
|
|
251
|
-
f"Bn_compute_blocks {
|
|
250
|
+
f"Bn_compute_blocks {self.cache_manager.Bn_compute_blocks()} must be less than "
|
|
252
251
|
f"the number of transformer blocks {len(self.transformer_blocks)}"
|
|
253
252
|
)
|
|
254
|
-
if len(
|
|
253
|
+
if len(self.cache_manager.Bn_compute_blocks_ids()) > 0:
|
|
255
254
|
raise ValueError(
|
|
256
255
|
f"Bn_compute_blocks_ids is not support for "
|
|
257
256
|
f"patterns: {self._supported_patterns}."
|