foscat 2025.8.4__py3-none-any.whl → 2025.9.1__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
foscat/UNET.py CHANGED
@@ -1,200 +1,477 @@
1
+ """
2
+ UNET for HEALPix (nested) using Foscat oriented convolutions.
3
+
4
+ This module defines a lightweight, U-Net–like encoder/decoder that operates on
5
+ signals defined on the HEALPix sphere (nested scheme). It leverages Foscat's
6
+ `HOrientedConvol` for orientation-aware convolutions and `funct` utilities for
7
+ upgrade/downgrade (change of nside) operations.
8
+
9
+ Key design choices
10
+ ------------------
11
+ • **Flat parameter vector**: all convolution kernels are stored in a single
12
+ 1‑D vector `self.x`. The dictionaries `self.wconv` and `self.t_wconv` map
13
+ layer indices to *offsets* within that vector.
14
+ • **HEALPix-aware down/up-sampling**: down-sampling uses
15
+ `self.f.ud_grade_2`, and up-sampling uses `self.f.up_grade`, both with
16
+ per-level `cell_ids` to preserve locality and orientation information.
17
+ • **Skip connections**: U‑Net skip connections are implemented by concatenating
18
+ encoder features with downgraded/upsampled paths along the channel axis.
19
+
20
+ Shape convention
21
+ ----------------
22
+ All tensors follow the Foscat backend shape `(batch, channels, npix)`.
23
+
24
+ Dependencies
25
+ ------------
26
+ - foscat.scat_cov as `sc`
27
+ - foscat.HOrientedConvol as `hs`
28
+
29
+ Example
30
+ -------
31
+ >>> import numpy as np
32
+ >>> from UNET import UNET
33
+ >>> nside = 8
34
+ >>> npix = 12 * nside * nside
35
+ >>> # Your backend tensor should be created via foscat backend; here we show a placeholder np.array
36
+ >>> x = np.random.randn(1, 1, npix).astype(np.float32)
37
+ >>> # cell_ids should be provided for the highest resolution (nside)
38
+ >>> # and must be consistent with the nested scheme expected by Foscat.
39
+ >>> # Example placeholder (use the real one from your pipeline):
40
+ >>> cell_ids = np.arange(npix, dtype=np.int64)
41
+ >>> net = UNET(in_nside=nside, n_chan_in=1, cell_ids=cell_ids)
42
+ >>> y = net.eval(net.f.backend.bk_cast(x)) # forward pass
43
+
44
+ Notes
45
+ -----
46
+ - This implementation assumes `cell_ids` is provided for the input resolution
47
+ `in_nside`. It propagates/derives the coarser `cell_ids` across levels.
48
+ - Some constructor parameters are reserved for future use (see docstring).
49
+ """
50
+
51
+ from typing import Dict, Optional
1
52
  import numpy as np
2
53
 
3
54
  import foscat.scat_cov as sc
4
55
  import foscat.HOrientedConvol as hs
5
56
 
57
+
6
58
  class UNET:
