diffcb 0.1.0__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- dcb/__init__.py +22 -0
- dcb/diagnostics.py +163 -0
- dcb/fft_kde.py +128 -0
- dcb/kde.py +394 -0
- dcb/layer.py +231 -0
- dcb/solver.py +604 -0
- dcb/utils.py +183 -0
- diffcb-0.1.0.dist-info/METADATA +148 -0
- diffcb-0.1.0.dist-info/RECORD +11 -0
- diffcb-0.1.0.dist-info/WHEEL +4 -0
- diffcb-0.1.0.dist-info/licenses/LICENSE +21 -0
dcb/solver.py
ADDED
|
@@ -0,0 +1,604 @@
|
|
|
1
|
+
"""
|
|
2
|
+
dcb.solver — IFT Root-Finder and Backward Pass
|
|
3
|
+
|
|
4
|
+
This module implements the two-stage computation that produces a differentiable
|
|
5
|
+
h_crit: (1) a bisection search (via scipy.optimize.brentq wrapped in a
|
|
6
|
+
no-grad context) that locates the root of M̃(h) - m = 0 for target mode count
|
|
7
|
+
m, and (2) the Implicit Function Theorem backward pass that computes
|
|
8
|
+
∂h_crit/∂X without unrolling the bisection iterations. The stabilized
|
|
9
|
+
denominator sg(u, δ) = sign(u) · max(|u|, δ) with default δ=1e-4 prevents
|
|
10
|
+
division by zero when ∂M̃/∂h is near zero (e.g., for distributions with
|
|
11
|
+
closely spaced modes). The public interface is `find_h_crit(M_tilde_fn,
|
|
12
|
+
h_lo, h_hi, m)` which returns (h_crit, context) suitable for use inside
|
|
13
|
+
`DCBFunction.forward`, and `ift_gradient(context, grad_output)` which
|
|
14
|
+
implements the IFT formula in `DCBFunction.backward`.
|
|
15
|
+
|
|
16
|
+
Round 15a: `g_brentq` parameter adds a coarse grid (default G=128) for the
|
|
17
|
+
brentq objective, giving 4× fewer KDE evaluations per iteration with negligible
|
|
18
|
+
root-location error. The full main grid is still used for IFT gradient.
|
|
19
|
+
|
|
20
|
+
Round 15b: `use_hard_bisection=True` (default) routes to `find_h_crit_hard`,
|
|
21
|
+
which bisects on the hard (discrete) mode count — provably non-increasing in h
|
|
22
|
+
(Silverman 1981) — eliminating the ~25% false-root rate of the brentq path.
|
|
23
|
+
The IFT backward is unchanged: M̃_cross is evaluated at the confirmed h_crit.
|
|
24
|
+
"""
|
|
25
|
+
|
|
26
|
+
from __future__ import annotations
|
|
27
|
+
|
|
28
|
+
import math
|
|
29
|
+
import warnings
|
|
30
|
+
|
|
31
|
+
import torch
|
|
32
|
+
from torch import Tensor
|
|
33
|
+
|
|
34
|
+
from dcb.kde import (
|
|
35
|
+
soft_mode_count,
|
|
36
|
+
soft_mode_count_cross,
|
|
37
|
+
soft_mode_count_cross_from_derivs,
|
|
38
|
+
kde_derivatives_chunked,
|
|
39
|
+
)
|
|
40
|
+
from dcb.fft_kde import fft_mode_count, adaptive_fft_G
|
|
41
|
+
|
|
42
|
+
_AUTO_FFT_THRESHOLD = 50_000 # n above which FFT bisection activates (use_fft_effective)
|
|
43
|
+
|
|
44
|
+
|
|
45
|
+
def hard_mode_count(f_prime: Tensor, grid: Tensor) -> int:
|
|
46
|
+
"""Count local maxima of f_h: sign changes of f' from + to - on the grid.
|
|
47
|
+
|
|
48
|
+
Parameters
|
|
49
|
+
----------
|
|
50
|
+
f_prime : Tensor, shape (G,)
|
|
51
|
+
First derivative of KDE evaluated on the grid.
|
|
52
|
+
grid : Tensor, shape (G,)
|
|
53
|
+
Uniform evaluation grid (used for shape reference only).
|
|
54
|
+
|
|
55
|
+
Returns
|
|
56
|
+
-------
|
|
57
|
+
int
|
|
58
|
+
Number of downward zero-crossings of f', i.e., number of modes.
|
|
59
|
+
"""
|
|
60
|
+
sign_changes = (f_prime[:-1] > 0) & (f_prime[1:] <= 0)
|
|
61
|
+
return int(sign_changes.sum().item())
|
|
62
|
+
|
|
63
|
+
|
|
64
|
+
def find_h_crit_hard(
|
|
65
|
+
X: Tensor,
|
|
66
|
+
grid: Tensor,
|
|
67
|
+
target_modes: int,
|
|
68
|
+
chunk_size: int,
|
|
69
|
+
brentq_n_max: int,
|
|
70
|
+
h_lo: float,
|
|
71
|
+
h_hi: float,
|
|
72
|
+
formula: str = 'cross',
|
|
73
|
+
tol: float = 1e-6,
|
|
74
|
+
eps: float = 0.1,
|
|
75
|
+
tau: float = 0.2,
|
|
76
|
+
use_fft: bool = False,
|
|
77
|
+
) -> tuple[float, float]:
|
|
78
|
+
"""Find h_crit via hard-mode-count bisection (monotone, no false roots).
|
|
79
|
+
|
|
80
|
+
Uses the hard (discrete) mode count — the number of sign changes of f'
|
|
81
|
+
from + to — which is provably non-increasing in h (Silverman 1981).
|
|
82
|
+
Bisects h until `hard_mode_count(f'_h) <= target_modes`. This approach
|
|
83
|
+
cannot produce false roots because the objective is monotone.
|
|
84
|
+
|
|
85
|
+
The IFT-backward condition number is still computed via M̃_cross evaluated
|
|
86
|
+
at the confirmed h_crit for diagnostic purposes.
|
|
87
|
+
|
|
88
|
+
Parameters
|
|
89
|
+
----------
|
|
90
|
+
X : Tensor, shape (n,)
|
|
91
|
+
grid : Tensor, shape (G,)
|
|
92
|
+
target_modes : int
|
|
93
|
+
chunk_size : int
|
|
94
|
+
brentq_n_max : int
|
|
95
|
+
Subsample X to this size for the bisection loop.
|
|
96
|
+
h_lo, h_hi : float
|
|
97
|
+
Initial search bracket.
|
|
98
|
+
formula : str
|
|
99
|
+
Used only for condition-number computation (default 'cross').
|
|
100
|
+
tol : float
|
|
101
|
+
Bisection tolerance on h (default 1e-6).
|
|
102
|
+
eps, tau : float
|
|
103
|
+
Soft mode count parameters for condition-number diagnostic.
|
|
104
|
+
use_fft : bool
|
|
105
|
+
If True (Round 18b), use FFT-based mode counting for bisection — no
|
|
106
|
+
subsampling, O(n + G log G) complexity. If False (default), use the
|
|
107
|
+
chunked KDE approach on a subsample of size brentq_n_max.
|
|
108
|
+
|
|
109
|
+
Returns
|
|
110
|
+
-------
|
|
111
|
+
(h_crit, cond_num) : (float, float)
|
|
112
|
+
h_crit: the critical bandwidth (smallest h with hard count <= target_modes).
|
|
113
|
+
cond_num: |∂M̃/∂h| at h_crit (large = well-conditioned IFT).
|
|
114
|
+
"""
|
|
115
|
+
mode_fn = _get_mode_count_fn(formula)
|
|
116
|
+
|
|
117
|
+
with torch.no_grad():
|
|
118
|
+
n = X.shape[0]
|
|
119
|
+
# FFT is only beneficial (and reliable) when n > brentq_n_max.
|
|
120
|
+
# For small n the histogram is too sparse (n/G < 1) and produces
|
|
121
|
+
# spurious sign changes. Fall back to direct KDE — there is no
|
|
122
|
+
# subsampling bias to fix when n ≤ brentq_n_max anyway.
|
|
123
|
+
use_fft_effective = use_fft and (n > brentq_n_max)
|
|
124
|
+
if not use_fft_effective and n > brentq_n_max:
|
|
125
|
+
idx = torch.randperm(n, device=X.device)[:brentq_n_max]
|
|
126
|
+
X_sub = X[idx]
|
|
127
|
+
else:
|
|
128
|
+
X_sub = X
|
|
129
|
+
|
|
130
|
+
if not use_fft_effective and n > brentq_n_max:
|
|
131
|
+
bias_factor = (brentq_n_max / n) ** (-0.2)
|
|
132
|
+
warnings.warn(
|
|
133
|
+
f"DCB: n={n} > brentq_n_max={brentq_n_max}. "
|
|
134
|
+
f"h_crit estimated on {brentq_n_max}-point subsample; "
|
|
135
|
+
f"expected upward bias ~{bias_factor:.2f}x vs full-data h_crit. "
|
|
136
|
+
"Use use_fft=True to eliminate subsampling bias.",
|
|
137
|
+
UserWarning,
|
|
138
|
+
stacklevel=4,
|
|
139
|
+
)
|
|
140
|
+
|
|
141
|
+
if use_fft_effective:
|
|
142
|
+
# Compute adaptive FFT grid size before bisection
|
|
143
|
+
with torch.no_grad():
|
|
144
|
+
sigma = X.std().item()
|
|
145
|
+
if sigma == 0.0:
|
|
146
|
+
sigma = 1.0
|
|
147
|
+
lo_domain = X.min().item() - 3 * sigma
|
|
148
|
+
hi_domain = X.max().item() + 3 * sigma
|
|
149
|
+
data_range = hi_domain - lo_domain
|
|
150
|
+
G_fft = adaptive_fft_G(data_range, h_hi)
|
|
151
|
+
|
|
152
|
+
with torch.no_grad():
|
|
153
|
+
# Verify bracket using FFT mode count on full X
|
|
154
|
+
count_lo = fft_mode_count(X, h_lo, G=G_fft)
|
|
155
|
+
if count_lo <= target_modes:
|
|
156
|
+
h_lo_try = h_lo
|
|
157
|
+
for _ in range(30):
|
|
158
|
+
h_lo_try *= 0.5
|
|
159
|
+
if h_lo_try < 1e-10:
|
|
160
|
+
break
|
|
161
|
+
if fft_mode_count(X, h_lo_try, G=G_fft) > target_modes:
|
|
162
|
+
h_lo = h_lo_try
|
|
163
|
+
break
|
|
164
|
+
|
|
165
|
+
count_hi = fft_mode_count(X, h_hi, G=G_fft)
|
|
166
|
+
if count_hi > target_modes:
|
|
167
|
+
for _ in range(30):
|
|
168
|
+
h_hi *= 2.0
|
|
169
|
+
if fft_mode_count(X, h_hi, G=G_fft) <= target_modes:
|
|
170
|
+
break
|
|
171
|
+
|
|
172
|
+
# Standard bisection: 50 iterations → bracket width / 2^50
|
|
173
|
+
lo, hi = h_lo, h_hi
|
|
174
|
+
for _ in range(50):
|
|
175
|
+
mid = (lo + hi) / 2.0
|
|
176
|
+
count = fft_mode_count(X, mid, G=G_fft)
|
|
177
|
+
if count <= target_modes:
|
|
178
|
+
hi = mid
|
|
179
|
+
else:
|
|
180
|
+
lo = mid
|
|
181
|
+
if (hi - lo) < tol:
|
|
182
|
+
break
|
|
183
|
+
|
|
184
|
+
h_crit = float(hi) # smallest h with count <= target_modes
|
|
185
|
+
|
|
186
|
+
else:
|
|
187
|
+
with torch.no_grad():
|
|
188
|
+
# Verify bracket: need count > target at h_lo, count <= target at h_hi.
|
|
189
|
+
f_lo, fp_lo, _ = kde_derivatives_chunked(X_sub, h_lo, grid, chunk_size)
|
|
190
|
+
count_lo = hard_mode_count(fp_lo, grid)
|
|
191
|
+
if count_lo <= target_modes:
|
|
192
|
+
# h_lo is already in the target regime — shrink h_lo
|
|
193
|
+
h_lo_try = h_lo
|
|
194
|
+
for _ in range(30):
|
|
195
|
+
h_lo_try *= 0.5
|
|
196
|
+
if h_lo_try < 1e-10:
|
|
197
|
+
break
|
|
198
|
+
_, fp_try, _ = kde_derivatives_chunked(X_sub, h_lo_try, grid, chunk_size)
|
|
199
|
+
if hard_mode_count(fp_try, grid) > target_modes:
|
|
200
|
+
h_lo = h_lo_try
|
|
201
|
+
break
|
|
202
|
+
|
|
203
|
+
_, fp_hi, _ = kde_derivatives_chunked(X_sub, h_hi, grid, chunk_size)
|
|
204
|
+
count_hi = hard_mode_count(fp_hi, grid)
|
|
205
|
+
if count_hi > target_modes:
|
|
206
|
+
# h_hi is still in the multi-mode regime — grow h_hi
|
|
207
|
+
for _ in range(30):
|
|
208
|
+
h_hi *= 2.0
|
|
209
|
+
_, fp_try, _ = kde_derivatives_chunked(X_sub, h_hi, grid, chunk_size)
|
|
210
|
+
if hard_mode_count(fp_try, grid) <= target_modes:
|
|
211
|
+
break
|
|
212
|
+
|
|
213
|
+
# Standard bisection: 50 iterations → bracket width / 2^50
|
|
214
|
+
lo, hi = h_lo, h_hi
|
|
215
|
+
for _ in range(50):
|
|
216
|
+
mid = (lo + hi) / 2.0
|
|
217
|
+
_, fp_mid, _ = kde_derivatives_chunked(X_sub, mid, grid, chunk_size)
|
|
218
|
+
count = hard_mode_count(fp_mid, grid)
|
|
219
|
+
if count <= target_modes:
|
|
220
|
+
hi = mid
|
|
221
|
+
else:
|
|
222
|
+
lo = mid
|
|
223
|
+
if (hi - lo) < tol:
|
|
224
|
+
break
|
|
225
|
+
|
|
226
|
+
h_crit = float(hi) # smallest h with count <= target_modes
|
|
227
|
+
|
|
228
|
+
# Condition number: |∂M̃/∂h| via finite difference at h_crit (for diagnostics).
|
|
229
|
+
# When use_fft=True, X_sub == X (full n points). soft_mode_count_cross builds an
|
|
230
|
+
# (n × G) matrix, so we cap to brentq_n_max to avoid O(n×G) OOM at large n.
|
|
231
|
+
dh = h_crit * 1e-4
|
|
232
|
+
with torch.no_grad():
|
|
233
|
+
X_cond = X_sub
|
|
234
|
+
if X_sub.shape[0] > brentq_n_max:
|
|
235
|
+
idx_cond = torch.randperm(X_sub.shape[0], device=X_sub.device)[:brentq_n_max]
|
|
236
|
+
X_cond = X_sub[idx_cond]
|
|
237
|
+
m_plus = mode_fn(X_cond, h_crit + dh, grid, eps, tau).item()
|
|
238
|
+
m_minus = mode_fn(X_cond, h_crit - dh, grid, eps, tau).item()
|
|
239
|
+
cond_num = abs((m_plus - m_minus) / (2 * dh))
|
|
240
|
+
|
|
241
|
+
return h_crit, float(cond_num)
|
|
242
|
+
|
|
243
|
+
|
|
244
|
+
def sg(u: Tensor, delta: float = 1e-4) -> Tensor:
|
|
245
|
+
"""Stabilized sign-magnitude denominator.
|
|
246
|
+
|
|
247
|
+
Computes sign(u) * max(|u|, delta), with the special case sg(0) = delta.
|
|
248
|
+
|
|
249
|
+
Parameters
|
|
250
|
+
----------
|
|
251
|
+
u : Tensor
|
|
252
|
+
Input tensor.
|
|
253
|
+
delta : float
|
|
254
|
+
Minimum absolute value (floor), default 1e-4.
|
|
255
|
+
|
|
256
|
+
Returns
|
|
257
|
+
-------
|
|
258
|
+
Tensor
|
|
259
|
+
Stabilized tensor, same shape as u.
|
|
260
|
+
"""
|
|
261
|
+
return torch.where(
|
|
262
|
+
u == 0,
|
|
263
|
+
torch.full_like(u, delta),
|
|
264
|
+
u.sign() * u.abs().clamp(min=delta),
|
|
265
|
+
)
|
|
266
|
+
|
|
267
|
+
|
|
268
|
+
def _get_mode_count_fn(formula: str):
|
|
269
|
+
if formula == 'cross':
|
|
270
|
+
return soft_mode_count_cross
|
|
271
|
+
return soft_mode_count
|
|
272
|
+
|
|
273
|
+
|
|
274
|
+
def find_h_crit(
|
|
275
|
+
X: Tensor,
|
|
276
|
+
grid: Tensor,
|
|
277
|
+
eps: float,
|
|
278
|
+
tau: float,
|
|
279
|
+
target_modes: int = 1,
|
|
280
|
+
h_lo: float = None,
|
|
281
|
+
h_hi: float = None,
|
|
282
|
+
formula: str = 'cross',
|
|
283
|
+
brentq_n_max: int = 50_000,
|
|
284
|
+
chunk_size: int = 50_000,
|
|
285
|
+
g_brentq: int = 128,
|
|
286
|
+
use_hard_bisection: bool = True,
|
|
287
|
+
use_fft: bool = True,
|
|
288
|
+
) -> tuple[float, float]:
|
|
289
|
+
"""Find h_crit and return (h_crit, condition_number).
|
|
290
|
+
|
|
291
|
+
When use_hard_bisection=True (default, Round 15b), dispatches to
|
|
292
|
+
`find_h_crit_hard` which bisects on the hard (discrete) mode count —
|
|
293
|
+
provably non-increasing in h (Silverman 1981), no false roots.
|
|
294
|
+
|
|
295
|
+
When use_hard_bisection=False, locates the root of
|
|
296
|
+
G(h, X) = M̃(h; X) - target_modes = 0 via Brent's method on a coarse
|
|
297
|
+
grid of g_brentq points (Round 15a: 4× fewer KDE evaluations per call).
|
|
298
|
+
The full main grid is still used for IFT gradient computation.
|
|
299
|
+
|
|
300
|
+
Also returns |∂M̃/∂h| at h_crit as a condition-number diagnostic:
|
|
301
|
+
values near zero indicate bifurcation instability (IFT denominator small).
|
|
302
|
+
|
|
303
|
+
Parameters
|
|
304
|
+
----------
|
|
305
|
+
X : Tensor, shape (n,)
|
|
306
|
+
grid : Tensor, shape (G,)
|
|
307
|
+
eps, tau : float
|
|
308
|
+
target_modes : int
|
|
309
|
+
h_lo, h_hi : float or None
|
|
310
|
+
formula : str
|
|
311
|
+
'cross' (default) — use M̃_cross (crossing-count, fixes m=1 bias).
|
|
312
|
+
'integral' — use original M̃ (legacy, known m=1 failure).
|
|
313
|
+
g_brentq : int
|
|
314
|
+
Grid resolution for the brentq objective (default 128). Ignored when
|
|
315
|
+
use_hard_bisection=True.
|
|
316
|
+
use_hard_bisection : bool
|
|
317
|
+
If True (default), use hard-mode-count bisection (no false roots).
|
|
318
|
+
If False, use legacy brentq on M̃_cross with coarse grid g_brentq.
|
|
319
|
+
use_fft : bool
|
|
320
|
+
Default True. Uses FFT-based mode counting (O(n + G log G)) for n > 50K,
|
|
321
|
+
eliminating subsampling bias. Falls back to direct KDE for n ≤ 50K (no
|
|
322
|
+
bias at small n). Set False only for legacy/ablation comparison.
|
|
323
|
+
|
|
324
|
+
Returns
|
|
325
|
+
-------
|
|
326
|
+
(h_crit, cond_num) : (float, float)
|
|
327
|
+
h_crit: the critical bandwidth.
|
|
328
|
+
cond_num: |∂M̃/∂h| at h_crit (large = well-conditioned IFT).
|
|
329
|
+
"""
|
|
330
|
+
with torch.no_grad():
|
|
331
|
+
std_val = X.std().item()
|
|
332
|
+
if h_lo is None:
|
|
333
|
+
h_lo = 1e-3 * std_val
|
|
334
|
+
if h_hi is None:
|
|
335
|
+
h_hi = 3.0 * std_val
|
|
336
|
+
|
|
337
|
+
if use_hard_bisection:
|
|
338
|
+
return find_h_crit_hard(
|
|
339
|
+
X, grid, target_modes, chunk_size, brentq_n_max,
|
|
340
|
+
h_lo, h_hi, formula=formula, eps=eps, tau=tau,
|
|
341
|
+
use_fft=use_fft,
|
|
342
|
+
)
|
|
343
|
+
|
|
344
|
+
from scipy.optimize import brentq
|
|
345
|
+
|
|
346
|
+
mode_fn = _get_mode_count_fn(formula)
|
|
347
|
+
|
|
348
|
+
with torch.no_grad():
|
|
349
|
+
n = X.shape[0]
|
|
350
|
+
if n > brentq_n_max:
|
|
351
|
+
idx = torch.randperm(n, device=X.device)[:brentq_n_max]
|
|
352
|
+
X_brentq = X[idx]
|
|
353
|
+
else:
|
|
354
|
+
X_brentq = X
|
|
355
|
+
|
|
356
|
+
# Build a coarse grid for brentq objective evaluation (Round 15a).
|
|
357
|
+
# g_brentq=128 gives 4× fewer KDE evaluations than G=512 with negligible
|
|
358
|
+
# root-location error — sign detection only needs coarse resolution.
|
|
359
|
+
grid_lo = grid[0].item()
|
|
360
|
+
grid_hi = grid[-1].item()
|
|
361
|
+
grid_coarse = torch.linspace(
|
|
362
|
+
grid_lo, grid_hi, g_brentq, dtype=grid.dtype, device=grid.device
|
|
363
|
+
)
|
|
364
|
+
|
|
365
|
+
# For 'cross': M̃_cross has a soft floor ~(target_modes + 0.001) in the
|
|
366
|
+
# target-mode regime, so a hard threshold of target_modes never cleanly
|
|
367
|
+
# separates the transition from the plateau. Using target_modes + 0.5
|
|
368
|
+
# places the root at the center of the sharp mode-merging transition.
|
|
369
|
+
# For 'integral': the formula has no such floor; use target_modes directly.
|
|
370
|
+
threshold = target_modes + (0.5 if formula == 'cross' else 0)
|
|
371
|
+
|
|
372
|
+
def f(h: float) -> float:
|
|
373
|
+
return mode_fn(X_brentq, h, grid_coarse, eps, tau).item() - threshold
|
|
374
|
+
|
|
375
|
+
f_lo = f(h_lo)
|
|
376
|
+
f_hi = f(h_hi)
|
|
377
|
+
|
|
378
|
+
if f_lo <= 0:
|
|
379
|
+
h_scan = h_lo
|
|
380
|
+
for _ in range(40):
|
|
381
|
+
h_scan *= 1.5
|
|
382
|
+
if h_scan >= h_hi:
|
|
383
|
+
break
|
|
384
|
+
f_scan = f(h_scan)
|
|
385
|
+
if f_scan > 0:
|
|
386
|
+
h_lo = h_scan
|
|
387
|
+
f_lo = f_scan
|
|
388
|
+
break
|
|
389
|
+
else:
|
|
390
|
+
return h_lo, 0.0
|
|
391
|
+
if f_lo <= 0:
|
|
392
|
+
return h_lo, 0.0
|
|
393
|
+
|
|
394
|
+
if f_hi > 0:
|
|
395
|
+
h_hi *= 3
|
|
396
|
+
f_hi = f(h_hi)
|
|
397
|
+
if f_hi > 0:
|
|
398
|
+
return h_hi, 0.0
|
|
399
|
+
|
|
400
|
+
h_crit = brentq(f, h_lo, h_hi)
|
|
401
|
+
|
|
402
|
+
# Condition number: |∂M̃/∂h| via finite difference at h_crit
|
|
403
|
+
dh = h_crit * 1e-4
|
|
404
|
+
with torch.no_grad():
|
|
405
|
+
m_plus = mode_fn(X_brentq, h_crit + dh, grid_coarse, eps, tau).item()
|
|
406
|
+
m_minus = mode_fn(X_brentq, h_crit - dh, grid_coarse, eps, tau).item()
|
|
407
|
+
cond_num = abs((m_plus - m_minus) / (2 * dh))
|
|
408
|
+
|
|
409
|
+
return float(h_crit), float(cond_num)
|
|
410
|
+
|
|
411
|
+
|
|
412
|
+
def _analytical_dM_dX(
|
|
413
|
+
X: Tensor,
|
|
414
|
+
h: float,
|
|
415
|
+
grid: Tensor,
|
|
416
|
+
dM_dfp: Tensor,
|
|
417
|
+
dM_dfpp: Tensor,
|
|
418
|
+
chunk_size: int,
|
|
419
|
+
) -> Tensor:
|
|
420
|
+
"""Chunked analytical ∂M̃/∂X using the KDE chain rule.
|
|
421
|
+
|
|
422
|
+
Avoids materialising the full n×G autograd graph. Derived from:
|
|
423
|
+
∂f'(grid_j)/∂X_i = (1/n) K_ij · (1/h² − diff²_ij/h⁴)
|
|
424
|
+
∂f''(grid_j)/∂X_i = (1/n) K_ij · diff_ij/h⁴ · (diff²_ij/h² − 3)
|
|
425
|
+
Then ∂M̃/∂X_i = Σ_j [∂M̃/∂f'_j · ∂f'_j/∂X_i + ∂M̃/∂f''_j · ∂f''_j/∂X_i].
|
|
426
|
+
"""
|
|
427
|
+
n = X.shape[0]
|
|
428
|
+
h2 = h * h
|
|
429
|
+
h4 = h2 * h2
|
|
430
|
+
h6 = h4 * h2
|
|
431
|
+
out = torch.zeros(n, dtype=X.dtype, device=X.device)
|
|
432
|
+
with torch.no_grad():
|
|
433
|
+
for start in range(0, n, chunk_size):
|
|
434
|
+
Xc = X[start : start + chunk_size]
|
|
435
|
+
diff = grid.unsqueeze(0) - Xc.unsqueeze(1) # (c, G)
|
|
436
|
+
K = torch.exp(-0.5 * (diff / h) ** 2) / (math.sqrt(2 * math.pi) * h)
|
|
437
|
+
# ∂f'_j/∂X_i
|
|
438
|
+
gfp = K * (1.0 / h2 - diff ** 2 / h4) # (c, G)
|
|
439
|
+
# ∂f''_j/∂X_i
|
|
440
|
+
gfpp = K * diff / h4 * (diff ** 2 / h2 - 3.0) # (c, G)
|
|
441
|
+
out[start : start + chunk_size] = (
|
|
442
|
+
dM_dfp * gfp + dM_dfpp * gfpp
|
|
443
|
+
).sum(1)
|
|
444
|
+
return out / n
|
|
445
|
+
|
|
446
|
+
|
|
447
|
+
def ift_gradient(
|
|
448
|
+
X: Tensor,
|
|
449
|
+
h_crit: float,
|
|
450
|
+
grid: Tensor,
|
|
451
|
+
eps: float,
|
|
452
|
+
tau: float,
|
|
453
|
+
grad_output: Tensor,
|
|
454
|
+
delta: float = 1e-4,
|
|
455
|
+
formula: str = 'cross',
|
|
456
|
+
chunk_size: int = 50_000,
|
|
457
|
+
analytical_n_thresh: int = 10_000,
|
|
458
|
+
safe_backward: bool = False,
|
|
459
|
+
) -> Tensor:
|
|
460
|
+
"""Compute the IFT gradient ∂h_crit/∂X and chain with upstream grad.
|
|
461
|
+
|
|
462
|
+
Implements the Implicit Function Theorem formula:
|
|
463
|
+
|
|
464
|
+
∂h_crit/∂X = -(∂M̃/∂h)^{-1} · (∂M̃/∂X) evaluated at h = h_crit
|
|
465
|
+
|
|
466
|
+
using a single PyTorch autograd backward pass. The denominator is
|
|
467
|
+
stabilized by sg(·, delta) to prevent division by zero.
|
|
468
|
+
|
|
469
|
+
Parameters
|
|
470
|
+
----------
|
|
471
|
+
X : Tensor, shape (n,)
|
|
472
|
+
Observed data points (need not require grad on entry).
|
|
473
|
+
h_crit : float
|
|
474
|
+
Critical bandwidth found by find_h_crit.
|
|
475
|
+
grid : Tensor, shape (G,)
|
|
476
|
+
Uniform evaluation grid.
|
|
477
|
+
eps : float
|
|
478
|
+
Width of the Gaussian delta approximation.
|
|
479
|
+
tau : float
|
|
480
|
+
Sigmoid temperature.
|
|
481
|
+
grad_output : Tensor, shape ()
|
|
482
|
+
Upstream scalar gradient from the loss w.r.t. h_crit.
|
|
483
|
+
delta : float
|
|
484
|
+
Stabilisation floor for the sg denominator. Default 1e-4.
|
|
485
|
+
|
|
486
|
+
Returns
|
|
487
|
+
-------
|
|
488
|
+
Tensor, shape (n,)
|
|
489
|
+
Gradient of the loss w.r.t. X.
|
|
490
|
+
"""
|
|
491
|
+
# torch.autograd.Function.backward runs under no_grad; we need enable_grad
|
|
492
|
+
# so that the forward pass through soft_mode_count builds a graph that lets
|
|
493
|
+
# us differentiate w.r.t. both h_tensor and X_req.
|
|
494
|
+
#
|
|
495
|
+
# Total-derivative IFT: h_crit is a function of X both as KDE data and as
|
|
496
|
+
# the source of the evaluation grid Ω(X). We must differentiate M̃ w.r.t.
|
|
497
|
+
# X in both roles to get the correct total ∂M/∂X.
|
|
498
|
+
#
|
|
499
|
+
# We do this by rebuilding a differentiable grid from X_req inside the
|
|
500
|
+
# grad context, so autograd sees both paths automatically.
|
|
501
|
+
G = grid.shape[0]
|
|
502
|
+
margin_sigma = 3.0
|
|
503
|
+
eps_coeff = 0.1
|
|
504
|
+
tau_coeff = 0.2
|
|
505
|
+
n = X.shape[0]
|
|
506
|
+
|
|
507
|
+
if n > analytical_n_thresh and formula == 'cross':
|
|
508
|
+
# Large-n analytical path — O(chunk×G) peak memory.
|
|
509
|
+
# Avoids building an n×G autograd graph; uses the KDE chain rule instead.
|
|
510
|
+
|
|
511
|
+
# Step 1: KDE derivatives at h_crit (chunked, no graph)
|
|
512
|
+
with torch.no_grad():
|
|
513
|
+
f, fp, fpp = kde_derivatives_chunked(X, h_crit, grid, chunk_size)
|
|
514
|
+
|
|
515
|
+
# Step 2: ∂M̃/∂fp and ∂M̃/∂fpp via G-dim autograd only
|
|
516
|
+
with torch.enable_grad():
|
|
517
|
+
fp_req = fp.detach().requires_grad_(True)
|
|
518
|
+
fpp_req = fpp.detach().requires_grad_(True)
|
|
519
|
+
M_val = soft_mode_count_cross_from_derivs(
|
|
520
|
+
f.detach(), fp_req, fpp_req, grid, eps, tau
|
|
521
|
+
)
|
|
522
|
+
dM_dfp, dM_dfpp = torch.autograd.grad(
|
|
523
|
+
M_val, [fp_req, fpp_req], retain_graph=False
|
|
524
|
+
)
|
|
525
|
+
|
|
526
|
+
# Step 3: ∂M̃/∂h via central finite differences (2 chunked KDE passes)
|
|
527
|
+
dh = h_crit * 1e-4
|
|
528
|
+
with torch.no_grad():
|
|
529
|
+
f_p, fp_p, fpp_p = kde_derivatives_chunked(X, h_crit + dh, grid, chunk_size)
|
|
530
|
+
f_m, fp_m, fpp_m = kde_derivatives_chunked(X, h_crit - dh, grid, chunk_size)
|
|
531
|
+
M_plus = soft_mode_count_cross_from_derivs(f_p, fp_p, fpp_p, grid, eps, tau)
|
|
532
|
+
M_minus = soft_mode_count_cross_from_derivs(f_m, fp_m, fpp_m, grid, eps, tau)
|
|
533
|
+
dM_dh = torch.tensor(
|
|
534
|
+
float((M_plus - M_minus) / (2 * dh)), dtype=X.dtype, device=X.device
|
|
535
|
+
)
|
|
536
|
+
|
|
537
|
+
# Step 4: ∂M̃/∂X analytically (chunked chain rule, O(chunk×G) memory)
|
|
538
|
+
dM_dX = _analytical_dM_dX(X, h_crit, grid, dM_dfp, dM_dfpp, chunk_size)
|
|
539
|
+
|
|
540
|
+
else:
|
|
541
|
+
# Small-n autograd path — total-derivative IFT through data + grid + eps/tau.
|
|
542
|
+
# Rebuilds a differentiable grid and eps/tau from X_req so all three paths
|
|
543
|
+
# (KDE data, grid bounds, adaptive scale) contribute to ∂M̃/∂X.
|
|
544
|
+
with torch.enable_grad():
|
|
545
|
+
h_tensor = torch.tensor(h_crit, dtype=X.dtype, device=X.device, requires_grad=True)
|
|
546
|
+
X_req = X.detach().requires_grad_(True)
|
|
547
|
+
n_req = X_req.shape[0]
|
|
548
|
+
|
|
549
|
+
std_x = X_req.std()
|
|
550
|
+
lo = X_req.min() - margin_sigma * std_x
|
|
551
|
+
hi = X_req.max() + margin_sigma * std_x
|
|
552
|
+
t = torch.linspace(0.0, 1.0, G, dtype=X.dtype, device=X.device)
|
|
553
|
+
grid_diff = lo + (hi - lo) * t
|
|
554
|
+
|
|
555
|
+
mode_fn = _get_mode_count_fn(formula)
|
|
556
|
+
if formula == 'cross':
|
|
557
|
+
eps_diff = torch.tensor(eps, dtype=X.dtype, device=X.device)
|
|
558
|
+
tau_diff = torch.tensor(tau, dtype=X.dtype, device=X.device)
|
|
559
|
+
else:
|
|
560
|
+
h0_diff = 0.9 * std_x * (n_req ** -0.2)
|
|
561
|
+
u = (grid_diff.unsqueeze(0) - X_req.unsqueeze(1)) / h0_diff
|
|
562
|
+
K0 = torch.exp(-0.5 * u ** 2) / (math.sqrt(2 * math.pi) * h0_diff)
|
|
563
|
+
f_prime0 = (-u / h0_diff * K0).mean(dim=0)
|
|
564
|
+
f_dbl0 = ((u ** 2 - 1.0) / h0_diff ** 2 * K0).mean(dim=0)
|
|
565
|
+
eps_diff = eps_coeff * f_prime0.std()
|
|
566
|
+
tau_diff = tau_coeff * f_dbl0.abs().median()
|
|
567
|
+
|
|
568
|
+
M = mode_fn(X_req, h_tensor, grid_diff, eps_diff, tau_diff)
|
|
569
|
+
dM_dh = torch.autograd.grad(M, h_tensor, retain_graph=True, create_graph=False)[0]
|
|
570
|
+
dM_dX = torch.autograd.grad(M, X_req, retain_graph=False, create_graph=False)[0]
|
|
571
|
+
|
|
572
|
+
# IFT formula with stabilized denominator + denom guard.
|
|
573
|
+
DENOM_GUARD = 0.01
|
|
574
|
+
SAFE_GUARD = 0.1 # 10× stricter clamp when safe_backward=True
|
|
575
|
+
denom_abs = dM_dh.abs().item()
|
|
576
|
+
denom_signed = dM_dh.item() # raw signed value before any clamping
|
|
577
|
+
ift_gradient.last_denom_abs = denom_abs
|
|
578
|
+
ift_gradient.last_denom_signed = denom_signed
|
|
579
|
+
ift_gradient.last_guard_triggered = denom_abs < DENOM_GUARD
|
|
580
|
+
|
|
581
|
+
if denom_abs < DENOM_GUARD:
|
|
582
|
+
warnings.warn(
|
|
583
|
+
f"DCB IFT denominator |∂M̃/∂h|={denom_abs:.2e} < {DENOM_GUARD}. "
|
|
584
|
+
"Gradient may be large. Use safe_backward=True to clamp.",
|
|
585
|
+
stacklevel=3,
|
|
586
|
+
)
|
|
587
|
+
|
|
588
|
+
if safe_backward and denom_abs == 0.0:
|
|
589
|
+
# Exact-zero edge case: sg(0, delta) always returns +delta, which can
|
|
590
|
+
# invert the gradient sign if dM_dh approached zero from the negative
|
|
591
|
+
# side. Use copysign to preserve whatever sign the raw value had
|
|
592
|
+
# (copysign(SAFE_GUARD, 0.0) = +SAFE_GUARD by IEEE convention, but
|
|
593
|
+
# this is no worse than the original behaviour and documents intent).
|
|
594
|
+
safe_denom = torch.tensor(
|
|
595
|
+
math.copysign(SAFE_GUARD, denom_signed),
|
|
596
|
+
dtype=X.dtype, device=X.device,
|
|
597
|
+
)
|
|
598
|
+
dh_dX = -(1.0 / safe_denom) * dM_dX
|
|
599
|
+
else:
|
|
600
|
+
effective_guard = SAFE_GUARD if safe_backward else (DENOM_GUARD if denom_abs < DENOM_GUARD else delta)
|
|
601
|
+
dh_dX = -(1.0 / sg(dM_dh, effective_guard)) * dM_dX
|
|
602
|
+
ift_gradient.last_grad_norm = float((grad_output * dh_dX).norm().item())
|
|
603
|
+
|
|
604
|
+
return grad_output * dh_dX
|