cache-dit 0.2.14__py3-none-any.whl → 0.2.16__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.

Potentially problematic release.


This version of cache-dit might be problematic. Click here for more details.

Files changed (37) hide show
  1. cache_dit/_version.py +2 -2
  2. cache_dit/cache_factory/__init__.py +1 -0
  3. cache_dit/cache_factory/adapters.py +47 -5
  4. cache_dit/cache_factory/dual_block_cache/__init__.py +4 -0
  5. cache_dit/cache_factory/dual_block_cache/cache_blocks.py +487 -0
  6. cache_dit/cache_factory/dual_block_cache/cache_context.py +10 -860
  7. cache_dit/cache_factory/dual_block_cache/diffusers_adapters/__init__.py +4 -0
  8. cache_dit/cache_factory/dual_block_cache/diffusers_adapters/cogvideox.py +5 -2
  9. cache_dit/cache_factory/dual_block_cache/diffusers_adapters/flux.py +14 -4
  10. cache_dit/cache_factory/dual_block_cache/diffusers_adapters/hunyuan_video.py +5 -2
  11. cache_dit/cache_factory/dual_block_cache/diffusers_adapters/mochi.py +5 -2
  12. cache_dit/cache_factory/{first_block_cache/diffusers_adapters/mochi.py → dual_block_cache/diffusers_adapters/qwen_image.py} +14 -12
  13. cache_dit/cache_factory/dual_block_cache/diffusers_adapters/wan.py +7 -4
  14. cache_dit/cache_factory/dynamic_block_prune/__init__.py +4 -0
  15. cache_dit/cache_factory/dynamic_block_prune/diffusers_adapters/__init__.py +4 -0
  16. cache_dit/cache_factory/dynamic_block_prune/diffusers_adapters/cogvideox.py +5 -2
  17. cache_dit/cache_factory/dynamic_block_prune/diffusers_adapters/flux.py +10 -4
  18. cache_dit/cache_factory/dynamic_block_prune/diffusers_adapters/hunyuan_video.py +5 -2
  19. cache_dit/cache_factory/dynamic_block_prune/diffusers_adapters/mochi.py +5 -2
  20. cache_dit/cache_factory/{first_block_cache/diffusers_adapters/cogvideox.py → dynamic_block_prune/diffusers_adapters/qwen_image.py} +28 -23
  21. cache_dit/cache_factory/dynamic_block_prune/diffusers_adapters/wan.py +5 -2
  22. cache_dit/cache_factory/dynamic_block_prune/prune_blocks.py +276 -0
  23. cache_dit/cache_factory/dynamic_block_prune/prune_context.py +228 -516
  24. cache_dit/cache_factory/patch/flux.py +241 -0
  25. {cache_dit-0.2.14.dist-info → cache_dit-0.2.16.dist-info}/METADATA +22 -80
  26. cache_dit-0.2.16.dist-info/RECORD +47 -0
  27. cache_dit/cache_factory/first_block_cache/cache_context.py +0 -719
  28. cache_dit/cache_factory/first_block_cache/diffusers_adapters/__init__.py +0 -57
  29. cache_dit/cache_factory/first_block_cache/diffusers_adapters/flux.py +0 -100
  30. cache_dit/cache_factory/first_block_cache/diffusers_adapters/hunyuan_video.py +0 -295
  31. cache_dit/cache_factory/first_block_cache/diffusers_adapters/wan.py +0 -98
  32. cache_dit-0.2.14.dist-info/RECORD +0 -49
  33. /cache_dit/cache_factory/{first_block_cache → patch}/__init__.py +0 -0
  34. {cache_dit-0.2.14.dist-info → cache_dit-0.2.16.dist-info}/WHEEL +0 -0
  35. {cache_dit-0.2.14.dist-info → cache_dit-0.2.16.dist-info}/entry_points.txt +0 -0
  36. {cache_dit-0.2.14.dist-info → cache_dit-0.2.16.dist-info}/licenses/LICENSE +0 -0
  37. {cache_dit-0.2.14.dist-info → cache_dit-0.2.16.dist-info}/top_level.txt +0 -0
