cache-dit 0.1.2__tar.gz → 0.1.5__tar.gz

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.

Potentially problematic release.


This version of cache-dit might be problematic. Click here for more details.

Files changed (77) hide show
  1. {cache_dit-0.1.2 → cache_dit-0.1.5}/PKG-INFO +48 -11
  2. {cache_dit-0.1.2 → cache_dit-0.1.5}/README.md +47 -10
  3. {cache_dit-0.1.2 → cache_dit-0.1.5}/bench/bench.py +2 -1
  4. cache_dit-0.1.5/examples/run_cogvideox.py +30 -0
  5. cache_dit-0.1.5/examples/run_mochi.py +25 -0
  6. {cache_dit-0.1.2 → cache_dit-0.1.5}/src/cache_dit/_version.py +2 -2
  7. {cache_dit-0.1.2 → cache_dit-0.1.5}/src/cache_dit.egg-info/PKG-INFO +48 -11
  8. {cache_dit-0.1.2 → cache_dit-0.1.5}/src/cache_dit.egg-info/SOURCES.txt +2 -0
  9. {cache_dit-0.1.2 → cache_dit-0.1.5}/.github/workflows/issue.yml +0 -0
  10. {cache_dit-0.1.2 → cache_dit-0.1.5}/.gitignore +0 -0
  11. {cache_dit-0.1.2 → cache_dit-0.1.5}/.pre-commit-config.yaml +0 -0
  12. {cache_dit-0.1.2 → cache_dit-0.1.5}/CONTRIBUTE.md +0 -0
  13. {cache_dit-0.1.2 → cache_dit-0.1.5}/LICENSE +0 -0
  14. {cache_dit-0.1.2 → cache_dit-0.1.5}/MANIFEST.in +0 -0
  15. {cache_dit-0.1.2 → cache_dit-0.1.5}/assets/DBCACHE_F12B12S4_R0.2_S16.png +0 -0
  16. {cache_dit-0.1.2 → cache_dit-0.1.5}/assets/DBCACHE_F12B16S4_R0.08_S6.png +0 -0
  17. {cache_dit-0.1.2 → cache_dit-0.1.5}/assets/DBCACHE_F16B16S2_R0.2_S14.png +0 -0
  18. {cache_dit-0.1.2 → cache_dit-0.1.5}/assets/DBCACHE_F16B16S4_R0.2_S13.png +0 -0
  19. {cache_dit-0.1.2 → cache_dit-0.1.5}/assets/DBCACHE_F1B0S1_R0.08_S11.png +0 -0
  20. {cache_dit-0.1.2 → cache_dit-0.1.5}/assets/DBCACHE_F1B0S1_R0.2_S19.png +0 -0
  21. {cache_dit-0.1.2 → cache_dit-0.1.5}/assets/DBCACHE_F8B0S2_R0.12_S12.png +0 -0
  22. {cache_dit-0.1.2 → cache_dit-0.1.5}/assets/DBCACHE_F8B16S1_R0.2_S18.png +0 -0
  23. {cache_dit-0.1.2 → cache_dit-0.1.5}/assets/DBCACHE_F8B8S1_R0.08_S9.png +0 -0
  24. {cache_dit-0.1.2 → cache_dit-0.1.5}/assets/DBCACHE_F8B8S1_R0.12_S12.png +0 -0
  25. {cache_dit-0.1.2 → cache_dit-0.1.5}/assets/DBCACHE_F8B8S1_R0.15_S15.png +0 -0
  26. {cache_dit-0.1.2 → cache_dit-0.1.5}/assets/DBCache.png +0 -0
  27. {cache_dit-0.1.2 → cache_dit-0.1.5}/assets/DBPRUNE_F1B0_R0.03_P24.0_T19.43s.png +0 -0
  28. {cache_dit-0.1.2 → cache_dit-0.1.5}/assets/DBPRUNE_F1B0_R0.04_P34.6_T16.82s.png +0 -0
  29. {cache_dit-0.1.2 → cache_dit-0.1.5}/assets/DBPRUNE_F1B0_R0.05_P38.3_T15.95s.png +0 -0
  30. {cache_dit-0.1.2 → cache_dit-0.1.5}/assets/DBPRUNE_F1B0_R0.06_P45.2_T14.24s.png +0 -0
  31. {cache_dit-0.1.2 → cache_dit-0.1.5}/assets/DBPRUNE_F1B0_R0.07_P52.3_T12.53s.png +0 -0
  32. {cache_dit-0.1.2 → cache_dit-0.1.5}/assets/DBPRUNE_F1B0_R0.08_P52.4_T12.52s.png +0 -0
  33. {cache_dit-0.1.2 → cache_dit-0.1.5}/assets/DBPRUNE_F1B0_R0.09_P59.2_T10.81s.png +0 -0
  34. {cache_dit-0.1.2 → cache_dit-0.1.5}/assets/DBPRUNE_F1B0_R0.12_P59.5_T10.76s.png +0 -0
  35. {cache_dit-0.1.2 → cache_dit-0.1.5}/assets/DBPRUNE_F1B0_R0.12_P63.0_T9.90s.png +0 -0
  36. {cache_dit-0.1.2 → cache_dit-0.1.5}/assets/DBPRUNE_F1B0_R0.1_P62.8_T9.95s.png +0 -0
  37. {cache_dit-0.1.2 → cache_dit-0.1.5}/assets/DBPRUNE_F1B0_R0.2_P59.5_T10.66s.png +0 -0
  38. {cache_dit-0.1.2 → cache_dit-0.1.5}/assets/DBPRUNE_F1B0_R0.3_P63.1_T9.79s.png +0 -0
  39. {cache_dit-0.1.2 → cache_dit-0.1.5}/assets/NONE_R0.08_S0.png +0 -0
  40. {cache_dit-0.1.2 → cache_dit-0.1.5}/assets/cache-dit.png +0 -0
  41. {cache_dit-0.1.2 → cache_dit-0.1.5}/bench/.gitignore +0 -0
  42. {cache_dit-0.1.2 → cache_dit-0.1.5}/docs/.gitignore +0 -0
  43. {cache_dit-0.1.2 → cache_dit-0.1.5}/examples/.gitignore +0 -0
  44. {cache_dit-0.1.2 → cache_dit-0.1.5}/examples/run_flux.py +0 -0
  45. {cache_dit-0.1.2 → cache_dit-0.1.5}/pyproject.toml +0 -0
  46. {cache_dit-0.1.2 → cache_dit-0.1.5}/pytest.ini +0 -0
  47. {cache_dit-0.1.2 → cache_dit-0.1.5}/requirements.txt +0 -0
  48. {cache_dit-0.1.2 → cache_dit-0.1.5}/setup.cfg +0 -0
  49. {cache_dit-0.1.2 → cache_dit-0.1.5}/setup.py +0 -0
  50. {cache_dit-0.1.2 → cache_dit-0.1.5}/src/cache_dit/__init__.py +0 -0
  51. {cache_dit-0.1.2 → cache_dit-0.1.5}/src/cache_dit/cache_factory/__init__.py +0 -0
  52. {cache_dit-0.1.2 → cache_dit-0.1.5}/src/cache_dit/cache_factory/dual_block_cache/__init__.py +0 -0
  53. {cache_dit-0.1.2 → cache_dit-0.1.5}/src/cache_dit/cache_factory/dual_block_cache/cache_context.py +0 -0
  54. {cache_dit-0.1.2 → cache_dit-0.1.5}/src/cache_dit/cache_factory/dual_block_cache/diffusers_adapters/__init__.py +0 -0
  55. {cache_dit-0.1.2 → cache_dit-0.1.5}/src/cache_dit/cache_factory/dual_block_cache/diffusers_adapters/cogvideox.py +0 -0
  56. {cache_dit-0.1.2 → cache_dit-0.1.5}/src/cache_dit/cache_factory/dual_block_cache/diffusers_adapters/flux.py +0 -0
  57. {cache_dit-0.1.2 → cache_dit-0.1.5}/src/cache_dit/cache_factory/dual_block_cache/diffusers_adapters/mochi.py +0 -0
  58. {cache_dit-0.1.2 → cache_dit-0.1.5}/src/cache_dit/cache_factory/dynamic_block_prune/__init__.py +0 -0
  59. {cache_dit-0.1.2 → cache_dit-0.1.5}/src/cache_dit/cache_factory/dynamic_block_prune/diffusers_adapters/__init__.py +0 -0
  60. {cache_dit-0.1.2 → cache_dit-0.1.5}/src/cache_dit/cache_factory/dynamic_block_prune/diffusers_adapters/cogvideox.py +0 -0
  61. {cache_dit-0.1.2 → cache_dit-0.1.5}/src/cache_dit/cache_factory/dynamic_block_prune/diffusers_adapters/flux.py +0 -0
  62. {cache_dit-0.1.2 → cache_dit-0.1.5}/src/cache_dit/cache_factory/dynamic_block_prune/diffusers_adapters/mochi.py +0 -0
  63. {cache_dit-0.1.2 → cache_dit-0.1.5}/src/cache_dit/cache_factory/dynamic_block_prune/prune_context.py +0 -0
  64. {cache_dit-0.1.2 → cache_dit-0.1.5}/src/cache_dit/cache_factory/first_block_cache/__init__.py +0 -0
  65. {cache_dit-0.1.2 → cache_dit-0.1.5}/src/cache_dit/cache_factory/first_block_cache/cache_context.py +0 -0
  66. {cache_dit-0.1.2 → cache_dit-0.1.5}/src/cache_dit/cache_factory/first_block_cache/diffusers_adapters/__init__.py +0 -0
  67. {cache_dit-0.1.2 → cache_dit-0.1.5}/src/cache_dit/cache_factory/first_block_cache/diffusers_adapters/cogvideox.py +0 -0
  68. {cache_dit-0.1.2 → cache_dit-0.1.5}/src/cache_dit/cache_factory/first_block_cache/diffusers_adapters/flux.py +0 -0
  69. {cache_dit-0.1.2 → cache_dit-0.1.5}/src/cache_dit/cache_factory/first_block_cache/diffusers_adapters/mochi.py +0 -0
  70. {cache_dit-0.1.2 → cache_dit-0.1.5}/src/cache_dit/cache_factory/first_block_cache/diffusers_adapters/wan.py +0 -0
  71. {cache_dit-0.1.2 → cache_dit-0.1.5}/src/cache_dit/cache_factory/taylorseer.py +0 -0
  72. {cache_dit-0.1.2 → cache_dit-0.1.5}/src/cache_dit/cache_factory/utils.py +0 -0
  73. {cache_dit-0.1.2 → cache_dit-0.1.5}/src/cache_dit/logger.py +0 -0
  74. {cache_dit-0.1.2 → cache_dit-0.1.5}/src/cache_dit/primitives.py +0 -0
  75. {cache_dit-0.1.2 → cache_dit-0.1.5}/src/cache_dit.egg-info/dependency_links.txt +0 -0
  76. {cache_dit-0.1.2 → cache_dit-0.1.5}/src/cache_dit.egg-info/requires.txt +0 -0
  77. {cache_dit-0.1.2 → cache_dit-0.1.5}/src/cache_dit.egg-info/top_level.txt +0 -0
