diffcb 0.1.0__tar.gz
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- diffcb-0.1.0/.gitignore +11 -0
- diffcb-0.1.0/.zenodo.json +30 -0
- diffcb-0.1.0/LICENSE +21 -0
- diffcb-0.1.0/PKG-INFO +148 -0
- diffcb-0.1.0/README.md +91 -0
- diffcb-0.1.0/dcb/__init__.py +22 -0
- diffcb-0.1.0/dcb/diagnostics.py +163 -0
- diffcb-0.1.0/dcb/fft_kde.py +128 -0
- diffcb-0.1.0/dcb/kde.py +394 -0
- diffcb-0.1.0/dcb/layer.py +231 -0
- diffcb-0.1.0/dcb/solver.py +604 -0
- diffcb-0.1.0/dcb/utils.py +183 -0
- diffcb-0.1.0/notebooks/.gitkeep +0 -0
- diffcb-0.1.0/pyproject.toml +63 -0
- diffcb-0.1.0/tests/test_kde.py +312 -0
- diffcb-0.1.0/tests/test_layer.py +165 -0
- diffcb-0.1.0/tests/test_r18c_denom_audit.py +118 -0
- diffcb-0.1.0/tests/test_r18c_deprecation_warn.py +64 -0
- diffcb-0.1.0/tests/test_r19_default_fft.py +52 -0
- diffcb-0.1.0/tests/test_r19_diagnostics.py +80 -0
- diffcb-0.1.0/tests/test_solver.py +179 -0
diffcb-0.1.0/.gitignore
ADDED
|
@@ -0,0 +1,30 @@
|
|
|
1
|
+
{
|
|
2
|
+
"title": "Differentiable Critical Bandwidth (DCB) v0.1.0",
|
|
3
|
+
"description": "A PyTorch package making Silverman's critical bandwidth test fully differentiable via a smooth mode-counting integral and an Implicit Function Theorem backward pass.",
|
|
4
|
+
"upload_type": "software",
|
|
5
|
+
"license": "MIT",
|
|
6
|
+
"creators": [
|
|
7
|
+
{
|
|
8
|
+
"name": "Zhang, Ruiyu",
|
|
9
|
+
"affiliation": "University of Hong Kong"
|
|
10
|
+
}
|
|
11
|
+
],
|
|
12
|
+
"keywords": [
|
|
13
|
+
"nonparametric statistics",
|
|
14
|
+
"kernel density estimation",
|
|
15
|
+
"differentiable programming",
|
|
16
|
+
"critical bandwidth",
|
|
17
|
+
"mode counting",
|
|
18
|
+
"implicit function theorem",
|
|
19
|
+
"PyTorch",
|
|
20
|
+
"JMLR",
|
|
21
|
+
"Silverman 1981"
|
|
22
|
+
],
|
|
23
|
+
"related_identifiers": [
|
|
24
|
+
{
|
|
25
|
+
"relation": "isSupplementTo",
|
|
26
|
+
"identifier": "10.48550/arXiv.XXXX.XXXXX",
|
|
27
|
+
"scheme": "doi"
|
|
28
|
+
}
|
|
29
|
+
]
|
|
30
|
+
}
|
diffcb-0.1.0/LICENSE
ADDED
|
@@ -0,0 +1,21 @@
|
|
|
1
|
+
MIT License
|
|
2
|
+
|
|
3
|
+
Copyright (c) 2026 Ruiyu Zhang
|
|
4
|
+
|
|
5
|
+
Permission is hereby granted, free of charge, to any person obtaining a copy
|
|
6
|
+
of this software and associated documentation files (the "Software"), to deal
|
|
7
|
+
in the Software without restriction, including without limitation the rights
|
|
8
|
+
to use, copy, modify, merge, publish, distribute, sublicense, and/or sell
|
|
9
|
+
copies of the Software, and to permit persons to whom the Software is
|
|
10
|
+
furnished to do so, subject to the following conditions:
|
|
11
|
+
|
|
12
|
+
The above copyright notice and this permission notice shall be included in all
|
|
13
|
+
copies or substantial portions of the Software.
|
|
14
|
+
|
|
15
|
+
THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR
|
|
16
|
+
IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY,
|
|
17
|
+
FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE
|
|
18
|
+
AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER
|
|
19
|
+
LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM,
|
|
20
|
+
OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE
|
|
21
|
+
SOFTWARE.
|
diffcb-0.1.0/PKG-INFO
ADDED
|
@@ -0,0 +1,148 @@
|
|
|
1
|
+
Metadata-Version: 2.4
|
|
2
|
+
Name: diffcb
|
|
3
|
+
Version: 0.1.0
|
|
4
|
+
Summary: Differentiable Critical Bandwidth: Silverman's modality test as a differentiable PyTorch layer with IFT backward pass.
|
|
5
|
+
Project-URL: Homepage, https://github.com/ryZhangHason/differentiable-critical-bandwidth
|
|
6
|
+
Project-URL: Repository, https://github.com/ryZhangHason/differentiable-critical-bandwidth
|
|
7
|
+
Project-URL: Documentation, https://github.com/ryZhangHason/differentiable-critical-bandwidth#readme
|
|
8
|
+
Project-URL: Bug Tracker, https://github.com/ryZhangHason/differentiable-critical-bandwidth/issues
|
|
9
|
+
Author-email: Ruiyu Zhang <dhhhason@gmail.com>
|
|
10
|
+
License: MIT License
|
|
11
|
+
|
|
12
|
+
Copyright (c) 2026 Ruiyu Zhang
|
|
13
|
+
|
|
14
|
+
Permission is hereby granted, free of charge, to any person obtaining a copy
|
|
15
|
+
of this software and associated documentation files (the "Software"), to deal
|
|
16
|
+
in the Software without restriction, including without limitation the rights
|
|
17
|
+
to use, copy, modify, merge, publish, distribute, sublicense, and/or sell
|
|
18
|
+
copies of the Software, and to permit persons to whom the Software is
|
|
19
|
+
furnished to do so, subject to the following conditions:
|
|
20
|
+
|
|
21
|
+
The above copyright notice and this permission notice shall be included in all
|
|
22
|
+
copies or substantial portions of the Software.
|
|
23
|
+
|
|
24
|
+
THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR
|
|
25
|
+
IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY,
|
|
26
|
+
FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE
|
|
27
|
+
AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER
|
|
28
|
+
LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM,
|
|
29
|
+
OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE
|
|
30
|
+
SOFTWARE.
|
|
31
|
+
License-File: LICENSE
|
|
32
|
+
Keywords: PyTorch,anomaly detection,critical bandwidth,differentiable programming,generative models,kernel density estimation,mode counting,nonparametric statistics
|
|
33
|
+
Classifier: Development Status :: 3 - Alpha
|
|
34
|
+
Classifier: Intended Audience :: Science/Research
|
|
35
|
+
Classifier: License :: OSI Approved :: MIT License
|
|
36
|
+
Classifier: Programming Language :: Python :: 3
|
|
37
|
+
Classifier: Programming Language :: Python :: 3.9
|
|
38
|
+
Classifier: Programming Language :: Python :: 3.10
|
|
39
|
+
Classifier: Programming Language :: Python :: 3.11
|
|
40
|
+
Classifier: Topic :: Scientific/Engineering :: Artificial Intelligence
|
|
41
|
+
Classifier: Topic :: Scientific/Engineering :: Mathematics
|
|
42
|
+
Requires-Python: >=3.9
|
|
43
|
+
Requires-Dist: matplotlib>=3.7.0
|
|
44
|
+
Requires-Dist: numpy>=1.24.0
|
|
45
|
+
Requires-Dist: scikit-learn>=1.3.0
|
|
46
|
+
Requires-Dist: scipy>=1.10.0
|
|
47
|
+
Requires-Dist: torch>=2.0.0
|
|
48
|
+
Provides-Extra: dev
|
|
49
|
+
Requires-Dist: black>=23.0.0; extra == 'dev'
|
|
50
|
+
Requires-Dist: pytest-cov>=4.1.0; extra == 'dev'
|
|
51
|
+
Requires-Dist: pytest>=7.4.0; extra == 'dev'
|
|
52
|
+
Requires-Dist: ruff>=0.1.0; extra == 'dev'
|
|
53
|
+
Provides-Extra: notebooks
|
|
54
|
+
Requires-Dist: ipywidgets>=8.0.0; extra == 'notebooks'
|
|
55
|
+
Requires-Dist: jupyter>=1.0.0; extra == 'notebooks'
|
|
56
|
+
Description-Content-Type: text/markdown
|
|
57
|
+
|
|
58
|
+
# DCB — Differentiable Critical Bandwidth
|
|
59
|
+
|
|
60
|
+
[](LICENSE)
|
|
61
|
+
[](https://www.python.org/)
|
|
62
|
+
|
|
63
|
+
A PyTorch package that makes **Silverman's critical bandwidth test (1981)** fully differentiable, enabling end-to-end gradient-based optimization over the modal structure of continuous distributions.
|
|
64
|
+
|
|
65
|
+
## Overview
|
|
66
|
+
|
|
67
|
+
The critical bandwidth `h_crit` is the minimum KDE bandwidth at which a distribution appears to have at most `m` modes — a classical nonparametric statistic for modality testing. DCB replaces every non-differentiable operation in its computation with a smooth surrogate, then uses the **Implicit Function Theorem** to compute exact gradients through the root-finding step at O(1) memory cost.
|
|
68
|
+
|
|
69
|
+
```python
|
|
70
|
+
import torch
|
|
71
|
+
from dcb import DCBLayer
|
|
72
|
+
|
|
73
|
+
X = torch.randn(256, requires_grad=True) # 1D samples
|
|
74
|
+
layer = DCBLayer(target_modes=1)
|
|
75
|
+
h_crit = layer(X) # differentiable scalar
|
|
76
|
+
h_crit.backward() # exact IFT gradients
|
|
77
|
+
```
|
|
78
|
+
|
|
79
|
+
## Installation
|
|
80
|
+
|
|
81
|
+
```bash
|
|
82
|
+
pip install dcb
|
|
83
|
+
```
|
|
84
|
+
|
|
85
|
+
Or from source:
|
|
86
|
+
|
|
87
|
+
```bash
|
|
88
|
+
git clone https://github.com/ryZhangHason/dcb
|
|
89
|
+
cd dcb
|
|
90
|
+
pip install -e ".[dev]"
|
|
91
|
+
```
|
|
92
|
+
|
|
93
|
+
## Paper
|
|
94
|
+
|
|
95
|
+
> Ruiyu Zhang. "Differentiable Critical Bandwidth: Making Silverman's Modality Test End-to-End Trainable." *Journal of Machine Learning Research*, 2026 (in preparation).
|
|
96
|
+
|
|
97
|
+
## Confirmed Experimental Results
|
|
98
|
+
|
|
99
|
+
All results produced on Kaggle GPU (T4 / P100) — see `experiments/` and `outputs/`.
|
|
100
|
+
|
|
101
|
+
| Experiment | Result | Criterion |
|
|
102
|
+
|---|---|---|
|
|
103
|
+
| **Validation (m≥2)** | R²=0.91, MAE=0.07, Spearman ρ=0.89 | R²≥0.85, MAE≤0.10 ✓ |
|
|
104
|
+
| **Speedup vs scipy (n=8192)** | **10.5×** on T4 | ≥3× ✓ |
|
|
105
|
+
| **GAN mode preservation** | h_crit=1.232 >> 0.3 | h_crit>0.3 ✓ |
|
|
106
|
+
| **Anomaly AUC (KDDCup99)** | DCB=**0.9982** vs IF=0.9867 | DCB≥IF ✓ |
|
|
107
|
+
|
|
108
|
+
## Repository Structure
|
|
109
|
+
|
|
110
|
+
```
|
|
111
|
+
dcb/ Core PyTorch package (layer.py, solver.py, kde.py, utils.py)
|
|
112
|
+
experiments/ Reproduction scripts for all paper figures and tables
|
|
113
|
+
phase1_validation.py Figure 1: DCB vs reference h_crit scatter
|
|
114
|
+
phase1_speedup.py Figure 2: GPU speedup benchmark
|
|
115
|
+
phase1_ablation.py Figures S1–S2: ε/τ sensitivity heatmaps
|
|
116
|
+
phase2_gan.py Figure 3: GAN mode-collapse prevention
|
|
117
|
+
phase3_anomaly.py Table 2 + Figure 5: anomaly detection benchmark
|
|
118
|
+
tests/ Unit tests (pytest, 35/35 passing)
|
|
119
|
+
outputs/ All generated figures and tables (PDFs, PNGs, CSVs)
|
|
120
|
+
notebooks/ Quickstart and demo notebooks
|
|
121
|
+
```
|
|
122
|
+
|
|
123
|
+
## Reproducing Paper Results
|
|
124
|
+
|
|
125
|
+
```bash
|
|
126
|
+
# Phase 1: validation, speedup, ablation
|
|
127
|
+
python experiments/phase1_validation.py
|
|
128
|
+
python experiments/phase1_speedup.py
|
|
129
|
+
python experiments/phase1_ablation.py
|
|
130
|
+
|
|
131
|
+
# Phase 2: GAN mode collapse experiment
|
|
132
|
+
python experiments/phase2_gan.py
|
|
133
|
+
|
|
134
|
+
# Phase 3: anomaly detection benchmark
|
|
135
|
+
python experiments/phase3_anomaly.py
|
|
136
|
+
```
|
|
137
|
+
|
|
138
|
+
For GPU runs, use the provided Kaggle kernels:
|
|
139
|
+
- Phase 1–2: `hsingle/dcb-full-experiments`
|
|
140
|
+
- Phase 3: `hsingle/dcb-phase-3-anomaly-detection`
|
|
141
|
+
|
|
142
|
+
## Kaggle GPU Notes
|
|
143
|
+
|
|
144
|
+
Kaggle may assign a P100 (sm_60) instead of T4. The Phase 3 kernel handles this automatically by installing `torch==2.2.2+cu118` (the earliest PyTorch release with both Python 3.12 and sm_60 support) when P100 is detected.
|
|
145
|
+
|
|
146
|
+
## License
|
|
147
|
+
|
|
148
|
+
MIT — see [LICENSE](LICENSE).
|
diffcb-0.1.0/README.md
ADDED
|
@@ -0,0 +1,91 @@
|
|
|
1
|
+
# DCB — Differentiable Critical Bandwidth
|
|
2
|
+
|
|
3
|
+
[](LICENSE)
|
|
4
|
+
[](https://www.python.org/)
|
|
5
|
+
|
|
6
|
+
A PyTorch package that makes **Silverman's critical bandwidth test (1981)** fully differentiable, enabling end-to-end gradient-based optimization over the modal structure of continuous distributions.
|
|
7
|
+
|
|
8
|
+
## Overview
|
|
9
|
+
|
|
10
|
+
The critical bandwidth `h_crit` is the minimum KDE bandwidth at which a distribution appears to have at most `m` modes — a classical nonparametric statistic for modality testing. DCB replaces every non-differentiable operation in its computation with a smooth surrogate, then uses the **Implicit Function Theorem** to compute exact gradients through the root-finding step at O(1) memory cost.
|
|
11
|
+
|
|
12
|
+
```python
|
|
13
|
+
import torch
|
|
14
|
+
from dcb import DCBLayer
|
|
15
|
+
|
|
16
|
+
X = torch.randn(256, requires_grad=True) # 1D samples
|
|
17
|
+
layer = DCBLayer(target_modes=1)
|
|
18
|
+
h_crit = layer(X) # differentiable scalar
|
|
19
|
+
h_crit.backward() # exact IFT gradients
|
|
20
|
+
```
|
|
21
|
+
|
|
22
|
+
## Installation
|
|
23
|
+
|
|
24
|
+
```bash
|
|
25
|
+
pip install dcb
|
|
26
|
+
```
|
|
27
|
+
|
|
28
|
+
Or from source:
|
|
29
|
+
|
|
30
|
+
```bash
|
|
31
|
+
git clone https://github.com/ryZhangHason/dcb
|
|
32
|
+
cd dcb
|
|
33
|
+
pip install -e ".[dev]"
|
|
34
|
+
```
|
|
35
|
+
|
|
36
|
+
## Paper
|
|
37
|
+
|
|
38
|
+
> Ruiyu Zhang. "Differentiable Critical Bandwidth: Making Silverman's Modality Test End-to-End Trainable." *Journal of Machine Learning Research*, 2026 (in preparation).
|
|
39
|
+
|
|
40
|
+
## Confirmed Experimental Results
|
|
41
|
+
|
|
42
|
+
All results produced on Kaggle GPU (T4 / P100) — see `experiments/` and `outputs/`.
|
|
43
|
+
|
|
44
|
+
| Experiment | Result | Criterion |
|
|
45
|
+
|---|---|---|
|
|
46
|
+
| **Validation (m≥2)** | R²=0.91, MAE=0.07, Spearman ρ=0.89 | R²≥0.85, MAE≤0.10 ✓ |
|
|
47
|
+
| **Speedup vs scipy (n=8192)** | **10.5×** on T4 | ≥3× ✓ |
|
|
48
|
+
| **GAN mode preservation** | h_crit=1.232 >> 0.3 | h_crit>0.3 ✓ |
|
|
49
|
+
| **Anomaly AUC (KDDCup99)** | DCB=**0.9982** vs IF=0.9867 | DCB≥IF ✓ |
|
|
50
|
+
|
|
51
|
+
## Repository Structure
|
|
52
|
+
|
|
53
|
+
```
|
|
54
|
+
dcb/ Core PyTorch package (layer.py, solver.py, kde.py, utils.py)
|
|
55
|
+
experiments/ Reproduction scripts for all paper figures and tables
|
|
56
|
+
phase1_validation.py Figure 1: DCB vs reference h_crit scatter
|
|
57
|
+
phase1_speedup.py Figure 2: GPU speedup benchmark
|
|
58
|
+
phase1_ablation.py Figures S1–S2: ε/τ sensitivity heatmaps
|
|
59
|
+
phase2_gan.py Figure 3: GAN mode-collapse prevention
|
|
60
|
+
phase3_anomaly.py Table 2 + Figure 5: anomaly detection benchmark
|
|
61
|
+
tests/ Unit tests (pytest, 35/35 passing)
|
|
62
|
+
outputs/ All generated figures and tables (PDFs, PNGs, CSVs)
|
|
63
|
+
notebooks/ Quickstart and demo notebooks
|
|
64
|
+
```
|
|
65
|
+
|
|
66
|
+
## Reproducing Paper Results
|
|
67
|
+
|
|
68
|
+
```bash
|
|
69
|
+
# Phase 1: validation, speedup, ablation
|
|
70
|
+
python experiments/phase1_validation.py
|
|
71
|
+
python experiments/phase1_speedup.py
|
|
72
|
+
python experiments/phase1_ablation.py
|
|
73
|
+
|
|
74
|
+
# Phase 2: GAN mode collapse experiment
|
|
75
|
+
python experiments/phase2_gan.py
|
|
76
|
+
|
|
77
|
+
# Phase 3: anomaly detection benchmark
|
|
78
|
+
python experiments/phase3_anomaly.py
|
|
79
|
+
```
|
|
80
|
+
|
|
81
|
+
For GPU runs, use the provided Kaggle kernels:
|
|
82
|
+
- Phase 1–2: `hsingle/dcb-full-experiments`
|
|
83
|
+
- Phase 3: `hsingle/dcb-phase-3-anomaly-detection`
|
|
84
|
+
|
|
85
|
+
## Kaggle GPU Notes
|
|
86
|
+
|
|
87
|
+
Kaggle may assign a P100 (sm_60) instead of T4. The Phase 3 kernel handles this automatically by installing `torch==2.2.2+cu118` (the earliest PyTorch release with both Python 3.12 and sm_60 support) when P100 is detected.
|
|
88
|
+
|
|
89
|
+
## License
|
|
90
|
+
|
|
91
|
+
MIT — see [LICENSE](LICENSE).
|
|
@@ -0,0 +1,22 @@
|
|
|
1
|
+
"""
|
|
2
|
+
dcb — Differentiable Critical Bandwidth
|
|
3
|
+
|
|
4
|
+
A PyTorch package that makes Silverman's critical bandwidth test (1981) fully
|
|
5
|
+
differentiable via a smooth mode-counting integral and an Implicit Function
|
|
6
|
+
Theorem (IFT) backward pass. The primary public API is the
|
|
7
|
+
`DifferentiableCriticalBandwidth` class, which behaves as a standard
|
|
8
|
+
`torch.nn.Module` and can be used as a loss component or regularizer in any
|
|
9
|
+
gradient-based learning pipeline. Import as `from dcb import DCBLayer` for
|
|
10
|
+
the layer, or `from dcb.kde import gaussian_kde_grid` for lower-level KDE
|
|
11
|
+
utilities. Requires PyTorch >= 2.0, NumPy >= 1.24, and SciPy >= 1.10.
|
|
12
|
+
"""
|
|
13
|
+
|
|
14
|
+
from dcb.layer import DCBLayer, DifferentiableCriticalBandwidth
|
|
15
|
+
from dcb.utils import anneal_eps_tau
|
|
16
|
+
from dcb.kde import soft_mode_count_cross, soft_mode_count
|
|
17
|
+
|
|
18
|
+
__all__ = [
|
|
19
|
+
"DCBLayer", "DifferentiableCriticalBandwidth",
|
|
20
|
+
"anneal_eps_tau", "soft_mode_count_cross", "soft_mode_count",
|
|
21
|
+
]
|
|
22
|
+
__version__ = "0.1.0"
|
|
@@ -0,0 +1,163 @@
|
|
|
1
|
+
"""
|
|
2
|
+
dcb.diagnostics — Gradient Stability Diagnostics for DCB
|
|
3
|
+
|
|
4
|
+
Provides `denom_profile()` which maps M̃(h) and ∂M̃/∂h over a bandwidth grid
|
|
5
|
+
to assess gradient conditioning before training. A stable IFT gradient at
|
|
6
|
+
h_crit requires |∂M̃/∂h| > 0 (non-zero denominator in the IFT formula).
|
|
7
|
+
|
|
8
|
+
Use case: call denom_profile() on your dataset before fitting DCBLayer to
|
|
9
|
+
verify that the IFT gradient is well-conditioned at h_crit. If
|
|
10
|
+
stability_mask=False at h_crit, consider using safe_backward=True or
|
|
11
|
+
widening the bandwidth search range.
|
|
12
|
+
"""
|
|
13
|
+
|
|
14
|
+
from __future__ import annotations
|
|
15
|
+
|
|
16
|
+
import torch
|
|
17
|
+
from torch import Tensor
|
|
18
|
+
|
|
19
|
+
from dcb.kde import soft_mode_count_cross
|
|
20
|
+
from dcb.utils import make_grid
|
|
21
|
+
|
|
22
|
+
|
|
23
|
+
def denom_profile(
|
|
24
|
+
X: Tensor,
|
|
25
|
+
h_grid: Tensor,
|
|
26
|
+
formula: str = 'cross',
|
|
27
|
+
eps: float = 0.1,
|
|
28
|
+
tau: float = 0.2,
|
|
29
|
+
chunk_size: int = 50_000,
|
|
30
|
+
guard: float = 0.01,
|
|
31
|
+
) -> dict:
|
|
32
|
+
"""Compute M̃(h) and ∂M̃/∂h over a bandwidth grid for gradient stability diagnosis.
|
|
33
|
+
|
|
34
|
+
Evaluates the soft mode count M̃_cross at each bandwidth in h_grid, then
|
|
35
|
+
computes the finite-difference derivative ∂M̃/∂h. The stability_mask
|
|
36
|
+
identifies bandwidths where the IFT denominator is large enough for
|
|
37
|
+
well-conditioned gradients.
|
|
38
|
+
|
|
39
|
+
Parameters
|
|
40
|
+
----------
|
|
41
|
+
X : Tensor, shape (n,)
|
|
42
|
+
Observed data points.
|
|
43
|
+
h_grid : Tensor, shape (H,)
|
|
44
|
+
Bandwidth grid to evaluate. Should cover the expected h_crit.
|
|
45
|
+
formula : str
|
|
46
|
+
Mode-count formula to use. Only 'cross' is supported (matches DCBLayer).
|
|
47
|
+
eps : float
|
|
48
|
+
Sigmoid temperature for the zero-crossing detector. Default 0.1.
|
|
49
|
+
tau : float
|
|
50
|
+
Sigmoid temperature for the local-max selector. Default 0.2.
|
|
51
|
+
chunk_size : int
|
|
52
|
+
Chunk size for KDE computation (not used in dense path, kept for API
|
|
53
|
+
consistency with large-n paths).
|
|
54
|
+
guard : float
|
|
55
|
+
Threshold for stability_mask: True where |dM_dh| > guard. Default 0.01.
|
|
56
|
+
|
|
57
|
+
Returns
|
|
58
|
+
-------
|
|
59
|
+
dict with keys:
|
|
60
|
+
'h_grid' : Tensor (H,) — input bandwidth grid
|
|
61
|
+
'M_tilde' : Tensor (H,) — soft mode count at each h
|
|
62
|
+
'dM_dh' : Tensor (H,) — ∂M̃/∂h via central finite differences
|
|
63
|
+
'stability_mask' : BoolTensor (H,) — True where |dM_dh| > guard
|
|
64
|
+
'h_crit_approx' : float — approximate h_crit (smallest h where M̃ ≤ 1.5);
|
|
65
|
+
float('nan') if not found
|
|
66
|
+
|
|
67
|
+
Notes
|
|
68
|
+
-----
|
|
69
|
+
All computation runs under torch.no_grad() — this function is diagnostic
|
|
70
|
+
only and does not build a computation graph.
|
|
71
|
+
|
|
72
|
+
Use case: call before training to confirm gradients are well-conditioned
|
|
73
|
+
at h_crit. stability_mask=True at h_crit means IFT gradient is valid for
|
|
74
|
+
that dataset.
|
|
75
|
+
"""
|
|
76
|
+
if formula != 'cross':
|
|
77
|
+
raise ValueError(f"Only formula='cross' is supported; got {formula!r}")
|
|
78
|
+
|
|
79
|
+
H = h_grid.shape[0]
|
|
80
|
+
grid = make_grid(X, G=512)
|
|
81
|
+
|
|
82
|
+
M_tilde = torch.zeros(H, dtype=X.dtype, device=X.device)
|
|
83
|
+
|
|
84
|
+
with torch.no_grad():
|
|
85
|
+
for i in range(H):
|
|
86
|
+
h_val = h_grid[i].item()
|
|
87
|
+
M_tilde[i] = soft_mode_count_cross(X, h_val, grid, eps, tau)
|
|
88
|
+
|
|
89
|
+
# Central finite differences for interior; forward/backward at edges
|
|
90
|
+
dM_dh = torch.zeros(H, dtype=X.dtype, device=X.device)
|
|
91
|
+
for i in range(H):
|
|
92
|
+
if i == 0:
|
|
93
|
+
# Forward difference
|
|
94
|
+
dM_dh[i] = (M_tilde[1] - M_tilde[0]) / (h_grid[1] - h_grid[0])
|
|
95
|
+
elif i == H - 1:
|
|
96
|
+
# Backward difference
|
|
97
|
+
dM_dh[i] = (M_tilde[H - 1] - M_tilde[H - 2]) / (h_grid[H - 1] - h_grid[H - 2])
|
|
98
|
+
else:
|
|
99
|
+
# Central difference
|
|
100
|
+
dM_dh[i] = (M_tilde[i + 1] - M_tilde[i - 1]) / (h_grid[i + 1] - h_grid[i - 1])
|
|
101
|
+
|
|
102
|
+
stability_mask = dM_dh.abs() > guard
|
|
103
|
+
|
|
104
|
+
# h_crit_approx: smallest h where M̃ ≤ 1.5 (threshold for target_modes=1)
|
|
105
|
+
below_threshold = (M_tilde <= 1.5).nonzero(as_tuple=False)
|
|
106
|
+
if below_threshold.numel() > 0:
|
|
107
|
+
first_idx = below_threshold[0].item()
|
|
108
|
+
h_crit_approx = h_grid[first_idx].item()
|
|
109
|
+
else:
|
|
110
|
+
h_crit_approx = float('nan')
|
|
111
|
+
|
|
112
|
+
return {
|
|
113
|
+
'h_grid': h_grid,
|
|
114
|
+
'M_tilde': M_tilde,
|
|
115
|
+
'dM_dh': dM_dh,
|
|
116
|
+
'stability_mask': stability_mask,
|
|
117
|
+
'h_crit_approx': h_crit_approx,
|
|
118
|
+
}
|
|
119
|
+
|
|
120
|
+
|
|
121
|
+
def print_stability_report(profile: dict) -> None:
|
|
122
|
+
"""Print a human-readable stability report from denom_profile output.
|
|
123
|
+
|
|
124
|
+
Parameters
|
|
125
|
+
----------
|
|
126
|
+
profile : dict
|
|
127
|
+
Output from `denom_profile()`.
|
|
128
|
+
"""
|
|
129
|
+
h_grid = profile['h_grid']
|
|
130
|
+
M_tilde = profile['M_tilde']
|
|
131
|
+
dM_dh = profile['dM_dh']
|
|
132
|
+
stability_mask = profile['stability_mask']
|
|
133
|
+
h_crit_approx = profile['h_crit_approx']
|
|
134
|
+
|
|
135
|
+
H = h_grid.shape[0]
|
|
136
|
+
n_stable = stability_mask.sum().item()
|
|
137
|
+
pct_stable = 100.0 * n_stable / H
|
|
138
|
+
|
|
139
|
+
print("=" * 60)
|
|
140
|
+
print("DCB Gradient Stability Report")
|
|
141
|
+
print("=" * 60)
|
|
142
|
+
print(f" h_grid range : [{h_grid.min().item():.4f}, {h_grid.max().item():.4f}] (H={H})")
|
|
143
|
+
print(f" M_tilde range : [{M_tilde.min().item():.4f}, {M_tilde.max().item():.4f}]")
|
|
144
|
+
print(f" dM_dh range : [{dM_dh.min().item():.4f}, {dM_dh.max().item():.4f}]")
|
|
145
|
+
print(f" h_crit_approx : {h_crit_approx:.4f}" if h_crit_approx == h_crit_approx
|
|
146
|
+
else " h_crit_approx : NaN (M_tilde never <= 1.5 in grid)")
|
|
147
|
+
print(f" Stable points : {n_stable}/{H} ({pct_stable:.1f}%)")
|
|
148
|
+
|
|
149
|
+
if h_crit_approx == h_crit_approx: # not NaN
|
|
150
|
+
# Find index of h_crit_approx
|
|
151
|
+
idx = (h_grid - h_crit_approx).abs().argmin().item()
|
|
152
|
+
stable_at_hcrit = stability_mask[idx].item()
|
|
153
|
+
dM_at_hcrit = dM_dh[idx].abs().item()
|
|
154
|
+
print(f" At h_crit : stability={stable_at_hcrit}, |dM_dh|={dM_at_hcrit:.4f}")
|
|
155
|
+
if not stable_at_hcrit:
|
|
156
|
+
print()
|
|
157
|
+
print(" WARNING: h_crit_approx falls in an UNSTABLE region.")
|
|
158
|
+
print(" IFT gradient may be ill-conditioned at h_crit.")
|
|
159
|
+
print(" Consider: safe_backward=True, wider h_grid, or larger n.")
|
|
160
|
+
else:
|
|
161
|
+
print()
|
|
162
|
+
print(" OK: IFT gradient is well-conditioned at h_crit.")
|
|
163
|
+
print("=" * 60)
|
|
@@ -0,0 +1,128 @@
|
|
|
1
|
+
"""
|
|
2
|
+
dcb.fft_kde — FFT-based KDE Mode Counter
|
|
3
|
+
|
|
4
|
+
Implements mode counting via FFT convolution of the histogram with a
|
|
5
|
+
Gaussian derivative kernel. Complexity is O(n + G log G), avoiding the
|
|
6
|
+
O(n × G) cost of the direct KDE approach and — crucially — requiring NO
|
|
7
|
+
subsampling. This eliminates the (brentq_n_max / n)^{-1/5} upward bias
|
|
8
|
+
that affects the standard bisection path when n > brentq_n_max.
|
|
9
|
+
|
|
10
|
+
Round 18b: forward kernel only. The IFT backward is unchanged (still uses
|
|
11
|
+
the analytical chunked KDE derivatives on all n points).
|
|
12
|
+
"""
|
|
13
|
+
|
|
14
|
+
from __future__ import annotations
|
|
15
|
+
|
|
16
|
+
import math
|
|
17
|
+
|
|
18
|
+
import torch
|
|
19
|
+
from torch import Tensor
|
|
20
|
+
|
|
21
|
+
|
|
22
|
+
def fft_mode_count(
|
|
23
|
+
X: Tensor,
|
|
24
|
+
h: float,
|
|
25
|
+
G: int = 4096,
|
|
26
|
+
pad_factor: int = 4,
|
|
27
|
+
) -> int:
|
|
28
|
+
"""Count KDE modes via FFT convolution — O(n + G log G), no subsampling.
|
|
29
|
+
|
|
30
|
+
Bins X into G histogram bins, zero-pads to pad_factor*G, convolves with
|
|
31
|
+
the Gaussian derivative kernel in the frequency domain (applying iω·exp(−½(ωh)²)),
|
|
32
|
+
back-transforms, and counts positive-to-negative sign changes of the
|
|
33
|
+
resulting f' estimate.
|
|
34
|
+
|
|
35
|
+
Parameters
|
|
36
|
+
----------
|
|
37
|
+
X : Tensor, shape (n,)
|
|
38
|
+
1D data tensor (may be on CPU or CUDA).
|
|
39
|
+
h : float
|
|
40
|
+
Bandwidth for the Gaussian kernel.
|
|
41
|
+
G : int
|
|
42
|
+
Number of histogram bins. Must satisfy h > 8 * (data_range / G) for
|
|
43
|
+
reliable derivative estimation. Use `adaptive_fft_G` to choose G
|
|
44
|
+
automatically before bisection.
|
|
45
|
+
pad_factor : int
|
|
46
|
+
Zero-padding multiplier (default 4). Mandatory ≥ 2 for circular-wrap
|
|
47
|
+
correctness; 4 is recommended at the largest h encountered.
|
|
48
|
+
|
|
49
|
+
Returns
|
|
50
|
+
-------
|
|
51
|
+
int
|
|
52
|
+
Number of KDE modes (downward zero-crossings of f').
|
|
53
|
+
"""
|
|
54
|
+
with torch.no_grad():
|
|
55
|
+
# Domain: extend 3σ beyond data range to avoid boundary effects
|
|
56
|
+
sigma = X.std().item()
|
|
57
|
+
if sigma == 0.0:
|
|
58
|
+
sigma = 1.0 # degenerate case: all points identical
|
|
59
|
+
lo = X.min().item() - 3 * sigma
|
|
60
|
+
hi = X.max().item() + 3 * sigma
|
|
61
|
+
data_range = hi - lo
|
|
62
|
+
|
|
63
|
+
if data_range == 0.0:
|
|
64
|
+
return 1 # single-point distribution has 1 mode
|
|
65
|
+
|
|
66
|
+
# Histogram (O(n), CUDA-native)
|
|
67
|
+
counts = torch.histc(X.float(), bins=G, min=lo, max=hi)
|
|
68
|
+
|
|
69
|
+
# Zero-pad to pad_factor*G (4× mandatory for circular wrap correctness at h_hi)
|
|
70
|
+
N = pad_factor * G
|
|
71
|
+
counts_padded = torch.zeros(N, dtype=torch.float32, device=X.device)
|
|
72
|
+
counts_padded[:G] = counts
|
|
73
|
+
|
|
74
|
+
# FFT of histogram
|
|
75
|
+
C = torch.fft.rfft(counts_padded)
|
|
76
|
+
|
|
77
|
+
# Derivative kernel in frequency domain: iω * exp(-0.5*(ω*h)²)
|
|
78
|
+
# ω_k = 2π*k / (N * bin_width), bin_width = data_range / G
|
|
79
|
+
bin_width = data_range / G
|
|
80
|
+
k = torch.arange(N // 2 + 1, device=X.device, dtype=torch.float32)
|
|
81
|
+
omega = 2 * math.pi * k / (N * bin_width)
|
|
82
|
+
K_deriv = 1j * omega * torch.exp(-0.5 * (omega * h) ** 2)
|
|
83
|
+
|
|
84
|
+
# Convolve and back-transform
|
|
85
|
+
f_prime_padded = torch.fft.irfft(C * K_deriv, n=N)
|
|
86
|
+
|
|
87
|
+
# Trim to original G grid (discard zero-padded tail)
|
|
88
|
+
f_prime = f_prime_padded[:G]
|
|
89
|
+
|
|
90
|
+
# Count (+→-) sign changes = number of modes
|
|
91
|
+
# A mode is a local max of f, i.e., f' crosses zero from + to -
|
|
92
|
+
# Remove zeros (flat segments) — carry forward last nonzero sign
|
|
93
|
+
nonzero_mask = f_prime != 0
|
|
94
|
+
if not nonzero_mask.any():
|
|
95
|
+
return 0
|
|
96
|
+
|
|
97
|
+
s = f_prime[nonzero_mask]
|
|
98
|
+
transitions = int(((s[:-1] > 0) & (s[1:] < 0)).sum().item())
|
|
99
|
+
return transitions
|
|
100
|
+
|
|
101
|
+
|
|
102
|
+
def adaptive_fft_G(data_range: float, h_hi: float, G_min: int = 4096) -> int:
|
|
103
|
+
"""Choose FFT grid size G so that the derivative kernel is well-resolved.
|
|
104
|
+
|
|
105
|
+
Requires h > 8 * bin_width = 8 * data_range / G, equivalently
|
|
106
|
+
G > 8 * data_range / h_hi. We use a factor of 16 for safety margin,
|
|
107
|
+
then round up to the next power of 2 for efficient FFT.
|
|
108
|
+
|
|
109
|
+
Parameters
|
|
110
|
+
----------
|
|
111
|
+
data_range : float
|
|
112
|
+
hi - lo of the data domain (typically X.max() - X.min() + 6σ).
|
|
113
|
+
h_hi : float
|
|
114
|
+
Upper bracket of the bisection (smallest h needing resolution).
|
|
115
|
+
G_min : int
|
|
116
|
+
Minimum returned G (default 4096).
|
|
117
|
+
|
|
118
|
+
Returns
|
|
119
|
+
-------
|
|
120
|
+
int
|
|
121
|
+
Grid size G, a power of 2, at least G_min.
|
|
122
|
+
"""
|
|
123
|
+
needed = 16 * math.ceil(data_range / h_hi)
|
|
124
|
+
# Round up to next power of 2
|
|
125
|
+
p = 1
|
|
126
|
+
while p < needed:
|
|
127
|
+
p <<= 1
|
|
128
|
+
return max(G_min, p)
|