cache-dit 0.2.32__py3-none-any.whl → 0.2.34__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
@@ -14,13 +14,9 @@ logger = init_logger(__name__)
14
14
  @dataclasses.dataclass
15
15
  class CachedContext: # Internal CachedContext Impl class
16
16
  name: str = "default"
17
- # Dual Block Cache
18
- # Fn=1, Bn=0, means FB Cache, otherwise, Dual Block Cache
17
+ # Dual Block Cache with flexible FnBn configuration.
19
18
  Fn_compute_blocks: int = 1
20
19
  Bn_compute_blocks: int = 0
21
- # We have added residual cache pattern for selected compute blocks
22
- Fn_compute_blocks_ids: List[int] = dataclasses.field(default_factory=list)
23
- Bn_compute_blocks_ids: List[int] = dataclasses.field(default_factory=list)
24
20
  # non compute blocks diff threshold, we don't skip the non
25
21
  # compute blocks if the diff >= threshold
26
22
  non_compute_blocks_diff_threshold: float = 0.08
@@ -31,13 +27,6 @@ class CachedContext: # Internal CachedContext Impl class
31
27
  l1_hidden_states_diff_threshold: float = None
32
28
  important_condition_threshold: float = 0.0
33
29
 
34
- # Alter Cache Settings
35
- # Pattern: 0 F 1 T 2 F 3 T 4 F 5 T ...
36
- enable_alter_cache: bool = False
37
- is_alter_cache: bool = True
38
- # 1.0 means we always cache the residuals if alter_cache is enabled.
39
- alter_residual_diff_threshold: Optional[Union[torch.Tensor, float]] = 1.0
40
-
41
30
  # Buffer for storing the residuals and other tensors
42
31
  buffers: Dict[str, Any] = dataclasses.field(default_factory=dict)
43
32
  incremental_name_counters: DefaultDict[str, int] = dataclasses.field(
@@ -63,22 +52,21 @@ class CachedContext: # Internal CachedContext Impl class
63
52
  # Url: https://arxiv.org/pdf/2503.06923
64
53
  enable_taylorseer: bool = False
65
54
  enable_encoder_taylorseer: bool = False
66
- # NOTE: use residual cache for taylorseer may incur precision loss
67
55
  taylorseer_cache_type: str = "hidden_states" # residual or hidden_states
68
- taylorseer_order: int = 2 # The order for TaylorSeer
56
+ taylorseer_order: int = 1 # The order for TaylorSeer
69
57
  taylorseer_kwargs: Dict[str, Any] = dataclasses.field(default_factory=dict)
70
58
  taylorseer: Optional[TaylorSeer] = None
71
59
  encoder_tarlorseer: Optional[TaylorSeer] = None
72
60
 
73
- # Support enable_spearate_cfg, such as Wan 2.1,
61
+ # Support enable_separate_cfg, such as Wan 2.1,
74
62
  # Qwen-Image. For model that fused CFG and non-CFG into single
75
- # forward step, should set enable_spearate_cfg as False.
63
+ # forward step, should set enable_separate_cfg as False.
76
64
  # For example: CogVideoX, HunyuanVideo, Mochi.
77
- enable_spearate_cfg: bool = False
65
+ enable_separate_cfg: bool = False
78
66
  # Compute cfg forward first or not, default False, namely,
79
67
  # 0, 2, 4, ..., -> non-CFG step; 1, 3, 5, ... -> CFG step.
80
68
  cfg_compute_first: bool = False
81
- # Compute spearate diff values for CFG and non-CFG step,
69
+ # Compute separate diff values for CFG and non-CFG step,
82
70
  # default True. If False, we will use the computed diff from
83
71
  # current non-CFG transformer step for current CFG step.
84
72
  cfg_diff_compute_separate: bool = True
@@ -97,16 +85,11 @@ class CachedContext: # Internal CachedContext Impl class
97
85
  )
