cache-dit 0.2.2__tar.gz → 0.2.3__tar.gz

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.

Potentially problematic release.


This version of cache-dit might be problematic. Click here for more details.

Files changed (136) hide show
  1. {cache_dit-0.2.2 → cache_dit-0.2.3}/PKG-INFO +6 -2
  2. {cache_dit-0.2.2 → cache_dit-0.2.3}/README.md +5 -1
  3. {cache_dit-0.2.2 → cache_dit-0.2.3}/bench/bench.py +15 -16
  4. {cache_dit-0.2.2 → cache_dit-0.2.3}/src/cache_dit/_version.py +2 -2
  5. {cache_dit-0.2.2 → cache_dit-0.2.3}/src/cache_dit/cache_factory/dual_block_cache/cache_context.py +0 -11
  6. {cache_dit-0.2.2 → cache_dit-0.2.3}/src/cache_dit/cache_factory/dual_block_cache/diffusers_adapters/wan.py +1 -1
  7. {cache_dit-0.2.2 → cache_dit-0.2.3}/src/cache_dit/cache_factory/dynamic_block_prune/prune_context.py +23 -23
  8. {cache_dit-0.2.2 → cache_dit-0.2.3}/src/cache_dit/cache_factory/first_block_cache/cache_context.py +0 -11
  9. {cache_dit-0.2.2 → cache_dit-0.2.3}/src/cache_dit/cache_factory/taylorseer.py +1 -2
  10. cache_dit-0.2.3/src/cache_dit/compile/__init__.py +1 -0
  11. cache_dit-0.2.3/src/cache_dit/compile/utils.py +94 -0
  12. cache_dit-0.2.3/src/cache_dit/custom_ops/__init__.py +0 -0
  13. cache_dit-0.2.3/src/cache_dit/custom_ops/triton_taylorseer.py +0 -0
  14. {cache_dit-0.2.2 → cache_dit-0.2.3}/src/cache_dit/logger.py +28 -0
  15. {cache_dit-0.2.2 → cache_dit-0.2.3}/src/cache_dit.egg-info/PKG-INFO +6 -2
  16. {cache_dit-0.2.2 → cache_dit-0.2.3}/src/cache_dit.egg-info/SOURCES.txt +4 -2
  17. cache_dit-0.2.2/assets/DBCache.png +0 -0
  18. cache_dit-0.2.2/assets/cache-dit.png +0 -0
  19. {cache_dit-0.2.2 → cache_dit-0.2.3}/.github/workflows/issue.yml +0 -0
  20. {cache_dit-0.2.2 → cache_dit-0.2.3}/.gitignore +0 -0
  21. {cache_dit-0.2.2 → cache_dit-0.2.3}/.pre-commit-config.yaml +0 -0
  22. {cache_dit-0.2.2 → cache_dit-0.2.3}/CONTRIBUTE.md +0 -0
  23. {cache_dit-0.2.2 → cache_dit-0.2.3}/LICENSE +0 -0
  24. {cache_dit-0.2.2 → cache_dit-0.2.3}/MANIFEST.in +0 -0
  25. {cache_dit-0.2.2 → cache_dit-0.2.3}/assets/DBCACHE_F12B12S4_R0.2_S16.png +0 -0
  26. {cache_dit-0.2.2 → cache_dit-0.2.3}/assets/DBCACHE_F12B16S4_R0.08_S6.png +0 -0
  27. {cache_dit-0.2.2 → cache_dit-0.2.3}/assets/DBCACHE_F16B16S2_R0.2_S14.png +0 -0
  28. {cache_dit-0.2.2 → cache_dit-0.2.3}/assets/DBCACHE_F16B16S4_R0.2_S13.png +0 -0
  29. {cache_dit-0.2.2 → cache_dit-0.2.3}/assets/DBCACHE_F1B0S1_R0.08_S11.png +0 -0
  30. {cache_dit-0.2.2 → cache_dit-0.2.3}/assets/DBCACHE_F1B0S1_R0.2_S19.png +0 -0
  31. {cache_dit-0.2.2 → cache_dit-0.2.3}/assets/DBCACHE_F8B0S2_R0.12_S12.png +0 -0
  32. {cache_dit-0.2.2 → cache_dit-0.2.3}/assets/DBCACHE_F8B16S1_R0.2_S18.png +0 -0
  33. {cache_dit-0.2.2 → cache_dit-0.2.3}/assets/DBCACHE_F8B8S1_R0.08_S9.png +0 -0
  34. {cache_dit-0.2.2 → cache_dit-0.2.3}/assets/DBCACHE_F8B8S1_R0.12_S12.png +0 -0
  35. {cache_dit-0.2.2 → cache_dit-0.2.3}/assets/DBCACHE_F8B8S1_R0.15_S15.png +0 -0
  36. {cache_dit-0.2.2 → cache_dit-0.2.3}/assets/DBPRUNE_F1B0_R0.03_P24.0_T19.43s.png +0 -0
  37. {cache_dit-0.2.2 → cache_dit-0.2.3}/assets/DBPRUNE_F1B0_R0.04_P34.6_T16.82s.png +0 -0
  38. {cache_dit-0.2.2 → cache_dit-0.2.3}/assets/DBPRUNE_F1B0_R0.05_P38.3_T15.95s.png +0 -0
  39. {cache_dit-0.2.2 → cache_dit-0.2.3}/assets/DBPRUNE_F1B0_R0.06_P45.2_T14.24s.png +0 -0
  40. {cache_dit-0.2.2 → cache_dit-0.2.3}/assets/DBPRUNE_F1B0_R0.07_P52.3_T12.53s.png +0 -0
  41. {cache_dit-0.2.2 → cache_dit-0.2.3}/assets/DBPRUNE_F1B0_R0.08_P52.4_T12.52s.png +0 -0
  42. {cache_dit-0.2.2 → cache_dit-0.2.3}/assets/DBPRUNE_F1B0_R0.09_P59.2_T10.81s.png +0 -0
  43. {cache_dit-0.2.2 → cache_dit-0.2.3}/assets/DBPRUNE_F1B0_R0.12_P59.5_T10.76s.png +0 -0
  44. {cache_dit-0.2.2 → cache_dit-0.2.3}/assets/DBPRUNE_F1B0_R0.12_P63.0_T9.90s.png +0 -0
  45. {cache_dit-0.2.2 → cache_dit-0.2.3}/assets/DBPRUNE_F1B0_R0.1_P62.8_T9.95s.png +0 -0
  46. {cache_dit-0.2.2 → cache_dit-0.2.3}/assets/DBPRUNE_F1B0_R0.2_P59.5_T10.66s.png +0 -0
  47. {cache_dit-0.2.2 → cache_dit-0.2.3}/assets/DBPRUNE_F1B0_R0.3_P63.1_T9.79s.png +0 -0
  48. {cache_dit-0.2.2 → cache_dit-0.2.3}/assets/NONE_R0.08_S0.png +0 -0
  49. {cache_dit-0.2.2 → cache_dit-0.2.3}/assets/TEXTURE_DBCACHE_F1B0_R0.08.png +0 -0
  50. {cache_dit-0.2.2 → cache_dit-0.2.3}/assets/TEXTURE_DBCACHE_F8B12_R0.12.png +0 -0
  51. {cache_dit-0.2.2 → cache_dit-0.2.3}/assets/TEXTURE_DBCACHE_F8B16_R0.2.png +0 -0
  52. {cache_dit-0.2.2 → cache_dit-0.2.3}/assets/TEXTURE_DBCACHE_F8B20_R0.2.png +0 -0
  53. {cache_dit-0.2.2 → cache_dit-0.2.3}/assets/TEXTURE_DBCACHE_F8B8_R0.12.png +0 -0
  54. {cache_dit-0.2.2 → cache_dit-0.2.3}/assets/TEXTURE_NONE_R0.08.png +0 -0
  55. {cache_dit-0.2.2 → cache_dit-0.2.3}/assets/U0_C0_DBCACHE_F1B0S1W0T0ET0_R0.12_S14_T12.85s.png +0 -0
  56. {cache_dit-0.2.2 → cache_dit-0.2.3}/assets/U0_C0_DBCACHE_F1B0S1W0T0ET0_R0.15_S17_T10.27s.png +0 -0
  57. {cache_dit-0.2.2 → cache_dit-0.2.3}/assets/U0_C0_DBCACHE_F1B0S1W0T1ET1_R0.12_S14_T12.86s.png +0 -0
  58. {cache_dit-0.2.2 → cache_dit-0.2.3}/assets/U0_C0_DBCACHE_F1B0S1W0T1ET1_R0.15_S17_T10.28s.png +0 -0
  59. {cache_dit-0.2.2 → cache_dit-0.2.3}/assets/U0_C1_DBCACHE_F1B0S1W0T1ET1_R0.15_S17_T8.48s.png +0 -0
  60. {cache_dit-0.2.2 → cache_dit-0.2.3}/assets/U0_C1_DBPRUNE_F1B0_R0.03_P24.0_T16.25s.png +0 -0
  61. {cache_dit-0.2.2 → cache_dit-0.2.3}/assets/U0_C1_DBPRUNE_F1B0_R0.045_P38.2_T13.41s.png +0 -0
  62. {cache_dit-0.2.2 → cache_dit-0.2.3}/assets/U0_C1_DBPRUNE_F1B0_R0.04_P34.6_T14.12s.png +0 -0
  63. {cache_dit-0.2.2 → cache_dit-0.2.3}/assets/U0_C1_DBPRUNE_F1B0_R0.055_P45.1_T12.00s.png +0 -0
  64. {cache_dit-0.2.2 → cache_dit-0.2.3}/assets/U0_C1_DBPRUNE_F1B0_R0.05_P41.6_T12.70s.png +0 -0
  65. {cache_dit-0.2.2 → cache_dit-0.2.3}/assets/U0_C1_DBPRUNE_F1B0_R0.2_P59.5_T8.86s.png +0 -0
  66. {cache_dit-0.2.2 → cache_dit-0.2.3}/assets/U0_C1_DBPRUNE_F8B8_R0.08_P23.1_T16.14s.png +0 -0
  67. {cache_dit-0.2.2 → cache_dit-0.2.3}/assets/U0_C1_NONE_R0.08_S0_T20.43s.png +0 -0
  68. {cache_dit-0.2.2 → cache_dit-0.2.3}/assets/U4_C1_DBPRUNE_F1B0_R0.03_P27.3_T6.62s.png +0 -0
  69. {cache_dit-0.2.2 → cache_dit-0.2.3}/assets/U4_C1_DBPRUNE_F1B0_R0.03_P27.3_T6.63s.png +0 -0
  70. {cache_dit-0.2.2 → cache_dit-0.2.3}/assets/U4_C1_DBPRUNE_F1B0_R0.045_P38.2_T5.81s.png +0 -0
  71. {cache_dit-0.2.2 → cache_dit-0.2.3}/assets/U4_C1_DBPRUNE_F1B0_R0.045_P38.2_T5.82s.png +0 -0
  72. {cache_dit-0.2.2 → cache_dit-0.2.3}/assets/U4_C1_DBPRUNE_F1B0_R0.04_P34.6_T6.06s.png +0 -0
  73. {cache_dit-0.2.2 → cache_dit-0.2.3}/assets/U4_C1_DBPRUNE_F1B0_R0.04_P34.6_T6.07s.png +0 -0
  74. {cache_dit-0.2.2 → cache_dit-0.2.3}/assets/U4_C1_DBPRUNE_F1B0_R0.04_P34.6_T6.08s.png +0 -0
  75. {cache_dit-0.2.2 → cache_dit-0.2.3}/assets/U4_C1_DBPRUNE_F1B0_R0.055_P45.1_T5.27s.png +0 -0
  76. {cache_dit-0.2.2 → cache_dit-0.2.3}/assets/U4_C1_DBPRUNE_F1B0_R0.055_P45.1_T5.28s.png +0 -0
  77. {cache_dit-0.2.2 → cache_dit-0.2.3}/assets/U4_C1_DBPRUNE_F1B0_R0.2_P59.5_T3.95s.png +0 -0
  78. {cache_dit-0.2.2 → cache_dit-0.2.3}/assets/U4_C1_DBPRUNE_F1B0_R0.2_P59.5_T3.96s.png +0 -0
  79. {cache_dit-0.2.2 → cache_dit-0.2.3}/assets/U4_C1_NONE_R0.08_S0_T7.78s.png +0 -0
  80. {cache_dit-0.2.2 → cache_dit-0.2.3}/assets/U4_C1_NONE_R0.08_S0_T7.79s.png +0 -0
  81. {cache_dit-0.2.2 → cache_dit-0.2.3}/assets/cache-dit-v1.png +0 -0
  82. {cache_dit-0.2.2 → cache_dit-0.2.3}/assets/dbcache-fnbn-v1.png +0 -0
  83. {cache_dit-0.2.2 → cache_dit-0.2.3}/assets/dbcache-v1.png +0 -0
  84. {cache_dit-0.2.2 → cache_dit-0.2.3}/assets/dbprune-v1.png +0 -0
  85. {cache_dit-0.2.2 → cache_dit-0.2.3}/assets/fbcache-v1.png +0 -0
  86. {cache_dit-0.2.2 → cache_dit-0.2.3}/bench/.gitignore +0 -0
  87. {cache_dit-0.2.2 → cache_dit-0.2.3}/docs/.gitignore +0 -0
  88. {cache_dit-0.2.2 → cache_dit-0.2.3}/examples/.gitignore +0 -0
  89. {cache_dit-0.2.2 → cache_dit-0.2.3}/examples/README.md +0 -0
  90. {cache_dit-0.2.2 → cache_dit-0.2.3}/examples/data/cup.png +0 -0
  91. {cache_dit-0.2.2 → cache_dit-0.2.3}/examples/data/cup_mask.png +0 -0
  92. {cache_dit-0.2.2 → cache_dit-0.2.3}/examples/requirements.txt +0 -0
  93. {cache_dit-0.2.2 → cache_dit-0.2.3}/examples/run_cogvideox.py +0 -0
  94. {cache_dit-0.2.2 → cache_dit-0.2.3}/examples/run_flux.py +0 -0
  95. {cache_dit-0.2.2 → cache_dit-0.2.3}/examples/run_flux_fill.py +0 -0
  96. {cache_dit-0.2.2 → cache_dit-0.2.3}/examples/run_hunyuan_video.py +0 -0
  97. {cache_dit-0.2.2 → cache_dit-0.2.3}/examples/run_mochi.py +0 -0
  98. {cache_dit-0.2.2 → cache_dit-0.2.3}/examples/run_wan.py +0 -0
  99. {cache_dit-0.2.2 → cache_dit-0.2.3}/pyproject.toml +0 -0
  100. {cache_dit-0.2.2 → cache_dit-0.2.3}/pytest.ini +0 -0
  101. {cache_dit-0.2.2 → cache_dit-0.2.3}/requirements.txt +0 -0
  102. {cache_dit-0.2.2 → cache_dit-0.2.3}/setup.cfg +0 -0
  103. {cache_dit-0.2.2 → cache_dit-0.2.3}/setup.py +0 -0
  104. {cache_dit-0.2.2 → cache_dit-0.2.3}/src/cache_dit/__init__.py +0 -0
  105. {cache_dit-0.2.2 → cache_dit-0.2.3}/src/cache_dit/cache_factory/__init__.py +0 -0
  106. {cache_dit-0.2.2 → cache_dit-0.2.3}/src/cache_dit/cache_factory/dual_block_cache/__init__.py +0 -0
  107. {cache_dit-0.2.2 → cache_dit-0.2.3}/src/cache_dit/cache_factory/dual_block_cache/diffusers_adapters/__init__.py +0 -0
  108. {cache_dit-0.2.2 → cache_dit-0.2.3}/src/cache_dit/cache_factory/dual_block_cache/diffusers_adapters/cogvideox.py +0 -0
  109. {cache_dit-0.2.2 → cache_dit-0.2.3}/src/cache_dit/cache_factory/dual_block_cache/diffusers_adapters/flux.py +0 -0
  110. {cache_dit-0.2.2 → cache_dit-0.2.3}/src/cache_dit/cache_factory/dual_block_cache/diffusers_adapters/hunyuan_video.py +0 -0
  111. {cache_dit-0.2.2 → cache_dit-0.2.3}/src/cache_dit/cache_factory/dual_block_cache/diffusers_adapters/mochi.py +0 -0
  112. {cache_dit-0.2.2 → cache_dit-0.2.3}/src/cache_dit/cache_factory/dynamic_block_prune/__init__.py +0 -0
  113. {cache_dit-0.2.2 → cache_dit-0.2.3}/src/cache_dit/cache_factory/dynamic_block_prune/diffusers_adapters/__init__.py +0 -0
  114. {cache_dit-0.2.2 → cache_dit-0.2.3}/src/cache_dit/cache_factory/dynamic_block_prune/diffusers_adapters/cogvideox.py +0 -0
  115. {cache_dit-0.2.2 → cache_dit-0.2.3}/src/cache_dit/cache_factory/dynamic_block_prune/diffusers_adapters/flux.py +0 -0
  116. {cache_dit-0.2.2 → cache_dit-0.2.3}/src/cache_dit/cache_factory/dynamic_block_prune/diffusers_adapters/hunyuan_video.py +0 -0
  117. {cache_dit-0.2.2 → cache_dit-0.2.3}/src/cache_dit/cache_factory/dynamic_block_prune/diffusers_adapters/mochi.py +0 -0
  118. {cache_dit-0.2.2 → cache_dit-0.2.3}/src/cache_dit/cache_factory/dynamic_block_prune/diffusers_adapters/wan.py +0 -0
  119. {cache_dit-0.2.2 → cache_dit-0.2.3}/src/cache_dit/cache_factory/first_block_cache/__init__.py +0 -0
  120. {cache_dit-0.2.2 → cache_dit-0.2.3}/src/cache_dit/cache_factory/first_block_cache/diffusers_adapters/__init__.py +0 -0
  121. {cache_dit-0.2.2 → cache_dit-0.2.3}/src/cache_dit/cache_factory/first_block_cache/diffusers_adapters/cogvideox.py +0 -0
  122. {cache_dit-0.2.2 → cache_dit-0.2.3}/src/cache_dit/cache_factory/first_block_cache/diffusers_adapters/flux.py +0 -0
  123. {cache_dit-0.2.2 → cache_dit-0.2.3}/src/cache_dit/cache_factory/first_block_cache/diffusers_adapters/hunyuan_video.py +0 -0
  124. {cache_dit-0.2.2 → cache_dit-0.2.3}/src/cache_dit/cache_factory/first_block_cache/diffusers_adapters/mochi.py +0 -0
  125. {cache_dit-0.2.2 → cache_dit-0.2.3}/src/cache_dit/cache_factory/first_block_cache/diffusers_adapters/wan.py +0 -0
  126. {cache_dit-0.2.2 → cache_dit-0.2.3}/src/cache_dit/cache_factory/utils.py +0 -0
  127. {cache_dit-0.2.2 → cache_dit-0.2.3}/src/cache_dit/primitives.py +0 -0
  128. {cache_dit-0.2.2 → cache_dit-0.2.3}/src/cache_dit.egg-info/dependency_links.txt +0 -0
  129. {cache_dit-0.2.2 → cache_dit-0.2.3}/src/cache_dit.egg-info/requires.txt +0 -0
  130. {cache_dit-0.2.2 → cache_dit-0.2.3}/src/cache_dit.egg-info/top_level.txt +0 -0
  131. {cache_dit-0.2.2 → cache_dit-0.2.3}/tests/.gitignore +0 -0
  132. {cache_dit-0.2.2 → cache_dit-0.2.3}/tests/README.md +0 -0
  133. {cache_dit-0.2.2 → cache_dit-0.2.3}/tests/taylorseer_approximation_order_2.png +0 -0
  134. {cache_dit-0.2.2 → cache_dit-0.2.3}/tests/taylorseer_approximation_order_4.png +0 -0
  135. {cache_dit-0.2.2 → cache_dit-0.2.3}/tests/taylorseer_approximation_test.png +0 -0
  136. {cache_dit-0.2.2 → cache_dit-0.2.3}/tests/test_taylorseer.py +0 -0
