cache-dit 0.2.27__py3-none-any.whl → 0.2.29__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Potentially problematic release.
This version of cache-dit might be problematic. Click here for more details.
- cache_dit/__init__.py +2 -0
- cache_dit/_version.py +2 -2
- cache_dit/cache_factory/__init__.py +3 -0
- cache_dit/cache_factory/block_adapters/__init__.py +105 -111
- cache_dit/cache_factory/block_adapters/block_adapters.py +314 -41
- cache_dit/cache_factory/block_adapters/block_registers.py +15 -6
- cache_dit/cache_factory/cache_adapters.py +244 -116
- cache_dit/cache_factory/cache_blocks/__init__.py +55 -4
- cache_dit/cache_factory/cache_blocks/pattern_3_4_5.py +36 -37
- cache_dit/cache_factory/cache_blocks/pattern_base.py +83 -76
- cache_dit/cache_factory/cache_blocks/utils.py +26 -8
- cache_dit/cache_factory/cache_contexts/__init__.py +4 -1
- cache_dit/cache_factory/cache_contexts/cache_context.py +14 -876
- cache_dit/cache_factory/cache_contexts/cache_manager.py +847 -0
- cache_dit/cache_factory/cache_interface.py +91 -24
- cache_dit/cache_factory/patch_functors/functor_chroma.py +1 -1
- cache_dit/cache_factory/patch_functors/functor_flux.py +1 -1
- cache_dit/utils.py +164 -58
- {cache_dit-0.2.27.dist-info → cache_dit-0.2.29.dist-info}/METADATA +59 -34
- {cache_dit-0.2.27.dist-info → cache_dit-0.2.29.dist-info}/RECORD +24 -24
- {cache_dit-0.2.27.dist-info → cache_dit-0.2.29.dist-info}/WHEEL +0 -0
- {cache_dit-0.2.27.dist-info → cache_dit-0.2.29.dist-info}/entry_points.txt +0 -0
- {cache_dit-0.2.27.dist-info → cache_dit-0.2.29.dist-info}/licenses/LICENSE +0 -0
- {cache_dit-0.2.27.dist-info → cache_dit-0.2.29.dist-info}/top_level.txt +0 -0
|
@@ -3,14 +3,16 @@ import torch
|
|
|
3
3
|
import unittest
|
|
4
4
|
import functools
|
|
5
5
|
|
|
6
|
-
from typing import Dict
|
|
7
6
|
from contextlib import ExitStack
|
|
7
|
+
from typing import Dict, List, Tuple, Any
|
|
8
|
+
|
|
8
9
|
from diffusers import DiffusionPipeline
|
|
10
|
+
|
|
9
11
|
from cache_dit.cache_factory import CacheType
|
|
10
|
-
from cache_dit.cache_factory import CachedContext
|
|
11
|
-
from cache_dit.cache_factory import ForwardPattern
|
|
12
12
|
from cache_dit.cache_factory import BlockAdapter
|
|
13
|
+
from cache_dit.cache_factory import ParamsModifier
|
|
13
14
|
from cache_dit.cache_factory import BlockAdapterRegistry
|
|
15
|
+
from cache_dit.cache_factory import CachedContextManager
|
|
14
16
|
from cache_dit.cache_factory import CachedBlocks
|
|
15
17
|
|
|
16
18
|
from cache_dit.logger import init_logger
|
|
@@ -27,37 +29,39 @@ class CachedAdapter:
|
|
|
27
29
|
@classmethod
|
|
28
30
|
def apply(
|
|
29
31
|
cls,
|
|
30
|
-
|
|
31
|
-
block_adapter: BlockAdapter = None,
|
|
32
|
-
# forward_pattern: ForwardPattern = ForwardPattern.Pattern_0,
|
|
32
|
+
pipe_or_adapter: DiffusionPipeline | BlockAdapter,
|
|
33
33
|
**cache_context_kwargs,
|
|
34
|
-
) ->
|
|
34
|
+
) -> BlockAdapter:
|
|
35
35
|
assert (
|
|
36
|
-
|
|
36
|
+
pipe_or_adapter is not None
|
|
37
37
|
), "pipe or block_adapter can not both None!"
|
|
38
38
|
|
|
39
|
-
if
|
|
40
|
-
if BlockAdapterRegistry.is_supported(
|
|
39
|
+
if isinstance(pipe_or_adapter, DiffusionPipeline):
|
|
40
|
+
if BlockAdapterRegistry.is_supported(pipe_or_adapter):
|
|
41
41
|
logger.info(
|
|
42
|
-
f"{
|
|
43
|
-
"Use it's pre-defined BlockAdapter
|
|
42
|
+
f"{pipe_or_adapter.__class__.__name__} is officially "
|
|
43
|
+
"supported by cache-dit. Use it's pre-defined BlockAdapter "
|
|
44
|
+
"directly!"
|
|
45
|
+
)
|
|
46
|
+
block_adapter = BlockAdapterRegistry.get_adapter(
|
|
47
|
+
pipe_or_adapter
|
|
44
48
|
)
|
|
45
|
-
block_adapter = BlockAdapterRegistry.get_adapter(pipe)
|
|
46
49
|
return cls.cachify(
|
|
47
50
|
block_adapter,
|
|
48
51
|
**cache_context_kwargs,
|
|
49
52
|
)
|
|
50
53
|
else:
|
|
51
54
|
raise ValueError(
|
|
52
|
-
f"{
|
|
55
|
+
f"{pipe_or_adapter.__class__.__name__} is not officially supported "
|
|
53
56
|
"by cache-dit, please set BlockAdapter instead!"
|
|
54
57
|
)
|
|
55
58
|
else:
|
|
59
|
+
assert isinstance(pipe_or_adapter, BlockAdapter)
|
|
56
60
|
logger.info(
|
|
57
|
-
"Adapting
|
|
61
|
+
"Adapting Cache Acceleration using custom BlockAdapter!"
|
|
58
62
|
)
|
|
59
63
|
return cls.cachify(
|
|
60
|
-
|
|
64
|
+
pipe_or_adapter,
|
|
61
65
|
**cache_context_kwargs,
|
|
62
66
|
)
|
|
63
67
|
|
|
@@ -66,7 +70,7 @@ class CachedAdapter:
|
|
|
66
70
|
cls,
|
|
67
71
|
block_adapter: BlockAdapter,
|
|
68
72
|
**cache_context_kwargs,
|
|
69
|
-
) ->
|
|
73
|
+
) -> BlockAdapter:
|
|
70
74
|
|
|
71
75
|
if block_adapter.auto:
|
|
72
76
|
block_adapter = BlockAdapter.auto_block_adapter(
|
|
@@ -74,56 +78,80 @@ class CachedAdapter:
|
|
|
74
78
|
)
|
|
75
79
|
|
|
76
80
|
if BlockAdapter.check_block_adapter(block_adapter):
|
|
81
|
+
|
|
82
|
+
# 0. Must normalize block_adapter before apply cache
|
|
77
83
|
block_adapter = BlockAdapter.normalize(block_adapter)
|
|
78
|
-
|
|
84
|
+
if BlockAdapter.is_cached(block_adapter):
|
|
85
|
+
return block_adapter.pipe
|
|
86
|
+
|
|
87
|
+
# 1. Apply cache on pipeline: wrap cache context, must
|
|
88
|
+
# call create_context before mock_blocks.
|
|
79
89
|
cls.create_context(
|
|
80
90
|
block_adapter,
|
|
81
91
|
**cache_context_kwargs,
|
|
82
92
|
)
|
|
83
|
-
|
|
93
|
+
|
|
94
|
+
# 2. Apply cache on transformer: mock cached blocks
|
|
84
95
|
cls.mock_blocks(
|
|
85
96
|
block_adapter,
|
|
86
97
|
)
|
|
87
|
-
|
|
88
|
-
|
|
89
|
-
**cache_context_kwargs,
|
|
90
|
-
)
|
|
91
|
-
return block_adapter.pipe
|
|
98
|
+
|
|
99
|
+
return block_adapter
|
|
92
100
|
|
|
93
101
|
@classmethod
|
|
94
102
|
def patch_params(
|
|
95
103
|
cls,
|
|
96
104
|
block_adapter: BlockAdapter,
|
|
97
|
-
|
|
105
|
+
contexts_kwargs: List[Dict],
|
|
98
106
|
):
|
|
99
|
-
block_adapter.
|
|
100
|
-
|
|
101
|
-
|
|
102
|
-
block_adapter.transformer
|
|
103
|
-
|
|
104
|
-
|
|
105
|
-
|
|
106
|
-
|
|
107
|
-
|
|
108
|
-
|
|
109
|
-
|
|
110
|
-
block_adapter.
|
|
111
|
-
|
|
112
|
-
|
|
113
|
-
|
|
107
|
+
block_adapter.pipe._cache_context_kwargs = contexts_kwargs[0]
|
|
108
|
+
|
|
109
|
+
params_shift = 0
|
|
110
|
+
for i in range(len(block_adapter.transformer)):
|
|
111
|
+
|
|
112
|
+
block_adapter.transformer[i]._forward_pattern = (
|
|
113
|
+
block_adapter.forward_pattern
|
|
114
|
+
)
|
|
115
|
+
block_adapter.transformer[i]._has_separate_cfg = (
|
|
116
|
+
block_adapter.has_separate_cfg
|
|
117
|
+
)
|
|
118
|
+
block_adapter.transformer[i]._cache_context_kwargs = (
|
|
119
|
+
contexts_kwargs[params_shift]
|
|
120
|
+
)
|
|
121
|
+
|
|
122
|
+
blocks = block_adapter.blocks[i]
|
|
123
|
+
for j in range(len(blocks)):
|
|
124
|
+
blocks[j]._forward_pattern = block_adapter.forward_pattern[i][j]
|
|
125
|
+
blocks[j]._cache_context_kwargs = contexts_kwargs[
|
|
126
|
+
params_shift + j
|
|
127
|
+
]
|
|
128
|
+
|
|
129
|
+
params_shift += len(blocks)
|
|
114
130
|
|
|
115
131
|
@classmethod
|
|
116
|
-
def check_context_kwargs(
|
|
132
|
+
def check_context_kwargs(
|
|
133
|
+
cls,
|
|
134
|
+
block_adapter: BlockAdapter,
|
|
135
|
+
**cache_context_kwargs,
|
|
136
|
+
):
|
|
117
137
|
# Check cache_context_kwargs
|
|
118
|
-
if not cache_context_kwargs["
|
|
138
|
+
if not cache_context_kwargs["enable_spearate_cfg"]:
|
|
119
139
|
# Check cfg for some specific case if users don't set it as True
|
|
120
|
-
|
|
121
|
-
|
|
122
|
-
|
|
123
|
-
|
|
124
|
-
|
|
125
|
-
|
|
126
|
-
|
|
140
|
+
if BlockAdapterRegistry.has_separate_cfg(block_adapter):
|
|
141
|
+
cache_context_kwargs["enable_spearate_cfg"] = True
|
|
142
|
+
logger.info(
|
|
143
|
+
f"Use custom 'enable_spearate_cfg' from BlockAdapter: True. "
|
|
144
|
+
f"Pipeline: {block_adapter.pipe.__class__.__name__}."
|
|
145
|
+
)
|
|
146
|
+
else:
|
|
147
|
+
cache_context_kwargs["enable_spearate_cfg"] = (
|
|
148
|
+
BlockAdapterRegistry.has_separate_cfg(block_adapter.pipe)
|
|
149
|
+
)
|
|
150
|
+
logger.info(
|
|
151
|
+
f"Use default 'enable_spearate_cfg' from block adapter "
|
|
152
|
+
f"register: {cache_context_kwargs['enable_spearate_cfg']}, "
|
|
153
|
+
f"Pipeline: {block_adapter.pipe.__class__.__name__}."
|
|
154
|
+
)
|
|
127
155
|
|
|
128
156
|
if cache_type := cache_context_kwargs.pop("cache_type", None):
|
|
129
157
|
assert (
|
|
@@ -138,31 +166,46 @@ class CachedAdapter:
|
|
|
138
166
|
block_adapter: BlockAdapter,
|
|
139
167
|
**cache_context_kwargs,
|
|
140
168
|
) -> DiffusionPipeline:
|
|
141
|
-
|
|
169
|
+
|
|
170
|
+
BlockAdapter.assert_normalized(block_adapter)
|
|
171
|
+
|
|
172
|
+
if BlockAdapter.is_cached(block_adapter.pipe):
|
|
142
173
|
return block_adapter.pipe
|
|
143
174
|
|
|
144
175
|
# Check cache_context_kwargs
|
|
145
176
|
cache_context_kwargs = cls.check_context_kwargs(
|
|
146
|
-
block_adapter
|
|
147
|
-
**cache_context_kwargs,
|
|
177
|
+
block_adapter, **cache_context_kwargs
|
|
148
178
|
)
|
|
149
179
|
# Apply cache on pipeline: wrap cache context
|
|
150
|
-
|
|
151
|
-
|
|
152
|
-
|
|
180
|
+
pipe_cls_name = block_adapter.pipe.__class__.__name__
|
|
181
|
+
|
|
182
|
+
# Each Pipeline should have it's own context manager instance.
|
|
183
|
+
# Different transformers (Wan2.2, etc) should shared the same
|
|
184
|
+
# cache manager but with different cache context (according
|
|
185
|
+
# to their unique instance id).
|
|
186
|
+
cache_manager = CachedContextManager(
|
|
187
|
+
name=f"{pipe_cls_name}_{hash(id(block_adapter.pipe))}",
|
|
153
188
|
)
|
|
189
|
+
block_adapter.pipe._cache_manager = cache_manager # instance level
|
|
190
|
+
|
|
191
|
+
flatten_contexts, contexts_kwargs = cls.modify_context_params(
|
|
192
|
+
block_adapter, cache_manager, **cache_context_kwargs
|
|
193
|
+
)
|
|
194
|
+
|
|
154
195
|
original_call = block_adapter.pipe.__class__.__call__
|
|
155
196
|
|
|
156
197
|
@functools.wraps(original_call)
|
|
157
198
|
def new_call(self, *args, **kwargs):
|
|
158
199
|
with ExitStack() as stack:
|
|
159
|
-
# cache context will reset for each pipe inference
|
|
160
|
-
for
|
|
200
|
+
# cache context will be reset for each pipe inference
|
|
201
|
+
for context_name, context_kwargs in zip(
|
|
202
|
+
flatten_contexts, contexts_kwargs
|
|
203
|
+
):
|
|
161
204
|
stack.enter_context(
|
|
162
|
-
|
|
163
|
-
|
|
164
|
-
|
|
165
|
-
**
|
|
205
|
+
cache_manager.enter_context(
|
|
206
|
+
cache_manager.reset_context(
|
|
207
|
+
context_name,
|
|
208
|
+
**context_kwargs,
|
|
166
209
|
),
|
|
167
210
|
)
|
|
168
211
|
)
|
|
@@ -171,109 +214,194 @@ class CachedAdapter:
|
|
|
171
214
|
return outputs
|
|
172
215
|
|
|
173
216
|
block_adapter.pipe.__class__.__call__ = new_call
|
|
217
|
+
block_adapter.pipe.__class__._original_call = original_call
|
|
174
218
|
block_adapter.pipe.__class__._is_cached = True
|
|
219
|
+
|
|
220
|
+
cls.patch_params(block_adapter, contexts_kwargs)
|
|
221
|
+
|
|
175
222
|
return block_adapter.pipe
|
|
176
223
|
|
|
177
224
|
@classmethod
|
|
178
|
-
def
|
|
225
|
+
def modify_context_params(
|
|
226
|
+
cls,
|
|
227
|
+
block_adapter: BlockAdapter,
|
|
228
|
+
cache_manager: CachedContextManager,
|
|
229
|
+
**cache_context_kwargs,
|
|
230
|
+
) -> Tuple[List[str], List[Dict[str, Any]]]:
|
|
231
|
+
|
|
232
|
+
flatten_contexts = BlockAdapter.flatten(
|
|
233
|
+
block_adapter.unique_blocks_name
|
|
234
|
+
)
|
|
235
|
+
contexts_kwargs = [
|
|
236
|
+
cache_context_kwargs.copy()
|
|
237
|
+
for _ in range(
|
|
238
|
+
len(flatten_contexts),
|
|
239
|
+
)
|
|
240
|
+
]
|
|
241
|
+
|
|
242
|
+
for i in range(len(contexts_kwargs)):
|
|
243
|
+
contexts_kwargs[i]["name"] = flatten_contexts[i]
|
|
244
|
+
|
|
245
|
+
if block_adapter.params_modifiers is None:
|
|
246
|
+
return flatten_contexts, contexts_kwargs
|
|
247
|
+
|
|
248
|
+
flatten_modifiers: List[ParamsModifier] = BlockAdapter.flatten(
|
|
249
|
+
block_adapter.params_modifiers,
|
|
250
|
+
)
|
|
251
|
+
|
|
252
|
+
for i in range(
|
|
253
|
+
min(len(contexts_kwargs), len(flatten_modifiers)),
|
|
254
|
+
):
|
|
255
|
+
contexts_kwargs[i].update(
|
|
256
|
+
flatten_modifiers[i]._context_kwargs,
|
|
257
|
+
)
|
|
258
|
+
contexts_kwargs[i], _ = cache_manager.collect_cache_kwargs(
|
|
259
|
+
default_attrs={}, **contexts_kwargs[i]
|
|
260
|
+
)
|
|
261
|
+
|
|
262
|
+
return flatten_contexts, contexts_kwargs
|
|
263
|
+
|
|
264
|
+
@classmethod
|
|
265
|
+
def patch_stats(
|
|
266
|
+
cls,
|
|
267
|
+
block_adapter: BlockAdapter,
|
|
268
|
+
):
|
|
179
269
|
from cache_dit.cache_factory.cache_blocks.utils import (
|
|
180
270
|
patch_cached_stats,
|
|
181
271
|
)
|
|
182
272
|
|
|
183
|
-
|
|
184
|
-
|
|
185
|
-
|
|
186
|
-
|
|
187
|
-
|
|
273
|
+
cache_manager = block_adapter.pipe._cache_manager
|
|
274
|
+
|
|
275
|
+
for i in range(len(block_adapter.transformer)):
|
|
276
|
+
patch_cached_stats(
|
|
277
|
+
block_adapter.transformer[i],
|
|
278
|
+
cache_context=block_adapter.unique_blocks_name[i][-1],
|
|
279
|
+
cache_manager=cache_manager,
|
|
280
|
+
)
|
|
281
|
+
for blocks, unique_name in zip(
|
|
282
|
+
block_adapter.blocks[i],
|
|
283
|
+
block_adapter.unique_blocks_name[i],
|
|
284
|
+
):
|
|
285
|
+
patch_cached_stats(
|
|
286
|
+
blocks,
|
|
287
|
+
cache_context=unique_name,
|
|
288
|
+
cache_manager=cache_manager,
|
|
289
|
+
)
|
|
188
290
|
|
|
189
291
|
@classmethod
|
|
190
292
|
def mock_blocks(
|
|
191
293
|
cls,
|
|
192
294
|
block_adapter: BlockAdapter,
|
|
193
|
-
) -> torch.nn.Module:
|
|
295
|
+
) -> List[torch.nn.Module]:
|
|
296
|
+
|
|
297
|
+
BlockAdapter.assert_normalized(block_adapter)
|
|
194
298
|
|
|
195
|
-
if
|
|
299
|
+
if BlockAdapter.is_cached(block_adapter.transformer):
|
|
196
300
|
return block_adapter.transformer
|
|
197
301
|
|
|
198
|
-
#
|
|
199
|
-
|
|
200
|
-
|
|
201
|
-
|
|
302
|
+
# Apply cache on transformer: mock cached transformer blocks
|
|
303
|
+
for (
|
|
304
|
+
cached_blocks,
|
|
305
|
+
transformer,
|
|
306
|
+
blocks_name,
|
|
307
|
+
unique_blocks_name,
|
|
308
|
+
dummy_blocks_names,
|
|
309
|
+
) in zip(
|
|
310
|
+
cls.collect_cached_blocks(block_adapter),
|
|
311
|
+
block_adapter.transformer,
|
|
312
|
+
block_adapter.blocks_name,
|
|
313
|
+
block_adapter.unique_blocks_name,
|
|
314
|
+
block_adapter.dummy_blocks_names,
|
|
202
315
|
):
|
|
203
|
-
|
|
204
|
-
|
|
205
|
-
|
|
206
|
-
|
|
207
|
-
|
|
208
|
-
|
|
209
|
-
f"supported lists: {ForwardPattern.supported_patterns()}"
|
|
316
|
+
cls.mock_transformer(
|
|
317
|
+
cached_blocks,
|
|
318
|
+
transformer,
|
|
319
|
+
blocks_name,
|
|
320
|
+
unique_blocks_name,
|
|
321
|
+
dummy_blocks_names,
|
|
210
322
|
)
|
|
211
323
|
|
|
212
|
-
|
|
213
|
-
|
|
214
|
-
|
|
215
|
-
|
|
216
|
-
|
|
217
|
-
cached_blocks
|
|
218
|
-
|
|
219
|
-
|
|
324
|
+
return block_adapter.transformer
|
|
325
|
+
|
|
326
|
+
@classmethod
|
|
327
|
+
def mock_transformer(
|
|
328
|
+
cls,
|
|
329
|
+
cached_blocks: Dict[str, torch.nn.ModuleList],
|
|
330
|
+
transformer: torch.nn.Module,
|
|
331
|
+
blocks_name: List[str],
|
|
332
|
+
unique_blocks_name: List[str],
|
|
333
|
+
dummy_blocks_names: List[str],
|
|
334
|
+
) -> torch.nn.Module:
|
|
220
335
|
dummy_blocks = torch.nn.ModuleList()
|
|
221
336
|
|
|
222
|
-
original_forward =
|
|
337
|
+
original_forward = transformer.forward
|
|
223
338
|
|
|
224
|
-
assert isinstance(
|
|
339
|
+
assert isinstance(dummy_blocks_names, list)
|
|
225
340
|
|
|
226
341
|
@functools.wraps(original_forward)
|
|
227
342
|
def new_forward(self, *args, **kwargs):
|
|
228
343
|
with ExitStack() as stack:
|
|
229
|
-
for
|
|
344
|
+
for name, context_name in zip(
|
|
345
|
+
blocks_name,
|
|
346
|
+
unique_blocks_name,
|
|
347
|
+
):
|
|
230
348
|
stack.enter_context(
|
|
231
349
|
unittest.mock.patch.object(
|
|
232
|
-
self,
|
|
233
|
-
blocks_name,
|
|
234
|
-
cached_blocks[blocks_name],
|
|
350
|
+
self, name, cached_blocks[context_name]
|
|
235
351
|
)
|
|
236
352
|
)
|
|
237
|
-
for dummy_name in
|
|
353
|
+
for dummy_name in dummy_blocks_names:
|
|
238
354
|
stack.enter_context(
|
|
239
355
|
unittest.mock.patch.object(
|
|
240
|
-
self,
|
|
241
|
-
dummy_name,
|
|
242
|
-
dummy_blocks,
|
|
356
|
+
self, dummy_name, dummy_blocks
|
|
243
357
|
)
|
|
244
358
|
)
|
|
245
359
|
return original_forward(*args, **kwargs)
|
|
246
360
|
|
|
247
|
-
|
|
248
|
-
|
|
249
|
-
|
|
250
|
-
block_adapter.transformer._is_cached = True
|
|
361
|
+
transformer.forward = new_forward.__get__(transformer)
|
|
362
|
+
transformer._original_forward = original_forward
|
|
363
|
+
transformer._is_cached = True
|
|
251
364
|
|
|
252
|
-
return
|
|
365
|
+
return transformer
|
|
253
366
|
|
|
254
367
|
@classmethod
|
|
255
368
|
def collect_cached_blocks(
|
|
256
369
|
cls,
|
|
257
370
|
block_adapter: BlockAdapter,
|
|
258
|
-
) -> Dict[str, torch.nn.ModuleList]:
|
|
259
|
-
block_adapter = BlockAdapter.normalize(block_adapter)
|
|
371
|
+
) -> List[Dict[str, torch.nn.ModuleList]]:
|
|
260
372
|
|
|
261
|
-
|
|
373
|
+
BlockAdapter.assert_normalized(block_adapter)
|
|
262
374
|
|
|
263
|
-
|
|
264
|
-
|
|
265
|
-
|
|
375
|
+
total_cached_blocks: List[Dict[str, torch.nn.ModuleList]] = []
|
|
376
|
+
assert hasattr(block_adapter.pipe, "_cache_manager")
|
|
377
|
+
assert isinstance(
|
|
378
|
+
block_adapter.pipe._cache_manager, CachedContextManager
|
|
379
|
+
)
|
|
380
|
+
|
|
381
|
+
for i in range(len(block_adapter.transformer)):
|
|
382
|
+
|
|
383
|
+
cached_blocks_bind_context = {}
|
|
384
|
+
for j in range(len(block_adapter.blocks[i])):
|
|
385
|
+
cached_blocks_bind_context[
|
|
386
|
+
block_adapter.unique_blocks_name[i][j]
|
|
387
|
+
] = torch.nn.ModuleList(
|
|
266
388
|
[
|
|
267
389
|
CachedBlocks(
|
|
268
|
-
|
|
269
|
-
block_adapter.
|
|
270
|
-
block_adapter.
|
|
271
|
-
|
|
272
|
-
forward_pattern=block_adapter.forward_pattern[i],
|
|
390
|
+
# 0. Transformer blocks configuration
|
|
391
|
+
block_adapter.blocks[i][j],
|
|
392
|
+
transformer=block_adapter.transformer[i],
|
|
393
|
+
forward_pattern=block_adapter.forward_pattern[i][j],
|
|
273
394
|
check_num_outputs=block_adapter.check_num_outputs,
|
|
395
|
+
# 1. Cache context configuration
|
|
396
|
+
cache_prefix=block_adapter.blocks_name[i][j],
|
|
397
|
+
cache_context=block_adapter.unique_blocks_name[i][
|
|
398
|
+
j
|
|
399
|
+
],
|
|
400
|
+
cache_manager=block_adapter.pipe._cache_manager,
|
|
274
401
|
)
|
|
275
402
|
]
|
|
276
403
|
)
|
|
277
|
-
)
|
|
278
404
|
|
|
279
|
-
|
|
405
|
+
total_cached_blocks.append(cached_blocks_bind_context)
|
|
406
|
+
|
|
407
|
+
return total_cached_blocks
|
|
@@ -1,3 +1,11 @@
|
|
|
1
|
+
import torch
|
|
2
|
+
|
|
3
|
+
from cache_dit.cache_factory import ForwardPattern
|
|
4
|
+
from cache_dit.cache_factory.cache_contexts.cache_context import CachedContext
|
|
5
|
+
from cache_dit.cache_factory.cache_contexts.cache_manager import (
|
|
6
|
+
CachedContextManager,
|
|
7
|
+
)
|
|
8
|
+
|
|
1
9
|
from cache_dit.cache_factory.cache_blocks.pattern_0_1_2 import (
|
|
2
10
|
CachedBlocks_Pattern_0_1_2,
|
|
3
11
|
)
|
|
@@ -5,14 +13,57 @@ from cache_dit.cache_factory.cache_blocks.pattern_3_4_5 import (
|
|
|
5
13
|
CachedBlocks_Pattern_3_4_5,
|
|
6
14
|
)
|
|
7
15
|
|
|
16
|
+
from cache_dit.logger import init_logger
|
|
17
|
+
|
|
18
|
+
logger = init_logger(__name__)
|
|
19
|
+
|
|
8
20
|
|
|
9
21
|
class CachedBlocks:
|
|
10
|
-
def __new__(
|
|
11
|
-
|
|
22
|
+
def __new__(
|
|
23
|
+
cls,
|
|
24
|
+
# 0. Transformer blocks configuration
|
|
25
|
+
transformer_blocks: torch.nn.ModuleList,
|
|
26
|
+
transformer: torch.nn.Module = None,
|
|
27
|
+
forward_pattern: ForwardPattern = None,
|
|
28
|
+
check_num_outputs: bool = True,
|
|
29
|
+
# 1. Cache context configuration
|
|
30
|
+
# 'transformer_blocks', 'blocks', 'single_transformer_blocks',
|
|
31
|
+
# 'layers', 'single_stream_blocks', 'double_stream_blocks'
|
|
32
|
+
cache_prefix: str = None, # cache_prefix maybe un-need.
|
|
33
|
+
# Usually, blocks_name, etc.
|
|
34
|
+
cache_context: CachedContext | str = None,
|
|
35
|
+
cache_manager: CachedContextManager = None,
|
|
36
|
+
**kwargs,
|
|
37
|
+
):
|
|
38
|
+
assert transformer is not None, "transformer can't be None."
|
|
12
39
|
assert forward_pattern is not None, "forward_pattern can't be None."
|
|
40
|
+
assert cache_context is not None, "cache_context can't be None."
|
|
41
|
+
assert cache_manager is not None, "cache_manager can't be None."
|
|
13
42
|
if forward_pattern in CachedBlocks_Pattern_0_1_2._supported_patterns:
|
|
14
|
-
return CachedBlocks_Pattern_0_1_2(
|
|
43
|
+
return CachedBlocks_Pattern_0_1_2(
|
|
44
|
+
# 0. Transformer blocks configuration
|
|
45
|
+
transformer_blocks,
|
|
46
|
+
transformer=transformer,
|
|
47
|
+
forward_pattern=forward_pattern,
|
|
48
|
+
check_num_outputs=check_num_outputs,
|
|
49
|
+
# 1. Cache context configuration
|
|
50
|
+
cache_prefix=cache_prefix,
|
|
51
|
+
cache_context=cache_context,
|
|
52
|
+
cache_manager=cache_manager,
|
|
53
|
+
**kwargs,
|
|
54
|
+
)
|
|
15
55
|
elif forward_pattern in CachedBlocks_Pattern_3_4_5._supported_patterns:
|
|
16
|
-
return CachedBlocks_Pattern_3_4_5(
|
|
56
|
+
return CachedBlocks_Pattern_3_4_5(
|
|
57
|
+
# 0. Transformer blocks configuration
|
|
58
|
+
transformer_blocks,
|
|
59
|
+
transformer=transformer,
|
|
60
|
+
forward_pattern=forward_pattern,
|
|
61
|
+
check_num_outputs=check_num_outputs,
|
|
62
|
+
# 1. Cache context configuration
|
|
63
|
+
cache_prefix=cache_prefix,
|
|
64
|
+
cache_context=cache_context,
|
|
65
|
+
cache_manager=cache_manager,
|
|
66
|
+
**kwargs,
|
|
67
|
+
)
|
|
17
68
|
else:
|
|
18
69
|
raise ValueError(f"Pattern {forward_pattern} is not supported now!")
|