cache-dit 0.2.33__py3-none-any.whl → 0.2.34__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
cache_dit/_version.py CHANGED
@@ -28,7 +28,7 @@ version_tuple: VERSION_TUPLE
28
28
  commit_id: COMMIT_ID
29
29
  __commit_id__: COMMIT_ID
30
30
 
31
- __version__ = version = '0.2.33'
32
- __version_tuple__ = version_tuple = (0, 2, 33)
31
+ __version__ = version = '0.2.34'
32
+ __version_tuple__ = version_tuple = (0, 2, 34)
33
33
 
34
34
  __commit_id__ = commit_id = None
@@ -153,7 +153,7 @@ def qwenimage_adapter(pipe, **kwargs) -> BlockAdapter:
153
153
  )
154
154
 
155
155
 
156
- @BlockAdapterRegistry.register("LTXVideo")
156
+ @BlockAdapterRegistry.register("LTX")
157
157
  def ltxvideo_adapter(pipe, **kwargs) -> BlockAdapter:
158
158
  from diffusers import LTXVideoTransformer3DModel
159
159
 
@@ -248,7 +248,10 @@ def skyreelsv2_adapter(pipe, **kwargs) -> BlockAdapter:
248
248
  pipe=pipe,
249
249
  transformer=pipe.transformer,
250
250
  blocks=pipe.transformer.blocks,
251
- forward_pattern=ForwardPattern.Pattern_2,
251
+ # NOTE: Use Pattern_3 instead of Pattern_2 because the
252
+ # encoder_hidden_states will never change in the blocks
253
+ # forward loop.
254
+ forward_pattern=ForwardPattern.Pattern_3,
252
255
  has_separate_cfg=True,
253
256
  **kwargs,
254
257
  )
@@ -285,6 +288,7 @@ def consisid_adapter(pipe, **kwargs) -> BlockAdapter:
285
288
  @BlockAdapterRegistry.register("DiT")
286
289
  def dit_adapter(pipe, **kwargs) -> BlockAdapter:
287
290
  from diffusers import DiTTransformer2DModel
291
+ from cache_dit.cache_factory.patch_functors import DiTPatchFunctor
288
292
 
289
293
  assert isinstance(pipe.transformer, DiTTransformer2DModel)
290
294
  return BlockAdapter(
@@ -292,6 +296,7 @@ def dit_adapter(pipe, **kwargs) -> BlockAdapter:
292
296
  transformer=pipe.transformer,
293
297
  blocks=pipe.transformer.transformer_blocks,
294
298
  forward_pattern=ForwardPattern.Pattern_3,
299
+ patch_functor=DiTPatchFunctor(),
295
300
  **kwargs,
296
301
  )
297
302
 
@@ -331,24 +336,13 @@ def bria_adapter(pipe, **kwargs) -> BlockAdapter:
331
336
 
332
337
 
333
338
  @BlockAdapterRegistry.register("Lumina")
334
- def lumina_adapter(pipe, **kwargs) -> BlockAdapter:
335
- from diffusers import LuminaNextDiT2DModel
336
-
337
- assert isinstance(pipe.transformer, LuminaNextDiT2DModel)
338
- return BlockAdapter(
339
- pipe=pipe,
340
- transformer=pipe.transformer,
341
- blocks=pipe.transformer.layers,
342
- forward_pattern=ForwardPattern.Pattern_3,
343
- **kwargs,
344
- )
345
-
346
-
347
- @BlockAdapterRegistry.register("Lumina2")
348
339
  def lumina2_adapter(pipe, **kwargs) -> BlockAdapter:
349
340
  from diffusers import Lumina2Transformer2DModel
341
+ from diffusers import LuminaNextDiT2DModel
350
342
 
351
- assert isinstance(pipe.transformer, Lumina2Transformer2DModel)
343
+ assert isinstance(
344
+ pipe.transformer, (Lumina2Transformer2DModel, LuminaNextDiT2DModel)
345
+ )
352
346
  return BlockAdapter(
353
347
  pipe=pipe,
354
348
  transformer=pipe.transformer,
@@ -386,12 +380,10 @@ def pixart_adapter(pipe, **kwargs) -> BlockAdapter:
386
380
  )
387
381
 
388
382
 
389
- @BlockAdapterRegistry.register("Sana", supported=False)
383
+ @BlockAdapterRegistry.register("Sana")
390
384
  def sana_adapter(pipe, **kwargs) -> BlockAdapter:
391
385
  from diffusers import SanaTransformer2DModel
392
386
 
393
- # TODO: fix -> got multiple values for argument 'encoder_hidden_states'
394
-
395
387
  assert isinstance(pipe.transformer, SanaTransformer2DModel)
396
388
  return BlockAdapter(
397
389
  pipe=pipe,
@@ -469,6 +461,7 @@ def auraflow_adapter(pipe, **kwargs) -> BlockAdapter:
469
461
  @BlockAdapterRegistry.register("Chroma")
470
462
  def chroma_adapter(pipe, **kwargs) -> BlockAdapter:
471
463
  from diffusers import ChromaTransformer2DModel
464
+ from cache_dit.cache_factory.patch_functors import ChromaPatchFunctor
472
465
 
473
466
  assert isinstance(pipe.transformer, ChromaTransformer2DModel)
474
467
  return BlockAdapter(
@@ -482,6 +475,7 @@ def chroma_adapter(pipe, **kwargs) -> BlockAdapter:
482
475
  ForwardPattern.Pattern_1,
483
476
  ForwardPattern.Pattern_3,
484
477
  ],
478
+ patch_functor=ChromaPatchFunctor(),
485
479
  has_separate_cfg=True,
486
480
  **kwargs,
487
481
  )
@@ -16,8 +16,52 @@ logger = init_logger(__name__)
16
16
 
17
17
 
18
18
  class ParamsModifier:
19
- def __init__(self, **kwargs):
20
- self._context_kwargs = kwargs.copy()
19
+ def __init__(
20
+ self,
21
+ # Cache context kwargs
22
+ Fn_compute_blocks: Optional[int] = None,
23
+ Bn_compute_blocks: Optional[int] = None,
24
+ max_warmup_steps: Optional[int] = None,
25
+ max_cached_steps: Optional[int] = None,
26
+ max_continuous_cached_steps: Optional[int] = None,
27
+ residual_diff_threshold: Optional[float] = None,
28
+ # Cache CFG or not
29
+ enable_separate_cfg: Optional[bool] = None,
30
+ cfg_compute_first: Optional[bool] = None,
31
+ cfg_diff_compute_separate: Optional[bool] = None,
32
+ # Hybird TaylorSeer
33
+ enable_taylorseer: Optional[bool] = None,
34
+ enable_encoder_taylorseer: Optional[bool] = None,
35
+ taylorseer_cache_type: Optional[str] = None,
36
+ taylorseer_order: Optional[int] = None,
37
+ **other_cache_context_kwargs,
38
+ ):
39
+ self._context_kwargs = other_cache_context_kwargs.copy()
40
+ self._maybe_update_param("Fn_compute_blocks", Fn_compute_blocks)
41
+ self._maybe_update_param("Bn_compute_blocks", Bn_compute_blocks)
42
+ self._maybe_update_param("max_warmup_steps", max_warmup_steps)
43
+ self._maybe_update_param("max_cached_steps", max_cached_steps)
44
+ self._maybe_update_param(
45
+ "max_continuous_cached_steps", max_continuous_cached_steps
46
+ )
47
+ self._maybe_update_param(
48
+ "residual_diff_threshold", residual_diff_threshold
49
+ )
50
+ self._maybe_update_param("enable_separate_cfg", enable_separate_cfg)
51
+ self._maybe_update_param("cfg_compute_first", cfg_compute_first)
52
+ self._maybe_update_param(
53
+ "cfg_diff_compute_separate", cfg_diff_compute_separate
54
+ )
55
+ self._maybe_update_param("enable_taylorseer", enable_taylorseer)
56
+ self._maybe_update_param(
57
+ "enable_encoder_taylorseer", enable_encoder_taylorseer
58
+ )
59
+ self._maybe_update_param("taylorseer_cache_type", taylorseer_cache_type)
60
+ self._maybe_update_param("taylorseer_order", taylorseer_order)
61
+
62
+ def _maybe_update_param(self, key: str, value: Any):
63
+ if value is not None:
64
+ self._context_kwargs[key] = value
21
65
 
22
66
 
23
67
  @dataclasses.dataclass
@@ -10,13 +10,14 @@ logger = init_logger(__name__)
10
10
 
11
11
  class BlockAdapterRegistry:
12
12
  _adapters: Dict[str, Callable[..., BlockAdapter]] = {}
13
- _predefined_adapters_has_spearate_cfg: List[str] = [
13
+ _predefined_adapters_has_separate_cfg: List[str] = [
14
14
  "QwenImage",
15
15
  "Wan",
16
16
  "CogView4",
17
17
  "Cosmos",
18
18
  "SkyReelsV2",
19
19
  "Chroma",
20
+ "Lumina2",
20
21
  ]
21
22
 
22
23
  @classmethod
@@ -68,7 +69,7 @@ class BlockAdapterRegistry:
68
69
  return True
69
70
 
70
71
  pipe_cls_name = pipe_or_adapter.__class__.__name__
71
- for name in cls._predefined_adapters_has_spearate_cfg:
72
+ for name in cls._predefined_adapters_has_separate_cfg:
72
73
  if pipe_cls_name.startswith(name):
73
74
  return True
74
75
 
@@ -114,27 +114,27 @@ class CachedAdapter:
114
114
  **cache_context_kwargs,
115
115
  ):
116
116
  # Check cache_context_kwargs
117
- if cache_context_kwargs["enable_spearate_cfg"] is None:
117
+ if cache_context_kwargs["enable_separate_cfg"] is None:
118
118
  # Check cfg for some specific case if users don't set it as True
119
119
  if BlockAdapterRegistry.has_separate_cfg(block_adapter):
120
- cache_context_kwargs["enable_spearate_cfg"] = True
120
+ cache_context_kwargs["enable_separate_cfg"] = True
121
121
  logger.info(
122
- f"Use custom 'enable_spearate_cfg' from BlockAdapter: True. "
122
+ f"Use custom 'enable_separate_cfg' from BlockAdapter: True. "
123
123
  f"Pipeline: {block_adapter.pipe.__class__.__name__}."
124
124
  )
125
125
  else:
126
- cache_context_kwargs["enable_spearate_cfg"] = (
126
+ cache_context_kwargs["enable_separate_cfg"] = (
127
127
  BlockAdapterRegistry.has_separate_cfg(block_adapter.pipe)
128
128
  )
129
129
  logger.info(
130
- f"Use default 'enable_spearate_cfg' from block adapter "
131
- f"register: {cache_context_kwargs['enable_spearate_cfg']}, "
130
+ f"Use default 'enable_separate_cfg' from block adapter "
131
+ f"register: {cache_context_kwargs['enable_separate_cfg']}, "
132
132
  f"Pipeline: {block_adapter.pipe.__class__.__name__}."
133
133
  )
134
134
  else:
135
135
  logger.info(
136
- f"Use custom 'enable_spearate_cfg' from cache context "
137
- f"kwargs: {cache_context_kwargs['enable_spearate_cfg']}. "
136
+ f"Use custom 'enable_separate_cfg' from cache context "
137
+ f"kwargs: {cache_context_kwargs['enable_separate_cfg']}. "
138
138
  f"Pipeline: {block_adapter.pipe.__class__.__name__}."
139
139
  )
140
140
 
@@ -53,20 +53,20 @@ class CachedContext: # Internal CachedContext Impl class
53
53
  enable_taylorseer: bool = False
54
54
  enable_encoder_taylorseer: bool = False
55
55
  taylorseer_cache_type: str = "hidden_states" # residual or hidden_states
56
- taylorseer_order: int = 2 # The order for TaylorSeer
56
+ taylorseer_order: int = 1 # The order for TaylorSeer
57
57
  taylorseer_kwargs: Dict[str, Any] = dataclasses.field(default_factory=dict)
58
58
  taylorseer: Optional[TaylorSeer] = None
59
59
  encoder_tarlorseer: Optional[TaylorSeer] = None
60
60
 
61
- # Support enable_spearate_cfg, such as Wan 2.1,
61
+ # Support enable_separate_cfg, such as Wan 2.1,
62
62
  # Qwen-Image. For model that fused CFG and non-CFG into single
63
- # forward step, should set enable_spearate_cfg as False.
63
+ # forward step, should set enable_separate_cfg as False.
64
64
  # For example: CogVideoX, HunyuanVideo, Mochi.
65
- enable_spearate_cfg: bool = False
65
+ enable_separate_cfg: bool = False
66
66
  # Compute cfg forward first or not, default False, namely,
67
67
  # 0, 2, 4, ..., -> non-CFG step; 1, 3, 5, ... -> CFG step.
68
68
  cfg_compute_first: bool = False
69
- # Compute spearate diff values for CFG and non-CFG step,
69
+ # Compute separate diff values for CFG and non-CFG step,
70
70
  # default True. If False, we will use the computed diff from
71
71
  # current non-CFG transformer step for current CFG step.
72
72
  cfg_diff_compute_separate: bool = True
@@ -89,7 +89,7 @@ class CachedContext: # Internal CachedContext Impl class
89
89
  if logger.isEnabledFor(logging.DEBUG):
90
90
  logger.info(f"Created _CacheContext: {self.name}")
91
91
  # Some checks for settings
92
- if self.enable_spearate_cfg:
92
+ if self.enable_separate_cfg:
93
93
  if self.cfg_diff_compute_separate:
94
94
  assert self.cfg_compute_first is False, (
95
95
  "cfg_compute_first must set as False if "
@@ -108,12 +108,12 @@ class CachedContext: # Internal CachedContext Impl class
108
108
 
109
109
  if self.enable_taylorseer:
110
110
  self.taylorseer = TaylorSeer(**self.taylorseer_kwargs)
111
- if self.enable_spearate_cfg:
111
+ if self.enable_separate_cfg:
112
112
  self.cfg_taylorseer = TaylorSeer(**self.taylorseer_kwargs)
113
113
 
114
114
  if self.enable_encoder_taylorseer:
115
115
  self.encoder_tarlorseer = TaylorSeer(**self.taylorseer_kwargs)
116
- if self.enable_spearate_cfg:
116
+ if self.enable_separate_cfg:
117
117
  self.cfg_encoder_taylorseer = TaylorSeer(
118
118
  **self.taylorseer_kwargs
119
119
  )
@@ -145,7 +145,7 @@ class CachedContext: # Internal CachedContext Impl class
145
145
  # incr step: prev 0 -> 1; prev 1 -> 2
146
146
  # current step: incr step - 1
147
147
  self.transformer_executed_steps += 1
148
- if not self.enable_spearate_cfg:
148
+ if not self.enable_separate_cfg:
149
149
  self.executed_steps += 1
150
150
  else:
151
151
  # 0,1 -> 0 + 1, 2,3 -> 1 + 1, ...
@@ -183,7 +183,7 @@ class CachedContext: # Internal CachedContext Impl class
183
183
 
184
184
  # mark_step_begin of TaylorSeer must be called after the cache is reset.
185
185
  if self.enable_taylorseer or self.enable_encoder_taylorseer:
186
- if self.enable_spearate_cfg:
186
+ if self.enable_separate_cfg:
187
187
  # Assume non-CFG steps: 0, 2, 4, 6, ...
188
188
  if not self.is_separate_cfg_step():
189
189
  taylorseer, encoder_taylorseer = self.get_taylorseers()
@@ -269,7 +269,7 @@ class CachedContext: # Internal CachedContext Impl class
269
269
  return self.transformer_executed_steps - 1
270
270
 
271
271
  def is_separate_cfg_step(self):
272
- if not self.enable_spearate_cfg:
272
+ if not self.enable_separate_cfg:
273
273
  return False
274
274
  if self.cfg_compute_first:
275
275
  # CFG steps: 0, 2, 4, 6, ...
@@ -74,8 +74,8 @@ class CachedContextManager:
74
74
  del self._cached_context_manager[cached_context]
75
75
 
76
76
  def clear_contexts(self):
77
- for cached_context in self._cached_context_manager:
78
- self.remove_context(cached_context)
77
+ for context_name in list(self._cached_context_manager.keys()):
78
+ self.remove_context(context_name)
79
79
 
80
80
  @contextlib.contextmanager
81
81
  def enter_context(self, cached_context: CachedContext | str):
@@ -364,10 +364,10 @@ class CachedContextManager:
364
364
  return cached_context.Bn_compute_blocks
365
365
 
366
366
  @torch.compiler.disable
367
- def enable_spearate_cfg(self) -> bool:
367
+ def enable_separate_cfg(self) -> bool:
368
368
  cached_context = self.get_context()
369
369
  assert cached_context is not None, "cached_context must be set before"
370
- return cached_context.enable_spearate_cfg
370
+ return cached_context.enable_separate_cfg
371
371
 
372
372
  @torch.compiler.disable
373
373
  def is_separate_cfg_step(self) -> bool:
@@ -410,7 +410,7 @@ class CachedContextManager:
410
410
 
411
411
  if all(
412
412
  (
413
- self.enable_spearate_cfg(),
413
+ self.enable_separate_cfg(),
414
414
  self.is_separate_cfg_step(),
415
415
  not self.cfg_diff_compute_separate(),
416
416
  self.get_current_step_residual_diff() is not None,
@@ -1,4 +1,6 @@
1
1
  import math
2
+ import torch
3
+ from typing import List, Dict
2
4
 
3
5
 
4
6
  class TaylorSeer:
@@ -17,7 +19,7 @@ class TaylorSeer:
17
19
  self.reset_cache()
18
20
 
19
21
  def reset_cache(self):
20
- self.state = {
22
+ self.state: Dict[str, List[torch.Tensor]] = {
21
23
  "dY_prev": [None] * self.ORDER,
22
24
  "dY_current": [None] * self.ORDER,
23
25
  }
@@ -36,15 +38,19 @@ class TaylorSeer:
36
38
  return True
37
39
  return False
38
40
 
39
- def approximate_derivative(self, Y):
41
+ def approximate_derivative(self, Y: torch.Tensor) -> List[torch.Tensor]:
40
42
  # n-th order Taylor expansion:
41
43
  # Y(t) = Y(0) + dY(0)/dt * t + d^2Y(0)/dt^2 * t^2 / 2!
42
44
  # + ... + d^nY(0)/dt^n * t^n / n!
43
45
  # TODO: Custom Triton/CUDA kernel for better performance,
44
46
  # especially for large n_derivatives.
45
- dY_current = [None] * self.ORDER
47
+ dY_current: List[torch.Tensor] = [None] * self.ORDER
46
48
  dY_current[0] = Y
47
49
  window = self.current_step - self.last_non_approximated_step
50
+ if self.state["dY_prev"][0] is not None:
51
+ if dY_current[0].shape != self.state["dY_prev"][0].shape:
52
+ self.reset_cache()
53
+
48
54
  for i in range(self.n_derivatives):
49
55
  if self.state["dY_prev"][i] is not None and self.current_step > 1:
50
56
  dY_current[i + 1] = (
@@ -54,7 +60,7 @@ class TaylorSeer:
54
60
  break
55
61
  return dY_current
56
62
 
57
- def approximate_value(self):
63
+ def approximate_value(self) -> torch.Tensor:
58
64
  # TODO: Custom Triton/CUDA kernel for better performance,
59
65
  # especially for large n_derivatives.
60
66
  elapsed = self.current_step - self.last_non_approximated_step
@@ -69,7 +75,7 @@ class TaylorSeer:
69
75
  def mark_step_begin(self):
70
76
  self.current_step += 1
71
77
 
72
- def update(self, Y):
78
+ def update(self, Y: torch.Tensor):
73
79
  # Directly call this method will ingnore the warmup
74
80
  # policy and force full computation.
75
81
  # Assume warmup steps is 3, and n_derivatives is 3.
@@ -87,7 +93,7 @@ class TaylorSeer:
87
93
  self.state["dY_current"] = self.approximate_derivative(Y)
88
94
  self.last_non_approximated_step = self.current_step
89
95
 
90
- def step(self, Y):
96
+ def step(self, Y: torch.Tensor):
91
97
  self.mark_step_begin()
92
98
  if self.should_compute_full():
93
99
  self.update(Y)
@@ -24,14 +24,14 @@ def enable_cache(
24
24
  max_continuous_cached_steps: int = -1,
25
25
  residual_diff_threshold: float = 0.08,
26
26
  # Cache CFG or not
27
- enable_spearate_cfg: bool | None = None,
27
+ enable_separate_cfg: bool = None,
28
28
  cfg_compute_first: bool = False,
29
29
  cfg_diff_compute_separate: bool = True,
30
30
  # Hybird TaylorSeer
31
31
  enable_taylorseer: bool = False,
32
32
  enable_encoder_taylorseer: bool = False,
33
33
  taylorseer_cache_type: str = "residual",
34
- taylorseer_order: int = 2,
34
+ taylorseer_order: int = 1,
35
35
  **other_cache_context_kwargs,
36
36
  ) -> Union[
37
37
  DiffusionPipeline,
@@ -70,15 +70,15 @@ def enable_cache(
70
70
  residual_diff_threshold (`float`, *required*, defaults to 0.08):
71
71
  he value of residual diff threshold, a higher value leads to faster performance at the
72
72
  cost of lower precision.
73
- enable_spearate_cfg (`bool`, *required*, defaults to None):
73
+ enable_separate_cfg (`bool`, *required*, defaults to None):
74
74
  Whether to do separate cfg or not, such as Wan 2.1, Qwen-Image. For model that fused CFG
75
- and non-CFG into single forward step, should set enable_spearate_cfg as False, for example:
75
+ and non-CFG into single forward step, should set enable_separate_cfg as False, for example:
76
76
  CogVideoX, HunyuanVideo, Mochi, etc.
77
77
  cfg_compute_first (`bool`, *required*, defaults to False):
78
78
  Compute cfg forward first or not, default False, namely, 0, 2, 4, ..., -> non-CFG step;
79
79
  1, 3, 5, ... -> CFG step.
80
80
  cfg_diff_compute_separate (`bool`, *required*, defaults to True):
81
- Compute spearate diff values for CFG and non-CFG step, default True. If False, we will
81
+ Compute separate diff values for CFG and non-CFG step, default True. If False, we will
82
82
  use the computed diff from current non-CFG transformer step for current CFG step.
83
83
  enable_taylorseer (`bool`, *required*, defaults to False):
84
84
  Enable the hybird TaylorSeer for hidden_states or not. We have supported the
@@ -91,10 +91,10 @@ def enable_cache(
91
91
  Enable the hybird TaylorSeer for encoder_hidden_states or not.
92
92
  taylorseer_cache_type (`str`, *required*, defaults to `residual`):
93
93
  The TaylorSeer implemented in cache-dit supports both `hidden_states` and `residual` as cache type.
94
- taylorseer_order (`int`, *required*, defaults to 2):
94
+ taylorseer_order (`int`, *required*, defaults to 1):
95
95
  The order of taylorseer, higher values of n_derivatives will lead to longer computation time,
96
- but may improve precision significantly.
97
- other_cache_kwargs: (`dict`, *optional*, defaults to {})
96
+ the recommended value is 1 or 2.
97
+ other_cache_context_kwargs: (`dict`, *optional*, defaults to {})
98
98
  Other cache context kwargs, please check https://github.com/vipshop/cache-dit/blob/main/src/cache_dit/cache_factory/cache_contexts/cache_context.py
99
99
  for more details.
100
100
 
@@ -123,7 +123,7 @@ def enable_cache(
123
123
  max_continuous_cached_steps
124
124
  )
125
125
  cache_context_kwargs["residual_diff_threshold"] = residual_diff_threshold
126
- cache_context_kwargs["enable_spearate_cfg"] = enable_spearate_cfg
126
+ cache_context_kwargs["enable_separate_cfg"] = enable_separate_cfg
127
127
  cache_context_kwargs["cfg_compute_first"] = cfg_compute_first
128
128
  cache_context_kwargs["cfg_diff_compute_separate"] = (
129
129
  cfg_diff_compute_separate
@@ -1,4 +1,5 @@
1
1
  from cache_dit.cache_factory.patch_functors.functor_base import PatchFunctor
2
+ from cache_dit.cache_factory.patch_functors.functor_dit import DiTPatchFunctor
2
3
  from cache_dit.cache_factory.patch_functors.functor_flux import FluxPatchFunctor
3
4
  from cache_dit.cache_factory.patch_functors.functor_chroma import (
4
5
  ChromaPatchFunctor,
@@ -1,10 +1,9 @@
1
- import inspect
2
-
3
1
  import torch
4
2
  import numpy as np
5
3
  from typing import Tuple, Optional, Dict, Any, Union
6
4
  from diffusers import ChromaTransformer2DModel
7
5
  from diffusers.models.transformers.transformer_chroma import (
6
+ ChromaTransformerBlock,
8
7
  ChromaSingleTransformerBlock,
9
8
  Transformer2DModelOutput,
10
9
  )
@@ -27,24 +26,31 @@ class ChromaPatchFunctor(PatchFunctor):
27
26
  def apply(
28
27
  self,
29
28
  transformer: ChromaTransformer2DModel,
30
- blocks: torch.nn.ModuleList = None,
31
29
  **kwargs,
32
30
  ) -> ChromaTransformer2DModel:
33
31
  if hasattr(transformer, "_is_patched"):
34
32
  return transformer
35
33
 
36
- if blocks is None:
37
- blocks = transformer.single_transformer_blocks
38
-
39
34
  is_patched = False
40
- for block in blocks:
41
- if isinstance(block, ChromaSingleTransformerBlock):
42
- forward_parameters = inspect.signature(
43
- block.forward
44
- ).parameters.keys()
45
- if "encoder_hidden_states" not in forward_parameters:
46
- block.forward = __patch_single_forward__.__get__(block)
47
- is_patched = True
35
+ for index_block, block in enumerate(transformer.transformer_blocks):
36
+ assert isinstance(block, ChromaTransformerBlock)
37
+ img_offset = 3 * len(transformer.single_transformer_blocks)
38
+ txt_offset = img_offset + 6 * len(transformer.transformer_blocks)
39
+ img_modulation = img_offset + 6 * index_block
40
+ text_modulation = txt_offset + 6 * index_block
41
+ block._img_modulation = img_modulation
42
+ block._text_modulation = text_modulation
43
+ block.forward = __patch_double_forward__.__get__(block)
44
+
45
+ for index_block, block in enumerate(
46
+ transformer.single_transformer_blocks
47
+ ):
48
+ assert isinstance(block, ChromaSingleTransformerBlock)
49
+ start_idx = 3 * index_block
50
+ block._start_idx = start_idx
51
+ block.forward = __patch_single_forward__.__get__(block)
52
+
53
+ is_patched = True
48
54
 
49
55
  cls_name = transformer.__class__.__name__
50
56
 
@@ -69,25 +75,123 @@ class ChromaPatchFunctor(PatchFunctor):
69
75
  return transformer
70
76
 
71
77
 
78
+ # Adapted from: https://github.com/huggingface/diffusers/blob/main/src/diffusers/models/transformers/transformer_chroma.py
79
+ def __patch_double_forward__(
80
+ self: ChromaTransformerBlock,
81
+ hidden_states: torch.Tensor,
82
+ encoder_hidden_states: torch.Tensor,
83
+ pooled_temb: torch.Tensor,
84
+ image_rotary_emb: Optional[Tuple[torch.Tensor, torch.Tensor]] = None,
85
+ attention_mask: Optional[torch.Tensor] = None,
86
+ joint_attention_kwargs: Optional[Dict[str, Any]] = None,
87
+ ) -> Tuple[torch.Tensor, torch.Tensor]:
88
+ # TODO: Fuse controlnet into block forward
89
+ img_modulation = self._img_modulation
90
+ text_modulation = self._text_modulation
91
+ temb = torch.cat(
92
+ (
93
+ pooled_temb[:, img_modulation : img_modulation + 6],
94
+ pooled_temb[:, text_modulation : text_modulation + 6],
95
+ ),
96
+ dim=1,
97
+ )
98
+
99
+ temb_img, temb_txt = temb[:, :6], temb[:, 6:]
100
+ norm_hidden_states, gate_msa, shift_mlp, scale_mlp, gate_mlp = self.norm1(
101
+ hidden_states, emb=temb_img
102
+ )
103
+
104
+ (
105
+ norm_encoder_hidden_states,
106
+ c_gate_msa,
107
+ c_shift_mlp,
108
+ c_scale_mlp,
109
+ c_gate_mlp,
110
+ ) = self.norm1_context(encoder_hidden_states, emb=temb_txt)
111
+ joint_attention_kwargs = joint_attention_kwargs or {}
112
+ if attention_mask is not None:
113
+ attention_mask = (
114
+ attention_mask[:, None, None, :] * attention_mask[:, None, :, None]
115
+ )
116
+
117
+ # Attention.
118
+ attention_outputs = self.attn(
119
+ hidden_states=norm_hidden_states,
120
+ encoder_hidden_states=norm_encoder_hidden_states,
121
+ image_rotary_emb=image_rotary_emb,
122
+ attention_mask=attention_mask,
123
+ **joint_attention_kwargs,
124
+ )
125
+
126
+ if len(attention_outputs) == 2:
127
+ attn_output, context_attn_output = attention_outputs
128
+ elif len(attention_outputs) == 3:
129
+ attn_output, context_attn_output, ip_attn_output = attention_outputs
130
+
131
+ # Process attention outputs for the `hidden_states`.
132
+ attn_output = gate_msa.unsqueeze(1) * attn_output
133
+ hidden_states = hidden_states + attn_output
134
+
135
+ norm_hidden_states = self.norm2(hidden_states)
136
+ norm_hidden_states = (
137
+ norm_hidden_states * (1 + scale_mlp[:, None]) + shift_mlp[:, None]
138
+ )
139
+
140
+ ff_output = self.ff(norm_hidden_states)
141
+ ff_output = gate_mlp.unsqueeze(1) * ff_output
142
+
143
+ hidden_states = hidden_states + ff_output
144
+ if len(attention_outputs) == 3:
145
+ hidden_states = hidden_states + ip_attn_output
146
+
147
+ # Process attention outputs for the `encoder_hidden_states`.
148
+
149
+ context_attn_output = c_gate_msa.unsqueeze(1) * context_attn_output
150
+ encoder_hidden_states = encoder_hidden_states + context_attn_output
151
+
152
+ norm_encoder_hidden_states = self.norm2_context(encoder_hidden_states)
153
+ norm_encoder_hidden_states = (
154
+ norm_encoder_hidden_states * (1 + c_scale_mlp[:, None])
155
+ + c_shift_mlp[:, None]
156
+ )
157
+
158
+ context_ff_output = self.ff_context(norm_encoder_hidden_states)
159
+ encoder_hidden_states = (
160
+ encoder_hidden_states + c_gate_mlp.unsqueeze(1) * context_ff_output
161
+ )
162
+ if encoder_hidden_states.dtype == torch.float16:
163
+ encoder_hidden_states = encoder_hidden_states.clip(-65504, 65504)
164
+
165
+ return encoder_hidden_states, hidden_states
166
+
167
+
72
168
  # adapted from: https://github.com/huggingface/diffusers/blob/main/src/diffusers/models/transformers/transformer_chroma.py
73
169
  def __patch_single_forward__(
74
170
  self: ChromaSingleTransformerBlock, # Almost same as FluxSingleTransformerBlock
75
171
  hidden_states: torch.Tensor,
76
- encoder_hidden_states: torch.Tensor,
77
- temb: torch.Tensor,
172
+ pooled_temb: torch.Tensor,
78
173
  image_rotary_emb: Optional[Tuple[torch.Tensor, torch.Tensor]] = None,
174
+ attention_mask: Optional[torch.Tensor] = None,
79
175
  joint_attention_kwargs: Optional[Dict[str, Any]] = None,
80
- ) -> Tuple[torch.Tensor, torch.Tensor]:
81
- text_seq_len = encoder_hidden_states.shape[1]
82
- hidden_states = torch.cat([encoder_hidden_states, hidden_states], dim=1)
176
+ ) -> torch.Tensor:
177
+ # TODO: Fuse controlnet into block forward
178
+ start_idx = self._start_idx
179
+ temb = pooled_temb[:, start_idx : start_idx + 3]
83
180
 
84
181
  residual = hidden_states
85
182
  norm_hidden_states, gate = self.norm(hidden_states, emb=temb)
86
183
  mlp_hidden_states = self.act_mlp(self.proj_mlp(norm_hidden_states))
87
184
  joint_attention_kwargs = joint_attention_kwargs or {}
185
+
186
+ if attention_mask is not None:
187
+ attention_mask = (
188
+ attention_mask[:, None, None, :] * attention_mask[:, None, :, None]
189
+ )
190
+
88
191
  attn_output = self.attn(
89
192
  hidden_states=norm_hidden_states,
90
193
  image_rotary_emb=image_rotary_emb,
194
+ attention_mask=attention_mask,
91
195
  **joint_attention_kwargs,
92
196
  )
93
197
 
@@ -98,11 +202,7 @@ def __patch_single_forward__(
98
202
  if hidden_states.dtype == torch.float16:
99
203
  hidden_states = hidden_states.clip(-65504, 65504)
100
204
 
101
- encoder_hidden_states, hidden_states = (
102
- hidden_states[:, :text_seq_len],
103
- hidden_states[:, text_seq_len:],
104
- )
105
- return encoder_hidden_states, hidden_states
205
+ return hidden_states
106
206
 
107
207
 
108
208
  # Adapted from: https://github.com/huggingface/diffusers/blob/main/src/diffusers/models/transformers/transformer_chroma.py
@@ -174,24 +274,13 @@ def __patch_transformer_forward__(
174
274
  joint_attention_kwargs.update({"ip_hidden_states": ip_hidden_states})
175
275
 
176
276
  for index_block, block in enumerate(self.transformer_blocks):
177
- img_offset = 3 * len(self.single_transformer_blocks)
178
- txt_offset = img_offset + 6 * len(self.transformer_blocks)
179
- img_modulation = img_offset + 6 * index_block
180
- text_modulation = txt_offset + 6 * index_block
181
- temb = torch.cat(
182
- (
183
- pooled_temb[:, img_modulation : img_modulation + 6],
184
- pooled_temb[:, text_modulation : text_modulation + 6],
185
- ),
186
- dim=1,
187
- )
188
277
  if torch.is_grad_enabled() and self.gradient_checkpointing:
189
278
  encoder_hidden_states, hidden_states = (
190
279
  self._gradient_checkpointing_func(
191
280
  block,
192
281
  hidden_states,
193
282
  encoder_hidden_states,
194
- temb,
283
+ pooled_temb,
195
284
  image_rotary_emb,
196
285
  attention_mask,
197
286
  )
@@ -201,12 +290,13 @@ def __patch_transformer_forward__(
201
290
  encoder_hidden_states, hidden_states = block(
202
291
  hidden_states=hidden_states,
203
292
  encoder_hidden_states=encoder_hidden_states,
204
- temb=temb,
293
+ pooled_temb=pooled_temb,
205
294
  image_rotary_emb=image_rotary_emb,
206
295
  attention_mask=attention_mask,
207
296
  joint_attention_kwargs=joint_attention_kwargs,
208
297
  )
209
298
 
299
+ # TODO: Fuse controlnet into block forward
210
300
  # controlnet residual
211
301
  if controlnet_block_samples is not None:
212
302
  interval_control = len(self.transformer_blocks) / len(
@@ -227,43 +317,43 @@ def __patch_transformer_forward__(
227
317
  + controlnet_block_samples[index_block // interval_control]
228
318
  )
229
319
 
320
+ hidden_states = torch.cat([encoder_hidden_states, hidden_states], dim=1)
321
+
230
322
  for index_block, block in enumerate(self.single_transformer_blocks):
231
- start_idx = 3 * index_block
232
- temb = pooled_temb[:, start_idx : start_idx + 3]
233
323
  if torch.is_grad_enabled() and self.gradient_checkpointing:
234
- encoder_hidden_states, hidden_states = (
235
- self._gradient_checkpointing_func(
236
- block,
237
- hidden_states,
238
- encoder_hidden_states,
239
- temb,
240
- image_rotary_emb,
241
- )
324
+ hidden_states = self._gradient_checkpointing_func(
325
+ block,
326
+ hidden_states,
327
+ pooled_temb,
328
+ image_rotary_emb,
329
+ attention_mask,
330
+ joint_attention_kwargs,
242
331
  )
243
332
 
244
333
  else:
245
- encoder_hidden_states, hidden_states = block(
334
+ hidden_states = block(
246
335
  hidden_states=hidden_states,
247
- encoder_hidden_states=encoder_hidden_states,
248
- temb=temb,
336
+ pooled_temb=pooled_temb,
249
337
  image_rotary_emb=image_rotary_emb,
250
338
  attention_mask=attention_mask,
251
339
  joint_attention_kwargs=joint_attention_kwargs,
252
340
  )
253
341
 
342
+ # TODO: Fuse controlnet into block forward
254
343
  # controlnet residual
255
344
  if controlnet_single_block_samples is not None:
256
345
  interval_control = len(self.single_transformer_blocks) / len(
257
346
  controlnet_single_block_samples
258
347
  )
259
348
  interval_control = int(np.ceil(interval_control))
260
- hidden_states = (
261
- hidden_states
349
+ hidden_states[:, encoder_hidden_states.shape[1] :, ...] = (
350
+ hidden_states[:, encoder_hidden_states.shape[1] :, ...]
262
351
  + controlnet_single_block_samples[
263
352
  index_block // interval_control
264
353
  ]
265
354
  )
266
355
 
356
+ hidden_states = hidden_states[:, encoder_hidden_states.shape[1] :, ...]
267
357
  temb = pooled_temb[:, -2:]
268
358
  hidden_states = self.norm_out(hidden_states, temb)
269
359
  output = self.proj_out(hidden_states)
@@ -0,0 +1,130 @@
1
+ import torch
2
+ import torch.nn.functional as F
3
+
4
+ from typing import Optional, Dict, Any
5
+ from diffusers.models.transformers.dit_transformer_2d import (
6
+ DiTTransformer2DModel,
7
+ Transformer2DModelOutput,
8
+ )
9
+ from cache_dit.cache_factory.patch_functors.functor_base import (
10
+ PatchFunctor,
11
+ )
12
+ from cache_dit.logger import init_logger
13
+
14
+ logger = init_logger(__name__)
15
+
16
+
17
+ class DiTPatchFunctor(PatchFunctor):
18
+
19
+ def apply(
20
+ self,
21
+ transformer: DiTTransformer2DModel,
22
+ **kwargs,
23
+ ) -> DiTTransformer2DModel:
24
+ if hasattr(transformer, "_is_patched"):
25
+ return transformer
26
+
27
+ is_patched = False
28
+
29
+ transformer._norm1_emb = transformer.transformer_blocks[0].norm1.emb
30
+
31
+ is_patched = True
32
+
33
+ cls_name = transformer.__class__.__name__
34
+
35
+ if is_patched:
36
+ logger.warning(f"Patched {cls_name} for cache-dit.")
37
+ assert not getattr(transformer, "_is_parallelized", False), (
38
+ "Please call `cache_dit.enable_cache` before Parallelize, "
39
+ "the __patch_transformer_forward__ will overwrite the "
40
+ "parallized forward and cause a downgrade of performance."
41
+ )
42
+ transformer.forward = __patch_transformer_forward__.__get__(
43
+ transformer
44
+ )
45
+
46
+ transformer._is_patched = is_patched # True or False
47
+
48
+ logger.info(
49
+ f"Applied {self.__class__.__name__} for {cls_name}, "
50
+ f"Patch: {is_patched}."
51
+ )
52
+
53
+ return transformer
54
+
55
+
56
+ def __patch_transformer_forward__(
57
+ self: DiTTransformer2DModel,
58
+ hidden_states: torch.Tensor,
59
+ timestep: Optional[torch.LongTensor] = None,
60
+ class_labels: Optional[torch.LongTensor] = None,
61
+ cross_attention_kwargs: Dict[str, Any] = None,
62
+ return_dict: bool = True,
63
+ ):
64
+ height, width = (
65
+ hidden_states.shape[-2] // self.patch_size,
66
+ hidden_states.shape[-1] // self.patch_size,
67
+ )
68
+ hidden_states = self.pos_embed(hidden_states)
69
+
70
+ # 2. Blocks
71
+ for block in self.transformer_blocks:
72
+ if torch.is_grad_enabled() and self.gradient_checkpointing:
73
+ hidden_states = self._gradient_checkpointing_func(
74
+ block,
75
+ hidden_states,
76
+ None,
77
+ None,
78
+ None,
79
+ timestep,
80
+ cross_attention_kwargs,
81
+ class_labels,
82
+ )
83
+ else:
84
+ hidden_states = block(
85
+ hidden_states,
86
+ attention_mask=None,
87
+ encoder_hidden_states=None,
88
+ encoder_attention_mask=None,
89
+ timestep=timestep,
90
+ cross_attention_kwargs=cross_attention_kwargs,
91
+ class_labels=class_labels,
92
+ )
93
+
94
+ # 3. Output
95
+ # conditioning = self.transformer_blocks[0].norm1.emb(timestep, class_labels, hidden_dtype=hidden_states.dtype)
96
+ conditioning = self._norm1_emb(
97
+ timestep, class_labels, hidden_dtype=hidden_states.dtype
98
+ )
99
+ shift, scale = self.proj_out_1(F.silu(conditioning)).chunk(2, dim=1)
100
+ hidden_states = (
101
+ self.norm_out(hidden_states) * (1 + scale[:, None]) + shift[:, None]
102
+ )
103
+ hidden_states = self.proj_out_2(hidden_states)
104
+
105
+ # unpatchify
106
+ height = width = int(hidden_states.shape[1] ** 0.5)
107
+ hidden_states = hidden_states.reshape(
108
+ shape=(
109
+ -1,
110
+ height,
111
+ width,
112
+ self.patch_size,
113
+ self.patch_size,
114
+ self.out_channels,
115
+ )
116
+ )
117
+ hidden_states = torch.einsum("nhwpqc->nchpwq", hidden_states)
118
+ output = hidden_states.reshape(
119
+ shape=(
120
+ -1,
121
+ self.out_channels,
122
+ height * self.patch_size,
123
+ width * self.patch_size,
124
+ )
125
+ )
126
+
127
+ if not return_dict:
128
+ return (output,)
129
+
130
+ return Transformer2DModelOutput(sample=output)
@@ -1,6 +1,6 @@
1
1
  Metadata-Version: 2.4
2
2
  Name: cache_dit
3
- Version: 0.2.33
3
+ Version: 0.2.34
4
4
  Summary: 🤗 A Unified and Training-free Cache Acceleration Toolbox for Diffusion Transformers
5
5
  Author: DefTruth, vipshop.com, etc.
6
6
  Maintainer: DefTruth, vipshop.com, etc
@@ -60,40 +60,97 @@ Dynamic: requires-python
60
60
  </p>
61
61
  <p align="center">
62
62
  🎉Now, <b>cache-dit</b> covers <b>most</b> mainstream Diffusers' <b>DiT</b> Pipelines🎉<br>
63
- 🔥<a href="#supported">Qwen-Image</a> | <a href="#supported">FLUX.1</a> | <a href="#supported">Qwen-Image-Lightning</a> | <a href="#supported"> Wan 2.1/2.2 </a>🔥<br>
64
- 🔥<a href="#supported">HunyuanImage-2.1</a> | <a href="#supported">HunyuanVideo</a> | <a href="#supported">HunyuanDiT</a> | <a href="#supported">HiDream</a> | <a href="#supported">Mochi</a>🔥<br>
65
- 🔥<a href="#supported">CogView3Plus</a> | <a href="#supported">CogView4</a> | <a href="#supported">Chroma</a> | <a href="#supported"> LTXVideo </a> | <a href="#supported">CogVideoX 1/1.5</a>🔥<br>
66
- 🔥<a href="#supported">Cosmos</a> | <a href="#supported">SkyReelsV2</a> | <a href="#supported">VisualCloze</a> | <a href="#supported"> OmniGen </a> | <a href="#supported">Lumina 1/2</a>🔥<br>
67
- 🔥<a href="#supported">Allegro</a> | <a href="#supported">EasyAnimate</a> | <a href="#supported">SD 3/3.5</a> | <a href="#supported"> ... </a> | <a href="#supported">PixArt</a>🔥
63
+ 🔥<a href="#supported">Qwen-Image</a> | <a href="#supported">FLUX.1</a> | <a href="#supported">Qwen-Image-Lightning</a> | <a href="#supported"> Wan 2.1 </a> | <a href="#supported"> Wan 2.2 </a>🔥<br>
64
+ 🔥<a href="#supported">HunyuanImage-2.1</a> | <a href="#supported">HunyuanVideo</a> | <a href="#supported">HunyuanDiT</a> | <a href="#supported">HiDream</a> | <a href="#supported">AuraFlow</a>🔥<br>
65
+ 🔥<a href="#supported">CogView3Plus</a> | <a href="#supported">CogView4</a> | <a href="#supported">LTXVideo</a> | <a href="#supported">CogVideoX</a> | <a href="#supported">CogVideoX 1.5</a> | <a href="#supported">ConsisID</a>🔥<br>
66
+ 🔥<a href="#supported">Cosmos</a> | <a href="#supported">SkyReelsV2</a> | <a href="#supported">VisualCloze</a> | <a href="#supported">OmniGen 1/2</a> | <a href="#supported">Lumina 1/2</a> | <a href="#supported">PixArt</a>🔥<br>
67
+ 🔥<a href="#supported">Chroma</a> | <a href="#supported">Sana</a> | <a href="#supported">Allegro</a> | <a href="#supported">Mochi</a> | <a href="#supported">SD 3/3.5</a> | <a href="#supported">Amused</a> | <a href="#supported"> ... </a> | <a href="#supported">DiT-XL</a>🔥
68
68
  </p>
69
69
  </div>
70
70
  <div align='center'>
71
71
  <img src=https://github.com/vipshop/cache-dit/raw/main/assets/gifs/wan2.2.C0_Q0_NONE.gif width=124px>
72
72
  <img src=https://github.com/vipshop/cache-dit/raw/main/assets/gifs/wan2.2.C1_Q0_DBCACHE_F1B0_W2M8MC2_T1O2_R0.08.gif width=124px>
73
- <img src=./assets/gifs/hunyuan_video.C0_L0_Q0_NONE.gif width=126px>
74
- <img src=./assets/gifs/hunyuan_video.C0_L0_Q0_DBCACHE_F1B0_W8M0MC2_T0O2_R0.12_S27.gif width=126px>
73
+ <img src=https://github.com/vipshop/cache-dit/raw/main/assets/gifs/hunyuan_video.C0_L0_Q0_NONE.gif width=126px>
74
+ <img src=https://github.com/vipshop/cache-dit/raw/main/assets/gifs/hunyuan_video.C0_L0_Q0_DBCACHE_F1B0_W8M0MC2_T0O2_R0.12_S27.gif width=126px>
75
75
  <p><b>🔥Wan2.2 MoE</b> | <a href="https://github.com/vipshop/cache-dit">+cache-dit</a>:2.0x↑🎉 | <b>HunyuanVideo</b> | <a href="https://github.com/vipshop/cache-dit">+cache-dit</a>:2.1x↑🎉</p>
76
76
  <img src=https://github.com/vipshop/cache-dit/raw/main/assets/qwen-image.C0_Q0_NONE.png width=160px>
77
77
  <img src=https://github.com/vipshop/cache-dit/raw/main/assets/qwen-image.C1_Q0_DBCACHE_F8B0_W8M0MC0_T1O4_R0.12_S23.png width=160px>
78
- <img src=./assets/flux.C0_Q0_NONE_T23.69s.png width=90px>
79
- <img src=./assets/flux.C0_Q0_DBCACHE_F1B0_W4M0MC0_T1O2_R0.15_S16_T11.39s.png width=90px>
78
+ <img src=https://github.com/vipshop/cache-dit/raw/main/assets/flux.C0_Q0_NONE_T23.69s.png width=90px>
79
+ <img src=https://github.com/vipshop/cache-dit/raw/main/assets/flux.C0_Q0_DBCACHE_F1B0_W4M0MC0_T1O2_R0.15_S16_T11.39s.png width=90px>
80
80
  <p><b>🔥Qwen-Image</b> | <a href="https://github.com/vipshop/cache-dit">+cache-dit</a>:1.8x↑🎉 | <b>FLUX.1-dev</b> | <a href="https://github.com/vipshop/cache-dit">+cache-dit</a>:2.1x↑🎉</p>
81
+ <img src=https://github.com/vipshop/cache-dit/raw/main/assets/flux-kontext-cat.C0_L0_Q0_NONE.png width=100px>
82
+ <img src=https://github.com/vipshop/cache-dit/raw/main/assets/flux-kontext.C0_L0_Q0_NONE.png width=100px>
83
+ <img src=https://github.com/vipshop/cache-dit/raw/main/assets/flux-kontext.C0_L0_Q0_DBCACHE_F8B0_W8M0MC0_T0O2_R0.08_S10.png width=100px>
84
+ <img src=https://github.com/vipshop/cache-dit/raw/main/assets/flux-kontext.C0_L0_Q0_DBCACHE_F1B0_W8M0MC2_T0O2_R0.12_S12.png width=100px>
85
+ <img src=https://github.com/vipshop/cache-dit/raw/main/assets/flux-kontext.C0_L0_Q0_DBCACHE_F1B0_W2M0MC2_T0O2_R0.15_S15.png width=100px>
86
+ <p><b>🔥FLUX-Kontext-dev</b> | Baseline | <a href="https://github.com/vipshop/cache-dit">+cache-dit</a>:1.3x↑🎉 | 1.7x↑🎉 | 2.0x↑ 🎉</p>
81
87
  <img src=https://github.com/vipshop/cache-dit/raw/main/assets/qwen-image-lightning.4steps.C0_L1_Q0_NONE.png width=160px>
82
88
  <img src=https://github.com/vipshop/cache-dit/raw/main/assets/qwen-image-lightning.4steps.C0_L1_Q0_DBCACHE_F16B16_W2M1MC1_T0O2_R0.9_S1.png width=160px>
83
- <img src=./assets/sd_3_5.C0_L0_Q0_NONE.png width=90px>
84
- <img src=./assets/sd_3_5.C0_L0_Q0_DBCACHE_F1B0_W8M0MC3_T0O2_R0.12_S30.png width=90px>
85
- <p><b>🔥Qwen-Image-Lightning</b> | <a href="https://github.com/vipshop/cache-dit">+cache-dit</a>:1.14x↑🎉 | <b>SD 3.5</b> | <a href="https://github.com/vipshop/cache-dit">+cache-dit</a>:2.5x↑🎉</p>
86
- <img src=./assets/hidream.C0_L0_Q0_NONE.png width=100px>
87
- <img src=./assets/hidream.C0_L0_Q0_DBCACHE_F1B0_W8M0MC0_T0O2_R0.08_S24.png width=100px>
88
- <img src=./assets/cogview4.C0_L0_Q0_NONE.png width=100px>
89
- <img src=./assets/cogview4.C0_L0_Q0_DBCACHE_F8B0_W8M0MC0_T0O2_R0.08_S15.png width=100px>
90
- <img src=./assets/cogview4.C0_L0_Q0_DBCACHE_F1B0_W4M0MC4_T0O2_R0.2_S22.png width=100px>
89
+ <img src=https://github.com/vipshop/cache-dit/raw/main/assets/hunyuan-image-2.1.C0_L0_Q1_fp8_w8a16_wo_NONE.png width=90px>
90
+ <img src=https://github.com/vipshop/cache-dit/raw/main/assets/hunyuan-image-2.1.C0_L0_Q1_fp8_w8a16_wo_DBCACHE_F8B0_W8M0MC2_T1O2_R0.12_S25.png width=90px>
91
+ <p><b>🔥Qwen...Lightning</b> | <a href="https://github.com/vipshop/cache-dit">+cache-dit</a>:1.14x↑🎉 | <b>HunyuanImage</b> | <a href="https://github.com/vipshop/cache-dit">+cache-dit</a>:1.7x↑🎉</p>
92
+ <img src=https://github.com/vipshop/cache-dit/raw/main/examples/data/bear.png width=125px>
93
+ <img src=https://github.com/vipshop/cache-dit/raw/main/assets/qwen-image-edit.C0_L0_Q0_NONE.png width=125px>
94
+ <img src=https://github.com/vipshop/cache-dit/raw/main/assets/qwen-image-edit.C0_L0_Q0_DBCACHE_F8B0_W8M0MC0_T0O2_R0.08_S18.png width=125px>
95
+ <img src=https://github.com/vipshop/cache-dit/raw/main/assets/qwen-image-edit.C0_L0_Q0_DBCACHE_F1B0_W8M0MC2_T0O2_R0.12_S24.png width=125px>
96
+ <p><b>🔥Qwen-Image-Edit</b> | Input w/o Edit | Baseline | <a href="https://github.com/vipshop/cache-dit">+cache-dit</a>:1.6x↑🎉 | 1.9x↑🎉 </p>
97
+ <img src=https://github.com/vipshop/cache-dit/raw/main/assets/hidream.C0_L0_Q0_NONE.png width=100px>
98
+ <img src=https://github.com/vipshop/cache-dit/raw/main/assets/hidream.C0_L0_Q0_DBCACHE_F1B0_W8M0MC0_T0O2_R0.08_S24.png width=100px>
99
+ <img src=https://github.com/vipshop/cache-dit/raw/main/assets/cogview4.C0_L0_Q0_NONE.png width=100px>
100
+ <img src=https://github.com/vipshop/cache-dit/raw/main/assets/cogview4.C0_L0_Q0_DBCACHE_F8B0_W8M0MC0_T0O2_R0.08_S15.png width=100px>
101
+ <img src=https://github.com/vipshop/cache-dit/raw/main/assets/cogview4.C0_L0_Q0_DBCACHE_F1B0_W4M0MC4_T0O2_R0.2_S22.png width=100px>
91
102
  <p><b>🔥HiDream-I1</b> | <a href="https://github.com/vipshop/cache-dit">+cache-dit</a>:1.9x↑🎉 | <b>CogView4</b> | <a href="https://github.com/vipshop/cache-dit">+cache-dit</a>:1.4x↑🎉 | 1.7x↑🎉</p>
92
- <img src=./assets/gifs/mochi.C0_L0_Q0_NONE.gif width=160px>
93
- <img src=./assets/gifs/mochi.C0_L0_Q0_DBCACHE_F8B0_W8M0MC0_T0O2_R0.08_S34.gif width=160px>
94
- <img src=./assets/hunyuan-image-2.1.C0_L0_Q1_fp8_w8a16_wo_NONE.png width=91px>
95
- <img src=./assets/hunyuan-image-2.1.C0_L0_Q1_fp8_w8a16_wo_DBCACHE_F8B0_W8M0MC2_T1O2_R0.12_S25.png width=91px>
96
- <p><b>🔥Mochi-1</b> | <a href="https://github.com/vipshop/cache-dit">+cache-dit</a>:1.8x↑🎉 | <b>HunyuanImage-2.1</b> | <a href="https://github.com/vipshop/cache-dit">+cache-dit</a>:1.7x↑🎉
103
+ <img src=https://github.com/vipshop/cache-dit/raw/main/assets/cogview3_plus.C0_L0_Q0_NONE.png width=100px>
104
+ <img src=https://github.com/vipshop/cache-dit/raw/main/assets/cogview3_plus.C0_L0_Q0_DBCACHE_F8B0_W8M0MC0_T0O2_R0.08_S15.png width=100px>
105
+ <img src=https://github.com/vipshop/cache-dit/raw/main/assets/cogview3_plus.C0_L0_Q0_DBCACHE_F1B0_W8M0MC2_T0O2_R0.08_S25.png width=100px>
106
+ <img src=https://github.com/vipshop/cache-dit/raw/main/assets/chroma1-hd.C0_L0_Q0_NONE.png width=100px>
107
+ <img src=https://github.com/vipshop/cache-dit/raw/main/assets/chroma1-hd.C0_L0_Q0_DBCACHE_F1B0_W8M0MC0_T0O2_R0.08_S20.png width=100px>
108
+ <p><b>🔥CogView3</b> | <a href="https://github.com/vipshop/cache-dit">+cache-dit</a>:1.5x↑🎉 | 2.0x↑🎉| <b>Chroma1-HD</b> | <a href="https://github.com/vipshop/cache-dit">+cache-dit</a>:1.9x↑🎉</p>
109
+ <img src=https://github.com/vipshop/cache-dit/raw/main/assets/gifs/mochi.C0_L0_Q0_NONE.gif width=125px>
110
+ <img src=https://github.com/vipshop/cache-dit/raw/main/assets/gifs/mochi.C0_L0_Q0_DBCACHE_F8B0_W8M0MC0_T0O2_R0.08_S34.gif width=125px>
111
+ <img src=https://github.com/vipshop/cache-dit/raw/main/assets/gifs/skyreels_v2.C0_L0_Q0_NONE.gif width=125px>
112
+ <img src=https://github.com/vipshop/cache-dit/raw/main/assets/gifs/skyreels_v2.C0_L0_Q0_DBCACHE_F8B0_W8M0MC0_T0O2_R0.12_S17.gif width=125px>
113
+ <p><b>🔥Mochi-1-preview</b> | <a href="https://github.com/vipshop/cache-dit">+cache-dit</a>:1.8x↑🎉 | <b>SkyReelsV2</b> | <a href="https://github.com/vipshop/cache-dit">+cache-dit</a>:1.6x↑🎉</p>
114
+ <img src=./examples/data/visualcloze/00555_00.jpg width=100px>
115
+ <img src=./examples/data/visualcloze/12265_00.jpg width=100px>
116
+ <img src=https://github.com/vipshop/cache-dit/raw/main/assets/visualcloze-512.C0_L0_Q0_NONE.png width=100px>
117
+ <img src=https://github.com/vipshop/cache-dit/raw/main/assets/visualcloze-512.C0_L0_Q0_DBCACHE_F8B0_W8M0MC0_T0O2_R0.08_S15.png width=100px>
118
+ <img src=https://github.com/vipshop/cache-dit/raw/main/assets/visualcloze-512.C0_L0_Q0_DBCACHE_F1B0_W8M0MC0_T0O2_R0.08_S18.png width=100px>
119
+ <p><b>🔥VisualCloze-512</b> | Model | Cloth | Baseline | <a href="https://github.com/vipshop/cache-dit">+cache-dit</a>:1.4x↑🎉 | 1.7x↑🎉 </p>
120
+ <img src=https://github.com/vipshop/cache-dit/raw/main/assets/gifs/ltx-video.C0_L0_Q0_NONE.gif width=144px>
121
+ <img src=https://github.com/vipshop/cache-dit/raw/main/assets/gifs/ltx-video.C0_L0_Q0_DBCACHE_F1B0_W8M0MC0_T0O2_R0.15_S13.gif width=144px>
122
+ <img src=https://github.com/vipshop/cache-dit/raw/main/assets/gifs/cogvideox1.5.C0_L0_Q0_NONE.gif width=105px>
123
+ <img src=https://github.com/vipshop/cache-dit/raw/main/assets/gifs/cogvideox1.5.C0_L0_Q0_DBCACHE_F1B0_W8M0MC0_T0O2_R0.12_S22.gif width=105px>
124
+ <p><b>🔥LTX-Video-0.9.7</b> | <a href="https://github.com/vipshop/cache-dit">+cache-dit</a>:1.7x↑🎉 | <b>CogVideoX1.5</b> | <a href="https://github.com/vipshop/cache-dit">+cache-dit</a>:2.0x↑🎉</p>
125
+ <img src=https://github.com/vipshop/cache-dit/raw/main/assets/omingen-v1.C0_L0_Q0_NONE.png width=100px>
126
+ <img src=https://github.com/vipshop/cache-dit/raw/main/assets/omingen-v1.C0_L0_Q0_DBCACHE_F8B0_W8M0MC0_T0O2_R0.08_S24.png width=100px>
127
+ <img src=https://github.com/vipshop/cache-dit/raw/main/assets/omingen-v1.C0_L0_Q0_DBCACHE_F1B0_W8M0MC0_T1O2_R0.08_S38.png width=100px>
128
+ <img src=https://github.com/vipshop/cache-dit/raw/main/assets/lumina2.C0_L0_Q0_NONE.png width=100px>
129
+ <img src=https://github.com/vipshop/cache-dit/raw/main/assets/lumina2.C0_L0_Q0_DBCACHE_F1B0_W2M0MC2_T0O2_R0.12_S14.png width=100px>
130
+ <p><b>🔥OmniGen-v1</b> | <a href="https://github.com/vipshop/cache-dit">+cache-dit</a>:1.5x↑🎉 | 3.3x↑🎉 | <b>Lumina2</b> | <a href="https://github.com/vipshop/cache-dit">+cache-dit</a>:1.9x↑🎉</p>
131
+ <img src=https://github.com/vipshop/cache-dit/raw/main/assets/gifs/allegro.C0_L0_Q0_NONE.gif width=117px>
132
+ <img src=https://github.com/vipshop/cache-dit/raw/main/assets/gifs/allegro.C0_L0_Q0_DBCACHE_F8B0_W8M0MC0_T0O2_R0.26_S27.gif width=117px>
133
+ <img src=https://github.com/vipshop/cache-dit/raw/main/assets/auraflow.C0_L0_Q0_NONE.png width=133px>
134
+ <img src=https://github.com/vipshop/cache-dit/raw/main/assets/auraflow.C0_L0_Q0_DBCACHE_F1B0_W8M0MC2_T0O2_R0.08_S28.png width=133px>
135
+ <p><b>🔥Allegro</b> | <a href="https://github.com/vipshop/cache-dit">+cache-dit</a>:1.36x↑🎉 | <b>AuraFlow-v0.3</b> | <a href="https://github.com/vipshop/cache-dit">+cache-dit</a>:2.27x↑🎉 </p>
136
+ <img src=https://github.com/vipshop/cache-dit/raw/main/assets/sana.C0_L0_Q0_NONE.png width=100px>
137
+ <img src=https://github.com/vipshop/cache-dit/raw/main/assets/sana.C0_L0_Q0_DBCACHE_F8B0_W8M0MC2_T0O2_R0.25_S6.png width=100px>
138
+ <img src=https://github.com/vipshop/cache-dit/raw/main/assets/sana.C0_L0_Q0_DBCACHE_F1B0_W8M0MC2_T0O2_R0.3_S8.png width=100px>
139
+ <img src=https://github.com/vipshop/cache-dit/raw/main/assets/pixart-sigma.C0_L0_Q0_NONE.png width=100px>
140
+ <img src=https://github.com/vipshop/cache-dit/raw/main/assets/pixart-sigma.C0_L0_Q0_DBCACHE_F8B0_W8M0MC0_T0O2_R0.08_S28.png width=100px>
141
+ <p><b>🔥Sana</b> | <a href="https://github.com/vipshop/cache-dit">+cache-dit</a>:1.3x↑🎉 | 1.6x↑🎉| <b>PixArt-Sigma</b> | <a href="https://github.com/vipshop/cache-dit">+cache-dit</a>:2.3x↑🎉</p>
142
+ <img src=https://github.com/vipshop/cache-dit/raw/main/assets/pixart-alpha.C0_L0_Q0_NONE.png width=100px>
143
+ <img src=https://github.com/vipshop/cache-dit/raw/main/assets/pixart-alpha.C0_L0_Q0_DBCACHE_F8B0_W8M0MC0_T0O2_R0.05_S27.png width=100px>
144
+ <img src=https://github.com/vipshop/cache-dit/raw/main/assets/pixart-alpha.C0_L0_Q0_DBCACHE_F8B0_W8M0MC0_T0O2_R0.08_S32.png width=100px>
145
+ <img src=https://github.com/vipshop/cache-dit/raw/main/assets/sd_3_5.C0_L0_Q0_NONE.png width=100px>
146
+ <img src=https://github.com/vipshop/cache-dit/raw/main/assets/sd_3_5.C0_L0_Q0_DBCACHE_F1B0_W8M0MC3_T0O2_R0.12_S30.png width=100px>
147
+ <p><b>🔥PixArt-Alpha</b> | <a href="https://github.com/vipshop/cache-dit">+cache-dit</a>:1.6x↑🎉 | 1.8x↑🎉| <b>SD 3.5</b> | <a href="https://github.com/vipshop/cache-dit">+cache-dit</a>:2.5x↑🎉</p>
148
+ <img src=https://github.com/vipshop/cache-dit/raw/main/assets/amused.C0_L0_Q0_NONE.png width=100px>
149
+ <img src=https://github.com/vipshop/cache-dit/raw/main/assets/amused.C0_L0_Q0_DBCACHE_F8B0_W8M0MC0_T0O2_R0.34_S1.png width=100px>
150
+ <img src=https://github.com/vipshop/cache-dit/raw/main/assets/amused.C0_L0_Q0_DBCACHE_F8B0_W8M0MC0_T0O2_R0.38_S2.png width=100px>
151
+ <img src=https://github.com/vipshop/cache-dit/raw/main/assets/dit-xl.C0_L0_Q0_NONE.png width=100px>
152
+ <img src=https://github.com/vipshop/cache-dit/raw/main/assets/dit-xl.C0_L0_Q0_DBCACHE_F1B0_W8M0MC2_T0O2_R0.15_S11.png width=100px>
153
+ <p><b>🔥Asumed</b> | <a href="https://github.com/vipshop/cache-dit">+cache-dit</a>:1.1x↑🎉 | 1.2x↑🎉 | <b>DiT-XL-256</b> | <a href="https://github.com/vipshop/cache-dit">+cache-dit</a>:1.8x↑🎉
97
154
  <br>♥️ Please consider to leave a <b>⭐️ Star</b> to support us ~ ♥️</p>
98
155
  </div>
99
156
 
@@ -126,9 +183,10 @@ Dynamic: requires-python
126
183
  - [🔥Supported Models](#supported)
127
184
  - [🎉Unified Cache APIs](#unified)
128
185
  - [📚Forward Pattern Matching](#unified)
129
- - [🎉Cache with One-line Code](#unified)
186
+ - [♥️Cache with One-line Code](#unified)
130
187
  - [🔥Automatic Block Adapter](#unified)
131
188
  - [📚Hybird Forward Pattern](#unified)
189
+ - [📚Implement Patch Functor](#unified)
132
190
  - [🤖Cache Acceleration Stats](#unified)
133
191
  - [⚡️Dual Block Cache](#dbcache)
134
192
  - [🔥Hybrid TaylorSeer](#taylorseer)
@@ -161,9 +219,15 @@ Currently, **cache-dit** library supports almost **Any** Diffusion Transformers
161
219
  - [🚀Qwen-Image-Lightning](https://github.com/vipshop/cache-dit/raw/main/examples)
162
220
  - [🚀Qwen-Image-Edit](https://github.com/vipshop/cache-dit/raw/main/examples)
163
221
  - [🚀Qwen-Image](https://github.com/vipshop/cache-dit/raw/main/examples)
222
+ - [🚀SkyReelsV2](https://github.com/vipshop/cache-dit/raw/main/examples)
223
+ - [🚀LTXVideo](https://github.com/vipshop/cache-dit/raw/main/examples)
224
+ - [🚀OmniGen](https://github.com/vipshop/cache-dit/raw/main/examples)
225
+ - [🚀Lumina2](https://github.com/vipshop/cache-dit/raw/main/examples)
164
226
  - [🚀FLUX.1-dev](https://github.com/vipshop/cache-dit/raw/main/examples)
165
227
  - [🚀FLUX.1-Fill-dev](https://github.com/vipshop/cache-dit/raw/main/examples)
166
228
  - [🚀FLUX.1-Kontext-dev](https://github.com/vipshop/cache-dit/raw/main/examples)
229
+ - [🚀Chroma1-HD](https://github.com/vipshop/cache-dit/raw/main/examples)
230
+ - [🚀VisualCloze](https://github.com/vipshop/cache-dit/raw/main/examples)
167
231
  - [🚀CogVideoX](https://github.com/vipshop/cache-dit/raw/main/examples)
168
232
  - [🚀CogVideoX1.5](https://github.com/vipshop/cache-dit/raw/main/examples)
169
233
  - [🚀CogView3-Plus](https://github.com/vipshop/cache-dit/raw/main/examples)
@@ -175,9 +239,16 @@ Currently, **cache-dit** library supports almost **Any** Diffusion Transformers
175
239
  - [🚀HunyuanVideo](https://github.com/vipshop/cache-dit/raw/main/examples)
176
240
  - [🚀HunyuanDiT](https://github.com/vipshop/cache-dit/raw/main/examples)
177
241
  - [🚀HiDream-I1-Full](https://github.com/vipshop/cache-dit/raw/main/examples)
242
+ - [🚀AuraFlow-v0.3](https://github.com/vipshop/cache-dit/raw/main/examples)
178
243
  - [🚀PixArt-Alpha](https://github.com/vipshop/cache-dit/raw/main/examples)
179
244
  - [🚀PixArt-Sigma](https://github.com/vipshop/cache-dit/raw/main/examples)
245
+ - [🚀NVIDIA Sana](https://github.com/vipshop/cache-dit/raw/main/examples)
180
246
  - [🚀SD-3/3.5](https://github.com/vipshop/cache-dit/raw/main/examples)
247
+ - [🚀ConsisID](https://github.com/vipshop/cache-dit/raw/main/examples)
248
+ - [🚀Allegro](https://github.com/vipshop/cache-dit/raw/main/examples)
249
+ - [🚀Amused](https://github.com/vipshop/cache-dit/raw/main/examples)
250
+ - [🚀DiT-XL](https://github.com/vipshop/cache-dit/raw/main/examples)
251
+ - ...
181
252
 
182
253
  </details>
183
254
 
@@ -265,6 +336,75 @@ cache_dit.enable_cache(
265
336
  )
266
337
  ```
267
338
 
339
+ Even sometimes you have more complex cases, such as **Wan 2.2 MoE**, which has more than one Transformer (namely `transformer` and `transformer_2`) in its structure. Fortunately, **cache-dit** can also handle this situation very well. Please refer to [📚Wan 2.2 MoE](./examples/pipeline/run_wan_2.2.py) as an example.
340
+
341
+ ```python
342
+ from cache_dit import ForwardPattern, BlockAdapter, ParamsModifier
343
+
344
+ cache_dit.enable_cache(
345
+ BlockAdapter(
346
+ pipe=pipe,
347
+ transformer=[
348
+ pipe.transformer,
349
+ pipe.transformer_2,
350
+ ],
351
+ blocks=[
352
+ pipe.transformer.blocks,
353
+ pipe.transformer_2.blocks,
354
+ ],
355
+ forward_pattern=[
356
+ ForwardPattern.Pattern_2,
357
+ ForwardPattern.Pattern_2,
358
+ ],
359
+ # Setup different cache params for each 'blocks'. You can
360
+ # pass any specific cache params to ParamModifier, the old
361
+ # value will be overwrite by the new one.
362
+ params_modifiers=[
363
+ ParamsModifier(
364
+ max_warmup_steps=4,
365
+ max_cached_steps=8,
366
+ ),
367
+ ParamsModifier(
368
+ max_warmup_steps=2,
369
+ max_cached_steps=20,
370
+ ),
371
+ ],
372
+ has_separate_cfg=True,
373
+ ),
374
+ )
375
+ ```
376
+ ### 📚Implement Patch Functor
377
+
378
+ For any PATTERN not {0...5}, we introduced the simple abstract concept of **Patch Functor**. Users can implement a subclass of Patch Functor to convert an unknown Pattern into a known PATTERN, and for some models, users may also need to fuse the operations within the blocks for loop into block forward.
379
+
380
+ ![](https://github.com/vipshop/cache-dit/raw/main/assets/patch-functor.png)
381
+
382
+ Some Patch functors have already been provided in cache-dit: [📚HiDreamPatchFunctor](./src/cache_dit/cache_factory/patch_functors/functor_hidream.py), [📚ChromaPatchFunctor](./src/cache_dit/cache_factory/patch_functors/functor_chroma.py), etc. After implementing Patch Functor, users need to set the `patch_functor` property of **BlockAdapter**.
383
+
384
+ ```python
385
+ @BlockAdapterRegistry.register("HiDream")
386
+ def hidream_adapter(pipe, **kwargs) -> BlockAdapter:
387
+ from diffusers import HiDreamImageTransformer2DModel
388
+ from cache_dit.cache_factory.patch_functors import HiDreamPatchFunctor
389
+
390
+ assert isinstance(pipe.transformer, HiDreamImageTransformer2DModel)
391
+ return BlockAdapter(
392
+ pipe=pipe,
393
+ transformer=pipe.transformer,
394
+ blocks=[
395
+ pipe.transformer.double_stream_blocks,
396
+ pipe.transformer.single_stream_blocks,
397
+ ],
398
+ forward_pattern=[
399
+ ForwardPattern.Pattern_0,
400
+ ForwardPattern.Pattern_3,
401
+ ],
402
+ # NOTE: Setup your custom patch functor here.
403
+ patch_functor=HiDreamPatchFunctor(),
404
+ **kwargs,
405
+ )
406
+ ```
407
+
268
408
  ### 🤖Cache Acceleration Stats Summary
269
409
 
270
410
  After finishing each inference of `pipe(...)`, you can call the `cache_dit.summary()` API on pipe to get the details of the **Cache Acceleration Stats** for the current inference.
@@ -348,7 +488,7 @@ cache_dit.enable_cache(
348
488
  # Taylorseer cache type cache be hidden_states or residual.
349
489
  taylorseer_cache_type="residual",
350
490
  # Higher values of order will lead to longer computation time
351
- taylorseer_order=2, # default is 2.
491
+ taylorseer_order=1, # default is 1.
352
492
  max_warmup_steps=3, # prefer: >= order + 1
353
493
  residual_diff_threshold=0.12
354
494
  )s
@@ -372,7 +512,7 @@ cache_dit.enable_cache(
372
512
 
373
513
  <div id="cfg"></div>
374
514
 
375
- cache-dit supports caching for **CFG (classifier-free guidance)**. For models that fuse CFG and non-CFG into a single forward step, or models that do not include CFG (classifier-free guidance) in the forward step, please set `enable_spearate_cfg` param to **False (default)**. Otherwise, set it to True. For examples:
515
+ cache-dit supports caching for **CFG (classifier-free guidance)**. For models that fuse CFG and non-CFG into a single forward step, or models that do not include CFG (classifier-free guidance) in the forward step, please set `enable_separate_cfg` param to **False (default, None)**. Otherwise, set it to True. For examples:
376
516
 
377
517
  ```python
378
518
  cache_dit.enable_cache(
@@ -380,14 +520,14 @@ cache_dit.enable_cache(
380
520
  ...,
381
521
  # CFG: classifier free guidance or not
382
522
  # For model that fused CFG and non-CFG into single forward step,
383
- # should set enable_spearate_cfg as False. For example, set it as True
523
+ # should set enable_separate_cfg as False. For example, set it as True
384
524
  # for Wan 2.1/Qwen-Image and set it as False for FLUX.1, HunyuanVideo,
385
525
  # CogVideoX, Mochi, LTXVideo, Allegro, CogView3Plus, EasyAnimate, SD3, etc.
386
- enable_spearate_cfg=True, # Wan 2.1, Qwen-Image, CogView4, Cosmos, SkyReelsV2, etc.
526
+ enable_separate_cfg=True, # Wan 2.1, Qwen-Image, CogView4, Cosmos, SkyReelsV2, etc.
387
527
  # Compute cfg forward first or not, default False, namely,
388
528
  # 0, 2, 4, ..., -> non-CFG step; 1, 3, 5, ... -> CFG step.
389
529
  cfg_compute_first=False,
390
- # Compute spearate diff values for CFG and non-CFG step,
530
+ # Compute separate diff values for CFG and non-CFG step,
391
531
  # default True. If False, we will use the computed diff from
392
532
  # current non-CFG transformer step for current CFG step.
393
533
  cfg_diff_compute_separate=True,
@@ -1,29 +1,30 @@
1
1
  cache_dit/__init__.py,sha256=kX9V-FegZG4c8LMwI4PTmMqH794MEW0pzDArdhC0cJw,1241
2
- cache_dit/_version.py,sha256=gTEHTWtuqv38KTvjBsXd5hC019b6d7AyfC8gLMY7KAo,706
2
+ cache_dit/_version.py,sha256=CtkelOzOJFXtgJ0APT8pLd5zWrG63eLavWaOD_cX7xo,706
3
3
  cache_dit/logger.py,sha256=0zsu42hN-3-rgGC_C29ms1IvVpV4_b4_SwJCKSenxBE,4304
4
4
  cache_dit/utils.py,sha256=WK7eqgH6gCYNHXNLmWyxBDU0XSHTPg7CfOcyXlGXBqE,10510
5
5
  cache_dit/cache_factory/.gitignore,sha256=5Cb-qT9wsTUoMJ7vACDF7ZcLpAXhi5v-xdcWSRit988,23
6
6
  cache_dit/cache_factory/__init__.py,sha256=Iw6-iJLFbdzCsIDZXXOw371L-HPmoeZO_P9a3sDjP5s,1103
7
- cache_dit/cache_factory/cache_adapters.py,sha256=dmNX68nBD52HtQvHnNAuSn1zjDWrQdycD0qXy-w-mwc,18212
8
- cache_dit/cache_factory/cache_interface.py,sha256=LpyCy-tQ_GcTRAYLpMMf9hFVIktABHI6CObn5Ll8bMw,8548
7
+ cache_dit/cache_factory/cache_adapters.py,sha256=OFJlxxyODhoZstN4EfPgC7tE8M1ZdQFcE25gDNrW7NA,18212
8
+ cache_dit/cache_factory/cache_interface.py,sha256=tHQv7i8Hp6nfbjZWHwDx3nEvCfxLeBw26aMYjyu6nMw,8541
9
9
  cache_dit/cache_factory/cache_types.py,sha256=ooukxQRG55uTLmaZ0SKw6gIeY6SQHhMxkbv55uj2Sqk,991
10
10
  cache_dit/cache_factory/forward_pattern.py,sha256=FumlCuZ-TSmSYH0hGBHctSJ-oGLCftdZjLygqhsmdR4,2258
11
11
  cache_dit/cache_factory/utils.py,sha256=XkVM9AXcB9zYq8-S8QKAsGz80r3tA6U3lBNGDGeHOe4,1871
12
- cache_dit/cache_factory/block_adapters/__init__.py,sha256=OZM5vJwmQIkoIwVmMxKXiHqKvs31NyAva1Z91C_ko3w,17547
13
- cache_dit/cache_factory/block_adapters/block_adapters.py,sha256=IqHV10aK2qA8kEVDi7EEoUSBt0GzwCUM4GpLNf8Jgww,21656
14
- cache_dit/cache_factory/block_adapters/block_registers.py,sha256=ZeN2wGPmuf2u3puSsBx8x-rl3wRo8-cWcuWNcrssVfA,2553
12
+ cache_dit/cache_factory/block_adapters/__init__.py,sha256=33geXMz56TxFWMp0c-H4__MY5SGRzKMKj3TXnUYOMlc,17512
13
+ cache_dit/cache_factory/block_adapters/block_adapters.py,sha256=zZbbsZYWbUClfa6He69w_Wdf8ZLhKwMAb9gURYEUmgQ,23725
14
+ cache_dit/cache_factory/block_adapters/block_registers.py,sha256=2L7QeM4ygnaKQpC9PoJod0QRYyxidUKU2AYpysDCUwE,2572
15
15
  cache_dit/cache_factory/cache_blocks/__init__.py,sha256=08Ox7kD05lkRKCOsVTdEZeKAWBheqpxfrAT1Nz7eclI,2916
16
16
  cache_dit/cache_factory/cache_blocks/pattern_0_1_2.py,sha256=ElMps6_7uI74tSF9GDR_dEI0bZEhdzcepM29xFWnYo8,428
17
17
  cache_dit/cache_factory/cache_blocks/pattern_3_4_5.py,sha256=Bv56qETXhsREvCrNvnZpSqDIIHsi6Ze3FJW4Yk2x3uI,8597
18
18
  cache_dit/cache_factory/cache_blocks/pattern_base.py,sha256=d4H9kEB0AgnVMT8aF0Y54SUMUQUxw5HQ8gRkoCuTQ_A,14577
19
19
  cache_dit/cache_factory/cache_blocks/utils.py,sha256=dGOC1tMMOvcbvEgx44eTESKn_jsv-0RZ3tRHPa3wmQ4,1315
20
20
  cache_dit/cache_factory/cache_contexts/__init__.py,sha256=rqnJ5__zqnpVHK5A1OqWILpNh5Ss-0ZDTGgtxZMKGGo,250
21
- cache_dit/cache_factory/cache_contexts/cache_context.py,sha256=zqixcxV_LjnyoYDZ6q3HAC-hqYyVV6g0MWKBI2hA1nQ,11855
22
- cache_dit/cache_factory/cache_contexts/cache_manager.py,sha256=Mcj1upIpXT_CwO4AdY4ZNJSWoOXn3Lx2mBZRi_QuLbU,32710
23
- cache_dit/cache_factory/cache_contexts/taylorseer.py,sha256=hgLmgIkQgwbFTjxqtLUCJ3mgDGEcJK09B7RK8sBdPiI,3593
24
- cache_dit/cache_factory/patch_functors/__init__.py,sha256=06zdddrjvSCgBzJ0a8niRHd3ucF2qsbzlbL00d4aCvk,451
21
+ cache_dit/cache_factory/cache_contexts/cache_context.py,sha256=FWdgInClWY8VZBsZIevtYk--rX-RL8c3QfNOJtqR8a4,11855
22
+ cache_dit/cache_factory/cache_contexts/cache_manager.py,sha256=Ig5VKoQ46iG3lKmsaMulYxd2vCm__2rY8NBvERwexwM,32719
23
+ cache_dit/cache_factory/cache_contexts/taylorseer.py,sha256=4nxgSEZvDy-w-7XuJYzsyzdtF1_uFrDwlF06XBDFVKQ,3922
24
+ cache_dit/cache_factory/patch_functors/__init__.py,sha256=oI6F3N9ezahRHaFUOZ1GfrAw1qFdKrxFXXmlwwehHj4,530
25
25
  cache_dit/cache_factory/patch_functors/functor_base.py,sha256=Ahk0fTfrHgNdEl-9JSkACvfyyv9G-Ei5OSz7XBIlX5o,357
26
- cache_dit/cache_factory/patch_functors/functor_chroma.py,sha256=2iLxlsc-1dDHRveqCXaC07E9CeMNOuBNkvpJ1atpK7E,10048
26
+ cache_dit/cache_factory/patch_functors/functor_chroma.py,sha256=xD0Q96VArp1vYBLQ0pcjRIyFB1i_Y7muZ2q07Hz8Oqs,13430
27
+ cache_dit/cache_factory/patch_functors/functor_dit.py,sha256=SDjhzCWa6PoFNN4_upoQEf6DHvW1yJ7zuXMS2VvyJco,3904
27
28
  cache_dit/cache_factory/patch_functors/functor_flux.py,sha256=UMkyuEYjO7UO_zmXi9Djd-nD-XMgCUgE-qkYA3plWSM,9559
28
29
  cache_dit/cache_factory/patch_functors/functor_hidream.py,sha256=pi_vvpDy1lsgQHxu3eK3v93rdJL7oNwkt3WakRP8pbw,15375
29
30
  cache_dit/cache_factory/patch_functors/functor_hunyuan_dit.py,sha256=iSo5dD5uKnjQQeysDUIkKt0wdnK5bzXTc_F_lfHG70w,6401
@@ -40,9 +41,9 @@ cache_dit/metrics/metrics.py,sha256=8jvM1sF-nDxUuwCRy44QEoo4dYVLCQVh1QyAMs4eaQY,
40
41
  cache_dit/quantize/__init__.py,sha256=kWYoMAyZgBXu9BJlZjTQ0dRffW9GqeeY9_iTkXrb70A,59
41
42
  cache_dit/quantize/quantize_ao.py,sha256=Fx1KW4l3gdEkdrcAYtPoDW7WKBJWrs3glOHiEwW_TgE,6160
42
43
  cache_dit/quantize/quantize_interface.py,sha256=2s_R7xPSKuJeFpEGeLwRxnq_CqJcBG3a3lzyW5wh-UM,1241
43
- cache_dit-0.2.33.dist-info/licenses/LICENSE,sha256=Dqb07Ik2dV41s9nIdMUbiRWEfDqo7-dQeRiY7kPO8PE,3769
44
- cache_dit-0.2.33.dist-info/METADATA,sha256=GQBvDzKLXL3tABguCRqLNc-Z39h0AcMK_J37demDTu8,25977
45
- cache_dit-0.2.33.dist-info/WHEEL,sha256=_zCd3N1l69ArxyTb8rzEoP9TpbYXkqRFSNOD5OuxnTs,91
46
- cache_dit-0.2.33.dist-info/entry_points.txt,sha256=FX2gysXaZx6NeK1iCLMcIdP8Q4_qikkIHtEmi3oWn8o,65
47
- cache_dit-0.2.33.dist-info/top_level.txt,sha256=ZJDydonLEhujzz0FOkVbO-BqfzO9d_VqRHmZU-3MOZo,10
48
- cache_dit-0.2.33.dist-info/RECORD,,
44
+ cache_dit-0.2.34.dist-info/licenses/LICENSE,sha256=Dqb07Ik2dV41s9nIdMUbiRWEfDqo7-dQeRiY7kPO8PE,3769
45
+ cache_dit-0.2.34.dist-info/METADATA,sha256=BvEY08xjrGPcqTEZSHvSDtJP4sGZv1T6jzhGj-jQbvo,38284
46
+ cache_dit-0.2.34.dist-info/WHEEL,sha256=_zCd3N1l69ArxyTb8rzEoP9TpbYXkqRFSNOD5OuxnTs,91
47
+ cache_dit-0.2.34.dist-info/entry_points.txt,sha256=FX2gysXaZx6NeK1iCLMcIdP8Q4_qikkIHtEmi3oWn8o,65
48
+ cache_dit-0.2.34.dist-info/top_level.txt,sha256=ZJDydonLEhujzz0FOkVbO-BqfzO9d_VqRHmZU-3MOZo,10
49
+ cache_dit-0.2.34.dist-info/RECORD,,