59
+ """U‑Net–like network on HEALPix (nested) using Foscat oriented convolutions.
60
+
61
+ The network is built as an encoder/decoder (down/upsampling) tower. Each
62
+ level performs two oriented convolutions. All kernels are packed in a flat
63
+ parameter vector `self.x` to simplify optimization with external solvers.
64
+
65
+ Parameters
66
+ ----------
67
+ nparam : int, optional
68
+ Reserved for future use. Currently unused.
69
+ KERNELSZ : int, optional
70
+ Spatial kernel size (k × k) used by oriented convolutions. Default is 3.
71
+ NORIENT : int, optional
72
+ Reserved for future use (number of orientations). Currently unused.
73
+ chanlist : Optional[list[int]], optional
74
+ Number of output channels per encoder level. If ``None``, it defaults to
75
+ ``[4 * 2**k for k in range(log2(in_nside))]``. The length of this list
76
+ defines the number of encoder/decoder levels.
77
+ in_nside : int, optional
78
+ Input HEALPix nside. Must be a power of two for the implicit
79
+ ``log2(in_nside)`` depth when ``chanlist`` is not given.
80
+ n_chan_in : int, optional
81
+ Number of input channels at the finest resolution. Default is 1.
82
+ cell_ids : array-like of int, required
83
+ Pixel identifiers at the input resolution (nested indexing). They are
84
+ used to build oriented convolutions and to derive coarser grids.
85
+ **Must not be ``None``.**
86
+ SEED : int, optional
87
+ Reserved for future use (random initialization seed). Currently unused.
88
+ filename : Optional[str], optional
89
+ Reserved for future use (checkpoint I/O). Currently unused.
90
+
91
+ Attributes
92
+ ----------
93
+ f : object
94
+ Foscat helper exposing the backend and grade/convolution utils.
95
+ KERNELSZ : int
96
+ Effective kernel size used by all convolutions.
97
+ chanlist : list[int]
98
+ Channels per encoder level.
99
+ wconv, t_wconv : Dict[int, int]
100
+ Offsets into the flat parameter vector `self.x` for encoder/decoder
101
+ convolutions respectively.
102
+ hconv, t_hconv : Dict[int, hs.HOrientedConvol]
103
+ Per-level oriented convolution operators for encoder/decoder.
104
+ l_cell_ids : Dict[int, np.ndarray]
105
+ Per-level cell ids for downsampled grids (encoder side).
106
+ m_cell_ids : Dict[int, np.ndarray]
107
+ Per-level cell ids for upsampled grids (decoder side). Mirrors levels of
108
+ ``l_cell_ids`` but indexed from the decoder traversal.
109
+ x : backend tensor (1‑D)
110
+ Flat vector holding *all* convolution weights.
111
+ nside : int
112
+ Input nside (finest resolution).
113
+ n_chan_in : int
114
+ Number of channels at input.
115
+
116
+ Notes
117
+ -----
118
+ - The constructor prints informative messages about the architecture layout
119
+ (channels and pixel counts) to ease debugging.
120
+ - The implementation keeps the logic identical to the original code; only
121
+ comments, docstrings and variable explanations are added.
122
+ """
7
123
 
8
124
  def __init__(
9
- self,
10
- nparam=1,
11
- KERNELSZ=3,
12
- NORIENT=4,
13
- chanlist=None,
14
- in_nside=1,
15
- n_chan_in=1,
16
- n_chan_out=1,
17
- cell_ids=None,
18
- SEED=1234,
19
- filename=None,
125
+ self,
126
+ nparam: int = 1,
127
+ KERNELSZ: int = 3,
128
+ NORIENT: int = 4,
129
+ chanlist: Optional[list] = None,
130
+ in_nside: int = 1,
131
+ n_chan_in: int = 1,
132
+ cell_ids: Optional[np.ndarray] = None,
133
+ SEED: int = 1234,
134
+ filename: Optional[str] = None,
20
135
  ):
21
- self.f=sc.funct(KERNELSZ=KERNELSZ)
22
-
136
+ # Foscat function wrapper providing backend and grade ops
137
+ self.f = sc.funct(KERNELSZ=KERNELSZ)
138
+
139
+ # If no channel plan is provided, build a default pyramid depth of
140
+ # log2(in_nside) levels with channels growing as 4 * 2**k
23
141
  if chanlist is None:
24
- nlayer=int(np.log2(in_nside))
25
- chanlist=[4*2**k for k in range(nlayer)]
142
+ nlayer = int(np.log2(in_nside))
143
+ chanlist = [4 * 2 ** k for k in range(nlayer)]
26
144
  else:
