foscat 2025.9.1__py3-none-any.whl → 2025.9.4__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
@@ -6,11 +6,11 @@ GPU by default (when available), with graceful CPU fallback if Foscat ops are CP
6
6
  - ReLU + BatchNorm after each convolution (encoder & decoder)
7
7
  - Segmentation/Regression heads with optional final activation
8
8
  - PyTorch-ified: inherits from nn.Module, standard state_dict
9
- - Device management: tries CUDA first; if Foscat HOrientedConvol cannot run on CUDA, falls back to CPU
9
+ - Device management: tries CUDA first; if Foscat SphericalStencil cannot run on CUDA, falls back to CPU
10
10
 
11
11
  Shape convention: (B, C, Npix)
12
12
 
13
- Requirements: foscat (scat_cov.funct + HOrientedConvol.Convol_torch must be differentiable on torch tensors)
13
+ Requirements: foscat (scat_cov.funct + SphericalStencil.Convol_torch must be differentiable on torch tensors)
14
14
  """
15
15
  from __future__ import annotations
16
16
  from typing import List, Optional, Literal, Tuple
@@ -18,13 +18,15 @@ import numpy as np
18
18
 
19
19
  import torch
20
20
  import torch.nn as nn
21
+ import healpy as hp
22
+
21
23
  import torch.nn.functional as F
24
+ from torch.utils.data import Dataset, DataLoader, TensorDataset
22
25
 
23
26
  import foscat.scat_cov as sc
24
- import foscat.HOrientedConvol as ho
27
+ import foscat.SphericalStencil as ho
25
28
  import matplotlib.pyplot as plt
26
29
 
27
-
28
30
  class HealpixUNet(nn.Module):
29
31
  """U-Net-like architecture on the HEALPix sphere using Foscat oriented convolutions.
30
32
 
@@ -40,6 +42,13 @@ class HealpixUNet(nn.Module):
40
42
  Cell indices for the finest resolution (nside = in_nside) in nested scheme.
41
43
  KERNELSZ : int, default 3
42
44
  Spatial kernel size K (K x K) for oriented convolution.
45
+ gauge_type : str
46
+ Type of gauge :
47
+ 'cosmo' use the same definition than
48
+ https://www.aanda.org/articles/aa/abs/2022/12/aa44566-22/aa44566-22.html
49
+ 'phi' is define at the pole, could be better for earth observation not using intensivly the pole
50
+ G : int, default 1
51
+ Number of gauges for the orientation definition.
43
52
  task : {'regression','segmentation'}, default 'regression'
44
53
  Chooses the head and default activation.
45
54
  out_channels : int, default 1
@@ -50,6 +59,8 @@ class HealpixUNet(nn.Module):
50
59
  device : str | torch.device | None, default: 'cuda' if available else 'cpu'
51
60
  Preferred device. The module will probe whether Foscat ops can run on CUDA; if not,
52
61
  it will fallback to CPU and keep all parameters/buffers on CPU for consistency.
62
+ down_type:
63
+ {"mean","max"}, default "max". Equivalent of max poll during down
53
64
  prefer_foscat_gpu : bool, default True
54
65
  When device is CUDA, try to move Foscat operators (internal tensors) to CUDA and do a dry-run.
55
66
  If the dry-run fails, everything falls back to CPU.
@@ -58,7 +69,7 @@ class HealpixUNet(nn.Module):
58
69
  -----
59
70
  - Two oriented convolutions per level. After each conv: BatchNorm1d + ReLU.
60
71
  - Downsampling uses foscat ``ud_grade_2``; upsampling uses ``up_grade``.
61
- - Convolution kernels are explicit parameters (shape [C_in, C_out, K*K]) and applied via ``HOrientedConvol.Convol_torch``.
72
+ - Convolution kernels are explicit parameters (shape [C_in, C_out, K*K]) and applied via ``SphericalStencil.Convol_torch``.
62
73
  - Foscat ops device is auto-probed to avoid CPU/CUDA mismatches.
