cache-dit 0.2.25__py3-none-any.whl → 0.2.27__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Potentially problematic release.
This version of cache-dit might be problematic. Click here for more details.
- cache_dit/__init__.py +9 -4
- cache_dit/_version.py +2 -2
- cache_dit/cache_factory/__init__.py +16 -3
- cache_dit/cache_factory/block_adapters/__init__.py +538 -0
- cache_dit/cache_factory/block_adapters/block_adapters.py +333 -0
- cache_dit/cache_factory/block_adapters/block_registers.py +77 -0
- cache_dit/cache_factory/cache_adapters.py +121 -563
- cache_dit/cache_factory/cache_blocks/__init__.py +18 -0
- cache_dit/cache_factory/cache_blocks/pattern_0_1_2.py +16 -0
- cache_dit/cache_factory/cache_blocks/pattern_3_4_5.py +275 -0
- cache_dit/cache_factory/{cache_blocks.py → cache_blocks/pattern_base.py} +100 -82
- cache_dit/cache_factory/cache_blocks/utils.py +23 -0
- cache_dit/cache_factory/cache_contexts/__init__.py +2 -0
- cache_dit/cache_factory/{cache_context.py → cache_contexts/cache_context.py} +94 -56
- cache_dit/cache_factory/cache_interface.py +24 -16
- cache_dit/cache_factory/forward_pattern.py +45 -24
- cache_dit/cache_factory/patch_functors/__init__.py +5 -0
- cache_dit/cache_factory/patch_functors/functor_base.py +18 -0
- cache_dit/cache_factory/patch_functors/functor_chroma.py +276 -0
- cache_dit/cache_factory/{patch/flux.py → patch_functors/functor_flux.py} +49 -31
- cache_dit/quantize/quantize_ao.py +19 -4
- cache_dit/quantize/quantize_interface.py +2 -2
- cache_dit/utils.py +19 -15
- {cache_dit-0.2.25.dist-info → cache_dit-0.2.27.dist-info}/METADATA +76 -19
- cache_dit-0.2.27.dist-info/RECORD +47 -0
- cache_dit-0.2.25.dist-info/RECORD +0 -36
- /cache_dit/cache_factory/{patch/__init__.py → cache_contexts/cache_manager.py} +0 -0
- /cache_dit/cache_factory/{taylorseer.py → cache_contexts/taylorseer.py} +0 -0
- {cache_dit-0.2.25.dist-info → cache_dit-0.2.27.dist-info}/WHEEL +0 -0
- {cache_dit-0.2.25.dist-info → cache_dit-0.2.27.dist-info}/entry_points.txt +0 -0
- {cache_dit-0.2.25.dist-info → cache_dit-0.2.27.dist-info}/licenses/LICENSE +0 -0
- {cache_dit-0.2.25.dist-info → cache_dit-0.2.27.dist-info}/top_level.txt +0 -0
|
@@ -2,14 +2,14 @@ import inspect
|
|
|
2
2
|
import torch
|
|
3
3
|
import torch.distributed as dist
|
|
4
4
|
|
|
5
|
-
from cache_dit.cache_factory import
|
|
5
|
+
from cache_dit.cache_factory import CachedContext
|
|
6
6
|
from cache_dit.cache_factory import ForwardPattern
|
|
7
7
|
from cache_dit.logger import init_logger
|
|
8
8
|
|
|
9
9
|
logger = init_logger(__name__)
|
|
10
10
|
|
|
11
11
|
|
|
12
|
-
class
|
|
12
|
+
class CachedBlocks_Pattern_Base(torch.nn.Module):
|
|
13
13
|
_supported_patterns = [
|
|
14
14
|
ForwardPattern.Pattern_0,
|
|
15
15
|
ForwardPattern.Pattern_1,
|
|
@@ -19,28 +19,54 @@ class DBCachedTransformerBlocks(torch.nn.Module):
|
|
|
19
19
|
def __init__(
|
|
20
20
|
self,
|
|
21
21
|
transformer_blocks: torch.nn.ModuleList,
|
|
22
|
+
# 'transformer_blocks', 'blocks', 'single_transformer_blocks',
|
|
23
|
+
# 'layers', 'single_stream_blocks', 'double_stream_blocks'
|
|
24
|
+
blocks_name: str,
|
|
25
|
+
# Usually, blocks_name, etc.
|
|
26
|
+
cache_context: str,
|
|
22
27
|
*,
|
|
23
28
|
transformer: torch.nn.Module = None,
|
|
24
29
|
forward_pattern: ForwardPattern = ForwardPattern.Pattern_0,
|
|
30
|
+
check_num_outputs: bool = True,
|
|
25
31
|
):
|
|
26
32
|
super().__init__()
|
|
27
33
|
|
|
28
34
|
self.transformer = transformer
|
|
29
35
|
self.transformer_blocks = transformer_blocks
|
|
36
|
+
self.blocks_name = blocks_name
|
|
37
|
+
self.cache_context = cache_context
|
|
30
38
|
self.forward_pattern = forward_pattern
|
|
39
|
+
self.check_num_outputs = check_num_outputs
|
|
31
40
|
self._check_forward_pattern()
|
|
41
|
+
logger.info(
|
|
42
|
+
f"Match Cached Blocks: {self.__class__.__name__}, for "
|
|
43
|
+
f"{self.blocks_name}, context: {self.cache_context}"
|
|
44
|
+
)
|
|
32
45
|
|
|
33
46
|
def _check_forward_pattern(self):
|
|
34
47
|
assert (
|
|
35
48
|
self.forward_pattern.Supported
|
|
36
49
|
and self.forward_pattern in self._supported_patterns
|
|
37
|
-
), f"Pattern {self.forward_pattern} is not
|
|
50
|
+
), f"Pattern {self.forward_pattern} is not supported now!"
|
|
38
51
|
|
|
39
52
|
if self.transformer_blocks is not None:
|
|
40
53
|
for block in self.transformer_blocks:
|
|
41
54
|
forward_parameters = set(
|
|
42
55
|
inspect.signature(block.forward).parameters.keys()
|
|
43
56
|
)
|
|
57
|
+
|
|
58
|
+
if self.check_num_outputs:
|
|
59
|
+
num_outputs = str(
|
|
60
|
+
inspect.signature(block.forward).return_annotation
|
|
61
|
+
).count("torch.Tensor")
|
|
62
|
+
|
|
63
|
+
if num_outputs > 0:
|
|
64
|
+
assert len(self.forward_pattern.Out) == num_outputs, (
|
|
65
|
+
f"The number of block's outputs is {num_outputs} don't not "
|
|
66
|
+
f"match the number of the pattern: {self.forward_pattern}, "
|
|
67
|
+
f"Out: {len(self.forward_pattern.Out)}."
|
|
68
|
+
)
|
|
69
|
+
|
|
44
70
|
for required_param in self.forward_pattern.In:
|
|
45
71
|
assert (
|
|
46
72
|
required_param in forward_parameters
|
|
@@ -53,6 +79,10 @@ class DBCachedTransformerBlocks(torch.nn.Module):
|
|
|
53
79
|
*args,
|
|
54
80
|
**kwargs,
|
|
55
81
|
):
|
|
82
|
+
CachedContext.set_cache_context(
|
|
83
|
+
self.cache_context,
|
|
84
|
+
)
|
|
85
|
+
|
|
56
86
|
original_hidden_states = hidden_states
|
|
57
87
|
# Call first `n` blocks to process the hidden states for
|
|
58
88
|
# more stable diff calculation.
|
|
@@ -66,39 +96,39 @@ class DBCachedTransformerBlocks(torch.nn.Module):
|
|
|
66
96
|
Fn_hidden_states_residual = hidden_states - original_hidden_states
|
|
67
97
|
del original_hidden_states
|
|
68
98
|
|
|
69
|
-
|
|
99
|
+
CachedContext.mark_step_begin()
|
|
70
100
|
# Residual L1 diff or Hidden States L1 diff
|
|
71
|
-
can_use_cache =
|
|
101
|
+
can_use_cache = CachedContext.get_can_use_cache(
|
|
72
102
|
(
|
|
73
103
|
Fn_hidden_states_residual
|
|
74
|
-
if not
|
|
104
|
+
if not CachedContext.is_l1_diff_enabled()
|
|
75
105
|
else hidden_states
|
|
76
106
|
),
|
|
77
107
|
parallelized=self._is_parallelized(),
|
|
78
108
|
prefix=(
|
|
79
|
-
"
|
|
80
|
-
if not
|
|
81
|
-
else "
|
|
109
|
+
f"{self.blocks_name}_Fn_residual"
|
|
110
|
+
if not CachedContext.is_l1_diff_enabled()
|
|
111
|
+
else f"{self.blocks_name}_Fn_hidden_states"
|
|
82
112
|
),
|
|
83
113
|
)
|
|
84
114
|
|
|
85
115
|
torch._dynamo.graph_break()
|
|
86
116
|
if can_use_cache:
|
|
87
|
-
|
|
117
|
+
CachedContext.add_cached_step()
|
|
88
118
|
del Fn_hidden_states_residual
|
|
89
119
|
hidden_states, encoder_hidden_states = (
|
|
90
|
-
|
|
120
|
+
CachedContext.apply_hidden_states_residual(
|
|
91
121
|
hidden_states,
|
|
92
122
|
encoder_hidden_states,
|
|
93
123
|
prefix=(
|
|
94
|
-
"
|
|
95
|
-
if
|
|
96
|
-
else "
|
|
124
|
+
f"{self.blocks_name}_Bn_residual"
|
|
125
|
+
if CachedContext.is_cache_residual()
|
|
126
|
+
else f"{self.blocks_name}_Bn_hidden_states"
|
|
97
127
|
),
|
|
98
128
|
encoder_prefix=(
|
|
99
|
-
"
|
|
100
|
-
if
|
|
101
|
-
else "
|
|
129
|
+
f"{self.blocks_name}_Bn_residual"
|
|
130
|
+
if CachedContext.is_encoder_cache_residual()
|
|
131
|
+
else f"{self.blocks_name}_Bn_hidden_states"
|
|
102
132
|
),
|
|
103
133
|
)
|
|
104
134
|
)
|
|
@@ -112,12 +142,16 @@ class DBCachedTransformerBlocks(torch.nn.Module):
|
|
|
112
142
|
**kwargs,
|
|
113
143
|
)
|
|
114
144
|
else:
|
|
115
|
-
|
|
116
|
-
Fn_hidden_states_residual,
|
|
145
|
+
CachedContext.set_Fn_buffer(
|
|
146
|
+
Fn_hidden_states_residual,
|
|
147
|
+
prefix=f"{self.blocks_name}_Fn_residual",
|
|
117
148
|
)
|
|
118
|
-
if
|
|
149
|
+
if CachedContext.is_l1_diff_enabled():
|
|
119
150
|
# for hidden states L1 diff
|
|
120
|
-
|
|
151
|
+
CachedContext.set_Fn_buffer(
|
|
152
|
+
hidden_states,
|
|
153
|
+
f"{self.blocks_name}_Fn_hidden_states",
|
|
154
|
+
)
|
|
121
155
|
del Fn_hidden_states_residual
|
|
122
156
|
torch._dynamo.graph_break()
|
|
123
157
|
(
|
|
@@ -132,27 +166,27 @@ class DBCachedTransformerBlocks(torch.nn.Module):
|
|
|
132
166
|
**kwargs,
|
|
133
167
|
)
|
|
134
168
|
torch._dynamo.graph_break()
|
|
135
|
-
if
|
|
136
|
-
|
|
169
|
+
if CachedContext.is_cache_residual():
|
|
170
|
+
CachedContext.set_Bn_buffer(
|
|
137
171
|
hidden_states_residual,
|
|
138
|
-
prefix="
|
|
172
|
+
prefix=f"{self.blocks_name}_Bn_residual",
|
|
139
173
|
)
|
|
140
174
|
else:
|
|
141
175
|
# TaylorSeer
|
|
142
|
-
|
|
176
|
+
CachedContext.set_Bn_buffer(
|
|
143
177
|
hidden_states,
|
|
144
|
-
prefix="
|
|
178
|
+
prefix=f"{self.blocks_name}_Bn_hidden_states",
|
|
145
179
|
)
|
|
146
|
-
if
|
|
147
|
-
|
|
180
|
+
if CachedContext.is_encoder_cache_residual():
|
|
181
|
+
CachedContext.set_Bn_encoder_buffer(
|
|
148
182
|
encoder_hidden_states_residual,
|
|
149
|
-
prefix="
|
|
183
|
+
prefix=f"{self.blocks_name}_Bn_residual",
|
|
150
184
|
)
|
|
151
185
|
else:
|
|
152
186
|
# TaylorSeer
|
|
153
|
-
|
|
187
|
+
CachedContext.set_Bn_encoder_buffer(
|
|
154
188
|
encoder_hidden_states,
|
|
155
|
-
prefix="
|
|
189
|
+
prefix=f"{self.blocks_name}_Bn_hidden_states",
|
|
156
190
|
)
|
|
157
191
|
torch._dynamo.graph_break()
|
|
158
192
|
# Call last `n` blocks to further process the hidden states
|
|
@@ -164,7 +198,7 @@ class DBCachedTransformerBlocks(torch.nn.Module):
|
|
|
164
198
|
**kwargs,
|
|
165
199
|
)
|
|
166
200
|
|
|
167
|
-
|
|
201
|
+
# patch cached stats for blocks or remove it.
|
|
168
202
|
torch._dynamo.graph_break()
|
|
169
203
|
|
|
170
204
|
return (
|
|
@@ -198,10 +232,10 @@ class DBCachedTransformerBlocks(torch.nn.Module):
|
|
|
198
232
|
# If so, we can skip some Bn blocks and directly
|
|
199
233
|
# use the cached values.
|
|
200
234
|
return (
|
|
201
|
-
|
|
235
|
+
CachedContext.get_current_step() in CachedContext.get_cached_steps()
|
|
202
236
|
) or (
|
|
203
|
-
|
|
204
|
-
in
|
|
237
|
+
CachedContext.get_current_step()
|
|
238
|
+
in CachedContext.get_cfg_cached_steps()
|
|
205
239
|
)
|
|
206
240
|
|
|
207
241
|
@torch.compiler.disable
|
|
@@ -210,20 +244,20 @@ class DBCachedTransformerBlocks(torch.nn.Module):
|
|
|
210
244
|
# more stable diff calculation.
|
|
211
245
|
# Fn: [0,...,n-1]
|
|
212
246
|
selected_Fn_blocks = self.transformer_blocks[
|
|
213
|
-
:
|
|
247
|
+
: CachedContext.Fn_compute_blocks()
|
|
214
248
|
]
|
|
215
249
|
return selected_Fn_blocks
|
|
216
250
|
|
|
217
251
|
@torch.compiler.disable
|
|
218
252
|
def _Mn_blocks(self): # middle blocks
|
|
219
253
|
# M(N-2n): only transformer_blocks [n,...,N-n], middle
|
|
220
|
-
if
|
|
254
|
+
if CachedContext.Bn_compute_blocks() == 0: # WARN: x[:-0] = []
|
|
221
255
|
selected_Mn_blocks = self.transformer_blocks[
|
|
222
|
-
|
|
256
|
+
CachedContext.Fn_compute_blocks() :
|
|
223
257
|
]
|
|
224
258
|
else:
|
|
225
259
|
selected_Mn_blocks = self.transformer_blocks[
|
|
226
|
-
|
|
260
|
+
CachedContext.Fn_compute_blocks() : -CachedContext.Bn_compute_blocks()
|
|
227
261
|
]
|
|
228
262
|
return selected_Mn_blocks
|
|
229
263
|
|
|
@@ -231,7 +265,7 @@ class DBCachedTransformerBlocks(torch.nn.Module):
|
|
|
231
265
|
def _Bn_blocks(self):
|
|
232
266
|
# Bn: transformer_blocks [N-n+1,...,N-1]
|
|
233
267
|
selected_Bn_blocks = self.transformer_blocks[
|
|
234
|
-
-
|
|
268
|
+
-CachedContext.Bn_compute_blocks() :
|
|
235
269
|
]
|
|
236
270
|
return selected_Bn_blocks
|
|
237
271
|
|
|
@@ -242,10 +276,10 @@ class DBCachedTransformerBlocks(torch.nn.Module):
|
|
|
242
276
|
*args,
|
|
243
277
|
**kwargs,
|
|
244
278
|
):
|
|
245
|
-
assert
|
|
279
|
+
assert CachedContext.Fn_compute_blocks() <= len(
|
|
246
280
|
self.transformer_blocks
|
|
247
281
|
), (
|
|
248
|
-
f"Fn_compute_blocks {
|
|
282
|
+
f"Fn_compute_blocks {CachedContext.Fn_compute_blocks()} must be less than "
|
|
249
283
|
f"the number of transformer blocks {len(self.transformer_blocks)}"
|
|
250
284
|
)
|
|
251
285
|
for block in self._Fn_blocks():
|
|
@@ -342,7 +376,7 @@ class DBCachedTransformerBlocks(torch.nn.Module):
|
|
|
342
376
|
)
|
|
343
377
|
# Cache residuals for the non-compute Bn blocks for
|
|
344
378
|
# subsequent cache steps.
|
|
345
|
-
if block_id not in
|
|
379
|
+
if block_id not in CachedContext.Bn_compute_blocks_ids():
|
|
346
380
|
Bn_i_hidden_states_residual = (
|
|
347
381
|
hidden_states - Bn_i_original_hidden_states
|
|
348
382
|
)
|
|
@@ -351,22 +385,22 @@ class DBCachedTransformerBlocks(torch.nn.Module):
|
|
|
351
385
|
)
|
|
352
386
|
|
|
353
387
|
# Save original_hidden_states for diff calculation.
|
|
354
|
-
|
|
388
|
+
CachedContext.set_Bn_buffer(
|
|
355
389
|
Bn_i_original_hidden_states,
|
|
356
|
-
prefix=f"
|
|
390
|
+
prefix=f"{self.blocks_name}_Bn_{block_id}_original",
|
|
357
391
|
)
|
|
358
|
-
|
|
392
|
+
CachedContext.set_Bn_encoder_buffer(
|
|
359
393
|
Bn_i_original_encoder_hidden_states,
|
|
360
|
-
prefix=f"
|
|
394
|
+
prefix=f"{self.blocks_name}_Bn_{block_id}_original",
|
|
361
395
|
)
|
|
362
396
|
|
|
363
|
-
|
|
397
|
+
CachedContext.set_Bn_buffer(
|
|
364
398
|
Bn_i_hidden_states_residual,
|
|
365
|
-
prefix=f"
|
|
399
|
+
prefix=f"{self.blocks_name}_Bn_{block_id}_residual",
|
|
366
400
|
)
|
|
367
|
-
|
|
401
|
+
CachedContext.set_Bn_encoder_buffer(
|
|
368
402
|
Bn_i_encoder_hidden_states_residual,
|
|
369
|
-
prefix=f"
|
|
403
|
+
prefix=f"{self.blocks_name}_Bn_{block_id}_residual",
|
|
370
404
|
)
|
|
371
405
|
del Bn_i_hidden_states_residual
|
|
372
406
|
del Bn_i_encoder_hidden_states_residual
|
|
@@ -377,7 +411,7 @@ class DBCachedTransformerBlocks(torch.nn.Module):
|
|
|
377
411
|
else:
|
|
378
412
|
# Cache steps: Reuse the cached residuals.
|
|
379
413
|
# Check if the block is in the Bn_compute_blocks_ids.
|
|
380
|
-
if block_id in
|
|
414
|
+
if block_id in CachedContext.Bn_compute_blocks_ids():
|
|
381
415
|
hidden_states = block(
|
|
382
416
|
hidden_states,
|
|
383
417
|
encoder_hidden_states,
|
|
@@ -395,25 +429,25 @@ class DBCachedTransformerBlocks(torch.nn.Module):
|
|
|
395
429
|
# Skip the block if it is not in the Bn_compute_blocks_ids.
|
|
396
430
|
# Use the cached residuals instead.
|
|
397
431
|
# Check if can use the cached residuals.
|
|
398
|
-
if
|
|
432
|
+
if CachedContext.get_can_use_cache(
|
|
399
433
|
hidden_states, # curr step
|
|
400
434
|
parallelized=self._is_parallelized(),
|
|
401
|
-
threshold=
|
|
402
|
-
prefix=f"
|
|
435
|
+
threshold=CachedContext.non_compute_blocks_diff_threshold(),
|
|
436
|
+
prefix=f"{self.blocks_name}_Bn_{block_id}_original", # prev step
|
|
403
437
|
):
|
|
404
438
|
hidden_states, encoder_hidden_states = (
|
|
405
|
-
|
|
439
|
+
CachedContext.apply_hidden_states_residual(
|
|
406
440
|
hidden_states,
|
|
407
441
|
encoder_hidden_states,
|
|
408
442
|
prefix=(
|
|
409
|
-
f"
|
|
410
|
-
if
|
|
411
|
-
else f"
|
|
443
|
+
f"{self.blocks_name}_Bn_{block_id}_residual"
|
|
444
|
+
if CachedContext.is_cache_residual()
|
|
445
|
+
else f"{self.blocks_name}_Bn_{block_id}_original"
|
|
412
446
|
),
|
|
413
447
|
encoder_prefix=(
|
|
414
|
-
f"
|
|
415
|
-
if
|
|
416
|
-
else f"
|
|
448
|
+
f"{self.blocks_name}_Bn_{block_id}_residual"
|
|
449
|
+
if CachedContext.is_encoder_cache_residual()
|
|
450
|
+
else f"{self.blocks_name}_Bn_{block_id}_original"
|
|
417
451
|
),
|
|
418
452
|
)
|
|
419
453
|
)
|
|
@@ -440,16 +474,16 @@ class DBCachedTransformerBlocks(torch.nn.Module):
|
|
|
440
474
|
*args,
|
|
441
475
|
**kwargs,
|
|
442
476
|
):
|
|
443
|
-
if
|
|
477
|
+
if CachedContext.Bn_compute_blocks() == 0:
|
|
444
478
|
return hidden_states, encoder_hidden_states
|
|
445
479
|
|
|
446
|
-
assert
|
|
480
|
+
assert CachedContext.Bn_compute_blocks() <= len(
|
|
447
481
|
self.transformer_blocks
|
|
448
482
|
), (
|
|
449
|
-
f"Bn_compute_blocks {
|
|
483
|
+
f"Bn_compute_blocks {CachedContext.Bn_compute_blocks()} must be less than "
|
|
450
484
|
f"the number of transformer blocks {len(self.transformer_blocks)}"
|
|
451
485
|
)
|
|
452
|
-
if len(
|
|
486
|
+
if len(CachedContext.Bn_compute_blocks_ids()) > 0:
|
|
453
487
|
for i, block in enumerate(self._Bn_blocks()):
|
|
454
488
|
hidden_states, encoder_hidden_states = (
|
|
455
489
|
self._compute_or_cache_block(
|
|
@@ -479,19 +513,3 @@ class DBCachedTransformerBlocks(torch.nn.Module):
|
|
|
479
513
|
)
|
|
480
514
|
|
|
481
515
|
return hidden_states, encoder_hidden_states
|
|
482
|
-
|
|
483
|
-
|
|
484
|
-
@torch.compiler.disable
|
|
485
|
-
def patch_cached_stats(
|
|
486
|
-
transformer,
|
|
487
|
-
):
|
|
488
|
-
# Patch the cached stats to the transformer, the cached stats
|
|
489
|
-
# will be reset for each calling of pipe.__call__(**kwargs).
|
|
490
|
-
if transformer is None:
|
|
491
|
-
return
|
|
492
|
-
|
|
493
|
-
# TODO: Patch more cached stats to the transformer
|
|
494
|
-
transformer._cached_steps = cache_context.get_cached_steps()
|
|
495
|
-
transformer._residual_diffs = cache_context.get_residual_diffs()
|
|
496
|
-
transformer._cfg_cached_steps = cache_context.get_cfg_cached_steps()
|
|
497
|
-
transformer._cfg_residual_diffs = cache_context.get_cfg_residual_diffs()
|
|
@@ -0,0 +1,23 @@
|
|
|
1
|
+
import torch
|
|
2
|
+
|
|
3
|
+
from typing import Any
|
|
4
|
+
from cache_dit.cache_factory import CachedContext
|
|
5
|
+
|
|
6
|
+
|
|
7
|
+
@torch.compiler.disable
|
|
8
|
+
def patch_cached_stats(
|
|
9
|
+
module: torch.nn.Module | Any, cache_context: str = None
|
|
10
|
+
):
|
|
11
|
+
# Patch the cached stats to the module, the cached stats
|
|
12
|
+
# will be reset for each calling of pipe.__call__(**kwargs).
|
|
13
|
+
if module is None:
|
|
14
|
+
return
|
|
15
|
+
|
|
16
|
+
if cache_context is not None:
|
|
17
|
+
CachedContext.set_cache_context(cache_context)
|
|
18
|
+
|
|
19
|
+
# TODO: Patch more cached stats to the module
|
|
20
|
+
module._cached_steps = CachedContext.get_cached_steps()
|
|
21
|
+
module._residual_diffs = CachedContext.get_residual_diffs()
|
|
22
|
+
module._cfg_cached_steps = CachedContext.get_cfg_cached_steps()
|
|
23
|
+
module._cfg_residual_diffs = CachedContext.get_cfg_residual_diffs()
|