@@ -0,0 +1,276 @@
1
+ import inspect
2
+ import torch
3
+
4
+ from cache_dit.cache_factory.dynamic_block_prune import prune_context
5
+ from cache_dit.logger import init_logger
6
+
7
+ logger = init_logger(__name__)
8
+
9
+
10
+ class DBPrunedTransformerBlocks(torch.nn.Module):
11
+ def __init__(
12
+ self,
13
+ transformer_blocks: torch.nn.ModuleList,
14
+ *,
15
+ transformer: torch.nn.Module = None,
16
+ return_hidden_states_first: bool = True,
17
+ return_hidden_states_only: bool = False,
18
+ ):
19
+ super().__init__()
20
+
21
+ self.transformer = transformer
22
+ self.transformer_blocks = transformer_blocks
23
+ self.return_hidden_states_first = return_hidden_states_first
24
+ self.return_hidden_states_only = return_hidden_states_only
25
+ self.pruned_blocks_step: int = 0
26
+ self._check_forward_params()
27
+
28
+ def _check_forward_params(self):
29
+ # NOTE: DBPrune only support blocks which have the pattern:
30
+ # IN/OUT: (hidden_states, encoder_hidden_states)
31
+ self.required_parameters = [
32
+ "hidden_states",
33
+ "encoder_hidden_states",
34
+ ]
35
+ if self.transformer_blocks is not None:
36
+ for block in self.transformer_blocks:
37
+ forward_parameters = set(
38
+ inspect.signature(block.forward).parameters.keys()
39
+ )
40
+ for required_param in self.required_parameters:
41
+ assert (
42
+ required_param in forward_parameters
43
+ ), f"The input parameters must contains: {required_param}."
44
+
45
+ def forward(
46
+ self,
47
+ hidden_states: torch.Tensor,
48
+ encoder_hidden_states: torch.Tensor,
49
+ *args,
50
+ **kwargs,
51
+ ):
52
+ prune_context.mark_step_begin()
53
+ self.pruned_blocks_step = 0
54
+ original_hidden_states = hidden_states
55
+
56
+ torch._dynamo.graph_break()
57
+ hidden_states, encoder_hidden_states = self.call_blocks(
58
+ hidden_states,
59
+ encoder_hidden_states,
60
+ *args,
61
+ **kwargs,
62
+ )
63
+
64
+ del original_hidden_states
65
+ torch._dynamo.graph_break()
66
+
67
+ prune_context.add_pruned_block(self.pruned_blocks_step)
68
+ prune_context.add_actual_block(self.num_transformer_blocks)
69
+ patch_pruned_stats(self.transformer)
70
+
71
+ return (
72
+ hidden_states
73
+ if self.return_hidden_states_only
74
+ else (
75
+ (hidden_states, encoder_hidden_states)
76
+ if self.return_hidden_states_first
77
+ else (encoder_hidden_states, hidden_states)
78
+ )
79
+ )
80
+
81
+ @property
82
+ @torch.compiler.disable
83
+ def num_transformer_blocks(self):
84
+ # Total number of transformer blocks.
85
+ return len(self.transformer_blocks)
86
+
87
+ @torch.compiler.disable
88
+ def _is_parallelized(self):
89
+ # Compatible with distributed inference.
90
+ return all(
91
+ (
92
+ self.transformer is not None,
93
+ getattr(self.transformer, "_is_parallelized", False),
94
+ )
95
+ )
96
+
97
+ @torch.compiler.disable
98
+ def _non_prune_blocks_ids(self):
99
+ # Never prune the first `Fn` and last `Bn` blocks.
100
+ num_blocks = self.num_transformer_blocks
101
+ Fn_compute_blocks_ = (
102
+ prune_context.Fn_compute_blocks()
103
+ if prune_context.Fn_compute_blocks() < num_blocks
104
+ else num_blocks
105
+ )
106
+ Fn_compute_blocks_ids = list(range(Fn_compute_blocks_))
107
+ Bn_compute_blocks_ = (
108
+ prune_context.Bn_compute_blocks()
109
+ if prune_context.Bn_compute_blocks() < num_blocks
110
+ else num_blocks
111
+ )
112
+ Bn_compute_blocks_ids = list(
113
+ range(
114
+ num_blocks - Bn_compute_blocks_,
115
+ num_blocks,
116
+ )
117
+ )
118
+ non_prune_blocks_ids = list(
119
+ set(
120
+ Fn_compute_blocks_ids
121
+ + Bn_compute_blocks_ids
122
+ + prune_context.get_non_prune_blocks_ids()
123
+ )
124
+ )
125
+ non_prune_blocks_ids = [
126
+ d for d in non_prune_blocks_ids if d < num_blocks
127
+ ]
128
+ return sorted(non_prune_blocks_ids)
129
+
130
+ @torch.compiler.disable
131
+ def _should_update_residuals(self):
132
+ # Wrap for non compiled mode.
133
+ # Check if the current step is a multiple of
134
+ # the residual cache update interval.
135
+ return (
136
+ prune_context.get_current_step()
137
+ % prune_context.residual_cache_update_interval()
138
+ == 0
139
+ )
140
+
141
+ @torch.compiler.disable
142
+ def _get_can_use_prune(
143
+ self,
144
+ block_id: int, # Block index in the transformer blocks
145
+ hidden_states: torch.Tensor, # hidden_states or residual
146
+ name: str = "Bn_original", # prev step name for single blocks
147
+ ):
148
+ # Wrap for non compiled mode.
149
+ can_use_prune = False
150
+ if block_id not in self._non_prune_blocks_ids():
151
+ can_use_prune = prune_context.get_can_use_prune(
152
+ hidden_states, # curr step
153
+ parallelized=self._is_parallelized(),
154
+ name=name, # prev step
155
+ )
156
+ self.pruned_blocks_step += int(can_use_prune)
157
+ return can_use_prune
158
+
159
+ def _compute_or_prune_block(
160
+ self,
161
+ block_id: int, # Block index in the transformer blocks
162
+ # Below are the inputs to the block
163
+ block, # The transformer block to be executed
164
+ hidden_states: torch.Tensor,
165
+ encoder_hidden_states: torch.Tensor,
166
+ *args,
167
+ **kwargs,
168
+ ):
169
+ # Helper function for `call_blocks`
170
+ original_hidden_states = hidden_states
171
+ original_encoder_hidden_states = encoder_hidden_states
172
+
173
+ # block_id: global block index in the transformer blocks +
174
+ # single_transformer_blocks
175
+ can_use_prune = self._get_can_use_prune(
176
+ block_id,
177
+ hidden_states, # hidden_states or residual
178
+ name=f"{block_id}_original", # prev step
179
+ )
180
+ # Prune steps: Prune current block and reuse the cached
181
+ # residuals for hidden states approximate.
182
+ if can_use_prune:
183
+ hidden_states, encoder_hidden_states = (
184
+ prune_context.apply_hidden_states_residual(
185
+ hidden_states,
186
+ encoder_hidden_states,
187
+ name=f"{block_id}_residual",
188
+ encoder_name=f"{block_id}_encoder_residual",
189
+ )
190
+ )
191
+ else:
192
+ # Normal steps: Compute the block and cache the residuals.
193
+ hidden_states = block(
194
+ hidden_states,
195
+ encoder_hidden_states,
196
+ *args,
197
+ **kwargs,
198
+ )
199
+ if not isinstance(hidden_states, torch.Tensor):
200
+ hidden_states, encoder_hidden_states = hidden_states
201
+ if not self.return_hidden_states_first:
202
+ hidden_states, encoder_hidden_states = (
203
+ encoder_hidden_states,
204
+ hidden_states,
205
+ )
206
+
207
+ # Save original_hidden_states for diff calculation.
208
+ # May not be necessary to update the hidden
209
+ # states and residuals each step?
210
+ if self._should_update_residuals():
211
+ # Cache residuals for the non-compute Bn blocks for
212
+ # subsequent prune steps.
213
+ hidden_states_residual = hidden_states - original_hidden_states
214
+ encoder_hidden_states_residual = (
215
+ encoder_hidden_states - original_encoder_hidden_states
216
+ )
217
+ prune_context.set_buffer(
218
+ f"{block_id}_original",
219
+ original_hidden_states,
220
+ )
221
+
222
+ prune_context.set_buffer(
223
+ f"{block_id}_residual",
224
+ hidden_states_residual,
225
+ )
226
+ prune_context.set_buffer(
227
+ f"{block_id}_encoder_residual",
228
+ encoder_hidden_states_residual,
229
+ )
230
+ del hidden_states_residual
231
+ del encoder_hidden_states_residual
232
+
233
+ del original_hidden_states
234
+ del original_encoder_hidden_states
235
+
236
+ return hidden_states, encoder_hidden_states
237
+
238
+ def call_blocks(
239
+ self,
240
+ hidden_states: torch.Tensor,
241
+ encoder_hidden_states: torch.Tensor,
242
+ *args,
243
+ **kwargs,
244
+ ):
245
+ for i, block in enumerate(self.transformer_blocks):
246
+ hidden_states, encoder_hidden_states = self._compute_or_prune_block(
247
+ i,
248
+ block,
249
+ hidden_states,
250
+ encoder_hidden_states,
251
+ *args,
252
+ **kwargs,
253
+ )
254
+
255
+ return hidden_states, encoder_hidden_states
256
+
257
+
258
+ @torch.compiler.disable
259
+ def patch_pruned_stats(
260
+ transformer,
261
+ ):
262
+ # Patch the pruned stats to the transformer, the pruned stats
263
+ # will be reset for each calling of pipe.__call__(**kwargs).
264
+ if transformer is None:
265
+ return
266
+
267
+ # TODO: Patch more pruned stats to the transformer
268
+ transformer._pruned_blocks = prune_context.get_pruned_blocks()
269
+ transformer._pruned_steps = prune_context.get_pruned_steps()
270
+ transformer._residual_diffs = prune_context.get_residual_diffs()
271
+ transformer._actual_blocks = prune_context.get_actual_blocks()
272
+
273
+ transformer._cfg_pruned_blocks = prune_context.get_cfg_pruned_blocks()
274
+ transformer._cfg_pruned_steps = prune_context.get_cfg_pruned_steps()
275
+ transformer._cfg_residual_diffs = prune_context.get_cfg_residual_diffs()
276
+ transformer._cfg_actual_blocks = prune_context.get_cfg_actual_blocks()