cache-dit 0.2.15__py3-none-any.whl → 0.2.17__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.

Potentially problematic release.


This version of cache-dit might be problematic. Click here for more details.

Files changed (43) hide show
  1. cache_dit/__init__.py +12 -0
  2. cache_dit/_version.py +16 -3
  3. cache_dit/cache_factory/.gitignore +2 -0
  4. cache_dit/cache_factory/__init__.py +52 -2
  5. cache_dit/cache_factory/cache_adapters.py +654 -0
  6. cache_dit/cache_factory/cache_blocks.py +487 -0
  7. cache_dit/cache_factory/{dual_block_cache/cache_context.py → cache_context.py} +11 -862
  8. cache_dit/cache_factory/patch/flux.py +249 -0
  9. cache_dit/cache_factory/utils.py +1 -1
  10. cache_dit/compile/__init__.py +1 -1
  11. cache_dit/compile/utils.py +1 -1
  12. {cache_dit-0.2.15.dist-info → cache_dit-0.2.17.dist-info}/METADATA +87 -204
  13. cache_dit-0.2.17.dist-info/RECORD +30 -0
  14. cache_dit/cache_factory/adapters.py +0 -169
  15. cache_dit/cache_factory/dual_block_cache/diffusers_adapters/__init__.py +0 -55
  16. cache_dit/cache_factory/dual_block_cache/diffusers_adapters/cogvideox.py +0 -87
  17. cache_dit/cache_factory/dual_block_cache/diffusers_adapters/flux.py +0 -98
  18. cache_dit/cache_factory/dual_block_cache/diffusers_adapters/hunyuan_video.py +0 -294
  19. cache_dit/cache_factory/dual_block_cache/diffusers_adapters/mochi.py +0 -87
  20. cache_dit/cache_factory/dual_block_cache/diffusers_adapters/qwen_image.py +0 -88
  21. cache_dit/cache_factory/dual_block_cache/diffusers_adapters/wan.py +0 -97
  22. cache_dit/cache_factory/dynamic_block_prune/__init__.py +0 -0
  23. cache_dit/cache_factory/dynamic_block_prune/diffusers_adapters/__init__.py +0 -51
  24. cache_dit/cache_factory/dynamic_block_prune/diffusers_adapters/cogvideox.py +0 -87
  25. cache_dit/cache_factory/dynamic_block_prune/diffusers_adapters/flux.py +0 -98
  26. cache_dit/cache_factory/dynamic_block_prune/diffusers_adapters/hunyuan_video.py +0 -294
  27. cache_dit/cache_factory/dynamic_block_prune/diffusers_adapters/mochi.py +0 -87
  28. cache_dit/cache_factory/dynamic_block_prune/diffusers_adapters/wan.py +0 -97
  29. cache_dit/cache_factory/dynamic_block_prune/prune_context.py +0 -1005
  30. cache_dit/cache_factory/first_block_cache/__init__.py +0 -0
  31. cache_dit/cache_factory/first_block_cache/cache_context.py +0 -719
  32. cache_dit/cache_factory/first_block_cache/diffusers_adapters/__init__.py +0 -57
  33. cache_dit/cache_factory/first_block_cache/diffusers_adapters/cogvideox.py +0 -89
  34. cache_dit/cache_factory/first_block_cache/diffusers_adapters/flux.py +0 -100
  35. cache_dit/cache_factory/first_block_cache/diffusers_adapters/hunyuan_video.py +0 -295
  36. cache_dit/cache_factory/first_block_cache/diffusers_adapters/mochi.py +0 -89
  37. cache_dit/cache_factory/first_block_cache/diffusers_adapters/wan.py +0 -98
  38. cache_dit-0.2.15.dist-info/RECORD +0 -50
  39. /cache_dit/cache_factory/{dual_block_cache → patch}/__init__.py +0 -0
  40. {cache_dit-0.2.15.dist-info → cache_dit-0.2.17.dist-info}/WHEEL +0 -0
  41. {cache_dit-0.2.15.dist-info → cache_dit-0.2.17.dist-info}/entry_points.txt +0 -0
  42. {cache_dit-0.2.15.dist-info → cache_dit-0.2.17.dist-info}/licenses/LICENSE +0 -0
  43. {cache_dit-0.2.15.dist-info → cache_dit-0.2.17.dist-info}/top_level.txt +0 -0
@@ -1,7 +1,7 @@
1
1
  Metadata-Version: 2.4
2
2
  Name: cache_dit
3
- Version: 0.2.15
4
- Summary: 🤗 CacheDiT: A Training-free and Easy-to-use Cache Acceleration Toolbox for Diffusion Transformers
3
+ Version: 0.2.17
4
+ Summary: 🤗 CacheDiT: An Unified and Training-free Cache Acceleration Toolbox for Diffusion Transformers
5
5
  Author: DefTruth, vipshop.com, etc.
6
6
  Maintainer: DefTruth, vipshop.com, etc
7
7
  Project-URL: Repository, https://github.com/vipshop/cache-dit.git
@@ -41,7 +41,7 @@ Dynamic: requires-python
41
41
 
42
42
  <div align="center">
43
43
  <p align="center">
44
- <h2>🤗 CacheDiT: A Training-free and Easy-to-use Cache Acceleration <br>Toolbox for Diffusion Transformers</h2>
44
+ <h2>🤗 CacheDiT: An Unified and Training-free Cache Acceleration <br>Toolbox for Diffusion Transformers</h2>
45
45
  </p>
46
46
  <img src=https://github.com/vipshop/cache-dit/raw/main/assets/cache-dit-v1.png >
47
47
  <div align='center'>
@@ -52,13 +52,23 @@ Dynamic: requires-python
52
52
  <img src=https://img.shields.io/badge/Python-3.10|3.11|3.12-9cf.svg >
