cache-dit 0.2.14__py3-none-any.whl → 0.2.16__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.

Potentially problematic release.


This version of cache-dit might be problematic. Click here for more details.

Files changed (37) hide show
  1. cache_dit/_version.py +2 -2
  2. cache_dit/cache_factory/__init__.py +1 -0
  3. cache_dit/cache_factory/adapters.py +47 -5
  4. cache_dit/cache_factory/dual_block_cache/__init__.py +4 -0
  5. cache_dit/cache_factory/dual_block_cache/cache_blocks.py +487 -0
  6. cache_dit/cache_factory/dual_block_cache/cache_context.py +10 -860
  7. cache_dit/cache_factory/dual_block_cache/diffusers_adapters/__init__.py +4 -0
  8. cache_dit/cache_factory/dual_block_cache/diffusers_adapters/cogvideox.py +5 -2
  9. cache_dit/cache_factory/dual_block_cache/diffusers_adapters/flux.py +14 -4
  10. cache_dit/cache_factory/dual_block_cache/diffusers_adapters/hunyuan_video.py +5 -2
  11. cache_dit/cache_factory/dual_block_cache/diffusers_adapters/mochi.py +5 -2
  12. cache_dit/cache_factory/{first_block_cache/diffusers_adapters/mochi.py → dual_block_cache/diffusers_adapters/qwen_image.py} +14 -12
  13. cache_dit/cache_factory/dual_block_cache/diffusers_adapters/wan.py +7 -4
  14. cache_dit/cache_factory/dynamic_block_prune/__init__.py +4 -0
  15. cache_dit/cache_factory/dynamic_block_prune/diffusers_adapters/__init__.py +4 -0
  16. cache_dit/cache_factory/dynamic_block_prune/diffusers_adapters/cogvideox.py +5 -2
  17. cache_dit/cache_factory/dynamic_block_prune/diffusers_adapters/flux.py +10 -4
  18. cache_dit/cache_factory/dynamic_block_prune/diffusers_adapters/hunyuan_video.py +5 -2
  19. cache_dit/cache_factory/dynamic_block_prune/diffusers_adapters/mochi.py +5 -2
  20. cache_dit/cache_factory/{first_block_cache/diffusers_adapters/cogvideox.py → dynamic_block_prune/diffusers_adapters/qwen_image.py} +28 -23
  21. cache_dit/cache_factory/dynamic_block_prune/diffusers_adapters/wan.py +5 -2
  22. cache_dit/cache_factory/dynamic_block_prune/prune_blocks.py +276 -0
  23. cache_dit/cache_factory/dynamic_block_prune/prune_context.py +228 -516
  24. cache_dit/cache_factory/patch/flux.py +241 -0
  25. {cache_dit-0.2.14.dist-info → cache_dit-0.2.16.dist-info}/METADATA +22 -80
  26. cache_dit-0.2.16.dist-info/RECORD +47 -0
  27. cache_dit/cache_factory/first_block_cache/cache_context.py +0 -719
  28. cache_dit/cache_factory/first_block_cache/diffusers_adapters/__init__.py +0 -57
  29. cache_dit/cache_factory/first_block_cache/diffusers_adapters/flux.py +0 -100
  30. cache_dit/cache_factory/first_block_cache/diffusers_adapters/hunyuan_video.py +0 -295
  31. cache_dit/cache_factory/first_block_cache/diffusers_adapters/wan.py +0 -98
  32. cache_dit-0.2.14.dist-info/RECORD +0 -49
  33. /cache_dit/cache_factory/{first_block_cache → patch}/__init__.py +0 -0
  34. {cache_dit-0.2.14.dist-info → cache_dit-0.2.16.dist-info}/WHEEL +0 -0
  35. {cache_dit-0.2.14.dist-info → cache_dit-0.2.16.dist-info}/entry_points.txt +0 -0
  36. {cache_dit-0.2.14.dist-info → cache_dit-0.2.16.dist-info}/licenses/LICENSE +0 -0
  37. {cache_dit-0.2.14.dist-info → cache_dit-0.2.16.dist-info}/top_level.txt +0 -0
