diffcb 0.1.1__tar.gz → 0.1.3__tar.gz

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
@@ -1,6 +1,6 @@
1
1
  Metadata-Version: 2.4
2
2
  Name: diffcb
3
- Version: 0.1.1
3
+ Version: 0.1.3
4
4
  Summary: Differentiable Critical Bandwidth: Silverman's modality test as a differentiable PyTorch layer with IFT backward pass.
5
5
  Project-URL: Homepage, https://github.com/ryZhangHason/differentiable-critical-bandwidth
6
6
  Project-URL: Repository, https://github.com/ryZhangHason/differentiable-critical-bandwidth
@@ -71,10 +71,10 @@ The critical bandwidth `h_crit` is the minimum KDE bandwidth at which a distribu
71
71
  import torch
72
72
  from dcb import DCBLayer
73
73
 
74
- X = torch.randn(256, requires_grad=True) # 1D samples
74
+ X = torch.randn(1000, requires_grad=True) # 1D samples
75
75
  layer = DCBLayer(target_modes=1)
76
- h_crit = layer(X) # differentiable scalar
77
- h_crit.backward() # exact IFT gradients
76
+ h_crit = layer(X) # differentiable scalar
77
+ h_crit.backward() # exact IFT gradients
78
78
  ```
79
79
 
80
80
  ## Installation
@@ -91,34 +91,72 @@ cd differentiable-critical-bandwidth
91
91
  pip install -e ".[dev]"
92
92
  ```
93
93
 
94
- ## Paper
94
+ ## Accuracy vs R's `bw.crit`
95
95
 
96
- > Ruiyu Zhang. "Differentiable Critical Bandwidth: Making Silverman's Modality Test End-to-End Trainable." *Journal of Machine Learning Research*, 2026 (in preparation).
96
+ DCB is validated against R's `multimode::bw.crit(data, mod0=1)` the standard reference implementation of Hall & York (2001). On **identical data**:
97
+
98
+ | n | DCB vs R (same sample) | DCB vs R (independent samples) |
99
+ |---|---|---|
100
+ | 100K | **0.004%** | ~0.5% (MC noise from independent RNG) |
101
+ | 1M | **0.005%** | ~0.2% |
102
+ | 10M | **0.004%** | ~0.1% |
103
+
104
+ The independent-sample figures reflect natural sampling variability (two unbiased estimators drawing different data), not algorithmic error. On identical data, DCB agrees with R to within **0.005%** at all tested n. DCB is 43× faster than R at n=100M (1.1 s vs 50 s) and handles n=2B in 24 s while R OOMs.
105
+
106
+ ## Key Parameters
107
+
108
+ ```python
109
+ DCBLayer(
110
+ target_modes=1, # target number of modes
111
+ G=512, # IFT evaluation grid points
112
+ use_fft=True, # FFT forward (default); eliminates subsampling bias for n>50K
113
+ max_n_exact=1_000_000,# sketch to sketch_size when n exceeds this (None = always exact)
114
+ sketch_size=500_000, # sketch target; 500K matches full-n accuracy (O(n^{-2/9}) rate)
115
+ safe_backward=False, # clamp IFT denominator near bifurcations
116
+ )
117
+ ```
97
118
 
98
119
  ## Confirmed Experimental Results
99
120
 
100
- All results produced on Kaggle GPU (T4 / P100) — see `experiments/` and `outputs/`.
121
+ All GPU results produced on Kaggle (T4 / P100) — see `experiments/` and `outputs/`.
101
122
 
102
123
  | Experiment | Result | Criterion |
103
124
  |---|---|---|
104
- | **Validation (m≥2)** | R²=0.91, MAE=0.07, Spearman ρ=0.89 | R²≥0.85, MAE≤0.10 ✓ |
105
- | **Speedup vs scipy (n=8192)** | **10.5×** on T4 | ≥3× ✓ |
125
+ | **Accuracy vs R (same data, n=100K)** | **0.004%** | < 0.01% ✓ |
126
+ | **Validation (m≥2, Marron-Wand)** | R²=0.91, MAE=0.07, ρ=0.89 | R²≥0.85 ✓ |
127
+ | **Speedup vs scipy (CUDA T4, n=8192)** | **10.5×** | ≥3× ✓ |
106
128
  | **GAN mode preservation** | h_crit=1.232 >> 0.3 | h_crit>0.3 ✓ |
107
129
  | **Anomaly AUC (KDDCup99)** | DCB=**0.9982** vs IF=0.9867 | DCB≥IF ✓ |
108
130
 
131
+ ## Changelog
132
+
133
+ ### v0.1.1 (2026-05-29)
134
+ - **MPS fix:** `torch.histc` on MPS allocated an n×bins intermediate (OOM at n≥5M). Replaced with `bucketize+bincount` on CPU — MPS-safe and numerically identical.
135
+ - **Sketch API:** `DCBLayer(max_n_exact=1_000_000, sketch_size=500_000)` — silently sketches to 500K when n exceeds threshold. Justified by O(n⁻²/⁹) convergence of h_crit; 500K sketch matches full-n accuracy.
136
+ - **Consistent bisection domain:** Pre-computed domain passed to all `fft_mode_count` calls in a single bisection, eliminating per-step drift.
137
+ - **Bias warning direction:** Corrected "expected upward bias" to "expected downward bias" on legacy `use_fft=False` path.
138
+ - **Test fixes:** Updated 8 pre-existing test failures (tuple unpacking, bounds, deprecation API).
139
+
140
+ ### v0.1.0 (2026-05-28)
141
+ - Initial PyPI release. FFT forward (O(n + G log G)), IFT backward, MPS support.
142
+
109
143
  ## Repository Structure
110
144
 
