foscat 2025.9.4__py3-none-any.whl → 2025.10.2__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- foscat/FoCUS.py +14 -9
- foscat/Plot.py +59 -17
- foscat/healpix_unet_torch.py +17 -1
- foscat/healpix_vit_skip.py +445 -0
- foscat/healpix_vit_torch-old.py +658 -0
- foscat/healpix_vit_torch.py +521 -0
- foscat/planar_vit.py +206 -0
- foscat/unet_2_d_from_healpix_params.py +421 -0
- {foscat-2025.9.4.dist-info → foscat-2025.10.2.dist-info}/METADATA +1 -1
- {foscat-2025.9.4.dist-info → foscat-2025.10.2.dist-info}/RECORD +13 -8
- {foscat-2025.9.4.dist-info → foscat-2025.10.2.dist-info}/WHEEL +0 -0
- {foscat-2025.9.4.dist-info → foscat-2025.10.2.dist-info}/licenses/LICENSE +0 -0
- {foscat-2025.9.4.dist-info → foscat-2025.10.2.dist-info}/top_level.txt +0 -0
|
@@ -0,0 +1,658 @@
|
|
|
1
|
+
"""
|
|
2
|
+
HealpixViT — Vision Transformer on HEALPix with Foscat
|
|
3
|
+
======================================================
|
|
4
|
+
|
|
5
|
+
This module provides a **Vision Transformer (ViT)** adapted to spherical data laid out on a
|
|
6
|
+
**HEALPix nested grid**. It integrates **Foscat**'s `SphericalStencil` operators to perform
|
|
7
|
+
spherical convolutions for patch embedding, hierarchical **Down/Up** between HEALPix levels,
|
|
8
|
+
and an optional per-pixel spherical head after the Transformer encoder.
|
|
9
|
+
|
|
10
|
+
Why this design?
|
|
11
|
+
----------------
|
|
12
|
+
- HEALPix provides a hierarchical, equal-area tessellation of the sphere. In **nested** ordering,
|
|
13
|
+
each pixel at level \(L\) has 4 children at level \(L+1\). This makes **tokenization** natural:
|
|
14
|
+
we can repeatedly call `Down()` to move to a coarser grid that serves as the **token grid**.
|
|
15
|
+
- A Transformer encoder then operates on the **sequence of tokens**. For dense outputs, we map the
|
|
16
|
+
token features back to the finest grid with `Up()` and refine with a spherical convolution head.
|
|
17
|
+
- Because we reuse the same Foscat operators as in a HEALPix U-Net, we preserve consistency with
|
|
18
|
+
existing spherical CNN pipelines while gaining the long-range modeling capacity of Transformers.
|
|
19
|
+
|
|
20
|
+
Typical use cases
|
|
21
|
+
-----------------
|
|
22
|
+
- **Global regression/classification** (e.g., predicting a climate index from full-sky fields).
|
|
23
|
+
- **Dense regression/segmentation** (e.g., SST anomaly prediction, cloud/ice masks) directly on
|
|
24
|
+
HEALPix maps, including **multi-resolution fusion** thanks to nested Down/Up.
|
|
25
|
+
|
|
26
|
+
Notes on `cell_ids`
|
|
27
|
+
-------------------
|
|
28
|
+
- This implementation supports passing **runtime `cell_ids`** to `forward(...)` to match your
|
|
29
|
+
data pipeline (e.g., when per-sample IDs are managed externally). If omitted, it uses the
|
|
30
|
+
`cell_ids` provided at construction.
|
|
31
|
+
- All IDs are assumed to be **nested** and **int64**, with range `[0, 12*nside^2 - 1]` at each level.
|
|
32
|
+
Sanity checks are included to prevent HEALPix `pix2loc` errors.
|
|
33
|
+
"""
|
|
34
|
+
|
|
35
|
+
from __future__ import annotations
|
|
36
|
+
from typing import List, Optional, Literal, Tuple, Union
|
|
37
|
+
import numpy as np
|
|
38
|
+
|
|
39
|
+
import torch
|
|
40
|
+
import torch.nn as nn
|
|
41
|
+
import torch.nn.functional as F
|
|
42
|
+
|
|
43
|
+
import foscat.scat_cov as sc
|
|
44
|
+
import foscat.SphericalStencil as ho
|
|
45
|
+
|
|
46
|
+
# -----------------------------------------------------------------------------
|
|
47
|
+
# Helper: safe type alias
|
|
48
|
+
# -----------------------------------------------------------------------------
|
|
49
|
+
ArrayLikeI64 = Union[np.ndarray, torch.Tensor]
|
|
50
|
+
|
|
51
|
+
|
|
52
|
+
class HealpixViT(nn.Module):
|
|
53
|
+
"""Vision Transformer on the HEALPix sphere using Foscat-oriented ops.
|
|
54
|
+
|
|
55
|
+
Parameters
|
|
56
|
+
----------
|
|
57
|
+
in_nside : int
|
|
58
|
+
Input HEALPix nside at the **finest** level (nested ordering). The number of pixels is
|
|
59
|
+
`Npix = 12 * in_nside**2`.
|
|
60
|
+
n_chan_in : int
|
|
61
|
+
Number of input channels at the finest grid.
|
|
62
|
+
embed_dim : int
|
|
63
|
+
Transformer embedding dimension (also the channel count after patch embedding).
|
|
64
|
+
depth : int
|
|
65
|
+
Number of Transformer encoder layers.
|
|
66
|
+
num_heads : int
|
|
67
|
+
Number of attention heads per layer.
|
|
68
|
+
cell_ids : np.ndarray
|
|
69
|
+
Finest-level **nested** cell indices (shape `[Npix]`, dtype `int64`). These define the
|
|
70
|
+
pixel layout of your input features.
|
|
71
|
+
mlp_ratio : float, default=4.0
|
|
72
|
+
Expansion ratio for the MLP inside each Transformer block.
|
|
73
|
+
token_down : int, default=2
|
|
74
|
+
Number of `Down()` steps to reach the token grid. The token nside is
|
|
75
|
+
`token_nside = in_nside // (2**token_down)`.
|
|
76
|
+
task : {"regression","segmentation","global"}, default="regression"
|
|
77
|
+
- "global": return a vector (pooled tokens → `out_channels`).
|
|
78
|
+
- "regression"/"segmentation": return per-pixel predictions on the finest grid.
|
|
79
|
+
out_channels : int, default=1
|
|
80
|
+
Output channels for dense tasks (ignored for `task="global"`).
|
|
81
|
+
final_activation : {"none","sigmoid","softmax"} | None
|
|
82
|
+
Optional activation for the output. If `None`, sensible defaults are chosen per task.
|
|
83
|
+
KERNELSZ : int, default=3
|
|
84
|
+
Spatial kernel size for spherical convolutions (Foscat oriented conv).
|
|
85
|
+
gauge_type : {"cosmo","phi"}, default="cosmo"
|
|
86
|
+
Orientation/gauge definition in `SphericalStencil`.
|
|
87
|
+
G : int, default=1
|
|
88
|
+
Number of gauges (internal orientation multiplicity). `embed_dim` must be divisible by `G`.
|
|
89
|
+
prefer_foscat_gpu : bool, default=True
|
|
90
|
+
Try Foscat on CUDA if available; fall back to CPU otherwise.
|
|
91
|
+
cls_token : bool, default=False
|
|
92
|
+
Include a `[CLS]` token for global tasks.
|
|
93
|
+
pos_embed : {"learned","none"}, default="learned"
|
|
94
|
+
Positional encoding type for tokens.
|
|
95
|
+
head_type : {"mean","cls"}, default="mean"
|
|
96
|
+
Pooling strategy for global tasks (mean over tokens or CLS vector).
|
|
97
|
+
dtype : {"float32","float64"}, default="float32"
|
|
98
|
+
Numpy dtype used for internal Foscat buffers. Model parameters remain `float32`.
|
|
99
|
+
|
|
100
|
+
Input/Output shapes
|
|
101
|
+
-------------------
|
|
102
|
+
Input: `(B, C_in, Npix)` with `Npix = 12 * in_nside**2`.
|
|
103
|
+
Output: - global task: `(B, out_channels)`
|
|
104
|
+
- dense task: `(B, out_channels, Npix)`
|
|
105
|
+
"""
|
|
106
|
+
|
|
107
|
+
# ------------------------------------------------------------------
|
|
108
|
+
# Construction
|
|
109
|
+
# ------------------------------------------------------------------
|
|
110
|
+
def __init__(
|
|
111
|
+
self,
|
|
112
|
+
*,
|
|
113
|
+
in_nside: int,
|
|
114
|
+
n_chan_in: int,
|
|
115
|
+
embed_dim: int,
|
|
116
|
+
depth: int,
|
|
117
|
+
num_heads: int,
|
|
118
|
+
cell_ids: np.ndarray,
|
|
119
|
+
mlp_ratio: float = 4.0,
|
|
120
|
+
token_down: int = 2,
|
|
121
|
+
task: Literal["regression","segmentation","global"] = "regression",
|
|
122
|
+
out_channels: int = 1,
|
|
123
|
+
final_activation: Optional[Literal["none","sigmoid","softmax"]] = None,
|
|
124
|
+
KERNELSZ: int = 3,
|
|
125
|
+
gauge_type: Optional[Literal["cosmo","phi"]] = "cosmo",
|
|
126
|
+
G: int = 1,
|
|
127
|
+
prefer_foscat_gpu: bool = True,
|
|
128
|
+
cls_token: bool = False,
|
|
129
|
+
pos_embed: Literal["learned","none"] = "learned",
|
|
130
|
+
head_type: Literal["mean","cls"] = "mean",
|
|
131
|
+
dtype: Literal["float32","float64"] = "float32",
|
|
132
|
+
) -> None:
|
|
133
|
+
super().__init__()
|
|
134
|
+
|
|
135
|
+
# ------------------- store config & dtypes -------------------
|
|
136
|
+
self.in_nside = int(in_nside)
|
|
137
|
+
self.n_chan_in = int(n_chan_in)
|
|
138
|
+
self.embed_dim = int(embed_dim)
|
|
139
|
+
self.depth = int(depth)
|
|
140
|
+
self.num_heads = int(num_heads)
|
|
141
|
+
self.mlp_ratio = float(mlp_ratio)
|
|
142
|
+
self.token_down = int(token_down)
|
|
143
|
+
self.task = task
|
|
144
|
+
self.out_channels = int(out_channels)
|
|
145
|
+
self.KERNELSZ = int(KERNELSZ)
|
|
146
|
+
self.gauge_type = gauge_type
|
|
147
|
+
self.G = int(G)
|
|
148
|
+
self.prefer_foscat_gpu = bool(prefer_foscat_gpu)
|
|
149
|
+
self.cls_token_enabled = bool(cls_token)
|
|
150
|
+
self.pos_embed_type = pos_embed
|
|
151
|
+
self.head_type = head_type
|
|
152
|
+
|
|
153
|
+
if dtype == "float32":
|
|
154
|
+
self.np_dtype = np.float32
|
|
155
|
+
self.torch_dtype = torch.float32
|
|
156
|
+
else:
|
|
157
|
+
self.np_dtype = np.float64
|
|
158
|
+
self.torch_dtype = torch.float32 # keep params in fp32
|
|
159
|
+
|
|
160
|
+
# ------------------- validate inputs -------------------
|
|
161
|
+
if cell_ids is None:
|
|
162
|
+
raise ValueError("cell_ids (finest) must be provided.")
|
|
163
|
+
self.cell_ids_fine = np.asarray(cell_ids)
|
|
164
|
+
self._check_ids(self.cell_ids_fine, self.in_nside, name="cell_ids_fine")
|
|
165
|
+
|
|
166
|
+
if self.G < 1:
|
|
167
|
+
raise ValueError("G must be >= 1")
|
|
168
|
+
if self.embed_dim % self.G != 0:
|
|
169
|
+
raise ValueError(f"embed_dim={self.embed_dim} must be divisible by G={self.G}")
|
|
170
|
+
if self.task not in {"regression", "segmentation", "global"}:
|
|
171
|
+
raise ValueError("task must be 'regression', 'segmentation', or 'global'")
|
|
172
|
+
|
|
173
|
+
# Default final activation per task if not specified
|
|
174
|
+
if final_activation is None:
|
|
175
|
+
if self.task == "regression":
|
|
176
|
+
self.final_activation = "none"
|
|
177
|
+
elif self.task == "segmentation":
|
|
178
|
+
self.final_activation = "sigmoid" if out_channels == 1 else "softmax"
|
|
179
|
+
else:
|
|
180
|
+
self.final_activation = "none"
|
|
181
|
+
else:
|
|
182
|
+
self.final_activation = final_activation
|
|
183
|
+
|
|
184
|
+
# ------------------- foscat functional wrapper -------------------
|
|
185
|
+
self.f = sc.funct(KERNELSZ=self.KERNELSZ)
|
|
186
|
+
|
|
187
|
+
# ------------------- build hierarchy (fine → coarse) -------------------
|
|
188
|
+
# We progressively `Down()` to precompute the token grid ids and operators.
|
|
189
|
+
self.hconv_levels: List[ho.SphericalStencil] = [] # op at successive levels
|
|
190
|
+
self.level_cell_ids: List[np.ndarray] = [self.cell_ids_fine]
|
|
191
|
+
current_nside = self.in_nside
|
|
192
|
+
|
|
193
|
+
# dummy buffer to probe Down; lives in Foscat backend dtype
|
|
194
|
+
dummy = self.f.backend.bk_cast(
|
|
195
|
+
np.zeros((1, 1, self.cell_ids_fine.shape[0]), dtype=self.np_dtype)
|
|
196
|
+
)
|
|
197
|
+
|
|
198
|
+
for _ in range(self.token_down):
|
|
199
|
+
hc = ho.SphericalStencil(
|
|
200
|
+
current_nside,
|
|
201
|
+
self.KERNELSZ,
|
|
202
|
+
n_gauges=self.G,
|
|
203
|
+
gauge_type=self.gauge_type,
|
|
204
|
+
cell_ids=self.level_cell_ids[-1],
|
|
205
|
+
dtype=self.torch_dtype,
|
|
206
|
+
)
|
|
207
|
+
self.hconv_levels.append(hc)
|
|
208
|
+
|
|
209
|
+
# Down to get next cell ids
|
|
210
|
+
dummy, next_ids = hc.Down(
|
|
211
|
+
dummy,
|
|
212
|
+
cell_ids=self.level_cell_ids[-1],
|
|
213
|
+
nside=current_nside,
|
|
214
|
+
max_poll=True,
|
|
215
|
+
)
|
|
216
|
+
next_ids = self.f.backend.to_numpy(next_ids)
|
|
217
|
+
current_nside //= 2
|
|
218
|
+
self._check_ids(next_ids, current_nside, name="token_level_cell_ids")
|
|
219
|
+
self.level_cell_ids.append(next_ids)
|
|
220
|
+
|
|
221
|
+
# token grid (where the Transformer runs)
|
|
222
|
+
self.token_nside = current_nside if self.token_down > 0 else self.in_nside
|
|
223
|
+
if self.token_nside < 1:
|
|
224
|
+
raise ValueError(
|
|
225
|
+
f"token_down={self.token_down} too large for in_nside={self.in_nside}"
|
|
226
|
+
)
|
|
227
|
+
self.token_cell_ids = self.level_cell_ids[-1]
|
|
228
|
+
|
|
229
|
+
# Operators at token and fine levels (used for Up and head)
|
|
230
|
+
self.hconv_token = ho.SphericalStencil(
|
|
231
|
+
self.token_nside,
|
|
232
|
+
self.KERNELSZ,
|
|
233
|
+
n_gauges=self.G,
|
|
234
|
+
gauge_type=self.gauge_type,
|
|
235
|
+
cell_ids=self.token_cell_ids,
|
|
236
|
+
dtype=self.torch_dtype,
|
|
237
|
+
)
|
|
238
|
+
self.hconv_head = ho.SphericalStencil(
|
|
239
|
+
self.in_nside,
|
|
240
|
+
self.KERNELSZ,
|
|
241
|
+
n_gauges=self.G,
|
|
242
|
+
gauge_type=self.gauge_type,
|
|
243
|
+
cell_ids=self.cell_ids_fine,
|
|
244
|
+
dtype=self.torch_dtype,
|
|
245
|
+
)
|
|
246
|
+
|
|
247
|
+
# ------------------- patch embedding (finest grid) -------------------
|
|
248
|
+
embed_g = self.embed_dim // self.G
|
|
249
|
+
# weight shapes follow Foscat conv expectations: (Cin, Cout_per_gauge, KERNELSZ*KERNELSZ)
|
|
250
|
+
self.patch_w1 = nn.Parameter(
|
|
251
|
+
torch.empty(self.n_chan_in, embed_g, self.KERNELSZ * self.KERNELSZ)
|
|
252
|
+
)
|
|
253
|
+
nn.init.kaiming_uniform_(self.patch_w1.view(self.n_chan_in * embed_g, -1), a=np.sqrt(5))
|
|
254
|
+
self.patch_bn1 = nn.GroupNorm(
|
|
255
|
+
num_groups=min(8, embed_g if embed_g > 1 else 1), num_channels=self.embed_dim
|
|
256
|
+
)
|
|
257
|
+
|
|
258
|
+
self.patch_w2 = nn.Parameter(
|
|
259
|
+
torch.empty(self.embed_dim, embed_g, self.KERNELSZ * self.KERNELSZ)
|
|
260
|
+
)
|
|
261
|
+
nn.init.kaiming_uniform_(self.patch_w2.view(self.embed_dim * embed_g, -1), a=np.sqrt(5))
|
|
262
|
+
self.patch_bn2 = nn.GroupNorm(
|
|
263
|
+
num_groups=min(8, embed_g if embed_g > 1 else 1), num_channels=self.embed_dim
|
|
264
|
+
)
|
|
265
|
+
|
|
266
|
+
# ------------------- positional encoding -------------------
|
|
267
|
+
self.n_tokens = int(self.token_cell_ids.shape[0])
|
|
268
|
+
if self.cls_token_enabled:
|
|
269
|
+
self.cls_token = nn.Parameter(torch.zeros(1, 1, self.embed_dim))
|
|
270
|
+
nn.init.trunc_normal_(self.cls_token, std=0.02)
|
|
271
|
+
n_pe = self.n_tokens + 1
|
|
272
|
+
else:
|
|
273
|
+
self.cls_token = None
|
|
274
|
+
n_pe = self.n_tokens
|
|
275
|
+
|
|
276
|
+
if self.pos_embed_type == "learned":
|
|
277
|
+
self.pos_embed = nn.Parameter(torch.zeros(1, n_pe, self.embed_dim))
|
|
278
|
+
nn.init.trunc_normal_(self.pos_embed, std=0.02)
|
|
279
|
+
else:
|
|
280
|
+
self.pos_embed = None
|
|
281
|
+
|
|
282
|
+
# ------------------- transformer encoder -------------------
|
|
283
|
+
enc_layer = nn.TransformerEncoderLayer(
|
|
284
|
+
d_model=self.embed_dim,
|
|
285
|
+
nhead=self.num_heads,
|
|
286
|
+
dim_feedforward=int(self.embed_dim * self.mlp_ratio),
|
|
287
|
+
dropout=0.0,
|
|
288
|
+
activation="gelu",
|
|
289
|
+
batch_first=True,
|
|
290
|
+
norm_first=True,
|
|
291
|
+
)
|
|
292
|
+
self.encoder = nn.TransformerEncoder(enc_layer, num_layers=self.depth)
|
|
293
|
+
|
|
294
|
+
# ------------------- output heads -------------------
|
|
295
|
+
if self.task == "global":
|
|
296
|
+
# Global head: a single Linear on pooled token features
|
|
297
|
+
self.global_head = nn.Linear(self.embed_dim, self.out_channels)
|
|
298
|
+
else:
|
|
299
|
+
# Dense head: project token embeddings to channels, Up to fine grid, optional conv
|
|
300
|
+
if self.out_channels % self.G != 0:
|
|
301
|
+
raise ValueError(
|
|
302
|
+
f"out_channels={self.out_channels} must be divisible by G={self.G}"
|
|
303
|
+
)
|
|
304
|
+
out_g = self.out_channels // self.G
|
|
305
|
+
self.token_proj = nn.Linear(self.embed_dim, self.G * out_g)
|
|
306
|
+
self.head_w = nn.Parameter(
|
|
307
|
+
torch.empty(self.out_channels, out_g, self.KERNELSZ * self.KERNELSZ)
|
|
308
|
+
)
|
|
309
|
+
nn.init.kaiming_uniform_(
|
|
310
|
+
self.head_w.view(self.out_channels * out_g, -1), a=np.sqrt(5)
|
|
311
|
+
)
|
|
312
|
+
self.head_bn = (
|
|
313
|
+
nn.GroupNorm(
|
|
314
|
+
num_groups=min(8, self.out_channels if self.out_channels > 1 else 1),
|
|
315
|
+
num_channels=self.out_channels,
|
|
316
|
+
)
|
|
317
|
+
if self.task == "segmentation"
|
|
318
|
+
else None
|
|
319
|
+
)
|
|
320
|
+
|
|
321
|
+
# ------------------- device probing (CUDA → CPU fallback) -------------------
|
|
322
|
+
pref = torch.device("cuda" if torch.cuda.is_available() else "cpu")
|
|
323
|
+
self.runtime_device = self._probe_and_set_runtime_device(pref)
|
|
324
|
+
|
|
325
|
+
# ------------------------------------------------------------------
|
|
326
|
+
# Internal sanity checks
|
|
327
|
+
# ------------------------------------------------------------------
|
|
328
|
+
@staticmethod
|
|
329
|
+
def _check_ids(ids: ArrayLikeI64, nside: int, name: str = "cell_ids") -> None:
|
|
330
|
+
"""Sanity check to avoid HEALPix `pix2loc` errors.
|
|
331
|
+
Ensures dtype=int64, range in [0, 12*nside^2 - 1].
|
|
332
|
+
"""
|
|
333
|
+
if isinstance(ids, torch.Tensor):
|
|
334
|
+
ids = ids.detach().cpu().numpy()
|
|
335
|
+
ids = np.asarray(ids)
|
|
336
|
+
if ids.dtype != np.int64:
|
|
337
|
+
raise TypeError(f"{name} must be int64, got {ids.dtype}.")
|
|
338
|
+
npix = 12 * nside * nside
|
|
339
|
+
imin, imax = int(ids.min()), int(ids.max())
|
|
340
|
+
if imin < 0 or imax >= npix:
|
|
341
|
+
raise ValueError(
|
|
342
|
+
f"{name} out of range for nside={nside}: min={imin}, max={imax}, allowed=[0,{npix-1}]"
|
|
343
|
+
)
|
|
344
|
+
|
|
345
|
+
# ------------------------------------------------------------------
|
|
346
|
+
# Device utilities
|
|
347
|
+
# ------------------------------------------------------------------
|
|
348
|
+
def _move_hc(self, hc: ho.SphericalStencil, device: torch.device) -> None:
|
|
349
|
+
"""Move internal tensors of SphericalStencil to the given device.
|
|
350
|
+
This mirrors the plumbing in U-Net-like codebases using Foscat.
|
|
351
|
+
"""
|
|
352
|
+
for name, val in list(vars(hc).items()):
|
|
353
|
+
try:
|
|
354
|
+
if torch.is_tensor(val):
|
|
355
|
+
setattr(hc, name, val.to(device))
|
|
356
|
+
elif isinstance(val, (list, tuple)) and val and torch.is_tensor(val[0]):
|
|
357
|
+
setattr(hc, name, type(val)([v.to(device) for v in val]))
|
|
358
|
+
except Exception:
|
|
359
|
+
# Some attributes may be non-tensors; ignore.
|
|
360
|
+
pass
|
|
361
|
+
|
|
362
|
+
@torch.no_grad()
|
|
363
|
+
def _probe_and_set_runtime_device(self, preferred: torch.device) -> torch.device:
|
|
364
|
+
"""Try to run on CUDA with Foscat; otherwise, gracefully fall back to CPU.
|
|
365
|
+
Performs a tiny dry-run spherical convolution to ensure compatibility.
|
|
366
|
+
"""
|
|
367
|
+
if preferred.type == "cuda" and self.prefer_foscat_gpu:
|
|
368
|
+
try:
|
|
369
|
+
super().to(preferred)
|
|
370
|
+
for hc in self.hconv_levels + [self.hconv_token, self.hconv_head]:
|
|
371
|
+
self._move_hc(hc, preferred)
|
|
372
|
+
# Dry run: minimal conv on finest grid
|
|
373
|
+
npix0 = int(self.cell_ids_fine.shape[0])
|
|
374
|
+
x_try = torch.zeros(1, self.n_chan_in, npix0, device=preferred)
|
|
375
|
+
hc0 = self.hconv_levels[0] if len(self.hconv_levels) > 0 else self.hconv_head
|
|
376
|
+
y_try = hc0.Convol_torch(x_try, self.patch_w1)
|
|
377
|
+
_ = (y_try if torch.is_tensor(y_try) else torch.as_tensor(y_try, device=preferred)).sum().item()
|
|
378
|
+
self._foscat_device = preferred
|
|
379
|
+
return preferred
|
|
380
|
+
except Exception as e:
|
|
381
|
+
# Record and fall back
|
|
382
|
+
self._gpu_probe_error = repr(e)
|
|
383
|
+
cpu = torch.device("cpu")
|
|
384
|
+
super().to(cpu)
|
|
385
|
+
for hc in self.hconv_levels + [self.hconv_token, self.hconv_head]:
|
|
386
|
+
self._move_hc(hc, cpu)
|
|
387
|
+
self._foscat_device = cpu
|
|
388
|
+
return cpu
|
|
389
|
+
|
|
390
|
+
# ------------------------------------------------------------------
|
|
391
|
+
# Forward helpers
|
|
392
|
+
# ------------------------------------------------------------------
|
|
393
|
+
def _patch_embed(self, x: torch.Tensor, cell_ids: Optional[np.ndarray]) -> torch.Tensor:
|
|
394
|
+
"""Spherical patch embedding at the **finest** grid.
|
|
395
|
+
Applies two Foscat oriented convolutions with GN+GELU.
|
|
396
|
+
Input `(B, C_in, Nfine)` → Output `(B, embed_dim, Nfine)`.
|
|
397
|
+
"""
|
|
398
|
+
hc0 = self.hconv_levels[0] if len(self.hconv_levels) > 0 else self.hconv_head
|
|
399
|
+
if cell_ids is None:
|
|
400
|
+
# Use constructor-time ids
|
|
401
|
+
y = hc0.Convol_torch(x, self.patch_w1)
|
|
402
|
+
y = self._as_tensor_batch(y)
|
|
403
|
+
y = self.patch_bn1(y)
|
|
404
|
+
y = F.gelu(y)
|
|
405
|
+
y = hc0.Convol_torch(y, self.patch_w2)
|
|
406
|
+
y = self._as_tensor_batch(y)
|
|
407
|
+
y = self.patch_bn2(y)
|
|
408
|
+
y = F.gelu(y)
|
|
409
|
+
return y
|
|
410
|
+
else:
|
|
411
|
+
# Use runtime ids provided by caller
|
|
412
|
+
y = hc0.Convol_torch(x, self.patch_w1, cell_ids=cell_ids)
|
|
413
|
+
y = self._as_tensor_batch(y)
|
|
414
|
+
y = self.patch_bn1(y)
|
|
415
|
+
y = F.gelu(y)
|
|
416
|
+
y = hc0.Convol_torch(y, self.patch_w2, cell_ids=cell_ids)
|
|
417
|
+
y = self._as_tensor_batch(y)
|
|
418
|
+
y = self.patch_bn2(y)
|
|
419
|
+
y = F.gelu(y)
|
|
420
|
+
return y
|
|
421
|
+
|
|
422
|
+
def _down_to_tokens(
|
|
423
|
+
self, x: torch.Tensor, cell_ids: Optional[np.ndarray]
|
|
424
|
+
) -> Tuple[torch.Tensor, np.ndarray]:
|
|
425
|
+
"""Apply `token_down` Down() steps to reach the **token grid**.
|
|
426
|
+
Returns `(x_tokens, token_cell_ids)` where `x_tokens` has shape `(B, C, N_tokens)`.
|
|
427
|
+
If `cell_ids` is provided, uses them as the starting fine-grid ids; otherwise uses
|
|
428
|
+
the constructor-time `self.cell_ids_fine`.
|
|
429
|
+
"""
|
|
430
|
+
l_data = x
|
|
431
|
+
l_cell_ids = self.cell_ids_fine if cell_ids is None else np.asarray(cell_ids)
|
|
432
|
+
current_nside = self.in_nside
|
|
433
|
+
|
|
434
|
+
for hc in self.hconv_levels:
|
|
435
|
+
l_data, l_cell_ids = hc.Down(
|
|
436
|
+
l_data, cell_ids=l_cell_ids, nside=current_nside, max_poll=True
|
|
437
|
+
)
|
|
438
|
+
l_data = self._as_tensor_batch(l_data)
|
|
439
|
+
current_nside //= 2
|
|
440
|
+
return l_data, l_cell_ids
|
|
441
|
+
|
|
442
|
+
def _tokens_to_sequence(self, x_tokens: torch.Tensor) -> torch.Tensor:
|
|
443
|
+
"""Rearrange `(B, C, Ntok)` → `(B, Ntok(+CLS), C)` and add positional embeddings."""
|
|
444
|
+
B, C, Nt = x_tokens.shape
|
|
445
|
+
seq = x_tokens.permute(0, 2, 1) # (B, Nt, C)
|
|
446
|
+
if self.cls_token_enabled:
|
|
447
|
+
cls = self.cls_token.expand(B, -1, -1)
|
|
448
|
+
seq = torch.cat([cls, seq], dim=1) # (B, 1+Nt, C)
|
|
449
|
+
if self.pos_embed is not None:
|
|
450
|
+
seq = seq + self.pos_embed[:, : seq.shape[1], :]
|
|
451
|
+
return seq
|
|
452
|
+
|
|
453
|
+
def _sequence_to_tokens(
|
|
454
|
+
self, seq: torch.Tensor
|
|
455
|
+
) -> Tuple[torch.Tensor, Optional[torch.Tensor]]:
|
|
456
|
+
"""Strip CLS if present and return `(tokens_only, cls_vector)`."""
|
|
457
|
+
if self.cls_token_enabled:
|
|
458
|
+
cls_vec = seq[:, 0, :]
|
|
459
|
+
tokens = seq[:, 1:, :]
|
|
460
|
+
return tokens, cls_vec
|
|
461
|
+
return seq, None
|
|
462
|
+
|
|
463
|
+
# ------------------------------------------------------------------
|
|
464
|
+
# Forward
|
|
465
|
+
# ------------------------------------------------------------------
|
|
466
|
+
def forward(self, x: torch.Tensor, cell_ids: Optional[ArrayLikeI64] = None) -> torch.Tensor:
|
|
467
|
+
"""Forward pass.
|
|
468
|
+
|
|
469
|
+
Parameters
|
|
470
|
+
----------
|
|
471
|
+
x : torch.Tensor
|
|
472
|
+
Input tensor of shape `(B, C_in, Npix)` at the finest grid.
|
|
473
|
+
cell_ids : Optional[np.ndarray or torch.Tensor]
|
|
474
|
+
Optional **nested** pixel indices for the input; if provided, they are used throughout
|
|
475
|
+
the pipeline (patch embedding, Down, Up, head conv). If `None`, the constructor-time
|
|
476
|
+
`cell_ids` are used.
|
|
477
|
+
"""
|
|
478
|
+
if not isinstance(x, torch.Tensor):
|
|
479
|
+
raise TypeError("x must be a torch.Tensor")
|
|
480
|
+
if x.dim() != 3:
|
|
481
|
+
raise ValueError("Input must be (B, C, Npix)")
|
|
482
|
+
if x.shape[1] != self.n_chan_in:
|
|
483
|
+
raise ValueError(f"Expected {self.n_chan_in} channels, got {x.shape[1]}")
|
|
484
|
+
|
|
485
|
+
# Normalize/validate runtime ids once
|
|
486
|
+
runtime_ids = None
|
|
487
|
+
if cell_ids is not None:
|
|
488
|
+
if isinstance(cell_ids, torch.Tensor):
|
|
489
|
+
cell_ids = cell_ids.detach().cpu().numpy()
|
|
490
|
+
cell_ids = np.asarray(cell_ids)
|
|
491
|
+
# If given per-batch ids (B, Npix), take first row (assume same layout for the batch)
|
|
492
|
+
if cell_ids.ndim == 2:
|
|
493
|
+
cell_ids = cell_ids[0]
|
|
494
|
+
self._check_ids(cell_ids, self.in_nside, name="forward:cell_ids")
|
|
495
|
+
runtime_ids = cell_ids
|
|
496
|
+
|
|
497
|
+
x = x.to(self.runtime_device)
|
|
498
|
+
|
|
499
|
+
# 1) Patch embedding (finest grid)
|
|
500
|
+
x = self._patch_embed(x, runtime_ids) # (B, embed_dim, Nfine)
|
|
501
|
+
|
|
502
|
+
# 2) Down to token grid
|
|
503
|
+
x_tok, token_ids = self._down_to_tokens(x, runtime_ids) # (B, embed_dim, Ntok)
|
|
504
|
+
|
|
505
|
+
# 3) Transformer encoder on token sequence
|
|
506
|
+
seq = self._tokens_to_sequence(x_tok) # (B, Ntok(+1), embed_dim)
|
|
507
|
+
seq = self.encoder(seq) # (B, Ntok(+1), embed_dim)
|
|
508
|
+
tokens, cls_vec = self._sequence_to_tokens(seq)
|
|
509
|
+
|
|
510
|
+
if self.task == "global":
|
|
511
|
+
# Global vector from mean/CLS pooling
|
|
512
|
+
if self.head_type == "cls" and self.cls_token_enabled and cls_vec is not None:
|
|
513
|
+
out = self.global_head(cls_vec) # (B, out_channels)
|
|
514
|
+
else:
|
|
515
|
+
out = self.global_head(tokens.mean(dim=1))
|
|
516
|
+
return out
|
|
517
|
+
|
|
518
|
+
# 4) Project tokens to channels at token grid
|
|
519
|
+
tok_proj = self.token_proj(tokens) # (B, Ntok, out_channels)
|
|
520
|
+
tok_proj = tok_proj.permute(0, 2, 1) # (B, out_channels, Ntok)
|
|
521
|
+
# Sanity: token feature count must match token_ids length
|
|
522
|
+
if isinstance(token_ids, torch.Tensor):
|
|
523
|
+
_tok_ids = token_ids.detach().cpu().numpy()
|
|
524
|
+
else:
|
|
525
|
+
_tok_ids = np.asarray(token_ids)
|
|
526
|
+
assert tok_proj.shape[-1] == _tok_ids.shape[0], (
|
|
527
|
+
f"Ntok mismatch: {tok_proj.shape[-1]} != {_tok_ids.shape[0]}"
|
|
528
|
+
)
|
|
529
|
+
|
|
530
|
+
# 5) Up from token grid to finest grid
|
|
531
|
+
# Use constructor-time fine ids by default; override if runtime ids provided.
|
|
532
|
+
fine_ids = self.cell_ids_fine if runtime_ids is None else runtime_ids # 5) Multi-step Up from token grid to finest grid (one HEALPix level at a time)
|
|
533
|
+
# Use constructor-time fine ids by default; override if runtime ids provided.
|
|
534
|
+
fine_ids_runtime = self.cell_ids_fine if runtime_ids is None else runtime_ids
|
|
535
|
+
|
|
536
|
+
# Build the ID chain from fine → ... → token for THIS forward, using runtime ids
|
|
537
|
+
_ids = fine_ids_runtime
|
|
538
|
+
nside_tmp = self.in_nside
|
|
539
|
+
ids_chain = [np.asarray(_ids)]
|
|
540
|
+
_dummy = self.f.backend.bk_cast(np.zeros((1, 1, ids_chain[0].shape[0]), dtype=self.np_dtype))
|
|
541
|
+
for hc in self.hconv_levels:
|
|
542
|
+
_dummy, _next = hc.Down(_dummy, cell_ids=ids_chain[-1], nside=nside_tmp, max_poll=True)
|
|
543
|
+
ids_chain.append(self.f.backend.to_numpy(_next))
|
|
544
|
+
nside_tmp //= 2
|
|
545
|
+
|
|
546
|
+
# Sanity: token_ids from the actual Down path must match the last element of the chain
|
|
547
|
+
if isinstance(token_ids, torch.Tensor):
|
|
548
|
+
_tok_ids = token_ids.detach().cpu().numpy()
|
|
549
|
+
else:
|
|
550
|
+
_tok_ids = np.asarray(token_ids)
|
|
551
|
+
assert tok_proj.shape[-1] == _tok_ids.shape[0], f"Ntok mismatch: {tok_proj.shape[-1]} != {_tok_ids.shape[0]}"
|
|
552
|
+
assert np.array_equal(_tok_ids, ids_chain[-1]), "token_ids mismatch with runtime Down() chain"
|
|
553
|
+
|
|
554
|
+
# Precompute nsides represented by hconv_levels (fine→coarse, excluding token level)
|
|
555
|
+
nsides_levels = [self.in_nside // (2 ** k) for k in range(self.token_down)] # e.g., [8, 4] for token_down=2
|
|
556
|
+
|
|
557
|
+
# Now Up step-by-step: token (coarse) → ... → fine
|
|
558
|
+
y_up = tok_proj
|
|
559
|
+
for i in range(len(ids_chain) - 1, 0, -1):
|
|
560
|
+
coarse_ids = ids_chain[i]
|
|
561
|
+
fine_ids_step = ids_chain[i - 1]
|
|
562
|
+
source_nside = self.in_nside // (2 ** i) # e.g., 2, then 4
|
|
563
|
+
fine_nside = self.in_nside // (2 ** (i - 1)) # e.g., 4, then 8
|
|
564
|
+
# pick the operator of the target (fine) level
|
|
565
|
+
if fine_nside == self.in_nside:
|
|
566
|
+
op_fine = self.hconv_head
|
|
567
|
+
else:
|
|
568
|
+
idx = nsides_levels.index(fine_nside)
|
|
569
|
+
op_fine = self.hconv_levels[idx]
|
|
570
|
+
y_up = op_fine.Up(y_up, cell_ids=coarse_ids, o_cell_ids=fine_ids_step, nside=source_nside)
|
|
571
|
+
if not torch.is_tensor(y_up):
|
|
572
|
+
y_up = torch.as_tensor(y_up, device=self.runtime_device)
|
|
573
|
+
y_up = self._as_tensor_batch(y_up) # (B, out_channels, N at this fine level)
|
|
574
|
+
|
|
575
|
+
|
|
576
|
+
# 6) Optional spherical head conv for refinement
|
|
577
|
+
y = self.hconv_head.Convol_torch(y_up, self.head_w, cell_ids=fine_ids)
|
|
578
|
+
if not torch.is_tensor(y):
|
|
579
|
+
y = torch.as_tensor(y, device=self.runtime_device)
|
|
580
|
+
y = self._as_tensor_batch(y)
|
|
581
|
+
|
|
582
|
+
if self.task == "segmentation" and self.head_bn is not None:
|
|
583
|
+
y = self.head_bn(y)
|
|
584
|
+
|
|
585
|
+
if self.final_activation == "sigmoid":
|
|
586
|
+
y = torch.sigmoid(y)
|
|
587
|
+
elif self.final_activation == "softmax":
|
|
588
|
+
y = torch.softmax(y, dim=1)
|
|
589
|
+
return y
|
|
590
|
+
|
|
591
|
+
# ------------------------------------------------------------------
|
|
592
|
+
# Misc helpers
|
|
593
|
+
# ------------------------------------------------------------------
|
|
594
|
+
def _as_tensor_batch(self, x):
|
|
595
|
+
"""Normalize outputs of Foscat ops into a contiguous batch tensor.
|
|
596
|
+
Foscat may return a tensor or a single-element list of tensors.
|
|
597
|
+
This function ensures we always get a tensor of the expected shape.
|
|
598
|
+
"""
|
|
599
|
+
if isinstance(x, list):
|
|
600
|
+
if len(x) == 1:
|
|
601
|
+
t = x[0]
|
|
602
|
+
return t.unsqueeze(0) if t.dim() == 2 else t
|
|
603
|
+
raise ValueError("Variable-length list not supported here; pass a tensor.")
|
|
604
|
+
return x
|
|
605
|
+
|
|
606
|
+
@torch.no_grad()
|
|
607
|
+
def predict(
|
|
608
|
+
self, x: Union[torch.Tensor, np.ndarray], batch_size: int = 8
|
|
609
|
+
) -> torch.Tensor:
|
|
610
|
+
"""Convenience method for batched inference.
|
|
611
|
+
|
|
612
|
+
Parameters
|
|
613
|
+
----------
|
|
614
|
+
x : Tensor or ndarray
|
|
615
|
+
Input `(B, C_in, Npix)`.
|
|
616
|
+
batch_size : int
|
|
617
|
+
Mini-batch size used during prediction.
|
|
618
|
+
"""
|
|
619
|
+
self.eval()
|
|
620
|
+
if isinstance(x, np.ndarray):
|
|
621
|
+
x = torch.from_numpy(x).float()
|
|
622
|
+
outs = []
|
|
623
|
+
for i in range(0, x.shape[0], batch_size):
|
|
624
|
+
xb = x[i : i + batch_size].to(self.runtime_device)
|
|
625
|
+
outs.append(self.forward(xb))
|
|
626
|
+
return torch.cat(outs, dim=0)
|
|
627
|
+
|
|
628
|
+
|
|
629
|
+
# -----------------------------------------------------------------------------
|
|
630
|
+
# Minimal smoke test (requires foscat installed)
|
|
631
|
+
# -----------------------------------------------------------------------------
|
|
632
|
+
if __name__ == "__main__":
|
|
633
|
+
# A tiny grid to validate shapes and device plumbing
|
|
634
|
+
in_nside = 4
|
|
635
|
+
npix = 12 * in_nside * in_nside
|
|
636
|
+
cell_ids = np.arange(npix, dtype=np.int64) # nested, fine-level ids
|
|
637
|
+
|
|
638
|
+
B, Cin = 2, 3
|
|
639
|
+
x = torch.randn(B, Cin, npix)
|
|
640
|
+
|
|
641
|
+
model = HealpixViT(
|
|
642
|
+
in_nside=in_nside,
|
|
643
|
+
n_chan_in=Cin,
|
|
644
|
+
embed_dim=64,
|
|
645
|
+
depth=2,
|
|
646
|
+
num_heads=4,
|
|
647
|
+
cell_ids=cell_ids,
|
|
648
|
+
token_down=2, # token_nside = in_nside // 4 = 1 here
|
|
649
|
+
task="regression",
|
|
650
|
+
out_channels=1,
|
|
651
|
+
KERNELSZ=3,
|
|
652
|
+
G=1,
|
|
653
|
+
cls_token=False,
|
|
654
|
+
)
|
|
655
|
+
|
|
656
|
+
with torch.no_grad():
|
|
657
|
+
y = model(x) # You can also pass `cell_ids=cell_ids` if your pipeline manages them at runtime
|
|
658
|
+
print("Output:", y.shape) # (B, out_channels, npix)
|