foscat 2025.8.4__py3-none-any.whl → 2025.9.3__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
@@ -0,0 +1,1202 @@
1
+ """
2
+ HEALPix U-Net (nested) with Foscat + PyTorch niceties
3
+ ----------------------------------------------------
4
+ GPU by default (when available), with graceful CPU fallback if Foscat ops are CPU-only.
5
+
6
+ - ReLU + BatchNorm after each convolution (encoder & decoder)
7
+ - Segmentation/Regression heads with optional final activation
8
+ - PyTorch-ified: inherits from nn.Module, standard state_dict
9
+ - Device management: tries CUDA first; if Foscat SphericalStencil cannot run on CUDA, falls back to CPU
10
+
11
+ Shape convention: (B, C, Npix)
12
+
13
+ Requirements: foscat (scat_cov.funct + SphericalStencil.Convol_torch must be differentiable on torch tensors)
14
+ """
15
+ from __future__ import annotations
16
+ from typing import List, Optional, Literal, Tuple
17
+ import numpy as np
18
+
19
+ import torch
20
+ import torch.nn as nn
21
+ import healpy as hp
22
+
23
+ import torch.nn.functional as F
24
+ from torch.utils.data import Dataset, DataLoader, TensorDataset
25
+
26
+ import foscat.scat_cov as sc
27
+ import foscat.SphericalStencil as ho
28
+ import matplotlib.pyplot as plt
29
+
30
+ class HealpixUNet(nn.Module):
31
+ """U-Net-like architecture on the HEALPix sphere using Foscat oriented convolutions.
32
+
33
+ Parameters
34
+ ----------
35
+ in_nside : int
36
+ Input HEALPix nside (nested scheme).
37
+ n_chan_in : int
38
+ Number of input channels.
39
+ chanlist : list[int]
40
+ Channels per encoder level (depth = len(chanlist)). Example: [16, 32, 64].
41
+ cell_ids : np.ndarray
42
+ Cell indices for the finest resolution (nside = in_nside) in nested scheme.
43
+ KERNELSZ : int, default 3
44
+ Spatial kernel size K (K x K) for oriented convolution.
45
+ gauge_type : str
46
+ Type of gauge :
47
+ 'cosmo' use the same definition than
48
+ https://www.aanda.org/articles/aa/abs/2022/12/aa44566-22/aa44566-22.html
49
+ 'phi' is define at the pole, could be better for earth observation not using intensivly the pole
50
+ G : int, default 1
51
+ Number of gauges for the orientation definition.
52
+ task : {'regression','segmentation'}, default 'regression'
53
+ Chooses the head and default activation.
54
+ out_channels : int, default 1
55
+ Number of output channels (e.g. num_classes for segmentation).
56
+ final_activation : {'none','sigmoid','softmax'} | None
57
+ If None, uses sensible default per task: 'none' for regression, 'softmax' for segmentation (multi-class),
58
+ 'sigmoid' for segmentation when out_channels==1.
59
+ device : str | torch.device | None, default: 'cuda' if available else 'cpu'
60
+ Preferred device. The module will probe whether Foscat ops can run on CUDA; if not,
61
+ it will fallback to CPU and keep all parameters/buffers on CPU for consistency.
62
+ down_type:
63
+ {"mean","max"}, default "max". Equivalent of max poll during down
64
+ prefer_foscat_gpu : bool, default True
65
+ When device is CUDA, try to move Foscat operators (internal tensors) to CUDA and do a dry-run.
66
+ If the dry-run fails, everything falls back to CPU.
67
+
68
+ Notes
69
+ -----
70
+ - Two oriented convolutions per level. After each conv: BatchNorm1d + ReLU.
71
+ - Downsampling uses foscat ``ud_grade_2``; upsampling uses ``up_grade``.
72
+ - Convolution kernels are explicit parameters (shape [C_in, C_out, K*K]) and applied via ``SphericalStencil.Convol_torch``.
73
+ - Foscat ops device is auto-probed to avoid CPU/CUDA mismatches.
74
+ """
75
+
76
+ def __init__(
77
+ self,
78
+ *,
79
+ in_nside: int,
80
+ n_chan_in: int,
81
+ chanlist: List[int],
82
+ cell_ids: np.ndarray,
83
+ KERNELSZ: int = 3,
84
+ task: Literal['regression', 'segmentation'] = 'regression',
85
+ out_channels: int = 1,
86
+ final_activation: Optional[Literal['none', 'sigmoid', 'softmax']] = None,
87
+ device: Optional[torch.device | str] = None,
88
+ prefer_foscat_gpu: bool = True,
89
+ gauge_type: Optional[Literal['cosmo','phi']] = 'cosmo',
90
+ G: int =1,
91
+ down_type: Optional[Literal['mean','max']] = 'max',
92
+ dtype: Literal['float32','float64'] = 'float32',
93
+ head_reduce: Literal['mean','learned']='mean'
94
+ ) -> None:
95
+ super().__init__()
96
+
97
+ self.dtype=dtype
98
+ if dtype=='float32':
99
+ self.np_dtype=np.float32
100
+ self.torch_dtype=torch.float32
101
+ else:
102
+ self.np_dtype=np.float64
103
+ self.torch_dtype=torch.float32
104
+
105
+ self.gauge_type=gauge_type
106
+ self.G = int(G)
107
+
108
+ if self.G < 1:
109
+ raise ValueError("G must be >= 1")
110
+
111
+ if cell_ids is None:
112
+ raise ValueError("cell_ids must be provided for the finest resolution.")
113
+ if len(chanlist) == 0:
114
+ raise ValueError("chanlist must be non-empty (depth >= 1).")
115
+
116
+ self.in_nside = int(in_nside)
117
+ self.n_chan_in = int(n_chan_in)
118
+ self.chanlist = list(map(int, chanlist))
119
+ self.chanlist = [self.chanlist[k]*self.G for k in range(len(self.chanlist))]
120
+ self.KERNELSZ = int(KERNELSZ)
121
+ self.task = task
122
+ self.out_channels = int(out_channels)*self.G
123
+ self.prefer_foscat_gpu = bool(prefer_foscat_gpu)
124
+ if down_type == 'max':
125
+ self.max_poll = True
126
+ else:
127
+ self.max_poll = False
128
+
129
+ # Choose default final activation if not given
130
+ if final_activation is None:
131
+ if task == 'regression':
132
+ self.final_activation = 'none'
133
+ else: # segmentation
134
+ self.final_activation = 'sigmoid' if out_channels == 1 else 'softmax'
135
+ else:
136
+ self.final_activation = final_activation
137
+
138
+ # Resolve preferred device
139
+ if device is None:
140
+ device = 'cuda' if torch.cuda.is_available() else 'cpu'
141
+ self.device = torch.device(device)
142
+
143
+ # Foscat functional wrapper (backend + grade ops)
144
+ self.f = sc.funct(KERNELSZ=self.KERNELSZ)
145
+
146
+ # ---------- Build multi-resolution bookkeeping ----------
147
+ depth = len(self.chanlist)
148
+ self.l_cell_ids: List[np.ndarray] = [None] * (depth + 1) # per encoder level + bottom
149
+ self.l_cell_ids[0] = np.asarray(cell_ids)
150
+
151
+ enc_nsides: List[int] = [self.in_nside]
152
+ current_nside = self.in_nside
153
+
154
+ # ---------- Oriented convolutions per level (encoder & decoder) ----------
155
+ self.hconv_enc: List[ho.SphericalStencil] = []
156
+ self.hconv_dec: List[ho.SphericalStencil] = []
157
+
158
+ # dummy data to propagate shapes/ids through ud_grade_2
159
+ l_data = self.f.backend.bk_cast(np.zeros((1, 1, cell_ids.shape[0]), dtype=self.np_dtype))
160
+
161
+ for l in range(depth):
162
+ # operator at encoder level l
163
+ hc = ho.SphericalStencil(current_nside,
164
+ self.KERNELSZ,
165
+ n_gauges = self.G,
166
+ gauge_type=self.gauge_type,
167
+ cell_ids=self.l_cell_ids[l],
168
+ dtype=self.torch_dtype)
169
+
170
+ self.hconv_enc.append(hc)
171
+
172
+ # downsample once to get next level ids and new data shape
173
+ l_data, next_ids = hc.Down(
174
+ l_data, cell_ids=self.l_cell_ids[l], nside=current_nside,max_poll=self.max_poll
175
+ )
176
+ self.l_cell_ids[l + 1] = self.f.backend.to_numpy(next_ids)
177
+ current_nside //= 2
178
+ enc_nsides.append(current_nside)
179
+
180
+ # encoder conv weights and BN
181
+ self.enc_w1 = nn.ParameterList()
182
+ self.enc_bn1 = nn.ModuleList()
183
+ self.enc_w2 = nn.ParameterList()
184
+ self.enc_bn2 = nn.ModuleList()
185
+
186
+ self.enc_nsides = enc_nsides # [in, in/2, ..., in/2**depth]
187
+
188
+ inC = self.n_chan_in
189
+ for l, outC in enumerate(self.chanlist):
190
+ if outC % self.G != 0:
191
+ raise ValueError(f"chanlist[{l}] = {outC} must be divisible by G={self.G}")
192
+ outC_g = outC // self.G
193
+
194
+ # conv1: inC -> outC (via multi-gauge => noyau (Ci, Co_g, P))
195
+ w1 = torch.empty(inC, outC_g, self.KERNELSZ * self.KERNELSZ)
196
+ nn.init.kaiming_uniform_(w1.view(inC * outC_g, -1), a=np.sqrt(5))
197
+ self.enc_w1.append(nn.Parameter(w1))
198
+ self.enc_bn1.append(self._norm_1d(outC, kind="group"))
199
+
200
+ # conv2: outC -> outC (entrée = total outC ; noyau (outC, outC_g, P))
201
+ w2 = torch.empty(outC, outC_g, self.KERNELSZ * self.KERNELSZ)
202
+ nn.init.kaiming_uniform_(w2.view(outC * outC_g, -1), a=np.sqrt(5))
203
+ self.enc_w2.append(nn.Parameter(w2))
204
+ self.enc_bn2.append(self._norm_1d(outC, kind="group"))
205
+
206
+ inC = outC # next layer sees total channels
207
+
208
+ # decoder conv weights and BN (mirrored levels)
209
+ self.dec_w1 = nn.ParameterList()
210
+ self.dec_bn1 = nn.ModuleList()
211
+ self.dec_w2 = nn.ParameterList()
212
+ self.dec_bn2 = nn.ModuleList()
213
+
214
+ for d in range(depth):
215
+ level = depth - 1 - d # encoder level we are going back to
216
+ hc = ho.SphericalStencil(self.enc_nsides[level],
217
+ self.KERNELSZ,
218
+ n_gauges = self.G,
219
+ gauge_type=self.gauge_type,
220
+ cell_ids=self.l_cell_ids[level],
221
+ dtype=self.torch_dtype)
222
+ #hc.make_idx_weights()
223
+ self.hconv_dec.append(hc)
224
+
225
+ upC = self.chanlist[level + 1] if level + 1 < depth else self.chanlist[level]
226
+ skipC = self.chanlist[level]
227
+ inC_dec = upC + skipC # total en entrée
228
+ outC_dec = skipC # total en sortie (ce que tu avais déjà)
229
+
230
+ if outC_dec % self.G != 0:
231
+ raise ValueError(f"decoder outC at level {level} = {outC_dec} must be divisible by G={self.G}")
232
+ outC_dec_g = outC_dec // self.G
233
+
234
+ w1 = torch.empty(inC_dec, outC_dec_g, self.KERNELSZ * self.KERNELSZ)
235
+ nn.init.kaiming_uniform_(w1.view(inC_dec * outC_dec_g, -1), a=np.sqrt(5))
236
+ self.dec_w1.append(nn.Parameter(w1))
237
+ self.dec_bn1.append(self._norm_1d(outC_dec, kind="group"))
238
+
239
+ w2 = torch.empty(outC_dec, outC_dec_g, self.KERNELSZ * self.KERNELSZ)
240
+ nn.init.kaiming_uniform_(w2.view(outC_dec * outC_dec_g, -1), a=np.sqrt(5))
241
+ self.dec_w2.append(nn.Parameter(w2))
242
+ self.dec_bn2.append(self._norm_1d(outC_dec, kind="group"))
243
+
244
+ # Output head (on finest grid, channels = chanlist[0])
245
+ self.head_hconv = ho.SphericalStencil(self.in_nside,
246
+ self.KERNELSZ,
247
+ n_gauges=self.G, #Mandatory for the output
248
+ gauge_type=self.gauge_type,
249
+ cell_ids=self.l_cell_ids[0],
250
+ dtype=self.torch_dtype)
251
+
252
+ head_inC = self.chanlist[0]
253
+ if self.out_channels % self.G != 0:
254
+ raise ValueError(f"out_channels={self.out_channels} must be divisible by G={self.G}")
255
+ outC_head_g = self.out_channels // self.G
256
+
257
+ self.head_w = nn.Parameter(
258
+ torch.empty(head_inC, outC_head_g, self.KERNELSZ * self.KERNELSZ)
259
+ )
260
+ nn.init.kaiming_uniform_(self.head_w.view(head_inC * outC_head_g, -1), a=np.sqrt(5))
261
+ self.head_bn = self._norm_1d(self.out_channels, kind="group") if self.task == 'segmentation' else None
262
+
263
+ # Choose how to reduce across gauges at head:
264
+ # 'sum' (default), 'mean', or 'learned' (via 1x1 conv).
265
+ self.head_reduce = getattr(self, 'head_reduce', 'mean') # you can turn this into a ctor arg if you like
266
+ if self.head_reduce == 'learned':
267
+ # Mixer takes G*outC_head_g -> out_channels (K-wise 1x1)
268
+ self.head_mixer = nn.Conv1d(self.G * outC_head_g, self.out_channels, kernel_size=1, bias=True)
269
+ else:
270
+ self.head_mixer = None
271
+
272
+ # ---- Decide runtime device (probe Foscat on CUDA, else CPU) ----
273
+ self.runtime_device = self._probe_and_set_runtime_device(self.device)
274
+
275
+ # -------------------------- define local batchnorm/group -------------------
276
+ def _norm_1d(self, C: int, kind: str = "group", **kwargs) -> nn.Module:
277
+ """
278
+ Return a normalization layer for (B, C, N) tensors.
279
+ kind: "group" | "instance" | "batch"
280
+ kwargs: extra args (e.g., num_groups for GroupNorm)
281
+ """
282
+ if kind == "group":
283
+ num_groups = kwargs.get("num_groups", min(8, max(1, C // 8)) or 1)
284
+ # s’assurer que num_groups divise C
285
+ while C % num_groups != 0 and num_groups > 1:
286
+ num_groups //= 2
287
+ return nn.GroupNorm(num_groups=num_groups, num_channels=C)
288
+ elif kind == "instance":
289
+ return nn.InstanceNorm1d(C, affine=True, track_running_stats=False)
290
+ elif kind == "batch":
291
+ return nn.BatchNorm1d(C)
292
+ else:
293
+ raise ValueError(f"Unknown norm kind: {kind}")
294
+
295
+ # -------------------------- device plumbing --------------------------
296
+ def _move_hconv_tensors(self, hc: ho.SphericalStencil, device: torch.device) -> None:
297
+ """Best-effort: move any torch.Tensor attribute of SphericalStencil to device."""
298
+ for name, val in list(vars(hc).items()):
299
+ try:
300
+ if torch.is_tensor(val):
301
+ setattr(hc, name, val.to(device))
302
+ elif isinstance(val, (list, tuple)) and val and torch.is_tensor(val[0]):
303
+ setattr(hc, name, type(val)([v.to(device) for v in val]))
304
+ except Exception:
305
+ # silently ignore non-tensor or protected attributes
306
+ pass
307
+
308
+ @torch.no_grad()
309
+ def _probe_and_set_runtime_device(self, preferred: torch.device) -> torch.device:
310
+ """Try to run a tiny Foscat conv on preferred device; fallback to CPU if it fails."""
311
+ if preferred.type == 'cuda' and self.prefer_foscat_gpu:
312
+ try:
313
+ # move module params/buffers first
314
+ super().to(preferred)
315
+ # move Foscat operator internals
316
+ for hc in self.hconv_enc + self.hconv_dec + [self.head_hconv]:
317
+ self._move_hconv_tensors(hc, preferred)
318
+ # dry run on level 0
319
+ npix0 = int(len(self.l_cell_ids[0]))
320
+ x_try = torch.zeros(1, self.n_chan_in, npix0, device=preferred)
321
+ y_try = self.hconv_enc[0].Convol_torch(x_try, self.enc_w1[0])
322
+ # success -> stay on CUDA
323
+ self._foscat_device = preferred
324
+ return preferred
325
+ except Exception as e:
326
+ # fallback to CPU; keep error for info
327
+ self._gpu_probe_error = repr(e)
328
+ pass
329
+ # CPU fallback
330
+ cpu = torch.device('cpu')
331
+ super().to(cpu)
332
+ for hc in self.hconv_enc + self.hconv_dec + [self.head_hconv]:
333
+ self._move_hconv_tensors(hc, cpu)
334
+ self._foscat_device = cpu
335
+ return cpu
336
+
337
+ def set_device(self, device: torch.device | str) -> torch.device:
338
+ """Request a (re)device; will probe Foscat and return the actual runtime device used."""
339
+ device = torch.device(device)
340
+ self.device = device
341
+ self.runtime_device = self._probe_and_set_runtime_device(device)
342
+ return self.runtime_device
343
+
344
+ # --- inside HealpixUNet class, add a single-sample forward helper ---
345
+ def _forward_one(self, x1: torch.Tensor, cell_ids1=None) -> torch.Tensor:
346
+ """
347
+ Single-sample forward. x1: (1, C_in, Npix_1). Returns (1, out_channels, Npix_1).
348
+ `cell_ids1` can be None or a 1D array (Npix_1,) for this sample.
349
+ """
350
+ if x1.dim() != 3 or x1.shape[0] != 1:
351
+ raise ValueError(f"_forward_one expects (1, C, Npix), got {tuple(x1.shape)}")
352
+ # Reuse existing forward by calling it with B=1 (your code already supports per-sample ids)
353
+ if cell_ids1 is None:
354
+ return super().forward(x1)
355
+ else:
356
+ # normalize ids to numpy 1D
357
+ if isinstance(cell_ids1, torch.Tensor):
358
+ cell_ids1 = cell_ids1.detach().cpu().numpy()
359
+ elif isinstance(cell_ids1, list):
360
+ cell_ids1 = np.asarray(cell_ids1)
361
+ if cell_ids1.ndim == 1:
362
+ ci = cell_ids1[None, :] # (1, Npix_1) so the current code path is happy
363
+ else:
364
+ ci = cell_ids1
365
+ return super().forward(x1, cell_ids=ci)
366
+
367
+ def _as_tensor_batch(self, x):
368
+ """
369
+ Ensure a (B, C, N) tensor.
370
+ - If x is a list of tensors, concatenate if all N are equal.
371
+ - If len==1, keep a batch dim (1, C, N).
372
+ - If x is already a tensor, return as-is.
373
+ """
374
+ if isinstance(x, list):
375
+ if len(x) == 1:
376
+ t = x[0]
377
+ # If t is (C, N) -> make it (1, C, N)
378
+ return t.unsqueeze(0) if t.dim() == 2 else t
379
+ # all same length -> concat along batch
380
+ Ns = [t.shape[-1] for t in x]
381
+ if all(n == Ns[0] for n in Ns):
382
+ return torch.cat([t if t.dim() == 3 else t.unsqueeze(0) for t in x], dim=0)
383
+ # variable-length with B>1 not supported in a single tensor
384
+ raise ValueError("Variable-length batch detected; use batch_size=1 or loop per-sample.")
385
+ return x
386
+
387
+ # --- replace your current `forward` signature/body with a dispatcher ---
388
+ def forward_any(self, x, cell_ids: Optional[np.ndarray] = None):
389
+ """
390
+ If `x` is a Tensor (B,C,N): standard batched path (requires same N for all).
391
+ If `x` is a list of Tensors: variable-length per-sample path, returns a list of outputs.
392
+ """
393
+ # Variable-length list path
394
+ if isinstance(x, (list, tuple)):
395
+ outs = []
396
+ if cell_ids is None or isinstance(cell_ids, (list, tuple)):
397
+ cids = cell_ids if isinstance(cell_ids, (list, tuple)) else [None] * len(x)
398
+ else:
399
+ raise ValueError("When x is a list, cell_ids must be a list of same length or None.")
400
+
401
+ for xb, cb in zip(x, cids):
402
+ if not torch.is_tensor(xb):
403
+ xb = torch.as_tensor(xb, dtype=torch.float32, device=self.runtime_device)
404
+ if xb.dim() == 2:
405
+ xb = xb.unsqueeze(0) # (1,C,Nb)
406
+ elif xb.dim() != 3 or xb.shape[0] != 1:
407
+ raise ValueError(f"Each sample must be (C,N) or (1,C,N); got {tuple(xb.shape)}")
408
+
409
+ yb = self._forward_one(xb.to(self.runtime_device), cell_ids1=cb) # (1,Co,Nb)
410
+ outs.append(yb.squeeze(0)) # -> (Co, Nb)
411
+ return outs # List[Tensor] (each length Nb)
412
+
413
+ # Fixed-length tensor path (your current implementation)
414
+ return super().forward(x, cell_ids=cell_ids)
415
+
416
+ # -------------------------- forward --------------------------
417
+ def forward(self, x: torch.Tensor,cell_ids: Optional[np.ndarray ] = None) -> torch.Tensor:
418
+ """Forward pass.
419
+
420
+ Parameters
421
+ ----------
422
+ x : torch.Tensor, shape (B, C_in, Npix)
423
+ Input tensor on `in_nside` grid.
424
+ cell_ids : np.ndarray (B, Npix) optional, use another cell_ids than the initial one.
425
+ if None use the initial cell_ids.
426
+ """
427
+ if not isinstance(x, torch.Tensor):
428
+ raise TypeError("Input must be a torch.Tensor")
429
+ if x.dim() != 3:
430
+ raise ValueError("Input must be (B, C, Npix)")
431
+
432
+ # Ensure input lives on the runtime (probed) device
433
+ x = x.to(self.runtime_device)
434
+
435
+ B, C, N = x.shape
436
+ if C != self.n_chan_in:
437
+ raise ValueError(f"Expected {self.n_chan_in} input channels, got {C}")
438
+
439
+ # Encoder
440
+ skips: List[torch.Tensor] = []
441
+ l_data = x
442
+ current_nside = self.in_nside
443
+ l_cell_ids=cell_ids
444
+
445
+ if cell_ids is not None:
446
+ t_cell_ids={}
447
+ t_cell_ids[0]=l_cell_ids
448
+ else:
449
+ t_cell_ids=self.l_cell_ids
450
+
451
+ for l, outC in enumerate(self.chanlist):
452
+ # conv1 + BN + ReLU
453
+ l_data = self.hconv_enc[l].Convol_torch(l_data,
454
+ self.enc_w1[l],
455
+ cell_ids=l_cell_ids)
456
+ l_data = self._as_tensor_batch(l_data)
457
+ l_data = self.enc_bn1[l](l_data)
458
+ l_data = F.relu(l_data, inplace=True)
459
+
460
+ # conv2 + BN + ReLU
461
+ l_data = self.hconv_enc[l].Convol_torch(l_data,
462
+ self.enc_w2[l],
463
+ cell_ids=l_cell_ids)
464
+ l_data = self._as_tensor_batch(l_data)
465
+ l_data = self.enc_bn2[l](l_data)
466
+ l_data = F.relu(l_data, inplace=True)
467
+
468
+ # save skip at this resolution
469
+ skips.append(l_data)
470
+
471
+ # downsample (except bottom level) -> ensure output is on runtime_device
472
+ if l < len(self.chanlist) - 1:
473
+ l_data, l_cell_ids = self.hconv_enc[l].Down(
474
+ l_data, cell_ids=t_cell_ids[l], nside=current_nside,max_poll=self.max_poll
475
+ )
476
+ l_data = self._as_tensor_batch(l_data)
477
+ if cell_ids is not None:
478
+ t_cell_ids[l+1]=l_cell_ids
479
+ else:
480
+ l_cell_ids=None
481
+
482
+ if isinstance(l_data, torch.Tensor) and l_data.device != self.runtime_device:
483
+ l_data = l_data.to(self.runtime_device)
484
+ current_nside //= 2
485
+
486
+ # Decoder
487
+ for d in range(len(self.chanlist)):
488
+ level = len(self.chanlist) - 1 - d # encoder level we are going back to
489
+
490
+ if level < len(self.chanlist) - 1:
491
+ # upsample: from encoder level (level+1) [coarser] -> level [finer]
492
+ src_nside = self.enc_nsides[level + 1] # coarse
493
+
494
+ # Use the **decoder** operator at this step (consistent with your hconv_dec stack)
495
+ l_data = self.hconv_dec[d].Up(
496
+ l_data,
497
+ cell_ids=t_cell_ids[level + 1], # source/coarse IDs
498
+ o_cell_ids=t_cell_ids[level], # target/fine IDs
499
+ nside=src_nside,
500
+ )
501
+ l_data = self._as_tensor_batch(l_data)
502
+
503
+ if isinstance(l_data, torch.Tensor) and l_data.device != self.runtime_device:
504
+ l_data = l_data.to(self.runtime_device)
505
+
506
+ # concat with skip features at this resolution
507
+ concat = self.f.backend.bk_concat([skips[level], l_data], 1)
508
+ l_data = concat.to(self.runtime_device) if torch.is_tensor(concat) else concat
509
+
510
+ # choose the right cell_ids for convolutions at this resolution
511
+ l_cell_ids = t_cell_ids[level] if (cell_ids is not None) else None
512
+
513
+ # apply decoder convs on this grid using the matching decoder operator
514
+ hc = self.hconv_dec[d]
515
+ l_data = hc.Convol_torch(l_data, self.dec_w1[d], cell_ids=l_cell_ids)
516
+ l_data = self._as_tensor_batch(l_data)
517
+ l_data = self.dec_bn1[d](l_data)
518
+ l_data = F.relu(l_data, inplace=True)
519
+
520
+ l_data = hc.Convol_torch(l_data, self.dec_w2[d], cell_ids=l_cell_ids)
521
+ l_data = self._as_tensor_batch(l_data)
522
+ l_data = self.dec_bn2[d](l_data)
523
+ l_data = F.relu(l_data, inplace=True)
524
+
525
+ # Head on finest grid
526
+ # y_head_raw: (B, G*outC_head_g, K)
527
+ y_head_raw = self.head_hconv.Convol_torch(l_data, self.head_w, cell_ids=l_cell_ids)
528
+
529
+ B, Ctot, K = y_head_raw.shape
530
+ outC_head_g = int(self.out_channels)//self.G
531
+ assert Ctot == self.G * outC_head_g, \
532
+ f"Head expects G*outC_head_g channels, got {Ctot} != {self.G}*{outC_head_g}"
533
+
534
+ if self.head_mixer is not None and self.head_reduce == 'learned':
535
+ # 1x1 learned mixing across G*outC_head_g -> out_channels
536
+ y = self.head_mixer(y_head_raw) # (B, out_channels, K)
537
+ else:
538
+ # reshape to (B, G, outC_head_g, K) then reduce across G
539
+ y_g = y_head_raw.view(B, self.G, outC_head_g, K)
540
+
541
+ y = y_g.mean(dim=1) # (B, outC_head_g, K)
542
+
543
+ y = self._as_tensor_batch(y)
544
+
545
+ # Optional BN + activation as before
546
+ if self.task == 'segmentation' and self.head_bn is not None:
547
+ y = self.head_bn(y)
548
+
549
+ if self.final_activation == 'sigmoid':
550
+ y = torch.sigmoid(y)
551
+
552
+ elif self.final_activation == 'softmax':
553
+ y = torch.softmax(y, dim=1)
554
+
555
+ return y
556
+
557
+ # -------------------------- utilities --------------------------
558
+ @torch.no_grad()
559
+ def predict(self, x: torch.Tensor, batch_size: int = 8,cell_ids: Optional[np.ndarray ] = None) -> torch.Tensor:
560
+ self.eval()
561
+ outs = []
562
+ if isinstance(x,np.ndarray):
563
+ x=self.to_Tensor(x)
564
+
565
+ if not isinstance(x, torch.Tensor):
566
+ for i in range(len(x)):
567
+ if cell_ids is not None:
568
+ outs.append(self.forward(x[i][None,:],cell_ids=cell_ids[i][:]))
569
+ else:
570
+ outs.append(self.forward(x[i][None,:]))
571
+ else:
572
+ for i in range(0, x.shape[0], batch_size):
573
+ if cell_ids is not None:
574
+ outs.append(self.forward(x[i : i + batch_size],
575
+ cell_ids=cell_ids[i : i + batch_size]))
576
+ else:
577
+ outs.append(self.forward(x[i : i + batch_size]))
578
+
579
+ return torch.cat(outs, dim=0)
580
+
581
+ def to_tensor(self,x):
582
+ return self.hconv_enc[0].f.backend.bk_cast(x)
583
+
584
+ def to_numpy(self,x):
585
+ if isinstance(x,np.ndarray):
586
+ return x
587
+ return x.cpu().numpy()
588
+
589
+ # -----------------------------
590
+ # Kernel extraction & plotting
591
+ # -----------------------------
592
+ def _arch_shapes(self):
593
+ """Return expected (in_c, out_c) per conv for encoder/decoder.
594
+
595
+ Returns
596
+ -------
597
+ enc_shapes : list[tuple[tuple[int,int], tuple[int,int]]]
598
+ For each level `l`, ((in1, out1), (in2, out2)) for the two encoder convs.
599
+ dec_shapes : list[tuple[tuple[int,int], tuple[int,int]]]
600
+ For each level `l`, ((in1, out1), (in2, out2)) for the two decoder convs.
601
+ """
602
+ nlayer = len(self.chanlist)
603
+ enc_shapes = []
604
+ l_chan = self.n_chan_in
605
+ for l in range(nlayer):
606
+ enc_shapes.append(((l_chan, self.chanlist[l]), (self.chanlist[l], self.chanlist[l])))
607
+ l_chan = self.chanlist[l] + 1
608
+
609
+ dec_shapes = []
610
+ l_chan = self.chanlist[-1] + 1
611
+ for l in range(nlayer):
612
+ in1 = l_chan + 1
613
+ out2 = 1 + (self.chanlist[nlayer - 1 - l] if (nlayer - 1 - l) > 0 else 0)
614
+ dec_shapes.append(((in1, in1), (in1, out2)))
615
+ l_chan = out2
616
+ return enc_shapes, dec_shapes
617
+
618
+ def extract_kernels(self, stage: str = "encoder", layer: int = 0, conv: int = 0):
619
+ """Extract raw convolution kernels for a given stage/level/conv.
620
+
621
+ Parameters
622
+ ----------
623
+ stage : {"encoder", "decoder"}
624
+ Which part of the network to inspect.
625
+ layer : int
626
+ Pyramid level (0 = finest encoder level / bottommost decoder level).
627
+ conv : int
628
+ 0 for the first conv at that level, 1 for the second conv.
629
+
630
+ Returns
631
+ -------
632
+ np.ndarray
633
+ Array of shape (in_c, out_c, K, K) containing the spatial kernels.
634
+ """
635
+ assert stage in {"encoder", "decoder"}
636
+ assert conv in {0, 1}
637
+ K = self.KERNELSZ
638
+ enc_shapes, dec_shapes = self._arch_shapes()
639
+
640
+ if stage == "encoder":
641
+ if conv==0:
642
+ w = self.enc_w1[layer]
643
+ else:
644
+ w = self.enc_w2[layer]
645
+ else:
646
+ if conv==0:
647
+ w = self.dec_w1[layer]
648
+ else:
649
+ w = self.dec_w2[layer]
650
+
651
+ w_np = self.f.backend.to_numpy(w.detach())
652
+ return w_np.reshape(w.shape[0],w.shape[1],K,K)
653
+
654
+ def plot_kernels(
655
+ self,
656
+ stage: str = "encoder",
657
+ layer: int = 0,
658
+ conv: int = 0,
659
+ fixed: str = "in",
660
+ index: int = 0,
661
+ max_tiles: int = 16,
662
+ ):
663
+ """Quick visualization of kernels on a grid using matplotlib.
664
+
665
+ Parameters
666
+ ----------
667
+ stage : {"encoder", "decoder"}
668
+ Which tower to visualize.
669
+ layer : int
670
+ Level to visualize.
671
+ conv : int
672
+ 0 or 1: first or second conv in the level.
673
+ fixed : {"in", "out"}
674
+ If "in", show kernels for a fixed input channel across many outputs.
675
+ If "out", show kernels for a fixed output channel across many inputs.
676
+ index : int
677
+ Channel index to fix (according to `fixed`).
678
+ max_tiles : int
679
+ Maximum number of tiles to display.
680
+ """
681
+ import math
682
+ import matplotlib.pyplot as plt
683
+
684
+ W = self.extract_kernels(stage=stage, layer=layer, conv=conv)
685
+ ic, oc, K,_ = W.shape
686
+
687
+ if fixed == "in":
688
+ idx = min(index, ic - 1)
689
+ tiles = [W[idx, j] for j in range(oc)]
690
+ title = f"{stage} L{layer} C{conv} | in={idx}"
691
+ else:
692
+ idx = min(index, oc - 1)
693
+ tiles = [W[i, idx] for i in range(ic)]
694
+ title = f"{stage} L{layer} C{conv} | out={idx}"
695
+
696
+ tiles = tiles[:max_tiles]
697
+ n = len(tiles)
698
+ cols = int(math.ceil(math.sqrt(n)))
699
+ rows = int(math.ceil(n / cols))
700
+
701
+ plt.figure(figsize=(2.5 * cols, 2.5 * rows))
702
+ for i, ker in enumerate(tiles, 1):
703
+ ax = plt.subplot(rows, cols, i)
704
+ ax.imshow(ker)
705
+ ax.set_xticks([])
706
+ ax.set_yticks([])
707
+ plt.suptitle(title)
708
+ plt.tight_layout()
709
+ plt.show()
710
+
711
+ # -----------------------------
712
+ # Unit tests (smoke tests)
713
+ # -----------------------------
714
+ # Run with: python UNET.py (or) python UNET.py -q for quieter output
715
+ # These tests assume Foscat and its dependencies are installed.
716
+
717
+
718
+ def _dummy_cell_ids(nside: int) -> np.ndarray:
719
+ """Return a simple identity mapping for HEALPix nested pixel IDs.
720
+
721
+ Notes
722
+ -----
723
+ Replace with your pipeline's real `cell_ids` if you have a precomputed
724
+ mapping consistent with Foscat/HEALPix nested ordering.
725
+ """
726
+ return np.arange(12 * nside * nside, dtype=np.int64)
727
+
728
+
729
+ if __name__ == "__main__":
730
+ import unittest
731
+
732
+ class TestUNET(unittest.TestCase):
733
+ """Lightweight smoke tests for shape and parameter plumbing."""
734
+
735
+ def setUp(self):
736
+ self.nside = 4 # small grid for fast tests (npix = 192)
737
+ self.chanlist = [4, 8] # two-level encoder/decoder
738
+ self.batch = 2
739
+ self.channels = 1
740
+ self.npix = 12 * self.nside * self.nside
741
+ self.cell_ids = _dummy_cell_ids(self.nside)
742
+ self.net = UNET(
743
+ in_nside=self.nside,
744
+ n_chan_in=self.channels,
745
+ chanlist=self.chanlist,
746
+ cell_ids=self.cell_ids,
747
+ )
748
+
749
+ def test_forward_shape(self):
750
+ # random input
751
+ x = np.random.randn(self.batch, self.channels, self.npix).astype(self.np_dtype)
752
+ x = self.net.f.backend.bk_cast(x)
753
+ y = self.net.eval(x)
754
+ # expected output: same npix, 1 channel at the very top
755
+ self.assertEqual(y.shape[0], self.batch)
756
+ self.assertEqual(y.shape[1], 1)
757
+ self.assertEqual(y.shape[2], self.npix)
758
+ # sanity: no NaNs
759
+ y_np = self.net.f.backend.to_numpy(y)
760
+ self.assertFalse(np.isnan(y_np).any())
761
+
762
+ def test_param_roundtrip_and_determinism(self):
763
+ x = np.random.randn(self.batch, self.channels, self.npix).astype(self.np_dtype)
764
+ x = self.net.f.backend.bk_cast(x)
765
+
766
+ # forward twice -> identical outputs with fixed params
767
+ y1 = self.net.eval(x)
768
+ y2 = self.net.eval(x)
769
+ y1_np = self.net.f.backend.to_numpy(y1)
770
+ y2_np = self.net.f.backend.to_numpy(y2)
771
+ np.testing.assert_allclose(y1_np, y2_np, rtol=0, atol=0)
772
+
773
+ # perturb parameters -> output should (very likely) change
774
+ p = self.net.get_param()
775
+ p_np = self.net.f.backend.to_numpy(p).copy()
776
+ if p_np.size > 0:
777
+ p_np[0] += 1.0
778
+ self.net.set_param(p_np)
779
+ y3 = self.net.eval(x)
780
+ y3_np = self.net.f.backend.to_numpy(y3)
781
+ with self.assertRaises(AssertionError):
782
+ np.testing.assert_allclose(y1_np, y3_np, rtol=0, atol=0)
783
+
784
+ unittest.main()
785
+
786
+ from torch.utils.data import Dataset
787
+ # 1) Dataset that omits cell_ids when None
788
+ from torch.utils.data import Dataset, DataLoader, TensorDataset
789
+
790
+ class HealpixDataset(Dataset):
791
+ """
792
+ Returns (x, y, cell_ids) per-sample if cell_ids is given, else (x, y).
793
+ Shapes:
794
+ x: (C, Npix)
795
+ y: (C_out or 1, Npix)
796
+ cell_ids: (Npix,) per-sample (or broadcasted from (Npix,))
797
+ """
798
+ def __init__(self, x, y, cell_ids=None, dtype=torch.float32):
799
+ self.x = torch.as_tensor(x, dtype=dtype)
800
+ self.y = torch.as_tensor(y, dtype=dtype)
801
+ assert self.x.shape[0] == self.y.shape[0], "x and y must share batch size"
802
+ self._has_cids = cell_ids is not None
803
+ if self._has_cids:
804
+ cid = torch.as_tensor(cell_ids, dtype=torch.long)
805
+ if cid.dim() == 1:
806
+ cid = cid.unsqueeze(0).expand(self.x.shape[0], -1)
807
+ assert cid.shape[0] == self.x.shape[0], "cell_ids must match batch size"
808
+ self.cids = cid
809
+ else:
810
+ self.cids = None
811
+
812
+ def __len__(self):
813
+ return self.x.shape[0]
814
+
815
+ def __getitem__(self, i):
816
+ if self._has_cids:
817
+ return self.x[i], self.y[i], self.cids[i]
818
+ else:
819
+ return self.x[i], self.y[i]
820
+
821
+ # ---------------------------
822
+ # Datasets / Collate helpers
823
+ # ---------------------------
824
+
825
+ class HealpixDataset(Dataset):
826
+ """
827
+ Fixed-grid dataset (common Npix for all samples).
828
+ Returns (x, y) if cell_ids is None, else (x, y, cell_ids).
829
+
830
+ x: (B, C, Npix)
831
+ y: (B, C_out or 1, Npix) or class indices depending on task
832
+ cell_ids: (Npix,) or (B, Npix)
833
+ """
834
+ def __init__(self,
835
+ x: torch.Tensor,
836
+ y: torch.Tensor,
837
+ cell_ids: Optional[Union[np.ndarray, torch.Tensor]] = None,
838
+ dtype: torch.dtype = torch.float32):
839
+ x = torch.as_tensor(x, dtype=dtype)
840
+ y = torch.as_tensor(y, dtype=dtype if y.ndim == 3 else torch.long)
841
+ assert x.shape[0] == y.shape[0], "x and y must share batch size"
842
+ self.x, self.y = x, y
843
+ self._has_cids = cell_ids is not None
844
+ if self._has_cids:
845
+ c = torch.as_tensor(cell_ids, dtype=torch.long)
846
+ if c.ndim == 1: # broadcast single (Npix,) to (B, Npix)
847
+ c = c.unsqueeze(0).expand(x.shape[0], -1)
848
+ assert c.shape == (x.shape[0], x.shape[2]), "cell_ids must be (B,Npix) or (Npix,)"
849
+ self.cids = c
850
+ else:
851
+ self.cids = None
852
+
853
+ def __len__(self) -> int: return self.x.shape[0]
854
+
855
+ def __getitem__(self, i: int):
856
+ if self._has_cids:
857
+ return self.x[i], self.y[i], self.cids[i]
858
+ return self.x[i], self.y[i]
859
+
860
+ # ---------------------------
861
+ # Datasets / Collate helpers
862
+ # ---------------------------
863
+
864
+ class HealpixDataset(Dataset):
865
+ """
866
+ Fixed-grid dataset (common Npix for all samples).
867
+ Returns (x, y) if cell_ids is None, else (x, y, cell_ids).
868
+
869
+ x: (B, C, Npix)
870
+ y: (B, C_out or 1, Npix) or class indices depending on task
871
+ cell_ids: (Npix,) or (B, Npix)
872
+ """
873
+ def __init__(self,
874
+ x: torch.Tensor,
875
+ y: torch.Tensor,
876
+ cell_ids: Optional[Union[np.ndarray, torch.Tensor]] = None,
877
+ dtype: torch.dtype = torch.float32):
878
+ x = torch.as_tensor(x, dtype=dtype)
879
+ y = torch.as_tensor(y, dtype=dtype if y.ndim == 3 else torch.long)
880
+ assert x.shape[0] == y.shape[0], "x and y must share batch size"
881
+ self.x, self.y = x, y
882
+ self._has_cids = cell_ids is not None
883
+ if self._has_cids:
884
+ c = torch.as_tensor(cell_ids, dtype=torch.long)
885
+ if c.ndim == 1: # broadcast single (Npix,) to (B, Npix)
886
+ c = c.unsqueeze(0).expand(x.shape[0], -1)
887
+ assert c.shape == (x.shape[0], x.shape[2]), "cell_ids must be (B,Npix) or (Npix,)"
888
+ self.cids = c
889
+ else:
890
+ self.cids = None
891
+
892
+ def __len__(self) -> int: return self.x.shape[0]
893
+
894
+ def __getitem__(self, i: int):
895
+ if self._has_cids:
896
+ return self.x[i], self.y[i], self.cids[i]
897
+ return self.x[i], self.y[i]
898
+
899
+
900
+ class VarLenHealpixDataset(Dataset):
901
+ """
902
+ Variable-length per-sample dataset.
903
+
904
+ x_list[b]: (C, Npix_b) or (1, C, Npix_b)
905
+ y_list[b]: (C_out or 1, Npix_b) or (1, C_out, Npix_b) (regression/segmentation targets)
906
+ For multi-class segmentation with CrossEntropyLoss, you may pass
907
+ class indices of shape (Npix_b,) or (1, Npix_b) (we’ll squeeze later).
908
+ cids_list[b]: (Npix_b,) or None
909
+ """
910
+ def __init__(self,
911
+ x_list: List[Union[np.ndarray, torch.Tensor]],
912
+ y_list: List[Union[np.ndarray, torch.Tensor]],
913
+ cids_list: Optional[List[Union[np.ndarray, torch.Tensor]]] = None,
914
+ dtype: torch.dtype = torch.float32):
915
+ assert len(x_list) == len(y_list), "x_list and y_list must have the same length"
916
+ self.x = [torch.as_tensor(x, dtype=dtype) for x in x_list]
917
+ # y can be float (regression) or long (class indices); we’ll coerce later per task
918
+ self.y = [torch.as_tensor(y) for y in y_list]
919
+ if cids_list is not None:
920
+ assert len(cids_list) == len(x_list), "cids_list must match x_list length"
921
+ self.c = [torch.as_tensor(c, dtype=torch.long) for c in cids_list]
922
+ else:
923
+ self.c = None
924
+
925
+ def __len__(self) -> int: return len(self.x)
926
+
927
+ def __getitem__(self, i: int):
928
+ ci = None if self.c is None else self.c[i]
929
+ return self.x[i], self.y[i], ci
930
+
931
+ from torch.utils.data import Dataset, DataLoader
932
+
933
+ class VarLenHealpixDataset(Dataset):
934
+ """
935
+ x_list: list of (C, Npix_b) tensors or arrays
936
+ y_list: list of (C_out or 1, Npix_b) tensors or arrays
937
+ cids_list: optional list of (Npix_b,) arrays
938
+ """
939
+ def __init__(self, x_list, y_list, cids_list=None, dtype=torch.float32):
940
+ assert len(x_list) == len(y_list)
941
+ self.x = [torch.as_tensor(x, dtype=dtype) for x in x_list]
942
+ self.y = [torch.as_tensor(y, dtype=dtype) for y in y_list]
943
+ self.c = None
944
+ if cids_list is not None:
945
+ assert len(cids_list) == len(x_list)
946
+ self.c = [np.asarray(c) for c in cids_list]
947
+
948
+ def __len__(self): return len(self.x)
949
+
950
+ def __getitem__(self, i):
951
+ if self.c is None:
952
+ return self.x[i], self.y[i], None
953
+ return self.x[i], self.y[i], self.c[i]
954
+
955
+ def varlen_collate(batch):
956
+ # Just return lists; do not stack.
957
+ xs, ys, cs = zip(*batch) # tuples of length B
958
+ # keep None if all Nones, else list
959
+ c_out = None if all(c is None for c in cs) else list(cs)
960
+ return list(xs), list(ys), c_out
961
+
962
+ def varlen_collate(batch):
963
+ """
964
+ Collate for variable-length samples: keep lists, do NOT stack.
965
+ Returns lists: xs, ys, cs (cs can be None).
966
+ """
967
+ xs, ys, cs = zip(*batch)
968
+ c_out = None if all(c is None for c in cs) else list(cs)
969
+ return list(xs), list(ys), c_out
970
+
971
+
972
+ # ---------------------------
973
+ # Training function
974
+ # ---------------------------
975
+
976
+ def fit(
977
+ model,
978
+ x_train: Union[torch.Tensor, np.ndarray, List[Union[torch.Tensor, np.ndarray]]],
979
+ y_train: Union[torch.Tensor, np.ndarray, List[Union[torch.Tensor, np.ndarray]]],
980
+ *,
981
+ cell_ids_train: Optional[Union[np.ndarray, torch.Tensor, List[Union[np.ndarray, torch.Tensor]]]] = None,
982
+ n_epoch: int = 10,
983
+ view_epoch: int = 10,
984
+ batch_size: int = 16,
985
+ lr: float = 1e-3,
986
+ weight_decay: float = 0.0,
987
+ clip_grad_norm: Optional[float] = None,
988
+ verbose: bool = True,
989
+ optimizer: Literal['ADAM', 'LBFGS'] = 'ADAM',
990
+ ) -> dict:
991
+ """
992
+ Train helper that supports:
993
+ - Fixed-grid tensors (B,C,N) with optional (B,N) or (N,) cell_ids.
994
+ - Variable-length lists: x=[(C,N_b)], y=[...], cell_ids=[(N_b,)], returning per-sample grids.
995
+
996
+ ADAM: standard minibatch update.
997
+ LBFGS: uses a closure that sums losses over the current (variable-length) mini-batch.
998
+
999
+ Notes
1000
+ -----
1001
+ - For segmentation with multiple classes, pass integer class targets for y:
1002
+ fixed-grid: (B, N) int64; variable-length: each y[b] of shape (N_b,) or (1,N_b).
1003
+ - For regression, pass float targets with the same (C_out, N) channeling.
1004
+ """
1005
+ device = model.runtime_device if hasattr(model, "runtime_device") else (torch.device("cuda") if torch.cuda.is_available() else torch.device("cpu"))
1006
+ model.to(device)
1007
+
1008
+ # Detect variable-length mode
1009
+ varlen_mode = isinstance(x_train, (list, tuple))
1010
+
1011
+ # ----- Build DataLoader
1012
+ if not varlen_mode:
1013
+ # Fixed-grid path
1014
+ x_t = torch.as_tensor(x_train, dtype=torch.float32, device=device)
1015
+ y_is_class = (model.task != 'regression' and getattr(model, "out_channels", 1) > 1)
1016
+ y_dtype = torch.long if y_is_class and (not torch.is_tensor(y_train) or y_train.ndim != 3) else torch.float32
1017
+ y_t = torch.as_tensor(y_train, dtype=y_dtype, device=device)
1018
+
1019
+ if cell_ids_train is None:
1020
+ ds = TensorDataset(x_t, y_t)
1021
+ loader = DataLoader(ds, batch_size=batch_size, shuffle=True, drop_last=False)
1022
+ with_cell_ids = False
1023
+ else:
1024
+ ds = HealpixDataset(x_t, y_t, cell_ids=cell_ids_train, dtype=torch.float32)
1025
+ loader = DataLoader(ds, batch_size=batch_size, shuffle=True, drop_last=False)
1026
+ with_cell_ids = True
1027
+ else:
1028
+ # Variable-length path
1029
+ ds = VarLenHealpixDataset(x_train, y_train, cids_list=cell_ids_train, dtype=torch.float32)
1030
+ loader = DataLoader(ds, batch_size=batch_size, shuffle=True, drop_last=False, collate_fn=varlen_collate)
1031
+ with_cell_ids = cell_ids_train is not None
1032
+
1033
+ # ----- Loss
1034
+ if getattr(model, "task", "regression") == 'regression':
1035
+ criterion = nn.MSELoss(reduction='mean')
1036
+ seg_multiclass = False
1037
+ else:
1038
+ # segmentation
1039
+ if getattr(model, "out_channels", 1) == 1:
1040
+ # binary
1041
+ # assume model head returns logits if final_activation == 'none'
1042
+ criterion = nn.BCEWithLogitsLoss() if getattr(model, "final_activation", "none") == 'none' else nn.BCELoss()
1043
+ seg_multiclass = False
1044
+ else:
1045
+ criterion = nn.CrossEntropyLoss()
1046
+ seg_multiclass = True
1047
+
1048
+ # ----- Optimizer
1049
+ if optimizer.upper() == 'ADAM':
1050
+ optim = torch.optim.Adam(model.parameters(), lr=lr, weight_decay=weight_decay)
1051
+ outer = n_epoch
1052
+ inner = 1
1053
+ else:
1054
+ optim = torch.optim.LBFGS(model.parameters(), lr=lr, max_iter=20,
1055
+ history_size=max(10, n_epoch * 5), line_search_fn="strong_wolfe")
1056
+ # emulate "epochs" with multiple inner LBFGS steps
1057
+ outer = max(1, n_epoch // 20)
1058
+ inner = 20
1059
+
1060
+ # ----- Training loop
1061
+ history: List[float] = []
1062
+ model.train()
1063
+
1064
+ for epoch in range(outer):
1065
+ for _ in range(inner):
1066
+ epoch_loss, n_samples = 0.0, 0
1067
+
1068
+ for batch in loader:
1069
+ if not varlen_mode:
1070
+ # -------- fixed-grid
1071
+ if with_cell_ids:
1072
+ xb, yb, cb = batch
1073
+ cb_np = cb.detach().cpu().numpy()
1074
+ else:
1075
+ xb, yb = batch
1076
+ cb_np = None
1077
+
1078
+ xb = xb.to(device, dtype=torch.float32, non_blocking=True)
1079
+ # y type: float for regression or binary; long for CrossEntropy
1080
+ yb = yb.to(device, non_blocking=True)
1081
+
1082
+ if isinstance(optim, torch.optim.LBFGS):
1083
+ def closure():
1084
+ optim.zero_grad(set_to_none=True)
1085
+ preds = model(xb, cell_ids=cb_np) if cb_np is not None else model(xb)
1086
+ loss = criterion(preds, yb)
1087
+ loss.backward()
1088
+ return loss
1089
+ _ = optim.step(closure)
1090
+ with torch.no_grad():
1091
+ preds = model(xb, cell_ids=cb_np) if cb_np is not None else model(xb)
1092
+ loss = criterion(preds, yb)
1093
+ else:
1094
+ optim.zero_grad(set_to_none=True)
1095
+ preds = model(xb, cell_ids=cb_np) if cb_np is not None else model(xb)
1096
+ loss = criterion(preds, yb)
1097
+ loss.backward()
1098
+ if clip_grad_norm is not None:
1099
+ nn.utils.clip_grad_norm_(model.parameters(), clip_grad_norm)
1100
+ optim.step()
1101
+
1102
+ bs = xb.shape[0]
1103
+ epoch_loss += float(loss.item()) * bs
1104
+ n_samples += bs
1105
+
1106
+ else:
1107
+ # -------- variable-length (lists)
1108
+ xs, ys, cs = batch # lists
1109
+
1110
+ def _prep_xyc(i):
1111
+ # x_i : (C, N_i) -> (1, C, N_i)
1112
+ xb = torch.as_tensor(xs[i], device=device, dtype=torch.float32)
1113
+ if xb.dim() == 2:
1114
+ xb = xb.unsqueeze(0)
1115
+ elif xb.dim() != 3 or xb.shape[0] != 1:
1116
+ raise ValueError("Each x[i] must be (C,N) or (1,C,N)")
1117
+
1118
+ # y_i :
1119
+ yb = torch.as_tensor(ys[i], device=device)
1120
+ if seg_multiclass:
1121
+ # class indices: (N_i,) ou (1, N_i)
1122
+ if yb.dim() == 2 and yb.shape[0] == 1:
1123
+ yb = yb.squeeze(0) # -> (N_i,)
1124
+ elif yb.dim() != 1:
1125
+ raise ValueError("For multiclass CE, y[i] must be (N,) or (1,N)")
1126
+ # le critère CE recevra (1,C_out,N_i) et (N_i,)
1127
+ else:
1128
+ # régression / binaire: cible de forme (1, C_out, N_i)
1129
+ if yb.dim() == 2:
1130
+ yb = yb.unsqueeze(0)
1131
+ elif yb.dim() != 3 or yb.shape[0] != 1:
1132
+ raise ValueError("For regression/binary, y[i] must be (C_out,N) or (1,C_out,N)")
1133
+
1134
+ # cell_ids : (N_i,) -> (1, N_i) en numpy (le forward les attend en np.ndarray)
1135
+ if cs is None or cs[i] is None:
1136
+ cb_np = None
1137
+ else:
1138
+ c = cs[i].detach().cpu().numpy() if torch.is_tensor(cs[i]) else np.asarray(cs[i])
1139
+ if c.ndim == 1:
1140
+ c = c[None, :] # -> (1, N_i)
1141
+ cb_np = c
1142
+ return xb, yb, cb_np
1143
+
1144
+ if isinstance(optim, torch.optim.LBFGS):
1145
+ def closure():
1146
+ optim.zero_grad(set_to_none=True)
1147
+ total = 0.0
1148
+ for i in range(len(xs)):
1149
+ xb, yb, cb_np = _prep_xyc(i)
1150
+ preds = model(xb, cell_ids=cb_np) if cb_np is not None else model(xb)
1151
+ # adapter la cible à la sortie
1152
+ if seg_multiclass:
1153
+ loss_i = criterion(preds, yb) # preds: (1,C_out,N_i), yb: (N_i,)
1154
+ else:
1155
+ loss_i = criterion(preds, yb) # preds: (1,C_out,N_i), yb: (1,C_out,N_i)
1156
+ loss_i.backward()
1157
+ total += float(loss_i.item())
1158
+ # retourner un scalaire Tensor pour LBFGS
1159
+ return torch.tensor(total / max(1, len(xs)), device=device, dtype=torch.float32)
1160
+
1161
+ _ = optim.step(closure)
1162
+ # logging (sans grad)
1163
+ with torch.no_grad():
1164
+ total = 0.0
1165
+ for i in range(len(xs)):
1166
+ xb, yb, cb_np = _prep_xyc(i)
1167
+ preds = model(xb, cell_ids=cb_np) if cb_np is not None else model(xb)
1168
+ if seg_multiclass:
1169
+ loss_i = criterion(preds, yb)
1170
+ else:
1171
+ loss_i = criterion(preds, yb)
1172
+ total += float(loss_i.item())
1173
+ loss_val = total / max(1, len(xs))
1174
+ else:
1175
+ optim.zero_grad(set_to_none=True)
1176
+ total = 0.0
1177
+ for i in range(len(xs)):
1178
+ xb, yb, cb_np = _prep_xyc(i)
1179
+ preds = model(xb, cell_ids=cb_np) if cb_np is not None else model(xb)
1180
+ if seg_multiclass:
1181
+ loss_i = criterion(preds, yb)
1182
+ else:
1183
+ loss_i = criterion(preds, yb)
1184
+ loss_i.backward()
1185
+ total += float(loss_i.item())
1186
+ if clip_grad_norm is not None:
1187
+ nn.utils.clip_grad_norm_(model.parameters(), clip_grad_norm)
1188
+ optim.step()
1189
+ loss_val = total / max(1, len(xs))
1190
+
1191
+ epoch_loss += loss_val * max(1, len(xs))
1192
+ n_samples += max(1, len(xs))
1193
+
1194
+
1195
+
1196
+ epoch_loss /= max(1, n_samples)
1197
+ history.append(epoch_loss)
1198
+ # print every view_epoch logical step
1199
+ if verbose and ((len(history) % view_epoch == 0) or (len(history) == 1)):
1200
+ print(f"[epoch {len(history)}] loss={epoch_loss:.6f}")
1201
+
1202
+ return {"loss": history}