27
- nlayer=len(chanlist)
28
- print('N_layer ',nlayer)
29
-
30
- n=0
31
- wconv={}
32
- hconv={}
33
- l_cell_ids={}
34
- self.KERNELSZ=KERNELSZ
35
- kernelsz=self.KERNELSZ
36
-
37
- # create the CNN part
38
- l_nside=in_nside
39
- l_cell_ids[0]=cell_ids.copy()
40
- l_data=self.f.backend.bk_cast(np.ones([1,1,l_cell_ids[0].shape[0]]))
41
- l_chan=n_chan_in
42
- print('Initial chan %d Npix=%d'%(l_chan,l_data.shape[2]))
145
+ nlayer = len(chanlist)
146
+ print("N_layer ", nlayer)
147
+
148
+ # Internal registries
149
+ n = 0 # running offset in the flat parameter vector
150
+ wconv: Dict[int, int] = {} # encoder weight offsets
151
+ hconv: Dict[int, hs.HOrientedConvol] = {} # encoder conv operators
152
+ l_cell_ids: Dict[int, np.ndarray] = {} # encoder level cell ids
153
+ self.KERNELSZ = KERNELSZ
154
+ kernelsz = self.KERNELSZ
155
+
156
+ # -----------------------------
157
+ # Encoder (downsampling) build
158
+ # -----------------------------
159
+ l_nside = in_nside
160
+ # NOTE: the original code assumes cell_ids is provided; we keep that
161
+ # contract and copy to avoid side effects.
162
+ l_cell_ids[0] = cell_ids.copy()
163
+ # Create a dummy data tensor to probe shapes; real data arrives in eval()
164
+ l_data = self.f.backend.bk_cast(np.ones([1, 1, l_cell_ids[0].shape[0]]))
165
+ l_chan = n_chan_in
166
+ print("Initial chan %d Npix=%d" % (l_chan, l_data.shape[2]))
167
+
43
168
  for l in range(nlayer):
44
- print('Layer %d Npix=%d'%(l,l_data.shape[2]))
45
- # init double convol weights
46
- wconv[2*l]=n
47
- nw=l_chan*chanlist[l]*kernelsz*kernelsz
48
- print('Layer %d conv [%d,%d]'%(l,l_chan,chanlist[l]))
49
- n+=nw
50
- wconv[2*l+1]=n
51
- nw=chanlist[l]*chanlist[l]*kernelsz*kernelsz
52
- print('Layer %d conv [%d,%d]'%(l,chanlist[l],chanlist[l]))
53
- n+=nw
54
-
55
- hconvol=hs.HOrientedConvol(l_nside,3,cell_ids=l_cell_ids[l])
56
- hconvol.make_idx_weights()
57
- hconv[l]=hconvol
58
- l_data,n_cell_ids=self.f.ud_grade_2(l_data,cell_ids=l_cell_ids[l],nside=l_nside)
59
- l_cell_ids[l+1]=self.f.backend.to_numpy(n_cell_ids)
60
- l_nside//=2
61
- # plus one to add the input downgrade data
62
- l_chan=chanlist[l]+n_chan_in
63
-
64
- self.n_cnn=n
65
- self.l_cell_ids=l_cell_ids
66
- self.wconv=wconv
67
- self.hconv=hconv
68
-
69
- # create the transpose CNN part
70
- m_cell_ids={}
71
- m_cell_ids[0]=l_cell_ids[nlayer]
72
- t_wconv={}
73
- t_hconv={}
74
-
169
+ print("Layer %d Npix=%d" % (l, l_data.shape[2]))
170
+
171
+ # Record offset for first conv at this level: (in -> chanlist[l])
172
+ wconv[2 * l] = n
173
+ nw = l_chan * chanlist[l] * kernelsz * kernelsz
174
+ print("Layer %d conv [%d,%d]" % (l, l_chan, chanlist[l]))
175
+ n += nw
176
+
177
+ # Record offset for second conv at this level: (chanlist[l] -> chanlist[l])
178
+ wconv[2 * l + 1] = n
179
+ nw = chanlist[l] * chanlist[l] * kernelsz * kernelsz
180
+ print("Layer %d conv [%d,%d]" % (l, chanlist[l], chanlist[l]))
181
+ n += nw
182
+
183
+ # Build oriented convolution operator for this level
184
+ hconvol = hs.HOrientedConvol(l_nside, 3, cell_ids=l_cell_ids[l])
185
+ hconvol.make_idx_weights() # precompute indices/weights once
186
+ hconv[l] = hconvol
187
+
188
+ # Downsample features and propagate cell ids to the next level
189
+ l_data, n_cell_ids = self.f.ud_grade_2(
190
+ l_data, cell_ids=l_cell_ids[l], nside=l_nside
191
+ )
192
+ l_cell_ids[l + 1] = self.f.backend.to_numpy(n_cell_ids)
193
+ l_nside //= 2
194
+
195
+ # +1 channel to concatenate the downgraded input (skip-like feature)
196
+ l_chan = chanlist[l] + 1
197
+
198
+ # Freeze encoder bookkeeping
199
+ self.n_cnn = n
200
+ self.l_cell_ids = l_cell_ids
201
+ self.wconv = wconv
202
+ self.hconv = hconv
203
+
204
+ # -----------------------------
205
+ # Decoder (upsampling) build
206
+ # -----------------------------
207
+ m_cell_ids: Dict[int, np.ndarray] = {}
208
+ m_cell_ids[0] = l_cell_ids[nlayer]
209
+ t_wconv: Dict[int, int] = {} # decoder weight offsets
210
+ t_hconv: Dict[int, hs.HOrientedConvol] = {} # decoder conv operators
211
+
75
212
  for l in range(nlayer):