63
74
  """
64
75
 
@@ -75,16 +86,28 @@ class HealpixUNet(nn.Module):
75
86
  final_activation: Optional[Literal['none', 'sigmoid', 'softmax']] = None,
76
87
  device: Optional[torch.device | str] = None,
77
88
  prefer_foscat_gpu: bool = True,
78
- dtype: Literal['float32','float64'] = 'float32'
89
+ gauge_type: Optional[Literal['cosmo','phi']] = 'cosmo',
90
+ G: int =1,
91
+ down_type: Optional[Literal['mean','max']] = 'max',
92
+ dtype: Literal['float32','float64'] = 'float32',
93
+ head_reduce: Literal['mean','learned']='mean'
79
94
  ) -> None:
80
95
  super().__init__()
81
-
96
+
82
97
  self.dtype=dtype
83
98
  if dtype=='float32':
84
99
  self.np_dtype=np.float32
100
+ self.torch_dtype=torch.float32
85
101
  else:
86
102
  self.np_dtype=np.float64
87
-
103
+ self.torch_dtype=torch.float32
104
+
105
+ self.gauge_type=gauge_type
106
+ self.G = int(G)
107
+
108
+ if self.G < 1:
109
+ raise ValueError("G must be >= 1")
110
+
88
111
  if cell_ids is None:
89
112
  raise ValueError("cell_ids must be provided for the finest resolution.")
90
113
  if len(chanlist) == 0:
@@ -93,11 +116,16 @@ class HealpixUNet(nn.Module):
93
116
  self.in_nside = int(in_nside)
94
117
  self.n_chan_in = int(n_chan_in)
95
118
  self.chanlist = list(map(int, chanlist))
119
+ self.chanlist = [self.chanlist[k]*self.G for k in range(len(self.chanlist))]
96
120
  self.KERNELSZ = int(KERNELSZ)
97
121
  self.task = task
98
- self.out_channels = int(out_channels)
122
+ self.out_channels = int(out_channels)*self.G
99
123
  self.prefer_foscat_gpu = bool(prefer_foscat_gpu)
100
-
124
+ if down_type == 'max':
125
+ self.max_poll = True
126
+ else:
127
+ self.max_poll = False
128
+
101
129
  # Choose default final activation if not given
102
130
  if final_activation is None:
103
131
  if task == 'regression':
@@ -124,25 +152,26 @@ class HealpixUNet(nn.Module):
124
152
  current_nside = self.in_nside
125
153
 
126
154
  # ---------- Oriented convolutions per level (encoder & decoder) ----------
127
- self.hconv_enc: List[ho.HOrientedConvol] = []
128
- self.hconv_dec: List[ho.HOrientedConvol] = []
155
+ self.hconv_enc: List[ho.SphericalStencil] = []
156
+ self.hconv_dec: List[ho.SphericalStencil] = []
129
157
 
130
158
  # dummy data to propagate shapes/ids through ud_grade_2
131
159
  l_data = self.f.backend.bk_cast(np.zeros((1, 1, cell_ids.shape[0]), dtype=self.np_dtype))
132
160
 
133
161
  for l in range(depth):
134
162
  # operator at encoder level l
135
- hc = ho.HOrientedConvol(current_nside,
136
- self.KERNELSZ,
137
- cell_ids=self.l_cell_ids[l],
138
- dtype=self.dtype)
139
- #hc.make_idx_weights()
163
+ hc = ho.SphericalStencil(current_nside,
164
+ self.KERNELSZ,
165
+ n_gauges = self.G,
166
+ gauge_type=self.gauge_type,
167
+ cell_ids=self.l_cell_ids[l],
168
+ dtype=self.torch_dtype)
140
169
 
141
170
  self.hconv_enc.append(hc)
142
171
 
143
172
  # downsample once to get next level ids and new data shape
144
173
  l_data, next_ids = hc.Down(
145
- l_data, cell_ids=self.l_cell_ids[l], nside=current_nside
174
+ l_data, cell_ids=self.l_cell_ids[l], nside=current_nside,max_poll=self.max_poll
146
175
  )
147
176
  self.l_cell_ids[l + 1] = self.f.backend.to_numpy(next_ids)
148
177
  current_nside //= 2
@@ -158,20 +187,23 @@ class HealpixUNet(nn.Module):
158
187
 
159
188
  inC = self.n_chan_in
160
189
  for l, outC in enumerate(self.chanlist):
190
+ if outC % self.G != 0:
191
+ raise ValueError(f"chanlist[{l}] = {outC} must be divisible by G={self.G}")
192
+ outC_g = outC // self.G
161
193
 
162
- # conv1: inC -> outC
163
- w1 = torch.empty(inC, outC, self.KERNELSZ * self.KERNELSZ)
164
- nn.init.kaiming_uniform_(w1.view(inC * outC, -1), a=np.sqrt(5))
194
+ # conv1: inC -> outC (via multi-gauge => noyau (Ci, Co_g, P))
195
+ w1 = torch.empty(inC, outC_g, self.KERNELSZ * self.KERNELSZ)
196
+ nn.init.kaiming_uniform_(w1.view(inC * outC_g, -1), a=np.sqrt(5))
165
197
  self.enc_w1.append(nn.Parameter(w1))
166
- self.enc_bn1.append(nn.BatchNorm1d(outC))
198
+ self.enc_bn1.append(self._norm_1d(outC, kind="group"))
167
199
 
168
- # conv2: outC -> outC
169
- w2 = torch.empty(outC, outC, self.KERNELSZ * self.KERNELSZ)
170
- nn.init.kaiming_uniform_(w2.view(outC * outC, -1), a=np.sqrt(5))
200
+ # conv2: outC -> outC (entrée = total outC ; noyau (outC, outC_g, P))
201
+ w2 = torch.empty(outC, outC_g, self.KERNELSZ * self.KERNELSZ)
202
+ nn.init.kaiming_uniform_(w2.view(outC * outC_g, -1), a=np.sqrt(5))
171
203
  self.enc_w2.append(nn.Parameter(w2))
172
- self.enc_bn2.append(nn.BatchNorm1d(outC))
204
+ self.enc_bn2.append(self._norm_1d(outC, kind="group"))
173
205
 
174
- inC = outC # next level input channels
206
+ inC = outC # next layer sees total channels
175
207
 
176
208
  # decoder conv weights and BN (mirrored levels)
177
209
  self.dec_w1 = nn.ParameterList()
@@ -181,45 +213,88 @@ class HealpixUNet(nn.Module):
181
213
 
182
214
  for d in range(depth):
183
215
  level = depth - 1 - d # encoder level we are going back to
184
- hc = ho.HOrientedConvol(self.enc_nsides[level],
185
- self.KERNELSZ,
186
- cell_ids=self.l_cell_ids[level],
187
- dtype=self.dtype)
216
+ hc = ho.SphericalStencil(self.enc_nsides[level],
217
+ self.KERNELSZ,
218
+ n_gauges = self.G,
219
+ gauge_type=self.gauge_type,
220
+ cell_ids=self.l_cell_ids[level],
221
+ dtype=self.torch_dtype)
188
222
  #hc.make_idx_weights()
189
223
  self.hconv_dec.append(hc)
190
224
 
191
225
  upC = self.chanlist[level + 1] if level + 1 < depth else self.chanlist[level]
192
226
  skipC = self.chanlist[level]
193
- inC_dec = upC + skipC
194
- outC_dec = skipC
227
+ inC_dec = upC + skipC # total en entrée
228
+ outC_dec = skipC # total en sortie (ce que tu avais déjà)
229
+
230
+ if outC_dec % self.G != 0:
231
+ raise ValueError(f"decoder outC at level {level} = {outC_dec} must be divisible by G={self.G}")
232
+ outC_dec_g = outC_dec // self.G
195
233
 
196
- w1 = torch.empty(inC_dec, outC_dec, self.KERNELSZ * self.KERNELSZ)
197
- nn.init.kaiming_uniform_(w1.view(inC_dec * outC_dec, -1), a=np.sqrt(5))
234
+ w1 = torch.empty(inC_dec, outC_dec_g, self.KERNELSZ * self.KERNELSZ)
235
+ nn.init.kaiming_uniform_(w1.view(inC_dec * outC_dec_g, -1), a=np.sqrt(5))
198
236
  self.dec_w1.append(nn.Parameter(w1))
199
- self.dec_bn1.append(nn.BatchNorm1d(outC_dec))
237
+ self.dec_bn1.append(self._norm_1d(outC_dec, kind="group"))
200
238
 
201
- w2 = torch.empty(outC_dec, outC_dec, self.KERNELSZ * self.KERNELSZ)
202
- nn.init.kaiming_uniform_(w2.view(outC_dec * outC_dec, -1), a=np.sqrt(5))
239
+ w2 = torch.empty(outC_dec, outC_dec_g, self.KERNELSZ * self.KERNELSZ)
240
+ nn.init.kaiming_uniform_(w2.view(outC_dec * outC_dec_g, -1), a=np.sqrt(5))
203
241
  self.dec_w2.append(nn.Parameter(w2))
204
- self.dec_bn2.append(nn.BatchNorm1d(outC_dec))
242
+ self.dec_bn2.append(self._norm_1d(outC_dec, kind="group"))
205
243
 
206
244
  # Output head (on finest grid, channels = chanlist[0])
207
- self.head_hconv = ho.HOrientedConvol(self.in_nside,
208
- self.KERNELSZ,
209
- cell_ids=self.l_cell_ids[0],
210
- dtype=self.dtype)
211
-
212
- head_inC = self.chanlist[0]
213
- self.head_w = nn.Parameter(torch.empty(head_inC, self.out_channels, self.KERNELSZ * self.KERNELSZ))
214
- nn.init.kaiming_uniform_(self.head_w.view(head_inC * self.out_channels, -1), a=np.sqrt(5))
215
- self.head_bn = nn.BatchNorm1d(self.out_channels) if self.task == 'segmentation' else None
245
+ self.head_hconv = ho.SphericalStencil(self.in_nside,
246
+ self.KERNELSZ,
247
+ n_gauges=self.G, #Mandatory for the output
248
+ gauge_type=self.gauge_type,
249
+ cell_ids=self.l_cell_ids[0],
250
+ dtype=self.torch_dtype)
216
251
 
252
+ head_inC = self.chanlist[0]
253
+ if self.out_channels % self.G != 0:
254
+ raise ValueError(f"out_channels={self.out_channels} must be divisible by G={self.G}")
255
+ outC_head_g = self.out_channels // self.G
256
+
257
+ self.head_w = nn.Parameter(
258
+ torch.empty(head_inC, outC_head_g, self.KERNELSZ * self.KERNELSZ)
259
+ )
260
+ nn.init.kaiming_uniform_(self.head_w.view(head_inC * outC_head_g, -1), a=np.sqrt(5))
261
+ self.head_bn = self._norm_1d(self.out_channels, kind="group") if self.task == 'segmentation' else None
262
+
263
+ # Choose how to reduce across gauges at head:
264
+ # 'sum' (default), 'mean', or 'learned' (via 1x1 conv).
265
+ self.head_reduce = getattr(self, 'head_reduce', 'mean') # you can turn this into a ctor arg if you like
266
+ if self.head_reduce == 'learned':
267
+ # Mixer takes G*outC_head_g -> out_channels (K-wise 1x1)
268
+ self.head_mixer = nn.Conv1d(self.G * outC_head_g, self.out_channels, kernel_size=1, bias=True)
269
+ else:
270
+ self.head_mixer = None
271
+
217
272
  # ---- Decide runtime device (probe Foscat on CUDA, else CPU) ----
218
273
  self.runtime_device = self._probe_and_set_runtime_device(self.device)
219
274
 
275
+ # -------------------------- define local batchnorm/group -------------------
276
+ def _norm_1d(self, C: int, kind: str = "group", **kwargs) -> nn.Module:
277
+ """
278
+ Return a normalization layer for (B, C, N) tensors.
279
+ kind: "group" | "instance" | "batch"
280
+ kwargs: extra args (e.g., num_groups for GroupNorm)
281
+ """
282
+ if kind == "group":
283
+ num_groups = kwargs.get("num_groups", min(8, max(1, C // 8)) or 1)
284
+ # s’assurer que num_groups divise C
285
+ while C % num_groups != 0 and num_groups > 1:
286
+ num_groups //= 2
287
+ return nn.GroupNorm(num_groups=num_groups, num_channels=C)
288
+ elif kind == "instance":
289
+ return nn.InstanceNorm1d(C, affine=True, track_running_stats=False)
290
+ elif kind == "batch":
291
+ return nn.BatchNorm1d(C)
292
+ else:
293
+ raise ValueError(f"Unknown norm kind: {kind}")
294
+
220
295
  # -------------------------- device plumbing --------------------------
221
- def _move_hconv_tensors(self, hc: ho.HOrientedConvol, device: torch.device) -> None:
222
- """Best-effort: move any torch.Tensor attribute of HOrientedConvol to device."""
296
+ def _move_hconv_tensors(self, hc: ho.SphericalStencil, device: torch.device) -> None:
297
+ """Best-effort: move any torch.Tensor attribute of SphericalStencil to device."""
223
298
  for name, val in list(vars(hc).items()):
224
299
  try:
225
300
  if torch.is_tensor(val):
@@ -265,6 +340,78 @@ class HealpixUNet(nn.Module):
265
340
  self.device = device
266
341
  self.runtime_device = self._probe_and_set_runtime_device(device)
267
342
  return self.runtime_device
343
+
344
+ # --- inside HealpixUNet class, add a single-sample forward helper ---
345
+ def _forward_one(self, x1: torch.Tensor, cell_ids1=None) -> torch.Tensor:
346
+ """
347
+ Single-sample forward. x1: (1, C_in, Npix_1). Returns (1, out_channels, Npix_1).
348
+ `cell_ids1` can be None or a 1D array (Npix_1,) for this sample.
349
+ """
350
+ if x1.dim() != 3 or x1.shape[0] != 1:
351
+ raise ValueError(f"_forward_one expects (1, C, Npix), got {tuple(x1.shape)}")
352
+ # Reuse existing forward by calling it with B=1 (your code already supports per-sample ids)
353
+ if cell_ids1 is None:
354
+ return super().forward(x1)
355
+ else:
356
+ # normalize ids to numpy 1D
357
+ if isinstance(cell_ids1, torch.Tensor):
358
+ cell_ids1 = cell_ids1.detach().cpu().numpy()
359
+ elif isinstance(cell_ids1, list):
360
+ cell_ids1 = np.asarray(cell_ids1)
361
+ if cell_ids1.ndim == 1:
362
+ ci = cell_ids1[None, :] # (1, Npix_1) so the current code path is happy
363
+ else:
364
+ ci = cell_ids1
365
+ return super().forward(x1, cell_ids=ci)
366
+
367
+ def _as_tensor_batch(self, x):
368
+ """
369
+ Ensure a (B, C, N) tensor.
370
+ - If x is a list of tensors, concatenate if all N are equal.
371
+ - If len==1, keep a batch dim (1, C, N).
372
+ - If x is already a tensor, return as-is.
373
+ """
374
+ if isinstance(x, list):
375
+ if len(x) == 1:
376
+ t = x[0]
377
+ # If t is (C, N) -> make it (1, C, N)
378
+ return t.unsqueeze(0) if t.dim() == 2 else t
379
+ # all same length -> concat along batch
380
+ Ns = [t.shape[-1] for t in x]
381
+ if all(n == Ns[0] for n in Ns):
382
+ return torch.cat([t if t.dim() == 3 else t.unsqueeze(0) for t in x], dim=0)
383
+ # variable-length with B>1 not supported in a single tensor
384
+ raise ValueError("Variable-length batch detected; use batch_size=1 or loop per-sample.")
385
+ return x
386
+
387
+ # --- replace your current `forward` signature/body with a dispatcher ---
388
+ def forward_any(self, x, cell_ids: Optional[np.ndarray] = None):
389
+ """
390
+ If `x` is a Tensor (B,C,N): standard batched path (requires same N for all).
391
+ If `x` is a list of Tensors: variable-length per-sample path, returns a list of outputs.
392
+ """
393
+ # Variable-length list path
394
+ if isinstance(x, (list, tuple)):
395
+ outs = []
396
+ if cell_ids is None or isinstance(cell_ids, (list, tuple)):
397
+ cids = cell_ids if isinstance(cell_ids, (list, tuple)) else [None] * len(x)
398
+ else:
399
+ raise ValueError("When x is a list, cell_ids must be a list of same length or None.")
400
+
401
+ for xb, cb in zip(x, cids):
402
+ if not torch.is_tensor(xb):
403
+ xb = torch.as_tensor(xb, dtype=torch.float32, device=self.runtime_device)
404
+ if xb.dim() == 2:
405
+ xb = xb.unsqueeze(0) # (1,C,Nb)
406
+ elif xb.dim() != 3 or xb.shape[0] != 1:
407
+ raise ValueError(f"Each sample must be (C,N) or (1,C,N); got {tuple(xb.shape)}")
408
+
409
+ yb = self._forward_one(xb.to(self.runtime_device), cell_ids1=cb) # (1,Co,Nb)
410
+ outs.append(yb.squeeze(0)) # -> (Co, Nb)
411
+ return outs # List[Tensor] (each length Nb)
412
+
413
+ # Fixed-length tensor path (your current implementation)
414
+ return super().forward(x, cell_ids=cell_ids)
268
415
 
269
416
  # -------------------------- forward --------------------------
270
417
  def forward(self, x: torch.Tensor,cell_ids: Optional[np.ndarray ] = None) -> torch.Tensor:
@@ -306,6 +453,7 @@ class HealpixUNet(nn.Module):
306
453
  l_data = self.hconv_enc[l].Convol_torch(l_data,
307
454
  self.enc_w1[l],
308
455
  cell_ids=l_cell_ids)
456
+ l_data = self._as_tensor_batch(l_data)
309
457
  l_data = self.enc_bn1[l](l_data)
310
458
  l_data = F.relu(l_data, inplace=True)
311
459
 
@@ -313,6 +461,7 @@ class HealpixUNet(nn.Module):
313
461
  l_data = self.hconv_enc[l].Convol_torch(l_data,
314
462
  self.enc_w2[l],
315
463
  cell_ids=l_cell_ids)
464
+ l_data = self._as_tensor_batch(l_data)
316
465
  l_data = self.enc_bn2[l](l_data)
317
466
  l_data = F.relu(l_data, inplace=True)
318
467
 
@@ -321,12 +470,10 @@ class HealpixUNet(nn.Module):
321
470
 
322
471
  # downsample (except bottom level) -> ensure output is on runtime_device
323
472
  if l < len(self.chanlist) - 1:
324
- #l_data, l_cell_ids = self.f.ud_grade_2(
325
- # l_data, cell_ids=t_cell_ids[l], nside=current_nside
326
- #)
327
473
  l_data, l_cell_ids = self.hconv_enc[l].Down(
328
- l_data, cell_ids=t_cell_ids[l], nside=current_nside
474
+ l_data, cell_ids=t_cell_ids[l], nside=current_nside,max_poll=self.max_poll
329
475
  )
476
+ l_data = self._as_tensor_batch(l_data)
330
477
  if cell_ids is not None:
331
478
  t_cell_ids[l+1]=l_cell_ids
332
479
  else:
@@ -335,34 +482,24 @@ class HealpixUNet(nn.Module):
335
482
  if isinstance(l_data, torch.Tensor) and l_data.device != self.runtime_device:
336
483
  l_data = l_data.to(self.runtime_device)
337
484
  current_nside //= 2
338
-
485
+
339
486
  # Decoder
340
487
  for d in range(len(self.chanlist)):
341
- level = len(self.chanlist) - 1 - d # corresponding encoder level
488
+ level = len(self.chanlist) - 1 - d # encoder level we are going back to
342
489
 
343
490
  if level < len(self.chanlist) - 1:
344
- # upsample to next finer grid
345
- # upsample to next finer grid (from level+1 -> level)
346
- src_nside = self.enc_nsides[level + 1] # current (coarser)
347
- tgt_nside = self.enc_nsides[level] # next finer (== src*2)
348
- # Foscat up_grade signature expects current (coarse) ids in `cell_ids`
349
- # and target (fine) ids in `o_cell_ids` (matching original UNET code).
350
- '''
351
- l_data = self.f.up_grade(
352
- l_data,
353
- tgt_nside,
354
- cell_ids=t_cell_ids[level + 1], # source (coarser) ids
355
- o_cell_ids=t_cell_ids[level], # target (finer) ids
356
- nside=src_nside,
357
- )
358
- '''
359
- l_data = self.hconv_enc[l].Up(
491
+ # upsample: from encoder level (level+1) [coarser] -> level [finer]
492
+ src_nside = self.enc_nsides[level + 1] # coarse
493
+
494
+ # Use the **decoder** operator at this step (consistent with your hconv_dec stack)
495
+ l_data = self.hconv_dec[d].Up(
360
496
  l_data,
361
- cell_ids=t_cell_ids[level + 1], # source (coarser) ids
362
- o_cell_ids=t_cell_ids[level], # target (finer) ids
497
+ cell_ids=t_cell_ids[level + 1], # source/coarse IDs
498
+ o_cell_ids=t_cell_ids[level], # target/fine IDs
363
499
  nside=src_nside,
364
500
  )
365
-
501
+ l_data = self._as_tensor_batch(l_data)
502
+
366
503
  if isinstance(l_data, torch.Tensor) and l_data.device != self.runtime_device:
367
504
  l_data = l_data.to(self.runtime_device)
368
505
 
@@ -370,49 +507,85 @@ class HealpixUNet(nn.Module):
370
507
  concat = self.f.backend.bk_concat([skips[level], l_data], 1)
371
508
  l_data = concat.to(self.runtime_device) if torch.is_tensor(concat) else concat
372
509
 
373
- if cell_ids is not None:
374
- l_cell_ids = t_cell_ids[level]
375
-
376
- # apply decoder convs on this grid
510
+ # choose the right cell_ids for convolutions at this resolution
511
+ l_cell_ids = t_cell_ids[level] if (cell_ids is not None) else None
512
+
513
+ # apply decoder convs on this grid using the matching decoder operator
377
514
  hc = self.hconv_dec[d]
378
- l_data = hc.Convol_torch(l_data,
379
- self.dec_w1[d],
380
- cell_ids=l_cell_ids)
381
-
515
+ l_data = hc.Convol_torch(l_data, self.dec_w1[d], cell_ids=l_cell_ids)
516
+ l_data = self._as_tensor_batch(l_data)
382
517
  l_data = self.dec_bn1[d](l_data)
383
518
  l_data = F.relu(l_data, inplace=True)
384
519
 
385
- l_data = hc.Convol_torch(l_data,
386
- self.dec_w2[d],
387
- cell_ids=l_cell_ids)
388
-
520
+ l_data = hc.Convol_torch(l_data, self.dec_w2[d], cell_ids=l_cell_ids)
521
+ l_data = self._as_tensor_batch(l_data)
389
522
  l_data = self.dec_bn2[d](l_data)
390
523
  l_data = F.relu(l_data, inplace=True)
391
524
 
392
525
  # Head on finest grid
393
- out = self.head_hconv.Convol_torch(l_data, self.head_w)
394
- if self.head_bn is not None:
395
- out = self.head_bn(out)
526
+ # y_head_raw: (B, G*outC_head_g, K)
527
+ y_head_raw = self.head_hconv.Convol_torch(l_data, self.head_w, cell_ids=l_cell_ids)
528
+
529
+ B, Ctot, K = y_head_raw.shape
530
+ outC_head_g = int(self.out_channels)//self.G
531
+ assert Ctot == self.G * outC_head_g, \
532
+ f"Head expects G*outC_head_g channels, got {Ctot} != {self.G}*{outC_head_g}"
533
+
534
+ if self.head_mixer is not None and self.head_reduce == 'learned':
535
+ # 1x1 learned mixing across G*outC_head_g -> out_channels
536
+ y = self.head_mixer(y_head_raw) # (B, out_channels, K)
537
+ else:
538
+ # reshape to (B, G, outC_head_g, K) then reduce across G
539
+ y_g = y_head_raw.view(B, self.G, outC_head_g, K)
540
+
541
+ y = y_g.mean(dim=1) # (B, outC_head_g, K)
542
+
543
+ y = self._as_tensor_batch(y)
544
+
545
+ # Optional BN + activation as before
546
+ if self.task == 'segmentation' and self.head_bn is not None:
547
+ y = self.head_bn(y)
548
+
396
549
  if self.final_activation == 'sigmoid':
397
- out = torch.sigmoid(out)
550
+ y = torch.sigmoid(y)
551
+
398
552
  elif self.final_activation == 'softmax':
399
- out = torch.softmax(out, dim=1)
400
- return out
553
+ y = torch.softmax(y, dim=1)
554
+
555
+ return y
401
556
 
402
557
  # -------------------------- utilities --------------------------
403
558
  @torch.no_grad()
404
559
  def predict(self, x: torch.Tensor, batch_size: int = 8,cell_ids: Optional[np.ndarray ] = None) -> torch.Tensor:
405
560
  self.eval()
406
561
  outs = []
407
- for i in range(0, x.shape[0], batch_size):
408
- if cell_ids is not None:
409
- outs.append(self.forward(x[i : i + batch_size],
410
- cell_ids=cell_ids[i : i + batch_size]))
411
- else:
412
- outs.append(self.forward(x[i : i + batch_size]))
562
+ if isinstance(x,np.ndarray):
563
+ x=self.to_Tensor(x)
564
+
565
+ if not isinstance(x, torch.Tensor):
566
+ for i in range(len(x)):
567
+ if cell_ids is not None:
568
+ outs.append(self.forward(x[i][None,:],cell_ids=cell_ids[i][:]))
569
+ else:
570
+ outs.append(self.forward(x[i][None,:]))
571
+ else:
572
+ for i in range(0, x.shape[0], batch_size):
573
+ if cell_ids is not None:
574
+ outs.append(self.forward(x[i : i + batch_size],
575
+ cell_ids=cell_ids[i : i + batch_size]))
576
+ else:
577
+ outs.append(self.forward(x[i : i + batch_size]))
413
578
 
414
579
  return torch.cat(outs, dim=0)
415
580
 
581
+ def to_tensor(self,x):
582
+ return self.hconv_enc[0].f.backend.bk_cast(x)
583
+
584
+ def to_numpy(self,x):
585
+ if isinstance(x,np.ndarray):
586
+ return x
587
+ return x.cpu().numpy()
588
+
416
589
  # -----------------------------
417
590
  # Kernel extraction & plotting
418
591
  # -----------------------------
@@ -609,109 +782,421 @@ if __name__ == "__main__":
609
782
  np.testing.assert_allclose(y1_np, y3_np, rtol=0, atol=0)
610
783
 
611
784
  unittest.main()
785
+
786
+ from torch.utils.data import Dataset
787
+ # 1) Dataset that omits cell_ids when None
788
+ from torch.utils.data import Dataset, DataLoader, TensorDataset
789
+
790
+ class HealpixDataset(Dataset):
791
+ """
792
+ Returns (x, y, cell_ids) per-sample if cell_ids is given, else (x, y).
793
+ Shapes:
794
+ x: (C, Npix)
795
+ y: (C_out or 1, Npix)
796
+ cell_ids: (Npix,) per-sample (or broadcasted from (Npix,))
797
+ """
798
+ def __init__(self, x, y, cell_ids=None, dtype=torch.float32):
799
+ self.x = torch.as_tensor(x, dtype=dtype)
800
+ self.y = torch.as_tensor(y, dtype=dtype)
801
+ assert self.x.shape[0] == self.y.shape[0], "x and y must share batch size"
802
+ self._has_cids = cell_ids is not None
803
+ if self._has_cids:
804
+ cid = torch.as_tensor(cell_ids, dtype=torch.long)
805
+ if cid.dim() == 1:
806
+ cid = cid.unsqueeze(0).expand(self.x.shape[0], -1)
807
+ assert cid.shape[0] == self.x.shape[0], "cell_ids must match batch size"
808
+ self.cids = cid
809
+ else:
810
+ self.cids = None
811
+
812
+ def __len__(self):
813
+ return self.x.shape[0]
814
+
815
+ def __getitem__(self, i):
816
+ if self._has_cids:
817
+ return self.x[i], self.y[i], self.cids[i]
818
+ else:
819
+ return self.x[i], self.y[i]
820
+
821
+ # ---------------------------
822
+ # Datasets / Collate helpers
823
+ # ---------------------------
824
+
825
+ class HealpixDataset(Dataset):
826
+ """
827
+ Fixed-grid dataset (common Npix for all samples).
828
+ Returns (x, y) if cell_ids is None, else (x, y, cell_ids).
829
+
830
+ x: (B, C, Npix)
831
+ y: (B, C_out or 1, Npix) or class indices depending on task
832
+ cell_ids: (Npix,) or (B, Npix)
833
+ """
834
+ def __init__(self,
835
+ x: torch.Tensor,
836
+ y: torch.Tensor,
837
+ cell_ids: Optional[Union[np.ndarray, torch.Tensor]] = None,
838
+ dtype: torch.dtype = torch.float32):
839
+ x = torch.as_tensor(x, dtype=dtype)
840
+ y = torch.as_tensor(y, dtype=dtype if y.ndim == 3 else torch.long)
841
+ assert x.shape[0] == y.shape[0], "x and y must share batch size"
842
+ self.x, self.y = x, y
843
+ self._has_cids = cell_ids is not None
844
+ if self._has_cids:
845
+ c = torch.as_tensor(cell_ids, dtype=torch.long)
846
+ if c.ndim == 1: # broadcast single (Npix,) to (B, Npix)
847
+ c = c.unsqueeze(0).expand(x.shape[0], -1)
848
+ assert c.shape == (x.shape[0], x.shape[2]), "cell_ids must be (B,Npix) or (Npix,)"
849
+ self.cids = c
850
+ else:
851
+ self.cids = None
852
+
853
+ def __len__(self) -> int: return self.x.shape[0]
854
+
855
+ def __getitem__(self, i: int):
856
+ if self._has_cids:
857
+ return self.x[i], self.y[i], self.cids[i]
858
+ return self.x[i], self.y[i]
859
+
860
+ # ---------------------------
861
+ # Datasets / Collate helpers
862
+ # ---------------------------
863
+
864
+ class HealpixDataset(Dataset):
865
+ """
866
+ Fixed-grid dataset (common Npix for all samples).
867
+ Returns (x, y) if cell_ids is None, else (x, y, cell_ids).
868
+
869
+ x: (B, C, Npix)
870
+ y: (B, C_out or 1, Npix) or class indices depending on task
871
+ cell_ids: (Npix,) or (B, Npix)
872
+ """
873
+ def __init__(self,
874
+ x: torch.Tensor,
875
+ y: torch.Tensor,
876
+ cell_ids: Optional[Union[np.ndarray, torch.Tensor]] = None,
877
+ dtype: torch.dtype = torch.float32):
878
+ x = torch.as_tensor(x, dtype=dtype)
879
+ y = torch.as_tensor(y, dtype=dtype if y.ndim == 3 else torch.long)
880
+ assert x.shape[0] == y.shape[0], "x and y must share batch size"
881
+ self.x, self.y = x, y
882
+ self._has_cids = cell_ids is not None
883
+ if self._has_cids:
884
+ c = torch.as_tensor(cell_ids, dtype=torch.long)
885
+ if c.ndim == 1: # broadcast single (Npix,) to (B, Npix)
886
+ c = c.unsqueeze(0).expand(x.shape[0], -1)
887
+ assert c.shape == (x.shape[0], x.shape[2]), "cell_ids must be (B,Npix) or (Npix,)"
888
+ self.cids = c
889
+ else:
890
+ self.cids = None
891
+
892
+ def __len__(self) -> int: return self.x.shape[0]
893
+
894
+ def __getitem__(self, i: int):
895
+ if self._has_cids:
896
+ return self.x[i], self.y[i], self.cids[i]
897
+ return self.x[i], self.y[i]
898
+
899
+
900
+ class VarLenHealpixDataset(Dataset):
901
+ """
902
+ Variable-length per-sample dataset.
903
+
904
+ x_list[b]: (C, Npix_b) or (1, C, Npix_b)
905
+ y_list[b]: (C_out or 1, Npix_b) or (1, C_out, Npix_b) (regression/segmentation targets)
906
+ For multi-class segmentation with CrossEntropyLoss, you may pass
907
+ class indices of shape (Npix_b,) or (1, Npix_b) (we’ll squeeze later).
908
+ cids_list[b]: (Npix_b,) or None
909
+ """
910
+ def __init__(self,
911
+ x_list: List[Union[np.ndarray, torch.Tensor]],
912
+ y_list: List[Union[np.ndarray, torch.Tensor]],
913
+ cids_list: Optional[List[Union[np.ndarray, torch.Tensor]]] = None,
914
+ dtype: torch.dtype = torch.float32):
915
+ assert len(x_list) == len(y_list), "x_list and y_list must have the same length"
916
+ self.x = [torch.as_tensor(x, dtype=dtype) for x in x_list]
917
+ # y can be float (regression) or long (class indices); we’ll coerce later per task
918
+ self.y = [torch.as_tensor(y) for y in y_list]
919
+ if cids_list is not None:
920
+ assert len(cids_list) == len(x_list), "cids_list must match x_list length"
921
+ self.c = [torch.as_tensor(c, dtype=torch.long) for c in cids_list]
922
+ else:
923
+ self.c = None
924
+
925
+ def __len__(self) -> int: return len(self.x)
926
+
927
+ def __getitem__(self, i: int):
928
+ ci = None if self.c is None else self.c[i]
929
+ return self.x[i], self.y[i], ci
930
+
931
+ from torch.utils.data import Dataset, DataLoader
932
+
933
+ class VarLenHealpixDataset(Dataset):
934
+ """
935
+ x_list: list of (C, Npix_b) tensors or arrays
936
+ y_list: list of (C_out or 1, Npix_b) tensors or arrays
937
+ cids_list: optional list of (Npix_b,) arrays
938
+ """
939
+ def __init__(self, x_list, y_list, cids_list=None, dtype=torch.float32):
940
+ assert len(x_list) == len(y_list)
941
+ self.x = [torch.as_tensor(x, dtype=dtype) for x in x_list]
942
+ self.y = [torch.as_tensor(y, dtype=dtype) for y in y_list]
943
+ self.c = None
944
+ if cids_list is not None:
945
+ assert len(cids_list) == len(x_list)
946
+ self.c = [np.asarray(c) for c in cids_list]
947
+
948
+ def __len__(self): return len(self.x)
949
+
950
+ def __getitem__(self, i):
951
+ if self.c is None:
952
+ return self.x[i], self.y[i], None
953
+ return self.x[i], self.y[i], self.c[i]
954
+
955
+ def varlen_collate(batch):
956
+ # Just return lists; do not stack.
957
+ xs, ys, cs = zip(*batch) # tuples of length B
958
+ # keep None if all Nones, else list
959
+ c_out = None if all(c is None for c in cs) else list(cs)
960
+ return list(xs), list(ys), c_out
961
+
962
+ def varlen_collate(batch):
963
+ """
964
+ Collate for variable-length samples: keep lists, do NOT stack.
965
+ Returns lists: xs, ys, cs (cs can be None).
966
+ """
967
+ xs, ys, cs = zip(*batch)
968
+ c_out = None if all(c is None for c in cs) else list(cs)
969
+ return list(xs), list(ys), c_out
970
+
971
+
972
+ # ---------------------------
973
+ # Training function
974
+ # ---------------------------
612
975
 
613
976
  def fit(
614
- model: HealpixUNet,
615
- x_train: torch.Tensor | np.ndarray,
616
- y_train: torch.Tensor | np.ndarray,
977
+ model,
978
+ x_train: Union[torch.Tensor, np.ndarray, List[Union[torch.Tensor, np.ndarray]]],
979
+ y_train: Union[torch.Tensor, np.ndarray, List[Union[torch.Tensor, np.ndarray]]],
617
980
  *,
981
+ cell_ids_train: Optional[Union[np.ndarray, torch.Tensor, List[Union[np.ndarray, torch.Tensor]]]] = None,
618
982
  n_epoch: int = 10,
619
983
  view_epoch: int = 10,
620
984
  batch_size: int = 16,
621
985
  lr: float = 1e-3,
622
986
  weight_decay: float = 0.0,
623
- clip_grad_norm: float | None = None,
987
+ clip_grad_norm: Optional[float] = None,
624
988
  verbose: bool = True,
625
- optimizer: Literal['ADAM', 'LBFGS'] = 'LBFGS',
989
+ optimizer: Literal['ADAM', 'LBFGS'] = 'ADAM',
626
990
  ) -> dict:
627
- """Train helper for the torch-ified HEALPix U-Net.
628
-
629
- Optimizes all registered parameters (kernels + BN affine) with Adam on MSE for regression,
630
- or CrossEntropy/BCE for segmentation.
631
-
632
- Device policy
633
- -------------
634
- Uses the model's probed runtime device (CUDA if Foscat conv works there; otherwise CPU).
635
991
  """
636
- import numpy as _np
637
- from torch.utils.data import TensorDataset, DataLoader
992
+ Train helper that supports:
993
+ - Fixed-grid tensors (B,C,N) with optional (B,N) or (N,) cell_ids.
994
+ - Variable-length lists: x=[(C,N_b)], y=[...], cell_ids=[(N_b,)], returning per-sample grids.
638
995
 
639
- # Ensure model is on its runtime device (already probed in __init__)
640
- model.to(model.runtime_device)
996
+ ADAM: standard minibatch update.
997
+ LBFGS: uses a closure that sums losses over the current (variable-length) mini-batch.
641
998
 
642
- def _to_t(x):
643
- if isinstance(x, torch.Tensor):
644
- return x.float().to(model.runtime_device)
645
- return torch.from_numpy(_np.asarray(x)).float().to(model.runtime_device)
646
-
647
- x_t = _to_t(x_train)
648
- y_t = _to_t(y_train)
649
-
650
- # choose loss
651
- if model.task == 'regression':
652
- criterion = nn.MSELoss()
999
+ Notes
1000
+ -----
1001
+ - For segmentation with multiple classes, pass integer class targets for y:
1002
+ fixed-grid: (B, N) int64; variable-length: each y[b] of shape (N_b,) or (1,N_b).
1003
+ - For regression, pass float targets with the same (C_out, N) channeling.
1004
+ """
1005
+ device = model.runtime_device if hasattr(model, "runtime_device") else (torch.device("cuda") if torch.cuda.is_available() else torch.device("cpu"))
1006
+ model.to(device)
1007
+
1008
+ # Detect variable-length mode
1009
+ varlen_mode = isinstance(x_train, (list, tuple))
1010
+
1011
+ # ----- Build DataLoader
1012
+ if not varlen_mode:
1013
+ # Fixed-grid path
1014
+ x_t = torch.as_tensor(x_train, dtype=torch.float32, device=device)
1015
+ y_is_class = (model.task != 'regression' and getattr(model, "out_channels", 1) > 1)
1016
+ y_dtype = torch.long if y_is_class and (not torch.is_tensor(y_train) or y_train.ndim != 3) else torch.float32
1017
+ y_t = torch.as_tensor(y_train, dtype=y_dtype, device=device)
1018
+
1019
+ if cell_ids_train is None:
1020
+ ds = TensorDataset(x_t, y_t)
1021
+ loader = DataLoader(ds, batch_size=batch_size, shuffle=True, drop_last=False)
1022
+ with_cell_ids = False
1023
+ else:
1024
+ ds = HealpixDataset(x_t, y_t, cell_ids=cell_ids_train, dtype=torch.float32)
1025
+ loader = DataLoader(ds, batch_size=batch_size, shuffle=True, drop_last=False)
1026
+ with_cell_ids = True
1027
+ else:
1028
+ # Variable-length path
1029
+ ds = VarLenHealpixDataset(x_train, y_train, cids_list=cell_ids_train, dtype=torch.float32)
1030
+ loader = DataLoader(ds, batch_size=batch_size, shuffle=True, drop_last=False, collate_fn=varlen_collate)
1031
+ with_cell_ids = cell_ids_train is not None
1032
+
1033
+ # ----- Loss
1034
+ if getattr(model, "task", "regression") == 'regression':
1035
+ criterion = nn.MSELoss(reduction='mean')
1036
+ seg_multiclass = False
653
1037
  else:
