diffcb 0.1.0__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- dcb/__init__.py +22 -0
- dcb/diagnostics.py +163 -0
- dcb/fft_kde.py +128 -0
- dcb/kde.py +394 -0
- dcb/layer.py +231 -0
- dcb/solver.py +604 -0
- dcb/utils.py +183 -0
- diffcb-0.1.0.dist-info/METADATA +148 -0
- diffcb-0.1.0.dist-info/RECORD +11 -0
- diffcb-0.1.0.dist-info/WHEEL +4 -0
- diffcb-0.1.0.dist-info/licenses/LICENSE +21 -0
dcb/kde.py
ADDED
|
@@ -0,0 +1,394 @@
|
|
|
1
|
+
"""
|
|
2
|
+
dcb.kde — Differentiable Gaussian KDE and Derivative Utilities
|
|
3
|
+
|
|
4
|
+
Two computation paths share the same formulas but differ in memory layout:
|
|
5
|
+
|
|
6
|
+
Dense path (n × G matrix, small n)
|
|
7
|
+
gaussian_kde_grid, kde_derivatives, soft_mode_count_cross
|
|
8
|
+
— simple, autograd-friendly, O(n × G) memory.
|
|
9
|
+
|
|
10
|
+
Chunked path (processes X in blocks, large n / GPU)
|
|
11
|
+
kde_derivatives_chunked, soft_mode_count_cross_from_derivs
|
|
12
|
+
— O(chunk × G) peak memory, same arithmetic.
|
|
13
|
+
kde_derivatives_chunked returns (f, f', f'') on the grid by
|
|
14
|
+
accumulating chunk contributions; soft_mode_count_cross_from_derivs
|
|
15
|
+
evaluates M̃_cross given pre-computed derivatives, enabling the
|
|
16
|
+
analytical IFT gradient without materialising the full n × G graph.
|
|
17
|
+
|
|
18
|
+
All functions operate on 1D input tensors (univariate DCB).
|
|
19
|
+
"""
|
|
20
|
+
|
|
21
|
+
from __future__ import annotations
|
|
22
|
+
|
|
23
|
+
import math
|
|
24
|
+
|
|
25
|
+
import torch
|
|
26
|
+
from torch import Tensor
|
|
27
|
+
|
|
28
|
+
|
|
29
|
+
def gaussian_kde_grid(X: Tensor, h: float, grid: Tensor) -> Tensor:
|
|
30
|
+
"""Evaluate the Gaussian KDE f_h on a grid of points.
|
|
31
|
+
|
|
32
|
+
Parameters
|
|
33
|
+
----------
|
|
34
|
+
X : Tensor, shape (n,)
|
|
35
|
+
Observed data points.
|
|
36
|
+
h : float
|
|
37
|
+
Bandwidth (positive scalar).
|
|
38
|
+
grid : Tensor, shape (G,)
|
|
39
|
+
Evaluation points.
|
|
40
|
+
|
|
41
|
+
Returns
|
|
42
|
+
-------
|
|
43
|
+
Tensor, shape (G,)
|
|
44
|
+
f_h evaluated at each grid point.
|
|
45
|
+
"""
|
|
46
|
+
# diff[i, j] = grid[j] - X[i], shape (n, G)
|
|
47
|
+
diff = grid.unsqueeze(0) - X.unsqueeze(1)
|
|
48
|
+
K = torch.exp(-0.5 * (diff / h) ** 2) / (math.sqrt(2 * math.pi) * h)
|
|
49
|
+
return K.mean(dim=0) # (G,)
|
|
50
|
+
|
|
51
|
+
|
|
52
|
+
def kde_derivatives(
|
|
53
|
+
X: Tensor, h: float, grid: Tensor
|
|
54
|
+
) -> tuple[Tensor, Tensor, Tensor]:
|
|
55
|
+
"""Evaluate the Gaussian KDE and its first two analytical derivatives on a grid.
|
|
56
|
+
|
|
57
|
+
Parameters
|
|
58
|
+
----------
|
|
59
|
+
X : Tensor, shape (n,)
|
|
60
|
+
Observed data points.
|
|
61
|
+
h : float
|
|
62
|
+
Bandwidth (positive scalar).
|
|
63
|
+
grid : Tensor, shape (G,)
|
|
64
|
+
Evaluation points.
|
|
65
|
+
|
|
66
|
+
Returns
|
|
67
|
+
-------
|
|
68
|
+
(f, f_prime, f_double_prime) : tuple of Tensor, each shape (G,)
|
|
69
|
+
KDE value, first derivative, and second derivative at each grid point.
|
|
70
|
+
All outputs support autograd.
|
|
71
|
+
"""
|
|
72
|
+
# diff[i, j] = grid[j] - X[i], shape (n, G)
|
|
73
|
+
diff = grid.unsqueeze(0) - X.unsqueeze(1)
|
|
74
|
+
K = torch.exp(-0.5 * (diff / h) ** 2) / (math.sqrt(2 * math.pi) * h)
|
|
75
|
+
|
|
76
|
+
f = K.mean(0)
|
|
77
|
+
f_prime = (-diff / h ** 2 * K).mean(0)
|
|
78
|
+
f_double_prime = ((diff ** 2 / h ** 4) - (1.0 / h ** 2)) * K
|
|
79
|
+
f_double_prime = f_double_prime.mean(0)
|
|
80
|
+
|
|
81
|
+
return f, f_prime, f_double_prime
|
|
82
|
+
|
|
83
|
+
|
|
84
|
+
def kde_derivatives_chunked(
|
|
85
|
+
X: Tensor, h: float, grid: Tensor, chunk_size: int = 50_000
|
|
86
|
+
) -> tuple[Tensor, Tensor, Tensor]:
|
|
87
|
+
"""Chunked KDE derivatives — O(chunk_size × G) peak memory.
|
|
88
|
+
|
|
89
|
+
Identical arithmetic to `kde_derivatives` but accumulates contributions
|
|
90
|
+
from X in blocks of `chunk_size` so the peak GPU allocation is
|
|
91
|
+
O(chunk_size × G) instead of O(n × G). Fully differentiable via autograd
|
|
92
|
+
(the chunk loop produces a valid computation graph).
|
|
93
|
+
|
|
94
|
+
Parameters
|
|
95
|
+
----------
|
|
96
|
+
X : Tensor, shape (n,)
|
|
97
|
+
h : float
|
|
98
|
+
grid : Tensor, shape (G,)
|
|
99
|
+
chunk_size : int
|
|
100
|
+
Number of data points per block. Default 50 000 uses ≈100 MB at
|
|
101
|
+
G = 512, float32. Increase for better GPU utilisation; decrease if
|
|
102
|
+
GPU OOM.
|
|
103
|
+
|
|
104
|
+
Returns
|
|
105
|
+
-------
|
|
106
|
+
(f, f_prime, f_double_prime) : tuple of Tensor, each shape (G,)
|
|
107
|
+
"""
|
|
108
|
+
n, G = X.shape[0], grid.shape[0]
|
|
109
|
+
f = torch.zeros(G, dtype=X.dtype, device=X.device)
|
|
110
|
+
fp = torch.zeros_like(f)
|
|
111
|
+
fpp = torch.zeros_like(f)
|
|
112
|
+
for start in range(0, n, chunk_size):
|
|
113
|
+
Xc = X[start : start + chunk_size]
|
|
114
|
+
diff = grid.unsqueeze(0) - Xc.unsqueeze(1) # (c, G)
|
|
115
|
+
K = torch.exp(-0.5 * (diff / h) ** 2) / (math.sqrt(2 * math.pi) * h)
|
|
116
|
+
f.add_(K.sum(0))
|
|
117
|
+
fp.add_((-diff / h ** 2 * K).sum(0))
|
|
118
|
+
fpp.add_((((diff ** 2 / h ** 4) - (1.0 / h ** 2)) * K).sum(0))
|
|
119
|
+
f.div_(n); fp.div_(n); fpp.div_(n)
|
|
120
|
+
return f, fp, fpp
|
|
121
|
+
|
|
122
|
+
|
|
123
|
+
def soft_mode_count_cross_from_derivs(
|
|
124
|
+
f: Tensor,
|
|
125
|
+
f_prime: Tensor,
|
|
126
|
+
f_double_prime: Tensor,
|
|
127
|
+
grid: Tensor,
|
|
128
|
+
eps: float,
|
|
129
|
+
tau: float,
|
|
130
|
+
) -> Tensor:
|
|
131
|
+
"""Evaluate M̃_cross from pre-computed KDE derivatives.
|
|
132
|
+
|
|
133
|
+
Same formula as `soft_mode_count_cross` but accepts (f, f', f'') directly
|
|
134
|
+
instead of recomputing them from X. Used by the large-n analytical IFT
|
|
135
|
+
path: call `kde_derivatives_chunked` once to get (f, f', f''), then pass
|
|
136
|
+
them here with `requires_grad=True` to obtain ∂M̃/∂f' and ∂M̃/∂f'' via
|
|
137
|
+
autograd — only G-dimensional tensors need to stay in the graph.
|
|
138
|
+
|
|
139
|
+
Parameters
|
|
140
|
+
----------
|
|
141
|
+
f, f_prime, f_double_prime : Tensor, each shape (G,)
|
|
142
|
+
grid : Tensor, shape (G,)
|
|
143
|
+
eps, tau : float
|
|
144
|
+
|
|
145
|
+
Returns
|
|
146
|
+
-------
|
|
147
|
+
Tensor, shape ()
|
|
148
|
+
"""
|
|
149
|
+
dx = ((grid[-1] - grid[0]) / (grid.shape[0] - 1)).detach()
|
|
150
|
+
fpp_scale = f_double_prime.abs().median().clamp(min=1e-10)
|
|
151
|
+
eps_eff = eps * fpp_scale * dx
|
|
152
|
+
tau_eff = tau * fpp_scale
|
|
153
|
+
|
|
154
|
+
fp_j = f_prime[:-1]
|
|
155
|
+
fp_j1 = f_prime[1:]
|
|
156
|
+
fpp_mid = (f_double_prime[:-1] + f_double_prime[1:]) / 2
|
|
157
|
+
|
|
158
|
+
pos_to_neg = torch.sigmoid( fp_j / eps_eff) * torch.sigmoid(-fp_j1 / eps_eff)
|
|
159
|
+
neg_to_pos = torch.sigmoid(-fp_j / eps_eff) * torch.sigmoid( fp_j1 / eps_eff)
|
|
160
|
+
crossing = pos_to_neg + neg_to_pos
|
|
161
|
+
local_max = torch.sigmoid(-fpp_mid / tau_eff)
|
|
162
|
+
|
|
163
|
+
f_peak = f.max().clamp(min=1e-10)
|
|
164
|
+
f_mid = (f[:-1] + f[1:]) / 2
|
|
165
|
+
density_mask = torch.sigmoid((f_mid - 0.01 * f_peak) / (0.001 * f_peak))
|
|
166
|
+
|
|
167
|
+
return (crossing * local_max * density_mask).sum()
|
|
168
|
+
|
|
169
|
+
|
|
170
|
+
def soft_mode_count_cross(
|
|
171
|
+
X: Tensor, h: float, grid: Tensor, eps: float, tau: float
|
|
172
|
+
) -> Tensor:
|
|
173
|
+
"""Compute the crossing-count soft mode count M̃_cross(h).
|
|
174
|
+
|
|
175
|
+
Counts positive-to-negative sign changes of f'_h that coincide with a local
|
|
176
|
+
maximum (f'' < 0), using sigmoid products to smoothly approximate the
|
|
177
|
+
discrete sign-change indicator. Unlike the integral formula, M̃_cross is
|
|
178
|
+
bounded and converges to M(h) as ε → 0 on a dense grid, without the
|
|
179
|
+
divergence pathology of the integral formula in the unimodal regime.
|
|
180
|
+
|
|
181
|
+
Formula:
|
|
182
|
+
M̃_cross = Σ_{j=0}^{G-2}
|
|
183
|
+
[σ(f'_j/ε)·σ(−f'_{j+1}/ε) + σ(−f'_j/ε)·σ(f'_{j+1}/ε)]
|
|
184
|
+
· σ(−f''_mid,j / τ)
|
|
185
|
+
|
|
186
|
+
where f''_mid,j = (f''_j + f''_{j+1}) / 2 and σ is the logistic sigmoid.
|
|
187
|
+
The first bracket captures both + → − and − → + crossings; the second
|
|
188
|
+
bracket selects only local maxima (f'' < 0). The result equals the number
|
|
189
|
+
of downward zero-crossings of f', i.e., the number of modes of f_h.
|
|
190
|
+
|
|
191
|
+
Parameters
|
|
192
|
+
----------
|
|
193
|
+
X : Tensor, shape (n,)
|
|
194
|
+
Observed data points.
|
|
195
|
+
h : float
|
|
196
|
+
Bandwidth (positive scalar).
|
|
197
|
+
grid : Tensor, shape (G,)
|
|
198
|
+
Uniform evaluation grid.
|
|
199
|
+
eps : float
|
|
200
|
+
Sigmoid temperature controlling the sharpness of the zero-crossing
|
|
201
|
+
detector on f' (smaller = sharper, converges to true M(h) faster).
|
|
202
|
+
tau : float
|
|
203
|
+
Sigmoid temperature for the local-max selector on f''.
|
|
204
|
+
|
|
205
|
+
Returns
|
|
206
|
+
-------
|
|
207
|
+
Tensor, shape ()
|
|
208
|
+
Scalar soft mode count in [0, G−1], differentiable w.r.t. X.
|
|
209
|
+
"""
|
|
210
|
+
f, f_prime, f_double_prime = kde_derivatives(X, h, grid)
|
|
211
|
+
|
|
212
|
+
# Calibrate eps to the local crossing scale, not the global f' amplitude.
|
|
213
|
+
# A single pos→neg crossing of f'_h happens over ~Δx in x-space; the f'
|
|
214
|
+
# value just outside the crossing pair is O(|f''_h| · Δx). We need
|
|
215
|
+
# eps_eff << |f''_h| · Δx so non-crossing pairs saturate while the
|
|
216
|
+
# crossing pair contributes ~1. Using eps_eff = eps · |f''_h| · Δx gives
|
|
217
|
+
# σ(|f'_crossing| / eps_eff) ≈ σ(1/(2·eps)) — well-calibrated for any h.
|
|
218
|
+
# Stop-gradient prevents these scale factors from entering the IFT backward.
|
|
219
|
+
dx = ((grid[-1] - grid[0]) / (grid.shape[0] - 1)).detach()
|
|
220
|
+
# fpp_scale retains gradient so the full ∂M/∂X and ∂M/∂h paths flow through
|
|
221
|
+
# the self-normalizing calibration — required for gradcheck to pass.
|
|
222
|
+
fpp_scale = f_double_prime.abs().median().clamp(min=1e-10)
|
|
223
|
+
eps_eff = eps * fpp_scale * dx # dimensionless × curvature × grid-step
|
|
224
|
+
tau_eff = tau * fpp_scale # dimensionless × curvature
|
|
225
|
+
|
|
226
|
+
fp_j = f_prime[:-1]
|
|
227
|
+
fp_j1 = f_prime[1:]
|
|
228
|
+
fpp_mid = (f_double_prime[:-1] + f_double_prime[1:]) / 2
|
|
229
|
+
|
|
230
|
+
pos_to_neg = torch.sigmoid(fp_j / eps_eff) * torch.sigmoid(-fp_j1 / eps_eff)
|
|
231
|
+
neg_to_pos = torch.sigmoid(-fp_j / eps_eff) * torch.sigmoid(fp_j1 / eps_eff)
|
|
232
|
+
crossing = pos_to_neg + neg_to_pos
|
|
233
|
+
|
|
234
|
+
local_max = torch.sigmoid(-fpp_mid / tau_eff)
|
|
235
|
+
|
|
236
|
+
# Suppress spurious tail crossings: in low-density regions f'_h ≈ 0 but
|
|
237
|
+
# fluctuates, giving fractional sigmoid values on every grid pair. Mask
|
|
238
|
+
# out intervals where the average density is below 1% of the peak.
|
|
239
|
+
# The threshold and sharpness are both relative to f_peak so this scales
|
|
240
|
+
# automatically with the KDE amplitude at any bandwidth.
|
|
241
|
+
f_peak = f.max().clamp(min=1e-10)
|
|
242
|
+
f_mid = (f[:-1] + f[1:]) / 2
|
|
243
|
+
density_mask = torch.sigmoid((f_mid - 0.01 * f_peak) / (0.001 * f_peak))
|
|
244
|
+
|
|
245
|
+
return (crossing * local_max * density_mask).sum()
|
|
246
|
+
|
|
247
|
+
|
|
248
|
+
def soft_mode_count(
|
|
249
|
+
X: Tensor, h: float, grid: Tensor, eps: float, tau: float
|
|
250
|
+
) -> Tensor:
|
|
251
|
+
"""Compute the soft (differentiable) mode count M̃(h).
|
|
252
|
+
|
|
253
|
+
Approximates the number of modes of f_h via the Riemann-sum integral
|
|
254
|
+
|
|
255
|
+
M̃(h) = Δx · Σ_j δ_ε(f'_h(grid_j)) · σ_τ(-f''_h(grid_j)) · |f''_h(grid_j)|
|
|
256
|
+
|
|
257
|
+
where δ_ε is a Gaussian approximation to the Dirac delta (peaking where
|
|
258
|
+
f' = 0) and σ_τ is a sigmoid that selects local maxima (f'' < 0).
|
|
259
|
+
|
|
260
|
+
Parameters
|
|
261
|
+
----------
|
|
262
|
+
X : Tensor, shape (n,)
|
|
263
|
+
Observed data points.
|
|
264
|
+
h : float
|
|
265
|
+
Bandwidth (positive scalar).
|
|
266
|
+
grid : Tensor, shape (G,)
|
|
267
|
+
Uniform evaluation grid.
|
|
268
|
+
eps : float
|
|
269
|
+
Width of the Gaussian delta approximation (controls sharpness of the
|
|
270
|
+
zero-crossing detector on f').
|
|
271
|
+
tau : float
|
|
272
|
+
Temperature of the sigmoid (controls sharpness of the local-max selector
|
|
273
|
+
on f'').
|
|
274
|
+
|
|
275
|
+
Returns
|
|
276
|
+
-------
|
|
277
|
+
Tensor, shape ()
|
|
278
|
+
Scalar soft mode count, differentiable w.r.t. X.
|
|
279
|
+
"""
|
|
280
|
+
_, f_prime, f_double_prime = kde_derivatives(X, h, grid)
|
|
281
|
+
|
|
282
|
+
# Gaussian approximation to Dirac delta centred at f' = 0
|
|
283
|
+
delta_eps = torch.exp(-0.5 * (f_prime / eps) ** 2) / (
|
|
284
|
+
math.sqrt(2 * math.pi) * eps
|
|
285
|
+
)
|
|
286
|
+
|
|
287
|
+
# Sigmoid selects points where f'' < 0, i.e. local maxima
|
|
288
|
+
sigma_tau = torch.sigmoid(-f_double_prime / tau)
|
|
289
|
+
|
|
290
|
+
integrand = delta_eps * sigma_tau * f_double_prime.abs()
|
|
291
|
+
|
|
292
|
+
dx = (grid[-1] - grid[0]) / (grid.shape[0] - 1)
|
|
293
|
+
return (integrand * dx).sum()
|
|
294
|
+
|
|
295
|
+
|
|
296
|
+
def kde_derivatives_batched(
|
|
297
|
+
X: Tensor,
|
|
298
|
+
h_batch: Tensor,
|
|
299
|
+
grid: Tensor,
|
|
300
|
+
chunk_size: int = 10_000,
|
|
301
|
+
) -> tuple[Tensor, Tensor, Tensor]:
|
|
302
|
+
"""Evaluate Gaussian KDE derivatives for B bandwidths simultaneously.
|
|
303
|
+
|
|
304
|
+
Accumulates contributions from X in chunks to keep peak memory at
|
|
305
|
+
O(chunk_size × B × G) instead of O(n × B × G).
|
|
306
|
+
|
|
307
|
+
Parameters
|
|
308
|
+
----------
|
|
309
|
+
X : Tensor, shape (n,)
|
|
310
|
+
h_batch : Tensor, shape (B,)
|
|
311
|
+
Candidate bandwidths.
|
|
312
|
+
grid : Tensor, shape (G,)
|
|
313
|
+
chunk_size : int
|
|
314
|
+
Number of data points per chunk. Default 10 000.
|
|
315
|
+
|
|
316
|
+
Returns
|
|
317
|
+
-------
|
|
318
|
+
(f, f_prime, f_double_prime) : tuple of Tensor, each shape (B, G)
|
|
319
|
+
KDE value and its first two derivatives, one row per bandwidth.
|
|
320
|
+
"""
|
|
321
|
+
n, G, B = X.shape[0], grid.shape[0], h_batch.shape[0]
|
|
322
|
+
f = torch.zeros(B, G, dtype=X.dtype, device=X.device)
|
|
323
|
+
fp = torch.zeros_like(f)
|
|
324
|
+
fpp = torch.zeros_like(f)
|
|
325
|
+
|
|
326
|
+
# h_batch: (B,) → (1, B, 1) for broadcasting against (c, 1, G)
|
|
327
|
+
h = h_batch.view(1, B, 1)
|
|
328
|
+
h2 = h * h
|
|
329
|
+
norm = 1.0 / (math.sqrt(2 * math.pi) * h) # (1, B, 1)
|
|
330
|
+
|
|
331
|
+
for start in range(0, n, chunk_size):
|
|
332
|
+
Xc = X[start : start + chunk_size] # (c,)
|
|
333
|
+
# diff[i, j] = grid[j] - Xc[i] → (c, 1, G) broadcasts to (c, B, G)
|
|
334
|
+
diff = (grid.unsqueeze(0) - Xc.unsqueeze(1)).unsqueeze(1) # (c, 1, G)
|
|
335
|
+
u = diff / h # (c, B, G)
|
|
336
|
+
K = norm * torch.exp(-0.5 * u * u) # (c, B, G)
|
|
337
|
+
f.add_(K.sum(0)) # (B, G)
|
|
338
|
+
fp.add_((-u / h * K).sum(0)) # (B, G)
|
|
339
|
+
fpp.add_((((u * u - 1.0) / h2) * K).sum(0)) # (B, G)
|
|
340
|
+
|
|
341
|
+
f.div_(n); fp.div_(n); fpp.div_(n)
|
|
342
|
+
return f, fp, fpp
|
|
343
|
+
|
|
344
|
+
|
|
345
|
+
def soft_mode_count_cross_batched(
|
|
346
|
+
f: Tensor,
|
|
347
|
+
fp: Tensor,
|
|
348
|
+
fpp: Tensor,
|
|
349
|
+
grid: Tensor,
|
|
350
|
+
eps: float,
|
|
351
|
+
tau: float,
|
|
352
|
+
) -> Tensor:
|
|
353
|
+
"""Vectorised M̃_cross over B bandwidths simultaneously.
|
|
354
|
+
|
|
355
|
+
Applies the same formula as `soft_mode_count_cross_from_derivs` but
|
|
356
|
+
operates on (B, G) tensors, returning one scalar per bandwidth.
|
|
357
|
+
|
|
358
|
+
Parameters
|
|
359
|
+
----------
|
|
360
|
+
f, fp, fpp : Tensor, each shape (B, G)
|
|
361
|
+
grid : Tensor, shape (G,)
|
|
362
|
+
eps, tau : float
|
|
363
|
+
|
|
364
|
+
Returns
|
|
365
|
+
-------
|
|
366
|
+
Tensor, shape (B,)
|
|
367
|
+
One soft mode count per bandwidth.
|
|
368
|
+
"""
|
|
369
|
+
dx = ((grid[-1] - grid[0]) / (grid.shape[0] - 1)).detach()
|
|
370
|
+
|
|
371
|
+
# Per-bandwidth calibration: median of |fpp| along grid axis
|
|
372
|
+
fpp_scale = fpp.abs().median(dim=1).values.clamp(min=1e-10) # (B,)
|
|
373
|
+
eps_eff = eps * fpp_scale * dx # (B,)
|
|
374
|
+
tau_eff = tau * fpp_scale # (B,)
|
|
375
|
+
|
|
376
|
+
fp_j = fp[:, :-1] # (B, G-1)
|
|
377
|
+
fp_j1 = fp[:, 1:] # (B, G-1)
|
|
378
|
+
fpp_mid = (fpp[:, :-1] + fpp[:, 1:]) / 2 # (B, G-1)
|
|
379
|
+
|
|
380
|
+
ee = eps_eff.unsqueeze(1) # (B, 1)
|
|
381
|
+
te = tau_eff.unsqueeze(1) # (B, 1)
|
|
382
|
+
|
|
383
|
+
pos_to_neg = torch.sigmoid( fp_j / ee) * torch.sigmoid(-fp_j1 / ee)
|
|
384
|
+
neg_to_pos = torch.sigmoid(-fp_j / ee) * torch.sigmoid( fp_j1 / ee)
|
|
385
|
+
crossing = pos_to_neg + neg_to_pos
|
|
386
|
+
local_max = torch.sigmoid(-fpp_mid / te)
|
|
387
|
+
|
|
388
|
+
f_peak = f.max(dim=1).values.clamp(min=1e-10) # (B,)
|
|
389
|
+
f_mid = (f[:, :-1] + f[:, 1:]) / 2 # (B, G-1)
|
|
390
|
+
fp_ref = (0.01 * f_peak).unsqueeze(1) # (B, 1)
|
|
391
|
+
fw_ref = (0.001 * f_peak).unsqueeze(1) # (B, 1)
|
|
392
|
+
density_mask = torch.sigmoid((f_mid - fp_ref) / fw_ref)
|
|
393
|
+
|
|
394
|
+
return (crossing * local_max * density_mask).sum(dim=1) # (B,)
|
dcb/layer.py
ADDED
|
@@ -0,0 +1,231 @@
|
|
|
1
|
+
"""
|
|
2
|
+
dcb.layer — DifferentiableCriticalBandwidth PyTorch Module
|
|
3
|
+
|
|
4
|
+
This module defines `DCBLayer`, the central contribution of the DCB package:
|
|
5
|
+
a `torch.nn.Module` that accepts a 1D tensor of samples X and returns the
|
|
6
|
+
critical bandwidth h_crit as a differentiable scalar. The forward pass
|
|
7
|
+
evaluates the soft mode-counting integral M̃(h; X) over a dense evaluation
|
|
8
|
+
grid Ω using Gaussian KDE derivatives from `dcb.kde`, then invokes the
|
|
9
|
+
bisection solver in `dcb.solver` to locate h_crit as the root of
|
|
10
|
+
M̃(h) - target_modes = 0. The backward pass is handled entirely by the
|
|
11
|
+
custom `DCBFunction` (a `torch.autograd.Function`) which applies the IFT
|
|
12
|
+
formula: ∂h_crit/∂X = -(∂M̃/∂h)^{-1} · (∂M̃/∂X)|_{h=h_crit}, bypassing
|
|
13
|
+
the computational graph of the iterative solver and maintaining O(1) memory
|
|
14
|
+
cost relative to the number of solver iterations. Hyperparameters ε and τ
|
|
15
|
+
may be supplied explicitly or computed adaptively via `dcb.utils`.
|
|
16
|
+
"""
|
|
17
|
+
|
|
18
|
+
from __future__ import annotations
|
|
19
|
+
|
|
20
|
+
import torch
|
|
21
|
+
import torch.nn as nn
|
|
22
|
+
from torch import Tensor
|
|
23
|
+
|
|
24
|
+
from dcb.solver import find_h_crit, ift_gradient
|
|
25
|
+
from dcb.utils import make_grid, silverman_bandwidth, adaptive_eps_tau, anneal_eps_tau
|
|
26
|
+
|
|
27
|
+
|
|
28
|
+
class DCBFunction(torch.autograd.Function):
|
|
29
|
+
"""Custom autograd Function for differentiable critical bandwidth.
|
|
30
|
+
|
|
31
|
+
Both the forward root-finding and the IFT backward use the same (eps, tau)
|
|
32
|
+
and formula ('cross' by default — the crossing-count M̃_cross that fixes the
|
|
33
|
+
m=1 divergence of the legacy integral formula).
|
|
34
|
+
"""
|
|
35
|
+
|
|
36
|
+
@staticmethod
|
|
37
|
+
def forward(ctx, X, grid, eps, tau, target_modes, delta, formula, chunk_size,
|
|
38
|
+
brentq_n_max, g_brentq, use_hard_bisection, safe_backward, use_fft):
|
|
39
|
+
"""Locate h_crit and save state for the backward pass."""
|
|
40
|
+
h_crit, cond_num = find_h_crit(
|
|
41
|
+
X, grid, eps, tau, target_modes,
|
|
42
|
+
formula=formula, brentq_n_max=brentq_n_max, chunk_size=chunk_size,
|
|
43
|
+
g_brentq=g_brentq, use_hard_bisection=use_hard_bisection,
|
|
44
|
+
use_fft=use_fft,
|
|
45
|
+
)
|
|
46
|
+
ctx.save_for_backward(X, grid)
|
|
47
|
+
ctx.h_crit = h_crit
|
|
48
|
+
ctx.eps = eps
|
|
49
|
+
ctx.tau = tau
|
|
50
|
+
ctx.delta = delta
|
|
51
|
+
ctx.formula = formula
|
|
52
|
+
ctx.chunk_size = chunk_size
|
|
53
|
+
ctx.cond_num = cond_num
|
|
54
|
+
ctx.safe_backward = safe_backward
|
|
55
|
+
return torch.tensor(h_crit, dtype=X.dtype, device=X.device)
|
|
56
|
+
|
|
57
|
+
@staticmethod
|
|
58
|
+
def backward(ctx, grad_output):
|
|
59
|
+
"""Apply the IFT formula to compute ∂h_crit/∂X."""
|
|
60
|
+
X, grid = ctx.saved_tensors
|
|
61
|
+
grad_X = ift_gradient(
|
|
62
|
+
X, ctx.h_crit, grid, ctx.eps, ctx.tau, grad_output,
|
|
63
|
+
delta=ctx.delta, formula=ctx.formula, chunk_size=ctx.chunk_size,
|
|
64
|
+
safe_backward=ctx.safe_backward,
|
|
65
|
+
)
|
|
66
|
+
ctx.guard_triggered = ift_gradient.last_guard_triggered
|
|
67
|
+
ctx.denom_abs = ift_gradient.last_denom_abs
|
|
68
|
+
# Gradients for: X, grid, eps, tau, target_modes, delta, formula,
|
|
69
|
+
# chunk_size, brentq_n_max, g_brentq, use_hard_bisection,
|
|
70
|
+
# safe_backward, use_fft
|
|
71
|
+
return grad_X, None, None, None, None, None, None, None, None, None, None, None, None
|
|
72
|
+
|
|
73
|
+
|
|
74
|
+
class DCBLayer(nn.Module):
|
|
75
|
+
"""Differentiable Critical Bandwidth as a PyTorch Module.
|
|
76
|
+
|
|
77
|
+
Computes the critical bandwidth h_crit for a 1D sample X — the smallest
|
|
78
|
+
bandwidth at which a Gaussian KDE has exactly `target_modes` modes — and
|
|
79
|
+
returns it as a differentiable scalar. Gradients flow back to X via the
|
|
80
|
+
Implicit Function Theorem.
|
|
81
|
+
|
|
82
|
+
The `anneal_factor` parameter scales eps and tau together, tightening the
|
|
83
|
+
soft mode-count approximation M̃ toward the true CBW. Use anneal_factor=1.0
|
|
84
|
+
(default) during gradient-based training for stable IFT gradients; reduce
|
|
85
|
+
toward 0.05–0.10 for eval-time accuracy (in torch.no_grad()). When
|
|
86
|
+
anneal_factor < 0.1 increase G proportionally (e.g. G=2048) so the grid
|
|
87
|
+
resolves the sharpened Gaussian delta approximation.
|
|
88
|
+
|
|
89
|
+
Parameters
|
|
90
|
+
----------
|
|
91
|
+
target_modes : int
|
|
92
|
+
Number of modes to target (default 1).
|
|
93
|
+
G : int
|
|
94
|
+
Number of evaluation grid points. Default 512.
|
|
95
|
+
delta : float
|
|
96
|
+
Stabilisation floor for the IFT denominator. Default 1e-4.
|
|
97
|
+
eps : float or None
|
|
98
|
+
Gaussian-delta width. If None, computed adaptively each call.
|
|
99
|
+
tau : float or None
|
|
100
|
+
Sigmoid temperature. If None, computed adaptively each call.
|
|
101
|
+
anneal_factor : float
|
|
102
|
+
Scales eps and tau; use 1.0 for training, <1 for eval accuracy.
|
|
103
|
+
formula : str
|
|
104
|
+
'cross' (default) — M̃_cross crossing-count formula, fixes m=1 bias.
|
|
105
|
+
'integral' — legacy integral formula (kept for ablation comparison).
|
|
106
|
+
g_brentq : int
|
|
107
|
+
Grid resolution for the brentq objective (default 128, Round 15a).
|
|
108
|
+
Ignored when use_hard_bisection=True.
|
|
109
|
+
use_hard_bisection : bool
|
|
110
|
+
If True (default, Round 15b), bisect on the hard mode count — provably
|
|
111
|
+
monotone, no false roots. If False, use legacy brentq on M̃_cross.
|
|
112
|
+
adaptive_G : bool
|
|
113
|
+
If True (Round 17a), compute grid resolution G dynamically each forward
|
|
114
|
+
call as max(G, min(32768, int(G * max(1.0, (n/1000)**0.2)))). This
|
|
115
|
+
scales G with sample size: G=512 at n≤1K, ~706 at n=5K, ~1119 at n=50K,
|
|
116
|
+
~2038 at n=1M. If False (default), use the fixed G value.
|
|
117
|
+
safe_backward : bool
|
|
118
|
+
If True (Round 17b), clamp the IFT denominator |∂M̃/∂h| to at least 0.1
|
|
119
|
+
(10× larger than the default guard of 0.01) to limit gradient magnitude
|
|
120
|
+
near mode-merging bifurcations. Default False. A warning is always emitted
|
|
121
|
+
when the denominator falls below 0.01, regardless of this flag.
|
|
122
|
+
use_fft : bool
|
|
123
|
+
Default True. Uses FFT-based mode counting (O(n + G log G)) for n > 50K,
|
|
124
|
+
eliminating subsampling bias. Falls back to direct KDE for n ≤ 50K (no
|
|
125
|
+
bias at small n). Set False only for legacy/ablation comparison.
|
|
126
|
+
|
|
127
|
+
Examples
|
|
128
|
+
--------
|
|
129
|
+
>>> layer = DCBLayer(target_modes=1)
|
|
130
|
+
>>> X = torch.cat([torch.randn(50) - 1, torch.randn(50) + 1]).requires_grad_(True)
|
|
131
|
+
>>> h = layer(X)
|
|
132
|
+
>>> h.backward()
|
|
133
|
+
>>> X.grad.shape
|
|
134
|
+
torch.Size([100])
|
|
135
|
+
"""
|
|
136
|
+
|
|
137
|
+
def __init__(
|
|
138
|
+
self,
|
|
139
|
+
target_modes: int = 1,
|
|
140
|
+
G: int = 512,
|
|
141
|
+
delta: float = 1e-4,
|
|
142
|
+
eps: float = None,
|
|
143
|
+
tau: float = None,
|
|
144
|
+
anneal_factor: float = 1.0,
|
|
145
|
+
formula: str = 'cross',
|
|
146
|
+
chunk_size: int = 50_000,
|
|
147
|
+
brentq_n_max: int = 50_000,
|
|
148
|
+
g_brentq: int = 128,
|
|
149
|
+
use_hard_bisection: bool = True,
|
|
150
|
+
adaptive_G: bool = False,
|
|
151
|
+
safe_backward: bool = False,
|
|
152
|
+
use_fft: bool = True,
|
|
153
|
+
):
|
|
154
|
+
super().__init__()
|
|
155
|
+
self.target_modes = target_modes
|
|
156
|
+
self.G = G
|
|
157
|
+
self.delta = delta
|
|
158
|
+
self._eps = eps
|
|
159
|
+
self._tau = tau
|
|
160
|
+
self.anneal_factor = anneal_factor
|
|
161
|
+
self.formula = formula
|
|
162
|
+
self.chunk_size = chunk_size
|
|
163
|
+
self.brentq_n_max = brentq_n_max
|
|
164
|
+
self.g_brentq = g_brentq
|
|
165
|
+
self.use_hard_bisection = use_hard_bisection
|
|
166
|
+
self.adaptive_G = adaptive_G
|
|
167
|
+
self.safe_backward = safe_backward
|
|
168
|
+
self.use_fft = use_fft
|
|
169
|
+
if use_fft and brentq_n_max != 50_000:
|
|
170
|
+
raise TypeError(
|
|
171
|
+
f"brentq_n_max={brentq_n_max} is meaningless when use_fft=True: the FFT path "
|
|
172
|
+
"does not subsample and ignores this parameter. Remove brentq_n_max from your "
|
|
173
|
+
"DCBLayer constructor, or set use_fft=False to use the legacy subsampling path."
|
|
174
|
+
)
|
|
175
|
+
if not use_fft and brentq_n_max != 50_000:
|
|
176
|
+
import warnings
|
|
177
|
+
warnings.warn(
|
|
178
|
+
"brentq_n_max is deprecated. When use_fft=True (default), this parameter is "
|
|
179
|
+
"ignored. Set use_fft=True (default) instead of tuning brentq_n_max.",
|
|
180
|
+
DeprecationWarning, stacklevel=2,
|
|
181
|
+
)
|
|
182
|
+
|
|
183
|
+
def set_anneal_factor(self, factor: float) -> None:
|
|
184
|
+
"""Update the sharpening factor in-place (call before eval)."""
|
|
185
|
+
self.anneal_factor = float(factor)
|
|
186
|
+
|
|
187
|
+
def forward(self, X: Tensor) -> Tensor:
|
|
188
|
+
"""Compute h_crit for sample X.
|
|
189
|
+
|
|
190
|
+
Parameters
|
|
191
|
+
----------
|
|
192
|
+
X : Tensor, shape (n,)
|
|
193
|
+
1D sample tensor. Must be on a single device; may require grad.
|
|
194
|
+
|
|
195
|
+
Returns
|
|
196
|
+
-------
|
|
197
|
+
Tensor, shape ()
|
|
198
|
+
Scalar h_crit, differentiable w.r.t. X.
|
|
199
|
+
"""
|
|
200
|
+
n = X.shape[0]
|
|
201
|
+
G_eff = (
|
|
202
|
+
max(self.G, min(32768, int(self.G * max(1.0, (n / 1000) ** 0.2))))
|
|
203
|
+
if self.adaptive_G else self.G
|
|
204
|
+
)
|
|
205
|
+
grid = make_grid(X.detach(), G_eff)
|
|
206
|
+
|
|
207
|
+
if self.formula == 'cross':
|
|
208
|
+
# M̃_cross self-normalises by std(f'_h) internally; eps/tau are
|
|
209
|
+
# dimensionless fractions. anneal_factor still sharpens them.
|
|
210
|
+
eps = self._eps if self._eps is not None else 0.1
|
|
211
|
+
tau = self._tau if self._tau is not None else 0.2
|
|
212
|
+
else:
|
|
213
|
+
if self._eps is None or self._tau is None:
|
|
214
|
+
h0 = silverman_bandwidth(X.detach())
|
|
215
|
+
eps_a, tau_a = adaptive_eps_tau(X.detach(), h0, grid)
|
|
216
|
+
eps = self._eps if self._eps is not None else eps_a
|
|
217
|
+
tau = self._tau if self._tau is not None else tau_a
|
|
218
|
+
else:
|
|
219
|
+
eps, tau = self._eps, self._tau
|
|
220
|
+
|
|
221
|
+
eps_eff, tau_eff = anneal_eps_tau(eps, tau, self.anneal_factor)
|
|
222
|
+
|
|
223
|
+
return DCBFunction.apply(
|
|
224
|
+
X, grid, eps_eff, tau_eff, self.target_modes, self.delta, self.formula,
|
|
225
|
+
self.chunk_size, self.brentq_n_max, self.g_brentq, self.use_hard_bisection,
|
|
226
|
+
self.safe_backward, self.use_fft,
|
|
227
|
+
)
|
|
228
|
+
|
|
229
|
+
|
|
230
|
+
# Public alias
|
|
231
|
+
DifferentiableCriticalBandwidth = DCBLayer
|