foscat 2025.8.4__py3-none-any.whl → 2025.9.3__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- foscat/BkTorch.py +309 -50
- foscat/FoCUS.py +74 -267
- foscat/HOrientedConvol.py +517 -130
- foscat/HealBili.py +309 -0
- foscat/Plot.py +331 -0
- foscat/SphericalStencil.py +1346 -0
- foscat/UNET.py +470 -179
- foscat/healpix_unet_torch.py +1202 -0
- foscat/scat_cov.py +3 -1
- {foscat-2025.8.4.dist-info → foscat-2025.9.3.dist-info}/METADATA +1 -1
- {foscat-2025.8.4.dist-info → foscat-2025.9.3.dist-info}/RECORD +14 -10
- {foscat-2025.8.4.dist-info → foscat-2025.9.3.dist-info}/WHEEL +0 -0
- {foscat-2025.8.4.dist-info → foscat-2025.9.3.dist-info}/licenses/LICENSE +0 -0
- {foscat-2025.8.4.dist-info → foscat-2025.9.3.dist-info}/top_level.txt +0 -0
|
@@ -0,0 +1,1202 @@
|
|
|
1
|
+
"""
|
|
2
|
+
HEALPix U-Net (nested) with Foscat + PyTorch niceties
|
|
3
|
+
----------------------------------------------------
|
|
4
|
+
GPU by default (when available), with graceful CPU fallback if Foscat ops are CPU-only.
|
|
5
|
+
|
|
6
|
+
- ReLU + BatchNorm after each convolution (encoder & decoder)
|
|
7
|
+
- Segmentation/Regression heads with optional final activation
|
|
8
|
+
- PyTorch-ified: inherits from nn.Module, standard state_dict
|
|
9
|
+
- Device management: tries CUDA first; if Foscat SphericalStencil cannot run on CUDA, falls back to CPU
|
|
10
|
+
|
|
11
|
+
Shape convention: (B, C, Npix)
|
|
12
|
+
|
|
13
|
+
Requirements: foscat (scat_cov.funct + SphericalStencil.Convol_torch must be differentiable on torch tensors)
|
|
14
|
+
"""
|
|
15
|
+
from __future__ import annotations
|
|
16
|
+
from typing import List, Optional, Literal, Tuple
|
|
17
|
+
import numpy as np
|
|
18
|
+
|
|
19
|
+
import torch
|
|
20
|
+
import torch.nn as nn
|
|
21
|
+
import healpy as hp
|
|
22
|
+
|
|
23
|
+
import torch.nn.functional as F
|
|
24
|
+
from torch.utils.data import Dataset, DataLoader, TensorDataset
|
|
25
|
+
|
|
26
|
+
import foscat.scat_cov as sc
|
|
27
|
+
import foscat.SphericalStencil as ho
|
|
28
|
+
import matplotlib.pyplot as plt
|
|
29
|
+
|
|
30
|
+
class HealpixUNet(nn.Module):
|
|
31
|
+
"""U-Net-like architecture on the HEALPix sphere using Foscat oriented convolutions.
|
|
32
|
+
|
|
33
|
+
Parameters
|
|
34
|
+
----------
|
|
35
|
+
in_nside : int
|
|
36
|
+
Input HEALPix nside (nested scheme).
|
|
37
|
+
n_chan_in : int
|
|
38
|
+
Number of input channels.
|
|
39
|
+
chanlist : list[int]
|
|
40
|
+
Channels per encoder level (depth = len(chanlist)). Example: [16, 32, 64].
|
|
41
|
+
cell_ids : np.ndarray
|
|
42
|
+
Cell indices for the finest resolution (nside = in_nside) in nested scheme.
|
|
43
|
+
KERNELSZ : int, default 3
|
|
44
|
+
Spatial kernel size K (K x K) for oriented convolution.
|
|
45
|
+
gauge_type : str
|
|
46
|
+
Type of gauge :
|
|
47
|
+
'cosmo' use the same definition than
|
|
48
|
+
https://www.aanda.org/articles/aa/abs/2022/12/aa44566-22/aa44566-22.html
|
|
49
|
+
'phi' is define at the pole, could be better for earth observation not using intensivly the pole
|
|
50
|
+
G : int, default 1
|
|
51
|
+
Number of gauges for the orientation definition.
|
|
52
|
+
task : {'regression','segmentation'}, default 'regression'
|
|
53
|
+
Chooses the head and default activation.
|
|
54
|
+
out_channels : int, default 1
|
|
55
|
+
Number of output channels (e.g. num_classes for segmentation).
|
|
56
|
+
final_activation : {'none','sigmoid','softmax'} | None
|
|
57
|
+
If None, uses sensible default per task: 'none' for regression, 'softmax' for segmentation (multi-class),
|
|
58
|
+
'sigmoid' for segmentation when out_channels==1.
|
|
59
|
+
device : str | torch.device | None, default: 'cuda' if available else 'cpu'
|
|
60
|
+
Preferred device. The module will probe whether Foscat ops can run on CUDA; if not,
|
|
61
|
+
it will fallback to CPU and keep all parameters/buffers on CPU for consistency.
|
|
62
|
+
down_type:
|
|
63
|
+
{"mean","max"}, default "max". Equivalent of max poll during down
|
|
64
|
+
prefer_foscat_gpu : bool, default True
|
|
65
|
+
When device is CUDA, try to move Foscat operators (internal tensors) to CUDA and do a dry-run.
|
|
66
|
+
If the dry-run fails, everything falls back to CPU.
|
|
67
|
+
|
|
68
|
+
Notes
|
|
69
|
+
-----
|
|
70
|
+
- Two oriented convolutions per level. After each conv: BatchNorm1d + ReLU.
|
|
71
|
+
- Downsampling uses foscat ``ud_grade_2``; upsampling uses ``up_grade``.
|
|
72
|
+
- Convolution kernels are explicit parameters (shape [C_in, C_out, K*K]) and applied via ``SphericalStencil.Convol_torch``.
|
|
73
|
+
- Foscat ops device is auto-probed to avoid CPU/CUDA mismatches.
|
|
74
|
+
"""
|
|
75
|
+
|
|
76
|
+
def __init__(
|
|
77
|
+
self,
|
|
78
|
+
*,
|
|
79
|
+
in_nside: int,
|
|
80
|
+
n_chan_in: int,
|
|
81
|
+
chanlist: List[int],
|
|
82
|
+
cell_ids: np.ndarray,
|
|
83
|
+
KERNELSZ: int = 3,
|
|
84
|
+
task: Literal['regression', 'segmentation'] = 'regression',
|
|
85
|
+
out_channels: int = 1,
|
|
86
|
+
final_activation: Optional[Literal['none', 'sigmoid', 'softmax']] = None,
|
|
87
|
+
device: Optional[torch.device | str] = None,
|
|
88
|
+
prefer_foscat_gpu: bool = True,
|
|
89
|
+
gauge_type: Optional[Literal['cosmo','phi']] = 'cosmo',
|
|
90
|
+
G: int =1,
|
|
91
|
+
down_type: Optional[Literal['mean','max']] = 'max',
|
|
92
|
+
dtype: Literal['float32','float64'] = 'float32',
|
|
93
|
+
head_reduce: Literal['mean','learned']='mean'
|
|
94
|
+
) -> None:
|
|
95
|
+
super().__init__()
|
|
96
|
+
|
|
97
|
+
self.dtype=dtype
|
|
98
|
+
if dtype=='float32':
|
|
99
|
+
self.np_dtype=np.float32
|
|
100
|
+
self.torch_dtype=torch.float32
|
|
101
|
+
else:
|
|
102
|
+
self.np_dtype=np.float64
|
|
103
|
+
self.torch_dtype=torch.float32
|
|
104
|
+
|
|
105
|
+
self.gauge_type=gauge_type
|
|
106
|
+
self.G = int(G)
|
|
107
|
+
|
|
108
|
+
if self.G < 1:
|
|
109
|
+
raise ValueError("G must be >= 1")
|
|
110
|
+
|
|
111
|
+
if cell_ids is None:
|
|
112
|
+
raise ValueError("cell_ids must be provided for the finest resolution.")
|
|
113
|
+
if len(chanlist) == 0:
|
|
114
|
+
raise ValueError("chanlist must be non-empty (depth >= 1).")
|
|
115
|
+
|
|
116
|
+
self.in_nside = int(in_nside)
|
|
117
|
+
self.n_chan_in = int(n_chan_in)
|
|
118
|
+
self.chanlist = list(map(int, chanlist))
|
|
119
|
+
self.chanlist = [self.chanlist[k]*self.G for k in range(len(self.chanlist))]
|
|
120
|
+
self.KERNELSZ = int(KERNELSZ)
|
|
121
|
+
self.task = task
|
|
122
|
+
self.out_channels = int(out_channels)*self.G
|
|
123
|
+
self.prefer_foscat_gpu = bool(prefer_foscat_gpu)
|
|
124
|
+
if down_type == 'max':
|
|
125
|
+
self.max_poll = True
|
|
126
|
+
else:
|
|
127
|
+
self.max_poll = False
|
|
128
|
+
|
|
129
|
+
# Choose default final activation if not given
|
|
130
|
+
if final_activation is None:
|
|
131
|
+
if task == 'regression':
|
|
132
|
+
self.final_activation = 'none'
|
|
133
|
+
else: # segmentation
|
|
134
|
+
self.final_activation = 'sigmoid' if out_channels == 1 else 'softmax'
|
|
135
|
+
else:
|
|
136
|
+
self.final_activation = final_activation
|
|
137
|
+
|
|
138
|
+
# Resolve preferred device
|
|
139
|
+
if device is None:
|
|
140
|
+
device = 'cuda' if torch.cuda.is_available() else 'cpu'
|
|
141
|
+
self.device = torch.device(device)
|
|
142
|
+
|
|
143
|
+
# Foscat functional wrapper (backend + grade ops)
|
|
144
|
+
self.f = sc.funct(KERNELSZ=self.KERNELSZ)
|
|
145
|
+
|
|
146
|
+
# ---------- Build multi-resolution bookkeeping ----------
|
|
147
|
+
depth = len(self.chanlist)
|
|
148
|
+
self.l_cell_ids: List[np.ndarray] = [None] * (depth + 1) # per encoder level + bottom
|
|
149
|
+
self.l_cell_ids[0] = np.asarray(cell_ids)
|
|
150
|
+
|
|
151
|
+
enc_nsides: List[int] = [self.in_nside]
|
|
152
|
+
current_nside = self.in_nside
|
|
153
|
+
|
|
154
|
+
# ---------- Oriented convolutions per level (encoder & decoder) ----------
|
|
155
|
+
self.hconv_enc: List[ho.SphericalStencil] = []
|
|
156
|
+
self.hconv_dec: List[ho.SphericalStencil] = []
|
|
157
|
+
|
|
158
|
+
# dummy data to propagate shapes/ids through ud_grade_2
|
|
159
|
+
l_data = self.f.backend.bk_cast(np.zeros((1, 1, cell_ids.shape[0]), dtype=self.np_dtype))
|
|
160
|
+
|
|
161
|
+
for l in range(depth):
|
|
162
|
+
# operator at encoder level l
|
|
163
|
+
hc = ho.SphericalStencil(current_nside,
|
|
164
|
+
self.KERNELSZ,
|
|
165
|
+
n_gauges = self.G,
|
|
166
|
+
gauge_type=self.gauge_type,
|
|
167
|
+
cell_ids=self.l_cell_ids[l],
|
|
168
|
+
dtype=self.torch_dtype)
|
|
169
|
+
|
|
170
|
+
self.hconv_enc.append(hc)
|
|
171
|
+
|
|
172
|
+
# downsample once to get next level ids and new data shape
|
|
173
|
+
l_data, next_ids = hc.Down(
|
|
174
|
+
l_data, cell_ids=self.l_cell_ids[l], nside=current_nside,max_poll=self.max_poll
|
|
175
|
+
)
|
|
176
|
+
self.l_cell_ids[l + 1] = self.f.backend.to_numpy(next_ids)
|
|
177
|
+
current_nside //= 2
|
|
178
|
+
enc_nsides.append(current_nside)
|
|
179
|
+
|
|
180
|
+
# encoder conv weights and BN
|
|
181
|
+
self.enc_w1 = nn.ParameterList()
|
|
182
|
+
self.enc_bn1 = nn.ModuleList()
|
|
183
|
+
self.enc_w2 = nn.ParameterList()
|
|
184
|
+
self.enc_bn2 = nn.ModuleList()
|
|
185
|
+
|
|
186
|
+
self.enc_nsides = enc_nsides # [in, in/2, ..., in/2**depth]
|
|
187
|
+
|
|
188
|
+
inC = self.n_chan_in
|
|
189
|
+
for l, outC in enumerate(self.chanlist):
|
|
190
|
+
if outC % self.G != 0:
|
|
191
|
+
raise ValueError(f"chanlist[{l}] = {outC} must be divisible by G={self.G}")
|
|
192
|
+
outC_g = outC // self.G
|
|
193
|
+
|
|
194
|
+
# conv1: inC -> outC (via multi-gauge => noyau (Ci, Co_g, P))
|
|
195
|
+
w1 = torch.empty(inC, outC_g, self.KERNELSZ * self.KERNELSZ)
|
|
196
|
+
nn.init.kaiming_uniform_(w1.view(inC * outC_g, -1), a=np.sqrt(5))
|
|
197
|
+
self.enc_w1.append(nn.Parameter(w1))
|
|
198
|
+
self.enc_bn1.append(self._norm_1d(outC, kind="group"))
|
|
199
|
+
|
|
200
|
+
# conv2: outC -> outC (entrée = total outC ; noyau (outC, outC_g, P))
|
|
201
|
+
w2 = torch.empty(outC, outC_g, self.KERNELSZ * self.KERNELSZ)
|
|
202
|
+
nn.init.kaiming_uniform_(w2.view(outC * outC_g, -1), a=np.sqrt(5))
|
|
203
|
+
self.enc_w2.append(nn.Parameter(w2))
|
|
204
|
+
self.enc_bn2.append(self._norm_1d(outC, kind="group"))
|
|
205
|
+
|
|
206
|
+
inC = outC # next layer sees total channels
|
|
207
|
+
|
|
208
|
+
# decoder conv weights and BN (mirrored levels)
|
|
209
|
+
self.dec_w1 = nn.ParameterList()
|
|
210
|
+
self.dec_bn1 = nn.ModuleList()
|
|
211
|
+
self.dec_w2 = nn.ParameterList()
|
|
212
|
+
self.dec_bn2 = nn.ModuleList()
|
|
213
|
+
|
|
214
|
+
for d in range(depth):
|
|
215
|
+
level = depth - 1 - d # encoder level we are going back to
|
|
216
|
+
hc = ho.SphericalStencil(self.enc_nsides[level],
|
|
217
|
+
self.KERNELSZ,
|
|
218
|
+
n_gauges = self.G,
|
|
219
|
+
gauge_type=self.gauge_type,
|
|
220
|
+
cell_ids=self.l_cell_ids[level],
|
|
221
|
+
dtype=self.torch_dtype)
|
|
222
|
+
#hc.make_idx_weights()
|
|
223
|
+
self.hconv_dec.append(hc)
|
|
224
|
+
|
|
225
|
+
upC = self.chanlist[level + 1] if level + 1 < depth else self.chanlist[level]
|
|
226
|
+
skipC = self.chanlist[level]
|
|
227
|
+
inC_dec = upC + skipC # total en entrée
|
|
228
|
+
outC_dec = skipC # total en sortie (ce que tu avais déjà)
|
|
229
|
+
|
|
230
|
+
if outC_dec % self.G != 0:
|
|
231
|
+
raise ValueError(f"decoder outC at level {level} = {outC_dec} must be divisible by G={self.G}")
|
|
232
|
+
outC_dec_g = outC_dec // self.G
|
|
233
|
+
|
|
234
|
+
w1 = torch.empty(inC_dec, outC_dec_g, self.KERNELSZ * self.KERNELSZ)
|
|
235
|
+
nn.init.kaiming_uniform_(w1.view(inC_dec * outC_dec_g, -1), a=np.sqrt(5))
|
|
236
|
+
self.dec_w1.append(nn.Parameter(w1))
|
|
237
|
+
self.dec_bn1.append(self._norm_1d(outC_dec, kind="group"))
|
|
238
|
+
|
|
239
|
+
w2 = torch.empty(outC_dec, outC_dec_g, self.KERNELSZ * self.KERNELSZ)
|
|
240
|
+
nn.init.kaiming_uniform_(w2.view(outC_dec * outC_dec_g, -1), a=np.sqrt(5))
|
|
241
|
+
self.dec_w2.append(nn.Parameter(w2))
|
|
242
|
+
self.dec_bn2.append(self._norm_1d(outC_dec, kind="group"))
|
|
243
|
+
|
|
244
|
+
# Output head (on finest grid, channels = chanlist[0])
|
|
245
|
+
self.head_hconv = ho.SphericalStencil(self.in_nside,
|
|
246
|
+
self.KERNELSZ,
|
|
247
|
+
n_gauges=self.G, #Mandatory for the output
|
|
248
|
+
gauge_type=self.gauge_type,
|
|
249
|
+
cell_ids=self.l_cell_ids[0],
|
|
250
|
+
dtype=self.torch_dtype)
|
|
251
|
+
|
|
252
|
+
head_inC = self.chanlist[0]
|
|
253
|
+
if self.out_channels % self.G != 0:
|
|
254
|
+
raise ValueError(f"out_channels={self.out_channels} must be divisible by G={self.G}")
|
|
255
|
+
outC_head_g = self.out_channels // self.G
|
|
256
|
+
|
|
257
|
+
self.head_w = nn.Parameter(
|
|
258
|
+
torch.empty(head_inC, outC_head_g, self.KERNELSZ * self.KERNELSZ)
|
|
259
|
+
)
|
|
260
|
+
nn.init.kaiming_uniform_(self.head_w.view(head_inC * outC_head_g, -1), a=np.sqrt(5))
|
|
261
|
+
self.head_bn = self._norm_1d(self.out_channels, kind="group") if self.task == 'segmentation' else None
|
|
262
|
+
|
|
263
|
+
# Choose how to reduce across gauges at head:
|
|
264
|
+
# 'sum' (default), 'mean', or 'learned' (via 1x1 conv).
|
|
265
|
+
self.head_reduce = getattr(self, 'head_reduce', 'mean') # you can turn this into a ctor arg if you like
|
|
266
|
+
if self.head_reduce == 'learned':
|
|
267
|
+
# Mixer takes G*outC_head_g -> out_channels (K-wise 1x1)
|
|
268
|
+
self.head_mixer = nn.Conv1d(self.G * outC_head_g, self.out_channels, kernel_size=1, bias=True)
|
|
269
|
+
else:
|
|
270
|
+
self.head_mixer = None
|
|
271
|
+
|
|
272
|
+
# ---- Decide runtime device (probe Foscat on CUDA, else CPU) ----
|
|
273
|
+
self.runtime_device = self._probe_and_set_runtime_device(self.device)
|
|
274
|
+
|
|
275
|
+
# -------------------------- define local batchnorm/group -------------------
|
|
276
|
+
def _norm_1d(self, C: int, kind: str = "group", **kwargs) -> nn.Module:
|
|
277
|
+
"""
|
|
278
|
+
Return a normalization layer for (B, C, N) tensors.
|
|
279
|
+
kind: "group" | "instance" | "batch"
|
|
280
|
+
kwargs: extra args (e.g., num_groups for GroupNorm)
|
|
281
|
+
"""
|
|
282
|
+
if kind == "group":
|
|
283
|
+
num_groups = kwargs.get("num_groups", min(8, max(1, C // 8)) or 1)
|
|
284
|
+
# s’assurer que num_groups divise C
|
|
285
|
+
while C % num_groups != 0 and num_groups > 1:
|
|
286
|
+
num_groups //= 2
|
|
287
|
+
return nn.GroupNorm(num_groups=num_groups, num_channels=C)
|
|
288
|
+
elif kind == "instance":
|
|
289
|
+
return nn.InstanceNorm1d(C, affine=True, track_running_stats=False)
|
|
290
|
+
elif kind == "batch":
|
|
291
|
+
return nn.BatchNorm1d(C)
|
|
292
|
+
else:
|
|
293
|
+
raise ValueError(f"Unknown norm kind: {kind}")
|
|
294
|
+
|
|
295
|
+
# -------------------------- device plumbing --------------------------
|
|
296
|
+
def _move_hconv_tensors(self, hc: ho.SphericalStencil, device: torch.device) -> None:
|
|
297
|
+
"""Best-effort: move any torch.Tensor attribute of SphericalStencil to device."""
|
|
298
|
+
for name, val in list(vars(hc).items()):
|
|
299
|
+
try:
|
|
300
|
+
if torch.is_tensor(val):
|
|
301
|
+
setattr(hc, name, val.to(device))
|
|
302
|
+
elif isinstance(val, (list, tuple)) and val and torch.is_tensor(val[0]):
|
|
303
|
+
setattr(hc, name, type(val)([v.to(device) for v in val]))
|
|
304
|
+
except Exception:
|
|
305
|
+
# silently ignore non-tensor or protected attributes
|
|
306
|
+
pass
|
|
307
|
+
|
|
308
|
+
@torch.no_grad()
|
|
309
|
+
def _probe_and_set_runtime_device(self, preferred: torch.device) -> torch.device:
|
|
310
|
+
"""Try to run a tiny Foscat conv on preferred device; fallback to CPU if it fails."""
|
|
311
|
+
if preferred.type == 'cuda' and self.prefer_foscat_gpu:
|
|
312
|
+
try:
|
|
313
|
+
# move module params/buffers first
|
|
314
|
+
super().to(preferred)
|
|
315
|
+
# move Foscat operator internals
|
|
316
|
+
for hc in self.hconv_enc + self.hconv_dec + [self.head_hconv]:
|
|
317
|
+
self._move_hconv_tensors(hc, preferred)
|
|
318
|
+
# dry run on level 0
|
|
319
|
+
npix0 = int(len(self.l_cell_ids[0]))
|
|
320
|
+
x_try = torch.zeros(1, self.n_chan_in, npix0, device=preferred)
|
|
321
|
+
y_try = self.hconv_enc[0].Convol_torch(x_try, self.enc_w1[0])
|
|
322
|
+
# success -> stay on CUDA
|
|
323
|
+
self._foscat_device = preferred
|
|
324
|
+
return preferred
|
|
325
|
+
except Exception as e:
|
|
326
|
+
# fallback to CPU; keep error for info
|
|
327
|
+
self._gpu_probe_error = repr(e)
|
|
328
|
+
pass
|
|
329
|
+
# CPU fallback
|
|
330
|
+
cpu = torch.device('cpu')
|
|
331
|
+
super().to(cpu)
|
|
332
|
+
for hc in self.hconv_enc + self.hconv_dec + [self.head_hconv]:
|
|
333
|
+
self._move_hconv_tensors(hc, cpu)
|
|
334
|
+
self._foscat_device = cpu
|
|
335
|
+
return cpu
|
|
336
|
+
|
|
337
|
+
def set_device(self, device: torch.device | str) -> torch.device:
|
|
338
|
+
"""Request a (re)device; will probe Foscat and return the actual runtime device used."""
|
|
339
|
+
device = torch.device(device)
|
|
340
|
+
self.device = device
|
|
341
|
+
self.runtime_device = self._probe_and_set_runtime_device(device)
|
|
342
|
+
return self.runtime_device
|
|
343
|
+
|
|
344
|
+
# --- inside HealpixUNet class, add a single-sample forward helper ---
|
|
345
|
+
def _forward_one(self, x1: torch.Tensor, cell_ids1=None) -> torch.Tensor:
|
|
346
|
+
"""
|
|
347
|
+
Single-sample forward. x1: (1, C_in, Npix_1). Returns (1, out_channels, Npix_1).
|
|
348
|
+
`cell_ids1` can be None or a 1D array (Npix_1,) for this sample.
|
|
349
|
+
"""
|
|
350
|
+
if x1.dim() != 3 or x1.shape[0] != 1:
|
|
351
|
+
raise ValueError(f"_forward_one expects (1, C, Npix), got {tuple(x1.shape)}")
|
|
352
|
+
# Reuse existing forward by calling it with B=1 (your code already supports per-sample ids)
|
|
353
|
+
if cell_ids1 is None:
|
|
354
|
+
return super().forward(x1)
|
|
355
|
+
else:
|
|
356
|
+
# normalize ids to numpy 1D
|
|
357
|
+
if isinstance(cell_ids1, torch.Tensor):
|
|
358
|
+
cell_ids1 = cell_ids1.detach().cpu().numpy()
|
|
359
|
+
elif isinstance(cell_ids1, list):
|
|
360
|
+
cell_ids1 = np.asarray(cell_ids1)
|
|
361
|
+
if cell_ids1.ndim == 1:
|
|
362
|
+
ci = cell_ids1[None, :] # (1, Npix_1) so the current code path is happy
|
|
363
|
+
else:
|
|
364
|
+
ci = cell_ids1
|
|
365
|
+
return super().forward(x1, cell_ids=ci)
|
|
366
|
+
|
|
367
|
+
def _as_tensor_batch(self, x):
|
|
368
|
+
"""
|
|
369
|
+
Ensure a (B, C, N) tensor.
|
|
370
|
+
- If x is a list of tensors, concatenate if all N are equal.
|
|
371
|
+
- If len==1, keep a batch dim (1, C, N).
|
|
372
|
+
- If x is already a tensor, return as-is.
|
|
373
|
+
"""
|
|
374
|
+
if isinstance(x, list):
|
|
375
|
+
if len(x) == 1:
|
|
376
|
+
t = x[0]
|
|
377
|
+
# If t is (C, N) -> make it (1, C, N)
|
|
378
|
+
return t.unsqueeze(0) if t.dim() == 2 else t
|
|
379
|
+
# all same length -> concat along batch
|
|
380
|
+
Ns = [t.shape[-1] for t in x]
|
|
381
|
+
if all(n == Ns[0] for n in Ns):
|
|
382
|
+
return torch.cat([t if t.dim() == 3 else t.unsqueeze(0) for t in x], dim=0)
|
|
383
|
+
# variable-length with B>1 not supported in a single tensor
|
|
384
|
+
raise ValueError("Variable-length batch detected; use batch_size=1 or loop per-sample.")
|
|
385
|
+
return x
|
|
386
|
+
|
|
387
|
+
# --- replace your current `forward` signature/body with a dispatcher ---
|
|
388
|
+
def forward_any(self, x, cell_ids: Optional[np.ndarray] = None):
|
|
389
|
+
"""
|
|
390
|
+
If `x` is a Tensor (B,C,N): standard batched path (requires same N for all).
|
|
391
|
+
If `x` is a list of Tensors: variable-length per-sample path, returns a list of outputs.
|
|
392
|
+
"""
|
|
393
|
+
# Variable-length list path
|
|
394
|
+
if isinstance(x, (list, tuple)):
|
|
395
|
+
outs = []
|
|
396
|
+
if cell_ids is None or isinstance(cell_ids, (list, tuple)):
|
|
397
|
+
cids = cell_ids if isinstance(cell_ids, (list, tuple)) else [None] * len(x)
|
|
398
|
+
else:
|
|
399
|
+
raise ValueError("When x is a list, cell_ids must be a list of same length or None.")
|
|
400
|
+
|
|
401
|
+
for xb, cb in zip(x, cids):
|
|
402
|
+
if not torch.is_tensor(xb):
|
|
403
|
+
xb = torch.as_tensor(xb, dtype=torch.float32, device=self.runtime_device)
|
|
404
|
+
if xb.dim() == 2:
|
|
405
|
+
xb = xb.unsqueeze(0) # (1,C,Nb)
|
|
406
|
+
elif xb.dim() != 3 or xb.shape[0] != 1:
|
|
407
|
+
raise ValueError(f"Each sample must be (C,N) or (1,C,N); got {tuple(xb.shape)}")
|
|
408
|
+
|
|
409
|
+
yb = self._forward_one(xb.to(self.runtime_device), cell_ids1=cb) # (1,Co,Nb)
|
|
410
|
+
outs.append(yb.squeeze(0)) # -> (Co, Nb)
|
|
411
|
+
return outs # List[Tensor] (each length Nb)
|
|
412
|
+
|
|
413
|
+
# Fixed-length tensor path (your current implementation)
|
|
414
|
+
return super().forward(x, cell_ids=cell_ids)
|
|
415
|
+
|
|
416
|
+
# -------------------------- forward --------------------------
|
|
417
|
+
def forward(self, x: torch.Tensor,cell_ids: Optional[np.ndarray ] = None) -> torch.Tensor:
|
|
418
|
+
"""Forward pass.
|
|
419
|
+
|
|
420
|
+
Parameters
|
|
421
|
+
----------
|
|
422
|
+
x : torch.Tensor, shape (B, C_in, Npix)
|
|
423
|
+
Input tensor on `in_nside` grid.
|
|
424
|
+
cell_ids : np.ndarray (B, Npix) optional, use another cell_ids than the initial one.
|
|
425
|
+
if None use the initial cell_ids.
|
|
426
|
+
"""
|
|
427
|
+
if not isinstance(x, torch.Tensor):
|
|
428
|
+
raise TypeError("Input must be a torch.Tensor")
|
|
429
|
+
if x.dim() != 3:
|
|
430
|
+
raise ValueError("Input must be (B, C, Npix)")
|
|
431
|
+
|
|
432
|
+
# Ensure input lives on the runtime (probed) device
|
|
433
|
+
x = x.to(self.runtime_device)
|
|
434
|
+
|
|
435
|
+
B, C, N = x.shape
|
|
436
|
+
if C != self.n_chan_in:
|
|
437
|
+
raise ValueError(f"Expected {self.n_chan_in} input channels, got {C}")
|
|
438
|
+
|
|
439
|
+
# Encoder
|
|
440
|
+
skips: List[torch.Tensor] = []
|
|
441
|
+
l_data = x
|
|
442
|
+
current_nside = self.in_nside
|
|
443
|
+
l_cell_ids=cell_ids
|
|
444
|
+
|
|
445
|
+
if cell_ids is not None:
|
|
446
|
+
t_cell_ids={}
|
|
447
|
+
t_cell_ids[0]=l_cell_ids
|
|
448
|
+
else:
|
|
449
|
+
t_cell_ids=self.l_cell_ids
|
|
450
|
+
|
|
451
|
+
for l, outC in enumerate(self.chanlist):
|
|
452
|
+
# conv1 + BN + ReLU
|
|
453
|
+
l_data = self.hconv_enc[l].Convol_torch(l_data,
|
|
454
|
+
self.enc_w1[l],
|
|
455
|
+
cell_ids=l_cell_ids)
|
|
456
|
+
l_data = self._as_tensor_batch(l_data)
|
|
457
|
+
l_data = self.enc_bn1[l](l_data)
|
|
458
|
+
l_data = F.relu(l_data, inplace=True)
|
|
459
|
+
|
|
460
|
+
# conv2 + BN + ReLU
|
|
461
|
+
l_data = self.hconv_enc[l].Convol_torch(l_data,
|
|
462
|
+
self.enc_w2[l],
|
|
463
|
+
cell_ids=l_cell_ids)
|
|
464
|
+
l_data = self._as_tensor_batch(l_data)
|
|
465
|
+
l_data = self.enc_bn2[l](l_data)
|
|
466
|
+
l_data = F.relu(l_data, inplace=True)
|
|
467
|
+
|
|
468
|
+
# save skip at this resolution
|
|
469
|
+
skips.append(l_data)
|
|
470
|
+
|
|
471
|
+
# downsample (except bottom level) -> ensure output is on runtime_device
|
|
472
|
+
if l < len(self.chanlist) - 1:
|
|
473
|
+
l_data, l_cell_ids = self.hconv_enc[l].Down(
|
|
474
|
+
l_data, cell_ids=t_cell_ids[l], nside=current_nside,max_poll=self.max_poll
|
|
475
|
+
)
|
|
476
|
+
l_data = self._as_tensor_batch(l_data)
|
|
477
|
+
if cell_ids is not None:
|
|
478
|
+
t_cell_ids[l+1]=l_cell_ids
|
|
479
|
+
else:
|
|
480
|
+
l_cell_ids=None
|
|
481
|
+
|
|
482
|
+
if isinstance(l_data, torch.Tensor) and l_data.device != self.runtime_device:
|
|
483
|
+
l_data = l_data.to(self.runtime_device)
|
|
484
|
+
current_nside //= 2
|
|
485
|
+
|
|
486
|
+
# Decoder
|
|
487
|
+
for d in range(len(self.chanlist)):
|
|
488
|
+
level = len(self.chanlist) - 1 - d # encoder level we are going back to
|
|
489
|
+
|
|
490
|
+
if level < len(self.chanlist) - 1:
|
|
491
|
+
# upsample: from encoder level (level+1) [coarser] -> level [finer]
|
|
492
|
+
src_nside = self.enc_nsides[level + 1] # coarse
|
|
493
|
+
|
|
494
|
+
# Use the **decoder** operator at this step (consistent with your hconv_dec stack)
|
|
495
|
+
l_data = self.hconv_dec[d].Up(
|
|
496
|
+
l_data,
|
|
497
|
+
cell_ids=t_cell_ids[level + 1], # source/coarse IDs
|
|
498
|
+
o_cell_ids=t_cell_ids[level], # target/fine IDs
|
|
499
|
+
nside=src_nside,
|
|
500
|
+
)
|
|
501
|
+
l_data = self._as_tensor_batch(l_data)
|
|
502
|
+
|
|
503
|
+
if isinstance(l_data, torch.Tensor) and l_data.device != self.runtime_device:
|
|
504
|
+
l_data = l_data.to(self.runtime_device)
|
|
505
|
+
|
|
506
|
+
# concat with skip features at this resolution
|
|
507
|
+
concat = self.f.backend.bk_concat([skips[level], l_data], 1)
|
|
508
|
+
l_data = concat.to(self.runtime_device) if torch.is_tensor(concat) else concat
|
|
509
|
+
|
|
510
|
+
# choose the right cell_ids for convolutions at this resolution
|
|
511
|
+
l_cell_ids = t_cell_ids[level] if (cell_ids is not None) else None
|
|
512
|
+
|
|
513
|
+
# apply decoder convs on this grid using the matching decoder operator
|
|
514
|
+
hc = self.hconv_dec[d]
|
|
515
|
+
l_data = hc.Convol_torch(l_data, self.dec_w1[d], cell_ids=l_cell_ids)
|
|
516
|
+
l_data = self._as_tensor_batch(l_data)
|
|
517
|
+
l_data = self.dec_bn1[d](l_data)
|
|
518
|
+
l_data = F.relu(l_data, inplace=True)
|
|
519
|
+
|
|
520
|
+
l_data = hc.Convol_torch(l_data, self.dec_w2[d], cell_ids=l_cell_ids)
|
|
521
|
+
l_data = self._as_tensor_batch(l_data)
|
|
522
|
+
l_data = self.dec_bn2[d](l_data)
|
|
523
|
+
l_data = F.relu(l_data, inplace=True)
|
|
524
|
+
|
|
525
|
+
# Head on finest grid
|
|
526
|
+
# y_head_raw: (B, G*outC_head_g, K)
|
|
527
|
+
y_head_raw = self.head_hconv.Convol_torch(l_data, self.head_w, cell_ids=l_cell_ids)
|
|
528
|
+
|
|
529
|
+
B, Ctot, K = y_head_raw.shape
|
|
530
|
+
outC_head_g = int(self.out_channels)//self.G
|
|
531
|
+
assert Ctot == self.G * outC_head_g, \
|
|
532
|
+
f"Head expects G*outC_head_g channels, got {Ctot} != {self.G}*{outC_head_g}"
|
|
533
|
+
|
|
534
|
+
if self.head_mixer is not None and self.head_reduce == 'learned':
|
|
535
|
+
# 1x1 learned mixing across G*outC_head_g -> out_channels
|
|
536
|
+
y = self.head_mixer(y_head_raw) # (B, out_channels, K)
|
|
537
|
+
else:
|
|
538
|
+
# reshape to (B, G, outC_head_g, K) then reduce across G
|
|
539
|
+
y_g = y_head_raw.view(B, self.G, outC_head_g, K)
|
|
540
|
+
|
|
541
|
+
y = y_g.mean(dim=1) # (B, outC_head_g, K)
|
|
542
|
+
|
|
543
|
+
y = self._as_tensor_batch(y)
|
|
544
|
+
|
|
545
|
+
# Optional BN + activation as before
|
|
546
|
+
if self.task == 'segmentation' and self.head_bn is not None:
|
|
547
|
+
y = self.head_bn(y)
|
|
548
|
+
|
|
549
|
+
if self.final_activation == 'sigmoid':
|
|
550
|
+
y = torch.sigmoid(y)
|
|
551
|
+
|
|
552
|
+
elif self.final_activation == 'softmax':
|
|
553
|
+
y = torch.softmax(y, dim=1)
|
|
554
|
+
|
|
555
|
+
return y
|
|
556
|
+
|
|
557
|
+
# -------------------------- utilities --------------------------
|
|
558
|
+
@torch.no_grad()
|
|
559
|
+
def predict(self, x: torch.Tensor, batch_size: int = 8,cell_ids: Optional[np.ndarray ] = None) -> torch.Tensor:
|
|
560
|
+
self.eval()
|
|
561
|
+
outs = []
|
|
562
|
+
if isinstance(x,np.ndarray):
|
|
563
|
+
x=self.to_Tensor(x)
|
|
564
|
+
|
|
565
|
+
if not isinstance(x, torch.Tensor):
|
|
566
|
+
for i in range(len(x)):
|
|
567
|
+
if cell_ids is not None:
|
|
568
|
+
outs.append(self.forward(x[i][None,:],cell_ids=cell_ids[i][:]))
|
|
569
|
+
else:
|
|
570
|
+
outs.append(self.forward(x[i][None,:]))
|
|
571
|
+
else:
|
|
572
|
+
for i in range(0, x.shape[0], batch_size):
|
|
573
|
+
if cell_ids is not None:
|
|
574
|
+
outs.append(self.forward(x[i : i + batch_size],
|
|
575
|
+
cell_ids=cell_ids[i : i + batch_size]))
|
|
576
|
+
else:
|
|
577
|
+
outs.append(self.forward(x[i : i + batch_size]))
|
|
578
|
+
|
|
579
|
+
return torch.cat(outs, dim=0)
|
|
580
|
+
|
|
581
|
+
def to_tensor(self,x):
|
|
582
|
+
return self.hconv_enc[0].f.backend.bk_cast(x)
|
|
583
|
+
|
|
584
|
+
def to_numpy(self,x):
|
|
585
|
+
if isinstance(x,np.ndarray):
|
|
586
|
+
return x
|
|
587
|
+
return x.cpu().numpy()
|
|
588
|
+
|
|
589
|
+
# -----------------------------
|
|
590
|
+
# Kernel extraction & plotting
|
|
591
|
+
# -----------------------------
|
|
592
|
+
def _arch_shapes(self):
|
|
593
|
+
"""Return expected (in_c, out_c) per conv for encoder/decoder.
|
|
594
|
+
|
|
595
|
+
Returns
|
|
596
|
+
-------
|
|
597
|
+
enc_shapes : list[tuple[tuple[int,int], tuple[int,int]]]
|
|
598
|
+
For each level `l`, ((in1, out1), (in2, out2)) for the two encoder convs.
|
|
599
|
+
dec_shapes : list[tuple[tuple[int,int], tuple[int,int]]]
|
|
600
|
+
For each level `l`, ((in1, out1), (in2, out2)) for the two decoder convs.
|
|
601
|
+
"""
|
|
602
|
+
nlayer = len(self.chanlist)
|
|
603
|
+
enc_shapes = []
|
|
604
|
+
l_chan = self.n_chan_in
|
|
605
|
+
for l in range(nlayer):
|
|
606
|
+
enc_shapes.append(((l_chan, self.chanlist[l]), (self.chanlist[l], self.chanlist[l])))
|
|
607
|
+
l_chan = self.chanlist[l] + 1
|
|
608
|
+
|
|
609
|
+
dec_shapes = []
|
|
610
|
+
l_chan = self.chanlist[-1] + 1
|
|
611
|
+
for l in range(nlayer):
|
|
612
|
+
in1 = l_chan + 1
|
|
613
|
+
out2 = 1 + (self.chanlist[nlayer - 1 - l] if (nlayer - 1 - l) > 0 else 0)
|
|
614
|
+
dec_shapes.append(((in1, in1), (in1, out2)))
|
|
615
|
+
l_chan = out2
|
|
616
|
+
return enc_shapes, dec_shapes
|
|
617
|
+
|
|
618
|
+
def extract_kernels(self, stage: str = "encoder", layer: int = 0, conv: int = 0):
|
|
619
|
+
"""Extract raw convolution kernels for a given stage/level/conv.
|
|
620
|
+
|
|
621
|
+
Parameters
|
|
622
|
+
----------
|
|
623
|
+
stage : {"encoder", "decoder"}
|
|
624
|
+
Which part of the network to inspect.
|
|
625
|
+
layer : int
|
|
626
|
+
Pyramid level (0 = finest encoder level / bottommost decoder level).
|
|
627
|
+
conv : int
|
|
628
|
+
0 for the first conv at that level, 1 for the second conv.
|
|
629
|
+
|
|
630
|
+
Returns
|
|
631
|
+
-------
|
|
632
|
+
np.ndarray
|
|
633
|
+
Array of shape (in_c, out_c, K, K) containing the spatial kernels.
|
|
634
|
+
"""
|
|
635
|
+
assert stage in {"encoder", "decoder"}
|
|
636
|
+
assert conv in {0, 1}
|
|
637
|
+
K = self.KERNELSZ
|
|
638
|
+
enc_shapes, dec_shapes = self._arch_shapes()
|
|
639
|
+
|
|
640
|
+
if stage == "encoder":
|
|
641
|
+
if conv==0:
|
|
642
|
+
w = self.enc_w1[layer]
|
|
643
|
+
else:
|
|
644
|
+
w = self.enc_w2[layer]
|
|
645
|
+
else:
|
|
646
|
+
if conv==0:
|
|
647
|
+
w = self.dec_w1[layer]
|
|
648
|
+
else:
|
|
649
|
+
w = self.dec_w2[layer]
|
|
650
|
+
|
|
651
|
+
w_np = self.f.backend.to_numpy(w.detach())
|
|
652
|
+
return w_np.reshape(w.shape[0],w.shape[1],K,K)
|
|
653
|
+
|
|
654
|
+
def plot_kernels(
|
|
655
|
+
self,
|
|
656
|
+
stage: str = "encoder",
|
|
657
|
+
layer: int = 0,
|
|
658
|
+
conv: int = 0,
|
|
659
|
+
fixed: str = "in",
|
|
660
|
+
index: int = 0,
|
|
661
|
+
max_tiles: int = 16,
|
|
662
|
+
):
|
|
663
|
+
"""Quick visualization of kernels on a grid using matplotlib.
|
|
664
|
+
|
|
665
|
+
Parameters
|
|
666
|
+
----------
|
|
667
|
+
stage : {"encoder", "decoder"}
|
|
668
|
+
Which tower to visualize.
|
|
669
|
+
layer : int
|
|
670
|
+
Level to visualize.
|
|
671
|
+
conv : int
|
|
672
|
+
0 or 1: first or second conv in the level.
|
|
673
|
+
fixed : {"in", "out"}
|
|
674
|
+
If "in", show kernels for a fixed input channel across many outputs.
|
|
675
|
+
If "out", show kernels for a fixed output channel across many inputs.
|
|
676
|
+
index : int
|
|
677
|
+
Channel index to fix (according to `fixed`).
|
|
678
|
+
max_tiles : int
|
|
679
|
+
Maximum number of tiles to display.
|
|
680
|
+
"""
|
|
681
|
+
import math
|
|
682
|
+
import matplotlib.pyplot as plt
|
|
683
|
+
|
|
684
|
+
W = self.extract_kernels(stage=stage, layer=layer, conv=conv)
|
|
685
|
+
ic, oc, K,_ = W.shape
|
|
686
|
+
|
|
687
|
+
if fixed == "in":
|
|
688
|
+
idx = min(index, ic - 1)
|
|
689
|
+
tiles = [W[idx, j] for j in range(oc)]
|
|
690
|
+
title = f"{stage} L{layer} C{conv} | in={idx}"
|
|
691
|
+
else:
|
|
692
|
+
idx = min(index, oc - 1)
|
|
693
|
+
tiles = [W[i, idx] for i in range(ic)]
|
|
694
|
+
title = f"{stage} L{layer} C{conv} | out={idx}"
|
|
695
|
+
|
|
696
|
+
tiles = tiles[:max_tiles]
|
|
697
|
+
n = len(tiles)
|
|
698
|
+
cols = int(math.ceil(math.sqrt(n)))
|
|
699
|
+
rows = int(math.ceil(n / cols))
|
|
700
|
+
|
|
701
|
+
plt.figure(figsize=(2.5 * cols, 2.5 * rows))
|
|
702
|
+
for i, ker in enumerate(tiles, 1):
|
|
703
|
+
ax = plt.subplot(rows, cols, i)
|
|
704
|
+
ax.imshow(ker)
|
|
705
|
+
ax.set_xticks([])
|
|
706
|
+
ax.set_yticks([])
|
|
707
|
+
plt.suptitle(title)
|
|
708
|
+
plt.tight_layout()
|
|
709
|
+
plt.show()
|
|
710
|
+
|
|
711
|
+
# -----------------------------
|
|
712
|
+
# Unit tests (smoke tests)
|
|
713
|
+
# -----------------------------
|
|
714
|
+
# Run with: python UNET.py (or) python UNET.py -q for quieter output
|
|
715
|
+
# These tests assume Foscat and its dependencies are installed.
|
|
716
|
+
|
|
717
|
+
|
|
718
|
+
def _dummy_cell_ids(nside: int) -> np.ndarray:
|
|
719
|
+
"""Return a simple identity mapping for HEALPix nested pixel IDs.
|
|
720
|
+
|
|
721
|
+
Notes
|
|
722
|
+
-----
|
|
723
|
+
Replace with your pipeline's real `cell_ids` if you have a precomputed
|
|
724
|
+
mapping consistent with Foscat/HEALPix nested ordering.
|
|
725
|
+
"""
|
|
726
|
+
return np.arange(12 * nside * nside, dtype=np.int64)
|
|
727
|
+
|
|
728
|
+
|
|
729
|
+
if __name__ == "__main__":
|
|
730
|
+
import unittest
|
|
731
|
+
|
|
732
|
+
class TestUNET(unittest.TestCase):
|
|
733
|
+
"""Lightweight smoke tests for shape and parameter plumbing."""
|
|
734
|
+
|
|
735
|
+
def setUp(self):
|
|
736
|
+
self.nside = 4 # small grid for fast tests (npix = 192)
|
|
737
|
+
self.chanlist = [4, 8] # two-level encoder/decoder
|
|
738
|
+
self.batch = 2
|
|
739
|
+
self.channels = 1
|
|
740
|
+
self.npix = 12 * self.nside * self.nside
|
|
741
|
+
self.cell_ids = _dummy_cell_ids(self.nside)
|
|
742
|
+
self.net = UNET(
|
|
743
|
+
in_nside=self.nside,
|
|
744
|
+
n_chan_in=self.channels,
|
|
745
|
+
chanlist=self.chanlist,
|
|
746
|
+
cell_ids=self.cell_ids,
|
|
747
|
+
)
|
|
748
|
+
|
|
749
|
+
def test_forward_shape(self):
|
|
750
|
+
# random input
|
|
751
|
+
x = np.random.randn(self.batch, self.channels, self.npix).astype(self.np_dtype)
|
|
752
|
+
x = self.net.f.backend.bk_cast(x)
|
|
753
|
+
y = self.net.eval(x)
|
|
754
|
+
# expected output: same npix, 1 channel at the very top
|
|
755
|
+
self.assertEqual(y.shape[0], self.batch)
|
|
756
|
+
self.assertEqual(y.shape[1], 1)
|
|
757
|
+
self.assertEqual(y.shape[2], self.npix)
|
|
758
|
+
# sanity: no NaNs
|
|
759
|
+
y_np = self.net.f.backend.to_numpy(y)
|
|
760
|
+
self.assertFalse(np.isnan(y_np).any())
|
|
761
|
+
|
|
762
|
+
def test_param_roundtrip_and_determinism(self):
|
|
763
|
+
x = np.random.randn(self.batch, self.channels, self.npix).astype(self.np_dtype)
|
|
764
|
+
x = self.net.f.backend.bk_cast(x)
|
|
765
|
+
|
|
766
|
+
# forward twice -> identical outputs with fixed params
|
|
767
|
+
y1 = self.net.eval(x)
|
|
768
|
+
y2 = self.net.eval(x)
|
|
769
|
+
y1_np = self.net.f.backend.to_numpy(y1)
|
|
770
|
+
y2_np = self.net.f.backend.to_numpy(y2)
|
|
771
|
+
np.testing.assert_allclose(y1_np, y2_np, rtol=0, atol=0)
|
|
772
|
+
|
|
773
|
+
# perturb parameters -> output should (very likely) change
|
|
774
|
+
p = self.net.get_param()
|
|
775
|
+
p_np = self.net.f.backend.to_numpy(p).copy()
|
|
776
|
+
if p_np.size > 0:
|
|
777
|
+
p_np[0] += 1.0
|
|
778
|
+
self.net.set_param(p_np)
|
|
779
|
+
y3 = self.net.eval(x)
|
|
780
|
+
y3_np = self.net.f.backend.to_numpy(y3)
|
|
781
|
+
with self.assertRaises(AssertionError):
|
|
782
|
+
np.testing.assert_allclose(y1_np, y3_np, rtol=0, atol=0)
|
|
783
|
+
|
|
784
|
+
unittest.main()
|
|
785
|
+
|
|
786
|
+
from torch.utils.data import Dataset
|
|
787
|
+
# 1) Dataset that omits cell_ids when None
|
|
788
|
+
from torch.utils.data import Dataset, DataLoader, TensorDataset
|
|
789
|
+
|
|
790
|
+
class HealpixDataset(Dataset):
|
|
791
|
+
"""
|
|
792
|
+
Returns (x, y, cell_ids) per-sample if cell_ids is given, else (x, y).
|
|
793
|
+
Shapes:
|
|
794
|
+
x: (C, Npix)
|
|
795
|
+
y: (C_out or 1, Npix)
|
|
796
|
+
cell_ids: (Npix,) per-sample (or broadcasted from (Npix,))
|
|
797
|
+
"""
|
|
798
|
+
def __init__(self, x, y, cell_ids=None, dtype=torch.float32):
|
|
799
|
+
self.x = torch.as_tensor(x, dtype=dtype)
|
|
800
|
+
self.y = torch.as_tensor(y, dtype=dtype)
|
|
801
|
+
assert self.x.shape[0] == self.y.shape[0], "x and y must share batch size"
|
|
802
|
+
self._has_cids = cell_ids is not None
|
|
803
|
+
if self._has_cids:
|
|
804
|
+
cid = torch.as_tensor(cell_ids, dtype=torch.long)
|
|
805
|
+
if cid.dim() == 1:
|
|
806
|
+
cid = cid.unsqueeze(0).expand(self.x.shape[0], -1)
|
|
807
|
+
assert cid.shape[0] == self.x.shape[0], "cell_ids must match batch size"
|
|
808
|
+
self.cids = cid
|
|
809
|
+
else:
|
|
810
|
+
self.cids = None
|
|
811
|
+
|
|
812
|
+
def __len__(self):
|
|
813
|
+
return self.x.shape[0]
|
|
814
|
+
|
|
815
|
+
def __getitem__(self, i):
|
|
816
|
+
if self._has_cids:
|
|
817
|
+
return self.x[i], self.y[i], self.cids[i]
|
|
818
|
+
else:
|
|
819
|
+
return self.x[i], self.y[i]
|
|
820
|
+
|
|
821
|
+
# ---------------------------
|
|
822
|
+
# Datasets / Collate helpers
|
|
823
|
+
# ---------------------------
|
|
824
|
+
|
|
825
|
+
class HealpixDataset(Dataset):
|
|
826
|
+
"""
|
|
827
|
+
Fixed-grid dataset (common Npix for all samples).
|
|
828
|
+
Returns (x, y) if cell_ids is None, else (x, y, cell_ids).
|
|
829
|
+
|
|
830
|
+
x: (B, C, Npix)
|
|
831
|
+
y: (B, C_out or 1, Npix) or class indices depending on task
|
|
832
|
+
cell_ids: (Npix,) or (B, Npix)
|
|
833
|
+
"""
|
|
834
|
+
def __init__(self,
|
|
835
|
+
x: torch.Tensor,
|
|
836
|
+
y: torch.Tensor,
|
|
837
|
+
cell_ids: Optional[Union[np.ndarray, torch.Tensor]] = None,
|
|
838
|
+
dtype: torch.dtype = torch.float32):
|
|
839
|
+
x = torch.as_tensor(x, dtype=dtype)
|
|
840
|
+
y = torch.as_tensor(y, dtype=dtype if y.ndim == 3 else torch.long)
|
|
841
|
+
assert x.shape[0] == y.shape[0], "x and y must share batch size"
|
|
842
|
+
self.x, self.y = x, y
|
|
843
|
+
self._has_cids = cell_ids is not None
|
|
844
|
+
if self._has_cids:
|
|
845
|
+
c = torch.as_tensor(cell_ids, dtype=torch.long)
|
|
846
|
+
if c.ndim == 1: # broadcast single (Npix,) to (B, Npix)
|
|
847
|
+
c = c.unsqueeze(0).expand(x.shape[0], -1)
|
|
848
|
+
assert c.shape == (x.shape[0], x.shape[2]), "cell_ids must be (B,Npix) or (Npix,)"
|
|
849
|
+
self.cids = c
|
|
850
|
+
else:
|
|
851
|
+
self.cids = None
|
|
852
|
+
|
|
853
|
+
def __len__(self) -> int: return self.x.shape[0]
|
|
854
|
+
|
|
855
|
+
def __getitem__(self, i: int):
|
|
856
|
+
if self._has_cids:
|
|
857
|
+
return self.x[i], self.y[i], self.cids[i]
|
|
858
|
+
return self.x[i], self.y[i]
|
|
859
|
+
|
|
860
|
+
# ---------------------------
|
|
861
|
+
# Datasets / Collate helpers
|
|
862
|
+
# ---------------------------
|
|
863
|
+
|
|
864
|
+
class HealpixDataset(Dataset):
|
|
865
|
+
"""
|
|
866
|
+
Fixed-grid dataset (common Npix for all samples).
|
|
867
|
+
Returns (x, y) if cell_ids is None, else (x, y, cell_ids).
|
|
868
|
+
|
|
869
|
+
x: (B, C, Npix)
|
|
870
|
+
y: (B, C_out or 1, Npix) or class indices depending on task
|
|
871
|
+
cell_ids: (Npix,) or (B, Npix)
|
|
872
|
+
"""
|
|
873
|
+
def __init__(self,
|
|
874
|
+
x: torch.Tensor,
|
|
875
|
+
y: torch.Tensor,
|
|
876
|
+
cell_ids: Optional[Union[np.ndarray, torch.Tensor]] = None,
|
|
877
|
+
dtype: torch.dtype = torch.float32):
|
|
878
|
+
x = torch.as_tensor(x, dtype=dtype)
|
|
879
|
+
y = torch.as_tensor(y, dtype=dtype if y.ndim == 3 else torch.long)
|
|
880
|
+
assert x.shape[0] == y.shape[0], "x and y must share batch size"
|
|
881
|
+
self.x, self.y = x, y
|
|
882
|
+
self._has_cids = cell_ids is not None
|
|
883
|
+
if self._has_cids:
|
|
884
|
+
c = torch.as_tensor(cell_ids, dtype=torch.long)
|
|
885
|
+
if c.ndim == 1: # broadcast single (Npix,) to (B, Npix)
|
|
886
|
+
c = c.unsqueeze(0).expand(x.shape[0], -1)
|
|
887
|
+
assert c.shape == (x.shape[0], x.shape[2]), "cell_ids must be (B,Npix) or (Npix,)"
|
|
888
|
+
self.cids = c
|
|
889
|
+
else:
|
|
890
|
+
self.cids = None
|
|
891
|
+
|
|
892
|
+
def __len__(self) -> int: return self.x.shape[0]
|
|
893
|
+
|
|
894
|
+
def __getitem__(self, i: int):
|
|
895
|
+
if self._has_cids:
|
|
896
|
+
return self.x[i], self.y[i], self.cids[i]
|
|
897
|
+
return self.x[i], self.y[i]
|
|
898
|
+
|
|
899
|
+
|
|
900
|
+
class VarLenHealpixDataset(Dataset):
|
|
901
|
+
"""
|
|
902
|
+
Variable-length per-sample dataset.
|
|
903
|
+
|
|
904
|
+
x_list[b]: (C, Npix_b) or (1, C, Npix_b)
|
|
905
|
+
y_list[b]: (C_out or 1, Npix_b) or (1, C_out, Npix_b) (regression/segmentation targets)
|
|
906
|
+
For multi-class segmentation with CrossEntropyLoss, you may pass
|
|
907
|
+
class indices of shape (Npix_b,) or (1, Npix_b) (we’ll squeeze later).
|
|
908
|
+
cids_list[b]: (Npix_b,) or None
|
|
909
|
+
"""
|
|
910
|
+
def __init__(self,
|
|
911
|
+
x_list: List[Union[np.ndarray, torch.Tensor]],
|
|
912
|
+
y_list: List[Union[np.ndarray, torch.Tensor]],
|
|
913
|
+
cids_list: Optional[List[Union[np.ndarray, torch.Tensor]]] = None,
|
|
914
|
+
dtype: torch.dtype = torch.float32):
|
|
915
|
+
assert len(x_list) == len(y_list), "x_list and y_list must have the same length"
|
|
916
|
+
self.x = [torch.as_tensor(x, dtype=dtype) for x in x_list]
|
|
917
|
+
# y can be float (regression) or long (class indices); we’ll coerce later per task
|
|
918
|
+
self.y = [torch.as_tensor(y) for y in y_list]
|
|
919
|
+
if cids_list is not None:
|
|
920
|
+
assert len(cids_list) == len(x_list), "cids_list must match x_list length"
|
|
921
|
+
self.c = [torch.as_tensor(c, dtype=torch.long) for c in cids_list]
|
|
922
|
+
else:
|
|
923
|
+
self.c = None
|
|
924
|
+
|
|
925
|
+
def __len__(self) -> int: return len(self.x)
|
|
926
|
+
|
|
927
|
+
def __getitem__(self, i: int):
|
|
928
|
+
ci = None if self.c is None else self.c[i]
|
|
929
|
+
return self.x[i], self.y[i], ci
|
|
930
|
+
|
|
931
|
+
from torch.utils.data import Dataset, DataLoader
|
|
932
|
+
|
|
933
|
+
class VarLenHealpixDataset(Dataset):
|
|
934
|
+
"""
|
|
935
|
+
x_list: list of (C, Npix_b) tensors or arrays
|
|
936
|
+
y_list: list of (C_out or 1, Npix_b) tensors or arrays
|
|
937
|
+
cids_list: optional list of (Npix_b,) arrays
|
|
938
|
+
"""
|
|
939
|
+
def __init__(self, x_list, y_list, cids_list=None, dtype=torch.float32):
|
|
940
|
+
assert len(x_list) == len(y_list)
|
|
941
|
+
self.x = [torch.as_tensor(x, dtype=dtype) for x in x_list]
|
|
942
|
+
self.y = [torch.as_tensor(y, dtype=dtype) for y in y_list]
|
|
943
|
+
self.c = None
|
|
944
|
+
if cids_list is not None:
|
|
945
|
+
assert len(cids_list) == len(x_list)
|
|
946
|
+
self.c = [np.asarray(c) for c in cids_list]
|
|
947
|
+
|
|
948
|
+
def __len__(self): return len(self.x)
|
|
949
|
+
|
|
950
|
+
def __getitem__(self, i):
|
|
951
|
+
if self.c is None:
|
|
952
|
+
return self.x[i], self.y[i], None
|
|
953
|
+
return self.x[i], self.y[i], self.c[i]
|
|
954
|
+
|
|
955
|
+
def varlen_collate(batch):
|
|
956
|
+
# Just return lists; do not stack.
|
|
957
|
+
xs, ys, cs = zip(*batch) # tuples of length B
|
|
958
|
+
# keep None if all Nones, else list
|
|
959
|
+
c_out = None if all(c is None for c in cs) else list(cs)
|
|
960
|
+
return list(xs), list(ys), c_out
|
|
961
|
+
|
|
962
|
+
def varlen_collate(batch):
|
|
963
|
+
"""
|
|
964
|
+
Collate for variable-length samples: keep lists, do NOT stack.
|
|
965
|
+
Returns lists: xs, ys, cs (cs can be None).
|
|
966
|
+
"""
|
|
967
|
+
xs, ys, cs = zip(*batch)
|
|
968
|
+
c_out = None if all(c is None for c in cs) else list(cs)
|
|
969
|
+
return list(xs), list(ys), c_out
|
|
970
|
+
|
|
971
|
+
|
|
972
|
+
# ---------------------------
|
|
973
|
+
# Training function
|
|
974
|
+
# ---------------------------
|
|
975
|
+
|
|
976
|
+
def fit(
|
|
977
|
+
model,
|
|
978
|
+
x_train: Union[torch.Tensor, np.ndarray, List[Union[torch.Tensor, np.ndarray]]],
|
|
979
|
+
y_train: Union[torch.Tensor, np.ndarray, List[Union[torch.Tensor, np.ndarray]]],
|
|
980
|
+
*,
|
|
981
|
+
cell_ids_train: Optional[Union[np.ndarray, torch.Tensor, List[Union[np.ndarray, torch.Tensor]]]] = None,
|
|
982
|
+
n_epoch: int = 10,
|
|
983
|
+
view_epoch: int = 10,
|
|
984
|
+
batch_size: int = 16,
|
|
985
|
+
lr: float = 1e-3,
|
|
986
|
+
weight_decay: float = 0.0,
|
|
987
|
+
clip_grad_norm: Optional[float] = None,
|
|
988
|
+
verbose: bool = True,
|
|
989
|
+
optimizer: Literal['ADAM', 'LBFGS'] = 'ADAM',
|
|
990
|
+
) -> dict:
|
|
991
|
+
"""
|
|
992
|
+
Train helper that supports:
|
|
993
|
+
- Fixed-grid tensors (B,C,N) with optional (B,N) or (N,) cell_ids.
|
|
994
|
+
- Variable-length lists: x=[(C,N_b)], y=[...], cell_ids=[(N_b,)], returning per-sample grids.
|
|
995
|
+
|
|
996
|
+
ADAM: standard minibatch update.
|
|
997
|
+
LBFGS: uses a closure that sums losses over the current (variable-length) mini-batch.
|
|
998
|
+
|
|
999
|
+
Notes
|
|
1000
|
+
-----
|
|
1001
|
+
- For segmentation with multiple classes, pass integer class targets for y:
|
|
1002
|
+
fixed-grid: (B, N) int64; variable-length: each y[b] of shape (N_b,) or (1,N_b).
|
|
1003
|
+
- For regression, pass float targets with the same (C_out, N) channeling.
|
|
1004
|
+
"""
|
|
1005
|
+
device = model.runtime_device if hasattr(model, "runtime_device") else (torch.device("cuda") if torch.cuda.is_available() else torch.device("cpu"))
|
|
1006
|
+
model.to(device)
|
|
1007
|
+
|
|
1008
|
+
# Detect variable-length mode
|
|
1009
|
+
varlen_mode = isinstance(x_train, (list, tuple))
|
|
1010
|
+
|
|
1011
|
+
# ----- Build DataLoader
|
|
1012
|
+
if not varlen_mode:
|
|
1013
|
+
# Fixed-grid path
|
|
1014
|
+
x_t = torch.as_tensor(x_train, dtype=torch.float32, device=device)
|
|
1015
|
+
y_is_class = (model.task != 'regression' and getattr(model, "out_channels", 1) > 1)
|
|
1016
|
+
y_dtype = torch.long if y_is_class and (not torch.is_tensor(y_train) or y_train.ndim != 3) else torch.float32
|
|
1017
|
+
y_t = torch.as_tensor(y_train, dtype=y_dtype, device=device)
|
|
1018
|
+
|
|
1019
|
+
if cell_ids_train is None:
|
|
1020
|
+
ds = TensorDataset(x_t, y_t)
|
|
1021
|
+
loader = DataLoader(ds, batch_size=batch_size, shuffle=True, drop_last=False)
|
|
1022
|
+
with_cell_ids = False
|
|
1023
|
+
else:
|
|
1024
|
+
ds = HealpixDataset(x_t, y_t, cell_ids=cell_ids_train, dtype=torch.float32)
|
|
1025
|
+
loader = DataLoader(ds, batch_size=batch_size, shuffle=True, drop_last=False)
|
|
1026
|
+
with_cell_ids = True
|
|
1027
|
+
else:
|
|
1028
|
+
# Variable-length path
|
|
1029
|
+
ds = VarLenHealpixDataset(x_train, y_train, cids_list=cell_ids_train, dtype=torch.float32)
|
|
1030
|
+
loader = DataLoader(ds, batch_size=batch_size, shuffle=True, drop_last=False, collate_fn=varlen_collate)
|
|
1031
|
+
with_cell_ids = cell_ids_train is not None
|
|
1032
|
+
|
|
1033
|
+
# ----- Loss
|
|
1034
|
+
if getattr(model, "task", "regression") == 'regression':
|
|
1035
|
+
criterion = nn.MSELoss(reduction='mean')
|
|
1036
|
+
seg_multiclass = False
|
|
1037
|
+
else:
|
|
1038
|
+
# segmentation
|
|
1039
|
+
if getattr(model, "out_channels", 1) == 1:
|
|
1040
|
+
# binary
|
|
1041
|
+
# assume model head returns logits if final_activation == 'none'
|
|
1042
|
+
criterion = nn.BCEWithLogitsLoss() if getattr(model, "final_activation", "none") == 'none' else nn.BCELoss()
|
|
1043
|
+
seg_multiclass = False
|
|
1044
|
+
else:
|
|
1045
|
+
criterion = nn.CrossEntropyLoss()
|
|
1046
|
+
seg_multiclass = True
|
|
1047
|
+
|
|
1048
|
+
# ----- Optimizer
|
|
1049
|
+
if optimizer.upper() == 'ADAM':
|
|
1050
|
+
optim = torch.optim.Adam(model.parameters(), lr=lr, weight_decay=weight_decay)
|
|
1051
|
+
outer = n_epoch
|
|
1052
|
+
inner = 1
|
|
1053
|
+
else:
|
|
1054
|
+
optim = torch.optim.LBFGS(model.parameters(), lr=lr, max_iter=20,
|
|
1055
|
+
history_size=max(10, n_epoch * 5), line_search_fn="strong_wolfe")
|
|
1056
|
+
# emulate "epochs" with multiple inner LBFGS steps
|
|
1057
|
+
outer = max(1, n_epoch // 20)
|
|
1058
|
+
inner = 20
|
|
1059
|
+
|
|
1060
|
+
# ----- Training loop
|
|
1061
|
+
history: List[float] = []
|
|
1062
|
+
model.train()
|
|
1063
|
+
|
|
1064
|
+
for epoch in range(outer):
|
|
1065
|
+
for _ in range(inner):
|
|
1066
|
+
epoch_loss, n_samples = 0.0, 0
|
|
1067
|
+
|
|
1068
|
+
for batch in loader:
|
|
1069
|
+
if not varlen_mode:
|
|
1070
|
+
# -------- fixed-grid
|
|
1071
|
+
if with_cell_ids:
|
|
1072
|
+
xb, yb, cb = batch
|
|
1073
|
+
cb_np = cb.detach().cpu().numpy()
|
|
1074
|
+
else:
|
|
1075
|
+
xb, yb = batch
|
|
1076
|
+
cb_np = None
|
|
1077
|
+
|
|
1078
|
+
xb = xb.to(device, dtype=torch.float32, non_blocking=True)
|
|
1079
|
+
# y type: float for regression or binary; long for CrossEntropy
|
|
1080
|
+
yb = yb.to(device, non_blocking=True)
|
|
1081
|
+
|
|
1082
|
+
if isinstance(optim, torch.optim.LBFGS):
|
|
1083
|
+
def closure():
|
|
1084
|
+
optim.zero_grad(set_to_none=True)
|
|
1085
|
+
preds = model(xb, cell_ids=cb_np) if cb_np is not None else model(xb)
|
|
1086
|
+
loss = criterion(preds, yb)
|
|
1087
|
+
loss.backward()
|
|
1088
|
+
return loss
|
|
1089
|
+
_ = optim.step(closure)
|
|
1090
|
+
with torch.no_grad():
|
|
1091
|
+
preds = model(xb, cell_ids=cb_np) if cb_np is not None else model(xb)
|
|
1092
|
+
loss = criterion(preds, yb)
|
|
1093
|
+
else:
|
|
1094
|
+
optim.zero_grad(set_to_none=True)
|
|
1095
|
+
preds = model(xb, cell_ids=cb_np) if cb_np is not None else model(xb)
|
|
1096
|
+
loss = criterion(preds, yb)
|
|
1097
|
+
loss.backward()
|
|
1098
|
+
if clip_grad_norm is not None:
|
|
1099
|
+
nn.utils.clip_grad_norm_(model.parameters(), clip_grad_norm)
|
|
1100
|
+
optim.step()
|
|
1101
|
+
|
|
1102
|
+
bs = xb.shape[0]
|
|
1103
|
+
epoch_loss += float(loss.item()) * bs
|
|
1104
|
+
n_samples += bs
|
|
1105
|
+
|
|
1106
|
+
else:
|
|
1107
|
+
# -------- variable-length (lists)
|
|
1108
|
+
xs, ys, cs = batch # lists
|
|
1109
|
+
|
|
1110
|
+
def _prep_xyc(i):
|
|
1111
|
+
# x_i : (C, N_i) -> (1, C, N_i)
|
|
1112
|
+
xb = torch.as_tensor(xs[i], device=device, dtype=torch.float32)
|
|
1113
|
+
if xb.dim() == 2:
|
|
1114
|
+
xb = xb.unsqueeze(0)
|
|
1115
|
+
elif xb.dim() != 3 or xb.shape[0] != 1:
|
|
1116
|
+
raise ValueError("Each x[i] must be (C,N) or (1,C,N)")
|
|
1117
|
+
|
|
1118
|
+
# y_i :
|
|
1119
|
+
yb = torch.as_tensor(ys[i], device=device)
|
|
1120
|
+
if seg_multiclass:
|
|
1121
|
+
# class indices: (N_i,) ou (1, N_i)
|
|
1122
|
+
if yb.dim() == 2 and yb.shape[0] == 1:
|
|
1123
|
+
yb = yb.squeeze(0) # -> (N_i,)
|
|
1124
|
+
elif yb.dim() != 1:
|
|
1125
|
+
raise ValueError("For multiclass CE, y[i] must be (N,) or (1,N)")
|
|
1126
|
+
# le critère CE recevra (1,C_out,N_i) et (N_i,)
|
|
1127
|
+
else:
|
|
1128
|
+
# régression / binaire: cible de forme (1, C_out, N_i)
|
|
1129
|
+
if yb.dim() == 2:
|
|
1130
|
+
yb = yb.unsqueeze(0)
|
|
1131
|
+
elif yb.dim() != 3 or yb.shape[0] != 1:
|
|
1132
|
+
raise ValueError("For regression/binary, y[i] must be (C_out,N) or (1,C_out,N)")
|
|
1133
|
+
|
|
1134
|
+
# cell_ids : (N_i,) -> (1, N_i) en numpy (le forward les attend en np.ndarray)
|
|
1135
|
+
if cs is None or cs[i] is None:
|
|
1136
|
+
cb_np = None
|
|
1137
|
+
else:
|
|
1138
|
+
c = cs[i].detach().cpu().numpy() if torch.is_tensor(cs[i]) else np.asarray(cs[i])
|
|
1139
|
+
if c.ndim == 1:
|
|
1140
|
+
c = c[None, :] # -> (1, N_i)
|
|
1141
|
+
cb_np = c
|
|
1142
|
+
return xb, yb, cb_np
|
|
1143
|
+
|
|
1144
|
+
if isinstance(optim, torch.optim.LBFGS):
|
|
1145
|
+
def closure():
|
|
1146
|
+
optim.zero_grad(set_to_none=True)
|
|
1147
|
+
total = 0.0
|
|
1148
|
+
for i in range(len(xs)):
|
|
1149
|
+
xb, yb, cb_np = _prep_xyc(i)
|
|
1150
|
+
preds = model(xb, cell_ids=cb_np) if cb_np is not None else model(xb)
|
|
1151
|
+
# adapter la cible à la sortie
|
|
1152
|
+
if seg_multiclass:
|
|
1153
|
+
loss_i = criterion(preds, yb) # preds: (1,C_out,N_i), yb: (N_i,)
|
|
1154
|
+
else:
|
|
1155
|
+
loss_i = criterion(preds, yb) # preds: (1,C_out,N_i), yb: (1,C_out,N_i)
|
|
1156
|
+
loss_i.backward()
|
|
1157
|
+
total += float(loss_i.item())
|
|
1158
|
+
# retourner un scalaire Tensor pour LBFGS
|
|
1159
|
+
return torch.tensor(total / max(1, len(xs)), device=device, dtype=torch.float32)
|
|
1160
|
+
|
|
1161
|
+
_ = optim.step(closure)
|
|
1162
|
+
# logging (sans grad)
|
|
1163
|
+
with torch.no_grad():
|
|
1164
|
+
total = 0.0
|
|
1165
|
+
for i in range(len(xs)):
|
|
1166
|
+
xb, yb, cb_np = _prep_xyc(i)
|
|
1167
|
+
preds = model(xb, cell_ids=cb_np) if cb_np is not None else model(xb)
|
|
1168
|
+
if seg_multiclass:
|
|
1169
|
+
loss_i = criterion(preds, yb)
|
|
1170
|
+
else:
|
|
1171
|
+
loss_i = criterion(preds, yb)
|
|
1172
|
+
total += float(loss_i.item())
|
|
1173
|
+
loss_val = total / max(1, len(xs))
|
|
1174
|
+
else:
|
|
1175
|
+
optim.zero_grad(set_to_none=True)
|
|
1176
|
+
total = 0.0
|
|
1177
|
+
for i in range(len(xs)):
|
|
1178
|
+
xb, yb, cb_np = _prep_xyc(i)
|
|
1179
|
+
preds = model(xb, cell_ids=cb_np) if cb_np is not None else model(xb)
|
|
1180
|
+
if seg_multiclass:
|
|
1181
|
+
loss_i = criterion(preds, yb)
|
|
1182
|
+
else:
|
|
1183
|
+
loss_i = criterion(preds, yb)
|
|
1184
|
+
loss_i.backward()
|
|
1185
|
+
total += float(loss_i.item())
|
|
1186
|
+
if clip_grad_norm is not None:
|
|
1187
|
+
nn.utils.clip_grad_norm_(model.parameters(), clip_grad_norm)
|
|
1188
|
+
optim.step()
|
|
1189
|
+
loss_val = total / max(1, len(xs))
|
|
1190
|
+
|
|
1191
|
+
epoch_loss += loss_val * max(1, len(xs))
|
|
1192
|
+
n_samples += max(1, len(xs))
|
|
1193
|
+
|
|
1194
|
+
|
|
1195
|
+
|
|
1196
|
+
epoch_loss /= max(1, n_samples)
|
|
1197
|
+
history.append(epoch_loss)
|
|
1198
|
+
# print every view_epoch logical step
|
|
1199
|
+
if verbose and ((len(history) % view_epoch == 0) or (len(history) == 1)):
|
|
1200
|
+
print(f"[epoch {len(history)}] loss={epoch_loss:.6f}")
|
|
1201
|
+
|
|
1202
|
+
return {"loss": history}
|