diffcb 0.1.8__tar.gz → 0.1.10__tar.gz
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- {diffcb-0.1.8 → diffcb-0.1.10}/PKG-INFO +1 -1
- {diffcb-0.1.8 → diffcb-0.1.10}/dcb/__init__.py +1 -1
- {diffcb-0.1.8 → diffcb-0.1.10}/dcb/fft_kde.py +26 -8
- {diffcb-0.1.8 → diffcb-0.1.10}/dcb/solver.py +41 -16
- {diffcb-0.1.8 → diffcb-0.1.10}/pyproject.toml +1 -1
- diffcb-0.1.10/v018_local_bench.py +394 -0
- {diffcb-0.1.8 → diffcb-0.1.10}/.gitignore +0 -0
- {diffcb-0.1.8 → diffcb-0.1.10}/.zenodo.json +0 -0
- {diffcb-0.1.8 → diffcb-0.1.10}/LICENSE +0 -0
- {diffcb-0.1.8 → diffcb-0.1.10}/README.md +0 -0
- {diffcb-0.1.8 → diffcb-0.1.10}/dcb/diagnostics.py +0 -0
- {diffcb-0.1.8 → diffcb-0.1.10}/dcb/kde.py +0 -0
- {diffcb-0.1.8 → diffcb-0.1.10}/dcb/layer.py +0 -0
- {diffcb-0.1.8 → diffcb-0.1.10}/dcb/training.py +0 -0
- {diffcb-0.1.8 → diffcb-0.1.10}/dcb/utils.py +0 -0
- {diffcb-0.1.8 → diffcb-0.1.10}/notebooks/.gitkeep +0 -0
- {diffcb-0.1.8 → diffcb-0.1.10}/round24_cumulative_bench.py +0 -0
- {diffcb-0.1.8 → diffcb-0.1.10}/round24_v016_test.py +0 -0
- {diffcb-0.1.8 → diffcb-0.1.10}/round25_full_range_sweep.py +0 -0
- {diffcb-0.1.8 → diffcb-0.1.10}/round25_write_csv.py +0 -0
- {diffcb-0.1.8 → diffcb-0.1.10}/tests/test_gradcheck.py +0 -0
- {diffcb-0.1.8 → diffcb-0.1.10}/tests/test_kde.py +0 -0
- {diffcb-0.1.8 → diffcb-0.1.10}/tests/test_layer.py +0 -0
- {diffcb-0.1.8 → diffcb-0.1.10}/tests/test_r18c_denom_audit.py +0 -0
- {diffcb-0.1.8 → diffcb-0.1.10}/tests/test_r18c_deprecation_warn.py +0 -0
- {diffcb-0.1.8 → diffcb-0.1.10}/tests/test_r19_default_fft.py +0 -0
- {diffcb-0.1.8 → diffcb-0.1.10}/tests/test_r19_diagnostics.py +0 -0
- {diffcb-0.1.8 → diffcb-0.1.10}/tests/test_solver.py +0 -0
|
@@ -1,6 +1,6 @@
|
|
|
1
1
|
Metadata-Version: 2.4
|
|
2
2
|
Name: diffcb
|
|
3
|
-
Version: 0.1.
|
|
3
|
+
Version: 0.1.10
|
|
4
4
|
Summary: Differentiable Critical Bandwidth: Silverman's modality test as a differentiable PyTorch layer with IFT backward pass.
|
|
5
5
|
Project-URL: Homepage, https://github.com/ryZhangHason/differentiable-critical-bandwidth
|
|
6
6
|
Project-URL: Repository, https://github.com/ryZhangHason/differentiable-critical-bandwidth
|
|
@@ -137,12 +137,15 @@ def mode_count_from_C(
|
|
|
137
137
|
if C.numel() == 0:
|
|
138
138
|
return 1 # degenerate single-point distribution
|
|
139
139
|
|
|
140
|
+
# Caller must pass C computed with fft_dtype=torch.float64 (complex128).
|
|
141
|
+
# Float32 histogram errors (~1e-4 relative) create spurious sign changes
|
|
142
|
+
# near small h_crit that cannot be filtered without killing genuine lobes.
|
|
143
|
+
# A relative threshold of 1e-12 removes machine-epsilon edge-bin noise
|
|
144
|
+
# (empty histogram bins at domain boundaries contribute ~eps_f64 to f′).
|
|
140
145
|
K_deriv = 1j * omega * torch.exp(-0.5 * (omega * h) ** 2)
|
|
141
146
|
f_prime = torch.fft.irfft(C * K_deriv, n=N).real[:G]
|
|
142
|
-
|
|
143
|
-
|
|
144
|
-
# one host sync (.any()) and two allocations (nonzero_mask, s) per call.
|
|
145
|
-
return int(((f_prime[:-1] > 0) & (f_prime[1:] < 0)).sum().item())
|
|
147
|
+
thresh = f_prime.abs().max() * 1e-12
|
|
148
|
+
return int(((f_prime[:-1] > thresh) & (f_prime[1:] < -thresh)).sum().item())
|
|
146
149
|
|
|
147
150
|
|
|
148
151
|
def mode_count_from_C_batch(
|
|
@@ -186,12 +189,15 @@ def mode_count_from_C_batch(
|
|
|
186
189
|
else:
|
|
187
190
|
h_t = h_batch.to(dtype=omega.dtype, device=omega.device) # (B,)
|
|
188
191
|
|
|
192
|
+
# Caller must pass C in float64 (complex128) — same rationale as mode_count_from_C.
|
|
189
193
|
# Build (B, M) kernel matrix in one vectorised op
|
|
190
194
|
omega_h = omega.unsqueeze(0) * h_t.unsqueeze(1) # (B, M)
|
|
191
195
|
K_batch = 1j * omega.unsqueeze(0) * torch.exp(-0.5 * omega_h ** 2) # (B, M)
|
|
192
196
|
# One batched irfft dispatch instead of B separate calls
|
|
193
|
-
f_prime_batch = torch.fft.irfft(C.unsqueeze(0) * K_batch, n=N)[:, :G]
|
|
194
|
-
|
|
197
|
+
f_prime_batch = torch.fft.irfft(C.unsqueeze(0) * K_batch, n=N)[:, :G] # (B, G)
|
|
198
|
+
# Per-bandwidth threshold: remove edge-bin machine-epsilon noise.
|
|
199
|
+
thresholds = f_prime_batch.abs().amax(dim=1, keepdim=True) * 1e-12 # (B, 1)
|
|
200
|
+
return ((f_prime_batch[:, :-1] > thresholds) & (f_prime_batch[:, 1:] < -thresholds)).sum(dim=1)
|
|
195
201
|
|
|
196
202
|
|
|
197
203
|
def fft_mode_count(
|
|
@@ -355,8 +361,20 @@ def _refine_hcrit(
|
|
|
355
361
|
y_mid = fprime(h_mid)[j].item()
|
|
356
362
|
y_hi = fp_hi_[j].item()
|
|
357
363
|
|
|
358
|
-
#
|
|
359
|
-
#
|
|
364
|
+
# Square-root model: near a saddle-node bifurcation, f'_peak ∝ (h_crit − h),
|
|
365
|
+
# so f'²_peak is linear in h. Fit y² = a + b·h through 3 points and solve
|
|
366
|
+
# y²=0 → h_crit = −a/b. More accurate than quadratic for Gaussian/unimodal data.
|
|
367
|
+
y_sq = [y_lo ** 2, y_mid ** 2, y_hi ** 2]
|
|
368
|
+
# Ensure y_lo >= 0 (it should be positive by candidate_mask selection)
|
|
369
|
+
if y_lo > 0:
|
|
370
|
+
sqrt_coeffs = np.polyfit([ref_lo, h_mid, ref_hi], y_sq, 1) # degree-1 fit
|
|
371
|
+
b_sqrt, a_sqrt = sqrt_coeffs[0], sqrt_coeffs[1]
|
|
372
|
+
if b_sqrt != 0:
|
|
373
|
+
sqrt_root = -a_sqrt / b_sqrt
|
|
374
|
+
if ref_lo <= sqrt_root <= ref_hi:
|
|
375
|
+
return float(sqrt_root)
|
|
376
|
+
|
|
377
|
+
# Fallback: original quadratic fit
|
|
360
378
|
coeffs = np.polyfit([ref_lo, h_mid, ref_hi], [y_lo, y_mid, y_hi], 2)
|
|
361
379
|
roots = np.roots(coeffs)
|
|
362
380
|
real_roots = [
|
|
@@ -231,35 +231,49 @@ def find_h_crit_hard(
|
|
|
231
231
|
lo_domain = X.min().item() - 3 * sigma
|
|
232
232
|
hi_domain = X.max().item() + 3 * sigma
|
|
233
233
|
data_range = hi_domain - lo_domain
|
|
234
|
-
|
|
234
|
+
# Scale G_min with n^0.4 to keep h_crit/bin_width stable as h_crit shrinks.
|
|
235
|
+
# For Gaussian-regime data h_crit ~ n^{-1/5}, bin_width ~ data_range/G_min,
|
|
236
|
+
# so ratio h_crit/bin_width ~ n^{-1/5} * G_min / data_range. Scaling G_min
|
|
237
|
+
# by n^0.4 keeps this ratio non-decreasing in n without hurting small-n speed.
|
|
238
|
+
n_fft = X.shape[0]
|
|
239
|
+
G_min_eff = min(max(G_min, int(G_min * (n_fft / 100_000) ** 0.4)), 262144)
|
|
240
|
+
G_fft = adaptive_fft_G(data_range, h_hi, G_min=G_min_eff)
|
|
235
241
|
_domain = (lo_domain, hi_domain)
|
|
236
242
|
pad_factor = 2 # Worker 5: pad_factor=2 (was 4) — safe for h ≤ 3σ, halves irfft size
|
|
237
243
|
N = pad_factor * G_fft
|
|
238
244
|
|
|
239
245
|
with torch.no_grad():
|
|
240
246
|
# Worker 1: precomputed C — hoist histogram + rfft out of bisection.
|
|
241
|
-
# Worker 3: float32 FFT
|
|
247
|
+
# Worker 3: float32 FFT for _refine_hcrit (2× faster; refinement is
|
|
248
|
+
# accuracy-insensitive to histogram dtype). Mode counting requires
|
|
249
|
+
# float64 histogram precision: float32 FFT errors (~1e-4 relative)
|
|
250
|
+
# create spurious sign changes near small h_crit that a relative
|
|
251
|
+
# threshold cannot remove without also killing genuine small lobes.
|
|
242
252
|
C, omega, _domain = precompute_fft(
|
|
243
253
|
X, G=G_fft, domain=_domain, pad_factor=pad_factor, fft_dtype=fft_dtype,
|
|
244
254
|
)
|
|
255
|
+
C_mc, omega_mc, _ = precompute_fft(
|
|
256
|
+
X, G=G_fft, domain=_domain, pad_factor=pad_factor,
|
|
257
|
+
fft_dtype=torch.float64,
|
|
258
|
+
)
|
|
245
259
|
|
|
246
260
|
# Verify bracket using FFT mode count on full X
|
|
247
|
-
count_lo = mode_count_from_C(
|
|
261
|
+
count_lo = mode_count_from_C(C_mc, omega_mc, h_lo, G_fft, N)
|
|
248
262
|
if count_lo <= target_modes:
|
|
249
263
|
h_lo_try = h_lo
|
|
250
264
|
for _ in range(30):
|
|
251
265
|
h_lo_try *= 0.5
|
|
252
266
|
if h_lo_try < 1e-10:
|
|
253
267
|
break
|
|
254
|
-
if mode_count_from_C(
|
|
268
|
+
if mode_count_from_C(C_mc, omega_mc, h_lo_try, G_fft, N) > target_modes:
|
|
255
269
|
h_lo = h_lo_try
|
|
256
270
|
break
|
|
257
271
|
|
|
258
|
-
count_hi = mode_count_from_C(
|
|
272
|
+
count_hi = mode_count_from_C(C_mc, omega_mc, h_hi, G_fft, N)
|
|
259
273
|
if count_hi > target_modes:
|
|
260
274
|
for _ in range(30):
|
|
261
275
|
h_hi *= 2.0
|
|
262
|
-
if mode_count_from_C(
|
|
276
|
+
if mode_count_from_C(C_mc, omega_mc, h_hi, G_fft, N) <= target_modes:
|
|
263
277
|
break
|
|
264
278
|
|
|
265
279
|
# Compile-friendly trisection: lo/hi are 0-d tensors, no .item()
|
|
@@ -267,8 +281,8 @@ def find_h_crit_hard(
|
|
|
267
281
|
# than enough for any bracket). torch.where replaces the Python
|
|
268
282
|
# if/elif/else so the loop body is a pure tensor computation that
|
|
269
283
|
# torch.compile(mode="reduce-overhead") can trace and replay.
|
|
270
|
-
_dtype =
|
|
271
|
-
_dev =
|
|
284
|
+
_dtype = omega_mc.dtype
|
|
285
|
+
_dev = C_mc.device
|
|
272
286
|
lo_t = torch.tensor(h_lo, dtype=_dtype, device=_dev)
|
|
273
287
|
hi_t = torch.tensor(h_hi, dtype=_dtype, device=_dev)
|
|
274
288
|
_target = torch.tensor(target_modes, dtype=torch.long, device=_dev)
|
|
@@ -277,7 +291,7 @@ def find_h_crit_hard(
|
|
|
277
291
|
h1 = lo_t + width * (1.0 / 3.0)
|
|
278
292
|
h2 = lo_t + width * (2.0 / 3.0)
|
|
279
293
|
counts = mode_count_from_C_batch(
|
|
280
|
-
|
|
294
|
+
C_mc, omega_mc, torch.stack([h1, h2]), G_fft, N
|
|
281
295
|
)
|
|
282
296
|
c1 = counts[0]
|
|
283
297
|
c2 = counts[1]
|
|
@@ -299,10 +313,13 @@ def find_h_crit_hard(
|
|
|
299
313
|
# to locate h_crit below the bin-width precision limit.
|
|
300
314
|
# Worker 23-4: pass C and omega from bisection to avoid duplicate O(n)
|
|
301
315
|
# histogram + rfft inside _refine_hcrit (saves ~80 ms at n=10M).
|
|
316
|
+
# Fix: use float64 C_mc/omega_mc instead of float32 C/omega — near h_crit
|
|
317
|
+
# the disappearing lobe peak is ~1e-3×max|f′|; float32 noise (~1e-4×max|f′|)
|
|
318
|
+
# gave SNR≈10 → ~0.10% error for Gaussian data. C_mc is already computed.
|
|
302
319
|
from dcb.fft_kde import _refine_hcrit
|
|
303
320
|
h_crit = _refine_hcrit(
|
|
304
321
|
X, lo_val, hi_val, G_fft, _domain, target_modes,
|
|
305
|
-
C_external=
|
|
322
|
+
C_external=C_mc, omega_external=omega_mc,
|
|
306
323
|
fft_dtype=fft_dtype,
|
|
307
324
|
)
|
|
308
325
|
|
|
@@ -319,6 +336,13 @@ def find_h_crit_hard(
|
|
|
319
336
|
X, G=G_half, domain=_domain,
|
|
320
337
|
pad_factor=pad_factor, fft_dtype=fft_dtype,
|
|
321
338
|
)
|
|
339
|
+
# Mirror of v0.1.9 main-path fix: compute float64 version of
|
|
340
|
+
# C_half so that the Richardson trisection and _refine_hcrit
|
|
341
|
+
# call have float64 SNR, eliminating float32 noise in h_crit_half.
|
|
342
|
+
C_half_mc, omega_half_mc, _ = precompute_fft(
|
|
343
|
+
X, G=G_half, domain=_domain,
|
|
344
|
+
pad_factor=pad_factor, fft_dtype=torch.float64,
|
|
345
|
+
)
|
|
322
346
|
N_half = pad_factor * G_half
|
|
323
347
|
|
|
324
348
|
# Narrow bracket around the fine-grid h_crit.
|
|
@@ -326,8 +350,8 @@ def find_h_crit_hard(
|
|
|
326
350
|
h_hi_r = h_crit * 1.05
|
|
327
351
|
|
|
328
352
|
def _bracket_valid(lo_r, hi_r):
|
|
329
|
-
c_lo = mode_count_from_C(
|
|
330
|
-
c_hi = mode_count_from_C(
|
|
353
|
+
c_lo = mode_count_from_C(C_half_mc, omega_half_mc, lo_r, G_half, N_half)
|
|
354
|
+
c_hi = mode_count_from_C(C_half_mc, omega_half_mc, hi_r, G_half, N_half)
|
|
331
355
|
return (c_lo > target_modes) and (c_hi <= target_modes)
|
|
332
356
|
|
|
333
357
|
valid = _bracket_valid(h_lo_r, h_hi_r)
|
|
@@ -338,8 +362,9 @@ def find_h_crit_hard(
|
|
|
338
362
|
|
|
339
363
|
if valid:
|
|
340
364
|
# Compile-friendly trisection for Richardson half-grid.
|
|
341
|
-
|
|
342
|
-
|
|
365
|
+
# Use float64 C_half_mc/omega_half_mc for SNR.
|
|
366
|
+
_dtype_r = omega_half_mc.dtype
|
|
367
|
+
_dev_r = C_half_mc.device
|
|
343
368
|
lo_rt = torch.tensor(h_lo_r, dtype=_dtype_r, device=_dev_r)
|
|
344
369
|
hi_rt = torch.tensor(h_hi_r, dtype=_dtype_r, device=_dev_r)
|
|
345
370
|
_target_r = torch.tensor(target_modes, dtype=torch.long, device=_dev_r)
|
|
@@ -348,7 +373,7 @@ def find_h_crit_hard(
|
|
|
348
373
|
h1_r = lo_rt + width_r * (1.0 / 3.0)
|
|
349
374
|
h2_r = lo_rt + width_r * (2.0 / 3.0)
|
|
350
375
|
counts_r = mode_count_from_C_batch(
|
|
351
|
-
|
|
376
|
+
C_half_mc, omega_half_mc,
|
|
352
377
|
torch.stack([h1_r, h2_r]), G_half, N_half,
|
|
353
378
|
)
|
|
354
379
|
c1_r = counts_r[0]
|
|
@@ -365,7 +390,7 @@ def find_h_crit_hard(
|
|
|
365
390
|
|
|
366
391
|
h_crit_half = _refine_hcrit(
|
|
367
392
|
X, lo_r, hi_r, G_half, _domain, target_modes,
|
|
368
|
-
C_external=
|
|
393
|
+
C_external=C_half_mc, omega_external=omega_half_mc,
|
|
369
394
|
fft_dtype=fft_dtype,
|
|
370
395
|
)
|
|
371
396
|
|
|
@@ -4,7 +4,7 @@ build-backend = "hatchling.build"
|
|
|
4
4
|
|
|
5
5
|
[project]
|
|
6
6
|
name = "diffcb"
|
|
7
|
-
version = "0.1.
|
|
7
|
+
version = "0.1.10"
|
|
8
8
|
description = "Differentiable Critical Bandwidth: Silverman's modality test as a differentiable PyTorch layer with IFT backward pass."
|
|
9
9
|
readme = "README.md"
|
|
10
10
|
license = { file = "LICENSE" }
|
|
@@ -0,0 +1,394 @@
|
|
|
1
|
+
"""
|
|
2
|
+
v018_local_bench.py
|
|
3
|
+
Comprehensive benchmark for diffcb v0.1.8 (forward_path='smooth' default).
|
|
4
|
+
Produces 3 CSV files and prints summary tables.
|
|
5
|
+
"""
|
|
6
|
+
|
|
7
|
+
import os
|
|
8
|
+
import time
|
|
9
|
+
import math
|
|
10
|
+
import subprocess
|
|
11
|
+
import tempfile
|
|
12
|
+
import statistics
|
|
13
|
+
import csv
|
|
14
|
+
|
|
15
|
+
import torch
|
|
16
|
+
from dcb import DCBLayer
|
|
17
|
+
|
|
18
|
+
RESULTS_DIR = "/Users/h/Downloads/DCB-workspace/02_projects/01_dcb_proposal/04_analysis/results"
|
|
19
|
+
os.makedirs(RESULTS_DIR, exist_ok=True)
|
|
20
|
+
|
|
21
|
+
SPEED_NS = [1_000, 2_000, 5_000, 10_000, 25_000, 50_000, 100_000,
|
|
22
|
+
500_000, 1_000_000, 5_000_000, 10_000_000]
|
|
23
|
+
ACCURACY_NS = [1_000, 5_000, 10_000, 25_000, 50_000, 100_000, 500_000, 1_000_000]
|
|
24
|
+
|
|
25
|
+
DEVICES = ['cpu']
|
|
26
|
+
if torch.backends.mps.is_available():
|
|
27
|
+
DEVICES.append('mps')
|
|
28
|
+
print("MPS is available — will benchmark both CPU and MPS.")
|
|
29
|
+
else:
|
|
30
|
+
print("MPS not available — benchmarking CPU only.")
|
|
31
|
+
|
|
32
|
+
R_BINARY = '/usr/local/bin/Rscript'
|
|
33
|
+
|
|
34
|
+
|
|
35
|
+
# ─────────────────────────────────────────────────────────
|
|
36
|
+
# Calibration: T_fft_one_ms per device
|
|
37
|
+
# ─────────────────────────────────────────────────────────
|
|
38
|
+
|
|
39
|
+
def calibrate_fft(device: str) -> float:
|
|
40
|
+
"""Return mean time in ms for one rfft(ones(16384)) call."""
|
|
41
|
+
x = torch.ones(16384, dtype=torch.float32, device=device)
|
|
42
|
+
WARMUP = 500
|
|
43
|
+
TIMED = 5000
|
|
44
|
+
|
|
45
|
+
# Warm-up
|
|
46
|
+
for _ in range(WARMUP):
|
|
47
|
+
torch.fft.rfft(x)
|
|
48
|
+
|
|
49
|
+
if device == 'mps':
|
|
50
|
+
torch.mps.synchronize()
|
|
51
|
+
t0 = time.perf_counter()
|
|
52
|
+
for _ in range(TIMED):
|
|
53
|
+
torch.fft.rfft(x)
|
|
54
|
+
torch.mps.synchronize()
|
|
55
|
+
t1 = time.perf_counter()
|
|
56
|
+
else:
|
|
57
|
+
t0 = time.perf_counter()
|
|
58
|
+
for _ in range(TIMED):
|
|
59
|
+
torch.fft.rfft(x)
|
|
60
|
+
t1 = time.perf_counter()
|
|
61
|
+
|
|
62
|
+
total_ms = (t1 - t0) * 1000.0
|
|
63
|
+
return total_ms / TIMED
|
|
64
|
+
|
|
65
|
+
|
|
66
|
+
print("\n=== Calibrating FFT units ===")
|
|
67
|
+
T_fft = {}
|
|
68
|
+
for dev in DEVICES:
|
|
69
|
+
T_fft[dev] = calibrate_fft(dev)
|
|
70
|
+
print(f" {dev}: T_fft_one_ms = {T_fft[dev]:.6f} ms")
|
|
71
|
+
|
|
72
|
+
|
|
73
|
+
# ─────────────────────────────────────────────────────────
|
|
74
|
+
# Part 1 — Speed benchmark
|
|
75
|
+
# ─────────────────────────────────────────────────────────
|
|
76
|
+
|
|
77
|
+
print("\n=== Part 1: Speed Benchmark ===")
|
|
78
|
+
speed_rows = []
|
|
79
|
+
|
|
80
|
+
for device in DEVICES:
|
|
81
|
+
# Build a shared DCBLayer for this device (reuse across n — it's stateless)
|
|
82
|
+
layer = DCBLayer()
|
|
83
|
+
if device == 'mps':
|
|
84
|
+
layer = layer.to(device)
|
|
85
|
+
|
|
86
|
+
T_fft_one_ms = T_fft[device]
|
|
87
|
+
|
|
88
|
+
for n in SPEED_NS:
|
|
89
|
+
if n >= 5_000_000 and device == 'mps':
|
|
90
|
+
print(f" [mps, n={n:,}] — skipping (too large for MPS RAM guard)")
|
|
91
|
+
speed_rows.append({
|
|
92
|
+
'device': device, 'n': n,
|
|
93
|
+
't_median_ms': float('nan'), 't_mean_ms': float('nan'),
|
|
94
|
+
't_std_ms': float('nan'), 'throughput_ns': float('nan'),
|
|
95
|
+
'fft_norm_cost': float('nan'), 'T_fft_one_ms': T_fft_one_ms,
|
|
96
|
+
})
|
|
97
|
+
continue
|
|
98
|
+
|
|
99
|
+
try:
|
|
100
|
+
X = torch.randn(n, device=device)
|
|
101
|
+
|
|
102
|
+
# 3 warm-up calls
|
|
103
|
+
for _ in range(3):
|
|
104
|
+
_ = layer(X)
|
|
105
|
+
if device == 'mps':
|
|
106
|
+
torch.mps.synchronize()
|
|
107
|
+
|
|
108
|
+
# 15 timed calls
|
|
109
|
+
times_ms = []
|
|
110
|
+
REPS = 15
|
|
111
|
+
for _ in range(REPS):
|
|
112
|
+
if device == 'mps':
|
|
113
|
+
torch.mps.synchronize()
|
|
114
|
+
t0 = time.perf_counter()
|
|
115
|
+
layer(X)
|
|
116
|
+
torch.mps.synchronize()
|
|
117
|
+
t1 = time.perf_counter()
|
|
118
|
+
else:
|
|
119
|
+
t0 = time.perf_counter()
|
|
120
|
+
layer(X)
|
|
121
|
+
t1 = time.perf_counter()
|
|
122
|
+
times_ms.append((t1 - t0) * 1000.0)
|
|
123
|
+
|
|
124
|
+
t_median_ms = statistics.median(times_ms)
|
|
125
|
+
t_mean_ms = statistics.mean(times_ms)
|
|
126
|
+
t_std_ms = statistics.stdev(times_ms) if len(times_ms) > 1 else 0.0
|
|
127
|
+
throughput_ns = n / (t_median_ms / 1000.0)
|
|
128
|
+
fft_norm_cost = t_median_ms / T_fft_one_ms
|
|
129
|
+
|
|
130
|
+
speed_rows.append({
|
|
131
|
+
'device': device, 'n': n,
|
|
132
|
+
't_median_ms': t_median_ms, 't_mean_ms': t_mean_ms,
|
|
133
|
+
't_std_ms': t_std_ms, 'throughput_ns': throughput_ns,
|
|
134
|
+
'fft_norm_cost': fft_norm_cost, 'T_fft_one_ms': T_fft_one_ms,
|
|
135
|
+
})
|
|
136
|
+
print(f" [{device}, n={n:>10,}] median={t_median_ms:.3f}ms tput={throughput_ns/1e6:.2f}M/s fft_cost={fft_norm_cost:.1f}")
|
|
137
|
+
|
|
138
|
+
except RuntimeError as e:
|
|
139
|
+
print(f" [{device}, n={n:,}] RuntimeError (OOM?): {e}")
|
|
140
|
+
speed_rows.append({
|
|
141
|
+
'device': device, 'n': n,
|
|
142
|
+
't_median_ms': float('nan'), 't_mean_ms': float('nan'),
|
|
143
|
+
't_std_ms': float('nan'), 'throughput_ns': float('nan'),
|
|
144
|
+
'fft_norm_cost': float('nan'), 'T_fft_one_ms': T_fft_one_ms,
|
|
145
|
+
})
|
|
146
|
+
|
|
147
|
+
# Write speed CSV
|
|
148
|
+
speed_csv = os.path.join(RESULTS_DIR, 'v018_local_speed.csv')
|
|
149
|
+
speed_fields = ['device', 'n', 't_median_ms', 't_mean_ms', 't_std_ms',
|
|
150
|
+
'throughput_ns', 'fft_norm_cost', 'T_fft_one_ms']
|
|
151
|
+
with open(speed_csv, 'w', newline='') as f:
|
|
152
|
+
w = csv.DictWriter(f, fieldnames=speed_fields)
|
|
153
|
+
w.writeheader()
|
|
154
|
+
w.writerows(speed_rows)
|
|
155
|
+
print(f"\nSpeed CSV saved: {speed_csv}")
|
|
156
|
+
|
|
157
|
+
|
|
158
|
+
# ─────────────────────────────────────────────────────────
|
|
159
|
+
# Helper: call R bw.crit on a tensor
|
|
160
|
+
# ─────────────────────────────────────────────────────────
|
|
161
|
+
|
|
162
|
+
def r_bwcrit(X_tensor: torch.Tensor) -> float:
|
|
163
|
+
"""Write tensor to temp CSV and call R bw.crit. Returns NaN on failure."""
|
|
164
|
+
try:
|
|
165
|
+
with tempfile.NamedTemporaryFile(suffix='.csv', mode='w', delete=False) as f:
|
|
166
|
+
f.write('x\n')
|
|
167
|
+
for v in X_tensor.tolist():
|
|
168
|
+
f.write(f'{v:.10f}\n')
|
|
169
|
+
fname = f.name
|
|
170
|
+
result = subprocess.run(
|
|
171
|
+
[R_BINARY, '--vanilla', '-e',
|
|
172
|
+
f'library(multimode); x<-read.csv("{fname}")$x; cat(bw.crit(x,mod0=1L))'],
|
|
173
|
+
capture_output=True, text=True, timeout=300
|
|
174
|
+
)
|
|
175
|
+
os.unlink(fname)
|
|
176
|
+
if result.returncode != 0 or not result.stdout.strip():
|
|
177
|
+
print(f" WARNING: R call failed. stderr: {result.stderr.strip()[:200]}")
|
|
178
|
+
return float('nan')
|
|
179
|
+
return float(result.stdout.strip())
|
|
180
|
+
except Exception as e:
|
|
181
|
+
print(f" WARNING: R call exception: {e}")
|
|
182
|
+
try:
|
|
183
|
+
os.unlink(fname)
|
|
184
|
+
except Exception:
|
|
185
|
+
pass
|
|
186
|
+
return float('nan')
|
|
187
|
+
|
|
188
|
+
|
|
189
|
+
# ─────────────────────────────────────────────────────────
|
|
190
|
+
# Part 2 — Accuracy: Independent-sample
|
|
191
|
+
# ─────────────────────────────────────────────────────────
|
|
192
|
+
|
|
193
|
+
print("\n=== Part 2: Accuracy — Independent-sample ===")
|
|
194
|
+
indep_rows = []
|
|
195
|
+
|
|
196
|
+
for n in ACCURACY_NS:
|
|
197
|
+
for seed in range(20):
|
|
198
|
+
torch.manual_seed(seed)
|
|
199
|
+
X = torch.randn(n)
|
|
200
|
+
|
|
201
|
+
# DCB
|
|
202
|
+
try:
|
|
203
|
+
layer_cpu = DCBLayer()
|
|
204
|
+
h_dcb = layer_cpu(X).item()
|
|
205
|
+
except Exception as e:
|
|
206
|
+
print(f" WARNING: DCB failed n={n}, seed={seed}: {e}")
|
|
207
|
+
h_dcb = float('nan')
|
|
208
|
+
|
|
209
|
+
# R (skip if n > 1_000_000)
|
|
210
|
+
if n > 1_000_000:
|
|
211
|
+
h_r = float('nan')
|
|
212
|
+
err_pct = float('nan')
|
|
213
|
+
else:
|
|
214
|
+
h_r = r_bwcrit(X)
|
|
215
|
+
if math.isnan(h_r) or math.isnan(h_dcb):
|
|
216
|
+
err_pct = float('nan')
|
|
217
|
+
else:
|
|
218
|
+
err_pct = abs(h_dcb - h_r) / h_r * 100.0
|
|
219
|
+
|
|
220
|
+
indep_rows.append({'n': n, 'seed': seed, 'h_dcb': h_dcb,
|
|
221
|
+
'h_r': h_r, 'err_pct': err_pct})
|
|
222
|
+
|
|
223
|
+
valid = [r for r in indep_rows if r['n'] == n and not math.isnan(r.get('err_pct', float('nan')))]
|
|
224
|
+
if valid:
|
|
225
|
+
mean_e = statistics.mean(r['err_pct'] for r in valid)
|
|
226
|
+
print(f" n={n:>9,}: mean_err={mean_e:.4f}% over {len(valid)} seeds with R comparison")
|
|
227
|
+
else:
|
|
228
|
+
dcb_vals = [r['h_dcb'] for r in indep_rows if r['n'] == n]
|
|
229
|
+
print(f" n={n:>9,}: no R comparison (skipped); h_dcb range [{min(dcb_vals):.4f}, {max(dcb_vals):.4f}]")
|
|
230
|
+
|
|
231
|
+
indep_csv = os.path.join(RESULTS_DIR, 'v018_local_accuracy_indep.csv')
|
|
232
|
+
with open(indep_csv, 'w', newline='') as f:
|
|
233
|
+
w = csv.DictWriter(f, fieldnames=['n', 'seed', 'h_dcb', 'h_r', 'err_pct'])
|
|
234
|
+
w.writeheader()
|
|
235
|
+
w.writerows(indep_rows)
|
|
236
|
+
print(f"\nIndep accuracy CSV saved: {indep_csv}")
|
|
237
|
+
|
|
238
|
+
|
|
239
|
+
# ─────────────────────────────────────────────────────────
|
|
240
|
+
# Part 3 — Accuracy: Same-sample
|
|
241
|
+
# ─────────────────────────────────────────────────────────
|
|
242
|
+
|
|
243
|
+
print("\n=== Part 3: Accuracy — Same-sample ===")
|
|
244
|
+
same_rows = []
|
|
245
|
+
|
|
246
|
+
SAME_NS = [1_000, 5_000, 10_000, 25_000, 50_000, 100_000, 500_000, 1_000_000]
|
|
247
|
+
|
|
248
|
+
for n in SAME_NS:
|
|
249
|
+
for seed in range(10):
|
|
250
|
+
torch.manual_seed(seed)
|
|
251
|
+
X = torch.randn(n)
|
|
252
|
+
|
|
253
|
+
# Write to temp CSV once — both DCB and R use same data
|
|
254
|
+
try:
|
|
255
|
+
with tempfile.NamedTemporaryFile(suffix='.csv', mode='w', delete=False) as f:
|
|
256
|
+
f.write('x\n')
|
|
257
|
+
for v in X.tolist():
|
|
258
|
+
f.write(f'{v:.10f}\n')
|
|
259
|
+
fname = f.name
|
|
260
|
+
|
|
261
|
+
# DCB
|
|
262
|
+
try:
|
|
263
|
+
layer_cpu = DCBLayer()
|
|
264
|
+
h_dcb = layer_cpu(X).item()
|
|
265
|
+
except Exception as e:
|
|
266
|
+
print(f" WARNING: DCB failed n={n}, seed={seed}: {e}")
|
|
267
|
+
h_dcb = float('nan')
|
|
268
|
+
|
|
269
|
+
# R (using same file)
|
|
270
|
+
try:
|
|
271
|
+
result = subprocess.run(
|
|
272
|
+
[R_BINARY, '--vanilla', '-e',
|
|
273
|
+
f'library(multimode); x<-read.csv("{fname}")$x; cat(bw.crit(x,mod0=1L))'],
|
|
274
|
+
capture_output=True, text=True, timeout=300
|
|
275
|
+
)
|
|
276
|
+
if result.returncode != 0 or not result.stdout.strip():
|
|
277
|
+
print(f" WARNING: R call failed n={n}, seed={seed}. stderr: {result.stderr.strip()[:200]}")
|
|
278
|
+
h_r = float('nan')
|
|
279
|
+
else:
|
|
280
|
+
h_r = float(result.stdout.strip())
|
|
281
|
+
except Exception as e:
|
|
282
|
+
print(f" WARNING: R call exception n={n}, seed={seed}: {e}")
|
|
283
|
+
h_r = float('nan')
|
|
284
|
+
|
|
285
|
+
os.unlink(fname)
|
|
286
|
+
|
|
287
|
+
except Exception as e:
|
|
288
|
+
print(f" WARNING: Temp file error n={n}, seed={seed}: {e}")
|
|
289
|
+
h_dcb = float('nan')
|
|
290
|
+
h_r = float('nan')
|
|
291
|
+
|
|
292
|
+
if math.isnan(h_r) or math.isnan(h_dcb):
|
|
293
|
+
err_pct = float('nan')
|
|
294
|
+
else:
|
|
295
|
+
err_pct = abs(h_dcb - h_r) / h_r * 100.0
|
|
296
|
+
|
|
297
|
+
same_rows.append({'n': n, 'seed': seed, 'h_dcb': h_dcb,
|
|
298
|
+
'h_r': h_r, 'err_pct': err_pct})
|
|
299
|
+
|
|
300
|
+
valid = [r for r in same_rows if r['n'] == n and not math.isnan(r.get('err_pct', float('nan')))]
|
|
301
|
+
if valid:
|
|
302
|
+
mean_e = statistics.mean(r['err_pct'] for r in valid)
|
|
303
|
+
print(f" n={n:>9,}: mean_err={mean_e:.4f}% over {len(valid)} seeds")
|
|
304
|
+
else:
|
|
305
|
+
print(f" n={n:>9,}: no valid comparisons")
|
|
306
|
+
|
|
307
|
+
same_csv = os.path.join(RESULTS_DIR, 'v018_local_accuracy_same.csv')
|
|
308
|
+
with open(same_csv, 'w', newline='') as f:
|
|
309
|
+
w = csv.DictWriter(f, fieldnames=['n', 'seed', 'h_dcb', 'h_r', 'err_pct'])
|
|
310
|
+
w.writeheader()
|
|
311
|
+
w.writerows(same_rows)
|
|
312
|
+
print(f"\nSame-sample accuracy CSV saved: {same_csv}")
|
|
313
|
+
|
|
314
|
+
|
|
315
|
+
# ─────────────────────────────────────────────────────────
|
|
316
|
+
# Summary Tables
|
|
317
|
+
# ─────────────────────────────────────────────────────────
|
|
318
|
+
|
|
319
|
+
def fmt_float(x, fmt='.3f'):
|
|
320
|
+
if math.isnan(x):
|
|
321
|
+
return 'N/A'
|
|
322
|
+
return format(x, fmt)
|
|
323
|
+
|
|
324
|
+
print("\n" + "="*110)
|
|
325
|
+
print("SUMMARY")
|
|
326
|
+
print("="*110)
|
|
327
|
+
|
|
328
|
+
# Calibration
|
|
329
|
+
print("\n--- Calibration Units ---")
|
|
330
|
+
print(f"{'Device':<8} | {'T_fft_one_ms':>14}")
|
|
331
|
+
print("-" * 28)
|
|
332
|
+
for dev in DEVICES:
|
|
333
|
+
print(f"{dev:<8} | {T_fft[dev]:>14.6f}")
|
|
334
|
+
|
|
335
|
+
# Speed table
|
|
336
|
+
print("\n--- Speed Benchmark ---")
|
|
337
|
+
col_w = 14
|
|
338
|
+
hdr = (f"{'n':>12} | {'CPU t_med(ms)':>{col_w}} | {'CPU fft_cost':>{col_w}} | "
|
|
339
|
+
f"{'CPU n/s':>{col_w}}")
|
|
340
|
+
if 'mps' in DEVICES:
|
|
341
|
+
hdr += (f" | {'MPS t_med(ms)':>{col_w}} | {'MPS fft_cost':>{col_w}} | "
|
|
342
|
+
f"{'MPS n/s':>{col_w}}")
|
|
343
|
+
print(hdr)
|
|
344
|
+
print("-" * len(hdr))
|
|
345
|
+
|
|
346
|
+
speed_by_n = {}
|
|
347
|
+
for r in speed_rows:
|
|
348
|
+
key = (r['n'], r['device'])
|
|
349
|
+
speed_by_n[key] = r
|
|
350
|
+
|
|
351
|
+
for n in SPEED_NS:
|
|
352
|
+
cpu_r = speed_by_n.get((n, 'cpu'), {})
|
|
353
|
+
cpu_med = fmt_float(cpu_r.get('t_median_ms', float('nan')))
|
|
354
|
+
cpu_fft = fmt_float(cpu_r.get('fft_norm_cost', float('nan')), '.1f')
|
|
355
|
+
cpu_tput_raw = cpu_r.get('throughput_ns', float('nan'))
|
|
356
|
+
cpu_tput = 'N/A' if math.isnan(cpu_tput_raw) else f"{cpu_tput_raw/1e6:.2f}M/s"
|
|
357
|
+
|
|
358
|
+
line = f"{n:>12,} | {cpu_med:>{col_w}} | {cpu_fft:>{col_w}} | {cpu_tput:>{col_w}}"
|
|
359
|
+
|
|
360
|
+
if 'mps' in DEVICES:
|
|
361
|
+
mps_r = speed_by_n.get((n, 'mps'), {})
|
|
362
|
+
mps_med = fmt_float(mps_r.get('t_median_ms', float('nan')))
|
|
363
|
+
mps_fft = fmt_float(mps_r.get('fft_norm_cost', float('nan')), '.1f')
|
|
364
|
+
mps_tput_raw = mps_r.get('throughput_ns', float('nan'))
|
|
365
|
+
mps_tput = 'N/A' if math.isnan(mps_tput_raw) else f"{mps_tput_raw/1e6:.2f}M/s"
|
|
366
|
+
line += f" | {mps_med:>{col_w}} | {mps_fft:>{col_w}} | {mps_tput:>{col_w}}"
|
|
367
|
+
|
|
368
|
+
print(line)
|
|
369
|
+
|
|
370
|
+
# Accuracy table
|
|
371
|
+
print("\n--- Accuracy (mean % error vs R bw.crit) ---")
|
|
372
|
+
print(f"{'n':>12} | {'Indep mean_err%':>16} | {'Indep std_err%':>15} | {'Same mean_err%':>15} | {'Same std_err%':>14}")
|
|
373
|
+
print("-" * 90)
|
|
374
|
+
|
|
375
|
+
def acc_stats(rows, n_val):
|
|
376
|
+
valid = [r['err_pct'] for r in rows if r['n'] == n_val and not math.isnan(r.get('err_pct', float('nan')))]
|
|
377
|
+
if not valid:
|
|
378
|
+
return float('nan'), float('nan')
|
|
379
|
+
mean_e = statistics.mean(valid)
|
|
380
|
+
std_e = statistics.stdev(valid) if len(valid) > 1 else 0.0
|
|
381
|
+
return mean_e, std_e
|
|
382
|
+
|
|
383
|
+
for n in ACCURACY_NS:
|
|
384
|
+
indep_mean, indep_std = acc_stats(indep_rows, n)
|
|
385
|
+
same_mean, same_std = acc_stats(same_rows, n)
|
|
386
|
+
print(f"{n:>12,} | {fmt_float(indep_mean, '.4f'):>16} | {fmt_float(indep_std, '.4f'):>15} | "
|
|
387
|
+
f"{fmt_float(same_mean, '.4f'):>15} | {fmt_float(same_std, '.4f'):>14}")
|
|
388
|
+
|
|
389
|
+
print("\n" + "="*110)
|
|
390
|
+
print("Benchmark complete.")
|
|
391
|
+
print(f"CSVs saved to: {RESULTS_DIR}")
|
|
392
|
+
print(f" Speed: v018_local_speed.csv")
|
|
393
|
+
print(f" Accuracy indep: v018_local_accuracy_indep.csv")
|
|
394
|
+
print(f" Accuracy same: v018_local_accuracy_same.csv")
|
|
File without changes
|
|
File without changes
|
|
File without changes
|
|
File without changes
|
|
File without changes
|
|
File without changes
|
|
File without changes
|
|
File without changes
|
|
File without changes
|
|
File without changes
|
|
File without changes
|
|
File without changes
|
|
File without changes
|
|
File without changes
|
|
File without changes
|
|
File without changes
|
|
File without changes
|
|
File without changes
|
|
File without changes
|
|
File without changes
|
|
File without changes
|
|
File without changes
|