foscat 2025.9.5__py3-none-any.whl → 2025.11.1__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
foscat/planar_vit.py ADDED
@@ -0,0 +1,206 @@
1
+ # healpix_unet_torch.py
2
+ # (Planar Vision Transformer baseline for lat–lon grids)
3
+ from __future__ import annotations
4
+ from typing import Optional
5
+
6
+ import torch
7
+ import torch.nn as nn
8
+ import torch.nn.functional as F
9
+
10
+
11
+ # ---------------------------
12
+ # Building blocks
13
+ # ---------------------------
14
+
15
+ class _MLP(nn.Module):
16
+ """ViT MLP: Linear -> GELU -> Dropout -> Linear -> Dropout."""
17
+ def __init__(self, dim: int, mlp_ratio: float = 4.0, drop: float = 0.1):
18
+ super().__init__()
19
+ hidden = int(dim * mlp_ratio)
20
+ self.fc1 = nn.Linear(dim, hidden)
21
+ self.fc2 = nn.Linear(hidden, dim)
22
+ self.act = nn.GELU()
23
+ self.drop = nn.Dropout(drop)
24
+
25
+ def forward(self, x: torch.Tensor) -> torch.Tensor:
26
+ x = self.drop(self.act(self.fc1(x)))
27
+ x = self.drop(self.fc2(x))
28
+ return x
29
+
30
+
31
+ class _ViTBlock(nn.Module):
32
+ """
33
+ Transformer block (Pre-LN):
34
+ x = x + Drop(MHA(LN(x)))
35
+ x = x + Drop(MLP(LN(x)))
36
+ """
37
+ def __init__(self, dim: int, num_heads: int, mlp_ratio: float = 4.0, drop: float = 0.1):
38
+ super().__init__()
39
+ assert dim % num_heads == 0, "embed_dim must be divisible by num_heads"
40
+ self.norm1 = nn.LayerNorm(dim)
41
+ self.attn = nn.MultiheadAttention(dim, num_heads, dropout=drop, batch_first=True)
42
+ self.norm2 = nn.LayerNorm(dim)
43
+ self.mlp = _MLP(dim, mlp_ratio, drop)
44
+ self.drop_path = nn.Dropout(drop)
45
+
46
+ def forward(self, x: torch.Tensor) -> torch.Tensor:
47
+ # Multi-head self-attention
48
+ x = x + self.drop_path(self.attn(self.norm1(x), self.norm1(x), self.norm1(x))[0])
49
+ # Feed-forward
50
+ x = x + self.drop_path(self.mlp(self.norm2(x)))
51
+ return x
52
+
53
+
54
+ # ---------------------------
55
+ # Planar ViT (lat–lon images)
56
+ # ---------------------------
57
+
58
+ class PlanarViT(nn.Module):
59
+ """
60
+ Vision Transformer for 2D lat–lon grids (planar baseline).
61
+
62
+ Input : (B, C=T_in, H, W)
63
+ Output: (B, out_ch, H, W) # dense per-pixel prediction
64
+
65
+ Pipeline
66
+ --------
67
+ 1) Patch embedding via Conv2d(kernel_size=patch, stride=patch) -> embed_dim
68
+ 2) Optional CLS token (disabled by default for dense output)
69
+ 3) Learned positional embeddings (or none)
70
+ 4) Stack of Transformer blocks
71
+ 5) Linear head per token, then nearest upsample back to (H, W)
72
+
73
+ Notes
74
+ -----
75
+ - Keep H, W divisible by `patch`.
76
+ - For residual-of-persistence training (recommended for monthly SST):
77
+ pred = x[:, -1:, ...] + model(x)
78
+ and train the loss on `pred` vs target.
79
+ """
80
+ def __init__(
81
+ self,
82
+ in_ch: int, # e.g., T_in months
83
+ H: int,
84
+ W: int,
85
+ *,
86
+ embed_dim: int = 384,
87
+ depth: int = 8,
88
+ num_heads: int = 12,
89
+ mlp_ratio: float = 4.0,
90
+ patch: int = 4,
91
+ out_ch: int = 1,
92
+ dropout: float = 0.1,
93
+ cls_token: bool = False, # keep False for dense prediction
94
+ pos_embed: str = "learned", # or "none"
95
+ ):
96
+ super().__init__()
97
+ assert H % patch == 0 and W % patch == 0, "H and W must be divisible by patch"
98
+ assert embed_dim % num_heads == 0, "embed_dim must be divisible by num_heads"
99
+
100
+ self.H, self.W = H, W
101
+ self.patch = patch
102
+ self.embed_dim = embed_dim
103
+ self.cls_token_enabled = bool(cls_token)
104
+ self.use_pos_embed = (pos_embed == "learned")
105
+
106
+ # 1) Patch embedding (Conv2d with stride=patch) → tokens
107
+ self.patch_embed = nn.Conv2d(in_ch, embed_dim, kernel_size=patch, stride=patch)
108
+
109
+ # 2) Token bookkeeping & positional embeddings
110
+ Hp, Wp = H // patch, W // patch
111
+ self.num_tokens = Hp * Wp + (1 if self.cls_token_enabled else 0)
112
+
113
+ if self.cls_token_enabled:
114
+ self.cls_token = nn.Parameter(torch.zeros(1, 1, embed_dim))
115
+ nn.init.trunc_normal_(self.cls_token, std=0.02)
116
+ else:
117
+ self.cls_token = None
118
+
119
+ if self.use_pos_embed:
120
+ self.pos_embed = nn.Parameter(torch.zeros(1, self.num_tokens, embed_dim))
121
+ nn.init.trunc_normal_(self.pos_embed, std=0.02)
122
+ else:
123
+ self.pos_embed = None
124
+
125
+ # 3) Transformer encoder
126
+ self.blocks = nn.ModuleList([
127
+ _ViTBlock(embed_dim, num_heads, mlp_ratio=mlp_ratio, drop=dropout)
128
+ for _ in range(depth)
129
+ ])
130
+
131
+ # 4) Patch-wise head (token -> out_ch)
132
+ self.head = nn.Linear(embed_dim, out_ch)
133
+
134
+ # Store for unpatching
135
+ self.Hp, self.Wp = Hp, Wp
136
+
137
+ def forward(self, x: torch.Tensor) -> torch.Tensor:
138
+ """
139
+ x: (B, C, H, W) with H,W fixed to construction-time H,W
140
+ returns: (B, out_ch, H, W)
141
+ """
142
+ B, C, H, W = x.shape
143
+ if (H != self.H) or (W != self.W):
144
+ raise ValueError(f"Input H,W must be ({self.H},{self.W}), got ({H},{W}).")
145
+
146
+ # Patch embedding → (B, E, Hp, Wp) → (B, Np, E)
147
+ z = self.patch_embed(x) # (B, E, Hp, Wp)
148
+ z = z.flatten(2).transpose(1, 2) # (B, Np, E)
149
+
150
+ # Optional CLS
151
+ if self.cls_token_enabled:
152
+ cls = self.cls_token.expand(B, -1, -1) # (B,1,E)
153
+ z = torch.cat([cls, z], dim=1) # (B,1+Np,E)
154
+
155
+ # Positional embedding
156
+ if self.pos_embed is not None:
157
+ z = z + self.pos_embed[:, :z.shape[1], :]
158
+
159
+ # Transformer
160
+ for blk in self.blocks:
161
+ z = blk(z) # (B, N, E)
162
+
163
+ # Drop CLS for dense output
164
+ if self.cls_token_enabled:
165
+ tokens = z[:, 1:, :] # (B, Np, E)
166
+ else:
167
+ tokens = z
168
+
169
+ # Token head → (B, Np, out_ch) → (B, out_ch, Hp, Wp) → upsample to (H, W)
170
+ y_tok = self.head(tokens).transpose(1, 2) # (B, out_ch, Np)
171
+ y = y_tok.reshape(B, -1, self.Hp, self.Wp) # (B, out_ch, Hp, Wp)
172
+ y = F.interpolate(y, scale_factor=self.patch, mode="nearest")
173
+ return y
174
+
175
+
176
+ # ---------------------------
177
+ # Utilities
178
+ # ---------------------------
179
+
180
+ def count_parameters(model: nn.Module) -> tuple[int, int]:
181
+ """Return (total_params, trainable_params)."""
182
+ total = sum(p.numel() for p in model.parameters())
183
+ trainable = sum(p.numel() for p in model.parameters() if p.requires_grad)
184
+ return total, trainable
185
+
186
+
187
+ # ---------------------------
188
+ # Smoke test
189
+ # ---------------------------
190
+
191
+ if __name__ == "__main__":
192
+ # Example: T_in=6, grid 128x256, predict 1 channel
193
+ B, C, H, W = 2, 6, 128, 256
194
+ x = torch.randn(B, C, H, W)
195
+
196
+ model = PlanarViT(
197
+ in_ch=C, H=H, W=W,
198
+ embed_dim=384, depth=8, num_heads=12,
199
+ mlp_ratio=4.0, patch=4, out_ch=1, dropout=0.1,
200
+ cls_token=False, pos_embed="learned"
201
+ )
202
+ y = model(x)
203
+ tot, trn = count_parameters(model)
204
+ print("Output:", tuple(y.shape))
205
+ print("Params:", f"total={tot:,}", f"trainable={trn:,}")
206
+
foscat/scat.py CHANGED
@@ -1659,7 +1659,7 @@ class funct(FOC.FoCUS):
1659
1659
  s2j2 = None
