foscat 2025.9.1__py3-none-any.whl → 2025.9.3__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- foscat/BkTorch.py +160 -93
- foscat/FoCUS.py +74 -267
- foscat/HOrientedConvol.py +233 -250
- foscat/HealBili.py +12 -8
- foscat/Plot.py +9 -6
- foscat/SphericalStencil.py +1346 -0
- foscat/UNET.py +21 -7
- foscat/healpix_unet_torch.py +656 -171
- foscat/scat_cov.py +2 -0
- {foscat-2025.9.1.dist-info → foscat-2025.9.3.dist-info}/METADATA +1 -1
- {foscat-2025.9.1.dist-info → foscat-2025.9.3.dist-info}/RECORD +14 -13
- {foscat-2025.9.1.dist-info → foscat-2025.9.3.dist-info}/WHEEL +0 -0
- {foscat-2025.9.1.dist-info → foscat-2025.9.3.dist-info}/licenses/LICENSE +0 -0
- {foscat-2025.9.1.dist-info → foscat-2025.9.3.dist-info}/top_level.txt +0 -0
foscat/UNET.py
CHANGED
|
@@ -24,7 +24,7 @@ All tensors follow the Foscat backend shape `(batch, channels, npix)`.
|
|
|
24
24
|
Dependencies
|
|
25
25
|
------------
|
|
26
26
|
- foscat.scat_cov as `sc`
|
|
27
|
-
- foscat.
|
|
27
|
+
- foscat.SphericalStencil as `hs`
|
|
28
28
|
|
|
29
29
|
Example
|
|
30
30
|
-------
|
|
@@ -52,7 +52,7 @@ from typing import Dict, Optional
|
|
|
52
52
|
import numpy as np
|
|
53
53
|
|
|
54
54
|
import foscat.scat_cov as sc
|
|
55
|
-
import foscat.
|
|
55
|
+
import foscat.SphericalPencil as hs
|
|
56
56
|
|
|
57
57
|
|
|
58
58
|
class UNET:
|
|
@@ -99,7 +99,7 @@ class UNET:
|
|
|
99
99
|
wconv, t_wconv : Dict[int, int]
|
|
100
100
|
Offsets into the flat parameter vector `self.x` for encoder/decoder
|
|
101
101
|
convolutions respectively.
|
|
102
|
-
hconv, t_hconv : Dict[int, hs.
|
|
102
|
+
hconv, t_hconv : Dict[int, hs.SphericalPencil]
|
|
103
103
|
Per-level oriented convolution operators for encoder/decoder.
|
|
104
104
|
l_cell_ids : Dict[int, np.ndarray]
|
|
105
105
|
Per-level cell ids for downsampled grids (encoder side).
|
|
@@ -148,7 +148,7 @@ class UNET:
|
|
|
148
148
|
# Internal registries
|
|
149
149
|
n = 0 # running offset in the flat parameter vector
|
|
150
150
|
wconv: Dict[int, int] = {} # encoder weight offsets
|
|
151
|
-
hconv: Dict[int, hs.
|
|
151
|
+
hconv: Dict[int, hs.SphericalPencil] = {} # encoder conv operators
|
|
152
152
|
l_cell_ids: Dict[int, np.ndarray] = {} # encoder level cell ids
|
|
153
153
|
self.KERNELSZ = KERNELSZ
|
|
154
154
|
kernelsz = self.KERNELSZ
|
|
@@ -181,7 +181,7 @@ class UNET:
|
|
|
181
181
|
n += nw
|
|
182
182
|
|
|
183
183
|
# Build oriented convolution operator for this level
|
|
184
|
-
hconvol = hs.
|
|
184
|
+
hconvol = hs.SphericalPencil(l_nside, 3, cell_ids=l_cell_ids[l])
|
|
185
185
|
hconvol.make_idx_weights() # precompute indices/weights once
|
|
186
186
|
hconv[l] = hconvol
|
|
187
187
|
|
|
@@ -207,7 +207,7 @@ class UNET:
|
|
|
207
207
|
m_cell_ids: Dict[int, np.ndarray] = {}
|
|
208
208
|
m_cell_ids[0] = l_cell_ids[nlayer]
|
|
209
209
|
t_wconv: Dict[int, int] = {} # decoder weight offsets
|
|
210
|
-
t_hconv: Dict[int, hs.
|
|
210
|
+
t_hconv: Dict[int, hs.SphericalPencil] = {} # decoder conv operators
|
|
211
211
|
|
|
212
212
|
for l in range(nlayer):
|
|
213
213
|
# Upsample features to the previous (finer) resolution
|
|
@@ -240,7 +240,7 @@ class UNET:
|
|
|
240
240
|
n += nw
|
|
241
241
|
|
|
242
242
|
# Build oriented convolution operator for this decoder level
|
|
243
|
-
hconvol = hs.
|
|
243
|
+
hconvol = hs.SphericalPencil(l_nside, 3, cell_ids=m_cell_ids[l])
|
|
244
244
|
hconvol.make_idx_weights()
|
|
245
245
|
t_hconv[l] = hconvol
|
|
246
246
|
|
|
@@ -400,6 +400,20 @@ class UNET:
|
|
|
400
400
|
return l_data
|
|
401
401
|
|
|
402
402
|
|
|
403
|
+
def to_tensor(self,x):
|
|
404
|
+
if self.f is None:
|
|
405
|
+
if self.dtype==torch.float64:
|
|
406
|
+
self.f=sc.funct(KERNELSZ=self.KERNELSZ,all_type='float64')
|
|
407
|
+
else:
|
|
408
|
+
self.f=sc.funct(KERNELSZ=self.KERNELSZ,all_type='float32')
|
|
409
|
+
return self.f.backend.bk_cast(x)
|
|
410
|
+
|
|
411
|
+
def to_numpy(self,x):
|
|
412
|
+
if isinstance(x,np.ndarray):
|
|
413
|
+
return x
|
|
414
|
+
return x.cpu().numpy()
|
|
415
|
+
|
|
416
|
+
|
|
403
417
|
# -----------------------------
|
|
404
418
|
# Unit tests (smoke tests)
|
|
405
419
|
# -----------------------------
|