cache-dit 0.2.33__py3-none-any.whl → 0.2.36__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
cache_dit/__init__.py CHANGED
@@ -4,6 +4,10 @@ except ImportError:
4
4
  __version__ = "unknown version"
5
5
  version_tuple = (0, 0, "unknown version")
6
6
 
7
+ from cache_dit.utils import summary
8
+ from cache_dit.utils import strify
9
+ from cache_dit.utils import disable_print
10
+ from cache_dit.logger import init_logger
7
11
  from cache_dit.cache_factory import load_options
8
12
  from cache_dit.cache_factory import enable_cache
9
13
  from cache_dit.cache_factory import disable_cache
@@ -18,9 +22,7 @@ from cache_dit.cache_factory import supported_pipelines
18
22
  from cache_dit.cache_factory import get_adapter
19
23
  from cache_dit.compile import set_compile_configs
20
24
  from cache_dit.quantize import quantize
21
- from cache_dit.utils import summary
22
- from cache_dit.utils import strify
23
- from cache_dit.logger import init_logger
25
+
24
26
 
25
27
  NONE = CacheType.NONE
26
28
  DBCache = CacheType.DBCache
cache_dit/_version.py CHANGED
@@ -28,7 +28,7 @@ version_tuple: VERSION_TUPLE
28
28
  commit_id: COMMIT_ID
29
29
  __commit_id__: COMMIT_ID
30
30
 
31
- __version__ = version = '0.2.33'
32
- __version_tuple__ = version_tuple = (0, 2, 33)
31
+ __version__ = version = '0.2.36'
32
+ __version_tuple__ = version_tuple = (0, 2, 36)
33
33
 
34
34
  __commit_id__ = commit_id = None
@@ -153,7 +153,7 @@ def qwenimage_adapter(pipe, **kwargs) -> BlockAdapter:
153
153
  )
154
154
 
155
155
 
156
- @BlockAdapterRegistry.register("LTXVideo")
156
+ @BlockAdapterRegistry.register("LTX")
157
157
  def ltxvideo_adapter(pipe, **kwargs) -> BlockAdapter:
158
158
  from diffusers import LTXVideoTransformer3DModel
159
159
 
@@ -248,7 +248,10 @@ def skyreelsv2_adapter(pipe, **kwargs) -> BlockAdapter:
248
248
  pipe=pipe,
249
249
  transformer=pipe.transformer,
250
250
  blocks=pipe.transformer.blocks,
251
- forward_pattern=ForwardPattern.Pattern_2,
251
+ # NOTE: Use Pattern_3 instead of Pattern_2 because the
252
+ # encoder_hidden_states will never change in the blocks
253
+ # forward loop.
254
+ forward_pattern=ForwardPattern.Pattern_3,
252
255
  has_separate_cfg=True,
253
256
  **kwargs,
254
257
  )
@@ -285,6 +288,7 @@ def consisid_adapter(pipe, **kwargs) -> BlockAdapter:
285
288
  @BlockAdapterRegistry.register("DiT")
286
289
  def dit_adapter(pipe, **kwargs) -> BlockAdapter:
287
290
  from diffusers import DiTTransformer2DModel
291
+ from cache_dit.cache_factory.patch_functors import DiTPatchFunctor
288
292
 
289
293
  assert isinstance(pipe.transformer, DiTTransformer2DModel)
290
294
  return BlockAdapter(
@@ -292,6 +296,7 @@ def dit_adapter(pipe, **kwargs) -> BlockAdapter:
292
296
  transformer=pipe.transformer,
293
297
  blocks=pipe.transformer.transformer_blocks,
294
298
  forward_pattern=ForwardPattern.Pattern_3,
299
+ patch_functor=DiTPatchFunctor(),
295
300
  **kwargs,
296
301
  )
297
302
 
@@ -331,24 +336,13 @@ def bria_adapter(pipe, **kwargs) -> BlockAdapter:
331
336
 
332
337
 
333
338
  @BlockAdapterRegistry.register("Lumina")
334
- def lumina_adapter(pipe, **kwargs) -> BlockAdapter:
335
- from diffusers import LuminaNextDiT2DModel
336
-
337
- assert isinstance(pipe.transformer, LuminaNextDiT2DModel)
338
- return BlockAdapter(
339
- pipe=pipe,
340
- transformer=pipe.transformer,
341
- blocks=pipe.transformer.layers,
342
- forward_pattern=ForwardPattern.Pattern_3,
343
- **kwargs,
344
- )
345
-
346
-
347
- @BlockAdapterRegistry.register("Lumina2")
348
339
  def lumina2_adapter(pipe, **kwargs) -> BlockAdapter:
