foscat 2025.9.5__py3-none-any.whl → 2025.10.2__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- foscat/FoCUS.py +14 -9
- foscat/Plot.py +4 -3
- foscat/healpix_unet_torch.py +17 -1
- foscat/healpix_vit_skip.py +445 -0
- foscat/healpix_vit_torch-old.py +658 -0
- foscat/healpix_vit_torch.py +521 -0
- foscat/planar_vit.py +206 -0
- foscat/unet_2_d_from_healpix_params.py +421 -0
- {foscat-2025.9.5.dist-info → foscat-2025.10.2.dist-info}/METADATA +1 -1
- {foscat-2025.9.5.dist-info → foscat-2025.10.2.dist-info}/RECORD +13 -8
- {foscat-2025.9.5.dist-info → foscat-2025.10.2.dist-info}/WHEEL +0 -0
- {foscat-2025.9.5.dist-info → foscat-2025.10.2.dist-info}/licenses/LICENSE +0 -0
- {foscat-2025.9.5.dist-info → foscat-2025.10.2.dist-info}/top_level.txt +0 -0
|
@@ -0,0 +1,521 @@
|
|
|
1
|
+
# healpix_vit_varlevels.py
|
|
2
|
+
# HEALPix ViT with level-wise (variable) channel widths and U-Net-style spherical decoder
|
|
3
|
+
from __future__ import annotations
|
|
4
|
+
from typing import List, Optional, Literal, Tuple, Union
|
|
5
|
+
import numpy as np
|
|
6
|
+
|
|
7
|
+
import torch
|
|
8
|
+
import torch.nn as nn
|
|
9
|
+
import torch.nn.functional as F
|
|
10
|
+
|
|
11
|
+
import foscat.scat_cov as sc
|
|
12
|
+
import foscat.SphericalStencil as ho
|
|
13
|
+
|
|
14
|
+
|
|
15
|
+
class HealpixViT(nn.Module):
|
|
16
|
+
"""
|
|
17
|
+
HEALPix Vision Transformer (Foscat-based) with *variable channel widths per level*
|
|
18
|
+
and a U-Net-like spherical decoder.
|
|
19
|
+
|
|
20
|
+
Key idea
|
|
21
|
+
--------
|
|
22
|
+
- Encoder uses a list of channel dimensions `level_dims = [C_fine, C_l1, ..., C_token]`
|
|
23
|
+
that evolve *with depth* (e.g., 128 -> 192 -> 256).
|
|
24
|
+
- At each encoder level (before a Down()), we apply a spherical convolution that
|
|
25
|
+
maps C_i -> C_{i+1}. Down() then reduces the HEALPix resolution by one level.
|
|
26
|
+
- Transformer runs at the token grid with embedding dim = C_token.
|
|
27
|
+
- Decoder upsamples *one level at a time*; after each Up() it concatenates the
|
|
28
|
+
upsampled token features (C_{i+1}) with the corresponding skip (C_i) and applies
|
|
29
|
+
a spherical convolution to fuse (C_{i+1} + C_i) -> C_i.
|
|
30
|
+
- Final head maps C_fine -> out_channels at the finest grid.
|
|
31
|
+
|
|
32
|
+
Shapes (dense tasks)
|
|
33
|
+
--------------------
|
|
34
|
+
Input : (B, Cin, Nfine)
|
|
35
|
+
→ patch-embed (Cin -> C_fine) at finest grid
|
|
36
|
+
→ for i in [0..L-1]: EncConv(C_i->C_{i+1}) → Down() (store skip_i=C_i at grid i)
|
|
37
|
+
→ tokens at grid L with dim C_token
|
|
38
|
+
→ Transformer on tokens (C_token)
|
|
39
|
+
→ for i in [L-1..0]: Up() to grid i → concat(skip_i, up) [C_i + C_{i+1}] → DecConv → C_i
|
|
40
|
+
→ Head: C_fine -> out_channels at finest grid
|
|
41
|
+
|
|
42
|
+
Requirements
|
|
43
|
+
------------
|
|
44
|
+
- level_dims length must be token_down+1, with:
|
|
45
|
+
len(level_dims) = token_down + 1
|
|
46
|
+
level_dims[0] = channels at finest grid after patch embedding
|
|
47
|
+
level_dims[-1] = Transformer embedding dimension
|
|
48
|
+
- Each value in level_dims must be divisible by G (number of gauges).
|
|
49
|
+
- out_channels must be divisible by G.
|
|
50
|
+
|
|
51
|
+
Parameters (main)
|
|
52
|
+
-----------------
|
|
53
|
+
in_nside : input HEALPix nside (nested)
|
|
54
|
+
n_chan_in : input channels at finest grid (Cin)
|
|
55
|
+
level_dims : list of ints, channel width per level from fine to token
|
|
56
|
+
depth : number of Transformer encoder layers
|
|
57
|
+
num_heads : self-attention heads
|
|
58
|
+
cell_ids : finest-level nested indices (Nfine = 12*nside^2)
|
|
59
|
+
task : "regression" | "segmentation" | "global"
|
|
60
|
+
out_channels : output channels for dense tasks
|
|
61
|
+
KERNELSZ : spherical kernel size for Foscat convolutions
|
|
62
|
+
gauge_type : "cosmo" | "phi"
|
|
63
|
+
G : number of gauges
|
|
64
|
+
"""
|
|
65
|
+
|
|
66
|
+
def __init__(
|
|
67
|
+
self,
|
|
68
|
+
*,
|
|
69
|
+
in_nside: int,
|
|
70
|
+
n_chan_in: int,
|
|
71
|
+
level_dims: List[int], # e.g., [128, 192, 256] (fine -> token)
|
|
72
|
+
depth: int,
|
|
73
|
+
num_heads: int,
|
|
74
|
+
cell_ids: np.ndarray,
|
|
75
|
+
task: Literal["regression", "segmentation", "global"] = "regression",
|
|
76
|
+
out_channels: int = 1,
|
|
77
|
+
mlp_ratio: float = 4.0,
|
|
78
|
+
KERNELSZ: int = 3,
|
|
79
|
+
gauge_type: Literal["cosmo", "phi"] = "cosmo",
|
|
80
|
+
G: int = 1,
|
|
81
|
+
prefer_foscat_gpu: bool = True,
|
|
82
|
+
cls_token: bool = False,
|
|
83
|
+
pos_embed: Literal["learned", "none"] = "learned",
|
|
84
|
+
head_type: Literal["mean", "cls"] = "mean",
|
|
85
|
+
dropout: float = 0.0,
|
|
86
|
+
dtype: Literal["float32", "float64"] = "float32",
|
|
87
|
+
) -> None:
|
|
88
|
+
super().__init__()
|
|
89
|
+
|
|
90
|
+
# ---- config ----
|
|
91
|
+
self.in_nside = int(in_nside)
|
|
92
|
+
self.n_chan_in = int(n_chan_in)
|
|
93
|
+
self.level_dims = list(level_dims)
|
|
94
|
+
self.depth = int(depth)
|
|
95
|
+
self.num_heads = int(num_heads)
|
|
96
|
+
self.task = task
|
|
97
|
+
self.out_channels = int(out_channels)
|
|
98
|
+
self.mlp_ratio = float(mlp_ratio)
|
|
99
|
+
self.KERNELSZ = int(KERNELSZ)
|
|
100
|
+
self.gauge_type = gauge_type
|
|
101
|
+
self.G = int(G)
|
|
102
|
+
self.prefer_foscat_gpu = bool(prefer_foscat_gpu)
|
|
103
|
+
self.cls_token_enabled = bool(cls_token)
|
|
104
|
+
self.pos_embed_type = pos_embed
|
|
105
|
+
self.head_type = head_type
|
|
106
|
+
self.dropout = float(dropout)
|
|
107
|
+
self.dtype = dtype
|
|
108
|
+
|
|
109
|
+
if len(self.level_dims) < 1:
|
|
110
|
+
raise ValueError("level_dims must have at least one element (fine level).")
|
|
111
|
+
self.token_down = len(self.level_dims) - 1
|
|
112
|
+
self.embed_dim = int(self.level_dims[-1]) # Transformer dim
|
|
113
|
+
|
|
114
|
+
for d in self.level_dims:
|
|
115
|
+
if d % self.G != 0:
|
|
116
|
+
raise ValueError(f"Each level dim must be divisible by G={self.G}; got {d}.")
|
|
117
|
+
if self.embed_dim % self.num_heads != 0:
|
|
118
|
+
raise ValueError("embed_dim must be divisible by num_heads.")
|
|
119
|
+
|
|
120
|
+
if dtype == "float32":
|
|
121
|
+
self.np_dtype = np.float32
|
|
122
|
+
self.torch_dtype = torch.float32
|
|
123
|
+
else:
|
|
124
|
+
self.np_dtype = np.float64
|
|
125
|
+
self.torch_dtype = torch.float32 # keep model in fp32
|
|
126
|
+
|
|
127
|
+
if cell_ids is None:
|
|
128
|
+
raise ValueError("cell_ids (finest) must be provided (nested ordering).")
|
|
129
|
+
self.cell_ids_fine = np.asarray(cell_ids)
|
|
130
|
+
|
|
131
|
+
# Default activation
|
|
132
|
+
if self.task == "segmentation":
|
|
133
|
+
self.final_activation = "sigmoid" if self.out_channels == 1 else "softmax"
|
|
134
|
+
else:
|
|
135
|
+
self.final_activation = "none"
|
|
136
|
+
|
|
137
|
+
# Foscat wrapper
|
|
138
|
+
self.f = sc.funct(KERNELSZ=self.KERNELSZ)
|
|
139
|
+
|
|
140
|
+
# ---- Build operators per level (fine -> ... -> token) and compute ids ----
|
|
141
|
+
self.hconv_levels: List[ho.SphericalStencil] = []
|
|
142
|
+
self.level_cell_ids: List[np.ndarray] = [self.cell_ids_fine]
|
|
143
|
+
current_nside = self.in_nside
|
|
144
|
+
|
|
145
|
+
dummy = self.f.backend.bk_cast(
|
|
146
|
+
np.zeros((1, 1, self.cell_ids_fine.shape[0]), dtype=self.np_dtype)
|
|
147
|
+
)
|
|
148
|
+
|
|
149
|
+
for _ in range(self.token_down):
|
|
150
|
+
hc = ho.SphericalStencil(
|
|
151
|
+
current_nside,
|
|
152
|
+
self.KERNELSZ,
|
|
153
|
+
n_gauges=self.G,
|
|
154
|
+
gauge_type=self.gauge_type,
|
|
155
|
+
cell_ids=self.level_cell_ids[-1],
|
|
156
|
+
dtype=self.torch_dtype,
|
|
157
|
+
)
|
|
158
|
+
self.hconv_levels.append(hc)
|
|
159
|
+
|
|
160
|
+
dummy, next_ids = hc.Down(
|
|
161
|
+
dummy, cell_ids=self.level_cell_ids[-1], nside=current_nside, max_poll=True
|
|
162
|
+
)
|
|
163
|
+
self.level_cell_ids.append(self.f.backend.to_numpy(next_ids))
|
|
164
|
+
current_nside //= 2
|
|
165
|
+
|
|
166
|
+
self.token_nside = current_nside if self.token_down > 0 else self.in_nside
|
|
167
|
+
self.token_cell_ids = self.level_cell_ids[-1]
|
|
168
|
+
|
|
169
|
+
# Token and fine-level operators (for convenience)
|
|
170
|
+
self.hconv_token = ho.SphericalStencil(
|
|
171
|
+
self.token_nside,
|
|
172
|
+
self.KERNELSZ,
|
|
173
|
+
n_gauges=self.G,
|
|
174
|
+
gauge_type=self.gauge_type,
|
|
175
|
+
cell_ids=self.token_cell_ids,
|
|
176
|
+
dtype=self.torch_dtype,
|
|
177
|
+
)
|
|
178
|
+
self.hconv_head = ho.SphericalStencil(
|
|
179
|
+
self.in_nside,
|
|
180
|
+
self.KERNELSZ,
|
|
181
|
+
n_gauges=self.G,
|
|
182
|
+
gauge_type=self.gauge_type,
|
|
183
|
+
cell_ids=self.cell_ids_fine,
|
|
184
|
+
dtype=self.torch_dtype,
|
|
185
|
+
)
|
|
186
|
+
|
|
187
|
+
# ---------------- Patch embedding (Cin -> C_fine) ----------------
|
|
188
|
+
fine_dim = self.level_dims[0]
|
|
189
|
+
fine_g = fine_dim // self.G
|
|
190
|
+
self.patch_w = nn.Parameter(
|
|
191
|
+
torch.empty(self.n_chan_in, fine_g, self.KERNELSZ * self.KERNELSZ)
|
|
192
|
+
)
|
|
193
|
+
nn.init.kaiming_uniform_(self.patch_w.view(self.n_chan_in * fine_g, -1), a=np.sqrt(5))
|
|
194
|
+
self.patch_bn = nn.GroupNorm(num_groups=min(8, fine_dim if fine_dim > 1 else 1),
|
|
195
|
+
num_channels=fine_dim)
|
|
196
|
+
|
|
197
|
+
# ---------------- Encoder convs per level (C_i -> C_{i+1}) ----------------
|
|
198
|
+
self.enc_w: nn.ParameterList = nn.ParameterList()
|
|
199
|
+
self.enc_bn: nn.ModuleList = nn.ModuleList()
|
|
200
|
+
for i in range(self.token_down):
|
|
201
|
+
Cin = self.level_dims[i]
|
|
202
|
+
Cout = self.level_dims[i+1]
|
|
203
|
+
Cout_g = Cout // self.G
|
|
204
|
+
w = nn.Parameter(torch.empty(Cin, Cout_g, self.KERNELSZ * self.KERNELSZ))
|
|
205
|
+
nn.init.kaiming_uniform_(w.view(Cin * Cout_g, -1), a=np.sqrt(5))
|
|
206
|
+
self.enc_w.append(w)
|
|
207
|
+
self.enc_bn.append(nn.GroupNorm(num_groups=min(8, Cout if Cout > 1 else 1),
|
|
208
|
+
num_channels=Cout))
|
|
209
|
+
|
|
210
|
+
# ---------------- Transformer at token grid ----------------
|
|
211
|
+
self.n_tokens = int(self.token_cell_ids.shape[0])
|
|
212
|
+
if self.cls_token_enabled:
|
|
213
|
+
self.cls_token = nn.Parameter(torch.zeros(1, 1, self.embed_dim))
|
|
214
|
+
nn.init.trunc_normal_(self.cls_token, std=0.02)
|
|
215
|
+
n_pe = self.n_tokens + 1
|
|
216
|
+
else:
|
|
217
|
+
self.cls_token = None
|
|
218
|
+
n_pe = self.n_tokens
|
|
219
|
+
|
|
220
|
+
if self.pos_embed_type == "learned":
|
|
221
|
+
self.pos_embed = nn.Parameter(torch.zeros(1, n_pe, self.embed_dim))
|
|
222
|
+
nn.init.trunc_normal_(self.pos_embed, std=0.02)
|
|
223
|
+
else:
|
|
224
|
+
self.pos_embed = None
|
|
225
|
+
|
|
226
|
+
enc_layer = nn.TransformerEncoderLayer(
|
|
227
|
+
d_model=self.embed_dim,
|
|
228
|
+
nhead=self.num_heads,
|
|
229
|
+
dim_feedforward=int(self.embed_dim * self.mlp_ratio),
|
|
230
|
+
dropout=self.dropout,
|
|
231
|
+
activation="gelu",
|
|
232
|
+
batch_first=True,
|
|
233
|
+
norm_first=True,
|
|
234
|
+
)
|
|
235
|
+
self.encoder = nn.TransformerEncoder(enc_layer, num_layers=self.depth)
|
|
236
|
+
|
|
237
|
+
# Projection at token grid (keep C_token)
|
|
238
|
+
self.token_proj = nn.Linear(self.embed_dim, self.embed_dim)
|
|
239
|
+
|
|
240
|
+
# ---------------- Decoder convs per level ( (C_{i+1}+C_i) -> C_i ) ----------------
|
|
241
|
+
self.dec_w: nn.ParameterList = nn.ParameterList()
|
|
242
|
+
self.dec_bn: nn.ModuleList = nn.ModuleList()
|
|
243
|
+
for i in range(self.token_down - 1, -1, -1):
|
|
244
|
+
# decoder proceeds from token level back to fine; we create weights in the same order
|
|
245
|
+
Cin_fuse = self.level_dims[i+1] + self.level_dims[i] # up + skip
|
|
246
|
+
Cout = self.level_dims[i]
|
|
247
|
+
Cout_g = Cout // self.G
|
|
248
|
+
w = nn.Parameter(torch.empty(Cin_fuse, Cout_g, self.KERNELSZ * self.KERNELSZ))
|
|
249
|
+
nn.init.kaiming_uniform_(w.view(Cin_fuse * Cout_g, -1), a=np.sqrt(5))
|
|
250
|
+
self.dec_w.append(w) # index 0 corresponds to up from token to level L-1
|
|
251
|
+
self.dec_bn.append(nn.GroupNorm(num_groups=min(8, Cout if Cout > 1 else 1),
|
|
252
|
+
num_channels=Cout))
|
|
253
|
+
|
|
254
|
+
# ---------------- Final head (C_fine -> out_channels) ----------------
|
|
255
|
+
if self.task == "global":
|
|
256
|
+
self.global_head = nn.Linear(self.embed_dim, self.out_channels)
|
|
257
|
+
else:
|
|
258
|
+
self.C_fine = self.level_dims[0]
|
|
259
|
+
|
|
260
|
+
if self.out_channels % self.G != 0:
|
|
261
|
+
raise ValueError(f"out_channels={self.out_channels} must be divisible by G={self.G}")
|
|
262
|
+
out_g = self.C_fine//self.G
|
|
263
|
+
self.head_w = nn.Parameter(torch.empty(out_g, self.out_channels, self.KERNELSZ * self.KERNELSZ))
|
|
264
|
+
nn.init.kaiming_uniform_(self.head_w.view(self.out_channels * out_g, -1), a=np.sqrt(5))
|
|
265
|
+
self.head_bn = (nn.GroupNorm(num_groups=min(8, self.out_channels if self.out_channels > 1 else 1),
|
|
266
|
+
num_channels=self.out_channels)
|
|
267
|
+
if self.task == "segmentation" else None)
|
|
268
|
+
|
|
269
|
+
# ---------------- Device probe ----------------
|
|
270
|
+
pref = torch.device("cuda" if torch.cuda.is_available() else "cpu")
|
|
271
|
+
self.runtime_device = self._probe_and_set_runtime_device(pref)
|
|
272
|
+
|
|
273
|
+
# ---------------- device helpers ----------------
|
|
274
|
+
def _move_hc(self, hc: ho.SphericalStencil, device: torch.device) -> None:
|
|
275
|
+
for name, val in list(vars(hc).items()):
|
|
276
|
+
try:
|
|
277
|
+
if torch.is_tensor(val):
|
|
278
|
+
setattr(hc, name, val.to(device))
|
|
279
|
+
elif isinstance(val, (list, tuple)) and val and torch.is_tensor(val[0]):
|
|
280
|
+
setattr(hc, name, type(val)([v.to(device) for v in val]))
|
|
281
|
+
except Exception:
|
|
282
|
+
pass
|
|
283
|
+
|
|
284
|
+
@torch.no_grad()
|
|
285
|
+
def _probe_and_set_runtime_device(self, preferred: torch.device) -> torch.device:
|
|
286
|
+
if preferred.type == "cuda" and self.prefer_foscat_gpu:
|
|
287
|
+
try:
|
|
288
|
+
super().to(preferred)
|
|
289
|
+
for hc in self.hconv_levels + [self.hconv_token, self.hconv_head]:
|
|
290
|
+
self._move_hc(hc, preferred)
|
|
291
|
+
# dry run
|
|
292
|
+
npix0 = int(self.cell_ids_fine.shape[0])
|
|
293
|
+
x_try = torch.zeros(1, self.n_chan_in, npix0, device=preferred)
|
|
294
|
+
hc0 = self.hconv_levels[0] if len(self.hconv_levels) > 0 else self.hconv_head
|
|
295
|
+
y_try = hc0.Convol_torch(x_try, self.patch_w, cell_ids=self.cell_ids_fine)
|
|
296
|
+
_ = y_try.sum().item()
|
|
297
|
+
self._foscat_device = preferred
|
|
298
|
+
return preferred
|
|
299
|
+
except Exception as e:
|
|
300
|
+
self._gpu_probe_error = repr(e)
|
|
301
|
+
cpu = torch.device("cpu")
|
|
302
|
+
super().to(cpu)
|
|
303
|
+
for hc in self.hconv_levels + [self.hconv_token, self.hconv_head]:
|
|
304
|
+
self._move_hc(hc, cpu)
|
|
305
|
+
self._foscat_device = cpu
|
|
306
|
+
return cpu
|
|
307
|
+
|
|
308
|
+
def _to_numpy_ids(self, ids):
|
|
309
|
+
"""Return ids as a NumPy array on CPU (handles torch.Tensor on CUDA)."""
|
|
310
|
+
if torch.is_tensor(ids):
|
|
311
|
+
return ids.detach().cpu().numpy()
|
|
312
|
+
return np.asarray(ids)
|
|
313
|
+
|
|
314
|
+
# ---------------- helpers ----------------
|
|
315
|
+
def _as_tensor_batch(self, x):
|
|
316
|
+
if isinstance(x, list):
|
|
317
|
+
if len(x) == 1:
|
|
318
|
+
t = x[0]
|
|
319
|
+
return t.unsqueeze(0) if t.dim() == 2 else t
|
|
320
|
+
raise ValueError("Variable-length list not supported here; pass a tensor.")
|
|
321
|
+
return x
|
|
322
|
+
|
|
323
|
+
# ---------------- forward ----------------
|
|
324
|
+
def forward(
|
|
325
|
+
self,
|
|
326
|
+
x: torch.Tensor,
|
|
327
|
+
runtime_ids: Optional[np.ndarray] = None,
|
|
328
|
+
) -> torch.Tensor:
|
|
329
|
+
"""
|
|
330
|
+
x: (B, Cin, Nfine), nested ordering
|
|
331
|
+
runtime_ids: optional fine-level ids to decode onto (defaults to training ids)
|
|
332
|
+
"""
|
|
333
|
+
if not isinstance(x, torch.Tensor):
|
|
334
|
+
raise TypeError("x must be a torch.Tensor")
|
|
335
|
+
if x.dim() != 3:
|
|
336
|
+
raise ValueError("Input must be (B, Cin, Npix)")
|
|
337
|
+
if x.shape[1] != self.n_chan_in:
|
|
338
|
+
raise ValueError(f"Expected {self.n_chan_in} channels, got {x.shape[1]}")
|
|
339
|
+
if runtime_ids is not None:
|
|
340
|
+
runtime_ids = self._to_numpy_ids(runtime_ids)
|
|
341
|
+
|
|
342
|
+
x = x.to(self.runtime_device)
|
|
343
|
+
|
|
344
|
+
# -------- Patch embedding Cin -> C_fine --------
|
|
345
|
+
hc_fine0 = self.hconv_levels[0] if len(self.hconv_levels) > 0 else self.hconv_head
|
|
346
|
+
z = hc_fine0.Convol_torch(x, self.patch_w, cell_ids=self.cell_ids_fine) # (B, C_fine, Nfine)
|
|
347
|
+
if not torch.is_tensor(z):
|
|
348
|
+
z = torch.as_tensor(z, device=self.runtime_device)
|
|
349
|
+
z = self._as_tensor_batch(z)
|
|
350
|
+
z = self.patch_bn(z)
|
|
351
|
+
z = F.gelu(z)
|
|
352
|
+
|
|
353
|
+
# -------- Encoder path: for each level i: EncConv(C_i->C_{i+1}) then Down() --------
|
|
354
|
+
skips: List[torch.Tensor] = []
|
|
355
|
+
ids_list: List[np.ndarray] = []
|
|
356
|
+
|
|
357
|
+
l_data = z
|
|
358
|
+
l_cell_ids = self.cell_ids_fine if runtime_ids is None else np.asarray(runtime_ids)
|
|
359
|
+
current_nside = self.in_nside
|
|
360
|
+
|
|
361
|
+
for i, hc in enumerate(self.hconv_levels):
|
|
362
|
+
# save skip BEFORE going down (channels = C_i, grid = current level)
|
|
363
|
+
skips.append(self._as_tensor_batch(l_data))
|
|
364
|
+
ids_list.append(self._to_numpy_ids(l_cell_ids))
|
|
365
|
+
|
|
366
|
+
# conv to next channels C_{i+1} at same grid
|
|
367
|
+
w_enc = self.enc_w[i]
|
|
368
|
+
l_data = hc.Convol_torch(l_data, w_enc, cell_ids=l_cell_ids) # (B, C_{i+1}, N_current)
|
|
369
|
+
if not torch.is_tensor(l_data):
|
|
370
|
+
l_data = torch.as_tensor(l_data, device=self.runtime_device)
|
|
371
|
+
l_data = self._as_tensor_batch(l_data)
|
|
372
|
+
l_data = self.enc_bn[i](l_data)
|
|
373
|
+
l_data = F.gelu(l_data)
|
|
374
|
+
|
|
375
|
+
# Down one level
|
|
376
|
+
l_data, l_cell_ids = hc.Down(l_data, cell_ids=l_cell_ids, nside=current_nside, max_poll=True)
|
|
377
|
+
l_data = self._as_tensor_batch(l_data)
|
|
378
|
+
current_nside //= 2
|
|
379
|
+
|
|
380
|
+
# We are now at token grid with channels = C_token
|
|
381
|
+
x_tok = l_data # (B, C_token, Ntok)
|
|
382
|
+
token_ids = l_cell_ids # ids at token level
|
|
383
|
+
assert x_tok.shape[1] == self.embed_dim, "Token channels mismatch with embed_dim."
|
|
384
|
+
|
|
385
|
+
# -------- Transformer on tokens --------
|
|
386
|
+
seq = x_tok.permute(0, 2, 1) # (B, Ntok, E)
|
|
387
|
+
if self.cls_token_enabled:
|
|
388
|
+
cls = self.cls_token.expand(seq.size(0), -1, -1)
|
|
389
|
+
seq = torch.cat([cls, seq], dim=1)
|
|
390
|
+
if self.pos_embed is not None:
|
|
391
|
+
seq = seq + self.pos_embed[:, :seq.shape[1], :]
|
|
392
|
+
|
|
393
|
+
seq = self.encoder(seq) # (B, Ntok(+1), E)
|
|
394
|
+
if self.cls_token_enabled:
|
|
395
|
+
tokens = seq[:, 1:, :] # drop cls for dense
|
|
396
|
+
else:
|
|
397
|
+
tokens = seq
|
|
398
|
+
|
|
399
|
+
tok_feat = self.token_proj(tokens).permute(0, 2, 1) # (B, C_token, Ntok)
|
|
400
|
+
|
|
401
|
+
if self.task == "global":
|
|
402
|
+
if self.head_type == "cls" and self.cls_token_enabled:
|
|
403
|
+
cls_vec = seq[:, 0, :]
|
|
404
|
+
return nn.Linear(self.embed_dim, self.out_channels).to(seq.device)(cls_vec)
|
|
405
|
+
else:
|
|
406
|
+
return nn.Linear(self.embed_dim, self.out_channels).to(seq.device)(tokens.mean(dim=1))
|
|
407
|
+
|
|
408
|
+
# -------- Build runtime id chain (fine -> ... -> token) --------
|
|
409
|
+
fine_ids_runtime = self.cell_ids_fine if runtime_ids is None else np.asarray(runtime_ids)
|
|
410
|
+
ids_chain = [np.asarray(fine_ids_runtime)]
|
|
411
|
+
nside_tmp = self.in_nside
|
|
412
|
+
_dummy = self.f.backend.bk_cast(np.zeros((1, 1, ids_chain[0].shape[0]), dtype=self.np_dtype))
|
|
413
|
+
for hc in self.hconv_levels:
|
|
414
|
+
_dummy, _next = hc.Down(_dummy, cell_ids=ids_chain[-1], nside=nside_tmp, max_poll=True)
|
|
415
|
+
ids_chain.append(self.f.backend.to_numpy(_next))
|
|
416
|
+
nside_tmp //= 2
|
|
417
|
+
|
|
418
|
+
tok_ids_np = self._to_numpy_ids(token_ids)
|
|
419
|
+
|
|
420
|
+
assert tok_feat.shape[-1] == tok_ids_np.shape[0], "Token count mismatch."
|
|
421
|
+
assert np.array_equal(tok_ids_np, ids_chain[-1]), "Token ids mismatch with runtime chain."
|
|
422
|
+
|
|
423
|
+
# list of nsides at each encoder level (fine -> ... -> pre-token)
|
|
424
|
+
nsides_levels = [self.in_nside // (2 ** k) for k in range(self.token_down)]
|
|
425
|
+
|
|
426
|
+
# -------- Decoder: Up step-by-step with fusion conv --------
|
|
427
|
+
y = tok_feat # (B, C_token, Ntok)
|
|
428
|
+
dec_idx = 0 # index in self.dec_w / self.dec_bn (built from token->fine order)
|
|
429
|
+
for i in range(len(ids_chain)-1, 0, -1):
|
|
430
|
+
coarse_ids = ids_chain[i] # current y grid
|
|
431
|
+
fine_ids = ids_chain[i-1] # target grid
|
|
432
|
+
source_ns = self.in_nside // (2 ** i)
|
|
433
|
+
fine_ns = self.in_nside // (2 ** (i-1))
|
|
434
|
+
|
|
435
|
+
# choose operator for the target fine level
|
|
436
|
+
if fine_ns == self.in_nside:
|
|
437
|
+
op_fine = self.hconv_head
|
|
438
|
+
else:
|
|
439
|
+
idx = nsides_levels.index(fine_ns)
|
|
440
|
+
op_fine = self.hconv_levels[idx]
|
|
441
|
+
|
|
442
|
+
# Up one level
|
|
443
|
+
y_up = op_fine.Up(y, cell_ids=coarse_ids, o_cell_ids=fine_ids, nside=source_ns)
|
|
444
|
+
if not torch.is_tensor(y_up):
|
|
445
|
+
y_up = torch.as_tensor(y_up, device=self.runtime_device)
|
|
446
|
+
y_up = self._as_tensor_batch(y_up) # (B, C_{i}, N_fine)
|
|
447
|
+
|
|
448
|
+
# Skip at this level (channels = C_{i-1})
|
|
449
|
+
skip_i = self._as_tensor_batch(skips[i-1]).to(y_up.device)
|
|
450
|
+
assert np.array_equal(np.asarray(ids_list[i-1]), np.asarray(fine_ids)), "Skip ids misaligned."
|
|
451
|
+
|
|
452
|
+
# Concat and fuse: (C_{i} + C_{i-1}) -> C_{i-1}
|
|
453
|
+
y_cat = torch.cat([y_up, skip_i], dim=1)
|
|
454
|
+
y = op_fine.Convol_torch(y_cat, self.dec_w[dec_idx], cell_ids=fine_ids)
|
|
455
|
+
if not torch.is_tensor(y):
|
|
456
|
+
y = torch.as_tensor(y, device=self.runtime_device)
|
|
457
|
+
y = self._as_tensor_batch(y)
|
|
458
|
+
y = self.dec_bn[dec_idx](y)
|
|
459
|
+
y = F.gelu(y)
|
|
460
|
+
if self.dropout > 0:
|
|
461
|
+
y = F.dropout(y, p=self.dropout, training=self.training)
|
|
462
|
+
dec_idx += 1
|
|
463
|
+
|
|
464
|
+
# y is now (B, C_fine, Nfine)
|
|
465
|
+
# -------- Final head to out_channels --------
|
|
466
|
+
y = self.hconv_head.Convol_torch(y, self.head_w, cell_ids=fine_ids_runtime)
|
|
467
|
+
if not torch.is_tensor(y):
|
|
468
|
+
y = torch.as_tensor(y, device=self.runtime_device)
|
|
469
|
+
y = self._as_tensor_batch(y)
|
|
470
|
+
if self.task == "segmentation" and self.head_bn is not None:
|
|
471
|
+
y = self.head_bn(y)
|
|
472
|
+
|
|
473
|
+
if self.final_activation == "sigmoid":
|
|
474
|
+
y = torch.sigmoid(y)
|
|
475
|
+
elif self.final_activation == "softmax":
|
|
476
|
+
y = torch.softmax(y, dim=1)
|
|
477
|
+
return y
|
|
478
|
+
|
|
479
|
+
@torch.no_grad()
|
|
480
|
+
def predict(self, x: Union[torch.Tensor, np.ndarray], batch_size: int = 8) -> torch.Tensor:
|
|
481
|
+
self.eval()
|
|
482
|
+
if isinstance(x, np.ndarray):
|
|
483
|
+
x = torch.from_numpy(x).float()
|
|
484
|
+
outs = []
|
|
485
|
+
for i in range(0, x.shape[0], batch_size):
|
|
486
|
+
xb = x[i : i + batch_size].to(self.runtime_device)
|
|
487
|
+
outs.append(self.forward(xb))
|
|
488
|
+
return torch.cat(outs, dim=0)
|
|
489
|
+
|
|
490
|
+
|
|
491
|
+
# -------------------------- Smoke test --------------------------
|
|
492
|
+
if __name__ == "__main__":
|
|
493
|
+
# nside=4 → Npix=192, 2 down levels → token_nside=1
|
|
494
|
+
in_nside = 4
|
|
495
|
+
npix = 12 * in_nside * in_nside
|
|
496
|
+
cell_ids = np.arange(npix, dtype=np.int64)
|
|
497
|
+
|
|
498
|
+
B, Cin = 2, 3
|
|
499
|
+
x = torch.randn(B, Cin, npix)
|
|
500
|
+
|
|
501
|
+
# Channel widths per level (fine -> token), divisible by G=1 here
|
|
502
|
+
level_dims = [64, 96, 128]
|
|
503
|
+
|
|
504
|
+
model = HealpixViTVarLevels(
|
|
505
|
+
in_nside=in_nside,
|
|
506
|
+
n_chan_in=Cin,
|
|
507
|
+
level_dims=level_dims, # len=3 => token_down=2
|
|
508
|
+
depth=2,
|
|
509
|
+
num_heads=4,
|
|
510
|
+
cell_ids=cell_ids,
|
|
511
|
+
task="regression",
|
|
512
|
+
out_channels=1,
|
|
513
|
+
KERNELSZ=3,
|
|
514
|
+
G=1,
|
|
515
|
+
cls_token=False,
|
|
516
|
+
dropout=0.1,
|
|
517
|
+
).eval()
|
|
518
|
+
|
|
519
|
+
with torch.no_grad():
|
|
520
|
+
y = model(x)
|
|
521
|
+
print("Output:", y.shape) # (B, Cout, Nfine)
|