cache-dit 0.2.23__py3-none-any.whl → 0.2.24__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.

Potentially problematic release.


This version of cache-dit might be problematic. Click here for more details.

cache_dit/_version.py CHANGED
@@ -28,7 +28,7 @@ version_tuple: VERSION_TUPLE
28
28
  commit_id: COMMIT_ID
29
29
  __commit_id__: COMMIT_ID
30
30
 
31
- __version__ = version = '0.2.23'
32
- __version_tuple__ = version_tuple = (0, 2, 23)
31
+ __version__ = version = '0.2.24'
32
+ __version_tuple__ = version_tuple = (0, 2, 24)
33
33
 
34
34
  __commit_id__ = commit_id = None
@@ -5,7 +5,7 @@ import unittest
5
5
  import functools
6
6
  import dataclasses
7
7
 
8
- from typing import Any, Tuple, List
8
+ from typing import Any, Tuple, List, Optional
9
9
  from contextlib import ExitStack
10
10
  from diffusers import DiffusionPipeline
11
11
  from cache_dit.cache_factory.patch.flux import (
@@ -40,6 +40,7 @@ class BlockAdapter:
40
40
  "layers",
41
41
  ]
42
42
  )
43
+ check_prefixes: bool = True
43
44
  allow_suffixes: List[str] = dataclasses.field(
44
45
  default_factory=lambda: ["TransformerBlock"]
45
46
  )
@@ -48,8 +49,24 @@ class BlockAdapter:
48
49
  default="max", metadata={"allowed_values": ["max", "min"]}
49
50
  )
50
51
 
52
+ def __post_init__(self):
53
+ self.maybe_apply_patch()
54
+
55
+ def maybe_apply_patch(self):
56
+ # Process some specificial cases, specific for transformers
57
+ # that has different forward patterns between single_transformer_blocks
58
+ # and transformer_blocks , such as Flux (diffusers < 0.35.0).
59
+ if self.transformer.__class__.__name__.startswith("Flux"):
60
+ self.transformer = maybe_patch_flux_transformer(
61
+ self.transformer,
62
+ blocks=self.blocks,
63
+ )
64
+
51
65
  @staticmethod
52
- def auto_block_adapter(adapter: "BlockAdapter") -> "BlockAdapter":
66
+ def auto_block_adapter(
67
+ adapter: "BlockAdapter",
68
+ forward_pattern: Optional[ForwardPattern] = None,
69
+ ) -> "BlockAdapter":
53
70
  assert adapter.auto, (
54
71
  "Please manually set `auto` to True, or, manually "
55
72
  "set all the transformer blocks configuration."
@@ -66,8 +83,10 @@ class BlockAdapter:
66
83
  transformer=transformer,
67
84
  allow_prefixes=adapter.allow_prefixes,
68
85
  allow_suffixes=adapter.allow_suffixes,
86
+ check_prefixes=adapter.check_prefixes,
69
87
  check_suffixes=adapter.check_suffixes,
70
88
  blocks_policy=adapter.blocks_policy,
89
+ forward_pattern=forward_pattern,
71
90
  )
72
91
 
73
92
  return BlockAdapter(
@@ -87,6 +106,8 @@ class BlockAdapter:
87
106
  and isinstance(adapter.blocks, torch.nn.ModuleList)
88
107
  ):
89
108
  return True
109
+
110
+ logger.warning("Check block adapter failed!")
90
111
  return False
91
112
 
92
113
  @staticmethod
@@ -101,24 +122,30 @@ class BlockAdapter:
101
122
  allow_suffixes: List[str] = [
102
123
  "TransformerBlock",
103
124
  ],
125
+ check_prefixes: bool = True,
104
126
  check_suffixes: bool = False,
105
127
  **kwargs,
106
128
  ) -> Tuple[torch.nn.ModuleList, str]:
129
+ # Check prefixes
130
+ if check_prefixes:
131
+ blocks_names = []
132
+ for attr_name in dir(transformer):
133
+ for prefix in allow_prefixes:
134
+ if attr_name.startswith(prefix):
135
+ blocks_names.append(attr_name)
136
+ else:
137
+ blocks_names = dir(transformer)
107
138
 
108
- blocks_names = []
109
- for attr_name in dir(transformer):
110
- for prefix in allow_prefixes:
111
- if attr_name.startswith(prefix):
112
- blocks_names.append(attr_name)
113
-
114
- # Type check
139
+ # Check ModuleList
115
140
  valid_names = []
116
141
  valid_count = []
142
+ forward_pattern = kwargs.get("forward_pattern", None)
117
143
  for blocks_name in blocks_names:
118
144
  if blocks := getattr(transformer, blocks_name, None):
119
145
  if isinstance(blocks, torch.nn.ModuleList):
120
146
  block = blocks[0]
121
147
  block_cls_name = block.__class__.__name__