@@ -1,6 +1,6 @@
1
1
  Metadata-Version: 2.4
2
2
  Name: cache_dit
3
- Version: 0.1.2
3
+ Version: 0.1.5
4
4
  Summary: 🤗 CacheDiT: A Training-free and Easy-to-use Cache Acceleration Toolbox for Diffusion Transformers
5
5
  Author: DefTruth, vipshop.com, etc.
6
6
  Maintainer: DefTruth, vipshop.com, etc
@@ -37,13 +37,14 @@ Dynamic: requires-python
37
37
  <p align="center">
38
38
  <h3>🤗 CacheDiT: A Training-free and Easy-to-use Cache Acceleration <br>Toolbox for Diffusion Transformers</h3>
39
39
  </p>
40
- <img src=https://github.com/vipshop/cache-dit/raw/dev/assets/cache-dit.png >
40
+ <img src=https://github.com/vipshop/cache-dit/raw/main/assets/cache-dit.png >
41
41
  <div align='center'>
42
- <img src=https://img.shields.io/badge/Language-Python-brightgreen.svg >
43
- <img src=https://img.shields.io/badge/PRs-welcome-9cf.svg >
44
- <img src=https://img.shields.io/badge/PyPI-pass-brightgreen.svg >
45
- <img src=https://img.shields.io/badge/Python-3.10|3.11|3.12-9cf.svg >
46
- <img src=https://img.shields.io/badge/Release-v0.1.2-brightgreen.svg >
42
+ <img src=https://img.shields.io/badge/Language-Python-brightgreen.svg >
43
+ <img src=https://img.shields.io/badge/PRs-welcome-9cf.svg >
44
+ <img src=https://img.shields.io/badge/PyPI-pass-brightgreen.svg >
45
+ <img src=https://static.pepy.tech/badge/cache-dit >
46
+ <img src=https://img.shields.io/badge/Python-3.10|3.11|3.12-9cf.svg >
47
+ <img src=https://img.shields.io/badge/Release-v0.1.5-brightgreen.svg >
47
48
  </div>