@@ -1,6 +1,6 @@
1
1
  Metadata-Version: 2.4
2
2
  Name: cache_dit
3
- Version: 0.2.2
3
+ Version: 0.2.3
4
4
  Summary: 🤗 CacheDiT: A Training-free and Easy-to-use Cache Acceleration Toolbox for Diffusion Transformers
5
5
  Author: DefTruth, vipshop.com, etc.
6
6
  Maintainer: DefTruth, vipshop.com, etc
@@ -283,7 +283,11 @@ cache_options = {
283
283
  "warmup_steps": 3, # n_derivatives + 1
284
284
  "residual_diff_threshold": 0.12,
285
285
  }
286
- ```
286
+ ```
287
+
288
+ > [!Important]
289
+ > Please note that if you have used TaylorSeer as the calibrator for approximate hidden states, the **Bn** param of DBCache can be set to **0**. In essence, DBCache's Bn is also act as a calibrator, so you can choose either Bn > 0 or TaylorSeer. We recommend using the configuration scheme of **TaylorSeer** + **DBCache FnB0**.
290
+
287
291
  <div align="center">
288
292
  <p align="center">
289
293
  <b>DBCache F1B0 + TaylorSeer</b>, L20x1, Steps: 28, <br>"A cat holding a sign that says hello world with complex background"