349
340
  from diffusers import Lumina2Transformer2DModel
341
+ from diffusers import LuminaNextDiT2DModel
350
342
 
351
- assert isinstance(pipe.transformer, Lumina2Transformer2DModel)
343
+ assert isinstance(
344
+ pipe.transformer, (Lumina2Transformer2DModel, LuminaNextDiT2DModel)
345
+ )
352
346
  return BlockAdapter(
353
347
  pipe=pipe,
354
348
  transformer=pipe.transformer,
@@ -386,12 +380,10 @@ def pixart_adapter(pipe, **kwargs) -> BlockAdapter:
386
380
  )
387
381
 
388
382
 
389
- @BlockAdapterRegistry.register("Sana", supported=False)
383
+ @BlockAdapterRegistry.register("Sana")
390
384
  def sana_adapter(pipe, **kwargs) -> BlockAdapter:
391
385
  from diffusers import SanaTransformer2DModel
392
386
 
393
- # TODO: fix -> got multiple values for argument 'encoder_hidden_states'
394
-
395
387
  assert isinstance(pipe.transformer, SanaTransformer2DModel)
396
388
  return BlockAdapter(
397
389
  pipe=pipe,
@@ -469,6 +461,7 @@ def auraflow_adapter(pipe, **kwargs) -> BlockAdapter:
469
461
  @BlockAdapterRegistry.register("Chroma")
470
462
  def chroma_adapter(pipe, **kwargs) -> BlockAdapter:
471
463
  from diffusers import ChromaTransformer2DModel
464
+ from cache_dit.cache_factory.patch_functors import ChromaPatchFunctor
472
465
 
473
466
  assert isinstance(pipe.transformer, ChromaTransformer2DModel)
474
467
  return BlockAdapter(
@@ -482,6 +475,7 @@ def chroma_adapter(pipe, **kwargs) -> BlockAdapter:
482
475
  ForwardPattern.Pattern_1,
483
476
  ForwardPattern.Pattern_3,
484
477
  ],
478
+ patch_functor=ChromaPatchFunctor(),
485
479
  has_separate_cfg=True,
486
480
  **kwargs,
487
481
  )
@@ -16,8 +16,52 @@ logger = init_logger(__name__)
16
16
 
17
17
 
18
18
  class ParamsModifier:
19
- def __init__(self, **kwargs):
20
- self._context_kwargs = kwargs.copy()
19
+ def __init__(
20
+ self,
21
+ # Cache context kwargs
22
+ Fn_compute_blocks: Optional[int] = None,
23
+ Bn_compute_blocks: Optional[int] = None,
24
+ max_warmup_steps: Optional[int] = None,
25
+ max_cached_steps: Optional[int] = None,
26
+ max_continuous_cached_steps: Optional[int] = None,
27
+ residual_diff_threshold: Optional[float] = None,
28
+ # Cache CFG or not
29
+ enable_separate_cfg: Optional[bool] = None,
30
+ cfg_compute_first: Optional[bool] = None,
31
+ cfg_diff_compute_separate: Optional[bool] = None,
32
+ # Hybird TaylorSeer
33
+ enable_taylorseer: Optional[bool] = None,
34
+ enable_encoder_taylorseer: Optional[bool] = None,
35
+ taylorseer_cache_type: Optional[str] = None,
36
+ taylorseer_order: Optional[int] = None,
37
+ **other_cache_context_kwargs,
38
+ ):
39
+ self._context_kwargs = other_cache_context_kwargs.copy()
40
+ self._maybe_update_param("Fn_compute_blocks", Fn_compute_blocks)
41
+ self._maybe_update_param("Bn_compute_blocks", Bn_compute_blocks)
42
+ self._maybe_update_param("max_warmup_steps", max_warmup_steps)
43
+ self._maybe_update_param("max_cached_steps", max_cached_steps)
44
+ self._maybe_update_param(
45
+ "max_continuous_cached_steps", max_continuous_cached_steps
46
+ )
47
+ self._maybe_update_param(
48
+ "residual_diff_threshold", residual_diff_threshold
49
+ )
50
+ self._maybe_update_param("enable_separate_cfg", enable_separate_cfg)
51
+ self._maybe_update_param("cfg_compute_first", cfg_compute_first)
52
+ self._maybe_update_param(
53
+ "cfg_diff_compute_separate", cfg_diff_compute_separate
54
+ )
55
+ self._maybe_update_param("enable_taylorseer", enable_taylorseer)
56
+ self._maybe_update_param(
57
+ "enable_encoder_taylorseer", enable_encoder_taylorseer
58
+ )
59
+ self._maybe_update_param("taylorseer_cache_type", taylorseer_cache_type)
60
+ self._maybe_update_param("taylorseer_order", taylorseer_order)
61
+
62
+ def _maybe_update_param(self, key: str, value: Any):
63
+ if value is not None:
64
+ self._context_kwargs[key] = value
21
65
 