48
49
  <p align="center">
49
50
  DeepCache is for UNet not DiT. Most DiT cache speedups are complex and not training-free. CacheDiT provides <br>a series of training-free, UNet-style cache accelerators for DiT: DBCache, DBPrune, FBCache, etc.
@@ -166,7 +167,7 @@ pip3 install git+https://github.com/vipshop/cache-dit.git
166
167
  - **Fn**: Specifies that DBCache uses the **first n** Transformer blocks to fit the information at time step t, enabling the calculation of a more stable L1 diff and delivering more accurate information to subsequent blocks.
167
168
  - **Bn**: Further fuses approximate information in the **last n** Transformer blocks to enhance prediction accuracy. These blocks act as an auto-scaler for approximate hidden states that use residual cache.
168
169
  - **warmup_steps**: (default: 0) DBCache does not apply the caching strategy when the number of running steps is less than or equal to this value, ensuring the model sufficiently learns basic features during warmup.
169
- - **max_cached_steps**: (default: -1) DBCache disables the caching strategy when the running steps exceed this value to prevent precision degradation.
170
+ - **max_cached_steps**: (default: -1) DBCache disables the caching strategy when the previous cached steps exceed this value to prevent precision degradation.
170
171
  - **residual_diff_threshold**: The value of residual diff threshold, a higher value leads to faster performance at the cost of lower precision.
