foscat 2025.11.1__py3-none-any.whl → 2026.1.1__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
foscat/FoCUS.py CHANGED
@@ -5,8 +5,11 @@ import healpy as hp
5
5
  import numpy as np
6
6
  import foscat.HealSpline as HS
7
7
  from scipy.interpolate import griddata
8
+ from foscat.SphereDownGeo import SphereDownGeo
9
+ from foscat.SphereUpGeo import SphereUpGeo
10
+ import torch
8
11
 
9
- TMPFILE_VERSION = "V10_0"
12
+ TMPFILE_VERSION = "V12_0"
10
13
 
11
14
 
12
15
  class FoCUS:
@@ -36,7 +39,7 @@ class FoCUS:
36
39
  mpi_rank=0
37
40
  ):
38
41
 
39
- self.__version__ = "2025.11.1"
42
+ self.__version__ = "2026.01.1"
40
43
  # P00 coeff for normalization for scat_cov
41
44
  self.TMPFILE_VERSION = TMPFILE_VERSION
42
45
  self.P1_dic = None
@@ -57,7 +60,8 @@ class FoCUS:
57
60
  self.kernelR_conv = {}
58
61
  self.kernelI_conv = {}
59
62
  self.padding_conv = {}
60
-
63
+ self.down = {}
64
+ self.up = {}
61
65
  if not self.silent:
62
66
  print("================================================")
63
67
  print(" START FOSCAT CONFIGURATION")
@@ -648,6 +652,7 @@ class FoCUS:
648
652
  return rim
649
653
 
650
654
  # --------------------------------------------------------
655
+
651
656
  def ud_grade_2(self, im, axis=0, cell_ids=None, nside=None,max_poll=False):
652
657
 
653
658
  if self.use_2D:
@@ -721,6 +726,22 @@ class FoCUS:
721
726
 
722
727
  else:
723
728
  shape = list(im.shape)
