diffcb 0.1.0__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- dcb/__init__.py +22 -0
- dcb/diagnostics.py +163 -0
- dcb/fft_kde.py +128 -0
- dcb/kde.py +394 -0
- dcb/layer.py +231 -0
- dcb/solver.py +604 -0
- dcb/utils.py +183 -0
- diffcb-0.1.0.dist-info/METADATA +148 -0
- diffcb-0.1.0.dist-info/RECORD +11 -0
- diffcb-0.1.0.dist-info/WHEEL +4 -0
- diffcb-0.1.0.dist-info/licenses/LICENSE +21 -0
dcb/__init__.py
ADDED
|
@@ -0,0 +1,22 @@
|
|
|
1
|
+
"""
|
|
2
|
+
dcb — Differentiable Critical Bandwidth
|
|
3
|
+
|
|
4
|
+
A PyTorch package that makes Silverman's critical bandwidth test (1981) fully
|
|
5
|
+
differentiable via a smooth mode-counting integral and an Implicit Function
|
|
6
|
+
Theorem (IFT) backward pass. The primary public API is the
|
|
7
|
+
`DifferentiableCriticalBandwidth` class, which behaves as a standard
|
|
8
|
+
`torch.nn.Module` and can be used as a loss component or regularizer in any
|
|
9
|
+
gradient-based learning pipeline. Import as `from dcb import DCBLayer` for
|
|
10
|
+
the layer, or `from dcb.kde import gaussian_kde_grid` for lower-level KDE
|
|
11
|
+
utilities. Requires PyTorch >= 2.0, NumPy >= 1.24, and SciPy >= 1.10.
|
|
12
|
+
"""
|
|
13
|
+
|
|
14
|
+
from dcb.layer import DCBLayer, DifferentiableCriticalBandwidth
|
|
15
|
+
from dcb.utils import anneal_eps_tau
|
|
16
|
+
from dcb.kde import soft_mode_count_cross, soft_mode_count
|
|
17
|
+
|
|
18
|
+
__all__ = [
|
|
19
|
+
"DCBLayer", "DifferentiableCriticalBandwidth",
|
|
20
|
+
"anneal_eps_tau", "soft_mode_count_cross", "soft_mode_count",
|
|
21
|
+
]
|
|
22
|
+
__version__ = "0.1.0"
|
dcb/diagnostics.py
ADDED
|
@@ -0,0 +1,163 @@
|
|
|
1
|
+
"""
|
|
2
|
+
dcb.diagnostics — Gradient Stability Diagnostics for DCB
|
|
3
|
+
|
|
4
|
+
Provides `denom_profile()` which maps M̃(h) and ∂M̃/∂h over a bandwidth grid
|
|
5
|
+
to assess gradient conditioning before training. A stable IFT gradient at
|
|
6
|
+
h_crit requires |∂M̃/∂h| > 0 (non-zero denominator in the IFT formula).
|
|
7
|
+
|
|
8
|
+
Use case: call denom_profile() on your dataset before fitting DCBLayer to
|
|
9
|
+
verify that the IFT gradient is well-conditioned at h_crit. If
|
|
10
|
+
stability_mask=False at h_crit, consider using safe_backward=True or
|
|
11
|
+
widening the bandwidth search range.
|
|
12
|
+
"""
|
|
13
|
+
|
|
14
|
+
from __future__ import annotations
|
|
15
|
+
|
|
16
|
+
import torch
|
|
17
|
+
from torch import Tensor
|
|
18
|
+
|
|
19
|
+
from dcb.kde import soft_mode_count_cross
|
|
20
|
+
from dcb.utils import make_grid
|
|
21
|
+
|
|
22
|
+
|
|
23
|
+
def denom_profile(
|
|
24
|
+
X: Tensor,
|
|
25
|
+
h_grid: Tensor,
|
|
26
|
+
formula: str = 'cross',
|
|
27
|
+
eps: float = 0.1,
|
|
28
|
+
tau: float = 0.2,
|
|
29
|
+
chunk_size: int = 50_000,
|
|
30
|
+
guard: float = 0.01,
|
|
31
|
+
) -> dict:
|
|
32
|
+
"""Compute M̃(h) and ∂M̃/∂h over a bandwidth grid for gradient stability diagnosis.
|
|
33
|
+
|
|
34
|
+
Evaluates the soft mode count M̃_cross at each bandwidth in h_grid, then
|
|
35
|
+
computes the finite-difference derivative ∂M̃/∂h. The stability_mask
|
|
36
|
+
identifies bandwidths where the IFT denominator is large enough for
|
|
37
|
+
well-conditioned gradients.
|
|
38
|
+
|
|
39
|
+
Parameters
|
|
40
|
+
----------
|
|
41
|
+
X : Tensor, shape (n,)
|
|
42
|
+
Observed data points.
|
|
43
|
+
h_grid : Tensor, shape (H,)
|
|
44
|
+
Bandwidth grid to evaluate. Should cover the expected h_crit.
|
|
45
|
+
formula : str
|
|
46
|
+
Mode-count formula to use. Only 'cross' is supported (matches DCBLayer).
|
|
47
|
+
eps : float
|
|
48
|
+
Sigmoid temperature for the zero-crossing detector. Default 0.1.
|
|
49
|
+
tau : float
|
|
50
|
+
Sigmoid temperature for the local-max selector. Default 0.2.
|
|
51
|
+
chunk_size : int
|
|
52
|
+
Chunk size for KDE computation (not used in dense path, kept for API
|
|
53
|
+
consistency with large-n paths).
|
|
54
|
+
guard : float
|
|
55
|
+
Threshold for stability_mask: True where |dM_dh| > guard. Default 0.01.
|
|
56
|
+
|
|
57
|
+
Returns
|
|
58
|
+
-------
|
|
59
|
+
dict with keys:
|
|
60
|
+
'h_grid' : Tensor (H,) — input bandwidth grid
|
|
61
|
+
'M_tilde' : Tensor (H,) — soft mode count at each h
|
|
62
|
+
'dM_dh' : Tensor (H,) — ∂M̃/∂h via central finite differences
|
|
63
|
+
'stability_mask' : BoolTensor (H,) — True where |dM_dh| > guard
|
|
64
|
+
'h_crit_approx' : float — approximate h_crit (smallest h where M̃ ≤ 1.5);
|
|
65
|
+
float('nan') if not found
|
|
66
|
+
|
|
67
|
+
Notes
|
|
68
|
+
-----
|
|
69
|
+
All computation runs under torch.no_grad() — this function is diagnostic
|
|
70
|
+
only and does not build a computation graph.
|
|
71
|
+
|
|
72
|
+
Use case: call before training to confirm gradients are well-conditioned
|
|
73
|
+
at h_crit. stability_mask=True at h_crit means IFT gradient is valid for
|
|
74
|
+
that dataset.
|
|
75
|
+
"""
|
|
76
|
+
if formula != 'cross':
|
|
77
|
+
raise ValueError(f"Only formula='cross' is supported; got {formula!r}")
|
|
78
|
+
|
|
79
|
+
H = h_grid.shape[0]
|
|
80
|
+
grid = make_grid(X, G=512)
|
|
81
|
+
|
|
82
|
+
M_tilde = torch.zeros(H, dtype=X.dtype, device=X.device)
|
|
83
|
+
|
|
84
|
+
with torch.no_grad():
|
|
85
|
+
for i in range(H):
|
|
86
|
+
h_val = h_grid[i].item()
|
|
87
|
+
M_tilde[i] = soft_mode_count_cross(X, h_val, grid, eps, tau)
|
|
88
|
+
|
|
89
|
+
# Central finite differences for interior; forward/backward at edges
|
|
90
|
+
dM_dh = torch.zeros(H, dtype=X.dtype, device=X.device)
|
|
91
|
+
for i in range(H):
|
|
92
|
+
if i == 0:
|
|
93
|
+
# Forward difference
|
|
94
|
+
dM_dh[i] = (M_tilde[1] - M_tilde[0]) / (h_grid[1] - h_grid[0])
|
|
95
|
+
elif i == H - 1:
|
|
96
|
+
# Backward difference
|
|
97
|
+
dM_dh[i] = (M_tilde[H - 1] - M_tilde[H - 2]) / (h_grid[H - 1] - h_grid[H - 2])
|
|
98
|
+
else:
|
|
99
|
+
# Central difference
|
|
100
|
+
dM_dh[i] = (M_tilde[i + 1] - M_tilde[i - 1]) / (h_grid[i + 1] - h_grid[i - 1])
|
|
101
|
+
|
|
102
|
+
stability_mask = dM_dh.abs() > guard
|
|
103
|
+
|
|
104
|
+
# h_crit_approx: smallest h where M̃ ≤ 1.5 (threshold for target_modes=1)
|
|
105
|
+
below_threshold = (M_tilde <= 1.5).nonzero(as_tuple=False)
|
|
106
|
+
if below_threshold.numel() > 0:
|
|
107
|
+
first_idx = below_threshold[0].item()
|
|
108
|
+
h_crit_approx = h_grid[first_idx].item()
|
|
109
|
+
else:
|
|
110
|
+
h_crit_approx = float('nan')
|
|
111
|
+
|
|
112
|
+
return {
|
|
113
|
+
'h_grid': h_grid,
|
|
114
|
+
'M_tilde': M_tilde,
|
|
115
|
+
'dM_dh': dM_dh,
|
|
116
|
+
'stability_mask': stability_mask,
|
|
117
|
+
'h_crit_approx': h_crit_approx,
|
|
118
|
+
}
|
|
119
|
+
|
|
120
|
+
|
|
121
|
+
def print_stability_report(profile: dict) -> None:
|
|
122
|
+
"""Print a human-readable stability report from denom_profile output.
|
|
123
|
+
|
|
124
|
+
Parameters
|
|
125
|
+
----------
|
|
126
|
+
profile : dict
|
|
127
|
+
Output from `denom_profile()`.
|
|
128
|
+
"""
|
|
129
|
+
h_grid = profile['h_grid']
|
|
130
|
+
M_tilde = profile['M_tilde']
|
|
131
|
+
dM_dh = profile['dM_dh']
|
|
132
|
+
stability_mask = profile['stability_mask']
|
|
133
|
+
h_crit_approx = profile['h_crit_approx']
|
|
134
|
+
|
|
135
|
+
H = h_grid.shape[0]
|
|
136
|
+
n_stable = stability_mask.sum().item()
|
|
137
|
+
pct_stable = 100.0 * n_stable / H
|
|
138
|
+
|
|
139
|
+
print("=" * 60)
|
|
140
|
+
print("DCB Gradient Stability Report")
|
|
141
|
+
print("=" * 60)
|
|
142
|
+
print(f" h_grid range : [{h_grid.min().item():.4f}, {h_grid.max().item():.4f}] (H={H})")
|
|
143
|
+
print(f" M_tilde range : [{M_tilde.min().item():.4f}, {M_tilde.max().item():.4f}]")
|
|
144
|
+
print(f" dM_dh range : [{dM_dh.min().item():.4f}, {dM_dh.max().item():.4f}]")
|
|
145
|
+
print(f" h_crit_approx : {h_crit_approx:.4f}" if h_crit_approx == h_crit_approx
|
|
146
|
+
else " h_crit_approx : NaN (M_tilde never <= 1.5 in grid)")
|
|
147
|
+
print(f" Stable points : {n_stable}/{H} ({pct_stable:.1f}%)")
|
|
148
|
+
|
|
149
|
+
if h_crit_approx == h_crit_approx: # not NaN
|
|
150
|
+
# Find index of h_crit_approx
|
|
151
|
+
idx = (h_grid - h_crit_approx).abs().argmin().item()
|
|
152
|
+
stable_at_hcrit = stability_mask[idx].item()
|
|
153
|
+
dM_at_hcrit = dM_dh[idx].abs().item()
|
|
154
|
+
print(f" At h_crit : stability={stable_at_hcrit}, |dM_dh|={dM_at_hcrit:.4f}")
|
|
155
|
+
if not stable_at_hcrit:
|
|
156
|
+
print()
|
|
157
|
+
print(" WARNING: h_crit_approx falls in an UNSTABLE region.")
|
|
158
|
+
print(" IFT gradient may be ill-conditioned at h_crit.")
|
|
159
|
+
print(" Consider: safe_backward=True, wider h_grid, or larger n.")
|
|
160
|
+
else:
|
|
161
|
+
print()
|
|
162
|
+
print(" OK: IFT gradient is well-conditioned at h_crit.")
|
|
163
|
+
print("=" * 60)
|
dcb/fft_kde.py
ADDED
|
@@ -0,0 +1,128 @@
|
|
|
1
|
+
"""
|
|
2
|
+
dcb.fft_kde — FFT-based KDE Mode Counter
|
|
3
|
+
|
|
4
|
+
Implements mode counting via FFT convolution of the histogram with a
|
|
5
|
+
Gaussian derivative kernel. Complexity is O(n + G log G), avoiding the
|
|
6
|
+
O(n × G) cost of the direct KDE approach and — crucially — requiring NO
|
|
7
|
+
subsampling. This eliminates the (brentq_n_max / n)^{-1/5} upward bias
|
|
8
|
+
that affects the standard bisection path when n > brentq_n_max.
|
|
9
|
+
|
|
10
|
+
Round 18b: forward kernel only. The IFT backward is unchanged (still uses
|
|
11
|
+
the analytical chunked KDE derivatives on all n points).
|
|
12
|
+
"""
|
|
13
|
+
|
|
14
|
+
from __future__ import annotations
|
|
15
|
+
|
|
16
|
+
import math
|
|
17
|
+
|
|
18
|
+
import torch
|
|
19
|
+
from torch import Tensor
|
|
20
|
+
|
|
21
|
+
|
|
22
|
+
def fft_mode_count(
|
|
23
|
+
X: Tensor,
|
|
24
|
+
h: float,
|
|
25
|
+
G: int = 4096,
|
|
26
|
+
pad_factor: int = 4,
|
|
27
|
+
) -> int:
|
|
28
|
+
"""Count KDE modes via FFT convolution — O(n + G log G), no subsampling.
|
|
29
|
+
|
|
30
|
+
Bins X into G histogram bins, zero-pads to pad_factor*G, convolves with
|
|
31
|
+
the Gaussian derivative kernel in the frequency domain (applying iω·exp(−½(ωh)²)),
|
|
32
|
+
back-transforms, and counts positive-to-negative sign changes of the
|
|
33
|
+
resulting f' estimate.
|
|
34
|
+
|
|
35
|
+
Parameters
|
|
36
|
+
----------
|
|
37
|
+
X : Tensor, shape (n,)
|
|
38
|
+
1D data tensor (may be on CPU or CUDA).
|
|
39
|
+
h : float
|
|
40
|
+
Bandwidth for the Gaussian kernel.
|
|
41
|
+
G : int
|
|
42
|
+
Number of histogram bins. Must satisfy h > 8 * (data_range / G) for
|
|
43
|
+
reliable derivative estimation. Use `adaptive_fft_G` to choose G
|
|
44
|
+
automatically before bisection.
|
|
45
|
+
pad_factor : int
|
|
46
|
+
Zero-padding multiplier (default 4). Mandatory ≥ 2 for circular-wrap
|
|
47
|
+
correctness; 4 is recommended at the largest h encountered.
|
|
48
|
+
|
|
49
|
+
Returns
|
|
50
|
+
-------
|
|
51
|
+
int
|
|
52
|
+
Number of KDE modes (downward zero-crossings of f').
|
|
53
|
+
"""
|
|
54
|
+
with torch.no_grad():
|
|
55
|
+
# Domain: extend 3σ beyond data range to avoid boundary effects
|
|
56
|
+
sigma = X.std().item()
|
|
57
|
+
if sigma == 0.0:
|
|
58
|
+
sigma = 1.0 # degenerate case: all points identical
|
|
59
|
+
lo = X.min().item() - 3 * sigma
|
|
60
|
+
hi = X.max().item() + 3 * sigma
|
|
61
|
+
data_range = hi - lo
|
|
62
|
+
|
|
63
|
+
if data_range == 0.0:
|
|
64
|
+
return 1 # single-point distribution has 1 mode
|
|
65
|
+
|
|
66
|
+
# Histogram (O(n), CUDA-native)
|
|
67
|
+
counts = torch.histc(X.float(), bins=G, min=lo, max=hi)
|
|
68
|
+
|
|
69
|
+
# Zero-pad to pad_factor*G (4× mandatory for circular wrap correctness at h_hi)
|
|
70
|
+
N = pad_factor * G
|
|
71
|
+
counts_padded = torch.zeros(N, dtype=torch.float32, device=X.device)
|
|
72
|
+
counts_padded[:G] = counts
|
|
73
|
+
|
|
74
|
+
# FFT of histogram
|
|
75
|
+
C = torch.fft.rfft(counts_padded)
|
|
76
|
+
|
|
77
|
+
# Derivative kernel in frequency domain: iω * exp(-0.5*(ω*h)²)
|
|
78
|
+
# ω_k = 2π*k / (N * bin_width), bin_width = data_range / G
|
|
79
|
+
bin_width = data_range / G
|
|
80
|
+
k = torch.arange(N // 2 + 1, device=X.device, dtype=torch.float32)
|
|
81
|
+
omega = 2 * math.pi * k / (N * bin_width)
|
|
82
|
+
K_deriv = 1j * omega * torch.exp(-0.5 * (omega * h) ** 2)
|
|
83
|
+
|
|
84
|
+
# Convolve and back-transform
|
|
85
|
+
f_prime_padded = torch.fft.irfft(C * K_deriv, n=N)
|
|
86
|
+
|
|
87
|
+
# Trim to original G grid (discard zero-padded tail)
|
|
88
|
+
f_prime = f_prime_padded[:G]
|
|
89
|
+
|
|
90
|
+
# Count (+→-) sign changes = number of modes
|
|
91
|
+
# A mode is a local max of f, i.e., f' crosses zero from + to -
|
|
92
|
+
# Remove zeros (flat segments) — carry forward last nonzero sign
|
|
93
|
+
nonzero_mask = f_prime != 0
|
|
94
|
+
if not nonzero_mask.any():
|
|
95
|
+
return 0
|
|
96
|
+
|
|
97
|
+
s = f_prime[nonzero_mask]
|
|
98
|
+
transitions = int(((s[:-1] > 0) & (s[1:] < 0)).sum().item())
|
|
99
|
+
return transitions
|
|
100
|
+
|
|
101
|
+
|
|
102
|
+
def adaptive_fft_G(data_range: float, h_hi: float, G_min: int = 4096) -> int:
|
|
103
|
+
"""Choose FFT grid size G so that the derivative kernel is well-resolved.
|
|
104
|
+
|
|
105
|
+
Requires h > 8 * bin_width = 8 * data_range / G, equivalently
|
|
106
|
+
G > 8 * data_range / h_hi. We use a factor of 16 for safety margin,
|
|
107
|
+
then round up to the next power of 2 for efficient FFT.
|
|
108
|
+
|
|
109
|
+
Parameters
|
|
110
|
+
----------
|
|
111
|
+
data_range : float
|
|
112
|
+
hi - lo of the data domain (typically X.max() - X.min() + 6σ).
|
|
113
|
+
h_hi : float
|
|
114
|
+
Upper bracket of the bisection (smallest h needing resolution).
|
|
115
|
+
G_min : int
|
|
116
|
+
Minimum returned G (default 4096).
|
|
117
|
+
|
|
118
|
+
Returns
|
|
119
|
+
-------
|
|
120
|
+
int
|
|
121
|
+
Grid size G, a power of 2, at least G_min.
|
|
122
|
+
"""
|
|
123
|
+
needed = 16 * math.ceil(data_range / h_hi)
|
|
124
|
+
# Round up to next power of 2
|
|
125
|
+
p = 1
|
|
126
|
+
while p < needed:
|
|
127
|
+
p <<= 1
|
|
128
|
+
return max(G_min, p)
|