76
- #upgrade data
77
- l_chan+=n_chan_in
78
- l_data=self.f.up_grade(l_data,l_nside*2,
79
- cell_ids=l_cell_ids[nlayer-l],
80
- o_cell_ids=l_cell_ids[nlayer-1-l],
81
- nside=l_nside)
82
- print('Transpose Layer %d Npix=%d'%(l,l_data.shape[2]))
83
-
84
-
85
- m_cell_ids[l]=l_cell_ids[nlayer-1-l]
86
- l_nside*=2
87
-
88
- # init double convol weights
89
- t_wconv[2*l]=n
90
- nw=l_chan*l_chan*kernelsz*kernelsz
91
- print('Transpose Layer %d conv [%d,%d]'%(l,l_chan,l_chan))
92
- n+=nw
93
- t_wconv[2*l+1]=n
94
- out_chan=n_chan_out
95
- if nlayer-1-l>0:
96
- out_chan+=chanlist[nlayer-1-l]
97
- print('Transpose Layer %d conv [%d,%d]'%(l,l_chan,out_chan))
98
- nw=l_chan*out_chan*kernelsz*kernelsz
99
- n+=nw
100
-
101
- hconvol=hs.HOrientedConvol(l_nside,3,cell_ids=m_cell_ids[l])
213
+ # Upsample features to the previous (finer) resolution
214
+ l_chan += 1 # account for concatenation before first conv at this level
215
+ l_data = self.f.up_grade(
216
+ l_data,
217
+ l_nside * 2,
218
+ cell_ids=l_cell_ids[nlayer - l],
219
+ o_cell_ids=l_cell_ids[nlayer - 1 - l],
220
+ nside=l_nside,
221
+ )
222
+ print("Transpose Layer %d Npix=%d" % (l, l_data.shape[2]))
223
+
224
+ m_cell_ids[l] = l_cell_ids[nlayer - 1 - l]
225
+ l_nside *= 2
226
+
227
+ # First decoder conv: (l_chan -> l_chan)
228
+ t_wconv[2 * l] = n
229
+ nw = l_chan * l_chan * kernelsz * kernelsz
230
+ print("Transpose Layer %d conv [%d,%d]" % (l, l_chan, l_chan))
231
+ n += nw
232
+
233
+ # Second decoder conv: (l_chan -> out_chan)
234
+ t_wconv[2 * l + 1] = n
235
+ out_chan = 1
236
+ if nlayer - 1 - l > 0:
237
+ out_chan += chanlist[nlayer - 1 - l]
238
+ print("Transpose Layer %d conv [%d,%d]" % (l, l_chan, out_chan))
239
+ nw = l_chan * out_chan * kernelsz * kernelsz
240
+ n += nw
241
+
242
+ # Build oriented convolution operator for this decoder level
243
+ hconvol = hs.HOrientedConvol(l_nside, 3, cell_ids=m_cell_ids[l])
102
244
  hconvol.make_idx_weights()
