diffcb 0.1.1__tar.gz → 0.1.4__tar.gz
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- {diffcb-0.1.1 → diffcb-0.1.4}/PKG-INFO +58 -21
- diffcb-0.1.4/README.md +129 -0
- {diffcb-0.1.1 → diffcb-0.1.4}/dcb/__init__.py +1 -1
- diffcb-0.1.4/dcb/fft_kde.py +339 -0
- {diffcb-0.1.1 → diffcb-0.1.4}/dcb/layer.py +17 -5
- {diffcb-0.1.1 → diffcb-0.1.4}/dcb/solver.py +37 -9
- {diffcb-0.1.1 → diffcb-0.1.4}/pyproject.toml +1 -1
- diffcb-0.1.1/README.md +0 -92
- diffcb-0.1.1/dcb/fft_kde.py +0 -144
- {diffcb-0.1.1 → diffcb-0.1.4}/.gitignore +0 -0
- {diffcb-0.1.1 → diffcb-0.1.4}/.zenodo.json +0 -0
- {diffcb-0.1.1 → diffcb-0.1.4}/LICENSE +0 -0
- {diffcb-0.1.1 → diffcb-0.1.4}/dcb/diagnostics.py +0 -0
- {diffcb-0.1.1 → diffcb-0.1.4}/dcb/kde.py +0 -0
- {diffcb-0.1.1 → diffcb-0.1.4}/dcb/utils.py +0 -0
- {diffcb-0.1.1 → diffcb-0.1.4}/notebooks/.gitkeep +0 -0
- {diffcb-0.1.1 → diffcb-0.1.4}/tests/test_kde.py +0 -0
- {diffcb-0.1.1 → diffcb-0.1.4}/tests/test_layer.py +0 -0
- {diffcb-0.1.1 → diffcb-0.1.4}/tests/test_r18c_denom_audit.py +0 -0
- {diffcb-0.1.1 → diffcb-0.1.4}/tests/test_r18c_deprecation_warn.py +0 -0
- {diffcb-0.1.1 → diffcb-0.1.4}/tests/test_r19_default_fft.py +0 -0
- {diffcb-0.1.1 → diffcb-0.1.4}/tests/test_r19_diagnostics.py +0 -0
- {diffcb-0.1.1 → diffcb-0.1.4}/tests/test_solver.py +0 -0
|
@@ -1,6 +1,6 @@
|
|
|
1
1
|
Metadata-Version: 2.4
|
|
2
2
|
Name: diffcb
|
|
3
|
-
Version: 0.1.
|
|
3
|
+
Version: 0.1.4
|
|
4
4
|
Summary: Differentiable Critical Bandwidth: Silverman's modality test as a differentiable PyTorch layer with IFT backward pass.
|
|
5
5
|
Project-URL: Homepage, https://github.com/ryZhangHason/differentiable-critical-bandwidth
|
|
6
6
|
Project-URL: Repository, https://github.com/ryZhangHason/differentiable-critical-bandwidth
|
|
@@ -71,10 +71,10 @@ The critical bandwidth `h_crit` is the minimum KDE bandwidth at which a distribu
|
|
|
71
71
|
import torch
|
|
72
72
|
from dcb import DCBLayer
|
|
73
73
|
|
|
74
|
-
X = torch.randn(
|
|
74
|
+
X = torch.randn(1000, requires_grad=True) # 1D samples
|
|
75
75
|
layer = DCBLayer(target_modes=1)
|
|
76
|
-
h_crit = layer(X)
|
|
77
|
-
h_crit.backward()
|
|
76
|
+
h_crit = layer(X) # differentiable scalar
|
|
77
|
+
h_crit.backward() # exact IFT gradients
|
|
78
78
|
```
|
|
79
79
|
|
|
80
80
|
## Installation
|
|
@@ -91,34 +91,72 @@ cd differentiable-critical-bandwidth
|
|
|
91
91
|
pip install -e ".[dev]"
|
|
92
92
|
```
|
|
93
93
|
|
|
94
|
-
##
|
|
94
|
+
## Accuracy vs R's `bw.crit`
|
|
95
95
|
|
|
96
|
-
|
|
96
|
+
DCB is validated against R's `multimode::bw.crit(data, mod0=1)` — the standard reference implementation of Hall & York (2001). On **identical data**:
|
|
97
|
+
|
|
98
|
+
| n | DCB vs R (same sample) | DCB vs R (independent samples) |
|
|
99
|
+
|---|---|---|
|
|
100
|
+
| 100K | **0.004%** | ~0.5% (MC noise from independent RNG) |
|
|
101
|
+
| 1M | **0.005%** | ~0.2% |
|
|
102
|
+
| 10M | **0.004%** | ~0.1% |
|
|
103
|
+
|
|
104
|
+
The independent-sample figures reflect natural sampling variability (two unbiased estimators drawing different data), not algorithmic error. On identical data, DCB agrees with R to within **0.005%** at all tested n. DCB is 43× faster than R at n=100M (1.1 s vs 50 s) and handles n=2B in 24 s while R OOMs.
|
|
105
|
+
|
|
106
|
+
## Key Parameters
|
|
107
|
+
|
|
108
|
+
```python
|
|
109
|
+
DCBLayer(
|
|
110
|
+
target_modes=1, # target number of modes
|
|
111
|
+
G=512, # IFT evaluation grid points
|
|
112
|
+
use_fft=True, # FFT forward (default); eliminates subsampling bias for n>50K
|
|
113
|
+
max_n_exact=1_000_000,# sketch to sketch_size when n exceeds this (None = always exact)
|
|
114
|
+
sketch_size=500_000, # sketch target; 500K matches full-n accuracy (O(n^{-2/9}) rate)
|
|
115
|
+
safe_backward=False, # clamp IFT denominator near bifurcations
|
|
116
|
+
)
|
|
117
|
+
```
|
|
97
118
|
|
|
98
119
|
## Confirmed Experimental Results
|
|
99
120
|
|
|
100
|
-
All results produced on Kaggle
|
|
121
|
+
All GPU results produced on Kaggle (T4 / P100) — see `experiments/` and `outputs/`.
|
|
101
122
|
|
|
102
123
|
| Experiment | Result | Criterion |
|
|
103
124
|
|---|---|---|
|
|
104
|
-
| **
|
|
105
|
-
| **
|
|
125
|
+
| **Accuracy vs R (same data, n=100K)** | **0.004%** | < 0.01% ✓ |
|
|
126
|
+
| **Validation (m≥2, Marron-Wand)** | R²=0.91, MAE=0.07, ρ=0.89 | R²≥0.85 ✓ |
|
|
127
|
+
| **Speedup vs scipy (CUDA T4, n=8192)** | **10.5×** | ≥3× ✓ |
|
|
106
128
|
| **GAN mode preservation** | h_crit=1.232 >> 0.3 | h_crit>0.3 ✓ |
|
|
107
129
|
| **Anomaly AUC (KDDCup99)** | DCB=**0.9982** vs IF=0.9867 | DCB≥IF ✓ |
|
|
108
130
|
|
|
131
|
+
## Changelog
|
|
132
|
+
|
|
133
|
+
### v0.1.1 (2026-05-29)
|
|
134
|
+
- **MPS fix:** `torch.histc` on MPS allocated an n×bins intermediate (OOM at n≥5M). Replaced with `bucketize+bincount` on CPU — MPS-safe and numerically identical.
|
|
135
|
+
- **Sketch API:** `DCBLayer(max_n_exact=1_000_000, sketch_size=500_000)` — silently sketches to 500K when n exceeds threshold. Justified by O(n⁻²/⁹) convergence of h_crit; 500K sketch matches full-n accuracy.
|
|
136
|
+
- **Consistent bisection domain:** Pre-computed domain passed to all `fft_mode_count` calls in a single bisection, eliminating per-step drift.
|
|
137
|
+
- **Bias warning direction:** Corrected "expected upward bias" to "expected downward bias" on legacy `use_fft=False` path.
|
|
138
|
+
- **Test fixes:** Updated 8 pre-existing test failures (tuple unpacking, bounds, deprecation API).
|
|
139
|
+
|
|
140
|
+
### v0.1.0 (2026-05-28)
|
|
141
|
+
- Initial PyPI release. FFT forward (O(n + G log G)), IFT backward, MPS support.
|
|
142
|
+
|
|
109
143
|
## Repository Structure
|
|
110
144
|
|
|
111
145
|
```
|
|
112
|
-
dcb/ Core PyTorch package
|
|
146
|
+
dcb/ Core PyTorch package
|
|
147
|
+
layer.py DCBLayer nn.Module + DCBFunction autograd
|
|
148
|
+
solver.py IFT root-finder and backward pass
|
|
149
|
+
fft_kde.py FFT-based mode counter (MPS-safe, float64, G=16384)
|
|
150
|
+
kde.py Direct KDE derivatives (small-n path)
|
|
151
|
+
utils.py Grid, Silverman bandwidth, sg() stabilizer
|
|
113
152
|
experiments/ Reproduction scripts for all paper figures and tables
|
|
114
|
-
|
|
115
|
-
|
|
116
|
-
|
|
117
|
-
|
|
118
|
-
|
|
119
|
-
tests/ Unit tests (pytest,
|
|
153
|
+
phase1_*.py Validation, speedup, ablation (Figures 1–2, S1–S2)
|
|
154
|
+
phase2_gan.py GAN mode-collapse prevention (Figure 3)
|
|
155
|
+
phase3_anomaly.py Anomaly detection (Table 2, Figure 5)
|
|
156
|
+
round20_*.py Large-n R comparison and streaming benchmarks
|
|
157
|
+
round21_*.py Accuracy improvement experiments
|
|
158
|
+
tests/ Unit tests (pytest, 45 passed, 1 xfailed)
|
|
120
159
|
outputs/ All generated figures and tables (PDFs, PNGs, CSVs)
|
|
121
|
-
notebooks/ Quickstart and demo notebooks
|
|
122
160
|
```
|
|
123
161
|
|
|
124
162
|
## Reproducing Paper Results
|
|
@@ -127,7 +165,6 @@ notebooks/ Quickstart and demo notebooks
|
|
|
127
165
|
# Phase 1: validation, speedup, ablation
|
|
128
166
|
python experiments/phase1_validation.py
|
|
129
167
|
python experiments/phase1_speedup.py
|
|
130
|
-
python experiments/phase1_ablation.py
|
|
131
168
|
|
|
132
169
|
# Phase 2: GAN mode collapse experiment
|
|
133
170
|
python experiments/phase2_gan.py
|
|
@@ -136,13 +173,13 @@ python experiments/phase2_gan.py
|
|
|
136
173
|
python experiments/phase3_anomaly.py
|
|
137
174
|
```
|
|
138
175
|
|
|
139
|
-
For GPU runs
|
|
176
|
+
For GPU runs use the Kaggle kernels:
|
|
140
177
|
- Phase 1–2: `hsingle/dcb-full-experiments`
|
|
141
178
|
- Phase 3: `hsingle/dcb-phase-3-anomaly-detection`
|
|
142
179
|
|
|
143
|
-
##
|
|
180
|
+
## Paper
|
|
144
181
|
|
|
145
|
-
|
|
182
|
+
> Ruiyu Zhang. "Differentiable Critical Bandwidth: Making Silverman's Modality Test End-to-End Trainable." *Journal of Machine Learning Research*, 2026 (in preparation).
|
|
146
183
|
|
|
147
184
|
## License
|
|
148
185
|
|
diffcb-0.1.4/README.md
ADDED
|
@@ -0,0 +1,129 @@
|
|
|
1
|
+
# DCB — Differentiable Critical Bandwidth
|
|
2
|
+
|
|
3
|
+
[](https://pypi.org/project/diffcb/)
|
|
4
|
+
[](LICENSE)
|
|
5
|
+
[](https://www.python.org/)
|
|
6
|
+
|
|
7
|
+
A PyTorch package that makes **Silverman's critical bandwidth test (1981)** fully differentiable, enabling end-to-end gradient-based optimization over the modal structure of continuous distributions.
|
|
8
|
+
|
|
9
|
+
## Overview
|
|
10
|
+
|
|
11
|
+
The critical bandwidth `h_crit` is the minimum KDE bandwidth at which a distribution appears to have at most `m` modes — a classical nonparametric statistic for modality testing. DCB replaces every non-differentiable operation in its computation with a smooth surrogate, then uses the **Implicit Function Theorem** to compute exact gradients through the root-finding step at O(1) memory cost.
|
|
12
|
+
|
|
13
|
+
```python
|
|
14
|
+
import torch
|
|
15
|
+
from dcb import DCBLayer
|
|
16
|
+
|
|
17
|
+
X = torch.randn(1000, requires_grad=True) # 1D samples
|
|
18
|
+
layer = DCBLayer(target_modes=1)
|
|
19
|
+
h_crit = layer(X) # differentiable scalar
|
|
20
|
+
h_crit.backward() # exact IFT gradients
|
|
21
|
+
```
|
|
22
|
+
|
|
23
|
+
## Installation
|
|
24
|
+
|
|
25
|
+
```bash
|
|
26
|
+
pip install diffcb
|
|
27
|
+
```
|
|
28
|
+
|
|
29
|
+
Or from source:
|
|
30
|
+
|
|
31
|
+
```bash
|
|
32
|
+
git clone https://github.com/ryZhangHason/differentiable-critical-bandwidth
|
|
33
|
+
cd differentiable-critical-bandwidth
|
|
34
|
+
pip install -e ".[dev]"
|
|
35
|
+
```
|
|
36
|
+
|
|
37
|
+
## Accuracy vs R's `bw.crit`
|
|
38
|
+
|
|
39
|
+
DCB is validated against R's `multimode::bw.crit(data, mod0=1)` — the standard reference implementation of Hall & York (2001). On **identical data**:
|
|
40
|
+
|
|
41
|
+
| n | DCB vs R (same sample) | DCB vs R (independent samples) |
|
|
42
|
+
|---|---|---|
|
|
43
|
+
| 100K | **0.004%** | ~0.5% (MC noise from independent RNG) |
|
|
44
|
+
| 1M | **0.005%** | ~0.2% |
|
|
45
|
+
| 10M | **0.004%** | ~0.1% |
|
|
46
|
+
|
|
47
|
+
The independent-sample figures reflect natural sampling variability (two unbiased estimators drawing different data), not algorithmic error. On identical data, DCB agrees with R to within **0.005%** at all tested n. DCB is 43× faster than R at n=100M (1.1 s vs 50 s) and handles n=2B in 24 s while R OOMs.
|
|
48
|
+
|
|
49
|
+
## Key Parameters
|
|
50
|
+
|
|
51
|
+
```python
|
|
52
|
+
DCBLayer(
|
|
53
|
+
target_modes=1, # target number of modes
|
|
54
|
+
G=512, # IFT evaluation grid points
|
|
55
|
+
use_fft=True, # FFT forward (default); eliminates subsampling bias for n>50K
|
|
56
|
+
max_n_exact=1_000_000,# sketch to sketch_size when n exceeds this (None = always exact)
|
|
57
|
+
sketch_size=500_000, # sketch target; 500K matches full-n accuracy (O(n^{-2/9}) rate)
|
|
58
|
+
safe_backward=False, # clamp IFT denominator near bifurcations
|
|
59
|
+
)
|
|
60
|
+
```
|
|
61
|
+
|
|
62
|
+
## Confirmed Experimental Results
|
|
63
|
+
|
|
64
|
+
All GPU results produced on Kaggle (T4 / P100) — see `experiments/` and `outputs/`.
|
|
65
|
+
|
|
66
|
+
| Experiment | Result | Criterion |
|
|
67
|
+
|---|---|---|
|
|
68
|
+
| **Accuracy vs R (same data, n=100K)** | **0.004%** | < 0.01% ✓ |
|
|
69
|
+
| **Validation (m≥2, Marron-Wand)** | R²=0.91, MAE=0.07, ρ=0.89 | R²≥0.85 ✓ |
|
|
70
|
+
| **Speedup vs scipy (CUDA T4, n=8192)** | **10.5×** | ≥3× ✓ |
|
|
71
|
+
| **GAN mode preservation** | h_crit=1.232 >> 0.3 | h_crit>0.3 ✓ |
|
|
72
|
+
| **Anomaly AUC (KDDCup99)** | DCB=**0.9982** vs IF=0.9867 | DCB≥IF ✓ |
|
|
73
|
+
|
|
74
|
+
## Changelog
|
|
75
|
+
|
|
76
|
+
### v0.1.1 (2026-05-29)
|
|
77
|
+
- **MPS fix:** `torch.histc` on MPS allocated an n×bins intermediate (OOM at n≥5M). Replaced with `bucketize+bincount` on CPU — MPS-safe and numerically identical.
|
|
78
|
+
- **Sketch API:** `DCBLayer(max_n_exact=1_000_000, sketch_size=500_000)` — silently sketches to 500K when n exceeds threshold. Justified by O(n⁻²/⁹) convergence of h_crit; 500K sketch matches full-n accuracy.
|
|
79
|
+
- **Consistent bisection domain:** Pre-computed domain passed to all `fft_mode_count` calls in a single bisection, eliminating per-step drift.
|
|
80
|
+
- **Bias warning direction:** Corrected "expected upward bias" to "expected downward bias" on legacy `use_fft=False` path.
|
|
81
|
+
- **Test fixes:** Updated 8 pre-existing test failures (tuple unpacking, bounds, deprecation API).
|
|
82
|
+
|
|
83
|
+
### v0.1.0 (2026-05-28)
|
|
84
|
+
- Initial PyPI release. FFT forward (O(n + G log G)), IFT backward, MPS support.
|
|
85
|
+
|
|
86
|
+
## Repository Structure
|
|
87
|
+
|
|
88
|
+
```
|
|
89
|
+
dcb/ Core PyTorch package
|
|
90
|
+
layer.py DCBLayer nn.Module + DCBFunction autograd
|
|
91
|
+
solver.py IFT root-finder and backward pass
|
|
92
|
+
fft_kde.py FFT-based mode counter (MPS-safe, float64, G=16384)
|
|
93
|
+
kde.py Direct KDE derivatives (small-n path)
|
|
94
|
+
utils.py Grid, Silverman bandwidth, sg() stabilizer
|
|
95
|
+
experiments/ Reproduction scripts for all paper figures and tables
|
|
96
|
+
phase1_*.py Validation, speedup, ablation (Figures 1–2, S1–S2)
|
|
97
|
+
phase2_gan.py GAN mode-collapse prevention (Figure 3)
|
|
98
|
+
phase3_anomaly.py Anomaly detection (Table 2, Figure 5)
|
|
99
|
+
round20_*.py Large-n R comparison and streaming benchmarks
|
|
100
|
+
round21_*.py Accuracy improvement experiments
|
|
101
|
+
tests/ Unit tests (pytest, 45 passed, 1 xfailed)
|
|
102
|
+
outputs/ All generated figures and tables (PDFs, PNGs, CSVs)
|
|
103
|
+
```
|
|
104
|
+
|
|
105
|
+
## Reproducing Paper Results
|
|
106
|
+
|
|
107
|
+
```bash
|
|
108
|
+
# Phase 1: validation, speedup, ablation
|
|
109
|
+
python experiments/phase1_validation.py
|
|
110
|
+
python experiments/phase1_speedup.py
|
|
111
|
+
|
|
112
|
+
# Phase 2: GAN mode collapse experiment
|
|
113
|
+
python experiments/phase2_gan.py
|
|
114
|
+
|
|
115
|
+
# Phase 3: anomaly detection benchmark
|
|
116
|
+
python experiments/phase3_anomaly.py
|
|
117
|
+
```
|
|
118
|
+
|
|
119
|
+
For GPU runs use the Kaggle kernels:
|
|
120
|
+
- Phase 1–2: `hsingle/dcb-full-experiments`
|
|
121
|
+
- Phase 3: `hsingle/dcb-phase-3-anomaly-detection`
|
|
122
|
+
|
|
123
|
+
## Paper
|
|
124
|
+
|
|
125
|
+
> Ruiyu Zhang. "Differentiable Critical Bandwidth: Making Silverman's Modality Test End-to-End Trainable." *Journal of Machine Learning Research*, 2026 (in preparation).
|
|
126
|
+
|
|
127
|
+
## License
|
|
128
|
+
|
|
129
|
+
MIT — see [LICENSE](LICENSE).
|
|
@@ -0,0 +1,339 @@
|
|
|
1
|
+
"""
|
|
2
|
+
dcb.fft_kde — FFT-based KDE Mode Counter
|
|
3
|
+
|
|
4
|
+
Implements mode counting via FFT convolution of the histogram with a
|
|
5
|
+
Gaussian derivative kernel. Complexity is O(n + G log G), avoiding the
|
|
6
|
+
O(n × G) cost of the direct KDE approach and — crucially — requiring NO
|
|
7
|
+
subsampling. This eliminates the (brentq_n_max / n)^{-1/5} upward bias
|
|
8
|
+
that affects the standard bisection path when n > brentq_n_max.
|
|
9
|
+
|
|
10
|
+
Round 18b: forward kernel only. The IFT backward is unchanged (still uses
|
|
11
|
+
the analytical chunked KDE derivatives on all n points).
|
|
12
|
+
"""
|
|
13
|
+
|
|
14
|
+
from __future__ import annotations
|
|
15
|
+
|
|
16
|
+
import math
|
|
17
|
+
|
|
18
|
+
import torch
|
|
19
|
+
from torch import Tensor
|
|
20
|
+
|
|
21
|
+
|
|
22
|
+
# Worker 2: device-native histogram
|
|
23
|
+
def _histogram_on_device(X: Tensor, G: int, lo: float, hi: float) -> Tensor:
|
|
24
|
+
"""Compute a G-bin histogram of X on the same device as X."""
|
|
25
|
+
device = X.device
|
|
26
|
+
if device.type == 'cuda':
|
|
27
|
+
return torch.histc(X.float(), bins=G, min=lo, max=hi)
|
|
28
|
+
elif device.type == 'mps':
|
|
29
|
+
bin_idx = ((X.float() - lo) * (G / (hi - lo))).long().clamp_(0, G - 1)
|
|
30
|
+
counts = torch.zeros(G, dtype=torch.float32, device=device)
|
|
31
|
+
counts.scatter_add_(0, bin_idx, torch.ones(X.shape[0], dtype=torch.float32, device=device))
|
|
32
|
+
return counts
|
|
33
|
+
else: # cpu
|
|
34
|
+
X_cpu = X.float()
|
|
35
|
+
edges = torch.linspace(lo, hi, G + 1)
|
|
36
|
+
bin_idx = torch.bucketize(X_cpu, edges, right=True).clamp(1, G) - 1
|
|
37
|
+
return torch.bincount(bin_idx, minlength=G).float()
|
|
38
|
+
|
|
39
|
+
|
|
40
|
+
def precompute_fft(
|
|
41
|
+
X: Tensor,
|
|
42
|
+
G: int = 4096,
|
|
43
|
+
domain: tuple[float, float] | None = None,
|
|
44
|
+
pad_factor: int = 2, # Worker 5: pad_factor=2 (was 4) — safe for h ≤ 3σ, halves irfft size
|
|
45
|
+
fft_dtype: torch.dtype = torch.float32, # Worker 3: float32 FFT
|
|
46
|
+
) -> tuple[Tensor, Tensor, tuple[float, float]]:
|
|
47
|
+
"""Precompute the FFT of the zero-padded histogram of X.
|
|
48
|
+
|
|
49
|
+
This is the bandwidth-independent work shared across a bisection loop on
|
|
50
|
+
h: build the histogram, zero-pad, take rfft, and build the frequency grid
|
|
51
|
+
omega. The per-step kernel K(omega, h) = i*omega*exp(-0.5*(omega*h)**2)
|
|
52
|
+
must be combined with C inside `mode_count_from_C`.
|
|
53
|
+
|
|
54
|
+
Parameters
|
|
55
|
+
----------
|
|
56
|
+
X : Tensor, shape (n,)
|
|
57
|
+
G : int
|
|
58
|
+
Number of histogram bins.
|
|
59
|
+
domain : (lo, hi) or None
|
|
60
|
+
If provided, use as histogram domain; otherwise computed from X
|
|
61
|
+
with a 3*sigma margin.
|
|
62
|
+
pad_factor : int
|
|
63
|
+
Zero-padding multiplier (default 4).
|
|
64
|
+
|
|
65
|
+
Returns
|
|
66
|
+
-------
|
|
67
|
+
C : Tensor, shape (N//2+1,), complex128
|
|
68
|
+
rfft of the zero-padded float64 histogram. Empty tensor (degenerate
|
|
69
|
+
zero-range domain) signals the caller to short-circuit to 1 mode.
|
|
70
|
+
omega : Tensor, shape (N//2+1,), float64
|
|
71
|
+
Angular frequency grid for the FFT.
|
|
72
|
+
domain : (lo, hi)
|
|
73
|
+
Domain tuple actually used.
|
|
74
|
+
"""
|
|
75
|
+
with torch.no_grad():
|
|
76
|
+
if domain is not None:
|
|
77
|
+
lo, hi = domain
|
|
78
|
+
else:
|
|
79
|
+
sigma = X.std().item()
|
|
80
|
+
if sigma == 0.0:
|
|
81
|
+
sigma = 1.0
|
|
82
|
+
lo = X.min().item() - 3 * sigma
|
|
83
|
+
hi = X.max().item() + 3 * sigma
|
|
84
|
+
data_range = hi - lo
|
|
85
|
+
|
|
86
|
+
if data_range == 0.0:
|
|
87
|
+
complex_dtype = torch.complex64 if fft_dtype == torch.float32 else torch.complex128
|
|
88
|
+
empty = torch.zeros(0, dtype=complex_dtype, device=X.device)
|
|
89
|
+
empty_omega = torch.zeros(0, dtype=fft_dtype, device=X.device)
|
|
90
|
+
return empty, empty_omega, (lo, hi)
|
|
91
|
+
|
|
92
|
+
# Histogram (O(n)) — device-native dispatch.
|
|
93
|
+
counts = _histogram_on_device(X, G, lo, hi)
|
|
94
|
+
|
|
95
|
+
N = pad_factor * G
|
|
96
|
+
counts_padded = torch.zeros(N, dtype=fft_dtype, device=X.device)
|
|
97
|
+
counts_padded[:G] = counts.to(fft_dtype)
|
|
98
|
+
|
|
99
|
+
C = torch.fft.rfft(counts_padded)
|
|
100
|
+
|
|
101
|
+
bin_width = data_range / G
|
|
102
|
+
k = torch.arange(N // 2 + 1, device=X.device, dtype=fft_dtype)
|
|
103
|
+
omega = 2 * math.pi * k / (N * bin_width)
|
|
104
|
+
|
|
105
|
+
return C, omega, (lo, hi)
|
|
106
|
+
|
|
107
|
+
|
|
108
|
+
def mode_count_from_C(
|
|
109
|
+
C: Tensor,
|
|
110
|
+
omega: Tensor,
|
|
111
|
+
h: float,
|
|
112
|
+
G: int,
|
|
113
|
+
N: int,
|
|
114
|
+
) -> int:
|
|
115
|
+
"""Per-step mode count: apply Gaussian derivative kernel and count sign changes.
|
|
116
|
+
|
|
117
|
+
Cheap inner loop body for bisection — only the kernel depends on h.
|
|
118
|
+
|
|
119
|
+
Parameters
|
|
120
|
+
----------
|
|
121
|
+
C : Tensor, shape (N//2+1,), complex
|
|
122
|
+
rfft of the zero-padded histogram (from `precompute_fft`).
|
|
123
|
+
omega : Tensor, shape (N//2+1,), float64
|
|
124
|
+
Frequency grid (from `precompute_fft`).
|
|
125
|
+
h : float
|
|
126
|
+
Bandwidth.
|
|
127
|
+
G : int
|
|
128
|
+
Histogram bin count.
|
|
129
|
+
N : int
|
|
130
|
+
Padded FFT length (pad_factor * G).
|
|
131
|
+
|
|
132
|
+
Returns
|
|
133
|
+
-------
|
|
134
|
+
int
|
|
135
|
+
Number of KDE modes.
|
|
136
|
+
"""
|
|
137
|
+
if C.numel() == 0:
|
|
138
|
+
return 1 # degenerate single-point distribution
|
|
139
|
+
|
|
140
|
+
K_deriv = 1j * omega * torch.exp(-0.5 * (omega * h) ** 2)
|
|
141
|
+
f_prime_padded = torch.fft.irfft(C * K_deriv, n=N).real
|
|
142
|
+
f_prime = f_prime_padded[:G]
|
|
143
|
+
|
|
144
|
+
nonzero_mask = f_prime != 0
|
|
145
|
+
if not nonzero_mask.any():
|
|
146
|
+
return 0
|
|
147
|
+
|
|
148
|
+
s = f_prime[nonzero_mask]
|
|
149
|
+
transitions = int(((s[:-1] > 0) & (s[1:] < 0)).sum().item())
|
|
150
|
+
return transitions
|
|
151
|
+
|
|
152
|
+
|
|
153
|
+
def fft_mode_count(
|
|
154
|
+
X: Tensor,
|
|
155
|
+
h: float,
|
|
156
|
+
G: int = 4096,
|
|
157
|
+
pad_factor: int = 2, # Worker 5: pad_factor=2 (was 4) — safe for h ≤ 3σ, halves irfft size
|
|
158
|
+
domain: tuple[float, float] | None = None,
|
|
159
|
+
) -> int:
|
|
160
|
+
"""Count KDE modes via FFT convolution — O(n + G log G), no subsampling.
|
|
161
|
+
|
|
162
|
+
Bins X into G histogram bins, zero-pads to pad_factor*G, convolves with
|
|
163
|
+
the Gaussian derivative kernel in the frequency domain (applying iω·exp(−½(ωh)²)),
|
|
164
|
+
back-transforms, and counts positive-to-negative sign changes of the
|
|
165
|
+
resulting f' estimate.
|
|
166
|
+
|
|
167
|
+
Parameters
|
|
168
|
+
----------
|
|
169
|
+
X : Tensor, shape (n,)
|
|
170
|
+
1D data tensor (may be on CPU or CUDA).
|
|
171
|
+
h : float
|
|
172
|
+
Bandwidth for the Gaussian kernel.
|
|
173
|
+
G : int
|
|
174
|
+
Number of histogram bins. Must satisfy h > 8 * (data_range / G) for
|
|
175
|
+
reliable derivative estimation. Use `adaptive_fft_G` to choose G
|
|
176
|
+
automatically before bisection.
|
|
177
|
+
pad_factor : int
|
|
178
|
+
Zero-padding multiplier (default 4). Mandatory ≥ 2 for circular-wrap
|
|
179
|
+
correctness; 4 is recommended at the largest h encountered.
|
|
180
|
+
domain : (lo, hi) or None
|
|
181
|
+
If provided, use this as the histogram domain instead of computing
|
|
182
|
+
X.min() - 3σ … X.max() + 3σ. Allows the caller to align the domain
|
|
183
|
+
with the bisection bracket (e.g., X.min() - 2*h_hi … X.max() + 2*h_hi)
|
|
184
|
+
so every fft_mode_count call in a bisection loop uses an identical grid.
|
|
185
|
+
|
|
186
|
+
Returns
|
|
187
|
+
-------
|
|
188
|
+
int
|
|
189
|
+
Number of KDE modes (downward zero-crossings of f').
|
|
190
|
+
"""
|
|
191
|
+
with torch.no_grad():
|
|
192
|
+
C, omega, _ = precompute_fft(X, G=G, domain=domain, pad_factor=pad_factor)
|
|
193
|
+
N = pad_factor * G
|
|
194
|
+
return mode_count_from_C(C, omega, h, G, N)
|
|
195
|
+
|
|
196
|
+
|
|
197
|
+
def _refine_hcrit(
|
|
198
|
+
X: Tensor,
|
|
199
|
+
h_lo: float,
|
|
200
|
+
h_hi: float,
|
|
201
|
+
G: int,
|
|
202
|
+
domain: tuple[float, float],
|
|
203
|
+
target_modes: int = 1,
|
|
204
|
+
pad_factor: int = 2, # Worker 5: pad_factor=2 (was 4) — safe for h ≤ 3σ, halves irfft size
|
|
205
|
+
) -> float:
|
|
206
|
+
"""Sub-bin quadratic refinement of h_crit after bisection converges.
|
|
207
|
+
|
|
208
|
+
Identifies the f′ lobe that disappears at the mode-merging bandwidth and
|
|
209
|
+
fits a quadratic in h to that lobe's peak value, returning the root — the
|
|
210
|
+
h where that peak exactly reaches zero. Reduces the bin-width-limited
|
|
211
|
+
systematic from ~bin_width/h_crit to well below 1e-4.
|
|
212
|
+
|
|
213
|
+
When the incoming bracket [h_lo, h_hi] is tighter than one histogram bin
|
|
214
|
+
width (the common case after 50-step bisection), the function expands the
|
|
215
|
+
bracket outward from h_hi by up to 4× the bin width while maintaining the
|
|
216
|
+
invariant that fft_mode_count > target at the left endpoint and
|
|
217
|
+
<= target at the right endpoint, so the disappearing f′ lobe is visible
|
|
218
|
+
across the bracket.
|
|
219
|
+
|
|
220
|
+
Parameters
|
|
221
|
+
----------
|
|
222
|
+
X : Tensor — data (may be on any device)
|
|
223
|
+
h_lo, h_hi : float — final bisection bracket; fft_mode_count(X,h_lo) > target,
|
|
224
|
+
fft_mode_count(X,h_hi) <= target
|
|
225
|
+
G, domain, target_modes, pad_factor — same as fft_mode_count
|
|
226
|
+
|
|
227
|
+
Returns
|
|
228
|
+
-------
|
|
229
|
+
float — refined h_crit, guaranteed to lie in [h_lo, h_hi] of the
|
|
230
|
+
(possibly expanded) bracket used for fitting.
|
|
231
|
+
"""
|
|
232
|
+
import numpy as np
|
|
233
|
+
|
|
234
|
+
lo_d, hi_d = domain
|
|
235
|
+
data_range = hi_d - lo_d
|
|
236
|
+
if data_range == 0.0:
|
|
237
|
+
return h_hi
|
|
238
|
+
|
|
239
|
+
bin_width = data_range / G
|
|
240
|
+
N = pad_factor * G
|
|
241
|
+
bw = bin_width # histogram bin width
|
|
242
|
+
|
|
243
|
+
# Pre-compute histogram once; reuse C (FFT of counts) for all h evaluations.
|
|
244
|
+
with torch.no_grad():
|
|
245
|
+
counts = _histogram_on_device(X, G, lo_d, hi_d).cpu()
|
|
246
|
+
counts_padded = torch.zeros(N, dtype=torch.float64)
|
|
247
|
+
counts_padded[:G] = counts.double()
|
|
248
|
+
C = torch.fft.rfft(counts_padded)
|
|
249
|
+
k = torch.arange(N // 2 + 1, dtype=torch.float64)
|
|
250
|
+
omega_base = 2 * math.pi * k / (N * bw)
|
|
251
|
+
|
|
252
|
+
def fprime(h: float) -> Tensor:
|
|
253
|
+
"""Compute f′ array (shape G,) for bandwidth h using cached C (float64)."""
|
|
254
|
+
K_deriv = 1j * omega_base * torch.exp(-0.5 * (omega_base * h) ** 2)
|
|
255
|
+
return torch.fft.irfft(C * K_deriv, n=N).float()[:G]
|
|
256
|
+
|
|
257
|
+
with torch.no_grad():
|
|
258
|
+
# If the bracket is tighter than bin_width, expand it so that the
|
|
259
|
+
# disappearing f′ lobe crosses zero somewhere inside the bracket.
|
|
260
|
+
# Expand the left endpoint leftward by up to 4 bin widths.
|
|
261
|
+
ref_lo = h_lo
|
|
262
|
+
ref_hi = h_hi
|
|
263
|
+
|
|
264
|
+
if (ref_hi - ref_lo) < bw:
|
|
265
|
+
# Try expanding leftward until we find a bin where fp crosses zero
|
|
266
|
+
for mult in [1, 2, 3, 4]:
|
|
267
|
+
cand_lo = max(ref_hi - mult * bw, ref_hi * 0.9)
|
|
268
|
+
fp_cand = fprime(cand_lo)
|
|
269
|
+
fp_hi_ = fprime(ref_hi)
|
|
270
|
+
cm = (fp_cand > 0) & (fp_hi_ <= 0)
|
|
271
|
+
if cm.any():
|
|
272
|
+
ref_lo = cand_lo
|
|
273
|
+
break
|
|
274
|
+
# If still no candidates found, return bisection result unchanged
|
|
275
|
+
fp_lo_ = fprime(ref_lo)
|
|
276
|
+
fp_hi_ = fprime(ref_hi)
|
|
277
|
+
candidate_mask = (fp_lo_ > 0) & (fp_hi_ <= 0)
|
|
278
|
+
if not candidate_mask.any():
|
|
279
|
+
return h_hi
|
|
280
|
+
else:
|
|
281
|
+
fp_lo_ = fprime(ref_lo)
|
|
282
|
+
fp_hi_ = fprime(ref_hi)
|
|
283
|
+
candidate_mask = (fp_lo_ > 0) & (fp_hi_ <= 0)
|
|
284
|
+
if not candidate_mask.any():
|
|
285
|
+
return h_hi
|
|
286
|
+
|
|
287
|
+
# Pick the bin with the largest positive value at ref_lo that crossed zero
|
|
288
|
+
masked_fp_lo = fp_lo_.clone()
|
|
289
|
+
masked_fp_lo[~candidate_mask] = -float('inf')
|
|
290
|
+
j = int(masked_fp_lo.argmax().item())
|
|
291
|
+
|
|
292
|
+
h_mid = (ref_lo + ref_hi) / 2.0
|
|
293
|
+
|
|
294
|
+
# Evaluate fp[j] at three bandwidths for quadratic fit
|
|
295
|
+
y_lo = fp_lo_[j].item()
|
|
296
|
+
y_mid = fprime(h_mid)[j].item()
|
|
297
|
+
y_hi = fp_hi_[j].item()
|
|
298
|
+
|
|
299
|
+
# Fit quadratic y = a*h² + b*h + c through the three (h, y) pairs
|
|
300
|
+
# and solve for the root in [ref_lo, ref_hi].
|
|
301
|
+
coeffs = np.polyfit([ref_lo, h_mid, ref_hi], [y_lo, y_mid, y_hi], 2)
|
|
302
|
+
roots = np.roots(coeffs)
|
|
303
|
+
real_roots = [
|
|
304
|
+
r.real for r in roots
|
|
305
|
+
if abs(r.imag) < 1e-10 * abs(r.real + 1e-30)
|
|
306
|
+
and ref_lo <= r.real <= ref_hi
|
|
307
|
+
]
|
|
308
|
+
if real_roots:
|
|
309
|
+
return float(min(real_roots, key=lambda r: abs(r - h_mid)))
|
|
310
|
+
return h_hi
|
|
311
|
+
|
|
312
|
+
|
|
313
|
+
def adaptive_fft_G(data_range: float, h_hi: float, G_min: int = 16384) -> int:
|
|
314
|
+
"""Choose FFT grid size G so that the derivative kernel is well-resolved.
|
|
315
|
+
|
|
316
|
+
Requires h > 8 * bin_width = 8 * data_range / G, equivalently
|
|
317
|
+
G > 8 * data_range / h_hi. We use a factor of 16 for safety margin,
|
|
318
|
+
then round up to the next power of 2 for efficient FFT.
|
|
319
|
+
|
|
320
|
+
Parameters
|
|
321
|
+
----------
|
|
322
|
+
data_range : float
|
|
323
|
+
hi - lo of the data domain (typically X.max() - X.min() + 6σ).
|
|
324
|
+
h_hi : float
|
|
325
|
+
Upper bracket of the bisection (smallest h needing resolution).
|
|
326
|
+
G_min : int
|
|
327
|
+
Minimum returned G (default 16384).
|
|
328
|
+
|
|
329
|
+
Returns
|
|
330
|
+
-------
|
|
331
|
+
int
|
|
332
|
+
Grid size G, a power of 2, at least G_min.
|
|
333
|
+
"""
|
|
334
|
+
needed = 16 * math.ceil(data_range / h_hi)
|
|
335
|
+
# Round up to next power of 2
|
|
336
|
+
p = 1
|
|
337
|
+
while p < needed:
|
|
338
|
+
p <<= 1
|
|
339
|
+
return max(G_min, p)
|
|
@@ -35,13 +35,14 @@ class DCBFunction(torch.autograd.Function):
|
|
|
35
35
|
|
|
36
36
|
@staticmethod
|
|
37
37
|
def forward(ctx, X, grid, eps, tau, target_modes, delta, formula, chunk_size,
|
|
38
|
-
brentq_n_max, g_brentq, use_hard_bisection, safe_backward, use_fft
|
|
38
|
+
brentq_n_max, g_brentq, use_hard_bisection, safe_backward, use_fft, fft_G_min,
|
|
39
|
+
fft_dtype):
|
|
39
40
|
"""Locate h_crit and save state for the backward pass."""
|
|
40
41
|
h_crit, cond_num = find_h_crit(
|
|
41
42
|
X, grid, eps, tau, target_modes,
|
|
42
43
|
formula=formula, brentq_n_max=brentq_n_max, chunk_size=chunk_size,
|
|
43
44
|
g_brentq=g_brentq, use_hard_bisection=use_hard_bisection,
|
|
44
|
-
use_fft=use_fft,
|
|
45
|
+
use_fft=use_fft, G_min=fft_G_min, fft_dtype=fft_dtype,
|
|
45
46
|
)
|
|
46
47
|
ctx.save_for_backward(X, grid)
|
|
47
48
|
ctx.h_crit = h_crit
|
|
@@ -67,8 +68,8 @@ class DCBFunction(torch.autograd.Function):
|
|
|
67
68
|
ctx.denom_abs = ift_gradient.last_denom_abs
|
|
68
69
|
# Gradients for: X, grid, eps, tau, target_modes, delta, formula,
|
|
69
70
|
# chunk_size, brentq_n_max, g_brentq, use_hard_bisection,
|
|
70
|
-
# safe_backward, use_fft
|
|
71
|
-
return grad_X, None, None, None, None, None, None, None, None, None, None, None, None
|
|
71
|
+
# safe_backward, use_fft, fft_G_min, fft_dtype
|
|
72
|
+
return grad_X, None, None, None, None, None, None, None, None, None, None, None, None, None, None
|
|
72
73
|
|
|
73
74
|
|
|
74
75
|
class DCBLayer(nn.Module):
|
|
@@ -133,6 +134,13 @@ class DCBLayer(nn.Module):
|
|
|
133
134
|
Number of points to sketch when n > max_n_exact. Default 500_000.
|
|
134
135
|
A 500K sketch achieves the same mean accuracy as streaming 100M points
|
|
135
136
|
(validated in Round 20 reservoir experiment).
|
|
137
|
+
fft_G_min : int
|
|
138
|
+
Minimum FFT histogram grid size for the bisection solver (default 16384).
|
|
139
|
+
Controls accuracy of the FFT path (n > 50K). Larger values reduce
|
|
140
|
+
discretisation error at a modest cost: G=16384 gives ~0.004% err vs R;
|
|
141
|
+
G=32768 gives ~0.001% at +9% cost; G=65536 reaches the R-matching floor
|
|
142
|
+
(~0.001%) with no further gain beyond that. Ignored for n ≤ 50K (direct
|
|
143
|
+
KDE path).
|
|
136
144
|
|
|
137
145
|
Examples
|
|
138
146
|
--------
|
|
@@ -162,6 +170,8 @@ class DCBLayer(nn.Module):
|
|
|
162
170
|
use_fft: bool = True,
|
|
163
171
|
max_n_exact: int | None = 1_000_000,
|
|
164
172
|
sketch_size: int = 500_000,
|
|
173
|
+
fft_G_min: int = 16384,
|
|
174
|
+
fft_dtype: torch.dtype = torch.float32,
|
|
165
175
|
):
|
|
166
176
|
super().__init__()
|
|
167
177
|
self.target_modes = target_modes
|
|
@@ -180,6 +190,8 @@ class DCBLayer(nn.Module):
|
|
|
180
190
|
self.use_fft = use_fft
|
|
181
191
|
self.max_n_exact = max_n_exact
|
|
182
192
|
self.sketch_size = sketch_size
|
|
193
|
+
self.fft_G_min = fft_G_min
|
|
194
|
+
self.fft_dtype = fft_dtype
|
|
183
195
|
if use_fft and brentq_n_max != 50_000:
|
|
184
196
|
raise TypeError(
|
|
185
197
|
f"brentq_n_max={brentq_n_max} is meaningless when use_fft=True: the FFT path "
|
|
@@ -250,7 +262,7 @@ class DCBLayer(nn.Module):
|
|
|
250
262
|
return DCBFunction.apply(
|
|
251
263
|
X, grid, eps_eff, tau_eff, self.target_modes, self.delta, self.formula,
|
|
252
264
|
self.chunk_size, self.brentq_n_max, self.g_brentq, self.use_hard_bisection,
|
|
253
|
-
self.safe_backward, self.use_fft,
|
|
265
|
+
self.safe_backward, self.use_fft, self.fft_G_min, self.fft_dtype,
|
|
254
266
|
)
|
|
255
267
|
|
|
256
268
|
|
|
@@ -37,7 +37,7 @@ from dcb.kde import (
|
|
|
37
37
|
soft_mode_count_cross_from_derivs,
|
|
38
38
|
kde_derivatives_chunked,
|
|
39
39
|
)
|
|
40
|
-
from dcb.fft_kde import fft_mode_count, adaptive_fft_G
|
|
40
|
+
from dcb.fft_kde import fft_mode_count, adaptive_fft_G, precompute_fft, mode_count_from_C
|
|
41
41
|
|
|
42
42
|
_AUTO_FFT_THRESHOLD = 50_000 # n above which FFT bisection activates (use_fft_effective)
|
|
43
43
|
|
|
@@ -74,6 +74,8 @@ def find_h_crit_hard(
|
|
|
74
74
|
eps: float = 0.1,
|
|
75
75
|
tau: float = 0.2,
|
|
76
76
|
use_fft: bool = False,
|
|
77
|
+
G_min: int = 16384,
|
|
78
|
+
fft_dtype: torch.dtype = torch.float32,
|
|
77
79
|
) -> tuple[float, float]:
|
|
78
80
|
"""Find h_crit via hard-mode-count bisection (monotone, no false roots).
|
|
79
81
|
|
|
@@ -151,43 +153,64 @@ def find_h_crit_hard(
|
|
|
151
153
|
lo_domain = X.min().item() - 3 * sigma
|
|
152
154
|
hi_domain = X.max().item() + 3 * sigma
|
|
153
155
|
data_range = hi_domain - lo_domain
|
|
154
|
-
G_fft = adaptive_fft_G(data_range, h_hi)
|
|
156
|
+
G_fft = adaptive_fft_G(data_range, h_hi, G_min=G_min)
|
|
155
157
|
_domain = (lo_domain, hi_domain)
|
|
158
|
+
pad_factor = 2 # Worker 5: pad_factor=2 (was 4) — safe for h ≤ 3σ, halves irfft size
|
|
159
|
+
N = pad_factor * G_fft
|
|
156
160
|
|
|
157
161
|
with torch.no_grad():
|
|
162
|
+
# Worker 1: precomputed C — hoist histogram + rfft out of bisection.
|
|
163
|
+
# Worker 3: float32 FFT by default — 2× faster; _refine_hcrit uses float64 independently.
|
|
164
|
+
C, omega, _domain = precompute_fft(
|
|
165
|
+
X, G=G_fft, domain=_domain, pad_factor=pad_factor, fft_dtype=fft_dtype,
|
|
166
|
+
)
|
|
167
|
+
|
|
158
168
|
# Verify bracket using FFT mode count on full X
|
|
159
|
-
count_lo =
|
|
169
|
+
count_lo = mode_count_from_C(C, omega, h_lo, G_fft, N)
|
|
160
170
|
if count_lo <= target_modes:
|
|
161
171
|
h_lo_try = h_lo
|
|
162
172
|
for _ in range(30):
|
|
163
173
|
h_lo_try *= 0.5
|
|
164
174
|
if h_lo_try < 1e-10:
|
|
165
175
|
break
|
|
166
|
-
if
|
|
176
|
+
if mode_count_from_C(C, omega, h_lo_try, G_fft, N) > target_modes:
|
|
167
177
|
h_lo = h_lo_try
|
|
168
178
|
break
|
|
169
179
|
|
|
170
|
-
count_hi =
|
|
180
|
+
count_hi = mode_count_from_C(C, omega, h_hi, G_fft, N)
|
|
171
181
|
if count_hi > target_modes:
|
|
172
182
|
for _ in range(30):
|
|
173
183
|
h_hi *= 2.0
|
|
174
|
-
if
|
|
184
|
+
if mode_count_from_C(C, omega, h_hi, G_fft, N) <= target_modes:
|
|
175
185
|
break
|
|
176
186
|
|
|
177
|
-
#
|
|
187
|
+
# Adaptive bisection: stop when bracket is localised (relative width < 1e-3)
|
|
188
|
+
# _refine_hcrit provides sub-bin precision afterwards — no need to over-bisect.
|
|
178
189
|
lo, hi = h_lo, h_hi
|
|
179
190
|
for _ in range(50):
|
|
180
191
|
mid = (lo + hi) / 2.0
|
|
181
|
-
count =
|
|
192
|
+
count = mode_count_from_C(C, omega, mid, G_fft, N)
|
|
182
193
|
if count <= target_modes:
|
|
183
194
|
hi = mid
|
|
184
195
|
else:
|
|
185
196
|
lo = mid
|
|
186
197
|
if (hi - lo) < tol:
|
|
187
198
|
break
|
|
199
|
+
# Worker 4: adaptive termination — stop when relative bracket width
|
|
200
|
+
# is small enough that further bisection cannot meaningfully shift
|
|
201
|
+
# _refine_hcrit's quadratic fit. Empirically 1e-7 preserves h_crit
|
|
202
|
+
# to within 1e-6 of the 50-step tol=1e-6 baseline while saving ~10
|
|
203
|
+
# bisection steps in typical cases.
|
|
204
|
+
if hi > 0 and (hi - lo) / hi < 1e-7:
|
|
205
|
+
break
|
|
188
206
|
|
|
189
207
|
h_crit = float(hi) # smallest h with count <= target_modes
|
|
190
208
|
|
|
209
|
+
# Sub-bin refinement: quadratic interpolation on the disappearing f′ lobe
|
|
210
|
+
# to locate h_crit below the bin-width precision limit.
|
|
211
|
+
from dcb.fft_kde import _refine_hcrit
|
|
212
|
+
h_crit = _refine_hcrit(X, lo, hi, G_fft, _domain, target_modes)
|
|
213
|
+
|
|
191
214
|
else:
|
|
192
215
|
with torch.no_grad():
|
|
193
216
|
# Verify bracket: need count > target at h_lo, count <= target at h_hi.
|
|
@@ -216,6 +239,9 @@ def find_h_crit_hard(
|
|
|
216
239
|
break
|
|
217
240
|
|
|
218
241
|
# Standard bisection: 50 iterations → bracket width / 2^50
|
|
242
|
+
# NOTE: non-FFT path has no _refine_hcrit sub-bin refinement, so we keep
|
|
243
|
+
# tight bisection here for gradient stability (IFT test requires h_crit
|
|
244
|
+
# accurate well below FD perturbation delta=1e-3).
|
|
219
245
|
lo, hi = h_lo, h_hi
|
|
220
246
|
for _ in range(50):
|
|
221
247
|
mid = (lo + hi) / 2.0
|
|
@@ -290,6 +316,8 @@ def find_h_crit(
|
|
|
290
316
|
g_brentq: int = 128,
|
|
291
317
|
use_hard_bisection: bool = True,
|
|
292
318
|
use_fft: bool = True,
|
|
319
|
+
G_min: int = 16384,
|
|
320
|
+
fft_dtype: torch.dtype = torch.float32,
|
|
293
321
|
) -> tuple[float, float]:
|
|
294
322
|
"""Find h_crit and return (h_crit, condition_number).
|
|
295
323
|
|
|
@@ -343,7 +371,7 @@ def find_h_crit(
|
|
|
343
371
|
return find_h_crit_hard(
|
|
344
372
|
X, grid, target_modes, chunk_size, brentq_n_max,
|
|
345
373
|
h_lo, h_hi, formula=formula, eps=eps, tau=tau,
|
|
346
|
-
use_fft=use_fft,
|
|
374
|
+
use_fft=use_fft, G_min=G_min, fft_dtype=fft_dtype,
|
|
347
375
|
)
|
|
348
376
|
|
|
349
377
|
from scipy.optimize import brentq
|
|
@@ -4,7 +4,7 @@ build-backend = "hatchling.build"
|
|
|
4
4
|
|
|
5
5
|
[project]
|
|
6
6
|
name = "diffcb"
|
|
7
|
-
version = "0.1.
|
|
7
|
+
version = "0.1.4"
|
|
8
8
|
description = "Differentiable Critical Bandwidth: Silverman's modality test as a differentiable PyTorch layer with IFT backward pass."
|
|
9
9
|
readme = "README.md"
|
|
10
10
|
license = { file = "LICENSE" }
|
diffcb-0.1.1/README.md
DELETED
|
@@ -1,92 +0,0 @@
|
|
|
1
|
-
# DCB — Differentiable Critical Bandwidth
|
|
2
|
-
|
|
3
|
-
[](https://pypi.org/project/diffcb/)
|
|
4
|
-
[](LICENSE)
|
|
5
|
-
[](https://www.python.org/)
|
|
6
|
-
|
|
7
|
-
A PyTorch package that makes **Silverman's critical bandwidth test (1981)** fully differentiable, enabling end-to-end gradient-based optimization over the modal structure of continuous distributions.
|
|
8
|
-
|
|
9
|
-
## Overview
|
|
10
|
-
|
|
11
|
-
The critical bandwidth `h_crit` is the minimum KDE bandwidth at which a distribution appears to have at most `m` modes — a classical nonparametric statistic for modality testing. DCB replaces every non-differentiable operation in its computation with a smooth surrogate, then uses the **Implicit Function Theorem** to compute exact gradients through the root-finding step at O(1) memory cost.
|
|
12
|
-
|
|
13
|
-
```python
|
|
14
|
-
import torch
|
|
15
|
-
from dcb import DCBLayer
|
|
16
|
-
|
|
17
|
-
X = torch.randn(256, requires_grad=True) # 1D samples
|
|
18
|
-
layer = DCBLayer(target_modes=1)
|
|
19
|
-
h_crit = layer(X) # differentiable scalar
|
|
20
|
-
h_crit.backward() # exact IFT gradients
|
|
21
|
-
```
|
|
22
|
-
|
|
23
|
-
## Installation
|
|
24
|
-
|
|
25
|
-
```bash
|
|
26
|
-
pip install diffcb
|
|
27
|
-
```
|
|
28
|
-
|
|
29
|
-
Or from source:
|
|
30
|
-
|
|
31
|
-
```bash
|
|
32
|
-
git clone https://github.com/ryZhangHason/differentiable-critical-bandwidth
|
|
33
|
-
cd differentiable-critical-bandwidth
|
|
34
|
-
pip install -e ".[dev]"
|
|
35
|
-
```
|
|
36
|
-
|
|
37
|
-
## Paper
|
|
38
|
-
|
|
39
|
-
> Ruiyu Zhang. "Differentiable Critical Bandwidth: Making Silverman's Modality Test End-to-End Trainable." *Journal of Machine Learning Research*, 2026 (in preparation).
|
|
40
|
-
|
|
41
|
-
## Confirmed Experimental Results
|
|
42
|
-
|
|
43
|
-
All results produced on Kaggle GPU (T4 / P100) — see `experiments/` and `outputs/`.
|
|
44
|
-
|
|
45
|
-
| Experiment | Result | Criterion |
|
|
46
|
-
|---|---|---|
|
|
47
|
-
| **Validation (m≥2)** | R²=0.91, MAE=0.07, Spearman ρ=0.89 | R²≥0.85, MAE≤0.10 ✓ |
|
|
48
|
-
| **Speedup vs scipy (n=8192)** | **10.5×** on T4 | ≥3× ✓ |
|
|
49
|
-
| **GAN mode preservation** | h_crit=1.232 >> 0.3 | h_crit>0.3 ✓ |
|
|
50
|
-
| **Anomaly AUC (KDDCup99)** | DCB=**0.9982** vs IF=0.9867 | DCB≥IF ✓ |
|
|
51
|
-
|
|
52
|
-
## Repository Structure
|
|
53
|
-
|
|
54
|
-
```
|
|
55
|
-
dcb/ Core PyTorch package (layer.py, solver.py, kde.py, utils.py)
|
|
56
|
-
experiments/ Reproduction scripts for all paper figures and tables
|
|
57
|
-
phase1_validation.py Figure 1: DCB vs reference h_crit scatter
|
|
58
|
-
phase1_speedup.py Figure 2: GPU speedup benchmark
|
|
59
|
-
phase1_ablation.py Figures S1–S2: ε/τ sensitivity heatmaps
|
|
60
|
-
phase2_gan.py Figure 3: GAN mode-collapse prevention
|
|
61
|
-
phase3_anomaly.py Table 2 + Figure 5: anomaly detection benchmark
|
|
62
|
-
tests/ Unit tests (pytest, 35/35 passing)
|
|
63
|
-
outputs/ All generated figures and tables (PDFs, PNGs, CSVs)
|
|
64
|
-
notebooks/ Quickstart and demo notebooks
|
|
65
|
-
```
|
|
66
|
-
|
|
67
|
-
## Reproducing Paper Results
|
|
68
|
-
|
|
69
|
-
```bash
|
|
70
|
-
# Phase 1: validation, speedup, ablation
|
|
71
|
-
python experiments/phase1_validation.py
|
|
72
|
-
python experiments/phase1_speedup.py
|
|
73
|
-
python experiments/phase1_ablation.py
|
|
74
|
-
|
|
75
|
-
# Phase 2: GAN mode collapse experiment
|
|
76
|
-
python experiments/phase2_gan.py
|
|
77
|
-
|
|
78
|
-
# Phase 3: anomaly detection benchmark
|
|
79
|
-
python experiments/phase3_anomaly.py
|
|
80
|
-
```
|
|
81
|
-
|
|
82
|
-
For GPU runs, use the provided Kaggle kernels:
|
|
83
|
-
- Phase 1–2: `hsingle/dcb-full-experiments`
|
|
84
|
-
- Phase 3: `hsingle/dcb-phase-3-anomaly-detection`
|
|
85
|
-
|
|
86
|
-
## Kaggle GPU Notes
|
|
87
|
-
|
|
88
|
-
Kaggle may assign a P100 (sm_60) instead of T4. The Phase 3 kernel handles this automatically by installing `torch==2.2.2+cu118` (the earliest PyTorch release with both Python 3.12 and sm_60 support) when P100 is detected.
|
|
89
|
-
|
|
90
|
-
## License
|
|
91
|
-
|
|
92
|
-
MIT — see [LICENSE](LICENSE).
|
diffcb-0.1.1/dcb/fft_kde.py
DELETED
|
@@ -1,144 +0,0 @@
|
|
|
1
|
-
"""
|
|
2
|
-
dcb.fft_kde — FFT-based KDE Mode Counter
|
|
3
|
-
|
|
4
|
-
Implements mode counting via FFT convolution of the histogram with a
|
|
5
|
-
Gaussian derivative kernel. Complexity is O(n + G log G), avoiding the
|
|
6
|
-
O(n × G) cost of the direct KDE approach and — crucially — requiring NO
|
|
7
|
-
subsampling. This eliminates the (brentq_n_max / n)^{-1/5} upward bias
|
|
8
|
-
that affects the standard bisection path when n > brentq_n_max.
|
|
9
|
-
|
|
10
|
-
Round 18b: forward kernel only. The IFT backward is unchanged (still uses
|
|
11
|
-
the analytical chunked KDE derivatives on all n points).
|
|
12
|
-
"""
|
|
13
|
-
|
|
14
|
-
from __future__ import annotations
|
|
15
|
-
|
|
16
|
-
import math
|
|
17
|
-
|
|
18
|
-
import torch
|
|
19
|
-
from torch import Tensor
|
|
20
|
-
|
|
21
|
-
|
|
22
|
-
def fft_mode_count(
|
|
23
|
-
X: Tensor,
|
|
24
|
-
h: float,
|
|
25
|
-
G: int = 4096,
|
|
26
|
-
pad_factor: int = 4,
|
|
27
|
-
domain: tuple[float, float] | None = None,
|
|
28
|
-
) -> int:
|
|
29
|
-
"""Count KDE modes via FFT convolution — O(n + G log G), no subsampling.
|
|
30
|
-
|
|
31
|
-
Bins X into G histogram bins, zero-pads to pad_factor*G, convolves with
|
|
32
|
-
the Gaussian derivative kernel in the frequency domain (applying iω·exp(−½(ωh)²)),
|
|
33
|
-
back-transforms, and counts positive-to-negative sign changes of the
|
|
34
|
-
resulting f' estimate.
|
|
35
|
-
|
|
36
|
-
Parameters
|
|
37
|
-
----------
|
|
38
|
-
X : Tensor, shape (n,)
|
|
39
|
-
1D data tensor (may be on CPU or CUDA).
|
|
40
|
-
h : float
|
|
41
|
-
Bandwidth for the Gaussian kernel.
|
|
42
|
-
G : int
|
|
43
|
-
Number of histogram bins. Must satisfy h > 8 * (data_range / G) for
|
|
44
|
-
reliable derivative estimation. Use `adaptive_fft_G` to choose G
|
|
45
|
-
automatically before bisection.
|
|
46
|
-
pad_factor : int
|
|
47
|
-
Zero-padding multiplier (default 4). Mandatory ≥ 2 for circular-wrap
|
|
48
|
-
correctness; 4 is recommended at the largest h encountered.
|
|
49
|
-
domain : (lo, hi) or None
|
|
50
|
-
If provided, use this as the histogram domain instead of computing
|
|
51
|
-
X.min() - 3σ … X.max() + 3σ. Allows the caller to align the domain
|
|
52
|
-
with the bisection bracket (e.g., X.min() - 2*h_hi … X.max() + 2*h_hi)
|
|
53
|
-
so every fft_mode_count call in a bisection loop uses an identical grid.
|
|
54
|
-
|
|
55
|
-
Returns
|
|
56
|
-
-------
|
|
57
|
-
int
|
|
58
|
-
Number of KDE modes (downward zero-crossings of f').
|
|
59
|
-
"""
|
|
60
|
-
with torch.no_grad():
|
|
61
|
-
if domain is not None:
|
|
62
|
-
lo, hi = domain
|
|
63
|
-
else:
|
|
64
|
-
# Domain: extend 3σ beyond data range to avoid boundary effects
|
|
65
|
-
sigma = X.std().item()
|
|
66
|
-
if sigma == 0.0:
|
|
67
|
-
sigma = 1.0 # degenerate case: all points identical
|
|
68
|
-
lo = X.min().item() - 3 * sigma
|
|
69
|
-
hi = X.max().item() + 3 * sigma
|
|
70
|
-
data_range = hi - lo
|
|
71
|
-
|
|
72
|
-
if data_range == 0.0:
|
|
73
|
-
return 1 # single-point distribution has 1 mode
|
|
74
|
-
|
|
75
|
-
# Histogram (O(n)) — MPS-safe via bucketize+bincount on CPU.
|
|
76
|
-
# torch.histc on MPS allocates an n × bins float32 intermediate (PyTorch
|
|
77
|
-
# MPS bug); at n=5M, bins=512 this is ~9.5 GiB → OOM. Moving to CPU for
|
|
78
|
-
# the binning step avoids the intermediate and is numerically identical
|
|
79
|
-
# for data within [lo, hi] (guaranteed by the 3σ domain extension above).
|
|
80
|
-
X_cpu = X.float().cpu()
|
|
81
|
-
edges = torch.linspace(lo, hi, G + 1) # (G+1,) CPU
|
|
82
|
-
bin_idx = torch.bucketize(X_cpu, edges, right=True).clamp(1, G) - 1 # 0-indexed
|
|
83
|
-
counts = torch.bincount(bin_idx, minlength=G).float().to(X.device) # back to device
|
|
84
|
-
|
|
85
|
-
# Zero-pad to pad_factor*G (4× mandatory for circular wrap correctness at h_hi)
|
|
86
|
-
N = pad_factor * G
|
|
87
|
-
counts_padded = torch.zeros(N, dtype=torch.float32, device=X.device)
|
|
88
|
-
counts_padded[:G] = counts
|
|
89
|
-
|
|
90
|
-
# FFT of histogram
|
|
91
|
-
C = torch.fft.rfft(counts_padded)
|
|
92
|
-
|
|
93
|
-
# Derivative kernel in frequency domain: iω * exp(-0.5*(ω*h)²)
|
|
94
|
-
# ω_k = 2π*k / (N * bin_width), bin_width = data_range / G
|
|
95
|
-
bin_width = data_range / G
|
|
96
|
-
k = torch.arange(N // 2 + 1, device=X.device, dtype=torch.float32)
|
|
97
|
-
omega = 2 * math.pi * k / (N * bin_width)
|
|
98
|
-
K_deriv = 1j * omega * torch.exp(-0.5 * (omega * h) ** 2)
|
|
99
|
-
|
|
100
|
-
# Convolve and back-transform
|
|
101
|
-
f_prime_padded = torch.fft.irfft(C * K_deriv, n=N)
|
|
102
|
-
|
|
103
|
-
# Trim to original G grid (discard zero-padded tail)
|
|
104
|
-
f_prime = f_prime_padded[:G]
|
|
105
|
-
|
|
106
|
-
# Count (+→-) sign changes = number of modes
|
|
107
|
-
# A mode is a local max of f, i.e., f' crosses zero from + to -
|
|
108
|
-
# Remove zeros (flat segments) — carry forward last nonzero sign
|
|
109
|
-
nonzero_mask = f_prime != 0
|
|
110
|
-
if not nonzero_mask.any():
|
|
111
|
-
return 0
|
|
112
|
-
|
|
113
|
-
s = f_prime[nonzero_mask]
|
|
114
|
-
transitions = int(((s[:-1] > 0) & (s[1:] < 0)).sum().item())
|
|
115
|
-
return transitions
|
|
116
|
-
|
|
117
|
-
|
|
118
|
-
def adaptive_fft_G(data_range: float, h_hi: float, G_min: int = 4096) -> int:
|
|
119
|
-
"""Choose FFT grid size G so that the derivative kernel is well-resolved.
|
|
120
|
-
|
|
121
|
-
Requires h > 8 * bin_width = 8 * data_range / G, equivalently
|
|
122
|
-
G > 8 * data_range / h_hi. We use a factor of 16 for safety margin,
|
|
123
|
-
then round up to the next power of 2 for efficient FFT.
|
|
124
|
-
|
|
125
|
-
Parameters
|
|
126
|
-
----------
|
|
127
|
-
data_range : float
|
|
128
|
-
hi - lo of the data domain (typically X.max() - X.min() + 6σ).
|
|
129
|
-
h_hi : float
|
|
130
|
-
Upper bracket of the bisection (smallest h needing resolution).
|
|
131
|
-
G_min : int
|
|
132
|
-
Minimum returned G (default 4096).
|
|
133
|
-
|
|
134
|
-
Returns
|
|
135
|
-
-------
|
|
136
|
-
int
|
|
137
|
-
Grid size G, a power of 2, at least G_min.
|
|
138
|
-
"""
|
|
139
|
-
needed = 16 * math.ceil(data_range / h_hi)
|
|
140
|
-
# Round up to next power of 2
|
|
141
|
-
p = 1
|
|
142
|
-
while p < needed:
|
|
143
|
-
p <<= 1
|
|
144
|
-
return max(G_min, p)
|
|
File without changes
|
|
File without changes
|
|
File without changes
|
|
File without changes
|
|
File without changes
|
|
File without changes
|
|
File without changes
|
|
File without changes
|
|
File without changes
|
|
File without changes
|
|
File without changes
|
|
File without changes
|
|
File without changes
|
|
File without changes
|