98
86
  cfg_continuous_cached_steps: int = 0
99
87
 
100
- @torch.compiler.disable
101
88
  def __post_init__(self):
102
89
  if logger.isEnabledFor(logging.DEBUG):
103
90
  logger.info(f"Created _CacheContext: {self.name}")
104
91
  # Some checks for settings
105
- if self.enable_spearate_cfg:
106
- assert self.enable_alter_cache is False, (
107
- "enable_alter_cache must set as False if "
108
- "enable_spearate_cfg is enabled."
109
- )
92
+ if self.enable_separate_cfg:
110
93
  if self.cfg_diff_compute_separate:
111
94
  assert self.cfg_compute_first is False, (
112
95
  "cfg_compute_first must set as False if "
@@ -125,59 +108,44 @@ class CachedContext: # Internal CachedContext Impl class
125
108
 
126
109
  if self.enable_taylorseer:
127
110
  self.taylorseer = TaylorSeer(**self.taylorseer_kwargs)
128
- if self.enable_spearate_cfg:
111
+ if self.enable_separate_cfg:
129
112
  self.cfg_taylorseer = TaylorSeer(**self.taylorseer_kwargs)
130
113
 
131
114
  if self.enable_encoder_taylorseer:
132
115
  self.encoder_tarlorseer = TaylorSeer(**self.taylorseer_kwargs)
133
- if self.enable_spearate_cfg:
116
+ if self.enable_separate_cfg:
134
117
  self.cfg_encoder_taylorseer = TaylorSeer(
135
118
  **self.taylorseer_kwargs
136
119
  )
137
120
 
138
- @torch.compiler.disable
139
121
  def get_residual_diff_threshold(self):
140
- if self.enable_alter_cache:
141
- residual_diff_threshold = self.alter_residual_diff_threshold
142
- else:
143
- residual_diff_threshold = self.residual_diff_threshold
144
- if self.l1_hidden_states_diff_threshold is not None:
145
- # Use the L1 hidden states diff threshold if set
146
- residual_diff_threshold = self.l1_hidden_states_diff_threshold
122
+ residual_diff_threshold = self.residual_diff_threshold
123
+ if self.l1_hidden_states_diff_threshold is not None:
124
+ # Use the L1 hidden states diff threshold if set
125
+ residual_diff_threshold = self.l1_hidden_states_diff_threshold
147
126
  if isinstance(residual_diff_threshold, torch.Tensor):
148
127
  residual_diff_threshold = residual_diff_threshold.item()
149
128
  return residual_diff_threshold
150
129
 
151
- @torch.compiler.disable
152
130
  def get_buffer(self, name):
153
- if self.enable_alter_cache and self.is_alter_cache:
154
- name = f"{name}_alter"
155
131
  return self.buffers.get(name)
156
132
 
157
- @torch.compiler.disable
158
133
  def set_buffer(self, name, buffer):
159
- if self.enable_alter_cache and self.is_alter_cache:
160
- name = f"{name}_alter"
161
134
  self.buffers[name] = buffer
162
135
 
163
- @torch.compiler.disable
164
136
  def remove_buffer(self, name):
165
- if self.enable_alter_cache and self.is_alter_cache:
166
- name = f"{name}_alter"
167
137
  if name in self.buffers:
168
138
  del self.buffers[name]
169
139
 
170
- @torch.compiler.disable
171
140
  def clear_buffers(self):
172
141
  self.buffers.clear()
173
142
 
174
- @torch.compiler.disable
175
143
  def mark_step_begin(self):
176
144
  # Always increase transformer executed steps
177
- # incr step: prev 0 -> 1; prev 1 -> 2
178
- # current step: incr step - 1
145
+ # incr step: prev 0 -> 1; prev 1 -> 2
146
+ # current step: incr step - 1
179
147
  self.transformer_executed_steps += 1
180
- if not self.enable_spearate_cfg:
148
+ if not self.enable_separate_cfg:
181
149
  self.executed_steps += 1
182
150
  else:
183
151
  # 0,1 -> 0 + 1, 2,3 -> 1 + 1, ...
@@ -190,10 +158,6 @@ class CachedContext: # Internal CachedContext Impl class
190
158
  # transformer step: 0,2,4,...
191
159
  self.executed_steps += 1
192
160
 
193
- if not self.enable_alter_cache:
194
- # 0 F 1 T 2 F 3 T 4 F 5 T ...
195
- self.is_alter_cache = not self.is_alter_cache
196
-
197
161
  # Reset the cached steps and residual diffs at the beginning
198
162
  # of each inference.
199
163
  if self.get_current_transformer_step() == 0:
@@ -219,7 +183,7 @@ class CachedContext: # Internal CachedContext Impl class
219
183
 
220
184
  # mark_step_begin of TaylorSeer must be called after the cache is reset.
221
185
  if self.enable_taylorseer or self.enable_encoder_taylorseer:
222
- if self.enable_spearate_cfg:
186
+ if self.enable_separate_cfg:
223
187
  # Assume non-CFG steps: 0, 2, 4, 6, ...
224
188
  if not self.is_separate_cfg_step():
225
189
  taylorseer, encoder_taylorseer = self.get_taylorseers()
@@ -248,7 +212,6 @@ class CachedContext: # Internal CachedContext Impl class
248
212
  def get_cfg_taylorseers(self) -> Tuple[TaylorSeer, TaylorSeer]:
249
213
  return self.cfg_taylorseer, self.cfg_encoder_taylorseer
250
214
 
251
- @torch.compiler.disable
252
215
  def add_residual_diff(self, diff):
253
216
  # step: executed_steps - 1, not transformer_steps - 1
254
217
  step = str(self.get_current_step())
@@ -260,15 +223,12 @@ class CachedContext: # Internal CachedContext Impl class
260
223
  if step not in self.cfg_residual_diffs:
261
224
  self.cfg_residual_diffs[step] = diff
262
225
 
263
- @torch.compiler.disable
264
226
  def get_residual_diffs(self):
265
227
  return self.residual_diffs.copy()
266
228
 
267
- @torch.compiler.disable
268
229
  def get_cfg_residual_diffs(self):
269
230
  return self.cfg_residual_diffs.copy()
270
231
 
271
- @torch.compiler.disable
272
232
  def add_cached_step(self):
273
233
  curr_cached_step = self.get_current_step()
274
234
  if not self.is_separate_cfg_step():
@@ -296,25 +256,20 @@ class CachedContext: # Internal CachedContext Impl class
296
256
 
297
257
  self.cfg_cached_steps.append(curr_cached_step)
298
258
 
299
- @torch.compiler.disable
300
259
  def get_cached_steps(self):
301
260
  return self.cached_steps.copy()
302
261
 
303
- @torch.compiler.disable
304
262
  def get_cfg_cached_steps(self):
305
263
  return self.cfg_cached_steps.copy()
306
264
 
307
- @torch.compiler.disable
308
265
  def get_current_step(self):
309
266
  return self.executed_steps - 1
310
267
 
311
- @torch.compiler.disable
312
268
  def get_current_transformer_step(self):
313
269
  return self.transformer_executed_steps - 1
314
270
 
315
- @torch.compiler.disable
316
271
  def is_separate_cfg_step(self):
317
- if not self.enable_spearate_cfg:
272
+ if not self.enable_separate_cfg:
318
273
  return False
319
274
  if self.cfg_compute_first:
320
275
  # CFG steps: 0, 2, 4, 6, ...
@@ -322,6 +277,5 @@ class CachedContext: # Internal CachedContext Impl class
322
277
  # CFG steps: 1, 3, 5, 7, ...
323
278
  return self.get_current_transformer_step() % 2 != 0
324
279
 
325
- @torch.compiler.disable
326
280
  def is_in_warmup(self):
327
281
  return self.get_current_step() < self.max_warmup_steps
@@ -74,8 +74,8 @@ class CachedContextManager:
74
74
  del self._cached_context_manager[cached_context]
75
75
 
76
76
  def clear_contexts(self):
77
- for cached_context in self._cached_context_manager:
78
- self.remove_context(cached_context)
77
+ for context_name in list(self._cached_context_manager.keys()):
78
+ self.remove_context(context_name)
79
79
 
80
80
  @contextlib.contextmanager
81
81
  def enter_context(self, cached_context: CachedContext | str):
@@ -122,10 +122,7 @@ class CachedContextManager:
122
122
  default_value,
123
123
  )