@@ -1,719 +0,0 @@
1
- # Adapted from: https://github.com/chengzeyi/ParaAttention/tree/main/src/para_attn/first_block_cache/context.py
2
- import contextlib
3
- import dataclasses
4
- import logging
5
- from collections import defaultdict
6
- from typing import Any, DefaultDict, Dict, List, Optional, Union
7
-
8
- import torch
9
-
10
- import cache_dit.primitives as primitives
11
- from cache_dit.cache_factory.taylorseer import TaylorSeer
12
- from cache_dit.logger import init_logger
13
-
14
- logger = init_logger(__name__)
15
-
16
-
17
- @dataclasses.dataclass
18
- class CacheContext:
19
- residual_diff_threshold: Union[torch.Tensor, float] = 0.0
20
- alter_residual_diff_threshold: Optional[Union[torch.Tensor, float]] = None
21
-
22
- downsample_factor: int = 1
23
-
24
- enable_alter_cache: bool = False
25
- num_inference_steps: int = -1
26
- warmup_steps: int = 0
27
-
28
- enable_taylorseer: bool = False
29
- taylorseer_kwargs: Dict[str, Any] = dataclasses.field(default_factory=dict)
30
-
31
- # Skip Layer Guidance, SLG
32
- # https://github.com/huggingface/candle/issues/2588
33
- slg_layers: Optional[List[int]] = None
34
- slg_start: float = 0.0
35
- slg_end: float = 0.1
36
-
37
- taylorseer: Optional[TaylorSeer] = None
38
- alter_taylorseer: Optional[TaylorSeer] = None
39
-
40
- buffers: Dict[str, Any] = dataclasses.field(default_factory=dict)
41
- incremental_name_counters: DefaultDict[str, int] = dataclasses.field(
42
- default_factory=lambda: defaultdict(int),
43
- )
44
-
45
- executed_steps: int = 0
46
- is_alter_cache: bool = True
47
-
48
- max_cached_steps: int = -1
49
- cached_steps: List[int] = dataclasses.field(default_factory=list)
50
- residual_diffs: DefaultDict[str, float] = dataclasses.field(
51
- default_factory=lambda: defaultdict(float),
52
- )
53
-
54
- def __post_init__(self):
55
- if self.enable_taylorseer:
56
- self.taylorseer = TaylorSeer(**self.taylorseer_kwargs)
57
- if self.enable_alter_cache:
58
- self.alter_taylorseer = TaylorSeer(**self.taylorseer_kwargs)
59
-
60
- def get_incremental_name(self, name=None):
61
- if name is None:
62
- name = "default"
63
- idx = self.incremental_name_counters[name]
64
- self.incremental_name_counters[name] += 1
65
- return f"{name}_{idx}"
66
-
67
- def reset_incremental_names(self):
68
- self.incremental_name_counters.clear()
69
-
70
- def get_residual_diff_threshold(self):
71
- if all(
72
- (
73
- self.enable_alter_cache,
74
- self.is_alter_cache,
75
- self.alter_residual_diff_threshold is not None,
76
- )
77
- ):
78
- residual_diff_threshold = self.alter_residual_diff_threshold
79
- else:
80
- residual_diff_threshold = self.residual_diff_threshold
81
- if isinstance(residual_diff_threshold, torch.Tensor):
82
- residual_diff_threshold = residual_diff_threshold.item()
83
- return residual_diff_threshold
84
-
85
- def get_buffer(self, name):
86
- if self.enable_alter_cache and self.is_alter_cache:
87
- name = f"{name}_alter"
88
- return self.buffers.get(name)
89
-
90
- def set_buffer(self, name, buffer):
91
- if self.enable_alter_cache and self.is_alter_cache:
92
- name = f"{name}_alter"
93
- self.buffers[name] = buffer
94
-
95
- def remove_buffer(self, name):
96
- if self.enable_alter_cache and self.is_alter_cache:
97
- name = f"{name}_alter"
98
- if name in self.buffers:
99
- del self.buffers[name]
100
-
101
- def clear_buffers(self):
102
- self.buffers.clear()
103
-
104
- def mark_step_begin(self):
105
- if not self.enable_alter_cache:
106
- self.executed_steps += 1
107
- else:
108
- self.is_alter_cache = not self.is_alter_cache
109
- if not self.is_alter_cache:
110
- self.executed_steps += 1
111
- if self.enable_taylorseer:
112
- taylorseer = self.get_taylorseer()
113
- taylorseer.mark_step_begin()
114
- if self.get_current_step() == 0:
115
- self.cached_steps.clear()
116
- self.residual_diffs.clear()
117
-
118
- def add_residual_diff(self, diff):
119
- step = str(self.get_current_step())
120
- self.residual_diffs[step] = diff
121
-
122
- def get_residual_diffs(self):
123
- return self.residual_diffs.copy()
124
-
125
- def add_cached_step(self):
126
- self.cached_steps.append(self.get_current_step())
127
-
128
- def get_cached_steps(self):
129
- return self.cached_steps.copy()
130
-
131
- def get_taylorseer(self):
132
- if self.enable_alter_cache and self.is_alter_cache:
133
- return self.alter_taylorseer
134
- return self.taylorseer
135
-
136
- def is_slg_enabled(self):
137
- return self.slg_layers is not None
138
-
139
- def slg_should_skip_block(self, block_idx):
140
- if not self.enable_alter_cache or not self.is_alter_cache:
141
- return False
142
- if self.slg_layers is None:
143
- return False
144
- if self.slg_start <= 0.0 and self.slg_end >= 1.0:
145
- return False
146
- num_inference_steps = self.num_inference_steps
147
- assert (
148
- num_inference_steps >= 0
149
- ), "num_inference_steps must be non-negative"
150
- return (
151
- block_idx in self.slg_layers
152
- and num_inference_steps * self.slg_start
153
- <= self.get_current_step()
154
- < num_inference_steps * self.slg_end
155
- )
156
-
157
- def get_current_step(self):
158
- return self.executed_steps - 1
159
-
160
- def is_in_warmup(self):
161
- return self.get_current_step() < self.warmup_steps
162
-
163
-
164
- @torch.compiler.disable
165
- def get_residual_diff_threshold():
166
- cache_context = get_current_cache_context()
167
- assert cache_context is not None, "cache_context must be set before"
168
- return cache_context.get_residual_diff_threshold()
169
-
170
-
171
- @torch.compiler.disable
172
- def get_buffer(name):
173
- cache_context = get_current_cache_context()
174
- assert cache_context is not None, "cache_context must be set before"
175
- return cache_context.get_buffer(name)
176
-
177
-
178
- @torch.compiler.disable
179
- def set_buffer(name, buffer):
180
- cache_context = get_current_cache_context()
181
- assert cache_context is not None, "cache_context must be set before"
182
- cache_context.set_buffer(name, buffer)
183
-
184
-
185
- @torch.compiler.disable
186
- def remove_buffer(name):
187
- cache_context = get_current_cache_context()
188
- assert cache_context is not None, "cache_context must be set before"
189
- cache_context.remove_buffer(name)
190
-
191
-
192
- @torch.compiler.disable
193
- def mark_step_begin():
194
- cache_context = get_current_cache_context()
195
- assert cache_context is not None, "cache_context must be set before"
196
- cache_context.mark_step_begin()
197
-
198
-
199
- @torch.compiler.disable
200
- def get_current_step():
201
- cache_context = get_current_cache_context()
202
- assert cache_context is not None, "cache_context must be set before"
203
- return cache_context.get_current_step()
204
-
205
-
206
- @torch.compiler.disable
207
- def get_cached_steps():
208
- cache_context = get_current_cache_context()
209
- assert cache_context is not None, "cache_context must be set before"
210
- return cache_context.get_cached_steps()
211
-
212
-
213
- @torch.compiler.disable
214
- def get_max_cached_steps():
215
- cache_context = get_current_cache_context()
216
- assert cache_context is not None, "cache_context must be set before"
217
- return cache_context.max_cached_steps
218
-
219
-
220
- @torch.compiler.disable
221
- def add_cached_step():
222
- cache_context = get_current_cache_context()
223
- assert cache_context is not None, "cache_context must be set before"
224
- cache_context.add_cached_step()
225
-
226
-
227
- @torch.compiler.disable
228
- def add_residual_diff(diff):
229
- cache_context = get_current_cache_context()
230
- assert cache_context is not None, "cache_context must be set before"
231
- cache_context.add_residual_diff(diff)
232
-
233
-
234
- @torch.compiler.disable
235
- def get_residual_diffs():
236
- cache_context = get_current_cache_context()
237
- assert cache_context is not None, "cache_context must be set before"
238
- return cache_context.get_residual_diffs()
239
-
240
-
241
- @torch.compiler.disable
242
- def is_taylorseer_enabled():
243
- cache_context = get_current_cache_context()
244
- assert cache_context is not None, "cache_context must be set before"
245
- return cache_context.enable_taylorseer
246
-
247
-
248
- @torch.compiler.disable
249
- def get_taylorseer():
250
- cache_context = get_current_cache_context()
251
- assert cache_context is not None, "cache_context must be set before"
252
- return cache_context.get_taylorseer()
253
-
254
-
255
- @torch.compiler.disable
256
- def is_slg_enabled():
257
- cache_context = get_current_cache_context()
258
- assert cache_context is not None, "cache_context must be set before"
259
- return cache_context.is_slg_enabled()
260
-
261
-
262
- @torch.compiler.disable
263
- def slg_should_skip_block(block_idx):
264
- cache_context = get_current_cache_context()
265
- assert cache_context is not None, "cache_context must be set before"
266
- return cache_context.slg_should_skip_block(block_idx)
267
-
268
-
269
- @torch.compiler.disable
270
- def is_in_warmup():
271
- cache_context = get_current_cache_context()
272
- assert cache_context is not None, "cache_context must be set before"
273
- return cache_context.is_in_warmup()
274
-
275
-
276
- _current_cache_context: CacheContext = None
277
-
278
-
279
- def create_cache_context(*args, **kwargs):
280
- return CacheContext(*args, **kwargs)
281
-
282
-
283
- def get_current_cache_context():
284
- return _current_cache_context
285
-
286
-
287
- def set_current_cache_context(cache_context=None):
288
- global _current_cache_context
289
- _current_cache_context = cache_context
290
-
291
-
292
- def collect_cache_kwargs(default_attrs: dict, **kwargs):
293
- # NOTE: This API will split kwargs into cache_kwargs and other_kwargs
294
- # default_attrs: specific settings for different pipelines
295
- cache_attrs = dataclasses.fields(CacheContext)
296
- cache_attrs = [
297
- attr
298
- for attr in cache_attrs
299
- if hasattr(
300
- CacheContext,
301
- attr.name,
302
- )
303
- ]
304
- cache_kwargs = {
305
- attr.name: kwargs.pop(
306
- attr.name,
307
- getattr(CacheContext, attr.name),
308
- )
309
- for attr in cache_attrs
310
- }
311
-
312
- assert default_attrs is not None, "default_attrs must be set before"
313
- for attr in cache_attrs:
314
- if attr.name in default_attrs:
315
- cache_kwargs[attr.name] = default_attrs[attr.name]
316
-
317
- if logger.isEnabledFor(logging.DEBUG):
318
- logger.debug(f"Collected Cache kwargs: {cache_kwargs}")
319
-
320
- return cache_kwargs, kwargs
321
-
322
-
323
- @contextlib.contextmanager
324
- def cache_context(cache_context):
325
- global _current_cache_context
326
- old_cache_context = _current_cache_context
327
- _current_cache_context = cache_context
328
- try:
329
- yield
330
- finally:
331
- _current_cache_context = old_cache_context
332
-
333
-
334
- @torch.compiler.disable
335
- def are_two_tensors_similar(
336
- t1: torch.Tensor,
337
- t2: torch.Tensor,
338
- *,
339
- threshold: float,
340
- parallelized: bool = False,
341
- ):
342
- if threshold <= 0.0:
343
- return False
344
-
345
- if t1.shape != t2.shape:
346
- return False
347
-
348
- mean_diff = (t1 - t2).abs().mean()
349
- mean_t1 = t1.abs().mean()
350
- if parallelized:
351
- mean_diff = primitives.all_reduce_sync(mean_diff, "avg")
352
- mean_t1 = primitives.all_reduce_sync(mean_t1, "avg")
353
- diff = (mean_diff / mean_t1).item()
354
-
355
- add_residual_diff(diff)
356
-
357
- return diff < threshold
358
-
359
-
360
- @torch.compiler.disable
361
- def apply_prev_hidden_states_residual(
362
- hidden_states: torch.Tensor,
363
- encoder_hidden_states: torch.Tensor,
364
- ):
365
- if is_taylorseer_enabled():
366
- hidden_states_residual = get_hidden_states_residual()
367
- assert (
368
- hidden_states_residual is not None
369
- ), "hidden_states_residual must be set before"
370
- hidden_states = hidden_states_residual + hidden_states
371
-
372
- hidden_states = hidden_states.contiguous()
373
- # NOTE: We should also support taylorseer for
374
- # encoder_hidden_states approximation. Please
375
- # use DBCache instead.
376
- else:
377
- hidden_states_residual = get_hidden_states_residual()
378
- assert (
379
- hidden_states_residual is not None
380
- ), "hidden_states_residual must be set before"
381
- hidden_states = hidden_states_residual + hidden_states
382
-
383
- encoder_hidden_states_residual = get_encoder_hidden_states_residual()
384
- assert (
385
- encoder_hidden_states_residual is not None
386
- ), "encoder_hidden_states_residual must be set before"
387
- encoder_hidden_states = (
388
- encoder_hidden_states_residual + encoder_hidden_states
389
- )
390
-
391
- hidden_states = hidden_states.contiguous()
392
- encoder_hidden_states = encoder_hidden_states.contiguous()
393
-
394
- return hidden_states, encoder_hidden_states
395
-
396
-
397
- @torch.compiler.disable
398
- def get_downsample_factor():
399
- cache_context = get_current_cache_context()
400
- assert cache_context is not None, "cache_context must be set before"
401
- return cache_context.downsample_factor
402
-
403
-
404
- @torch.compiler.disable
405
- def get_can_use_cache(
406
- first_hidden_states_residual: torch.Tensor,
407
- parallelized: bool = False,
408
- ):
409
- if is_in_warmup():
410
- return False
411
- cached_steps = get_cached_steps()
412
- max_cached_steps = get_max_cached_steps()
413
- if max_cached_steps >= 0 and (len(cached_steps) >= max_cached_steps):
414
- return False
415
- threshold = get_residual_diff_threshold()
416
- if threshold <= 0.0:
417
- return False
418
- downsample_factor = get_downsample_factor()
419
- if downsample_factor > 1:
420
- first_hidden_states_residual = first_hidden_states_residual[
421
- ..., ::downsample_factor
422
- ]
423
- prev_first_hidden_states_residual = get_first_hidden_states_residual()
424
- can_use_cache = (
425
- prev_first_hidden_states_residual is not None
426
- and are_two_tensors_similar(
427
- prev_first_hidden_states_residual,
428
- first_hidden_states_residual,
429
- threshold=threshold,
430
- parallelized=parallelized,
431
- )
432
- )
433
- return can_use_cache
434
-
435
-
436
- @torch.compiler.disable
437
- def set_first_hidden_states_residual(
438
- first_hidden_states_residual: torch.Tensor,
439
- ):
440
- downsample_factor = get_downsample_factor()
441
- if downsample_factor > 1:
442
- first_hidden_states_residual = first_hidden_states_residual[
443
- ..., ::downsample_factor
444
- ]
445
- first_hidden_states_residual = first_hidden_states_residual.contiguous()
446
- set_buffer("first_hidden_states_residual", first_hidden_states_residual)
447
-
448
-
449
- @torch.compiler.disable
450
- def get_first_hidden_states_residual():
451
- return get_buffer("first_hidden_states_residual")
452
-
453
-
454
- @torch.compiler.disable
455
- def set_hidden_states_residual(hidden_states_residual: torch.Tensor):
456
- if is_taylorseer_enabled():
457
- taylorseer = get_taylorseer()
458
- taylorseer.update(hidden_states_residual)
459
- else:
460
- set_buffer("hidden_states_residual", hidden_states_residual)
461
-
462
-
463
- @torch.compiler.disable
464
- def get_hidden_states_residual():
465
- if is_taylorseer_enabled():
466
- taylorseer = get_taylorseer()
467
- return taylorseer.approximate_value()
468
- else:
469
- return get_buffer("hidden_states_residual")
470
-
471
-
472
- @torch.compiler.disable
473
- def set_encoder_hidden_states_residual(
474
- encoder_hidden_states_residual: torch.Tensor,
475
- ):
476
- if is_taylorseer_enabled():
477
- return
478
- set_buffer("encoder_hidden_states_residual", encoder_hidden_states_residual)
479
-
480
-
481
- @torch.compiler.disable
482
- def get_encoder_hidden_states_residual():
483
- return get_buffer("encoder_hidden_states_residual")
484
-
485
-
486
- class CachedTransformerBlocks(torch.nn.Module):
487
- def __init__(
488
- self,
489
- transformer_blocks,
490
- single_transformer_blocks=None,
491
- *,
492
- transformer=None,
493
- return_hidden_states_first=True,
494
- return_hidden_states_only=False,
495
- ):
496
- super().__init__()
497
-
498
- self.transformer = transformer
499
- self.transformer_blocks = transformer_blocks
500
- self.single_transformer_blocks = single_transformer_blocks
501
- self.return_hidden_states_first = return_hidden_states_first
502
- self.return_hidden_states_only = return_hidden_states_only
503
-
504
- def forward(
505
- self,
506
- hidden_states: torch.Tensor,
507
- encoder_hidden_states: torch.Tensor,
508
- *args,
509
- **kwargs,
510
- ):
511
- original_hidden_states = hidden_states
512
- first_transformer_block = self.transformer_blocks[0]
513
- hidden_states = first_transformer_block(
514
- hidden_states,
515
- encoder_hidden_states,
516
- *args,
517
- **kwargs,
518
- )
519
- if not isinstance(hidden_states, torch.Tensor):
520
- hidden_states, encoder_hidden_states = hidden_states
521
- if not self.return_hidden_states_first:
522
- hidden_states, encoder_hidden_states = (
523
- encoder_hidden_states,
524
- hidden_states,
525
- )
526
- first_hidden_states_residual = hidden_states - original_hidden_states
527
- del original_hidden_states
528
-
529
- mark_step_begin()
530
- can_use_cache = get_can_use_cache(
531
- first_hidden_states_residual,
532
- parallelized=self._is_parallelized(),
533
- )
534
-
535
- torch._dynamo.graph_break()
536
- if can_use_cache:
537
- add_cached_step()
538
- del first_hidden_states_residual
539
- hidden_states, encoder_hidden_states = (
540
- apply_prev_hidden_states_residual(
541
- hidden_states, encoder_hidden_states
542
- )
543
- )
544
- else:
545
- set_first_hidden_states_residual(first_hidden_states_residual)
546
- del first_hidden_states_residual
547
- (
548
- hidden_states,
549
- encoder_hidden_states,
550
- hidden_states_residual,
551
- encoder_hidden_states_residual,
552
- ) = self.call_remaining_transformer_blocks(
553
- hidden_states,
554
- encoder_hidden_states,
555
- *args,
556
- **kwargs,
557
- )
558
- set_hidden_states_residual(hidden_states_residual)
559
- set_encoder_hidden_states_residual(encoder_hidden_states_residual)
560
-
561
- patch_cached_stats(self.transformer)
562
- torch._dynamo.graph_break()
563
-
564
- return (
565
- hidden_states
566
- if self.return_hidden_states_only
567
- else (
568
- (hidden_states, encoder_hidden_states)
569
- if self.return_hidden_states_first
570
- else (encoder_hidden_states, hidden_states)
571
- )
572
- )
573
-
574
- def _is_parallelized(self):
575
- return all(
576
- (
577
- self.transformer is not None,
578
- getattr(self.transformer, "_is_parallelized", False),
579
- )
580
- )
581
-
582
- def call_remaining_transformer_blocks(
583
- self,
584
- hidden_states: torch.Tensor,
585
- encoder_hidden_states: torch.Tensor,
586
- *args,
587
- **kwargs,
588
- ):
589
- original_hidden_states = hidden_states
590
- original_encoder_hidden_states = encoder_hidden_states
591
- if not is_slg_enabled():
592
- for block in self.transformer_blocks[1:]:
593
- hidden_states = block(
594
- hidden_states,
595
- encoder_hidden_states,
596
- *args,
597
- **kwargs,
598
- )
599
- if not isinstance(hidden_states, torch.Tensor):
600
- hidden_states, encoder_hidden_states = hidden_states
601
- if not self.return_hidden_states_first:
602
- hidden_states, encoder_hidden_states = (
603
- encoder_hidden_states,
604
- hidden_states,
605
- )
606
- if self.single_transformer_blocks is not None:
607
- hidden_states = torch.cat(
608
- [encoder_hidden_states, hidden_states], dim=1
609
- )
610
- for block in self.single_transformer_blocks:
611
- hidden_states = block(
612
- hidden_states,
613
- *args,
614
- **kwargs,
615
- )
616
- encoder_hidden_states, hidden_states = hidden_states.split(
617
- [
618
- encoder_hidden_states.shape[1],
619
- hidden_states.shape[1] - encoder_hidden_states.shape[1],
620
- ],
621
- dim=1,
622
- )
623
- else:
624
- for i, encoder_block in enumerate(self.transformer_blocks[1:]):
625
- if slg_should_skip_block(i + 1):
626
- continue
627
- hidden_states = encoder_block(
628
- hidden_states,
629
- encoder_hidden_states,
630
- *args,
631
- **kwargs,
632
- )
633
- if not isinstance(hidden_states, torch.Tensor):
634
- hidden_states, encoder_hidden_states = hidden_states
635
- if not self.return_hidden_states_first:
636
- hidden_states, encoder_hidden_states = (
637
- encoder_hidden_states,
638
- hidden_states,
639
- )
640
- if self.single_transformer_blocks is not None:
641
- hidden_states = torch.cat(
642
- [encoder_hidden_states, hidden_states], dim=1
643
- )
644
- for i, block in enumerate(self.single_transformer_blocks):
645
- if slg_should_skip_block(len(self.transformer_blocks) + i):
646
- continue
647
- hidden_states = block(
648
- hidden_states,
649
- *args,
650
- **kwargs,
651
- )
652
- encoder_hidden_states, hidden_states = hidden_states.split(
653
- [
654
- encoder_hidden_states.shape[1],
655
- hidden_states.shape[1] - encoder_hidden_states.shape[1],
656
- ],
657
- dim=1,
658
- )
659
-
660
- # hidden_states_shape = hidden_states.shape
661
- # encoder_hidden_states_shape = encoder_hidden_states.shape
662
- hidden_states = (
663
- hidden_states.reshape(-1)
664
- .contiguous()
665
- .reshape(
666
- original_hidden_states.shape,
667
- )
668
- )
669
- encoder_hidden_states = (
670
- encoder_hidden_states.reshape(-1)
671
- .contiguous()
672
- .reshape(
673
- original_encoder_hidden_states.shape,
674
- )
675
- )
676
-
677
- # hidden_states = hidden_states.contiguous()
678
- # encoder_hidden_states = encoder_hidden_states.contiguous()
679
-
680
- hidden_states_residual = hidden_states - original_hidden_states
681
- encoder_hidden_states_residual = (
682
- encoder_hidden_states - original_encoder_hidden_states
683
- )
684
-
685
- hidden_states_residual = (
686
- hidden_states_residual.reshape(-1)
687
- .contiguous()
688
- .reshape(
689
- original_hidden_states.shape,
690
- )
691
- )
692
- encoder_hidden_states_residual = (
693
- encoder_hidden_states_residual.reshape(-1)
694
- .contiguous()
695
- .reshape(
696
- original_encoder_hidden_states.shape,
697
- )
698
- )
699
-
700
- return (
701
- hidden_states,
702
- encoder_hidden_states,
703
- hidden_states_residual,
704
- encoder_hidden_states_residual,
705
- )
706
-
707
-
708
- @torch.compiler.disable
709
- def patch_cached_stats(
710
- transformer,
711
- ):
712
- # Patch the cached stats to the transformer, the cached stats
713
- # will be reset for each calling of pipe.__call__(**kwargs).
714
- if transformer is None:
715
- return
716
-
717
- # TODO: Patch more cached stats to the transformer
718
- transformer._cached_steps = get_cached_steps()
719
- transformer._residual_diffs = get_residual_diffs()