cache-dit 0.2.14__py3-none-any.whl → 0.2.16__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Potentially problematic release.
This version of cache-dit might be problematic. Click here for more details.
- cache_dit/_version.py +2 -2
- cache_dit/cache_factory/__init__.py +1 -0
- cache_dit/cache_factory/adapters.py +47 -5
- cache_dit/cache_factory/dual_block_cache/__init__.py +4 -0
- cache_dit/cache_factory/dual_block_cache/cache_blocks.py +487 -0
- cache_dit/cache_factory/dual_block_cache/cache_context.py +10 -860
- cache_dit/cache_factory/dual_block_cache/diffusers_adapters/__init__.py +4 -0
- cache_dit/cache_factory/dual_block_cache/diffusers_adapters/cogvideox.py +5 -2
- cache_dit/cache_factory/dual_block_cache/diffusers_adapters/flux.py +14 -4
- cache_dit/cache_factory/dual_block_cache/diffusers_adapters/hunyuan_video.py +5 -2
- cache_dit/cache_factory/dual_block_cache/diffusers_adapters/mochi.py +5 -2
- cache_dit/cache_factory/{first_block_cache/diffusers_adapters/mochi.py → dual_block_cache/diffusers_adapters/qwen_image.py} +14 -12
- cache_dit/cache_factory/dual_block_cache/diffusers_adapters/wan.py +7 -4
- cache_dit/cache_factory/dynamic_block_prune/__init__.py +4 -0
- cache_dit/cache_factory/dynamic_block_prune/diffusers_adapters/__init__.py +4 -0
- cache_dit/cache_factory/dynamic_block_prune/diffusers_adapters/cogvideox.py +5 -2
- cache_dit/cache_factory/dynamic_block_prune/diffusers_adapters/flux.py +10 -4
- cache_dit/cache_factory/dynamic_block_prune/diffusers_adapters/hunyuan_video.py +5 -2
- cache_dit/cache_factory/dynamic_block_prune/diffusers_adapters/mochi.py +5 -2
- cache_dit/cache_factory/{first_block_cache/diffusers_adapters/cogvideox.py → dynamic_block_prune/diffusers_adapters/qwen_image.py} +28 -23
- cache_dit/cache_factory/dynamic_block_prune/diffusers_adapters/wan.py +5 -2
- cache_dit/cache_factory/dynamic_block_prune/prune_blocks.py +276 -0
- cache_dit/cache_factory/dynamic_block_prune/prune_context.py +228 -516
- cache_dit/cache_factory/patch/flux.py +241 -0
- {cache_dit-0.2.14.dist-info → cache_dit-0.2.16.dist-info}/METADATA +22 -80
- cache_dit-0.2.16.dist-info/RECORD +47 -0
- cache_dit/cache_factory/first_block_cache/cache_context.py +0 -719
- cache_dit/cache_factory/first_block_cache/diffusers_adapters/__init__.py +0 -57
- cache_dit/cache_factory/first_block_cache/diffusers_adapters/flux.py +0 -100
- cache_dit/cache_factory/first_block_cache/diffusers_adapters/hunyuan_video.py +0 -295
- cache_dit/cache_factory/first_block_cache/diffusers_adapters/wan.py +0 -98
- cache_dit-0.2.14.dist-info/RECORD +0 -49
- /cache_dit/cache_factory/{first_block_cache → patch}/__init__.py +0 -0
- {cache_dit-0.2.14.dist-info → cache_dit-0.2.16.dist-info}/WHEEL +0 -0
- {cache_dit-0.2.14.dist-info → cache_dit-0.2.16.dist-info}/entry_points.txt +0 -0
- {cache_dit-0.2.14.dist-info → cache_dit-0.2.16.dist-info}/licenses/LICENSE +0 -0
- {cache_dit-0.2.14.dist-info → cache_dit-0.2.16.dist-info}/top_level.txt +0 -0
|
@@ -0,0 +1,276 @@
|
|
|
1
|
+
import inspect
|
|
2
|
+
import torch
|
|
3
|
+
|
|
4
|
+
from cache_dit.cache_factory.dynamic_block_prune import prune_context
|
|
5
|
+
from cache_dit.logger import init_logger
|
|
6
|
+
|
|
7
|
+
logger = init_logger(__name__)
|
|
8
|
+
|
|
9
|
+
|
|
10
|
+
class DBPrunedTransformerBlocks(torch.nn.Module):
|
|
11
|
+
def __init__(
|
|
12
|
+
self,
|
|
13
|
+
transformer_blocks: torch.nn.ModuleList,
|
|
14
|
+
*,
|
|
15
|
+
transformer: torch.nn.Module = None,
|
|
16
|
+
return_hidden_states_first: bool = True,
|
|
17
|
+
return_hidden_states_only: bool = False,
|
|
18
|
+
):
|
|
19
|
+
super().__init__()
|
|
20
|
+
|
|
21
|
+
self.transformer = transformer
|
|
22
|
+
self.transformer_blocks = transformer_blocks
|
|
23
|
+
self.return_hidden_states_first = return_hidden_states_first
|
|
24
|
+
self.return_hidden_states_only = return_hidden_states_only
|
|
25
|
+
self.pruned_blocks_step: int = 0
|
|
26
|
+
self._check_forward_params()
|
|
27
|
+
|
|
28
|
+
def _check_forward_params(self):
|
|
29
|
+
# NOTE: DBPrune only support blocks which have the pattern:
|
|
30
|
+
# IN/OUT: (hidden_states, encoder_hidden_states)
|
|
31
|
+
self.required_parameters = [
|
|
32
|
+
"hidden_states",
|
|
33
|
+
"encoder_hidden_states",
|
|
34
|
+
]
|
|
35
|
+
if self.transformer_blocks is not None:
|
|
36
|
+
for block in self.transformer_blocks:
|
|
37
|
+
forward_parameters = set(
|
|
38
|
+
inspect.signature(block.forward).parameters.keys()
|
|
39
|
+
)
|
|
40
|
+
for required_param in self.required_parameters:
|
|
41
|
+
assert (
|
|
42
|
+
required_param in forward_parameters
|
|
43
|
+
), f"The input parameters must contains: {required_param}."
|
|
44
|
+
|
|
45
|
+
def forward(
|
|
46
|
+
self,
|
|
47
|
+
hidden_states: torch.Tensor,
|
|
48
|
+
encoder_hidden_states: torch.Tensor,
|
|
49
|
+
*args,
|
|
50
|
+
**kwargs,
|
|
51
|
+
):
|
|
52
|
+
prune_context.mark_step_begin()
|
|
53
|
+
self.pruned_blocks_step = 0
|
|
54
|
+
original_hidden_states = hidden_states
|
|
55
|
+
|
|
56
|
+
torch._dynamo.graph_break()
|
|
57
|
+
hidden_states, encoder_hidden_states = self.call_blocks(
|
|
58
|
+
hidden_states,
|
|
59
|
+
encoder_hidden_states,
|
|
60
|
+
*args,
|
|
61
|
+
**kwargs,
|
|
62
|
+
)
|
|
63
|
+
|
|
64
|
+
del original_hidden_states
|
|
65
|
+
torch._dynamo.graph_break()
|
|
66
|
+
|
|
67
|
+
prune_context.add_pruned_block(self.pruned_blocks_step)
|
|
68
|
+
prune_context.add_actual_block(self.num_transformer_blocks)
|
|
69
|
+
patch_pruned_stats(self.transformer)
|
|
70
|
+
|
|
71
|
+
return (
|
|
72
|
+
hidden_states
|
|
73
|
+
if self.return_hidden_states_only
|
|
74
|
+
else (
|
|
75
|
+
(hidden_states, encoder_hidden_states)
|
|
76
|
+
if self.return_hidden_states_first
|
|
77
|
+
else (encoder_hidden_states, hidden_states)
|
|
78
|
+
)
|
|
79
|
+
)
|
|
80
|
+
|
|
81
|
+
@property
|
|
82
|
+
@torch.compiler.disable
|
|
83
|
+
def num_transformer_blocks(self):
|
|
84
|
+
# Total number of transformer blocks.
|
|
85
|
+
return len(self.transformer_blocks)
|
|
86
|
+
|
|
87
|
+
@torch.compiler.disable
|
|
88
|
+
def _is_parallelized(self):
|
|
89
|
+
# Compatible with distributed inference.
|
|
90
|
+
return all(
|
|
91
|
+
(
|
|
92
|
+
self.transformer is not None,
|
|
93
|
+
getattr(self.transformer, "_is_parallelized", False),
|
|
94
|
+
)
|
|
95
|
+
)
|
|
96
|
+
|
|
97
|
+
@torch.compiler.disable
|
|
98
|
+
def _non_prune_blocks_ids(self):
|
|
99
|
+
# Never prune the first `Fn` and last `Bn` blocks.
|
|
100
|
+
num_blocks = self.num_transformer_blocks
|
|
101
|
+
Fn_compute_blocks_ = (
|
|
102
|
+
prune_context.Fn_compute_blocks()
|
|
103
|
+
if prune_context.Fn_compute_blocks() < num_blocks
|
|
104
|
+
else num_blocks
|
|
105
|
+
)
|
|
106
|
+
Fn_compute_blocks_ids = list(range(Fn_compute_blocks_))
|
|
107
|
+
Bn_compute_blocks_ = (
|
|
108
|
+
prune_context.Bn_compute_blocks()
|
|
109
|
+
if prune_context.Bn_compute_blocks() < num_blocks
|
|
110
|
+
else num_blocks
|
|
111
|
+
)
|
|
112
|
+
Bn_compute_blocks_ids = list(
|
|
113
|
+
range(
|
|
114
|
+
num_blocks - Bn_compute_blocks_,
|
|
115
|
+
num_blocks,
|
|
116
|
+
)
|
|
117
|
+
)
|
|
118
|
+
non_prune_blocks_ids = list(
|
|
119
|
+
set(
|
|
120
|
+
Fn_compute_blocks_ids
|
|
121
|
+
+ Bn_compute_blocks_ids
|
|
122
|
+
+ prune_context.get_non_prune_blocks_ids()
|
|
123
|
+
)
|
|
124
|
+
)
|
|
125
|
+
non_prune_blocks_ids = [
|
|
126
|
+
d for d in non_prune_blocks_ids if d < num_blocks
|
|
127
|
+
]
|
|
128
|
+
return sorted(non_prune_blocks_ids)
|
|
129
|
+
|
|
130
|
+
@torch.compiler.disable
|
|
131
|
+
def _should_update_residuals(self):
|
|
132
|
+
# Wrap for non compiled mode.
|
|
133
|
+
# Check if the current step is a multiple of
|
|
134
|
+
# the residual cache update interval.
|
|
135
|
+
return (
|
|
136
|
+
prune_context.get_current_step()
|
|
137
|
+
% prune_context.residual_cache_update_interval()
|
|
138
|
+
== 0
|
|
139
|
+
)
|
|
140
|
+
|
|
141
|
+
@torch.compiler.disable
|
|
142
|
+
def _get_can_use_prune(
|
|
143
|
+
self,
|
|
144
|
+
block_id: int, # Block index in the transformer blocks
|
|
145
|
+
hidden_states: torch.Tensor, # hidden_states or residual
|
|
146
|
+
name: str = "Bn_original", # prev step name for single blocks
|
|
147
|
+
):
|
|
148
|
+
# Wrap for non compiled mode.
|
|
149
|
+
can_use_prune = False
|
|
150
|
+
if block_id not in self._non_prune_blocks_ids():
|
|
151
|
+
can_use_prune = prune_context.get_can_use_prune(
|
|
152
|
+
hidden_states, # curr step
|
|
153
|
+
parallelized=self._is_parallelized(),
|
|
154
|
+
name=name, # prev step
|
|
155
|
+
)
|
|
156
|
+
self.pruned_blocks_step += int(can_use_prune)
|
|
157
|
+
return can_use_prune
|
|
158
|
+
|
|
159
|
+
def _compute_or_prune_block(
|
|
160
|
+
self,
|
|
161
|
+
block_id: int, # Block index in the transformer blocks
|
|
162
|
+
# Below are the inputs to the block
|
|
163
|
+
block, # The transformer block to be executed
|
|
164
|
+
hidden_states: torch.Tensor,
|
|
165
|
+
encoder_hidden_states: torch.Tensor,
|
|
166
|
+
*args,
|
|
167
|
+
**kwargs,
|
|
168
|
+
):
|
|
169
|
+
# Helper function for `call_blocks`
|
|
170
|
+
original_hidden_states = hidden_states
|
|
171
|
+
original_encoder_hidden_states = encoder_hidden_states
|
|
172
|
+
|
|
173
|
+
# block_id: global block index in the transformer blocks +
|
|
174
|
+
# single_transformer_blocks
|
|
175
|
+
can_use_prune = self._get_can_use_prune(
|
|
176
|
+
block_id,
|
|
177
|
+
hidden_states, # hidden_states or residual
|
|
178
|
+
name=f"{block_id}_original", # prev step
|
|
179
|
+
)
|
|
180
|
+
# Prune steps: Prune current block and reuse the cached
|
|
181
|
+
# residuals for hidden states approximate.
|
|
182
|
+
if can_use_prune:
|
|
183
|
+
hidden_states, encoder_hidden_states = (
|
|
184
|
+
prune_context.apply_hidden_states_residual(
|
|
185
|
+
hidden_states,
|
|
186
|
+
encoder_hidden_states,
|
|
187
|
+
name=f"{block_id}_residual",
|
|
188
|
+
encoder_name=f"{block_id}_encoder_residual",
|
|
189
|
+
)
|
|
190
|
+
)
|
|
191
|
+
else:
|
|
192
|
+
# Normal steps: Compute the block and cache the residuals.
|
|
193
|
+
hidden_states = block(
|
|
194
|
+
hidden_states,
|
|
195
|
+
encoder_hidden_states,
|
|
196
|
+
*args,
|
|
197
|
+
**kwargs,
|
|
198
|
+
)
|
|
199
|
+
if not isinstance(hidden_states, torch.Tensor):
|
|
200
|
+
hidden_states, encoder_hidden_states = hidden_states
|
|
201
|
+
if not self.return_hidden_states_first:
|
|
202
|
+
hidden_states, encoder_hidden_states = (
|
|
203
|
+
encoder_hidden_states,
|
|
204
|
+
hidden_states,
|
|
205
|
+
)
|
|
206
|
+
|
|
207
|
+
# Save original_hidden_states for diff calculation.
|
|
208
|
+
# May not be necessary to update the hidden
|
|
209
|
+
# states and residuals each step?
|
|
210
|
+
if self._should_update_residuals():
|
|
211
|
+
# Cache residuals for the non-compute Bn blocks for
|
|
212
|
+
# subsequent prune steps.
|
|
213
|
+
hidden_states_residual = hidden_states - original_hidden_states
|
|
214
|
+
encoder_hidden_states_residual = (
|
|
215
|
+
encoder_hidden_states - original_encoder_hidden_states
|
|
216
|
+
)
|
|
217
|
+
prune_context.set_buffer(
|
|
218
|
+
f"{block_id}_original",
|
|
219
|
+
original_hidden_states,
|
|
220
|
+
)
|
|
221
|
+
|
|
222
|
+
prune_context.set_buffer(
|
|
223
|
+
f"{block_id}_residual",
|
|
224
|
+
hidden_states_residual,
|
|
225
|
+
)
|
|
226
|
+
prune_context.set_buffer(
|
|
227
|
+
f"{block_id}_encoder_residual",
|
|
228
|
+
encoder_hidden_states_residual,
|
|
229
|
+
)
|
|
230
|
+
del hidden_states_residual
|
|
231
|
+
del encoder_hidden_states_residual
|
|
232
|
+
|
|
233
|
+
del original_hidden_states
|
|
234
|
+
del original_encoder_hidden_states
|
|
235
|
+
|
|
236
|
+
return hidden_states, encoder_hidden_states
|
|
237
|
+
|
|
238
|
+
def call_blocks(
|
|
239
|
+
self,
|
|
240
|
+
hidden_states: torch.Tensor,
|
|
241
|
+
encoder_hidden_states: torch.Tensor,
|
|
242
|
+
*args,
|
|
243
|
+
**kwargs,
|
|
244
|
+
):
|
|
245
|
+
for i, block in enumerate(self.transformer_blocks):
|
|
246
|
+
hidden_states, encoder_hidden_states = self._compute_or_prune_block(
|
|
247
|
+
i,
|
|
248
|
+
block,
|
|
249
|
+
hidden_states,
|
|
250
|
+
encoder_hidden_states,
|
|
251
|
+
*args,
|
|
252
|
+
**kwargs,
|
|
253
|
+
)
|
|
254
|
+
|
|
255
|
+
return hidden_states, encoder_hidden_states
|
|
256
|
+
|
|
257
|
+
|
|
258
|
+
@torch.compiler.disable
|
|
259
|
+
def patch_pruned_stats(
|
|
260
|
+
transformer,
|
|
261
|
+
):
|
|
262
|
+
# Patch the pruned stats to the transformer, the pruned stats
|
|
263
|
+
# will be reset for each calling of pipe.__call__(**kwargs).
|
|
264
|
+
if transformer is None:
|
|
265
|
+
return
|
|
266
|
+
|
|
267
|
+
# TODO: Patch more pruned stats to the transformer
|
|
268
|
+
transformer._pruned_blocks = prune_context.get_pruned_blocks()
|
|
269
|
+
transformer._pruned_steps = prune_context.get_pruned_steps()
|
|
270
|
+
transformer._residual_diffs = prune_context.get_residual_diffs()
|
|
271
|
+
transformer._actual_blocks = prune_context.get_actual_blocks()
|
|
272
|
+
|
|
273
|
+
transformer._cfg_pruned_blocks = prune_context.get_cfg_pruned_blocks()
|
|
274
|
+
transformer._cfg_pruned_steps = prune_context.get_cfg_pruned_steps()
|
|
275
|
+
transformer._cfg_residual_diffs = prune_context.get_cfg_residual_diffs()
|
|
276
|
+
transformer._cfg_actual_blocks = prune_context.get_cfg_actual_blocks()
|