foscat 2025.9.5__py3-none-any.whl → 2025.11.1__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- foscat/BkTorch.py +635 -141
- foscat/FoCUS.py +92 -50
- foscat/Plot.py +4 -3
- foscat/healpix_unet_torch.py +17 -1
- foscat/healpix_vit_skip.py +445 -0
- foscat/healpix_vit_torch.py +521 -0
- foscat/planar_vit.py +206 -0
- foscat/scat.py +1 -1
- foscat/scat1D.py +1 -1
- foscat/scat_cov.py +2 -2
- foscat/unet_2_d_from_healpix_params.py +421 -0
- {foscat-2025.9.5.dist-info → foscat-2025.11.1.dist-info}/METADATA +1 -1
- {foscat-2025.9.5.dist-info → foscat-2025.11.1.dist-info}/RECORD +16 -12
- {foscat-2025.9.5.dist-info → foscat-2025.11.1.dist-info}/WHEEL +0 -0
- {foscat-2025.9.5.dist-info → foscat-2025.11.1.dist-info}/licenses/LICENSE +0 -0
- {foscat-2025.9.5.dist-info → foscat-2025.11.1.dist-info}/top_level.txt +0 -0
foscat/FoCUS.py
CHANGED
|
@@ -6,7 +6,7 @@ import numpy as np
|
|
|
6
6
|
import foscat.HealSpline as HS
|
|
7
7
|
from scipy.interpolate import griddata
|
|
8
8
|
|
|
9
|
-
TMPFILE_VERSION = "
|
|
9
|
+
TMPFILE_VERSION = "V10_0"
|
|
10
10
|
|
|
11
11
|
|
|
12
12
|
class FoCUS:
|
|
@@ -28,15 +28,15 @@ class FoCUS:
|
|
|
28
28
|
use_2D=False,
|
|
29
29
|
use_1D=False,
|
|
30
30
|
return_data=False,
|
|
31
|
-
JmaxDelta=0,
|
|
32
31
|
DODIV=False,
|
|
32
|
+
use_median=False,
|
|
33
33
|
InitWave=None,
|
|
34
34
|
silent=True,
|
|
35
35
|
mpi_size=1,
|
|
36
36
|
mpi_rank=0
|
|
37
37
|
):
|
|
38
38
|
|
|
39
|
-
self.__version__ = "2025.
|
|
39
|
+
self.__version__ = "2025.11.1"
|
|
40
40
|
# P00 coeff for normalization for scat_cov
|
|
41
41
|
self.TMPFILE_VERSION = TMPFILE_VERSION
|
|
42
42
|
self.P1_dic = None
|
|
@@ -50,6 +50,7 @@ class FoCUS:
|
|
|
50
50
|
self.mpi_rank = mpi_rank
|
|
51
51
|
self.return_data = return_data
|
|
52
52
|
self.silent = silent
|
|
53
|
+
self.use_median = use_median
|
|
53
54
|
|
|
54
55
|
self.kernel_smooth = {}
|
|
55
56
|
self.padding_smooth = {}
|
|
@@ -89,13 +90,6 @@ class FoCUS:
|
|
|
89
90
|
self.nlog = 0
|
|
90
91
|
self.padding = padding
|
|
91
92
|
|
|
92
|
-
if JmaxDelta != 0:
|
|
93
|
-
print(
|
|
94
|
-
"OPTION JmaxDelta is not avialable anymore after version 3.6.2. Please use Jmax option in eval function"
|
|
95
|
-
)
|
|
96
|
-
return None
|
|
97
|
-
|
|
98
|
-
self.OSTEP = JmaxDelta
|
|
99
93
|
self.use_2D = use_2D
|
|
100
94
|
self.use_1D = use_1D
|
|
101
95
|
|
|
@@ -673,13 +667,20 @@ class FoCUS:
|
|
|
673
667
|
tim = self.backend.bk_reshape(
|
|
674
668
|
self.backend.bk_cast(im), [ndata, npix, npiy, 1]
|
|
675
669
|
)
|
|
670
|
+
'''
|
|
676
671
|
tim = self.backend.bk_reshape(
|
|
677
672
|
tim[:, 0 : 2 * (npix // 2), 0 : 2 * (npiy // 2), :],
|
|
678
673
|
[ndata, npix // 2, 2, npiy // 2, 2, 1],
|
|
679
674
|
)
|
|
680
|
-
|
|
681
|
-
res = self.backend.bk_reduce_sum(self.backend.bk_reduce_sum(tim, 4), 2) / 4
|
|
682
|
-
|
|
675
|
+
|
|
676
|
+
#res = self.backend.bk_reduce_sum(self.backend.bk_reduce_sum(tim, 4), 2) / 4
|
|
677
|
+
'''
|
|
678
|
+
|
|
679
|
+
if self.use_median:
|
|
680
|
+
res = self.backend.downsample_median_2x2(tim)
|
|
681
|
+
else:
|
|
682
|
+
res = self.backend.downsample_mean_2x2(tim)
|
|
683
|
+
|
|
683
684
|
if len(ishape) == 2:
|
|
684
685
|
return (
|
|
685
686
|
self.backend.bk_reshape(
|
|
@@ -711,13 +712,26 @@ class FoCUS:
|
|
|
711
712
|
self.backend.bk_cast(im), [ndata, npix // 2, 2]
|
|
712
713
|
)
|
|
713
714
|
|
|
714
|
-
|
|
715
|
+
if self.use_median:
|
|
716
|
+
res=self.backend.bk_reduce_median(tim,axis=-1)
|
|
717
|
+
else:
|
|
718
|
+
res=self.backend.bk_reduce_mean(tim,axis=-1)
|
|
715
719
|
|
|
716
720
|
return self.backend.bk_reshape(res, ishape[0:-1] + [npix // 2]), None
|
|
717
721
|
|
|
718
722
|
else:
|
|
719
723
|
shape = list(im.shape)
|
|
720
|
-
if
|
|
724
|
+
if self.use_median:
|
|
725
|
+
if cell_ids is not None:
|
|
726
|
+
sim, new_cell_ids = self.backend.binned_mean(im, cell_ids,reduce='median')
|
|
727
|
+
return sim, new_cell_ids
|
|
728
|
+
|
|
729
|
+
return self.backend.bk_reduce_median(
|
|
730
|
+
self.backend.bk_reshape(im, shape[0:-1]+[shape[-1]//4,4]), axis=-1
|
|
731
|
+
),None
|
|
732
|
+
|
|
733
|
+
elif max_poll:
|
|
734
|
+
|
|
721
735
|
if cell_ids is not None:
|
|
722
736
|
sim, new_cell_ids = self.backend.binned_mean(im, cell_ids,reduce='max')
|
|
723
737
|
return sim, new_cell_ids
|
|
@@ -1488,7 +1502,7 @@ class FoCUS:
|
|
|
1488
1502
|
if l_kernel == 5:
|
|
1489
1503
|
pw = 0.5
|
|
1490
1504
|
pw2 = 0.5
|
|
1491
|
-
threshold = 2e-
|
|
1505
|
+
threshold = 2e-5
|
|
1492
1506
|
|
|
1493
1507
|
elif l_kernel == 3:
|
|
1494
1508
|
pw = 1.0 / np.sqrt(2)
|
|
@@ -1498,7 +1512,7 @@ class FoCUS:
|
|
|
1498
1512
|
elif l_kernel == 7:
|
|
1499
1513
|
pw = 0.5
|
|
1500
1514
|
pw2 = 0.25
|
|
1501
|
-
threshold =
|
|
1515
|
+
threshold = 2e-5
|
|
1502
1516
|
|
|
1503
1517
|
import foscat.SphericalStencil as hs
|
|
1504
1518
|
import torch
|
|
@@ -1517,14 +1531,19 @@ class FoCUS:
|
|
|
1517
1531
|
n_gauges=self.NORIENT,
|
|
1518
1532
|
gauge_type='cosmo')
|
|
1519
1533
|
|
|
1520
|
-
xx=np.tile(np.arange(self.KERNELSZ)-self.KERNELSZ//2,self.KERNELSZ).reshape(self.KERNELSZ
|
|
1534
|
+
xx=np.tile(np.arange(self.KERNELSZ)-self.KERNELSZ//2,self.KERNELSZ).reshape(self.KERNELSZ,self.KERNELSZ)
|
|
1521
1535
|
|
|
1522
|
-
wwr=
|
|
1536
|
+
wwr=(np.exp(-pw2*(xx**2+(xx.T)**2))*np.cos(pw*xx*np.pi)).reshape(1,1,self.KERNELSZ*self.KERNELSZ)
|
|
1523
1537
|
wwr-=wwr.mean()
|
|
1524
|
-
wwi=
|
|
1538
|
+
wwi=(np.exp(-pw2*(xx**2+(xx.T)**2))*np.sin(pw*xx*np.pi)).reshape(1,1,self.KERNELSZ*self.KERNELSZ)
|
|
1525
1539
|
wwi-=wwi.mean()
|
|
1526
|
-
|
|
1527
|
-
|
|
1540
|
+
amp=np.sum(abs(wwr+1J*wwi))
|
|
1541
|
+
|
|
1542
|
+
wwr/=amp
|
|
1543
|
+
wwi/=amp
|
|
1544
|
+
|
|
1545
|
+
wwr=hconvol.to_tensor(wwr)
|
|
1546
|
+
wwi=hconvol.to_tensor(wwi)
|
|
1528
1547
|
|
|
1529
1548
|
wavr,indice,mshape=hconvol.make_matrix(wwr)
|
|
1530
1549
|
wavi,indice,mshape=hconvol.make_matrix(wwi)
|
|
@@ -2180,17 +2199,21 @@ class FoCUS:
|
|
|
2180
2199
|
# mtmp = l_mask[:,self.KERNELSZ // 2 : -self.KERNELSZ // 2,self.KERNELSZ // 2 : -self.KERNELSZ // 2,:]
|
|
2181
2200
|
# vtmp = l_x[:,self.KERNELSZ // 2 : -self.KERNELSZ // 2,self.KERNELSZ // 2 : -self.KERNELSZ // 2,:]
|
|
2182
2201
|
|
|
2183
|
-
|
|
2184
|
-
self.backend.
|
|
2185
|
-
|
|
2186
|
-
|
|
2187
|
-
|
|
2188
|
-
|
|
2189
|
-
|
|
2190
|
-
|
|
2191
|
-
|
|
2202
|
+
if self.use_median:
|
|
2203
|
+
res,res2 = self.backend.bk_masked_median_2d_weiszfeld(vtmp, mtmp)
|
|
2204
|
+
else:
|
|
2205
|
+
v1 = self.backend.bk_reduce_sum(
|
|
2206
|
+
self.backend.bk_reduce_sum(mtmp * vtmp, axis=-1), -1
|
|
2207
|
+
)
|
|
2208
|
+
v2 = self.backend.bk_reduce_sum(
|
|
2209
|
+
self.backend.bk_reduce_sum(mtmp * vtmp * vtmp, axis=-1), -1
|
|
2210
|
+
)
|
|
2211
|
+
vh = self.backend.bk_reduce_sum(
|
|
2212
|
+
self.backend.bk_reduce_sum(mtmp, axis=-1), -1
|
|
2213
|
+
)
|
|
2192
2214
|
|
|
2193
|
-
|
|
2215
|
+
res = v1 / vh
|
|
2216
|
+
res2= v2 / vh
|
|
2194
2217
|
|
|
2195
2218
|
oshape = [x.shape[0]] + [mask.shape[0]]
|
|
2196
2219
|
if len(x.shape) > 3:
|
|
@@ -2199,22 +2222,26 @@ class FoCUS:
|
|
|
2199
2222
|
oshape = oshape + [1]
|
|
2200
2223
|
|
|
2201
2224
|
if calc_var:
|
|
2225
|
+
if self.use_median:
|
|
2226
|
+
vh = self.backend.bk_reduce_sum(
|
|
2227
|
+
self.backend.bk_reduce_sum(mtmp, axis=-1), -1
|
|
2228
|
+
)
|
|
2202
2229
|
if self.backend.bk_is_complex(vtmp):
|
|
2203
2230
|
res2 = self.backend.bk_sqrt(
|
|
2204
2231
|
(
|
|
2205
2232
|
(
|
|
2206
|
-
self.backend.bk_real(
|
|
2233
|
+
self.backend.bk_real(res2)
|
|
2207
2234
|
- self.backend.bk_real(res) * self.backend.bk_real(res)
|
|
2208
2235
|
)
|
|
2209
2236
|
+ (
|
|
2210
|
-
self.backend.bk_imag(
|
|
2237
|
+
self.backend.bk_imag(res2)
|
|
2211
2238
|
- self.backend.bk_imag(res) * self.backend.bk_imag(res)
|
|
2212
2239
|
)
|
|
2213
2240
|
)
|
|
2214
2241
|
/ self.backend.bk_real(vh)
|
|
2215
2242
|
)
|
|
2216
2243
|
else:
|
|
2217
|
-
res2 = self.backend.bk_sqrt((
|
|
2244
|
+
res2 = self.backend.bk_sqrt((res2 - res * res) / (vh))
|
|
2218
2245
|
|
|
2219
2246
|
res = self.backend.bk_reshape(res, oshape)
|
|
2220
2247
|
res2 = self.backend.bk_reshape(res2, oshape)
|
|
@@ -2226,11 +2253,16 @@ class FoCUS:
|
|
|
2226
2253
|
elif self.use_1D:
|
|
2227
2254
|
mtmp = l_mask
|
|
2228
2255
|
vtmp = l_x
|
|
2229
|
-
|
|
2230
|
-
|
|
2231
|
-
|
|
2256
|
+
|
|
2257
|
+
if self.use_median:
|
|
2258
|
+
res,res2 = self.backend.bk_masked_median(l_x, l_mask)
|
|
2259
|
+
else:
|
|
2260
|
+
v1 = self.backend.bk_reduce_sum(l_mask * l_x, axis=-1)
|
|
2261
|
+
v2 = self.backend.bk_reduce_sum(l_mask * l_x * l_x, axis=-1)
|
|
2262
|
+
vh = self.backend.bk_reduce_sum(l_mask, axis=-1)
|
|
2232
2263
|
|
|
2233
|
-
|
|
2264
|
+
res = v1 / vh
|
|
2265
|
+
res2= v2 / vh
|
|
2234
2266
|
|
|
2235
2267
|
oshape = [x.shape[0]] + [mask.shape[0]]
|
|
2236
2268
|
if len(x.shape) > 1:
|
|
@@ -2239,35 +2271,42 @@ class FoCUS:
|
|
|
2239
2271
|
oshape = oshape + [1]
|
|
2240
2272
|
|
|
2241
2273
|
if calc_var:
|
|
2274
|
+
if self.use_median:
|
|
2275
|
+
vh = self.backend.bk_reduce_sum(l_mask, axis=-1)
|
|
2276
|
+
|
|
2242
2277
|
if self.backend.bk_is_complex(vtmp):
|
|
2243
2278
|
res2 = self.backend.bk_sqrt(
|
|
2244
2279
|
(
|
|
2245
2280
|
(
|
|
2246
|
-
self.backend.bk_real(
|
|
2281
|
+
self.backend.bk_real(res2)
|
|
2247
2282
|
- self.backend.bk_real(res) * self.backend.bk_real(res)
|
|
2248
2283
|
)
|
|
2249
2284
|
+ (
|
|
2250
|
-
self.backend.bk_imag(
|
|
2285
|
+
self.backend.bk_imag(res2)
|
|
2251
2286
|
- self.backend.bk_imag(res) * self.backend.bk_imag(res)
|
|
2252
2287
|
)
|
|
2253
2288
|
)
|
|
2254
2289
|
/ self.backend.bk_real(vh)
|
|
2255
2290
|
)
|
|
2256
2291
|
else:
|
|
2257
|
-
res2 = self.backend.bk_sqrt((
|
|
2292
|
+
res2 = self.backend.bk_sqrt((res2 - res * res) / (vh))
|
|
2293
|
+
|
|
2258
2294
|
res = self.backend.bk_reshape(res, oshape)
|
|
2259
2295
|
res2 = self.backend.bk_reshape(res2, oshape)
|
|
2260
2296
|
return res, res2
|
|
2261
2297
|
else:
|
|
2262
2298
|
res = self.backend.bk_reshape(res, oshape)
|
|
2263
2299
|
return res
|
|
2264
|
-
|
|
2265
2300
|
else:
|
|
2266
|
-
|
|
2267
|
-
|
|
2268
|
-
|
|
2301
|
+
if self.use_median:
|
|
2302
|
+
res,res2 = self.backend.bk_masked_median(l_x, l_mask)
|
|
2303
|
+
else:
|
|
2304
|
+
v1 = self.backend.bk_reduce_sum(l_mask * l_x, axis=-1)
|
|
2305
|
+
v2 = self.backend.bk_reduce_sum(l_mask * l_x * l_x, axis=-1)
|
|
2306
|
+
vh = self.backend.bk_reduce_sum(l_mask, axis=-1)
|
|
2269
2307
|
|
|
2270
|
-
|
|
2308
|
+
res = v1 / vh
|
|
2309
|
+
res2= v2 / vh
|
|
2271
2310
|
|
|
2272
2311
|
oshape = []
|
|
2273
2312
|
if len(shape) > 1:
|
|
@@ -2276,24 +2315,27 @@ class FoCUS:
|
|
|
2276
2315
|
oshape = [1]
|
|
2277
2316
|
|
|
2278
2317
|
oshape = oshape + [mask.shape[0]]
|
|
2318
|
+
|
|
2279
2319
|
if len(shape) > 2:
|
|
2280
2320
|
oshape = oshape + shape[1:-1]
|
|
2281
2321
|
else:
|
|
2282
2322
|
oshape = oshape + [1]
|
|
2283
2323
|
|
|
2284
2324
|
if calc_var:
|
|
2325
|
+
if self.use_median:
|
|
2326
|
+
vh = self.backend.bk_reduce_sum(l_mask, axis=-1)
|
|
2285
2327
|
if self.backend.bk_is_complex(l_x):
|
|
2286
2328
|
res2 = self.backend.bk_sqrt(
|
|
2287
2329
|
(
|
|
2288
|
-
self.backend.bk_real(
|
|
2330
|
+
self.backend.bk_real(res2)
|
|
2289
2331
|
- self.backend.bk_real(res) * self.backend.bk_real(res)
|
|
2290
|
-
+ self.backend.bk_imag(
|
|
2332
|
+
+ self.backend.bk_imag(res2)
|
|
2291
2333
|
- self.backend.bk_imag(res) * self.backend.bk_imag(res)
|
|
2292
2334
|
)
|
|
2293
2335
|
/ self.backend.bk_real(vh)
|
|
2294
2336
|
)
|
|
2295
2337
|
else:
|
|
2296
|
-
res2 = self.backend.bk_sqrt((
|
|
2338
|
+
res2 = self.backend.bk_sqrt((res2 - res * res) / (vh))
|
|
2297
2339
|
|
|
2298
2340
|
res = self.backend.bk_reshape(res, oshape)
|
|
2299
2341
|
res2 = self.backend.bk_reshape(res2, oshape)
|
foscat/Plot.py
CHANGED
|
@@ -959,8 +959,9 @@ def conjugate_gradient_normal_equation(data, x0, www, all_idx,
|
|
|
959
959
|
|
|
960
960
|
rs_new = np.dot(r, r)
|
|
961
961
|
|
|
962
|
-
if verbose and i %
|
|
963
|
-
|
|
962
|
+
if verbose and i % 10 == 0:
|
|
963
|
+
v=np.mean((LP(p, www, all_idx)-data)**2)
|
|
964
|
+
print(f"Iter {i:03d}: residual = {np.sqrt(rs_new):.3e},{np.sqrt(v):.3e}")
|
|
964
965
|
|
|
965
966
|
if np.sqrt(rs_new) < tol:
|
|
966
967
|
if verbose:
|
|
@@ -1155,7 +1156,7 @@ def plot_wave(wave,title="spectrum",unit="Amplitude",cmap="viridis"):
|
|
|
1155
1156
|
plt.xlabel(r"$k_x$ [cycles / km]")
|
|
1156
1157
|
plt.ylabel(r"$k_y$ [cycles / km]")
|
|
1157
1158
|
plt.title(title)
|
|
1158
|
-
|
|
1159
|
+
|
|
1159
1160
|
def lonlat_edges_from_ref(shape, ref_lon, ref_lat, dlon, dlat, anchor="center"):
|
|
1160
1161
|
"""
|
|
1161
1162
|
Build lon/lat *edges* (H+1, W+1) for a regular, axis-aligned grid.
|
foscat/healpix_unet_torch.py
CHANGED
|
@@ -982,6 +982,9 @@ def fit(
|
|
|
982
982
|
n_epoch: int = 10,
|
|
983
983
|
view_epoch: int = 10,
|
|
984
984
|
batch_size: int = 16,
|
|
985
|
+
x_valid: Union[torch.Tensor, np.ndarray, List[Union[torch.Tensor, np.ndarray]]]= None,
|
|
986
|
+
y_valid: Union[torch.Tensor, np.ndarray, List[Union[torch.Tensor, np.ndarray]]]= None,
|
|
987
|
+
save_model: bool = False,
|
|
985
988
|
lr: float = 1e-3,
|
|
986
989
|
weight_decay: float = 0.0,
|
|
987
990
|
clip_grad_norm: Optional[float] = None,
|
|
@@ -1005,6 +1008,11 @@ def fit(
|
|
|
1005
1008
|
device = model.runtime_device if hasattr(model, "runtime_device") else (torch.device("cuda") if torch.cuda.is_available() else torch.device("cpu"))
|
|
1006
1009
|
model.to(device)
|
|
1007
1010
|
|
|
1011
|
+
if save_model:
|
|
1012
|
+
assert x_valid is None, "If save_mode=True x_valid should not be None"
|
|
1013
|
+
assert y_valid is None, "If save_mode=True y_valid should not be None"
|
|
1014
|
+
best_valid=1E30
|
|
1015
|
+
|
|
1008
1016
|
# Detect variable-length mode
|
|
1009
1017
|
varlen_mode = isinstance(x_train, (list, tuple))
|
|
1010
1018
|
|
|
@@ -1197,6 +1205,14 @@ def fit(
|
|
|
1197
1205
|
history.append(epoch_loss)
|
|
1198
1206
|
# print every view_epoch logical step
|
|
1199
1207
|
if verbose and ((len(history) % view_epoch == 0) or (len(history) == 1)):
|
|
1200
|
-
|
|
1208
|
+
if x_valid is not None:
|
|
1209
|
+
preds=model.predict(model.to_tensor(x_valid)).cpu().numpy()
|
|
1210
|
+
valid_loss=np.mean((preds-y_valid)**2)
|
|
1211
|
+
if save_model:
|
|
1212
|
+
if best_valid>valid_loss:
|
|
1213
|
+
torch.save({"model": self.state_dict(), "cfg": CFG}, os.path.join(CFG["save_dir"], "best.pt"))
|
|
1214
|
+
print(f"[epoch {len(history)}] loss={epoch_loss:.4f} loss_valid={valid_loss:.4f}")
|
|
1215
|
+
else:
|
|
1216
|
+
print(f"[epoch {len(history)}] loss={epoch_loss:.4f}")
|
|
1201
1217
|
|
|
1202
1218
|
return {"loss": history}
|