foscat 2025.9.5__py3-none-any.whl → 2025.11.1__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- foscat/BkTorch.py +635 -141
- foscat/FoCUS.py +92 -50
- foscat/Plot.py +4 -3
- foscat/healpix_unet_torch.py +17 -1
- foscat/healpix_vit_skip.py +445 -0
- foscat/healpix_vit_torch.py +521 -0
- foscat/planar_vit.py +206 -0
- foscat/scat.py +1 -1
- foscat/scat1D.py +1 -1
- foscat/scat_cov.py +2 -2
- foscat/unet_2_d_from_healpix_params.py +421 -0
- {foscat-2025.9.5.dist-info → foscat-2025.11.1.dist-info}/METADATA +1 -1
- {foscat-2025.9.5.dist-info → foscat-2025.11.1.dist-info}/RECORD +16 -12
- {foscat-2025.9.5.dist-info → foscat-2025.11.1.dist-info}/WHEEL +0 -0
- {foscat-2025.9.5.dist-info → foscat-2025.11.1.dist-info}/licenses/LICENSE +0 -0
- {foscat-2025.9.5.dist-info → foscat-2025.11.1.dist-info}/top_level.txt +0 -0
foscat/BkTorch.py
CHANGED
|
@@ -62,158 +62,118 @@ class BkTorch(BackendBase.BackendBase):
|
|
|
62
62
|
self.torch_device = (
|
|
63
63
|
torch.device("cuda") if torch.cuda.is_available() else torch.device("cpu")
|
|
64
64
|
)
|
|
65
|
-
|
|
66
|
-
|
|
67
|
-
|
|
68
|
-
|
|
69
|
-
# Robust binned_mean that supports arbitrary subsets (N not divisible by 4)
|
|
70
|
-
# and batched cell_ids of shape [B, N]. It returns compact per-parent means
|
|
71
|
-
# even when some parents are missing (sparse coverage).
|
|
72
|
-
|
|
73
|
-
def binned_mean_old(self, data, cell_ids, *, padded: bool = False, fill_value: float = float("nan")):
|
|
74
|
-
"""Average values over parent HEALPix pixels (nested) when downgrading nside→nside/2.
|
|
75
|
-
|
|
76
|
-
Works with full-sky or sparse subsets (no need for N to be divisible by 4).
|
|
65
|
+
|
|
66
|
+
def downsample_mean_2x2(self,tim: torch.Tensor) -> torch.Tensor:
|
|
67
|
+
"""
|
|
68
|
+
Average-pool tensor tim over non-overlapping 2x2 spatial blocks.
|
|
77
69
|
|
|
78
70
|
Parameters
|
|
79
71
|
----------
|
|
80
|
-
|
|
81
|
-
|
|
82
|
-
cell_ids : torch.LongTensor or np.ndarray
|
|
83
|
-
Shape ``[N]`` or ``[B, N]`` (nested indexing at the *child* resolution).
|
|
84
|
-
padded : bool, optional (default: False)
|
|
85
|
-
Only used when ``cell_ids`` is ``[B, N]``. If ``False``, returns Python
|
|
86
|
-
lists (ragged) of per-batch results. If ``True``, returns padded tensors
|
|
87
|
-
plus a boolean mask of valid bins.
|
|
88
|
-
fill_value : float, optional
|
|
89
|
-
Value used for padding when ``padded=True``.
|
|
72
|
+
tim : torch.Tensor
|
|
73
|
+
Tensor of shape [a, N1, N2, b].
|
|
90
74
|
|
|
91
75
|
Returns
|
|
92
76
|
-------
|
|
93
|
-
|
|
94
|
-
|
|
95
|
-
|
|
96
|
-
|
|
97
|
-
If ``cell_ids`` is ``[B, N]`` and ``padded=False``:
|
|
98
|
-
means_list : List[torch.Tensor] of length B, each shape ``[T, n_bins_b]``
|
|
99
|
-
where ``T = prod(data.shape[1:-1])`` (or 1 if none).
|
|
100
|
-
groups_list : List[torch.LongTensor] of length B, each shape ``[n_bins_b]``
|
|
101
|
-
|
|
102
|
-
If ``cell_ids`` is ``[B, N]`` and ``padded=True``:
|
|
103
|
-
mean_padded : torch.Tensor, shape ``[B, T, max_bins]`` (or ``[B, max_bins]`` if T==1)
|
|
104
|
-
groups_pad : torch.LongTensor, shape ``[B, max_bins]`` (parents, padded with -1)
|
|
105
|
-
mask : torch.BoolTensor, shape ``[B, max_bins]`` (True where valid)
|
|
77
|
+
torch.Tensor
|
|
78
|
+
Downsampled tensor of shape [a, N1//2, N2//2, b],
|
|
79
|
+
each element being the mean of a 2x2 block.
|
|
106
80
|
"""
|
|
107
|
-
|
|
108
|
-
|
|
109
|
-
|
|
110
|
-
|
|
111
|
-
|
|
112
|
-
|
|
113
|
-
|
|
114
|
-
|
|
115
|
-
|
|
116
|
-
|
|
117
|
-
|
|
118
|
-
|
|
119
|
-
|
|
120
|
-
|
|
121
|
-
|
|
122
|
-
|
|
123
|
-
|
|
124
|
-
|
|
125
|
-
|
|
126
|
-
|
|
127
|
-
|
|
128
|
-
|
|
129
|
-
|
|
130
|
-
|
|
131
|
-
|
|
132
|
-
|
|
133
|
-
|
|
134
|
-
|
|
135
|
-
|
|
136
|
-
|
|
137
|
-
|
|
138
|
-
|
|
139
|
-
|
|
140
|
-
|
|
141
|
-
|
|
142
|
-
|
|
143
|
-
|
|
144
|
-
|
|
145
|
-
|
|
146
|
-
|
|
147
|
-
|
|
148
|
-
|
|
149
|
-
|
|
150
|
-
|
|
151
|
-
|
|
152
|
-
|
|
153
|
-
|
|
154
|
-
|
|
155
|
-
|
|
156
|
-
|
|
157
|
-
|
|
158
|
-
|
|
159
|
-
|
|
160
|
-
|
|
161
|
-
|
|
162
|
-
# First pass: compute per-batch parents/inv and scatter means
|
|
163
|
-
for b in range(B):
|
|
164
|
-
groups_b = (cell_ids[b] // 4).to(torch.long) # [N]
|
|
165
|
-
parents_b, inv_b = torch.unique(groups_b, sorted=True, return_inverse=True)
|
|
166
|
-
n_bins_b = parents_b.numel()
|
|
167
|
-
max_bins = max(max_bins, n_bins_b)
|
|
168
|
-
|
|
169
|
-
# rows for this batch in data_flat
|
|
170
|
-
start = b * T
|
|
171
|
-
stop = (b + 1) * T
|
|
172
|
-
rows = slice(start, stop) # T rows
|
|
173
|
-
|
|
174
|
-
row_offsets = (torch.arange(T, device=data.device).unsqueeze(1) * n_bins_b)
|
|
175
|
-
idx = inv_b.unsqueeze(0).expand(T, -1) + row_offsets # [T, N]
|
|
176
|
-
|
|
177
|
-
vals_flat = data_flat[rows].reshape(-1)
|
|
178
|
-
idx_flat = idx.reshape(-1)
|
|
179
|
-
|
|
180
|
-
out_sum = torch.zeros(T * n_bins_b, dtype=data.dtype, device=data.device)
|
|
181
|
-
out_cnt = torch.zeros_like(out_sum)
|
|
182
|
-
out_sum.scatter_add_(0, idx_flat, vals_flat)
|
|
183
|
-
out_cnt.scatter_add_(0, idx_flat, torch.ones_like(vals_flat))
|
|
184
|
-
out_cnt.clamp_(min=1)
|
|
185
|
-
mean_bt = (out_sum / out_cnt).view(T, n_bins_b) # [T, n_bins_b]
|
|
186
|
-
|
|
187
|
-
means_list.append(mean_bt)
|
|
188
|
-
groups_list.append(parents_b)
|
|
189
|
-
|
|
190
|
-
if not padded:
|
|
191
|
-
return means_list, groups_list
|
|
192
|
-
|
|
193
|
-
# Padded output
|
|
194
|
-
# mean_padded: [B, T, max_bins]; groups_pad: [B, max_bins]; mask: [B, max_bins]
|
|
195
|
-
mean_pad = torch.full((B, T, max_bins), fill_value, dtype=data.dtype, device=data.device)
|
|
196
|
-
groups_pad = torch.full((B, max_bins), -1, dtype=torch.long, device=data.device)
|
|
197
|
-
mask = torch.zeros((B, max_bins), dtype=torch.bool, device=data.device)
|
|
198
|
-
for b, (m_b, g_b) in enumerate(zip(means_list, groups_list)):
|
|
199
|
-
nb = g_b.numel()
|
|
200
|
-
mean_pad[b, :, :nb] = m_b
|
|
201
|
-
groups_pad[b, :nb] = g_b
|
|
202
|
-
mask[b, :nb] = True
|
|
81
|
+
a, N1, N2, b = tim.shape
|
|
82
|
+
# Ensure even sizes
|
|
83
|
+
N1_2 = N1 // 2
|
|
84
|
+
N2_2 = N2 // 2
|
|
85
|
+
|
|
86
|
+
# reshape to group 2x2 patches
|
|
87
|
+
tim_reshaped = tim[:, :2*N1_2, :2*N2_2, :].reshape(a, N1_2, 2, N2_2, 2, b)
|
|
88
|
+
# mean over the two small dims (2x2)
|
|
89
|
+
out = tim_reshaped.mean(dim=(2, 4))
|
|
90
|
+
return out
|
|
91
|
+
|
|
92
|
+
def downsample_median_2x2(self,tim: torch.Tensor) -> torch.Tensor:
|
|
93
|
+
"""
|
|
94
|
+
2x2 block median downsampling on spatial axes (N1, N2).
|
|
95
|
+
|
|
96
|
+
Input:
|
|
97
|
+
tim: [a, N1, N2, b] (real or complex)
|
|
98
|
+
Output:
|
|
99
|
+
out: [a, N1//2, N2//2, b]
|
|
100
|
+
each value is the median over the corresponding 2x2 block.
|
|
101
|
+
- For complex inputs: median is taken by sorting the 4 values by |.|,
|
|
102
|
+
returning the complex sample at the lower median rank.
|
|
103
|
+
"""
|
|
104
|
+
a, N1, N2, b = tim.shape
|
|
105
|
+
N1_2 = N1 // 2
|
|
106
|
+
N2_2 = N2 // 2
|
|
107
|
+
# On ignore la dernière ligne/colonne si N1/N2 sont impairs
|
|
108
|
+
x = tim[:, :2*N1_2, :2*N2_2, :] # [a, 2*N1_2, 2*N2_2, b]
|
|
109
|
+
|
|
110
|
+
# Regrouper les blocs 2x2 -> construire une dernière dimension de taille 4
|
|
111
|
+
# Réarrange: [a, N1_2, 2, N2_2, 2, b] -> [a, N1_2, N2_2, b, 4]
|
|
112
|
+
x = x.reshape(a, N1_2, 2, N2_2, 2, b).permute(0, 1, 3, 5, 2, 4).reshape(a, N1_2, N2_2, b, 4)
|
|
113
|
+
|
|
114
|
+
if not torch.is_complex(x):
|
|
115
|
+
# Réel : médiane le long de la dernière dim (taille 4)
|
|
116
|
+
med, _ = torch.median(x, dim=-1) # [a, N1_2, N2_2, b]
|
|
117
|
+
return med
|
|
118
|
+
else:
|
|
119
|
+
# Complexe : trier par module puis prendre l'élément de rang 1 (médiane inférieure)
|
|
120
|
+
mags = x.abs() # [a, N1_2, N2_2, b, 4]
|
|
121
|
+
sorted_mag, idx = torch.sort(mags, dim=-1) # idx: indices triés par |.| croissant
|
|
122
|
+
# Récupérer l'indice de médiane inférieure (pour 4 éléments -> position 1)
|
|
123
|
+
med_rank = 1
|
|
124
|
+
gather_idx = idx[..., med_rank:med_rank+1] # [a, N1_2, N2_2, b, 1]
|
|
125
|
+
# Sélectionner la valeur complexe correspondante
|
|
126
|
+
med = torch.gather(x, dim=-1, index=gather_idx).squeeze(-1) # [a, N1_2, N2_2, b]
|
|
127
|
+
return med
|
|
128
|
+
|
|
129
|
+
def downsample_mean_1d(self,tim: torch.Tensor) -> torch.Tensor:
|
|
130
|
+
"""
|
|
131
|
+
Downsample tensor tim [a, N1] by averaging non-overlapping 2-element blocks.
|
|
132
|
+
Output shape: [a, N1//2]
|
|
133
|
+
"""
|
|
134
|
+
a, N1 = tim.shape
|
|
135
|
+
N1_2 = N1 // 2
|
|
203
136
|
|
|
204
|
-
|
|
205
|
-
|
|
206
|
-
extra = orig[1:] # e.g., (D1, D2, ...)
|
|
207
|
-
mean_pad = mean_pad.view(B, *extra, max_bins)
|
|
208
|
-
else:
|
|
209
|
-
mean_pad = mean_pad.view(B, max_bins)
|
|
137
|
+
# Ignore the last element if N1 is odd
|
|
138
|
+
x = tim[:, :2 * N1_2]
|
|
210
139
|
|
|
211
|
-
|
|
140
|
+
# Reshape to group pairs of 2 and take mean
|
|
141
|
+
x = x.reshape(a, N1_2, 2)
|
|
142
|
+
out = x.mean(dim=-1) # [a, N1//2]
|
|
143
|
+
return out
|
|
212
144
|
|
|
145
|
+
def downsample_median_1d(self,tim: torch.Tensor) -> torch.Tensor:
|
|
146
|
+
"""
|
|
147
|
+
Downsample tensor tim [a, N1] by taking the median of non-overlapping pairs (2 values).
|
|
148
|
+
Output shape: [a, N1//2]
|
|
149
|
+
- For real inputs: median of the two values.
|
|
150
|
+
- For complex inputs: pick the complex value with the smallest |.| among the two.
|
|
151
|
+
"""
|
|
152
|
+
a, N1 = tim.shape
|
|
153
|
+
N1_2 = N1 // 2
|
|
154
|
+
x = tim[:, :2 * N1_2].reshape(a, N1_2, 2) # group 2 by 2
|
|
155
|
+
|
|
156
|
+
if not torch.is_complex(x):
|
|
157
|
+
# Sort values in ascending order, then take mean of the two (true median for 2 samples)
|
|
158
|
+
x_sorted, _ = torch.sort(x, dim=-1)
|
|
159
|
+
med = x_sorted.mean(dim=-1) # [a, N1//2]
|
|
160
|
+
return med
|
|
213
161
|
else:
|
|
214
|
-
|
|
162
|
+
# Complex: sort by magnitude
|
|
163
|
+
mags = x.abs()
|
|
164
|
+
sorted_mags, idx = torch.sort(mags, dim=-1)
|
|
165
|
+
# Take the one with smallest magnitude (lower median)
|
|
166
|
+
med = torch.gather(x, dim=-1, index=idx[..., 0:1]).squeeze(-1)
|
|
167
|
+
return med
|
|
168
|
+
# ---------------------------------
|
|
169
|
+
# HEALPix binning utilities (nested)
|
|
170
|
+
# ---------------------------------
|
|
171
|
+
# Robust binned_mean that supports arbitrary subsets (N not divisible by 4)
|
|
172
|
+
# and batched cell_ids of shape [B, N]. It returns compact per-parent means
|
|
173
|
+
# even when some parents are missing (sparse coverage).
|
|
215
174
|
|
|
216
|
-
|
|
175
|
+
|
|
176
|
+
def binned_mean_old(
|
|
217
177
|
self,
|
|
218
178
|
data,
|
|
219
179
|
cell_ids,
|
|
@@ -385,7 +345,268 @@ class BkTorch(BackendBase.BackendBase):
|
|
|
385
345
|
|
|
386
346
|
else:
|
|
387
347
|
raise ValueError("`cell_ids` must be of shape [N] or [B, N].")
|
|
388
|
-
|
|
348
|
+
|
|
349
|
+
|
|
350
|
+
def binned_mean( # (garde ton nom si besoin de compat)
|
|
351
|
+
self,
|
|
352
|
+
data,
|
|
353
|
+
cell_ids,
|
|
354
|
+
*,
|
|
355
|
+
reduce: str = "mean", # "mean" | "max" | "median"
|
|
356
|
+
padded: bool = False,
|
|
357
|
+
fill_value: float = float("nan"),
|
|
358
|
+
):
|
|
359
|
+
"""
|
|
360
|
+
Reduce values over parent HEALPix pixels (nested) when downgrading nside→nside/2.
|
|
361
|
+
|
|
362
|
+
Parameters
|
|
363
|
+
----------
|
|
364
|
+
data : torch.Tensor | np.ndarray
|
|
365
|
+
Shape [..., N] or [B, ..., N].
|
|
366
|
+
cell_ids : torch.LongTensor | np.ndarray
|
|
367
|
+
Shape [N] or [B, N] (nested indexing at the child resolution).
|
|
368
|
+
reduce : {"mean","max","median"}, default "mean"
|
|
369
|
+
Aggregation within each parent group of 4 children.
|
|
370
|
+
padded : bool, default False
|
|
371
|
+
Only when `cell_ids` is [B, N]. If False, returns ragged Python lists.
|
|
372
|
+
If True, returns padded tensors + mask.
|
|
373
|
+
fill_value : float, default NaN
|
|
374
|
+
Padding value when `padded=True`.
|
|
375
|
+
|
|
376
|
+
Returns
|
|
377
|
+
-------
|
|
378
|
+
As in your original function, with aggregation set by `reduce`.
|
|
379
|
+
"""
|
|
380
|
+
|
|
381
|
+
# ---- Tensorize & device/dtype plumbing ----
|
|
382
|
+
if isinstance(data, np.ndarray):
|
|
383
|
+
data = torch.from_numpy(data).to(
|
|
384
|
+
dtype=torch.float32, device=getattr(self, "torch_device", "cpu")
|
|
385
|
+
)
|
|
386
|
+
if isinstance(cell_ids, np.ndarray):
|
|
387
|
+
cell_ids = torch.from_numpy(cell_ids).to(
|
|
388
|
+
dtype=torch.long, device=getattr(self, "torch_device", data.device)
|
|
389
|
+
)
|
|
390
|
+
data = data.to(device=getattr(self, "torch_device", data.device))
|
|
391
|
+
cell_ids = cell_ids.to(device=data.device, dtype=torch.long)
|
|
392
|
+
|
|
393
|
+
if data.ndim < 1:
|
|
394
|
+
raise ValueError("`data` must have at least 1 dimension (last is N).")
|
|
395
|
+
N = data.shape[-1]
|
|
396
|
+
orig = data.shape[:-1]
|
|
397
|
+
|
|
398
|
+
# ---- Utilities ----
|
|
399
|
+
def _segment_max(vals_flat, idx_flat, out_size):
|
|
400
|
+
"""Compute out[g] = max(vals[idx==g]); vectorized when possible."""
|
|
401
|
+
if hasattr(torch.Tensor, "scatter_reduce_"):
|
|
402
|
+
out = torch.full((out_size,), -float("inf"),
|
|
403
|
+
dtype=vals_flat.dtype, device=vals_flat.device)
|
|
404
|
+
out.scatter_reduce_(0, idx_flat, vals_flat, reduce="amax", include_self=True)
|
|
405
|
+
return out
|
|
406
|
+
out = torch.full((out_size,), -float("inf"),
|
|
407
|
+
dtype=vals_flat.dtype, device=vals_flat.device)
|
|
408
|
+
uniq = torch.unique(idx_flat)
|
|
409
|
+
for u in uniq.tolist():
|
|
410
|
+
m = (idx_flat == u)
|
|
411
|
+
if m.any():
|
|
412
|
+
out[u] = torch.max(vals_flat[m])
|
|
413
|
+
return out
|
|
414
|
+
|
|
415
|
+
def _median_from_four_real(v4: torch.Tensor) -> torch.Tensor:
|
|
416
|
+
"""
|
|
417
|
+
v4: [..., 4] real with NaN for missing.
|
|
418
|
+
Returns true median: average of the 2 middle finite values.
|
|
419
|
+
"""
|
|
420
|
+
# Sort with NaN-last: replace NaN by +inf for sorting, then restore
|
|
421
|
+
is_nan = torch.isnan(v4)
|
|
422
|
+
v4_sortkey = torch.where(is_nan, torch.full_like(v4, float('inf')), v4)
|
|
423
|
+
v_sorted, _ = torch.sort(v4_sortkey, dim=-1) # NaN (inf) at the end
|
|
424
|
+
|
|
425
|
+
# Count finite per group
|
|
426
|
+
k = torch.sum(~is_nan, dim=-1) # [...]
|
|
427
|
+
|
|
428
|
+
# For k==0 -> NaN; k==1 -> the single value; k>=2 -> average middle two
|
|
429
|
+
# Indices for middle-two among the first k finite values: m-1 and m (with m = k//2)
|
|
430
|
+
m = torch.clamp(k // 2, min=1) # ensure >=1 for gather
|
|
431
|
+
idx_lo = (m - 1).unsqueeze(-1)
|
|
432
|
+
idx_hi = torch.clamp(m, max=3).unsqueeze(-1) # upper middle (cap at 3)
|
|
433
|
+
|
|
434
|
+
# Gather from sorted finite section
|
|
435
|
+
gather_lo = torch.gather(v_sorted, -1, idx_lo).squeeze(-1)
|
|
436
|
+
gather_hi = torch.gather(v_sorted, -1, idx_hi).squeeze(-1)
|
|
437
|
+
med = 0.5 * (gather_lo + gather_hi)
|
|
438
|
+
|
|
439
|
+
# Handle k==1: both idx point to same single finite value -> OK
|
|
440
|
+
# Handle k==0: set NaN
|
|
441
|
+
med = torch.where(k > 0, med, torch.full_like(med, float('nan')))
|
|
442
|
+
return med
|
|
443
|
+
|
|
444
|
+
def _median_from_four_complex(v4: torch.Tensor) -> torch.Tensor:
|
|
445
|
+
"""
|
|
446
|
+
v4: [..., 4] complex with NaN for missing.
|
|
447
|
+
Returns lower median by magnitude (rank 1 among finite elements).
|
|
448
|
+
"""
|
|
449
|
+
mags = v4.abs()
|
|
450
|
+
# NaN mags -> set to +inf so they go last
|
|
451
|
+
mags_key = torch.where(torch.isnan(mags), torch.full_like(mags, float('inf')), mags)
|
|
452
|
+
mags_sorted, idx = torch.sort(mags_key, dim=-1)
|
|
453
|
+
# Count finite elements
|
|
454
|
+
k = torch.sum(torch.isfinite(mags), dim=-1) # [...]
|
|
455
|
+
# lower median rank = max(0, k//2 - 1)
|
|
456
|
+
rank = torch.clamp(k // 2 - 1, min=0).unsqueeze(-1)
|
|
457
|
+
pick = torch.gather(idx, -1, rank)
|
|
458
|
+
med = torch.gather(v4, -1, pick).squeeze(-1)
|
|
459
|
+
# If k==0 -> NaN+NaNj
|
|
460
|
+
med = torch.where(
|
|
461
|
+
k > 0,
|
|
462
|
+
med,
|
|
463
|
+
torch.full_like(med, complex(float('nan'), float('nan')))
|
|
464
|
+
)
|
|
465
|
+
return med
|
|
466
|
+
|
|
467
|
+
# ---- Branch: cell_ids shape [N] (shared mapping) ----
|
|
468
|
+
if cell_ids.ndim == 1:
|
|
469
|
+
groups = (cell_ids // 4).to(torch.long) # [N] parent ids (global)
|
|
470
|
+
parents, inv = torch.unique(groups, sorted=True, return_inverse=True)
|
|
471
|
+
n_bins = parents.numel()
|
|
472
|
+
|
|
473
|
+
R = int(np.prod(orig)) if len(orig) > 0 else 1
|
|
474
|
+
data_flat = data.reshape(R, N) # [R, N]
|
|
475
|
+
|
|
476
|
+
if reduce in ("mean", "max"):
|
|
477
|
+
# Vectorized scatter path (same as before)
|
|
478
|
+
row_offsets = torch.arange(R, device=data.device).unsqueeze(1) * n_bins
|
|
479
|
+
idx = inv.unsqueeze(0).expand(R, -1) + row_offsets
|
|
480
|
+
vals_flat = data_flat.reshape(-1)
|
|
481
|
+
idx_flat = idx.reshape(-1)
|
|
482
|
+
out_size = R * n_bins
|
|
483
|
+
|
|
484
|
+
if reduce == "mean":
|
|
485
|
+
out_sum = torch.zeros(out_size, dtype=data.dtype, device=data.device)
|
|
486
|
+
out_cnt = torch.zeros_like(out_sum)
|
|
487
|
+
out_sum.scatter_add_(0, idx_flat, vals_flat)
|
|
488
|
+
out_cnt.scatter_add_(0, idx_flat, torch.ones_like(vals_flat))
|
|
489
|
+
out_cnt.clamp_(min=1)
|
|
490
|
+
reduced = out_sum / out_cnt
|
|
491
|
+
else: # "max"
|
|
492
|
+
reduced = _segment_max(vals_flat, idx_flat, out_size)
|
|
493
|
+
|
|
494
|
+
output = reduced.view(*orig, n_bins)
|
|
495
|
+
return output, parents
|
|
496
|
+
|
|
497
|
+
elif reduce == "median":
|
|
498
|
+
# Build a 4-slot array per parent using child offset = cell_ids % 4
|
|
499
|
+
off = (cell_ids % 4).to(torch.long) # [N] in {0,1,2,3}
|
|
500
|
+
out4 = torch.full((R, n_bins, 4),
|
|
501
|
+
torch.nan,
|
|
502
|
+
dtype=data.dtype,
|
|
503
|
+
device=data.device)
|
|
504
|
+
# flat indexing to scatter
|
|
505
|
+
base = torch.arange(R, device=data.device).unsqueeze(1) * (n_bins * 4)
|
|
506
|
+
flat_index = base + (inv.unsqueeze(0) * 4) + off.unsqueeze(0) # [R, N]
|
|
507
|
+
out4 = out4.reshape(-1)
|
|
508
|
+
out4.scatter_(0, flat_index.reshape(-1), data_flat.reshape(-1))
|
|
509
|
+
out4 = out4.view(R, n_bins, 4) # [R, n_bins, 4]
|
|
510
|
+
|
|
511
|
+
if torch.is_complex(data):
|
|
512
|
+
med = _median_from_four_complex(out4) # [R, n_bins]
|
|
513
|
+
else:
|
|
514
|
+
med = _median_from_four_real(out4) # [R, n_bins]
|
|
515
|
+
|
|
516
|
+
output = med.view(*orig, n_bins)
|
|
517
|
+
return output, parents
|
|
518
|
+
else:
|
|
519
|
+
raise ValueError("reduce must be 'mean', 'max', or 'median'.")
|
|
520
|
+
|
|
521
|
+
# ---- Branch: cell_ids shape [B, N] (per-batch mapping) ----
|
|
522
|
+
elif cell_ids.ndim == 2:
|
|
523
|
+
B = cell_ids.shape[0]
|
|
524
|
+
R = int(np.prod(orig)) if len(orig) > 0 else 1
|
|
525
|
+
data_flat = data.reshape(R, N)
|
|
526
|
+
B_data = data.shape[0] if len(orig) > 0 else 1
|
|
527
|
+
if B_data % B != 0:
|
|
528
|
+
raise ValueError(
|
|
529
|
+
f"Leading dim of data ({B_data}) must be a multiple of cell_ids batch ({B})."
|
|
530
|
+
)
|
|
531
|
+
T = (R // B_data) if B_data > 0 else 1 # repeats per batch row
|
|
532
|
+
|
|
533
|
+
outs_list, groups_list = [], []
|
|
534
|
+
max_bins = 0
|
|
535
|
+
|
|
536
|
+
for b in range(B):
|
|
537
|
+
groups_b = (cell_ids[b] // 4).to(torch.long) # [N]
|
|
538
|
+
parents_b, inv_b = torch.unique(groups_b, sorted=True, return_inverse=True)
|
|
539
|
+
n_bins_b = parents_b.numel()
|
|
540
|
+
max_bins = max(max_bins, n_bins_b)
|
|
541
|
+
|
|
542
|
+
# rows of data_flat that correspond to this batch row
|
|
543
|
+
start, stop = b * T, (b + 1) * T
|
|
544
|
+
rows = slice(start, stop) # T rows -> [T, N]
|
|
545
|
+
vals = data_flat[rows]
|
|
546
|
+
|
|
547
|
+
if reduce in ("mean", "max"):
|
|
548
|
+
row_offsets = torch.arange(T, device=data.device).unsqueeze(1) * n_bins_b
|
|
549
|
+
idx = inv_b.unsqueeze(0).expand(T, -1) + row_offsets
|
|
550
|
+
vals_flat = vals.reshape(-1)
|
|
551
|
+
idx_flat = idx.reshape(-1)
|
|
552
|
+
out_size = T * n_bins_b
|
|
553
|
+
if reduce == "mean":
|
|
554
|
+
out_sum = torch.zeros(out_size, dtype=data.dtype, device=data.device)
|
|
555
|
+
out_cnt = torch.zeros_like(out_sum)
|
|
556
|
+
out_sum.scatter_add_(0, idx_flat, vals_flat)
|
|
557
|
+
out_cnt.scatter_add_(0, idx_flat, torch.ones_like(vals_flat))
|
|
558
|
+
out_cnt.clamp_(min=1)
|
|
559
|
+
reduced_bt = (out_sum / out_cnt).view(T, n_bins_b)
|
|
560
|
+
else:
|
|
561
|
+
reduced_bt = _segment_max(vals_flat, idx_flat, out_size).view(T, n_bins_b)
|
|
562
|
+
outs_list.append(reduced_bt)
|
|
563
|
+
groups_list.append(parents_b)
|
|
564
|
+
elif reduce == "median":
|
|
565
|
+
off_b = (cell_ids[b] % 4).to(torch.long) # [N] in {0,1,2,3}
|
|
566
|
+
out4 = torch.full((T, n_bins_b, 4),
|
|
567
|
+
torch.nan,
|
|
568
|
+
dtype=data.dtype,
|
|
569
|
+
device=data.device)
|
|
570
|
+
base = torch.arange(T, device=data.device).unsqueeze(1) * (n_bins_b * 4)
|
|
571
|
+
flat_index = base + (inv_b.unsqueeze(0) * 4) + off_b.unsqueeze(0) # [T, N]
|
|
572
|
+
out4 = out4.reshape(-1)
|
|
573
|
+
out4.scatter_(0, flat_index.reshape(-1), vals.reshape(-1))
|
|
574
|
+
out4 = out4.view(T, n_bins_b, 4) # [T, n_bins_b, 4]
|
|
575
|
+
|
|
576
|
+
if torch.is_complex(data):
|
|
577
|
+
reduced_bt = _median_from_four_complex(out4) # [T, n_bins_b]
|
|
578
|
+
else:
|
|
579
|
+
reduced_bt = _median_from_four_real(out4) # [T, n_bins_b]
|
|
580
|
+
|
|
581
|
+
outs_list.append(reduced_bt)
|
|
582
|
+
groups_list.append(parents_b)
|
|
583
|
+
else:
|
|
584
|
+
raise ValueError("reduce must be 'mean', 'max', or 'median'.")
|
|
585
|
+
|
|
586
|
+
if not padded:
|
|
587
|
+
return outs_list, groups_list
|
|
588
|
+
|
|
589
|
+
# Padded output (B, T, max_bins) [+ mask]
|
|
590
|
+
out_pad = torch.full((B, T, max_bins), fill_value, dtype=data.dtype, device=data.device)
|
|
591
|
+
groups_pad = torch.full((B, max_bins), -1, dtype=torch.long, device=data.device)
|
|
592
|
+
mask = torch.zeros((B, max_bins), dtype=torch.bool, device=data.device)
|
|
593
|
+
for b, (o_b, g_b) in enumerate(zip(outs_list, groups_list)):
|
|
594
|
+
nb = g_b.numel()
|
|
595
|
+
out_pad[b, :, :nb] = o_b
|
|
596
|
+
groups_pad[b, :nb] = g_b
|
|
597
|
+
mask[b, :nb] = True
|
|
598
|
+
|
|
599
|
+
if len(orig) > 1:
|
|
600
|
+
extra = orig[1:]
|
|
601
|
+
out_pad = out_pad.view(B, *extra, max_bins)
|
|
602
|
+
else:
|
|
603
|
+
out_pad = out_pad.view(B, max_bins)
|
|
604
|
+
|
|
605
|
+
return out_pad, groups_pad, mask
|
|
606
|
+
|
|
607
|
+
else:
|
|
608
|
+
raise ValueError("`cell_ids` must be of shape [N] or [B, N].")
|
|
609
|
+
|
|
389
610
|
def average_by_cell_group(data, cell_ids):
|
|
390
611
|
"""
|
|
391
612
|
data: tensor of shape [..., N, ...] (ex: [B, N, C])
|
|
@@ -404,6 +625,271 @@ class BkTorch(BackendBase.BackendBase):
|
|
|
404
625
|
|
|
405
626
|
return torch.bincount(group_indices, weights=data) / counts, unique_groups
|
|
406
627
|
|
|
628
|
+
def bk_masked_median(self, x: torch.Tensor, mask: torch.Tensor,
|
|
629
|
+
max_iter: int = 100, tol: float = 1e-6, eps: float = 1e-12):
|
|
630
|
+
"""
|
|
631
|
+
Masked geometric median over the last axis using Weiszfeld iteration (1D case).
|
|
632
|
+
|
|
633
|
+
Parameters
|
|
634
|
+
----------
|
|
635
|
+
x : torch.Tensor
|
|
636
|
+
Shape [a, b, c, N]. Can be real or complex.
|
|
637
|
+
mask : torch.Tensor
|
|
638
|
+
Binary mask of shape [a, b, 1, N]; broadcast across 'c'.
|
|
639
|
+
max_iter : int
|
|
640
|
+
Max number of Weiszfeld iterations.
|
|
641
|
+
tol : float
|
|
642
|
+
Convergence tolerance on the max absolute update per voxel.
|
|
643
|
+
eps : float
|
|
644
|
+
Small value to avoid division-by-zero in the weights.
|
|
645
|
+
|
|
646
|
+
Returns
|
|
647
|
+
-------
|
|
648
|
+
med : torch.Tensor, shape [a, b, c]
|
|
649
|
+
Geometric median of x along the last axis where mask == 1.
|
|
650
|
+
- For complex x: distances use the complex magnitude |x - y|.
|
|
651
|
+
The returned median is complex.
|
|
652
|
+
med2 : torch.Tensor, shape [a, b, c]
|
|
653
|
+
Geometric median of squared values along the last axis where mask == 1.
|
|
654
|
+
- If x is real : median of x**2 (real).
|
|
655
|
+
- If x is complex : median of |x|**2 (real).
|
|
656
|
+
"""
|
|
657
|
+
|
|
658
|
+
# --- helpers ---
|
|
659
|
+
def _nan_like(y: torch.Tensor) -> torch.Tensor:
|
|
660
|
+
"""Return a NaN tensor with the same shape/dtype/device as y."""
|
|
661
|
+
if torch.is_complex(y):
|
|
662
|
+
return torch.full_like(y, complex(float('nan'), float('nan')))
|
|
663
|
+
else:
|
|
664
|
+
return torch.full_like(y, float('nan'))
|
|
665
|
+
|
|
666
|
+
def safe_nanmax(t: torch.Tensor) -> torch.Tensor:
|
|
667
|
+
"""
|
|
668
|
+
Backward-compatible replacement for torch.nanmax.
|
|
669
|
+
Assumes t is real-valued (we only call it on absolute updates).
|
|
670
|
+
"""
|
|
671
|
+
if torch.isnan(t).all():
|
|
672
|
+
return torch.tensor(float('nan'), device=t.device, dtype=t.dtype)
|
|
673
|
+
if torch.isnan(t).any():
|
|
674
|
+
return torch.max(t[~torch.isnan(t)])
|
|
675
|
+
return torch.max(t)
|
|
676
|
+
|
|
677
|
+
# --- prep shapes & mask ---
|
|
678
|
+
# Broadcast mask to x's shape [a,b,c,N]
|
|
679
|
+
mask_bool = mask.to(torch.bool).expand_as(x) # [a,b,c,N]
|
|
680
|
+
m_float = mask_bool.to(dtype=x.real.dtype) # weights need real dtype
|
|
681
|
+
|
|
682
|
+
# Count valid samples per voxel
|
|
683
|
+
valid_counts = mask_bool.sum(dim=-1) # [a,b,c]
|
|
684
|
+
zero_valid = (valid_counts == 0)
|
|
685
|
+
|
|
686
|
+
# Denominator for masked mean initialization (avoid div-by-zero with clamp_min)
|
|
687
|
+
denom = valid_counts.clamp_min(1).to(dtype=x.real.dtype) # real
|
|
688
|
+
|
|
689
|
+
# --- initialize y with masked mean (good starting point) ---
|
|
690
|
+
if torch.is_complex(x):
|
|
691
|
+
# (m_float*x) promotes to complex; denom to complex for division
|
|
692
|
+
y = (m_float * x).sum(dim=-1) / denom.to(dtype=x.dtype) # [a,b,c], complex
|
|
693
|
+
else:
|
|
694
|
+
y = (m_float * x).sum(dim=-1) / denom # [a,b,c], real
|
|
695
|
+
|
|
696
|
+
# Put NaNs where there are no valid samples
|
|
697
|
+
y = torch.where(zero_valid, _nan_like(y), y)
|
|
698
|
+
|
|
699
|
+
# --- Weiszfeld iterations for x -> med ---
|
|
700
|
+
# y_{k+1} = sum_i (x_i / ||x_i - y_k||) / sum_i (1/||x_i - y_k||), masked
|
|
701
|
+
for _ in range(max_iter):
|
|
702
|
+
if torch.all(zero_valid):
|
|
703
|
+
break
|
|
704
|
+
|
|
705
|
+
diff = x - y.unsqueeze(-1) # [a,b,c,N]
|
|
706
|
+
dist = diff.abs() # real, [a,b,c,N]
|
|
707
|
+
w = m_float * (1.0 / torch.clamp(dist, min=eps)) # real weights
|
|
708
|
+
w_sum = w.sum(dim=-1) # [a,b,c], real
|
|
709
|
+
y_new = (w * x).sum(dim=-1) / w_sum.clamp_min(eps) # [a,b,c], real/complex
|
|
710
|
+
|
|
711
|
+
# Keep NaNs on zero-valid voxels
|
|
712
|
+
y_new = torch.where(zero_valid, _nan_like(y_new), y_new)
|
|
713
|
+
|
|
714
|
+
# Convergence
|
|
715
|
+
upd = (y_new - y).abs() # real
|
|
716
|
+
if safe_nanmax(upd).item() <= tol:
|
|
717
|
+
y = y_new
|
|
718
|
+
break
|
|
719
|
+
y = y_new
|
|
720
|
+
|
|
721
|
+
med = y # [a,b,c]
|
|
722
|
+
|
|
723
|
+
# --- Weiszfeld iterations for squared values -> med2 ---
|
|
724
|
+
# For complex: use |x|^2; for real: x^2
|
|
725
|
+
s = (x.abs() ** 2) if torch.is_complex(x) else (x ** 2) # [a,b,c,N], real
|
|
726
|
+
# Init with masked mean of s
|
|
727
|
+
z = (m_float * s).sum(dim=-1) / denom # [a,b,c], real
|
|
728
|
+
z = torch.where(zero_valid, _nan_like(z), z)
|
|
729
|
+
|
|
730
|
+
# Weiszfeld on real scalars s
|
|
731
|
+
for _ in range(max_iter):
|
|
732
|
+
if torch.all(zero_valid):
|
|
733
|
+
break
|
|
734
|
+
diff_s = s - z.unsqueeze(-1) # [a,b,c,N]
|
|
735
|
+
dist_s = diff_s.abs().clamp_min(eps) # real
|
|
736
|
+
w_s = m_float * (1.0 / dist_s) # real
|
|
737
|
+
w_s_sum = w_s.sum(dim=-1) # [a,b,c]
|
|
738
|
+
z_new = (w_s * s).sum(dim=-1) / w_s_sum.clamp_min(eps) # [a,b,c]
|
|
739
|
+
z_new = torch.where(zero_valid, _nan_like(z_new), z_new)
|
|
740
|
+
|
|
741
|
+
upd_s = (z_new - z).abs()
|
|
742
|
+
if safe_nanmax(upd_s).item() <= tol:
|
|
743
|
+
z = z_new
|
|
744
|
+
break
|
|
745
|
+
z = z_new
|
|
746
|
+
|
|
747
|
+
med2 = z # [a,b,c], real
|
|
748
|
+
|
|
749
|
+
return med, med2
|
|
750
|
+
|
|
751
|
+
def bk_masked_median_2d_weiszfeld(self,
|
|
752
|
+
x: torch.Tensor,
|
|
753
|
+
mask: torch.Tensor,
|
|
754
|
+
max_iter: int = 100,
|
|
755
|
+
tol: float = 1e-6,
|
|
756
|
+
eps: float = 1e-12):
|
|
757
|
+
"""
|
|
758
|
+
Masked geometric median over 2D spatial axes using Weiszfeld iteration.
|
|
759
|
+
|
|
760
|
+
Parameters
|
|
761
|
+
----------
|
|
762
|
+
x : torch.Tensor
|
|
763
|
+
Input of shape [a, b, c, N1, N2]. Can be real or complex.
|
|
764
|
+
mask : torch.Tensor
|
|
765
|
+
Binary mask of shape [a, b, 1, N1, N2]; broadcasted over 'c'.
|
|
766
|
+
max_iter : int
|
|
767
|
+
Maximum number of Weiszfeld iterations.
|
|
768
|
+
tol : float
|
|
769
|
+
Stopping tolerance on the max absolute update per voxel.
|
|
770
|
+
eps : float
|
|
771
|
+
Small positive value to avoid division by zero in weights.
|
|
772
|
+
|
|
773
|
+
Returns
|
|
774
|
+
-------
|
|
775
|
+
med : torch.Tensor, shape [a, b, c]
|
|
776
|
+
Geometric median of x over (N1, N2) where mask == 1.
|
|
777
|
+
- If x is complex, distances are magnitudes |x - y| in the complex plane,
|
|
778
|
+
and the returned value is the complex sample estimate (not its magnitude).
|
|
779
|
+
med2 : torch.Tensor, shape [a, b, c]
|
|
780
|
+
Geometric median of squared values over (N1, N2) where mask == 1.
|
|
781
|
+
- If x is real : median of x**2 (via Weiszfeld in 1D).
|
|
782
|
+
- If x is complex : median of |x|**2 (real, non-negative).
|
|
783
|
+
|
|
784
|
+
Notes
|
|
785
|
+
-----
|
|
786
|
+
- Voxels with zero valid samples return NaN (NaN+NaNj for complex med).
|
|
787
|
+
- Weiszfeld update: y_{k+1} = sum_i w_i x_i / sum_i w_i with w_i = 1 / ||x_i - y_k||.
|
|
788
|
+
Here ||.|| is |.| for real numbers and the complex magnitude for complex numbers.
|
|
789
|
+
"""
|
|
790
|
+
# Broadcast mask to x's shape [a,b,c,N1,N2] and flatten spatial dims to N
|
|
791
|
+
mask_bool = mask.to(torch.bool).expand_as(x)
|
|
792
|
+
a, b, c, N1, N2 = x.shape
|
|
793
|
+
N = N1 * N2
|
|
794
|
+
x_flat = x.reshape(a, b, c, N)
|
|
795
|
+
m_flat = mask_bool.reshape(a, b, c, N)
|
|
796
|
+
|
|
797
|
+
# Count valid samples per voxel
|
|
798
|
+
valid_counts = m_flat.sum(dim=-1) # [a,b,c]
|
|
799
|
+
|
|
800
|
+
# Helper to create NaN of the right dtype
|
|
801
|
+
def _nan_like(y):
|
|
802
|
+
if torch.is_complex(y):
|
|
803
|
+
return torch.full_like(y, complex(float('nan'), float('nan')))
|
|
804
|
+
else:
|
|
805
|
+
return torch.full_like(y, float('nan'))
|
|
806
|
+
|
|
807
|
+
# --- Geometric median of x (real or complex) ---
|
|
808
|
+
# Initialize y0: masked mean (robust enough as a starting point)
|
|
809
|
+
# y0 = sum(mask * x) / sum(mask)
|
|
810
|
+
denom = valid_counts.clamp_min(1).to(x.dtype)
|
|
811
|
+
if torch.is_complex(x):
|
|
812
|
+
denom_c = denom.to(x.dtype)
|
|
813
|
+
y = ( (m_flat * x_flat).sum(dim=-1) / denom_c ) # [a,b,c]
|
|
814
|
+
else:
|
|
815
|
+
y = ( (m_flat * x_flat).sum(dim=-1) / denom ) # [a,b,c]
|
|
816
|
+
|
|
817
|
+
# Where there are zero valid samples, set to NaN now (and keep NaN through)
|
|
818
|
+
zero_valid = (valid_counts == 0)
|
|
819
|
+
if torch.is_complex(x):
|
|
820
|
+
y = torch.where(zero_valid, torch.full_like(y, complex(float('nan'), float('nan'))), y)
|
|
821
|
+
else:
|
|
822
|
+
y = torch.where(zero_valid, torch.full_like(y, float('nan')), y)
|
|
823
|
+
|
|
824
|
+
# helper: nanmax replacement
|
|
825
|
+
def safe_nanmax(t):
|
|
826
|
+
if torch.isnan(t).all():
|
|
827
|
+
return torch.tensor(float('nan'), device=t.device, dtype=torch.float32)
|
|
828
|
+
return torch.max(t[~torch.isnan(t)]) if torch.isnan(t).any() else torch.max(t)
|
|
829
|
+
|
|
830
|
+
# Iterate Weiszfeld
|
|
831
|
+
for _ in range(max_iter):
|
|
832
|
+
# Skip voxels with no valid samples
|
|
833
|
+
if torch.all(zero_valid):
|
|
834
|
+
break
|
|
835
|
+
|
|
836
|
+
# diff: [a,b,c,N], distances are |diff|
|
|
837
|
+
diff = x_flat - y.unsqueeze(-1) # broadcast y over N
|
|
838
|
+
dist = diff.abs() # real tensor, [a,b,c,N]
|
|
839
|
+
|
|
840
|
+
# weights w = mask / max(dist, eps)
|
|
841
|
+
w = m_flat * (1.0 / torch.clamp(dist, min=eps)) # [a,b,c,N]
|
|
842
|
+
w_sum = w.sum(dim=-1) # [a,b,c], real
|
|
843
|
+
|
|
844
|
+
# Next iterate y_new = sum(w * x) / sum(w)
|
|
845
|
+
# For complex x, w is real so (w*x) is complex — OK.
|
|
846
|
+
y_new = (w * x_flat).sum(dim=-1) / w_sum.clamp_min(eps)
|
|
847
|
+
|
|
848
|
+
# Keep NaN on zero-valid voxels
|
|
849
|
+
if torch.is_complex(x):
|
|
850
|
+
y_new = torch.where(zero_valid, torch.full_like(y_new, complex(float('nan'), float('nan'))), y_new)
|
|
851
|
+
else:
|
|
852
|
+
y_new = torch.where(zero_valid, torch.full_like(y_new, float('nan')), y_new)
|
|
853
|
+
|
|
854
|
+
# Convergence check (max absolute update over all voxels)
|
|
855
|
+
upd = (y_new - y).abs()
|
|
856
|
+
if safe_nanmax(upd).item() <= tol:
|
|
857
|
+
y = y_new
|
|
858
|
+
break
|
|
859
|
+
y = y_new
|
|
860
|
+
|
|
861
|
+
med = y # [a,b,c]
|
|
862
|
+
|
|
863
|
+
# --- Geometric median of squared values (med2) ---
|
|
864
|
+
if torch.is_complex(x):
|
|
865
|
+
s_flat = (x_flat.abs() ** 2) # [a,b,c,N], real
|
|
866
|
+
else:
|
|
867
|
+
s_flat = (x_flat ** 2) # [a,b,c,N], real
|
|
868
|
+
|
|
869
|
+
# Initialize z0 = masked mean of s
|
|
870
|
+
z = (m_flat * s_flat).sum(dim=-1) / denom # [a,b,c], real
|
|
871
|
+
z = torch.where(zero_valid, torch.full_like(z, float('nan')), z)
|
|
872
|
+
|
|
873
|
+
# Weiszfeld on scalars (1D) for s: distances are |s_i - z|
|
|
874
|
+
for _ in range(max_iter):
|
|
875
|
+
if torch.all(zero_valid):
|
|
876
|
+
break
|
|
877
|
+
diff_s = s_flat - z.unsqueeze(-1) # [a,b,c,N]
|
|
878
|
+
dist_s = diff_s.abs().clamp_min(eps) # avoid div-by-zero
|
|
879
|
+
w_s = m_flat * (1.0 / dist_s)
|
|
880
|
+
w_s_sum = w_s.sum(dim=-1)
|
|
881
|
+
z_new = (w_s * s_flat).sum(dim=-1) / w_s_sum.clamp_min(eps)
|
|
882
|
+
z_new = torch.where(zero_valid, torch.full_like(z_new, float('nan')), z_new)
|
|
883
|
+
upd_s = (z_new - z).abs()
|
|
884
|
+
if safe_nanmax(upd_s).item() <= tol:
|
|
885
|
+
z = z_new
|
|
886
|
+
break
|
|
887
|
+
z = z_new
|
|
888
|
+
|
|
889
|
+
med2 = z # [a,b,c], real
|
|
890
|
+
|
|
891
|
+
return med, med2
|
|
892
|
+
|
|
407
893
|
# ---------------------------------------------−---------
|
|
408
894
|
# -- BACKEND DEFINITION --
|
|
409
895
|
# ---------------------------------------------−---------
|
|
@@ -578,6 +1064,14 @@ class BkTorch(BackendBase.BackendBase):
|
|
|
578
1064
|
return self.backend.mean(data)
|
|
579
1065
|
else:
|
|
580
1066
|
return self.backend.mean(data, axis)
|
|
1067
|
+
|
|
1068
|
+
def bk_reduce_median(self, data, axis=None):
|
|
1069
|
+
|
|
1070
|
+
if axis is None:
|
|
1071
|
+
res,_ = self.backend.median(data)
|
|
1072
|
+
else:
|
|
1073
|
+
res,_ = self.backend.median(data, axis)
|
|
1074
|
+
return res
|
|
581
1075
|
|
|
582
1076
|
def bk_reduce_min(self, data, axis=None):
|
|
583
1077
|
|