cache-dit 0.1.5__py3-none-any.whl → 0.1.6__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.

Potentially problematic release.


This version of cache-dit might be problematic. Click here for more details.

cache_dit/_version.py CHANGED
@@ -17,5 +17,5 @@ __version__: str
17
17
  __version_tuple__: VERSION_TUPLE
18
18
  version_tuple: VERSION_TUPLE
19
19
 
20
- __version__ = version = '0.1.5'
21
- __version_tuple__ = version_tuple = (0, 1, 5)
20
+ __version__ = version = '0.1.6'
21
+ __version_tuple__ = version_tuple = (0, 1, 6)
@@ -562,7 +562,7 @@ class DBPrunedTransformerBlocks(torch.nn.Module):
562
562
  torch._dynamo.graph_break()
563
563
 
564
564
  add_pruned_block(self.pruned_blocks_step)
565
- add_actual_block(self._num_transformer_blocks)
565
+ add_actual_block(self.num_transformer_blocks)
566
566
  patch_pruned_stats(self.transformer)
567
567
 
568
568
  return (
@@ -577,7 +577,7 @@ class DBPrunedTransformerBlocks(torch.nn.Module):
577
577
 
578
578
  @property
579
579
  @torch.compiler.disable
580
- def _num_transformer_blocks(self):
580
+ def num_transformer_blocks(self):
581
581
  # Total number of transformer blocks, including single transformer blocks.
582
582
  num_blocks = len(self.transformer_blocks)
583
583
  if self.single_transformer_blocks is not None:
@@ -597,7 +597,7 @@ class DBPrunedTransformerBlocks(torch.nn.Module):
597
597
  @torch.compiler.disable
598
598
  def _non_prune_blocks_ids(self):
599
599
  # Never prune the first `Fn` and last `Bn` blocks.
600
- num_blocks = self._num_transformer_blocks
600
+ num_blocks = self.num_transformer_blocks
601
601
  Fn_compute_blocks_ = (
602
602
  Fn_compute_blocks()
603
603
  if Fn_compute_blocks() < num_blocks
@@ -627,6 +627,10 @@ class DBPrunedTransformerBlocks(torch.nn.Module):
627
627
  ]
628
628
  return sorted(non_prune_blocks_ids)
629
629
 
630
+ # @torch.compile(dynamic=True)
631
+ # mark this function as compile with dynamic=True will
632
+ # cause precision degradate, so, we choose to disable it
633
+ # now, until we find a better solution or fixed the bug.
630
634
  @torch.compiler.disable
631
635
  def _compute_single_hidden_states_residual(
632
636
  self,
@@ -663,6 +667,10 @@ class DBPrunedTransformerBlocks(torch.nn.Module):
663
667
  single_encoder_hidden_states_residual,
664
668
  )
665
669
 
670
+ # @torch.compile(dynamic=True)
671
+ # mark this function as compile with dynamic=True will
672
+ # cause precision degradate, so, we choose to disable it
673
+ # now, until we find a better solution or fixed the bug.
666
674
  @torch.compiler.disable
667
675
  def _split_single_hidden_states(
668
676
  self,
@@ -1,6 +1,6 @@
1
1
  Metadata-Version: 2.4
2
2
  Name: cache_dit
3
- Version: 0.1.5
3
+ Version: 0.1.6
4
4
  Summary: 🤗 CacheDiT: A Training-free and Easy-to-use Cache Acceleration Toolbox for Diffusion Transformers
5
5
  Author: DefTruth, vipshop.com, etc.
6
6
  Maintainer: DefTruth, vipshop.com, etc
@@ -10,9 +10,9 @@ Requires-Python: >=3.10
10
10
  Description-Content-Type: text/markdown
11
11
  License-File: LICENSE
12
12
  Requires-Dist: packaging
13
- Requires-Dist: torch
14
- Requires-Dist: transformers
15
- Requires-Dist: diffusers
13
+ Requires-Dist: torch>=2.5.1
14
+ Requires-Dist: transformers>=4.51.3
15
+ Requires-Dist: diffusers>=0.33.1
16
16
  Provides-Extra: all
17
17
  Provides-Extra: dev
18
18
  Requires-Dist: pre-commit; extra == "dev"
@@ -44,10 +44,10 @@ Dynamic: requires-python
44
44
  <img src=https://img.shields.io/badge/PyPI-pass-brightgreen.svg >
45
45
  <img src=https://static.pepy.tech/badge/cache-dit >
46
46
  <img src=https://img.shields.io/badge/Python-3.10|3.11|3.12-9cf.svg >
47
- <img src=https://img.shields.io/badge/Release-v0.1.5-brightgreen.svg >
47
+ <img src=https://img.shields.io/badge/Release-v0.1.6-brightgreen.svg >
48
48
  </div>
49
49
  <p align="center">
50
- DeepCache is for UNet not DiT. Most DiT cache speedups are complex and not training-free. CacheDiT provides <br>a series of training-free, UNet-style cache accelerators for DiT: DBCache, DBPrune, FBCache, etc.
50
+ DeepCache is for UNet not DiT. Most DiT cache speedups are complex and not training-free. CacheDiT <br>offers a set of training-free cache accelerators for DiT: 🔥DBCache, DBPrune, FBCache, etc🔥
51
51
  </p>
52
52
  </div>
53
53
 
@@ -55,7 +55,7 @@ Dynamic: requires-python
55
55
 
56
56
  <div align="center">
57
57
  <p align="center">
58
- <h3>DBCache: Dual Block Caching for Diffusion Transformers</h3>
58
+ <h3>🔥 DBCache: Dual Block Caching for Diffusion Transformers</h3>
59
59
  </p>
60
60
  </div>
61
61
 
@@ -77,7 +77,7 @@ Dynamic: requires-python
77
77
 
78
78
  <div align="center">
79
79
  <p align="center">
80
- DBCache, <b> L20x4 </b>, Steps: 20, case to show the texture recovery ability of DBCache
80
+ DBCache, <b> L20x4 </b>, Steps: 20, case to show the texture recovery ability of DBCache
81
81
  </p>
82
82
  </div>
83
83
 
@@ -85,7 +85,7 @@ These case studies demonstrate that even with relatively high thresholds (such a
85
85
 
86
86
  <div align="center">
87
87
  <p align="center">
88
- <h3>DBPrune: Dynamic Block Prune with Residual Caching</h3>
88
+ <h3>🔥 DBPrune: Dynamic Block Prune with Residual Caching</h3>
89
89
  </p>
90
90
  </div>
91
91
 
@@ -102,10 +102,10 @@ These case studies demonstrate that even with relatively high thresholds (such a
102
102
  </p>
103
103
  </div>
104
104
 
105
- Moreover, both DBCache and DBPrune are **plug-and-play** solutions that works hand-in-hand with [ParaAttention](https://github.com/chengzeyi/ParaAttention). Users can easily tap into its **Context Parallelism** features for distributed inference.
105
+ Moreover, **CacheDiT** are **plug-and-play** solutions that works hand-in-hand with [ParaAttention](https://github.com/chengzeyi/ParaAttention). Users can easily tap into its **Context Parallelism** features for distributed inference.
106
106
 
107
107
  <p align="center">
108
- ♥️ Please consider to leave a ⭐️ Star to support us ~ ♥️
108
+ ♥️ Please consider to leave a ⭐️ Star to support us ~ ♥️
109
109
  </p>
110
110
 
111
111
  ## ©️Citations
@@ -135,7 +135,7 @@ The **CacheDiT** codebase was adapted from FBCache's implementation at the [Para
135
135
  - [🎉First Block Cache](#fbcache)
136
136
  - [⚡️Dynamic Block Prune](#dbprune)
137
137
  - [🎉Context Parallelism](#context-parallelism)
138
- - [⚡️Torch Compile](#compile)
138
+ - [🔥Torch Compile](#compile)
139
139
  - [🎉Supported Models](#supported)
140
140
  - [👋Contribute](#contribute)
141
141
  - [©️License](#license)
@@ -210,6 +210,17 @@ cache_options = {
210
210
  }
211
211
  ```
212
212
 
213
+ <div align="center">
214
+ <p align="center">
215
+ DBCache, <b> L20x1 </b>, Steps: 28, "A cat holding a sign that says hello world with complex background"
216
+ </p>
217
+ </div>
218
+
219
+ |Baseline(L20x1)|F1B0 (0.08)|F1B0 (0.20)|F8B8 (0.15)|F12B12 (0.20)|F16B16 (0.20)|
220
+ |:---:|:---:|:---:|:---:|:---:|:---:|
221
+ |24.85s|15.59s|8.58s|15.41s|15.11s|17.74s|
222
+ |<img src=https://github.com/vipshop/cache-dit/raw/main/assets/NONE_R0.08_S0.png width=105px>|<img src=https://github.com/vipshop/cache-dit/raw/main/assets/DBCACHE_F1B0S1_R0.08_S11.png width=105px> | <img src=https://github.com/vipshop/cache-dit/raw/main/assets/DBCACHE_F1B0S1_R0.2_S19.png width=105px>|<img src=https://github.com/vipshop/cache-dit/raw/main/assets/DBCACHE_F8B8S1_R0.15_S15.png width=105px>|<img src=https://github.com/vipshop/cache-dit/raw/main/assets/DBCACHE_F12B12S4_R0.2_S16.png width=105px>|<img src=https://github.com/vipshop/cache-dit/raw/main/assets/DBCACHE_F16B16S4_R0.2_S13.png width=105px>|
223
+
213
224
  ## 🎉FBCache: First Block Cache
214
225
 
215
226
  <div id="fbcache"></div>
@@ -323,14 +334,19 @@ apply_cache_on_pipe(pipe, **cache_options)
323
334
  pip3 install para-attn # or install `para-attn` from sources.
324
335
  ```
325
336
 
326
- Then, you can run **DBCache** with **Context Parallelism** on 4 GPUs:
337
+ Then, you can run **DBCache** or **DBPrune** with **Context Parallelism** on 4 GPUs:
327
338
 
328
339
  ```python
340
+ import torch.distributed as dist
329
341
  from diffusers import FluxPipeline
330
342
  from para_attn.context_parallel import init_context_parallel_mesh
331
343
  from para_attn.context_parallel.diffusers_adapters import parallelize_pipe
332
344
  from cache_dit.cache_factory import apply_cache_on_pipe, CacheType
333
345
 
346
+ # Init distributed process group
347
+ dist.init_process_group()
348
+ torch.cuda.set_device(dist.get_rank())
349
+
334
350
  pipe = FluxPipeline.from_pretrained(
335
351
  "black-forest-labs/FLUX.1-dev",
336
352
  torch_dtype=torch.bfloat16,
@@ -343,13 +359,31 @@ parallelize_pipe(
343
359
  )
344
360
  )
345
361
 
346
- # DBCache with F8B8 from this library
362
+ # DBPrune with default options from this library
347
363
  apply_cache_on_pipe(
348
- pipe, **CacheType.default_options(CacheType.DBCache)
364
+ pipe, **CacheType.default_options(CacheType.DBPrune)
349
365
  )
366
+
367
+ dist.destroy_process_group()
368
+ ```
369
+ Then, run the python test script with `torchrun`:
370
+ ```bash
371
+ torchrun --nproc_per_node=4 parallel_cache.py
350
372
  ```
351
373
 
352
- ## ⚡️Torch Compile
374
+ <div align="center">
375
+ <p align="center">
376
+ DBPrune, <b> L20x4 </b>, Steps: 28, "A cat holding a sign that says hello world with complex background"
377
+ </p>
378
+ </div>
379
+
380
+ |Baseline(L20x1)|Pruned(24%)|Pruned(35%)|Pruned(38%)|Pruned(45%)|Pruned(60%)|
381
+ |:---:|:---:|:---:|:---:|:---:|:---:|
382
+ |24.85s|19.43s|16.82s|15.95s|14.24s|10.66s|
383
+ |8.54s (L20x4)|7.20s (L20x4)|6.61s (L20x4)|6.09s (L20x4)|5.54s (L20x4)|4.22s (L20x4)|
384
+ |<img src=https://github.com/vipshop/cache-dit/raw/main/assets/NONE_R0.08_S0.png width=105px>|<img src=https://github.com/vipshop/cache-dit/raw/main/assets/DBPRUNE_F1B0_R0.03_P24.0_T19.43s.png width=105px> | <img src=https://github.com/vipshop/cache-dit/raw/main/assets/DBPRUNE_F1B0_R0.04_P34.6_T16.82s.png width=105px>|<img src=https://github.com/vipshop/cache-dit/raw/main/assets/DBPRUNE_F1B0_R0.05_P38.3_T15.95s.png width=105px>|<img src=https://github.com/vipshop/cache-dit/raw/main/assets/DBPRUNE_F1B0_R0.06_P45.2_T14.24s.png width=105px>|<img src=https://github.com/vipshop/cache-dit/raw/main/assets/DBPRUNE_F1B0_R0.2_P59.5_T10.66s.png width=105px>|
385
+
386
+ ## 🔥Torch Compile
353
387
 
354
388
  <div id="compile"></div>
355
389
 
@@ -357,7 +391,7 @@ apply_cache_on_pipe(
357
391
 
358
392
  ```python
359
393
  apply_cache_on_pipe(
360
- pipe, **CacheType.default_options(CacheType.DBCache)
394
+ pipe, **CacheType.default_options(CacheType.DBPrune)
361
395
  )
362
396
  # Compile the Transformer module
363
397
  pipe.transformer = torch.compile(pipe.transformer)
@@ -368,7 +402,13 @@ However, users intending to use **CacheDiT** for DiT with **dynamic input shapes
368
402
  torch._dynamo.config.recompile_limit = 96 # default is 8
369
403
  torch._dynamo.config.accumulated_recompile_limit = 2048 # default is 256
370
404
  ```
371
- Otherwise, the recompile_limit error may be triggered, causing the module to fall back to eager mode.
405
+ Otherwise, the recompile_limit error may be triggered, causing the module to fall back to eager mode. Here is the case of **DBPrune + torch.compile**.
406
+
407
+ <div align="center">
408
+ <p align="center">
409
+ DBPrune + compile, Steps: 28, "A cat holding a sign that says hello world with complex background"
410
+ </p>
411
+ </div>
372
412
 
373
413
  ## 🎉Supported Models
374
414
 
@@ -381,7 +421,7 @@ Otherwise, the recompile_limit error may be triggered, causing the module to fal
381
421
  ## 👋Contribute
382
422
  <div id="contribute"></div>
383
423
 
384
- How to contribute? Star this repo or check [CONTRIBUTE.md](./CONTRIBUTE.md).
424
+ How to contribute? Star ⭐️ this repo to support us or check [CONTRIBUTE.md](./CONTRIBUTE.md).
385
425
 
386
426
  ## ©️License
387
427
 
@@ -1,5 +1,5 @@
1
1
  cache_dit/__init__.py,sha256=47DEQpj8HBSa-_TImW-5JCeuQeRkm5NMpJWZG3hSuFU,0
2
- cache_dit/_version.py,sha256=Y4jy4bEMmwl_qNPCmiMFnlQ2ofMoqyG37hp8uwI3m10,511
2
+ cache_dit/_version.py,sha256=ESbJO0YD7TYfOUv_WDIJJgWELGepEWsoyhqVifEcXPA,511
3
3
  cache_dit/logger.py,sha256=dKfNe_RRk9HJwfgHGeRR1f0LbskJpKdGmISCbL9roQs,3443
4
4
  cache_dit/primitives.py,sha256=A2iG9YLot3gOsZSPp-_gyjqjLgJvWQRx8aitD4JQ23Y,3877
5
5
  cache_dit/cache_factory/__init__.py,sha256=5RNuhWakvvqrOV4vkqrEBA7d-V1LwcNSsjtW14mkqK8,5255
@@ -12,7 +12,7 @@ cache_dit/cache_factory/dual_block_cache/diffusers_adapters/cogvideox.py,sha256=
12
12
  cache_dit/cache_factory/dual_block_cache/diffusers_adapters/flux.py,sha256=UbE6nIF-EtA92QxIZVMzIssdZKQSPAVX1hchF9R8drU,2754
13
13
  cache_dit/cache_factory/dual_block_cache/diffusers_adapters/mochi.py,sha256=qxMu1L3ycT8F-uxpGsmFQBY_BH1vDiGIOXgS_Qbb7dM,2391
14
14
  cache_dit/cache_factory/dynamic_block_prune/__init__.py,sha256=47DEQpj8HBSa-_TImW-5JCeuQeRkm5NMpJWZG3hSuFU,0
15
- cache_dit/cache_factory/dynamic_block_prune/prune_context.py,sha256=cE27f5NPgQ_COmTnF__e85Uz5pWyXID0ut-tmtSQfVQ,34597
15
+ cache_dit/cache_factory/dynamic_block_prune/prune_context.py,sha256=foUGCBtpCbfLWw6pxJguyxOfcp_YrizfDEKawCt_UKI,35028
16
16
  cache_dit/cache_factory/dynamic_block_prune/diffusers_adapters/__init__.py,sha256=8IjJjZOs5XRzsj7Ni2MXpR2Z1PUyRSONIhmfAn1G0eM,1667
17
17
  cache_dit/cache_factory/dynamic_block_prune/diffusers_adapters/cogvideox.py,sha256=ORJpdkXkgziDUo-rpebC6pUemgYaDCoeu0cwwLz175U,2407
18
18
  cache_dit/cache_factory/dynamic_block_prune/diffusers_adapters/flux.py,sha256=KbEkLSsHtS6xwLWNh3jlOlXRyGRdrI2pWV1zyQxMTj4,2757
@@ -24,8 +24,8 @@ cache_dit/cache_factory/first_block_cache/diffusers_adapters/cogvideox.py,sha256
24
24
  cache_dit/cache_factory/first_block_cache/diffusers_adapters/flux.py,sha256=Dcd4OzABCtyQCZNX2KNnUTdVoO1E1ApM7P8gcVYzcK0,2733
25
25
  cache_dit/cache_factory/first_block_cache/diffusers_adapters/mochi.py,sha256=lQTClo52OwPbNEE4jiBZQhfC7hbtYqnYIABp_vbm_dk,2363
26
26
  cache_dit/cache_factory/first_block_cache/diffusers_adapters/wan.py,sha256=IVH-lroOzvYb4XKLk9MOw54EtijBtuzVaKcVGz0KlBA,2656
27
- cache_dit-0.1.5.dist-info/licenses/LICENSE,sha256=Dqb07Ik2dV41s9nIdMUbiRWEfDqo7-dQeRiY7kPO8PE,3769
28
- cache_dit-0.1.5.dist-info/METADATA,sha256=87hqiy00L-lsW4yktOKTnWAnk7jHCVhxYKOkxaK2K48,18478
29
- cache_dit-0.1.5.dist-info/WHEEL,sha256=_zCd3N1l69ArxyTb8rzEoP9TpbYXkqRFSNOD5OuxnTs,91
30
- cache_dit-0.1.5.dist-info/top_level.txt,sha256=ZJDydonLEhujzz0FOkVbO-BqfzO9d_VqRHmZU-3MOZo,10
31
- cache_dit-0.1.5.dist-info/RECORD,,
27
+ cache_dit-0.1.6.dist-info/licenses/LICENSE,sha256=Dqb07Ik2dV41s9nIdMUbiRWEfDqo7-dQeRiY7kPO8PE,3769
28
+ cache_dit-0.1.6.dist-info/METADATA,sha256=B8ddDPBXwFFPYYAxEgu8itLKjpb5IKTaA2JHFV7eQhM,21030
29
+ cache_dit-0.1.6.dist-info/WHEEL,sha256=_zCd3N1l69ArxyTb8rzEoP9TpbYXkqRFSNOD5OuxnTs,91
30
+ cache_dit-0.1.6.dist-info/top_level.txt,sha256=ZJDydonLEhujzz0FOkVbO-BqfzO9d_VqRHmZU-3MOZo,10
31
+ cache_dit-0.1.6.dist-info/RECORD,,