cache-dit 0.1.5__py3-none-any.whl → 0.1.6__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Potentially problematic release.
This version of cache-dit might be problematic. Click here for more details.
- cache_dit/_version.py +2 -2
- cache_dit/cache_factory/dynamic_block_prune/prune_context.py +11 -3
- {cache_dit-0.1.5.dist-info → cache_dit-0.1.6.dist-info}/METADATA +59 -19
- {cache_dit-0.1.5.dist-info → cache_dit-0.1.6.dist-info}/RECORD +7 -7
- {cache_dit-0.1.5.dist-info → cache_dit-0.1.6.dist-info}/WHEEL +0 -0
- {cache_dit-0.1.5.dist-info → cache_dit-0.1.6.dist-info}/licenses/LICENSE +0 -0
- {cache_dit-0.1.5.dist-info → cache_dit-0.1.6.dist-info}/top_level.txt +0 -0
cache_dit/_version.py
CHANGED
|
@@ -562,7 +562,7 @@ class DBPrunedTransformerBlocks(torch.nn.Module):
|
|
|
562
562
|
torch._dynamo.graph_break()
|
|
563
563
|
|
|
564
564
|
add_pruned_block(self.pruned_blocks_step)
|
|
565
|
-
add_actual_block(self.
|
|
565
|
+
add_actual_block(self.num_transformer_blocks)
|
|
566
566
|
patch_pruned_stats(self.transformer)
|
|
567
567
|
|
|
568
568
|
return (
|
|
@@ -577,7 +577,7 @@ class DBPrunedTransformerBlocks(torch.nn.Module):
|
|
|
577
577
|
|
|
578
578
|
@property
|
|
579
579
|
@torch.compiler.disable
|
|
580
|
-
def
|
|
580
|
+
def num_transformer_blocks(self):
|
|
581
581
|
# Total number of transformer blocks, including single transformer blocks.
|
|
582
582
|
num_blocks = len(self.transformer_blocks)
|
|
583
583
|
if self.single_transformer_blocks is not None:
|
|
@@ -597,7 +597,7 @@ class DBPrunedTransformerBlocks(torch.nn.Module):
|
|
|
597
597
|
@torch.compiler.disable
|
|
598
598
|
def _non_prune_blocks_ids(self):
|
|
599
599
|
# Never prune the first `Fn` and last `Bn` blocks.
|
|
600
|
-
num_blocks = self.
|
|
600
|
+
num_blocks = self.num_transformer_blocks
|
|
601
601
|
Fn_compute_blocks_ = (
|
|
602
602
|
Fn_compute_blocks()
|
|
603
603
|
if Fn_compute_blocks() < num_blocks
|
|
@@ -627,6 +627,10 @@ class DBPrunedTransformerBlocks(torch.nn.Module):
|
|
|
627
627
|
]
|
|
628
628
|
return sorted(non_prune_blocks_ids)
|
|
629
629
|
|
|
630
|
+
# @torch.compile(dynamic=True)
|
|
631
|
+
# mark this function as compile with dynamic=True will
|
|
632
|
+
# cause precision degradate, so, we choose to disable it
|
|
633
|
+
# now, until we find a better solution or fixed the bug.
|
|
630
634
|
@torch.compiler.disable
|
|
631
635
|
def _compute_single_hidden_states_residual(
|
|
632
636
|
self,
|
|
@@ -663,6 +667,10 @@ class DBPrunedTransformerBlocks(torch.nn.Module):
|
|
|
663
667
|
single_encoder_hidden_states_residual,
|
|
664
668
|
)
|
|
665
669
|
|
|
670
|
+
# @torch.compile(dynamic=True)
|
|
671
|
+
# mark this function as compile with dynamic=True will
|
|
672
|
+
# cause precision degradate, so, we choose to disable it
|
|
673
|
+
# now, until we find a better solution or fixed the bug.
|
|
666
674
|
@torch.compiler.disable
|
|
667
675
|
def _split_single_hidden_states(
|
|
668
676
|
self,
|
|
@@ -1,6 +1,6 @@
|
|
|
1
1
|
Metadata-Version: 2.4
|
|
2
2
|
Name: cache_dit
|
|
3
|
-
Version: 0.1.
|
|
3
|
+
Version: 0.1.6
|
|
4
4
|
Summary: 🤗 CacheDiT: A Training-free and Easy-to-use Cache Acceleration Toolbox for Diffusion Transformers
|
|
5
5
|
Author: DefTruth, vipshop.com, etc.
|
|
6
6
|
Maintainer: DefTruth, vipshop.com, etc
|
|
@@ -10,9 +10,9 @@ Requires-Python: >=3.10
|
|
|
10
10
|
Description-Content-Type: text/markdown
|
|
11
11
|
License-File: LICENSE
|
|
12
12
|
Requires-Dist: packaging
|
|
13
|
-
Requires-Dist: torch
|
|
14
|
-
Requires-Dist: transformers
|
|
15
|
-
Requires-Dist: diffusers
|
|
13
|
+
Requires-Dist: torch>=2.5.1
|
|
14
|
+
Requires-Dist: transformers>=4.51.3
|
|
15
|
+
Requires-Dist: diffusers>=0.33.1
|
|
16
16
|
Provides-Extra: all
|
|
17
17
|
Provides-Extra: dev
|
|
18
18
|
Requires-Dist: pre-commit; extra == "dev"
|
|
@@ -44,10 +44,10 @@ Dynamic: requires-python
|
|
|
44
44
|
<img src=https://img.shields.io/badge/PyPI-pass-brightgreen.svg >
|
|
45
45
|
<img src=https://static.pepy.tech/badge/cache-dit >
|
|
46
46
|
<img src=https://img.shields.io/badge/Python-3.10|3.11|3.12-9cf.svg >
|
|
47
|
-
<img src=https://img.shields.io/badge/Release-v0.1.
|
|
47
|
+
<img src=https://img.shields.io/badge/Release-v0.1.6-brightgreen.svg >
|
|
48
48
|
</div>
|
|
49
49
|
<p align="center">
|
|
50
|
-
DeepCache is for UNet not DiT. Most DiT cache speedups are complex and not training-free. CacheDiT
|
|
50
|
+
DeepCache is for UNet not DiT. Most DiT cache speedups are complex and not training-free. CacheDiT <br>offers a set of training-free cache accelerators for DiT: 🔥DBCache, DBPrune, FBCache, etc🔥
|
|
51
51
|
</p>
|
|
52
52
|
</div>
|
|
53
53
|
|
|
@@ -55,7 +55,7 @@ Dynamic: requires-python
|
|
|
55
55
|
|
|
56
56
|
<div align="center">
|
|
57
57
|
<p align="center">
|
|
58
|
-
<h3
|
|
58
|
+
<h3>🔥 DBCache: Dual Block Caching for Diffusion Transformers</h3>
|
|
59
59
|
</p>
|
|
60
60
|
</div>
|
|
61
61
|
|
|
@@ -77,7 +77,7 @@ Dynamic: requires-python
|
|
|
77
77
|
|
|
78
78
|
<div align="center">
|
|
79
79
|
<p align="center">
|
|
80
|
-
|
|
80
|
+
DBCache, <b> L20x4 </b>, Steps: 20, case to show the texture recovery ability of DBCache
|
|
81
81
|
</p>
|
|
82
82
|
</div>
|
|
83
83
|
|
|
@@ -85,7 +85,7 @@ These case studies demonstrate that even with relatively high thresholds (such a
|
|
|
85
85
|
|
|
86
86
|
<div align="center">
|
|
87
87
|
<p align="center">
|
|
88
|
-
<h3
|
|
88
|
+
<h3>🔥 DBPrune: Dynamic Block Prune with Residual Caching</h3>
|
|
89
89
|
</p>
|
|
90
90
|
</div>
|
|
91
91
|
|
|
@@ -102,10 +102,10 @@ These case studies demonstrate that even with relatively high thresholds (such a
|
|
|
102
102
|
</p>
|
|
103
103
|
</div>
|
|
104
104
|
|
|
105
|
-
Moreover,
|
|
105
|
+
Moreover, **CacheDiT** are **plug-and-play** solutions that works hand-in-hand with [ParaAttention](https://github.com/chengzeyi/ParaAttention). Users can easily tap into its **Context Parallelism** features for distributed inference.
|
|
106
106
|
|
|
107
107
|
<p align="center">
|
|
108
|
-
|
|
108
|
+
♥️ Please consider to leave a ⭐️ Star to support us ~ ♥️
|
|
109
109
|
</p>
|
|
110
110
|
|
|
111
111
|
## ©️Citations
|
|
@@ -135,7 +135,7 @@ The **CacheDiT** codebase was adapted from FBCache's implementation at the [Para
|
|
|
135
135
|
- [🎉First Block Cache](#fbcache)
|
|
136
136
|
- [⚡️Dynamic Block Prune](#dbprune)
|
|
137
137
|
- [🎉Context Parallelism](#context-parallelism)
|
|
138
|
-
- [
|
|
138
|
+
- [🔥Torch Compile](#compile)
|
|
139
139
|
- [🎉Supported Models](#supported)
|
|
140
140
|
- [👋Contribute](#contribute)
|
|
141
141
|
- [©️License](#license)
|
|
@@ -210,6 +210,17 @@ cache_options = {
|
|
|
210
210
|
}
|
|
211
211
|
```
|
|
212
212
|
|
|
213
|
+
<div align="center">
|
|
214
|
+
<p align="center">
|
|
215
|
+
DBCache, <b> L20x1 </b>, Steps: 28, "A cat holding a sign that says hello world with complex background"
|
|
216
|
+
</p>
|
|
217
|
+
</div>
|
|
218
|
+
|
|
219
|
+
|Baseline(L20x1)|F1B0 (0.08)|F1B0 (0.20)|F8B8 (0.15)|F12B12 (0.20)|F16B16 (0.20)|
|
|
220
|
+
|:---:|:---:|:---:|:---:|:---:|:---:|
|
|
221
|
+
|24.85s|15.59s|8.58s|15.41s|15.11s|17.74s|
|
|
222
|
+
|<img src=https://github.com/vipshop/cache-dit/raw/main/assets/NONE_R0.08_S0.png width=105px>|<img src=https://github.com/vipshop/cache-dit/raw/main/assets/DBCACHE_F1B0S1_R0.08_S11.png width=105px> | <img src=https://github.com/vipshop/cache-dit/raw/main/assets/DBCACHE_F1B0S1_R0.2_S19.png width=105px>|<img src=https://github.com/vipshop/cache-dit/raw/main/assets/DBCACHE_F8B8S1_R0.15_S15.png width=105px>|<img src=https://github.com/vipshop/cache-dit/raw/main/assets/DBCACHE_F12B12S4_R0.2_S16.png width=105px>|<img src=https://github.com/vipshop/cache-dit/raw/main/assets/DBCACHE_F16B16S4_R0.2_S13.png width=105px>|
|
|
223
|
+
|
|
213
224
|
## 🎉FBCache: First Block Cache
|
|
214
225
|
|
|
215
226
|
<div id="fbcache"></div>
|
|
@@ -323,14 +334,19 @@ apply_cache_on_pipe(pipe, **cache_options)
|
|
|
323
334
|
pip3 install para-attn # or install `para-attn` from sources.
|
|
324
335
|
```
|
|
325
336
|
|
|
326
|
-
Then, you can run **DBCache** with **Context Parallelism** on 4 GPUs:
|
|
337
|
+
Then, you can run **DBCache** or **DBPrune** with **Context Parallelism** on 4 GPUs:
|
|
327
338
|
|
|
328
339
|
```python
|
|
340
|
+
import torch.distributed as dist
|
|
329
341
|
from diffusers import FluxPipeline
|
|
330
342
|
from para_attn.context_parallel import init_context_parallel_mesh
|
|
331
343
|
from para_attn.context_parallel.diffusers_adapters import parallelize_pipe
|
|
332
344
|
from cache_dit.cache_factory import apply_cache_on_pipe, CacheType
|
|
333
345
|
|
|
346
|
+
# Init distributed process group
|
|
347
|
+
dist.init_process_group()
|
|
348
|
+
torch.cuda.set_device(dist.get_rank())
|
|
349
|
+
|
|
334
350
|
pipe = FluxPipeline.from_pretrained(
|
|
335
351
|
"black-forest-labs/FLUX.1-dev",
|
|
336
352
|
torch_dtype=torch.bfloat16,
|
|
@@ -343,13 +359,31 @@ parallelize_pipe(
|
|
|
343
359
|
)
|
|
344
360
|
)
|
|
345
361
|
|
|
346
|
-
#
|
|
362
|
+
# DBPrune with default options from this library
|
|
347
363
|
apply_cache_on_pipe(
|
|
348
|
-
pipe, **CacheType.default_options(CacheType.
|
|
364
|
+
pipe, **CacheType.default_options(CacheType.DBPrune)
|
|
349
365
|
)
|
|
366
|
+
|
|
367
|
+
dist.destroy_process_group()
|
|
368
|
+
```
|
|
369
|
+
Then, run the python test script with `torchrun`:
|
|
370
|
+
```bash
|
|
371
|
+
torchrun --nproc_per_node=4 parallel_cache.py
|
|
350
372
|
```
|
|
351
373
|
|
|
352
|
-
|
|
374
|
+
<div align="center">
|
|
375
|
+
<p align="center">
|
|
376
|
+
DBPrune, <b> L20x4 </b>, Steps: 28, "A cat holding a sign that says hello world with complex background"
|
|
377
|
+
</p>
|
|
378
|
+
</div>
|
|
379
|
+
|
|
380
|
+
|Baseline(L20x1)|Pruned(24%)|Pruned(35%)|Pruned(38%)|Pruned(45%)|Pruned(60%)|
|
|
381
|
+
|:---:|:---:|:---:|:---:|:---:|:---:|
|
|
382
|
+
|24.85s|19.43s|16.82s|15.95s|14.24s|10.66s|
|
|
383
|
+
|8.54s (L20x4)|7.20s (L20x4)|6.61s (L20x4)|6.09s (L20x4)|5.54s (L20x4)|4.22s (L20x4)|
|
|
384
|
+
|<img src=https://github.com/vipshop/cache-dit/raw/main/assets/NONE_R0.08_S0.png width=105px>|<img src=https://github.com/vipshop/cache-dit/raw/main/assets/DBPRUNE_F1B0_R0.03_P24.0_T19.43s.png width=105px> | <img src=https://github.com/vipshop/cache-dit/raw/main/assets/DBPRUNE_F1B0_R0.04_P34.6_T16.82s.png width=105px>|<img src=https://github.com/vipshop/cache-dit/raw/main/assets/DBPRUNE_F1B0_R0.05_P38.3_T15.95s.png width=105px>|<img src=https://github.com/vipshop/cache-dit/raw/main/assets/DBPRUNE_F1B0_R0.06_P45.2_T14.24s.png width=105px>|<img src=https://github.com/vipshop/cache-dit/raw/main/assets/DBPRUNE_F1B0_R0.2_P59.5_T10.66s.png width=105px>|
|
|
385
|
+
|
|
386
|
+
## 🔥Torch Compile
|
|
353
387
|
|
|
354
388
|
<div id="compile"></div>
|
|
355
389
|
|
|
@@ -357,7 +391,7 @@ apply_cache_on_pipe(
|
|
|
357
391
|
|
|
358
392
|
```python
|
|
359
393
|
apply_cache_on_pipe(
|
|
360
|
-
pipe, **CacheType.default_options(CacheType.
|
|
394
|
+
pipe, **CacheType.default_options(CacheType.DBPrune)
|
|
361
395
|
)
|
|
362
396
|
# Compile the Transformer module
|
|
363
397
|
pipe.transformer = torch.compile(pipe.transformer)
|
|
@@ -368,7 +402,13 @@ However, users intending to use **CacheDiT** for DiT with **dynamic input shapes
|
|
|
368
402
|
torch._dynamo.config.recompile_limit = 96 # default is 8
|
|
369
403
|
torch._dynamo.config.accumulated_recompile_limit = 2048 # default is 256
|
|
370
404
|
```
|
|
371
|
-
Otherwise, the recompile_limit error may be triggered, causing the module to fall back to eager mode.
|
|
405
|
+
Otherwise, the recompile_limit error may be triggered, causing the module to fall back to eager mode. Here is the case of **DBPrune + torch.compile**.
|
|
406
|
+
|
|
407
|
+
<div align="center">
|
|
408
|
+
<p align="center">
|
|
409
|
+
DBPrune + compile, Steps: 28, "A cat holding a sign that says hello world with complex background"
|
|
410
|
+
</p>
|
|
411
|
+
</div>
|
|
372
412
|
|
|
373
413
|
## 🎉Supported Models
|
|
374
414
|
|
|
@@ -381,7 +421,7 @@ Otherwise, the recompile_limit error may be triggered, causing the module to fal
|
|
|
381
421
|
## 👋Contribute
|
|
382
422
|
<div id="contribute"></div>
|
|
383
423
|
|
|
384
|
-
How to contribute? Star this repo or check [CONTRIBUTE.md](./CONTRIBUTE.md).
|
|
424
|
+
How to contribute? Star ⭐️ this repo to support us or check [CONTRIBUTE.md](./CONTRIBUTE.md).
|
|
385
425
|
|
|
386
426
|
## ©️License
|
|
387
427
|
|
|
@@ -1,5 +1,5 @@
|
|
|
1
1
|
cache_dit/__init__.py,sha256=47DEQpj8HBSa-_TImW-5JCeuQeRkm5NMpJWZG3hSuFU,0
|
|
2
|
-
cache_dit/_version.py,sha256=
|
|
2
|
+
cache_dit/_version.py,sha256=ESbJO0YD7TYfOUv_WDIJJgWELGepEWsoyhqVifEcXPA,511
|
|
3
3
|
cache_dit/logger.py,sha256=dKfNe_RRk9HJwfgHGeRR1f0LbskJpKdGmISCbL9roQs,3443
|
|
4
4
|
cache_dit/primitives.py,sha256=A2iG9YLot3gOsZSPp-_gyjqjLgJvWQRx8aitD4JQ23Y,3877
|
|
5
5
|
cache_dit/cache_factory/__init__.py,sha256=5RNuhWakvvqrOV4vkqrEBA7d-V1LwcNSsjtW14mkqK8,5255
|
|
@@ -12,7 +12,7 @@ cache_dit/cache_factory/dual_block_cache/diffusers_adapters/cogvideox.py,sha256=
|
|
|
12
12
|
cache_dit/cache_factory/dual_block_cache/diffusers_adapters/flux.py,sha256=UbE6nIF-EtA92QxIZVMzIssdZKQSPAVX1hchF9R8drU,2754
|
|
13
13
|
cache_dit/cache_factory/dual_block_cache/diffusers_adapters/mochi.py,sha256=qxMu1L3ycT8F-uxpGsmFQBY_BH1vDiGIOXgS_Qbb7dM,2391
|
|
14
14
|
cache_dit/cache_factory/dynamic_block_prune/__init__.py,sha256=47DEQpj8HBSa-_TImW-5JCeuQeRkm5NMpJWZG3hSuFU,0
|
|
15
|
-
cache_dit/cache_factory/dynamic_block_prune/prune_context.py,sha256=
|
|
15
|
+
cache_dit/cache_factory/dynamic_block_prune/prune_context.py,sha256=foUGCBtpCbfLWw6pxJguyxOfcp_YrizfDEKawCt_UKI,35028
|
|
16
16
|
cache_dit/cache_factory/dynamic_block_prune/diffusers_adapters/__init__.py,sha256=8IjJjZOs5XRzsj7Ni2MXpR2Z1PUyRSONIhmfAn1G0eM,1667
|
|
17
17
|
cache_dit/cache_factory/dynamic_block_prune/diffusers_adapters/cogvideox.py,sha256=ORJpdkXkgziDUo-rpebC6pUemgYaDCoeu0cwwLz175U,2407
|
|
18
18
|
cache_dit/cache_factory/dynamic_block_prune/diffusers_adapters/flux.py,sha256=KbEkLSsHtS6xwLWNh3jlOlXRyGRdrI2pWV1zyQxMTj4,2757
|
|
@@ -24,8 +24,8 @@ cache_dit/cache_factory/first_block_cache/diffusers_adapters/cogvideox.py,sha256
|
|
|
24
24
|
cache_dit/cache_factory/first_block_cache/diffusers_adapters/flux.py,sha256=Dcd4OzABCtyQCZNX2KNnUTdVoO1E1ApM7P8gcVYzcK0,2733
|
|
25
25
|
cache_dit/cache_factory/first_block_cache/diffusers_adapters/mochi.py,sha256=lQTClo52OwPbNEE4jiBZQhfC7hbtYqnYIABp_vbm_dk,2363
|
|
26
26
|
cache_dit/cache_factory/first_block_cache/diffusers_adapters/wan.py,sha256=IVH-lroOzvYb4XKLk9MOw54EtijBtuzVaKcVGz0KlBA,2656
|
|
27
|
-
cache_dit-0.1.
|
|
28
|
-
cache_dit-0.1.
|
|
29
|
-
cache_dit-0.1.
|
|
30
|
-
cache_dit-0.1.
|
|
31
|
-
cache_dit-0.1.
|
|
27
|
+
cache_dit-0.1.6.dist-info/licenses/LICENSE,sha256=Dqb07Ik2dV41s9nIdMUbiRWEfDqo7-dQeRiY7kPO8PE,3769
|
|
28
|
+
cache_dit-0.1.6.dist-info/METADATA,sha256=B8ddDPBXwFFPYYAxEgu8itLKjpb5IKTaA2JHFV7eQhM,21030
|
|
29
|
+
cache_dit-0.1.6.dist-info/WHEEL,sha256=_zCd3N1l69ArxyTb8rzEoP9TpbYXkqRFSNOD5OuxnTs,91
|
|
30
|
+
cache_dit-0.1.6.dist-info/top_level.txt,sha256=ZJDydonLEhujzz0FOkVbO-BqfzO9d_VqRHmZU-3MOZo,10
|
|
31
|
+
cache_dit-0.1.6.dist-info/RECORD,,
|
|
File without changes
|
|
File without changes
|
|
File without changes
|