foscat 2025.8.4__py3-none-any.whl → 2025.9.1__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- foscat/BkTorch.py +241 -49
- foscat/FoCUS.py +1 -1
- foscat/HOrientedConvol.py +446 -42
- foscat/HealBili.py +305 -0
- foscat/Plot.py +328 -0
- foscat/UNET.py +455 -178
- foscat/healpix_unet_torch.py +717 -0
- foscat/scat_cov.py +1 -1
- {foscat-2025.8.4.dist-info → foscat-2025.9.1.dist-info}/METADATA +1 -1
- {foscat-2025.8.4.dist-info → foscat-2025.9.1.dist-info}/RECORD +13 -10
- {foscat-2025.8.4.dist-info → foscat-2025.9.1.dist-info}/WHEEL +0 -0
- {foscat-2025.8.4.dist-info → foscat-2025.9.1.dist-info}/licenses/LICENSE +0 -0
- {foscat-2025.8.4.dist-info → foscat-2025.9.1.dist-info}/top_level.txt +0 -0
|
@@ -0,0 +1,717 @@
|
|
|
1
|
+
"""
|
|
2
|
+
HEALPix U-Net (nested) with Foscat + PyTorch niceties
|
|
3
|
+
----------------------------------------------------
|
|
4
|
+
GPU by default (when available), with graceful CPU fallback if Foscat ops are CPU-only.
|
|
5
|
+
|
|
6
|
+
- ReLU + BatchNorm after each convolution (encoder & decoder)
|
|
7
|
+
- Segmentation/Regression heads with optional final activation
|
|
8
|
+
- PyTorch-ified: inherits from nn.Module, standard state_dict
|
|
9
|
+
- Device management: tries CUDA first; if Foscat HOrientedConvol cannot run on CUDA, falls back to CPU
|
|
10
|
+
|
|
11
|
+
Shape convention: (B, C, Npix)
|
|
12
|
+
|
|
13
|
+
Requirements: foscat (scat_cov.funct + HOrientedConvol.Convol_torch must be differentiable on torch tensors)
|
|
14
|
+
"""
|
|
15
|
+
from __future__ import annotations
|
|
16
|
+
from typing import List, Optional, Literal, Tuple
|
|
17
|
+
import numpy as np
|
|
18
|
+
|
|
19
|
+
import torch
|
|
20
|
+
import torch.nn as nn
|
|
21
|
+
import torch.nn.functional as F
|
|
22
|
+
|
|
23
|
+
import foscat.scat_cov as sc
|
|
24
|
+
import foscat.HOrientedConvol as ho
|
|
25
|
+
import matplotlib.pyplot as plt
|
|
26
|
+
|
|
27
|
+
|
|
28
|
+
class HealpixUNet(nn.Module):
|
|
29
|
+
"""U-Net-like architecture on the HEALPix sphere using Foscat oriented convolutions.
|
|
30
|
+
|
|
31
|
+
Parameters
|
|
32
|
+
----------
|
|
33
|
+
in_nside : int
|
|
34
|
+
Input HEALPix nside (nested scheme).
|
|
35
|
+
n_chan_in : int
|
|
36
|
+
Number of input channels.
|
|
37
|
+
chanlist : list[int]
|
|
38
|
+
Channels per encoder level (depth = len(chanlist)). Example: [16, 32, 64].
|
|
39
|
+
cell_ids : np.ndarray
|
|
40
|
+
Cell indices for the finest resolution (nside = in_nside) in nested scheme.
|
|
41
|
+
KERNELSZ : int, default 3
|
|
42
|
+
Spatial kernel size K (K x K) for oriented convolution.
|
|
43
|
+
task : {'regression','segmentation'}, default 'regression'
|
|
44
|
+
Chooses the head and default activation.
|
|
45
|
+
out_channels : int, default 1
|
|
46
|
+
Number of output channels (e.g. num_classes for segmentation).
|
|
47
|
+
final_activation : {'none','sigmoid','softmax'} | None
|
|
48
|
+
If None, uses sensible default per task: 'none' for regression, 'softmax' for segmentation (multi-class),
|
|
49
|
+
'sigmoid' for segmentation when out_channels==1.
|
|
50
|
+
device : str | torch.device | None, default: 'cuda' if available else 'cpu'
|
|
51
|
+
Preferred device. The module will probe whether Foscat ops can run on CUDA; if not,
|
|
52
|
+
it will fallback to CPU and keep all parameters/buffers on CPU for consistency.
|
|
53
|
+
prefer_foscat_gpu : bool, default True
|
|
54
|
+
When device is CUDA, try to move Foscat operators (internal tensors) to CUDA and do a dry-run.
|
|
55
|
+
If the dry-run fails, everything falls back to CPU.
|
|
56
|
+
|
|
57
|
+
Notes
|
|
58
|
+
-----
|
|
59
|
+
- Two oriented convolutions per level. After each conv: BatchNorm1d + ReLU.
|
|
60
|
+
- Downsampling uses foscat ``ud_grade_2``; upsampling uses ``up_grade``.
|
|
61
|
+
- Convolution kernels are explicit parameters (shape [C_in, C_out, K*K]) and applied via ``HOrientedConvol.Convol_torch``.
|
|
62
|
+
- Foscat ops device is auto-probed to avoid CPU/CUDA mismatches.
|
|
63
|
+
"""
|
|
64
|
+
|
|
65
|
+
def __init__(
|
|
66
|
+
self,
|
|
67
|
+
*,
|
|
68
|
+
in_nside: int,
|
|
69
|
+
n_chan_in: int,
|
|
70
|
+
chanlist: List[int],
|
|
71
|
+
cell_ids: np.ndarray,
|
|
72
|
+
KERNELSZ: int = 3,
|
|
73
|
+
task: Literal['regression', 'segmentation'] = 'regression',
|
|
74
|
+
out_channels: int = 1,
|
|
75
|
+
final_activation: Optional[Literal['none', 'sigmoid', 'softmax']] = None,
|
|
76
|
+
device: Optional[torch.device | str] = None,
|
|
77
|
+
prefer_foscat_gpu: bool = True,
|
|
78
|
+
dtype: Literal['float32','float64'] = 'float32'
|
|
79
|
+
) -> None:
|
|
80
|
+
super().__init__()
|
|
81
|
+
|
|
82
|
+
self.dtype=dtype
|
|
83
|
+
if dtype=='float32':
|
|
84
|
+
self.np_dtype=np.float32
|
|
85
|
+
else:
|
|
86
|
+
self.np_dtype=np.float64
|
|
87
|
+
|
|
88
|
+
if cell_ids is None:
|
|
89
|
+
raise ValueError("cell_ids must be provided for the finest resolution.")
|
|
90
|
+
if len(chanlist) == 0:
|
|
91
|
+
raise ValueError("chanlist must be non-empty (depth >= 1).")
|
|
92
|
+
|
|
93
|
+
self.in_nside = int(in_nside)
|
|
94
|
+
self.n_chan_in = int(n_chan_in)
|
|
95
|
+
self.chanlist = list(map(int, chanlist))
|
|
96
|
+
self.KERNELSZ = int(KERNELSZ)
|
|
97
|
+
self.task = task
|
|
98
|
+
self.out_channels = int(out_channels)
|
|
99
|
+
self.prefer_foscat_gpu = bool(prefer_foscat_gpu)
|
|
100
|
+
|
|
101
|
+
# Choose default final activation if not given
|
|
102
|
+
if final_activation is None:
|
|
103
|
+
if task == 'regression':
|
|
104
|
+
self.final_activation = 'none'
|
|
105
|
+
else: # segmentation
|
|
106
|
+
self.final_activation = 'sigmoid' if out_channels == 1 else 'softmax'
|
|
107
|
+
else:
|
|
108
|
+
self.final_activation = final_activation
|
|
109
|
+
|
|
110
|
+
# Resolve preferred device
|
|
111
|
+
if device is None:
|
|
112
|
+
device = 'cuda' if torch.cuda.is_available() else 'cpu'
|
|
113
|
+
self.device = torch.device(device)
|
|
114
|
+
|
|
115
|
+
# Foscat functional wrapper (backend + grade ops)
|
|
116
|
+
self.f = sc.funct(KERNELSZ=self.KERNELSZ)
|
|
117
|
+
|
|
118
|
+
# ---------- Build multi-resolution bookkeeping ----------
|
|
119
|
+
depth = len(self.chanlist)
|
|
120
|
+
self.l_cell_ids: List[np.ndarray] = [None] * (depth + 1) # per encoder level + bottom
|
|
121
|
+
self.l_cell_ids[0] = np.asarray(cell_ids)
|
|
122
|
+
|
|
123
|
+
enc_nsides: List[int] = [self.in_nside]
|
|
124
|
+
current_nside = self.in_nside
|
|
125
|
+
|
|
126
|
+
# ---------- Oriented convolutions per level (encoder & decoder) ----------
|
|
127
|
+
self.hconv_enc: List[ho.HOrientedConvol] = []
|
|
128
|
+
self.hconv_dec: List[ho.HOrientedConvol] = []
|
|
129
|
+
|
|
130
|
+
# dummy data to propagate shapes/ids through ud_grade_2
|
|
131
|
+
l_data = self.f.backend.bk_cast(np.zeros((1, 1, cell_ids.shape[0]), dtype=self.np_dtype))
|
|
132
|
+
|
|
133
|
+
for l in range(depth):
|
|
134
|
+
# operator at encoder level l
|
|
135
|
+
hc = ho.HOrientedConvol(current_nside,
|
|
136
|
+
self.KERNELSZ,
|
|
137
|
+
cell_ids=self.l_cell_ids[l],
|
|
138
|
+
dtype=self.dtype)
|
|
139
|
+
#hc.make_idx_weights()
|
|
140
|
+
|
|
141
|
+
self.hconv_enc.append(hc)
|
|
142
|
+
|
|
143
|
+
# downsample once to get next level ids and new data shape
|
|
144
|
+
l_data, next_ids = hc.Down(
|
|
145
|
+
l_data, cell_ids=self.l_cell_ids[l], nside=current_nside
|
|
146
|
+
)
|
|
147
|
+
self.l_cell_ids[l + 1] = self.f.backend.to_numpy(next_ids)
|
|
148
|
+
current_nside //= 2
|
|
149
|
+
enc_nsides.append(current_nside)
|
|
150
|
+
|
|
151
|
+
# encoder conv weights and BN
|
|
152
|
+
self.enc_w1 = nn.ParameterList()
|
|
153
|
+
self.enc_bn1 = nn.ModuleList()
|
|
154
|
+
self.enc_w2 = nn.ParameterList()
|
|
155
|
+
self.enc_bn2 = nn.ModuleList()
|
|
156
|
+
|
|
157
|
+
self.enc_nsides = enc_nsides # [in, in/2, ..., in/2**depth]
|
|
158
|
+
|
|
159
|
+
inC = self.n_chan_in
|
|
160
|
+
for l, outC in enumerate(self.chanlist):
|
|
161
|
+
|
|
162
|
+
# conv1: inC -> outC
|
|
163
|
+
w1 = torch.empty(inC, outC, self.KERNELSZ * self.KERNELSZ)
|
|
164
|
+
nn.init.kaiming_uniform_(w1.view(inC * outC, -1), a=np.sqrt(5))
|
|
165
|
+
self.enc_w1.append(nn.Parameter(w1))
|
|
166
|
+
self.enc_bn1.append(nn.BatchNorm1d(outC))
|
|
167
|
+
|
|
168
|
+
# conv2: outC -> outC
|
|
169
|
+
w2 = torch.empty(outC, outC, self.KERNELSZ * self.KERNELSZ)
|
|
170
|
+
nn.init.kaiming_uniform_(w2.view(outC * outC, -1), a=np.sqrt(5))
|
|
171
|
+
self.enc_w2.append(nn.Parameter(w2))
|
|
172
|
+
self.enc_bn2.append(nn.BatchNorm1d(outC))
|
|
173
|
+
|
|
174
|
+
inC = outC # next level input channels
|
|
175
|
+
|
|
176
|
+
# decoder conv weights and BN (mirrored levels)
|
|
177
|
+
self.dec_w1 = nn.ParameterList()
|
|
178
|
+
self.dec_bn1 = nn.ModuleList()
|
|
179
|
+
self.dec_w2 = nn.ParameterList()
|
|
180
|
+
self.dec_bn2 = nn.ModuleList()
|
|
181
|
+
|
|
182
|
+
for d in range(depth):
|
|
183
|
+
level = depth - 1 - d # encoder level we are going back to
|
|
184
|
+
hc = ho.HOrientedConvol(self.enc_nsides[level],
|
|
185
|
+
self.KERNELSZ,
|
|
186
|
+
cell_ids=self.l_cell_ids[level],
|
|
187
|
+
dtype=self.dtype)
|
|
188
|
+
#hc.make_idx_weights()
|
|
189
|
+
self.hconv_dec.append(hc)
|
|
190
|
+
|
|
191
|
+
upC = self.chanlist[level + 1] if level + 1 < depth else self.chanlist[level]
|
|
192
|
+
skipC = self.chanlist[level]
|
|
193
|
+
inC_dec = upC + skipC
|
|
194
|
+
outC_dec = skipC
|
|
195
|
+
|
|
196
|
+
w1 = torch.empty(inC_dec, outC_dec, self.KERNELSZ * self.KERNELSZ)
|
|
197
|
+
nn.init.kaiming_uniform_(w1.view(inC_dec * outC_dec, -1), a=np.sqrt(5))
|
|
198
|
+
self.dec_w1.append(nn.Parameter(w1))
|
|
199
|
+
self.dec_bn1.append(nn.BatchNorm1d(outC_dec))
|
|
200
|
+
|
|
201
|
+
w2 = torch.empty(outC_dec, outC_dec, self.KERNELSZ * self.KERNELSZ)
|
|
202
|
+
nn.init.kaiming_uniform_(w2.view(outC_dec * outC_dec, -1), a=np.sqrt(5))
|
|
203
|
+
self.dec_w2.append(nn.Parameter(w2))
|
|
204
|
+
self.dec_bn2.append(nn.BatchNorm1d(outC_dec))
|
|
205
|
+
|
|
206
|
+
# Output head (on finest grid, channels = chanlist[0])
|
|
207
|
+
self.head_hconv = ho.HOrientedConvol(self.in_nside,
|
|
208
|
+
self.KERNELSZ,
|
|
209
|
+
cell_ids=self.l_cell_ids[0],
|
|
210
|
+
dtype=self.dtype)
|
|
211
|
+
|
|
212
|
+
head_inC = self.chanlist[0]
|
|
213
|
+
self.head_w = nn.Parameter(torch.empty(head_inC, self.out_channels, self.KERNELSZ * self.KERNELSZ))
|
|
214
|
+
nn.init.kaiming_uniform_(self.head_w.view(head_inC * self.out_channels, -1), a=np.sqrt(5))
|
|
215
|
+
self.head_bn = nn.BatchNorm1d(self.out_channels) if self.task == 'segmentation' else None
|
|
216
|
+
|
|
217
|
+
# ---- Decide runtime device (probe Foscat on CUDA, else CPU) ----
|
|
218
|
+
self.runtime_device = self._probe_and_set_runtime_device(self.device)
|
|
219
|
+
|
|
220
|
+
# -------------------------- device plumbing --------------------------
|
|
221
|
+
def _move_hconv_tensors(self, hc: ho.HOrientedConvol, device: torch.device) -> None:
|
|
222
|
+
"""Best-effort: move any torch.Tensor attribute of HOrientedConvol to device."""
|
|
223
|
+
for name, val in list(vars(hc).items()):
|
|
224
|
+
try:
|
|
225
|
+
if torch.is_tensor(val):
|
|
226
|
+
setattr(hc, name, val.to(device))
|
|
227
|
+
elif isinstance(val, (list, tuple)) and val and torch.is_tensor(val[0]):
|
|
228
|
+
setattr(hc, name, type(val)([v.to(device) for v in val]))
|
|
229
|
+
except Exception:
|
|
230
|
+
# silently ignore non-tensor or protected attributes
|
|
231
|
+
pass
|
|
232
|
+
|
|
233
|
+
@torch.no_grad()
|
|
234
|
+
def _probe_and_set_runtime_device(self, preferred: torch.device) -> torch.device:
|
|
235
|
+
"""Try to run a tiny Foscat conv on preferred device; fallback to CPU if it fails."""
|
|
236
|
+
if preferred.type == 'cuda' and self.prefer_foscat_gpu:
|
|
237
|
+
try:
|
|
238
|
+
# move module params/buffers first
|
|
239
|
+
super().to(preferred)
|
|
240
|
+
# move Foscat operator internals
|
|
241
|
+
for hc in self.hconv_enc + self.hconv_dec + [self.head_hconv]:
|
|
242
|
+
self._move_hconv_tensors(hc, preferred)
|
|
243
|
+
# dry run on level 0
|
|
244
|
+
npix0 = int(len(self.l_cell_ids[0]))
|
|
245
|
+
x_try = torch.zeros(1, self.n_chan_in, npix0, device=preferred)
|
|
246
|
+
y_try = self.hconv_enc[0].Convol_torch(x_try, self.enc_w1[0])
|
|
247
|
+
# success -> stay on CUDA
|
|
248
|
+
self._foscat_device = preferred
|
|
249
|
+
return preferred
|
|
250
|
+
except Exception as e:
|
|
251
|
+
# fallback to CPU; keep error for info
|
|
252
|
+
self._gpu_probe_error = repr(e)
|
|
253
|
+
pass
|
|
254
|
+
# CPU fallback
|
|
255
|
+
cpu = torch.device('cpu')
|
|
256
|
+
super().to(cpu)
|
|
257
|
+
for hc in self.hconv_enc + self.hconv_dec + [self.head_hconv]:
|
|
258
|
+
self._move_hconv_tensors(hc, cpu)
|
|
259
|
+
self._foscat_device = cpu
|
|
260
|
+
return cpu
|
|
261
|
+
|
|
262
|
+
def set_device(self, device: torch.device | str) -> torch.device:
|
|
263
|
+
"""Request a (re)device; will probe Foscat and return the actual runtime device used."""
|
|
264
|
+
device = torch.device(device)
|
|
265
|
+
self.device = device
|
|
266
|
+
self.runtime_device = self._probe_and_set_runtime_device(device)
|
|
267
|
+
return self.runtime_device
|
|
268
|
+
|
|
269
|
+
# -------------------------- forward --------------------------
|
|
270
|
+
def forward(self, x: torch.Tensor,cell_ids: Optional[np.ndarray ] = None) -> torch.Tensor:
|
|
271
|
+
"""Forward pass.
|
|
272
|
+
|
|
273
|
+
Parameters
|
|
274
|
+
----------
|
|
275
|
+
x : torch.Tensor, shape (B, C_in, Npix)
|
|
276
|
+
Input tensor on `in_nside` grid.
|
|
277
|
+
cell_ids : np.ndarray (B, Npix) optional, use another cell_ids than the initial one.
|
|
278
|
+
if None use the initial cell_ids.
|
|
279
|
+
"""
|
|
280
|
+
if not isinstance(x, torch.Tensor):
|
|
281
|
+
raise TypeError("Input must be a torch.Tensor")
|
|
282
|
+
if x.dim() != 3:
|
|
283
|
+
raise ValueError("Input must be (B, C, Npix)")
|
|
284
|
+
|
|
285
|
+
# Ensure input lives on the runtime (probed) device
|
|
286
|
+
x = x.to(self.runtime_device)
|
|
287
|
+
|
|
288
|
+
B, C, N = x.shape
|
|
289
|
+
if C != self.n_chan_in:
|
|
290
|
+
raise ValueError(f"Expected {self.n_chan_in} input channels, got {C}")
|
|
291
|
+
|
|
292
|
+
# Encoder
|
|
293
|
+
skips: List[torch.Tensor] = []
|
|
294
|
+
l_data = x
|
|
295
|
+
current_nside = self.in_nside
|
|
296
|
+
l_cell_ids=cell_ids
|
|
297
|
+
|
|
298
|
+
if cell_ids is not None:
|
|
299
|
+
t_cell_ids={}
|
|
300
|
+
t_cell_ids[0]=l_cell_ids
|
|
301
|
+
else:
|
|
302
|
+
t_cell_ids=self.l_cell_ids
|
|
303
|
+
|
|
304
|
+
for l, outC in enumerate(self.chanlist):
|
|
305
|
+
# conv1 + BN + ReLU
|
|
306
|
+
l_data = self.hconv_enc[l].Convol_torch(l_data,
|
|
307
|
+
self.enc_w1[l],
|
|
308
|
+
cell_ids=l_cell_ids)
|
|
309
|
+
l_data = self.enc_bn1[l](l_data)
|
|
310
|
+
l_data = F.relu(l_data, inplace=True)
|
|
311
|
+
|
|
312
|
+
# conv2 + BN + ReLU
|
|
313
|
+
l_data = self.hconv_enc[l].Convol_torch(l_data,
|
|
314
|
+
self.enc_w2[l],
|
|
315
|
+
cell_ids=l_cell_ids)
|
|
316
|
+
l_data = self.enc_bn2[l](l_data)
|
|
317
|
+
l_data = F.relu(l_data, inplace=True)
|
|
318
|
+
|
|
319
|
+
# save skip at this resolution
|
|
320
|
+
skips.append(l_data)
|
|
321
|
+
|
|
322
|
+
# downsample (except bottom level) -> ensure output is on runtime_device
|
|
323
|
+
if l < len(self.chanlist) - 1:
|
|
324
|
+
#l_data, l_cell_ids = self.f.ud_grade_2(
|
|
325
|
+
# l_data, cell_ids=t_cell_ids[l], nside=current_nside
|
|
326
|
+
#)
|
|
327
|
+
l_data, l_cell_ids = self.hconv_enc[l].Down(
|
|
328
|
+
l_data, cell_ids=t_cell_ids[l], nside=current_nside
|
|
329
|
+
)
|
|
330
|
+
if cell_ids is not None:
|
|
331
|
+
t_cell_ids[l+1]=l_cell_ids
|
|
332
|
+
else:
|
|
333
|
+
l_cell_ids=None
|
|
334
|
+
|
|
335
|
+
if isinstance(l_data, torch.Tensor) and l_data.device != self.runtime_device:
|
|
336
|
+
l_data = l_data.to(self.runtime_device)
|
|
337
|
+
current_nside //= 2
|
|
338
|
+
|
|
339
|
+
# Decoder
|
|
340
|
+
for d in range(len(self.chanlist)):
|
|
341
|
+
level = len(self.chanlist) - 1 - d # corresponding encoder level
|
|
342
|
+
|
|
343
|
+
if level < len(self.chanlist) - 1:
|
|
344
|
+
# upsample to next finer grid
|
|
345
|
+
# upsample to next finer grid (from level+1 -> level)
|
|
346
|
+
src_nside = self.enc_nsides[level + 1] # current (coarser)
|
|
347
|
+
tgt_nside = self.enc_nsides[level] # next finer (== src*2)
|
|
348
|
+
# Foscat up_grade signature expects current (coarse) ids in `cell_ids`
|
|
349
|
+
# and target (fine) ids in `o_cell_ids` (matching original UNET code).
|
|
350
|
+
'''
|
|
351
|
+
l_data = self.f.up_grade(
|
|
352
|
+
l_data,
|
|
353
|
+
tgt_nside,
|
|
354
|
+
cell_ids=t_cell_ids[level + 1], # source (coarser) ids
|
|
355
|
+
o_cell_ids=t_cell_ids[level], # target (finer) ids
|
|
356
|
+
nside=src_nside,
|
|
357
|
+
)
|
|
358
|
+
'''
|
|
359
|
+
l_data = self.hconv_enc[l].Up(
|
|
360
|
+
l_data,
|
|
361
|
+
cell_ids=t_cell_ids[level + 1], # source (coarser) ids
|
|
362
|
+
o_cell_ids=t_cell_ids[level], # target (finer) ids
|
|
363
|
+
nside=src_nside,
|
|
364
|
+
)
|
|
365
|
+
|
|
366
|
+
if isinstance(l_data, torch.Tensor) and l_data.device != self.runtime_device:
|
|
367
|
+
l_data = l_data.to(self.runtime_device)
|
|
368
|
+
|
|
369
|
+
# concat with skip features at this resolution
|
|
370
|
+
concat = self.f.backend.bk_concat([skips[level], l_data], 1)
|
|
371
|
+
l_data = concat.to(self.runtime_device) if torch.is_tensor(concat) else concat
|
|
372
|
+
|
|
373
|
+
if cell_ids is not None:
|
|
374
|
+
l_cell_ids = t_cell_ids[level]
|
|
375
|
+
|
|
376
|
+
# apply decoder convs on this grid
|
|
377
|
+
hc = self.hconv_dec[d]
|
|
378
|
+
l_data = hc.Convol_torch(l_data,
|
|
379
|
+
self.dec_w1[d],
|
|
380
|
+
cell_ids=l_cell_ids)
|
|
381
|
+
|
|
382
|
+
l_data = self.dec_bn1[d](l_data)
|
|
383
|
+
l_data = F.relu(l_data, inplace=True)
|
|
384
|
+
|
|
385
|
+
l_data = hc.Convol_torch(l_data,
|
|
386
|
+
self.dec_w2[d],
|
|
387
|
+
cell_ids=l_cell_ids)
|
|
388
|
+
|
|
389
|
+
l_data = self.dec_bn2[d](l_data)
|
|
390
|
+
l_data = F.relu(l_data, inplace=True)
|
|
391
|
+
|
|
392
|
+
# Head on finest grid
|
|
393
|
+
out = self.head_hconv.Convol_torch(l_data, self.head_w)
|
|
394
|
+
if self.head_bn is not None:
|
|
395
|
+
out = self.head_bn(out)
|
|
396
|
+
if self.final_activation == 'sigmoid':
|
|
397
|
+
out = torch.sigmoid(out)
|
|
398
|
+
elif self.final_activation == 'softmax':
|
|
399
|
+
out = torch.softmax(out, dim=1)
|
|
400
|
+
return out
|
|
401
|
+
|
|
402
|
+
# -------------------------- utilities --------------------------
|
|
403
|
+
@torch.no_grad()
|
|
404
|
+
def predict(self, x: torch.Tensor, batch_size: int = 8,cell_ids: Optional[np.ndarray ] = None) -> torch.Tensor:
|
|
405
|
+
self.eval()
|
|
406
|
+
outs = []
|
|
407
|
+
for i in range(0, x.shape[0], batch_size):
|
|
408
|
+
if cell_ids is not None:
|
|
409
|
+
outs.append(self.forward(x[i : i + batch_size],
|
|
410
|
+
cell_ids=cell_ids[i : i + batch_size]))
|
|
411
|
+
else:
|
|
412
|
+
outs.append(self.forward(x[i : i + batch_size]))
|
|
413
|
+
|
|
414
|
+
return torch.cat(outs, dim=0)
|
|
415
|
+
|
|
416
|
+
# -----------------------------
|
|
417
|
+
# Kernel extraction & plotting
|
|
418
|
+
# -----------------------------
|
|
419
|
+
def _arch_shapes(self):
|
|
420
|
+
"""Return expected (in_c, out_c) per conv for encoder/decoder.
|
|
421
|
+
|
|
422
|
+
Returns
|
|
423
|
+
-------
|
|
424
|
+
enc_shapes : list[tuple[tuple[int,int], tuple[int,int]]]
|
|
425
|
+
For each level `l`, ((in1, out1), (in2, out2)) for the two encoder convs.
|
|
426
|
+
dec_shapes : list[tuple[tuple[int,int], tuple[int,int]]]
|
|
427
|
+
For each level `l`, ((in1, out1), (in2, out2)) for the two decoder convs.
|
|
428
|
+
"""
|
|
429
|
+
nlayer = len(self.chanlist)
|
|
430
|
+
enc_shapes = []
|
|
431
|
+
l_chan = self.n_chan_in
|
|
432
|
+
for l in range(nlayer):
|
|
433
|
+
enc_shapes.append(((l_chan, self.chanlist[l]), (self.chanlist[l], self.chanlist[l])))
|
|
434
|
+
l_chan = self.chanlist[l] + 1
|
|
435
|
+
|
|
436
|
+
dec_shapes = []
|
|
437
|
+
l_chan = self.chanlist[-1] + 1
|
|
438
|
+
for l in range(nlayer):
|
|
439
|
+
in1 = l_chan + 1
|
|
440
|
+
out2 = 1 + (self.chanlist[nlayer - 1 - l] if (nlayer - 1 - l) > 0 else 0)
|
|
441
|
+
dec_shapes.append(((in1, in1), (in1, out2)))
|
|
442
|
+
l_chan = out2
|
|
443
|
+
return enc_shapes, dec_shapes
|
|
444
|
+
|
|
445
|
+
def extract_kernels(self, stage: str = "encoder", layer: int = 0, conv: int = 0):
|
|
446
|
+
"""Extract raw convolution kernels for a given stage/level/conv.
|
|
447
|
+
|
|
448
|
+
Parameters
|
|
449
|
+
----------
|
|
450
|
+
stage : {"encoder", "decoder"}
|
|
451
|
+
Which part of the network to inspect.
|
|
452
|
+
layer : int
|
|
453
|
+
Pyramid level (0 = finest encoder level / bottommost decoder level).
|
|
454
|
+
conv : int
|
|
455
|
+
0 for the first conv at that level, 1 for the second conv.
|
|
456
|
+
|
|
457
|
+
Returns
|
|
458
|
+
-------
|
|
459
|
+
np.ndarray
|
|
460
|
+
Array of shape (in_c, out_c, K, K) containing the spatial kernels.
|
|
461
|
+
"""
|
|
462
|
+
assert stage in {"encoder", "decoder"}
|
|
463
|
+
assert conv in {0, 1}
|
|
464
|
+
K = self.KERNELSZ
|
|
465
|
+
enc_shapes, dec_shapes = self._arch_shapes()
|
|
466
|
+
|
|
467
|
+
if stage == "encoder":
|
|
468
|
+
if conv==0:
|
|
469
|
+
w = self.enc_w1[layer]
|
|
470
|
+
else:
|
|
471
|
+
w = self.enc_w2[layer]
|
|
472
|
+
else:
|
|
473
|
+
if conv==0:
|
|
474
|
+
w = self.dec_w1[layer]
|
|
475
|
+
else:
|
|
476
|
+
w = self.dec_w2[layer]
|
|
477
|
+
|
|
478
|
+
w_np = self.f.backend.to_numpy(w.detach())
|
|
479
|
+
return w_np.reshape(w.shape[0],w.shape[1],K,K)
|
|
480
|
+
|
|
481
|
+
def plot_kernels(
|
|
482
|
+
self,
|
|
483
|
+
stage: str = "encoder",
|
|
484
|
+
layer: int = 0,
|
|
485
|
+
conv: int = 0,
|
|
486
|
+
fixed: str = "in",
|
|
487
|
+
index: int = 0,
|
|
488
|
+
max_tiles: int = 16,
|
|
489
|
+
):
|
|
490
|
+
"""Quick visualization of kernels on a grid using matplotlib.
|
|
491
|
+
|
|
492
|
+
Parameters
|
|
493
|
+
----------
|
|
494
|
+
stage : {"encoder", "decoder"}
|
|
495
|
+
Which tower to visualize.
|
|
496
|
+
layer : int
|
|
497
|
+
Level to visualize.
|
|
498
|
+
conv : int
|
|
499
|
+
0 or 1: first or second conv in the level.
|
|
500
|
+
fixed : {"in", "out"}
|
|
501
|
+
If "in", show kernels for a fixed input channel across many outputs.
|
|
502
|
+
If "out", show kernels for a fixed output channel across many inputs.
|
|
503
|
+
index : int
|
|
504
|
+
Channel index to fix (according to `fixed`).
|
|
505
|
+
max_tiles : int
|
|
506
|
+
Maximum number of tiles to display.
|
|
507
|
+
"""
|
|
508
|
+
import math
|
|
509
|
+
import matplotlib.pyplot as plt
|
|
510
|
+
|
|
511
|
+
W = self.extract_kernels(stage=stage, layer=layer, conv=conv)
|
|
512
|
+
ic, oc, K,_ = W.shape
|
|
513
|
+
|
|
514
|
+
if fixed == "in":
|
|
515
|
+
idx = min(index, ic - 1)
|
|
516
|
+
tiles = [W[idx, j] for j in range(oc)]
|
|
517
|
+
title = f"{stage} L{layer} C{conv} | in={idx}"
|
|
518
|
+
else:
|
|
519
|
+
idx = min(index, oc - 1)
|
|
520
|
+
tiles = [W[i, idx] for i in range(ic)]
|
|
521
|
+
title = f"{stage} L{layer} C{conv} | out={idx}"
|
|
522
|
+
|
|
523
|
+
tiles = tiles[:max_tiles]
|
|
524
|
+
n = len(tiles)
|
|
525
|
+
cols = int(math.ceil(math.sqrt(n)))
|
|
526
|
+
rows = int(math.ceil(n / cols))
|
|
527
|
+
|
|
528
|
+
plt.figure(figsize=(2.5 * cols, 2.5 * rows))
|
|
529
|
+
for i, ker in enumerate(tiles, 1):
|
|
530
|
+
ax = plt.subplot(rows, cols, i)
|
|
531
|
+
ax.imshow(ker)
|
|
532
|
+
ax.set_xticks([])
|
|
533
|
+
ax.set_yticks([])
|
|
534
|
+
plt.suptitle(title)
|
|
535
|
+
plt.tight_layout()
|
|
536
|
+
plt.show()
|
|
537
|
+
|
|
538
|
+
# -----------------------------
|
|
539
|
+
# Unit tests (smoke tests)
|
|
540
|
+
# -----------------------------
|
|
541
|
+
# Run with: python UNET.py (or) python UNET.py -q for quieter output
|
|
542
|
+
# These tests assume Foscat and its dependencies are installed.
|
|
543
|
+
|
|
544
|
+
|
|
545
|
+
def _dummy_cell_ids(nside: int) -> np.ndarray:
|
|
546
|
+
"""Return a simple identity mapping for HEALPix nested pixel IDs.
|
|
547
|
+
|
|
548
|
+
Notes
|
|
549
|
+
-----
|
|
550
|
+
Replace with your pipeline's real `cell_ids` if you have a precomputed
|
|
551
|
+
mapping consistent with Foscat/HEALPix nested ordering.
|
|
552
|
+
"""
|
|
553
|
+
return np.arange(12 * nside * nside, dtype=np.int64)
|
|
554
|
+
|
|
555
|
+
|
|
556
|
+
if __name__ == "__main__":
|
|
557
|
+
import unittest
|
|
558
|
+
|
|
559
|
+
class TestUNET(unittest.TestCase):
|
|
560
|
+
"""Lightweight smoke tests for shape and parameter plumbing."""
|
|
561
|
+
|
|
562
|
+
def setUp(self):
|
|
563
|
+
self.nside = 4 # small grid for fast tests (npix = 192)
|
|
564
|
+
self.chanlist = [4, 8] # two-level encoder/decoder
|
|
565
|
+
self.batch = 2
|
|
566
|
+
self.channels = 1
|
|
567
|
+
self.npix = 12 * self.nside * self.nside
|
|
568
|
+
self.cell_ids = _dummy_cell_ids(self.nside)
|
|
569
|
+
self.net = UNET(
|
|
570
|
+
in_nside=self.nside,
|
|
571
|
+
n_chan_in=self.channels,
|
|
572
|
+
chanlist=self.chanlist,
|
|
573
|
+
cell_ids=self.cell_ids,
|
|
574
|
+
)
|
|
575
|
+
|
|
576
|
+
def test_forward_shape(self):
|
|
577
|
+
# random input
|
|
578
|
+
x = np.random.randn(self.batch, self.channels, self.npix).astype(self.np_dtype)
|
|
579
|
+
x = self.net.f.backend.bk_cast(x)
|
|
580
|
+
y = self.net.eval(x)
|
|
581
|
+
# expected output: same npix, 1 channel at the very top
|
|
582
|
+
self.assertEqual(y.shape[0], self.batch)
|
|
583
|
+
self.assertEqual(y.shape[1], 1)
|
|
584
|
+
self.assertEqual(y.shape[2], self.npix)
|
|
585
|
+
# sanity: no NaNs
|
|
586
|
+
y_np = self.net.f.backend.to_numpy(y)
|
|
587
|
+
self.assertFalse(np.isnan(y_np).any())
|
|
588
|
+
|
|
589
|
+
def test_param_roundtrip_and_determinism(self):
|
|
590
|
+
x = np.random.randn(self.batch, self.channels, self.npix).astype(self.np_dtype)
|
|
591
|
+
x = self.net.f.backend.bk_cast(x)
|
|
592
|
+
|
|
593
|
+
# forward twice -> identical outputs with fixed params
|
|
594
|
+
y1 = self.net.eval(x)
|
|
595
|
+
y2 = self.net.eval(x)
|
|
596
|
+
y1_np = self.net.f.backend.to_numpy(y1)
|
|
597
|
+
y2_np = self.net.f.backend.to_numpy(y2)
|
|
598
|
+
np.testing.assert_allclose(y1_np, y2_np, rtol=0, atol=0)
|
|
599
|
+
|
|
600
|
+
# perturb parameters -> output should (very likely) change
|
|
601
|
+
p = self.net.get_param()
|
|
602
|
+
p_np = self.net.f.backend.to_numpy(p).copy()
|
|
603
|
+
if p_np.size > 0:
|
|
604
|
+
p_np[0] += 1.0
|
|
605
|
+
self.net.set_param(p_np)
|
|
606
|
+
y3 = self.net.eval(x)
|
|
607
|
+
y3_np = self.net.f.backend.to_numpy(y3)
|
|
608
|
+
with self.assertRaises(AssertionError):
|
|
609
|
+
np.testing.assert_allclose(y1_np, y3_np, rtol=0, atol=0)
|
|
610
|
+
|
|
611
|
+
unittest.main()
|
|
612
|
+
|
|
613
|
+
def fit(
|
|
614
|
+
model: HealpixUNet,
|
|
615
|
+
x_train: torch.Tensor | np.ndarray,
|
|
616
|
+
y_train: torch.Tensor | np.ndarray,
|
|
617
|
+
*,
|
|
618
|
+
n_epoch: int = 10,
|
|
619
|
+
view_epoch: int = 10,
|
|
620
|
+
batch_size: int = 16,
|
|
621
|
+
lr: float = 1e-3,
|
|
622
|
+
weight_decay: float = 0.0,
|
|
623
|
+
clip_grad_norm: float | None = None,
|
|
624
|
+
verbose: bool = True,
|
|
625
|
+
optimizer: Literal['ADAM', 'LBFGS'] = 'LBFGS',
|
|
626
|
+
) -> dict:
|
|
627
|
+
"""Train helper for the torch-ified HEALPix U-Net.
|
|
628
|
+
|
|
629
|
+
Optimizes all registered parameters (kernels + BN affine) with Adam on MSE for regression,
|
|
630
|
+
or CrossEntropy/BCE for segmentation.
|
|
631
|
+
|
|
632
|
+
Device policy
|
|
633
|
+
-------------
|
|
634
|
+
Uses the model's probed runtime device (CUDA if Foscat conv works there; otherwise CPU).
|
|
635
|
+
"""
|
|
636
|
+
import numpy as _np
|
|
637
|
+
from torch.utils.data import TensorDataset, DataLoader
|
|
638
|
+
|
|
639
|
+
# Ensure model is on its runtime device (already probed in __init__)
|
|
640
|
+
model.to(model.runtime_device)
|
|
641
|
+
|
|
642
|
+
def _to_t(x):
|
|
643
|
+
if isinstance(x, torch.Tensor):
|
|
644
|
+
return x.float().to(model.runtime_device)
|
|
645
|
+
return torch.from_numpy(_np.asarray(x)).float().to(model.runtime_device)
|
|
646
|
+
|
|
647
|
+
x_t = _to_t(x_train)
|
|
648
|
+
y_t = _to_t(y_train)
|
|
649
|
+
|
|
650
|
+
# choose loss
|
|
651
|
+
if model.task == 'regression':
|
|
652
|
+
criterion = nn.MSELoss()
|
|
653
|
+
else:
|
|
654
|
+
if model.out_channels == 1:
|
|
655
|
+
criterion = nn.BCEWithLogitsLoss() if model.final_activation == 'none' else nn.BCELoss()
|
|
656
|
+
else:
|
|
657
|
+
if y_t.dim() == 3:
|
|
658
|
+
y_t = y_t.argmax(dim=1)
|
|
659
|
+
criterion = nn.CrossEntropyLoss()
|
|
660
|
+
|
|
661
|
+
ds = TensorDataset(x_t, y_t)
|
|
662
|
+
loader = DataLoader(ds, batch_size=batch_size, shuffle=True, drop_last=False)
|
|
663
|
+
|
|
664
|
+
if optimizer=='ADAM':
|
|
665
|
+
optim = torch.optim.Adam(model.parameters(), lr=lr, weight_decay=weight_decay)
|
|
666
|
+
l_n_epoch=n_epoch
|
|
667
|
+
n_inter=1
|
|
668
|
+
else:
|
|
669
|
+
optim = torch.optim.LBFGS(model.parameters(), lr=lr, max_iter=20, history_size=n_epoch*5, line_search_fn="strong_wolfe")
|
|
670
|
+
l_n_epoch=n_epoch//20
|
|
671
|
+
n_inter=20
|
|
672
|
+
|
|
673
|
+
history: List[float] = []
|
|
674
|
+
model.train()
|
|
675
|
+
|
|
676
|
+
for epoch in range(l_n_epoch):
|
|
677
|
+
for k in range(n_inter):
|
|
678
|
+
epoch_loss = 0.0
|
|
679
|
+
n_samples = 0
|
|
680
|
+
for xb, yb in loader:
|
|
681
|
+
# LBFGS a besoin d'un closure qui recalcule loss et gradients
|
|
682
|
+
if isinstance(optim, torch.optim.LBFGS):
|
|
683
|
+
def closure():
|
|
684
|
+
optim.zero_grad(set_to_none=True)
|
|
685
|
+
preds = model(xb)
|
|
686
|
+
loss = criterion(preds, yb)
|
|
687
|
+
loss.backward()
|
|
688
|
+
return loss
|
|
689
|
+
|
|
690
|
+
_ = optim.step(closure) # LBFGS appelle plusieurs fois le closure
|
|
691
|
+
# on recalcule la loss finale pour l’agg (sans gradient)
|
|
692
|
+
with torch.no_grad():
|
|
693
|
+
preds = model(xb)
|
|
694
|
+
loss = criterion(preds, yb)
|
|
695
|
+
|
|
696
|
+
else:
|
|
697
|
+
optim.zero_grad(set_to_none=True)
|
|
698
|
+
preds = model(xb)
|
|
699
|
+
loss = criterion(preds, yb)
|
|
700
|
+
loss.backward()
|
|
701
|
+
if clip_grad_norm is not None:
|
|
702
|
+
nn.utils.clip_grad_norm_(model.parameters(), clip_grad_norm)
|
|
703
|
+
optim.step()
|
|
704
|
+
|
|
705
|
+
bs = xb.shape[0]
|
|
706
|
+
epoch_loss += loss.item() * bs
|
|
707
|
+
n_samples += bs
|
|
708
|
+
|
|
709
|
+
epoch_loss /= max(1, n_samples)
|
|
710
|
+
history.append(epoch_loss)
|
|
711
|
+
if verbose and ((epoch*n_inter+k+1)%view_epoch==0 or epoch*n_inter+k==0):
|
|
712
|
+
print(f"[epoch {epoch*n_inter+k+1}/{l_n_epoch*n_inter}] loss={epoch_loss:.6f}")
|
|
713
|
+
|
|
714
|
+
return {"loss": history}
|
|
715
|
+
|
|
716
|
+
|
|
717
|
+
__all__ = ["HealpixUNet", "fit"]
|