@@ -248,7 +248,11 @@ cache_options = {
248
248
  "warmup_steps": 3, # n_derivatives + 1
249
249
  "residual_diff_threshold": 0.12,
250
250
  }
251
- ```
251
+ ```
252
+
253
+ > [!Important]
254
+ > Please note that if you have used TaylorSeer as the calibrator for approximate hidden states, the **Bn** param of DBCache can be set to **0**. In essence, DBCache's Bn is also act as a calibrator, so you can choose either Bn > 0 or TaylorSeer. We recommend using the configuration scheme of **TaylorSeer** + **DBCache FnB0**.
255
+
252
256
  <div align="center">
253
257
  <p align="center">
254
258
  <b>DBCache F1B0 + TaylorSeer</b>, L20x1, Steps: 28, <br>"A cat holding a sign that says hello world with complex background"
@@ -21,9 +21,6 @@ def get_args() -> argparse.ArgumentParser:
21
21
  parser.add_argument("--alter", action="store_true", default=False)
22
22
  parser.add_argument("--taylorseer", action="store_true", default=False)
23
23
  parser.add_argument("--taylorseer-order", "--order", type=int, default=2)
24
- parser.add_argument(
25
- "--encoder-taylorseer", action="store_true", default=False
26
- )
27
24
  parser.add_argument("--l1-diff", action="store_true", default=False)