654
- if model.out_channels == 1:
655
- criterion = nn.BCEWithLogitsLoss() if model.final_activation == 'none' else nn.BCELoss()
1038
+ # segmentation
1039
+ if getattr(model, "out_channels", 1) == 1:
1040
+ # binary
1041
+ # assume model head returns logits if final_activation == 'none'
1042
+ criterion = nn.BCEWithLogitsLoss() if getattr(model, "final_activation", "none") == 'none' else nn.BCELoss()
1043
+ seg_multiclass = False
656
1044
  else:
657
- if y_t.dim() == 3:
658
- y_t = y_t.argmax(dim=1)
659
1045
  criterion = nn.CrossEntropyLoss()
1046
+ seg_multiclass = True
660
1047
 
661
- ds = TensorDataset(x_t, y_t)
662
- loader = DataLoader(ds, batch_size=batch_size, shuffle=True, drop_last=False)
663
-
664
- if optimizer=='ADAM':
1048
+ # ----- Optimizer
1049
+ if optimizer.upper() == 'ADAM':
665
1050
  optim = torch.optim.Adam(model.parameters(), lr=lr, weight_decay=weight_decay)
666
- l_n_epoch=n_epoch
667
- n_inter=1
1051
+ outer = n_epoch
1052
+ inner = 1
668
1053
  else:
