diffcb 0.1.0__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
dcb/solver.py ADDED
@@ -0,0 +1,604 @@
1
+ """
2
+ dcb.solver — IFT Root-Finder and Backward Pass
3
+
4
+ This module implements the two-stage computation that produces a differentiable
5
+ h_crit: (1) a bisection search (via scipy.optimize.brentq wrapped in a
6
+ no-grad context) that locates the root of M̃(h) - m = 0 for target mode count
7
+ m, and (2) the Implicit Function Theorem backward pass that computes
8
+ ∂h_crit/∂X without unrolling the bisection iterations. The stabilized
9
+ denominator sg(u, δ) = sign(u) · max(|u|, δ) with default δ=1e-4 prevents
10
+ division by zero when ∂M̃/∂h is near zero (e.g., for distributions with
11
+ closely spaced modes). The public interface is `find_h_crit(M_tilde_fn,
12
+ h_lo, h_hi, m)` which returns (h_crit, context) suitable for use inside
13
+ `DCBFunction.forward`, and `ift_gradient(context, grad_output)` which
14
+ implements the IFT formula in `DCBFunction.backward`.
15
+
16
+ Round 15a: `g_brentq` parameter adds a coarse grid (default G=128) for the
17
+ brentq objective, giving 4× fewer KDE evaluations per iteration with negligible
18
+ root-location error. The full main grid is still used for IFT gradient.
19
+
20
+ Round 15b: `use_hard_bisection=True` (default) routes to `find_h_crit_hard`,
21
+ which bisects on the hard (discrete) mode count — provably non-increasing in h
22
+ (Silverman 1981) — eliminating the ~25% false-root rate of the brentq path.
23
+ The IFT backward is unchanged: M̃_cross is evaluated at the confirmed h_crit.
24
+ """
25
+
26
+ from __future__ import annotations
27
+
28
+ import math
29
+ import warnings
30
+
31
+ import torch
32
+ from torch import Tensor
33
+
34
+ from dcb.kde import (
35
+ soft_mode_count,
36
+ soft_mode_count_cross,
37
+ soft_mode_count_cross_from_derivs,
38
+ kde_derivatives_chunked,
39
+ )
40
+ from dcb.fft_kde import fft_mode_count, adaptive_fft_G
41
+
42
+ _AUTO_FFT_THRESHOLD = 50_000 # n above which FFT bisection activates (use_fft_effective)
43
+
44
+
45
+ def hard_mode_count(f_prime: Tensor, grid: Tensor) -> int:
46
+ """Count local maxima of f_h: sign changes of f' from + to - on the grid.
47
+
48
+ Parameters
49
+ ----------
50
+ f_prime : Tensor, shape (G,)
51
+ First derivative of KDE evaluated on the grid.
52
+ grid : Tensor, shape (G,)
53
+ Uniform evaluation grid (used for shape reference only).
54
+
55
+ Returns
56
+ -------
57
+ int
58
+ Number of downward zero-crossings of f', i.e., number of modes.
59
+ """
60
+ sign_changes = (f_prime[:-1] > 0) & (f_prime[1:] <= 0)
61
+ return int(sign_changes.sum().item())
62
+
63
+
64
+ def find_h_crit_hard(
65
+ X: Tensor,
66
+ grid: Tensor,
67
+ target_modes: int,
68
+ chunk_size: int,
69
+ brentq_n_max: int,
70
+ h_lo: float,
71
+ h_hi: float,
72
+ formula: str = 'cross',
73
+ tol: float = 1e-6,
74
+ eps: float = 0.1,
75
+ tau: float = 0.2,
76
+ use_fft: bool = False,
77
+ ) -> tuple[float, float]:
78
+ """Find h_crit via hard-mode-count bisection (monotone, no false roots).
79
+
80
+ Uses the hard (discrete) mode count — the number of sign changes of f'
81
+ from + to — which is provably non-increasing in h (Silverman 1981).
82
+ Bisects h until `hard_mode_count(f'_h) <= target_modes`. This approach
83
+ cannot produce false roots because the objective is monotone.
84
+
85
+ The IFT-backward condition number is still computed via M̃_cross evaluated
86
+ at the confirmed h_crit for diagnostic purposes.
87
+
88
+ Parameters
89
+ ----------
90
+ X : Tensor, shape (n,)
91
+ grid : Tensor, shape (G,)
92
+ target_modes : int
93
+ chunk_size : int
94
+ brentq_n_max : int
95
+ Subsample X to this size for the bisection loop.
96
+ h_lo, h_hi : float
97
+ Initial search bracket.
98
+ formula : str
99
+ Used only for condition-number computation (default 'cross').
100
+ tol : float
101
+ Bisection tolerance on h (default 1e-6).
102
+ eps, tau : float
103
+ Soft mode count parameters for condition-number diagnostic.
104
+ use_fft : bool
105
+ If True (Round 18b), use FFT-based mode counting for bisection — no
106
+ subsampling, O(n + G log G) complexity. If False (default), use the
107
+ chunked KDE approach on a subsample of size brentq_n_max.
108
+
109
+ Returns
110
+ -------
111
+ (h_crit, cond_num) : (float, float)
112
+ h_crit: the critical bandwidth (smallest h with hard count <= target_modes).
113
+ cond_num: |∂M̃/∂h| at h_crit (large = well-conditioned IFT).
114
+ """
115
+ mode_fn = _get_mode_count_fn(formula)
116
+
117
+ with torch.no_grad():
118
+ n = X.shape[0]
119
+ # FFT is only beneficial (and reliable) when n > brentq_n_max.
120
+ # For small n the histogram is too sparse (n/G < 1) and produces
121
+ # spurious sign changes. Fall back to direct KDE — there is no
122
+ # subsampling bias to fix when n ≤ brentq_n_max anyway.
123
+ use_fft_effective = use_fft and (n > brentq_n_max)
124
+ if not use_fft_effective and n > brentq_n_max:
125
+ idx = torch.randperm(n, device=X.device)[:brentq_n_max]
126
+ X_sub = X[idx]
127
+ else:
128
+ X_sub = X
129
+
130
+ if not use_fft_effective and n > brentq_n_max:
131
+ bias_factor = (brentq_n_max / n) ** (-0.2)
132
+ warnings.warn(
133
+ f"DCB: n={n} > brentq_n_max={brentq_n_max}. "
134
+ f"h_crit estimated on {brentq_n_max}-point subsample; "
135
+ f"expected upward bias ~{bias_factor:.2f}x vs full-data h_crit. "
136
+ "Use use_fft=True to eliminate subsampling bias.",
137
+ UserWarning,
138
+ stacklevel=4,
139
+ )
140
+
141
+ if use_fft_effective:
142
+ # Compute adaptive FFT grid size before bisection
143
+ with torch.no_grad():
144
+ sigma = X.std().item()
145
+ if sigma == 0.0:
146
+ sigma = 1.0
147
+ lo_domain = X.min().item() - 3 * sigma
148
+ hi_domain = X.max().item() + 3 * sigma
149
+ data_range = hi_domain - lo_domain
150
+ G_fft = adaptive_fft_G(data_range, h_hi)
151
+
152
+ with torch.no_grad():
153
+ # Verify bracket using FFT mode count on full X
154
+ count_lo = fft_mode_count(X, h_lo, G=G_fft)
155
+ if count_lo <= target_modes:
156
+ h_lo_try = h_lo
157
+ for _ in range(30):
158
+ h_lo_try *= 0.5
159
+ if h_lo_try < 1e-10:
160
+ break
161
+ if fft_mode_count(X, h_lo_try, G=G_fft) > target_modes:
162
+ h_lo = h_lo_try
163
+ break
164
+
165
+ count_hi = fft_mode_count(X, h_hi, G=G_fft)
166
+ if count_hi > target_modes:
167
+ for _ in range(30):
168
+ h_hi *= 2.0
169
+ if fft_mode_count(X, h_hi, G=G_fft) <= target_modes:
170
+ break
171
+
172
+ # Standard bisection: 50 iterations → bracket width / 2^50
173
+ lo, hi = h_lo, h_hi
174
+ for _ in range(50):
175
+ mid = (lo + hi) / 2.0
176
+ count = fft_mode_count(X, mid, G=G_fft)
177
+ if count <= target_modes:
178
+ hi = mid
179
+ else:
180
+ lo = mid
181
+ if (hi - lo) < tol:
182
+ break
183
+
184
+ h_crit = float(hi) # smallest h with count <= target_modes
185
+
186
+ else:
187
+ with torch.no_grad():
188
+ # Verify bracket: need count > target at h_lo, count <= target at h_hi.
189
+ f_lo, fp_lo, _ = kde_derivatives_chunked(X_sub, h_lo, grid, chunk_size)
190
+ count_lo = hard_mode_count(fp_lo, grid)
191
+ if count_lo <= target_modes:
192
+ # h_lo is already in the target regime — shrink h_lo
193
+ h_lo_try = h_lo
194
+ for _ in range(30):
195
+ h_lo_try *= 0.5
196
+ if h_lo_try < 1e-10:
197
+ break
198
+ _, fp_try, _ = kde_derivatives_chunked(X_sub, h_lo_try, grid, chunk_size)
199
+ if hard_mode_count(fp_try, grid) > target_modes:
200
+ h_lo = h_lo_try
201
+ break
202
+
203
+ _, fp_hi, _ = kde_derivatives_chunked(X_sub, h_hi, grid, chunk_size)
204
+ count_hi = hard_mode_count(fp_hi, grid)
205
+ if count_hi > target_modes:
206
+ # h_hi is still in the multi-mode regime — grow h_hi
207
+ for _ in range(30):
208
+ h_hi *= 2.0
209
+ _, fp_try, _ = kde_derivatives_chunked(X_sub, h_hi, grid, chunk_size)
210
+ if hard_mode_count(fp_try, grid) <= target_modes:
211
+ break
212
+
213
+ # Standard bisection: 50 iterations → bracket width / 2^50
214
+ lo, hi = h_lo, h_hi
215
+ for _ in range(50):
216
+ mid = (lo + hi) / 2.0
217
+ _, fp_mid, _ = kde_derivatives_chunked(X_sub, mid, grid, chunk_size)
218
+ count = hard_mode_count(fp_mid, grid)
219
+ if count <= target_modes:
220
+ hi = mid
221
+ else:
222
+ lo = mid
223
+ if (hi - lo) < tol:
224
+ break
225
+
226
+ h_crit = float(hi) # smallest h with count <= target_modes
227
+
228
+ # Condition number: |∂M̃/∂h| via finite difference at h_crit (for diagnostics).
229
+ # When use_fft=True, X_sub == X (full n points). soft_mode_count_cross builds an
230
+ # (n × G) matrix, so we cap to brentq_n_max to avoid O(n×G) OOM at large n.
231
+ dh = h_crit * 1e-4
232
+ with torch.no_grad():
233
+ X_cond = X_sub
234
+ if X_sub.shape[0] > brentq_n_max:
235
+ idx_cond = torch.randperm(X_sub.shape[0], device=X_sub.device)[:brentq_n_max]
236
+ X_cond = X_sub[idx_cond]
237
+ m_plus = mode_fn(X_cond, h_crit + dh, grid, eps, tau).item()
238
+ m_minus = mode_fn(X_cond, h_crit - dh, grid, eps, tau).item()
239
+ cond_num = abs((m_plus - m_minus) / (2 * dh))
240
+
241
+ return h_crit, float(cond_num)
242
+
243
+
244
+ def sg(u: Tensor, delta: float = 1e-4) -> Tensor:
245
+ """Stabilized sign-magnitude denominator.
246
+
247
+ Computes sign(u) * max(|u|, delta), with the special case sg(0) = delta.
248
+
249
+ Parameters
250
+ ----------
251
+ u : Tensor
252
+ Input tensor.
253
+ delta : float
254
+ Minimum absolute value (floor), default 1e-4.
255
+
256
+ Returns
257
+ -------
258
+ Tensor
259
+ Stabilized tensor, same shape as u.
260
+ """
261
+ return torch.where(
262
+ u == 0,
263
+ torch.full_like(u, delta),
264
+ u.sign() * u.abs().clamp(min=delta),
265
+ )
266
+
267
+
268
+ def _get_mode_count_fn(formula: str):
269
+ if formula == 'cross':
270
+ return soft_mode_count_cross
271
+ return soft_mode_count
272
+
273
+
274
+ def find_h_crit(
275
+ X: Tensor,
276
+ grid: Tensor,
277
+ eps: float,
278
+ tau: float,
279
+ target_modes: int = 1,
280
+ h_lo: float = None,
281
+ h_hi: float = None,
282
+ formula: str = 'cross',
283
+ brentq_n_max: int = 50_000,
284
+ chunk_size: int = 50_000,
285
+ g_brentq: int = 128,
286
+ use_hard_bisection: bool = True,
287
+ use_fft: bool = True,
288
+ ) -> tuple[float, float]:
289
+ """Find h_crit and return (h_crit, condition_number).
290
+
291
+ When use_hard_bisection=True (default, Round 15b), dispatches to
292
+ `find_h_crit_hard` which bisects on the hard (discrete) mode count —
293
+ provably non-increasing in h (Silverman 1981), no false roots.
294
+
295
+ When use_hard_bisection=False, locates the root of
296
+ G(h, X) = M̃(h; X) - target_modes = 0 via Brent's method on a coarse
297
+ grid of g_brentq points (Round 15a: 4× fewer KDE evaluations per call).
298
+ The full main grid is still used for IFT gradient computation.
299
+
300
+ Also returns |∂M̃/∂h| at h_crit as a condition-number diagnostic:
301
+ values near zero indicate bifurcation instability (IFT denominator small).
302
+
303
+ Parameters
304
+ ----------
305
+ X : Tensor, shape (n,)
306
+ grid : Tensor, shape (G,)
307
+ eps, tau : float
308
+ target_modes : int
309
+ h_lo, h_hi : float or None
310
+ formula : str
311
+ 'cross' (default) — use M̃_cross (crossing-count, fixes m=1 bias).
312
+ 'integral' — use original M̃ (legacy, known m=1 failure).
313
+ g_brentq : int
314
+ Grid resolution for the brentq objective (default 128). Ignored when
315
+ use_hard_bisection=True.
316
+ use_hard_bisection : bool
317
+ If True (default), use hard-mode-count bisection (no false roots).
318
+ If False, use legacy brentq on M̃_cross with coarse grid g_brentq.
319
+ use_fft : bool
320
+ Default True. Uses FFT-based mode counting (O(n + G log G)) for n > 50K,
321
+ eliminating subsampling bias. Falls back to direct KDE for n ≤ 50K (no
322
+ bias at small n). Set False only for legacy/ablation comparison.
323
+
324
+ Returns
325
+ -------
326
+ (h_crit, cond_num) : (float, float)
327
+ h_crit: the critical bandwidth.
328
+ cond_num: |∂M̃/∂h| at h_crit (large = well-conditioned IFT).
329
+ """
330
+ with torch.no_grad():
331
+ std_val = X.std().item()
332
+ if h_lo is None:
333
+ h_lo = 1e-3 * std_val
334
+ if h_hi is None:
335
+ h_hi = 3.0 * std_val
336
+
337
+ if use_hard_bisection:
338
+ return find_h_crit_hard(
339
+ X, grid, target_modes, chunk_size, brentq_n_max,
340
+ h_lo, h_hi, formula=formula, eps=eps, tau=tau,
341
+ use_fft=use_fft,
342
+ )
343
+
344
+ from scipy.optimize import brentq
345
+
346
+ mode_fn = _get_mode_count_fn(formula)
347
+
348
+ with torch.no_grad():
349
+ n = X.shape[0]
350
+ if n > brentq_n_max:
351
+ idx = torch.randperm(n, device=X.device)[:brentq_n_max]
352
+ X_brentq = X[idx]
353
+ else:
354
+ X_brentq = X
355
+
356
+ # Build a coarse grid for brentq objective evaluation (Round 15a).
357
+ # g_brentq=128 gives 4× fewer KDE evaluations than G=512 with negligible
358
+ # root-location error — sign detection only needs coarse resolution.
359
+ grid_lo = grid[0].item()
360
+ grid_hi = grid[-1].item()
361
+ grid_coarse = torch.linspace(
362
+ grid_lo, grid_hi, g_brentq, dtype=grid.dtype, device=grid.device
363
+ )
364
+
365
+ # For 'cross': M̃_cross has a soft floor ~(target_modes + 0.001) in the
366
+ # target-mode regime, so a hard threshold of target_modes never cleanly
367
+ # separates the transition from the plateau. Using target_modes + 0.5
368
+ # places the root at the center of the sharp mode-merging transition.
369
+ # For 'integral': the formula has no such floor; use target_modes directly.
370
+ threshold = target_modes + (0.5 if formula == 'cross' else 0)
371
+
372
+ def f(h: float) -> float:
373
+ return mode_fn(X_brentq, h, grid_coarse, eps, tau).item() - threshold
374
+
375
+ f_lo = f(h_lo)
376
+ f_hi = f(h_hi)
377
+
378
+ if f_lo <= 0:
379
+ h_scan = h_lo
380
+ for _ in range(40):
381
+ h_scan *= 1.5
382
+ if h_scan >= h_hi:
383
+ break
384
+ f_scan = f(h_scan)
385
+ if f_scan > 0:
386
+ h_lo = h_scan
387
+ f_lo = f_scan
388
+ break
389
+ else:
390
+ return h_lo, 0.0
391
+ if f_lo <= 0:
392
+ return h_lo, 0.0
393
+
394
+ if f_hi > 0:
395
+ h_hi *= 3
396
+ f_hi = f(h_hi)
397
+ if f_hi > 0:
398
+ return h_hi, 0.0
399
+
400
+ h_crit = brentq(f, h_lo, h_hi)
401
+
402
+ # Condition number: |∂M̃/∂h| via finite difference at h_crit
403
+ dh = h_crit * 1e-4
404
+ with torch.no_grad():
405
+ m_plus = mode_fn(X_brentq, h_crit + dh, grid_coarse, eps, tau).item()
406
+ m_minus = mode_fn(X_brentq, h_crit - dh, grid_coarse, eps, tau).item()
407
+ cond_num = abs((m_plus - m_minus) / (2 * dh))
408
+
409
+ return float(h_crit), float(cond_num)
410
+
411
+
412
+ def _analytical_dM_dX(
413
+ X: Tensor,
414
+ h: float,
415
+ grid: Tensor,
416
+ dM_dfp: Tensor,
417
+ dM_dfpp: Tensor,
418
+ chunk_size: int,
419
+ ) -> Tensor:
420
+ """Chunked analytical ∂M̃/∂X using the KDE chain rule.
421
+
422
+ Avoids materialising the full n×G autograd graph. Derived from:
423
+ ∂f'(grid_j)/∂X_i = (1/n) K_ij · (1/h² − diff²_ij/h⁴)
424
+ ∂f''(grid_j)/∂X_i = (1/n) K_ij · diff_ij/h⁴ · (diff²_ij/h² − 3)
425
+ Then ∂M̃/∂X_i = Σ_j [∂M̃/∂f'_j · ∂f'_j/∂X_i + ∂M̃/∂f''_j · ∂f''_j/∂X_i].
426
+ """
427
+ n = X.shape[0]
428
+ h2 = h * h
429
+ h4 = h2 * h2
430
+ h6 = h4 * h2
431
+ out = torch.zeros(n, dtype=X.dtype, device=X.device)
432
+ with torch.no_grad():
433
+ for start in range(0, n, chunk_size):
434
+ Xc = X[start : start + chunk_size]
435
+ diff = grid.unsqueeze(0) - Xc.unsqueeze(1) # (c, G)
436
+ K = torch.exp(-0.5 * (diff / h) ** 2) / (math.sqrt(2 * math.pi) * h)
437
+ # ∂f'_j/∂X_i
438
+ gfp = K * (1.0 / h2 - diff ** 2 / h4) # (c, G)
439
+ # ∂f''_j/∂X_i
440
+ gfpp = K * diff / h4 * (diff ** 2 / h2 - 3.0) # (c, G)
441
+ out[start : start + chunk_size] = (
442
+ dM_dfp * gfp + dM_dfpp * gfpp
443
+ ).sum(1)
444
+ return out / n
445
+
446
+
447
+ def ift_gradient(
448
+ X: Tensor,
449
+ h_crit: float,
450
+ grid: Tensor,
451
+ eps: float,
452
+ tau: float,
453
+ grad_output: Tensor,
454
+ delta: float = 1e-4,
455
+ formula: str = 'cross',
456
+ chunk_size: int = 50_000,
457
+ analytical_n_thresh: int = 10_000,
458
+ safe_backward: bool = False,
459
+ ) -> Tensor:
460
+ """Compute the IFT gradient ∂h_crit/∂X and chain with upstream grad.
461
+
462
+ Implements the Implicit Function Theorem formula:
463
+
464
+ ∂h_crit/∂X = -(∂M̃/∂h)^{-1} · (∂M̃/∂X) evaluated at h = h_crit
465
+
466
+ using a single PyTorch autograd backward pass. The denominator is
467
+ stabilized by sg(·, delta) to prevent division by zero.
468
+
469
+ Parameters
470
+ ----------
471
+ X : Tensor, shape (n,)
472
+ Observed data points (need not require grad on entry).
473
+ h_crit : float
474
+ Critical bandwidth found by find_h_crit.
475
+ grid : Tensor, shape (G,)
476
+ Uniform evaluation grid.
477
+ eps : float
478
+ Width of the Gaussian delta approximation.
479
+ tau : float
480
+ Sigmoid temperature.
481
+ grad_output : Tensor, shape ()
482
+ Upstream scalar gradient from the loss w.r.t. h_crit.
483
+ delta : float
484
+ Stabilisation floor for the sg denominator. Default 1e-4.
485
+
486
+ Returns
487
+ -------
488
+ Tensor, shape (n,)
489
+ Gradient of the loss w.r.t. X.
490
+ """
491
+ # torch.autograd.Function.backward runs under no_grad; we need enable_grad
492
+ # so that the forward pass through soft_mode_count builds a graph that lets
493
+ # us differentiate w.r.t. both h_tensor and X_req.
494
+ #
495
+ # Total-derivative IFT: h_crit is a function of X both as KDE data and as
496
+ # the source of the evaluation grid Ω(X). We must differentiate M̃ w.r.t.
497
+ # X in both roles to get the correct total ∂M/∂X.
498
+ #
499
+ # We do this by rebuilding a differentiable grid from X_req inside the
500
+ # grad context, so autograd sees both paths automatically.
501
+ G = grid.shape[0]
502
+ margin_sigma = 3.0
503
+ eps_coeff = 0.1
504
+ tau_coeff = 0.2
505
+ n = X.shape[0]
506
+
507
+ if n > analytical_n_thresh and formula == 'cross':
508
+ # Large-n analytical path — O(chunk×G) peak memory.
509
+ # Avoids building an n×G autograd graph; uses the KDE chain rule instead.
510
+
511
+ # Step 1: KDE derivatives at h_crit (chunked, no graph)
512
+ with torch.no_grad():
513
+ f, fp, fpp = kde_derivatives_chunked(X, h_crit, grid, chunk_size)
514
+
515
+ # Step 2: ∂M̃/∂fp and ∂M̃/∂fpp via G-dim autograd only
516
+ with torch.enable_grad():
517
+ fp_req = fp.detach().requires_grad_(True)
518
+ fpp_req = fpp.detach().requires_grad_(True)
519
+ M_val = soft_mode_count_cross_from_derivs(
520
+ f.detach(), fp_req, fpp_req, grid, eps, tau
521
+ )
522
+ dM_dfp, dM_dfpp = torch.autograd.grad(
523
+ M_val, [fp_req, fpp_req], retain_graph=False
524
+ )
525
+
526
+ # Step 3: ∂M̃/∂h via central finite differences (2 chunked KDE passes)
527
+ dh = h_crit * 1e-4
528
+ with torch.no_grad():
529
+ f_p, fp_p, fpp_p = kde_derivatives_chunked(X, h_crit + dh, grid, chunk_size)
530
+ f_m, fp_m, fpp_m = kde_derivatives_chunked(X, h_crit - dh, grid, chunk_size)
531
+ M_plus = soft_mode_count_cross_from_derivs(f_p, fp_p, fpp_p, grid, eps, tau)
532
+ M_minus = soft_mode_count_cross_from_derivs(f_m, fp_m, fpp_m, grid, eps, tau)
533
+ dM_dh = torch.tensor(
534
+ float((M_plus - M_minus) / (2 * dh)), dtype=X.dtype, device=X.device
535
+ )
536
+
537
+ # Step 4: ∂M̃/∂X analytically (chunked chain rule, O(chunk×G) memory)
538
+ dM_dX = _analytical_dM_dX(X, h_crit, grid, dM_dfp, dM_dfpp, chunk_size)
539
+
540
+ else:
541
+ # Small-n autograd path — total-derivative IFT through data + grid + eps/tau.
542
+ # Rebuilds a differentiable grid and eps/tau from X_req so all three paths
543
+ # (KDE data, grid bounds, adaptive scale) contribute to ∂M̃/∂X.
544
+ with torch.enable_grad():
545
+ h_tensor = torch.tensor(h_crit, dtype=X.dtype, device=X.device, requires_grad=True)
546
+ X_req = X.detach().requires_grad_(True)
547
+ n_req = X_req.shape[0]
548
+
549
+ std_x = X_req.std()
550
+ lo = X_req.min() - margin_sigma * std_x
551
+ hi = X_req.max() + margin_sigma * std_x
552
+ t = torch.linspace(0.0, 1.0, G, dtype=X.dtype, device=X.device)
553
+ grid_diff = lo + (hi - lo) * t
554
+
555
+ mode_fn = _get_mode_count_fn(formula)
556
+ if formula == 'cross':
557
+ eps_diff = torch.tensor(eps, dtype=X.dtype, device=X.device)
558
+ tau_diff = torch.tensor(tau, dtype=X.dtype, device=X.device)
559
+ else:
560
+ h0_diff = 0.9 * std_x * (n_req ** -0.2)
561
+ u = (grid_diff.unsqueeze(0) - X_req.unsqueeze(1)) / h0_diff
562
+ K0 = torch.exp(-0.5 * u ** 2) / (math.sqrt(2 * math.pi) * h0_diff)
563
+ f_prime0 = (-u / h0_diff * K0).mean(dim=0)
564
+ f_dbl0 = ((u ** 2 - 1.0) / h0_diff ** 2 * K0).mean(dim=0)
565
+ eps_diff = eps_coeff * f_prime0.std()
566
+ tau_diff = tau_coeff * f_dbl0.abs().median()
567
+
568
+ M = mode_fn(X_req, h_tensor, grid_diff, eps_diff, tau_diff)
569
+ dM_dh = torch.autograd.grad(M, h_tensor, retain_graph=True, create_graph=False)[0]
570
+ dM_dX = torch.autograd.grad(M, X_req, retain_graph=False, create_graph=False)[0]
571
+
572
+ # IFT formula with stabilized denominator + denom guard.
573
+ DENOM_GUARD = 0.01
574
+ SAFE_GUARD = 0.1 # 10× stricter clamp when safe_backward=True
575
+ denom_abs = dM_dh.abs().item()
576
+ denom_signed = dM_dh.item() # raw signed value before any clamping
577
+ ift_gradient.last_denom_abs = denom_abs
578
+ ift_gradient.last_denom_signed = denom_signed
579
+ ift_gradient.last_guard_triggered = denom_abs < DENOM_GUARD
580
+
581
+ if denom_abs < DENOM_GUARD:
582
+ warnings.warn(
583
+ f"DCB IFT denominator |∂M̃/∂h|={denom_abs:.2e} < {DENOM_GUARD}. "
584
+ "Gradient may be large. Use safe_backward=True to clamp.",
585
+ stacklevel=3,
586
+ )
587
+
588
+ if safe_backward and denom_abs == 0.0:
589
+ # Exact-zero edge case: sg(0, delta) always returns +delta, which can
590
+ # invert the gradient sign if dM_dh approached zero from the negative
591
+ # side. Use copysign to preserve whatever sign the raw value had
592
+ # (copysign(SAFE_GUARD, 0.0) = +SAFE_GUARD by IEEE convention, but
593
+ # this is no worse than the original behaviour and documents intent).
594
+ safe_denom = torch.tensor(
595
+ math.copysign(SAFE_GUARD, denom_signed),
596
+ dtype=X.dtype, device=X.device,
597
+ )
598
+ dh_dX = -(1.0 / safe_denom) * dM_dX
599
+ else:
600
+ effective_guard = SAFE_GUARD if safe_backward else (DENOM_GUARD if denom_abs < DENOM_GUARD else delta)
601
+ dh_dX = -(1.0 / sg(dM_dh, effective_guard)) * dM_dX
602
+ ift_gradient.last_grad_norm = float((grad_output * dh_dX).norm().item())
603
+
604
+ return grad_output * dh_dX