foscat 2025.8.4__py3-none-any.whl → 2025.9.3__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- foscat/BkTorch.py +309 -50
- foscat/FoCUS.py +74 -267
- foscat/HOrientedConvol.py +517 -130
- foscat/HealBili.py +309 -0
- foscat/Plot.py +331 -0
- foscat/SphericalStencil.py +1346 -0
- foscat/UNET.py +470 -179
- foscat/healpix_unet_torch.py +1202 -0
- foscat/scat_cov.py +3 -1
- {foscat-2025.8.4.dist-info → foscat-2025.9.3.dist-info}/METADATA +1 -1
- {foscat-2025.8.4.dist-info → foscat-2025.9.3.dist-info}/RECORD +14 -10
- {foscat-2025.8.4.dist-info → foscat-2025.9.3.dist-info}/WHEEL +0 -0
- {foscat-2025.8.4.dist-info → foscat-2025.9.3.dist-info}/licenses/LICENSE +0 -0
- {foscat-2025.8.4.dist-info → foscat-2025.9.3.dist-info}/top_level.txt +0 -0
foscat/BkTorch.py
CHANGED
|
@@ -63,70 +63,329 @@ class BkTorch(BackendBase.BackendBase):
|
|
|
63
63
|
torch.device("cuda") if torch.cuda.is_available() else torch.device("cpu")
|
|
64
64
|
)
|
|
65
65
|
|
|
66
|
-
|
|
67
|
-
|
|
68
|
-
|
|
66
|
+
# ---------------------------------
|
|
67
|
+
# HEALPix binning utilities (nested)
|
|
68
|
+
# ---------------------------------
|
|
69
|
+
# Robust binned_mean that supports arbitrary subsets (N not divisible by 4)
|
|
70
|
+
# and batched cell_ids of shape [B, N]. It returns compact per-parent means
|
|
71
|
+
# even when some parents are missing (sparse coverage).
|
|
72
|
+
|
|
73
|
+
def binned_mean_old(self, data, cell_ids, *, padded: bool = False, fill_value: float = float("nan")):
|
|
74
|
+
"""Average values over parent HEALPix pixels (nested) when downgrading nside→nside/2.
|
|
75
|
+
|
|
76
|
+
Works with full-sky or sparse subsets (no need for N to be divisible by 4).
|
|
77
|
+
|
|
78
|
+
Parameters
|
|
79
|
+
----------
|
|
80
|
+
data : torch.Tensor or np.ndarray
|
|
81
|
+
Shape ``[..., N]`` or ``[B, ..., N]``.
|
|
82
|
+
cell_ids : torch.LongTensor or np.ndarray
|
|
83
|
+
Shape ``[N]`` or ``[B, N]`` (nested indexing at the *child* resolution).
|
|
84
|
+
padded : bool, optional (default: False)
|
|
85
|
+
Only used when ``cell_ids`` is ``[B, N]``. If ``False``, returns Python
|
|
86
|
+
lists (ragged) of per-batch results. If ``True``, returns padded tensors
|
|
87
|
+
plus a boolean mask of valid bins.
|
|
88
|
+
fill_value : float, optional
|
|
89
|
+
Value used for padding when ``padded=True``.
|
|
90
|
+
|
|
91
|
+
Returns
|
|
92
|
+
-------
|
|
93
|
+
If ``cell_ids`` is ``[N]``:
|
|
94
|
+
mean : torch.Tensor, shape ``[..., n_bins]``
|
|
95
|
+
groups: torch.LongTensor, shape ``[n_bins]`` (sorted unique parents)
|
|
96
|
+
|
|
97
|
+
If ``cell_ids`` is ``[B, N]`` and ``padded=False``:
|
|
98
|
+
means_list : List[torch.Tensor] of length B, each shape ``[T, n_bins_b]``
|
|
99
|
+
where ``T = prod(data.shape[1:-1])`` (or 1 if none).
|
|
100
|
+
groups_list : List[torch.LongTensor] of length B, each shape ``[n_bins_b]``
|
|
101
|
+
|
|
102
|
+
If ``cell_ids`` is ``[B, N]`` and ``padded=True``:
|
|
103
|
+
mean_padded : torch.Tensor, shape ``[B, T, max_bins]`` (or ``[B, max_bins]`` if T==1)
|
|
104
|
+
groups_pad : torch.LongTensor, shape ``[B, max_bins]`` (parents, padded with -1)
|
|
105
|
+
mask : torch.BoolTensor, shape ``[B, max_bins]`` (True where valid)
|
|
69
106
|
"""
|
|
70
|
-
|
|
107
|
+
import torch, numpy as np
|
|
71
108
|
|
|
72
|
-
|
|
73
|
-
|
|
74
|
-
|
|
109
|
+
# ---- Tensorize & device/dtype plumbing ----
|
|
110
|
+
if isinstance(data, np.ndarray):
|
|
111
|
+
data = torch.from_numpy(data).to(dtype=torch.float32, device=getattr(self, 'torch_device', 'cpu'))
|
|
112
|
+
if isinstance(cell_ids, np.ndarray):
|
|
113
|
+
cell_ids = torch.from_numpy(cell_ids).to(dtype=torch.long, device=data.device)
|
|
75
114
|
|
|
76
|
-
|
|
77
|
-
|
|
115
|
+
data = data.to(device=getattr(self, 'torch_device', data.device))
|
|
116
|
+
cell_ids = cell_ids.to(device=data.device, dtype=torch.long)
|
|
117
|
+
|
|
118
|
+
if data.ndim < 1:
|
|
119
|
+
raise ValueError("`data` must have at least 1 dimension (last is N).")
|
|
120
|
+
N = data.shape[-1]
|
|
121
|
+
|
|
122
|
+
# Flatten leading dims (rows) for scatter convenience
|
|
123
|
+
orig = data.shape[:-1]
|
|
124
|
+
T = int(np.prod(orig[1:])) if len(orig) > 1 else 1 # repeats per batch row
|
|
125
|
+
if cell_ids.ndim == 1:
|
|
126
|
+
# Shared mapping for all rows
|
|
127
|
+
groups = (cell_ids // 4).to(torch.long) # [N]
|
|
128
|
+
# Unique parent ids + inverse indices
|
|
129
|
+
parents, inv = torch.unique(groups, sorted=True, return_inverse=True)
|
|
130
|
+
n_bins = parents.numel()
|
|
131
|
+
|
|
132
|
+
R = int(np.prod(orig)) if len(orig) > 0 else 1
|
|
133
|
+
data_flat = data.reshape(R, N) # [R, N]
|
|
134
|
+
|
|
135
|
+
# Row offsets -> independent bins per row
|
|
136
|
+
row_offsets = torch.arange(R, device=data.device).unsqueeze(1) * n_bins # [R,1]
|
|
137
|
+
idx = inv.unsqueeze(0).expand(R, -1) + row_offsets # [R,N]
|
|
138
|
+
|
|
139
|
+
vals_flat = data_flat.reshape(-1)
|
|
140
|
+
idx_flat = idx.reshape(-1)
|
|
141
|
+
|
|
142
|
+
out_sum = torch.zeros(R * n_bins, dtype=data.dtype, device=data.device)
|
|
143
|
+
out_cnt = torch.zeros_like(out_sum)
|
|
144
|
+
out_sum.scatter_add_(0, idx_flat, vals_flat)
|
|
145
|
+
out_cnt.scatter_add_(0, idx_flat, torch.ones_like(vals_flat))
|
|
146
|
+
out_cnt.clamp_(min=1)
|
|
147
|
+
|
|
148
|
+
mean = (out_sum / out_cnt).view(*orig, n_bins)
|
|
149
|
+
return mean, parents
|
|
150
|
+
|
|
151
|
+
elif cell_ids.ndim == 2:
|
|
152
|
+
B = cell_ids.shape[0]
|
|
153
|
+
if data.shape[0] % B != 0:
|
|
154
|
+
raise ValueError(f"Leading dim of data ({data.shape[0]}) must be a multiple of cell_ids batch ({B}).")
|
|
155
|
+
R = int(np.prod(orig)) if len(orig) > 0 else 1
|
|
156
|
+
data_flat = data.reshape(R, N) # [R, N]
|
|
157
|
+
B_data = data.shape[0]
|
|
158
|
+
T = R // B_data # repeats per batch row (product of extra leading dims)
|
|
159
|
+
|
|
160
|
+
means_list, groups_list = [], []
|
|
161
|
+
max_bins = 0
|
|
162
|
+
# First pass: compute per-batch parents/inv and scatter means
|
|
163
|
+
for b in range(B):
|
|
164
|
+
groups_b = (cell_ids[b] // 4).to(torch.long) # [N]
|
|
165
|
+
parents_b, inv_b = torch.unique(groups_b, sorted=True, return_inverse=True)
|
|
166
|
+
n_bins_b = parents_b.numel()
|
|
167
|
+
max_bins = max(max_bins, n_bins_b)
|
|
168
|
+
|
|
169
|
+
# rows for this batch in data_flat
|
|
170
|
+
start = b * T
|
|
171
|
+
stop = (b + 1) * T
|
|
172
|
+
rows = slice(start, stop) # T rows
|
|
173
|
+
|
|
174
|
+
row_offsets = (torch.arange(T, device=data.device).unsqueeze(1) * n_bins_b)
|
|
175
|
+
idx = inv_b.unsqueeze(0).expand(T, -1) + row_offsets # [T, N]
|
|
176
|
+
|
|
177
|
+
vals_flat = data_flat[rows].reshape(-1)
|
|
178
|
+
idx_flat = idx.reshape(-1)
|
|
179
|
+
|
|
180
|
+
out_sum = torch.zeros(T * n_bins_b, dtype=data.dtype, device=data.device)
|
|
181
|
+
out_cnt = torch.zeros_like(out_sum)
|
|
182
|
+
out_sum.scatter_add_(0, idx_flat, vals_flat)
|
|
183
|
+
out_cnt.scatter_add_(0, idx_flat, torch.ones_like(vals_flat))
|
|
184
|
+
out_cnt.clamp_(min=1)
|
|
185
|
+
mean_bt = (out_sum / out_cnt).view(T, n_bins_b) # [T, n_bins_b]
|
|
186
|
+
|
|
187
|
+
means_list.append(mean_bt)
|
|
188
|
+
groups_list.append(parents_b)
|
|
189
|
+
|
|
190
|
+
if not padded:
|
|
191
|
+
return means_list, groups_list
|
|
192
|
+
|
|
193
|
+
# Padded output
|
|
194
|
+
# mean_padded: [B, T, max_bins]; groups_pad: [B, max_bins]; mask: [B, max_bins]
|
|
195
|
+
mean_pad = torch.full((B, T, max_bins), fill_value, dtype=data.dtype, device=data.device)
|
|
196
|
+
groups_pad = torch.full((B, max_bins), -1, dtype=torch.long, device=data.device)
|
|
197
|
+
mask = torch.zeros((B, max_bins), dtype=torch.bool, device=data.device)
|
|
198
|
+
for b, (m_b, g_b) in enumerate(zip(means_list, groups_list)):
|
|
199
|
+
nb = g_b.numel()
|
|
200
|
+
mean_pad[b, :, :nb] = m_b
|
|
201
|
+
groups_pad[b, :nb] = g_b
|
|
202
|
+
mask[b, :nb] = True
|
|
203
|
+
|
|
204
|
+
# Reshape back to [B, (*extra leading dims), max_bins] if needed
|
|
205
|
+
if len(orig) > 1:
|
|
206
|
+
extra = orig[1:] # e.g., (D1, D2, ...)
|
|
207
|
+
mean_pad = mean_pad.view(B, *extra, max_bins)
|
|
208
|
+
else:
|
|
209
|
+
mean_pad = mean_pad.view(B, max_bins)
|
|
210
|
+
|
|
211
|
+
return mean_pad, groups_pad, mask
|
|
212
|
+
|
|
213
|
+
else:
|
|
214
|
+
raise ValueError("`cell_ids` must be of shape [N] or [B, N].")
|
|
215
|
+
|
|
216
|
+
def binned_mean(
|
|
217
|
+
self,
|
|
218
|
+
data,
|
|
219
|
+
cell_ids,
|
|
220
|
+
*,
|
|
221
|
+
reduce: str = "mean", # <-- NEW: "mean" (par défaut) ou "max"
|
|
222
|
+
padded: bool = False,
|
|
223
|
+
fill_value: float = float("nan"),
|
|
224
|
+
):
|
|
225
|
+
"""
|
|
226
|
+
Reduce values over parent HEALPix pixels (nested) when downgrading nside→nside/2.
|
|
227
|
+
|
|
228
|
+
Parameters
|
|
229
|
+
----------
|
|
230
|
+
data : torch.Tensor | np.ndarray
|
|
231
|
+
Shape [..., N] or [B, ..., N].
|
|
232
|
+
cell_ids : torch.LongTensor | np.ndarray
|
|
233
|
+
Shape [N] or [B, N] (nested indexing at the child resolution).
|
|
234
|
+
reduce : {"mean","max"}, default "mean"
|
|
235
|
+
Aggregation to apply within each parent group of 4 children.
|
|
236
|
+
padded : bool, default False
|
|
237
|
+
Only used when `cell_ids` is [B, N]. If False, returns ragged Python lists.
|
|
238
|
+
If True, returns padded tensors + mask.
|
|
239
|
+
fill_value : float, default NaN
|
|
240
|
+
Padding value when `padded=True`.
|
|
241
|
+
|
|
242
|
+
Returns
|
|
243
|
+
-------
|
|
244
|
+
# idem à ta doc existante, mais la valeur est une moyenne (reduce="mean")
|
|
245
|
+
# ou un maximum (reduce="max").
|
|
78
246
|
"""
|
|
247
|
+
|
|
248
|
+
# ---- Tensorize & device/dtype plumbing ----
|
|
79
249
|
if isinstance(data, np.ndarray):
|
|
80
250
|
data = torch.from_numpy(data).to(
|
|
81
|
-
dtype=torch.float32, device=self
|
|
251
|
+
dtype=torch.float32, device=getattr(self, "torch_device", "cpu")
|
|
82
252
|
)
|
|
83
253
|
if isinstance(cell_ids, np.ndarray):
|
|
84
254
|
cell_ids = torch.from_numpy(cell_ids).to(
|
|
85
|
-
dtype=torch.long, device=
|
|
255
|
+
dtype=torch.long, device=data.device
|
|
86
256
|
)
|
|
257
|
+
data = data.to(device=getattr(self, "torch_device", data.device))
|
|
258
|
+
cell_ids = cell_ids.to(device=data.device, dtype=torch.long)
|
|
87
259
|
|
|
88
|
-
|
|
89
|
-
|
|
90
|
-
|
|
91
|
-
# Get unique group ids and inverse mapping
|
|
92
|
-
unique_groups, inverse_indices = torch.unique(groups, return_inverse=True)
|
|
93
|
-
n_bins = unique_groups.shape[0]
|
|
94
|
-
|
|
95
|
-
# Flatten all leading dimensions into a single batch dimension
|
|
96
|
-
original_shape = data.shape[:-1]
|
|
260
|
+
if data.ndim < 1:
|
|
261
|
+
raise ValueError("`data` must have at least 1 dimension (last is N).")
|
|
97
262
|
N = data.shape[-1]
|
|
98
|
-
data_flat = data.reshape(-1, N) # Shape: [B, N]
|
|
99
|
-
|
|
100
|
-
# Prepare to compute sums using scatter_add
|
|
101
|
-
B = data_flat.shape[0]
|
|
102
|
-
|
|
103
|
-
# Repeat inverse indices for each batch element
|
|
104
|
-
idx = inverse_indices.repeat(B, 1) # Shape: [B, N]
|
|
105
|
-
|
|
106
|
-
# Offset indices to simulate a per-batch scatter into [B * n_bins]
|
|
107
|
-
batch_offsets = torch.arange(B, device=data.device).unsqueeze(1) * n_bins
|
|
108
|
-
idx_offset = idx + batch_offsets # Shape: [B, N]
|
|
109
|
-
|
|
110
|
-
# Flatten everything for scatter
|
|
111
|
-
idx_offset_flat = idx_offset.flatten()
|
|
112
|
-
data_flat_flat = data_flat.flatten()
|
|
113
|
-
|
|
114
|
-
# Accumulate sums per bin
|
|
115
|
-
out = torch.zeros(B * n_bins, dtype=data.dtype, device=data.device)
|
|
116
|
-
out = out.scatter_add(0, idx_offset_flat, data_flat_flat)
|
|
117
263
|
|
|
118
|
-
#
|
|
119
|
-
|
|
120
|
-
|
|
121
|
-
|
|
122
|
-
|
|
123
|
-
|
|
124
|
-
|
|
125
|
-
|
|
264
|
+
# Utilitaires pour 'max' (fallback si scatter_reduce_ indisponible)
|
|
265
|
+
def _segment_max(vals_flat, idx_flat, out_size):
|
|
266
|
+
"""Retourne out[out_idx] = max(vals[ idx==out_idx ]), vectorisé si possible."""
|
|
267
|
+
# PyTorch >= 1.12 / 2.0: scatter_reduce_ disponible
|
|
268
|
+
if hasattr(torch.Tensor, "scatter_reduce_"):
|
|
269
|
+
out = torch.full((out_size,), -float("inf"),
|
|
270
|
+
dtype=vals_flat.dtype, device=vals_flat.device)
|
|
271
|
+
out.scatter_reduce_(0, idx_flat, vals_flat, reduce="amax", include_self=True)
|
|
272
|
+
return out
|
|
273
|
+
# Fallback simple (boucle sur indices uniques) – OK pour du downsample
|
|
274
|
+
out = torch.full((out_size,), -float("inf"),
|
|
275
|
+
dtype=vals_flat.dtype, device=vals_flat.device)
|
|
276
|
+
uniq = torch.unique(idx_flat)
|
|
277
|
+
for u in uniq.tolist():
|
|
278
|
+
m = (idx_flat == u)
|
|
279
|
+
# éviter max() sur tensor vide
|
|
280
|
+
if m.any():
|
|
281
|
+
out[u] = torch.max(vals_flat[m])
|
|
282
|
+
return out
|
|
283
|
+
|
|
284
|
+
# ---- Flatten leading dims for scatter convenience ----
|
|
285
|
+
orig = data.shape[:-1]
|
|
286
|
+
if cell_ids.ndim == 1:
|
|
287
|
+
# Shared mapping for all rows
|
|
288
|
+
groups = (cell_ids // 4).to(torch.long) # [N]
|
|
289
|
+
parents, inv = torch.unique(groups, sorted=True, return_inverse=True)
|
|
290
|
+
n_bins = parents.numel()
|
|
291
|
+
|
|
292
|
+
R = int(np.prod(orig)) if len(orig) > 0 else 1
|
|
293
|
+
data_flat = data.reshape(R, N) # [R, N]
|
|
294
|
+
row_offsets = torch.arange(R, device=data.device).unsqueeze(1) * n_bins
|
|
295
|
+
idx = inv.unsqueeze(0).expand(R, -1) + row_offsets # [R, N]
|
|
296
|
+
|
|
297
|
+
vals_flat = data_flat.reshape(-1)
|
|
298
|
+
idx_flat = idx.reshape(-1)
|
|
299
|
+
out_size = R * n_bins
|
|
300
|
+
|
|
301
|
+
if reduce == "mean":
|
|
302
|
+
out_sum = torch.zeros(out_size, dtype=data.dtype, device=data.device)
|
|
303
|
+
out_cnt = torch.zeros_like(out_sum)
|
|
304
|
+
out_sum.scatter_add_(0, idx_flat, vals_flat)
|
|
305
|
+
out_cnt.scatter_add_(0, idx_flat, torch.ones_like(vals_flat))
|
|
306
|
+
out_cnt.clamp_(min=1)
|
|
307
|
+
reduced = out_sum / out_cnt
|
|
308
|
+
elif reduce == "max":
|
|
309
|
+
reduced = _segment_max(vals_flat, idx_flat, out_size)
|
|
310
|
+
else:
|
|
311
|
+
raise ValueError("reduce must be 'mean' or 'max'.")
|
|
312
|
+
|
|
313
|
+
output = reduced.view(*orig, n_bins)
|
|
314
|
+
return output, parents
|
|
315
|
+
|
|
316
|
+
elif cell_ids.ndim == 2:
|
|
317
|
+
# Per-batch mapping
|
|
318
|
+
B = cell_ids.shape[0]
|
|
319
|
+
R = int(np.prod(orig)) if len(orig) > 0 else 1
|
|
320
|
+
data_flat = data.reshape(R, N) # [R, N]
|
|
321
|
+
B_data = data.shape[0] if len(orig) > 0 else 1
|
|
322
|
+
if B_data % B != 0:
|
|
323
|
+
raise ValueError(
|
|
324
|
+
f"Leading dim of data ({B_data}) must be a multiple of cell_ids batch ({B})."
|
|
325
|
+
)
|
|
326
|
+
# T = repeats per batch row (product of extra leading dims)
|
|
327
|
+
T = (R // B_data) if B_data > 0 else 1
|
|
328
|
+
|
|
329
|
+
means_list, groups_list = [], []
|
|
330
|
+
max_bins = 0
|
|
331
|
+
|
|
332
|
+
for b in range(B):
|
|
333
|
+
groups_b = (cell_ids[b] // 4).to(torch.long) # [N]
|
|
334
|
+
parents_b, inv_b = torch.unique(groups_b, sorted=True, return_inverse=True)
|
|
335
|
+
n_bins_b = parents_b.numel()
|
|
336
|
+
max_bins = max(max_bins, n_bins_b)
|
|
337
|
+
|
|
338
|
+
# rows for this batch in data_flat
|
|
339
|
+
start, stop = b * T, (b + 1) * T
|
|
340
|
+
rows = slice(start, stop) # T rows
|
|
341
|
+
|
|
342
|
+
row_offsets = torch.arange(T, device=data.device).unsqueeze(1) * n_bins_b
|
|
343
|
+
idx = inv_b.unsqueeze(0).expand(T, -1) + row_offsets # [T, N]
|
|
344
|
+
|
|
345
|
+
vals_flat = data_flat[rows].reshape(-1)
|
|
346
|
+
idx_flat = idx.reshape(-1)
|
|
347
|
+
out_size = T * n_bins_b
|
|
348
|
+
|
|
349
|
+
if reduce == "mean":
|
|
350
|
+
out_sum = torch.zeros(out_size, dtype=data.dtype, device=data.device)
|
|
351
|
+
out_cnt = torch.zeros_like(out_sum)
|
|
352
|
+
out_sum.scatter_add_(0, idx_flat, vals_flat)
|
|
353
|
+
out_cnt.scatter_add_(0, idx_flat, torch.ones_like(vals_flat))
|
|
354
|
+
out_cnt.clamp_(min=1)
|
|
355
|
+
reduced_bt = (out_sum / out_cnt).view(T, n_bins_b)
|
|
356
|
+
elif reduce == "max":
|
|
357
|
+
reduced_bt = _segment_max(vals_flat, idx_flat, out_size).view(T, n_bins_b)
|
|
358
|
+
else:
|
|
359
|
+
raise ValueError("reduce must be 'mean' or 'max'.")
|
|
360
|
+
|
|
361
|
+
means_list.append(reduced_bt)
|
|
362
|
+
groups_list.append(parents_b)
|
|
363
|
+
|
|
364
|
+
if not padded:
|
|
365
|
+
return means_list, groups_list
|
|
366
|
+
|
|
367
|
+
# Padded output (B, T, max_bins) [+ mask]
|
|
368
|
+
mean_pad = torch.full((B, T, max_bins), fill_value, dtype=data.dtype, device=data.device)
|
|
369
|
+
groups_pad = torch.full((B, max_bins), -1, dtype=torch.long, device=data.device)
|
|
370
|
+
mask = torch.zeros((B, max_bins), dtype=torch.bool, device=data.device)
|
|
371
|
+
for b, (m_b, g_b) in enumerate(zip(means_list, groups_list)):
|
|
372
|
+
nb = g_b.numel()
|
|
373
|
+
mean_pad[b, :, :nb] = m_b
|
|
374
|
+
groups_pad[b, :nb] = g_b
|
|
375
|
+
mask[b, :nb] = True
|
|
376
|
+
|
|
377
|
+
# Reshape back to [B, (*extra dims), max_bins] si besoin
|
|
378
|
+
if len(orig) > 1:
|
|
379
|
+
extra = orig[1:]
|
|
380
|
+
mean_pad = mean_pad.view(B, *extra, max_bins)
|
|
381
|
+
else:
|
|
382
|
+
mean_pad = mean_pad.view(B, max_bins)
|
|
126
383
|
|
|
127
|
-
|
|
128
|
-
return mean.view(*original_shape, n_bins), unique_groups
|
|
384
|
+
return mean_pad, groups_pad, mask
|
|
129
385
|
|
|
386
|
+
else:
|
|
387
|
+
raise ValueError("`cell_ids` must be of shape [N] or [B, N].")
|
|
388
|
+
|
|
130
389
|
def average_by_cell_group(data, cell_ids):
|
|
131
390
|
"""
|
|
132
391
|
data: tensor of shape [..., N, ...] (ex: [B, N, C])
|
|
@@ -154,7 +413,7 @@ class BkTorch(BackendBase.BackendBase):
|
|
|
154
413
|
return S.numel()
|
|
155
414
|
|
|
156
415
|
def bk_SparseTensor(self, indice, w, dense_shape=[]):
|
|
157
|
-
return self.backend.sparse_coo_tensor(indice
|
|
416
|
+
return self.backend.sparse_coo_tensor(indice, w, dense_shape).coalesce().to_sparse_csr().to(self.torch_device)
|
|
158
417
|
|
|
159
418
|
def bk_stack(self, list, axis=0):
|
|
160
419
|
return self.backend.stack(list, axis=axis).to(self.torch_device)
|