171
172
 
172
173
  For a good balance between performance and precision, DBCache is configured by default with **F8B8**, 8 warmup steps, and unlimited cached steps.
@@ -202,8 +203,9 @@ Moreover, users configuring higher **Bn** values (e.g., **F8B16**) while aiming
202
203
  cache_options = {
203
204
  # 0, 2, 4, ..., 14, 15, etc. [0,16)
204
205
  "Bn_compute_blocks_ids": CacheType.range(0, 16, 2),
205
- # Skip Bn blocks (1, 3, 5 ,..., etc.) only if the L1 diff
206
- # lower than this value, otherwise, compute it.
206
+ # If the L1 difference is below this threshold, skip Bn blocks
207
+ # not in `Bn_compute_blocks_ids`(1, 3,..., etc), Otherwise,
208
+ # compute these blocks.
207
209
  "non_compute_blocks_diff_threshold": 0.08,
208
210
  }
209
211
  ```
@@ -259,12 +261,47 @@ pipe = FluxPipeline.from_pretrained(
259
261
  torch_dtype=torch.bfloat16,
260
262
  ).to("cuda")
261
263
 
262
- # Using DBPrune
264
+ # Using DBPrune with default options
263
265
  cache_options = CacheType.default_options(CacheType.DBPrune)
264
266
 
265
267
  apply_cache_on_pipe(pipe, **cache_options)
266
268
  ```
267
269
 
270
+ We have also brought the designs from DBCache to DBPrune to make it a more general and customizable block prune algorithm. You can specify the values of **Fn** and **Bn** for higher precision, or set up the non-prune blocks list **non_prune_blocks_ids** to avoid aggressive pruning. For example:
271
+
272
+ ```python
273
+ # Custom options for DBPrune
274
+ cache_options = {
275
+ "cache_type": CacheType.DBPrune,
276
+ "residual_diff_threshold": 0.05,
277
+ # Never prune the first `Fn` and last `Bn` blocks.
278
+ "Fn_compute_blocks": 8, # default 1
279
+ "Bn_compute_blocks": 8, # default 0
280
+ "warmup_steps": 8, # default -1
281
+ # Disables the pruning strategy when the previous
282
+ # pruned steps greater than this value.
283
+ "max_pruned_steps": 12, # default, -1 means no limit
284
+ # Enable dynamic prune threshold within step, higher
285
+ # `max_dynamic_prune_threshold` value may introduce a more
286
+ # ageressive pruning strategy.
287
+ "enable_dynamic_prune_threshold": True,
288
+ "max_dynamic_prune_threshold": 2 * 0.05,
289
+ # (New thresh) = mean(previous_block_diffs_within_step) * 1.25
290
+ # (New thresh) = ((New thresh) if (New thresh) <
291
+ # max_dynamic_prune_threshold else residual_diff_threshold)
292
+ "dynamic_prune_threshold_relax_ratio": 1.25,
293
+ # The step interval to update residual cache. For example,
294
+ # 2: means the update steps will be [0, 2, 4, ...].
295
+ "residual_cache_update_interval": 1,
296
+ # You can set non-prune blocks to avoid ageressive pruning.
297
+ # For example, FLUX.1 has 19 + 38 blocks, so we can set it
298
+ # to 0, 2, 4, ..., 56, etc.
299
+ "non_prune_blocks_ids": [],
300
+ }
301
+
302
+ apply_cache_on_pipe(pipe, **cache_options)
303
+ ```
304
+
268
305
  <div align="center">
269
306
  <p align="center">
270
307
  DBPrune, <b> L20x1 </b>, Steps: 28, "A cat holding a sign that says hello world with complex background"
@@ -2,13 +2,14 @@
2
2
  <p align="center">
