cache-dit 0.2.14__py3-none-any.whl → 0.2.16__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.

Potentially problematic release.


This version of cache-dit might be problematic. Click here for more details.

Files changed (37) hide show
  1. cache_dit/_version.py +2 -2
  2. cache_dit/cache_factory/__init__.py +1 -0
  3. cache_dit/cache_factory/adapters.py +47 -5
  4. cache_dit/cache_factory/dual_block_cache/__init__.py +4 -0
  5. cache_dit/cache_factory/dual_block_cache/cache_blocks.py +487 -0
  6. cache_dit/cache_factory/dual_block_cache/cache_context.py +10 -860
  7. cache_dit/cache_factory/dual_block_cache/diffusers_adapters/__init__.py +4 -0
  8. cache_dit/cache_factory/dual_block_cache/diffusers_adapters/cogvideox.py +5 -2
  9. cache_dit/cache_factory/dual_block_cache/diffusers_adapters/flux.py +14 -4
  10. cache_dit/cache_factory/dual_block_cache/diffusers_adapters/hunyuan_video.py +5 -2
  11. cache_dit/cache_factory/dual_block_cache/diffusers_adapters/mochi.py +5 -2
  12. cache_dit/cache_factory/{first_block_cache/diffusers_adapters/mochi.py → dual_block_cache/diffusers_adapters/qwen_image.py} +14 -12
  13. cache_dit/cache_factory/dual_block_cache/diffusers_adapters/wan.py +7 -4
  14. cache_dit/cache_factory/dynamic_block_prune/__init__.py +4 -0
  15. cache_dit/cache_factory/dynamic_block_prune/diffusers_adapters/__init__.py +4 -0
  16. cache_dit/cache_factory/dynamic_block_prune/diffusers_adapters/cogvideox.py +5 -2
  17. cache_dit/cache_factory/dynamic_block_prune/diffusers_adapters/flux.py +10 -4
  18. cache_dit/cache_factory/dynamic_block_prune/diffusers_adapters/hunyuan_video.py +5 -2
  19. cache_dit/cache_factory/dynamic_block_prune/diffusers_adapters/mochi.py +5 -2
  20. cache_dit/cache_factory/{first_block_cache/diffusers_adapters/cogvideox.py → dynamic_block_prune/diffusers_adapters/qwen_image.py} +28 -23
  21. cache_dit/cache_factory/dynamic_block_prune/diffusers_adapters/wan.py +5 -2
  22. cache_dit/cache_factory/dynamic_block_prune/prune_blocks.py +276 -0
  23. cache_dit/cache_factory/dynamic_block_prune/prune_context.py +228 -516
  24. cache_dit/cache_factory/patch/flux.py +241 -0
  25. {cache_dit-0.2.14.dist-info → cache_dit-0.2.16.dist-info}/METADATA +22 -80
  26. cache_dit-0.2.16.dist-info/RECORD +47 -0
  27. cache_dit/cache_factory/first_block_cache/cache_context.py +0 -719
  28. cache_dit/cache_factory/first_block_cache/diffusers_adapters/__init__.py +0 -57
  29. cache_dit/cache_factory/first_block_cache/diffusers_adapters/flux.py +0 -100
  30. cache_dit/cache_factory/first_block_cache/diffusers_adapters/hunyuan_video.py +0 -295
  31. cache_dit/cache_factory/first_block_cache/diffusers_adapters/wan.py +0 -98
  32. cache_dit-0.2.14.dist-info/RECORD +0 -49
  33. /cache_dit/cache_factory/{first_block_cache → patch}/__init__.py +0 -0
  34. {cache_dit-0.2.14.dist-info → cache_dit-0.2.16.dist-info}/WHEEL +0 -0
  35. {cache_dit-0.2.14.dist-info → cache_dit-0.2.16.dist-info}/entry_points.txt +0 -0
  36. {cache_dit-0.2.14.dist-info → cache_dit-0.2.16.dist-info}/licenses/LICENSE +0 -0
  37. {cache_dit-0.2.14.dist-info → cache_dit-0.2.16.dist-info}/top_level.txt +0 -0
@@ -1,14 +1,12 @@
1
- # Adapted from: https://github.com/chengzeyi/ParaAttention/tree/main/src/para_attn/first_block_cache/context.py
2
1
  import logging
3
2
  import contextlib
4
3
  import dataclasses
5
4
  from collections import defaultdict
6
- from typing import Any, Dict, List, Optional, Union
5
+ from typing import Any, Dict, List, Optional, Union, DefaultDict
7
6
 
8
7
  import torch
9
8
 
10
9
  import cache_dit.primitives as primitives
11
- from cache_dit.utils import is_diffusers_at_least_0_3_5
12
10
  from cache_dit.logger import init_logger
13
11
 
14
12
  logger = init_logger(__name__)
