cache-dit 0.1.0__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Potentially problematic release.
This version of cache-dit might be problematic. Click here for more details.
- cache_dit/__init__.py +0 -0
- cache_dit/_version.py +21 -0
- cache_dit/cache_factory/__init__.py +166 -0
- cache_dit/cache_factory/dual_block_cache/__init__.py +0 -0
- cache_dit/cache_factory/dual_block_cache/cache_context.py +1361 -0
- cache_dit/cache_factory/dual_block_cache/diffusers_adapters/__init__.py +45 -0
- cache_dit/cache_factory/dual_block_cache/diffusers_adapters/cogvideox.py +89 -0
- cache_dit/cache_factory/dual_block_cache/diffusers_adapters/flux.py +100 -0
- cache_dit/cache_factory/dual_block_cache/diffusers_adapters/mochi.py +88 -0
- cache_dit/cache_factory/dynamic_block_prune/__init__.py +0 -0
- cache_dit/cache_factory/dynamic_block_prune/diffusers_adapters/__init__.py +45 -0
- cache_dit/cache_factory/dynamic_block_prune/diffusers_adapters/cogvideox.py +89 -0
- cache_dit/cache_factory/dynamic_block_prune/diffusers_adapters/flux.py +100 -0
- cache_dit/cache_factory/dynamic_block_prune/diffusers_adapters/mochi.py +89 -0
- cache_dit/cache_factory/dynamic_block_prune/prune_context.py +979 -0
- cache_dit/cache_factory/first_block_cache/__init__.py +0 -0
- cache_dit/cache_factory/first_block_cache/cache_context.py +727 -0
- cache_dit/cache_factory/first_block_cache/diffusers_adapters/__init__.py +53 -0
- cache_dit/cache_factory/first_block_cache/diffusers_adapters/cogvideox.py +89 -0
- cache_dit/cache_factory/first_block_cache/diffusers_adapters/flux.py +100 -0
- cache_dit/cache_factory/first_block_cache/diffusers_adapters/mochi.py +89 -0
- cache_dit/cache_factory/first_block_cache/diffusers_adapters/wan.py +98 -0
- cache_dit/cache_factory/taylorseer.py +76 -0
- cache_dit/cache_factory/utils.py +0 -0
- cache_dit/logger.py +97 -0
- cache_dit/primitives.py +152 -0
- cache_dit-0.1.0.dist-info/METADATA +350 -0
- cache_dit-0.1.0.dist-info/RECORD +31 -0
- cache_dit-0.1.0.dist-info/WHEEL +5 -0
- cache_dit-0.1.0.dist-info/licenses/LICENSE +53 -0
- cache_dit-0.1.0.dist-info/top_level.txt +1 -0
|
@@ -0,0 +1,979 @@
|
|
|
1
|
+
# Adapted from: https://github.com/chengzeyi/ParaAttention/tree/main/src/para_attn/first_block_cache/context.py
|
|
2
|
+
import logging
|
|
3
|
+
import contextlib
|
|
4
|
+
import dataclasses
|
|
5
|
+
from collections import defaultdict
|
|
6
|
+
from typing import Any, Dict, List, Optional, Union
|
|
7
|
+
|
|
8
|
+
import torch
|
|
9
|
+
|
|
10
|
+
import cache_dit.primitives as DP
|
|
11
|
+
from cache_dit.logger import init_logger
|
|
12
|
+
|
|
13
|
+
logger = init_logger(__name__)
|
|
14
|
+
|
|
15
|
+
|
|
16
|
+
@dataclasses.dataclass
|
|
17
|
+
class DBPPruneContext:
|
|
18
|
+
# Dyanmic Block Prune
|
|
19
|
+
# Aleast compute first `Fn` and last `Bn` blocks
|
|
20
|
+
# FnBn designs are inspired by the Dual Block Cache
|
|
21
|
+
Fn_compute_blocks: int = 8
|
|
22
|
+
Bn_compute_blocks: int = 8
|
|
23
|
+
# Non prune blocks IDs, e.g., [0, 1, 2, 3, 4, 5, 6, 7]
|
|
24
|
+
non_prune_blocks_ids: List[int] = dataclasses.field(default_factory=list)
|
|
25
|
+
# L1 hidden states or residual diff threshold for Fn
|
|
26
|
+
residual_diff_threshold: Union[torch.Tensor, float] = 0.0
|
|
27
|
+
l1_hidden_states_diff_threshold: float = None
|
|
28
|
+
important_condition_threshold: float = 0.0
|
|
29
|
+
# Compute the dynamic prune threshold based on the mean of the
|
|
30
|
+
# residual diffs of the previous computed or pruned blocks.
|
|
31
|
+
# But, also limit mean_diff to be at least 2x the residual_diff_threshold
|
|
32
|
+
# to avoid too aggressive pruning.
|
|
33
|
+
enable_dynamic_prune_threshold: bool = False
|
|
34
|
+
max_dynamic_prune_threshold: float = None
|
|
35
|
+
dynamic_prune_threshold_relax_ratio: float = 1.25
|
|
36
|
+
# Residual cache update interval, in steps.
|
|
37
|
+
residual_cache_update_interval: int = 1
|
|
38
|
+
|
|
39
|
+
# Buffer for storing the residuals and other tensors
|
|
40
|
+
buffers: Dict[str, Any] = dataclasses.field(default_factory=dict)
|
|
41
|
+
|
|
42
|
+
# Other settings
|
|
43
|
+
downsample_factor: int = 1
|
|
44
|
+
num_inference_steps: int = -1
|
|
45
|
+
warmup_steps: int = 0 # DON'T pruned in warmup steps
|
|
46
|
+
# DON'T prune if the number of pruned steps >= max_pruned_steps
|
|
47
|
+
max_pruned_steps: int = -1
|
|
48
|
+
|
|
49
|
+
# Statistics
|
|
50
|
+
executed_steps: int = 0
|
|
51
|
+
pruned_blocks: List[int] = dataclasses.field(default_factory=list)
|
|
52
|
+
actual_blocks: List[int] = dataclasses.field(default_factory=list)
|
|
53
|
+
# Residual diffs for each step, [step: list[float]]
|
|
54
|
+
residual_diffs: Dict[str, List[float]] = dataclasses.field(
|
|
55
|
+
default_factory=lambda: defaultdict(list),
|
|
56
|
+
)
|
|
57
|
+
|
|
58
|
+
def get_residual_diff_threshold(self):
|
|
59
|
+
residual_diff_threshold = self.residual_diff_threshold
|
|
60
|
+
if self.l1_hidden_states_diff_threshold is not None:
|
|
61
|
+
# Use the L1 hidden states diff threshold if set
|
|
62
|
+
residual_diff_threshold = self.l1_hidden_states_diff_threshold
|
|
63
|
+
if isinstance(residual_diff_threshold, torch.Tensor):
|
|
64
|
+
residual_diff_threshold = residual_diff_threshold.item()
|
|
65
|
+
if self.enable_dynamic_prune_threshold:
|
|
66
|
+
# Compute the dynamic prune threshold based on the mean of the
|
|
67
|
+
# residual diffs of the previous computed or pruned blocks.
|
|
68
|
+
step = self.get_current_step()
|
|
69
|
+
if step >= 0 and step in self.residual_diffs:
|
|
70
|
+
# TODO: Should we only use the last 5 diffs
|
|
71
|
+
diffs = self.residual_diffs[step][:]
|
|
72
|
+
diffs = [d for d in diffs if d > 0.0]
|
|
73
|
+
if diffs:
|
|
74
|
+
mean_diff = sum(diffs) / len(diffs)
|
|
75
|
+
relaxed_diff = (
|
|
76
|
+
mean_diff * self.dynamic_prune_threshold_relax_ratio
|
|
77
|
+
)
|
|
78
|
+
if self.max_dynamic_prune_threshold is None:
|
|
79
|
+
max_dynamic_prune_threshold = (
|
|
80
|
+
2 * residual_diff_threshold
|
|
81
|
+
)
|
|
82
|
+
else:
|
|
83
|
+
max_dynamic_prune_threshold = (
|
|
84
|
+
self.max_dynamic_prune_threshold
|
|
85
|
+
)
|
|
86
|
+
if relaxed_diff < max_dynamic_prune_threshold:
|
|
87
|
+
# If the mean diff is less than twice the threshold,
|
|
88
|
+
# we can use it as the dynamic prune threshold.
|
|
89
|
+
residual_diff_threshold = (
|
|
90
|
+
relaxed_diff
|
|
91
|
+
if relaxed_diff > residual_diff_threshold
|
|
92
|
+
else residual_diff_threshold
|
|
93
|
+
)
|
|
94
|
+
if logger.isEnabledFor(logging.DEBUG):
|
|
95
|
+
logger.debug(
|
|
96
|
+
f"Dynamic prune threshold for step {step}: "
|
|
97
|
+
f"{residual_diff_threshold:.6f}"
|
|
98
|
+
)
|
|
99
|
+
return residual_diff_threshold
|
|
100
|
+
|
|
101
|
+
def get_buffer(self, name):
|
|
102
|
+
return self.buffers.get(name)
|
|
103
|
+
|
|
104
|
+
def set_buffer(self, name, buffer):
|
|
105
|
+
self.buffers[name] = buffer
|
|
106
|
+
|
|
107
|
+
def remove_buffer(self, name):
|
|
108
|
+
if name in self.buffers:
|
|
109
|
+
del self.buffers[name]
|
|
110
|
+
|
|
111
|
+
def clear_buffers(self):
|
|
112
|
+
self.buffers.clear()
|
|
113
|
+
|
|
114
|
+
def mark_step_begin(self):
|
|
115
|
+
self.executed_steps += 1
|
|
116
|
+
if self.get_current_step() == 0:
|
|
117
|
+
self.pruned_blocks.clear()
|
|
118
|
+
self.actual_blocks.clear()
|
|
119
|
+
self.residual_diffs.clear()
|
|
120
|
+
|
|
121
|
+
def add_pruned_block(self, num_blocks):
|
|
122
|
+
self.pruned_blocks.append(num_blocks)
|
|
123
|
+
|
|
124
|
+
def add_actual_block(self, num_blocks):
|
|
125
|
+
self.actual_blocks.append(num_blocks)
|
|
126
|
+
|
|
127
|
+
def add_residual_diff(self, diff):
|
|
128
|
+
if isinstance(diff, torch.Tensor):
|
|
129
|
+
diff = diff.item()
|
|
130
|
+
step = self.get_current_step()
|
|
131
|
+
self.residual_diffs[step].append(diff)
|
|
132
|
+
max_num_block_diffs = 1000
|
|
133
|
+
# Avoid memory leak, keep only the last 1000 diffs
|
|
134
|
+
if len(self.residual_diffs[step]) > max_num_block_diffs:
|
|
135
|
+
self.residual_diffs[step] = self.residual_diffs[step][
|
|
136
|
+
-max_num_block_diffs:
|
|
137
|
+
]
|
|
138
|
+
if logger.isEnabledFor(logging.DEBUG):
|
|
139
|
+
logger.debug(
|
|
140
|
+
f"Step {step}, block: {len(self.residual_diffs[step])}, "
|
|
141
|
+
f"residual diff: {diff:.6f}"
|
|
142
|
+
)
|
|
143
|
+
|
|
144
|
+
def get_current_step(self):
|
|
145
|
+
return self.executed_steps - 1
|
|
146
|
+
|
|
147
|
+
def is_in_warmup(self):
|
|
148
|
+
return self.get_current_step() < self.warmup_steps
|
|
149
|
+
|
|
150
|
+
|
|
151
|
+
@torch.compiler.disable
|
|
152
|
+
def get_residual_diff_threshold():
|
|
153
|
+
prune_context = get_current_prune_context()
|
|
154
|
+
assert prune_context is not None, "prune_context must be set before"
|
|
155
|
+
return prune_context.get_residual_diff_threshold()
|
|
156
|
+
|
|
157
|
+
|
|
158
|
+
@torch.compiler.disable
|
|
159
|
+
def get_buffer(name):
|
|
160
|
+
prune_context = get_current_prune_context()
|
|
161
|
+
assert prune_context is not None, "prune_context must be set before"
|
|
162
|
+
return prune_context.get_buffer(name)
|
|
163
|
+
|
|
164
|
+
|
|
165
|
+
@torch.compiler.disable
|
|
166
|
+
def set_buffer(name, buffer):
|
|
167
|
+
prune_context = get_current_prune_context()
|
|
168
|
+
assert prune_context is not None, "prune_context must be set before"
|
|
169
|
+
prune_context.set_buffer(name, buffer)
|
|
170
|
+
|
|
171
|
+
|
|
172
|
+
@torch.compiler.disable
|
|
173
|
+
def remove_buffer(name):
|
|
174
|
+
prune_context = get_current_prune_context()
|
|
175
|
+
assert prune_context is not None, "prune_context must be set before"
|
|
176
|
+
prune_context.remove_buffer(name)
|
|
177
|
+
|
|
178
|
+
|
|
179
|
+
@torch.compiler.disable
|
|
180
|
+
def mark_step_begin():
|
|
181
|
+
prune_context = get_current_prune_context()
|
|
182
|
+
assert prune_context is not None, "prune_context must be set before"
|
|
183
|
+
prune_context.mark_step_begin()
|
|
184
|
+
|
|
185
|
+
|
|
186
|
+
@torch.compiler.disable
|
|
187
|
+
def get_current_step():
|
|
188
|
+
prune_context = get_current_prune_context()
|
|
189
|
+
assert prune_context is not None, "prune_context must be set before"
|
|
190
|
+
return prune_context.get_current_step()
|
|
191
|
+
|
|
192
|
+
|
|
193
|
+
@torch.compiler.disable
|
|
194
|
+
def get_max_pruned_steps():
|
|
195
|
+
prune_context = get_current_prune_context()
|
|
196
|
+
assert prune_context is not None, "prune_context must be set before"
|
|
197
|
+
return prune_context.max_pruned_steps
|
|
198
|
+
|
|
199
|
+
|
|
200
|
+
@torch.compiler.disable
|
|
201
|
+
def add_pruned_block(num_blocks):
|
|
202
|
+
assert (
|
|
203
|
+
isinstance(num_blocks, int) and num_blocks >= 0
|
|
204
|
+
), "num_blocks must be a non-negative integer"
|
|
205
|
+
prune_context = get_current_prune_context()
|
|
206
|
+
assert prune_context is not None, "prune_context must be set before"
|
|
207
|
+
prune_context.add_pruned_block(num_blocks)
|
|
208
|
+
|
|
209
|
+
|
|
210
|
+
@torch.compiler.disable
|
|
211
|
+
def get_pruned_blocks():
|
|
212
|
+
prune_context = get_current_prune_context()
|
|
213
|
+
assert prune_context is not None, "prune_context must be set before"
|
|
214
|
+
return prune_context.pruned_blocks.copy()
|
|
215
|
+
|
|
216
|
+
|
|
217
|
+
@torch.compiler.disable
|
|
218
|
+
def add_actual_block(num_blocks):
|
|
219
|
+
assert (
|
|
220
|
+
isinstance(num_blocks, int) and num_blocks >= 0
|
|
221
|
+
), "num_blocks must be a non-negative integer"
|
|
222
|
+
prune_context = get_current_prune_context()
|
|
223
|
+
assert prune_context is not None, "prune_context must be set before"
|
|
224
|
+
prune_context.add_actual_block(num_blocks)
|
|
225
|
+
|
|
226
|
+
|
|
227
|
+
@torch.compiler.disable
|
|
228
|
+
def get_actual_blocks():
|
|
229
|
+
prune_context = get_current_prune_context()
|
|
230
|
+
assert prune_context is not None, "prune_context must be set before"
|
|
231
|
+
return prune_context.actual_blocks.copy()
|
|
232
|
+
|
|
233
|
+
|
|
234
|
+
@torch.compiler.disable
|
|
235
|
+
def get_pruned_steps():
|
|
236
|
+
prune_context = get_current_prune_context()
|
|
237
|
+
assert prune_context is not None, "prune_context must be set before"
|
|
238
|
+
pruned_blocks = get_pruned_blocks()
|
|
239
|
+
pruned_blocks = [x for x in pruned_blocks if x > 0]
|
|
240
|
+
return len(pruned_blocks)
|
|
241
|
+
|
|
242
|
+
|
|
243
|
+
@torch.compiler.disable
|
|
244
|
+
def is_in_warmup():
|
|
245
|
+
prune_context = get_current_prune_context()
|
|
246
|
+
assert prune_context is not None, "prune_context must be set before"
|
|
247
|
+
return prune_context.is_in_warmup()
|
|
248
|
+
|
|
249
|
+
|
|
250
|
+
@torch.compiler.disable
|
|
251
|
+
def is_l1_diff_enabled():
|
|
252
|
+
prune_context = get_current_prune_context()
|
|
253
|
+
assert prune_context is not None, "prune_context must be set before"
|
|
254
|
+
return (
|
|
255
|
+
prune_context.l1_hidden_states_diff_threshold is not None
|
|
256
|
+
and prune_context.l1_hidden_states_diff_threshold > 0.0
|
|
257
|
+
)
|
|
258
|
+
|
|
259
|
+
|
|
260
|
+
@torch.compiler.disable
|
|
261
|
+
def add_residual_diff(diff):
|
|
262
|
+
prune_context = get_current_prune_context()
|
|
263
|
+
assert prune_context is not None, "prune_context must be set before"
|
|
264
|
+
prune_context.add_residual_diff(diff)
|
|
265
|
+
|
|
266
|
+
|
|
267
|
+
@torch.compiler.disable
|
|
268
|
+
def get_residual_diffs():
|
|
269
|
+
prune_context = get_current_prune_context()
|
|
270
|
+
assert prune_context is not None, "prune_context must be set before"
|
|
271
|
+
# Return a copy of the residual diffs to avoid modification
|
|
272
|
+
return prune_context.residual_diffs.copy()
|
|
273
|
+
|
|
274
|
+
|
|
275
|
+
@torch.compiler.disable
|
|
276
|
+
def get_important_condition_threshold():
|
|
277
|
+
prune_context = get_current_prune_context()
|
|
278
|
+
assert prune_context is not None, "prune_context must be set before"
|
|
279
|
+
return prune_context.important_condition_threshold
|
|
280
|
+
|
|
281
|
+
|
|
282
|
+
@torch.compiler.disable
|
|
283
|
+
def residual_cache_update_interval():
|
|
284
|
+
prune_context = get_current_prune_context()
|
|
285
|
+
assert prune_context is not None, "prune_context must be set before"
|
|
286
|
+
return prune_context.residual_cache_update_interval
|
|
287
|
+
|
|
288
|
+
|
|
289
|
+
@torch.compiler.disable
|
|
290
|
+
def Fn_compute_blocks():
|
|
291
|
+
prune_context = get_current_prune_context()
|
|
292
|
+
assert prune_context is not None, "prune_context must be set before"
|
|
293
|
+
assert (
|
|
294
|
+
prune_context.Fn_compute_blocks >= 0
|
|
295
|
+
), "Fn_compute_blocks must be >= 0"
|
|
296
|
+
return prune_context.Fn_compute_blocks
|
|
297
|
+
|
|
298
|
+
|
|
299
|
+
@torch.compiler.disable
|
|
300
|
+
def Bn_compute_blocks():
|
|
301
|
+
prune_context = get_current_prune_context()
|
|
302
|
+
assert prune_context is not None, "prune_context must be set before"
|
|
303
|
+
assert (
|
|
304
|
+
prune_context.Bn_compute_blocks >= 0
|
|
305
|
+
), "Bn_compute_blocks must be >= 0"
|
|
306
|
+
return prune_context.Bn_compute_blocks
|
|
307
|
+
|
|
308
|
+
|
|
309
|
+
@torch.compiler.disable
|
|
310
|
+
def get_non_prune_blocks_ids():
|
|
311
|
+
prune_context = get_current_prune_context()
|
|
312
|
+
assert prune_context is not None, "prune_context must be set before"
|
|
313
|
+
return prune_context.non_prune_blocks_ids
|
|
314
|
+
|
|
315
|
+
|
|
316
|
+
_current_prune_context: DBPPruneContext = None
|
|
317
|
+
|
|
318
|
+
|
|
319
|
+
def create_prune_context(*args, **kwargs):
|
|
320
|
+
return DBPPruneContext(*args, **kwargs)
|
|
321
|
+
|
|
322
|
+
|
|
323
|
+
def get_current_prune_context():
|
|
324
|
+
return _current_prune_context
|
|
325
|
+
|
|
326
|
+
|
|
327
|
+
def set_current_prune_context(prune_context=None):
|
|
328
|
+
global _current_prune_context
|
|
329
|
+
_current_prune_context = prune_context
|
|
330
|
+
|
|
331
|
+
|
|
332
|
+
def collect_prune_kwargs(default_attrs: dict, **kwargs):
|
|
333
|
+
# NOTE: This API will split kwargs into prune_kwargs and other_kwargs
|
|
334
|
+
# default_attrs: specific settings for different pipelines
|
|
335
|
+
prune_attrs = dataclasses.fields(DBPPruneContext)
|
|
336
|
+
prune_attrs = [
|
|
337
|
+
attr
|
|
338
|
+
for attr in prune_attrs
|
|
339
|
+
if hasattr(
|
|
340
|
+
DBPPruneContext,
|
|
341
|
+
attr.name,
|
|
342
|
+
)
|
|
343
|
+
]
|
|
344
|
+
prune_kwargs = {
|
|
345
|
+
attr.name: kwargs.pop(
|
|
346
|
+
attr.name,
|
|
347
|
+
getattr(DBPPruneContext, attr.name),
|
|
348
|
+
)
|
|
349
|
+
for attr in prune_attrs
|
|
350
|
+
}
|
|
351
|
+
# Manually set sequence fields, such as non_prune_blocks_ids
|
|
352
|
+
prune_kwargs["non_prune_blocks_ids"] = kwargs.pop(
|
|
353
|
+
"non_prune_blocks_ids",
|
|
354
|
+
[],
|
|
355
|
+
)
|
|
356
|
+
|
|
357
|
+
assert default_attrs is not None, "default_attrs must be set before"
|
|
358
|
+
for attr in prune_attrs:
|
|
359
|
+
if attr.name in default_attrs:
|
|
360
|
+
prune_kwargs[attr.name] = default_attrs[attr.name]
|
|
361
|
+
|
|
362
|
+
if logger.isEnabledFor(logging.DEBUG):
|
|
363
|
+
logger.debug(f"Collected DBPrune kwargs: {prune_kwargs}")
|
|
364
|
+
|
|
365
|
+
return prune_kwargs, kwargs
|
|
366
|
+
|
|
367
|
+
|
|
368
|
+
@contextlib.contextmanager
|
|
369
|
+
def prune_context(prune_context):
|
|
370
|
+
global _current_prune_context
|
|
371
|
+
old_prune_context = _current_prune_context
|
|
372
|
+
_current_prune_context = prune_context
|
|
373
|
+
try:
|
|
374
|
+
yield
|
|
375
|
+
finally:
|
|
376
|
+
_current_prune_context = old_prune_context
|
|
377
|
+
|
|
378
|
+
|
|
379
|
+
@torch.compiler.disable
|
|
380
|
+
def are_two_tensors_similar(
|
|
381
|
+
t1: torch.Tensor, # prev residual R(t-1,n) = H(t-1,n) - H(t-1,0)
|
|
382
|
+
t2: torch.Tensor, # curr residual R(t ,n) = H(t ,n) - H(t ,0)
|
|
383
|
+
*,
|
|
384
|
+
threshold: float,
|
|
385
|
+
parallelized: bool = False,
|
|
386
|
+
name: str = "Bn", # for debugging
|
|
387
|
+
):
|
|
388
|
+
# Special case for threshold, 0.0 means the threshold is disabled, -1.0 means
|
|
389
|
+
# the threshold is always enabled, -2.0 means the shape is not matched.
|
|
390
|
+
if threshold <= 0.0:
|
|
391
|
+
add_residual_diff(-0.0)
|
|
392
|
+
return False
|
|
393
|
+
|
|
394
|
+
if threshold >= 1.0:
|
|
395
|
+
# If threshold is 1.0 or more, we consider them always similar.
|
|
396
|
+
add_residual_diff(-1.0)
|
|
397
|
+
return True
|
|
398
|
+
|
|
399
|
+
if t1.shape != t2.shape:
|
|
400
|
+
if logger.isEnabledFor(logging.DEBUG):
|
|
401
|
+
logger.debug(f"{name}, shape error: {t1.shape} != {t2.shape}")
|
|
402
|
+
add_residual_diff(-2.0)
|
|
403
|
+
return False
|
|
404
|
+
|
|
405
|
+
# Find the most significant token through t1 and t2, and
|
|
406
|
+
# consider the diff of the significant token. The more significant,
|
|
407
|
+
# the more important.
|
|
408
|
+
condition_thresh = get_important_condition_threshold()
|
|
409
|
+
if condition_thresh > 0.0:
|
|
410
|
+
raw_diff = (t1 - t2).abs() # [B, seq_len, d]
|
|
411
|
+
token_m_df = raw_diff.mean(dim=-1) # [B, seq_len]
|
|
412
|
+
token_m_t1 = t1.abs().mean(dim=-1) # [B, seq_len]
|
|
413
|
+
# D = (t1 - t2) / t1 = 1 - (t2 / t1), if D = 0, then t1 = t2.
|
|
414
|
+
token_diff = token_m_df / token_m_t1 # [B, seq_len]
|
|
415
|
+
condition = token_diff > condition_thresh # [B, seq_len]
|
|
416
|
+
if condition.sum() > 0:
|
|
417
|
+
condition = condition.unsqueeze(-1) # [B, seq_len, 1]
|
|
418
|
+
condition = condition.expand_as(raw_diff) # [B, seq_len, d]
|
|
419
|
+
mean_diff = raw_diff[condition].mean()
|
|
420
|
+
mean_t1 = t1[condition].abs().mean()
|
|
421
|
+
else:
|
|
422
|
+
mean_diff = (t1 - t2).abs().mean()
|
|
423
|
+
mean_t1 = t1.abs().mean()
|
|
424
|
+
else:
|
|
425
|
+
# Use the mean of the absolute difference of the tensors
|
|
426
|
+
mean_diff = (t1 - t2).abs().mean()
|
|
427
|
+
mean_t1 = t1.abs().mean()
|
|
428
|
+
|
|
429
|
+
if parallelized:
|
|
430
|
+
mean_diff = DP.all_reduce_sync(mean_diff, "avg")
|
|
431
|
+
mean_t1 = DP.all_reduce_sync(mean_t1, "avg")
|
|
432
|
+
|
|
433
|
+
# D = (t1 - t2) / t1 = 1 - (t2 / t1), if D = 0, then t1 = t2.
|
|
434
|
+
# Futher, if we assume that (H(t, 0) - H(t-1,0)) ~ 0, then,
|
|
435
|
+
# H(t-1,n) ~ H(t ,n), which means the hidden states are similar.
|
|
436
|
+
diff = (mean_diff / mean_t1).item()
|
|
437
|
+
|
|
438
|
+
if logger.isEnabledFor(logging.DEBUG):
|
|
439
|
+
logger.debug(f"{name}, diff: {diff:.6f}, threshold: {threshold:.6f}")
|
|
440
|
+
|
|
441
|
+
add_residual_diff(diff)
|
|
442
|
+
|
|
443
|
+
return diff < threshold
|
|
444
|
+
|
|
445
|
+
|
|
446
|
+
@torch.compiler.disable
|
|
447
|
+
def apply_hidden_states_residual(
|
|
448
|
+
hidden_states: torch.Tensor,
|
|
449
|
+
encoder_hidden_states: torch.Tensor,
|
|
450
|
+
name: str = "Bn",
|
|
451
|
+
encoder_name: str = "Bn_encoder",
|
|
452
|
+
):
|
|
453
|
+
hidden_states_residual = get_buffer(f"{name}")
|
|
454
|
+
|
|
455
|
+
assert hidden_states_residual is not None, f"{name} must be set before"
|
|
456
|
+
hidden_states = hidden_states_residual + hidden_states
|
|
457
|
+
|
|
458
|
+
encoder_hidden_states_residual = get_buffer(f"{encoder_name}")
|
|
459
|
+
assert (
|
|
460
|
+
encoder_hidden_states_residual is not None
|
|
461
|
+
), f"{encoder_name} must be set before"
|
|
462
|
+
encoder_hidden_states = (
|
|
463
|
+
encoder_hidden_states_residual + encoder_hidden_states
|
|
464
|
+
)
|
|
465
|
+
|
|
466
|
+
hidden_states = hidden_states.contiguous()
|
|
467
|
+
encoder_hidden_states = encoder_hidden_states.contiguous()
|
|
468
|
+
|
|
469
|
+
return hidden_states, encoder_hidden_states
|
|
470
|
+
|
|
471
|
+
|
|
472
|
+
@torch.compiler.disable
|
|
473
|
+
def get_downsample_factor():
|
|
474
|
+
prune_context = get_current_prune_context()
|
|
475
|
+
assert prune_context is not None, "prune_context must be set before"
|
|
476
|
+
return prune_context.downsample_factor
|
|
477
|
+
|
|
478
|
+
|
|
479
|
+
@torch.compiler.disable
|
|
480
|
+
def get_can_use_prune(
|
|
481
|
+
states_tensor: torch.Tensor, # hidden_states or residual
|
|
482
|
+
parallelized: bool = False,
|
|
483
|
+
threshold: Optional[float] = None, # can manually set threshold
|
|
484
|
+
name: str = "Bn",
|
|
485
|
+
):
|
|
486
|
+
if is_in_warmup():
|
|
487
|
+
return False
|
|
488
|
+
|
|
489
|
+
pruned_steps = get_pruned_steps()
|
|
490
|
+
max_pruned_steps = get_max_pruned_steps()
|
|
491
|
+
if max_pruned_steps >= 0 and (pruned_steps >= max_pruned_steps):
|
|
492
|
+
if logger.isEnabledFor(logging.DEBUG):
|
|
493
|
+
logger.debug(
|
|
494
|
+
f"{name}, max_pruned_steps reached: {max_pruned_steps}, "
|
|
495
|
+
"cannot use prune."
|
|
496
|
+
)
|
|
497
|
+
return False
|
|
498
|
+
|
|
499
|
+
if threshold is None or threshold <= 0.0:
|
|
500
|
+
threshold = get_residual_diff_threshold()
|
|
501
|
+
if threshold <= 0.0:
|
|
502
|
+
return False
|
|
503
|
+
|
|
504
|
+
downsample_factor = get_downsample_factor()
|
|
505
|
+
prev_states_tensor = get_buffer(f"{name}")
|
|
506
|
+
|
|
507
|
+
if downsample_factor > 1:
|
|
508
|
+
states_tensor = states_tensor[..., ::downsample_factor]
|
|
509
|
+
states_tensor = states_tensor.contiguous()
|
|
510
|
+
if prev_states_tensor is not None:
|
|
511
|
+
prev_states_tensor = prev_states_tensor[..., ::downsample_factor]
|
|
512
|
+
prev_states_tensor = prev_states_tensor.contiguous()
|
|
513
|
+
|
|
514
|
+
return prev_states_tensor is not None and are_two_tensors_similar(
|
|
515
|
+
prev_states_tensor,
|
|
516
|
+
states_tensor,
|
|
517
|
+
threshold=threshold,
|
|
518
|
+
parallelized=parallelized,
|
|
519
|
+
name=name,
|
|
520
|
+
)
|
|
521
|
+
|
|
522
|
+
|
|
523
|
+
class DBPrunedTransformerBlocks(torch.nn.Module):
|
|
524
|
+
def __init__(
|
|
525
|
+
self,
|
|
526
|
+
transformer_blocks,
|
|
527
|
+
single_transformer_blocks=None,
|
|
528
|
+
*,
|
|
529
|
+
transformer=None,
|
|
530
|
+
return_hidden_states_first=True,
|
|
531
|
+
return_hidden_states_only=False,
|
|
532
|
+
):
|
|
533
|
+
super().__init__()
|
|
534
|
+
|
|
535
|
+
self.transformer = transformer
|
|
536
|
+
self.transformer_blocks = transformer_blocks
|
|
537
|
+
self.single_transformer_blocks = single_transformer_blocks
|
|
538
|
+
self.return_hidden_states_first = return_hidden_states_first
|
|
539
|
+
self.return_hidden_states_only = return_hidden_states_only
|
|
540
|
+
self.pruned_blocks_step: int = 0
|
|
541
|
+
|
|
542
|
+
def forward(
|
|
543
|
+
self,
|
|
544
|
+
hidden_states: torch.Tensor,
|
|
545
|
+
encoder_hidden_states: torch.Tensor,
|
|
546
|
+
*args,
|
|
547
|
+
**kwargs,
|
|
548
|
+
):
|
|
549
|
+
mark_step_begin()
|
|
550
|
+
self.pruned_blocks_step = 0
|
|
551
|
+
original_hidden_states = hidden_states
|
|
552
|
+
|
|
553
|
+
torch._dynamo.graph_break()
|
|
554
|
+
hidden_states, encoder_hidden_states = self.call_transformer_blocks(
|
|
555
|
+
hidden_states,
|
|
556
|
+
encoder_hidden_states,
|
|
557
|
+
*args,
|
|
558
|
+
**kwargs,
|
|
559
|
+
)
|
|
560
|
+
|
|
561
|
+
del original_hidden_states
|
|
562
|
+
torch._dynamo.graph_break()
|
|
563
|
+
|
|
564
|
+
add_pruned_block(self.pruned_blocks_step)
|
|
565
|
+
add_actual_block(self._num_transformer_blocks)
|
|
566
|
+
patch_pruned_stats(self.transformer)
|
|
567
|
+
|
|
568
|
+
return (
|
|
569
|
+
hidden_states
|
|
570
|
+
if self.return_hidden_states_only
|
|
571
|
+
else (
|
|
572
|
+
(hidden_states, encoder_hidden_states)
|
|
573
|
+
if self.return_hidden_states_first
|
|
574
|
+
else (encoder_hidden_states, hidden_states)
|
|
575
|
+
)
|
|
576
|
+
)
|
|
577
|
+
|
|
578
|
+
@property
|
|
579
|
+
@torch.compiler.disable
|
|
580
|
+
def _num_transformer_blocks(self):
|
|
581
|
+
# Total number of transformer blocks, including single transformer blocks.
|
|
582
|
+
num_blocks = len(self.transformer_blocks)
|
|
583
|
+
if self.single_transformer_blocks is not None:
|
|
584
|
+
num_blocks += len(self.single_transformer_blocks)
|
|
585
|
+
return num_blocks
|
|
586
|
+
|
|
587
|
+
@torch.compiler.disable
|
|
588
|
+
def _is_parallelized(self):
|
|
589
|
+
# Compatible with distributed inference.
|
|
590
|
+
return all(
|
|
591
|
+
(
|
|
592
|
+
self.transformer is not None,
|
|
593
|
+
getattr(self.transformer, "_is_parallelized", False),
|
|
594
|
+
)
|
|
595
|
+
)
|
|
596
|
+
|
|
597
|
+
@torch.compiler.disable
|
|
598
|
+
def _non_prune_blocks_ids(self):
|
|
599
|
+
# Never prune the first `Fn` and last `Bn` blocks.
|
|
600
|
+
num_blocks = self._num_transformer_blocks
|
|
601
|
+
Fn_compute_blocks_ = (
|
|
602
|
+
Fn_compute_blocks()
|
|
603
|
+
if Fn_compute_blocks() < num_blocks
|
|
604
|
+
else num_blocks
|
|
605
|
+
)
|
|
606
|
+
Fn_compute_blocks_ids = list(range(Fn_compute_blocks_))
|
|
607
|
+
Bn_compute_blocks_ = (
|
|
608
|
+
Bn_compute_blocks()
|
|
609
|
+
if Bn_compute_blocks() < num_blocks
|
|
610
|
+
else num_blocks
|
|
611
|
+
)
|
|
612
|
+
Bn_compute_blocks_ids = list(
|
|
613
|
+
range(
|
|
614
|
+
num_blocks - Bn_compute_blocks_,
|
|
615
|
+
num_blocks,
|
|
616
|
+
)
|
|
617
|
+
)
|
|
618
|
+
non_prune_blocks_ids = list(
|
|
619
|
+
set(
|
|
620
|
+
Fn_compute_blocks_ids
|
|
621
|
+
+ Bn_compute_blocks_ids
|
|
622
|
+
+ get_non_prune_blocks_ids()
|
|
623
|
+
)
|
|
624
|
+
)
|
|
625
|
+
non_prune_blocks_ids = [
|
|
626
|
+
d for d in non_prune_blocks_ids if d < num_blocks
|
|
627
|
+
]
|
|
628
|
+
return sorted(non_prune_blocks_ids)
|
|
629
|
+
|
|
630
|
+
@torch.compiler.disable
|
|
631
|
+
def _compute_single_hidden_states_residual(
|
|
632
|
+
self,
|
|
633
|
+
single_hidden_states: torch.Tensor,
|
|
634
|
+
single_original_hidden_states: torch.Tensor,
|
|
635
|
+
# global original single hidden states
|
|
636
|
+
original_single_hidden_states: torch.Tensor,
|
|
637
|
+
original_single_encoder_hidden_states: torch.Tensor,
|
|
638
|
+
):
|
|
639
|
+
single_hidden_states, single_encoder_hidden_states = (
|
|
640
|
+
self._split_single_hidden_states(
|
|
641
|
+
single_hidden_states,
|
|
642
|
+
original_single_hidden_states,
|
|
643
|
+
original_single_encoder_hidden_states,
|
|
644
|
+
)
|
|
645
|
+
)
|
|
646
|
+
|
|
647
|
+
single_original_hidden_states, single_original_encoder_hidden_states = (
|
|
648
|
+
self._split_single_hidden_states(
|
|
649
|
+
single_original_hidden_states,
|
|
650
|
+
original_single_hidden_states,
|
|
651
|
+
original_single_encoder_hidden_states,
|
|
652
|
+
)
|
|
653
|
+
)
|
|
654
|
+
|
|
655
|
+
single_hidden_states_residual = (
|
|
656
|
+
single_hidden_states - single_original_hidden_states
|
|
657
|
+
)
|
|
658
|
+
single_encoder_hidden_states_residual = (
|
|
659
|
+
single_encoder_hidden_states - single_original_encoder_hidden_states
|
|
660
|
+
)
|
|
661
|
+
return (
|
|
662
|
+
single_hidden_states_residual,
|
|
663
|
+
single_encoder_hidden_states_residual,
|
|
664
|
+
)
|
|
665
|
+
|
|
666
|
+
@torch.compiler.disable
|
|
667
|
+
def _split_single_hidden_states(
|
|
668
|
+
self,
|
|
669
|
+
single_hidden_states: torch.Tensor,
|
|
670
|
+
# global original single hidden states
|
|
671
|
+
original_single_hidden_states: torch.Tensor,
|
|
672
|
+
original_single_encoder_hidden_states: torch.Tensor,
|
|
673
|
+
):
|
|
674
|
+
single_encoder_hidden_states, single_hidden_states = (
|
|
675
|
+
single_hidden_states.split(
|
|
676
|
+
[
|
|
677
|
+
original_single_encoder_hidden_states.shape[1],
|
|
678
|
+
single_hidden_states.shape[1]
|
|
679
|
+
- original_single_encoder_hidden_states.shape[1],
|
|
680
|
+
],
|
|
681
|
+
dim=1,
|
|
682
|
+
)
|
|
683
|
+
)
|
|
684
|
+
# Reshape the single_hidden_states and single_encoder_hidden_states
|
|
685
|
+
# to the original shape. This is necessary to ensure that the
|
|
686
|
+
# residuals are computed correctly.
|
|
687
|
+
single_hidden_states = (
|
|
688
|
+
single_hidden_states.reshape(-1)
|
|
689
|
+
.contiguous()
|
|
690
|
+
.reshape(original_single_hidden_states.shape)
|
|
691
|
+
)
|
|
692
|
+
single_encoder_hidden_states = (
|
|
693
|
+
single_encoder_hidden_states.reshape(-1)
|
|
694
|
+
.contiguous()
|
|
695
|
+
.reshape(original_single_encoder_hidden_states.shape)
|
|
696
|
+
)
|
|
697
|
+
return single_hidden_states, single_encoder_hidden_states
|
|
698
|
+
|
|
699
|
+
@torch.compiler.disable
|
|
700
|
+
def _should_update_residuals(self):
|
|
701
|
+
# Wrap for non compiled mode.
|
|
702
|
+
# Check if the current step is a multiple of
|
|
703
|
+
# the residual cache update interval.
|
|
704
|
+
return get_current_step() % residual_cache_update_interval() == 0
|
|
705
|
+
|
|
706
|
+
@torch.compiler.disable
|
|
707
|
+
def _get_can_use_prune(
|
|
708
|
+
self,
|
|
709
|
+
block_id: int, # Block index in the transformer blocks
|
|
710
|
+
hidden_states: torch.Tensor, # hidden_states or residual
|
|
711
|
+
name: str = "Bn_original", # prev step name for single blocks
|
|
712
|
+
):
|
|
713
|
+
# Wrap for non compiled mode.
|
|
714
|
+
can_use_prune = False
|
|
715
|
+
if block_id not in self._non_prune_blocks_ids():
|
|
716
|
+
can_use_prune = get_can_use_prune(
|
|
717
|
+
hidden_states, # curr step
|
|
718
|
+
parallelized=self._is_parallelized(),
|
|
719
|
+
name=name, # prev step
|
|
720
|
+
)
|
|
721
|
+
self.pruned_blocks_step += int(can_use_prune)
|
|
722
|
+
return can_use_prune
|
|
723
|
+
|
|
724
|
+
def _compute_or_prune_single_transformer_block(
|
|
725
|
+
self,
|
|
726
|
+
block_id: int, # Block index in the transformer blocks
|
|
727
|
+
# Helper inputs for hidden states split and reshape
|
|
728
|
+
# Global original single hidden states
|
|
729
|
+
original_single_hidden_states: torch.Tensor,
|
|
730
|
+
original_single_encoder_hidden_states: torch.Tensor,
|
|
731
|
+
# Below are the inputs to the block
|
|
732
|
+
block, # The transformer block to be executed
|
|
733
|
+
hidden_states: torch.Tensor,
|
|
734
|
+
*args,
|
|
735
|
+
**kwargs,
|
|
736
|
+
):
|
|
737
|
+
# Helper function for `call_transformer_blocks`
|
|
738
|
+
# block_id: global block index in the transformer blocks +
|
|
739
|
+
# single_transformer_blocks
|
|
740
|
+
can_use_prune = self._get_can_use_prune(
|
|
741
|
+
block_id,
|
|
742
|
+
hidden_states, # hidden_states or residual
|
|
743
|
+
name=f"{block_id}_single_original", # prev step
|
|
744
|
+
)
|
|
745
|
+
|
|
746
|
+
# Prune steps: Prune current block and reuse the cached
|
|
747
|
+
# residuals for hidden states approximate.
|
|
748
|
+
if can_use_prune:
|
|
749
|
+
single_original_hidden_states = hidden_states
|
|
750
|
+
(
|
|
751
|
+
single_original_hidden_states,
|
|
752
|
+
single_original_encoder_hidden_states,
|
|
753
|
+
) = self._split_single_hidden_states(
|
|
754
|
+
single_original_hidden_states,
|
|
755
|
+
original_single_hidden_states,
|
|
756
|
+
original_single_encoder_hidden_states,
|
|
757
|
+
)
|
|
758
|
+
hidden_states, encoder_hidden_states = apply_hidden_states_residual(
|
|
759
|
+
single_original_hidden_states,
|
|
760
|
+
single_original_encoder_hidden_states,
|
|
761
|
+
name=f"{block_id}_single_residual",
|
|
762
|
+
encoder_name=f"{block_id}_single_encoder_residual",
|
|
763
|
+
)
|
|
764
|
+
hidden_states = torch.cat(
|
|
765
|
+
[encoder_hidden_states, hidden_states],
|
|
766
|
+
dim=1,
|
|
767
|
+
)
|
|
768
|
+
del single_original_hidden_states
|
|
769
|
+
del single_original_encoder_hidden_states
|
|
770
|
+
|
|
771
|
+
else:
|
|
772
|
+
# Normal steps: Compute the block and cache the residuals.
|
|
773
|
+
single_original_hidden_states = hidden_states
|
|
774
|
+
hidden_states = block(
|
|
775
|
+
hidden_states,
|
|
776
|
+
*args,
|
|
777
|
+
**kwargs,
|
|
778
|
+
)
|
|
779
|
+
|
|
780
|
+
# Save original_hidden_states for diff calculation.
|
|
781
|
+
# May not be necessary to update the hidden
|
|
782
|
+
# states and residuals each step?
|
|
783
|
+
if self._should_update_residuals():
|
|
784
|
+
# Cache residuals for the non-compute Bn blocks for
|
|
785
|
+
# subsequent prune steps.
|
|
786
|
+
single_hidden_states = hidden_states
|
|
787
|
+
(
|
|
788
|
+
single_hidden_states_residual,
|
|
789
|
+
single_encoder_hidden_states_residual,
|
|
790
|
+
) = self._compute_single_hidden_states_residual(
|
|
791
|
+
single_hidden_states,
|
|
792
|
+
single_original_hidden_states,
|
|
793
|
+
original_single_hidden_states,
|
|
794
|
+
original_single_encoder_hidden_states,
|
|
795
|
+
)
|
|
796
|
+
|
|
797
|
+
set_buffer(
|
|
798
|
+
f"{block_id}_single_original",
|
|
799
|
+
single_original_hidden_states,
|
|
800
|
+
)
|
|
801
|
+
|
|
802
|
+
set_buffer(
|
|
803
|
+
f"{block_id}_single_residual",
|
|
804
|
+
single_hidden_states_residual,
|
|
805
|
+
)
|
|
806
|
+
set_buffer(
|
|
807
|
+
f"{block_id}_single_encoder_residual",
|
|
808
|
+
single_encoder_hidden_states_residual,
|
|
809
|
+
)
|
|
810
|
+
|
|
811
|
+
del single_hidden_states
|
|
812
|
+
del single_hidden_states_residual
|
|
813
|
+
del single_encoder_hidden_states_residual
|
|
814
|
+
|
|
815
|
+
del single_original_hidden_states
|
|
816
|
+
|
|
817
|
+
return hidden_states
|
|
818
|
+
|
|
819
|
+
def _compute_or_prune_transformer_block(
|
|
820
|
+
self,
|
|
821
|
+
block_id: int, # Block index in the transformer blocks
|
|
822
|
+
# Below are the inputs to the block
|
|
823
|
+
block, # The transformer block to be executed
|
|
824
|
+
hidden_states: torch.Tensor,
|
|
825
|
+
encoder_hidden_states: torch.Tensor,
|
|
826
|
+
*args,
|
|
827
|
+
**kwargs,
|
|
828
|
+
):
|
|
829
|
+
# Helper function for `call_transformer_blocks`
|
|
830
|
+
original_hidden_states = hidden_states
|
|
831
|
+
original_encoder_hidden_states = encoder_hidden_states
|
|
832
|
+
|
|
833
|
+
# block_id: global block index in the transformer blocks +
|
|
834
|
+
# single_transformer_blocks
|
|
835
|
+
can_use_prune = self._get_can_use_prune(
|
|
836
|
+
block_id,
|
|
837
|
+
hidden_states, # hidden_states or residual
|
|
838
|
+
name=f"{block_id}_original", # prev step
|
|
839
|
+
)
|
|
840
|
+
|
|
841
|
+
# Prune steps: Prune current block and reuse the cached
|
|
842
|
+
# residuals for hidden states approximate.
|
|
843
|
+
if can_use_prune:
|
|
844
|
+
hidden_states, encoder_hidden_states = apply_hidden_states_residual(
|
|
845
|
+
hidden_states,
|
|
846
|
+
encoder_hidden_states,
|
|
847
|
+
name=f"{block_id}_residual",
|
|
848
|
+
encoder_name=f"{block_id}_encoder_residual",
|
|
849
|
+
)
|
|
850
|
+
else:
|
|
851
|
+
# Normal steps: Compute the block and cache the residuals.
|
|
852
|
+
hidden_states = block(
|
|
853
|
+
hidden_states,
|
|
854
|
+
encoder_hidden_states,
|
|
855
|
+
*args,
|
|
856
|
+
**kwargs,
|
|
857
|
+
)
|
|
858
|
+
if not isinstance(hidden_states, torch.Tensor):
|
|
859
|
+
hidden_states, encoder_hidden_states = hidden_states
|
|
860
|
+
if not self.return_hidden_states_first:
|
|
861
|
+
hidden_states, encoder_hidden_states = (
|
|
862
|
+
encoder_hidden_states,
|
|
863
|
+
hidden_states,
|
|
864
|
+
)
|
|
865
|
+
|
|
866
|
+
# Save original_hidden_states for diff calculation.
|
|
867
|
+
# May not be necessary to update the hidden
|
|
868
|
+
# states and residuals each step?
|
|
869
|
+
if self._should_update_residuals():
|
|
870
|
+
# Cache residuals for the non-compute Bn blocks for
|
|
871
|
+
# subsequent prune steps.
|
|
872
|
+
hidden_states_residual = hidden_states - original_hidden_states
|
|
873
|
+
encoder_hidden_states_residual = (
|
|
874
|
+
encoder_hidden_states - original_encoder_hidden_states
|
|
875
|
+
)
|
|
876
|
+
set_buffer(
|
|
877
|
+
f"{block_id}_original",
|
|
878
|
+
original_hidden_states,
|
|
879
|
+
)
|
|
880
|
+
|
|
881
|
+
set_buffer(
|
|
882
|
+
f"{block_id}_residual",
|
|
883
|
+
hidden_states_residual,
|
|
884
|
+
)
|
|
885
|
+
set_buffer(
|
|
886
|
+
f"{block_id}_encoder_residual",
|
|
887
|
+
encoder_hidden_states_residual,
|
|
888
|
+
)
|
|
889
|
+
del hidden_states_residual
|
|
890
|
+
del encoder_hidden_states_residual
|
|
891
|
+
|
|
892
|
+
del original_hidden_states
|
|
893
|
+
del original_encoder_hidden_states
|
|
894
|
+
|
|
895
|
+
return hidden_states, encoder_hidden_states
|
|
896
|
+
|
|
897
|
+
def call_transformer_blocks(
|
|
898
|
+
self,
|
|
899
|
+
hidden_states: torch.Tensor,
|
|
900
|
+
encoder_hidden_states: torch.Tensor,
|
|
901
|
+
*args,
|
|
902
|
+
**kwargs,
|
|
903
|
+
):
|
|
904
|
+
original_hidden_states = hidden_states
|
|
905
|
+
original_encoder_hidden_states = encoder_hidden_states
|
|
906
|
+
|
|
907
|
+
for i, block in enumerate(self.transformer_blocks):
|
|
908
|
+
hidden_states, encoder_hidden_states = (
|
|
909
|
+
self._compute_or_prune_transformer_block(
|
|
910
|
+
i,
|
|
911
|
+
block,
|
|
912
|
+
hidden_states,
|
|
913
|
+
encoder_hidden_states,
|
|
914
|
+
*args,
|
|
915
|
+
**kwargs,
|
|
916
|
+
)
|
|
917
|
+
)
|
|
918
|
+
|
|
919
|
+
if self.single_transformer_blocks is not None:
|
|
920
|
+
hidden_states = torch.cat(
|
|
921
|
+
[encoder_hidden_states, hidden_states], dim=1
|
|
922
|
+
)
|
|
923
|
+
for j, block in enumerate(self.single_transformer_blocks):
|
|
924
|
+
hidden_states = self._compute_or_prune_single_transformer_block(
|
|
925
|
+
j + len(self.transformer_blocks),
|
|
926
|
+
original_hidden_states,
|
|
927
|
+
original_encoder_hidden_states,
|
|
928
|
+
block,
|
|
929
|
+
hidden_states,
|
|
930
|
+
*args,
|
|
931
|
+
**kwargs,
|
|
932
|
+
)
|
|
933
|
+
|
|
934
|
+
encoder_hidden_states, hidden_states = hidden_states.split(
|
|
935
|
+
[
|
|
936
|
+
encoder_hidden_states.shape[1],
|
|
937
|
+
hidden_states.shape[1] - encoder_hidden_states.shape[1],
|
|
938
|
+
],
|
|
939
|
+
dim=1,
|
|
940
|
+
)
|
|
941
|
+
|
|
942
|
+
hidden_states = (
|
|
943
|
+
hidden_states.reshape(-1)
|
|
944
|
+
.contiguous()
|
|
945
|
+
.reshape(original_hidden_states.shape)
|
|
946
|
+
)
|
|
947
|
+
encoder_hidden_states = (
|
|
948
|
+
encoder_hidden_states.reshape(-1)
|
|
949
|
+
.contiguous()
|
|
950
|
+
.reshape(original_encoder_hidden_states.shape)
|
|
951
|
+
)
|
|
952
|
+
return hidden_states, encoder_hidden_states
|
|
953
|
+
|
|
954
|
+
|
|
955
|
+
@torch.compiler.disable
|
|
956
|
+
def patch_pruned_stats(
|
|
957
|
+
transformer,
|
|
958
|
+
):
|
|
959
|
+
# Patch the pruned stats to the transformer, the pruned stats
|
|
960
|
+
# will be reset for each calling of pipe.__call__(**kwargs).
|
|
961
|
+
if transformer is None:
|
|
962
|
+
return
|
|
963
|
+
|
|
964
|
+
pruned_transformer_blocks = getattr(transformer, "transformer_blocks", None)
|
|
965
|
+
if pruned_transformer_blocks is None:
|
|
966
|
+
return
|
|
967
|
+
|
|
968
|
+
if isinstance(pruned_transformer_blocks, torch.nn.ModuleList):
|
|
969
|
+
pruned_transformer_blocks = pruned_transformer_blocks[0]
|
|
970
|
+
if not isinstance(
|
|
971
|
+
pruned_transformer_blocks, DBPrunedTransformerBlocks
|
|
972
|
+
) or not isinstance(transformer, torch.nn.Module):
|
|
973
|
+
return
|
|
974
|
+
|
|
975
|
+
# TODO: Patch more pruned stats to the transformer
|
|
976
|
+
transformer._pruned_blocks = get_pruned_blocks()
|
|
977
|
+
transformer._pruned_steps = get_pruned_steps()
|
|
978
|
+
transformer._residual_diffs = get_residual_diffs()
|
|
979
|
+
transformer._actual_blocks = get_actual_blocks()
|