foscat 2025.9.4__py3-none-any.whl → 2025.10.2__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
foscat/FoCUS.py CHANGED
@@ -6,7 +6,7 @@ import numpy as np
6
6
  import foscat.HealSpline as HS
7
7
  from scipy.interpolate import griddata
8
8
 
9
- TMPFILE_VERSION = "V9_0"
9
+ TMPFILE_VERSION = "V10_0"
10
10
 
11
11
 
12
12
  class FoCUS:
@@ -36,7 +36,7 @@ class FoCUS:
36
36
  mpi_rank=0
37
37
  ):
38
38
 
39
- self.__version__ = "2025.09.4"
39
+ self.__version__ = "2025.10.2"
40
40
  # P00 coeff for normalization for scat_cov
41
41
  self.TMPFILE_VERSION = TMPFILE_VERSION
42
42
  self.P1_dic = None
@@ -1488,7 +1488,7 @@ class FoCUS:
1488
1488
  if l_kernel == 5:
1489
1489
  pw = 0.5
1490
1490
  pw2 = 0.5
1491
- threshold = 2e-4
1491
+ threshold = 2e-5
1492
1492
 
1493
1493
  elif l_kernel == 3:
1494
1494
  pw = 1.0 / np.sqrt(2)
@@ -1498,7 +1498,7 @@ class FoCUS:
1498
1498
  elif l_kernel == 7:
1499
1499
  pw = 0.5
1500
1500
  pw2 = 0.25
1501
- threshold = 4e-5
1501
+ threshold = 2e-5
1502
1502
 
1503
1503
  import foscat.SphericalStencil as hs
1504
1504
  import torch
@@ -1517,14 +1517,19 @@ class FoCUS:
1517
1517
  n_gauges=self.NORIENT,
1518
1518
  gauge_type='cosmo')
1519
1519
 
