foscat 2025.8.4__tar.gz → 2025.9.3__tar.gz

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (45) hide show
  1. {foscat-2025.8.4/src/foscat.egg-info → foscat-2025.9.3}/PKG-INFO +1 -1
  2. {foscat-2025.8.4 → foscat-2025.9.3}/pyproject.toml +1 -1
  3. {foscat-2025.8.4 → foscat-2025.9.3}/src/foscat/BkTorch.py +309 -50
  4. {foscat-2025.8.4 → foscat-2025.9.3}/src/foscat/FoCUS.py +74 -267
  5. foscat-2025.9.3/src/foscat/HOrientedConvol.py +933 -0
  6. foscat-2025.9.3/src/foscat/HealBili.py +309 -0
  7. foscat-2025.9.3/src/foscat/Plot.py +331 -0
  8. foscat-2025.9.3/src/foscat/SphericalStencil.py +1346 -0
  9. foscat-2025.9.3/src/foscat/UNET.py +491 -0
  10. foscat-2025.9.3/src/foscat/healpix_unet_torch.py +1202 -0
  11. {foscat-2025.8.4 → foscat-2025.9.3}/src/foscat/scat_cov.py +3 -1
  12. {foscat-2025.8.4 → foscat-2025.9.3/src/foscat.egg-info}/PKG-INFO +1 -1
  13. {foscat-2025.8.4 → foscat-2025.9.3}/src/foscat.egg-info/SOURCES.txt +4 -0
  14. foscat-2025.8.4/src/foscat/HOrientedConvol.py +0 -546
  15. foscat-2025.8.4/src/foscat/UNET.py +0 -200
  16. {foscat-2025.8.4 → foscat-2025.9.3}/LICENSE +0 -0
  17. {foscat-2025.8.4 → foscat-2025.9.3}/README.md +0 -0
  18. {foscat-2025.8.4 → foscat-2025.9.3}/setup.cfg +0 -0
  19. {foscat-2025.8.4 → foscat-2025.9.3}/src/foscat/BkBase.py +0 -0
  20. {foscat-2025.8.4 → foscat-2025.9.3}/src/foscat/BkNumpy.py +0 -0
  21. {foscat-2025.8.4 → foscat-2025.9.3}/src/foscat/BkTensorflow.py +0 -0
  22. {foscat-2025.8.4 → foscat-2025.9.3}/src/foscat/CNN.py +0 -0
  23. {foscat-2025.8.4 → foscat-2025.9.3}/src/foscat/CircSpline.py +0 -0
  24. {foscat-2025.8.4 → foscat-2025.9.3}/src/foscat/GCNN.py +0 -0
  25. {foscat-2025.8.4 → foscat-2025.9.3}/src/foscat/HealSpline.py +0 -0
  26. {foscat-2025.8.4 → foscat-2025.9.3}/src/foscat/Softmax.py +0 -0
  27. {foscat-2025.8.4 → foscat-2025.9.3}/src/foscat/Spline1D.py +0 -0
  28. {foscat-2025.8.4 → foscat-2025.9.3}/src/foscat/Synthesis.py +0 -0
  29. {foscat-2025.8.4 → foscat-2025.9.3}/src/foscat/__init__.py +0 -0
  30. {foscat-2025.8.4 → foscat-2025.9.3}/src/foscat/alm.py +0 -0
  31. {foscat-2025.8.4 → foscat-2025.9.3}/src/foscat/backend.py +0 -0
  32. {foscat-2025.8.4 → foscat-2025.9.3}/src/foscat/backend_tens.py +0 -0
  33. {foscat-2025.8.4 → foscat-2025.9.3}/src/foscat/heal_NN.py +0 -0
  34. {foscat-2025.8.4 → foscat-2025.9.3}/src/foscat/loss_backend_tens.py +0 -0
  35. {foscat-2025.8.4 → foscat-2025.9.3}/src/foscat/loss_backend_torch.py +0 -0
  36. {foscat-2025.8.4 → foscat-2025.9.3}/src/foscat/scat.py +0 -0
  37. {foscat-2025.8.4 → foscat-2025.9.3}/src/foscat/scat1D.py +0 -0
  38. {foscat-2025.8.4 → foscat-2025.9.3}/src/foscat/scat2D.py +0 -0
  39. {foscat-2025.8.4 → foscat-2025.9.3}/src/foscat/scat_cov1D.py +0 -0
  40. {foscat-2025.8.4 → foscat-2025.9.3}/src/foscat/scat_cov2D.py +0 -0
  41. {foscat-2025.8.4 → foscat-2025.9.3}/src/foscat/scat_cov_map.py +0 -0
  42. {foscat-2025.8.4 → foscat-2025.9.3}/src/foscat/scat_cov_map2D.py +0 -0
  43. {foscat-2025.8.4 → foscat-2025.9.3}/src/foscat.egg-info/dependency_links.txt +0 -0
  44. {foscat-2025.8.4 → foscat-2025.9.3}/src/foscat.egg-info/requires.txt +0 -0
  45. {foscat-2025.8.4 → foscat-2025.9.3}/src/foscat.egg-info/top_level.txt +0 -0
