cache-dit 0.1.0__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.

Potentially problematic release.


This version of cache-dit might be problematic. Click here for more details.

Files changed (31) hide show
  1. cache_dit/__init__.py +0 -0
  2. cache_dit/_version.py +21 -0
  3. cache_dit/cache_factory/__init__.py +166 -0
  4. cache_dit/cache_factory/dual_block_cache/__init__.py +0 -0
  5. cache_dit/cache_factory/dual_block_cache/cache_context.py +1361 -0
  6. cache_dit/cache_factory/dual_block_cache/diffusers_adapters/__init__.py +45 -0
  7. cache_dit/cache_factory/dual_block_cache/diffusers_adapters/cogvideox.py +89 -0
  8. cache_dit/cache_factory/dual_block_cache/diffusers_adapters/flux.py +100 -0
  9. cache_dit/cache_factory/dual_block_cache/diffusers_adapters/mochi.py +88 -0
  10. cache_dit/cache_factory/dynamic_block_prune/__init__.py +0 -0
  11. cache_dit/cache_factory/dynamic_block_prune/diffusers_adapters/__init__.py +45 -0
  12. cache_dit/cache_factory/dynamic_block_prune/diffusers_adapters/cogvideox.py +89 -0
  13. cache_dit/cache_factory/dynamic_block_prune/diffusers_adapters/flux.py +100 -0
  14. cache_dit/cache_factory/dynamic_block_prune/diffusers_adapters/mochi.py +89 -0
  15. cache_dit/cache_factory/dynamic_block_prune/prune_context.py +979 -0
  16. cache_dit/cache_factory/first_block_cache/__init__.py +0 -0
  17. cache_dit/cache_factory/first_block_cache/cache_context.py +727 -0
  18. cache_dit/cache_factory/first_block_cache/diffusers_adapters/__init__.py +53 -0
  19. cache_dit/cache_factory/first_block_cache/diffusers_adapters/cogvideox.py +89 -0
  20. cache_dit/cache_factory/first_block_cache/diffusers_adapters/flux.py +100 -0
  21. cache_dit/cache_factory/first_block_cache/diffusers_adapters/mochi.py +89 -0
  22. cache_dit/cache_factory/first_block_cache/diffusers_adapters/wan.py +98 -0
  23. cache_dit/cache_factory/taylorseer.py +76 -0
  24. cache_dit/cache_factory/utils.py +0 -0
  25. cache_dit/logger.py +97 -0
  26. cache_dit/primitives.py +152 -0
  27. cache_dit-0.1.0.dist-info/METADATA +350 -0
  28. cache_dit-0.1.0.dist-info/RECORD +31 -0
  29. cache_dit-0.1.0.dist-info/WHEEL +5 -0
  30. cache_dit-0.1.0.dist-info/licenses/LICENSE +53 -0
  31. cache_dit-0.1.0.dist-info/top_level.txt +1 -0
