cache-dit 1.0.3__py3-none-any.whl → 1.0.4__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Potentially problematic release.
This version of cache-dit might be problematic. Click here for more details.
- cache_dit/__init__.py +3 -0
- cache_dit/_version.py +2 -2
- cache_dit/cache_factory/__init__.py +8 -1
- cache_dit/cache_factory/cache_adapters/cache_adapter.py +90 -76
- cache_dit/cache_factory/cache_blocks/__init__.py +167 -17
- cache_dit/cache_factory/cache_blocks/pattern_0_1_2.py +10 -0
- cache_dit/cache_factory/cache_blocks/pattern_3_4_5.py +256 -24
- cache_dit/cache_factory/cache_blocks/pattern_base.py +273 -38
- cache_dit/cache_factory/cache_blocks/pattern_utils.py +55 -10
- cache_dit/cache_factory/cache_contexts/__init__.py +15 -2
- cache_dit/cache_factory/cache_contexts/cache_config.py +102 -0
- cache_dit/cache_factory/cache_contexts/cache_context.py +15 -93
- cache_dit/cache_factory/cache_contexts/cache_manager.py +7 -7
- cache_dit/cache_factory/cache_contexts/calibrators/taylorseer.py +78 -8
- cache_dit/cache_factory/cache_contexts/context_manager.py +29 -0
- cache_dit/cache_factory/cache_contexts/prune_config.py +69 -0
- cache_dit/cache_factory/cache_contexts/prune_context.py +155 -0
- cache_dit/cache_factory/cache_contexts/prune_manager.py +154 -0
- cache_dit/cache_factory/cache_interface.py +20 -14
- cache_dit/cache_factory/cache_types.py +19 -2
- cache_dit/cache_factory/params_modifier.py +7 -7
- cache_dit/cache_factory/utils.py +18 -7
- cache_dit/utils.py +191 -54
- {cache_dit-1.0.3.dist-info → cache_dit-1.0.4.dist-info}/METADATA +9 -9
- {cache_dit-1.0.3.dist-info → cache_dit-1.0.4.dist-info}/RECORD +29 -24
- {cache_dit-1.0.3.dist-info → cache_dit-1.0.4.dist-info}/WHEEL +0 -0
- {cache_dit-1.0.3.dist-info → cache_dit-1.0.4.dist-info}/entry_points.txt +0 -0
- {cache_dit-1.0.3.dist-info → cache_dit-1.0.4.dist-info}/licenses/LICENSE +0 -0
- {cache_dit-1.0.3.dist-info → cache_dit-1.0.4.dist-info}/top_level.txt +0 -0
cache_dit/__init__.py
CHANGED
|
@@ -19,6 +19,8 @@ from cache_dit.cache_factory import ParamsModifier
|
|
|
19
19
|
from cache_dit.cache_factory import ForwardPattern
|
|
20
20
|
from cache_dit.cache_factory import PatchFunctor
|
|
21
21
|
from cache_dit.cache_factory import BasicCacheConfig
|
|
22
|
+
from cache_dit.cache_factory import DBCacheConfig
|
|
23
|
+
from cache_dit.cache_factory import DBPruneConfig
|
|
22
24
|
from cache_dit.cache_factory import CalibratorConfig
|
|
23
25
|
from cache_dit.cache_factory import TaylorSeerCalibratorConfig
|
|
24
26
|
from cache_dit.cache_factory import FoCaCalibratorConfig
|
|
@@ -30,6 +32,7 @@ from cache_dit.quantize import quantize
|
|
|
30
32
|
|
|
31
33
|
NONE = CacheType.NONE
|
|
32
34
|
DBCache = CacheType.DBCache
|
|
35
|
+
DBPrune = CacheType.DBPrune
|
|
33
36
|
|
|
34
37
|
Pattern_0 = ForwardPattern.Pattern_0
|
|
35
38
|
Pattern_1 = ForwardPattern.Pattern_1
|
cache_dit/_version.py
CHANGED
|
@@ -28,7 +28,7 @@ version_tuple: VERSION_TUPLE
|
|
|
28
28
|
commit_id: COMMIT_ID
|
|
29
29
|
__commit_id__: COMMIT_ID
|
|
30
30
|
|
|
31
|
-
__version__ = version = '1.0.
|
|
32
|
-
__version_tuple__ = version_tuple = (1, 0,
|
|
31
|
+
__version__ = version = '1.0.4'
|
|
32
|
+
__version_tuple__ = version_tuple = (1, 0, 4)
|
|
33
33
|
|
|
34
34
|
__commit_id__ = commit_id = None
|
|
@@ -9,14 +9,21 @@ from cache_dit.cache_factory.patch_functors import PatchFunctor
|
|
|
9
9
|
from cache_dit.cache_factory.block_adapters import BlockAdapter
|
|
10
10
|
from cache_dit.cache_factory.block_adapters import BlockAdapterRegistry
|
|
11
11
|
|
|
12
|
-
from cache_dit.cache_factory.cache_contexts import CachedContext
|
|
13
12
|
from cache_dit.cache_factory.cache_contexts import BasicCacheConfig
|
|
13
|
+
from cache_dit.cache_factory.cache_contexts import DBCacheConfig
|
|
14
|
+
from cache_dit.cache_factory.cache_contexts import CachedContext
|
|
14
15
|
from cache_dit.cache_factory.cache_contexts import CachedContextManager
|
|
16
|
+
from cache_dit.cache_factory.cache_contexts import DBPruneConfig
|
|
17
|
+
from cache_dit.cache_factory.cache_contexts import PrunedContext
|
|
18
|
+
from cache_dit.cache_factory.cache_contexts import PrunedContextManager
|
|
19
|
+
from cache_dit.cache_factory.cache_contexts import ContextManager
|
|
15
20
|
from cache_dit.cache_factory.cache_contexts import CalibratorConfig
|
|
16
21
|
from cache_dit.cache_factory.cache_contexts import TaylorSeerCalibratorConfig
|
|
17
22
|
from cache_dit.cache_factory.cache_contexts import FoCaCalibratorConfig
|
|
18
23
|
|
|
19
24
|
from cache_dit.cache_factory.cache_blocks import CachedBlocks
|
|
25
|
+
from cache_dit.cache_factory.cache_blocks import PrunedBlocks
|
|
26
|
+
from cache_dit.cache_factory.cache_blocks import UnifiedBlocks
|
|
20
27
|
|
|
21
28
|
from cache_dit.cache_factory.cache_adapters import CachedAdapter
|
|
22
29
|
|
|
@@ -10,10 +10,10 @@ from cache_dit.cache_factory.cache_types import CacheType
|
|
|
10
10
|
from cache_dit.cache_factory.block_adapters import BlockAdapter
|
|
11
11
|
from cache_dit.cache_factory.block_adapters import ParamsModifier
|
|
12
12
|
from cache_dit.cache_factory.block_adapters import BlockAdapterRegistry
|
|
13
|
-
from cache_dit.cache_factory.cache_contexts import
|
|
13
|
+
from cache_dit.cache_factory.cache_contexts import ContextManager
|
|
14
14
|
from cache_dit.cache_factory.cache_contexts import BasicCacheConfig
|
|
15
15
|
from cache_dit.cache_factory.cache_contexts import CalibratorConfig
|
|
16
|
-
from cache_dit.cache_factory.cache_blocks import
|
|
16
|
+
from cache_dit.cache_factory.cache_blocks import UnifiedBlocks
|
|
17
17
|
from cache_dit.logger import init_logger
|
|
18
18
|
|
|
19
19
|
logger = init_logger(__name__)
|
|
@@ -32,7 +32,7 @@ class CachedAdapter:
|
|
|
32
32
|
DiffusionPipeline,
|
|
33
33
|
BlockAdapter,
|
|
34
34
|
],
|
|
35
|
-
**
|
|
35
|
+
**context_kwargs,
|
|
36
36
|
) -> Union[
|
|
37
37
|
DiffusionPipeline,
|
|
38
38
|
BlockAdapter,
|
|
@@ -51,7 +51,7 @@ class CachedAdapter:
|
|
|
51
51
|
block_adapter = BlockAdapterRegistry.get_adapter(
|
|
52
52
|
pipe_or_adapter
|
|
53
53
|
)
|
|
54
|
-
if params_modifiers :=
|
|
54
|
+
if params_modifiers := context_kwargs.pop(
|
|
55
55
|
"params_modifiers",
|
|
56
56
|
None,
|
|
57
57
|
):
|
|
@@ -59,7 +59,7 @@ class CachedAdapter:
|
|
|
59
59
|
|
|
60
60
|
return cls.cachify(
|
|
61
61
|
block_adapter,
|
|
62
|
-
**
|
|
62
|
+
**context_kwargs,
|
|
63
63
|
).pipe
|
|
64
64
|
else:
|
|
65
65
|
raise ValueError(
|
|
@@ -72,21 +72,21 @@ class CachedAdapter:
|
|
|
72
72
|
"Adapting Cache Acceleration using custom BlockAdapter!"
|
|
73
73
|
)
|
|
74
74
|
if pipe_or_adapter.params_modifiers is None:
|
|
75
|
-
if params_modifiers :=
|
|
75
|
+
if params_modifiers := context_kwargs.pop(
|
|
76
76
|
"params_modifiers", None
|
|
77
77
|
):
|
|
78
78
|
pipe_or_adapter.params_modifiers = params_modifiers
|
|
79
79
|
|
|
80
80
|
return cls.cachify(
|
|
81
81
|
pipe_or_adapter,
|
|
82
|
-
**
|
|
82
|
+
**context_kwargs,
|
|
83
83
|
)
|
|
84
84
|
|
|
85
85
|
@classmethod
|
|
86
86
|
def cachify(
|
|
87
87
|
cls,
|
|
88
88
|
block_adapter: BlockAdapter,
|
|
89
|
-
**
|
|
89
|
+
**context_kwargs,
|
|
90
90
|
) -> BlockAdapter:
|
|
91
91
|
|
|
92
92
|
if block_adapter.auto:
|
|
@@ -103,14 +103,15 @@ class CachedAdapter:
|
|
|
103
103
|
|
|
104
104
|
# 1. Apply cache on pipeline: wrap cache context, must
|
|
105
105
|
# call create_context before mock_blocks.
|
|
106
|
-
cls.create_context(
|
|
106
|
+
_, contexts_kwargs = cls.create_context(
|
|
107
107
|
block_adapter,
|
|
108
|
-
**
|
|
108
|
+
**context_kwargs,
|
|
109
109
|
)
|
|
110
110
|
|
|
111
111
|
# 2. Apply cache on transformer: mock cached blocks
|
|
112
112
|
cls.mock_blocks(
|
|
113
113
|
block_adapter,
|
|
114
|
+
contexts_kwargs,
|
|
114
115
|
)
|
|
115
116
|
|
|
116
117
|
return block_adapter
|
|
@@ -119,12 +120,10 @@ class CachedAdapter:
|
|
|
119
120
|
def check_context_kwargs(
|
|
120
121
|
cls,
|
|
121
122
|
block_adapter: BlockAdapter,
|
|
122
|
-
**
|
|
123
|
+
**context_kwargs,
|
|
123
124
|
):
|
|
124
|
-
# Check
|
|
125
|
-
cache_config: BasicCacheConfig =
|
|
126
|
-
"cache_config"
|
|
127
|
-
] # ref
|
|
125
|
+
# Check context_kwargs
|
|
126
|
+
cache_config: BasicCacheConfig = context_kwargs["cache_config"] # ref
|
|
128
127
|
assert cache_config is not None, "cache_config can not be None."
|
|
129
128
|
if cache_config.enable_separate_cfg is None:
|
|
130
129
|
# Check cfg for some specific case if users don't set it as True
|
|
@@ -150,19 +149,23 @@ class CachedAdapter:
|
|
|
150
149
|
f"Pipeline: {block_adapter.pipe.__class__.__name__}."
|
|
151
150
|
)
|
|
152
151
|
|
|
153
|
-
cache_type =
|
|
152
|
+
cache_type = context_kwargs.pop("cache_type", None)
|
|
154
153
|
if cache_type is not None:
|
|
155
|
-
assert (
|
|
156
|
-
cache_type
|
|
157
|
-
), "
|
|
154
|
+
assert isinstance(
|
|
155
|
+
cache_type, CacheType
|
|
156
|
+
), f"cache_type must be CacheType Enum, but got {type(cache_type)}."
|
|
157
|
+
assert cache_type == cache_config.cache_type, (
|
|
158
|
+
f"cache_type from context_kwargs ({cache_type}) must be the same "
|
|
159
|
+
f"as that from cache_config ({cache_config.cache_type})."
|
|
160
|
+
)
|
|
158
161
|
|
|
159
|
-
return
|
|
162
|
+
return context_kwargs
|
|
160
163
|
|
|
161
164
|
@classmethod
|
|
162
165
|
def create_context(
|
|
163
166
|
cls,
|
|
164
167
|
block_adapter: BlockAdapter,
|
|
165
|
-
**
|
|
168
|
+
**context_kwargs,
|
|
166
169
|
) -> Tuple[List[str], List[Dict[str, Any]]]:
|
|
167
170
|
|
|
168
171
|
BlockAdapter.assert_normalized(block_adapter)
|
|
@@ -170,9 +173,9 @@ class CachedAdapter:
|
|
|
170
173
|
if BlockAdapter.is_cached(block_adapter.pipe):
|
|
171
174
|
return block_adapter.pipe
|
|
172
175
|
|
|
173
|
-
# Check
|
|
174
|
-
|
|
175
|
-
block_adapter, **
|
|
176
|
+
# Check context_kwargs
|
|
177
|
+
context_kwargs = cls.check_context_kwargs(
|
|
178
|
+
block_adapter, **context_kwargs
|
|
176
179
|
)
|
|
177
180
|
# Apply cache on pipeline: wrap cache context
|
|
178
181
|
pipe_cls_name = block_adapter.pipe.__class__.__name__
|
|
@@ -181,13 +184,18 @@ class CachedAdapter:
|
|
|
181
184
|
# Different transformers (Wan2.2, etc) should shared the same
|
|
182
185
|
# cache manager but with different cache context (according
|
|
183
186
|
# to their unique instance id).
|
|
184
|
-
|
|
187
|
+
cache_config: BasicCacheConfig = context_kwargs.get(
|
|
188
|
+
"cache_config", None
|
|
189
|
+
)
|
|
190
|
+
assert cache_config is not None, "cache_config can not be None."
|
|
191
|
+
context_manager = ContextManager(
|
|
185
192
|
name=f"{pipe_cls_name}_{hash(id(block_adapter.pipe))}",
|
|
193
|
+
cache_type=cache_config.cache_type,
|
|
186
194
|
)
|
|
187
|
-
block_adapter.pipe.
|
|
195
|
+
block_adapter.pipe._context_manager = context_manager # instance level
|
|
188
196
|
|
|
189
197
|
flatten_contexts, contexts_kwargs = cls.modify_context_params(
|
|
190
|
-
block_adapter, **
|
|
198
|
+
block_adapter, **context_kwargs
|
|
191
199
|
)
|
|
192
200
|
|
|
193
201
|
original_call = block_adapter.pipe.__class__.__call__
|
|
@@ -200,8 +208,8 @@ class CachedAdapter:
|
|
|
200
208
|
flatten_contexts, contexts_kwargs
|
|
201
209
|
):
|
|
202
210
|
stack.enter_context(
|
|
203
|
-
|
|
204
|
-
|
|
211
|
+
context_manager.enter_context(
|
|
212
|
+
context_manager.reset_context(
|
|
205
213
|
context_name,
|
|
206
214
|
**context_kwargs,
|
|
207
215
|
),
|
|
@@ -223,14 +231,14 @@ class CachedAdapter:
|
|
|
223
231
|
def modify_context_params(
|
|
224
232
|
cls,
|
|
225
233
|
block_adapter: BlockAdapter,
|
|
226
|
-
**
|
|
234
|
+
**context_kwargs,
|
|
227
235
|
) -> Tuple[List[str], List[Dict[str, Any]]]:
|
|
228
236
|
|
|
229
237
|
flatten_contexts = BlockAdapter.flatten(
|
|
230
238
|
block_adapter.unique_blocks_name
|
|
231
239
|
)
|
|
232
240
|
contexts_kwargs = [
|
|
233
|
-
|
|
241
|
+
context_kwargs.copy()
|
|
234
242
|
for _ in range(
|
|
235
243
|
len(flatten_contexts),
|
|
236
244
|
)
|
|
@@ -267,7 +275,7 @@ class CachedAdapter:
|
|
|
267
275
|
"calibrator_config", None
|
|
268
276
|
)
|
|
269
277
|
if cache_config is not None:
|
|
270
|
-
message = f"Collected
|
|
278
|
+
message = f"Collected Context Config: {cache_config.strify()}"
|
|
271
279
|
if calibrator_config is not None:
|
|
272
280
|
message += f", Calibrator Config: {calibrator_config.strify(details=True)}"
|
|
273
281
|
else:
|
|
@@ -278,6 +286,7 @@ class CachedAdapter:
|
|
|
278
286
|
def mock_blocks(
|
|
279
287
|
cls,
|
|
280
288
|
block_adapter: BlockAdapter,
|
|
289
|
+
contexts_kwargs: List[Dict],
|
|
281
290
|
) -> List[torch.nn.Module]:
|
|
282
291
|
|
|
283
292
|
BlockAdapter.assert_normalized(block_adapter)
|
|
@@ -287,20 +296,23 @@ class CachedAdapter:
|
|
|
287
296
|
|
|
288
297
|
# Apply cache on transformer: mock cached transformer blocks
|
|
289
298
|
for (
|
|
290
|
-
|
|
299
|
+
unified_blocks,
|
|
291
300
|
transformer,
|
|
292
301
|
blocks_name,
|
|
293
302
|
unique_blocks_name,
|
|
294
303
|
dummy_blocks_names,
|
|
295
304
|
) in zip(
|
|
296
|
-
cls.
|
|
305
|
+
cls.collect_unified_blocks(
|
|
306
|
+
block_adapter,
|
|
307
|
+
contexts_kwargs,
|
|
308
|
+
),
|
|
297
309
|
block_adapter.transformer,
|
|
298
310
|
block_adapter.blocks_name,
|
|
299
311
|
block_adapter.unique_blocks_name,
|
|
300
312
|
block_adapter.dummy_blocks_names,
|
|
301
313
|
):
|
|
302
314
|
cls.mock_transformer(
|
|
303
|
-
|
|
315
|
+
unified_blocks,
|
|
304
316
|
transformer,
|
|
305
317
|
blocks_name,
|
|
306
318
|
unique_blocks_name,
|
|
@@ -312,7 +324,7 @@ class CachedAdapter:
|
|
|
312
324
|
@classmethod
|
|
313
325
|
def mock_transformer(
|
|
314
326
|
cls,
|
|
315
|
-
|
|
327
|
+
unified_blocks: Dict[str, torch.nn.ModuleList],
|
|
316
328
|
transformer: torch.nn.Module,
|
|
317
329
|
blocks_name: List[str],
|
|
318
330
|
unique_blocks_name: List[str],
|
|
@@ -352,7 +364,7 @@ class CachedAdapter:
|
|
|
352
364
|
):
|
|
353
365
|
stack.enter_context(
|
|
354
366
|
unittest.mock.patch.object(
|
|
355
|
-
self, name,
|
|
367
|
+
self, name, unified_blocks[context_name]
|
|
356
368
|
)
|
|
357
369
|
)
|
|
358
370
|
for dummy_name in dummy_blocks_names:
|
|
@@ -388,46 +400,51 @@ class CachedAdapter:
|
|
|
388
400
|
return transformer
|
|
389
401
|
|
|
390
402
|
@classmethod
|
|
391
|
-
def
|
|
403
|
+
def collect_unified_blocks(
|
|
392
404
|
cls,
|
|
393
405
|
block_adapter: BlockAdapter,
|
|
406
|
+
contexts_kwargs: List[Dict],
|
|
394
407
|
) -> List[Dict[str, torch.nn.ModuleList]]:
|
|
395
408
|
|
|
396
409
|
BlockAdapter.assert_normalized(block_adapter)
|
|
397
410
|
|
|
398
411
|
total_cached_blocks: List[Dict[str, torch.nn.ModuleList]] = []
|
|
399
|
-
assert hasattr(block_adapter.pipe, "
|
|
412
|
+
assert hasattr(block_adapter.pipe, "_context_manager")
|
|
400
413
|
assert isinstance(
|
|
401
|
-
block_adapter.pipe.
|
|
402
|
-
|
|
414
|
+
block_adapter.pipe._context_manager,
|
|
415
|
+
ContextManager._supported_managers,
|
|
403
416
|
)
|
|
404
417
|
|
|
405
418
|
for i in range(len(block_adapter.transformer)):
|
|
406
419
|
|
|
407
|
-
|
|
420
|
+
unified_blocks_bind_context = {}
|
|
408
421
|
for j in range(len(block_adapter.blocks[i])):
|
|
409
|
-
|
|
422
|
+
cache_config: BasicCacheConfig = contexts_kwargs[
|
|
423
|
+
i * len(block_adapter.blocks[i]) + j
|
|
424
|
+
]["cache_config"]
|
|
425
|
+
unified_blocks_bind_context[
|
|
410
426
|
block_adapter.unique_blocks_name[i][j]
|
|
411
427
|
] = torch.nn.ModuleList(
|
|
412
428
|
[
|
|
413
|
-
|
|
429
|
+
UnifiedBlocks(
|
|
414
430
|
# 0. Transformer blocks configuration
|
|
415
431
|
block_adapter.blocks[i][j],
|
|
416
432
|
transformer=block_adapter.transformer[i],
|
|
417
433
|
forward_pattern=block_adapter.forward_pattern[i][j],
|
|
418
434
|
check_forward_pattern=block_adapter.check_forward_pattern,
|
|
419
435
|
check_num_outputs=block_adapter.check_num_outputs,
|
|
420
|
-
# 1. Cache context configuration
|
|
436
|
+
# 1. Cache/Prune context configuration
|
|
421
437
|
cache_prefix=block_adapter.blocks_name[i][j],
|
|
422
438
|
cache_context=block_adapter.unique_blocks_name[i][
|
|
423
439
|
j
|
|
424
440
|
],
|
|
425
|
-
|
|
441
|
+
context_manager=block_adapter.pipe._context_manager,
|
|
442
|
+
cache_type=cache_config.cache_type,
|
|
426
443
|
)
|
|
427
444
|
]
|
|
428
445
|
)
|
|
429
446
|
|
|
430
|
-
total_cached_blocks.append(
|
|
447
|
+
total_cached_blocks.append(unified_blocks_bind_context)
|
|
431
448
|
|
|
432
449
|
return total_cached_blocks
|
|
433
450
|
|
|
@@ -437,7 +454,7 @@ class CachedAdapter:
|
|
|
437
454
|
block_adapter: BlockAdapter,
|
|
438
455
|
contexts_kwargs: List[Dict],
|
|
439
456
|
):
|
|
440
|
-
block_adapter.pipe.
|
|
457
|
+
block_adapter.pipe._context_kwargs = contexts_kwargs[0]
|
|
441
458
|
|
|
442
459
|
params_shift = 0
|
|
443
460
|
for i in range(len(block_adapter.transformer)):
|
|
@@ -448,16 +465,14 @@ class CachedAdapter:
|
|
|
448
465
|
block_adapter.transformer[i]._has_separate_cfg = (
|
|
449
466
|
block_adapter.has_separate_cfg
|
|
450
467
|
)
|
|
451
|
-
block_adapter.transformer[i].
|
|
452
|
-
|
|
453
|
-
|
|
468
|
+
block_adapter.transformer[i]._context_kwargs = contexts_kwargs[
|
|
469
|
+
params_shift
|
|
470
|
+
]
|
|
454
471
|
|
|
455
472
|
blocks = block_adapter.blocks[i]
|
|
456
473
|
for j in range(len(blocks)):
|
|
457
474
|
blocks[j]._forward_pattern = block_adapter.forward_pattern[i][j]
|
|
458
|
-
blocks[j].
|
|
459
|
-
params_shift + j
|
|
460
|
-
]
|
|
475
|
+
blocks[j]._context_kwargs = contexts_kwargs[params_shift + j]
|
|
461
476
|
|
|
462
477
|
params_shift += len(blocks)
|
|
463
478
|
|
|
@@ -467,25 +482,25 @@ class CachedAdapter:
|
|
|
467
482
|
block_adapter: BlockAdapter,
|
|
468
483
|
):
|
|
469
484
|
from cache_dit.cache_factory.cache_blocks import (
|
|
470
|
-
|
|
485
|
+
apply_stats,
|
|
471
486
|
)
|
|
472
487
|
|
|
473
|
-
|
|
488
|
+
context_manager = block_adapter.pipe._context_manager
|
|
474
489
|
|
|
475
490
|
for i in range(len(block_adapter.transformer)):
|
|
476
|
-
|
|
491
|
+
apply_stats(
|
|
477
492
|
block_adapter.transformer[i],
|
|
478
493
|
cache_context=block_adapter.unique_blocks_name[i][-1],
|
|
479
|
-
|
|
494
|
+
context_manager=context_manager,
|
|
480
495
|
)
|
|
481
496
|
for blocks, unique_name in zip(
|
|
482
497
|
block_adapter.blocks[i],
|
|
483
498
|
block_adapter.unique_blocks_name[i],
|
|
484
499
|
):
|
|
485
|
-
|
|
500
|
+
apply_stats(
|
|
486
501
|
blocks,
|
|
487
502
|
cache_context=unique_name,
|
|
488
|
-
|
|
503
|
+
context_manager=context_manager,
|
|
489
504
|
)
|
|
490
505
|
|
|
491
506
|
@classmethod
|
|
@@ -513,11 +528,13 @@ class CachedAdapter:
|
|
|
513
528
|
original_call = pipe.__class__._original_call
|
|
514
529
|
pipe.__class__.__call__ = original_call
|
|
515
530
|
del pipe.__class__._original_call
|
|
516
|
-
if hasattr(pipe, "
|
|
517
|
-
|
|
518
|
-
if isinstance(
|
|
519
|
-
|
|
520
|
-
|
|
531
|
+
if hasattr(pipe, "_context_manager"):
|
|
532
|
+
context_manager = pipe._context_manager
|
|
533
|
+
if isinstance(
|
|
534
|
+
context_manager, ContextManager._supported_managers
|
|
535
|
+
):
|
|
536
|
+
context_manager.clear_contexts()
|
|
537
|
+
del pipe._context_manager
|
|
521
538
|
if hasattr(pipe, "_is_cached"):
|
|
522
539
|
del pipe.__class__._is_cached
|
|
523
540
|
|
|
@@ -532,22 +549,22 @@ class CachedAdapter:
|
|
|
532
549
|
def _release_blocks_params(blocks):
|
|
533
550
|
if hasattr(blocks, "_forward_pattern"):
|
|
534
551
|
del blocks._forward_pattern
|
|
535
|
-
if hasattr(blocks, "
|
|
536
|
-
del blocks.
|
|
552
|
+
if hasattr(blocks, "_context_kwargs"):
|
|
553
|
+
del blocks._context_kwargs
|
|
537
554
|
|
|
538
555
|
def _release_transformer_params(transformer):
|
|
539
556
|
if hasattr(transformer, "_forward_pattern"):
|
|
540
557
|
del transformer._forward_pattern
|
|
541
558
|
if hasattr(transformer, "_has_separate_cfg"):
|
|
542
559
|
del transformer._has_separate_cfg
|
|
543
|
-
if hasattr(transformer, "
|
|
544
|
-
del transformer.
|
|
560
|
+
if hasattr(transformer, "_context_kwargs"):
|
|
561
|
+
del transformer._context_kwargs
|
|
545
562
|
for blocks in BlockAdapter.find_blocks(transformer):
|
|
546
563
|
_release_blocks_params(blocks)
|
|
547
564
|
|
|
548
565
|
def _release_pipeline_params(pipe):
|
|
549
|
-
if hasattr(pipe, "
|
|
550
|
-
del pipe.
|
|
566
|
+
if hasattr(pipe, "_context_kwargs"):
|
|
567
|
+
del pipe._context_kwargs
|
|
551
568
|
|
|
552
569
|
cls.release_hooks(
|
|
553
570
|
pipe_or_adapter,
|
|
@@ -558,14 +575,11 @@ class CachedAdapter:
|
|
|
558
575
|
|
|
559
576
|
# release stats hooks
|
|
560
577
|
from cache_dit.cache_factory.cache_blocks import (
|
|
561
|
-
|
|
578
|
+
remove_stats,
|
|
562
579
|
)
|
|
563
580
|
|
|
564
581
|
cls.release_hooks(
|
|
565
|
-
pipe_or_adapter,
|
|
566
|
-
remove_cached_stats,
|
|
567
|
-
remove_cached_stats,
|
|
568
|
-
remove_cached_stats,
|
|
582
|
+
pipe_or_adapter, remove_stats, remove_stats, remove_stats
|
|
569
583
|
)
|
|
570
584
|
|
|
571
585
|
@classmethod
|