@@ -1,6 +1,6 @@
1
1
  Metadata-Version: 2.4
2
2
  Name: foscat
3
- Version: 2025.8.4
3
+ Version: 2025.9.3
4
4
  Summary: Generate synthetic Healpix or 2D data using Cross Scattering Transform
5
5
  Author-email: Jean-Marc DELOUIS <jean.marc.delouis@ifremer.fr>
6
6
  Maintainer-email: Theo Foulquier <theo.foulquier@ifremer.fr>
@@ -1,6 +1,6 @@
1
1
  [project]
2
2
  name = "foscat"
3
- version = "2025.08.4"
3
+ version = "2025.09.3"
4
4
  description = "Generate synthetic Healpix or 2D data using Cross Scattering Transform"
5
5
  readme = "README.md"
6
6
  license = { text = "BSD-3-Clause" }
@@ -63,70 +63,329 @@ class BkTorch(BackendBase.BackendBase):
63
63
  torch.device("cuda") if torch.cuda.is_available() else torch.device("cpu")
64
64
  )
65
65
 
66
- import torch
67
-
68
- def binned_mean(self, data, cell_ids):
66
+ # ---------------------------------
67
+ # HEALPix binning utilities (nested)
68
+ # ---------------------------------
69
+ # Robust binned_mean that supports arbitrary subsets (N not divisible by 4)
70
+ # and batched cell_ids of shape [B, N]. It returns compact per-parent means
71
+ # even when some parents are missing (sparse coverage).
72
+
73
+ def binned_mean_old(self, data, cell_ids, *, padded: bool = False, fill_value: float = float("nan")):
74
+ """Average values over parent HEALPix pixels (nested) when downgrading nside→nside/2.
75
+
76
+ Works with full-sky or sparse subsets (no need for N to be divisible by 4).
77
+
78
+ Parameters
79
+ ----------
80
+ data : torch.Tensor or np.ndarray
81
+ Shape ``[..., N]`` or ``[B, ..., N]``.
82
+ cell_ids : torch.LongTensor or np.ndarray
83
+ Shape ``[N]`` or ``[B, N]`` (nested indexing at the *child* resolution).
84
+ padded : bool, optional (default: False)
85
+ Only used when ``cell_ids`` is ``[B, N]``. If ``False``, returns Python
86
+ lists (ragged) of per-batch results. If ``True``, returns padded tensors
87
+ plus a boolean mask of valid bins.
88
+ fill_value : float, optional
89
+ Value used for padding when ``padded=True``.
90
+
91
+ Returns
92
+ -------
93
+ If ``cell_ids`` is ``[N]``:
94
+ mean : torch.Tensor, shape ``[..., n_bins]``
95
+ groups: torch.LongTensor, shape ``[n_bins]`` (sorted unique parents)
96
+
97
+ If ``cell_ids`` is ``[B, N]`` and ``padded=False``:
98
+ means_list : List[torch.Tensor] of length B, each shape ``[T, n_bins_b]``
99
+ where ``T = prod(data.shape[1:-1])`` (or 1 if none).
100
+ groups_list : List[torch.LongTensor] of length B, each shape ``[n_bins_b]``
101
+
102
+ If ``cell_ids`` is ``[B, N]`` and ``padded=True``:
103
+ mean_padded : torch.Tensor, shape ``[B, T, max_bins]`` (or ``[B, max_bins]`` if T==1)
104
+ groups_pad : torch.LongTensor, shape ``[B, max_bins]`` (parents, padded with -1)
105
+ mask : torch.BoolTensor, shape ``[B, max_bins]`` (True where valid)
69
106
  """
70
- Compute the mean over groups of 4 nested HEALPix cells (nside → nside/2).
107
+ import torch, numpy as np
71
108
 