111
145
  ```
112
- dcb/ Core PyTorch package (layer.py, solver.py, kde.py, utils.py)
146
+ dcb/ Core PyTorch package
147
+ layer.py DCBLayer nn.Module + DCBFunction autograd
148
+ solver.py IFT root-finder and backward pass
149
+ fft_kde.py FFT-based mode counter (MPS-safe, float64, G=16384)
150
+ kde.py Direct KDE derivatives (small-n path)
151
+ utils.py Grid, Silverman bandwidth, sg() stabilizer
113
152
  experiments/ Reproduction scripts for all paper figures and tables
114
- phase1_validation.py Figure 1: DCB vs reference h_crit scatter
115
- phase1_speedup.py Figure 2: GPU speedup benchmark
116
- phase1_ablation.py Figures S1–S2: ε/τ sensitivity heatmaps
117
- phase2_gan.py Figure 3: GAN mode-collapse prevention
118
- phase3_anomaly.py Table 2 + Figure 5: anomaly detection benchmark
119
- tests/ Unit tests (pytest, 35/35 passing)
153
+ phase1_*.py Validation, speedup, ablation (Figures 1–2, S1–S2)
154
+ phase2_gan.py GAN mode-collapse prevention (Figure 3)
155
+ phase3_anomaly.py Anomaly detection (Table 2, Figure 5)
156
+ round20_*.py Large-n R comparison and streaming benchmarks
157
+ round21_*.py Accuracy improvement experiments
158
+ tests/ Unit tests (pytest, 45 passed, 1 xfailed)
120
159
  outputs/ All generated figures and tables (PDFs, PNGs, CSVs)
121
- notebooks/ Quickstart and demo notebooks
122
160
  ```
123
161
 
124
162
  ## Reproducing Paper Results
@@ -127,7 +165,6 @@ notebooks/ Quickstart and demo notebooks
127
165
  # Phase 1: validation, speedup, ablation
128
166
  python experiments/phase1_validation.py
129
167
  python experiments/phase1_speedup.py
130
- python experiments/phase1_ablation.py
131
168
 
132
169
  # Phase 2: GAN mode collapse experiment
133
170
  python experiments/phase2_gan.py
@@ -136,13 +173,13 @@ python experiments/phase2_gan.py
136
173
  python experiments/phase3_anomaly.py
