diffcb 0.1.0__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
dcb/__init__.py ADDED
@@ -0,0 +1,22 @@
1
+ """
2
+ dcb — Differentiable Critical Bandwidth
3
+
4
+ A PyTorch package that makes Silverman's critical bandwidth test (1981) fully
5
+ differentiable via a smooth mode-counting integral and an Implicit Function
6
+ Theorem (IFT) backward pass. The primary public API is the
7
+ `DifferentiableCriticalBandwidth` class, which behaves as a standard
8
+ `torch.nn.Module` and can be used as a loss component or regularizer in any
9
+ gradient-based learning pipeline. Import as `from dcb import DCBLayer` for
10
+ the layer, or `from dcb.kde import gaussian_kde_grid` for lower-level KDE
11
+ utilities. Requires PyTorch >= 2.0, NumPy >= 1.24, and SciPy >= 1.10.
12
+ """
13
+
14
+ from dcb.layer import DCBLayer, DifferentiableCriticalBandwidth
15
+ from dcb.utils import anneal_eps_tau
16
+ from dcb.kde import soft_mode_count_cross, soft_mode_count
17
+
18
+ __all__ = [
19
+ "DCBLayer", "DifferentiableCriticalBandwidth",
20
+ "anneal_eps_tau", "soft_mode_count_cross", "soft_mode_count",
21
+ ]
22
+ __version__ = "0.1.0"
dcb/diagnostics.py ADDED
@@ -0,0 +1,163 @@
1
+ """
2
+ dcb.diagnostics — Gradient Stability Diagnostics for DCB
3
+
4
+ Provides `denom_profile()` which maps M̃(h) and ∂M̃/∂h over a bandwidth grid
5
+ to assess gradient conditioning before training. A stable IFT gradient at
6
+ h_crit requires |∂M̃/∂h| > 0 (non-zero denominator in the IFT formula).
7
+
8
+ Use case: call denom_profile() on your dataset before fitting DCBLayer to
9
+ verify that the IFT gradient is well-conditioned at h_crit. If
10
+ stability_mask=False at h_crit, consider using safe_backward=True or
11
+ widening the bandwidth search range.
12
+ """
13
+
14
+ from __future__ import annotations
15
+
16
+ import torch
17
+ from torch import Tensor
18
+
19
+ from dcb.kde import soft_mode_count_cross
20
+ from dcb.utils import make_grid
21
+
22
+
23
+ def denom_profile(
24
+ X: Tensor,
25
+ h_grid: Tensor,
26
+ formula: str = 'cross',
27
+ eps: float = 0.1,
28
+ tau: float = 0.2,
29
+ chunk_size: int = 50_000,
30
+ guard: float = 0.01,
31
+ ) -> dict:
32
+ """Compute M̃(h) and ∂M̃/∂h over a bandwidth grid for gradient stability diagnosis.
33
+
34
+ Evaluates the soft mode count M̃_cross at each bandwidth in h_grid, then
35
+ computes the finite-difference derivative ∂M̃/∂h. The stability_mask
36
+ identifies bandwidths where the IFT denominator is large enough for
37
+ well-conditioned gradients.
38
+
39
+ Parameters
40
+ ----------
41
+ X : Tensor, shape (n,)
42
+ Observed data points.
43
+ h_grid : Tensor, shape (H,)
44
+ Bandwidth grid to evaluate. Should cover the expected h_crit.
45
+ formula : str
46
+ Mode-count formula to use. Only 'cross' is supported (matches DCBLayer).
47
+ eps : float
48
+ Sigmoid temperature for the zero-crossing detector. Default 0.1.
49
+ tau : float
50
+ Sigmoid temperature for the local-max selector. Default 0.2.
51
+ chunk_size : int
52
+ Chunk size for KDE computation (not used in dense path, kept for API
53
+ consistency with large-n paths).
54
+ guard : float
55
+ Threshold for stability_mask: True where |dM_dh| > guard. Default 0.01.
56
+
57
+ Returns
58
+ -------
59
+ dict with keys:
60
+ 'h_grid' : Tensor (H,) — input bandwidth grid
61
+ 'M_tilde' : Tensor (H,) — soft mode count at each h
62
+ 'dM_dh' : Tensor (H,) — ∂M̃/∂h via central finite differences
63
+ 'stability_mask' : BoolTensor (H,) — True where |dM_dh| > guard
64
+ 'h_crit_approx' : float — approximate h_crit (smallest h where M̃ ≤ 1.5);
65
+ float('nan') if not found
66
+
67
+ Notes
68
+ -----
69
+ All computation runs under torch.no_grad() — this function is diagnostic
70
+ only and does not build a computation graph.
71
+
72
+ Use case: call before training to confirm gradients are well-conditioned
73
+ at h_crit. stability_mask=True at h_crit means IFT gradient is valid for
74
+ that dataset.
75
+ """
76
+ if formula != 'cross':
77
+ raise ValueError(f"Only formula='cross' is supported; got {formula!r}")
78
+
79
+ H = h_grid.shape[0]
80
+ grid = make_grid(X, G=512)
81
+
82
+ M_tilde = torch.zeros(H, dtype=X.dtype, device=X.device)
83
+
84
+ with torch.no_grad():
85
+ for i in range(H):
86
+ h_val = h_grid[i].item()
87
+ M_tilde[i] = soft_mode_count_cross(X, h_val, grid, eps, tau)
88
+
89
+ # Central finite differences for interior; forward/backward at edges
90
+ dM_dh = torch.zeros(H, dtype=X.dtype, device=X.device)
91
+ for i in range(H):
92
+ if i == 0:
93
+ # Forward difference
94
+ dM_dh[i] = (M_tilde[1] - M_tilde[0]) / (h_grid[1] - h_grid[0])
95
+ elif i == H - 1:
96
+ # Backward difference
97
+ dM_dh[i] = (M_tilde[H - 1] - M_tilde[H - 2]) / (h_grid[H - 1] - h_grid[H - 2])
98
+ else:
99
+ # Central difference
100
+ dM_dh[i] = (M_tilde[i + 1] - M_tilde[i - 1]) / (h_grid[i + 1] - h_grid[i - 1])
101
+
102
+ stability_mask = dM_dh.abs() > guard
103
+
104
+ # h_crit_approx: smallest h where M̃ ≤ 1.5 (threshold for target_modes=1)
105
+ below_threshold = (M_tilde <= 1.5).nonzero(as_tuple=False)
106
+ if below_threshold.numel() > 0:
107
+ first_idx = below_threshold[0].item()
108
+ h_crit_approx = h_grid[first_idx].item()
109
+ else:
110
+ h_crit_approx = float('nan')
111
+
112
+ return {
113
+ 'h_grid': h_grid,
114
+ 'M_tilde': M_tilde,
115
+ 'dM_dh': dM_dh,
116
+ 'stability_mask': stability_mask,
117
+ 'h_crit_approx': h_crit_approx,
118
+ }
119
+
120
+
121
+ def print_stability_report(profile: dict) -> None:
122
+ """Print a human-readable stability report from denom_profile output.
123
+
124
+ Parameters
125
+ ----------
126
+ profile : dict
127
+ Output from `denom_profile()`.
128
+ """
129
+ h_grid = profile['h_grid']
130
+ M_tilde = profile['M_tilde']
131
+ dM_dh = profile['dM_dh']
132
+ stability_mask = profile['stability_mask']
133
+ h_crit_approx = profile['h_crit_approx']
134
+
135
+ H = h_grid.shape[0]
136
+ n_stable = stability_mask.sum().item()
137
+ pct_stable = 100.0 * n_stable / H
138
+
139
+ print("=" * 60)
140
+ print("DCB Gradient Stability Report")
141
+ print("=" * 60)
142
+ print(f" h_grid range : [{h_grid.min().item():.4f}, {h_grid.max().item():.4f}] (H={H})")
143
+ print(f" M_tilde range : [{M_tilde.min().item():.4f}, {M_tilde.max().item():.4f}]")
144
+ print(f" dM_dh range : [{dM_dh.min().item():.4f}, {dM_dh.max().item():.4f}]")
145
+ print(f" h_crit_approx : {h_crit_approx:.4f}" if h_crit_approx == h_crit_approx
146
+ else " h_crit_approx : NaN (M_tilde never <= 1.5 in grid)")
147
+ print(f" Stable points : {n_stable}/{H} ({pct_stable:.1f}%)")
148
+
149
+ if h_crit_approx == h_crit_approx: # not NaN
150
+ # Find index of h_crit_approx
151
+ idx = (h_grid - h_crit_approx).abs().argmin().item()
152
+ stable_at_hcrit = stability_mask[idx].item()
153
+ dM_at_hcrit = dM_dh[idx].abs().item()
154
+ print(f" At h_crit : stability={stable_at_hcrit}, |dM_dh|={dM_at_hcrit:.4f}")
155
+ if not stable_at_hcrit:
156
+ print()
157
+ print(" WARNING: h_crit_approx falls in an UNSTABLE region.")
158
+ print(" IFT gradient may be ill-conditioned at h_crit.")
159
+ print(" Consider: safe_backward=True, wider h_grid, or larger n.")
160
+ else:
161
+ print()
162
+ print(" OK: IFT gradient is well-conditioned at h_crit.")
163
+ print("=" * 60)
dcb/fft_kde.py ADDED
@@ -0,0 +1,128 @@
1
+ """
2
+ dcb.fft_kde — FFT-based KDE Mode Counter
3
+
4
+ Implements mode counting via FFT convolution of the histogram with a
5
+ Gaussian derivative kernel. Complexity is O(n + G log G), avoiding the
6
+ O(n × G) cost of the direct KDE approach and — crucially — requiring NO
7
+ subsampling. This eliminates the (brentq_n_max / n)^{-1/5} upward bias
8
+ that affects the standard bisection path when n > brentq_n_max.
9
+
10
+ Round 18b: forward kernel only. The IFT backward is unchanged (still uses
11
+ the analytical chunked KDE derivatives on all n points).
12
+ """
13
+
14
+ from __future__ import annotations
15
+
16
+ import math
17
+
18
+ import torch
19
+ from torch import Tensor
20
+
21
+
22
+ def fft_mode_count(
23
+ X: Tensor,
24
+ h: float,
25
+ G: int = 4096,
26
+ pad_factor: int = 4,
27
+ ) -> int:
28
+ """Count KDE modes via FFT convolution — O(n + G log G), no subsampling.
29
+
30
+ Bins X into G histogram bins, zero-pads to pad_factor*G, convolves with
31
+ the Gaussian derivative kernel in the frequency domain (applying iω·exp(−½(ωh)²)),
32
+ back-transforms, and counts positive-to-negative sign changes of the
33
+ resulting f' estimate.
34
+
35
+ Parameters
36
+ ----------
37
+ X : Tensor, shape (n,)
38
+ 1D data tensor (may be on CPU or CUDA).
39
+ h : float
40
+ Bandwidth for the Gaussian kernel.
41
+ G : int
42
+ Number of histogram bins. Must satisfy h > 8 * (data_range / G) for
43
+ reliable derivative estimation. Use `adaptive_fft_G` to choose G
44
+ automatically before bisection.
45
+ pad_factor : int
46
+ Zero-padding multiplier (default 4). Mandatory ≥ 2 for circular-wrap
47
+ correctness; 4 is recommended at the largest h encountered.
48
+
49
+ Returns
50
+ -------
51
+ int
52
+ Number of KDE modes (downward zero-crossings of f').
53
+ """
54
+ with torch.no_grad():
55
+ # Domain: extend 3σ beyond data range to avoid boundary effects
56
+ sigma = X.std().item()
57
+ if sigma == 0.0:
58
+ sigma = 1.0 # degenerate case: all points identical
59
+ lo = X.min().item() - 3 * sigma
60
+ hi = X.max().item() + 3 * sigma
61
+ data_range = hi - lo
62
+
63
+ if data_range == 0.0:
64
+ return 1 # single-point distribution has 1 mode
65
+
66
+ # Histogram (O(n), CUDA-native)
67
+ counts = torch.histc(X.float(), bins=G, min=lo, max=hi)
68
+
69
+ # Zero-pad to pad_factor*G (4× mandatory for circular wrap correctness at h_hi)
70
+ N = pad_factor * G
71
+ counts_padded = torch.zeros(N, dtype=torch.float32, device=X.device)
72
+ counts_padded[:G] = counts
73
+
74
+ # FFT of histogram
75
+ C = torch.fft.rfft(counts_padded)
76
+
77
+ # Derivative kernel in frequency domain: iω * exp(-0.5*(ω*h)²)
78
+ # ω_k = 2π*k / (N * bin_width), bin_width = data_range / G
79
+ bin_width = data_range / G
80
+ k = torch.arange(N // 2 + 1, device=X.device, dtype=torch.float32)
81
+ omega = 2 * math.pi * k / (N * bin_width)
82
+ K_deriv = 1j * omega * torch.exp(-0.5 * (omega * h) ** 2)
83
+
84
+ # Convolve and back-transform
85
+ f_prime_padded = torch.fft.irfft(C * K_deriv, n=N)
86
+
87
+ # Trim to original G grid (discard zero-padded tail)
88
+ f_prime = f_prime_padded[:G]
89
+
90
+ # Count (+→-) sign changes = number of modes
91
+ # A mode is a local max of f, i.e., f' crosses zero from + to -
92
+ # Remove zeros (flat segments) — carry forward last nonzero sign
93
+ nonzero_mask = f_prime != 0
94
+ if not nonzero_mask.any():
95
+ return 0
96
+
97
+ s = f_prime[nonzero_mask]
98
+ transitions = int(((s[:-1] > 0) & (s[1:] < 0)).sum().item())
99
+ return transitions
100
+
101
+
102
+ def adaptive_fft_G(data_range: float, h_hi: float, G_min: int = 4096) -> int:
103
+ """Choose FFT grid size G so that the derivative kernel is well-resolved.
104
+
105
+ Requires h > 8 * bin_width = 8 * data_range / G, equivalently
106
+ G > 8 * data_range / h_hi. We use a factor of 16 for safety margin,
107
+ then round up to the next power of 2 for efficient FFT.
108
+
109
+ Parameters
110
+ ----------
111
+ data_range : float
112
+ hi - lo of the data domain (typically X.max() - X.min() + 6σ).
113
+ h_hi : float
114
+ Upper bracket of the bisection (smallest h needing resolution).
115
+ G_min : int
116
+ Minimum returned G (default 4096).
117
+
118
+ Returns
119
+ -------
120
+ int
121
+ Grid size G, a power of 2, at least G_min.
122
+ """
123
+ needed = 16 * math.ceil(data_range / h_hi)
124
+ # Round up to next power of 2
125
+ p = 1
126
+ while p < needed:
127
+ p <<= 1
128
+ return max(G_min, p)