124
124
 
125
- # Manually set sequence fields, namely, Fn_compute_blocks_ids
126
- # and Bn_compute_blocks_ids, which are lists or sets.
127
- _safe_set_sequence_field("Fn_compute_blocks_ids", [])
128
- _safe_set_sequence_field("Bn_compute_blocks_ids", [])
125
+ # Manually set sequence fields
129
126
  _safe_set_sequence_field("taylorseer_kwargs", {})
130
127
 
131
128
  for attr in cache_attrs:
@@ -301,18 +298,6 @@ class CachedContextManager:
301
298
  return self.is_taylorseer_cache_residual()
302
299
  return True
303
300
 
304
- @torch.compiler.disable
305
- def is_alter_cache_enabled(self) -> bool:
306
- cached_context = self.get_context()
307
- assert cached_context is not None, "cached_context must be set before"
308
- return cached_context.enable_alter_cache
309
-
310
- @torch.compiler.disable
311
- def is_alter_cache(self) -> bool:
312
- cached_context = self.get_context()
313
- assert cached_context is not None, "cached_context must be set before"
314
- return cached_context.is_alter_cache
315
-
316
301
  @torch.compiler.disable
317
302
  def is_in_warmup(self) -> bool:
318
303
  cached_context = self.get_context()
@@ -359,20 +344,6 @@ class CachedContextManager:
359
344
  )