103
- t_hconv[l]=hconvol
104
-
105
- # plus one to add the input downgrade data
106
- l_chan=out_chan
107
- print('Final chan %d Npix=%d'%(out_chan,l_data.shape[2]))
108
- self.n_cnn=n
109
- self.m_cell_ids=l_cell_ids
110
- self.t_wconv=t_wconv
111
- self.t_hconv=t_hconv
112
- self.x=self.f.backend.bk_cast((np.random.rand(n)-0.5)/self.KERNELSZ)
113
- self.nside=in_nside
114
- self.n_chan_in=n_chan_in
115
- self.n_chan_out=n_chan_out
116
- self.chanlist=chanlist
245
+ t_hconv[l] = hconvol
246
+
247
+ # Update channel count after producing out_chan
248
+ l_chan = out_chan
249
+
250
+ print("Final chan %d Npix=%d" % (out_chan, l_data.shape[2]))
251
+
252
+ # Freeze decoder bookkeeping
253
+ self.n_cnn = n
254
+ self.m_cell_ids = l_cell_ids # mirror of encoder ids (kept for backward compat)
255
+ self.t_wconv = t_wconv
256
+ self.t_hconv = t_hconv
257
+
258
+ # Initialize flat parameter vector with small random values
259
+ self.x = self.f.backend.bk_cast((np.random.rand(n) - 0.5) / self.KERNELSZ)
260
+
261
+ # Expose config
262
+ self.nside = in_nside
263
+ self.n_chan_in = n_chan_in
264
+ self.chanlist = chanlist
117
265
 
118
266
  def get_param(self):
267
+ """Return the flat parameter vector that stores all convolution kernels.
268
+
269
+ Returns
270
+ -------
271
+ backend tensor (1‑D)
272
+ The Foscat backend representation (e.g., NumPy/Torch/TF tensor)
273
+ holding all convolution weights in a single vector.
274
+ """
119
275
  return self.x
120
276
 
