cache-dit 0.2.25__py3-none-any.whl → 0.2.26__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.

Potentially problematic release.


This version of cache-dit might be problematic. Click here for more details.

Files changed (27) hide show
  1. cache_dit/__init__.py +4 -0
  2. cache_dit/_version.py +2 -2
  3. cache_dit/cache_factory/__init__.py +2 -0
  4. cache_dit/cache_factory/cache_adapters.py +375 -26
  5. cache_dit/cache_factory/cache_blocks/__init__.py +20 -0
  6. cache_dit/cache_factory/cache_blocks/pattern_0_1_2.py +16 -0
  7. cache_dit/cache_factory/cache_blocks/pattern_3_4_5.py +270 -0
  8. cache_dit/cache_factory/{cache_blocks.py → cache_blocks/pattern_base.py} +17 -18
  9. cache_dit/cache_factory/cache_blocks/utils.py +19 -0
  10. cache_dit/cache_factory/cache_context.py +5 -1
  11. cache_dit/cache_factory/cache_interface.py +7 -2
  12. cache_dit/cache_factory/forward_pattern.py +45 -24
  13. cache_dit/cache_factory/patch_functors/__init__.py +5 -0
  14. cache_dit/cache_factory/patch_functors/functor_base.py +18 -0
  15. cache_dit/cache_factory/patch_functors/functor_chroma.py +273 -0
  16. cache_dit/cache_factory/{patch/flux.py → patch_functors/functor_flux.py} +45 -31
  17. cache_dit/quantize/quantize_ao.py +18 -4
  18. cache_dit/quantize/quantize_interface.py +2 -2
  19. cache_dit/utils.py +3 -2
  20. {cache_dit-0.2.25.dist-info → cache_dit-0.2.26.dist-info}/METADATA +35 -8
  21. cache_dit-0.2.26.dist-info/RECORD +42 -0
  22. cache_dit/cache_factory/patch/__init__.py +0 -0
  23. cache_dit-0.2.25.dist-info/RECORD +0 -36
  24. {cache_dit-0.2.25.dist-info → cache_dit-0.2.26.dist-info}/WHEEL +0 -0
  25. {cache_dit-0.2.25.dist-info → cache_dit-0.2.26.dist-info}/entry_points.txt +0 -0
  26. {cache_dit-0.2.25.dist-info → cache_dit-0.2.26.dist-info}/licenses/LICENSE +0 -0
  27. {cache_dit-0.2.25.dist-info → cache_dit-0.2.26.dist-info}/top_level.txt +0 -0