22
66
 
23
67
  @dataclasses.dataclass
@@ -10,13 +10,14 @@ logger = init_logger(__name__)
10
10
 
11
11
  class BlockAdapterRegistry:
12
12
  _adapters: Dict[str, Callable[..., BlockAdapter]] = {}
13
- _predefined_adapters_has_spearate_cfg: List[str] = [
13
+ _predefined_adapters_has_separate_cfg: List[str] = [
14
14
  "QwenImage",
15
15
  "Wan",
16
16
  "CogView4",
17
17
  "Cosmos",
18
18
  "SkyReelsV2",
19
19
  "Chroma",
20
+ "Lumina2",
20
21
  ]
21
22
 
22
23
  @classmethod
@@ -68,7 +69,7 @@ class BlockAdapterRegistry:
68
69
  return True
69
70
 
70
71
  pipe_cls_name = pipe_or_adapter.__class__.__name__
71
- for name in cls._predefined_adapters_has_spearate_cfg:
72
+ for name in cls._predefined_adapters_has_separate_cfg:
72
73
  if pipe_cls_name.startswith(name):
73
74
  return True
74
75
 
@@ -114,27 +114,27 @@ class CachedAdapter:
114
114
  **cache_context_kwargs,
115
115
  ):
116
116
  # Check cache_context_kwargs
117
- if cache_context_kwargs["enable_spearate_cfg"] is None:
117
+ if cache_context_kwargs["enable_separate_cfg"] is None:
118
118
  # Check cfg for some specific case if users don't set it as True
119
119
  if BlockAdapterRegistry.has_separate_cfg(block_adapter):
120
- cache_context_kwargs["enable_spearate_cfg"] = True
120
+ cache_context_kwargs["enable_separate_cfg"] = True
121
121
  logger.info(
122
- f"Use custom 'enable_spearate_cfg' from BlockAdapter: True. "
122
+ f"Use custom 'enable_separate_cfg' from BlockAdapter: True. "
123
123
  f"Pipeline: {block_adapter.pipe.__class__.__name__}."
124
124
  )
125
125
  else:
