cache-dit 0.2.32__py3-none-any.whl → 0.2.34__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- cache_dit/_version.py +2 -2
- cache_dit/cache_factory/block_adapters/__init__.py +14 -20
- cache_dit/cache_factory/block_adapters/block_adapters.py +47 -3
- cache_dit/cache_factory/block_adapters/block_registers.py +3 -2
- cache_dit/cache_factory/cache_adapters.py +8 -8
- cache_dit/cache_factory/cache_blocks/pattern_3_4_5.py +23 -62
- cache_dit/cache_factory/cache_blocks/pattern_base.py +23 -168
- cache_dit/cache_factory/cache_contexts/cache_context.py +18 -64
- cache_dit/cache_factory/cache_contexts/cache_manager.py +23 -71
- cache_dit/cache_factory/cache_contexts/taylorseer.py +11 -13
- cache_dit/cache_factory/cache_interface.py +9 -9
- cache_dit/cache_factory/patch_functors/__init__.py +1 -0
- cache_dit/cache_factory/patch_functors/functor_chroma.py +142 -52
- cache_dit/cache_factory/patch_functors/functor_dit.py +130 -0
- cache_dit/quantize/quantize_ao.py +3 -0
- {cache_dit-0.2.32.dist-info → cache_dit-0.2.34.dist-info}/METADATA +184 -39
- {cache_dit-0.2.32.dist-info → cache_dit-0.2.34.dist-info}/RECORD +21 -21
- cache_dit/quantize/quantize_svdq.py +0 -0
- {cache_dit-0.2.32.dist-info → cache_dit-0.2.34.dist-info}/WHEEL +0 -0
- {cache_dit-0.2.32.dist-info → cache_dit-0.2.34.dist-info}/entry_points.txt +0 -0
- {cache_dit-0.2.32.dist-info → cache_dit-0.2.34.dist-info}/licenses/LICENSE +0 -0
- {cache_dit-0.2.32.dist-info → cache_dit-0.2.34.dist-info}/top_level.txt +0 -0
cache_dit/_version.py
CHANGED
|
@@ -28,7 +28,7 @@ version_tuple: VERSION_TUPLE
|
|
|
28
28
|
commit_id: COMMIT_ID
|
|
29
29
|
__commit_id__: COMMIT_ID
|
|
30
30
|
|
|
31
|
-
__version__ = version = '0.2.
|
|
32
|
-
__version_tuple__ = version_tuple = (0, 2,
|
|
31
|
+
__version__ = version = '0.2.34'
|
|
32
|
+
__version_tuple__ = version_tuple = (0, 2, 34)
|
|
33
33
|
|
|
34
34
|
__commit_id__ = commit_id = None
|
|
@@ -153,7 +153,7 @@ def qwenimage_adapter(pipe, **kwargs) -> BlockAdapter:
|
|
|
153
153
|
)
|
|
154
154
|
|
|
155
155
|
|
|
156
|
-
@BlockAdapterRegistry.register("
|
|
156
|
+
@BlockAdapterRegistry.register("LTX")
|
|
157
157
|
def ltxvideo_adapter(pipe, **kwargs) -> BlockAdapter:
|
|
158
158
|
from diffusers import LTXVideoTransformer3DModel
|
|
159
159
|
|
|
@@ -248,7 +248,10 @@ def skyreelsv2_adapter(pipe, **kwargs) -> BlockAdapter:
|
|
|
248
248
|
pipe=pipe,
|
|
249
249
|
transformer=pipe.transformer,
|
|
250
250
|
blocks=pipe.transformer.blocks,
|
|
251
|
-
|
|
251
|
+
# NOTE: Use Pattern_3 instead of Pattern_2 because the
|
|
252
|
+
# encoder_hidden_states will never change in the blocks
|
|
253
|
+
# forward loop.
|
|
254
|
+
forward_pattern=ForwardPattern.Pattern_3,
|
|
252
255
|
has_separate_cfg=True,
|
|
253
256
|
**kwargs,
|
|
254
257
|
)
|
|
@@ -285,6 +288,7 @@ def consisid_adapter(pipe, **kwargs) -> BlockAdapter:
|
|
|
285
288
|
@BlockAdapterRegistry.register("DiT")
|
|
286
289
|
def dit_adapter(pipe, **kwargs) -> BlockAdapter:
|
|
287
290
|
from diffusers import DiTTransformer2DModel
|
|
291
|
+
from cache_dit.cache_factory.patch_functors import DiTPatchFunctor
|
|
288
292
|
|
|
289
293
|
assert isinstance(pipe.transformer, DiTTransformer2DModel)
|
|
290
294
|
return BlockAdapter(
|
|
@@ -292,6 +296,7 @@ def dit_adapter(pipe, **kwargs) -> BlockAdapter:
|
|
|
292
296
|
transformer=pipe.transformer,
|
|
293
297
|
blocks=pipe.transformer.transformer_blocks,
|
|
294
298
|
forward_pattern=ForwardPattern.Pattern_3,
|
|
299
|
+
patch_functor=DiTPatchFunctor(),
|
|
295
300
|
**kwargs,
|
|
296
301
|
)
|
|
297
302
|
|
|
@@ -331,24 +336,13 @@ def bria_adapter(pipe, **kwargs) -> BlockAdapter:
|
|
|
331
336
|
|
|
332
337
|
|
|
333
338
|
@BlockAdapterRegistry.register("Lumina")
|
|
334
|
-
def lumina_adapter(pipe, **kwargs) -> BlockAdapter:
|
|
335
|
-
from diffusers import LuminaNextDiT2DModel
|
|
336
|
-
|
|
337
|
-
assert isinstance(pipe.transformer, LuminaNextDiT2DModel)
|
|
338
|
-
return BlockAdapter(
|
|
339
|
-
pipe=pipe,
|
|
340
|
-
transformer=pipe.transformer,
|
|
341
|
-
blocks=pipe.transformer.layers,
|
|
342
|
-
forward_pattern=ForwardPattern.Pattern_3,
|
|
343
|
-
**kwargs,
|
|
344
|
-
)
|
|
345
|
-
|
|
346
|
-
|
|
347
|
-
@BlockAdapterRegistry.register("Lumina2")
|
|
348
339
|
def lumina2_adapter(pipe, **kwargs) -> BlockAdapter:
|
|
349
340
|
from diffusers import Lumina2Transformer2DModel
|
|
341
|
+
from diffusers import LuminaNextDiT2DModel
|
|
350
342
|
|
|
351
|
-
assert isinstance(
|
|
343
|
+
assert isinstance(
|
|
344
|
+
pipe.transformer, (Lumina2Transformer2DModel, LuminaNextDiT2DModel)
|
|
345
|
+
)
|
|
352
346
|
return BlockAdapter(
|
|
353
347
|
pipe=pipe,
|
|
354
348
|
transformer=pipe.transformer,
|
|
@@ -386,12 +380,10 @@ def pixart_adapter(pipe, **kwargs) -> BlockAdapter:
|
|
|
386
380
|
)
|
|
387
381
|
|
|
388
382
|
|
|
389
|
-
@BlockAdapterRegistry.register("Sana"
|
|
383
|
+
@BlockAdapterRegistry.register("Sana")
|
|
390
384
|
def sana_adapter(pipe, **kwargs) -> BlockAdapter:
|
|
391
385
|
from diffusers import SanaTransformer2DModel
|
|
392
386
|
|
|
393
|
-
# TODO: fix -> got multiple values for argument 'encoder_hidden_states'
|
|
394
|
-
|
|
395
387
|
assert isinstance(pipe.transformer, SanaTransformer2DModel)
|
|
396
388
|
return BlockAdapter(
|
|
397
389
|
pipe=pipe,
|
|
@@ -469,6 +461,7 @@ def auraflow_adapter(pipe, **kwargs) -> BlockAdapter:
|
|
|
469
461
|
@BlockAdapterRegistry.register("Chroma")
|
|
470
462
|
def chroma_adapter(pipe, **kwargs) -> BlockAdapter:
|
|
471
463
|
from diffusers import ChromaTransformer2DModel
|
|
464
|
+
from cache_dit.cache_factory.patch_functors import ChromaPatchFunctor
|
|
472
465
|
|
|
473
466
|
assert isinstance(pipe.transformer, ChromaTransformer2DModel)
|
|
474
467
|
return BlockAdapter(
|
|
@@ -482,6 +475,7 @@ def chroma_adapter(pipe, **kwargs) -> BlockAdapter:
|
|
|
482
475
|
ForwardPattern.Pattern_1,
|
|
483
476
|
ForwardPattern.Pattern_3,
|
|
484
477
|
],
|
|
478
|
+
patch_functor=ChromaPatchFunctor(),
|
|
485
479
|
has_separate_cfg=True,
|
|
486
480
|
**kwargs,
|
|
487
481
|
)
|
|
@@ -16,8 +16,52 @@ logger = init_logger(__name__)
|
|
|
16
16
|
|
|
17
17
|
|
|
18
18
|
class ParamsModifier:
|
|
19
|
-
def __init__(
|
|
20
|
-
self
|
|
19
|
+
def __init__(
|
|
20
|
+
self,
|
|
21
|
+
# Cache context kwargs
|
|
22
|
+
Fn_compute_blocks: Optional[int] = None,
|
|
23
|
+
Bn_compute_blocks: Optional[int] = None,
|
|
24
|
+
max_warmup_steps: Optional[int] = None,
|
|
25
|
+
max_cached_steps: Optional[int] = None,
|
|
26
|
+
max_continuous_cached_steps: Optional[int] = None,
|
|
27
|
+
residual_diff_threshold: Optional[float] = None,
|
|
28
|
+
# Cache CFG or not
|
|
29
|
+
enable_separate_cfg: Optional[bool] = None,
|
|
30
|
+
cfg_compute_first: Optional[bool] = None,
|
|
31
|
+
cfg_diff_compute_separate: Optional[bool] = None,
|
|
32
|
+
# Hybird TaylorSeer
|
|
33
|
+
enable_taylorseer: Optional[bool] = None,
|
|
34
|
+
enable_encoder_taylorseer: Optional[bool] = None,
|
|
35
|
+
taylorseer_cache_type: Optional[str] = None,
|
|
36
|
+
taylorseer_order: Optional[int] = None,
|
|
37
|
+
**other_cache_context_kwargs,
|
|
38
|
+
):
|
|
39
|
+
self._context_kwargs = other_cache_context_kwargs.copy()
|
|
40
|
+
self._maybe_update_param("Fn_compute_blocks", Fn_compute_blocks)
|
|
41
|
+
self._maybe_update_param("Bn_compute_blocks", Bn_compute_blocks)
|
|
42
|
+
self._maybe_update_param("max_warmup_steps", max_warmup_steps)
|
|
43
|
+
self._maybe_update_param("max_cached_steps", max_cached_steps)
|
|
44
|
+
self._maybe_update_param(
|
|
45
|
+
"max_continuous_cached_steps", max_continuous_cached_steps
|
|
46
|
+
)
|
|
47
|
+
self._maybe_update_param(
|
|
48
|
+
"residual_diff_threshold", residual_diff_threshold
|
|
49
|
+
)
|
|
50
|
+
self._maybe_update_param("enable_separate_cfg", enable_separate_cfg)
|
|
51
|
+
self._maybe_update_param("cfg_compute_first", cfg_compute_first)
|
|
52
|
+
self._maybe_update_param(
|
|
53
|
+
"cfg_diff_compute_separate", cfg_diff_compute_separate
|
|
54
|
+
)
|
|
55
|
+
self._maybe_update_param("enable_taylorseer", enable_taylorseer)
|
|
56
|
+
self._maybe_update_param(
|
|
57
|
+
"enable_encoder_taylorseer", enable_encoder_taylorseer
|
|
58
|
+
)
|
|
59
|
+
self._maybe_update_param("taylorseer_cache_type", taylorseer_cache_type)
|
|
60
|
+
self._maybe_update_param("taylorseer_order", taylorseer_order)
|
|
61
|
+
|
|
62
|
+
def _maybe_update_param(self, key: str, value: Any):
|
|
63
|
+
if value is not None:
|
|
64
|
+
self._context_kwargs[key] = value
|
|
21
65
|
|
|
22
66
|
|
|
23
67
|
@dataclasses.dataclass
|
|
@@ -579,7 +623,7 @@ class BlockAdapter:
|
|
|
579
623
|
assert isinstance(adapter[0], torch.nn.Module)
|
|
580
624
|
return getattr(adapter[0], "_is_cached", False)
|
|
581
625
|
else:
|
|
582
|
-
|
|
626
|
+
return getattr(adapter, "_is_cached", False)
|
|
583
627
|
|
|
584
628
|
@classmethod
|
|
585
629
|
def nested_depth(cls, obj: Any):
|
|
@@ -10,13 +10,14 @@ logger = init_logger(__name__)
|
|
|
10
10
|
|
|
11
11
|
class BlockAdapterRegistry:
|
|
12
12
|
_adapters: Dict[str, Callable[..., BlockAdapter]] = {}
|
|
13
|
-
|
|
13
|
+
_predefined_adapters_has_separate_cfg: List[str] = [
|
|
14
14
|
"QwenImage",
|
|
15
15
|
"Wan",
|
|
16
16
|
"CogView4",
|
|
17
17
|
"Cosmos",
|
|
18
18
|
"SkyReelsV2",
|
|
19
19
|
"Chroma",
|
|
20
|
+
"Lumina2",
|
|
20
21
|
]
|
|
21
22
|
|
|
22
23
|
@classmethod
|
|
@@ -68,7 +69,7 @@ class BlockAdapterRegistry:
|
|
|
68
69
|
return True
|
|
69
70
|
|
|
70
71
|
pipe_cls_name = pipe_or_adapter.__class__.__name__
|
|
71
|
-
for name in cls.
|
|
72
|
+
for name in cls._predefined_adapters_has_separate_cfg:
|
|
72
73
|
if pipe_cls_name.startswith(name):
|
|
73
74
|
return True
|
|
74
75
|
|
|
@@ -114,27 +114,27 @@ class CachedAdapter:
|
|
|
114
114
|
**cache_context_kwargs,
|
|
115
115
|
):
|
|
116
116
|
# Check cache_context_kwargs
|
|
117
|
-
if cache_context_kwargs["
|
|
117
|
+
if cache_context_kwargs["enable_separate_cfg"] is None:
|
|
118
118
|
# Check cfg for some specific case if users don't set it as True
|
|
119
119
|
if BlockAdapterRegistry.has_separate_cfg(block_adapter):
|
|
120
|
-
cache_context_kwargs["
|
|
120
|
+
cache_context_kwargs["enable_separate_cfg"] = True
|
|
121
121
|
logger.info(
|
|
122
|
-
f"Use custom '
|
|
122
|
+
f"Use custom 'enable_separate_cfg' from BlockAdapter: True. "
|
|
123
123
|
f"Pipeline: {block_adapter.pipe.__class__.__name__}."
|
|
124
124
|
)
|
|
125
125
|
else:
|
|
126
|
-
cache_context_kwargs["
|
|
126
|
+
cache_context_kwargs["enable_separate_cfg"] = (
|
|
127
127
|
BlockAdapterRegistry.has_separate_cfg(block_adapter.pipe)
|
|
128
128
|
)
|
|
129
129
|
logger.info(
|
|
130
|
-
f"Use default '
|
|
131
|
-
f"register: {cache_context_kwargs['
|
|
130
|
+
f"Use default 'enable_separate_cfg' from block adapter "
|
|
131
|
+
f"register: {cache_context_kwargs['enable_separate_cfg']}, "
|
|
132
132
|
f"Pipeline: {block_adapter.pipe.__class__.__name__}."
|
|
133
133
|
)
|
|
134
134
|
else:
|
|
135
135
|
logger.info(
|
|
136
|
-
f"Use custom '
|
|
137
|
-
f"kwargs: {cache_context_kwargs['
|
|
136
|
+
f"Use custom 'enable_separate_cfg' from cache context "
|
|
137
|
+
f"kwargs: {cache_context_kwargs['enable_separate_cfg']}. "
|
|
138
138
|
f"Pipeline: {block_adapter.pipe.__class__.__name__}."
|
|
139
139
|
)
|
|
140
140
|
|
|
@@ -1,6 +1,5 @@
|
|
|
1
1
|
import torch
|
|
2
2
|
|
|
3
|
-
from typing import Dict, Any
|
|
4
3
|
from cache_dit.cache_factory import ForwardPattern
|
|
5
4
|
from cache_dit.cache_factory.cache_blocks.pattern_base import (
|
|
6
5
|
CachedBlocks_Pattern_Base,
|
|
@@ -24,14 +23,12 @@ class CachedBlocks_Pattern_3_4_5(CachedBlocks_Pattern_Base):
|
|
|
24
23
|
**kwargs,
|
|
25
24
|
):
|
|
26
25
|
# Use it's own cache context.
|
|
27
|
-
self.cache_manager.set_context(
|
|
28
|
-
|
|
29
|
-
)
|
|
26
|
+
self.cache_manager.set_context(self.cache_context)
|
|
27
|
+
self._check_cache_params()
|
|
30
28
|
|
|
31
29
|
original_hidden_states = hidden_states
|
|
32
30
|
# Call first `n` blocks to process the hidden states for
|
|
33
31
|
# more stable diff calculation.
|
|
34
|
-
# encoder_hidden_states: None Pattern 3, else 4, 5
|
|
35
32
|
hidden_states, new_encoder_hidden_states = self.call_Fn_blocks(
|
|
36
33
|
hidden_states,
|
|
37
34
|
*args,
|
|
@@ -109,10 +106,7 @@ class CachedBlocks_Pattern_3_4_5(CachedBlocks_Pattern_Base):
|
|
|
109
106
|
*args,
|
|
110
107
|
**kwargs,
|
|
111
108
|
)
|
|
112
|
-
|
|
113
|
-
new_encoder_hidden_states_residual = (
|
|
114
|
-
new_encoder_hidden_states - old_encoder_hidden_states
|
|
115
|
-
)
|
|
109
|
+
|
|
116
110
|
torch._dynamo.graph_break()
|
|
117
111
|
if self.cache_manager.is_cache_residual():
|
|
118
112
|
self.cache_manager.set_Bn_buffer(
|
|
@@ -125,6 +119,10 @@ class CachedBlocks_Pattern_3_4_5(CachedBlocks_Pattern_Base):
|
|
|
125
119
|
prefix=f"{self.cache_prefix}_Bn_hidden_states",
|
|
126
120
|
)
|
|
127
121
|
|
|
122
|
+
if new_encoder_hidden_states is not None:
|
|
123
|
+
new_encoder_hidden_states_residual = (
|
|
124
|
+
new_encoder_hidden_states - old_encoder_hidden_states
|
|
125
|
+
)
|
|
128
126
|
if self.cache_manager.is_encoder_cache_residual():
|
|
129
127
|
if new_encoder_hidden_states is not None:
|
|
130
128
|
self.cache_manager.set_Bn_encoder_buffer(
|
|
@@ -159,27 +157,12 @@ class CachedBlocks_Pattern_3_4_5(CachedBlocks_Pattern_Base):
|
|
|
159
157
|
)
|
|
160
158
|
)
|
|
161
159
|
|
|
162
|
-
@torch.compiler.disable
|
|
163
|
-
def maybe_update_kwargs(
|
|
164
|
-
self, encoder_hidden_states, kwargs: Dict[str, Any]
|
|
165
|
-
) -> Dict[str, Any]:
|
|
166
|
-
# if "encoder_hidden_states" in kwargs:
|
|
167
|
-
# kwargs["encoder_hidden_states"] = encoder_hidden_states
|
|
168
|
-
# return kwargs
|
|
169
|
-
return kwargs
|
|
170
|
-
|
|
171
160
|
def call_Fn_blocks(
|
|
172
161
|
self,
|
|
173
162
|
hidden_states: torch.Tensor,
|
|
174
163
|
*args,
|
|
175
164
|
**kwargs,
|
|
176
165
|
):
|
|
177
|
-
assert self.cache_manager.Fn_compute_blocks() <= len(
|
|
178
|
-
self.transformer_blocks
|
|
179
|
-
), (
|
|
180
|
-
f"Fn_compute_blocks {self.cache_manager.Fn_compute_blocks()} must be less than "
|
|
181
|
-
f"the number of transformer blocks {len(self.transformer_blocks)}"
|
|
182
|
-
)
|
|
183
166
|
new_encoder_hidden_states = None
|
|
184
167
|
for block in self._Fn_blocks():
|
|
185
168
|
hidden_states = block(
|
|
@@ -194,10 +177,6 @@ class CachedBlocks_Pattern_3_4_5(CachedBlocks_Pattern_Base):
|
|
|
194
177
|
new_encoder_hidden_states,
|
|
195
178
|
hidden_states,
|
|
196
179
|
)
|
|
197
|
-
kwargs = self.maybe_update_kwargs(
|
|
198
|
-
new_encoder_hidden_states,
|
|
199
|
-
kwargs,
|
|
200
|
-
)
|
|
201
180
|
|
|
202
181
|
return hidden_states, new_encoder_hidden_states
|
|
203
182
|
|
|
@@ -222,11 +201,6 @@ class CachedBlocks_Pattern_3_4_5(CachedBlocks_Pattern_Base):
|
|
|
222
201
|
new_encoder_hidden_states,
|
|
223
202
|
hidden_states,
|
|
224
203
|
)
|
|
225
|
-
kwargs = self.maybe_update_kwargs(
|
|
226
|
-
new_encoder_hidden_states,
|
|
227
|
-
kwargs,
|
|
228
|
-
)
|
|
229
|
-
|
|
230
204
|
# compute hidden_states residual
|
|
231
205
|
hidden_states = hidden_states.contiguous()
|
|
232
206
|
hidden_states_residual = hidden_states - original_hidden_states
|
|
@@ -243,35 +217,22 @@ class CachedBlocks_Pattern_3_4_5(CachedBlocks_Pattern_Base):
|
|
|
243
217
|
*args,
|
|
244
218
|
**kwargs,
|
|
245
219
|
):
|
|
246
|
-
|
|
247
|
-
|
|
248
|
-
|
|
249
|
-
|
|
250
|
-
|
|
251
|
-
|
|
252
|
-
|
|
253
|
-
|
|
254
|
-
|
|
255
|
-
f"patterns: {self._supported_patterns}."
|
|
220
|
+
new_encoder_hidden_states = None
|
|
221
|
+
if self.cache_manager.Bn_compute_blocks() == 0:
|
|
222
|
+
return hidden_states, new_encoder_hidden_states
|
|
223
|
+
|
|
224
|
+
for block in self._Bn_blocks():
|
|
225
|
+
hidden_states = block(
|
|
226
|
+
hidden_states,
|
|
227
|
+
*args,
|
|
228
|
+
**kwargs,
|
|
256
229
|
)
|
|
257
|
-
|
|
258
|
-
|
|
259
|
-
|
|
260
|
-
|
|
261
|
-
|
|
262
|
-
|
|
263
|
-
|
|
264
|
-
)
|
|
265
|
-
if not isinstance(hidden_states, torch.Tensor): # Pattern 4,5
|
|
266
|
-
hidden_states, new_encoder_hidden_states = hidden_states
|
|
267
|
-
if not self.forward_pattern.Return_H_First:
|
|
268
|
-
hidden_states, new_encoder_hidden_states = (
|
|
269
|
-
new_encoder_hidden_states,
|
|
270
|
-
hidden_states,
|
|
271
|
-
)
|
|
272
|
-
kwargs = self.maybe_update_kwargs(
|
|
273
|
-
new_encoder_hidden_states,
|
|
274
|
-
kwargs,
|
|
275
|
-
)
|
|
230
|
+
if not isinstance(hidden_states, torch.Tensor): # Pattern 4,5
|
|
231
|
+
hidden_states, new_encoder_hidden_states = hidden_states
|
|
232
|
+
if not self.forward_pattern.Return_H_First:
|
|
233
|
+
hidden_states, new_encoder_hidden_states = (
|
|
234
|
+
new_encoder_hidden_states,
|
|
235
|
+
hidden_states,
|
|
236
|
+
)
|
|
276
237
|
|
|
277
238
|
return hidden_states, new_encoder_hidden_states
|
|
@@ -93,6 +93,21 @@ class CachedBlocks_Pattern_Base(torch.nn.Module):
|
|
|
93
93
|
required_param in forward_parameters
|
|
94
94
|
), f"The input parameters must contains: {required_param}."
|
|
95
95
|
|
|
96
|
+
@torch.compiler.disable
|
|
97
|
+
def _check_cache_params(self):
|
|
98
|
+
assert self.cache_manager.Fn_compute_blocks() <= len(
|
|
99
|
+
self.transformer_blocks
|
|
100
|
+
), (
|
|
101
|
+
f"Fn_compute_blocks {self.cache_manager.Fn_compute_blocks()} must be less than "
|
|
102
|
+
f"the number of transformer blocks {len(self.transformer_blocks)}"
|
|
103
|
+
)
|
|
104
|
+
assert self.cache_manager.Bn_compute_blocks() <= len(
|
|
105
|
+
self.transformer_blocks
|
|
106
|
+
), (
|
|
107
|
+
f"Bn_compute_blocks {self.cache_manager.Bn_compute_blocks()} must be less than "
|
|
108
|
+
f"the number of transformer blocks {len(self.transformer_blocks)}"
|
|
109
|
+
)
|
|
110
|
+
|
|
96
111
|
def forward(
|
|
97
112
|
self,
|
|
98
113
|
hidden_states: torch.Tensor,
|
|
@@ -100,7 +115,9 @@ class CachedBlocks_Pattern_Base(torch.nn.Module):
|
|
|
100
115
|
*args,
|
|
101
116
|
**kwargs,
|
|
102
117
|
):
|
|
118
|
+
# Use it's own cache context.
|
|
103
119
|
self.cache_manager.set_context(self.cache_context)
|
|
120
|
+
self._check_cache_params()
|
|
104
121
|
|
|
105
122
|
original_hidden_states = hidden_states
|
|
106
123
|
# Call first `n` blocks to process the hidden states for
|
|
@@ -191,18 +208,17 @@ class CachedBlocks_Pattern_Base(torch.nn.Module):
|
|
|
191
208
|
prefix=f"{self.cache_prefix}_Bn_residual",
|
|
192
209
|
)
|
|
193
210
|
else:
|
|
194
|
-
# TaylorSeer
|
|
195
211
|
self.cache_manager.set_Bn_buffer(
|
|
196
212
|
hidden_states,
|
|
197
213
|
prefix=f"{self.cache_prefix}_Bn_hidden_states",
|
|
198
214
|
)
|
|
215
|
+
|
|
199
216
|
if self.cache_manager.is_encoder_cache_residual():
|
|
200
217
|
self.cache_manager.set_Bn_encoder_buffer(
|
|
201
218
|
encoder_hidden_states_residual,
|
|
202
219
|
prefix=f"{self.cache_prefix}_Bn_residual",
|
|
203
220
|
)
|
|
204
221
|
else:
|
|
205
|
-
# TaylorSeer
|
|
206
222
|
self.cache_manager.set_Bn_encoder_buffer(
|
|
207
223
|
encoder_hidden_states,
|
|
208
224
|
prefix=f"{self.cache_prefix}_Bn_hidden_states",
|
|
@@ -296,12 +312,6 @@ class CachedBlocks_Pattern_Base(torch.nn.Module):
|
|
|
296
312
|
*args,
|
|
297
313
|
**kwargs,
|
|
298
314
|
):
|
|
299
|
-
assert self.cache_manager.Fn_compute_blocks() <= len(
|
|
300
|
-
self.transformer_blocks
|
|
301
|
-
), (
|
|
302
|
-
f"Fn_compute_blocks {self.cache_manager.Fn_compute_blocks()} must be less than "
|
|
303
|
-
f"the number of transformer blocks {len(self.transformer_blocks)}"
|
|
304
|
-
)
|
|
305
315
|
for block in self._Fn_blocks():
|
|
306
316
|
hidden_states = block(
|
|
307
317
|
hidden_states,
|
|
@@ -366,28 +376,17 @@ class CachedBlocks_Pattern_Base(torch.nn.Module):
|
|
|
366
376
|
encoder_hidden_states_residual,
|
|
367
377
|
)
|
|
368
378
|
|
|
369
|
-
def
|
|
379
|
+
def call_Bn_blocks(
|
|
370
380
|
self,
|
|
371
|
-
# Block index in the transformer blocks
|
|
372
|
-
# Bn: 8, block_id should be in [0, 8)
|
|
373
|
-
block_id: int,
|
|
374
|
-
# Below are the inputs to the block
|
|
375
|
-
block, # The transformer block to be executed
|
|
376
381
|
hidden_states: torch.Tensor,
|
|
377
382
|
encoder_hidden_states: torch.Tensor,
|
|
378
383
|
*args,
|
|
379
384
|
**kwargs,
|
|
380
385
|
):
|
|
381
|
-
|
|
382
|
-
|
|
383
|
-
|
|
384
|
-
|
|
385
|
-
# and cache the residuals in non-cache steps.
|
|
386
|
-
|
|
387
|
-
# Normal steps: Compute the block and cache the residuals.
|
|
388
|
-
if not self._is_in_cache_step():
|
|
389
|
-
Bn_i_original_hidden_states = hidden_states
|
|
390
|
-
Bn_i_original_encoder_hidden_states = encoder_hidden_states
|
|
386
|
+
if self.cache_manager.Bn_compute_blocks() == 0:
|
|
387
|
+
return hidden_states, encoder_hidden_states
|
|
388
|
+
|
|
389
|
+
for block in self._Bn_blocks():
|
|
391
390
|
hidden_states = block(
|
|
392
391
|
hidden_states,
|
|
393
392
|
encoder_hidden_states,
|
|
@@ -401,149 +400,5 @@ class CachedBlocks_Pattern_Base(torch.nn.Module):
|
|
|
401
400
|
encoder_hidden_states,
|
|
402
401
|
hidden_states,
|
|
403
402
|
)
|
|
404
|
-
# Cache residuals for the non-compute Bn blocks for
|
|
405
|
-
# subsequent cache steps.
|
|
406
|
-
if block_id not in self.cache_manager.Bn_compute_blocks_ids():
|
|
407
|
-
Bn_i_hidden_states_residual = (
|
|
408
|
-
hidden_states - Bn_i_original_hidden_states
|
|
409
|
-
)
|
|
410
|
-
if (
|
|
411
|
-
encoder_hidden_states is not None
|
|
412
|
-
and Bn_i_original_encoder_hidden_states is not None
|
|
413
|
-
):
|
|
414
|
-
Bn_i_encoder_hidden_states_residual = (
|
|
415
|
-
encoder_hidden_states
|
|
416
|
-
- Bn_i_original_encoder_hidden_states
|
|
417
|
-
)
|
|
418
|
-
else:
|
|
419
|
-
Bn_i_encoder_hidden_states_residual = None
|
|
420
|
-
|
|
421
|
-
# Save original_hidden_states for diff calculation.
|
|
422
|
-
self.cache_manager.set_Bn_buffer(
|
|
423
|
-
Bn_i_original_hidden_states,
|
|
424
|
-
prefix=f"{self.cache_prefix}_Bn_{block_id}_original",
|
|
425
|
-
)
|
|
426
|
-
self.cache_manager.set_Bn_encoder_buffer(
|
|
427
|
-
Bn_i_original_encoder_hidden_states,
|
|
428
|
-
prefix=f"{self.cache_prefix}_Bn_{block_id}_original",
|
|
429
|
-
)
|
|
430
|
-
|
|
431
|
-
self.cache_manager.set_Bn_buffer(
|
|
432
|
-
Bn_i_hidden_states_residual,
|
|
433
|
-
prefix=f"{self.cache_prefix}_Bn_{block_id}_residual",
|
|
434
|
-
)
|
|
435
|
-
self.cache_manager.set_Bn_encoder_buffer(
|
|
436
|
-
Bn_i_encoder_hidden_states_residual,
|
|
437
|
-
prefix=f"{self.cache_prefix}_Bn_{block_id}_residual",
|
|
438
|
-
)
|
|
439
|
-
del Bn_i_hidden_states_residual
|
|
440
|
-
del Bn_i_encoder_hidden_states_residual
|
|
441
|
-
|
|
442
|
-
del Bn_i_original_hidden_states
|
|
443
|
-
del Bn_i_original_encoder_hidden_states
|
|
444
|
-
|
|
445
|
-
else:
|
|
446
|
-
# Cache steps: Reuse the cached residuals.
|
|
447
|
-
# Check if the block is in the Bn_compute_blocks_ids.
|
|
448
|
-
if block_id in self.cache_manager.Bn_compute_blocks_ids():
|
|
449
|
-
hidden_states = block(
|
|
450
|
-
hidden_states,
|
|
451
|
-
encoder_hidden_states,
|
|
452
|
-
*args,
|
|
453
|
-
**kwargs,
|
|
454
|
-
)
|
|
455
|
-
if not isinstance(hidden_states, torch.Tensor):
|
|
456
|
-
hidden_states, encoder_hidden_states = hidden_states
|
|
457
|
-
if not self.forward_pattern.Return_H_First:
|
|
458
|
-
hidden_states, encoder_hidden_states = (
|
|
459
|
-
encoder_hidden_states,
|
|
460
|
-
hidden_states,
|
|
461
|
-
)
|
|
462
|
-
else:
|
|
463
|
-
# Skip the block if it is not in the Bn_compute_blocks_ids.
|
|
464
|
-
# Use the cached residuals instead.
|
|
465
|
-
# Check if can use the cached residuals.
|
|
466
|
-
if self.cache_manager.can_cache(
|
|
467
|
-
hidden_states, # curr step
|
|
468
|
-
parallelized=self._is_parallelized(),
|
|
469
|
-
threshold=self.cache_manager.non_compute_blocks_diff_threshold(),
|
|
470
|
-
prefix=f"{self.cache_prefix}_Bn_{block_id}_original", # prev step
|
|
471
|
-
):
|
|
472
|
-
hidden_states, encoder_hidden_states = (
|
|
473
|
-
self.cache_manager.apply_cache(
|
|
474
|
-
hidden_states,
|
|
475
|
-
encoder_hidden_states,
|
|
476
|
-
prefix=(
|
|
477
|
-
f"{self.cache_prefix}_Bn_{block_id}_residual"
|
|
478
|
-
if self.cache_manager.is_cache_residual()
|
|
479
|
-
else f"{self.cache_prefix}_Bn_{block_id}_original"
|
|
480
|
-
),
|
|
481
|
-
encoder_prefix=(
|
|
482
|
-
f"{self.cache_prefix}_Bn_{block_id}_residual"
|
|
483
|
-
if self.cache_manager.is_encoder_cache_residual()
|
|
484
|
-
else f"{self.cache_prefix}_Bn_{block_id}_original"
|
|
485
|
-
),
|
|
486
|
-
)
|
|
487
|
-
)
|
|
488
|
-
else:
|
|
489
|
-
hidden_states = block(
|
|
490
|
-
hidden_states,
|
|
491
|
-
encoder_hidden_states,
|
|
492
|
-
*args,
|
|
493
|
-
**kwargs,
|
|
494
|
-
)
|
|
495
|
-
if not isinstance(hidden_states, torch.Tensor):
|
|
496
|
-
hidden_states, encoder_hidden_states = hidden_states
|
|
497
|
-
if not self.forward_pattern.Return_H_First:
|
|
498
|
-
hidden_states, encoder_hidden_states = (
|
|
499
|
-
encoder_hidden_states,
|
|
500
|
-
hidden_states,
|
|
501
|
-
)
|
|
502
|
-
return hidden_states, encoder_hidden_states
|
|
503
|
-
|
|
504
|
-
def call_Bn_blocks(
|
|
505
|
-
self,
|
|
506
|
-
hidden_states: torch.Tensor,
|
|
507
|
-
encoder_hidden_states: torch.Tensor,
|
|
508
|
-
*args,
|
|
509
|
-
**kwargs,
|
|
510
|
-
):
|
|
511
|
-
if self.cache_manager.Bn_compute_blocks() == 0:
|
|
512
|
-
return hidden_states, encoder_hidden_states
|
|
513
|
-
|
|
514
|
-
assert self.cache_manager.Bn_compute_blocks() <= len(
|
|
515
|
-
self.transformer_blocks
|
|
516
|
-
), (
|
|
517
|
-
f"Bn_compute_blocks {self.cache_manager.Bn_compute_blocks()} must be less than "
|
|
518
|
-
f"the number of transformer blocks {len(self.transformer_blocks)}"
|
|
519
|
-
)
|
|
520
|
-
if len(self.cache_manager.Bn_compute_blocks_ids()) > 0:
|
|
521
|
-
for i, block in enumerate(self._Bn_blocks()):
|
|
522
|
-
hidden_states, encoder_hidden_states = (
|
|
523
|
-
self._compute_or_cache_block(
|
|
524
|
-
i,
|
|
525
|
-
block,
|
|
526
|
-
hidden_states,
|
|
527
|
-
encoder_hidden_states,
|
|
528
|
-
*args,
|
|
529
|
-
**kwargs,
|
|
530
|
-
)
|
|
531
|
-
)
|
|
532
|
-
else:
|
|
533
|
-
# Compute all Bn blocks if no specific Bn compute blocks ids are set.
|
|
534
|
-
for block in self._Bn_blocks():
|
|
535
|
-
hidden_states = block(
|
|
536
|
-
hidden_states,
|
|
537
|
-
encoder_hidden_states,
|
|
538
|
-
*args,
|
|
539
|
-
**kwargs,
|
|
540
|
-
)
|
|
541
|
-
if not isinstance(hidden_states, torch.Tensor):
|
|
542
|
-
hidden_states, encoder_hidden_states = hidden_states
|
|
543
|
-
if not self.forward_pattern.Return_H_First:
|
|
544
|
-
hidden_states, encoder_hidden_states = (
|
|
545
|
-
encoder_hidden_states,
|
|
546
|
-
hidden_states,
|
|
547
|
-
)
|
|
548
403
|
|
|
549
404
|
return hidden_states, encoder_hidden_states
|