foscat 2025.9.5__py3-none-any.whl → 2025.11.1__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- foscat/BkTorch.py +635 -141
- foscat/FoCUS.py +92 -50
- foscat/Plot.py +4 -3
- foscat/healpix_unet_torch.py +17 -1
- foscat/healpix_vit_skip.py +445 -0
- foscat/healpix_vit_torch.py +521 -0
- foscat/planar_vit.py +206 -0
- foscat/scat.py +1 -1
- foscat/scat1D.py +1 -1
- foscat/scat_cov.py +2 -2
- foscat/unet_2_d_from_healpix_params.py +421 -0
- {foscat-2025.9.5.dist-info → foscat-2025.11.1.dist-info}/METADATA +1 -1
- {foscat-2025.9.5.dist-info → foscat-2025.11.1.dist-info}/RECORD +16 -12
- {foscat-2025.9.5.dist-info → foscat-2025.11.1.dist-info}/WHEEL +0 -0
- {foscat-2025.9.5.dist-info → foscat-2025.11.1.dist-info}/licenses/LICENSE +0 -0
- {foscat-2025.9.5.dist-info → foscat-2025.11.1.dist-info}/top_level.txt +0 -0
|
@@ -0,0 +1,445 @@
|
|
|
1
|
+
# healpix_vit_skip.py
|
|
2
|
+
# HEALPix ViT U-Net with temporal encoders and Transformer-based skip fusion.
|
|
3
|
+
# - Multi-level HEALPix pyramid using Foscat.SphericalStencil
|
|
4
|
+
# - Per-level temporal encoding (sequence over T_in months) at encoder
|
|
5
|
+
# - Decoder uses cross-attention to fuse upsampled features with encoder skips
|
|
6
|
+
# - Double spherical convolution + GroupNorm + GELU at each encoder/decoder level
|
|
7
|
+
|
|
8
|
+
from __future__ import annotations
|
|
9
|
+
from typing import List, Optional, Literal
|
|
10
|
+
import numpy as np
|
|
11
|
+
|
|
12
|
+
import torch
|
|
13
|
+
import torch.nn as nn
|
|
14
|
+
import torch.nn.functional as F
|
|
15
|
+
|
|
16
|
+
import foscat.scat_cov as sc
|
|
17
|
+
import foscat.SphericalStencil as ho
|
|
18
|
+
|
|
19
|
+
|
|
20
|
+
class MLP(nn.Module):
|
|
21
|
+
def __init__(self, d: int, hidden_mult: int = 4, drop: float = 0.0):
|
|
22
|
+
super().__init__()
|
|
23
|
+
self.net = nn.Sequential(
|
|
24
|
+
nn.LayerNorm(d),
|
|
25
|
+
nn.Linear(d, hidden_mult * d),
|
|
26
|
+
nn.GELU(),
|
|
27
|
+
nn.Dropout(drop),
|
|
28
|
+
nn.Linear(hidden_mult * d, d),
|
|
29
|
+
nn.Dropout(drop),
|
|
30
|
+
)
|
|
31
|
+
def forward(self, x):
|
|
32
|
+
return self.net(x)
|
|
33
|
+
|
|
34
|
+
|
|
35
|
+
class HealpixViTSkip(nn.Module):
|
|
36
|
+
def __init__(
|
|
37
|
+
self,
|
|
38
|
+
*,
|
|
39
|
+
in_nside: int,
|
|
40
|
+
n_chan_in: int,
|
|
41
|
+
level_dims: List[int],
|
|
42
|
+
depth_token: int,
|
|
43
|
+
num_heads_token: int,
|
|
44
|
+
cell_ids: np.ndarray,
|
|
45
|
+
task: Literal["regression","segmentation","global"] = "regression",
|
|
46
|
+
out_channels: int = 1,
|
|
47
|
+
mlp_ratio_token: float = 4.0,
|
|
48
|
+
KERNELSZ: int = 3,
|
|
49
|
+
gauge_type: Literal["cosmo","phi"] = "cosmo",
|
|
50
|
+
G: int = 1,
|
|
51
|
+
prefer_foscat_gpu: bool = True,
|
|
52
|
+
dropout: float = 0.1,
|
|
53
|
+
dtype: Literal["float32","float64"] = "float32",
|
|
54
|
+
pos_embed_per_level: bool = True,
|
|
55
|
+
) -> None:
|
|
56
|
+
super().__init__()
|
|
57
|
+
|
|
58
|
+
self.in_nside = int(in_nside)
|
|
59
|
+
self.n_chan_in = int(n_chan_in)
|
|
60
|
+
self.level_dims = list(level_dims)
|
|
61
|
+
self.token_down = len(self.level_dims) - 1
|
|
62
|
+
assert self.token_down >= 0
|
|
63
|
+
self.C_fine = int(self.level_dims[0])
|
|
64
|
+
self.embed_dim = int(self.level_dims[-1])
|
|
65
|
+
self.depth_token = int(depth_token)
|
|
66
|
+
self.num_heads_token = int(num_heads_token)
|
|
67
|
+
self.mlp_ratio_token = float(mlp_ratio_token)
|
|
68
|
+
self.task = task
|
|
69
|
+
self.out_channels = int(out_channels)
|
|
70
|
+
self.KERNELSZ = int(KERNELSZ)
|
|
71
|
+
self.gauge_type = gauge_type
|
|
72
|
+
self.G = int(G)
|
|
73
|
+
self.prefer_foscat_gpu = bool(prefer_foscat_gpu)
|
|
74
|
+
self.dropout = float(dropout)
|
|
75
|
+
self.dtype = dtype
|
|
76
|
+
self.pos_embed_per_level = bool(pos_embed_per_level)
|
|
77
|
+
|
|
78
|
+
for d in self.level_dims:
|
|
79
|
+
if d % self.G != 0:
|
|
80
|
+
raise ValueError(f"All level_dims must be divisible by G={self.G}, got {d}.")
|
|
81
|
+
if self.embed_dim % self.num_heads_token != 0:
|
|
82
|
+
raise ValueError("embed_dim must be divisible by num_heads_token.")
|
|
83
|
+
|
|
84
|
+
if dtype == "float32":
|
|
85
|
+
self.np_dtype = np.float32
|
|
86
|
+
self.torch_dtype = torch.float32
|
|
87
|
+
else:
|
|
88
|
+
self.np_dtype = np.float64
|
|
89
|
+
self.torch_dtype = torch.float32
|
|
90
|
+
|
|
91
|
+
if cell_ids is None:
|
|
92
|
+
raise ValueError("cell_ids (finest) must be provided.")
|
|
93
|
+
self.cell_ids_fine = np.asarray(cell_ids)
|
|
94
|
+
|
|
95
|
+
if self.task == "segmentation":
|
|
96
|
+
self.final_activation = "sigmoid" if self.out_channels == 1 else "softmax"
|
|
97
|
+
else:
|
|
98
|
+
self.final_activation = "none"
|
|
99
|
+
|
|
100
|
+
self.f = sc.funct(KERNELSZ=self.KERNELSZ)
|
|
101
|
+
|
|
102
|
+
# Build stencils
|
|
103
|
+
self.hconv_levels: List[ho.SphericalStencil] = []
|
|
104
|
+
self.level_cell_ids: List[np.ndarray] = [self.cell_ids_fine]
|
|
105
|
+
current_nside = self.in_nside
|
|
106
|
+
dummy = self.f.backend.bk_cast(np.zeros((1, 1, self.cell_ids_fine.shape[0]), dtype=self.np_dtype))
|
|
107
|
+
for _ in range(self.token_down):
|
|
108
|
+
hc = ho.SphericalStencil(current_nside, self.KERNELSZ, n_gauges=self.G,
|
|
109
|
+
gauge_type=self.gauge_type, cell_ids=self.level_cell_ids[-1],
|
|
110
|
+
dtype=self.torch_dtype)
|
|
111
|
+
self.hconv_levels.append(hc)
|
|
112
|
+
dummy, next_ids = hc.Down(dummy, cell_ids=self.level_cell_ids[-1], nside=current_nside, max_poll=True)
|
|
113
|
+
self.level_cell_ids.append(self.f.backend.to_numpy(next_ids))
|
|
114
|
+
current_nside //= 2
|
|
115
|
+
|
|
116
|
+
self.token_nside = current_nside if self.token_down > 0 else self.in_nside
|
|
117
|
+
self.token_cell_ids = self.level_cell_ids[-1]
|
|
118
|
+
|
|
119
|
+
self.hconv_token = ho.SphericalStencil(self.token_nside, self.KERNELSZ, n_gauges=self.G,
|
|
120
|
+
gauge_type=self.gauge_type, cell_ids=self.token_cell_ids, dtype=self.torch_dtype)
|
|
121
|
+
self.hconv_head = ho.SphericalStencil(self.in_nside, self.KERNELSZ, n_gauges=self.G,
|
|
122
|
+
gauge_type=self.gauge_type, cell_ids=self.cell_ids_fine, dtype=self.torch_dtype)
|
|
123
|
+
|
|
124
|
+
self.nsides_levels = [self.in_nside // (2**i) for i in range(self.token_down+1)]
|
|
125
|
+
self.ntokens_levels = [12 * n * n for n in self.nsides_levels]
|
|
126
|
+
|
|
127
|
+
# Patch embed (double conv)
|
|
128
|
+
fine_g = self.C_fine // self.G
|
|
129
|
+
self.pe_w1 = nn.Parameter(torch.empty(self.n_chan_in, fine_g, self.KERNELSZ*self.KERNELSZ))
|
|
130
|
+
nn.init.kaiming_uniform_(self.pe_w1.view(self.n_chan_in * fine_g, -1), a=np.sqrt(5))
|
|
131
|
+
self.pe_w2 = nn.Parameter(torch.empty(self.C_fine, fine_g, self.KERNELSZ*self.KERNELSZ))
|
|
132
|
+
nn.init.kaiming_uniform_(self.pe_w2.view(self.C_fine * fine_g, -1), a=np.sqrt(5))
|
|
133
|
+
self.pe_bn1 = nn.GroupNorm(num_groups=min(8, self.C_fine if self.C_fine>1 else 1), num_channels=self.C_fine)
|
|
134
|
+
self.pe_bn2 = nn.GroupNorm(num_groups=min(8, self.C_fine if self.C_fine>1 else 1), num_channels=self.C_fine)
|
|
135
|
+
|
|
136
|
+
# Encoder double convs
|
|
137
|
+
self.enc_w1 = nn.ParameterList()
|
|
138
|
+
self.enc_w2 = nn.ParameterList()
|
|
139
|
+
self.enc_bn1 = nn.ModuleList()
|
|
140
|
+
self.enc_bn2 = nn.ModuleList()
|
|
141
|
+
for i in range(self.token_down):
|
|
142
|
+
Cin = self.level_dims[i]
|
|
143
|
+
Cout = self.level_dims[i+1]
|
|
144
|
+
Cout_g = Cout // self.G
|
|
145
|
+
w1 = nn.Parameter(torch.empty(Cin, Cout_g, self.KERNELSZ*self.KERNELSZ))
|
|
146
|
+
nn.init.kaiming_uniform_(w1.view(Cin * Cout_g, -1), a=np.sqrt(5))
|
|
147
|
+
w2 = nn.Parameter(torch.empty(Cout, Cout_g, self.KERNELSZ*self.KERNELSZ))
|
|
148
|
+
nn.init.kaiming_uniform_(w2.view(Cout * Cout_g, -1), a=np.sqrt(5))
|
|
149
|
+
self.enc_w1.append(w1); self.enc_w2.append(w2)
|
|
150
|
+
self.enc_bn1.append(nn.GroupNorm(num_groups=min(8, Cout if Cout>1 else 1), num_channels=Cout))
|
|
151
|
+
self.enc_bn2.append(nn.GroupNorm(num_groups=min(8, Cout if Cout>1 else 1), num_channels=Cout))
|
|
152
|
+
|
|
153
|
+
# Temporal encoders per level (fine..pre-token)
|
|
154
|
+
self.temporal_encoders = nn.ModuleList([
|
|
155
|
+
nn.TransformerEncoder(
|
|
156
|
+
nn.TransformerEncoderLayer(
|
|
157
|
+
d_model=self.level_dims[i],
|
|
158
|
+
nhead=max(1, min(8, self.level_dims[i] // 64)),
|
|
159
|
+
dim_feedforward=2*self.level_dims[i],
|
|
160
|
+
dropout=self.dropout,
|
|
161
|
+
activation='gelu',
|
|
162
|
+
batch_first=True,
|
|
163
|
+
norm_first=True,
|
|
164
|
+
),
|
|
165
|
+
num_layers=2,
|
|
166
|
+
)
|
|
167
|
+
for i in range(self.token_down)
|
|
168
|
+
])
|
|
169
|
+
|
|
170
|
+
# Token-level Transformer
|
|
171
|
+
self.n_tokens = int(self.token_cell_ids.shape[0])
|
|
172
|
+
self.pos_token = nn.Parameter(torch.zeros(1, self.n_tokens, self.embed_dim))
|
|
173
|
+
nn.init.trunc_normal_(self.pos_token, std=0.02)
|
|
174
|
+
enc_layer = nn.TransformerEncoderLayer(
|
|
175
|
+
d_model=self.embed_dim,
|
|
176
|
+
nhead=self.num_heads_token,
|
|
177
|
+
dim_feedforward=int(self.embed_dim * self.mlp_ratio_token),
|
|
178
|
+
dropout=self.dropout,
|
|
179
|
+
activation='gelu',
|
|
180
|
+
batch_first=True,
|
|
181
|
+
norm_first=True,
|
|
182
|
+
)
|
|
183
|
+
self.encoder_token = nn.TransformerEncoder(enc_layer, num_layers=self.depth_token)
|
|
184
|
+
|
|
185
|
+
# Decoder fusion modules per level (cross-attention)
|
|
186
|
+
self.dec_q = nn.ModuleList()
|
|
187
|
+
self.dec_k = nn.ModuleList()
|
|
188
|
+
self.dec_v = nn.ModuleList()
|
|
189
|
+
self.dec_attn = nn.ModuleList()
|
|
190
|
+
self.dec_mlp = nn.ModuleList()
|
|
191
|
+
self.level_pos = nn.ParameterList() if self.pos_embed_per_level else None
|
|
192
|
+
for i in range(self.token_down, 0, -1):
|
|
193
|
+
Cfine = self.level_dims[i-1]
|
|
194
|
+
d_fuse = Cfine
|
|
195
|
+
self.dec_q.append(nn.Linear(Cfine, d_fuse))
|
|
196
|
+
self.dec_k.append(nn.Linear(Cfine, d_fuse))
|
|
197
|
+
self.dec_v.append(nn.Linear(Cfine, d_fuse))
|
|
198
|
+
self.dec_attn.append(nn.MultiheadAttention(embed_dim=d_fuse, num_heads=max(1, min(8, d_fuse // 64)), batch_first=True))
|
|
199
|
+
self.dec_mlp.append(MLP(d_fuse, hidden_mult=4, drop=self.dropout))
|
|
200
|
+
if self.pos_embed_per_level:
|
|
201
|
+
n_tok_i = self.ntokens_levels[i-1]
|
|
202
|
+
p = nn.Parameter(torch.zeros(1, n_tok_i, d_fuse))
|
|
203
|
+
nn.init.trunc_normal_(p, std=0.02)
|
|
204
|
+
self.level_pos.append(p)
|
|
205
|
+
|
|
206
|
+
# Decoder refinement double convs
|
|
207
|
+
self.dec_refine_w1 = nn.ParameterList()
|
|
208
|
+
self.dec_refine_w2 = nn.ParameterList()
|
|
209
|
+
self.dec_refine_bn1 = nn.ModuleList()
|
|
210
|
+
self.dec_refine_bn2 = nn.ModuleList()
|
|
211
|
+
for i in range(self.token_down, 0, -1):
|
|
212
|
+
Cfine = self.level_dims[i-1]
|
|
213
|
+
Cfine_g = Cfine // self.G
|
|
214
|
+
w1 = nn.Parameter(torch.empty(Cfine, Cfine_g, self.KERNELSZ*self.KERNELSZ))
|
|
215
|
+
nn.init.kaiming_uniform_(w1.view(Cfine * Cfine_g, -1), a=np.sqrt(5))
|
|
216
|
+
w2 = nn.Parameter(torch.empty(Cfine, Cfine_g, self.KERNELSZ*self.KERNELSZ))
|
|
217
|
+
nn.init.kaiming_uniform_(w2.view(Cfine * Cfine_g, -1), a=np.sqrt(5))
|
|
218
|
+
self.dec_refine_w1.append(w1); self.dec_refine_w2.append(w2)
|
|
219
|
+
self.dec_refine_bn1.append(nn.GroupNorm(num_groups=min(8, Cfine if Cfine>1 else 1), num_channels=Cfine))
|
|
220
|
+
self.dec_refine_bn2.append(nn.GroupNorm(num_groups=min(8, Cfine if Cfine>1 else 1), num_channels=Cfine))
|
|
221
|
+
|
|
222
|
+
# Head
|
|
223
|
+
if self.task == "global":
|
|
224
|
+
self.global_head = nn.Linear(self.embed_dim, self.out_channels)
|
|
225
|
+
else:
|
|
226
|
+
if self.out_channels % self.G != 0:
|
|
227
|
+
raise ValueError(f"out_channels={self.out_channels} must be divisible by G={self.G}")
|
|
228
|
+
out_g = self.out_channels // self.G
|
|
229
|
+
self.head_w = nn.Parameter(torch.empty(self.C_fine, out_g, self.KERNELSZ*self.KERNELSZ))
|
|
230
|
+
nn.init.kaiming_uniform_(self.head_w.view(self.C_fine * out_g, -1), a=np.sqrt(5))
|
|
231
|
+
self.head_bn = nn.GroupNorm(num_groups=min(8, self.out_channels if self.out_channels>1 else 1),
|
|
232
|
+
num_channels=self.out_channels) if self.task=="segmentation" else None
|
|
233
|
+
|
|
234
|
+
pref = torch.device("cuda" if torch.cuda.is_available() else "cpu")
|
|
235
|
+
self.runtime_device = self._probe_and_set_runtime_device(pref)
|
|
236
|
+
|
|
237
|
+
def _move_hc(self, hc: ho.SphericalStencil, device: torch.device) -> None:
|
|
238
|
+
for name, val in list(vars(hc).items()):
|
|
239
|
+
try:
|
|
240
|
+
if torch.is_tensor(val):
|
|
241
|
+
setattr(hc, name, val.to(device))
|
|
242
|
+
elif isinstance(val, (list, tuple)) and val and torch.is_tensor(val[0]):
|
|
243
|
+
setattr(hc, name, type(val)([v.to(device) for v in val]))
|
|
244
|
+
except Exception:
|
|
245
|
+
pass
|
|
246
|
+
|
|
247
|
+
@torch.no_grad()
|
|
248
|
+
def _probe_and_set_runtime_device(self, preferred: torch.device) -> torch.device:
|
|
249
|
+
if preferred.type == "cuda":
|
|
250
|
+
try:
|
|
251
|
+
super().to(preferred)
|
|
252
|
+
for hc in self.hconv_levels + [self.hconv_token, self.hconv_head]:
|
|
253
|
+
self._move_hc(hc, preferred)
|
|
254
|
+
npix0 = int(self.cell_ids_fine.shape[0])
|
|
255
|
+
x_try = torch.zeros(1, self.n_chan_in, npix0, device=preferred)
|
|
256
|
+
hc0 = self.hconv_levels[0] if len(self.hconv_levels)>0 else self.hconv_head
|
|
257
|
+
y_try = hc0.Convol_torch(x_try, self.pe_w1, cell_ids=self.cell_ids_fine)
|
|
258
|
+
_ = (y_try if torch.is_tensor(y_try) else torch.as_tensor(y_try, device=preferred)).sum().item()
|
|
259
|
+
self._foscat_device = preferred
|
|
260
|
+
return preferred
|
|
261
|
+
except Exception:
|
|
262
|
+
pass
|
|
263
|
+
cpu = torch.device("cpu")
|
|
264
|
+
super().to(cpu)
|
|
265
|
+
for hc in self.hconv_levels + [self.hconv_token, self.hconv_head]:
|
|
266
|
+
self._move_hc(hc, cpu)
|
|
267
|
+
self._foscat_device = cpu
|
|
268
|
+
return cpu
|
|
269
|
+
|
|
270
|
+
def _as_tensor_batch(self, x):
|
|
271
|
+
if isinstance(x, list):
|
|
272
|
+
if len(x) == 1:
|
|
273
|
+
t = x[0]
|
|
274
|
+
return t.unsqueeze(0) if t.dim() == 2 else t
|
|
275
|
+
raise ValueError("Variable-length list not supported here; pass a tensor.")
|
|
276
|
+
return x
|
|
277
|
+
|
|
278
|
+
def _to_numpy_ids(self, ids):
|
|
279
|
+
if torch.is_tensor(ids):
|
|
280
|
+
return ids.detach().cpu().numpy()
|
|
281
|
+
return np.asarray(ids)
|
|
282
|
+
|
|
283
|
+
def _patch_embed_fine(self, x_t: torch.Tensor) -> torch.Tensor:
|
|
284
|
+
hc0 = self.hconv_levels[0] if len(self.hconv_levels)>0 else self.hconv_head
|
|
285
|
+
z = hc0.Convol_torch(x_t, self.pe_w1, cell_ids=self.cell_ids_fine)
|
|
286
|
+
z = self._as_tensor_batch(z if torch.is_tensor(z) else torch.as_tensor(z, device=self.runtime_device))
|
|
287
|
+
z = self.pe_bn1(z); z = F.gelu(z)
|
|
288
|
+
z = hc0.Convol_torch(z, self.pe_w2, cell_ids=self.cell_ids_fine)
|
|
289
|
+
z = self._as_tensor_batch(z if torch.is_tensor(z) else torch.as_tensor(z, device=self.runtime_device))
|
|
290
|
+
z = self.pe_bn2(z); z = F.gelu(z)
|
|
291
|
+
return z
|
|
292
|
+
|
|
293
|
+
def forward(self, x: torch.Tensor, runtime_ids: Optional[np.ndarray] = None) -> torch.Tensor:
|
|
294
|
+
if x.dim() != 4:
|
|
295
|
+
raise ValueError("Expected input shape (B, T_in, C_in, Npix)")
|
|
296
|
+
B, T_in, C_in, Nf = x.shape
|
|
297
|
+
if C_in != self.n_chan_in:
|
|
298
|
+
raise ValueError(f"Expected n_chan_in={self.n_chan_in}, got {C_in}")
|
|
299
|
+
x = x.to(self.runtime_device)
|
|
300
|
+
|
|
301
|
+
fine_ids_runtime = self.cell_ids_fine if runtime_ids is None else self._to_numpy_ids(runtime_ids)
|
|
302
|
+
ids_chain = [np.asarray(fine_ids_runtime)]
|
|
303
|
+
nside_tmp = self.in_nside
|
|
304
|
+
_dummy = self.f.backend.bk_cast(np.zeros((1, 1, ids_chain[0].shape[0]), dtype=self.np_dtype))
|
|
305
|
+
for hc in self.hconv_levels:
|
|
306
|
+
_dummy, _next = hc.Down(_dummy, cell_ids=ids_chain[-1], nside=nside_tmp, max_poll=True)
|
|
307
|
+
ids_chain.append(self.f.backend.to_numpy(_next))
|
|
308
|
+
nside_tmp //= 2
|
|
309
|
+
|
|
310
|
+
# Encoder histories per level
|
|
311
|
+
l_hist: List[torch.Tensor] = []
|
|
312
|
+
l_ids: List[np.ndarray] = []
|
|
313
|
+
|
|
314
|
+
feats_fine = []
|
|
315
|
+
for t in range(T_in):
|
|
316
|
+
zt = self._patch_embed_fine(x[:, t, :, :])
|
|
317
|
+
feats_fine.append(zt.unsqueeze(1))
|
|
318
|
+
feats_fine = torch.cat(feats_fine, dim=1) # (B, T_in, C_fine, N_fine)
|
|
319
|
+
l_hist.append(feats_fine)
|
|
320
|
+
l_ids.append(self.cell_ids_fine)
|
|
321
|
+
|
|
322
|
+
current_nside = self.in_nside
|
|
323
|
+
l_data_hist = feats_fine
|
|
324
|
+
for i, hc in enumerate(self.hconv_levels):
|
|
325
|
+
Cin = self.level_dims[i]
|
|
326
|
+
Cout = self.level_dims[i+1]
|
|
327
|
+
w1, w2 = self.enc_w1[i], self.enc_w2[i]
|
|
328
|
+
feats_next = []
|
|
329
|
+
for t in range(T_in):
|
|
330
|
+
zt = l_data_hist[:, t, :, :]
|
|
331
|
+
zt = hc.Convol_torch(zt, w1, cell_ids=l_ids[-1])
|
|
332
|
+
zt = self._as_tensor_batch(zt if torch.is_tensor(zt) else torch.as_tensor(zt, device=self.runtime_device))
|
|
333
|
+
zt = self.enc_bn1[i](zt); zt = F.gelu(zt)
|
|
334
|
+
zt = hc.Convol_torch(zt, w2, cell_ids=l_ids[-1])
|
|
335
|
+
zt = self._as_tensor_batch(zt if torch.is_tensor(zt) else torch.as_tensor(zt, device=self.runtime_device))
|
|
336
|
+
zt = self.enc_bn2[i](zt); zt = F.gelu(zt)
|
|
337
|
+
feats_next.append(zt.unsqueeze(1))
|
|
338
|
+
feats_next = torch.cat(feats_next, dim=1) # (B, T_in, Cout, N_i)
|
|
339
|
+
|
|
340
|
+
feats_down = []
|
|
341
|
+
next_ids_list = None
|
|
342
|
+
for t in range(T_in):
|
|
343
|
+
zt, next_ids = hc.Down(feats_next[:, t, :, :], cell_ids=l_ids[-1], nside=current_nside, max_poll=True)
|
|
344
|
+
zt = self._as_tensor_batch(zt)
|
|
345
|
+
feats_down.append(zt.unsqueeze(1))
|
|
346
|
+
next_ids_list = next_ids
|
|
347
|
+
feats_down = torch.cat(feats_down, dim=1) # (B, T_in, Cout, N_{i+1})
|
|
348
|
+
|
|
349
|
+
l_hist.append(feats_down)
|
|
350
|
+
l_ids.append(self.f.backend.to_numpy(next_ids_list))
|
|
351
|
+
l_data_hist = feats_down
|
|
352
|
+
current_nside //= 2
|
|
353
|
+
|
|
354
|
+
# Temporal encoder on skips (levels 0..token_down-1)
|
|
355
|
+
skips: List[torch.Tensor] = []
|
|
356
|
+
for i in range(self.token_down):
|
|
357
|
+
Bx, Tx, Cx, Nx = l_hist[i].shape
|
|
358
|
+
z = l_hist[i].permute(0, 3, 1, 2).reshape(Bx*Nx, Tx, Cx)
|
|
359
|
+
z = self.temporal_encoders[i](z)
|
|
360
|
+
z = z.mean(dim=1)
|
|
361
|
+
H_i = z.view(Bx, Nx, Cx).permute(0, 2, 1).contiguous()
|
|
362
|
+
skips.append(H_i)
|
|
363
|
+
|
|
364
|
+
# Token-level transformer (spatial)
|
|
365
|
+
x_tok_hist = l_hist[-1] # (B, T_in, E, Ntok)
|
|
366
|
+
x_tok = x_tok_hist.mean(dim=1) # (B, E, Ntok) (could add temporal encoder here as well)
|
|
367
|
+
seq = x_tok.permute(0, 2, 1) + self.pos_token[:, :x_tok.shape[2], :]
|
|
368
|
+
seq = self.encoder_token(seq)
|
|
369
|
+
y = seq.permute(0, 2, 1) # (B, E, Ntok)
|
|
370
|
+
|
|
371
|
+
if self.task == "global":
|
|
372
|
+
g = seq.mean(dim=1)
|
|
373
|
+
return self.global_head(g)
|
|
374
|
+
|
|
375
|
+
# Decoder: Up + cross-attn fusion + double conv refinement
|
|
376
|
+
dec_idx = 0
|
|
377
|
+
for i in range(self.token_down, 0, -1):
|
|
378
|
+
coarse_ids = ids_chain[i]
|
|
379
|
+
fine_ids = ids_chain[i-1]
|
|
380
|
+
source_ns = self.in_nside // (2 ** i)
|
|
381
|
+
fine_ns = self.in_nside // (2 ** (i-1))
|
|
382
|
+
Cfine = self.level_dims[i-1]
|
|
383
|
+
|
|
384
|
+
op_fine = self.hconv_head if fine_ns == self.in_nside else self.hconv_levels[self.nsides_levels.index(fine_ns)]
|
|
385
|
+
|
|
386
|
+
y_up = op_fine.Up(y, cell_ids=coarse_ids, o_cell_ids=fine_ids, nside=source_ns)
|
|
387
|
+
y_up = self._as_tensor_batch(y_up if torch.is_tensor(y_up) else torch.as_tensor(y_up, device=self.runtime_device)) # (B, Cfine, N)
|
|
388
|
+
|
|
389
|
+
skip_i = skips[i-1] # (B, Cfine, N)
|
|
390
|
+
q = self.dec_q[dec_idx](y_up.permute(0,2,1))
|
|
391
|
+
k = self.dec_k[dec_idx](skip_i.permute(0,2,1))
|
|
392
|
+
v = self.dec_v[dec_idx](skip_i.permute(0,2,1))
|
|
393
|
+
if self.pos_embed_per_level:
|
|
394
|
+
pos = self.level_pos[dec_idx][:, :q.shape[1], :]
|
|
395
|
+
q = q + pos; k = k + pos
|
|
396
|
+
z, _ = self.dec_attn[dec_idx](q, k, v)
|
|
397
|
+
z = self.dec_mlp[dec_idx](z)
|
|
398
|
+
z = z.permute(0,2,1).contiguous() # (B, Cfine, N)
|
|
399
|
+
|
|
400
|
+
z = op_fine.Convol_torch(z, self.dec_refine_w1[dec_idx], cell_ids=fine_ids)
|
|
401
|
+
z = self._as_tensor_batch(z if torch.is_tensor(z) else torch.as_tensor(z, device=self.runtime_device))
|
|
402
|
+
z = self.dec_refine_bn1[dec_idx](z); z = F.gelu(z)
|
|
403
|
+
z = op_fine.Convol_torch(z, self.dec_refine_w2[dec_idx], cell_ids=fine_ids)
|
|
404
|
+
z = self._as_tensor_batch(z if torch.is_tensor(z) else torch.as_tensor(z, device=self.runtime_device))
|
|
405
|
+
z = self.dec_refine_bn2[dec_idx](z); z = F.gelu(z)
|
|
406
|
+
|
|
407
|
+
y = z
|
|
408
|
+
dec_idx += 1
|
|
409
|
+
|
|
410
|
+
y = self.hconv_head.Convol_torch(y, self.head_w, cell_ids=fine_ids_runtime)
|
|
411
|
+
y = self._as_tensor_batch(y if torch.is_tensor(y) else torch.as_tensor(y, device=self.runtime_device))
|
|
412
|
+
if self.task == "segmentation" and self.head_bn is not None:
|
|
413
|
+
y = self.head_bn(y)
|
|
414
|
+
if self.final_activation == "sigmoid":
|
|
415
|
+
y = torch.sigmoid(y)
|
|
416
|
+
elif self.final_activation == "softmax":
|
|
417
|
+
y = torch.softmax(y, dim=1)
|
|
418
|
+
return y
|
|
419
|
+
|
|
420
|
+
|
|
421
|
+
if __name__ == "__main__":
|
|
422
|
+
in_nside = 4
|
|
423
|
+
npix = 12 * in_nside * in_nside
|
|
424
|
+
cell_ids = np.arange(npix, dtype=np.int64)
|
|
425
|
+
|
|
426
|
+
B, T_in, Cin = 2, 3, 4
|
|
427
|
+
x = torch.randn(B, T_in, Cin, npix)
|
|
428
|
+
|
|
429
|
+
model = HealpixViTSkip(
|
|
430
|
+
in_nside=in_nside,
|
|
431
|
+
n_chan_in=Cin,
|
|
432
|
+
level_dims=[64, 96, 128],
|
|
433
|
+
depth_token=2,
|
|
434
|
+
num_heads_token=4,
|
|
435
|
+
cell_ids=cell_ids,
|
|
436
|
+
task="regression",
|
|
437
|
+
out_channels=1,
|
|
438
|
+
KERNELSZ=3,
|
|
439
|
+
G=1,
|
|
440
|
+
dropout=0.1,
|
|
441
|
+
).eval()
|
|
442
|
+
|
|
443
|
+
with torch.no_grad():
|
|
444
|
+
y = model(x)
|
|
445
|
+
print("Output:", tuple(y.shape))
|