28
25
  parser.add_argument("--rdt", type=float, default=0.08)
29
26
  parser.add_argument("--Fn-compute-blocks", "--Fn", type=int, default=1)
@@ -32,9 +29,15 @@ def get_args() -> argparse.ArgumentParser:
32
29
  parser.add_argument("--warmup-steps", type=int, default=0)
33
30
  parser.add_argument("--max-cached-steps", type=int, default=-1)
34
31
  parser.add_argument("--max-pruned-steps", type=int, default=-1)
32
+ parser.add_argument("--gen-device", type=str, default="cuda")
35
33
  parser.add_argument("--ulysses", type=int, default=None)
36
34
  parser.add_argument("--compile", action="store_true", default=False)
37
- parser.add_argument("--gen-device", type=str, default="cuda")
35
+ parser.add_argument(
36
+ "--force-compile-all",
37
+ "--compile-all",
38
+ action="store_true",
39
+ default=False,
40
+ )
38
41
  return parser.parse_args()
39
42
 
40
43
 
@@ -52,12 +55,7 @@ def get_cache_options(cache_type: CacheType, args: argparse.Namespace):
52
55
  elif cache_type == CacheType.DBCache:
53
56
  cache_options = {
54
57
  "cache_type": CacheType.DBCache,
55
- "warmup_steps": (
56
- # TaylorSeer needs at least order + 1 warmup steps
57
- max(args.warmup_steps, args.taylorseer_order + 1)
58
- if (args.taylorseer or args.encoder_taylorseer)
59
- else args.warmup_steps
60
- ),
58
+ "warmup_steps": args.warmup_steps,
61
59
  "max_cached_steps": args.max_cached_steps, # -1 means no limit
62
60
  # Fn=1, Bn=0, means FB Cache, otherwise, Dual Block Cache
63
61
  "Fn_compute_blocks": args.Fn_compute_blocks, # Fn, F8, etc.
@@ -81,7 +79,7 @@ def get_cache_options(cache_type: CacheType, args: argparse.Namespace):
81
79
  "important_condition_threshold": 0.00,
82
80
  # TaylorSeer options
83
81
  "enable_taylorseer": args.taylorseer,
84
- "enable_encoder_taylorseer": args.encoder_taylorseer,
82
+ "enable_encoder_taylorseer": args.taylorseer,
85
83
  # Taylorseer cache type cache be hidden_states or residual
86
84
  "taylorseer_cache_type": "residual",
87
85
  "taylorseer_kwargs": {
@@ -90,7 +88,7 @@ def get_cache_options(cache_type: CacheType, args: argparse.Namespace):
90
88
  }
91
89
  elif cache_type == CacheType.DBPrune:
92
90
  assert (
93
- args.taylorseer is False and args.encoder_taylorseer is False
91
+ args.taylorseer is False
94
92
  ), "DBPrune does not support TaylorSeer yet."
95
93
  cache_options = {
96
94
  "cache_type": CacheType.DBPrune,
@@ -122,7 +120,6 @@ def get_cache_options(cache_type: CacheType, args: argparse.Namespace):
122
120
  f"{cache_type_str}_F{args.Fn_compute_blocks}"
123
121
  f"B{args.Bn_compute_blocks}S{args.Bn_steps}"
124
122
  f"W{args.warmup_steps}T{int(args.taylorseer)}"
125
- f"ET{int(args.encoder_taylorseer)}"
126
123
  f"O{args.taylorseer_order}"
127
124
  )
128
125
  elif cache_type == CacheType.DBPrune:
@@ -132,7 +129,7 @@ def get_cache_options(cache_type: CacheType, args: argparse.Namespace):
132
129
  )
