cache-dit 0.2.26__py3-none-any.whl → 0.2.28__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- cache_dit/__init__.py +8 -6
- cache_dit/_version.py +2 -2
- cache_dit/cache_factory/__init__.py +17 -4
- cache_dit/cache_factory/block_adapters/__init__.py +555 -0
- cache_dit/cache_factory/block_adapters/block_adapters.py +538 -0
- cache_dit/cache_factory/block_adapters/block_registers.py +77 -0
- cache_dit/cache_factory/cache_adapters.py +262 -938
- cache_dit/cache_factory/cache_blocks/__init__.py +60 -11
- cache_dit/cache_factory/cache_blocks/pattern_0_1_2.py +2 -2
- cache_dit/cache_factory/cache_blocks/pattern_3_4_5.py +45 -41
- cache_dit/cache_factory/cache_blocks/pattern_base.py +106 -80
- cache_dit/cache_factory/cache_blocks/utils.py +16 -10
- cache_dit/cache_factory/cache_contexts/__init__.py +5 -0
- cache_dit/cache_factory/cache_contexts/cache_context.py +327 -0
- cache_dit/cache_factory/cache_contexts/cache_manager.py +833 -0
- cache_dit/cache_factory/cache_interface.py +31 -31
- cache_dit/cache_factory/patch_functors/functor_chroma.py +3 -0
- cache_dit/cache_factory/patch_functors/functor_flux.py +4 -0
- cache_dit/quantize/quantize_ao.py +1 -0
- cache_dit/utils.py +26 -26
- {cache_dit-0.2.26.dist-info → cache_dit-0.2.28.dist-info}/METADATA +59 -23
- cache_dit-0.2.28.dist-info/RECORD +47 -0
- cache_dit/cache_factory/cache_context.py +0 -1155
- cache_dit-0.2.26.dist-info/RECORD +0 -42
- /cache_dit/cache_factory/{taylorseer.py → cache_contexts/taylorseer.py} +0 -0
- {cache_dit-0.2.26.dist-info → cache_dit-0.2.28.dist-info}/WHEEL +0 -0
- {cache_dit-0.2.26.dist-info → cache_dit-0.2.28.dist-info}/entry_points.txt +0 -0
- {cache_dit-0.2.26.dist-info → cache_dit-0.2.28.dist-info}/licenses/LICENSE +0 -0
- {cache_dit-0.2.26.dist-info → cache_dit-0.2.28.dist-info}/top_level.txt +0 -0
|
@@ -2,17 +2,17 @@ import inspect
|
|
|
2
2
|
import torch
|
|
3
3
|
import torch.distributed as dist
|
|
4
4
|
|
|
5
|
-
from cache_dit.cache_factory import
|
|
6
|
-
from cache_dit.cache_factory import
|
|
7
|
-
|
|
8
|
-
patch_cached_stats,
|
|
5
|
+
from cache_dit.cache_factory.cache_contexts.cache_context import CachedContext
|
|
6
|
+
from cache_dit.cache_factory.cache_contexts.cache_manager import (
|
|
7
|
+
CachedContextManager,
|
|
9
8
|
)
|
|
9
|
+
from cache_dit.cache_factory import ForwardPattern
|
|
10
10
|
from cache_dit.logger import init_logger
|
|
11
11
|
|
|
12
12
|
logger = init_logger(__name__)
|
|
13
13
|
|
|
14
14
|
|
|
15
|
-
class
|
|
15
|
+
class CachedBlocks_Pattern_Base(torch.nn.Module):
|
|
16
16
|
_supported_patterns = [
|
|
17
17
|
ForwardPattern.Pattern_0,
|
|
18
18
|
ForwardPattern.Pattern_1,
|
|
@@ -21,18 +21,35 @@ class DBCachedBlocks_Pattern_Base(torch.nn.Module):
|
|
|
21
21
|
|
|
22
22
|
def __init__(
|
|
23
23
|
self,
|
|
24
|
+
# 0. Transformer blocks configuration
|
|
24
25
|
transformer_blocks: torch.nn.ModuleList,
|
|
25
|
-
*,
|
|
26
26
|
transformer: torch.nn.Module = None,
|
|
27
27
|
forward_pattern: ForwardPattern = ForwardPattern.Pattern_0,
|
|
28
|
+
check_num_outputs: bool = True,
|
|
29
|
+
# 1. Cache context configuration
|
|
30
|
+
cache_prefix: str = None, # maybe un-need.
|
|
31
|
+
cache_context: CachedContext | str = None,
|
|
32
|
+
cache_manager: CachedContextManager = None,
|
|
33
|
+
**kwargs,
|
|
28
34
|
):
|
|
29
35
|
super().__init__()
|
|
30
36
|
|
|
37
|
+
# 0. Transformer blocks configuration
|
|
31
38
|
self.transformer = transformer
|
|
32
39
|
self.transformer_blocks = transformer_blocks
|
|
33
40
|
self.forward_pattern = forward_pattern
|
|
41
|
+
self.check_num_outputs = check_num_outputs
|
|
42
|
+
# 1. Cache context configuration
|
|
43
|
+
self.cache_prefix = cache_prefix
|
|
44
|
+
self.cache_context = cache_context
|
|
45
|
+
self.cache_manager = cache_manager
|
|
46
|
+
|
|
34
47
|
self._check_forward_pattern()
|
|
35
|
-
logger.info(
|
|
48
|
+
logger.info(
|
|
49
|
+
f"Match Cached Blocks: {self.__class__.__name__}, for "
|
|
50
|
+
f"{self.cache_prefix}, cache_context: {self.cache_context}, "
|
|
51
|
+
f"cache_manager: {self.cache_manager.name}."
|
|
52
|
+
)
|
|
36
53
|
|
|
37
54
|
def _check_forward_pattern(self):
|
|
38
55
|
assert (
|
|
@@ -45,16 +62,18 @@ class DBCachedBlocks_Pattern_Base(torch.nn.Module):
|
|
|
45
62
|
forward_parameters = set(
|
|
46
63
|
inspect.signature(block.forward).parameters.keys()
|
|
47
64
|
)
|
|
48
|
-
|
|
49
|
-
|
|
50
|
-
|
|
51
|
-
|
|
52
|
-
|
|
53
|
-
|
|
54
|
-
|
|
55
|
-
|
|
56
|
-
|
|
57
|
-
|
|
65
|
+
|
|
66
|
+
if self.check_num_outputs:
|
|
67
|
+
num_outputs = str(
|
|
68
|
+
inspect.signature(block.forward).return_annotation
|
|
69
|
+
).count("torch.Tensor")
|
|
70
|
+
|
|
71
|
+
if num_outputs > 0:
|
|
72
|
+
assert len(self.forward_pattern.Out) == num_outputs, (
|
|
73
|
+
f"The number of block's outputs is {num_outputs} don't not "
|
|
74
|
+
f"match the number of the pattern: {self.forward_pattern}, "
|
|
75
|
+
f"Out: {len(self.forward_pattern.Out)}."
|
|
76
|
+
)
|
|
58
77
|
|
|
59
78
|
for required_param in self.forward_pattern.In:
|
|
60
79
|
assert (
|
|
@@ -68,6 +87,8 @@ class DBCachedBlocks_Pattern_Base(torch.nn.Module):
|
|
|
68
87
|
*args,
|
|
69
88
|
**kwargs,
|
|
70
89
|
):
|
|
90
|
+
self.cache_manager.set_context(self.cache_context)
|
|
91
|
+
|
|
71
92
|
original_hidden_states = hidden_states
|
|
72
93
|
# Call first `n` blocks to process the hidden states for
|
|
73
94
|
# more stable diff calculation.
|
|
@@ -81,39 +102,39 @@ class DBCachedBlocks_Pattern_Base(torch.nn.Module):
|
|
|
81
102
|
Fn_hidden_states_residual = hidden_states - original_hidden_states
|
|
82
103
|
del original_hidden_states
|
|
83
104
|
|
|
84
|
-
|
|
105
|
+
self.cache_manager.mark_step_begin()
|
|
85
106
|
# Residual L1 diff or Hidden States L1 diff
|
|
86
|
-
can_use_cache =
|
|
107
|
+
can_use_cache = self.cache_manager.can_cache(
|
|
87
108
|
(
|
|
88
109
|
Fn_hidden_states_residual
|
|
89
|
-
if not
|
|
110
|
+
if not self.cache_manager.is_l1_diff_enabled()
|
|
90
111
|
else hidden_states
|
|
91
112
|
),
|
|
92
113
|
parallelized=self._is_parallelized(),
|
|
93
114
|
prefix=(
|
|
94
|
-
"
|
|
95
|
-
if not
|
|
96
|
-
else "
|
|
115
|
+
f"{self.cache_prefix}_Fn_residual"
|
|
116
|
+
if not self.cache_manager.is_l1_diff_enabled()
|
|
117
|
+
else f"{self.cache_prefix}_Fn_hidden_states"
|
|
97
118
|
),
|
|
98
119
|
)
|
|
99
120
|
|
|
100
121
|
torch._dynamo.graph_break()
|
|
101
122
|
if can_use_cache:
|
|
102
|
-
|
|
123
|
+
self.cache_manager.add_cached_step()
|
|
103
124
|
del Fn_hidden_states_residual
|
|
104
125
|
hidden_states, encoder_hidden_states = (
|
|
105
|
-
|
|
126
|
+
self.cache_manager.apply_cache(
|
|
106
127
|
hidden_states,
|
|
107
128
|
encoder_hidden_states,
|
|
108
129
|
prefix=(
|
|
109
|
-
"
|
|
110
|
-
if
|
|
111
|
-
else "
|
|
130
|
+
f"{self.cache_prefix}_Bn_residual"
|
|
131
|
+
if self.cache_manager.is_cache_residual()
|
|
132
|
+
else f"{self.cache_prefix}_Bn_hidden_states"
|
|
112
133
|
),
|
|
113
134
|
encoder_prefix=(
|
|
114
|
-
"
|
|
115
|
-
if
|
|
116
|
-
else "
|
|
135
|
+
f"{self.cache_prefix}_Bn_residual"
|
|
136
|
+
if self.cache_manager.is_encoder_cache_residual()
|
|
137
|
+
else f"{self.cache_prefix}_Bn_hidden_states"
|
|
117
138
|
),
|
|
118
139
|
)
|
|
119
140
|
)
|
|
@@ -127,12 +148,16 @@ class DBCachedBlocks_Pattern_Base(torch.nn.Module):
|
|
|
127
148
|
**kwargs,
|
|
128
149
|
)
|
|
129
150
|
else:
|
|
130
|
-
|
|
131
|
-
Fn_hidden_states_residual,
|
|
151
|
+
self.cache_manager.set_Fn_buffer(
|
|
152
|
+
Fn_hidden_states_residual,
|
|
153
|
+
prefix=f"{self.cache_prefix}_Fn_residual",
|
|
132
154
|
)
|
|
133
|
-
if
|
|
155
|
+
if self.cache_manager.is_l1_diff_enabled():
|
|
134
156
|
# for hidden states L1 diff
|
|
135
|
-
|
|
157
|
+
self.cache_manager.set_Fn_buffer(
|
|
158
|
+
hidden_states,
|
|
159
|
+
f"{self.cache_prefix}_Fn_hidden_states",
|
|
160
|
+
)
|
|
136
161
|
del Fn_hidden_states_residual
|
|
137
162
|
torch._dynamo.graph_break()
|
|
138
163
|
(
|
|
@@ -147,27 +172,27 @@ class DBCachedBlocks_Pattern_Base(torch.nn.Module):
|
|
|
147
172
|
**kwargs,
|
|
148
173
|
)
|
|
149
174
|
torch._dynamo.graph_break()
|
|
150
|
-
if
|
|
151
|
-
|
|
175
|
+
if self.cache_manager.is_cache_residual():
|
|
176
|
+
self.cache_manager.set_Bn_buffer(
|
|
152
177
|
hidden_states_residual,
|
|
153
|
-
prefix="
|
|
178
|
+
prefix=f"{self.cache_prefix}_Bn_residual",
|
|
154
179
|
)
|
|
155
180
|
else:
|
|
156
181
|
# TaylorSeer
|
|
157
|
-
|
|
182
|
+
self.cache_manager.set_Bn_buffer(
|
|
158
183
|
hidden_states,
|
|
159
|
-
prefix="
|
|
184
|
+
prefix=f"{self.cache_prefix}_Bn_hidden_states",
|
|
160
185
|
)
|
|
161
|
-
if
|
|
162
|
-
|
|
186
|
+
if self.cache_manager.is_encoder_cache_residual():
|
|
187
|
+
self.cache_manager.set_Bn_encoder_buffer(
|
|
163
188
|
encoder_hidden_states_residual,
|
|
164
|
-
prefix="
|
|
189
|
+
prefix=f"{self.cache_prefix}_Bn_residual",
|
|
165
190
|
)
|
|
166
191
|
else:
|
|
167
192
|
# TaylorSeer
|
|
168
|
-
|
|
193
|
+
self.cache_manager.set_Bn_encoder_buffer(
|
|
169
194
|
encoder_hidden_states,
|
|
170
|
-
prefix="
|
|
195
|
+
prefix=f"{self.cache_prefix}_Bn_hidden_states",
|
|
171
196
|
)
|
|
172
197
|
torch._dynamo.graph_break()
|
|
173
198
|
# Call last `n` blocks to further process the hidden states
|
|
@@ -179,7 +204,7 @@ class DBCachedBlocks_Pattern_Base(torch.nn.Module):
|
|
|
179
204
|
**kwargs,
|
|
180
205
|
)
|
|
181
206
|
|
|
182
|
-
|
|
207
|
+
# patch cached stats for blocks or remove it.
|
|
183
208
|
torch._dynamo.graph_break()
|
|
184
209
|
|
|
185
210
|
return (
|
|
@@ -213,10 +238,11 @@ class DBCachedBlocks_Pattern_Base(torch.nn.Module):
|
|
|
213
238
|
# If so, we can skip some Bn blocks and directly
|
|
214
239
|
# use the cached values.
|
|
215
240
|
return (
|
|
216
|
-
|
|
241
|
+
self.cache_manager.get_current_step()
|
|
242
|
+
in self.cache_manager.get_cached_steps()
|
|
217
243
|
) or (
|
|
218
|
-
|
|
219
|
-
in
|
|
244
|
+
self.cache_manager.get_current_step()
|
|
245
|
+
in self.cache_manager.get_cfg_cached_steps()
|
|
220
246
|
)
|
|
221
247
|
|
|
222
248
|
@torch.compiler.disable
|
|
@@ -225,20 +251,20 @@ class DBCachedBlocks_Pattern_Base(torch.nn.Module):
|
|
|
225
251
|
# more stable diff calculation.
|
|
226
252
|
# Fn: [0,...,n-1]
|
|
227
253
|
selected_Fn_blocks = self.transformer_blocks[
|
|
228
|
-
:
|
|
254
|
+
: self.cache_manager.Fn_compute_blocks()
|
|
229
255
|
]
|
|
230
256
|
return selected_Fn_blocks
|
|
231
257
|
|
|
232
258
|
@torch.compiler.disable
|
|
233
259
|
def _Mn_blocks(self): # middle blocks
|
|
234
260
|
# M(N-2n): only transformer_blocks [n,...,N-n], middle
|
|
235
|
-
if
|
|
261
|
+
if self.cache_manager.Bn_compute_blocks() == 0: # WARN: x[:-0] = []
|
|
236
262
|
selected_Mn_blocks = self.transformer_blocks[
|
|
237
|
-
|
|
263
|
+
self.cache_manager.Fn_compute_blocks() :
|
|
238
264
|
]
|
|
239
265
|
else:
|
|
240
266
|
selected_Mn_blocks = self.transformer_blocks[
|
|
241
|
-
|
|
267
|
+
self.cache_manager.Fn_compute_blocks() : -self.cache_manager.Bn_compute_blocks()
|
|
242
268
|
]
|
|
243
269
|
return selected_Mn_blocks
|
|
244
270
|
|
|
@@ -246,7 +272,7 @@ class DBCachedBlocks_Pattern_Base(torch.nn.Module):
|
|
|
246
272
|
def _Bn_blocks(self):
|
|
247
273
|
# Bn: transformer_blocks [N-n+1,...,N-1]
|
|
248
274
|
selected_Bn_blocks = self.transformer_blocks[
|
|
249
|
-
-
|
|
275
|
+
-self.cache_manager.Bn_compute_blocks() :
|
|
250
276
|
]
|
|
251
277
|
return selected_Bn_blocks
|
|
252
278
|
|
|
@@ -257,10 +283,10 @@ class DBCachedBlocks_Pattern_Base(torch.nn.Module):
|
|
|
257
283
|
*args,
|
|
258
284
|
**kwargs,
|
|
259
285
|
):
|
|
260
|
-
assert
|
|
286
|
+
assert self.cache_manager.Fn_compute_blocks() <= len(
|
|
261
287
|
self.transformer_blocks
|
|
262
288
|
), (
|
|
263
|
-
f"Fn_compute_blocks {
|
|
289
|
+
f"Fn_compute_blocks {self.cache_manager.Fn_compute_blocks()} must be less than "
|
|
264
290
|
f"the number of transformer blocks {len(self.transformer_blocks)}"
|
|
265
291
|
)
|
|
266
292
|
for block in self._Fn_blocks():
|
|
@@ -357,7 +383,7 @@ class DBCachedBlocks_Pattern_Base(torch.nn.Module):
|
|
|
357
383
|
)
|
|
358
384
|
# Cache residuals for the non-compute Bn blocks for
|
|
359
385
|
# subsequent cache steps.
|
|
360
|
-
if block_id not in
|
|
386
|
+
if block_id not in self.cache_manager.Bn_compute_blocks_ids():
|
|
361
387
|
Bn_i_hidden_states_residual = (
|
|
362
388
|
hidden_states - Bn_i_original_hidden_states
|
|
363
389
|
)
|
|
@@ -366,22 +392,22 @@ class DBCachedBlocks_Pattern_Base(torch.nn.Module):
|
|
|
366
392
|
)
|
|
367
393
|
|
|
368
394
|
# Save original_hidden_states for diff calculation.
|
|
369
|
-
|
|
395
|
+
self.cache_manager.set_Bn_buffer(
|
|
370
396
|
Bn_i_original_hidden_states,
|
|
371
|
-
prefix=f"
|
|
397
|
+
prefix=f"{self.cache_prefix}_Bn_{block_id}_original",
|
|
372
398
|
)
|
|
373
|
-
|
|
399
|
+
self.cache_manager.set_Bn_encoder_buffer(
|
|
374
400
|
Bn_i_original_encoder_hidden_states,
|
|
375
|
-
prefix=f"
|
|
401
|
+
prefix=f"{self.cache_prefix}_Bn_{block_id}_original",
|
|
376
402
|
)
|
|
377
403
|
|
|
378
|
-
|
|
404
|
+
self.cache_manager.set_Bn_buffer(
|
|
379
405
|
Bn_i_hidden_states_residual,
|
|
380
|
-
prefix=f"
|
|
406
|
+
prefix=f"{self.cache_prefix}_Bn_{block_id}_residual",
|
|
381
407
|
)
|
|
382
|
-
|
|
408
|
+
self.cache_manager.set_Bn_encoder_buffer(
|
|
383
409
|
Bn_i_encoder_hidden_states_residual,
|
|
384
|
-
prefix=f"
|
|
410
|
+
prefix=f"{self.cache_prefix}_Bn_{block_id}_residual",
|
|
385
411
|
)
|
|
386
412
|
del Bn_i_hidden_states_residual
|
|
387
413
|
del Bn_i_encoder_hidden_states_residual
|
|
@@ -392,7 +418,7 @@ class DBCachedBlocks_Pattern_Base(torch.nn.Module):
|
|
|
392
418
|
else:
|
|
393
419
|
# Cache steps: Reuse the cached residuals.
|
|
394
420
|
# Check if the block is in the Bn_compute_blocks_ids.
|
|
395
|
-
if block_id in
|
|
421
|
+
if block_id in self.cache_manager.Bn_compute_blocks_ids():
|
|
396
422
|
hidden_states = block(
|
|
397
423
|
hidden_states,
|
|
398
424
|
encoder_hidden_states,
|
|
@@ -410,25 +436,25 @@ class DBCachedBlocks_Pattern_Base(torch.nn.Module):
|
|
|
410
436
|
# Skip the block if it is not in the Bn_compute_blocks_ids.
|
|
411
437
|
# Use the cached residuals instead.
|
|
412
438
|
# Check if can use the cached residuals.
|
|
413
|
-
if
|
|
439
|
+
if self.cache_manager.can_cache(
|
|
414
440
|
hidden_states, # curr step
|
|
415
441
|
parallelized=self._is_parallelized(),
|
|
416
|
-
threshold=
|
|
417
|
-
prefix=f"
|
|
442
|
+
threshold=self.cache_manager.non_compute_blocks_diff_threshold(),
|
|
443
|
+
prefix=f"{self.cache_prefix}_Bn_{block_id}_original", # prev step
|
|
418
444
|
):
|
|
419
445
|
hidden_states, encoder_hidden_states = (
|
|
420
|
-
|
|
446
|
+
self.cache_manager.apply_cache(
|
|
421
447
|
hidden_states,
|
|
422
448
|
encoder_hidden_states,
|
|
423
449
|
prefix=(
|
|
424
|
-
f"
|
|
425
|
-
if
|
|
426
|
-
else f"
|
|
450
|
+
f"{self.cache_prefix}_Bn_{block_id}_residual"
|
|
451
|
+
if self.cache_manager.is_cache_residual()
|
|
452
|
+
else f"{self.cache_prefix}_Bn_{block_id}_original"
|
|
427
453
|
),
|
|
428
454
|
encoder_prefix=(
|
|
429
|
-
f"
|
|
430
|
-
if
|
|
431
|
-
else f"
|
|
455
|
+
f"{self.cache_prefix}_Bn_{block_id}_residual"
|
|
456
|
+
if self.cache_manager.is_encoder_cache_residual()
|
|
457
|
+
else f"{self.cache_prefix}_Bn_{block_id}_original"
|
|
432
458
|
),
|
|
433
459
|
)
|
|
434
460
|
)
|
|
@@ -455,16 +481,16 @@ class DBCachedBlocks_Pattern_Base(torch.nn.Module):
|
|
|
455
481
|
*args,
|
|
456
482
|
**kwargs,
|
|
457
483
|
):
|
|
458
|
-
if
|
|
484
|
+
if self.cache_manager.Bn_compute_blocks() == 0:
|
|
459
485
|
return hidden_states, encoder_hidden_states
|
|
460
486
|
|
|
461
|
-
assert
|
|
487
|
+
assert self.cache_manager.Bn_compute_blocks() <= len(
|
|
462
488
|
self.transformer_blocks
|
|
463
489
|
), (
|
|
464
|
-
f"Bn_compute_blocks {
|
|
490
|
+
f"Bn_compute_blocks {self.cache_manager.Bn_compute_blocks()} must be less than "
|
|
465
491
|
f"the number of transformer blocks {len(self.transformer_blocks)}"
|
|
466
492
|
)
|
|
467
|
-
if len(
|
|
493
|
+
if len(self.cache_manager.Bn_compute_blocks_ids()) > 0:
|
|
468
494
|
for i, block in enumerate(self._Bn_blocks()):
|
|
469
495
|
hidden_states, encoder_hidden_states = (
|
|
470
496
|
self._compute_or_cache_block(
|
|
@@ -1,19 +1,25 @@
|
|
|
1
1
|
import torch
|
|
2
2
|
|
|
3
|
-
from
|
|
3
|
+
from typing import Any
|
|
4
|
+
from cache_dit.cache_factory import CachedContext
|
|
5
|
+
from cache_dit.cache_factory import CachedContextManager
|
|
4
6
|
|
|
5
7
|
|
|
6
|
-
@torch.compiler.disable
|
|
7
8
|
def patch_cached_stats(
|
|
8
|
-
|
|
9
|
+
module: torch.nn.Module | Any,
|
|
10
|
+
cache_context: CachedContext | str = None,
|
|
11
|
+
cache_manager: CachedContextManager = None,
|
|
9
12
|
):
|
|
10
|
-
# Patch the cached stats to the
|
|
13
|
+
# Patch the cached stats to the module, the cached stats
|
|
11
14
|
# will be reset for each calling of pipe.__call__(**kwargs).
|
|
12
|
-
if
|
|
15
|
+
if module is None or cache_manager is None:
|
|
13
16
|
return
|
|
14
17
|
|
|
15
|
-
|
|
16
|
-
|
|
17
|
-
|
|
18
|
-
|
|
19
|
-
|
|
18
|
+
if cache_context is not None:
|
|
19
|
+
cache_manager.set_context(cache_context)
|
|
20
|
+
|
|
21
|
+
# TODO: Patch more cached stats to the module
|
|
22
|
+
module._cached_steps = cache_manager.get_cached_steps()
|
|
23
|
+
module._residual_diffs = cache_manager.get_residual_diffs()
|
|
24
|
+
module._cfg_cached_steps = cache_manager.get_cfg_cached_steps()
|
|
25
|
+
module._cfg_residual_diffs = cache_manager.get_cfg_residual_diffs()
|