foscat 2025.9.4__py3-none-any.whl → 2025.10.2__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
foscat/planar_vit.py ADDED
@@ -0,0 +1,206 @@
1
+ # healpix_unet_torch.py
2
+ # (Planar Vision Transformer baseline for lat–lon grids)
3
+ from __future__ import annotations
4
+ from typing import Optional
5
+
6
+ import torch
7
+ import torch.nn as nn
8
+ import torch.nn.functional as F
9
+
10
+
11
+ # ---------------------------
12
+ # Building blocks
13
+ # ---------------------------
14
+
15
+ class _MLP(nn.Module):
16
+ """ViT MLP: Linear -> GELU -> Dropout -> Linear -> Dropout."""
17
+ def __init__(self, dim: int, mlp_ratio: float = 4.0, drop: float = 0.1):
18
+ super().__init__()
19
+ hidden = int(dim * mlp_ratio)
20
+ self.fc1 = nn.Linear(dim, hidden)
21
+ self.fc2 = nn.Linear(hidden, dim)
22
+ self.act = nn.GELU()
23
+ self.drop = nn.Dropout(drop)
24
+
25
+ def forward(self, x: torch.Tensor) -> torch.Tensor:
26
+ x = self.drop(self.act(self.fc1(x)))
27
+ x = self.drop(self.fc2(x))
28
+ return x
29
+
30
+
31
+ class _ViTBlock(nn.Module):
32
+ """
33
+ Transformer block (Pre-LN):
34
+ x = x + Drop(MHA(LN(x)))
35
+ x = x + Drop(MLP(LN(x)))
36
+ """
37
+ def __init__(self, dim: int, num_heads: int, mlp_ratio: float = 4.0, drop: float = 0.1):
38
+ super().__init__()
39
+ assert dim % num_heads == 0, "embed_dim must be divisible by num_heads"
40
+ self.norm1 = nn.LayerNorm(dim)
41
+ self.attn = nn.MultiheadAttention(dim, num_heads, dropout=drop, batch_first=True)
42
+ self.norm2 = nn.LayerNorm(dim)
43
+ self.mlp = _MLP(dim, mlp_ratio, drop)
44
+ self.drop_path = nn.Dropout(drop)
45
+
46
+ def forward(self, x: torch.Tensor) -> torch.Tensor:
47
+ # Multi-head self-attention
48
+ x = x + self.drop_path(self.attn(self.norm1(x), self.norm1(x), self.norm1(x))[0])
49
+ # Feed-forward
50
+ x = x + self.drop_path(self.mlp(self.norm2(x)))
51
+ return x
52
+
53
+
54
+ # ---------------------------
55
+ # Planar ViT (lat–lon images)
56
+ # ---------------------------
57
+
58
+ class PlanarViT(nn.Module):
59
+ """
60
+ Vision Transformer for 2D lat–lon grids (planar baseline).
61
+
62
+ Input : (B, C=T_in, H, W)
63
+ Output: (B, out_ch, H, W) # dense per-pixel prediction
64
+
65
+ Pipeline
66
+ --------
67
+ 1) Patch embedding via Conv2d(kernel_size=patch, stride=patch) -> embed_dim
68
+ 2) Optional CLS token (disabled by default for dense output)
69
+ 3) Learned positional embeddings (or none)
70
+ 4) Stack of Transformer blocks
71
+ 5) Linear head per token, then nearest upsample back to (H, W)
72
+
73
+ Notes
74
+ -----
75
+ - Keep H, W divisible by `patch`.
76
+ - For residual-of-persistence training (recommended for monthly SST):
77
+ pred = x[:, -1:, ...] + model(x)
78
+ and train the loss on `pred` vs target.
79
+ """
80
+ def __init__(
81
+ self,
82
+ in_ch: int, # e.g., T_in months
83
+ H: int,
84
+ W: int,
85
+ *,
86
+ embed_dim: int = 384,
87
+ depth: int = 8,
88
+ num_heads: int = 12,
89
+ mlp_ratio: float = 4.0,
90
+ patch: int = 4,
91
+ out_ch: int = 1,
92
+ dropout: float = 0.1,
93
+ cls_token: bool = False, # keep False for dense prediction
94
+ pos_embed: str = "learned", # or "none"
95
+ ):
96
+ super().__init__()
97
+ assert H % patch == 0 and W % patch == 0, "H and W must be divisible by patch"
98
+ assert embed_dim % num_heads == 0, "embed_dim must be divisible by num_heads"
99
+
100
+ self.H, self.W = H, W
101
+ self.patch = patch
102
+ self.embed_dim = embed_dim
103
+ self.cls_token_enabled = bool(cls_token)
104
+ self.use_pos_embed = (pos_embed == "learned")
105
+
106
+ # 1) Patch embedding (Conv2d with stride=patch) → tokens
107
+ self.patch_embed = nn.Conv2d(in_ch, embed_dim, kernel_size=patch, stride=patch)
108
+
109
+ # 2) Token bookkeeping & positional embeddings
110
+ Hp, Wp = H // patch, W // patch
111
+ self.num_tokens = Hp * Wp + (1 if self.cls_token_enabled else 0)
112
+
113
+ if self.cls_token_enabled:
114
+ self.cls_token = nn.Parameter(torch.zeros(1, 1, embed_dim))
115
+ nn.init.trunc_normal_(self.cls_token, std=0.02)
116
+ else:
117
+ self.cls_token = None
118
+
119
+ if self.use_pos_embed:
120
+ self.pos_embed = nn.Parameter(torch.zeros(1, self.num_tokens, embed_dim))
121
+ nn.init.trunc_normal_(self.pos_embed, std=0.02)
122
+ else:
123
+ self.pos_embed = None
124
+
125
+ # 3) Transformer encoder
126
+ self.blocks = nn.ModuleList([
127
+ _ViTBlock(embed_dim, num_heads, mlp_ratio=mlp_ratio, drop=dropout)
128
+ for _ in range(depth)
129
+ ])
130
+
131
+ # 4) Patch-wise head (token -> out_ch)
132
+ self.head = nn.Linear(embed_dim, out_ch)
133
+
134
+ # Store for unpatching
135
+ self.Hp, self.Wp = Hp, Wp
136
+
137
+ def forward(self, x: torch.Tensor) -> torch.Tensor:
138
+ """
139
+ x: (B, C, H, W) with H,W fixed to construction-time H,W
140
+ returns: (B, out_ch, H, W)
141
+ """
142
+ B, C, H, W = x.shape
143
+ if (H != self.H) or (W != self.W):
144
+ raise ValueError(f"Input H,W must be ({self.H},{self.W}), got ({H},{W}).")
145
+
146
+ # Patch embedding → (B, E, Hp, Wp) → (B, Np, E)
147
+ z = self.patch_embed(x) # (B, E, Hp, Wp)
148
+ z = z.flatten(2).transpose(1, 2) # (B, Np, E)
149
+
150
+ # Optional CLS
151
+ if self.cls_token_enabled:
152
+ cls = self.cls_token.expand(B, -1, -1) # (B,1,E)
153
+ z = torch.cat([cls, z], dim=1) # (B,1+Np,E)
154
+
155
+ # Positional embedding
156
+ if self.pos_embed is not None:
157
+ z = z + self.pos_embed[:, :z.shape[1], :]
158
+
159
+ # Transformer
160
+ for blk in self.blocks:
161
+ z = blk(z) # (B, N, E)
162
+
163
+ # Drop CLS for dense output
164
+ if self.cls_token_enabled:
165
+ tokens = z[:, 1:, :] # (B, Np, E)
166
+ else:
167
+ tokens = z
168
+
169
+ # Token head → (B, Np, out_ch) → (B, out_ch, Hp, Wp) → upsample to (H, W)
170
+ y_tok = self.head(tokens).transpose(1, 2) # (B, out_ch, Np)
171
+ y = y_tok.reshape(B, -1, self.Hp, self.Wp) # (B, out_ch, Hp, Wp)
172
+ y = F.interpolate(y, scale_factor=self.patch, mode="nearest")
173
+ return y
174
+
175
+
176
+ # ---------------------------
177
+ # Utilities
178
+ # ---------------------------
179
+
180
+ def count_parameters(model: nn.Module) -> tuple[int, int]:
181
+ """Return (total_params, trainable_params)."""
182
+ total = sum(p.numel() for p in model.parameters())
183
+ trainable = sum(p.numel() for p in model.parameters() if p.requires_grad)
184
+ return total, trainable
185
+
186
+
187
+ # ---------------------------
188
+ # Smoke test
189
+ # ---------------------------
190
+
191
+ if __name__ == "__main__":
192
+ # Example: T_in=6, grid 128x256, predict 1 channel
193
+ B, C, H, W = 2, 6, 128, 256
194
+ x = torch.randn(B, C, H, W)
195
+
196
+ model = PlanarViT(
197
+ in_ch=C, H=H, W=W,
198
+ embed_dim=384, depth=8, num_heads=12,
199
+ mlp_ratio=4.0, patch=4, out_ch=1, dropout=0.1,
200
+ cls_token=False, pos_embed="learned"
201
+ )
202
+ y = model(x)
203
+ tot, trn = count_parameters(model)
204
+ print("Output:", tuple(y.shape))
205
+ print("Params:", f"total={tot:,}", f"trainable={trn:,}")
206
+
@@ -0,0 +1,421 @@
1
+ from __future__ import annotations
2
+ from typing import List, Optional, Literal, Tuple
3
+ import torch
4
+ import torch.nn as nn
5
+ import torch.nn.functional as F
6
+ from contextlib import nullcontext
7
+
8
+ class PlanarUNet(nn.Module):
9
+ """
10
+ U-Net 2D (images HxW) mirroring the parameterization of the HealpixUNet.
11
+
12
+ Key compat points with HealpixUNet:
13
+ - Same constructor fields: in_nside, n_chan_in, chanlist, KERNELSZ, task,
14
+ out_channels, final_activation, device, down_type, dtype, head_reduce.
15
+ - Two convs per level (encoder & decoder), GroupNorm + ReLU after each conv.
16
+ - Downsampling by factor 2 at each level; upsampling mirrors back.
17
+ - Head produces `out_channels` with optional BN and final activation.
18
+
19
+ Differences vs sphere version:
20
+ - Operates on regular 2D images of size (3*in_nside, 4*in_nside).
21
+ - Standard Conv2d instead of custom spherical stencil.
22
+ - No gauges (G=1 implicit) and no cell_ids.
23
+
24
+ Shapes
25
+ ------
26
+ Input : (B, C_in, 3*in_nside, 4*in_nside)
27
+ Output : (B, C_out, 3*in_nside, 4*in_nside)
28
+
29
+ Constraints
30
+ -----------
31
+ `in_nside` must be divisible by 2**depth, where depth == len(chanlist).
32
+ """
33
+
34
+ def __init__(
35
+ self,
36
+ *,
37
+ in_nside: int,
38
+ n_chan_in: int,
39
+ chanlist: List[int],
40
+ KERNELSZ: int = 3,
41
+ task: Literal['regression', 'segmentation'] = 'regression',
42
+ out_channels: int = 1,
43
+ final_activation: Optional[Literal['none', 'sigmoid', 'softmax']] = None,
44
+ device: Optional[torch.device | str] = None,
45
+ down_type: Optional[Literal['mean','max']] = 'max',
46
+ dtype: Literal['float32','float64'] = 'float32',
47
+ head_reduce: Literal['mean','learned'] = 'mean', # kept for API symmetry
48
+ ) -> None:
49
+ super().__init__()
50
+
51
+ if len(chanlist) == 0:
52
+ raise ValueError("chanlist must be non-empty (depth >= 1)")
53
+ self.in_nside = int(in_nside)
54
+ self.n_chan_in = int(n_chan_in)
55
+ self.chanlist = list(map(int, chanlist))
56
+ self.KERNELSZ = int(KERNELSZ)
57
+ self.task = task
58
+ self.out_channels = int(out_channels)
59
+ self.down_type = down_type
60
+ self.dtype = torch.float32 if dtype == 'float32' else torch.float64
61
+ self.head_reduce = head_reduce
62
+
63
+ # default final activation consistent with HealpixUNet
64
+ if final_activation is None:
65
+ if task == 'regression':
66
+ self.final_activation = 'none'
67
+ else:
68
+ self.final_activation = 'sigmoid' if out_channels == 1 else 'softmax'
69
+ else:
70
+ self.final_activation = final_activation
71
+
72
+ # Resolve device
73
+ if device is None:
74
+ device = 'cuda' if torch.cuda.is_available() else 'cpu'
75
+ self.device = torch.device(device)
76
+
77
+ depth = len(self.chanlist)
78
+ # geometry
79
+ H0, W0 = 3 * self.in_nside, 4 * self.in_nside
80
+ # ensure divisibility by 2**depth
81
+ if (self.in_nside % (2 ** depth)) != 0:
82
+ raise ValueError(
83
+ f"in_nside={self.in_nside} must be divisible by 2**depth where depth={depth}"
84
+ )
85
+
86
+ padding = self.KERNELSZ // 2
87
+
88
+ # --- Encoder ---
89
+ enc_layers = []
90
+ inC = self.n_chan_in
91
+ self.skips_channels: List[int] = []
92
+ for outC in self.chanlist:
93
+ block = nn.Sequential(
94
+ nn.Conv2d(inC, outC, kernel_size=self.KERNELSZ, padding=padding, bias=False),
95
+ _norm_2d(outC, kind="group"),
96
+ nn.ReLU(inplace=True),
97
+ nn.Conv2d(outC, outC, kernel_size=self.KERNELSZ, padding=padding, bias=False),
98
+ _norm_2d(outC, kind="group"),
99
+ nn.ReLU(inplace=True),
100
+ )
101
+ enc_layers.append(block)
102
+ inC = outC
103
+ self.skips_channels.append(outC)
104
+ self.encoder = nn.ModuleList(enc_layers)
105
+
106
+ # Pools
107
+ if self.down_type == 'max':
108
+ self.pool = nn.MaxPool2d(kernel_size=2, stride=2)
109
+ else:
110
+ self.pool = nn.AvgPool2d(kernel_size=2, stride=2)
111
+
112
+ # --- Decoder ---
113
+ dec_layers = []
114
+ upconvs = []
115
+ for l in reversed(range(depth)):
116
+ skipC = self.skips_channels[l]
117
+ upC = self.skips_channels[l + 1] if (l + 1) < depth else self.skips_channels[l]
118
+ inC_dec = upC + skipC
119
+ outC_dec = skipC
120
+
121
+ upconvs.append(
122
+ nn.ConvTranspose2d(upC, upC, kernel_size=2, stride=2)
123
+ )
124
+ dec_layers.append(
125
+ nn.Sequential(
126
+ nn.Conv2d(inC_dec, outC_dec, kernel_size=self.KERNELSZ, padding=padding, bias=False),
127
+ _norm_2d(outC_dec, kind="group"),
128
+ nn.ReLU(inplace=True),
129
+ nn.Conv2d(outC_dec, outC_dec, kernel_size=self.KERNELSZ, padding=padding, bias=False),
130
+ _norm_2d(outC_dec, kind="group"),
131
+ nn.ReLU(inplace=True),
132
+ )
133
+ )
134
+ self.upconvs = nn.ModuleList(upconvs)
135
+ self.decoder = nn.ModuleList(dec_layers)
136
+
137
+ # --- Head ---
138
+ head_inC = self.chanlist[0]
139
+ self.head_conv = nn.Conv2d(head_inC, self.out_channels, kernel_size=self.KERNELSZ, padding=padding)
140
+ self.head_bn = _norm_2d(self.out_channels, kind="group") if self.task == 'segmentation' else None
141
+
142
+ # optional learned mixer kept for API compatibility (no gauges here)
143
+ self.head_mixer = nn.Identity()
144
+
145
+ self.to(self.device, dtype=self.dtype)
146
+
147
+ def to_tensor(self,x):
148
+ return torch.tensor(x,device=self.device)
149
+
150
+ def to_numpy(self,x):
151
+ if isinstance(x,np.ndarray):
152
+ return x
153
+ return x.cpu().numpy()
154
+
155
+ # -------------------------- forward --------------------------
156
+ def forward(self, x: torch.Tensor) -> torch.Tensor:
157
+ """x: (B, C_in, H, W) with H=3*in_nside, W=4*in_nside"""
158
+ if x.dim() != 4:
159
+ raise ValueError("Input must be (B, C, H, W)")
160
+ if x.shape[1] != self.n_chan_in:
161
+ raise ValueError(f"Expected {self.n_chan_in} input channels, got {x.shape[1]}")
162
+
163
+ x = x.to(self.device, dtype=self.dtype)
164
+
165
+ skips = []
166
+ z = x
167
+ for l, block in enumerate(self.encoder):
168
+ z = block(z)
169
+ skips.append(z)
170
+ if l < len(self.encoder) - 1:
171
+ z = self.pool(z)
172
+
173
+ # Decoder
174
+ for d, l in enumerate(reversed(range(len(self.chanlist)))):
175
+ if l < len(self.chanlist) - 1:
176
+ z = self.upconvs[d](z)
177
+ # pad if odd due to pooling/upsampling asymmetry (shouldn't happen given divisibility)
178
+ sh = skips[l].shape
179
+ if z.shape[-2:] != sh[-2:]:
180
+ z = _pad_to_match(z, sh[-2], sh[-1])
181
+ z = torch.cat([skips[l], z], dim=1)
182
+ z = self.decoder[d](z)
183
+
184
+ y = self.head_conv(z)
185
+ if self.task == 'segmentation' and self.head_bn is not None:
186
+ y = self.head_bn(y)
187
+
188
+ if self.final_activation == 'sigmoid':
189
+ y = torch.sigmoid(y)
190
+ elif self.final_activation == 'softmax':
191
+ y = torch.softmax(y, dim=1)
192
+ return y
193
+
194
+ @torch.no_grad()
195
+ def predict(self, x: torch.Tensor, batch_size: int = 8) -> torch.Tensor:
196
+ self.eval()
197
+ outs = []
198
+ for i in range(0, x.shape[0], batch_size):
199
+ xb = x[i:i+batch_size]
200
+ outs.append(self.forward(xb))
201
+ return torch.cat(outs, dim=0)
202
+
203
+ @torch.no_grad()
204
+ def predict(
205
+ self,
206
+ x: torch.Tensor,
207
+ batch_size: int = 8,
208
+ *,
209
+ amp: bool = False,
210
+ out_device: Optional[str] = 'cpu',
211
+ out_dtype: Literal['float32','float16'] = 'float32',
212
+ show_pbar: bool = False,
213
+ ) -> torch.Tensor:
214
+ """Memory-safe prediction.
215
+ - Streams mini-batches avec torch.inference_mode() + AMP optionnel.
216
+ - Déplace chaque batch de sorties sur `out_device` (CPU par défaut) pour libérer la VRAM.
217
+ - Vérifie et explicite les erreurs de shape.
218
+ """
219
+ self.eval()
220
+
221
+ # --- checks & normalisation d'entrée ---
222
+ x = x if torch.is_tensor(x) else torch.as_tensor(x)
223
+ if x.ndim != 4:
224
+ raise ValueError(f"predict expects (N,C,H,W), got {tuple(getattr(x,'shape',()))}")
225
+ if x.shape[1] != self.n_chan_in:
226
+ raise ValueError(f"predict expected {self.n_chan_in} channels, got {x.shape[1]}")
227
+ n = int(x.shape[0])
228
+ if n == 0:
229
+ H, W = int(x.shape[-2]), int(x.shape[-1])
230
+ return torch.empty((0, self.out_channels, H, W), device=out_device or self.device)
231
+
232
+ # --- préparation ---
233
+ dtype_map = {'float32': torch.float32, 'float16': torch.float16}
234
+ out_dtype_t = dtype_map[out_dtype]
235
+ use_cuda = (self.device.type == 'cuda')
236
+ if use_cuda:
237
+ torch.backends.cudnn.benchmark = True
238
+
239
+ from math import ceil
240
+ nb = ceil(n / batch_size)
241
+ rng = range(0, n, batch_size)
242
+ if show_pbar:
243
+ try:
244
+ from tqdm import tqdm # type: ignore
245
+ rng = tqdm(rng, total=nb, desc='predict')
246
+ except Exception:
247
+ pass
248
+
249
+ # --- inférence batch par batch ---
250
+ out_list: List[torch.Tensor] = []
251
+ with torch.inference_mode():
252
+ ctx = (torch.cuda.amp.autocast() if (amp and use_cuda) else nullcontext())
253
+ for i in rng:
254
+ xb = x[i:i+batch_size].to(self.device, dtype=self.dtype, non_blocking=True)
255
+ with ctx:
256
+ yb = self.forward(xb)
257
+ # Déplacer la sortie vers l'appareil voulu (CPU par défaut)
258
+ yb = yb.to(out_device, dtype=out_dtype_t) if out_device is not None else yb.to(dtype=out_dtype_t)
259
+ out_list.append(yb)
260
+ del xb, yb
261
+ if use_cuda:
262
+ torch.cuda.empty_cache()
263
+
264
+ if not out_list:
265
+ raise RuntimeError(f"predict produced no outputs; check input shape {tuple(x.shape)} and batch_size={batch_size}")
266
+ return torch.cat(out_list, dim=0)
267
+
268
+ # -----------------------------
269
+ # Helpers
270
+ # -----------------------------
271
+
272
+ def _norm_2d(C: int, kind: str = "group", **kwargs) -> nn.Module:
273
+ if kind == "group":
274
+ num_groups = kwargs.get("num_groups", min(8, max(1, C // 8)) or 1)
275
+ while C % num_groups != 0 and num_groups > 1:
276
+ num_groups //= 2
277
+ return nn.GroupNorm(num_groups=num_groups, num_channels=C)
278
+ elif kind == "instance":
279
+ return nn.InstanceNorm2d(C, affine=True, track_running_stats=False)
280
+ elif kind == "batch":
281
+ return nn.BatchNorm2d(C)
282
+ else:
283
+ raise ValueError(f"Unknown norm kind: {kind}")
284
+
285
+
286
+ def _pad_to_match(x: torch.Tensor, H: int, W: int) -> torch.Tensor:
287
+ """Pad x (B,C,h,w) with zeros on right/bottom to reach (H,W)."""
288
+ _, _, h, w = x.shape
289
+ ph = max(0, H - h)
290
+ pw = max(0, W - w)
291
+ if ph == 0 and pw == 0:
292
+ return x
293
+ return F.pad(x, (0, pw, 0, ph), mode='constant', value=0)
294
+
295
+
296
+ # -----------------------------
297
+ # Training utilities (mirror of Healpix fit)
298
+ # -----------------------------
299
+ from typing import Union
300
+ import numpy as np
301
+ from torch.utils.data import DataLoader, TensorDataset
302
+
303
+ def fit(
304
+ model: nn.Module,
305
+ x_train: Union[torch.Tensor, np.ndarray],
306
+ y_train: Union[torch.Tensor, np.ndarray],
307
+ *,
308
+ n_epoch: int = 10,
309
+ view_epoch: int = 10,
310
+ batch_size: int = 16,
311
+ lr: float = 1e-3,
312
+ weight_decay: float = 0.0,
313
+ clip_grad_norm: Optional[float] = None,
314
+ verbose: bool = True,
315
+ optimizer: Literal['ADAM', 'LBFGS'] = 'ADAM',
316
+ ) -> dict:
317
+ """Training loop *miroir* de `healpix_unet_torch.fit`, adapté aux images 2D.
318
+
319
+ - Entrées fixes: tensors/ndarrays de même taille (B, C, H, W) avec H=3*nside, W=4*nside
320
+ - Perte: MSE (regression) / BCE(BCEWithLogits si final_activation='none') / CrossEntropy (multiclasses)
321
+ - Optimiseur: ADAM ou LBFGS avec closure
322
+ - Logs: renvoie {"loss": history}
323
+ """
324
+ device = next(model.parameters()).device
325
+ model.to(device)
326
+
327
+ # ---- DataLoader
328
+ x_t = torch.as_tensor(x_train, dtype=torch.float32, device=device)
329
+ y_is_class = (getattr(model, 'task', 'regression') != 'regression' and getattr(model, 'out_channels', 1) > 1)
330
+ y_dtype = torch.long if y_is_class and (not torch.is_tensor(y_train) or y_train.ndim == x_t.ndim - 1) else torch.float32
331
+ y_t = torch.as_tensor(y_train, dtype=y_dtype, device=device)
332
+
333
+ ds = TensorDataset(x_t, y_t)
334
+ loader = DataLoader(ds, batch_size=batch_size, shuffle=True, drop_last=False)
335
+
336
+ # ---- Loss
337
+ if getattr(model, 'task', 'regression') == 'regression':
338
+ criterion = nn.MSELoss(reduction='mean')
339
+ seg_multiclass = False
340
+ else:
341
+ if getattr(model, 'out_channels', 1) == 1:
342
+ criterion = nn.BCEWithLogitsLoss() if getattr(model, 'final_activation', 'none') == 'none' else nn.BCELoss()
343
+ seg_multiclass = False
344
+ else:
345
+ criterion = nn.CrossEntropyLoss()
346
+ seg_multiclass = True
347
+
348
+ # ---- Optim
349
+ if optimizer.upper() == 'ADAM':
350
+ optim = torch.optim.Adam(model.parameters(), lr=lr, weight_decay=weight_decay)
351
+ outer, inner = n_epoch, 1
352
+ elif optimizer.upper() == 'LBFGS':
353
+ optim = torch.optim.LBFGS(model.parameters(), lr=lr, max_iter=20, history_size=50, line_search_fn='strong_wolfe')
354
+ outer, inner = max(1, n_epoch // 20), 20
355
+ else:
356
+ raise ValueError("optimizer must be 'ADAM' or 'LBFGS'")
357
+
358
+ # ---- Train
359
+ history: List[float] = []
360
+ model.train()
361
+
362
+ for epoch in range(outer):
363
+ for _ in range(inner):
364
+ epoch_loss, n_samples = 0.0, 0
365
+ for xb, yb in loader:
366
+ xb = xb.to(device, dtype=torch.float32, non_blocking=True)
367
+ yb = yb.to(device, non_blocking=True)
368
+
369
+ if isinstance(optim, torch.optim.LBFGS):
370
+ def closure():
371
+ optim.zero_grad(set_to_none=True)
372
+ preds = model(xb)
373
+ if seg_multiclass:
374
+ loss = criterion(preds, yb)
375
+ else:
376
+ loss = criterion(preds, yb)
377
+ loss.backward()
378
+ return loss
379
+ loss_val = float(optim.step(closure).item())
380
+ else:
381
+ optim.zero_grad(set_to_none=True)
382
+ preds = model(xb)
383
+ if seg_multiclass:
384
+ loss = criterion(preds, yb)
385
+ else:
386
+ loss = criterion(preds, yb)
387
+ loss.backward()
388
+ if clip_grad_norm is not None:
389
+ torch.nn.utils.clip_grad_norm_(model.parameters(), clip_grad_norm)
390
+ optim.step()
391
+ loss_val = float(loss.item())
392
+
393
+ epoch_loss += loss_val * xb.shape[0]
394
+ n_samples += xb.shape[0]
395
+
396
+ epoch_loss /= max(1, n_samples)
397
+ history.append(epoch_loss)
398
+ if verbose and ((len(history) % view_epoch == 0) or (len(history) == 1)):
399
+ print(f"[epoch {len(history)}] loss={epoch_loss:.6f}")
400
+
401
+ return {"loss": history}
402
+
403
+
404
+ # -----------------------------
405
+ # Minimal smoke test
406
+ # -----------------------------
407
+ if __name__ == "__main__":
408
+ torch.manual_seed(0)
409
+ nside = 32
410
+ chanlist = [16, 32, 64]
411
+ net = PlanarUNet(
412
+ in_nside=nside,
413
+ n_chan_in=3,
414
+ chanlist=chanlist,
415
+ KERNELSZ=3,
416
+ task='regression',
417
+ out_channels=1,
418
+ )
419
+ x = torch.randn(2, 3, 3*nside, 4*nside)
420
+ y = net(x)
421
+ print('input:', x.shape, 'output:', y.shape)
@@ -1,6 +1,6 @@
1
1
  Metadata-Version: 2.4
2
2
  Name: foscat
3
- Version: 2025.9.4
3
+ Version: 2025.10.2
4
4
  Summary: Generate synthetic Healpix or 2D data using Cross Scattering Transform
5
5
  Author-email: Jean-Marc DELOUIS <jean.marc.delouis@ifremer.fr>
6
6
  Maintainer-email: Theo Foulquier <theo.foulquier@ifremer.fr>
@@ -4,12 +4,12 @@ foscat/BkTensorflow.py,sha256=iIdLx6VTOfOEocfZBOGyizQn5geDLTfdWWAwDeQr9YA,20056
4
4
  foscat/BkTorch.py,sha256=W3n3XkFw9oecmyfSWt2JbP5L5eKWn8AAu0RZx1yLQb0,31975
5
5
  foscat/CNN.py,sha256=4vky7jqTshL1aYLWsc-hQwf7gDjTVjL7I6HZiAsa6x4,5158
6
6
  foscat/CircSpline.py,sha256=CXi49FxF8ZoeZ17Ua8c1AZXe2B5ICEC9aCXb97atB3s,4028
7
- foscat/FoCUS.py,sha256=DFPzRMRHE7eUfXfqoqHiM0f4mAiOs_6GmUieipXplo4,105645
7
+ foscat/FoCUS.py,sha256=9GHjHUBM8sfmYwZw7u-xGeOOjFCpbbvM-5m3yjPSt5g,105752
8
8
  foscat/GCNN.py,sha256=q7yWHCMJpP7-m3WvR3OQnp5taeYWaMxIY2hQ6SIb9gs,4487
9
9
  foscat/HOrientedConvol.py,sha256=xMaS-zzoUyXisBCPsHBVpn54tuA9Qv3na-tT86Cwn7U,38744
10
10
  foscat/HealBili.py,sha256=YRPk9PO5G8NdwKeb33xiJs3_pMPAgIv5phCs8GT6LN0,12943
11
11
  foscat/HealSpline.py,sha256=YRotJ1NQuXYFyFiM8fp6qkATIwRJ8lqIVo4vGXpHO-w,7472
12
- foscat/Plot.py,sha256=a_NKE91F82qBeLNL3gy0qE06GCL-JbHgeIQ2iwIRPjY,47326
12
+ foscat/Plot.py,sha256=bpohWGsblTBxMrqE_X-iRvuvT-YyHDcgfWB4iYk5l10,49218
13
13
  foscat/Softmax.py,sha256=UDZGrTroYtmGEyokGUVpwNO_cgbICi9QVuRr8Yx52_k,2917
14
14
  foscat/SphericalStencil.py,sha256=DFipxQWupYPBa_62CkuMe9K7HkLFiHteF7Q6lAs3TLQ,56714
15
15
  foscat/Spline1D.py,sha256=rKzzenduaZZ-yBDJd35it6Gyrj1spqb7hoIaUgISPzY,2983
@@ -20,9 +20,13 @@ foscat/alm.py,sha256=XkK4rFVRoO-oJpr74iBffKt7hdS_iJkR016IlYm10gQ,33832
20
20
  foscat/backend.py,sha256=l3aMwDyXP6jURMIvratFMGWCTcQpaR68KnUuuGDezqE,45418
21
21
  foscat/backend_tens.py,sha256=9Dp136m9frkclkwifJQLLbIpl3ETI3_txdPUZcKfuMw,1618
22
22
  foscat/heal_NN.py,sha256=krEHM9NMZ74T9HUf-qK5td0tFttBA5SbaRgzThM2GYs,16943
23
- foscat/healpix_unet_torch.py,sha256=LXu5pDptDQN-mn3Fv7-I0Q_g_i4gPoEd5b9BmUTTSdU,51279
23
+ foscat/healpix_unet_torch.py,sha256=CDdrWyPJqF_gUT6rwB-TnghsbtkJ89WFw-K14m37DDQ,52221
24
+ foscat/healpix_vit_skip.py,sha256=26qpYoX7W1vCJujqtYUiRPUmwrDf_UJSN5kbL7DVV8I,20359
25
+ foscat/healpix_vit_torch-old.py,sha256=_PJecWRIWJc2FTQB_rthEeqLKYJ_UIdTN8ib_JuQ_xw,28985
26
+ foscat/healpix_vit_torch.py,sha256=XOqEOazob6WRptSD5dh5acWM_1_uErKXaVWPm08X-O0,22198
24
27
  foscat/loss_backend_tens.py,sha256=dCOVN6faDtIpN3VO78HTmYP2i5fnFAf-Ddy5qVBlGrM,1783
25
28
  foscat/loss_backend_torch.py,sha256=k3z18Dj3SaLKK6ZIKcm7GO4U_YKYVP6LtHG1aIbxkYk,1627
29
+ foscat/planar_vit.py,sha256=lQqwyz_P8G-Dav2vLqgkssDfeSe15YmjFzP5W-otjs0,6888
26
30
  foscat/scat.py,sha256=qGYiBIysPt65MdmF07WWA4piVlTfA9-lFDTaicnqC2w,72822
27
31
  foscat/scat1D.py,sha256=W5Uu6wdQ4ZsFKXpof0f1OBl-1wjJmW7ruvddRWxe7uM,53726
28
32
  foscat/scat2D.py,sha256=boKj0ASqMMSy7uQLK6hPniG87m3hZGJBYBiq5v8F9IQ,532
@@ -31,8 +35,9 @@ foscat/scat_cov1D.py,sha256=XOxsZZ5TYq8f34i2tUgIfzyaqaTDlICB3HzD2l_puro,531
31
35
  foscat/scat_cov2D.py,sha256=pAm0fKw8wyXram0TFbtw8tGcc8QPKuPXpQk0kh10r4U,7078
32
36
  foscat/scat_cov_map.py,sha256=9MzpwT2g9S3dmnjHEMK7PPLQ27oGQg2VFVsP_TDUU5E,2869
33
37
  foscat/scat_cov_map2D.py,sha256=zaIIYshXCqAeZ04I158GhD-Op4aoMlLnLEy7rxckVYY,2842
34
- foscat-2025.9.4.dist-info/licenses/LICENSE,sha256=i0ukIr8ZUpkSY2sZaE9XZK-6vuSU5iG6IgX_3pjatP8,1505
35
- foscat-2025.9.4.dist-info/METADATA,sha256=MznoPNHG--ttgyql7lKM2YZrrYz3AiHLag3yn2C04iM,7215
36
- foscat-2025.9.4.dist-info/WHEEL,sha256=_zCd3N1l69ArxyTb8rzEoP9TpbYXkqRFSNOD5OuxnTs,91
37
- foscat-2025.9.4.dist-info/top_level.txt,sha256=AGySXBBAlJgb8Tj8af6m_F-aiNg2zNTcybCUPVOKjAg,7
38
- foscat-2025.9.4.dist-info/RECORD,,
38
+ foscat/unet_2_d_from_healpix_params.py,sha256=r8hN-s091f3yHYlvAAiBbLOvtsz9vPrdwrWPM0ULR2Q,15949
39
+ foscat-2025.10.2.dist-info/licenses/LICENSE,sha256=i0ukIr8ZUpkSY2sZaE9XZK-6vuSU5iG6IgX_3pjatP8,1505
40
+ foscat-2025.10.2.dist-info/METADATA,sha256=WTeul0o0mER67oWvadXX0JV9hn1R3yAesPHOgPn8dys,7216
41
+ foscat-2025.10.2.dist-info/WHEEL,sha256=_zCd3N1l69ArxyTb8rzEoP9TpbYXkqRFSNOD5OuxnTs,91
42
+ foscat-2025.10.2.dist-info/top_level.txt,sha256=AGySXBBAlJgb8Tj8af6m_F-aiNg2zNTcybCUPVOKjAg,7
43
+ foscat-2025.10.2.dist-info/RECORD,,