foscat 2025.9.5__py3-none-any.whl → 2025.11.1__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
@@ -0,0 +1,445 @@
1
+ # healpix_vit_skip.py
2
+ # HEALPix ViT U-Net with temporal encoders and Transformer-based skip fusion.
3
+ # - Multi-level HEALPix pyramid using Foscat.SphericalStencil
4
+ # - Per-level temporal encoding (sequence over T_in months) at encoder
5
+ # - Decoder uses cross-attention to fuse upsampled features with encoder skips
6
+ # - Double spherical convolution + GroupNorm + GELU at each encoder/decoder level
7
+
8
+ from __future__ import annotations
9
+ from typing import List, Optional, Literal
10
+ import numpy as np
11
+
12
+ import torch
13
+ import torch.nn as nn
14
+ import torch.nn.functional as F
15
+
16
+ import foscat.scat_cov as sc
17
+ import foscat.SphericalStencil as ho
18
+
19
+
20
+ class MLP(nn.Module):
21
+ def __init__(self, d: int, hidden_mult: int = 4, drop: float = 0.0):
22
+ super().__init__()
23
+ self.net = nn.Sequential(
24
+ nn.LayerNorm(d),
25
+ nn.Linear(d, hidden_mult * d),
26
+ nn.GELU(),
27
+ nn.Dropout(drop),
28
+ nn.Linear(hidden_mult * d, d),
29
+ nn.Dropout(drop),
30
+ )
31
+ def forward(self, x):
32
+ return self.net(x)
33
+
34
+
35
+ class HealpixViTSkip(nn.Module):
36
+ def __init__(
37
+ self,
38
+ *,
39
+ in_nside: int,
40
+ n_chan_in: int,
41
+ level_dims: List[int],
42
+ depth_token: int,
43
+ num_heads_token: int,
44
+ cell_ids: np.ndarray,
45
+ task: Literal["regression","segmentation","global"] = "regression",
46
+ out_channels: int = 1,
47
+ mlp_ratio_token: float = 4.0,
48
+ KERNELSZ: int = 3,
49
+ gauge_type: Literal["cosmo","phi"] = "cosmo",
50
+ G: int = 1,
51
+ prefer_foscat_gpu: bool = True,
52
+ dropout: float = 0.1,
53
+ dtype: Literal["float32","float64"] = "float32",
54
+ pos_embed_per_level: bool = True,
55
+ ) -> None:
56
+ super().__init__()
57
+
58
+ self.in_nside = int(in_nside)
59
+ self.n_chan_in = int(n_chan_in)
60
+ self.level_dims = list(level_dims)
61
+ self.token_down = len(self.level_dims) - 1
62
+ assert self.token_down >= 0
63
+ self.C_fine = int(self.level_dims[0])
64
+ self.embed_dim = int(self.level_dims[-1])
65
+ self.depth_token = int(depth_token)
66
+ self.num_heads_token = int(num_heads_token)
67
+ self.mlp_ratio_token = float(mlp_ratio_token)
68
+ self.task = task
69
+ self.out_channels = int(out_channels)
70
+ self.KERNELSZ = int(KERNELSZ)
71
+ self.gauge_type = gauge_type
72
+ self.G = int(G)
73
+ self.prefer_foscat_gpu = bool(prefer_foscat_gpu)
74
+ self.dropout = float(dropout)
75
+ self.dtype = dtype
76
+ self.pos_embed_per_level = bool(pos_embed_per_level)
77
+
78
+ for d in self.level_dims:
79
+ if d % self.G != 0:
80
+ raise ValueError(f"All level_dims must be divisible by G={self.G}, got {d}.")
81
+ if self.embed_dim % self.num_heads_token != 0:
82
+ raise ValueError("embed_dim must be divisible by num_heads_token.")
83
+
84
+ if dtype == "float32":
85
+ self.np_dtype = np.float32
86
+ self.torch_dtype = torch.float32
87
+ else:
88
+ self.np_dtype = np.float64
89
+ self.torch_dtype = torch.float32
90
+
91
+ if cell_ids is None:
92
+ raise ValueError("cell_ids (finest) must be provided.")
93
+ self.cell_ids_fine = np.asarray(cell_ids)
94
+
95
+ if self.task == "segmentation":
96
+ self.final_activation = "sigmoid" if self.out_channels == 1 else "softmax"
97
+ else:
98
+ self.final_activation = "none"
99
+
100
+ self.f = sc.funct(KERNELSZ=self.KERNELSZ)
101
+
102
+ # Build stencils
103
+ self.hconv_levels: List[ho.SphericalStencil] = []
104
+ self.level_cell_ids: List[np.ndarray] = [self.cell_ids_fine]
105
+ current_nside = self.in_nside
106
+ dummy = self.f.backend.bk_cast(np.zeros((1, 1, self.cell_ids_fine.shape[0]), dtype=self.np_dtype))
107
+ for _ in range(self.token_down):
108
+ hc = ho.SphericalStencil(current_nside, self.KERNELSZ, n_gauges=self.G,
109
+ gauge_type=self.gauge_type, cell_ids=self.level_cell_ids[-1],
110
+ dtype=self.torch_dtype)
111
+ self.hconv_levels.append(hc)
112
+ dummy, next_ids = hc.Down(dummy, cell_ids=self.level_cell_ids[-1], nside=current_nside, max_poll=True)
113
+ self.level_cell_ids.append(self.f.backend.to_numpy(next_ids))
114
+ current_nside //= 2
115
+
116
+ self.token_nside = current_nside if self.token_down > 0 else self.in_nside
117
+ self.token_cell_ids = self.level_cell_ids[-1]
118
+
119
+ self.hconv_token = ho.SphericalStencil(self.token_nside, self.KERNELSZ, n_gauges=self.G,
120
+ gauge_type=self.gauge_type, cell_ids=self.token_cell_ids, dtype=self.torch_dtype)
121
+ self.hconv_head = ho.SphericalStencil(self.in_nside, self.KERNELSZ, n_gauges=self.G,
122
+ gauge_type=self.gauge_type, cell_ids=self.cell_ids_fine, dtype=self.torch_dtype)
123
+
124
+ self.nsides_levels = [self.in_nside // (2**i) for i in range(self.token_down+1)]
125
+ self.ntokens_levels = [12 * n * n for n in self.nsides_levels]
126
+
127
+ # Patch embed (double conv)
128
+ fine_g = self.C_fine // self.G
129
+ self.pe_w1 = nn.Parameter(torch.empty(self.n_chan_in, fine_g, self.KERNELSZ*self.KERNELSZ))
130
+ nn.init.kaiming_uniform_(self.pe_w1.view(self.n_chan_in * fine_g, -1), a=np.sqrt(5))
131
+ self.pe_w2 = nn.Parameter(torch.empty(self.C_fine, fine_g, self.KERNELSZ*self.KERNELSZ))
132
+ nn.init.kaiming_uniform_(self.pe_w2.view(self.C_fine * fine_g, -1), a=np.sqrt(5))
133
+ self.pe_bn1 = nn.GroupNorm(num_groups=min(8, self.C_fine if self.C_fine>1 else 1), num_channels=self.C_fine)
134
+ self.pe_bn2 = nn.GroupNorm(num_groups=min(8, self.C_fine if self.C_fine>1 else 1), num_channels=self.C_fine)
135
+
136
+ # Encoder double convs
137
+ self.enc_w1 = nn.ParameterList()
138
+ self.enc_w2 = nn.ParameterList()
139
+ self.enc_bn1 = nn.ModuleList()
140
+ self.enc_bn2 = nn.ModuleList()
141
+ for i in range(self.token_down):
142
+ Cin = self.level_dims[i]
143
+ Cout = self.level_dims[i+1]
144
+ Cout_g = Cout // self.G
145
+ w1 = nn.Parameter(torch.empty(Cin, Cout_g, self.KERNELSZ*self.KERNELSZ))
146
+ nn.init.kaiming_uniform_(w1.view(Cin * Cout_g, -1), a=np.sqrt(5))
147
+ w2 = nn.Parameter(torch.empty(Cout, Cout_g, self.KERNELSZ*self.KERNELSZ))
148
+ nn.init.kaiming_uniform_(w2.view(Cout * Cout_g, -1), a=np.sqrt(5))
149
+ self.enc_w1.append(w1); self.enc_w2.append(w2)
150
+ self.enc_bn1.append(nn.GroupNorm(num_groups=min(8, Cout if Cout>1 else 1), num_channels=Cout))
151
+ self.enc_bn2.append(nn.GroupNorm(num_groups=min(8, Cout if Cout>1 else 1), num_channels=Cout))
152
+
153
+ # Temporal encoders per level (fine..pre-token)
154
+ self.temporal_encoders = nn.ModuleList([
155
+ nn.TransformerEncoder(
156
+ nn.TransformerEncoderLayer(
157
+ d_model=self.level_dims[i],
158
+ nhead=max(1, min(8, self.level_dims[i] // 64)),
159
+ dim_feedforward=2*self.level_dims[i],
160
+ dropout=self.dropout,
161
+ activation='gelu',
162
+ batch_first=True,
163
+ norm_first=True,
164
+ ),
165
+ num_layers=2,
166
+ )
167
+ for i in range(self.token_down)
168
+ ])
169
+
170
+ # Token-level Transformer
171
+ self.n_tokens = int(self.token_cell_ids.shape[0])
172
+ self.pos_token = nn.Parameter(torch.zeros(1, self.n_tokens, self.embed_dim))
173
+ nn.init.trunc_normal_(self.pos_token, std=0.02)
174
+ enc_layer = nn.TransformerEncoderLayer(
175
+ d_model=self.embed_dim,
176
+ nhead=self.num_heads_token,
177
+ dim_feedforward=int(self.embed_dim * self.mlp_ratio_token),
178
+ dropout=self.dropout,
179
+ activation='gelu',
180
+ batch_first=True,
181
+ norm_first=True,
182
+ )
183
+ self.encoder_token = nn.TransformerEncoder(enc_layer, num_layers=self.depth_token)
184
+
185
+ # Decoder fusion modules per level (cross-attention)
186
+ self.dec_q = nn.ModuleList()
187
+ self.dec_k = nn.ModuleList()
188
+ self.dec_v = nn.ModuleList()
189
+ self.dec_attn = nn.ModuleList()
190
+ self.dec_mlp = nn.ModuleList()
191
+ self.level_pos = nn.ParameterList() if self.pos_embed_per_level else None
192
+ for i in range(self.token_down, 0, -1):
193
+ Cfine = self.level_dims[i-1]
194
+ d_fuse = Cfine
195
+ self.dec_q.append(nn.Linear(Cfine, d_fuse))
196
+ self.dec_k.append(nn.Linear(Cfine, d_fuse))
197
+ self.dec_v.append(nn.Linear(Cfine, d_fuse))
198
+ self.dec_attn.append(nn.MultiheadAttention(embed_dim=d_fuse, num_heads=max(1, min(8, d_fuse // 64)), batch_first=True))
199
+ self.dec_mlp.append(MLP(d_fuse, hidden_mult=4, drop=self.dropout))
200
+ if self.pos_embed_per_level:
201
+ n_tok_i = self.ntokens_levels[i-1]
202
+ p = nn.Parameter(torch.zeros(1, n_tok_i, d_fuse))
203
+ nn.init.trunc_normal_(p, std=0.02)
204
+ self.level_pos.append(p)
205
+
206
+ # Decoder refinement double convs
207
+ self.dec_refine_w1 = nn.ParameterList()
208
+ self.dec_refine_w2 = nn.ParameterList()
209
+ self.dec_refine_bn1 = nn.ModuleList()
210
+ self.dec_refine_bn2 = nn.ModuleList()
211
+ for i in range(self.token_down, 0, -1):
212
+ Cfine = self.level_dims[i-1]
213
+ Cfine_g = Cfine // self.G
214
+ w1 = nn.Parameter(torch.empty(Cfine, Cfine_g, self.KERNELSZ*self.KERNELSZ))
215
+ nn.init.kaiming_uniform_(w1.view(Cfine * Cfine_g, -1), a=np.sqrt(5))
216
+ w2 = nn.Parameter(torch.empty(Cfine, Cfine_g, self.KERNELSZ*self.KERNELSZ))
217
+ nn.init.kaiming_uniform_(w2.view(Cfine * Cfine_g, -1), a=np.sqrt(5))
218
+ self.dec_refine_w1.append(w1); self.dec_refine_w2.append(w2)
219
+ self.dec_refine_bn1.append(nn.GroupNorm(num_groups=min(8, Cfine if Cfine>1 else 1), num_channels=Cfine))
220
+ self.dec_refine_bn2.append(nn.GroupNorm(num_groups=min(8, Cfine if Cfine>1 else 1), num_channels=Cfine))
221
+
222
+ # Head
223
+ if self.task == "global":
224
+ self.global_head = nn.Linear(self.embed_dim, self.out_channels)
225
+ else:
226
+ if self.out_channels % self.G != 0:
227
+ raise ValueError(f"out_channels={self.out_channels} must be divisible by G={self.G}")
228
+ out_g = self.out_channels // self.G
229
+ self.head_w = nn.Parameter(torch.empty(self.C_fine, out_g, self.KERNELSZ*self.KERNELSZ))
230
+ nn.init.kaiming_uniform_(self.head_w.view(self.C_fine * out_g, -1), a=np.sqrt(5))
231
+ self.head_bn = nn.GroupNorm(num_groups=min(8, self.out_channels if self.out_channels>1 else 1),
232
+ num_channels=self.out_channels) if self.task=="segmentation" else None
233
+
234
+ pref = torch.device("cuda" if torch.cuda.is_available() else "cpu")
235
+ self.runtime_device = self._probe_and_set_runtime_device(pref)
236
+
237
+ def _move_hc(self, hc: ho.SphericalStencil, device: torch.device) -> None:
238
+ for name, val in list(vars(hc).items()):
239
+ try:
240
+ if torch.is_tensor(val):
241
+ setattr(hc, name, val.to(device))
242
+ elif isinstance(val, (list, tuple)) and val and torch.is_tensor(val[0]):
243
+ setattr(hc, name, type(val)([v.to(device) for v in val]))
244
+ except Exception:
245
+ pass
246
+
247
+ @torch.no_grad()
248
+ def _probe_and_set_runtime_device(self, preferred: torch.device) -> torch.device:
249
+ if preferred.type == "cuda":
250
+ try:
251
+ super().to(preferred)
252
+ for hc in self.hconv_levels + [self.hconv_token, self.hconv_head]:
253
+ self._move_hc(hc, preferred)
254
+ npix0 = int(self.cell_ids_fine.shape[0])
255
+ x_try = torch.zeros(1, self.n_chan_in, npix0, device=preferred)
256
+ hc0 = self.hconv_levels[0] if len(self.hconv_levels)>0 else self.hconv_head
257
+ y_try = hc0.Convol_torch(x_try, self.pe_w1, cell_ids=self.cell_ids_fine)
258
+ _ = (y_try if torch.is_tensor(y_try) else torch.as_tensor(y_try, device=preferred)).sum().item()
259
+ self._foscat_device = preferred
260
+ return preferred
261
+ except Exception:
262
+ pass
263
+ cpu = torch.device("cpu")
264
+ super().to(cpu)
265
+ for hc in self.hconv_levels + [self.hconv_token, self.hconv_head]:
266
+ self._move_hc(hc, cpu)
267
+ self._foscat_device = cpu
268
+ return cpu
269
+
270
+ def _as_tensor_batch(self, x):
271
+ if isinstance(x, list):
272
+ if len(x) == 1:
273
+ t = x[0]
274
+ return t.unsqueeze(0) if t.dim() == 2 else t
275
+ raise ValueError("Variable-length list not supported here; pass a tensor.")
276
+ return x
277
+
278
+ def _to_numpy_ids(self, ids):
279
+ if torch.is_tensor(ids):
280
+ return ids.detach().cpu().numpy()
281
+ return np.asarray(ids)
282
+
283
+ def _patch_embed_fine(self, x_t: torch.Tensor) -> torch.Tensor:
284
+ hc0 = self.hconv_levels[0] if len(self.hconv_levels)>0 else self.hconv_head
285
+ z = hc0.Convol_torch(x_t, self.pe_w1, cell_ids=self.cell_ids_fine)
286
+ z = self._as_tensor_batch(z if torch.is_tensor(z) else torch.as_tensor(z, device=self.runtime_device))
287
+ z = self.pe_bn1(z); z = F.gelu(z)
288
+ z = hc0.Convol_torch(z, self.pe_w2, cell_ids=self.cell_ids_fine)
289
+ z = self._as_tensor_batch(z if torch.is_tensor(z) else torch.as_tensor(z, device=self.runtime_device))
290
+ z = self.pe_bn2(z); z = F.gelu(z)
291
+ return z
292
+
293
+ def forward(self, x: torch.Tensor, runtime_ids: Optional[np.ndarray] = None) -> torch.Tensor:
294
+ if x.dim() != 4:
295
+ raise ValueError("Expected input shape (B, T_in, C_in, Npix)")
296
+ B, T_in, C_in, Nf = x.shape
297
+ if C_in != self.n_chan_in:
298
+ raise ValueError(f"Expected n_chan_in={self.n_chan_in}, got {C_in}")
299
+ x = x.to(self.runtime_device)
300
+
301
+ fine_ids_runtime = self.cell_ids_fine if runtime_ids is None else self._to_numpy_ids(runtime_ids)
302
+ ids_chain = [np.asarray(fine_ids_runtime)]
303
+ nside_tmp = self.in_nside
304
+ _dummy = self.f.backend.bk_cast(np.zeros((1, 1, ids_chain[0].shape[0]), dtype=self.np_dtype))
305
+ for hc in self.hconv_levels:
306
+ _dummy, _next = hc.Down(_dummy, cell_ids=ids_chain[-1], nside=nside_tmp, max_poll=True)
307
+ ids_chain.append(self.f.backend.to_numpy(_next))
308
+ nside_tmp //= 2
309
+
310
+ # Encoder histories per level
311
+ l_hist: List[torch.Tensor] = []
312
+ l_ids: List[np.ndarray] = []
313
+
314
+ feats_fine = []
315
+ for t in range(T_in):
316
+ zt = self._patch_embed_fine(x[:, t, :, :])
317
+ feats_fine.append(zt.unsqueeze(1))
318
+ feats_fine = torch.cat(feats_fine, dim=1) # (B, T_in, C_fine, N_fine)
319
+ l_hist.append(feats_fine)
320
+ l_ids.append(self.cell_ids_fine)
321
+
322
+ current_nside = self.in_nside
323
+ l_data_hist = feats_fine
324
+ for i, hc in enumerate(self.hconv_levels):
325
+ Cin = self.level_dims[i]
326
+ Cout = self.level_dims[i+1]
327
+ w1, w2 = self.enc_w1[i], self.enc_w2[i]
328
+ feats_next = []
329
+ for t in range(T_in):
330
+ zt = l_data_hist[:, t, :, :]
331
+ zt = hc.Convol_torch(zt, w1, cell_ids=l_ids[-1])
332
+ zt = self._as_tensor_batch(zt if torch.is_tensor(zt) else torch.as_tensor(zt, device=self.runtime_device))
333
+ zt = self.enc_bn1[i](zt); zt = F.gelu(zt)
334
+ zt = hc.Convol_torch(zt, w2, cell_ids=l_ids[-1])
335
+ zt = self._as_tensor_batch(zt if torch.is_tensor(zt) else torch.as_tensor(zt, device=self.runtime_device))
336
+ zt = self.enc_bn2[i](zt); zt = F.gelu(zt)
337
+ feats_next.append(zt.unsqueeze(1))
338
+ feats_next = torch.cat(feats_next, dim=1) # (B, T_in, Cout, N_i)
339
+
340
+ feats_down = []
341
+ next_ids_list = None
342
+ for t in range(T_in):
343
+ zt, next_ids = hc.Down(feats_next[:, t, :, :], cell_ids=l_ids[-1], nside=current_nside, max_poll=True)
344
+ zt = self._as_tensor_batch(zt)
345
+ feats_down.append(zt.unsqueeze(1))
346
+ next_ids_list = next_ids
347
+ feats_down = torch.cat(feats_down, dim=1) # (B, T_in, Cout, N_{i+1})
348
+
349
+ l_hist.append(feats_down)
350
+ l_ids.append(self.f.backend.to_numpy(next_ids_list))
351
+ l_data_hist = feats_down
352
+ current_nside //= 2
353
+
354
+ # Temporal encoder on skips (levels 0..token_down-1)
355
+ skips: List[torch.Tensor] = []
356
+ for i in range(self.token_down):
357
+ Bx, Tx, Cx, Nx = l_hist[i].shape
358
+ z = l_hist[i].permute(0, 3, 1, 2).reshape(Bx*Nx, Tx, Cx)
359
+ z = self.temporal_encoders[i](z)
360
+ z = z.mean(dim=1)
361
+ H_i = z.view(Bx, Nx, Cx).permute(0, 2, 1).contiguous()
362
+ skips.append(H_i)
363
+
364
+ # Token-level transformer (spatial)
365
+ x_tok_hist = l_hist[-1] # (B, T_in, E, Ntok)
366
+ x_tok = x_tok_hist.mean(dim=1) # (B, E, Ntok) (could add temporal encoder here as well)
367
+ seq = x_tok.permute(0, 2, 1) + self.pos_token[:, :x_tok.shape[2], :]
368
+ seq = self.encoder_token(seq)
369
+ y = seq.permute(0, 2, 1) # (B, E, Ntok)
370
+
371
+ if self.task == "global":
372
+ g = seq.mean(dim=1)
373
+ return self.global_head(g)
374
+
375
+ # Decoder: Up + cross-attn fusion + double conv refinement
376
+ dec_idx = 0
377
+ for i in range(self.token_down, 0, -1):
378
+ coarse_ids = ids_chain[i]
379
+ fine_ids = ids_chain[i-1]
380
+ source_ns = self.in_nside // (2 ** i)
381
+ fine_ns = self.in_nside // (2 ** (i-1))
382
+ Cfine = self.level_dims[i-1]
383
+
384
+ op_fine = self.hconv_head if fine_ns == self.in_nside else self.hconv_levels[self.nsides_levels.index(fine_ns)]
385
+
386
+ y_up = op_fine.Up(y, cell_ids=coarse_ids, o_cell_ids=fine_ids, nside=source_ns)
387
+ y_up = self._as_tensor_batch(y_up if torch.is_tensor(y_up) else torch.as_tensor(y_up, device=self.runtime_device)) # (B, Cfine, N)
388
+
389
+ skip_i = skips[i-1] # (B, Cfine, N)
390
+ q = self.dec_q[dec_idx](y_up.permute(0,2,1))
391
+ k = self.dec_k[dec_idx](skip_i.permute(0,2,1))
392
+ v = self.dec_v[dec_idx](skip_i.permute(0,2,1))
393
+ if self.pos_embed_per_level:
394
+ pos = self.level_pos[dec_idx][:, :q.shape[1], :]
395
+ q = q + pos; k = k + pos
396
+ z, _ = self.dec_attn[dec_idx](q, k, v)
397
+ z = self.dec_mlp[dec_idx](z)
398
+ z = z.permute(0,2,1).contiguous() # (B, Cfine, N)
399
+
400
+ z = op_fine.Convol_torch(z, self.dec_refine_w1[dec_idx], cell_ids=fine_ids)
401
+ z = self._as_tensor_batch(z if torch.is_tensor(z) else torch.as_tensor(z, device=self.runtime_device))
402
+ z = self.dec_refine_bn1[dec_idx](z); z = F.gelu(z)
403
+ z = op_fine.Convol_torch(z, self.dec_refine_w2[dec_idx], cell_ids=fine_ids)
404
+ z = self._as_tensor_batch(z if torch.is_tensor(z) else torch.as_tensor(z, device=self.runtime_device))
405
+ z = self.dec_refine_bn2[dec_idx](z); z = F.gelu(z)
406
+
407
+ y = z
408
+ dec_idx += 1
409
+
410
+ y = self.hconv_head.Convol_torch(y, self.head_w, cell_ids=fine_ids_runtime)
411
+ y = self._as_tensor_batch(y if torch.is_tensor(y) else torch.as_tensor(y, device=self.runtime_device))
412
+ if self.task == "segmentation" and self.head_bn is not None:
413
+ y = self.head_bn(y)
414
+ if self.final_activation == "sigmoid":
415
+ y = torch.sigmoid(y)
416
+ elif self.final_activation == "softmax":
417
+ y = torch.softmax(y, dim=1)
418
+ return y
419
+
420
+
421
+ if __name__ == "__main__":
422
+ in_nside = 4
423
+ npix = 12 * in_nside * in_nside
424
+ cell_ids = np.arange(npix, dtype=np.int64)
425
+
426
+ B, T_in, Cin = 2, 3, 4
427
+ x = torch.randn(B, T_in, Cin, npix)
428
+
429
+ model = HealpixViTSkip(
430
+ in_nside=in_nside,
431
+ n_chan_in=Cin,
432
+ level_dims=[64, 96, 128],
433
+ depth_token=2,
434
+ num_heads_token=4,
435
+ cell_ids=cell_ids,
436
+ task="regression",
437
+ out_channels=1,
438
+ KERNELSZ=3,
439
+ G=1,
440
+ dropout=0.1,
441
+ ).eval()
442
+
443
+ with torch.no_grad():
444
+ y = model(x)
445
+ print("Output:", tuple(y.shape))