@@ -41,21 +39,56 @@ class DBPPruneContext:
41
39
  buffers: Dict[str, Any] = dataclasses.field(default_factory=dict)
42
40
 
43
41
  # Other settings
44
- downsample_factor: int = 1
42
+ downsample_factor: int = 1 # un-used
45
43
  num_inference_steps: int = -1
46
44
  warmup_steps: int = 0 # DON'T pruned in warmup steps
47
45
  # DON'T prune if the number of pruned steps >= max_pruned_steps
48
46
  max_pruned_steps: int = -1
49
47
 
50
- # Statistics
51
- executed_steps: int = 0
48
+ # Record the steps that have been cached, both cached and non-cache
49
+ executed_steps: int = 0 # cache + non-cache steps pippeline
50
+ # steps for transformer, for CFG, transformer_executed_steps will
51
+ # be double of executed_steps.
52
+ transformer_executed_steps: int = 0
53
+
54
+ # Support do_separate_classifier_free_guidance, such as Wan 2.1,
55
+ # Qwen-Image. For model that fused CFG and non-CFG into single
56
+ # forward step, should set do_separate_classifier_free_guidance
57
+ # as False. For example: CogVideoX, HunyuanVideo, Mochi.
58
+ do_separate_classifier_free_guidance: bool = False
59
+ # Compute cfg forward first or not, default False, namely,
60
+ # 0, 2, 4, ..., -> non-CFG step; 1, 3, 5, ... -> CFG step.
61
+ cfg_compute_first: bool = False
62
+ # Compute spearate diff values for CFG and non-CFG step,
63
+ # default True. If False, we will use the computed diff from
64
+ # current non-CFG transformer step for current CFG step.
65
+ cfg_diff_compute_separate: bool = True
66
+
67
+ # CFG & non-CFG pruned steps
52
68
  pruned_blocks: List[int] = dataclasses.field(default_factory=list)
53
69
  actual_blocks: List[int] = dataclasses.field(default_factory=list)