1520
- xx=np.tile(np.arange(self.KERNELSZ)-self.KERNELSZ//2,self.KERNELSZ).reshape(self.KERNELSZ*self.KERNELSZ)
1520
+ xx=np.tile(np.arange(self.KERNELSZ)-self.KERNELSZ//2,self.KERNELSZ).reshape(self.KERNELSZ,self.KERNELSZ)
1521
1521
 
1522
- wwr=hconvol.to_tensor((np.exp(-pw2*(xx**2+(xx.T)**2))*np.cos(pw*xx*np.pi)).reshape(1,1,self.KERNELSZ*self.KERNELSZ))
1522
+ wwr=(np.exp(-pw2*(xx**2+(xx.T)**2))*np.cos(pw*xx*np.pi)).reshape(1,1,self.KERNELSZ*self.KERNELSZ)
1523
1523
  wwr-=wwr.mean()
1524
- wwi=hconvol.to_tensor((np.exp(-pw2*(xx**2+(xx.T)**2))*np.sin(pw*xx*np.pi)).reshape(1,1,self.KERNELSZ*self.KERNELSZ))
1524
+ wwi=(np.exp(-pw2*(xx**2+(xx.T)**2))*np.sin(pw*xx*np.pi)).reshape(1,1,self.KERNELSZ*self.KERNELSZ)
1525
1525
  wwi-=wwi.mean()
1526
- wwr/=(abs(wwr+1J*wwi)).sum()
1527
- wwi/=(abs(wwr+1J*wwi)).sum()
1526
+ amp=np.sum(abs(wwr+1J*wwi))
1527
+
1528
+ wwr/=amp
1529
+ wwi/=amp
1530
+
1531
+ wwr=hconvol.to_tensor(wwr)
1532
+ wwi=hconvol.to_tensor(wwi)
1528
1533
 
1529
1534
  wavr,indice,mshape=hconvol.make_matrix(wwr)
1530
1535
  wavi,indice,mshape=hconvol.make_matrix(wwi)
foscat/Plot.py CHANGED
@@ -959,8 +959,9 @@ def conjugate_gradient_normal_equation(data, x0, www, all_idx,
959
959
 
960
960
  rs_new = np.dot(r, r)
961
961
 
962
- if verbose and i % 50 == 0:
963
- print(f"Iter {i:03d}: residual = {np.sqrt(rs_new):.3e}")
962
+ if verbose and i % 10 == 0:
963
+ v=np.mean((LP(p, www, all_idx)-data)**2)
964
+ print(f"Iter {i:03d}: residual = {np.sqrt(rs_new):.3e},{np.sqrt(v):.3e}")
964
965
 
965
966
  if np.sqrt(rs_new) < tol:
966
967
  if verbose:
@@ -1065,28 +1066,69 @@ def _spectrum_polar_to_cartesian_core(
1065
1066
  valid = np.isfinite(radial_index)
1066
1067
  try:
1067
1068
  from scipy.ndimage import map_coordinates
1068
- order = 3 if method.lower() == "bicubic" else 1
1069
- coords = np.vstack([radial_index.ravel(), angular_index.ravel()])
1070
- eps = 1e-6
1071
- coords[0, :] = np.where(np.isfinite(coords[0, :]),
1072
- np.clip(coords[0, :], 0.0+eps, (ns-1)-eps), 0.0)
1073
- sampled = map_coordinates(
1074
- w, coords, order=order, mode="wrap", cval=fill_value, prefilter=True
1075
- ).reshape(n_pixels, n_pixels)
1076
- img = np.where(valid, sampled, fill_value)
1069
+
1070
+ if method.lower() == "bicubic":
1071
+ # ===== Bicubic with circular angular wrap =====
1072
+ # Pad the angular axis by K columns on both sides so that the cubic kernel
1073
+ # has valid neighbors across the 0°/360° seam. K=2 is enough for a cubic kernel.
1074
+ K = 2
1075
+ # w has shape (Nscale, Norient)
1076
+ w_pad = np.concatenate([w[:, -K:], w, w[:, :K]], axis=1) # (ns, no+2K)
1077
+
1078
+ # Build coordinates for map_coordinates on the padded array:
1079
+ # - radial index stays the same, clipped to [0, ns-1] (no wrap)
1080
+ # - angular index is shifted by +K and wrapped into [0, no) before querying
1081
+ order = 3
1082
+ coords = np.vstack([radial_index.ravel(), angular_index.ravel()])
1083
+
1084
+ # clip the radial coordinate to the valid [eps, ns-1-eps] band
1085
+ eps = 1e-6
1086
+ coords[0, :] = np.where(
1087
+ np.isfinite(coords[0, :]),
1088
+ np.clip(coords[0, :], 0.0 + eps, (ns - 1) - eps),
1089
+ 0.0,
1090
+ )
1091
+
1092
+ # wrap angular coordinate to [0, no) then shift by +K to address the padded array
1093
+ ang = np.mod(coords[1, :], float(no)) + K # [K, no+K)
1094
+ coords[1, :] = ang
1095
+
1096
+ # Now sample the padded array; 'nearest' mode is fine thanks to explicit padding
1097
+ sampled = map_coordinates(
1098
+ w_pad, coords, order=order, mode="nearest", cval=fill_value, prefilter=True
1099
+ ).reshape(n_pixels, n_pixels)
1100
+
1101
+ img = np.where(valid, sampled, fill_value)
1102
+
1103
+ else:
1104
+ # ===== Bilinear (or other) without special padding =====
1105
+ order = 1
1106
+ coords = np.vstack([radial_index.ravel(), angular_index.ravel()])
1107
+ eps = 1e-6
1108
+ coords[0, :] = np.where(
1109
+ np.isfinite(coords[0, :]),
1110
+ np.clip(coords[0, :], 0.0 + eps, (ns - 1) - eps),
1111
+ 0.0,
1112
+ )
1113
+ # For non-bicubic, SciPy's mode="wrap" is sufficient on the angular axis
1114
+ sampled = map_coordinates(
1115
+ w, coords, order=order, mode="wrap", cval=fill_value, prefilter=True
1116
+ ).reshape(n_pixels, n_pixels)
1117
+ img = np.where(valid, sampled, fill_value)
1118
+
1077
1119
  except Exception:
1078
- # bilinear fallback
1120
+ # ---- Vectorized bilinear fallback with explicit angular wrap ----
1079
1121
  r_idx = np.floor(radial_index).astype(np.int64)
1080
1122
  t_idx = np.floor(angular_index).astype(np.int64)
1081
- r_idx = np.clip(r_idx, 0, ns-2)
1123
+ r_idx = np.clip(r_idx, 0, ns - 2)
1082
1124
  t0 = np.mod(t_idx, no)
1083
- t1 = np.mod(t_idx+1, no)
1125
+ t1 = np.mod(t_idx + 1, no)
1084
1126
  tr = np.clip(radial_index - r_idx, 0.0, 1.0)
1085
1127
  ta = np.clip(angular_index - t_idx, 0.0, 1.0)
1086
1128
  f00 = w[r_idx, t0]
1087
1129
  f01 = w[r_idx, t1]
1088
- f10 = w[r_idx+1, t0]
1089
- f11 = w[r_idx+1, t1]
1130
+ f10 = w[r_idx + 1, t0]
1131
+ f11 = w[r_idx + 1, t1]
1090
1132
  g0 = (1.0 - ta) * f00 + ta * f01
1091
1133
  g1 = (1.0 - ta) * f10 + ta * f11
1092
1134
  img = (1.0 - tr) * g0 + tr * g1
@@ -1114,7 +1156,7 @@ def plot_wave(wave,title="spectrum",unit="Amplitude",cmap="viridis"):
1114
1156
  plt.xlabel(r"$k_x$ [cycles / km]")
1115
1157
  plt.ylabel(r"$k_y$ [cycles / km]")
1116
1158
  plt.title(title)
1117
-
1159
+
1118
1160
  def lonlat_edges_from_ref(shape, ref_lon, ref_lat, dlon, dlat, anchor="center"):
1119
1161
  """
1120
1162
  Build lon/lat *edges* (H+1, W+1) for a regular, axis-aligned grid.
@@ -982,6 +982,9 @@ def fit(
982
982
  n_epoch: int = 10,
983
983
  view_epoch: int = 10,
984
984
  batch_size: int = 16,
985
+ x_valid: Union[torch.Tensor, np.ndarray, List[Union[torch.Tensor, np.ndarray]]]= None,
986
+ y_valid: Union[torch.Tensor, np.ndarray, List[Union[torch.Tensor, np.ndarray]]]= None,
987
+ save_model: bool = False,
985
988
  lr: float = 1e-3,
986
989
  weight_decay: float = 0.0,
987
990
  clip_grad_norm: Optional[float] = None,
@@ -1005,6 +1008,11 @@ def fit(
1005
1008
  device = model.runtime_device if hasattr(model, "runtime_device") else (torch.device("cuda") if torch.cuda.is_available() else torch.device("cpu"))
1006
1009
  model.to(device)
1007
1010
 
1011
+ if save_model:
1012
+ assert x_valid is None, "If save_mode=True x_valid should not be None"
1013
+ assert y_valid is None, "If save_mode=True y_valid should not be None"
1014
+ best_valid=1E30
1015
+
1008
1016
  # Detect variable-length mode
1009
1017
  varlen_mode = isinstance(x_train, (list, tuple))
1010
1018
 
@@ -1197,6 +1205,14 @@ def fit(
1197
1205
  history.append(epoch_loss)
1198
1206
  # print every view_epoch logical step
1199
1207
  if verbose and ((len(history) % view_epoch == 0) or (len(history) == 1)):
1200
- print(f"[epoch {len(history)}] loss={epoch_loss:.6f}")
1208
+ if x_valid is not None:
1209
+ preds=model.predict(model.to_tensor(x_valid)).cpu().numpy()
1210
+ valid_loss=np.mean((preds-y_valid)**2)
1211
+ if save_model:
1212
+ if best_valid>valid_loss:
1213
+ torch.save({"model": self.state_dict(), "cfg": CFG}, os.path.join(CFG["save_dir"], "best.pt"))
1214
+ print(f"[epoch {len(history)}] loss={epoch_loss:.4f} loss_valid={valid_loss:.4f}")
1215
+ else:
1216
+ print(f"[epoch {len(history)}] loss={epoch_loss:.4f}")
1201
1217
 
1202
1218
  return {"loss": history}
@@ -0,0 +1,445 @@
1
+ # healpix_vit_skip.py
2
+ # HEALPix ViT U-Net with temporal encoders and Transformer-based skip fusion.
3
+ # - Multi-level HEALPix pyramid using Foscat.SphericalStencil
4
+ # - Per-level temporal encoding (sequence over T_in months) at encoder
5
+ # - Decoder uses cross-attention to fuse upsampled features with encoder skips
6
+ # - Double spherical convolution + GroupNorm + GELU at each encoder/decoder level
7
+
8
+ from __future__ import annotations
9
+ from typing import List, Optional, Literal
10
+ import numpy as np
11
+
12
+ import torch
13
+ import torch.nn as nn
14
+ import torch.nn.functional as F
15
+
16
+ import foscat.scat_cov as sc
17
+ import foscat.SphericalStencil as ho
18
+
19
+
20
+ class MLP(nn.Module):
21
+ def __init__(self, d: int, hidden_mult: int = 4, drop: float = 0.0):
22
+ super().__init__()
23
+ self.net = nn.Sequential(
24
+ nn.LayerNorm(d),
25
+ nn.Linear(d, hidden_mult * d),
26
+ nn.GELU(),
27
+ nn.Dropout(drop),
28
+ nn.Linear(hidden_mult * d, d),
29
+ nn.Dropout(drop),
30
+ )
31
+ def forward(self, x):
32
+ return self.net(x)
33
+
34
+
35
+ class HealpixViTSkip(nn.Module):
36
+ def __init__(
37
+ self,
38
+ *,
39
+ in_nside: int,
40
+ n_chan_in: int,
41
+ level_dims: List[int],
42
+ depth_token: int,
43
+ num_heads_token: int,
44
+ cell_ids: np.ndarray,
45
+ task: Literal["regression","segmentation","global"] = "regression",
46
+ out_channels: int = 1,
47
+ mlp_ratio_token: float = 4.0,
48
+ KERNELSZ: int = 3,
49
+ gauge_type: Literal["cosmo","phi"] = "cosmo",
50
+ G: int = 1,
51
+ prefer_foscat_gpu: bool = True,
52
+ dropout: float = 0.1,
53
+ dtype: Literal["float32","float64"] = "float32",
54
+ pos_embed_per_level: bool = True,
55
+ ) -> None:
56
+ super().__init__()
57
+
58
+ self.in_nside = int(in_nside)
59
+ self.n_chan_in = int(n_chan_in)
60
+ self.level_dims = list(level_dims)
61
+ self.token_down = len(self.level_dims) - 1
62
+ assert self.token_down >= 0
63
+ self.C_fine = int(self.level_dims[0])
64
+ self.embed_dim = int(self.level_dims[-1])
65
+ self.depth_token = int(depth_token)
66
+ self.num_heads_token = int(num_heads_token)
67
+ self.mlp_ratio_token = float(mlp_ratio_token)
68
+ self.task = task
69
+ self.out_channels = int(out_channels)
70
+ self.KERNELSZ = int(KERNELSZ)
71
+ self.gauge_type = gauge_type
72
+ self.G = int(G)
73
+ self.prefer_foscat_gpu = bool(prefer_foscat_gpu)
74
+ self.dropout = float(dropout)
75
+ self.dtype = dtype
76
+ self.pos_embed_per_level = bool(pos_embed_per_level)
77
+
78
+ for d in self.level_dims:
79
+ if d % self.G != 0:
80
+ raise ValueError(f"All level_dims must be divisible by G={self.G}, got {d}.")
81
+ if self.embed_dim % self.num_heads_token != 0:
82
+ raise ValueError("embed_dim must be divisible by num_heads_token.")
83
+
84
+ if dtype == "float32":
85
+ self.np_dtype = np.float32
86
+ self.torch_dtype = torch.float32
87
+ else:
88
+ self.np_dtype = np.float64
89
+ self.torch_dtype = torch.float32
90
+
91
+ if cell_ids is None:
92
+ raise ValueError("cell_ids (finest) must be provided.")
93
+ self.cell_ids_fine = np.asarray(cell_ids)
94
+
95
+ if self.task == "segmentation":
96
+ self.final_activation = "sigmoid" if self.out_channels == 1 else "softmax"
97
+ else:
98
+ self.final_activation = "none"
99
+
100
+ self.f = sc.funct(KERNELSZ=self.KERNELSZ)
101
+
102
+ # Build stencils
103
+ self.hconv_levels: List[ho.SphericalStencil] = []
104
+ self.level_cell_ids: List[np.ndarray] = [self.cell_ids_fine]
105
+ current_nside = self.in_nside
106
+ dummy = self.f.backend.bk_cast(np.zeros((1, 1, self.cell_ids_fine.shape[0]), dtype=self.np_dtype))
107
+ for _ in range(self.token_down):
108
+ hc = ho.SphericalStencil(current_nside, self.KERNELSZ, n_gauges=self.G,
109
+ gauge_type=self.gauge_type, cell_ids=self.level_cell_ids[-1],
110
+ dtype=self.torch_dtype)
111
+ self.hconv_levels.append(hc)
112
+ dummy, next_ids = hc.Down(dummy, cell_ids=self.level_cell_ids[-1], nside=current_nside, max_poll=True)
113
+ self.level_cell_ids.append(self.f.backend.to_numpy(next_ids))
114
+ current_nside //= 2
115
+
116
+ self.token_nside = current_nside if self.token_down > 0 else self.in_nside
117
+ self.token_cell_ids = self.level_cell_ids[-1]
118
+
119
+ self.hconv_token = ho.SphericalStencil(self.token_nside, self.KERNELSZ, n_gauges=self.G,
120
+ gauge_type=self.gauge_type, cell_ids=self.token_cell_ids, dtype=self.torch_dtype)
121
+ self.hconv_head = ho.SphericalStencil(self.in_nside, self.KERNELSZ, n_gauges=self.G,
122
+ gauge_type=self.gauge_type, cell_ids=self.cell_ids_fine, dtype=self.torch_dtype)
123
+
124
+ self.nsides_levels = [self.in_nside // (2**i) for i in range(self.token_down+1)]
125
+ self.ntokens_levels = [12 * n * n for n in self.nsides_levels]
126
+
127
+ # Patch embed (double conv)
128
+ fine_g = self.C_fine // self.G
129
+ self.pe_w1 = nn.Parameter(torch.empty(self.n_chan_in, fine_g, self.KERNELSZ*self.KERNELSZ))
130
+ nn.init.kaiming_uniform_(self.pe_w1.view(self.n_chan_in * fine_g, -1), a=np.sqrt(5))
131
+ self.pe_w2 = nn.Parameter(torch.empty(self.C_fine, fine_g, self.KERNELSZ*self.KERNELSZ))
132
+ nn.init.kaiming_uniform_(self.pe_w2.view(self.C_fine * fine_g, -1), a=np.sqrt(5))
133
+ self.pe_bn1 = nn.GroupNorm(num_groups=min(8, self.C_fine if self.C_fine>1 else 1), num_channels=self.C_fine)
134
+ self.pe_bn2 = nn.GroupNorm(num_groups=min(8, self.C_fine if self.C_fine>1 else 1), num_channels=self.C_fine)
135
+
136
+ # Encoder double convs
137
+ self.enc_w1 = nn.ParameterList()
138
+ self.enc_w2 = nn.ParameterList()
139
+ self.enc_bn1 = nn.ModuleList()
140
+ self.enc_bn2 = nn.ModuleList()
141
+ for i in range(self.token_down):
142
+ Cin = self.level_dims[i]
143
+ Cout = self.level_dims[i+1]
144
+ Cout_g = Cout // self.G
145
+ w1 = nn.Parameter(torch.empty(Cin, Cout_g, self.KERNELSZ*self.KERNELSZ))
146
+ nn.init.kaiming_uniform_(w1.view(Cin * Cout_g, -1), a=np.sqrt(5))
147
+ w2 = nn.Parameter(torch.empty(Cout, Cout_g, self.KERNELSZ*self.KERNELSZ))
148
+ nn.init.kaiming_uniform_(w2.view(Cout * Cout_g, -1), a=np.sqrt(5))
149
+ self.enc_w1.append(w1); self.enc_w2.append(w2)
150
+ self.enc_bn1.append(nn.GroupNorm(num_groups=min(8, Cout if Cout>1 else 1), num_channels=Cout))
151
+ self.enc_bn2.append(nn.GroupNorm(num_groups=min(8, Cout if Cout>1 else 1), num_channels=Cout))
152
+
153
+ # Temporal encoders per level (fine..pre-token)
154
+ self.temporal_encoders = nn.ModuleList([
155
+ nn.TransformerEncoder(
156
+ nn.TransformerEncoderLayer(
157
+ d_model=self.level_dims[i],
158
+ nhead=max(1, min(8, self.level_dims[i] // 64)),
159
+ dim_feedforward=2*self.level_dims[i],
160
+ dropout=self.dropout,
161
+ activation='gelu',
162
+ batch_first=True,
163
+ norm_first=True,
164
+ ),
165
+ num_layers=2,
166
+ )
167
+ for i in range(self.token_down)
168
+ ])
169
+
170
+ # Token-level Transformer
171
+ self.n_tokens = int(self.token_cell_ids.shape[0])
172
+ self.pos_token = nn.Parameter(torch.zeros(1, self.n_tokens, self.embed_dim))
173
+ nn.init.trunc_normal_(self.pos_token, std=0.02)
174
+ enc_layer = nn.TransformerEncoderLayer(
175
+ d_model=self.embed_dim,
176
+ nhead=self.num_heads_token,
177
+ dim_feedforward=int(self.embed_dim * self.mlp_ratio_token),
178
+ dropout=self.dropout,
179
+ activation='gelu',
180
+ batch_first=True,
181
+ norm_first=True,
182
+ )
183
+ self.encoder_token = nn.TransformerEncoder(enc_layer, num_layers=self.depth_token)
184
+
185
+ # Decoder fusion modules per level (cross-attention)
186
+ self.dec_q = nn.ModuleList()
187
+ self.dec_k = nn.ModuleList()
188
+ self.dec_v = nn.ModuleList()
189
+ self.dec_attn = nn.ModuleList()
190
+ self.dec_mlp = nn.ModuleList()
191
+ self.level_pos = nn.ParameterList() if self.pos_embed_per_level else None
192
+ for i in range(self.token_down, 0, -1):
193
+ Cfine = self.level_dims[i-1]
194
+ d_fuse = Cfine
195
+ self.dec_q.append(nn.Linear(Cfine, d_fuse))
196
+ self.dec_k.append(nn.Linear(Cfine, d_fuse))
197
+ self.dec_v.append(nn.Linear(Cfine, d_fuse))
198
+ self.dec_attn.append(nn.MultiheadAttention(embed_dim=d_fuse, num_heads=max(1, min(8, d_fuse // 64)), batch_first=True))
199
+ self.dec_mlp.append(MLP(d_fuse, hidden_mult=4, drop=self.dropout))
200
+ if self.pos_embed_per_level:
201
+ n_tok_i = self.ntokens_levels[i-1]
202
+ p = nn.Parameter(torch.zeros(1, n_tok_i, d_fuse))
203
+ nn.init.trunc_normal_(p, std=0.02)
204
+ self.level_pos.append(p)
205
+
206
+ # Decoder refinement double convs
207
+ self.dec_refine_w1 = nn.ParameterList()
208
+ self.dec_refine_w2 = nn.ParameterList()
209
+ self.dec_refine_bn1 = nn.ModuleList()
210
+ self.dec_refine_bn2 = nn.ModuleList()
211
+ for i in range(self.token_down, 0, -1):
212
+ Cfine = self.level_dims[i-1]
213
+ Cfine_g = Cfine // self.G
214
+ w1 = nn.Parameter(torch.empty(Cfine, Cfine_g, self.KERNELSZ*self.KERNELSZ))
215
+ nn.init.kaiming_uniform_(w1.view(Cfine * Cfine_g, -1), a=np.sqrt(5))
216
+ w2 = nn.Parameter(torch.empty(Cfine, Cfine_g, self.KERNELSZ*self.KERNELSZ))
217
+ nn.init.kaiming_uniform_(w2.view(Cfine * Cfine_g, -1), a=np.sqrt(5))
218
+ self.dec_refine_w1.append(w1); self.dec_refine_w2.append(w2)
219
+ self.dec_refine_bn1.append(nn.GroupNorm(num_groups=min(8, Cfine if Cfine>1 else 1), num_channels=Cfine))
220
+ self.dec_refine_bn2.append(nn.GroupNorm(num_groups=min(8, Cfine if Cfine>1 else 1), num_channels=Cfine))
221
+
222
+ # Head
223
+ if self.task == "global":
224
+ self.global_head = nn.Linear(self.embed_dim, self.out_channels)
225
+ else:
226
+ if self.out_channels % self.G != 0:
227
+ raise ValueError(f"out_channels={self.out_channels} must be divisible by G={self.G}")
228
+ out_g = self.out_channels // self.G
229
+ self.head_w = nn.Parameter(torch.empty(self.C_fine, out_g, self.KERNELSZ*self.KERNELSZ))
230
+ nn.init.kaiming_uniform_(self.head_w.view(self.C_fine * out_g, -1), a=np.sqrt(5))
231
+ self.head_bn = nn.GroupNorm(num_groups=min(8, self.out_channels if self.out_channels>1 else 1),
232
+ num_channels=self.out_channels) if self.task=="segmentation" else None
233
+
234
+ pref = torch.device("cuda" if torch.cuda.is_available() else "cpu")
235
+ self.runtime_device = self._probe_and_set_runtime_device(pref)
236
+
237
+ def _move_hc(self, hc: ho.SphericalStencil, device: torch.device) -> None:
238
+ for name, val in list(vars(hc).items()):
239
+ try:
240
+ if torch.is_tensor(val):
241
+ setattr(hc, name, val.to(device))
242
+ elif isinstance(val, (list, tuple)) and val and torch.is_tensor(val[0]):
243
+ setattr(hc, name, type(val)([v.to(device) for v in val]))
244
+ except Exception:
245
+ pass
246
+
247
+ @torch.no_grad()
248
+ def _probe_and_set_runtime_device(self, preferred: torch.device) -> torch.device:
249
+ if preferred.type == "cuda":
250
+ try:
251
+ super().to(preferred)
252
+ for hc in self.hconv_levels + [self.hconv_token, self.hconv_head]:
253
+ self._move_hc(hc, preferred)
254
+ npix0 = int(self.cell_ids_fine.shape[0])
255
+ x_try = torch.zeros(1, self.n_chan_in, npix0, device=preferred)
256
+ hc0 = self.hconv_levels[0] if len(self.hconv_levels)>0 else self.hconv_head
257
+ y_try = hc0.Convol_torch(x_try, self.pe_w1, cell_ids=self.cell_ids_fine)
258
+ _ = (y_try if torch.is_tensor(y_try) else torch.as_tensor(y_try, device=preferred)).sum().item()
259
+ self._foscat_device = preferred
260
+ return preferred
261
+ except Exception:
262
+ pass
263
+ cpu = torch.device("cpu")
264
+ super().to(cpu)
265
+ for hc in self.hconv_levels + [self.hconv_token, self.hconv_head]:
266
+ self._move_hc(hc, cpu)
267
+ self._foscat_device = cpu
268
+ return cpu
269
+
270
+ def _as_tensor_batch(self, x):
271
+ if isinstance(x, list):
272
+ if len(x) == 1:
273
+ t = x[0]
274
+ return t.unsqueeze(0) if t.dim() == 2 else t
275
+ raise ValueError("Variable-length list not supported here; pass a tensor.")
276
+ return x
277
+
278
+ def _to_numpy_ids(self, ids):
279
+ if torch.is_tensor(ids):
280
+ return ids.detach().cpu().numpy()
281
+ return np.asarray(ids)
282
+
283
+ def _patch_embed_fine(self, x_t: torch.Tensor) -> torch.Tensor:
284
+ hc0 = self.hconv_levels[0] if len(self.hconv_levels)>0 else self.hconv_head
285
+ z = hc0.Convol_torch(x_t, self.pe_w1, cell_ids=self.cell_ids_fine)
286
+ z = self._as_tensor_batch(z if torch.is_tensor(z) else torch.as_tensor(z, device=self.runtime_device))
287
+ z = self.pe_bn1(z); z = F.gelu(z)
288
+ z = hc0.Convol_torch(z, self.pe_w2, cell_ids=self.cell_ids_fine)
289
+ z = self._as_tensor_batch(z if torch.is_tensor(z) else torch.as_tensor(z, device=self.runtime_device))
290
+ z = self.pe_bn2(z); z = F.gelu(z)
291
+ return z
292
+
293
+ def forward(self, x: torch.Tensor, runtime_ids: Optional[np.ndarray] = None) -> torch.Tensor:
294
+ if x.dim() != 4:
295
+ raise ValueError("Expected input shape (B, T_in, C_in, Npix)")
296
+ B, T_in, C_in, Nf = x.shape
297
+ if C_in != self.n_chan_in:
298
+ raise ValueError(f"Expected n_chan_in={self.n_chan_in}, got {C_in}")
299
+ x = x.to(self.runtime_device)
300
+
301
+ fine_ids_runtime = self.cell_ids_fine if runtime_ids is None else self._to_numpy_ids(runtime_ids)
302
+ ids_chain = [np.asarray(fine_ids_runtime)]
303
+ nside_tmp = self.in_nside
304
+ _dummy = self.f.backend.bk_cast(np.zeros((1, 1, ids_chain[0].shape[0]), dtype=self.np_dtype))
305
+ for hc in self.hconv_levels:
306
+ _dummy, _next = hc.Down(_dummy, cell_ids=ids_chain[-1], nside=nside_tmp, max_poll=True)
307
+ ids_chain.append(self.f.backend.to_numpy(_next))
308
+ nside_tmp //= 2
309
+
310
+ # Encoder histories per level
311
+ l_hist: List[torch.Tensor] = []
312
+ l_ids: List[np.ndarray] = []
313
+
314
+ feats_fine = []
315
+ for t in range(T_in):
316
+ zt = self._patch_embed_fine(x[:, t, :, :])
317
+ feats_fine.append(zt.unsqueeze(1))
318
+ feats_fine = torch.cat(feats_fine, dim=1) # (B, T_in, C_fine, N_fine)
319
+ l_hist.append(feats_fine)
320
+ l_ids.append(self.cell_ids_fine)
321
+
322
+ current_nside = self.in_nside
323
+ l_data_hist = feats_fine
324
+ for i, hc in enumerate(self.hconv_levels):
325
+ Cin = self.level_dims[i]
326
+ Cout = self.level_dims[i+1]
327
+ w1, w2 = self.enc_w1[i], self.enc_w2[i]
328
+ feats_next = []
329
+ for t in range(T_in):
330
+ zt = l_data_hist[:, t, :, :]
331
+ zt = hc.Convol_torch(zt, w1, cell_ids=l_ids[-1])
332
+ zt = self._as_tensor_batch(zt if torch.is_tensor(zt) else torch.as_tensor(zt, device=self.runtime_device))
333
+ zt = self.enc_bn1[i](zt); zt = F.gelu(zt)
334
+ zt = hc.Convol_torch(zt, w2, cell_ids=l_ids[-1])
335
+ zt = self._as_tensor_batch(zt if torch.is_tensor(zt) else torch.as_tensor(zt, device=self.runtime_device))
336
+ zt = self.enc_bn2[i](zt); zt = F.gelu(zt)
337
+ feats_next.append(zt.unsqueeze(1))
338
+ feats_next = torch.cat(feats_next, dim=1) # (B, T_in, Cout, N_i)
339
+
340
+ feats_down = []
341
+ next_ids_list = None
342
+ for t in range(T_in):
343
+ zt, next_ids = hc.Down(feats_next[:, t, :, :], cell_ids=l_ids[-1], nside=current_nside, max_poll=True)
344
+ zt = self._as_tensor_batch(zt)
345
+ feats_down.append(zt.unsqueeze(1))
346
+ next_ids_list = next_ids
347
+ feats_down = torch.cat(feats_down, dim=1) # (B, T_in, Cout, N_{i+1})
348
+
349
+ l_hist.append(feats_down)
350
+ l_ids.append(self.f.backend.to_numpy(next_ids_list))
351
+ l_data_hist = feats_down
352
+ current_nside //= 2
353
+
354
+ # Temporal encoder on skips (levels 0..token_down-1)
355
+ skips: List[torch.Tensor] = []
356
+ for i in range(self.token_down):
357
+ Bx, Tx, Cx, Nx = l_hist[i].shape
358
+ z = l_hist[i].permute(0, 3, 1, 2).reshape(Bx*Nx, Tx, Cx)
359
+ z = self.temporal_encoders[i](z)
360
+ z = z.mean(dim=1)
361
+ H_i = z.view(Bx, Nx, Cx).permute(0, 2, 1).contiguous()
362
+ skips.append(H_i)
363
+
364
+ # Token-level transformer (spatial)
365
+ x_tok_hist = l_hist[-1] # (B, T_in, E, Ntok)
366
+ x_tok = x_tok_hist.mean(dim=1) # (B, E, Ntok) (could add temporal encoder here as well)
367
+ seq = x_tok.permute(0, 2, 1) + self.pos_token[:, :x_tok.shape[2], :]
368
+ seq = self.encoder_token(seq)
369
+ y = seq.permute(0, 2, 1) # (B, E, Ntok)
370
+
371
+ if self.task == "global":
372
+ g = seq.mean(dim=1)
373
+ return self.global_head(g)
374
+
375
+ # Decoder: Up + cross-attn fusion + double conv refinement
376
+ dec_idx = 0
377
+ for i in range(self.token_down, 0, -1):
378
+ coarse_ids = ids_chain[i]
379
+ fine_ids = ids_chain[i-1]
380
+ source_ns = self.in_nside // (2 ** i)
381
+ fine_ns = self.in_nside // (2 ** (i-1))
382
+ Cfine = self.level_dims[i-1]
383
+
384
+ op_fine = self.hconv_head if fine_ns == self.in_nside else self.hconv_levels[self.nsides_levels.index(fine_ns)]
385
+
386
+ y_up = op_fine.Up(y, cell_ids=coarse_ids, o_cell_ids=fine_ids, nside=source_ns)
387
+ y_up = self._as_tensor_batch(y_up if torch.is_tensor(y_up) else torch.as_tensor(y_up, device=self.runtime_device)) # (B, Cfine, N)
388
+
389
+ skip_i = skips[i-1] # (B, Cfine, N)
390
+ q = self.dec_q[dec_idx](y_up.permute(0,2,1))
391
+ k = self.dec_k[dec_idx](skip_i.permute(0,2,1))
392
+ v = self.dec_v[dec_idx](skip_i.permute(0,2,1))
393
+ if self.pos_embed_per_level:
394
+ pos = self.level_pos[dec_idx][:, :q.shape[1], :]
395
+ q = q + pos; k = k + pos
396
+ z, _ = self.dec_attn[dec_idx](q, k, v)
397
+ z = self.dec_mlp[dec_idx](z)
398
+ z = z.permute(0,2,1).contiguous() # (B, Cfine, N)
399
+
400
+ z = op_fine.Convol_torch(z, self.dec_refine_w1[dec_idx], cell_ids=fine_ids)
401
+ z = self._as_tensor_batch(z if torch.is_tensor(z) else torch.as_tensor(z, device=self.runtime_device))
402
+ z = self.dec_refine_bn1[dec_idx](z); z = F.gelu(z)
403
+ z = op_fine.Convol_torch(z, self.dec_refine_w2[dec_idx], cell_ids=fine_ids)
404
+ z = self._as_tensor_batch(z if torch.is_tensor(z) else torch.as_tensor(z, device=self.runtime_device))
405
+ z = self.dec_refine_bn2[dec_idx](z); z = F.gelu(z)
406
+
407
+ y = z
408
+ dec_idx += 1
409
+
410
+ y = self.hconv_head.Convol_torch(y, self.head_w, cell_ids=fine_ids_runtime)
411
+ y = self._as_tensor_batch(y if torch.is_tensor(y) else torch.as_tensor(y, device=self.runtime_device))
412
+ if self.task == "segmentation" and self.head_bn is not None:
413
+ y = self.head_bn(y)
414
+ if self.final_activation == "sigmoid":
415
+ y = torch.sigmoid(y)
416
+ elif self.final_activation == "softmax":
417
+ y = torch.softmax(y, dim=1)
418
+ return y
419
+
420
+
421
+ if __name__ == "__main__":
422
+ in_nside = 4
423
+ npix = 12 * in_nside * in_nside
424
+ cell_ids = np.arange(npix, dtype=np.int64)
425
+
426
+ B, T_in, Cin = 2, 3, 4
427
+ x = torch.randn(B, T_in, Cin, npix)
428
+
429
+ model = HealpixViTSkip(
430
+ in_nside=in_nside,
431
+ n_chan_in=Cin,
432
+ level_dims=[64, 96, 128],
433
+ depth_token=2,
434
+ num_heads_token=4,
435
+ cell_ids=cell_ids,
436
+ task="regression",
437
+ out_channels=1,
438
+ KERNELSZ=3,
439
+ G=1,
440
+ dropout=0.1,
441
+ ).eval()
442
+
443
+ with torch.no_grad():
444
+ y = model(x)
445
+ print("Output:", tuple(y.shape))