137
174
  ```
138
175
 
139
- For GPU runs, use the provided Kaggle kernels:
176
+ For GPU runs use the Kaggle kernels:
140
177
  - Phase 1–2: `hsingle/dcb-full-experiments`
141
178
  - Phase 3: `hsingle/dcb-phase-3-anomaly-detection`
142
179
 
143
- ## Kaggle GPU Notes
180
+ ## Paper
144
181
 
145
- Kaggle may assign a P100 (sm_60) instead of T4. The Phase 3 kernel handles this automatically by installing `torch==2.2.2+cu118` (the earliest PyTorch release with both Python 3.12 and sm_60 support) when P100 is detected.
182
+ > Ruiyu Zhang. "Differentiable Critical Bandwidth: Making Silverman's Modality Test End-to-End Trainable." *Journal of Machine Learning Research*, 2026 (in preparation).
146
183
 
147
184
  ## License
148
185
 
diffcb-0.1.3/README.md ADDED
@@ -0,0 +1,129 @@
1
+ # DCB — Differentiable Critical Bandwidth
2
+
3
+ [![PyPI](https://img.shields.io/pypi/v/diffcb.svg)](https://pypi.org/project/diffcb/)
4
+ [![License: MIT](https://img.shields.io/badge/License-MIT-yellow.svg)](LICENSE)
5
+ [![Python 3.9+](https://img.shields.io/badge/python-3.9+-blue.svg)](https://www.python.org/)
6
+
7
+ A PyTorch package that makes **Silverman's critical bandwidth test (1981)** fully differentiable, enabling end-to-end gradient-based optimization over the modal structure of continuous distributions.
8
+
9
+ ## Overview
10
+
11
+ The critical bandwidth `h_crit` is the minimum KDE bandwidth at which a distribution appears to have at most `m` modes — a classical nonparametric statistic for modality testing. DCB replaces every non-differentiable operation in its computation with a smooth surrogate, then uses the **Implicit Function Theorem** to compute exact gradients through the root-finding step at O(1) memory cost.
12
+
13
+ ```python
14
+ import torch
15
+ from dcb import DCBLayer
16
+
17
+ X = torch.randn(1000, requires_grad=True) # 1D samples
18
+ layer = DCBLayer(target_modes=1)
19
+ h_crit = layer(X) # differentiable scalar
20
+ h_crit.backward() # exact IFT gradients
21
+ ```
22
+
23
+ ## Installation
24
+
25
+ ```bash
26
+ pip install diffcb
27
+ ```
28
+
29
+ Or from source:
30
+
31
+ ```bash
32
+ git clone https://github.com/ryZhangHason/differentiable-critical-bandwidth
33
+ cd differentiable-critical-bandwidth
34
+ pip install -e ".[dev]"
35
+ ```
36
+
37
+ ## Accuracy vs R's `bw.crit`
38
+
39
+ DCB is validated against R's `multimode::bw.crit(data, mod0=1)` — the standard reference implementation of Hall & York (2001). On **identical data**:
40
+
41
+ | n | DCB vs R (same sample) | DCB vs R (independent samples) |
42
+ |---|---|---|
43
+ | 100K | **0.004%** | ~0.5% (MC noise from independent RNG) |
44
+ | 1M | **0.005%** | ~0.2% |
45
+ | 10M | **0.004%** | ~0.1% |
46
+
47
+ The independent-sample figures reflect natural sampling variability (two unbiased estimators drawing different data), not algorithmic error. On identical data, DCB agrees with R to within **0.005%** at all tested n. DCB is 43× faster than R at n=100M (1.1 s vs 50 s) and handles n=2B in 24 s while R OOMs.
48
+
49
+ ## Key Parameters
50
+
51
+ ```python
52
+ DCBLayer(
53
+ target_modes=1, # target number of modes
54
+ G=512, # IFT evaluation grid points
55
+ use_fft=True, # FFT forward (default); eliminates subsampling bias for n>50K
56
+ max_n_exact=1_000_000,# sketch to sketch_size when n exceeds this (None = always exact)
57
+ sketch_size=500_000, # sketch target; 500K matches full-n accuracy (O(n^{-2/9}) rate)
58
+ safe_backward=False, # clamp IFT denominator near bifurcations
59
+ )
60
+ ```
61
+
62
+ ## Confirmed Experimental Results
63
+
64
+ All GPU results produced on Kaggle (T4 / P100) — see `experiments/` and `outputs/`.
65
+
66
+ | Experiment | Result | Criterion |
67
+ |---|---|---|
68
+ | **Accuracy vs R (same data, n=100K)** | **0.004%** | < 0.01% ✓ |
69
+ | **Validation (m≥2, Marron-Wand)** | R²=0.91, MAE=0.07, ρ=0.89 | R²≥0.85 ✓ |
70
+ | **Speedup vs scipy (CUDA T4, n=8192)** | **10.5×** | ≥3× ✓ |
71
+ | **GAN mode preservation** | h_crit=1.232 >> 0.3 | h_crit>0.3 ✓ |
72
+ | **Anomaly AUC (KDDCup99)** | DCB=**0.9982** vs IF=0.9867 | DCB≥IF ✓ |
73
+
74
+ ## Changelog
75
+
76
+ ### v0.1.1 (2026-05-29)
77
+ - **MPS fix:** `torch.histc` on MPS allocated an n×bins intermediate (OOM at n≥5M). Replaced with `bucketize+bincount` on CPU — MPS-safe and numerically identical.
78
+ - **Sketch API:** `DCBLayer(max_n_exact=1_000_000, sketch_size=500_000)` — silently sketches to 500K when n exceeds threshold. Justified by O(n⁻²/⁹) convergence of h_crit; 500K sketch matches full-n accuracy.
79
+ - **Consistent bisection domain:** Pre-computed domain passed to all `fft_mode_count` calls in a single bisection, eliminating per-step drift.
80
+ - **Bias warning direction:** Corrected "expected upward bias" to "expected downward bias" on legacy `use_fft=False` path.
81
+ - **Test fixes:** Updated 8 pre-existing test failures (tuple unpacking, bounds, deprecation API).
82
+
83
+ ### v0.1.0 (2026-05-28)
84
+ - Initial PyPI release. FFT forward (O(n + G log G)), IFT backward, MPS support.
85
+
86
+ ## Repository Structure
87
+
88
+ ```
89
+ dcb/ Core PyTorch package
90
+ layer.py DCBLayer nn.Module + DCBFunction autograd
91
+ solver.py IFT root-finder and backward pass
92
+ fft_kde.py FFT-based mode counter (MPS-safe, float64, G=16384)
93
+ kde.py Direct KDE derivatives (small-n path)
94
+ utils.py Grid, Silverman bandwidth, sg() stabilizer
95
+ experiments/ Reproduction scripts for all paper figures and tables
96
+ phase1_*.py Validation, speedup, ablation (Figures 1–2, S1–S2)
97
+ phase2_gan.py GAN mode-collapse prevention (Figure 3)
98
+ phase3_anomaly.py Anomaly detection (Table 2, Figure 5)
99
+ round20_*.py Large-n R comparison and streaming benchmarks
100
+ round21_*.py Accuracy improvement experiments
101
+ tests/ Unit tests (pytest, 45 passed, 1 xfailed)
102
+ outputs/ All generated figures and tables (PDFs, PNGs, CSVs)
103
+ ```
104
+
105
+ ## Reproducing Paper Results
106
+
107
+ ```bash
108
+ # Phase 1: validation, speedup, ablation
109
+ python experiments/phase1_validation.py
110
+ python experiments/phase1_speedup.py
111
+
112
+ # Phase 2: GAN mode collapse experiment
113
+ python experiments/phase2_gan.py
114
+
115
+ # Phase 3: anomaly detection benchmark
116
+ python experiments/phase3_anomaly.py
117
+ ```
118
+
119
+ For GPU runs use the Kaggle kernels:
120
+ - Phase 1–2: `hsingle/dcb-full-experiments`
121
+ - Phase 3: `hsingle/dcb-phase-3-anomaly-detection`
122
+
123
+ ## Paper
124
+
125
+ > Ruiyu Zhang. "Differentiable Critical Bandwidth: Making Silverman's Modality Test End-to-End Trainable." *Journal of Machine Learning Research*, 2026 (in preparation).
126
+
127
+ ## License
128
+
129
+ MIT — see [LICENSE](LICENSE).
@@ -19,4 +19,4 @@ __all__ = [
19
19
  "DCBLayer", "DifferentiableCriticalBandwidth",
20
20
  "anneal_eps_tau", "soft_mode_count_cross", "soft_mode_count",
21
21
  ]
22
- __version__ = "0.1.1"
22
+ __version__ = "0.1.3"
@@ -0,0 +1,262 @@
1
+ """
2
+ dcb.fft_kde — FFT-based KDE Mode Counter
3
+
4
+ Implements mode counting via FFT convolution of the histogram with a
5
+ Gaussian derivative kernel. Complexity is O(n + G log G), avoiding the
6
+ O(n × G) cost of the direct KDE approach and — crucially — requiring NO
7
+ subsampling. This eliminates the (brentq_n_max / n)^{-1/5} upward bias
8
+ that affects the standard bisection path when n > brentq_n_max.
9
+
10
+ Round 18b: forward kernel only. The IFT backward is unchanged (still uses
11
+ the analytical chunked KDE derivatives on all n points).
12
+ """
13
+
14
+ from __future__ import annotations
15
+
16
+ import math
17
+
18
+ import torch
19
+ from torch import Tensor
20
+
21
+
22
+ def fft_mode_count(
23
+ X: Tensor,
24
+ h: float,
25
+ G: int = 4096,
26
+ pad_factor: int = 4,
27
+ domain: tuple[float, float] | None = None,
28
+ ) -> int:
29
+ """Count KDE modes via FFT convolution — O(n + G log G), no subsampling.
30
+
31
+ Bins X into G histogram bins, zero-pads to pad_factor*G, convolves with
32
+ the Gaussian derivative kernel in the frequency domain (applying iω·exp(−½(ωh)²)),
33
+ back-transforms, and counts positive-to-negative sign changes of the
34
+ resulting f' estimate.
35
+
36
+ Parameters
37
+ ----------
38
+ X : Tensor, shape (n,)
39
+ 1D data tensor (may be on CPU or CUDA).
40
+ h : float
41
+ Bandwidth for the Gaussian kernel.
42
+ G : int
43
+ Number of histogram bins. Must satisfy h > 8 * (data_range / G) for
44
+ reliable derivative estimation. Use `adaptive_fft_G` to choose G
45
+ automatically before bisection.
46
+ pad_factor : int
47
+ Zero-padding multiplier (default 4). Mandatory ≥ 2 for circular-wrap
48
+ correctness; 4 is recommended at the largest h encountered.
49
+ domain : (lo, hi) or None
50
+ If provided, use this as the histogram domain instead of computing
51
+ X.min() - 3σ … X.max() + 3σ. Allows the caller to align the domain
52
+ with the bisection bracket (e.g., X.min() - 2*h_hi … X.max() + 2*h_hi)
53
+ so every fft_mode_count call in a bisection loop uses an identical grid.
54
+
55
+ Returns
56
+ -------
57
+ int
58
+ Number of KDE modes (downward zero-crossings of f').
59
+ """
60
+ with torch.no_grad():
61
+ if domain is not None:
62
+ lo, hi = domain
63
+ else:
64
+ # Domain: extend 3σ beyond data range to avoid boundary effects
65
+ sigma = X.std().item()
66
+ if sigma == 0.0:
67
+ sigma = 1.0 # degenerate case: all points identical
68
+ lo = X.min().item() - 3 * sigma
69
+ hi = X.max().item() + 3 * sigma
70
+ data_range = hi - lo
71
+
72
+ if data_range == 0.0:
73
+ return 1 # single-point distribution has 1 mode
74
+
75
+ # Histogram (O(n)) — MPS-safe via bucketize+bincount on CPU.
76
+ # torch.histc on MPS allocates an n × bins float32 intermediate (PyTorch
77
+ # MPS bug); at n=5M, bins=512 this is ~9.5 GiB → OOM. Moving to CPU for
78
+ # the binning step avoids the intermediate and is numerically identical
79
+ # for data within [lo, hi] (guaranteed by the 3σ domain extension above).
80
+ X_cpu = X.float().cpu()
81
+ edges = torch.linspace(lo, hi, G + 1) # (G+1,) CPU
82
+ bin_idx = torch.bucketize(X_cpu, edges, right=True).clamp(1, G) - 1 # 0-indexed
83
+ counts = torch.bincount(bin_idx, minlength=G).float().to(X.device) # back to device
84
+
85
+ # Zero-pad to pad_factor*G — promote to float64 for FFT precision
86
+ N = pad_factor * G
87
+ counts_padded = torch.zeros(N, dtype=torch.float64, device=X.device)
88
+ counts_padded[:G] = counts.double()
89
+
90
+ # FFT of histogram (float64)
91
+ C = torch.fft.rfft(counts_padded)
92
+
93
+ # Derivative kernel in frequency domain (float64)
94
+ bin_width = data_range / G
95
+ k = torch.arange(N // 2 + 1, device=X.device, dtype=torch.float64)
96
+ omega = 2 * math.pi * k / (N * bin_width)
97
+ K_deriv = 1j * omega * torch.exp(-0.5 * (omega * h) ** 2)
98
+
99
+ # Convolve and back-transform; cast result back to float32
100
+ f_prime_padded = torch.fft.irfft(C * K_deriv, n=N).float()
101
+
102
+ # Trim to original G grid (discard zero-padded tail)
103
+ f_prime = f_prime_padded[:G]
104
+
105
+ # Count (+→-) sign changes = number of modes
106
+ # A mode is a local max of f, i.e., f' crosses zero from + to -
107
+ # Remove zeros (flat segments) — carry forward last nonzero sign
108
+ nonzero_mask = f_prime != 0
109
+ if not nonzero_mask.any():
110
+ return 0
111
+
112
+ s = f_prime[nonzero_mask]
113
+ transitions = int(((s[:-1] > 0) & (s[1:] < 0)).sum().item())
114
+ return transitions
115
+
116
+
117
+ def _refine_hcrit(
118
+ X: Tensor,
119
+ h_lo: float,
120
+ h_hi: float,
121
+ G: int,
122
+ domain: tuple[float, float],
123
+ target_modes: int = 1,
124
+ pad_factor: int = 4,
125
+ ) -> float:
126
+ """Sub-bin quadratic refinement of h_crit after bisection converges.
127
+
128
+ Identifies the f′ lobe that disappears at the mode-merging bandwidth and
129
+ fits a quadratic in h to that lobe's peak value, returning the root — the
130
+ h where that peak exactly reaches zero. Reduces the bin-width-limited
131
+ systematic from ~bin_width/h_crit to well below 1e-4.
132
+
133
+ When the incoming bracket [h_lo, h_hi] is tighter than one histogram bin
134
+ width (the common case after 50-step bisection), the function expands the
135
+ bracket outward from h_hi by up to 4× the bin width while maintaining the
136
+ invariant that fft_mode_count > target at the left endpoint and
137
+ <= target at the right endpoint, so the disappearing f′ lobe is visible
138
+ across the bracket.
139
+
140
+ Parameters
141
+ ----------
142
+ X : Tensor — data (may be on any device)
143
+ h_lo, h_hi : float — final bisection bracket; fft_mode_count(X,h_lo) > target,
144
+ fft_mode_count(X,h_hi) <= target
145
+ G, domain, target_modes, pad_factor — same as fft_mode_count
146
+
147
+ Returns
148
+ -------
149
+ float — refined h_crit, guaranteed to lie in [h_lo, h_hi] of the
150
+ (possibly expanded) bracket used for fitting.
151
+ """
152
+ import numpy as np
153
+
154
+ lo_d, hi_d = domain
155
+ data_range = hi_d - lo_d
156
+ if data_range == 0.0:
157
+ return h_hi
158
+
159
+ bin_width = data_range / G
160
+ N = pad_factor * G
161
+ bw = bin_width # histogram bin width
162
+
163
+ # Pre-compute histogram once; reuse C (FFT of counts) for all h evaluations.
164
+ with torch.no_grad():
165
+ X_cpu = X.float().cpu()
166
+ edges = torch.linspace(lo_d, hi_d, G + 1)
167
+ bin_idx = torch.bucketize(X_cpu, edges, right=True).clamp(1, G) - 1
168
+ counts = torch.bincount(bin_idx, minlength=G).float()
169
+ counts_padded = torch.zeros(N, dtype=torch.float64)
170
+ counts_padded[:G] = counts.double()
171
+ C = torch.fft.rfft(counts_padded)
172
+ k = torch.arange(N // 2 + 1, dtype=torch.float64)
173
+ omega_base = 2 * math.pi * k / (N * bw)
174
+
175
+ def fprime(h: float) -> Tensor:
176
+ """Compute f′ array (shape G,) for bandwidth h using cached C (float64)."""
177
+ K_deriv = 1j * omega_base * torch.exp(-0.5 * (omega_base * h) ** 2)
178
+ return torch.fft.irfft(C * K_deriv, n=N).float()[:G]
179
+
180
+ with torch.no_grad():
181
+ # If the bracket is tighter than bin_width, expand it so that the
182
+ # disappearing f′ lobe crosses zero somewhere inside the bracket.
183
+ # Expand the left endpoint leftward by up to 4 bin widths.
184
+ ref_lo = h_lo
185
+ ref_hi = h_hi
186
+
187
+ if (ref_hi - ref_lo) < bw:
188
+ # Try expanding leftward until we find a bin where fp crosses zero
189
+ for mult in [1, 2, 3, 4]:
190
+ cand_lo = max(ref_hi - mult * bw, ref_hi * 0.9)
191
+ fp_cand = fprime(cand_lo)
192
+ fp_hi_ = fprime(ref_hi)
193
+ cm = (fp_cand > 0) & (fp_hi_ <= 0)
194
+ if cm.any():
195
+ ref_lo = cand_lo
196
+ break
197
+ # If still no candidates found, return bisection result unchanged
198
+ fp_lo_ = fprime(ref_lo)
199
+ fp_hi_ = fprime(ref_hi)
200
+ candidate_mask = (fp_lo_ > 0) & (fp_hi_ <= 0)
201
+ if not candidate_mask.any():
202
+ return h_hi
203
+ else:
204
+ fp_lo_ = fprime(ref_lo)
205
+ fp_hi_ = fprime(ref_hi)
206
+ candidate_mask = (fp_lo_ > 0) & (fp_hi_ <= 0)
207
+ if not candidate_mask.any():
208
+ return h_hi
209
+
210
+ # Pick the bin with the largest positive value at ref_lo that crossed zero
211
+ masked_fp_lo = fp_lo_.clone()
212
+ masked_fp_lo[~candidate_mask] = -float('inf')
213
+ j = int(masked_fp_lo.argmax().item())
214
+
215
+ h_mid = (ref_lo + ref_hi) / 2.0
216
+
217
+ # Evaluate fp[j] at three bandwidths for quadratic fit
218
+ y_lo = fp_lo_[j].item()
219
+ y_mid = fprime(h_mid)[j].item()
220
+ y_hi = fp_hi_[j].item()
221
+
222
+ # Fit quadratic y = a*h² + b*h + c through the three (h, y) pairs
223
+ # and solve for the root in [ref_lo, ref_hi].
224
+ coeffs = np.polyfit([ref_lo, h_mid, ref_hi], [y_lo, y_mid, y_hi], 2)
225
+ roots = np.roots(coeffs)
226
+ real_roots = [
227
+ r.real for r in roots
228
+ if abs(r.imag) < 1e-10 * abs(r.real + 1e-30)
229
+ and ref_lo <= r.real <= ref_hi
230
+ ]
231
+ if real_roots:
232
+ return float(min(real_roots, key=lambda r: abs(r - h_mid)))
233
+ return h_hi
234
+
235
+
236
+ def adaptive_fft_G(data_range: float, h_hi: float, G_min: int = 16384) -> int:
237
+ """Choose FFT grid size G so that the derivative kernel is well-resolved.
238
+
239
+ Requires h > 8 * bin_width = 8 * data_range / G, equivalently
240
+ G > 8 * data_range / h_hi. We use a factor of 16 for safety margin,
241
+ then round up to the next power of 2 for efficient FFT.
242
+
243
+ Parameters
244
+ ----------
245
+ data_range : float
246
+ hi - lo of the data domain (typically X.max() - X.min() + 6σ).
247
+ h_hi : float
248
+ Upper bracket of the bisection (smallest h needing resolution).
249
+ G_min : int
250
+ Minimum returned G (default 16384).
251
+
252
+ Returns
253
+ -------
254
+ int
255
+ Grid size G, a power of 2, at least G_min.
256
+ """
257
+ needed = 16 * math.ceil(data_range / h_hi)
258
+ # Round up to next power of 2
259
+ p = 1
260
+ while p < needed:
261
+ p <<= 1
262
+ return max(G_min, p)
@@ -35,13 +35,13 @@ class DCBFunction(torch.autograd.Function):
35
35
 
36
36
  @staticmethod
37
37
  def forward(ctx, X, grid, eps, tau, target_modes, delta, formula, chunk_size,
38
- brentq_n_max, g_brentq, use_hard_bisection, safe_backward, use_fft):
38
+ brentq_n_max, g_brentq, use_hard_bisection, safe_backward, use_fft, fft_G_min):
39
39
  """Locate h_crit and save state for the backward pass."""
40
40
  h_crit, cond_num = find_h_crit(
41
41
  X, grid, eps, tau, target_modes,
42
42
  formula=formula, brentq_n_max=brentq_n_max, chunk_size=chunk_size,
43
43
  g_brentq=g_brentq, use_hard_bisection=use_hard_bisection,
44
- use_fft=use_fft,
44
+ use_fft=use_fft, G_min=fft_G_min,
45
45
  )
46
46
  ctx.save_for_backward(X, grid)
47
47
  ctx.h_crit = h_crit
@@ -67,8 +67,8 @@ class DCBFunction(torch.autograd.Function):
67
67
  ctx.denom_abs = ift_gradient.last_denom_abs
68
68
  # Gradients for: X, grid, eps, tau, target_modes, delta, formula,
69
69
  # chunk_size, brentq_n_max, g_brentq, use_hard_bisection,
70
- # safe_backward, use_fft
71
- return grad_X, None, None, None, None, None, None, None, None, None, None, None, None
70
+ # safe_backward, use_fft, fft_G_min
71
+ return grad_X, None, None, None, None, None, None, None, None, None, None, None, None, None
72
72
 
73
73
 
74
74
  class DCBLayer(nn.Module):
@@ -133,6 +133,13 @@ class DCBLayer(nn.Module):
133
133
  Number of points to sketch when n > max_n_exact. Default 500_000.
134
134
  A 500K sketch achieves the same mean accuracy as streaming 100M points
135
135
  (validated in Round 20 reservoir experiment).
136
+ fft_G_min : int
137
+ Minimum FFT histogram grid size for the bisection solver (default 16384).
138
+ Controls accuracy of the FFT path (n > 50K). Larger values reduce
139
+ discretisation error at a modest cost: G=16384 gives ~0.004% err vs R;
140
+ G=32768 gives ~0.001% at +9% cost; G=65536 reaches the R-matching floor
141
+ (~0.001%) with no further gain beyond that. Ignored for n ≤ 50K (direct
142
+ KDE path).
136
143
 
137
144
  Examples
138
145
  --------
@@ -162,6 +169,7 @@ class DCBLayer(nn.Module):
162
169
  use_fft: bool = True,
163
170
  max_n_exact: int | None = 1_000_000,
164
171
  sketch_size: int = 500_000,
172
+ fft_G_min: int = 16384,
165
173
  ):
166
174
  super().__init__()
167
175
  self.target_modes = target_modes
@@ -180,6 +188,7 @@ class DCBLayer(nn.Module):
180
188
  self.use_fft = use_fft
181
189
  self.max_n_exact = max_n_exact
182
190
  self.sketch_size = sketch_size
191
+ self.fft_G_min = fft_G_min
183
192
  if use_fft and brentq_n_max != 50_000:
184
193
  raise TypeError(
185
194
  f"brentq_n_max={brentq_n_max} is meaningless when use_fft=True: the FFT path "
@@ -250,7 +259,7 @@ class DCBLayer(nn.Module):
250
259
  return DCBFunction.apply(
251
260
  X, grid, eps_eff, tau_eff, self.target_modes, self.delta, self.formula,
252
261
  self.chunk_size, self.brentq_n_max, self.g_brentq, self.use_hard_bisection,
253
- self.safe_backward, self.use_fft,
262
+ self.safe_backward, self.use_fft, self.fft_G_min,
254
263
  )
255
264
 
256
265
 
@@ -74,6 +74,7 @@ def find_h_crit_hard(
74
74
  eps: float = 0.1,
75
75
  tau: float = 0.2,
76
76
  use_fft: bool = False,
77
+ G_min: int = 16384,
77
78
  ) -> tuple[float, float]:
78
79
  """Find h_crit via hard-mode-count bisection (monotone, no false roots).
79
80
 
@@ -151,7 +152,7 @@ def find_h_crit_hard(
151
152
  lo_domain = X.min().item() - 3 * sigma
152
153
  hi_domain = X.max().item() + 3 * sigma
153
154
  data_range = hi_domain - lo_domain
154
- G_fft = adaptive_fft_G(data_range, h_hi)
155
+ G_fft = adaptive_fft_G(data_range, h_hi, G_min=G_min)
155
156
  _domain = (lo_domain, hi_domain)
156
157
 
157
158
  with torch.no_grad():
@@ -188,6 +189,11 @@ def find_h_crit_hard(
188
189
 
189
190
  h_crit = float(hi) # smallest h with count <= target_modes
190
191
 
192
+ # Sub-bin refinement: quadratic interpolation on the disappearing f′ lobe
193
+ # to locate h_crit below the bin-width precision limit.
194
+ from dcb.fft_kde import _refine_hcrit
195
+ h_crit = _refine_hcrit(X, lo, hi, G_fft, _domain, target_modes)
196
+
191
197
  else:
192
198
  with torch.no_grad():
193
199
  # Verify bracket: need count > target at h_lo, count <= target at h_hi.
@@ -290,6 +296,7 @@ def find_h_crit(
290
296
  g_brentq: int = 128,
291
297
  use_hard_bisection: bool = True,
292
298
  use_fft: bool = True,
299
+ G_min: int = 16384,
293
300
  ) -> tuple[float, float]:
294
301
  """Find h_crit and return (h_crit, condition_number).
295
302
 
@@ -343,7 +350,7 @@ def find_h_crit(
343
350
  return find_h_crit_hard(
344
351
  X, grid, target_modes, chunk_size, brentq_n_max,
345
352
  h_lo, h_hi, formula=formula, eps=eps, tau=tau,
346
- use_fft=use_fft,
353
+ use_fft=use_fft, G_min=G_min,
347
354
  )
348
355
 
349
356
  from scipy.optimize import brentq
@@ -4,7 +4,7 @@ build-backend = "hatchling.build"
4
4
 
5
5
  [project]
6
6
  name = "diffcb"
7
- version = "0.1.1"
7
+ version = "0.1.3"
8
8
  description = "Differentiable Critical Bandwidth: Silverman's modality test as a differentiable PyTorch layer with IFT backward pass."
9
9
  readme = "README.md"
10
10
  license = { file = "LICENSE" }
diffcb-0.1.1/README.md DELETED
@@ -1,92 +0,0 @@
1
- # DCB — Differentiable Critical Bandwidth
2
-
3
- [![PyPI](https://img.shields.io/pypi/v/diffcb.svg)](https://pypi.org/project/diffcb/)
4
- [![License: MIT](https://img.shields.io/badge/License-MIT-yellow.svg)](LICENSE)
5
- [![Python 3.9+](https://img.shields.io/badge/python-3.9+-blue.svg)](https://www.python.org/)
6
-
7
- A PyTorch package that makes **Silverman's critical bandwidth test (1981)** fully differentiable, enabling end-to-end gradient-based optimization over the modal structure of continuous distributions.
8
-
9
- ## Overview
10
-
11
- The critical bandwidth `h_crit` is the minimum KDE bandwidth at which a distribution appears to have at most `m` modes — a classical nonparametric statistic for modality testing. DCB replaces every non-differentiable operation in its computation with a smooth surrogate, then uses the **Implicit Function Theorem** to compute exact gradients through the root-finding step at O(1) memory cost.
12
-
13
- ```python
14
- import torch
15
- from dcb import DCBLayer
16
-
17
- X = torch.randn(256, requires_grad=True) # 1D samples
18
- layer = DCBLayer(target_modes=1)
19
- h_crit = layer(X) # differentiable scalar
20
- h_crit.backward() # exact IFT gradients
21
- ```
22
-
23
- ## Installation
24
-
25
- ```bash
26
- pip install diffcb
27
- ```
28
-
29
- Or from source:
30
-
31
- ```bash
32
- git clone https://github.com/ryZhangHason/differentiable-critical-bandwidth
33
- cd differentiable-critical-bandwidth
34
- pip install -e ".[dev]"
35
- ```
36
-
37
- ## Paper
38
-
39
- > Ruiyu Zhang. "Differentiable Critical Bandwidth: Making Silverman's Modality Test End-to-End Trainable." *Journal of Machine Learning Research*, 2026 (in preparation).
40
-
41
- ## Confirmed Experimental Results
42
-
43
- All results produced on Kaggle GPU (T4 / P100) — see `experiments/` and `outputs/`.
44
-
45
- | Experiment | Result | Criterion |
46
- |---|---|---|
47
- | **Validation (m≥2)** | R²=0.91, MAE=0.07, Spearman ρ=0.89 | R²≥0.85, MAE≤0.10 ✓ |
48
- | **Speedup vs scipy (n=8192)** | **10.5×** on T4 | ≥3× ✓ |
49
- | **GAN mode preservation** | h_crit=1.232 >> 0.3 | h_crit>0.3 ✓ |
50
- | **Anomaly AUC (KDDCup99)** | DCB=**0.9982** vs IF=0.9867 | DCB≥IF ✓ |
51
-
52
- ## Repository Structure
53
-
54
- ```
55
- dcb/ Core PyTorch package (layer.py, solver.py, kde.py, utils.py)
56
- experiments/ Reproduction scripts for all paper figures and tables
57
- phase1_validation.py Figure 1: DCB vs reference h_crit scatter
58
- phase1_speedup.py Figure 2: GPU speedup benchmark
59
- phase1_ablation.py Figures S1–S2: ε/τ sensitivity heatmaps
60
- phase2_gan.py Figure 3: GAN mode-collapse prevention
61
- phase3_anomaly.py Table 2 + Figure 5: anomaly detection benchmark
62
- tests/ Unit tests (pytest, 35/35 passing)
63
- outputs/ All generated figures and tables (PDFs, PNGs, CSVs)
64
- notebooks/ Quickstart and demo notebooks
65
- ```
66
-
67
- ## Reproducing Paper Results
68
-
69
- ```bash
70
- # Phase 1: validation, speedup, ablation
71
- python experiments/phase1_validation.py
72
- python experiments/phase1_speedup.py
73
- python experiments/phase1_ablation.py
74
-
75
- # Phase 2: GAN mode collapse experiment
76
- python experiments/phase2_gan.py
77
-
78
- # Phase 3: anomaly detection benchmark
79
- python experiments/phase3_anomaly.py
80
- ```
81
-
82
- For GPU runs, use the provided Kaggle kernels:
83
- - Phase 1–2: `hsingle/dcb-full-experiments`
84
- - Phase 3: `hsingle/dcb-phase-3-anomaly-detection`
85
-
86
- ## Kaggle GPU Notes
87
-
88
- Kaggle may assign a P100 (sm_60) instead of T4. The Phase 3 kernel handles this automatically by installing `torch==2.2.2+cu118` (the earliest PyTorch release with both Python 3.12 and sm_60 support) when P100 is detected.
89
-
90
- ## License
91
-
92
- MIT — see [LICENSE](LICENSE).
@@ -1,144 +0,0 @@
1
- """
2
- dcb.fft_kde — FFT-based KDE Mode Counter
3
-
4
- Implements mode counting via FFT convolution of the histogram with a
5
- Gaussian derivative kernel. Complexity is O(n + G log G), avoiding the
6
- O(n × G) cost of the direct KDE approach and — crucially — requiring NO
7
- subsampling. This eliminates the (brentq_n_max / n)^{-1/5} upward bias
8
- that affects the standard bisection path when n > brentq_n_max.
9
-
10
- Round 18b: forward kernel only. The IFT backward is unchanged (still uses
11
- the analytical chunked KDE derivatives on all n points).
12
- """
13
-
14
- from __future__ import annotations
15
-
16
- import math
17
-
18
- import torch
19
- from torch import Tensor
20
-
21
-
22
- def fft_mode_count(
23
- X: Tensor,
24
- h: float,
25
- G: int = 4096,
26
- pad_factor: int = 4,
27
- domain: tuple[float, float] | None = None,
28
- ) -> int:
29
- """Count KDE modes via FFT convolution — O(n + G log G), no subsampling.
30
-
31
- Bins X into G histogram bins, zero-pads to pad_factor*G, convolves with
32
- the Gaussian derivative kernel in the frequency domain (applying iω·exp(−½(ωh)²)),
33
- back-transforms, and counts positive-to-negative sign changes of the
34
- resulting f' estimate.
35
-
36
- Parameters
37
- ----------
38
- X : Tensor, shape (n,)
39
- 1D data tensor (may be on CPU or CUDA).
40
- h : float
41
- Bandwidth for the Gaussian kernel.
42
- G : int
43
- Number of histogram bins. Must satisfy h > 8 * (data_range / G) for
44
- reliable derivative estimation. Use `adaptive_fft_G` to choose G
45
- automatically before bisection.
46
- pad_factor : int
47
- Zero-padding multiplier (default 4). Mandatory ≥ 2 for circular-wrap
48
- correctness; 4 is recommended at the largest h encountered.
49
- domain : (lo, hi) or None
50
- If provided, use this as the histogram domain instead of computing
51
- X.min() - 3σ … X.max() + 3σ. Allows the caller to align the domain
52
- with the bisection bracket (e.g., X.min() - 2*h_hi … X.max() + 2*h_hi)
53
- so every fft_mode_count call in a bisection loop uses an identical grid.
54
-
55
- Returns
56
- -------
57
- int
58
- Number of KDE modes (downward zero-crossings of f').
59
- """
60
- with torch.no_grad():
61
- if domain is not None:
62
- lo, hi = domain
63
- else:
64
- # Domain: extend 3σ beyond data range to avoid boundary effects
65
- sigma = X.std().item()
66
- if sigma == 0.0:
67
- sigma = 1.0 # degenerate case: all points identical
68
- lo = X.min().item() - 3 * sigma
69
- hi = X.max().item() + 3 * sigma
70
- data_range = hi - lo
71
-
72
- if data_range == 0.0:
73
- return 1 # single-point distribution has 1 mode
74
-
75
- # Histogram (O(n)) — MPS-safe via bucketize+bincount on CPU.
76
- # torch.histc on MPS allocates an n × bins float32 intermediate (PyTorch
77
- # MPS bug); at n=5M, bins=512 this is ~9.5 GiB → OOM. Moving to CPU for
78
- # the binning step avoids the intermediate and is numerically identical
79
- # for data within [lo, hi] (guaranteed by the 3σ domain extension above).
80
- X_cpu = X.float().cpu()
81
- edges = torch.linspace(lo, hi, G + 1) # (G+1,) CPU
82
- bin_idx = torch.bucketize(X_cpu, edges, right=True).clamp(1, G) - 1 # 0-indexed
83
- counts = torch.bincount(bin_idx, minlength=G).float().to(X.device) # back to device
84
-
85
- # Zero-pad to pad_factor*G (4× mandatory for circular wrap correctness at h_hi)
86
- N = pad_factor * G
87
- counts_padded = torch.zeros(N, dtype=torch.float32, device=X.device)
88
- counts_padded[:G] = counts
89
-
90
- # FFT of histogram
91
- C = torch.fft.rfft(counts_padded)
92
-
93
- # Derivative kernel in frequency domain: iω * exp(-0.5*(ω*h)²)
94
- # ω_k = 2π*k / (N * bin_width), bin_width = data_range / G
95
- bin_width = data_range / G
96
- k = torch.arange(N // 2 + 1, device=X.device, dtype=torch.float32)
97
- omega = 2 * math.pi * k / (N * bin_width)
98
- K_deriv = 1j * omega * torch.exp(-0.5 * (omega * h) ** 2)
99
-
100
- # Convolve and back-transform
101
- f_prime_padded = torch.fft.irfft(C * K_deriv, n=N)
102
-
103
- # Trim to original G grid (discard zero-padded tail)
104
- f_prime = f_prime_padded[:G]
105
-
106
- # Count (+→-) sign changes = number of modes
107
- # A mode is a local max of f, i.e., f' crosses zero from + to -
108
- # Remove zeros (flat segments) — carry forward last nonzero sign
109
- nonzero_mask = f_prime != 0
110
- if not nonzero_mask.any():
111
- return 0
112
-
113
- s = f_prime[nonzero_mask]
114
- transitions = int(((s[:-1] > 0) & (s[1:] < 0)).sum().item())
115
- return transitions
116
-
117
-
118
- def adaptive_fft_G(data_range: float, h_hi: float, G_min: int = 4096) -> int:
119
- """Choose FFT grid size G so that the derivative kernel is well-resolved.
120
-
121
- Requires h > 8 * bin_width = 8 * data_range / G, equivalently
122
- G > 8 * data_range / h_hi. We use a factor of 16 for safety margin,
123
- then round up to the next power of 2 for efficient FFT.
124
-
125
- Parameters
126
- ----------
127
- data_range : float
128
- hi - lo of the data domain (typically X.max() - X.min() + 6σ).
129
- h_hi : float
130
- Upper bracket of the bisection (smallest h needing resolution).
131
- G_min : int
132
- Minimum returned G (default 4096).
133
-
134
- Returns
135
- -------
136
- int
137
- Grid size G, a power of 2, at least G_min.
138
- """
139
- needed = 16 * math.ceil(data_range / h_hi)
140
- # Round up to next power of 2
141
- p = 1
142
- while p < needed:
143
- p <<= 1
144
- return max(G_min, p)
File without changes
File without changes
File without changes
File without changes
File without changes
File without changes
File without changes
File without changes
File without changes
File without changes