cache-dit 0.2.23__py3-none-any.whl → 0.2.24__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Potentially problematic release.
This version of cache-dit might be problematic. Click here for more details.
- cache_dit/_version.py +2 -2
- cache_dit/cache_factory/cache_adapters.py +137 -76
- cache_dit/cache_factory/cache_context.py +85 -15
- cache_dit/cache_factory/cache_interface.py +10 -3
- cache_dit/cache_factory/taylorseer.py +5 -4
- cache_dit/cache_factory/utils.py +1 -1
- cache_dit/utils.py +25 -22
- {cache_dit-0.2.23.dist-info → cache_dit-0.2.24.dist-info}/METADATA +8 -6
- {cache_dit-0.2.23.dist-info → cache_dit-0.2.24.dist-info}/RECORD +13 -14
- cache_dit/primitives.py +0 -152
- {cache_dit-0.2.23.dist-info → cache_dit-0.2.24.dist-info}/WHEEL +0 -0
- {cache_dit-0.2.23.dist-info → cache_dit-0.2.24.dist-info}/entry_points.txt +0 -0
- {cache_dit-0.2.23.dist-info → cache_dit-0.2.24.dist-info}/licenses/LICENSE +0 -0
- {cache_dit-0.2.23.dist-info → cache_dit-0.2.24.dist-info}/top_level.txt +0 -0
cache_dit/_version.py
CHANGED
|
@@ -28,7 +28,7 @@ version_tuple: VERSION_TUPLE
|
|
|
28
28
|
commit_id: COMMIT_ID
|
|
29
29
|
__commit_id__: COMMIT_ID
|
|
30
30
|
|
|
31
|
-
__version__ = version = '0.2.
|
|
32
|
-
__version_tuple__ = version_tuple = (0, 2,
|
|
31
|
+
__version__ = version = '0.2.24'
|
|
32
|
+
__version_tuple__ = version_tuple = (0, 2, 24)
|
|
33
33
|
|
|
34
34
|
__commit_id__ = commit_id = None
|
|
@@ -5,7 +5,7 @@ import unittest
|
|
|
5
5
|
import functools
|
|
6
6
|
import dataclasses
|
|
7
7
|
|
|
8
|
-
from typing import Any, Tuple, List
|
|
8
|
+
from typing import Any, Tuple, List, Optional
|
|
9
9
|
from contextlib import ExitStack
|
|
10
10
|
from diffusers import DiffusionPipeline
|
|
11
11
|
from cache_dit.cache_factory.patch.flux import (
|
|
@@ -40,6 +40,7 @@ class BlockAdapter:
|
|
|
40
40
|
"layers",
|
|
41
41
|
]
|
|
42
42
|
)
|
|
43
|
+
check_prefixes: bool = True
|
|
43
44
|
allow_suffixes: List[str] = dataclasses.field(
|
|
44
45
|
default_factory=lambda: ["TransformerBlock"]
|
|
45
46
|
)
|
|
@@ -48,8 +49,24 @@ class BlockAdapter:
|
|
|
48
49
|
default="max", metadata={"allowed_values": ["max", "min"]}
|
|
49
50
|
)
|
|
50
51
|
|
|
52
|
+
def __post_init__(self):
|
|
53
|
+
self.maybe_apply_patch()
|
|
54
|
+
|
|
55
|
+
def maybe_apply_patch(self):
|
|
56
|
+
# Process some specificial cases, specific for transformers
|
|
57
|
+
# that has different forward patterns between single_transformer_blocks
|
|
58
|
+
# and transformer_blocks , such as Flux (diffusers < 0.35.0).
|
|
59
|
+
if self.transformer.__class__.__name__.startswith("Flux"):
|
|
60
|
+
self.transformer = maybe_patch_flux_transformer(
|
|
61
|
+
self.transformer,
|
|
62
|
+
blocks=self.blocks,
|
|
63
|
+
)
|
|
64
|
+
|
|
51
65
|
@staticmethod
|
|
52
|
-
def auto_block_adapter(
|
|
66
|
+
def auto_block_adapter(
|
|
67
|
+
adapter: "BlockAdapter",
|
|
68
|
+
forward_pattern: Optional[ForwardPattern] = None,
|
|
69
|
+
) -> "BlockAdapter":
|
|
53
70
|
assert adapter.auto, (
|
|
54
71
|
"Please manually set `auto` to True, or, manually "
|
|
55
72
|
"set all the transformer blocks configuration."
|
|
@@ -66,8 +83,10 @@ class BlockAdapter:
|
|
|
66
83
|
transformer=transformer,
|
|
67
84
|
allow_prefixes=adapter.allow_prefixes,
|
|
68
85
|
allow_suffixes=adapter.allow_suffixes,
|
|
86
|
+
check_prefixes=adapter.check_prefixes,
|
|
69
87
|
check_suffixes=adapter.check_suffixes,
|
|
70
88
|
blocks_policy=adapter.blocks_policy,
|
|
89
|
+
forward_pattern=forward_pattern,
|
|
71
90
|
)
|
|
72
91
|
|
|
73
92
|
return BlockAdapter(
|
|
@@ -87,6 +106,8 @@ class BlockAdapter:
|
|
|
87
106
|
and isinstance(adapter.blocks, torch.nn.ModuleList)
|
|
88
107
|
):
|
|
89
108
|
return True
|
|
109
|
+
|
|
110
|
+
logger.warning("Check block adapter failed!")
|
|
90
111
|
return False
|
|
91
112
|
|
|
92
113
|
@staticmethod
|
|
@@ -101,24 +122,30 @@ class BlockAdapter:
|
|
|
101
122
|
allow_suffixes: List[str] = [
|
|
102
123
|
"TransformerBlock",
|
|
103
124
|
],
|
|
125
|
+
check_prefixes: bool = True,
|
|
104
126
|
check_suffixes: bool = False,
|
|
105
127
|
**kwargs,
|
|
106
128
|
) -> Tuple[torch.nn.ModuleList, str]:
|
|
129
|
+
# Check prefixes
|
|
130
|
+
if check_prefixes:
|
|
131
|
+
blocks_names = []
|
|
132
|
+
for attr_name in dir(transformer):
|
|
133
|
+
for prefix in allow_prefixes:
|
|
134
|
+
if attr_name.startswith(prefix):
|
|
135
|
+
blocks_names.append(attr_name)
|
|
136
|
+
else:
|
|
137
|
+
blocks_names = dir(transformer)
|
|
107
138
|
|
|
108
|
-
|
|
109
|
-
for attr_name in dir(transformer):
|
|
110
|
-
for prefix in allow_prefixes:
|
|
111
|
-
if attr_name.startswith(prefix):
|
|
112
|
-
blocks_names.append(attr_name)
|
|
113
|
-
|
|
114
|
-
# Type check
|
|
139
|
+
# Check ModuleList
|
|
115
140
|
valid_names = []
|
|
116
141
|
valid_count = []
|
|
142
|
+
forward_pattern = kwargs.get("forward_pattern", None)
|
|
117
143
|
for blocks_name in blocks_names:
|
|
118
144
|
if blocks := getattr(transformer, blocks_name, None):
|
|
119
145
|
if isinstance(blocks, torch.nn.ModuleList):
|
|
120
146
|
block = blocks[0]
|
|
121
147
|
block_cls_name = block.__class__.__name__
|
|
148
|
+
# Check suffixes
|
|
122
149
|
if isinstance(block, torch.nn.Module) and (
|
|
123
150
|
any(
|
|
124
151
|
(
|
|
@@ -128,8 +155,18 @@ class BlockAdapter:
|
|
|
128
155
|
)
|
|
129
156
|
or (not check_suffixes)
|
|
130
157
|
):
|
|
131
|
-
|
|
132
|
-
|
|
158
|
+
# May check forward pattern
|
|
159
|
+
if forward_pattern is not None:
|
|
160
|
+
if BlockAdapter.match_blocks_pattern(
|
|
161
|
+
blocks,
|
|
162
|
+
forward_pattern,
|
|
163
|
+
logging=False,
|
|
164
|
+
):
|
|
165
|
+
valid_names.append(blocks_name)
|
|
166
|
+
valid_count.append(len(blocks))
|
|
167
|
+
else:
|
|
168
|
+
valid_names.append(blocks_name)
|
|
169
|
+
valid_count.append(len(blocks))
|
|
133
170
|
|
|
134
171
|
if not valid_names:
|
|
135
172
|
raise ValueError(
|
|
@@ -139,6 +176,7 @@ class BlockAdapter:
|
|
|
139
176
|
final_name = valid_names[0]
|
|
140
177
|
final_count = valid_count[0]
|
|
141
178
|
block_policy = kwargs.get("blocks_policy", "max")
|
|
179
|
+
|
|
142
180
|
for blocks_name, count in zip(valid_names, valid_count):
|
|
143
181
|
blocks = getattr(transformer, blocks_name)
|
|
144
182
|
logger.info(
|
|
@@ -165,6 +203,67 @@ class BlockAdapter:
|
|
|
165
203
|
|
|
166
204
|
return final_blocks, final_name
|
|
167
205
|
|
|
206
|
+
@staticmethod
|
|
207
|
+
def match_block_pattern(
|
|
208
|
+
block: torch.nn.Module,
|
|
209
|
+
forward_pattern: ForwardPattern,
|
|
210
|
+
) -> bool:
|
|
211
|
+
assert (
|
|
212
|
+
forward_pattern.Supported
|
|
213
|
+
and forward_pattern in ForwardPattern.supported_patterns()
|
|
214
|
+
), f"Pattern {forward_pattern} is not support now!"
|
|
215
|
+
|
|
216
|
+
forward_parameters = set(
|
|
217
|
+
inspect.signature(block.forward).parameters.keys()
|
|
218
|
+
)
|
|
219
|
+
num_outputs = str(
|
|
220
|
+
inspect.signature(block.forward).return_annotation
|
|
221
|
+
).count("torch.Tensor")
|
|
222
|
+
|
|
223
|
+
in_matched = True
|
|
224
|
+
out_matched = True
|
|
225
|
+
if num_outputs > 0 and len(forward_pattern.Out) != num_outputs:
|
|
226
|
+
# output pattern not match
|
|
227
|
+
out_matched = False
|
|
228
|
+
|
|
229
|
+
for required_param in forward_pattern.In:
|
|
230
|
+
if required_param not in forward_parameters:
|
|
231
|
+
in_matched = False
|
|
232
|
+
|
|
233
|
+
return in_matched and out_matched
|
|
234
|
+
|
|
235
|
+
@staticmethod
|
|
236
|
+
def match_blocks_pattern(
|
|
237
|
+
transformer_blocks: torch.nn.ModuleList,
|
|
238
|
+
forward_pattern: ForwardPattern,
|
|
239
|
+
logging: bool = True,
|
|
240
|
+
) -> bool:
|
|
241
|
+
assert (
|
|
242
|
+
forward_pattern.Supported
|
|
243
|
+
and forward_pattern in ForwardPattern.supported_patterns()
|
|
244
|
+
), f"Pattern {forward_pattern} is not support now!"
|
|
245
|
+
|
|
246
|
+
assert isinstance(transformer_blocks, torch.nn.ModuleList)
|
|
247
|
+
|
|
248
|
+
pattern_matched_states = []
|
|
249
|
+
for block in transformer_blocks:
|
|
250
|
+
pattern_matched_states.append(
|
|
251
|
+
BlockAdapter.match_block_pattern(
|
|
252
|
+
block,
|
|
253
|
+
forward_pattern,
|
|
254
|
+
)
|
|
255
|
+
)
|
|
256
|
+
|
|
257
|
+
pattern_matched = all(pattern_matched_states) # all block match
|
|
258
|
+
if pattern_matched and logging:
|
|
259
|
+
block_cls_name = transformer_blocks[0].__class__.__name__
|
|
260
|
+
logger.info(
|
|
261
|
+
f"Match Block Forward Pattern: {block_cls_name}, {forward_pattern}"
|
|
262
|
+
f"\nIN:{forward_pattern.In}, OUT:{forward_pattern.Out})"
|
|
263
|
+
)
|
|
264
|
+
|
|
265
|
+
return pattern_matched
|
|
266
|
+
|
|
168
267
|
|
|
169
268
|
@dataclasses.dataclass
|
|
170
269
|
class UnifiedCacheParams:
|
|
@@ -463,19 +562,42 @@ class UnifiedCacheAdapter:
|
|
|
463
562
|
) -> DiffusionPipeline:
|
|
464
563
|
|
|
465
564
|
if block_adapter.auto:
|
|
466
|
-
block_adapter = BlockAdapter.auto_block_adapter(
|
|
565
|
+
block_adapter = BlockAdapter.auto_block_adapter(
|
|
566
|
+
block_adapter,
|
|
567
|
+
forward_pattern,
|
|
568
|
+
)
|
|
467
569
|
|
|
468
570
|
if BlockAdapter.check_block_adapter(block_adapter):
|
|
469
|
-
assert isinstance(block_adapter.blocks, torch.nn.ModuleList)
|
|
470
571
|
# Apply cache on pipeline: wrap cache context
|
|
471
|
-
cls.create_context(
|
|
572
|
+
cls.create_context(
|
|
573
|
+
block_adapter.pipe,
|
|
574
|
+
**cache_context_kwargs,
|
|
575
|
+
)
|
|
472
576
|
# Apply cache on transformer: mock cached transformer blocks
|
|
473
577
|
cls.mock_blocks(
|
|
474
578
|
block_adapter,
|
|
475
579
|
forward_pattern=forward_pattern,
|
|
476
580
|
)
|
|
581
|
+
cls.patch_params(
|
|
582
|
+
block_adapter,
|
|
583
|
+
forward_pattern=forward_pattern,
|
|
584
|
+
**cache_context_kwargs,
|
|
585
|
+
)
|
|
477
586
|
return block_adapter.pipe
|
|
478
587
|
|
|
588
|
+
@classmethod
|
|
589
|
+
def patch_params(
|
|
590
|
+
cls,
|
|
591
|
+
block_adapter: BlockAdapter,
|
|
592
|
+
forward_pattern: ForwardPattern = None,
|
|
593
|
+
**cache_context_kwargs,
|
|
594
|
+
):
|
|
595
|
+
block_adapter.transformer._forward_pattern = forward_pattern
|
|
596
|
+
block_adapter.transformer._cache_context_kwargs = cache_context_kwargs
|
|
597
|
+
block_adapter.pipe.__class__._cache_context_kwargs = (
|
|
598
|
+
cache_context_kwargs
|
|
599
|
+
)
|
|
600
|
+
|
|
479
601
|
@classmethod
|
|
480
602
|
def has_separate_cfg(
|
|
481
603
|
cls,
|
|
@@ -534,7 +656,6 @@ class UnifiedCacheAdapter:
|
|
|
534
656
|
|
|
535
657
|
pipe.__class__.__call__ = new_call
|
|
536
658
|
pipe.__class__._is_cached = True
|
|
537
|
-
pipe.__class__._cache_options = cache_kwargs
|
|
538
659
|
return pipe
|
|
539
660
|
|
|
540
661
|
@classmethod
|
|
@@ -544,28 +665,11 @@ class UnifiedCacheAdapter:
|
|
|
544
665
|
forward_pattern: ForwardPattern = ForwardPattern.Pattern_0,
|
|
545
666
|
) -> torch.nn.Module:
|
|
546
667
|
|
|
547
|
-
if (
|
|
548
|
-
block_adapter.transformer is None
|
|
549
|
-
or block_adapter.blocks_name is None
|
|
550
|
-
or block_adapter.blocks is None
|
|
551
|
-
):
|
|
552
|
-
assert block_adapter.auto, (
|
|
553
|
-
"Please manually set `auto` to True, or, "
|
|
554
|
-
"manually set transformer blocks configuration."
|
|
555
|
-
)
|
|
556
|
-
|
|
557
668
|
if getattr(block_adapter.transformer, "_is_cached", False):
|
|
558
669
|
return block_adapter.transformer
|
|
559
670
|
|
|
560
|
-
# Firstly, process some specificial cases (TODO: more patches)
|
|
561
|
-
if block_adapter.transformer.__class__.__name__.startswith("Flux"):
|
|
562
|
-
block_adapter.transformer = maybe_patch_flux_transformer(
|
|
563
|
-
block_adapter.transformer,
|
|
564
|
-
blocks=block_adapter.blocks,
|
|
565
|
-
)
|
|
566
|
-
|
|
567
671
|
# Check block forward pattern matching
|
|
568
|
-
assert
|
|
672
|
+
assert BlockAdapter.match_blocks_pattern(
|
|
569
673
|
block_adapter.blocks,
|
|
570
674
|
forward_pattern=forward_pattern,
|
|
571
675
|
), (
|
|
@@ -615,46 +719,3 @@ class UnifiedCacheAdapter:
|
|
|
615
719
|
block_adapter.transformer._is_cached = True
|
|
616
720
|
|
|
617
721
|
return block_adapter.transformer
|
|
618
|
-
|
|
619
|
-
@classmethod
|
|
620
|
-
def match_pattern(
|
|
621
|
-
cls,
|
|
622
|
-
transformer_blocks: torch.nn.ModuleList,
|
|
623
|
-
forward_pattern: ForwardPattern = ForwardPattern.Pattern_0,
|
|
624
|
-
) -> bool:
|
|
625
|
-
pattern_matched_states = []
|
|
626
|
-
|
|
627
|
-
assert (
|
|
628
|
-
forward_pattern.Supported
|
|
629
|
-
and forward_pattern in ForwardPattern.supported_patterns()
|
|
630
|
-
), f"Pattern {forward_pattern} is not support now!"
|
|
631
|
-
|
|
632
|
-
for block in transformer_blocks:
|
|
633
|
-
forward_parameters = set(
|
|
634
|
-
inspect.signature(block.forward).parameters.keys()
|
|
635
|
-
)
|
|
636
|
-
num_outputs = str(
|
|
637
|
-
inspect.signature(block.forward).return_annotation
|
|
638
|
-
).count("torch.Tensor")
|
|
639
|
-
|
|
640
|
-
in_matched = True
|
|
641
|
-
out_matched = True
|
|
642
|
-
if num_outputs > 0 and len(forward_pattern.Out) != num_outputs:
|
|
643
|
-
# output pattern not match
|
|
644
|
-
out_matched = False
|
|
645
|
-
|
|
646
|
-
for required_param in forward_pattern.In:
|
|
647
|
-
if required_param not in forward_parameters:
|
|
648
|
-
in_matched = False
|
|
649
|
-
|
|
650
|
-
pattern_matched_states.append(in_matched and out_matched)
|
|
651
|
-
|
|
652
|
-
pattern_matched = all(pattern_matched_states) # all block match
|
|
653
|
-
if pattern_matched:
|
|
654
|
-
block_cls_name = transformer_blocks[0].__class__.__name__
|
|
655
|
-
logger.info(
|
|
656
|
-
f"Match Block Forward Pattern: {block_cls_name}, {forward_pattern}"
|
|
657
|
-
f"\nIN:{forward_pattern.In}, OUT:{forward_pattern.Out})"
|
|
658
|
-
)
|
|
659
|
-
|
|
660
|
-
return pattern_matched
|
|
@@ -5,8 +5,8 @@ from collections import defaultdict
|
|
|
5
5
|
from typing import Any, DefaultDict, Dict, List, Optional, Union, Tuple
|
|
6
6
|
|
|
7
7
|
import torch
|
|
8
|
+
import torch.distributed as dist
|
|
8
9
|
|
|
9
|
-
import cache_dit.primitives as primitives
|
|
10
10
|
from cache_dit.cache_factory.taylorseer import TaylorSeer
|
|
11
11
|
from cache_dit.logger import init_logger
|
|
12
12
|
|
|
@@ -47,10 +47,11 @@ class DBCacheContext:
|
|
|
47
47
|
|
|
48
48
|
# Other settings
|
|
49
49
|
downsample_factor: int = 1
|
|
50
|
-
num_inference_steps: int = -1 #
|
|
51
|
-
|
|
50
|
+
num_inference_steps: int = -1 # for future use
|
|
51
|
+
max_warmup_steps: int = 0 # DON'T Cache in warmup steps
|
|
52
52
|
# DON'T Cache if the number of cached steps >= max_cached_steps
|
|
53
53
|
max_cached_steps: int = -1 # for both CFG and non-CFG
|
|
54
|
+
max_continuous_cached_steps: int = -1 # the max continuous cached steps
|
|
54
55
|
|
|
55
56
|
# Record the steps that have been cached, both cached and non-cache
|
|
56
57
|
executed_steps: int = 0 # cache + non-cache steps pippeline
|
|
@@ -89,10 +90,12 @@ class DBCacheContext:
|
|
|
89
90
|
residual_diffs: DefaultDict[str, float] = dataclasses.field(
|
|
90
91
|
default_factory=lambda: defaultdict(float),
|
|
91
92
|
)
|
|
93
|
+
continuous_cached_steps: int = 0
|
|
92
94
|
cfg_cached_steps: List[int] = dataclasses.field(default_factory=list)
|
|
93
95
|
cfg_residual_diffs: DefaultDict[str, float] = dataclasses.field(
|
|
94
96
|
default_factory=lambda: defaultdict(float),
|
|
95
97
|
)
|
|
98
|
+
cfg_continuous_cached_steps: int = 0
|
|
96
99
|
|
|
97
100
|
@torch.compiler.disable
|
|
98
101
|
def __post_init__(self):
|
|
@@ -108,17 +111,17 @@ class DBCacheContext:
|
|
|
108
111
|
"cfg_diff_compute_separate is enabled."
|
|
109
112
|
)
|
|
110
113
|
|
|
111
|
-
if "
|
|
112
|
-
# If
|
|
113
|
-
# set the same as
|
|
114
|
-
self.taylorseer_kwargs["
|
|
115
|
-
self.
|
|
114
|
+
if "max_warmup_steps" not in self.taylorseer_kwargs:
|
|
115
|
+
# If max_warmup_steps is not set in taylorseer_kwargs,
|
|
116
|
+
# set the same as max_warmup_steps for DBCache
|
|
117
|
+
self.taylorseer_kwargs["max_warmup_steps"] = (
|
|
118
|
+
self.max_warmup_steps if self.max_warmup_steps > 0 else 1
|
|
116
119
|
)
|
|
117
120
|
|
|
118
121
|
# Only set n_derivatives as 2 or 3, which is enough for most cases.
|
|
119
122
|
if "n_derivatives" not in self.taylorseer_kwargs:
|
|
120
123
|
self.taylorseer_kwargs["n_derivatives"] = max(
|
|
121
|
-
2, min(3, self.taylorseer_kwargs["
|
|
124
|
+
2, min(3, self.taylorseer_kwargs["max_warmup_steps"])
|
|
122
125
|
)
|
|
123
126
|
|
|
124
127
|
if self.enable_taylorseer:
|
|
@@ -268,10 +271,31 @@ class DBCacheContext:
|
|
|
268
271
|
|
|
269
272
|
@torch.compiler.disable
|
|
270
273
|
def add_cached_step(self):
|
|
274
|
+
curr_cached_step = self.get_current_step()
|
|
271
275
|
if not self.is_separate_cfg_step():
|
|
272
|
-
self.cached_steps
|
|
276
|
+
if self.cached_steps:
|
|
277
|
+
prev_cached_step = self.cached_steps[-1]
|
|
278
|
+
if curr_cached_step - prev_cached_step == 1:
|
|
279
|
+
if self.continuous_cached_steps == 0:
|
|
280
|
+
self.continuous_cached_steps += 2
|
|
281
|
+
else:
|
|
282
|
+
self.continuous_cached_steps += 1
|
|
283
|
+
else:
|
|
284
|
+
self.continuous_cached_steps += 1
|
|
285
|
+
|
|
286
|
+
self.cached_steps.append(curr_cached_step)
|
|
273
287
|
else:
|
|
274
|
-
self.cfg_cached_steps
|
|
288
|
+
if self.cfg_cached_steps:
|
|
289
|
+
prev_cfg_cached_step = self.cfg_cached_steps[-1]
|
|
290
|
+
if curr_cached_step - prev_cfg_cached_step == 1:
|
|
291
|
+
if self.cfg_continuous_cached_steps == 0:
|
|
292
|
+
self.cfg_continuous_cached_steps += 2
|
|
293
|
+
else:
|
|
294
|
+
self.cfg_continuous_cached_steps += 1
|
|
295
|
+
else:
|
|
296
|
+
self.cfg_continuous_cached_steps += 1
|
|
297
|
+
|
|
298
|
+
self.cfg_cached_steps.append(curr_cached_step)
|
|
275
299
|
|
|
276
300
|
@torch.compiler.disable
|
|
277
301
|
def get_cached_steps(self):
|
|
@@ -301,7 +325,7 @@ class DBCacheContext:
|
|
|
301
325
|
|
|
302
326
|
@torch.compiler.disable
|
|
303
327
|
def is_in_warmup(self):
|
|
304
|
-
return self.get_current_step() < self.
|
|
328
|
+
return self.get_current_step() < self.max_warmup_steps
|
|
305
329
|
|
|
306
330
|
|
|
307
331
|
@torch.compiler.disable
|
|
@@ -396,6 +420,27 @@ def get_max_cached_steps():
|
|
|
396
420
|
return cache_context.max_cached_steps
|
|
397
421
|
|
|
398
422
|
|
|
423
|
+
@torch.compiler.disable
|
|
424
|
+
def get_max_continuous_cached_steps():
|
|
425
|
+
cache_context = get_current_cache_context()
|
|
426
|
+
assert cache_context is not None, "cache_context must be set before"
|
|
427
|
+
return cache_context.max_continuous_cached_steps
|
|
428
|
+
|
|
429
|
+
|
|
430
|
+
@torch.compiler.disable
|
|
431
|
+
def get_continuous_cached_steps():
|
|
432
|
+
cache_context = get_current_cache_context()
|
|
433
|
+
assert cache_context is not None, "cache_context must be set before"
|
|
434
|
+
return cache_context.continuous_cached_steps
|
|
435
|
+
|
|
436
|
+
|
|
437
|
+
@torch.compiler.disable
|
|
438
|
+
def get_cfg_continuous_cached_steps():
|
|
439
|
+
cache_context = get_current_cache_context()
|
|
440
|
+
assert cache_context is not None, "cache_context must be set before"
|
|
441
|
+
return cache_context.cfg_continuous_cached_steps
|
|
442
|
+
|
|
443
|
+
|
|
399
444
|
@torch.compiler.disable
|
|
400
445
|
def add_cached_step():
|
|
401
446
|
cache_context = get_current_cache_context()
|
|
@@ -744,8 +789,8 @@ def are_two_tensors_similar(
|
|
|
744
789
|
mean_t1 = t1.abs().mean()
|
|
745
790
|
|
|
746
791
|
if parallelized:
|
|
747
|
-
|
|
748
|
-
|
|
792
|
+
dist.all_reduce(mean_diff, op=dist.ReduceOp.AVG)
|
|
793
|
+
dist.all_reduce(mean_t1, op=dist.ReduceOp.AVG)
|
|
749
794
|
|
|
750
795
|
# D = (t1 - t2) / t1 = 1 - (t2 / t1), if D = 0, then t1 = t2.
|
|
751
796
|
# Futher, if we assume that (H(t, 0) - H(t-1,0)) ~ 0, then,
|
|
@@ -1020,6 +1065,7 @@ def get_can_use_cache(
|
|
|
1020
1065
|
if is_in_warmup():
|
|
1021
1066
|
return False
|
|
1022
1067
|
|
|
1068
|
+
# max cached steps
|
|
1023
1069
|
max_cached_steps = get_max_cached_steps()
|
|
1024
1070
|
if not is_separate_cfg_step():
|
|
1025
1071
|
cached_steps = get_cached_steps()
|
|
@@ -1030,8 +1076,32 @@ def get_can_use_cache(
|
|
|
1030
1076
|
if logger.isEnabledFor(logging.DEBUG):
|
|
1031
1077
|
logger.debug(
|
|
1032
1078
|
f"{prefix}, max_cached_steps reached: {max_cached_steps}, "
|
|
1033
|
-
"
|
|
1079
|
+
"can not use cache."
|
|
1080
|
+
)
|
|
1081
|
+
return False
|
|
1082
|
+
|
|
1083
|
+
# max continuous cached steps
|
|
1084
|
+
max_continuous_cached_steps = get_max_continuous_cached_steps()
|
|
1085
|
+
if not is_separate_cfg_step():
|
|
1086
|
+
continuous_cached_steps = get_continuous_cached_steps()
|
|
1087
|
+
else:
|
|
1088
|
+
continuous_cached_steps = get_cfg_continuous_cached_steps()
|
|
1089
|
+
|
|
1090
|
+
if max_continuous_cached_steps >= 0 and (
|
|
1091
|
+
continuous_cached_steps >= max_continuous_cached_steps
|
|
1092
|
+
):
|
|
1093
|
+
if logger.isEnabledFor(logging.DEBUG):
|
|
1094
|
+
logger.debug(
|
|
1095
|
+
f"{prefix}, max_continuous_cached_steps "
|
|
1096
|
+
f"reached: {max_continuous_cached_steps}, "
|
|
1097
|
+
"can not use cache."
|
|
1034
1098
|
)
|
|
1099
|
+
# reset continuous cached steps stats
|
|
1100
|
+
cache_context = get_current_cache_context()
|
|
1101
|
+
if not is_separate_cfg_step():
|
|
1102
|
+
cache_context.continuous_cached_steps = 0
|
|
1103
|
+
else:
|
|
1104
|
+
cache_context.cfg_continuous_cached_steps = 0
|
|
1035
1105
|
return False
|
|
1036
1106
|
|
|
1037
1107
|
if threshold is None or threshold <= 0.0:
|
|
@@ -16,8 +16,9 @@ def enable_cache(
|
|
|
16
16
|
# Cache context kwargs
|
|
17
17
|
Fn_compute_blocks: int = 8,
|
|
18
18
|
Bn_compute_blocks: int = 0,
|
|
19
|
-
|
|
19
|
+
max_warmup_steps: int = 8,
|
|
20
20
|
max_cached_steps: int = -1,
|
|
21
|
+
max_continuous_cached_steps: int = -1,
|
|
21
22
|
residual_diff_threshold: float = 0.08,
|
|
22
23
|
# Cache CFG or not
|
|
23
24
|
do_separate_cfg: bool = False,
|
|
@@ -54,12 +55,15 @@ def enable_cache(
|
|
|
54
55
|
Further fuses approximate information in the **last n** Transformer blocks to enhance
|
|
55
56
|
prediction accuracy. These blocks act as an auto-scaler for approximate hidden states
|
|
56
57
|
that use residual cache.
|
|
57
|
-
|
|
58
|
+
max_warmup_steps (`int`, *required*, defaults to 8):
|
|
58
59
|
DBCache does not apply the caching strategy when the number of running steps is less than
|
|
59
60
|
or equal to this value, ensuring the model sufficiently learns basic features during warmup.
|
|
60
61
|
max_cached_steps (`int`, *required*, defaults to -1):
|
|
61
62
|
DBCache disables the caching strategy when the previous cached steps exceed this value to
|
|
62
63
|
prevent precision degradation.
|
|
64
|
+
max_continuous_cached_steps (`int`, *required*, defaults to -1):
|
|
65
|
+
DBCache disables the caching strategy when the previous continous cached steps exceed this value to
|
|
66
|
+
prevent precision degradation.
|
|
63
67
|
residual_diff_threshold (`float`, *required*, defaults to 0.08):
|
|
64
68
|
he value of residual diff threshold, a higher value leads to faster performance at the
|
|
65
69
|
cost of lower precision.
|
|
@@ -106,8 +110,11 @@ def enable_cache(
|
|
|
106
110
|
cache_context_kwargs["cache_type"] = CacheType.DBCache
|
|
107
111
|
cache_context_kwargs["Fn_compute_blocks"] = Fn_compute_blocks
|
|
108
112
|
cache_context_kwargs["Bn_compute_blocks"] = Bn_compute_blocks
|
|
109
|
-
cache_context_kwargs["
|
|
113
|
+
cache_context_kwargs["max_warmup_steps"] = max_warmup_steps
|
|
110
114
|
cache_context_kwargs["max_cached_steps"] = max_cached_steps
|
|
115
|
+
cache_context_kwargs["max_continuous_cached_steps"] = (
|
|
116
|
+
max_continuous_cached_steps
|
|
117
|
+
)
|
|
111
118
|
cache_context_kwargs["residual_diff_threshold"] = residual_diff_threshold
|
|
112
119
|
cache_context_kwargs["do_separate_cfg"] = do_separate_cfg
|
|
113
120
|
cache_context_kwargs["cfg_compute_first"] = cfg_compute_first
|
|
@@ -6,13 +6,13 @@ class TaylorSeer:
|
|
|
6
6
|
def __init__(
|
|
7
7
|
self,
|
|
8
8
|
n_derivatives=2,
|
|
9
|
-
|
|
9
|
+
max_warmup_steps=1,
|
|
10
10
|
skip_interval_steps=1,
|
|
11
11
|
compute_step_map=None,
|
|
12
12
|
):
|
|
13
13
|
self.n_derivatives = n_derivatives
|
|
14
14
|
self.ORDER = n_derivatives + 1
|
|
15
|
-
self.
|
|
15
|
+
self.max_warmup_steps = max_warmup_steps
|
|
16
16
|
self.skip_interval_steps = skip_interval_steps
|
|
17
17
|
self.compute_step_map = compute_step_map
|
|
18
18
|
self.reset_cache()
|
|
@@ -32,8 +32,9 @@ class TaylorSeer:
|
|
|
32
32
|
if self.compute_step_map is not None:
|
|
33
33
|
return self.compute_step_map[step]
|
|
34
34
|
if (
|
|
35
|
-
step < self.
|
|
36
|
-
or (step - self.
|
|
35
|
+
step < self.max_warmup_steps
|
|
36
|
+
or (step - self.max_warmup_steps + 1) % self.skip_interval_steps
|
|
37
|
+
== 0
|
|
37
38
|
):
|
|
38
39
|
return True
|
|
39
40
|
return False
|
cache_dit/cache_factory/utils.py
CHANGED
cache_dit/utils.py
CHANGED
|
@@ -27,22 +27,26 @@ class CacheStats:
|
|
|
27
27
|
|
|
28
28
|
|
|
29
29
|
def summary(
|
|
30
|
-
|
|
31
|
-
|
|
30
|
+
pipe_or_transformer: DiffusionPipeline | torch.nn.Module,
|
|
31
|
+
details: bool = False,
|
|
32
|
+
logging: bool = True,
|
|
33
|
+
) -> CacheStats:
|
|
32
34
|
cache_stats = CacheStats()
|
|
33
|
-
|
|
35
|
+
cls_name = pipe_or_transformer.__class__.__name__
|
|
36
|
+
if isinstance(pipe_or_transformer, DiffusionPipeline):
|
|
37
|
+
transformer = pipe_or_transformer.transformer
|
|
38
|
+
else:
|
|
39
|
+
transformer = pipe_or_transformer
|
|
34
40
|
|
|
35
|
-
if hasattr(
|
|
36
|
-
cache_options =
|
|
41
|
+
if hasattr(transformer, "_cache_context_kwargs"):
|
|
42
|
+
cache_options = transformer._cache_context_kwargs
|
|
37
43
|
cache_stats.cache_options = cache_options
|
|
38
44
|
if logging:
|
|
39
|
-
print(f"\n🤗Cache Options: {
|
|
45
|
+
print(f"\n🤗Cache Options: {cls_name}\n\n{cache_options}")
|
|
40
46
|
|
|
41
|
-
if hasattr(
|
|
42
|
-
cached_steps: list[int] =
|
|
43
|
-
residual_diffs: dict[str, float] = dict(
|
|
44
|
-
pipe.transformer._residual_diffs
|
|
45
|
-
)
|
|
47
|
+
if hasattr(transformer, "_cached_steps"):
|
|
48
|
+
cached_steps: list[int] = transformer._cached_steps
|
|
49
|
+
residual_diffs: dict[str, float] = dict(transformer._residual_diffs)
|
|
46
50
|
cache_stats.cached_steps = cached_steps
|
|
47
51
|
cache_stats.residual_diffs = residual_diffs
|
|
48
52
|
|
|
@@ -57,7 +61,7 @@ def summary(
|
|
|
57
61
|
qmax = np.max(diffs_values)
|
|
58
62
|
|
|
59
63
|
print(
|
|
60
|
-
f"\n⚡️Cache Steps and Residual Diffs Statistics: {
|
|
64
|
+
f"\n⚡️Cache Steps and Residual Diffs Statistics: {cls_name}\n"
|
|
61
65
|
)
|
|
62
66
|
|
|
63
67
|
print(
|
|
@@ -74,9 +78,7 @@ def summary(
|
|
|
74
78
|
print("")
|
|
75
79
|
|
|
76
80
|
if details:
|
|
77
|
-
print(
|
|
78
|
-
f"📚Cache Steps and Residual Diffs Details: {pipe_cls_name}\n"
|
|
79
|
-
)
|
|
81
|
+
print(f"📚Cache Steps and Residual Diffs Details: {cls_name}\n")
|
|
80
82
|
pprint(
|
|
81
83
|
f"Cache Steps: {len(cached_steps)}, {cached_steps}",
|
|
82
84
|
)
|
|
@@ -85,10 +87,10 @@ def summary(
|
|
|
85
87
|
compact=True,
|
|
86
88
|
)
|
|
87
89
|
|
|
88
|
-
if hasattr(
|
|
89
|
-
cfg_cached_steps: list[int] =
|
|
90
|
+
if hasattr(transformer, "_cfg_cached_steps"):
|
|
91
|
+
cfg_cached_steps: list[int] = transformer._cfg_cached_steps
|
|
90
92
|
cfg_residual_diffs: dict[str, float] = dict(
|
|
91
|
-
|
|
93
|
+
transformer._cfg_residual_diffs
|
|
92
94
|
)
|
|
93
95
|
cache_stats.cfg_cached_steps = cfg_cached_steps
|
|
94
96
|
cache_stats.cfg_residual_diffs = cfg_residual_diffs
|
|
@@ -104,7 +106,7 @@ def summary(
|
|
|
104
106
|
qmax = np.max(cfg_diffs_values)
|
|
105
107
|
|
|
106
108
|
print(
|
|
107
|
-
f"\n⚡️CFG Cache Steps and Residual Diffs Statistics: {
|
|
109
|
+
f"\n⚡️CFG Cache Steps and Residual Diffs Statistics: {cls_name}\n"
|
|
108
110
|
)
|
|
109
111
|
|
|
110
112
|
print(
|
|
@@ -122,7 +124,7 @@ def summary(
|
|
|
122
124
|
|
|
123
125
|
if details:
|
|
124
126
|
print(
|
|
125
|
-
f"📚CFG Cache Steps and Residual Diffs Details: {
|
|
127
|
+
f"📚CFG Cache Steps and Residual Diffs Details: {cls_name}\n"
|
|
126
128
|
)
|
|
127
129
|
pprint(
|
|
128
130
|
f"CFG Cache Steps: {len(cfg_cached_steps)}, {cfg_cached_steps}",
|
|
@@ -149,9 +151,10 @@ def strify(pipe_or_stats: DiffusionPipeline | CacheStats):
|
|
|
149
151
|
|
|
150
152
|
cache_type_str = (
|
|
151
153
|
f"DBCACHE_F{cache_options['Fn_compute_blocks']}"
|
|
152
|
-
f"B{cache_options['Bn_compute_blocks']}"
|
|
153
|
-
f"W{cache_options['
|
|
154
|
+
f"B{cache_options['Bn_compute_blocks']}_"
|
|
155
|
+
f"W{cache_options['max_warmup_steps']}"
|
|
154
156
|
f"M{max(0, cache_options['max_cached_steps'])}"
|
|
157
|
+
f"MC{max(0, cache_options['max_continuous_cached_steps'])}_"
|
|
155
158
|
f"T{int(cache_options['enable_taylorseer'])}"
|
|
156
159
|
f"O{cache_options['taylorseer_kwargs']['n_derivatives']}_"
|
|
157
160
|
f"R{cache_options['residual_diff_threshold']}_"
|
|
@@ -1,6 +1,6 @@
|
|
|
1
1
|
Metadata-Version: 2.4
|
|
2
2
|
Name: cache_dit
|
|
3
|
-
Version: 0.2.
|
|
3
|
+
Version: 0.2.24
|
|
4
4
|
Summary: 🤗 CacheDiT: An Unified and Training-free Cache Acceleration Toolbox for Diffusion Transformers
|
|
5
5
|
Author: DefTruth, vipshop.com, etc.
|
|
6
6
|
Maintainer: DefTruth, vipshop.com, etc
|
|
@@ -59,12 +59,13 @@ Dynamic: requires-python
|
|
|
59
59
|
</p>
|
|
60
60
|
<p align="center">
|
|
61
61
|
🎉Now, <b>cache-dit</b> covers <b>Most</b> mainstream <b>Diffusers'</b> Pipelines</b>🎉<br>
|
|
62
|
-
🔥<b><a href="#supported">Qwen-Image</a> | <a href="#supported">FLUX.1</a> | <a href="#supported">Wan 2.1</a> | <a href="#supported"> ... </a> | <a href="#supported">CogVideoX</a></b>🔥
|
|
62
|
+
🔥<b><a href="#supported">Qwen-Image</a> | <a href="#supported">FLUX.1</a> | <a href="#supported">Wan 2.1/2.2</a> | <a href="#supported"> ... </a> | <a href="#supported">CogVideoX</a></b>🔥
|
|
63
63
|
</p>
|
|
64
64
|
</div>
|
|
65
65
|
|
|
66
66
|
## 🔥News
|
|
67
67
|
|
|
68
|
+
- [2025-08-26] 🎉[**Wan2.2**](https://github.com/Wan-Video) **1.5x⚡️** speedup! Please check [run_wan_2.2.py](./examples/run_wan_2.2.py) as an example.
|
|
68
69
|
- [2025-08-19] 🔥[**Qwen-Image-Edit**](https://github.com/QwenLM/Qwen-Image) **2x⚡️** speedup! Check example [run_qwen_image_edit.py](./examples/run_qwen_image_edit.py).
|
|
69
70
|
- [2025-08-12] 🎉First caching mechanism in [QwenLM/Qwen-Image](https://github.com/QwenLM/Qwen-Image) with **[cache-dit](https://github.com/vipshop/cache-dit)**, check the [PR](https://github.com/QwenLM/Qwen-Image/pull/61).
|
|
70
71
|
- [2025-08-11] 🔥[**Qwen-Image**](https://github.com/QwenLM/Qwen-Image) **1.8x⚡️** speedup! Please refer [run_qwen_image.py](./examples/run_qwen_image.py) as an example.
|
|
@@ -119,6 +120,7 @@ Currently, **cache-dit** library supports almost **Any** Diffusion Transformers
|
|
|
119
120
|
- [🚀FLUX.1-Kontext-dev](https://github.com/vipshop/cache-dit/raw/main/examples)
|
|
120
121
|
- [🚀CogVideoX](https://github.com/vipshop/cache-dit/raw/main/examples)
|
|
121
122
|
- [🚀CogVideoX1.5](https://github.com/vipshop/cache-dit/raw/main/examples)
|
|
123
|
+
- [🚀Wan2.2-T2V](https://github.com/vipshop/cache-dit/raw/main/examples)
|
|
122
124
|
- [🚀Wan2.1-T2V](https://github.com/vipshop/cache-dit/raw/main/examples)
|
|
123
125
|
- [🚀Wan2.1-FLF2V](https://github.com/vipshop/cache-dit/raw/main/examples)
|
|
124
126
|
- [🚀HunyuanVideo](https://github.com/vipshop/cache-dit/raw/main/examples)
|
|
@@ -166,7 +168,7 @@ cache_dit.enable_cache(pipe)
|
|
|
166
168
|
output = pipe(...)
|
|
167
169
|
```
|
|
168
170
|
|
|
169
|
-
### 🔥
|
|
171
|
+
### 🔥Automatic Block Adapter
|
|
170
172
|
|
|
171
173
|
But in some cases, you may have a **modified** Diffusion Pipeline or Transformer that is not located in the diffusers library or not officially supported by **cache-dit** at this time. The **BlockAdapter** can help you solve this problems. Please refer to [Qwen-Image w/ BlockAdapter](./examples/run_qwen_image_adapter.py) as an example.
|
|
172
174
|
|
|
@@ -181,7 +183,7 @@ cache_dit.enable_cache(
|
|
|
181
183
|
forward_pattern=ForwardPattern.Pattern_1,
|
|
182
184
|
)
|
|
183
185
|
|
|
184
|
-
# Or,
|
|
186
|
+
# Or, manually setup transformer configurations.
|
|
185
187
|
cache_dit.enable_cache(
|
|
186
188
|
BlockAdapter(
|
|
187
189
|
pipe=pipe, # Qwen-Image, etc.
|
|
@@ -238,7 +240,7 @@ cache_dit.enable_cache(pipe)
|
|
|
238
240
|
# Custom options, F8B8, higher precision
|
|
239
241
|
cache_dit.enable_cache(
|
|
240
242
|
pipe,
|
|
241
|
-
|
|
243
|
+
max_warmup_steps=8, # steps do not cache
|
|
242
244
|
max_cached_steps=-1, # -1 means no limit
|
|
243
245
|
Fn_compute_blocks=8, # Fn, F8, etc.
|
|
244
246
|
Bn_compute_blocks=8, # Bn, B8, etc.
|
|
@@ -297,7 +299,7 @@ cache_dit.enable_cache(
|
|
|
297
299
|
taylorseer_kwargs={
|
|
298
300
|
"n_derivatives": 2, # default is 2.
|
|
299
301
|
},
|
|
300
|
-
|
|
302
|
+
max_warmup_steps=3, # prefer: >= n_derivatives + 1
|
|
301
303
|
residual_diff_threshold=0.12
|
|
302
304
|
)
|
|
303
305
|
```
|
|
@@ -1,18 +1,17 @@
|
|
|
1
1
|
cache_dit/__init__.py,sha256=KwhX9NfYkWSvDFuuUVeVjcuiZiGS_22y386l8j4afMo,905
|
|
2
|
-
cache_dit/_version.py,sha256=
|
|
2
|
+
cache_dit/_version.py,sha256=AZPr2DJJAwMsYN7GLT_kjMvP33B8Rgy4O_7h4o_T_88,706
|
|
3
3
|
cache_dit/logger.py,sha256=0zsu42hN-3-rgGC_C29ms1IvVpV4_b4_SwJCKSenxBE,4304
|
|
4
|
-
cache_dit/
|
|
5
|
-
cache_dit/utils.py,sha256=3UgVhfmTFG28w6CV-Rfxp5u1uzLrRozocHwLCTGiQ5M,5865
|
|
4
|
+
cache_dit/utils.py,sha256=kzwF98nzfzIFHSLtCx7Vq4a9aTW42lY-Bth7Oi4jAhg,6083
|
|
6
5
|
cache_dit/cache_factory/.gitignore,sha256=5Cb-qT9wsTUoMJ7vACDF7ZcLpAXhi5v-xdcWSRit988,23
|
|
7
6
|
cache_dit/cache_factory/__init__.py,sha256=evWenCin1kuBGa6W5BCKMrDZc1C1R2uVPSg0BjXgdXE,499
|
|
8
|
-
cache_dit/cache_factory/cache_adapters.py,sha256=
|
|
7
|
+
cache_dit/cache_factory/cache_adapters.py,sha256=Yugqljm9tm615srM2BGQlR_tA0QiZo3PbLPceObh4dQ,25988
|
|
9
8
|
cache_dit/cache_factory/cache_blocks.py,sha256=ZeazBsYvLIjI5M_OnLL2xP2W7zMeM0rxVfBBwIVHBRs,18661
|
|
10
|
-
cache_dit/cache_factory/cache_context.py,sha256=
|
|
11
|
-
cache_dit/cache_factory/cache_interface.py,sha256=
|
|
9
|
+
cache_dit/cache_factory/cache_context.py,sha256=Cexr1_uwEkX7v8gB7DSyhCX0SI2dqS_e_ccTR16G2es,41738
|
|
10
|
+
cache_dit/cache_factory/cache_interface.py,sha256=ri8wAxmHOsDW8c6qYP6VquOJQaTSXuOchWXG3PdcYQM,8434
|
|
12
11
|
cache_dit/cache_factory/cache_types.py,sha256=FIFa6ZBfvvSMMHyBBhvarvgg2Y2wbRgITcG_uGylGe0,991
|
|
13
12
|
cache_dit/cache_factory/forward_pattern.py,sha256=B2YeqV2t_zo2Ar8m7qimPBjwQgoXHGp2grPZmEAhi8s,1286
|
|
14
|
-
cache_dit/cache_factory/taylorseer.py,sha256=
|
|
15
|
-
cache_dit/cache_factory/utils.py,sha256=
|
|
13
|
+
cache_dit/cache_factory/taylorseer.py,sha256=etSUIZzDvqW3ScKCbccTPcFaSmxV1T-xAXdk-p3e3wk,3802
|
|
14
|
+
cache_dit/cache_factory/utils.py,sha256=XkVM9AXcB9zYq8-S8QKAsGz80r3tA6U3lBNGDGeHOe4,1871
|
|
16
15
|
cache_dit/cache_factory/patch/__init__.py,sha256=47DEQpj8HBSa-_TImW-5JCeuQeRkm5NMpJWZG3hSuFU,0
|
|
17
16
|
cache_dit/cache_factory/patch/flux.py,sha256=iNQ-1RlOgXupZ4uPiEvJ__Ro6vKT_fOKja9JrpMrO78,8998
|
|
18
17
|
cache_dit/compile/__init__.py,sha256=FcTVzCeyypl-mxlc59_ehHL3lBNiDAFsXuRoJ-5Cfi0,56
|
|
@@ -25,9 +24,9 @@ cache_dit/metrics/fid.py,sha256=9Ivtazl6mW0Bon2VXa-Ia5Xj2ewxRD3V1Qkd69zYM3Y,1706
|
|
|
25
24
|
cache_dit/metrics/inception.py,sha256=pBVe2X6ylLPIXTG4-GWDM9DWnCviMJbJ45R3ulhktR0,12759
|
|
26
25
|
cache_dit/metrics/lpips.py,sha256=I2qCNi6qJh5TRsaIsdxO0WoRX1DN7U_H3zS0oCSahYM,1032
|
|
27
26
|
cache_dit/metrics/metrics.py,sha256=8jvM1sF-nDxUuwCRy44QEoo4dYVLCQVh1QyAMs4eaQY,27840
|
|
28
|
-
cache_dit-0.2.
|
|
29
|
-
cache_dit-0.2.
|
|
30
|
-
cache_dit-0.2.
|
|
31
|
-
cache_dit-0.2.
|
|
32
|
-
cache_dit-0.2.
|
|
33
|
-
cache_dit-0.2.
|
|
27
|
+
cache_dit-0.2.24.dist-info/licenses/LICENSE,sha256=Dqb07Ik2dV41s9nIdMUbiRWEfDqo7-dQeRiY7kPO8PE,3769
|
|
28
|
+
cache_dit-0.2.24.dist-info/METADATA,sha256=zq_bGjQ_X--m1njAbOob--MwOpTDlUlAzZ3u_MiNiFM,19977
|
|
29
|
+
cache_dit-0.2.24.dist-info/WHEEL,sha256=_zCd3N1l69ArxyTb8rzEoP9TpbYXkqRFSNOD5OuxnTs,91
|
|
30
|
+
cache_dit-0.2.24.dist-info/entry_points.txt,sha256=FX2gysXaZx6NeK1iCLMcIdP8Q4_qikkIHtEmi3oWn8o,65
|
|
31
|
+
cache_dit-0.2.24.dist-info/top_level.txt,sha256=ZJDydonLEhujzz0FOkVbO-BqfzO9d_VqRHmZU-3MOZo,10
|
|
32
|
+
cache_dit-0.2.24.dist-info/RECORD,,
|
cache_dit/primitives.py
DELETED
|
@@ -1,152 +0,0 @@
|
|
|
1
|
-
# Adapted from: https://github.com/chengzeyi/ParaAttention/blob/main/src/para_attn/primitives.py
|
|
2
|
-
|
|
3
|
-
from typing import List, Optional, Tuple, Union
|
|
4
|
-
|
|
5
|
-
import torch
|
|
6
|
-
import torch.distributed as dist
|
|
7
|
-
|
|
8
|
-
if dist.is_available():
|
|
9
|
-
import torch.distributed._functional_collectives as ft_c
|
|
10
|
-
import torch.distributed.distributed_c10d as c10d
|
|
11
|
-
else:
|
|
12
|
-
ft_c = None
|
|
13
|
-
c10d = None
|
|
14
|
-
|
|
15
|
-
|
|
16
|
-
def get_group(group=None):
|
|
17
|
-
if group is None:
|
|
18
|
-
group = c10d._get_default_group()
|
|
19
|
-
|
|
20
|
-
if isinstance(group, dist.ProcessGroup):
|
|
21
|
-
pg: Union[dist.ProcessGroup, List[dist.ProcessGroup]] = group
|
|
22
|
-
else:
|
|
23
|
-
pg = group.get_group()
|
|
24
|
-
|
|
25
|
-
return pg
|
|
26
|
-
|
|
27
|
-
|
|
28
|
-
def get_world_size(group=None):
|
|
29
|
-
pg = get_group(group)
|
|
30
|
-
return dist.get_world_size(pg)
|
|
31
|
-
|
|
32
|
-
|
|
33
|
-
def get_rank(group=None):
|
|
34
|
-
pg = get_group(group)
|
|
35
|
-
return dist.get_rank(pg)
|
|
36
|
-
|
|
37
|
-
|
|
38
|
-
def _maybe_wait(tensor: torch.Tensor) -> torch.Tensor:
|
|
39
|
-
"""
|
|
40
|
-
When tracing the code, the result tensor is not an AsyncCollectiveTensor,
|
|
41
|
-
so we cannot call ``wait()``.
|
|
42
|
-
"""
|
|
43
|
-
if isinstance(tensor, ft_c.AsyncCollectiveTensor):
|
|
44
|
-
return tensor.wait()
|
|
45
|
-
return tensor
|
|
46
|
-
|
|
47
|
-
|
|
48
|
-
def all_gather_tensor_sync(x, *args, group=None, **kwargs):
|
|
49
|
-
group = get_group(group)
|
|
50
|
-
x_shape = x.shape
|
|
51
|
-
x = x.flatten()
|
|
52
|
-
x_numel = x.numel()
|
|
53
|
-
x = ft_c.all_gather_tensor(x, *args, group=group, **kwargs)
|
|
54
|
-
x = _maybe_wait(x)
|
|
55
|
-
x_shape = list(x_shape)
|
|
56
|
-
x_shape[0] *= x.numel() // x_numel
|
|
57
|
-
x = x.reshape(x_shape)
|
|
58
|
-
return x
|
|
59
|
-
|
|
60
|
-
|
|
61
|
-
def all_gather_tensor_autograd_sync(x, *args, group=None, **kwargs):
|
|
62
|
-
group = get_group(group)
|
|
63
|
-
x_shape = x.shape
|
|
64
|
-
x = x.flatten()
|
|
65
|
-
x_numel = x.numel()
|
|
66
|
-
x = ft_c.all_gather_tensor_autograd(x, *args, group=group, **kwargs)
|
|
67
|
-
x = _maybe_wait(x)
|
|
68
|
-
x_shape = list(x_shape)
|
|
69
|
-
x_shape[0] *= x.numel() // x_numel
|
|
70
|
-
x = x.reshape(x_shape)
|
|
71
|
-
return x
|
|
72
|
-
|
|
73
|
-
|
|
74
|
-
def all_to_all_single_sync(x, *args, **kwargs):
|
|
75
|
-
x_shape = x.shape
|
|
76
|
-
x = x.flatten()
|
|
77
|
-
x = ft_c.all_to_all_single(x, *args, **kwargs)
|
|
78
|
-
x = _maybe_wait(x)
|
|
79
|
-
x = x.reshape(x_shape)
|
|
80
|
-
return x
|
|
81
|
-
|
|
82
|
-
|
|
83
|
-
def all_to_all_single_autograd_sync(x, *args, **kwargs):
|
|
84
|
-
x_shape = x.shape
|
|
85
|
-
x = x.flatten()
|
|
86
|
-
x = ft_c.all_to_all_single_autograd(x, *args, **kwargs)
|
|
87
|
-
x = _maybe_wait(x)
|
|
88
|
-
x = x.reshape(x_shape)
|
|
89
|
-
return x
|
|
90
|
-
|
|
91
|
-
|
|
92
|
-
def all_reduce_sync(x, *args, group=None, **kwargs):
|
|
93
|
-
group = get_group(group)
|
|
94
|
-
x = ft_c.all_reduce(x, *args, group=group, **kwargs)
|
|
95
|
-
x = _maybe_wait(x)
|
|
96
|
-
return x
|
|
97
|
-
|
|
98
|
-
|
|
99
|
-
def get_buffer(
|
|
100
|
-
shape_or_tensor: Union[Tuple[int], torch.Tensor],
|
|
101
|
-
*,
|
|
102
|
-
repeats: int = 1,
|
|
103
|
-
dim: int = 0,
|
|
104
|
-
dtype: Optional[torch.dtype] = None,
|
|
105
|
-
device: Optional[torch.device] = None,
|
|
106
|
-
group=None,
|
|
107
|
-
) -> torch.Tensor:
|
|
108
|
-
if repeats is None:
|
|
109
|
-
repeats = get_world_size(group)
|
|
110
|
-
|
|
111
|
-
if isinstance(shape_or_tensor, torch.Tensor):
|
|
112
|
-
shape = shape_or_tensor.shape
|
|
113
|
-
dtype = shape_or_tensor.dtype
|
|
114
|
-
device = shape_or_tensor.device
|
|
115
|
-
|
|
116
|
-
assert dtype is not None
|
|
117
|
-
assert device is not None
|
|
118
|
-
|
|
119
|
-
shape = list(shape)
|
|
120
|
-
if repeats > 1:
|
|
121
|
-
shape[dim] *= repeats
|
|
122
|
-
|
|
123
|
-
buffer = torch.empty(shape, dtype=dtype, device=device)
|
|
124
|
-
return buffer
|
|
125
|
-
|
|
126
|
-
|
|
127
|
-
def get_assigned_chunk(
|
|
128
|
-
tensor: torch.Tensor,
|
|
129
|
-
dim: int = 0,
|
|
130
|
-
idx: Optional[int] = None,
|
|
131
|
-
group=None,
|
|
132
|
-
) -> torch.Tensor:
|
|
133
|
-
if idx is None:
|
|
134
|
-
idx = get_rank(group)
|
|
135
|
-
world_size = get_world_size(group)
|
|
136
|
-
total_size = tensor.shape[dim]
|
|
137
|
-
assert (
|
|
138
|
-
total_size % world_size == 0
|
|
139
|
-
), f"tensor.shape[{dim}]={total_size} is not divisible by world_size={world_size}"
|
|
140
|
-
return tensor.chunk(world_size, dim=dim)[idx]
|
|
141
|
-
|
|
142
|
-
|
|
143
|
-
def get_complete_tensor(
|
|
144
|
-
tensor: torch.Tensor,
|
|
145
|
-
*,
|
|
146
|
-
dim: int = 0,
|
|
147
|
-
group=None,
|
|
148
|
-
) -> torch.Tensor:
|
|
149
|
-
tensor = tensor.transpose(0, dim).contiguous()
|
|
150
|
-
output_tensor = all_gather_tensor_sync(tensor, gather_dim=0, group=group)
|
|
151
|
-
output_tensor = output_tensor.transpose(0, dim)
|
|
152
|
-
return output_tensor
|
|
File without changes
|
|
File without changes
|
|
File without changes
|
|
File without changes
|