cache-dit 0.2.28__py3-none-any.whl → 0.2.30__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Potentially problematic release.
This version of cache-dit might be problematic. Click here for more details.
- cache_dit/__init__.py +1 -0
- cache_dit/_version.py +2 -2
- cache_dit/cache_factory/__init__.py +1 -0
- cache_dit/cache_factory/block_adapters/__init__.py +166 -160
- cache_dit/cache_factory/block_adapters/block_adapters.py +195 -125
- cache_dit/cache_factory/block_adapters/block_registers.py +25 -13
- cache_dit/cache_factory/cache_adapters.py +209 -86
- cache_dit/cache_factory/cache_blocks/pattern_3_4_5.py +70 -67
- cache_dit/cache_factory/cache_blocks/utils.py +16 -0
- cache_dit/cache_factory/cache_contexts/cache_manager.py +22 -10
- cache_dit/cache_factory/cache_interface.py +26 -14
- cache_dit/cache_factory/cache_types.py +5 -5
- cache_dit/cache_factory/patch_functors/functor_chroma.py +3 -2
- cache_dit/cache_factory/patch_functors/functor_flux.py +3 -2
- cache_dit/utils.py +168 -55
- {cache_dit-0.2.28.dist-info → cache_dit-0.2.30.dist-info}/METADATA +34 -55
- {cache_dit-0.2.28.dist-info → cache_dit-0.2.30.dist-info}/RECORD +21 -21
- {cache_dit-0.2.28.dist-info → cache_dit-0.2.30.dist-info}/WHEEL +0 -0
- {cache_dit-0.2.28.dist-info → cache_dit-0.2.30.dist-info}/entry_points.txt +0 -0
- {cache_dit-0.2.28.dist-info → cache_dit-0.2.30.dist-info}/licenses/LICENSE +0 -0
- {cache_dit-0.2.28.dist-info → cache_dit-0.2.30.dist-info}/top_level.txt +0 -0
|
@@ -4,7 +4,7 @@ import unittest
|
|
|
4
4
|
import functools
|
|
5
5
|
|
|
6
6
|
from contextlib import ExitStack
|
|
7
|
-
from typing import Dict, List, Tuple, Any
|
|
7
|
+
from typing import Dict, List, Tuple, Any, Union, Callable
|
|
8
8
|
|
|
9
9
|
from diffusers import DiffusionPipeline
|
|
10
10
|
|
|
@@ -14,7 +14,10 @@ from cache_dit.cache_factory import ParamsModifier
|
|
|
14
14
|
from cache_dit.cache_factory import BlockAdapterRegistry
|
|
15
15
|
from cache_dit.cache_factory import CachedContextManager
|
|
16
16
|
from cache_dit.cache_factory import CachedBlocks
|
|
17
|
-
|
|
17
|
+
from cache_dit.cache_factory.cache_blocks.utils import (
|
|
18
|
+
patch_cached_stats,
|
|
19
|
+
remove_cached_stats,
|
|
20
|
+
)
|
|
18
21
|
from cache_dit.logger import init_logger
|
|
19
22
|
|
|
20
23
|
logger = init_logger(__name__)
|
|
@@ -29,36 +32,45 @@ class CachedAdapter:
|
|
|
29
32
|
@classmethod
|
|
30
33
|
def apply(
|
|
31
34
|
cls,
|
|
32
|
-
|
|
33
|
-
|
|
35
|
+
pipe_or_adapter: Union[
|
|
36
|
+
DiffusionPipeline,
|
|
37
|
+
BlockAdapter,
|
|
38
|
+
],
|
|
34
39
|
**cache_context_kwargs,
|
|
35
|
-
) ->
|
|
40
|
+
) -> Union[
|
|
41
|
+
DiffusionPipeline,
|
|
42
|
+
BlockAdapter,
|
|
43
|
+
]:
|
|
36
44
|
assert (
|
|
37
|
-
|
|
45
|
+
pipe_or_adapter is not None
|
|
38
46
|
), "pipe or block_adapter can not both None!"
|
|
39
47
|
|
|
40
|
-
if
|
|
41
|
-
if BlockAdapterRegistry.is_supported(
|
|
48
|
+
if isinstance(pipe_or_adapter, DiffusionPipeline):
|
|
49
|
+
if BlockAdapterRegistry.is_supported(pipe_or_adapter):
|
|
42
50
|
logger.info(
|
|
43
|
-
f"{
|
|
44
|
-
"Use it's pre-defined BlockAdapter
|
|
51
|
+
f"{pipe_or_adapter.__class__.__name__} is officially "
|
|
52
|
+
"supported by cache-dit. Use it's pre-defined BlockAdapter "
|
|
53
|
+
"directly!"
|
|
54
|
+
)
|
|
55
|
+
block_adapter = BlockAdapterRegistry.get_adapter(
|
|
56
|
+
pipe_or_adapter
|
|
45
57
|
)
|
|
46
|
-
block_adapter = BlockAdapterRegistry.get_adapter(pipe)
|
|
47
58
|
return cls.cachify(
|
|
48
59
|
block_adapter,
|
|
49
60
|
**cache_context_kwargs,
|
|
50
|
-
)
|
|
61
|
+
).pipe
|
|
51
62
|
else:
|
|
52
63
|
raise ValueError(
|
|
53
|
-
f"{
|
|
64
|
+
f"{pipe_or_adapter.__class__.__name__} is not officially supported "
|
|
54
65
|
"by cache-dit, please set BlockAdapter instead!"
|
|
55
66
|
)
|
|
56
67
|
else:
|
|
68
|
+
assert isinstance(pipe_or_adapter, BlockAdapter)
|
|
57
69
|
logger.info(
|
|
58
|
-
"Adapting
|
|
70
|
+
"Adapting Cache Acceleration using custom BlockAdapter!"
|
|
59
71
|
)
|
|
60
72
|
return cls.cachify(
|
|
61
|
-
|
|
73
|
+
pipe_or_adapter,
|
|
62
74
|
**cache_context_kwargs,
|
|
63
75
|
)
|
|
64
76
|
|
|
@@ -67,7 +79,7 @@ class CachedAdapter:
|
|
|
67
79
|
cls,
|
|
68
80
|
block_adapter: BlockAdapter,
|
|
69
81
|
**cache_context_kwargs,
|
|
70
|
-
) ->
|
|
82
|
+
) -> BlockAdapter:
|
|
71
83
|
|
|
72
84
|
if block_adapter.auto:
|
|
73
85
|
block_adapter = BlockAdapter.auto_block_adapter(
|
|
@@ -79,7 +91,7 @@ class CachedAdapter:
|
|
|
79
91
|
# 0. Must normalize block_adapter before apply cache
|
|
80
92
|
block_adapter = BlockAdapter.normalize(block_adapter)
|
|
81
93
|
if BlockAdapter.is_cached(block_adapter):
|
|
82
|
-
return block_adapter
|
|
94
|
+
return block_adapter
|
|
83
95
|
|
|
84
96
|
# 1. Apply cache on pipeline: wrap cache context, must
|
|
85
97
|
# call create_context before mock_blocks.
|
|
@@ -93,53 +105,36 @@ class CachedAdapter:
|
|
|
93
105
|
block_adapter,
|
|
94
106
|
)
|
|
95
107
|
|
|
96
|
-
return block_adapter
|
|
108
|
+
return block_adapter
|
|
97
109
|
|
|
98
110
|
@classmethod
|
|
99
|
-
def
|
|
111
|
+
def check_context_kwargs(
|
|
100
112
|
cls,
|
|
101
113
|
block_adapter: BlockAdapter,
|
|
102
|
-
|
|
114
|
+
**cache_context_kwargs,
|
|
103
115
|
):
|
|
104
|
-
block_adapter.pipe._cache_context_kwargs = contexts_kwargs[0]
|
|
105
|
-
|
|
106
|
-
params_shift = 0
|
|
107
|
-
for i in range(len(block_adapter.transformer)):
|
|
108
|
-
|
|
109
|
-
block_adapter.transformer[i]._forward_pattern = (
|
|
110
|
-
block_adapter.forward_pattern
|
|
111
|
-
)
|
|
112
|
-
block_adapter.transformer[i]._has_separate_cfg = (
|
|
113
|
-
block_adapter.has_separate_cfg
|
|
114
|
-
)
|
|
115
|
-
block_adapter.transformer[i]._cache_context_kwargs = (
|
|
116
|
-
contexts_kwargs[params_shift]
|
|
117
|
-
)
|
|
118
|
-
|
|
119
|
-
blocks = block_adapter.blocks[i]
|
|
120
|
-
for j in range(len(blocks)):
|
|
121
|
-
blocks[j]._forward_pattern = block_adapter.forward_pattern[i][j]
|
|
122
|
-
blocks[j]._cache_context_kwargs = contexts_kwargs[
|
|
123
|
-
params_shift + j
|
|
124
|
-
]
|
|
125
|
-
|
|
126
|
-
params_shift += len(blocks)
|
|
127
|
-
|
|
128
|
-
@classmethod
|
|
129
|
-
def check_context_kwargs(cls, pipe, **cache_context_kwargs):
|
|
130
116
|
# Check cache_context_kwargs
|
|
131
117
|
if not cache_context_kwargs["enable_spearate_cfg"]:
|
|
132
118
|
# Check cfg for some specific case if users don't set it as True
|
|
133
|
-
|
|
134
|
-
|
|
135
|
-
|
|
136
|
-
|
|
137
|
-
|
|
138
|
-
|
|
139
|
-
|
|
140
|
-
|
|
119
|
+
if BlockAdapterRegistry.has_separate_cfg(block_adapter):
|
|
120
|
+
cache_context_kwargs["enable_spearate_cfg"] = True
|
|
121
|
+
logger.info(
|
|
122
|
+
f"Use custom 'enable_spearate_cfg' from BlockAdapter: True. "
|
|
123
|
+
f"Pipeline: {block_adapter.pipe.__class__.__name__}."
|
|
124
|
+
)
|
|
125
|
+
else:
|
|
126
|
+
cache_context_kwargs["enable_spearate_cfg"] = (
|
|
127
|
+
BlockAdapterRegistry.has_separate_cfg(block_adapter.pipe)
|
|
128
|
+
)
|
|
129
|
+
logger.info(
|
|
130
|
+
f"Use default 'enable_spearate_cfg' from block adapter "
|
|
131
|
+
f"register: {cache_context_kwargs['enable_spearate_cfg']}, "
|
|
132
|
+
f"Pipeline: {block_adapter.pipe.__class__.__name__}."
|
|
133
|
+
)
|
|
141
134
|
|
|
142
|
-
if
|
|
135
|
+
if (
|
|
136
|
+
cache_type := cache_context_kwargs.pop("cache_type", None)
|
|
137
|
+
) is not None:
|
|
143
138
|
assert (
|
|
144
139
|
cache_type == CacheType.DBCache
|
|
145
140
|
), "Custom cache setting only support for DBCache now!"
|
|
@@ -160,8 +155,7 @@ class CachedAdapter:
|
|
|
160
155
|
|
|
161
156
|
# Check cache_context_kwargs
|
|
162
157
|
cache_context_kwargs = cls.check_context_kwargs(
|
|
163
|
-
block_adapter
|
|
164
|
-
**cache_context_kwargs,
|
|
158
|
+
block_adapter, **cache_context_kwargs
|
|
165
159
|
)
|
|
166
160
|
# Apply cache on pipeline: wrap cache context
|
|
167
161
|
pipe_cls_name = block_adapter.pipe.__class__.__name__
|
|
@@ -197,14 +191,14 @@ class CachedAdapter:
|
|
|
197
191
|
)
|
|
198
192
|
)
|
|
199
193
|
outputs = original_call(self, *args, **kwargs)
|
|
200
|
-
cls.
|
|
194
|
+
cls.apply_stats_hooks(block_adapter)
|
|
201
195
|
return outputs
|
|
202
196
|
|
|
203
197
|
block_adapter.pipe.__class__.__call__ = new_call
|
|
204
198
|
block_adapter.pipe.__class__._original_call = original_call
|
|
205
199
|
block_adapter.pipe.__class__._is_cached = True
|
|
206
200
|
|
|
207
|
-
cls.
|
|
201
|
+
cls.apply_params_hooks(block_adapter, contexts_kwargs)
|
|
208
202
|
|
|
209
203
|
return block_adapter.pipe
|
|
210
204
|
|
|
@@ -248,33 +242,6 @@ class CachedAdapter:
|
|
|
248
242
|
|
|
249
243
|
return flatten_contexts, contexts_kwargs
|
|
250
244
|
|
|
251
|
-
@classmethod
|
|
252
|
-
def patch_stats(
|
|
253
|
-
cls,
|
|
254
|
-
block_adapter: BlockAdapter,
|
|
255
|
-
):
|
|
256
|
-
from cache_dit.cache_factory.cache_blocks.utils import (
|
|
257
|
-
patch_cached_stats,
|
|
258
|
-
)
|
|
259
|
-
|
|
260
|
-
cache_manager = block_adapter.pipe._cache_manager
|
|
261
|
-
|
|
262
|
-
for i in range(len(block_adapter.transformer)):
|
|
263
|
-
patch_cached_stats(
|
|
264
|
-
block_adapter.transformer[i],
|
|
265
|
-
cache_context=block_adapter.unique_blocks_name[i][-1],
|
|
266
|
-
cache_manager=cache_manager,
|
|
267
|
-
)
|
|
268
|
-
for blocks, unique_name in zip(
|
|
269
|
-
block_adapter.blocks[i],
|
|
270
|
-
block_adapter.unique_blocks_name[i],
|
|
271
|
-
):
|
|
272
|
-
patch_cached_stats(
|
|
273
|
-
blocks,
|
|
274
|
-
cache_context=unique_name,
|
|
275
|
-
cache_manager=cache_manager,
|
|
276
|
-
)
|
|
277
|
-
|
|
278
245
|
@classmethod
|
|
279
246
|
def mock_blocks(
|
|
280
247
|
cls,
|
|
@@ -392,3 +359,159 @@ class CachedAdapter:
|
|
|
392
359
|
total_cached_blocks.append(cached_blocks_bind_context)
|
|
393
360
|
|
|
394
361
|
return total_cached_blocks
|
|
362
|
+
|
|
363
|
+
@classmethod
|
|
364
|
+
def apply_params_hooks(
|
|
365
|
+
cls,
|
|
366
|
+
block_adapter: BlockAdapter,
|
|
367
|
+
contexts_kwargs: List[Dict],
|
|
368
|
+
):
|
|
369
|
+
block_adapter.pipe._cache_context_kwargs = contexts_kwargs[0]
|
|
370
|
+
|
|
371
|
+
params_shift = 0
|
|
372
|
+
for i in range(len(block_adapter.transformer)):
|
|
373
|
+
|
|
374
|
+
block_adapter.transformer[i]._forward_pattern = (
|
|
375
|
+
block_adapter.forward_pattern
|
|
376
|
+
)
|
|
377
|
+
block_adapter.transformer[i]._has_separate_cfg = (
|
|
378
|
+
block_adapter.has_separate_cfg
|
|
379
|
+
)
|
|
380
|
+
block_adapter.transformer[i]._cache_context_kwargs = (
|
|
381
|
+
contexts_kwargs[params_shift]
|
|
382
|
+
)
|
|
383
|
+
|
|
384
|
+
blocks = block_adapter.blocks[i]
|
|
385
|
+
for j in range(len(blocks)):
|
|
386
|
+
blocks[j]._forward_pattern = block_adapter.forward_pattern[i][j]
|
|
387
|
+
blocks[j]._cache_context_kwargs = contexts_kwargs[
|
|
388
|
+
params_shift + j
|
|
389
|
+
]
|
|
390
|
+
|
|
391
|
+
params_shift += len(blocks)
|
|
392
|
+
|
|
393
|
+
@classmethod
|
|
394
|
+
def apply_stats_hooks(
|
|
395
|
+
cls,
|
|
396
|
+
block_adapter: BlockAdapter,
|
|
397
|
+
):
|
|
398
|
+
cache_manager = block_adapter.pipe._cache_manager
|
|
399
|
+
|
|
400
|
+
for i in range(len(block_adapter.transformer)):
|
|
401
|
+
patch_cached_stats(
|
|
402
|
+
block_adapter.transformer[i],
|
|
403
|
+
cache_context=block_adapter.unique_blocks_name[i][-1],
|
|
404
|
+
cache_manager=cache_manager,
|
|
405
|
+
)
|
|
406
|
+
for blocks, unique_name in zip(
|
|
407
|
+
block_adapter.blocks[i],
|
|
408
|
+
block_adapter.unique_blocks_name[i],
|
|
409
|
+
):
|
|
410
|
+
patch_cached_stats(
|
|
411
|
+
blocks,
|
|
412
|
+
cache_context=unique_name,
|
|
413
|
+
cache_manager=cache_manager,
|
|
414
|
+
)
|
|
415
|
+
|
|
416
|
+
@classmethod
|
|
417
|
+
def maybe_release_hooks(
|
|
418
|
+
cls,
|
|
419
|
+
pipe_or_adapter: Union[
|
|
420
|
+
DiffusionPipeline,
|
|
421
|
+
BlockAdapter,
|
|
422
|
+
],
|
|
423
|
+
):
|
|
424
|
+
# release model hooks
|
|
425
|
+
def _release_blocks_hooks(blocks):
|
|
426
|
+
return
|
|
427
|
+
|
|
428
|
+
def _release_transformer_hooks(transformer):
|
|
429
|
+
if hasattr(transformer, "_original_forward"):
|
|
430
|
+
original_forward = transformer._original_forward
|
|
431
|
+
transformer.forward = original_forward.__get__(transformer)
|
|
432
|
+
del transformer._original_forward
|
|
433
|
+
if hasattr(transformer, "_is_cached"):
|
|
434
|
+
del transformer._is_cached
|
|
435
|
+
|
|
436
|
+
def _release_pipeline_hooks(pipe):
|
|
437
|
+
if hasattr(pipe, "_original_call"):
|
|
438
|
+
original_call = pipe.__class__._original_call
|
|
439
|
+
pipe.__class__.__call__ = original_call
|
|
440
|
+
del pipe.__class__._original_call
|
|
441
|
+
if hasattr(pipe, "_cache_manager"):
|
|
442
|
+
cache_manager = pipe._cache_manager
|
|
443
|
+
if isinstance(cache_manager, CachedContextManager):
|
|
444
|
+
cache_manager.clear_contexts()
|
|
445
|
+
del pipe._cache_manager
|
|
446
|
+
if hasattr(pipe, "_is_cached"):
|
|
447
|
+
del pipe.__class__._is_cached
|
|
448
|
+
|
|
449
|
+
cls.release_hooks(
|
|
450
|
+
pipe_or_adapter,
|
|
451
|
+
_release_blocks_hooks,
|
|
452
|
+
_release_transformer_hooks,
|
|
453
|
+
_release_pipeline_hooks,
|
|
454
|
+
)
|
|
455
|
+
|
|
456
|
+
# release params hooks
|
|
457
|
+
def _release_blocks_params(blocks):
|
|
458
|
+
if hasattr(blocks, "_forward_pattern"):
|
|
459
|
+
del blocks._forward_pattern
|
|
460
|
+
if hasattr(blocks, "_cache_context_kwargs"):
|
|
461
|
+
del blocks._cache_context_kwargs
|
|
462
|
+
|
|
463
|
+
def _release_transformer_params(transformer):
|
|
464
|
+
if hasattr(transformer, "_forward_pattern"):
|
|
465
|
+
del transformer._forward_pattern
|
|
466
|
+
if hasattr(transformer, "_has_separate_cfg"):
|
|
467
|
+
del transformer._has_separate_cfg
|
|
468
|
+
if hasattr(transformer, "_cache_context_kwargs"):
|
|
469
|
+
del transformer._cache_context_kwargs
|
|
470
|
+
for blocks in BlockAdapter.find_blocks(transformer):
|
|
471
|
+
_release_blocks_params(blocks)
|
|
472
|
+
|
|
473
|
+
def _release_pipeline_params(pipe):
|
|
474
|
+
if hasattr(pipe, "_cache_context_kwargs"):
|
|
475
|
+
del pipe._cache_context_kwargs
|
|
476
|
+
|
|
477
|
+
cls.release_hooks(
|
|
478
|
+
pipe_or_adapter,
|
|
479
|
+
_release_blocks_params,
|
|
480
|
+
_release_transformer_params,
|
|
481
|
+
_release_pipeline_params,
|
|
482
|
+
)
|
|
483
|
+
|
|
484
|
+
# release stats hooks
|
|
485
|
+
cls.release_hooks(
|
|
486
|
+
pipe_or_adapter,
|
|
487
|
+
remove_cached_stats,
|
|
488
|
+
remove_cached_stats,
|
|
489
|
+
remove_cached_stats,
|
|
490
|
+
)
|
|
491
|
+
|
|
492
|
+
@classmethod
|
|
493
|
+
def release_hooks(
|
|
494
|
+
cls,
|
|
495
|
+
pipe_or_adapter: Union[
|
|
496
|
+
DiffusionPipeline,
|
|
497
|
+
BlockAdapter,
|
|
498
|
+
],
|
|
499
|
+
_release_blocks: Callable,
|
|
500
|
+
_release_transformer: Callable,
|
|
501
|
+
_release_pipeline: Callable,
|
|
502
|
+
):
|
|
503
|
+
if isinstance(pipe_or_adapter, DiffusionPipeline):
|
|
504
|
+
pipe = pipe_or_adapter
|
|
505
|
+
_release_pipeline(pipe)
|
|
506
|
+
if hasattr(pipe, "transformer"):
|
|
507
|
+
_release_transformer(pipe.transformer)
|
|
508
|
+
if hasattr(pipe, "transformer_2"): # Wan 2.2
|
|
509
|
+
_release_transformer(pipe.transformer_2)
|
|
510
|
+
elif isinstance(pipe_or_adapter, BlockAdapter):
|
|
511
|
+
adapter = pipe_or_adapter
|
|
512
|
+
BlockAdapter.assert_normalized(adapter)
|
|
513
|
+
_release_pipeline(adapter.pipe)
|
|
514
|
+
for transformer in BlockAdapter.flatten(adapter.transformer):
|
|
515
|
+
_release_transformer(transformer)
|
|
516
|
+
for blocks in BlockAdapter.flatten(adapter.blocks):
|
|
517
|
+
_release_blocks(blocks)
|
|
@@ -1,5 +1,6 @@
|
|
|
1
1
|
import torch
|
|
2
2
|
|
|
3
|
+
from typing import Dict, Any
|
|
3
4
|
from cache_dit.cache_factory import ForwardPattern
|
|
4
5
|
from cache_dit.cache_factory.cache_blocks.pattern_base import (
|
|
5
6
|
CachedBlocks_Pattern_Base,
|
|
@@ -31,7 +32,7 @@ class CachedBlocks_Pattern_3_4_5(CachedBlocks_Pattern_Base):
|
|
|
31
32
|
# Call first `n` blocks to process the hidden states for
|
|
32
33
|
# more stable diff calculation.
|
|
33
34
|
# encoder_hidden_states: None Pattern 3, else 4, 5
|
|
34
|
-
hidden_states,
|
|
35
|
+
hidden_states, new_encoder_hidden_states = self.call_Fn_blocks(
|
|
35
36
|
hidden_states,
|
|
36
37
|
*args,
|
|
37
38
|
**kwargs,
|
|
@@ -60,11 +61,10 @@ class CachedBlocks_Pattern_3_4_5(CachedBlocks_Pattern_Base):
|
|
|
60
61
|
if can_use_cache:
|
|
61
62
|
self.cache_manager.add_cached_step()
|
|
62
63
|
del Fn_hidden_states_residual
|
|
63
|
-
hidden_states,
|
|
64
|
+
hidden_states, new_encoder_hidden_states = (
|
|
64
65
|
self.cache_manager.apply_cache(
|
|
65
66
|
hidden_states,
|
|
66
|
-
#
|
|
67
|
-
encoder_hidden_states,
|
|
67
|
+
new_encoder_hidden_states, # encoder_hidden_states not use cache
|
|
68
68
|
prefix=(
|
|
69
69
|
f"{self.cache_prefix}_Bn_residual"
|
|
70
70
|
if self.cache_manager.is_cache_residual()
|
|
@@ -80,12 +80,12 @@ class CachedBlocks_Pattern_3_4_5(CachedBlocks_Pattern_Base):
|
|
|
80
80
|
torch._dynamo.graph_break()
|
|
81
81
|
# Call last `n` blocks to further process the hidden states
|
|
82
82
|
# for higher precision.
|
|
83
|
-
|
|
84
|
-
hidden_states,
|
|
85
|
-
|
|
86
|
-
|
|
87
|
-
|
|
88
|
-
|
|
83
|
+
if self.cache_manager.Bn_compute_blocks() > 0:
|
|
84
|
+
hidden_states, new_encoder_hidden_states = self.call_Bn_blocks(
|
|
85
|
+
hidden_states,
|
|
86
|
+
*args,
|
|
87
|
+
**kwargs,
|
|
88
|
+
)
|
|
89
89
|
else:
|
|
90
90
|
self.cache_manager.set_Fn_buffer(
|
|
91
91
|
Fn_hidden_states_residual,
|
|
@@ -99,19 +99,20 @@ class CachedBlocks_Pattern_3_4_5(CachedBlocks_Pattern_Base):
|
|
|
99
99
|
)
|
|
100
100
|
del Fn_hidden_states_residual
|
|
101
101
|
torch._dynamo.graph_break()
|
|
102
|
+
old_encoder_hidden_states = new_encoder_hidden_states
|
|
102
103
|
(
|
|
103
104
|
hidden_states,
|
|
104
|
-
|
|
105
|
+
new_encoder_hidden_states,
|
|
105
106
|
hidden_states_residual,
|
|
106
|
-
# None Pattern 3, else 4, 5
|
|
107
|
-
encoder_hidden_states_residual,
|
|
108
107
|
) = self.call_Mn_blocks( # middle
|
|
109
108
|
hidden_states,
|
|
110
|
-
# None Pattern 3, else 4, 5
|
|
111
|
-
encoder_hidden_states,
|
|
112
109
|
*args,
|
|
113
110
|
**kwargs,
|
|
114
111
|
)
|
|
112
|
+
if new_encoder_hidden_states is not None:
|
|
113
|
+
new_encoder_hidden_states_residual = (
|
|
114
|
+
new_encoder_hidden_states - old_encoder_hidden_states
|
|
115
|
+
)
|
|
115
116
|
torch._dynamo.graph_break()
|
|
116
117
|
if self.cache_manager.is_cache_residual():
|
|
117
118
|
self.cache_manager.set_Bn_buffer(
|
|
@@ -119,34 +120,32 @@ class CachedBlocks_Pattern_3_4_5(CachedBlocks_Pattern_Base):
|
|
|
119
120
|
prefix=f"{self.cache_prefix}_Bn_residual",
|
|
120
121
|
)
|
|
121
122
|
else:
|
|
122
|
-
# TaylorSeer
|
|
123
123
|
self.cache_manager.set_Bn_buffer(
|
|
124
124
|
hidden_states,
|
|
125
125
|
prefix=f"{self.cache_prefix}_Bn_hidden_states",
|
|
126
126
|
)
|
|
127
|
+
|
|
127
128
|
if self.cache_manager.is_encoder_cache_residual():
|
|
128
|
-
|
|
129
|
-
|
|
130
|
-
|
|
131
|
-
|
|
132
|
-
|
|
129
|
+
if new_encoder_hidden_states is not None:
|
|
130
|
+
self.cache_manager.set_Bn_encoder_buffer(
|
|
131
|
+
new_encoder_hidden_states_residual,
|
|
132
|
+
prefix=f"{self.cache_prefix}_Bn_residual",
|
|
133
|
+
)
|
|
133
134
|
else:
|
|
134
|
-
|
|
135
|
-
|
|
136
|
-
|
|
137
|
-
|
|
138
|
-
|
|
139
|
-
)
|
|
135
|
+
if new_encoder_hidden_states is not None:
|
|
136
|
+
self.cache_manager.set_Bn_encoder_buffer(
|
|
137
|
+
new_encoder_hidden_states_residual,
|
|
138
|
+
prefix=f"{self.cache_prefix}_Bn_hidden_states",
|
|
139
|
+
)
|
|
140
140
|
torch._dynamo.graph_break()
|
|
141
141
|
# Call last `n` blocks to further process the hidden states
|
|
142
142
|
# for higher precision.
|
|
143
|
-
|
|
144
|
-
hidden_states,
|
|
145
|
-
|
|
146
|
-
|
|
147
|
-
|
|
148
|
-
|
|
149
|
-
)
|
|
143
|
+
if self.cache_manager.Bn_compute_blocks() > 0:
|
|
144
|
+
hidden_states, new_encoder_hidden_states = self.call_Bn_blocks(
|
|
145
|
+
hidden_states,
|
|
146
|
+
*args,
|
|
147
|
+
**kwargs,
|
|
148
|
+
)
|
|
150
149
|
|
|
151
150
|
torch._dynamo.graph_break()
|
|
152
151
|
|
|
@@ -154,12 +153,21 @@ class CachedBlocks_Pattern_3_4_5(CachedBlocks_Pattern_Base):
|
|
|
154
153
|
hidden_states
|
|
155
154
|
if self.forward_pattern.Return_H_Only
|
|
156
155
|
else (
|
|
157
|
-
(hidden_states,
|
|
156
|
+
(hidden_states, new_encoder_hidden_states)
|
|
158
157
|
if self.forward_pattern.Return_H_First
|
|
159
|
-
else (
|
|
158
|
+
else (new_encoder_hidden_states, hidden_states)
|
|
160
159
|
)
|
|
161
160
|
)
|
|
162
161
|
|
|
162
|
+
@torch.compiler.disable
|
|
163
|
+
def maybe_update_kwargs(
|
|
164
|
+
self, encoder_hidden_states, kwargs: Dict[str, Any]
|
|
165
|
+
) -> Dict[str, Any]:
|
|
166
|
+
# if "encoder_hidden_states" in kwargs:
|
|
167
|
+
# kwargs["encoder_hidden_states"] = encoder_hidden_states
|
|
168
|
+
# return kwargs
|
|
169
|
+
return kwargs
|
|
170
|
+
|
|
163
171
|
def call_Fn_blocks(
|
|
164
172
|
self,
|
|
165
173
|
hidden_states: torch.Tensor,
|
|
@@ -172,7 +180,7 @@ class CachedBlocks_Pattern_3_4_5(CachedBlocks_Pattern_Base):
|
|
|
172
180
|
f"Fn_compute_blocks {self.cache_manager.Fn_compute_blocks()} must be less than "
|
|
173
181
|
f"the number of transformer blocks {len(self.transformer_blocks)}"
|
|
174
182
|
)
|
|
175
|
-
|
|
183
|
+
new_encoder_hidden_states = None
|
|
176
184
|
for block in self._Fn_blocks():
|
|
177
185
|
hidden_states = block(
|
|
178
186
|
hidden_states,
|
|
@@ -180,25 +188,27 @@ class CachedBlocks_Pattern_3_4_5(CachedBlocks_Pattern_Base):
|
|
|
180
188
|
**kwargs,
|
|
181
189
|
)
|
|
182
190
|
if not isinstance(hidden_states, torch.Tensor): # Pattern 4, 5
|
|
183
|
-
hidden_states,
|
|
191
|
+
hidden_states, new_encoder_hidden_states = hidden_states
|
|
184
192
|
if not self.forward_pattern.Return_H_First:
|
|
185
|
-
hidden_states,
|
|
186
|
-
|
|
193
|
+
hidden_states, new_encoder_hidden_states = (
|
|
194
|
+
new_encoder_hidden_states,
|
|
187
195
|
hidden_states,
|
|
188
196
|
)
|
|
197
|
+
kwargs = self.maybe_update_kwargs(
|
|
198
|
+
new_encoder_hidden_states,
|
|
199
|
+
kwargs,
|
|
200
|
+
)
|
|
189
201
|
|
|
190
|
-
return hidden_states,
|
|
202
|
+
return hidden_states, new_encoder_hidden_states
|
|
191
203
|
|
|
192
204
|
def call_Mn_blocks(
|
|
193
205
|
self,
|
|
194
206
|
hidden_states: torch.Tensor,
|
|
195
|
-
# None Pattern 3, else 4, 5
|
|
196
|
-
encoder_hidden_states: torch.Tensor | None,
|
|
197
207
|
*args,
|
|
198
208
|
**kwargs,
|
|
199
209
|
):
|
|
200
210
|
original_hidden_states = hidden_states
|
|
201
|
-
|
|
211
|
+
new_encoder_hidden_states = None
|
|
202
212
|
for block in self._Mn_blocks():
|
|
203
213
|
hidden_states = block(
|
|
204
214
|
hidden_states,
|
|
@@ -206,44 +216,33 @@ class CachedBlocks_Pattern_3_4_5(CachedBlocks_Pattern_Base):
|
|
|
206
216
|
**kwargs,
|
|
207
217
|
)
|
|
208
218
|
if not isinstance(hidden_states, torch.Tensor): # Pattern 4, 5
|
|
209
|
-
hidden_states,
|
|
219
|
+
hidden_states, new_encoder_hidden_states = hidden_states
|
|
210
220
|
if not self.forward_pattern.Return_H_First:
|
|
211
|
-
hidden_states,
|
|
212
|
-
|
|
221
|
+
hidden_states, new_encoder_hidden_states = (
|
|
222
|
+
new_encoder_hidden_states,
|
|
213
223
|
hidden_states,
|
|
214
224
|
)
|
|
225
|
+
kwargs = self.maybe_update_kwargs(
|
|
226
|
+
new_encoder_hidden_states,
|
|
227
|
+
kwargs,
|
|
228
|
+
)
|
|
215
229
|
|
|
216
230
|
# compute hidden_states residual
|
|
217
231
|
hidden_states = hidden_states.contiguous()
|
|
218
232
|
hidden_states_residual = hidden_states - original_hidden_states
|
|
219
|
-
if (
|
|
220
|
-
original_encoder_hidden_states is not None
|
|
221
|
-
and encoder_hidden_states is not None
|
|
222
|
-
): # Pattern 4, 5
|
|
223
|
-
encoder_hidden_states_residual = (
|
|
224
|
-
encoder_hidden_states - original_encoder_hidden_states
|
|
225
|
-
)
|
|
226
|
-
else:
|
|
227
|
-
encoder_hidden_states_residual = None # Pattern 3
|
|
228
233
|
|
|
229
234
|
return (
|
|
230
235
|
hidden_states,
|
|
231
|
-
|
|
236
|
+
new_encoder_hidden_states,
|
|
232
237
|
hidden_states_residual,
|
|
233
|
-
encoder_hidden_states_residual,
|
|
234
238
|
)
|
|
235
239
|
|
|
236
240
|
def call_Bn_blocks(
|
|
237
241
|
self,
|
|
238
242
|
hidden_states: torch.Tensor,
|
|
239
|
-
# None Pattern 3, else 4, 5
|
|
240
|
-
encoder_hidden_states: torch.Tensor | None,
|
|
241
243
|
*args,
|
|
242
244
|
**kwargs,
|
|
243
245
|
):
|
|
244
|
-
if self.cache_manager.Bn_compute_blocks() == 0:
|
|
245
|
-
return hidden_states, encoder_hidden_states
|
|
246
|
-
|
|
247
246
|
assert self.cache_manager.Bn_compute_blocks() <= len(
|
|
248
247
|
self.transformer_blocks
|
|
249
248
|
), (
|
|
@@ -264,11 +263,15 @@ class CachedBlocks_Pattern_3_4_5(CachedBlocks_Pattern_Base):
|
|
|
264
263
|
**kwargs,
|
|
265
264
|
)
|
|
266
265
|
if not isinstance(hidden_states, torch.Tensor): # Pattern 4,5
|
|
267
|
-
hidden_states,
|
|
266
|
+
hidden_states, new_encoder_hidden_states = hidden_states
|
|
268
267
|
if not self.forward_pattern.Return_H_First:
|
|
269
|
-
hidden_states,
|
|
270
|
-
|
|
268
|
+
hidden_states, new_encoder_hidden_states = (
|
|
269
|
+
new_encoder_hidden_states,
|
|
271
270
|
hidden_states,
|
|
272
271
|
)
|
|
272
|
+
kwargs = self.maybe_update_kwargs(
|
|
273
|
+
new_encoder_hidden_states,
|
|
274
|
+
kwargs,
|
|
275
|
+
)
|
|
273
276
|
|
|
274
|
-
return hidden_states,
|
|
277
|
+
return hidden_states, new_encoder_hidden_states
|
|
@@ -23,3 +23,19 @@ def patch_cached_stats(
|
|
|
23
23
|
module._residual_diffs = cache_manager.get_residual_diffs()
|
|
24
24
|
module._cfg_cached_steps = cache_manager.get_cfg_cached_steps()
|
|
25
25
|
module._cfg_residual_diffs = cache_manager.get_cfg_residual_diffs()
|
|
26
|
+
|
|
27
|
+
|
|
28
|
+
def remove_cached_stats(
|
|
29
|
+
module: torch.nn.Module | Any,
|
|
30
|
+
):
|
|
31
|
+
if module is None:
|
|
32
|
+
return
|
|
33
|
+
|
|
34
|
+
if hasattr(module, "_cached_steps"):
|
|
35
|
+
del module._cached_steps
|
|
36
|
+
if hasattr(module, "_residual_diffs"):
|
|
37
|
+
del module._residual_diffs
|
|
38
|
+
if hasattr(module, "_cfg_cached_steps"):
|
|
39
|
+
del module._cfg_cached_steps
|
|
40
|
+
if hasattr(module, "_cfg_residual_diffs"):
|
|
41
|
+
del module._cfg_residual_diffs
|