cache-dit 1.0.3__py3-none-any.whl → 1.0.14__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- cache_dit/__init__.py +37 -19
- cache_dit/_version.py +2 -2
- cache_dit/caching/__init__.py +36 -0
- cache_dit/{cache_factory → caching}/block_adapters/__init__.py +126 -11
- cache_dit/{cache_factory → caching}/block_adapters/block_adapters.py +78 -7
- cache_dit/caching/block_adapters/block_registers.py +118 -0
- cache_dit/caching/cache_adapters/__init__.py +1 -0
- cache_dit/{cache_factory → caching}/cache_adapters/cache_adapter.py +214 -114
- cache_dit/caching/cache_blocks/__init__.py +226 -0
- cache_dit/caching/cache_blocks/pattern_0_1_2.py +26 -0
- cache_dit/caching/cache_blocks/pattern_3_4_5.py +543 -0
- cache_dit/caching/cache_blocks/pattern_base.py +748 -0
- cache_dit/caching/cache_blocks/pattern_utils.py +86 -0
- cache_dit/caching/cache_contexts/__init__.py +28 -0
- cache_dit/caching/cache_contexts/cache_config.py +120 -0
- cache_dit/{cache_factory → caching}/cache_contexts/cache_context.py +18 -94
- cache_dit/{cache_factory → caching}/cache_contexts/cache_manager.py +133 -12
- cache_dit/{cache_factory → caching}/cache_contexts/calibrators/__init__.py +25 -3
- cache_dit/{cache_factory → caching}/cache_contexts/calibrators/foca.py +1 -1
- cache_dit/{cache_factory → caching}/cache_contexts/calibrators/taylorseer.py +81 -9
- cache_dit/caching/cache_contexts/context_manager.py +36 -0
- cache_dit/caching/cache_contexts/prune_config.py +63 -0
- cache_dit/caching/cache_contexts/prune_context.py +155 -0
- cache_dit/caching/cache_contexts/prune_manager.py +167 -0
- cache_dit/{cache_factory → caching}/cache_interface.py +150 -37
- cache_dit/{cache_factory → caching}/cache_types.py +19 -2
- cache_dit/{cache_factory → caching}/forward_pattern.py +14 -14
- cache_dit/{cache_factory → caching}/params_modifier.py +10 -10
- cache_dit/caching/patch_functors/__init__.py +15 -0
- cache_dit/{cache_factory → caching}/patch_functors/functor_chroma.py +1 -1
- cache_dit/{cache_factory → caching}/patch_functors/functor_dit.py +1 -1
- cache_dit/{cache_factory → caching}/patch_functors/functor_flux.py +1 -1
- cache_dit/{cache_factory → caching}/patch_functors/functor_hidream.py +1 -1
- cache_dit/{cache_factory → caching}/patch_functors/functor_hunyuan_dit.py +1 -1
- cache_dit/{cache_factory → caching}/patch_functors/functor_qwen_image_controlnet.py +1 -1
- cache_dit/{cache_factory → caching}/utils.py +19 -8
- cache_dit/metrics/__init__.py +11 -0
- cache_dit/parallelism/__init__.py +3 -0
- cache_dit/parallelism/backends/native_diffusers/__init__.py +6 -0
- cache_dit/parallelism/backends/native_diffusers/context_parallelism/__init__.py +164 -0
- cache_dit/parallelism/backends/native_diffusers/context_parallelism/attention/__init__.py +4 -0
- cache_dit/parallelism/backends/native_diffusers/context_parallelism/attention/_attention_dispatch.py +304 -0
- cache_dit/parallelism/backends/native_diffusers/context_parallelism/cp_plan_chroma.py +95 -0
- cache_dit/parallelism/backends/native_diffusers/context_parallelism/cp_plan_cogvideox.py +202 -0
- cache_dit/parallelism/backends/native_diffusers/context_parallelism/cp_plan_cogview.py +299 -0
- cache_dit/parallelism/backends/native_diffusers/context_parallelism/cp_plan_cosisid.py +123 -0
- cache_dit/parallelism/backends/native_diffusers/context_parallelism/cp_plan_dit.py +94 -0
- cache_dit/parallelism/backends/native_diffusers/context_parallelism/cp_plan_flux.py +88 -0
- cache_dit/parallelism/backends/native_diffusers/context_parallelism/cp_plan_hunyuan.py +729 -0
- cache_dit/parallelism/backends/native_diffusers/context_parallelism/cp_plan_ltxvideo.py +264 -0
- cache_dit/parallelism/backends/native_diffusers/context_parallelism/cp_plan_nunchaku.py +407 -0
- cache_dit/parallelism/backends/native_diffusers/context_parallelism/cp_plan_pixart.py +285 -0
- cache_dit/parallelism/backends/native_diffusers/context_parallelism/cp_plan_qwen_image.py +104 -0
- cache_dit/parallelism/backends/native_diffusers/context_parallelism/cp_plan_registers.py +84 -0
- cache_dit/parallelism/backends/native_diffusers/context_parallelism/cp_plan_wan.py +101 -0
- cache_dit/parallelism/backends/native_diffusers/context_parallelism/cp_planners.py +117 -0
- cache_dit/parallelism/backends/native_diffusers/parallel_difffusers.py +49 -0
- cache_dit/parallelism/backends/native_diffusers/utils.py +11 -0
- cache_dit/parallelism/backends/native_pytorch/__init__.py +6 -0
- cache_dit/parallelism/backends/native_pytorch/parallel_torch.py +62 -0
- cache_dit/parallelism/backends/native_pytorch/tensor_parallelism/__init__.py +48 -0
- cache_dit/parallelism/backends/native_pytorch/tensor_parallelism/tp_plan_flux.py +171 -0
- cache_dit/parallelism/backends/native_pytorch/tensor_parallelism/tp_plan_kandinsky5.py +79 -0
- cache_dit/parallelism/backends/native_pytorch/tensor_parallelism/tp_plan_qwen_image.py +78 -0
- cache_dit/parallelism/backends/native_pytorch/tensor_parallelism/tp_plan_registers.py +65 -0
- cache_dit/parallelism/backends/native_pytorch/tensor_parallelism/tp_plan_wan.py +153 -0
- cache_dit/parallelism/backends/native_pytorch/tensor_parallelism/tp_planners.py +14 -0
- cache_dit/parallelism/parallel_backend.py +26 -0
- cache_dit/parallelism/parallel_config.py +88 -0
- cache_dit/parallelism/parallel_interface.py +77 -0
- cache_dit/quantize/__init__.py +7 -0
- cache_dit/quantize/backends/__init__.py +1 -0
- cache_dit/quantize/backends/bitsandbytes/__init__.py +0 -0
- cache_dit/quantize/backends/torchao/__init__.py +1 -0
- cache_dit/quantize/{quantize_ao.py → backends/torchao/quantize_ao.py} +40 -30
- cache_dit/quantize/quantize_backend.py +0 -0
- cache_dit/quantize/quantize_config.py +0 -0
- cache_dit/quantize/quantize_interface.py +3 -16
- cache_dit/summary.py +593 -0
- cache_dit/utils.py +46 -290
- {cache_dit-1.0.3.dist-info → cache_dit-1.0.14.dist-info}/METADATA +123 -116
- cache_dit-1.0.14.dist-info/RECORD +102 -0
- cache_dit-1.0.14.dist-info/licenses/LICENSE +203 -0
- cache_dit/cache_factory/__init__.py +0 -28
- cache_dit/cache_factory/block_adapters/block_registers.py +0 -90
- cache_dit/cache_factory/cache_adapters/__init__.py +0 -1
- cache_dit/cache_factory/cache_blocks/__init__.py +0 -76
- cache_dit/cache_factory/cache_blocks/pattern_0_1_2.py +0 -16
- cache_dit/cache_factory/cache_blocks/pattern_3_4_5.py +0 -306
- cache_dit/cache_factory/cache_blocks/pattern_base.py +0 -458
- cache_dit/cache_factory/cache_blocks/pattern_utils.py +0 -41
- cache_dit/cache_factory/cache_contexts/__init__.py +0 -15
- cache_dit/cache_factory/patch_functors/__init__.py +0 -15
- cache_dit-1.0.3.dist-info/RECORD +0 -58
- cache_dit-1.0.3.dist-info/licenses/LICENSE +0 -53
- /cache_dit/{cache_factory → caching}/.gitignore +0 -0
- /cache_dit/{cache_factory → caching}/cache_blocks/offload_utils.py +0 -0
- /cache_dit/{cache_factory → caching}/cache_contexts/calibrators/base.py +0 -0
- /cache_dit/{cache_factory → caching}/patch_functors/functor_base.py +0 -0
- /cache_dit/{custom_ops → kernels}/__init__.py +0 -0
- /cache_dit/{custom_ops → kernels}/triton_taylorseer.py +0 -0
- {cache_dit-1.0.3.dist-info → cache_dit-1.0.14.dist-info}/WHEEL +0 -0
- {cache_dit-1.0.3.dist-info → cache_dit-1.0.14.dist-info}/entry_points.txt +0 -0
- {cache_dit-1.0.3.dist-info → cache_dit-1.0.14.dist-info}/top_level.txt +0 -0
cache_dit/utils.py
CHANGED
|
@@ -1,18 +1,10 @@
|
|
|
1
|
+
import gc
|
|
2
|
+
import time
|
|
1
3
|
import torch
|
|
2
|
-
import dataclasses
|
|
3
4
|
import diffusers
|
|
4
5
|
import builtins as __builtin__
|
|
5
6
|
import contextlib
|
|
6
7
|
|
|
7
|
-
import numpy as np
|
|
8
|
-
from pprint import pprint
|
|
9
|
-
from diffusers import DiffusionPipeline
|
|
10
|
-
|
|
11
|
-
from typing import Dict, Any, List, Union
|
|
12
|
-
from cache_dit.cache_factory import CacheType
|
|
13
|
-
from cache_dit.cache_factory import BlockAdapter
|
|
14
|
-
from cache_dit.cache_factory import BasicCacheConfig
|
|
15
|
-
from cache_dit.cache_factory import CalibratorConfig
|
|
16
8
|
from cache_dit.logger import init_logger
|
|
17
9
|
|
|
18
10
|
|
|
@@ -36,290 +28,54 @@ def is_diffusers_at_least_0_3_5() -> bool:
|
|
|
36
28
|
return diffusers.__version__ >= "0.35.0"
|
|
37
29
|
|
|
38
30
|
|
|
39
|
-
@
|
|
40
|
-
|
|
41
|
-
|
|
42
|
-
|
|
43
|
-
|
|
44
|
-
|
|
45
|
-
|
|
46
|
-
|
|
47
|
-
|
|
48
|
-
|
|
31
|
+
@torch.compiler.disable
|
|
32
|
+
def maybe_empty_cache():
|
|
33
|
+
try:
|
|
34
|
+
time.sleep(1)
|
|
35
|
+
gc.collect()
|
|
36
|
+
torch.cuda.empty_cache()
|
|
37
|
+
torch.cuda.ipc_collect()
|
|
38
|
+
time.sleep(1)
|
|
39
|
+
gc.collect()
|
|
40
|
+
torch.cuda.empty_cache()
|
|
41
|
+
torch.cuda.ipc_collect()
|
|
42
|
+
except Exception:
|
|
43
|
+
pass
|
|
49
44
|
|
|
50
|
-
def summary(
|
|
51
|
-
adapter_or_others: Union[
|
|
52
|
-
BlockAdapter,
|
|
53
|
-
DiffusionPipeline,
|
|
54
|
-
torch.nn.Module,
|
|
55
|
-
],
|
|
56
|
-
details: bool = False,
|
|
57
|
-
logging: bool = True,
|
|
58
|
-
**kwargs,
|
|
59
|
-
) -> List[CacheStats]:
|
|
60
|
-
if adapter_or_others is None:
|
|
61
|
-
return [CacheStats()]
|
|
62
45
|
|
|
63
|
-
|
|
64
|
-
|
|
65
|
-
|
|
66
|
-
|
|
46
|
+
@torch.compiler.disable
|
|
47
|
+
def print_tensor(
|
|
48
|
+
x: torch.Tensor,
|
|
49
|
+
name: str,
|
|
50
|
+
dim: int = 1,
|
|
51
|
+
no_dist_shape: bool = True,
|
|
52
|
+
disable: bool = False,
|
|
53
|
+
):
|
|
54
|
+
if disable:
|
|
55
|
+
return
|
|
56
|
+
|
|
57
|
+
x = x.contiguous()
|
|
58
|
+
if torch.distributed.is_initialized():
|
|
59
|
+
# all gather hidden_states and check values mean
|
|
60
|
+
gather_x = [
|
|
61
|
+
torch.zeros_like(x)
|
|
62
|
+
for _ in range(torch.distributed.get_world_size())
|
|
63
|
+
]
|
|
64
|
+
torch.distributed.all_gather(gather_x, x)
|
|
65
|
+
gather_x = torch.cat(gather_x, dim=dim)
|
|
66
|
+
|
|
67
|
+
if not no_dist_shape:
|
|
68
|
+
x_shape = gather_x.shape
|
|
67
69
|
else:
|
|
68
|
-
|
|
69
|
-
transformer_2 = None
|
|
70
|
-
if hasattr(adapter_or_others, "transformer_2"):
|
|
71
|
-
transformer_2 = adapter_or_others.transformer_2
|
|
72
|
-
|
|
73
|
-
if not BlockAdapter.is_cached(transformer):
|
|
74
|
-
return [CacheStats()]
|
|
75
|
-
|
|
76
|
-
blocks_stats: List[CacheStats] = []
|
|
77
|
-
for blocks in BlockAdapter.find_blocks(transformer):
|
|
78
|
-
blocks_stats.append(
|
|
79
|
-
_summary(
|
|
80
|
-
blocks,
|
|
81
|
-
details=details,
|
|
82
|
-
logging=logging,
|
|
83
|
-
**kwargs,
|
|
84
|
-
)
|
|
85
|
-
)
|
|
86
|
-
|
|
87
|
-
if transformer_2 is not None:
|
|
88
|
-
for blocks in BlockAdapter.find_blocks(transformer_2):
|
|
89
|
-
blocks_stats.append(
|
|
90
|
-
_summary(
|
|
91
|
-
blocks,
|
|
92
|
-
details=details,
|
|
93
|
-
logging=logging,
|
|
94
|
-
**kwargs,
|
|
95
|
-
)
|
|
96
|
-
)
|
|
70
|
+
x_shape = x.shape
|
|
97
71
|
|
|
98
|
-
|
|
99
|
-
|
|
100
|
-
|
|
101
|
-
|
|
102
|
-
logging=logging,
|
|
103
|
-
**kwargs,
|
|
104
|
-
)
|
|
105
|
-
)
|
|
106
|
-
if transformer_2 is not None:
|
|
107
|
-
blocks_stats.append(
|
|
108
|
-
_summary(
|
|
109
|
-
transformer_2,
|
|
110
|
-
details=details,
|
|
111
|
-
logging=logging,
|
|
112
|
-
**kwargs,
|
|
113
|
-
)
|
|
114
|
-
)
|
|
115
|
-
|
|
116
|
-
blocks_stats = [stats for stats in blocks_stats if stats.cache_options]
|
|
117
|
-
|
|
118
|
-
return blocks_stats if len(blocks_stats) else [CacheStats()]
|
|
119
|
-
|
|
120
|
-
adapter = adapter_or_others
|
|
121
|
-
if not BlockAdapter.check_block_adapter(adapter):
|
|
122
|
-
return [CacheStats()]
|
|
123
|
-
|
|
124
|
-
blocks_stats = []
|
|
125
|
-
flatten_blocks = BlockAdapter.flatten(adapter.blocks)
|
|
126
|
-
for blocks in flatten_blocks:
|
|
127
|
-
blocks_stats.append(
|
|
128
|
-
_summary(
|
|
129
|
-
blocks,
|
|
130
|
-
details=details,
|
|
131
|
-
logging=logging,
|
|
132
|
-
**kwargs,
|
|
72
|
+
if torch.distributed.get_rank() == 0:
|
|
73
|
+
print(
|
|
74
|
+
f"{name}, mean: {gather_x.float().mean().item()}, "
|
|
75
|
+
f"std: {gather_x.float().std().item()}, shape: {x_shape}"
|
|
133
76
|
)
|
|
134
|
-
)
|
|
135
|
-
|
|
136
|
-
blocks_stats = [stats for stats in blocks_stats if stats.cache_options]
|
|
137
|
-
|
|
138
|
-
return blocks_stats if len(blocks_stats) else [CacheStats()]
|
|
139
|
-
|
|
140
|
-
|
|
141
|
-
def strify(
|
|
142
|
-
adapter_or_others: Union[
|
|
143
|
-
BlockAdapter,
|
|
144
|
-
DiffusionPipeline,
|
|
145
|
-
CacheStats,
|
|
146
|
-
List[CacheStats],
|
|
147
|
-
Dict[str, Any],
|
|
148
|
-
],
|
|
149
|
-
) -> str:
|
|
150
|
-
if isinstance(adapter_or_others, BlockAdapter):
|
|
151
|
-
stats = summary(adapter_or_others, logging=False)[-1]
|
|
152
|
-
cache_options = stats.cache_options
|
|
153
|
-
cached_steps = len(stats.cached_steps)
|
|
154
|
-
elif isinstance(adapter_or_others, DiffusionPipeline):
|
|
155
|
-
stats = summary(adapter_or_others, logging=False)[-1]
|
|
156
|
-
cache_options = stats.cache_options
|
|
157
|
-
cached_steps = len(stats.cached_steps)
|
|
158
|
-
elif isinstance(adapter_or_others, CacheStats):
|
|
159
|
-
stats = adapter_or_others
|
|
160
|
-
cache_options = stats.cache_options
|
|
161
|
-
cached_steps = len(stats.cached_steps)
|
|
162
|
-
elif isinstance(adapter_or_others, list):
|
|
163
|
-
stats = adapter_or_others[0]
|
|
164
|
-
cache_options = stats.cache_options
|
|
165
|
-
cached_steps = len(stats.cached_steps)
|
|
166
|
-
elif isinstance(adapter_or_others, dict):
|
|
167
|
-
|
|
168
|
-
# Assume cache_context_kwargs
|
|
169
|
-
cache_options = adapter_or_others
|
|
170
|
-
cached_steps = None
|
|
171
|
-
cache_type = cache_options.get("cache_type", CacheType.NONE)
|
|
172
|
-
|
|
173
|
-
if cache_type == CacheType.NONE:
|
|
174
|
-
return "NONE"
|
|
175
77
|
else:
|
|
176
|
-
|
|
177
|
-
"
|
|
178
|
-
"
|
|
179
|
-
)
|
|
180
|
-
|
|
181
|
-
if not cache_options:
|
|
182
|
-
return "NONE"
|
|
183
|
-
|
|
184
|
-
def basic_cache_str():
|
|
185
|
-
cache_config: BasicCacheConfig = cache_options.get("cache_config", None)
|
|
186
|
-
if cache_config is not None:
|
|
187
|
-
return cache_config.strify()
|
|
188
|
-
return "NONE"
|
|
189
|
-
|
|
190
|
-
def calibrator_str():
|
|
191
|
-
calibrator_config: CalibratorConfig = cache_options.get(
|
|
192
|
-
"calibrator_config", None
|
|
78
|
+
print(
|
|
79
|
+
f"{name}, mean: {x.float().mean().item()}, "
|
|
80
|
+
f"std: {x.float().std().item()}, shape: {x.shape}"
|
|
193
81
|
)
|
|
194
|
-
if calibrator_config is not None:
|
|
195
|
-
return calibrator_config.strify()
|
|
196
|
-
return "T0O0"
|
|
197
|
-
|
|
198
|
-
cache_type_str = f"{basic_cache_str()}_{calibrator_str()}"
|
|
199
|
-
|
|
200
|
-
if cached_steps:
|
|
201
|
-
cache_type_str += f"_S{cached_steps}"
|
|
202
|
-
|
|
203
|
-
return cache_type_str
|
|
204
|
-
|
|
205
|
-
|
|
206
|
-
def _summary(
|
|
207
|
-
pipe_or_module: Union[
|
|
208
|
-
DiffusionPipeline,
|
|
209
|
-
torch.nn.Module,
|
|
210
|
-
],
|
|
211
|
-
details: bool = False,
|
|
212
|
-
logging: bool = True,
|
|
213
|
-
**kwargs,
|
|
214
|
-
) -> CacheStats:
|
|
215
|
-
cache_stats = CacheStats()
|
|
216
|
-
|
|
217
|
-
if not isinstance(pipe_or_module, torch.nn.Module):
|
|
218
|
-
assert hasattr(pipe_or_module, "transformer")
|
|
219
|
-
module = pipe_or_module.transformer
|
|
220
|
-
cls_name = module.__class__.__name__
|
|
221
|
-
else:
|
|
222
|
-
module = pipe_or_module
|
|
223
|
-
|
|
224
|
-
cls_name = module.__class__.__name__
|
|
225
|
-
if isinstance(module, torch.nn.ModuleList):
|
|
226
|
-
cls_name = module[0].__class__.__name__
|
|
227
|
-
|
|
228
|
-
if hasattr(module, "_cache_context_kwargs"):
|
|
229
|
-
cache_options = module._cache_context_kwargs
|
|
230
|
-
cache_stats.cache_options = cache_options
|
|
231
|
-
if logging:
|
|
232
|
-
print(f"\n🤗Cache Options: {cls_name}\n\n{cache_options}")
|
|
233
|
-
else:
|
|
234
|
-
if logging:
|
|
235
|
-
logger.warning(f"Can't find Cache Options for: {cls_name}")
|
|
236
|
-
|
|
237
|
-
if hasattr(module, "_cached_steps"):
|
|
238
|
-
cached_steps: list[int] = module._cached_steps
|
|
239
|
-
residual_diffs: dict[str, float] = dict(module._residual_diffs)
|
|
240
|
-
cache_stats.cached_steps = cached_steps
|
|
241
|
-
cache_stats.residual_diffs = residual_diffs
|
|
242
|
-
|
|
243
|
-
if residual_diffs and logging:
|
|
244
|
-
diffs_values = list(residual_diffs.values())
|
|
245
|
-
qmin = np.min(diffs_values)
|
|
246
|
-
q0 = np.percentile(diffs_values, 0)
|
|
247
|
-
q1 = np.percentile(diffs_values, 25)
|
|
248
|
-
q2 = np.percentile(diffs_values, 50)
|
|
249
|
-
q3 = np.percentile(diffs_values, 75)
|
|
250
|
-
q4 = np.percentile(diffs_values, 95)
|
|
251
|
-
qmax = np.max(diffs_values)
|
|
252
|
-
|
|
253
|
-
print(
|
|
254
|
-
f"\n⚡️Cache Steps and Residual Diffs Statistics: {cls_name}\n"
|
|
255
|
-
)
|
|
256
|
-
|
|
257
|
-
print(
|
|
258
|
-
"| Cache Steps | Diffs P00 | Diffs P25 | Diffs P50 | Diffs P75 | Diffs P95 | Diffs Min | Diffs Max |"
|
|
259
|
-
)
|
|
260
|
-
print(
|
|
261
|
-
"|-------------|-----------|-----------|-----------|-----------|-----------|-----------|-----------|"
|
|
262
|
-
)
|
|
263
|
-
print(
|
|
264
|
-
f"| {len(cached_steps):<11} | {round(q0, 3):<9} | {round(q1, 3):<9} "
|
|
265
|
-
f"| {round(q2, 3):<9} | {round(q3, 3):<9} | {round(q4, 3):<9} "
|
|
266
|
-
f"| {round(qmin, 3):<9} | {round(qmax, 3):<9} |"
|
|
267
|
-
)
|
|
268
|
-
print("")
|
|
269
|
-
|
|
270
|
-
if details:
|
|
271
|
-
print(f"📚Cache Steps and Residual Diffs Details: {cls_name}\n")
|
|
272
|
-
pprint(
|
|
273
|
-
f"Cache Steps: {len(cached_steps)}, {cached_steps}",
|
|
274
|
-
)
|
|
275
|
-
pprint(
|
|
276
|
-
f"Residual Diffs: {len(residual_diffs)}, {residual_diffs}",
|
|
277
|
-
compact=True,
|
|
278
|
-
)
|
|
279
|
-
|
|
280
|
-
if hasattr(module, "_cfg_cached_steps"):
|
|
281
|
-
cfg_cached_steps: list[int] = module._cfg_cached_steps
|
|
282
|
-
cfg_residual_diffs: dict[str, float] = dict(module._cfg_residual_diffs)
|
|
283
|
-
cache_stats.cfg_cached_steps = cfg_cached_steps
|
|
284
|
-
cache_stats.cfg_residual_diffs = cfg_residual_diffs
|
|
285
|
-
|
|
286
|
-
if cfg_residual_diffs and logging:
|
|
287
|
-
cfg_diffs_values = list(cfg_residual_diffs.values())
|
|
288
|
-
qmin = np.min(cfg_diffs_values)
|
|
289
|
-
q0 = np.percentile(cfg_diffs_values, 0)
|
|
290
|
-
q1 = np.percentile(cfg_diffs_values, 25)
|
|
291
|
-
q2 = np.percentile(cfg_diffs_values, 50)
|
|
292
|
-
q3 = np.percentile(cfg_diffs_values, 75)
|
|
293
|
-
q4 = np.percentile(cfg_diffs_values, 95)
|
|
294
|
-
qmax = np.max(cfg_diffs_values)
|
|
295
|
-
|
|
296
|
-
print(
|
|
297
|
-
f"\n⚡️CFG Cache Steps and Residual Diffs Statistics: {cls_name}\n"
|
|
298
|
-
)
|
|
299
|
-
|
|
300
|
-
print(
|
|
301
|
-
"| CFG Cache Steps | Diffs P00 | Diffs P25 | Diffs P50 | Diffs P75 | Diffs P95 | Diffs Min | Diffs Max |"
|
|
302
|
-
)
|
|
303
|
-
print(
|
|
304
|
-
"|-----------------|-----------|-----------|-----------|-----------|-----------|-----------|-----------|"
|
|
305
|
-
)
|
|
306
|
-
print(
|
|
307
|
-
f"| {len(cfg_cached_steps):<15} | {round(q0, 3):<9} | {round(q1, 3):<9} "
|
|
308
|
-
f"| {round(q2, 3):<9} | {round(q3, 3):<9} | {round(q4, 3):<9} "
|
|
309
|
-
f"| {round(qmin, 3):<9} | {round(qmax, 3):<9} |"
|
|
310
|
-
)
|
|
311
|
-
print("")
|
|
312
|
-
|
|
313
|
-
if details:
|
|
314
|
-
print(
|
|
315
|
-
f"📚CFG Cache Steps and Residual Diffs Details: {cls_name}\n"
|
|
316
|
-
)
|
|
317
|
-
pprint(
|
|
318
|
-
f"CFG Cache Steps: {len(cfg_cached_steps)}, {cfg_cached_steps}",
|
|
319
|
-
)
|
|
320
|
-
pprint(
|
|
321
|
-
f"CFG Residual Diffs: {len(cfg_residual_diffs)}, {cfg_residual_diffs}",
|
|
322
|
-
compact=True,
|
|
323
|
-
)
|
|
324
|
-
|
|
325
|
-
return cache_stats
|