diffcb 0.1.1__tar.gz → 0.1.4__tar.gz

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
@@ -1,6 +1,6 @@
1
1
  Metadata-Version: 2.4
2
2
  Name: diffcb
3
- Version: 0.1.1
3
+ Version: 0.1.4
4
4
  Summary: Differentiable Critical Bandwidth: Silverman's modality test as a differentiable PyTorch layer with IFT backward pass.
5
5
  Project-URL: Homepage, https://github.com/ryZhangHason/differentiable-critical-bandwidth
6
6
  Project-URL: Repository, https://github.com/ryZhangHason/differentiable-critical-bandwidth
@@ -71,10 +71,10 @@ The critical bandwidth `h_crit` is the minimum KDE bandwidth at which a distribu
71
71
  import torch
72
72
  from dcb import DCBLayer
73
73
 
74
- X = torch.randn(256, requires_grad=True) # 1D samples
74
+ X = torch.randn(1000, requires_grad=True) # 1D samples
75
75
  layer = DCBLayer(target_modes=1)
76
- h_crit = layer(X) # differentiable scalar
77
- h_crit.backward() # exact IFT gradients
76
+ h_crit = layer(X) # differentiable scalar
77
+ h_crit.backward() # exact IFT gradients
78
78
  ```
79
79
 
80
80
  ## Installation
@@ -91,34 +91,72 @@ cd differentiable-critical-bandwidth
91
91
  pip install -e ".[dev]"
92
92
  ```
93
93
 
94
- ## Paper
94
+ ## Accuracy vs R's `bw.crit`
95
95
 
96
- > Ruiyu Zhang. "Differentiable Critical Bandwidth: Making Silverman's Modality Test End-to-End Trainable." *Journal of Machine Learning Research*, 2026 (in preparation).
96
+ DCB is validated against R's `multimode::bw.crit(data, mod0=1)` the standard reference implementation of Hall & York (2001). On **identical data**:
97
+
98
+ | n | DCB vs R (same sample) | DCB vs R (independent samples) |
99
+ |---|---|---|
100
+ | 100K | **0.004%** | ~0.5% (MC noise from independent RNG) |
101
+ | 1M | **0.005%** | ~0.2% |
102
+ | 10M | **0.004%** | ~0.1% |
103
+
104
+ The independent-sample figures reflect natural sampling variability (two unbiased estimators drawing different data), not algorithmic error. On identical data, DCB agrees with R to within **0.005%** at all tested n. DCB is 43× faster than R at n=100M (1.1 s vs 50 s) and handles n=2B in 24 s while R OOMs.
105
+
106
+ ## Key Parameters
107
+
108
+ ```python
109
+ DCBLayer(
110
+ target_modes=1, # target number of modes
111
+ G=512, # IFT evaluation grid points
112
+ use_fft=True, # FFT forward (default); eliminates subsampling bias for n>50K
113
+ max_n_exact=1_000_000,# sketch to sketch_size when n exceeds this (None = always exact)
114
+ sketch_size=500_000, # sketch target; 500K matches full-n accuracy (O(n^{-2/9}) rate)
115
+ safe_backward=False, # clamp IFT denominator near bifurcations
116
+ )
117
+ ```
97
118
 
98
119
  ## Confirmed Experimental Results
99
120
 
100
- All results produced on Kaggle GPU (T4 / P100) — see `experiments/` and `outputs/`.
121
+ All GPU results produced on Kaggle (T4 / P100) — see `experiments/` and `outputs/`.
101
122
 
102
123
  | Experiment | Result | Criterion |
103
124
  |---|---|---|
104
- | **Validation (m≥2)** | R²=0.91, MAE=0.07, Spearman ρ=0.89 | R²≥0.85, MAE≤0.10 ✓ |
105
- | **Speedup vs scipy (n=8192)** | **10.5×** on T4 | ≥3× ✓ |
125
+ | **Accuracy vs R (same data, n=100K)** | **0.004%** | < 0.01% ✓ |
126
+ | **Validation (m≥2, Marron-Wand)** | R²=0.91, MAE=0.07, ρ=0.89 | R²≥0.85 ✓ |
127
+ | **Speedup vs scipy (CUDA T4, n=8192)** | **10.5×** | ≥3× ✓ |
106
128
  | **GAN mode preservation** | h_crit=1.232 >> 0.3 | h_crit>0.3 ✓ |
107
129
  | **Anomaly AUC (KDDCup99)** | DCB=**0.9982** vs IF=0.9867 | DCB≥IF ✓ |
108
130
 
131
+ ## Changelog
132
+
133
+ ### v0.1.1 (2026-05-29)
134
+ - **MPS fix:** `torch.histc` on MPS allocated an n×bins intermediate (OOM at n≥5M). Replaced with `bucketize+bincount` on CPU — MPS-safe and numerically identical.
135
+ - **Sketch API:** `DCBLayer(max_n_exact=1_000_000, sketch_size=500_000)` — silently sketches to 500K when n exceeds threshold. Justified by O(n⁻²/⁹) convergence of h_crit; 500K sketch matches full-n accuracy.
136
+ - **Consistent bisection domain:** Pre-computed domain passed to all `fft_mode_count` calls in a single bisection, eliminating per-step drift.
137
+ - **Bias warning direction:** Corrected "expected upward bias" to "expected downward bias" on legacy `use_fft=False` path.
138
+ - **Test fixes:** Updated 8 pre-existing test failures (tuple unpacking, bounds, deprecation API).
139
+
140
+ ### v0.1.0 (2026-05-28)
141
+ - Initial PyPI release. FFT forward (O(n + G log G)), IFT backward, MPS support.
142
+
109
143
  ## Repository Structure
110
144
 
111
145
  ```
112
- dcb/ Core PyTorch package (layer.py, solver.py, kde.py, utils.py)
146
+ dcb/ Core PyTorch package
147
+ layer.py DCBLayer nn.Module + DCBFunction autograd
148
+ solver.py IFT root-finder and backward pass
149
+ fft_kde.py FFT-based mode counter (MPS-safe, float64, G=16384)
150
+ kde.py Direct KDE derivatives (small-n path)
151
+ utils.py Grid, Silverman bandwidth, sg() stabilizer
113
152
  experiments/ Reproduction scripts for all paper figures and tables
114
- phase1_validation.py Figure 1: DCB vs reference h_crit scatter
115
- phase1_speedup.py Figure 2: GPU speedup benchmark
116
- phase1_ablation.py Figures S1–S2: ε/τ sensitivity heatmaps
117
- phase2_gan.py Figure 3: GAN mode-collapse prevention
118
- phase3_anomaly.py Table 2 + Figure 5: anomaly detection benchmark
119
- tests/ Unit tests (pytest, 35/35 passing)
153
+ phase1_*.py Validation, speedup, ablation (Figures 1–2, S1–S2)
154
+ phase2_gan.py GAN mode-collapse prevention (Figure 3)
155
+ phase3_anomaly.py Anomaly detection (Table 2, Figure 5)
156
+ round20_*.py Large-n R comparison and streaming benchmarks
157
+ round21_*.py Accuracy improvement experiments
158
+ tests/ Unit tests (pytest, 45 passed, 1 xfailed)
120
159
  outputs/ All generated figures and tables (PDFs, PNGs, CSVs)
121
- notebooks/ Quickstart and demo notebooks
122
160
  ```
