foscat 2025.10.2__py3-none-any.whl → 2026.1.1__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
foscat/BkTorch.py CHANGED
@@ -62,158 +62,118 @@ class BkTorch(BackendBase.BackendBase):
62
62
  self.torch_device = (
63
63
  torch.device("cuda") if torch.cuda.is_available() else torch.device("cpu")
64
64
  )
65
-
66
- # ---------------------------------
67
- # HEALPix binning utilities (nested)
68
- # ---------------------------------
69
- # Robust binned_mean that supports arbitrary subsets (N not divisible by 4)
70
- # and batched cell_ids of shape [B, N]. It returns compact per-parent means
71
- # even when some parents are missing (sparse coverage).
72
-
73
- def binned_mean_old(self, data, cell_ids, *, padded: bool = False, fill_value: float = float("nan")):
74
- """Average values over parent HEALPix pixels (nested) when downgrading nside→nside/2.
75
-
76
- Works with full-sky or sparse subsets (no need for N to be divisible by 4).
65
+
66
+ def downsample_mean_2x2(self,tim: torch.Tensor) -> torch.Tensor:
67
+ """
68
+ Average-pool tensor tim over non-overlapping 2x2 spatial blocks.
77
69
 
78
70
  Parameters
79
71
  ----------
80
- data : torch.Tensor or np.ndarray
81
- Shape ``[..., N]`` or ``[B, ..., N]``.
82
- cell_ids : torch.LongTensor or np.ndarray
83
- Shape ``[N]`` or ``[B, N]`` (nested indexing at the *child* resolution).
84
- padded : bool, optional (default: False)
85
- Only used when ``cell_ids`` is ``[B, N]``. If ``False``, returns Python
86
- lists (ragged) of per-batch results. If ``True``, returns padded tensors
87
- plus a boolean mask of valid bins.
88
- fill_value : float, optional
89
- Value used for padding when ``padded=True``.
72
+ tim : torch.Tensor
73
+ Tensor of shape [a, N1, N2, b].
90
74
 
91
75
  Returns
92
76
  -------