3
3
  <h3>🤗 CacheDiT: A Training-free and Easy-to-use Cache Acceleration <br>Toolbox for Diffusion Transformers</h3>
4
4
  </p>
5
- <img src=https://github.com/vipshop/cache-dit/raw/dev/assets/cache-dit.png >
5
+ <img src=https://github.com/vipshop/cache-dit/raw/main/assets/cache-dit.png >
6
6
  <div align='center'>
7
- <img src=https://img.shields.io/badge/Language-Python-brightgreen.svg >
8
- <img src=https://img.shields.io/badge/PRs-welcome-9cf.svg >
9
- <img src=https://img.shields.io/badge/PyPI-pass-brightgreen.svg >
10
- <img src=https://img.shields.io/badge/Python-3.10|3.11|3.12-9cf.svg >
11
- <img src=https://img.shields.io/badge/Release-v0.1.2-brightgreen.svg >
7
+ <img src=https://img.shields.io/badge/Language-Python-brightgreen.svg >
8
+ <img src=https://img.shields.io/badge/PRs-welcome-9cf.svg >
9
+ <img src=https://img.shields.io/badge/PyPI-pass-brightgreen.svg >
10
+ <img src=https://static.pepy.tech/badge/cache-dit >
11
+ <img src=https://img.shields.io/badge/Python-3.10|3.11|3.12-9cf.svg >
12
+ <img src=https://img.shields.io/badge/Release-v0.1.5-brightgreen.svg >
12
13
  </div>
13
14
  <p align="center">
14
15
  DeepCache is for UNet not DiT. Most DiT cache speedups are complex and not training-free. CacheDiT provides <br>a series of training-free, UNet-style cache accelerators for DiT: DBCache, DBPrune, FBCache, etc.
@@ -131,7 +132,7 @@ pip3 install git+https://github.com/vipshop/cache-dit.git
131
132
  - **Fn**: Specifies that DBCache uses the **first n** Transformer blocks to fit the information at time step t, enabling the calculation of a more stable L1 diff and delivering more accurate information to subsequent blocks.
132
133
  - **Bn**: Further fuses approximate information in the **last n** Transformer blocks to enhance prediction accuracy. These blocks act as an auto-scaler for approximate hidden states that use residual cache.
133
134
  - **warmup_steps**: (default: 0) DBCache does not apply the caching strategy when the number of running steps is less than or equal to this value, ensuring the model sufficiently learns basic features during warmup.
134
- - **max_cached_steps**: (default: -1) DBCache disables the caching strategy when the running steps exceed this value to prevent precision degradation.
135
+ - **max_cached_steps**: (default: -1) DBCache disables the caching strategy when the previous cached steps exceed this value to prevent precision degradation.
135
136
  - **residual_diff_threshold**: The value of residual diff threshold, a higher value leads to faster performance at the cost of lower precision.
136
137
 
137
138
  For a good balance between performance and precision, DBCache is configured by default with **F8B8**, 8 warmup steps, and unlimited cached steps.
@@ -167,8 +168,9 @@ Moreover, users configuring higher **Bn** values (e.g., **F8B16**) while aiming
167
168
  cache_options = {
168
169
  # 0, 2, 4, ..., 14, 15, etc. [0,16)
169
170
  "Bn_compute_blocks_ids": CacheType.range(0, 16, 2),
170
- # Skip Bn blocks (1, 3, 5 ,..., etc.) only if the L1 diff
171
- # lower than this value, otherwise, compute it.
171
+ # If the L1 difference is below this threshold, skip Bn blocks
172
+ # not in `Bn_compute_blocks_ids`(1, 3,..., etc), Otherwise,
173
+ # compute these blocks.
172
174
  "non_compute_blocks_diff_threshold": 0.08,
173
175
  }
174
176
  ```
@@ -224,12 +226,47 @@ pipe = FluxPipeline.from_pretrained(
224
226
  torch_dtype=torch.bfloat16,
225
227
  ).to("cuda")
226
228
 
227
- # Using DBPrune
229
+ # Using DBPrune with default options
228
230
  cache_options = CacheType.default_options(CacheType.DBPrune)
229
231
 
230
232
  apply_cache_on_pipe(pipe, **cache_options)
231
233
  ```
232
234
 
