foscat 2025.9.5__py3-none-any.whl → 2025.10.2__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
foscat/FoCUS.py CHANGED
@@ -6,7 +6,7 @@ import numpy as np
6
6
  import foscat.HealSpline as HS
7
7
  from scipy.interpolate import griddata
8
8
 
9
- TMPFILE_VERSION = "V9_0"
9
+ TMPFILE_VERSION = "V10_0"
10
10
 
11
11
 
12
12
  class FoCUS:
@@ -36,7 +36,7 @@ class FoCUS:
36
36
  mpi_rank=0
37
37
  ):
38
38
 
39
- self.__version__ = "2025.09.5"
39
+ self.__version__ = "2025.10.2"
40
40
  # P00 coeff for normalization for scat_cov
41
41
  self.TMPFILE_VERSION = TMPFILE_VERSION
42
42
  self.P1_dic = None
@@ -1488,7 +1488,7 @@ class FoCUS:
1488
1488
  if l_kernel == 5:
1489
1489
  pw = 0.5
1490
1490
  pw2 = 0.5
1491
- threshold = 2e-4
1491
+ threshold = 2e-5
1492
1492
 
1493
1493
  elif l_kernel == 3:
1494
1494
  pw = 1.0 / np.sqrt(2)
@@ -1498,7 +1498,7 @@ class FoCUS:
1498
1498
  elif l_kernel == 7:
1499
1499
  pw = 0.5
1500
1500
  pw2 = 0.25
1501
- threshold = 4e-5
1501
+ threshold = 2e-5
1502
1502
 
1503
1503
  import foscat.SphericalStencil as hs
1504
1504
  import torch
@@ -1517,14 +1517,19 @@ class FoCUS:
1517
1517
  n_gauges=self.NORIENT,
1518
1518
  gauge_type='cosmo')
1519
1519
 
