diffcb 0.1.9__tar.gz → 0.1.10__tar.gz

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (28) hide show
  1. {diffcb-0.1.9 → diffcb-0.1.10}/PKG-INFO +1 -1
  2. {diffcb-0.1.9 → diffcb-0.1.10}/dcb/__init__.py +1 -1
  3. {diffcb-0.1.9 → diffcb-0.1.10}/dcb/fft_kde.py +14 -2
  4. {diffcb-0.1.9 → diffcb-0.1.10}/dcb/solver.py +25 -8
  5. {diffcb-0.1.9 → diffcb-0.1.10}/pyproject.toml +1 -1
  6. {diffcb-0.1.9 → diffcb-0.1.10}/.gitignore +0 -0
  7. {diffcb-0.1.9 → diffcb-0.1.10}/.zenodo.json +0 -0
  8. {diffcb-0.1.9 → diffcb-0.1.10}/LICENSE +0 -0
  9. {diffcb-0.1.9 → diffcb-0.1.10}/README.md +0 -0
  10. {diffcb-0.1.9 → diffcb-0.1.10}/dcb/diagnostics.py +0 -0
  11. {diffcb-0.1.9 → diffcb-0.1.10}/dcb/kde.py +0 -0
  12. {diffcb-0.1.9 → diffcb-0.1.10}/dcb/layer.py +0 -0
  13. {diffcb-0.1.9 → diffcb-0.1.10}/dcb/training.py +0 -0
  14. {diffcb-0.1.9 → diffcb-0.1.10}/dcb/utils.py +0 -0
  15. {diffcb-0.1.9 → diffcb-0.1.10}/notebooks/.gitkeep +0 -0
  16. {diffcb-0.1.9 → diffcb-0.1.10}/round24_cumulative_bench.py +0 -0
  17. {diffcb-0.1.9 → diffcb-0.1.10}/round24_v016_test.py +0 -0
  18. {diffcb-0.1.9 → diffcb-0.1.10}/round25_full_range_sweep.py +0 -0
  19. {diffcb-0.1.9 → diffcb-0.1.10}/round25_write_csv.py +0 -0
  20. {diffcb-0.1.9 → diffcb-0.1.10}/tests/test_gradcheck.py +0 -0
  21. {diffcb-0.1.9 → diffcb-0.1.10}/tests/test_kde.py +0 -0
  22. {diffcb-0.1.9 → diffcb-0.1.10}/tests/test_layer.py +0 -0
  23. {diffcb-0.1.9 → diffcb-0.1.10}/tests/test_r18c_denom_audit.py +0 -0
  24. {diffcb-0.1.9 → diffcb-0.1.10}/tests/test_r18c_deprecation_warn.py +0 -0
  25. {diffcb-0.1.9 → diffcb-0.1.10}/tests/test_r19_default_fft.py +0 -0
  26. {diffcb-0.1.9 → diffcb-0.1.10}/tests/test_r19_diagnostics.py +0 -0
  27. {diffcb-0.1.9 → diffcb-0.1.10}/tests/test_solver.py +0 -0
  28. {diffcb-0.1.9 → diffcb-0.1.10}/v018_local_bench.py +0 -0
@@ -1,6 +1,6 @@
1
1
  Metadata-Version: 2.4
2
2
  Name: diffcb
3
- Version: 0.1.9
3
+ Version: 0.1.10
4
4
  Summary: Differentiable Critical Bandwidth: Silverman's modality test as a differentiable PyTorch layer with IFT backward pass.
5
5
  Project-URL: Homepage, https://github.com/ryZhangHason/differentiable-critical-bandwidth
6
6
  Project-URL: Repository, https://github.com/ryZhangHason/differentiable-critical-bandwidth
@@ -21,4 +21,4 @@ __all__ = [
21
21
  "TrainingLayer",
22
22
  "anneal_eps_tau", "soft_mode_count_cross", "soft_mode_count",
23
23
  ]