148
+ # Check suffixes
122
149
  if isinstance(block, torch.nn.Module) and (
123
150
  any(
124
151
  (
@@ -128,8 +155,18 @@ class BlockAdapter:
128
155
  )
129
156
  or (not check_suffixes)
130
157
  ):
131
- valid_names.append(blocks_name)
132
- valid_count.append(len(blocks))
158
+ # May check forward pattern
159
+ if forward_pattern is not None:
160
+ if BlockAdapter.match_blocks_pattern(
161
+ blocks,
162
+ forward_pattern,
163
+ logging=False,
164
+ ):
165
+ valid_names.append(blocks_name)
166
+ valid_count.append(len(blocks))
167
+ else:
168
+ valid_names.append(blocks_name)
169
+ valid_count.append(len(blocks))
133
170
 
134
171
  if not valid_names:
135
172
  raise ValueError(
@@ -139,6 +176,7 @@ class BlockAdapter:
139
176
  final_name = valid_names[0]
140
177
  final_count = valid_count[0]
141
178
  block_policy = kwargs.get("blocks_policy", "max")
179
+
142
180
  for blocks_name, count in zip(valid_names, valid_count):
143
181
  blocks = getattr(transformer, blocks_name)
144
182
  logger.info(
@@ -165,6 +203,67 @@ class BlockAdapter:
165
203
 
166
204
  return final_blocks, final_name
167
205
 
206
+ @staticmethod
207
+ def match_block_pattern(
208
+ block: torch.nn.Module,
209
+ forward_pattern: ForwardPattern,
210
+ ) -> bool:
211
+ assert (
212
+ forward_pattern.Supported
213
+ and forward_pattern in ForwardPattern.supported_patterns()
214
+ ), f"Pattern {forward_pattern} is not support now!"
215
+
216
+ forward_parameters = set(
217
+ inspect.signature(block.forward).parameters.keys()
218
+ )
219
+ num_outputs = str(
220
+ inspect.signature(block.forward).return_annotation
221
+ ).count("torch.Tensor")
222
+
223
+ in_matched = True
224
+ out_matched = True
225
+ if num_outputs > 0 and len(forward_pattern.Out) != num_outputs:
226
+ # output pattern not match
227
+ out_matched = False
228
+
229
+ for required_param in forward_pattern.In:
230
+ if required_param not in forward_parameters:
231
+ in_matched = False
232
+
233
+ return in_matched and out_matched
234
+
235
+ @staticmethod
236
+ def match_blocks_pattern(
237
+ transformer_blocks: torch.nn.ModuleList,
238
+ forward_pattern: ForwardPattern,
239
+ logging: bool = True,
240
+ ) -> bool:
241
+ assert (
242
+ forward_pattern.Supported
243
+ and forward_pattern in ForwardPattern.supported_patterns()
244
+ ), f"Pattern {forward_pattern} is not support now!"
245
+
246
+ assert isinstance(transformer_blocks, torch.nn.ModuleList)
247
+
248
+ pattern_matched_states = []
249
+ for block in transformer_blocks:
250
+ pattern_matched_states.append(
251
+ BlockAdapter.match_block_pattern(
252
+ block,
253
+ forward_pattern,
254
+ )
255
+ )
256
+
257
+ pattern_matched = all(pattern_matched_states) # all block match
258
+ if pattern_matched and logging:
259
+ block_cls_name = transformer_blocks[0].__class__.__name__
260
+ logger.info(
261
+ f"Match Block Forward Pattern: {block_cls_name}, {forward_pattern}"
262
+ f"\nIN:{forward_pattern.In}, OUT:{forward_pattern.Out})"
263
+ )
264
+
265
+ return pattern_matched
266
+
168
267
 
169
268
  @dataclasses.dataclass
170
269
  class UnifiedCacheParams:
@@ -463,19 +562,42 @@ class UnifiedCacheAdapter:
463
562
  ) -> DiffusionPipeline:
464
563
 
465
564
  if block_adapter.auto:
466
- block_adapter = BlockAdapter.auto_block_adapter(block_adapter)
565
+ block_adapter = BlockAdapter.auto_block_adapter(
566
+ block_adapter,
567
+ forward_pattern,
568
+ )
467
569
 
468
570
  if BlockAdapter.check_block_adapter(block_adapter):
469
- assert isinstance(block_adapter.blocks, torch.nn.ModuleList)
470
571
  # Apply cache on pipeline: wrap cache context
471
- cls.create_context(block_adapter.pipe, **cache_context_kwargs)
572
+ cls.create_context(
573
+ block_adapter.pipe,
574
+ **cache_context_kwargs,
575
+ )
472
576
  # Apply cache on transformer: mock cached transformer blocks
473
577
  cls.mock_blocks(
474
578
  block_adapter,
475
579
  forward_pattern=forward_pattern,
476
580
  )
581
+ cls.patch_params(
582
+ block_adapter,
583
+ forward_pattern=forward_pattern,
584
+ **cache_context_kwargs,
585
+ )
477
586
  return block_adapter.pipe
478
587
 
588
+ @classmethod
589
+ def patch_params(
590
+ cls,
591
+ block_adapter: BlockAdapter,
592
+ forward_pattern: ForwardPattern = None,
593
+ **cache_context_kwargs,
594
+ ):
595
+ block_adapter.transformer._forward_pattern = forward_pattern
596
+ block_adapter.transformer._cache_context_kwargs = cache_context_kwargs
597
+ block_adapter.pipe.__class__._cache_context_kwargs = (
598
+ cache_context_kwargs
599
+ )
600
+
479
601
  @classmethod
480
602
  def has_separate_cfg(
481
603
  cls,
@@ -534,7 +656,6 @@ class UnifiedCacheAdapter:
534
656
 
535
657
  pipe.__class__.__call__ = new_call
536
658
  pipe.__class__._is_cached = True
537
- pipe.__class__._cache_options = cache_kwargs
538
659
  return pipe
539
660
 
540
661
  @classmethod
@@ -544,28 +665,11 @@ class UnifiedCacheAdapter:
544
665
  forward_pattern: ForwardPattern = ForwardPattern.Pattern_0,
545
666
  ) -> torch.nn.Module:
546
667
 
547
- if (
548
- block_adapter.transformer is None
549
- or block_adapter.blocks_name is None
550
- or block_adapter.blocks is None
551
- ):
552
- assert block_adapter.auto, (
553
- "Please manually set `auto` to True, or, "
554
- "manually set transformer blocks configuration."
555
- )
556
-
557
668
  if getattr(block_adapter.transformer, "_is_cached", False):
558
669
  return block_adapter.transformer
559
670
 
560
- # Firstly, process some specificial cases (TODO: more patches)
561
- if block_adapter.transformer.__class__.__name__.startswith("Flux"):
562
- block_adapter.transformer = maybe_patch_flux_transformer(
563
- block_adapter.transformer,
564
- blocks=block_adapter.blocks,
565
- )
566
-
567
671
  # Check block forward pattern matching
568
- assert cls.match_pattern(
672
+ assert BlockAdapter.match_blocks_pattern(
569
673
  block_adapter.blocks,
570
674
  forward_pattern=forward_pattern,
571
675
  ), (
@@ -615,46 +719,3 @@ class UnifiedCacheAdapter:
615
719
  block_adapter.transformer._is_cached = True
616
720
 
617
721
  return block_adapter.transformer
618
-
619
- @classmethod
620
- def match_pattern(
621
- cls,
622
- transformer_blocks: torch.nn.ModuleList,
623
- forward_pattern: ForwardPattern = ForwardPattern.Pattern_0,
624
- ) -> bool:
625
- pattern_matched_states = []
626
-
627
- assert (
628
- forward_pattern.Supported
629
- and forward_pattern in ForwardPattern.supported_patterns()
630
- ), f"Pattern {forward_pattern} is not support now!"
631
-
632
- for block in transformer_blocks:
633
- forward_parameters = set(
634
- inspect.signature(block.forward).parameters.keys()
635
- )
636
- num_outputs = str(
637
- inspect.signature(block.forward).return_annotation
638
- ).count("torch.Tensor")
639
-
640
- in_matched = True
641
- out_matched = True
642
- if num_outputs > 0 and len(forward_pattern.Out) != num_outputs:
643
- # output pattern not match
644
- out_matched = False
645
-
646
- for required_param in forward_pattern.In:
647
- if required_param not in forward_parameters:
648
- in_matched = False
649
-
650
- pattern_matched_states.append(in_matched and out_matched)
651
-
652
- pattern_matched = all(pattern_matched_states) # all block match
653
- if pattern_matched:
654
- block_cls_name = transformer_blocks[0].__class__.__name__
655
- logger.info(
656
- f"Match Block Forward Pattern: {block_cls_name}, {forward_pattern}"
657
- f"\nIN:{forward_pattern.In}, OUT:{forward_pattern.Out})"
658
- )
659
-
660
- return pattern_matched
@@ -5,8 +5,8 @@ from collections import defaultdict
5
5
  from typing import Any, DefaultDict, Dict, List, Optional, Union, Tuple
6
6
 
7
7
  import torch
8
+ import torch.distributed as dist
8
9
 
9
- import cache_dit.primitives as primitives
10
10
  from cache_dit.cache_factory.taylorseer import TaylorSeer
11
11
  from cache_dit.logger import init_logger
12
12
 
@@ -47,10 +47,11 @@ class DBCacheContext:
47
47
 
48
48
  # Other settings
49
49
  downsample_factor: int = 1
50
- num_inference_steps: int = -1 # un-used now
51
- warmup_steps: int = 0 # DON'T Cache in warmup steps
50
+ num_inference_steps: int = -1 # for future use
51
+ max_warmup_steps: int = 0 # DON'T Cache in warmup steps
52
52
  # DON'T Cache if the number of cached steps >= max_cached_steps
53
53
  max_cached_steps: int = -1 # for both CFG and non-CFG
54
+ max_continuous_cached_steps: int = -1 # the max continuous cached steps
54
55
 
55
56
  # Record the steps that have been cached, both cached and non-cache
56
57
  executed_steps: int = 0 # cache + non-cache steps pippeline
@@ -89,10 +90,12 @@ class DBCacheContext:
89
90
  residual_diffs: DefaultDict[str, float] = dataclasses.field(
90
91
  default_factory=lambda: defaultdict(float),
91
92
  )
93
+ continuous_cached_steps: int = 0
92
94
  cfg_cached_steps: List[int] = dataclasses.field(default_factory=list)
93
95
  cfg_residual_diffs: DefaultDict[str, float] = dataclasses.field(
94
96
  default_factory=lambda: defaultdict(float),
95
97
  )
98
+ cfg_continuous_cached_steps: int = 0
96
99
 
97
100
  @torch.compiler.disable
98
101
  def __post_init__(self):
@@ -108,17 +111,17 @@ class DBCacheContext:
108
111
  "cfg_diff_compute_separate is enabled."
109
112
  )
110
113
 
111
- if "warmup_steps" not in self.taylorseer_kwargs:
112
- # If warmup_steps is not set in taylorseer_kwargs,
113
- # set the same as warmup_steps for DBCache
114
- self.taylorseer_kwargs["warmup_steps"] = (
115
- self.warmup_steps if self.warmup_steps > 0 else 1
114
+ if "max_warmup_steps" not in self.taylorseer_kwargs:
115
+ # If max_warmup_steps is not set in taylorseer_kwargs,
116
+ # set the same as max_warmup_steps for DBCache
117
+ self.taylorseer_kwargs["max_warmup_steps"] = (
118
+ self.max_warmup_steps if self.max_warmup_steps > 0 else 1
116
119
  )
117
120
 
118
121
  # Only set n_derivatives as 2 or 3, which is enough for most cases.
119
122
  if "n_derivatives" not in self.taylorseer_kwargs:
120
123
  self.taylorseer_kwargs["n_derivatives"] = max(
121
- 2, min(3, self.taylorseer_kwargs["warmup_steps"])
124
+ 2, min(3, self.taylorseer_kwargs["max_warmup_steps"])
122
125
  )
123
126
 
124
127
  if self.enable_taylorseer:
@@ -268,10 +271,31 @@ class DBCacheContext:
268
271
 
269
272
  @torch.compiler.disable
270
273
  def add_cached_step(self):
274
+ curr_cached_step = self.get_current_step()
271
275
  if not self.is_separate_cfg_step():
272
- self.cached_steps.append(self.get_current_step())
276
+ if self.cached_steps:
277
+ prev_cached_step = self.cached_steps[-1]
278
+ if curr_cached_step - prev_cached_step == 1:
279
+ if self.continuous_cached_steps == 0:
280
+ self.continuous_cached_steps += 2
281
+ else:
282
+ self.continuous_cached_steps += 1
283
+ else:
284
+ self.continuous_cached_steps += 1
285
+
286
+ self.cached_steps.append(curr_cached_step)
273
287
  else:
274
- self.cfg_cached_steps.append(self.get_current_step())
288
+ if self.cfg_cached_steps:
289
+ prev_cfg_cached_step = self.cfg_cached_steps[-1]
290
+ if curr_cached_step - prev_cfg_cached_step == 1:
291
+ if self.cfg_continuous_cached_steps == 0:
292
+ self.cfg_continuous_cached_steps += 2
293
+ else:
294
+ self.cfg_continuous_cached_steps += 1
295
+ else:
296
+ self.cfg_continuous_cached_steps += 1
297
+
298
+ self.cfg_cached_steps.append(curr_cached_step)
275
299
 
276
300
  @torch.compiler.disable
277
301
  def get_cached_steps(self):
@@ -301,7 +325,7 @@ class DBCacheContext:
301
325
 
302
326
  @torch.compiler.disable
303
327
  def is_in_warmup(self):
304
- return self.get_current_step() < self.warmup_steps
328
+ return self.get_current_step() < self.max_warmup_steps
305
329
 
306
330
 
307
331
  @torch.compiler.disable
@@ -396,6 +420,27 @@ def get_max_cached_steps():
396
420
  return cache_context.max_cached_steps
397
421
 
398
422
 
423
+ @torch.compiler.disable
424
+ def get_max_continuous_cached_steps():
425
+ cache_context = get_current_cache_context()
426
+ assert cache_context is not None, "cache_context must be set before"
427
+ return cache_context.max_continuous_cached_steps
428
+
429
+
430
+ @torch.compiler.disable
431
+ def get_continuous_cached_steps():
432
+ cache_context = get_current_cache_context()
433
+ assert cache_context is not None, "cache_context must be set before"
434
+ return cache_context.continuous_cached_steps
435
+
436
+
437
+ @torch.compiler.disable
438
+ def get_cfg_continuous_cached_steps():
439
+ cache_context = get_current_cache_context()
440
+ assert cache_context is not None, "cache_context must be set before"
441
+ return cache_context.cfg_continuous_cached_steps
442
+
443
+
399
444
  @torch.compiler.disable
400
445
  def add_cached_step():
401
446
  cache_context = get_current_cache_context()
@@ -744,8 +789,8 @@ def are_two_tensors_similar(
744
789
  mean_t1 = t1.abs().mean()
745
790
 
746
791
  if parallelized:
747
- mean_diff = primitives.all_reduce_sync(mean_diff, "avg")
748
- mean_t1 = primitives.all_reduce_sync(mean_t1, "avg")
792
+ dist.all_reduce(mean_diff, op=dist.ReduceOp.AVG)
793
+ dist.all_reduce(mean_t1, op=dist.ReduceOp.AVG)
749
794
 
750
795
  # D = (t1 - t2) / t1 = 1 - (t2 / t1), if D = 0, then t1 = t2.
751
796
  # Futher, if we assume that (H(t, 0) - H(t-1,0)) ~ 0, then,
@@ -1020,6 +1065,7 @@ def get_can_use_cache(
1020
1065
  if is_in_warmup():
1021
1066
  return False
1022
1067
 
1068
+ # max cached steps
1023
1069
  max_cached_steps = get_max_cached_steps()
1024
1070
  if not is_separate_cfg_step():
1025
1071
  cached_steps = get_cached_steps()
@@ -1030,8 +1076,32 @@ def get_can_use_cache(
1030
1076
  if logger.isEnabledFor(logging.DEBUG):
1031
1077
  logger.debug(
1032
1078
  f"{prefix}, max_cached_steps reached: {max_cached_steps}, "
1033
- "cannot use cache."
1079
+ "can not use cache."
1080
+ )
1081
+ return False
1082
+
1083
+ # max continuous cached steps
1084
+ max_continuous_cached_steps = get_max_continuous_cached_steps()
1085
+ if not is_separate_cfg_step():
1086
+ continuous_cached_steps = get_continuous_cached_steps()
1087
+ else:
1088
+ continuous_cached_steps = get_cfg_continuous_cached_steps()
1089
+
1090
+ if max_continuous_cached_steps >= 0 and (
1091
+ continuous_cached_steps >= max_continuous_cached_steps
1092
+ ):
1093
+ if logger.isEnabledFor(logging.DEBUG):
1094
+ logger.debug(
1095
+ f"{prefix}, max_continuous_cached_steps "
1096
+ f"reached: {max_continuous_cached_steps}, "
1097
+ "can not use cache."
1034
1098
  )
1099
+ # reset continuous cached steps stats
1100
+ cache_context = get_current_cache_context()
1101
+ if not is_separate_cfg_step():
1102
+ cache_context.continuous_cached_steps = 0
1103
+ else:
1104
+ cache_context.cfg_continuous_cached_steps = 0
1035
1105
  return False
1036
1106
 
1037
1107
  if threshold is None or threshold <= 0.0:
@@ -16,8 +16,9 @@ def enable_cache(
16
16
  # Cache context kwargs
17
17
  Fn_compute_blocks: int = 8,
18
18
  Bn_compute_blocks: int = 0,
19
- warmup_steps: int = 8,
19
+ max_warmup_steps: int = 8,
20
20
  max_cached_steps: int = -1,
21
+ max_continuous_cached_steps: int = -1,
21
22
  residual_diff_threshold: float = 0.08,
22
23
  # Cache CFG or not
23
24
  do_separate_cfg: bool = False,
@@ -54,12 +55,15 @@ def enable_cache(
54
55
  Further fuses approximate information in the **last n** Transformer blocks to enhance
55
56
  prediction accuracy. These blocks act as an auto-scaler for approximate hidden states
56
57
  that use residual cache.
57
- warmup_steps (`int`, *required*, defaults to 8):
58
+ max_warmup_steps (`int`, *required*, defaults to 8):
58
59
  DBCache does not apply the caching strategy when the number of running steps is less than
59
60
  or equal to this value, ensuring the model sufficiently learns basic features during warmup.
60
61
  max_cached_steps (`int`, *required*, defaults to -1):
61
62
  DBCache disables the caching strategy when the previous cached steps exceed this value to
62
63
  prevent precision degradation.
64
+ max_continuous_cached_steps (`int`, *required*, defaults to -1):
65
+ DBCache disables the caching strategy when the previous continous cached steps exceed this value to
66
+ prevent precision degradation.
63
67
  residual_diff_threshold (`float`, *required*, defaults to 0.08):
64
68
  he value of residual diff threshold, a higher value leads to faster performance at the
65
69
  cost of lower precision.
@@ -106,8 +110,11 @@ def enable_cache(
106
110
  cache_context_kwargs["cache_type"] = CacheType.DBCache
107
111
  cache_context_kwargs["Fn_compute_blocks"] = Fn_compute_blocks
108
112
  cache_context_kwargs["Bn_compute_blocks"] = Bn_compute_blocks
109
- cache_context_kwargs["warmup_steps"] = warmup_steps
113
+ cache_context_kwargs["max_warmup_steps"] = max_warmup_steps
110
114
  cache_context_kwargs["max_cached_steps"] = max_cached_steps
115
+ cache_context_kwargs["max_continuous_cached_steps"] = (
116
+ max_continuous_cached_steps
117
+ )
111
118
  cache_context_kwargs["residual_diff_threshold"] = residual_diff_threshold
112
119
  cache_context_kwargs["do_separate_cfg"] = do_separate_cfg
113
120
  cache_context_kwargs["cfg_compute_first"] = cfg_compute_first
@@ -6,13 +6,13 @@ class TaylorSeer:
6
6
  def __init__(
7
7
  self,
8
8
  n_derivatives=2,
9
- warmup_steps=1,
9
+ max_warmup_steps=1,
10
10
  skip_interval_steps=1,
11
11
  compute_step_map=None,
12
12
  ):
13
13
  self.n_derivatives = n_derivatives
14
14
  self.ORDER = n_derivatives + 1
15
- self.warmup_steps = warmup_steps
15
+ self.max_warmup_steps = max_warmup_steps
16
16
  self.skip_interval_steps = skip_interval_steps
17
17
  self.compute_step_map = compute_step_map
18
18
  self.reset_cache()
@@ -32,8 +32,9 @@ class TaylorSeer:
32
32
  if self.compute_step_map is not None:
33
33
  return self.compute_step_map[step]
34
34
  if (
35
- step < self.warmup_steps
36
- or (step - self.warmup_steps + 1) % self.skip_interval_steps == 0
35
+ step < self.max_warmup_steps
36
+ or (step - self.max_warmup_steps + 1) % self.skip_interval_steps
37
+ == 0
37
38
  ):
38
39
  return True
39
40
  return False
@@ -9,7 +9,7 @@ def load_cache_options_from_yaml(yaml_file_path):
9
9
 
10
10
  required_keys = [
11
11
  "cache_type",
12
- "warmup_steps",
12
+ "max_warmup_steps",
13
13
  "max_cached_steps",
14
14
  "Fn_compute_blocks",
15
15
  "Bn_compute_blocks",
cache_dit/utils.py CHANGED
@@ -27,22 +27,26 @@ class CacheStats:
27
27
 
28
28
 
29
29
  def summary(
30
- pipe: DiffusionPipeline, details: bool = False, logging: bool = True
31
- ):
30
+ pipe_or_transformer: DiffusionPipeline | torch.nn.Module,
31
+ details: bool = False,
32
+ logging: bool = True,
33
+ ) -> CacheStats:
32
34
  cache_stats = CacheStats()
33
- pipe_cls_name = pipe.__class__.__name__
35
+ cls_name = pipe_or_transformer.__class__.__name__
36
+ if isinstance(pipe_or_transformer, DiffusionPipeline):
37
+ transformer = pipe_or_transformer.transformer
38
+ else:
39
+ transformer = pipe_or_transformer
34
40
 
35
- if hasattr(pipe, "_cache_options"):
36
- cache_options = pipe._cache_options
41
+ if hasattr(transformer, "_cache_context_kwargs"):
42
+ cache_options = transformer._cache_context_kwargs
37
43
  cache_stats.cache_options = cache_options
38
44
  if logging:
39
- print(f"\n🤗Cache Options: {pipe_cls_name}\n\n{cache_options}")
45
+ print(f"\n🤗Cache Options: {cls_name}\n\n{cache_options}")
40
46
 
41
- if hasattr(pipe.transformer, "_cached_steps"):
42
- cached_steps: list[int] = pipe.transformer._cached_steps
43
- residual_diffs: dict[str, float] = dict(
44
- pipe.transformer._residual_diffs
45
- )
47
+ if hasattr(transformer, "_cached_steps"):
48
+ cached_steps: list[int] = transformer._cached_steps
49
+ residual_diffs: dict[str, float] = dict(transformer._residual_diffs)
46
50
  cache_stats.cached_steps = cached_steps
47
51
  cache_stats.residual_diffs = residual_diffs
48
52
 
@@ -57,7 +61,7 @@ def summary(
57
61
  qmax = np.max(diffs_values)
58
62
 
59
63
  print(
60
- f"\n⚡️Cache Steps and Residual Diffs Statistics: {pipe_cls_name}\n"
64
+ f"\n⚡️Cache Steps and Residual Diffs Statistics: {cls_name}\n"
61
65
  )
62
66
 
63
67
  print(
@@ -74,9 +78,7 @@ def summary(
74
78
  print("")
75
79
 
76
80
  if details:
77
- print(
78
- f"📚Cache Steps and Residual Diffs Details: {pipe_cls_name}\n"
79
- )
81
+ print(f"📚Cache Steps and Residual Diffs Details: {cls_name}\n")
80
82
  pprint(
81
83
  f"Cache Steps: {len(cached_steps)}, {cached_steps}",
82
84
  )
@@ -85,10 +87,10 @@ def summary(
85
87
  compact=True,
86
88
  )
87
89
 
88
- if hasattr(pipe.transformer, "_cfg_cached_steps"):
89
- cfg_cached_steps: list[int] = pipe.transformer._cfg_cached_steps
90
+ if hasattr(transformer, "_cfg_cached_steps"):
91
+ cfg_cached_steps: list[int] = transformer._cfg_cached_steps
90
92
  cfg_residual_diffs: dict[str, float] = dict(
91
- pipe.transformer._cfg_residual_diffs
93
+ transformer._cfg_residual_diffs
92
94
  )
93
95
  cache_stats.cfg_cached_steps = cfg_cached_steps
94
96
  cache_stats.cfg_residual_diffs = cfg_residual_diffs
@@ -104,7 +106,7 @@ def summary(
104
106
  qmax = np.max(cfg_diffs_values)
105
107
 
106
108
  print(
107
- f"\n⚡️CFG Cache Steps and Residual Diffs Statistics: {pipe_cls_name}\n"
109
+ f"\n⚡️CFG Cache Steps and Residual Diffs Statistics: {cls_name}\n"
108
110
  )
109
111
 
110
112
  print(
@@ -122,7 +124,7 @@ def summary(
122
124
 
123
125
  if details:
124
126
  print(
125
- f"📚CFG Cache Steps and Residual Diffs Details: {pipe_cls_name}\n"
127
+ f"📚CFG Cache Steps and Residual Diffs Details: {cls_name}\n"
126
128
  )
127
129
  pprint(
128
130
  f"CFG Cache Steps: {len(cfg_cached_steps)}, {cfg_cached_steps}",
@@ -149,9 +151,10 @@ def strify(pipe_or_stats: DiffusionPipeline | CacheStats):
149
151
 
150
152
  cache_type_str = (
151
153
  f"DBCACHE_F{cache_options['Fn_compute_blocks']}"
152
- f"B{cache_options['Bn_compute_blocks']}"
153
- f"W{cache_options['warmup_steps']}"
154
+ f"B{cache_options['Bn_compute_blocks']}_"
155
+ f"W{cache_options['max_warmup_steps']}"
154
156
  f"M{max(0, cache_options['max_cached_steps'])}"
157
+ f"MC{max(0, cache_options['max_continuous_cached_steps'])}_"
155
158
  f"T{int(cache_options['enable_taylorseer'])}"
156
159
  f"O{cache_options['taylorseer_kwargs']['n_derivatives']}_"
157
160
  f"R{cache_options['residual_diff_threshold']}_"
@@ -1,6 +1,6 @@
1
1
  Metadata-Version: 2.4
2
2
  Name: cache_dit
3
- Version: 0.2.23
3
+ Version: 0.2.24
4
4
  Summary: 🤗 CacheDiT: An Unified and Training-free Cache Acceleration Toolbox for Diffusion Transformers
5
5
  Author: DefTruth, vipshop.com, etc.
6
6
  Maintainer: DefTruth, vipshop.com, etc
@@ -59,12 +59,13 @@ Dynamic: requires-python
59
59
  </p>
60
60
  <p align="center">
61
61
  🎉Now, <b>cache-dit</b> covers <b>Most</b> mainstream <b>Diffusers'</b> Pipelines</b>🎉<br>
62
- 🔥<b><a href="#supported">Qwen-Image</a> | <a href="#supported">FLUX.1</a> | <a href="#supported">Wan 2.1</a> | <a href="#supported"> ... </a> | <a href="#supported">CogVideoX</a></b>🔥
62
+ 🔥<b><a href="#supported">Qwen-Image</a> | <a href="#supported">FLUX.1</a> | <a href="#supported">Wan 2.1/2.2</a> | <a href="#supported"> ... </a> | <a href="#supported">CogVideoX</a></b>🔥
63
63
  </p>
64
64
  </div>
65
65
 
66
66
  ## 🔥News
67
67
 
68
+ - [2025-08-26] 🎉[**Wan2.2**](https://github.com/Wan-Video) **1.5x⚡️** speedup! Please check [run_wan_2.2.py](./examples/run_wan_2.2.py) as an example.
68
69
  - [2025-08-19] 🔥[**Qwen-Image-Edit**](https://github.com/QwenLM/Qwen-Image) **2x⚡️** speedup! Check example [run_qwen_image_edit.py](./examples/run_qwen_image_edit.py).
69
70
  - [2025-08-12] 🎉First caching mechanism in [QwenLM/Qwen-Image](https://github.com/QwenLM/Qwen-Image) with **[cache-dit](https://github.com/vipshop/cache-dit)**, check the [PR](https://github.com/QwenLM/Qwen-Image/pull/61).
70
71
  - [2025-08-11] 🔥[**Qwen-Image**](https://github.com/QwenLM/Qwen-Image) **1.8x⚡️** speedup! Please refer [run_qwen_image.py](./examples/run_qwen_image.py) as an example.
@@ -119,6 +120,7 @@ Currently, **cache-dit** library supports almost **Any** Diffusion Transformers
119
120
  - [🚀FLUX.1-Kontext-dev](https://github.com/vipshop/cache-dit/raw/main/examples)
120
121
  - [🚀CogVideoX](https://github.com/vipshop/cache-dit/raw/main/examples)
121
122
  - [🚀CogVideoX1.5](https://github.com/vipshop/cache-dit/raw/main/examples)
123
+ - [🚀Wan2.2-T2V](https://github.com/vipshop/cache-dit/raw/main/examples)
122
124
  - [🚀Wan2.1-T2V](https://github.com/vipshop/cache-dit/raw/main/examples)
123
125
  - [🚀Wan2.1-FLF2V](https://github.com/vipshop/cache-dit/raw/main/examples)
124
126
  - [🚀HunyuanVideo](https://github.com/vipshop/cache-dit/raw/main/examples)
@@ -166,7 +168,7 @@ cache_dit.enable_cache(pipe)
166
168
  output = pipe(...)
167
169
  ```
168
170
 
169
- ### 🔥BlockAdapter: Cache Acceleration for Custom Diffusion Models
171
+ ### 🔥Automatic Block Adapter
170
172
 
171
173
  But in some cases, you may have a **modified** Diffusion Pipeline or Transformer that is not located in the diffusers library or not officially supported by **cache-dit** at this time. The **BlockAdapter** can help you solve this problems. Please refer to [Qwen-Image w/ BlockAdapter](./examples/run_qwen_image_adapter.py) as an example.
172
174
 
@@ -181,7 +183,7 @@ cache_dit.enable_cache(
181
183
  forward_pattern=ForwardPattern.Pattern_1,
182
184
  )
183
185
 
184
- # Or, manualy setup transformer configurations.
186
+ # Or, manually setup transformer configurations.
185
187
  cache_dit.enable_cache(
186
188
  BlockAdapter(
187
189
  pipe=pipe, # Qwen-Image, etc.
@@ -238,7 +240,7 @@ cache_dit.enable_cache(pipe)
238
240
  # Custom options, F8B8, higher precision
239
241
  cache_dit.enable_cache(
240
242
  pipe,
241
- warmup_steps=8, # steps do not cache
243
+ max_warmup_steps=8, # steps do not cache
242
244
  max_cached_steps=-1, # -1 means no limit
243
245
  Fn_compute_blocks=8, # Fn, F8, etc.
244
246
  Bn_compute_blocks=8, # Bn, B8, etc.
@@ -297,7 +299,7 @@ cache_dit.enable_cache(
297
299
  taylorseer_kwargs={
298
300
  "n_derivatives": 2, # default is 2.
299
301
  },
300
- warmup_steps=3, # prefer: >= n_derivatives + 1
302
+ max_warmup_steps=3, # prefer: >= n_derivatives + 1
301
303
  residual_diff_threshold=0.12
302
304
  )
303
305
  ```
@@ -1,18 +1,17 @@
1
1
  cache_dit/__init__.py,sha256=KwhX9NfYkWSvDFuuUVeVjcuiZiGS_22y386l8j4afMo,905
2
- cache_dit/_version.py,sha256=6GZdGbiFdhndXqR5oFLOd8VGzUvRkESP-NyStAZWYUw,706
2
+ cache_dit/_version.py,sha256=AZPr2DJJAwMsYN7GLT_kjMvP33B8Rgy4O_7h4o_T_88,706
3
3
  cache_dit/logger.py,sha256=0zsu42hN-3-rgGC_C29ms1IvVpV4_b4_SwJCKSenxBE,4304
4
- cache_dit/primitives.py,sha256=A2iG9YLot3gOsZSPp-_gyjqjLgJvWQRx8aitD4JQ23Y,3877
5
- cache_dit/utils.py,sha256=3UgVhfmTFG28w6CV-Rfxp5u1uzLrRozocHwLCTGiQ5M,5865
4
+ cache_dit/utils.py,sha256=kzwF98nzfzIFHSLtCx7Vq4a9aTW42lY-Bth7Oi4jAhg,6083
6
5
  cache_dit/cache_factory/.gitignore,sha256=5Cb-qT9wsTUoMJ7vACDF7ZcLpAXhi5v-xdcWSRit988,23
7
6
  cache_dit/cache_factory/__init__.py,sha256=evWenCin1kuBGa6W5BCKMrDZc1C1R2uVPSg0BjXgdXE,499
8
- cache_dit/cache_factory/cache_adapters.py,sha256=QTSwjdCmHDeF80TLp6D3KhzQS_oMPna0_bESgJBrdkg,23978
7
+ cache_dit/cache_factory/cache_adapters.py,sha256=Yugqljm9tm615srM2BGQlR_tA0QiZo3PbLPceObh4dQ,25988
9
8
  cache_dit/cache_factory/cache_blocks.py,sha256=ZeazBsYvLIjI5M_OnLL2xP2W7zMeM0rxVfBBwIVHBRs,18661
10
- cache_dit/cache_factory/cache_context.py,sha256=4thx9NYxVaYZ_Nr2quUVE8bsNmTsXhZK0F960rccOc8,39000
11
- cache_dit/cache_factory/cache_interface.py,sha256=PohG_2oy747O37YSsWz_DwxxTXN7ORhQatyEbg_6fQs,8045
9
+ cache_dit/cache_factory/cache_context.py,sha256=Cexr1_uwEkX7v8gB7DSyhCX0SI2dqS_e_ccTR16G2es,41738
10
+ cache_dit/cache_factory/cache_interface.py,sha256=ri8wAxmHOsDW8c6qYP6VquOJQaTSXuOchWXG3PdcYQM,8434
12
11
  cache_dit/cache_factory/cache_types.py,sha256=FIFa6ZBfvvSMMHyBBhvarvgg2Y2wbRgITcG_uGylGe0,991
13
12
  cache_dit/cache_factory/forward_pattern.py,sha256=B2YeqV2t_zo2Ar8m7qimPBjwQgoXHGp2grPZmEAhi8s,1286
14
- cache_dit/cache_factory/taylorseer.py,sha256=WeK2WlAJa4Px_pnAKokmnZXeqQYylQkPw4-EDqBIqeQ,3770
15
- cache_dit/cache_factory/utils.py,sha256=YGtn02O3fVlrfQ32gGV4WAtTRvzzwSXNxzP_FmnE2Uk,1867
13
+ cache_dit/cache_factory/taylorseer.py,sha256=etSUIZzDvqW3ScKCbccTPcFaSmxV1T-xAXdk-p3e3wk,3802
14
+ cache_dit/cache_factory/utils.py,sha256=XkVM9AXcB9zYq8-S8QKAsGz80r3tA6U3lBNGDGeHOe4,1871
16
15
  cache_dit/cache_factory/patch/__init__.py,sha256=47DEQpj8HBSa-_TImW-5JCeuQeRkm5NMpJWZG3hSuFU,0
17
16
  cache_dit/cache_factory/patch/flux.py,sha256=iNQ-1RlOgXupZ4uPiEvJ__Ro6vKT_fOKja9JrpMrO78,8998
18
17
  cache_dit/compile/__init__.py,sha256=FcTVzCeyypl-mxlc59_ehHL3lBNiDAFsXuRoJ-5Cfi0,56
@@ -25,9 +24,9 @@ cache_dit/metrics/fid.py,sha256=9Ivtazl6mW0Bon2VXa-Ia5Xj2ewxRD3V1Qkd69zYM3Y,1706
25
24
  cache_dit/metrics/inception.py,sha256=pBVe2X6ylLPIXTG4-GWDM9DWnCviMJbJ45R3ulhktR0,12759
26
25
  cache_dit/metrics/lpips.py,sha256=I2qCNi6qJh5TRsaIsdxO0WoRX1DN7U_H3zS0oCSahYM,1032
27
26
  cache_dit/metrics/metrics.py,sha256=8jvM1sF-nDxUuwCRy44QEoo4dYVLCQVh1QyAMs4eaQY,27840
28
- cache_dit-0.2.23.dist-info/licenses/LICENSE,sha256=Dqb07Ik2dV41s9nIdMUbiRWEfDqo7-dQeRiY7kPO8PE,3769
29
- cache_dit-0.2.23.dist-info/METADATA,sha256=Dq2f8TlyTmv36otIJ2F-fRGkJlZmpW2SY6O14P2AYKo,19772
30
- cache_dit-0.2.23.dist-info/WHEEL,sha256=_zCd3N1l69ArxyTb8rzEoP9TpbYXkqRFSNOD5OuxnTs,91
31
- cache_dit-0.2.23.dist-info/entry_points.txt,sha256=FX2gysXaZx6NeK1iCLMcIdP8Q4_qikkIHtEmi3oWn8o,65
32
- cache_dit-0.2.23.dist-info/top_level.txt,sha256=ZJDydonLEhujzz0FOkVbO-BqfzO9d_VqRHmZU-3MOZo,10
33
- cache_dit-0.2.23.dist-info/RECORD,,
27
+ cache_dit-0.2.24.dist-info/licenses/LICENSE,sha256=Dqb07Ik2dV41s9nIdMUbiRWEfDqo7-dQeRiY7kPO8PE,3769
28
+ cache_dit-0.2.24.dist-info/METADATA,sha256=zq_bGjQ_X--m1njAbOob--MwOpTDlUlAzZ3u_MiNiFM,19977
29
+ cache_dit-0.2.24.dist-info/WHEEL,sha256=_zCd3N1l69ArxyTb8rzEoP9TpbYXkqRFSNOD5OuxnTs,91
30
+ cache_dit-0.2.24.dist-info/entry_points.txt,sha256=FX2gysXaZx6NeK1iCLMcIdP8Q4_qikkIHtEmi3oWn8o,65
31
+ cache_dit-0.2.24.dist-info/top_level.txt,sha256=ZJDydonLEhujzz0FOkVbO-BqfzO9d_VqRHmZU-3MOZo,10
32
+ cache_dit-0.2.24.dist-info/RECORD,,
cache_dit/primitives.py DELETED
@@ -1,152 +0,0 @@
1
- # Adapted from: https://github.com/chengzeyi/ParaAttention/blob/main/src/para_attn/primitives.py
2
-
3
- from typing import List, Optional, Tuple, Union
4
-
5
- import torch
6
- import torch.distributed as dist
7
-
8
- if dist.is_available():
9
- import torch.distributed._functional_collectives as ft_c
10
- import torch.distributed.distributed_c10d as c10d
11
- else:
12
- ft_c = None
13
- c10d = None
14
-
15
-
16
- def get_group(group=None):
17
- if group is None:
18
- group = c10d._get_default_group()
19
-
20
- if isinstance(group, dist.ProcessGroup):
21
- pg: Union[dist.ProcessGroup, List[dist.ProcessGroup]] = group
22
- else:
23
- pg = group.get_group()
24
-
25
- return pg
26
-
27
-
28
- def get_world_size(group=None):
29
- pg = get_group(group)
30
- return dist.get_world_size(pg)
31
-
32
-
33
- def get_rank(group=None):
34
- pg = get_group(group)
35
- return dist.get_rank(pg)
36
-
37
-
38
- def _maybe_wait(tensor: torch.Tensor) -> torch.Tensor:
39
- """
40
- When tracing the code, the result tensor is not an AsyncCollectiveTensor,
41
- so we cannot call ``wait()``.
42
- """
43
- if isinstance(tensor, ft_c.AsyncCollectiveTensor):
44
- return tensor.wait()
45
- return tensor
46
-
47
-
48
- def all_gather_tensor_sync(x, *args, group=None, **kwargs):
49
- group = get_group(group)
50
- x_shape = x.shape
51
- x = x.flatten()
52
- x_numel = x.numel()
53
- x = ft_c.all_gather_tensor(x, *args, group=group, **kwargs)
54
- x = _maybe_wait(x)
55
- x_shape = list(x_shape)
56
- x_shape[0] *= x.numel() // x_numel
57
- x = x.reshape(x_shape)
58
- return x
59
-
60
-
61
- def all_gather_tensor_autograd_sync(x, *args, group=None, **kwargs):
62
- group = get_group(group)
63
- x_shape = x.shape
64
- x = x.flatten()
65
- x_numel = x.numel()
66
- x = ft_c.all_gather_tensor_autograd(x, *args, group=group, **kwargs)
67
- x = _maybe_wait(x)
68
- x_shape = list(x_shape)
69
- x_shape[0] *= x.numel() // x_numel
70
- x = x.reshape(x_shape)
71
- return x
72
-
73
-
74
- def all_to_all_single_sync(x, *args, **kwargs):
75
- x_shape = x.shape
76
- x = x.flatten()
77
- x = ft_c.all_to_all_single(x, *args, **kwargs)
78
- x = _maybe_wait(x)
79
- x = x.reshape(x_shape)
80
- return x
81
-
82
-
83
- def all_to_all_single_autograd_sync(x, *args, **kwargs):
84
- x_shape = x.shape
85
- x = x.flatten()
86
- x = ft_c.all_to_all_single_autograd(x, *args, **kwargs)
87
- x = _maybe_wait(x)
88
- x = x.reshape(x_shape)
89
- return x
90
-
91
-
92
- def all_reduce_sync(x, *args, group=None, **kwargs):
93
- group = get_group(group)
94
- x = ft_c.all_reduce(x, *args, group=group, **kwargs)
95
- x = _maybe_wait(x)
96
- return x
97
-
98
-
99
- def get_buffer(
100
- shape_or_tensor: Union[Tuple[int], torch.Tensor],
101
- *,
102
- repeats: int = 1,
103
- dim: int = 0,
104
- dtype: Optional[torch.dtype] = None,
105
- device: Optional[torch.device] = None,
106
- group=None,
107
- ) -> torch.Tensor:
108
- if repeats is None:
109
- repeats = get_world_size(group)
110
-
111
- if isinstance(shape_or_tensor, torch.Tensor):
112
- shape = shape_or_tensor.shape
113
- dtype = shape_or_tensor.dtype
114
- device = shape_or_tensor.device
115
-
116
- assert dtype is not None
117
- assert device is not None
118
-
119
- shape = list(shape)
120
- if repeats > 1:
121
- shape[dim] *= repeats
122
-
123
- buffer = torch.empty(shape, dtype=dtype, device=device)
124
- return buffer
125
-
126
-
127
- def get_assigned_chunk(
128
- tensor: torch.Tensor,
129
- dim: int = 0,
130
- idx: Optional[int] = None,
131
- group=None,
132
- ) -> torch.Tensor:
133
- if idx is None:
134
- idx = get_rank(group)
135
- world_size = get_world_size(group)
136
- total_size = tensor.shape[dim]
137
- assert (
138
- total_size % world_size == 0
139
- ), f"tensor.shape[{dim}]={total_size} is not divisible by world_size={world_size}"
140
- return tensor.chunk(world_size, dim=dim)[idx]
141
-
142
-
143
- def get_complete_tensor(
144
- tensor: torch.Tensor,
145
- *,
146
- dim: int = 0,
147
- group=None,
148
- ) -> torch.Tensor:
149
- tensor = tensor.transpose(0, dim).contiguous()
150
- output_tensor = all_gather_tensor_sync(tensor, gather_dim=0, group=group)
151
- output_tensor = output_tensor.transpose(0, dim)
152
- return output_tensor