diffcb 0.1.3__tar.gz → 0.1.4__tar.gz

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
@@ -1,6 +1,6 @@
1
1
  Metadata-Version: 2.4
2
2
  Name: diffcb
3
- Version: 0.1.3
3
+ Version: 0.1.4
4
4
  Summary: Differentiable Critical Bandwidth: Silverman's modality test as a differentiable PyTorch layer with IFT backward pass.
5
5
  Project-URL: Homepage, https://github.com/ryZhangHason/differentiable-critical-bandwidth
6
6
  Project-URL: Repository, https://github.com/ryZhangHason/differentiable-critical-bandwidth
@@ -19,4 +19,4 @@ __all__ = [
19
19
  "DCBLayer", "DifferentiableCriticalBandwidth",
20
20
  "anneal_eps_tau", "soft_mode_count_cross", "soft_mode_count",
21
21
  ]
22
- __version__ = "0.1.3"
22
+ __version__ = "0.1.4"
@@ -19,11 +19,142 @@ import torch
19
19
  from torch import Tensor
20
20
 
21
21
 
22
+ # Worker 2: device-native histogram
23
+ def _histogram_on_device(X: Tensor, G: int, lo: float, hi: float) -> Tensor:
24
+ """Compute a G-bin histogram of X on the same device as X."""
25
+ device = X.device
26
+ if device.type == 'cuda':
27
+ return torch.histc(X.float(), bins=G, min=lo, max=hi)
28
+ elif device.type == 'mps':
29
+ bin_idx = ((X.float() - lo) * (G / (hi - lo))).long().clamp_(0, G - 1)
30
+ counts = torch.zeros(G, dtype=torch.float32, device=device)
31
+ counts.scatter_add_(0, bin_idx, torch.ones(X.shape[0], dtype=torch.float32, device=device))
32
+ return counts
33
+ else: # cpu
34
+ X_cpu = X.float()
35
+ edges = torch.linspace(lo, hi, G + 1)
36
+ bin_idx = torch.bucketize(X_cpu, edges, right=True).clamp(1, G) - 1
37
+ return torch.bincount(bin_idx, minlength=G).float()
38
+
39
+
40
+ def precompute_fft(
41
+ X: Tensor,
42
+ G: int = 4096,
43
+ domain: tuple[float, float] | None = None,
44
+ pad_factor: int = 2, # Worker 5: pad_factor=2 (was 4) — safe for h ≤ 3σ, halves irfft size
45
+ fft_dtype: torch.dtype = torch.float32, # Worker 3: float32 FFT
46
+ ) -> tuple[Tensor, Tensor, tuple[float, float]]:
47
+ """Precompute the FFT of the zero-padded histogram of X.
48
+
49
+ This is the bandwidth-independent work shared across a bisection loop on
50
+ h: build the histogram, zero-pad, take rfft, and build the frequency grid
51
+ omega. The per-step kernel K(omega, h) = i*omega*exp(-0.5*(omega*h)**2)
52
+ must be combined with C inside `mode_count_from_C`.
53
+
54
+ Parameters
55
+ ----------
56
+ X : Tensor, shape (n,)
57
+ G : int
58
+ Number of histogram bins.
59
+ domain : (lo, hi) or None
60
+ If provided, use as histogram domain; otherwise computed from X
61
+ with a 3*sigma margin.
62
+ pad_factor : int
63
+ Zero-padding multiplier (default 4).
64
+
65
+ Returns
66
+ -------
67
+ C : Tensor, shape (N//2+1,), complex128
68
+ rfft of the zero-padded float64 histogram. Empty tensor (degenerate
69
+ zero-range domain) signals the caller to short-circuit to 1 mode.
70
+ omega : Tensor, shape (N//2+1,), float64
71
+ Angular frequency grid for the FFT.
72
+ domain : (lo, hi)
73
+ Domain tuple actually used.
74
+ """
75
+ with torch.no_grad():
76
+ if domain is not None:
77
+ lo, hi = domain
78
+ else:
79
+ sigma = X.std().item()
80
+ if sigma == 0.0:
81
+ sigma = 1.0
82
+ lo = X.min().item() - 3 * sigma
83
+ hi = X.max().item() + 3 * sigma
84
+ data_range = hi - lo
85
+
86
+ if data_range == 0.0:
87
+ complex_dtype = torch.complex64 if fft_dtype == torch.float32 else torch.complex128
88
+ empty = torch.zeros(0, dtype=complex_dtype, device=X.device)
89
+ empty_omega = torch.zeros(0, dtype=fft_dtype, device=X.device)
90
+ return empty, empty_omega, (lo, hi)
91
+
92
+ # Histogram (O(n)) — device-native dispatch.
93
+ counts = _histogram_on_device(X, G, lo, hi)
94
+
95
+ N = pad_factor * G
96
+ counts_padded = torch.zeros(N, dtype=fft_dtype, device=X.device)
97
+ counts_padded[:G] = counts.to(fft_dtype)
98
+
99
+ C = torch.fft.rfft(counts_padded)
100
+
101
+ bin_width = data_range / G
102
+ k = torch.arange(N // 2 + 1, device=X.device, dtype=fft_dtype)
103
+ omega = 2 * math.pi * k / (N * bin_width)
104
+
105
+ return C, omega, (lo, hi)
106
+
107
+
108
+ def mode_count_from_C(
109
+ C: Tensor,
110
+ omega: Tensor,
111
+ h: float,
112
+ G: int,
113
+ N: int,
114
+ ) -> int:
115
+ """Per-step mode count: apply Gaussian derivative kernel and count sign changes.
116
+
117
+ Cheap inner loop body for bisection — only the kernel depends on h.
118
+
119
+ Parameters
120
+ ----------
121
+ C : Tensor, shape (N//2+1,), complex
122
+ rfft of the zero-padded histogram (from `precompute_fft`).
123
+ omega : Tensor, shape (N//2+1,), float64
124
+ Frequency grid (from `precompute_fft`).
125
+ h : float
126
+ Bandwidth.
127
+ G : int
128
+ Histogram bin count.
129
+ N : int
130
+ Padded FFT length (pad_factor * G).
131
+
132
+ Returns
133
+ -------
134
+ int
135
+ Number of KDE modes.
136
+ """
137
+ if C.numel() == 0:
138
+ return 1 # degenerate single-point distribution
139
+
140
+ K_deriv = 1j * omega * torch.exp(-0.5 * (omega * h) ** 2)
141
+ f_prime_padded = torch.fft.irfft(C * K_deriv, n=N).real
142
+ f_prime = f_prime_padded[:G]
143
+
144
+ nonzero_mask = f_prime != 0
145
+ if not nonzero_mask.any():
146
+ return 0
147
+
148
+ s = f_prime[nonzero_mask]
149
+ transitions = int(((s[:-1] > 0) & (s[1:] < 0)).sum().item())
150
+ return transitions
151
+
152
+
22
153
  def fft_mode_count(
23
154
  X: Tensor,
24
155
  h: float,
25
156
  G: int = 4096,
26
- pad_factor: int = 4,
157
+ pad_factor: int = 2, # Worker 5: pad_factor=2 (was 4) — safe for h ≤ 3σ, halves irfft size
27
158
  domain: tuple[float, float] | None = None,
28
159
  ) -> int:
29
160
  """Count KDE modes via FFT convolution — O(n + G log G), no subsampling.
@@ -58,60 +189,9 @@ def fft_mode_count(
58
189
  Number of KDE modes (downward zero-crossings of f').
59
190
  """
60
191
  with torch.no_grad():
61
- if domain is not None:
62
- lo, hi = domain
63
- else:
64
- # Domain: extend 3σ beyond data range to avoid boundary effects
65
- sigma = X.std().item()
66
- if sigma == 0.0:
67
- sigma = 1.0 # degenerate case: all points identical
68
- lo = X.min().item() - 3 * sigma
69
- hi = X.max().item() + 3 * sigma
70
- data_range = hi - lo
71
-
72
- if data_range == 0.0:
73
- return 1 # single-point distribution has 1 mode
74
-
75
- # Histogram (O(n)) — MPS-safe via bucketize+bincount on CPU.
76
- # torch.histc on MPS allocates an n × bins float32 intermediate (PyTorch
77
- # MPS bug); at n=5M, bins=512 this is ~9.5 GiB → OOM. Moving to CPU for
78
- # the binning step avoids the intermediate and is numerically identical
79
- # for data within [lo, hi] (guaranteed by the 3σ domain extension above).
80
- X_cpu = X.float().cpu()
81
- edges = torch.linspace(lo, hi, G + 1) # (G+1,) CPU
82
- bin_idx = torch.bucketize(X_cpu, edges, right=True).clamp(1, G) - 1 # 0-indexed
83
- counts = torch.bincount(bin_idx, minlength=G).float().to(X.device) # back to device
84
-
85
- # Zero-pad to pad_factor*G — promote to float64 for FFT precision
192
+ C, omega, _ = precompute_fft(X, G=G, domain=domain, pad_factor=pad_factor)
86
193
  N = pad_factor * G
87
- counts_padded = torch.zeros(N, dtype=torch.float64, device=X.device)
88
- counts_padded[:G] = counts.double()
89
-
90
- # FFT of histogram (float64)
91
- C = torch.fft.rfft(counts_padded)
92
-
93
- # Derivative kernel in frequency domain (float64)
94
- bin_width = data_range / G
95
- k = torch.arange(N // 2 + 1, device=X.device, dtype=torch.float64)
96
- omega = 2 * math.pi * k / (N * bin_width)
97
- K_deriv = 1j * omega * torch.exp(-0.5 * (omega * h) ** 2)
98
-
99
- # Convolve and back-transform; cast result back to float32
100
- f_prime_padded = torch.fft.irfft(C * K_deriv, n=N).float()
101
-
102
- # Trim to original G grid (discard zero-padded tail)
103
- f_prime = f_prime_padded[:G]
104
-
105
- # Count (+→-) sign changes = number of modes
106
- # A mode is a local max of f, i.e., f' crosses zero from + to -
107
- # Remove zeros (flat segments) — carry forward last nonzero sign
108
- nonzero_mask = f_prime != 0
109
- if not nonzero_mask.any():
110
- return 0
111
-
112
- s = f_prime[nonzero_mask]
113
- transitions = int(((s[:-1] > 0) & (s[1:] < 0)).sum().item())
114
- return transitions
194
+ return mode_count_from_C(C, omega, h, G, N)
115
195
 
116
196
 
117
197
  def _refine_hcrit(
@@ -121,7 +201,7 @@ def _refine_hcrit(
121
201
  G: int,
122
202
  domain: tuple[float, float],
123
203
  target_modes: int = 1,
124
- pad_factor: int = 4,
204
+ pad_factor: int = 2, # Worker 5: pad_factor=2 (was 4) — safe for h ≤ 3σ, halves irfft size
125
205
  ) -> float:
126
206
  """Sub-bin quadratic refinement of h_crit after bisection converges.
127
207
 
@@ -162,10 +242,7 @@ def _refine_hcrit(
162
242
 
163
243
  # Pre-compute histogram once; reuse C (FFT of counts) for all h evaluations.
164
244
  with torch.no_grad():
165
- X_cpu = X.float().cpu()
166
- edges = torch.linspace(lo_d, hi_d, G + 1)
167
- bin_idx = torch.bucketize(X_cpu, edges, right=True).clamp(1, G) - 1
168
- counts = torch.bincount(bin_idx, minlength=G).float()
245
+ counts = _histogram_on_device(X, G, lo_d, hi_d).cpu()
169
246
  counts_padded = torch.zeros(N, dtype=torch.float64)
170
247
  counts_padded[:G] = counts.double()
171
248
  C = torch.fft.rfft(counts_padded)
@@ -35,13 +35,14 @@ class DCBFunction(torch.autograd.Function):
35
35
 
36
36
  @staticmethod
37
37
  def forward(ctx, X, grid, eps, tau, target_modes, delta, formula, chunk_size,
38
- brentq_n_max, g_brentq, use_hard_bisection, safe_backward, use_fft, fft_G_min):
38
+ brentq_n_max, g_brentq, use_hard_bisection, safe_backward, use_fft, fft_G_min,
39
+ fft_dtype):
39
40
  """Locate h_crit and save state for the backward pass."""
40
41
  h_crit, cond_num = find_h_crit(
41
42
  X, grid, eps, tau, target_modes,
42
43
  formula=formula, brentq_n_max=brentq_n_max, chunk_size=chunk_size,
43
44
  g_brentq=g_brentq, use_hard_bisection=use_hard_bisection,
44
- use_fft=use_fft, G_min=fft_G_min,
45
+ use_fft=use_fft, G_min=fft_G_min, fft_dtype=fft_dtype,
45
46
  )
46
47
  ctx.save_for_backward(X, grid)
47
48
  ctx.h_crit = h_crit
@@ -67,8 +68,8 @@ class DCBFunction(torch.autograd.Function):
67
68
  ctx.denom_abs = ift_gradient.last_denom_abs
68
69
  # Gradients for: X, grid, eps, tau, target_modes, delta, formula,
69
70
  # chunk_size, brentq_n_max, g_brentq, use_hard_bisection,
70
- # safe_backward, use_fft, fft_G_min
71
- return grad_X, None, None, None, None, None, None, None, None, None, None, None, None, None
71
+ # safe_backward, use_fft, fft_G_min, fft_dtype
72
+ return grad_X, None, None, None, None, None, None, None, None, None, None, None, None, None, None
72
73
 
73
74
 
74
75
  class DCBLayer(nn.Module):
@@ -170,6 +171,7 @@ class DCBLayer(nn.Module):
170
171
  max_n_exact: int | None = 1_000_000,
171
172
  sketch_size: int = 500_000,
172
173
  fft_G_min: int = 16384,
174
+ fft_dtype: torch.dtype = torch.float32,
173
175
  ):
174
176
  super().__init__()
175
177
  self.target_modes = target_modes
@@ -189,6 +191,7 @@ class DCBLayer(nn.Module):
189
191
  self.max_n_exact = max_n_exact
190
192
  self.sketch_size = sketch_size
191
193
  self.fft_G_min = fft_G_min
194
+ self.fft_dtype = fft_dtype
192
195
  if use_fft and brentq_n_max != 50_000:
193
196
  raise TypeError(
194
197
  f"brentq_n_max={brentq_n_max} is meaningless when use_fft=True: the FFT path "
@@ -259,7 +262,7 @@ class DCBLayer(nn.Module):
259
262
  return DCBFunction.apply(
260
263
  X, grid, eps_eff, tau_eff, self.target_modes, self.delta, self.formula,
261
264
  self.chunk_size, self.brentq_n_max, self.g_brentq, self.use_hard_bisection,
262
- self.safe_backward, self.use_fft, self.fft_G_min,
265
+ self.safe_backward, self.use_fft, self.fft_G_min, self.fft_dtype,
263
266
  )
264
267
 
265
268
 
@@ -37,7 +37,7 @@ from dcb.kde import (
37
37
  soft_mode_count_cross_from_derivs,
38
38
  kde_derivatives_chunked,
39
39
  )
40
- from dcb.fft_kde import fft_mode_count, adaptive_fft_G
40
+ from dcb.fft_kde import fft_mode_count, adaptive_fft_G, precompute_fft, mode_count_from_C
41
41
 
42
42
  _AUTO_FFT_THRESHOLD = 50_000 # n above which FFT bisection activates (use_fft_effective)
43
43
 
@@ -75,6 +75,7 @@ def find_h_crit_hard(
75
75
  tau: float = 0.2,
76
76
  use_fft: bool = False,
77
77
  G_min: int = 16384,
78
+ fft_dtype: torch.dtype = torch.float32,
78
79
  ) -> tuple[float, float]:
79
80
  """Find h_crit via hard-mode-count bisection (monotone, no false roots).
80
81
 
@@ -154,38 +155,54 @@ def find_h_crit_hard(
154
155
  data_range = hi_domain - lo_domain
155
156
  G_fft = adaptive_fft_G(data_range, h_hi, G_min=G_min)
156
157
  _domain = (lo_domain, hi_domain)
158
+ pad_factor = 2 # Worker 5: pad_factor=2 (was 4) — safe for h ≤ 3σ, halves irfft size
159
+ N = pad_factor * G_fft
157
160
 
158
161
  with torch.no_grad():
162
+ # Worker 1: precomputed C — hoist histogram + rfft out of bisection.
163
+ # Worker 3: float32 FFT by default — 2× faster; _refine_hcrit uses float64 independently.
164
+ C, omega, _domain = precompute_fft(
165
+ X, G=G_fft, domain=_domain, pad_factor=pad_factor, fft_dtype=fft_dtype,
166
+ )
167
+
159
168
  # Verify bracket using FFT mode count on full X
160
- count_lo = fft_mode_count(X, h_lo, G=G_fft, domain=_domain)
169
+ count_lo = mode_count_from_C(C, omega, h_lo, G_fft, N)
161
170
  if count_lo <= target_modes:
162
171
  h_lo_try = h_lo
163
172
  for _ in range(30):
164
173
  h_lo_try *= 0.5
165
174
  if h_lo_try < 1e-10:
166
175
  break
167
- if fft_mode_count(X, h_lo_try, G=G_fft, domain=_domain) > target_modes:
176
+ if mode_count_from_C(C, omega, h_lo_try, G_fft, N) > target_modes:
168
177
  h_lo = h_lo_try
169
178
  break
170
179
 
171
- count_hi = fft_mode_count(X, h_hi, G=G_fft, domain=_domain)
180
+ count_hi = mode_count_from_C(C, omega, h_hi, G_fft, N)
172
181
  if count_hi > target_modes:
173
182
  for _ in range(30):
174
183
  h_hi *= 2.0
175
- if fft_mode_count(X, h_hi, G=G_fft, domain=_domain) <= target_modes:
184
+ if mode_count_from_C(C, omega, h_hi, G_fft, N) <= target_modes:
176
185
  break
177
186
 
178
- # Standard bisection: 50 iterations bracket width / 2^50
187
+ # Adaptive bisection: stop when bracket is localised (relative width < 1e-3)
188
+ # _refine_hcrit provides sub-bin precision afterwards — no need to over-bisect.
179
189
  lo, hi = h_lo, h_hi
180
190
  for _ in range(50):
181
191
  mid = (lo + hi) / 2.0
182
- count = fft_mode_count(X, mid, G=G_fft, domain=_domain)
192
+ count = mode_count_from_C(C, omega, mid, G_fft, N)
183
193
  if count <= target_modes:
184
194
  hi = mid
185
195
  else:
186
196
  lo = mid
187
197
  if (hi - lo) < tol:
188
198
  break
199
+ # Worker 4: adaptive termination — stop when relative bracket width
200
+ # is small enough that further bisection cannot meaningfully shift
201
+ # _refine_hcrit's quadratic fit. Empirically 1e-7 preserves h_crit
202
+ # to within 1e-6 of the 50-step tol=1e-6 baseline while saving ~10
203
+ # bisection steps in typical cases.
204
+ if hi > 0 and (hi - lo) / hi < 1e-7:
205
+ break
189
206
 
190
207
  h_crit = float(hi) # smallest h with count <= target_modes
191
208
 
@@ -222,6 +239,9 @@ def find_h_crit_hard(
222
239
  break
223
240
 
224
241
  # Standard bisection: 50 iterations → bracket width / 2^50
242
+ # NOTE: non-FFT path has no _refine_hcrit sub-bin refinement, so we keep
243
+ # tight bisection here for gradient stability (IFT test requires h_crit
244
+ # accurate well below FD perturbation delta=1e-3).
225
245
  lo, hi = h_lo, h_hi
226
246
  for _ in range(50):
227
247
  mid = (lo + hi) / 2.0
@@ -297,6 +317,7 @@ def find_h_crit(
297
317
  use_hard_bisection: bool = True,
298
318
  use_fft: bool = True,
299
319
  G_min: int = 16384,
320
+ fft_dtype: torch.dtype = torch.float32,
300
321
  ) -> tuple[float, float]:
301
322
  """Find h_crit and return (h_crit, condition_number).
302
323
 
@@ -350,7 +371,7 @@ def find_h_crit(
350
371
  return find_h_crit_hard(
351
372
  X, grid, target_modes, chunk_size, brentq_n_max,
352
373
  h_lo, h_hi, formula=formula, eps=eps, tau=tau,
353
- use_fft=use_fft, G_min=G_min,
374
+ use_fft=use_fft, G_min=G_min, fft_dtype=fft_dtype,
354
375
  )
355
376
 
356
377
  from scipy.optimize import brentq
@@ -4,7 +4,7 @@ build-backend = "hatchling.build"
4
4
 
5
5
  [project]
6
6
  name = "diffcb"
7
- version = "0.1.3"
7
+ version = "0.1.4"
8
8
  description = "Differentiable Critical Bandwidth: Silverman's modality test as a differentiable PyTorch layer with IFT backward pass."
9
9
  readme = "README.md"
10
10
  license = { file = "LICENSE" }
File without changes
File without changes
File without changes
File without changes
File without changes
File without changes
File without changes
File without changes
File without changes
File without changes
File without changes