diffcb 0.1.8__tar.gz → 0.1.10__tar.gz

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (28) hide show
  1. {diffcb-0.1.8 → diffcb-0.1.10}/PKG-INFO +1 -1
  2. {diffcb-0.1.8 → diffcb-0.1.10}/dcb/__init__.py +1 -1
  3. {diffcb-0.1.8 → diffcb-0.1.10}/dcb/fft_kde.py +26 -8
  4. {diffcb-0.1.8 → diffcb-0.1.10}/dcb/solver.py +41 -16
  5. {diffcb-0.1.8 → diffcb-0.1.10}/pyproject.toml +1 -1
  6. diffcb-0.1.10/v018_local_bench.py +394 -0
  7. {diffcb-0.1.8 → diffcb-0.1.10}/.gitignore +0 -0
  8. {diffcb-0.1.8 → diffcb-0.1.10}/.zenodo.json +0 -0
  9. {diffcb-0.1.8 → diffcb-0.1.10}/LICENSE +0 -0
  10. {diffcb-0.1.8 → diffcb-0.1.10}/README.md +0 -0
  11. {diffcb-0.1.8 → diffcb-0.1.10}/dcb/diagnostics.py +0 -0
  12. {diffcb-0.1.8 → diffcb-0.1.10}/dcb/kde.py +0 -0
  13. {diffcb-0.1.8 → diffcb-0.1.10}/dcb/layer.py +0 -0
  14. {diffcb-0.1.8 → diffcb-0.1.10}/dcb/training.py +0 -0
  15. {diffcb-0.1.8 → diffcb-0.1.10}/dcb/utils.py +0 -0
  16. {diffcb-0.1.8 → diffcb-0.1.10}/notebooks/.gitkeep +0 -0
  17. {diffcb-0.1.8 → diffcb-0.1.10}/round24_cumulative_bench.py +0 -0
  18. {diffcb-0.1.8 → diffcb-0.1.10}/round24_v016_test.py +0 -0
  19. {diffcb-0.1.8 → diffcb-0.1.10}/round25_full_range_sweep.py +0 -0
  20. {diffcb-0.1.8 → diffcb-0.1.10}/round25_write_csv.py +0 -0
  21. {diffcb-0.1.8 → diffcb-0.1.10}/tests/test_gradcheck.py +0 -0
  22. {diffcb-0.1.8 → diffcb-0.1.10}/tests/test_kde.py +0 -0
  23. {diffcb-0.1.8 → diffcb-0.1.10}/tests/test_layer.py +0 -0
  24. {diffcb-0.1.8 → diffcb-0.1.10}/tests/test_r18c_denom_audit.py +0 -0
  25. {diffcb-0.1.8 → diffcb-0.1.10}/tests/test_r18c_deprecation_warn.py +0 -0
  26. {diffcb-0.1.8 → diffcb-0.1.10}/tests/test_r19_default_fft.py +0 -0
  27. {diffcb-0.1.8 → diffcb-0.1.10}/tests/test_r19_diagnostics.py +0 -0
  28. {diffcb-0.1.8 → diffcb-0.1.10}/tests/test_solver.py +0 -0
@@ -1,6 +1,6 @@
1
1
  Metadata-Version: 2.4
2
2
  Name: diffcb
3
- Version: 0.1.8
3
+ Version: 0.1.10
4
4
  Summary: Differentiable Critical Bandwidth: Silverman's modality test as a differentiable PyTorch layer with IFT backward pass.
5
5
  Project-URL: Homepage, https://github.com/ryZhangHason/differentiable-critical-bandwidth
6
6
  Project-URL: Repository, https://github.com/ryZhangHason/differentiable-critical-bandwidth
@@ -21,4 +21,4 @@ __all__ = [
21
21
  "TrainingLayer",
22
22
  "anneal_eps_tau", "soft_mode_count_cross", "soft_mode_count",
23
23
  ]