54
- # Residual diffs for each step, [step: list[float]]
55
- residual_diffs: Dict[str, List[float]] = dataclasses.field(
70
+ residual_diffs: DefaultDict[str, list[float]] = dataclasses.field(
71
+ default_factory=lambda: defaultdict(list),
72
+ )
73
+ cfg_pruned_blocks: List[int] = dataclasses.field(default_factory=list)
74
+ cfg_actual_blocks: List[int] = dataclasses.field(default_factory=list)
75
+ cfg_residual_diffs: DefaultDict[str, list[float]] = dataclasses.field(
56
76
  default_factory=lambda: defaultdict(list),
57
77
  )
58
78
 
79
+ @torch.compiler.disable
80
+ def __post_init__(self):
81
+ # Some checks for settings
82
+ if self.do_separate_classifier_free_guidance:
83
+ assert (
84
+ self.cfg_diff_compute_separate
85
+ ), "cfg_diff_compute_separate must be True"
86
+ if self.cfg_diff_compute_separate:
87
+ assert self.cfg_compute_first is False, (
88
+ "cfg_compute_first must set as False if "
89
+ "cfg_diff_compute_separate is enabled."
90
+ )
91
+
59
92
  @torch.compiler.disable
60
93
  def get_residual_diff_threshold(self):
61
94
  residual_diff_threshold = self.residual_diff_threshold
@@ -119,42 +152,89 @@ class DBPPruneContext:
119
152
 
120
153
  @torch.compiler.disable
121
154
  def mark_step_begin(self):
122
- self.executed_steps += 1
123
- if self.get_current_step() == 0:
155
+ # Always increase transformer executed steps
156
+ # incr step: prev 0 -> 1; prev 1 -> 2
157
+ # current step: incr step - 1
158
+ self.transformer_executed_steps += 1
159
+ if not self.do_separate_classifier_free_guidance:
160
+ self.executed_steps += 1
161
+ else:
162
+ # 0,1 -> 0 + 1, 2,3 -> 1 + 1, ...
163
+ if not self.cfg_compute_first:
164
+ if not self.is_separate_classifier_free_guidance_step():
165
+ # transformer step: 0,2,4,...
166
+ self.executed_steps += 1
167
+ else:
168
+ if self.is_separate_classifier_free_guidance_step():
169
+ # transformer step: 0,2,4,...
170
+ self.executed_steps += 1
171
+
172
+ # Reset the cached steps and residual diffs at the beginning
173
+ # of each inference.
174
+ if self.get_current_transformer_step() == 0:
124
175
  self.pruned_blocks.clear()
125
176
  self.actual_blocks.clear()
126
177
  self.residual_diffs.clear()
178
+ self.cfg_pruned_blocks.clear()
179
+ self.cfg_actual_blocks.clear()
180
+ self.cfg_residual_diffs.clear()
127
181
 
128
182
  @torch.compiler.disable
129
183
  def add_pruned_block(self, num_blocks):
130
- self.pruned_blocks.append(num_blocks)
184
+ if not self.is_separate_classifier_free_guidance_step():
185
+ self.pruned_blocks.append(num_blocks)
186
+ else:
187
+ self.cfg_pruned_blocks.append(num_blocks)
131
188
 
132
189
  @torch.compiler.disable
133
190
  def add_actual_block(self, num_blocks):
134
- self.actual_blocks.append(num_blocks)
191
+ if not self.is_separate_classifier_free_guidance_step():
192
+ self.actual_blocks.append(num_blocks)
193
+ else:
194
+ self.cfg_actual_blocks.append(num_blocks)
135
195
 
136
196
  @torch.compiler.disable
137
197
  def add_residual_diff(self, diff):
138
- if isinstance(diff, torch.Tensor):
139
- diff = diff.item()
140
- step = self.get_current_step()
141
- self.residual_diffs[step].append(diff)
142
- max_num_block_diffs = 1000
143
- # Avoid memory leak, keep only the last 1000 diffs
144
- if len(self.residual_diffs[step]) > max_num_block_diffs:
145
- self.residual_diffs[step] = self.residual_diffs[step][
146
- -max_num_block_diffs:
147
- ]
148
- if logger.isEnabledFor(logging.DEBUG):
149
- logger.debug(
150
- f"Step {step}, block: {len(self.residual_diffs[step])}, "
151
- f"residual diff: {diff:.6f}"
152
- )
198
+ # step: executed_steps - 1, not transformer_steps - 1
199
+ step = str(self.get_current_step())
200
+ # Only add the diff if it is not already recorded for this step
201
+ if not self.is_separate_classifier_free_guidance_step():
202
+ if step not in self.residual_diffs:
203
+ self.residual_diffs[step] = [diff]
204
+ else:
205
+ self.residual_diffs[step].append(diff)
206
+ else:
207
+ if step not in self.cfg_residual_diffs:
208
+ self.cfg_residual_diffs[step] = [diff]
209
+ else:
210
+ self.cfg_residual_diffs[step].append(diff)
211
+
212
+ @torch.compiler.disable
213
+ def get_pruned_blocks(self):
214
+ return self.pruned_blocks.copy()
215
+
216
+ @torch.compiler.disable
217
+ def get_cfg_pruned_blocks(self):
218
+ return self.cfg_pruned_blocks.copy()
153
219
 
154
220
  @torch.compiler.disable
155
221
  def get_current_step(self):
156
222
  return self.executed_steps - 1
157
223
 
224
+ @torch.compiler.disable
225
+ def get_current_transformer_step(self):
226
+ return self.transformer_executed_steps - 1
227
+
228
+ @torch.compiler.disable
229
+ def is_separate_classifier_free_guidance_step(self):
230
+ if not self.do_separate_classifier_free_guidance:
231
+ return False
232
+ if self.cfg_compute_first:
233
+ # CFG steps: 0, 2, 4, 6, ...
234
+ return self.get_current_transformer_step() % 2 == 0
235
+ # CFG steps: 1, 3, 5, 7, ...
236
+ return self.get_current_transformer_step() % 2 != 0
237
+
158
238
  @torch.compiler.disable
159
239
  def is_in_warmup(self):
160
240
  return self.get_current_step() < self.warmup_steps
@@ -168,38 +248,35 @@ def get_residual_diff_threshold():
168
248
 
169
249
 
170
250
  @torch.compiler.disable
171
- def get_buffer(name):
172
- prune_context = get_current_prune_context()
173
- assert prune_context is not None, "prune_context must be set before"
174
- return prune_context.get_buffer(name)
175
-
176
-
177
- @torch.compiler.disable
178
- def set_buffer(name, buffer):
251
+ def mark_step_begin():
179
252
  prune_context = get_current_prune_context()
180
253
  assert prune_context is not None, "prune_context must be set before"
181
- prune_context.set_buffer(name, buffer)
254
+ prune_context.mark_step_begin()
182
255
 
183
256
 
184
257
  @torch.compiler.disable
185
- def remove_buffer(name):
258
+ def get_current_step():
186
259
  prune_context = get_current_prune_context()
187
260
  assert prune_context is not None, "prune_context must be set before"
188
- prune_context.remove_buffer(name)
261
+ return prune_context.get_current_step()
189
262
 
190
263
 
191
264
  @torch.compiler.disable
192
- def mark_step_begin():
265
+ def get_current_step_cfg_residual_diff():
193
266
  prune_context = get_current_prune_context()
194
267
  assert prune_context is not None, "prune_context must be set before"
195
- prune_context.mark_step_begin()
268
+ step = str(get_current_step())
269
+ cfg_residual_diffs = get_cfg_residual_diffs()
270
+ if step in cfg_residual_diffs:
271
+ return cfg_residual_diffs[step]
272
+ return None
196
273
 
197
274
 
198
275
  @torch.compiler.disable
199
- def get_current_step():
276
+ def get_current_transformer_step():
200
277
  prune_context = get_current_prune_context()
201
278
  assert prune_context is not None, "prune_context must be set before"
202
- return prune_context.get_current_step()
279
+ return prune_context.get_current_transformer_step()
203
280
 
204
281
 
205
282
  @torch.compiler.disable
@@ -226,6 +303,13 @@ def get_pruned_blocks():
226
303
  return prune_context.pruned_blocks.copy()
227
304
 
228
305
 
306
+ @torch.compiler.disable
307
+ def get_cfg_pruned_blocks():
308
+ prune_context = get_current_prune_context()
309
+ assert prune_context is not None, "prune_context must be set before"
310
+ return prune_context.cfg_pruned_blocks.copy()
311
+
312
+
229
313
  @torch.compiler.disable
230
314
  def add_actual_block(num_blocks):
231
315
  assert (
@@ -243,6 +327,13 @@ def get_actual_blocks():
243
327
  return prune_context.actual_blocks.copy()
244
328
 
245
329
 
330
+ @torch.compiler.disable
331
+ def get_cfg_actual_blocks():
332
+ prune_context = get_current_prune_context()
333
+ assert prune_context is not None, "prune_context must be set before"
334
+ return prune_context.cfg_actual_blocks.copy()
335
+
336
+
246
337
  @torch.compiler.disable
247
338
  def get_pruned_steps():
248
339
  prune_context = get_current_prune_context()
@@ -252,6 +343,15 @@ def get_pruned_steps():
252
343
  return len(pruned_blocks)
253
344
 
254
345
 
346
+ @torch.compiler.disable
347
+ def get_cfg_pruned_steps():
348
+ prune_context = get_current_prune_context()
349
+ assert prune_context is not None, "prune_context must be set before"
350
+ cfg_pruned_blocks = get_cfg_pruned_blocks()
351
+ cfg_pruned_blocks = [x for x in cfg_pruned_blocks if x > 0]
352
+ return len(cfg_pruned_blocks)
353
+
354
+
255
355
  @torch.compiler.disable
256
356
  def is_in_warmup():
257
357
  prune_context = get_current_prune_context()
@@ -284,6 +384,14 @@ def get_residual_diffs():
284
384
  return prune_context.residual_diffs.copy()
285
385
 
286
386
 
387
+ @torch.compiler.disable
388
+ def get_cfg_residual_diffs():
389
+ prune_context = get_current_prune_context()
390
+ assert prune_context is not None, "prune_context must be set before"
391
+ # Return a copy of the residual diffs to avoid modification
392
+ return prune_context.cfg_residual_diffs.copy()
393
+
394
+
287
395
  @torch.compiler.disable
288
396
  def get_important_condition_threshold():
289
397
  prune_context = get_current_prune_context()
@@ -325,6 +433,27 @@ def get_non_prune_blocks_ids():
325
433
  return prune_context.non_prune_blocks_ids
326
434
 
327
435
 
436
+ @torch.compiler.disable
437
+ def do_separate_classifier_free_guidance():
438
+ prune_context = get_current_prune_context()
439
+ assert prune_context is not None, "prune_context must be set before"
440
+ return prune_context.do_separate_classifier_free_guidance
441
+
442
+
443
+ @torch.compiler.disable
444
+ def is_separate_classifier_free_guidance_step():
445
+ prune_context = get_current_prune_context()
446
+ assert prune_context is not None, "prune_context must be set before"
447
+ return prune_context.is_separate_classifier_free_guidance_step()
448
+
449
+
450
+ @torch.compiler.disable
451
+ def cfg_diff_compute_separate():
452
+ prune_context = get_current_prune_context()
453
+ assert prune_context is not None, "prune_context must be set before"
454
+ return prune_context.cfg_diff_compute_separate
455
+
456
+
328
457
  _current_prune_context: DBPPruneContext = None
329
458
 
330
459
 
@@ -463,6 +592,58 @@ def are_two_tensors_similar(
463
592
  return diff < threshold
464
593
 
465
594
 
595
+ @torch.compiler.disable
596
+ def _debugging_set_buffer(prefix):
597
+ if logger.isEnabledFor(logging.DEBUG):
598
+ logger.debug(
599
+ f"set {prefix}, "
600
+ f"transformer step: {get_current_transformer_step()}, "
601
+ f"executed step: {get_current_step()}"
602
+ )
603
+
604
+
605
+ @torch.compiler.disable
606
+ def _debugging_get_buffer(prefix):
607
+ if logger.isEnabledFor(logging.DEBUG):
608
+ logger.debug(
609
+ f"get {prefix}, "
610
+ f"transformer step: {get_current_transformer_step()}, "
611
+ f"executed step: {get_current_step()}"
612
+ )
613
+
614
+
615
+ @torch.compiler.disable
616
+ def set_buffer(name: str, buffer: torch.Tensor):
617
+ # Set hidden_states or residual for Fn blocks.
618
+ # This buffer is only use for L1 diff calculation.
619
+ prune_context = get_current_prune_context()
620
+ assert prune_context is not None, "prune_context must be set before"
621
+ if is_separate_classifier_free_guidance_step():
622
+ _debugging_set_buffer(f"{name}_buffer_cfg")
623
+ prune_context.set_buffer(f"{name}_buffer_cfg", buffer)
624
+ else:
625
+ _debugging_set_buffer(f"{name}_buffer")
626
+ prune_context.set_buffer(f"{name}_buffer", buffer)
627
+
628
+
629
+ @torch.compiler.disable
630
+ def get_buffer(name: str):
631
+ prune_context = get_current_prune_context()
632
+ assert prune_context is not None, "prune_context must be set before"
633
+ if is_separate_classifier_free_guidance_step():
634
+ _debugging_get_buffer(f"{name}_buffer_cfg")
635
+ return prune_context.get_buffer(f"{name}_buffer_cfg")
636
+ _debugging_get_buffer(f"{name}_buffer")
637
+ return prune_context.get_buffer(f"{name}_buffer")
638
+
639
+
640
+ @torch.compiler.disable
641
+ def remove_buffer(name: str):
642
+ prune_context = get_current_prune_context()
643
+ assert prune_context is not None, "prune_context must be set before"
644
+ prune_context.remove_buffer(name)
645
+
646
+
466
647
  @torch.compiler.disable
467
648
  def apply_hidden_states_residual(
468
649
  hidden_states: torch.Tensor,
@@ -506,7 +687,11 @@ def get_can_use_prune(
506
687
  if is_in_warmup():
507
688
  return False
508
689
 
509
- pruned_steps = get_pruned_steps()
690
+ if not is_separate_classifier_free_guidance_step():
691
+ pruned_steps = get_pruned_steps()
692
+ else:
693
+ pruned_steps = get_cfg_pruned_steps()
694
+
510
695
  max_pruned_steps = get_max_pruned_steps()
511
696
  if max_pruned_steps >= 0 and (pruned_steps >= max_pruned_steps):
512
697
  if logger.isEnabledFor(logging.DEBUG):
@@ -521,16 +706,8 @@ def get_can_use_prune(
521
706
  if threshold <= 0.0:
522
707
  return False
523
708
 
524
- downsample_factor = get_downsample_factor()
525
709
  prev_states_tensor = get_buffer(f"{name}")
526
710
 
527
- if downsample_factor > 1:
528
- states_tensor = states_tensor[..., ::downsample_factor]
529
- states_tensor = states_tensor.contiguous()
530
- if prev_states_tensor is not None:
531
- prev_states_tensor = prev_states_tensor[..., ::downsample_factor]
532
- prev_states_tensor = prev_states_tensor.contiguous()
533
-
534
711
  return prev_states_tensor is not None and are_two_tensors_similar(
535
712
  prev_states_tensor,
536
713
  states_tensor,
@@ -538,468 +715,3 @@ def get_can_use_prune(
538
715
  parallelized=parallelized,
539
716
  name=name,
540
717
  )
541
-
542
-
543
- class DBPrunedTransformerBlocks(torch.nn.Module):
544
- def __init__(
545
- self,
546
- transformer_blocks,
547
- single_transformer_blocks=None,
548
- *,
549
- transformer=None,
550
- return_hidden_states_first=True,
551
- return_hidden_states_only=False,
552
- ):
553
- super().__init__()
554
-
555
- self.transformer = transformer
556
- self.transformer_blocks = transformer_blocks
557
- self.single_transformer_blocks = single_transformer_blocks
558
- self.return_hidden_states_first = return_hidden_states_first
559
- self.return_hidden_states_only = return_hidden_states_only
560
- self.pruned_blocks_step: int = 0
561
-
562
- def forward(
563
- self,
564
- hidden_states: torch.Tensor,
565
- encoder_hidden_states: torch.Tensor,
566
- *args,
567
- **kwargs,
568
- ):
569
- mark_step_begin()
570
- self.pruned_blocks_step = 0
571
- original_hidden_states = hidden_states
572
-
573
- torch._dynamo.graph_break()
574
- hidden_states, encoder_hidden_states = self.call_transformer_blocks(
575
- hidden_states,
576
- encoder_hidden_states,
577
- *args,
578
- **kwargs,
579
- )
580
-
581
- del original_hidden_states
582
- torch._dynamo.graph_break()
583
-
584
- add_pruned_block(self.pruned_blocks_step)
585
- add_actual_block(self.num_transformer_blocks)
586
- patch_pruned_stats(self.transformer)
587
-
588
- return (
589
- hidden_states
590
- if self.return_hidden_states_only
591
- else (
592
- (hidden_states, encoder_hidden_states)
593
- if self.return_hidden_states_first
594
- else (encoder_hidden_states, hidden_states)
595
- )
596
- )
597
-
598
- @property
599
- @torch.compiler.disable
600
- def num_transformer_blocks(self):
601
- # Total number of transformer blocks, including single transformer blocks.
602
- num_blocks = len(self.transformer_blocks)
603
- if self.single_transformer_blocks is not None:
604
- num_blocks += len(self.single_transformer_blocks)
605
- return num_blocks
606
-
607
- @torch.compiler.disable
608
- def _is_parallelized(self):
609
- # Compatible with distributed inference.
610
- return all(
611
- (
612
- self.transformer is not None,
613
- getattr(self.transformer, "_is_parallelized", False),
614
- )
615
- )
616
-
617
- @torch.compiler.disable
618
- def _non_prune_blocks_ids(self):
619
- # Never prune the first `Fn` and last `Bn` blocks.
620
- num_blocks = self.num_transformer_blocks
621
- Fn_compute_blocks_ = (
622
- Fn_compute_blocks()
623
- if Fn_compute_blocks() < num_blocks
624
- else num_blocks
625
- )
626
- Fn_compute_blocks_ids = list(range(Fn_compute_blocks_))
627
- Bn_compute_blocks_ = (
628
- Bn_compute_blocks()
629
- if Bn_compute_blocks() < num_blocks
630
- else num_blocks
631
- )
632
- Bn_compute_blocks_ids = list(
633
- range(
634
- num_blocks - Bn_compute_blocks_,
635
- num_blocks,
636
- )
637
- )
638
- non_prune_blocks_ids = list(
639
- set(
640
- Fn_compute_blocks_ids
641
- + Bn_compute_blocks_ids
642
- + get_non_prune_blocks_ids()
643
- )
644
- )
645
- non_prune_blocks_ids = [
646
- d for d in non_prune_blocks_ids if d < num_blocks
647
- ]
648
- return sorted(non_prune_blocks_ids)
649
-
650
- @torch.compiler.disable
651
- def _compute_single_hidden_states_residual(
652
- self,
653
- single_hidden_states: torch.Tensor,
654
- single_original_hidden_states: torch.Tensor,
655
- # global original single hidden states
656
- original_single_hidden_states: torch.Tensor,
657
- original_single_encoder_hidden_states: torch.Tensor,
658
- ):
659
- single_hidden_states, single_encoder_hidden_states = (
660
- self._split_single_hidden_states(
661
- single_hidden_states,
662
- original_single_hidden_states,
663
- original_single_encoder_hidden_states,
664
- )
665
- )
666
-
667
- single_original_hidden_states, single_original_encoder_hidden_states = (
668
- self._split_single_hidden_states(
669
- single_original_hidden_states,
670
- original_single_hidden_states,
671
- original_single_encoder_hidden_states,
672
- )
673
- )
674
-
675
- single_hidden_states_residual = (
676
- single_hidden_states - single_original_hidden_states
677
- )
678
- single_encoder_hidden_states_residual = (
679
- single_encoder_hidden_states - single_original_encoder_hidden_states
680
- )
681
- return (
682
- single_hidden_states_residual,
683
- single_encoder_hidden_states_residual,
684
- )
685
-
686
- @torch.compiler.disable
687
- def _split_single_hidden_states(
688
- self,
689
- single_hidden_states: torch.Tensor,
690
- # global original single hidden states
691
- original_single_hidden_states: torch.Tensor,
692
- original_single_encoder_hidden_states: torch.Tensor,
693
- ):
694
- single_encoder_hidden_states, single_hidden_states = (
695
- single_hidden_states.split(
696
- [
697
- original_single_encoder_hidden_states.shape[1],
698
- single_hidden_states.shape[1]
699
- - original_single_encoder_hidden_states.shape[1],
700
- ],
701
- dim=1,
702
- )
703
- )
704
- # Reshape the single_hidden_states and single_encoder_hidden_states
705
- # to the original shape. This is necessary to ensure that the
706
- # residuals are computed correctly.
707
- single_hidden_states = (
708
- single_hidden_states.reshape(-1)
709
- .contiguous()
710
- .reshape(original_single_hidden_states.shape)
711
- )
712
- single_encoder_hidden_states = (
713
- single_encoder_hidden_states.reshape(-1)
714
- .contiguous()
715
- .reshape(original_single_encoder_hidden_states.shape)
716
- )
717
- return single_hidden_states, single_encoder_hidden_states
718
-
719
- @torch.compiler.disable
720
- def _should_update_residuals(self):
721
- # Wrap for non compiled mode.
722
- # Check if the current step is a multiple of
723
- # the residual cache update interval.
724
- return get_current_step() % residual_cache_update_interval() == 0
725
-
726
- @torch.compiler.disable
727
- def _get_can_use_prune(
728
- self,
729
- block_id: int, # Block index in the transformer blocks
730
- hidden_states: torch.Tensor, # hidden_states or residual
731
- name: str = "Bn_original", # prev step name for single blocks
732
- ):
733
- # Wrap for non compiled mode.
734
- can_use_prune = False
735
- if block_id not in self._non_prune_blocks_ids():
736
- can_use_prune = get_can_use_prune(
737
- hidden_states, # curr step
738
- parallelized=self._is_parallelized(),
739
- name=name, # prev step
740
- )
741
- self.pruned_blocks_step += int(can_use_prune)
742
- return can_use_prune
743
-
744
- def _compute_or_prune_single_transformer_block(
745
- self,
746
- block_id: int, # Block index in the transformer blocks
747
- # Helper inputs for hidden states split and reshape
748
- # Global original single hidden states
749
- original_single_hidden_states: torch.Tensor,
750
- original_single_encoder_hidden_states: torch.Tensor,
751
- # Below are the inputs to the block
752
- block, # The transformer block to be executed
753
- hidden_states: torch.Tensor,
754
- *args,
755
- **kwargs,
756
- ):
757
- # Helper function for `call_transformer_blocks`
758
- # block_id: global block index in the transformer blocks +
759
- # single_transformer_blocks
760
- can_use_prune = self._get_can_use_prune(
761
- block_id,
762
- hidden_states, # hidden_states or residual
763
- name=f"{block_id}_single_original", # prev step
764
- )
765
-
766
- # Prune steps: Prune current block and reuse the cached
767
- # residuals for hidden states approximate.
768
- if can_use_prune:
769
- single_original_hidden_states = hidden_states
770
- (
771
- single_original_hidden_states,
772
- single_original_encoder_hidden_states,
773
- ) = self._split_single_hidden_states(
774
- single_original_hidden_states,
775
- original_single_hidden_states,
776
- original_single_encoder_hidden_states,
777
- )
778
- hidden_states, encoder_hidden_states = apply_hidden_states_residual(
779
- single_original_hidden_states,
780
- single_original_encoder_hidden_states,
781
- name=f"{block_id}_single_residual",
782
- encoder_name=f"{block_id}_single_encoder_residual",
783
- )
784
- hidden_states = torch.cat(
785
- [encoder_hidden_states, hidden_states],
786
- dim=1,
787
- )
788
- del single_original_hidden_states
789
- del single_original_encoder_hidden_states
790
-
791
- else:
792
- # Normal steps: Compute the block and cache the residuals.
793
- single_original_hidden_states = hidden_states
794
- hidden_states = block(
795
- hidden_states,
796
- *args,
797
- **kwargs,
798
- )
799
-
800
- # Save original_hidden_states for diff calculation.
801
- # May not be necessary to update the hidden
802
- # states and residuals each step?
803
- if self._should_update_residuals():
804
- # Cache residuals for the non-compute Bn blocks for
805
- # subsequent prune steps.
806
- single_hidden_states = hidden_states
807
- (
808
- single_hidden_states_residual,
809
- single_encoder_hidden_states_residual,
810
- ) = self._compute_single_hidden_states_residual(
811
- single_hidden_states,
812
- single_original_hidden_states,
813
- original_single_hidden_states,
814
- original_single_encoder_hidden_states,
815
- )
816
-
817
- set_buffer(
818
- f"{block_id}_single_original",
819
- single_original_hidden_states,
820
- )
821
-
822
- set_buffer(
823
- f"{block_id}_single_residual",
824
- single_hidden_states_residual,
825
- )
826
- set_buffer(
827
- f"{block_id}_single_encoder_residual",
828
- single_encoder_hidden_states_residual,
829
- )
830
-
831
- del single_hidden_states
832
- del single_hidden_states_residual
833
- del single_encoder_hidden_states_residual
834
-
835
- del single_original_hidden_states
836
-
837
- return hidden_states
838
-
839
- def _compute_or_prune_transformer_block(
840
- self,
841
- block_id: int, # Block index in the transformer blocks
842
- # Below are the inputs to the block
843
- block, # The transformer block to be executed
844
- hidden_states: torch.Tensor,
845
- encoder_hidden_states: torch.Tensor,
846
- *args,
847
- **kwargs,
848
- ):
849
- # Helper function for `call_transformer_blocks`
850
- original_hidden_states = hidden_states
851
- original_encoder_hidden_states = encoder_hidden_states
852
-
853
- # block_id: global block index in the transformer blocks +
854
- # single_transformer_blocks
855
- can_use_prune = self._get_can_use_prune(
856
- block_id,
857
- hidden_states, # hidden_states or residual
858
- name=f"{block_id}_original", # prev step
859
- )
860
-
861
- # Prune steps: Prune current block and reuse the cached
862
- # residuals for hidden states approximate.
863
- if can_use_prune:
864
- hidden_states, encoder_hidden_states = apply_hidden_states_residual(
865
- hidden_states,
866
- encoder_hidden_states,
867
- name=f"{block_id}_residual",
868
- encoder_name=f"{block_id}_encoder_residual",
869
- )
870
- else:
871
- # Normal steps: Compute the block and cache the residuals.
872
- hidden_states = block(
873
- hidden_states,
874
- encoder_hidden_states,
875
- *args,
876
- **kwargs,
877
- )
878
- if not isinstance(hidden_states, torch.Tensor):
879
- hidden_states, encoder_hidden_states = hidden_states
880
- if not self.return_hidden_states_first:
881
- hidden_states, encoder_hidden_states = (
882
- encoder_hidden_states,
883
- hidden_states,
884
- )
885
-
886
- # Save original_hidden_states for diff calculation.
887
- # May not be necessary to update the hidden
888
- # states and residuals each step?
889
- if self._should_update_residuals():
890
- # Cache residuals for the non-compute Bn blocks for
891
- # subsequent prune steps.
892
- hidden_states_residual = hidden_states - original_hidden_states
893
- encoder_hidden_states_residual = (
894
- encoder_hidden_states - original_encoder_hidden_states
895
- )
896
- set_buffer(
897
- f"{block_id}_original",
898
- original_hidden_states,
899
- )
900
-
901
- set_buffer(
902
- f"{block_id}_residual",
903
- hidden_states_residual,
904
- )
905
- set_buffer(
906
- f"{block_id}_encoder_residual",
907
- encoder_hidden_states_residual,
908
- )
909
- del hidden_states_residual
910
- del encoder_hidden_states_residual
911
-
912
- del original_hidden_states
913
- del original_encoder_hidden_states
914
-
915
- return hidden_states, encoder_hidden_states
916
-
917
- def call_transformer_blocks(
918
- self,
919
- hidden_states: torch.Tensor,
920
- encoder_hidden_states: torch.Tensor,
921
- *args,
922
- **kwargs,
923
- ):
924
- original_hidden_states = hidden_states
925
- original_encoder_hidden_states = encoder_hidden_states
926
-
927
- for i, block in enumerate(self.transformer_blocks):
928
- hidden_states, encoder_hidden_states = (
929
- self._compute_or_prune_transformer_block(
930
- i,
931
- block,
932
- hidden_states,
933
- encoder_hidden_states,
934
- *args,
935
- **kwargs,
936
- )
937
- )
938
-
939
- if self.single_transformer_blocks is not None:
940
- # https://github.com/huggingface/diffusers/blob/main/src/diffusers/models/transformers/transformer_flux.py#L380
941
- if is_diffusers_at_least_0_3_5():
942
- for j, block in enumerate(self.single_transformer_blocks):
943
- # NOTE: Reuse _compute_or_prune_transformer_block here.
944
- hidden_states, encoder_hidden_states = (
945
- self._compute_or_prune_transformer_block(
946
- j + len(self.transformer_blocks),
947
- block,
948
- hidden_states,
949
- encoder_hidden_states,
950
- *args,
951
- **kwargs,
952
- )
953
- )
954
- else:
955
- hidden_states = torch.cat(
956
- [encoder_hidden_states, hidden_states], dim=1
957
- )
958
- for j, block in enumerate(self.single_transformer_blocks):
959
- hidden_states = (
960
- self._compute_or_prune_single_transformer_block(
961
- j + len(self.transformer_blocks),
962
- original_hidden_states,
963
- original_encoder_hidden_states,
964
- block,
965
- hidden_states,
966
- *args,
967
- **kwargs,
968
- )
969
- )
970
-
971
- encoder_hidden_states, hidden_states = hidden_states.split(
972
- [
973
- encoder_hidden_states.shape[1],
974
- hidden_states.shape[1] - encoder_hidden_states.shape[1],
975
- ],
976
- dim=1,
977
- )
978
-
979
- hidden_states = (
980
- hidden_states.reshape(-1)
981
- .contiguous()
982
- .reshape(original_hidden_states.shape)
983
- )
984
- encoder_hidden_states = (
985
- encoder_hidden_states.reshape(-1)
986
- .contiguous()
987
- .reshape(original_encoder_hidden_states.shape)
988
- )
989
- return hidden_states, encoder_hidden_states
990
-
991
-
992
- @torch.compiler.disable
993
- def patch_pruned_stats(
994
- transformer,
995
- ):
996
- # Patch the pruned stats to the transformer, the pruned stats
997
- # will be reset for each calling of pipe.__call__(**kwargs).
998
- if transformer is None:
999
- return
1000
-
1001
- # TODO: Patch more pruned stats to the transformer
1002
- transformer._pruned_blocks = get_pruned_blocks()
1003
- transformer._pruned_steps = get_pruned_steps()
1004
- transformer._residual_diffs = get_residual_diffs()
1005
- transformer._actual_blocks = get_actual_blocks()