foscat 2025.10.2__py3-none-any.whl → 2026.1.1__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- foscat/BkTorch.py +635 -141
- foscat/FoCUS.py +135 -52
- foscat/SphereDownGeo.py +380 -0
- foscat/SphereUpGeo.py +175 -0
- foscat/SphericalStencil.py +27 -246
- foscat/alm_loc.py +270 -0
- foscat/scat.py +1 -1
- foscat/scat1D.py +1 -1
- foscat/scat_cov.py +24 -24
- {foscat-2025.10.2.dist-info → foscat-2026.1.1.dist-info}/METADATA +1 -69
- {foscat-2025.10.2.dist-info → foscat-2026.1.1.dist-info}/RECORD +14 -11
- {foscat-2025.10.2.dist-info → foscat-2026.1.1.dist-info}/WHEEL +1 -1
- {foscat-2025.10.2.dist-info → foscat-2026.1.1.dist-info}/licenses/LICENSE +0 -0
- {foscat-2025.10.2.dist-info → foscat-2026.1.1.dist-info}/top_level.txt +0 -0
foscat/SphereDownGeo.py
ADDED
|
@@ -0,0 +1,380 @@
|
|
|
1
|
+
|
|
2
|
+
import torch
|
|
3
|
+
import torch.nn as nn
|
|
4
|
+
import numpy as np
|
|
5
|
+
import healpy as hp
|
|
6
|
+
|
|
7
|
+
|
|
8
|
+
class SphereDownGeo(nn.Module):
|
|
9
|
+
"""
|
|
10
|
+
Geometric HEALPix downsampling operator (NESTED indexing).
|
|
11
|
+
|
|
12
|
+
This module reduces resolution by a factor 2:
|
|
13
|
+
nside_out = nside_in // 2
|
|
14
|
+
|
|
15
|
+
Input conventions
|
|
16
|
+
-----------------
|
|
17
|
+
- If in_cell_ids is None:
|
|
18
|
+
x is expected to be full-sphere: [B, C, N_in]
|
|
19
|
+
output is [B, C, K_out] with K_out = len(cell_ids_out) (or N_out if None).
|
|
20
|
+
- If in_cell_ids is provided (fine pixels at nside_in, NESTED):
|
|
21
|
+
x can be either:
|
|
22
|
+
* compact: [B, C, K_in] where K_in = len(in_cell_ids), aligned with in_cell_ids order
|
|
23
|
+
* full-sphere: [B, C, N_in] (also supported)
|
|
24
|
+
output is [B, C, K_out] where cell_ids_out is derived as unique(in_cell_ids // 4),
|
|
25
|
+
unless you explicitly pass cell_ids_out (then it will be intersected with the derived set).
|
|
26
|
+
|
|
27
|
+
Modes
|
|
28
|
+
-----
|
|
29
|
+
- mode="smooth": linear downsampling y = M @ x (M sparse)
|
|
30
|
+
- mode="maxpool": non-linear max over available children (fast)
|
|
31
|
+
"""
|
|
32
|
+
|
|
33
|
+
def __init__(
|
|
34
|
+
self,
|
|
35
|
+
nside_in: int,
|
|
36
|
+
mode: str = "smooth",
|
|
37
|
+
radius_deg: float | None = None,
|
|
38
|
+
sigma_deg: float | None = None,
|
|
39
|
+
weight_norm: str = "l1",
|
|
40
|
+
cell_ids_out: np.ndarray | list[int] | None = None,
|
|
41
|
+
in_cell_ids: np.ndarray | list[int] | torch.Tensor | None = None,
|
|
42
|
+
use_csr=True,
|
|
43
|
+
device=None,
|
|
44
|
+
dtype: torch.dtype = torch.float32,
|
|
45
|
+
):
|
|
46
|
+
super().__init__()
|
|
47
|
+
|
|
48
|
+
if device is None:
|
|
49
|
+
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
|
|
50
|
+
self.device = device
|
|
51
|
+
self.dtype = dtype
|
|
52
|
+
|
|
53
|
+
self.nside_in = int(nside_in)
|
|
54
|
+
assert (self.nside_in & (self.nside_in - 1)) == 0, "nside_in must be a power of 2."
|
|
55
|
+
self.nside_out = self.nside_in // 2
|
|
56
|
+
assert self.nside_out >= 1, "nside_out must be >= 1."
|
|
57
|
+
|
|
58
|
+
self.N_in = 12 * self.nside_in * self.nside_in
|
|
59
|
+
self.N_out = 12 * self.nside_out * self.nside_out
|
|
60
|
+
|
|
61
|
+
self.mode = str(mode).lower()
|
|
62
|
+
assert self.mode in ("smooth", "maxpool"), "mode must be 'smooth' or 'maxpool'."
|
|
63
|
+
|
|
64
|
+
self.weight_norm = str(weight_norm).lower()
|
|
65
|
+
assert self.weight_norm in ("l1", "l2"), "weight_norm must be 'l1' or 'l2'."
|
|
66
|
+
|
|
67
|
+
# ---- Handle reduced-domain inputs (fine pixels) ----
|
|
68
|
+
self.in_cell_ids = self._validate_in_cell_ids(in_cell_ids)
|
|
69
|
+
self.has_in_subset = self.in_cell_ids is not None
|
|
70
|
+
if self.has_in_subset:
|
|
71
|
+
# derive parents
|
|
72
|
+
derived_out = np.unique(self.in_cell_ids // 4).astype(np.int64)
|
|
73
|
+
if cell_ids_out is None:
|
|
74
|
+
self.cell_ids_out = derived_out
|
|
75
|
+
else:
|
|
76
|
+
req_out = self._validate_cell_ids_out(cell_ids_out)
|
|
77
|
+
# keep only those compatible with derived_out (otherwise they'd be all-zero)
|
|
78
|
+
self.cell_ids_out = np.intersect1d(req_out, derived_out, assume_unique=False)
|
|
79
|
+
if self.cell_ids_out.size == 0:
|
|
80
|
+
raise ValueError(
|
|
81
|
+
"After intersecting cell_ids_out with unique(in_cell_ids//4), "
|
|
82
|
+
"no coarse pixel remains. Check your inputs."
|
|
83
|
+
)
|
|
84
|
+
else:
|
|
85
|
+
self.cell_ids_out = self._validate_cell_ids_out(cell_ids_out)
|
|
86
|
+
|
|
87
|
+
self.K_out = int(self.cell_ids_out.size)
|
|
88
|
+
|
|
89
|
+
# Column basis for smooth matrix:
|
|
90
|
+
# - full sphere: columns are 0..N_in-1
|
|
91
|
+
# - subset: columns are 0..K_in-1 aligned to self.in_cell_ids
|
|
92
|
+
self.K_in = int(self.in_cell_ids.size) if self.has_in_subset else self.N_in
|
|
93
|
+
|
|
94
|
+
if self.mode == "smooth":
|
|
95
|
+
if radius_deg is None:
|
|
96
|
+
# default: include roughly the 4 children footprint
|
|
97
|
+
# (healpy pixel size ~ sqrt(4pi/N), coarse pixel is 4x area)
|
|
98
|
+
radius_deg = 2.0 * hp.nside2resol(self.nside_out, arcmin=True) / 60.0
|
|
99
|
+
if sigma_deg is None:
|
|
100
|
+
sigma_deg = max(radius_deg / 2.0, 1e-6)
|
|
101
|
+
|
|
102
|
+
self.radius_deg = float(radius_deg)
|
|
103
|
+
self.sigma_deg = float(sigma_deg)
|
|
104
|
+
self.radius_rad = self.radius_deg * np.pi / 180.0
|
|
105
|
+
self.sigma_rad = self.sigma_deg * np.pi / 180.0
|
|
106
|
+
|
|
107
|
+
M = self._build_down_matrix() # shape (K_out, K_in or N_in)
|
|
108
|
+
|
|
109
|
+
self.M = M.coalesce()
|
|
110
|
+
|
|
111
|
+
if use_csr:
|
|
112
|
+
self.M = self.M.to_sparse_csr().to(self.device)
|
|
113
|
+
|
|
114
|
+
self.M_size = M.size()
|
|
115
|
+
|
|
116
|
+
else:
|
|
117
|
+
# Precompute children indices for maxpool
|
|
118
|
+
# For subset mode, store mapping from each parent to indices in compact vector,
|
|
119
|
+
# with -1 for missing children.
|
|
120
|
+
children = np.stack(
|
|
121
|
+
[4 * self.cell_ids_out + i for i in range(4)],
|
|
122
|
+
axis=1,
|
|
123
|
+
).astype(np.int64) # [K_out, 4] in fine pixel ids (full indexing)
|
|
124
|
+
|
|
125
|
+
if self.has_in_subset:
|
|
126
|
+
# map each child pixel id to position in in_cell_ids (compact index)
|
|
127
|
+
pos = self._positions_in_sorted(self.in_cell_ids, children.reshape(-1))
|
|
128
|
+
children_compact = pos.reshape(self.K_out, 4).astype(np.int64) # -1 if missing
|
|
129
|
+
self.register_buffer(
|
|
130
|
+
"children_compact",
|
|
131
|
+
torch.tensor(children_compact, dtype=torch.long, device=self.device),
|
|
132
|
+
)
|
|
133
|
+
else:
|
|
134
|
+
self.register_buffer(
|
|
135
|
+
"children_full",
|
|
136
|
+
torch.tensor(children, dtype=torch.long, device=self.device),
|
|
137
|
+
)
|
|
138
|
+
|
|
139
|
+
# expose ids as torch buffers for convenience
|
|
140
|
+
self.register_buffer(
|
|
141
|
+
"cell_ids_out_t",
|
|
142
|
+
torch.tensor(self.cell_ids_out.astype(np.int64), dtype=torch.long, device=self.device),
|
|
143
|
+
)
|
|
144
|
+
if self.has_in_subset:
|
|
145
|
+
self.register_buffer(
|
|
146
|
+
"in_cell_ids_t",
|
|
147
|
+
torch.tensor(self.in_cell_ids.astype(np.int64), dtype=torch.long, device=self.device),
|
|
148
|
+
)
|
|
149
|
+
|
|
150
|
+
# ---------------- validation helpers ----------------
|
|
151
|
+
def _validate_cell_ids_out(self, cell_ids_out):
|
|
152
|
+
"""Return a 1D np.int64 array of coarse cell ids (nside_out)."""
|
|
153
|
+
if cell_ids_out is None:
|
|
154
|
+
return np.arange(self.N_out, dtype=np.int64)
|
|
155
|
+
|
|
156
|
+
arr = np.asarray(cell_ids_out, dtype=np.int64).reshape(-1)
|
|
157
|
+
if arr.size == 0:
|
|
158
|
+
raise ValueError("cell_ids_out is empty: provide at least one coarse pixel id.")
|
|
159
|
+
arr = np.unique(arr)
|
|
160
|
+
if arr.min() < 0 or arr.max() >= self.N_out:
|
|
161
|
+
raise ValueError(f"cell_ids_out must be in [0, {self.N_out-1}] for nside_out={self.nside_out}.")
|
|
162
|
+
return arr
|
|
163
|
+
|
|
164
|
+
def _validate_in_cell_ids(self, in_cell_ids):
|
|
165
|
+
"""Return a 1D np.int64 array of fine cell ids (nside_in) or None."""
|
|
166
|
+
if in_cell_ids is None:
|
|
167
|
+
return None
|
|
168
|
+
if torch.is_tensor(in_cell_ids):
|
|
169
|
+
arr = in_cell_ids.detach().cpu().numpy()
|
|
170
|
+
else:
|
|
171
|
+
arr = np.asarray(in_cell_ids)
|
|
172
|
+
arr = np.asarray(arr, dtype=np.int64).reshape(-1)
|
|
173
|
+
if arr.size == 0:
|
|
174
|
+
raise ValueError("in_cell_ids is empty: provide at least one fine pixel id or None.")
|
|
175
|
+
arr = np.unique(arr)
|
|
176
|
+
if arr.min() < 0 or arr.max() >= self.N_in:
|
|
177
|
+
raise ValueError(f"in_cell_ids must be in [0, {self.N_in-1}] for nside_in={self.nside_in}.")
|
|
178
|
+
return arr
|
|
179
|
+
|
|
180
|
+
@staticmethod
|
|
181
|
+
def _positions_in_sorted(sorted_ids: np.ndarray, query_ids: np.ndarray) -> np.ndarray:
|
|
182
|
+
"""
|
|
183
|
+
For each query_id, return its index in sorted_ids if present, else -1.
|
|
184
|
+
sorted_ids must be sorted ascending unique.
|
|
185
|
+
"""
|
|
186
|
+
q = np.asarray(query_ids, dtype=np.int64)
|
|
187
|
+
idx = np.searchsorted(sorted_ids, q)
|
|
188
|
+
ok = (idx >= 0) & (idx < sorted_ids.size) & (sorted_ids[idx] == q)
|
|
189
|
+
out = np.full(q.shape, -1, dtype=np.int64)
|
|
190
|
+
out[ok] = idx[ok]
|
|
191
|
+
return out
|
|
192
|
+
|
|
193
|
+
# ---------------- weights and matrix build ----------------
|
|
194
|
+
def _normalize_weights(self, w: np.ndarray) -> np.ndarray:
|
|
195
|
+
w = np.asarray(w, dtype=np.float64)
|
|
196
|
+
if w.size == 0:
|
|
197
|
+
return w
|
|
198
|
+
w = np.maximum(w, 0.0)
|
|
199
|
+
|
|
200
|
+
if self.weight_norm == "l1":
|
|
201
|
+
s = w.sum()
|
|
202
|
+
if s <= 0.0:
|
|
203
|
+
return np.ones_like(w) / max(w.size, 1)
|
|
204
|
+
return w / s
|
|
205
|
+
|
|
206
|
+
# l2
|
|
207
|
+
s2 = (w * w).sum()
|
|
208
|
+
if s2 <= 0.0:
|
|
209
|
+
return np.ones_like(w) / max(np.sqrt(w.size), 1.0)
|
|
210
|
+
return w / np.sqrt(s2)
|
|
211
|
+
|
|
212
|
+
def _build_down_matrix(self) -> torch.Tensor:
|
|
213
|
+
"""Construct sparse matrix M (K_out, K_in or N_in) for the selected coarse pixels."""
|
|
214
|
+
nside_in = self.nside_in
|
|
215
|
+
nside_out = self.nside_out
|
|
216
|
+
|
|
217
|
+
radius_rad = self.radius_rad
|
|
218
|
+
sigma_rad = self.sigma_rad
|
|
219
|
+
|
|
220
|
+
rows: list[int] = []
|
|
221
|
+
cols: list[int] = []
|
|
222
|
+
vals: list[float] = []
|
|
223
|
+
|
|
224
|
+
# For subset columns, we use self.in_cell_ids as the basis
|
|
225
|
+
subset_cols = self.has_in_subset
|
|
226
|
+
in_ids = self.in_cell_ids # np.ndarray or None
|
|
227
|
+
|
|
228
|
+
for r, p_out in enumerate(self.cell_ids_out.tolist()):
|
|
229
|
+
theta0, phi0 = hp.pix2ang(nside_out, int(p_out), nest=True)
|
|
230
|
+
vec0 = hp.ang2vec(theta0, phi0)
|
|
231
|
+
|
|
232
|
+
neigh = hp.query_disc(nside_in, vec0, radius_rad, inclusive=True, nest=True)
|
|
233
|
+
neigh = np.asarray(neigh, dtype=np.int64)
|
|
234
|
+
|
|
235
|
+
if subset_cols:
|
|
236
|
+
# keep only valid fine pixels
|
|
237
|
+
# neigh is not sorted; intersect1d expects sorted
|
|
238
|
+
neigh_sorted = np.sort(neigh)
|
|
239
|
+
keep = np.intersect1d(neigh_sorted, in_ids, assume_unique=False)
|
|
240
|
+
neigh = keep
|
|
241
|
+
|
|
242
|
+
# Fallback: if radius query returns nothing in subset mode, at least try the 4 children
|
|
243
|
+
if neigh.size == 0:
|
|
244
|
+
children = (4 * int(p_out) + np.arange(4, dtype=np.int64))
|
|
245
|
+
if subset_cols:
|
|
246
|
+
pos = self._positions_in_sorted(in_ids, children)
|
|
247
|
+
ok = pos >= 0
|
|
248
|
+
if np.any(ok):
|
|
249
|
+
neigh = children[ok]
|
|
250
|
+
else:
|
|
251
|
+
# nothing to connect -> row stays zero
|
|
252
|
+
continue
|
|
253
|
+
else:
|
|
254
|
+
neigh = children
|
|
255
|
+
|
|
256
|
+
theta, phi = hp.pix2ang(nside_in, neigh, nest=True)
|
|
257
|
+
vec = hp.ang2vec(theta, phi)
|
|
258
|
+
|
|
259
|
+
# angular distance via dot product
|
|
260
|
+
dots = np.clip(np.dot(vec, vec0), -1.0, 1.0)
|
|
261
|
+
ang = np.arccos(dots)
|
|
262
|
+
w = np.exp(- 2*(ang / sigma_rad) ** 2)
|
|
263
|
+
|
|
264
|
+
w = self._normalize_weights(w)
|
|
265
|
+
|
|
266
|
+
if subset_cols:
|
|
267
|
+
pos = self._positions_in_sorted(in_ids, neigh)
|
|
268
|
+
# all should be present due to filtering, but guard anyway
|
|
269
|
+
ok = pos >= 0
|
|
270
|
+
neigh_pos = pos[ok]
|
|
271
|
+
w = w[ok]
|
|
272
|
+
if neigh_pos.size == 0:
|
|
273
|
+
continue
|
|
274
|
+
for c, v in zip(neigh_pos.tolist(), w.tolist()):
|
|
275
|
+
rows.append(r)
|
|
276
|
+
cols.append(int(c))
|
|
277
|
+
vals.append(float(v))
|
|
278
|
+
else:
|
|
279
|
+
for c, v in zip(neigh.tolist(), w.tolist()):
|
|
280
|
+
rows.append(r)
|
|
281
|
+
cols.append(int(c))
|
|
282
|
+
vals.append(float(v))
|
|
283
|
+
|
|
284
|
+
if len(rows) == 0:
|
|
285
|
+
# build an all-zero sparse tensor
|
|
286
|
+
indices = torch.zeros((2, 0), dtype=torch.long, device=self.device)
|
|
287
|
+
vals_t = torch.zeros((0,), dtype=self.dtype, device=self.device)
|
|
288
|
+
return torch.sparse_coo_tensor(
|
|
289
|
+
indices, vals_t, size=(self.K_out, self.K_in), device=self.device, dtype=self.dtype
|
|
290
|
+
).coalesce()
|
|
291
|
+
|
|
292
|
+
rows_t = torch.tensor(rows, dtype=torch.long, device=self.device)
|
|
293
|
+
cols_t = torch.tensor(cols, dtype=torch.long, device=self.device)
|
|
294
|
+
vals_t = torch.tensor(vals, dtype=self.dtype, device=self.device)
|
|
295
|
+
|
|
296
|
+
indices = torch.stack([rows_t, cols_t], dim=0)
|
|
297
|
+
M = torch.sparse_coo_tensor(
|
|
298
|
+
indices,
|
|
299
|
+
vals_t,
|
|
300
|
+
size=(self.K_out, self.K_in),
|
|
301
|
+
device=self.device,
|
|
302
|
+
dtype=self.dtype,
|
|
303
|
+
).coalesce()
|
|
304
|
+
return M
|
|
305
|
+
|
|
306
|
+
# ---------------- forward ----------------
|
|
307
|
+
def forward(self, x: torch.Tensor):
|
|
308
|
+
"""
|
|
309
|
+
Parameters
|
|
310
|
+
----------
|
|
311
|
+
x : torch.Tensor
|
|
312
|
+
If has_in_subset:
|
|
313
|
+
- [B,C,K_in] (compact, aligned with in_cell_ids) OR [B,C,N_in] (full sphere)
|
|
314
|
+
Else:
|
|
315
|
+
- [B,C,N_in] (full sphere)
|
|
316
|
+
|
|
317
|
+
Returns
|
|
318
|
+
-------
|
|
319
|
+
y : torch.Tensor
|
|
320
|
+
[B,C,K_out]
|
|
321
|
+
cell_ids_out : torch.Tensor
|
|
322
|
+
[K_out] coarse pixel ids (nside_out), aligned with y last dimension.
|
|
323
|
+
"""
|
|
324
|
+
if x.dim() != 3:
|
|
325
|
+
raise ValueError("x must be [B, C, N]")
|
|
326
|
+
|
|
327
|
+
B, C, N = x.shape
|
|
328
|
+
if self.has_in_subset:
|
|
329
|
+
if N not in (self.K_in, self.N_in):
|
|
330
|
+
raise ValueError(
|
|
331
|
+
f"x last dim must be K_in={self.K_in} (compact) or N_in={self.N_in} (full), got {N}"
|
|
332
|
+
)
|
|
333
|
+
else:
|
|
334
|
+
if N != self.N_in:
|
|
335
|
+
raise ValueError(f"x last dim must be N_in={self.N_in}, got {N}")
|
|
336
|
+
|
|
337
|
+
if self.mode == "smooth":
|
|
338
|
+
|
|
339
|
+
# If x is full-sphere but M is subset-based, gather compact inputs
|
|
340
|
+
if self.has_in_subset and N == self.N_in:
|
|
341
|
+
x_use = x.index_select(dim=2, index=self.in_cell_ids_t.to(x.device))
|
|
342
|
+
else:
|
|
343
|
+
x_use = x
|
|
344
|
+
|
|
345
|
+
# sparse mm expects 2D: (K_out, K_in) @ (K_in, B*C)
|
|
346
|
+
x2 = x_use.reshape(B * C, -1).transpose(0, 1).contiguous()
|
|
347
|
+
y2 = torch.sparse.mm(self.M, x2)
|
|
348
|
+
y = y2.transpose(0, 1).reshape(B, C, self.K_out).contiguous()
|
|
349
|
+
return y, self.cell_ids_out_t.to(x.device)
|
|
350
|
+
|
|
351
|
+
# maxpool
|
|
352
|
+
if self.has_in_subset and N == self.N_in:
|
|
353
|
+
x_use = x.index_select(dim=2, index=self.in_cell_ids_t.to(x.device))
|
|
354
|
+
else:
|
|
355
|
+
x_use = x
|
|
356
|
+
|
|
357
|
+
if self.has_in_subset:
|
|
358
|
+
# children_compact: [K_out, 4] indices in 0..K_in-1 or -1
|
|
359
|
+
ch = self.children_compact.to(x.device) # [K_out,4]
|
|
360
|
+
# gather with masking
|
|
361
|
+
# We build y by iterating 4 children with max
|
|
362
|
+
y = None
|
|
363
|
+
for j in range(4):
|
|
364
|
+
idx = ch[:, j] # [K_out]
|
|
365
|
+
mask = idx >= 0
|
|
366
|
+
# start with very negative so missing children don't win
|
|
367
|
+
tmp = torch.full((B, C, self.K_out), -torch.inf, device=x.device, dtype=x.dtype)
|
|
368
|
+
if mask.any():
|
|
369
|
+
tmp[:, :, mask] = x_use.index_select(dim=2, index=idx[mask]).reshape(B, C, -1)
|
|
370
|
+
y = tmp if y is None else torch.maximum(y, tmp)
|
|
371
|
+
# If a parent had no valid children at all, it is -inf -> set to 0
|
|
372
|
+
y = torch.where(torch.isfinite(y), y, torch.zeros_like(y))
|
|
373
|
+
return y, self.cell_ids_out_t.to(x.device)
|
|
374
|
+
|
|
375
|
+
else:
|
|
376
|
+
ch = self.children_full.to(x.device) # [K_out,4] full indices
|
|
377
|
+
# gather children and max
|
|
378
|
+
xch = x_use.index_select(dim=2, index=ch.reshape(-1)).reshape(B, C, self.K_out, 4)
|
|
379
|
+
y = xch.max(dim=3).values
|
|
380
|
+
return y, self.cell_ids_out_t.to(x.device)
|
foscat/SphereUpGeo.py
ADDED
|
@@ -0,0 +1,175 @@
|
|
|
1
|
+
import torch
|
|
2
|
+
import torch.nn as nn
|
|
3
|
+
import numpy as np
|
|
4
|
+
|
|
5
|
+
from foscat.SphereDownGeo import SphereDownGeo
|
|
6
|
+
|
|
7
|
+
|
|
8
|
+
class SphereUpGeo(nn.Module):
|
|
9
|
+
"""Geometric HEALPix upsampling operator using the transpose of SphereDownGeo.
|
|
10
|
+
|
|
11
|
+
`cell_ids_out` (coarse pixels at nside_out, NESTED) is mandatory.
|
|
12
|
+
Forward expects x of shape [B, C, K_out] aligned with that order.
|
|
13
|
+
Output is a full fine-grid map [B, C, N_in] at nside_in = 2*nside_out.
|
|
14
|
+
|
|
15
|
+
Normalization (diagonal corrections):
|
|
16
|
+
- up_norm='adjoint': x_up = M^T x
|
|
17
|
+
- up_norm='col_l1': x_up = (M^T x) / col_sum, col_sum[i] = sum_k M[k,i]
|
|
18
|
+
- up_norm='diag_l2': x_up = (M^T x) / col_l2, col_l2[i] = sum_k M[k,i]^2
|
|
19
|
+
"""
|
|
20
|
+
|
|
21
|
+
def __init__(
|
|
22
|
+
self,
|
|
23
|
+
nside_out: int,
|
|
24
|
+
cell_ids_out,
|
|
25
|
+
radius_deg: float | None = None,
|
|
26
|
+
sigma_deg: float | None = None,
|
|
27
|
+
weight_norm: str = "l1",
|
|
28
|
+
up_norm: str = "col_l1",
|
|
29
|
+
eps: float = 1e-12,
|
|
30
|
+
device=None,
|
|
31
|
+
dtype=torch.float32,
|
|
32
|
+
):
|
|
33
|
+
super().__init__()
|
|
34
|
+
|
|
35
|
+
if cell_ids_out is None:
|
|
36
|
+
raise ValueError("cell_ids_out is mandatory (1D list/np/tensor of coarse HEALPix ids at nside_out).")
|
|
37
|
+
|
|
38
|
+
if device is None:
|
|
39
|
+
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
|
|
40
|
+
self.device = device
|
|
41
|
+
self.dtype = dtype
|
|
42
|
+
|
|
43
|
+
self.nside_out = int(nside_out)
|
|
44
|
+
assert (self.nside_out & (self.nside_out - 1)) == 0, "nside_out must be a power of 2."
|
|
45
|
+
self.nside_in = self.nside_out * 2
|
|
46
|
+
|
|
47
|
+
self.N_out = 12 * self.nside_out * self.nside_out
|
|
48
|
+
self.N_in = 12 * self.nside_in * self.nside_in
|
|
49
|
+
|
|
50
|
+
up_norm = str(up_norm).lower().strip()
|
|
51
|
+
if up_norm not in ("adjoint", "col_l1", "diag_l2"):
|
|
52
|
+
raise ValueError("up_norm must be 'adjoint', 'col_l1', or 'diag_l2'.")
|
|
53
|
+
self.up_norm = up_norm
|
|
54
|
+
self.eps = float(eps)
|
|
55
|
+
|
|
56
|
+
# Coarse ids in user-provided order (must be unique for alignment)
|
|
57
|
+
if isinstance(cell_ids_out, torch.Tensor):
|
|
58
|
+
cell_ids_out_np = cell_ids_out.detach().cpu().numpy().astype(np.int64)
|
|
59
|
+
else:
|
|
60
|
+
cell_ids_out_np = np.asarray(cell_ids_out, dtype=np.int64)
|
|
61
|
+
|
|
62
|
+
if cell_ids_out_np.ndim != 1:
|
|
63
|
+
raise ValueError("cell_ids_out must be 1D")
|
|
64
|
+
if cell_ids_out_np.size == 0:
|
|
65
|
+
raise ValueError("cell_ids_out must be non-empty")
|
|
66
|
+
if cell_ids_out_np.min() < 0 or cell_ids_out_np.max() >= self.N_out:
|
|
67
|
+
raise ValueError("cell_ids_out contains out-of-bounds ids for this nside_out")
|
|
68
|
+
if np.unique(cell_ids_out_np).size != cell_ids_out_np.size:
|
|
69
|
+
raise ValueError("cell_ids_out must not contain duplicates (order matters for alignment).")
|
|
70
|
+
|
|
71
|
+
self.cell_ids_out_np = cell_ids_out_np
|
|
72
|
+
self.K_out = int(cell_ids_out_np.size)
|
|
73
|
+
self.register_buffer("cell_ids_out_t", torch.as_tensor(cell_ids_out_np, dtype=torch.long, device=self.device))
|
|
74
|
+
|
|
75
|
+
# Build the FULL down operator at fine resolution (nside_in -> nside_out)
|
|
76
|
+
tmp_down = SphereDownGeo(
|
|
77
|
+
nside_in=self.nside_in,
|
|
78
|
+
mode="smooth",
|
|
79
|
+
radius_deg=radius_deg,
|
|
80
|
+
sigma_deg=sigma_deg,
|
|
81
|
+
weight_norm=weight_norm,
|
|
82
|
+
device=self.device,
|
|
83
|
+
dtype=self.dtype,
|
|
84
|
+
use_csr=False,
|
|
85
|
+
)
|
|
86
|
+
|
|
87
|
+
M_down_full = torch.sparse_coo_tensor(
|
|
88
|
+
tmp_down.M.indices(),
|
|
89
|
+
tmp_down.M.values(),
|
|
90
|
+
size=(tmp_down.N_out, tmp_down.N_in),
|
|
91
|
+
device=self.device,
|
|
92
|
+
dtype=self.dtype,
|
|
93
|
+
).coalesce()
|
|
94
|
+
|
|
95
|
+
# Extract ONLY the requested coarse rows, in the provided order.
|
|
96
|
+
# We do this on CPU with numpy for simplicity and speed at init.
|
|
97
|
+
idx = M_down_full.indices().cpu().numpy()
|
|
98
|
+
vals = M_down_full.values().cpu().numpy()
|
|
99
|
+
rows = idx[0]
|
|
100
|
+
cols = idx[1]
|
|
101
|
+
|
|
102
|
+
# Map original row id -> new row position [0..K_out-1]
|
|
103
|
+
row_map = {int(r): i for i, r in enumerate(cell_ids_out_np.tolist())}
|
|
104
|
+
mask = np.fromiter((r in row_map for r in rows), dtype=bool, count=rows.size)
|
|
105
|
+
|
|
106
|
+
rows_sel = rows[mask]
|
|
107
|
+
cols_sel = cols[mask]
|
|
108
|
+
vals_sel = vals[mask]
|
|
109
|
+
|
|
110
|
+
new_rows = np.fromiter((row_map[int(r)] for r in rows_sel), dtype=np.int64, count=rows_sel.size)
|
|
111
|
+
|
|
112
|
+
M_down_sub = torch.sparse_coo_tensor(
|
|
113
|
+
torch.as_tensor(np.stack([new_rows, cols_sel], axis=0), dtype=torch.long),
|
|
114
|
+
torch.as_tensor(vals_sel, dtype=self.dtype),
|
|
115
|
+
size=(self.K_out, self.N_in),
|
|
116
|
+
device=self.device,
|
|
117
|
+
dtype=self.dtype,
|
|
118
|
+
).coalesce()
|
|
119
|
+
|
|
120
|
+
# Store M^T (sparse) so forward is just sparse.mm
|
|
121
|
+
M_up = self._transpose_sparse(M_down_sub) # [N_in, K_out]
|
|
122
|
+
self.register_buffer("M_indices", M_up.indices())
|
|
123
|
+
self.register_buffer("M_values", M_up.values())
|
|
124
|
+
self.M_size = M_up.size()
|
|
125
|
+
|
|
126
|
+
# Diagonal normalizers (length N_in), based on the selected coarse rows only
|
|
127
|
+
idx_sub = M_down_sub.indices()
|
|
128
|
+
vals_sub = M_down_sub.values()
|
|
129
|
+
fine_cols = idx_sub[1]
|
|
130
|
+
|
|
131
|
+
col_sum = torch.zeros(self.N_in, device=self.device, dtype=self.dtype)
|
|
132
|
+
col_l2 = torch.zeros(self.N_in, device=self.device, dtype=self.dtype)
|
|
133
|
+
col_sum.scatter_add_(0, fine_cols, vals_sub)
|
|
134
|
+
col_l2.scatter_add_(0, fine_cols, vals_sub * vals_sub)
|
|
135
|
+
|
|
136
|
+
self.register_buffer("col_sum", col_sum)
|
|
137
|
+
self.register_buffer("col_l2", col_l2)
|
|
138
|
+
|
|
139
|
+
# Fine ids (full sphere)
|
|
140
|
+
self.register_buffer("cell_ids_in_t", torch.arange(self.N_in, dtype=torch.long, device=self.device))
|
|
141
|
+
|
|
142
|
+
self.M_T = torch.sparse_coo_tensor(
|
|
143
|
+
self.M_indices.to(device=self.device),
|
|
144
|
+
self.M_values.to(device=self.device, dtype=self.dtype),
|
|
145
|
+
size=self.M_size,
|
|
146
|
+
device=self.device,
|
|
147
|
+
dtype=self.dtype,
|
|
148
|
+
).coalesce().to_sparse_csr().to(self.device)
|
|
149
|
+
|
|
150
|
+
@staticmethod
|
|
151
|
+
def _transpose_sparse(M: torch.Tensor) -> torch.Tensor:
|
|
152
|
+
M = M.coalesce()
|
|
153
|
+
idx = M.indices()
|
|
154
|
+
vals = M.values()
|
|
155
|
+
R, C = M.size()
|
|
156
|
+
idx_T = torch.stack([idx[1], idx[0]], dim=0)
|
|
157
|
+
return torch.sparse_coo_tensor(idx_T, vals, size=(C, R), device=M.device, dtype=M.dtype).coalesce()
|
|
158
|
+
|
|
159
|
+
def forward(self, x: torch.Tensor):
|
|
160
|
+
"""x: [B, C, K_out] -> x_up: [B, C, N_in]."""
|
|
161
|
+
B, C, K_out = x.shape
|
|
162
|
+
assert K_out == self.K_out, f"Expected K_out={self.K_out}, got {K_out}"
|
|
163
|
+
|
|
164
|
+
x_bc = x.reshape(B * C, K_out)
|
|
165
|
+
x_up_bc_T = torch.sparse.mm(self.M_T, x_bc.T) # [N_in, B*C]
|
|
166
|
+
x_up = x_up_bc_T.T.reshape(B, C, self.N_in) # [B, C, N_in]
|
|
167
|
+
|
|
168
|
+
if self.up_norm == "col_l1":
|
|
169
|
+
denom = self.col_sum.to(device=x.device, dtype=x.dtype).clamp_min(self.eps)
|
|
170
|
+
x_up = x_up / denom.view(1, 1, -1)
|
|
171
|
+
elif self.up_norm == "diag_l2":
|
|
172
|
+
denom = self.col_l2.to(device=x.device, dtype=x.dtype).clamp_min(self.eps)
|
|
173
|
+
x_up = x_up / denom.view(1, 1, -1)
|
|
174
|
+
|
|
175
|
+
return x_up, self.cell_ids_in_t.to(device=x.device)
|