121
- def set_param(self,x):
122
- self.x=self.f.backend.bk_cast(x)
123
-
124
- def eval(self,data):
125
- # create the CNN part
126
- l_nside=self.nside
127
- l_chan=self.n_chan_in
128
- l_data=data
129
- m_data=data
130
- nlayer=len(self.chanlist)
131
- kernelsz=self.KERNELSZ
132
- ud_data={}
133
-
277
+ def set_param(self, x):
278
+ """Overwrite the flat parameter vector with externally provided values.
279
+
280
+ This is useful when optimizing parameters with an external optimizer or
281
+ when restoring weights from a checkpoint (after proper conversion to the
282
+ Foscat backend type).
283
+
284
+ Parameters
285
+ ----------
286
+ x : array-like (1‑D)
287
+ New values for the flat parameter vector. Must match `self.x` size.
288
+ """
289
+ self.x = self.f.backend.bk_cast(x)
290
+
291
+ def eval(self, data):
292
+ """Run a forward pass through the encoder/decoder.
293
+
294
+ Parameters
295
+ ----------
296
+ data : backend tensor, shape (B, C, Npix)
297
+ Input signal at resolution `self.nside` (finest grid). `C` must
298
+ equal `self.n_chan_in`.
299
+
300
+ Returns
301
+ -------
302
+ backend tensor, shape (B, C_out, Npix)
303
+ Network output at the input resolution. `C_out` is `1` at the top
304
+ level, or `1 + chanlist[level]` for intermediate decoder levels.
305
+
306
+ Notes
307
+ -----
308
+ The forward comprises two stages:
309
+ (1) **Encoder**: for each level `l`, apply two oriented convolutions
310
+ ("conv -> conv"), downsample to the next coarser grid, and
311
+ concatenate with a downgraded copy of the running input (`m_data`).
312
+ (2) **Decoder**: for each level, upsample to the finer grid, concatenate
313
+ with the stored encoder feature (skip connection), then apply two
314
+ oriented convolutions ("conv -> conv") to produce `out_chan`.
315
+ """
316
+ # Encoder state
317
+ l_nside = self.nside
318
+ l_chan = self.n_chan_in
319
+ l_data = data
320
+ m_data = data # running copy of input used for the additional concat
321
+ nlayer = len(self.chanlist)
322
+ kernelsz = self.KERNELSZ
323
+ ud_data: Dict[int, object] = {} # stores per-level skip features
324
+
325
+ # -----------------
326
+ # Encoder traversal
327
+ # -----------------
134
328
  for l in range(nlayer):
135
- # init double convol weights
136
- nw=l_chan*self.chanlist[l]*kernelsz*kernelsz
137
- ww=self.x[self.wconv[2*l]:self.wconv[2*l]+nw]
138
- ww=self.f.backend.bk_reshape(ww,[l_chan,
139
- self.chanlist[l],
140
- kernelsz*kernelsz])
141
- l_data = self.hconv[l].Convol_torch(l_data,ww)
142
-
143
- nw=self.chanlist[l]*self.chanlist[l]*kernelsz*kernelsz
144
- ww=self.x[self.wconv[2*l+1]:self.wconv[2*l+1]+nw]
145
- ww=self.f.backend.bk_reshape(ww,[self.chanlist[l],
146
- self.chanlist[l],
147
- kernelsz*kernelsz])
148
-
149
- l_data = self.hconv[l].Convol_torch(l_data,ww)
150
-
151
- l_data,_=self.f.ud_grade_2(l_data,
152
- cell_ids=self.l_cell_ids[l],
153
- nside=l_nside)
154
-
155
- ud_data[l]=m_data
156
-
157
- m_data,_=self.f.ud_grade_2(m_data,
158
- cell_ids=self.l_cell_ids[l],
159
- nside=l_nside)
160
-
161
- l_data = self.f.backend.bk_concat([m_data,l_data],1)
162
-
163
- l_nside//=2
164
- # plus one to add the input downgrade data
165
- l_chan=self.chanlist[l]+self.n_chan_in
329
+ # Fetch weights for conv (in -> chanlist[l]) and reshape to Foscat backend
330
+ nw = l_chan * self.chanlist[l] * kernelsz * kernelsz
331
+ ww = self.x[self.wconv[2 * l] : self.wconv[2 * l] + nw]
332
+ ww = self.f.backend.bk_reshape(
333
+ ww, [l_chan, self.chanlist[l], kernelsz * kernelsz]
334
+ )
335
+ l_data = self.hconv[l].Convol_torch(l_data, ww)
336
+
337
+ # Second conv (chanlist[l] -> chanlist[l])
338
+ nw = self.chanlist[l] * self.chanlist[l] * kernelsz * kernelsz
339
+ ww = self.x[self.wconv[2 * l + 1] : self.wconv[2 * l + 1] + nw]
340
+ ww = self.f.backend.bk_reshape(
341
+ ww, [self.chanlist[l], self.chanlist[l], kernelsz * kernelsz]
342
+ )
343
+ l_data = self.hconv[l].Convol_torch(l_data, ww)
166
344
 
