foscat 2025.8.4__py3-none-any.whl → 2025.9.3__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- foscat/BkTorch.py +309 -50
- foscat/FoCUS.py +74 -267
- foscat/HOrientedConvol.py +517 -130
- foscat/HealBili.py +309 -0
- foscat/Plot.py +331 -0
- foscat/SphericalStencil.py +1346 -0
- foscat/UNET.py +470 -179
- foscat/healpix_unet_torch.py +1202 -0
- foscat/scat_cov.py +3 -1
- {foscat-2025.8.4.dist-info → foscat-2025.9.3.dist-info}/METADATA +1 -1
- {foscat-2025.8.4.dist-info → foscat-2025.9.3.dist-info}/RECORD +14 -10
- {foscat-2025.8.4.dist-info → foscat-2025.9.3.dist-info}/WHEEL +0 -0
- {foscat-2025.8.4.dist-info → foscat-2025.9.3.dist-info}/licenses/LICENSE +0 -0
- {foscat-2025.8.4.dist-info → foscat-2025.9.3.dist-info}/top_level.txt +0 -0
foscat/UNET.py
CHANGED
|
@@ -1,200 +1,491 @@
|
|
|
1
|
+
"""
|
|
2
|
+
UNET for HEALPix (nested) using Foscat oriented convolutions.
|
|
3
|
+
|
|
4
|
+
This module defines a lightweight, U-Net–like encoder/decoder that operates on
|
|
5
|
+
signals defined on the HEALPix sphere (nested scheme). It leverages Foscat's
|
|
6
|
+
`HOrientedConvol` for orientation-aware convolutions and `funct` utilities for
|
|
7
|
+
upgrade/downgrade (change of nside) operations.
|
|
8
|
+
|
|
9
|
+
Key design choices
|
|
10
|
+
------------------
|
|
11
|
+
• **Flat parameter vector**: all convolution kernels are stored in a single
|
|
12
|
+
1‑D vector `self.x`. The dictionaries `self.wconv` and `self.t_wconv` map
|
|
13
|
+
layer indices to *offsets* within that vector.
|
|
14
|
+
• **HEALPix-aware down/up-sampling**: down-sampling uses
|
|
15
|
+
`self.f.ud_grade_2`, and up-sampling uses `self.f.up_grade`, both with
|
|
16
|
+
per-level `cell_ids` to preserve locality and orientation information.
|
|
17
|
+
• **Skip connections**: U‑Net skip connections are implemented by concatenating
|
|
18
|
+
encoder features with downgraded/upsampled paths along the channel axis.
|
|
19
|
+
|
|
20
|
+
Shape convention
|
|
21
|
+
----------------
|
|
22
|
+
All tensors follow the Foscat backend shape `(batch, channels, npix)`.
|
|
23
|
+
|
|
24
|
+
Dependencies
|
|
25
|
+
------------
|
|
26
|
+
- foscat.scat_cov as `sc`
|
|
27
|
+
- foscat.SphericalStencil as `hs`
|
|
28
|
+
|
|
29
|
+
Example
|
|
30
|
+
-------
|
|
31
|
+
>>> import numpy as np
|
|
32
|
+
>>> from UNET import UNET
|
|
33
|
+
>>> nside = 8
|
|
34
|
+
>>> npix = 12 * nside * nside
|
|
35
|
+
>>> # Your backend tensor should be created via foscat backend; here we show a placeholder np.array
|
|
36
|
+
>>> x = np.random.randn(1, 1, npix).astype(np.float32)
|
|
37
|
+
>>> # cell_ids should be provided for the highest resolution (nside)
|
|
38
|
+
>>> # and must be consistent with the nested scheme expected by Foscat.
|
|
39
|
+
>>> # Example placeholder (use the real one from your pipeline):
|
|
40
|
+
>>> cell_ids = np.arange(npix, dtype=np.int64)
|
|
41
|
+
>>> net = UNET(in_nside=nside, n_chan_in=1, cell_ids=cell_ids)
|
|
42
|
+
>>> y = net.eval(net.f.backend.bk_cast(x)) # forward pass
|
|
43
|
+
|
|
44
|
+
Notes
|
|
45
|
+
-----
|
|
46
|
+
- This implementation assumes `cell_ids` is provided for the input resolution
|
|
47
|
+
`in_nside`. It propagates/derives the coarser `cell_ids` across levels.
|
|
48
|
+
- Some constructor parameters are reserved for future use (see docstring).
|
|
49
|
+
"""
|
|
50
|
+
|
|
51
|
+
from typing import Dict, Optional
|
|
1
52
|
import numpy as np
|
|
2
53
|
|
|
3
54
|
import foscat.scat_cov as sc
|
|
4
|
-
import foscat.
|
|
55
|
+
import foscat.SphericalPencil as hs
|
|
56
|
+
|
|
5
57
|
|
|
6
58
|
class UNET:
|
|
59
|
+
"""U‑Net–like network on HEALPix (nested) using Foscat oriented convolutions.
|
|
60
|
+
|
|
61
|
+
The network is built as an encoder/decoder (down/upsampling) tower. Each
|
|
62
|
+
level performs two oriented convolutions. All kernels are packed in a flat
|
|
63
|
+
parameter vector `self.x` to simplify optimization with external solvers.
|
|
64
|
+
|
|
65
|
+
Parameters
|
|
66
|
+
----------
|
|
67
|
+
nparam : int, optional
|
|
68
|
+
Reserved for future use. Currently unused.
|
|
69
|
+
KERNELSZ : int, optional
|
|
70
|
+
Spatial kernel size (k × k) used by oriented convolutions. Default is 3.
|
|
71
|
+
NORIENT : int, optional
|
|
72
|
+
Reserved for future use (number of orientations). Currently unused.
|
|
73
|
+
chanlist : Optional[list[int]], optional
|
|
74
|
+
Number of output channels per encoder level. If ``None``, it defaults to
|
|
75
|
+
``[4 * 2**k for k in range(log2(in_nside))]``. The length of this list
|
|
76
|
+
defines the number of encoder/decoder levels.
|
|
77
|
+
in_nside : int, optional
|
|
78
|
+
Input HEALPix nside. Must be a power of two for the implicit
|
|
79
|
+
``log2(in_nside)`` depth when ``chanlist`` is not given.
|
|
80
|
+
n_chan_in : int, optional
|
|
81
|
+
Number of input channels at the finest resolution. Default is 1.
|
|
82
|
+
cell_ids : array-like of int, required
|
|
83
|
+
Pixel identifiers at the input resolution (nested indexing). They are
|
|
84
|
+
used to build oriented convolutions and to derive coarser grids.
|
|
85
|
+
**Must not be ``None``.**
|
|
86
|
+
SEED : int, optional
|
|
87
|
+
Reserved for future use (random initialization seed). Currently unused.
|
|
88
|
+
filename : Optional[str], optional
|
|
89
|
+
Reserved for future use (checkpoint I/O). Currently unused.
|
|
90
|
+
|
|
91
|
+
Attributes
|
|
92
|
+
----------
|
|
93
|
+
f : object
|
|
94
|
+
Foscat helper exposing the backend and grade/convolution utils.
|
|
95
|
+
KERNELSZ : int
|
|
96
|
+
Effective kernel size used by all convolutions.
|
|
97
|
+
chanlist : list[int]
|
|
98
|
+
Channels per encoder level.
|
|
99
|
+
wconv, t_wconv : Dict[int, int]
|
|
100
|
+
Offsets into the flat parameter vector `self.x` for encoder/decoder
|
|
101
|
+
convolutions respectively.
|
|
102
|
+
hconv, t_hconv : Dict[int, hs.SphericalPencil]
|
|
103
|
+
Per-level oriented convolution operators for encoder/decoder.
|
|
104
|
+
l_cell_ids : Dict[int, np.ndarray]
|
|
105
|
+
Per-level cell ids for downsampled grids (encoder side).
|
|
106
|
+
m_cell_ids : Dict[int, np.ndarray]
|
|
107
|
+
Per-level cell ids for upsampled grids (decoder side). Mirrors levels of
|
|
108
|
+
``l_cell_ids`` but indexed from the decoder traversal.
|
|
109
|
+
x : backend tensor (1‑D)
|
|
110
|
+
Flat vector holding *all* convolution weights.
|
|
111
|
+
nside : int
|
|
112
|
+
Input nside (finest resolution).
|
|
113
|
+
n_chan_in : int
|
|
114
|
+
Number of channels at input.
|
|
115
|
+
|
|
116
|
+
Notes
|
|
117
|
+
-----
|
|
118
|
+
- The constructor prints informative messages about the architecture layout
|
|
119
|
+
(channels and pixel counts) to ease debugging.
|
|
120
|
+
- The implementation keeps the logic identical to the original code; only
|
|
121
|
+
comments, docstrings and variable explanations are added.
|
|
122
|
+
"""
|
|
7
123
|
|
|
8
124
|
def __init__(
|
|
9
|
-
|
|
10
|
-
|
|
11
|
-
|
|
12
|
-
|
|
13
|
-
|
|
14
|
-
|
|
15
|
-
|
|
16
|
-
|
|
17
|
-
|
|
18
|
-
|
|
19
|
-
filename=None,
|
|
125
|
+
self,
|
|
126
|
+
nparam: int = 1,
|
|
127
|
+
KERNELSZ: int = 3,
|
|
128
|
+
NORIENT: int = 4,
|
|
129
|
+
chanlist: Optional[list] = None,
|
|
130
|
+
in_nside: int = 1,
|
|
131
|
+
n_chan_in: int = 1,
|
|
132
|
+
cell_ids: Optional[np.ndarray] = None,
|
|
133
|
+
SEED: int = 1234,
|
|
134
|
+
filename: Optional[str] = None,
|
|
20
135
|
):
|
|
21
|
-
|
|
22
|
-
|
|
136
|
+
# Foscat function wrapper providing backend and grade ops
|
|
137
|
+
self.f = sc.funct(KERNELSZ=KERNELSZ)
|
|
138
|
+
|
|
139
|
+
# If no channel plan is provided, build a default pyramid depth of
|
|
140
|
+
# log2(in_nside) levels with channels growing as 4 * 2**k
|
|
23
141
|
if chanlist is None:
|
|
24
|
-
nlayer=int(np.log2(in_nside))
|
|
25
|
-
chanlist=[4*2**k for k in range(nlayer)]
|
|
142
|
+
nlayer = int(np.log2(in_nside))
|
|
143
|
+
chanlist = [4 * 2 ** k for k in range(nlayer)]
|
|
26
144
|
else:
|
|
27
|
-
nlayer=len(chanlist)
|
|
28
|
-
print(
|
|
29
|
-
|
|
30
|
-
|
|
31
|
-
|
|
32
|
-
|
|
33
|
-
|
|
34
|
-
|
|
35
|
-
|
|
36
|
-
|
|
37
|
-
|
|
38
|
-
|
|
39
|
-
|
|
40
|
-
|
|
41
|
-
|
|
42
|
-
|
|
145
|
+
nlayer = len(chanlist)
|
|
146
|
+
print("N_layer ", nlayer)
|
|
147
|
+
|
|
148
|
+
# Internal registries
|
|
149
|
+
n = 0 # running offset in the flat parameter vector
|
|
150
|
+
wconv: Dict[int, int] = {} # encoder weight offsets
|
|
151
|
+
hconv: Dict[int, hs.SphericalPencil] = {} # encoder conv operators
|
|
152
|
+
l_cell_ids: Dict[int, np.ndarray] = {} # encoder level cell ids
|
|
153
|
+
self.KERNELSZ = KERNELSZ
|
|
154
|
+
kernelsz = self.KERNELSZ
|
|
155
|
+
|
|
156
|
+
# -----------------------------
|
|
157
|
+
# Encoder (downsampling) build
|
|
158
|
+
# -----------------------------
|
|
159
|
+
l_nside = in_nside
|
|
160
|
+
# NOTE: the original code assumes cell_ids is provided; we keep that
|
|
161
|
+
# contract and copy to avoid side effects.
|
|
162
|
+
l_cell_ids[0] = cell_ids.copy()
|
|
163
|
+
# Create a dummy data tensor to probe shapes; real data arrives in eval()
|
|
164
|
+
l_data = self.f.backend.bk_cast(np.ones([1, 1, l_cell_ids[0].shape[0]]))
|
|
165
|
+
l_chan = n_chan_in
|
|
166
|
+
print("Initial chan %d Npix=%d" % (l_chan, l_data.shape[2]))
|
|
167
|
+
|
|
43
168
|
for l in range(nlayer):
|
|
44
|
-
print(
|
|
45
|
-
|
|
46
|
-
|
|
47
|
-
|
|
48
|
-
|
|
49
|
-
|
|
50
|
-
|
|
51
|
-
|
|
52
|
-
|
|
53
|
-
n
|
|
54
|
-
|
|
55
|
-
|
|
56
|
-
|
|
57
|
-
|
|
58
|
-
|
|
59
|
-
l_cell_ids[l
|
|
60
|
-
|
|
61
|
-
|
|
62
|
-
|
|
63
|
-
|
|
64
|
-
|
|
65
|
-
|
|
66
|
-
|
|
67
|
-
|
|
68
|
-
|
|
69
|
-
|
|
70
|
-
|
|
71
|
-
|
|
72
|
-
|
|
73
|
-
|
|
74
|
-
|
|
169
|
+
print("Layer %d Npix=%d" % (l, l_data.shape[2]))
|
|
170
|
+
|
|
171
|
+
# Record offset for first conv at this level: (in -> chanlist[l])
|
|
172
|
+
wconv[2 * l] = n
|
|
173
|
+
nw = l_chan * chanlist[l] * kernelsz * kernelsz
|
|
174
|
+
print("Layer %d conv [%d,%d]" % (l, l_chan, chanlist[l]))
|
|
175
|
+
n += nw
|
|
176
|
+
|
|
177
|
+
# Record offset for second conv at this level: (chanlist[l] -> chanlist[l])
|
|
178
|
+
wconv[2 * l + 1] = n
|
|
179
|
+
nw = chanlist[l] * chanlist[l] * kernelsz * kernelsz
|
|
180
|
+
print("Layer %d conv [%d,%d]" % (l, chanlist[l], chanlist[l]))
|
|
181
|
+
n += nw
|
|
182
|
+
|
|
183
|
+
# Build oriented convolution operator for this level
|
|
184
|
+
hconvol = hs.SphericalPencil(l_nside, 3, cell_ids=l_cell_ids[l])
|
|
185
|
+
hconvol.make_idx_weights() # precompute indices/weights once
|
|
186
|
+
hconv[l] = hconvol
|
|
187
|
+
|
|
188
|
+
# Downsample features and propagate cell ids to the next level
|
|
189
|
+
l_data, n_cell_ids = self.f.ud_grade_2(
|
|
190
|
+
l_data, cell_ids=l_cell_ids[l], nside=l_nside
|
|
191
|
+
)
|
|
192
|
+
l_cell_ids[l + 1] = self.f.backend.to_numpy(n_cell_ids)
|
|
193
|
+
l_nside //= 2
|
|
194
|
+
|
|
195
|
+
# +1 channel to concatenate the downgraded input (skip-like feature)
|
|
196
|
+
l_chan = chanlist[l] + 1
|
|
197
|
+
|
|
198
|
+
# Freeze encoder bookkeeping
|
|
199
|
+
self.n_cnn = n
|
|
200
|
+
self.l_cell_ids = l_cell_ids
|
|
201
|
+
self.wconv = wconv
|
|
202
|
+
self.hconv = hconv
|
|
203
|
+
|
|
204
|
+
# -----------------------------
|
|
205
|
+
# Decoder (upsampling) build
|
|
206
|
+
# -----------------------------
|
|
207
|
+
m_cell_ids: Dict[int, np.ndarray] = {}
|
|
208
|
+
m_cell_ids[0] = l_cell_ids[nlayer]
|
|
209
|
+
t_wconv: Dict[int, int] = {} # decoder weight offsets
|
|
210
|
+
t_hconv: Dict[int, hs.SphericalPencil] = {} # decoder conv operators
|
|
211
|
+
|
|
75
212
|
for l in range(nlayer):
|
|
76
|
-
#
|
|
77
|
-
l_chan+=
|
|
78
|
-
l_data=self.f.up_grade(
|
|
79
|
-
|
|
80
|
-
|
|
81
|
-
|
|
82
|
-
|
|
83
|
-
|
|
84
|
-
|
|
85
|
-
|
|
86
|
-
|
|
87
|
-
|
|
88
|
-
|
|
89
|
-
|
|
90
|
-
|
|
91
|
-
|
|
92
|
-
|
|
93
|
-
|
|
94
|
-
|
|
95
|
-
|
|
96
|
-
|
|
97
|
-
|
|
98
|
-
|
|
99
|
-
|
|
100
|
-
|
|
101
|
-
|
|
213
|
+
# Upsample features to the previous (finer) resolution
|
|
214
|
+
l_chan += 1 # account for concatenation before first conv at this level
|
|
215
|
+
l_data = self.f.up_grade(
|
|
216
|
+
l_data,
|
|
217
|
+
l_nside * 2,
|
|
218
|
+
cell_ids=l_cell_ids[nlayer - l],
|
|
219
|
+
o_cell_ids=l_cell_ids[nlayer - 1 - l],
|
|
220
|
+
nside=l_nside,
|
|
221
|
+
)
|
|
222
|
+
print("Transpose Layer %d Npix=%d" % (l, l_data.shape[2]))
|
|
223
|
+
|
|
224
|
+
m_cell_ids[l] = l_cell_ids[nlayer - 1 - l]
|
|
225
|
+
l_nside *= 2
|
|
226
|
+
|
|
227
|
+
# First decoder conv: (l_chan -> l_chan)
|
|
228
|
+
t_wconv[2 * l] = n
|
|
229
|
+
nw = l_chan * l_chan * kernelsz * kernelsz
|
|
230
|
+
print("Transpose Layer %d conv [%d,%d]" % (l, l_chan, l_chan))
|
|
231
|
+
n += nw
|
|
232
|
+
|
|
233
|
+
# Second decoder conv: (l_chan -> out_chan)
|
|
234
|
+
t_wconv[2 * l + 1] = n
|
|
235
|
+
out_chan = 1
|
|
236
|
+
if nlayer - 1 - l > 0:
|
|
237
|
+
out_chan += chanlist[nlayer - 1 - l]
|
|
238
|
+
print("Transpose Layer %d conv [%d,%d]" % (l, l_chan, out_chan))
|
|
239
|
+
nw = l_chan * out_chan * kernelsz * kernelsz
|
|
240
|
+
n += nw
|
|
241
|
+
|
|
242
|
+
# Build oriented convolution operator for this decoder level
|
|
243
|
+
hconvol = hs.SphericalPencil(l_nside, 3, cell_ids=m_cell_ids[l])
|
|
102
244
|
hconvol.make_idx_weights()
|
|
103
|
-
t_hconv[l]=hconvol
|
|
104
|
-
|
|
105
|
-
#
|
|
106
|
-
l_chan=out_chan
|
|
107
|
-
|
|
108
|
-
|
|
109
|
-
|
|
110
|
-
|
|
111
|
-
self.
|
|
112
|
-
self.
|
|
113
|
-
self.
|
|
114
|
-
self.
|
|
115
|
-
|
|
116
|
-
|
|
245
|
+
t_hconv[l] = hconvol
|
|
246
|
+
|
|
247
|
+
# Update channel count after producing out_chan
|
|
248
|
+
l_chan = out_chan
|
|
249
|
+
|
|
250
|
+
print("Final chan %d Npix=%d" % (out_chan, l_data.shape[2]))
|
|
251
|
+
|
|
252
|
+
# Freeze decoder bookkeeping
|
|
253
|
+
self.n_cnn = n
|
|
254
|
+
self.m_cell_ids = l_cell_ids # mirror of encoder ids (kept for backward compat)
|
|
255
|
+
self.t_wconv = t_wconv
|
|
256
|
+
self.t_hconv = t_hconv
|
|
257
|
+
|
|
258
|
+
# Initialize flat parameter vector with small random values
|
|
259
|
+
self.x = self.f.backend.bk_cast((np.random.rand(n) - 0.5) / self.KERNELSZ)
|
|
260
|
+
|
|
261
|
+
# Expose config
|
|
262
|
+
self.nside = in_nside
|
|
263
|
+
self.n_chan_in = n_chan_in
|
|
264
|
+
self.chanlist = chanlist
|
|
117
265
|
|
|
118
266
|
def get_param(self):
|
|
267
|
+
"""Return the flat parameter vector that stores all convolution kernels.
|
|
268
|
+
|
|
269
|
+
Returns
|
|
270
|
+
-------
|
|
271
|
+
backend tensor (1‑D)
|
|
272
|
+
The Foscat backend representation (e.g., NumPy/Torch/TF tensor)
|
|
273
|
+
holding all convolution weights in a single vector.
|
|
274
|
+
"""
|
|
119
275
|
return self.x
|
|
120
276
|
|
|
121
|
-
def set_param(self,x):
|
|
122
|
-
|
|
123
|
-
|
|
124
|
-
|
|
125
|
-
|
|
126
|
-
|
|
127
|
-
|
|
128
|
-
|
|
129
|
-
|
|
130
|
-
|
|
131
|
-
|
|
132
|
-
|
|
133
|
-
|
|
277
|
+
def set_param(self, x):
|
|
278
|
+
"""Overwrite the flat parameter vector with externally provided values.
|
|
279
|
+
|
|
280
|
+
This is useful when optimizing parameters with an external optimizer or
|
|
281
|
+
when restoring weights from a checkpoint (after proper conversion to the
|
|
282
|
+
Foscat backend type).
|
|
283
|
+
|
|
284
|
+
Parameters
|
|
285
|
+
----------
|
|
286
|
+
x : array-like (1‑D)
|
|
287
|
+
New values for the flat parameter vector. Must match `self.x` size.
|
|
288
|
+
"""
|
|
289
|
+
self.x = self.f.backend.bk_cast(x)
|
|
290
|
+
|
|
291
|
+
def eval(self, data):
|
|
292
|
+
"""Run a forward pass through the encoder/decoder.
|
|
293
|
+
|
|
294
|
+
Parameters
|
|
295
|
+
----------
|
|
296
|
+
data : backend tensor, shape (B, C, Npix)
|
|
297
|
+
Input signal at resolution `self.nside` (finest grid). `C` must
|
|
298
|
+
equal `self.n_chan_in`.
|
|
299
|
+
|
|
300
|
+
Returns
|
|
301
|
+
-------
|
|
302
|
+
backend tensor, shape (B, C_out, Npix)
|
|
303
|
+
Network output at the input resolution. `C_out` is `1` at the top
|
|
304
|
+
level, or `1 + chanlist[level]` for intermediate decoder levels.
|
|
305
|
+
|
|
306
|
+
Notes
|
|
307
|
+
-----
|
|
308
|
+
The forward comprises two stages:
|
|
309
|
+
(1) **Encoder**: for each level `l`, apply two oriented convolutions
|
|
310
|
+
("conv -> conv"), downsample to the next coarser grid, and
|
|
311
|
+
concatenate with a downgraded copy of the running input (`m_data`).
|
|
312
|
+
(2) **Decoder**: for each level, upsample to the finer grid, concatenate
|
|
313
|
+
with the stored encoder feature (skip connection), then apply two
|
|
314
|
+
oriented convolutions ("conv -> conv") to produce `out_chan`.
|
|
315
|
+
"""
|
|
316
|
+
# Encoder state
|
|
317
|
+
l_nside = self.nside
|
|
318
|
+
l_chan = self.n_chan_in
|
|
319
|
+
l_data = data
|
|
320
|
+
m_data = data # running copy of input used for the additional concat
|
|
321
|
+
nlayer = len(self.chanlist)
|
|
322
|
+
kernelsz = self.KERNELSZ
|
|
323
|
+
ud_data: Dict[int, object] = {} # stores per-level skip features
|
|
324
|
+
|
|
325
|
+
# -----------------
|
|
326
|
+
# Encoder traversal
|
|
327
|
+
# -----------------
|
|
134
328
|
for l in range(nlayer):
|
|
135
|
-
#
|
|
136
|
-
nw=l_chan*self.chanlist[l]*kernelsz*kernelsz
|
|
137
|
-
ww=self.x[self.wconv[2*l]:self.wconv[2*l]+nw]
|
|
138
|
-
ww=self.f.backend.bk_reshape(
|
|
139
|
-
|
|
140
|
-
|
|
141
|
-
l_data = self.hconv[l].Convol_torch(l_data,ww)
|
|
142
|
-
|
|
143
|
-
|
|
144
|
-
|
|
145
|
-
ww=self.
|
|
146
|
-
|
|
147
|
-
|
|
148
|
-
|
|
149
|
-
l_data = self.hconv[l].Convol_torch(l_data,ww)
|
|
150
|
-
|
|
151
|
-
l_data,_=self.f.ud_grade_2(l_data,
|
|
152
|
-
cell_ids=self.l_cell_ids[l],
|
|
153
|
-
nside=l_nside)
|
|
154
|
-
|
|
155
|
-
ud_data[l]=m_data
|
|
156
|
-
|
|
157
|
-
m_data,_=self.f.ud_grade_2(m_data,
|
|
158
|
-
cell_ids=self.l_cell_ids[l],
|
|
159
|
-
nside=l_nside)
|
|
160
|
-
|
|
161
|
-
l_data = self.f.backend.bk_concat([m_data,l_data],1)
|
|
162
|
-
|
|
163
|
-
l_nside//=2
|
|
164
|
-
# plus one to add the input downgrade data
|
|
165
|
-
l_chan=self.chanlist[l]+self.n_chan_in
|
|
329
|
+
# Fetch weights for conv (in -> chanlist[l]) and reshape to Foscat backend
|
|
330
|
+
nw = l_chan * self.chanlist[l] * kernelsz * kernelsz
|
|
331
|
+
ww = self.x[self.wconv[2 * l] : self.wconv[2 * l] + nw]
|
|
332
|
+
ww = self.f.backend.bk_reshape(
|
|
333
|
+
ww, [l_chan, self.chanlist[l], kernelsz * kernelsz]
|
|
334
|
+
)
|
|
335
|
+
l_data = self.hconv[l].Convol_torch(l_data, ww)
|
|
336
|
+
|
|
337
|
+
# Second conv (chanlist[l] -> chanlist[l])
|
|
338
|
+
nw = self.chanlist[l] * self.chanlist[l] * kernelsz * kernelsz
|
|
339
|
+
ww = self.x[self.wconv[2 * l + 1] : self.wconv[2 * l + 1] + nw]
|
|
340
|
+
ww = self.f.backend.bk_reshape(
|
|
341
|
+
ww, [self.chanlist[l], self.chanlist[l], kernelsz * kernelsz]
|
|
342
|
+
)
|
|
343
|
+
l_data = self.hconv[l].Convol_torch(l_data, ww)
|
|
166
344
|
|
|
345
|
+
# Downsample features and store skip connection
|
|
346
|
+
l_data, _ = self.f.ud_grade_2(
|
|
347
|
+
l_data, cell_ids=self.l_cell_ids[l], nside=l_nside
|
|
348
|
+
)
|
|
349
|
+
ud_data[l] = m_data
|
|
350
|
+
|
|
351
|
+
# Also downgrade the running input for the auxiliary concat
|
|
352
|
+
m_data, _ = self.f.ud_grade_2(
|
|
353
|
+
m_data, cell_ids=self.l_cell_ids[l], nside=l_nside
|
|
354
|
+
)
|
|
355
|
+
|
|
356
|
+
# Concatenate along channels: [ downgraded_input , features ]
|
|
357
|
+
l_data = self.f.backend.bk_concat([m_data, l_data], 1)
|
|
358
|
+
|
|
359
|
+
l_nside //= 2
|
|
360
|
+
l_chan = self.chanlist[l] + 1 # account for the concat above
|
|
361
|
+
|
|
362
|
+
# -----------------
|
|
363
|
+
# Decoder traversal
|
|
364
|
+
# -----------------
|
|
167
365
|
for l in range(nlayer):
|
|
168
|
-
|
|
169
|
-
|
|
170
|
-
|
|
171
|
-
|
|
172
|
-
|
|
173
|
-
|
|
174
|
-
|
|
175
|
-
|
|
176
|
-
|
|
177
|
-
|
|
178
|
-
#
|
|
179
|
-
|
|
180
|
-
|
|
181
|
-
|
|
182
|
-
|
|
183
|
-
|
|
184
|
-
|
|
185
|
-
|
|
186
|
-
|
|
187
|
-
|
|
188
|
-
|
|
189
|
-
|
|
190
|
-
|
|
191
|
-
|
|
192
|
-
|
|
193
|
-
|
|
194
|
-
|
|
195
|
-
|
|
196
|
-
|
|
197
|
-
|
|
198
|
-
|
|
199
|
-
|
|
366
|
+
# Upsample to finer grid
|
|
367
|
+
l_chan += 1 # due to upcoming concat with ud_data
|
|
368
|
+
l_data = self.f.up_grade(
|
|
369
|
+
l_data,
|
|
370
|
+
l_nside * 2,
|
|
371
|
+
cell_ids=self.l_cell_ids[nlayer - l],
|
|
372
|
+
o_cell_ids=self.l_cell_ids[nlayer - 1 - l],
|
|
373
|
+
nside=l_nside,
|
|
374
|
+
)
|
|
375
|
+
|
|
376
|
+
# Concatenate with encoder skip features
|
|
377
|
+
l_data = self.f.backend.bk_concat([ud_data[nlayer - 1 - l], l_data], 1)
|
|
378
|
+
l_nside *= 2
|
|
379
|
+
|
|
380
|
+
# Determine output channels at this level
|
|
381
|
+
out_chan = 1
|
|
382
|
+
if nlayer - 1 - l > 0:
|
|
383
|
+
out_chan += self.chanlist[nlayer - 1 - l]
|
|
384
|
+
|
|
385
|
+
# First decoder conv (l_chan -> l_chan)
|
|
386
|
+
nw = l_chan * l_chan * kernelsz * kernelsz
|
|
387
|
+
ww = self.x[self.t_wconv[2 * l] : self.t_wconv[2 * l] + nw]
|
|
388
|
+
ww = self.f.backend.bk_reshape(ww, [l_chan, l_chan, kernelsz * kernelsz])
|
|
389
|
+
c_data = self.t_hconv[l].Convol_torch(l_data, ww)
|
|
390
|
+
|
|
391
|
+
# Second decoder conv (l_chan -> out_chan)
|
|
392
|
+
nw = l_chan * out_chan * kernelsz * kernelsz
|
|
393
|
+
ww = self.x[self.t_wconv[2 * l + 1] : self.t_wconv[2 * l + 1] + nw]
|
|
394
|
+
ww = self.f.backend.bk_reshape(ww, [l_chan, out_chan, kernelsz * kernelsz])
|
|
395
|
+
l_data = self.t_hconv[l].Convol_torch(c_data, ww)
|
|
396
|
+
|
|
397
|
+
# Update channel count for next iteration
|
|
398
|
+
l_chan = out_chan
|
|
399
|
+
|
|
200
400
|
return l_data
|
|
401
|
+
|
|
402
|
+
|
|
403
|
+
def to_tensor(self,x):
|
|
404
|
+
if self.f is None:
|
|
405
|
+
if self.dtype==torch.float64:
|
|
406
|
+
self.f=sc.funct(KERNELSZ=self.KERNELSZ,all_type='float64')
|
|
407
|
+
else:
|
|
408
|
+
self.f=sc.funct(KERNELSZ=self.KERNELSZ,all_type='float32')
|
|
409
|
+
return self.f.backend.bk_cast(x)
|
|
410
|
+
|
|
411
|
+
def to_numpy(self,x):
|
|
412
|
+
if isinstance(x,np.ndarray):
|
|
413
|
+
return x
|
|
414
|
+
return x.cpu().numpy()
|
|
415
|
+
|
|
416
|
+
|
|
417
|
+
# -----------------------------
|
|
418
|
+
# Unit tests (smoke tests)
|
|
419
|
+
# -----------------------------
|
|
420
|
+
# Run with: python UNET.py (or) python UNET.py -q for quieter output
|
|
421
|
+
# These tests assume Foscat and its dependencies are installed.
|
|
422
|
+
|
|
423
|
+
|
|
424
|
+
def _dummy_cell_ids(nside: int) -> np.ndarray:
|
|
425
|
+
"""Return a simple identity mapping for HEALPix nested pixel IDs.
|
|
426
|
+
|
|
427
|
+
Notes
|
|
428
|
+
-----
|
|
429
|
+
Replace with your pipeline's real `cell_ids` if you have a precomputed
|
|
430
|
+
mapping consistent with Foscat/HEALPix nested ordering.
|
|
431
|
+
"""
|
|
432
|
+
return np.arange(12 * nside * nside, dtype=np.int64)
|
|
433
|
+
|
|
434
|
+
|
|
435
|
+
if __name__ == "__main__":
|
|
436
|
+
import unittest
|
|
437
|
+
|
|
438
|
+
class TestUNET(unittest.TestCase):
|
|
439
|
+
"""Lightweight smoke tests for shape and parameter plumbing."""
|
|
440
|
+
|
|
441
|
+
def setUp(self):
|
|
442
|
+
self.nside = 4 # small grid for fast tests (npix = 192)
|
|
443
|
+
self.chanlist = [4, 8] # two-level encoder/decoder
|
|
444
|
+
self.batch = 2
|
|
445
|
+
self.channels = 1
|
|
446
|
+
self.npix = 12 * self.nside * self.nside
|
|
447
|
+
self.cell_ids = _dummy_cell_ids(self.nside)
|
|
448
|
+
self.net = UNET(
|
|
449
|
+
in_nside=self.nside,
|
|
450
|
+
n_chan_in=self.channels,
|
|
451
|
+
chanlist=self.chanlist,
|
|
452
|
+
cell_ids=self.cell_ids,
|
|
453
|
+
)
|
|
454
|
+
|
|
455
|
+
def test_forward_shape(self):
|
|
456
|
+
# random input
|
|
457
|
+
x = np.random.randn(self.batch, self.channels, self.npix).astype(np.float32)
|
|
458
|
+
x = self.net.f.backend.bk_cast(x)
|
|
459
|
+
y = self.net.eval(x)
|
|
460
|
+
# expected output: same npix, 1 channel at the very top
|
|
461
|
+
self.assertEqual(y.shape[0], self.batch)
|
|
462
|
+
self.assertEqual(y.shape[1], 1)
|
|
463
|
+
self.assertEqual(y.shape[2], self.npix)
|
|
464
|
+
# sanity: no NaNs
|
|
465
|
+
y_np = self.net.f.backend.to_numpy(y)
|
|
466
|
+
self.assertFalse(np.isnan(y_np).any())
|
|
467
|
+
|
|
468
|
+
def test_param_roundtrip_and_determinism(self):
|
|
469
|
+
x = np.random.randn(self.batch, self.channels, self.npix).astype(np.float32)
|
|
470
|
+
x = self.net.f.backend.bk_cast(x)
|
|
471
|
+
|
|
472
|
+
# forward twice -> identical outputs with fixed params
|
|
473
|
+
y1 = self.net.eval(x)
|
|
474
|
+
y2 = self.net.eval(x)
|
|
475
|
+
y1_np = self.net.f.backend.to_numpy(y1)
|
|
476
|
+
y2_np = self.net.f.backend.to_numpy(y2)
|
|
477
|
+
np.testing.assert_allclose(y1_np, y2_np, rtol=0, atol=0)
|
|
478
|
+
|
|
479
|
+
# perturb parameters -> output should (very likely) change
|
|
480
|
+
p = self.net.get_param()
|
|
481
|
+
p_np = self.net.f.backend.to_numpy(p).copy()
|
|
482
|
+
if p_np.size > 0:
|
|
483
|
+
p_np[0] += 1.0
|
|
484
|
+
self.net.set_param(p_np)
|
|
485
|
+
y3 = self.net.eval(x)
|
|
486
|
+
y3_np = self.net.f.backend.to_numpy(y3)
|
|
487
|
+
with self.assertRaises(AssertionError):
|
|
488
|
+
np.testing.assert_allclose(y1_np, y3_np, rtol=0, atol=0)
|
|
489
|
+
|
|
490
|
+
unittest.main()
|
|
491
|
+
|