24
- __version__ = "0.1.9"
24
+ __version__ = "0.1.10"
@@ -361,8 +361,20 @@ def _refine_hcrit(
361
361
  y_mid = fprime(h_mid)[j].item()
362
362
  y_hi = fp_hi_[j].item()
363
363
 
364
- # Fit quadratic y = a*h² + b*h + c through the three (h, y) pairs
365
- # and solve for the root in [ref_lo, ref_hi].
364
+ # Square-root model: near a saddle-node bifurcation, f'_peak (h_crit h),
365
+ # so f'²_peak is linear in h. Fit y² = a + b·h through 3 points and solve
366
+ # y²=0 → h_crit = −a/b. More accurate than quadratic for Gaussian/unimodal data.
367
+ y_sq = [y_lo ** 2, y_mid ** 2, y_hi ** 2]
368
+ # Ensure y_lo >= 0 (it should be positive by candidate_mask selection)
369
+ if y_lo > 0:
370
+ sqrt_coeffs = np.polyfit([ref_lo, h_mid, ref_hi], y_sq, 1) # degree-1 fit
371
+ b_sqrt, a_sqrt = sqrt_coeffs[0], sqrt_coeffs[1]
372
+ if b_sqrt != 0:
373
+ sqrt_root = -a_sqrt / b_sqrt
374
+ if ref_lo <= sqrt_root <= ref_hi:
375
+ return float(sqrt_root)
376
+
377
+ # Fallback: original quadratic fit
366
378
  coeffs = np.polyfit([ref_lo, h_mid, ref_hi], [y_lo, y_mid, y_hi], 2)
367
379
  roots = np.roots(coeffs)
368
380
  real_roots = [
@@ -231,7 +231,13 @@ def find_h_crit_hard(
231
231
  lo_domain = X.min().item() - 3 * sigma
232
232
  hi_domain = X.max().item() + 3 * sigma
233
233
  data_range = hi_domain - lo_domain
234
- G_fft = adaptive_fft_G(data_range, h_hi, G_min=G_min)
234
+ # Scale G_min with n^0.4 to keep h_crit/bin_width stable as h_crit shrinks.
235
+ # For Gaussian-regime data h_crit ~ n^{-1/5}, bin_width ~ data_range/G_min,
236
+ # so ratio h_crit/bin_width ~ n^{-1/5} * G_min / data_range. Scaling G_min
237
+ # by n^0.4 keeps this ratio non-decreasing in n without hurting small-n speed.
238
+ n_fft = X.shape[0]
239
+ G_min_eff = min(max(G_min, int(G_min * (n_fft / 100_000) ** 0.4)), 262144)
240
+ G_fft = adaptive_fft_G(data_range, h_hi, G_min=G_min_eff)
235
241
  _domain = (lo_domain, hi_domain)
236
242
  pad_factor = 2 # Worker 5: pad_factor=2 (was 4) — safe for h ≤ 3σ, halves irfft size
237
243
  N = pad_factor * G_fft
@@ -307,10 +313,13 @@ def find_h_crit_hard(
307
313
  # to locate h_crit below the bin-width precision limit.
308
314
  # Worker 23-4: pass C and omega from bisection to avoid duplicate O(n)
309
315
  # histogram + rfft inside _refine_hcrit (saves ~80 ms at n=10M).
316
+ # Fix: use float64 C_mc/omega_mc instead of float32 C/omega — near h_crit
317
+ # the disappearing lobe peak is ~1e-3×max|f′|; float32 noise (~1e-4×max|f′|)
318
+ # gave SNR≈10 → ~0.10% error for Gaussian data. C_mc is already computed.
310
319
  from dcb.fft_kde import _refine_hcrit
311
320
  h_crit = _refine_hcrit(
312
321
  X, lo_val, hi_val, G_fft, _domain, target_modes,
313
- C_external=C, omega_external=omega,
322
+ C_external=C_mc, omega_external=omega_mc,
314
323
  fft_dtype=fft_dtype,
315
324
  )
316
325
 
@@ -327,6 +336,13 @@ def find_h_crit_hard(
327
336
  X, G=G_half, domain=_domain,
328
337
  pad_factor=pad_factor, fft_dtype=fft_dtype,
329
338
  )
339
+ # Mirror of v0.1.9 main-path fix: compute float64 version of
340
+ # C_half so that the Richardson trisection and _refine_hcrit
341
+ # call have float64 SNR, eliminating float32 noise in h_crit_half.
342
+ C_half_mc, omega_half_mc, _ = precompute_fft(
343
+ X, G=G_half, domain=_domain,
344
+ pad_factor=pad_factor, fft_dtype=torch.float64,
345
+ )
330
346
  N_half = pad_factor * G_half
331
347
 
332
348
  # Narrow bracket around the fine-grid h_crit.
@@ -334,8 +350,8 @@ def find_h_crit_hard(
334
350
  h_hi_r = h_crit * 1.05
335
351
 
336
352
  def _bracket_valid(lo_r, hi_r):
337
- c_lo = mode_count_from_C(C_half, omega_half, lo_r, G_half, N_half)
338
- c_hi = mode_count_from_C(C_half, omega_half, hi_r, G_half, N_half)
353
+ c_lo = mode_count_from_C(C_half_mc, omega_half_mc, lo_r, G_half, N_half)
354
+ c_hi = mode_count_from_C(C_half_mc, omega_half_mc, hi_r, G_half, N_half)
339
355
  return (c_lo > target_modes) and (c_hi <= target_modes)
340
356
 
341
357
  valid = _bracket_valid(h_lo_r, h_hi_r)
@@ -346,8 +362,9 @@ def find_h_crit_hard(
346
362
 
347
363
  if valid:
348
364
  # Compile-friendly trisection for Richardson half-grid.
349
- _dtype_r = omega_half.dtype
350
- _dev_r = C_half.device
365
+ # Use float64 C_half_mc/omega_half_mc for SNR.
366
+ _dtype_r = omega_half_mc.dtype
367
+ _dev_r = C_half_mc.device
351
368
  lo_rt = torch.tensor(h_lo_r, dtype=_dtype_r, device=_dev_r)
352
369
  hi_rt = torch.tensor(h_hi_r, dtype=_dtype_r, device=_dev_r)
353
370
  _target_r = torch.tensor(target_modes, dtype=torch.long, device=_dev_r)
@@ -356,7 +373,7 @@ def find_h_crit_hard(
356
373
  h1_r = lo_rt + width_r * (1.0 / 3.0)
357
374
  h2_r = lo_rt + width_r * (2.0 / 3.0)
358
375
  counts_r = mode_count_from_C_batch(
359
- C_half, omega_half,
376
+ C_half_mc, omega_half_mc,
360
377
  torch.stack([h1_r, h2_r]), G_half, N_half,
361
378
  )
362
379
  c1_r = counts_r[0]
@@ -373,7 +390,7 @@ def find_h_crit_hard(
373
390
 
374
391
  h_crit_half = _refine_hcrit(
375
392
  X, lo_r, hi_r, G_half, _domain, target_modes,
376
- C_external=C_half, omega_external=omega_half,
393
+ C_external=C_half_mc, omega_external=omega_half_mc,
377
394
  fft_dtype=fft_dtype,
378
395
  )
379
396
 
@@ -4,7 +4,7 @@ build-backend = "hatchling.build"
4
4
 
5
5
  [project]
6
6
  name = "diffcb"
7
- version = "0.1.9"
7
+ version = "0.1.10"
8
8
  description = "Differentiable Critical Bandwidth: Silverman's modality test as a differentiable PyTorch layer with IFT backward pass."
9
9
  readme = "README.md"
10
10
  license = { file = "LICENSE" }
File without changes
File without changes
File without changes
File without changes
File without changes
File without changes
File without changes
File without changes
File without changes
File without changes
File without changes
File without changes
File without changes
File without changes
File without changes
File without changes
File without changes