foscat 2025.9.5__py3-none-any.whl → 2025.10.2__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- foscat/FoCUS.py +14 -9
- foscat/Plot.py +4 -3
- foscat/healpix_unet_torch.py +17 -1
- foscat/healpix_vit_skip.py +445 -0
- foscat/healpix_vit_torch-old.py +658 -0
- foscat/healpix_vit_torch.py +521 -0
- foscat/planar_vit.py +206 -0
- foscat/unet_2_d_from_healpix_params.py +421 -0
- {foscat-2025.9.5.dist-info → foscat-2025.10.2.dist-info}/METADATA +1 -1
- {foscat-2025.9.5.dist-info → foscat-2025.10.2.dist-info}/RECORD +13 -8
- {foscat-2025.9.5.dist-info → foscat-2025.10.2.dist-info}/WHEEL +0 -0
- {foscat-2025.9.5.dist-info → foscat-2025.10.2.dist-info}/licenses/LICENSE +0 -0
- {foscat-2025.9.5.dist-info → foscat-2025.10.2.dist-info}/top_level.txt +0 -0
foscat/planar_vit.py
ADDED
|
@@ -0,0 +1,206 @@
|
|
|
1
|
+
# healpix_unet_torch.py
|
|
2
|
+
# (Planar Vision Transformer baseline for lat–lon grids)
|
|
3
|
+
from __future__ import annotations
|
|
4
|
+
from typing import Optional
|
|
5
|
+
|
|
6
|
+
import torch
|
|
7
|
+
import torch.nn as nn
|
|
8
|
+
import torch.nn.functional as F
|
|
9
|
+
|
|
10
|
+
|
|
11
|
+
# ---------------------------
|
|
12
|
+
# Building blocks
|
|
13
|
+
# ---------------------------
|
|
14
|
+
|
|
15
|
+
class _MLP(nn.Module):
|
|
16
|
+
"""ViT MLP: Linear -> GELU -> Dropout -> Linear -> Dropout."""
|
|
17
|
+
def __init__(self, dim: int, mlp_ratio: float = 4.0, drop: float = 0.1):
|
|
18
|
+
super().__init__()
|
|
19
|
+
hidden = int(dim * mlp_ratio)
|
|
20
|
+
self.fc1 = nn.Linear(dim, hidden)
|
|
21
|
+
self.fc2 = nn.Linear(hidden, dim)
|
|
22
|
+
self.act = nn.GELU()
|
|
23
|
+
self.drop = nn.Dropout(drop)
|
|
24
|
+
|
|
25
|
+
def forward(self, x: torch.Tensor) -> torch.Tensor:
|
|
26
|
+
x = self.drop(self.act(self.fc1(x)))
|
|
27
|
+
x = self.drop(self.fc2(x))
|
|
28
|
+
return x
|
|
29
|
+
|
|
30
|
+
|
|
31
|
+
class _ViTBlock(nn.Module):
|
|
32
|
+
"""
|
|
33
|
+
Transformer block (Pre-LN):
|
|
34
|
+
x = x + Drop(MHA(LN(x)))
|
|
35
|
+
x = x + Drop(MLP(LN(x)))
|
|
36
|
+
"""
|
|
37
|
+
def __init__(self, dim: int, num_heads: int, mlp_ratio: float = 4.0, drop: float = 0.1):
|
|
38
|
+
super().__init__()
|
|
39
|
+
assert dim % num_heads == 0, "embed_dim must be divisible by num_heads"
|
|
40
|
+
self.norm1 = nn.LayerNorm(dim)
|
|
41
|
+
self.attn = nn.MultiheadAttention(dim, num_heads, dropout=drop, batch_first=True)
|
|
42
|
+
self.norm2 = nn.LayerNorm(dim)
|
|
43
|
+
self.mlp = _MLP(dim, mlp_ratio, drop)
|
|
44
|
+
self.drop_path = nn.Dropout(drop)
|
|
45
|
+
|
|
46
|
+
def forward(self, x: torch.Tensor) -> torch.Tensor:
|
|
47
|
+
# Multi-head self-attention
|
|
48
|
+
x = x + self.drop_path(self.attn(self.norm1(x), self.norm1(x), self.norm1(x))[0])
|
|
49
|
+
# Feed-forward
|
|
50
|
+
x = x + self.drop_path(self.mlp(self.norm2(x)))
|
|
51
|
+
return x
|
|
52
|
+
|
|
53
|
+
|
|
54
|
+
# ---------------------------
|
|
55
|
+
# Planar ViT (lat–lon images)
|
|
56
|
+
# ---------------------------
|
|
57
|
+
|
|
58
|
+
class PlanarViT(nn.Module):
|
|
59
|
+
"""
|
|
60
|
+
Vision Transformer for 2D lat–lon grids (planar baseline).
|
|
61
|
+
|
|
62
|
+
Input : (B, C=T_in, H, W)
|
|
63
|
+
Output: (B, out_ch, H, W) # dense per-pixel prediction
|
|
64
|
+
|
|
65
|
+
Pipeline
|
|
66
|
+
--------
|
|
67
|
+
1) Patch embedding via Conv2d(kernel_size=patch, stride=patch) -> embed_dim
|
|
68
|
+
2) Optional CLS token (disabled by default for dense output)
|
|
69
|
+
3) Learned positional embeddings (or none)
|
|
70
|
+
4) Stack of Transformer blocks
|
|
71
|
+
5) Linear head per token, then nearest upsample back to (H, W)
|
|
72
|
+
|
|
73
|
+
Notes
|
|
74
|
+
-----
|
|
75
|
+
- Keep H, W divisible by `patch`.
|
|
76
|
+
- For residual-of-persistence training (recommended for monthly SST):
|
|
77
|
+
pred = x[:, -1:, ...] + model(x)
|
|
78
|
+
and train the loss on `pred` vs target.
|
|
79
|
+
"""
|
|
80
|
+
def __init__(
|
|
81
|
+
self,
|
|
82
|
+
in_ch: int, # e.g., T_in months
|
|
83
|
+
H: int,
|
|
84
|
+
W: int,
|
|
85
|
+
*,
|
|
86
|
+
embed_dim: int = 384,
|
|
87
|
+
depth: int = 8,
|
|
88
|
+
num_heads: int = 12,
|
|
89
|
+
mlp_ratio: float = 4.0,
|
|
90
|
+
patch: int = 4,
|
|
91
|
+
out_ch: int = 1,
|
|
92
|
+
dropout: float = 0.1,
|
|
93
|
+
cls_token: bool = False, # keep False for dense prediction
|
|
94
|
+
pos_embed: str = "learned", # or "none"
|
|
95
|
+
):
|
|
96
|
+
super().__init__()
|
|
97
|
+
assert H % patch == 0 and W % patch == 0, "H and W must be divisible by patch"
|
|
98
|
+
assert embed_dim % num_heads == 0, "embed_dim must be divisible by num_heads"
|
|
99
|
+
|
|
100
|
+
self.H, self.W = H, W
|
|
101
|
+
self.patch = patch
|
|
102
|
+
self.embed_dim = embed_dim
|
|
103
|
+
self.cls_token_enabled = bool(cls_token)
|
|
104
|
+
self.use_pos_embed = (pos_embed == "learned")
|
|
105
|
+
|
|
106
|
+
# 1) Patch embedding (Conv2d with stride=patch) → tokens
|
|
107
|
+
self.patch_embed = nn.Conv2d(in_ch, embed_dim, kernel_size=patch, stride=patch)
|
|
108
|
+
|
|
109
|
+
# 2) Token bookkeeping & positional embeddings
|
|
110
|
+
Hp, Wp = H // patch, W // patch
|
|
111
|
+
self.num_tokens = Hp * Wp + (1 if self.cls_token_enabled else 0)
|
|
112
|
+
|
|
113
|
+
if self.cls_token_enabled:
|
|
114
|
+
self.cls_token = nn.Parameter(torch.zeros(1, 1, embed_dim))
|
|
115
|
+
nn.init.trunc_normal_(self.cls_token, std=0.02)
|
|
116
|
+
else:
|
|
117
|
+
self.cls_token = None
|
|
118
|
+
|
|
119
|
+
if self.use_pos_embed:
|
|
120
|
+
self.pos_embed = nn.Parameter(torch.zeros(1, self.num_tokens, embed_dim))
|
|
121
|
+
nn.init.trunc_normal_(self.pos_embed, std=0.02)
|
|
122
|
+
else:
|
|
123
|
+
self.pos_embed = None
|
|
124
|
+
|
|
125
|
+
# 3) Transformer encoder
|
|
126
|
+
self.blocks = nn.ModuleList([
|
|
127
|
+
_ViTBlock(embed_dim, num_heads, mlp_ratio=mlp_ratio, drop=dropout)
|
|
128
|
+
for _ in range(depth)
|
|
129
|
+
])
|
|
130
|
+
|
|
131
|
+
# 4) Patch-wise head (token -> out_ch)
|
|
132
|
+
self.head = nn.Linear(embed_dim, out_ch)
|
|
133
|
+
|
|
134
|
+
# Store for unpatching
|
|
135
|
+
self.Hp, self.Wp = Hp, Wp
|
|
136
|
+
|
|
137
|
+
def forward(self, x: torch.Tensor) -> torch.Tensor:
|
|
138
|
+
"""
|
|
139
|
+
x: (B, C, H, W) with H,W fixed to construction-time H,W
|
|
140
|
+
returns: (B, out_ch, H, W)
|
|
141
|
+
"""
|
|
142
|
+
B, C, H, W = x.shape
|
|
143
|
+
if (H != self.H) or (W != self.W):
|
|
144
|
+
raise ValueError(f"Input H,W must be ({self.H},{self.W}), got ({H},{W}).")
|
|
145
|
+
|
|
146
|
+
# Patch embedding → (B, E, Hp, Wp) → (B, Np, E)
|
|
147
|
+
z = self.patch_embed(x) # (B, E, Hp, Wp)
|
|
148
|
+
z = z.flatten(2).transpose(1, 2) # (B, Np, E)
|
|
149
|
+
|
|
150
|
+
# Optional CLS
|
|
151
|
+
if self.cls_token_enabled:
|
|
152
|
+
cls = self.cls_token.expand(B, -1, -1) # (B,1,E)
|
|
153
|
+
z = torch.cat([cls, z], dim=1) # (B,1+Np,E)
|
|
154
|
+
|
|
155
|
+
# Positional embedding
|
|
156
|
+
if self.pos_embed is not None:
|
|
157
|
+
z = z + self.pos_embed[:, :z.shape[1], :]
|
|
158
|
+
|
|
159
|
+
# Transformer
|
|
160
|
+
for blk in self.blocks:
|
|
161
|
+
z = blk(z) # (B, N, E)
|
|
162
|
+
|
|
163
|
+
# Drop CLS for dense output
|
|
164
|
+
if self.cls_token_enabled:
|
|
165
|
+
tokens = z[:, 1:, :] # (B, Np, E)
|
|
166
|
+
else:
|
|
167
|
+
tokens = z
|
|
168
|
+
|
|
169
|
+
# Token head → (B, Np, out_ch) → (B, out_ch, Hp, Wp) → upsample to (H, W)
|
|
170
|
+
y_tok = self.head(tokens).transpose(1, 2) # (B, out_ch, Np)
|
|
171
|
+
y = y_tok.reshape(B, -1, self.Hp, self.Wp) # (B, out_ch, Hp, Wp)
|
|
172
|
+
y = F.interpolate(y, scale_factor=self.patch, mode="nearest")
|
|
173
|
+
return y
|
|
174
|
+
|
|
175
|
+
|
|
176
|
+
# ---------------------------
|
|
177
|
+
# Utilities
|
|
178
|
+
# ---------------------------
|
|
179
|
+
|
|
180
|
+
def count_parameters(model: nn.Module) -> tuple[int, int]:
|
|
181
|
+
"""Return (total_params, trainable_params)."""
|
|
182
|
+
total = sum(p.numel() for p in model.parameters())
|
|
183
|
+
trainable = sum(p.numel() for p in model.parameters() if p.requires_grad)
|
|
184
|
+
return total, trainable
|
|
185
|
+
|
|
186
|
+
|
|
187
|
+
# ---------------------------
|
|
188
|
+
# Smoke test
|
|
189
|
+
# ---------------------------
|
|
190
|
+
|
|
191
|
+
if __name__ == "__main__":
|
|
192
|
+
# Example: T_in=6, grid 128x256, predict 1 channel
|
|
193
|
+
B, C, H, W = 2, 6, 128, 256
|
|
194
|
+
x = torch.randn(B, C, H, W)
|
|
195
|
+
|
|
196
|
+
model = PlanarViT(
|
|
197
|
+
in_ch=C, H=H, W=W,
|
|
198
|
+
embed_dim=384, depth=8, num_heads=12,
|
|
199
|
+
mlp_ratio=4.0, patch=4, out_ch=1, dropout=0.1,
|
|
200
|
+
cls_token=False, pos_embed="learned"
|
|
201
|
+
)
|
|
202
|
+
y = model(x)
|
|
203
|
+
tot, trn = count_parameters(model)
|
|
204
|
+
print("Output:", tuple(y.shape))
|
|
205
|
+
print("Params:", f"total={tot:,}", f"trainable={trn:,}")
|
|
206
|
+
|
|
@@ -0,0 +1,421 @@
|
|
|
1
|
+
from __future__ import annotations
|
|
2
|
+
from typing import List, Optional, Literal, Tuple
|
|
3
|
+
import torch
|
|
4
|
+
import torch.nn as nn
|
|
5
|
+
import torch.nn.functional as F
|
|
6
|
+
from contextlib import nullcontext
|
|
7
|
+
|
|
8
|
+
class PlanarUNet(nn.Module):
|
|
9
|
+
"""
|
|
10
|
+
U-Net 2D (images HxW) mirroring the parameterization of the HealpixUNet.
|
|
11
|
+
|
|
12
|
+
Key compat points with HealpixUNet:
|
|
13
|
+
- Same constructor fields: in_nside, n_chan_in, chanlist, KERNELSZ, task,
|
|
14
|
+
out_channels, final_activation, device, down_type, dtype, head_reduce.
|
|
15
|
+
- Two convs per level (encoder & decoder), GroupNorm + ReLU after each conv.
|
|
16
|
+
- Downsampling by factor 2 at each level; upsampling mirrors back.
|
|
17
|
+
- Head produces `out_channels` with optional BN and final activation.
|
|
18
|
+
|
|
19
|
+
Differences vs sphere version:
|
|
20
|
+
- Operates on regular 2D images of size (3*in_nside, 4*in_nside).
|
|
21
|
+
- Standard Conv2d instead of custom spherical stencil.
|
|
22
|
+
- No gauges (G=1 implicit) and no cell_ids.
|
|
23
|
+
|
|
24
|
+
Shapes
|
|
25
|
+
------
|
|
26
|
+
Input : (B, C_in, 3*in_nside, 4*in_nside)
|
|
27
|
+
Output : (B, C_out, 3*in_nside, 4*in_nside)
|
|
28
|
+
|
|
29
|
+
Constraints
|
|
30
|
+
-----------
|
|
31
|
+
`in_nside` must be divisible by 2**depth, where depth == len(chanlist).
|
|
32
|
+
"""
|
|
33
|
+
|
|
34
|
+
def __init__(
|
|
35
|
+
self,
|
|
36
|
+
*,
|
|
37
|
+
in_nside: int,
|
|
38
|
+
n_chan_in: int,
|
|
39
|
+
chanlist: List[int],
|
|
40
|
+
KERNELSZ: int = 3,
|
|
41
|
+
task: Literal['regression', 'segmentation'] = 'regression',
|
|
42
|
+
out_channels: int = 1,
|
|
43
|
+
final_activation: Optional[Literal['none', 'sigmoid', 'softmax']] = None,
|
|
44
|
+
device: Optional[torch.device | str] = None,
|
|
45
|
+
down_type: Optional[Literal['mean','max']] = 'max',
|
|
46
|
+
dtype: Literal['float32','float64'] = 'float32',
|
|
47
|
+
head_reduce: Literal['mean','learned'] = 'mean', # kept for API symmetry
|
|
48
|
+
) -> None:
|
|
49
|
+
super().__init__()
|
|
50
|
+
|
|
51
|
+
if len(chanlist) == 0:
|
|
52
|
+
raise ValueError("chanlist must be non-empty (depth >= 1)")
|
|
53
|
+
self.in_nside = int(in_nside)
|
|
54
|
+
self.n_chan_in = int(n_chan_in)
|
|
55
|
+
self.chanlist = list(map(int, chanlist))
|
|
56
|
+
self.KERNELSZ = int(KERNELSZ)
|
|
57
|
+
self.task = task
|
|
58
|
+
self.out_channels = int(out_channels)
|
|
59
|
+
self.down_type = down_type
|
|
60
|
+
self.dtype = torch.float32 if dtype == 'float32' else torch.float64
|
|
61
|
+
self.head_reduce = head_reduce
|
|
62
|
+
|
|
63
|
+
# default final activation consistent with HealpixUNet
|
|
64
|
+
if final_activation is None:
|
|
65
|
+
if task == 'regression':
|
|
66
|
+
self.final_activation = 'none'
|
|
67
|
+
else:
|
|
68
|
+
self.final_activation = 'sigmoid' if out_channels == 1 else 'softmax'
|
|
69
|
+
else:
|
|
70
|
+
self.final_activation = final_activation
|
|
71
|
+
|
|
72
|
+
# Resolve device
|
|
73
|
+
if device is None:
|
|
74
|
+
device = 'cuda' if torch.cuda.is_available() else 'cpu'
|
|
75
|
+
self.device = torch.device(device)
|
|
76
|
+
|
|
77
|
+
depth = len(self.chanlist)
|
|
78
|
+
# geometry
|
|
79
|
+
H0, W0 = 3 * self.in_nside, 4 * self.in_nside
|
|
80
|
+
# ensure divisibility by 2**depth
|
|
81
|
+
if (self.in_nside % (2 ** depth)) != 0:
|
|
82
|
+
raise ValueError(
|
|
83
|
+
f"in_nside={self.in_nside} must be divisible by 2**depth where depth={depth}"
|
|
84
|
+
)
|
|
85
|
+
|
|
86
|
+
padding = self.KERNELSZ // 2
|
|
87
|
+
|
|
88
|
+
# --- Encoder ---
|
|
89
|
+
enc_layers = []
|
|
90
|
+
inC = self.n_chan_in
|
|
91
|
+
self.skips_channels: List[int] = []
|
|
92
|
+
for outC in self.chanlist:
|
|
93
|
+
block = nn.Sequential(
|
|
94
|
+
nn.Conv2d(inC, outC, kernel_size=self.KERNELSZ, padding=padding, bias=False),
|
|
95
|
+
_norm_2d(outC, kind="group"),
|
|
96
|
+
nn.ReLU(inplace=True),
|
|
97
|
+
nn.Conv2d(outC, outC, kernel_size=self.KERNELSZ, padding=padding, bias=False),
|
|
98
|
+
_norm_2d(outC, kind="group"),
|
|
99
|
+
nn.ReLU(inplace=True),
|
|
100
|
+
)
|
|
101
|
+
enc_layers.append(block)
|
|
102
|
+
inC = outC
|
|
103
|
+
self.skips_channels.append(outC)
|
|
104
|
+
self.encoder = nn.ModuleList(enc_layers)
|
|
105
|
+
|
|
106
|
+
# Pools
|
|
107
|
+
if self.down_type == 'max':
|
|
108
|
+
self.pool = nn.MaxPool2d(kernel_size=2, stride=2)
|
|
109
|
+
else:
|
|
110
|
+
self.pool = nn.AvgPool2d(kernel_size=2, stride=2)
|
|
111
|
+
|
|
112
|
+
# --- Decoder ---
|
|
113
|
+
dec_layers = []
|
|
114
|
+
upconvs = []
|
|
115
|
+
for l in reversed(range(depth)):
|
|
116
|
+
skipC = self.skips_channels[l]
|
|
117
|
+
upC = self.skips_channels[l + 1] if (l + 1) < depth else self.skips_channels[l]
|
|
118
|
+
inC_dec = upC + skipC
|
|
119
|
+
outC_dec = skipC
|
|
120
|
+
|
|
121
|
+
upconvs.append(
|
|
122
|
+
nn.ConvTranspose2d(upC, upC, kernel_size=2, stride=2)
|
|
123
|
+
)
|
|
124
|
+
dec_layers.append(
|
|
125
|
+
nn.Sequential(
|
|
126
|
+
nn.Conv2d(inC_dec, outC_dec, kernel_size=self.KERNELSZ, padding=padding, bias=False),
|
|
127
|
+
_norm_2d(outC_dec, kind="group"),
|
|
128
|
+
nn.ReLU(inplace=True),
|
|
129
|
+
nn.Conv2d(outC_dec, outC_dec, kernel_size=self.KERNELSZ, padding=padding, bias=False),
|
|
130
|
+
_norm_2d(outC_dec, kind="group"),
|
|
131
|
+
nn.ReLU(inplace=True),
|
|
132
|
+
)
|
|
133
|
+
)
|
|
134
|
+
self.upconvs = nn.ModuleList(upconvs)
|
|
135
|
+
self.decoder = nn.ModuleList(dec_layers)
|
|
136
|
+
|
|
137
|
+
# --- Head ---
|
|
138
|
+
head_inC = self.chanlist[0]
|
|
139
|
+
self.head_conv = nn.Conv2d(head_inC, self.out_channels, kernel_size=self.KERNELSZ, padding=padding)
|
|
140
|
+
self.head_bn = _norm_2d(self.out_channels, kind="group") if self.task == 'segmentation' else None
|
|
141
|
+
|
|
142
|
+
# optional learned mixer kept for API compatibility (no gauges here)
|
|
143
|
+
self.head_mixer = nn.Identity()
|
|
144
|
+
|
|
145
|
+
self.to(self.device, dtype=self.dtype)
|
|
146
|
+
|
|
147
|
+
def to_tensor(self,x):
|
|
148
|
+
return torch.tensor(x,device=self.device)
|
|
149
|
+
|
|
150
|
+
def to_numpy(self,x):
|
|
151
|
+
if isinstance(x,np.ndarray):
|
|
152
|
+
return x
|
|
153
|
+
return x.cpu().numpy()
|
|
154
|
+
|
|
155
|
+
# -------------------------- forward --------------------------
|
|
156
|
+
def forward(self, x: torch.Tensor) -> torch.Tensor:
|
|
157
|
+
"""x: (B, C_in, H, W) with H=3*in_nside, W=4*in_nside"""
|
|
158
|
+
if x.dim() != 4:
|
|
159
|
+
raise ValueError("Input must be (B, C, H, W)")
|
|
160
|
+
if x.shape[1] != self.n_chan_in:
|
|
161
|
+
raise ValueError(f"Expected {self.n_chan_in} input channels, got {x.shape[1]}")
|
|
162
|
+
|
|
163
|
+
x = x.to(self.device, dtype=self.dtype)
|
|
164
|
+
|
|
165
|
+
skips = []
|
|
166
|
+
z = x
|
|
167
|
+
for l, block in enumerate(self.encoder):
|
|
168
|
+
z = block(z)
|
|
169
|
+
skips.append(z)
|
|
170
|
+
if l < len(self.encoder) - 1:
|
|
171
|
+
z = self.pool(z)
|
|
172
|
+
|
|
173
|
+
# Decoder
|
|
174
|
+
for d, l in enumerate(reversed(range(len(self.chanlist)))):
|
|
175
|
+
if l < len(self.chanlist) - 1:
|
|
176
|
+
z = self.upconvs[d](z)
|
|
177
|
+
# pad if odd due to pooling/upsampling asymmetry (shouldn't happen given divisibility)
|
|
178
|
+
sh = skips[l].shape
|
|
179
|
+
if z.shape[-2:] != sh[-2:]:
|
|
180
|
+
z = _pad_to_match(z, sh[-2], sh[-1])
|
|
181
|
+
z = torch.cat([skips[l], z], dim=1)
|
|
182
|
+
z = self.decoder[d](z)
|
|
183
|
+
|
|
184
|
+
y = self.head_conv(z)
|
|
185
|
+
if self.task == 'segmentation' and self.head_bn is not None:
|
|
186
|
+
y = self.head_bn(y)
|
|
187
|
+
|
|
188
|
+
if self.final_activation == 'sigmoid':
|
|
189
|
+
y = torch.sigmoid(y)
|
|
190
|
+
elif self.final_activation == 'softmax':
|
|
191
|
+
y = torch.softmax(y, dim=1)
|
|
192
|
+
return y
|
|
193
|
+
|
|
194
|
+
@torch.no_grad()
|
|
195
|
+
def predict(self, x: torch.Tensor, batch_size: int = 8) -> torch.Tensor:
|
|
196
|
+
self.eval()
|
|
197
|
+
outs = []
|
|
198
|
+
for i in range(0, x.shape[0], batch_size):
|
|
199
|
+
xb = x[i:i+batch_size]
|
|
200
|
+
outs.append(self.forward(xb))
|
|
201
|
+
return torch.cat(outs, dim=0)
|
|
202
|
+
|
|
203
|
+
@torch.no_grad()
|
|
204
|
+
def predict(
|
|
205
|
+
self,
|
|
206
|
+
x: torch.Tensor,
|
|
207
|
+
batch_size: int = 8,
|
|
208
|
+
*,
|
|
209
|
+
amp: bool = False,
|
|
210
|
+
out_device: Optional[str] = 'cpu',
|
|
211
|
+
out_dtype: Literal['float32','float16'] = 'float32',
|
|
212
|
+
show_pbar: bool = False,
|
|
213
|
+
) -> torch.Tensor:
|
|
214
|
+
"""Memory-safe prediction.
|
|
215
|
+
- Streams mini-batches avec torch.inference_mode() + AMP optionnel.
|
|
216
|
+
- Déplace chaque batch de sorties sur `out_device` (CPU par défaut) pour libérer la VRAM.
|
|
217
|
+
- Vérifie et explicite les erreurs de shape.
|
|
218
|
+
"""
|
|
219
|
+
self.eval()
|
|
220
|
+
|
|
221
|
+
# --- checks & normalisation d'entrée ---
|
|
222
|
+
x = x if torch.is_tensor(x) else torch.as_tensor(x)
|
|
223
|
+
if x.ndim != 4:
|
|
224
|
+
raise ValueError(f"predict expects (N,C,H,W), got {tuple(getattr(x,'shape',()))}")
|
|
225
|
+
if x.shape[1] != self.n_chan_in:
|
|
226
|
+
raise ValueError(f"predict expected {self.n_chan_in} channels, got {x.shape[1]}")
|
|
227
|
+
n = int(x.shape[0])
|
|
228
|
+
if n == 0:
|
|
229
|
+
H, W = int(x.shape[-2]), int(x.shape[-1])
|
|
230
|
+
return torch.empty((0, self.out_channels, H, W), device=out_device or self.device)
|
|
231
|
+
|
|
232
|
+
# --- préparation ---
|
|
233
|
+
dtype_map = {'float32': torch.float32, 'float16': torch.float16}
|
|
234
|
+
out_dtype_t = dtype_map[out_dtype]
|
|
235
|
+
use_cuda = (self.device.type == 'cuda')
|
|
236
|
+
if use_cuda:
|
|
237
|
+
torch.backends.cudnn.benchmark = True
|
|
238
|
+
|
|
239
|
+
from math import ceil
|
|
240
|
+
nb = ceil(n / batch_size)
|
|
241
|
+
rng = range(0, n, batch_size)
|
|
242
|
+
if show_pbar:
|
|
243
|
+
try:
|
|
244
|
+
from tqdm import tqdm # type: ignore
|
|
245
|
+
rng = tqdm(rng, total=nb, desc='predict')
|
|
246
|
+
except Exception:
|
|
247
|
+
pass
|
|
248
|
+
|
|
249
|
+
# --- inférence batch par batch ---
|
|
250
|
+
out_list: List[torch.Tensor] = []
|
|
251
|
+
with torch.inference_mode():
|
|
252
|
+
ctx = (torch.cuda.amp.autocast() if (amp and use_cuda) else nullcontext())
|
|
253
|
+
for i in rng:
|
|
254
|
+
xb = x[i:i+batch_size].to(self.device, dtype=self.dtype, non_blocking=True)
|
|
255
|
+
with ctx:
|
|
256
|
+
yb = self.forward(xb)
|
|
257
|
+
# Déplacer la sortie vers l'appareil voulu (CPU par défaut)
|
|
258
|
+
yb = yb.to(out_device, dtype=out_dtype_t) if out_device is not None else yb.to(dtype=out_dtype_t)
|
|
259
|
+
out_list.append(yb)
|
|
260
|
+
del xb, yb
|
|
261
|
+
if use_cuda:
|
|
262
|
+
torch.cuda.empty_cache()
|
|
263
|
+
|
|
264
|
+
if not out_list:
|
|
265
|
+
raise RuntimeError(f"predict produced no outputs; check input shape {tuple(x.shape)} and batch_size={batch_size}")
|
|
266
|
+
return torch.cat(out_list, dim=0)
|
|
267
|
+
|
|
268
|
+
# -----------------------------
|
|
269
|
+
# Helpers
|
|
270
|
+
# -----------------------------
|
|
271
|
+
|
|
272
|
+
def _norm_2d(C: int, kind: str = "group", **kwargs) -> nn.Module:
|
|
273
|
+
if kind == "group":
|
|
274
|
+
num_groups = kwargs.get("num_groups", min(8, max(1, C // 8)) or 1)
|
|
275
|
+
while C % num_groups != 0 and num_groups > 1:
|
|
276
|
+
num_groups //= 2
|
|
277
|
+
return nn.GroupNorm(num_groups=num_groups, num_channels=C)
|
|
278
|
+
elif kind == "instance":
|
|
279
|
+
return nn.InstanceNorm2d(C, affine=True, track_running_stats=False)
|
|
280
|
+
elif kind == "batch":
|
|
281
|
+
return nn.BatchNorm2d(C)
|
|
282
|
+
else:
|
|
283
|
+
raise ValueError(f"Unknown norm kind: {kind}")
|
|
284
|
+
|
|
285
|
+
|
|
286
|
+
def _pad_to_match(x: torch.Tensor, H: int, W: int) -> torch.Tensor:
|
|
287
|
+
"""Pad x (B,C,h,w) with zeros on right/bottom to reach (H,W)."""
|
|
288
|
+
_, _, h, w = x.shape
|
|
289
|
+
ph = max(0, H - h)
|
|
290
|
+
pw = max(0, W - w)
|
|
291
|
+
if ph == 0 and pw == 0:
|
|
292
|
+
return x
|
|
293
|
+
return F.pad(x, (0, pw, 0, ph), mode='constant', value=0)
|
|
294
|
+
|
|
295
|
+
|
|
296
|
+
# -----------------------------
|
|
297
|
+
# Training utilities (mirror of Healpix fit)
|
|
298
|
+
# -----------------------------
|
|
299
|
+
from typing import Union
|
|
300
|
+
import numpy as np
|
|
301
|
+
from torch.utils.data import DataLoader, TensorDataset
|
|
302
|
+
|
|
303
|
+
def fit(
|
|
304
|
+
model: nn.Module,
|
|
305
|
+
x_train: Union[torch.Tensor, np.ndarray],
|
|
306
|
+
y_train: Union[torch.Tensor, np.ndarray],
|
|
307
|
+
*,
|
|
308
|
+
n_epoch: int = 10,
|
|
309
|
+
view_epoch: int = 10,
|
|
310
|
+
batch_size: int = 16,
|
|
311
|
+
lr: float = 1e-3,
|
|
312
|
+
weight_decay: float = 0.0,
|
|
313
|
+
clip_grad_norm: Optional[float] = None,
|
|
314
|
+
verbose: bool = True,
|
|
315
|
+
optimizer: Literal['ADAM', 'LBFGS'] = 'ADAM',
|
|
316
|
+
) -> dict:
|
|
317
|
+
"""Training loop *miroir* de `healpix_unet_torch.fit`, adapté aux images 2D.
|
|
318
|
+
|
|
319
|
+
- Entrées fixes: tensors/ndarrays de même taille (B, C, H, W) avec H=3*nside, W=4*nside
|
|
320
|
+
- Perte: MSE (regression) / BCE(BCEWithLogits si final_activation='none') / CrossEntropy (multiclasses)
|
|
321
|
+
- Optimiseur: ADAM ou LBFGS avec closure
|
|
322
|
+
- Logs: renvoie {"loss": history}
|
|
323
|
+
"""
|
|
324
|
+
device = next(model.parameters()).device
|
|
325
|
+
model.to(device)
|
|
326
|
+
|
|
327
|
+
# ---- DataLoader
|
|
328
|
+
x_t = torch.as_tensor(x_train, dtype=torch.float32, device=device)
|
|
329
|
+
y_is_class = (getattr(model, 'task', 'regression') != 'regression' and getattr(model, 'out_channels', 1) > 1)
|
|
330
|
+
y_dtype = torch.long if y_is_class and (not torch.is_tensor(y_train) or y_train.ndim == x_t.ndim - 1) else torch.float32
|
|
331
|
+
y_t = torch.as_tensor(y_train, dtype=y_dtype, device=device)
|
|
332
|
+
|
|
333
|
+
ds = TensorDataset(x_t, y_t)
|
|
334
|
+
loader = DataLoader(ds, batch_size=batch_size, shuffle=True, drop_last=False)
|
|
335
|
+
|
|
336
|
+
# ---- Loss
|
|
337
|
+
if getattr(model, 'task', 'regression') == 'regression':
|
|
338
|
+
criterion = nn.MSELoss(reduction='mean')
|
|
339
|
+
seg_multiclass = False
|
|
340
|
+
else:
|
|
341
|
+
if getattr(model, 'out_channels', 1) == 1:
|
|
342
|
+
criterion = nn.BCEWithLogitsLoss() if getattr(model, 'final_activation', 'none') == 'none' else nn.BCELoss()
|
|
343
|
+
seg_multiclass = False
|
|
344
|
+
else:
|
|
345
|
+
criterion = nn.CrossEntropyLoss()
|
|
346
|
+
seg_multiclass = True
|
|
347
|
+
|
|
348
|
+
# ---- Optim
|
|
349
|
+
if optimizer.upper() == 'ADAM':
|
|
350
|
+
optim = torch.optim.Adam(model.parameters(), lr=lr, weight_decay=weight_decay)
|
|
351
|
+
outer, inner = n_epoch, 1
|
|
352
|
+
elif optimizer.upper() == 'LBFGS':
|
|
353
|
+
optim = torch.optim.LBFGS(model.parameters(), lr=lr, max_iter=20, history_size=50, line_search_fn='strong_wolfe')
|
|
354
|
+
outer, inner = max(1, n_epoch // 20), 20
|
|
355
|
+
else:
|
|
356
|
+
raise ValueError("optimizer must be 'ADAM' or 'LBFGS'")
|
|
357
|
+
|
|
358
|
+
# ---- Train
|
|
359
|
+
history: List[float] = []
|
|
360
|
+
model.train()
|
|
361
|
+
|
|
362
|
+
for epoch in range(outer):
|
|
363
|
+
for _ in range(inner):
|
|
364
|
+
epoch_loss, n_samples = 0.0, 0
|
|
365
|
+
for xb, yb in loader:
|
|
366
|
+
xb = xb.to(device, dtype=torch.float32, non_blocking=True)
|
|
367
|
+
yb = yb.to(device, non_blocking=True)
|
|
368
|
+
|
|
369
|
+
if isinstance(optim, torch.optim.LBFGS):
|
|
370
|
+
def closure():
|
|
371
|
+
optim.zero_grad(set_to_none=True)
|
|
372
|
+
preds = model(xb)
|
|
373
|
+
if seg_multiclass:
|
|
374
|
+
loss = criterion(preds, yb)
|
|
375
|
+
else:
|
|
376
|
+
loss = criterion(preds, yb)
|
|
377
|
+
loss.backward()
|
|
378
|
+
return loss
|
|
379
|
+
loss_val = float(optim.step(closure).item())
|
|
380
|
+
else:
|
|
381
|
+
optim.zero_grad(set_to_none=True)
|
|
382
|
+
preds = model(xb)
|
|
383
|
+
if seg_multiclass:
|
|
384
|
+
loss = criterion(preds, yb)
|
|
385
|
+
else:
|
|
386
|
+
loss = criterion(preds, yb)
|
|
387
|
+
loss.backward()
|
|
388
|
+
if clip_grad_norm is not None:
|
|
389
|
+
torch.nn.utils.clip_grad_norm_(model.parameters(), clip_grad_norm)
|
|
390
|
+
optim.step()
|
|
391
|
+
loss_val = float(loss.item())
|
|
392
|
+
|
|
393
|
+
epoch_loss += loss_val * xb.shape[0]
|
|
394
|
+
n_samples += xb.shape[0]
|
|
395
|
+
|
|
396
|
+
epoch_loss /= max(1, n_samples)
|
|
397
|
+
history.append(epoch_loss)
|
|
398
|
+
if verbose and ((len(history) % view_epoch == 0) or (len(history) == 1)):
|
|
399
|
+
print(f"[epoch {len(history)}] loss={epoch_loss:.6f}")
|
|
400
|
+
|
|
401
|
+
return {"loss": history}
|
|
402
|
+
|
|
403
|
+
|
|
404
|
+
# -----------------------------
|
|
405
|
+
# Minimal smoke test
|
|
406
|
+
# -----------------------------
|
|
407
|
+
if __name__ == "__main__":
|
|
408
|
+
torch.manual_seed(0)
|
|
409
|
+
nside = 32
|
|
410
|
+
chanlist = [16, 32, 64]
|
|
411
|
+
net = PlanarUNet(
|
|
412
|
+
in_nside=nside,
|
|
413
|
+
n_chan_in=3,
|
|
414
|
+
chanlist=chanlist,
|
|
415
|
+
KERNELSZ=3,
|
|
416
|
+
task='regression',
|
|
417
|
+
out_channels=1,
|
|
418
|
+
)
|
|
419
|
+
x = torch.randn(2, 3, 3*nside, 4*nside)
|
|
420
|
+
y = net(x)
|
|
421
|
+
print('input:', x.shape, 'output:', y.shape)
|
|
@@ -1,6 +1,6 @@
|
|
|
1
1
|
Metadata-Version: 2.4
|
|
2
2
|
Name: foscat
|
|
3
|
-
Version: 2025.
|
|
3
|
+
Version: 2025.10.2
|
|
4
4
|
Summary: Generate synthetic Healpix or 2D data using Cross Scattering Transform
|
|
5
5
|
Author-email: Jean-Marc DELOUIS <jean.marc.delouis@ifremer.fr>
|
|
6
6
|
Maintainer-email: Theo Foulquier <theo.foulquier@ifremer.fr>
|
|
@@ -4,12 +4,12 @@ foscat/BkTensorflow.py,sha256=iIdLx6VTOfOEocfZBOGyizQn5geDLTfdWWAwDeQr9YA,20056
|
|
|
4
4
|
foscat/BkTorch.py,sha256=W3n3XkFw9oecmyfSWt2JbP5L5eKWn8AAu0RZx1yLQb0,31975
|
|
5
5
|
foscat/CNN.py,sha256=4vky7jqTshL1aYLWsc-hQwf7gDjTVjL7I6HZiAsa6x4,5158
|
|
6
6
|
foscat/CircSpline.py,sha256=CXi49FxF8ZoeZ17Ua8c1AZXe2B5ICEC9aCXb97atB3s,4028
|
|
7
|
-
foscat/FoCUS.py,sha256=
|
|
7
|
+
foscat/FoCUS.py,sha256=9GHjHUBM8sfmYwZw7u-xGeOOjFCpbbvM-5m3yjPSt5g,105752
|
|
8
8
|
foscat/GCNN.py,sha256=q7yWHCMJpP7-m3WvR3OQnp5taeYWaMxIY2hQ6SIb9gs,4487
|
|
9
9
|
foscat/HOrientedConvol.py,sha256=xMaS-zzoUyXisBCPsHBVpn54tuA9Qv3na-tT86Cwn7U,38744
|
|
10
10
|
foscat/HealBili.py,sha256=YRPk9PO5G8NdwKeb33xiJs3_pMPAgIv5phCs8GT6LN0,12943
|
|
11
11
|
foscat/HealSpline.py,sha256=YRotJ1NQuXYFyFiM8fp6qkATIwRJ8lqIVo4vGXpHO-w,7472
|
|
12
|
-
foscat/Plot.py,sha256=
|
|
12
|
+
foscat/Plot.py,sha256=bpohWGsblTBxMrqE_X-iRvuvT-YyHDcgfWB4iYk5l10,49218
|
|
13
13
|
foscat/Softmax.py,sha256=UDZGrTroYtmGEyokGUVpwNO_cgbICi9QVuRr8Yx52_k,2917
|
|
14
14
|
foscat/SphericalStencil.py,sha256=DFipxQWupYPBa_62CkuMe9K7HkLFiHteF7Q6lAs3TLQ,56714
|
|
15
15
|
foscat/Spline1D.py,sha256=rKzzenduaZZ-yBDJd35it6Gyrj1spqb7hoIaUgISPzY,2983
|
|
@@ -20,9 +20,13 @@ foscat/alm.py,sha256=XkK4rFVRoO-oJpr74iBffKt7hdS_iJkR016IlYm10gQ,33832
|
|
|
20
20
|
foscat/backend.py,sha256=l3aMwDyXP6jURMIvratFMGWCTcQpaR68KnUuuGDezqE,45418
|
|
21
21
|
foscat/backend_tens.py,sha256=9Dp136m9frkclkwifJQLLbIpl3ETI3_txdPUZcKfuMw,1618
|
|
22
22
|
foscat/heal_NN.py,sha256=krEHM9NMZ74T9HUf-qK5td0tFttBA5SbaRgzThM2GYs,16943
|
|
23
|
-
foscat/healpix_unet_torch.py,sha256=
|
|
23
|
+
foscat/healpix_unet_torch.py,sha256=CDdrWyPJqF_gUT6rwB-TnghsbtkJ89WFw-K14m37DDQ,52221
|
|
24
|
+
foscat/healpix_vit_skip.py,sha256=26qpYoX7W1vCJujqtYUiRPUmwrDf_UJSN5kbL7DVV8I,20359
|
|
25
|
+
foscat/healpix_vit_torch-old.py,sha256=_PJecWRIWJc2FTQB_rthEeqLKYJ_UIdTN8ib_JuQ_xw,28985
|
|
26
|
+
foscat/healpix_vit_torch.py,sha256=XOqEOazob6WRptSD5dh5acWM_1_uErKXaVWPm08X-O0,22198
|
|
24
27
|
foscat/loss_backend_tens.py,sha256=dCOVN6faDtIpN3VO78HTmYP2i5fnFAf-Ddy5qVBlGrM,1783
|
|
25
28
|
foscat/loss_backend_torch.py,sha256=k3z18Dj3SaLKK6ZIKcm7GO4U_YKYVP6LtHG1aIbxkYk,1627
|
|
29
|
+
foscat/planar_vit.py,sha256=lQqwyz_P8G-Dav2vLqgkssDfeSe15YmjFzP5W-otjs0,6888
|
|
26
30
|
foscat/scat.py,sha256=qGYiBIysPt65MdmF07WWA4piVlTfA9-lFDTaicnqC2w,72822
|
|
27
31
|
foscat/scat1D.py,sha256=W5Uu6wdQ4ZsFKXpof0f1OBl-1wjJmW7ruvddRWxe7uM,53726
|
|
28
32
|
foscat/scat2D.py,sha256=boKj0ASqMMSy7uQLK6hPniG87m3hZGJBYBiq5v8F9IQ,532
|
|
@@ -31,8 +35,9 @@ foscat/scat_cov1D.py,sha256=XOxsZZ5TYq8f34i2tUgIfzyaqaTDlICB3HzD2l_puro,531
|
|
|
31
35
|
foscat/scat_cov2D.py,sha256=pAm0fKw8wyXram0TFbtw8tGcc8QPKuPXpQk0kh10r4U,7078
|
|
32
36
|
foscat/scat_cov_map.py,sha256=9MzpwT2g9S3dmnjHEMK7PPLQ27oGQg2VFVsP_TDUU5E,2869
|
|
33
37
|
foscat/scat_cov_map2D.py,sha256=zaIIYshXCqAeZ04I158GhD-Op4aoMlLnLEy7rxckVYY,2842
|
|
34
|
-
foscat
|
|
35
|
-
foscat-2025.
|
|
36
|
-
foscat-2025.
|
|
37
|
-
foscat-2025.
|
|
38
|
-
foscat-2025.
|
|
38
|
+
foscat/unet_2_d_from_healpix_params.py,sha256=r8hN-s091f3yHYlvAAiBbLOvtsz9vPrdwrWPM0ULR2Q,15949
|
|
39
|
+
foscat-2025.10.2.dist-info/licenses/LICENSE,sha256=i0ukIr8ZUpkSY2sZaE9XZK-6vuSU5iG6IgX_3pjatP8,1505
|
|
40
|
+
foscat-2025.10.2.dist-info/METADATA,sha256=WTeul0o0mER67oWvadXX0JV9hn1R3yAesPHOgPn8dys,7216
|
|
41
|
+
foscat-2025.10.2.dist-info/WHEEL,sha256=_zCd3N1l69ArxyTb8rzEoP9TpbYXkqRFSNOD5OuxnTs,91
|
|
42
|
+
foscat-2025.10.2.dist-info/top_level.txt,sha256=AGySXBBAlJgb8Tj8af6m_F-aiNg2zNTcybCUPVOKjAg,7
|
|
43
|
+
foscat-2025.10.2.dist-info/RECORD,,
|
|
File without changes
|
|
File without changes
|
|
File without changes
|