diffcb 0.1.0__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
dcb/kde.py ADDED
@@ -0,0 +1,394 @@
1
+ """
2
+ dcb.kde — Differentiable Gaussian KDE and Derivative Utilities
3
+
4
+ Two computation paths share the same formulas but differ in memory layout:
5
+
6
+ Dense path (n × G matrix, small n)
7
+ gaussian_kde_grid, kde_derivatives, soft_mode_count_cross
8
+ — simple, autograd-friendly, O(n × G) memory.
9
+
10
+ Chunked path (processes X in blocks, large n / GPU)
11
+ kde_derivatives_chunked, soft_mode_count_cross_from_derivs
12
+ — O(chunk × G) peak memory, same arithmetic.
13
+ kde_derivatives_chunked returns (f, f', f'') on the grid by
14
+ accumulating chunk contributions; soft_mode_count_cross_from_derivs
15
+ evaluates M̃_cross given pre-computed derivatives, enabling the
16
+ analytical IFT gradient without materialising the full n × G graph.
17
+
18
+ All functions operate on 1D input tensors (univariate DCB).
19
+ """
20
+
21
+ from __future__ import annotations
22
+
23
+ import math
24
+
25
+ import torch
26
+ from torch import Tensor
27
+
28
+
29
+ def gaussian_kde_grid(X: Tensor, h: float, grid: Tensor) -> Tensor:
30
+ """Evaluate the Gaussian KDE f_h on a grid of points.
31
+
32
+ Parameters
33
+ ----------
34
+ X : Tensor, shape (n,)
35
+ Observed data points.
36
+ h : float
37
+ Bandwidth (positive scalar).
38
+ grid : Tensor, shape (G,)
39
+ Evaluation points.
40
+
41
+ Returns
42
+ -------
43
+ Tensor, shape (G,)
44
+ f_h evaluated at each grid point.
45
+ """
46
+ # diff[i, j] = grid[j] - X[i], shape (n, G)
47
+ diff = grid.unsqueeze(0) - X.unsqueeze(1)
48
+ K = torch.exp(-0.5 * (diff / h) ** 2) / (math.sqrt(2 * math.pi) * h)
49
+ return K.mean(dim=0) # (G,)
50
+
51
+
52
+ def kde_derivatives(
53
+ X: Tensor, h: float, grid: Tensor
54
+ ) -> tuple[Tensor, Tensor, Tensor]:
55
+ """Evaluate the Gaussian KDE and its first two analytical derivatives on a grid.
56
+
57
+ Parameters
58
+ ----------
59
+ X : Tensor, shape (n,)
60
+ Observed data points.
61
+ h : float
62
+ Bandwidth (positive scalar).
63
+ grid : Tensor, shape (G,)
64
+ Evaluation points.
65
+
66
+ Returns
67
+ -------
68
+ (f, f_prime, f_double_prime) : tuple of Tensor, each shape (G,)
69
+ KDE value, first derivative, and second derivative at each grid point.
70
+ All outputs support autograd.
71
+ """
72
+ # diff[i, j] = grid[j] - X[i], shape (n, G)
73
+ diff = grid.unsqueeze(0) - X.unsqueeze(1)
74
+ K = torch.exp(-0.5 * (diff / h) ** 2) / (math.sqrt(2 * math.pi) * h)
75
+
76
+ f = K.mean(0)
77
+ f_prime = (-diff / h ** 2 * K).mean(0)
78
+ f_double_prime = ((diff ** 2 / h ** 4) - (1.0 / h ** 2)) * K
79
+ f_double_prime = f_double_prime.mean(0)
80
+
81
+ return f, f_prime, f_double_prime
82
+
83
+
84
+ def kde_derivatives_chunked(
85
+ X: Tensor, h: float, grid: Tensor, chunk_size: int = 50_000
86
+ ) -> tuple[Tensor, Tensor, Tensor]:
87
+ """Chunked KDE derivatives — O(chunk_size × G) peak memory.
88
+
89
+ Identical arithmetic to `kde_derivatives` but accumulates contributions
90
+ from X in blocks of `chunk_size` so the peak GPU allocation is
91
+ O(chunk_size × G) instead of O(n × G). Fully differentiable via autograd
92
+ (the chunk loop produces a valid computation graph).
93
+
94
+ Parameters
95
+ ----------
96
+ X : Tensor, shape (n,)
97
+ h : float
98
+ grid : Tensor, shape (G,)
99
+ chunk_size : int
100
+ Number of data points per block. Default 50 000 uses ≈100 MB at
101
+ G = 512, float32. Increase for better GPU utilisation; decrease if
102
+ GPU OOM.
103
+
104
+ Returns
105
+ -------
106
+ (f, f_prime, f_double_prime) : tuple of Tensor, each shape (G,)
107
+ """
108
+ n, G = X.shape[0], grid.shape[0]
109
+ f = torch.zeros(G, dtype=X.dtype, device=X.device)
110
+ fp = torch.zeros_like(f)
111
+ fpp = torch.zeros_like(f)
112
+ for start in range(0, n, chunk_size):
113
+ Xc = X[start : start + chunk_size]
114
+ diff = grid.unsqueeze(0) - Xc.unsqueeze(1) # (c, G)
115
+ K = torch.exp(-0.5 * (diff / h) ** 2) / (math.sqrt(2 * math.pi) * h)
116
+ f.add_(K.sum(0))
117
+ fp.add_((-diff / h ** 2 * K).sum(0))
118
+ fpp.add_((((diff ** 2 / h ** 4) - (1.0 / h ** 2)) * K).sum(0))
119
+ f.div_(n); fp.div_(n); fpp.div_(n)
120
+ return f, fp, fpp
121
+
122
+
123
+ def soft_mode_count_cross_from_derivs(
124
+ f: Tensor,
125
+ f_prime: Tensor,
126
+ f_double_prime: Tensor,
127
+ grid: Tensor,
128
+ eps: float,
129
+ tau: float,
130
+ ) -> Tensor:
131
+ """Evaluate M̃_cross from pre-computed KDE derivatives.
132
+
133
+ Same formula as `soft_mode_count_cross` but accepts (f, f', f'') directly
134
+ instead of recomputing them from X. Used by the large-n analytical IFT
135
+ path: call `kde_derivatives_chunked` once to get (f, f', f''), then pass
136
+ them here with `requires_grad=True` to obtain ∂M̃/∂f' and ∂M̃/∂f'' via
137
+ autograd — only G-dimensional tensors need to stay in the graph.
138
+
139
+ Parameters
140
+ ----------
141
+ f, f_prime, f_double_prime : Tensor, each shape (G,)
142
+ grid : Tensor, shape (G,)
143
+ eps, tau : float
144
+
145
+ Returns
146
+ -------
147
+ Tensor, shape ()
148
+ """
149
+ dx = ((grid[-1] - grid[0]) / (grid.shape[0] - 1)).detach()
150
+ fpp_scale = f_double_prime.abs().median().clamp(min=1e-10)
151
+ eps_eff = eps * fpp_scale * dx
152
+ tau_eff = tau * fpp_scale
153
+
154
+ fp_j = f_prime[:-1]
155
+ fp_j1 = f_prime[1:]
156
+ fpp_mid = (f_double_prime[:-1] + f_double_prime[1:]) / 2
157
+
158
+ pos_to_neg = torch.sigmoid( fp_j / eps_eff) * torch.sigmoid(-fp_j1 / eps_eff)
159
+ neg_to_pos = torch.sigmoid(-fp_j / eps_eff) * torch.sigmoid( fp_j1 / eps_eff)
160
+ crossing = pos_to_neg + neg_to_pos
161
+ local_max = torch.sigmoid(-fpp_mid / tau_eff)
162
+
163
+ f_peak = f.max().clamp(min=1e-10)
164
+ f_mid = (f[:-1] + f[1:]) / 2
165
+ density_mask = torch.sigmoid((f_mid - 0.01 * f_peak) / (0.001 * f_peak))
166
+
167
+ return (crossing * local_max * density_mask).sum()
168
+
169
+
170
+ def soft_mode_count_cross(
171
+ X: Tensor, h: float, grid: Tensor, eps: float, tau: float
172
+ ) -> Tensor:
173
+ """Compute the crossing-count soft mode count M̃_cross(h).
174
+
175
+ Counts positive-to-negative sign changes of f'_h that coincide with a local
176
+ maximum (f'' < 0), using sigmoid products to smoothly approximate the
177
+ discrete sign-change indicator. Unlike the integral formula, M̃_cross is
178
+ bounded and converges to M(h) as ε → 0 on a dense grid, without the
179
+ divergence pathology of the integral formula in the unimodal regime.
180
+
181
+ Formula:
182
+ M̃_cross = Σ_{j=0}^{G-2}
183
+ [σ(f'_j/ε)·σ(−f'_{j+1}/ε) + σ(−f'_j/ε)·σ(f'_{j+1}/ε)]
184
+ · σ(−f''_mid,j / τ)
185
+
186
+ where f''_mid,j = (f''_j + f''_{j+1}) / 2 and σ is the logistic sigmoid.
187
+ The first bracket captures both + → − and − → + crossings; the second
188
+ bracket selects only local maxima (f'' < 0). The result equals the number
189
+ of downward zero-crossings of f', i.e., the number of modes of f_h.
190
+
191
+ Parameters
192
+ ----------
193
+ X : Tensor, shape (n,)
194
+ Observed data points.
195
+ h : float
196
+ Bandwidth (positive scalar).
197
+ grid : Tensor, shape (G,)
198
+ Uniform evaluation grid.
199
+ eps : float
200
+ Sigmoid temperature controlling the sharpness of the zero-crossing
201
+ detector on f' (smaller = sharper, converges to true M(h) faster).
202
+ tau : float
203
+ Sigmoid temperature for the local-max selector on f''.
204
+
205
+ Returns
206
+ -------
207
+ Tensor, shape ()
208
+ Scalar soft mode count in [0, G−1], differentiable w.r.t. X.
209
+ """
210
+ f, f_prime, f_double_prime = kde_derivatives(X, h, grid)
211
+
212
+ # Calibrate eps to the local crossing scale, not the global f' amplitude.
213
+ # A single pos→neg crossing of f'_h happens over ~Δx in x-space; the f'
214
+ # value just outside the crossing pair is O(|f''_h| · Δx). We need
215
+ # eps_eff << |f''_h| · Δx so non-crossing pairs saturate while the
216
+ # crossing pair contributes ~1. Using eps_eff = eps · |f''_h| · Δx gives
217
+ # σ(|f'_crossing| / eps_eff) ≈ σ(1/(2·eps)) — well-calibrated for any h.
218
+ # Stop-gradient prevents these scale factors from entering the IFT backward.
219
+ dx = ((grid[-1] - grid[0]) / (grid.shape[0] - 1)).detach()
220
+ # fpp_scale retains gradient so the full ∂M/∂X and ∂M/∂h paths flow through
221
+ # the self-normalizing calibration — required for gradcheck to pass.
222
+ fpp_scale = f_double_prime.abs().median().clamp(min=1e-10)
223
+ eps_eff = eps * fpp_scale * dx # dimensionless × curvature × grid-step
224
+ tau_eff = tau * fpp_scale # dimensionless × curvature
225
+
226
+ fp_j = f_prime[:-1]
227
+ fp_j1 = f_prime[1:]
228
+ fpp_mid = (f_double_prime[:-1] + f_double_prime[1:]) / 2
229
+
230
+ pos_to_neg = torch.sigmoid(fp_j / eps_eff) * torch.sigmoid(-fp_j1 / eps_eff)
231
+ neg_to_pos = torch.sigmoid(-fp_j / eps_eff) * torch.sigmoid(fp_j1 / eps_eff)
232
+ crossing = pos_to_neg + neg_to_pos
233
+
234
+ local_max = torch.sigmoid(-fpp_mid / tau_eff)
235
+
236
+ # Suppress spurious tail crossings: in low-density regions f'_h ≈ 0 but
237
+ # fluctuates, giving fractional sigmoid values on every grid pair. Mask
238
+ # out intervals where the average density is below 1% of the peak.
239
+ # The threshold and sharpness are both relative to f_peak so this scales
240
+ # automatically with the KDE amplitude at any bandwidth.
241
+ f_peak = f.max().clamp(min=1e-10)
242
+ f_mid = (f[:-1] + f[1:]) / 2
243
+ density_mask = torch.sigmoid((f_mid - 0.01 * f_peak) / (0.001 * f_peak))
244
+
245
+ return (crossing * local_max * density_mask).sum()
246
+
247
+
248
+ def soft_mode_count(
249
+ X: Tensor, h: float, grid: Tensor, eps: float, tau: float
250
+ ) -> Tensor:
251
+ """Compute the soft (differentiable) mode count M̃(h).
252
+
253
+ Approximates the number of modes of f_h via the Riemann-sum integral
254
+
255
+ M̃(h) = Δx · Σ_j δ_ε(f'_h(grid_j)) · σ_τ(-f''_h(grid_j)) · |f''_h(grid_j)|
256
+
257
+ where δ_ε is a Gaussian approximation to the Dirac delta (peaking where
258
+ f' = 0) and σ_τ is a sigmoid that selects local maxima (f'' < 0).
259
+
260
+ Parameters
261
+ ----------
262
+ X : Tensor, shape (n,)
263
+ Observed data points.
264
+ h : float
265
+ Bandwidth (positive scalar).
266
+ grid : Tensor, shape (G,)
267
+ Uniform evaluation grid.
268
+ eps : float
269
+ Width of the Gaussian delta approximation (controls sharpness of the
270
+ zero-crossing detector on f').
271
+ tau : float
272
+ Temperature of the sigmoid (controls sharpness of the local-max selector
273
+ on f'').
274
+
275
+ Returns
276
+ -------
277
+ Tensor, shape ()
278
+ Scalar soft mode count, differentiable w.r.t. X.
279
+ """
280
+ _, f_prime, f_double_prime = kde_derivatives(X, h, grid)
281
+
282
+ # Gaussian approximation to Dirac delta centred at f' = 0
283
+ delta_eps = torch.exp(-0.5 * (f_prime / eps) ** 2) / (
284
+ math.sqrt(2 * math.pi) * eps
285
+ )
286
+
287
+ # Sigmoid selects points where f'' < 0, i.e. local maxima
288
+ sigma_tau = torch.sigmoid(-f_double_prime / tau)
289
+
290
+ integrand = delta_eps * sigma_tau * f_double_prime.abs()
291
+
292
+ dx = (grid[-1] - grid[0]) / (grid.shape[0] - 1)
293
+ return (integrand * dx).sum()
294
+
295
+
296
+ def kde_derivatives_batched(
297
+ X: Tensor,
298
+ h_batch: Tensor,
299
+ grid: Tensor,
300
+ chunk_size: int = 10_000,
301
+ ) -> tuple[Tensor, Tensor, Tensor]:
302
+ """Evaluate Gaussian KDE derivatives for B bandwidths simultaneously.
303
+
304
+ Accumulates contributions from X in chunks to keep peak memory at
305
+ O(chunk_size × B × G) instead of O(n × B × G).
306
+
307
+ Parameters
308
+ ----------
309
+ X : Tensor, shape (n,)
310
+ h_batch : Tensor, shape (B,)
311
+ Candidate bandwidths.
312
+ grid : Tensor, shape (G,)
313
+ chunk_size : int
314
+ Number of data points per chunk. Default 10 000.
315
+
316
+ Returns
317
+ -------
318
+ (f, f_prime, f_double_prime) : tuple of Tensor, each shape (B, G)
319
+ KDE value and its first two derivatives, one row per bandwidth.
320
+ """
321
+ n, G, B = X.shape[0], grid.shape[0], h_batch.shape[0]
322
+ f = torch.zeros(B, G, dtype=X.dtype, device=X.device)
323
+ fp = torch.zeros_like(f)
324
+ fpp = torch.zeros_like(f)
325
+
326
+ # h_batch: (B,) → (1, B, 1) for broadcasting against (c, 1, G)
327
+ h = h_batch.view(1, B, 1)
328
+ h2 = h * h
329
+ norm = 1.0 / (math.sqrt(2 * math.pi) * h) # (1, B, 1)
330
+
331
+ for start in range(0, n, chunk_size):
332
+ Xc = X[start : start + chunk_size] # (c,)
333
+ # diff[i, j] = grid[j] - Xc[i] → (c, 1, G) broadcasts to (c, B, G)
334
+ diff = (grid.unsqueeze(0) - Xc.unsqueeze(1)).unsqueeze(1) # (c, 1, G)
335
+ u = diff / h # (c, B, G)
336
+ K = norm * torch.exp(-0.5 * u * u) # (c, B, G)
337
+ f.add_(K.sum(0)) # (B, G)
338
+ fp.add_((-u / h * K).sum(0)) # (B, G)
339
+ fpp.add_((((u * u - 1.0) / h2) * K).sum(0)) # (B, G)
340
+
341
+ f.div_(n); fp.div_(n); fpp.div_(n)
342
+ return f, fp, fpp
343
+
344
+
345
+ def soft_mode_count_cross_batched(
346
+ f: Tensor,
347
+ fp: Tensor,
348
+ fpp: Tensor,
349
+ grid: Tensor,
350
+ eps: float,
351
+ tau: float,
352
+ ) -> Tensor:
353
+ """Vectorised M̃_cross over B bandwidths simultaneously.
354
+
355
+ Applies the same formula as `soft_mode_count_cross_from_derivs` but
356
+ operates on (B, G) tensors, returning one scalar per bandwidth.
357
+
358
+ Parameters
359
+ ----------
360
+ f, fp, fpp : Tensor, each shape (B, G)
361
+ grid : Tensor, shape (G,)
362
+ eps, tau : float
363
+
364
+ Returns
365
+ -------
366
+ Tensor, shape (B,)
367
+ One soft mode count per bandwidth.
368
+ """
369
+ dx = ((grid[-1] - grid[0]) / (grid.shape[0] - 1)).detach()
370
+
371
+ # Per-bandwidth calibration: median of |fpp| along grid axis
372
+ fpp_scale = fpp.abs().median(dim=1).values.clamp(min=1e-10) # (B,)
373
+ eps_eff = eps * fpp_scale * dx # (B,)
374
+ tau_eff = tau * fpp_scale # (B,)
375
+
376
+ fp_j = fp[:, :-1] # (B, G-1)
377
+ fp_j1 = fp[:, 1:] # (B, G-1)
378
+ fpp_mid = (fpp[:, :-1] + fpp[:, 1:]) / 2 # (B, G-1)
379
+
380
+ ee = eps_eff.unsqueeze(1) # (B, 1)
381
+ te = tau_eff.unsqueeze(1) # (B, 1)
382
+
383
+ pos_to_neg = torch.sigmoid( fp_j / ee) * torch.sigmoid(-fp_j1 / ee)
384
+ neg_to_pos = torch.sigmoid(-fp_j / ee) * torch.sigmoid( fp_j1 / ee)
385
+ crossing = pos_to_neg + neg_to_pos
386
+ local_max = torch.sigmoid(-fpp_mid / te)
387
+
388
+ f_peak = f.max(dim=1).values.clamp(min=1e-10) # (B,)
389
+ f_mid = (f[:, :-1] + f[:, 1:]) / 2 # (B, G-1)
390
+ fp_ref = (0.01 * f_peak).unsqueeze(1) # (B, 1)
391
+ fw_ref = (0.001 * f_peak).unsqueeze(1) # (B, 1)
392
+ density_mask = torch.sigmoid((f_mid - fp_ref) / fw_ref)
393
+
394
+ return (crossing * local_max * density_mask).sum(dim=1) # (B,)
dcb/layer.py ADDED
@@ -0,0 +1,231 @@
1
+ """
2
+ dcb.layer — DifferentiableCriticalBandwidth PyTorch Module
3
+
4
+ This module defines `DCBLayer`, the central contribution of the DCB package:
5
+ a `torch.nn.Module` that accepts a 1D tensor of samples X and returns the
6
+ critical bandwidth h_crit as a differentiable scalar. The forward pass
7
+ evaluates the soft mode-counting integral M̃(h; X) over a dense evaluation
8
+ grid Ω using Gaussian KDE derivatives from `dcb.kde`, then invokes the
9
+ bisection solver in `dcb.solver` to locate h_crit as the root of
10
+ M̃(h) - target_modes = 0. The backward pass is handled entirely by the
11
+ custom `DCBFunction` (a `torch.autograd.Function`) which applies the IFT
12
+ formula: ∂h_crit/∂X = -(∂M̃/∂h)^{-1} · (∂M̃/∂X)|_{h=h_crit}, bypassing
13
+ the computational graph of the iterative solver and maintaining O(1) memory
14
+ cost relative to the number of solver iterations. Hyperparameters ε and τ
15
+ may be supplied explicitly or computed adaptively via `dcb.utils`.
16
+ """
17
+
18
+ from __future__ import annotations
19
+
20
+ import torch
21
+ import torch.nn as nn
22
+ from torch import Tensor
23
+
24
+ from dcb.solver import find_h_crit, ift_gradient
25
+ from dcb.utils import make_grid, silverman_bandwidth, adaptive_eps_tau, anneal_eps_tau
26
+
27
+
28
+ class DCBFunction(torch.autograd.Function):
29
+ """Custom autograd Function for differentiable critical bandwidth.
30
+
31
+ Both the forward root-finding and the IFT backward use the same (eps, tau)
32
+ and formula ('cross' by default — the crossing-count M̃_cross that fixes the
33
+ m=1 divergence of the legacy integral formula).
34
+ """
35
+
36
+ @staticmethod
37
+ def forward(ctx, X, grid, eps, tau, target_modes, delta, formula, chunk_size,
38
+ brentq_n_max, g_brentq, use_hard_bisection, safe_backward, use_fft):
39
+ """Locate h_crit and save state for the backward pass."""
40
+ h_crit, cond_num = find_h_crit(
41
+ X, grid, eps, tau, target_modes,
42
+ formula=formula, brentq_n_max=brentq_n_max, chunk_size=chunk_size,
43
+ g_brentq=g_brentq, use_hard_bisection=use_hard_bisection,
44
+ use_fft=use_fft,
45
+ )
46
+ ctx.save_for_backward(X, grid)
47
+ ctx.h_crit = h_crit
48
+ ctx.eps = eps
49
+ ctx.tau = tau
50
+ ctx.delta = delta
51
+ ctx.formula = formula
52
+ ctx.chunk_size = chunk_size
53
+ ctx.cond_num = cond_num
54
+ ctx.safe_backward = safe_backward
55
+ return torch.tensor(h_crit, dtype=X.dtype, device=X.device)
56
+
57
+ @staticmethod
58
+ def backward(ctx, grad_output):
59
+ """Apply the IFT formula to compute ∂h_crit/∂X."""
60
+ X, grid = ctx.saved_tensors
61
+ grad_X = ift_gradient(
62
+ X, ctx.h_crit, grid, ctx.eps, ctx.tau, grad_output,
63
+ delta=ctx.delta, formula=ctx.formula, chunk_size=ctx.chunk_size,
64
+ safe_backward=ctx.safe_backward,
65
+ )
66
+ ctx.guard_triggered = ift_gradient.last_guard_triggered
67
+ ctx.denom_abs = ift_gradient.last_denom_abs
68
+ # Gradients for: X, grid, eps, tau, target_modes, delta, formula,
69
+ # chunk_size, brentq_n_max, g_brentq, use_hard_bisection,
70
+ # safe_backward, use_fft
71
+ return grad_X, None, None, None, None, None, None, None, None, None, None, None, None
72
+
73
+
74
+ class DCBLayer(nn.Module):
75
+ """Differentiable Critical Bandwidth as a PyTorch Module.
76
+
77
+ Computes the critical bandwidth h_crit for a 1D sample X — the smallest
78
+ bandwidth at which a Gaussian KDE has exactly `target_modes` modes — and
79
+ returns it as a differentiable scalar. Gradients flow back to X via the
80
+ Implicit Function Theorem.
81
+
82
+ The `anneal_factor` parameter scales eps and tau together, tightening the
83
+ soft mode-count approximation M̃ toward the true CBW. Use anneal_factor=1.0
84
+ (default) during gradient-based training for stable IFT gradients; reduce
85
+ toward 0.05–0.10 for eval-time accuracy (in torch.no_grad()). When
86
+ anneal_factor < 0.1 increase G proportionally (e.g. G=2048) so the grid
87
+ resolves the sharpened Gaussian delta approximation.
88
+
89
+ Parameters
90
+ ----------
91
+ target_modes : int
92
+ Number of modes to target (default 1).
93
+ G : int
94
+ Number of evaluation grid points. Default 512.
95
+ delta : float
96
+ Stabilisation floor for the IFT denominator. Default 1e-4.
97
+ eps : float or None
98
+ Gaussian-delta width. If None, computed adaptively each call.
99
+ tau : float or None
100
+ Sigmoid temperature. If None, computed adaptively each call.
101
+ anneal_factor : float
102
+ Scales eps and tau; use 1.0 for training, <1 for eval accuracy.
103
+ formula : str
104
+ 'cross' (default) — M̃_cross crossing-count formula, fixes m=1 bias.
105
+ 'integral' — legacy integral formula (kept for ablation comparison).
106
+ g_brentq : int
107
+ Grid resolution for the brentq objective (default 128, Round 15a).
108
+ Ignored when use_hard_bisection=True.
109
+ use_hard_bisection : bool
110
+ If True (default, Round 15b), bisect on the hard mode count — provably
111
+ monotone, no false roots. If False, use legacy brentq on M̃_cross.
112
+ adaptive_G : bool
113
+ If True (Round 17a), compute grid resolution G dynamically each forward
114
+ call as max(G, min(32768, int(G * max(1.0, (n/1000)**0.2)))). This
115
+ scales G with sample size: G=512 at n≤1K, ~706 at n=5K, ~1119 at n=50K,
116
+ ~2038 at n=1M. If False (default), use the fixed G value.
117
+ safe_backward : bool
118
+ If True (Round 17b), clamp the IFT denominator |∂M̃/∂h| to at least 0.1
119
+ (10× larger than the default guard of 0.01) to limit gradient magnitude
120
+ near mode-merging bifurcations. Default False. A warning is always emitted
121
+ when the denominator falls below 0.01, regardless of this flag.
122
+ use_fft : bool
123
+ Default True. Uses FFT-based mode counting (O(n + G log G)) for n > 50K,
124
+ eliminating subsampling bias. Falls back to direct KDE for n ≤ 50K (no
125
+ bias at small n). Set False only for legacy/ablation comparison.
126
+
127
+ Examples
128
+ --------
129
+ >>> layer = DCBLayer(target_modes=1)
130
+ >>> X = torch.cat([torch.randn(50) - 1, torch.randn(50) + 1]).requires_grad_(True)
131
+ >>> h = layer(X)
132
+ >>> h.backward()
133
+ >>> X.grad.shape
134
+ torch.Size([100])
135
+ """
136
+
137
+ def __init__(
138
+ self,
139
+ target_modes: int = 1,
140
+ G: int = 512,
141
+ delta: float = 1e-4,
142
+ eps: float = None,
143
+ tau: float = None,
144
+ anneal_factor: float = 1.0,
145
+ formula: str = 'cross',
146
+ chunk_size: int = 50_000,
147
+ brentq_n_max: int = 50_000,
148
+ g_brentq: int = 128,
149
+ use_hard_bisection: bool = True,
150
+ adaptive_G: bool = False,
151
+ safe_backward: bool = False,
152
+ use_fft: bool = True,
153
+ ):
154
+ super().__init__()
155
+ self.target_modes = target_modes
156
+ self.G = G
157
+ self.delta = delta
158
+ self._eps = eps
159
+ self._tau = tau
160
+ self.anneal_factor = anneal_factor
161
+ self.formula = formula
162
+ self.chunk_size = chunk_size
163
+ self.brentq_n_max = brentq_n_max
164
+ self.g_brentq = g_brentq
165
+ self.use_hard_bisection = use_hard_bisection
166
+ self.adaptive_G = adaptive_G
167
+ self.safe_backward = safe_backward
168
+ self.use_fft = use_fft
169
+ if use_fft and brentq_n_max != 50_000:
170
+ raise TypeError(
171
+ f"brentq_n_max={brentq_n_max} is meaningless when use_fft=True: the FFT path "
172
+ "does not subsample and ignores this parameter. Remove brentq_n_max from your "
173
+ "DCBLayer constructor, or set use_fft=False to use the legacy subsampling path."
174
+ )
175
+ if not use_fft and brentq_n_max != 50_000:
176
+ import warnings
177
+ warnings.warn(
178
+ "brentq_n_max is deprecated. When use_fft=True (default), this parameter is "
179
+ "ignored. Set use_fft=True (default) instead of tuning brentq_n_max.",
180
+ DeprecationWarning, stacklevel=2,
181
+ )
182
+
183
+ def set_anneal_factor(self, factor: float) -> None:
184
+ """Update the sharpening factor in-place (call before eval)."""
185
+ self.anneal_factor = float(factor)
186
+
187
+ def forward(self, X: Tensor) -> Tensor:
188
+ """Compute h_crit for sample X.
189
+
190
+ Parameters
191
+ ----------
192
+ X : Tensor, shape (n,)
193
+ 1D sample tensor. Must be on a single device; may require grad.
194
+
195
+ Returns
196
+ -------
197
+ Tensor, shape ()
198
+ Scalar h_crit, differentiable w.r.t. X.
199
+ """
200
+ n = X.shape[0]
201
+ G_eff = (
202
+ max(self.G, min(32768, int(self.G * max(1.0, (n / 1000) ** 0.2))))
203
+ if self.adaptive_G else self.G
204
+ )
205
+ grid = make_grid(X.detach(), G_eff)
206
+
207
+ if self.formula == 'cross':
208
+ # M̃_cross self-normalises by std(f'_h) internally; eps/tau are
209
+ # dimensionless fractions. anneal_factor still sharpens them.
210
+ eps = self._eps if self._eps is not None else 0.1
211
+ tau = self._tau if self._tau is not None else 0.2
212
+ else:
213
+ if self._eps is None or self._tau is None:
214
+ h0 = silverman_bandwidth(X.detach())
215
+ eps_a, tau_a = adaptive_eps_tau(X.detach(), h0, grid)
216
+ eps = self._eps if self._eps is not None else eps_a
217
+ tau = self._tau if self._tau is not None else tau_a
218
+ else:
219
+ eps, tau = self._eps, self._tau
220
+
221
+ eps_eff, tau_eff = anneal_eps_tau(eps, tau, self.anneal_factor)
222
+
223
+ return DCBFunction.apply(
224
+ X, grid, eps_eff, tau_eff, self.target_modes, self.delta, self.formula,
225
+ self.chunk_size, self.brentq_n_max, self.g_brentq, self.use_hard_bisection,
226
+ self.safe_backward, self.use_fft,
227
+ )
228
+
229
+
230
+ # Public alias
231
+ DifferentiableCriticalBandwidth = DCBLayer