@@ -0,0 +1,727 @@
1
+ # Adapted from: https://github.com/chengzeyi/ParaAttention/tree/main/src/para_attn/first_block_cache/context.py
2
+ import contextlib
3
+ import dataclasses
4
+ import logging
5
+ from collections import defaultdict
6
+ from typing import Any, DefaultDict, Dict, List, Optional, Union
7
+
8
+ import torch
9
+
10
+ import cache_dit.primitives as DP
11
+ from cache_dit.cache_factory.taylorseer import TaylorSeer
12
+ from cache_dit.logger import init_logger
13
+
14
+ logger = init_logger(__name__)
15
+
16
+
17
+ @dataclasses.dataclass
18
+ class CacheContext:
19
+ residual_diff_threshold: Union[torch.Tensor, float] = 0.0
20
+ alter_residual_diff_threshold: Optional[Union[torch.Tensor, float]] = None
21
+
22
+ downsample_factor: int = 1
23
+
24
+ enable_alter_cache: bool = False
25
+ num_inference_steps: int = -1
26
+ warmup_steps: int = 0
27
+
28
+ enable_taylorseer: bool = False
29
+ taylorseer_kwargs: Dict[str, Any] = dataclasses.field(default_factory=dict)
30
+
31
+ # Skip Layer Guidance, SLG
32
+ # https://github.com/huggingface/candle/issues/2588
33
+ slg_layers: Optional[List[int]] = None
34
+ slg_start: float = 0.0
35
+ slg_end: float = 0.1
36
+
37
+ taylorseer: Optional[TaylorSeer] = None
38
+ alter_taylorseer: Optional[TaylorSeer] = None
39
+
40
+ buffers: Dict[str, Any] = dataclasses.field(default_factory=dict)
41
+ incremental_name_counters: DefaultDict[str, int] = dataclasses.field(
42
+ default_factory=lambda: defaultdict(int),
43
+ )
44
+
45
+ executed_steps: int = 0
46
+ is_alter_cache: bool = True
47
+
48
+ max_cached_steps: int = -1
49
+ cached_steps: List[int] = dataclasses.field(default_factory=list)
50
+ residual_diffs: DefaultDict[str, float] = dataclasses.field(
51
+ default_factory=lambda: defaultdict(float),
52
+ )
53
+
54
+ def __post_init__(self):
55
+ if self.enable_taylorseer:
56
+ self.taylorseer = TaylorSeer(**self.taylorseer_kwargs)
57
+ if self.enable_alter_cache:
58
+ self.alter_taylorseer = TaylorSeer(**self.taylorseer_kwargs)
59
+
60
+ def get_incremental_name(self, name=None):
61
+ if name is None:
62
+ name = "default"
63
+ idx = self.incremental_name_counters[name]
64
+ self.incremental_name_counters[name] += 1
65
+ return f"{name}_{idx}"
66
+
67
+ def reset_incremental_names(self):
68
+ self.incremental_name_counters.clear()
69
+
70
+ def get_residual_diff_threshold(self):
71
+ if all(
72
+ (
73
+ self.enable_alter_cache,
74
+ self.is_alter_cache,
75
+ self.alter_residual_diff_threshold is not None,
76
+ )
77
+ ):
78
+ residual_diff_threshold = self.alter_residual_diff_threshold
79
+ else:
80
+ residual_diff_threshold = self.residual_diff_threshold
81
+ if isinstance(residual_diff_threshold, torch.Tensor):
82
+ residual_diff_threshold = residual_diff_threshold.item()
83
+ return residual_diff_threshold
84
+
85
+ def get_buffer(self, name):
86
+ if self.enable_alter_cache and self.is_alter_cache:
87
+ name = f"{name}_alter"
88
+ return self.buffers.get(name)
89
+
90
+ def set_buffer(self, name, buffer):
91
+ if self.enable_alter_cache and self.is_alter_cache:
92
+ name = f"{name}_alter"
93
+ self.buffers[name] = buffer
94
+
95
+ def remove_buffer(self, name):
96
+ if self.enable_alter_cache and self.is_alter_cache:
97
+ name = f"{name}_alter"
98
+ if name in self.buffers:
99
+ del self.buffers[name]
100
+
101
+ def clear_buffers(self):
102
+ self.buffers.clear()
103
+
104
+ def mark_step_begin(self):
105
+ if not self.enable_alter_cache:
106
+ self.executed_steps += 1
107
+ else:
108
+ self.is_alter_cache = not self.is_alter_cache
109
+ if not self.is_alter_cache:
110
+ self.executed_steps += 1
111
+ if self.enable_taylorseer:
112
+ taylorseer = self.get_taylorseer()
113
+ taylorseer.mark_step_begin()
114
+ if self.get_current_step() == 0:
115
+ self.cached_steps.clear()
116
+ self.residual_diffs.clear()
117
+
118
+ def add_residual_diff(self, diff):
119
+ step = str(self.get_current_step())
120
+ self.residual_diffs[step] = diff
121
+
122
+ def get_residual_diffs(self):
123
+ return self.residual_diffs.copy()
124
+
125
+ def add_cached_step(self):
126
+ self.cached_steps.append(self.get_current_step())
127
+
128
+ def get_cached_steps(self):
129
+ return self.cached_steps.copy()
130
+
131
+ def get_taylorseer(self):
132
+ if self.enable_alter_cache and self.is_alter_cache:
133
+ return self.alter_taylorseer
134
+ return self.taylorseer
135
+
136
+ def is_slg_enabled(self):
137
+ return self.slg_layers is not None
138
+
139
+ def slg_should_skip_block(self, block_idx):
140
+ if not self.enable_alter_cache or not self.is_alter_cache:
141
+ return False
142
+ if self.slg_layers is None:
143
+ return False
144
+ if self.slg_start <= 0.0 and self.slg_end >= 1.0:
145
+ return False
146
+ num_inference_steps = self.num_inference_steps
147
+ assert (
148
+ num_inference_steps >= 0
149
+ ), "num_inference_steps must be non-negative"
150
+ return (
151
+ block_idx in self.slg_layers
152
+ and num_inference_steps * self.slg_start
153
+ <= self.get_current_step()
154
+ < num_inference_steps * self.slg_end
155
+ )
156
+
157
+ def get_current_step(self):
158
+ return self.executed_steps - 1
159
+
160
+ def is_in_warmup(self):
161
+ return self.get_current_step() < self.warmup_steps
162
+
163
+
164
+ @torch.compiler.disable
165
+ def get_residual_diff_threshold():
166
+ cache_context = get_current_cache_context()
167
+ assert cache_context is not None, "cache_context must be set before"
168
+ return cache_context.get_residual_diff_threshold()
169
+
170
+
171
+ @torch.compiler.disable
172
+ def get_buffer(name):
173
+ cache_context = get_current_cache_context()
174
+ assert cache_context is not None, "cache_context must be set before"
175
+ return cache_context.get_buffer(name)
176
+
177
+
178
+ @torch.compiler.disable
179
+ def set_buffer(name, buffer):
180
+ cache_context = get_current_cache_context()
181
+ assert cache_context is not None, "cache_context must be set before"
182
+ cache_context.set_buffer(name, buffer)
183
+
184
+
185
+ @torch.compiler.disable
186
+ def remove_buffer(name):
187
+ cache_context = get_current_cache_context()
188
+ assert cache_context is not None, "cache_context must be set before"
189
+ cache_context.remove_buffer(name)
190
+
191
+
192
+ @torch.compiler.disable
193
+ def mark_step_begin():
194
+ cache_context = get_current_cache_context()
195
+ assert cache_context is not None, "cache_context must be set before"
196
+ cache_context.mark_step_begin()
197
+
198
+
199
+ @torch.compiler.disable
200
+ def get_current_step():
201
+ cache_context = get_current_cache_context()
202
+ assert cache_context is not None, "cache_context must be set before"
203
+ return cache_context.get_current_step()
204
+
205
+
206
+ @torch.compiler.disable
207
+ def get_cached_steps():
208
+ cache_context = get_current_cache_context()
209
+ assert cache_context is not None, "cache_context must be set before"
210
+ return cache_context.get_cached_steps()
211
+
212
+
213
+ @torch.compiler.disable
214
+ def get_max_cached_steps():
215
+ cache_context = get_current_cache_context()
216
+ assert cache_context is not None, "cache_context must be set before"
217
+ return cache_context.max_cached_steps
218
+
219
+
220
+ @torch.compiler.disable
221
+ def add_cached_step():
222
+ cache_context = get_current_cache_context()
223
+ assert cache_context is not None, "cache_context must be set before"
224
+ cache_context.add_cached_step()
225
+
226
+
227
+ @torch.compiler.disable
228
+ def add_residual_diff(diff):
229
+ cache_context = get_current_cache_context()
230
+ assert cache_context is not None, "cache_context must be set before"
231
+ cache_context.add_residual_diff(diff)
232
+
233
+
234
+ @torch.compiler.disable
235
+ def get_residual_diffs():
236
+ cache_context = get_current_cache_context()
237
+ assert cache_context is not None, "cache_context must be set before"
238
+ return cache_context.get_residual_diffs()
239
+
240
+
241
+ @torch.compiler.disable
242
+ def is_taylorseer_enabled():
243
+ cache_context = get_current_cache_context()
244
+ assert cache_context is not None, "cache_context must be set before"
245
+ return cache_context.enable_taylorseer
246
+
247
+
248
+ @torch.compiler.disable
249
+ def get_taylorseer():
250
+ cache_context = get_current_cache_context()
251
+ assert cache_context is not None, "cache_context must be set before"
252
+ return cache_context.get_taylorseer()
253
+
254
+
255
+ @torch.compiler.disable
256
+ def is_slg_enabled():
257
+ cache_context = get_current_cache_context()
258
+ assert cache_context is not None, "cache_context must be set before"
259
+ return cache_context.is_slg_enabled()
260
+
261
+
262
+ @torch.compiler.disable
263
+ def slg_should_skip_block(block_idx):
264
+ cache_context = get_current_cache_context()
265
+ assert cache_context is not None, "cache_context must be set before"
266
+ return cache_context.slg_should_skip_block(block_idx)
267
+
268
+
269
+ @torch.compiler.disable
270
+ def is_in_warmup():
271
+ cache_context = get_current_cache_context()
272
+ assert cache_context is not None, "cache_context must be set before"
273
+ return cache_context.is_in_warmup()
274
+
275
+
276
+ _current_cache_context: CacheContext = None
277
+
278
+
279
+ def create_cache_context(*args, **kwargs):
280
+ return CacheContext(*args, **kwargs)
281
+
282
+
283
+ def get_current_cache_context():
284
+ return _current_cache_context
285
+
286
+
287
+ def set_current_cache_context(cache_context=None):
288
+ global _current_cache_context
289
+ _current_cache_context = cache_context
290
+
291
+
292
+ def collect_cache_kwargs(default_attrs: dict, **kwargs):
293
+ # NOTE: This API will split kwargs into cache_kwargs and other_kwargs
294
+ # default_attrs: specific settings for different pipelines
295
+ cache_attrs = dataclasses.fields(CacheContext)
296
+ cache_attrs = [
297
+ attr
298
+ for attr in cache_attrs
299
+ if hasattr(
300
+ CacheContext,
301
+ attr.name,
302
+ )
303
+ ]
304
+ cache_kwargs = {
305
+ attr.name: kwargs.pop(
306
+ attr.name,
307
+ getattr(CacheContext, attr.name),
308
+ )
309
+ for attr in cache_attrs
310
+ }
311
+
312
+ assert default_attrs is not None, "default_attrs must be set before"
313
+ for attr in cache_attrs:
314
+ if attr.name in default_attrs:
315
+ cache_kwargs[attr.name] = default_attrs[attr.name]
316
+
317
+ if logger.isEnabledFor(logging.DEBUG):
318
+ logger.debug(f"Collected Cache kwargs: {cache_kwargs}")
319
+
320
+ return cache_kwargs, kwargs
321
+
322
+
323
+ @contextlib.contextmanager
324
+ def cache_context(cache_context):
325
+ global _current_cache_context
326
+ old_cache_context = _current_cache_context
327
+ _current_cache_context = cache_context
328
+ try:
329
+ yield
330
+ finally:
331
+ _current_cache_context = old_cache_context
332
+
333
+
334
+ @torch.compiler.disable
335
+ def are_two_tensors_similar(
336
+ t1: torch.Tensor,
337
+ t2: torch.Tensor,
338
+ *,
339
+ threshold: float,
340
+ parallelized: bool = False,
341
+ ):
342
+ if threshold <= 0.0:
343
+ return False
344
+
345
+ if t1.shape != t2.shape:
346
+ return False
347
+
348
+ mean_diff = (t1 - t2).abs().mean()
349
+ mean_t1 = t1.abs().mean()
350
+ if parallelized:
351
+ mean_diff = DP.all_reduce_sync(mean_diff, "avg")
352
+ mean_t1 = DP.all_reduce_sync(mean_t1, "avg")
353
+ diff = (mean_diff / mean_t1).item()
354
+
355
+ add_residual_diff(diff)
356
+
357
+ return diff < threshold
358
+
359
+
360
+ @torch.compiler.disable
361
+ def apply_prev_hidden_states_residual(
362
+ hidden_states: torch.Tensor,
363
+ encoder_hidden_states: torch.Tensor,
364
+ ):
365
+ if is_taylorseer_enabled():
366
+ hidden_states_residual = get_hidden_states_residual()
367
+ assert (
368
+ hidden_states_residual is not None
369
+ ), "hidden_states_residual must be set before"
370
+ hidden_states = hidden_states_residual + hidden_states
371
+
372
+ hidden_states = hidden_states.contiguous()
373
+ else:
374
+ hidden_states_residual = get_hidden_states_residual()
375
+ assert (
376
+ hidden_states_residual is not None
377
+ ), "hidden_states_residual must be set before"
378
+ hidden_states = hidden_states_residual + hidden_states
379
+
380
+ encoder_hidden_states_residual = get_encoder_hidden_states_residual()
381
+ assert (
382
+ encoder_hidden_states_residual is not None
383
+ ), "encoder_hidden_states_residual must be set before"
384
+ encoder_hidden_states = (
385
+ encoder_hidden_states_residual + encoder_hidden_states
386
+ )
387
+
388
+ hidden_states = hidden_states.contiguous()
389
+ encoder_hidden_states = encoder_hidden_states.contiguous()
390
+
391
+ return hidden_states, encoder_hidden_states
392
+
393
+
394
+ @torch.compiler.disable
395
+ def get_downsample_factor():
396
+ cache_context = get_current_cache_context()
397
+ assert cache_context is not None, "cache_context must be set before"
398
+ return cache_context.downsample_factor
399
+
400
+
401
+ @torch.compiler.disable
402
+ def get_can_use_cache(
403
+ first_hidden_states_residual: torch.Tensor,
404
+ parallelized: bool = False,
405
+ ):
406
+ if is_in_warmup():
407
+ return False
408
+ cached_steps = get_cached_steps()
409
+ max_cached_steps = get_max_cached_steps()
410
+ if max_cached_steps >= 0 and (len(cached_steps) >= max_cached_steps):
411
+ return False
412
+ threshold = get_residual_diff_threshold()
413
+ if threshold <= 0.0:
414
+ return False
415
+ downsample_factor = get_downsample_factor()
416
+ if downsample_factor > 1:
417
+ first_hidden_states_residual = first_hidden_states_residual[
418
+ ..., ::downsample_factor
419
+ ]
420
+ prev_first_hidden_states_residual = get_first_hidden_states_residual()
421
+ can_use_cache = (
422
+ prev_first_hidden_states_residual is not None
423
+ and are_two_tensors_similar(
424
+ prev_first_hidden_states_residual,
425
+ first_hidden_states_residual,
426
+ threshold=threshold,
427
+ parallelized=parallelized,
428
+ )
429
+ )
430
+ return can_use_cache
431
+
432
+
433
+ @torch.compiler.disable
434
+ def set_first_hidden_states_residual(
435
+ first_hidden_states_residual: torch.Tensor,
436
+ ):
437
+ downsample_factor = get_downsample_factor()
438
+ if downsample_factor > 1:
439
+ first_hidden_states_residual = first_hidden_states_residual[
440
+ ..., ::downsample_factor
441
+ ]
442
+ first_hidden_states_residual = first_hidden_states_residual.contiguous()
443
+ set_buffer("first_hidden_states_residual", first_hidden_states_residual)
444
+
445
+
446
+ @torch.compiler.disable
447
+ def get_first_hidden_states_residual():
448
+ return get_buffer("first_hidden_states_residual")
449
+
450
+
451
+ @torch.compiler.disable
452
+ def set_hidden_states_residual(hidden_states_residual: torch.Tensor):
453
+ if is_taylorseer_enabled():
454
+ taylorseer = get_taylorseer()
455
+ taylorseer.update(hidden_states_residual)
456
+ else:
457
+ set_buffer("hidden_states_residual", hidden_states_residual)
458
+
459
+
460
+ @torch.compiler.disable
461
+ def get_hidden_states_residual():
462
+ if is_taylorseer_enabled():
463
+ taylorseer = get_taylorseer()
464
+ return taylorseer.approximate_value()
465
+ else:
466
+ return get_buffer("hidden_states_residual")
467
+
468
+
469
+ @torch.compiler.disable
470
+ def set_encoder_hidden_states_residual(
471
+ encoder_hidden_states_residual: torch.Tensor,
472
+ ):
473
+ if is_taylorseer_enabled():
474
+ return
475
+ set_buffer("encoder_hidden_states_residual", encoder_hidden_states_residual)
476
+
477
+
478
+ @torch.compiler.disable
479
+ def get_encoder_hidden_states_residual():
480
+ return get_buffer("encoder_hidden_states_residual")
481
+
482
+
483
+ class CachedTransformerBlocks(torch.nn.Module):
484
+ def __init__(
485
+ self,
486
+ transformer_blocks,
487
+ single_transformer_blocks=None,
488
+ *,
489
+ transformer=None,
490
+ return_hidden_states_first=True,
491
+ return_hidden_states_only=False,
492
+ ):
493
+ super().__init__()
494
+
495
+ self.transformer = transformer
496
+ self.transformer_blocks = transformer_blocks
497
+ self.single_transformer_blocks = single_transformer_blocks
498
+ self.return_hidden_states_first = return_hidden_states_first
499
+ self.return_hidden_states_only = return_hidden_states_only
500
+
501
+ def forward(
502
+ self,
503
+ hidden_states: torch.Tensor,
504
+ encoder_hidden_states: torch.Tensor,
505
+ *args,
506
+ **kwargs,
507
+ ):
508
+ original_hidden_states = hidden_states
509
+ first_transformer_block = self.transformer_blocks[0]
510
+ hidden_states = first_transformer_block(
511
+ hidden_states,
512
+ encoder_hidden_states,
513
+ *args,
514
+ **kwargs,
515
+ )
516
+ if not isinstance(hidden_states, torch.Tensor):
517
+ hidden_states, encoder_hidden_states = hidden_states
518
+ if not self.return_hidden_states_first:
519
+ hidden_states, encoder_hidden_states = (
520
+ encoder_hidden_states,
521
+ hidden_states,
522
+ )
523
+ first_hidden_states_residual = hidden_states - original_hidden_states
524
+ del original_hidden_states
525
+
526
+ mark_step_begin()
527
+ can_use_cache = get_can_use_cache(
528
+ first_hidden_states_residual,
529
+ parallelized=self._is_parallelized(),
530
+ )
531
+
532
+ torch._dynamo.graph_break()
533
+ if can_use_cache:
534
+ add_cached_step()
535
+ del first_hidden_states_residual
536
+ hidden_states, encoder_hidden_states = (
537
+ apply_prev_hidden_states_residual(
538
+ hidden_states, encoder_hidden_states
539
+ )
540
+ )
541
+ else:
542
+ set_first_hidden_states_residual(first_hidden_states_residual)
543
+ del first_hidden_states_residual
544
+ (
545
+ hidden_states,
546
+ encoder_hidden_states,
547
+ hidden_states_residual,
548
+ encoder_hidden_states_residual,
549
+ ) = self.call_remaining_transformer_blocks(
550
+ hidden_states,
551
+ encoder_hidden_states,
552
+ *args,
553
+ **kwargs,
554
+ )
555
+ set_hidden_states_residual(hidden_states_residual)
556
+ set_encoder_hidden_states_residual(encoder_hidden_states_residual)
557
+
558
+ patch_cached_stats(self.transformer)
559
+ torch._dynamo.graph_break()
560
+
561
+ return (
562
+ hidden_states
563
+ if self.return_hidden_states_only
564
+ else (
565
+ (hidden_states, encoder_hidden_states)
566
+ if self.return_hidden_states_first
567
+ else (encoder_hidden_states, hidden_states)
568
+ )
569
+ )
570
+
571
+ def _is_parallelized(self):
572
+ return all(
573
+ (
574
+ self.transformer is not None,
575
+ getattr(self.transformer, "_is_parallelized", False),
576
+ )
577
+ )
578
+
579
+ def call_remaining_transformer_blocks(
580
+ self,
581
+ hidden_states: torch.Tensor,
582
+ encoder_hidden_states: torch.Tensor,
583
+ *args,
584
+ **kwargs,
585
+ ):
586
+ original_hidden_states = hidden_states
587
+ original_encoder_hidden_states = encoder_hidden_states
588
+ if not is_slg_enabled():
589
+ for block in self.transformer_blocks[1:]:
590
+ hidden_states = block(
591
+ hidden_states,
592
+ encoder_hidden_states,
593
+ *args,
594
+ **kwargs,
595
+ )
596
+ if not isinstance(hidden_states, torch.Tensor):
597
+ hidden_states, encoder_hidden_states = hidden_states
598
+ if not self.return_hidden_states_first:
599
+ hidden_states, encoder_hidden_states = (
600
+ encoder_hidden_states,
601
+ hidden_states,
602
+ )
603
+ if self.single_transformer_blocks is not None:
604
+ hidden_states = torch.cat(
605
+ [encoder_hidden_states, hidden_states], dim=1
606
+ )
607
+ for block in self.single_transformer_blocks:
608
+ hidden_states = block(
609
+ hidden_states,
610
+ *args,
611
+ **kwargs,
612
+ )
613
+ encoder_hidden_states, hidden_states = hidden_states.split(
614
+ [
615
+ encoder_hidden_states.shape[1],
616
+ hidden_states.shape[1] - encoder_hidden_states.shape[1],
617
+ ],
618
+ dim=1,
619
+ )
620
+ else:
621
+ for i, encoder_block in enumerate(self.transformer_blocks[1:]):
622
+ if slg_should_skip_block(i + 1):
623
+ continue
624
+ hidden_states = encoder_block(
625
+ hidden_states,
626
+ encoder_hidden_states,
627
+ *args,
628
+ **kwargs,
629
+ )
630
+ if not isinstance(hidden_states, torch.Tensor):
631
+ hidden_states, encoder_hidden_states = hidden_states
632
+ if not self.return_hidden_states_first:
633
+ hidden_states, encoder_hidden_states = (
634
+ encoder_hidden_states,
635
+ hidden_states,
636
+ )
637
+ if self.single_transformer_blocks is not None:
638
+ hidden_states = torch.cat(
639
+ [encoder_hidden_states, hidden_states], dim=1
640
+ )
641
+ for i, block in enumerate(self.single_transformer_blocks):
642
+ if slg_should_skip_block(len(self.transformer_blocks) + i):
643
+ continue
644
+ hidden_states = block(
645
+ hidden_states,
646
+ *args,
647
+ **kwargs,
648
+ )
649
+ encoder_hidden_states, hidden_states = hidden_states.split(
650
+ [
651
+ encoder_hidden_states.shape[1],
652
+ hidden_states.shape[1] - encoder_hidden_states.shape[1],
653
+ ],
654
+ dim=1,
655
+ )
656
+
657
+ # hidden_states_shape = hidden_states.shape
658
+ # encoder_hidden_states_shape = encoder_hidden_states.shape
659
+ hidden_states = (
660
+ hidden_states.reshape(-1)
661
+ .contiguous()
662
+ .reshape(
663
+ original_hidden_states.shape,
664
+ )
665
+ )
666
+ encoder_hidden_states = (
667
+ encoder_hidden_states.reshape(-1)
668
+ .contiguous()
669
+ .reshape(
670
+ original_encoder_hidden_states.shape,
671
+ )
672
+ )
673
+
674
+ # hidden_states = hidden_states.contiguous()
675
+ # encoder_hidden_states = encoder_hidden_states.contiguous()
676
+
677
+ hidden_states_residual = hidden_states - original_hidden_states
678
+ encoder_hidden_states_residual = (
679
+ encoder_hidden_states - original_encoder_hidden_states
680
+ )
681
+
682
+ hidden_states_residual = (
683
+ hidden_states_residual.reshape(-1)
684
+ .contiguous()
685
+ .reshape(
686
+ original_hidden_states.shape,
687
+ )
688
+ )
689
+ encoder_hidden_states_residual = (
690
+ encoder_hidden_states_residual.reshape(-1)
691
+ .contiguous()
692
+ .reshape(
693
+ original_encoder_hidden_states.shape,
694
+ )
695
+ )
696
+
697
+ return (
698
+ hidden_states,
699
+ encoder_hidden_states,
700
+ hidden_states_residual,
701
+ encoder_hidden_states_residual,
702
+ )
703
+
704
+
705
+ @torch.compiler.disable
706
+ def patch_cached_stats(
707
+ transformer,
708
+ ):
709
+ # Patch the cached stats to the transformer, the cached stats
710
+ # will be reset for each calling of pipe.__call__(**kwargs).
711
+ if transformer is None:
712
+ return
713
+
714
+ cached_transformer_blocks = getattr(transformer, "transformer_blocks", None)
715
+ if cached_transformer_blocks is None:
716
+ return
717
+
718
+ if isinstance(cached_transformer_blocks, torch.nn.ModuleList):
719
+ cached_transformer_blocks = cached_transformer_blocks[0]
720
+ if not isinstance(
721
+ cached_transformer_blocks, CachedTransformerBlocks
722
+ ) or not isinstance(transformer, torch.nn.Module):
723
+ return
724
+
725
+ # TODO: Patch more cached stats to the transformer
726
+ transformer._cached_steps = get_cached_steps()
727
+ transformer._residual_diffs = get_residual_diffs()