72
- Args:
73
- data (torch.Tensor): Tensor of shape [..., N], where N is the number of HEALPix cells.
74
- cell_ids (torch.LongTensor): Tensor of shape [N], with cell indices (nested ordering).
109
+ # ---- Tensorize & device/dtype plumbing ----
110
+ if isinstance(data, np.ndarray):
111
+ data = torch.from_numpy(data).to(dtype=torch.float32, device=getattr(self, 'torch_device', 'cpu'))
112
+ if isinstance(cell_ids, np.ndarray):
113
+ cell_ids = torch.from_numpy(cell_ids).to(dtype=torch.long, device=data.device)
75
114
 
76
- Returns:
77
- torch.Tensor: Tensor of shape [..., n_bins], with averaged values per group of 4 cells.
115
+ data = data.to(device=getattr(self, 'torch_device', data.device))
116
+ cell_ids = cell_ids.to(device=data.device, dtype=torch.long)
117
+
118
+ if data.ndim < 1:
119
+ raise ValueError("`data` must have at least 1 dimension (last is N).")
120
+ N = data.shape[-1]
121
+
122
+ # Flatten leading dims (rows) for scatter convenience
123
+ orig = data.shape[:-1]
124
+ T = int(np.prod(orig[1:])) if len(orig) > 1 else 1 # repeats per batch row
125
+ if cell_ids.ndim == 1:
126
+ # Shared mapping for all rows
127
+ groups = (cell_ids // 4).to(torch.long) # [N]
128
+ # Unique parent ids + inverse indices
129
+ parents, inv = torch.unique(groups, sorted=True, return_inverse=True)
130
+ n_bins = parents.numel()
131
+
132
+ R = int(np.prod(orig)) if len(orig) > 0 else 1
133
+ data_flat = data.reshape(R, N) # [R, N]
134
+
135
+ # Row offsets -> independent bins per row
136
+ row_offsets = torch.arange(R, device=data.device).unsqueeze(1) * n_bins # [R,1]
137
+ idx = inv.unsqueeze(0).expand(R, -1) + row_offsets # [R,N]
138
+
139
+ vals_flat = data_flat.reshape(-1)
140
+ idx_flat = idx.reshape(-1)
141
+
142
+ out_sum = torch.zeros(R * n_bins, dtype=data.dtype, device=data.device)
143
+ out_cnt = torch.zeros_like(out_sum)
144
+ out_sum.scatter_add_(0, idx_flat, vals_flat)
145
+ out_cnt.scatter_add_(0, idx_flat, torch.ones_like(vals_flat))
146
+ out_cnt.clamp_(min=1)
147
+
148
+ mean = (out_sum / out_cnt).view(*orig, n_bins)
149
+ return mean, parents
150
+
151
+ elif cell_ids.ndim == 2:
152
+ B = cell_ids.shape[0]
153
+ if data.shape[0] % B != 0:
154
+ raise ValueError(f"Leading dim of data ({data.shape[0]}) must be a multiple of cell_ids batch ({B}).")
155
+ R = int(np.prod(orig)) if len(orig) > 0 else 1
156
+ data_flat = data.reshape(R, N) # [R, N]
157
+ B_data = data.shape[0]
158
+ T = R // B_data # repeats per batch row (product of extra leading dims)
159
+
160
+ means_list, groups_list = [], []
161
+ max_bins = 0
162
+ # First pass: compute per-batch parents/inv and scatter means
163
+ for b in range(B):
164
+ groups_b = (cell_ids[b] // 4).to(torch.long) # [N]
165
+ parents_b, inv_b = torch.unique(groups_b, sorted=True, return_inverse=True)
166
+ n_bins_b = parents_b.numel()
167
+ max_bins = max(max_bins, n_bins_b)
168
+
169
+ # rows for this batch in data_flat
170
+ start = b * T
171
+ stop = (b + 1) * T
172
+ rows = slice(start, stop) # T rows
173
+
174
+ row_offsets = (torch.arange(T, device=data.device).unsqueeze(1) * n_bins_b)
175
+ idx = inv_b.unsqueeze(0).expand(T, -1) + row_offsets # [T, N]
176
+
177
+ vals_flat = data_flat[rows].reshape(-1)
178
+ idx_flat = idx.reshape(-1)
179
+
180
+ out_sum = torch.zeros(T * n_bins_b, dtype=data.dtype, device=data.device)
181
+ out_cnt = torch.zeros_like(out_sum)
182
+ out_sum.scatter_add_(0, idx_flat, vals_flat)
183
+ out_cnt.scatter_add_(0, idx_flat, torch.ones_like(vals_flat))
184
+ out_cnt.clamp_(min=1)
185
+ mean_bt = (out_sum / out_cnt).view(T, n_bins_b) # [T, n_bins_b]
186
+
187
+ means_list.append(mean_bt)
188
+ groups_list.append(parents_b)
189
+
190
+ if not padded:
191
+ return means_list, groups_list
192
+
193
+ # Padded output
194
+ # mean_padded: [B, T, max_bins]; groups_pad: [B, max_bins]; mask: [B, max_bins]
195
+ mean_pad = torch.full((B, T, max_bins), fill_value, dtype=data.dtype, device=data.device)
196
+ groups_pad = torch.full((B, max_bins), -1, dtype=torch.long, device=data.device)
197
+ mask = torch.zeros((B, max_bins), dtype=torch.bool, device=data.device)
198
+ for b, (m_b, g_b) in enumerate(zip(means_list, groups_list)):
199
+ nb = g_b.numel()
200
+ mean_pad[b, :, :nb] = m_b
201
+ groups_pad[b, :nb] = g_b
202
+ mask[b, :nb] = True
203
+
204
+ # Reshape back to [B, (*extra leading dims), max_bins] if needed
205
+ if len(orig) > 1:
206
+ extra = orig[1:] # e.g., (D1, D2, ...)
207
+ mean_pad = mean_pad.view(B, *extra, max_bins)
208
+ else:
209
+ mean_pad = mean_pad.view(B, max_bins)
210
+
211
+ return mean_pad, groups_pad, mask
212
+
213
+ else:
214
+ raise ValueError("`cell_ids` must be of shape [N] or [B, N].")
215
+
216
+ def binned_mean(
217
+ self,
218
+ data,
219
+ cell_ids,
220
+ *,
221
+ reduce: str = "mean", # <-- NEW: "mean" (par défaut) ou "max"
222
+ padded: bool = False,
223
+ fill_value: float = float("nan"),
224
+ ):
225
+ """
226
+ Reduce values over parent HEALPix pixels (nested) when downgrading nside→nside/2.
227
+
228
+ Parameters
229
+ ----------
230
+ data : torch.Tensor | np.ndarray
231
+ Shape [..., N] or [B, ..., N].
232
+ cell_ids : torch.LongTensor | np.ndarray
233
+ Shape [N] or [B, N] (nested indexing at the child resolution).
234
+ reduce : {"mean","max"}, default "mean"
235
+ Aggregation to apply within each parent group of 4 children.
236
+ padded : bool, default False
237
+ Only used when `cell_ids` is [B, N]. If False, returns ragged Python lists.
238
+ If True, returns padded tensors + mask.
239
+ fill_value : float, default NaN
240
+ Padding value when `padded=True`.
241
+
242
+ Returns
243
+ -------
244
+ # idem à ta doc existante, mais la valeur est une moyenne (reduce="mean")
245
+ # ou un maximum (reduce="max").
78
246
  """
247
+
248
+ # ---- Tensorize & device/dtype plumbing ----
79
249
  if isinstance(data, np.ndarray):
80
250
  data = torch.from_numpy(data).to(
81
- dtype=torch.float32, device=self.torch_device
251
+ dtype=torch.float32, device=getattr(self, "torch_device", "cpu")
82
252
  )
83
253
  if isinstance(cell_ids, np.ndarray):
84
254
  cell_ids = torch.from_numpy(cell_ids).to(
85
- dtype=torch.long, device=self.torch_device
255
+ dtype=torch.long, device=data.device
86
256
  )
257
+ data = data.to(device=getattr(self, "torch_device", data.device))
258
+ cell_ids = cell_ids.to(device=data.device, dtype=torch.long)
87
259
 
88
- # Compute supercell ids by grouping 4 nested cells together
89
- groups = cell_ids // 4
90
-
91
- # Get unique group ids and inverse mapping
92
- unique_groups, inverse_indices = torch.unique(groups, return_inverse=True)
93
- n_bins = unique_groups.shape[0]
94
-
95
- # Flatten all leading dimensions into a single batch dimension
96
- original_shape = data.shape[:-1]
260
+ if data.ndim < 1:
261
+ raise ValueError("`data` must have at least 1 dimension (last is N).")
97
262
  N = data.shape[-1]
98
- data_flat = data.reshape(-1, N) # Shape: [B, N]
99
-
100
- # Prepare to compute sums using scatter_add
101
- B = data_flat.shape[0]
102
-
103
- # Repeat inverse indices for each batch element
104
- idx = inverse_indices.repeat(B, 1) # Shape: [B, N]
105
-
106
- # Offset indices to simulate a per-batch scatter into [B * n_bins]
107
- batch_offsets = torch.arange(B, device=data.device).unsqueeze(1) * n_bins
108
- idx_offset = idx + batch_offsets # Shape: [B, N]
109
-
110
- # Flatten everything for scatter
111
- idx_offset_flat = idx_offset.flatten()
112
- data_flat_flat = data_flat.flatten()
113
-
114
- # Accumulate sums per bin
115
- out = torch.zeros(B * n_bins, dtype=data.dtype, device=data.device)
116
- out = out.scatter_add(0, idx_offset_flat, data_flat_flat)
117
263
 
118
- # Count number of elements per bin (to compute mean)
119
- ones = torch.ones_like(data_flat_flat)
120
- counts = torch.zeros(B * n_bins, dtype=data.dtype, device=data.device)
121
- counts = counts.scatter_add(0, idx_offset_flat, ones)
122
-
123
- # Compute mean
124
- mean = out / counts # Shape: [B * n_bins]
125
- mean = mean.view(B, n_bins)
264
+ # Utilitaires pour 'max' (fallback si scatter_reduce_ indisponible)
265
+ def _segment_max(vals_flat, idx_flat, out_size):
266
+ """Retourne out[out_idx] = max(vals[ idx==out_idx ]), vectorisé si possible."""
267
+ # PyTorch >= 1.12 / 2.0: scatter_reduce_ disponible
268
+ if hasattr(torch.Tensor, "scatter_reduce_"):
269
+ out = torch.full((out_size,), -float("inf"),
270
+ dtype=vals_flat.dtype, device=vals_flat.device)
271
+ out.scatter_reduce_(0, idx_flat, vals_flat, reduce="amax", include_self=True)
272
+ return out
273
+ # Fallback simple (boucle sur indices uniques) – OK pour du downsample
274
+ out = torch.full((out_size,), -float("inf"),
275
+ dtype=vals_flat.dtype, device=vals_flat.device)
276
+ uniq = torch.unique(idx_flat)
277
+ for u in uniq.tolist():
278
+ m = (idx_flat == u)
279
+ # éviter max() sur tensor vide
280
+ if m.any():
281
+ out[u] = torch.max(vals_flat[m])
282
+ return out
283
+
284
+ # ---- Flatten leading dims for scatter convenience ----
285
+ orig = data.shape[:-1]
286
+ if cell_ids.ndim == 1:
287
+ # Shared mapping for all rows
288
+ groups = (cell_ids // 4).to(torch.long) # [N]
289
+ parents, inv = torch.unique(groups, sorted=True, return_inverse=True)
290
+ n_bins = parents.numel()
291
+
292
+ R = int(np.prod(orig)) if len(orig) > 0 else 1
293
+ data_flat = data.reshape(R, N) # [R, N]
294
+ row_offsets = torch.arange(R, device=data.device).unsqueeze(1) * n_bins
295
+ idx = inv.unsqueeze(0).expand(R, -1) + row_offsets # [R, N]
296
+
297
+ vals_flat = data_flat.reshape(-1)
298
+ idx_flat = idx.reshape(-1)
299
+ out_size = R * n_bins
300
+
301
+ if reduce == "mean":
302
+ out_sum = torch.zeros(out_size, dtype=data.dtype, device=data.device)
303
+ out_cnt = torch.zeros_like(out_sum)
304
+ out_sum.scatter_add_(0, idx_flat, vals_flat)
305
+ out_cnt.scatter_add_(0, idx_flat, torch.ones_like(vals_flat))
306
+ out_cnt.clamp_(min=1)
307
+ reduced = out_sum / out_cnt
308
+ elif reduce == "max":
309
+ reduced = _segment_max(vals_flat, idx_flat, out_size)
310
+ else:
311
+ raise ValueError("reduce must be 'mean' or 'max'.")
312
+
313
+ output = reduced.view(*orig, n_bins)
314
+ return output, parents
315
+
316
+ elif cell_ids.ndim == 2:
317
+ # Per-batch mapping
318
+ B = cell_ids.shape[0]
319
+ R = int(np.prod(orig)) if len(orig) > 0 else 1
320
+ data_flat = data.reshape(R, N) # [R, N]
321
+ B_data = data.shape[0] if len(orig) > 0 else 1
322
+ if B_data % B != 0:
323
+ raise ValueError(
324
+ f"Leading dim of data ({B_data}) must be a multiple of cell_ids batch ({B})."
325
+ )
326
+ # T = repeats per batch row (product of extra leading dims)
327
+ T = (R // B_data) if B_data > 0 else 1
328
+
329
+ means_list, groups_list = [], []
330
+ max_bins = 0
331
+
332
+ for b in range(B):
333
+ groups_b = (cell_ids[b] // 4).to(torch.long) # [N]
334
+ parents_b, inv_b = torch.unique(groups_b, sorted=True, return_inverse=True)
335
+ n_bins_b = parents_b.numel()
336
+ max_bins = max(max_bins, n_bins_b)
337
+
338
+ # rows for this batch in data_flat
339
+ start, stop = b * T, (b + 1) * T
340
+ rows = slice(start, stop) # T rows
341
+
342
+ row_offsets = torch.arange(T, device=data.device).unsqueeze(1) * n_bins_b
343
+ idx = inv_b.unsqueeze(0).expand(T, -1) + row_offsets # [T, N]
344
+
345
+ vals_flat = data_flat[rows].reshape(-1)
346
+ idx_flat = idx.reshape(-1)
347
+ out_size = T * n_bins_b
348
+
349
+ if reduce == "mean":
350
+ out_sum = torch.zeros(out_size, dtype=data.dtype, device=data.device)
351
+ out_cnt = torch.zeros_like(out_sum)
352
+ out_sum.scatter_add_(0, idx_flat, vals_flat)
353
+ out_cnt.scatter_add_(0, idx_flat, torch.ones_like(vals_flat))
354
+ out_cnt.clamp_(min=1)
355
+ reduced_bt = (out_sum / out_cnt).view(T, n_bins_b)
356
+ elif reduce == "max":
357
+ reduced_bt = _segment_max(vals_flat, idx_flat, out_size).view(T, n_bins_b)
358
+ else:
359
+ raise ValueError("reduce must be 'mean' or 'max'.")
360
+
361
+ means_list.append(reduced_bt)
362
+ groups_list.append(parents_b)
363
+
364
+ if not padded:
365
+ return means_list, groups_list
366
+
367
+ # Padded output (B, T, max_bins) [+ mask]
368
+ mean_pad = torch.full((B, T, max_bins), fill_value, dtype=data.dtype, device=data.device)
369
+ groups_pad = torch.full((B, max_bins), -1, dtype=torch.long, device=data.device)
370
+ mask = torch.zeros((B, max_bins), dtype=torch.bool, device=data.device)
371
+ for b, (m_b, g_b) in enumerate(zip(means_list, groups_list)):
372
+ nb = g_b.numel()
373
+ mean_pad[b, :, :nb] = m_b
374
+ groups_pad[b, :nb] = g_b
375
+ mask[b, :nb] = True
376
+
377
+ # Reshape back to [B, (*extra dims), max_bins] si besoin
378
+ if len(orig) > 1:
379
+ extra = orig[1:]
380
+ mean_pad = mean_pad.view(B, *extra, max_bins)
381
+ else:
382
+ mean_pad = mean_pad.view(B, max_bins)
126
383
 
127
- # Restore original leading dimensions
128
- return mean.view(*original_shape, n_bins), unique_groups
384
+ return mean_pad, groups_pad, mask
129
385
 
386
+ else:
387
+ raise ValueError("`cell_ids` must be of shape [N] or [B, N].")
388
+
130
389
  def average_by_cell_group(data, cell_ids):
131
390
  """
132
391
  data: tensor of shape [..., N, ...] (ex: [B, N, C])
@@ -154,7 +413,7 @@ class BkTorch(BackendBase.BackendBase):
154
413
  return S.numel()
155
414
 
156
415
  def bk_SparseTensor(self, indice, w, dense_shape=[]):
157
- return self.backend.sparse_coo_tensor(indice.T, w, dense_shape).to_sparse_csr().to(self.torch_device)
416
+ return self.backend.sparse_coo_tensor(indice, w, dense_shape).coalesce().to_sparse_csr().to(self.torch_device)
158
417
 
159
418
  def bk_stack(self, list, axis=0):
160
419
  return self.backend.stack(list, axis=axis).to(self.torch_device)