cache-dit 0.2.23__py3-none-any.whl → 0.2.25__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.

Potentially problematic release.


This version of cache-dit might be problematic. Click here for more details.

cache_dit/__init__.py CHANGED
@@ -12,6 +12,7 @@ from cache_dit.cache_factory import CacheType
12
12
  from cache_dit.cache_factory import BlockAdapter
13
13
  from cache_dit.cache_factory import ForwardPattern
14
14
  from cache_dit.compile import set_compile_configs
15
+ from cache_dit.quantize import quantize
15
16
  from cache_dit.utils import summary
16
17
  from cache_dit.utils import strify
17
18
  from cache_dit.logger import init_logger
cache_dit/_version.py CHANGED
@@ -28,7 +28,7 @@ version_tuple: VERSION_TUPLE
28
28
  commit_id: COMMIT_ID
29
29
  __commit_id__: COMMIT_ID
30
30
 
31
- __version__ = version = '0.2.23'
32
- __version_tuple__ = version_tuple = (0, 2, 23)
31
+ __version__ = version = '0.2.25'
32
+ __version_tuple__ = version_tuple = (0, 2, 25)
33
33
 
34
34
  __commit_id__ = commit_id = None
@@ -5,7 +5,7 @@ import unittest
5
5
  import functools
6
6
  import dataclasses
7
7
 
8
- from typing import Any, Tuple, List
8
+ from typing import Any, Tuple, List, Optional
9
9
  from contextlib import ExitStack
10
10
  from diffusers import DiffusionPipeline
11
11
  from cache_dit.cache_factory.patch.flux import (
@@ -40,6 +40,7 @@ class BlockAdapter:
40
40
  "layers",
41
41
  ]
42
42
  )
43
+ check_prefixes: bool = True
43
44
  allow_suffixes: List[str] = dataclasses.field(
44
45
  default_factory=lambda: ["TransformerBlock"]
45
46
  )
@@ -48,8 +49,24 @@ class BlockAdapter:
48
49
  default="max", metadata={"allowed_values": ["max", "min"]}
49
50
  )
50
51
 
52
+ def __post_init__(self):
53
+ self.maybe_apply_patch()
54
+
55
+ def maybe_apply_patch(self):
56
+ # Process some specificial cases, specific for transformers
57
+ # that has different forward patterns between single_transformer_blocks
58
+ # and transformer_blocks , such as Flux (diffusers < 0.35.0).
59
+ if self.transformer.__class__.__name__.startswith("Flux"):
60
+ self.transformer = maybe_patch_flux_transformer(
61
+ self.transformer,
62
+ blocks=self.blocks,
63
+ )
64
+
51
65
  @staticmethod
52
- def auto_block_adapter(adapter: "BlockAdapter") -> "BlockAdapter":
66
+ def auto_block_adapter(
67
+ adapter: "BlockAdapter",
68
+ forward_pattern: Optional[ForwardPattern] = None,
69
+ ) -> "BlockAdapter":
53
70
  assert adapter.auto, (
54
71
  "Please manually set `auto` to True, or, manually "
55
72
  "set all the transformer blocks configuration."
@@ -66,8 +83,10 @@ class BlockAdapter:
66
83
  transformer=transformer,
67
84
  allow_prefixes=adapter.allow_prefixes,
68
85
  allow_suffixes=adapter.allow_suffixes,
86
+ check_prefixes=adapter.check_prefixes,
69
87
  check_suffixes=adapter.check_suffixes,
70
88
  blocks_policy=adapter.blocks_policy,
89
+ forward_pattern=forward_pattern,
71
90
  )
72
91
 
73
92
  return BlockAdapter(
@@ -87,6 +106,8 @@ class BlockAdapter:
87
106
  and isinstance(adapter.blocks, torch.nn.ModuleList)
88
107
  ):
89
108
  return True
109
+
110
+ logger.warning("Check block adapter failed!")
90
111
  return False
91
112
 
92
113
  @staticmethod
@@ -101,24 +122,30 @@ class BlockAdapter:
101
122
  allow_suffixes: List[str] = [
102
123
  "TransformerBlock",
103
124
  ],
125
+ check_prefixes: bool = True,
104
126
  check_suffixes: bool = False,
105
127
  **kwargs,
106
128
  ) -> Tuple[torch.nn.ModuleList, str]:
129
+ # Check prefixes
130
+ if check_prefixes:
131
+ blocks_names = []
132
+ for attr_name in dir(transformer):
133
+ for prefix in allow_prefixes:
134
+ if attr_name.startswith(prefix):
135
+ blocks_names.append(attr_name)
136
+ else:
137
+ blocks_names = dir(transformer)
107
138
 
108
- blocks_names = []
109
- for attr_name in dir(transformer):
110
- for prefix in allow_prefixes:
111
- if attr_name.startswith(prefix):
112
- blocks_names.append(attr_name)
113
-
114
- # Type check
139
+ # Check ModuleList
115
140
  valid_names = []
116
141
  valid_count = []
142
+ forward_pattern = kwargs.get("forward_pattern", None)
117
143
  for blocks_name in blocks_names:
118
144
  if blocks := getattr(transformer, blocks_name, None):
119
145
  if isinstance(blocks, torch.nn.ModuleList):
120
146
  block = blocks[0]
121
147
  block_cls_name = block.__class__.__name__
148
+ # Check suffixes
122
149
  if isinstance(block, torch.nn.Module) and (
123
150
  any(
124
151
  (
@@ -128,8 +155,18 @@ class BlockAdapter:
128
155
  )
129
156
  or (not check_suffixes)
130
157
  ):
131
- valid_names.append(blocks_name)
132
- valid_count.append(len(blocks))
158
+ # May check forward pattern
159
+ if forward_pattern is not None:
160
+ if BlockAdapter.match_blocks_pattern(
161
+ blocks,
162
+ forward_pattern,
163
+ logging=False,
164
+ ):
165
+ valid_names.append(blocks_name)
166
+ valid_count.append(len(blocks))
167
+ else:
168
+ valid_names.append(blocks_name)
169
+ valid_count.append(len(blocks))
133
170
 
134
171
  if not valid_names:
135
172
  raise ValueError(
@@ -139,6 +176,7 @@ class BlockAdapter:
139
176
  final_name = valid_names[0]
140
177
  final_count = valid_count[0]
141
178
  block_policy = kwargs.get("blocks_policy", "max")
179
+
142
180
  for blocks_name, count in zip(valid_names, valid_count):
143
181
  blocks = getattr(transformer, blocks_name)
144
182
  logger.info(
@@ -165,6 +203,67 @@ class BlockAdapter:
165
203
 
166
204
  return final_blocks, final_name
167
205
 
206
+ @staticmethod
207
+ def match_block_pattern(
208
+ block: torch.nn.Module,
209
+ forward_pattern: ForwardPattern,
210
+ ) -> bool:
211
+ assert (
212
+ forward_pattern.Supported
213
+ and forward_pattern in ForwardPattern.supported_patterns()
214
+ ), f"Pattern {forward_pattern} is not support now!"
215
+
216
+ forward_parameters = set(
217
+ inspect.signature(block.forward).parameters.keys()
218
+ )
219
+ num_outputs = str(
220
+ inspect.signature(block.forward).return_annotation
221
+ ).count("torch.Tensor")
222
+
223
+ in_matched = True
224
+ out_matched = True
225
+ if num_outputs > 0 and len(forward_pattern.Out) != num_outputs:
226
+ # output pattern not match
227
+ out_matched = False
228
+
229
+ for required_param in forward_pattern.In:
230
+ if required_param not in forward_parameters:
231
+ in_matched = False
232
+
233
+ return in_matched and out_matched
234
+
235
+ @staticmethod
236
+ def match_blocks_pattern(
237
+ transformer_blocks: torch.nn.ModuleList,
238
+ forward_pattern: ForwardPattern,
239
+ logging: bool = True,
240
+ ) -> bool:
241
+ assert (
242
+ forward_pattern.Supported
243
+ and forward_pattern in ForwardPattern.supported_patterns()
244
+ ), f"Pattern {forward_pattern} is not support now!"
245
+
246
+ assert isinstance(transformer_blocks, torch.nn.ModuleList)
247
+
248
+ pattern_matched_states = []
249
+ for block in transformer_blocks:
250
+ pattern_matched_states.append(
251
+ BlockAdapter.match_block_pattern(
252
+ block,
253
+ forward_pattern,
254
+ )
255
+ )
256
+
257
+ pattern_matched = all(pattern_matched_states) # all block match
258
+ if pattern_matched and logging:
259
+ block_cls_name = transformer_blocks[0].__class__.__name__
260
+ logger.info(
261
+ f"Match Block Forward Pattern: {block_cls_name}, {forward_pattern}"
262
+ f"\nIN:{forward_pattern.In}, OUT:{forward_pattern.Out})"
263
+ )
264
+
265
+ return pattern_matched
266
+
168
267
 
169
268
  @dataclasses.dataclass
170
269
  class UnifiedCacheParams:
@@ -463,19 +562,42 @@ class UnifiedCacheAdapter:
463
562
  ) -> DiffusionPipeline:
464
563
 
465
564
  if block_adapter.auto:
466
- block_adapter = BlockAdapter.auto_block_adapter(block_adapter)
565
+ block_adapter = BlockAdapter.auto_block_adapter(
566
+ block_adapter,
567
+ forward_pattern,
568
+ )
467
569
 
468
570
  if BlockAdapter.check_block_adapter(block_adapter):
469
- assert isinstance(block_adapter.blocks, torch.nn.ModuleList)
470
571
  # Apply cache on pipeline: wrap cache context
471
- cls.create_context(block_adapter.pipe, **cache_context_kwargs)
572
+ cls.create_context(
573
+ block_adapter.pipe,
574
+ **cache_context_kwargs,
575
+ )
472
576
  # Apply cache on transformer: mock cached transformer blocks
473
577
  cls.mock_blocks(
474
578
  block_adapter,
475
579
  forward_pattern=forward_pattern,
476
580
  )
581
+ cls.patch_params(
582
+ block_adapter,
583
+ forward_pattern=forward_pattern,
584
+ **cache_context_kwargs,
585
+ )
477
586
  return block_adapter.pipe
478
587
 
588
+ @classmethod
589
+ def patch_params(
590
+ cls,
591
+ block_adapter: BlockAdapter,
592
+ forward_pattern: ForwardPattern = None,
593
+ **cache_context_kwargs,
594
+ ):
595
+ block_adapter.transformer._forward_pattern = forward_pattern
596
+ block_adapter.transformer._cache_context_kwargs = cache_context_kwargs
597
+ block_adapter.pipe.__class__._cache_context_kwargs = (
598
+ cache_context_kwargs
599
+ )
600
+
479
601
  @classmethod
480
602
  def has_separate_cfg(
481
603
  cls,
@@ -534,7 +656,6 @@ class UnifiedCacheAdapter:
534
656
 
535
657
  pipe.__class__.__call__ = new_call
536
658
  pipe.__class__._is_cached = True
537
- pipe.__class__._cache_options = cache_kwargs
538
659
  return pipe
539
660
 
540
661
  @classmethod
@@ -544,28 +665,11 @@ class UnifiedCacheAdapter:
544
665
  forward_pattern: ForwardPattern = ForwardPattern.Pattern_0,
545
666
  ) -> torch.nn.Module:
546
667
 
547
- if (
548
- block_adapter.transformer is None
549
- or block_adapter.blocks_name is None
550
- or block_adapter.blocks is None
551
- ):
552
- assert block_adapter.auto, (
553
- "Please manually set `auto` to True, or, "
554
- "manually set transformer blocks configuration."
555
- )
556
-
557
668
  if getattr(block_adapter.transformer, "_is_cached", False):
558
669
  return block_adapter.transformer
559
670
 
560
- # Firstly, process some specificial cases (TODO: more patches)
561
- if block_adapter.transformer.__class__.__name__.startswith("Flux"):
562
- block_adapter.transformer = maybe_patch_flux_transformer(
563
- block_adapter.transformer,
564
- blocks=block_adapter.blocks,
565
- )
566
-
567
671
  # Check block forward pattern matching
568
- assert cls.match_pattern(
672
+ assert BlockAdapter.match_blocks_pattern(
569
673
  block_adapter.blocks,
570
674
  forward_pattern=forward_pattern,
571
675
  ), (
@@ -615,46 +719,3 @@ class UnifiedCacheAdapter:
615
719
  block_adapter.transformer._is_cached = True
616
720
 
617
721
  return block_adapter.transformer
618
-
619
- @classmethod
620
- def match_pattern(
621
- cls,
622
- transformer_blocks: torch.nn.ModuleList,
623
- forward_pattern: ForwardPattern = ForwardPattern.Pattern_0,
624
- ) -> bool:
625
- pattern_matched_states = []
626
-
627
- assert (
628
- forward_pattern.Supported
629
- and forward_pattern in ForwardPattern.supported_patterns()
630
- ), f"Pattern {forward_pattern} is not support now!"
631
-
632
- for block in transformer_blocks:
633
- forward_parameters = set(
634
- inspect.signature(block.forward).parameters.keys()
635
- )
636
- num_outputs = str(
637
- inspect.signature(block.forward).return_annotation
638
- ).count("torch.Tensor")
639
-
640
- in_matched = True
641
- out_matched = True
642
- if num_outputs > 0 and len(forward_pattern.Out) != num_outputs:
643
- # output pattern not match
644
- out_matched = False
645
-
646
- for required_param in forward_pattern.In:
647
- if required_param not in forward_parameters:
648
- in_matched = False
649
-
650
- pattern_matched_states.append(in_matched and out_matched)
651
-
652
- pattern_matched = all(pattern_matched_states) # all block match
653
- if pattern_matched:
654
- block_cls_name = transformer_blocks[0].__class__.__name__
655
- logger.info(
656
- f"Match Block Forward Pattern: {block_cls_name}, {forward_pattern}"
657
- f"\nIN:{forward_pattern.In}, OUT:{forward_pattern.Out})"
658
- )
659
-
660
- return pattern_matched
@@ -5,8 +5,8 @@ from collections import defaultdict
5
5
  from typing import Any, DefaultDict, Dict, List, Optional, Union, Tuple
6
6
 
7
7
  import torch
8
+ import torch.distributed as dist
8
9
 
9
- import cache_dit.primitives as primitives
10
10
  from cache_dit.cache_factory.taylorseer import TaylorSeer
11
11
  from cache_dit.logger import init_logger
12
12
 
@@ -47,10 +47,11 @@ class DBCacheContext:
47
47
 
48
48
  # Other settings
49
49
  downsample_factor: int = 1
50
- num_inference_steps: int = -1 # un-used now
51
- warmup_steps: int = 0 # DON'T Cache in warmup steps
50
+ num_inference_steps: int = -1 # for future use
51
+ max_warmup_steps: int = 0 # DON'T Cache in warmup steps
52
52
  # DON'T Cache if the number of cached steps >= max_cached_steps
53
53
  max_cached_steps: int = -1 # for both CFG and non-CFG
54
+ max_continuous_cached_steps: int = -1 # the max continuous cached steps
54
55
 
55
56
  # Record the steps that have been cached, both cached and non-cache
56
57
  executed_steps: int = 0 # cache + non-cache steps pippeline
@@ -89,10 +90,12 @@ class DBCacheContext:
89
90
  residual_diffs: DefaultDict[str, float] = dataclasses.field(
90
91
  default_factory=lambda: defaultdict(float),
91
92
  )
93
+ continuous_cached_steps: int = 0
92
94
  cfg_cached_steps: List[int] = dataclasses.field(default_factory=list)
93
95
  cfg_residual_diffs: DefaultDict[str, float] = dataclasses.field(
94
96
  default_factory=lambda: defaultdict(float),
95
97
  )
98
+ cfg_continuous_cached_steps: int = 0
96
99
 
97
100
  @torch.compiler.disable
98
101
  def __post_init__(self):
@@ -108,17 +111,17 @@ class DBCacheContext:
108
111
  "cfg_diff_compute_separate is enabled."
109
112
  )
110
113
 
111
- if "warmup_steps" not in self.taylorseer_kwargs:
112
- # If warmup_steps is not set in taylorseer_kwargs,
113
- # set the same as warmup_steps for DBCache
114
- self.taylorseer_kwargs["warmup_steps"] = (
115
- self.warmup_steps if self.warmup_steps > 0 else 1
114
+ if "max_warmup_steps" not in self.taylorseer_kwargs:
115
+ # If max_warmup_steps is not set in taylorseer_kwargs,
116
+ # set the same as max_warmup_steps for DBCache
117
+ self.taylorseer_kwargs["max_warmup_steps"] = (
118
+ self.max_warmup_steps if self.max_warmup_steps > 0 else 1
116
119
  )
117
120
 
118
121
  # Only set n_derivatives as 2 or 3, which is enough for most cases.
119
122
  if "n_derivatives" not in self.taylorseer_kwargs:
120
123
  self.taylorseer_kwargs["n_derivatives"] = max(
121
- 2, min(3, self.taylorseer_kwargs["warmup_steps"])
124
+ 2, min(3, self.taylorseer_kwargs["max_warmup_steps"])
122
125
  )
123
126
 
124
127
  if self.enable_taylorseer:
@@ -268,10 +271,31 @@ class DBCacheContext:
268
271
 
269
272
  @torch.compiler.disable
270
273
  def add_cached_step(self):
274
+ curr_cached_step = self.get_current_step()
271
275
  if not self.is_separate_cfg_step():
272
- self.cached_steps.append(self.get_current_step())
276
+ if self.cached_steps:
277
+ prev_cached_step = self.cached_steps[-1]
278
+ if curr_cached_step - prev_cached_step == 1:
279
+ if self.continuous_cached_steps == 0:
280
+ self.continuous_cached_steps += 2
281
+ else:
282
+ self.continuous_cached_steps += 1
283
+ else:
284
+ self.continuous_cached_steps += 1
285
+
286
+ self.cached_steps.append(curr_cached_step)
273
287
  else:
274
- self.cfg_cached_steps.append(self.get_current_step())
288
+ if self.cfg_cached_steps:
289
+ prev_cfg_cached_step = self.cfg_cached_steps[-1]
290
+ if curr_cached_step - prev_cfg_cached_step == 1:
291
+ if self.cfg_continuous_cached_steps == 0:
292
+ self.cfg_continuous_cached_steps += 2
293
+ else:
294
+ self.cfg_continuous_cached_steps += 1
295
+ else:
296
+ self.cfg_continuous_cached_steps += 1
297
+
298
+ self.cfg_cached_steps.append(curr_cached_step)
275
299
 
276
300
  @torch.compiler.disable
277
301
  def get_cached_steps(self):
@@ -301,7 +325,34 @@ class DBCacheContext:
301
325
 
302
326
  @torch.compiler.disable
303
327
  def is_in_warmup(self):
304
- return self.get_current_step() < self.warmup_steps
328
+ return self.get_current_step() < self.max_warmup_steps
329
+
330
+
331
+ # TODO: Support context manager for different cache_context
332
+
333
+
334
+ def create_cache_context(*args, **kwargs):
335
+ return DBCacheContext(*args, **kwargs)
336
+
337
+
338
+ def get_current_cache_context():
339
+ return _current_cache_context
340
+
341
+
342
+ def set_current_cache_context(cache_context=None):
343
+ global _current_cache_context
344
+ _current_cache_context = cache_context
345
+
346
+
347
+ @contextlib.contextmanager
348
+ def cache_context(cache_context):
349
+ global _current_cache_context
350
+ old_cache_context = _current_cache_context
351
+ _current_cache_context = cache_context
352
+ try:
353
+ yield
354
+ finally:
355
+ _current_cache_context = old_cache_context
305
356
 
306
357
 
307
358
  @torch.compiler.disable
@@ -396,6 +447,27 @@ def get_max_cached_steps():
396
447
  return cache_context.max_cached_steps
397
448
 
398
449
 
450
+ @torch.compiler.disable
451
+ def get_max_continuous_cached_steps():
452
+ cache_context = get_current_cache_context()
453
+ assert cache_context is not None, "cache_context must be set before"
454
+ return cache_context.max_continuous_cached_steps
455
+
456
+
457
+ @torch.compiler.disable
458
+ def get_continuous_cached_steps():
459
+ cache_context = get_current_cache_context()
460
+ assert cache_context is not None, "cache_context must be set before"
461
+ return cache_context.continuous_cached_steps
462
+
463
+
464
+ @torch.compiler.disable
465
+ def get_cfg_continuous_cached_steps():
466
+ cache_context = get_current_cache_context()
467
+ assert cache_context is not None, "cache_context must be set before"
468
+ return cache_context.cfg_continuous_cached_steps
469
+
470
+
399
471
  @torch.compiler.disable
400
472
  def add_cached_step():
401
473
  cache_context = get_current_cache_context()
@@ -612,19 +684,6 @@ def cfg_diff_compute_separate():
612
684
  _current_cache_context: DBCacheContext = None
613
685
 
614
686
 
615
- def create_cache_context(*args, **kwargs):
616
- return DBCacheContext(*args, **kwargs)
617
-
618
-
619
- def get_current_cache_context():
620
- return _current_cache_context
621
-
622
-
623
- def set_current_cache_context(cache_context=None):
624
- global _current_cache_context
625
- _current_cache_context = cache_context
626
-
627
-
628
687
  def collect_cache_kwargs(default_attrs: dict, **kwargs):
629
688
  # NOTE: This API will split kwargs into cache_kwargs and other_kwargs
630
689
  # default_attrs: specific settings for different pipelines
@@ -671,17 +730,6 @@ def collect_cache_kwargs(default_attrs: dict, **kwargs):
671
730
  return cache_kwargs, kwargs
672
731
 
673
732
 
674
- @contextlib.contextmanager
675
- def cache_context(cache_context):
676
- global _current_cache_context
677
- old_cache_context = _current_cache_context
678
- _current_cache_context = cache_context
679
- try:
680
- yield
681
- finally:
682
- _current_cache_context = old_cache_context
683
-
684
-
685
733
  @torch.compiler.disable
686
734
  def are_two_tensors_similar(
687
735
  t1: torch.Tensor, # prev residual R(t-1,n) = H(t-1,n) - H(t-1,0)
@@ -744,8 +792,8 @@ def are_two_tensors_similar(
744
792
  mean_t1 = t1.abs().mean()
745
793
 
746
794
  if parallelized:
747
- mean_diff = primitives.all_reduce_sync(mean_diff, "avg")
748
- mean_t1 = primitives.all_reduce_sync(mean_t1, "avg")
795
+ dist.all_reduce(mean_diff, op=dist.ReduceOp.AVG)
796
+ dist.all_reduce(mean_t1, op=dist.ReduceOp.AVG)
749
797
 
750
798
  # D = (t1 - t2) / t1 = 1 - (t2 / t1), if D = 0, then t1 = t2.
751
799
  # Futher, if we assume that (H(t, 0) - H(t-1,0)) ~ 0, then,
@@ -1020,6 +1068,7 @@ def get_can_use_cache(
1020
1068
  if is_in_warmup():
1021
1069
  return False
1022
1070
 
1071
+ # max cached steps
1023
1072
  max_cached_steps = get_max_cached_steps()
1024
1073
  if not is_separate_cfg_step():
1025
1074
  cached_steps = get_cached_steps()
@@ -1030,10 +1079,34 @@ def get_can_use_cache(
1030
1079
  if logger.isEnabledFor(logging.DEBUG):
1031
1080
  logger.debug(
1032
1081
  f"{prefix}, max_cached_steps reached: {max_cached_steps}, "
1033
- "cannot use cache."
1082
+ "can not use cache."
1034
1083
  )
1035
1084
  return False
1036
1085
 
1086
+ # max continuous cached steps
1087
+ max_continuous_cached_steps = get_max_continuous_cached_steps()
1088
+ if not is_separate_cfg_step():
1089
+ continuous_cached_steps = get_continuous_cached_steps()
1090
+ else:
1091
+ continuous_cached_steps = get_cfg_continuous_cached_steps()
1092
+
1093
+ if max_continuous_cached_steps >= 0 and (
1094
+ continuous_cached_steps >= max_continuous_cached_steps
1095
+ ):
1096
+ if logger.isEnabledFor(logging.DEBUG):
1097
+ logger.debug(
1098
+ f"{prefix}, max_continuous_cached_steps "
1099
+ f"reached: {max_continuous_cached_steps}, "
1100
+ "can not use cache."
1101
+ )
1102
+ # reset continuous cached steps stats
1103
+ cache_context = get_current_cache_context()
1104
+ if not is_separate_cfg_step():
1105
+ cache_context.continuous_cached_steps = 0
1106
+ else:
1107
+ cache_context.cfg_continuous_cached_steps = 0
1108
+ return False
1109
+
1037
1110
  if threshold is None or threshold <= 0.0:
1038
1111
  threshold = get_residual_diff_threshold()
1039
1112
  if threshold <= 0.0:
@@ -16,13 +16,14 @@ def enable_cache(
16
16
  # Cache context kwargs
17
17
  Fn_compute_blocks: int = 8,
18
18
  Bn_compute_blocks: int = 0,
19
- warmup_steps: int = 8,
19
+ max_warmup_steps: int = 8,
20
20
  max_cached_steps: int = -1,
21
+ max_continuous_cached_steps: int = -1,
21
22
  residual_diff_threshold: float = 0.08,
22
23
  # Cache CFG or not
23
24
  do_separate_cfg: bool = False,
24
25
  cfg_compute_first: bool = False,
25
- cfg_diff_compute_separate: bool = False,
26
+ cfg_diff_compute_separate: bool = True,
26
27
  # Hybird TaylorSeer
27
28
  enable_taylorseer: bool = False,
28
29
  enable_encoder_taylorseer: bool = False,
@@ -54,12 +55,15 @@ def enable_cache(
54
55
  Further fuses approximate information in the **last n** Transformer blocks to enhance
55
56
  prediction accuracy. These blocks act as an auto-scaler for approximate hidden states
56
57
  that use residual cache.
57
- warmup_steps (`int`, *required*, defaults to 8):
58
+ max_warmup_steps (`int`, *required*, defaults to 8):
58
59
  DBCache does not apply the caching strategy when the number of running steps is less than
59
60
  or equal to this value, ensuring the model sufficiently learns basic features during warmup.
60
61
  max_cached_steps (`int`, *required*, defaults to -1):
61
62
  DBCache disables the caching strategy when the previous cached steps exceed this value to
62
63
  prevent precision degradation.
64
+ max_continuous_cached_steps (`int`, *required*, defaults to -1):
65
+ DBCache disables the caching strategy when the previous continous cached steps exceed this value to
66
+ prevent precision degradation.
63
67
  residual_diff_threshold (`float`, *required*, defaults to 0.08):
64
68
  he value of residual diff threshold, a higher value leads to faster performance at the
65
69
  cost of lower precision.
@@ -106,8 +110,11 @@ def enable_cache(
106
110
  cache_context_kwargs["cache_type"] = CacheType.DBCache
107
111
  cache_context_kwargs["Fn_compute_blocks"] = Fn_compute_blocks
108
112
  cache_context_kwargs["Bn_compute_blocks"] = Bn_compute_blocks
109
- cache_context_kwargs["warmup_steps"] = warmup_steps
113
+ cache_context_kwargs["max_warmup_steps"] = max_warmup_steps
110
114
  cache_context_kwargs["max_cached_steps"] = max_cached_steps
115
+ cache_context_kwargs["max_continuous_cached_steps"] = (
116
+ max_continuous_cached_steps
117
+ )
111
118
  cache_context_kwargs["residual_diff_threshold"] = residual_diff_threshold
112
119
  cache_context_kwargs["do_separate_cfg"] = do_separate_cfg
113
120
  cache_context_kwargs["cfg_compute_first"] = cfg_compute_first
@@ -6,13 +6,13 @@ class TaylorSeer:
6
6
  def __init__(
7
7
  self,
8
8
  n_derivatives=2,
9
- warmup_steps=1,
9
+ max_warmup_steps=1,
10
10
  skip_interval_steps=1,
11
11
  compute_step_map=None,
12
12
  ):
13
13
  self.n_derivatives = n_derivatives
14
14
  self.ORDER = n_derivatives + 1
15
- self.warmup_steps = warmup_steps
15
+ self.max_warmup_steps = max_warmup_steps
16
16
  self.skip_interval_steps = skip_interval_steps
17
17
  self.compute_step_map = compute_step_map
18
18
  self.reset_cache()
@@ -32,8 +32,9 @@ class TaylorSeer:
32
32
  if self.compute_step_map is not None:
33
33
  return self.compute_step_map[step]
34
34
  if (
35
- step < self.warmup_steps
36
- or (step - self.warmup_steps + 1) % self.skip_interval_steps == 0
35
+ step < self.max_warmup_steps
36
+ or (step - self.max_warmup_steps + 1) % self.skip_interval_steps
37
+ == 0
37
38
  ):
38
39
  return True
39
40
  return False
@@ -9,7 +9,7 @@ def load_cache_options_from_yaml(yaml_file_path):
9
9
 
10
10
  required_keys = [
11
11
  "cache_type",
12
- "warmup_steps",
12
+ "max_warmup_steps",
13
13
  "max_cached_steps",
14
14
  "Fn_compute_blocks",
15
15
  "Bn_compute_blocks",
@@ -24,7 +24,7 @@ def epilogue_prologue_fusion_enabled(**kwargs) -> bool:
24
24
 
25
25
 
26
26
  def set_compile_configs(
27
- descent_tuning: bool = True,
27
+ descent_tuning: bool = False,
28
28
  cuda_graphs: bool = False,
29
29
  force_disable_compile_caches: bool = False,
30
30
  use_fast_math: bool = False,
@@ -0,0 +1 @@
1
+ from cache_dit.quantize.quantize_interface import quantize