360
345
  return cached_context.Fn_compute_blocks
361
346
 
362
- @torch.compiler.disable
363
- def Fn_compute_blocks_ids(self) -> List[int]:
364
- cached_context = self.get_context()
365
- assert cached_context is not None, "cached_context must be set before"
366
- assert (
367
- len(cached_context.Fn_compute_blocks_ids)
368
- <= cached_context.Fn_compute_blocks
369
- ), (
370
- "The num of Fn_compute_blocks_ids must be <= Fn_compute_blocks "
371
- f"{cached_context.Fn_compute_blocks}, but got "
372
- f"{len(cached_context.Fn_compute_blocks_ids)}"
373
- )
374
- return cached_context.Fn_compute_blocks_ids
375
-
376
347
  @torch.compiler.disable
377
348
  def Bn_compute_blocks(self) -> int:
378
349
  cached_context = self.get_context()
@@ -393,24 +364,10 @@ class CachedContextManager:
393
364
  return cached_context.Bn_compute_blocks
394
365
 
395
366
  @torch.compiler.disable
396
- def Bn_compute_blocks_ids(self) -> List[int]:
397
- cached_context = self.get_context()
398
- assert cached_context is not None, "cached_context must be set before"
399
- assert (
400
- len(cached_context.Bn_compute_blocks_ids)
401
- <= cached_context.Bn_compute_blocks
402
- ), (
403
- "The num of Bn_compute_blocks_ids must be <= Bn_compute_blocks "
404
- f"{cached_context.Bn_compute_blocks}, but got "
405
- f"{len(cached_context.Bn_compute_blocks_ids)}"
406
- )
407
- return cached_context.Bn_compute_blocks_ids
408
-
409
- @torch.compiler.disable
410
- def enable_spearate_cfg(self) -> bool:
367
+ def enable_separate_cfg(self) -> bool:
411
368
  cached_context = self.get_context()
