cache-dit 1.0.3__py3-none-any.whl → 1.0.5__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Potentially problematic release.
This version of cache-dit might be problematic. Click here for more details.
- cache_dit/__init__.py +3 -0
- cache_dit/_version.py +2 -2
- cache_dit/cache_factory/__init__.py +8 -1
- cache_dit/cache_factory/block_adapters/__init__.py +4 -1
- cache_dit/cache_factory/cache_adapters/cache_adapter.py +126 -80
- cache_dit/cache_factory/cache_blocks/__init__.py +167 -17
- cache_dit/cache_factory/cache_blocks/pattern_0_1_2.py +10 -0
- cache_dit/cache_factory/cache_blocks/pattern_3_4_5.py +256 -24
- cache_dit/cache_factory/cache_blocks/pattern_base.py +273 -38
- cache_dit/cache_factory/cache_blocks/pattern_utils.py +55 -10
- cache_dit/cache_factory/cache_contexts/__init__.py +15 -2
- cache_dit/cache_factory/cache_contexts/cache_config.py +118 -0
- cache_dit/cache_factory/cache_contexts/cache_context.py +15 -93
- cache_dit/cache_factory/cache_contexts/cache_manager.py +7 -7
- cache_dit/cache_factory/cache_contexts/calibrators/__init__.py +22 -0
- cache_dit/cache_factory/cache_contexts/calibrators/taylorseer.py +78 -8
- cache_dit/cache_factory/cache_contexts/context_manager.py +29 -0
- cache_dit/cache_factory/cache_contexts/prune_config.py +63 -0
- cache_dit/cache_factory/cache_contexts/prune_context.py +155 -0
- cache_dit/cache_factory/cache_contexts/prune_manager.py +154 -0
- cache_dit/cache_factory/cache_interface.py +20 -14
- cache_dit/cache_factory/cache_types.py +19 -2
- cache_dit/cache_factory/params_modifier.py +7 -7
- cache_dit/cache_factory/utils.py +18 -7
- cache_dit/quantize/quantize_ao.py +58 -17
- cache_dit/utils.py +191 -54
- {cache_dit-1.0.3.dist-info → cache_dit-1.0.5.dist-info}/METADATA +11 -10
- {cache_dit-1.0.3.dist-info → cache_dit-1.0.5.dist-info}/RECORD +32 -27
- {cache_dit-1.0.3.dist-info → cache_dit-1.0.5.dist-info}/WHEEL +0 -0
- {cache_dit-1.0.3.dist-info → cache_dit-1.0.5.dist-info}/entry_points.txt +0 -0
- {cache_dit-1.0.3.dist-info → cache_dit-1.0.5.dist-info}/licenses/LICENSE +0 -0
- {cache_dit-1.0.3.dist-info → cache_dit-1.0.5.dist-info}/top_level.txt +0 -0
cache_dit/__init__.py
CHANGED
|
@@ -19,6 +19,8 @@ from cache_dit.cache_factory import ParamsModifier
|
|
|
19
19
|
from cache_dit.cache_factory import ForwardPattern
|
|
20
20
|
from cache_dit.cache_factory import PatchFunctor
|
|
21
21
|
from cache_dit.cache_factory import BasicCacheConfig
|
|
22
|
+
from cache_dit.cache_factory import DBCacheConfig
|
|
23
|
+
from cache_dit.cache_factory import DBPruneConfig
|
|
22
24
|
from cache_dit.cache_factory import CalibratorConfig
|
|
23
25
|
from cache_dit.cache_factory import TaylorSeerCalibratorConfig
|
|
24
26
|
from cache_dit.cache_factory import FoCaCalibratorConfig
|
|
@@ -30,6 +32,7 @@ from cache_dit.quantize import quantize
|
|
|
30
32
|
|
|
31
33
|
NONE = CacheType.NONE
|
|
32
34
|
DBCache = CacheType.DBCache
|
|
35
|
+
DBPrune = CacheType.DBPrune
|
|
33
36
|
|
|
34
37
|
Pattern_0 = ForwardPattern.Pattern_0
|
|
35
38
|
Pattern_1 = ForwardPattern.Pattern_1
|
cache_dit/_version.py
CHANGED
|
@@ -28,7 +28,7 @@ version_tuple: VERSION_TUPLE
|
|
|
28
28
|
commit_id: COMMIT_ID
|
|
29
29
|
__commit_id__: COMMIT_ID
|
|
30
30
|
|
|
31
|
-
__version__ = version = '1.0.
|
|
32
|
-
__version_tuple__ = version_tuple = (1, 0,
|
|
31
|
+
__version__ = version = '1.0.5'
|
|
32
|
+
__version_tuple__ = version_tuple = (1, 0, 5)
|
|
33
33
|
|
|
34
34
|
__commit_id__ = commit_id = None
|
|
@@ -9,14 +9,21 @@ from cache_dit.cache_factory.patch_functors import PatchFunctor
|
|
|
9
9
|
from cache_dit.cache_factory.block_adapters import BlockAdapter
|
|
10
10
|
from cache_dit.cache_factory.block_adapters import BlockAdapterRegistry
|
|
11
11
|
|
|
12
|
-
from cache_dit.cache_factory.cache_contexts import CachedContext
|
|
13
12
|
from cache_dit.cache_factory.cache_contexts import BasicCacheConfig
|
|
13
|
+
from cache_dit.cache_factory.cache_contexts import DBCacheConfig
|
|
14
|
+
from cache_dit.cache_factory.cache_contexts import CachedContext
|
|
14
15
|
from cache_dit.cache_factory.cache_contexts import CachedContextManager
|
|
16
|
+
from cache_dit.cache_factory.cache_contexts import DBPruneConfig
|
|
17
|
+
from cache_dit.cache_factory.cache_contexts import PrunedContext
|
|
18
|
+
from cache_dit.cache_factory.cache_contexts import PrunedContextManager
|
|
19
|
+
from cache_dit.cache_factory.cache_contexts import ContextManager
|
|
15
20
|
from cache_dit.cache_factory.cache_contexts import CalibratorConfig
|
|
16
21
|
from cache_dit.cache_factory.cache_contexts import TaylorSeerCalibratorConfig
|
|
17
22
|
from cache_dit.cache_factory.cache_contexts import FoCaCalibratorConfig
|
|
18
23
|
|
|
19
24
|
from cache_dit.cache_factory.cache_blocks import CachedBlocks
|
|
25
|
+
from cache_dit.cache_factory.cache_blocks import PrunedBlocks
|
|
26
|
+
from cache_dit.cache_factory.cache_blocks import UnifiedBlocks
|
|
20
27
|
|
|
21
28
|
from cache_dit.cache_factory.cache_adapters import CachedAdapter
|
|
22
29
|
|
|
@@ -12,7 +12,10 @@ def flux_adapter(pipe, **kwargs) -> BlockAdapter:
|
|
|
12
12
|
from cache_dit.utils import is_diffusers_at_least_0_3_5
|
|
13
13
|
|
|
14
14
|
assert isinstance(pipe.transformer, FluxTransformer2DModel)
|
|
15
|
-
|
|
15
|
+
transformer_cls_name: str = pipe.transformer.__class__.__name__
|
|
16
|
+
if is_diffusers_at_least_0_3_5() and not transformer_cls_name.startswith(
|
|
17
|
+
"Nunchaku"
|
|
18
|
+
):
|
|
16
19
|
return BlockAdapter(
|
|
17
20
|
pipe=pipe,
|
|
18
21
|
transformer=pipe.transformer,
|
|
@@ -1,3 +1,4 @@
|
|
|
1
|
+
import copy
|
|
1
2
|
import torch
|
|
2
3
|
import unittest
|
|
3
4
|
import functools
|
|
@@ -10,10 +11,10 @@ from cache_dit.cache_factory.cache_types import CacheType
|
|
|
10
11
|
from cache_dit.cache_factory.block_adapters import BlockAdapter
|
|
11
12
|
from cache_dit.cache_factory.block_adapters import ParamsModifier
|
|
12
13
|
from cache_dit.cache_factory.block_adapters import BlockAdapterRegistry
|
|
13
|
-
from cache_dit.cache_factory.cache_contexts import
|
|
14
|
+
from cache_dit.cache_factory.cache_contexts import ContextManager
|
|
14
15
|
from cache_dit.cache_factory.cache_contexts import BasicCacheConfig
|
|
15
16
|
from cache_dit.cache_factory.cache_contexts import CalibratorConfig
|
|
16
|
-
from cache_dit.cache_factory.cache_blocks import
|
|
17
|
+
from cache_dit.cache_factory.cache_blocks import UnifiedBlocks
|
|
17
18
|
from cache_dit.logger import init_logger
|
|
18
19
|
|
|
19
20
|
logger = init_logger(__name__)
|
|
@@ -32,7 +33,7 @@ class CachedAdapter:
|
|
|
32
33
|
DiffusionPipeline,
|
|
33
34
|
BlockAdapter,
|
|
34
35
|
],
|
|
35
|
-
**
|
|
36
|
+
**context_kwargs,
|
|
36
37
|
) -> Union[
|
|
37
38
|
DiffusionPipeline,
|
|
38
39
|
BlockAdapter,
|
|
@@ -51,7 +52,7 @@ class CachedAdapter:
|
|
|
51
52
|
block_adapter = BlockAdapterRegistry.get_adapter(
|
|
52
53
|
pipe_or_adapter
|
|
53
54
|
)
|
|
54
|
-
if params_modifiers :=
|
|
55
|
+
if params_modifiers := context_kwargs.pop(
|
|
55
56
|
"params_modifiers",
|
|
56
57
|
None,
|
|
57
58
|
):
|
|
@@ -59,7 +60,7 @@ class CachedAdapter:
|
|
|
59
60
|
|
|
60
61
|
return cls.cachify(
|
|
61
62
|
block_adapter,
|
|
62
|
-
**
|
|
63
|
+
**context_kwargs,
|
|
63
64
|
).pipe
|
|
64
65
|
else:
|
|
65
66
|
raise ValueError(
|
|
@@ -72,21 +73,21 @@ class CachedAdapter:
|
|
|
72
73
|
"Adapting Cache Acceleration using custom BlockAdapter!"
|
|
73
74
|
)
|
|
74
75
|
if pipe_or_adapter.params_modifiers is None:
|
|
75
|
-
if params_modifiers :=
|
|
76
|
+
if params_modifiers := context_kwargs.pop(
|
|
76
77
|
"params_modifiers", None
|
|
77
78
|
):
|
|
78
79
|
pipe_or_adapter.params_modifiers = params_modifiers
|
|
79
80
|
|
|
80
81
|
return cls.cachify(
|
|
81
82
|
pipe_or_adapter,
|
|
82
|
-
**
|
|
83
|
+
**context_kwargs,
|
|
83
84
|
)
|
|
84
85
|
|
|
85
86
|
@classmethod
|
|
86
87
|
def cachify(
|
|
87
88
|
cls,
|
|
88
89
|
block_adapter: BlockAdapter,
|
|
89
|
-
**
|
|
90
|
+
**context_kwargs,
|
|
90
91
|
) -> BlockAdapter:
|
|
91
92
|
|
|
92
93
|
if block_adapter.auto:
|
|
@@ -103,14 +104,15 @@ class CachedAdapter:
|
|
|
103
104
|
|
|
104
105
|
# 1. Apply cache on pipeline: wrap cache context, must
|
|
105
106
|
# call create_context before mock_blocks.
|
|
106
|
-
cls.create_context(
|
|
107
|
+
_, contexts_kwargs = cls.create_context(
|
|
107
108
|
block_adapter,
|
|
108
|
-
**
|
|
109
|
+
**context_kwargs,
|
|
109
110
|
)
|
|
110
111
|
|
|
111
112
|
# 2. Apply cache on transformer: mock cached blocks
|
|
112
113
|
cls.mock_blocks(
|
|
113
114
|
block_adapter,
|
|
115
|
+
contexts_kwargs,
|
|
114
116
|
)
|
|
115
117
|
|
|
116
118
|
return block_adapter
|
|
@@ -119,12 +121,10 @@ class CachedAdapter:
|
|
|
119
121
|
def check_context_kwargs(
|
|
120
122
|
cls,
|
|
121
123
|
block_adapter: BlockAdapter,
|
|
122
|
-
**
|
|
124
|
+
**context_kwargs,
|
|
123
125
|
):
|
|
124
|
-
# Check
|
|
125
|
-
cache_config: BasicCacheConfig =
|
|
126
|
-
"cache_config"
|
|
127
|
-
] # ref
|
|
126
|
+
# Check context_kwargs
|
|
127
|
+
cache_config: BasicCacheConfig = context_kwargs["cache_config"] # ref
|
|
128
128
|
assert cache_config is not None, "cache_config can not be None."
|
|
129
129
|
if cache_config.enable_separate_cfg is None:
|
|
130
130
|
# Check cfg for some specific case if users don't set it as True
|
|
@@ -150,19 +150,23 @@ class CachedAdapter:
|
|
|
150
150
|
f"Pipeline: {block_adapter.pipe.__class__.__name__}."
|
|
151
151
|
)
|
|
152
152
|
|
|
153
|
-
cache_type =
|
|
153
|
+
cache_type = context_kwargs.pop("cache_type", None)
|
|
154
154
|
if cache_type is not None:
|
|
155
|
-
assert (
|
|
156
|
-
cache_type
|
|
157
|
-
), "
|
|
155
|
+
assert isinstance(
|
|
156
|
+
cache_type, CacheType
|
|
157
|
+
), f"cache_type must be CacheType Enum, but got {type(cache_type)}."
|
|
158
|
+
assert cache_type == cache_config.cache_type, (
|
|
159
|
+
f"cache_type from context_kwargs ({cache_type}) must be the same "
|
|
160
|
+
f"as that from cache_config ({cache_config.cache_type})."
|
|
161
|
+
)
|
|
158
162
|
|
|
159
|
-
return
|
|
163
|
+
return context_kwargs
|
|
160
164
|
|
|
161
165
|
@classmethod
|
|
162
166
|
def create_context(
|
|
163
167
|
cls,
|
|
164
168
|
block_adapter: BlockAdapter,
|
|
165
|
-
**
|
|
169
|
+
**context_kwargs,
|
|
166
170
|
) -> Tuple[List[str], List[Dict[str, Any]]]:
|
|
167
171
|
|
|
168
172
|
BlockAdapter.assert_normalized(block_adapter)
|
|
@@ -170,9 +174,9 @@ class CachedAdapter:
|
|
|
170
174
|
if BlockAdapter.is_cached(block_adapter.pipe):
|
|
171
175
|
return block_adapter.pipe
|
|
172
176
|
|
|
173
|
-
# Check
|
|
174
|
-
|
|
175
|
-
block_adapter, **
|
|
177
|
+
# Check context_kwargs
|
|
178
|
+
context_kwargs = cls.check_context_kwargs(
|
|
179
|
+
block_adapter, **context_kwargs
|
|
176
180
|
)
|
|
177
181
|
# Apply cache on pipeline: wrap cache context
|
|
178
182
|
pipe_cls_name = block_adapter.pipe.__class__.__name__
|
|
@@ -181,15 +185,19 @@ class CachedAdapter:
|
|
|
181
185
|
# Different transformers (Wan2.2, etc) should shared the same
|
|
182
186
|
# cache manager but with different cache context (according
|
|
183
187
|
# to their unique instance id).
|
|
184
|
-
|
|
188
|
+
cache_config: BasicCacheConfig = context_kwargs.get(
|
|
189
|
+
"cache_config", None
|
|
190
|
+
)
|
|
191
|
+
assert cache_config is not None, "cache_config can not be None."
|
|
192
|
+
context_manager = ContextManager(
|
|
185
193
|
name=f"{pipe_cls_name}_{hash(id(block_adapter.pipe))}",
|
|
194
|
+
cache_type=cache_config.cache_type,
|
|
186
195
|
)
|
|
187
|
-
block_adapter.pipe.
|
|
196
|
+
block_adapter.pipe._context_manager = context_manager # instance level
|
|
188
197
|
|
|
189
198
|
flatten_contexts, contexts_kwargs = cls.modify_context_params(
|
|
190
|
-
block_adapter, **
|
|
199
|
+
block_adapter, **context_kwargs
|
|
191
200
|
)
|
|
192
|
-
|
|
193
201
|
original_call = block_adapter.pipe.__class__.__call__
|
|
194
202
|
|
|
195
203
|
@functools.wraps(original_call)
|
|
@@ -200,8 +208,8 @@ class CachedAdapter:
|
|
|
200
208
|
flatten_contexts, contexts_kwargs
|
|
201
209
|
):
|
|
202
210
|
stack.enter_context(
|
|
203
|
-
|
|
204
|
-
|
|
211
|
+
context_manager.enter_context(
|
|
212
|
+
context_manager.reset_context(
|
|
205
213
|
context_name,
|
|
206
214
|
**context_kwargs,
|
|
207
215
|
),
|
|
@@ -223,14 +231,14 @@ class CachedAdapter:
|
|
|
223
231
|
def modify_context_params(
|
|
224
232
|
cls,
|
|
225
233
|
block_adapter: BlockAdapter,
|
|
226
|
-
**
|
|
234
|
+
**context_kwargs,
|
|
227
235
|
) -> Tuple[List[str], List[Dict[str, Any]]]:
|
|
228
236
|
|
|
229
237
|
flatten_contexts = BlockAdapter.flatten(
|
|
230
238
|
block_adapter.unique_blocks_name
|
|
231
239
|
)
|
|
232
240
|
contexts_kwargs = [
|
|
233
|
-
|
|
241
|
+
copy.deepcopy(context_kwargs) # must deep copy
|
|
234
242
|
for _ in range(
|
|
235
243
|
len(flatten_contexts),
|
|
236
244
|
)
|
|
@@ -251,9 +259,41 @@ class CachedAdapter:
|
|
|
251
259
|
for i in range(
|
|
252
260
|
min(len(contexts_kwargs), len(flatten_modifiers)),
|
|
253
261
|
):
|
|
254
|
-
|
|
255
|
-
flatten_modifiers[
|
|
256
|
-
|
|
262
|
+
if "cache_config" in flatten_modifiers[i]._context_kwargs:
|
|
263
|
+
modifier_cache_config = flatten_modifiers[
|
|
264
|
+
i
|
|
265
|
+
]._context_kwargs.get("cache_config", None)
|
|
266
|
+
modifier_calibrator_config = flatten_modifiers[
|
|
267
|
+
i
|
|
268
|
+
]._context_kwargs.get("calibrator_config", None)
|
|
269
|
+
if modifier_cache_config is not None:
|
|
270
|
+
assert isinstance(
|
|
271
|
+
modifier_cache_config, BasicCacheConfig
|
|
272
|
+
), (
|
|
273
|
+
f"cache_config must be BasicCacheConfig, but got "
|
|
274
|
+
f"{type(modifier_cache_config)}."
|
|
275
|
+
)
|
|
276
|
+
contexts_kwargs[i]["cache_config"].update(
|
|
277
|
+
**modifier_cache_config.as_dict()
|
|
278
|
+
)
|
|
279
|
+
if modifier_calibrator_config is not None:
|
|
280
|
+
assert isinstance(
|
|
281
|
+
modifier_calibrator_config, CalibratorConfig
|
|
282
|
+
), (
|
|
283
|
+
f"calibrator_config must be CalibratorConfig, but got "
|
|
284
|
+
f"{type(modifier_calibrator_config)}."
|
|
285
|
+
)
|
|
286
|
+
if (
|
|
287
|
+
contexts_kwargs[i].get("calibrator_config", None)
|
|
288
|
+
is None
|
|
289
|
+
):
|
|
290
|
+
contexts_kwargs[i][
|
|
291
|
+
"calibrator_config"
|
|
292
|
+
] = modifier_calibrator_config
|
|
293
|
+
else:
|
|
294
|
+
contexts_kwargs[i]["calibrator_config"].update(
|
|
295
|
+
**modifier_calibrator_config.as_dict()
|
|
296
|
+
)
|
|
257
297
|
cls._config_messages(**contexts_kwargs[i])
|
|
258
298
|
|
|
259
299
|
return flatten_contexts, contexts_kwargs
|
|
@@ -267,7 +307,7 @@ class CachedAdapter:
|
|
|
267
307
|
"calibrator_config", None
|
|
268
308
|
)
|
|
269
309
|
if cache_config is not None:
|
|
270
|
-
message = f"Collected
|
|
310
|
+
message = f"Collected Context Config: {cache_config.strify()}"
|
|
271
311
|
if calibrator_config is not None:
|
|
272
312
|
message += f", Calibrator Config: {calibrator_config.strify(details=True)}"
|
|
273
313
|
else:
|
|
@@ -278,6 +318,7 @@ class CachedAdapter:
|
|
|
278
318
|
def mock_blocks(
|
|
279
319
|
cls,
|
|
280
320
|
block_adapter: BlockAdapter,
|
|
321
|
+
contexts_kwargs: List[Dict],
|
|
281
322
|
) -> List[torch.nn.Module]:
|
|
282
323
|
|
|
283
324
|
BlockAdapter.assert_normalized(block_adapter)
|
|
@@ -287,20 +328,23 @@ class CachedAdapter:
|
|
|
287
328
|
|
|
288
329
|
# Apply cache on transformer: mock cached transformer blocks
|
|
289
330
|
for (
|
|
290
|
-
|
|
331
|
+
unified_blocks,
|
|
291
332
|
transformer,
|
|
292
333
|
blocks_name,
|
|
293
334
|
unique_blocks_name,
|
|
294
335
|
dummy_blocks_names,
|
|
295
336
|
) in zip(
|
|
296
|
-
cls.
|
|
337
|
+
cls.collect_unified_blocks(
|
|
338
|
+
block_adapter,
|
|
339
|
+
contexts_kwargs,
|
|
340
|
+
),
|
|
297
341
|
block_adapter.transformer,
|
|
298
342
|
block_adapter.blocks_name,
|
|
299
343
|
block_adapter.unique_blocks_name,
|
|
300
344
|
block_adapter.dummy_blocks_names,
|
|
301
345
|
):
|
|
302
346
|
cls.mock_transformer(
|
|
303
|
-
|
|
347
|
+
unified_blocks,
|
|
304
348
|
transformer,
|
|
305
349
|
blocks_name,
|
|
306
350
|
unique_blocks_name,
|
|
@@ -312,7 +356,7 @@ class CachedAdapter:
|
|
|
312
356
|
@classmethod
|
|
313
357
|
def mock_transformer(
|
|
314
358
|
cls,
|
|
315
|
-
|
|
359
|
+
unified_blocks: Dict[str, torch.nn.ModuleList],
|
|
316
360
|
transformer: torch.nn.Module,
|
|
317
361
|
blocks_name: List[str],
|
|
318
362
|
unique_blocks_name: List[str],
|
|
@@ -352,7 +396,7 @@ class CachedAdapter:
|
|
|
352
396
|
):
|
|
353
397
|
stack.enter_context(
|
|
354
398
|
unittest.mock.patch.object(
|
|
355
|
-
self, name,
|
|
399
|
+
self, name, unified_blocks[context_name]
|
|
356
400
|
)
|
|
357
401
|
)
|
|
358
402
|
for dummy_name in dummy_blocks_names:
|
|
@@ -388,46 +432,51 @@ class CachedAdapter:
|
|
|
388
432
|
return transformer
|
|
389
433
|
|
|
390
434
|
@classmethod
|
|
391
|
-
def
|
|
435
|
+
def collect_unified_blocks(
|
|
392
436
|
cls,
|
|
393
437
|
block_adapter: BlockAdapter,
|
|
438
|
+
contexts_kwargs: List[Dict],
|
|
394
439
|
) -> List[Dict[str, torch.nn.ModuleList]]:
|
|
395
440
|
|
|
396
441
|
BlockAdapter.assert_normalized(block_adapter)
|
|
397
442
|
|
|
398
443
|
total_cached_blocks: List[Dict[str, torch.nn.ModuleList]] = []
|
|
399
|
-
assert hasattr(block_adapter.pipe, "
|
|
444
|
+
assert hasattr(block_adapter.pipe, "_context_manager")
|
|
400
445
|
assert isinstance(
|
|
401
|
-
block_adapter.pipe.
|
|
402
|
-
|
|
446
|
+
block_adapter.pipe._context_manager,
|
|
447
|
+
ContextManager._supported_managers,
|
|
403
448
|
)
|
|
404
449
|
|
|
405
450
|
for i in range(len(block_adapter.transformer)):
|
|
406
451
|
|
|
407
|
-
|
|
452
|
+
unified_blocks_bind_context = {}
|
|
408
453
|
for j in range(len(block_adapter.blocks[i])):
|
|
409
|
-
|
|
454
|
+
cache_config: BasicCacheConfig = contexts_kwargs[
|
|
455
|
+
i * len(block_adapter.blocks[i]) + j
|
|
456
|
+
]["cache_config"]
|
|
457
|
+
unified_blocks_bind_context[
|
|
410
458
|
block_adapter.unique_blocks_name[i][j]
|
|
411
459
|
] = torch.nn.ModuleList(
|
|
412
460
|
[
|
|
413
|
-
|
|
461
|
+
UnifiedBlocks(
|
|
414
462
|
# 0. Transformer blocks configuration
|
|
415
463
|
block_adapter.blocks[i][j],
|
|
416
464
|
transformer=block_adapter.transformer[i],
|
|
417
465
|
forward_pattern=block_adapter.forward_pattern[i][j],
|
|
418
466
|
check_forward_pattern=block_adapter.check_forward_pattern,
|
|
419
467
|
check_num_outputs=block_adapter.check_num_outputs,
|
|
420
|
-
# 1. Cache context configuration
|
|
468
|
+
# 1. Cache/Prune context configuration
|
|
421
469
|
cache_prefix=block_adapter.blocks_name[i][j],
|
|
422
470
|
cache_context=block_adapter.unique_blocks_name[i][
|
|
423
471
|
j
|
|
424
472
|
],
|
|
425
|
-
|
|
473
|
+
context_manager=block_adapter.pipe._context_manager,
|
|
474
|
+
cache_type=cache_config.cache_type,
|
|
426
475
|
)
|
|
427
476
|
]
|
|
428
477
|
)
|
|
429
478
|
|
|
430
|
-
total_cached_blocks.append(
|
|
479
|
+
total_cached_blocks.append(unified_blocks_bind_context)
|
|
431
480
|
|
|
432
481
|
return total_cached_blocks
|
|
433
482
|
|
|
@@ -437,7 +486,7 @@ class CachedAdapter:
|
|
|
437
486
|
block_adapter: BlockAdapter,
|
|
438
487
|
contexts_kwargs: List[Dict],
|
|
439
488
|
):
|
|
440
|
-
block_adapter.pipe.
|
|
489
|
+
block_adapter.pipe._context_kwargs = contexts_kwargs[0]
|
|
441
490
|
|
|
442
491
|
params_shift = 0
|
|
443
492
|
for i in range(len(block_adapter.transformer)):
|
|
@@ -448,16 +497,14 @@ class CachedAdapter:
|
|
|
448
497
|
block_adapter.transformer[i]._has_separate_cfg = (
|
|
449
498
|
block_adapter.has_separate_cfg
|
|
450
499
|
)
|
|
451
|
-
block_adapter.transformer[i].
|
|
452
|
-
|
|
453
|
-
|
|
500
|
+
block_adapter.transformer[i]._context_kwargs = contexts_kwargs[
|
|
501
|
+
params_shift
|
|
502
|
+
]
|
|
454
503
|
|
|
455
504
|
blocks = block_adapter.blocks[i]
|
|
456
505
|
for j in range(len(blocks)):
|
|
457
506
|
blocks[j]._forward_pattern = block_adapter.forward_pattern[i][j]
|
|
458
|
-
blocks[j].
|
|
459
|
-
params_shift + j
|
|
460
|
-
]
|
|
507
|
+
blocks[j]._context_kwargs = contexts_kwargs[params_shift + j]
|
|
461
508
|
|
|
462
509
|
params_shift += len(blocks)
|
|
463
510
|
|
|
@@ -467,25 +514,25 @@ class CachedAdapter:
|
|
|
467
514
|
block_adapter: BlockAdapter,
|
|
468
515
|
):
|
|
469
516
|
from cache_dit.cache_factory.cache_blocks import (
|
|
470
|
-
|
|
517
|
+
apply_stats,
|
|
471
518
|
)
|
|
472
519
|
|
|
473
|
-
|
|
520
|
+
context_manager = block_adapter.pipe._context_manager
|
|
474
521
|
|
|
475
522
|
for i in range(len(block_adapter.transformer)):
|
|
476
|
-
|
|
523
|
+
apply_stats(
|
|
477
524
|
block_adapter.transformer[i],
|
|
478
525
|
cache_context=block_adapter.unique_blocks_name[i][-1],
|
|
479
|
-
|
|
526
|
+
context_manager=context_manager,
|
|
480
527
|
)
|
|
481
528
|
for blocks, unique_name in zip(
|
|
482
529
|
block_adapter.blocks[i],
|
|
483
530
|
block_adapter.unique_blocks_name[i],
|
|
484
531
|
):
|
|
485
|
-
|
|
532
|
+
apply_stats(
|
|
486
533
|
blocks,
|
|
487
534
|
cache_context=unique_name,
|
|
488
|
-
|
|
535
|
+
context_manager=context_manager,
|
|
489
536
|
)
|
|
490
537
|
|
|
491
538
|
@classmethod
|
|
@@ -513,11 +560,13 @@ class CachedAdapter:
|
|
|
513
560
|
original_call = pipe.__class__._original_call
|
|
514
561
|
pipe.__class__.__call__ = original_call
|
|
515
562
|
del pipe.__class__._original_call
|
|
516
|
-
if hasattr(pipe, "
|
|
517
|
-
|
|
518
|
-
if isinstance(
|
|
519
|
-
|
|
520
|
-
|
|
563
|
+
if hasattr(pipe, "_context_manager"):
|
|
564
|
+
context_manager = pipe._context_manager
|
|
565
|
+
if isinstance(
|
|
566
|
+
context_manager, ContextManager._supported_managers
|
|
567
|
+
):
|
|
568
|
+
context_manager.clear_contexts()
|
|
569
|
+
del pipe._context_manager
|
|
521
570
|
if hasattr(pipe, "_is_cached"):
|
|
522
571
|
del pipe.__class__._is_cached
|
|
523
572
|
|
|
@@ -532,22 +581,22 @@ class CachedAdapter:
|
|
|
532
581
|
def _release_blocks_params(blocks):
|
|
533
582
|
if hasattr(blocks, "_forward_pattern"):
|
|
534
583
|
del blocks._forward_pattern
|
|
535
|
-
if hasattr(blocks, "
|
|
536
|
-
del blocks.
|
|
584
|
+
if hasattr(blocks, "_context_kwargs"):
|
|
585
|
+
del blocks._context_kwargs
|
|
537
586
|
|
|
538
587
|
def _release_transformer_params(transformer):
|
|
539
588
|
if hasattr(transformer, "_forward_pattern"):
|
|
540
589
|
del transformer._forward_pattern
|
|
541
590
|
if hasattr(transformer, "_has_separate_cfg"):
|
|
542
591
|
del transformer._has_separate_cfg
|
|
543
|
-
if hasattr(transformer, "
|
|
544
|
-
del transformer.
|
|
592
|
+
if hasattr(transformer, "_context_kwargs"):
|
|
593
|
+
del transformer._context_kwargs
|
|
545
594
|
for blocks in BlockAdapter.find_blocks(transformer):
|
|
546
595
|
_release_blocks_params(blocks)
|
|
547
596
|
|
|
548
597
|
def _release_pipeline_params(pipe):
|
|
549
|
-
if hasattr(pipe, "
|
|
550
|
-
del pipe.
|
|
598
|
+
if hasattr(pipe, "_context_kwargs"):
|
|
599
|
+
del pipe._context_kwargs
|
|
551
600
|
|
|
552
601
|
cls.release_hooks(
|
|
553
602
|
pipe_or_adapter,
|
|
@@ -558,14 +607,11 @@ class CachedAdapter:
|
|
|
558
607
|
|
|
559
608
|
# release stats hooks
|
|
560
609
|
from cache_dit.cache_factory.cache_blocks import (
|
|
561
|
-
|
|
610
|
+
remove_stats,
|
|
562
611
|
)
|
|
563
612
|
|
|
564
613
|
cls.release_hooks(
|
|
565
|
-
pipe_or_adapter,
|
|
566
|
-
remove_cached_stats,
|
|
567
|
-
remove_cached_stats,
|
|
568
|
-
remove_cached_stats,
|
|
614
|
+
pipe_or_adapter, remove_stats, remove_stats, remove_stats
|
|
569
615
|
)
|
|
570
616
|
|
|
571
617
|
@classmethod
|