24
- __version__ = "0.1.8"
24
+ __version__ = "0.1.10"
@@ -137,12 +137,15 @@ def mode_count_from_C(
137
137
  if C.numel() == 0:
138
138
  return 1 # degenerate single-point distribution
139
139
 
140
+ # Caller must pass C computed with fft_dtype=torch.float64 (complex128).
141
+ # Float32 histogram errors (~1e-4 relative) create spurious sign changes
142
+ # near small h_crit that cannot be filtered without killing genuine lobes.
143
+ # A relative threshold of 1e-12 removes machine-epsilon edge-bin noise
144
+ # (empty histogram bins at domain boundaries contribute ~eps_f64 to f′).
140
145
  K_deriv = 1j * omega * torch.exp(-0.5 * (omega * h) ** 2)
141
146
  f_prime = torch.fft.irfft(C * K_deriv, n=N).real[:G]
142
- # Exact zeros are measure-zero for smooth KDE on non-degenerate data; strict
143
- # inequalities match the masked path in all practical cases while avoiding
144
- # one host sync (.any()) and two allocations (nonzero_mask, s) per call.
145
- return int(((f_prime[:-1] > 0) & (f_prime[1:] < 0)).sum().item())
147
+ thresh = f_prime.abs().max() * 1e-12
148
+ return int(((f_prime[:-1] > thresh) & (f_prime[1:] < -thresh)).sum().item())
146
149
 
147
150
 
148
151
  def mode_count_from_C_batch(
@@ -186,12 +189,15 @@ def mode_count_from_C_batch(
186
189
  else:
187
190
  h_t = h_batch.to(dtype=omega.dtype, device=omega.device) # (B,)
188
191
 
192
+ # Caller must pass C in float64 (complex128) — same rationale as mode_count_from_C.
189
193
  # Build (B, M) kernel matrix in one vectorised op
190
194
  omega_h = omega.unsqueeze(0) * h_t.unsqueeze(1) # (B, M)
191
195
  K_batch = 1j * omega.unsqueeze(0) * torch.exp(-0.5 * omega_h ** 2) # (B, M)
192
196
  # One batched irfft dispatch instead of B separate calls
193
- f_prime_batch = torch.fft.irfft(C.unsqueeze(0) * K_batch, n=N)[:, :G] # (B, G)
194
- return ((f_prime_batch[:, :-1] > 0) & (f_prime_batch[:, 1:] < 0)).sum(dim=1)
197
+ f_prime_batch = torch.fft.irfft(C.unsqueeze(0) * K_batch, n=N)[:, :G] # (B, G)
198
+ # Per-bandwidth threshold: remove edge-bin machine-epsilon noise.
199
+ thresholds = f_prime_batch.abs().amax(dim=1, keepdim=True) * 1e-12 # (B, 1)
200
+ return ((f_prime_batch[:, :-1] > thresholds) & (f_prime_batch[:, 1:] < -thresholds)).sum(dim=1)
195
201
 
196
202
 
197
203
  def fft_mode_count(
@@ -355,8 +361,20 @@ def _refine_hcrit(
355
361
  y_mid = fprime(h_mid)[j].item()
356
362
  y_hi = fp_hi_[j].item()
357
363
 
358
- # Fit quadratic y = a*h² + b*h + c through the three (h, y) pairs
359
- # and solve for the root in [ref_lo, ref_hi].
364
+ # Square-root model: near a saddle-node bifurcation, f'_peak (h_crit h),
365
+ # so f'²_peak is linear in h. Fit y² = a + b·h through 3 points and solve
366
+ # y²=0 → h_crit = −a/b. More accurate than quadratic for Gaussian/unimodal data.
367
+ y_sq = [y_lo ** 2, y_mid ** 2, y_hi ** 2]
368
+ # Ensure y_lo >= 0 (it should be positive by candidate_mask selection)
369
+ if y_lo > 0:
370
+ sqrt_coeffs = np.polyfit([ref_lo, h_mid, ref_hi], y_sq, 1) # degree-1 fit
371
+ b_sqrt, a_sqrt = sqrt_coeffs[0], sqrt_coeffs[1]
372
+ if b_sqrt != 0:
373
+ sqrt_root = -a_sqrt / b_sqrt
374
+ if ref_lo <= sqrt_root <= ref_hi:
375
+ return float(sqrt_root)
376
+
377
+ # Fallback: original quadratic fit
360
378
  coeffs = np.polyfit([ref_lo, h_mid, ref_hi], [y_lo, y_mid, y_hi], 2)
361
379
  roots = np.roots(coeffs)
362
380
  real_roots = [
@@ -231,35 +231,49 @@ def find_h_crit_hard(
231
231
  lo_domain = X.min().item() - 3 * sigma
232
232
  hi_domain = X.max().item() + 3 * sigma
233
233
  data_range = hi_domain - lo_domain
234
- G_fft = adaptive_fft_G(data_range, h_hi, G_min=G_min)
234
+ # Scale G_min with n^0.4 to keep h_crit/bin_width stable as h_crit shrinks.
235
+ # For Gaussian-regime data h_crit ~ n^{-1/5}, bin_width ~ data_range/G_min,
236
+ # so ratio h_crit/bin_width ~ n^{-1/5} * G_min / data_range. Scaling G_min
237
+ # by n^0.4 keeps this ratio non-decreasing in n without hurting small-n speed.
238
+ n_fft = X.shape[0]
239
+ G_min_eff = min(max(G_min, int(G_min * (n_fft / 100_000) ** 0.4)), 262144)
240
+ G_fft = adaptive_fft_G(data_range, h_hi, G_min=G_min_eff)
235
241
  _domain = (lo_domain, hi_domain)
236
242
  pad_factor = 2 # Worker 5: pad_factor=2 (was 4) — safe for h ≤ 3σ, halves irfft size
237
243
  N = pad_factor * G_fft
238
244
 
239
245
  with torch.no_grad():
240
246
  # Worker 1: precomputed C — hoist histogram + rfft out of bisection.
241
- # Worker 3: float32 FFT by default 2× faster; _refine_hcrit uses float64 independently.
247
+ # Worker 3: float32 FFT for _refine_hcrit (2× faster; refinement is
248
+ # accuracy-insensitive to histogram dtype). Mode counting requires
249
+ # float64 histogram precision: float32 FFT errors (~1e-4 relative)
250
+ # create spurious sign changes near small h_crit that a relative
251
+ # threshold cannot remove without also killing genuine small lobes.
242
252
  C, omega, _domain = precompute_fft(
243
253
  X, G=G_fft, domain=_domain, pad_factor=pad_factor, fft_dtype=fft_dtype,
244
254
  )
255
+ C_mc, omega_mc, _ = precompute_fft(
256
+ X, G=G_fft, domain=_domain, pad_factor=pad_factor,
257
+ fft_dtype=torch.float64,
258
+ )
245
259
 
246
260
  # Verify bracket using FFT mode count on full X
247
- count_lo = mode_count_from_C(C, omega, h_lo, G_fft, N)
261
+ count_lo = mode_count_from_C(C_mc, omega_mc, h_lo, G_fft, N)
248
262
  if count_lo <= target_modes:
249
263
  h_lo_try = h_lo
250
264
  for _ in range(30):
251
265
  h_lo_try *= 0.5
252
266
  if h_lo_try < 1e-10:
253
267
  break
254
- if mode_count_from_C(C, omega, h_lo_try, G_fft, N) > target_modes:
268
+ if mode_count_from_C(C_mc, omega_mc, h_lo_try, G_fft, N) > target_modes:
255
269
  h_lo = h_lo_try
256
270
  break
257
271
 
258
- count_hi = mode_count_from_C(C, omega, h_hi, G_fft, N)
272
+ count_hi = mode_count_from_C(C_mc, omega_mc, h_hi, G_fft, N)
259
273
  if count_hi > target_modes:
260
274
  for _ in range(30):
261
275
  h_hi *= 2.0
262
- if mode_count_from_C(C, omega, h_hi, G_fft, N) <= target_modes:
276
+ if mode_count_from_C(C_mc, omega_mc, h_hi, G_fft, N) <= target_modes:
263
277
  break
264
278
 
265
279
  # Compile-friendly trisection: lo/hi are 0-d tensors, no .item()
@@ -267,8 +281,8 @@ def find_h_crit_hard(
267
281
  # than enough for any bracket). torch.where replaces the Python
268
282
  # if/elif/else so the loop body is a pure tensor computation that
269
283
  # torch.compile(mode="reduce-overhead") can trace and replay.
270
- _dtype = omega.dtype
271
- _dev = C.device
284
+ _dtype = omega_mc.dtype
285
+ _dev = C_mc.device
272
286
  lo_t = torch.tensor(h_lo, dtype=_dtype, device=_dev)
273
287
  hi_t = torch.tensor(h_hi, dtype=_dtype, device=_dev)
274
288
  _target = torch.tensor(target_modes, dtype=torch.long, device=_dev)
@@ -277,7 +291,7 @@ def find_h_crit_hard(
277
291
  h1 = lo_t + width * (1.0 / 3.0)
278
292
  h2 = lo_t + width * (2.0 / 3.0)
279
293
  counts = mode_count_from_C_batch(
280
- C, omega, torch.stack([h1, h2]), G_fft, N
294
+ C_mc, omega_mc, torch.stack([h1, h2]), G_fft, N
281
295
  )
282
296
  c1 = counts[0]
283
297
  c2 = counts[1]
@@ -299,10 +313,13 @@ def find_h_crit_hard(
299
313
  # to locate h_crit below the bin-width precision limit.
300
314
  # Worker 23-4: pass C and omega from bisection to avoid duplicate O(n)
301
315
  # histogram + rfft inside _refine_hcrit (saves ~80 ms at n=10M).
316
+ # Fix: use float64 C_mc/omega_mc instead of float32 C/omega — near h_crit
317
+ # the disappearing lobe peak is ~1e-3×max|f′|; float32 noise (~1e-4×max|f′|)
318
+ # gave SNR≈10 → ~0.10% error for Gaussian data. C_mc is already computed.
302
319
  from dcb.fft_kde import _refine_hcrit
303
320
  h_crit = _refine_hcrit(
304
321
  X, lo_val, hi_val, G_fft, _domain, target_modes,
305
- C_external=C, omega_external=omega,
322
+ C_external=C_mc, omega_external=omega_mc,
306
323
  fft_dtype=fft_dtype,
307
324
  )
308
325
 
@@ -319,6 +336,13 @@ def find_h_crit_hard(
319
336
  X, G=G_half, domain=_domain,
320
337
  pad_factor=pad_factor, fft_dtype=fft_dtype,
321
338
  )
339
+ # Mirror of v0.1.9 main-path fix: compute float64 version of
340
+ # C_half so that the Richardson trisection and _refine_hcrit
341
+ # call have float64 SNR, eliminating float32 noise in h_crit_half.
342
+ C_half_mc, omega_half_mc, _ = precompute_fft(
343
+ X, G=G_half, domain=_domain,
344
+ pad_factor=pad_factor, fft_dtype=torch.float64,
345
+ )
322
346
  N_half = pad_factor * G_half
323
347
 
324
348
  # Narrow bracket around the fine-grid h_crit.
@@ -326,8 +350,8 @@ def find_h_crit_hard(
326
350
  h_hi_r = h_crit * 1.05
327
351
 
328
352
  def _bracket_valid(lo_r, hi_r):
329
- c_lo = mode_count_from_C(C_half, omega_half, lo_r, G_half, N_half)
330
- c_hi = mode_count_from_C(C_half, omega_half, hi_r, G_half, N_half)
353
+ c_lo = mode_count_from_C(C_half_mc, omega_half_mc, lo_r, G_half, N_half)
354
+ c_hi = mode_count_from_C(C_half_mc, omega_half_mc, hi_r, G_half, N_half)
331
355
  return (c_lo > target_modes) and (c_hi <= target_modes)
332
356
 
333
357
  valid = _bracket_valid(h_lo_r, h_hi_r)
@@ -338,8 +362,9 @@ def find_h_crit_hard(
338
362
 
339
363
  if valid:
340
364
  # Compile-friendly trisection for Richardson half-grid.
341
- _dtype_r = omega_half.dtype
342
- _dev_r = C_half.device
365
+ # Use float64 C_half_mc/omega_half_mc for SNR.
366
+ _dtype_r = omega_half_mc.dtype
367
+ _dev_r = C_half_mc.device
343
368
  lo_rt = torch.tensor(h_lo_r, dtype=_dtype_r, device=_dev_r)
344
369
  hi_rt = torch.tensor(h_hi_r, dtype=_dtype_r, device=_dev_r)
345
370
  _target_r = torch.tensor(target_modes, dtype=torch.long, device=_dev_r)
@@ -348,7 +373,7 @@ def find_h_crit_hard(
348
373
  h1_r = lo_rt + width_r * (1.0 / 3.0)
349
374
  h2_r = lo_rt + width_r * (2.0 / 3.0)
350
375
  counts_r = mode_count_from_C_batch(
351
- C_half, omega_half,
376
+ C_half_mc, omega_half_mc,
352
377
  torch.stack([h1_r, h2_r]), G_half, N_half,
353
378
  )
354
379
  c1_r = counts_r[0]
@@ -365,7 +390,7 @@ def find_h_crit_hard(
365
390
 
366
391
  h_crit_half = _refine_hcrit(
367
392
  X, lo_r, hi_r, G_half, _domain, target_modes,
368
- C_external=C_half, omega_external=omega_half,
393
+ C_external=C_half_mc, omega_external=omega_half_mc,
369
394
  fft_dtype=fft_dtype,
370
395
  )
371
396
 
@@ -4,7 +4,7 @@ build-backend = "hatchling.build"
4
4
 
5
5
  [project]
6
6
  name = "diffcb"
7
- version = "0.1.8"
7
+ version = "0.1.10"
8
8
  description = "Differentiable Critical Bandwidth: Silverman's modality test as a differentiable PyTorch layer with IFT backward pass."
9
9
  readme = "README.md"
10
10
  license = { file = "LICENSE" }
@@ -0,0 +1,394 @@
1
+ """
2
+ v018_local_bench.py
3
+ Comprehensive benchmark for diffcb v0.1.8 (forward_path='smooth' default).
4
+ Produces 3 CSV files and prints summary tables.
5
+ """
6
+
7
+ import os
8
+ import time
9
+ import math
10
+ import subprocess
11
+ import tempfile
12
+ import statistics
13
+ import csv
14
+
15
+ import torch
16
+ from dcb import DCBLayer
17
+
18
+ RESULTS_DIR = "/Users/h/Downloads/DCB-workspace/02_projects/01_dcb_proposal/04_analysis/results"
19
+ os.makedirs(RESULTS_DIR, exist_ok=True)
20
+
21
+ SPEED_NS = [1_000, 2_000, 5_000, 10_000, 25_000, 50_000, 100_000,
22
+ 500_000, 1_000_000, 5_000_000, 10_000_000]
23
+ ACCURACY_NS = [1_000, 5_000, 10_000, 25_000, 50_000, 100_000, 500_000, 1_000_000]
24
+
25
+ DEVICES = ['cpu']
26
+ if torch.backends.mps.is_available():
27
+ DEVICES.append('mps')
28
+ print("MPS is available — will benchmark both CPU and MPS.")
29
+ else:
30
+ print("MPS not available — benchmarking CPU only.")
31
+
32
+ R_BINARY = '/usr/local/bin/Rscript'
33
+
34
+
35
+ # ─────────────────────────────────────────────────────────
36
+ # Calibration: T_fft_one_ms per device
37
+ # ─────────────────────────────────────────────────────────
38
+
39
+ def calibrate_fft(device: str) -> float:
40
+ """Return mean time in ms for one rfft(ones(16384)) call."""
41
+ x = torch.ones(16384, dtype=torch.float32, device=device)
42
+ WARMUP = 500
43
+ TIMED = 5000
44
+
45
+ # Warm-up
46
+ for _ in range(WARMUP):
47
+ torch.fft.rfft(x)
48
+
49
+ if device == 'mps':
50
+ torch.mps.synchronize()
51
+ t0 = time.perf_counter()
52
+ for _ in range(TIMED):
53
+ torch.fft.rfft(x)
54
+ torch.mps.synchronize()
55
+ t1 = time.perf_counter()
56
+ else:
57
+ t0 = time.perf_counter()
58
+ for _ in range(TIMED):
59
+ torch.fft.rfft(x)
60
+ t1 = time.perf_counter()
61
+
62
+ total_ms = (t1 - t0) * 1000.0
63
+ return total_ms / TIMED
64
+
65
+
66
+ print("\n=== Calibrating FFT units ===")
67
+ T_fft = {}
68
+ for dev in DEVICES:
69
+ T_fft[dev] = calibrate_fft(dev)
70
+ print(f" {dev}: T_fft_one_ms = {T_fft[dev]:.6f} ms")
71
+
72
+
73
+ # ─────────────────────────────────────────────────────────
74
+ # Part 1 — Speed benchmark
75
+ # ─────────────────────────────────────────────────────────
76
+
77
+ print("\n=== Part 1: Speed Benchmark ===")
78
+ speed_rows = []
79
+
80
+ for device in DEVICES:
81
+ # Build a shared DCBLayer for this device (reuse across n — it's stateless)
82
+ layer = DCBLayer()
83
+ if device == 'mps':
84
+ layer = layer.to(device)
85
+
86
+ T_fft_one_ms = T_fft[device]
87
+
88
+ for n in SPEED_NS:
89
+ if n >= 5_000_000 and device == 'mps':
90
+ print(f" [mps, n={n:,}] — skipping (too large for MPS RAM guard)")
91
+ speed_rows.append({
92
+ 'device': device, 'n': n,
93
+ 't_median_ms': float('nan'), 't_mean_ms': float('nan'),
94
+ 't_std_ms': float('nan'), 'throughput_ns': float('nan'),
95
+ 'fft_norm_cost': float('nan'), 'T_fft_one_ms': T_fft_one_ms,
96
+ })
97
+ continue
98
+
99
+ try:
100
+ X = torch.randn(n, device=device)
101
+
102
+ # 3 warm-up calls
103
+ for _ in range(3):
104
+ _ = layer(X)
105
+ if device == 'mps':
106
+ torch.mps.synchronize()
107
+
108
+ # 15 timed calls
109
+ times_ms = []
110
+ REPS = 15
111
+ for _ in range(REPS):
112
+ if device == 'mps':
113
+ torch.mps.synchronize()
114
+ t0 = time.perf_counter()
115
+ layer(X)
116
+ torch.mps.synchronize()
117
+ t1 = time.perf_counter()
118
+ else:
119
+ t0 = time.perf_counter()
120
+ layer(X)
121
+ t1 = time.perf_counter()
122
+ times_ms.append((t1 - t0) * 1000.0)
123
+
124
+ t_median_ms = statistics.median(times_ms)
125
+ t_mean_ms = statistics.mean(times_ms)
126
+ t_std_ms = statistics.stdev(times_ms) if len(times_ms) > 1 else 0.0
127
+ throughput_ns = n / (t_median_ms / 1000.0)
128
+ fft_norm_cost = t_median_ms / T_fft_one_ms
129
+
130
+ speed_rows.append({
131
+ 'device': device, 'n': n,
132
+ 't_median_ms': t_median_ms, 't_mean_ms': t_mean_ms,
133
+ 't_std_ms': t_std_ms, 'throughput_ns': throughput_ns,
134
+ 'fft_norm_cost': fft_norm_cost, 'T_fft_one_ms': T_fft_one_ms,
135
+ })
136
+ print(f" [{device}, n={n:>10,}] median={t_median_ms:.3f}ms tput={throughput_ns/1e6:.2f}M/s fft_cost={fft_norm_cost:.1f}")
137
+
138
+ except RuntimeError as e:
139
+ print(f" [{device}, n={n:,}] RuntimeError (OOM?): {e}")
140
+ speed_rows.append({
141
+ 'device': device, 'n': n,
142
+ 't_median_ms': float('nan'), 't_mean_ms': float('nan'),
143
+ 't_std_ms': float('nan'), 'throughput_ns': float('nan'),
144
+ 'fft_norm_cost': float('nan'), 'T_fft_one_ms': T_fft_one_ms,
145
+ })
146
+
147
+ # Write speed CSV
148
+ speed_csv = os.path.join(RESULTS_DIR, 'v018_local_speed.csv')
149
+ speed_fields = ['device', 'n', 't_median_ms', 't_mean_ms', 't_std_ms',
150
+ 'throughput_ns', 'fft_norm_cost', 'T_fft_one_ms']
151
+ with open(speed_csv, 'w', newline='') as f:
152
+ w = csv.DictWriter(f, fieldnames=speed_fields)
153
+ w.writeheader()
154
+ w.writerows(speed_rows)
155
+ print(f"\nSpeed CSV saved: {speed_csv}")
156
+
157
+
158
+ # ─────────────────────────────────────────────────────────
159
+ # Helper: call R bw.crit on a tensor
160
+ # ─────────────────────────────────────────────────────────
161
+
162
+ def r_bwcrit(X_tensor: torch.Tensor) -> float:
163
+ """Write tensor to temp CSV and call R bw.crit. Returns NaN on failure."""
164
+ try:
165
+ with tempfile.NamedTemporaryFile(suffix='.csv', mode='w', delete=False) as f:
166
+ f.write('x\n')
167
+ for v in X_tensor.tolist():
168
+ f.write(f'{v:.10f}\n')
169
+ fname = f.name
170
+ result = subprocess.run(
171
+ [R_BINARY, '--vanilla', '-e',
172
+ f'library(multimode); x<-read.csv("{fname}")$x; cat(bw.crit(x,mod0=1L))'],
173
+ capture_output=True, text=True, timeout=300
174
+ )
175
+ os.unlink(fname)
176
+ if result.returncode != 0 or not result.stdout.strip():
177
+ print(f" WARNING: R call failed. stderr: {result.stderr.strip()[:200]}")
178
+ return float('nan')
179
+ return float(result.stdout.strip())
180
+ except Exception as e:
181
+ print(f" WARNING: R call exception: {e}")
182
+ try:
183
+ os.unlink(fname)
184
+ except Exception:
185
+ pass
186
+ return float('nan')
187
+
188
+
189
+ # ─────────────────────────────────────────────────────────
190
+ # Part 2 — Accuracy: Independent-sample
191
+ # ─────────────────────────────────────────────────────────
192
+
193
+ print("\n=== Part 2: Accuracy — Independent-sample ===")
194
+ indep_rows = []
195
+
196
+ for n in ACCURACY_NS:
197
+ for seed in range(20):
198
+ torch.manual_seed(seed)
199
+ X = torch.randn(n)
200
+
201
+ # DCB
202
+ try:
203
+ layer_cpu = DCBLayer()
204
+ h_dcb = layer_cpu(X).item()
205
+ except Exception as e:
206
+ print(f" WARNING: DCB failed n={n}, seed={seed}: {e}")
207
+ h_dcb = float('nan')
208
+
209
+ # R (skip if n > 1_000_000)
210
+ if n > 1_000_000:
211
+ h_r = float('nan')
212
+ err_pct = float('nan')
213
+ else:
214
+ h_r = r_bwcrit(X)
215
+ if math.isnan(h_r) or math.isnan(h_dcb):
216
+ err_pct = float('nan')
217
+ else:
218
+ err_pct = abs(h_dcb - h_r) / h_r * 100.0
219
+
220
+ indep_rows.append({'n': n, 'seed': seed, 'h_dcb': h_dcb,
221
+ 'h_r': h_r, 'err_pct': err_pct})
222
+
223
+ valid = [r for r in indep_rows if r['n'] == n and not math.isnan(r.get('err_pct', float('nan')))]
224
+ if valid:
225
+ mean_e = statistics.mean(r['err_pct'] for r in valid)
226
+ print(f" n={n:>9,}: mean_err={mean_e:.4f}% over {len(valid)} seeds with R comparison")
227
+ else:
228
+ dcb_vals = [r['h_dcb'] for r in indep_rows if r['n'] == n]
229
+ print(f" n={n:>9,}: no R comparison (skipped); h_dcb range [{min(dcb_vals):.4f}, {max(dcb_vals):.4f}]")
230
+
231
+ indep_csv = os.path.join(RESULTS_DIR, 'v018_local_accuracy_indep.csv')
232
+ with open(indep_csv, 'w', newline='') as f:
233
+ w = csv.DictWriter(f, fieldnames=['n', 'seed', 'h_dcb', 'h_r', 'err_pct'])
234
+ w.writeheader()
235
+ w.writerows(indep_rows)
236
+ print(f"\nIndep accuracy CSV saved: {indep_csv}")
237
+
238
+
239
+ # ─────────────────────────────────────────────────────────
240
+ # Part 3 — Accuracy: Same-sample
241
+ # ─────────────────────────────────────────────────────────
242
+
243
+ print("\n=== Part 3: Accuracy — Same-sample ===")
244
+ same_rows = []
245
+
246
+ SAME_NS = [1_000, 5_000, 10_000, 25_000, 50_000, 100_000, 500_000, 1_000_000]
247
+
248
+ for n in SAME_NS:
249
+ for seed in range(10):
250
+ torch.manual_seed(seed)
251
+ X = torch.randn(n)
252
+
253
+ # Write to temp CSV once — both DCB and R use same data
254
+ try:
255
+ with tempfile.NamedTemporaryFile(suffix='.csv', mode='w', delete=False) as f:
256
+ f.write('x\n')
257
+ for v in X.tolist():
258
+ f.write(f'{v:.10f}\n')
259
+ fname = f.name
260
+
261
+ # DCB
262
+ try:
263
+ layer_cpu = DCBLayer()
264
+ h_dcb = layer_cpu(X).item()
265
+ except Exception as e:
266
+ print(f" WARNING: DCB failed n={n}, seed={seed}: {e}")
267
+ h_dcb = float('nan')
268
+
269
+ # R (using same file)
270
+ try:
271
+ result = subprocess.run(
272
+ [R_BINARY, '--vanilla', '-e',
273
+ f'library(multimode); x<-read.csv("{fname}")$x; cat(bw.crit(x,mod0=1L))'],
274
+ capture_output=True, text=True, timeout=300
275
+ )
276
+ if result.returncode != 0 or not result.stdout.strip():
277
+ print(f" WARNING: R call failed n={n}, seed={seed}. stderr: {result.stderr.strip()[:200]}")
278
+ h_r = float('nan')
279
+ else:
280
+ h_r = float(result.stdout.strip())
281
+ except Exception as e:
282
+ print(f" WARNING: R call exception n={n}, seed={seed}: {e}")
283
+ h_r = float('nan')
284
+
285
+ os.unlink(fname)
286
+
287
+ except Exception as e:
288
+ print(f" WARNING: Temp file error n={n}, seed={seed}: {e}")
289
+ h_dcb = float('nan')
290
+ h_r = float('nan')
291
+
292
+ if math.isnan(h_r) or math.isnan(h_dcb):
293
+ err_pct = float('nan')
294
+ else:
295
+ err_pct = abs(h_dcb - h_r) / h_r * 100.0
296
+
297
+ same_rows.append({'n': n, 'seed': seed, 'h_dcb': h_dcb,
298
+ 'h_r': h_r, 'err_pct': err_pct})
299
+
300
+ valid = [r for r in same_rows if r['n'] == n and not math.isnan(r.get('err_pct', float('nan')))]
301
+ if valid:
302
+ mean_e = statistics.mean(r['err_pct'] for r in valid)
303
+ print(f" n={n:>9,}: mean_err={mean_e:.4f}% over {len(valid)} seeds")
304
+ else:
305
+ print(f" n={n:>9,}: no valid comparisons")
306
+
307
+ same_csv = os.path.join(RESULTS_DIR, 'v018_local_accuracy_same.csv')
308
+ with open(same_csv, 'w', newline='') as f:
309
+ w = csv.DictWriter(f, fieldnames=['n', 'seed', 'h_dcb', 'h_r', 'err_pct'])
310
+ w.writeheader()
311
+ w.writerows(same_rows)
312
+ print(f"\nSame-sample accuracy CSV saved: {same_csv}")
313
+
314
+
315
+ # ─────────────────────────────────────────────────────────
316
+ # Summary Tables
317
+ # ─────────────────────────────────────────────────────────
318
+
319
+ def fmt_float(x, fmt='.3f'):
320
+ if math.isnan(x):
321
+ return 'N/A'
322
+ return format(x, fmt)
323
+
324
+ print("\n" + "="*110)
325
+ print("SUMMARY")
326
+ print("="*110)
327
+
328
+ # Calibration
329
+ print("\n--- Calibration Units ---")
330
+ print(f"{'Device':<8} | {'T_fft_one_ms':>14}")
331
+ print("-" * 28)
332
+ for dev in DEVICES:
333
+ print(f"{dev:<8} | {T_fft[dev]:>14.6f}")
334
+
335
+ # Speed table
336
+ print("\n--- Speed Benchmark ---")
337
+ col_w = 14
338
+ hdr = (f"{'n':>12} | {'CPU t_med(ms)':>{col_w}} | {'CPU fft_cost':>{col_w}} | "
339
+ f"{'CPU n/s':>{col_w}}")
340
+ if 'mps' in DEVICES:
341
+ hdr += (f" | {'MPS t_med(ms)':>{col_w}} | {'MPS fft_cost':>{col_w}} | "
342
+ f"{'MPS n/s':>{col_w}}")
343
+ print(hdr)
344
+ print("-" * len(hdr))
345
+
346
+ speed_by_n = {}
347
+ for r in speed_rows:
348
+ key = (r['n'], r['device'])
349
+ speed_by_n[key] = r
350
+
351
+ for n in SPEED_NS:
352
+ cpu_r = speed_by_n.get((n, 'cpu'), {})
353
+ cpu_med = fmt_float(cpu_r.get('t_median_ms', float('nan')))
354
+ cpu_fft = fmt_float(cpu_r.get('fft_norm_cost', float('nan')), '.1f')
355
+ cpu_tput_raw = cpu_r.get('throughput_ns', float('nan'))
356
+ cpu_tput = 'N/A' if math.isnan(cpu_tput_raw) else f"{cpu_tput_raw/1e6:.2f}M/s"
357
+
358
+ line = f"{n:>12,} | {cpu_med:>{col_w}} | {cpu_fft:>{col_w}} | {cpu_tput:>{col_w}}"
359
+
360
+ if 'mps' in DEVICES:
361
+ mps_r = speed_by_n.get((n, 'mps'), {})
362
+ mps_med = fmt_float(mps_r.get('t_median_ms', float('nan')))
363
+ mps_fft = fmt_float(mps_r.get('fft_norm_cost', float('nan')), '.1f')
364
+ mps_tput_raw = mps_r.get('throughput_ns', float('nan'))
365
+ mps_tput = 'N/A' if math.isnan(mps_tput_raw) else f"{mps_tput_raw/1e6:.2f}M/s"
366
+ line += f" | {mps_med:>{col_w}} | {mps_fft:>{col_w}} | {mps_tput:>{col_w}}"
367
+
368
+ print(line)
369
+
370
+ # Accuracy table
371
+ print("\n--- Accuracy (mean % error vs R bw.crit) ---")
372
+ print(f"{'n':>12} | {'Indep mean_err%':>16} | {'Indep std_err%':>15} | {'Same mean_err%':>15} | {'Same std_err%':>14}")
373
+ print("-" * 90)
374
+
375
+ def acc_stats(rows, n_val):
376
+ valid = [r['err_pct'] for r in rows if r['n'] == n_val and not math.isnan(r.get('err_pct', float('nan')))]
377
+ if not valid:
378
+ return float('nan'), float('nan')
379
+ mean_e = statistics.mean(valid)
380
+ std_e = statistics.stdev(valid) if len(valid) > 1 else 0.0
381
+ return mean_e, std_e
382
+
383
+ for n in ACCURACY_NS:
384
+ indep_mean, indep_std = acc_stats(indep_rows, n)
385
+ same_mean, same_std = acc_stats(same_rows, n)
386
+ print(f"{n:>12,} | {fmt_float(indep_mean, '.4f'):>16} | {fmt_float(indep_std, '.4f'):>15} | "
387
+ f"{fmt_float(same_mean, '.4f'):>15} | {fmt_float(same_std, '.4f'):>14}")
388
+
389
+ print("\n" + "="*110)
390
+ print("Benchmark complete.")
391
+ print(f"CSVs saved to: {RESULTS_DIR}")
392
+ print(f" Speed: v018_local_speed.csv")
393
+ print(f" Accuracy indep: v018_local_accuracy_indep.csv")
394
+ print(f" Accuracy same: v018_local_accuracy_same.csv")
File without changes
File without changes
File without changes
File without changes
File without changes
File without changes
File without changes
File without changes
File without changes
File without changes
File without changes
File without changes
File without changes
File without changes
File without changes
File without changes