cache-dit 0.2.14__py3-none-any.whl → 0.2.16__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.

Potentially problematic release.


This version of cache-dit might be problematic. Click here for more details.

Files changed (37) hide show
  1. cache_dit/_version.py +2 -2
  2. cache_dit/cache_factory/__init__.py +1 -0
  3. cache_dit/cache_factory/adapters.py +47 -5
  4. cache_dit/cache_factory/dual_block_cache/__init__.py +4 -0
  5. cache_dit/cache_factory/dual_block_cache/cache_blocks.py +487 -0
  6. cache_dit/cache_factory/dual_block_cache/cache_context.py +10 -860
  7. cache_dit/cache_factory/dual_block_cache/diffusers_adapters/__init__.py +4 -0
  8. cache_dit/cache_factory/dual_block_cache/diffusers_adapters/cogvideox.py +5 -2
  9. cache_dit/cache_factory/dual_block_cache/diffusers_adapters/flux.py +14 -4
  10. cache_dit/cache_factory/dual_block_cache/diffusers_adapters/hunyuan_video.py +5 -2
  11. cache_dit/cache_factory/dual_block_cache/diffusers_adapters/mochi.py +5 -2
  12. cache_dit/cache_factory/{first_block_cache/diffusers_adapters/mochi.py → dual_block_cache/diffusers_adapters/qwen_image.py} +14 -12
  13. cache_dit/cache_factory/dual_block_cache/diffusers_adapters/wan.py +7 -4
  14. cache_dit/cache_factory/dynamic_block_prune/__init__.py +4 -0
  15. cache_dit/cache_factory/dynamic_block_prune/diffusers_adapters/__init__.py +4 -0
  16. cache_dit/cache_factory/dynamic_block_prune/diffusers_adapters/cogvideox.py +5 -2
  17. cache_dit/cache_factory/dynamic_block_prune/diffusers_adapters/flux.py +10 -4
  18. cache_dit/cache_factory/dynamic_block_prune/diffusers_adapters/hunyuan_video.py +5 -2
  19. cache_dit/cache_factory/dynamic_block_prune/diffusers_adapters/mochi.py +5 -2
  20. cache_dit/cache_factory/{first_block_cache/diffusers_adapters/cogvideox.py → dynamic_block_prune/diffusers_adapters/qwen_image.py} +28 -23
  21. cache_dit/cache_factory/dynamic_block_prune/diffusers_adapters/wan.py +5 -2
  22. cache_dit/cache_factory/dynamic_block_prune/prune_blocks.py +276 -0
  23. cache_dit/cache_factory/dynamic_block_prune/prune_context.py +228 -516
  24. cache_dit/cache_factory/patch/flux.py +241 -0
  25. {cache_dit-0.2.14.dist-info → cache_dit-0.2.16.dist-info}/METADATA +22 -80
  26. cache_dit-0.2.16.dist-info/RECORD +47 -0
  27. cache_dit/cache_factory/first_block_cache/cache_context.py +0 -719
  28. cache_dit/cache_factory/first_block_cache/diffusers_adapters/__init__.py +0 -57
  29. cache_dit/cache_factory/first_block_cache/diffusers_adapters/flux.py +0 -100
  30. cache_dit/cache_factory/first_block_cache/diffusers_adapters/hunyuan_video.py +0 -295
  31. cache_dit/cache_factory/first_block_cache/diffusers_adapters/wan.py +0 -98
  32. cache_dit-0.2.14.dist-info/RECORD +0 -49
  33. /cache_dit/cache_factory/{first_block_cache → patch}/__init__.py +0 -0
  34. {cache_dit-0.2.14.dist-info → cache_dit-0.2.16.dist-info}/WHEEL +0 -0
  35. {cache_dit-0.2.14.dist-info → cache_dit-0.2.16.dist-info}/entry_points.txt +0 -0
  36. {cache_dit-0.2.14.dist-info → cache_dit-0.2.16.dist-info}/licenses/LICENSE +0 -0
  37. {cache_dit-0.2.14.dist-info → cache_dit-0.2.16.dist-info}/top_level.txt +0 -0
