diffcb 0.1.7__tar.gz → 0.1.9__tar.gz
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- {diffcb-0.1.7 → diffcb-0.1.9}/PKG-INFO +13 -11
- {diffcb-0.1.7 → diffcb-0.1.9}/README.md +12 -10
- {diffcb-0.1.7 → diffcb-0.1.9}/dcb/__init__.py +1 -1
- {diffcb-0.1.7 → diffcb-0.1.9}/dcb/fft_kde.py +12 -6
- {diffcb-0.1.7 → diffcb-0.1.9}/dcb/layer.py +22 -31
- {diffcb-0.1.7 → diffcb-0.1.9}/dcb/solver.py +16 -8
- {diffcb-0.1.7 → diffcb-0.1.9}/pyproject.toml +1 -1
- {diffcb-0.1.7 → diffcb-0.1.9}/tests/test_gradcheck.py +14 -22
- {diffcb-0.1.7 → diffcb-0.1.9}/tests/test_layer.py +9 -6
- diffcb-0.1.9/v018_local_bench.py +394 -0
- {diffcb-0.1.7 → diffcb-0.1.9}/.gitignore +0 -0
- {diffcb-0.1.7 → diffcb-0.1.9}/.zenodo.json +0 -0
- {diffcb-0.1.7 → diffcb-0.1.9}/LICENSE +0 -0
- {diffcb-0.1.7 → diffcb-0.1.9}/dcb/diagnostics.py +0 -0
- {diffcb-0.1.7 → diffcb-0.1.9}/dcb/kde.py +0 -0
- {diffcb-0.1.7 → diffcb-0.1.9}/dcb/training.py +0 -0
- {diffcb-0.1.7 → diffcb-0.1.9}/dcb/utils.py +0 -0
- {diffcb-0.1.7 → diffcb-0.1.9}/notebooks/.gitkeep +0 -0
- {diffcb-0.1.7 → diffcb-0.1.9}/round24_cumulative_bench.py +0 -0
- {diffcb-0.1.7 → diffcb-0.1.9}/round24_v016_test.py +0 -0
- {diffcb-0.1.7 → diffcb-0.1.9}/round25_full_range_sweep.py +0 -0
- {diffcb-0.1.7 → diffcb-0.1.9}/round25_write_csv.py +0 -0
- {diffcb-0.1.7 → diffcb-0.1.9}/tests/test_kde.py +0 -0
- {diffcb-0.1.7 → diffcb-0.1.9}/tests/test_r18c_denom_audit.py +0 -0
- {diffcb-0.1.7 → diffcb-0.1.9}/tests/test_r18c_deprecation_warn.py +0 -0
- {diffcb-0.1.7 → diffcb-0.1.9}/tests/test_r19_default_fft.py +0 -0
- {diffcb-0.1.7 → diffcb-0.1.9}/tests/test_r19_diagnostics.py +0 -0
- {diffcb-0.1.7 → diffcb-0.1.9}/tests/test_solver.py +0 -0
|
@@ -1,6 +1,6 @@
|
|
|
1
1
|
Metadata-Version: 2.4
|
|
2
2
|
Name: diffcb
|
|
3
|
-
Version: 0.1.
|
|
3
|
+
Version: 0.1.9
|
|
4
4
|
Summary: Differentiable Critical Bandwidth: Silverman's modality test as a differentiable PyTorch layer with IFT backward pass.
|
|
5
5
|
Project-URL: Homepage, https://github.com/ryZhangHason/differentiable-critical-bandwidth
|
|
6
6
|
Project-URL: Repository, https://github.com/ryZhangHason/differentiable-critical-bandwidth
|
|
@@ -295,15 +295,17 @@ Cumulative speedup vs v0.1.4 on CPU: 1.1× (100K), 1.7× (1M), **4.2× (10M)**.
|
|
|
295
295
|
|
|
296
296
|
```python
|
|
297
297
|
DCBLayer(
|
|
298
|
-
target_modes=1,
|
|
299
|
-
use_fft=True,
|
|
300
|
-
max_n_exact=None,
|
|
301
|
-
G_min=16384,
|
|
302
|
-
use_richardson=
|
|
303
|
-
direct_n_max=25_000,
|
|
304
|
-
direct_M=2048,
|
|
305
|
-
|
|
306
|
-
|
|
298
|
+
target_modes=1, # target number of modes (default 1)
|
|
299
|
+
use_fft=True, # FFT path for n > 50K (default True)
|
|
300
|
+
max_n_exact=None, # sketch above this n (None = always exact)
|
|
301
|
+
G_min=16384, # minimum FFT histogram bins (accuracy ↑ with G)
|
|
302
|
+
use_richardson="auto", # Richardson on CPU, off on GPU (30% accuracy gain on CPU)
|
|
303
|
+
direct_n_max=25_000, # direct-KDE active only when forward_path='auto'/'direct'
|
|
304
|
+
direct_M=2048, # direct-KDE evaluation grid size
|
|
305
|
+
forward_path='smooth', # 'smooth' (default, strictly differentiable) |
|
|
306
|
+
# 'auto' (direct-KDE at n≤25K, surrogate gradient) |
|
|
307
|
+
# 'direct' (force direct-KDE, accuracy benchmarks)
|
|
308
|
+
safe_backward=False, # clamp IFT denominator near bifurcations
|
|
307
309
|
)
|
|
308
310
|
```
|
|
309
311
|
|
|
@@ -343,7 +345,7 @@ By default (`use_richardson=True`), DCB runs a second bisection at G/2=8192 and
|
|
|
343
345
|
|
|
344
346
|
- **`compile=True` on MPS**: blocked by float64 in `_refine_hcrit` fallback (fix in v0.1.7)
|
|
345
347
|
- **`compile=True` on CUDA with Python 3.12**: requires torch ≥ 2.4 or Python ≤ 3.11
|
|
346
|
-
- **`gradcheck
|
|
348
|
+
- **`gradcheck`**: passes with the default `forward_path='smooth'`; the default is strictly differentiable at all n. Opt into `forward_path='auto'` only for forward-only accuracy benchmarks (surrogate gradient at n≤25K)
|
|
347
349
|
- **n > 100M**: requires streaming histogram (not yet public API); use `max_n_exact=1_000_000` sketch as workaround
|
|
348
350
|
|
|
349
351
|
## Confirmed Experimental Results
|
|
@@ -73,15 +73,17 @@ Cumulative speedup vs v0.1.4 on CPU: 1.1× (100K), 1.7× (1M), **4.2× (10M)**.
|
|
|
73
73
|
|
|
74
74
|
```python
|
|
75
75
|
DCBLayer(
|
|
76
|
-
target_modes=1,
|
|
77
|
-
use_fft=True,
|
|
78
|
-
max_n_exact=None,
|
|
79
|
-
G_min=16384,
|
|
80
|
-
use_richardson=
|
|
81
|
-
direct_n_max=25_000,
|
|
82
|
-
direct_M=2048,
|
|
83
|
-
|
|
84
|
-
|
|
76
|
+
target_modes=1, # target number of modes (default 1)
|
|
77
|
+
use_fft=True, # FFT path for n > 50K (default True)
|
|
78
|
+
max_n_exact=None, # sketch above this n (None = always exact)
|
|
79
|
+
G_min=16384, # minimum FFT histogram bins (accuracy ↑ with G)
|
|
80
|
+
use_richardson="auto", # Richardson on CPU, off on GPU (30% accuracy gain on CPU)
|
|
81
|
+
direct_n_max=25_000, # direct-KDE active only when forward_path='auto'/'direct'
|
|
82
|
+
direct_M=2048, # direct-KDE evaluation grid size
|
|
83
|
+
forward_path='smooth', # 'smooth' (default, strictly differentiable) |
|
|
84
|
+
# 'auto' (direct-KDE at n≤25K, surrogate gradient) |
|
|
85
|
+
# 'direct' (force direct-KDE, accuracy benchmarks)
|
|
86
|
+
safe_backward=False, # clamp IFT denominator near bifurcations
|
|
85
87
|
)
|
|
86
88
|
```
|
|
87
89
|
|
|
@@ -121,7 +123,7 @@ By default (`use_richardson=True`), DCB runs a second bisection at G/2=8192 and
|
|
|
121
123
|
|
|
122
124
|
- **`compile=True` on MPS**: blocked by float64 in `_refine_hcrit` fallback (fix in v0.1.7)
|
|
123
125
|
- **`compile=True` on CUDA with Python 3.12**: requires torch ≥ 2.4 or Python ≤ 3.11
|
|
124
|
-
- **`gradcheck
|
|
126
|
+
- **`gradcheck`**: passes with the default `forward_path='smooth'`; the default is strictly differentiable at all n. Opt into `forward_path='auto'` only for forward-only accuracy benchmarks (surrogate gradient at n≤25K)
|
|
125
127
|
- **n > 100M**: requires streaming histogram (not yet public API); use `max_n_exact=1_000_000` sketch as workaround
|
|
126
128
|
|
|
127
129
|
## Confirmed Experimental Results
|
|
@@ -137,12 +137,15 @@ def mode_count_from_C(
|
|
|
137
137
|
if C.numel() == 0:
|
|
138
138
|
return 1 # degenerate single-point distribution
|
|
139
139
|
|
|
140
|
+
# Caller must pass C computed with fft_dtype=torch.float64 (complex128).
|
|
141
|
+
# Float32 histogram errors (~1e-4 relative) create spurious sign changes
|
|
142
|
+
# near small h_crit that cannot be filtered without killing genuine lobes.
|
|
143
|
+
# A relative threshold of 1e-12 removes machine-epsilon edge-bin noise
|
|
144
|
+
# (empty histogram bins at domain boundaries contribute ~eps_f64 to f′).
|
|
140
145
|
K_deriv = 1j * omega * torch.exp(-0.5 * (omega * h) ** 2)
|
|
141
146
|
f_prime = torch.fft.irfft(C * K_deriv, n=N).real[:G]
|
|
142
|
-
|
|
143
|
-
|
|
144
|
-
# one host sync (.any()) and two allocations (nonzero_mask, s) per call.
|
|
145
|
-
return int(((f_prime[:-1] > 0) & (f_prime[1:] < 0)).sum().item())
|
|
147
|
+
thresh = f_prime.abs().max() * 1e-12
|
|
148
|
+
return int(((f_prime[:-1] > thresh) & (f_prime[1:] < -thresh)).sum().item())
|
|
146
149
|
|
|
147
150
|
|
|
148
151
|
def mode_count_from_C_batch(
|
|
@@ -186,12 +189,15 @@ def mode_count_from_C_batch(
|
|
|
186
189
|
else:
|
|
187
190
|
h_t = h_batch.to(dtype=omega.dtype, device=omega.device) # (B,)
|
|
188
191
|
|
|
192
|
+
# Caller must pass C in float64 (complex128) — same rationale as mode_count_from_C.
|
|
189
193
|
# Build (B, M) kernel matrix in one vectorised op
|
|
190
194
|
omega_h = omega.unsqueeze(0) * h_t.unsqueeze(1) # (B, M)
|
|
191
195
|
K_batch = 1j * omega.unsqueeze(0) * torch.exp(-0.5 * omega_h ** 2) # (B, M)
|
|
192
196
|
# One batched irfft dispatch instead of B separate calls
|
|
193
|
-
f_prime_batch = torch.fft.irfft(C.unsqueeze(0) * K_batch, n=N)[:, :G]
|
|
194
|
-
|
|
197
|
+
f_prime_batch = torch.fft.irfft(C.unsqueeze(0) * K_batch, n=N)[:, :G] # (B, G)
|
|
198
|
+
# Per-bandwidth threshold: remove edge-bin machine-epsilon noise.
|
|
199
|
+
thresholds = f_prime_batch.abs().amax(dim=1, keepdim=True) * 1e-12 # (B, 1)
|
|
200
|
+
return ((f_prime_batch[:, :-1] > thresholds) & (f_prime_batch[:, 1:] < -thresholds)).sum(dim=1)
|
|
195
201
|
|
|
196
202
|
|
|
197
203
|
def fft_mode_count(
|
|
@@ -14,16 +14,16 @@ the computational graph of the iterative solver and maintaining O(1) memory
|
|
|
14
14
|
cost relative to the number of solver iterations. Hyperparameters ε and τ
|
|
15
15
|
may be supplied explicitly or computed adaptively via `dcb.utils`.
|
|
16
16
|
|
|
17
|
-
|
|
18
|
-
|
|
19
|
-
|
|
20
|
-
|
|
21
|
-
|
|
22
|
-
|
|
23
|
-
|
|
24
|
-
|
|
25
|
-
|
|
26
|
-
|
|
17
|
+
Strict differentiability
|
|
18
|
+
------------------------
|
|
19
|
+
``forward_path='smooth'`` (the default) forces both the forward and backward
|
|
20
|
+
to use the same smooth M̃ surrogate at all n — ``torch.autograd.gradcheck``
|
|
21
|
+
passes and ∂h_crit/∂X is the exact IFT gradient of the computed h_crit.
|
|
22
|
+
|
|
23
|
+
``forward_path='auto'`` opts into the direct-KDE forward at n ≤ direct_n_max
|
|
24
|
+
for zero-bias accuracy, but forward and backward then use different implicit
|
|
25
|
+
functions — gradcheck will fail and the gradient is a surrogate. Only use
|
|
26
|
+
'auto' for forward-only inference benchmarks.
|
|
27
27
|
"""
|
|
28
28
|
|
|
29
29
|
from __future__ import annotations
|
|
@@ -189,26 +189,17 @@ class DCBLayer(nn.Module):
|
|
|
189
189
|
forward_path : str
|
|
190
190
|
Controls forward-pass routing. One of:
|
|
191
191
|
|
|
192
|
-
- ``'
|
|
193
|
-
|
|
194
|
-
|
|
195
|
-
|
|
196
|
-
|
|
197
|
-
|
|
198
|
-
|
|
199
|
-
|
|
200
|
-
|
|
201
|
-
|
|
202
|
-
|
|
203
|
-
-----
|
|
204
|
-
Forward/backward path at n <= direct_n_max:
|
|
205
|
-
The forward pass uses direct KDE (no histogram) for accuracy.
|
|
206
|
-
The backward pass uses the smooth IFT on M̃ (soft mode count) at all n.
|
|
207
|
-
These are different implicit functions, so ``torch.autograd.gradcheck``
|
|
208
|
-
will fail for n <= direct_n_max. Use ``forward_path='smooth'`` to
|
|
209
|
-
force the smooth path at all n — gradcheck will then pass.
|
|
210
|
-
For ML training this mismatch is correct by design: the smooth
|
|
211
|
-
gradient is the appropriate object for gradient descent.
|
|
192
|
+
- ``'smooth'`` **(default)**: always use the FFT histogram or
|
|
193
|
+
chunked-KDE path. Both forward and backward use the same smooth
|
|
194
|
+
M̃ surrogate — ``torch.autograd.gradcheck`` passes and gradients
|
|
195
|
+
are the exact IFT derivatives of the computed h_crit.
|
|
196
|
+
- ``'auto'``: use direct-KDE for n ≤ direct_n_max (zero histogram
|
|
197
|
+
bias), FFT for n > 50K, chunked-KDE otherwise. Forward and
|
|
198
|
+
backward use different implicit functions at n ≤ direct_n_max,
|
|
199
|
+
so gradcheck will fail and the gradient is a surrogate.
|
|
200
|
+
Use only for forward-only inference / accuracy benchmarks.
|
|
201
|
+
- ``'direct'``: force direct-KDE at all n (accuracy benchmark only;
|
|
202
|
+
slow for large n; surrogate gradient at all n).
|
|
212
203
|
|
|
213
204
|
Examples
|
|
214
205
|
--------
|
|
@@ -244,7 +235,7 @@ class DCBLayer(nn.Module):
|
|
|
244
235
|
use_compile: bool = False,
|
|
245
236
|
direct_n_max: int = 25_000,
|
|
246
237
|
direct_M: int = 2048,
|
|
247
|
-
forward_path: str = '
|
|
238
|
+
forward_path: str = 'smooth',
|
|
248
239
|
):
|
|
249
240
|
super().__init__()
|
|
250
241
|
if forward_path not in ('auto', 'smooth', 'direct'):
|
|
@@ -238,28 +238,36 @@ def find_h_crit_hard(
|
|
|
238
238
|
|
|
239
239
|
with torch.no_grad():
|
|
240
240
|
# Worker 1: precomputed C — hoist histogram + rfft out of bisection.
|
|
241
|
-
# Worker 3: float32 FFT
|
|
241
|
+
# Worker 3: float32 FFT for _refine_hcrit (2× faster; refinement is
|
|
242
|
+
# accuracy-insensitive to histogram dtype). Mode counting requires
|
|
243
|
+
# float64 histogram precision: float32 FFT errors (~1e-4 relative)
|
|
244
|
+
# create spurious sign changes near small h_crit that a relative
|
|
245
|
+
# threshold cannot remove without also killing genuine small lobes.
|
|
242
246
|
C, omega, _domain = precompute_fft(
|
|
243
247
|
X, G=G_fft, domain=_domain, pad_factor=pad_factor, fft_dtype=fft_dtype,
|
|
244
248
|
)
|
|
249
|
+
C_mc, omega_mc, _ = precompute_fft(
|
|
250
|
+
X, G=G_fft, domain=_domain, pad_factor=pad_factor,
|
|
251
|
+
fft_dtype=torch.float64,
|
|
252
|
+
)
|
|
245
253
|
|
|
246
254
|
# Verify bracket using FFT mode count on full X
|
|
247
|
-
count_lo = mode_count_from_C(
|
|
255
|
+
count_lo = mode_count_from_C(C_mc, omega_mc, h_lo, G_fft, N)
|
|
248
256
|
if count_lo <= target_modes:
|
|
249
257
|
h_lo_try = h_lo
|
|
250
258
|
for _ in range(30):
|
|
251
259
|
h_lo_try *= 0.5
|
|
252
260
|
if h_lo_try < 1e-10:
|
|
253
261
|
break
|
|
254
|
-
if mode_count_from_C(
|
|
262
|
+
if mode_count_from_C(C_mc, omega_mc, h_lo_try, G_fft, N) > target_modes:
|
|
255
263
|
h_lo = h_lo_try
|
|
256
264
|
break
|
|
257
265
|
|
|
258
|
-
count_hi = mode_count_from_C(
|
|
266
|
+
count_hi = mode_count_from_C(C_mc, omega_mc, h_hi, G_fft, N)
|
|
259
267
|
if count_hi > target_modes:
|
|
260
268
|
for _ in range(30):
|
|
261
269
|
h_hi *= 2.0
|
|
262
|
-
if mode_count_from_C(
|
|
270
|
+
if mode_count_from_C(C_mc, omega_mc, h_hi, G_fft, N) <= target_modes:
|
|
263
271
|
break
|
|
264
272
|
|
|
265
273
|
# Compile-friendly trisection: lo/hi are 0-d tensors, no .item()
|
|
@@ -267,8 +275,8 @@ def find_h_crit_hard(
|
|
|
267
275
|
# than enough for any bracket). torch.where replaces the Python
|
|
268
276
|
# if/elif/else so the loop body is a pure tensor computation that
|
|
269
277
|
# torch.compile(mode="reduce-overhead") can trace and replay.
|
|
270
|
-
_dtype =
|
|
271
|
-
_dev =
|
|
278
|
+
_dtype = omega_mc.dtype
|
|
279
|
+
_dev = C_mc.device
|
|
272
280
|
lo_t = torch.tensor(h_lo, dtype=_dtype, device=_dev)
|
|
273
281
|
hi_t = torch.tensor(h_hi, dtype=_dtype, device=_dev)
|
|
274
282
|
_target = torch.tensor(target_modes, dtype=torch.long, device=_dev)
|
|
@@ -277,7 +285,7 @@ def find_h_crit_hard(
|
|
|
277
285
|
h1 = lo_t + width * (1.0 / 3.0)
|
|
278
286
|
h2 = lo_t + width * (2.0 / 3.0)
|
|
279
287
|
counts = mode_count_from_C_batch(
|
|
280
|
-
|
|
288
|
+
C_mc, omega_mc, torch.stack([h1, h2]), G_fft, N
|
|
281
289
|
)
|
|
282
290
|
c1 = counts[0]
|
|
283
291
|
c2 = counts[1]
|
|
@@ -4,7 +4,7 @@ build-backend = "hatchling.build"
|
|
|
4
4
|
|
|
5
5
|
[project]
|
|
6
6
|
name = "diffcb"
|
|
7
|
-
version = "0.1.
|
|
7
|
+
version = "0.1.9"
|
|
8
8
|
description = "Differentiable Critical Bandwidth: Silverman's modality test as a differentiable PyTorch layer with IFT backward pass."
|
|
9
9
|
readme = "README.md"
|
|
10
10
|
license = { file = "LICENSE" }
|
|
@@ -17,37 +17,29 @@ from dcb.layer import DCBLayer
|
|
|
17
17
|
# ---------------------------------------------------------------------------
|
|
18
18
|
|
|
19
19
|
def test_gradcheck_smooth_path():
|
|
20
|
-
"""gradcheck passes
|
|
21
|
-
|
|
22
|
-
|
|
23
|
-
|
|
24
|
-
the
|
|
25
|
-
|
|
26
|
-
|
|
27
|
-
|
|
28
|
-
|
|
29
|
-
|
|
30
|
-
2. The IFT backward uses smooth M̃ evaluated at this h_crit, which is
|
|
31
|
-
not perfectly aligned with the bisection steps.
|
|
32
|
-
3. These two effects cause O(1%) discrepancies between the analytical
|
|
33
|
-
and numerical Jacobians — well within the intended ML-training use case,
|
|
34
|
-
but requiring loose gradcheck tolerances to pass reliably.
|
|
20
|
+
"""gradcheck passes with the default DCBLayer (forward_path='smooth').
|
|
21
|
+
|
|
22
|
+
forward_path='smooth' (the default since v0.1.7) forces both the forward
|
|
23
|
+
pass and the IFT backward to use the same smooth M̃ surrogate — they are
|
|
24
|
+
the same implicit function, so gradcheck is internally consistent.
|
|
25
|
+
|
|
26
|
+
Tolerances are loose (atol=0.05, rtol=0.2) because the hard bisection
|
|
27
|
+
root-finder introduces O(1%) quantisation noise between the analytical
|
|
28
|
+
IFT gradient and the finite-difference Jacobian. This is the expected
|
|
29
|
+
residual for the discrete-step bisection and does not affect ML training.
|
|
35
30
|
"""
|
|
36
31
|
torch.manual_seed(42)
|
|
37
32
|
X = torch.cat([torch.randn(30) - 1.0, torch.randn(30) + 1.0]).double()
|
|
38
33
|
X.requires_grad_(True)
|
|
34
|
+
# Default forward_path='smooth' — strictly differentiable at all n
|
|
39
35
|
layer = DCBLayer(
|
|
40
|
-
use_fft=False,
|
|
41
|
-
|
|
42
|
-
use_richardson=False, # Richardson adds a second backward; skip for gradcheck
|
|
36
|
+
use_fft=False,
|
|
37
|
+
use_richardson=False, # skip Richardson for gradcheck clarity
|
|
43
38
|
)
|
|
44
39
|
|
|
45
|
-
# eps=1e-3 for finite differences; atol/rtol are loose because the smooth
|
|
46
|
-
# IFT gradient and FD Jacobian can differ by ~1–5% due to bisection
|
|
47
|
-
# quantisation — verified reliable across 20 seeds.
|
|
48
40
|
result = torch.autograd.gradcheck(layer, (X,), eps=1e-3, atol=0.05, rtol=0.2,
|
|
49
41
|
raise_exception=True)
|
|
50
|
-
assert result, "gradcheck failed with forward_path='smooth'"
|
|
42
|
+
assert result, "gradcheck failed with default DCBLayer (forward_path='smooth')"
|
|
51
43
|
|
|
52
44
|
|
|
53
45
|
# ---------------------------------------------------------------------------
|
|
@@ -121,24 +121,27 @@ def test_dcblayer_state_dict():
|
|
|
121
121
|
|
|
122
122
|
@pytest.mark.xfail(
|
|
123
123
|
reason=(
|
|
124
|
-
"
|
|
125
|
-
"
|
|
126
|
-
"Qualitative correctness verified in test_ift_gradient_matches_finite_diff
|
|
124
|
+
"Hard bisection introduces quantisation noise even with forward_path='smooth'. "
|
|
125
|
+
"atol=1e-3 is too strict for the bisection step-function discretisation. "
|
|
126
|
+
"Qualitative correctness verified in test_ift_gradient_matches_finite_diff; "
|
|
127
|
+
"loose-tolerance gradcheck passes in tests/test_gradcheck.py."
|
|
127
128
|
),
|
|
128
129
|
strict=False,
|
|
129
130
|
)
|
|
130
131
|
def test_dcblayer_gradcheck():
|
|
131
132
|
"""torch.autograd.gradcheck with double precision, eps=1e-4, atol=1e-3.
|
|
132
133
|
|
|
133
|
-
Uses bimodal X with n=30 (small for speed).
|
|
134
|
-
|
|
134
|
+
Uses bimodal X with n=30 (small for speed). Default forward_path='smooth'
|
|
135
|
+
means forward and backward use the same M̃ surrogate — the remaining xfail
|
|
136
|
+
is due to bisection quantisation noise (step-function root-finding), not a
|
|
137
|
+
forward/backward path mismatch. Loose-tolerance gradcheck passes separately.
|
|
135
138
|
"""
|
|
136
139
|
torch.manual_seed(42)
|
|
137
140
|
n = 30
|
|
138
141
|
X_base = torch.cat([torch.randn(15) - 1.0, torch.randn(15) + 1.0])
|
|
139
142
|
X = X_base.double().requires_grad_(True)
|
|
140
143
|
|
|
141
|
-
layer = DCBLayer(target_modes=1, G=64)
|
|
144
|
+
layer = DCBLayer(target_modes=1, G=64) # default forward_path='smooth'
|
|
142
145
|
|
|
143
146
|
def fn(x):
|
|
144
147
|
return layer(x)
|
|
@@ -0,0 +1,394 @@
|
|
|
1
|
+
"""
|
|
2
|
+
v018_local_bench.py
|
|
3
|
+
Comprehensive benchmark for diffcb v0.1.8 (forward_path='smooth' default).
|
|
4
|
+
Produces 3 CSV files and prints summary tables.
|
|
5
|
+
"""
|
|
6
|
+
|
|
7
|
+
import os
|
|
8
|
+
import time
|
|
9
|
+
import math
|
|
10
|
+
import subprocess
|
|
11
|
+
import tempfile
|
|
12
|
+
import statistics
|
|
13
|
+
import csv
|
|
14
|
+
|
|
15
|
+
import torch
|
|
16
|
+
from dcb import DCBLayer
|
|
17
|
+
|
|
18
|
+
RESULTS_DIR = "/Users/h/Downloads/DCB-workspace/02_projects/01_dcb_proposal/04_analysis/results"
|
|
19
|
+
os.makedirs(RESULTS_DIR, exist_ok=True)
|
|
20
|
+
|
|
21
|
+
SPEED_NS = [1_000, 2_000, 5_000, 10_000, 25_000, 50_000, 100_000,
|
|
22
|
+
500_000, 1_000_000, 5_000_000, 10_000_000]
|
|
23
|
+
ACCURACY_NS = [1_000, 5_000, 10_000, 25_000, 50_000, 100_000, 500_000, 1_000_000]
|
|
24
|
+
|
|
25
|
+
DEVICES = ['cpu']
|
|
26
|
+
if torch.backends.mps.is_available():
|
|
27
|
+
DEVICES.append('mps')
|
|
28
|
+
print("MPS is available — will benchmark both CPU and MPS.")
|
|
29
|
+
else:
|
|
30
|
+
print("MPS not available — benchmarking CPU only.")
|
|
31
|
+
|
|
32
|
+
R_BINARY = '/usr/local/bin/Rscript'
|
|
33
|
+
|
|
34
|
+
|
|
35
|
+
# ─────────────────────────────────────────────────────────
|
|
36
|
+
# Calibration: T_fft_one_ms per device
|
|
37
|
+
# ─────────────────────────────────────────────────────────
|
|
38
|
+
|
|
39
|
+
def calibrate_fft(device: str) -> float:
|
|
40
|
+
"""Return mean time in ms for one rfft(ones(16384)) call."""
|
|
41
|
+
x = torch.ones(16384, dtype=torch.float32, device=device)
|
|
42
|
+
WARMUP = 500
|
|
43
|
+
TIMED = 5000
|
|
44
|
+
|
|
45
|
+
# Warm-up
|
|
46
|
+
for _ in range(WARMUP):
|
|
47
|
+
torch.fft.rfft(x)
|
|
48
|
+
|
|
49
|
+
if device == 'mps':
|
|
50
|
+
torch.mps.synchronize()
|
|
51
|
+
t0 = time.perf_counter()
|
|
52
|
+
for _ in range(TIMED):
|
|
53
|
+
torch.fft.rfft(x)
|
|
54
|
+
torch.mps.synchronize()
|
|
55
|
+
t1 = time.perf_counter()
|
|
56
|
+
else:
|
|
57
|
+
t0 = time.perf_counter()
|
|
58
|
+
for _ in range(TIMED):
|
|
59
|
+
torch.fft.rfft(x)
|
|
60
|
+
t1 = time.perf_counter()
|
|
61
|
+
|
|
62
|
+
total_ms = (t1 - t0) * 1000.0
|
|
63
|
+
return total_ms / TIMED
|
|
64
|
+
|
|
65
|
+
|
|
66
|
+
print("\n=== Calibrating FFT units ===")
|
|
67
|
+
T_fft = {}
|
|
68
|
+
for dev in DEVICES:
|
|
69
|
+
T_fft[dev] = calibrate_fft(dev)
|
|
70
|
+
print(f" {dev}: T_fft_one_ms = {T_fft[dev]:.6f} ms")
|
|
71
|
+
|
|
72
|
+
|
|
73
|
+
# ─────────────────────────────────────────────────────────
|
|
74
|
+
# Part 1 — Speed benchmark
|
|
75
|
+
# ─────────────────────────────────────────────────────────
|
|
76
|
+
|
|
77
|
+
print("\n=== Part 1: Speed Benchmark ===")
|
|
78
|
+
speed_rows = []
|
|
79
|
+
|
|
80
|
+
for device in DEVICES:
|
|
81
|
+
# Build a shared DCBLayer for this device (reuse across n — it's stateless)
|
|
82
|
+
layer = DCBLayer()
|
|
83
|
+
if device == 'mps':
|
|
84
|
+
layer = layer.to(device)
|
|
85
|
+
|
|
86
|
+
T_fft_one_ms = T_fft[device]
|
|
87
|
+
|
|
88
|
+
for n in SPEED_NS:
|
|
89
|
+
if n >= 5_000_000 and device == 'mps':
|
|
90
|
+
print(f" [mps, n={n:,}] — skipping (too large for MPS RAM guard)")
|
|
91
|
+
speed_rows.append({
|
|
92
|
+
'device': device, 'n': n,
|
|
93
|
+
't_median_ms': float('nan'), 't_mean_ms': float('nan'),
|
|
94
|
+
't_std_ms': float('nan'), 'throughput_ns': float('nan'),
|
|
95
|
+
'fft_norm_cost': float('nan'), 'T_fft_one_ms': T_fft_one_ms,
|
|
96
|
+
})
|
|
97
|
+
continue
|
|
98
|
+
|
|
99
|
+
try:
|
|
100
|
+
X = torch.randn(n, device=device)
|
|
101
|
+
|
|
102
|
+
# 3 warm-up calls
|
|
103
|
+
for _ in range(3):
|
|
104
|
+
_ = layer(X)
|
|
105
|
+
if device == 'mps':
|
|
106
|
+
torch.mps.synchronize()
|
|
107
|
+
|
|
108
|
+
# 15 timed calls
|
|
109
|
+
times_ms = []
|
|
110
|
+
REPS = 15
|
|
111
|
+
for _ in range(REPS):
|
|
112
|
+
if device == 'mps':
|
|
113
|
+
torch.mps.synchronize()
|
|
114
|
+
t0 = time.perf_counter()
|
|
115
|
+
layer(X)
|
|
116
|
+
torch.mps.synchronize()
|
|
117
|
+
t1 = time.perf_counter()
|
|
118
|
+
else:
|
|
119
|
+
t0 = time.perf_counter()
|
|
120
|
+
layer(X)
|
|
121
|
+
t1 = time.perf_counter()
|
|
122
|
+
times_ms.append((t1 - t0) * 1000.0)
|
|
123
|
+
|
|
124
|
+
t_median_ms = statistics.median(times_ms)
|
|
125
|
+
t_mean_ms = statistics.mean(times_ms)
|
|
126
|
+
t_std_ms = statistics.stdev(times_ms) if len(times_ms) > 1 else 0.0
|
|
127
|
+
throughput_ns = n / (t_median_ms / 1000.0)
|
|
128
|
+
fft_norm_cost = t_median_ms / T_fft_one_ms
|
|
129
|
+
|
|
130
|
+
speed_rows.append({
|
|
131
|
+
'device': device, 'n': n,
|
|
132
|
+
't_median_ms': t_median_ms, 't_mean_ms': t_mean_ms,
|
|
133
|
+
't_std_ms': t_std_ms, 'throughput_ns': throughput_ns,
|
|
134
|
+
'fft_norm_cost': fft_norm_cost, 'T_fft_one_ms': T_fft_one_ms,
|
|
135
|
+
})
|
|
136
|
+
print(f" [{device}, n={n:>10,}] median={t_median_ms:.3f}ms tput={throughput_ns/1e6:.2f}M/s fft_cost={fft_norm_cost:.1f}")
|
|
137
|
+
|
|
138
|
+
except RuntimeError as e:
|
|
139
|
+
print(f" [{device}, n={n:,}] RuntimeError (OOM?): {e}")
|
|
140
|
+
speed_rows.append({
|
|
141
|
+
'device': device, 'n': n,
|
|
142
|
+
't_median_ms': float('nan'), 't_mean_ms': float('nan'),
|
|
143
|
+
't_std_ms': float('nan'), 'throughput_ns': float('nan'),
|
|
144
|
+
'fft_norm_cost': float('nan'), 'T_fft_one_ms': T_fft_one_ms,
|
|
145
|
+
})
|
|
146
|
+
|
|
147
|
+
# Write speed CSV
|
|
148
|
+
speed_csv = os.path.join(RESULTS_DIR, 'v018_local_speed.csv')
|
|
149
|
+
speed_fields = ['device', 'n', 't_median_ms', 't_mean_ms', 't_std_ms',
|
|
150
|
+
'throughput_ns', 'fft_norm_cost', 'T_fft_one_ms']
|
|
151
|
+
with open(speed_csv, 'w', newline='') as f:
|
|
152
|
+
w = csv.DictWriter(f, fieldnames=speed_fields)
|
|
153
|
+
w.writeheader()
|
|
154
|
+
w.writerows(speed_rows)
|
|
155
|
+
print(f"\nSpeed CSV saved: {speed_csv}")
|
|
156
|
+
|
|
157
|
+
|
|
158
|
+
# ─────────────────────────────────────────────────────────
|
|
159
|
+
# Helper: call R bw.crit on a tensor
|
|
160
|
+
# ─────────────────────────────────────────────────────────
|
|
161
|
+
|
|
162
|
+
def r_bwcrit(X_tensor: torch.Tensor) -> float:
|
|
163
|
+
"""Write tensor to temp CSV and call R bw.crit. Returns NaN on failure."""
|
|
164
|
+
try:
|
|
165
|
+
with tempfile.NamedTemporaryFile(suffix='.csv', mode='w', delete=False) as f:
|
|
166
|
+
f.write('x\n')
|
|
167
|
+
for v in X_tensor.tolist():
|
|
168
|
+
f.write(f'{v:.10f}\n')
|
|
169
|
+
fname = f.name
|
|
170
|
+
result = subprocess.run(
|
|
171
|
+
[R_BINARY, '--vanilla', '-e',
|
|
172
|
+
f'library(multimode); x<-read.csv("{fname}")$x; cat(bw.crit(x,mod0=1L))'],
|
|
173
|
+
capture_output=True, text=True, timeout=300
|
|
174
|
+
)
|
|
175
|
+
os.unlink(fname)
|
|
176
|
+
if result.returncode != 0 or not result.stdout.strip():
|
|
177
|
+
print(f" WARNING: R call failed. stderr: {result.stderr.strip()[:200]}")
|
|
178
|
+
return float('nan')
|
|
179
|
+
return float(result.stdout.strip())
|
|
180
|
+
except Exception as e:
|
|
181
|
+
print(f" WARNING: R call exception: {e}")
|
|
182
|
+
try:
|
|
183
|
+
os.unlink(fname)
|
|
184
|
+
except Exception:
|
|
185
|
+
pass
|
|
186
|
+
return float('nan')
|
|
187
|
+
|
|
188
|
+
|
|
189
|
+
# ─────────────────────────────────────────────────────────
|
|
190
|
+
# Part 2 — Accuracy: Independent-sample
|
|
191
|
+
# ─────────────────────────────────────────────────────────
|
|
192
|
+
|
|
193
|
+
print("\n=== Part 2: Accuracy — Independent-sample ===")
|
|
194
|
+
indep_rows = []
|
|
195
|
+
|
|
196
|
+
for n in ACCURACY_NS:
|
|
197
|
+
for seed in range(20):
|
|
198
|
+
torch.manual_seed(seed)
|
|
199
|
+
X = torch.randn(n)
|
|
200
|
+
|
|
201
|
+
# DCB
|
|
202
|
+
try:
|
|
203
|
+
layer_cpu = DCBLayer()
|
|
204
|
+
h_dcb = layer_cpu(X).item()
|
|
205
|
+
except Exception as e:
|
|
206
|
+
print(f" WARNING: DCB failed n={n}, seed={seed}: {e}")
|
|
207
|
+
h_dcb = float('nan')
|
|
208
|
+
|
|
209
|
+
# R (skip if n > 1_000_000)
|
|
210
|
+
if n > 1_000_000:
|
|
211
|
+
h_r = float('nan')
|
|
212
|
+
err_pct = float('nan')
|
|
213
|
+
else:
|
|
214
|
+
h_r = r_bwcrit(X)
|
|
215
|
+
if math.isnan(h_r) or math.isnan(h_dcb):
|
|
216
|
+
err_pct = float('nan')
|
|
217
|
+
else:
|
|
218
|
+
err_pct = abs(h_dcb - h_r) / h_r * 100.0
|
|
219
|
+
|
|
220
|
+
indep_rows.append({'n': n, 'seed': seed, 'h_dcb': h_dcb,
|
|
221
|
+
'h_r': h_r, 'err_pct': err_pct})
|
|
222
|
+
|
|
223
|
+
valid = [r for r in indep_rows if r['n'] == n and not math.isnan(r.get('err_pct', float('nan')))]
|
|
224
|
+
if valid:
|
|
225
|
+
mean_e = statistics.mean(r['err_pct'] for r in valid)
|
|
226
|
+
print(f" n={n:>9,}: mean_err={mean_e:.4f}% over {len(valid)} seeds with R comparison")
|
|
227
|
+
else:
|
|
228
|
+
dcb_vals = [r['h_dcb'] for r in indep_rows if r['n'] == n]
|
|
229
|
+
print(f" n={n:>9,}: no R comparison (skipped); h_dcb range [{min(dcb_vals):.4f}, {max(dcb_vals):.4f}]")
|
|
230
|
+
|
|
231
|
+
indep_csv = os.path.join(RESULTS_DIR, 'v018_local_accuracy_indep.csv')
|
|
232
|
+
with open(indep_csv, 'w', newline='') as f:
|
|
233
|
+
w = csv.DictWriter(f, fieldnames=['n', 'seed', 'h_dcb', 'h_r', 'err_pct'])
|
|
234
|
+
w.writeheader()
|
|
235
|
+
w.writerows(indep_rows)
|
|
236
|
+
print(f"\nIndep accuracy CSV saved: {indep_csv}")
|
|
237
|
+
|
|
238
|
+
|
|
239
|
+
# ─────────────────────────────────────────────────────────
|
|
240
|
+
# Part 3 — Accuracy: Same-sample
|
|
241
|
+
# ─────────────────────────────────────────────────────────
|
|
242
|
+
|
|
243
|
+
print("\n=== Part 3: Accuracy — Same-sample ===")
|
|
244
|
+
same_rows = []
|
|
245
|
+
|
|
246
|
+
SAME_NS = [1_000, 5_000, 10_000, 25_000, 50_000, 100_000, 500_000, 1_000_000]
|
|
247
|
+
|
|
248
|
+
for n in SAME_NS:
|
|
249
|
+
for seed in range(10):
|
|
250
|
+
torch.manual_seed(seed)
|
|
251
|
+
X = torch.randn(n)
|
|
252
|
+
|
|
253
|
+
# Write to temp CSV once — both DCB and R use same data
|
|
254
|
+
try:
|
|
255
|
+
with tempfile.NamedTemporaryFile(suffix='.csv', mode='w', delete=False) as f:
|
|
256
|
+
f.write('x\n')
|
|
257
|
+
for v in X.tolist():
|
|
258
|
+
f.write(f'{v:.10f}\n')
|
|
259
|
+
fname = f.name
|
|
260
|
+
|
|
261
|
+
# DCB
|
|
262
|
+
try:
|
|
263
|
+
layer_cpu = DCBLayer()
|
|
264
|
+
h_dcb = layer_cpu(X).item()
|
|
265
|
+
except Exception as e:
|
|
266
|
+
print(f" WARNING: DCB failed n={n}, seed={seed}: {e}")
|
|
267
|
+
h_dcb = float('nan')
|
|
268
|
+
|
|
269
|
+
# R (using same file)
|
|
270
|
+
try:
|
|
271
|
+
result = subprocess.run(
|
|
272
|
+
[R_BINARY, '--vanilla', '-e',
|
|
273
|
+
f'library(multimode); x<-read.csv("{fname}")$x; cat(bw.crit(x,mod0=1L))'],
|
|
274
|
+
capture_output=True, text=True, timeout=300
|
|
275
|
+
)
|
|
276
|
+
if result.returncode != 0 or not result.stdout.strip():
|
|
277
|
+
print(f" WARNING: R call failed n={n}, seed={seed}. stderr: {result.stderr.strip()[:200]}")
|
|
278
|
+
h_r = float('nan')
|
|
279
|
+
else:
|
|
280
|
+
h_r = float(result.stdout.strip())
|
|
281
|
+
except Exception as e:
|
|
282
|
+
print(f" WARNING: R call exception n={n}, seed={seed}: {e}")
|
|
283
|
+
h_r = float('nan')
|
|
284
|
+
|
|
285
|
+
os.unlink(fname)
|
|
286
|
+
|
|
287
|
+
except Exception as e:
|
|
288
|
+
print(f" WARNING: Temp file error n={n}, seed={seed}: {e}")
|
|
289
|
+
h_dcb = float('nan')
|
|
290
|
+
h_r = float('nan')
|
|
291
|
+
|
|
292
|
+
if math.isnan(h_r) or math.isnan(h_dcb):
|
|
293
|
+
err_pct = float('nan')
|
|
294
|
+
else:
|
|
295
|
+
err_pct = abs(h_dcb - h_r) / h_r * 100.0
|
|
296
|
+
|
|
297
|
+
same_rows.append({'n': n, 'seed': seed, 'h_dcb': h_dcb,
|
|
298
|
+
'h_r': h_r, 'err_pct': err_pct})
|
|
299
|
+
|
|
300
|
+
valid = [r for r in same_rows if r['n'] == n and not math.isnan(r.get('err_pct', float('nan')))]
|
|
301
|
+
if valid:
|
|
302
|
+
mean_e = statistics.mean(r['err_pct'] for r in valid)
|
|
303
|
+
print(f" n={n:>9,}: mean_err={mean_e:.4f}% over {len(valid)} seeds")
|
|
304
|
+
else:
|
|
305
|
+
print(f" n={n:>9,}: no valid comparisons")
|
|
306
|
+
|
|
307
|
+
same_csv = os.path.join(RESULTS_DIR, 'v018_local_accuracy_same.csv')
|
|
308
|
+
with open(same_csv, 'w', newline='') as f:
|
|
309
|
+
w = csv.DictWriter(f, fieldnames=['n', 'seed', 'h_dcb', 'h_r', 'err_pct'])
|
|
310
|
+
w.writeheader()
|
|
311
|
+
w.writerows(same_rows)
|
|
312
|
+
print(f"\nSame-sample accuracy CSV saved: {same_csv}")
|
|
313
|
+
|
|
314
|
+
|
|
315
|
+
# ─────────────────────────────────────────────────────────
|
|
316
|
+
# Summary Tables
|
|
317
|
+
# ─────────────────────────────────────────────────────────
|
|
318
|
+
|
|
319
|
+
def fmt_float(x, fmt='.3f'):
|
|
320
|
+
if math.isnan(x):
|
|
321
|
+
return 'N/A'
|
|
322
|
+
return format(x, fmt)
|
|
323
|
+
|
|
324
|
+
print("\n" + "="*110)
|
|
325
|
+
print("SUMMARY")
|
|
326
|
+
print("="*110)
|
|
327
|
+
|
|
328
|
+
# Calibration
|
|
329
|
+
print("\n--- Calibration Units ---")
|
|
330
|
+
print(f"{'Device':<8} | {'T_fft_one_ms':>14}")
|
|
331
|
+
print("-" * 28)
|
|
332
|
+
for dev in DEVICES:
|
|
333
|
+
print(f"{dev:<8} | {T_fft[dev]:>14.6f}")
|
|
334
|
+
|
|
335
|
+
# Speed table
|
|
336
|
+
print("\n--- Speed Benchmark ---")
|
|
337
|
+
col_w = 14
|
|
338
|
+
hdr = (f"{'n':>12} | {'CPU t_med(ms)':>{col_w}} | {'CPU fft_cost':>{col_w}} | "
|
|
339
|
+
f"{'CPU n/s':>{col_w}}")
|
|
340
|
+
if 'mps' in DEVICES:
|
|
341
|
+
hdr += (f" | {'MPS t_med(ms)':>{col_w}} | {'MPS fft_cost':>{col_w}} | "
|
|
342
|
+
f"{'MPS n/s':>{col_w}}")
|
|
343
|
+
print(hdr)
|
|
344
|
+
print("-" * len(hdr))
|
|
345
|
+
|
|
346
|
+
speed_by_n = {}
|
|
347
|
+
for r in speed_rows:
|
|
348
|
+
key = (r['n'], r['device'])
|
|
349
|
+
speed_by_n[key] = r
|
|
350
|
+
|
|
351
|
+
for n in SPEED_NS:
|
|
352
|
+
cpu_r = speed_by_n.get((n, 'cpu'), {})
|
|
353
|
+
cpu_med = fmt_float(cpu_r.get('t_median_ms', float('nan')))
|
|
354
|
+
cpu_fft = fmt_float(cpu_r.get('fft_norm_cost', float('nan')), '.1f')
|
|
355
|
+
cpu_tput_raw = cpu_r.get('throughput_ns', float('nan'))
|
|
356
|
+
cpu_tput = 'N/A' if math.isnan(cpu_tput_raw) else f"{cpu_tput_raw/1e6:.2f}M/s"
|
|
357
|
+
|
|
358
|
+
line = f"{n:>12,} | {cpu_med:>{col_w}} | {cpu_fft:>{col_w}} | {cpu_tput:>{col_w}}"
|
|
359
|
+
|
|
360
|
+
if 'mps' in DEVICES:
|
|
361
|
+
mps_r = speed_by_n.get((n, 'mps'), {})
|
|
362
|
+
mps_med = fmt_float(mps_r.get('t_median_ms', float('nan')))
|
|
363
|
+
mps_fft = fmt_float(mps_r.get('fft_norm_cost', float('nan')), '.1f')
|
|
364
|
+
mps_tput_raw = mps_r.get('throughput_ns', float('nan'))
|
|
365
|
+
mps_tput = 'N/A' if math.isnan(mps_tput_raw) else f"{mps_tput_raw/1e6:.2f}M/s"
|
|
366
|
+
line += f" | {mps_med:>{col_w}} | {mps_fft:>{col_w}} | {mps_tput:>{col_w}}"
|
|
367
|
+
|
|
368
|
+
print(line)
|
|
369
|
+
|
|
370
|
+
# Accuracy table
|
|
371
|
+
print("\n--- Accuracy (mean % error vs R bw.crit) ---")
|
|
372
|
+
print(f"{'n':>12} | {'Indep mean_err%':>16} | {'Indep std_err%':>15} | {'Same mean_err%':>15} | {'Same std_err%':>14}")
|
|
373
|
+
print("-" * 90)
|
|
374
|
+
|
|
375
|
+
def acc_stats(rows, n_val):
|
|
376
|
+
valid = [r['err_pct'] for r in rows if r['n'] == n_val and not math.isnan(r.get('err_pct', float('nan')))]
|
|
377
|
+
if not valid:
|
|
378
|
+
return float('nan'), float('nan')
|
|
379
|
+
mean_e = statistics.mean(valid)
|
|
380
|
+
std_e = statistics.stdev(valid) if len(valid) > 1 else 0.0
|
|
381
|
+
return mean_e, std_e
|
|
382
|
+
|
|
383
|
+
for n in ACCURACY_NS:
|
|
384
|
+
indep_mean, indep_std = acc_stats(indep_rows, n)
|
|
385
|
+
same_mean, same_std = acc_stats(same_rows, n)
|
|
386
|
+
print(f"{n:>12,} | {fmt_float(indep_mean, '.4f'):>16} | {fmt_float(indep_std, '.4f'):>15} | "
|
|
387
|
+
f"{fmt_float(same_mean, '.4f'):>15} | {fmt_float(same_std, '.4f'):>14}")
|
|
388
|
+
|
|
389
|
+
print("\n" + "="*110)
|
|
390
|
+
print("Benchmark complete.")
|
|
391
|
+
print(f"CSVs saved to: {RESULTS_DIR}")
|
|
392
|
+
print(f" Speed: v018_local_speed.csv")
|
|
393
|
+
print(f" Accuracy indep: v018_local_accuracy_indep.csv")
|
|
394
|
+
print(f" Accuracy same: v018_local_accuracy_same.csv")
|
|
File without changes
|
|
File without changes
|
|
File without changes
|
|
File without changes
|
|
File without changes
|
|
File without changes
|
|
File without changes
|
|
File without changes
|
|
File without changes
|
|
File without changes
|
|
File without changes
|
|
File without changes
|
|
File without changes
|
|
File without changes
|
|
File without changes
|
|
File without changes
|
|
File without changes
|
|
File without changes
|