foscat 2025.8.3__py3-none-any.whl → 2025.9.1__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
@@ -0,0 +1,717 @@
1
+ """
2
+ HEALPix U-Net (nested) with Foscat + PyTorch niceties
3
+ ----------------------------------------------------
4
+ GPU by default (when available), with graceful CPU fallback if Foscat ops are CPU-only.
5
+
6
+ - ReLU + BatchNorm after each convolution (encoder & decoder)
7
+ - Segmentation/Regression heads with optional final activation
8
+ - PyTorch-ified: inherits from nn.Module, standard state_dict
9
+ - Device management: tries CUDA first; if Foscat HOrientedConvol cannot run on CUDA, falls back to CPU
10
+
11
+ Shape convention: (B, C, Npix)
12
+
13
+ Requirements: foscat (scat_cov.funct + HOrientedConvol.Convol_torch must be differentiable on torch tensors)
14
+ """
15
+ from __future__ import annotations
16
+ from typing import List, Optional, Literal, Tuple
17
+ import numpy as np
18
+
19
+ import torch
20
+ import torch.nn as nn
21
+ import torch.nn.functional as F
22
+
23
+ import foscat.scat_cov as sc
24
+ import foscat.HOrientedConvol as ho
25
+ import matplotlib.pyplot as plt
26
+
27
+
28
+ class HealpixUNet(nn.Module):
29
+ """U-Net-like architecture on the HEALPix sphere using Foscat oriented convolutions.
30
+
31
+ Parameters
32
+ ----------
33
+ in_nside : int
34
+ Input HEALPix nside (nested scheme).
35
+ n_chan_in : int
36
+ Number of input channels.
37
+ chanlist : list[int]
38
+ Channels per encoder level (depth = len(chanlist)). Example: [16, 32, 64].
39
+ cell_ids : np.ndarray
40
+ Cell indices for the finest resolution (nside = in_nside) in nested scheme.
41
+ KERNELSZ : int, default 3
42
+ Spatial kernel size K (K x K) for oriented convolution.
43
+ task : {'regression','segmentation'}, default 'regression'
44
+ Chooses the head and default activation.
45
+ out_channels : int, default 1
46
+ Number of output channels (e.g. num_classes for segmentation).
47
+ final_activation : {'none','sigmoid','softmax'} | None
48
+ If None, uses sensible default per task: 'none' for regression, 'softmax' for segmentation (multi-class),
49
+ 'sigmoid' for segmentation when out_channels==1.
50
+ device : str | torch.device | None, default: 'cuda' if available else 'cpu'
51
+ Preferred device. The module will probe whether Foscat ops can run on CUDA; if not,
52
+ it will fallback to CPU and keep all parameters/buffers on CPU for consistency.
53
+ prefer_foscat_gpu : bool, default True
54
+ When device is CUDA, try to move Foscat operators (internal tensors) to CUDA and do a dry-run.
55
+ If the dry-run fails, everything falls back to CPU.
56
+
57
+ Notes
58
+ -----
59
+ - Two oriented convolutions per level. After each conv: BatchNorm1d + ReLU.
60
+ - Downsampling uses foscat ``ud_grade_2``; upsampling uses ``up_grade``.
61
+ - Convolution kernels are explicit parameters (shape [C_in, C_out, K*K]) and applied via ``HOrientedConvol.Convol_torch``.
62
+ - Foscat ops device is auto-probed to avoid CPU/CUDA mismatches.
63
+ """
64
+
65
+ def __init__(
66
+ self,
67
+ *,
68
+ in_nside: int,
69
+ n_chan_in: int,
70
+ chanlist: List[int],
71
+ cell_ids: np.ndarray,
72
+ KERNELSZ: int = 3,
73
+ task: Literal['regression', 'segmentation'] = 'regression',
74
+ out_channels: int = 1,
75
+ final_activation: Optional[Literal['none', 'sigmoid', 'softmax']] = None,
76
+ device: Optional[torch.device | str] = None,
77
+ prefer_foscat_gpu: bool = True,
78
+ dtype: Literal['float32','float64'] = 'float32'
79
+ ) -> None:
80
+ super().__init__()
81
+
82
+ self.dtype=dtype
83
+ if dtype=='float32':
84
+ self.np_dtype=np.float32
85
+ else:
86
+ self.np_dtype=np.float64
87
+
88
+ if cell_ids is None:
89
+ raise ValueError("cell_ids must be provided for the finest resolution.")
90
+ if len(chanlist) == 0:
91
+ raise ValueError("chanlist must be non-empty (depth >= 1).")
92
+
93
+ self.in_nside = int(in_nside)
94
+ self.n_chan_in = int(n_chan_in)
95
+ self.chanlist = list(map(int, chanlist))
96
+ self.KERNELSZ = int(KERNELSZ)
97
+ self.task = task
98
+ self.out_channels = int(out_channels)
99
+ self.prefer_foscat_gpu = bool(prefer_foscat_gpu)
100
+
101
+ # Choose default final activation if not given
102
+ if final_activation is None:
103
+ if task == 'regression':
104
+ self.final_activation = 'none'
105
+ else: # segmentation
106
+ self.final_activation = 'sigmoid' if out_channels == 1 else 'softmax'
107
+ else:
108
+ self.final_activation = final_activation
109
+
110
+ # Resolve preferred device
111
+ if device is None:
112
+ device = 'cuda' if torch.cuda.is_available() else 'cpu'
113
+ self.device = torch.device(device)
114
+
115
+ # Foscat functional wrapper (backend + grade ops)
116
+ self.f = sc.funct(KERNELSZ=self.KERNELSZ)
117
+
118
+ # ---------- Build multi-resolution bookkeeping ----------
119
+ depth = len(self.chanlist)
120
+ self.l_cell_ids: List[np.ndarray] = [None] * (depth + 1) # per encoder level + bottom
121
+ self.l_cell_ids[0] = np.asarray(cell_ids)
122
+
123
+ enc_nsides: List[int] = [self.in_nside]
124
+ current_nside = self.in_nside
125
+
126
+ # ---------- Oriented convolutions per level (encoder & decoder) ----------
127
+ self.hconv_enc: List[ho.HOrientedConvol] = []
128
+ self.hconv_dec: List[ho.HOrientedConvol] = []
129
+
130
+ # dummy data to propagate shapes/ids through ud_grade_2
131
+ l_data = self.f.backend.bk_cast(np.zeros((1, 1, cell_ids.shape[0]), dtype=self.np_dtype))
132
+
133
+ for l in range(depth):
134
+ # operator at encoder level l
135
+ hc = ho.HOrientedConvol(current_nside,
136
+ self.KERNELSZ,
137
+ cell_ids=self.l_cell_ids[l],
138
+ dtype=self.dtype)
139
+ #hc.make_idx_weights()
140
+
141
+ self.hconv_enc.append(hc)
142
+
143
+ # downsample once to get next level ids and new data shape
144
+ l_data, next_ids = hc.Down(
145
+ l_data, cell_ids=self.l_cell_ids[l], nside=current_nside
146
+ )
147
+ self.l_cell_ids[l + 1] = self.f.backend.to_numpy(next_ids)
148
+ current_nside //= 2
149
+ enc_nsides.append(current_nside)
150
+
151
+ # encoder conv weights and BN
152
+ self.enc_w1 = nn.ParameterList()
153
+ self.enc_bn1 = nn.ModuleList()
154
+ self.enc_w2 = nn.ParameterList()
155
+ self.enc_bn2 = nn.ModuleList()
156
+
157
+ self.enc_nsides = enc_nsides # [in, in/2, ..., in/2**depth]
158
+
159
+ inC = self.n_chan_in
160
+ for l, outC in enumerate(self.chanlist):
161
+
162
+ # conv1: inC -> outC
163
+ w1 = torch.empty(inC, outC, self.KERNELSZ * self.KERNELSZ)
164
+ nn.init.kaiming_uniform_(w1.view(inC * outC, -1), a=np.sqrt(5))
165
+ self.enc_w1.append(nn.Parameter(w1))
166
+ self.enc_bn1.append(nn.BatchNorm1d(outC))
167
+
168
+ # conv2: outC -> outC
169
+ w2 = torch.empty(outC, outC, self.KERNELSZ * self.KERNELSZ)
170
+ nn.init.kaiming_uniform_(w2.view(outC * outC, -1), a=np.sqrt(5))
171
+ self.enc_w2.append(nn.Parameter(w2))
172
+ self.enc_bn2.append(nn.BatchNorm1d(outC))
173
+
174
+ inC = outC # next level input channels
175
+
176
+ # decoder conv weights and BN (mirrored levels)
177
+ self.dec_w1 = nn.ParameterList()
178
+ self.dec_bn1 = nn.ModuleList()
179
+ self.dec_w2 = nn.ParameterList()
180
+ self.dec_bn2 = nn.ModuleList()
181
+
182
+ for d in range(depth):
183
+ level = depth - 1 - d # encoder level we are going back to
184
+ hc = ho.HOrientedConvol(self.enc_nsides[level],
185
+ self.KERNELSZ,
186
+ cell_ids=self.l_cell_ids[level],
187
+ dtype=self.dtype)
188
+ #hc.make_idx_weights()
189
+ self.hconv_dec.append(hc)
190
+
191
+ upC = self.chanlist[level + 1] if level + 1 < depth else self.chanlist[level]
192
+ skipC = self.chanlist[level]
193
+ inC_dec = upC + skipC
194
+ outC_dec = skipC
195
+
196
+ w1 = torch.empty(inC_dec, outC_dec, self.KERNELSZ * self.KERNELSZ)
197
+ nn.init.kaiming_uniform_(w1.view(inC_dec * outC_dec, -1), a=np.sqrt(5))
198
+ self.dec_w1.append(nn.Parameter(w1))
199
+ self.dec_bn1.append(nn.BatchNorm1d(outC_dec))
200
+
201
+ w2 = torch.empty(outC_dec, outC_dec, self.KERNELSZ * self.KERNELSZ)
202
+ nn.init.kaiming_uniform_(w2.view(outC_dec * outC_dec, -1), a=np.sqrt(5))
203
+ self.dec_w2.append(nn.Parameter(w2))
204
+ self.dec_bn2.append(nn.BatchNorm1d(outC_dec))
205
+
206
+ # Output head (on finest grid, channels = chanlist[0])
207
+ self.head_hconv = ho.HOrientedConvol(self.in_nside,
208
+ self.KERNELSZ,
209
+ cell_ids=self.l_cell_ids[0],
210
+ dtype=self.dtype)
211
+
212
+ head_inC = self.chanlist[0]
213
+ self.head_w = nn.Parameter(torch.empty(head_inC, self.out_channels, self.KERNELSZ * self.KERNELSZ))
214
+ nn.init.kaiming_uniform_(self.head_w.view(head_inC * self.out_channels, -1), a=np.sqrt(5))
215
+ self.head_bn = nn.BatchNorm1d(self.out_channels) if self.task == 'segmentation' else None
216
+
217
+ # ---- Decide runtime device (probe Foscat on CUDA, else CPU) ----
218
+ self.runtime_device = self._probe_and_set_runtime_device(self.device)
219
+
220
+ # -------------------------- device plumbing --------------------------
221
+ def _move_hconv_tensors(self, hc: ho.HOrientedConvol, device: torch.device) -> None:
222
+ """Best-effort: move any torch.Tensor attribute of HOrientedConvol to device."""
223
+ for name, val in list(vars(hc).items()):
224
+ try:
225
+ if torch.is_tensor(val):
226
+ setattr(hc, name, val.to(device))
227
+ elif isinstance(val, (list, tuple)) and val and torch.is_tensor(val[0]):
228
+ setattr(hc, name, type(val)([v.to(device) for v in val]))
229
+ except Exception:
230
+ # silently ignore non-tensor or protected attributes
231
+ pass
232
+
233
+ @torch.no_grad()
234
+ def _probe_and_set_runtime_device(self, preferred: torch.device) -> torch.device:
235
+ """Try to run a tiny Foscat conv on preferred device; fallback to CPU if it fails."""
236
+ if preferred.type == 'cuda' and self.prefer_foscat_gpu:
237
+ try:
238
+ # move module params/buffers first
239
+ super().to(preferred)
240
+ # move Foscat operator internals
241
+ for hc in self.hconv_enc + self.hconv_dec + [self.head_hconv]:
242
+ self._move_hconv_tensors(hc, preferred)
243
+ # dry run on level 0
244
+ npix0 = int(len(self.l_cell_ids[0]))
245
+ x_try = torch.zeros(1, self.n_chan_in, npix0, device=preferred)
246
+ y_try = self.hconv_enc[0].Convol_torch(x_try, self.enc_w1[0])
247
+ # success -> stay on CUDA
248
+ self._foscat_device = preferred
249
+ return preferred
250
+ except Exception as e:
251
+ # fallback to CPU; keep error for info
252
+ self._gpu_probe_error = repr(e)
253
+ pass
254
+ # CPU fallback
255
+ cpu = torch.device('cpu')
256
+ super().to(cpu)
257
+ for hc in self.hconv_enc + self.hconv_dec + [self.head_hconv]:
258
+ self._move_hconv_tensors(hc, cpu)
259
+ self._foscat_device = cpu
260
+ return cpu
261
+
262
+ def set_device(self, device: torch.device | str) -> torch.device:
263
+ """Request a (re)device; will probe Foscat and return the actual runtime device used."""
264
+ device = torch.device(device)
265
+ self.device = device
266
+ self.runtime_device = self._probe_and_set_runtime_device(device)
267
+ return self.runtime_device
268
+
269
+ # -------------------------- forward --------------------------
270
+ def forward(self, x: torch.Tensor,cell_ids: Optional[np.ndarray ] = None) -> torch.Tensor:
271
+ """Forward pass.
272
+
273
+ Parameters
274
+ ----------
275
+ x : torch.Tensor, shape (B, C_in, Npix)
276
+ Input tensor on `in_nside` grid.
277
+ cell_ids : np.ndarray (B, Npix) optional, use another cell_ids than the initial one.
278
+ if None use the initial cell_ids.
279
+ """
280
+ if not isinstance(x, torch.Tensor):
281
+ raise TypeError("Input must be a torch.Tensor")
282
+ if x.dim() != 3:
283
+ raise ValueError("Input must be (B, C, Npix)")
284
+
285
+ # Ensure input lives on the runtime (probed) device
286
+ x = x.to(self.runtime_device)
287
+
288
+ B, C, N = x.shape
289
+ if C != self.n_chan_in:
290
+ raise ValueError(f"Expected {self.n_chan_in} input channels, got {C}")
291
+
292
+ # Encoder
293
+ skips: List[torch.Tensor] = []
294
+ l_data = x
295
+ current_nside = self.in_nside
296
+ l_cell_ids=cell_ids
297
+
298
+ if cell_ids is not None:
299
+ t_cell_ids={}
300
+ t_cell_ids[0]=l_cell_ids
301
+ else:
302
+ t_cell_ids=self.l_cell_ids
303
+
304
+ for l, outC in enumerate(self.chanlist):
305
+ # conv1 + BN + ReLU
306
+ l_data = self.hconv_enc[l].Convol_torch(l_data,
307
+ self.enc_w1[l],
308
+ cell_ids=l_cell_ids)
309
+ l_data = self.enc_bn1[l](l_data)
310
+ l_data = F.relu(l_data, inplace=True)
311
+
312
+ # conv2 + BN + ReLU
313
+ l_data = self.hconv_enc[l].Convol_torch(l_data,
314
+ self.enc_w2[l],
315
+ cell_ids=l_cell_ids)
316
+ l_data = self.enc_bn2[l](l_data)
317
+ l_data = F.relu(l_data, inplace=True)
318
+
319
+ # save skip at this resolution
320
+ skips.append(l_data)
321
+
322
+ # downsample (except bottom level) -> ensure output is on runtime_device
323
+ if l < len(self.chanlist) - 1:
324
+ #l_data, l_cell_ids = self.f.ud_grade_2(
325
+ # l_data, cell_ids=t_cell_ids[l], nside=current_nside
326
+ #)
327
+ l_data, l_cell_ids = self.hconv_enc[l].Down(
328
+ l_data, cell_ids=t_cell_ids[l], nside=current_nside
329
+ )
330
+ if cell_ids is not None:
331
+ t_cell_ids[l+1]=l_cell_ids
332
+ else:
333
+ l_cell_ids=None
334
+
335
+ if isinstance(l_data, torch.Tensor) and l_data.device != self.runtime_device:
336
+ l_data = l_data.to(self.runtime_device)
337
+ current_nside //= 2
338
+
339
+ # Decoder
340
+ for d in range(len(self.chanlist)):
341
+ level = len(self.chanlist) - 1 - d # corresponding encoder level
342
+
343
+ if level < len(self.chanlist) - 1:
344
+ # upsample to next finer grid
345
+ # upsample to next finer grid (from level+1 -> level)
346
+ src_nside = self.enc_nsides[level + 1] # current (coarser)
347
+ tgt_nside = self.enc_nsides[level] # next finer (== src*2)
348
+ # Foscat up_grade signature expects current (coarse) ids in `cell_ids`
349
+ # and target (fine) ids in `o_cell_ids` (matching original UNET code).
350
+ '''
351
+ l_data = self.f.up_grade(
352
+ l_data,
353
+ tgt_nside,
354
+ cell_ids=t_cell_ids[level + 1], # source (coarser) ids
355
+ o_cell_ids=t_cell_ids[level], # target (finer) ids
356
+ nside=src_nside,
357
+ )
358
+ '''
359
+ l_data = self.hconv_enc[l].Up(
360
+ l_data,
361
+ cell_ids=t_cell_ids[level + 1], # source (coarser) ids
362
+ o_cell_ids=t_cell_ids[level], # target (finer) ids
363
+ nside=src_nside,
364
+ )
365
+
366
+ if isinstance(l_data, torch.Tensor) and l_data.device != self.runtime_device:
367
+ l_data = l_data.to(self.runtime_device)
368
+
369
+ # concat with skip features at this resolution
370
+ concat = self.f.backend.bk_concat([skips[level], l_data], 1)
371
+ l_data = concat.to(self.runtime_device) if torch.is_tensor(concat) else concat
372
+
373
+ if cell_ids is not None:
374
+ l_cell_ids = t_cell_ids[level]
375
+
376
+ # apply decoder convs on this grid
377
+ hc = self.hconv_dec[d]
378
+ l_data = hc.Convol_torch(l_data,
379
+ self.dec_w1[d],
380
+ cell_ids=l_cell_ids)
381
+
382
+ l_data = self.dec_bn1[d](l_data)
383
+ l_data = F.relu(l_data, inplace=True)
384
+
385
+ l_data = hc.Convol_torch(l_data,
386
+ self.dec_w2[d],
387
+ cell_ids=l_cell_ids)
388
+
389
+ l_data = self.dec_bn2[d](l_data)
390
+ l_data = F.relu(l_data, inplace=True)
391
+
392
+ # Head on finest grid
393
+ out = self.head_hconv.Convol_torch(l_data, self.head_w)
394
+ if self.head_bn is not None:
395
+ out = self.head_bn(out)
396
+ if self.final_activation == 'sigmoid':
397
+ out = torch.sigmoid(out)
398
+ elif self.final_activation == 'softmax':
399
+ out = torch.softmax(out, dim=1)
400
+ return out
401
+
402
+ # -------------------------- utilities --------------------------
403
+ @torch.no_grad()
404
+ def predict(self, x: torch.Tensor, batch_size: int = 8,cell_ids: Optional[np.ndarray ] = None) -> torch.Tensor:
405
+ self.eval()
406
+ outs = []
407
+ for i in range(0, x.shape[0], batch_size):
408
+ if cell_ids is not None:
409
+ outs.append(self.forward(x[i : i + batch_size],
410
+ cell_ids=cell_ids[i : i + batch_size]))
411
+ else:
412
+ outs.append(self.forward(x[i : i + batch_size]))
413
+
414
+ return torch.cat(outs, dim=0)
415
+
416
+ # -----------------------------
417
+ # Kernel extraction & plotting
418
+ # -----------------------------
419
+ def _arch_shapes(self):
420
+ """Return expected (in_c, out_c) per conv for encoder/decoder.
421
+
422
+ Returns
423
+ -------
424
+ enc_shapes : list[tuple[tuple[int,int], tuple[int,int]]]
425
+ For each level `l`, ((in1, out1), (in2, out2)) for the two encoder convs.
426
+ dec_shapes : list[tuple[tuple[int,int], tuple[int,int]]]
427
+ For each level `l`, ((in1, out1), (in2, out2)) for the two decoder convs.
428
+ """
429
+ nlayer = len(self.chanlist)
430
+ enc_shapes = []
431
+ l_chan = self.n_chan_in
432
+ for l in range(nlayer):
433
+ enc_shapes.append(((l_chan, self.chanlist[l]), (self.chanlist[l], self.chanlist[l])))
434
+ l_chan = self.chanlist[l] + 1
435
+
436
+ dec_shapes = []
437
+ l_chan = self.chanlist[-1] + 1
438
+ for l in range(nlayer):
439
+ in1 = l_chan + 1
440
+ out2 = 1 + (self.chanlist[nlayer - 1 - l] if (nlayer - 1 - l) > 0 else 0)
441
+ dec_shapes.append(((in1, in1), (in1, out2)))
442
+ l_chan = out2
443
+ return enc_shapes, dec_shapes
444
+
445
+ def extract_kernels(self, stage: str = "encoder", layer: int = 0, conv: int = 0):
446
+ """Extract raw convolution kernels for a given stage/level/conv.
447
+
448
+ Parameters
449
+ ----------
450
+ stage : {"encoder", "decoder"}
451
+ Which part of the network to inspect.
452
+ layer : int
453
+ Pyramid level (0 = finest encoder level / bottommost decoder level).
454
+ conv : int
455
+ 0 for the first conv at that level, 1 for the second conv.
456
+
457
+ Returns
458
+ -------
459
+ np.ndarray
460
+ Array of shape (in_c, out_c, K, K) containing the spatial kernels.
461
+ """
462
+ assert stage in {"encoder", "decoder"}
463
+ assert conv in {0, 1}
464
+ K = self.KERNELSZ
465
+ enc_shapes, dec_shapes = self._arch_shapes()
466
+
467
+ if stage == "encoder":
468
+ if conv==0:
469
+ w = self.enc_w1[layer]
470
+ else:
471
+ w = self.enc_w2[layer]
472
+ else:
473
+ if conv==0:
474
+ w = self.dec_w1[layer]
475
+ else:
476
+ w = self.dec_w2[layer]
477
+
478
+ w_np = self.f.backend.to_numpy(w.detach())
479
+ return w_np.reshape(w.shape[0],w.shape[1],K,K)
480
+
481
+ def plot_kernels(
482
+ self,
483
+ stage: str = "encoder",
484
+ layer: int = 0,
485
+ conv: int = 0,
486
+ fixed: str = "in",
487
+ index: int = 0,
488
+ max_tiles: int = 16,
489
+ ):
490
+ """Quick visualization of kernels on a grid using matplotlib.
491
+
492
+ Parameters
493
+ ----------
494
+ stage : {"encoder", "decoder"}
495
+ Which tower to visualize.
496
+ layer : int
497
+ Level to visualize.
498
+ conv : int
499
+ 0 or 1: first or second conv in the level.
500
+ fixed : {"in", "out"}
501
+ If "in", show kernels for a fixed input channel across many outputs.
502
+ If "out", show kernels for a fixed output channel across many inputs.
503
+ index : int
504
+ Channel index to fix (according to `fixed`).
505
+ max_tiles : int
506
+ Maximum number of tiles to display.
507
+ """
508
+ import math
509
+ import matplotlib.pyplot as plt
510
+
511
+ W = self.extract_kernels(stage=stage, layer=layer, conv=conv)
512
+ ic, oc, K,_ = W.shape
513
+
514
+ if fixed == "in":
515
+ idx = min(index, ic - 1)
516
+ tiles = [W[idx, j] for j in range(oc)]
517
+ title = f"{stage} L{layer} C{conv} | in={idx}"
518
+ else:
519
+ idx = min(index, oc - 1)
520
+ tiles = [W[i, idx] for i in range(ic)]
521
+ title = f"{stage} L{layer} C{conv} | out={idx}"
522
+
523
+ tiles = tiles[:max_tiles]
524
+ n = len(tiles)
525
+ cols = int(math.ceil(math.sqrt(n)))
526
+ rows = int(math.ceil(n / cols))
527
+
528
+ plt.figure(figsize=(2.5 * cols, 2.5 * rows))
529
+ for i, ker in enumerate(tiles, 1):
530
+ ax = plt.subplot(rows, cols, i)
531
+ ax.imshow(ker)
532
+ ax.set_xticks([])
533
+ ax.set_yticks([])
534
+ plt.suptitle(title)
535
+ plt.tight_layout()
536
+ plt.show()
537
+
538
+ # -----------------------------
539
+ # Unit tests (smoke tests)
540
+ # -----------------------------
541
+ # Run with: python UNET.py (or) python UNET.py -q for quieter output
542
+ # These tests assume Foscat and its dependencies are installed.
543
+
544
+
545
+ def _dummy_cell_ids(nside: int) -> np.ndarray:
546
+ """Return a simple identity mapping for HEALPix nested pixel IDs.
547
+
548
+ Notes
549
+ -----
550
+ Replace with your pipeline's real `cell_ids` if you have a precomputed
551
+ mapping consistent with Foscat/HEALPix nested ordering.
552
+ """
553
+ return np.arange(12 * nside * nside, dtype=np.int64)
554
+
555
+
556
+ if __name__ == "__main__":
557
+ import unittest
558
+
559
+ class TestUNET(unittest.TestCase):
560
+ """Lightweight smoke tests for shape and parameter plumbing."""
561
+
562
+ def setUp(self):
563
+ self.nside = 4 # small grid for fast tests (npix = 192)
564
+ self.chanlist = [4, 8] # two-level encoder/decoder
565
+ self.batch = 2
566
+ self.channels = 1
567
+ self.npix = 12 * self.nside * self.nside
568
+ self.cell_ids = _dummy_cell_ids(self.nside)
569
+ self.net = UNET(
570
+ in_nside=self.nside,
571
+ n_chan_in=self.channels,
572
+ chanlist=self.chanlist,
573
+ cell_ids=self.cell_ids,
574
+ )
575
+
576
+ def test_forward_shape(self):
577
+ # random input
578
+ x = np.random.randn(self.batch, self.channels, self.npix).astype(self.np_dtype)
579
+ x = self.net.f.backend.bk_cast(x)
580
+ y = self.net.eval(x)
581
+ # expected output: same npix, 1 channel at the very top
582
+ self.assertEqual(y.shape[0], self.batch)
583
+ self.assertEqual(y.shape[1], 1)
584
+ self.assertEqual(y.shape[2], self.npix)
585
+ # sanity: no NaNs
586
+ y_np = self.net.f.backend.to_numpy(y)
587
+ self.assertFalse(np.isnan(y_np).any())
588
+
589
+ def test_param_roundtrip_and_determinism(self):
590
+ x = np.random.randn(self.batch, self.channels, self.npix).astype(self.np_dtype)
591
+ x = self.net.f.backend.bk_cast(x)
592
+
593
+ # forward twice -> identical outputs with fixed params
594
+ y1 = self.net.eval(x)
595
+ y2 = self.net.eval(x)
596
+ y1_np = self.net.f.backend.to_numpy(y1)
597
+ y2_np = self.net.f.backend.to_numpy(y2)
598
+ np.testing.assert_allclose(y1_np, y2_np, rtol=0, atol=0)
599
+
600
+ # perturb parameters -> output should (very likely) change
601
+ p = self.net.get_param()
602
+ p_np = self.net.f.backend.to_numpy(p).copy()
603
+ if p_np.size > 0:
604
+ p_np[0] += 1.0
605
+ self.net.set_param(p_np)
606
+ y3 = self.net.eval(x)
607
+ y3_np = self.net.f.backend.to_numpy(y3)
608
+ with self.assertRaises(AssertionError):
609
+ np.testing.assert_allclose(y1_np, y3_np, rtol=0, atol=0)
610
+
611
+ unittest.main()
612
+
613
+ def fit(
614
+ model: HealpixUNet,
615
+ x_train: torch.Tensor | np.ndarray,
616
+ y_train: torch.Tensor | np.ndarray,
617
+ *,
618
+ n_epoch: int = 10,
619
+ view_epoch: int = 10,
620
+ batch_size: int = 16,
621
+ lr: float = 1e-3,
622
+ weight_decay: float = 0.0,
623
+ clip_grad_norm: float | None = None,
624
+ verbose: bool = True,
625
+ optimizer: Literal['ADAM', 'LBFGS'] = 'LBFGS',
626
+ ) -> dict:
627
+ """Train helper for the torch-ified HEALPix U-Net.
628
+
629
+ Optimizes all registered parameters (kernels + BN affine) with Adam on MSE for regression,
630
+ or CrossEntropy/BCE for segmentation.
631
+
632
+ Device policy
633
+ -------------
634
+ Uses the model's probed runtime device (CUDA if Foscat conv works there; otherwise CPU).
635
+ """
636
+ import numpy as _np
637
+ from torch.utils.data import TensorDataset, DataLoader
638
+
639
+ # Ensure model is on its runtime device (already probed in __init__)
640
+ model.to(model.runtime_device)
641
+
642
+ def _to_t(x):
643
+ if isinstance(x, torch.Tensor):
644
+ return x.float().to(model.runtime_device)
645
+ return torch.from_numpy(_np.asarray(x)).float().to(model.runtime_device)
646
+
647
+ x_t = _to_t(x_train)
648
+ y_t = _to_t(y_train)
649
+
650
+ # choose loss
651
+ if model.task == 'regression':
652
+ criterion = nn.MSELoss()
653
+ else:
654
+ if model.out_channels == 1:
655
+ criterion = nn.BCEWithLogitsLoss() if model.final_activation == 'none' else nn.BCELoss()
656
+ else:
657
+ if y_t.dim() == 3:
658
+ y_t = y_t.argmax(dim=1)
659
+ criterion = nn.CrossEntropyLoss()
660
+
661
+ ds = TensorDataset(x_t, y_t)
662
+ loader = DataLoader(ds, batch_size=batch_size, shuffle=True, drop_last=False)
663
+
664
+ if optimizer=='ADAM':
665
+ optim = torch.optim.Adam(model.parameters(), lr=lr, weight_decay=weight_decay)
666
+ l_n_epoch=n_epoch
667
+ n_inter=1
668
+ else:
669
+ optim = torch.optim.LBFGS(model.parameters(), lr=lr, max_iter=20, history_size=n_epoch*5, line_search_fn="strong_wolfe")
670
+ l_n_epoch=n_epoch//20
671
+ n_inter=20
672
+
673
+ history: List[float] = []
674
+ model.train()
675
+
676
+ for epoch in range(l_n_epoch):
677
+ for k in range(n_inter):
678
+ epoch_loss = 0.0
679
+ n_samples = 0
680
+ for xb, yb in loader:
681
+ # LBFGS a besoin d'un closure qui recalcule loss et gradients
682
+ if isinstance(optim, torch.optim.LBFGS):
683
+ def closure():
684
+ optim.zero_grad(set_to_none=True)
685
+ preds = model(xb)
686
+ loss = criterion(preds, yb)
687
+ loss.backward()
688
+ return loss
689
+
690
+ _ = optim.step(closure) # LBFGS appelle plusieurs fois le closure
691
+ # on recalcule la loss finale pour l’agg (sans gradient)
692
+ with torch.no_grad():
693
+ preds = model(xb)
694
+ loss = criterion(preds, yb)
695
+
696
+ else:
697
+ optim.zero_grad(set_to_none=True)
698
+ preds = model(xb)
699
+ loss = criterion(preds, yb)
700
+ loss.backward()
701
+ if clip_grad_norm is not None:
702
+ nn.utils.clip_grad_norm_(model.parameters(), clip_grad_norm)
703
+ optim.step()
704
+
705
+ bs = xb.shape[0]
706
+ epoch_loss += loss.item() * bs
707
+ n_samples += bs
708
+
709
+ epoch_loss /= max(1, n_samples)
710
+ history.append(epoch_loss)
711
+ if verbose and ((epoch*n_inter+k+1)%view_epoch==0 or epoch*n_inter+k==0):
712
+ print(f"[epoch {epoch*n_inter+k+1}/{l_n_epoch*n_inter}] loss={epoch_loss:.6f}")
713
+
714
+ return {"loss": history}
715
+
716
+
717
+ __all__ = ["HealpixUNet", "fit"]