@@ -0,0 +1,241 @@
1
+ import inspect
2
+
3
+ import torch
4
+ import numpy as np
5
+ from typing import Tuple, Optional, Dict, Any, Union
6
+ from diffusers import FluxTransformer2DModel
7
+ from diffusers.models.transformers.transformer_flux import (
8
+ FluxSingleTransformerBlock,
9
+ Transformer2DModelOutput,
10
+ )
11
+ from diffusers.utils import (
12
+ USE_PEFT_BACKEND,
13
+ scale_lora_layers,
14
+ unscale_lora_layers,
15
+ )
16
+
17
+
18
+ from cache_dit.logger import init_logger
19
+
20
+ logger = init_logger(__name__)
21
+
22
+
23
+ # copy from: https://github.com/huggingface/diffusers/blob/main/src/diffusers/models/transformers/transformer_flux.py#L380
24
+ def __patch_single_forward__(
25
+ self: FluxSingleTransformerBlock,
26
+ hidden_states: torch.Tensor,
27
+ encoder_hidden_states: torch.Tensor,
28
+ temb: torch.Tensor,
29
+ image_rotary_emb: Optional[Tuple[torch.Tensor, torch.Tensor]] = None,
30
+ joint_attention_kwargs: Optional[Dict[str, Any]] = None,
31
+ ) -> Tuple[torch.Tensor, torch.Tensor]:
32
+ text_seq_len = encoder_hidden_states.shape[1]
33
+ hidden_states = torch.cat([encoder_hidden_states, hidden_states], dim=1)
34
+
35
+ residual = hidden_states
36
+ norm_hidden_states, gate = self.norm(hidden_states, emb=temb)
37
+ mlp_hidden_states = self.act_mlp(self.proj_mlp(norm_hidden_states))
38
+ joint_attention_kwargs = joint_attention_kwargs or {}
39
+ attn_output = self.attn(
40
+ hidden_states=norm_hidden_states,
41
+ image_rotary_emb=image_rotary_emb,
42
+ **joint_attention_kwargs,
43
+ )
44
+
45
+ hidden_states = torch.cat([attn_output, mlp_hidden_states], dim=2)
46
+ gate = gate.unsqueeze(1)
47
+ hidden_states = gate * self.proj_out(hidden_states)
48
+ hidden_states = residual + hidden_states
49
+ if hidden_states.dtype == torch.float16:
50
+ hidden_states = hidden_states.clip(-65504, 65504)
51
+
52
+ encoder_hidden_states, hidden_states = (
53
+ hidden_states[:, :text_seq_len],
54
+ hidden_states[:, text_seq_len:],
55
+ )
56
+ return encoder_hidden_states, hidden_states
57
+
58
+
59
+ # copy from: https://github.com/huggingface/diffusers/blob/main/src/diffusers/models/transformers/transformer_flux.py#L631
60
+ def __patch_transformer_forward__(
61
+ self: FluxTransformer2DModel,
62
+ hidden_states: torch.Tensor,
63
+ encoder_hidden_states: torch.Tensor = None,
64
+ pooled_projections: torch.Tensor = None,
65
+ timestep: torch.LongTensor = None,
66
+ img_ids: torch.Tensor = None,
67
+ txt_ids: torch.Tensor = None,
68
+ guidance: torch.Tensor = None,
69
+ joint_attention_kwargs: Optional[Dict[str, Any]] = None,
70
+ controlnet_block_samples=None,
71
+ controlnet_single_block_samples=None,
72
+ return_dict: bool = True,
73
+ controlnet_blocks_repeat: bool = False,
74
+ ) -> Union[torch.Tensor, Transformer2DModelOutput]:
75
+ if joint_attention_kwargs is not None:
76
+ joint_attention_kwargs = joint_attention_kwargs.copy()
77
+ lora_scale = joint_attention_kwargs.pop("scale", 1.0)
78
+ else:
79
+ lora_scale = 1.0
80
+
81
+ if USE_PEFT_BACKEND:
82
+ # weight the lora layers by setting `lora_scale` for each PEFT layer
83
+ scale_lora_layers(self, lora_scale)
84
+ else:
85
+ if (
86
+ joint_attention_kwargs is not None
87
+ and joint_attention_kwargs.get("scale", None) is not None
88
+ ):
89
+ logger.warning(
90
+ "Passing `scale` via `joint_attention_kwargs` when not using the PEFT backend is ineffective."
91
+ )
92
+
93
+ hidden_states = self.x_embedder(hidden_states)
94
+
95
+ timestep = timestep.to(hidden_states.dtype) * 1000
96
+ if guidance is not None:
97
+ guidance = guidance.to(hidden_states.dtype) * 1000
98
+
99
+ temb = (
100
+ self.time_text_embed(timestep, pooled_projections)
101
+ if guidance is None
102
+ else self.time_text_embed(timestep, guidance, pooled_projections)
103
+ )
104
+ encoder_hidden_states = self.context_embedder(encoder_hidden_states)
105
+
106
+ if txt_ids.ndim == 3:
107
+ logger.warning(
108
+ "Passing `txt_ids` 3d torch.Tensor is deprecated."
109
+ "Please remove the batch dimension and pass it as a 2d torch Tensor"
110
+ )
111
+ txt_ids = txt_ids[0]
112
+ if img_ids.ndim == 3:
113
+ logger.warning(
114
+ "Passing `img_ids` 3d torch.Tensor is deprecated."
115
+ "Please remove the batch dimension and pass it as a 2d torch Tensor"
116
+ )
117
+ img_ids = img_ids[0]
118
+
119
+ ids = torch.cat((txt_ids, img_ids), dim=0)
120
+ image_rotary_emb = self.pos_embed(ids)
121
+
122
+ if (
123
+ joint_attention_kwargs is not None
124
+ and "ip_adapter_image_embeds" in joint_attention_kwargs
125
+ ):
126
+ ip_adapter_image_embeds = joint_attention_kwargs.pop(
127
+ "ip_adapter_image_embeds"
128
+ )
129
+ ip_hidden_states = self.encoder_hid_proj(ip_adapter_image_embeds)
130
+ joint_attention_kwargs.update({"ip_hidden_states": ip_hidden_states})
131
+
132
+ for index_block, block in enumerate(self.transformer_blocks):
133
+ if torch.is_grad_enabled() and self.gradient_checkpointing:
134
+ encoder_hidden_states, hidden_states = (
135
+ self._gradient_checkpointing_func(
136
+ block,
137
+ hidden_states,
138
+ encoder_hidden_states,
139
+ temb,
140
+ image_rotary_emb,
141
+ joint_attention_kwargs,
142
+ )
143
+ )
144
+
145
+ else:
146
+ encoder_hidden_states, hidden_states = block(
147
+ hidden_states=hidden_states,
148
+ encoder_hidden_states=encoder_hidden_states,
149
+ temb=temb,
150
+ image_rotary_emb=image_rotary_emb,
151
+ joint_attention_kwargs=joint_attention_kwargs,
152
+ )
153
+
154
+ # controlnet residual
155
+ if controlnet_block_samples is not None:
156
+ interval_control = len(self.transformer_blocks) / len(
157
+ controlnet_block_samples
158
+ )
159
+ interval_control = int(np.ceil(interval_control))
160
+ # For Xlabs ControlNet.
161
+ if controlnet_blocks_repeat:
162
+ hidden_states = (
163
+ hidden_states
164
+ + controlnet_block_samples[
165
+ index_block % len(controlnet_block_samples)
166
+ ]
167
+ )
168
+ else:
169
+ hidden_states = (
170
+ hidden_states
171
+ + controlnet_block_samples[index_block // interval_control]
172
+ )
173
+
174
+ for index_block, block in enumerate(self.single_transformer_blocks):
175
+ if torch.is_grad_enabled() and self.gradient_checkpointing:
176
+ encoder_hidden_states, hidden_states = (
177
+ self._gradient_checkpointing_func(
178
+ block,
179
+ hidden_states,
180
+ encoder_hidden_states,
181
+ temb,
182
+ image_rotary_emb,
183
+ joint_attention_kwargs,
184
+ )
185
+ )
186
+
187
+ else:
188
+ encoder_hidden_states, hidden_states = block(
189
+ hidden_states=hidden_states,
190
+ encoder_hidden_states=encoder_hidden_states,
191
+ temb=temb,
192
+ image_rotary_emb=image_rotary_emb,
193
+ joint_attention_kwargs=joint_attention_kwargs,
194
+ )
195
+
196
+ # controlnet residual
197
+ if controlnet_single_block_samples is not None:
198
+ interval_control = len(self.single_transformer_blocks) / len(
199
+ controlnet_single_block_samples
200
+ )
201
+ interval_control = int(np.ceil(interval_control))
202
+ hidden_states = (
203
+ hidden_states
204
+ + controlnet_single_block_samples[
205
+ index_block // interval_control
206
+ ]
207
+ )
208
+
209
+ hidden_states = self.norm_out(hidden_states, temb)
210
+ output = self.proj_out(hidden_states)
211
+
212
+ if USE_PEFT_BACKEND:
213
+ # remove `lora_scale` from each PEFT layer
214
+ unscale_lora_layers(self, lora_scale)
215
+
216
+ if not return_dict:
217
+ return (output,)
218
+
219
+ return Transformer2DModelOutput(sample=output)
220
+
221
+
222
+ def maybe_patch_flux_transformer(
223
+ transformer: FluxTransformer2DModel,
224
+ ) -> FluxTransformer2DModel:
225
+ single_forward_parameters = inspect.signature(
226
+ transformer.single_transformer_blocks[0].forward
227
+ ).parameters.keys()
228
+ if "encoder_hidden_states" not in single_forward_parameters:
229
+ logger.warning("Patch FluxSingleTransformerBlock for cache-dit.")
230
+ for block in transformer.single_transformer_blocks:
231
+ block.forward = __patch_single_forward__.__get__(block)
232
+
233
+ assert not getattr(transformer, "_is_parallelized", False), (
234
+ "Please call apply_cache_on_pipe before Parallelize, "
235
+ "the __patch_transformer_forward__ will overwrite the "
236
+ "parallized forward and cause a downgrade of performance."
237
+ )
238
+ transformer.forward = __patch_transformer_forward__.__get__(transformer)
239
+ transformer._is_patched = True
240
+
241
+ return transformer
@@ -1,6 +1,6 @@
1
1
  Metadata-Version: 2.4
2
2
  Name: cache_dit
3
- Version: 0.2.14
3
+ Version: 0.2.16
4
4
  Summary: 🤗 CacheDiT: A Training-free and Easy-to-use Cache Acceleration Toolbox for Diffusion Transformers
5
5
  Author: DefTruth, vipshop.com, etc.
6
6
  Maintainer: DefTruth, vipshop.com, etc
@@ -52,20 +52,22 @@ Dynamic: requires-python
52
52
  <img src=https://img.shields.io/badge/Python-3.10|3.11|3.12-9cf.svg >
53
53
  <img src=https://img.shields.io/badge/Release-v0.2-brightgreen.svg >
54
54
  </div>
55
- <p align="center">
56
- DeepCache is for UNet not DiT. Most DiT cache speedups are complex and not training-free. <br> CacheDiT offers a set of training-free cache accelerators for Diffusion Transformers: <br> <b>🔥<a href="#dbcache">DBCache</a>, <a href="#dbprune">DBPrune</a>, <a href="#taylorseer">Hybrid TaylorSeer</a>, <a href="#cfg">Hybrid Cache CFG</a>, <a href="#fbcache">FBCache</a></b>, etc🔥
57
- </p>
55
+ 🔥<a href="#dbcache">DBCache</a> | <a href="#dbprune">DBPrune</a> | <a href="#taylorseer">Hybrid TaylorSeer</a> | <a href="#cfg">Hybrid Cache CFG</a> | <a href="#fbcache">FBCache</a>🔥
58
56
  </div>
59
57
 
60
58
  <div align="center">
61
59
  <p align="center">
62
- <b>♥️ Please consider to leave a ⭐️ Star to support us ~ ♥️</b>
60
+ ♥️ Please consider to leave a ⭐️ Star to support us ~ ♥️
63
61
  </p>
64
62
  </div>
65
63
 
66
- ## 🔥News🔥
67
- - [2025-07-18] 🎉First caching mechanism in **[🤗huggingface/flux-fast](https://github.com/huggingface/flux-fast)** with **[cache-dit](https://github.com/vipshop/cache-dit)**, also check the **[PR](https://github.com/huggingface/flux-fast/pull/13)**.
68
- - [2025-07-13] **[🤗flux-faster](https://github.com/xlite-dev/flux-faster)** is released! A forked version of **[🤗huggingface/flux-fast](https://github.com/huggingface/flux-fast)** that **makes flux-fast even faster** with **[cache-dit](https://github.com/vipshop/cache-dit)**, **3.3x** speedup on NVIDIA L20.
64
+ ## 🔥News
65
+
66
+ - [2025-08-12] 🎉First caching mechanism in [QwenLM/Qwen-Image](https://github.com/QwenLM/Qwen-Image) with **[cache-dit](https://github.com/vipshop/cache-dit)**, check the [PR](https://github.com/QwenLM/Qwen-Image/pull/61).
67
+ - [2025-08-11] 🔥[Qwen-Image](https://github.com/QwenLM/Qwen-Image) is supported now! Please check [run_qwen_image.py](./examples/run_qwen_image.py) as an example.
68
+ - [2025-08-10] 🔥[FLUX.1-Kontext-dev](https://huggingface.co/black-forest-labs/FLUX.1-Kontext-dev) is supported! Please check [run_flux_kontext.py](./examples/run_flux_kontext.py) as an example.
69
+ - [2025-07-18] 🎉First caching mechanism in [🤗huggingface/flux-fast](https://github.com/huggingface/flux-fast) with **[cache-dit](https://github.com/vipshop/cache-dit)**, check the [PR](https://github.com/huggingface/flux-fast/pull/13).
70
+ - [2025-07-13] **[🤗flux-faster](https://github.com/xlite-dev/flux-faster)** is released! **3.3x** speedup for FLUX.1 on NVIDIA L20 with `cache-dit`.
69
71
 
70
72
  ## 📖Contents
71
73
 
@@ -78,7 +80,6 @@ Dynamic: requires-python
78
80
  - [⚡️Hybrid Cache CFG](#cfg)
79
81
  - [🎉First Block Cache](#fbcache)
80
82
  - [⚡️Dynamic Block Prune](#dbprune)
81
- - [🎉Context Parallelism](#context-parallelism)
82
83
  - [🔥Torch Compile](#compile)
83
84
  - [⚙️Metrics CLI](#metrics)
84
85
  - [👋Contribute](#contribute)
@@ -104,8 +105,10 @@ pip3 install git+https://github.com/vipshop/cache-dit.git
104
105
 
105
106
  <div id="supported"></div>
106
107
 
108
+ - [🚀Qwen-Image](https://github.com/vipshop/cache-dit/raw/main/examples)
107
109
  - [🚀FLUX.1-dev](https://github.com/vipshop/cache-dit/raw/main/examples)
108
110
  - [🚀FLUX.1-Fill-dev](https://github.com/vipshop/cache-dit/raw/main/examples)
111
+ - [🚀FLUX.1-Kontext-dev](https://github.com/vipshop/cache-dit/raw/main/examples)
109
112
  - [🚀mochi-1-preview](https://github.com/vipshop/cache-dit/raw/main/examples)
110
113
  - [🚀CogVideoX](https://github.com/vipshop/cache-dit/raw/main/examples)
111
114
  - [🚀CogVideoX1.5](https://github.com/vipshop/cache-dit/raw/main/examples)
@@ -157,7 +160,7 @@ These case studies demonstrate that even with relatively high thresholds (such a
157
160
  - **max_cached_steps**: (default: -1) DBCache disables the caching strategy when the previous cached steps exceed this value to prevent precision degradation.
158
161
  - **residual_diff_threshold**: The value of residual diff threshold, a higher value leads to faster performance at the cost of lower precision.
159
162
 
160
- For a good balance between performance and precision, DBCache is configured by default with **F8B8**, 8 warmup steps, and unlimited cached steps.
163
+ For a good balance between performance and precision, DBCache is configured by default with **F8B0**, 8 warmup steps, and unlimited cached steps.
161
164
 
162
165
  ```python
163
166
  from diffusers import FluxPipeline
@@ -168,16 +171,16 @@ pipe = FluxPipeline.from_pretrained(
168
171
  torch_dtype=torch.bfloat16,
169
172
  ).to("cuda")
170
173
 
171
- # Default options, F8B8, good balance between performance and precision
174
+ # Default options, F8B0, good balance between performance and precision
172
175
  cache_options = CacheType.default_options(CacheType.DBCache)
173
176
 
174
- # Custom options, F8B16, higher precision
177
+ # Custom options, F8B0, higher precision
175
178
  cache_options = {
176
179
  "cache_type": CacheType.DBCache,
177
180
  "warmup_steps": 8,
178
- "max_cached_steps": 8, # -1 means no limit
179
- "Fn_compute_blocks": 8, # Fn, F8, etc.
180
- "Bn_compute_blocks": 16, # Bn, B16, etc.
181
+ "max_cached_steps": -1, # -1 means no limit
182
+ "Fn_compute_blocks": 8, # Fn, F8, etc.
183
+ "Bn_compute_blocks": 0, # Bn, B0, etc.
181
184
  "residual_diff_threshold": 0.12,
182
185
  }
183
186
 
@@ -253,7 +256,7 @@ cache_options = {
253
256
  # should set do_separate_classifier_free_guidance as False.
254
257
  # For example, set it as True for Wan 2.1 and set it as False
255
258
  # for FLUX.1, HunyuanVideo, CogVideoX, Mochi.
256
- "do_separate_classifier_free_guidance": True, # Wan 2.1
259
+ "do_separate_classifier_free_guidance": True, # Wan 2.1, Qwen-Image
257
260
  # Compute cfg forward first or not, default False, namely,
258
261
  # 0, 2, 4, ..., -> non-CFG step; 1, 3, 5, ... -> CFG step.
259
262
  "cfg_compute_first": False,
@@ -270,7 +273,7 @@ cache_options = {
270
273
 
271
274
  ![](https://github.com/vipshop/cache-dit/raw/main/assets/fbcache-v1.png)
272
275
 
273
- **DBCache** is a more general cache algorithm than **FBCache**. When Fn=1 and Bn=0, DBCache behaves identically to FBCache. Therefore, you can either use the original FBCache implementation directly or configure **DBCache** with **F1B0** settings to achieve the same functionality.
276
+ **DBCache** is a more general cache algorithm than **FBCache**. When Fn=1 and Bn=0, DBCache behaves identically to FBCache. Therefore, you can use configure **DBCache** with **F1B0** settings to achieve the same functionality.
274
277
 
275
278
  ```python
276
279
  from diffusers import FluxPipeline
@@ -281,15 +284,12 @@ pipe = FluxPipeline.from_pretrained(
281
284
  torch_dtype=torch.bfloat16,
282
285
  ).to("cuda")
283
286
 
284
- # Using FBCache directly
285
- cache_options = CacheType.default_options(CacheType.FBCache)
286
-
287
287
  # Or using DBCache with F1B0.
288
288
  # Fn=1, Bn=0, means FB Cache, otherwise, Dual Block Cache
289
289
  cache_options = {
290
290
  "cache_type": CacheType.DBCache,
291
291
  "warmup_steps": 8,
292
- "max_cached_steps": 8, # -1 means no limit
292
+ "max_cached_steps": -1, # -1 means no limit
293
293
  "Fn_compute_blocks": 1, # Fn, F1, etc.
294
294
  "Bn_compute_blocks": 0, # Bn, B0, etc.
295
295
  "residual_diff_threshold": 0.12,
@@ -370,64 +370,6 @@ apply_cache_on_pipe(pipe, **cache_options)
370
370
  |24.85s|19.43s|16.82s|15.95s|14.24s|10.66s|
371
371
  |<img src=https://github.com/vipshop/cache-dit/raw/main/assets/NONE_R0.08_S0.png width=105px>|<img src=https://github.com/vipshop/cache-dit/raw/main/assets/DBPRUNE_F1B0_R0.03_P24.0_T19.43s.png width=105px> | <img src=https://github.com/vipshop/cache-dit/raw/main/assets/DBPRUNE_F1B0_R0.04_P34.6_T16.82s.png width=105px>|<img src=https://github.com/vipshop/cache-dit/raw/main/assets/DBPRUNE_F1B0_R0.05_P38.3_T15.95s.png width=105px>|<img src=https://github.com/vipshop/cache-dit/raw/main/assets/DBPRUNE_F1B0_R0.06_P45.2_T14.24s.png width=105px>|<img src=https://github.com/vipshop/cache-dit/raw/main/assets/DBPRUNE_F1B0_R0.2_P59.5_T10.66s.png width=105px>|
372
372
 
373
- ## 🎉Context Parallelism
374
-
375
- <div id="context-parallelism"></div>
376
-
377
- **CacheDiT** are **plug-and-play** solutions that works hand-in-hand with [ParaAttention](https://github.com/chengzeyi/ParaAttention). Users can **easily tap into** its **Context Parallelism** features for distributed inference. Firstly, install `para-attn` from PyPI:
378
-
379
- ```bash
380
- pip3 install para-attn # or install `para-attn` from sources.
381
- ```
382
-
383
- Then, you can run **DBCache** or **DBPrune** with **Context Parallelism** on 4 GPUs:
384
-
385
- ```python
386
- import torch.distributed as dist
387
- from diffusers import FluxPipeline
388
- from para_attn.context_parallel import init_context_parallel_mesh
389
- from para_attn.context_parallel.diffusers_adapters import parallelize_pipe
390
- from cache_dit.cache_factory import apply_cache_on_pipe, CacheType
391
-
392
- # Init distributed process group
393
- dist.init_process_group()
394
- torch.cuda.set_device(dist.get_rank())
395
-
396
- pipe = FluxPipeline.from_pretrained(
397
- "black-forest-labs/FLUX.1-dev",
398
- torch_dtype=torch.bfloat16,
399
- ).to("cuda")
400
-
401
- # Context Parallel from ParaAttention
402
- parallelize_pipe(
403
- pipe, mesh=init_context_parallel_mesh(
404
- pipe.device.type, max_ulysses_dim_size=4
405
- )
406
- )
407
-
408
- # DBPrune with default options from this library
409
- apply_cache_on_pipe(
410
- pipe, **CacheType.default_options(CacheType.DBPrune)
411
- )
412
-
413
- dist.destroy_process_group()
414
- ```
415
- Then, run the python test script with `torchrun`:
416
- ```bash
417
- torchrun --nproc_per_node=4 parallel_cache.py
418
- ```
419
-
420
- <div align="center">
421
- <p align="center">
422
- DBPrune + <b>torch.compile + context parallelism</b> <br>Steps: 28, "A cat holding a sign that says hello world with complex background"
423
- </p>
424
- </div>
425
-
426
- |Baseline|Pruned(24%)|Pruned(35%)|Pruned(38%)|Pruned(45%)|Pruned(60%)|
427
- |:---:|:---:|:---:|:---:|:---:|:---:|
428
- |+compile:20.43s|16.25s|14.12s|13.41s|12.00s|8.86s|
429
- |+L20x4:7.75s|6.62s|6.03s|5.81s|5.24s|3.93s|
430
- |<img src=https://github.com/vipshop/cache-dit/raw/main/assets/U0_C1_NONE_R0.08_S0_T20.43s.png width=105px>|<img src=https://github.com/vipshop/cache-dit/raw/main/assets/U0_C1_DBPRUNE_F1B0_R0.03_P24.0_T16.25s.png width=105px> | <img src=https://github.com/vipshop/cache-dit/raw/main/assets/U0_C1_DBPRUNE_F1B0_R0.04_P34.6_T14.12s.png width=105px>|<img src=https://github.com/vipshop/cache-dit/raw/main/assets/U0_C1_DBPRUNE_F1B0_R0.045_P38.2_T13.41s.png width=105px>|<img src=https://github.com/vipshop/cache-dit/raw/main/assets/U0_C1_DBPRUNE_F1B0_R0.055_P45.1_T12.00s.png width=105px>|<img src=https://github.com/vipshop/cache-dit/raw/main/assets/U0_C1_DBPRUNE_F1B0_R0.2_P59.5_T8.86s.png width=105px>|
431
373
 
432
374
  ## 🔥Torch Compile
433
375
 
@@ -490,7 +432,7 @@ How to contribute? Star ⭐️ this repo to support us or check [CONTRIBUTE.md](
490
432
 
491
433
  <div id="license"></div>
492
434
 
493
- The **CacheDiT** codebase is adapted from [FBCache](https://github.com/chengzeyi/ParaAttention/tree/main/src/para_attn/first_block_cache). Special thanks to their excellent work! We have followed the original License from [FBCache](https://github.com/chengzeyi/ParaAttention), please check [LICENSE](https://github.com/vipshop/cache-dit/raw/main/LICENSE) for more details.
435
+ The **CacheDiT** codebase is adapted from [FBCache](https://github.com/chengzeyi/ParaAttention/tree/main/src/para_attn/first_block_cache). Special thanks to their excellent work! We have followed the original License from FBCache, please check [LICENSE](https://github.com/vipshop/cache-dit/raw/main/LICENSE) for more details.
494
436
 
495
437
  ## ©️Citations
496
438
 
@@ -0,0 +1,47 @@
1
+ cache_dit/__init__.py,sha256=0-B173-fLi3IA8nJXoS71zK0zD33Xplysd9skmLfEOY,171
2
+ cache_dit/_version.py,sha256=uFGhweCFKwebVyMUvDALfnhYcWJQj8O3h_9xJIOhTtk,513
3
+ cache_dit/logger.py,sha256=0zsu42hN-3-rgGC_C29ms1IvVpV4_b4_SwJCKSenxBE,4304
4
+ cache_dit/primitives.py,sha256=A2iG9YLot3gOsZSPp-_gyjqjLgJvWQRx8aitD4JQ23Y,3877
5
+ cache_dit/utils.py,sha256=4cFNh0asch6Zgsixq0bS1ElfwBu_6BG5ZSmaa1khjyg,144
6
+ cache_dit/cache_factory/__init__.py,sha256=2UuUJ-CRXrLbv_ZoC2nV3qPHoipfqeWvO7xZO3CxOD4,263
7
+ cache_dit/cache_factory/adapters.py,sha256=3iHIkhkb_2s1f-W-jw0bCToZyLYvbJlPpxASv4EqrqU,6714
8
+ cache_dit/cache_factory/taylorseer.py,sha256=LKSNo2ode69EVo9xrxjxAMEjz0yDGiGADeDYnEqddA8,3987
9
+ cache_dit/cache_factory/utils.py,sha256=V-Mb5Jn07geEUUWo4QAfh6pmSzkL-2OGDn0VAXbG6hQ,1799
10
+ cache_dit/cache_factory/dual_block_cache/__init__.py,sha256=gkouVv5IgcsiTOQ5_I-a3S3TnJifNZhrwdkrO1KRCqw,120
11
+ cache_dit/cache_factory/dual_block_cache/cache_blocks.py,sha256=M9R_6t-X6vrsSNvSMQLn24I_fZ9EpFK4RlRzgAtEdac,18407
12
+ cache_dit/cache_factory/dual_block_cache/cache_context.py,sha256=28g3lqgsarCk-u880QF0yl967glXIZrWpUfJYKJZqxg,39872
13
+ cache_dit/cache_factory/dual_block_cache/diffusers_adapters/__init__.py,sha256=uSqF5aD2-feHB25vEbx1STBQVjWVAOn_wYTdAEmS4NU,2045
14
+ cache_dit/cache_factory/dual_block_cache/diffusers_adapters/cogvideox.py,sha256=RNF24Ysuddo5cjv4hUT4s3C4C3pgud5YJ1OklvtrrlU,2286
15
+ cache_dit/cache_factory/dual_block_cache/diffusers_adapters/flux.py,sha256=URWqROg_ANseraFf-on5ZqvJzD7tH3WouboFaJgOAkk,2853
16
+ cache_dit/cache_factory/dual_block_cache/diffusers_adapters/hunyuan_video.py,sha256=y0Ab_tTJOixTaK-Remg5IZLCDGIgBOZFszI32PnK9gc,9981
17
+ cache_dit/cache_factory/dual_block_cache/diffusers_adapters/mochi.py,sha256=nqpbVjT0vFecIIqVGBvCUIBysOtFFjh-MhrXw4VPBGo,2278
18
+ cache_dit/cache_factory/dual_block_cache/diffusers_adapters/qwen_image.py,sha256=qhwpGDxx5RBbumSArc38z4aMv5YVwxH3eiyIpzjXycQ,2322
19
+ cache_dit/cache_factory/dual_block_cache/diffusers_adapters/wan.py,sha256=P9-r0HNofBO8rWpB4pAQvGOEbsfiHUAKV16Noq8ZzWg,2610
20
+ cache_dit/cache_factory/dynamic_block_prune/__init__.py,sha256=d_cbjrkBaprzQw4HJI3sRtqMzrhDVL56moLkpSGnqO4,123
21
+ cache_dit/cache_factory/dynamic_block_prune/prune_blocks.py,sha256=6L9WhuWP6GHFRLnGRO_XZ1oE4v9xm-yqQ7cgNP_fOdY,9704
22
+ cache_dit/cache_factory/dynamic_block_prune/prune_context.py,sha256=94e0VKmbU09hb-s8JQ9vOYTJtZ07gva7XT9OQm8EBSc,25530
23
+ cache_dit/cache_factory/dynamic_block_prune/diffusers_adapters/__init__.py,sha256=9WLorIqd7m_dIDs6pPPj-lcd9e56fdvGM2D3DDdWwEU,2045
24
+ cache_dit/cache_factory/dynamic_block_prune/diffusers_adapters/cogvideox.py,sha256=Ny1f3gvtcbVbTvGXNdyv7YC2Ez1rGonttx7xnsvfHM4,2299
25
+ cache_dit/cache_factory/dynamic_block_prune/diffusers_adapters/flux.py,sha256=ibANvG29Sa2CU2pG_IndkEYarshuQDWGTP3klnAStV0,2792
26
+ cache_dit/cache_factory/dynamic_block_prune/diffusers_adapters/hunyuan_video.py,sha256=ylNgdkCgFwWbuCb2sbyIV6XoHTd8MzGbgMO0t1Z6D50,9994
27
+ cache_dit/cache_factory/dynamic_block_prune/diffusers_adapters/mochi.py,sha256=iMZd473vkh2hsJRIlKN2DCsCRJmsV4JSB7SZzZdg4eI,2291
28
+ cache_dit/cache_factory/dynamic_block_prune/diffusers_adapters/qwen_image.py,sha256=0MgHcT4EVUcUnq8-g4GO2p11MzMtRhQp5D3wG9t_KlA,2398
29
+ cache_dit/cache_factory/dynamic_block_prune/diffusers_adapters/wan.py,sha256=VOdnpIZRbcyKHP7e4yRPSlS5hZSyq4dcp9fmHOfI-Fc,2630
30
+ cache_dit/cache_factory/patch/__init__.py,sha256=47DEQpj8HBSa-_TImW-5JCeuQeRkm5NMpJWZG3hSuFU,0
31
+ cache_dit/cache_factory/patch/flux.py,sha256=zRUWZlt02vZGJcK2WOfmSbGR4UOUvsFvSBDgeBNZxh8,8813
32
+ cache_dit/compile/__init__.py,sha256=DfMdPleFFGADXLsr7zXui8BTz_y9futY6rNmNdh9y7k,63
33
+ cache_dit/compile/utils.py,sha256=N4A55_8uIbEd-S4xyJPcrdKceI2MGM9BTIhJE63jyL4,3786
34
+ cache_dit/custom_ops/__init__.py,sha256=47DEQpj8HBSa-_TImW-5JCeuQeRkm5NMpJWZG3hSuFU,0
35
+ cache_dit/custom_ops/triton_taylorseer.py,sha256=47DEQpj8HBSa-_TImW-5JCeuQeRkm5NMpJWZG3hSuFU,0
36
+ cache_dit/metrics/__init__.py,sha256=RaUhl5dieF40RqnizGzR30qoJJ9dyMUEADwgwMaMQrE,575
37
+ cache_dit/metrics/config.py,sha256=ieOgD9ayz722RjVzk24bSIqS2D6o7TZjGk8KeXV-OLQ,551
38
+ cache_dit/metrics/fid.py,sha256=9Ivtazl6mW0Bon2VXa-Ia5Xj2ewxRD3V1Qkd69zYM3Y,17066
39
+ cache_dit/metrics/inception.py,sha256=pBVe2X6ylLPIXTG4-GWDM9DWnCviMJbJ45R3ulhktR0,12759
40
+ cache_dit/metrics/lpips.py,sha256=I2qCNi6qJh5TRsaIsdxO0WoRX1DN7U_H3zS0oCSahYM,1032
41
+ cache_dit/metrics/metrics.py,sha256=8jvM1sF-nDxUuwCRy44QEoo4dYVLCQVh1QyAMs4eaQY,27840
42
+ cache_dit-0.2.16.dist-info/licenses/LICENSE,sha256=Dqb07Ik2dV41s9nIdMUbiRWEfDqo7-dQeRiY7kPO8PE,3769
43
+ cache_dit-0.2.16.dist-info/METADATA,sha256=y8oidHX3B0iZXfjym_mHtGgB0fGE4fiEA6JIlIeWaRo,22769
44
+ cache_dit-0.2.16.dist-info/WHEEL,sha256=_zCd3N1l69ArxyTb8rzEoP9TpbYXkqRFSNOD5OuxnTs,91
45
+ cache_dit-0.2.16.dist-info/entry_points.txt,sha256=FX2gysXaZx6NeK1iCLMcIdP8Q4_qikkIHtEmi3oWn8o,65
46
+ cache_dit-0.2.16.dist-info/top_level.txt,sha256=ZJDydonLEhujzz0FOkVbO-BqfzO9d_VqRHmZU-3MOZo,10
47
+ cache_dit-0.2.16.dist-info/RECORD,,