foscat 2025.9.5__py3-none-any.whl → 2025.10.2__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- foscat/FoCUS.py +14 -9
- foscat/Plot.py +4 -3
- foscat/healpix_unet_torch.py +17 -1
- foscat/healpix_vit_skip.py +445 -0
- foscat/healpix_vit_torch-old.py +658 -0
- foscat/healpix_vit_torch.py +521 -0
- foscat/planar_vit.py +206 -0
- foscat/unet_2_d_from_healpix_params.py +421 -0
- {foscat-2025.9.5.dist-info → foscat-2025.10.2.dist-info}/METADATA +1 -1
- {foscat-2025.9.5.dist-info → foscat-2025.10.2.dist-info}/RECORD +13 -8
- {foscat-2025.9.5.dist-info → foscat-2025.10.2.dist-info}/WHEEL +0 -0
- {foscat-2025.9.5.dist-info → foscat-2025.10.2.dist-info}/licenses/LICENSE +0 -0
- {foscat-2025.9.5.dist-info → foscat-2025.10.2.dist-info}/top_level.txt +0 -0
foscat/FoCUS.py
CHANGED
|
@@ -6,7 +6,7 @@ import numpy as np
|
|
|
6
6
|
import foscat.HealSpline as HS
|
|
7
7
|
from scipy.interpolate import griddata
|
|
8
8
|
|
|
9
|
-
TMPFILE_VERSION = "
|
|
9
|
+
TMPFILE_VERSION = "V10_0"
|
|
10
10
|
|
|
11
11
|
|
|
12
12
|
class FoCUS:
|
|
@@ -36,7 +36,7 @@ class FoCUS:
|
|
|
36
36
|
mpi_rank=0
|
|
37
37
|
):
|
|
38
38
|
|
|
39
|
-
self.__version__ = "2025.
|
|
39
|
+
self.__version__ = "2025.10.2"
|
|
40
40
|
# P00 coeff for normalization for scat_cov
|
|
41
41
|
self.TMPFILE_VERSION = TMPFILE_VERSION
|
|
42
42
|
self.P1_dic = None
|
|
@@ -1488,7 +1488,7 @@ class FoCUS:
|
|
|
1488
1488
|
if l_kernel == 5:
|
|
1489
1489
|
pw = 0.5
|
|
1490
1490
|
pw2 = 0.5
|
|
1491
|
-
threshold = 2e-
|
|
1491
|
+
threshold = 2e-5
|
|
1492
1492
|
|
|
1493
1493
|
elif l_kernel == 3:
|
|
1494
1494
|
pw = 1.0 / np.sqrt(2)
|
|
@@ -1498,7 +1498,7 @@ class FoCUS:
|
|
|
1498
1498
|
elif l_kernel == 7:
|
|
1499
1499
|
pw = 0.5
|
|
1500
1500
|
pw2 = 0.25
|
|
1501
|
-
threshold =
|
|
1501
|
+
threshold = 2e-5
|
|
1502
1502
|
|
|
1503
1503
|
import foscat.SphericalStencil as hs
|
|
1504
1504
|
import torch
|
|
@@ -1517,14 +1517,19 @@ class FoCUS:
|
|
|
1517
1517
|
n_gauges=self.NORIENT,
|
|
1518
1518
|
gauge_type='cosmo')
|
|
1519
1519
|
|
|
1520
|
-
xx=np.tile(np.arange(self.KERNELSZ)-self.KERNELSZ//2,self.KERNELSZ).reshape(self.KERNELSZ
|
|
1520
|
+
xx=np.tile(np.arange(self.KERNELSZ)-self.KERNELSZ//2,self.KERNELSZ).reshape(self.KERNELSZ,self.KERNELSZ)
|
|
1521
1521
|
|
|
1522
|
-
wwr=
|
|
1522
|
+
wwr=(np.exp(-pw2*(xx**2+(xx.T)**2))*np.cos(pw*xx*np.pi)).reshape(1,1,self.KERNELSZ*self.KERNELSZ)
|
|
1523
1523
|
wwr-=wwr.mean()
|
|
1524
|
-
wwi=
|
|
1524
|
+
wwi=(np.exp(-pw2*(xx**2+(xx.T)**2))*np.sin(pw*xx*np.pi)).reshape(1,1,self.KERNELSZ*self.KERNELSZ)
|
|
1525
1525
|
wwi-=wwi.mean()
|
|
1526
|
-
|
|
1527
|
-
|
|
1526
|
+
amp=np.sum(abs(wwr+1J*wwi))
|
|
1527
|
+
|
|
1528
|
+
wwr/=amp
|
|
1529
|
+
wwi/=amp
|
|
1530
|
+
|
|
1531
|
+
wwr=hconvol.to_tensor(wwr)
|
|
1532
|
+
wwi=hconvol.to_tensor(wwi)
|
|
1528
1533
|
|
|
1529
1534
|
wavr,indice,mshape=hconvol.make_matrix(wwr)
|
|
1530
1535
|
wavi,indice,mshape=hconvol.make_matrix(wwi)
|
foscat/Plot.py
CHANGED
|
@@ -959,8 +959,9 @@ def conjugate_gradient_normal_equation(data, x0, www, all_idx,
|
|
|
959
959
|
|
|
960
960
|
rs_new = np.dot(r, r)
|
|
961
961
|
|
|
962
|
-
if verbose and i %
|
|
963
|
-
|
|
962
|
+
if verbose and i % 10 == 0:
|
|
963
|
+
v=np.mean((LP(p, www, all_idx)-data)**2)
|
|
964
|
+
print(f"Iter {i:03d}: residual = {np.sqrt(rs_new):.3e},{np.sqrt(v):.3e}")
|
|
964
965
|
|
|
965
966
|
if np.sqrt(rs_new) < tol:
|
|
966
967
|
if verbose:
|
|
@@ -1155,7 +1156,7 @@ def plot_wave(wave,title="spectrum",unit="Amplitude",cmap="viridis"):
|
|
|
1155
1156
|
plt.xlabel(r"$k_x$ [cycles / km]")
|
|
1156
1157
|
plt.ylabel(r"$k_y$ [cycles / km]")
|
|
1157
1158
|
plt.title(title)
|
|
1158
|
-
|
|
1159
|
+
|
|
1159
1160
|
def lonlat_edges_from_ref(shape, ref_lon, ref_lat, dlon, dlat, anchor="center"):
|
|
1160
1161
|
"""
|
|
1161
1162
|
Build lon/lat *edges* (H+1, W+1) for a regular, axis-aligned grid.
|
foscat/healpix_unet_torch.py
CHANGED
|
@@ -982,6 +982,9 @@ def fit(
|
|
|
982
982
|
n_epoch: int = 10,
|
|
983
983
|
view_epoch: int = 10,
|
|
984
984
|
batch_size: int = 16,
|
|
985
|
+
x_valid: Union[torch.Tensor, np.ndarray, List[Union[torch.Tensor, np.ndarray]]]= None,
|
|
986
|
+
y_valid: Union[torch.Tensor, np.ndarray, List[Union[torch.Tensor, np.ndarray]]]= None,
|
|
987
|
+
save_model: bool = False,
|
|
985
988
|
lr: float = 1e-3,
|
|
986
989
|
weight_decay: float = 0.0,
|
|
987
990
|
clip_grad_norm: Optional[float] = None,
|
|
@@ -1005,6 +1008,11 @@ def fit(
|
|
|
1005
1008
|
device = model.runtime_device if hasattr(model, "runtime_device") else (torch.device("cuda") if torch.cuda.is_available() else torch.device("cpu"))
|
|
1006
1009
|
model.to(device)
|
|
1007
1010
|
|
|
1011
|
+
if save_model:
|
|
1012
|
+
assert x_valid is None, "If save_mode=True x_valid should not be None"
|
|
1013
|
+
assert y_valid is None, "If save_mode=True y_valid should not be None"
|
|
1014
|
+
best_valid=1E30
|
|
1015
|
+
|
|
1008
1016
|
# Detect variable-length mode
|
|
1009
1017
|
varlen_mode = isinstance(x_train, (list, tuple))
|
|
1010
1018
|
|
|
@@ -1197,6 +1205,14 @@ def fit(
|
|
|
1197
1205
|
history.append(epoch_loss)
|
|
1198
1206
|
# print every view_epoch logical step
|
|
1199
1207
|
if verbose and ((len(history) % view_epoch == 0) or (len(history) == 1)):
|
|
1200
|
-
|
|
1208
|
+
if x_valid is not None:
|
|
1209
|
+
preds=model.predict(model.to_tensor(x_valid)).cpu().numpy()
|
|
1210
|
+
valid_loss=np.mean((preds-y_valid)**2)
|
|
1211
|
+
if save_model:
|
|
1212
|
+
if best_valid>valid_loss:
|
|
1213
|
+
torch.save({"model": self.state_dict(), "cfg": CFG}, os.path.join(CFG["save_dir"], "best.pt"))
|
|
1214
|
+
print(f"[epoch {len(history)}] loss={epoch_loss:.4f} loss_valid={valid_loss:.4f}")
|
|
1215
|
+
else:
|
|
1216
|
+
print(f"[epoch {len(history)}] loss={epoch_loss:.4f}")
|
|
1201
1217
|
|
|
1202
1218
|
return {"loss": history}
|
|
@@ -0,0 +1,445 @@
|
|
|
1
|
+
# healpix_vit_skip.py
|
|
2
|
+
# HEALPix ViT U-Net with temporal encoders and Transformer-based skip fusion.
|
|
3
|
+
# - Multi-level HEALPix pyramid using Foscat.SphericalStencil
|
|
4
|
+
# - Per-level temporal encoding (sequence over T_in months) at encoder
|
|
5
|
+
# - Decoder uses cross-attention to fuse upsampled features with encoder skips
|
|
6
|
+
# - Double spherical convolution + GroupNorm + GELU at each encoder/decoder level
|
|
7
|
+
|
|
8
|
+
from __future__ import annotations
|
|
9
|
+
from typing import List, Optional, Literal
|
|
10
|
+
import numpy as np
|
|
11
|
+
|
|
12
|
+
import torch
|
|
13
|
+
import torch.nn as nn
|
|
14
|
+
import torch.nn.functional as F
|
|
15
|
+
|
|
16
|
+
import foscat.scat_cov as sc
|
|
17
|
+
import foscat.SphericalStencil as ho
|
|
18
|
+
|
|
19
|
+
|
|
20
|
+
class MLP(nn.Module):
|
|
21
|
+
def __init__(self, d: int, hidden_mult: int = 4, drop: float = 0.0):
|
|
22
|
+
super().__init__()
|
|
23
|
+
self.net = nn.Sequential(
|
|
24
|
+
nn.LayerNorm(d),
|
|
25
|
+
nn.Linear(d, hidden_mult * d),
|
|
26
|
+
nn.GELU(),
|
|
27
|
+
nn.Dropout(drop),
|
|
28
|
+
nn.Linear(hidden_mult * d, d),
|
|
29
|
+
nn.Dropout(drop),
|
|
30
|
+
)
|
|
31
|
+
def forward(self, x):
|
|
32
|
+
return self.net(x)
|
|
33
|
+
|
|
34
|
+
|
|
35
|
+
class HealpixViTSkip(nn.Module):
|
|
36
|
+
def __init__(
|
|
37
|
+
self,
|
|
38
|
+
*,
|
|
39
|
+
in_nside: int,
|
|
40
|
+
n_chan_in: int,
|
|
41
|
+
level_dims: List[int],
|
|
42
|
+
depth_token: int,
|
|
43
|
+
num_heads_token: int,
|
|
44
|
+
cell_ids: np.ndarray,
|
|
45
|
+
task: Literal["regression","segmentation","global"] = "regression",
|
|
46
|
+
out_channels: int = 1,
|
|
47
|
+
mlp_ratio_token: float = 4.0,
|
|
48
|
+
KERNELSZ: int = 3,
|
|
49
|
+
gauge_type: Literal["cosmo","phi"] = "cosmo",
|
|
50
|
+
G: int = 1,
|
|
51
|
+
prefer_foscat_gpu: bool = True,
|
|
52
|
+
dropout: float = 0.1,
|
|
53
|
+
dtype: Literal["float32","float64"] = "float32",
|
|
54
|
+
pos_embed_per_level: bool = True,
|
|
55
|
+
) -> None:
|
|
56
|
+
super().__init__()
|
|
57
|
+
|
|
58
|
+
self.in_nside = int(in_nside)
|
|
59
|
+
self.n_chan_in = int(n_chan_in)
|
|
60
|
+
self.level_dims = list(level_dims)
|
|
61
|
+
self.token_down = len(self.level_dims) - 1
|
|
62
|
+
assert self.token_down >= 0
|
|
63
|
+
self.C_fine = int(self.level_dims[0])
|
|
64
|
+
self.embed_dim = int(self.level_dims[-1])
|
|
65
|
+
self.depth_token = int(depth_token)
|
|
66
|
+
self.num_heads_token = int(num_heads_token)
|
|
67
|
+
self.mlp_ratio_token = float(mlp_ratio_token)
|
|
68
|
+
self.task = task
|
|
69
|
+
self.out_channels = int(out_channels)
|
|
70
|
+
self.KERNELSZ = int(KERNELSZ)
|
|
71
|
+
self.gauge_type = gauge_type
|
|
72
|
+
self.G = int(G)
|
|
73
|
+
self.prefer_foscat_gpu = bool(prefer_foscat_gpu)
|
|
74
|
+
self.dropout = float(dropout)
|
|
75
|
+
self.dtype = dtype
|
|
76
|
+
self.pos_embed_per_level = bool(pos_embed_per_level)
|
|
77
|
+
|
|
78
|
+
for d in self.level_dims:
|
|
79
|
+
if d % self.G != 0:
|
|
80
|
+
raise ValueError(f"All level_dims must be divisible by G={self.G}, got {d}.")
|
|
81
|
+
if self.embed_dim % self.num_heads_token != 0:
|
|
82
|
+
raise ValueError("embed_dim must be divisible by num_heads_token.")
|
|
83
|
+
|
|
84
|
+
if dtype == "float32":
|
|
85
|
+
self.np_dtype = np.float32
|
|
86
|
+
self.torch_dtype = torch.float32
|
|
87
|
+
else:
|
|
88
|
+
self.np_dtype = np.float64
|
|
89
|
+
self.torch_dtype = torch.float32
|
|
90
|
+
|
|
91
|
+
if cell_ids is None:
|
|
92
|
+
raise ValueError("cell_ids (finest) must be provided.")
|
|
93
|
+
self.cell_ids_fine = np.asarray(cell_ids)
|
|
94
|
+
|
|
95
|
+
if self.task == "segmentation":
|
|
96
|
+
self.final_activation = "sigmoid" if self.out_channels == 1 else "softmax"
|
|
97
|
+
else:
|
|
98
|
+
self.final_activation = "none"
|
|
99
|
+
|
|
100
|
+
self.f = sc.funct(KERNELSZ=self.KERNELSZ)
|
|
101
|
+
|
|
102
|
+
# Build stencils
|
|
103
|
+
self.hconv_levels: List[ho.SphericalStencil] = []
|
|
104
|
+
self.level_cell_ids: List[np.ndarray] = [self.cell_ids_fine]
|
|
105
|
+
current_nside = self.in_nside
|
|
106
|
+
dummy = self.f.backend.bk_cast(np.zeros((1, 1, self.cell_ids_fine.shape[0]), dtype=self.np_dtype))
|
|
107
|
+
for _ in range(self.token_down):
|
|
108
|
+
hc = ho.SphericalStencil(current_nside, self.KERNELSZ, n_gauges=self.G,
|
|
109
|
+
gauge_type=self.gauge_type, cell_ids=self.level_cell_ids[-1],
|
|
110
|
+
dtype=self.torch_dtype)
|
|
111
|
+
self.hconv_levels.append(hc)
|
|
112
|
+
dummy, next_ids = hc.Down(dummy, cell_ids=self.level_cell_ids[-1], nside=current_nside, max_poll=True)
|
|
113
|
+
self.level_cell_ids.append(self.f.backend.to_numpy(next_ids))
|
|
114
|
+
current_nside //= 2
|
|
115
|
+
|
|
116
|
+
self.token_nside = current_nside if self.token_down > 0 else self.in_nside
|
|
117
|
+
self.token_cell_ids = self.level_cell_ids[-1]
|
|
118
|
+
|
|
119
|
+
self.hconv_token = ho.SphericalStencil(self.token_nside, self.KERNELSZ, n_gauges=self.G,
|
|
120
|
+
gauge_type=self.gauge_type, cell_ids=self.token_cell_ids, dtype=self.torch_dtype)
|
|
121
|
+
self.hconv_head = ho.SphericalStencil(self.in_nside, self.KERNELSZ, n_gauges=self.G,
|
|
122
|
+
gauge_type=self.gauge_type, cell_ids=self.cell_ids_fine, dtype=self.torch_dtype)
|
|
123
|
+
|
|
124
|
+
self.nsides_levels = [self.in_nside // (2**i) for i in range(self.token_down+1)]
|
|
125
|
+
self.ntokens_levels = [12 * n * n for n in self.nsides_levels]
|
|
126
|
+
|
|
127
|
+
# Patch embed (double conv)
|
|
128
|
+
fine_g = self.C_fine // self.G
|
|
129
|
+
self.pe_w1 = nn.Parameter(torch.empty(self.n_chan_in, fine_g, self.KERNELSZ*self.KERNELSZ))
|
|
130
|
+
nn.init.kaiming_uniform_(self.pe_w1.view(self.n_chan_in * fine_g, -1), a=np.sqrt(5))
|
|
131
|
+
self.pe_w2 = nn.Parameter(torch.empty(self.C_fine, fine_g, self.KERNELSZ*self.KERNELSZ))
|
|
132
|
+
nn.init.kaiming_uniform_(self.pe_w2.view(self.C_fine * fine_g, -1), a=np.sqrt(5))
|
|
133
|
+
self.pe_bn1 = nn.GroupNorm(num_groups=min(8, self.C_fine if self.C_fine>1 else 1), num_channels=self.C_fine)
|
|
134
|
+
self.pe_bn2 = nn.GroupNorm(num_groups=min(8, self.C_fine if self.C_fine>1 else 1), num_channels=self.C_fine)
|
|
135
|
+
|
|
136
|
+
# Encoder double convs
|
|
137
|
+
self.enc_w1 = nn.ParameterList()
|
|
138
|
+
self.enc_w2 = nn.ParameterList()
|
|
139
|
+
self.enc_bn1 = nn.ModuleList()
|
|
140
|
+
self.enc_bn2 = nn.ModuleList()
|
|
141
|
+
for i in range(self.token_down):
|
|
142
|
+
Cin = self.level_dims[i]
|
|
143
|
+
Cout = self.level_dims[i+1]
|
|
144
|
+
Cout_g = Cout // self.G
|
|
145
|
+
w1 = nn.Parameter(torch.empty(Cin, Cout_g, self.KERNELSZ*self.KERNELSZ))
|
|
146
|
+
nn.init.kaiming_uniform_(w1.view(Cin * Cout_g, -1), a=np.sqrt(5))
|
|
147
|
+
w2 = nn.Parameter(torch.empty(Cout, Cout_g, self.KERNELSZ*self.KERNELSZ))
|
|
148
|
+
nn.init.kaiming_uniform_(w2.view(Cout * Cout_g, -1), a=np.sqrt(5))
|
|
149
|
+
self.enc_w1.append(w1); self.enc_w2.append(w2)
|
|
150
|
+
self.enc_bn1.append(nn.GroupNorm(num_groups=min(8, Cout if Cout>1 else 1), num_channels=Cout))
|
|
151
|
+
self.enc_bn2.append(nn.GroupNorm(num_groups=min(8, Cout if Cout>1 else 1), num_channels=Cout))
|
|
152
|
+
|
|
153
|
+
# Temporal encoders per level (fine..pre-token)
|
|
154
|
+
self.temporal_encoders = nn.ModuleList([
|
|
155
|
+
nn.TransformerEncoder(
|
|
156
|
+
nn.TransformerEncoderLayer(
|
|
157
|
+
d_model=self.level_dims[i],
|
|
158
|
+
nhead=max(1, min(8, self.level_dims[i] // 64)),
|
|
159
|
+
dim_feedforward=2*self.level_dims[i],
|
|
160
|
+
dropout=self.dropout,
|
|
161
|
+
activation='gelu',
|
|
162
|
+
batch_first=True,
|
|
163
|
+
norm_first=True,
|
|
164
|
+
),
|
|
165
|
+
num_layers=2,
|
|
166
|
+
)
|
|
167
|
+
for i in range(self.token_down)
|
|
168
|
+
])
|
|
169
|
+
|
|
170
|
+
# Token-level Transformer
|
|
171
|
+
self.n_tokens = int(self.token_cell_ids.shape[0])
|
|
172
|
+
self.pos_token = nn.Parameter(torch.zeros(1, self.n_tokens, self.embed_dim))
|
|
173
|
+
nn.init.trunc_normal_(self.pos_token, std=0.02)
|
|
174
|
+
enc_layer = nn.TransformerEncoderLayer(
|
|
175
|
+
d_model=self.embed_dim,
|
|
176
|
+
nhead=self.num_heads_token,
|
|
177
|
+
dim_feedforward=int(self.embed_dim * self.mlp_ratio_token),
|
|
178
|
+
dropout=self.dropout,
|
|
179
|
+
activation='gelu',
|
|
180
|
+
batch_first=True,
|
|
181
|
+
norm_first=True,
|
|
182
|
+
)
|
|
183
|
+
self.encoder_token = nn.TransformerEncoder(enc_layer, num_layers=self.depth_token)
|
|
184
|
+
|
|
185
|
+
# Decoder fusion modules per level (cross-attention)
|
|
186
|
+
self.dec_q = nn.ModuleList()
|
|
187
|
+
self.dec_k = nn.ModuleList()
|
|
188
|
+
self.dec_v = nn.ModuleList()
|
|
189
|
+
self.dec_attn = nn.ModuleList()
|
|
190
|
+
self.dec_mlp = nn.ModuleList()
|
|
191
|
+
self.level_pos = nn.ParameterList() if self.pos_embed_per_level else None
|
|
192
|
+
for i in range(self.token_down, 0, -1):
|
|
193
|
+
Cfine = self.level_dims[i-1]
|
|
194
|
+
d_fuse = Cfine
|
|
195
|
+
self.dec_q.append(nn.Linear(Cfine, d_fuse))
|
|
196
|
+
self.dec_k.append(nn.Linear(Cfine, d_fuse))
|
|
197
|
+
self.dec_v.append(nn.Linear(Cfine, d_fuse))
|
|
198
|
+
self.dec_attn.append(nn.MultiheadAttention(embed_dim=d_fuse, num_heads=max(1, min(8, d_fuse // 64)), batch_first=True))
|
|
199
|
+
self.dec_mlp.append(MLP(d_fuse, hidden_mult=4, drop=self.dropout))
|
|
200
|
+
if self.pos_embed_per_level:
|
|
201
|
+
n_tok_i = self.ntokens_levels[i-1]
|
|
202
|
+
p = nn.Parameter(torch.zeros(1, n_tok_i, d_fuse))
|
|
203
|
+
nn.init.trunc_normal_(p, std=0.02)
|
|
204
|
+
self.level_pos.append(p)
|
|
205
|
+
|
|
206
|
+
# Decoder refinement double convs
|
|
207
|
+
self.dec_refine_w1 = nn.ParameterList()
|
|
208
|
+
self.dec_refine_w2 = nn.ParameterList()
|
|
209
|
+
self.dec_refine_bn1 = nn.ModuleList()
|
|
210
|
+
self.dec_refine_bn2 = nn.ModuleList()
|
|
211
|
+
for i in range(self.token_down, 0, -1):
|
|
212
|
+
Cfine = self.level_dims[i-1]
|
|
213
|
+
Cfine_g = Cfine // self.G
|
|
214
|
+
w1 = nn.Parameter(torch.empty(Cfine, Cfine_g, self.KERNELSZ*self.KERNELSZ))
|
|
215
|
+
nn.init.kaiming_uniform_(w1.view(Cfine * Cfine_g, -1), a=np.sqrt(5))
|
|
216
|
+
w2 = nn.Parameter(torch.empty(Cfine, Cfine_g, self.KERNELSZ*self.KERNELSZ))
|
|
217
|
+
nn.init.kaiming_uniform_(w2.view(Cfine * Cfine_g, -1), a=np.sqrt(5))
|
|
218
|
+
self.dec_refine_w1.append(w1); self.dec_refine_w2.append(w2)
|
|
219
|
+
self.dec_refine_bn1.append(nn.GroupNorm(num_groups=min(8, Cfine if Cfine>1 else 1), num_channels=Cfine))
|
|
220
|
+
self.dec_refine_bn2.append(nn.GroupNorm(num_groups=min(8, Cfine if Cfine>1 else 1), num_channels=Cfine))
|
|
221
|
+
|
|
222
|
+
# Head
|
|
223
|
+
if self.task == "global":
|
|
224
|
+
self.global_head = nn.Linear(self.embed_dim, self.out_channels)
|
|
225
|
+
else:
|
|
226
|
+
if self.out_channels % self.G != 0:
|
|
227
|
+
raise ValueError(f"out_channels={self.out_channels} must be divisible by G={self.G}")
|
|
228
|
+
out_g = self.out_channels // self.G
|
|
229
|
+
self.head_w = nn.Parameter(torch.empty(self.C_fine, out_g, self.KERNELSZ*self.KERNELSZ))
|
|
230
|
+
nn.init.kaiming_uniform_(self.head_w.view(self.C_fine * out_g, -1), a=np.sqrt(5))
|
|
231
|
+
self.head_bn = nn.GroupNorm(num_groups=min(8, self.out_channels if self.out_channels>1 else 1),
|
|
232
|
+
num_channels=self.out_channels) if self.task=="segmentation" else None
|
|
233
|
+
|
|
234
|
+
pref = torch.device("cuda" if torch.cuda.is_available() else "cpu")
|
|
235
|
+
self.runtime_device = self._probe_and_set_runtime_device(pref)
|
|
236
|
+
|
|
237
|
+
def _move_hc(self, hc: ho.SphericalStencil, device: torch.device) -> None:
|
|
238
|
+
for name, val in list(vars(hc).items()):
|
|
239
|
+
try:
|
|
240
|
+
if torch.is_tensor(val):
|
|
241
|
+
setattr(hc, name, val.to(device))
|
|
242
|
+
elif isinstance(val, (list, tuple)) and val and torch.is_tensor(val[0]):
|
|
243
|
+
setattr(hc, name, type(val)([v.to(device) for v in val]))
|
|
244
|
+
except Exception:
|
|
245
|
+
pass
|
|
246
|
+
|
|
247
|
+
@torch.no_grad()
|
|
248
|
+
def _probe_and_set_runtime_device(self, preferred: torch.device) -> torch.device:
|
|
249
|
+
if preferred.type == "cuda":
|
|
250
|
+
try:
|
|
251
|
+
super().to(preferred)
|
|
252
|
+
for hc in self.hconv_levels + [self.hconv_token, self.hconv_head]:
|
|
253
|
+
self._move_hc(hc, preferred)
|
|
254
|
+
npix0 = int(self.cell_ids_fine.shape[0])
|
|
255
|
+
x_try = torch.zeros(1, self.n_chan_in, npix0, device=preferred)
|
|
256
|
+
hc0 = self.hconv_levels[0] if len(self.hconv_levels)>0 else self.hconv_head
|
|
257
|
+
y_try = hc0.Convol_torch(x_try, self.pe_w1, cell_ids=self.cell_ids_fine)
|
|
258
|
+
_ = (y_try if torch.is_tensor(y_try) else torch.as_tensor(y_try, device=preferred)).sum().item()
|
|
259
|
+
self._foscat_device = preferred
|
|
260
|
+
return preferred
|
|
261
|
+
except Exception:
|
|
262
|
+
pass
|
|
263
|
+
cpu = torch.device("cpu")
|
|
264
|
+
super().to(cpu)
|
|
265
|
+
for hc in self.hconv_levels + [self.hconv_token, self.hconv_head]:
|
|
266
|
+
self._move_hc(hc, cpu)
|
|
267
|
+
self._foscat_device = cpu
|
|
268
|
+
return cpu
|
|
269
|
+
|
|
270
|
+
def _as_tensor_batch(self, x):
|
|
271
|
+
if isinstance(x, list):
|
|
272
|
+
if len(x) == 1:
|
|
273
|
+
t = x[0]
|
|
274
|
+
return t.unsqueeze(0) if t.dim() == 2 else t
|
|
275
|
+
raise ValueError("Variable-length list not supported here; pass a tensor.")
|
|
276
|
+
return x
|
|
277
|
+
|
|
278
|
+
def _to_numpy_ids(self, ids):
|
|
279
|
+
if torch.is_tensor(ids):
|
|
280
|
+
return ids.detach().cpu().numpy()
|
|
281
|
+
return np.asarray(ids)
|
|
282
|
+
|
|
283
|
+
def _patch_embed_fine(self, x_t: torch.Tensor) -> torch.Tensor:
|
|
284
|
+
hc0 = self.hconv_levels[0] if len(self.hconv_levels)>0 else self.hconv_head
|
|
285
|
+
z = hc0.Convol_torch(x_t, self.pe_w1, cell_ids=self.cell_ids_fine)
|
|
286
|
+
z = self._as_tensor_batch(z if torch.is_tensor(z) else torch.as_tensor(z, device=self.runtime_device))
|
|
287
|
+
z = self.pe_bn1(z); z = F.gelu(z)
|
|
288
|
+
z = hc0.Convol_torch(z, self.pe_w2, cell_ids=self.cell_ids_fine)
|
|
289
|
+
z = self._as_tensor_batch(z if torch.is_tensor(z) else torch.as_tensor(z, device=self.runtime_device))
|
|
290
|
+
z = self.pe_bn2(z); z = F.gelu(z)
|
|
291
|
+
return z
|
|
292
|
+
|
|
293
|
+
def forward(self, x: torch.Tensor, runtime_ids: Optional[np.ndarray] = None) -> torch.Tensor:
|
|
294
|
+
if x.dim() != 4:
|
|
295
|
+
raise ValueError("Expected input shape (B, T_in, C_in, Npix)")
|
|
296
|
+
B, T_in, C_in, Nf = x.shape
|
|
297
|
+
if C_in != self.n_chan_in:
|
|
298
|
+
raise ValueError(f"Expected n_chan_in={self.n_chan_in}, got {C_in}")
|
|
299
|
+
x = x.to(self.runtime_device)
|
|
300
|
+
|
|
301
|
+
fine_ids_runtime = self.cell_ids_fine if runtime_ids is None else self._to_numpy_ids(runtime_ids)
|
|
302
|
+
ids_chain = [np.asarray(fine_ids_runtime)]
|
|
303
|
+
nside_tmp = self.in_nside
|
|
304
|
+
_dummy = self.f.backend.bk_cast(np.zeros((1, 1, ids_chain[0].shape[0]), dtype=self.np_dtype))
|
|
305
|
+
for hc in self.hconv_levels:
|
|
306
|
+
_dummy, _next = hc.Down(_dummy, cell_ids=ids_chain[-1], nside=nside_tmp, max_poll=True)
|
|
307
|
+
ids_chain.append(self.f.backend.to_numpy(_next))
|
|
308
|
+
nside_tmp //= 2
|
|
309
|
+
|
|
310
|
+
# Encoder histories per level
|
|
311
|
+
l_hist: List[torch.Tensor] = []
|
|
312
|
+
l_ids: List[np.ndarray] = []
|
|
313
|
+
|
|
314
|
+
feats_fine = []
|
|
315
|
+
for t in range(T_in):
|
|
316
|
+
zt = self._patch_embed_fine(x[:, t, :, :])
|
|
317
|
+
feats_fine.append(zt.unsqueeze(1))
|
|
318
|
+
feats_fine = torch.cat(feats_fine, dim=1) # (B, T_in, C_fine, N_fine)
|
|
319
|
+
l_hist.append(feats_fine)
|
|
320
|
+
l_ids.append(self.cell_ids_fine)
|
|
321
|
+
|
|
322
|
+
current_nside = self.in_nside
|
|
323
|
+
l_data_hist = feats_fine
|
|
324
|
+
for i, hc in enumerate(self.hconv_levels):
|
|
325
|
+
Cin = self.level_dims[i]
|
|
326
|
+
Cout = self.level_dims[i+1]
|
|
327
|
+
w1, w2 = self.enc_w1[i], self.enc_w2[i]
|
|
328
|
+
feats_next = []
|
|
329
|
+
for t in range(T_in):
|
|
330
|
+
zt = l_data_hist[:, t, :, :]
|
|
331
|
+
zt = hc.Convol_torch(zt, w1, cell_ids=l_ids[-1])
|
|
332
|
+
zt = self._as_tensor_batch(zt if torch.is_tensor(zt) else torch.as_tensor(zt, device=self.runtime_device))
|
|
333
|
+
zt = self.enc_bn1[i](zt); zt = F.gelu(zt)
|
|
334
|
+
zt = hc.Convol_torch(zt, w2, cell_ids=l_ids[-1])
|
|
335
|
+
zt = self._as_tensor_batch(zt if torch.is_tensor(zt) else torch.as_tensor(zt, device=self.runtime_device))
|
|
336
|
+
zt = self.enc_bn2[i](zt); zt = F.gelu(zt)
|
|
337
|
+
feats_next.append(zt.unsqueeze(1))
|
|
338
|
+
feats_next = torch.cat(feats_next, dim=1) # (B, T_in, Cout, N_i)
|
|
339
|
+
|
|
340
|
+
feats_down = []
|
|
341
|
+
next_ids_list = None
|
|
342
|
+
for t in range(T_in):
|
|
343
|
+
zt, next_ids = hc.Down(feats_next[:, t, :, :], cell_ids=l_ids[-1], nside=current_nside, max_poll=True)
|
|
344
|
+
zt = self._as_tensor_batch(zt)
|
|
345
|
+
feats_down.append(zt.unsqueeze(1))
|
|
346
|
+
next_ids_list = next_ids
|
|
347
|
+
feats_down = torch.cat(feats_down, dim=1) # (B, T_in, Cout, N_{i+1})
|
|
348
|
+
|
|
349
|
+
l_hist.append(feats_down)
|
|
350
|
+
l_ids.append(self.f.backend.to_numpy(next_ids_list))
|
|
351
|
+
l_data_hist = feats_down
|
|
352
|
+
current_nside //= 2
|
|
353
|
+
|
|
354
|
+
# Temporal encoder on skips (levels 0..token_down-1)
|
|
355
|
+
skips: List[torch.Tensor] = []
|
|
356
|
+
for i in range(self.token_down):
|
|
357
|
+
Bx, Tx, Cx, Nx = l_hist[i].shape
|
|
358
|
+
z = l_hist[i].permute(0, 3, 1, 2).reshape(Bx*Nx, Tx, Cx)
|
|
359
|
+
z = self.temporal_encoders[i](z)
|
|
360
|
+
z = z.mean(dim=1)
|
|
361
|
+
H_i = z.view(Bx, Nx, Cx).permute(0, 2, 1).contiguous()
|
|
362
|
+
skips.append(H_i)
|
|
363
|
+
|
|
364
|
+
# Token-level transformer (spatial)
|
|
365
|
+
x_tok_hist = l_hist[-1] # (B, T_in, E, Ntok)
|
|
366
|
+
x_tok = x_tok_hist.mean(dim=1) # (B, E, Ntok) (could add temporal encoder here as well)
|
|
367
|
+
seq = x_tok.permute(0, 2, 1) + self.pos_token[:, :x_tok.shape[2], :]
|
|
368
|
+
seq = self.encoder_token(seq)
|
|
369
|
+
y = seq.permute(0, 2, 1) # (B, E, Ntok)
|
|
370
|
+
|
|
371
|
+
if self.task == "global":
|
|
372
|
+
g = seq.mean(dim=1)
|
|
373
|
+
return self.global_head(g)
|
|
374
|
+
|
|
375
|
+
# Decoder: Up + cross-attn fusion + double conv refinement
|
|
376
|
+
dec_idx = 0
|
|
377
|
+
for i in range(self.token_down, 0, -1):
|
|
378
|
+
coarse_ids = ids_chain[i]
|
|
379
|
+
fine_ids = ids_chain[i-1]
|
|
380
|
+
source_ns = self.in_nside // (2 ** i)
|
|
381
|
+
fine_ns = self.in_nside // (2 ** (i-1))
|
|
382
|
+
Cfine = self.level_dims[i-1]
|
|
383
|
+
|
|
384
|
+
op_fine = self.hconv_head if fine_ns == self.in_nside else self.hconv_levels[self.nsides_levels.index(fine_ns)]
|
|
385
|
+
|
|
386
|
+
y_up = op_fine.Up(y, cell_ids=coarse_ids, o_cell_ids=fine_ids, nside=source_ns)
|
|
387
|
+
y_up = self._as_tensor_batch(y_up if torch.is_tensor(y_up) else torch.as_tensor(y_up, device=self.runtime_device)) # (B, Cfine, N)
|
|
388
|
+
|
|
389
|
+
skip_i = skips[i-1] # (B, Cfine, N)
|
|
390
|
+
q = self.dec_q[dec_idx](y_up.permute(0,2,1))
|
|
391
|
+
k = self.dec_k[dec_idx](skip_i.permute(0,2,1))
|
|
392
|
+
v = self.dec_v[dec_idx](skip_i.permute(0,2,1))
|
|
393
|
+
if self.pos_embed_per_level:
|
|
394
|
+
pos = self.level_pos[dec_idx][:, :q.shape[1], :]
|
|
395
|
+
q = q + pos; k = k + pos
|
|
396
|
+
z, _ = self.dec_attn[dec_idx](q, k, v)
|
|
397
|
+
z = self.dec_mlp[dec_idx](z)
|
|
398
|
+
z = z.permute(0,2,1).contiguous() # (B, Cfine, N)
|
|
399
|
+
|
|
400
|
+
z = op_fine.Convol_torch(z, self.dec_refine_w1[dec_idx], cell_ids=fine_ids)
|
|
401
|
+
z = self._as_tensor_batch(z if torch.is_tensor(z) else torch.as_tensor(z, device=self.runtime_device))
|
|
402
|
+
z = self.dec_refine_bn1[dec_idx](z); z = F.gelu(z)
|
|
403
|
+
z = op_fine.Convol_torch(z, self.dec_refine_w2[dec_idx], cell_ids=fine_ids)
|
|
404
|
+
z = self._as_tensor_batch(z if torch.is_tensor(z) else torch.as_tensor(z, device=self.runtime_device))
|
|
405
|
+
z = self.dec_refine_bn2[dec_idx](z); z = F.gelu(z)
|
|
406
|
+
|
|
407
|
+
y = z
|
|
408
|
+
dec_idx += 1
|
|
409
|
+
|
|
410
|
+
y = self.hconv_head.Convol_torch(y, self.head_w, cell_ids=fine_ids_runtime)
|
|
411
|
+
y = self._as_tensor_batch(y if torch.is_tensor(y) else torch.as_tensor(y, device=self.runtime_device))
|
|
412
|
+
if self.task == "segmentation" and self.head_bn is not None:
|
|
413
|
+
y = self.head_bn(y)
|
|
414
|
+
if self.final_activation == "sigmoid":
|
|
415
|
+
y = torch.sigmoid(y)
|
|
416
|
+
elif self.final_activation == "softmax":
|
|
417
|
+
y = torch.softmax(y, dim=1)
|
|
418
|
+
return y
|
|
419
|
+
|
|
420
|
+
|
|
421
|
+
if __name__ == "__main__":
|
|
422
|
+
in_nside = 4
|
|
423
|
+
npix = 12 * in_nside * in_nside
|
|
424
|
+
cell_ids = np.arange(npix, dtype=np.int64)
|
|
425
|
+
|
|
426
|
+
B, T_in, Cin = 2, 3, 4
|
|
427
|
+
x = torch.randn(B, T_in, Cin, npix)
|
|
428
|
+
|
|
429
|
+
model = HealpixViTSkip(
|
|
430
|
+
in_nside=in_nside,
|
|
431
|
+
n_chan_in=Cin,
|
|
432
|
+
level_dims=[64, 96, 128],
|
|
433
|
+
depth_token=2,
|
|
434
|
+
num_heads_token=4,
|
|
435
|
+
cell_ids=cell_ids,
|
|
436
|
+
task="regression",
|
|
437
|
+
out_channels=1,
|
|
438
|
+
KERNELSZ=3,
|
|
439
|
+
G=1,
|
|
440
|
+
dropout=0.1,
|
|
441
|
+
).eval()
|
|
442
|
+
|
|
443
|
+
with torch.no_grad():
|
|
444
|
+
y = model(x)
|
|
445
|
+
print("Output:", tuple(y.shape))
|