133
130
  elif cache_type == CacheType.FBCache:
134
131
  cache_type_str = (
135
- f"{cache_type_str}_W{args.warmup_steps}" f"T{int(args.taylorseer)}"
132
+ f"{cache_type_str}_W{args.warmup_steps}T{int(args.taylorseer)}"
136
133
  )
137
134
  return cache_options, cache_type_str
138
135
 
@@ -201,8 +198,10 @@ def main():
201
198
  "Only compile transformer blocks not the whole model "
202
199
  "for FluxTransformer2DModel to keep higher precision."
203
200
  )
204
- if args.taylorseer_order <= 2 or (
205
- not args.taylorseer and not args.encoder_taylorseer
201
+ if (
202
+ args.taylorseer_order <= 2
203
+ or not args.taylorseer
204
+ or args.force_compile_all
206
205
  ):
207
206
  # NOTE: Seems like compiling the whole transformer
208
207
  # will cause precision issues while using TaylorSeer
@@ -17,5 +17,5 @@ __version__: str
17
17
  __version_tuple__: VERSION_TUPLE
18
18
  version_tuple: VERSION_TUPLE
19
19
 
20
- __version__ = version = '0.2.2'
21
- __version_tuple__ = version_tuple = (0, 2, 2)
20
+ __version__ = version = '0.2.3'
21
+ __version_tuple__ = version_tuple = (0, 2, 3)
@@ -1598,17 +1598,6 @@ def patch_cached_stats(
1598
1598
  if transformer is None:
1599
1599
  return
1600
1600
 
1601
- cached_transformer_blocks = getattr(transformer, "transformer_blocks", None)
1602
- if cached_transformer_blocks is None:
1603
- return
1604
-
1605
- if isinstance(cached_transformer_blocks, torch.nn.ModuleList):
1606
- cached_transformer_blocks = cached_transformer_blocks[0]
1607
- if not isinstance(
1608
- cached_transformer_blocks, DBCachedTransformerBlocks
1609
- ) or not isinstance(transformer, torch.nn.Module):
1610
- return
1611
-
1612
1601
  # TODO: Patch more cached stats to the transformer
1613
1602
  transformer._cached_steps = get_cached_steps()
1614
1603
  transformer._residual_diffs = get_residual_diffs()
@@ -92,6 +92,6 @@ def apply_db_cache_on_pipe(
92
92
  pipe.__class__._is_cached = True
93
93
 
94
94
  if not shallow_patch:
95
- apply_db_cache_on_transformer(pipe.transformer, **kwargs)
95
+ apply_db_cache_on_transformer(pipe.transformer)
96
96
 
97
97
  return pipe
@@ -55,6 +55,7 @@ class DBPPruneContext:
55
55
  default_factory=lambda: defaultdict(list),
56
56
  )
57
57
 
58
+ @torch.compiler.disable
58
59
  def get_residual_diff_threshold(self):
59
60
  residual_diff_threshold = self.residual_diff_threshold
60
61
  if self.l1_hidden_states_diff_threshold is not None:
@@ -98,19 +99,24 @@ class DBPPruneContext:
98
99
  )
99
100
  return residual_diff_threshold
100
101
 
102
+ @torch.compiler.disable
101
103
  def get_buffer(self, name):
102
104
  return self.buffers.get(name)
103
105
 
106
+ @torch.compiler.disable
104
107
  def set_buffer(self, name, buffer):
105
108
  self.buffers[name] = buffer
106
109
 
110
+ @torch.compiler.disable
107
111
  def remove_buffer(self, name):
108
112
  if name in self.buffers:
109
113
  del self.buffers[name]
110
114
 
115
+ @torch.compiler.disable
111
116
  def clear_buffers(self):
112
117
  self.buffers.clear()
113
118
 
119
+ @torch.compiler.disable
114
120
  def mark_step_begin(self):
115
121
  self.executed_steps += 1
116
122
  if self.get_current_step() == 0:
@@ -118,12 +124,15 @@ class DBPPruneContext:
118
124
  self.actual_blocks.clear()
119
125
  self.residual_diffs.clear()
120
126
 
127
+ @torch.compiler.disable
121
128
  def add_pruned_block(self, num_blocks):
122
129
  self.pruned_blocks.append(num_blocks)
123
130
 
131
+ @torch.compiler.disable
124
132
  def add_actual_block(self, num_blocks):
125
133
  self.actual_blocks.append(num_blocks)
126
134
 
135
+ @torch.compiler.disable
127
136
  def add_residual_diff(self, diff):
128
137
  if isinstance(diff, torch.Tensor):
129
138
  diff = diff.item()
@@ -141,9 +150,11 @@ class DBPPruneContext:
141
150
  f"residual diff: {diff:.6f}"
142
151
  )
