foscat 2025.9.4__py3-none-any.whl → 2025.10.2__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
@@ -0,0 +1,658 @@
1
+ """
2
+ HealpixViT — Vision Transformer on HEALPix with Foscat
3
+ ======================================================
4
+
5
+ This module provides a **Vision Transformer (ViT)** adapted to spherical data laid out on a
6
+ **HEALPix nested grid**. It integrates **Foscat**'s `SphericalStencil` operators to perform
7
+ spherical convolutions for patch embedding, hierarchical **Down/Up** between HEALPix levels,
8
+ and an optional per-pixel spherical head after the Transformer encoder.
9
+
10
+ Why this design?
11
+ ----------------
12
+ - HEALPix provides a hierarchical, equal-area tessellation of the sphere. In **nested** ordering,
13
+ each pixel at level \(L\) has 4 children at level \(L+1\). This makes **tokenization** natural:
14
+ we can repeatedly call `Down()` to move to a coarser grid that serves as the **token grid**.
15
+ - A Transformer encoder then operates on the **sequence of tokens**. For dense outputs, we map the
16
+ token features back to the finest grid with `Up()` and refine with a spherical convolution head.
17
+ - Because we reuse the same Foscat operators as in a HEALPix U-Net, we preserve consistency with
18
+ existing spherical CNN pipelines while gaining the long-range modeling capacity of Transformers.
19
+
20
+ Typical use cases
21
+ -----------------
22
+ - **Global regression/classification** (e.g., predicting a climate index from full-sky fields).
23
+ - **Dense regression/segmentation** (e.g., SST anomaly prediction, cloud/ice masks) directly on
24
+ HEALPix maps, including **multi-resolution fusion** thanks to nested Down/Up.
25
+
26
+ Notes on `cell_ids`
27
+ -------------------
28
+ - This implementation supports passing **runtime `cell_ids`** to `forward(...)` to match your
29
+ data pipeline (e.g., when per-sample IDs are managed externally). If omitted, it uses the
30
+ `cell_ids` provided at construction.
31
+ - All IDs are assumed to be **nested** and **int64**, with range `[0, 12*nside^2 - 1]` at each level.
32
+ Sanity checks are included to prevent HEALPix `pix2loc` errors.
33
+ """
34
+
35
+ from __future__ import annotations
36
+ from typing import List, Optional, Literal, Tuple, Union
37
+ import numpy as np
38
+
39
+ import torch
40
+ import torch.nn as nn
41
+ import torch.nn.functional as F
42
+
43
+ import foscat.scat_cov as sc
44
+ import foscat.SphericalStencil as ho
45
+
46
+ # -----------------------------------------------------------------------------
47
+ # Helper: safe type alias
48
+ # -----------------------------------------------------------------------------
49
+ ArrayLikeI64 = Union[np.ndarray, torch.Tensor]
50
+
51
+
52
+ class HealpixViT(nn.Module):
53
+ """Vision Transformer on the HEALPix sphere using Foscat-oriented ops.
54
+
55
+ Parameters
56
+ ----------
57
+ in_nside : int
58
+ Input HEALPix nside at the **finest** level (nested ordering). The number of pixels is
59
+ `Npix = 12 * in_nside**2`.
60
+ n_chan_in : int
61
+ Number of input channels at the finest grid.
62
+ embed_dim : int
63
+ Transformer embedding dimension (also the channel count after patch embedding).
64
+ depth : int
65
+ Number of Transformer encoder layers.
66
+ num_heads : int
67
+ Number of attention heads per layer.
68
+ cell_ids : np.ndarray
69
+ Finest-level **nested** cell indices (shape `[Npix]`, dtype `int64`). These define the
70
+ pixel layout of your input features.
71
+ mlp_ratio : float, default=4.0
72
+ Expansion ratio for the MLP inside each Transformer block.
73
+ token_down : int, default=2
74
+ Number of `Down()` steps to reach the token grid. The token nside is
75
+ `token_nside = in_nside // (2**token_down)`.
76
+ task : {"regression","segmentation","global"}, default="regression"
77
+ - "global": return a vector (pooled tokens → `out_channels`).
78
+ - "regression"/"segmentation": return per-pixel predictions on the finest grid.
79
+ out_channels : int, default=1
80
+ Output channels for dense tasks (ignored for `task="global"`).
81
+ final_activation : {"none","sigmoid","softmax"} | None
82
+ Optional activation for the output. If `None`, sensible defaults are chosen per task.
83
+ KERNELSZ : int, default=3
84
+ Spatial kernel size for spherical convolutions (Foscat oriented conv).
85
+ gauge_type : {"cosmo","phi"}, default="cosmo"
86
+ Orientation/gauge definition in `SphericalStencil`.
87
+ G : int, default=1
88
+ Number of gauges (internal orientation multiplicity). `embed_dim` must be divisible by `G`.
89
+ prefer_foscat_gpu : bool, default=True
90
+ Try Foscat on CUDA if available; fall back to CPU otherwise.
91
+ cls_token : bool, default=False
92
+ Include a `[CLS]` token for global tasks.
93
+ pos_embed : {"learned","none"}, default="learned"
94
+ Positional encoding type for tokens.
95
+ head_type : {"mean","cls"}, default="mean"
96
+ Pooling strategy for global tasks (mean over tokens or CLS vector).
97
+ dtype : {"float32","float64"}, default="float32"
98
+ Numpy dtype used for internal Foscat buffers. Model parameters remain `float32`.
99
+
100
+ Input/Output shapes
101
+ -------------------
102
+ Input: `(B, C_in, Npix)` with `Npix = 12 * in_nside**2`.
103
+ Output: - global task: `(B, out_channels)`
104
+ - dense task: `(B, out_channels, Npix)`
105
+ """
106
+
107
+ # ------------------------------------------------------------------
108
+ # Construction
109
+ # ------------------------------------------------------------------
110
+ def __init__(
111
+ self,
112
+ *,
113
+ in_nside: int,
114
+ n_chan_in: int,
115
+ embed_dim: int,
116
+ depth: int,
117
+ num_heads: int,
118
+ cell_ids: np.ndarray,
119
+ mlp_ratio: float = 4.0,
120
+ token_down: int = 2,
121
+ task: Literal["regression","segmentation","global"] = "regression",
122
+ out_channels: int = 1,
123
+ final_activation: Optional[Literal["none","sigmoid","softmax"]] = None,
124
+ KERNELSZ: int = 3,
125
+ gauge_type: Optional[Literal["cosmo","phi"]] = "cosmo",
126
+ G: int = 1,
127
+ prefer_foscat_gpu: bool = True,
128
+ cls_token: bool = False,
129
+ pos_embed: Literal["learned","none"] = "learned",
130
+ head_type: Literal["mean","cls"] = "mean",
131
+ dtype: Literal["float32","float64"] = "float32",
132
+ ) -> None:
133
+ super().__init__()
134
+
135
+ # ------------------- store config & dtypes -------------------
136
+ self.in_nside = int(in_nside)
137
+ self.n_chan_in = int(n_chan_in)
138
+ self.embed_dim = int(embed_dim)
139
+ self.depth = int(depth)
140
+ self.num_heads = int(num_heads)
141
+ self.mlp_ratio = float(mlp_ratio)
142
+ self.token_down = int(token_down)
143
+ self.task = task
144
+ self.out_channels = int(out_channels)
145
+ self.KERNELSZ = int(KERNELSZ)
146
+ self.gauge_type = gauge_type
147
+ self.G = int(G)
148
+ self.prefer_foscat_gpu = bool(prefer_foscat_gpu)
149
+ self.cls_token_enabled = bool(cls_token)
150
+ self.pos_embed_type = pos_embed
151
+ self.head_type = head_type
152
+
153
+ if dtype == "float32":
154
+ self.np_dtype = np.float32
155
+ self.torch_dtype = torch.float32
156
+ else:
157
+ self.np_dtype = np.float64
158
+ self.torch_dtype = torch.float32 # keep params in fp32
159
+
160
+ # ------------------- validate inputs -------------------
161
+ if cell_ids is None:
162
+ raise ValueError("cell_ids (finest) must be provided.")
163
+ self.cell_ids_fine = np.asarray(cell_ids)
164
+ self._check_ids(self.cell_ids_fine, self.in_nside, name="cell_ids_fine")
165
+
166
+ if self.G < 1:
167
+ raise ValueError("G must be >= 1")
168
+ if self.embed_dim % self.G != 0:
169
+ raise ValueError(f"embed_dim={self.embed_dim} must be divisible by G={self.G}")
170
+ if self.task not in {"regression", "segmentation", "global"}:
171
+ raise ValueError("task must be 'regression', 'segmentation', or 'global'")
172
+
173
+ # Default final activation per task if not specified
174
+ if final_activation is None:
175
+ if self.task == "regression":
176
+ self.final_activation = "none"
177
+ elif self.task == "segmentation":
178
+ self.final_activation = "sigmoid" if out_channels == 1 else "softmax"
179
+ else:
180
+ self.final_activation = "none"
181
+ else:
182
+ self.final_activation = final_activation
183
+
184
+ # ------------------- foscat functional wrapper -------------------
185
+ self.f = sc.funct(KERNELSZ=self.KERNELSZ)
186
+
187
+ # ------------------- build hierarchy (fine → coarse) -------------------
188
+ # We progressively `Down()` to precompute the token grid ids and operators.
189
+ self.hconv_levels: List[ho.SphericalStencil] = [] # op at successive levels
190
+ self.level_cell_ids: List[np.ndarray] = [self.cell_ids_fine]
191
+ current_nside = self.in_nside
192
+
193
+ # dummy buffer to probe Down; lives in Foscat backend dtype
194
+ dummy = self.f.backend.bk_cast(
195
+ np.zeros((1, 1, self.cell_ids_fine.shape[0]), dtype=self.np_dtype)
196
+ )
197
+
198
+ for _ in range(self.token_down):
199
+ hc = ho.SphericalStencil(
200
+ current_nside,
201
+ self.KERNELSZ,
202
+ n_gauges=self.G,
203
+ gauge_type=self.gauge_type,
204
+ cell_ids=self.level_cell_ids[-1],
205
+ dtype=self.torch_dtype,
206
+ )
207
+ self.hconv_levels.append(hc)
208
+
209
+ # Down to get next cell ids
210
+ dummy, next_ids = hc.Down(
211
+ dummy,
212
+ cell_ids=self.level_cell_ids[-1],
213
+ nside=current_nside,
214
+ max_poll=True,
215
+ )
216
+ next_ids = self.f.backend.to_numpy(next_ids)
217
+ current_nside //= 2
218
+ self._check_ids(next_ids, current_nside, name="token_level_cell_ids")
219
+ self.level_cell_ids.append(next_ids)
220
+
221
+ # token grid (where the Transformer runs)
222
+ self.token_nside = current_nside if self.token_down > 0 else self.in_nside
223
+ if self.token_nside < 1:
224
+ raise ValueError(
225
+ f"token_down={self.token_down} too large for in_nside={self.in_nside}"
226
+ )
227
+ self.token_cell_ids = self.level_cell_ids[-1]
228
+
229
+ # Operators at token and fine levels (used for Up and head)
230
+ self.hconv_token = ho.SphericalStencil(
231
+ self.token_nside,
232
+ self.KERNELSZ,
233
+ n_gauges=self.G,
234
+ gauge_type=self.gauge_type,
235
+ cell_ids=self.token_cell_ids,
236
+ dtype=self.torch_dtype,
237
+ )
238
+ self.hconv_head = ho.SphericalStencil(
239
+ self.in_nside,
240
+ self.KERNELSZ,
241
+ n_gauges=self.G,
242
+ gauge_type=self.gauge_type,
243
+ cell_ids=self.cell_ids_fine,
244
+ dtype=self.torch_dtype,
245
+ )
246
+
247
+ # ------------------- patch embedding (finest grid) -------------------
248
+ embed_g = self.embed_dim // self.G
249
+ # weight shapes follow Foscat conv expectations: (Cin, Cout_per_gauge, KERNELSZ*KERNELSZ)
250
+ self.patch_w1 = nn.Parameter(
251
+ torch.empty(self.n_chan_in, embed_g, self.KERNELSZ * self.KERNELSZ)
252
+ )
253
+ nn.init.kaiming_uniform_(self.patch_w1.view(self.n_chan_in * embed_g, -1), a=np.sqrt(5))
254
+ self.patch_bn1 = nn.GroupNorm(
255
+ num_groups=min(8, embed_g if embed_g > 1 else 1), num_channels=self.embed_dim
256
+ )
257
+
258
+ self.patch_w2 = nn.Parameter(
259
+ torch.empty(self.embed_dim, embed_g, self.KERNELSZ * self.KERNELSZ)
260
+ )
261
+ nn.init.kaiming_uniform_(self.patch_w2.view(self.embed_dim * embed_g, -1), a=np.sqrt(5))
262
+ self.patch_bn2 = nn.GroupNorm(
263
+ num_groups=min(8, embed_g if embed_g > 1 else 1), num_channels=self.embed_dim
264
+ )
265
+
266
+ # ------------------- positional encoding -------------------
267
+ self.n_tokens = int(self.token_cell_ids.shape[0])
268
+ if self.cls_token_enabled:
269
+ self.cls_token = nn.Parameter(torch.zeros(1, 1, self.embed_dim))
270
+ nn.init.trunc_normal_(self.cls_token, std=0.02)
271
+ n_pe = self.n_tokens + 1
272
+ else:
273
+ self.cls_token = None
274
+ n_pe = self.n_tokens
275
+
276
+ if self.pos_embed_type == "learned":
277
+ self.pos_embed = nn.Parameter(torch.zeros(1, n_pe, self.embed_dim))
278
+ nn.init.trunc_normal_(self.pos_embed, std=0.02)
279
+ else:
280
+ self.pos_embed = None
281
+
282
+ # ------------------- transformer encoder -------------------
283
+ enc_layer = nn.TransformerEncoderLayer(
284
+ d_model=self.embed_dim,
285
+ nhead=self.num_heads,
286
+ dim_feedforward=int(self.embed_dim * self.mlp_ratio),
287
+ dropout=0.0,
288
+ activation="gelu",
289
+ batch_first=True,
290
+ norm_first=True,
291
+ )
292
+ self.encoder = nn.TransformerEncoder(enc_layer, num_layers=self.depth)
293
+
294
+ # ------------------- output heads -------------------
295
+ if self.task == "global":
296
+ # Global head: a single Linear on pooled token features
297
+ self.global_head = nn.Linear(self.embed_dim, self.out_channels)
298
+ else:
299
+ # Dense head: project token embeddings to channels, Up to fine grid, optional conv
300
+ if self.out_channels % self.G != 0:
301
+ raise ValueError(
302
+ f"out_channels={self.out_channels} must be divisible by G={self.G}"
303
+ )
304
+ out_g = self.out_channels // self.G
305
+ self.token_proj = nn.Linear(self.embed_dim, self.G * out_g)
306
+ self.head_w = nn.Parameter(
307
+ torch.empty(self.out_channels, out_g, self.KERNELSZ * self.KERNELSZ)
308
+ )
309
+ nn.init.kaiming_uniform_(
310
+ self.head_w.view(self.out_channels * out_g, -1), a=np.sqrt(5)
311
+ )
312
+ self.head_bn = (
313
+ nn.GroupNorm(
314
+ num_groups=min(8, self.out_channels if self.out_channels > 1 else 1),
315
+ num_channels=self.out_channels,
316
+ )
317
+ if self.task == "segmentation"
318
+ else None
319
+ )
320
+
321
+ # ------------------- device probing (CUDA → CPU fallback) -------------------
322
+ pref = torch.device("cuda" if torch.cuda.is_available() else "cpu")
323
+ self.runtime_device = self._probe_and_set_runtime_device(pref)
324
+
325
+ # ------------------------------------------------------------------
326
+ # Internal sanity checks
327
+ # ------------------------------------------------------------------
328
+ @staticmethod
329
+ def _check_ids(ids: ArrayLikeI64, nside: int, name: str = "cell_ids") -> None:
330
+ """Sanity check to avoid HEALPix `pix2loc` errors.
331
+ Ensures dtype=int64, range in [0, 12*nside^2 - 1].
332
+ """
333
+ if isinstance(ids, torch.Tensor):
334
+ ids = ids.detach().cpu().numpy()
335
+ ids = np.asarray(ids)
336
+ if ids.dtype != np.int64:
337
+ raise TypeError(f"{name} must be int64, got {ids.dtype}.")
338
+ npix = 12 * nside * nside
339
+ imin, imax = int(ids.min()), int(ids.max())
340
+ if imin < 0 or imax >= npix:
341
+ raise ValueError(
342
+ f"{name} out of range for nside={nside}: min={imin}, max={imax}, allowed=[0,{npix-1}]"
343
+ )
344
+
345
+ # ------------------------------------------------------------------
346
+ # Device utilities
347
+ # ------------------------------------------------------------------
348
+ def _move_hc(self, hc: ho.SphericalStencil, device: torch.device) -> None:
349
+ """Move internal tensors of SphericalStencil to the given device.
350
+ This mirrors the plumbing in U-Net-like codebases using Foscat.
351
+ """
352
+ for name, val in list(vars(hc).items()):
353
+ try:
354
+ if torch.is_tensor(val):
355
+ setattr(hc, name, val.to(device))
356
+ elif isinstance(val, (list, tuple)) and val and torch.is_tensor(val[0]):
357
+ setattr(hc, name, type(val)([v.to(device) for v in val]))
358
+ except Exception:
359
+ # Some attributes may be non-tensors; ignore.
360
+ pass
361
+
362
+ @torch.no_grad()
363
+ def _probe_and_set_runtime_device(self, preferred: torch.device) -> torch.device:
364
+ """Try to run on CUDA with Foscat; otherwise, gracefully fall back to CPU.
365
+ Performs a tiny dry-run spherical convolution to ensure compatibility.
366
+ """
367
+ if preferred.type == "cuda" and self.prefer_foscat_gpu:
368
+ try:
369
+ super().to(preferred)
370
+ for hc in self.hconv_levels + [self.hconv_token, self.hconv_head]:
371
+ self._move_hc(hc, preferred)
372
+ # Dry run: minimal conv on finest grid
373
+ npix0 = int(self.cell_ids_fine.shape[0])
374
+ x_try = torch.zeros(1, self.n_chan_in, npix0, device=preferred)
375
+ hc0 = self.hconv_levels[0] if len(self.hconv_levels) > 0 else self.hconv_head
376
+ y_try = hc0.Convol_torch(x_try, self.patch_w1)
377
+ _ = (y_try if torch.is_tensor(y_try) else torch.as_tensor(y_try, device=preferred)).sum().item()
378
+ self._foscat_device = preferred
379
+ return preferred
380
+ except Exception as e:
381
+ # Record and fall back
382
+ self._gpu_probe_error = repr(e)
383
+ cpu = torch.device("cpu")
384
+ super().to(cpu)
385
+ for hc in self.hconv_levels + [self.hconv_token, self.hconv_head]:
386
+ self._move_hc(hc, cpu)
387
+ self._foscat_device = cpu
388
+ return cpu
389
+
390
+ # ------------------------------------------------------------------
391
+ # Forward helpers
392
+ # ------------------------------------------------------------------
393
+ def _patch_embed(self, x: torch.Tensor, cell_ids: Optional[np.ndarray]) -> torch.Tensor:
394
+ """Spherical patch embedding at the **finest** grid.
395
+ Applies two Foscat oriented convolutions with GN+GELU.
396
+ Input `(B, C_in, Nfine)` → Output `(B, embed_dim, Nfine)`.
397
+ """
398
+ hc0 = self.hconv_levels[0] if len(self.hconv_levels) > 0 else self.hconv_head
399
+ if cell_ids is None:
400
+ # Use constructor-time ids
401
+ y = hc0.Convol_torch(x, self.patch_w1)
402
+ y = self._as_tensor_batch(y)
403
+ y = self.patch_bn1(y)
404
+ y = F.gelu(y)
405
+ y = hc0.Convol_torch(y, self.patch_w2)
406
+ y = self._as_tensor_batch(y)
407
+ y = self.patch_bn2(y)
408
+ y = F.gelu(y)
409
+ return y
410
+ else:
411
+ # Use runtime ids provided by caller
412
+ y = hc0.Convol_torch(x, self.patch_w1, cell_ids=cell_ids)
413
+ y = self._as_tensor_batch(y)
414
+ y = self.patch_bn1(y)
415
+ y = F.gelu(y)
416
+ y = hc0.Convol_torch(y, self.patch_w2, cell_ids=cell_ids)
417
+ y = self._as_tensor_batch(y)
418
+ y = self.patch_bn2(y)
419
+ y = F.gelu(y)
420
+ return y
421
+
422
+ def _down_to_tokens(
423
+ self, x: torch.Tensor, cell_ids: Optional[np.ndarray]
424
+ ) -> Tuple[torch.Tensor, np.ndarray]:
425
+ """Apply `token_down` Down() steps to reach the **token grid**.
426
+ Returns `(x_tokens, token_cell_ids)` where `x_tokens` has shape `(B, C, N_tokens)`.
427
+ If `cell_ids` is provided, uses them as the starting fine-grid ids; otherwise uses
428
+ the constructor-time `self.cell_ids_fine`.
429
+ """
430
+ l_data = x
431
+ l_cell_ids = self.cell_ids_fine if cell_ids is None else np.asarray(cell_ids)
432
+ current_nside = self.in_nside
433
+
434
+ for hc in self.hconv_levels:
435
+ l_data, l_cell_ids = hc.Down(
436
+ l_data, cell_ids=l_cell_ids, nside=current_nside, max_poll=True
437
+ )
438
+ l_data = self._as_tensor_batch(l_data)
439
+ current_nside //= 2
440
+ return l_data, l_cell_ids
441
+
442
+ def _tokens_to_sequence(self, x_tokens: torch.Tensor) -> torch.Tensor:
443
+ """Rearrange `(B, C, Ntok)` → `(B, Ntok(+CLS), C)` and add positional embeddings."""
444
+ B, C, Nt = x_tokens.shape
445
+ seq = x_tokens.permute(0, 2, 1) # (B, Nt, C)
446
+ if self.cls_token_enabled:
447
+ cls = self.cls_token.expand(B, -1, -1)
448
+ seq = torch.cat([cls, seq], dim=1) # (B, 1+Nt, C)
449
+ if self.pos_embed is not None:
450
+ seq = seq + self.pos_embed[:, : seq.shape[1], :]
451
+ return seq
452
+
453
+ def _sequence_to_tokens(
454
+ self, seq: torch.Tensor
455
+ ) -> Tuple[torch.Tensor, Optional[torch.Tensor]]:
456
+ """Strip CLS if present and return `(tokens_only, cls_vector)`."""
457
+ if self.cls_token_enabled:
458
+ cls_vec = seq[:, 0, :]
459
+ tokens = seq[:, 1:, :]
460
+ return tokens, cls_vec
461
+ return seq, None
462
+
463
+ # ------------------------------------------------------------------
464
+ # Forward
465
+ # ------------------------------------------------------------------
466
+ def forward(self, x: torch.Tensor, cell_ids: Optional[ArrayLikeI64] = None) -> torch.Tensor:
467
+ """Forward pass.
468
+
469
+ Parameters
470
+ ----------
471
+ x : torch.Tensor
472
+ Input tensor of shape `(B, C_in, Npix)` at the finest grid.
473
+ cell_ids : Optional[np.ndarray or torch.Tensor]
474
+ Optional **nested** pixel indices for the input; if provided, they are used throughout
475
+ the pipeline (patch embedding, Down, Up, head conv). If `None`, the constructor-time
476
+ `cell_ids` are used.
477
+ """
478
+ if not isinstance(x, torch.Tensor):
479
+ raise TypeError("x must be a torch.Tensor")
480
+ if x.dim() != 3:
481
+ raise ValueError("Input must be (B, C, Npix)")
482
+ if x.shape[1] != self.n_chan_in:
483
+ raise ValueError(f"Expected {self.n_chan_in} channels, got {x.shape[1]}")
484
+
485
+ # Normalize/validate runtime ids once
486
+ runtime_ids = None
487
+ if cell_ids is not None:
488
+ if isinstance(cell_ids, torch.Tensor):
489
+ cell_ids = cell_ids.detach().cpu().numpy()
490
+ cell_ids = np.asarray(cell_ids)
491
+ # If given per-batch ids (B, Npix), take first row (assume same layout for the batch)
492
+ if cell_ids.ndim == 2:
493
+ cell_ids = cell_ids[0]
494
+ self._check_ids(cell_ids, self.in_nside, name="forward:cell_ids")
495
+ runtime_ids = cell_ids
496
+
497
+ x = x.to(self.runtime_device)
498
+
499
+ # 1) Patch embedding (finest grid)
500
+ x = self._patch_embed(x, runtime_ids) # (B, embed_dim, Nfine)
501
+
502
+ # 2) Down to token grid
503
+ x_tok, token_ids = self._down_to_tokens(x, runtime_ids) # (B, embed_dim, Ntok)
504
+
505
+ # 3) Transformer encoder on token sequence
506
+ seq = self._tokens_to_sequence(x_tok) # (B, Ntok(+1), embed_dim)
507
+ seq = self.encoder(seq) # (B, Ntok(+1), embed_dim)
508
+ tokens, cls_vec = self._sequence_to_tokens(seq)
509
+
510
+ if self.task == "global":
511
+ # Global vector from mean/CLS pooling
512
+ if self.head_type == "cls" and self.cls_token_enabled and cls_vec is not None:
513
+ out = self.global_head(cls_vec) # (B, out_channels)
514
+ else:
515
+ out = self.global_head(tokens.mean(dim=1))
516
+ return out
517
+
518
+ # 4) Project tokens to channels at token grid
519
+ tok_proj = self.token_proj(tokens) # (B, Ntok, out_channels)
520
+ tok_proj = tok_proj.permute(0, 2, 1) # (B, out_channels, Ntok)
521
+ # Sanity: token feature count must match token_ids length
522
+ if isinstance(token_ids, torch.Tensor):
523
+ _tok_ids = token_ids.detach().cpu().numpy()
524
+ else:
525
+ _tok_ids = np.asarray(token_ids)
526
+ assert tok_proj.shape[-1] == _tok_ids.shape[0], (
527
+ f"Ntok mismatch: {tok_proj.shape[-1]} != {_tok_ids.shape[0]}"
528
+ )
529
+
530
+ # 5) Up from token grid to finest grid
531
+ # Use constructor-time fine ids by default; override if runtime ids provided.
532
+ fine_ids = self.cell_ids_fine if runtime_ids is None else runtime_ids # 5) Multi-step Up from token grid to finest grid (one HEALPix level at a time)
533
+ # Use constructor-time fine ids by default; override if runtime ids provided.
534
+ fine_ids_runtime = self.cell_ids_fine if runtime_ids is None else runtime_ids
535
+
536
+ # Build the ID chain from fine → ... → token for THIS forward, using runtime ids
537
+ _ids = fine_ids_runtime
538
+ nside_tmp = self.in_nside
539
+ ids_chain = [np.asarray(_ids)]
540
+ _dummy = self.f.backend.bk_cast(np.zeros((1, 1, ids_chain[0].shape[0]), dtype=self.np_dtype))
541
+ for hc in self.hconv_levels:
542
+ _dummy, _next = hc.Down(_dummy, cell_ids=ids_chain[-1], nside=nside_tmp, max_poll=True)
543
+ ids_chain.append(self.f.backend.to_numpy(_next))
544
+ nside_tmp //= 2
545
+
546
+ # Sanity: token_ids from the actual Down path must match the last element of the chain
547
+ if isinstance(token_ids, torch.Tensor):
548
+ _tok_ids = token_ids.detach().cpu().numpy()
549
+ else:
550
+ _tok_ids = np.asarray(token_ids)
551
+ assert tok_proj.shape[-1] == _tok_ids.shape[0], f"Ntok mismatch: {tok_proj.shape[-1]} != {_tok_ids.shape[0]}"
552
+ assert np.array_equal(_tok_ids, ids_chain[-1]), "token_ids mismatch with runtime Down() chain"
553
+
554
+ # Precompute nsides represented by hconv_levels (fine→coarse, excluding token level)
555
+ nsides_levels = [self.in_nside // (2 ** k) for k in range(self.token_down)] # e.g., [8, 4] for token_down=2
556
+
557
+ # Now Up step-by-step: token (coarse) → ... → fine
558
+ y_up = tok_proj
559
+ for i in range(len(ids_chain) - 1, 0, -1):
560
+ coarse_ids = ids_chain[i]
561
+ fine_ids_step = ids_chain[i - 1]
562
+ source_nside = self.in_nside // (2 ** i) # e.g., 2, then 4
563
+ fine_nside = self.in_nside // (2 ** (i - 1)) # e.g., 4, then 8
564
+ # pick the operator of the target (fine) level
565
+ if fine_nside == self.in_nside:
566
+ op_fine = self.hconv_head
567
+ else:
568
+ idx = nsides_levels.index(fine_nside)
569
+ op_fine = self.hconv_levels[idx]
570
+ y_up = op_fine.Up(y_up, cell_ids=coarse_ids, o_cell_ids=fine_ids_step, nside=source_nside)
571
+ if not torch.is_tensor(y_up):
572
+ y_up = torch.as_tensor(y_up, device=self.runtime_device)
573
+ y_up = self._as_tensor_batch(y_up) # (B, out_channels, N at this fine level)
574
+
575
+
576
+ # 6) Optional spherical head conv for refinement
577
+ y = self.hconv_head.Convol_torch(y_up, self.head_w, cell_ids=fine_ids)
578
+ if not torch.is_tensor(y):
579
+ y = torch.as_tensor(y, device=self.runtime_device)
580
+ y = self._as_tensor_batch(y)
581
+
582
+ if self.task == "segmentation" and self.head_bn is not None:
583
+ y = self.head_bn(y)
584
+
585
+ if self.final_activation == "sigmoid":
586
+ y = torch.sigmoid(y)
587
+ elif self.final_activation == "softmax":
588
+ y = torch.softmax(y, dim=1)
589
+ return y
590
+
591
+ # ------------------------------------------------------------------
592
+ # Misc helpers
593
+ # ------------------------------------------------------------------
594
+ def _as_tensor_batch(self, x):
595
+ """Normalize outputs of Foscat ops into a contiguous batch tensor.
596
+ Foscat may return a tensor or a single-element list of tensors.
597
+ This function ensures we always get a tensor of the expected shape.
598
+ """
599
+ if isinstance(x, list):
600
+ if len(x) == 1:
601
+ t = x[0]
602
+ return t.unsqueeze(0) if t.dim() == 2 else t
603
+ raise ValueError("Variable-length list not supported here; pass a tensor.")
604
+ return x
605
+
606
+ @torch.no_grad()
607
+ def predict(
608
+ self, x: Union[torch.Tensor, np.ndarray], batch_size: int = 8
609
+ ) -> torch.Tensor:
610
+ """Convenience method for batched inference.
611
+
612
+ Parameters
613
+ ----------
614
+ x : Tensor or ndarray
615
+ Input `(B, C_in, Npix)`.
616
+ batch_size : int
617
+ Mini-batch size used during prediction.
618
+ """
619
+ self.eval()
620
+ if isinstance(x, np.ndarray):
621
+ x = torch.from_numpy(x).float()
622
+ outs = []
623
+ for i in range(0, x.shape[0], batch_size):
624
+ xb = x[i : i + batch_size].to(self.runtime_device)
625
+ outs.append(self.forward(xb))
626
+ return torch.cat(outs, dim=0)
627
+
628
+
629
+ # -----------------------------------------------------------------------------
630
+ # Minimal smoke test (requires foscat installed)
631
+ # -----------------------------------------------------------------------------
632
+ if __name__ == "__main__":
633
+ # A tiny grid to validate shapes and device plumbing
634
+ in_nside = 4
635
+ npix = 12 * in_nside * in_nside
636
+ cell_ids = np.arange(npix, dtype=np.int64) # nested, fine-level ids
637
+
638
+ B, Cin = 2, 3
639
+ x = torch.randn(B, Cin, npix)
640
+
641
+ model = HealpixViT(
642
+ in_nside=in_nside,
643
+ n_chan_in=Cin,
644
+ embed_dim=64,
645
+ depth=2,
646
+ num_heads=4,
647
+ cell_ids=cell_ids,
648
+ token_down=2, # token_nside = in_nside // 4 = 1 here
649
+ task="regression",
650
+ out_channels=1,
651
+ KERNELSZ=3,
652
+ G=1,
653
+ cls_token=False,
654
+ )
655
+
656
+ with torch.no_grad():
657
+ y = model(x) # You can also pass `cell_ids=cell_ids` if your pipeline manages them at runtime
658
+ print("Output:", y.shape) # (B, out_channels, npix)