cache-dit 0.2.14__py3-none-any.whl → 0.2.16__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Potentially problematic release.
This version of cache-dit might be problematic. Click here for more details.
- cache_dit/_version.py +2 -2
- cache_dit/cache_factory/__init__.py +1 -0
- cache_dit/cache_factory/adapters.py +47 -5
- cache_dit/cache_factory/dual_block_cache/__init__.py +4 -0
- cache_dit/cache_factory/dual_block_cache/cache_blocks.py +487 -0
- cache_dit/cache_factory/dual_block_cache/cache_context.py +10 -860
- cache_dit/cache_factory/dual_block_cache/diffusers_adapters/__init__.py +4 -0
- cache_dit/cache_factory/dual_block_cache/diffusers_adapters/cogvideox.py +5 -2
- cache_dit/cache_factory/dual_block_cache/diffusers_adapters/flux.py +14 -4
- cache_dit/cache_factory/dual_block_cache/diffusers_adapters/hunyuan_video.py +5 -2
- cache_dit/cache_factory/dual_block_cache/diffusers_adapters/mochi.py +5 -2
- cache_dit/cache_factory/{first_block_cache/diffusers_adapters/mochi.py → dual_block_cache/diffusers_adapters/qwen_image.py} +14 -12
- cache_dit/cache_factory/dual_block_cache/diffusers_adapters/wan.py +7 -4
- cache_dit/cache_factory/dynamic_block_prune/__init__.py +4 -0
- cache_dit/cache_factory/dynamic_block_prune/diffusers_adapters/__init__.py +4 -0
- cache_dit/cache_factory/dynamic_block_prune/diffusers_adapters/cogvideox.py +5 -2
- cache_dit/cache_factory/dynamic_block_prune/diffusers_adapters/flux.py +10 -4
- cache_dit/cache_factory/dynamic_block_prune/diffusers_adapters/hunyuan_video.py +5 -2
- cache_dit/cache_factory/dynamic_block_prune/diffusers_adapters/mochi.py +5 -2
- cache_dit/cache_factory/{first_block_cache/diffusers_adapters/cogvideox.py → dynamic_block_prune/diffusers_adapters/qwen_image.py} +28 -23
- cache_dit/cache_factory/dynamic_block_prune/diffusers_adapters/wan.py +5 -2
- cache_dit/cache_factory/dynamic_block_prune/prune_blocks.py +276 -0
- cache_dit/cache_factory/dynamic_block_prune/prune_context.py +228 -516
- cache_dit/cache_factory/patch/flux.py +241 -0
- {cache_dit-0.2.14.dist-info → cache_dit-0.2.16.dist-info}/METADATA +22 -80
- cache_dit-0.2.16.dist-info/RECORD +47 -0
- cache_dit/cache_factory/first_block_cache/cache_context.py +0 -719
- cache_dit/cache_factory/first_block_cache/diffusers_adapters/__init__.py +0 -57
- cache_dit/cache_factory/first_block_cache/diffusers_adapters/flux.py +0 -100
- cache_dit/cache_factory/first_block_cache/diffusers_adapters/hunyuan_video.py +0 -295
- cache_dit/cache_factory/first_block_cache/diffusers_adapters/wan.py +0 -98
- cache_dit-0.2.14.dist-info/RECORD +0 -49
- /cache_dit/cache_factory/{first_block_cache → patch}/__init__.py +0 -0
- {cache_dit-0.2.14.dist-info → cache_dit-0.2.16.dist-info}/WHEEL +0 -0
- {cache_dit-0.2.14.dist-info → cache_dit-0.2.16.dist-info}/entry_points.txt +0 -0
- {cache_dit-0.2.14.dist-info → cache_dit-0.2.16.dist-info}/licenses/LICENSE +0 -0
- {cache_dit-0.2.14.dist-info → cache_dit-0.2.16.dist-info}/top_level.txt +0 -0
|
@@ -1,14 +1,12 @@
|
|
|
1
|
-
# Adapted from: https://github.com/chengzeyi/ParaAttention/tree/main/src/para_attn/first_block_cache/context.py
|
|
2
1
|
import logging
|
|
3
2
|
import contextlib
|
|
4
3
|
import dataclasses
|
|
5
4
|
from collections import defaultdict
|
|
6
|
-
from typing import Any, Dict, List, Optional, Union
|
|
5
|
+
from typing import Any, Dict, List, Optional, Union, DefaultDict
|
|
7
6
|
|
|
8
7
|
import torch
|
|
9
8
|
|
|
10
9
|
import cache_dit.primitives as primitives
|
|
11
|
-
from cache_dit.utils import is_diffusers_at_least_0_3_5
|
|
12
10
|
from cache_dit.logger import init_logger
|
|
13
11
|
|
|
14
12
|
logger = init_logger(__name__)
|
|
@@ -41,21 +39,56 @@ class DBPPruneContext:
|
|
|
41
39
|
buffers: Dict[str, Any] = dataclasses.field(default_factory=dict)
|
|
42
40
|
|
|
43
41
|
# Other settings
|
|
44
|
-
downsample_factor: int = 1
|
|
42
|
+
downsample_factor: int = 1 # un-used
|
|
45
43
|
num_inference_steps: int = -1
|
|
46
44
|
warmup_steps: int = 0 # DON'T pruned in warmup steps
|
|
47
45
|
# DON'T prune if the number of pruned steps >= max_pruned_steps
|
|
48
46
|
max_pruned_steps: int = -1
|
|
49
47
|
|
|
50
|
-
#
|
|
51
|
-
executed_steps: int = 0
|
|
48
|
+
# Record the steps that have been cached, both cached and non-cache
|
|
49
|
+
executed_steps: int = 0 # cache + non-cache steps pippeline
|
|
50
|
+
# steps for transformer, for CFG, transformer_executed_steps will
|
|
51
|
+
# be double of executed_steps.
|
|
52
|
+
transformer_executed_steps: int = 0
|
|
53
|
+
|
|
54
|
+
# Support do_separate_classifier_free_guidance, such as Wan 2.1,
|
|
55
|
+
# Qwen-Image. For model that fused CFG and non-CFG into single
|
|
56
|
+
# forward step, should set do_separate_classifier_free_guidance
|
|
57
|
+
# as False. For example: CogVideoX, HunyuanVideo, Mochi.
|
|
58
|
+
do_separate_classifier_free_guidance: bool = False
|
|
59
|
+
# Compute cfg forward first or not, default False, namely,
|
|
60
|
+
# 0, 2, 4, ..., -> non-CFG step; 1, 3, 5, ... -> CFG step.
|
|
61
|
+
cfg_compute_first: bool = False
|
|
62
|
+
# Compute spearate diff values for CFG and non-CFG step,
|
|
63
|
+
# default True. If False, we will use the computed diff from
|
|
64
|
+
# current non-CFG transformer step for current CFG step.
|
|
65
|
+
cfg_diff_compute_separate: bool = True
|
|
66
|
+
|
|
67
|
+
# CFG & non-CFG pruned steps
|
|
52
68
|
pruned_blocks: List[int] = dataclasses.field(default_factory=list)
|
|
53
69
|
actual_blocks: List[int] = dataclasses.field(default_factory=list)
|
|
54
|
-
|
|
55
|
-
|
|
70
|
+
residual_diffs: DefaultDict[str, list[float]] = dataclasses.field(
|
|
71
|
+
default_factory=lambda: defaultdict(list),
|
|
72
|
+
)
|
|
73
|
+
cfg_pruned_blocks: List[int] = dataclasses.field(default_factory=list)
|
|
74
|
+
cfg_actual_blocks: List[int] = dataclasses.field(default_factory=list)
|
|
75
|
+
cfg_residual_diffs: DefaultDict[str, list[float]] = dataclasses.field(
|
|
56
76
|
default_factory=lambda: defaultdict(list),
|
|
57
77
|
)
|
|
58
78
|
|
|
79
|
+
@torch.compiler.disable
|
|
80
|
+
def __post_init__(self):
|
|
81
|
+
# Some checks for settings
|
|
82
|
+
if self.do_separate_classifier_free_guidance:
|
|
83
|
+
assert (
|
|
84
|
+
self.cfg_diff_compute_separate
|
|
85
|
+
), "cfg_diff_compute_separate must be True"
|
|
86
|
+
if self.cfg_diff_compute_separate:
|
|
87
|
+
assert self.cfg_compute_first is False, (
|
|
88
|
+
"cfg_compute_first must set as False if "
|
|
89
|
+
"cfg_diff_compute_separate is enabled."
|
|
90
|
+
)
|
|
91
|
+
|
|
59
92
|
@torch.compiler.disable
|
|
60
93
|
def get_residual_diff_threshold(self):
|
|
61
94
|
residual_diff_threshold = self.residual_diff_threshold
|
|
@@ -119,42 +152,89 @@ class DBPPruneContext:
|
|
|
119
152
|
|
|
120
153
|
@torch.compiler.disable
|
|
121
154
|
def mark_step_begin(self):
|
|
122
|
-
|
|
123
|
-
|
|
155
|
+
# Always increase transformer executed steps
|
|
156
|
+
# incr step: prev 0 -> 1; prev 1 -> 2
|
|
157
|
+
# current step: incr step - 1
|
|
158
|
+
self.transformer_executed_steps += 1
|
|
159
|
+
if not self.do_separate_classifier_free_guidance:
|
|
160
|
+
self.executed_steps += 1
|
|
161
|
+
else:
|
|
162
|
+
# 0,1 -> 0 + 1, 2,3 -> 1 + 1, ...
|
|
163
|
+
if not self.cfg_compute_first:
|
|
164
|
+
if not self.is_separate_classifier_free_guidance_step():
|
|
165
|
+
# transformer step: 0,2,4,...
|
|
166
|
+
self.executed_steps += 1
|
|
167
|
+
else:
|
|
168
|
+
if self.is_separate_classifier_free_guidance_step():
|
|
169
|
+
# transformer step: 0,2,4,...
|
|
170
|
+
self.executed_steps += 1
|
|
171
|
+
|
|
172
|
+
# Reset the cached steps and residual diffs at the beginning
|
|
173
|
+
# of each inference.
|
|
174
|
+
if self.get_current_transformer_step() == 0:
|
|
124
175
|
self.pruned_blocks.clear()
|
|
125
176
|
self.actual_blocks.clear()
|
|
126
177
|
self.residual_diffs.clear()
|
|
178
|
+
self.cfg_pruned_blocks.clear()
|
|
179
|
+
self.cfg_actual_blocks.clear()
|
|
180
|
+
self.cfg_residual_diffs.clear()
|
|
127
181
|
|
|
128
182
|
@torch.compiler.disable
|
|
129
183
|
def add_pruned_block(self, num_blocks):
|
|
130
|
-
self.
|
|
184
|
+
if not self.is_separate_classifier_free_guidance_step():
|
|
185
|
+
self.pruned_blocks.append(num_blocks)
|
|
186
|
+
else:
|
|
187
|
+
self.cfg_pruned_blocks.append(num_blocks)
|
|
131
188
|
|
|
132
189
|
@torch.compiler.disable
|
|
133
190
|
def add_actual_block(self, num_blocks):
|
|
134
|
-
self.
|
|
191
|
+
if not self.is_separate_classifier_free_guidance_step():
|
|
192
|
+
self.actual_blocks.append(num_blocks)
|
|
193
|
+
else:
|
|
194
|
+
self.cfg_actual_blocks.append(num_blocks)
|
|
135
195
|
|
|
136
196
|
@torch.compiler.disable
|
|
137
197
|
def add_residual_diff(self, diff):
|
|
138
|
-
|
|
139
|
-
|
|
140
|
-
|
|
141
|
-
self.
|
|
142
|
-
|
|
143
|
-
|
|
144
|
-
|
|
145
|
-
|
|
146
|
-
|
|
147
|
-
|
|
148
|
-
|
|
149
|
-
|
|
150
|
-
|
|
151
|
-
|
|
152
|
-
|
|
198
|
+
# step: executed_steps - 1, not transformer_steps - 1
|
|
199
|
+
step = str(self.get_current_step())
|
|
200
|
+
# Only add the diff if it is not already recorded for this step
|
|
201
|
+
if not self.is_separate_classifier_free_guidance_step():
|
|
202
|
+
if step not in self.residual_diffs:
|
|
203
|
+
self.residual_diffs[step] = [diff]
|
|
204
|
+
else:
|
|
205
|
+
self.residual_diffs[step].append(diff)
|
|
206
|
+
else:
|
|
207
|
+
if step not in self.cfg_residual_diffs:
|
|
208
|
+
self.cfg_residual_diffs[step] = [diff]
|
|
209
|
+
else:
|
|
210
|
+
self.cfg_residual_diffs[step].append(diff)
|
|
211
|
+
|
|
212
|
+
@torch.compiler.disable
|
|
213
|
+
def get_pruned_blocks(self):
|
|
214
|
+
return self.pruned_blocks.copy()
|
|
215
|
+
|
|
216
|
+
@torch.compiler.disable
|
|
217
|
+
def get_cfg_pruned_blocks(self):
|
|
218
|
+
return self.cfg_pruned_blocks.copy()
|
|
153
219
|
|
|
154
220
|
@torch.compiler.disable
|
|
155
221
|
def get_current_step(self):
|
|
156
222
|
return self.executed_steps - 1
|
|
157
223
|
|
|
224
|
+
@torch.compiler.disable
|
|
225
|
+
def get_current_transformer_step(self):
|
|
226
|
+
return self.transformer_executed_steps - 1
|
|
227
|
+
|
|
228
|
+
@torch.compiler.disable
|
|
229
|
+
def is_separate_classifier_free_guidance_step(self):
|
|
230
|
+
if not self.do_separate_classifier_free_guidance:
|
|
231
|
+
return False
|
|
232
|
+
if self.cfg_compute_first:
|
|
233
|
+
# CFG steps: 0, 2, 4, 6, ...
|
|
234
|
+
return self.get_current_transformer_step() % 2 == 0
|
|
235
|
+
# CFG steps: 1, 3, 5, 7, ...
|
|
236
|
+
return self.get_current_transformer_step() % 2 != 0
|
|
237
|
+
|
|
158
238
|
@torch.compiler.disable
|
|
159
239
|
def is_in_warmup(self):
|
|
160
240
|
return self.get_current_step() < self.warmup_steps
|
|
@@ -168,38 +248,35 @@ def get_residual_diff_threshold():
|
|
|
168
248
|
|
|
169
249
|
|
|
170
250
|
@torch.compiler.disable
|
|
171
|
-
def
|
|
172
|
-
prune_context = get_current_prune_context()
|
|
173
|
-
assert prune_context is not None, "prune_context must be set before"
|
|
174
|
-
return prune_context.get_buffer(name)
|
|
175
|
-
|
|
176
|
-
|
|
177
|
-
@torch.compiler.disable
|
|
178
|
-
def set_buffer(name, buffer):
|
|
251
|
+
def mark_step_begin():
|
|
179
252
|
prune_context = get_current_prune_context()
|
|
180
253
|
assert prune_context is not None, "prune_context must be set before"
|
|
181
|
-
prune_context.
|
|
254
|
+
prune_context.mark_step_begin()
|
|
182
255
|
|
|
183
256
|
|
|
184
257
|
@torch.compiler.disable
|
|
185
|
-
def
|
|
258
|
+
def get_current_step():
|
|
186
259
|
prune_context = get_current_prune_context()
|
|
187
260
|
assert prune_context is not None, "prune_context must be set before"
|
|
188
|
-
prune_context.
|
|
261
|
+
return prune_context.get_current_step()
|
|
189
262
|
|
|
190
263
|
|
|
191
264
|
@torch.compiler.disable
|
|
192
|
-
def
|
|
265
|
+
def get_current_step_cfg_residual_diff():
|
|
193
266
|
prune_context = get_current_prune_context()
|
|
194
267
|
assert prune_context is not None, "prune_context must be set before"
|
|
195
|
-
|
|
268
|
+
step = str(get_current_step())
|
|
269
|
+
cfg_residual_diffs = get_cfg_residual_diffs()
|
|
270
|
+
if step in cfg_residual_diffs:
|
|
271
|
+
return cfg_residual_diffs[step]
|
|
272
|
+
return None
|
|
196
273
|
|
|
197
274
|
|
|
198
275
|
@torch.compiler.disable
|
|
199
|
-
def
|
|
276
|
+
def get_current_transformer_step():
|
|
200
277
|
prune_context = get_current_prune_context()
|
|
201
278
|
assert prune_context is not None, "prune_context must be set before"
|
|
202
|
-
return prune_context.
|
|
279
|
+
return prune_context.get_current_transformer_step()
|
|
203
280
|
|
|
204
281
|
|
|
205
282
|
@torch.compiler.disable
|
|
@@ -226,6 +303,13 @@ def get_pruned_blocks():
|
|
|
226
303
|
return prune_context.pruned_blocks.copy()
|
|
227
304
|
|
|
228
305
|
|
|
306
|
+
@torch.compiler.disable
|
|
307
|
+
def get_cfg_pruned_blocks():
|
|
308
|
+
prune_context = get_current_prune_context()
|
|
309
|
+
assert prune_context is not None, "prune_context must be set before"
|
|
310
|
+
return prune_context.cfg_pruned_blocks.copy()
|
|
311
|
+
|
|
312
|
+
|
|
229
313
|
@torch.compiler.disable
|
|
230
314
|
def add_actual_block(num_blocks):
|
|
231
315
|
assert (
|
|
@@ -243,6 +327,13 @@ def get_actual_blocks():
|
|
|
243
327
|
return prune_context.actual_blocks.copy()
|
|
244
328
|
|
|
245
329
|
|
|
330
|
+
@torch.compiler.disable
|
|
331
|
+
def get_cfg_actual_blocks():
|
|
332
|
+
prune_context = get_current_prune_context()
|
|
333
|
+
assert prune_context is not None, "prune_context must be set before"
|
|
334
|
+
return prune_context.cfg_actual_blocks.copy()
|
|
335
|
+
|
|
336
|
+
|
|
246
337
|
@torch.compiler.disable
|
|
247
338
|
def get_pruned_steps():
|
|
248
339
|
prune_context = get_current_prune_context()
|
|
@@ -252,6 +343,15 @@ def get_pruned_steps():
|
|
|
252
343
|
return len(pruned_blocks)
|
|
253
344
|
|
|
254
345
|
|
|
346
|
+
@torch.compiler.disable
|
|
347
|
+
def get_cfg_pruned_steps():
|
|
348
|
+
prune_context = get_current_prune_context()
|
|
349
|
+
assert prune_context is not None, "prune_context must be set before"
|
|
350
|
+
cfg_pruned_blocks = get_cfg_pruned_blocks()
|
|
351
|
+
cfg_pruned_blocks = [x for x in cfg_pruned_blocks if x > 0]
|
|
352
|
+
return len(cfg_pruned_blocks)
|
|
353
|
+
|
|
354
|
+
|
|
255
355
|
@torch.compiler.disable
|
|
256
356
|
def is_in_warmup():
|
|
257
357
|
prune_context = get_current_prune_context()
|
|
@@ -284,6 +384,14 @@ def get_residual_diffs():
|
|
|
284
384
|
return prune_context.residual_diffs.copy()
|
|
285
385
|
|
|
286
386
|
|
|
387
|
+
@torch.compiler.disable
|
|
388
|
+
def get_cfg_residual_diffs():
|
|
389
|
+
prune_context = get_current_prune_context()
|
|
390
|
+
assert prune_context is not None, "prune_context must be set before"
|
|
391
|
+
# Return a copy of the residual diffs to avoid modification
|
|
392
|
+
return prune_context.cfg_residual_diffs.copy()
|
|
393
|
+
|
|
394
|
+
|
|
287
395
|
@torch.compiler.disable
|
|
288
396
|
def get_important_condition_threshold():
|
|
289
397
|
prune_context = get_current_prune_context()
|
|
@@ -325,6 +433,27 @@ def get_non_prune_blocks_ids():
|
|
|
325
433
|
return prune_context.non_prune_blocks_ids
|
|
326
434
|
|
|
327
435
|
|
|
436
|
+
@torch.compiler.disable
|
|
437
|
+
def do_separate_classifier_free_guidance():
|
|
438
|
+
prune_context = get_current_prune_context()
|
|
439
|
+
assert prune_context is not None, "prune_context must be set before"
|
|
440
|
+
return prune_context.do_separate_classifier_free_guidance
|
|
441
|
+
|
|
442
|
+
|
|
443
|
+
@torch.compiler.disable
|
|
444
|
+
def is_separate_classifier_free_guidance_step():
|
|
445
|
+
prune_context = get_current_prune_context()
|
|
446
|
+
assert prune_context is not None, "prune_context must be set before"
|
|
447
|
+
return prune_context.is_separate_classifier_free_guidance_step()
|
|
448
|
+
|
|
449
|
+
|
|
450
|
+
@torch.compiler.disable
|
|
451
|
+
def cfg_diff_compute_separate():
|
|
452
|
+
prune_context = get_current_prune_context()
|
|
453
|
+
assert prune_context is not None, "prune_context must be set before"
|
|
454
|
+
return prune_context.cfg_diff_compute_separate
|
|
455
|
+
|
|
456
|
+
|
|
328
457
|
_current_prune_context: DBPPruneContext = None
|
|
329
458
|
|
|
330
459
|
|
|
@@ -463,6 +592,58 @@ def are_two_tensors_similar(
|
|
|
463
592
|
return diff < threshold
|
|
464
593
|
|
|
465
594
|
|
|
595
|
+
@torch.compiler.disable
|
|
596
|
+
def _debugging_set_buffer(prefix):
|
|
597
|
+
if logger.isEnabledFor(logging.DEBUG):
|
|
598
|
+
logger.debug(
|
|
599
|
+
f"set {prefix}, "
|
|
600
|
+
f"transformer step: {get_current_transformer_step()}, "
|
|
601
|
+
f"executed step: {get_current_step()}"
|
|
602
|
+
)
|
|
603
|
+
|
|
604
|
+
|
|
605
|
+
@torch.compiler.disable
|
|
606
|
+
def _debugging_get_buffer(prefix):
|
|
607
|
+
if logger.isEnabledFor(logging.DEBUG):
|
|
608
|
+
logger.debug(
|
|
609
|
+
f"get {prefix}, "
|
|
610
|
+
f"transformer step: {get_current_transformer_step()}, "
|
|
611
|
+
f"executed step: {get_current_step()}"
|
|
612
|
+
)
|
|
613
|
+
|
|
614
|
+
|
|
615
|
+
@torch.compiler.disable
|
|
616
|
+
def set_buffer(name: str, buffer: torch.Tensor):
|
|
617
|
+
# Set hidden_states or residual for Fn blocks.
|
|
618
|
+
# This buffer is only use for L1 diff calculation.
|
|
619
|
+
prune_context = get_current_prune_context()
|
|
620
|
+
assert prune_context is not None, "prune_context must be set before"
|
|
621
|
+
if is_separate_classifier_free_guidance_step():
|
|
622
|
+
_debugging_set_buffer(f"{name}_buffer_cfg")
|
|
623
|
+
prune_context.set_buffer(f"{name}_buffer_cfg", buffer)
|
|
624
|
+
else:
|
|
625
|
+
_debugging_set_buffer(f"{name}_buffer")
|
|
626
|
+
prune_context.set_buffer(f"{name}_buffer", buffer)
|
|
627
|
+
|
|
628
|
+
|
|
629
|
+
@torch.compiler.disable
|
|
630
|
+
def get_buffer(name: str):
|
|
631
|
+
prune_context = get_current_prune_context()
|
|
632
|
+
assert prune_context is not None, "prune_context must be set before"
|
|
633
|
+
if is_separate_classifier_free_guidance_step():
|
|
634
|
+
_debugging_get_buffer(f"{name}_buffer_cfg")
|
|
635
|
+
return prune_context.get_buffer(f"{name}_buffer_cfg")
|
|
636
|
+
_debugging_get_buffer(f"{name}_buffer")
|
|
637
|
+
return prune_context.get_buffer(f"{name}_buffer")
|
|
638
|
+
|
|
639
|
+
|
|
640
|
+
@torch.compiler.disable
|
|
641
|
+
def remove_buffer(name: str):
|
|
642
|
+
prune_context = get_current_prune_context()
|
|
643
|
+
assert prune_context is not None, "prune_context must be set before"
|
|
644
|
+
prune_context.remove_buffer(name)
|
|
645
|
+
|
|
646
|
+
|
|
466
647
|
@torch.compiler.disable
|
|
467
648
|
def apply_hidden_states_residual(
|
|
468
649
|
hidden_states: torch.Tensor,
|
|
@@ -506,7 +687,11 @@ def get_can_use_prune(
|
|
|
506
687
|
if is_in_warmup():
|
|
507
688
|
return False
|
|
508
689
|
|
|
509
|
-
|
|
690
|
+
if not is_separate_classifier_free_guidance_step():
|
|
691
|
+
pruned_steps = get_pruned_steps()
|
|
692
|
+
else:
|
|
693
|
+
pruned_steps = get_cfg_pruned_steps()
|
|
694
|
+
|
|
510
695
|
max_pruned_steps = get_max_pruned_steps()
|
|
511
696
|
if max_pruned_steps >= 0 and (pruned_steps >= max_pruned_steps):
|
|
512
697
|
if logger.isEnabledFor(logging.DEBUG):
|
|
@@ -521,16 +706,8 @@ def get_can_use_prune(
|
|
|
521
706
|
if threshold <= 0.0:
|
|
522
707
|
return False
|
|
523
708
|
|
|
524
|
-
downsample_factor = get_downsample_factor()
|
|
525
709
|
prev_states_tensor = get_buffer(f"{name}")
|
|
526
710
|
|
|
527
|
-
if downsample_factor > 1:
|
|
528
|
-
states_tensor = states_tensor[..., ::downsample_factor]
|
|
529
|
-
states_tensor = states_tensor.contiguous()
|
|
530
|
-
if prev_states_tensor is not None:
|
|
531
|
-
prev_states_tensor = prev_states_tensor[..., ::downsample_factor]
|
|
532
|
-
prev_states_tensor = prev_states_tensor.contiguous()
|
|
533
|
-
|
|
534
711
|
return prev_states_tensor is not None and are_two_tensors_similar(
|
|
535
712
|
prev_states_tensor,
|
|
536
713
|
states_tensor,
|
|
@@ -538,468 +715,3 @@ def get_can_use_prune(
|
|
|
538
715
|
parallelized=parallelized,
|
|
539
716
|
name=name,
|
|
540
717
|
)
|
|
541
|
-
|
|
542
|
-
|
|
543
|
-
class DBPrunedTransformerBlocks(torch.nn.Module):
|
|
544
|
-
def __init__(
|
|
545
|
-
self,
|
|
546
|
-
transformer_blocks,
|
|
547
|
-
single_transformer_blocks=None,
|
|
548
|
-
*,
|
|
549
|
-
transformer=None,
|
|
550
|
-
return_hidden_states_first=True,
|
|
551
|
-
return_hidden_states_only=False,
|
|
552
|
-
):
|
|
553
|
-
super().__init__()
|
|
554
|
-
|
|
555
|
-
self.transformer = transformer
|
|
556
|
-
self.transformer_blocks = transformer_blocks
|
|
557
|
-
self.single_transformer_blocks = single_transformer_blocks
|
|
558
|
-
self.return_hidden_states_first = return_hidden_states_first
|
|
559
|
-
self.return_hidden_states_only = return_hidden_states_only
|
|
560
|
-
self.pruned_blocks_step: int = 0
|
|
561
|
-
|
|
562
|
-
def forward(
|
|
563
|
-
self,
|
|
564
|
-
hidden_states: torch.Tensor,
|
|
565
|
-
encoder_hidden_states: torch.Tensor,
|
|
566
|
-
*args,
|
|
567
|
-
**kwargs,
|
|
568
|
-
):
|
|
569
|
-
mark_step_begin()
|
|
570
|
-
self.pruned_blocks_step = 0
|
|
571
|
-
original_hidden_states = hidden_states
|
|
572
|
-
|
|
573
|
-
torch._dynamo.graph_break()
|
|
574
|
-
hidden_states, encoder_hidden_states = self.call_transformer_blocks(
|
|
575
|
-
hidden_states,
|
|
576
|
-
encoder_hidden_states,
|
|
577
|
-
*args,
|
|
578
|
-
**kwargs,
|
|
579
|
-
)
|
|
580
|
-
|
|
581
|
-
del original_hidden_states
|
|
582
|
-
torch._dynamo.graph_break()
|
|
583
|
-
|
|
584
|
-
add_pruned_block(self.pruned_blocks_step)
|
|
585
|
-
add_actual_block(self.num_transformer_blocks)
|
|
586
|
-
patch_pruned_stats(self.transformer)
|
|
587
|
-
|
|
588
|
-
return (
|
|
589
|
-
hidden_states
|
|
590
|
-
if self.return_hidden_states_only
|
|
591
|
-
else (
|
|
592
|
-
(hidden_states, encoder_hidden_states)
|
|
593
|
-
if self.return_hidden_states_first
|
|
594
|
-
else (encoder_hidden_states, hidden_states)
|
|
595
|
-
)
|
|
596
|
-
)
|
|
597
|
-
|
|
598
|
-
@property
|
|
599
|
-
@torch.compiler.disable
|
|
600
|
-
def num_transformer_blocks(self):
|
|
601
|
-
# Total number of transformer blocks, including single transformer blocks.
|
|
602
|
-
num_blocks = len(self.transformer_blocks)
|
|
603
|
-
if self.single_transformer_blocks is not None:
|
|
604
|
-
num_blocks += len(self.single_transformer_blocks)
|
|
605
|
-
return num_blocks
|
|
606
|
-
|
|
607
|
-
@torch.compiler.disable
|
|
608
|
-
def _is_parallelized(self):
|
|
609
|
-
# Compatible with distributed inference.
|
|
610
|
-
return all(
|
|
611
|
-
(
|
|
612
|
-
self.transformer is not None,
|
|
613
|
-
getattr(self.transformer, "_is_parallelized", False),
|
|
614
|
-
)
|
|
615
|
-
)
|
|
616
|
-
|
|
617
|
-
@torch.compiler.disable
|
|
618
|
-
def _non_prune_blocks_ids(self):
|
|
619
|
-
# Never prune the first `Fn` and last `Bn` blocks.
|
|
620
|
-
num_blocks = self.num_transformer_blocks
|
|
621
|
-
Fn_compute_blocks_ = (
|
|
622
|
-
Fn_compute_blocks()
|
|
623
|
-
if Fn_compute_blocks() < num_blocks
|
|
624
|
-
else num_blocks
|
|
625
|
-
)
|
|
626
|
-
Fn_compute_blocks_ids = list(range(Fn_compute_blocks_))
|
|
627
|
-
Bn_compute_blocks_ = (
|
|
628
|
-
Bn_compute_blocks()
|
|
629
|
-
if Bn_compute_blocks() < num_blocks
|
|
630
|
-
else num_blocks
|
|
631
|
-
)
|
|
632
|
-
Bn_compute_blocks_ids = list(
|
|
633
|
-
range(
|
|
634
|
-
num_blocks - Bn_compute_blocks_,
|
|
635
|
-
num_blocks,
|
|
636
|
-
)
|
|
637
|
-
)
|
|
638
|
-
non_prune_blocks_ids = list(
|
|
639
|
-
set(
|
|
640
|
-
Fn_compute_blocks_ids
|
|
641
|
-
+ Bn_compute_blocks_ids
|
|
642
|
-
+ get_non_prune_blocks_ids()
|
|
643
|
-
)
|
|
644
|
-
)
|
|
645
|
-
non_prune_blocks_ids = [
|
|
646
|
-
d for d in non_prune_blocks_ids if d < num_blocks
|
|
647
|
-
]
|
|
648
|
-
return sorted(non_prune_blocks_ids)
|
|
649
|
-
|
|
650
|
-
@torch.compiler.disable
|
|
651
|
-
def _compute_single_hidden_states_residual(
|
|
652
|
-
self,
|
|
653
|
-
single_hidden_states: torch.Tensor,
|
|
654
|
-
single_original_hidden_states: torch.Tensor,
|
|
655
|
-
# global original single hidden states
|
|
656
|
-
original_single_hidden_states: torch.Tensor,
|
|
657
|
-
original_single_encoder_hidden_states: torch.Tensor,
|
|
658
|
-
):
|
|
659
|
-
single_hidden_states, single_encoder_hidden_states = (
|
|
660
|
-
self._split_single_hidden_states(
|
|
661
|
-
single_hidden_states,
|
|
662
|
-
original_single_hidden_states,
|
|
663
|
-
original_single_encoder_hidden_states,
|
|
664
|
-
)
|
|
665
|
-
)
|
|
666
|
-
|
|
667
|
-
single_original_hidden_states, single_original_encoder_hidden_states = (
|
|
668
|
-
self._split_single_hidden_states(
|
|
669
|
-
single_original_hidden_states,
|
|
670
|
-
original_single_hidden_states,
|
|
671
|
-
original_single_encoder_hidden_states,
|
|
672
|
-
)
|
|
673
|
-
)
|
|
674
|
-
|
|
675
|
-
single_hidden_states_residual = (
|
|
676
|
-
single_hidden_states - single_original_hidden_states
|
|
677
|
-
)
|
|
678
|
-
single_encoder_hidden_states_residual = (
|
|
679
|
-
single_encoder_hidden_states - single_original_encoder_hidden_states
|
|
680
|
-
)
|
|
681
|
-
return (
|
|
682
|
-
single_hidden_states_residual,
|
|
683
|
-
single_encoder_hidden_states_residual,
|
|
684
|
-
)
|
|
685
|
-
|
|
686
|
-
@torch.compiler.disable
|
|
687
|
-
def _split_single_hidden_states(
|
|
688
|
-
self,
|
|
689
|
-
single_hidden_states: torch.Tensor,
|
|
690
|
-
# global original single hidden states
|
|
691
|
-
original_single_hidden_states: torch.Tensor,
|
|
692
|
-
original_single_encoder_hidden_states: torch.Tensor,
|
|
693
|
-
):
|
|
694
|
-
single_encoder_hidden_states, single_hidden_states = (
|
|
695
|
-
single_hidden_states.split(
|
|
696
|
-
[
|
|
697
|
-
original_single_encoder_hidden_states.shape[1],
|
|
698
|
-
single_hidden_states.shape[1]
|
|
699
|
-
- original_single_encoder_hidden_states.shape[1],
|
|
700
|
-
],
|
|
701
|
-
dim=1,
|
|
702
|
-
)
|
|
703
|
-
)
|
|
704
|
-
# Reshape the single_hidden_states and single_encoder_hidden_states
|
|
705
|
-
# to the original shape. This is necessary to ensure that the
|
|
706
|
-
# residuals are computed correctly.
|
|
707
|
-
single_hidden_states = (
|
|
708
|
-
single_hidden_states.reshape(-1)
|
|
709
|
-
.contiguous()
|
|
710
|
-
.reshape(original_single_hidden_states.shape)
|
|
711
|
-
)
|
|
712
|
-
single_encoder_hidden_states = (
|
|
713
|
-
single_encoder_hidden_states.reshape(-1)
|
|
714
|
-
.contiguous()
|
|
715
|
-
.reshape(original_single_encoder_hidden_states.shape)
|
|
716
|
-
)
|
|
717
|
-
return single_hidden_states, single_encoder_hidden_states
|
|
718
|
-
|
|
719
|
-
@torch.compiler.disable
|
|
720
|
-
def _should_update_residuals(self):
|
|
721
|
-
# Wrap for non compiled mode.
|
|
722
|
-
# Check if the current step is a multiple of
|
|
723
|
-
# the residual cache update interval.
|
|
724
|
-
return get_current_step() % residual_cache_update_interval() == 0
|
|
725
|
-
|
|
726
|
-
@torch.compiler.disable
|
|
727
|
-
def _get_can_use_prune(
|
|
728
|
-
self,
|
|
729
|
-
block_id: int, # Block index in the transformer blocks
|
|
730
|
-
hidden_states: torch.Tensor, # hidden_states or residual
|
|
731
|
-
name: str = "Bn_original", # prev step name for single blocks
|
|
732
|
-
):
|
|
733
|
-
# Wrap for non compiled mode.
|
|
734
|
-
can_use_prune = False
|
|
735
|
-
if block_id not in self._non_prune_blocks_ids():
|
|
736
|
-
can_use_prune = get_can_use_prune(
|
|
737
|
-
hidden_states, # curr step
|
|
738
|
-
parallelized=self._is_parallelized(),
|
|
739
|
-
name=name, # prev step
|
|
740
|
-
)
|
|
741
|
-
self.pruned_blocks_step += int(can_use_prune)
|
|
742
|
-
return can_use_prune
|
|
743
|
-
|
|
744
|
-
def _compute_or_prune_single_transformer_block(
|
|
745
|
-
self,
|
|
746
|
-
block_id: int, # Block index in the transformer blocks
|
|
747
|
-
# Helper inputs for hidden states split and reshape
|
|
748
|
-
# Global original single hidden states
|
|
749
|
-
original_single_hidden_states: torch.Tensor,
|
|
750
|
-
original_single_encoder_hidden_states: torch.Tensor,
|
|
751
|
-
# Below are the inputs to the block
|
|
752
|
-
block, # The transformer block to be executed
|
|
753
|
-
hidden_states: torch.Tensor,
|
|
754
|
-
*args,
|
|
755
|
-
**kwargs,
|
|
756
|
-
):
|
|
757
|
-
# Helper function for `call_transformer_blocks`
|
|
758
|
-
# block_id: global block index in the transformer blocks +
|
|
759
|
-
# single_transformer_blocks
|
|
760
|
-
can_use_prune = self._get_can_use_prune(
|
|
761
|
-
block_id,
|
|
762
|
-
hidden_states, # hidden_states or residual
|
|
763
|
-
name=f"{block_id}_single_original", # prev step
|
|
764
|
-
)
|
|
765
|
-
|
|
766
|
-
# Prune steps: Prune current block and reuse the cached
|
|
767
|
-
# residuals for hidden states approximate.
|
|
768
|
-
if can_use_prune:
|
|
769
|
-
single_original_hidden_states = hidden_states
|
|
770
|
-
(
|
|
771
|
-
single_original_hidden_states,
|
|
772
|
-
single_original_encoder_hidden_states,
|
|
773
|
-
) = self._split_single_hidden_states(
|
|
774
|
-
single_original_hidden_states,
|
|
775
|
-
original_single_hidden_states,
|
|
776
|
-
original_single_encoder_hidden_states,
|
|
777
|
-
)
|
|
778
|
-
hidden_states, encoder_hidden_states = apply_hidden_states_residual(
|
|
779
|
-
single_original_hidden_states,
|
|
780
|
-
single_original_encoder_hidden_states,
|
|
781
|
-
name=f"{block_id}_single_residual",
|
|
782
|
-
encoder_name=f"{block_id}_single_encoder_residual",
|
|
783
|
-
)
|
|
784
|
-
hidden_states = torch.cat(
|
|
785
|
-
[encoder_hidden_states, hidden_states],
|
|
786
|
-
dim=1,
|
|
787
|
-
)
|
|
788
|
-
del single_original_hidden_states
|
|
789
|
-
del single_original_encoder_hidden_states
|
|
790
|
-
|
|
791
|
-
else:
|
|
792
|
-
# Normal steps: Compute the block and cache the residuals.
|
|
793
|
-
single_original_hidden_states = hidden_states
|
|
794
|
-
hidden_states = block(
|
|
795
|
-
hidden_states,
|
|
796
|
-
*args,
|
|
797
|
-
**kwargs,
|
|
798
|
-
)
|
|
799
|
-
|
|
800
|
-
# Save original_hidden_states for diff calculation.
|
|
801
|
-
# May not be necessary to update the hidden
|
|
802
|
-
# states and residuals each step?
|
|
803
|
-
if self._should_update_residuals():
|
|
804
|
-
# Cache residuals for the non-compute Bn blocks for
|
|
805
|
-
# subsequent prune steps.
|
|
806
|
-
single_hidden_states = hidden_states
|
|
807
|
-
(
|
|
808
|
-
single_hidden_states_residual,
|
|
809
|
-
single_encoder_hidden_states_residual,
|
|
810
|
-
) = self._compute_single_hidden_states_residual(
|
|
811
|
-
single_hidden_states,
|
|
812
|
-
single_original_hidden_states,
|
|
813
|
-
original_single_hidden_states,
|
|
814
|
-
original_single_encoder_hidden_states,
|
|
815
|
-
)
|
|
816
|
-
|
|
817
|
-
set_buffer(
|
|
818
|
-
f"{block_id}_single_original",
|
|
819
|
-
single_original_hidden_states,
|
|
820
|
-
)
|
|
821
|
-
|
|
822
|
-
set_buffer(
|
|
823
|
-
f"{block_id}_single_residual",
|
|
824
|
-
single_hidden_states_residual,
|
|
825
|
-
)
|
|
826
|
-
set_buffer(
|
|
827
|
-
f"{block_id}_single_encoder_residual",
|
|
828
|
-
single_encoder_hidden_states_residual,
|
|
829
|
-
)
|
|
830
|
-
|
|
831
|
-
del single_hidden_states
|
|
832
|
-
del single_hidden_states_residual
|
|
833
|
-
del single_encoder_hidden_states_residual
|
|
834
|
-
|
|
835
|
-
del single_original_hidden_states
|
|
836
|
-
|
|
837
|
-
return hidden_states
|
|
838
|
-
|
|
839
|
-
def _compute_or_prune_transformer_block(
|
|
840
|
-
self,
|
|
841
|
-
block_id: int, # Block index in the transformer blocks
|
|
842
|
-
# Below are the inputs to the block
|
|
843
|
-
block, # The transformer block to be executed
|
|
844
|
-
hidden_states: torch.Tensor,
|
|
845
|
-
encoder_hidden_states: torch.Tensor,
|
|
846
|
-
*args,
|
|
847
|
-
**kwargs,
|
|
848
|
-
):
|
|
849
|
-
# Helper function for `call_transformer_blocks`
|
|
850
|
-
original_hidden_states = hidden_states
|
|
851
|
-
original_encoder_hidden_states = encoder_hidden_states
|
|
852
|
-
|
|
853
|
-
# block_id: global block index in the transformer blocks +
|
|
854
|
-
# single_transformer_blocks
|
|
855
|
-
can_use_prune = self._get_can_use_prune(
|
|
856
|
-
block_id,
|
|
857
|
-
hidden_states, # hidden_states or residual
|
|
858
|
-
name=f"{block_id}_original", # prev step
|
|
859
|
-
)
|
|
860
|
-
|
|
861
|
-
# Prune steps: Prune current block and reuse the cached
|
|
862
|
-
# residuals for hidden states approximate.
|
|
863
|
-
if can_use_prune:
|
|
864
|
-
hidden_states, encoder_hidden_states = apply_hidden_states_residual(
|
|
865
|
-
hidden_states,
|
|
866
|
-
encoder_hidden_states,
|
|
867
|
-
name=f"{block_id}_residual",
|
|
868
|
-
encoder_name=f"{block_id}_encoder_residual",
|
|
869
|
-
)
|
|
870
|
-
else:
|
|
871
|
-
# Normal steps: Compute the block and cache the residuals.
|
|
872
|
-
hidden_states = block(
|
|
873
|
-
hidden_states,
|
|
874
|
-
encoder_hidden_states,
|
|
875
|
-
*args,
|
|
876
|
-
**kwargs,
|
|
877
|
-
)
|
|
878
|
-
if not isinstance(hidden_states, torch.Tensor):
|
|
879
|
-
hidden_states, encoder_hidden_states = hidden_states
|
|
880
|
-
if not self.return_hidden_states_first:
|
|
881
|
-
hidden_states, encoder_hidden_states = (
|
|
882
|
-
encoder_hidden_states,
|
|
883
|
-
hidden_states,
|
|
884
|
-
)
|
|
885
|
-
|
|
886
|
-
# Save original_hidden_states for diff calculation.
|
|
887
|
-
# May not be necessary to update the hidden
|
|
888
|
-
# states and residuals each step?
|
|
889
|
-
if self._should_update_residuals():
|
|
890
|
-
# Cache residuals for the non-compute Bn blocks for
|
|
891
|
-
# subsequent prune steps.
|
|
892
|
-
hidden_states_residual = hidden_states - original_hidden_states
|
|
893
|
-
encoder_hidden_states_residual = (
|
|
894
|
-
encoder_hidden_states - original_encoder_hidden_states
|
|
895
|
-
)
|
|
896
|
-
set_buffer(
|
|
897
|
-
f"{block_id}_original",
|
|
898
|
-
original_hidden_states,
|
|
899
|
-
)
|
|
900
|
-
|
|
901
|
-
set_buffer(
|
|
902
|
-
f"{block_id}_residual",
|
|
903
|
-
hidden_states_residual,
|
|
904
|
-
)
|
|
905
|
-
set_buffer(
|
|
906
|
-
f"{block_id}_encoder_residual",
|
|
907
|
-
encoder_hidden_states_residual,
|
|
908
|
-
)
|
|
909
|
-
del hidden_states_residual
|
|
910
|
-
del encoder_hidden_states_residual
|
|
911
|
-
|
|
912
|
-
del original_hidden_states
|
|
913
|
-
del original_encoder_hidden_states
|
|
914
|
-
|
|
915
|
-
return hidden_states, encoder_hidden_states
|
|
916
|
-
|
|
917
|
-
def call_transformer_blocks(
|
|
918
|
-
self,
|
|
919
|
-
hidden_states: torch.Tensor,
|
|
920
|
-
encoder_hidden_states: torch.Tensor,
|
|
921
|
-
*args,
|
|
922
|
-
**kwargs,
|
|
923
|
-
):
|
|
924
|
-
original_hidden_states = hidden_states
|
|
925
|
-
original_encoder_hidden_states = encoder_hidden_states
|
|
926
|
-
|
|
927
|
-
for i, block in enumerate(self.transformer_blocks):
|
|
928
|
-
hidden_states, encoder_hidden_states = (
|
|
929
|
-
self._compute_or_prune_transformer_block(
|
|
930
|
-
i,
|
|
931
|
-
block,
|
|
932
|
-
hidden_states,
|
|
933
|
-
encoder_hidden_states,
|
|
934
|
-
*args,
|
|
935
|
-
**kwargs,
|
|
936
|
-
)
|
|
937
|
-
)
|
|
938
|
-
|
|
939
|
-
if self.single_transformer_blocks is not None:
|
|
940
|
-
# https://github.com/huggingface/diffusers/blob/main/src/diffusers/models/transformers/transformer_flux.py#L380
|
|
941
|
-
if is_diffusers_at_least_0_3_5():
|
|
942
|
-
for j, block in enumerate(self.single_transformer_blocks):
|
|
943
|
-
# NOTE: Reuse _compute_or_prune_transformer_block here.
|
|
944
|
-
hidden_states, encoder_hidden_states = (
|
|
945
|
-
self._compute_or_prune_transformer_block(
|
|
946
|
-
j + len(self.transformer_blocks),
|
|
947
|
-
block,
|
|
948
|
-
hidden_states,
|
|
949
|
-
encoder_hidden_states,
|
|
950
|
-
*args,
|
|
951
|
-
**kwargs,
|
|
952
|
-
)
|
|
953
|
-
)
|
|
954
|
-
else:
|
|
955
|
-
hidden_states = torch.cat(
|
|
956
|
-
[encoder_hidden_states, hidden_states], dim=1
|
|
957
|
-
)
|
|
958
|
-
for j, block in enumerate(self.single_transformer_blocks):
|
|
959
|
-
hidden_states = (
|
|
960
|
-
self._compute_or_prune_single_transformer_block(
|
|
961
|
-
j + len(self.transformer_blocks),
|
|
962
|
-
original_hidden_states,
|
|
963
|
-
original_encoder_hidden_states,
|
|
964
|
-
block,
|
|
965
|
-
hidden_states,
|
|
966
|
-
*args,
|
|
967
|
-
**kwargs,
|
|
968
|
-
)
|
|
969
|
-
)
|
|
970
|
-
|
|
971
|
-
encoder_hidden_states, hidden_states = hidden_states.split(
|
|
972
|
-
[
|
|
973
|
-
encoder_hidden_states.shape[1],
|
|
974
|
-
hidden_states.shape[1] - encoder_hidden_states.shape[1],
|
|
975
|
-
],
|
|
976
|
-
dim=1,
|
|
977
|
-
)
|
|
978
|
-
|
|
979
|
-
hidden_states = (
|
|
980
|
-
hidden_states.reshape(-1)
|
|
981
|
-
.contiguous()
|
|
982
|
-
.reshape(original_hidden_states.shape)
|
|
983
|
-
)
|
|
984
|
-
encoder_hidden_states = (
|
|
985
|
-
encoder_hidden_states.reshape(-1)
|
|
986
|
-
.contiguous()
|
|
987
|
-
.reshape(original_encoder_hidden_states.shape)
|
|
988
|
-
)
|
|
989
|
-
return hidden_states, encoder_hidden_states
|
|
990
|
-
|
|
991
|
-
|
|
992
|
-
@torch.compiler.disable
|
|
993
|
-
def patch_pruned_stats(
|
|
994
|
-
transformer,
|
|
995
|
-
):
|
|
996
|
-
# Patch the pruned stats to the transformer, the pruned stats
|
|
997
|
-
# will be reset for each calling of pipe.__call__(**kwargs).
|
|
998
|
-
if transformer is None:
|
|
999
|
-
return
|
|
1000
|
-
|
|
1001
|
-
# TODO: Patch more pruned stats to the transformer
|
|
1002
|
-
transformer._pruned_blocks = get_pruned_blocks()
|
|
1003
|
-
transformer._pruned_steps = get_pruned_steps()
|
|
1004
|
-
transformer._residual_diffs = get_residual_diffs()
|
|
1005
|
-
transformer._actual_blocks = get_actual_blocks()
|