cache-dit 0.1.1.dev2__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Potentially problematic release.
This version of cache-dit might be problematic. Click here for more details.
- cache_dit/__init__.py +0 -0
- cache_dit/_version.py +21 -0
- cache_dit/cache_factory/__init__.py +166 -0
- cache_dit/cache_factory/dual_block_cache/__init__.py +0 -0
- cache_dit/cache_factory/dual_block_cache/cache_context.py +1361 -0
- cache_dit/cache_factory/dual_block_cache/diffusers_adapters/__init__.py +45 -0
- cache_dit/cache_factory/dual_block_cache/diffusers_adapters/cogvideox.py +89 -0
- cache_dit/cache_factory/dual_block_cache/diffusers_adapters/flux.py +100 -0
- cache_dit/cache_factory/dual_block_cache/diffusers_adapters/mochi.py +88 -0
- cache_dit/cache_factory/dynamic_block_prune/__init__.py +0 -0
- cache_dit/cache_factory/dynamic_block_prune/diffusers_adapters/__init__.py +45 -0
- cache_dit/cache_factory/dynamic_block_prune/diffusers_adapters/cogvideox.py +89 -0
- cache_dit/cache_factory/dynamic_block_prune/diffusers_adapters/flux.py +100 -0
- cache_dit/cache_factory/dynamic_block_prune/diffusers_adapters/mochi.py +89 -0
- cache_dit/cache_factory/dynamic_block_prune/prune_context.py +979 -0
- cache_dit/cache_factory/first_block_cache/__init__.py +0 -0
- cache_dit/cache_factory/first_block_cache/cache_context.py +727 -0
- cache_dit/cache_factory/first_block_cache/diffusers_adapters/__init__.py +53 -0
- cache_dit/cache_factory/first_block_cache/diffusers_adapters/cogvideox.py +89 -0
- cache_dit/cache_factory/first_block_cache/diffusers_adapters/flux.py +100 -0
- cache_dit/cache_factory/first_block_cache/diffusers_adapters/mochi.py +89 -0
- cache_dit/cache_factory/first_block_cache/diffusers_adapters/wan.py +98 -0
- cache_dit/cache_factory/taylorseer.py +76 -0
- cache_dit/cache_factory/utils.py +0 -0
- cache_dit/logger.py +97 -0
- cache_dit/primitives.py +152 -0
- cache_dit-0.1.1.dev2.dist-info/METADATA +31 -0
- cache_dit-0.1.1.dev2.dist-info/RECORD +30 -0
- cache_dit-0.1.1.dev2.dist-info/WHEEL +5 -0
- cache_dit-0.1.1.dev2.dist-info/top_level.txt +1 -0
|
@@ -0,0 +1,727 @@
|
|
|
1
|
+
# Adapted from: https://github.com/chengzeyi/ParaAttention/tree/main/src/para_attn/first_block_cache/context.py
|
|
2
|
+
import contextlib
|
|
3
|
+
import dataclasses
|
|
4
|
+
import logging
|
|
5
|
+
from collections import defaultdict
|
|
6
|
+
from typing import Any, DefaultDict, Dict, List, Optional, Union
|
|
7
|
+
|
|
8
|
+
import torch
|
|
9
|
+
|
|
10
|
+
import cache_dit.primitives as DP
|
|
11
|
+
from cache_dit.cache_factory.taylorseer import TaylorSeer
|
|
12
|
+
from cache_dit.logger import init_logger
|
|
13
|
+
|
|
14
|
+
logger = init_logger(__name__)
|
|
15
|
+
|
|
16
|
+
|
|
17
|
+
@dataclasses.dataclass
|
|
18
|
+
class CacheContext:
|
|
19
|
+
residual_diff_threshold: Union[torch.Tensor, float] = 0.0
|
|
20
|
+
alter_residual_diff_threshold: Optional[Union[torch.Tensor, float]] = None
|
|
21
|
+
|
|
22
|
+
downsample_factor: int = 1
|
|
23
|
+
|
|
24
|
+
enable_alter_cache: bool = False
|
|
25
|
+
num_inference_steps: int = -1
|
|
26
|
+
warmup_steps: int = 0
|
|
27
|
+
|
|
28
|
+
enable_taylorseer: bool = False
|
|
29
|
+
taylorseer_kwargs: Dict[str, Any] = dataclasses.field(default_factory=dict)
|
|
30
|
+
|
|
31
|
+
# Skip Layer Guidance, SLG
|
|
32
|
+
# https://github.com/huggingface/candle/issues/2588
|
|
33
|
+
slg_layers: Optional[List[int]] = None
|
|
34
|
+
slg_start: float = 0.0
|
|
35
|
+
slg_end: float = 0.1
|
|
36
|
+
|
|
37
|
+
taylorseer: Optional[TaylorSeer] = None
|
|
38
|
+
alter_taylorseer: Optional[TaylorSeer] = None
|
|
39
|
+
|
|
40
|
+
buffers: Dict[str, Any] = dataclasses.field(default_factory=dict)
|
|
41
|
+
incremental_name_counters: DefaultDict[str, int] = dataclasses.field(
|
|
42
|
+
default_factory=lambda: defaultdict(int),
|
|
43
|
+
)
|
|
44
|
+
|
|
45
|
+
executed_steps: int = 0
|
|
46
|
+
is_alter_cache: bool = True
|
|
47
|
+
|
|
48
|
+
max_cached_steps: int = -1
|
|
49
|
+
cached_steps: List[int] = dataclasses.field(default_factory=list)
|
|
50
|
+
residual_diffs: DefaultDict[str, float] = dataclasses.field(
|
|
51
|
+
default_factory=lambda: defaultdict(float),
|
|
52
|
+
)
|
|
53
|
+
|
|
54
|
+
def __post_init__(self):
|
|
55
|
+
if self.enable_taylorseer:
|
|
56
|
+
self.taylorseer = TaylorSeer(**self.taylorseer_kwargs)
|
|
57
|
+
if self.enable_alter_cache:
|
|
58
|
+
self.alter_taylorseer = TaylorSeer(**self.taylorseer_kwargs)
|
|
59
|
+
|
|
60
|
+
def get_incremental_name(self, name=None):
|
|
61
|
+
if name is None:
|
|
62
|
+
name = "default"
|
|
63
|
+
idx = self.incremental_name_counters[name]
|
|
64
|
+
self.incremental_name_counters[name] += 1
|
|
65
|
+
return f"{name}_{idx}"
|
|
66
|
+
|
|
67
|
+
def reset_incremental_names(self):
|
|
68
|
+
self.incremental_name_counters.clear()
|
|
69
|
+
|
|
70
|
+
def get_residual_diff_threshold(self):
|
|
71
|
+
if all(
|
|
72
|
+
(
|
|
73
|
+
self.enable_alter_cache,
|
|
74
|
+
self.is_alter_cache,
|
|
75
|
+
self.alter_residual_diff_threshold is not None,
|
|
76
|
+
)
|
|
77
|
+
):
|
|
78
|
+
residual_diff_threshold = self.alter_residual_diff_threshold
|
|
79
|
+
else:
|
|
80
|
+
residual_diff_threshold = self.residual_diff_threshold
|
|
81
|
+
if isinstance(residual_diff_threshold, torch.Tensor):
|
|
82
|
+
residual_diff_threshold = residual_diff_threshold.item()
|
|
83
|
+
return residual_diff_threshold
|
|
84
|
+
|
|
85
|
+
def get_buffer(self, name):
|
|
86
|
+
if self.enable_alter_cache and self.is_alter_cache:
|
|
87
|
+
name = f"{name}_alter"
|
|
88
|
+
return self.buffers.get(name)
|
|
89
|
+
|
|
90
|
+
def set_buffer(self, name, buffer):
|
|
91
|
+
if self.enable_alter_cache and self.is_alter_cache:
|
|
92
|
+
name = f"{name}_alter"
|
|
93
|
+
self.buffers[name] = buffer
|
|
94
|
+
|
|
95
|
+
def remove_buffer(self, name):
|
|
96
|
+
if self.enable_alter_cache and self.is_alter_cache:
|
|
97
|
+
name = f"{name}_alter"
|
|
98
|
+
if name in self.buffers:
|
|
99
|
+
del self.buffers[name]
|
|
100
|
+
|
|
101
|
+
def clear_buffers(self):
|
|
102
|
+
self.buffers.clear()
|
|
103
|
+
|
|
104
|
+
def mark_step_begin(self):
|
|
105
|
+
if not self.enable_alter_cache:
|
|
106
|
+
self.executed_steps += 1
|
|
107
|
+
else:
|
|
108
|
+
self.is_alter_cache = not self.is_alter_cache
|
|
109
|
+
if not self.is_alter_cache:
|
|
110
|
+
self.executed_steps += 1
|
|
111
|
+
if self.enable_taylorseer:
|
|
112
|
+
taylorseer = self.get_taylorseer()
|
|
113
|
+
taylorseer.mark_step_begin()
|
|
114
|
+
if self.get_current_step() == 0:
|
|
115
|
+
self.cached_steps.clear()
|
|
116
|
+
self.residual_diffs.clear()
|
|
117
|
+
|
|
118
|
+
def add_residual_diff(self, diff):
|
|
119
|
+
step = str(self.get_current_step())
|
|
120
|
+
self.residual_diffs[step] = diff
|
|
121
|
+
|
|
122
|
+
def get_residual_diffs(self):
|
|
123
|
+
return self.residual_diffs.copy()
|
|
124
|
+
|
|
125
|
+
def add_cached_step(self):
|
|
126
|
+
self.cached_steps.append(self.get_current_step())
|
|
127
|
+
|
|
128
|
+
def get_cached_steps(self):
|
|
129
|
+
return self.cached_steps.copy()
|
|
130
|
+
|
|
131
|
+
def get_taylorseer(self):
|
|
132
|
+
if self.enable_alter_cache and self.is_alter_cache:
|
|
133
|
+
return self.alter_taylorseer
|
|
134
|
+
return self.taylorseer
|
|
135
|
+
|
|
136
|
+
def is_slg_enabled(self):
|
|
137
|
+
return self.slg_layers is not None
|
|
138
|
+
|
|
139
|
+
def slg_should_skip_block(self, block_idx):
|
|
140
|
+
if not self.enable_alter_cache or not self.is_alter_cache:
|
|
141
|
+
return False
|
|
142
|
+
if self.slg_layers is None:
|
|
143
|
+
return False
|
|
144
|
+
if self.slg_start <= 0.0 and self.slg_end >= 1.0:
|
|
145
|
+
return False
|
|
146
|
+
num_inference_steps = self.num_inference_steps
|
|
147
|
+
assert (
|
|
148
|
+
num_inference_steps >= 0
|
|
149
|
+
), "num_inference_steps must be non-negative"
|
|
150
|
+
return (
|
|
151
|
+
block_idx in self.slg_layers
|
|
152
|
+
and num_inference_steps * self.slg_start
|
|
153
|
+
<= self.get_current_step()
|
|
154
|
+
< num_inference_steps * self.slg_end
|
|
155
|
+
)
|
|
156
|
+
|
|
157
|
+
def get_current_step(self):
|
|
158
|
+
return self.executed_steps - 1
|
|
159
|
+
|
|
160
|
+
def is_in_warmup(self):
|
|
161
|
+
return self.get_current_step() < self.warmup_steps
|
|
162
|
+
|
|
163
|
+
|
|
164
|
+
@torch.compiler.disable
|
|
165
|
+
def get_residual_diff_threshold():
|
|
166
|
+
cache_context = get_current_cache_context()
|
|
167
|
+
assert cache_context is not None, "cache_context must be set before"
|
|
168
|
+
return cache_context.get_residual_diff_threshold()
|
|
169
|
+
|
|
170
|
+
|
|
171
|
+
@torch.compiler.disable
|
|
172
|
+
def get_buffer(name):
|
|
173
|
+
cache_context = get_current_cache_context()
|
|
174
|
+
assert cache_context is not None, "cache_context must be set before"
|
|
175
|
+
return cache_context.get_buffer(name)
|
|
176
|
+
|
|
177
|
+
|
|
178
|
+
@torch.compiler.disable
|
|
179
|
+
def set_buffer(name, buffer):
|
|
180
|
+
cache_context = get_current_cache_context()
|
|
181
|
+
assert cache_context is not None, "cache_context must be set before"
|
|
182
|
+
cache_context.set_buffer(name, buffer)
|
|
183
|
+
|
|
184
|
+
|
|
185
|
+
@torch.compiler.disable
|
|
186
|
+
def remove_buffer(name):
|
|
187
|
+
cache_context = get_current_cache_context()
|
|
188
|
+
assert cache_context is not None, "cache_context must be set before"
|
|
189
|
+
cache_context.remove_buffer(name)
|
|
190
|
+
|
|
191
|
+
|
|
192
|
+
@torch.compiler.disable
|
|
193
|
+
def mark_step_begin():
|
|
194
|
+
cache_context = get_current_cache_context()
|
|
195
|
+
assert cache_context is not None, "cache_context must be set before"
|
|
196
|
+
cache_context.mark_step_begin()
|
|
197
|
+
|
|
198
|
+
|
|
199
|
+
@torch.compiler.disable
|
|
200
|
+
def get_current_step():
|
|
201
|
+
cache_context = get_current_cache_context()
|
|
202
|
+
assert cache_context is not None, "cache_context must be set before"
|
|
203
|
+
return cache_context.get_current_step()
|
|
204
|
+
|
|
205
|
+
|
|
206
|
+
@torch.compiler.disable
|
|
207
|
+
def get_cached_steps():
|
|
208
|
+
cache_context = get_current_cache_context()
|
|
209
|
+
assert cache_context is not None, "cache_context must be set before"
|
|
210
|
+
return cache_context.get_cached_steps()
|
|
211
|
+
|
|
212
|
+
|
|
213
|
+
@torch.compiler.disable
|
|
214
|
+
def get_max_cached_steps():
|
|
215
|
+
cache_context = get_current_cache_context()
|
|
216
|
+
assert cache_context is not None, "cache_context must be set before"
|
|
217
|
+
return cache_context.max_cached_steps
|
|
218
|
+
|
|
219
|
+
|
|
220
|
+
@torch.compiler.disable
|
|
221
|
+
def add_cached_step():
|
|
222
|
+
cache_context = get_current_cache_context()
|
|
223
|
+
assert cache_context is not None, "cache_context must be set before"
|
|
224
|
+
cache_context.add_cached_step()
|
|
225
|
+
|
|
226
|
+
|
|
227
|
+
@torch.compiler.disable
|
|
228
|
+
def add_residual_diff(diff):
|
|
229
|
+
cache_context = get_current_cache_context()
|
|
230
|
+
assert cache_context is not None, "cache_context must be set before"
|
|
231
|
+
cache_context.add_residual_diff(diff)
|
|
232
|
+
|
|
233
|
+
|
|
234
|
+
@torch.compiler.disable
|
|
235
|
+
def get_residual_diffs():
|
|
236
|
+
cache_context = get_current_cache_context()
|
|
237
|
+
assert cache_context is not None, "cache_context must be set before"
|
|
238
|
+
return cache_context.get_residual_diffs()
|
|
239
|
+
|
|
240
|
+
|
|
241
|
+
@torch.compiler.disable
|
|
242
|
+
def is_taylorseer_enabled():
|
|
243
|
+
cache_context = get_current_cache_context()
|
|
244
|
+
assert cache_context is not None, "cache_context must be set before"
|
|
245
|
+
return cache_context.enable_taylorseer
|
|
246
|
+
|
|
247
|
+
|
|
248
|
+
@torch.compiler.disable
|
|
249
|
+
def get_taylorseer():
|
|
250
|
+
cache_context = get_current_cache_context()
|
|
251
|
+
assert cache_context is not None, "cache_context must be set before"
|
|
252
|
+
return cache_context.get_taylorseer()
|
|
253
|
+
|
|
254
|
+
|
|
255
|
+
@torch.compiler.disable
|
|
256
|
+
def is_slg_enabled():
|
|
257
|
+
cache_context = get_current_cache_context()
|
|
258
|
+
assert cache_context is not None, "cache_context must be set before"
|
|
259
|
+
return cache_context.is_slg_enabled()
|
|
260
|
+
|
|
261
|
+
|
|
262
|
+
@torch.compiler.disable
|
|
263
|
+
def slg_should_skip_block(block_idx):
|
|
264
|
+
cache_context = get_current_cache_context()
|
|
265
|
+
assert cache_context is not None, "cache_context must be set before"
|
|
266
|
+
return cache_context.slg_should_skip_block(block_idx)
|
|
267
|
+
|
|
268
|
+
|
|
269
|
+
@torch.compiler.disable
|
|
270
|
+
def is_in_warmup():
|
|
271
|
+
cache_context = get_current_cache_context()
|
|
272
|
+
assert cache_context is not None, "cache_context must be set before"
|
|
273
|
+
return cache_context.is_in_warmup()
|
|
274
|
+
|
|
275
|
+
|
|
276
|
+
_current_cache_context: CacheContext = None
|
|
277
|
+
|
|
278
|
+
|
|
279
|
+
def create_cache_context(*args, **kwargs):
|
|
280
|
+
return CacheContext(*args, **kwargs)
|
|
281
|
+
|
|
282
|
+
|
|
283
|
+
def get_current_cache_context():
|
|
284
|
+
return _current_cache_context
|
|
285
|
+
|
|
286
|
+
|
|
287
|
+
def set_current_cache_context(cache_context=None):
|
|
288
|
+
global _current_cache_context
|
|
289
|
+
_current_cache_context = cache_context
|
|
290
|
+
|
|
291
|
+
|
|
292
|
+
def collect_cache_kwargs(default_attrs: dict, **kwargs):
|
|
293
|
+
# NOTE: This API will split kwargs into cache_kwargs and other_kwargs
|
|
294
|
+
# default_attrs: specific settings for different pipelines
|
|
295
|
+
cache_attrs = dataclasses.fields(CacheContext)
|
|
296
|
+
cache_attrs = [
|
|
297
|
+
attr
|
|
298
|
+
for attr in cache_attrs
|
|
299
|
+
if hasattr(
|
|
300
|
+
CacheContext,
|
|
301
|
+
attr.name,
|
|
302
|
+
)
|
|
303
|
+
]
|
|
304
|
+
cache_kwargs = {
|
|
305
|
+
attr.name: kwargs.pop(
|
|
306
|
+
attr.name,
|
|
307
|
+
getattr(CacheContext, attr.name),
|
|
308
|
+
)
|
|
309
|
+
for attr in cache_attrs
|
|
310
|
+
}
|
|
311
|
+
|
|
312
|
+
assert default_attrs is not None, "default_attrs must be set before"
|
|
313
|
+
for attr in cache_attrs:
|
|
314
|
+
if attr.name in default_attrs:
|
|
315
|
+
cache_kwargs[attr.name] = default_attrs[attr.name]
|
|
316
|
+
|
|
317
|
+
if logger.isEnabledFor(logging.DEBUG):
|
|
318
|
+
logger.debug(f"Collected Cache kwargs: {cache_kwargs}")
|
|
319
|
+
|
|
320
|
+
return cache_kwargs, kwargs
|
|
321
|
+
|
|
322
|
+
|
|
323
|
+
@contextlib.contextmanager
|
|
324
|
+
def cache_context(cache_context):
|
|
325
|
+
global _current_cache_context
|
|
326
|
+
old_cache_context = _current_cache_context
|
|
327
|
+
_current_cache_context = cache_context
|
|
328
|
+
try:
|
|
329
|
+
yield
|
|
330
|
+
finally:
|
|
331
|
+
_current_cache_context = old_cache_context
|
|
332
|
+
|
|
333
|
+
|
|
334
|
+
@torch.compiler.disable
|
|
335
|
+
def are_two_tensors_similar(
|
|
336
|
+
t1: torch.Tensor,
|
|
337
|
+
t2: torch.Tensor,
|
|
338
|
+
*,
|
|
339
|
+
threshold: float,
|
|
340
|
+
parallelized: bool = False,
|
|
341
|
+
):
|
|
342
|
+
if threshold <= 0.0:
|
|
343
|
+
return False
|
|
344
|
+
|
|
345
|
+
if t1.shape != t2.shape:
|
|
346
|
+
return False
|
|
347
|
+
|
|
348
|
+
mean_diff = (t1 - t2).abs().mean()
|
|
349
|
+
mean_t1 = t1.abs().mean()
|
|
350
|
+
if parallelized:
|
|
351
|
+
mean_diff = DP.all_reduce_sync(mean_diff, "avg")
|
|
352
|
+
mean_t1 = DP.all_reduce_sync(mean_t1, "avg")
|
|
353
|
+
diff = (mean_diff / mean_t1).item()
|
|
354
|
+
|
|
355
|
+
add_residual_diff(diff)
|
|
356
|
+
|
|
357
|
+
return diff < threshold
|
|
358
|
+
|
|
359
|
+
|
|
360
|
+
@torch.compiler.disable
|
|
361
|
+
def apply_prev_hidden_states_residual(
|
|
362
|
+
hidden_states: torch.Tensor,
|
|
363
|
+
encoder_hidden_states: torch.Tensor,
|
|
364
|
+
):
|
|
365
|
+
if is_taylorseer_enabled():
|
|
366
|
+
hidden_states_residual = get_hidden_states_residual()
|
|
367
|
+
assert (
|
|
368
|
+
hidden_states_residual is not None
|
|
369
|
+
), "hidden_states_residual must be set before"
|
|
370
|
+
hidden_states = hidden_states_residual + hidden_states
|
|
371
|
+
|
|
372
|
+
hidden_states = hidden_states.contiguous()
|
|
373
|
+
else:
|
|
374
|
+
hidden_states_residual = get_hidden_states_residual()
|
|
375
|
+
assert (
|
|
376
|
+
hidden_states_residual is not None
|
|
377
|
+
), "hidden_states_residual must be set before"
|
|
378
|
+
hidden_states = hidden_states_residual + hidden_states
|
|
379
|
+
|
|
380
|
+
encoder_hidden_states_residual = get_encoder_hidden_states_residual()
|
|
381
|
+
assert (
|
|
382
|
+
encoder_hidden_states_residual is not None
|
|
383
|
+
), "encoder_hidden_states_residual must be set before"
|
|
384
|
+
encoder_hidden_states = (
|
|
385
|
+
encoder_hidden_states_residual + encoder_hidden_states
|
|
386
|
+
)
|
|
387
|
+
|
|
388
|
+
hidden_states = hidden_states.contiguous()
|
|
389
|
+
encoder_hidden_states = encoder_hidden_states.contiguous()
|
|
390
|
+
|
|
391
|
+
return hidden_states, encoder_hidden_states
|
|
392
|
+
|
|
393
|
+
|
|
394
|
+
@torch.compiler.disable
|
|
395
|
+
def get_downsample_factor():
|
|
396
|
+
cache_context = get_current_cache_context()
|
|
397
|
+
assert cache_context is not None, "cache_context must be set before"
|
|
398
|
+
return cache_context.downsample_factor
|
|
399
|
+
|
|
400
|
+
|
|
401
|
+
@torch.compiler.disable
|
|
402
|
+
def get_can_use_cache(
|
|
403
|
+
first_hidden_states_residual: torch.Tensor,
|
|
404
|
+
parallelized: bool = False,
|
|
405
|
+
):
|
|
406
|
+
if is_in_warmup():
|
|
407
|
+
return False
|
|
408
|
+
cached_steps = get_cached_steps()
|
|
409
|
+
max_cached_steps = get_max_cached_steps()
|
|
410
|
+
if max_cached_steps >= 0 and (len(cached_steps) >= max_cached_steps):
|
|
411
|
+
return False
|
|
412
|
+
threshold = get_residual_diff_threshold()
|
|
413
|
+
if threshold <= 0.0:
|
|
414
|
+
return False
|
|
415
|
+
downsample_factor = get_downsample_factor()
|
|
416
|
+
if downsample_factor > 1:
|
|
417
|
+
first_hidden_states_residual = first_hidden_states_residual[
|
|
418
|
+
..., ::downsample_factor
|
|
419
|
+
]
|
|
420
|
+
prev_first_hidden_states_residual = get_first_hidden_states_residual()
|
|
421
|
+
can_use_cache = (
|
|
422
|
+
prev_first_hidden_states_residual is not None
|
|
423
|
+
and are_two_tensors_similar(
|
|
424
|
+
prev_first_hidden_states_residual,
|
|
425
|
+
first_hidden_states_residual,
|
|
426
|
+
threshold=threshold,
|
|
427
|
+
parallelized=parallelized,
|
|
428
|
+
)
|
|
429
|
+
)
|
|
430
|
+
return can_use_cache
|
|
431
|
+
|
|
432
|
+
|
|
433
|
+
@torch.compiler.disable
|
|
434
|
+
def set_first_hidden_states_residual(
|
|
435
|
+
first_hidden_states_residual: torch.Tensor,
|
|
436
|
+
):
|
|
437
|
+
downsample_factor = get_downsample_factor()
|
|
438
|
+
if downsample_factor > 1:
|
|
439
|
+
first_hidden_states_residual = first_hidden_states_residual[
|
|
440
|
+
..., ::downsample_factor
|
|
441
|
+
]
|
|
442
|
+
first_hidden_states_residual = first_hidden_states_residual.contiguous()
|
|
443
|
+
set_buffer("first_hidden_states_residual", first_hidden_states_residual)
|
|
444
|
+
|
|
445
|
+
|
|
446
|
+
@torch.compiler.disable
|
|
447
|
+
def get_first_hidden_states_residual():
|
|
448
|
+
return get_buffer("first_hidden_states_residual")
|
|
449
|
+
|
|
450
|
+
|
|
451
|
+
@torch.compiler.disable
|
|
452
|
+
def set_hidden_states_residual(hidden_states_residual: torch.Tensor):
|
|
453
|
+
if is_taylorseer_enabled():
|
|
454
|
+
taylorseer = get_taylorseer()
|
|
455
|
+
taylorseer.update(hidden_states_residual)
|
|
456
|
+
else:
|
|
457
|
+
set_buffer("hidden_states_residual", hidden_states_residual)
|
|
458
|
+
|
|
459
|
+
|
|
460
|
+
@torch.compiler.disable
|
|
461
|
+
def get_hidden_states_residual():
|
|
462
|
+
if is_taylorseer_enabled():
|
|
463
|
+
taylorseer = get_taylorseer()
|
|
464
|
+
return taylorseer.approximate_value()
|
|
465
|
+
else:
|
|
466
|
+
return get_buffer("hidden_states_residual")
|
|
467
|
+
|
|
468
|
+
|
|
469
|
+
@torch.compiler.disable
|
|
470
|
+
def set_encoder_hidden_states_residual(
|
|
471
|
+
encoder_hidden_states_residual: torch.Tensor,
|
|
472
|
+
):
|
|
473
|
+
if is_taylorseer_enabled():
|
|
474
|
+
return
|
|
475
|
+
set_buffer("encoder_hidden_states_residual", encoder_hidden_states_residual)
|
|
476
|
+
|
|
477
|
+
|
|
478
|
+
@torch.compiler.disable
|
|
479
|
+
def get_encoder_hidden_states_residual():
|
|
480
|
+
return get_buffer("encoder_hidden_states_residual")
|
|
481
|
+
|
|
482
|
+
|
|
483
|
+
class CachedTransformerBlocks(torch.nn.Module):
|
|
484
|
+
def __init__(
|
|
485
|
+
self,
|
|
486
|
+
transformer_blocks,
|
|
487
|
+
single_transformer_blocks=None,
|
|
488
|
+
*,
|
|
489
|
+
transformer=None,
|
|
490
|
+
return_hidden_states_first=True,
|
|
491
|
+
return_hidden_states_only=False,
|
|
492
|
+
):
|
|
493
|
+
super().__init__()
|
|
494
|
+
|
|
495
|
+
self.transformer = transformer
|
|
496
|
+
self.transformer_blocks = transformer_blocks
|
|
497
|
+
self.single_transformer_blocks = single_transformer_blocks
|
|
498
|
+
self.return_hidden_states_first = return_hidden_states_first
|
|
499
|
+
self.return_hidden_states_only = return_hidden_states_only
|
|
500
|
+
|
|
501
|
+
def forward(
|
|
502
|
+
self,
|
|
503
|
+
hidden_states: torch.Tensor,
|
|
504
|
+
encoder_hidden_states: torch.Tensor,
|
|
505
|
+
*args,
|
|
506
|
+
**kwargs,
|
|
507
|
+
):
|
|
508
|
+
original_hidden_states = hidden_states
|
|
509
|
+
first_transformer_block = self.transformer_blocks[0]
|
|
510
|
+
hidden_states = first_transformer_block(
|
|
511
|
+
hidden_states,
|
|
512
|
+
encoder_hidden_states,
|
|
513
|
+
*args,
|
|
514
|
+
**kwargs,
|
|
515
|
+
)
|
|
516
|
+
if not isinstance(hidden_states, torch.Tensor):
|
|
517
|
+
hidden_states, encoder_hidden_states = hidden_states
|
|
518
|
+
if not self.return_hidden_states_first:
|
|
519
|
+
hidden_states, encoder_hidden_states = (
|
|
520
|
+
encoder_hidden_states,
|
|
521
|
+
hidden_states,
|
|
522
|
+
)
|
|
523
|
+
first_hidden_states_residual = hidden_states - original_hidden_states
|
|
524
|
+
del original_hidden_states
|
|
525
|
+
|
|
526
|
+
mark_step_begin()
|
|
527
|
+
can_use_cache = get_can_use_cache(
|
|
528
|
+
first_hidden_states_residual,
|
|
529
|
+
parallelized=self._is_parallelized(),
|
|
530
|
+
)
|
|
531
|
+
|
|
532
|
+
torch._dynamo.graph_break()
|
|
533
|
+
if can_use_cache:
|
|
534
|
+
add_cached_step()
|
|
535
|
+
del first_hidden_states_residual
|
|
536
|
+
hidden_states, encoder_hidden_states = (
|
|
537
|
+
apply_prev_hidden_states_residual(
|
|
538
|
+
hidden_states, encoder_hidden_states
|
|
539
|
+
)
|
|
540
|
+
)
|
|
541
|
+
else:
|
|
542
|
+
set_first_hidden_states_residual(first_hidden_states_residual)
|
|
543
|
+
del first_hidden_states_residual
|
|
544
|
+
(
|
|
545
|
+
hidden_states,
|
|
546
|
+
encoder_hidden_states,
|
|
547
|
+
hidden_states_residual,
|
|
548
|
+
encoder_hidden_states_residual,
|
|
549
|
+
) = self.call_remaining_transformer_blocks(
|
|
550
|
+
hidden_states,
|
|
551
|
+
encoder_hidden_states,
|
|
552
|
+
*args,
|
|
553
|
+
**kwargs,
|
|
554
|
+
)
|
|
555
|
+
set_hidden_states_residual(hidden_states_residual)
|
|
556
|
+
set_encoder_hidden_states_residual(encoder_hidden_states_residual)
|
|
557
|
+
|
|
558
|
+
patch_cached_stats(self.transformer)
|
|
559
|
+
torch._dynamo.graph_break()
|
|
560
|
+
|
|
561
|
+
return (
|
|
562
|
+
hidden_states
|
|
563
|
+
if self.return_hidden_states_only
|
|
564
|
+
else (
|
|
565
|
+
(hidden_states, encoder_hidden_states)
|
|
566
|
+
if self.return_hidden_states_first
|
|
567
|
+
else (encoder_hidden_states, hidden_states)
|
|
568
|
+
)
|
|
569
|
+
)
|
|
570
|
+
|
|
571
|
+
def _is_parallelized(self):
|
|
572
|
+
return all(
|
|
573
|
+
(
|
|
574
|
+
self.transformer is not None,
|
|
575
|
+
getattr(self.transformer, "_is_parallelized", False),
|
|
576
|
+
)
|
|
577
|
+
)
|
|
578
|
+
|
|
579
|
+
def call_remaining_transformer_blocks(
|
|
580
|
+
self,
|
|
581
|
+
hidden_states: torch.Tensor,
|
|
582
|
+
encoder_hidden_states: torch.Tensor,
|
|
583
|
+
*args,
|
|
584
|
+
**kwargs,
|
|
585
|
+
):
|
|
586
|
+
original_hidden_states = hidden_states
|
|
587
|
+
original_encoder_hidden_states = encoder_hidden_states
|
|
588
|
+
if not is_slg_enabled():
|
|
589
|
+
for block in self.transformer_blocks[1:]:
|
|
590
|
+
hidden_states = block(
|
|
591
|
+
hidden_states,
|
|
592
|
+
encoder_hidden_states,
|
|
593
|
+
*args,
|
|
594
|
+
**kwargs,
|
|
595
|
+
)
|
|
596
|
+
if not isinstance(hidden_states, torch.Tensor):
|
|
597
|
+
hidden_states, encoder_hidden_states = hidden_states
|
|
598
|
+
if not self.return_hidden_states_first:
|
|
599
|
+
hidden_states, encoder_hidden_states = (
|
|
600
|
+
encoder_hidden_states,
|
|
601
|
+
hidden_states,
|
|
602
|
+
)
|
|
603
|
+
if self.single_transformer_blocks is not None:
|
|
604
|
+
hidden_states = torch.cat(
|
|
605
|
+
[encoder_hidden_states, hidden_states], dim=1
|
|
606
|
+
)
|
|
607
|
+
for block in self.single_transformer_blocks:
|
|
608
|
+
hidden_states = block(
|
|
609
|
+
hidden_states,
|
|
610
|
+
*args,
|
|
611
|
+
**kwargs,
|
|
612
|
+
)
|
|
613
|
+
encoder_hidden_states, hidden_states = hidden_states.split(
|
|
614
|
+
[
|
|
615
|
+
encoder_hidden_states.shape[1],
|
|
616
|
+
hidden_states.shape[1] - encoder_hidden_states.shape[1],
|
|
617
|
+
],
|
|
618
|
+
dim=1,
|
|
619
|
+
)
|
|
620
|
+
else:
|
|
621
|
+
for i, encoder_block in enumerate(self.transformer_blocks[1:]):
|
|
622
|
+
if slg_should_skip_block(i + 1):
|
|
623
|
+
continue
|
|
624
|
+
hidden_states = encoder_block(
|
|
625
|
+
hidden_states,
|
|
626
|
+
encoder_hidden_states,
|
|
627
|
+
*args,
|
|
628
|
+
**kwargs,
|
|
629
|
+
)
|
|
630
|
+
if not isinstance(hidden_states, torch.Tensor):
|
|
631
|
+
hidden_states, encoder_hidden_states = hidden_states
|
|
632
|
+
if not self.return_hidden_states_first:
|
|
633
|
+
hidden_states, encoder_hidden_states = (
|
|
634
|
+
encoder_hidden_states,
|
|
635
|
+
hidden_states,
|
|
636
|
+
)
|
|
637
|
+
if self.single_transformer_blocks is not None:
|
|
638
|
+
hidden_states = torch.cat(
|
|
639
|
+
[encoder_hidden_states, hidden_states], dim=1
|
|
640
|
+
)
|
|
641
|
+
for i, block in enumerate(self.single_transformer_blocks):
|
|
642
|
+
if slg_should_skip_block(len(self.transformer_blocks) + i):
|
|
643
|
+
continue
|
|
644
|
+
hidden_states = block(
|
|
645
|
+
hidden_states,
|
|
646
|
+
*args,
|
|
647
|
+
**kwargs,
|
|
648
|
+
)
|
|
649
|
+
encoder_hidden_states, hidden_states = hidden_states.split(
|
|
650
|
+
[
|
|
651
|
+
encoder_hidden_states.shape[1],
|
|
652
|
+
hidden_states.shape[1] - encoder_hidden_states.shape[1],
|
|
653
|
+
],
|
|
654
|
+
dim=1,
|
|
655
|
+
)
|
|
656
|
+
|
|
657
|
+
# hidden_states_shape = hidden_states.shape
|
|
658
|
+
# encoder_hidden_states_shape = encoder_hidden_states.shape
|
|
659
|
+
hidden_states = (
|
|
660
|
+
hidden_states.reshape(-1)
|
|
661
|
+
.contiguous()
|
|
662
|
+
.reshape(
|
|
663
|
+
original_hidden_states.shape,
|
|
664
|
+
)
|
|
665
|
+
)
|
|
666
|
+
encoder_hidden_states = (
|
|
667
|
+
encoder_hidden_states.reshape(-1)
|
|
668
|
+
.contiguous()
|
|
669
|
+
.reshape(
|
|
670
|
+
original_encoder_hidden_states.shape,
|
|
671
|
+
)
|
|
672
|
+
)
|
|
673
|
+
|
|
674
|
+
# hidden_states = hidden_states.contiguous()
|
|
675
|
+
# encoder_hidden_states = encoder_hidden_states.contiguous()
|
|
676
|
+
|
|
677
|
+
hidden_states_residual = hidden_states - original_hidden_states
|
|
678
|
+
encoder_hidden_states_residual = (
|
|
679
|
+
encoder_hidden_states - original_encoder_hidden_states
|
|
680
|
+
)
|
|
681
|
+
|
|
682
|
+
hidden_states_residual = (
|
|
683
|
+
hidden_states_residual.reshape(-1)
|
|
684
|
+
.contiguous()
|
|
685
|
+
.reshape(
|
|
686
|
+
original_hidden_states.shape,
|
|
687
|
+
)
|
|
688
|
+
)
|
|
689
|
+
encoder_hidden_states_residual = (
|
|
690
|
+
encoder_hidden_states_residual.reshape(-1)
|
|
691
|
+
.contiguous()
|
|
692
|
+
.reshape(
|
|
693
|
+
original_encoder_hidden_states.shape,
|
|
694
|
+
)
|
|
695
|
+
)
|
|
696
|
+
|
|
697
|
+
return (
|
|
698
|
+
hidden_states,
|
|
699
|
+
encoder_hidden_states,
|
|
700
|
+
hidden_states_residual,
|
|
701
|
+
encoder_hidden_states_residual,
|
|
702
|
+
)
|
|
703
|
+
|
|
704
|
+
|
|
705
|
+
@torch.compiler.disable
|
|
706
|
+
def patch_cached_stats(
|
|
707
|
+
transformer,
|
|
708
|
+
):
|
|
709
|
+
# Patch the cached stats to the transformer, the cached stats
|
|
710
|
+
# will be reset for each calling of pipe.__call__(**kwargs).
|
|
711
|
+
if transformer is None:
|
|
712
|
+
return
|
|
713
|
+
|
|
714
|
+
cached_transformer_blocks = getattr(transformer, "transformer_blocks", None)
|
|
715
|
+
if cached_transformer_blocks is None:
|
|
716
|
+
return
|
|
717
|
+
|
|
718
|
+
if isinstance(cached_transformer_blocks, torch.nn.ModuleList):
|
|
719
|
+
cached_transformer_blocks = cached_transformer_blocks[0]
|
|
720
|
+
if not isinstance(
|
|
721
|
+
cached_transformer_blocks, CachedTransformerBlocks
|
|
722
|
+
) or not isinstance(transformer, torch.nn.Module):
|
|
723
|
+
return
|
|
724
|
+
|
|
725
|
+
# TODO: Patch more cached stats to the transformer
|
|
726
|
+
transformer._cached_steps = get_cached_steps()
|
|
727
|
+
transformer._residual_diffs = get_residual_diffs()
|