143
152
 
153
+ @torch.compiler.disable
144
154
  def get_current_step(self):
145
155
  return self.executed_steps - 1
146
156
 
157
+ @torch.compiler.disable
147
158
  def is_in_warmup(self):
148
159
  return self.get_current_step() < self.warmup_steps
149
160
 
@@ -348,11 +359,19 @@ def collect_prune_kwargs(default_attrs: dict, **kwargs):
348
359
  )
349
360
  for attr in prune_attrs
350
361
  }
362
+
351
363
  # Manually set sequence fields, such as non_prune_blocks_ids
352
- prune_kwargs["non_prune_blocks_ids"] = kwargs.pop(
353
- "non_prune_blocks_ids",
354
- [],
355
- )
364
+ def _safe_set_sequence_field(
365
+ field_name: str,
366
+ default_value: Any = None,
367
+ ):
368
+ if field_name not in prune_kwargs:
369
+ prune_kwargs[field_name] = kwargs.pop(
370
+ field_name,
371
+ default_value,
372
+ )
373
+
374
+ _safe_set_sequence_field("non_prune_blocks_ids", [])
356
375
 
357
376
  assert default_attrs is not None, "default_attrs must be set before"
358
377
  for attr in prune_attrs:
@@ -627,10 +646,6 @@ class DBPrunedTransformerBlocks(torch.nn.Module):
627
646
  ]
628
647
  return sorted(non_prune_blocks_ids)
629
648
 
630
- # @torch.compile(dynamic=True)
631
- # mark this function as compile with dynamic=True will
632
- # cause precision degradate, so, we choose to disable it
633
- # now, until we find a better solution or fixed the bug.
634
649
  @torch.compiler.disable
635
650
  def _compute_single_hidden_states_residual(
636
651
  self,
@@ -667,10 +682,6 @@ class DBPrunedTransformerBlocks(torch.nn.Module):
667
682
  single_encoder_hidden_states_residual,
668
683
  )
669
684
 
670
- # @torch.compile(dynamic=True)
671
- # mark this function as compile with dynamic=True will
672
- # cause precision degradate, so, we choose to disable it
673
- # now, until we find a better solution or fixed the bug.
674
685
  @torch.compiler.disable
