foscat 2025.9.4__py3-none-any.whl → 2025.10.2__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- foscat/FoCUS.py +14 -9
- foscat/Plot.py +59 -17
- foscat/healpix_unet_torch.py +17 -1
- foscat/healpix_vit_skip.py +445 -0
- foscat/healpix_vit_torch-old.py +658 -0
- foscat/healpix_vit_torch.py +521 -0
- foscat/planar_vit.py +206 -0
- foscat/unet_2_d_from_healpix_params.py +421 -0
- {foscat-2025.9.4.dist-info → foscat-2025.10.2.dist-info}/METADATA +1 -1
- {foscat-2025.9.4.dist-info → foscat-2025.10.2.dist-info}/RECORD +13 -8
- {foscat-2025.9.4.dist-info → foscat-2025.10.2.dist-info}/WHEEL +0 -0
- {foscat-2025.9.4.dist-info → foscat-2025.10.2.dist-info}/licenses/LICENSE +0 -0
- {foscat-2025.9.4.dist-info → foscat-2025.10.2.dist-info}/top_level.txt +0 -0
foscat/FoCUS.py
CHANGED
|
@@ -6,7 +6,7 @@ import numpy as np
|
|
|
6
6
|
import foscat.HealSpline as HS
|
|
7
7
|
from scipy.interpolate import griddata
|
|
8
8
|
|
|
9
|
-
TMPFILE_VERSION = "
|
|
9
|
+
TMPFILE_VERSION = "V10_0"
|
|
10
10
|
|
|
11
11
|
|
|
12
12
|
class FoCUS:
|
|
@@ -36,7 +36,7 @@ class FoCUS:
|
|
|
36
36
|
mpi_rank=0
|
|
37
37
|
):
|
|
38
38
|
|
|
39
|
-
self.__version__ = "2025.
|
|
39
|
+
self.__version__ = "2025.10.2"
|
|
40
40
|
# P00 coeff for normalization for scat_cov
|
|
41
41
|
self.TMPFILE_VERSION = TMPFILE_VERSION
|
|
42
42
|
self.P1_dic = None
|
|
@@ -1488,7 +1488,7 @@ class FoCUS:
|
|
|
1488
1488
|
if l_kernel == 5:
|
|
1489
1489
|
pw = 0.5
|
|
1490
1490
|
pw2 = 0.5
|
|
1491
|
-
threshold = 2e-
|
|
1491
|
+
threshold = 2e-5
|
|
1492
1492
|
|
|
1493
1493
|
elif l_kernel == 3:
|
|
1494
1494
|
pw = 1.0 / np.sqrt(2)
|
|
@@ -1498,7 +1498,7 @@ class FoCUS:
|
|
|
1498
1498
|
elif l_kernel == 7:
|
|
1499
1499
|
pw = 0.5
|
|
1500
1500
|
pw2 = 0.25
|
|
1501
|
-
threshold =
|
|
1501
|
+
threshold = 2e-5
|
|
1502
1502
|
|
|
1503
1503
|
import foscat.SphericalStencil as hs
|
|
1504
1504
|
import torch
|
|
@@ -1517,14 +1517,19 @@ class FoCUS:
|
|
|
1517
1517
|
n_gauges=self.NORIENT,
|
|
1518
1518
|
gauge_type='cosmo')
|
|
1519
1519
|
|
|
1520
|
-
xx=np.tile(np.arange(self.KERNELSZ)-self.KERNELSZ//2,self.KERNELSZ).reshape(self.KERNELSZ
|
|
1520
|
+
xx=np.tile(np.arange(self.KERNELSZ)-self.KERNELSZ//2,self.KERNELSZ).reshape(self.KERNELSZ,self.KERNELSZ)
|
|
1521
1521
|
|
|
1522
|
-
wwr=
|
|
1522
|
+
wwr=(np.exp(-pw2*(xx**2+(xx.T)**2))*np.cos(pw*xx*np.pi)).reshape(1,1,self.KERNELSZ*self.KERNELSZ)
|
|
1523
1523
|
wwr-=wwr.mean()
|
|
1524
|
-
wwi=
|
|
1524
|
+
wwi=(np.exp(-pw2*(xx**2+(xx.T)**2))*np.sin(pw*xx*np.pi)).reshape(1,1,self.KERNELSZ*self.KERNELSZ)
|
|
1525
1525
|
wwi-=wwi.mean()
|
|
1526
|
-
|
|
1527
|
-
|
|
1526
|
+
amp=np.sum(abs(wwr+1J*wwi))
|
|
1527
|
+
|
|
1528
|
+
wwr/=amp
|
|
1529
|
+
wwi/=amp
|
|
1530
|
+
|
|
1531
|
+
wwr=hconvol.to_tensor(wwr)
|
|
1532
|
+
wwi=hconvol.to_tensor(wwi)
|
|
1528
1533
|
|
|
1529
1534
|
wavr,indice,mshape=hconvol.make_matrix(wwr)
|
|
1530
1535
|
wavi,indice,mshape=hconvol.make_matrix(wwi)
|
foscat/Plot.py
CHANGED
|
@@ -959,8 +959,9 @@ def conjugate_gradient_normal_equation(data, x0, www, all_idx,
|
|
|
959
959
|
|
|
960
960
|
rs_new = np.dot(r, r)
|
|
961
961
|
|
|
962
|
-
if verbose and i %
|
|
963
|
-
|
|
962
|
+
if verbose and i % 10 == 0:
|
|
963
|
+
v=np.mean((LP(p, www, all_idx)-data)**2)
|
|
964
|
+
print(f"Iter {i:03d}: residual = {np.sqrt(rs_new):.3e},{np.sqrt(v):.3e}")
|
|
964
965
|
|
|
965
966
|
if np.sqrt(rs_new) < tol:
|
|
966
967
|
if verbose:
|
|
@@ -1065,28 +1066,69 @@ def _spectrum_polar_to_cartesian_core(
|
|
|
1065
1066
|
valid = np.isfinite(radial_index)
|
|
1066
1067
|
try:
|
|
1067
1068
|
from scipy.ndimage import map_coordinates
|
|
1068
|
-
|
|
1069
|
-
|
|
1070
|
-
|
|
1071
|
-
|
|
1072
|
-
|
|
1073
|
-
|
|
1074
|
-
w
|
|
1075
|
-
|
|
1076
|
-
|
|
1069
|
+
|
|
1070
|
+
if method.lower() == "bicubic":
|
|
1071
|
+
# ===== Bicubic with circular angular wrap =====
|
|
1072
|
+
# Pad the angular axis by K columns on both sides so that the cubic kernel
|
|
1073
|
+
# has valid neighbors across the 0°/360° seam. K=2 is enough for a cubic kernel.
|
|
1074
|
+
K = 2
|
|
1075
|
+
# w has shape (Nscale, Norient)
|
|
1076
|
+
w_pad = np.concatenate([w[:, -K:], w, w[:, :K]], axis=1) # (ns, no+2K)
|
|
1077
|
+
|
|
1078
|
+
# Build coordinates for map_coordinates on the padded array:
|
|
1079
|
+
# - radial index stays the same, clipped to [0, ns-1] (no wrap)
|
|
1080
|
+
# - angular index is shifted by +K and wrapped into [0, no) before querying
|
|
1081
|
+
order = 3
|
|
1082
|
+
coords = np.vstack([radial_index.ravel(), angular_index.ravel()])
|
|
1083
|
+
|
|
1084
|
+
# clip the radial coordinate to the valid [eps, ns-1-eps] band
|
|
1085
|
+
eps = 1e-6
|
|
1086
|
+
coords[0, :] = np.where(
|
|
1087
|
+
np.isfinite(coords[0, :]),
|
|
1088
|
+
np.clip(coords[0, :], 0.0 + eps, (ns - 1) - eps),
|
|
1089
|
+
0.0,
|
|
1090
|
+
)
|
|
1091
|
+
|
|
1092
|
+
# wrap angular coordinate to [0, no) then shift by +K to address the padded array
|
|
1093
|
+
ang = np.mod(coords[1, :], float(no)) + K # [K, no+K)
|
|
1094
|
+
coords[1, :] = ang
|
|
1095
|
+
|
|
1096
|
+
# Now sample the padded array; 'nearest' mode is fine thanks to explicit padding
|
|
1097
|
+
sampled = map_coordinates(
|
|
1098
|
+
w_pad, coords, order=order, mode="nearest", cval=fill_value, prefilter=True
|
|
1099
|
+
).reshape(n_pixels, n_pixels)
|
|
1100
|
+
|
|
1101
|
+
img = np.where(valid, sampled, fill_value)
|
|
1102
|
+
|
|
1103
|
+
else:
|
|
1104
|
+
# ===== Bilinear (or other) without special padding =====
|
|
1105
|
+
order = 1
|
|
1106
|
+
coords = np.vstack([radial_index.ravel(), angular_index.ravel()])
|
|
1107
|
+
eps = 1e-6
|
|
1108
|
+
coords[0, :] = np.where(
|
|
1109
|
+
np.isfinite(coords[0, :]),
|
|
1110
|
+
np.clip(coords[0, :], 0.0 + eps, (ns - 1) - eps),
|
|
1111
|
+
0.0,
|
|
1112
|
+
)
|
|
1113
|
+
# For non-bicubic, SciPy's mode="wrap" is sufficient on the angular axis
|
|
1114
|
+
sampled = map_coordinates(
|
|
1115
|
+
w, coords, order=order, mode="wrap", cval=fill_value, prefilter=True
|
|
1116
|
+
).reshape(n_pixels, n_pixels)
|
|
1117
|
+
img = np.where(valid, sampled, fill_value)
|
|
1118
|
+
|
|
1077
1119
|
except Exception:
|
|
1078
|
-
# bilinear fallback
|
|
1120
|
+
# ---- Vectorized bilinear fallback with explicit angular wrap ----
|
|
1079
1121
|
r_idx = np.floor(radial_index).astype(np.int64)
|
|
1080
1122
|
t_idx = np.floor(angular_index).astype(np.int64)
|
|
1081
|
-
r_idx = np.clip(r_idx, 0, ns-2)
|
|
1123
|
+
r_idx = np.clip(r_idx, 0, ns - 2)
|
|
1082
1124
|
t0 = np.mod(t_idx, no)
|
|
1083
|
-
t1 = np.mod(t_idx+1, no)
|
|
1125
|
+
t1 = np.mod(t_idx + 1, no)
|
|
1084
1126
|
tr = np.clip(radial_index - r_idx, 0.0, 1.0)
|
|
1085
1127
|
ta = np.clip(angular_index - t_idx, 0.0, 1.0)
|
|
1086
1128
|
f00 = w[r_idx, t0]
|
|
1087
1129
|
f01 = w[r_idx, t1]
|
|
1088
|
-
f10 = w[r_idx+1,
|
|
1089
|
-
f11 = w[r_idx+1,
|
|
1130
|
+
f10 = w[r_idx + 1, t0]
|
|
1131
|
+
f11 = w[r_idx + 1, t1]
|
|
1090
1132
|
g0 = (1.0 - ta) * f00 + ta * f01
|
|
1091
1133
|
g1 = (1.0 - ta) * f10 + ta * f11
|
|
1092
1134
|
img = (1.0 - tr) * g0 + tr * g1
|
|
@@ -1114,7 +1156,7 @@ def plot_wave(wave,title="spectrum",unit="Amplitude",cmap="viridis"):
|
|
|
1114
1156
|
plt.xlabel(r"$k_x$ [cycles / km]")
|
|
1115
1157
|
plt.ylabel(r"$k_y$ [cycles / km]")
|
|
1116
1158
|
plt.title(title)
|
|
1117
|
-
|
|
1159
|
+
|
|
1118
1160
|
def lonlat_edges_from_ref(shape, ref_lon, ref_lat, dlon, dlat, anchor="center"):
|
|
1119
1161
|
"""
|
|
1120
1162
|
Build lon/lat *edges* (H+1, W+1) for a regular, axis-aligned grid.
|
foscat/healpix_unet_torch.py
CHANGED
|
@@ -982,6 +982,9 @@ def fit(
|
|
|
982
982
|
n_epoch: int = 10,
|
|
983
983
|
view_epoch: int = 10,
|
|
984
984
|
batch_size: int = 16,
|
|
985
|
+
x_valid: Union[torch.Tensor, np.ndarray, List[Union[torch.Tensor, np.ndarray]]]= None,
|
|
986
|
+
y_valid: Union[torch.Tensor, np.ndarray, List[Union[torch.Tensor, np.ndarray]]]= None,
|
|
987
|
+
save_model: bool = False,
|
|
985
988
|
lr: float = 1e-3,
|
|
986
989
|
weight_decay: float = 0.0,
|
|
987
990
|
clip_grad_norm: Optional[float] = None,
|
|
@@ -1005,6 +1008,11 @@ def fit(
|
|
|
1005
1008
|
device = model.runtime_device if hasattr(model, "runtime_device") else (torch.device("cuda") if torch.cuda.is_available() else torch.device("cpu"))
|
|
1006
1009
|
model.to(device)
|
|
1007
1010
|
|
|
1011
|
+
if save_model:
|
|
1012
|
+
assert x_valid is None, "If save_mode=True x_valid should not be None"
|
|
1013
|
+
assert y_valid is None, "If save_mode=True y_valid should not be None"
|
|
1014
|
+
best_valid=1E30
|
|
1015
|
+
|
|
1008
1016
|
# Detect variable-length mode
|
|
1009
1017
|
varlen_mode = isinstance(x_train, (list, tuple))
|
|
1010
1018
|
|
|
@@ -1197,6 +1205,14 @@ def fit(
|
|
|
1197
1205
|
history.append(epoch_loss)
|
|
1198
1206
|
# print every view_epoch logical step
|
|
1199
1207
|
if verbose and ((len(history) % view_epoch == 0) or (len(history) == 1)):
|
|
1200
|
-
|
|
1208
|
+
if x_valid is not None:
|
|
1209
|
+
preds=model.predict(model.to_tensor(x_valid)).cpu().numpy()
|
|
1210
|
+
valid_loss=np.mean((preds-y_valid)**2)
|
|
1211
|
+
if save_model:
|
|
1212
|
+
if best_valid>valid_loss:
|
|
1213
|
+
torch.save({"model": self.state_dict(), "cfg": CFG}, os.path.join(CFG["save_dir"], "best.pt"))
|
|
1214
|
+
print(f"[epoch {len(history)}] loss={epoch_loss:.4f} loss_valid={valid_loss:.4f}")
|
|
1215
|
+
else:
|
|
1216
|
+
print(f"[epoch {len(history)}] loss={epoch_loss:.4f}")
|
|
1201
1217
|
|
|
1202
1218
|
return {"loss": history}
|
|
@@ -0,0 +1,445 @@
|
|
|
1
|
+
# healpix_vit_skip.py
|
|
2
|
+
# HEALPix ViT U-Net with temporal encoders and Transformer-based skip fusion.
|
|
3
|
+
# - Multi-level HEALPix pyramid using Foscat.SphericalStencil
|
|
4
|
+
# - Per-level temporal encoding (sequence over T_in months) at encoder
|
|
5
|
+
# - Decoder uses cross-attention to fuse upsampled features with encoder skips
|
|
6
|
+
# - Double spherical convolution + GroupNorm + GELU at each encoder/decoder level
|
|
7
|
+
|
|
8
|
+
from __future__ import annotations
|
|
9
|
+
from typing import List, Optional, Literal
|
|
10
|
+
import numpy as np
|
|
11
|
+
|
|
12
|
+
import torch
|
|
13
|
+
import torch.nn as nn
|
|
14
|
+
import torch.nn.functional as F
|
|
15
|
+
|
|
16
|
+
import foscat.scat_cov as sc
|
|
17
|
+
import foscat.SphericalStencil as ho
|
|
18
|
+
|
|
19
|
+
|
|
20
|
+
class MLP(nn.Module):
|
|
21
|
+
def __init__(self, d: int, hidden_mult: int = 4, drop: float = 0.0):
|
|
22
|
+
super().__init__()
|
|
23
|
+
self.net = nn.Sequential(
|
|
24
|
+
nn.LayerNorm(d),
|
|
25
|
+
nn.Linear(d, hidden_mult * d),
|
|
26
|
+
nn.GELU(),
|
|
27
|
+
nn.Dropout(drop),
|
|
28
|
+
nn.Linear(hidden_mult * d, d),
|
|
29
|
+
nn.Dropout(drop),
|
|
30
|
+
)
|
|
31
|
+
def forward(self, x):
|
|
32
|
+
return self.net(x)
|
|
33
|
+
|
|
34
|
+
|
|
35
|
+
class HealpixViTSkip(nn.Module):
|
|
36
|
+
def __init__(
|
|
37
|
+
self,
|
|
38
|
+
*,
|
|
39
|
+
in_nside: int,
|
|
40
|
+
n_chan_in: int,
|
|
41
|
+
level_dims: List[int],
|
|
42
|
+
depth_token: int,
|
|
43
|
+
num_heads_token: int,
|
|
44
|
+
cell_ids: np.ndarray,
|
|
45
|
+
task: Literal["regression","segmentation","global"] = "regression",
|
|
46
|
+
out_channels: int = 1,
|
|
47
|
+
mlp_ratio_token: float = 4.0,
|
|
48
|
+
KERNELSZ: int = 3,
|
|
49
|
+
gauge_type: Literal["cosmo","phi"] = "cosmo",
|
|
50
|
+
G: int = 1,
|
|
51
|
+
prefer_foscat_gpu: bool = True,
|
|
52
|
+
dropout: float = 0.1,
|
|
53
|
+
dtype: Literal["float32","float64"] = "float32",
|
|
54
|
+
pos_embed_per_level: bool = True,
|
|
55
|
+
) -> None:
|
|
56
|
+
super().__init__()
|
|
57
|
+
|
|
58
|
+
self.in_nside = int(in_nside)
|
|
59
|
+
self.n_chan_in = int(n_chan_in)
|
|
60
|
+
self.level_dims = list(level_dims)
|
|
61
|
+
self.token_down = len(self.level_dims) - 1
|
|
62
|
+
assert self.token_down >= 0
|
|
63
|
+
self.C_fine = int(self.level_dims[0])
|
|
64
|
+
self.embed_dim = int(self.level_dims[-1])
|
|
65
|
+
self.depth_token = int(depth_token)
|
|
66
|
+
self.num_heads_token = int(num_heads_token)
|
|
67
|
+
self.mlp_ratio_token = float(mlp_ratio_token)
|
|
68
|
+
self.task = task
|
|
69
|
+
self.out_channels = int(out_channels)
|
|
70
|
+
self.KERNELSZ = int(KERNELSZ)
|
|
71
|
+
self.gauge_type = gauge_type
|
|
72
|
+
self.G = int(G)
|
|
73
|
+
self.prefer_foscat_gpu = bool(prefer_foscat_gpu)
|
|
74
|
+
self.dropout = float(dropout)
|
|
75
|
+
self.dtype = dtype
|
|
76
|
+
self.pos_embed_per_level = bool(pos_embed_per_level)
|
|
77
|
+
|
|
78
|
+
for d in self.level_dims:
|
|
79
|
+
if d % self.G != 0:
|
|
80
|
+
raise ValueError(f"All level_dims must be divisible by G={self.G}, got {d}.")
|
|
81
|
+
if self.embed_dim % self.num_heads_token != 0:
|
|
82
|
+
raise ValueError("embed_dim must be divisible by num_heads_token.")
|
|
83
|
+
|
|
84
|
+
if dtype == "float32":
|
|
85
|
+
self.np_dtype = np.float32
|
|
86
|
+
self.torch_dtype = torch.float32
|
|
87
|
+
else:
|
|
88
|
+
self.np_dtype = np.float64
|
|
89
|
+
self.torch_dtype = torch.float32
|
|
90
|
+
|
|
91
|
+
if cell_ids is None:
|
|
92
|
+
raise ValueError("cell_ids (finest) must be provided.")
|
|
93
|
+
self.cell_ids_fine = np.asarray(cell_ids)
|
|
94
|
+
|
|
95
|
+
if self.task == "segmentation":
|
|
96
|
+
self.final_activation = "sigmoid" if self.out_channels == 1 else "softmax"
|
|
97
|
+
else:
|
|
98
|
+
self.final_activation = "none"
|
|
99
|
+
|
|
100
|
+
self.f = sc.funct(KERNELSZ=self.KERNELSZ)
|
|
101
|
+
|
|
102
|
+
# Build stencils
|
|
103
|
+
self.hconv_levels: List[ho.SphericalStencil] = []
|
|
104
|
+
self.level_cell_ids: List[np.ndarray] = [self.cell_ids_fine]
|
|
105
|
+
current_nside = self.in_nside
|
|
106
|
+
dummy = self.f.backend.bk_cast(np.zeros((1, 1, self.cell_ids_fine.shape[0]), dtype=self.np_dtype))
|
|
107
|
+
for _ in range(self.token_down):
|
|
108
|
+
hc = ho.SphericalStencil(current_nside, self.KERNELSZ, n_gauges=self.G,
|
|
109
|
+
gauge_type=self.gauge_type, cell_ids=self.level_cell_ids[-1],
|
|
110
|
+
dtype=self.torch_dtype)
|
|
111
|
+
self.hconv_levels.append(hc)
|
|
112
|
+
dummy, next_ids = hc.Down(dummy, cell_ids=self.level_cell_ids[-1], nside=current_nside, max_poll=True)
|
|
113
|
+
self.level_cell_ids.append(self.f.backend.to_numpy(next_ids))
|
|
114
|
+
current_nside //= 2
|
|
115
|
+
|
|
116
|
+
self.token_nside = current_nside if self.token_down > 0 else self.in_nside
|
|
117
|
+
self.token_cell_ids = self.level_cell_ids[-1]
|
|
118
|
+
|
|
119
|
+
self.hconv_token = ho.SphericalStencil(self.token_nside, self.KERNELSZ, n_gauges=self.G,
|
|
120
|
+
gauge_type=self.gauge_type, cell_ids=self.token_cell_ids, dtype=self.torch_dtype)
|
|
121
|
+
self.hconv_head = ho.SphericalStencil(self.in_nside, self.KERNELSZ, n_gauges=self.G,
|
|
122
|
+
gauge_type=self.gauge_type, cell_ids=self.cell_ids_fine, dtype=self.torch_dtype)
|
|
123
|
+
|
|
124
|
+
self.nsides_levels = [self.in_nside // (2**i) for i in range(self.token_down+1)]
|
|
125
|
+
self.ntokens_levels = [12 * n * n for n in self.nsides_levels]
|
|
126
|
+
|
|
127
|
+
# Patch embed (double conv)
|
|
128
|
+
fine_g = self.C_fine // self.G
|
|
129
|
+
self.pe_w1 = nn.Parameter(torch.empty(self.n_chan_in, fine_g, self.KERNELSZ*self.KERNELSZ))
|
|
130
|
+
nn.init.kaiming_uniform_(self.pe_w1.view(self.n_chan_in * fine_g, -1), a=np.sqrt(5))
|
|
131
|
+
self.pe_w2 = nn.Parameter(torch.empty(self.C_fine, fine_g, self.KERNELSZ*self.KERNELSZ))
|
|
132
|
+
nn.init.kaiming_uniform_(self.pe_w2.view(self.C_fine * fine_g, -1), a=np.sqrt(5))
|
|
133
|
+
self.pe_bn1 = nn.GroupNorm(num_groups=min(8, self.C_fine if self.C_fine>1 else 1), num_channels=self.C_fine)
|
|
134
|
+
self.pe_bn2 = nn.GroupNorm(num_groups=min(8, self.C_fine if self.C_fine>1 else 1), num_channels=self.C_fine)
|
|
135
|
+
|
|
136
|
+
# Encoder double convs
|
|
137
|
+
self.enc_w1 = nn.ParameterList()
|
|
138
|
+
self.enc_w2 = nn.ParameterList()
|
|
139
|
+
self.enc_bn1 = nn.ModuleList()
|
|
140
|
+
self.enc_bn2 = nn.ModuleList()
|
|
141
|
+
for i in range(self.token_down):
|
|
142
|
+
Cin = self.level_dims[i]
|
|
143
|
+
Cout = self.level_dims[i+1]
|
|
144
|
+
Cout_g = Cout // self.G
|
|
145
|
+
w1 = nn.Parameter(torch.empty(Cin, Cout_g, self.KERNELSZ*self.KERNELSZ))
|
|
146
|
+
nn.init.kaiming_uniform_(w1.view(Cin * Cout_g, -1), a=np.sqrt(5))
|
|
147
|
+
w2 = nn.Parameter(torch.empty(Cout, Cout_g, self.KERNELSZ*self.KERNELSZ))
|
|
148
|
+
nn.init.kaiming_uniform_(w2.view(Cout * Cout_g, -1), a=np.sqrt(5))
|
|
149
|
+
self.enc_w1.append(w1); self.enc_w2.append(w2)
|
|
150
|
+
self.enc_bn1.append(nn.GroupNorm(num_groups=min(8, Cout if Cout>1 else 1), num_channels=Cout))
|
|
151
|
+
self.enc_bn2.append(nn.GroupNorm(num_groups=min(8, Cout if Cout>1 else 1), num_channels=Cout))
|
|
152
|
+
|
|
153
|
+
# Temporal encoders per level (fine..pre-token)
|
|
154
|
+
self.temporal_encoders = nn.ModuleList([
|
|
155
|
+
nn.TransformerEncoder(
|
|
156
|
+
nn.TransformerEncoderLayer(
|
|
157
|
+
d_model=self.level_dims[i],
|
|
158
|
+
nhead=max(1, min(8, self.level_dims[i] // 64)),
|
|
159
|
+
dim_feedforward=2*self.level_dims[i],
|
|
160
|
+
dropout=self.dropout,
|
|
161
|
+
activation='gelu',
|
|
162
|
+
batch_first=True,
|
|
163
|
+
norm_first=True,
|
|
164
|
+
),
|
|
165
|
+
num_layers=2,
|
|
166
|
+
)
|
|
167
|
+
for i in range(self.token_down)
|
|
168
|
+
])
|
|
169
|
+
|
|
170
|
+
# Token-level Transformer
|
|
171
|
+
self.n_tokens = int(self.token_cell_ids.shape[0])
|
|
172
|
+
self.pos_token = nn.Parameter(torch.zeros(1, self.n_tokens, self.embed_dim))
|
|
173
|
+
nn.init.trunc_normal_(self.pos_token, std=0.02)
|
|
174
|
+
enc_layer = nn.TransformerEncoderLayer(
|
|
175
|
+
d_model=self.embed_dim,
|
|
176
|
+
nhead=self.num_heads_token,
|
|
177
|
+
dim_feedforward=int(self.embed_dim * self.mlp_ratio_token),
|
|
178
|
+
dropout=self.dropout,
|
|
179
|
+
activation='gelu',
|
|
180
|
+
batch_first=True,
|
|
181
|
+
norm_first=True,
|
|
182
|
+
)
|
|
183
|
+
self.encoder_token = nn.TransformerEncoder(enc_layer, num_layers=self.depth_token)
|
|
184
|
+
|
|
185
|
+
# Decoder fusion modules per level (cross-attention)
|
|
186
|
+
self.dec_q = nn.ModuleList()
|
|
187
|
+
self.dec_k = nn.ModuleList()
|
|
188
|
+
self.dec_v = nn.ModuleList()
|
|
189
|
+
self.dec_attn = nn.ModuleList()
|
|
190
|
+
self.dec_mlp = nn.ModuleList()
|
|
191
|
+
self.level_pos = nn.ParameterList() if self.pos_embed_per_level else None
|
|
192
|
+
for i in range(self.token_down, 0, -1):
|
|
193
|
+
Cfine = self.level_dims[i-1]
|
|
194
|
+
d_fuse = Cfine
|
|
195
|
+
self.dec_q.append(nn.Linear(Cfine, d_fuse))
|
|
196
|
+
self.dec_k.append(nn.Linear(Cfine, d_fuse))
|
|
197
|
+
self.dec_v.append(nn.Linear(Cfine, d_fuse))
|
|
198
|
+
self.dec_attn.append(nn.MultiheadAttention(embed_dim=d_fuse, num_heads=max(1, min(8, d_fuse // 64)), batch_first=True))
|
|
199
|
+
self.dec_mlp.append(MLP(d_fuse, hidden_mult=4, drop=self.dropout))
|
|
200
|
+
if self.pos_embed_per_level:
|
|
201
|
+
n_tok_i = self.ntokens_levels[i-1]
|
|
202
|
+
p = nn.Parameter(torch.zeros(1, n_tok_i, d_fuse))
|
|
203
|
+
nn.init.trunc_normal_(p, std=0.02)
|
|
204
|
+
self.level_pos.append(p)
|
|
205
|
+
|
|
206
|
+
# Decoder refinement double convs
|
|
207
|
+
self.dec_refine_w1 = nn.ParameterList()
|
|
208
|
+
self.dec_refine_w2 = nn.ParameterList()
|
|
209
|
+
self.dec_refine_bn1 = nn.ModuleList()
|
|
210
|
+
self.dec_refine_bn2 = nn.ModuleList()
|
|
211
|
+
for i in range(self.token_down, 0, -1):
|
|
212
|
+
Cfine = self.level_dims[i-1]
|
|
213
|
+
Cfine_g = Cfine // self.G
|
|
214
|
+
w1 = nn.Parameter(torch.empty(Cfine, Cfine_g, self.KERNELSZ*self.KERNELSZ))
|
|
215
|
+
nn.init.kaiming_uniform_(w1.view(Cfine * Cfine_g, -1), a=np.sqrt(5))
|
|
216
|
+
w2 = nn.Parameter(torch.empty(Cfine, Cfine_g, self.KERNELSZ*self.KERNELSZ))
|
|
217
|
+
nn.init.kaiming_uniform_(w2.view(Cfine * Cfine_g, -1), a=np.sqrt(5))
|
|
218
|
+
self.dec_refine_w1.append(w1); self.dec_refine_w2.append(w2)
|
|
219
|
+
self.dec_refine_bn1.append(nn.GroupNorm(num_groups=min(8, Cfine if Cfine>1 else 1), num_channels=Cfine))
|
|
220
|
+
self.dec_refine_bn2.append(nn.GroupNorm(num_groups=min(8, Cfine if Cfine>1 else 1), num_channels=Cfine))
|
|
221
|
+
|
|
222
|
+
# Head
|
|
223
|
+
if self.task == "global":
|
|
224
|
+
self.global_head = nn.Linear(self.embed_dim, self.out_channels)
|
|
225
|
+
else:
|
|
226
|
+
if self.out_channels % self.G != 0:
|
|
227
|
+
raise ValueError(f"out_channels={self.out_channels} must be divisible by G={self.G}")
|
|
228
|
+
out_g = self.out_channels // self.G
|
|
229
|
+
self.head_w = nn.Parameter(torch.empty(self.C_fine, out_g, self.KERNELSZ*self.KERNELSZ))
|
|
230
|
+
nn.init.kaiming_uniform_(self.head_w.view(self.C_fine * out_g, -1), a=np.sqrt(5))
|
|
231
|
+
self.head_bn = nn.GroupNorm(num_groups=min(8, self.out_channels if self.out_channels>1 else 1),
|
|
232
|
+
num_channels=self.out_channels) if self.task=="segmentation" else None
|
|
233
|
+
|
|
234
|
+
pref = torch.device("cuda" if torch.cuda.is_available() else "cpu")
|
|
235
|
+
self.runtime_device = self._probe_and_set_runtime_device(pref)
|
|
236
|
+
|
|
237
|
+
def _move_hc(self, hc: ho.SphericalStencil, device: torch.device) -> None:
|
|
238
|
+
for name, val in list(vars(hc).items()):
|
|
239
|
+
try:
|
|
240
|
+
if torch.is_tensor(val):
|
|
241
|
+
setattr(hc, name, val.to(device))
|
|
242
|
+
elif isinstance(val, (list, tuple)) and val and torch.is_tensor(val[0]):
|
|
243
|
+
setattr(hc, name, type(val)([v.to(device) for v in val]))
|
|
244
|
+
except Exception:
|
|
245
|
+
pass
|
|
246
|
+
|
|
247
|
+
@torch.no_grad()
|
|
248
|
+
def _probe_and_set_runtime_device(self, preferred: torch.device) -> torch.device:
|
|
249
|
+
if preferred.type == "cuda":
|
|
250
|
+
try:
|
|
251
|
+
super().to(preferred)
|
|
252
|
+
for hc in self.hconv_levels + [self.hconv_token, self.hconv_head]:
|
|
253
|
+
self._move_hc(hc, preferred)
|
|
254
|
+
npix0 = int(self.cell_ids_fine.shape[0])
|
|
255
|
+
x_try = torch.zeros(1, self.n_chan_in, npix0, device=preferred)
|
|
256
|
+
hc0 = self.hconv_levels[0] if len(self.hconv_levels)>0 else self.hconv_head
|
|
257
|
+
y_try = hc0.Convol_torch(x_try, self.pe_w1, cell_ids=self.cell_ids_fine)
|
|
258
|
+
_ = (y_try if torch.is_tensor(y_try) else torch.as_tensor(y_try, device=preferred)).sum().item()
|
|
259
|
+
self._foscat_device = preferred
|
|
260
|
+
return preferred
|
|
261
|
+
except Exception:
|
|
262
|
+
pass
|
|
263
|
+
cpu = torch.device("cpu")
|
|
264
|
+
super().to(cpu)
|
|
265
|
+
for hc in self.hconv_levels + [self.hconv_token, self.hconv_head]:
|
|
266
|
+
self._move_hc(hc, cpu)
|
|
267
|
+
self._foscat_device = cpu
|
|
268
|
+
return cpu
|
|
269
|
+
|
|
270
|
+
def _as_tensor_batch(self, x):
|
|
271
|
+
if isinstance(x, list):
|
|
272
|
+
if len(x) == 1:
|
|
273
|
+
t = x[0]
|
|
274
|
+
return t.unsqueeze(0) if t.dim() == 2 else t
|
|
275
|
+
raise ValueError("Variable-length list not supported here; pass a tensor.")
|
|
276
|
+
return x
|
|
277
|
+
|
|
278
|
+
def _to_numpy_ids(self, ids):
|
|
279
|
+
if torch.is_tensor(ids):
|
|
280
|
+
return ids.detach().cpu().numpy()
|
|
281
|
+
return np.asarray(ids)
|
|
282
|
+
|
|
283
|
+
def _patch_embed_fine(self, x_t: torch.Tensor) -> torch.Tensor:
|
|
284
|
+
hc0 = self.hconv_levels[0] if len(self.hconv_levels)>0 else self.hconv_head
|
|
285
|
+
z = hc0.Convol_torch(x_t, self.pe_w1, cell_ids=self.cell_ids_fine)
|
|
286
|
+
z = self._as_tensor_batch(z if torch.is_tensor(z) else torch.as_tensor(z, device=self.runtime_device))
|
|
287
|
+
z = self.pe_bn1(z); z = F.gelu(z)
|
|
288
|
+
z = hc0.Convol_torch(z, self.pe_w2, cell_ids=self.cell_ids_fine)
|
|
289
|
+
z = self._as_tensor_batch(z if torch.is_tensor(z) else torch.as_tensor(z, device=self.runtime_device))
|
|
290
|
+
z = self.pe_bn2(z); z = F.gelu(z)
|
|
291
|
+
return z
|
|
292
|
+
|
|
293
|
+
def forward(self, x: torch.Tensor, runtime_ids: Optional[np.ndarray] = None) -> torch.Tensor:
|
|
294
|
+
if x.dim() != 4:
|
|
295
|
+
raise ValueError("Expected input shape (B, T_in, C_in, Npix)")
|
|
296
|
+
B, T_in, C_in, Nf = x.shape
|
|
297
|
+
if C_in != self.n_chan_in:
|
|
298
|
+
raise ValueError(f"Expected n_chan_in={self.n_chan_in}, got {C_in}")
|
|
299
|
+
x = x.to(self.runtime_device)
|
|
300
|
+
|
|
301
|
+
fine_ids_runtime = self.cell_ids_fine if runtime_ids is None else self._to_numpy_ids(runtime_ids)
|
|
302
|
+
ids_chain = [np.asarray(fine_ids_runtime)]
|
|
303
|
+
nside_tmp = self.in_nside
|
|
304
|
+
_dummy = self.f.backend.bk_cast(np.zeros((1, 1, ids_chain[0].shape[0]), dtype=self.np_dtype))
|
|
305
|
+
for hc in self.hconv_levels:
|
|
306
|
+
_dummy, _next = hc.Down(_dummy, cell_ids=ids_chain[-1], nside=nside_tmp, max_poll=True)
|
|
307
|
+
ids_chain.append(self.f.backend.to_numpy(_next))
|
|
308
|
+
nside_tmp //= 2
|
|
309
|
+
|
|
310
|
+
# Encoder histories per level
|
|
311
|
+
l_hist: List[torch.Tensor] = []
|
|
312
|
+
l_ids: List[np.ndarray] = []
|
|
313
|
+
|
|
314
|
+
feats_fine = []
|
|
315
|
+
for t in range(T_in):
|
|
316
|
+
zt = self._patch_embed_fine(x[:, t, :, :])
|
|
317
|
+
feats_fine.append(zt.unsqueeze(1))
|
|
318
|
+
feats_fine = torch.cat(feats_fine, dim=1) # (B, T_in, C_fine, N_fine)
|
|
319
|
+
l_hist.append(feats_fine)
|
|
320
|
+
l_ids.append(self.cell_ids_fine)
|
|
321
|
+
|
|
322
|
+
current_nside = self.in_nside
|
|
323
|
+
l_data_hist = feats_fine
|
|
324
|
+
for i, hc in enumerate(self.hconv_levels):
|
|
325
|
+
Cin = self.level_dims[i]
|
|
326
|
+
Cout = self.level_dims[i+1]
|
|
327
|
+
w1, w2 = self.enc_w1[i], self.enc_w2[i]
|
|
328
|
+
feats_next = []
|
|
329
|
+
for t in range(T_in):
|
|
330
|
+
zt = l_data_hist[:, t, :, :]
|
|
331
|
+
zt = hc.Convol_torch(zt, w1, cell_ids=l_ids[-1])
|
|
332
|
+
zt = self._as_tensor_batch(zt if torch.is_tensor(zt) else torch.as_tensor(zt, device=self.runtime_device))
|
|
333
|
+
zt = self.enc_bn1[i](zt); zt = F.gelu(zt)
|
|
334
|
+
zt = hc.Convol_torch(zt, w2, cell_ids=l_ids[-1])
|
|
335
|
+
zt = self._as_tensor_batch(zt if torch.is_tensor(zt) else torch.as_tensor(zt, device=self.runtime_device))
|
|
336
|
+
zt = self.enc_bn2[i](zt); zt = F.gelu(zt)
|
|
337
|
+
feats_next.append(zt.unsqueeze(1))
|
|
338
|
+
feats_next = torch.cat(feats_next, dim=1) # (B, T_in, Cout, N_i)
|
|
339
|
+
|
|
340
|
+
feats_down = []
|
|
341
|
+
next_ids_list = None
|
|
342
|
+
for t in range(T_in):
|
|
343
|
+
zt, next_ids = hc.Down(feats_next[:, t, :, :], cell_ids=l_ids[-1], nside=current_nside, max_poll=True)
|
|
344
|
+
zt = self._as_tensor_batch(zt)
|
|
345
|
+
feats_down.append(zt.unsqueeze(1))
|
|
346
|
+
next_ids_list = next_ids
|
|
347
|
+
feats_down = torch.cat(feats_down, dim=1) # (B, T_in, Cout, N_{i+1})
|
|
348
|
+
|
|
349
|
+
l_hist.append(feats_down)
|
|
350
|
+
l_ids.append(self.f.backend.to_numpy(next_ids_list))
|
|
351
|
+
l_data_hist = feats_down
|
|
352
|
+
current_nside //= 2
|
|
353
|
+
|
|
354
|
+
# Temporal encoder on skips (levels 0..token_down-1)
|
|
355
|
+
skips: List[torch.Tensor] = []
|
|
356
|
+
for i in range(self.token_down):
|
|
357
|
+
Bx, Tx, Cx, Nx = l_hist[i].shape
|
|
358
|
+
z = l_hist[i].permute(0, 3, 1, 2).reshape(Bx*Nx, Tx, Cx)
|
|
359
|
+
z = self.temporal_encoders[i](z)
|
|
360
|
+
z = z.mean(dim=1)
|
|
361
|
+
H_i = z.view(Bx, Nx, Cx).permute(0, 2, 1).contiguous()
|
|
362
|
+
skips.append(H_i)
|
|
363
|
+
|
|
364
|
+
# Token-level transformer (spatial)
|
|
365
|
+
x_tok_hist = l_hist[-1] # (B, T_in, E, Ntok)
|
|
366
|
+
x_tok = x_tok_hist.mean(dim=1) # (B, E, Ntok) (could add temporal encoder here as well)
|
|
367
|
+
seq = x_tok.permute(0, 2, 1) + self.pos_token[:, :x_tok.shape[2], :]
|
|
368
|
+
seq = self.encoder_token(seq)
|
|
369
|
+
y = seq.permute(0, 2, 1) # (B, E, Ntok)
|
|
370
|
+
|
|
371
|
+
if self.task == "global":
|
|
372
|
+
g = seq.mean(dim=1)
|
|
373
|
+
return self.global_head(g)
|
|
374
|
+
|
|
375
|
+
# Decoder: Up + cross-attn fusion + double conv refinement
|
|
376
|
+
dec_idx = 0
|
|
377
|
+
for i in range(self.token_down, 0, -1):
|
|
378
|
+
coarse_ids = ids_chain[i]
|
|
379
|
+
fine_ids = ids_chain[i-1]
|
|
380
|
+
source_ns = self.in_nside // (2 ** i)
|
|
381
|
+
fine_ns = self.in_nside // (2 ** (i-1))
|
|
382
|
+
Cfine = self.level_dims[i-1]
|
|
383
|
+
|
|
384
|
+
op_fine = self.hconv_head if fine_ns == self.in_nside else self.hconv_levels[self.nsides_levels.index(fine_ns)]
|
|
385
|
+
|
|
386
|
+
y_up = op_fine.Up(y, cell_ids=coarse_ids, o_cell_ids=fine_ids, nside=source_ns)
|
|
387
|
+
y_up = self._as_tensor_batch(y_up if torch.is_tensor(y_up) else torch.as_tensor(y_up, device=self.runtime_device)) # (B, Cfine, N)
|
|
388
|
+
|
|
389
|
+
skip_i = skips[i-1] # (B, Cfine, N)
|
|
390
|
+
q = self.dec_q[dec_idx](y_up.permute(0,2,1))
|
|
391
|
+
k = self.dec_k[dec_idx](skip_i.permute(0,2,1))
|
|
392
|
+
v = self.dec_v[dec_idx](skip_i.permute(0,2,1))
|
|
393
|
+
if self.pos_embed_per_level:
|
|
394
|
+
pos = self.level_pos[dec_idx][:, :q.shape[1], :]
|
|
395
|
+
q = q + pos; k = k + pos
|
|
396
|
+
z, _ = self.dec_attn[dec_idx](q, k, v)
|
|
397
|
+
z = self.dec_mlp[dec_idx](z)
|
|
398
|
+
z = z.permute(0,2,1).contiguous() # (B, Cfine, N)
|
|
399
|
+
|
|
400
|
+
z = op_fine.Convol_torch(z, self.dec_refine_w1[dec_idx], cell_ids=fine_ids)
|
|
401
|
+
z = self._as_tensor_batch(z if torch.is_tensor(z) else torch.as_tensor(z, device=self.runtime_device))
|
|
402
|
+
z = self.dec_refine_bn1[dec_idx](z); z = F.gelu(z)
|
|
403
|
+
z = op_fine.Convol_torch(z, self.dec_refine_w2[dec_idx], cell_ids=fine_ids)
|
|
404
|
+
z = self._as_tensor_batch(z if torch.is_tensor(z) else torch.as_tensor(z, device=self.runtime_device))
|
|
405
|
+
z = self.dec_refine_bn2[dec_idx](z); z = F.gelu(z)
|
|
406
|
+
|
|
407
|
+
y = z
|
|
408
|
+
dec_idx += 1
|
|
409
|
+
|
|
410
|
+
y = self.hconv_head.Convol_torch(y, self.head_w, cell_ids=fine_ids_runtime)
|
|
411
|
+
y = self._as_tensor_batch(y if torch.is_tensor(y) else torch.as_tensor(y, device=self.runtime_device))
|
|
412
|
+
if self.task == "segmentation" and self.head_bn is not None:
|
|
413
|
+
y = self.head_bn(y)
|
|
414
|
+
if self.final_activation == "sigmoid":
|
|
415
|
+
y = torch.sigmoid(y)
|
|
416
|
+
elif self.final_activation == "softmax":
|
|
417
|
+
y = torch.softmax(y, dim=1)
|
|
418
|
+
return y
|
|
419
|
+
|
|
420
|
+
|
|
421
|
+
if __name__ == "__main__":
|
|
422
|
+
in_nside = 4
|
|
423
|
+
npix = 12 * in_nside * in_nside
|
|
424
|
+
cell_ids = np.arange(npix, dtype=np.int64)
|
|
425
|
+
|
|
426
|
+
B, T_in, Cin = 2, 3, 4
|
|
427
|
+
x = torch.randn(B, T_in, Cin, npix)
|
|
428
|
+
|
|
429
|
+
model = HealpixViTSkip(
|
|
430
|
+
in_nside=in_nside,
|
|
431
|
+
n_chan_in=Cin,
|
|
432
|
+
level_dims=[64, 96, 128],
|
|
433
|
+
depth_token=2,
|
|
434
|
+
num_heads_token=4,
|
|
435
|
+
cell_ids=cell_ids,
|
|
436
|
+
task="regression",
|
|
437
|
+
out_channels=1,
|
|
438
|
+
KERNELSZ=3,
|
|
439
|
+
G=1,
|
|
440
|
+
dropout=0.1,
|
|
441
|
+
).eval()
|
|
442
|
+
|
|
443
|
+
with torch.no_grad():
|
|
444
|
+
y = model(x)
|
|
445
|
+
print("Output:", tuple(y.shape))
|