235
+ We have also brought the designs from DBCache to DBPrune to make it a more general and customizable block prune algorithm. You can specify the values of **Fn** and **Bn** for higher precision, or set up the non-prune blocks list **non_prune_blocks_ids** to avoid aggressive pruning. For example:
236
+
237
+ ```python
238
+ # Custom options for DBPrune
239
+ cache_options = {
240
+ "cache_type": CacheType.DBPrune,
241
+ "residual_diff_threshold": 0.05,
242
+ # Never prune the first `Fn` and last `Bn` blocks.
243
+ "Fn_compute_blocks": 8, # default 1
244
+ "Bn_compute_blocks": 8, # default 0
245
+ "warmup_steps": 8, # default -1
246
+ # Disables the pruning strategy when the previous
247
+ # pruned steps greater than this value.
248
+ "max_pruned_steps": 12, # default, -1 means no limit
249
+ # Enable dynamic prune threshold within step, higher
250
+ # `max_dynamic_prune_threshold` value may introduce a more
251
+ # ageressive pruning strategy.
252
+ "enable_dynamic_prune_threshold": True,
253
+ "max_dynamic_prune_threshold": 2 * 0.05,
254
+ # (New thresh) = mean(previous_block_diffs_within_step) * 1.25
255
+ # (New thresh) = ((New thresh) if (New thresh) <
256
+ # max_dynamic_prune_threshold else residual_diff_threshold)
257
+ "dynamic_prune_threshold_relax_ratio": 1.25,
258
+ # The step interval to update residual cache. For example,
259
+ # 2: means the update steps will be [0, 2, 4, ...].
260
+ "residual_cache_update_interval": 1,
261
+ # You can set non-prune blocks to avoid ageressive pruning.
262
+ # For example, FLUX.1 has 19 + 38 blocks, so we can set it
263
+ # to 0, 2, 4, ..., 56, etc.
264
+ "non_prune_blocks_ids": [],
265
+ }
266
+
267
+ apply_cache_on_pipe(pipe, **cache_options)
268
+ ```
269
+
233
270
  <div align="center">
234
271
  <p align="center">
235
272
  DBPrune, <b> L20x1 </b>, Steps: 28, "A cat holding a sign that says hello world with complex background"
@@ -25,6 +25,7 @@ def get_args() -> argparse.ArgumentParser:
25
25
  parser.add_argument("--Bn-steps", "--BnS", type=int, default=1)
26
26
  parser.add_argument("--warmup-steps", type=int, default=0)
27
27
  parser.add_argument("--max-cached-steps", type=int, default=-1)
28
+ parser.add_argument("--max-pruned-steps", type=int, default=-1)
28
29
  parser.add_argument("--seed", type=int, default=0)