669
- optim = torch.optim.LBFGS(model.parameters(), lr=lr, max_iter=20, history_size=n_epoch*5, line_search_fn="strong_wolfe")
670
- l_n_epoch=n_epoch//20
671
- n_inter=20
1054
+ optim = torch.optim.LBFGS(model.parameters(), lr=lr, max_iter=20,
1055
+ history_size=max(10, n_epoch * 5), line_search_fn="strong_wolfe")
1056
+ # emulate "epochs" with multiple inner LBFGS steps
1057
+ outer = max(1, n_epoch // 20)
1058
+ inner = 20
672
1059
 
1060
+ # ----- Training loop
673
1061
  history: List[float] = []
674
1062
  model.train()
675
1063
 
676
- for epoch in range(l_n_epoch):
677
- for k in range(n_inter):
678
- epoch_loss = 0.0
679
- n_samples = 0
680
- for xb, yb in loader:
681
- # LBFGS a besoin d'un closure qui recalcule loss et gradients
682
- if isinstance(optim, torch.optim.LBFGS):
683
- def closure():
1064
+ for epoch in range(outer):
1065
+ for _ in range(inner):
1066
+ epoch_loss, n_samples = 0.0, 0
1067
+
1068
+ for batch in loader:
1069
+ if not varlen_mode:
1070
+ # -------- fixed-grid
1071
+ if with_cell_ids:
1072
+ xb, yb, cb = batch
1073
+ cb_np = cb.detach().cpu().numpy()
1074
+ else:
1075
+ xb, yb = batch
1076
+ cb_np = None
1077
+
1078
+ xb = xb.to(device, dtype=torch.float32, non_blocking=True)
1079
+ # y type: float for regression or binary; long for CrossEntropy
1080
+ yb = yb.to(device, non_blocking=True)
1081
+
1082
+ if isinstance(optim, torch.optim.LBFGS):
1083
+ def closure():
1084
+ optim.zero_grad(set_to_none=True)
1085
+ preds = model(xb, cell_ids=cb_np) if cb_np is not None else model(xb)
1086
+ loss = criterion(preds, yb)
1087
+ loss.backward()
1088
+ return loss
1089
+ _ = optim.step(closure)
1090
+ with torch.no_grad():
1091
+ preds = model(xb, cell_ids=cb_np) if cb_np is not None else model(xb)
1092
+ loss = criterion(preds, yb)
1093
+ else:
684
1094
  optim.zero_grad(set_to_none=True)
685
- preds = model(xb)
1095
+ preds = model(xb, cell_ids=cb_np) if cb_np is not None else model(xb)
686
1096
  loss = criterion(preds, yb)
687
1097
  loss.backward()
688
- return loss
1098
+ if clip_grad_norm is not None:
1099
+ nn.utils.clip_grad_norm_(model.parameters(), clip_grad_norm)
1100
+ optim.step()
689
1101
 
690
- _ = optim.step(closure) # LBFGS appelle plusieurs fois le closure
691
- # on recalcule la loss finale pour l’agg (sans gradient)
692
- with torch.no_grad():
693
- preds = model(xb)
694
- loss = criterion(preds, yb)
1102
+ bs = xb.shape[0]
1103
+ epoch_loss += float(loss.item()) * bs
1104
+ n_samples += bs
695
1105
 
696
1106
  else:
697
- optim.zero_grad(set_to_none=True)
698
- preds = model(xb)
699
- loss = criterion(preds, yb)
700
- loss.backward()
701
- if clip_grad_norm is not None:
702
- nn.utils.clip_grad_norm_(model.parameters(), clip_grad_norm)
703
- optim.step()
704
-
705
- bs = xb.shape[0]
706
- epoch_loss += loss.item() * bs
707
- n_samples += bs
1107
+ # -------- variable-length (lists)
1108
+ xs, ys, cs = batch # lists
1109
+
1110
+ def _prep_xyc(i):
1111
+ # x_i : (C, N_i) -> (1, C, N_i)
1112
+ xb = torch.as_tensor(xs[i], device=device, dtype=torch.float32)
1113
+ if xb.dim() == 2:
1114
+ xb = xb.unsqueeze(0)
1115
+ elif xb.dim() != 3 or xb.shape[0] != 1:
1116
+ raise ValueError("Each x[i] must be (C,N) or (1,C,N)")
1117
+
1118
+ # y_i :
1119
+ yb = torch.as_tensor(ys[i], device=device)
1120
+ if seg_multiclass:
1121
+ # class indices: (N_i,) ou (1, N_i)
1122
+ if yb.dim() == 2 and yb.shape[0] == 1:
1123
+ yb = yb.squeeze(0) # -> (N_i,)
1124
+ elif yb.dim() != 1:
1125
+ raise ValueError("For multiclass CE, y[i] must be (N,) or (1,N)")
1126
+ # le critère CE recevra (1,C_out,N_i) et (N_i,)
1127
+ else:
1128
+ # régression / binaire: cible de forme (1, C_out, N_i)
1129
+ if yb.dim() == 2:
1130
+ yb = yb.unsqueeze(0)
1131
+ elif yb.dim() != 3 or yb.shape[0] != 1:
1132
+ raise ValueError("For regression/binary, y[i] must be (C_out,N) or (1,C_out,N)")
1133
+
1134
+ # cell_ids : (N_i,) -> (1, N_i) en numpy (le forward les attend en np.ndarray)
1135
+ if cs is None or cs[i] is None:
1136
+ cb_np = None
1137
+ else:
1138
+ c = cs[i].detach().cpu().numpy() if torch.is_tensor(cs[i]) else np.asarray(cs[i])
1139
+ if c.ndim == 1:
1140
+ c = c[None, :] # -> (1, N_i)
1141
+ cb_np = c
1142
+ return xb, yb, cb_np
1143
+
1144
+ if isinstance(optim, torch.optim.LBFGS):
1145
+ def closure():
1146
+ optim.zero_grad(set_to_none=True)
1147
+ total = 0.0
1148
+ for i in range(len(xs)):
1149
+ xb, yb, cb_np = _prep_xyc(i)
1150
+ preds = model(xb, cell_ids=cb_np) if cb_np is not None else model(xb)
1151
+ # adapter la cible à la sortie
1152
+ if seg_multiclass:
1153
+ loss_i = criterion(preds, yb) # preds: (1,C_out,N_i), yb: (N_i,)
1154
+ else:
1155
+ loss_i = criterion(preds, yb) # preds: (1,C_out,N_i), yb: (1,C_out,N_i)
1156
+ loss_i.backward()
1157
+ total += float(loss_i.item())
1158
+ # retourner un scalaire Tensor pour LBFGS
1159
+ return torch.tensor(total / max(1, len(xs)), device=device, dtype=torch.float32)
1160
+
1161
+ _ = optim.step(closure)
1162
+ # logging (sans grad)
1163
+ with torch.no_grad():
1164
+ total = 0.0
1165
+ for i in range(len(xs)):
1166
+ xb, yb, cb_np = _prep_xyc(i)
1167
+ preds = model(xb, cell_ids=cb_np) if cb_np is not None else model(xb)
1168
+ if seg_multiclass:
1169
+ loss_i = criterion(preds, yb)
1170
+ else:
1171
+ loss_i = criterion(preds, yb)
1172
+ total += float(loss_i.item())
1173
+ loss_val = total / max(1, len(xs))
1174
+ else:
1175
+ optim.zero_grad(set_to_none=True)
1176
+ total = 0.0
1177
+ for i in range(len(xs)):
1178
+ xb, yb, cb_np = _prep_xyc(i)
1179
+ preds = model(xb, cell_ids=cb_np) if cb_np is not None else model(xb)
1180
+ if seg_multiclass:
1181
+ loss_i = criterion(preds, yb)
1182
+ else:
1183
+ loss_i = criterion(preds, yb)
1184
+ loss_i.backward()
1185
+ total += float(loss_i.item())
1186
+ if clip_grad_norm is not None:
1187
+ nn.utils.clip_grad_norm_(model.parameters(), clip_grad_norm)
1188
+ optim.step()
1189
+ loss_val = total / max(1, len(xs))
1190
+
1191
+ epoch_loss += loss_val * max(1, len(xs))
1192
+ n_samples += max(1, len(xs))
1193
+
1194
+
708
1195
 
709
1196
  epoch_loss /= max(1, n_samples)
710
1197
  history.append(epoch_loss)
711
- if verbose and ((epoch*n_inter+k+1)%view_epoch==0 or epoch*n_inter+k==0):
712
- print(f"[epoch {epoch*n_inter+k+1}/{l_n_epoch*n_inter}] loss={epoch_loss:.6f}")
1198
+ # print every view_epoch logical step
1199
+ if verbose and ((len(history) % view_epoch == 0) or (len(history) == 1)):
1200
+ print(f"[epoch {len(history)}] loss={epoch_loss:.6f}")
713
1201
 
714
1202
  return {"loss": history}
715
-
716
-
717
- __all__ = ["HealpixUNet", "fit"]