53
53
  <img src=https://img.shields.io/badge/Release-v0.2-brightgreen.svg >
54
54
  </div>
55
- <b>🔥<a href="#dbcache">DBCache</a> | <a href="#dbprune">DBPrune</a> | <a href="#taylorseer">Hybrid TaylorSeer</a> | <a href="#cfg">Hybrid Cache CFG</a> | <a href="#fbcache">FBCache</a></b>🔥
55
+ 🔥<b><a href="#unified">Unified Cache APIs</a> | <a href="#dbcache">DBCache</a> | <a href="#taylorseer">Hybrid TaylorSeer</a> | <a href="#cfg">Hybrid Cache CFG</a></b>🔥
56
56
  </div>
57
57
 
58
- ## 🔥News
59
- - [2025-08-11] 🔥[Qwen-Image](./examples/run_qwen_image.py) is supported! Please check [run_qwen_image.py](./examples/run_qwen_image.py) as an example.
60
- - [2025-08-10] 🔥[FLUX.1-Kontext-dev](./examples/run_flux_kontext.py) is supported! Please check [run_flux_kontext.py](./examples/run_flux_kontext.py) as an example.
61
- - [2025-07-18] 🎉First caching mechanism in [🤗huggingface/flux-fast](https://github.com/huggingface/flux-fast) with **[cache-dit](https://github.com/vipshop/cache-dit)**, also check the [PR](https://github.com/huggingface/flux-fast/pull/13).
58
+ <div align="center">
59
+ <p align="center">
60
+ ♥️ Cache <b>Acceleration</b> with <b>One-line</b> Code ~ ♥️
61
+ </p>
62
+ </div>
63
+
64
+
65
+ ## 🔥News
66
+
67
+ - [2025-08-18] 🎉Early **[Unified Cache APIs](#unified)** released! Check [Qwen-Image w/ UAPI](./examples/run_qwen_image_uapi.py) as an example.
68
+ - [2025-08-12] 🎉First caching mechanism in [QwenLM/Qwen-Image](https://github.com/QwenLM/Qwen-Image) with **[cache-dit](https://github.com/vipshop/cache-dit)**, check the [PR](https://github.com/QwenLM/Qwen-Image/pull/61).
69
+ - [2025-08-11] 🔥[Qwen-Image](https://github.com/QwenLM/Qwen-Image) is supported now! Please refer [run_qwen_image.py](./examples/run_qwen_image.py) as an example.
70
+ - [2025-08-10] 🔥[FLUX.1-Kontext-dev](https://huggingface.co/black-forest-labs/FLUX.1-Kontext-dev) is supported! Please refer [run_flux_kontext.py](./examples/run_flux_kontext.py) as an example.
71
+ - [2025-07-18] 🎉First caching mechanism in [🤗huggingface/flux-fast](https://github.com/huggingface/flux-fast) with **[cache-dit](https://github.com/vipshop/cache-dit)**, check the [PR](https://github.com/huggingface/flux-fast/pull/13).
62
72
  - [2025-07-13] **[🤗flux-faster](https://github.com/xlite-dev/flux-faster)** is released! **3.3x** speedup for FLUX.1 on NVIDIA L20 with `cache-dit`.
63
73
 
64
74
  ## 📖Contents
@@ -67,17 +77,12 @@ Dynamic: requires-python
67
77
 
68
78
  - [⚙️Installation](#️installation)
69
79
  - [🔥Supported Models](#supported)
80
+ - [🎉Unified Cache APIs](#unified)
70
81
  - [⚡️Dual Block Cache](#dbcache)
71
82
  - [🔥Hybrid TaylorSeer](#taylorseer)
72
83
  - [⚡️Hybrid Cache CFG](#cfg)
73
- - [🎉First Block Cache](#fbcache)
74
- - [⚡️Dynamic Block Prune](#dbprune)
75
- - [🎉Context Parallelism](#context-parallelism)
76
84
  - [🔥Torch Compile](#compile)
77
- - [⚙️Metrics CLI](#metrics)
78
- - [👋Contribute](#contribute)
79
- - [©️License](#license)
80
- - [©️Citations](#citations)
85
+ - [🛠Metrics CLI](#metrics)
81
86
 
82
87
  ## ⚙️Installation
83
88
 
@@ -98,6 +103,8 @@ pip3 install git+https://github.com/vipshop/cache-dit.git
98
103
 
99
104
  <div id="supported"></div>
100
105
 
106
+ Currently, **cache-dit** library supports almost **Any** Diffusion Transformers (with **Transformer Blocks** that match the specific Input and Output **patterns**). Please check [🎉Unified Cache APIs](#unified) for more details. Here are just some of the tested models listed:
107
+
101
108
  - [🚀Qwen-Image](https://github.com/vipshop/cache-dit/raw/main/examples)
102
109
  - [🚀FLUX.1-dev](https://github.com/vipshop/cache-dit/raw/main/examples)
103
110
  - [🚀FLUX.1-Fill-dev](https://github.com/vipshop/cache-dit/raw/main/examples)
@@ -108,7 +115,48 @@ pip3 install git+https://github.com/vipshop/cache-dit.git
108
115
  - [🚀Wan2.1-T2V](https://github.com/vipshop/cache-dit/raw/main/examples)
109
116
  - [🚀Wan2.1-FLF2V](https://github.com/vipshop/cache-dit/raw/main/examples)
110
117
  - [🚀HunyuanVideo](https://github.com/vipshop/cache-dit/raw/main/examples)
118
+ - [🚀LTXVideo](https://github.com/vipshop/cache-dit/raw/main/examples)
119
+ - [🚀Allegro](https://github.com/vipshop/cache-dit/raw/main/examples)
120
+ - [🚀CogView3Plus](https://github.com/vipshop/cache-dit/raw/main/examples)
121
+ - [🚀CogView4](https://github.com/vipshop/cache-dit/raw/main/examples)
122
+ - [🚀Cosmos](https://github.com/vipshop/cache-dit/raw/main/examples)
123
+ - [🚀EasyAnimate](https://github.com/vipshop/cache-dit/raw/main/examples)
124
+ - [🚀SkyReelsV2](https://github.com/vipshop/cache-dit/raw/main/examples)
125
+ - [🚀SD3](https://github.com/vipshop/cache-dit/raw/main/examples)
126
+
127
+ ## 🎉Unified Cache APIs
128
+
129
+ <div id="unified"></div>
130
+
131
+
132
+ Currently, for any **Diffusion** models with **Transformer Blocks** that match the specific **Input/Output pattern**, we can use the **Unified Cache APIs** from **cache-dit**. The supported patterns are listed as follows:
133
+
134
+ ```bash
135
+ (IN: hidden_states, encoder_hidden_states, ...) -> (OUT: hidden_states, encoder_hidden_states)
136
+ (IN: hidden_states, encoder_hidden_states, ...) -> (OUT: encoder_hidden_states, hidden_states)
137
+ (IN: hidden_states, encoder_hidden_states, ...) -> (OUT: hidden_states)
138
+ (IN: hidden_states, ...) -> (OUT: hidden_states) # TODO, DiT, Lumina2, etc.
139
+ ```
140
+
141
+ Please refer to [Qwen-Image w/ UAPI](./examples/run_qwen_image_uapi.py) as an example. The `pipe` parameter can be **Any** Diffusion Pipelines. The **Unified Cache APIs** are currently in the experimental phase, please stay tuned for updates.
142
+
143
+ ```python
144
+ import cache_dit
145
+ from diffusers import DiffusionPipeline # Can be [Any] Diffusion Pipeline
111
146
 
147
+ pipe = DiffusionPipeline.from_pretrained("Qwen/Qwen-Image")
148
+
149
+ # Just use the one line code with default cache options.
150
+ cache_dit.enable_cache(pipe)
151
+
152
+ # Or, enable cache with custom setting according to your models.
153
+ cache_dit.enable_cache(
154
+ pipe, transformer=pipe.transformer,
155
+ blocks=pipe.transformer.transformer_blocks,
156
+ return_hidden_states_first=False,
157
+ **cache_dit.default_options(),
158
+ )
159
+ ```
112
160
 
113
161
  ## ⚡️DBCache: Dual Block Cache
114
162
 
@@ -153,31 +201,31 @@ These case studies demonstrate that even with relatively high thresholds (such a
153
201
  - **max_cached_steps**: (default: -1) DBCache disables the caching strategy when the previous cached steps exceed this value to prevent precision degradation.
154
202
  - **residual_diff_threshold**: The value of residual diff threshold, a higher value leads to faster performance at the cost of lower precision.
155
203
 
156
- For a good balance between performance and precision, DBCache is configured by default with **F8B8**, 8 warmup steps, and unlimited cached steps.
204
+ For a good balance between performance and precision, DBCache is configured by default with **F8B0**, 8 warmup steps, and unlimited cached steps.
157
205
 
158
206
  ```python
207
+ import cache_dit
159
208
  from diffusers import FluxPipeline
160
- from cache_dit.cache_factory import apply_cache_on_pipe, CacheType
161
209
 
162
210
  pipe = FluxPipeline.from_pretrained(
163
211
  "black-forest-labs/FLUX.1-dev",
164
212
  torch_dtype=torch.bfloat16,
165
213
  ).to("cuda")
166
214
 
167
- # Default options, F8B8, good balance between performance and precision
168
- cache_options = CacheType.default_options(CacheType.DBCache)
215
+ # Default options, F8B0, good balance between performance and precision
216
+ cache_options = cache_dit.default_options()
169
217
 
170
- # Custom options, F8B16, higher precision
218
+ # Custom options, F8B8, higher precision
171
219
  cache_options = {
172
- "cache_type": CacheType.DBCache,
220
+ "cache_type": cache_dit.DBCache,
173
221
  "warmup_steps": 8,
174
- "max_cached_steps": 8, # -1 means no limit
175
- "Fn_compute_blocks": 8, # Fn, F8, etc.
176
- "Bn_compute_blocks": 16, # Bn, B16, etc.
222
+ "max_cached_steps": -1, # -1 means no limit
223
+ "Fn_compute_blocks": 8, # Fn, F8, etc.
224
+ "Bn_compute_blocks": 8, # Bn, B8, etc.
177
225
  "residual_diff_threshold": 0.12,
178
226
  }
179
227
 
180
- apply_cache_on_pipe(pipe, **cache_options)
228
+ cache_dit.enable_cache(pipe, **cache_options)
181
229
  ```
182
230
  Moreover, users configuring higher **Bn** values (e.g., **F8B16**) while aiming to maintain good performance can specify **Bn_compute_blocks_ids** to work with Bn. DBCache will only compute the specified blocks, with the remaining estimated using the previous step's residual cache.
183
231
 
@@ -185,7 +233,7 @@ Moreover, users configuring higher **Bn** values (e.g., **F8B16**) while aiming
185
233
  # Custom options, F8B16, higher precision with good performance.
186
234
  cache_options = {
187
235
  # 0, 2, 4, ..., 14, 15, etc. [0,16)
188
- "Bn_compute_blocks_ids": CacheType.range(0, 16, 2),
236
+ "Bn_compute_blocks_ids": cache_dit.block_range(0, 16, 2),
189
237
  # If the L1 difference is below this threshold, skip Bn blocks
190
238
  # not in `Bn_compute_blocks_ids`(1, 3,..., etc), Otherwise,
191
239
  # compute these blocks.
@@ -203,7 +251,7 @@ $$
203
251
  \mathcal{F}\_{\text {pred }, m}\left(x_{t-k}^l\right)=\mathcal{F}\left(x_t^l\right)+\sum_{i=1}^m \frac{\Delta^i \mathcal{F}\left(x_t^l\right)}{i!\cdot N^i}(-k)^i
204
252
  $$
205
253
 
206
- **TaylorSeer** employs a differential method to approximate the higher-order derivatives of features and predict features in future timesteps with Taylor series expansion. The TaylorSeer implemented in CacheDiT supports both hidden states and residual cache types. That is $\mathcal{F}\_{\text {pred }, m}\left(x_{t-k}^l\right)$ can be a residual cache or a hidden-state cache.
254
+ **TaylorSeer** employs a differential method to approximate the higher-order derivatives of features and predict features in future timesteps with Taylor series expansion. The TaylorSeer implemented in cache-dit supports both hidden states and residual cache types. That is $\mathcal{F}\_{\text {pred }, m}\left(x_{t-k}^l\right)$ can be a residual cache or a hidden-state cache.
207
255
 
208
256
  ```python
209
257
  cache_options = {
@@ -240,7 +288,7 @@ cache_options = {
240
288
 
241
289
  <div id="cfg"></div>
242
290
 
243
- CacheDiT supports caching for **CFG (classifier-free guidance)**. For models that fuse CFG and non-CFG into a single forward step, or models that do not include CFG (classifier-free guidance) in the forward step, please set `do_separate_classifier_free_guidance` param to **False (default)**. Otherwise, set it to True. For examples:
291
+ cache-dit supports caching for **CFG (classifier-free guidance)**. For models that fuse CFG and non-CFG into a single forward step, or models that do not include CFG (classifier-free guidance) in the forward step, please set `do_separate_classifier_free_guidance` param to **False (default)**. Otherwise, set it to True. For examples:
244
292
 
245
293
  ```python
246
294
  cache_options = {
@@ -249,7 +297,7 @@ cache_options = {
249
297
  # should set do_separate_classifier_free_guidance as False.
250
298
  # For example, set it as True for Wan 2.1 and set it as False
251
299
  # for FLUX.1, HunyuanVideo, CogVideoX, Mochi.
252
- "do_separate_classifier_free_guidance": True, # Wan 2.1
300
+ "do_separate_classifier_free_guidance": True, # Wan 2.1, Qwen-Image
253
301
  # Compute cfg forward first or not, default False, namely,
254
302
  # 0, 2, 4, ..., -> non-CFG step; 1, 3, 5, ... -> CFG step.
255
303
  "cfg_compute_first": False,
@@ -260,185 +308,20 @@ cache_options = {
260
308
  }
261
309
  ```
262
310
 
263
- ## 🎉FBCache: First Block Cache
264
-
265
- <div id="fbcache"></div>
266
-
267
- ![](https://github.com/vipshop/cache-dit/raw/main/assets/fbcache-v1.png)
268
-
269
- **DBCache** is a more general cache algorithm than **FBCache**. When Fn=1 and Bn=0, DBCache behaves identically to FBCache. Therefore, you can either use the original FBCache implementation directly or configure **DBCache** with **F1B0** settings to achieve the same functionality.
270
-
271
- ```python
272
- from diffusers import FluxPipeline
273
- from cache_dit.cache_factory import apply_cache_on_pipe, CacheType
274
-
275
- pipe = FluxPipeline.from_pretrained(
276
- "black-forest-labs/FLUX.1-dev",
277
- torch_dtype=torch.bfloat16,
278
- ).to("cuda")
279
-
280
- # Using FBCache directly
281
- cache_options = CacheType.default_options(CacheType.FBCache)
282
-
283
- # Or using DBCache with F1B0.
284
- # Fn=1, Bn=0, means FB Cache, otherwise, Dual Block Cache
285
- cache_options = {
286
- "cache_type": CacheType.DBCache,
287
- "warmup_steps": 8,
288
- "max_cached_steps": 8, # -1 means no limit
289
- "Fn_compute_blocks": 1, # Fn, F1, etc.
290
- "Bn_compute_blocks": 0, # Bn, B0, etc.
291
- "residual_diff_threshold": 0.12,
292
- }
293
-
294
- apply_cache_on_pipe(pipe, **cache_options)
295
- ```
296
-
297
- ## ⚡️DBPrune: Dynamic Block Prune
298
-
299
- <div id="dbprune"></div>
300
-
301
- ![](https://github.com/vipshop/cache-dit/raw/main/assets/dbprune-v1.png)
302
-
303
- We have further implemented a new **Dynamic Block Prune** algorithm based on **Residual Caching** for Diffusion Transformers, which is referred to as **DBPrune**. DBPrune caches each block's hidden states and residuals, then dynamically prunes blocks during inference by computing the L1 distance between previous hidden states. When a block is pruned, its output is approximated using the cached residuals. DBPrune is currently in the experimental phase, and we kindly invite you to stay tuned for upcoming updates.
304
-
305
- ```python
306
- from diffusers import FluxPipeline
307
- from cache_dit.cache_factory import apply_cache_on_pipe, CacheType
308
-
309
- pipe = FluxPipeline.from_pretrained(
310
- "black-forest-labs/FLUX.1-dev",
311
- torch_dtype=torch.bfloat16,
312
- ).to("cuda")
313
-
314
- # Using DBPrune with default options
315
- cache_options = CacheType.default_options(CacheType.DBPrune)
316
-
317
- apply_cache_on_pipe(pipe, **cache_options)
318
- ```
319
-
320
- We have also brought the designs from DBCache to DBPrune to make it a more general and customizable block prune algorithm. You can specify the values of **Fn** and **Bn** for higher precision, or set up the non-prune blocks list **non_prune_blocks_ids** to avoid aggressive pruning. For example:
321
-
322
- ```python
323
- # Custom options for DBPrune
324
- cache_options = {
325
- "cache_type": CacheType.DBPrune,
326
- "residual_diff_threshold": 0.05,
327
- # Never prune the first `Fn` and last `Bn` blocks.
328
- "Fn_compute_blocks": 8, # default 1
329
- "Bn_compute_blocks": 8, # default 0
330
- "warmup_steps": 8, # default -1
331
- # Disables the pruning strategy when the previous
332
- # pruned steps greater than this value.
333
- "max_pruned_steps": 12, # default, -1 means no limit
334
- # Enable dynamic prune threshold within step, higher
335
- # `max_dynamic_prune_threshold` value may introduce a more
336
- # ageressive pruning strategy.
337
- "enable_dynamic_prune_threshold": True,
338
- "max_dynamic_prune_threshold": 2 * 0.05,
339
- # (New thresh) = mean(previous_block_diffs_within_step) * 1.25
340
- # (New thresh) = ((New thresh) if (New thresh) <
341
- # max_dynamic_prune_threshold else residual_diff_threshold)
342
- "dynamic_prune_threshold_relax_ratio": 1.25,
343
- # The step interval to update residual cache. For example,
344
- # 2: means the update steps will be [0, 2, 4, ...].
345
- "residual_cache_update_interval": 1,
346
- # You can set non-prune blocks to avoid ageressive pruning.
347
- # For example, FLUX.1 has 19 + 38 blocks, so we can set it
348
- # to 0, 2, 4, ..., 56, etc.
349
- "non_prune_blocks_ids": [],
350
- }
351
-
352
- apply_cache_on_pipe(pipe, **cache_options)
353
- ```
354
-
355
- > [!Important]
356
- > Please note that for GPUs with lower VRAM, DBPrune may not be suitable for use on video DiTs, as it caches the hidden states and residuals of each block, leading to higher GPU memory requirements. In such cases, please use DBCache, which only caches the hidden states and residuals of 2 blocks.
357
-
358
- <div align="center">
359
- <p align="center">
360
- DBPrune, <b> L20x1 </b>, Steps: 28, "A cat holding a sign that says hello world with complex background"
361
- </p>
362
- </div>
363
-
364
- |Baseline(L20x1)|Pruned(24%)|Pruned(35%)|Pruned(38%)|Pruned(45%)|Pruned(60%)|
365
- |:---:|:---:|:---:|:---:|:---:|:---:|
366
- |24.85s|19.43s|16.82s|15.95s|14.24s|10.66s|
367
- |<img src=https://github.com/vipshop/cache-dit/raw/main/assets/NONE_R0.08_S0.png width=105px>|<img src=https://github.com/vipshop/cache-dit/raw/main/assets/DBPRUNE_F1B0_R0.03_P24.0_T19.43s.png width=105px> | <img src=https://github.com/vipshop/cache-dit/raw/main/assets/DBPRUNE_F1B0_R0.04_P34.6_T16.82s.png width=105px>|<img src=https://github.com/vipshop/cache-dit/raw/main/assets/DBPRUNE_F1B0_R0.05_P38.3_T15.95s.png width=105px>|<img src=https://github.com/vipshop/cache-dit/raw/main/assets/DBPRUNE_F1B0_R0.06_P45.2_T14.24s.png width=105px>|<img src=https://github.com/vipshop/cache-dit/raw/main/assets/DBPRUNE_F1B0_R0.2_P59.5_T10.66s.png width=105px>|
368
-
369
- ## 🎉Context Parallelism
370
-
371
- <div id="context-parallelism"></div>
372
-
373
- **CacheDiT** are **plug-and-play** solutions that works hand-in-hand with [ParaAttention](https://github.com/chengzeyi/ParaAttention). Users can **easily tap into** its **Context Parallelism** features for distributed inference. Firstly, install `para-attn` from PyPI:
374
-
375
- ```bash
376
- pip3 install para-attn # or install `para-attn` from sources.
377
- ```
378
-
379
- Then, you can run **DBCache** or **DBPrune** with **Context Parallelism** on 4 GPUs:
380
-
381
- ```python
382
- import torch.distributed as dist
383
- from diffusers import FluxPipeline
384
- from para_attn.context_parallel import init_context_parallel_mesh
385
- from para_attn.context_parallel.diffusers_adapters import parallelize_pipe
386
- from cache_dit.cache_factory import apply_cache_on_pipe, CacheType
387
-
388
- # Init distributed process group
389
- dist.init_process_group()
390
- torch.cuda.set_device(dist.get_rank())
391
-
392
- pipe = FluxPipeline.from_pretrained(
393
- "black-forest-labs/FLUX.1-dev",
394
- torch_dtype=torch.bfloat16,
395
- ).to("cuda")
396
-
397
- # Context Parallel from ParaAttention
398
- parallelize_pipe(
399
- pipe, mesh=init_context_parallel_mesh(
400
- pipe.device.type, max_ulysses_dim_size=4
401
- )
402
- )
403
-
404
- # DBPrune with default options from this library
405
- apply_cache_on_pipe(
406
- pipe, **CacheType.default_options(CacheType.DBPrune)
407
- )
408
-
409
- dist.destroy_process_group()
410
- ```
411
- Then, run the python test script with `torchrun`:
412
- ```bash
413
- torchrun --nproc_per_node=4 parallel_cache.py
414
- ```
415
-
416
- <div align="center">
417
- <p align="center">
418
- DBPrune + <b>torch.compile + context parallelism</b> <br>Steps: 28, "A cat holding a sign that says hello world with complex background"
419
- </p>
420
- </div>
421
-
422
- |Baseline|Pruned(24%)|Pruned(35%)|Pruned(38%)|Pruned(45%)|Pruned(60%)|
423
- |:---:|:---:|:---:|:---:|:---:|:---:|
424
- |+compile:20.43s|16.25s|14.12s|13.41s|12.00s|8.86s|
425
- |+L20x4:7.75s|6.62s|6.03s|5.81s|5.24s|3.93s|
426
- |<img src=https://github.com/vipshop/cache-dit/raw/main/assets/U0_C1_NONE_R0.08_S0_T20.43s.png width=105px>|<img src=https://github.com/vipshop/cache-dit/raw/main/assets/U0_C1_DBPRUNE_F1B0_R0.03_P24.0_T16.25s.png width=105px> | <img src=https://github.com/vipshop/cache-dit/raw/main/assets/U0_C1_DBPRUNE_F1B0_R0.04_P34.6_T14.12s.png width=105px>|<img src=https://github.com/vipshop/cache-dit/raw/main/assets/U0_C1_DBPRUNE_F1B0_R0.045_P38.2_T13.41s.png width=105px>|<img src=https://github.com/vipshop/cache-dit/raw/main/assets/U0_C1_DBPRUNE_F1B0_R0.055_P45.1_T12.00s.png width=105px>|<img src=https://github.com/vipshop/cache-dit/raw/main/assets/U0_C1_DBPRUNE_F1B0_R0.2_P59.5_T8.86s.png width=105px>|
427
-
428
311
  ## 🔥Torch Compile
429
312
 
430
313
  <div id="compile"></div>
431
314
 
432
- By the way, **CacheDiT** is designed to work compatibly with **torch.compile.** You can easily use CacheDiT with torch.compile to further achieve a better performance. For example:
315
+ By the way, **cache-dit** is designed to work compatibly with **torch.compile.** You can easily use cache-dit with torch.compile to further achieve a better performance. For example:
433
316
 
434
317
  ```python
435
- apply_cache_on_pipe(
436
- pipe, **CacheType.default_options(CacheType.DBPrune)
318
+ cache_dit.enable_cache(
319
+ pipe, **cache_dit.default_options()
437
320
  )
438
321
  # Compile the Transformer module
439
322
  pipe.transformer = torch.compile(pipe.transformer)
440
323
  ```
441
- However, users intending to use **CacheDiT** for DiT with **dynamic input shapes** should consider increasing the **recompile** **limit** of `torch._dynamo`. Otherwise, the recompile_limit error may be triggered, causing the module to fall back to eager mode.
324
+ However, users intending to use **cache-dit** for DiT with **dynamic input shapes** should consider increasing the **recompile** **limit** of `torch._dynamo`. Otherwise, the recompile_limit error may be triggered, causing the module to fall back to eager mode.
442
325
  ```python
443
326
  torch._dynamo.config.recompile_limit = 96 # default is 8
444
327
  torch._dynamo.config.accumulated_recompile_limit = 2048 # default is 256
@@ -447,11 +330,11 @@ torch._dynamo.config.accumulated_recompile_limit = 2048 # default is 256
447
330
  Please check [bench.py](./bench/bench.py) for more details.
448
331
 
449
332
 
450
- ## ⚙️Metrics CLI
333
+ ## 🛠Metrics CLI
451
334
 
452
335
  <div id="metrics"></div>
453
336
 
454
- You can utilize the APIs provided by CacheDiT to quickly evaluate the accuracy losses caused by different cache configurations. For example:
337
+ You can utilize the APIs provided by cache-dit to quickly evaluate the accuracy losses caused by different cache configurations. For example:
455
338
 
456
339
  ```python
457
340
  from cache_dit.metrics import compute_psnr
@@ -480,21 +363,21 @@ cache-dit-metrics-cli psnr -i1 true_dir -i2 test_dir # PSNR
480
363
  ## 👋Contribute
481
364
  <div id="contribute"></div>
482
365
 
483
- How to contribute? Star ⭐️ this repo to support us or check [CONTRIBUTE.md](https://github.com/vipshop/cache-dit/raw/main/CONTRIBUTE.md).
366
+ How to contribute? Star ⭐️ this repo to support us or check [CONTRIBUTE.md](./CONTRIBUTE.md).
484
367
 
485
368
  ## ©️License
486
369
 
487
370
  <div id="license"></div>
488
371
 
489
- The **CacheDiT** codebase is adapted from [FBCache](https://github.com/chengzeyi/ParaAttention/tree/main/src/para_attn/first_block_cache). Special thanks to their excellent work! We have followed the original License from [FBCache](https://github.com/chengzeyi/ParaAttention), please check [LICENSE](https://github.com/vipshop/cache-dit/raw/main/LICENSE) for more details.
372
+ The **cache-dit** codebase is adapted from FBCache. Special thanks to their excellent work! We have followed the original License from FBCache, please check [LICENSE](./LICENSE) for more details.
490
373
 
491
374
  ## ©️Citations
492
375
 
493
376
  <div id="citations"></div>
494
377
 
495
378
  ```BibTeX
496
- @misc{CacheDiT@2025,
497
- title={CacheDiT: A Training-free and Easy-to-use cache acceleration Toolbox for Diffusion Transformers},
379
+ @misc{cache-dit@2025,
380
+ title={cache-dit: An Unified and Training-free Cache Acceleration Toolbox for Diffusion Transformers},
498
381
  url={https://github.com/vipshop/cache-dit.git},
499
382
  note={Open-source software available at https://github.com/vipshop/cache-dit.git},
500
383
  author={vipshop.com},
@@ -0,0 +1,30 @@
1
+ cache_dit/__init__.py,sha256=gRJrSVrj-700qjgjwHfcHkiIHKbGm2cutP1TybxQZk4,605
2
+ cache_dit/_version.py,sha256=sRnPbdnyLakHrE7uBPRC_AQNPiFphtVIa4BPaftkqk4,706
3
+ cache_dit/logger.py,sha256=0zsu42hN-3-rgGC_C29ms1IvVpV4_b4_SwJCKSenxBE,4304
4
+ cache_dit/primitives.py,sha256=A2iG9YLot3gOsZSPp-_gyjqjLgJvWQRx8aitD4JQ23Y,3877
5
+ cache_dit/utils.py,sha256=4cFNh0asch6Zgsixq0bS1ElfwBu_6BG5ZSmaa1khjyg,144
6
+ cache_dit/cache_factory/.gitignore,sha256=5Cb-qT9wsTUoMJ7vACDF7ZcLpAXhi5v-xdcWSRit988,23
7
+ cache_dit/cache_factory/__init__.py,sha256=2td8ivq0DDzu00Kq1oPvq0Bh5C76w_gwsMfyUo2xW9U,1652
8
+ cache_dit/cache_factory/cache_adapters.py,sha256=ECYRvgx6ePX6Jd6sqUXmXi6kbWaqlOdvm6aZLhpedW0,23455
9
+ cache_dit/cache_factory/cache_blocks.py,sha256=9jgK2IT0Y_AlbhJLnhgA47lOxQNwNizDgHve45818gg,18390
10
+ cache_dit/cache_factory/cache_context.py,sha256=f-ihx14NXIZNakN2b_dduegRpJr5SwcPtc2PqnpDdUY,39818
11
+ cache_dit/cache_factory/taylorseer.py,sha256=LKSNo2ode69EVo9xrxjxAMEjz0yDGiGADeDYnEqddA8,3987
12
+ cache_dit/cache_factory/utils.py,sha256=iQg3dqBfQTGkvMdKeO5-YmzkQO5LBSoZ8sYKwQA_7_I,1805
13
+ cache_dit/cache_factory/patch/__init__.py,sha256=47DEQpj8HBSa-_TImW-5JCeuQeRkm5NMpJWZG3hSuFU,0
14
+ cache_dit/cache_factory/patch/flux.py,sha256=eTdq-3limKHgwtVCILkZTwt9FwYUhH7_VlhKnfu55BU,8999
15
+ cache_dit/compile/__init__.py,sha256=FcTVzCeyypl-mxlc59_ehHL3lBNiDAFsXuRoJ-5Cfi0,56
16
+ cache_dit/compile/utils.py,sha256=ugHrv3QRieG1xKwcg_pi3yVZF6EpSOEJjRmbnfa7VG0,3779
17
+ cache_dit/custom_ops/__init__.py,sha256=47DEQpj8HBSa-_TImW-5JCeuQeRkm5NMpJWZG3hSuFU,0
18
+ cache_dit/custom_ops/triton_taylorseer.py,sha256=47DEQpj8HBSa-_TImW-5JCeuQeRkm5NMpJWZG3hSuFU,0
19
+ cache_dit/metrics/__init__.py,sha256=RaUhl5dieF40RqnizGzR30qoJJ9dyMUEADwgwMaMQrE,575
20
+ cache_dit/metrics/config.py,sha256=ieOgD9ayz722RjVzk24bSIqS2D6o7TZjGk8KeXV-OLQ,551
21
+ cache_dit/metrics/fid.py,sha256=9Ivtazl6mW0Bon2VXa-Ia5Xj2ewxRD3V1Qkd69zYM3Y,17066
22
+ cache_dit/metrics/inception.py,sha256=pBVe2X6ylLPIXTG4-GWDM9DWnCviMJbJ45R3ulhktR0,12759
23
+ cache_dit/metrics/lpips.py,sha256=I2qCNi6qJh5TRsaIsdxO0WoRX1DN7U_H3zS0oCSahYM,1032
24
+ cache_dit/metrics/metrics.py,sha256=8jvM1sF-nDxUuwCRy44QEoo4dYVLCQVh1QyAMs4eaQY,27840
25
+ cache_dit-0.2.17.dist-info/licenses/LICENSE,sha256=Dqb07Ik2dV41s9nIdMUbiRWEfDqo7-dQeRiY7kPO8PE,3769
26
+ cache_dit-0.2.17.dist-info/METADATA,sha256=HqEAEr08N7whWcxOMOVJKThQPglCW_GAj-LcynXmIDI,19804
27
+ cache_dit-0.2.17.dist-info/WHEEL,sha256=_zCd3N1l69ArxyTb8rzEoP9TpbYXkqRFSNOD5OuxnTs,91
28
+ cache_dit-0.2.17.dist-info/entry_points.txt,sha256=FX2gysXaZx6NeK1iCLMcIdP8Q4_qikkIHtEmi3oWn8o,65
29
+ cache_dit-0.2.17.dist-info/top_level.txt,sha256=ZJDydonLEhujzz0FOkVbO-BqfzO9d_VqRHmZU-3MOZo,10
30
+ cache_dit-0.2.17.dist-info/RECORD,,
@@ -1,169 +0,0 @@
1
- from enum import Enum
2
-
3
- from diffusers import DiffusionPipeline
4
-
5
- from cache_dit.cache_factory.dual_block_cache.diffusers_adapters import (
6
- apply_db_cache_on_pipe,
7
- )
8
- from cache_dit.cache_factory.first_block_cache.diffusers_adapters import (
9
- apply_fb_cache_on_pipe,
10
- )
11
- from cache_dit.cache_factory.dynamic_block_prune.diffusers_adapters import (
12
- apply_db_prune_on_pipe,
13
- )
14
-
15
- from cache_dit.logger import init_logger
16
-
17
-
18
- logger = init_logger(__name__)
19
-
20
-
21
- class CacheType(Enum):
22
- NONE = "NONE"
23
- FBCache = "First_Block_Cache"
24
- DBCache = "Dual_Block_Cache"
25
- DBPrune = "Dynamic_Block_Prune"
26
-
27
- @staticmethod
28
- def type(cache_type: "CacheType | str") -> "CacheType":
29
- if isinstance(cache_type, CacheType):
30
- return cache_type
31
- return CacheType.cache_type(cache_type)
32
-
33
- @staticmethod
34
- def cache_type(cache_type: "CacheType | str") -> "CacheType":
35
- if cache_type is None:
36
- return CacheType.NONE
37
-
38
- if isinstance(cache_type, CacheType):
39
- return cache_type
40
- if cache_type.lower() in (
41
- "first_block_cache",
42
- "fb_cache",
43
- "fbcache",
44
- "fb",
45
- ):
46
- return CacheType.FBCache
47
- elif cache_type.lower() in (
48
- "dual_block_cache",
49
- "db_cache",
50
- "dbcache",
51
- "db",
52
- ):
53
- return CacheType.DBCache
54
- elif cache_type.lower() in (
55
- "dynamic_block_prune",
56
- "db_prune",
57
- "dbprune",
58
- "dbp",
59
- ):
60
- return CacheType.DBPrune
61
- elif cache_type.lower() in (
62
- "none_cache",
63
- "nonecache",
64
- "no_cache",
65
- "nocache",
66
- "none",
67
- "no",
68
- ):
69
- return CacheType.NONE
70
- else:
71
- raise ValueError(f"Unknown cache type: {cache_type}")
72
-
73
- @staticmethod
74
- def range(start: int, end: int, step: int = 1) -> list[int]:
75
- if start > end or end <= 0 or step <= 1:
76
- return []
77
- # Always compute 0 and end - 1 blocks for DB Cache
78
- return list(
79
- sorted(set([0] + list(range(start, end, step)) + [end - 1]))
80
- )
81
-
82
- @staticmethod
83
- def default_options(cache_type: "CacheType | str") -> dict:
84
- _no_options = {
85
- "cache_type": CacheType.NONE,
86
- }
87
-
88
- _fb_options = {
89
- "cache_type": CacheType.FBCache,
90
- "residual_diff_threshold": 0.08,
91
- "warmup_steps": 8,
92
- "max_cached_steps": 8,
93
- }
94
-
95
- _Fn_compute_blocks = 8
96
- _Bn_compute_blocks = 8
97
-
98
- _db_options = {
99
- "cache_type": CacheType.DBCache,
100
- "residual_diff_threshold": 0.12,
101
- "warmup_steps": 8,
102
- "max_cached_steps": -1, # -1 means no limit
103
- # Fn=1, Bn=0, means FB Cache, otherwise, Dual Block Cache
104
- "Fn_compute_blocks": _Fn_compute_blocks,
105
- "Bn_compute_blocks": _Bn_compute_blocks,
106
- "max_Fn_compute_blocks": 16,
107
- "max_Bn_compute_blocks": 16,
108
- "Fn_compute_blocks_ids": [], # 0, 1, 2, ..., 7, etc.
109
- "Bn_compute_blocks_ids": [], # 0, 1, 2, ..., 7, etc.
110
- }
111
-
112
- _dbp_options = {
113
- "cache_type": CacheType.DBPrune,
114
- "residual_diff_threshold": 0.08,
115
- "Fn_compute_blocks": _Fn_compute_blocks,
116
- "Bn_compute_blocks": _Bn_compute_blocks,
117
- "warmup_steps": 8,
118
- "max_pruned_steps": -1, # -1 means no limit
119
- }
120
-
121
- if cache_type == CacheType.FBCache:
122
- return _fb_options
123
- elif cache_type == CacheType.DBCache:
124
- return _db_options
125
- elif cache_type == CacheType.DBPrune:
126
- return _dbp_options
127
- elif cache_type == CacheType.NONE:
128
- return _no_options
129
- else:
130
- raise ValueError(f"Unknown cache type: {cache_type}")
131
-
132
-
133
- def apply_cache_on_pipe(pipe: DiffusionPipeline, *args, **kwargs):
134
- assert isinstance(pipe, DiffusionPipeline)
135
-
136
- if hasattr(pipe, "_is_cached") and pipe._is_cached:
137
- return pipe
138
-
139
- if hasattr(pipe, "_is_pruned") and pipe._is_pruned:
140
- return pipe
141
-
142
- cache_type = kwargs.pop("cache_type", None)
143
- if cache_type is None:
144
- logger.warning(
145
- "No cache type specified, we will use DBCache by default. "
146
- "Please specify the cache_type explicitly if you want to "
147
- "use a different cache type."
148
- )
149
- # Force to use DBCache with default cache options
150
- return apply_db_cache_on_pipe(
151
- pipe,
152
- **CacheType.default_options(CacheType.DBCache),
153
- )
154
-
155
- cache_type = CacheType.type(cache_type)
156
-
157
- if cache_type == CacheType.FBCache:
158
- return apply_fb_cache_on_pipe(pipe, *args, **kwargs)
159
- elif cache_type == CacheType.DBCache:
160
- return apply_db_cache_on_pipe(pipe, *args, **kwargs)
161
- elif cache_type == CacheType.DBPrune:
162
- return apply_db_prune_on_pipe(pipe, *args, **kwargs)
163
- elif cache_type == CacheType.NONE:
164
- logger.warning(
165
- f"Cache type is {cache_type}, no caching will be applied."
166
- )
167
- return pipe
168
- else:
169
- raise ValueError(f"Unknown cache type: {cache_type}")