cache-dit 0.2.23__py3-none-any.whl → 0.2.25__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Potentially problematic release.
This version of cache-dit might be problematic. Click here for more details.
- cache_dit/__init__.py +1 -0
- cache_dit/_version.py +2 -2
- cache_dit/cache_factory/cache_adapters.py +137 -76
- cache_dit/cache_factory/cache_context.py +112 -39
- cache_dit/cache_factory/cache_interface.py +11 -4
- cache_dit/cache_factory/taylorseer.py +5 -4
- cache_dit/cache_factory/utils.py +1 -1
- cache_dit/compile/utils.py +1 -1
- cache_dit/quantize/__init__.py +1 -0
- cache_dit/quantize/quantize_ao.py +182 -0
- cache_dit/quantize/quantize_interface.py +46 -0
- cache_dit/quantize/quantize_svdq.py +0 -0
- cache_dit/utils.py +68 -34
- {cache_dit-0.2.23.dist-info → cache_dit-0.2.25.dist-info}/METADATA +15 -15
- {cache_dit-0.2.23.dist-info → cache_dit-0.2.25.dist-info}/RECORD +19 -16
- cache_dit/primitives.py +0 -152
- {cache_dit-0.2.23.dist-info → cache_dit-0.2.25.dist-info}/WHEEL +0 -0
- {cache_dit-0.2.23.dist-info → cache_dit-0.2.25.dist-info}/entry_points.txt +0 -0
- {cache_dit-0.2.23.dist-info → cache_dit-0.2.25.dist-info}/licenses/LICENSE +0 -0
- {cache_dit-0.2.23.dist-info → cache_dit-0.2.25.dist-info}/top_level.txt +0 -0
cache_dit/__init__.py
CHANGED
|
@@ -12,6 +12,7 @@ from cache_dit.cache_factory import CacheType
|
|
|
12
12
|
from cache_dit.cache_factory import BlockAdapter
|
|
13
13
|
from cache_dit.cache_factory import ForwardPattern
|
|
14
14
|
from cache_dit.compile import set_compile_configs
|
|
15
|
+
from cache_dit.quantize import quantize
|
|
15
16
|
from cache_dit.utils import summary
|
|
16
17
|
from cache_dit.utils import strify
|
|
17
18
|
from cache_dit.logger import init_logger
|
cache_dit/_version.py
CHANGED
|
@@ -28,7 +28,7 @@ version_tuple: VERSION_TUPLE
|
|
|
28
28
|
commit_id: COMMIT_ID
|
|
29
29
|
__commit_id__: COMMIT_ID
|
|
30
30
|
|
|
31
|
-
__version__ = version = '0.2.
|
|
32
|
-
__version_tuple__ = version_tuple = (0, 2,
|
|
31
|
+
__version__ = version = '0.2.25'
|
|
32
|
+
__version_tuple__ = version_tuple = (0, 2, 25)
|
|
33
33
|
|
|
34
34
|
__commit_id__ = commit_id = None
|
|
@@ -5,7 +5,7 @@ import unittest
|
|
|
5
5
|
import functools
|
|
6
6
|
import dataclasses
|
|
7
7
|
|
|
8
|
-
from typing import Any, Tuple, List
|
|
8
|
+
from typing import Any, Tuple, List, Optional
|
|
9
9
|
from contextlib import ExitStack
|
|
10
10
|
from diffusers import DiffusionPipeline
|
|
11
11
|
from cache_dit.cache_factory.patch.flux import (
|
|
@@ -40,6 +40,7 @@ class BlockAdapter:
|
|
|
40
40
|
"layers",
|
|
41
41
|
]
|
|
42
42
|
)
|
|
43
|
+
check_prefixes: bool = True
|
|
43
44
|
allow_suffixes: List[str] = dataclasses.field(
|
|
44
45
|
default_factory=lambda: ["TransformerBlock"]
|
|
45
46
|
)
|
|
@@ -48,8 +49,24 @@ class BlockAdapter:
|
|
|
48
49
|
default="max", metadata={"allowed_values": ["max", "min"]}
|
|
49
50
|
)
|
|
50
51
|
|
|
52
|
+
def __post_init__(self):
|
|
53
|
+
self.maybe_apply_patch()
|
|
54
|
+
|
|
55
|
+
def maybe_apply_patch(self):
|
|
56
|
+
# Process some specificial cases, specific for transformers
|
|
57
|
+
# that has different forward patterns between single_transformer_blocks
|
|
58
|
+
# and transformer_blocks , such as Flux (diffusers < 0.35.0).
|
|
59
|
+
if self.transformer.__class__.__name__.startswith("Flux"):
|
|
60
|
+
self.transformer = maybe_patch_flux_transformer(
|
|
61
|
+
self.transformer,
|
|
62
|
+
blocks=self.blocks,
|
|
63
|
+
)
|
|
64
|
+
|
|
51
65
|
@staticmethod
|
|
52
|
-
def auto_block_adapter(
|
|
66
|
+
def auto_block_adapter(
|
|
67
|
+
adapter: "BlockAdapter",
|
|
68
|
+
forward_pattern: Optional[ForwardPattern] = None,
|
|
69
|
+
) -> "BlockAdapter":
|
|
53
70
|
assert adapter.auto, (
|
|
54
71
|
"Please manually set `auto` to True, or, manually "
|
|
55
72
|
"set all the transformer blocks configuration."
|
|
@@ -66,8 +83,10 @@ class BlockAdapter:
|
|
|
66
83
|
transformer=transformer,
|
|
67
84
|
allow_prefixes=adapter.allow_prefixes,
|
|
68
85
|
allow_suffixes=adapter.allow_suffixes,
|
|
86
|
+
check_prefixes=adapter.check_prefixes,
|
|
69
87
|
check_suffixes=adapter.check_suffixes,
|
|
70
88
|
blocks_policy=adapter.blocks_policy,
|
|
89
|
+
forward_pattern=forward_pattern,
|
|
71
90
|
)
|
|
72
91
|
|
|
73
92
|
return BlockAdapter(
|
|
@@ -87,6 +106,8 @@ class BlockAdapter:
|
|
|
87
106
|
and isinstance(adapter.blocks, torch.nn.ModuleList)
|
|
88
107
|
):
|
|
89
108
|
return True
|
|
109
|
+
|
|
110
|
+
logger.warning("Check block adapter failed!")
|
|
90
111
|
return False
|
|
91
112
|
|
|
92
113
|
@staticmethod
|
|
@@ -101,24 +122,30 @@ class BlockAdapter:
|
|
|
101
122
|
allow_suffixes: List[str] = [
|
|
102
123
|
"TransformerBlock",
|
|
103
124
|
],
|
|
125
|
+
check_prefixes: bool = True,
|
|
104
126
|
check_suffixes: bool = False,
|
|
105
127
|
**kwargs,
|
|
106
128
|
) -> Tuple[torch.nn.ModuleList, str]:
|
|
129
|
+
# Check prefixes
|
|
130
|
+
if check_prefixes:
|
|
131
|
+
blocks_names = []
|
|
132
|
+
for attr_name in dir(transformer):
|
|
133
|
+
for prefix in allow_prefixes:
|
|
134
|
+
if attr_name.startswith(prefix):
|
|
135
|
+
blocks_names.append(attr_name)
|
|
136
|
+
else:
|
|
137
|
+
blocks_names = dir(transformer)
|
|
107
138
|
|
|
108
|
-
|
|
109
|
-
for attr_name in dir(transformer):
|
|
110
|
-
for prefix in allow_prefixes:
|
|
111
|
-
if attr_name.startswith(prefix):
|
|
112
|
-
blocks_names.append(attr_name)
|
|
113
|
-
|
|
114
|
-
# Type check
|
|
139
|
+
# Check ModuleList
|
|
115
140
|
valid_names = []
|
|
116
141
|
valid_count = []
|
|
142
|
+
forward_pattern = kwargs.get("forward_pattern", None)
|
|
117
143
|
for blocks_name in blocks_names:
|
|
118
144
|
if blocks := getattr(transformer, blocks_name, None):
|
|
119
145
|
if isinstance(blocks, torch.nn.ModuleList):
|
|
120
146
|
block = blocks[0]
|
|
121
147
|
block_cls_name = block.__class__.__name__
|
|
148
|
+
# Check suffixes
|
|
122
149
|
if isinstance(block, torch.nn.Module) and (
|
|
123
150
|
any(
|
|
124
151
|
(
|
|
@@ -128,8 +155,18 @@ class BlockAdapter:
|
|
|
128
155
|
)
|
|
129
156
|
or (not check_suffixes)
|
|
130
157
|
):
|
|
131
|
-
|
|
132
|
-
|
|
158
|
+
# May check forward pattern
|
|
159
|
+
if forward_pattern is not None:
|
|
160
|
+
if BlockAdapter.match_blocks_pattern(
|
|
161
|
+
blocks,
|
|
162
|
+
forward_pattern,
|
|
163
|
+
logging=False,
|
|
164
|
+
):
|
|
165
|
+
valid_names.append(blocks_name)
|
|
166
|
+
valid_count.append(len(blocks))
|
|
167
|
+
else:
|
|
168
|
+
valid_names.append(blocks_name)
|
|
169
|
+
valid_count.append(len(blocks))
|
|
133
170
|
|
|
134
171
|
if not valid_names:
|
|
135
172
|
raise ValueError(
|
|
@@ -139,6 +176,7 @@ class BlockAdapter:
|
|
|
139
176
|
final_name = valid_names[0]
|
|
140
177
|
final_count = valid_count[0]
|
|
141
178
|
block_policy = kwargs.get("blocks_policy", "max")
|
|
179
|
+
|
|
142
180
|
for blocks_name, count in zip(valid_names, valid_count):
|
|
143
181
|
blocks = getattr(transformer, blocks_name)
|
|
144
182
|
logger.info(
|
|
@@ -165,6 +203,67 @@ class BlockAdapter:
|
|
|
165
203
|
|
|
166
204
|
return final_blocks, final_name
|
|
167
205
|
|
|
206
|
+
@staticmethod
|
|
207
|
+
def match_block_pattern(
|
|
208
|
+
block: torch.nn.Module,
|
|
209
|
+
forward_pattern: ForwardPattern,
|
|
210
|
+
) -> bool:
|
|
211
|
+
assert (
|
|
212
|
+
forward_pattern.Supported
|
|
213
|
+
and forward_pattern in ForwardPattern.supported_patterns()
|
|
214
|
+
), f"Pattern {forward_pattern} is not support now!"
|
|
215
|
+
|
|
216
|
+
forward_parameters = set(
|
|
217
|
+
inspect.signature(block.forward).parameters.keys()
|
|
218
|
+
)
|
|
219
|
+
num_outputs = str(
|
|
220
|
+
inspect.signature(block.forward).return_annotation
|
|
221
|
+
).count("torch.Tensor")
|
|
222
|
+
|
|
223
|
+
in_matched = True
|
|
224
|
+
out_matched = True
|
|
225
|
+
if num_outputs > 0 and len(forward_pattern.Out) != num_outputs:
|
|
226
|
+
# output pattern not match
|
|
227
|
+
out_matched = False
|
|
228
|
+
|
|
229
|
+
for required_param in forward_pattern.In:
|
|
230
|
+
if required_param not in forward_parameters:
|
|
231
|
+
in_matched = False
|
|
232
|
+
|
|
233
|
+
return in_matched and out_matched
|
|
234
|
+
|
|
235
|
+
@staticmethod
|
|
236
|
+
def match_blocks_pattern(
|
|
237
|
+
transformer_blocks: torch.nn.ModuleList,
|
|
238
|
+
forward_pattern: ForwardPattern,
|
|
239
|
+
logging: bool = True,
|
|
240
|
+
) -> bool:
|
|
241
|
+
assert (
|
|
242
|
+
forward_pattern.Supported
|
|
243
|
+
and forward_pattern in ForwardPattern.supported_patterns()
|
|
244
|
+
), f"Pattern {forward_pattern} is not support now!"
|
|
245
|
+
|
|
246
|
+
assert isinstance(transformer_blocks, torch.nn.ModuleList)
|
|
247
|
+
|
|
248
|
+
pattern_matched_states = []
|
|
249
|
+
for block in transformer_blocks:
|
|
250
|
+
pattern_matched_states.append(
|
|
251
|
+
BlockAdapter.match_block_pattern(
|
|
252
|
+
block,
|
|
253
|
+
forward_pattern,
|
|
254
|
+
)
|
|
255
|
+
)
|
|
256
|
+
|
|
257
|
+
pattern_matched = all(pattern_matched_states) # all block match
|
|
258
|
+
if pattern_matched and logging:
|
|
259
|
+
block_cls_name = transformer_blocks[0].__class__.__name__
|
|
260
|
+
logger.info(
|
|
261
|
+
f"Match Block Forward Pattern: {block_cls_name}, {forward_pattern}"
|
|
262
|
+
f"\nIN:{forward_pattern.In}, OUT:{forward_pattern.Out})"
|
|
263
|
+
)
|
|
264
|
+
|
|
265
|
+
return pattern_matched
|
|
266
|
+
|
|
168
267
|
|
|
169
268
|
@dataclasses.dataclass
|
|
170
269
|
class UnifiedCacheParams:
|
|
@@ -463,19 +562,42 @@ class UnifiedCacheAdapter:
|
|
|
463
562
|
) -> DiffusionPipeline:
|
|
464
563
|
|
|
465
564
|
if block_adapter.auto:
|
|
466
|
-
block_adapter = BlockAdapter.auto_block_adapter(
|
|
565
|
+
block_adapter = BlockAdapter.auto_block_adapter(
|
|
566
|
+
block_adapter,
|
|
567
|
+
forward_pattern,
|
|
568
|
+
)
|
|
467
569
|
|
|
468
570
|
if BlockAdapter.check_block_adapter(block_adapter):
|
|
469
|
-
assert isinstance(block_adapter.blocks, torch.nn.ModuleList)
|
|
470
571
|
# Apply cache on pipeline: wrap cache context
|
|
471
|
-
cls.create_context(
|
|
572
|
+
cls.create_context(
|
|
573
|
+
block_adapter.pipe,
|
|
574
|
+
**cache_context_kwargs,
|
|
575
|
+
)
|
|
472
576
|
# Apply cache on transformer: mock cached transformer blocks
|
|
473
577
|
cls.mock_blocks(
|
|
474
578
|
block_adapter,
|
|
475
579
|
forward_pattern=forward_pattern,
|
|
476
580
|
)
|
|
581
|
+
cls.patch_params(
|
|
582
|
+
block_adapter,
|
|
583
|
+
forward_pattern=forward_pattern,
|
|
584
|
+
**cache_context_kwargs,
|
|
585
|
+
)
|
|
477
586
|
return block_adapter.pipe
|
|
478
587
|
|
|
588
|
+
@classmethod
|
|
589
|
+
def patch_params(
|
|
590
|
+
cls,
|
|
591
|
+
block_adapter: BlockAdapter,
|
|
592
|
+
forward_pattern: ForwardPattern = None,
|
|
593
|
+
**cache_context_kwargs,
|
|
594
|
+
):
|
|
595
|
+
block_adapter.transformer._forward_pattern = forward_pattern
|
|
596
|
+
block_adapter.transformer._cache_context_kwargs = cache_context_kwargs
|
|
597
|
+
block_adapter.pipe.__class__._cache_context_kwargs = (
|
|
598
|
+
cache_context_kwargs
|
|
599
|
+
)
|
|
600
|
+
|
|
479
601
|
@classmethod
|
|
480
602
|
def has_separate_cfg(
|
|
481
603
|
cls,
|
|
@@ -534,7 +656,6 @@ class UnifiedCacheAdapter:
|
|
|
534
656
|
|
|
535
657
|
pipe.__class__.__call__ = new_call
|
|
536
658
|
pipe.__class__._is_cached = True
|
|
537
|
-
pipe.__class__._cache_options = cache_kwargs
|
|
538
659
|
return pipe
|
|
539
660
|
|
|
540
661
|
@classmethod
|
|
@@ -544,28 +665,11 @@ class UnifiedCacheAdapter:
|
|
|
544
665
|
forward_pattern: ForwardPattern = ForwardPattern.Pattern_0,
|
|
545
666
|
) -> torch.nn.Module:
|
|
546
667
|
|
|
547
|
-
if (
|
|
548
|
-
block_adapter.transformer is None
|
|
549
|
-
or block_adapter.blocks_name is None
|
|
550
|
-
or block_adapter.blocks is None
|
|
551
|
-
):
|
|
552
|
-
assert block_adapter.auto, (
|
|
553
|
-
"Please manually set `auto` to True, or, "
|
|
554
|
-
"manually set transformer blocks configuration."
|
|
555
|
-
)
|
|
556
|
-
|
|
557
668
|
if getattr(block_adapter.transformer, "_is_cached", False):
|
|
558
669
|
return block_adapter.transformer
|
|
559
670
|
|
|
560
|
-
# Firstly, process some specificial cases (TODO: more patches)
|
|
561
|
-
if block_adapter.transformer.__class__.__name__.startswith("Flux"):
|
|
562
|
-
block_adapter.transformer = maybe_patch_flux_transformer(
|
|
563
|
-
block_adapter.transformer,
|
|
564
|
-
blocks=block_adapter.blocks,
|
|
565
|
-
)
|
|
566
|
-
|
|
567
671
|
# Check block forward pattern matching
|
|
568
|
-
assert
|
|
672
|
+
assert BlockAdapter.match_blocks_pattern(
|
|
569
673
|
block_adapter.blocks,
|
|
570
674
|
forward_pattern=forward_pattern,
|
|
571
675
|
), (
|
|
@@ -615,46 +719,3 @@ class UnifiedCacheAdapter:
|
|
|
615
719
|
block_adapter.transformer._is_cached = True
|
|
616
720
|
|
|
617
721
|
return block_adapter.transformer
|
|
618
|
-
|
|
619
|
-
@classmethod
|
|
620
|
-
def match_pattern(
|
|
621
|
-
cls,
|
|
622
|
-
transformer_blocks: torch.nn.ModuleList,
|
|
623
|
-
forward_pattern: ForwardPattern = ForwardPattern.Pattern_0,
|
|
624
|
-
) -> bool:
|
|
625
|
-
pattern_matched_states = []
|
|
626
|
-
|
|
627
|
-
assert (
|
|
628
|
-
forward_pattern.Supported
|
|
629
|
-
and forward_pattern in ForwardPattern.supported_patterns()
|
|
630
|
-
), f"Pattern {forward_pattern} is not support now!"
|
|
631
|
-
|
|
632
|
-
for block in transformer_blocks:
|
|
633
|
-
forward_parameters = set(
|
|
634
|
-
inspect.signature(block.forward).parameters.keys()
|
|
635
|
-
)
|
|
636
|
-
num_outputs = str(
|
|
637
|
-
inspect.signature(block.forward).return_annotation
|
|
638
|
-
).count("torch.Tensor")
|
|
639
|
-
|
|
640
|
-
in_matched = True
|
|
641
|
-
out_matched = True
|
|
642
|
-
if num_outputs > 0 and len(forward_pattern.Out) != num_outputs:
|
|
643
|
-
# output pattern not match
|
|
644
|
-
out_matched = False
|
|
645
|
-
|
|
646
|
-
for required_param in forward_pattern.In:
|
|
647
|
-
if required_param not in forward_parameters:
|
|
648
|
-
in_matched = False
|
|
649
|
-
|
|
650
|
-
pattern_matched_states.append(in_matched and out_matched)
|
|
651
|
-
|
|
652
|
-
pattern_matched = all(pattern_matched_states) # all block match
|
|
653
|
-
if pattern_matched:
|
|
654
|
-
block_cls_name = transformer_blocks[0].__class__.__name__
|
|
655
|
-
logger.info(
|
|
656
|
-
f"Match Block Forward Pattern: {block_cls_name}, {forward_pattern}"
|
|
657
|
-
f"\nIN:{forward_pattern.In}, OUT:{forward_pattern.Out})"
|
|
658
|
-
)
|
|
659
|
-
|
|
660
|
-
return pattern_matched
|
|
@@ -5,8 +5,8 @@ from collections import defaultdict
|
|
|
5
5
|
from typing import Any, DefaultDict, Dict, List, Optional, Union, Tuple
|
|
6
6
|
|
|
7
7
|
import torch
|
|
8
|
+
import torch.distributed as dist
|
|
8
9
|
|
|
9
|
-
import cache_dit.primitives as primitives
|
|
10
10
|
from cache_dit.cache_factory.taylorseer import TaylorSeer
|
|
11
11
|
from cache_dit.logger import init_logger
|
|
12
12
|
|
|
@@ -47,10 +47,11 @@ class DBCacheContext:
|
|
|
47
47
|
|
|
48
48
|
# Other settings
|
|
49
49
|
downsample_factor: int = 1
|
|
50
|
-
num_inference_steps: int = -1 #
|
|
51
|
-
|
|
50
|
+
num_inference_steps: int = -1 # for future use
|
|
51
|
+
max_warmup_steps: int = 0 # DON'T Cache in warmup steps
|
|
52
52
|
# DON'T Cache if the number of cached steps >= max_cached_steps
|
|
53
53
|
max_cached_steps: int = -1 # for both CFG and non-CFG
|
|
54
|
+
max_continuous_cached_steps: int = -1 # the max continuous cached steps
|
|
54
55
|
|
|
55
56
|
# Record the steps that have been cached, both cached and non-cache
|
|
56
57
|
executed_steps: int = 0 # cache + non-cache steps pippeline
|
|
@@ -89,10 +90,12 @@ class DBCacheContext:
|
|
|
89
90
|
residual_diffs: DefaultDict[str, float] = dataclasses.field(
|
|
90
91
|
default_factory=lambda: defaultdict(float),
|
|
91
92
|
)
|
|
93
|
+
continuous_cached_steps: int = 0
|
|
92
94
|
cfg_cached_steps: List[int] = dataclasses.field(default_factory=list)
|
|
93
95
|
cfg_residual_diffs: DefaultDict[str, float] = dataclasses.field(
|
|
94
96
|
default_factory=lambda: defaultdict(float),
|
|
95
97
|
)
|
|
98
|
+
cfg_continuous_cached_steps: int = 0
|
|
96
99
|
|
|
97
100
|
@torch.compiler.disable
|
|
98
101
|
def __post_init__(self):
|
|
@@ -108,17 +111,17 @@ class DBCacheContext:
|
|
|
108
111
|
"cfg_diff_compute_separate is enabled."
|
|
109
112
|
)
|
|
110
113
|
|
|
111
|
-
if "
|
|
112
|
-
# If
|
|
113
|
-
# set the same as
|
|
114
|
-
self.taylorseer_kwargs["
|
|
115
|
-
self.
|
|
114
|
+
if "max_warmup_steps" not in self.taylorseer_kwargs:
|
|
115
|
+
# If max_warmup_steps is not set in taylorseer_kwargs,
|
|
116
|
+
# set the same as max_warmup_steps for DBCache
|
|
117
|
+
self.taylorseer_kwargs["max_warmup_steps"] = (
|
|
118
|
+
self.max_warmup_steps if self.max_warmup_steps > 0 else 1
|
|
116
119
|
)
|
|
117
120
|
|
|
118
121
|
# Only set n_derivatives as 2 or 3, which is enough for most cases.
|
|
119
122
|
if "n_derivatives" not in self.taylorseer_kwargs:
|
|
120
123
|
self.taylorseer_kwargs["n_derivatives"] = max(
|
|
121
|
-
2, min(3, self.taylorseer_kwargs["
|
|
124
|
+
2, min(3, self.taylorseer_kwargs["max_warmup_steps"])
|
|
122
125
|
)
|
|
123
126
|
|
|
124
127
|
if self.enable_taylorseer:
|
|
@@ -268,10 +271,31 @@ class DBCacheContext:
|
|
|
268
271
|
|
|
269
272
|
@torch.compiler.disable
|
|
270
273
|
def add_cached_step(self):
|
|
274
|
+
curr_cached_step = self.get_current_step()
|
|
271
275
|
if not self.is_separate_cfg_step():
|
|
272
|
-
self.cached_steps
|
|
276
|
+
if self.cached_steps:
|
|
277
|
+
prev_cached_step = self.cached_steps[-1]
|
|
278
|
+
if curr_cached_step - prev_cached_step == 1:
|
|
279
|
+
if self.continuous_cached_steps == 0:
|
|
280
|
+
self.continuous_cached_steps += 2
|
|
281
|
+
else:
|
|
282
|
+
self.continuous_cached_steps += 1
|
|
283
|
+
else:
|
|
284
|
+
self.continuous_cached_steps += 1
|
|
285
|
+
|
|
286
|
+
self.cached_steps.append(curr_cached_step)
|
|
273
287
|
else:
|
|
274
|
-
self.cfg_cached_steps
|
|
288
|
+
if self.cfg_cached_steps:
|
|
289
|
+
prev_cfg_cached_step = self.cfg_cached_steps[-1]
|
|
290
|
+
if curr_cached_step - prev_cfg_cached_step == 1:
|
|
291
|
+
if self.cfg_continuous_cached_steps == 0:
|
|
292
|
+
self.cfg_continuous_cached_steps += 2
|
|
293
|
+
else:
|
|
294
|
+
self.cfg_continuous_cached_steps += 1
|
|
295
|
+
else:
|
|
296
|
+
self.cfg_continuous_cached_steps += 1
|
|
297
|
+
|
|
298
|
+
self.cfg_cached_steps.append(curr_cached_step)
|
|
275
299
|
|
|
276
300
|
@torch.compiler.disable
|
|
277
301
|
def get_cached_steps(self):
|
|
@@ -301,7 +325,34 @@ class DBCacheContext:
|
|
|
301
325
|
|
|
302
326
|
@torch.compiler.disable
|
|
303
327
|
def is_in_warmup(self):
|
|
304
|
-
return self.get_current_step() < self.
|
|
328
|
+
return self.get_current_step() < self.max_warmup_steps
|
|
329
|
+
|
|
330
|
+
|
|
331
|
+
# TODO: Support context manager for different cache_context
|
|
332
|
+
|
|
333
|
+
|
|
334
|
+
def create_cache_context(*args, **kwargs):
|
|
335
|
+
return DBCacheContext(*args, **kwargs)
|
|
336
|
+
|
|
337
|
+
|
|
338
|
+
def get_current_cache_context():
|
|
339
|
+
return _current_cache_context
|
|
340
|
+
|
|
341
|
+
|
|
342
|
+
def set_current_cache_context(cache_context=None):
|
|
343
|
+
global _current_cache_context
|
|
344
|
+
_current_cache_context = cache_context
|
|
345
|
+
|
|
346
|
+
|
|
347
|
+
@contextlib.contextmanager
|
|
348
|
+
def cache_context(cache_context):
|
|
349
|
+
global _current_cache_context
|
|
350
|
+
old_cache_context = _current_cache_context
|
|
351
|
+
_current_cache_context = cache_context
|
|
352
|
+
try:
|
|
353
|
+
yield
|
|
354
|
+
finally:
|
|
355
|
+
_current_cache_context = old_cache_context
|
|
305
356
|
|
|
306
357
|
|
|
307
358
|
@torch.compiler.disable
|
|
@@ -396,6 +447,27 @@ def get_max_cached_steps():
|
|
|
396
447
|
return cache_context.max_cached_steps
|
|
397
448
|
|
|
398
449
|
|
|
450
|
+
@torch.compiler.disable
|
|
451
|
+
def get_max_continuous_cached_steps():
|
|
452
|
+
cache_context = get_current_cache_context()
|
|
453
|
+
assert cache_context is not None, "cache_context must be set before"
|
|
454
|
+
return cache_context.max_continuous_cached_steps
|
|
455
|
+
|
|
456
|
+
|
|
457
|
+
@torch.compiler.disable
|
|
458
|
+
def get_continuous_cached_steps():
|
|
459
|
+
cache_context = get_current_cache_context()
|
|
460
|
+
assert cache_context is not None, "cache_context must be set before"
|
|
461
|
+
return cache_context.continuous_cached_steps
|
|
462
|
+
|
|
463
|
+
|
|
464
|
+
@torch.compiler.disable
|
|
465
|
+
def get_cfg_continuous_cached_steps():
|
|
466
|
+
cache_context = get_current_cache_context()
|
|
467
|
+
assert cache_context is not None, "cache_context must be set before"
|
|
468
|
+
return cache_context.cfg_continuous_cached_steps
|
|
469
|
+
|
|
470
|
+
|
|
399
471
|
@torch.compiler.disable
|
|
400
472
|
def add_cached_step():
|
|
401
473
|
cache_context = get_current_cache_context()
|
|
@@ -612,19 +684,6 @@ def cfg_diff_compute_separate():
|
|
|
612
684
|
_current_cache_context: DBCacheContext = None
|
|
613
685
|
|
|
614
686
|
|
|
615
|
-
def create_cache_context(*args, **kwargs):
|
|
616
|
-
return DBCacheContext(*args, **kwargs)
|
|
617
|
-
|
|
618
|
-
|
|
619
|
-
def get_current_cache_context():
|
|
620
|
-
return _current_cache_context
|
|
621
|
-
|
|
622
|
-
|
|
623
|
-
def set_current_cache_context(cache_context=None):
|
|
624
|
-
global _current_cache_context
|
|
625
|
-
_current_cache_context = cache_context
|
|
626
|
-
|
|
627
|
-
|
|
628
687
|
def collect_cache_kwargs(default_attrs: dict, **kwargs):
|
|
629
688
|
# NOTE: This API will split kwargs into cache_kwargs and other_kwargs
|
|
630
689
|
# default_attrs: specific settings for different pipelines
|
|
@@ -671,17 +730,6 @@ def collect_cache_kwargs(default_attrs: dict, **kwargs):
|
|
|
671
730
|
return cache_kwargs, kwargs
|
|
672
731
|
|
|
673
732
|
|
|
674
|
-
@contextlib.contextmanager
|
|
675
|
-
def cache_context(cache_context):
|
|
676
|
-
global _current_cache_context
|
|
677
|
-
old_cache_context = _current_cache_context
|
|
678
|
-
_current_cache_context = cache_context
|
|
679
|
-
try:
|
|
680
|
-
yield
|
|
681
|
-
finally:
|
|
682
|
-
_current_cache_context = old_cache_context
|
|
683
|
-
|
|
684
|
-
|
|
685
733
|
@torch.compiler.disable
|
|
686
734
|
def are_two_tensors_similar(
|
|
687
735
|
t1: torch.Tensor, # prev residual R(t-1,n) = H(t-1,n) - H(t-1,0)
|
|
@@ -744,8 +792,8 @@ def are_two_tensors_similar(
|
|
|
744
792
|
mean_t1 = t1.abs().mean()
|
|
745
793
|
|
|
746
794
|
if parallelized:
|
|
747
|
-
|
|
748
|
-
|
|
795
|
+
dist.all_reduce(mean_diff, op=dist.ReduceOp.AVG)
|
|
796
|
+
dist.all_reduce(mean_t1, op=dist.ReduceOp.AVG)
|
|
749
797
|
|
|
750
798
|
# D = (t1 - t2) / t1 = 1 - (t2 / t1), if D = 0, then t1 = t2.
|
|
751
799
|
# Futher, if we assume that (H(t, 0) - H(t-1,0)) ~ 0, then,
|
|
@@ -1020,6 +1068,7 @@ def get_can_use_cache(
|
|
|
1020
1068
|
if is_in_warmup():
|
|
1021
1069
|
return False
|
|
1022
1070
|
|
|
1071
|
+
# max cached steps
|
|
1023
1072
|
max_cached_steps = get_max_cached_steps()
|
|
1024
1073
|
if not is_separate_cfg_step():
|
|
1025
1074
|
cached_steps = get_cached_steps()
|
|
@@ -1030,10 +1079,34 @@ def get_can_use_cache(
|
|
|
1030
1079
|
if logger.isEnabledFor(logging.DEBUG):
|
|
1031
1080
|
logger.debug(
|
|
1032
1081
|
f"{prefix}, max_cached_steps reached: {max_cached_steps}, "
|
|
1033
|
-
"
|
|
1082
|
+
"can not use cache."
|
|
1034
1083
|
)
|
|
1035
1084
|
return False
|
|
1036
1085
|
|
|
1086
|
+
# max continuous cached steps
|
|
1087
|
+
max_continuous_cached_steps = get_max_continuous_cached_steps()
|
|
1088
|
+
if not is_separate_cfg_step():
|
|
1089
|
+
continuous_cached_steps = get_continuous_cached_steps()
|
|
1090
|
+
else:
|
|
1091
|
+
continuous_cached_steps = get_cfg_continuous_cached_steps()
|
|
1092
|
+
|
|
1093
|
+
if max_continuous_cached_steps >= 0 and (
|
|
1094
|
+
continuous_cached_steps >= max_continuous_cached_steps
|
|
1095
|
+
):
|
|
1096
|
+
if logger.isEnabledFor(logging.DEBUG):
|
|
1097
|
+
logger.debug(
|
|
1098
|
+
f"{prefix}, max_continuous_cached_steps "
|
|
1099
|
+
f"reached: {max_continuous_cached_steps}, "
|
|
1100
|
+
"can not use cache."
|
|
1101
|
+
)
|
|
1102
|
+
# reset continuous cached steps stats
|
|
1103
|
+
cache_context = get_current_cache_context()
|
|
1104
|
+
if not is_separate_cfg_step():
|
|
1105
|
+
cache_context.continuous_cached_steps = 0
|
|
1106
|
+
else:
|
|
1107
|
+
cache_context.cfg_continuous_cached_steps = 0
|
|
1108
|
+
return False
|
|
1109
|
+
|
|
1037
1110
|
if threshold is None or threshold <= 0.0:
|
|
1038
1111
|
threshold = get_residual_diff_threshold()
|
|
1039
1112
|
if threshold <= 0.0:
|
|
@@ -16,13 +16,14 @@ def enable_cache(
|
|
|
16
16
|
# Cache context kwargs
|
|
17
17
|
Fn_compute_blocks: int = 8,
|
|
18
18
|
Bn_compute_blocks: int = 0,
|
|
19
|
-
|
|
19
|
+
max_warmup_steps: int = 8,
|
|
20
20
|
max_cached_steps: int = -1,
|
|
21
|
+
max_continuous_cached_steps: int = -1,
|
|
21
22
|
residual_diff_threshold: float = 0.08,
|
|
22
23
|
# Cache CFG or not
|
|
23
24
|
do_separate_cfg: bool = False,
|
|
24
25
|
cfg_compute_first: bool = False,
|
|
25
|
-
cfg_diff_compute_separate: bool =
|
|
26
|
+
cfg_diff_compute_separate: bool = True,
|
|
26
27
|
# Hybird TaylorSeer
|
|
27
28
|
enable_taylorseer: bool = False,
|
|
28
29
|
enable_encoder_taylorseer: bool = False,
|
|
@@ -54,12 +55,15 @@ def enable_cache(
|
|
|
54
55
|
Further fuses approximate information in the **last n** Transformer blocks to enhance
|
|
55
56
|
prediction accuracy. These blocks act as an auto-scaler for approximate hidden states
|
|
56
57
|
that use residual cache.
|
|
57
|
-
|
|
58
|
+
max_warmup_steps (`int`, *required*, defaults to 8):
|
|
58
59
|
DBCache does not apply the caching strategy when the number of running steps is less than
|
|
59
60
|
or equal to this value, ensuring the model sufficiently learns basic features during warmup.
|
|
60
61
|
max_cached_steps (`int`, *required*, defaults to -1):
|
|
61
62
|
DBCache disables the caching strategy when the previous cached steps exceed this value to
|
|
62
63
|
prevent precision degradation.
|
|
64
|
+
max_continuous_cached_steps (`int`, *required*, defaults to -1):
|
|
65
|
+
DBCache disables the caching strategy when the previous continous cached steps exceed this value to
|
|
66
|
+
prevent precision degradation.
|
|
63
67
|
residual_diff_threshold (`float`, *required*, defaults to 0.08):
|
|
64
68
|
he value of residual diff threshold, a higher value leads to faster performance at the
|
|
65
69
|
cost of lower precision.
|
|
@@ -106,8 +110,11 @@ def enable_cache(
|
|
|
106
110
|
cache_context_kwargs["cache_type"] = CacheType.DBCache
|
|
107
111
|
cache_context_kwargs["Fn_compute_blocks"] = Fn_compute_blocks
|
|
108
112
|
cache_context_kwargs["Bn_compute_blocks"] = Bn_compute_blocks
|
|
109
|
-
cache_context_kwargs["
|
|
113
|
+
cache_context_kwargs["max_warmup_steps"] = max_warmup_steps
|
|
110
114
|
cache_context_kwargs["max_cached_steps"] = max_cached_steps
|
|
115
|
+
cache_context_kwargs["max_continuous_cached_steps"] = (
|
|
116
|
+
max_continuous_cached_steps
|
|
117
|
+
)
|
|
111
118
|
cache_context_kwargs["residual_diff_threshold"] = residual_diff_threshold
|
|
112
119
|
cache_context_kwargs["do_separate_cfg"] = do_separate_cfg
|
|
113
120
|
cache_context_kwargs["cfg_compute_first"] = cfg_compute_first
|
|
@@ -6,13 +6,13 @@ class TaylorSeer:
|
|
|
6
6
|
def __init__(
|
|
7
7
|
self,
|
|
8
8
|
n_derivatives=2,
|
|
9
|
-
|
|
9
|
+
max_warmup_steps=1,
|
|
10
10
|
skip_interval_steps=1,
|
|
11
11
|
compute_step_map=None,
|
|
12
12
|
):
|
|
13
13
|
self.n_derivatives = n_derivatives
|
|
14
14
|
self.ORDER = n_derivatives + 1
|
|
15
|
-
self.
|
|
15
|
+
self.max_warmup_steps = max_warmup_steps
|
|
16
16
|
self.skip_interval_steps = skip_interval_steps
|
|
17
17
|
self.compute_step_map = compute_step_map
|
|
18
18
|
self.reset_cache()
|
|
@@ -32,8 +32,9 @@ class TaylorSeer:
|
|
|
32
32
|
if self.compute_step_map is not None:
|
|
33
33
|
return self.compute_step_map[step]
|
|
34
34
|
if (
|
|
35
|
-
step < self.
|
|
36
|
-
or (step - self.
|
|
35
|
+
step < self.max_warmup_steps
|
|
36
|
+
or (step - self.max_warmup_steps + 1) % self.skip_interval_steps
|
|
37
|
+
== 0
|
|
37
38
|
):
|
|
38
39
|
return True
|
|
39
40
|
return False
|
cache_dit/cache_factory/utils.py
CHANGED
cache_dit/compile/utils.py
CHANGED
|
@@ -24,7 +24,7 @@ def epilogue_prologue_fusion_enabled(**kwargs) -> bool:
|
|
|
24
24
|
|
|
25
25
|
|
|
26
26
|
def set_compile_configs(
|
|
27
|
-
descent_tuning: bool =
|
|
27
|
+
descent_tuning: bool = False,
|
|
28
28
|
cuda_graphs: bool = False,
|
|
29
29
|
force_disable_compile_caches: bool = False,
|
|
30
30
|
use_fast_math: bool = False,
|
|
@@ -0,0 +1 @@
|
|
|
1
|
+
from cache_dit.quantize.quantize_interface import quantize
|