diffcb 0.1.0__tar.gz

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
@@ -0,0 +1,11 @@
1
+ __pycache__/
2
+ *.py[cod]
3
+ *.egg-info/
4
+ dist/
5
+ build/
6
+ .venv/
7
+ venv/
8
+ .ipynb_checkpoints/
9
+ .DS_Store
10
+ outputs/
11
+ .claude/
@@ -0,0 +1,30 @@
1
+ {
2
+ "title": "Differentiable Critical Bandwidth (DCB) v0.1.0",
3
+ "description": "A PyTorch package making Silverman's critical bandwidth test fully differentiable via a smooth mode-counting integral and an Implicit Function Theorem backward pass.",
4
+ "upload_type": "software",
5
+ "license": "MIT",
6
+ "creators": [
7
+ {
8
+ "name": "Zhang, Ruiyu",
9
+ "affiliation": "University of Hong Kong"
10
+ }
11
+ ],
12
+ "keywords": [
13
+ "nonparametric statistics",
14
+ "kernel density estimation",
15
+ "differentiable programming",
16
+ "critical bandwidth",
17
+ "mode counting",
18
+ "implicit function theorem",
19
+ "PyTorch",
20
+ "JMLR",
21
+ "Silverman 1981"
22
+ ],
23
+ "related_identifiers": [
24
+ {
25
+ "relation": "isSupplementTo",
26
+ "identifier": "10.48550/arXiv.XXXX.XXXXX",
27
+ "scheme": "doi"
28
+ }
29
+ ]
30
+ }
diffcb-0.1.0/LICENSE ADDED
@@ -0,0 +1,21 @@
1
+ MIT License
2
+
3
+ Copyright (c) 2026 Ruiyu Zhang
4
+
5
+ Permission is hereby granted, free of charge, to any person obtaining a copy
6
+ of this software and associated documentation files (the "Software"), to deal
7
+ in the Software without restriction, including without limitation the rights
8
+ to use, copy, modify, merge, publish, distribute, sublicense, and/or sell
9
+ copies of the Software, and to permit persons to whom the Software is
10
+ furnished to do so, subject to the following conditions:
11
+
12
+ The above copyright notice and this permission notice shall be included in all
13
+ copies or substantial portions of the Software.
14
+
15
+ THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR
16
+ IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY,
17
+ FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE
18
+ AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER
19
+ LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM,
20
+ OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE
21
+ SOFTWARE.
diffcb-0.1.0/PKG-INFO ADDED
@@ -0,0 +1,148 @@
1
+ Metadata-Version: 2.4
2
+ Name: diffcb
3
+ Version: 0.1.0
4
+ Summary: Differentiable Critical Bandwidth: Silverman's modality test as a differentiable PyTorch layer with IFT backward pass.
5
+ Project-URL: Homepage, https://github.com/ryZhangHason/differentiable-critical-bandwidth
6
+ Project-URL: Repository, https://github.com/ryZhangHason/differentiable-critical-bandwidth
7
+ Project-URL: Documentation, https://github.com/ryZhangHason/differentiable-critical-bandwidth#readme
8
+ Project-URL: Bug Tracker, https://github.com/ryZhangHason/differentiable-critical-bandwidth/issues
9
+ Author-email: Ruiyu Zhang <dhhhason@gmail.com>
10
+ License: MIT License
11
+
12
+ Copyright (c) 2026 Ruiyu Zhang
13
+
14
+ Permission is hereby granted, free of charge, to any person obtaining a copy
15
+ of this software and associated documentation files (the "Software"), to deal
16
+ in the Software without restriction, including without limitation the rights
17
+ to use, copy, modify, merge, publish, distribute, sublicense, and/or sell
18
+ copies of the Software, and to permit persons to whom the Software is
19
+ furnished to do so, subject to the following conditions:
20
+
21
+ The above copyright notice and this permission notice shall be included in all
22
+ copies or substantial portions of the Software.
23
+
24
+ THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR
25
+ IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY,
26
+ FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE
27
+ AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER
28
+ LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM,
29
+ OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE
30
+ SOFTWARE.
31
+ License-File: LICENSE
32
+ Keywords: PyTorch,anomaly detection,critical bandwidth,differentiable programming,generative models,kernel density estimation,mode counting,nonparametric statistics
33
+ Classifier: Development Status :: 3 - Alpha
34
+ Classifier: Intended Audience :: Science/Research
35
+ Classifier: License :: OSI Approved :: MIT License
36
+ Classifier: Programming Language :: Python :: 3
37
+ Classifier: Programming Language :: Python :: 3.9
38
+ Classifier: Programming Language :: Python :: 3.10
39
+ Classifier: Programming Language :: Python :: 3.11
40
+ Classifier: Topic :: Scientific/Engineering :: Artificial Intelligence
41
+ Classifier: Topic :: Scientific/Engineering :: Mathematics
42
+ Requires-Python: >=3.9
43
+ Requires-Dist: matplotlib>=3.7.0
44
+ Requires-Dist: numpy>=1.24.0
45
+ Requires-Dist: scikit-learn>=1.3.0
46
+ Requires-Dist: scipy>=1.10.0
47
+ Requires-Dist: torch>=2.0.0
48
+ Provides-Extra: dev
49
+ Requires-Dist: black>=23.0.0; extra == 'dev'
50
+ Requires-Dist: pytest-cov>=4.1.0; extra == 'dev'
51
+ Requires-Dist: pytest>=7.4.0; extra == 'dev'
52
+ Requires-Dist: ruff>=0.1.0; extra == 'dev'
53
+ Provides-Extra: notebooks
54
+ Requires-Dist: ipywidgets>=8.0.0; extra == 'notebooks'
55
+ Requires-Dist: jupyter>=1.0.0; extra == 'notebooks'
56
+ Description-Content-Type: text/markdown
57
+
58
+ # DCB — Differentiable Critical Bandwidth
59
+
60
+ [![License: MIT](https://img.shields.io/badge/License-MIT-yellow.svg)](LICENSE)
61
+ [![Python 3.9+](https://img.shields.io/badge/python-3.9+-blue.svg)](https://www.python.org/)
62
+
63
+ A PyTorch package that makes **Silverman's critical bandwidth test (1981)** fully differentiable, enabling end-to-end gradient-based optimization over the modal structure of continuous distributions.
64
+
65
+ ## Overview
66
+
67
+ The critical bandwidth `h_crit` is the minimum KDE bandwidth at which a distribution appears to have at most `m` modes — a classical nonparametric statistic for modality testing. DCB replaces every non-differentiable operation in its computation with a smooth surrogate, then uses the **Implicit Function Theorem** to compute exact gradients through the root-finding step at O(1) memory cost.
68
+
69
+ ```python
70
+ import torch
71
+ from dcb import DCBLayer
72
+
73
+ X = torch.randn(256, requires_grad=True) # 1D samples
74
+ layer = DCBLayer(target_modes=1)
75
+ h_crit = layer(X) # differentiable scalar
76
+ h_crit.backward() # exact IFT gradients
77
+ ```
78
+
79
+ ## Installation
80
+
81
+ ```bash
82
+ pip install dcb
83
+ ```
84
+
85
+ Or from source:
86
+
87
+ ```bash
88
+ git clone https://github.com/ryZhangHason/dcb
89
+ cd dcb
90
+ pip install -e ".[dev]"
91
+ ```
92
+
93
+ ## Paper
94
+
95
+ > Ruiyu Zhang. "Differentiable Critical Bandwidth: Making Silverman's Modality Test End-to-End Trainable." *Journal of Machine Learning Research*, 2026 (in preparation).
96
+
97
+ ## Confirmed Experimental Results
98
+
99
+ All results produced on Kaggle GPU (T4 / P100) — see `experiments/` and `outputs/`.
100
+
101
+ | Experiment | Result | Criterion |
102
+ |---|---|---|
103
+ | **Validation (m≥2)** | R²=0.91, MAE=0.07, Spearman ρ=0.89 | R²≥0.85, MAE≤0.10 ✓ |
104
+ | **Speedup vs scipy (n=8192)** | **10.5×** on T4 | ≥3× ✓ |
105
+ | **GAN mode preservation** | h_crit=1.232 >> 0.3 | h_crit>0.3 ✓ |
106
+ | **Anomaly AUC (KDDCup99)** | DCB=**0.9982** vs IF=0.9867 | DCB≥IF ✓ |
107
+
108
+ ## Repository Structure
109
+
110
+ ```
111
+ dcb/ Core PyTorch package (layer.py, solver.py, kde.py, utils.py)
112
+ experiments/ Reproduction scripts for all paper figures and tables
113
+ phase1_validation.py Figure 1: DCB vs reference h_crit scatter
114
+ phase1_speedup.py Figure 2: GPU speedup benchmark
115
+ phase1_ablation.py Figures S1–S2: ε/τ sensitivity heatmaps
116
+ phase2_gan.py Figure 3: GAN mode-collapse prevention
117
+ phase3_anomaly.py Table 2 + Figure 5: anomaly detection benchmark
118
+ tests/ Unit tests (pytest, 35/35 passing)
119
+ outputs/ All generated figures and tables (PDFs, PNGs, CSVs)
120
+ notebooks/ Quickstart and demo notebooks
121
+ ```
122
+
123
+ ## Reproducing Paper Results
124
+
125
+ ```bash
126
+ # Phase 1: validation, speedup, ablation
127
+ python experiments/phase1_validation.py
128
+ python experiments/phase1_speedup.py
129
+ python experiments/phase1_ablation.py
130
+
131
+ # Phase 2: GAN mode collapse experiment
132
+ python experiments/phase2_gan.py
133
+
134
+ # Phase 3: anomaly detection benchmark
135
+ python experiments/phase3_anomaly.py
136
+ ```
137
+
138
+ For GPU runs, use the provided Kaggle kernels:
139
+ - Phase 1–2: `hsingle/dcb-full-experiments`
140
+ - Phase 3: `hsingle/dcb-phase-3-anomaly-detection`
141
+
142
+ ## Kaggle GPU Notes
143
+
144
+ Kaggle may assign a P100 (sm_60) instead of T4. The Phase 3 kernel handles this automatically by installing `torch==2.2.2+cu118` (the earliest PyTorch release with both Python 3.12 and sm_60 support) when P100 is detected.
145
+
146
+ ## License
147
+
148
+ MIT — see [LICENSE](LICENSE).
diffcb-0.1.0/README.md ADDED
@@ -0,0 +1,91 @@
1
+ # DCB — Differentiable Critical Bandwidth
2
+
3
+ [![License: MIT](https://img.shields.io/badge/License-MIT-yellow.svg)](LICENSE)
4
+ [![Python 3.9+](https://img.shields.io/badge/python-3.9+-blue.svg)](https://www.python.org/)
5
+
6
+ A PyTorch package that makes **Silverman's critical bandwidth test (1981)** fully differentiable, enabling end-to-end gradient-based optimization over the modal structure of continuous distributions.
7
+
8
+ ## Overview
9
+
10
+ The critical bandwidth `h_crit` is the minimum KDE bandwidth at which a distribution appears to have at most `m` modes — a classical nonparametric statistic for modality testing. DCB replaces every non-differentiable operation in its computation with a smooth surrogate, then uses the **Implicit Function Theorem** to compute exact gradients through the root-finding step at O(1) memory cost.
11
+
12
+ ```python
13
+ import torch
14
+ from dcb import DCBLayer
15
+
16
+ X = torch.randn(256, requires_grad=True) # 1D samples
17
+ layer = DCBLayer(target_modes=1)
18
+ h_crit = layer(X) # differentiable scalar
19
+ h_crit.backward() # exact IFT gradients
20
+ ```
21
+
22
+ ## Installation
23
+
24
+ ```bash
25
+ pip install dcb
26
+ ```
27
+
28
+ Or from source:
29
+
30
+ ```bash
31
+ git clone https://github.com/ryZhangHason/dcb
32
+ cd dcb
33
+ pip install -e ".[dev]"
34
+ ```
35
+
36
+ ## Paper
37
+
38
+ > Ruiyu Zhang. "Differentiable Critical Bandwidth: Making Silverman's Modality Test End-to-End Trainable." *Journal of Machine Learning Research*, 2026 (in preparation).
39
+
40
+ ## Confirmed Experimental Results
41
+
42
+ All results produced on Kaggle GPU (T4 / P100) — see `experiments/` and `outputs/`.
43
+
44
+ | Experiment | Result | Criterion |
45
+ |---|---|---|
46
+ | **Validation (m≥2)** | R²=0.91, MAE=0.07, Spearman ρ=0.89 | R²≥0.85, MAE≤0.10 ✓ |
47
+ | **Speedup vs scipy (n=8192)** | **10.5×** on T4 | ≥3× ✓ |
48
+ | **GAN mode preservation** | h_crit=1.232 >> 0.3 | h_crit>0.3 ✓ |
49
+ | **Anomaly AUC (KDDCup99)** | DCB=**0.9982** vs IF=0.9867 | DCB≥IF ✓ |
50
+
51
+ ## Repository Structure
52
+
53
+ ```
54
+ dcb/ Core PyTorch package (layer.py, solver.py, kde.py, utils.py)
55
+ experiments/ Reproduction scripts for all paper figures and tables
56
+ phase1_validation.py Figure 1: DCB vs reference h_crit scatter
57
+ phase1_speedup.py Figure 2: GPU speedup benchmark
58
+ phase1_ablation.py Figures S1–S2: ε/τ sensitivity heatmaps
59
+ phase2_gan.py Figure 3: GAN mode-collapse prevention
60
+ phase3_anomaly.py Table 2 + Figure 5: anomaly detection benchmark
61
+ tests/ Unit tests (pytest, 35/35 passing)
62
+ outputs/ All generated figures and tables (PDFs, PNGs, CSVs)
63
+ notebooks/ Quickstart and demo notebooks
64
+ ```
65
+
66
+ ## Reproducing Paper Results
67
+
68
+ ```bash
69
+ # Phase 1: validation, speedup, ablation
70
+ python experiments/phase1_validation.py
71
+ python experiments/phase1_speedup.py
72
+ python experiments/phase1_ablation.py
73
+
74
+ # Phase 2: GAN mode collapse experiment
75
+ python experiments/phase2_gan.py
76
+
77
+ # Phase 3: anomaly detection benchmark
78
+ python experiments/phase3_anomaly.py
79
+ ```
80
+
81
+ For GPU runs, use the provided Kaggle kernels:
82
+ - Phase 1–2: `hsingle/dcb-full-experiments`
83
+ - Phase 3: `hsingle/dcb-phase-3-anomaly-detection`
84
+
85
+ ## Kaggle GPU Notes
86
+
87
+ Kaggle may assign a P100 (sm_60) instead of T4. The Phase 3 kernel handles this automatically by installing `torch==2.2.2+cu118` (the earliest PyTorch release with both Python 3.12 and sm_60 support) when P100 is detected.
88
+
89
+ ## License
90
+
91
+ MIT — see [LICENSE](LICENSE).
@@ -0,0 +1,22 @@
1
+ """
2
+ dcb — Differentiable Critical Bandwidth
3
+
4
+ A PyTorch package that makes Silverman's critical bandwidth test (1981) fully
5
+ differentiable via a smooth mode-counting integral and an Implicit Function
6
+ Theorem (IFT) backward pass. The primary public API is the
7
+ `DifferentiableCriticalBandwidth` class, which behaves as a standard
8
+ `torch.nn.Module` and can be used as a loss component or regularizer in any
9
+ gradient-based learning pipeline. Import as `from dcb import DCBLayer` for
10
+ the layer, or `from dcb.kde import gaussian_kde_grid` for lower-level KDE
11
+ utilities. Requires PyTorch >= 2.0, NumPy >= 1.24, and SciPy >= 1.10.
12
+ """
13
+
14
+ from dcb.layer import DCBLayer, DifferentiableCriticalBandwidth
15
+ from dcb.utils import anneal_eps_tau
16
+ from dcb.kde import soft_mode_count_cross, soft_mode_count
17
+
18
+ __all__ = [
19
+ "DCBLayer", "DifferentiableCriticalBandwidth",
20
+ "anneal_eps_tau", "soft_mode_count_cross", "soft_mode_count",
21
+ ]
22
+ __version__ = "0.1.0"
@@ -0,0 +1,163 @@
1
+ """
2
+ dcb.diagnostics — Gradient Stability Diagnostics for DCB
3
+
4
+ Provides `denom_profile()` which maps M̃(h) and ∂M̃/∂h over a bandwidth grid
5
+ to assess gradient conditioning before training. A stable IFT gradient at
6
+ h_crit requires |∂M̃/∂h| > 0 (non-zero denominator in the IFT formula).
7
+
8
+ Use case: call denom_profile() on your dataset before fitting DCBLayer to
9
+ verify that the IFT gradient is well-conditioned at h_crit. If
10
+ stability_mask=False at h_crit, consider using safe_backward=True or
11
+ widening the bandwidth search range.
12
+ """
13
+
14
+ from __future__ import annotations
15
+
16
+ import torch
17
+ from torch import Tensor
18
+
19
+ from dcb.kde import soft_mode_count_cross
20
+ from dcb.utils import make_grid
21
+
22
+
23
+ def denom_profile(
24
+ X: Tensor,
25
+ h_grid: Tensor,
26
+ formula: str = 'cross',
27
+ eps: float = 0.1,
28
+ tau: float = 0.2,
29
+ chunk_size: int = 50_000,
30
+ guard: float = 0.01,
31
+ ) -> dict:
32
+ """Compute M̃(h) and ∂M̃/∂h over a bandwidth grid for gradient stability diagnosis.
33
+
34
+ Evaluates the soft mode count M̃_cross at each bandwidth in h_grid, then
35
+ computes the finite-difference derivative ∂M̃/∂h. The stability_mask
36
+ identifies bandwidths where the IFT denominator is large enough for
37
+ well-conditioned gradients.
38
+
39
+ Parameters
40
+ ----------
41
+ X : Tensor, shape (n,)
42
+ Observed data points.
43
+ h_grid : Tensor, shape (H,)
44
+ Bandwidth grid to evaluate. Should cover the expected h_crit.
45
+ formula : str
46
+ Mode-count formula to use. Only 'cross' is supported (matches DCBLayer).
47
+ eps : float
48
+ Sigmoid temperature for the zero-crossing detector. Default 0.1.
49
+ tau : float
50
+ Sigmoid temperature for the local-max selector. Default 0.2.
51
+ chunk_size : int
52
+ Chunk size for KDE computation (not used in dense path, kept for API
53
+ consistency with large-n paths).
54
+ guard : float
55
+ Threshold for stability_mask: True where |dM_dh| > guard. Default 0.01.
56
+
57
+ Returns
58
+ -------
59
+ dict with keys:
60
+ 'h_grid' : Tensor (H,) — input bandwidth grid
61
+ 'M_tilde' : Tensor (H,) — soft mode count at each h
62
+ 'dM_dh' : Tensor (H,) — ∂M̃/∂h via central finite differences
63
+ 'stability_mask' : BoolTensor (H,) — True where |dM_dh| > guard
64
+ 'h_crit_approx' : float — approximate h_crit (smallest h where M̃ ≤ 1.5);
65
+ float('nan') if not found
66
+
67
+ Notes
68
+ -----
69
+ All computation runs under torch.no_grad() — this function is diagnostic
70
+ only and does not build a computation graph.
71
+
72
+ Use case: call before training to confirm gradients are well-conditioned
73
+ at h_crit. stability_mask=True at h_crit means IFT gradient is valid for
74
+ that dataset.
75
+ """
76
+ if formula != 'cross':
77
+ raise ValueError(f"Only formula='cross' is supported; got {formula!r}")
78
+
79
+ H = h_grid.shape[0]
80
+ grid = make_grid(X, G=512)
81
+
82
+ M_tilde = torch.zeros(H, dtype=X.dtype, device=X.device)
83
+
84
+ with torch.no_grad():
85
+ for i in range(H):
86
+ h_val = h_grid[i].item()
87
+ M_tilde[i] = soft_mode_count_cross(X, h_val, grid, eps, tau)
88
+
89
+ # Central finite differences for interior; forward/backward at edges
90
+ dM_dh = torch.zeros(H, dtype=X.dtype, device=X.device)
91
+ for i in range(H):
92
+ if i == 0:
93
+ # Forward difference
94
+ dM_dh[i] = (M_tilde[1] - M_tilde[0]) / (h_grid[1] - h_grid[0])
95
+ elif i == H - 1:
96
+ # Backward difference
97
+ dM_dh[i] = (M_tilde[H - 1] - M_tilde[H - 2]) / (h_grid[H - 1] - h_grid[H - 2])
98
+ else:
99
+ # Central difference
100
+ dM_dh[i] = (M_tilde[i + 1] - M_tilde[i - 1]) / (h_grid[i + 1] - h_grid[i - 1])
101
+
102
+ stability_mask = dM_dh.abs() > guard
103
+
104
+ # h_crit_approx: smallest h where M̃ ≤ 1.5 (threshold for target_modes=1)
105
+ below_threshold = (M_tilde <= 1.5).nonzero(as_tuple=False)
106
+ if below_threshold.numel() > 0:
107
+ first_idx = below_threshold[0].item()
108
+ h_crit_approx = h_grid[first_idx].item()
109
+ else:
110
+ h_crit_approx = float('nan')
111
+
112
+ return {
113
+ 'h_grid': h_grid,
114
+ 'M_tilde': M_tilde,
115
+ 'dM_dh': dM_dh,
116
+ 'stability_mask': stability_mask,
117
+ 'h_crit_approx': h_crit_approx,
118
+ }
119
+
120
+
121
+ def print_stability_report(profile: dict) -> None:
122
+ """Print a human-readable stability report from denom_profile output.
123
+
124
+ Parameters
125
+ ----------
126
+ profile : dict
127
+ Output from `denom_profile()`.
128
+ """
129
+ h_grid = profile['h_grid']
130
+ M_tilde = profile['M_tilde']
131
+ dM_dh = profile['dM_dh']
132
+ stability_mask = profile['stability_mask']
133
+ h_crit_approx = profile['h_crit_approx']
134
+
135
+ H = h_grid.shape[0]
136
+ n_stable = stability_mask.sum().item()
137
+ pct_stable = 100.0 * n_stable / H
138
+
139
+ print("=" * 60)
140
+ print("DCB Gradient Stability Report")
141
+ print("=" * 60)
142
+ print(f" h_grid range : [{h_grid.min().item():.4f}, {h_grid.max().item():.4f}] (H={H})")
143
+ print(f" M_tilde range : [{M_tilde.min().item():.4f}, {M_tilde.max().item():.4f}]")
144
+ print(f" dM_dh range : [{dM_dh.min().item():.4f}, {dM_dh.max().item():.4f}]")
145
+ print(f" h_crit_approx : {h_crit_approx:.4f}" if h_crit_approx == h_crit_approx
146
+ else " h_crit_approx : NaN (M_tilde never <= 1.5 in grid)")
147
+ print(f" Stable points : {n_stable}/{H} ({pct_stable:.1f}%)")
148
+
149
+ if h_crit_approx == h_crit_approx: # not NaN
150
+ # Find index of h_crit_approx
151
+ idx = (h_grid - h_crit_approx).abs().argmin().item()
152
+ stable_at_hcrit = stability_mask[idx].item()
153
+ dM_at_hcrit = dM_dh[idx].abs().item()
154
+ print(f" At h_crit : stability={stable_at_hcrit}, |dM_dh|={dM_at_hcrit:.4f}")
155
+ if not stable_at_hcrit:
156
+ print()
157
+ print(" WARNING: h_crit_approx falls in an UNSTABLE region.")
158
+ print(" IFT gradient may be ill-conditioned at h_crit.")
159
+ print(" Consider: safe_backward=True, wider h_grid, or larger n.")
160
+ else:
161
+ print()
162
+ print(" OK: IFT gradient is well-conditioned at h_crit.")
163
+ print("=" * 60)
@@ -0,0 +1,128 @@
1
+ """
2
+ dcb.fft_kde — FFT-based KDE Mode Counter
3
+
4
+ Implements mode counting via FFT convolution of the histogram with a
5
+ Gaussian derivative kernel. Complexity is O(n + G log G), avoiding the
6
+ O(n × G) cost of the direct KDE approach and — crucially — requiring NO
7
+ subsampling. This eliminates the (brentq_n_max / n)^{-1/5} upward bias
8
+ that affects the standard bisection path when n > brentq_n_max.
9
+
10
+ Round 18b: forward kernel only. The IFT backward is unchanged (still uses
11
+ the analytical chunked KDE derivatives on all n points).
12
+ """
13
+
14
+ from __future__ import annotations
15
+
16
+ import math
17
+
18
+ import torch
19
+ from torch import Tensor
20
+
21
+
22
+ def fft_mode_count(
23
+ X: Tensor,
24
+ h: float,
25
+ G: int = 4096,
26
+ pad_factor: int = 4,
27
+ ) -> int:
28
+ """Count KDE modes via FFT convolution — O(n + G log G), no subsampling.
29
+
30
+ Bins X into G histogram bins, zero-pads to pad_factor*G, convolves with
31
+ the Gaussian derivative kernel in the frequency domain (applying iω·exp(−½(ωh)²)),
32
+ back-transforms, and counts positive-to-negative sign changes of the
33
+ resulting f' estimate.
34
+
35
+ Parameters
36
+ ----------
37
+ X : Tensor, shape (n,)
38
+ 1D data tensor (may be on CPU or CUDA).
39
+ h : float
40
+ Bandwidth for the Gaussian kernel.
41
+ G : int
42
+ Number of histogram bins. Must satisfy h > 8 * (data_range / G) for
43
+ reliable derivative estimation. Use `adaptive_fft_G` to choose G
44
+ automatically before bisection.
45
+ pad_factor : int
46
+ Zero-padding multiplier (default 4). Mandatory ≥ 2 for circular-wrap
47
+ correctness; 4 is recommended at the largest h encountered.
48
+
49
+ Returns
50
+ -------
51
+ int
52
+ Number of KDE modes (downward zero-crossings of f').
53
+ """
54
+ with torch.no_grad():
55
+ # Domain: extend 3σ beyond data range to avoid boundary effects
56
+ sigma = X.std().item()
57
+ if sigma == 0.0:
58
+ sigma = 1.0 # degenerate case: all points identical
59
+ lo = X.min().item() - 3 * sigma
60
+ hi = X.max().item() + 3 * sigma
61
+ data_range = hi - lo
62
+
63
+ if data_range == 0.0:
64
+ return 1 # single-point distribution has 1 mode
65
+
66
+ # Histogram (O(n), CUDA-native)
67
+ counts = torch.histc(X.float(), bins=G, min=lo, max=hi)
68
+
69
+ # Zero-pad to pad_factor*G (4× mandatory for circular wrap correctness at h_hi)
70
+ N = pad_factor * G
71
+ counts_padded = torch.zeros(N, dtype=torch.float32, device=X.device)
72
+ counts_padded[:G] = counts
73
+
74
+ # FFT of histogram
75
+ C = torch.fft.rfft(counts_padded)
76
+
77
+ # Derivative kernel in frequency domain: iω * exp(-0.5*(ω*h)²)
78
+ # ω_k = 2π*k / (N * bin_width), bin_width = data_range / G
79
+ bin_width = data_range / G
80
+ k = torch.arange(N // 2 + 1, device=X.device, dtype=torch.float32)
81
+ omega = 2 * math.pi * k / (N * bin_width)
82
+ K_deriv = 1j * omega * torch.exp(-0.5 * (omega * h) ** 2)
83
+
84
+ # Convolve and back-transform
85
+ f_prime_padded = torch.fft.irfft(C * K_deriv, n=N)
86
+
87
+ # Trim to original G grid (discard zero-padded tail)
88
+ f_prime = f_prime_padded[:G]
89
+
90
+ # Count (+→-) sign changes = number of modes
91
+ # A mode is a local max of f, i.e., f' crosses zero from + to -
92
+ # Remove zeros (flat segments) — carry forward last nonzero sign
93
+ nonzero_mask = f_prime != 0
94
+ if not nonzero_mask.any():
95
+ return 0
96
+
97
+ s = f_prime[nonzero_mask]
98
+ transitions = int(((s[:-1] > 0) & (s[1:] < 0)).sum().item())
99
+ return transitions
100
+
101
+
102
+ def adaptive_fft_G(data_range: float, h_hi: float, G_min: int = 4096) -> int:
103
+ """Choose FFT grid size G so that the derivative kernel is well-resolved.
104
+
105
+ Requires h > 8 * bin_width = 8 * data_range / G, equivalently
106
+ G > 8 * data_range / h_hi. We use a factor of 16 for safety margin,
107
+ then round up to the next power of 2 for efficient FFT.
108
+
109
+ Parameters
110
+ ----------
111
+ data_range : float
112
+ hi - lo of the data domain (typically X.max() - X.min() + 6σ).
113
+ h_hi : float
114
+ Upper bracket of the bisection (smallest h needing resolution).
115
+ G_min : int
116
+ Minimum returned G (default 4096).
117
+
118
+ Returns
119
+ -------
120
+ int
121
+ Grid size G, a power of 2, at least G_min.
122
+ """
123
+ needed = 16 * math.ceil(data_range / h_hi)
124
+ # Round up to next power of 2
125
+ p = 1
126
+ while p < needed:
127
+ p <<= 1
128
+ return max(G_min, p)