foscat 2025.9.1__py3-none-any.whl → 2025.9.3__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
foscat/UNET.py CHANGED
@@ -24,7 +24,7 @@ All tensors follow the Foscat backend shape `(batch, channels, npix)`.
24
24
  Dependencies
25
25
  ------------
26
26
  - foscat.scat_cov as `sc`
27
- - foscat.HOrientedConvol as `hs`
27
+ - foscat.SphericalStencil as `hs`
28
28
 
29
29
  Example
30
30
  -------
@@ -52,7 +52,7 @@ from typing import Dict, Optional
52
52
  import numpy as np
53
53
 
54
54
  import foscat.scat_cov as sc
55
- import foscat.HOrientedConvol as hs
55
+ import foscat.SphericalPencil as hs
56
56
 
57
57
 
58
58
  class UNET:
@@ -99,7 +99,7 @@ class UNET:
99
99
  wconv, t_wconv : Dict[int, int]
100
100
  Offsets into the flat parameter vector `self.x` for encoder/decoder
101
101
  convolutions respectively.
102
- hconv, t_hconv : Dict[int, hs.HOrientedConvol]
102
+ hconv, t_hconv : Dict[int, hs.SphericalPencil]
103
103
  Per-level oriented convolution operators for encoder/decoder.
104
104
  l_cell_ids : Dict[int, np.ndarray]
105
105
  Per-level cell ids for downsampled grids (encoder side).
@@ -148,7 +148,7 @@ class UNET:
148
148
  # Internal registries
149
149
  n = 0 # running offset in the flat parameter vector
150
150
  wconv: Dict[int, int] = {} # encoder weight offsets
151
- hconv: Dict[int, hs.HOrientedConvol] = {} # encoder conv operators
151
+ hconv: Dict[int, hs.SphericalPencil] = {} # encoder conv operators
152
152
  l_cell_ids: Dict[int, np.ndarray] = {} # encoder level cell ids
153
153
  self.KERNELSZ = KERNELSZ
154
154
  kernelsz = self.KERNELSZ
@@ -181,7 +181,7 @@ class UNET:
181
181
  n += nw
182
182
 
183
183
  # Build oriented convolution operator for this level
184
- hconvol = hs.HOrientedConvol(l_nside, 3, cell_ids=l_cell_ids[l])
184
+ hconvol = hs.SphericalPencil(l_nside, 3, cell_ids=l_cell_ids[l])
185
185
  hconvol.make_idx_weights() # precompute indices/weights once
186
186
  hconv[l] = hconvol
187
187
 
@@ -207,7 +207,7 @@ class UNET:
207
207
  m_cell_ids: Dict[int, np.ndarray] = {}
208
208
  m_cell_ids[0] = l_cell_ids[nlayer]
209
209
  t_wconv: Dict[int, int] = {} # decoder weight offsets
210
- t_hconv: Dict[int, hs.HOrientedConvol] = {} # decoder conv operators
210
+ t_hconv: Dict[int, hs.SphericalPencil] = {} # decoder conv operators
211
211
 
212
212
  for l in range(nlayer):
213
213
  # Upsample features to the previous (finer) resolution
@@ -240,7 +240,7 @@ class UNET:
240
240
  n += nw
241
241
 
242
242
  # Build oriented convolution operator for this decoder level
243
- hconvol = hs.HOrientedConvol(l_nside, 3, cell_ids=m_cell_ids[l])
243
+ hconvol = hs.SphericalPencil(l_nside, 3, cell_ids=m_cell_ids[l])
244
244
  hconvol.make_idx_weights()
245
245
  t_hconv[l] = hconvol
246
246
 
@@ -400,6 +400,20 @@ class UNET:
400
400
  return l_data
401
401
 
402
402
 
403
+ def to_tensor(self,x):
404
+ if self.f is None:
405
+ if self.dtype==torch.float64:
406
+ self.f=sc.funct(KERNELSZ=self.KERNELSZ,all_type='float64')
407
+ else:
408
+ self.f=sc.funct(KERNELSZ=self.KERNELSZ,all_type='float32')
409
+ return self.f.backend.bk_cast(x)
410
+
411
+ def to_numpy(self,x):
412
+ if isinstance(x,np.ndarray):
413
+ return x
414
+ return x.cpu().numpy()
415
+
416
+
403
417
  # -----------------------------
404
418
  # Unit tests (smoke tests)
405
419
  # -----------------------------