cache-dit 0.1.1.dev2__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Potentially problematic release.
This version of cache-dit might be problematic. Click here for more details.
- cache_dit/__init__.py +0 -0
- cache_dit/_version.py +21 -0
- cache_dit/cache_factory/__init__.py +166 -0
- cache_dit/cache_factory/dual_block_cache/__init__.py +0 -0
- cache_dit/cache_factory/dual_block_cache/cache_context.py +1361 -0
- cache_dit/cache_factory/dual_block_cache/diffusers_adapters/__init__.py +45 -0
- cache_dit/cache_factory/dual_block_cache/diffusers_adapters/cogvideox.py +89 -0
- cache_dit/cache_factory/dual_block_cache/diffusers_adapters/flux.py +100 -0
- cache_dit/cache_factory/dual_block_cache/diffusers_adapters/mochi.py +88 -0
- cache_dit/cache_factory/dynamic_block_prune/__init__.py +0 -0
- cache_dit/cache_factory/dynamic_block_prune/diffusers_adapters/__init__.py +45 -0
- cache_dit/cache_factory/dynamic_block_prune/diffusers_adapters/cogvideox.py +89 -0
- cache_dit/cache_factory/dynamic_block_prune/diffusers_adapters/flux.py +100 -0
- cache_dit/cache_factory/dynamic_block_prune/diffusers_adapters/mochi.py +89 -0
- cache_dit/cache_factory/dynamic_block_prune/prune_context.py +979 -0
- cache_dit/cache_factory/first_block_cache/__init__.py +0 -0
- cache_dit/cache_factory/first_block_cache/cache_context.py +727 -0
- cache_dit/cache_factory/first_block_cache/diffusers_adapters/__init__.py +53 -0
- cache_dit/cache_factory/first_block_cache/diffusers_adapters/cogvideox.py +89 -0
- cache_dit/cache_factory/first_block_cache/diffusers_adapters/flux.py +100 -0
- cache_dit/cache_factory/first_block_cache/diffusers_adapters/mochi.py +89 -0
- cache_dit/cache_factory/first_block_cache/diffusers_adapters/wan.py +98 -0
- cache_dit/cache_factory/taylorseer.py +76 -0
- cache_dit/cache_factory/utils.py +0 -0
- cache_dit/logger.py +97 -0
- cache_dit/primitives.py +152 -0
- cache_dit-0.1.1.dev2.dist-info/METADATA +31 -0
- cache_dit-0.1.1.dev2.dist-info/RECORD +30 -0
- cache_dit-0.1.1.dev2.dist-info/WHEEL +5 -0
- cache_dit-0.1.1.dev2.dist-info/top_level.txt +1 -0
|
@@ -0,0 +1,1361 @@
|
|
|
1
|
+
# Adapted from: https://github.com/chengzeyi/ParaAttention/tree/main/src/para_attn/first_block_cache/context.py
|
|
2
|
+
|
|
3
|
+
import logging
|
|
4
|
+
import contextlib
|
|
5
|
+
import dataclasses
|
|
6
|
+
from collections import defaultdict
|
|
7
|
+
from typing import Any, DefaultDict, Dict, List, Optional, Union
|
|
8
|
+
|
|
9
|
+
import torch
|
|
10
|
+
|
|
11
|
+
import cache_dit.primitives as DP
|
|
12
|
+
from cache_dit.logger import init_logger
|
|
13
|
+
|
|
14
|
+
logger = init_logger(__name__)
|
|
15
|
+
|
|
16
|
+
|
|
17
|
+
@dataclasses.dataclass
|
|
18
|
+
class DBCacheContext:
|
|
19
|
+
# Dual Block Cache
|
|
20
|
+
# Fn=1, Bn=0, means FB Cache, otherwise, Dual Block Cache
|
|
21
|
+
Fn_compute_blocks: int = 1
|
|
22
|
+
Bn_compute_blocks: int = 0
|
|
23
|
+
# We have added residual cache pattern for selected compute blocks
|
|
24
|
+
Fn_compute_blocks_ids: List[int] = dataclasses.field(default_factory=list)
|
|
25
|
+
Bn_compute_blocks_ids: List[int] = dataclasses.field(default_factory=list)
|
|
26
|
+
# non compute blocks diff threshold, we don't skip the non
|
|
27
|
+
# compute blocks if the diff >= threshold
|
|
28
|
+
non_compute_blocks_diff_threshold: float = 0.08
|
|
29
|
+
max_Fn_compute_blocks: int = -1
|
|
30
|
+
max_Bn_compute_blocks: int = -1
|
|
31
|
+
# L1 hidden states or residual diff threshold for Fn
|
|
32
|
+
residual_diff_threshold: Union[torch.Tensor, float] = 0.0
|
|
33
|
+
l1_hidden_states_diff_threshold: float = None
|
|
34
|
+
important_condition_threshold: float = 0.0
|
|
35
|
+
|
|
36
|
+
# Alter Cache Settings
|
|
37
|
+
# Pattern: 0 F 1 T 2 F 3 T 4 F 5 T ...
|
|
38
|
+
enable_alter_cache: bool = False
|
|
39
|
+
is_alter_cache: bool = True
|
|
40
|
+
# 1.0 means we always cache the residuals if alter_cache is enabled.
|
|
41
|
+
alter_residual_diff_threshold: Optional[Union[torch.Tensor, float]] = 1.0
|
|
42
|
+
|
|
43
|
+
# Buffer for storing the residuals and other tensors
|
|
44
|
+
buffers: Dict[str, Any] = dataclasses.field(default_factory=dict)
|
|
45
|
+
incremental_name_counters: DefaultDict[str, int] = dataclasses.field(
|
|
46
|
+
default_factory=lambda: defaultdict(int),
|
|
47
|
+
)
|
|
48
|
+
|
|
49
|
+
# Other settings
|
|
50
|
+
downsample_factor: int = 1
|
|
51
|
+
num_inference_steps: int = -1
|
|
52
|
+
warmup_steps: int = 0 # DON'T Cache in warmup steps
|
|
53
|
+
# DON'T Cache if the number of cached steps >= max_cached_steps
|
|
54
|
+
max_cached_steps: int = -1
|
|
55
|
+
|
|
56
|
+
# Statistics for botch alter cache and non-alter cache
|
|
57
|
+
# Record the steps that have been cached, both alter cache and non-alter cache
|
|
58
|
+
executed_steps: int = 0 # cache + non-cache steps
|
|
59
|
+
cached_steps: List[int] = dataclasses.field(default_factory=list)
|
|
60
|
+
residual_diffs: DefaultDict[str, float] = dataclasses.field(
|
|
61
|
+
default_factory=lambda: defaultdict(float),
|
|
62
|
+
)
|
|
63
|
+
|
|
64
|
+
def get_incremental_name(self, name=None):
|
|
65
|
+
if name is None:
|
|
66
|
+
name = "default"
|
|
67
|
+
idx = self.incremental_name_counters[name]
|
|
68
|
+
self.incremental_name_counters[name] += 1
|
|
69
|
+
return f"{name}_{idx}"
|
|
70
|
+
|
|
71
|
+
def reset_incremental_names(self):
|
|
72
|
+
self.incremental_name_counters.clear()
|
|
73
|
+
|
|
74
|
+
def get_residual_diff_threshold(self):
|
|
75
|
+
if self.enable_alter_cache:
|
|
76
|
+
residual_diff_threshold = self.alter_residual_diff_threshold
|
|
77
|
+
else:
|
|
78
|
+
residual_diff_threshold = self.residual_diff_threshold
|
|
79
|
+
if self.l1_hidden_states_diff_threshold is not None:
|
|
80
|
+
# Use the L1 hidden states diff threshold if set
|
|
81
|
+
residual_diff_threshold = self.l1_hidden_states_diff_threshold
|
|
82
|
+
if isinstance(residual_diff_threshold, torch.Tensor):
|
|
83
|
+
residual_diff_threshold = residual_diff_threshold.item()
|
|
84
|
+
return residual_diff_threshold
|
|
85
|
+
|
|
86
|
+
def get_buffer(self, name):
|
|
87
|
+
if self.enable_alter_cache and self.is_alter_cache:
|
|
88
|
+
name = f"{name}_alter"
|
|
89
|
+
return self.buffers.get(name)
|
|
90
|
+
|
|
91
|
+
def set_buffer(self, name, buffer):
|
|
92
|
+
if self.enable_alter_cache and self.is_alter_cache:
|
|
93
|
+
name = f"{name}_alter"
|
|
94
|
+
self.buffers[name] = buffer
|
|
95
|
+
|
|
96
|
+
def remove_buffer(self, name):
|
|
97
|
+
if self.enable_alter_cache and self.is_alter_cache:
|
|
98
|
+
name = f"{name}_alter"
|
|
99
|
+
if name in self.buffers:
|
|
100
|
+
del self.buffers[name]
|
|
101
|
+
|
|
102
|
+
def clear_buffers(self):
|
|
103
|
+
self.buffers.clear()
|
|
104
|
+
|
|
105
|
+
def mark_step_begin(self):
|
|
106
|
+
if not self.enable_alter_cache:
|
|
107
|
+
self.executed_steps += 1
|
|
108
|
+
else:
|
|
109
|
+
self.executed_steps += 1
|
|
110
|
+
# 0 F 1 T 2 F 3 T 4 F 5 T ...
|
|
111
|
+
self.is_alter_cache = not self.is_alter_cache
|
|
112
|
+
|
|
113
|
+
# Reset the cached steps and residual diffs at the beginning
|
|
114
|
+
# of each inference.
|
|
115
|
+
if self.get_current_step() == 0:
|
|
116
|
+
self.cached_steps.clear()
|
|
117
|
+
self.residual_diffs.clear()
|
|
118
|
+
self.reset_incremental_names()
|
|
119
|
+
|
|
120
|
+
def add_residual_diff(self, diff):
|
|
121
|
+
step = str(self.get_current_step())
|
|
122
|
+
if step not in self.residual_diffs:
|
|
123
|
+
# Only add the diff if it is not already recorded for this step
|
|
124
|
+
self.residual_diffs[step] = diff
|
|
125
|
+
|
|
126
|
+
def get_residual_diffs(self):
|
|
127
|
+
return self.residual_diffs.copy()
|
|
128
|
+
|
|
129
|
+
def add_cached_step(self):
|
|
130
|
+
self.cached_steps.append(self.get_current_step())
|
|
131
|
+
|
|
132
|
+
def get_cached_steps(self):
|
|
133
|
+
return self.cached_steps.copy()
|
|
134
|
+
|
|
135
|
+
def get_current_step(self):
|
|
136
|
+
return self.executed_steps - 1
|
|
137
|
+
|
|
138
|
+
def is_in_warmup(self):
|
|
139
|
+
return self.get_current_step() < self.warmup_steps
|
|
140
|
+
|
|
141
|
+
|
|
142
|
+
@torch.compiler.disable
|
|
143
|
+
def get_residual_diff_threshold():
|
|
144
|
+
cache_context = get_current_cache_context()
|
|
145
|
+
assert cache_context is not None, "cache_context must be set before"
|
|
146
|
+
return cache_context.get_residual_diff_threshold()
|
|
147
|
+
|
|
148
|
+
|
|
149
|
+
@torch.compiler.disable
|
|
150
|
+
def get_buffer(name):
|
|
151
|
+
cache_context = get_current_cache_context()
|
|
152
|
+
assert cache_context is not None, "cache_context must be set before"
|
|
153
|
+
return cache_context.get_buffer(name)
|
|
154
|
+
|
|
155
|
+
|
|
156
|
+
@torch.compiler.disable
|
|
157
|
+
def set_buffer(name, buffer):
|
|
158
|
+
cache_context = get_current_cache_context()
|
|
159
|
+
assert cache_context is not None, "cache_context must be set before"
|
|
160
|
+
cache_context.set_buffer(name, buffer)
|
|
161
|
+
|
|
162
|
+
|
|
163
|
+
@torch.compiler.disable
|
|
164
|
+
def remove_buffer(name):
|
|
165
|
+
cache_context = get_current_cache_context()
|
|
166
|
+
assert cache_context is not None, "cache_context must be set before"
|
|
167
|
+
cache_context.remove_buffer(name)
|
|
168
|
+
|
|
169
|
+
|
|
170
|
+
@torch.compiler.disable
|
|
171
|
+
def mark_step_begin():
|
|
172
|
+
cache_context = get_current_cache_context()
|
|
173
|
+
assert cache_context is not None, "cache_context must be set before"
|
|
174
|
+
cache_context.mark_step_begin()
|
|
175
|
+
|
|
176
|
+
|
|
177
|
+
@torch.compiler.disable
|
|
178
|
+
def get_current_step():
|
|
179
|
+
cache_context = get_current_cache_context()
|
|
180
|
+
assert cache_context is not None, "cache_context must be set before"
|
|
181
|
+
return cache_context.get_current_step()
|
|
182
|
+
|
|
183
|
+
|
|
184
|
+
@torch.compiler.disable
|
|
185
|
+
def get_cached_steps():
|
|
186
|
+
cache_context = get_current_cache_context()
|
|
187
|
+
assert cache_context is not None, "cache_context must be set before"
|
|
188
|
+
return cache_context.get_cached_steps()
|
|
189
|
+
|
|
190
|
+
|
|
191
|
+
@torch.compiler.disable
|
|
192
|
+
def get_max_cached_steps():
|
|
193
|
+
cache_context = get_current_cache_context()
|
|
194
|
+
assert cache_context is not None, "cache_context must be set before"
|
|
195
|
+
return cache_context.max_cached_steps
|
|
196
|
+
|
|
197
|
+
|
|
198
|
+
@torch.compiler.disable
|
|
199
|
+
def add_cached_step():
|
|
200
|
+
cache_context = get_current_cache_context()
|
|
201
|
+
assert cache_context is not None, "cache_context must be set before"
|
|
202
|
+
cache_context.add_cached_step()
|
|
203
|
+
|
|
204
|
+
|
|
205
|
+
@torch.compiler.disable
|
|
206
|
+
def add_residual_diff(diff):
|
|
207
|
+
cache_context = get_current_cache_context()
|
|
208
|
+
assert cache_context is not None, "cache_context must be set before"
|
|
209
|
+
cache_context.add_residual_diff(diff)
|
|
210
|
+
|
|
211
|
+
|
|
212
|
+
@torch.compiler.disable
|
|
213
|
+
def get_residual_diffs():
|
|
214
|
+
cache_context = get_current_cache_context()
|
|
215
|
+
assert cache_context is not None, "cache_context must be set before"
|
|
216
|
+
return cache_context.get_residual_diffs()
|
|
217
|
+
|
|
218
|
+
|
|
219
|
+
@torch.compiler.disable
|
|
220
|
+
def is_alter_cache_enabled():
|
|
221
|
+
cache_context = get_current_cache_context()
|
|
222
|
+
assert cache_context is not None, "cache_context must be set before"
|
|
223
|
+
return cache_context.enable_alter_cache
|
|
224
|
+
|
|
225
|
+
|
|
226
|
+
@torch.compiler.disable
|
|
227
|
+
def is_alter_cache():
|
|
228
|
+
cache_context = get_current_cache_context()
|
|
229
|
+
assert cache_context is not None, "cache_context must be set before"
|
|
230
|
+
return cache_context.is_alter_cache
|
|
231
|
+
|
|
232
|
+
|
|
233
|
+
@torch.compiler.disable
|
|
234
|
+
def is_in_warmup():
|
|
235
|
+
cache_context = get_current_cache_context()
|
|
236
|
+
assert cache_context is not None, "cache_context must be set before"
|
|
237
|
+
return cache_context.is_in_warmup()
|
|
238
|
+
|
|
239
|
+
|
|
240
|
+
@torch.compiler.disable
|
|
241
|
+
def is_l1_diff_enabled():
|
|
242
|
+
cache_context = get_current_cache_context()
|
|
243
|
+
assert cache_context is not None, "cache_context must be set before"
|
|
244
|
+
return (
|
|
245
|
+
cache_context.l1_hidden_states_diff_threshold is not None
|
|
246
|
+
and cache_context.l1_hidden_states_diff_threshold > 0.0
|
|
247
|
+
)
|
|
248
|
+
|
|
249
|
+
|
|
250
|
+
@torch.compiler.disable
|
|
251
|
+
def get_important_condition_threshold():
|
|
252
|
+
cache_context = get_current_cache_context()
|
|
253
|
+
assert cache_context is not None, "cache_context must be set before"
|
|
254
|
+
return cache_context.important_condition_threshold
|
|
255
|
+
|
|
256
|
+
|
|
257
|
+
@torch.compiler.disable
|
|
258
|
+
def non_compute_blocks_diff_threshold():
|
|
259
|
+
cache_context = get_current_cache_context()
|
|
260
|
+
assert cache_context is not None, "cache_context must be set before"
|
|
261
|
+
return cache_context.non_compute_blocks_diff_threshold
|
|
262
|
+
|
|
263
|
+
|
|
264
|
+
@torch.compiler.disable
|
|
265
|
+
def Fn_compute_blocks():
|
|
266
|
+
cache_context = get_current_cache_context()
|
|
267
|
+
assert cache_context is not None, "cache_context must be set before"
|
|
268
|
+
assert (
|
|
269
|
+
cache_context.Fn_compute_blocks >= 1
|
|
270
|
+
), "Fn_compute_blocks must be >= 1"
|
|
271
|
+
if cache_context.max_Fn_compute_blocks > 0:
|
|
272
|
+
# NOTE: Fn_compute_blocks can be 1, which means FB Cache
|
|
273
|
+
# but it must be less than or equal to max_Fn_compute_blocks
|
|
274
|
+
assert (
|
|
275
|
+
cache_context.Fn_compute_blocks
|
|
276
|
+
<= cache_context.max_Fn_compute_blocks
|
|
277
|
+
), (
|
|
278
|
+
f"Fn_compute_blocks must be <= {cache_context.max_Fn_compute_blocks}, "
|
|
279
|
+
f"but got {cache_context.Fn_compute_blocks}"
|
|
280
|
+
)
|
|
281
|
+
return cache_context.Fn_compute_blocks
|
|
282
|
+
|
|
283
|
+
|
|
284
|
+
@torch.compiler.disable
|
|
285
|
+
def Fn_compute_blocks_ids():
|
|
286
|
+
cache_context = get_current_cache_context()
|
|
287
|
+
assert cache_context is not None, "cache_context must be set before"
|
|
288
|
+
assert (
|
|
289
|
+
len(cache_context.Fn_compute_blocks_ids)
|
|
290
|
+
<= cache_context.Fn_compute_blocks
|
|
291
|
+
), (
|
|
292
|
+
"The num of Fn_compute_blocks_ids must be <= Fn_compute_blocks "
|
|
293
|
+
f"{cache_context.Fn_compute_blocks}, but got "
|
|
294
|
+
f"{len(cache_context.Fn_compute_blocks_ids)}"
|
|
295
|
+
)
|
|
296
|
+
return cache_context.Fn_compute_blocks_ids
|
|
297
|
+
|
|
298
|
+
|
|
299
|
+
@torch.compiler.disable
|
|
300
|
+
def Bn_compute_blocks():
|
|
301
|
+
cache_context = get_current_cache_context()
|
|
302
|
+
assert cache_context is not None, "cache_context must be set before"
|
|
303
|
+
assert (
|
|
304
|
+
cache_context.Bn_compute_blocks >= 0
|
|
305
|
+
), "Bn_compute_blocks must be >= 0"
|
|
306
|
+
if cache_context.max_Bn_compute_blocks > 0:
|
|
307
|
+
# NOTE: Bn_compute_blocks can be 0, which means FB Cache
|
|
308
|
+
# but it must be less than or equal to max_Bn_compute_blocks
|
|
309
|
+
assert (
|
|
310
|
+
cache_context.Bn_compute_blocks
|
|
311
|
+
<= cache_context.max_Bn_compute_blocks
|
|
312
|
+
), (
|
|
313
|
+
f"Bn_compute_blocks must be <= {cache_context.max_Bn_compute_blocks}, "
|
|
314
|
+
f"but got {cache_context.Bn_compute_blocks}"
|
|
315
|
+
)
|
|
316
|
+
return cache_context.Bn_compute_blocks
|
|
317
|
+
|
|
318
|
+
|
|
319
|
+
@torch.compiler.disable
|
|
320
|
+
def Bn_compute_blocks_ids():
|
|
321
|
+
cache_context = get_current_cache_context()
|
|
322
|
+
assert cache_context is not None, "cache_context must be set before"
|
|
323
|
+
assert (
|
|
324
|
+
len(cache_context.Bn_compute_blocks_ids)
|
|
325
|
+
<= cache_context.Bn_compute_blocks
|
|
326
|
+
), (
|
|
327
|
+
"The num of Bn_compute_blocks_ids must be <= Bn_compute_blocks "
|
|
328
|
+
f"{cache_context.Bn_compute_blocks}, but got "
|
|
329
|
+
f"{len(cache_context.Bn_compute_blocks_ids)}"
|
|
330
|
+
)
|
|
331
|
+
return cache_context.Bn_compute_blocks_ids
|
|
332
|
+
|
|
333
|
+
|
|
334
|
+
_current_cache_context: DBCacheContext = None
|
|
335
|
+
|
|
336
|
+
|
|
337
|
+
def create_cache_context(*args, **kwargs):
|
|
338
|
+
return DBCacheContext(*args, **kwargs)
|
|
339
|
+
|
|
340
|
+
|
|
341
|
+
def get_current_cache_context():
|
|
342
|
+
return _current_cache_context
|
|
343
|
+
|
|
344
|
+
|
|
345
|
+
def set_current_cache_context(cache_context=None):
|
|
346
|
+
global _current_cache_context
|
|
347
|
+
_current_cache_context = cache_context
|
|
348
|
+
|
|
349
|
+
|
|
350
|
+
def collect_cache_kwargs(default_attrs: dict, **kwargs):
|
|
351
|
+
# NOTE: This API will split kwargs into cache_kwargs and other_kwargs
|
|
352
|
+
# default_attrs: specific settings for different pipelines
|
|
353
|
+
cache_attrs = dataclasses.fields(DBCacheContext)
|
|
354
|
+
cache_attrs = [
|
|
355
|
+
attr
|
|
356
|
+
for attr in cache_attrs
|
|
357
|
+
if hasattr(
|
|
358
|
+
DBCacheContext,
|
|
359
|
+
attr.name,
|
|
360
|
+
)
|
|
361
|
+
]
|
|
362
|
+
cache_kwargs = {
|
|
363
|
+
attr.name: kwargs.pop(
|
|
364
|
+
attr.name,
|
|
365
|
+
getattr(DBCacheContext, attr.name),
|
|
366
|
+
)
|
|
367
|
+
for attr in cache_attrs
|
|
368
|
+
}
|
|
369
|
+
|
|
370
|
+
# Manually set sequence fields, namely, Fn_compute_blocks_ids
|
|
371
|
+
# and Bn_compute_blocks_ids, which are lists or sets.
|
|
372
|
+
cache_kwargs["Fn_compute_blocks_ids"] = kwargs.pop(
|
|
373
|
+
"Fn_compute_blocks_ids",
|
|
374
|
+
[],
|
|
375
|
+
)
|
|
376
|
+
cache_kwargs["Bn_compute_blocks_ids"] = kwargs.pop(
|
|
377
|
+
"Bn_compute_blocks_ids",
|
|
378
|
+
[],
|
|
379
|
+
)
|
|
380
|
+
|
|
381
|
+
assert default_attrs is not None, "default_attrs must be set before"
|
|
382
|
+
for attr in cache_attrs:
|
|
383
|
+
if attr.name in default_attrs:
|
|
384
|
+
cache_kwargs[attr.name] = default_attrs[attr.name]
|
|
385
|
+
|
|
386
|
+
if logger.isEnabledFor(logging.DEBUG):
|
|
387
|
+
logger.debug(f"Collected DBCache kwargs: {cache_kwargs}")
|
|
388
|
+
|
|
389
|
+
return cache_kwargs, kwargs
|
|
390
|
+
|
|
391
|
+
|
|
392
|
+
@contextlib.contextmanager
|
|
393
|
+
def cache_context(cache_context):
|
|
394
|
+
global _current_cache_context
|
|
395
|
+
old_cache_context = _current_cache_context
|
|
396
|
+
_current_cache_context = cache_context
|
|
397
|
+
try:
|
|
398
|
+
yield
|
|
399
|
+
finally:
|
|
400
|
+
_current_cache_context = old_cache_context
|
|
401
|
+
|
|
402
|
+
|
|
403
|
+
@torch.compiler.disable
|
|
404
|
+
def are_two_tensors_similar(
|
|
405
|
+
t1: torch.Tensor, # prev residual R(t-1,n) = H(t-1,n) - H(t-1,0)
|
|
406
|
+
t2: torch.Tensor, # curr residual R(t ,n) = H(t ,n) - H(t ,0)
|
|
407
|
+
*,
|
|
408
|
+
threshold: float,
|
|
409
|
+
parallelized: bool = False,
|
|
410
|
+
prefix: str = "Fn", # for debugging
|
|
411
|
+
):
|
|
412
|
+
# Special case for threshold, 0.0 means the threshold is disabled, -1.0 means
|
|
413
|
+
# the threshold is always enabled, -2.0 means the shape is not matched.
|
|
414
|
+
if threshold <= 0.0:
|
|
415
|
+
add_residual_diff(-0.0)
|
|
416
|
+
return False
|
|
417
|
+
|
|
418
|
+
if threshold >= 1.0:
|
|
419
|
+
# If threshold is 1.0 or more, we consider them always similar.
|
|
420
|
+
add_residual_diff(-1.0)
|
|
421
|
+
return True
|
|
422
|
+
|
|
423
|
+
if t1.shape != t2.shape:
|
|
424
|
+
if logger.isEnabledFor(logging.DEBUG):
|
|
425
|
+
logger.debug(f"{prefix}, shape error: {t1.shape} != {t2.shape}")
|
|
426
|
+
add_residual_diff(-2.0)
|
|
427
|
+
return False
|
|
428
|
+
|
|
429
|
+
# Find the most significant token through t1 and t2, and
|
|
430
|
+
# consider the diff of the significant token. The more significant,
|
|
431
|
+
# the more important.
|
|
432
|
+
condition_thresh = get_important_condition_threshold()
|
|
433
|
+
if condition_thresh > 0.0:
|
|
434
|
+
raw_diff = (t1 - t2).abs() # [B, seq_len, d]
|
|
435
|
+
token_m_df = raw_diff.mean(dim=-1) # [B, seq_len]
|
|
436
|
+
token_m_t1 = t1.abs().mean(dim=-1) # [B, seq_len]
|
|
437
|
+
# D = (t1 - t2) / t1 = 1 - (t2 / t1), if D = 0, then t1 = t2.
|
|
438
|
+
token_diff = token_m_df / token_m_t1 # [B, seq_len]
|
|
439
|
+
condition = token_diff > condition_thresh # [B, seq_len]
|
|
440
|
+
if condition.sum() > 0:
|
|
441
|
+
condition = condition.unsqueeze(-1) # [B, seq_len, 1]
|
|
442
|
+
condition = condition.expand_as(raw_diff) # [B, seq_len, d]
|
|
443
|
+
mean_diff = raw_diff[condition].mean()
|
|
444
|
+
mean_t1 = t1[condition].abs().mean()
|
|
445
|
+
else:
|
|
446
|
+
mean_diff = (t1 - t2).abs().mean()
|
|
447
|
+
mean_t1 = t1.abs().mean()
|
|
448
|
+
else:
|
|
449
|
+
# Use the mean of the absolute difference of the tensors
|
|
450
|
+
mean_diff = (t1 - t2).abs().mean()
|
|
451
|
+
mean_t1 = t1.abs().mean()
|
|
452
|
+
|
|
453
|
+
if parallelized:
|
|
454
|
+
mean_diff = DP.all_reduce_sync(mean_diff, "avg")
|
|
455
|
+
mean_t1 = DP.all_reduce_sync(mean_t1, "avg")
|
|
456
|
+
|
|
457
|
+
# D = (t1 - t2) / t1 = 1 - (t2 / t1), if D = 0, then t1 = t2.
|
|
458
|
+
# Futher, if we assume that (H(t, 0) - H(t-1,0)) ~ 0, then,
|
|
459
|
+
# H(t-1,n) ~ H(t ,n), which means the hidden states are similar.
|
|
460
|
+
diff = (mean_diff / mean_t1).item()
|
|
461
|
+
|
|
462
|
+
if logger.isEnabledFor(logging.DEBUG):
|
|
463
|
+
logger.debug(f"{prefix}, diff: {diff:.6f}, threshold: {threshold:.6f}")
|
|
464
|
+
|
|
465
|
+
add_residual_diff(diff)
|
|
466
|
+
|
|
467
|
+
return diff < threshold
|
|
468
|
+
|
|
469
|
+
|
|
470
|
+
# Fn buffers
|
|
471
|
+
@torch.compiler.disable
|
|
472
|
+
def set_Fn_buffer(buffer: torch.Tensor, prefix: str = "Fn"):
|
|
473
|
+
# Set hidden_states or residual for Fn blocks.
|
|
474
|
+
downsample_factor = get_downsample_factor()
|
|
475
|
+
if downsample_factor > 1:
|
|
476
|
+
buffer = buffer[..., ::downsample_factor]
|
|
477
|
+
buffer = buffer.contiguous()
|
|
478
|
+
set_buffer(f"{prefix}_buffer", buffer)
|
|
479
|
+
|
|
480
|
+
|
|
481
|
+
@torch.compiler.disable
|
|
482
|
+
def get_Fn_buffer(prefix: str = "Fn"):
|
|
483
|
+
return get_buffer(f"{prefix}_buffer")
|
|
484
|
+
|
|
485
|
+
|
|
486
|
+
@torch.compiler.disable
|
|
487
|
+
def set_Fn_encoder_buffer(buffer: torch.Tensor, prefix: str = "Fn"):
|
|
488
|
+
set_buffer(f"{prefix}_encoder_buffer", buffer)
|
|
489
|
+
|
|
490
|
+
|
|
491
|
+
@torch.compiler.disable
|
|
492
|
+
def get_Fn_encoder_buffer(prefix: str = "Fn"):
|
|
493
|
+
return get_buffer(f"{prefix}_encoder_buffer")
|
|
494
|
+
|
|
495
|
+
|
|
496
|
+
# Bn buffers
|
|
497
|
+
@torch.compiler.disable
|
|
498
|
+
def set_Bn_buffer(buffer: torch.Tensor, prefix: str = "Bn"):
|
|
499
|
+
# Set hidden_states or residual for Bn blocks.
|
|
500
|
+
set_buffer(f"{prefix}_buffer", buffer)
|
|
501
|
+
|
|
502
|
+
|
|
503
|
+
@torch.compiler.disable
|
|
504
|
+
def get_Bn_buffer(prefix: str = "Bn"):
|
|
505
|
+
return get_buffer(f"{prefix}_buffer")
|
|
506
|
+
|
|
507
|
+
|
|
508
|
+
@torch.compiler.disable
|
|
509
|
+
def set_Bn_encoder_buffer(buffer: torch.Tensor, prefix: str = "Bn"):
|
|
510
|
+
set_buffer(f"{prefix}_encoder_buffer", buffer)
|
|
511
|
+
|
|
512
|
+
|
|
513
|
+
@torch.compiler.disable
|
|
514
|
+
def get_Bn_encoder_buffer(prefix: str = "Bn"):
|
|
515
|
+
return get_buffer(f"{prefix}_encoder_buffer")
|
|
516
|
+
|
|
517
|
+
|
|
518
|
+
@torch.compiler.disable
|
|
519
|
+
def apply_hidden_states_residual(
|
|
520
|
+
hidden_states: torch.Tensor,
|
|
521
|
+
encoder_hidden_states: torch.Tensor,
|
|
522
|
+
prefix: str = "Bn",
|
|
523
|
+
):
|
|
524
|
+
# Allow Bn and Fn prefix to be used for residual cache.
|
|
525
|
+
if "Bn" in prefix:
|
|
526
|
+
hidden_states_residual = get_Bn_buffer(prefix)
|
|
527
|
+
else:
|
|
528
|
+
hidden_states_residual = get_Fn_buffer(prefix)
|
|
529
|
+
|
|
530
|
+
assert (
|
|
531
|
+
hidden_states_residual is not None
|
|
532
|
+
), f"{prefix}_buffer must be set before"
|
|
533
|
+
hidden_states = hidden_states_residual + hidden_states
|
|
534
|
+
|
|
535
|
+
if "Bn" in prefix:
|
|
536
|
+
encoder_hidden_states_residual = get_Bn_encoder_buffer(prefix)
|
|
537
|
+
else:
|
|
538
|
+
encoder_hidden_states_residual = get_Fn_encoder_buffer(prefix)
|
|
539
|
+
|
|
540
|
+
assert (
|
|
541
|
+
encoder_hidden_states_residual is not None
|
|
542
|
+
), f"{prefix}_encoder_buffer must be set before"
|
|
543
|
+
encoder_hidden_states = (
|
|
544
|
+
encoder_hidden_states_residual + encoder_hidden_states
|
|
545
|
+
)
|
|
546
|
+
|
|
547
|
+
hidden_states = hidden_states.contiguous()
|
|
548
|
+
encoder_hidden_states = encoder_hidden_states.contiguous()
|
|
549
|
+
|
|
550
|
+
return hidden_states, encoder_hidden_states
|
|
551
|
+
|
|
552
|
+
|
|
553
|
+
@torch.compiler.disable
|
|
554
|
+
def get_downsample_factor():
|
|
555
|
+
cache_context = get_current_cache_context()
|
|
556
|
+
assert cache_context is not None, "cache_context must be set before"
|
|
557
|
+
return cache_context.downsample_factor
|
|
558
|
+
|
|
559
|
+
|
|
560
|
+
@torch.compiler.disable
|
|
561
|
+
def get_can_use_cache(
|
|
562
|
+
states_tensor: torch.Tensor, # hidden_states or residual
|
|
563
|
+
parallelized: bool = False,
|
|
564
|
+
threshold: Optional[float] = None, # can manually set threshold
|
|
565
|
+
prefix: str = "Fn",
|
|
566
|
+
):
|
|
567
|
+
if is_in_warmup():
|
|
568
|
+
return False
|
|
569
|
+
cached_steps = get_cached_steps()
|
|
570
|
+
max_cached_steps = get_max_cached_steps()
|
|
571
|
+
if max_cached_steps >= 0 and (len(cached_steps) >= max_cached_steps):
|
|
572
|
+
if logger.isEnabledFor(logging.DEBUG):
|
|
573
|
+
logger.debug(
|
|
574
|
+
f"{prefix}, max_cached_steps reached: {max_cached_steps}, "
|
|
575
|
+
"cannot use cache."
|
|
576
|
+
)
|
|
577
|
+
return False
|
|
578
|
+
if threshold is None or threshold <= 0.0:
|
|
579
|
+
threshold = get_residual_diff_threshold()
|
|
580
|
+
if threshold <= 0.0:
|
|
581
|
+
return False
|
|
582
|
+
downsample_factor = get_downsample_factor()
|
|
583
|
+
if downsample_factor > 1 and "Bn" not in prefix:
|
|
584
|
+
states_tensor = states_tensor[..., ::downsample_factor]
|
|
585
|
+
states_tensor = states_tensor.contiguous()
|
|
586
|
+
|
|
587
|
+
# Allow Bn and Fn prefix to be used for diff calculation.
|
|
588
|
+
if "Bn" in prefix:
|
|
589
|
+
prev_states_tensor = get_Bn_buffer(prefix)
|
|
590
|
+
else:
|
|
591
|
+
prev_states_tensor = get_Fn_buffer(prefix)
|
|
592
|
+
|
|
593
|
+
if not is_alter_cache_enabled():
|
|
594
|
+
# Dynamic cache according to the residual diff
|
|
595
|
+
can_use_cache = (
|
|
596
|
+
prev_states_tensor is not None
|
|
597
|
+
and are_two_tensors_similar(
|
|
598
|
+
prev_states_tensor,
|
|
599
|
+
states_tensor,
|
|
600
|
+
threshold=threshold,
|
|
601
|
+
parallelized=parallelized,
|
|
602
|
+
prefix=prefix,
|
|
603
|
+
)
|
|
604
|
+
)
|
|
605
|
+
else:
|
|
606
|
+
# Only cache in the alter cache steps
|
|
607
|
+
can_use_cache = (
|
|
608
|
+
prev_states_tensor is not None
|
|
609
|
+
and are_two_tensors_similar(
|
|
610
|
+
prev_states_tensor,
|
|
611
|
+
states_tensor,
|
|
612
|
+
threshold=threshold,
|
|
613
|
+
parallelized=parallelized,
|
|
614
|
+
prefix=prefix,
|
|
615
|
+
)
|
|
616
|
+
and is_alter_cache()
|
|
617
|
+
)
|
|
618
|
+
return can_use_cache
|
|
619
|
+
|
|
620
|
+
|
|
621
|
+
class DBCachedTransformerBlocks(torch.nn.Module):
|
|
622
|
+
def __init__(
|
|
623
|
+
self,
|
|
624
|
+
transformer_blocks,
|
|
625
|
+
single_transformer_blocks=None,
|
|
626
|
+
*,
|
|
627
|
+
transformer=None,
|
|
628
|
+
return_hidden_states_first=True,
|
|
629
|
+
return_hidden_states_only=False,
|
|
630
|
+
):
|
|
631
|
+
super().__init__()
|
|
632
|
+
|
|
633
|
+
self.transformer = transformer
|
|
634
|
+
self.transformer_blocks = transformer_blocks
|
|
635
|
+
self.single_transformer_blocks = single_transformer_blocks
|
|
636
|
+
self.return_hidden_states_first = return_hidden_states_first
|
|
637
|
+
self.return_hidden_states_only = return_hidden_states_only
|
|
638
|
+
|
|
639
|
+
def forward(
|
|
640
|
+
self,
|
|
641
|
+
hidden_states: torch.Tensor,
|
|
642
|
+
encoder_hidden_states: torch.Tensor,
|
|
643
|
+
*args,
|
|
644
|
+
**kwargs,
|
|
645
|
+
):
|
|
646
|
+
original_hidden_states = hidden_states
|
|
647
|
+
# Call first `n` blocks to process the hidden states for
|
|
648
|
+
# more stable diff calculation.
|
|
649
|
+
hidden_states, encoder_hidden_states = self.call_Fn_transformer_blocks(
|
|
650
|
+
hidden_states,
|
|
651
|
+
encoder_hidden_states,
|
|
652
|
+
*args,
|
|
653
|
+
**kwargs,
|
|
654
|
+
)
|
|
655
|
+
|
|
656
|
+
Fn_hidden_states_residual = hidden_states - original_hidden_states
|
|
657
|
+
del original_hidden_states
|
|
658
|
+
|
|
659
|
+
mark_step_begin()
|
|
660
|
+
# Residual L1 diff or Hidden States L1 diff
|
|
661
|
+
can_use_cache = get_can_use_cache(
|
|
662
|
+
(
|
|
663
|
+
Fn_hidden_states_residual
|
|
664
|
+
if not is_l1_diff_enabled()
|
|
665
|
+
else hidden_states
|
|
666
|
+
),
|
|
667
|
+
parallelized=self._is_parallelized(),
|
|
668
|
+
prefix=(
|
|
669
|
+
"Fn_residual"
|
|
670
|
+
if not is_l1_diff_enabled()
|
|
671
|
+
else "Fn_hidden_states"
|
|
672
|
+
),
|
|
673
|
+
)
|
|
674
|
+
|
|
675
|
+
torch._dynamo.graph_break()
|
|
676
|
+
if can_use_cache:
|
|
677
|
+
add_cached_step()
|
|
678
|
+
del Fn_hidden_states_residual
|
|
679
|
+
hidden_states, encoder_hidden_states = apply_hidden_states_residual(
|
|
680
|
+
hidden_states, encoder_hidden_states, prefix="Bn_residual"
|
|
681
|
+
)
|
|
682
|
+
# Call last `n` blocks to further process the hidden states
|
|
683
|
+
# for higher precision.
|
|
684
|
+
hidden_states, encoder_hidden_states = (
|
|
685
|
+
self.call_Bn_transformer_blocks(
|
|
686
|
+
hidden_states,
|
|
687
|
+
encoder_hidden_states,
|
|
688
|
+
*args,
|
|
689
|
+
**kwargs,
|
|
690
|
+
)
|
|
691
|
+
)
|
|
692
|
+
else:
|
|
693
|
+
set_Fn_buffer(Fn_hidden_states_residual, prefix="Fn_residual")
|
|
694
|
+
if is_l1_diff_enabled():
|
|
695
|
+
# for hidden states L1 diff
|
|
696
|
+
set_Fn_buffer(hidden_states, "Fn_hidden_states")
|
|
697
|
+
del Fn_hidden_states_residual
|
|
698
|
+
(
|
|
699
|
+
hidden_states,
|
|
700
|
+
encoder_hidden_states,
|
|
701
|
+
hidden_states_residual,
|
|
702
|
+
encoder_hidden_states_residual,
|
|
703
|
+
) = self.call_MN2n_transformer_blocks( # middle
|
|
704
|
+
hidden_states,
|
|
705
|
+
encoder_hidden_states,
|
|
706
|
+
*args,
|
|
707
|
+
**kwargs,
|
|
708
|
+
)
|
|
709
|
+
set_Bn_buffer(hidden_states_residual, prefix="Bn_residual")
|
|
710
|
+
set_Bn_encoder_buffer(
|
|
711
|
+
encoder_hidden_states_residual, prefix="Bn_residual"
|
|
712
|
+
)
|
|
713
|
+
# Call last `n` blocks to further process the hidden states
|
|
714
|
+
# for higher precision.
|
|
715
|
+
hidden_states, encoder_hidden_states = (
|
|
716
|
+
self.call_Bn_transformer_blocks(
|
|
717
|
+
hidden_states,
|
|
718
|
+
encoder_hidden_states,
|
|
719
|
+
*args,
|
|
720
|
+
**kwargs,
|
|
721
|
+
)
|
|
722
|
+
)
|
|
723
|
+
|
|
724
|
+
patch_cached_stats(self.transformer)
|
|
725
|
+
torch._dynamo.graph_break()
|
|
726
|
+
|
|
727
|
+
return (
|
|
728
|
+
hidden_states
|
|
729
|
+
if self.return_hidden_states_only
|
|
730
|
+
else (
|
|
731
|
+
(hidden_states, encoder_hidden_states)
|
|
732
|
+
if self.return_hidden_states_first
|
|
733
|
+
else (encoder_hidden_states, hidden_states)
|
|
734
|
+
)
|
|
735
|
+
)
|
|
736
|
+
|
|
737
|
+
@torch.compiler.disable
|
|
738
|
+
def _is_parallelized(self):
|
|
739
|
+
# Compatible with distributed inference.
|
|
740
|
+
return all(
|
|
741
|
+
(
|
|
742
|
+
self.transformer is not None,
|
|
743
|
+
getattr(self.transformer, "_is_parallelized", False),
|
|
744
|
+
)
|
|
745
|
+
)
|
|
746
|
+
|
|
747
|
+
@torch.compiler.disable
|
|
748
|
+
def _is_in_cache_step(self):
|
|
749
|
+
# Check if the current step is in cache steps.
|
|
750
|
+
# If so, we can skip some Bn blocks and directly
|
|
751
|
+
# use the cached values.
|
|
752
|
+
return get_current_step() in get_cached_steps()
|
|
753
|
+
|
|
754
|
+
@torch.compiler.disable
|
|
755
|
+
def _Fn_transformer_blocks(self):
|
|
756
|
+
# Select first `n` blocks to process the hidden states for
|
|
757
|
+
# more stable diff calculation.
|
|
758
|
+
# Fn: [0,...,n-1]
|
|
759
|
+
selected_Fn_transformer_blocks = self.transformer_blocks[
|
|
760
|
+
: Fn_compute_blocks()
|
|
761
|
+
]
|
|
762
|
+
# Skip the blocks if they are not in the Fn_compute_blocks_ids.
|
|
763
|
+
# WARN: DON'T set len(Fn_compute_blocks_ids) > 0 NOW, still have
|
|
764
|
+
# some precision issues. We don't know whether a step should be
|
|
765
|
+
# cached or not before the first Fn blocks are processed.
|
|
766
|
+
if len(Fn_compute_blocks_ids()) > 0:
|
|
767
|
+
selected_Fn_transformer_blocks = [
|
|
768
|
+
selected_Fn_transformer_blocks[i]
|
|
769
|
+
for i in Fn_compute_blocks_ids()
|
|
770
|
+
if i < len(selected_Fn_transformer_blocks)
|
|
771
|
+
]
|
|
772
|
+
return selected_Fn_transformer_blocks
|
|
773
|
+
|
|
774
|
+
@torch.compiler.disable
|
|
775
|
+
def _MN2n_single_transformer_blocks(self): # middle
|
|
776
|
+
# M(N-2n): transformer_blocks [n,...] + single_transformer_blocks [0,...,N-n]
|
|
777
|
+
selected_MN2n_single_transformer_blocks = []
|
|
778
|
+
if self.single_transformer_blocks is not None:
|
|
779
|
+
if Bn_compute_blocks() == 0: # WARN: x[:-0] = []
|
|
780
|
+
selected_MN2n_single_transformer_blocks = (
|
|
781
|
+
self.single_transformer_blocks
|
|
782
|
+
)
|
|
783
|
+
else:
|
|
784
|
+
selected_MN2n_single_transformer_blocks = (
|
|
785
|
+
self.single_transformer_blocks[: -Bn_compute_blocks()]
|
|
786
|
+
)
|
|
787
|
+
return selected_MN2n_single_transformer_blocks
|
|
788
|
+
|
|
789
|
+
@torch.compiler.disable
|
|
790
|
+
def _MN2n_transformer_blocks(self):
|
|
791
|
+
# M(N-2n): only transformer_blocks [n,...,N-n], middle
|
|
792
|
+
if Bn_compute_blocks() == 0: # WARN: x[:-0] = []
|
|
793
|
+
selected_MN2n_transformer_blocks = self.transformer_blocks[
|
|
794
|
+
Fn_compute_blocks() :
|
|
795
|
+
]
|
|
796
|
+
else:
|
|
797
|
+
selected_MN2n_transformer_blocks = self.transformer_blocks[
|
|
798
|
+
Fn_compute_blocks() : -Bn_compute_blocks()
|
|
799
|
+
]
|
|
800
|
+
return selected_MN2n_transformer_blocks
|
|
801
|
+
|
|
802
|
+
@torch.compiler.disable
|
|
803
|
+
def _Bn_single_transformer_blocks(self):
|
|
804
|
+
# Bn: single_transformer_blocks [N-n+1,...,N-1]
|
|
805
|
+
selected_Bn_single_transformer_blocks = []
|
|
806
|
+
if self.single_transformer_blocks is not None:
|
|
807
|
+
selected_Bn_single_transformer_blocks = (
|
|
808
|
+
self.single_transformer_blocks[-Bn_compute_blocks() :]
|
|
809
|
+
)
|
|
810
|
+
return selected_Bn_single_transformer_blocks
|
|
811
|
+
|
|
812
|
+
@torch.compiler.disable
|
|
813
|
+
def _Bn_transformer_blocks(self):
|
|
814
|
+
# Bn: transformer_blocks [N-n+1,...,N-1]
|
|
815
|
+
selected_Bn_transformer_blocks = self.transformer_blocks[
|
|
816
|
+
-Bn_compute_blocks() :
|
|
817
|
+
]
|
|
818
|
+
return selected_Bn_transformer_blocks
|
|
819
|
+
|
|
820
|
+
def call_Fn_transformer_blocks(
|
|
821
|
+
self,
|
|
822
|
+
hidden_states: torch.Tensor,
|
|
823
|
+
encoder_hidden_states: torch.Tensor,
|
|
824
|
+
*args,
|
|
825
|
+
**kwargs,
|
|
826
|
+
):
|
|
827
|
+
assert Fn_compute_blocks() <= len(self.transformer_blocks), (
|
|
828
|
+
f"Fn_compute_blocks {Fn_compute_blocks()} must be less than "
|
|
829
|
+
f"the number of transformer blocks {len(self.transformer_blocks)}"
|
|
830
|
+
)
|
|
831
|
+
for block in self._Fn_transformer_blocks():
|
|
832
|
+
hidden_states = block(
|
|
833
|
+
hidden_states,
|
|
834
|
+
encoder_hidden_states,
|
|
835
|
+
*args,
|
|
836
|
+
**kwargs,
|
|
837
|
+
)
|
|
838
|
+
if not isinstance(hidden_states, torch.Tensor):
|
|
839
|
+
hidden_states, encoder_hidden_states = hidden_states
|
|
840
|
+
if not self.return_hidden_states_first:
|
|
841
|
+
hidden_states, encoder_hidden_states = (
|
|
842
|
+
encoder_hidden_states,
|
|
843
|
+
hidden_states,
|
|
844
|
+
)
|
|
845
|
+
|
|
846
|
+
return hidden_states, encoder_hidden_states
|
|
847
|
+
|
|
848
|
+
def call_MN2n_transformer_blocks(
|
|
849
|
+
self,
|
|
850
|
+
hidden_states: torch.Tensor,
|
|
851
|
+
encoder_hidden_states: torch.Tensor,
|
|
852
|
+
*args,
|
|
853
|
+
**kwargs,
|
|
854
|
+
):
|
|
855
|
+
original_hidden_states = hidden_states
|
|
856
|
+
original_encoder_hidden_states = encoder_hidden_states
|
|
857
|
+
if self.single_transformer_blocks is not None:
|
|
858
|
+
for block in self.transformer_blocks[Fn_compute_blocks() :]:
|
|
859
|
+
hidden_states = block(
|
|
860
|
+
hidden_states,
|
|
861
|
+
encoder_hidden_states,
|
|
862
|
+
*args,
|
|
863
|
+
**kwargs,
|
|
864
|
+
)
|
|
865
|
+
if not isinstance(hidden_states, torch.Tensor):
|
|
866
|
+
hidden_states, encoder_hidden_states = hidden_states
|
|
867
|
+
if not self.return_hidden_states_first:
|
|
868
|
+
hidden_states, encoder_hidden_states = (
|
|
869
|
+
encoder_hidden_states,
|
|
870
|
+
hidden_states,
|
|
871
|
+
)
|
|
872
|
+
|
|
873
|
+
hidden_states = torch.cat(
|
|
874
|
+
[encoder_hidden_states, hidden_states], dim=1
|
|
875
|
+
)
|
|
876
|
+
for block in self._MN2n_single_transformer_blocks():
|
|
877
|
+
hidden_states = block(
|
|
878
|
+
hidden_states,
|
|
879
|
+
*args,
|
|
880
|
+
**kwargs,
|
|
881
|
+
)
|
|
882
|
+
encoder_hidden_states, hidden_states = hidden_states.split(
|
|
883
|
+
[
|
|
884
|
+
encoder_hidden_states.shape[1],
|
|
885
|
+
hidden_states.shape[1] - encoder_hidden_states.shape[1],
|
|
886
|
+
],
|
|
887
|
+
dim=1,
|
|
888
|
+
)
|
|
889
|
+
else:
|
|
890
|
+
for block in self._MN2n_transformer_blocks():
|
|
891
|
+
hidden_states = block(
|
|
892
|
+
hidden_states,
|
|
893
|
+
encoder_hidden_states,
|
|
894
|
+
*args,
|
|
895
|
+
**kwargs,
|
|
896
|
+
)
|
|
897
|
+
if not isinstance(hidden_states, torch.Tensor):
|
|
898
|
+
hidden_states, encoder_hidden_states = hidden_states
|
|
899
|
+
if not self.return_hidden_states_first:
|
|
900
|
+
hidden_states, encoder_hidden_states = (
|
|
901
|
+
encoder_hidden_states,
|
|
902
|
+
hidden_states,
|
|
903
|
+
)
|
|
904
|
+
|
|
905
|
+
# hidden_states_shape = hidden_states.shape
|
|
906
|
+
# encoder_hidden_states_shape = encoder_hidden_states.shape
|
|
907
|
+
hidden_states = (
|
|
908
|
+
hidden_states.reshape(-1)
|
|
909
|
+
.contiguous()
|
|
910
|
+
.reshape(original_hidden_states.shape)
|
|
911
|
+
)
|
|
912
|
+
encoder_hidden_states = (
|
|
913
|
+
encoder_hidden_states.reshape(-1)
|
|
914
|
+
.contiguous()
|
|
915
|
+
.reshape(original_encoder_hidden_states.shape)
|
|
916
|
+
)
|
|
917
|
+
|
|
918
|
+
# hidden_states = hidden_states.contiguous()
|
|
919
|
+
# encoder_hidden_states = encoder_hidden_states.contiguous()
|
|
920
|
+
|
|
921
|
+
hidden_states_residual = hidden_states - original_hidden_states
|
|
922
|
+
encoder_hidden_states_residual = (
|
|
923
|
+
encoder_hidden_states - original_encoder_hidden_states
|
|
924
|
+
)
|
|
925
|
+
|
|
926
|
+
hidden_states_residual = (
|
|
927
|
+
hidden_states_residual.reshape(-1)
|
|
928
|
+
.contiguous()
|
|
929
|
+
.reshape(original_hidden_states.shape)
|
|
930
|
+
)
|
|
931
|
+
encoder_hidden_states_residual = (
|
|
932
|
+
encoder_hidden_states_residual.reshape(-1)
|
|
933
|
+
.contiguous()
|
|
934
|
+
.reshape(original_encoder_hidden_states.shape)
|
|
935
|
+
)
|
|
936
|
+
|
|
937
|
+
return (
|
|
938
|
+
hidden_states,
|
|
939
|
+
encoder_hidden_states,
|
|
940
|
+
hidden_states_residual,
|
|
941
|
+
encoder_hidden_states_residual,
|
|
942
|
+
)
|
|
943
|
+
|
|
944
|
+
@torch.compiler.disable
|
|
945
|
+
def _Bn_i_single_hidden_states_residual(
|
|
946
|
+
self,
|
|
947
|
+
Bn_i_hidden_states: torch.Tensor,
|
|
948
|
+
Bn_i_original_hidden_states: torch.Tensor,
|
|
949
|
+
original_hidden_states: torch.Tensor,
|
|
950
|
+
original_encoder_hidden_states: torch.Tensor,
|
|
951
|
+
):
|
|
952
|
+
# Split the Bn_i_hidden_states and Bn_i_original_hidden_states
|
|
953
|
+
# into encoder_hidden_states and hidden_states.
|
|
954
|
+
Bn_i_hidden_states, Bn_i_encoder_hidden_states = (
|
|
955
|
+
self._split_Bn_i_single_hidden_states(
|
|
956
|
+
Bn_i_hidden_states,
|
|
957
|
+
original_hidden_states,
|
|
958
|
+
original_encoder_hidden_states,
|
|
959
|
+
)
|
|
960
|
+
)
|
|
961
|
+
# Split the Bn_i_original_hidden_states into encoder_hidden_states
|
|
962
|
+
# and hidden_states.
|
|
963
|
+
Bn_i_original_hidden_states, Bn_i_original_encoder_hidden_states = (
|
|
964
|
+
self._split_Bn_i_single_hidden_states(
|
|
965
|
+
Bn_i_original_hidden_states,
|
|
966
|
+
original_hidden_states,
|
|
967
|
+
original_encoder_hidden_states,
|
|
968
|
+
)
|
|
969
|
+
)
|
|
970
|
+
|
|
971
|
+
# Compute the residuals for the Bn_i_hidden_states and
|
|
972
|
+
# Bn_i_encoder_hidden_states.
|
|
973
|
+
Bn_i_hidden_states_residual = (
|
|
974
|
+
Bn_i_hidden_states - Bn_i_original_hidden_states
|
|
975
|
+
)
|
|
976
|
+
Bn_i_encoder_hidden_states_residual = (
|
|
977
|
+
Bn_i_encoder_hidden_states - Bn_i_original_encoder_hidden_states
|
|
978
|
+
)
|
|
979
|
+
return (
|
|
980
|
+
Bn_i_hidden_states_residual,
|
|
981
|
+
Bn_i_encoder_hidden_states_residual,
|
|
982
|
+
)
|
|
983
|
+
|
|
984
|
+
@torch.compiler.disable
|
|
985
|
+
def _split_Bn_i_single_hidden_states(
|
|
986
|
+
self,
|
|
987
|
+
Bn_i_hidden_states: torch.Tensor,
|
|
988
|
+
original_hidden_states: torch.Tensor,
|
|
989
|
+
original_encoder_hidden_states: torch.Tensor,
|
|
990
|
+
):
|
|
991
|
+
# Split the Bn_i_hidden_states into encoder_hidden_states and hidden_states.
|
|
992
|
+
Bn_i_encoder_hidden_states, Bn_i_hidden_states = (
|
|
993
|
+
Bn_i_hidden_states.split(
|
|
994
|
+
[
|
|
995
|
+
original_encoder_hidden_states.shape[1],
|
|
996
|
+
Bn_i_hidden_states.shape[1]
|
|
997
|
+
- original_encoder_hidden_states.shape[1],
|
|
998
|
+
],
|
|
999
|
+
dim=1,
|
|
1000
|
+
)
|
|
1001
|
+
)
|
|
1002
|
+
# Reshape the Bn_i_hidden_states and Bn_i_encoder_hidden_states
|
|
1003
|
+
# to the original shape. This is necessary to ensure that the
|
|
1004
|
+
# residuals are computed correctly.
|
|
1005
|
+
Bn_i_hidden_states = (
|
|
1006
|
+
Bn_i_hidden_states.reshape(-1)
|
|
1007
|
+
.contiguous()
|
|
1008
|
+
.reshape(original_hidden_states.shape)
|
|
1009
|
+
)
|
|
1010
|
+
Bn_i_encoder_hidden_states = (
|
|
1011
|
+
Bn_i_encoder_hidden_states.reshape(-1)
|
|
1012
|
+
.contiguous()
|
|
1013
|
+
.reshape(original_encoder_hidden_states.shape)
|
|
1014
|
+
)
|
|
1015
|
+
return Bn_i_hidden_states, Bn_i_encoder_hidden_states
|
|
1016
|
+
|
|
1017
|
+
def _compute_and_cache_single_transformer_block(
|
|
1018
|
+
self,
|
|
1019
|
+
i: int, # Block index in the transformer blocks
|
|
1020
|
+
# Helper inputs for hidden states split and reshape
|
|
1021
|
+
original_hidden_states: torch.Tensor,
|
|
1022
|
+
original_encoder_hidden_states: torch.Tensor,
|
|
1023
|
+
# Below are the inputs to the block
|
|
1024
|
+
block, # The transformer block to be executed
|
|
1025
|
+
hidden_states: torch.Tensor,
|
|
1026
|
+
*args,
|
|
1027
|
+
**kwargs,
|
|
1028
|
+
):
|
|
1029
|
+
# Helper function for `call_Bn_transformer_blocks`
|
|
1030
|
+
# Skip the blocks by reuse residual cache if they are not
|
|
1031
|
+
# in the Bn_compute_blocks_ids. NOTE: We should only skip
|
|
1032
|
+
# the specific Bn blocks in cache steps. Compute the block
|
|
1033
|
+
# and cache the residuals in non-cache steps.
|
|
1034
|
+
|
|
1035
|
+
# Normal steps: Compute the block and cache the residuals.
|
|
1036
|
+
if not self._is_in_cache_step():
|
|
1037
|
+
Bn_i_original_hidden_states = hidden_states
|
|
1038
|
+
hidden_states = block(
|
|
1039
|
+
hidden_states,
|
|
1040
|
+
*args,
|
|
1041
|
+
**kwargs,
|
|
1042
|
+
)
|
|
1043
|
+
# Cache residuals for the non-compute Bn blocks for
|
|
1044
|
+
# subsequent cache steps.
|
|
1045
|
+
if i not in Bn_compute_blocks_ids():
|
|
1046
|
+
Bn_i_hidden_states = hidden_states
|
|
1047
|
+
(
|
|
1048
|
+
Bn_i_hidden_states_residual,
|
|
1049
|
+
Bn_i_encoder_hidden_states_residual,
|
|
1050
|
+
) = self._Bn_i_single_hidden_states_residual(
|
|
1051
|
+
Bn_i_hidden_states,
|
|
1052
|
+
Bn_i_original_hidden_states,
|
|
1053
|
+
original_hidden_states,
|
|
1054
|
+
original_encoder_hidden_states,
|
|
1055
|
+
)
|
|
1056
|
+
|
|
1057
|
+
# Save original_hidden_states for diff calculation.
|
|
1058
|
+
set_Bn_buffer(
|
|
1059
|
+
Bn_i_original_hidden_states,
|
|
1060
|
+
prefix=f"Bn_{i}_single_original",
|
|
1061
|
+
)
|
|
1062
|
+
|
|
1063
|
+
set_Bn_buffer(
|
|
1064
|
+
Bn_i_hidden_states_residual,
|
|
1065
|
+
prefix=f"Bn_{i}_single_residual",
|
|
1066
|
+
)
|
|
1067
|
+
set_Bn_encoder_buffer(
|
|
1068
|
+
Bn_i_encoder_hidden_states_residual,
|
|
1069
|
+
prefix=f"Bn_{i}_single_residual",
|
|
1070
|
+
)
|
|
1071
|
+
del Bn_i_hidden_states
|
|
1072
|
+
del Bn_i_hidden_states_residual
|
|
1073
|
+
del Bn_i_encoder_hidden_states_residual
|
|
1074
|
+
|
|
1075
|
+
del Bn_i_original_hidden_states
|
|
1076
|
+
|
|
1077
|
+
else:
|
|
1078
|
+
# Cache steps: Reuse the cached residuals.
|
|
1079
|
+
# Check if the block is in the Bn_compute_blocks_ids.
|
|
1080
|
+
if i in Bn_compute_blocks_ids():
|
|
1081
|
+
hidden_states = block(
|
|
1082
|
+
hidden_states,
|
|
1083
|
+
*args,
|
|
1084
|
+
**kwargs,
|
|
1085
|
+
)
|
|
1086
|
+
else:
|
|
1087
|
+
# Skip the block if it is not in the Bn_compute_blocks_ids.
|
|
1088
|
+
# Use the cached residuals instead.
|
|
1089
|
+
# Check if can use the cached residuals.
|
|
1090
|
+
if get_can_use_cache(
|
|
1091
|
+
hidden_states, # curr step
|
|
1092
|
+
parallelized=self._is_parallelized(),
|
|
1093
|
+
threshold=non_compute_blocks_diff_threshold(),
|
|
1094
|
+
prefix=f"Bn_{i}_single_original", # prev step
|
|
1095
|
+
):
|
|
1096
|
+
Bn_i_original_hidden_states = hidden_states
|
|
1097
|
+
(
|
|
1098
|
+
Bn_i_original_hidden_states,
|
|
1099
|
+
Bn_i_original_encoder_hidden_states,
|
|
1100
|
+
) = self._split_Bn_i_single_hidden_states(
|
|
1101
|
+
Bn_i_original_hidden_states,
|
|
1102
|
+
original_hidden_states,
|
|
1103
|
+
original_encoder_hidden_states,
|
|
1104
|
+
)
|
|
1105
|
+
hidden_states, encoder_hidden_states = (
|
|
1106
|
+
apply_hidden_states_residual(
|
|
1107
|
+
Bn_i_original_hidden_states,
|
|
1108
|
+
Bn_i_original_encoder_hidden_states,
|
|
1109
|
+
prefix=f"Bn_{i}_single_residual",
|
|
1110
|
+
)
|
|
1111
|
+
)
|
|
1112
|
+
hidden_states = torch.cat(
|
|
1113
|
+
[encoder_hidden_states, hidden_states],
|
|
1114
|
+
dim=1,
|
|
1115
|
+
)
|
|
1116
|
+
del Bn_i_original_hidden_states
|
|
1117
|
+
del Bn_i_original_encoder_hidden_states
|
|
1118
|
+
else:
|
|
1119
|
+
hidden_states = block(
|
|
1120
|
+
hidden_states,
|
|
1121
|
+
*args,
|
|
1122
|
+
**kwargs,
|
|
1123
|
+
)
|
|
1124
|
+
return hidden_states
|
|
1125
|
+
|
|
1126
|
+
def _compute_and_cache_transformer_block(
|
|
1127
|
+
self,
|
|
1128
|
+
i: int, # Block index in the transformer blocks
|
|
1129
|
+
# Below are the inputs to the block
|
|
1130
|
+
block, # The transformer block to be executed
|
|
1131
|
+
hidden_states: torch.Tensor,
|
|
1132
|
+
encoder_hidden_states: torch.Tensor,
|
|
1133
|
+
*args,
|
|
1134
|
+
**kwargs,
|
|
1135
|
+
):
|
|
1136
|
+
# Helper function for `call_Bn_transformer_blocks`
|
|
1137
|
+
# Skip the blocks by reuse residual cache if they are not
|
|
1138
|
+
# in the Bn_compute_blocks_ids. NOTE: We should only skip
|
|
1139
|
+
# the specific Bn blocks in cache steps. Compute the block
|
|
1140
|
+
# and cache the residuals in non-cache steps.
|
|
1141
|
+
|
|
1142
|
+
# Normal steps: Compute the block and cache the residuals.
|
|
1143
|
+
if not self._is_in_cache_step():
|
|
1144
|
+
Bn_i_original_hidden_states = hidden_states
|
|
1145
|
+
Bn_i_original_encoder_hidden_states = encoder_hidden_states
|
|
1146
|
+
hidden_states = block(
|
|
1147
|
+
hidden_states,
|
|
1148
|
+
encoder_hidden_states,
|
|
1149
|
+
*args,
|
|
1150
|
+
**kwargs,
|
|
1151
|
+
)
|
|
1152
|
+
if not isinstance(hidden_states, torch.Tensor):
|
|
1153
|
+
hidden_states, encoder_hidden_states = hidden_states
|
|
1154
|
+
if not self.return_hidden_states_first:
|
|
1155
|
+
hidden_states, encoder_hidden_states = (
|
|
1156
|
+
encoder_hidden_states,
|
|
1157
|
+
hidden_states,
|
|
1158
|
+
)
|
|
1159
|
+
# Cache residuals for the non-compute Bn blocks for
|
|
1160
|
+
# subsequent cache steps.
|
|
1161
|
+
if i not in Bn_compute_blocks_ids():
|
|
1162
|
+
Bn_i_hidden_states_residual = (
|
|
1163
|
+
hidden_states - Bn_i_original_hidden_states
|
|
1164
|
+
)
|
|
1165
|
+
Bn_i_encoder_hidden_states_residual = (
|
|
1166
|
+
encoder_hidden_states - Bn_i_original_encoder_hidden_states
|
|
1167
|
+
)
|
|
1168
|
+
|
|
1169
|
+
# Save original_hidden_states for diff calculation.
|
|
1170
|
+
set_Bn_buffer(
|
|
1171
|
+
Bn_i_original_hidden_states,
|
|
1172
|
+
prefix=f"Bn_{i}_original",
|
|
1173
|
+
)
|
|
1174
|
+
|
|
1175
|
+
set_Bn_buffer(
|
|
1176
|
+
Bn_i_hidden_states_residual,
|
|
1177
|
+
prefix=f"Bn_{i}_residual",
|
|
1178
|
+
)
|
|
1179
|
+
set_Bn_encoder_buffer(
|
|
1180
|
+
Bn_i_encoder_hidden_states_residual,
|
|
1181
|
+
prefix=f"Bn_{i}_residual",
|
|
1182
|
+
)
|
|
1183
|
+
del Bn_i_hidden_states_residual
|
|
1184
|
+
del Bn_i_encoder_hidden_states_residual
|
|
1185
|
+
|
|
1186
|
+
del Bn_i_original_hidden_states
|
|
1187
|
+
del Bn_i_original_encoder_hidden_states
|
|
1188
|
+
|
|
1189
|
+
else:
|
|
1190
|
+
# Cache steps: Reuse the cached residuals.
|
|
1191
|
+
# Check if the block is in the Bn_compute_blocks_ids.
|
|
1192
|
+
if i in Bn_compute_blocks_ids():
|
|
1193
|
+
hidden_states = block(
|
|
1194
|
+
hidden_states,
|
|
1195
|
+
encoder_hidden_states,
|
|
1196
|
+
*args,
|
|
1197
|
+
**kwargs,
|
|
1198
|
+
)
|
|
1199
|
+
if not isinstance(hidden_states, torch.Tensor):
|
|
1200
|
+
hidden_states, encoder_hidden_states = hidden_states
|
|
1201
|
+
if not self.return_hidden_states_first:
|
|
1202
|
+
hidden_states, encoder_hidden_states = (
|
|
1203
|
+
encoder_hidden_states,
|
|
1204
|
+
hidden_states,
|
|
1205
|
+
)
|
|
1206
|
+
else:
|
|
1207
|
+
# Skip the block if it is not in the Bn_compute_blocks_ids.
|
|
1208
|
+
# Use the cached residuals instead.
|
|
1209
|
+
# Check if can use the cached residuals.
|
|
1210
|
+
if get_can_use_cache(
|
|
1211
|
+
hidden_states, # curr step
|
|
1212
|
+
parallelized=self._is_parallelized(),
|
|
1213
|
+
threshold=non_compute_blocks_diff_threshold(),
|
|
1214
|
+
prefix=f"Bn_{i}_original", # prev step
|
|
1215
|
+
):
|
|
1216
|
+
hidden_states, encoder_hidden_states = (
|
|
1217
|
+
apply_hidden_states_residual(
|
|
1218
|
+
hidden_states,
|
|
1219
|
+
encoder_hidden_states,
|
|
1220
|
+
prefix=f"Bn_{i}_residual",
|
|
1221
|
+
)
|
|
1222
|
+
)
|
|
1223
|
+
else:
|
|
1224
|
+
hidden_states = block(
|
|
1225
|
+
hidden_states,
|
|
1226
|
+
encoder_hidden_states,
|
|
1227
|
+
*args,
|
|
1228
|
+
**kwargs,
|
|
1229
|
+
)
|
|
1230
|
+
if not isinstance(hidden_states, torch.Tensor):
|
|
1231
|
+
hidden_states, encoder_hidden_states = hidden_states
|
|
1232
|
+
if not self.return_hidden_states_first:
|
|
1233
|
+
hidden_states, encoder_hidden_states = (
|
|
1234
|
+
encoder_hidden_states,
|
|
1235
|
+
hidden_states,
|
|
1236
|
+
)
|
|
1237
|
+
return hidden_states, encoder_hidden_states
|
|
1238
|
+
|
|
1239
|
+
def call_Bn_transformer_blocks(
|
|
1240
|
+
self,
|
|
1241
|
+
hidden_states: torch.Tensor,
|
|
1242
|
+
encoder_hidden_states: torch.Tensor,
|
|
1243
|
+
*args,
|
|
1244
|
+
**kwargs,
|
|
1245
|
+
):
|
|
1246
|
+
if Bn_compute_blocks() == 0:
|
|
1247
|
+
return hidden_states, encoder_hidden_states
|
|
1248
|
+
|
|
1249
|
+
original_hidden_states = hidden_states
|
|
1250
|
+
original_encoder_hidden_states = encoder_hidden_states
|
|
1251
|
+
if self.single_transformer_blocks is not None:
|
|
1252
|
+
assert Bn_compute_blocks() <= len(self.single_transformer_blocks), (
|
|
1253
|
+
f"Bn_compute_blocks {Bn_compute_blocks()} must be less than "
|
|
1254
|
+
f"the number of single transformer blocks {len(self.single_transformer_blocks)}"
|
|
1255
|
+
)
|
|
1256
|
+
|
|
1257
|
+
torch._dynamo.graph_break()
|
|
1258
|
+
hidden_states = torch.cat(
|
|
1259
|
+
[encoder_hidden_states, hidden_states], dim=1
|
|
1260
|
+
)
|
|
1261
|
+
if len(Bn_compute_blocks_ids()) > 0:
|
|
1262
|
+
for i, block in enumerate(self._Bn_single_transformer_blocks()):
|
|
1263
|
+
hidden_states = (
|
|
1264
|
+
self._compute_and_cache_single_transformer_block(
|
|
1265
|
+
i,
|
|
1266
|
+
original_hidden_states,
|
|
1267
|
+
original_encoder_hidden_states,
|
|
1268
|
+
block,
|
|
1269
|
+
hidden_states,
|
|
1270
|
+
*args,
|
|
1271
|
+
**kwargs,
|
|
1272
|
+
)
|
|
1273
|
+
)
|
|
1274
|
+
else:
|
|
1275
|
+
# Compute all Bn blocks if no specific Bn compute blocks ids are set.
|
|
1276
|
+
for block in self._Bn_single_transformer_blocks():
|
|
1277
|
+
hidden_states = block(
|
|
1278
|
+
hidden_states,
|
|
1279
|
+
*args,
|
|
1280
|
+
**kwargs,
|
|
1281
|
+
)
|
|
1282
|
+
encoder_hidden_states, hidden_states = hidden_states.split(
|
|
1283
|
+
[
|
|
1284
|
+
encoder_hidden_states.shape[1],
|
|
1285
|
+
hidden_states.shape[1] - encoder_hidden_states.shape[1],
|
|
1286
|
+
],
|
|
1287
|
+
dim=1,
|
|
1288
|
+
)
|
|
1289
|
+
torch._dynamo.graph_break()
|
|
1290
|
+
else:
|
|
1291
|
+
assert Bn_compute_blocks() <= len(self.transformer_blocks), (
|
|
1292
|
+
f"Bn_compute_blocks {Bn_compute_blocks()} must be less than "
|
|
1293
|
+
f"the number of transformer blocks {len(self.transformer_blocks)}"
|
|
1294
|
+
)
|
|
1295
|
+
torch._dynamo.graph_break()
|
|
1296
|
+
if len(Bn_compute_blocks_ids()) > 0:
|
|
1297
|
+
for i, block in enumerate(self._Bn_transformer_blocks()):
|
|
1298
|
+
hidden_states, encoder_hidden_states = (
|
|
1299
|
+
self._compute_and_cache_transformer_block(
|
|
1300
|
+
i,
|
|
1301
|
+
block,
|
|
1302
|
+
hidden_states,
|
|
1303
|
+
encoder_hidden_states,
|
|
1304
|
+
*args,
|
|
1305
|
+
**kwargs,
|
|
1306
|
+
)
|
|
1307
|
+
)
|
|
1308
|
+
else:
|
|
1309
|
+
# Compute all Bn blocks if no specific Bn compute blocks ids are set.
|
|
1310
|
+
for block in self._Bn_transformer_blocks():
|
|
1311
|
+
hidden_states = block(
|
|
1312
|
+
hidden_states,
|
|
1313
|
+
encoder_hidden_states,
|
|
1314
|
+
*args,
|
|
1315
|
+
**kwargs,
|
|
1316
|
+
)
|
|
1317
|
+
if not isinstance(hidden_states, torch.Tensor):
|
|
1318
|
+
hidden_states, encoder_hidden_states = hidden_states
|
|
1319
|
+
if not self.return_hidden_states_first:
|
|
1320
|
+
hidden_states, encoder_hidden_states = (
|
|
1321
|
+
encoder_hidden_states,
|
|
1322
|
+
hidden_states,
|
|
1323
|
+
)
|
|
1324
|
+
torch._dynamo.graph_break()
|
|
1325
|
+
|
|
1326
|
+
hidden_states = (
|
|
1327
|
+
hidden_states.reshape(-1)
|
|
1328
|
+
.contiguous()
|
|
1329
|
+
.reshape(original_hidden_states.shape)
|
|
1330
|
+
)
|
|
1331
|
+
encoder_hidden_states = (
|
|
1332
|
+
encoder_hidden_states.reshape(-1)
|
|
1333
|
+
.contiguous()
|
|
1334
|
+
.reshape(original_encoder_hidden_states.shape)
|
|
1335
|
+
)
|
|
1336
|
+
return hidden_states, encoder_hidden_states
|
|
1337
|
+
|
|
1338
|
+
|
|
1339
|
+
@torch.compiler.disable
|
|
1340
|
+
def patch_cached_stats(
|
|
1341
|
+
transformer,
|
|
1342
|
+
):
|
|
1343
|
+
# Patch the cached stats to the transformer, the cached stats
|
|
1344
|
+
# will be reset for each calling of pipe.__call__(**kwargs).
|
|
1345
|
+
if transformer is None:
|
|
1346
|
+
return
|
|
1347
|
+
|
|
1348
|
+
cached_transformer_blocks = getattr(transformer, "transformer_blocks", None)
|
|
1349
|
+
if cached_transformer_blocks is None:
|
|
1350
|
+
return
|
|
1351
|
+
|
|
1352
|
+
if isinstance(cached_transformer_blocks, torch.nn.ModuleList):
|
|
1353
|
+
cached_transformer_blocks = cached_transformer_blocks[0]
|
|
1354
|
+
if not isinstance(
|
|
1355
|
+
cached_transformer_blocks, DBCachedTransformerBlocks
|
|
1356
|
+
) or not isinstance(transformer, torch.nn.Module):
|
|
1357
|
+
return
|
|
1358
|
+
|
|
1359
|
+
# TODO: Patch more cached stats to the transformer
|
|
1360
|
+
transformer._cached_steps = get_cached_steps()
|
|
1361
|
+
transformer._residual_diffs = get_residual_diffs()
|