@@ -0,0 +1,273 @@
1
+ import inspect
2
+
3
+ import torch
4
+ import numpy as np
5
+ from typing import Tuple, Optional, Dict, Any, Union
6
+ from diffusers import ChromaTransformer2DModel
7
+ from diffusers.models.transformers.transformer_chroma import (
8
+ ChromaSingleTransformerBlock,
9
+ Transformer2DModelOutput,
10
+ )
11
+ from diffusers.utils import (
12
+ USE_PEFT_BACKEND,
13
+ scale_lora_layers,
14
+ unscale_lora_layers,
15
+ )
16
+
17
+ from cache_dit.cache_factory.patch_functors.functor_base import (
18
+ PatchFunctor,
19
+ )
20
+ from cache_dit.logger import init_logger
21
+
22
+ logger = init_logger(__name__)
23
+
24
+
25
+ class ChromaPatchFunctor(PatchFunctor):
26
+
27
+ def apply(
28
+ self,
29
+ transformer: ChromaTransformer2DModel,
30
+ blocks: torch.nn.ModuleList = None,
31
+ **kwargs,
32
+ ) -> ChromaTransformer2DModel:
33
+ if blocks is None:
34
+ blocks = transformer.single_transformer_blocks
35
+
36
+ is_patched = False
37
+ for block in blocks:
38
+ if isinstance(block, ChromaSingleTransformerBlock):
39
+ forward_parameters = inspect.signature(
40
+ block.forward
41
+ ).parameters.keys()
42
+ if "encoder_hidden_states" not in forward_parameters:
43
+ block.forward = __patch_single_forward__.__get__(block)
44
+ is_patched = True
45
+
46
+ if is_patched:
47
+ logger.warning("Patched Chroma for cache-dit.")
48
+ assert not getattr(transformer, "_is_parallelized", False), (
49
+ "Please call `cache_dit.enable_cache` before Parallelize, "
50
+ "the __patch_transformer_forward__ will overwrite the "
51
+ "parallized forward and cause a downgrade of performance."
52
+ )
53
+ transformer.forward = __patch_transformer_forward__.__get__(
54
+ transformer
55
+ )
56
+ transformer._is_patched = True
57
+
58
+ cls_name = transformer.__class__.__name__
59
+ logger.info(
60
+ f"Applied {self.__class__.__name__} for {cls_name}, "
61
+ f"Patch: {is_patched}."
62
+ )
63
+
64
+ return transformer
65
+
66
+
67
+ # adapted from: https://github.com/huggingface/diffusers/blob/main/src/diffusers/models/transformers/transformer_chroma.py
68
+ def __patch_single_forward__(
69
+ self: ChromaSingleTransformerBlock, # Almost same as FluxSingleTransformerBlock
70
+ hidden_states: torch.Tensor,
71
+ encoder_hidden_states: torch.Tensor,
72
+ temb: torch.Tensor,
73
+ image_rotary_emb: Optional[Tuple[torch.Tensor, torch.Tensor]] = None,
74
+ joint_attention_kwargs: Optional[Dict[str, Any]] = None,
75
+ ) -> Tuple[torch.Tensor, torch.Tensor]:
76
+ text_seq_len = encoder_hidden_states.shape[1]
77
+ hidden_states = torch.cat([encoder_hidden_states, hidden_states], dim=1)
78
+
79
+ residual = hidden_states
80
+ norm_hidden_states, gate = self.norm(hidden_states, emb=temb)
81
+ mlp_hidden_states = self.act_mlp(self.proj_mlp(norm_hidden_states))
82
+ joint_attention_kwargs = joint_attention_kwargs or {}
83
+ attn_output = self.attn(
84
+ hidden_states=norm_hidden_states,
85
+ image_rotary_emb=image_rotary_emb,
86
+ **joint_attention_kwargs,
87
+ )
88
+
89
+ hidden_states = torch.cat([attn_output, mlp_hidden_states], dim=2)
90
+ gate = gate.unsqueeze(1)
91
+ hidden_states = gate * self.proj_out(hidden_states)
92
+ hidden_states = residual + hidden_states
93
+ if hidden_states.dtype == torch.float16:
94
+ hidden_states = hidden_states.clip(-65504, 65504)
95
+
96
+ encoder_hidden_states, hidden_states = (
97
+ hidden_states[:, :text_seq_len],
98
+ hidden_states[:, text_seq_len:],
99
+ )
100
+ return encoder_hidden_states, hidden_states
101
+
102
+
103
+ # Adapted from: https://github.com/huggingface/diffusers/blob/main/src/diffusers/models/transformers/transformer_chroma.py
104
+ def __patch_transformer_forward__(
105
+ self: ChromaTransformer2DModel,
106
+ hidden_states: torch.Tensor,
107
+ encoder_hidden_states: torch.Tensor = None,
108
+ timestep: torch.LongTensor = None,
109
+ img_ids: torch.Tensor = None,
110
+ txt_ids: torch.Tensor = None,
111
+ attention_mask: torch.Tensor = None,
112
+ joint_attention_kwargs: Optional[Dict[str, Any]] = None,
113
+ controlnet_block_samples=None,
114
+ controlnet_single_block_samples=None,
115
+ return_dict: bool = True,
116
+ controlnet_blocks_repeat: bool = False,
117
+ ) -> Union[torch.Tensor, Transformer2DModelOutput]:
118
+ if joint_attention_kwargs is not None:
119
+ joint_attention_kwargs = joint_attention_kwargs.copy()
120
+ lora_scale = joint_attention_kwargs.pop("scale", 1.0)
121
+ else:
122
+ lora_scale = 1.0
123
+
124
+ if USE_PEFT_BACKEND:
125
+ # weight the lora layers by setting `lora_scale` for each PEFT layer
126
+ scale_lora_layers(self, lora_scale)
127
+ else:
128
+ if (
129
+ joint_attention_kwargs is not None
130
+ and joint_attention_kwargs.get("scale", None) is not None
131
+ ):
132
+ logger.warning(
133
+ "Passing `scale` via `joint_attention_kwargs` when not using the PEFT backend is ineffective."
134
+ )
135
+
136
+ hidden_states = self.x_embedder(hidden_states)
137
+
138
+ timestep = timestep.to(hidden_states.dtype) * 1000
139
+
140
+ input_vec = self.time_text_embed(timestep)
141
+ pooled_temb = self.distilled_guidance_layer(input_vec)
142
+
143
+ encoder_hidden_states = self.context_embedder(encoder_hidden_states)
144
+
145
+ if txt_ids.ndim == 3:
146
+ logger.warning(
147
+ "Passing `txt_ids` 3d torch.Tensor is deprecated."
148
+ "Please remove the batch dimension and pass it as a 2d torch Tensor"
149
+ )
150
+ txt_ids = txt_ids[0]
151
+ if img_ids.ndim == 3:
152
+ logger.warning(
153
+ "Passing `img_ids` 3d torch.Tensor is deprecated."
154
+ "Please remove the batch dimension and pass it as a 2d torch Tensor"
155
+ )
156
+ img_ids = img_ids[0]
157
+
158
+ ids = torch.cat((txt_ids, img_ids), dim=0)
159
+ image_rotary_emb = self.pos_embed(ids)
160
+
161
+ if (
162
+ joint_attention_kwargs is not None
163
+ and "ip_adapter_image_embeds" in joint_attention_kwargs
164
+ ):
165
+ ip_adapter_image_embeds = joint_attention_kwargs.pop(
166
+ "ip_adapter_image_embeds"
167
+ )
168
+ ip_hidden_states = self.encoder_hid_proj(ip_adapter_image_embeds)
169
+ joint_attention_kwargs.update({"ip_hidden_states": ip_hidden_states})
170
+
171
+ for index_block, block in enumerate(self.transformer_blocks):
172
+ img_offset = 3 * len(self.single_transformer_blocks)
173
+ txt_offset = img_offset + 6 * len(self.transformer_blocks)
174
+ img_modulation = img_offset + 6 * index_block
175
+ text_modulation = txt_offset + 6 * index_block
176
+ temb = torch.cat(
177
+ (
178
+ pooled_temb[:, img_modulation : img_modulation + 6],
179
+ pooled_temb[:, text_modulation : text_modulation + 6],
180
+ ),
181
+ dim=1,
182
+ )
183
+ if torch.is_grad_enabled() and self.gradient_checkpointing:
184
+ encoder_hidden_states, hidden_states = (
185
+ self._gradient_checkpointing_func(
186
+ block,
187
+ hidden_states,
188
+ encoder_hidden_states,
189
+ temb,
190
+ image_rotary_emb,
191
+ attention_mask,
192
+ )
193
+ )
194
+
195
+ else:
196
+ encoder_hidden_states, hidden_states = block(
197
+ hidden_states=hidden_states,
198
+ encoder_hidden_states=encoder_hidden_states,
199
+ temb=temb,
200
+ image_rotary_emb=image_rotary_emb,
201
+ attention_mask=attention_mask,
202
+ joint_attention_kwargs=joint_attention_kwargs,
203
+ )
204
+
205
+ # controlnet residual
206
+ if controlnet_block_samples is not None:
207
+ interval_control = len(self.transformer_blocks) / len(
208
+ controlnet_block_samples
209
+ )
210
+ interval_control = int(np.ceil(interval_control))
211
+ # For Xlabs ControlNet.
212
+ if controlnet_blocks_repeat:
213
+ hidden_states = (
214
+ hidden_states
215
+ + controlnet_block_samples[
216
+ index_block % len(controlnet_block_samples)
217
+ ]
218
+ )
219
+ else:
220
+ hidden_states = (
221
+ hidden_states
222
+ + controlnet_block_samples[index_block // interval_control]
223
+ )
224
+
225
+ for index_block, block in enumerate(self.single_transformer_blocks):
226
+ start_idx = 3 * index_block
227
+ temb = pooled_temb[:, start_idx : start_idx + 3]
228
+ if torch.is_grad_enabled() and self.gradient_checkpointing:
229
+ encoder_hidden_states, hidden_states = (
230
+ self._gradient_checkpointing_func(
231
+ block,
232
+ hidden_states,
233
+ encoder_hidden_states,
234
+ temb,
235
+ image_rotary_emb,
236
+ )
237
+ )
238
+
239
+ else:
240
+ encoder_hidden_states, hidden_states = block(
241
+ hidden_states=hidden_states,
242
+ encoder_hidden_states=encoder_hidden_states,
243
+ temb=temb,
244
+ image_rotary_emb=image_rotary_emb,
245
+ attention_mask=attention_mask,
246
+ joint_attention_kwargs=joint_attention_kwargs,
247
+ )
248
+
249
+ # controlnet residual
250
+ if controlnet_single_block_samples is not None:
251
+ interval_control = len(self.single_transformer_blocks) / len(
252
+ controlnet_single_block_samples
253
+ )
254
+ interval_control = int(np.ceil(interval_control))
255
+ hidden_states = (
256
+ hidden_states
257
+ + controlnet_single_block_samples[
258
+ index_block // interval_control
259
+ ]
260
+ )
261
+
262
+ temb = pooled_temb[:, -2:]
263
+ hidden_states = self.norm_out(hidden_states, temb)
264
+ output = self.proj_out(hidden_states)
265
+
266
+ if USE_PEFT_BACKEND:
267
+ # remove `lora_scale` from each PEFT layer
268
+ unscale_lora_layers(self, lora_scale)
269
+
270
+ if not return_dict:
271
+ return (output,)
272
+
273
+ return Transformer2DModelOutput(sample=output)
@@ -14,12 +14,56 @@ from diffusers.utils import (
14
14
  unscale_lora_layers,
15
15
  )
16
16
 
17
-
17
+ from cache_dit.cache_factory.patch_functors.functor_base import (
18
+ PatchFunctor,
19
+ )
18
20
  from cache_dit.logger import init_logger
19
21
 
20
22
  logger = init_logger(__name__)
21
23
 
22
24
 
25
+ class FluxPatchFunctor(PatchFunctor):
26
+
27
+ def apply(
28
+ self,
29
+ transformer: FluxTransformer2DModel,
30
+ blocks: torch.nn.ModuleList = None,
31
+ **kwargs,
32
+ ) -> FluxTransformer2DModel:
33
+ if blocks is None:
34
+ blocks = transformer.single_transformer_blocks
35
+
36
+ is_patched = False
37
+ for block in blocks:
38
+ if isinstance(block, FluxSingleTransformerBlock):
39
+ forward_parameters = inspect.signature(
40
+ block.forward
41
+ ).parameters.keys()
42
+ if "encoder_hidden_states" not in forward_parameters:
43
+ block.forward = __patch_single_forward__.__get__(block)
44
+ is_patched = True
45
+
46
+ if is_patched:
47
+ logger.warning("Patched Flux for cache-dit.")
48
+ assert not getattr(transformer, "_is_parallelized", False), (
49
+ "Please call `cache_dit.enable_cache` before Parallelize, "
50
+ "the __patch_transformer_forward__ will overwrite the "
51
+ "parallized forward and cause a downgrade of performance."
52
+ )
53
+ transformer.forward = __patch_transformer_forward__.__get__(
54
+ transformer
55
+ )
56
+ transformer._is_patched = True
57
+
58
+ cls_name = transformer.__class__.__name__
59
+ logger.info(
60
+ f"Applied {self.__class__.__name__} for {cls_name}, "
61
+ f"Patch: {is_patched}."
62
+ )
63
+
64
+ return transformer
65
+
66
+
23
67
  # copy from: https://github.com/huggingface/diffusers/blob/main/src/diffusers/models/transformers/transformer_flux.py#L380
24
68
  def __patch_single_forward__(
25
69
  self: FluxSingleTransformerBlock,
@@ -217,33 +261,3 @@ def __patch_transformer_forward__(
217
261
  return (output,)
218
262
 
219
263
  return Transformer2DModelOutput(sample=output)
220
-
221
-
222
- def maybe_patch_flux_transformer(
223
- transformer: FluxTransformer2DModel,
224
- blocks: torch.nn.ModuleList = None,
225
- ) -> FluxTransformer2DModel:
226
- if blocks is None:
227
- blocks = transformer.single_transformer_blocks
228
-
229
- is_patched = False
230
- for block in blocks:
231
- if isinstance(block, FluxSingleTransformerBlock):
232
- forward_parameters = inspect.signature(
233
- block.forward
234
- ).parameters.keys()
235
- if "encoder_hidden_states" not in forward_parameters:
236
- block.forward = __patch_single_forward__.__get__(block)
237
- is_patched = True
238
-
239
- if is_patched:
240
- logger.warning("Patched Flux for cache-dit.")
241
- assert not getattr(transformer, "_is_parallelized", False), (
242
- "Please call apply_cache_on_pipe before Parallelize, "
243
- "the __patch_transformer_forward__ will overwrite the "
244
- "parallized forward and cause a downgrade of performance."
245
- )
246
- transformer.forward = __patch_transformer_forward__.__get__(transformer)
247
- transformer._is_patched = True
248
-
249
- return transformer
@@ -10,12 +10,13 @@ logger = init_logger(__name__)
10
10
  def quantize_ao(
11
11
  module: torch.nn.Module,
12
12
  quant_type: str = "fp8_w8a8_dq",
13
- per_row: bool = True,
14
13
  exclude_layers: List[str] = [
15
14
  "embedder",
16
15
  "embed",
17
16
  ],
18
17
  filter_fn: Optional[Callable] = None,
18
+ # paramters for fp8 quantization
19
+ per_row: bool = True,
19
20
  **kwargs,
20
21
  ) -> torch.nn.Module:
21
22
  # Apply FP8 DQ for module and skip any `embed` modules
@@ -89,17 +90,30 @@ def quantize_ao(
89
90
  )
90
91
 
91
92
  quantization_fn = float8_dynamic_activation_float8_weight(
93
+ weight_dtype=kwargs.get(
94
+ "weight_dtype",
95
+ torch.float8_e4m3fn,
96
+ ),
97
+ activation_dtype=kwargs.get(
98
+ "activation_dtype",
99
+ torch.float8_e4m3fn,
100
+ ),
92
101
  granularity=(
93
102
  ((PerRow(), PerRow()))
94
103
  if per_row
95
104
  else ((PerTensor(), PerTensor()))
96
- )
105
+ ),
97
106
  )
98
107
 
99
108
  elif quant_type == "fp8_w8a16_wo":
100
109
  from torchao.quantization import float8_weight_only
101
110
 
102
- quantization_fn = float8_weight_only()
111
+ quantization_fn = float8_weight_only(
112
+ weight_dtype=kwargs.get(
113
+ "weight_dtype",
114
+ torch.float8_e4m3fn,
115
+ ),
116
+ )
103
117
 
104
118
  elif quant_type == "int8_w8a8_dq":
105
119
  from torchao.quantization import (
@@ -159,7 +173,7 @@ def quantize_ao(
159
173
  module,
160
174
  _quantization_fn(),
161
175
  filter_fn=_filter_fn if filter_fn is None else filter_fn,
162
- **kwargs,
176
+ device=kwargs.get("device", None),
163
177
  )
164
178
 
165
179
  force_empty_cache()
@@ -9,13 +9,13 @@ def quantize(
9
9
  module: torch.nn.Module,
10
10
  quant_type: str = "fp8_w8a8_dq",
11
11
  backend: str = "ao",
12
- # only for fp8_w8a8_dq
13
- per_row: bool = True,
14
12
  exclude_layers: List[str] = [
15
13
  "embedder",
16
14
  "embed",
17
15
  ],
18
16
  filter_fn: Optional[Callable] = None,
17
+ # only for fp8_w8a8_dq
18
+ per_row: bool = True,
19
19
  **kwargs,
20
20
  ) -> torch.nn.Module:
21
21
  assert isinstance(module, torch.nn.Module)
cache_dit/utils.py CHANGED
@@ -30,13 +30,14 @@ class CacheStats:
30
30
 
31
31
 
32
32
  def summary(
33
- pipe_or_transformer: DiffusionPipeline | torch.nn.Module,
33
+ pipe_or_transformer: DiffusionPipeline | torch.nn.Module | Any,
34
34
  details: bool = False,
35
35
  logging: bool = True,
36
36
  ) -> CacheStats:
37
37
  cache_stats = CacheStats()
38
38
  cls_name = pipe_or_transformer.__class__.__name__
39
- if isinstance(pipe_or_transformer, DiffusionPipeline):
39
+ if not isinstance(pipe_or_transformer, torch.nn.Module):
40
+ assert hasattr(pipe_or_transformer, "transformer")
40
41
  transformer = pipe_or_transformer.transformer
41
42
  else:
42
43
  transformer = pipe_or_transformer
@@ -1,6 +1,6 @@
1
1
  Metadata-Version: 2.4
2
2
  Name: cache_dit
3
- Version: 0.2.25
3
+ Version: 0.2.26
4
4
  Summary: 🤗 CacheDiT: An Unified and Training-free Cache Acceleration Toolbox for Diffusion Transformers
5
5
  Author: DefTruth, vipshop.com, etc.
6
6
  Maintainer: DefTruth, vipshop.com, etc
@@ -59,16 +59,17 @@ Dynamic: requires-python
59
59
  🔥<b><a href="#unified">Unified Cache APIs</a> | <a href="#dbcache">DBCache</a> | <a href="#taylorseer">Hybrid TaylorSeer</a> | <a href="#cfg">Hybrid Cache CFG</a></b>🔥
60
60
  </p>
61
61
  <p align="center">
62
- 🎉Now, <b>cache-dit</b> covers <b>Most</b> mainstream <b>Diffusers'</b> Pipelines</b>🎉<br>
62
+ 🎉Now, <b>cache-dit</b> covers <b>All</b> mainstream <b>DiT-based</b> Diffusers' Pipelines</b>🎉<br>
63
63
  🔥<b><a href="#supported">Qwen-Image</a> | <a href="#supported">FLUX.1</a> | <a href="#supported">Wan 2.1/2.2</a> | <a href="#supported"> ... </a> | <a href="#supported">CogVideoX</a></b>🔥
64
64
  </p>
65
65
  </div>
66
66
 
67
67
  ## 🔥News
68
68
 
69
- - [2025-08-26] 🎉[**Wan2.2**](https://github.com/Wan-Video) **1.8x⚡️** speedup with `cache-dit + compile`! Check the [example](./examples/run_wan_2.2.py).
70
- - [2025-08-19] 🔥[**Qwen-Image-Edit**](https://github.com/QwenLM/Qwen-Image) **2x⚡️** speedup! Check example [run_qwen_image_edit.py](./examples/run_qwen_image_edit.py).
71
- - [2025-08-12] 🎉First caching mechanism in [QwenLM/Qwen-Image](https://github.com/QwenLM/Qwen-Image) with **[cache-dit](https://github.com/vipshop/cache-dit)**, check the [PR](https://github.com/QwenLM/Qwen-Image/pull/61).
69
+ - [2025-08-29] 🔥</b>Covers <b>All</b> Diffusers' <b>DiT-based</b> Pipelines via **[BlockAdapter](#unified) + [Pattern Matching](#unified).**
70
+ - [2025-08-26] 🎉[**Wan2.2**](https://github.com/Wan-Video) **1.8x⚡️** speedup with `cache-dit + compile`! Please check the [example](./examples/run_wan_2.2.py).
71
+ - [2025-08-19] 🔥[**Qwen-Image-Edit**](https://github.com/QwenLM/Qwen-Image) **2x⚡️** speedup! Check the example at [run_qwen_image_edit.py](./examples/run_qwen_image_edit.py).
72
+ - [2025-08-12] 🎉First caching mechanism in [QwenLM/Qwen-Image](https://github.com/QwenLM/Qwen-Image) with **[cache-dit](https://github.com/vipshop/cache-dit)**, check this [PR](https://github.com/QwenLM/Qwen-Image/pull/61).
72
73
  - [2025-08-11] 🔥[**Qwen-Image**](https://github.com/QwenLM/Qwen-Image) **1.8x⚡️** speedup! Please refer [run_qwen_image.py](./examples/run_qwen_image.py) as an example.
73
74
 
74
75
  <details>
@@ -79,7 +80,7 @@ Dynamic: requires-python
79
80
  - [2025-07-13] **[🤗flux-faster](https://github.com/xlite-dev/flux-faster)** is released! **3.3x** speedup for FLUX.1 on NVIDIA L20 with **[cache-dit](https://github.com/vipshop/cache-dit)**.
80
81
 
81
82
  </details>
82
-
83
+
83
84
  ## 📖Contents
84
85
 
85
86
  <div id="contents"></div>
@@ -112,8 +113,19 @@ pip3 install git+https://github.com/vipshop/cache-dit.git
112
113
 
113
114
  <div id="supported"></div>
114
115
 
116
+ ```python
117
+ >>> import cache_dit
118
+ >>> cache_dit.supported_pipelines()
119
+ (31, ['Flux*', 'Mochi*', 'CogVideoX*', 'Wan*', 'HunyuanVideo*', 'QwenImage*', 'LTXVideo*',
120
+ 'Allegro*', 'CogView3Plus*', 'CogView4*', 'Cosmos*', 'EasyAnimate*', 'SkyReelsV2*', 'SD3*',
121
+ 'ConsisID*', 'DiT*', 'Amused*', 'Bria*', 'HunyuanDiT*', 'HunyuanDiTPAG*', 'Lumina*', 'Lumina2*',
122
+ 'OmniGen*', 'PixArt*', 'Sana*', 'ShapE*', 'StableAudio*', 'VisualCloze*', 'AuraFlow*',
123
+ 'Chroma*', 'HiDream*'])
124
+ ```
125
+
115
126
  Currently, **cache-dit** library supports almost **Any** Diffusion Transformers (with **Transformer Blocks** that match the specific Input and Output **patterns**). Please check [🎉Unified Cache APIs](#unified) for more details. Here are just some of the tested models listed:
116
127
 
128
+
117
129
  - [🚀Qwen-Image-Edit](https://github.com/vipshop/cache-dit/raw/main/examples)
118
130
  - [🚀Qwen-Image](https://github.com/vipshop/cache-dit/raw/main/examples)
119
131
  - [🚀FLUX.1-dev](https://github.com/vipshop/cache-dit/raw/main/examples)
@@ -125,6 +137,7 @@ Currently, **cache-dit** library supports almost **Any** Diffusion Transformers
125
137
  - [🚀Wan2.1-T2V](https://github.com/vipshop/cache-dit/raw/main/examples)
126
138
  - [🚀Wan2.1-FLF2V](https://github.com/vipshop/cache-dit/raw/main/examples)
127
139
  - [🚀HunyuanVideo](https://github.com/vipshop/cache-dit/raw/main/examples)
140
+ - [🚀HunyuanDiT](https://github.com/vipshop/cache-dit/raw/main/examples)
128
141
 
129
142
  <details>
130
143
  <summary> More Pipelines </summary>
@@ -138,6 +151,20 @@ Currently, **cache-dit** library supports almost **Any** Diffusion Transformers
138
151
  - [🚀EasyAnimate](https://github.com/vipshop/cache-dit/raw/main/examples)
139
152
  - [🚀SkyReelsV2](https://github.com/vipshop/cache-dit/raw/main/examples)
140
153
  - [🚀SD3](https://github.com/vipshop/cache-dit/raw/main/examples)
154
+ - [🚀ConsisID](https://github.com/vipshop/cache-dit/raw/main/examples)
155
+ - [🚀DiT](https://github.com/vipshop/cache-dit/raw/main/examples)
156
+ - [🚀Amused](https://github.com/vipshop/cache-dit/raw/main/examples)
157
+ - [🚀HunyuanDiTPAG](https://github.com/vipshop/cache-dit/raw/main/examples)
158
+ - [🚀Lumina](https://github.com/vipshop/cache-dit/raw/main/examples)
159
+ - [🚀Lumina2](https://github.com/vipshop/cache-dit/raw/main/examples)
160
+ - [🚀OmniGen](https://github.com/vipshop/cache-dit/raw/main/examples)
161
+ - [🚀PixArt](https://github.com/vipshop/cache-dit/raw/main/examples)
162
+ - [🚀Sana](https://github.com/vipshop/cache-dit/raw/main/examples)
163
+ - [🚀StableAudio](https://github.com/vipshop/cache-dit/raw/main/examples)
164
+ - [🚀VisualCloze](https://github.com/vipshop/cache-dit/raw/main/examples)
165
+ - [🚀AuraFlow](https://github.com/vipshop/cache-dit/raw/main/examples)
166
+ - [🚀Chroma](https://github.com/vipshop/cache-dit/raw/main/examples)
167
+ - [🚀HiDream](https://github.com/vipshop/cache-dit/raw/main/examples)
141
168
 
142
169
  </details>
143
170
 
@@ -330,8 +357,8 @@ cache_dit.enable_cache(
330
357
  # For model that fused CFG and non-CFG into single forward step,
331
358
  # should set do_separate_cfg as False. For example, set it as True
332
359
  # for Wan 2.1/Qwen-Image and set it as False for FLUX.1, HunyuanVideo,
333
- # CogVideoX, Mochi, etc.
334
- do_separate_cfg=True, # Wan 2.1, Qwen-Image
360
+ # CogVideoX, Mochi, LTXVideo, Allegro, CogView3Plus, EasyAnimate, SD3, etc.
361
+ do_separate_cfg=True, # Wan 2.1, Qwen-Image, CogView4, Cosmos, SkyReelsV2, etc.
335
362
  # Compute cfg forward first or not, default False, namely,
336
363
  # 0, 2, 4, ..., -> non-CFG step; 1, 3, 5, ... -> CFG step.
337
364
  cfg_compute_first=False,
@@ -0,0 +1,42 @@
1
+ cache_dit/__init__.py,sha256=6_DrKjVU0N7BpQLz4A5-qLLmpc4plflHbBAss_4FmC8,1140
2
+ cache_dit/_version.py,sha256=qaUIn8np9pb6UE7Q3omOIYbBHqbmKse_sogKIKw72sA,706
3
+ cache_dit/logger.py,sha256=0zsu42hN-3-rgGC_C29ms1IvVpV4_b4_SwJCKSenxBE,4304
4
+ cache_dit/utils.py,sha256=H4YqlkvenlBxh2-ilOflbVDFqhI1UtnFniDgQac-D6k,7303
5
+ cache_dit/cache_factory/.gitignore,sha256=5Cb-qT9wsTUoMJ7vACDF7ZcLpAXhi5v-xdcWSRit988,23
6
+ cache_dit/cache_factory/__init__.py,sha256=sHGxoYnUWy4CqWTTnrqF2JdleYGdtN7T3erz__zztzE,635
7
+ cache_dit/cache_factory/cache_adapters.py,sha256=NAM7Zo7WrdSb7WOUG2WFRkFXmSx9sXm2oTMx4PhtlZk,39302
8
+ cache_dit/cache_factory/cache_context.py,sha256=krIPLYExwRbwZBj4-eVLV-v5QSEQoVqoLFMZBFWIIT0,41874
9
+ cache_dit/cache_factory/cache_interface.py,sha256=1lCoN1Co1J6lqRI3mDikgqbscWnZShmEw53uatTHJdc,8588
10
+ cache_dit/cache_factory/cache_types.py,sha256=FIFa6ZBfvvSMMHyBBhvarvgg2Y2wbRgITcG_uGylGe0,991
11
+ cache_dit/cache_factory/forward_pattern.py,sha256=FumlCuZ-TSmSYH0hGBHctSJ-oGLCftdZjLygqhsmdR4,2258
12
+ cache_dit/cache_factory/taylorseer.py,sha256=etSUIZzDvqW3ScKCbccTPcFaSmxV1T-xAXdk-p3e3wk,3802
13
+ cache_dit/cache_factory/utils.py,sha256=XkVM9AXcB9zYq8-S8QKAsGz80r3tA6U3lBNGDGeHOe4,1871
14
+ cache_dit/cache_factory/cache_blocks/__init__.py,sha256=jxO8v6o-Ke30HGfnDfZNZ6XknP3sabA2tlHiBW2BZTo,815
15
+ cache_dit/cache_factory/cache_blocks/pattern_0_1_2.py,sha256=dSVcjHPkjlAqLaXxCyvcx8jdFFq6UfIhZk0geziQCVE,434
16
+ cache_dit/cache_factory/cache_blocks/pattern_3_4_5.py,sha256=IMfm3HPzCbQC_SHbn74pfri4zwmPKBgpnT_NhdLxRZs,9598
17
+ cache_dit/cache_factory/cache_blocks/pattern_base.py,sha256=_B1PPICBkyLUoHYWo0XmtIDv944bGBWFSN9y3X9uKTM,18739
18
+ cache_dit/cache_factory/cache_blocks/utils.py,sha256=3VeCcqsYhD719uakrKJNSIFUa0-Qqgw08uu0LCKFa_A,648
19
+ cache_dit/cache_factory/patch_functors/__init__.py,sha256=yK05iONMGILsTZ83ynrUUJtiJKJ_FDjxmVIzRLy416s,252
20
+ cache_dit/cache_factory/patch_functors/functor_base.py,sha256=Ahk0fTfrHgNdEl-9JSkACvfyyv9G-Ei5OSz7XBIlX5o,357
21
+ cache_dit/cache_factory/patch_functors/functor_chroma.py,sha256=RcN6AmpDp19ILY37tjOANlwXHrcaNMHlbv9XWF8hBwA,9942
22
+ cache_dit/cache_factory/patch_functors/functor_flux.py,sha256=ycAypjJ34Uh7hmvbRbHadswRQj_fpxU24YfX1vtBL6c,9450
23
+ cache_dit/compile/__init__.py,sha256=FcTVzCeyypl-mxlc59_ehHL3lBNiDAFsXuRoJ-5Cfi0,56
24
+ cache_dit/compile/utils.py,sha256=nN2OIrSdwRR5zGxJinKDqb07pXpvTNTF3g_OgLkeeBU,3858
25
+ cache_dit/custom_ops/__init__.py,sha256=47DEQpj8HBSa-_TImW-5JCeuQeRkm5NMpJWZG3hSuFU,0
26
+ cache_dit/custom_ops/triton_taylorseer.py,sha256=47DEQpj8HBSa-_TImW-5JCeuQeRkm5NMpJWZG3hSuFU,0
27
+ cache_dit/metrics/__init__.py,sha256=RaUhl5dieF40RqnizGzR30qoJJ9dyMUEADwgwMaMQrE,575
28
+ cache_dit/metrics/config.py,sha256=ieOgD9ayz722RjVzk24bSIqS2D6o7TZjGk8KeXV-OLQ,551
29
+ cache_dit/metrics/fid.py,sha256=9Ivtazl6mW0Bon2VXa-Ia5Xj2ewxRD3V1Qkd69zYM3Y,17066
30
+ cache_dit/metrics/inception.py,sha256=pBVe2X6ylLPIXTG4-GWDM9DWnCviMJbJ45R3ulhktR0,12759
31
+ cache_dit/metrics/lpips.py,sha256=I2qCNi6qJh5TRsaIsdxO0WoRX1DN7U_H3zS0oCSahYM,1032
32
+ cache_dit/metrics/metrics.py,sha256=8jvM1sF-nDxUuwCRy44QEoo4dYVLCQVh1QyAMs4eaQY,27840
33
+ cache_dit/quantize/__init__.py,sha256=kWYoMAyZgBXu9BJlZjTQ0dRffW9GqeeY9_iTkXrb70A,59
34
+ cache_dit/quantize/quantize_ao.py,sha256=x9zm7AX9JjNhh7mqMkjHDGz2rDl4PzBwwU-CP1e_AVA,6012
35
+ cache_dit/quantize/quantize_interface.py,sha256=2s_R7xPSKuJeFpEGeLwRxnq_CqJcBG3a3lzyW5wh-UM,1241
36
+ cache_dit/quantize/quantize_svdq.py,sha256=47DEQpj8HBSa-_TImW-5JCeuQeRkm5NMpJWZG3hSuFU,0
37
+ cache_dit-0.2.26.dist-info/licenses/LICENSE,sha256=Dqb07Ik2dV41s9nIdMUbiRWEfDqo7-dQeRiY7kPO8PE,3769
38
+ cache_dit-0.2.26.dist-info/METADATA,sha256=xBljGjnEV9OxSWLRmKx-FLWKUw2DGIv3pNJHg106uOo,21722
39
+ cache_dit-0.2.26.dist-info/WHEEL,sha256=_zCd3N1l69ArxyTb8rzEoP9TpbYXkqRFSNOD5OuxnTs,91
40
+ cache_dit-0.2.26.dist-info/entry_points.txt,sha256=FX2gysXaZx6NeK1iCLMcIdP8Q4_qikkIHtEmi3oWn8o,65
41
+ cache_dit-0.2.26.dist-info/top_level.txt,sha256=ZJDydonLEhujzz0FOkVbO-BqfzO9d_VqRHmZU-3MOZo,10
42
+ cache_dit-0.2.26.dist-info/RECORD,,
File without changes
@@ -1,36 +0,0 @@
1
- cache_dit/__init__.py,sha256=VsT0f0R0COp8v6Sx9hGNsqxiElERaDpfG11a9MfK0is,945
2
- cache_dit/_version.py,sha256=t9iixyDuMWz1nP7KM0bgrLNIpwu8JK6uZApA8DoCQwM,706
3
- cache_dit/logger.py,sha256=0zsu42hN-3-rgGC_C29ms1IvVpV4_b4_SwJCKSenxBE,4304
4
- cache_dit/utils.py,sha256=1oWDMYs6E7FRsd8cidsVOPT-meIRKeuqbGbE6CrCUec,7236
5
- cache_dit/cache_factory/.gitignore,sha256=5Cb-qT9wsTUoMJ7vACDF7ZcLpAXhi5v-xdcWSRit988,23
6
- cache_dit/cache_factory/__init__.py,sha256=evWenCin1kuBGa6W5BCKMrDZc1C1R2uVPSg0BjXgdXE,499
7
- cache_dit/cache_factory/cache_adapters.py,sha256=Yugqljm9tm615srM2BGQlR_tA0QiZo3PbLPceObh4dQ,25988
8
- cache_dit/cache_factory/cache_blocks.py,sha256=ZeazBsYvLIjI5M_OnLL2xP2W7zMeM0rxVfBBwIVHBRs,18661
9
- cache_dit/cache_factory/cache_context.py,sha256=HhA5IMSdF-i-koSB1jqf5AMC_UyDV7VinpHm4Qee9Ig,41800
10
- cache_dit/cache_factory/cache_interface.py,sha256=HymagnKEDs48Ly_x3IM5jTMNJpLrIdJnppVlkr2xHaM,8433
11
- cache_dit/cache_factory/cache_types.py,sha256=FIFa6ZBfvvSMMHyBBhvarvgg2Y2wbRgITcG_uGylGe0,991
12
- cache_dit/cache_factory/forward_pattern.py,sha256=B2YeqV2t_zo2Ar8m7qimPBjwQgoXHGp2grPZmEAhi8s,1286
13
- cache_dit/cache_factory/taylorseer.py,sha256=etSUIZzDvqW3ScKCbccTPcFaSmxV1T-xAXdk-p3e3wk,3802
14
- cache_dit/cache_factory/utils.py,sha256=XkVM9AXcB9zYq8-S8QKAsGz80r3tA6U3lBNGDGeHOe4,1871
15
- cache_dit/cache_factory/patch/__init__.py,sha256=47DEQpj8HBSa-_TImW-5JCeuQeRkm5NMpJWZG3hSuFU,0
16
- cache_dit/cache_factory/patch/flux.py,sha256=iNQ-1RlOgXupZ4uPiEvJ__Ro6vKT_fOKja9JrpMrO78,8998
17
- cache_dit/compile/__init__.py,sha256=FcTVzCeyypl-mxlc59_ehHL3lBNiDAFsXuRoJ-5Cfi0,56
18
- cache_dit/compile/utils.py,sha256=nN2OIrSdwRR5zGxJinKDqb07pXpvTNTF3g_OgLkeeBU,3858
19
- cache_dit/custom_ops/__init__.py,sha256=47DEQpj8HBSa-_TImW-5JCeuQeRkm5NMpJWZG3hSuFU,0
20
- cache_dit/custom_ops/triton_taylorseer.py,sha256=47DEQpj8HBSa-_TImW-5JCeuQeRkm5NMpJWZG3hSuFU,0
21
- cache_dit/metrics/__init__.py,sha256=RaUhl5dieF40RqnizGzR30qoJJ9dyMUEADwgwMaMQrE,575
22
- cache_dit/metrics/config.py,sha256=ieOgD9ayz722RjVzk24bSIqS2D6o7TZjGk8KeXV-OLQ,551
23
- cache_dit/metrics/fid.py,sha256=9Ivtazl6mW0Bon2VXa-Ia5Xj2ewxRD3V1Qkd69zYM3Y,17066
24
- cache_dit/metrics/inception.py,sha256=pBVe2X6ylLPIXTG4-GWDM9DWnCviMJbJ45R3ulhktR0,12759
25
- cache_dit/metrics/lpips.py,sha256=I2qCNi6qJh5TRsaIsdxO0WoRX1DN7U_H3zS0oCSahYM,1032
26
- cache_dit/metrics/metrics.py,sha256=8jvM1sF-nDxUuwCRy44QEoo4dYVLCQVh1QyAMs4eaQY,27840
27
- cache_dit/quantize/__init__.py,sha256=kWYoMAyZgBXu9BJlZjTQ0dRffW9GqeeY9_iTkXrb70A,59
28
- cache_dit/quantize/quantize_ao.py,sha256=sKz_RmVtxLOpAPnUv_iOjzY_226pfaxgB_HMNrfyqB8,5465
29
- cache_dit/quantize/quantize_interface.py,sha256=NG4WP7s8CLW6KhVFb9e1aAjW30KWTCcM2aS5n8QuwsA,1241
30
- cache_dit/quantize/quantize_svdq.py,sha256=47DEQpj8HBSa-_TImW-5JCeuQeRkm5NMpJWZG3hSuFU,0
31
- cache_dit-0.2.25.dist-info/licenses/LICENSE,sha256=Dqb07Ik2dV41s9nIdMUbiRWEfDqo7-dQeRiY7kPO8PE,3769
32
- cache_dit-0.2.25.dist-info/METADATA,sha256=a5wbENMZ9BDjHbM3Ejb7Il7x4QzD8W7Lzmu4poo95Wo,19913
33
- cache_dit-0.2.25.dist-info/WHEEL,sha256=_zCd3N1l69ArxyTb8rzEoP9TpbYXkqRFSNOD5OuxnTs,91
34
- cache_dit-0.2.25.dist-info/entry_points.txt,sha256=FX2gysXaZx6NeK1iCLMcIdP8Q4_qikkIHtEmi3oWn8o,65
35
- cache_dit-0.2.25.dist-info/top_level.txt,sha256=ZJDydonLEhujzz0FOkVbO-BqfzO9d_VqRHmZU-3MOZo,10
36
- cache_dit-0.2.25.dist-info/RECORD,,