1660
1660
  l2_image = None
1661
1661
  for j1 in range(jmax):
1662
- if j1 < jmax - self.OSTEP: # stop to add scales
1662
+ if j1 < jmax: # stop to add scales
1663
1663
  # Convol image along the axis defined by 'axis' using the wavelet defined at
1664
1664
  # the foscat initialisation
1665
1665
  # c_image_real is [....,Npix_j1,....,Norient]
foscat/scat1D.py CHANGED
@@ -1282,7 +1282,7 @@ class funct(FOC.FoCUS):
1282
1282
  l2_image = None
1283
1283
 
1284
1284
  for j1 in range(jmax):
1285
- if j1 < jmax - self.OSTEP: # stop to add scales
1285
+ if j1 < jmax: # stop to add scales
1286
1286
  # Convol image along the axis defined by 'axis' using the wavelet defined at
1287
1287
  # the foscat initialisation
1288
1288
  # c_image_real is [....,Npix_j1,....,Norient]
foscat/scat_cov.py CHANGED
@@ -350,7 +350,7 @@ class scat_cov:
350
350
  self.backend.bk_flattenR(self.S3[0]),
351
351
  self.backend.bk_flattenR(self.S4[0]),
352
352
  ],
353
- 0,
353
+ 0,
354
354
  )
