foscat 2025.9.1__py3-none-any.whl → 2025.9.4__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- foscat/BkTorch.py +160 -93
- foscat/FoCUS.py +80 -267
- foscat/HOrientedConvol.py +233 -250
- foscat/HealBili.py +12 -8
- foscat/Plot.py +1112 -142
- foscat/SphericalStencil.py +1346 -0
- foscat/UNET.py +21 -7
- foscat/healpix_unet_torch.py +656 -171
- foscat/scat_cov.py +2 -0
- {foscat-2025.9.1.dist-info → foscat-2025.9.4.dist-info}/METADATA +1 -1
- {foscat-2025.9.1.dist-info → foscat-2025.9.4.dist-info}/RECORD +14 -13
- {foscat-2025.9.1.dist-info → foscat-2025.9.4.dist-info}/WHEEL +0 -0
- {foscat-2025.9.1.dist-info → foscat-2025.9.4.dist-info}/licenses/LICENSE +0 -0
- {foscat-2025.9.1.dist-info → foscat-2025.9.4.dist-info}/top_level.txt +0 -0
foscat/healpix_unet_torch.py
CHANGED
|
@@ -6,11 +6,11 @@ GPU by default (when available), with graceful CPU fallback if Foscat ops are CP
|
|
|
6
6
|
- ReLU + BatchNorm after each convolution (encoder & decoder)
|
|
7
7
|
- Segmentation/Regression heads with optional final activation
|
|
8
8
|
- PyTorch-ified: inherits from nn.Module, standard state_dict
|
|
9
|
-
- Device management: tries CUDA first; if Foscat
|
|
9
|
+
- Device management: tries CUDA first; if Foscat SphericalStencil cannot run on CUDA, falls back to CPU
|
|
10
10
|
|
|
11
11
|
Shape convention: (B, C, Npix)
|
|
12
12
|
|
|
13
|
-
Requirements: foscat (scat_cov.funct +
|
|
13
|
+
Requirements: foscat (scat_cov.funct + SphericalStencil.Convol_torch must be differentiable on torch tensors)
|
|
14
14
|
"""
|
|
15
15
|
from __future__ import annotations
|
|
16
16
|
from typing import List, Optional, Literal, Tuple
|
|
@@ -18,13 +18,15 @@ import numpy as np
|
|
|
18
18
|
|
|
19
19
|
import torch
|
|
20
20
|
import torch.nn as nn
|
|
21
|
+
import healpy as hp
|
|
22
|
+
|
|
21
23
|
import torch.nn.functional as F
|
|
24
|
+
from torch.utils.data import Dataset, DataLoader, TensorDataset
|
|
22
25
|
|
|
23
26
|
import foscat.scat_cov as sc
|
|
24
|
-
import foscat.
|
|
27
|
+
import foscat.SphericalStencil as ho
|
|
25
28
|
import matplotlib.pyplot as plt
|
|
26
29
|
|
|
27
|
-
|
|
28
30
|
class HealpixUNet(nn.Module):
|
|
29
31
|
"""U-Net-like architecture on the HEALPix sphere using Foscat oriented convolutions.
|
|
30
32
|
|
|
@@ -40,6 +42,13 @@ class HealpixUNet(nn.Module):
|
|
|
40
42
|
Cell indices for the finest resolution (nside = in_nside) in nested scheme.
|
|
41
43
|
KERNELSZ : int, default 3
|
|
42
44
|
Spatial kernel size K (K x K) for oriented convolution.
|
|
45
|
+
gauge_type : str
|
|
46
|
+
Type of gauge :
|
|
47
|
+
'cosmo' use the same definition than
|
|
48
|
+
https://www.aanda.org/articles/aa/abs/2022/12/aa44566-22/aa44566-22.html
|
|
49
|
+
'phi' is define at the pole, could be better for earth observation not using intensivly the pole
|
|
50
|
+
G : int, default 1
|
|
51
|
+
Number of gauges for the orientation definition.
|
|
43
52
|
task : {'regression','segmentation'}, default 'regression'
|
|
44
53
|
Chooses the head and default activation.
|
|
45
54
|
out_channels : int, default 1
|
|
@@ -50,6 +59,8 @@ class HealpixUNet(nn.Module):
|
|
|
50
59
|
device : str | torch.device | None, default: 'cuda' if available else 'cpu'
|
|
51
60
|
Preferred device. The module will probe whether Foscat ops can run on CUDA; if not,
|
|
52
61
|
it will fallback to CPU and keep all parameters/buffers on CPU for consistency.
|
|
62
|
+
down_type:
|
|
63
|
+
{"mean","max"}, default "max". Equivalent of max poll during down
|
|
53
64
|
prefer_foscat_gpu : bool, default True
|
|
54
65
|
When device is CUDA, try to move Foscat operators (internal tensors) to CUDA and do a dry-run.
|
|
55
66
|
If the dry-run fails, everything falls back to CPU.
|
|
@@ -58,7 +69,7 @@ class HealpixUNet(nn.Module):
|
|
|
58
69
|
-----
|
|
59
70
|
- Two oriented convolutions per level. After each conv: BatchNorm1d + ReLU.
|
|
60
71
|
- Downsampling uses foscat ``ud_grade_2``; upsampling uses ``up_grade``.
|
|
61
|
-
- Convolution kernels are explicit parameters (shape [C_in, C_out, K*K]) and applied via ``
|
|
72
|
+
- Convolution kernels are explicit parameters (shape [C_in, C_out, K*K]) and applied via ``SphericalStencil.Convol_torch``.
|
|
62
73
|
- Foscat ops device is auto-probed to avoid CPU/CUDA mismatches.
|
|
63
74
|
"""
|
|
64
75
|
|
|
@@ -75,16 +86,28 @@ class HealpixUNet(nn.Module):
|
|
|
75
86
|
final_activation: Optional[Literal['none', 'sigmoid', 'softmax']] = None,
|
|
76
87
|
device: Optional[torch.device | str] = None,
|
|
77
88
|
prefer_foscat_gpu: bool = True,
|
|
78
|
-
|
|
89
|
+
gauge_type: Optional[Literal['cosmo','phi']] = 'cosmo',
|
|
90
|
+
G: int =1,
|
|
91
|
+
down_type: Optional[Literal['mean','max']] = 'max',
|
|
92
|
+
dtype: Literal['float32','float64'] = 'float32',
|
|
93
|
+
head_reduce: Literal['mean','learned']='mean'
|
|
79
94
|
) -> None:
|
|
80
95
|
super().__init__()
|
|
81
|
-
|
|
96
|
+
|
|
82
97
|
self.dtype=dtype
|
|
83
98
|
if dtype=='float32':
|
|
84
99
|
self.np_dtype=np.float32
|
|
100
|
+
self.torch_dtype=torch.float32
|
|
85
101
|
else:
|
|
86
102
|
self.np_dtype=np.float64
|
|
87
|
-
|
|
103
|
+
self.torch_dtype=torch.float32
|
|
104
|
+
|
|
105
|
+
self.gauge_type=gauge_type
|
|
106
|
+
self.G = int(G)
|
|
107
|
+
|
|
108
|
+
if self.G < 1:
|
|
109
|
+
raise ValueError("G must be >= 1")
|
|
110
|
+
|
|
88
111
|
if cell_ids is None:
|
|
89
112
|
raise ValueError("cell_ids must be provided for the finest resolution.")
|
|
90
113
|
if len(chanlist) == 0:
|
|
@@ -93,11 +116,16 @@ class HealpixUNet(nn.Module):
|
|
|
93
116
|
self.in_nside = int(in_nside)
|
|
94
117
|
self.n_chan_in = int(n_chan_in)
|
|
95
118
|
self.chanlist = list(map(int, chanlist))
|
|
119
|
+
self.chanlist = [self.chanlist[k]*self.G for k in range(len(self.chanlist))]
|
|
96
120
|
self.KERNELSZ = int(KERNELSZ)
|
|
97
121
|
self.task = task
|
|
98
|
-
self.out_channels = int(out_channels)
|
|
122
|
+
self.out_channels = int(out_channels)*self.G
|
|
99
123
|
self.prefer_foscat_gpu = bool(prefer_foscat_gpu)
|
|
100
|
-
|
|
124
|
+
if down_type == 'max':
|
|
125
|
+
self.max_poll = True
|
|
126
|
+
else:
|
|
127
|
+
self.max_poll = False
|
|
128
|
+
|
|
101
129
|
# Choose default final activation if not given
|
|
102
130
|
if final_activation is None:
|
|
103
131
|
if task == 'regression':
|
|
@@ -124,25 +152,26 @@ class HealpixUNet(nn.Module):
|
|
|
124
152
|
current_nside = self.in_nside
|
|
125
153
|
|
|
126
154
|
# ---------- Oriented convolutions per level (encoder & decoder) ----------
|
|
127
|
-
self.hconv_enc: List[ho.
|
|
128
|
-
self.hconv_dec: List[ho.
|
|
155
|
+
self.hconv_enc: List[ho.SphericalStencil] = []
|
|
156
|
+
self.hconv_dec: List[ho.SphericalStencil] = []
|
|
129
157
|
|
|
130
158
|
# dummy data to propagate shapes/ids through ud_grade_2
|
|
131
159
|
l_data = self.f.backend.bk_cast(np.zeros((1, 1, cell_ids.shape[0]), dtype=self.np_dtype))
|
|
132
160
|
|
|
133
161
|
for l in range(depth):
|
|
134
162
|
# operator at encoder level l
|
|
135
|
-
hc = ho.
|
|
136
|
-
|
|
137
|
-
|
|
138
|
-
|
|
139
|
-
|
|
163
|
+
hc = ho.SphericalStencil(current_nside,
|
|
164
|
+
self.KERNELSZ,
|
|
165
|
+
n_gauges = self.G,
|
|
166
|
+
gauge_type=self.gauge_type,
|
|
167
|
+
cell_ids=self.l_cell_ids[l],
|
|
168
|
+
dtype=self.torch_dtype)
|
|
140
169
|
|
|
141
170
|
self.hconv_enc.append(hc)
|
|
142
171
|
|
|
143
172
|
# downsample once to get next level ids and new data shape
|
|
144
173
|
l_data, next_ids = hc.Down(
|
|
145
|
-
l_data, cell_ids=self.l_cell_ids[l], nside=current_nside
|
|
174
|
+
l_data, cell_ids=self.l_cell_ids[l], nside=current_nside,max_poll=self.max_poll
|
|
146
175
|
)
|
|
147
176
|
self.l_cell_ids[l + 1] = self.f.backend.to_numpy(next_ids)
|
|
148
177
|
current_nside //= 2
|
|
@@ -158,20 +187,23 @@ class HealpixUNet(nn.Module):
|
|
|
158
187
|
|
|
159
188
|
inC = self.n_chan_in
|
|
160
189
|
for l, outC in enumerate(self.chanlist):
|
|
190
|
+
if outC % self.G != 0:
|
|
191
|
+
raise ValueError(f"chanlist[{l}] = {outC} must be divisible by G={self.G}")
|
|
192
|
+
outC_g = outC // self.G
|
|
161
193
|
|
|
162
|
-
# conv1: inC -> outC
|
|
163
|
-
w1 = torch.empty(inC,
|
|
164
|
-
nn.init.kaiming_uniform_(w1.view(inC *
|
|
194
|
+
# conv1: inC -> outC (via multi-gauge => noyau (Ci, Co_g, P))
|
|
195
|
+
w1 = torch.empty(inC, outC_g, self.KERNELSZ * self.KERNELSZ)
|
|
196
|
+
nn.init.kaiming_uniform_(w1.view(inC * outC_g, -1), a=np.sqrt(5))
|
|
165
197
|
self.enc_w1.append(nn.Parameter(w1))
|
|
166
|
-
self.enc_bn1.append(
|
|
198
|
+
self.enc_bn1.append(self._norm_1d(outC, kind="group"))
|
|
167
199
|
|
|
168
|
-
# conv2: outC -> outC
|
|
169
|
-
w2 = torch.empty(outC,
|
|
170
|
-
nn.init.kaiming_uniform_(w2.view(outC *
|
|
200
|
+
# conv2: outC -> outC (entrée = total outC ; noyau (outC, outC_g, P))
|
|
201
|
+
w2 = torch.empty(outC, outC_g, self.KERNELSZ * self.KERNELSZ)
|
|
202
|
+
nn.init.kaiming_uniform_(w2.view(outC * outC_g, -1), a=np.sqrt(5))
|
|
171
203
|
self.enc_w2.append(nn.Parameter(w2))
|
|
172
|
-
self.enc_bn2.append(
|
|
204
|
+
self.enc_bn2.append(self._norm_1d(outC, kind="group"))
|
|
173
205
|
|
|
174
|
-
inC = outC # next
|
|
206
|
+
inC = outC # next layer sees total channels
|
|
175
207
|
|
|
176
208
|
# decoder conv weights and BN (mirrored levels)
|
|
177
209
|
self.dec_w1 = nn.ParameterList()
|
|
@@ -181,45 +213,88 @@ class HealpixUNet(nn.Module):
|
|
|
181
213
|
|
|
182
214
|
for d in range(depth):
|
|
183
215
|
level = depth - 1 - d # encoder level we are going back to
|
|
184
|
-
hc = ho.
|
|
185
|
-
|
|
186
|
-
|
|
187
|
-
|
|
216
|
+
hc = ho.SphericalStencil(self.enc_nsides[level],
|
|
217
|
+
self.KERNELSZ,
|
|
218
|
+
n_gauges = self.G,
|
|
219
|
+
gauge_type=self.gauge_type,
|
|
220
|
+
cell_ids=self.l_cell_ids[level],
|
|
221
|
+
dtype=self.torch_dtype)
|
|
188
222
|
#hc.make_idx_weights()
|
|
189
223
|
self.hconv_dec.append(hc)
|
|
190
224
|
|
|
191
225
|
upC = self.chanlist[level + 1] if level + 1 < depth else self.chanlist[level]
|
|
192
226
|
skipC = self.chanlist[level]
|
|
193
|
-
inC_dec
|
|
194
|
-
outC_dec = skipC
|
|
227
|
+
inC_dec = upC + skipC # total en entrée
|
|
228
|
+
outC_dec = skipC # total en sortie (ce que tu avais déjà)
|
|
229
|
+
|
|
230
|
+
if outC_dec % self.G != 0:
|
|
231
|
+
raise ValueError(f"decoder outC at level {level} = {outC_dec} must be divisible by G={self.G}")
|
|
232
|
+
outC_dec_g = outC_dec // self.G
|
|
195
233
|
|
|
196
|
-
w1 = torch.empty(inC_dec,
|
|
197
|
-
nn.init.kaiming_uniform_(w1.view(inC_dec *
|
|
234
|
+
w1 = torch.empty(inC_dec, outC_dec_g, self.KERNELSZ * self.KERNELSZ)
|
|
235
|
+
nn.init.kaiming_uniform_(w1.view(inC_dec * outC_dec_g, -1), a=np.sqrt(5))
|
|
198
236
|
self.dec_w1.append(nn.Parameter(w1))
|
|
199
|
-
self.dec_bn1.append(
|
|
237
|
+
self.dec_bn1.append(self._norm_1d(outC_dec, kind="group"))
|
|
200
238
|
|
|
201
|
-
w2 = torch.empty(outC_dec,
|
|
202
|
-
nn.init.kaiming_uniform_(w2.view(outC_dec *
|
|
239
|
+
w2 = torch.empty(outC_dec, outC_dec_g, self.KERNELSZ * self.KERNELSZ)
|
|
240
|
+
nn.init.kaiming_uniform_(w2.view(outC_dec * outC_dec_g, -1), a=np.sqrt(5))
|
|
203
241
|
self.dec_w2.append(nn.Parameter(w2))
|
|
204
|
-
self.dec_bn2.append(
|
|
242
|
+
self.dec_bn2.append(self._norm_1d(outC_dec, kind="group"))
|
|
205
243
|
|
|
206
244
|
# Output head (on finest grid, channels = chanlist[0])
|
|
207
|
-
self.head_hconv = ho.
|
|
208
|
-
|
|
209
|
-
|
|
210
|
-
|
|
211
|
-
|
|
212
|
-
|
|
213
|
-
self.head_w = nn.Parameter(torch.empty(head_inC, self.out_channels, self.KERNELSZ * self.KERNELSZ))
|
|
214
|
-
nn.init.kaiming_uniform_(self.head_w.view(head_inC * self.out_channels, -1), a=np.sqrt(5))
|
|
215
|
-
self.head_bn = nn.BatchNorm1d(self.out_channels) if self.task == 'segmentation' else None
|
|
245
|
+
self.head_hconv = ho.SphericalStencil(self.in_nside,
|
|
246
|
+
self.KERNELSZ,
|
|
247
|
+
n_gauges=self.G, #Mandatory for the output
|
|
248
|
+
gauge_type=self.gauge_type,
|
|
249
|
+
cell_ids=self.l_cell_ids[0],
|
|
250
|
+
dtype=self.torch_dtype)
|
|
216
251
|
|
|
252
|
+
head_inC = self.chanlist[0]
|
|
253
|
+
if self.out_channels % self.G != 0:
|
|
254
|
+
raise ValueError(f"out_channels={self.out_channels} must be divisible by G={self.G}")
|
|
255
|
+
outC_head_g = self.out_channels // self.G
|
|
256
|
+
|
|
257
|
+
self.head_w = nn.Parameter(
|
|
258
|
+
torch.empty(head_inC, outC_head_g, self.KERNELSZ * self.KERNELSZ)
|
|
259
|
+
)
|
|
260
|
+
nn.init.kaiming_uniform_(self.head_w.view(head_inC * outC_head_g, -1), a=np.sqrt(5))
|
|
261
|
+
self.head_bn = self._norm_1d(self.out_channels, kind="group") if self.task == 'segmentation' else None
|
|
262
|
+
|
|
263
|
+
# Choose how to reduce across gauges at head:
|
|
264
|
+
# 'sum' (default), 'mean', or 'learned' (via 1x1 conv).
|
|
265
|
+
self.head_reduce = getattr(self, 'head_reduce', 'mean') # you can turn this into a ctor arg if you like
|
|
266
|
+
if self.head_reduce == 'learned':
|
|
267
|
+
# Mixer takes G*outC_head_g -> out_channels (K-wise 1x1)
|
|
268
|
+
self.head_mixer = nn.Conv1d(self.G * outC_head_g, self.out_channels, kernel_size=1, bias=True)
|
|
269
|
+
else:
|
|
270
|
+
self.head_mixer = None
|
|
271
|
+
|
|
217
272
|
# ---- Decide runtime device (probe Foscat on CUDA, else CPU) ----
|
|
218
273
|
self.runtime_device = self._probe_and_set_runtime_device(self.device)
|
|
219
274
|
|
|
275
|
+
# -------------------------- define local batchnorm/group -------------------
|
|
276
|
+
def _norm_1d(self, C: int, kind: str = "group", **kwargs) -> nn.Module:
|
|
277
|
+
"""
|
|
278
|
+
Return a normalization layer for (B, C, N) tensors.
|
|
279
|
+
kind: "group" | "instance" | "batch"
|
|
280
|
+
kwargs: extra args (e.g., num_groups for GroupNorm)
|
|
281
|
+
"""
|
|
282
|
+
if kind == "group":
|
|
283
|
+
num_groups = kwargs.get("num_groups", min(8, max(1, C // 8)) or 1)
|
|
284
|
+
# s’assurer que num_groups divise C
|
|
285
|
+
while C % num_groups != 0 and num_groups > 1:
|
|
286
|
+
num_groups //= 2
|
|
287
|
+
return nn.GroupNorm(num_groups=num_groups, num_channels=C)
|
|
288
|
+
elif kind == "instance":
|
|
289
|
+
return nn.InstanceNorm1d(C, affine=True, track_running_stats=False)
|
|
290
|
+
elif kind == "batch":
|
|
291
|
+
return nn.BatchNorm1d(C)
|
|
292
|
+
else:
|
|
293
|
+
raise ValueError(f"Unknown norm kind: {kind}")
|
|
294
|
+
|
|
220
295
|
# -------------------------- device plumbing --------------------------
|
|
221
|
-
def _move_hconv_tensors(self, hc: ho.
|
|
222
|
-
"""Best-effort: move any torch.Tensor attribute of
|
|
296
|
+
def _move_hconv_tensors(self, hc: ho.SphericalStencil, device: torch.device) -> None:
|
|
297
|
+
"""Best-effort: move any torch.Tensor attribute of SphericalStencil to device."""
|
|
223
298
|
for name, val in list(vars(hc).items()):
|
|
224
299
|
try:
|
|
225
300
|
if torch.is_tensor(val):
|
|
@@ -265,6 +340,78 @@ class HealpixUNet(nn.Module):
|
|
|
265
340
|
self.device = device
|
|
266
341
|
self.runtime_device = self._probe_and_set_runtime_device(device)
|
|
267
342
|
return self.runtime_device
|
|
343
|
+
|
|
344
|
+
# --- inside HealpixUNet class, add a single-sample forward helper ---
|
|
345
|
+
def _forward_one(self, x1: torch.Tensor, cell_ids1=None) -> torch.Tensor:
|
|
346
|
+
"""
|
|
347
|
+
Single-sample forward. x1: (1, C_in, Npix_1). Returns (1, out_channels, Npix_1).
|
|
348
|
+
`cell_ids1` can be None or a 1D array (Npix_1,) for this sample.
|
|
349
|
+
"""
|
|
350
|
+
if x1.dim() != 3 or x1.shape[0] != 1:
|
|
351
|
+
raise ValueError(f"_forward_one expects (1, C, Npix), got {tuple(x1.shape)}")
|
|
352
|
+
# Reuse existing forward by calling it with B=1 (your code already supports per-sample ids)
|
|
353
|
+
if cell_ids1 is None:
|
|
354
|
+
return super().forward(x1)
|
|
355
|
+
else:
|
|
356
|
+
# normalize ids to numpy 1D
|
|
357
|
+
if isinstance(cell_ids1, torch.Tensor):
|
|
358
|
+
cell_ids1 = cell_ids1.detach().cpu().numpy()
|
|
359
|
+
elif isinstance(cell_ids1, list):
|
|
360
|
+
cell_ids1 = np.asarray(cell_ids1)
|
|
361
|
+
if cell_ids1.ndim == 1:
|
|
362
|
+
ci = cell_ids1[None, :] # (1, Npix_1) so the current code path is happy
|
|
363
|
+
else:
|
|
364
|
+
ci = cell_ids1
|
|
365
|
+
return super().forward(x1, cell_ids=ci)
|
|
366
|
+
|
|
367
|
+
def _as_tensor_batch(self, x):
|
|
368
|
+
"""
|
|
369
|
+
Ensure a (B, C, N) tensor.
|
|
370
|
+
- If x is a list of tensors, concatenate if all N are equal.
|
|
371
|
+
- If len==1, keep a batch dim (1, C, N).
|
|
372
|
+
- If x is already a tensor, return as-is.
|
|
373
|
+
"""
|
|
374
|
+
if isinstance(x, list):
|
|
375
|
+
if len(x) == 1:
|
|
376
|
+
t = x[0]
|
|
377
|
+
# If t is (C, N) -> make it (1, C, N)
|
|
378
|
+
return t.unsqueeze(0) if t.dim() == 2 else t
|
|
379
|
+
# all same length -> concat along batch
|
|
380
|
+
Ns = [t.shape[-1] for t in x]
|
|
381
|
+
if all(n == Ns[0] for n in Ns):
|
|
382
|
+
return torch.cat([t if t.dim() == 3 else t.unsqueeze(0) for t in x], dim=0)
|
|
383
|
+
# variable-length with B>1 not supported in a single tensor
|
|
384
|
+
raise ValueError("Variable-length batch detected; use batch_size=1 or loop per-sample.")
|
|
385
|
+
return x
|
|
386
|
+
|
|
387
|
+
# --- replace your current `forward` signature/body with a dispatcher ---
|
|
388
|
+
def forward_any(self, x, cell_ids: Optional[np.ndarray] = None):
|
|
389
|
+
"""
|
|
390
|
+
If `x` is a Tensor (B,C,N): standard batched path (requires same N for all).
|
|
391
|
+
If `x` is a list of Tensors: variable-length per-sample path, returns a list of outputs.
|
|
392
|
+
"""
|
|
393
|
+
# Variable-length list path
|
|
394
|
+
if isinstance(x, (list, tuple)):
|
|
395
|
+
outs = []
|
|
396
|
+
if cell_ids is None or isinstance(cell_ids, (list, tuple)):
|
|
397
|
+
cids = cell_ids if isinstance(cell_ids, (list, tuple)) else [None] * len(x)
|
|
398
|
+
else:
|
|
399
|
+
raise ValueError("When x is a list, cell_ids must be a list of same length or None.")
|
|
400
|
+
|
|
401
|
+
for xb, cb in zip(x, cids):
|
|
402
|
+
if not torch.is_tensor(xb):
|
|
403
|
+
xb = torch.as_tensor(xb, dtype=torch.float32, device=self.runtime_device)
|
|
404
|
+
if xb.dim() == 2:
|
|
405
|
+
xb = xb.unsqueeze(0) # (1,C,Nb)
|
|
406
|
+
elif xb.dim() != 3 or xb.shape[0] != 1:
|
|
407
|
+
raise ValueError(f"Each sample must be (C,N) or (1,C,N); got {tuple(xb.shape)}")
|
|
408
|
+
|
|
409
|
+
yb = self._forward_one(xb.to(self.runtime_device), cell_ids1=cb) # (1,Co,Nb)
|
|
410
|
+
outs.append(yb.squeeze(0)) # -> (Co, Nb)
|
|
411
|
+
return outs # List[Tensor] (each length Nb)
|
|
412
|
+
|
|
413
|
+
# Fixed-length tensor path (your current implementation)
|
|
414
|
+
return super().forward(x, cell_ids=cell_ids)
|
|
268
415
|
|
|
269
416
|
# -------------------------- forward --------------------------
|
|
270
417
|
def forward(self, x: torch.Tensor,cell_ids: Optional[np.ndarray ] = None) -> torch.Tensor:
|
|
@@ -306,6 +453,7 @@ class HealpixUNet(nn.Module):
|
|
|
306
453
|
l_data = self.hconv_enc[l].Convol_torch(l_data,
|
|
307
454
|
self.enc_w1[l],
|
|
308
455
|
cell_ids=l_cell_ids)
|
|
456
|
+
l_data = self._as_tensor_batch(l_data)
|
|
309
457
|
l_data = self.enc_bn1[l](l_data)
|
|
310
458
|
l_data = F.relu(l_data, inplace=True)
|
|
311
459
|
|
|
@@ -313,6 +461,7 @@ class HealpixUNet(nn.Module):
|
|
|
313
461
|
l_data = self.hconv_enc[l].Convol_torch(l_data,
|
|
314
462
|
self.enc_w2[l],
|
|
315
463
|
cell_ids=l_cell_ids)
|
|
464
|
+
l_data = self._as_tensor_batch(l_data)
|
|
316
465
|
l_data = self.enc_bn2[l](l_data)
|
|
317
466
|
l_data = F.relu(l_data, inplace=True)
|
|
318
467
|
|
|
@@ -321,12 +470,10 @@ class HealpixUNet(nn.Module):
|
|
|
321
470
|
|
|
322
471
|
# downsample (except bottom level) -> ensure output is on runtime_device
|
|
323
472
|
if l < len(self.chanlist) - 1:
|
|
324
|
-
#l_data, l_cell_ids = self.f.ud_grade_2(
|
|
325
|
-
# l_data, cell_ids=t_cell_ids[l], nside=current_nside
|
|
326
|
-
#)
|
|
327
473
|
l_data, l_cell_ids = self.hconv_enc[l].Down(
|
|
328
|
-
l_data, cell_ids=t_cell_ids[l], nside=current_nside
|
|
474
|
+
l_data, cell_ids=t_cell_ids[l], nside=current_nside,max_poll=self.max_poll
|
|
329
475
|
)
|
|
476
|
+
l_data = self._as_tensor_batch(l_data)
|
|
330
477
|
if cell_ids is not None:
|
|
331
478
|
t_cell_ids[l+1]=l_cell_ids
|
|
332
479
|
else:
|
|
@@ -335,34 +482,24 @@ class HealpixUNet(nn.Module):
|
|
|
335
482
|
if isinstance(l_data, torch.Tensor) and l_data.device != self.runtime_device:
|
|
336
483
|
l_data = l_data.to(self.runtime_device)
|
|
337
484
|
current_nside //= 2
|
|
338
|
-
|
|
485
|
+
|
|
339
486
|
# Decoder
|
|
340
487
|
for d in range(len(self.chanlist)):
|
|
341
|
-
level = len(self.chanlist) - 1 - d #
|
|
488
|
+
level = len(self.chanlist) - 1 - d # encoder level we are going back to
|
|
342
489
|
|
|
343
490
|
if level < len(self.chanlist) - 1:
|
|
344
|
-
# upsample
|
|
345
|
-
|
|
346
|
-
|
|
347
|
-
|
|
348
|
-
|
|
349
|
-
# and target (fine) ids in `o_cell_ids` (matching original UNET code).
|
|
350
|
-
'''
|
|
351
|
-
l_data = self.f.up_grade(
|
|
352
|
-
l_data,
|
|
353
|
-
tgt_nside,
|
|
354
|
-
cell_ids=t_cell_ids[level + 1], # source (coarser) ids
|
|
355
|
-
o_cell_ids=t_cell_ids[level], # target (finer) ids
|
|
356
|
-
nside=src_nside,
|
|
357
|
-
)
|
|
358
|
-
'''
|
|
359
|
-
l_data = self.hconv_enc[l].Up(
|
|
491
|
+
# upsample: from encoder level (level+1) [coarser] -> level [finer]
|
|
492
|
+
src_nside = self.enc_nsides[level + 1] # coarse
|
|
493
|
+
|
|
494
|
+
# Use the **decoder** operator at this step (consistent with your hconv_dec stack)
|
|
495
|
+
l_data = self.hconv_dec[d].Up(
|
|
360
496
|
l_data,
|
|
361
|
-
cell_ids=t_cell_ids[level + 1],
|
|
362
|
-
o_cell_ids=t_cell_ids[level], # target
|
|
497
|
+
cell_ids=t_cell_ids[level + 1], # source/coarse IDs
|
|
498
|
+
o_cell_ids=t_cell_ids[level], # target/fine IDs
|
|
363
499
|
nside=src_nside,
|
|
364
500
|
)
|
|
365
|
-
|
|
501
|
+
l_data = self._as_tensor_batch(l_data)
|
|
502
|
+
|
|
366
503
|
if isinstance(l_data, torch.Tensor) and l_data.device != self.runtime_device:
|
|
367
504
|
l_data = l_data.to(self.runtime_device)
|
|
368
505
|
|
|
@@ -370,49 +507,85 @@ class HealpixUNet(nn.Module):
|
|
|
370
507
|
concat = self.f.backend.bk_concat([skips[level], l_data], 1)
|
|
371
508
|
l_data = concat.to(self.runtime_device) if torch.is_tensor(concat) else concat
|
|
372
509
|
|
|
373
|
-
|
|
374
|
-
|
|
375
|
-
|
|
376
|
-
# apply decoder convs on this grid
|
|
510
|
+
# choose the right cell_ids for convolutions at this resolution
|
|
511
|
+
l_cell_ids = t_cell_ids[level] if (cell_ids is not None) else None
|
|
512
|
+
|
|
513
|
+
# apply decoder convs on this grid using the matching decoder operator
|
|
377
514
|
hc = self.hconv_dec[d]
|
|
378
|
-
l_data = hc.Convol_torch(l_data,
|
|
379
|
-
|
|
380
|
-
cell_ids=l_cell_ids)
|
|
381
|
-
|
|
515
|
+
l_data = hc.Convol_torch(l_data, self.dec_w1[d], cell_ids=l_cell_ids)
|
|
516
|
+
l_data = self._as_tensor_batch(l_data)
|
|
382
517
|
l_data = self.dec_bn1[d](l_data)
|
|
383
518
|
l_data = F.relu(l_data, inplace=True)
|
|
384
519
|
|
|
385
|
-
l_data = hc.Convol_torch(l_data,
|
|
386
|
-
|
|
387
|
-
cell_ids=l_cell_ids)
|
|
388
|
-
|
|
520
|
+
l_data = hc.Convol_torch(l_data, self.dec_w2[d], cell_ids=l_cell_ids)
|
|
521
|
+
l_data = self._as_tensor_batch(l_data)
|
|
389
522
|
l_data = self.dec_bn2[d](l_data)
|
|
390
523
|
l_data = F.relu(l_data, inplace=True)
|
|
391
524
|
|
|
392
525
|
# Head on finest grid
|
|
393
|
-
|
|
394
|
-
|
|
395
|
-
|
|
526
|
+
# y_head_raw: (B, G*outC_head_g, K)
|
|
527
|
+
y_head_raw = self.head_hconv.Convol_torch(l_data, self.head_w, cell_ids=l_cell_ids)
|
|
528
|
+
|
|
529
|
+
B, Ctot, K = y_head_raw.shape
|
|
530
|
+
outC_head_g = int(self.out_channels)//self.G
|
|
531
|
+
assert Ctot == self.G * outC_head_g, \
|
|
532
|
+
f"Head expects G*outC_head_g channels, got {Ctot} != {self.G}*{outC_head_g}"
|
|
533
|
+
|
|
534
|
+
if self.head_mixer is not None and self.head_reduce == 'learned':
|
|
535
|
+
# 1x1 learned mixing across G*outC_head_g -> out_channels
|
|
536
|
+
y = self.head_mixer(y_head_raw) # (B, out_channels, K)
|
|
537
|
+
else:
|
|
538
|
+
# reshape to (B, G, outC_head_g, K) then reduce across G
|
|
539
|
+
y_g = y_head_raw.view(B, self.G, outC_head_g, K)
|
|
540
|
+
|
|
541
|
+
y = y_g.mean(dim=1) # (B, outC_head_g, K)
|
|
542
|
+
|
|
543
|
+
y = self._as_tensor_batch(y)
|
|
544
|
+
|
|
545
|
+
# Optional BN + activation as before
|
|
546
|
+
if self.task == 'segmentation' and self.head_bn is not None:
|
|
547
|
+
y = self.head_bn(y)
|
|
548
|
+
|
|
396
549
|
if self.final_activation == 'sigmoid':
|
|
397
|
-
|
|
550
|
+
y = torch.sigmoid(y)
|
|
551
|
+
|
|
398
552
|
elif self.final_activation == 'softmax':
|
|
399
|
-
|
|
400
|
-
|
|
553
|
+
y = torch.softmax(y, dim=1)
|
|
554
|
+
|
|
555
|
+
return y
|
|
401
556
|
|
|
402
557
|
# -------------------------- utilities --------------------------
|
|
403
558
|
@torch.no_grad()
|
|
404
559
|
def predict(self, x: torch.Tensor, batch_size: int = 8,cell_ids: Optional[np.ndarray ] = None) -> torch.Tensor:
|
|
405
560
|
self.eval()
|
|
406
561
|
outs = []
|
|
407
|
-
|
|
408
|
-
|
|
409
|
-
|
|
410
|
-
|
|
411
|
-
|
|
412
|
-
|
|
562
|
+
if isinstance(x,np.ndarray):
|
|
563
|
+
x=self.to_Tensor(x)
|
|
564
|
+
|
|
565
|
+
if not isinstance(x, torch.Tensor):
|
|
566
|
+
for i in range(len(x)):
|
|
567
|
+
if cell_ids is not None:
|
|
568
|
+
outs.append(self.forward(x[i][None,:],cell_ids=cell_ids[i][:]))
|
|
569
|
+
else:
|
|
570
|
+
outs.append(self.forward(x[i][None,:]))
|
|
571
|
+
else:
|
|
572
|
+
for i in range(0, x.shape[0], batch_size):
|
|
573
|
+
if cell_ids is not None:
|
|
574
|
+
outs.append(self.forward(x[i : i + batch_size],
|
|
575
|
+
cell_ids=cell_ids[i : i + batch_size]))
|
|
576
|
+
else:
|
|
577
|
+
outs.append(self.forward(x[i : i + batch_size]))
|
|
413
578
|
|
|
414
579
|
return torch.cat(outs, dim=0)
|
|
415
580
|
|
|
581
|
+
def to_tensor(self,x):
|
|
582
|
+
return self.hconv_enc[0].f.backend.bk_cast(x)
|
|
583
|
+
|
|
584
|
+
def to_numpy(self,x):
|
|
585
|
+
if isinstance(x,np.ndarray):
|
|
586
|
+
return x
|
|
587
|
+
return x.cpu().numpy()
|
|
588
|
+
|
|
416
589
|
# -----------------------------
|
|
417
590
|
# Kernel extraction & plotting
|
|
418
591
|
# -----------------------------
|
|
@@ -609,109 +782,421 @@ if __name__ == "__main__":
|
|
|
609
782
|
np.testing.assert_allclose(y1_np, y3_np, rtol=0, atol=0)
|
|
610
783
|
|
|
611
784
|
unittest.main()
|
|
785
|
+
|
|
786
|
+
from torch.utils.data import Dataset
|
|
787
|
+
# 1) Dataset that omits cell_ids when None
|
|
788
|
+
from torch.utils.data import Dataset, DataLoader, TensorDataset
|
|
789
|
+
|
|
790
|
+
class HealpixDataset(Dataset):
|
|
791
|
+
"""
|
|
792
|
+
Returns (x, y, cell_ids) per-sample if cell_ids is given, else (x, y).
|
|
793
|
+
Shapes:
|
|
794
|
+
x: (C, Npix)
|
|
795
|
+
y: (C_out or 1, Npix)
|
|
796
|
+
cell_ids: (Npix,) per-sample (or broadcasted from (Npix,))
|
|
797
|
+
"""
|
|
798
|
+
def __init__(self, x, y, cell_ids=None, dtype=torch.float32):
|
|
799
|
+
self.x = torch.as_tensor(x, dtype=dtype)
|
|
800
|
+
self.y = torch.as_tensor(y, dtype=dtype)
|
|
801
|
+
assert self.x.shape[0] == self.y.shape[0], "x and y must share batch size"
|
|
802
|
+
self._has_cids = cell_ids is not None
|
|
803
|
+
if self._has_cids:
|
|
804
|
+
cid = torch.as_tensor(cell_ids, dtype=torch.long)
|
|
805
|
+
if cid.dim() == 1:
|
|
806
|
+
cid = cid.unsqueeze(0).expand(self.x.shape[0], -1)
|
|
807
|
+
assert cid.shape[0] == self.x.shape[0], "cell_ids must match batch size"
|
|
808
|
+
self.cids = cid
|
|
809
|
+
else:
|
|
810
|
+
self.cids = None
|
|
811
|
+
|
|
812
|
+
def __len__(self):
|
|
813
|
+
return self.x.shape[0]
|
|
814
|
+
|
|
815
|
+
def __getitem__(self, i):
|
|
816
|
+
if self._has_cids:
|
|
817
|
+
return self.x[i], self.y[i], self.cids[i]
|
|
818
|
+
else:
|
|
819
|
+
return self.x[i], self.y[i]
|
|
820
|
+
|
|
821
|
+
# ---------------------------
|
|
822
|
+
# Datasets / Collate helpers
|
|
823
|
+
# ---------------------------
|
|
824
|
+
|
|
825
|
+
class HealpixDataset(Dataset):
|
|
826
|
+
"""
|
|
827
|
+
Fixed-grid dataset (common Npix for all samples).
|
|
828
|
+
Returns (x, y) if cell_ids is None, else (x, y, cell_ids).
|
|
829
|
+
|
|
830
|
+
x: (B, C, Npix)
|
|
831
|
+
y: (B, C_out or 1, Npix) or class indices depending on task
|
|
832
|
+
cell_ids: (Npix,) or (B, Npix)
|
|
833
|
+
"""
|
|
834
|
+
def __init__(self,
|
|
835
|
+
x: torch.Tensor,
|
|
836
|
+
y: torch.Tensor,
|
|
837
|
+
cell_ids: Optional[Union[np.ndarray, torch.Tensor]] = None,
|
|
838
|
+
dtype: torch.dtype = torch.float32):
|
|
839
|
+
x = torch.as_tensor(x, dtype=dtype)
|
|
840
|
+
y = torch.as_tensor(y, dtype=dtype if y.ndim == 3 else torch.long)
|
|
841
|
+
assert x.shape[0] == y.shape[0], "x and y must share batch size"
|
|
842
|
+
self.x, self.y = x, y
|
|
843
|
+
self._has_cids = cell_ids is not None
|
|
844
|
+
if self._has_cids:
|
|
845
|
+
c = torch.as_tensor(cell_ids, dtype=torch.long)
|
|
846
|
+
if c.ndim == 1: # broadcast single (Npix,) to (B, Npix)
|
|
847
|
+
c = c.unsqueeze(0).expand(x.shape[0], -1)
|
|
848
|
+
assert c.shape == (x.shape[0], x.shape[2]), "cell_ids must be (B,Npix) or (Npix,)"
|
|
849
|
+
self.cids = c
|
|
850
|
+
else:
|
|
851
|
+
self.cids = None
|
|
852
|
+
|
|
853
|
+
def __len__(self) -> int: return self.x.shape[0]
|
|
854
|
+
|
|
855
|
+
def __getitem__(self, i: int):
|
|
856
|
+
if self._has_cids:
|
|
857
|
+
return self.x[i], self.y[i], self.cids[i]
|
|
858
|
+
return self.x[i], self.y[i]
|
|
859
|
+
|
|
860
|
+
# ---------------------------
|
|
861
|
+
# Datasets / Collate helpers
|
|
862
|
+
# ---------------------------
|
|
863
|
+
|
|
864
|
+
class HealpixDataset(Dataset):
|
|
865
|
+
"""
|
|
866
|
+
Fixed-grid dataset (common Npix for all samples).
|
|
867
|
+
Returns (x, y) if cell_ids is None, else (x, y, cell_ids).
|
|
868
|
+
|
|
869
|
+
x: (B, C, Npix)
|
|
870
|
+
y: (B, C_out or 1, Npix) or class indices depending on task
|
|
871
|
+
cell_ids: (Npix,) or (B, Npix)
|
|
872
|
+
"""
|
|
873
|
+
def __init__(self,
|
|
874
|
+
x: torch.Tensor,
|
|
875
|
+
y: torch.Tensor,
|
|
876
|
+
cell_ids: Optional[Union[np.ndarray, torch.Tensor]] = None,
|
|
877
|
+
dtype: torch.dtype = torch.float32):
|
|
878
|
+
x = torch.as_tensor(x, dtype=dtype)
|
|
879
|
+
y = torch.as_tensor(y, dtype=dtype if y.ndim == 3 else torch.long)
|
|
880
|
+
assert x.shape[0] == y.shape[0], "x and y must share batch size"
|
|
881
|
+
self.x, self.y = x, y
|
|
882
|
+
self._has_cids = cell_ids is not None
|
|
883
|
+
if self._has_cids:
|
|
884
|
+
c = torch.as_tensor(cell_ids, dtype=torch.long)
|
|
885
|
+
if c.ndim == 1: # broadcast single (Npix,) to (B, Npix)
|
|
886
|
+
c = c.unsqueeze(0).expand(x.shape[0], -1)
|
|
887
|
+
assert c.shape == (x.shape[0], x.shape[2]), "cell_ids must be (B,Npix) or (Npix,)"
|
|
888
|
+
self.cids = c
|
|
889
|
+
else:
|
|
890
|
+
self.cids = None
|
|
891
|
+
|
|
892
|
+
def __len__(self) -> int: return self.x.shape[0]
|
|
893
|
+
|
|
894
|
+
def __getitem__(self, i: int):
|
|
895
|
+
if self._has_cids:
|
|
896
|
+
return self.x[i], self.y[i], self.cids[i]
|
|
897
|
+
return self.x[i], self.y[i]
|
|
898
|
+
|
|
899
|
+
|
|
900
|
+
class VarLenHealpixDataset(Dataset):
|
|
901
|
+
"""
|
|
902
|
+
Variable-length per-sample dataset.
|
|
903
|
+
|
|
904
|
+
x_list[b]: (C, Npix_b) or (1, C, Npix_b)
|
|
905
|
+
y_list[b]: (C_out or 1, Npix_b) or (1, C_out, Npix_b) (regression/segmentation targets)
|
|
906
|
+
For multi-class segmentation with CrossEntropyLoss, you may pass
|
|
907
|
+
class indices of shape (Npix_b,) or (1, Npix_b) (we’ll squeeze later).
|
|
908
|
+
cids_list[b]: (Npix_b,) or None
|
|
909
|
+
"""
|
|
910
|
+
def __init__(self,
|
|
911
|
+
x_list: List[Union[np.ndarray, torch.Tensor]],
|
|
912
|
+
y_list: List[Union[np.ndarray, torch.Tensor]],
|
|
913
|
+
cids_list: Optional[List[Union[np.ndarray, torch.Tensor]]] = None,
|
|
914
|
+
dtype: torch.dtype = torch.float32):
|
|
915
|
+
assert len(x_list) == len(y_list), "x_list and y_list must have the same length"
|
|
916
|
+
self.x = [torch.as_tensor(x, dtype=dtype) for x in x_list]
|
|
917
|
+
# y can be float (regression) or long (class indices); we’ll coerce later per task
|
|
918
|
+
self.y = [torch.as_tensor(y) for y in y_list]
|
|
919
|
+
if cids_list is not None:
|
|
920
|
+
assert len(cids_list) == len(x_list), "cids_list must match x_list length"
|
|
921
|
+
self.c = [torch.as_tensor(c, dtype=torch.long) for c in cids_list]
|
|
922
|
+
else:
|
|
923
|
+
self.c = None
|
|
924
|
+
|
|
925
|
+
def __len__(self) -> int: return len(self.x)
|
|
926
|
+
|
|
927
|
+
def __getitem__(self, i: int):
|
|
928
|
+
ci = None if self.c is None else self.c[i]
|
|
929
|
+
return self.x[i], self.y[i], ci
|
|
930
|
+
|
|
931
|
+
from torch.utils.data import Dataset, DataLoader
|
|
932
|
+
|
|
933
|
+
class VarLenHealpixDataset(Dataset):
|
|
934
|
+
"""
|
|
935
|
+
x_list: list of (C, Npix_b) tensors or arrays
|
|
936
|
+
y_list: list of (C_out or 1, Npix_b) tensors or arrays
|
|
937
|
+
cids_list: optional list of (Npix_b,) arrays
|
|
938
|
+
"""
|
|
939
|
+
def __init__(self, x_list, y_list, cids_list=None, dtype=torch.float32):
|
|
940
|
+
assert len(x_list) == len(y_list)
|
|
941
|
+
self.x = [torch.as_tensor(x, dtype=dtype) for x in x_list]
|
|
942
|
+
self.y = [torch.as_tensor(y, dtype=dtype) for y in y_list]
|
|
943
|
+
self.c = None
|
|
944
|
+
if cids_list is not None:
|
|
945
|
+
assert len(cids_list) == len(x_list)
|
|
946
|
+
self.c = [np.asarray(c) for c in cids_list]
|
|
947
|
+
|
|
948
|
+
def __len__(self): return len(self.x)
|
|
949
|
+
|
|
950
|
+
def __getitem__(self, i):
|
|
951
|
+
if self.c is None:
|
|
952
|
+
return self.x[i], self.y[i], None
|
|
953
|
+
return self.x[i], self.y[i], self.c[i]
|
|
954
|
+
|
|
955
|
+
def varlen_collate(batch):
|
|
956
|
+
# Just return lists; do not stack.
|
|
957
|
+
xs, ys, cs = zip(*batch) # tuples of length B
|
|
958
|
+
# keep None if all Nones, else list
|
|
959
|
+
c_out = None if all(c is None for c in cs) else list(cs)
|
|
960
|
+
return list(xs), list(ys), c_out
|
|
961
|
+
|
|
962
|
+
def varlen_collate(batch):
|
|
963
|
+
"""
|
|
964
|
+
Collate for variable-length samples: keep lists, do NOT stack.
|
|
965
|
+
Returns lists: xs, ys, cs (cs can be None).
|
|
966
|
+
"""
|
|
967
|
+
xs, ys, cs = zip(*batch)
|
|
968
|
+
c_out = None if all(c is None for c in cs) else list(cs)
|
|
969
|
+
return list(xs), list(ys), c_out
|
|
970
|
+
|
|
971
|
+
|
|
972
|
+
# ---------------------------
|
|
973
|
+
# Training function
|
|
974
|
+
# ---------------------------
|
|
612
975
|
|
|
613
976
|
def fit(
|
|
614
|
-
model
|
|
615
|
-
x_train: torch.Tensor
|
|
616
|
-
y_train: torch.Tensor
|
|
977
|
+
model,
|
|
978
|
+
x_train: Union[torch.Tensor, np.ndarray, List[Union[torch.Tensor, np.ndarray]]],
|
|
979
|
+
y_train: Union[torch.Tensor, np.ndarray, List[Union[torch.Tensor, np.ndarray]]],
|
|
617
980
|
*,
|
|
981
|
+
cell_ids_train: Optional[Union[np.ndarray, torch.Tensor, List[Union[np.ndarray, torch.Tensor]]]] = None,
|
|
618
982
|
n_epoch: int = 10,
|
|
619
983
|
view_epoch: int = 10,
|
|
620
984
|
batch_size: int = 16,
|
|
621
985
|
lr: float = 1e-3,
|
|
622
986
|
weight_decay: float = 0.0,
|
|
623
|
-
clip_grad_norm: float
|
|
987
|
+
clip_grad_norm: Optional[float] = None,
|
|
624
988
|
verbose: bool = True,
|
|
625
|
-
optimizer:
|
|
989
|
+
optimizer: Literal['ADAM', 'LBFGS'] = 'ADAM',
|
|
626
990
|
) -> dict:
|
|
627
|
-
"""Train helper for the torch-ified HEALPix U-Net.
|
|
628
|
-
|
|
629
|
-
Optimizes all registered parameters (kernels + BN affine) with Adam on MSE for regression,
|
|
630
|
-
or CrossEntropy/BCE for segmentation.
|
|
631
|
-
|
|
632
|
-
Device policy
|
|
633
|
-
-------------
|
|
634
|
-
Uses the model's probed runtime device (CUDA if Foscat conv works there; otherwise CPU).
|
|
635
991
|
"""
|
|
636
|
-
|
|
637
|
-
|
|
992
|
+
Train helper that supports:
|
|
993
|
+
- Fixed-grid tensors (B,C,N) with optional (B,N) or (N,) cell_ids.
|
|
994
|
+
- Variable-length lists: x=[(C,N_b)], y=[...], cell_ids=[(N_b,)], returning per-sample grids.
|
|
638
995
|
|
|
639
|
-
|
|
640
|
-
|
|
996
|
+
ADAM: standard minibatch update.
|
|
997
|
+
LBFGS: uses a closure that sums losses over the current (variable-length) mini-batch.
|
|
641
998
|
|
|
642
|
-
|
|
643
|
-
|
|
644
|
-
|
|
645
|
-
|
|
646
|
-
|
|
647
|
-
|
|
648
|
-
|
|
649
|
-
|
|
650
|
-
|
|
651
|
-
|
|
652
|
-
|
|
999
|
+
Notes
|
|
1000
|
+
-----
|
|
1001
|
+
- For segmentation with multiple classes, pass integer class targets for y:
|
|
1002
|
+
fixed-grid: (B, N) int64; variable-length: each y[b] of shape (N_b,) or (1,N_b).
|
|
1003
|
+
- For regression, pass float targets with the same (C_out, N) channeling.
|
|
1004
|
+
"""
|
|
1005
|
+
device = model.runtime_device if hasattr(model, "runtime_device") else (torch.device("cuda") if torch.cuda.is_available() else torch.device("cpu"))
|
|
1006
|
+
model.to(device)
|
|
1007
|
+
|
|
1008
|
+
# Detect variable-length mode
|
|
1009
|
+
varlen_mode = isinstance(x_train, (list, tuple))
|
|
1010
|
+
|
|
1011
|
+
# ----- Build DataLoader
|
|
1012
|
+
if not varlen_mode:
|
|
1013
|
+
# Fixed-grid path
|
|
1014
|
+
x_t = torch.as_tensor(x_train, dtype=torch.float32, device=device)
|
|
1015
|
+
y_is_class = (model.task != 'regression' and getattr(model, "out_channels", 1) > 1)
|
|
1016
|
+
y_dtype = torch.long if y_is_class and (not torch.is_tensor(y_train) or y_train.ndim != 3) else torch.float32
|
|
1017
|
+
y_t = torch.as_tensor(y_train, dtype=y_dtype, device=device)
|
|
1018
|
+
|
|
1019
|
+
if cell_ids_train is None:
|
|
1020
|
+
ds = TensorDataset(x_t, y_t)
|
|
1021
|
+
loader = DataLoader(ds, batch_size=batch_size, shuffle=True, drop_last=False)
|
|
1022
|
+
with_cell_ids = False
|
|
1023
|
+
else:
|
|
1024
|
+
ds = HealpixDataset(x_t, y_t, cell_ids=cell_ids_train, dtype=torch.float32)
|
|
1025
|
+
loader = DataLoader(ds, batch_size=batch_size, shuffle=True, drop_last=False)
|
|
1026
|
+
with_cell_ids = True
|
|
1027
|
+
else:
|
|
1028
|
+
# Variable-length path
|
|
1029
|
+
ds = VarLenHealpixDataset(x_train, y_train, cids_list=cell_ids_train, dtype=torch.float32)
|
|
1030
|
+
loader = DataLoader(ds, batch_size=batch_size, shuffle=True, drop_last=False, collate_fn=varlen_collate)
|
|
1031
|
+
with_cell_ids = cell_ids_train is not None
|
|
1032
|
+
|
|
1033
|
+
# ----- Loss
|
|
1034
|
+
if getattr(model, "task", "regression") == 'regression':
|
|
1035
|
+
criterion = nn.MSELoss(reduction='mean')
|
|
1036
|
+
seg_multiclass = False
|
|
653
1037
|
else:
|
|
654
|
-
|
|
655
|
-
|
|
1038
|
+
# segmentation
|
|
1039
|
+
if getattr(model, "out_channels", 1) == 1:
|
|
1040
|
+
# binary
|
|
1041
|
+
# assume model head returns logits if final_activation == 'none'
|
|
1042
|
+
criterion = nn.BCEWithLogitsLoss() if getattr(model, "final_activation", "none") == 'none' else nn.BCELoss()
|
|
1043
|
+
seg_multiclass = False
|
|
656
1044
|
else:
|
|
657
|
-
if y_t.dim() == 3:
|
|
658
|
-
y_t = y_t.argmax(dim=1)
|
|
659
1045
|
criterion = nn.CrossEntropyLoss()
|
|
1046
|
+
seg_multiclass = True
|
|
660
1047
|
|
|
661
|
-
|
|
662
|
-
|
|
663
|
-
|
|
664
|
-
if optimizer=='ADAM':
|
|
1048
|
+
# ----- Optimizer
|
|
1049
|
+
if optimizer.upper() == 'ADAM':
|
|
665
1050
|
optim = torch.optim.Adam(model.parameters(), lr=lr, weight_decay=weight_decay)
|
|
666
|
-
|
|
667
|
-
|
|
1051
|
+
outer = n_epoch
|
|
1052
|
+
inner = 1
|
|
668
1053
|
else:
|
|
669
|
-
optim = torch.optim.LBFGS(model.parameters(), lr=lr, max_iter=20,
|
|
670
|
-
|
|
671
|
-
|
|
1054
|
+
optim = torch.optim.LBFGS(model.parameters(), lr=lr, max_iter=20,
|
|
1055
|
+
history_size=max(10, n_epoch * 5), line_search_fn="strong_wolfe")
|
|
1056
|
+
# emulate "epochs" with multiple inner LBFGS steps
|
|
1057
|
+
outer = max(1, n_epoch // 20)
|
|
1058
|
+
inner = 20
|
|
672
1059
|
|
|
1060
|
+
# ----- Training loop
|
|
673
1061
|
history: List[float] = []
|
|
674
1062
|
model.train()
|
|
675
1063
|
|
|
676
|
-
for epoch in range(
|
|
677
|
-
for
|
|
678
|
-
epoch_loss = 0.0
|
|
679
|
-
|
|
680
|
-
for
|
|
681
|
-
|
|
682
|
-
|
|
683
|
-
|
|
1064
|
+
for epoch in range(outer):
|
|
1065
|
+
for _ in range(inner):
|
|
1066
|
+
epoch_loss, n_samples = 0.0, 0
|
|
1067
|
+
|
|
1068
|
+
for batch in loader:
|
|
1069
|
+
if not varlen_mode:
|
|
1070
|
+
# -------- fixed-grid
|
|
1071
|
+
if with_cell_ids:
|
|
1072
|
+
xb, yb, cb = batch
|
|
1073
|
+
cb_np = cb.detach().cpu().numpy()
|
|
1074
|
+
else:
|
|
1075
|
+
xb, yb = batch
|
|
1076
|
+
cb_np = None
|
|
1077
|
+
|
|
1078
|
+
xb = xb.to(device, dtype=torch.float32, non_blocking=True)
|
|
1079
|
+
# y type: float for regression or binary; long for CrossEntropy
|
|
1080
|
+
yb = yb.to(device, non_blocking=True)
|
|
1081
|
+
|
|
1082
|
+
if isinstance(optim, torch.optim.LBFGS):
|
|
1083
|
+
def closure():
|
|
1084
|
+
optim.zero_grad(set_to_none=True)
|
|
1085
|
+
preds = model(xb, cell_ids=cb_np) if cb_np is not None else model(xb)
|
|
1086
|
+
loss = criterion(preds, yb)
|
|
1087
|
+
loss.backward()
|
|
1088
|
+
return loss
|
|
1089
|
+
_ = optim.step(closure)
|
|
1090
|
+
with torch.no_grad():
|
|
1091
|
+
preds = model(xb, cell_ids=cb_np) if cb_np is not None else model(xb)
|
|
1092
|
+
loss = criterion(preds, yb)
|
|
1093
|
+
else:
|
|
684
1094
|
optim.zero_grad(set_to_none=True)
|
|
685
|
-
preds = model(xb)
|
|
1095
|
+
preds = model(xb, cell_ids=cb_np) if cb_np is not None else model(xb)
|
|
686
1096
|
loss = criterion(preds, yb)
|
|
687
1097
|
loss.backward()
|
|
688
|
-
|
|
1098
|
+
if clip_grad_norm is not None:
|
|
1099
|
+
nn.utils.clip_grad_norm_(model.parameters(), clip_grad_norm)
|
|
1100
|
+
optim.step()
|
|
689
1101
|
|
|
690
|
-
|
|
691
|
-
|
|
692
|
-
|
|
693
|
-
preds = model(xb)
|
|
694
|
-
loss = criterion(preds, yb)
|
|
1102
|
+
bs = xb.shape[0]
|
|
1103
|
+
epoch_loss += float(loss.item()) * bs
|
|
1104
|
+
n_samples += bs
|
|
695
1105
|
|
|
696
1106
|
else:
|
|
697
|
-
|
|
698
|
-
|
|
699
|
-
|
|
700
|
-
|
|
701
|
-
|
|
702
|
-
|
|
703
|
-
|
|
704
|
-
|
|
705
|
-
|
|
706
|
-
|
|
707
|
-
|
|
1107
|
+
# -------- variable-length (lists)
|
|
1108
|
+
xs, ys, cs = batch # lists
|
|
1109
|
+
|
|
1110
|
+
def _prep_xyc(i):
|
|
1111
|
+
# x_i : (C, N_i) -> (1, C, N_i)
|
|
1112
|
+
xb = torch.as_tensor(xs[i], device=device, dtype=torch.float32)
|
|
1113
|
+
if xb.dim() == 2:
|
|
1114
|
+
xb = xb.unsqueeze(0)
|
|
1115
|
+
elif xb.dim() != 3 or xb.shape[0] != 1:
|
|
1116
|
+
raise ValueError("Each x[i] must be (C,N) or (1,C,N)")
|
|
1117
|
+
|
|
1118
|
+
# y_i :
|
|
1119
|
+
yb = torch.as_tensor(ys[i], device=device)
|
|
1120
|
+
if seg_multiclass:
|
|
1121
|
+
# class indices: (N_i,) ou (1, N_i)
|
|
1122
|
+
if yb.dim() == 2 and yb.shape[0] == 1:
|
|
1123
|
+
yb = yb.squeeze(0) # -> (N_i,)
|
|
1124
|
+
elif yb.dim() != 1:
|
|
1125
|
+
raise ValueError("For multiclass CE, y[i] must be (N,) or (1,N)")
|
|
1126
|
+
# le critère CE recevra (1,C_out,N_i) et (N_i,)
|
|
1127
|
+
else:
|
|
1128
|
+
# régression / binaire: cible de forme (1, C_out, N_i)
|
|
1129
|
+
if yb.dim() == 2:
|
|
1130
|
+
yb = yb.unsqueeze(0)
|
|
1131
|
+
elif yb.dim() != 3 or yb.shape[0] != 1:
|
|
1132
|
+
raise ValueError("For regression/binary, y[i] must be (C_out,N) or (1,C_out,N)")
|
|
1133
|
+
|
|
1134
|
+
# cell_ids : (N_i,) -> (1, N_i) en numpy (le forward les attend en np.ndarray)
|
|
1135
|
+
if cs is None or cs[i] is None:
|
|
1136
|
+
cb_np = None
|
|
1137
|
+
else:
|
|
1138
|
+
c = cs[i].detach().cpu().numpy() if torch.is_tensor(cs[i]) else np.asarray(cs[i])
|
|
1139
|
+
if c.ndim == 1:
|
|
1140
|
+
c = c[None, :] # -> (1, N_i)
|
|
1141
|
+
cb_np = c
|
|
1142
|
+
return xb, yb, cb_np
|
|
1143
|
+
|
|
1144
|
+
if isinstance(optim, torch.optim.LBFGS):
|
|
1145
|
+
def closure():
|
|
1146
|
+
optim.zero_grad(set_to_none=True)
|
|
1147
|
+
total = 0.0
|
|
1148
|
+
for i in range(len(xs)):
|
|
1149
|
+
xb, yb, cb_np = _prep_xyc(i)
|
|
1150
|
+
preds = model(xb, cell_ids=cb_np) if cb_np is not None else model(xb)
|
|
1151
|
+
# adapter la cible à la sortie
|
|
1152
|
+
if seg_multiclass:
|
|
1153
|
+
loss_i = criterion(preds, yb) # preds: (1,C_out,N_i), yb: (N_i,)
|
|
1154
|
+
else:
|
|
1155
|
+
loss_i = criterion(preds, yb) # preds: (1,C_out,N_i), yb: (1,C_out,N_i)
|
|
1156
|
+
loss_i.backward()
|
|
1157
|
+
total += float(loss_i.item())
|
|
1158
|
+
# retourner un scalaire Tensor pour LBFGS
|
|
1159
|
+
return torch.tensor(total / max(1, len(xs)), device=device, dtype=torch.float32)
|
|
1160
|
+
|
|
1161
|
+
_ = optim.step(closure)
|
|
1162
|
+
# logging (sans grad)
|
|
1163
|
+
with torch.no_grad():
|
|
1164
|
+
total = 0.0
|
|
1165
|
+
for i in range(len(xs)):
|
|
1166
|
+
xb, yb, cb_np = _prep_xyc(i)
|
|
1167
|
+
preds = model(xb, cell_ids=cb_np) if cb_np is not None else model(xb)
|
|
1168
|
+
if seg_multiclass:
|
|
1169
|
+
loss_i = criterion(preds, yb)
|
|
1170
|
+
else:
|
|
1171
|
+
loss_i = criterion(preds, yb)
|
|
1172
|
+
total += float(loss_i.item())
|
|
1173
|
+
loss_val = total / max(1, len(xs))
|
|
1174
|
+
else:
|
|
1175
|
+
optim.zero_grad(set_to_none=True)
|
|
1176
|
+
total = 0.0
|
|
1177
|
+
for i in range(len(xs)):
|
|
1178
|
+
xb, yb, cb_np = _prep_xyc(i)
|
|
1179
|
+
preds = model(xb, cell_ids=cb_np) if cb_np is not None else model(xb)
|
|
1180
|
+
if seg_multiclass:
|
|
1181
|
+
loss_i = criterion(preds, yb)
|
|
1182
|
+
else:
|
|
1183
|
+
loss_i = criterion(preds, yb)
|
|
1184
|
+
loss_i.backward()
|
|
1185
|
+
total += float(loss_i.item())
|
|
1186
|
+
if clip_grad_norm is not None:
|
|
1187
|
+
nn.utils.clip_grad_norm_(model.parameters(), clip_grad_norm)
|
|
1188
|
+
optim.step()
|
|
1189
|
+
loss_val = total / max(1, len(xs))
|
|
1190
|
+
|
|
1191
|
+
epoch_loss += loss_val * max(1, len(xs))
|
|
1192
|
+
n_samples += max(1, len(xs))
|
|
1193
|
+
|
|
1194
|
+
|
|
708
1195
|
|
|
709
1196
|
epoch_loss /= max(1, n_samples)
|
|
710
1197
|
history.append(epoch_loss)
|
|
711
|
-
|
|
712
|
-
|
|
1198
|
+
# print every view_epoch logical step
|
|
1199
|
+
if verbose and ((len(history) % view_epoch == 0) or (len(history) == 1)):
|
|
1200
|
+
print(f"[epoch {len(history)}] loss={epoch_loss:.6f}")
|
|
713
1201
|
|
|
714
1202
|
return {"loss": history}
|
|
715
|
-
|
|
716
|
-
|
|
717
|
-
__all__ = ["HealpixUNet", "fit"]
|