412
369
  assert cached_context is not None, "cached_context must be set before"
413
- return cached_context.enable_spearate_cfg
370
+ return cached_context.enable_separate_cfg
414
371
 
415
372
  @torch.compiler.disable
416
373
  def is_separate_cfg_step(self) -> bool:
@@ -453,7 +410,7 @@ class CachedContextManager:
453
410
 
454
411
  if all(
455
412
  (
456
- self.enable_spearate_cfg(),
413
+ self.enable_separate_cfg(),
457
414
  self.is_separate_cfg_step(),
458
415
  not self.cfg_diff_compute_separate(),
459
416
  self.get_current_step_residual_diff() is not None,
@@ -525,6 +482,9 @@ class CachedContextManager:
525
482
  # Fn buffers
526
483
  @torch.compiler.disable
527
484
  def set_Fn_buffer(self, buffer: torch.Tensor, prefix: str = "Fn"):
485
+ # DON'T set None Buffer
486
+ if buffer is None:
487
+ return
528
488
  # Set hidden_states or residual for Fn blocks.
529
489
  # This buffer is only use for L1 diff calculation.
530
490
  downsample_factor = self.get_downsample_factor()
@@ -548,6 +508,9 @@ class CachedContextManager:
548
508
 
549
509
  @torch.compiler.disable
550
510
  def set_Fn_encoder_buffer(self, buffer: torch.Tensor, prefix: str = "Fn"):
511
+ # DON'T set None Buffer
512
+ if buffer is None:
513
+ return
551
514
  if self.is_separate_cfg_step():
552
515
  self._debugging_set_buffer(f"{prefix}_encoder_buffer_cfg")
553
516
  self.set_buffer(f"{prefix}_encoder_buffer_cfg", buffer)
@@ -566,6 +529,9 @@ class CachedContextManager:
566
529
  # Bn buffers
567
530
  @torch.compiler.disable
568
531
  def set_Bn_buffer(self, buffer: torch.Tensor, prefix: str = "Bn"):
532
+ # DON'T set None Buffer
533
+ if buffer is None:
534
+ return
569
535
  # Set hidden_states or residual for Bn blocks.
570
536
  # This buffer is use for hidden states approximation.
571
537
  if self.is_taylorseer_enabled():
@@ -820,26 +786,12 @@ class CachedContextManager:
820
786
  else:
821
787
  prev_states_tensor = self.get_Fn_buffer(prefix)
822
788
 
823
- if not self.is_alter_cache_enabled():
824
- # Dynamic cache according to the residual diff
825
- can_cache = prev_states_tensor is not None and self.similarity(
826
- prev_states_tensor,
827
- states_tensor,
828
- threshold=threshold,
829
- parallelized=parallelized,
830
- prefix=prefix,
831
- )
832
- else:
833
- # Only cache in the alter cache steps
834
- can_cache = (
835
- prev_states_tensor is not None
836
- and self.similarity(
837
- prev_states_tensor,
838
- states_tensor,
839
- threshold=threshold,
840
- parallelized=parallelized,
841
- prefix=prefix,
842
- )
843
- and self.is_alter_cache()
844
- )
789
+ # Dynamic cache according to the residual diff
790
+ can_cache = prev_states_tensor is not None and self.similarity(
791
+ prev_states_tensor,
792
+ states_tensor,
793
+ threshold=threshold,
794
+ parallelized=parallelized,
795
+ prefix=prefix,
796
+ )
845
797
  return can_cache
@@ -1,5 +1,6 @@
1
1
  import math
2
2
  import torch
3
+ from typing import List, Dict
3
4
 
4
5
 
5
6
  class TaylorSeer:
@@ -17,16 +18,14 @@ class TaylorSeer:
17
18
  self.compute_step_map = compute_step_map
18
19
  self.reset_cache()
19
20
 
20
- @torch.compiler.disable
21
21
  def reset_cache(self):
22
- self.state = {
22
+ self.state: Dict[str, List[torch.Tensor]] = {
23
23
  "dY_prev": [None] * self.ORDER,
24
24
  "dY_current": [None] * self.ORDER,
25
25
  }
26
26
  self.current_step = -1
27
27
  self.last_non_approximated_step = -1
28
28
 
29
- @torch.compiler.disable
30
29
  def should_compute_full(self, step=None):
31
30
  step = self.current_step if step is None else step
32
31
  if self.compute_step_map is not None:
@@ -39,16 +38,19 @@ class TaylorSeer:
39
38
  return True
40
39
  return False
41
40
 
42
- @torch.compiler.disable
43
- def approximate_derivative(self, Y):
41
+ def approximate_derivative(self, Y: torch.Tensor) -> List[torch.Tensor]:
44
42
  # n-th order Taylor expansion:
45
43
  # Y(t) = Y(0) + dY(0)/dt * t + d^2Y(0)/dt^2 * t^2 / 2!
46
44
  # + ... + d^nY(0)/dt^n * t^n / n!
47
45
  # TODO: Custom Triton/CUDA kernel for better performance,
48
46
  # especially for large n_derivatives.
49
- dY_current = [None] * self.ORDER
47
+ dY_current: List[torch.Tensor] = [None] * self.ORDER
50
48
  dY_current[0] = Y
51
49
  window = self.current_step - self.last_non_approximated_step
50
+ if self.state["dY_prev"][0] is not None:
51
+ if dY_current[0].shape != self.state["dY_prev"][0].shape:
52
+ self.reset_cache()
53
+
52
54
  for i in range(self.n_derivatives):
53
55
  if self.state["dY_prev"][i] is not None and self.current_step > 1:
54
56
  dY_current[i + 1] = (
@@ -58,8 +60,7 @@ class TaylorSeer:
58
60
  break
59
61
  return dY_current
60
62
 
61
- @torch.compiler.disable
62
- def approximate_value(self):
63
+ def approximate_value(self) -> torch.Tensor:
63
64
  # TODO: Custom Triton/CUDA kernel for better performance,
64
65
  # especially for large n_derivatives.
65
66
  elapsed = self.current_step - self.last_non_approximated_step
@@ -71,12 +72,10 @@ class TaylorSeer:
71
72
  break
72
73
  return output
73
74
 
74
- @torch.compiler.disable
75
75
  def mark_step_begin(self):
76
76
  self.current_step += 1
77
77
 
78
- @torch.compiler.disable
79
- def update(self, Y):
78
+ def update(self, Y: torch.Tensor):
80
79
  # Directly call this method will ingnore the warmup
81
80
  # policy and force full computation.
82
81
  # Assume warmup steps is 3, and n_derivatives is 3.
@@ -94,8 +93,7 @@ class TaylorSeer:
94
93
  self.state["dY_current"] = self.approximate_derivative(Y)
95
94
  self.last_non_approximated_step = self.current_step
96
95
 
97
- @torch.compiler.disable
98
- def step(self, Y):
96
+ def step(self, Y: torch.Tensor):
99
97
  self.mark_step_begin()
100
98
  if self.should_compute_full():
101
99
  self.update(Y)
@@ -24,14 +24,14 @@ def enable_cache(
24
24
  max_continuous_cached_steps: int = -1,
25
25
  residual_diff_threshold: float = 0.08,
26
26
  # Cache CFG or not
27
- enable_spearate_cfg: bool | None = None,
27
+ enable_separate_cfg: bool = None,
28
28
  cfg_compute_first: bool = False,
29
29
  cfg_diff_compute_separate: bool = True,
30
30
  # Hybird TaylorSeer
31
31
  enable_taylorseer: bool = False,
32
32
  enable_encoder_taylorseer: bool = False,
33
33
  taylorseer_cache_type: str = "residual",
34
- taylorseer_order: int = 2,
34
+ taylorseer_order: int = 1,
35
35
  **other_cache_context_kwargs,
36
36
  ) -> Union[
37
37
  DiffusionPipeline,
@@ -70,15 +70,15 @@ def enable_cache(
70
70
  residual_diff_threshold (`float`, *required*, defaults to 0.08):
71
71
  he value of residual diff threshold, a higher value leads to faster performance at the
72
72
  cost of lower precision.
73
- enable_spearate_cfg (`bool`, *required*, defaults to None):
73
+ enable_separate_cfg (`bool`, *required*, defaults to None):
74
74
  Whether to do separate cfg or not, such as Wan 2.1, Qwen-Image. For model that fused CFG
75
- and non-CFG into single forward step, should set enable_spearate_cfg as False, for example:
75
+ and non-CFG into single forward step, should set enable_separate_cfg as False, for example:
76
76
  CogVideoX, HunyuanVideo, Mochi, etc.
77
77
  cfg_compute_first (`bool`, *required*, defaults to False):
78
78
  Compute cfg forward first or not, default False, namely, 0, 2, 4, ..., -> non-CFG step;
79
79
  1, 3, 5, ... -> CFG step.
80
80
  cfg_diff_compute_separate (`bool`, *required*, defaults to True):
81
- Compute spearate diff values for CFG and non-CFG step, default True. If False, we will
81
+ Compute separate diff values for CFG and non-CFG step, default True. If False, we will
82
82
  use the computed diff from current non-CFG transformer step for current CFG step.
83
83
  enable_taylorseer (`bool`, *required*, defaults to False):
84
84
  Enable the hybird TaylorSeer for hidden_states or not. We have supported the
@@ -91,10 +91,10 @@ def enable_cache(
91
91
  Enable the hybird TaylorSeer for encoder_hidden_states or not.
92
92
  taylorseer_cache_type (`str`, *required*, defaults to `residual`):
93
93
  The TaylorSeer implemented in cache-dit supports both `hidden_states` and `residual` as cache type.
94
- taylorseer_order (`int`, *required*, defaults to 2):
94
+ taylorseer_order (`int`, *required*, defaults to 1):
95
95
  The order of taylorseer, higher values of n_derivatives will lead to longer computation time,
96
- but may improve precision significantly.
97
- other_cache_kwargs: (`dict`, *optional*, defaults to {})
96
+ the recommended value is 1 or 2.
97
+ other_cache_context_kwargs: (`dict`, *optional*, defaults to {})
98
98
  Other cache context kwargs, please check https://github.com/vipshop/cache-dit/blob/main/src/cache_dit/cache_factory/cache_contexts/cache_context.py
99
99
  for more details.
100
100
 
@@ -123,7 +123,7 @@ def enable_cache(
123
123
  max_continuous_cached_steps
124
124
  )
125
125
  cache_context_kwargs["residual_diff_threshold"] = residual_diff_threshold
126
- cache_context_kwargs["enable_spearate_cfg"] = enable_spearate_cfg
126
+ cache_context_kwargs["enable_separate_cfg"] = enable_separate_cfg
127
127
  cache_context_kwargs["cfg_compute_first"] = cfg_compute_first
128
128
  cache_context_kwargs["cfg_diff_compute_separate"] = (
129
129
  cfg_diff_compute_separate
@@ -1,4 +1,5 @@
1
1
  from cache_dit.cache_factory.patch_functors.functor_base import PatchFunctor
2
+ from cache_dit.cache_factory.patch_functors.functor_dit import DiTPatchFunctor
2
3
  from cache_dit.cache_factory.patch_functors.functor_flux import FluxPatchFunctor
3
4
  from cache_dit.cache_factory.patch_functors.functor_chroma import (
4
5
  ChromaPatchFunctor,