foscat 2025.9.5__py3-none-any.whl → 2025.10.2__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
@@ -0,0 +1,521 @@
1
+ # healpix_vit_varlevels.py
2
+ # HEALPix ViT with level-wise (variable) channel widths and U-Net-style spherical decoder
3
+ from __future__ import annotations
4
+ from typing import List, Optional, Literal, Tuple, Union
5
+ import numpy as np
6
+
7
+ import torch
8
+ import torch.nn as nn
9
+ import torch.nn.functional as F
10
+
11
+ import foscat.scat_cov as sc
12
+ import foscat.SphericalStencil as ho
13
+
14
+
15
+ class HealpixViT(nn.Module):
16
+ """
17
+ HEALPix Vision Transformer (Foscat-based) with *variable channel widths per level*
18
+ and a U-Net-like spherical decoder.
19
+
20
+ Key idea
21
+ --------
22
+ - Encoder uses a list of channel dimensions `level_dims = [C_fine, C_l1, ..., C_token]`
23
+ that evolve *with depth* (e.g., 128 -> 192 -> 256).
24
+ - At each encoder level (before a Down()), we apply a spherical convolution that
25
+ maps C_i -> C_{i+1}. Down() then reduces the HEALPix resolution by one level.
26
+ - Transformer runs at the token grid with embedding dim = C_token.
27
+ - Decoder upsamples *one level at a time*; after each Up() it concatenates the
28
+ upsampled token features (C_{i+1}) with the corresponding skip (C_i) and applies
29
+ a spherical convolution to fuse (C_{i+1} + C_i) -> C_i.
30
+ - Final head maps C_fine -> out_channels at the finest grid.
31
+
32
+ Shapes (dense tasks)
33
+ --------------------
34
+ Input : (B, Cin, Nfine)
35
+ → patch-embed (Cin -> C_fine) at finest grid
36
+ → for i in [0..L-1]: EncConv(C_i->C_{i+1}) → Down() (store skip_i=C_i at grid i)
37
+ → tokens at grid L with dim C_token
38
+ → Transformer on tokens (C_token)
39
+ → for i in [L-1..0]: Up() to grid i → concat(skip_i, up) [C_i + C_{i+1}] → DecConv → C_i
40
+ → Head: C_fine -> out_channels at finest grid
41
+
42
+ Requirements
43
+ ------------
44
+ - level_dims length must be token_down+1, with:
45
+ len(level_dims) = token_down + 1
46
+ level_dims[0] = channels at finest grid after patch embedding
47
+ level_dims[-1] = Transformer embedding dimension
48
+ - Each value in level_dims must be divisible by G (number of gauges).
49
+ - out_channels must be divisible by G.
50
+
51
+ Parameters (main)
52
+ -----------------
53
+ in_nside : input HEALPix nside (nested)
54
+ n_chan_in : input channels at finest grid (Cin)
55
+ level_dims : list of ints, channel width per level from fine to token
56
+ depth : number of Transformer encoder layers
57
+ num_heads : self-attention heads
58
+ cell_ids : finest-level nested indices (Nfine = 12*nside^2)
59
+ task : "regression" | "segmentation" | "global"
60
+ out_channels : output channels for dense tasks
61
+ KERNELSZ : spherical kernel size for Foscat convolutions
62
+ gauge_type : "cosmo" | "phi"
63
+ G : number of gauges
64
+ """
65
+
66
+ def __init__(
67
+ self,
68
+ *,
69
+ in_nside: int,
70
+ n_chan_in: int,
71
+ level_dims: List[int], # e.g., [128, 192, 256] (fine -> token)
72
+ depth: int,
73
+ num_heads: int,
74
+ cell_ids: np.ndarray,
75
+ task: Literal["regression", "segmentation", "global"] = "regression",
76
+ out_channels: int = 1,
77
+ mlp_ratio: float = 4.0,
78
+ KERNELSZ: int = 3,
79
+ gauge_type: Literal["cosmo", "phi"] = "cosmo",
80
+ G: int = 1,
81
+ prefer_foscat_gpu: bool = True,
82
+ cls_token: bool = False,
83
+ pos_embed: Literal["learned", "none"] = "learned",
84
+ head_type: Literal["mean", "cls"] = "mean",
85
+ dropout: float = 0.0,
86
+ dtype: Literal["float32", "float64"] = "float32",
87
+ ) -> None:
88
+ super().__init__()
89
+
90
+ # ---- config ----
91
+ self.in_nside = int(in_nside)
92
+ self.n_chan_in = int(n_chan_in)
93
+ self.level_dims = list(level_dims)
94
+ self.depth = int(depth)
95
+ self.num_heads = int(num_heads)
96
+ self.task = task
97
+ self.out_channels = int(out_channels)
98
+ self.mlp_ratio = float(mlp_ratio)
99
+ self.KERNELSZ = int(KERNELSZ)
100
+ self.gauge_type = gauge_type
101
+ self.G = int(G)
102
+ self.prefer_foscat_gpu = bool(prefer_foscat_gpu)
103
+ self.cls_token_enabled = bool(cls_token)
104
+ self.pos_embed_type = pos_embed
105
+ self.head_type = head_type
106
+ self.dropout = float(dropout)
107
+ self.dtype = dtype
108
+
109
+ if len(self.level_dims) < 1:
110
+ raise ValueError("level_dims must have at least one element (fine level).")
111
+ self.token_down = len(self.level_dims) - 1
112
+ self.embed_dim = int(self.level_dims[-1]) # Transformer dim
113
+
114
+ for d in self.level_dims:
115
+ if d % self.G != 0:
116
+ raise ValueError(f"Each level dim must be divisible by G={self.G}; got {d}.")
117
+ if self.embed_dim % self.num_heads != 0:
118
+ raise ValueError("embed_dim must be divisible by num_heads.")
119
+
120
+ if dtype == "float32":
121
+ self.np_dtype = np.float32
122
+ self.torch_dtype = torch.float32
123
+ else:
124
+ self.np_dtype = np.float64
125
+ self.torch_dtype = torch.float32 # keep model in fp32
126
+
127
+ if cell_ids is None:
128
+ raise ValueError("cell_ids (finest) must be provided (nested ordering).")
129
+ self.cell_ids_fine = np.asarray(cell_ids)
130
+
131
+ # Default activation
132
+ if self.task == "segmentation":
133
+ self.final_activation = "sigmoid" if self.out_channels == 1 else "softmax"
134
+ else:
135
+ self.final_activation = "none"
136
+
137
+ # Foscat wrapper
138
+ self.f = sc.funct(KERNELSZ=self.KERNELSZ)
139
+
140
+ # ---- Build operators per level (fine -> ... -> token) and compute ids ----
141
+ self.hconv_levels: List[ho.SphericalStencil] = []
142
+ self.level_cell_ids: List[np.ndarray] = [self.cell_ids_fine]
143
+ current_nside = self.in_nside
144
+
145
+ dummy = self.f.backend.bk_cast(
146
+ np.zeros((1, 1, self.cell_ids_fine.shape[0]), dtype=self.np_dtype)
147
+ )
148
+
149
+ for _ in range(self.token_down):
150
+ hc = ho.SphericalStencil(
151
+ current_nside,
152
+ self.KERNELSZ,
153
+ n_gauges=self.G,
154
+ gauge_type=self.gauge_type,
155
+ cell_ids=self.level_cell_ids[-1],
156
+ dtype=self.torch_dtype,
157
+ )
158
+ self.hconv_levels.append(hc)
159
+
160
+ dummy, next_ids = hc.Down(
161
+ dummy, cell_ids=self.level_cell_ids[-1], nside=current_nside, max_poll=True
162
+ )
163
+ self.level_cell_ids.append(self.f.backend.to_numpy(next_ids))
164
+ current_nside //= 2
165
+
166
+ self.token_nside = current_nside if self.token_down > 0 else self.in_nside
167
+ self.token_cell_ids = self.level_cell_ids[-1]
168
+
169
+ # Token and fine-level operators (for convenience)
170
+ self.hconv_token = ho.SphericalStencil(
171
+ self.token_nside,
172
+ self.KERNELSZ,
173
+ n_gauges=self.G,
174
+ gauge_type=self.gauge_type,
175
+ cell_ids=self.token_cell_ids,
176
+ dtype=self.torch_dtype,
177
+ )
178
+ self.hconv_head = ho.SphericalStencil(
179
+ self.in_nside,
180
+ self.KERNELSZ,
181
+ n_gauges=self.G,
182
+ gauge_type=self.gauge_type,
183
+ cell_ids=self.cell_ids_fine,
184
+ dtype=self.torch_dtype,
185
+ )
186
+
187
+ # ---------------- Patch embedding (Cin -> C_fine) ----------------
188
+ fine_dim = self.level_dims[0]
189
+ fine_g = fine_dim // self.G
190
+ self.patch_w = nn.Parameter(
191
+ torch.empty(self.n_chan_in, fine_g, self.KERNELSZ * self.KERNELSZ)
192
+ )
193
+ nn.init.kaiming_uniform_(self.patch_w.view(self.n_chan_in * fine_g, -1), a=np.sqrt(5))
194
+ self.patch_bn = nn.GroupNorm(num_groups=min(8, fine_dim if fine_dim > 1 else 1),
195
+ num_channels=fine_dim)
196
+
197
+ # ---------------- Encoder convs per level (C_i -> C_{i+1}) ----------------
198
+ self.enc_w: nn.ParameterList = nn.ParameterList()
199
+ self.enc_bn: nn.ModuleList = nn.ModuleList()
200
+ for i in range(self.token_down):
201
+ Cin = self.level_dims[i]
202
+ Cout = self.level_dims[i+1]
203
+ Cout_g = Cout // self.G
204
+ w = nn.Parameter(torch.empty(Cin, Cout_g, self.KERNELSZ * self.KERNELSZ))
205
+ nn.init.kaiming_uniform_(w.view(Cin * Cout_g, -1), a=np.sqrt(5))
206
+ self.enc_w.append(w)
207
+ self.enc_bn.append(nn.GroupNorm(num_groups=min(8, Cout if Cout > 1 else 1),
208
+ num_channels=Cout))
209
+
210
+ # ---------------- Transformer at token grid ----------------
211
+ self.n_tokens = int(self.token_cell_ids.shape[0])
212
+ if self.cls_token_enabled:
213
+ self.cls_token = nn.Parameter(torch.zeros(1, 1, self.embed_dim))
214
+ nn.init.trunc_normal_(self.cls_token, std=0.02)
215
+ n_pe = self.n_tokens + 1
216
+ else:
217
+ self.cls_token = None
218
+ n_pe = self.n_tokens
219
+
220
+ if self.pos_embed_type == "learned":
221
+ self.pos_embed = nn.Parameter(torch.zeros(1, n_pe, self.embed_dim))
222
+ nn.init.trunc_normal_(self.pos_embed, std=0.02)
223
+ else:
224
+ self.pos_embed = None
225
+
226
+ enc_layer = nn.TransformerEncoderLayer(
227
+ d_model=self.embed_dim,
228
+ nhead=self.num_heads,
229
+ dim_feedforward=int(self.embed_dim * self.mlp_ratio),
230
+ dropout=self.dropout,
231
+ activation="gelu",
232
+ batch_first=True,
233
+ norm_first=True,
234
+ )
235
+ self.encoder = nn.TransformerEncoder(enc_layer, num_layers=self.depth)
236
+
237
+ # Projection at token grid (keep C_token)
238
+ self.token_proj = nn.Linear(self.embed_dim, self.embed_dim)
239
+
240
+ # ---------------- Decoder convs per level ( (C_{i+1}+C_i) -> C_i ) ----------------
241
+ self.dec_w: nn.ParameterList = nn.ParameterList()
242
+ self.dec_bn: nn.ModuleList = nn.ModuleList()
243
+ for i in range(self.token_down - 1, -1, -1):
244
+ # decoder proceeds from token level back to fine; we create weights in the same order
245
+ Cin_fuse = self.level_dims[i+1] + self.level_dims[i] # up + skip
246
+ Cout = self.level_dims[i]
247
+ Cout_g = Cout // self.G
248
+ w = nn.Parameter(torch.empty(Cin_fuse, Cout_g, self.KERNELSZ * self.KERNELSZ))
249
+ nn.init.kaiming_uniform_(w.view(Cin_fuse * Cout_g, -1), a=np.sqrt(5))
250
+ self.dec_w.append(w) # index 0 corresponds to up from token to level L-1
251
+ self.dec_bn.append(nn.GroupNorm(num_groups=min(8, Cout if Cout > 1 else 1),
252
+ num_channels=Cout))
253
+
254
+ # ---------------- Final head (C_fine -> out_channels) ----------------
255
+ if self.task == "global":
256
+ self.global_head = nn.Linear(self.embed_dim, self.out_channels)
257
+ else:
258
+ self.C_fine = self.level_dims[0]
259
+
260
+ if self.out_channels % self.G != 0:
261
+ raise ValueError(f"out_channels={self.out_channels} must be divisible by G={self.G}")
262
+ out_g = self.C_fine//self.G
263
+ self.head_w = nn.Parameter(torch.empty(out_g, self.out_channels, self.KERNELSZ * self.KERNELSZ))
264
+ nn.init.kaiming_uniform_(self.head_w.view(self.out_channels * out_g, -1), a=np.sqrt(5))
265
+ self.head_bn = (nn.GroupNorm(num_groups=min(8, self.out_channels if self.out_channels > 1 else 1),
266
+ num_channels=self.out_channels)
267
+ if self.task == "segmentation" else None)
268
+
269
+ # ---------------- Device probe ----------------
270
+ pref = torch.device("cuda" if torch.cuda.is_available() else "cpu")
271
+ self.runtime_device = self._probe_and_set_runtime_device(pref)
272
+
273
+ # ---------------- device helpers ----------------
274
+ def _move_hc(self, hc: ho.SphericalStencil, device: torch.device) -> None:
275
+ for name, val in list(vars(hc).items()):
276
+ try:
277
+ if torch.is_tensor(val):
278
+ setattr(hc, name, val.to(device))
279
+ elif isinstance(val, (list, tuple)) and val and torch.is_tensor(val[0]):
280
+ setattr(hc, name, type(val)([v.to(device) for v in val]))
281
+ except Exception:
282
+ pass
283
+
284
+ @torch.no_grad()
285
+ def _probe_and_set_runtime_device(self, preferred: torch.device) -> torch.device:
286
+ if preferred.type == "cuda" and self.prefer_foscat_gpu:
287
+ try:
288
+ super().to(preferred)
289
+ for hc in self.hconv_levels + [self.hconv_token, self.hconv_head]:
290
+ self._move_hc(hc, preferred)
291
+ # dry run
292
+ npix0 = int(self.cell_ids_fine.shape[0])
293
+ x_try = torch.zeros(1, self.n_chan_in, npix0, device=preferred)
294
+ hc0 = self.hconv_levels[0] if len(self.hconv_levels) > 0 else self.hconv_head
295
+ y_try = hc0.Convol_torch(x_try, self.patch_w, cell_ids=self.cell_ids_fine)
296
+ _ = y_try.sum().item()
297
+ self._foscat_device = preferred
298
+ return preferred
299
+ except Exception as e:
300
+ self._gpu_probe_error = repr(e)
301
+ cpu = torch.device("cpu")
302
+ super().to(cpu)
303
+ for hc in self.hconv_levels + [self.hconv_token, self.hconv_head]:
304
+ self._move_hc(hc, cpu)
305
+ self._foscat_device = cpu
306
+ return cpu
307
+
308
+ def _to_numpy_ids(self, ids):
309
+ """Return ids as a NumPy array on CPU (handles torch.Tensor on CUDA)."""
310
+ if torch.is_tensor(ids):
311
+ return ids.detach().cpu().numpy()
312
+ return np.asarray(ids)
313
+
314
+ # ---------------- helpers ----------------
315
+ def _as_tensor_batch(self, x):
316
+ if isinstance(x, list):
317
+ if len(x) == 1:
318
+ t = x[0]
319
+ return t.unsqueeze(0) if t.dim() == 2 else t
320
+ raise ValueError("Variable-length list not supported here; pass a tensor.")
321
+ return x
322
+
323
+ # ---------------- forward ----------------
324
+ def forward(
325
+ self,
326
+ x: torch.Tensor,
327
+ runtime_ids: Optional[np.ndarray] = None,
328
+ ) -> torch.Tensor:
329
+ """
330
+ x: (B, Cin, Nfine), nested ordering
331
+ runtime_ids: optional fine-level ids to decode onto (defaults to training ids)
332
+ """
333
+ if not isinstance(x, torch.Tensor):
334
+ raise TypeError("x must be a torch.Tensor")
335
+ if x.dim() != 3:
336
+ raise ValueError("Input must be (B, Cin, Npix)")
337
+ if x.shape[1] != self.n_chan_in:
338
+ raise ValueError(f"Expected {self.n_chan_in} channels, got {x.shape[1]}")
339
+ if runtime_ids is not None:
340
+ runtime_ids = self._to_numpy_ids(runtime_ids)
341
+
342
+ x = x.to(self.runtime_device)
343
+
344
+ # -------- Patch embedding Cin -> C_fine --------
345
+ hc_fine0 = self.hconv_levels[0] if len(self.hconv_levels) > 0 else self.hconv_head
346
+ z = hc_fine0.Convol_torch(x, self.patch_w, cell_ids=self.cell_ids_fine) # (B, C_fine, Nfine)
347
+ if not torch.is_tensor(z):
348
+ z = torch.as_tensor(z, device=self.runtime_device)
349
+ z = self._as_tensor_batch(z)
350
+ z = self.patch_bn(z)
351
+ z = F.gelu(z)
352
+
353
+ # -------- Encoder path: for each level i: EncConv(C_i->C_{i+1}) then Down() --------
354
+ skips: List[torch.Tensor] = []
355
+ ids_list: List[np.ndarray] = []
356
+
357
+ l_data = z
358
+ l_cell_ids = self.cell_ids_fine if runtime_ids is None else np.asarray(runtime_ids)
359
+ current_nside = self.in_nside
360
+
361
+ for i, hc in enumerate(self.hconv_levels):
362
+ # save skip BEFORE going down (channels = C_i, grid = current level)
363
+ skips.append(self._as_tensor_batch(l_data))
364
+ ids_list.append(self._to_numpy_ids(l_cell_ids))
365
+
366
+ # conv to next channels C_{i+1} at same grid
367
+ w_enc = self.enc_w[i]
368
+ l_data = hc.Convol_torch(l_data, w_enc, cell_ids=l_cell_ids) # (B, C_{i+1}, N_current)
369
+ if not torch.is_tensor(l_data):
370
+ l_data = torch.as_tensor(l_data, device=self.runtime_device)
371
+ l_data = self._as_tensor_batch(l_data)
372
+ l_data = self.enc_bn[i](l_data)
373
+ l_data = F.gelu(l_data)
374
+
375
+ # Down one level
376
+ l_data, l_cell_ids = hc.Down(l_data, cell_ids=l_cell_ids, nside=current_nside, max_poll=True)
377
+ l_data = self._as_tensor_batch(l_data)
378
+ current_nside //= 2
379
+
380
+ # We are now at token grid with channels = C_token
381
+ x_tok = l_data # (B, C_token, Ntok)
382
+ token_ids = l_cell_ids # ids at token level
383
+ assert x_tok.shape[1] == self.embed_dim, "Token channels mismatch with embed_dim."
384
+
385
+ # -------- Transformer on tokens --------
386
+ seq = x_tok.permute(0, 2, 1) # (B, Ntok, E)
387
+ if self.cls_token_enabled:
388
+ cls = self.cls_token.expand(seq.size(0), -1, -1)
389
+ seq = torch.cat([cls, seq], dim=1)
390
+ if self.pos_embed is not None:
391
+ seq = seq + self.pos_embed[:, :seq.shape[1], :]
392
+
393
+ seq = self.encoder(seq) # (B, Ntok(+1), E)
394
+ if self.cls_token_enabled:
395
+ tokens = seq[:, 1:, :] # drop cls for dense
396
+ else:
397
+ tokens = seq
398
+
399
+ tok_feat = self.token_proj(tokens).permute(0, 2, 1) # (B, C_token, Ntok)
400
+
401
+ if self.task == "global":
402
+ if self.head_type == "cls" and self.cls_token_enabled:
403
+ cls_vec = seq[:, 0, :]
404
+ return nn.Linear(self.embed_dim, self.out_channels).to(seq.device)(cls_vec)
405
+ else:
406
+ return nn.Linear(self.embed_dim, self.out_channels).to(seq.device)(tokens.mean(dim=1))
407
+
408
+ # -------- Build runtime id chain (fine -> ... -> token) --------
409
+ fine_ids_runtime = self.cell_ids_fine if runtime_ids is None else np.asarray(runtime_ids)
410
+ ids_chain = [np.asarray(fine_ids_runtime)]
411
+ nside_tmp = self.in_nside
412
+ _dummy = self.f.backend.bk_cast(np.zeros((1, 1, ids_chain[0].shape[0]), dtype=self.np_dtype))
413
+ for hc in self.hconv_levels:
414
+ _dummy, _next = hc.Down(_dummy, cell_ids=ids_chain[-1], nside=nside_tmp, max_poll=True)
415
+ ids_chain.append(self.f.backend.to_numpy(_next))
416
+ nside_tmp //= 2
417
+
418
+ tok_ids_np = self._to_numpy_ids(token_ids)
419
+
420
+ assert tok_feat.shape[-1] == tok_ids_np.shape[0], "Token count mismatch."
421
+ assert np.array_equal(tok_ids_np, ids_chain[-1]), "Token ids mismatch with runtime chain."
422
+
423
+ # list of nsides at each encoder level (fine -> ... -> pre-token)
424
+ nsides_levels = [self.in_nside // (2 ** k) for k in range(self.token_down)]
425
+
426
+ # -------- Decoder: Up step-by-step with fusion conv --------
427
+ y = tok_feat # (B, C_token, Ntok)
428
+ dec_idx = 0 # index in self.dec_w / self.dec_bn (built from token->fine order)
429
+ for i in range(len(ids_chain)-1, 0, -1):
430
+ coarse_ids = ids_chain[i] # current y grid
431
+ fine_ids = ids_chain[i-1] # target grid
432
+ source_ns = self.in_nside // (2 ** i)
433
+ fine_ns = self.in_nside // (2 ** (i-1))
434
+
435
+ # choose operator for the target fine level
436
+ if fine_ns == self.in_nside:
437
+ op_fine = self.hconv_head
438
+ else:
439
+ idx = nsides_levels.index(fine_ns)
440
+ op_fine = self.hconv_levels[idx]
441
+
442
+ # Up one level
443
+ y_up = op_fine.Up(y, cell_ids=coarse_ids, o_cell_ids=fine_ids, nside=source_ns)
444
+ if not torch.is_tensor(y_up):
445
+ y_up = torch.as_tensor(y_up, device=self.runtime_device)
446
+ y_up = self._as_tensor_batch(y_up) # (B, C_{i}, N_fine)
447
+
448
+ # Skip at this level (channels = C_{i-1})
449
+ skip_i = self._as_tensor_batch(skips[i-1]).to(y_up.device)
450
+ assert np.array_equal(np.asarray(ids_list[i-1]), np.asarray(fine_ids)), "Skip ids misaligned."
451
+
452
+ # Concat and fuse: (C_{i} + C_{i-1}) -> C_{i-1}
453
+ y_cat = torch.cat([y_up, skip_i], dim=1)
454
+ y = op_fine.Convol_torch(y_cat, self.dec_w[dec_idx], cell_ids=fine_ids)
455
+ if not torch.is_tensor(y):
456
+ y = torch.as_tensor(y, device=self.runtime_device)
457
+ y = self._as_tensor_batch(y)
458
+ y = self.dec_bn[dec_idx](y)
459
+ y = F.gelu(y)
460
+ if self.dropout > 0:
461
+ y = F.dropout(y, p=self.dropout, training=self.training)
462
+ dec_idx += 1
463
+
464
+ # y is now (B, C_fine, Nfine)
465
+ # -------- Final head to out_channels --------
466
+ y = self.hconv_head.Convol_torch(y, self.head_w, cell_ids=fine_ids_runtime)
467
+ if not torch.is_tensor(y):
468
+ y = torch.as_tensor(y, device=self.runtime_device)
469
+ y = self._as_tensor_batch(y)
470
+ if self.task == "segmentation" and self.head_bn is not None:
471
+ y = self.head_bn(y)
472
+
473
+ if self.final_activation == "sigmoid":
474
+ y = torch.sigmoid(y)
475
+ elif self.final_activation == "softmax":
476
+ y = torch.softmax(y, dim=1)
477
+ return y
478
+
479
+ @torch.no_grad()
480
+ def predict(self, x: Union[torch.Tensor, np.ndarray], batch_size: int = 8) -> torch.Tensor:
481
+ self.eval()
482
+ if isinstance(x, np.ndarray):
483
+ x = torch.from_numpy(x).float()
484
+ outs = []
485
+ for i in range(0, x.shape[0], batch_size):
486
+ xb = x[i : i + batch_size].to(self.runtime_device)
487
+ outs.append(self.forward(xb))
488
+ return torch.cat(outs, dim=0)
489
+
490
+
491
+ # -------------------------- Smoke test --------------------------
492
+ if __name__ == "__main__":
493
+ # nside=4 → Npix=192, 2 down levels → token_nside=1
494
+ in_nside = 4
495
+ npix = 12 * in_nside * in_nside
496
+ cell_ids = np.arange(npix, dtype=np.int64)
497
+
498
+ B, Cin = 2, 3
499
+ x = torch.randn(B, Cin, npix)
500
+
501
+ # Channel widths per level (fine -> token), divisible by G=1 here
502
+ level_dims = [64, 96, 128]
503
+
504
+ model = HealpixViTVarLevels(
505
+ in_nside=in_nside,
506
+ n_chan_in=Cin,
507
+ level_dims=level_dims, # len=3 => token_down=2
508
+ depth=2,
509
+ num_heads=4,
510
+ cell_ids=cell_ids,
511
+ task="regression",
512
+ out_channels=1,
513
+ KERNELSZ=3,
514
+ G=1,
515
+ cls_token=False,
516
+ dropout=0.1,
517
+ ).eval()
518
+
519
+ with torch.no_grad():
520
+ y = model(x)
521
+ print("Output:", y.shape) # (B, Cout, Nfine)