1520
- xx=np.tile(np.arange(self.KERNELSZ)-self.KERNELSZ//2,self.KERNELSZ).reshape(self.KERNELSZ*self.KERNELSZ)
1520
+ xx=np.tile(np.arange(self.KERNELSZ)-self.KERNELSZ//2,self.KERNELSZ).reshape(self.KERNELSZ,self.KERNELSZ)
1521
1521
 
1522
- wwr=hconvol.to_tensor((np.exp(-pw2*(xx**2+(xx.T)**2))*np.cos(pw*xx*np.pi)).reshape(1,1,self.KERNELSZ*self.KERNELSZ))
1522
+ wwr=(np.exp(-pw2*(xx**2+(xx.T)**2))*np.cos(pw*xx*np.pi)).reshape(1,1,self.KERNELSZ*self.KERNELSZ)
1523
1523
  wwr-=wwr.mean()
1524
- wwi=hconvol.to_tensor((np.exp(-pw2*(xx**2+(xx.T)**2))*np.sin(pw*xx*np.pi)).reshape(1,1,self.KERNELSZ*self.KERNELSZ))
1524
+ wwi=(np.exp(-pw2*(xx**2+(xx.T)**2))*np.sin(pw*xx*np.pi)).reshape(1,1,self.KERNELSZ*self.KERNELSZ)
1525
1525
  wwi-=wwi.mean()
1526
- wwr/=(abs(wwr+1J*wwi)).sum()
1527
- wwi/=(abs(wwr+1J*wwi)).sum()
1526
+ amp=np.sum(abs(wwr+1J*wwi))
1527
+
1528
+ wwr/=amp
1529
+ wwi/=amp
1530
+
1531
+ wwr=hconvol.to_tensor(wwr)
1532
+ wwi=hconvol.to_tensor(wwi)
1528
1533
 
1529
1534
  wavr,indice,mshape=hconvol.make_matrix(wwr)
1530
1535
  wavi,indice,mshape=hconvol.make_matrix(wwi)
foscat/Plot.py CHANGED
@@ -959,8 +959,9 @@ def conjugate_gradient_normal_equation(data, x0, www, all_idx,
959
959
 
960
960
  rs_new = np.dot(r, r)
961
961
 
962
- if verbose and i % 50 == 0:
963
- print(f"Iter {i:03d}: residual = {np.sqrt(rs_new):.3e}")
962
+ if verbose and i % 10 == 0:
963
+ v=np.mean((LP(p, www, all_idx)-data)**2)
964
+ print(f"Iter {i:03d}: residual = {np.sqrt(rs_new):.3e},{np.sqrt(v):.3e}")
964
965
 
965
966
  if np.sqrt(rs_new) < tol:
966
967
  if verbose:
@@ -1155,7 +1156,7 @@ def plot_wave(wave,title="spectrum",unit="Amplitude",cmap="viridis"):
1155
1156
  plt.xlabel(r"$k_x$ [cycles / km]")
1156
1157
  plt.ylabel(r"$k_y$ [cycles / km]")
1157
1158
  plt.title(title)
1158
-
1159
+
1159
1160
  def lonlat_edges_from_ref(shape, ref_lon, ref_lat, dlon, dlat, anchor="center"):
1160
1161
  """
1161
1162
  Build lon/lat *edges* (H+1, W+1) for a regular, axis-aligned grid.
@@ -982,6 +982,9 @@ def fit(
982
982
  n_epoch: int = 10,
983
983
  view_epoch: int = 10,
984
984
  batch_size: int = 16,
985
+ x_valid: Union[torch.Tensor, np.ndarray, List[Union[torch.Tensor, np.ndarray]]]= None,
986
+ y_valid: Union[torch.Tensor, np.ndarray, List[Union[torch.Tensor, np.ndarray]]]= None,
987
+ save_model: bool = False,
985
988
  lr: float = 1e-3,
986
989
  weight_decay: float = 0.0,
987
990
  clip_grad_norm: Optional[float] = None,
@@ -1005,6 +1008,11 @@ def fit(
1005
1008
  device = model.runtime_device if hasattr(model, "runtime_device") else (torch.device("cuda") if torch.cuda.is_available() else torch.device("cpu"))
1006
1009
  model.to(device)
1007
1010
 
1011
+ if save_model:
1012
+ assert x_valid is None, "If save_mode=True x_valid should not be None"
1013
+ assert y_valid is None, "If save_mode=True y_valid should not be None"
1014
+ best_valid=1E30
1015
+
1008
1016
  # Detect variable-length mode
1009
1017
  varlen_mode = isinstance(x_train, (list, tuple))
1010
1018
 
@@ -1197,6 +1205,14 @@ def fit(
1197
1205
  history.append(epoch_loss)
1198
1206
  # print every view_epoch logical step
1199
1207
  if verbose and ((len(history) % view_epoch == 0) or (len(history) == 1)):
1200
- print(f"[epoch {len(history)}] loss={epoch_loss:.6f}")
1208
+ if x_valid is not None:
1209
+ preds=model.predict(model.to_tensor(x_valid)).cpu().numpy()
1210
+ valid_loss=np.mean((preds-y_valid)**2)
1211
+ if save_model:
1212
+ if best_valid>valid_loss:
1213
+ torch.save({"model": self.state_dict(), "cfg": CFG}, os.path.join(CFG["save_dir"], "best.pt"))
1214
+ print(f"[epoch {len(history)}] loss={epoch_loss:.4f} loss_valid={valid_loss:.4f}")
1215
+ else:
1216
+ print(f"[epoch {len(history)}] loss={epoch_loss:.4f}")
1201
1217
 
1202
1218
  return {"loss": history}
@@ -0,0 +1,445 @@
1
+ # healpix_vit_skip.py
2
+ # HEALPix ViT U-Net with temporal encoders and Transformer-based skip fusion.
3
+ # - Multi-level HEALPix pyramid using Foscat.SphericalStencil
4
+ # - Per-level temporal encoding (sequence over T_in months) at encoder
5
+ # - Decoder uses cross-attention to fuse upsampled features with encoder skips
6
+ # - Double spherical convolution + GroupNorm + GELU at each encoder/decoder level
7
+
8
+ from __future__ import annotations
9
+ from typing import List, Optional, Literal
10
+ import numpy as np
11
+
12
+ import torch
13
+ import torch.nn as nn
14
+ import torch.nn.functional as F
15
+
16
+ import foscat.scat_cov as sc
17
+ import foscat.SphericalStencil as ho
18
+
19
+
20
+ class MLP(nn.Module):
21
+ def __init__(self, d: int, hidden_mult: int = 4, drop: float = 0.0):
22
+ super().__init__()
23
+ self.net = nn.Sequential(
24
+ nn.LayerNorm(d),
25
+ nn.Linear(d, hidden_mult * d),
26
+ nn.GELU(),
27
+ nn.Dropout(drop),
28
+ nn.Linear(hidden_mult * d, d),
29
+ nn.Dropout(drop),
30
+ )
31
+ def forward(self, x):
32
+ return self.net(x)
33
+
34
+
35
+ class HealpixViTSkip(nn.Module):
36
+ def __init__(
37
+ self,
38
+ *,
39
+ in_nside: int,
40
+ n_chan_in: int,
41
+ level_dims: List[int],
42
+ depth_token: int,
43
+ num_heads_token: int,
44
+ cell_ids: np.ndarray,
45
+ task: Literal["regression","segmentation","global"] = "regression",
46
+ out_channels: int = 1,
47
+ mlp_ratio_token: float = 4.0,
48
+ KERNELSZ: int = 3,
49
+ gauge_type: Literal["cosmo","phi"] = "cosmo",
50
+ G: int = 1,
51
+ prefer_foscat_gpu: bool = True,
52
+ dropout: float = 0.1,
53
+ dtype: Literal["float32","float64"] = "float32",
54
+ pos_embed_per_level: bool = True,
55
+ ) -> None:
56
+ super().__init__()
57
+
58
+ self.in_nside = int(in_nside)
59
+ self.n_chan_in = int(n_chan_in)
60
+ self.level_dims = list(level_dims)
61
+ self.token_down = len(self.level_dims) - 1
62
+ assert self.token_down >= 0
63
+ self.C_fine = int(self.level_dims[0])
64
+ self.embed_dim = int(self.level_dims[-1])
65
+ self.depth_token = int(depth_token)
66
+ self.num_heads_token = int(num_heads_token)
67
+ self.mlp_ratio_token = float(mlp_ratio_token)
68
+ self.task = task
69
+ self.out_channels = int(out_channels)
70
+ self.KERNELSZ = int(KERNELSZ)
71
+ self.gauge_type = gauge_type
72
+ self.G = int(G)
73
+ self.prefer_foscat_gpu = bool(prefer_foscat_gpu)
74
+ self.dropout = float(dropout)
75
+ self.dtype = dtype
76
+ self.pos_embed_per_level = bool(pos_embed_per_level)
77
+
78
+ for d in self.level_dims:
79
+ if d % self.G != 0:
80
+ raise ValueError(f"All level_dims must be divisible by G={self.G}, got {d}.")
81
+ if self.embed_dim % self.num_heads_token != 0:
82
+ raise ValueError("embed_dim must be divisible by num_heads_token.")
83
+
84
+ if dtype == "float32":
85
+ self.np_dtype = np.float32
86
+ self.torch_dtype = torch.float32
87
+ else:
88
+ self.np_dtype = np.float64
89
+ self.torch_dtype = torch.float32
90
+
91
+ if cell_ids is None:
92
+ raise ValueError("cell_ids (finest) must be provided.")
93
+ self.cell_ids_fine = np.asarray(cell_ids)
94
+
95
+ if self.task == "segmentation":
96
+ self.final_activation = "sigmoid" if self.out_channels == 1 else "softmax"
97
+ else:
98
+ self.final_activation = "none"
99
+
100
+ self.f = sc.funct(KERNELSZ=self.KERNELSZ)
101
+
102
+ # Build stencils
103
+ self.hconv_levels: List[ho.SphericalStencil] = []
104
+ self.level_cell_ids: List[np.ndarray] = [self.cell_ids_fine]
105
+ current_nside = self.in_nside
106
+ dummy = self.f.backend.bk_cast(np.zeros((1, 1, self.cell_ids_fine.shape[0]), dtype=self.np_dtype))
107
+ for _ in range(self.token_down):
108
+ hc = ho.SphericalStencil(current_nside, self.KERNELSZ, n_gauges=self.G,
109
+ gauge_type=self.gauge_type, cell_ids=self.level_cell_ids[-1],
110
+ dtype=self.torch_dtype)
111
+ self.hconv_levels.append(hc)
112
+ dummy, next_ids = hc.Down(dummy, cell_ids=self.level_cell_ids[-1], nside=current_nside, max_poll=True)
113
+ self.level_cell_ids.append(self.f.backend.to_numpy(next_ids))
114
+ current_nside //= 2
115
+
116
+ self.token_nside = current_nside if self.token_down > 0 else self.in_nside
117
+ self.token_cell_ids = self.level_cell_ids[-1]
118
+
119
+ self.hconv_token = ho.SphericalStencil(self.token_nside, self.KERNELSZ, n_gauges=self.G,
120
+ gauge_type=self.gauge_type, cell_ids=self.token_cell_ids, dtype=self.torch_dtype)
121
+ self.hconv_head = ho.SphericalStencil(self.in_nside, self.KERNELSZ, n_gauges=self.G,
122
+ gauge_type=self.gauge_type, cell_ids=self.cell_ids_fine, dtype=self.torch_dtype)
123
+
124
+ self.nsides_levels = [self.in_nside // (2**i) for i in range(self.token_down+1)]
125
+ self.ntokens_levels = [12 * n * n for n in self.nsides_levels]
126
+
127
+ # Patch embed (double conv)
128
+ fine_g = self.C_fine // self.G
129
+ self.pe_w1 = nn.Parameter(torch.empty(self.n_chan_in, fine_g, self.KERNELSZ*self.KERNELSZ))
130
+ nn.init.kaiming_uniform_(self.pe_w1.view(self.n_chan_in * fine_g, -1), a=np.sqrt(5))
131
+ self.pe_w2 = nn.Parameter(torch.empty(self.C_fine, fine_g, self.KERNELSZ*self.KERNELSZ))
132
+ nn.init.kaiming_uniform_(self.pe_w2.view(self.C_fine * fine_g, -1), a=np.sqrt(5))
133
+ self.pe_bn1 = nn.GroupNorm(num_groups=min(8, self.C_fine if self.C_fine>1 else 1), num_channels=self.C_fine)
134
+ self.pe_bn2 = nn.GroupNorm(num_groups=min(8, self.C_fine if self.C_fine>1 else 1), num_channels=self.C_fine)
135
+
136
+ # Encoder double convs
137
+ self.enc_w1 = nn.ParameterList()
138
+ self.enc_w2 = nn.ParameterList()
139
+ self.enc_bn1 = nn.ModuleList()
140
+ self.enc_bn2 = nn.ModuleList()
141
+ for i in range(self.token_down):
142
+ Cin = self.level_dims[i]
143
+ Cout = self.level_dims[i+1]
144
+ Cout_g = Cout // self.G
145
+ w1 = nn.Parameter(torch.empty(Cin, Cout_g, self.KERNELSZ*self.KERNELSZ))
146
+ nn.init.kaiming_uniform_(w1.view(Cin * Cout_g, -1), a=np.sqrt(5))
147
+ w2 = nn.Parameter(torch.empty(Cout, Cout_g, self.KERNELSZ*self.KERNELSZ))
148
+ nn.init.kaiming_uniform_(w2.view(Cout * Cout_g, -1), a=np.sqrt(5))
149
+ self.enc_w1.append(w1); self.enc_w2.append(w2)
150
+ self.enc_bn1.append(nn.GroupNorm(num_groups=min(8, Cout if Cout>1 else 1), num_channels=Cout))
151
+ self.enc_bn2.append(nn.GroupNorm(num_groups=min(8, Cout if Cout>1 else 1), num_channels=Cout))
152
+
153
+ # Temporal encoders per level (fine..pre-token)
154
+ self.temporal_encoders = nn.ModuleList([
155
+ nn.TransformerEncoder(
156
+ nn.TransformerEncoderLayer(
157
+ d_model=self.level_dims[i],
158
+ nhead=max(1, min(8, self.level_dims[i] // 64)),
159
+ dim_feedforward=2*self.level_dims[i],
160
+ dropout=self.dropout,
161
+ activation='gelu',
162
+ batch_first=True,
163
+ norm_first=True,
164
+ ),
165
+ num_layers=2,
166
+ )
167
+ for i in range(self.token_down)
168
+ ])
169
+
170
+ # Token-level Transformer
171
+ self.n_tokens = int(self.token_cell_ids.shape[0])
172
+ self.pos_token = nn.Parameter(torch.zeros(1, self.n_tokens, self.embed_dim))
173
+ nn.init.trunc_normal_(self.pos_token, std=0.02)
174
+ enc_layer = nn.TransformerEncoderLayer(
175
+ d_model=self.embed_dim,
176
+ nhead=self.num_heads_token,
177
+ dim_feedforward=int(self.embed_dim * self.mlp_ratio_token),
178
+ dropout=self.dropout,
179
+ activation='gelu',
180
+ batch_first=True,
181
+ norm_first=True,
182
+ )
183
+ self.encoder_token = nn.TransformerEncoder(enc_layer, num_layers=self.depth_token)
184
+
185
+ # Decoder fusion modules per level (cross-attention)
186
+ self.dec_q = nn.ModuleList()
187
+ self.dec_k = nn.ModuleList()
188
+ self.dec_v = nn.ModuleList()
189
+ self.dec_attn = nn.ModuleList()
190
+ self.dec_mlp = nn.ModuleList()
191
+ self.level_pos = nn.ParameterList() if self.pos_embed_per_level else None
192
+ for i in range(self.token_down, 0, -1):
193
+ Cfine = self.level_dims[i-1]
194
+ d_fuse = Cfine
195
+ self.dec_q.append(nn.Linear(Cfine, d_fuse))
196
+ self.dec_k.append(nn.Linear(Cfine, d_fuse))
197
+ self.dec_v.append(nn.Linear(Cfine, d_fuse))
198
+ self.dec_attn.append(nn.MultiheadAttention(embed_dim=d_fuse, num_heads=max(1, min(8, d_fuse // 64)), batch_first=True))
199
+ self.dec_mlp.append(MLP(d_fuse, hidden_mult=4, drop=self.dropout))
200
+ if self.pos_embed_per_level:
201
+ n_tok_i = self.ntokens_levels[i-1]
202
+ p = nn.Parameter(torch.zeros(1, n_tok_i, d_fuse))
203
+ nn.init.trunc_normal_(p, std=0.02)
204
+ self.level_pos.append(p)
205
+
206
+ # Decoder refinement double convs
207
+ self.dec_refine_w1 = nn.ParameterList()
208
+ self.dec_refine_w2 = nn.ParameterList()
209
+ self.dec_refine_bn1 = nn.ModuleList()
210
+ self.dec_refine_bn2 = nn.ModuleList()
211
+ for i in range(self.token_down, 0, -1):
212
+ Cfine = self.level_dims[i-1]
213
+ Cfine_g = Cfine // self.G
214
+ w1 = nn.Parameter(torch.empty(Cfine, Cfine_g, self.KERNELSZ*self.KERNELSZ))
215
+ nn.init.kaiming_uniform_(w1.view(Cfine * Cfine_g, -1), a=np.sqrt(5))
216
+ w2 = nn.Parameter(torch.empty(Cfine, Cfine_g, self.KERNELSZ*self.KERNELSZ))
217
+ nn.init.kaiming_uniform_(w2.view(Cfine * Cfine_g, -1), a=np.sqrt(5))
218
+ self.dec_refine_w1.append(w1); self.dec_refine_w2.append(w2)
219
+ self.dec_refine_bn1.append(nn.GroupNorm(num_groups=min(8, Cfine if Cfine>1 else 1), num_channels=Cfine))
220
+ self.dec_refine_bn2.append(nn.GroupNorm(num_groups=min(8, Cfine if Cfine>1 else 1), num_channels=Cfine))
221
+
222
+ # Head
223
+ if self.task == "global":
224
+ self.global_head = nn.Linear(self.embed_dim, self.out_channels)
225
+ else:
226
+ if self.out_channels % self.G != 0:
227
+ raise ValueError(f"out_channels={self.out_channels} must be divisible by G={self.G}")
228
+ out_g = self.out_channels // self.G
229
+ self.head_w = nn.Parameter(torch.empty(self.C_fine, out_g, self.KERNELSZ*self.KERNELSZ))
230
+ nn.init.kaiming_uniform_(self.head_w.view(self.C_fine * out_g, -1), a=np.sqrt(5))
231
+ self.head_bn = nn.GroupNorm(num_groups=min(8, self.out_channels if self.out_channels>1 else 1),
232
+ num_channels=self.out_channels) if self.task=="segmentation" else None
233
+
234
+ pref = torch.device("cuda" if torch.cuda.is_available() else "cpu")
235
+ self.runtime_device = self._probe_and_set_runtime_device(pref)
236
+
237
+ def _move_hc(self, hc: ho.SphericalStencil, device: torch.device) -> None:
238
+ for name, val in list(vars(hc).items()):
239
+ try:
240
+ if torch.is_tensor(val):
241
+ setattr(hc, name, val.to(device))
242
+ elif isinstance(val, (list, tuple)) and val and torch.is_tensor(val[0]):
243
+ setattr(hc, name, type(val)([v.to(device) for v in val]))
244
+ except Exception:
245
+ pass
246
+
247
+ @torch.no_grad()
248
+ def _probe_and_set_runtime_device(self, preferred: torch.device) -> torch.device:
249
+ if preferred.type == "cuda":
250
+ try:
251
+ super().to(preferred)
252
+ for hc in self.hconv_levels + [self.hconv_token, self.hconv_head]:
253
+ self._move_hc(hc, preferred)
254
+ npix0 = int(self.cell_ids_fine.shape[0])
255
+ x_try = torch.zeros(1, self.n_chan_in, npix0, device=preferred)
256
+ hc0 = self.hconv_levels[0] if len(self.hconv_levels)>0 else self.hconv_head
257
+ y_try = hc0.Convol_torch(x_try, self.pe_w1, cell_ids=self.cell_ids_fine)
258
+ _ = (y_try if torch.is_tensor(y_try) else torch.as_tensor(y_try, device=preferred)).sum().item()
259
+ self._foscat_device = preferred
260
+ return preferred
261
+ except Exception:
262
+ pass
263
+ cpu = torch.device("cpu")
264
+ super().to(cpu)
265
+ for hc in self.hconv_levels + [self.hconv_token, self.hconv_head]:
266
+ self._move_hc(hc, cpu)
267
+ self._foscat_device = cpu
268
+ return cpu
269
+
270
+ def _as_tensor_batch(self, x):
271
+ if isinstance(x, list):
272
+ if len(x) == 1:
273
+ t = x[0]
274
+ return t.unsqueeze(0) if t.dim() == 2 else t
275
+ raise ValueError("Variable-length list not supported here; pass a tensor.")
276
+ return x
277
+
278
+ def _to_numpy_ids(self, ids):
279
+ if torch.is_tensor(ids):
280
+ return ids.detach().cpu().numpy()
281
+ return np.asarray(ids)
282
+
283
+ def _patch_embed_fine(self, x_t: torch.Tensor) -> torch.Tensor:
284
+ hc0 = self.hconv_levels[0] if len(self.hconv_levels)>0 else self.hconv_head
285
+ z = hc0.Convol_torch(x_t, self.pe_w1, cell_ids=self.cell_ids_fine)
286
+ z = self._as_tensor_batch(z if torch.is_tensor(z) else torch.as_tensor(z, device=self.runtime_device))
287
+ z = self.pe_bn1(z); z = F.gelu(z)
288
+ z = hc0.Convol_torch(z, self.pe_w2, cell_ids=self.cell_ids_fine)
289
+ z = self._as_tensor_batch(z if torch.is_tensor(z) else torch.as_tensor(z, device=self.runtime_device))
290
+ z = self.pe_bn2(z); z = F.gelu(z)
291
+ return z
292
+
293
+ def forward(self, x: torch.Tensor, runtime_ids: Optional[np.ndarray] = None) -> torch.Tensor:
294
+ if x.dim() != 4:
295
+ raise ValueError("Expected input shape (B, T_in, C_in, Npix)")
296
+ B, T_in, C_in, Nf = x.shape
297
+ if C_in != self.n_chan_in:
298
+ raise ValueError(f"Expected n_chan_in={self.n_chan_in}, got {C_in}")
299
+ x = x.to(self.runtime_device)
300
+
301
+ fine_ids_runtime = self.cell_ids_fine if runtime_ids is None else self._to_numpy_ids(runtime_ids)
302
+ ids_chain = [np.asarray(fine_ids_runtime)]
303
+ nside_tmp = self.in_nside
304
+ _dummy = self.f.backend.bk_cast(np.zeros((1, 1, ids_chain[0].shape[0]), dtype=self.np_dtype))
305
+ for hc in self.hconv_levels:
306
+ _dummy, _next = hc.Down(_dummy, cell_ids=ids_chain[-1], nside=nside_tmp, max_poll=True)
307
+ ids_chain.append(self.f.backend.to_numpy(_next))
308
+ nside_tmp //= 2
309
+
310
+ # Encoder histories per level
311
+ l_hist: List[torch.Tensor] = []
312
+ l_ids: List[np.ndarray] = []
313
+
314
+ feats_fine = []
315
+ for t in range(T_in):
316
+ zt = self._patch_embed_fine(x[:, t, :, :])
317
+ feats_fine.append(zt.unsqueeze(1))
318
+ feats_fine = torch.cat(feats_fine, dim=1) # (B, T_in, C_fine, N_fine)
319
+ l_hist.append(feats_fine)
320
+ l_ids.append(self.cell_ids_fine)
321
+
322
+ current_nside = self.in_nside
323
+ l_data_hist = feats_fine
324
+ for i, hc in enumerate(self.hconv_levels):
325
+ Cin = self.level_dims[i]
326
+ Cout = self.level_dims[i+1]
327
+ w1, w2 = self.enc_w1[i], self.enc_w2[i]
328
+ feats_next = []
329
+ for t in range(T_in):
330
+ zt = l_data_hist[:, t, :, :]
331
+ zt = hc.Convol_torch(zt, w1, cell_ids=l_ids[-1])
332
+ zt = self._as_tensor_batch(zt if torch.is_tensor(zt) else torch.as_tensor(zt, device=self.runtime_device))
333
+ zt = self.enc_bn1[i](zt); zt = F.gelu(zt)
334
+ zt = hc.Convol_torch(zt, w2, cell_ids=l_ids[-1])
335
+ zt = self._as_tensor_batch(zt if torch.is_tensor(zt) else torch.as_tensor(zt, device=self.runtime_device))
336
+ zt = self.enc_bn2[i](zt); zt = F.gelu(zt)
337
+ feats_next.append(zt.unsqueeze(1))
338
+ feats_next = torch.cat(feats_next, dim=1) # (B, T_in, Cout, N_i)
339
+
340
+ feats_down = []
341
+ next_ids_list = None
342
+ for t in range(T_in):
343
+ zt, next_ids = hc.Down(feats_next[:, t, :, :], cell_ids=l_ids[-1], nside=current_nside, max_poll=True)
344
+ zt = self._as_tensor_batch(zt)
345
+ feats_down.append(zt.unsqueeze(1))
346
+ next_ids_list = next_ids
347
+ feats_down = torch.cat(feats_down, dim=1) # (B, T_in, Cout, N_{i+1})
348
+
349
+ l_hist.append(feats_down)
350
+ l_ids.append(self.f.backend.to_numpy(next_ids_list))
351
+ l_data_hist = feats_down
352
+ current_nside //= 2
353
+
354
+ # Temporal encoder on skips (levels 0..token_down-1)
355
+ skips: List[torch.Tensor] = []
356
+ for i in range(self.token_down):
357
+ Bx, Tx, Cx, Nx = l_hist[i].shape
358
+ z = l_hist[i].permute(0, 3, 1, 2).reshape(Bx*Nx, Tx, Cx)
359
+ z = self.temporal_encoders[i](z)
360
+ z = z.mean(dim=1)
361
+ H_i = z.view(Bx, Nx, Cx).permute(0, 2, 1).contiguous()
362
+ skips.append(H_i)
363
+
364
+ # Token-level transformer (spatial)
365
+ x_tok_hist = l_hist[-1] # (B, T_in, E, Ntok)
366
+ x_tok = x_tok_hist.mean(dim=1) # (B, E, Ntok) (could add temporal encoder here as well)
367
+ seq = x_tok.permute(0, 2, 1) + self.pos_token[:, :x_tok.shape[2], :]
368
+ seq = self.encoder_token(seq)
369
+ y = seq.permute(0, 2, 1) # (B, E, Ntok)
370
+
371
+ if self.task == "global":
372
+ g = seq.mean(dim=1)
373
+ return self.global_head(g)
374
+
375
+ # Decoder: Up + cross-attn fusion + double conv refinement
376
+ dec_idx = 0
377
+ for i in range(self.token_down, 0, -1):
378
+ coarse_ids = ids_chain[i]
379
+ fine_ids = ids_chain[i-1]
380
+ source_ns = self.in_nside // (2 ** i)
381
+ fine_ns = self.in_nside // (2 ** (i-1))
382
+ Cfine = self.level_dims[i-1]
383
+
384
+ op_fine = self.hconv_head if fine_ns == self.in_nside else self.hconv_levels[self.nsides_levels.index(fine_ns)]
385
+
386
+ y_up = op_fine.Up(y, cell_ids=coarse_ids, o_cell_ids=fine_ids, nside=source_ns)
387
+ y_up = self._as_tensor_batch(y_up if torch.is_tensor(y_up) else torch.as_tensor(y_up, device=self.runtime_device)) # (B, Cfine, N)
388
+
389
+ skip_i = skips[i-1] # (B, Cfine, N)
390
+ q = self.dec_q[dec_idx](y_up.permute(0,2,1))
391
+ k = self.dec_k[dec_idx](skip_i.permute(0,2,1))
392
+ v = self.dec_v[dec_idx](skip_i.permute(0,2,1))
393
+ if self.pos_embed_per_level:
394
+ pos = self.level_pos[dec_idx][:, :q.shape[1], :]
395
+ q = q + pos; k = k + pos
396
+ z, _ = self.dec_attn[dec_idx](q, k, v)
397
+ z = self.dec_mlp[dec_idx](z)
398
+ z = z.permute(0,2,1).contiguous() # (B, Cfine, N)
399
+
400
+ z = op_fine.Convol_torch(z, self.dec_refine_w1[dec_idx], cell_ids=fine_ids)
401
+ z = self._as_tensor_batch(z if torch.is_tensor(z) else torch.as_tensor(z, device=self.runtime_device))
402
+ z = self.dec_refine_bn1[dec_idx](z); z = F.gelu(z)
403
+ z = op_fine.Convol_torch(z, self.dec_refine_w2[dec_idx], cell_ids=fine_ids)
404
+ z = self._as_tensor_batch(z if torch.is_tensor(z) else torch.as_tensor(z, device=self.runtime_device))
405
+ z = self.dec_refine_bn2[dec_idx](z); z = F.gelu(z)
406
+
407
+ y = z
408
+ dec_idx += 1
409
+
410
+ y = self.hconv_head.Convol_torch(y, self.head_w, cell_ids=fine_ids_runtime)
411
+ y = self._as_tensor_batch(y if torch.is_tensor(y) else torch.as_tensor(y, device=self.runtime_device))
412
+ if self.task == "segmentation" and self.head_bn is not None:
413
+ y = self.head_bn(y)
414
+ if self.final_activation == "sigmoid":
415
+ y = torch.sigmoid(y)
416
+ elif self.final_activation == "softmax":
417
+ y = torch.softmax(y, dim=1)
418
+ return y
419
+
420
+
421
+ if __name__ == "__main__":
422
+ in_nside = 4
423
+ npix = 12 * in_nside * in_nside
424
+ cell_ids = np.arange(npix, dtype=np.int64)
425
+
426
+ B, T_in, Cin = 2, 3, 4
427
+ x = torch.randn(B, T_in, Cin, npix)
428
+
429
+ model = HealpixViTSkip(
430
+ in_nside=in_nside,
431
+ n_chan_in=Cin,
432
+ level_dims=[64, 96, 128],
433
+ depth_token=2,
434
+ num_heads_token=4,
435
+ cell_ids=cell_ids,
436
+ task="regression",
437
+ out_channels=1,
438
+ KERNELSZ=3,
439
+ G=1,
440
+ dropout=0.1,
441
+ ).eval()
442
+
443
+ with torch.no_grad():
444
+ y = model(x)
445
+ print("Output:", tuple(y.shape))