diffcb 0.1.1__tar.gz → 0.1.3__tar.gz
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- {diffcb-0.1.1 → diffcb-0.1.3}/PKG-INFO +58 -21
- diffcb-0.1.3/README.md +129 -0
- {diffcb-0.1.1 → diffcb-0.1.3}/dcb/__init__.py +1 -1
- diffcb-0.1.3/dcb/fft_kde.py +262 -0
- {diffcb-0.1.1 → diffcb-0.1.3}/dcb/layer.py +14 -5
- {diffcb-0.1.1 → diffcb-0.1.3}/dcb/solver.py +9 -2
- {diffcb-0.1.1 → diffcb-0.1.3}/pyproject.toml +1 -1
- diffcb-0.1.1/README.md +0 -92
- diffcb-0.1.1/dcb/fft_kde.py +0 -144
- {diffcb-0.1.1 → diffcb-0.1.3}/.gitignore +0 -0
- {diffcb-0.1.1 → diffcb-0.1.3}/.zenodo.json +0 -0
- {diffcb-0.1.1 → diffcb-0.1.3}/LICENSE +0 -0
- {diffcb-0.1.1 → diffcb-0.1.3}/dcb/diagnostics.py +0 -0
- {diffcb-0.1.1 → diffcb-0.1.3}/dcb/kde.py +0 -0
- {diffcb-0.1.1 → diffcb-0.1.3}/dcb/utils.py +0 -0
- {diffcb-0.1.1 → diffcb-0.1.3}/notebooks/.gitkeep +0 -0
- {diffcb-0.1.1 → diffcb-0.1.3}/tests/test_kde.py +0 -0
- {diffcb-0.1.1 → diffcb-0.1.3}/tests/test_layer.py +0 -0
- {diffcb-0.1.1 → diffcb-0.1.3}/tests/test_r18c_denom_audit.py +0 -0
- {diffcb-0.1.1 → diffcb-0.1.3}/tests/test_r18c_deprecation_warn.py +0 -0
- {diffcb-0.1.1 → diffcb-0.1.3}/tests/test_r19_default_fft.py +0 -0
- {diffcb-0.1.1 → diffcb-0.1.3}/tests/test_r19_diagnostics.py +0 -0
- {diffcb-0.1.1 → diffcb-0.1.3}/tests/test_solver.py +0 -0
|
@@ -1,6 +1,6 @@
|
|
|
1
1
|
Metadata-Version: 2.4
|
|
2
2
|
Name: diffcb
|
|
3
|
-
Version: 0.1.
|
|
3
|
+
Version: 0.1.3
|
|
4
4
|
Summary: Differentiable Critical Bandwidth: Silverman's modality test as a differentiable PyTorch layer with IFT backward pass.
|
|
5
5
|
Project-URL: Homepage, https://github.com/ryZhangHason/differentiable-critical-bandwidth
|
|
6
6
|
Project-URL: Repository, https://github.com/ryZhangHason/differentiable-critical-bandwidth
|
|
@@ -71,10 +71,10 @@ The critical bandwidth `h_crit` is the minimum KDE bandwidth at which a distribu
|
|
|
71
71
|
import torch
|
|
72
72
|
from dcb import DCBLayer
|
|
73
73
|
|
|
74
|
-
X = torch.randn(
|
|
74
|
+
X = torch.randn(1000, requires_grad=True) # 1D samples
|
|
75
75
|
layer = DCBLayer(target_modes=1)
|
|
76
|
-
h_crit = layer(X)
|
|
77
|
-
h_crit.backward()
|
|
76
|
+
h_crit = layer(X) # differentiable scalar
|
|
77
|
+
h_crit.backward() # exact IFT gradients
|
|
78
78
|
```
|
|
79
79
|
|
|
80
80
|
## Installation
|
|
@@ -91,34 +91,72 @@ cd differentiable-critical-bandwidth
|
|
|
91
91
|
pip install -e ".[dev]"
|
|
92
92
|
```
|
|
93
93
|
|
|
94
|
-
##
|
|
94
|
+
## Accuracy vs R's `bw.crit`
|
|
95
95
|
|
|
96
|
-
|
|
96
|
+
DCB is validated against R's `multimode::bw.crit(data, mod0=1)` — the standard reference implementation of Hall & York (2001). On **identical data**:
|
|
97
|
+
|
|
98
|
+
| n | DCB vs R (same sample) | DCB vs R (independent samples) |
|
|
99
|
+
|---|---|---|
|
|
100
|
+
| 100K | **0.004%** | ~0.5% (MC noise from independent RNG) |
|
|
101
|
+
| 1M | **0.005%** | ~0.2% |
|
|
102
|
+
| 10M | **0.004%** | ~0.1% |
|
|
103
|
+
|
|
104
|
+
The independent-sample figures reflect natural sampling variability (two unbiased estimators drawing different data), not algorithmic error. On identical data, DCB agrees with R to within **0.005%** at all tested n. DCB is 43× faster than R at n=100M (1.1 s vs 50 s) and handles n=2B in 24 s while R OOMs.
|
|
105
|
+
|
|
106
|
+
## Key Parameters
|
|
107
|
+
|
|
108
|
+
```python
|
|
109
|
+
DCBLayer(
|
|
110
|
+
target_modes=1, # target number of modes
|
|
111
|
+
G=512, # IFT evaluation grid points
|
|
112
|
+
use_fft=True, # FFT forward (default); eliminates subsampling bias for n>50K
|
|
113
|
+
max_n_exact=1_000_000,# sketch to sketch_size when n exceeds this (None = always exact)
|
|
114
|
+
sketch_size=500_000, # sketch target; 500K matches full-n accuracy (O(n^{-2/9}) rate)
|
|
115
|
+
safe_backward=False, # clamp IFT denominator near bifurcations
|
|
116
|
+
)
|
|
117
|
+
```
|
|
97
118
|
|
|
98
119
|
## Confirmed Experimental Results
|
|
99
120
|
|
|
100
|
-
All results produced on Kaggle
|
|
121
|
+
All GPU results produced on Kaggle (T4 / P100) — see `experiments/` and `outputs/`.
|
|
101
122
|
|
|
102
123
|
| Experiment | Result | Criterion |
|
|
103
124
|
|---|---|---|
|
|
104
|
-
| **
|
|
105
|
-
| **
|
|
125
|
+
| **Accuracy vs R (same data, n=100K)** | **0.004%** | < 0.01% ✓ |
|
|
126
|
+
| **Validation (m≥2, Marron-Wand)** | R²=0.91, MAE=0.07, ρ=0.89 | R²≥0.85 ✓ |
|
|
127
|
+
| **Speedup vs scipy (CUDA T4, n=8192)** | **10.5×** | ≥3× ✓ |
|
|
106
128
|
| **GAN mode preservation** | h_crit=1.232 >> 0.3 | h_crit>0.3 ✓ |
|
|
107
129
|
| **Anomaly AUC (KDDCup99)** | DCB=**0.9982** vs IF=0.9867 | DCB≥IF ✓ |
|
|
108
130
|
|
|
131
|
+
## Changelog
|
|
132
|
+
|
|
133
|
+
### v0.1.1 (2026-05-29)
|
|
134
|
+
- **MPS fix:** `torch.histc` on MPS allocated an n×bins intermediate (OOM at n≥5M). Replaced with `bucketize+bincount` on CPU — MPS-safe and numerically identical.
|
|
135
|
+
- **Sketch API:** `DCBLayer(max_n_exact=1_000_000, sketch_size=500_000)` — silently sketches to 500K when n exceeds threshold. Justified by O(n⁻²/⁹) convergence of h_crit; 500K sketch matches full-n accuracy.
|
|
136
|
+
- **Consistent bisection domain:** Pre-computed domain passed to all `fft_mode_count` calls in a single bisection, eliminating per-step drift.
|
|
137
|
+
- **Bias warning direction:** Corrected "expected upward bias" to "expected downward bias" on legacy `use_fft=False` path.
|
|
138
|
+
- **Test fixes:** Updated 8 pre-existing test failures (tuple unpacking, bounds, deprecation API).
|
|
139
|
+
|
|
140
|
+
### v0.1.0 (2026-05-28)
|
|
141
|
+
- Initial PyPI release. FFT forward (O(n + G log G)), IFT backward, MPS support.
|
|
142
|
+
|
|
109
143
|
## Repository Structure
|
|
110
144
|
|
|
111
145
|
```
|
|
112
|
-
dcb/ Core PyTorch package
|
|
146
|
+
dcb/ Core PyTorch package
|
|
147
|
+
layer.py DCBLayer nn.Module + DCBFunction autograd
|
|
148
|
+
solver.py IFT root-finder and backward pass
|
|
149
|
+
fft_kde.py FFT-based mode counter (MPS-safe, float64, G=16384)
|
|
150
|
+
kde.py Direct KDE derivatives (small-n path)
|
|
151
|
+
utils.py Grid, Silverman bandwidth, sg() stabilizer
|
|
113
152
|
experiments/ Reproduction scripts for all paper figures and tables
|
|
114
|
-
|
|
115
|
-
|
|
116
|
-
|
|
117
|
-
|
|
118
|
-
|
|
119
|
-
tests/ Unit tests (pytest,
|
|
153
|
+
phase1_*.py Validation, speedup, ablation (Figures 1–2, S1–S2)
|
|
154
|
+
phase2_gan.py GAN mode-collapse prevention (Figure 3)
|
|
155
|
+
phase3_anomaly.py Anomaly detection (Table 2, Figure 5)
|
|
156
|
+
round20_*.py Large-n R comparison and streaming benchmarks
|
|
157
|
+
round21_*.py Accuracy improvement experiments
|
|
158
|
+
tests/ Unit tests (pytest, 45 passed, 1 xfailed)
|
|
120
159
|
outputs/ All generated figures and tables (PDFs, PNGs, CSVs)
|
|
121
|
-
notebooks/ Quickstart and demo notebooks
|
|
122
160
|
```
|
|
123
161
|
|
|
124
162
|
## Reproducing Paper Results
|
|
@@ -127,7 +165,6 @@ notebooks/ Quickstart and demo notebooks
|
|
|
127
165
|
# Phase 1: validation, speedup, ablation
|
|
128
166
|
python experiments/phase1_validation.py
|
|
129
167
|
python experiments/phase1_speedup.py
|
|
130
|
-
python experiments/phase1_ablation.py
|
|
131
168
|
|
|
132
169
|
# Phase 2: GAN mode collapse experiment
|
|
133
170
|
python experiments/phase2_gan.py
|
|
@@ -136,13 +173,13 @@ python experiments/phase2_gan.py
|
|
|
136
173
|
python experiments/phase3_anomaly.py
|
|
137
174
|
```
|
|
138
175
|
|
|
139
|
-
For GPU runs
|
|
176
|
+
For GPU runs use the Kaggle kernels:
|
|
140
177
|
- Phase 1–2: `hsingle/dcb-full-experiments`
|
|
141
178
|
- Phase 3: `hsingle/dcb-phase-3-anomaly-detection`
|
|
142
179
|
|
|
143
|
-
##
|
|
180
|
+
## Paper
|
|
144
181
|
|
|
145
|
-
|
|
182
|
+
> Ruiyu Zhang. "Differentiable Critical Bandwidth: Making Silverman's Modality Test End-to-End Trainable." *Journal of Machine Learning Research*, 2026 (in preparation).
|
|
146
183
|
|
|
147
184
|
## License
|
|
148
185
|
|
diffcb-0.1.3/README.md
ADDED
|
@@ -0,0 +1,129 @@
|
|
|
1
|
+
# DCB — Differentiable Critical Bandwidth
|
|
2
|
+
|
|
3
|
+
[](https://pypi.org/project/diffcb/)
|
|
4
|
+
[](LICENSE)
|
|
5
|
+
[](https://www.python.org/)
|
|
6
|
+
|
|
7
|
+
A PyTorch package that makes **Silverman's critical bandwidth test (1981)** fully differentiable, enabling end-to-end gradient-based optimization over the modal structure of continuous distributions.
|
|
8
|
+
|
|
9
|
+
## Overview
|
|
10
|
+
|
|
11
|
+
The critical bandwidth `h_crit` is the minimum KDE bandwidth at which a distribution appears to have at most `m` modes — a classical nonparametric statistic for modality testing. DCB replaces every non-differentiable operation in its computation with a smooth surrogate, then uses the **Implicit Function Theorem** to compute exact gradients through the root-finding step at O(1) memory cost.
|
|
12
|
+
|
|
13
|
+
```python
|
|
14
|
+
import torch
|
|
15
|
+
from dcb import DCBLayer
|
|
16
|
+
|
|
17
|
+
X = torch.randn(1000, requires_grad=True) # 1D samples
|
|
18
|
+
layer = DCBLayer(target_modes=1)
|
|
19
|
+
h_crit = layer(X) # differentiable scalar
|
|
20
|
+
h_crit.backward() # exact IFT gradients
|
|
21
|
+
```
|
|
22
|
+
|
|
23
|
+
## Installation
|
|
24
|
+
|
|
25
|
+
```bash
|
|
26
|
+
pip install diffcb
|
|
27
|
+
```
|
|
28
|
+
|
|
29
|
+
Or from source:
|
|
30
|
+
|
|
31
|
+
```bash
|
|
32
|
+
git clone https://github.com/ryZhangHason/differentiable-critical-bandwidth
|
|
33
|
+
cd differentiable-critical-bandwidth
|
|
34
|
+
pip install -e ".[dev]"
|
|
35
|
+
```
|
|
36
|
+
|
|
37
|
+
## Accuracy vs R's `bw.crit`
|
|
38
|
+
|
|
39
|
+
DCB is validated against R's `multimode::bw.crit(data, mod0=1)` — the standard reference implementation of Hall & York (2001). On **identical data**:
|
|
40
|
+
|
|
41
|
+
| n | DCB vs R (same sample) | DCB vs R (independent samples) |
|
|
42
|
+
|---|---|---|
|
|
43
|
+
| 100K | **0.004%** | ~0.5% (MC noise from independent RNG) |
|
|
44
|
+
| 1M | **0.005%** | ~0.2% |
|
|
45
|
+
| 10M | **0.004%** | ~0.1% |
|
|
46
|
+
|
|
47
|
+
The independent-sample figures reflect natural sampling variability (two unbiased estimators drawing different data), not algorithmic error. On identical data, DCB agrees with R to within **0.005%** at all tested n. DCB is 43× faster than R at n=100M (1.1 s vs 50 s) and handles n=2B in 24 s while R OOMs.
|
|
48
|
+
|
|
49
|
+
## Key Parameters
|
|
50
|
+
|
|
51
|
+
```python
|
|
52
|
+
DCBLayer(
|
|
53
|
+
target_modes=1, # target number of modes
|
|
54
|
+
G=512, # IFT evaluation grid points
|
|
55
|
+
use_fft=True, # FFT forward (default); eliminates subsampling bias for n>50K
|
|
56
|
+
max_n_exact=1_000_000,# sketch to sketch_size when n exceeds this (None = always exact)
|
|
57
|
+
sketch_size=500_000, # sketch target; 500K matches full-n accuracy (O(n^{-2/9}) rate)
|
|
58
|
+
safe_backward=False, # clamp IFT denominator near bifurcations
|
|
59
|
+
)
|
|
60
|
+
```
|
|
61
|
+
|
|
62
|
+
## Confirmed Experimental Results
|
|
63
|
+
|
|
64
|
+
All GPU results produced on Kaggle (T4 / P100) — see `experiments/` and `outputs/`.
|
|
65
|
+
|
|
66
|
+
| Experiment | Result | Criterion |
|
|
67
|
+
|---|---|---|
|
|
68
|
+
| **Accuracy vs R (same data, n=100K)** | **0.004%** | < 0.01% ✓ |
|
|
69
|
+
| **Validation (m≥2, Marron-Wand)** | R²=0.91, MAE=0.07, ρ=0.89 | R²≥0.85 ✓ |
|
|
70
|
+
| **Speedup vs scipy (CUDA T4, n=8192)** | **10.5×** | ≥3× ✓ |
|
|
71
|
+
| **GAN mode preservation** | h_crit=1.232 >> 0.3 | h_crit>0.3 ✓ |
|
|
72
|
+
| **Anomaly AUC (KDDCup99)** | DCB=**0.9982** vs IF=0.9867 | DCB≥IF ✓ |
|
|
73
|
+
|
|
74
|
+
## Changelog
|
|
75
|
+
|
|
76
|
+
### v0.1.1 (2026-05-29)
|
|
77
|
+
- **MPS fix:** `torch.histc` on MPS allocated an n×bins intermediate (OOM at n≥5M). Replaced with `bucketize+bincount` on CPU — MPS-safe and numerically identical.
|
|
78
|
+
- **Sketch API:** `DCBLayer(max_n_exact=1_000_000, sketch_size=500_000)` — silently sketches to 500K when n exceeds threshold. Justified by O(n⁻²/⁹) convergence of h_crit; 500K sketch matches full-n accuracy.
|
|
79
|
+
- **Consistent bisection domain:** Pre-computed domain passed to all `fft_mode_count` calls in a single bisection, eliminating per-step drift.
|
|
80
|
+
- **Bias warning direction:** Corrected "expected upward bias" to "expected downward bias" on legacy `use_fft=False` path.
|
|
81
|
+
- **Test fixes:** Updated 8 pre-existing test failures (tuple unpacking, bounds, deprecation API).
|
|
82
|
+
|
|
83
|
+
### v0.1.0 (2026-05-28)
|
|
84
|
+
- Initial PyPI release. FFT forward (O(n + G log G)), IFT backward, MPS support.
|
|
85
|
+
|
|
86
|
+
## Repository Structure
|
|
87
|
+
|
|
88
|
+
```
|
|
89
|
+
dcb/ Core PyTorch package
|
|
90
|
+
layer.py DCBLayer nn.Module + DCBFunction autograd
|
|
91
|
+
solver.py IFT root-finder and backward pass
|
|
92
|
+
fft_kde.py FFT-based mode counter (MPS-safe, float64, G=16384)
|
|
93
|
+
kde.py Direct KDE derivatives (small-n path)
|
|
94
|
+
utils.py Grid, Silverman bandwidth, sg() stabilizer
|
|
95
|
+
experiments/ Reproduction scripts for all paper figures and tables
|
|
96
|
+
phase1_*.py Validation, speedup, ablation (Figures 1–2, S1–S2)
|
|
97
|
+
phase2_gan.py GAN mode-collapse prevention (Figure 3)
|
|
98
|
+
phase3_anomaly.py Anomaly detection (Table 2, Figure 5)
|
|
99
|
+
round20_*.py Large-n R comparison and streaming benchmarks
|
|
100
|
+
round21_*.py Accuracy improvement experiments
|
|
101
|
+
tests/ Unit tests (pytest, 45 passed, 1 xfailed)
|
|
102
|
+
outputs/ All generated figures and tables (PDFs, PNGs, CSVs)
|
|
103
|
+
```
|
|
104
|
+
|
|
105
|
+
## Reproducing Paper Results
|
|
106
|
+
|
|
107
|
+
```bash
|
|
108
|
+
# Phase 1: validation, speedup, ablation
|
|
109
|
+
python experiments/phase1_validation.py
|
|
110
|
+
python experiments/phase1_speedup.py
|
|
111
|
+
|
|
112
|
+
# Phase 2: GAN mode collapse experiment
|
|
113
|
+
python experiments/phase2_gan.py
|
|
114
|
+
|
|
115
|
+
# Phase 3: anomaly detection benchmark
|
|
116
|
+
python experiments/phase3_anomaly.py
|
|
117
|
+
```
|
|
118
|
+
|
|
119
|
+
For GPU runs use the Kaggle kernels:
|
|
120
|
+
- Phase 1–2: `hsingle/dcb-full-experiments`
|
|
121
|
+
- Phase 3: `hsingle/dcb-phase-3-anomaly-detection`
|
|
122
|
+
|
|
123
|
+
## Paper
|
|
124
|
+
|
|
125
|
+
> Ruiyu Zhang. "Differentiable Critical Bandwidth: Making Silverman's Modality Test End-to-End Trainable." *Journal of Machine Learning Research*, 2026 (in preparation).
|
|
126
|
+
|
|
127
|
+
## License
|
|
128
|
+
|
|
129
|
+
MIT — see [LICENSE](LICENSE).
|
|
@@ -0,0 +1,262 @@
|
|
|
1
|
+
"""
|
|
2
|
+
dcb.fft_kde — FFT-based KDE Mode Counter
|
|
3
|
+
|
|
4
|
+
Implements mode counting via FFT convolution of the histogram with a
|
|
5
|
+
Gaussian derivative kernel. Complexity is O(n + G log G), avoiding the
|
|
6
|
+
O(n × G) cost of the direct KDE approach and — crucially — requiring NO
|
|
7
|
+
subsampling. This eliminates the (brentq_n_max / n)^{-1/5} upward bias
|
|
8
|
+
that affects the standard bisection path when n > brentq_n_max.
|
|
9
|
+
|
|
10
|
+
Round 18b: forward kernel only. The IFT backward is unchanged (still uses
|
|
11
|
+
the analytical chunked KDE derivatives on all n points).
|
|
12
|
+
"""
|
|
13
|
+
|
|
14
|
+
from __future__ import annotations
|
|
15
|
+
|
|
16
|
+
import math
|
|
17
|
+
|
|
18
|
+
import torch
|
|
19
|
+
from torch import Tensor
|
|
20
|
+
|
|
21
|
+
|
|
22
|
+
def fft_mode_count(
|
|
23
|
+
X: Tensor,
|
|
24
|
+
h: float,
|
|
25
|
+
G: int = 4096,
|
|
26
|
+
pad_factor: int = 4,
|
|
27
|
+
domain: tuple[float, float] | None = None,
|
|
28
|
+
) -> int:
|
|
29
|
+
"""Count KDE modes via FFT convolution — O(n + G log G), no subsampling.
|
|
30
|
+
|
|
31
|
+
Bins X into G histogram bins, zero-pads to pad_factor*G, convolves with
|
|
32
|
+
the Gaussian derivative kernel in the frequency domain (applying iω·exp(−½(ωh)²)),
|
|
33
|
+
back-transforms, and counts positive-to-negative sign changes of the
|
|
34
|
+
resulting f' estimate.
|
|
35
|
+
|
|
36
|
+
Parameters
|
|
37
|
+
----------
|
|
38
|
+
X : Tensor, shape (n,)
|
|
39
|
+
1D data tensor (may be on CPU or CUDA).
|
|
40
|
+
h : float
|
|
41
|
+
Bandwidth for the Gaussian kernel.
|
|
42
|
+
G : int
|
|
43
|
+
Number of histogram bins. Must satisfy h > 8 * (data_range / G) for
|
|
44
|
+
reliable derivative estimation. Use `adaptive_fft_G` to choose G
|
|
45
|
+
automatically before bisection.
|
|
46
|
+
pad_factor : int
|
|
47
|
+
Zero-padding multiplier (default 4). Mandatory ≥ 2 for circular-wrap
|
|
48
|
+
correctness; 4 is recommended at the largest h encountered.
|
|
49
|
+
domain : (lo, hi) or None
|
|
50
|
+
If provided, use this as the histogram domain instead of computing
|
|
51
|
+
X.min() - 3σ … X.max() + 3σ. Allows the caller to align the domain
|
|
52
|
+
with the bisection bracket (e.g., X.min() - 2*h_hi … X.max() + 2*h_hi)
|
|
53
|
+
so every fft_mode_count call in a bisection loop uses an identical grid.
|
|
54
|
+
|
|
55
|
+
Returns
|
|
56
|
+
-------
|
|
57
|
+
int
|
|
58
|
+
Number of KDE modes (downward zero-crossings of f').
|
|
59
|
+
"""
|
|
60
|
+
with torch.no_grad():
|
|
61
|
+
if domain is not None:
|
|
62
|
+
lo, hi = domain
|
|
63
|
+
else:
|
|
64
|
+
# Domain: extend 3σ beyond data range to avoid boundary effects
|
|
65
|
+
sigma = X.std().item()
|
|
66
|
+
if sigma == 0.0:
|
|
67
|
+
sigma = 1.0 # degenerate case: all points identical
|
|
68
|
+
lo = X.min().item() - 3 * sigma
|
|
69
|
+
hi = X.max().item() + 3 * sigma
|
|
70
|
+
data_range = hi - lo
|
|
71
|
+
|
|
72
|
+
if data_range == 0.0:
|
|
73
|
+
return 1 # single-point distribution has 1 mode
|
|
74
|
+
|
|
75
|
+
# Histogram (O(n)) — MPS-safe via bucketize+bincount on CPU.
|
|
76
|
+
# torch.histc on MPS allocates an n × bins float32 intermediate (PyTorch
|
|
77
|
+
# MPS bug); at n=5M, bins=512 this is ~9.5 GiB → OOM. Moving to CPU for
|
|
78
|
+
# the binning step avoids the intermediate and is numerically identical
|
|
79
|
+
# for data within [lo, hi] (guaranteed by the 3σ domain extension above).
|
|
80
|
+
X_cpu = X.float().cpu()
|
|
81
|
+
edges = torch.linspace(lo, hi, G + 1) # (G+1,) CPU
|
|
82
|
+
bin_idx = torch.bucketize(X_cpu, edges, right=True).clamp(1, G) - 1 # 0-indexed
|
|
83
|
+
counts = torch.bincount(bin_idx, minlength=G).float().to(X.device) # back to device
|
|
84
|
+
|
|
85
|
+
# Zero-pad to pad_factor*G — promote to float64 for FFT precision
|
|
86
|
+
N = pad_factor * G
|
|
87
|
+
counts_padded = torch.zeros(N, dtype=torch.float64, device=X.device)
|
|
88
|
+
counts_padded[:G] = counts.double()
|
|
89
|
+
|
|
90
|
+
# FFT of histogram (float64)
|
|
91
|
+
C = torch.fft.rfft(counts_padded)
|
|
92
|
+
|
|
93
|
+
# Derivative kernel in frequency domain (float64)
|
|
94
|
+
bin_width = data_range / G
|
|
95
|
+
k = torch.arange(N // 2 + 1, device=X.device, dtype=torch.float64)
|
|
96
|
+
omega = 2 * math.pi * k / (N * bin_width)
|
|
97
|
+
K_deriv = 1j * omega * torch.exp(-0.5 * (omega * h) ** 2)
|
|
98
|
+
|
|
99
|
+
# Convolve and back-transform; cast result back to float32
|
|
100
|
+
f_prime_padded = torch.fft.irfft(C * K_deriv, n=N).float()
|
|
101
|
+
|
|
102
|
+
# Trim to original G grid (discard zero-padded tail)
|
|
103
|
+
f_prime = f_prime_padded[:G]
|
|
104
|
+
|
|
105
|
+
# Count (+→-) sign changes = number of modes
|
|
106
|
+
# A mode is a local max of f, i.e., f' crosses zero from + to -
|
|
107
|
+
# Remove zeros (flat segments) — carry forward last nonzero sign
|
|
108
|
+
nonzero_mask = f_prime != 0
|
|
109
|
+
if not nonzero_mask.any():
|
|
110
|
+
return 0
|
|
111
|
+
|
|
112
|
+
s = f_prime[nonzero_mask]
|
|
113
|
+
transitions = int(((s[:-1] > 0) & (s[1:] < 0)).sum().item())
|
|
114
|
+
return transitions
|
|
115
|
+
|
|
116
|
+
|
|
117
|
+
def _refine_hcrit(
|
|
118
|
+
X: Tensor,
|
|
119
|
+
h_lo: float,
|
|
120
|
+
h_hi: float,
|
|
121
|
+
G: int,
|
|
122
|
+
domain: tuple[float, float],
|
|
123
|
+
target_modes: int = 1,
|
|
124
|
+
pad_factor: int = 4,
|
|
125
|
+
) -> float:
|
|
126
|
+
"""Sub-bin quadratic refinement of h_crit after bisection converges.
|
|
127
|
+
|
|
128
|
+
Identifies the f′ lobe that disappears at the mode-merging bandwidth and
|
|
129
|
+
fits a quadratic in h to that lobe's peak value, returning the root — the
|
|
130
|
+
h where that peak exactly reaches zero. Reduces the bin-width-limited
|
|
131
|
+
systematic from ~bin_width/h_crit to well below 1e-4.
|
|
132
|
+
|
|
133
|
+
When the incoming bracket [h_lo, h_hi] is tighter than one histogram bin
|
|
134
|
+
width (the common case after 50-step bisection), the function expands the
|
|
135
|
+
bracket outward from h_hi by up to 4× the bin width while maintaining the
|
|
136
|
+
invariant that fft_mode_count > target at the left endpoint and
|
|
137
|
+
<= target at the right endpoint, so the disappearing f′ lobe is visible
|
|
138
|
+
across the bracket.
|
|
139
|
+
|
|
140
|
+
Parameters
|
|
141
|
+
----------
|
|
142
|
+
X : Tensor — data (may be on any device)
|
|
143
|
+
h_lo, h_hi : float — final bisection bracket; fft_mode_count(X,h_lo) > target,
|
|
144
|
+
fft_mode_count(X,h_hi) <= target
|
|
145
|
+
G, domain, target_modes, pad_factor — same as fft_mode_count
|
|
146
|
+
|
|
147
|
+
Returns
|
|
148
|
+
-------
|
|
149
|
+
float — refined h_crit, guaranteed to lie in [h_lo, h_hi] of the
|
|
150
|
+
(possibly expanded) bracket used for fitting.
|
|
151
|
+
"""
|
|
152
|
+
import numpy as np
|
|
153
|
+
|
|
154
|
+
lo_d, hi_d = domain
|
|
155
|
+
data_range = hi_d - lo_d
|
|
156
|
+
if data_range == 0.0:
|
|
157
|
+
return h_hi
|
|
158
|
+
|
|
159
|
+
bin_width = data_range / G
|
|
160
|
+
N = pad_factor * G
|
|
161
|
+
bw = bin_width # histogram bin width
|
|
162
|
+
|
|
163
|
+
# Pre-compute histogram once; reuse C (FFT of counts) for all h evaluations.
|
|
164
|
+
with torch.no_grad():
|
|
165
|
+
X_cpu = X.float().cpu()
|
|
166
|
+
edges = torch.linspace(lo_d, hi_d, G + 1)
|
|
167
|
+
bin_idx = torch.bucketize(X_cpu, edges, right=True).clamp(1, G) - 1
|
|
168
|
+
counts = torch.bincount(bin_idx, minlength=G).float()
|
|
169
|
+
counts_padded = torch.zeros(N, dtype=torch.float64)
|
|
170
|
+
counts_padded[:G] = counts.double()
|
|
171
|
+
C = torch.fft.rfft(counts_padded)
|
|
172
|
+
k = torch.arange(N // 2 + 1, dtype=torch.float64)
|
|
173
|
+
omega_base = 2 * math.pi * k / (N * bw)
|
|
174
|
+
|
|
175
|
+
def fprime(h: float) -> Tensor:
|
|
176
|
+
"""Compute f′ array (shape G,) for bandwidth h using cached C (float64)."""
|
|
177
|
+
K_deriv = 1j * omega_base * torch.exp(-0.5 * (omega_base * h) ** 2)
|
|
178
|
+
return torch.fft.irfft(C * K_deriv, n=N).float()[:G]
|
|
179
|
+
|
|
180
|
+
with torch.no_grad():
|
|
181
|
+
# If the bracket is tighter than bin_width, expand it so that the
|
|
182
|
+
# disappearing f′ lobe crosses zero somewhere inside the bracket.
|
|
183
|
+
# Expand the left endpoint leftward by up to 4 bin widths.
|
|
184
|
+
ref_lo = h_lo
|
|
185
|
+
ref_hi = h_hi
|
|
186
|
+
|
|
187
|
+
if (ref_hi - ref_lo) < bw:
|
|
188
|
+
# Try expanding leftward until we find a bin where fp crosses zero
|
|
189
|
+
for mult in [1, 2, 3, 4]:
|
|
190
|
+
cand_lo = max(ref_hi - mult * bw, ref_hi * 0.9)
|
|
191
|
+
fp_cand = fprime(cand_lo)
|
|
192
|
+
fp_hi_ = fprime(ref_hi)
|
|
193
|
+
cm = (fp_cand > 0) & (fp_hi_ <= 0)
|
|
194
|
+
if cm.any():
|
|
195
|
+
ref_lo = cand_lo
|
|
196
|
+
break
|
|
197
|
+
# If still no candidates found, return bisection result unchanged
|
|
198
|
+
fp_lo_ = fprime(ref_lo)
|
|
199
|
+
fp_hi_ = fprime(ref_hi)
|
|
200
|
+
candidate_mask = (fp_lo_ > 0) & (fp_hi_ <= 0)
|
|
201
|
+
if not candidate_mask.any():
|
|
202
|
+
return h_hi
|
|
203
|
+
else:
|
|
204
|
+
fp_lo_ = fprime(ref_lo)
|
|
205
|
+
fp_hi_ = fprime(ref_hi)
|
|
206
|
+
candidate_mask = (fp_lo_ > 0) & (fp_hi_ <= 0)
|
|
207
|
+
if not candidate_mask.any():
|
|
208
|
+
return h_hi
|
|
209
|
+
|
|
210
|
+
# Pick the bin with the largest positive value at ref_lo that crossed zero
|
|
211
|
+
masked_fp_lo = fp_lo_.clone()
|
|
212
|
+
masked_fp_lo[~candidate_mask] = -float('inf')
|
|
213
|
+
j = int(masked_fp_lo.argmax().item())
|
|
214
|
+
|
|
215
|
+
h_mid = (ref_lo + ref_hi) / 2.0
|
|
216
|
+
|
|
217
|
+
# Evaluate fp[j] at three bandwidths for quadratic fit
|
|
218
|
+
y_lo = fp_lo_[j].item()
|
|
219
|
+
y_mid = fprime(h_mid)[j].item()
|
|
220
|
+
y_hi = fp_hi_[j].item()
|
|
221
|
+
|
|
222
|
+
# Fit quadratic y = a*h² + b*h + c through the three (h, y) pairs
|
|
223
|
+
# and solve for the root in [ref_lo, ref_hi].
|
|
224
|
+
coeffs = np.polyfit([ref_lo, h_mid, ref_hi], [y_lo, y_mid, y_hi], 2)
|
|
225
|
+
roots = np.roots(coeffs)
|
|
226
|
+
real_roots = [
|
|
227
|
+
r.real for r in roots
|
|
228
|
+
if abs(r.imag) < 1e-10 * abs(r.real + 1e-30)
|
|
229
|
+
and ref_lo <= r.real <= ref_hi
|
|
230
|
+
]
|
|
231
|
+
if real_roots:
|
|
232
|
+
return float(min(real_roots, key=lambda r: abs(r - h_mid)))
|
|
233
|
+
return h_hi
|
|
234
|
+
|
|
235
|
+
|
|
236
|
+
def adaptive_fft_G(data_range: float, h_hi: float, G_min: int = 16384) -> int:
|
|
237
|
+
"""Choose FFT grid size G so that the derivative kernel is well-resolved.
|
|
238
|
+
|
|
239
|
+
Requires h > 8 * bin_width = 8 * data_range / G, equivalently
|
|
240
|
+
G > 8 * data_range / h_hi. We use a factor of 16 for safety margin,
|
|
241
|
+
then round up to the next power of 2 for efficient FFT.
|
|
242
|
+
|
|
243
|
+
Parameters
|
|
244
|
+
----------
|
|
245
|
+
data_range : float
|
|
246
|
+
hi - lo of the data domain (typically X.max() - X.min() + 6σ).
|
|
247
|
+
h_hi : float
|
|
248
|
+
Upper bracket of the bisection (smallest h needing resolution).
|
|
249
|
+
G_min : int
|
|
250
|
+
Minimum returned G (default 16384).
|
|
251
|
+
|
|
252
|
+
Returns
|
|
253
|
+
-------
|
|
254
|
+
int
|
|
255
|
+
Grid size G, a power of 2, at least G_min.
|
|
256
|
+
"""
|
|
257
|
+
needed = 16 * math.ceil(data_range / h_hi)
|
|
258
|
+
# Round up to next power of 2
|
|
259
|
+
p = 1
|
|
260
|
+
while p < needed:
|
|
261
|
+
p <<= 1
|
|
262
|
+
return max(G_min, p)
|
|
@@ -35,13 +35,13 @@ class DCBFunction(torch.autograd.Function):
|
|
|
35
35
|
|
|
36
36
|
@staticmethod
|
|
37
37
|
def forward(ctx, X, grid, eps, tau, target_modes, delta, formula, chunk_size,
|
|
38
|
-
brentq_n_max, g_brentq, use_hard_bisection, safe_backward, use_fft):
|
|
38
|
+
brentq_n_max, g_brentq, use_hard_bisection, safe_backward, use_fft, fft_G_min):
|
|
39
39
|
"""Locate h_crit and save state for the backward pass."""
|
|
40
40
|
h_crit, cond_num = find_h_crit(
|
|
41
41
|
X, grid, eps, tau, target_modes,
|
|
42
42
|
formula=formula, brentq_n_max=brentq_n_max, chunk_size=chunk_size,
|
|
43
43
|
g_brentq=g_brentq, use_hard_bisection=use_hard_bisection,
|
|
44
|
-
use_fft=use_fft,
|
|
44
|
+
use_fft=use_fft, G_min=fft_G_min,
|
|
45
45
|
)
|
|
46
46
|
ctx.save_for_backward(X, grid)
|
|
47
47
|
ctx.h_crit = h_crit
|
|
@@ -67,8 +67,8 @@ class DCBFunction(torch.autograd.Function):
|
|
|
67
67
|
ctx.denom_abs = ift_gradient.last_denom_abs
|
|
68
68
|
# Gradients for: X, grid, eps, tau, target_modes, delta, formula,
|
|
69
69
|
# chunk_size, brentq_n_max, g_brentq, use_hard_bisection,
|
|
70
|
-
# safe_backward, use_fft
|
|
71
|
-
return grad_X, None, None, None, None, None, None, None, None, None, None, None, None
|
|
70
|
+
# safe_backward, use_fft, fft_G_min
|
|
71
|
+
return grad_X, None, None, None, None, None, None, None, None, None, None, None, None, None
|
|
72
72
|
|
|
73
73
|
|
|
74
74
|
class DCBLayer(nn.Module):
|
|
@@ -133,6 +133,13 @@ class DCBLayer(nn.Module):
|
|
|
133
133
|
Number of points to sketch when n > max_n_exact. Default 500_000.
|
|
134
134
|
A 500K sketch achieves the same mean accuracy as streaming 100M points
|
|
135
135
|
(validated in Round 20 reservoir experiment).
|
|
136
|
+
fft_G_min : int
|
|
137
|
+
Minimum FFT histogram grid size for the bisection solver (default 16384).
|
|
138
|
+
Controls accuracy of the FFT path (n > 50K). Larger values reduce
|
|
139
|
+
discretisation error at a modest cost: G=16384 gives ~0.004% err vs R;
|
|
140
|
+
G=32768 gives ~0.001% at +9% cost; G=65536 reaches the R-matching floor
|
|
141
|
+
(~0.001%) with no further gain beyond that. Ignored for n ≤ 50K (direct
|
|
142
|
+
KDE path).
|
|
136
143
|
|
|
137
144
|
Examples
|
|
138
145
|
--------
|
|
@@ -162,6 +169,7 @@ class DCBLayer(nn.Module):
|
|
|
162
169
|
use_fft: bool = True,
|
|
163
170
|
max_n_exact: int | None = 1_000_000,
|
|
164
171
|
sketch_size: int = 500_000,
|
|
172
|
+
fft_G_min: int = 16384,
|
|
165
173
|
):
|
|
166
174
|
super().__init__()
|
|
167
175
|
self.target_modes = target_modes
|
|
@@ -180,6 +188,7 @@ class DCBLayer(nn.Module):
|
|
|
180
188
|
self.use_fft = use_fft
|
|
181
189
|
self.max_n_exact = max_n_exact
|
|
182
190
|
self.sketch_size = sketch_size
|
|
191
|
+
self.fft_G_min = fft_G_min
|
|
183
192
|
if use_fft and brentq_n_max != 50_000:
|
|
184
193
|
raise TypeError(
|
|
185
194
|
f"brentq_n_max={brentq_n_max} is meaningless when use_fft=True: the FFT path "
|
|
@@ -250,7 +259,7 @@ class DCBLayer(nn.Module):
|
|
|
250
259
|
return DCBFunction.apply(
|
|
251
260
|
X, grid, eps_eff, tau_eff, self.target_modes, self.delta, self.formula,
|
|
252
261
|
self.chunk_size, self.brentq_n_max, self.g_brentq, self.use_hard_bisection,
|
|
253
|
-
self.safe_backward, self.use_fft,
|
|
262
|
+
self.safe_backward, self.use_fft, self.fft_G_min,
|
|
254
263
|
)
|
|
255
264
|
|
|
256
265
|
|
|
@@ -74,6 +74,7 @@ def find_h_crit_hard(
|
|
|
74
74
|
eps: float = 0.1,
|
|
75
75
|
tau: float = 0.2,
|
|
76
76
|
use_fft: bool = False,
|
|
77
|
+
G_min: int = 16384,
|
|
77
78
|
) -> tuple[float, float]:
|
|
78
79
|
"""Find h_crit via hard-mode-count bisection (monotone, no false roots).
|
|
79
80
|
|
|
@@ -151,7 +152,7 @@ def find_h_crit_hard(
|
|
|
151
152
|
lo_domain = X.min().item() - 3 * sigma
|
|
152
153
|
hi_domain = X.max().item() + 3 * sigma
|
|
153
154
|
data_range = hi_domain - lo_domain
|
|
154
|
-
G_fft = adaptive_fft_G(data_range, h_hi)
|
|
155
|
+
G_fft = adaptive_fft_G(data_range, h_hi, G_min=G_min)
|
|
155
156
|
_domain = (lo_domain, hi_domain)
|
|
156
157
|
|
|
157
158
|
with torch.no_grad():
|
|
@@ -188,6 +189,11 @@ def find_h_crit_hard(
|
|
|
188
189
|
|
|
189
190
|
h_crit = float(hi) # smallest h with count <= target_modes
|
|
190
191
|
|
|
192
|
+
# Sub-bin refinement: quadratic interpolation on the disappearing f′ lobe
|
|
193
|
+
# to locate h_crit below the bin-width precision limit.
|
|
194
|
+
from dcb.fft_kde import _refine_hcrit
|
|
195
|
+
h_crit = _refine_hcrit(X, lo, hi, G_fft, _domain, target_modes)
|
|
196
|
+
|
|
191
197
|
else:
|
|
192
198
|
with torch.no_grad():
|
|
193
199
|
# Verify bracket: need count > target at h_lo, count <= target at h_hi.
|
|
@@ -290,6 +296,7 @@ def find_h_crit(
|
|
|
290
296
|
g_brentq: int = 128,
|
|
291
297
|
use_hard_bisection: bool = True,
|
|
292
298
|
use_fft: bool = True,
|
|
299
|
+
G_min: int = 16384,
|
|
293
300
|
) -> tuple[float, float]:
|
|
294
301
|
"""Find h_crit and return (h_crit, condition_number).
|
|
295
302
|
|
|
@@ -343,7 +350,7 @@ def find_h_crit(
|
|
|
343
350
|
return find_h_crit_hard(
|
|
344
351
|
X, grid, target_modes, chunk_size, brentq_n_max,
|
|
345
352
|
h_lo, h_hi, formula=formula, eps=eps, tau=tau,
|
|
346
|
-
use_fft=use_fft,
|
|
353
|
+
use_fft=use_fft, G_min=G_min,
|
|
347
354
|
)
|
|
348
355
|
|
|
349
356
|
from scipy.optimize import brentq
|
|
@@ -4,7 +4,7 @@ build-backend = "hatchling.build"
|
|
|
4
4
|
|
|
5
5
|
[project]
|
|
6
6
|
name = "diffcb"
|
|
7
|
-
version = "0.1.
|
|
7
|
+
version = "0.1.3"
|
|
8
8
|
description = "Differentiable Critical Bandwidth: Silverman's modality test as a differentiable PyTorch layer with IFT backward pass."
|
|
9
9
|
readme = "README.md"
|
|
10
10
|
license = { file = "LICENSE" }
|
diffcb-0.1.1/README.md
DELETED
|
@@ -1,92 +0,0 @@
|
|
|
1
|
-
# DCB — Differentiable Critical Bandwidth
|
|
2
|
-
|
|
3
|
-
[](https://pypi.org/project/diffcb/)
|
|
4
|
-
[](LICENSE)
|
|
5
|
-
[](https://www.python.org/)
|
|
6
|
-
|
|
7
|
-
A PyTorch package that makes **Silverman's critical bandwidth test (1981)** fully differentiable, enabling end-to-end gradient-based optimization over the modal structure of continuous distributions.
|
|
8
|
-
|
|
9
|
-
## Overview
|
|
10
|
-
|
|
11
|
-
The critical bandwidth `h_crit` is the minimum KDE bandwidth at which a distribution appears to have at most `m` modes — a classical nonparametric statistic for modality testing. DCB replaces every non-differentiable operation in its computation with a smooth surrogate, then uses the **Implicit Function Theorem** to compute exact gradients through the root-finding step at O(1) memory cost.
|
|
12
|
-
|
|
13
|
-
```python
|
|
14
|
-
import torch
|
|
15
|
-
from dcb import DCBLayer
|
|
16
|
-
|
|
17
|
-
X = torch.randn(256, requires_grad=True) # 1D samples
|
|
18
|
-
layer = DCBLayer(target_modes=1)
|
|
19
|
-
h_crit = layer(X) # differentiable scalar
|
|
20
|
-
h_crit.backward() # exact IFT gradients
|
|
21
|
-
```
|
|
22
|
-
|
|
23
|
-
## Installation
|
|
24
|
-
|
|
25
|
-
```bash
|
|
26
|
-
pip install diffcb
|
|
27
|
-
```
|
|
28
|
-
|
|
29
|
-
Or from source:
|
|
30
|
-
|
|
31
|
-
```bash
|
|
32
|
-
git clone https://github.com/ryZhangHason/differentiable-critical-bandwidth
|
|
33
|
-
cd differentiable-critical-bandwidth
|
|
34
|
-
pip install -e ".[dev]"
|
|
35
|
-
```
|
|
36
|
-
|
|
37
|
-
## Paper
|
|
38
|
-
|
|
39
|
-
> Ruiyu Zhang. "Differentiable Critical Bandwidth: Making Silverman's Modality Test End-to-End Trainable." *Journal of Machine Learning Research*, 2026 (in preparation).
|
|
40
|
-
|
|
41
|
-
## Confirmed Experimental Results
|
|
42
|
-
|
|
43
|
-
All results produced on Kaggle GPU (T4 / P100) — see `experiments/` and `outputs/`.
|
|
44
|
-
|
|
45
|
-
| Experiment | Result | Criterion |
|
|
46
|
-
|---|---|---|
|
|
47
|
-
| **Validation (m≥2)** | R²=0.91, MAE=0.07, Spearman ρ=0.89 | R²≥0.85, MAE≤0.10 ✓ |
|
|
48
|
-
| **Speedup vs scipy (n=8192)** | **10.5×** on T4 | ≥3× ✓ |
|
|
49
|
-
| **GAN mode preservation** | h_crit=1.232 >> 0.3 | h_crit>0.3 ✓ |
|
|
50
|
-
| **Anomaly AUC (KDDCup99)** | DCB=**0.9982** vs IF=0.9867 | DCB≥IF ✓ |
|
|
51
|
-
|
|
52
|
-
## Repository Structure
|
|
53
|
-
|
|
54
|
-
```
|
|
55
|
-
dcb/ Core PyTorch package (layer.py, solver.py, kde.py, utils.py)
|
|
56
|
-
experiments/ Reproduction scripts for all paper figures and tables
|
|
57
|
-
phase1_validation.py Figure 1: DCB vs reference h_crit scatter
|
|
58
|
-
phase1_speedup.py Figure 2: GPU speedup benchmark
|
|
59
|
-
phase1_ablation.py Figures S1–S2: ε/τ sensitivity heatmaps
|
|
60
|
-
phase2_gan.py Figure 3: GAN mode-collapse prevention
|
|
61
|
-
phase3_anomaly.py Table 2 + Figure 5: anomaly detection benchmark
|
|
62
|
-
tests/ Unit tests (pytest, 35/35 passing)
|
|
63
|
-
outputs/ All generated figures and tables (PDFs, PNGs, CSVs)
|
|
64
|
-
notebooks/ Quickstart and demo notebooks
|
|
65
|
-
```
|
|
66
|
-
|
|
67
|
-
## Reproducing Paper Results
|
|
68
|
-
|
|
69
|
-
```bash
|
|
70
|
-
# Phase 1: validation, speedup, ablation
|
|
71
|
-
python experiments/phase1_validation.py
|
|
72
|
-
python experiments/phase1_speedup.py
|
|
73
|
-
python experiments/phase1_ablation.py
|
|
74
|
-
|
|
75
|
-
# Phase 2: GAN mode collapse experiment
|
|
76
|
-
python experiments/phase2_gan.py
|
|
77
|
-
|
|
78
|
-
# Phase 3: anomaly detection benchmark
|
|
79
|
-
python experiments/phase3_anomaly.py
|
|
80
|
-
```
|
|
81
|
-
|
|
82
|
-
For GPU runs, use the provided Kaggle kernels:
|
|
83
|
-
- Phase 1–2: `hsingle/dcb-full-experiments`
|
|
84
|
-
- Phase 3: `hsingle/dcb-phase-3-anomaly-detection`
|
|
85
|
-
|
|
86
|
-
## Kaggle GPU Notes
|
|
87
|
-
|
|
88
|
-
Kaggle may assign a P100 (sm_60) instead of T4. The Phase 3 kernel handles this automatically by installing `torch==2.2.2+cu118` (the earliest PyTorch release with both Python 3.12 and sm_60 support) when P100 is detected.
|
|
89
|
-
|
|
90
|
-
## License
|
|
91
|
-
|
|
92
|
-
MIT — see [LICENSE](LICENSE).
|
diffcb-0.1.1/dcb/fft_kde.py
DELETED
|
@@ -1,144 +0,0 @@
|
|
|
1
|
-
"""
|
|
2
|
-
dcb.fft_kde — FFT-based KDE Mode Counter
|
|
3
|
-
|
|
4
|
-
Implements mode counting via FFT convolution of the histogram with a
|
|
5
|
-
Gaussian derivative kernel. Complexity is O(n + G log G), avoiding the
|
|
6
|
-
O(n × G) cost of the direct KDE approach and — crucially — requiring NO
|
|
7
|
-
subsampling. This eliminates the (brentq_n_max / n)^{-1/5} upward bias
|
|
8
|
-
that affects the standard bisection path when n > brentq_n_max.
|
|
9
|
-
|
|
10
|
-
Round 18b: forward kernel only. The IFT backward is unchanged (still uses
|
|
11
|
-
the analytical chunked KDE derivatives on all n points).
|
|
12
|
-
"""
|
|
13
|
-
|
|
14
|
-
from __future__ import annotations
|
|
15
|
-
|
|
16
|
-
import math
|
|
17
|
-
|
|
18
|
-
import torch
|
|
19
|
-
from torch import Tensor
|
|
20
|
-
|
|
21
|
-
|
|
22
|
-
def fft_mode_count(
|
|
23
|
-
X: Tensor,
|
|
24
|
-
h: float,
|
|
25
|
-
G: int = 4096,
|
|
26
|
-
pad_factor: int = 4,
|
|
27
|
-
domain: tuple[float, float] | None = None,
|
|
28
|
-
) -> int:
|
|
29
|
-
"""Count KDE modes via FFT convolution — O(n + G log G), no subsampling.
|
|
30
|
-
|
|
31
|
-
Bins X into G histogram bins, zero-pads to pad_factor*G, convolves with
|
|
32
|
-
the Gaussian derivative kernel in the frequency domain (applying iω·exp(−½(ωh)²)),
|
|
33
|
-
back-transforms, and counts positive-to-negative sign changes of the
|
|
34
|
-
resulting f' estimate.
|
|
35
|
-
|
|
36
|
-
Parameters
|
|
37
|
-
----------
|
|
38
|
-
X : Tensor, shape (n,)
|
|
39
|
-
1D data tensor (may be on CPU or CUDA).
|
|
40
|
-
h : float
|
|
41
|
-
Bandwidth for the Gaussian kernel.
|
|
42
|
-
G : int
|
|
43
|
-
Number of histogram bins. Must satisfy h > 8 * (data_range / G) for
|
|
44
|
-
reliable derivative estimation. Use `adaptive_fft_G` to choose G
|
|
45
|
-
automatically before bisection.
|
|
46
|
-
pad_factor : int
|
|
47
|
-
Zero-padding multiplier (default 4). Mandatory ≥ 2 for circular-wrap
|
|
48
|
-
correctness; 4 is recommended at the largest h encountered.
|
|
49
|
-
domain : (lo, hi) or None
|
|
50
|
-
If provided, use this as the histogram domain instead of computing
|
|
51
|
-
X.min() - 3σ … X.max() + 3σ. Allows the caller to align the domain
|
|
52
|
-
with the bisection bracket (e.g., X.min() - 2*h_hi … X.max() + 2*h_hi)
|
|
53
|
-
so every fft_mode_count call in a bisection loop uses an identical grid.
|
|
54
|
-
|
|
55
|
-
Returns
|
|
56
|
-
-------
|
|
57
|
-
int
|
|
58
|
-
Number of KDE modes (downward zero-crossings of f').
|
|
59
|
-
"""
|
|
60
|
-
with torch.no_grad():
|
|
61
|
-
if domain is not None:
|
|
62
|
-
lo, hi = domain
|
|
63
|
-
else:
|
|
64
|
-
# Domain: extend 3σ beyond data range to avoid boundary effects
|
|
65
|
-
sigma = X.std().item()
|
|
66
|
-
if sigma == 0.0:
|
|
67
|
-
sigma = 1.0 # degenerate case: all points identical
|
|
68
|
-
lo = X.min().item() - 3 * sigma
|
|
69
|
-
hi = X.max().item() + 3 * sigma
|
|
70
|
-
data_range = hi - lo
|
|
71
|
-
|
|
72
|
-
if data_range == 0.0:
|
|
73
|
-
return 1 # single-point distribution has 1 mode
|
|
74
|
-
|
|
75
|
-
# Histogram (O(n)) — MPS-safe via bucketize+bincount on CPU.
|
|
76
|
-
# torch.histc on MPS allocates an n × bins float32 intermediate (PyTorch
|
|
77
|
-
# MPS bug); at n=5M, bins=512 this is ~9.5 GiB → OOM. Moving to CPU for
|
|
78
|
-
# the binning step avoids the intermediate and is numerically identical
|
|
79
|
-
# for data within [lo, hi] (guaranteed by the 3σ domain extension above).
|
|
80
|
-
X_cpu = X.float().cpu()
|
|
81
|
-
edges = torch.linspace(lo, hi, G + 1) # (G+1,) CPU
|
|
82
|
-
bin_idx = torch.bucketize(X_cpu, edges, right=True).clamp(1, G) - 1 # 0-indexed
|
|
83
|
-
counts = torch.bincount(bin_idx, minlength=G).float().to(X.device) # back to device
|
|
84
|
-
|
|
85
|
-
# Zero-pad to pad_factor*G (4× mandatory for circular wrap correctness at h_hi)
|
|
86
|
-
N = pad_factor * G
|
|
87
|
-
counts_padded = torch.zeros(N, dtype=torch.float32, device=X.device)
|
|
88
|
-
counts_padded[:G] = counts
|
|
89
|
-
|
|
90
|
-
# FFT of histogram
|
|
91
|
-
C = torch.fft.rfft(counts_padded)
|
|
92
|
-
|
|
93
|
-
# Derivative kernel in frequency domain: iω * exp(-0.5*(ω*h)²)
|
|
94
|
-
# ω_k = 2π*k / (N * bin_width), bin_width = data_range / G
|
|
95
|
-
bin_width = data_range / G
|
|
96
|
-
k = torch.arange(N // 2 + 1, device=X.device, dtype=torch.float32)
|
|
97
|
-
omega = 2 * math.pi * k / (N * bin_width)
|
|
98
|
-
K_deriv = 1j * omega * torch.exp(-0.5 * (omega * h) ** 2)
|
|
99
|
-
|
|
100
|
-
# Convolve and back-transform
|
|
101
|
-
f_prime_padded = torch.fft.irfft(C * K_deriv, n=N)
|
|
102
|
-
|
|
103
|
-
# Trim to original G grid (discard zero-padded tail)
|
|
104
|
-
f_prime = f_prime_padded[:G]
|
|
105
|
-
|
|
106
|
-
# Count (+→-) sign changes = number of modes
|
|
107
|
-
# A mode is a local max of f, i.e., f' crosses zero from + to -
|
|
108
|
-
# Remove zeros (flat segments) — carry forward last nonzero sign
|
|
109
|
-
nonzero_mask = f_prime != 0
|
|
110
|
-
if not nonzero_mask.any():
|
|
111
|
-
return 0
|
|
112
|
-
|
|
113
|
-
s = f_prime[nonzero_mask]
|
|
114
|
-
transitions = int(((s[:-1] > 0) & (s[1:] < 0)).sum().item())
|
|
115
|
-
return transitions
|
|
116
|
-
|
|
117
|
-
|
|
118
|
-
def adaptive_fft_G(data_range: float, h_hi: float, G_min: int = 4096) -> int:
|
|
119
|
-
"""Choose FFT grid size G so that the derivative kernel is well-resolved.
|
|
120
|
-
|
|
121
|
-
Requires h > 8 * bin_width = 8 * data_range / G, equivalently
|
|
122
|
-
G > 8 * data_range / h_hi. We use a factor of 16 for safety margin,
|
|
123
|
-
then round up to the next power of 2 for efficient FFT.
|
|
124
|
-
|
|
125
|
-
Parameters
|
|
126
|
-
----------
|
|
127
|
-
data_range : float
|
|
128
|
-
hi - lo of the data domain (typically X.max() - X.min() + 6σ).
|
|
129
|
-
h_hi : float
|
|
130
|
-
Upper bracket of the bisection (smallest h needing resolution).
|
|
131
|
-
G_min : int
|
|
132
|
-
Minimum returned G (default 4096).
|
|
133
|
-
|
|
134
|
-
Returns
|
|
135
|
-
-------
|
|
136
|
-
int
|
|
137
|
-
Grid size G, a power of 2, at least G_min.
|
|
138
|
-
"""
|
|
139
|
-
needed = 16 * math.ceil(data_range / h_hi)
|
|
140
|
-
# Round up to next power of 2
|
|
141
|
-
p = 1
|
|
142
|
-
while p < needed:
|
|
143
|
-
p <<= 1
|
|
144
|
-
return max(G_min, p)
|
|
File without changes
|
|
File without changes
|
|
File without changes
|
|
File without changes
|
|
File without changes
|
|
File without changes
|
|
File without changes
|
|
File without changes
|
|
File without changes
|
|
File without changes
|
|
File without changes
|
|
File without changes
|
|
File without changes
|
|
File without changes
|