355
355
  else:
356
356
  tmp = self.backend.bk_concat(
@@ -2636,7 +2636,7 @@ class funct(FOC.FoCUS):
2636
2636
  if mask is not None:
2637
2637
  if self.use_2D:
2638
2638
  if (
2639
- image1.shape[-2] != mask.shape[-1]
2639
+ image1.shape[-2] != mask.shape[-2]
2640
2640
  or image1.shape[-1] != mask.shape[-1]
2641
2641
  ):
2642
2642
  print(
@@ -0,0 +1,421 @@
1
+ from __future__ import annotations
2
+ from typing import List, Optional, Literal, Tuple
3
+ import torch
4
+ import torch.nn as nn
5
+ import torch.nn.functional as F
6
+ from contextlib import nullcontext
7
+
8
+ class PlanarUNet(nn.Module):
9
+ """
10
+ U-Net 2D (images HxW) mirroring the parameterization of the HealpixUNet.
11
+
12
+ Key compat points with HealpixUNet:
13
+ - Same constructor fields: in_nside, n_chan_in, chanlist, KERNELSZ, task,
14
+ out_channels, final_activation, device, down_type, dtype, head_reduce.
15
+ - Two convs per level (encoder & decoder), GroupNorm + ReLU after each conv.
16
+ - Downsampling by factor 2 at each level; upsampling mirrors back.
17
+ - Head produces `out_channels` with optional BN and final activation.
18
+
19
+ Differences vs sphere version:
20
+ - Operates on regular 2D images of size (3*in_nside, 4*in_nside).
21
+ - Standard Conv2d instead of custom spherical stencil.
22
+ - No gauges (G=1 implicit) and no cell_ids.
23
+
24
+ Shapes
25
+ ------
26
+ Input : (B, C_in, 3*in_nside, 4*in_nside)
27
+ Output : (B, C_out, 3*in_nside, 4*in_nside)
28
+
29
+ Constraints
30
+ -----------
31
+ `in_nside` must be divisible by 2**depth, where depth == len(chanlist).
32
+ """
33
+
34
+ def __init__(
35
+ self,
36
+ *,
37
+ in_nside: int,
38
+ n_chan_in: int,
39
+ chanlist: List[int],
40
+ KERNELSZ: int = 3,
41
+ task: Literal['regression', 'segmentation'] = 'regression',
42
+ out_channels: int = 1,
43
+ final_activation: Optional[Literal['none', 'sigmoid', 'softmax']] = None,
44
+ device: Optional[torch.device | str] = None,
45
+ down_type: Optional[Literal['mean','max']] = 'max',
46
+ dtype: Literal['float32','float64'] = 'float32',
47
+ head_reduce: Literal['mean','learned'] = 'mean', # kept for API symmetry
48
+ ) -> None:
49
+ super().__init__()
50
+
51
+ if len(chanlist) == 0:
52
+ raise ValueError("chanlist must be non-empty (depth >= 1)")
53
+ self.in_nside = int(in_nside)
54
+ self.n_chan_in = int(n_chan_in)
55
+ self.chanlist = list(map(int, chanlist))
56
+ self.KERNELSZ = int(KERNELSZ)
57
+ self.task = task
58
+ self.out_channels = int(out_channels)
59
+ self.down_type = down_type
60
+ self.dtype = torch.float32 if dtype == 'float32' else torch.float64
61
+ self.head_reduce = head_reduce
62
+
63
+ # default final activation consistent with HealpixUNet
64
+ if final_activation is None:
65
+ if task == 'regression':
66
+ self.final_activation = 'none'
67
+ else:
68
+ self.final_activation = 'sigmoid' if out_channels == 1 else 'softmax'
69
+ else:
70
+ self.final_activation = final_activation
71
+
72
+ # Resolve device
73
+ if device is None:
74
+ device = 'cuda' if torch.cuda.is_available() else 'cpu'
75
+ self.device = torch.device(device)
76
+
77
+ depth = len(self.chanlist)
78
+ # geometry
79
+ H0, W0 = 3 * self.in_nside, 4 * self.in_nside
80
+ # ensure divisibility by 2**depth
81
+ if (self.in_nside % (2 ** depth)) != 0:
82
+ raise ValueError(
83
+ f"in_nside={self.in_nside} must be divisible by 2**depth where depth={depth}"
84
+ )
85
+
86
+ padding = self.KERNELSZ // 2
87
+
88
+ # --- Encoder ---
89
+ enc_layers = []
90
+ inC = self.n_chan_in
91
+ self.skips_channels: List[int] = []
92
+ for outC in self.chanlist:
93
+ block = nn.Sequential(
94
+ nn.Conv2d(inC, outC, kernel_size=self.KERNELSZ, padding=padding, bias=False),
95
+ _norm_2d(outC, kind="group"),
96
+ nn.ReLU(inplace=True),
97
+ nn.Conv2d(outC, outC, kernel_size=self.KERNELSZ, padding=padding, bias=False),
98
+ _norm_2d(outC, kind="group"),
99
+ nn.ReLU(inplace=True),
100
+ )
101
+ enc_layers.append(block)
102
+ inC = outC
103
+ self.skips_channels.append(outC)
104
+ self.encoder = nn.ModuleList(enc_layers)
105
+
106
+ # Pools
107
+ if self.down_type == 'max':
108
+ self.pool = nn.MaxPool2d(kernel_size=2, stride=2)
109
+ else:
110
+ self.pool = nn.AvgPool2d(kernel_size=2, stride=2)
111
+
112
+ # --- Decoder ---
113
+ dec_layers = []
114
+ upconvs = []
115
+ for l in reversed(range(depth)):
116
+ skipC = self.skips_channels[l]
117
+ upC = self.skips_channels[l + 1] if (l + 1) < depth else self.skips_channels[l]
118
+ inC_dec = upC + skipC
119
+ outC_dec = skipC
120
+
121
+ upconvs.append(
122
+ nn.ConvTranspose2d(upC, upC, kernel_size=2, stride=2)
123
+ )
124
+ dec_layers.append(
125
+ nn.Sequential(
126
+ nn.Conv2d(inC_dec, outC_dec, kernel_size=self.KERNELSZ, padding=padding, bias=False),
127
+ _norm_2d(outC_dec, kind="group"),
128
+ nn.ReLU(inplace=True),
129
+ nn.Conv2d(outC_dec, outC_dec, kernel_size=self.KERNELSZ, padding=padding, bias=False),
130
+ _norm_2d(outC_dec, kind="group"),
131
+ nn.ReLU(inplace=True),
132
+ )
133
+ )
134
+ self.upconvs = nn.ModuleList(upconvs)
135
+ self.decoder = nn.ModuleList(dec_layers)
136
+
137
+ # --- Head ---
138
+ head_inC = self.chanlist[0]
139
+ self.head_conv = nn.Conv2d(head_inC, self.out_channels, kernel_size=self.KERNELSZ, padding=padding)
140
+ self.head_bn = _norm_2d(self.out_channels, kind="group") if self.task == 'segmentation' else None
141
+
142
+ # optional learned mixer kept for API compatibility (no gauges here)
143
+ self.head_mixer = nn.Identity()
144
+
145
+ self.to(self.device, dtype=self.dtype)
146
+
147
+ def to_tensor(self,x):
148
+ return torch.tensor(x,device=self.device)
149
+
150
+ def to_numpy(self,x):
151
+ if isinstance(x,np.ndarray):
152
+ return x
153
+ return x.cpu().numpy()
154
+
155
+ # -------------------------- forward --------------------------
156
+ def forward(self, x: torch.Tensor) -> torch.Tensor:
157
+ """x: (B, C_in, H, W) with H=3*in_nside, W=4*in_nside"""
158
+ if x.dim() != 4:
159
+ raise ValueError("Input must be (B, C, H, W)")
160
+ if x.shape[1] != self.n_chan_in:
161
+ raise ValueError(f"Expected {self.n_chan_in} input channels, got {x.shape[1]}")
162
+
163
+ x = x.to(self.device, dtype=self.dtype)
164
+
165
+ skips = []
166
+ z = x
167
+ for l, block in enumerate(self.encoder):
168
+ z = block(z)
169
+ skips.append(z)
170
+ if l < len(self.encoder) - 1:
171
+ z = self.pool(z)
172
+
173
+ # Decoder
174
+ for d, l in enumerate(reversed(range(len(self.chanlist)))):
175
+ if l < len(self.chanlist) - 1:
176
+ z = self.upconvs[d](z)
177
+ # pad if odd due to pooling/upsampling asymmetry (shouldn't happen given divisibility)
178
+ sh = skips[l].shape
179
+ if z.shape[-2:] != sh[-2:]:
180
+ z = _pad_to_match(z, sh[-2], sh[-1])
181
+ z = torch.cat([skips[l], z], dim=1)
182
+ z = self.decoder[d](z)
183
+
184
+ y = self.head_conv(z)
185
+ if self.task == 'segmentation' and self.head_bn is not None:
186
+ y = self.head_bn(y)
187
+
188
+ if self.final_activation == 'sigmoid':
189
+ y = torch.sigmoid(y)
190
+ elif self.final_activation == 'softmax':
191
+ y = torch.softmax(y, dim=1)
192
+ return y
193
+
194
+ @torch.no_grad()
195
+ def predict(self, x: torch.Tensor, batch_size: int = 8) -> torch.Tensor:
196
+ self.eval()
197
+ outs = []
198
+ for i in range(0, x.shape[0], batch_size):
199
+ xb = x[i:i+batch_size]
200
+ outs.append(self.forward(xb))
201
+ return torch.cat(outs, dim=0)
202
+
203
+ @torch.no_grad()
204
+ def predict(
205
+ self,
206
+ x: torch.Tensor,
207
+ batch_size: int = 8,
208
+ *,
209
+ amp: bool = False,
210
+ out_device: Optional[str] = 'cpu',
211
+ out_dtype: Literal['float32','float16'] = 'float32',
212
+ show_pbar: bool = False,
213
+ ) -> torch.Tensor:
214
+ """Memory-safe prediction.
215
+ - Streams mini-batches avec torch.inference_mode() + AMP optionnel.
216
+ - Déplace chaque batch de sorties sur `out_device` (CPU par défaut) pour libérer la VRAM.
217
+ - Vérifie et explicite les erreurs de shape.
218
+ """
219
+ self.eval()
220
+
221
+ # --- checks & normalisation d'entrée ---
222
+ x = x if torch.is_tensor(x) else torch.as_tensor(x)
223
+ if x.ndim != 4:
224
+ raise ValueError(f"predict expects (N,C,H,W), got {tuple(getattr(x,'shape',()))}")
225
+ if x.shape[1] != self.n_chan_in:
226
+ raise ValueError(f"predict expected {self.n_chan_in} channels, got {x.shape[1]}")
227
+ n = int(x.shape[0])
228
+ if n == 0:
229
+ H, W = int(x.shape[-2]), int(x.shape[-1])
230
+ return torch.empty((0, self.out_channels, H, W), device=out_device or self.device)
231
+
232
+ # --- préparation ---
233
+ dtype_map = {'float32': torch.float32, 'float16': torch.float16}
234
+ out_dtype_t = dtype_map[out_dtype]
235
+ use_cuda = (self.device.type == 'cuda')
236
+ if use_cuda:
237
+ torch.backends.cudnn.benchmark = True
238
+
239
+ from math import ceil
240
+ nb = ceil(n / batch_size)
241
+ rng = range(0, n, batch_size)
242
+ if show_pbar:
243
+ try:
244
+ from tqdm import tqdm # type: ignore
245
+ rng = tqdm(rng, total=nb, desc='predict')
246
+ except Exception:
247
+ pass
248
+
249
+ # --- inférence batch par batch ---
250
+ out_list: List[torch.Tensor] = []
251
+ with torch.inference_mode():
252
+ ctx = (torch.cuda.amp.autocast() if (amp and use_cuda) else nullcontext())
253
+ for i in rng:
254
+ xb = x[i:i+batch_size].to(self.device, dtype=self.dtype, non_blocking=True)
255
+ with ctx:
256
+ yb = self.forward(xb)
257
+ # Déplacer la sortie vers l'appareil voulu (CPU par défaut)
258
+ yb = yb.to(out_device, dtype=out_dtype_t) if out_device is not None else yb.to(dtype=out_dtype_t)
259
+ out_list.append(yb)
260
+ del xb, yb
261
+ if use_cuda:
262
+ torch.cuda.empty_cache()
263
+
264
+ if not out_list:
265
+ raise RuntimeError(f"predict produced no outputs; check input shape {tuple(x.shape)} and batch_size={batch_size}")
266
+ return torch.cat(out_list, dim=0)
267
+
268
+ # -----------------------------
269
+ # Helpers
270
+ # -----------------------------
271
+
272
+ def _norm_2d(C: int, kind: str = "group", **kwargs) -> nn.Module:
273
+ if kind == "group":
274
+ num_groups = kwargs.get("num_groups", min(8, max(1, C // 8)) or 1)
275
+ while C % num_groups != 0 and num_groups > 1:
276
+ num_groups //= 2
277
+ return nn.GroupNorm(num_groups=num_groups, num_channels=C)
278
+ elif kind == "instance":
279
+ return nn.InstanceNorm2d(C, affine=True, track_running_stats=False)
280
+ elif kind == "batch":
281
+ return nn.BatchNorm2d(C)
282
+ else:
283
+ raise ValueError(f"Unknown norm kind: {kind}")
284
+
285
+
286
+ def _pad_to_match(x: torch.Tensor, H: int, W: int) -> torch.Tensor:
287
+ """Pad x (B,C,h,w) with zeros on right/bottom to reach (H,W)."""
288
+ _, _, h, w = x.shape
289
+ ph = max(0, H - h)
290
+ pw = max(0, W - w)
291
+ if ph == 0 and pw == 0:
292
+ return x
293
+ return F.pad(x, (0, pw, 0, ph), mode='constant', value=0)
294
+
295
+
296
+ # -----------------------------
297
+ # Training utilities (mirror of Healpix fit)
298
+ # -----------------------------
299
+ from typing import Union
300
+ import numpy as np
301
+ from torch.utils.data import DataLoader, TensorDataset
302
+
303
+ def fit(
304
+ model: nn.Module,
305
+ x_train: Union[torch.Tensor, np.ndarray],
306
+ y_train: Union[torch.Tensor, np.ndarray],
307
+ *,
308
+ n_epoch: int = 10,
309
+ view_epoch: int = 10,
310
+ batch_size: int = 16,
311
+ lr: float = 1e-3,
312
+ weight_decay: float = 0.0,
313
+ clip_grad_norm: Optional[float] = None,
314
+ verbose: bool = True,
315
+ optimizer: Literal['ADAM', 'LBFGS'] = 'ADAM',
316
+ ) -> dict:
317
+ """Training loop *miroir* de `healpix_unet_torch.fit`, adapté aux images 2D.
318
+
319
+ - Entrées fixes: tensors/ndarrays de même taille (B, C, H, W) avec H=3*nside, W=4*nside
320
+ - Perte: MSE (regression) / BCE(BCEWithLogits si final_activation='none') / CrossEntropy (multiclasses)
321
+ - Optimiseur: ADAM ou LBFGS avec closure
322
+ - Logs: renvoie {"loss": history}
323
+ """
324
+ device = next(model.parameters()).device
325
+ model.to(device)
326
+
327
+ # ---- DataLoader
328
+ x_t = torch.as_tensor(x_train, dtype=torch.float32, device=device)
329
+ y_is_class = (getattr(model, 'task', 'regression') != 'regression' and getattr(model, 'out_channels', 1) > 1)
330
+ y_dtype = torch.long if y_is_class and (not torch.is_tensor(y_train) or y_train.ndim == x_t.ndim - 1) else torch.float32
331
+ y_t = torch.as_tensor(y_train, dtype=y_dtype, device=device)
332
+
333
+ ds = TensorDataset(x_t, y_t)
334
+ loader = DataLoader(ds, batch_size=batch_size, shuffle=True, drop_last=False)
335
+
336
+ # ---- Loss
337
+ if getattr(model, 'task', 'regression') == 'regression':
338
+ criterion = nn.MSELoss(reduction='mean')
339
+ seg_multiclass = False
340
+ else:
341
+ if getattr(model, 'out_channels', 1) == 1:
342
+ criterion = nn.BCEWithLogitsLoss() if getattr(model, 'final_activation', 'none') == 'none' else nn.BCELoss()
343
+ seg_multiclass = False
344
+ else:
345
+ criterion = nn.CrossEntropyLoss()
346
+ seg_multiclass = True
347
+
348
+ # ---- Optim
349
+ if optimizer.upper() == 'ADAM':
350
+ optim = torch.optim.Adam(model.parameters(), lr=lr, weight_decay=weight_decay)
351
+ outer, inner = n_epoch, 1
352
+ elif optimizer.upper() == 'LBFGS':
353
+ optim = torch.optim.LBFGS(model.parameters(), lr=lr, max_iter=20, history_size=50, line_search_fn='strong_wolfe')
354
+ outer, inner = max(1, n_epoch // 20), 20
355
+ else:
356
+ raise ValueError("optimizer must be 'ADAM' or 'LBFGS'")
357
+
358
+ # ---- Train
359
+ history: List[float] = []
360
+ model.train()
361
+
362
+ for epoch in range(outer):
363
+ for _ in range(inner):
364
+ epoch_loss, n_samples = 0.0, 0
365
+ for xb, yb in loader:
366
+ xb = xb.to(device, dtype=torch.float32, non_blocking=True)
367
+ yb = yb.to(device, non_blocking=True)
368
+
369
+ if isinstance(optim, torch.optim.LBFGS):
370
+ def closure():
371
+ optim.zero_grad(set_to_none=True)
372
+ preds = model(xb)
373
+ if seg_multiclass:
374
+ loss = criterion(preds, yb)
375
+ else:
376
+ loss = criterion(preds, yb)
377
+ loss.backward()
378
+ return loss
379
+ loss_val = float(optim.step(closure).item())
380
+ else:
381
+ optim.zero_grad(set_to_none=True)
382
+ preds = model(xb)
383
+ if seg_multiclass:
384
+ loss = criterion(preds, yb)
385
+ else:
386
+ loss = criterion(preds, yb)
387
+ loss.backward()
388
+ if clip_grad_norm is not None:
389
+ torch.nn.utils.clip_grad_norm_(model.parameters(), clip_grad_norm)
390
+ optim.step()
391
+ loss_val = float(loss.item())
392
+
393
+ epoch_loss += loss_val * xb.shape[0]
394
+ n_samples += xb.shape[0]
395
+
396
+ epoch_loss /= max(1, n_samples)
397
+ history.append(epoch_loss)
398
+ if verbose and ((len(history) % view_epoch == 0) or (len(history) == 1)):
399
+ print(f"[epoch {len(history)}] loss={epoch_loss:.6f}")
400
+
401
+ return {"loss": history}
402
+
403
+
404
+ # -----------------------------
405
+ # Minimal smoke test
406
+ # -----------------------------
407
+ if __name__ == "__main__":
408
+ torch.manual_seed(0)
409
+ nside = 32
410
+ chanlist = [16, 32, 64]
411
+ net = PlanarUNet(
412
+ in_nside=nside,
413
+ n_chan_in=3,
414
+ chanlist=chanlist,
415
+ KERNELSZ=3,
416
+ task='regression',
417
+ out_channels=1,
418
+ )
419
+ x = torch.randn(2, 3, 3*nside, 4*nside)
420
+ y = net(x)
421
+ print('input:', x.shape, 'output:', y.shape)
@@ -1,6 +1,6 @@
1
1
  Metadata-Version: 2.4
2
2
  Name: foscat
3
- Version: 2025.9.5
3
+ Version: 2025.11.1
4
4
  Summary: Generate synthetic Healpix or 2D data using Cross Scattering Transform
5
5
  Author-email: Jean-Marc DELOUIS <jean.marc.delouis@ifremer.fr>
6
6
  Maintainer-email: Theo Foulquier <theo.foulquier@ifremer.fr>