729
+ if nside is None:
730
+ l_nside=int(np.sqrt(shape[-1]//12))
731
+ else:
732
+ l_nside=nside
733
+
734
+ nbatch=1
735
+ for k in range(len(shape)-1):
736
+ nbatch*=shape[k]
737
+ if l_nside not in self.down:
738
+ print('initialise down', l_nside)
739
+ self.down[l_nside] = SphereDownGeo(nside_in=l_nside, dtype=self.all_bk_type,mode="smooth", in_cell_ids=cell_ids)
740
+
741
+ res,out_cell=self.down[l_nside](self.backend.bk_reshape(im,[nbatch,1,shape[-1]]))
742
+
743
+ return self.backend.bk_reshape(res,shape[:-1]+[out_cell.shape[0]]),out_cell
744
+ '''
724
745
  if self.use_median:
725
746
  if cell_ids is not None:
726
747
  sim, new_cell_ids = self.backend.binned_mean(im, cell_ids,reduce='median')
@@ -747,6 +768,7 @@ class FoCUS:
747
768
  return self.backend.bk_reduce_mean(
748
769
  self.backend.bk_reshape(im, shape[0:-1]+[shape[-1]//4,4]), axis=-1
749
770
  ),None
771
+ '''
750
772
 
751
773
  # --------------------------------------------------------
752
774
  def up_grade(self, im, nout,
@@ -836,6 +858,7 @@ class FoCUS:
836
858
  else:
837
859
  lout = nside
838
860
 
861
+ '''
839
862
  if (lout,nout) not in self.pix_interp_val or force_init_index:
840
863
  if not self.silent:
841
864
  print("compute lout nout", lout, nout)
@@ -926,12 +949,32 @@ class FoCUS:
926
949
 
927
950
  del w
928
951
  del p
929
-
930
- if lout == nout:
931
- imout = im
932
- else:
933
- # work only on the last column
934
-
952
+ '''
953
+ shape=list(im.shape)
954
+ nbatch=1
955
+ for k in range(len(shape)-1):
956
+ nbatch*=shape[k]
957
+
958
+ im=self.backend.bk_reshape(im,[nbatch,1,shape[-1]])
959
+
960
+ while lout<nout:
961
+ if lout not in self.up:
962
+ if o_cell_ids is None:
963
+ l_o_cell_ids=torch.tensor(np.arange(12*(lout**2),dtype='int'),device=im.device)
964
+ else:
965
+ l_o_cell_ids=o_cell_ids
966
+ self.up[lout] = SphereUpGeo(nside_out=lout,
967
+ dtype=self.all_bk_type,
968
+ cell_ids_out=l_o_cell_ids,
969
+ up_norm="col_l1")
970
+ im, fine_ids = self.up[lout](self.backend.bk_cast(im))
971
+ lout*=2
972
+ if lout<nout and o_cell_ids is not None:
973
+ o_cell_ids=torch.repeat(fine_ids,4)*4+ \
974
+ torch.tile(torch.tensor([0,1,2,3],device=fine_ids.device,dtype=fine_ids.dtype),fine_ids.shape[0])
975
+
976
+ return self.backend.bk_reshape(im,shape[:-1]+[im.shape[-1]])
977
+ '''
935
978
  ndata = 1
936
979
  for k in range(len(ishape)-1):
937
980
  ndata = ndata * ishape[k]
@@ -960,6 +1003,7 @@ class FoCUS:
960
1003
  return self.backend.bk_reshape(
961
1004
  imout, ishape[0:-1]+[imout.shape[-1]]
962
1005
  )
1006
+ '''
963
1007
  return imout
964
1008
 
965
1009
  # --------------------------------------------------------
@@ -1354,7 +1398,9 @@ class FoCUS:
1354
1398
  else:
1355
1399
  l_cell_ids=cell_ids
1356
1400
 
1357
- nvalid=self.KERNELSZ**2
1401
+ nvalid=4*self.KERNELSZ**2
1402
+ if nvalid>12*nside**2:
1403
+ nvalid=12*nside**2
1358
1404
  idxEB=hconvol.idx_nn[:,0:nvalid]
1359
1405
  tmpEB=np.zeros([self.NORIENT,4,l_cell_ids.shape[0],nvalid],dtype='complex')
1360
1406
  tmpS=np.zeros([4,l_cell_ids.shape[0],nvalid],dtype='float')
@@ -1500,7 +1546,7 @@ class FoCUS:
1500
1546
 
1501
1547
  else:
1502
1548
  if l_kernel == 5:
1503
- pw = 0.5
1549
+ pw = 0.75
1504
1550
  pw2 = 0.5
1505
1551
  threshold = 2e-5
1506
1552
 
@@ -0,0 +1,380 @@
1
+
2
+ import torch
3
+ import torch.nn as nn
4
+ import numpy as np
5
+ import healpy as hp
6
+
7
+
8
+ class SphereDownGeo(nn.Module):
9
+ """
10
+ Geometric HEALPix downsampling operator (NESTED indexing).
11
+
12
+ This module reduces resolution by a factor 2:
13
+ nside_out = nside_in // 2
14
+
15
+ Input conventions
16
+ -----------------
17
+ - If in_cell_ids is None:
18
+ x is expected to be full-sphere: [B, C, N_in]
19
+ output is [B, C, K_out] with K_out = len(cell_ids_out) (or N_out if None).
20
+ - If in_cell_ids is provided (fine pixels at nside_in, NESTED):
21
+ x can be either:
22
+ * compact: [B, C, K_in] where K_in = len(in_cell_ids), aligned with in_cell_ids order
23
+ * full-sphere: [B, C, N_in] (also supported)
24
+ output is [B, C, K_out] where cell_ids_out is derived as unique(in_cell_ids // 4),
25
+ unless you explicitly pass cell_ids_out (then it will be intersected with the derived set).
26
+
27
+ Modes
28
+ -----
29
+ - mode="smooth": linear downsampling y = M @ x (M sparse)
30
+ - mode="maxpool": non-linear max over available children (fast)
31
+ """
32
+
33
+ def __init__(
34
+ self,
35
+ nside_in: int,
36
+ mode: str = "smooth",
37
+ radius_deg: float | None = None,
38
+ sigma_deg: float | None = None,
39
+ weight_norm: str = "l1",
40
+ cell_ids_out: np.ndarray | list[int] | None = None,
41
+ in_cell_ids: np.ndarray | list[int] | torch.Tensor | None = None,
42
+ use_csr=True,
43
+ device=None,
44
+ dtype: torch.dtype = torch.float32,
45
+ ):
46
+ super().__init__()
47
+
48
+ if device is None:
49
+ device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
50
+ self.device = device
51
+ self.dtype = dtype
52
+
53
+ self.nside_in = int(nside_in)
54
+ assert (self.nside_in & (self.nside_in - 1)) == 0, "nside_in must be a power of 2."
55
+ self.nside_out = self.nside_in // 2
56
+ assert self.nside_out >= 1, "nside_out must be >= 1."
57
+
58
+ self.N_in = 12 * self.nside_in * self.nside_in
59
+ self.N_out = 12 * self.nside_out * self.nside_out
60
+
61
+ self.mode = str(mode).lower()
62
+ assert self.mode in ("smooth", "maxpool"), "mode must be 'smooth' or 'maxpool'."
63
+
64
+ self.weight_norm = str(weight_norm).lower()
65
+ assert self.weight_norm in ("l1", "l2"), "weight_norm must be 'l1' or 'l2'."
66
+
67
+ # ---- Handle reduced-domain inputs (fine pixels) ----
68
+ self.in_cell_ids = self._validate_in_cell_ids(in_cell_ids)
69
+ self.has_in_subset = self.in_cell_ids is not None
70
+ if self.has_in_subset:
71
+ # derive parents
72
+ derived_out = np.unique(self.in_cell_ids // 4).astype(np.int64)
73
+ if cell_ids_out is None:
74
+ self.cell_ids_out = derived_out
75
+ else:
76
+ req_out = self._validate_cell_ids_out(cell_ids_out)
77
+ # keep only those compatible with derived_out (otherwise they'd be all-zero)
78
+ self.cell_ids_out = np.intersect1d(req_out, derived_out, assume_unique=False)
79
+ if self.cell_ids_out.size == 0:
80
+ raise ValueError(
81
+ "After intersecting cell_ids_out with unique(in_cell_ids//4), "
82
+ "no coarse pixel remains. Check your inputs."
83
+ )
84
+ else:
85
+ self.cell_ids_out = self._validate_cell_ids_out(cell_ids_out)
86
+
87
+ self.K_out = int(self.cell_ids_out.size)
88
+
89
+ # Column basis for smooth matrix:
90
+ # - full sphere: columns are 0..N_in-1
91
+ # - subset: columns are 0..K_in-1 aligned to self.in_cell_ids
92
+ self.K_in = int(self.in_cell_ids.size) if self.has_in_subset else self.N_in
93
+
94
+ if self.mode == "smooth":
95
+ if radius_deg is None:
96
+ # default: include roughly the 4 children footprint
97
+ # (healpy pixel size ~ sqrt(4pi/N), coarse pixel is 4x area)
98
+ radius_deg = 2.0 * hp.nside2resol(self.nside_out, arcmin=True) / 60.0
99
+ if sigma_deg is None:
100
+ sigma_deg = max(radius_deg / 2.0, 1e-6)
101
+
102
+ self.radius_deg = float(radius_deg)
103
+ self.sigma_deg = float(sigma_deg)
104
+ self.radius_rad = self.radius_deg * np.pi / 180.0
105
+ self.sigma_rad = self.sigma_deg * np.pi / 180.0
106
+
107
+ M = self._build_down_matrix() # shape (K_out, K_in or N_in)
108
+
109
+ self.M = M.coalesce()
110
+
111
+ if use_csr:
112
+ self.M = self.M.to_sparse_csr().to(self.device)
113
+
114
+ self.M_size = M.size()
115
+
116
+ else:
117
+ # Precompute children indices for maxpool
118
+ # For subset mode, store mapping from each parent to indices in compact vector,
119
+ # with -1 for missing children.
120
+ children = np.stack(
121
+ [4 * self.cell_ids_out + i for i in range(4)],
122
+ axis=1,
123
+ ).astype(np.int64) # [K_out, 4] in fine pixel ids (full indexing)
124
+
125
+ if self.has_in_subset:
126
+ # map each child pixel id to position in in_cell_ids (compact index)
127
+ pos = self._positions_in_sorted(self.in_cell_ids, children.reshape(-1))
128
+ children_compact = pos.reshape(self.K_out, 4).astype(np.int64) # -1 if missing
129
+ self.register_buffer(
130
+ "children_compact",
131
+ torch.tensor(children_compact, dtype=torch.long, device=self.device),
132
+ )
133
+ else:
134
+ self.register_buffer(
135
+ "children_full",
136
+ torch.tensor(children, dtype=torch.long, device=self.device),
137
+ )
138
+
139
+ # expose ids as torch buffers for convenience
140
+ self.register_buffer(
141
+ "cell_ids_out_t",
142
+ torch.tensor(self.cell_ids_out.astype(np.int64), dtype=torch.long, device=self.device),
143
+ )
144
+ if self.has_in_subset:
145
+ self.register_buffer(
146
+ "in_cell_ids_t",
147
+ torch.tensor(self.in_cell_ids.astype(np.int64), dtype=torch.long, device=self.device),
148
+ )
149
+
150
+ # ---------------- validation helpers ----------------
151
+ def _validate_cell_ids_out(self, cell_ids_out):
152
+ """Return a 1D np.int64 array of coarse cell ids (nside_out)."""
153
+ if cell_ids_out is None:
154
+ return np.arange(self.N_out, dtype=np.int64)
155
+
156
+ arr = np.asarray(cell_ids_out, dtype=np.int64).reshape(-1)
157
+ if arr.size == 0:
158
+ raise ValueError("cell_ids_out is empty: provide at least one coarse pixel id.")
159
+ arr = np.unique(arr)
160
+ if arr.min() < 0 or arr.max() >= self.N_out:
161
+ raise ValueError(f"cell_ids_out must be in [0, {self.N_out-1}] for nside_out={self.nside_out}.")
162
+ return arr
163
+
164
+ def _validate_in_cell_ids(self, in_cell_ids):
165
+ """Return a 1D np.int64 array of fine cell ids (nside_in) or None."""
166
+ if in_cell_ids is None:
167
+ return None
168
+ if torch.is_tensor(in_cell_ids):
169
+ arr = in_cell_ids.detach().cpu().numpy()
170
+ else:
171
+ arr = np.asarray(in_cell_ids)
172
+ arr = np.asarray(arr, dtype=np.int64).reshape(-1)
173
+ if arr.size == 0:
174
+ raise ValueError("in_cell_ids is empty: provide at least one fine pixel id or None.")
175
+ arr = np.unique(arr)
176
+ if arr.min() < 0 or arr.max() >= self.N_in:
177
+ raise ValueError(f"in_cell_ids must be in [0, {self.N_in-1}] for nside_in={self.nside_in}.")
178
+ return arr
179
+
180
+ @staticmethod
181
+ def _positions_in_sorted(sorted_ids: np.ndarray, query_ids: np.ndarray) -> np.ndarray:
182
+ """
183
+ For each query_id, return its index in sorted_ids if present, else -1.
184
+ sorted_ids must be sorted ascending unique.
185
+ """
186
+ q = np.asarray(query_ids, dtype=np.int64)
187
+ idx = np.searchsorted(sorted_ids, q)
188
+ ok = (idx >= 0) & (idx < sorted_ids.size) & (sorted_ids[idx] == q)
189
+ out = np.full(q.shape, -1, dtype=np.int64)
190
+ out[ok] = idx[ok]
191
+ return out
192
+
193
+ # ---------------- weights and matrix build ----------------
194
+ def _normalize_weights(self, w: np.ndarray) -> np.ndarray:
195
+ w = np.asarray(w, dtype=np.float64)
196
+ if w.size == 0:
197
+ return w
198
+ w = np.maximum(w, 0.0)
199
+
200
+ if self.weight_norm == "l1":
201
+ s = w.sum()
202
+ if s <= 0.0:
203
+ return np.ones_like(w) / max(w.size, 1)
204
+ return w / s
205
+
206
+ # l2
207
+ s2 = (w * w).sum()
208
+ if s2 <= 0.0:
209
+ return np.ones_like(w) / max(np.sqrt(w.size), 1.0)
210
+ return w / np.sqrt(s2)
211
+
212
+ def _build_down_matrix(self) -> torch.Tensor:
213
+ """Construct sparse matrix M (K_out, K_in or N_in) for the selected coarse pixels."""
214
+ nside_in = self.nside_in
215
+ nside_out = self.nside_out
216
+
217
+ radius_rad = self.radius_rad
218
+ sigma_rad = self.sigma_rad
219
+
220
+ rows: list[int] = []
221
+ cols: list[int] = []
222
+ vals: list[float] = []
223
+
224
+ # For subset columns, we use self.in_cell_ids as the basis
225
+ subset_cols = self.has_in_subset
226
+ in_ids = self.in_cell_ids # np.ndarray or None
227
+
228
+ for r, p_out in enumerate(self.cell_ids_out.tolist()):
229
+ theta0, phi0 = hp.pix2ang(nside_out, int(p_out), nest=True)
230
+ vec0 = hp.ang2vec(theta0, phi0)
231
+
232
+ neigh = hp.query_disc(nside_in, vec0, radius_rad, inclusive=True, nest=True)
233
+ neigh = np.asarray(neigh, dtype=np.int64)
234
+
235
+ if subset_cols:
236
+ # keep only valid fine pixels
237
+ # neigh is not sorted; intersect1d expects sorted
238
+ neigh_sorted = np.sort(neigh)
239
+ keep = np.intersect1d(neigh_sorted, in_ids, assume_unique=False)
240
+ neigh = keep
241
+
242
+ # Fallback: if radius query returns nothing in subset mode, at least try the 4 children
243
+ if neigh.size == 0:
244
+ children = (4 * int(p_out) + np.arange(4, dtype=np.int64))
245
+ if subset_cols:
246
+ pos = self._positions_in_sorted(in_ids, children)
247
+ ok = pos >= 0
248
+ if np.any(ok):
249
+ neigh = children[ok]
250
+ else:
251
+ # nothing to connect -> row stays zero
252
+ continue
253
+ else:
254
+ neigh = children
255
+
256
+ theta, phi = hp.pix2ang(nside_in, neigh, nest=True)
257
+ vec = hp.ang2vec(theta, phi)
258
+
259
+ # angular distance via dot product
260
+ dots = np.clip(np.dot(vec, vec0), -1.0, 1.0)
261
+ ang = np.arccos(dots)
262
+ w = np.exp(- 2*(ang / sigma_rad) ** 2)
263
+
264
+ w = self._normalize_weights(w)
265
+
266
+ if subset_cols:
267
+ pos = self._positions_in_sorted(in_ids, neigh)
268
+ # all should be present due to filtering, but guard anyway
269
+ ok = pos >= 0
270
+ neigh_pos = pos[ok]
271
+ w = w[ok]
272
+ if neigh_pos.size == 0:
273
+ continue
274
+ for c, v in zip(neigh_pos.tolist(), w.tolist()):
275
+ rows.append(r)
276
+ cols.append(int(c))
277
+ vals.append(float(v))
278
+ else:
279
+ for c, v in zip(neigh.tolist(), w.tolist()):
280
+ rows.append(r)
281
+ cols.append(int(c))
282
+ vals.append(float(v))
283
+
284
+ if len(rows) == 0:
285
+ # build an all-zero sparse tensor
286
+ indices = torch.zeros((2, 0), dtype=torch.long, device=self.device)
287
+ vals_t = torch.zeros((0,), dtype=self.dtype, device=self.device)
288
+ return torch.sparse_coo_tensor(
289
+ indices, vals_t, size=(self.K_out, self.K_in), device=self.device, dtype=self.dtype
290
+ ).coalesce()
291
+
292
+ rows_t = torch.tensor(rows, dtype=torch.long, device=self.device)
293
+ cols_t = torch.tensor(cols, dtype=torch.long, device=self.device)
294
+ vals_t = torch.tensor(vals, dtype=self.dtype, device=self.device)
295
+
296
+ indices = torch.stack([rows_t, cols_t], dim=0)
297
+ M = torch.sparse_coo_tensor(
298
+ indices,
299
+ vals_t,
300
+ size=(self.K_out, self.K_in),
301
+ device=self.device,
302
+ dtype=self.dtype,
303
+ ).coalesce()
304
+ return M
305
+
306
+ # ---------------- forward ----------------
307
+ def forward(self, x: torch.Tensor):
308
+ """
309
+ Parameters
310
+ ----------
311
+ x : torch.Tensor
312
+ If has_in_subset:
313
+ - [B,C,K_in] (compact, aligned with in_cell_ids) OR [B,C,N_in] (full sphere)
314
+ Else:
315
+ - [B,C,N_in] (full sphere)
316
+
317
+ Returns
318
+ -------
319
+ y : torch.Tensor
320
+ [B,C,K_out]
321
+ cell_ids_out : torch.Tensor
322
+ [K_out] coarse pixel ids (nside_out), aligned with y last dimension.
323
+ """
324
+ if x.dim() != 3:
325
+ raise ValueError("x must be [B, C, N]")
326
+
327
+ B, C, N = x.shape
328
+ if self.has_in_subset:
329
+ if N not in (self.K_in, self.N_in):
330
+ raise ValueError(
331
+ f"x last dim must be K_in={self.K_in} (compact) or N_in={self.N_in} (full), got {N}"
332
+ )
333
+ else:
334
+ if N != self.N_in:
335
+ raise ValueError(f"x last dim must be N_in={self.N_in}, got {N}")
336
+
337
+ if self.mode == "smooth":
338
+
339
+ # If x is full-sphere but M is subset-based, gather compact inputs
340
+ if self.has_in_subset and N == self.N_in:
341
+ x_use = x.index_select(dim=2, index=self.in_cell_ids_t.to(x.device))
342
+ else:
343
+ x_use = x
344
+
345
+ # sparse mm expects 2D: (K_out, K_in) @ (K_in, B*C)
346
+ x2 = x_use.reshape(B * C, -1).transpose(0, 1).contiguous()
347
+ y2 = torch.sparse.mm(self.M, x2)
348
+ y = y2.transpose(0, 1).reshape(B, C, self.K_out).contiguous()
349
+ return y, self.cell_ids_out_t.to(x.device)
350
+
351
+ # maxpool
352
+ if self.has_in_subset and N == self.N_in:
353
+ x_use = x.index_select(dim=2, index=self.in_cell_ids_t.to(x.device))
354
+ else:
355
+ x_use = x
356
+
357
+ if self.has_in_subset:
358
+ # children_compact: [K_out, 4] indices in 0..K_in-1 or -1
359
+ ch = self.children_compact.to(x.device) # [K_out,4]
360
+ # gather with masking
361
+ # We build y by iterating 4 children with max
362
+ y = None
363
+ for j in range(4):
364
+ idx = ch[:, j] # [K_out]
365
+ mask = idx >= 0
366
+ # start with very negative so missing children don't win
367
+ tmp = torch.full((B, C, self.K_out), -torch.inf, device=x.device, dtype=x.dtype)
368
+ if mask.any():
369
+ tmp[:, :, mask] = x_use.index_select(dim=2, index=idx[mask]).reshape(B, C, -1)
370
+ y = tmp if y is None else torch.maximum(y, tmp)
371
+ # If a parent had no valid children at all, it is -inf -> set to 0
372
+ y = torch.where(torch.isfinite(y), y, torch.zeros_like(y))
373
+ return y, self.cell_ids_out_t.to(x.device)
374
+
375
+ else:
376
+ ch = self.children_full.to(x.device) # [K_out,4] full indices
377
+ # gather children and max
378
+ xch = x_use.index_select(dim=2, index=ch.reshape(-1)).reshape(B, C, self.K_out, 4)
379
+ y = xch.max(dim=3).values
380
+ return y, self.cell_ids_out_t.to(x.device)
foscat/SphereUpGeo.py ADDED
@@ -0,0 +1,175 @@
1
+ import torch
2
+ import torch.nn as nn
3
+ import numpy as np
4
+
5
+ from foscat.SphereDownGeo import SphereDownGeo
6
+
7
+
8
+ class SphereUpGeo(nn.Module):
9
+ """Geometric HEALPix upsampling operator using the transpose of SphereDownGeo.
10
+
11
+ `cell_ids_out` (coarse pixels at nside_out, NESTED) is mandatory.
12
+ Forward expects x of shape [B, C, K_out] aligned with that order.
13
+ Output is a full fine-grid map [B, C, N_in] at nside_in = 2*nside_out.
14
+
15
+ Normalization (diagonal corrections):
16
+ - up_norm='adjoint': x_up = M^T x
17
+ - up_norm='col_l1': x_up = (M^T x) / col_sum, col_sum[i] = sum_k M[k,i]
18
+ - up_norm='diag_l2': x_up = (M^T x) / col_l2, col_l2[i] = sum_k M[k,i]^2
19
+ """
20
+
21
+ def __init__(
22
+ self,
23
+ nside_out: int,
24
+ cell_ids_out,
25
+ radius_deg: float | None = None,
26
+ sigma_deg: float | None = None,
27
+ weight_norm: str = "l1",
28
+ up_norm: str = "col_l1",
29
+ eps: float = 1e-12,
30
+ device=None,
31
+ dtype=torch.float32,
32
+ ):
33
+ super().__init__()
34
+
35
+ if cell_ids_out is None:
36
+ raise ValueError("cell_ids_out is mandatory (1D list/np/tensor of coarse HEALPix ids at nside_out).")
37
+
38
+ if device is None:
39
+ device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
40
+ self.device = device
41
+ self.dtype = dtype
42
+
43
+ self.nside_out = int(nside_out)
44
+ assert (self.nside_out & (self.nside_out - 1)) == 0, "nside_out must be a power of 2."
45
+ self.nside_in = self.nside_out * 2
46
+
47
+ self.N_out = 12 * self.nside_out * self.nside_out
48
+ self.N_in = 12 * self.nside_in * self.nside_in
49
+
50
+ up_norm = str(up_norm).lower().strip()
51
+ if up_norm not in ("adjoint", "col_l1", "diag_l2"):
52
+ raise ValueError("up_norm must be 'adjoint', 'col_l1', or 'diag_l2'.")
53
+ self.up_norm = up_norm
54
+ self.eps = float(eps)
55
+
56
+ # Coarse ids in user-provided order (must be unique for alignment)
57
+ if isinstance(cell_ids_out, torch.Tensor):
58
+ cell_ids_out_np = cell_ids_out.detach().cpu().numpy().astype(np.int64)
59
+ else:
60
+ cell_ids_out_np = np.asarray(cell_ids_out, dtype=np.int64)
61
+
62
+ if cell_ids_out_np.ndim != 1:
63
+ raise ValueError("cell_ids_out must be 1D")
64
+ if cell_ids_out_np.size == 0:
65
+ raise ValueError("cell_ids_out must be non-empty")
66
+ if cell_ids_out_np.min() < 0 or cell_ids_out_np.max() >= self.N_out:
67
+ raise ValueError("cell_ids_out contains out-of-bounds ids for this nside_out")
68
+ if np.unique(cell_ids_out_np).size != cell_ids_out_np.size:
69
+ raise ValueError("cell_ids_out must not contain duplicates (order matters for alignment).")
70
+
71
+ self.cell_ids_out_np = cell_ids_out_np
72
+ self.K_out = int(cell_ids_out_np.size)
73
+ self.register_buffer("cell_ids_out_t", torch.as_tensor(cell_ids_out_np, dtype=torch.long, device=self.device))
74
+
75
+ # Build the FULL down operator at fine resolution (nside_in -> nside_out)
76
+ tmp_down = SphereDownGeo(
77
+ nside_in=self.nside_in,
78
+ mode="smooth",
79
+ radius_deg=radius_deg,
80
+ sigma_deg=sigma_deg,
81
+ weight_norm=weight_norm,
82
+ device=self.device,
83
+ dtype=self.dtype,
84
+ use_csr=False,
85
+ )
86
+
87
+ M_down_full = torch.sparse_coo_tensor(
88
+ tmp_down.M.indices(),
89
+ tmp_down.M.values(),
90
+ size=(tmp_down.N_out, tmp_down.N_in),
91
+ device=self.device,
92
+ dtype=self.dtype,
93
+ ).coalesce()
94
+
95
+ # Extract ONLY the requested coarse rows, in the provided order.
96
+ # We do this on CPU with numpy for simplicity and speed at init.
97
+ idx = M_down_full.indices().cpu().numpy()
98
+ vals = M_down_full.values().cpu().numpy()
99
+ rows = idx[0]
100
+ cols = idx[1]
101
+
102
+ # Map original row id -> new row position [0..K_out-1]
103
+ row_map = {int(r): i for i, r in enumerate(cell_ids_out_np.tolist())}
104
+ mask = np.fromiter((r in row_map for r in rows), dtype=bool, count=rows.size)
105
+
106
+ rows_sel = rows[mask]
107
+ cols_sel = cols[mask]
108
+ vals_sel = vals[mask]
109
+
110
+ new_rows = np.fromiter((row_map[int(r)] for r in rows_sel), dtype=np.int64, count=rows_sel.size)
111
+
112
+ M_down_sub = torch.sparse_coo_tensor(
113
+ torch.as_tensor(np.stack([new_rows, cols_sel], axis=0), dtype=torch.long),
114
+ torch.as_tensor(vals_sel, dtype=self.dtype),
115
+ size=(self.K_out, self.N_in),
116
+ device=self.device,
117
+ dtype=self.dtype,
118
+ ).coalesce()
119
+
120
+ # Store M^T (sparse) so forward is just sparse.mm
121
+ M_up = self._transpose_sparse(M_down_sub) # [N_in, K_out]
122
+ self.register_buffer("M_indices", M_up.indices())
123
+ self.register_buffer("M_values", M_up.values())
124
+ self.M_size = M_up.size()
125
+
126
+ # Diagonal normalizers (length N_in), based on the selected coarse rows only
127
+ idx_sub = M_down_sub.indices()
128
+ vals_sub = M_down_sub.values()
129
+ fine_cols = idx_sub[1]
130
+
131
+ col_sum = torch.zeros(self.N_in, device=self.device, dtype=self.dtype)
132
+ col_l2 = torch.zeros(self.N_in, device=self.device, dtype=self.dtype)
133
+ col_sum.scatter_add_(0, fine_cols, vals_sub)
134
+ col_l2.scatter_add_(0, fine_cols, vals_sub * vals_sub)
135
+
136
+ self.register_buffer("col_sum", col_sum)
137
+ self.register_buffer("col_l2", col_l2)
138
+
139
+ # Fine ids (full sphere)
140
+ self.register_buffer("cell_ids_in_t", torch.arange(self.N_in, dtype=torch.long, device=self.device))
141
+
142
+ self.M_T = torch.sparse_coo_tensor(
143
+ self.M_indices.to(device=self.device),
144
+ self.M_values.to(device=self.device, dtype=self.dtype),
145
+ size=self.M_size,
146
+ device=self.device,
147
+ dtype=self.dtype,
148
+ ).coalesce().to_sparse_csr().to(self.device)
149
+
150
+ @staticmethod
151
+ def _transpose_sparse(M: torch.Tensor) -> torch.Tensor:
152
+ M = M.coalesce()
153
+ idx = M.indices()
154
+ vals = M.values()
155
+ R, C = M.size()
156
+ idx_T = torch.stack([idx[1], idx[0]], dim=0)
157
+ return torch.sparse_coo_tensor(idx_T, vals, size=(C, R), device=M.device, dtype=M.dtype).coalesce()
158
+
159
+ def forward(self, x: torch.Tensor):
160
+ """x: [B, C, K_out] -> x_up: [B, C, N_in]."""
161
+ B, C, K_out = x.shape
162
+ assert K_out == self.K_out, f"Expected K_out={self.K_out}, got {K_out}"
163
+
164
+ x_bc = x.reshape(B * C, K_out)
165
+ x_up_bc_T = torch.sparse.mm(self.M_T, x_bc.T) # [N_in, B*C]
166
+ x_up = x_up_bc_T.T.reshape(B, C, self.N_in) # [B, C, N_in]
167
+
168
+ if self.up_norm == "col_l1":
169
+ denom = self.col_sum.to(device=x.device, dtype=x.dtype).clamp_min(self.eps)
170
+ x_up = x_up / denom.view(1, 1, -1)
171
+ elif self.up_norm == "diag_l2":
172
+ denom = self.col_l2.to(device=x.device, dtype=x.dtype).clamp_min(self.eps)
173
+ x_up = x_up / denom.view(1, 1, -1)
174
+
175
+ return x_up, self.cell_ids_in_t.to(device=x.device)