123
161
 
124
162
  ## Reproducing Paper Results
@@ -127,7 +165,6 @@ notebooks/ Quickstart and demo notebooks
127
165
  # Phase 1: validation, speedup, ablation
128
166
  python experiments/phase1_validation.py
129
167
  python experiments/phase1_speedup.py
130
- python experiments/phase1_ablation.py
131
168
 
132
169
  # Phase 2: GAN mode collapse experiment
133
170
  python experiments/phase2_gan.py
@@ -136,13 +173,13 @@ python experiments/phase2_gan.py
136
173
  python experiments/phase3_anomaly.py
137
174
  ```
138
175
 
139
- For GPU runs, use the provided Kaggle kernels:
176
+ For GPU runs use the Kaggle kernels:
140
177
  - Phase 1–2: `hsingle/dcb-full-experiments`
141
178
  - Phase 3: `hsingle/dcb-phase-3-anomaly-detection`
142
179
 
143
- ## Kaggle GPU Notes
180
+ ## Paper
144
181
 
145
- Kaggle may assign a P100 (sm_60) instead of T4. The Phase 3 kernel handles this automatically by installing `torch==2.2.2+cu118` (the earliest PyTorch release with both Python 3.12 and sm_60 support) when P100 is detected.
182
+ > Ruiyu Zhang. "Differentiable Critical Bandwidth: Making Silverman's Modality Test End-to-End Trainable." *Journal of Machine Learning Research*, 2026 (in preparation).
146
183
 
147
184
  ## License
148
185
 
diffcb-0.1.4/README.md ADDED
@@ -0,0 +1,129 @@
1
+ # DCB — Differentiable Critical Bandwidth
2
+
3
+ [![PyPI](https://img.shields.io/pypi/v/diffcb.svg)](https://pypi.org/project/diffcb/)
4
+ [![License: MIT](https://img.shields.io/badge/License-MIT-yellow.svg)](LICENSE)
5
+ [![Python 3.9+](https://img.shields.io/badge/python-3.9+-blue.svg)](https://www.python.org/)
6
+
7
+ A PyTorch package that makes **Silverman's critical bandwidth test (1981)** fully differentiable, enabling end-to-end gradient-based optimization over the modal structure of continuous distributions.
8
+
9
+ ## Overview
10
+
11
+ The critical bandwidth `h_crit` is the minimum KDE bandwidth at which a distribution appears to have at most `m` modes — a classical nonparametric statistic for modality testing. DCB replaces every non-differentiable operation in its computation with a smooth surrogate, then uses the **Implicit Function Theorem** to compute exact gradients through the root-finding step at O(1) memory cost.
12
+
13
+ ```python
14
+ import torch
15
+ from dcb import DCBLayer
16
+
17
+ X = torch.randn(1000, requires_grad=True) # 1D samples
18
+ layer = DCBLayer(target_modes=1)
19
+ h_crit = layer(X) # differentiable scalar
20
+ h_crit.backward() # exact IFT gradients
21
+ ```
22
+
23
+ ## Installation
24
+
25
+ ```bash
26
+ pip install diffcb
27
+ ```
28
+
29
+ Or from source:
30
+
31
+ ```bash
32
+ git clone https://github.com/ryZhangHason/differentiable-critical-bandwidth
33
+ cd differentiable-critical-bandwidth
34
+ pip install -e ".[dev]"
35
+ ```
36
+
37
+ ## Accuracy vs R's `bw.crit`
38
+
39
+ DCB is validated against R's `multimode::bw.crit(data, mod0=1)` — the standard reference implementation of Hall & York (2001). On **identical data**:
40
+
41
+ | n | DCB vs R (same sample) | DCB vs R (independent samples) |
42
+ |---|---|---|
43
+ | 100K | **0.004%** | ~0.5% (MC noise from independent RNG) |
44
+ | 1M | **0.005%** | ~0.2% |
45
+ | 10M | **0.004%** | ~0.1% |
46
+
47
+ The independent-sample figures reflect natural sampling variability (two unbiased estimators drawing different data), not algorithmic error. On identical data, DCB agrees with R to within **0.005%** at all tested n. DCB is 43× faster than R at n=100M (1.1 s vs 50 s) and handles n=2B in 24 s while R OOMs.
48
+
49
+ ## Key Parameters
50
+
51
+ ```python
52
+ DCBLayer(
53
+ target_modes=1, # target number of modes
54
+ G=512, # IFT evaluation grid points
55
+ use_fft=True, # FFT forward (default); eliminates subsampling bias for n>50K
56
+ max_n_exact=1_000_000,# sketch to sketch_size when n exceeds this (None = always exact)
57
+ sketch_size=500_000, # sketch target; 500K matches full-n accuracy (O(n^{-2/9}) rate)
58
+ safe_backward=False, # clamp IFT denominator near bifurcations
59
+ )
60
+ ```
61
+
62
+ ## Confirmed Experimental Results
63
+
64
+ All GPU results produced on Kaggle (T4 / P100) — see `experiments/` and `outputs/`.
65
+
66
+ | Experiment | Result | Criterion |
67
+ |---|---|---|
68
+ | **Accuracy vs R (same data, n=100K)** | **0.004%** | < 0.01% ✓ |
69
+ | **Validation (m≥2, Marron-Wand)** | R²=0.91, MAE=0.07, ρ=0.89 | R²≥0.85 ✓ |
70
+ | **Speedup vs scipy (CUDA T4, n=8192)** | **10.5×** | ≥3× ✓ |
71
+ | **GAN mode preservation** | h_crit=1.232 >> 0.3 | h_crit>0.3 ✓ |
72
+ | **Anomaly AUC (KDDCup99)** | DCB=**0.9982** vs IF=0.9867 | DCB≥IF ✓ |
73
+
74
+ ## Changelog
75
+
76
+ ### v0.1.1 (2026-05-29)
77
+ - **MPS fix:** `torch.histc` on MPS allocated an n×bins intermediate (OOM at n≥5M). Replaced with `bucketize+bincount` on CPU — MPS-safe and numerically identical.
78
+ - **Sketch API:** `DCBLayer(max_n_exact=1_000_000, sketch_size=500_000)` — silently sketches to 500K when n exceeds threshold. Justified by O(n⁻²/⁹) convergence of h_crit; 500K sketch matches full-n accuracy.
79
+ - **Consistent bisection domain:** Pre-computed domain passed to all `fft_mode_count` calls in a single bisection, eliminating per-step drift.
80
+ - **Bias warning direction:** Corrected "expected upward bias" to "expected downward bias" on legacy `use_fft=False` path.
81
+ - **Test fixes:** Updated 8 pre-existing test failures (tuple unpacking, bounds, deprecation API).
82
+
83
+ ### v0.1.0 (2026-05-28)
84
+ - Initial PyPI release. FFT forward (O(n + G log G)), IFT backward, MPS support.
85
+
86
+ ## Repository Structure
87
+
88
+ ```
89
+ dcb/ Core PyTorch package
90
+ layer.py DCBLayer nn.Module + DCBFunction autograd
91
+ solver.py IFT root-finder and backward pass
92
+ fft_kde.py FFT-based mode counter (MPS-safe, float64, G=16384)
93
+ kde.py Direct KDE derivatives (small-n path)
94
+ utils.py Grid, Silverman bandwidth, sg() stabilizer
95
+ experiments/ Reproduction scripts for all paper figures and tables
96
+ phase1_*.py Validation, speedup, ablation (Figures 1–2, S1–S2)
97
+ phase2_gan.py GAN mode-collapse prevention (Figure 3)
98
+ phase3_anomaly.py Anomaly detection (Table 2, Figure 5)
99
+ round20_*.py Large-n R comparison and streaming benchmarks
100
+ round21_*.py Accuracy improvement experiments
101
+ tests/ Unit tests (pytest, 45 passed, 1 xfailed)
102
+ outputs/ All generated figures and tables (PDFs, PNGs, CSVs)
103
+ ```
104
+
105
+ ## Reproducing Paper Results
106
+
107
+ ```bash
108
+ # Phase 1: validation, speedup, ablation
109
+ python experiments/phase1_validation.py
110
+ python experiments/phase1_speedup.py
111
+
112
+ # Phase 2: GAN mode collapse experiment
113
+ python experiments/phase2_gan.py
114
+
115
+ # Phase 3: anomaly detection benchmark
116
+ python experiments/phase3_anomaly.py
117
+ ```
118
+
119
+ For GPU runs use the Kaggle kernels:
120
+ - Phase 1–2: `hsingle/dcb-full-experiments`
121
+ - Phase 3: `hsingle/dcb-phase-3-anomaly-detection`
122
+
123
+ ## Paper
124
+
125
+ > Ruiyu Zhang. "Differentiable Critical Bandwidth: Making Silverman's Modality Test End-to-End Trainable." *Journal of Machine Learning Research*, 2026 (in preparation).
126
+
127
+ ## License
128
+
129
+ MIT — see [LICENSE](LICENSE).
@@ -19,4 +19,4 @@ __all__ = [
19
19
  "DCBLayer", "DifferentiableCriticalBandwidth",
20
20
  "anneal_eps_tau", "soft_mode_count_cross", "soft_mode_count",
21
21
  ]
22
- __version__ = "0.1.1"
22
+ __version__ = "0.1.4"
@@ -0,0 +1,339 @@
1
+ """
2
+ dcb.fft_kde — FFT-based KDE Mode Counter
3
+
4
+ Implements mode counting via FFT convolution of the histogram with a
5
+ Gaussian derivative kernel. Complexity is O(n + G log G), avoiding the
6
+ O(n × G) cost of the direct KDE approach and — crucially — requiring NO
7
+ subsampling. This eliminates the (brentq_n_max / n)^{-1/5} upward bias
8
+ that affects the standard bisection path when n > brentq_n_max.
9
+
10
+ Round 18b: forward kernel only. The IFT backward is unchanged (still uses
11
+ the analytical chunked KDE derivatives on all n points).
12
+ """
13
+
14
+ from __future__ import annotations
15
+
16
+ import math
17
+
18
+ import torch
19
+ from torch import Tensor
20
+
21
+
22
+ # Worker 2: device-native histogram
23
+ def _histogram_on_device(X: Tensor, G: int, lo: float, hi: float) -> Tensor:
24
+ """Compute a G-bin histogram of X on the same device as X."""
25
+ device = X.device
26
+ if device.type == 'cuda':
27
+ return torch.histc(X.float(), bins=G, min=lo, max=hi)
28
+ elif device.type == 'mps':
29
+ bin_idx = ((X.float() - lo) * (G / (hi - lo))).long().clamp_(0, G - 1)
30
+ counts = torch.zeros(G, dtype=torch.float32, device=device)
31
+ counts.scatter_add_(0, bin_idx, torch.ones(X.shape[0], dtype=torch.float32, device=device))
32
+ return counts
33
+ else: # cpu
34
+ X_cpu = X.float()
35
+ edges = torch.linspace(lo, hi, G + 1)
36
+ bin_idx = torch.bucketize(X_cpu, edges, right=True).clamp(1, G) - 1
37
+ return torch.bincount(bin_idx, minlength=G).float()
38
+
39
+
40
+ def precompute_fft(
41
+ X: Tensor,
42
+ G: int = 4096,
43
+ domain: tuple[float, float] | None = None,
44
+ pad_factor: int = 2, # Worker 5: pad_factor=2 (was 4) — safe for h ≤ 3σ, halves irfft size
45
+ fft_dtype: torch.dtype = torch.float32, # Worker 3: float32 FFT
46
+ ) -> tuple[Tensor, Tensor, tuple[float, float]]:
47
+ """Precompute the FFT of the zero-padded histogram of X.
48
+
49
+ This is the bandwidth-independent work shared across a bisection loop on
50
+ h: build the histogram, zero-pad, take rfft, and build the frequency grid
51
+ omega. The per-step kernel K(omega, h) = i*omega*exp(-0.5*(omega*h)**2)
52
+ must be combined with C inside `mode_count_from_C`.
53
+
54
+ Parameters
55
+ ----------
56
+ X : Tensor, shape (n,)
57
+ G : int
58
+ Number of histogram bins.
59
+ domain : (lo, hi) or None
60
+ If provided, use as histogram domain; otherwise computed from X
61
+ with a 3*sigma margin.
62
+ pad_factor : int
63
+ Zero-padding multiplier (default 4).
64
+
65
+ Returns
66
+ -------
67
+ C : Tensor, shape (N//2+1,), complex128
68
+ rfft of the zero-padded float64 histogram. Empty tensor (degenerate
69
+ zero-range domain) signals the caller to short-circuit to 1 mode.
70
+ omega : Tensor, shape (N//2+1,), float64
71
+ Angular frequency grid for the FFT.
72
+ domain : (lo, hi)
73
+ Domain tuple actually used.
74
+ """
75
+ with torch.no_grad():
76
+ if domain is not None:
77
+ lo, hi = domain
78
+ else:
79
+ sigma = X.std().item()
80
+ if sigma == 0.0:
81
+ sigma = 1.0
82
+ lo = X.min().item() - 3 * sigma
83
+ hi = X.max().item() + 3 * sigma
84
+ data_range = hi - lo
85
+
86
+ if data_range == 0.0:
87
+ complex_dtype = torch.complex64 if fft_dtype == torch.float32 else torch.complex128
88
+ empty = torch.zeros(0, dtype=complex_dtype, device=X.device)
89
+ empty_omega = torch.zeros(0, dtype=fft_dtype, device=X.device)
90
+ return empty, empty_omega, (lo, hi)
91
+
92
+ # Histogram (O(n)) — device-native dispatch.
93
+ counts = _histogram_on_device(X, G, lo, hi)
94
+
95
+ N = pad_factor * G
96
+ counts_padded = torch.zeros(N, dtype=fft_dtype, device=X.device)
97
+ counts_padded[:G] = counts.to(fft_dtype)
98
+
99
+ C = torch.fft.rfft(counts_padded)
100
+
101
+ bin_width = data_range / G
102
+ k = torch.arange(N // 2 + 1, device=X.device, dtype=fft_dtype)
103
+ omega = 2 * math.pi * k / (N * bin_width)
104
+
105
+ return C, omega, (lo, hi)
106
+
107
+
108
+ def mode_count_from_C(
109
+ C: Tensor,
110
+ omega: Tensor,
111
+ h: float,
112
+ G: int,
113
+ N: int,
114
+ ) -> int:
115
+ """Per-step mode count: apply Gaussian derivative kernel and count sign changes.
116
+
117
+ Cheap inner loop body for bisection — only the kernel depends on h.
118
+
119
+ Parameters
120
+ ----------
121
+ C : Tensor, shape (N//2+1,), complex
122
+ rfft of the zero-padded histogram (from `precompute_fft`).
123
+ omega : Tensor, shape (N//2+1,), float64
124
+ Frequency grid (from `precompute_fft`).
125
+ h : float
126
+ Bandwidth.
127
+ G : int
128
+ Histogram bin count.
129
+ N : int
130
+ Padded FFT length (pad_factor * G).
131
+
132
+ Returns
133
+ -------
134
+ int
135
+ Number of KDE modes.
136
+ """
137
+ if C.numel() == 0:
138
+ return 1 # degenerate single-point distribution
139
+
140
+ K_deriv = 1j * omega * torch.exp(-0.5 * (omega * h) ** 2)
141
+ f_prime_padded = torch.fft.irfft(C * K_deriv, n=N).real
142
+ f_prime = f_prime_padded[:G]
143
+
144
+ nonzero_mask = f_prime != 0
145
+ if not nonzero_mask.any():
146
+ return 0
147
+
148
+ s = f_prime[nonzero_mask]
149
+ transitions = int(((s[:-1] > 0) & (s[1:] < 0)).sum().item())
150
+ return transitions
151
+
152
+
153
+ def fft_mode_count(
154
+ X: Tensor,
155
+ h: float,
156
+ G: int = 4096,
157
+ pad_factor: int = 2, # Worker 5: pad_factor=2 (was 4) — safe for h ≤ 3σ, halves irfft size
158
+ domain: tuple[float, float] | None = None,
159
+ ) -> int:
160
+ """Count KDE modes via FFT convolution — O(n + G log G), no subsampling.
161
+
162
+ Bins X into G histogram bins, zero-pads to pad_factor*G, convolves with
163
+ the Gaussian derivative kernel in the frequency domain (applying iω·exp(−½(ωh)²)),
164
+ back-transforms, and counts positive-to-negative sign changes of the
165
+ resulting f' estimate.
166
+
167
+ Parameters
168
+ ----------
169
+ X : Tensor, shape (n,)
170
+ 1D data tensor (may be on CPU or CUDA).
171
+ h : float
172
+ Bandwidth for the Gaussian kernel.
173
+ G : int
174
+ Number of histogram bins. Must satisfy h > 8 * (data_range / G) for
175
+ reliable derivative estimation. Use `adaptive_fft_G` to choose G
176
+ automatically before bisection.
177
+ pad_factor : int
178
+ Zero-padding multiplier (default 4). Mandatory ≥ 2 for circular-wrap
179
+ correctness; 4 is recommended at the largest h encountered.
180
+ domain : (lo, hi) or None
181
+ If provided, use this as the histogram domain instead of computing
182
+ X.min() - 3σ … X.max() + 3σ. Allows the caller to align the domain
183
+ with the bisection bracket (e.g., X.min() - 2*h_hi … X.max() + 2*h_hi)
184
+ so every fft_mode_count call in a bisection loop uses an identical grid.
185
+
186
+ Returns
187
+ -------
188
+ int
189
+ Number of KDE modes (downward zero-crossings of f').
190
+ """
191
+ with torch.no_grad():
192
+ C, omega, _ = precompute_fft(X, G=G, domain=domain, pad_factor=pad_factor)
193
+ N = pad_factor * G
194
+ return mode_count_from_C(C, omega, h, G, N)
195
+
196
+
197
+ def _refine_hcrit(
198
+ X: Tensor,
199
+ h_lo: float,
200
+ h_hi: float,
201
+ G: int,
202
+ domain: tuple[float, float],
203
+ target_modes: int = 1,
204
+ pad_factor: int = 2, # Worker 5: pad_factor=2 (was 4) — safe for h ≤ 3σ, halves irfft size
205
+ ) -> float:
206
+ """Sub-bin quadratic refinement of h_crit after bisection converges.
207
+
208
+ Identifies the f′ lobe that disappears at the mode-merging bandwidth and
209
+ fits a quadratic in h to that lobe's peak value, returning the root — the
210
+ h where that peak exactly reaches zero. Reduces the bin-width-limited
211
+ systematic from ~bin_width/h_crit to well below 1e-4.
212
+
213
+ When the incoming bracket [h_lo, h_hi] is tighter than one histogram bin
214
+ width (the common case after 50-step bisection), the function expands the
215
+ bracket outward from h_hi by up to 4× the bin width while maintaining the
216
+ invariant that fft_mode_count > target at the left endpoint and
217
+ <= target at the right endpoint, so the disappearing f′ lobe is visible
218
+ across the bracket.
219
+
220
+ Parameters
221
+ ----------
222
+ X : Tensor — data (may be on any device)
223
+ h_lo, h_hi : float — final bisection bracket; fft_mode_count(X,h_lo) > target,
224
+ fft_mode_count(X,h_hi) <= target
225
+ G, domain, target_modes, pad_factor — same as fft_mode_count
226
+
227
+ Returns
228
+ -------
229
+ float — refined h_crit, guaranteed to lie in [h_lo, h_hi] of the
230
+ (possibly expanded) bracket used for fitting.
231
+ """
232
+ import numpy as np
233
+
234
+ lo_d, hi_d = domain
235
+ data_range = hi_d - lo_d
236
+ if data_range == 0.0:
237
+ return h_hi
238
+
239
+ bin_width = data_range / G
240
+ N = pad_factor * G
241
+ bw = bin_width # histogram bin width
242
+
243
+ # Pre-compute histogram once; reuse C (FFT of counts) for all h evaluations.
244
+ with torch.no_grad():
245
+ counts = _histogram_on_device(X, G, lo_d, hi_d).cpu()
246
+ counts_padded = torch.zeros(N, dtype=torch.float64)
247
+ counts_padded[:G] = counts.double()
248
+ C = torch.fft.rfft(counts_padded)
249
+ k = torch.arange(N // 2 + 1, dtype=torch.float64)
250
+ omega_base = 2 * math.pi * k / (N * bw)
251
+
252
+ def fprime(h: float) -> Tensor:
253
+ """Compute f′ array (shape G,) for bandwidth h using cached C (float64)."""
254
+ K_deriv = 1j * omega_base * torch.exp(-0.5 * (omega_base * h) ** 2)
255
+ return torch.fft.irfft(C * K_deriv, n=N).float()[:G]
256
+
257
+ with torch.no_grad():
258
+ # If the bracket is tighter than bin_width, expand it so that the
259
+ # disappearing f′ lobe crosses zero somewhere inside the bracket.
260
+ # Expand the left endpoint leftward by up to 4 bin widths.
261
+ ref_lo = h_lo
262
+ ref_hi = h_hi
263
+
264
+ if (ref_hi - ref_lo) < bw:
265
+ # Try expanding leftward until we find a bin where fp crosses zero
266
+ for mult in [1, 2, 3, 4]:
267
+ cand_lo = max(ref_hi - mult * bw, ref_hi * 0.9)
268
+ fp_cand = fprime(cand_lo)
269
+ fp_hi_ = fprime(ref_hi)
270
+ cm = (fp_cand > 0) & (fp_hi_ <= 0)
271
+ if cm.any():
272
+ ref_lo = cand_lo
273
+ break
274
+ # If still no candidates found, return bisection result unchanged
275
+ fp_lo_ = fprime(ref_lo)
276
+ fp_hi_ = fprime(ref_hi)
277
+ candidate_mask = (fp_lo_ > 0) & (fp_hi_ <= 0)
278
+ if not candidate_mask.any():
279
+ return h_hi
280
+ else:
281
+ fp_lo_ = fprime(ref_lo)
282
+ fp_hi_ = fprime(ref_hi)
283
+ candidate_mask = (fp_lo_ > 0) & (fp_hi_ <= 0)
284
+ if not candidate_mask.any():
285
+ return h_hi
286
+
287
+ # Pick the bin with the largest positive value at ref_lo that crossed zero
288
+ masked_fp_lo = fp_lo_.clone()
289
+ masked_fp_lo[~candidate_mask] = -float('inf')
290
+ j = int(masked_fp_lo.argmax().item())
291
+
292
+ h_mid = (ref_lo + ref_hi) / 2.0
293
+
294
+ # Evaluate fp[j] at three bandwidths for quadratic fit
295
+ y_lo = fp_lo_[j].item()
296
+ y_mid = fprime(h_mid)[j].item()
297
+ y_hi = fp_hi_[j].item()
298
+
299
+ # Fit quadratic y = a*h² + b*h + c through the three (h, y) pairs
300
+ # and solve for the root in [ref_lo, ref_hi].
301
+ coeffs = np.polyfit([ref_lo, h_mid, ref_hi], [y_lo, y_mid, y_hi], 2)
302
+ roots = np.roots(coeffs)
303
+ real_roots = [
304
+ r.real for r in roots
305
+ if abs(r.imag) < 1e-10 * abs(r.real + 1e-30)
306
+ and ref_lo <= r.real <= ref_hi
307
+ ]
308
+ if real_roots:
309
+ return float(min(real_roots, key=lambda r: abs(r - h_mid)))
310
+ return h_hi
311
+
312
+
313
+ def adaptive_fft_G(data_range: float, h_hi: float, G_min: int = 16384) -> int:
314
+ """Choose FFT grid size G so that the derivative kernel is well-resolved.
315
+
316
+ Requires h > 8 * bin_width = 8 * data_range / G, equivalently
317
+ G > 8 * data_range / h_hi. We use a factor of 16 for safety margin,
318
+ then round up to the next power of 2 for efficient FFT.
319
+
320
+ Parameters
321
+ ----------
322
+ data_range : float
323
+ hi - lo of the data domain (typically X.max() - X.min() + 6σ).
324
+ h_hi : float
325
+ Upper bracket of the bisection (smallest h needing resolution).
326
+ G_min : int
327
+ Minimum returned G (default 16384).
328
+
329
+ Returns
330
+ -------
331
+ int
332
+ Grid size G, a power of 2, at least G_min.
333
+ """
334
+ needed = 16 * math.ceil(data_range / h_hi)
335
+ # Round up to next power of 2
336
+ p = 1
337
+ while p < needed:
338
+ p <<= 1
339
+ return max(G_min, p)
@@ -35,13 +35,14 @@ class DCBFunction(torch.autograd.Function):
35
35
 
36
36
  @staticmethod
37
37
  def forward(ctx, X, grid, eps, tau, target_modes, delta, formula, chunk_size,
38
- brentq_n_max, g_brentq, use_hard_bisection, safe_backward, use_fft):
38
+ brentq_n_max, g_brentq, use_hard_bisection, safe_backward, use_fft, fft_G_min,
39
+ fft_dtype):
39
40
  """Locate h_crit and save state for the backward pass."""
40
41
  h_crit, cond_num = find_h_crit(
41
42
  X, grid, eps, tau, target_modes,
42
43
  formula=formula, brentq_n_max=brentq_n_max, chunk_size=chunk_size,
43
44
  g_brentq=g_brentq, use_hard_bisection=use_hard_bisection,
44
- use_fft=use_fft,
45
+ use_fft=use_fft, G_min=fft_G_min, fft_dtype=fft_dtype,
45
46
  )
46
47
  ctx.save_for_backward(X, grid)
47
48
  ctx.h_crit = h_crit
@@ -67,8 +68,8 @@ class DCBFunction(torch.autograd.Function):
67
68
  ctx.denom_abs = ift_gradient.last_denom_abs
68
69
  # Gradients for: X, grid, eps, tau, target_modes, delta, formula,
69
70
  # chunk_size, brentq_n_max, g_brentq, use_hard_bisection,
70
- # safe_backward, use_fft
71
- return grad_X, None, None, None, None, None, None, None, None, None, None, None, None
71
+ # safe_backward, use_fft, fft_G_min, fft_dtype
72
+ return grad_X, None, None, None, None, None, None, None, None, None, None, None, None, None, None
72
73
 
73
74
 
74
75
  class DCBLayer(nn.Module):
@@ -133,6 +134,13 @@ class DCBLayer(nn.Module):
133
134
  Number of points to sketch when n > max_n_exact. Default 500_000.
134
135
  A 500K sketch achieves the same mean accuracy as streaming 100M points
135
136
  (validated in Round 20 reservoir experiment).
137
+ fft_G_min : int
138
+ Minimum FFT histogram grid size for the bisection solver (default 16384).
139
+ Controls accuracy of the FFT path (n > 50K). Larger values reduce
140
+ discretisation error at a modest cost: G=16384 gives ~0.004% err vs R;
141
+ G=32768 gives ~0.001% at +9% cost; G=65536 reaches the R-matching floor
142
+ (~0.001%) with no further gain beyond that. Ignored for n ≤ 50K (direct
143
+ KDE path).
136
144
 
137
145
  Examples
138
146
  --------
@@ -162,6 +170,8 @@ class DCBLayer(nn.Module):
162
170
  use_fft: bool = True,
163
171
  max_n_exact: int | None = 1_000_000,
164
172
  sketch_size: int = 500_000,
173
+ fft_G_min: int = 16384,
174
+ fft_dtype: torch.dtype = torch.float32,
165
175
  ):
166
176
  super().__init__()
167
177
  self.target_modes = target_modes
@@ -180,6 +190,8 @@ class DCBLayer(nn.Module):
180
190
  self.use_fft = use_fft
181
191
  self.max_n_exact = max_n_exact
182
192
  self.sketch_size = sketch_size
193
+ self.fft_G_min = fft_G_min
194
+ self.fft_dtype = fft_dtype
183
195
  if use_fft and brentq_n_max != 50_000:
184
196
  raise TypeError(
185
197
  f"brentq_n_max={brentq_n_max} is meaningless when use_fft=True: the FFT path "
@@ -250,7 +262,7 @@ class DCBLayer(nn.Module):
250
262
  return DCBFunction.apply(
251
263
  X, grid, eps_eff, tau_eff, self.target_modes, self.delta, self.formula,
252
264
  self.chunk_size, self.brentq_n_max, self.g_brentq, self.use_hard_bisection,
253
- self.safe_backward, self.use_fft,
265
+ self.safe_backward, self.use_fft, self.fft_G_min, self.fft_dtype,
254
266
  )
255
267
 
256
268
 
@@ -37,7 +37,7 @@ from dcb.kde import (
37
37
  soft_mode_count_cross_from_derivs,
38
38
  kde_derivatives_chunked,
39
39
  )
40
- from dcb.fft_kde import fft_mode_count, adaptive_fft_G
40
+ from dcb.fft_kde import fft_mode_count, adaptive_fft_G, precompute_fft, mode_count_from_C
41
41
 
42
42
  _AUTO_FFT_THRESHOLD = 50_000 # n above which FFT bisection activates (use_fft_effective)
43
43
 
@@ -74,6 +74,8 @@ def find_h_crit_hard(
74
74
  eps: float = 0.1,
75
75
  tau: float = 0.2,
76
76
  use_fft: bool = False,
77
+ G_min: int = 16384,
78
+ fft_dtype: torch.dtype = torch.float32,
77
79
  ) -> tuple[float, float]:
78
80
  """Find h_crit via hard-mode-count bisection (monotone, no false roots).
79
81
 
@@ -151,43 +153,64 @@ def find_h_crit_hard(
151
153
  lo_domain = X.min().item() - 3 * sigma
152
154
  hi_domain = X.max().item() + 3 * sigma
153
155
  data_range = hi_domain - lo_domain
154
- G_fft = adaptive_fft_G(data_range, h_hi)
156
+ G_fft = adaptive_fft_G(data_range, h_hi, G_min=G_min)
155
157
  _domain = (lo_domain, hi_domain)
158
+ pad_factor = 2 # Worker 5: pad_factor=2 (was 4) — safe for h ≤ 3σ, halves irfft size
159
+ N = pad_factor * G_fft
156
160
 
157
161
  with torch.no_grad():
162
+ # Worker 1: precomputed C — hoist histogram + rfft out of bisection.
163
+ # Worker 3: float32 FFT by default — 2× faster; _refine_hcrit uses float64 independently.
164
+ C, omega, _domain = precompute_fft(
165
+ X, G=G_fft, domain=_domain, pad_factor=pad_factor, fft_dtype=fft_dtype,
166
+ )
167
+
158
168
  # Verify bracket using FFT mode count on full X
159
- count_lo = fft_mode_count(X, h_lo, G=G_fft, domain=_domain)
169
+ count_lo = mode_count_from_C(C, omega, h_lo, G_fft, N)
160
170
  if count_lo <= target_modes:
161
171
  h_lo_try = h_lo
162
172
  for _ in range(30):
163
173
  h_lo_try *= 0.5
164
174
  if h_lo_try < 1e-10:
165
175
  break
166
- if fft_mode_count(X, h_lo_try, G=G_fft, domain=_domain) > target_modes:
176
+ if mode_count_from_C(C, omega, h_lo_try, G_fft, N) > target_modes:
167
177
  h_lo = h_lo_try
168
178
  break
169
179
 
170
- count_hi = fft_mode_count(X, h_hi, G=G_fft, domain=_domain)
180
+ count_hi = mode_count_from_C(C, omega, h_hi, G_fft, N)
171
181
  if count_hi > target_modes:
172
182
  for _ in range(30):
173
183
  h_hi *= 2.0
174
- if fft_mode_count(X, h_hi, G=G_fft, domain=_domain) <= target_modes:
184
+ if mode_count_from_C(C, omega, h_hi, G_fft, N) <= target_modes:
175
185
  break
176
186
 
177
- # Standard bisection: 50 iterations bracket width / 2^50
187
+ # Adaptive bisection: stop when bracket is localised (relative width < 1e-3)
188
+ # _refine_hcrit provides sub-bin precision afterwards — no need to over-bisect.
178
189
  lo, hi = h_lo, h_hi
179
190
  for _ in range(50):
180
191
  mid = (lo + hi) / 2.0
181
- count = fft_mode_count(X, mid, G=G_fft, domain=_domain)
192
+ count = mode_count_from_C(C, omega, mid, G_fft, N)
182
193
  if count <= target_modes:
183
194
  hi = mid
184
195
  else:
185
196
  lo = mid
186
197
  if (hi - lo) < tol:
187
198
  break
199
+ # Worker 4: adaptive termination — stop when relative bracket width
200
+ # is small enough that further bisection cannot meaningfully shift
201
+ # _refine_hcrit's quadratic fit. Empirically 1e-7 preserves h_crit
202
+ # to within 1e-6 of the 50-step tol=1e-6 baseline while saving ~10
203
+ # bisection steps in typical cases.
204
+ if hi > 0 and (hi - lo) / hi < 1e-7:
205
+ break
188
206
 
189
207
  h_crit = float(hi) # smallest h with count <= target_modes
190
208
 
209
+ # Sub-bin refinement: quadratic interpolation on the disappearing f′ lobe
210
+ # to locate h_crit below the bin-width precision limit.
211
+ from dcb.fft_kde import _refine_hcrit
212
+ h_crit = _refine_hcrit(X, lo, hi, G_fft, _domain, target_modes)
213
+
191
214
  else:
192
215
  with torch.no_grad():
193
216
  # Verify bracket: need count > target at h_lo, count <= target at h_hi.
@@ -216,6 +239,9 @@ def find_h_crit_hard(
216
239
  break
217
240
 
218
241
  # Standard bisection: 50 iterations → bracket width / 2^50
242
+ # NOTE: non-FFT path has no _refine_hcrit sub-bin refinement, so we keep
243
+ # tight bisection here for gradient stability (IFT test requires h_crit
244
+ # accurate well below FD perturbation delta=1e-3).
219
245
  lo, hi = h_lo, h_hi
220
246
  for _ in range(50):
221
247
  mid = (lo + hi) / 2.0
@@ -290,6 +316,8 @@ def find_h_crit(
290
316
  g_brentq: int = 128,
291
317
  use_hard_bisection: bool = True,
292
318
  use_fft: bool = True,
319
+ G_min: int = 16384,
320
+ fft_dtype: torch.dtype = torch.float32,
293
321
  ) -> tuple[float, float]:
294
322
  """Find h_crit and return (h_crit, condition_number).
295
323
 
@@ -343,7 +371,7 @@ def find_h_crit(
343
371
  return find_h_crit_hard(
344
372
  X, grid, target_modes, chunk_size, brentq_n_max,
345
373
  h_lo, h_hi, formula=formula, eps=eps, tau=tau,
346
- use_fft=use_fft,
374
+ use_fft=use_fft, G_min=G_min, fft_dtype=fft_dtype,
347
375
  )
348
376
 
349
377
  from scipy.optimize import brentq
@@ -4,7 +4,7 @@ build-backend = "hatchling.build"
4
4
 
5
5
  [project]
6
6
  name = "diffcb"
7
- version = "0.1.1"
7
+ version = "0.1.4"
8
8
  description = "Differentiable Critical Bandwidth: Silverman's modality test as a differentiable PyTorch layer with IFT backward pass."
9
9
  readme = "README.md"
10
10
  license = { file = "LICENSE" }
diffcb-0.1.1/README.md DELETED
@@ -1,92 +0,0 @@
1
- # DCB — Differentiable Critical Bandwidth
2
-
3
- [![PyPI](https://img.shields.io/pypi/v/diffcb.svg)](https://pypi.org/project/diffcb/)
4
- [![License: MIT](https://img.shields.io/badge/License-MIT-yellow.svg)](LICENSE)
5
- [![Python 3.9+](https://img.shields.io/badge/python-3.9+-blue.svg)](https://www.python.org/)
6
-
7
- A PyTorch package that makes **Silverman's critical bandwidth test (1981)** fully differentiable, enabling end-to-end gradient-based optimization over the modal structure of continuous distributions.
8
-
9
- ## Overview
10
-
11
- The critical bandwidth `h_crit` is the minimum KDE bandwidth at which a distribution appears to have at most `m` modes — a classical nonparametric statistic for modality testing. DCB replaces every non-differentiable operation in its computation with a smooth surrogate, then uses the **Implicit Function Theorem** to compute exact gradients through the root-finding step at O(1) memory cost.
12
-
13
- ```python
14
- import torch
15
- from dcb import DCBLayer
16
-
17
- X = torch.randn(256, requires_grad=True) # 1D samples
18
- layer = DCBLayer(target_modes=1)
19
- h_crit = layer(X) # differentiable scalar
20
- h_crit.backward() # exact IFT gradients
21
- ```
22
-
23
- ## Installation
24
-
25
- ```bash
26
- pip install diffcb
27
- ```
28
-
29
- Or from source:
30
-
31
- ```bash
32
- git clone https://github.com/ryZhangHason/differentiable-critical-bandwidth
33
- cd differentiable-critical-bandwidth
34
- pip install -e ".[dev]"
35
- ```
36
-
37
- ## Paper
38
-
39
- > Ruiyu Zhang. "Differentiable Critical Bandwidth: Making Silverman's Modality Test End-to-End Trainable." *Journal of Machine Learning Research*, 2026 (in preparation).
40
-
41
- ## Confirmed Experimental Results
42
-
43
- All results produced on Kaggle GPU (T4 / P100) — see `experiments/` and `outputs/`.
44
-
45
- | Experiment | Result | Criterion |
46
- |---|---|---|
47
- | **Validation (m≥2)** | R²=0.91, MAE=0.07, Spearman ρ=0.89 | R²≥0.85, MAE≤0.10 ✓ |
48
- | **Speedup vs scipy (n=8192)** | **10.5×** on T4 | ≥3× ✓ |
49
- | **GAN mode preservation** | h_crit=1.232 >> 0.3 | h_crit>0.3 ✓ |
50
- | **Anomaly AUC (KDDCup99)** | DCB=**0.9982** vs IF=0.9867 | DCB≥IF ✓ |
51
-
52
- ## Repository Structure
53
-
54
- ```
55
- dcb/ Core PyTorch package (layer.py, solver.py, kde.py, utils.py)
56
- experiments/ Reproduction scripts for all paper figures and tables
57
- phase1_validation.py Figure 1: DCB vs reference h_crit scatter
58
- phase1_speedup.py Figure 2: GPU speedup benchmark
59
- phase1_ablation.py Figures S1–S2: ε/τ sensitivity heatmaps
60
- phase2_gan.py Figure 3: GAN mode-collapse prevention
61
- phase3_anomaly.py Table 2 + Figure 5: anomaly detection benchmark
62
- tests/ Unit tests (pytest, 35/35 passing)
63
- outputs/ All generated figures and tables (PDFs, PNGs, CSVs)
64
- notebooks/ Quickstart and demo notebooks
65
- ```
66
-
67
- ## Reproducing Paper Results
68
-
69
- ```bash
70
- # Phase 1: validation, speedup, ablation
71
- python experiments/phase1_validation.py
72
- python experiments/phase1_speedup.py
73
- python experiments/phase1_ablation.py
74
-
75
- # Phase 2: GAN mode collapse experiment
76
- python experiments/phase2_gan.py
77
-
78
- # Phase 3: anomaly detection benchmark
79
- python experiments/phase3_anomaly.py
80
- ```
81
-
82
- For GPU runs, use the provided Kaggle kernels:
83
- - Phase 1–2: `hsingle/dcb-full-experiments`
84
- - Phase 3: `hsingle/dcb-phase-3-anomaly-detection`
85
-
86
- ## Kaggle GPU Notes
87
-
88
- Kaggle may assign a P100 (sm_60) instead of T4. The Phase 3 kernel handles this automatically by installing `torch==2.2.2+cu118` (the earliest PyTorch release with both Python 3.12 and sm_60 support) when P100 is detected.
89
-
90
- ## License
91
-
92
- MIT — see [LICENSE](LICENSE).
@@ -1,144 +0,0 @@
1
- """
2
- dcb.fft_kde — FFT-based KDE Mode Counter
3
-
4
- Implements mode counting via FFT convolution of the histogram with a
5
- Gaussian derivative kernel. Complexity is O(n + G log G), avoiding the
6
- O(n × G) cost of the direct KDE approach and — crucially — requiring NO
7
- subsampling. This eliminates the (brentq_n_max / n)^{-1/5} upward bias
8
- that affects the standard bisection path when n > brentq_n_max.
9
-
10
- Round 18b: forward kernel only. The IFT backward is unchanged (still uses
11
- the analytical chunked KDE derivatives on all n points).
12
- """
13
-
14
- from __future__ import annotations
15
-
16
- import math
17
-
18
- import torch
19
- from torch import Tensor
20
-
21
-
22
- def fft_mode_count(
23
- X: Tensor,
24
- h: float,
25
- G: int = 4096,
26
- pad_factor: int = 4,
27
- domain: tuple[float, float] | None = None,
28
- ) -> int:
29
- """Count KDE modes via FFT convolution — O(n + G log G), no subsampling.
30
-
31
- Bins X into G histogram bins, zero-pads to pad_factor*G, convolves with
32
- the Gaussian derivative kernel in the frequency domain (applying iω·exp(−½(ωh)²)),
33
- back-transforms, and counts positive-to-negative sign changes of the
34
- resulting f' estimate.
35
-
36
- Parameters
37
- ----------
38
- X : Tensor, shape (n,)
39
- 1D data tensor (may be on CPU or CUDA).
40
- h : float
41
- Bandwidth for the Gaussian kernel.
42
- G : int
43
- Number of histogram bins. Must satisfy h > 8 * (data_range / G) for
44
- reliable derivative estimation. Use `adaptive_fft_G` to choose G
45
- automatically before bisection.
46
- pad_factor : int
47
- Zero-padding multiplier (default 4). Mandatory ≥ 2 for circular-wrap
48
- correctness; 4 is recommended at the largest h encountered.
49
- domain : (lo, hi) or None
50
- If provided, use this as the histogram domain instead of computing
51
- X.min() - 3σ … X.max() + 3σ. Allows the caller to align the domain
52
- with the bisection bracket (e.g., X.min() - 2*h_hi … X.max() + 2*h_hi)
53
- so every fft_mode_count call in a bisection loop uses an identical grid.
54
-
55
- Returns
56
- -------
57
- int
58
- Number of KDE modes (downward zero-crossings of f').
59
- """
60
- with torch.no_grad():
61
- if domain is not None:
62
- lo, hi = domain
63
- else:
64
- # Domain: extend 3σ beyond data range to avoid boundary effects
65
- sigma = X.std().item()
66
- if sigma == 0.0:
67
- sigma = 1.0 # degenerate case: all points identical
68
- lo = X.min().item() - 3 * sigma
69
- hi = X.max().item() + 3 * sigma
70
- data_range = hi - lo
71
-
72
- if data_range == 0.0:
73
- return 1 # single-point distribution has 1 mode
74
-
75
- # Histogram (O(n)) — MPS-safe via bucketize+bincount on CPU.
76
- # torch.histc on MPS allocates an n × bins float32 intermediate (PyTorch
77
- # MPS bug); at n=5M, bins=512 this is ~9.5 GiB → OOM. Moving to CPU for
78
- # the binning step avoids the intermediate and is numerically identical
79
- # for data within [lo, hi] (guaranteed by the 3σ domain extension above).
80
- X_cpu = X.float().cpu()
81
- edges = torch.linspace(lo, hi, G + 1) # (G+1,) CPU
82
- bin_idx = torch.bucketize(X_cpu, edges, right=True).clamp(1, G) - 1 # 0-indexed
83
- counts = torch.bincount(bin_idx, minlength=G).float().to(X.device) # back to device
84
-
85
- # Zero-pad to pad_factor*G (4× mandatory for circular wrap correctness at h_hi)
86
- N = pad_factor * G
87
- counts_padded = torch.zeros(N, dtype=torch.float32, device=X.device)
88
- counts_padded[:G] = counts
89
-
90
- # FFT of histogram
91
- C = torch.fft.rfft(counts_padded)
92
-
93
- # Derivative kernel in frequency domain: iω * exp(-0.5*(ω*h)²)
94
- # ω_k = 2π*k / (N * bin_width), bin_width = data_range / G
95
- bin_width = data_range / G
96
- k = torch.arange(N // 2 + 1, device=X.device, dtype=torch.float32)
97
- omega = 2 * math.pi * k / (N * bin_width)
98
- K_deriv = 1j * omega * torch.exp(-0.5 * (omega * h) ** 2)
99
-
100
- # Convolve and back-transform
101
- f_prime_padded = torch.fft.irfft(C * K_deriv, n=N)
102
-
103
- # Trim to original G grid (discard zero-padded tail)
104
- f_prime = f_prime_padded[:G]
105
-
106
- # Count (+→-) sign changes = number of modes
107
- # A mode is a local max of f, i.e., f' crosses zero from + to -
108
- # Remove zeros (flat segments) — carry forward last nonzero sign
109
- nonzero_mask = f_prime != 0
110
- if not nonzero_mask.any():
111
- return 0
112
-
113
- s = f_prime[nonzero_mask]
114
- transitions = int(((s[:-1] > 0) & (s[1:] < 0)).sum().item())
115
- return transitions
116
-
117
-
118
- def adaptive_fft_G(data_range: float, h_hi: float, G_min: int = 4096) -> int:
119
- """Choose FFT grid size G so that the derivative kernel is well-resolved.
120
-
121
- Requires h > 8 * bin_width = 8 * data_range / G, equivalently
122
- G > 8 * data_range / h_hi. We use a factor of 16 for safety margin,
123
- then round up to the next power of 2 for efficient FFT.
124
-
125
- Parameters
126
- ----------
127
- data_range : float
128
- hi - lo of the data domain (typically X.max() - X.min() + 6σ).
129
- h_hi : float
130
- Upper bracket of the bisection (smallest h needing resolution).
131
- G_min : int
132
- Minimum returned G (default 4096).
133
-
134
- Returns
135
- -------
136
- int
137
- Grid size G, a power of 2, at least G_min.
138
- """
139
- needed = 16 * math.ceil(data_range / h_hi)
140
- # Round up to next power of 2
141
- p = 1
142
- while p < needed:
143
- p <<= 1
144
- return max(G_min, p)
File without changes
File without changes
File without changes
File without changes
File without changes
File without changes
File without changes
File without changes
File without changes
File without changes