675
686
  def _split_single_hidden_states(
676
687
  self,
@@ -969,17 +980,6 @@ def patch_pruned_stats(
969
980
  if transformer is None:
970
981
  return
971
982
 
972
- pruned_transformer_blocks = getattr(transformer, "transformer_blocks", None)
973
- if pruned_transformer_blocks is None:
974
- return
975
-
976
- if isinstance(pruned_transformer_blocks, torch.nn.ModuleList):
977
- pruned_transformer_blocks = pruned_transformer_blocks[0]
978
- if not isinstance(
979
- pruned_transformer_blocks, DBPrunedTransformerBlocks
980
- ) or not isinstance(transformer, torch.nn.Module):
981
- return
982
-
983
983
  # TODO: Patch more pruned stats to the transformer
984
984
  transformer._pruned_blocks = get_pruned_blocks()
985
985
  transformer._pruned_steps = get_pruned_steps()
@@ -714,17 +714,6 @@ def patch_cached_stats(
714
714
  if transformer is None:
715
715
  return
716
716
 
717
- cached_transformer_blocks = getattr(transformer, "transformer_blocks", None)
718
- if cached_transformer_blocks is None:
719
- return
720
-
721
- if isinstance(cached_transformer_blocks, torch.nn.ModuleList):
722
- cached_transformer_blocks = cached_transformer_blocks[0]
723
- if not isinstance(
724
- cached_transformer_blocks, CachedTransformerBlocks
725
- ) or not isinstance(transformer, torch.nn.Module):
726
- return
727
-
728
717
  # TODO: Patch more cached stats to the transformer
729
718
  transformer._cached_steps = get_cached_steps()
730
719
  transformer._residual_diffs = get_residual_diffs()
@@ -1,4 +1,5 @@
1
1
  # Adapted from: https://github.com/chengzeyi/ParaAttention/blob/main/src/para_attn/first_block_cache/taylorseer.py
2
+ # Reference: https://github.com/Shenyi-Z/TaylorSeer/TaylorSeer-FLUX/src/flux/taylor_utils/__init__.py
2
3
  import math
3
4
  import torch
4
5
 
@@ -44,8 +45,6 @@ class TaylorSeer:
44
45
  # n-th order Taylor expansion:
45
46
  # Y(t) = Y(0) + dY(0)/dt * t + d^2Y(0)/dt^2 * t^2 / 2!
46
47
  # + ... + d^nY(0)/dt^n * t^n / n!
47
- # reference: https://github.com/Shenyi-Z/TaylorSeer
48
- # TaylorSeer-FLUX/src/flux/taylor_utils/__init__.py
49
48
  # TODO: Custom Triton/CUDA kernel for better performance,
50
49
  # especially for large n_derivatives.
51
50
  dY_current = [None] * self.ORDER
@@ -0,0 +1 @@
1
+ from cache_dit.compile.utils import set_custom_compile_configs
@@ -0,0 +1,94 @@
1
+ import os
2
+
3
+ import torch
4
+ from cache_dit.logger import init_logger, logging_rank_0
5
+
6
+ logger = init_logger(__name__)
7
+
8
+
9
+ def epilogue_prologue_fusion_enabled(**kwargs) -> bool:
10
+ mode = kwargs.get("epilogue_prologue_fusion", False)
11
+ CACHE_DIT_EPILOGUE_PROLOGUE_FUSION = bool(
12
+ int(os.environ.get("CACHE_DIT_EPILOGUE_PROLOGUE_FUSION", "0"))
13
+ )
14
+
15
+ if CACHE_DIT_EPILOGUE_PROLOGUE_FUSION:
16
+ logging_rank_0(
17
+ logger,
18
+ "CACHE_DIT_EPILOGUE_PROLOGUE_FUSION is set to 1. \n"
19
+ "Force enable epilogue and prologue fusion.",
20
+ )
21
+
22
+ return CACHE_DIT_EPILOGUE_PROLOGUE_FUSION or mode
23
+
24
+
25
+ def set_custom_compile_configs(
26
+ cuda_graphs: bool = False,
27
+ force_disable_compile_caches: bool = False,
28
+ use_fast_math: bool = False,
29
+ **kwargs, # other kwargs
30
+ ):
31
+ # Alway increase recompile_limit for dynamic shape compilation
32
+ torch._dynamo.config.recompile_limit = 96 # default is 8
33
+ torch._dynamo.config.accumulated_recompile_limit = 2048 # default is 256
34
+ # Handle compiler caches
35
+ # https://github.com/vllm-project/vllm/blob/23baa2180b0ebba5ae94073ba9b8e93f88b75486/vllm/compilation/compiler_interface.py#L270
36
+ torch._inductor.config.fx_graph_cache = True
37
+ torch._inductor.config.fx_graph_remote_cache = False
38
+ # https://github.com/pytorch/pytorch/issues/153791
39
+ torch._inductor.config.autotune_local_cache = False
40
+
41
+ FORCE_DISABLE_CUSTOM_COMPILE_CONFIG = (
42
+ os.environ.get("CACHE_DIT_FORCE_DISABLE_CUSTOM_COMPILE_CONFIG", "0")
43
+ == "1"
44
+ )
45
+ if FORCE_DISABLE_CUSTOM_COMPILE_CONFIG:
46
+ logging_rank_0(
47
+ logger,
48
+ "CACHE_DIT_FORCE_DISABLE_CUSTOM_COMPILE_CONFIG is set to 1. \n"
49
+ "Force disable custom compile config.",
50
+ )
51
+ return
52
+
53
+ # Enable compute comm overlap
54
+ torch._inductor.config.reorder_for_compute_comm_overlap = True
55
+ # L20 64 GB/s, PCIe; A100/A800 NVLink 300 GB/s.
56
+ torch._inductor.config.intra_node_bw = (
57
+ 64 if "L20" in torch.cuda.get_device_name() else 300
58
+ )
59
+
60
+ # Below are default settings for torch.compile, you can change
61
+ # them to your needs and test the performance
62
+ torch._inductor.config.max_fusion_size = 64
63
+ torch._inductor.config.max_pointwise_cat_inputs = 8
64
+ torch._inductor.config.triton.cudagraphs = cuda_graphs
65
+ torch._inductor.config.triton.use_block_ptr = False
66
+ torch._inductor.config.triton.codegen_upcast_to_fp32 = True
67
+
68
+ # Copy from https://pytorch.org/blog/accelerating-generative-ai-3/
69
+ torch._inductor.config.conv_1x1_as_mm = True
70
+ torch._inductor.config.coordinate_descent_tuning = True
71
+ torch._inductor.config.coordinate_descent_check_all_directions = True
72
+ torch._inductor.config.epilogue_fusion = False
73
+
74
+ # Enable epilogue and prologue fusion
75
+ if epilogue_prologue_fusion_enabled(**kwargs):
76
+ torch._inductor.config.epilogue_fusion = True
77
+ torch._inductor.config.prologue_fusion = True
78
+ torch._inductor.config.epilogue_fusion_first = True
79
+
80
+ # Dead code elimination
81
+ torch._inductor.config.dce = True # default is False
82
+
83
+ # May need to force disable all cache
84
+ if force_disable_compile_caches:
85
+ torch._inductor.config.force_disable_caches = True
86
+ torch._inductor.config.fx_graph_cache = False
87
+ torch._inductor.config.fx_graph_remote_cache = False
88
+ torch._inductor.config.autotune_local_cache = False # default is True
89
+
90
+ # Use fast math
91
+ if hasattr(torch._inductor.config, "use_fast_math"):
92
+ torch._inductor.config.use_fast_math = use_fast_math
93
+ if hasattr(torch._inductor.config, "cuda.use_fast_math"):
94
+ torch._inductor.config.cuda.use_fast_math = use_fast_math
File without changes
@@ -1,6 +1,7 @@
1
1
  import logging
2
2
  import os
3
3
  import sys
4
+ import torch.distributed as dist
4
5
 
5
6
  _FORMAT = "%(levelname)s %(asctime)s [%(filename)s:%(lineno)d] %(message)s"
6
7
  _DATE_FORMAT = "%m-%d %H:%M:%S"
@@ -95,3 +96,30 @@ def init_logger(name: str):
95
96
  logger.addHandler(_inference_log_file_handler[pid])
96
97
  logger.propagate = False
97
98
  return logger
99
+
100
+
101
+ def logging_rank_0(
102
+ logger: logging.Logger, message: str, level: int = logging.INFO
103
+ ):
104
+ if not isinstance(logger, logging.Logger):
105
+ raise TypeError("logger must be an instance of logging.Logger")
106
+ if not isinstance(message, str):
107
+ raise TypeError("message must be a string")
108
+ if not isinstance(level, int):
109
+ raise TypeError("level must be an integer representing a logging level")
110
+
111
+ def _logging_msg():
112
+ if level == logging.DEBUG:
113
+ logger.debug(message)
114
+ elif level == logging.WARNING:
115
+ logger.warning(message)
116
+ elif level == logging.ERROR:
117
+ logger.error(message)
118
+ else:
119
+ logger.info(message)
120
+
121
+ if dist.is_initialized():
122
+ if dist.get_rank() == 0:
123
+ _logging_msg()
124
+ else:
125
+ _logging_msg()
@@ -1,6 +1,6 @@
1
1
  Metadata-Version: 2.4
2
2
  Name: cache_dit
3
- Version: 0.2.2
3
+ Version: 0.2.3
4
4
  Summary: 🤗 CacheDiT: A Training-free and Easy-to-use Cache Acceleration Toolbox for Diffusion Transformers
5
5
  Author: DefTruth, vipshop.com, etc.
6
6
  Maintainer: DefTruth, vipshop.com, etc
@@ -283,7 +283,11 @@ cache_options = {
283
283
  "warmup_steps": 3, # n_derivatives + 1
284
284
  "residual_diff_threshold": 0.12,
285
285
  }
286
- ```
286
+ ```
287
+
288
+ > [!Important]
289
+ > Please note that if you have used TaylorSeer as the calibrator for approximate hidden states, the **Bn** param of DBCache can be set to **0**. In essence, DBCache's Bn is also act as a calibrator, so you can choose either Bn > 0 or TaylorSeer. We recommend using the configuration scheme of **TaylorSeer** + **DBCache FnB0**.
290
+
287
291
  <div align="center">
288
292
  <p align="center">
289
293
  <b>DBCache F1B0 + TaylorSeer</b>, L20x1, Steps: 28, <br>"A cat holding a sign that says hello world with complex background"
@@ -21,7 +21,6 @@ assets/DBCACHE_F8B16S1_R0.2_S18.png
21
21
  assets/DBCACHE_F8B8S1_R0.08_S9.png
22
22
  assets/DBCACHE_F8B8S1_R0.12_S12.png
23
23
  assets/DBCACHE_F8B8S1_R0.15_S15.png
24
- assets/DBCache.png
25
24
  assets/DBPRUNE_F1B0_R0.03_P24.0_T19.43s.png
26
25
  assets/DBPRUNE_F1B0_R0.04_P34.6_T16.82s.png
27
26
  assets/DBPRUNE_F1B0_R0.05_P38.3_T15.95s.png
@@ -68,7 +67,6 @@ assets/U4_C1_DBPRUNE_F1B0_R0.2_P59.5_T3.96s.png
68
67
  assets/U4_C1_NONE_R0.08_S0_T7.78s.png
69
68
  assets/U4_C1_NONE_R0.08_S0_T7.79s.png
70
69
  assets/cache-dit-v1.png
71
- assets/cache-dit.png
72
70
  assets/dbcache-fnbn-v1.png
73
71
  assets/dbcache-v1.png
74
72
  assets/dbprune-v1.png
@@ -123,6 +121,10 @@ src/cache_dit/cache_factory/first_block_cache/diffusers_adapters/flux.py
123
121
  src/cache_dit/cache_factory/first_block_cache/diffusers_adapters/hunyuan_video.py
124
122
  src/cache_dit/cache_factory/first_block_cache/diffusers_adapters/mochi.py
125
123
  src/cache_dit/cache_factory/first_block_cache/diffusers_adapters/wan.py
124
+ src/cache_dit/compile/__init__.py
125
+ src/cache_dit/compile/utils.py
126
+ src/cache_dit/custom_ops/__init__.py
127
+ src/cache_dit/custom_ops/triton_taylorseer.py
126
128
  tests/.gitignore
127
129
  tests/README.md
128
130
  tests/taylorseer_approximation_order_2.png
Binary file
Binary file
File without changes
File without changes
File without changes
File without changes
File without changes
File without changes
File without changes
File without changes
File without changes
File without changes
File without changes
File without changes
File without changes
File without changes
File without changes
File without changes