93
- If ``cell_ids`` is ``[N]``:
94
- mean : torch.Tensor, shape ``[..., n_bins]``
95
- groups: torch.LongTensor, shape ``[n_bins]`` (sorted unique parents)
96
-
97
- If ``cell_ids`` is ``[B, N]`` and ``padded=False``:
98
- means_list : List[torch.Tensor] of length B, each shape ``[T, n_bins_b]``
99
- where ``T = prod(data.shape[1:-1])`` (or 1 if none).
100
- groups_list : List[torch.LongTensor] of length B, each shape ``[n_bins_b]``
101
-
102
- If ``cell_ids`` is ``[B, N]`` and ``padded=True``:
103
- mean_padded : torch.Tensor, shape ``[B, T, max_bins]`` (or ``[B, max_bins]`` if T==1)
104
- groups_pad : torch.LongTensor, shape ``[B, max_bins]`` (parents, padded with -1)
105
- mask : torch.BoolTensor, shape ``[B, max_bins]`` (True where valid)
77
+ torch.Tensor
78
+ Downsampled tensor of shape [a, N1//2, N2//2, b],
79
+ each element being the mean of a 2x2 block.
106
80
  """
107
- import torch, numpy as np
108
-
109
- # ---- Tensorize & device/dtype plumbing ----
110
- if isinstance(data, np.ndarray):
111
- data = torch.from_numpy(data).to(dtype=torch.float32, device=getattr(self, 'torch_device', 'cpu'))
112
- if isinstance(cell_ids, np.ndarray):
113
- cell_ids = torch.from_numpy(cell_ids).to(dtype=torch.long, device=data.device)
114
-
115
- data = data.to(device=getattr(self, 'torch_device', data.device))
116
- cell_ids = cell_ids.to(device=data.device, dtype=torch.long)
117
-
118
- if data.ndim < 1:
119
- raise ValueError("`data` must have at least 1 dimension (last is N).")
120
- N = data.shape[-1]
121
-
122
- # Flatten leading dims (rows) for scatter convenience
123
- orig = data.shape[:-1]
124
- T = int(np.prod(orig[1:])) if len(orig) > 1 else 1 # repeats per batch row
125
- if cell_ids.ndim == 1:
126
- # Shared mapping for all rows
127
- groups = (cell_ids // 4).to(torch.long) # [N]
128
- # Unique parent ids + inverse indices
129
- parents, inv = torch.unique(groups, sorted=True, return_inverse=True)
130
- n_bins = parents.numel()
131
-
132
- R = int(np.prod(orig)) if len(orig) > 0 else 1
133
- data_flat = data.reshape(R, N) # [R, N]
134
-
135
- # Row offsets -> independent bins per row
136
- row_offsets = torch.arange(R, device=data.device).unsqueeze(1) * n_bins # [R,1]
137
- idx = inv.unsqueeze(0).expand(R, -1) + row_offsets # [R,N]
138
-
139
- vals_flat = data_flat.reshape(-1)
140
- idx_flat = idx.reshape(-1)
141
-
142
- out_sum = torch.zeros(R * n_bins, dtype=data.dtype, device=data.device)
143
- out_cnt = torch.zeros_like(out_sum)
144
- out_sum.scatter_add_(0, idx_flat, vals_flat)
145
- out_cnt.scatter_add_(0, idx_flat, torch.ones_like(vals_flat))
146
- out_cnt.clamp_(min=1)
147
-
148
- mean = (out_sum / out_cnt).view(*orig, n_bins)
149
- return mean, parents
150
-
151
- elif cell_ids.ndim == 2:
152
- B = cell_ids.shape[0]
153
- if data.shape[0] % B != 0:
154
- raise ValueError(f"Leading dim of data ({data.shape[0]}) must be a multiple of cell_ids batch ({B}).")
155
- R = int(np.prod(orig)) if len(orig) > 0 else 1
156
- data_flat = data.reshape(R, N) # [R, N]
157
- B_data = data.shape[0]
158
- T = R // B_data # repeats per batch row (product of extra leading dims)
159
-
160
- means_list, groups_list = [], []
161
- max_bins = 0
162
- # First pass: compute per-batch parents/inv and scatter means
163
- for b in range(B):
164
- groups_b = (cell_ids[b] // 4).to(torch.long) # [N]
165
- parents_b, inv_b = torch.unique(groups_b, sorted=True, return_inverse=True)
166
- n_bins_b = parents_b.numel()
167
- max_bins = max(max_bins, n_bins_b)
168
-
169
- # rows for this batch in data_flat
170
- start = b * T
171
- stop = (b + 1) * T
172
- rows = slice(start, stop) # T rows
173
-
174
- row_offsets = (torch.arange(T, device=data.device).unsqueeze(1) * n_bins_b)
175
- idx = inv_b.unsqueeze(0).expand(T, -1) + row_offsets # [T, N]
176
-
177
- vals_flat = data_flat[rows].reshape(-1)
178
- idx_flat = idx.reshape(-1)
179
-
180
- out_sum = torch.zeros(T * n_bins_b, dtype=data.dtype, device=data.device)
181
- out_cnt = torch.zeros_like(out_sum)
182
- out_sum.scatter_add_(0, idx_flat, vals_flat)
183
- out_cnt.scatter_add_(0, idx_flat, torch.ones_like(vals_flat))
184
- out_cnt.clamp_(min=1)
185
- mean_bt = (out_sum / out_cnt).view(T, n_bins_b) # [T, n_bins_b]
186
-
187
- means_list.append(mean_bt)
188
- groups_list.append(parents_b)
189
-
190
- if not padded:
191
- return means_list, groups_list
192
-
193
- # Padded output
194
- # mean_padded: [B, T, max_bins]; groups_pad: [B, max_bins]; mask: [B, max_bins]
195
- mean_pad = torch.full((B, T, max_bins), fill_value, dtype=data.dtype, device=data.device)
196
- groups_pad = torch.full((B, max_bins), -1, dtype=torch.long, device=data.device)
197
- mask = torch.zeros((B, max_bins), dtype=torch.bool, device=data.device)
198
- for b, (m_b, g_b) in enumerate(zip(means_list, groups_list)):
199
- nb = g_b.numel()
200
- mean_pad[b, :, :nb] = m_b
201
- groups_pad[b, :nb] = g_b
202
- mask[b, :nb] = True
81
+ a, N1, N2, b = tim.shape
82
+ # Ensure even sizes
83
+ N1_2 = N1 // 2
84
+ N2_2 = N2 // 2
85
+
86
+ # reshape to group 2x2 patches
87
+ tim_reshaped = tim[:, :2*N1_2, :2*N2_2, :].reshape(a, N1_2, 2, N2_2, 2, b)
88
+ # mean over the two small dims (2x2)
89
+ out = tim_reshaped.mean(dim=(2, 4))
90
+ return out
91
+
92
+ def downsample_median_2x2(self,tim: torch.Tensor) -> torch.Tensor:
93
+ """
94
+ 2x2 block median downsampling on spatial axes (N1, N2).
95
+
96
+ Input:
97
+ tim: [a, N1, N2, b] (real or complex)
98
+ Output:
99
+ out: [a, N1//2, N2//2, b]
100
+ each value is the median over the corresponding 2x2 block.
101
+ - For complex inputs: median is taken by sorting the 4 values by |.|,
102
+ returning the complex sample at the lower median rank.
103
+ """
104
+ a, N1, N2, b = tim.shape
105
+ N1_2 = N1 // 2
106
+ N2_2 = N2 // 2
107
+ # On ignore la dernière ligne/colonne si N1/N2 sont impairs
108
+ x = tim[:, :2*N1_2, :2*N2_2, :] # [a, 2*N1_2, 2*N2_2, b]
109
+
110
+ # Regrouper les blocs 2x2 -> construire une dernière dimension de taille 4
111
+ # Réarrange: [a, N1_2, 2, N2_2, 2, b] -> [a, N1_2, N2_2, b, 4]
112
+ x = x.reshape(a, N1_2, 2, N2_2, 2, b).permute(0, 1, 3, 5, 2, 4).reshape(a, N1_2, N2_2, b, 4)
113
+
114
+ if not torch.is_complex(x):
115
+ # Réel : médiane le long de la dernière dim (taille 4)
116
+ med, _ = torch.median(x, dim=-1) # [a, N1_2, N2_2, b]
117
+ return med
118
+ else:
119
+ # Complexe : trier par module puis prendre l'élément de rang 1 (médiane inférieure)
120
+ mags = x.abs() # [a, N1_2, N2_2, b, 4]
121
+ sorted_mag, idx = torch.sort(mags, dim=-1) # idx: indices triés par |.| croissant
122
+ # Récupérer l'indice de médiane inférieure (pour 4 éléments -> position 1)
123
+ med_rank = 1
124
+ gather_idx = idx[..., med_rank:med_rank+1] # [a, N1_2, N2_2, b, 1]
125
+ # Sélectionner la valeur complexe correspondante
126
+ med = torch.gather(x, dim=-1, index=gather_idx).squeeze(-1) # [a, N1_2, N2_2, b]
127
+ return med
128
+
129
+ def downsample_mean_1d(self,tim: torch.Tensor) -> torch.Tensor:
130
+ """
131
+ Downsample tensor tim [a, N1] by averaging non-overlapping 2-element blocks.
132
+ Output shape: [a, N1//2]
133
+ """
134
+ a, N1 = tim.shape
135
+ N1_2 = N1 // 2
203
136
 
204
- # Reshape back to [B, (*extra leading dims), max_bins] if needed
205
- if len(orig) > 1:
206
- extra = orig[1:] # e.g., (D1, D2, ...)
207
- mean_pad = mean_pad.view(B, *extra, max_bins)
208
- else:
209
- mean_pad = mean_pad.view(B, max_bins)
137
+ # Ignore the last element if N1 is odd
138
+ x = tim[:, :2 * N1_2]
210
139
 
211
- return mean_pad, groups_pad, mask
140
+ # Reshape to group pairs of 2 and take mean
141
+ x = x.reshape(a, N1_2, 2)
142
+ out = x.mean(dim=-1) # [a, N1//2]
143
+ return out
212
144
 
145
+ def downsample_median_1d(self,tim: torch.Tensor) -> torch.Tensor:
146
+ """
147
+ Downsample tensor tim [a, N1] by taking the median of non-overlapping pairs (2 values).
148
+ Output shape: [a, N1//2]
149
+ - For real inputs: median of the two values.
150
+ - For complex inputs: pick the complex value with the smallest |.| among the two.
151
+ """
152
+ a, N1 = tim.shape
153
+ N1_2 = N1 // 2
154
+ x = tim[:, :2 * N1_2].reshape(a, N1_2, 2) # group 2 by 2
155
+
156
+ if not torch.is_complex(x):
157
+ # Sort values in ascending order, then take mean of the two (true median for 2 samples)
158
+ x_sorted, _ = torch.sort(x, dim=-1)
159
+ med = x_sorted.mean(dim=-1) # [a, N1//2]
160
+ return med
213
161
  else:
214
- raise ValueError("`cell_ids` must be of shape [N] or [B, N].")
162
+ # Complex: sort by magnitude
163
+ mags = x.abs()
164
+ sorted_mags, idx = torch.sort(mags, dim=-1)
165
+ # Take the one with smallest magnitude (lower median)
166
+ med = torch.gather(x, dim=-1, index=idx[..., 0:1]).squeeze(-1)
167
+ return med
168
+ # ---------------------------------
169
+ # HEALPix binning utilities (nested)
170
+ # ---------------------------------
171
+ # Robust binned_mean that supports arbitrary subsets (N not divisible by 4)
172
+ # and batched cell_ids of shape [B, N]. It returns compact per-parent means
173
+ # even when some parents are missing (sparse coverage).
215
174
 
216
- def binned_mean(
175
+
176
+ def binned_mean_old(
217
177
  self,
218
178
  data,
219
179
  cell_ids,
@@ -385,7 +345,268 @@ class BkTorch(BackendBase.BackendBase):
385
345
 
386
346
  else:
387
347
  raise ValueError("`cell_ids` must be of shape [N] or [B, N].")
388
-
348
+
349
+
350
+ def binned_mean( # (garde ton nom si besoin de compat)
351
+ self,
352
+ data,
353
+ cell_ids,
354
+ *,
355
+ reduce: str = "mean", # "mean" | "max" | "median"
356
+ padded: bool = False,
357
+ fill_value: float = float("nan"),
358
+ ):
359
+ """
360
+ Reduce values over parent HEALPix pixels (nested) when downgrading nside→nside/2.
361
+
362
+ Parameters
363
+ ----------
364
+ data : torch.Tensor | np.ndarray
365
+ Shape [..., N] or [B, ..., N].
366
+ cell_ids : torch.LongTensor | np.ndarray
367
+ Shape [N] or [B, N] (nested indexing at the child resolution).
368
+ reduce : {"mean","max","median"}, default "mean"
369
+ Aggregation within each parent group of 4 children.
370
+ padded : bool, default False
371
+ Only when `cell_ids` is [B, N]. If False, returns ragged Python lists.
372
+ If True, returns padded tensors + mask.
373
+ fill_value : float, default NaN
374
+ Padding value when `padded=True`.
375
+
376
+ Returns
377
+ -------
378
+ As in your original function, with aggregation set by `reduce`.
379
+ """
380
+
381
+ # ---- Tensorize & device/dtype plumbing ----
382
+ if isinstance(data, np.ndarray):
383
+ data = torch.from_numpy(data).to(
384
+ dtype=torch.float32, device=getattr(self, "torch_device", "cpu")
385
+ )
386
+ if isinstance(cell_ids, np.ndarray):
387
+ cell_ids = torch.from_numpy(cell_ids).to(
388
+ dtype=torch.long, device=getattr(self, "torch_device", data.device)
389
+ )
390
+ data = data.to(device=getattr(self, "torch_device", data.device))
391
+ cell_ids = cell_ids.to(device=data.device, dtype=torch.long)
392
+
393
+ if data.ndim < 1:
394
+ raise ValueError("`data` must have at least 1 dimension (last is N).")
395
+ N = data.shape[-1]
396
+ orig = data.shape[:-1]
397
+
398
+ # ---- Utilities ----
399
+ def _segment_max(vals_flat, idx_flat, out_size):
400
+ """Compute out[g] = max(vals[idx==g]); vectorized when possible."""
401
+ if hasattr(torch.Tensor, "scatter_reduce_"):
402
+ out = torch.full((out_size,), -float("inf"),
403
+ dtype=vals_flat.dtype, device=vals_flat.device)
404
+ out.scatter_reduce_(0, idx_flat, vals_flat, reduce="amax", include_self=True)
405
+ return out
406
+ out = torch.full((out_size,), -float("inf"),
407
+ dtype=vals_flat.dtype, device=vals_flat.device)
408
+ uniq = torch.unique(idx_flat)
409
+ for u in uniq.tolist():
410
+ m = (idx_flat == u)
411
+ if m.any():
412
+ out[u] = torch.max(vals_flat[m])
413
+ return out
414
+
415
+ def _median_from_four_real(v4: torch.Tensor) -> torch.Tensor:
416
+ """
417
+ v4: [..., 4] real with NaN for missing.
418
+ Returns true median: average of the 2 middle finite values.
419
+ """
420
+ # Sort with NaN-last: replace NaN by +inf for sorting, then restore
421
+ is_nan = torch.isnan(v4)
422
+ v4_sortkey = torch.where(is_nan, torch.full_like(v4, float('inf')), v4)
423
+ v_sorted, _ = torch.sort(v4_sortkey, dim=-1) # NaN (inf) at the end
424
+
425
+ # Count finite per group
426
+ k = torch.sum(~is_nan, dim=-1) # [...]
427
+
428
+ # For k==0 -> NaN; k==1 -> the single value; k>=2 -> average middle two
429
+ # Indices for middle-two among the first k finite values: m-1 and m (with m = k//2)
430
+ m = torch.clamp(k // 2, min=1) # ensure >=1 for gather
431
+ idx_lo = (m - 1).unsqueeze(-1)
432
+ idx_hi = torch.clamp(m, max=3).unsqueeze(-1) # upper middle (cap at 3)
433
+
434
+ # Gather from sorted finite section
435
+ gather_lo = torch.gather(v_sorted, -1, idx_lo).squeeze(-1)
436
+ gather_hi = torch.gather(v_sorted, -1, idx_hi).squeeze(-1)
437
+ med = 0.5 * (gather_lo + gather_hi)
438
+
439
+ # Handle k==1: both idx point to same single finite value -> OK
440
+ # Handle k==0: set NaN
441
+ med = torch.where(k > 0, med, torch.full_like(med, float('nan')))
442
+ return med
443
+
444
+ def _median_from_four_complex(v4: torch.Tensor) -> torch.Tensor:
445
+ """
446
+ v4: [..., 4] complex with NaN for missing.
447
+ Returns lower median by magnitude (rank 1 among finite elements).
448
+ """
449
+ mags = v4.abs()
450
+ # NaN mags -> set to +inf so they go last
451
+ mags_key = torch.where(torch.isnan(mags), torch.full_like(mags, float('inf')), mags)
452
+ mags_sorted, idx = torch.sort(mags_key, dim=-1)
453
+ # Count finite elements
454
+ k = torch.sum(torch.isfinite(mags), dim=-1) # [...]
455
+ # lower median rank = max(0, k//2 - 1)
456
+ rank = torch.clamp(k // 2 - 1, min=0).unsqueeze(-1)
457
+ pick = torch.gather(idx, -1, rank)
458
+ med = torch.gather(v4, -1, pick).squeeze(-1)
459
+ # If k==0 -> NaN+NaNj
460
+ med = torch.where(
461
+ k > 0,
462
+ med,
463
+ torch.full_like(med, complex(float('nan'), float('nan')))
464
+ )
465
+ return med
466
+
467
+ # ---- Branch: cell_ids shape [N] (shared mapping) ----
468
+ if cell_ids.ndim == 1:
469
+ groups = (cell_ids // 4).to(torch.long) # [N] parent ids (global)
470
+ parents, inv = torch.unique(groups, sorted=True, return_inverse=True)
471
+ n_bins = parents.numel()
472
+
473
+ R = int(np.prod(orig)) if len(orig) > 0 else 1
474
+ data_flat = data.reshape(R, N) # [R, N]
475
+
476
+ if reduce in ("mean", "max"):
477
+ # Vectorized scatter path (same as before)
478
+ row_offsets = torch.arange(R, device=data.device).unsqueeze(1) * n_bins
479
+ idx = inv.unsqueeze(0).expand(R, -1) + row_offsets
480
+ vals_flat = data_flat.reshape(-1)
481
+ idx_flat = idx.reshape(-1)
482
+ out_size = R * n_bins
483
+
484
+ if reduce == "mean":
485
+ out_sum = torch.zeros(out_size, dtype=data.dtype, device=data.device)
486
+ out_cnt = torch.zeros_like(out_sum)
487
+ out_sum.scatter_add_(0, idx_flat, vals_flat)
488
+ out_cnt.scatter_add_(0, idx_flat, torch.ones_like(vals_flat))
489
+ out_cnt.clamp_(min=1)
490
+ reduced = out_sum / out_cnt
491
+ else: # "max"
492
+ reduced = _segment_max(vals_flat, idx_flat, out_size)
493
+
494
+ output = reduced.view(*orig, n_bins)
495
+ return output, parents
496
+
497
+ elif reduce == "median":
498
+ # Build a 4-slot array per parent using child offset = cell_ids % 4
499
+ off = (cell_ids % 4).to(torch.long) # [N] in {0,1,2,3}
500
+ out4 = torch.full((R, n_bins, 4),
501
+ torch.nan,
502
+ dtype=data.dtype,
503
+ device=data.device)
504
+ # flat indexing to scatter
505
+ base = torch.arange(R, device=data.device).unsqueeze(1) * (n_bins * 4)
506
+ flat_index = base + (inv.unsqueeze(0) * 4) + off.unsqueeze(0) # [R, N]
507
+ out4 = out4.reshape(-1)
508
+ out4.scatter_(0, flat_index.reshape(-1), data_flat.reshape(-1))
509
+ out4 = out4.view(R, n_bins, 4) # [R, n_bins, 4]
510
+
511
+ if torch.is_complex(data):
512
+ med = _median_from_four_complex(out4) # [R, n_bins]
513
+ else:
514
+ med = _median_from_four_real(out4) # [R, n_bins]
515
+
516
+ output = med.view(*orig, n_bins)
517
+ return output, parents
518
+ else:
519
+ raise ValueError("reduce must be 'mean', 'max', or 'median'.")
520
+
521
+ # ---- Branch: cell_ids shape [B, N] (per-batch mapping) ----
522
+ elif cell_ids.ndim == 2:
523
+ B = cell_ids.shape[0]
524
+ R = int(np.prod(orig)) if len(orig) > 0 else 1
525
+ data_flat = data.reshape(R, N)
526
+ B_data = data.shape[0] if len(orig) > 0 else 1
527
+ if B_data % B != 0:
528
+ raise ValueError(
529
+ f"Leading dim of data ({B_data}) must be a multiple of cell_ids batch ({B})."
530
+ )
531
+ T = (R // B_data) if B_data > 0 else 1 # repeats per batch row
532
+
533
+ outs_list, groups_list = [], []
534
+ max_bins = 0
535
+
536
+ for b in range(B):
537
+ groups_b = (cell_ids[b] // 4).to(torch.long) # [N]
538
+ parents_b, inv_b = torch.unique(groups_b, sorted=True, return_inverse=True)
539
+ n_bins_b = parents_b.numel()
540
+ max_bins = max(max_bins, n_bins_b)
541
+
542
+ # rows of data_flat that correspond to this batch row
543
+ start, stop = b * T, (b + 1) * T
544
+ rows = slice(start, stop) # T rows -> [T, N]
545
+ vals = data_flat[rows]
546
+
547
+ if reduce in ("mean", "max"):
548
+ row_offsets = torch.arange(T, device=data.device).unsqueeze(1) * n_bins_b
549
+ idx = inv_b.unsqueeze(0).expand(T, -1) + row_offsets
550
+ vals_flat = vals.reshape(-1)
551
+ idx_flat = idx.reshape(-1)
552
+ out_size = T * n_bins_b
553
+ if reduce == "mean":
554
+ out_sum = torch.zeros(out_size, dtype=data.dtype, device=data.device)
555
+ out_cnt = torch.zeros_like(out_sum)
556
+ out_sum.scatter_add_(0, idx_flat, vals_flat)
557
+ out_cnt.scatter_add_(0, idx_flat, torch.ones_like(vals_flat))
558
+ out_cnt.clamp_(min=1)
559
+ reduced_bt = (out_sum / out_cnt).view(T, n_bins_b)
560
+ else:
561
+ reduced_bt = _segment_max(vals_flat, idx_flat, out_size).view(T, n_bins_b)
562
+ outs_list.append(reduced_bt)
563
+ groups_list.append(parents_b)
564
+ elif reduce == "median":
565
+ off_b = (cell_ids[b] % 4).to(torch.long) # [N] in {0,1,2,3}
566
+ out4 = torch.full((T, n_bins_b, 4),
567
+ torch.nan,
568
+ dtype=data.dtype,
569
+ device=data.device)
570
+ base = torch.arange(T, device=data.device).unsqueeze(1) * (n_bins_b * 4)
571
+ flat_index = base + (inv_b.unsqueeze(0) * 4) + off_b.unsqueeze(0) # [T, N]
572
+ out4 = out4.reshape(-1)
573
+ out4.scatter_(0, flat_index.reshape(-1), vals.reshape(-1))
574
+ out4 = out4.view(T, n_bins_b, 4) # [T, n_bins_b, 4]
575
+
576
+ if torch.is_complex(data):
577
+ reduced_bt = _median_from_four_complex(out4) # [T, n_bins_b]
578
+ else:
579
+ reduced_bt = _median_from_four_real(out4) # [T, n_bins_b]
580
+
581
+ outs_list.append(reduced_bt)
582
+ groups_list.append(parents_b)
583
+ else:
584
+ raise ValueError("reduce must be 'mean', 'max', or 'median'.")
585
+
586
+ if not padded:
587
+ return outs_list, groups_list
588
+
589
+ # Padded output (B, T, max_bins) [+ mask]
590
+ out_pad = torch.full((B, T, max_bins), fill_value, dtype=data.dtype, device=data.device)
591
+ groups_pad = torch.full((B, max_bins), -1, dtype=torch.long, device=data.device)
592
+ mask = torch.zeros((B, max_bins), dtype=torch.bool, device=data.device)
593
+ for b, (o_b, g_b) in enumerate(zip(outs_list, groups_list)):
594
+ nb = g_b.numel()
595
+ out_pad[b, :, :nb] = o_b
596
+ groups_pad[b, :nb] = g_b
597
+ mask[b, :nb] = True
598
+
599
+ if len(orig) > 1:
600
+ extra = orig[1:]
601
+ out_pad = out_pad.view(B, *extra, max_bins)
602
+ else:
603
+ out_pad = out_pad.view(B, max_bins)
604
+
605
+ return out_pad, groups_pad, mask
606
+
607
+ else:
608
+ raise ValueError("`cell_ids` must be of shape [N] or [B, N].")
609
+
389
610
  def average_by_cell_group(data, cell_ids):
390
611
  """
391
612
  data: tensor of shape [..., N, ...] (ex: [B, N, C])
@@ -404,6 +625,271 @@ class BkTorch(BackendBase.BackendBase):
404
625
 
405
626
  return torch.bincount(group_indices, weights=data) / counts, unique_groups
406
627
 
628
+ def bk_masked_median(self, x: torch.Tensor, mask: torch.Tensor,
629
+ max_iter: int = 100, tol: float = 1e-6, eps: float = 1e-12):
630
+ """
631
+ Masked geometric median over the last axis using Weiszfeld iteration (1D case).
632
+
633
+ Parameters
634
+ ----------
635
+ x : torch.Tensor
636
+ Shape [a, b, c, N]. Can be real or complex.
637
+ mask : torch.Tensor
638
+ Binary mask of shape [a, b, 1, N]; broadcast across 'c'.
639
+ max_iter : int
640
+ Max number of Weiszfeld iterations.
641
+ tol : float
642
+ Convergence tolerance on the max absolute update per voxel.
643
+ eps : float
644
+ Small value to avoid division-by-zero in the weights.
645
+
646
+ Returns
647
+ -------
648
+ med : torch.Tensor, shape [a, b, c]
649
+ Geometric median of x along the last axis where mask == 1.
650
+ - For complex x: distances use the complex magnitude |x - y|.
651
+ The returned median is complex.
652
+ med2 : torch.Tensor, shape [a, b, c]
653
+ Geometric median of squared values along the last axis where mask == 1.
654
+ - If x is real : median of x**2 (real).
655
+ - If x is complex : median of |x|**2 (real).
656
+ """
657
+
658
+ # --- helpers ---
659
+ def _nan_like(y: torch.Tensor) -> torch.Tensor:
660
+ """Return a NaN tensor with the same shape/dtype/device as y."""
661
+ if torch.is_complex(y):
662
+ return torch.full_like(y, complex(float('nan'), float('nan')))
663
+ else:
664
+ return torch.full_like(y, float('nan'))
665
+
666
+ def safe_nanmax(t: torch.Tensor) -> torch.Tensor:
667
+ """
668
+ Backward-compatible replacement for torch.nanmax.
669
+ Assumes t is real-valued (we only call it on absolute updates).
670
+ """
671
+ if torch.isnan(t).all():
672
+ return torch.tensor(float('nan'), device=t.device, dtype=t.dtype)
673
+ if torch.isnan(t).any():
674
+ return torch.max(t[~torch.isnan(t)])
675
+ return torch.max(t)
676
+
677
+ # --- prep shapes & mask ---
678
+ # Broadcast mask to x's shape [a,b,c,N]
679
+ mask_bool = mask.to(torch.bool).expand_as(x) # [a,b,c,N]
680
+ m_float = mask_bool.to(dtype=x.real.dtype) # weights need real dtype
681
+
682
+ # Count valid samples per voxel
683
+ valid_counts = mask_bool.sum(dim=-1) # [a,b,c]
684
+ zero_valid = (valid_counts == 0)
685
+
686
+ # Denominator for masked mean initialization (avoid div-by-zero with clamp_min)
687
+ denom = valid_counts.clamp_min(1).to(dtype=x.real.dtype) # real
688
+
689
+ # --- initialize y with masked mean (good starting point) ---
690
+ if torch.is_complex(x):
691
+ # (m_float*x) promotes to complex; denom to complex for division
692
+ y = (m_float * x).sum(dim=-1) / denom.to(dtype=x.dtype) # [a,b,c], complex
693
+ else:
694
+ y = (m_float * x).sum(dim=-1) / denom # [a,b,c], real
695
+
696
+ # Put NaNs where there are no valid samples
697
+ y = torch.where(zero_valid, _nan_like(y), y)
698
+
699
+ # --- Weiszfeld iterations for x -> med ---
700
+ # y_{k+1} = sum_i (x_i / ||x_i - y_k||) / sum_i (1/||x_i - y_k||), masked
701
+ for _ in range(max_iter):
702
+ if torch.all(zero_valid):
703
+ break
704
+
705
+ diff = x - y.unsqueeze(-1) # [a,b,c,N]
706
+ dist = diff.abs() # real, [a,b,c,N]
707
+ w = m_float * (1.0 / torch.clamp(dist, min=eps)) # real weights
708
+ w_sum = w.sum(dim=-1) # [a,b,c], real
709
+ y_new = (w * x).sum(dim=-1) / w_sum.clamp_min(eps) # [a,b,c], real/complex
710
+
711
+ # Keep NaNs on zero-valid voxels
712
+ y_new = torch.where(zero_valid, _nan_like(y_new), y_new)
713
+
714
+ # Convergence
715
+ upd = (y_new - y).abs() # real
716
+ if safe_nanmax(upd).item() <= tol:
717
+ y = y_new
718
+ break
719
+ y = y_new
720
+
721
+ med = y # [a,b,c]
722
+
723
+ # --- Weiszfeld iterations for squared values -> med2 ---
724
+ # For complex: use |x|^2; for real: x^2
725
+ s = (x.abs() ** 2) if torch.is_complex(x) else (x ** 2) # [a,b,c,N], real
726
+ # Init with masked mean of s
727
+ z = (m_float * s).sum(dim=-1) / denom # [a,b,c], real
728
+ z = torch.where(zero_valid, _nan_like(z), z)
729
+
730
+ # Weiszfeld on real scalars s
731
+ for _ in range(max_iter):
732
+ if torch.all(zero_valid):
733
+ break
734
+ diff_s = s - z.unsqueeze(-1) # [a,b,c,N]
735
+ dist_s = diff_s.abs().clamp_min(eps) # real
736
+ w_s = m_float * (1.0 / dist_s) # real
737
+ w_s_sum = w_s.sum(dim=-1) # [a,b,c]
738
+ z_new = (w_s * s).sum(dim=-1) / w_s_sum.clamp_min(eps) # [a,b,c]
739
+ z_new = torch.where(zero_valid, _nan_like(z_new), z_new)
740
+
741
+ upd_s = (z_new - z).abs()
742
+ if safe_nanmax(upd_s).item() <= tol:
743
+ z = z_new
744
+ break
745
+ z = z_new
746
+
747
+ med2 = z # [a,b,c], real
748
+
749
+ return med, med2
750
+
751
+ def bk_masked_median_2d_weiszfeld(self,
752
+ x: torch.Tensor,
753
+ mask: torch.Tensor,
754
+ max_iter: int = 100,
755
+ tol: float = 1e-6,
756
+ eps: float = 1e-12):
757
+ """
758
+ Masked geometric median over 2D spatial axes using Weiszfeld iteration.
759
+
760
+ Parameters
761
+ ----------
762
+ x : torch.Tensor
763
+ Input of shape [a, b, c, N1, N2]. Can be real or complex.
764
+ mask : torch.Tensor
765
+ Binary mask of shape [a, b, 1, N1, N2]; broadcasted over 'c'.
766
+ max_iter : int
767
+ Maximum number of Weiszfeld iterations.
768
+ tol : float
769
+ Stopping tolerance on the max absolute update per voxel.
770
+ eps : float
771
+ Small positive value to avoid division by zero in weights.
772
+
773
+ Returns
774
+ -------
775
+ med : torch.Tensor, shape [a, b, c]
776
+ Geometric median of x over (N1, N2) where mask == 1.
777
+ - If x is complex, distances are magnitudes |x - y| in the complex plane,
778
+ and the returned value is the complex sample estimate (not its magnitude).
779
+ med2 : torch.Tensor, shape [a, b, c]
780
+ Geometric median of squared values over (N1, N2) where mask == 1.
781
+ - If x is real : median of x**2 (via Weiszfeld in 1D).
782
+ - If x is complex : median of |x|**2 (real, non-negative).
783
+
784
+ Notes
785
+ -----
786
+ - Voxels with zero valid samples return NaN (NaN+NaNj for complex med).
787
+ - Weiszfeld update: y_{k+1} = sum_i w_i x_i / sum_i w_i with w_i = 1 / ||x_i - y_k||.
788
+ Here ||.|| is |.| for real numbers and the complex magnitude for complex numbers.
789
+ """
790
+ # Broadcast mask to x's shape [a,b,c,N1,N2] and flatten spatial dims to N
791
+ mask_bool = mask.to(torch.bool).expand_as(x)
792
+ a, b, c, N1, N2 = x.shape
793
+ N = N1 * N2
794
+ x_flat = x.reshape(a, b, c, N)
795
+ m_flat = mask_bool.reshape(a, b, c, N)
796
+
797
+ # Count valid samples per voxel
798
+ valid_counts = m_flat.sum(dim=-1) # [a,b,c]
799
+
800
+ # Helper to create NaN of the right dtype
801
+ def _nan_like(y):
802
+ if torch.is_complex(y):
803
+ return torch.full_like(y, complex(float('nan'), float('nan')))
804
+ else:
805
+ return torch.full_like(y, float('nan'))
806
+
807
+ # --- Geometric median of x (real or complex) ---
808
+ # Initialize y0: masked mean (robust enough as a starting point)
809
+ # y0 = sum(mask * x) / sum(mask)
810
+ denom = valid_counts.clamp_min(1).to(x.dtype)
811
+ if torch.is_complex(x):
812
+ denom_c = denom.to(x.dtype)
813
+ y = ( (m_flat * x_flat).sum(dim=-1) / denom_c ) # [a,b,c]
814
+ else:
815
+ y = ( (m_flat * x_flat).sum(dim=-1) / denom ) # [a,b,c]
816
+
817
+ # Where there are zero valid samples, set to NaN now (and keep NaN through)
818
+ zero_valid = (valid_counts == 0)
819
+ if torch.is_complex(x):
820
+ y = torch.where(zero_valid, torch.full_like(y, complex(float('nan'), float('nan'))), y)
821
+ else:
822
+ y = torch.where(zero_valid, torch.full_like(y, float('nan')), y)
823
+
824
+ # helper: nanmax replacement
825
+ def safe_nanmax(t):
826
+ if torch.isnan(t).all():
827
+ return torch.tensor(float('nan'), device=t.device, dtype=torch.float32)
828
+ return torch.max(t[~torch.isnan(t)]) if torch.isnan(t).any() else torch.max(t)
829
+
830
+ # Iterate Weiszfeld
831
+ for _ in range(max_iter):
832
+ # Skip voxels with no valid samples
833
+ if torch.all(zero_valid):
834
+ break
835
+
836
+ # diff: [a,b,c,N], distances are |diff|
837
+ diff = x_flat - y.unsqueeze(-1) # broadcast y over N
838
+ dist = diff.abs() # real tensor, [a,b,c,N]
839
+
840
+ # weights w = mask / max(dist, eps)
841
+ w = m_flat * (1.0 / torch.clamp(dist, min=eps)) # [a,b,c,N]
842
+ w_sum = w.sum(dim=-1) # [a,b,c], real
843
+
844
+ # Next iterate y_new = sum(w * x) / sum(w)
845
+ # For complex x, w is real so (w*x) is complex — OK.
846
+ y_new = (w * x_flat).sum(dim=-1) / w_sum.clamp_min(eps)
847
+
848
+ # Keep NaN on zero-valid voxels
849
+ if torch.is_complex(x):
850
+ y_new = torch.where(zero_valid, torch.full_like(y_new, complex(float('nan'), float('nan'))), y_new)
851
+ else:
852
+ y_new = torch.where(zero_valid, torch.full_like(y_new, float('nan')), y_new)
853
+
854
+ # Convergence check (max absolute update over all voxels)
855
+ upd = (y_new - y).abs()
856
+ if safe_nanmax(upd).item() <= tol:
857
+ y = y_new
858
+ break
859
+ y = y_new
860
+
861
+ med = y # [a,b,c]
862
+
863
+ # --- Geometric median of squared values (med2) ---
864
+ if torch.is_complex(x):
865
+ s_flat = (x_flat.abs() ** 2) # [a,b,c,N], real
866
+ else:
867
+ s_flat = (x_flat ** 2) # [a,b,c,N], real
868
+
869
+ # Initialize z0 = masked mean of s
870
+ z = (m_flat * s_flat).sum(dim=-1) / denom # [a,b,c], real
871
+ z = torch.where(zero_valid, torch.full_like(z, float('nan')), z)
872
+
873
+ # Weiszfeld on scalars (1D) for s: distances are |s_i - z|
874
+ for _ in range(max_iter):
875
+ if torch.all(zero_valid):
876
+ break
877
+ diff_s = s_flat - z.unsqueeze(-1) # [a,b,c,N]
878
+ dist_s = diff_s.abs().clamp_min(eps) # avoid div-by-zero
879
+ w_s = m_flat * (1.0 / dist_s)
880
+ w_s_sum = w_s.sum(dim=-1)
881
+ z_new = (w_s * s_flat).sum(dim=-1) / w_s_sum.clamp_min(eps)
882
+ z_new = torch.where(zero_valid, torch.full_like(z_new, float('nan')), z_new)
883
+ upd_s = (z_new - z).abs()
884
+ if safe_nanmax(upd_s).item() <= tol:
885
+ z = z_new
886
+ break
887
+ z = z_new
888
+
889
+ med2 = z # [a,b,c], real
890
+
891
+ return med, med2
892
+
407
893
  # ---------------------------------------------−---------
408
894
  # -- BACKEND DEFINITION --
409
895
  # ---------------------------------------------−---------
@@ -578,6 +1064,14 @@ class BkTorch(BackendBase.BackendBase):
578
1064
  return self.backend.mean(data)
579
1065
  else:
580
1066
  return self.backend.mean(data, axis)
1067
+
1068
+ def bk_reduce_median(self, data, axis=None):
1069
+
1070
+ if axis is None:
1071
+ res,_ = self.backend.median(data)
1072
+ else:
1073
+ res,_ = self.backend.median(data, axis)
1074
+ return res
581
1075
 
582
1076
  def bk_reduce_min(self, data, axis=None):
583
1077