29
30
  parser.add_argument(
30
31
  "--compile",
@@ -79,7 +80,7 @@ def get_cache_options(cache_type: CacheType, args: argparse.Namespace):
79
80
  "Fn_compute_blocks": args.Fn_compute_blocks,
80
81
  "Bn_compute_blocks": args.Bn_compute_blocks,
81
82
  "warmup_steps": args.warmup_steps,
82
- "max_pruned_steps": args.max_cached_steps, # -1 means no limit
83
+ "max_pruned_steps": args.max_pruned_steps, # -1 means no limit
83
84
  # releative token diff threshold, default is 0.0
84
85
  "important_condition_threshold": 0.00,
85
86
  "enable_dynamic_prune_threshold": (
@@ -0,0 +1,30 @@
1
+ import torch
2
+ from diffusers import CogVideoXPipeline
3
+ from diffusers.utils import export_to_video
4
+ from cache_dit.cache_factory import apply_cache_on_pipe, CacheType
5
+
6
+ pipe = CogVideoXPipeline.from_pretrained(
7
+ "THUDM/CogVideoX-5b",
8
+ torch_dtype=torch.bfloat16,
9
+ ).to("cuda")
10
+
11
+ # Default options, F8B8, good balance between performance and precision
12
+ cache_options = CacheType.default_options(CacheType.DBCache)
13
+
14
+ apply_cache_on_pipe(pipe, **cache_options)
15
+
16
+ pipe.vae.enable_slicing()
17
+ pipe.vae.enable_tiling()
18
+
19
+ prompt = "A panda, dressed in a small, red jacket and a tiny hat, sits on a wooden stool in a serene bamboo forest. The panda's fluffy paws strum a miniature acoustic guitar, producing soft, melodic tunes. Nearby, a few other pandas gather, watching curiously and some clapping in rhythm. Sunlight filters through the tall bamboo, casting a gentle glow on the scene. The panda's face is expressive, showing concentration and joy as it plays. The background includes a small, flowing stream and vibrant green foliage, enhancing the peaceful and magical atmosphere of this unique musical performance."
20
+ video = pipe(
21
+ prompt=prompt,
22
+ num_videos_per_prompt=1,
23
+ num_inference_steps=50,
24
+ num_frames=49,
25
+ guidance_scale=6,
26
+ generator=torch.Generator("cuda").manual_seed(0),
27
+ ).frames[0]
28
+
29
+ print("Saving video to cogvideox.mp4")
30
+ export_to_video(video, "cogvideox.mp4", fps=8)
@@ -0,0 +1,25 @@
1
+ import torch
2
+ from diffusers import MochiPipeline
3
+ from diffusers.utils import export_to_video
4
+ from cache_dit.cache_factory import apply_cache_on_pipe, CacheType
5
+
6
+ pipe = MochiPipeline.from_pretrained(
7
+ "genmo/mochi-1-preview",
8
+ torch_dtype=torch.bfloat16,
9
+ ).to("cuda")
10
+
11
+ # Default options, F8B8, good balance between performance and precision
12
+ cache_options = CacheType.default_options(CacheType.DBCache)
13
+
14
+ apply_cache_on_pipe(pipe, **cache_options)
15
+
16
+ pipe.enable_vae_tiling()
17
+
18
+ prompt = "Close-up of a chameleon's eye, with its scaly skin changing color. Ultra high resolution 4k."
19
+ video = pipe(
20
+ prompt,
21
+ num_frames=84,
22
+ ).frames[0]
23
+
24
+ print("Saving video to mochi.mp4")
25
+ export_to_video(video, "mochi.mp4", fps=30)
@@ -17,5 +17,5 @@ __version__: str
17
17
  __version_tuple__: VERSION_TUPLE
18
18
  version_tuple: VERSION_TUPLE
19
19
 
20
- __version__ = version = '0.1.2'
21
- __version_tuple__ = version_tuple = (0, 1, 2)
20
+ __version__ = version = '0.1.5'
21
+ __version_tuple__ = version_tuple = (0, 1, 5)
@@ -1,6 +1,6 @@
1
1
  Metadata-Version: 2.4
2
2
  Name: cache_dit
3
- Version: 0.1.2
3
+ Version: 0.1.5
4
4
  Summary: 🤗 CacheDiT: A Training-free and Easy-to-use Cache Acceleration Toolbox for Diffusion Transformers
5
5
  Author: DefTruth, vipshop.com, etc.
6
6
  Maintainer: DefTruth, vipshop.com, etc
@@ -37,13 +37,14 @@ Dynamic: requires-python
37
37
  <p align="center">
38
38
  <h3>🤗 CacheDiT: A Training-free and Easy-to-use Cache Acceleration <br>Toolbox for Diffusion Transformers</h3>
39
39
  </p>
40
- <img src=https://github.com/vipshop/cache-dit/raw/dev/assets/cache-dit.png >
40
+ <img src=https://github.com/vipshop/cache-dit/raw/main/assets/cache-dit.png >
41
41
  <div align='center'>
42
- <img src=https://img.shields.io/badge/Language-Python-brightgreen.svg >
43
- <img src=https://img.shields.io/badge/PRs-welcome-9cf.svg >
44
- <img src=https://img.shields.io/badge/PyPI-pass-brightgreen.svg >
45
- <img src=https://img.shields.io/badge/Python-3.10|3.11|3.12-9cf.svg >
46
- <img src=https://img.shields.io/badge/Release-v0.1.2-brightgreen.svg >
42
+ <img src=https://img.shields.io/badge/Language-Python-brightgreen.svg >
43
+ <img src=https://img.shields.io/badge/PRs-welcome-9cf.svg >
44
+ <img src=https://img.shields.io/badge/PyPI-pass-brightgreen.svg >
45
+ <img src=https://static.pepy.tech/badge/cache-dit >
46
+ <img src=https://img.shields.io/badge/Python-3.10|3.11|3.12-9cf.svg >
47
+ <img src=https://img.shields.io/badge/Release-v0.1.5-brightgreen.svg >
47
48
  </div>
48
49
  <p align="center">
49
50
  DeepCache is for UNet not DiT. Most DiT cache speedups are complex and not training-free. CacheDiT provides <br>a series of training-free, UNet-style cache accelerators for DiT: DBCache, DBPrune, FBCache, etc.
@@ -166,7 +167,7 @@ pip3 install git+https://github.com/vipshop/cache-dit.git
166
167
  - **Fn**: Specifies that DBCache uses the **first n** Transformer blocks to fit the information at time step t, enabling the calculation of a more stable L1 diff and delivering more accurate information to subsequent blocks.
167
168
  - **Bn**: Further fuses approximate information in the **last n** Transformer blocks to enhance prediction accuracy. These blocks act as an auto-scaler for approximate hidden states that use residual cache.
168
169
  - **warmup_steps**: (default: 0) DBCache does not apply the caching strategy when the number of running steps is less than or equal to this value, ensuring the model sufficiently learns basic features during warmup.
169
- - **max_cached_steps**: (default: -1) DBCache disables the caching strategy when the running steps exceed this value to prevent precision degradation.
170
+ - **max_cached_steps**: (default: -1) DBCache disables the caching strategy when the previous cached steps exceed this value to prevent precision degradation.
170
171
  - **residual_diff_threshold**: The value of residual diff threshold, a higher value leads to faster performance at the cost of lower precision.
171
172
 
172
173
  For a good balance between performance and precision, DBCache is configured by default with **F8B8**, 8 warmup steps, and unlimited cached steps.
@@ -202,8 +203,9 @@ Moreover, users configuring higher **Bn** values (e.g., **F8B16**) while aiming
202
203
  cache_options = {
203
204
  # 0, 2, 4, ..., 14, 15, etc. [0,16)
204
205
  "Bn_compute_blocks_ids": CacheType.range(0, 16, 2),
205
- # Skip Bn blocks (1, 3, 5 ,..., etc.) only if the L1 diff
206
- # lower than this value, otherwise, compute it.
206
+ # If the L1 difference is below this threshold, skip Bn blocks
207
+ # not in `Bn_compute_blocks_ids`(1, 3,..., etc), Otherwise,
208
+ # compute these blocks.
207
209
  "non_compute_blocks_diff_threshold": 0.08,
208
210
  }
209
211
  ```
@@ -259,12 +261,47 @@ pipe = FluxPipeline.from_pretrained(
259
261
  torch_dtype=torch.bfloat16,
260
262
  ).to("cuda")
261
263
 
262
- # Using DBPrune
264
+ # Using DBPrune with default options
263
265
  cache_options = CacheType.default_options(CacheType.DBPrune)
264
266
 
265
267
  apply_cache_on_pipe(pipe, **cache_options)
266
268
  ```
267
269
 
270
+ We have also brought the designs from DBCache to DBPrune to make it a more general and customizable block prune algorithm. You can specify the values of **Fn** and **Bn** for higher precision, or set up the non-prune blocks list **non_prune_blocks_ids** to avoid aggressive pruning. For example:
271
+
272
+ ```python
273
+ # Custom options for DBPrune
274
+ cache_options = {
275
+ "cache_type": CacheType.DBPrune,
276
+ "residual_diff_threshold": 0.05,
277
+ # Never prune the first `Fn` and last `Bn` blocks.
278
+ "Fn_compute_blocks": 8, # default 1
279
+ "Bn_compute_blocks": 8, # default 0
280
+ "warmup_steps": 8, # default -1
281
+ # Disables the pruning strategy when the previous
282
+ # pruned steps greater than this value.
283
+ "max_pruned_steps": 12, # default, -1 means no limit
284
+ # Enable dynamic prune threshold within step, higher
285
+ # `max_dynamic_prune_threshold` value may introduce a more
286
+ # ageressive pruning strategy.
287
+ "enable_dynamic_prune_threshold": True,
288
+ "max_dynamic_prune_threshold": 2 * 0.05,
289
+ # (New thresh) = mean(previous_block_diffs_within_step) * 1.25
290
+ # (New thresh) = ((New thresh) if (New thresh) <
291
+ # max_dynamic_prune_threshold else residual_diff_threshold)
292
+ "dynamic_prune_threshold_relax_ratio": 1.25,
293
+ # The step interval to update residual cache. For example,
294
+ # 2: means the update steps will be [0, 2, 4, ...].
295
+ "residual_cache_update_interval": 1,
296
+ # You can set non-prune blocks to avoid ageressive pruning.
297
+ # For example, FLUX.1 has 19 + 38 blocks, so we can set it
298
+ # to 0, 2, 4, ..., 56, etc.
299
+ "non_prune_blocks_ids": [],
300
+ }
301
+
302
+ apply_cache_on_pipe(pipe, **cache_options)
303
+ ```
304
+
268
305
  <div align="center">
269
306
  <p align="center">
270
307
  DBPrune, <b> L20x1 </b>, Steps: 28, "A cat holding a sign that says hello world with complex background"
@@ -40,7 +40,9 @@ bench/.gitignore
40
40
  bench/bench.py
41
41
  docs/.gitignore
42
42
  examples/.gitignore
43
+ examples/run_cogvideox.py
43
44
  examples/run_flux.py
45
+ examples/run_mochi.py
44
46
  src/cache_dit/__init__.py
45
47
  src/cache_dit/_version.py
46
48
  src/cache_dit/logger.py
File without changes
File without changes
File without changes
File without changes
File without changes
File without changes
File without changes
File without changes
File without changes
File without changes
File without changes
File without changes
File without changes