345
+ # Downsample features and store skip connection
346
+ l_data, _ = self.f.ud_grade_2(
347
+ l_data, cell_ids=self.l_cell_ids[l], nside=l_nside
348
+ )
349
+ ud_data[l] = m_data
350
+
351
+ # Also downgrade the running input for the auxiliary concat
352
+ m_data, _ = self.f.ud_grade_2(
353
+ m_data, cell_ids=self.l_cell_ids[l], nside=l_nside
354
+ )
355
+
356
+ # Concatenate along channels: [ downgraded_input , features ]
357
+ l_data = self.f.backend.bk_concat([m_data, l_data], 1)
358
+
359
+ l_nside //= 2
360
+ l_chan = self.chanlist[l] + 1 # account for the concat above
361
+
362
+ # -----------------
363
+ # Decoder traversal
364
+ # -----------------
167
365
  for l in range(nlayer):
168
- l_chan+=self.n_chan_in
169
- l_data=self.f.up_grade(l_data,l_nside*2,
170
- cell_ids=self.l_cell_ids[nlayer-l],
171
- o_cell_ids=self.l_cell_ids[nlayer-1-l],
172
- nside=l_nside)
173
-
174
-
175
- l_data = self.f.backend.bk_concat([ud_data[nlayer-1-l],l_data],1)
176
- l_nside*=2
177
-
178
- # init double convol weights
179
- out_chan=self.n_chan_out
180
- if nlayer-1-l>0:
181
- out_chan+=self.chanlist[nlayer-1-l]
182
- nw=l_chan*l_chan*kernelsz*kernelsz
183
- ww=self.x[self.t_wconv[2*l]:self.t_wconv[2*l]+nw]
184
- ww=self.f.backend.bk_reshape(ww,[l_chan,
185
- l_chan,
186
- kernelsz*kernelsz])
187
-
188
- c_data = self.t_hconv[l].Convol_torch(l_data,ww)
189
-
190
- nw=l_chan*out_chan*kernelsz*kernelsz
191
- ww=self.x[self.t_wconv[2*l+1]:self.t_wconv[2*l+1]+nw]
192
- ww=self.f.backend.bk_reshape(ww,[l_chan,
193
- out_chan,
194
- kernelsz*kernelsz])
195
- l_data = self.t_hconv[l].Convol_torch(c_data,ww)
196
-
197
- # plus one to add the input downgrade data
198
- l_chan=out_chan
199
-
366
+ # Upsample to finer grid
367
+ l_chan += 1 # due to upcoming concat with ud_data
368
+ l_data = self.f.up_grade(
369
+ l_data,
370
+ l_nside * 2,
371
+ cell_ids=self.l_cell_ids[nlayer - l],
372
+ o_cell_ids=self.l_cell_ids[nlayer - 1 - l],
373
+ nside=l_nside,
374
+ )
375
+
376
+ # Concatenate with encoder skip features
377
+ l_data = self.f.backend.bk_concat([ud_data[nlayer - 1 - l], l_data], 1)
378
+ l_nside *= 2
379
+
380
+ # Determine output channels at this level
381
+ out_chan = 1
382
+ if nlayer - 1 - l > 0:
383
+ out_chan += self.chanlist[nlayer - 1 - l]
384
+
385
+ # First decoder conv (l_chan -> l_chan)
386
+ nw = l_chan * l_chan * kernelsz * kernelsz
387
+ ww = self.x[self.t_wconv[2 * l] : self.t_wconv[2 * l] + nw]
388
+ ww = self.f.backend.bk_reshape(ww, [l_chan, l_chan, kernelsz * kernelsz])
389
+ c_data = self.t_hconv[l].Convol_torch(l_data, ww)
390
+
391
+ # Second decoder conv (l_chan -> out_chan)
392
+ nw = l_chan * out_chan * kernelsz * kernelsz
393
+ ww = self.x[self.t_wconv[2 * l + 1] : self.t_wconv[2 * l + 1] + nw]
394
+ ww = self.f.backend.bk_reshape(ww, [l_chan, out_chan, kernelsz * kernelsz])
395
+ l_data = self.t_hconv[l].Convol_torch(c_data, ww)
396
+
397
+ # Update channel count for next iteration
398
+ l_chan = out_chan
399
+
200
400
  return l_data
