diffcb 0.1.3__tar.gz → 0.1.4__tar.gz
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- {diffcb-0.1.3 → diffcb-0.1.4}/PKG-INFO +1 -1
- {diffcb-0.1.3 → diffcb-0.1.4}/dcb/__init__.py +1 -1
- {diffcb-0.1.3 → diffcb-0.1.4}/dcb/fft_kde.py +136 -59
- {diffcb-0.1.3 → diffcb-0.1.4}/dcb/layer.py +8 -5
- {diffcb-0.1.3 → diffcb-0.1.4}/dcb/solver.py +29 -8
- {diffcb-0.1.3 → diffcb-0.1.4}/pyproject.toml +1 -1
- {diffcb-0.1.3 → diffcb-0.1.4}/.gitignore +0 -0
- {diffcb-0.1.3 → diffcb-0.1.4}/.zenodo.json +0 -0
- {diffcb-0.1.3 → diffcb-0.1.4}/LICENSE +0 -0
- {diffcb-0.1.3 → diffcb-0.1.4}/README.md +0 -0
- {diffcb-0.1.3 → diffcb-0.1.4}/dcb/diagnostics.py +0 -0
- {diffcb-0.1.3 → diffcb-0.1.4}/dcb/kde.py +0 -0
- {diffcb-0.1.3 → diffcb-0.1.4}/dcb/utils.py +0 -0
- {diffcb-0.1.3 → diffcb-0.1.4}/notebooks/.gitkeep +0 -0
- {diffcb-0.1.3 → diffcb-0.1.4}/tests/test_kde.py +0 -0
- {diffcb-0.1.3 → diffcb-0.1.4}/tests/test_layer.py +0 -0
- {diffcb-0.1.3 → diffcb-0.1.4}/tests/test_r18c_denom_audit.py +0 -0
- {diffcb-0.1.3 → diffcb-0.1.4}/tests/test_r18c_deprecation_warn.py +0 -0
- {diffcb-0.1.3 → diffcb-0.1.4}/tests/test_r19_default_fft.py +0 -0
- {diffcb-0.1.3 → diffcb-0.1.4}/tests/test_r19_diagnostics.py +0 -0
- {diffcb-0.1.3 → diffcb-0.1.4}/tests/test_solver.py +0 -0
|
@@ -1,6 +1,6 @@
|
|
|
1
1
|
Metadata-Version: 2.4
|
|
2
2
|
Name: diffcb
|
|
3
|
-
Version: 0.1.
|
|
3
|
+
Version: 0.1.4
|
|
4
4
|
Summary: Differentiable Critical Bandwidth: Silverman's modality test as a differentiable PyTorch layer with IFT backward pass.
|
|
5
5
|
Project-URL: Homepage, https://github.com/ryZhangHason/differentiable-critical-bandwidth
|
|
6
6
|
Project-URL: Repository, https://github.com/ryZhangHason/differentiable-critical-bandwidth
|
|
@@ -19,11 +19,142 @@ import torch
|
|
|
19
19
|
from torch import Tensor
|
|
20
20
|
|
|
21
21
|
|
|
22
|
+
# Worker 2: device-native histogram
|
|
23
|
+
def _histogram_on_device(X: Tensor, G: int, lo: float, hi: float) -> Tensor:
|
|
24
|
+
"""Compute a G-bin histogram of X on the same device as X."""
|
|
25
|
+
device = X.device
|
|
26
|
+
if device.type == 'cuda':
|
|
27
|
+
return torch.histc(X.float(), bins=G, min=lo, max=hi)
|
|
28
|
+
elif device.type == 'mps':
|
|
29
|
+
bin_idx = ((X.float() - lo) * (G / (hi - lo))).long().clamp_(0, G - 1)
|
|
30
|
+
counts = torch.zeros(G, dtype=torch.float32, device=device)
|
|
31
|
+
counts.scatter_add_(0, bin_idx, torch.ones(X.shape[0], dtype=torch.float32, device=device))
|
|
32
|
+
return counts
|
|
33
|
+
else: # cpu
|
|
34
|
+
X_cpu = X.float()
|
|
35
|
+
edges = torch.linspace(lo, hi, G + 1)
|
|
36
|
+
bin_idx = torch.bucketize(X_cpu, edges, right=True).clamp(1, G) - 1
|
|
37
|
+
return torch.bincount(bin_idx, minlength=G).float()
|
|
38
|
+
|
|
39
|
+
|
|
40
|
+
def precompute_fft(
|
|
41
|
+
X: Tensor,
|
|
42
|
+
G: int = 4096,
|
|
43
|
+
domain: tuple[float, float] | None = None,
|
|
44
|
+
pad_factor: int = 2, # Worker 5: pad_factor=2 (was 4) — safe for h ≤ 3σ, halves irfft size
|
|
45
|
+
fft_dtype: torch.dtype = torch.float32, # Worker 3: float32 FFT
|
|
46
|
+
) -> tuple[Tensor, Tensor, tuple[float, float]]:
|
|
47
|
+
"""Precompute the FFT of the zero-padded histogram of X.
|
|
48
|
+
|
|
49
|
+
This is the bandwidth-independent work shared across a bisection loop on
|
|
50
|
+
h: build the histogram, zero-pad, take rfft, and build the frequency grid
|
|
51
|
+
omega. The per-step kernel K(omega, h) = i*omega*exp(-0.5*(omega*h)**2)
|
|
52
|
+
must be combined with C inside `mode_count_from_C`.
|
|
53
|
+
|
|
54
|
+
Parameters
|
|
55
|
+
----------
|
|
56
|
+
X : Tensor, shape (n,)
|
|
57
|
+
G : int
|
|
58
|
+
Number of histogram bins.
|
|
59
|
+
domain : (lo, hi) or None
|
|
60
|
+
If provided, use as histogram domain; otherwise computed from X
|
|
61
|
+
with a 3*sigma margin.
|
|
62
|
+
pad_factor : int
|
|
63
|
+
Zero-padding multiplier (default 4).
|
|
64
|
+
|
|
65
|
+
Returns
|
|
66
|
+
-------
|
|
67
|
+
C : Tensor, shape (N//2+1,), complex128
|
|
68
|
+
rfft of the zero-padded float64 histogram. Empty tensor (degenerate
|
|
69
|
+
zero-range domain) signals the caller to short-circuit to 1 mode.
|
|
70
|
+
omega : Tensor, shape (N//2+1,), float64
|
|
71
|
+
Angular frequency grid for the FFT.
|
|
72
|
+
domain : (lo, hi)
|
|
73
|
+
Domain tuple actually used.
|
|
74
|
+
"""
|
|
75
|
+
with torch.no_grad():
|
|
76
|
+
if domain is not None:
|
|
77
|
+
lo, hi = domain
|
|
78
|
+
else:
|
|
79
|
+
sigma = X.std().item()
|
|
80
|
+
if sigma == 0.0:
|
|
81
|
+
sigma = 1.0
|
|
82
|
+
lo = X.min().item() - 3 * sigma
|
|
83
|
+
hi = X.max().item() + 3 * sigma
|
|
84
|
+
data_range = hi - lo
|
|
85
|
+
|
|
86
|
+
if data_range == 0.0:
|
|
87
|
+
complex_dtype = torch.complex64 if fft_dtype == torch.float32 else torch.complex128
|
|
88
|
+
empty = torch.zeros(0, dtype=complex_dtype, device=X.device)
|
|
89
|
+
empty_omega = torch.zeros(0, dtype=fft_dtype, device=X.device)
|
|
90
|
+
return empty, empty_omega, (lo, hi)
|
|
91
|
+
|
|
92
|
+
# Histogram (O(n)) — device-native dispatch.
|
|
93
|
+
counts = _histogram_on_device(X, G, lo, hi)
|
|
94
|
+
|
|
95
|
+
N = pad_factor * G
|
|
96
|
+
counts_padded = torch.zeros(N, dtype=fft_dtype, device=X.device)
|
|
97
|
+
counts_padded[:G] = counts.to(fft_dtype)
|
|
98
|
+
|
|
99
|
+
C = torch.fft.rfft(counts_padded)
|
|
100
|
+
|
|
101
|
+
bin_width = data_range / G
|
|
102
|
+
k = torch.arange(N // 2 + 1, device=X.device, dtype=fft_dtype)
|
|
103
|
+
omega = 2 * math.pi * k / (N * bin_width)
|
|
104
|
+
|
|
105
|
+
return C, omega, (lo, hi)
|
|
106
|
+
|
|
107
|
+
|
|
108
|
+
def mode_count_from_C(
|
|
109
|
+
C: Tensor,
|
|
110
|
+
omega: Tensor,
|
|
111
|
+
h: float,
|
|
112
|
+
G: int,
|
|
113
|
+
N: int,
|
|
114
|
+
) -> int:
|
|
115
|
+
"""Per-step mode count: apply Gaussian derivative kernel and count sign changes.
|
|
116
|
+
|
|
117
|
+
Cheap inner loop body for bisection — only the kernel depends on h.
|
|
118
|
+
|
|
119
|
+
Parameters
|
|
120
|
+
----------
|
|
121
|
+
C : Tensor, shape (N//2+1,), complex
|
|
122
|
+
rfft of the zero-padded histogram (from `precompute_fft`).
|
|
123
|
+
omega : Tensor, shape (N//2+1,), float64
|
|
124
|
+
Frequency grid (from `precompute_fft`).
|
|
125
|
+
h : float
|
|
126
|
+
Bandwidth.
|
|
127
|
+
G : int
|
|
128
|
+
Histogram bin count.
|
|
129
|
+
N : int
|
|
130
|
+
Padded FFT length (pad_factor * G).
|
|
131
|
+
|
|
132
|
+
Returns
|
|
133
|
+
-------
|
|
134
|
+
int
|
|
135
|
+
Number of KDE modes.
|
|
136
|
+
"""
|
|
137
|
+
if C.numel() == 0:
|
|
138
|
+
return 1 # degenerate single-point distribution
|
|
139
|
+
|
|
140
|
+
K_deriv = 1j * omega * torch.exp(-0.5 * (omega * h) ** 2)
|
|
141
|
+
f_prime_padded = torch.fft.irfft(C * K_deriv, n=N).real
|
|
142
|
+
f_prime = f_prime_padded[:G]
|
|
143
|
+
|
|
144
|
+
nonzero_mask = f_prime != 0
|
|
145
|
+
if not nonzero_mask.any():
|
|
146
|
+
return 0
|
|
147
|
+
|
|
148
|
+
s = f_prime[nonzero_mask]
|
|
149
|
+
transitions = int(((s[:-1] > 0) & (s[1:] < 0)).sum().item())
|
|
150
|
+
return transitions
|
|
151
|
+
|
|
152
|
+
|
|
22
153
|
def fft_mode_count(
|
|
23
154
|
X: Tensor,
|
|
24
155
|
h: float,
|
|
25
156
|
G: int = 4096,
|
|
26
|
-
pad_factor: int = 4
|
|
157
|
+
pad_factor: int = 2, # Worker 5: pad_factor=2 (was 4) — safe for h ≤ 3σ, halves irfft size
|
|
27
158
|
domain: tuple[float, float] | None = None,
|
|
28
159
|
) -> int:
|
|
29
160
|
"""Count KDE modes via FFT convolution — O(n + G log G), no subsampling.
|
|
@@ -58,60 +189,9 @@ def fft_mode_count(
|
|
|
58
189
|
Number of KDE modes (downward zero-crossings of f').
|
|
59
190
|
"""
|
|
60
191
|
with torch.no_grad():
|
|
61
|
-
|
|
62
|
-
lo, hi = domain
|
|
63
|
-
else:
|
|
64
|
-
# Domain: extend 3σ beyond data range to avoid boundary effects
|
|
65
|
-
sigma = X.std().item()
|
|
66
|
-
if sigma == 0.0:
|
|
67
|
-
sigma = 1.0 # degenerate case: all points identical
|
|
68
|
-
lo = X.min().item() - 3 * sigma
|
|
69
|
-
hi = X.max().item() + 3 * sigma
|
|
70
|
-
data_range = hi - lo
|
|
71
|
-
|
|
72
|
-
if data_range == 0.0:
|
|
73
|
-
return 1 # single-point distribution has 1 mode
|
|
74
|
-
|
|
75
|
-
# Histogram (O(n)) — MPS-safe via bucketize+bincount on CPU.
|
|
76
|
-
# torch.histc on MPS allocates an n × bins float32 intermediate (PyTorch
|
|
77
|
-
# MPS bug); at n=5M, bins=512 this is ~9.5 GiB → OOM. Moving to CPU for
|
|
78
|
-
# the binning step avoids the intermediate and is numerically identical
|
|
79
|
-
# for data within [lo, hi] (guaranteed by the 3σ domain extension above).
|
|
80
|
-
X_cpu = X.float().cpu()
|
|
81
|
-
edges = torch.linspace(lo, hi, G + 1) # (G+1,) CPU
|
|
82
|
-
bin_idx = torch.bucketize(X_cpu, edges, right=True).clamp(1, G) - 1 # 0-indexed
|
|
83
|
-
counts = torch.bincount(bin_idx, minlength=G).float().to(X.device) # back to device
|
|
84
|
-
|
|
85
|
-
# Zero-pad to pad_factor*G — promote to float64 for FFT precision
|
|
192
|
+
C, omega, _ = precompute_fft(X, G=G, domain=domain, pad_factor=pad_factor)
|
|
86
193
|
N = pad_factor * G
|
|
87
|
-
|
|
88
|
-
counts_padded[:G] = counts.double()
|
|
89
|
-
|
|
90
|
-
# FFT of histogram (float64)
|
|
91
|
-
C = torch.fft.rfft(counts_padded)
|
|
92
|
-
|
|
93
|
-
# Derivative kernel in frequency domain (float64)
|
|
94
|
-
bin_width = data_range / G
|
|
95
|
-
k = torch.arange(N // 2 + 1, device=X.device, dtype=torch.float64)
|
|
96
|
-
omega = 2 * math.pi * k / (N * bin_width)
|
|
97
|
-
K_deriv = 1j * omega * torch.exp(-0.5 * (omega * h) ** 2)
|
|
98
|
-
|
|
99
|
-
# Convolve and back-transform; cast result back to float32
|
|
100
|
-
f_prime_padded = torch.fft.irfft(C * K_deriv, n=N).float()
|
|
101
|
-
|
|
102
|
-
# Trim to original G grid (discard zero-padded tail)
|
|
103
|
-
f_prime = f_prime_padded[:G]
|
|
104
|
-
|
|
105
|
-
# Count (+→-) sign changes = number of modes
|
|
106
|
-
# A mode is a local max of f, i.e., f' crosses zero from + to -
|
|
107
|
-
# Remove zeros (flat segments) — carry forward last nonzero sign
|
|
108
|
-
nonzero_mask = f_prime != 0
|
|
109
|
-
if not nonzero_mask.any():
|
|
110
|
-
return 0
|
|
111
|
-
|
|
112
|
-
s = f_prime[nonzero_mask]
|
|
113
|
-
transitions = int(((s[:-1] > 0) & (s[1:] < 0)).sum().item())
|
|
114
|
-
return transitions
|
|
194
|
+
return mode_count_from_C(C, omega, h, G, N)
|
|
115
195
|
|
|
116
196
|
|
|
117
197
|
def _refine_hcrit(
|
|
@@ -121,7 +201,7 @@ def _refine_hcrit(
|
|
|
121
201
|
G: int,
|
|
122
202
|
domain: tuple[float, float],
|
|
123
203
|
target_modes: int = 1,
|
|
124
|
-
pad_factor: int = 4
|
|
204
|
+
pad_factor: int = 2, # Worker 5: pad_factor=2 (was 4) — safe for h ≤ 3σ, halves irfft size
|
|
125
205
|
) -> float:
|
|
126
206
|
"""Sub-bin quadratic refinement of h_crit after bisection converges.
|
|
127
207
|
|
|
@@ -162,10 +242,7 @@ def _refine_hcrit(
|
|
|
162
242
|
|
|
163
243
|
# Pre-compute histogram once; reuse C (FFT of counts) for all h evaluations.
|
|
164
244
|
with torch.no_grad():
|
|
165
|
-
|
|
166
|
-
edges = torch.linspace(lo_d, hi_d, G + 1)
|
|
167
|
-
bin_idx = torch.bucketize(X_cpu, edges, right=True).clamp(1, G) - 1
|
|
168
|
-
counts = torch.bincount(bin_idx, minlength=G).float()
|
|
245
|
+
counts = _histogram_on_device(X, G, lo_d, hi_d).cpu()
|
|
169
246
|
counts_padded = torch.zeros(N, dtype=torch.float64)
|
|
170
247
|
counts_padded[:G] = counts.double()
|
|
171
248
|
C = torch.fft.rfft(counts_padded)
|
|
@@ -35,13 +35,14 @@ class DCBFunction(torch.autograd.Function):
|
|
|
35
35
|
|
|
36
36
|
@staticmethod
|
|
37
37
|
def forward(ctx, X, grid, eps, tau, target_modes, delta, formula, chunk_size,
|
|
38
|
-
brentq_n_max, g_brentq, use_hard_bisection, safe_backward, use_fft, fft_G_min
|
|
38
|
+
brentq_n_max, g_brentq, use_hard_bisection, safe_backward, use_fft, fft_G_min,
|
|
39
|
+
fft_dtype):
|
|
39
40
|
"""Locate h_crit and save state for the backward pass."""
|
|
40
41
|
h_crit, cond_num = find_h_crit(
|
|
41
42
|
X, grid, eps, tau, target_modes,
|
|
42
43
|
formula=formula, brentq_n_max=brentq_n_max, chunk_size=chunk_size,
|
|
43
44
|
g_brentq=g_brentq, use_hard_bisection=use_hard_bisection,
|
|
44
|
-
use_fft=use_fft, G_min=fft_G_min,
|
|
45
|
+
use_fft=use_fft, G_min=fft_G_min, fft_dtype=fft_dtype,
|
|
45
46
|
)
|
|
46
47
|
ctx.save_for_backward(X, grid)
|
|
47
48
|
ctx.h_crit = h_crit
|
|
@@ -67,8 +68,8 @@ class DCBFunction(torch.autograd.Function):
|
|
|
67
68
|
ctx.denom_abs = ift_gradient.last_denom_abs
|
|
68
69
|
# Gradients for: X, grid, eps, tau, target_modes, delta, formula,
|
|
69
70
|
# chunk_size, brentq_n_max, g_brentq, use_hard_bisection,
|
|
70
|
-
# safe_backward, use_fft, fft_G_min
|
|
71
|
-
return grad_X, None, None, None, None, None, None, None, None, None, None, None, None, None
|
|
71
|
+
# safe_backward, use_fft, fft_G_min, fft_dtype
|
|
72
|
+
return grad_X, None, None, None, None, None, None, None, None, None, None, None, None, None, None
|
|
72
73
|
|
|
73
74
|
|
|
74
75
|
class DCBLayer(nn.Module):
|
|
@@ -170,6 +171,7 @@ class DCBLayer(nn.Module):
|
|
|
170
171
|
max_n_exact: int | None = 1_000_000,
|
|
171
172
|
sketch_size: int = 500_000,
|
|
172
173
|
fft_G_min: int = 16384,
|
|
174
|
+
fft_dtype: torch.dtype = torch.float32,
|
|
173
175
|
):
|
|
174
176
|
super().__init__()
|
|
175
177
|
self.target_modes = target_modes
|
|
@@ -189,6 +191,7 @@ class DCBLayer(nn.Module):
|
|
|
189
191
|
self.max_n_exact = max_n_exact
|
|
190
192
|
self.sketch_size = sketch_size
|
|
191
193
|
self.fft_G_min = fft_G_min
|
|
194
|
+
self.fft_dtype = fft_dtype
|
|
192
195
|
if use_fft and brentq_n_max != 50_000:
|
|
193
196
|
raise TypeError(
|
|
194
197
|
f"brentq_n_max={brentq_n_max} is meaningless when use_fft=True: the FFT path "
|
|
@@ -259,7 +262,7 @@ class DCBLayer(nn.Module):
|
|
|
259
262
|
return DCBFunction.apply(
|
|
260
263
|
X, grid, eps_eff, tau_eff, self.target_modes, self.delta, self.formula,
|
|
261
264
|
self.chunk_size, self.brentq_n_max, self.g_brentq, self.use_hard_bisection,
|
|
262
|
-
self.safe_backward, self.use_fft, self.fft_G_min,
|
|
265
|
+
self.safe_backward, self.use_fft, self.fft_G_min, self.fft_dtype,
|
|
263
266
|
)
|
|
264
267
|
|
|
265
268
|
|
|
@@ -37,7 +37,7 @@ from dcb.kde import (
|
|
|
37
37
|
soft_mode_count_cross_from_derivs,
|
|
38
38
|
kde_derivatives_chunked,
|
|
39
39
|
)
|
|
40
|
-
from dcb.fft_kde import fft_mode_count, adaptive_fft_G
|
|
40
|
+
from dcb.fft_kde import fft_mode_count, adaptive_fft_G, precompute_fft, mode_count_from_C
|
|
41
41
|
|
|
42
42
|
_AUTO_FFT_THRESHOLD = 50_000 # n above which FFT bisection activates (use_fft_effective)
|
|
43
43
|
|
|
@@ -75,6 +75,7 @@ def find_h_crit_hard(
|
|
|
75
75
|
tau: float = 0.2,
|
|
76
76
|
use_fft: bool = False,
|
|
77
77
|
G_min: int = 16384,
|
|
78
|
+
fft_dtype: torch.dtype = torch.float32,
|
|
78
79
|
) -> tuple[float, float]:
|
|
79
80
|
"""Find h_crit via hard-mode-count bisection (monotone, no false roots).
|
|
80
81
|
|
|
@@ -154,38 +155,54 @@ def find_h_crit_hard(
|
|
|
154
155
|
data_range = hi_domain - lo_domain
|
|
155
156
|
G_fft = adaptive_fft_G(data_range, h_hi, G_min=G_min)
|
|
156
157
|
_domain = (lo_domain, hi_domain)
|
|
158
|
+
pad_factor = 2 # Worker 5: pad_factor=2 (was 4) — safe for h ≤ 3σ, halves irfft size
|
|
159
|
+
N = pad_factor * G_fft
|
|
157
160
|
|
|
158
161
|
with torch.no_grad():
|
|
162
|
+
# Worker 1: precomputed C — hoist histogram + rfft out of bisection.
|
|
163
|
+
# Worker 3: float32 FFT by default — 2× faster; _refine_hcrit uses float64 independently.
|
|
164
|
+
C, omega, _domain = precompute_fft(
|
|
165
|
+
X, G=G_fft, domain=_domain, pad_factor=pad_factor, fft_dtype=fft_dtype,
|
|
166
|
+
)
|
|
167
|
+
|
|
159
168
|
# Verify bracket using FFT mode count on full X
|
|
160
|
-
count_lo =
|
|
169
|
+
count_lo = mode_count_from_C(C, omega, h_lo, G_fft, N)
|
|
161
170
|
if count_lo <= target_modes:
|
|
162
171
|
h_lo_try = h_lo
|
|
163
172
|
for _ in range(30):
|
|
164
173
|
h_lo_try *= 0.5
|
|
165
174
|
if h_lo_try < 1e-10:
|
|
166
175
|
break
|
|
167
|
-
if
|
|
176
|
+
if mode_count_from_C(C, omega, h_lo_try, G_fft, N) > target_modes:
|
|
168
177
|
h_lo = h_lo_try
|
|
169
178
|
break
|
|
170
179
|
|
|
171
|
-
count_hi =
|
|
180
|
+
count_hi = mode_count_from_C(C, omega, h_hi, G_fft, N)
|
|
172
181
|
if count_hi > target_modes:
|
|
173
182
|
for _ in range(30):
|
|
174
183
|
h_hi *= 2.0
|
|
175
|
-
if
|
|
184
|
+
if mode_count_from_C(C, omega, h_hi, G_fft, N) <= target_modes:
|
|
176
185
|
break
|
|
177
186
|
|
|
178
|
-
#
|
|
187
|
+
# Adaptive bisection: stop when bracket is localised (relative width < 1e-3)
|
|
188
|
+
# _refine_hcrit provides sub-bin precision afterwards — no need to over-bisect.
|
|
179
189
|
lo, hi = h_lo, h_hi
|
|
180
190
|
for _ in range(50):
|
|
181
191
|
mid = (lo + hi) / 2.0
|
|
182
|
-
count =
|
|
192
|
+
count = mode_count_from_C(C, omega, mid, G_fft, N)
|
|
183
193
|
if count <= target_modes:
|
|
184
194
|
hi = mid
|
|
185
195
|
else:
|
|
186
196
|
lo = mid
|
|
187
197
|
if (hi - lo) < tol:
|
|
188
198
|
break
|
|
199
|
+
# Worker 4: adaptive termination — stop when relative bracket width
|
|
200
|
+
# is small enough that further bisection cannot meaningfully shift
|
|
201
|
+
# _refine_hcrit's quadratic fit. Empirically 1e-7 preserves h_crit
|
|
202
|
+
# to within 1e-6 of the 50-step tol=1e-6 baseline while saving ~10
|
|
203
|
+
# bisection steps in typical cases.
|
|
204
|
+
if hi > 0 and (hi - lo) / hi < 1e-7:
|
|
205
|
+
break
|
|
189
206
|
|
|
190
207
|
h_crit = float(hi) # smallest h with count <= target_modes
|
|
191
208
|
|
|
@@ -222,6 +239,9 @@ def find_h_crit_hard(
|
|
|
222
239
|
break
|
|
223
240
|
|
|
224
241
|
# Standard bisection: 50 iterations → bracket width / 2^50
|
|
242
|
+
# NOTE: non-FFT path has no _refine_hcrit sub-bin refinement, so we keep
|
|
243
|
+
# tight bisection here for gradient stability (IFT test requires h_crit
|
|
244
|
+
# accurate well below FD perturbation delta=1e-3).
|
|
225
245
|
lo, hi = h_lo, h_hi
|
|
226
246
|
for _ in range(50):
|
|
227
247
|
mid = (lo + hi) / 2.0
|
|
@@ -297,6 +317,7 @@ def find_h_crit(
|
|
|
297
317
|
use_hard_bisection: bool = True,
|
|
298
318
|
use_fft: bool = True,
|
|
299
319
|
G_min: int = 16384,
|
|
320
|
+
fft_dtype: torch.dtype = torch.float32,
|
|
300
321
|
) -> tuple[float, float]:
|
|
301
322
|
"""Find h_crit and return (h_crit, condition_number).
|
|
302
323
|
|
|
@@ -350,7 +371,7 @@ def find_h_crit(
|
|
|
350
371
|
return find_h_crit_hard(
|
|
351
372
|
X, grid, target_modes, chunk_size, brentq_n_max,
|
|
352
373
|
h_lo, h_hi, formula=formula, eps=eps, tau=tau,
|
|
353
|
-
use_fft=use_fft, G_min=G_min,
|
|
374
|
+
use_fft=use_fft, G_min=G_min, fft_dtype=fft_dtype,
|
|
354
375
|
)
|
|
355
376
|
|
|
356
377
|
from scipy.optimize import brentq
|
|
@@ -4,7 +4,7 @@ build-backend = "hatchling.build"
|
|
|
4
4
|
|
|
5
5
|
[project]
|
|
6
6
|
name = "diffcb"
|
|
7
|
-
version = "0.1.
|
|
7
|
+
version = "0.1.4"
|
|
8
8
|
description = "Differentiable Critical Bandwidth: Silverman's modality test as a differentiable PyTorch layer with IFT backward pass."
|
|
9
9
|
readme = "README.md"
|
|
10
10
|
license = { file = "LICENSE" }
|
|
File without changes
|
|
File without changes
|
|
File without changes
|
|
File without changes
|
|
File without changes
|
|
File without changes
|
|
File without changes
|
|
File without changes
|
|
File without changes
|
|
File without changes
|
|
File without changes
|
|
File without changes
|
|
File without changes
|
|
File without changes
|
|
File without changes
|