126
- cache_context_kwargs["enable_spearate_cfg"] = (
126
+ cache_context_kwargs["enable_separate_cfg"] = (
127
127
  BlockAdapterRegistry.has_separate_cfg(block_adapter.pipe)
128
128
  )
129
129
  logger.info(
130
- f"Use default 'enable_spearate_cfg' from block adapter "
131
- f"register: {cache_context_kwargs['enable_spearate_cfg']}, "
130
+ f"Use default 'enable_separate_cfg' from block adapter "
131
+ f"register: {cache_context_kwargs['enable_separate_cfg']}, "
132
132
  f"Pipeline: {block_adapter.pipe.__class__.__name__}."
133
133
  )
134
134
  else:
135
135
  logger.info(
136
- f"Use custom 'enable_spearate_cfg' from cache context "
137
- f"kwargs: {cache_context_kwargs['enable_spearate_cfg']}. "
136
+ f"Use custom 'enable_separate_cfg' from cache context "
137
+ f"kwargs: {cache_context_kwargs['enable_separate_cfg']}. "
138
138
  f"Pipeline: {block_adapter.pipe.__class__.__name__}."
139
139
  )
140
140
 
@@ -53,20 +53,20 @@ class CachedContext: # Internal CachedContext Impl class
53
53
  enable_taylorseer: bool = False
54
54
  enable_encoder_taylorseer: bool = False
55
55
  taylorseer_cache_type: str = "hidden_states" # residual or hidden_states
56
- taylorseer_order: int = 2 # The order for TaylorSeer
56
+ taylorseer_order: int = 1 # The order for TaylorSeer
57
57
  taylorseer_kwargs: Dict[str, Any] = dataclasses.field(default_factory=dict)
58
58
  taylorseer: Optional[TaylorSeer] = None
59
59
  encoder_tarlorseer: Optional[TaylorSeer] = None
60
60
 
61
- # Support enable_spearate_cfg, such as Wan 2.1,
61
+ # Support enable_separate_cfg, such as Wan 2.1,
62
62
  # Qwen-Image. For model that fused CFG and non-CFG into single
63
- # forward step, should set enable_spearate_cfg as False.
63
+ # forward step, should set enable_separate_cfg as False.
64
64
  # For example: CogVideoX, HunyuanVideo, Mochi.
65
- enable_spearate_cfg: bool = False
65
+ enable_separate_cfg: bool = False
66
66
  # Compute cfg forward first or not, default False, namely,
67
67
  # 0, 2, 4, ..., -> non-CFG step; 1, 3, 5, ... -> CFG step.
68
68
  cfg_compute_first: bool = False
69
- # Compute spearate diff values for CFG and non-CFG step,
69
+ # Compute separate diff values for CFG and non-CFG step,
70
70
  # default True. If False, we will use the computed diff from
71
71
  # current non-CFG transformer step for current CFG step.
72
72
  cfg_diff_compute_separate: bool = True
@@ -89,7 +89,7 @@ class CachedContext: # Internal CachedContext Impl class
89
89
  if logger.isEnabledFor(logging.DEBUG):
90
90
  logger.info(f"Created _CacheContext: {self.name}")
91
91
  # Some checks for settings
92
- if self.enable_spearate_cfg:
92
+ if self.enable_separate_cfg:
93
93
  if self.cfg_diff_compute_separate:
94
94
  assert self.cfg_compute_first is False, (
95
95
  "cfg_compute_first must set as False if "
@@ -108,12 +108,12 @@ class CachedContext: # Internal CachedContext Impl class
108
108
 
109
109
  if self.enable_taylorseer:
110
110
  self.taylorseer = TaylorSeer(**self.taylorseer_kwargs)
111
- if self.enable_spearate_cfg:
111
+ if self.enable_separate_cfg:
112
112
  self.cfg_taylorseer = TaylorSeer(**self.taylorseer_kwargs)
113
113
 
114
114
  if self.enable_encoder_taylorseer:
115
115
  self.encoder_tarlorseer = TaylorSeer(**self.taylorseer_kwargs)
116
- if self.enable_spearate_cfg:
116
+ if self.enable_separate_cfg:
117
117
  self.cfg_encoder_taylorseer = TaylorSeer(
118
118
  **self.taylorseer_kwargs
119
119
  )
@@ -145,7 +145,7 @@ class CachedContext: # Internal CachedContext Impl class
145
145
  # incr step: prev 0 -> 1; prev 1 -> 2
146
146
  # current step: incr step - 1
147
147
  self.transformer_executed_steps += 1
148
- if not self.enable_spearate_cfg:
148
+ if not self.enable_separate_cfg:
149
149
  self.executed_steps += 1
150
150
  else:
151
151
  # 0,1 -> 0 + 1, 2,3 -> 1 + 1, ...
@@ -183,7 +183,7 @@ class CachedContext: # Internal CachedContext Impl class
183
183
 
184
184
  # mark_step_begin of TaylorSeer must be called after the cache is reset.
185
185
  if self.enable_taylorseer or self.enable_encoder_taylorseer:
186
- if self.enable_spearate_cfg:
186
+ if self.enable_separate_cfg:
187
187
  # Assume non-CFG steps: 0, 2, 4, 6, ...
188
188
  if not self.is_separate_cfg_step():
189
189
  taylorseer, encoder_taylorseer = self.get_taylorseers()
@@ -269,7 +269,7 @@ class CachedContext: # Internal CachedContext Impl class
269
269
  return self.transformer_executed_steps - 1
270
270
 
271
271
  def is_separate_cfg_step(self):
272
- if not self.enable_spearate_cfg:
272
+ if not self.enable_separate_cfg:
273
273
  return False
274
274
  if self.cfg_compute_first:
275
275
  # CFG steps: 0, 2, 4, 6, ...
@@ -74,8 +74,8 @@ class CachedContextManager:
74
74
  del self._cached_context_manager[cached_context]
75
75
 
76
76
  def clear_contexts(self):
77
- for cached_context in self._cached_context_manager:
78
- self.remove_context(cached_context)
77
+ for context_name in list(self._cached_context_manager.keys()):
78
+ self.remove_context(context_name)
79
79
 
80
80
  @contextlib.contextmanager
81
81
  def enter_context(self, cached_context: CachedContext | str):
@@ -364,10 +364,10 @@ class CachedContextManager:
364
364
  return cached_context.Bn_compute_blocks
365
365
 
366
366
  @torch.compiler.disable
367
- def enable_spearate_cfg(self) -> bool:
367
+ def enable_separate_cfg(self) -> bool:
368
368
  cached_context = self.get_context()
369
369
  assert cached_context is not None, "cached_context must be set before"
370
- return cached_context.enable_spearate_cfg
370
+ return cached_context.enable_separate_cfg
371
371
 
372
372
  @torch.compiler.disable
373
373
  def is_separate_cfg_step(self) -> bool:
@@ -410,7 +410,7 @@ class CachedContextManager:
410
410
 
411
411
  if all(
412
412
  (
413
- self.enable_spearate_cfg(),
413
+ self.enable_separate_cfg(),
414
414
  self.is_separate_cfg_step(),
415
415
  not self.cfg_diff_compute_separate(),
416
416
  self.get_current_step_residual_diff() is not None,
@@ -1,4 +1,6 @@
1
1
  import math
2
+ import torch
3
+ from typing import List, Dict
2
4
 
3
5
 
4
6
  class TaylorSeer:
@@ -17,7 +19,7 @@ class TaylorSeer:
17
19
  self.reset_cache()
18
20
 
19
21
  def reset_cache(self):
20
- self.state = {
22
+ self.state: Dict[str, List[torch.Tensor]] = {
21
23
  "dY_prev": [None] * self.ORDER,
22
24
  "dY_current": [None] * self.ORDER,
23
25
  }
@@ -36,15 +38,19 @@ class TaylorSeer:
36
38
  return True
37
39
  return False
38
40
 
39
- def approximate_derivative(self, Y):
41
+ def approximate_derivative(self, Y: torch.Tensor) -> List[torch.Tensor]:
40
42
  # n-th order Taylor expansion:
41
43
  # Y(t) = Y(0) + dY(0)/dt * t + d^2Y(0)/dt^2 * t^2 / 2!
42
44
  # + ... + d^nY(0)/dt^n * t^n / n!
43
45
  # TODO: Custom Triton/CUDA kernel for better performance,
44
46
  # especially for large n_derivatives.
45
- dY_current = [None] * self.ORDER
47
+ dY_current: List[torch.Tensor] = [None] * self.ORDER
46
48
  dY_current[0] = Y
47
49
  window = self.current_step - self.last_non_approximated_step
50
+ if self.state["dY_prev"][0] is not None:
51
+ if dY_current[0].shape != self.state["dY_prev"][0].shape:
52
+ self.reset_cache()
53
+
48
54
  for i in range(self.n_derivatives):
49
55
  if self.state["dY_prev"][i] is not None and self.current_step > 1:
50
56
  dY_current[i + 1] = (
@@ -54,7 +60,7 @@ class TaylorSeer:
54
60
  break
55
61
  return dY_current
56
62
 
57
- def approximate_value(self):
63
+ def approximate_value(self) -> torch.Tensor:
58
64
  # TODO: Custom Triton/CUDA kernel for better performance,
59
65
  # especially for large n_derivatives.
60
66
  elapsed = self.current_step - self.last_non_approximated_step
@@ -69,7 +75,7 @@ class TaylorSeer:
69
75
  def mark_step_begin(self):
70
76
  self.current_step += 1
71
77
 
72
- def update(self, Y):
78
+ def update(self, Y: torch.Tensor):
73
79
  # Directly call this method will ingnore the warmup
74
80
  # policy and force full computation.
75
81
  # Assume warmup steps is 3, and n_derivatives is 3.
@@ -87,7 +93,7 @@ class TaylorSeer:
87
93
  self.state["dY_current"] = self.approximate_derivative(Y)
88
94
  self.last_non_approximated_step = self.current_step
89
95
 
90
- def step(self, Y):
96
+ def step(self, Y: torch.Tensor):
91
97
  self.mark_step_begin()
92
98
  if self.should_compute_full():
93
99
  self.update(Y)
@@ -24,14 +24,14 @@ def enable_cache(
24
24
  max_continuous_cached_steps: int = -1,
25
25
  residual_diff_threshold: float = 0.08,
26
26
  # Cache CFG or not
27
- enable_spearate_cfg: bool | None = None,
27
+ enable_separate_cfg: bool = None,
28
28
  cfg_compute_first: bool = False,
29
29
  cfg_diff_compute_separate: bool = True,
30
30
  # Hybird TaylorSeer
31
31
  enable_taylorseer: bool = False,
32
32
  enable_encoder_taylorseer: bool = False,
33
33
  taylorseer_cache_type: str = "residual",
34
- taylorseer_order: int = 2,
34
+ taylorseer_order: int = 1,
35
35
  **other_cache_context_kwargs,
36
36
  ) -> Union[
37
37
  DiffusionPipeline,
@@ -70,15 +70,15 @@ def enable_cache(
70
70
  residual_diff_threshold (`float`, *required*, defaults to 0.08):
71
71
  he value of residual diff threshold, a higher value leads to faster performance at the
72
72
  cost of lower precision.
73
- enable_spearate_cfg (`bool`, *required*, defaults to None):
73
+ enable_separate_cfg (`bool`, *required*, defaults to None):
74
74
  Whether to do separate cfg or not, such as Wan 2.1, Qwen-Image. For model that fused CFG
75
- and non-CFG into single forward step, should set enable_spearate_cfg as False, for example:
75
+ and non-CFG into single forward step, should set enable_separate_cfg as False, for example:
76
76
  CogVideoX, HunyuanVideo, Mochi, etc.
77
77
  cfg_compute_first (`bool`, *required*, defaults to False):
78
78
  Compute cfg forward first or not, default False, namely, 0, 2, 4, ..., -> non-CFG step;
79
79
  1, 3, 5, ... -> CFG step.
80
80
  cfg_diff_compute_separate (`bool`, *required*, defaults to True):
81
- Compute spearate diff values for CFG and non-CFG step, default True. If False, we will
81
+ Compute separate diff values for CFG and non-CFG step, default True. If False, we will
82
82
  use the computed diff from current non-CFG transformer step for current CFG step.
83
83
  enable_taylorseer (`bool`, *required*, defaults to False):
84
84
  Enable the hybird TaylorSeer for hidden_states or not. We have supported the
@@ -91,10 +91,10 @@ def enable_cache(
91
91
  Enable the hybird TaylorSeer for encoder_hidden_states or not.
92
92
  taylorseer_cache_type (`str`, *required*, defaults to `residual`):
93
93
  The TaylorSeer implemented in cache-dit supports both `hidden_states` and `residual` as cache type.
94
- taylorseer_order (`int`, *required*, defaults to 2):
94
+ taylorseer_order (`int`, *required*, defaults to 1):
95
95
  The order of taylorseer, higher values of n_derivatives will lead to longer computation time,
96
- but may improve precision significantly.
97
- other_cache_kwargs: (`dict`, *optional*, defaults to {})
96
+ the recommended value is 1 or 2.
97
+ other_cache_context_kwargs: (`dict`, *optional*, defaults to {})
98
98
  Other cache context kwargs, please check https://github.com/vipshop/cache-dit/blob/main/src/cache_dit/cache_factory/cache_contexts/cache_context.py
99
99
  for more details.
100
100
 
@@ -123,7 +123,7 @@ def enable_cache(
123
123
  max_continuous_cached_steps
124
124
  )
125
125
  cache_context_kwargs["residual_diff_threshold"] = residual_diff_threshold
126
- cache_context_kwargs["enable_spearate_cfg"] = enable_spearate_cfg
126
+ cache_context_kwargs["enable_separate_cfg"] = enable_separate_cfg
127
127
  cache_context_kwargs["cfg_compute_first"] = cfg_compute_first
128
128
  cache_context_kwargs["cfg_diff_compute_separate"] = (
129
129
  cfg_diff_compute_separate
@@ -1,4 +1,5 @@
1
1
  from cache_dit.cache_factory.patch_functors.functor_base import PatchFunctor
2
+ from cache_dit.cache_factory.patch_functors.functor_dit import DiTPatchFunctor
2
3
  from cache_dit.cache_factory.patch_functors.functor_flux import FluxPatchFunctor
3
4
  from cache_dit.cache_factory.patch_functors.functor_chroma import (
4
5
  ChromaPatchFunctor,