401
+
402
+
403
+ # -----------------------------
404
+ # Unit tests (smoke tests)
405
+ # -----------------------------
406
+ # Run with: python UNET.py (or) python UNET.py -q for quieter output
407
+ # These tests assume Foscat and its dependencies are installed.
408
+
409
+
410
+ def _dummy_cell_ids(nside: int) -> np.ndarray:
411
+ """Return a simple identity mapping for HEALPix nested pixel IDs.
412
+
413
+ Notes
414
+ -----
415
+ Replace with your pipeline's real `cell_ids` if you have a precomputed
416
+ mapping consistent with Foscat/HEALPix nested ordering.
417
+ """
418
+ return np.arange(12 * nside * nside, dtype=np.int64)
419
+
420
+
421
+ if __name__ == "__main__":
422
+ import unittest
423
+
424
+ class TestUNET(unittest.TestCase):
425
+ """Lightweight smoke tests for shape and parameter plumbing."""
426
+
427
+ def setUp(self):
428
+ self.nside = 4 # small grid for fast tests (npix = 192)
429
+ self.chanlist = [4, 8] # two-level encoder/decoder
430
+ self.batch = 2
431
+ self.channels = 1
432
+ self.npix = 12 * self.nside * self.nside
433
+ self.cell_ids = _dummy_cell_ids(self.nside)
434
+ self.net = UNET(
435
+ in_nside=self.nside,
436
+ n_chan_in=self.channels,
437
+ chanlist=self.chanlist,
438
+ cell_ids=self.cell_ids,
439
+ )
440
+
441
+ def test_forward_shape(self):
442
+ # random input
443
+ x = np.random.randn(self.batch, self.channels, self.npix).astype(np.float32)
444
+ x = self.net.f.backend.bk_cast(x)
445
+ y = self.net.eval(x)
446
+ # expected output: same npix, 1 channel at the very top
447
+ self.assertEqual(y.shape[0], self.batch)
448
+ self.assertEqual(y.shape[1], 1)
449
+ self.assertEqual(y.shape[2], self.npix)
450
+ # sanity: no NaNs
451
+ y_np = self.net.f.backend.to_numpy(y)
452
+ self.assertFalse(np.isnan(y_np).any())
453
+
454
+ def test_param_roundtrip_and_determinism(self):
455
+ x = np.random.randn(self.batch, self.channels, self.npix).astype(np.float32)
456
+ x = self.net.f.backend.bk_cast(x)
457
+
458
+ # forward twice -> identical outputs with fixed params
459
+ y1 = self.net.eval(x)
460
+ y2 = self.net.eval(x)
461
+ y1_np = self.net.f.backend.to_numpy(y1)
462
+ y2_np = self.net.f.backend.to_numpy(y2)
463
+ np.testing.assert_allclose(y1_np, y2_np, rtol=0, atol=0)
464
+
465
+ # perturb parameters -> output should (very likely) change
466
+ p = self.net.get_param()
467
+ p_np = self.net.f.backend.to_numpy(p).copy()
468
+ if p_np.size > 0:
469
+ p_np[0] += 1.0
470
+ self.net.set_param(p_np)
471
+ y3 = self.net.eval(x)
472
+ y3_np = self.net.f.backend.to_numpy(y3)
473
+ with self.assertRaises(AssertionError):
474
+ np.testing.assert_allclose(y1_np, y3_np, rtol=0, atol=0)
475
+
476
+ unittest.main()
477
+