foscat 2025.9.5__py3-none-any.whl → 2025.11.1__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
foscat/FoCUS.py CHANGED
@@ -6,7 +6,7 @@ import numpy as np
6
6
  import foscat.HealSpline as HS
7
7
  from scipy.interpolate import griddata
8
8
 
9
- TMPFILE_VERSION = "V9_0"
9
+ TMPFILE_VERSION = "V10_0"
10
10
 
11
11
 
12
12
  class FoCUS:
@@ -28,15 +28,15 @@ class FoCUS:
28
28
  use_2D=False,
29
29
  use_1D=False,
30
30
  return_data=False,
31
- JmaxDelta=0,
32
31
  DODIV=False,
32
+ use_median=False,
33
33
  InitWave=None,
34
34
  silent=True,
35
35
  mpi_size=1,
36
36
  mpi_rank=0
37
37
  ):
38
38
 
39
- self.__version__ = "2025.09.5"
39
+ self.__version__ = "2025.11.1"
40
40
  # P00 coeff for normalization for scat_cov
41
41
  self.TMPFILE_VERSION = TMPFILE_VERSION
42
42
  self.P1_dic = None
@@ -50,6 +50,7 @@ class FoCUS:
50
50
  self.mpi_rank = mpi_rank
51
51
  self.return_data = return_data
52
52
  self.silent = silent
53
+ self.use_median = use_median
53
54
 
54
55
  self.kernel_smooth = {}
55
56
  self.padding_smooth = {}
@@ -89,13 +90,6 @@ class FoCUS:
89
90
  self.nlog = 0
90
91
  self.padding = padding
91
92
 
92
- if JmaxDelta != 0:
93
- print(
94
- "OPTION JmaxDelta is not avialable anymore after version 3.6.2. Please use Jmax option in eval function"
95
- )
96
- return None
97
-
98
- self.OSTEP = JmaxDelta
99
93
  self.use_2D = use_2D
100
94
  self.use_1D = use_1D
101
95
 
@@ -673,13 +667,20 @@ class FoCUS:
673
667
  tim = self.backend.bk_reshape(
674
668
  self.backend.bk_cast(im), [ndata, npix, npiy, 1]
675
669
  )
670
+ '''
676
671
  tim = self.backend.bk_reshape(
677
672
  tim[:, 0 : 2 * (npix // 2), 0 : 2 * (npiy // 2), :],
678
673
  [ndata, npix // 2, 2, npiy // 2, 2, 1],
679
674
  )
680
-
681
- res = self.backend.bk_reduce_sum(self.backend.bk_reduce_sum(tim, 4), 2) / 4
682
-
675
+
676
+ #res = self.backend.bk_reduce_sum(self.backend.bk_reduce_sum(tim, 4), 2) / 4
677
+ '''
678
+
679
+ if self.use_median:
680
+ res = self.backend.downsample_median_2x2(tim)
681
+ else:
682
+ res = self.backend.downsample_mean_2x2(tim)
683
+
683
684
  if len(ishape) == 2:
684
685
  return (
685
686
  self.backend.bk_reshape(
@@ -711,13 +712,26 @@ class FoCUS:
711
712
  self.backend.bk_cast(im), [ndata, npix // 2, 2]
712
713
  )
713
714
 
714
- res = self.backend.bk_reduce_mean(tim, -1)
715
+ if self.use_median:
716
+ res=self.backend.bk_reduce_median(tim,axis=-1)
717
+ else:
718
+ res=self.backend.bk_reduce_mean(tim,axis=-1)
715
719
 
716
720
  return self.backend.bk_reshape(res, ishape[0:-1] + [npix // 2]), None
717
721
 
718
722
  else:
719
723
  shape = list(im.shape)
720
- if max_poll:
724
+ if self.use_median:
725
+ if cell_ids is not None:
726
+ sim, new_cell_ids = self.backend.binned_mean(im, cell_ids,reduce='median')
727
+ return sim, new_cell_ids
728
+
729
+ return self.backend.bk_reduce_median(
730
+ self.backend.bk_reshape(im, shape[0:-1]+[shape[-1]//4,4]), axis=-1
731
+ ),None
732
+
733
+ elif max_poll:
734
+
721
735
  if cell_ids is not None:
722
736
  sim, new_cell_ids = self.backend.binned_mean(im, cell_ids,reduce='max')
723
737
  return sim, new_cell_ids
@@ -1488,7 +1502,7 @@ class FoCUS:
1488
1502
  if l_kernel == 5:
1489
1503
  pw = 0.5
1490
1504
  pw2 = 0.5
1491
- threshold = 2e-4
1505
+ threshold = 2e-5
1492
1506
 
1493
1507
  elif l_kernel == 3:
1494
1508
  pw = 1.0 / np.sqrt(2)
@@ -1498,7 +1512,7 @@ class FoCUS:
1498
1512
  elif l_kernel == 7:
1499
1513
  pw = 0.5
1500
1514
  pw2 = 0.25
1501
- threshold = 4e-5
1515
+ threshold = 2e-5
1502
1516
 
1503
1517
  import foscat.SphericalStencil as hs
1504
1518
  import torch
@@ -1517,14 +1531,19 @@ class FoCUS:
1517
1531
  n_gauges=self.NORIENT,
1518
1532
  gauge_type='cosmo')
1519
1533
 
1520
- xx=np.tile(np.arange(self.KERNELSZ)-self.KERNELSZ//2,self.KERNELSZ).reshape(self.KERNELSZ*self.KERNELSZ)
1534
+ xx=np.tile(np.arange(self.KERNELSZ)-self.KERNELSZ//2,self.KERNELSZ).reshape(self.KERNELSZ,self.KERNELSZ)
1521
1535
 
1522
- wwr=hconvol.to_tensor((np.exp(-pw2*(xx**2+(xx.T)**2))*np.cos(pw*xx*np.pi)).reshape(1,1,self.KERNELSZ*self.KERNELSZ))
1536
+ wwr=(np.exp(-pw2*(xx**2+(xx.T)**2))*np.cos(pw*xx*np.pi)).reshape(1,1,self.KERNELSZ*self.KERNELSZ)
1523
1537
  wwr-=wwr.mean()
1524
- wwi=hconvol.to_tensor((np.exp(-pw2*(xx**2+(xx.T)**2))*np.sin(pw*xx*np.pi)).reshape(1,1,self.KERNELSZ*self.KERNELSZ))
1538
+ wwi=(np.exp(-pw2*(xx**2+(xx.T)**2))*np.sin(pw*xx*np.pi)).reshape(1,1,self.KERNELSZ*self.KERNELSZ)
1525
1539
  wwi-=wwi.mean()
1526
- wwr/=(abs(wwr+1J*wwi)).sum()
1527
- wwi/=(abs(wwr+1J*wwi)).sum()
1540
+ amp=np.sum(abs(wwr+1J*wwi))
1541
+
1542
+ wwr/=amp
1543
+ wwi/=amp
1544
+
1545
+ wwr=hconvol.to_tensor(wwr)
1546
+ wwi=hconvol.to_tensor(wwi)
1528
1547
 
1529
1548
  wavr,indice,mshape=hconvol.make_matrix(wwr)
1530
1549
  wavi,indice,mshape=hconvol.make_matrix(wwi)
@@ -2180,17 +2199,21 @@ class FoCUS:
2180
2199
  # mtmp = l_mask[:,self.KERNELSZ // 2 : -self.KERNELSZ // 2,self.KERNELSZ // 2 : -self.KERNELSZ // 2,:]
2181
2200
  # vtmp = l_x[:,self.KERNELSZ // 2 : -self.KERNELSZ // 2,self.KERNELSZ // 2 : -self.KERNELSZ // 2,:]
2182
2201
 
2183
- v1 = self.backend.bk_reduce_sum(
2184
- self.backend.bk_reduce_sum(mtmp * vtmp, axis=-1), -1
2185
- )
2186
- v2 = self.backend.bk_reduce_sum(
2187
- self.backend.bk_reduce_sum(mtmp * vtmp * vtmp, axis=-1), -1
2188
- )
2189
- vh = self.backend.bk_reduce_sum(
2190
- self.backend.bk_reduce_sum(mtmp, axis=-1), -1
2191
- )
2202
+ if self.use_median:
2203
+ res,res2 = self.backend.bk_masked_median_2d_weiszfeld(vtmp, mtmp)
2204
+ else:
2205
+ v1 = self.backend.bk_reduce_sum(
2206
+ self.backend.bk_reduce_sum(mtmp * vtmp, axis=-1), -1
2207
+ )
2208
+ v2 = self.backend.bk_reduce_sum(
2209
+ self.backend.bk_reduce_sum(mtmp * vtmp * vtmp, axis=-1), -1
2210
+ )
2211
+ vh = self.backend.bk_reduce_sum(
2212
+ self.backend.bk_reduce_sum(mtmp, axis=-1), -1
2213
+ )
2192
2214
 
2193
- res = v1 / vh
2215
+ res = v1 / vh
2216
+ res2= v2 / vh
2194
2217
 
2195
2218
  oshape = [x.shape[0]] + [mask.shape[0]]
2196
2219
  if len(x.shape) > 3:
@@ -2199,22 +2222,26 @@ class FoCUS:
2199
2222
  oshape = oshape + [1]
2200
2223
 
2201
2224
  if calc_var:
2225
+ if self.use_median:
2226
+ vh = self.backend.bk_reduce_sum(
2227
+ self.backend.bk_reduce_sum(mtmp, axis=-1), -1
2228
+ )
2202
2229
  if self.backend.bk_is_complex(vtmp):
2203
2230
  res2 = self.backend.bk_sqrt(
2204
2231
  (
2205
2232
  (
2206
- self.backend.bk_real(v2) / self.backend.bk_real(vh)
2233
+ self.backend.bk_real(res2)
2207
2234
  - self.backend.bk_real(res) * self.backend.bk_real(res)
2208
2235
  )
2209
2236
  + (
2210
- self.backend.bk_imag(v2) / self.backend.bk_real(vh)
2237
+ self.backend.bk_imag(res2)
2211
2238
  - self.backend.bk_imag(res) * self.backend.bk_imag(res)
2212
2239
  )
2213
2240
  )
2214
2241
  / self.backend.bk_real(vh)
2215
2242
  )
2216
2243
  else:
2217
- res2 = self.backend.bk_sqrt((v2 / vh - res * res) / (vh))
2244
+ res2 = self.backend.bk_sqrt((res2 - res * res) / (vh))
2218
2245
 
2219
2246
  res = self.backend.bk_reshape(res, oshape)
2220
2247
  res2 = self.backend.bk_reshape(res2, oshape)
@@ -2226,11 +2253,16 @@ class FoCUS:
2226
2253
  elif self.use_1D:
2227
2254
  mtmp = l_mask
2228
2255
  vtmp = l_x
2229
- v1 = self.backend.bk_reduce_sum(l_mask * vtmp, axis=-1)
2230
- v2 = self.backend.bk_reduce_sum(l_mask * vtmp * vtmp, axis=-1)
2231
- vh = self.backend.bk_reduce_sum(l_mask , axis=-1)
2256
+
2257
+ if self.use_median:
2258
+ res,res2 = self.backend.bk_masked_median(l_x, l_mask)
2259
+ else:
2260
+ v1 = self.backend.bk_reduce_sum(l_mask * l_x, axis=-1)
2261
+ v2 = self.backend.bk_reduce_sum(l_mask * l_x * l_x, axis=-1)
2262
+ vh = self.backend.bk_reduce_sum(l_mask, axis=-1)
2232
2263
 
2233
- res = v1 / vh
2264
+ res = v1 / vh
2265
+ res2= v2 / vh
2234
2266
 
2235
2267
  oshape = [x.shape[0]] + [mask.shape[0]]
2236
2268
  if len(x.shape) > 1:
@@ -2239,35 +2271,42 @@ class FoCUS:
2239
2271
  oshape = oshape + [1]
2240
2272
 
2241
2273
  if calc_var:
2274
+ if self.use_median:
2275
+ vh = self.backend.bk_reduce_sum(l_mask, axis=-1)
2276
+
2242
2277
  if self.backend.bk_is_complex(vtmp):
2243
2278
  res2 = self.backend.bk_sqrt(
2244
2279
  (
2245
2280
  (
2246
- self.backend.bk_real(v2) / self.backend.bk_real(vh)
2281
+ self.backend.bk_real(res2)
2247
2282
  - self.backend.bk_real(res) * self.backend.bk_real(res)
2248
2283
  )
2249
2284
  + (
2250
- self.backend.bk_imag(v2) / self.backend.bk_real(vh)
2285
+ self.backend.bk_imag(res2)
2251
2286
  - self.backend.bk_imag(res) * self.backend.bk_imag(res)
2252
2287
  )
2253
2288
  )
2254
2289
  / self.backend.bk_real(vh)
2255
2290
  )
2256
2291
  else:
2257
- res2 = self.backend.bk_sqrt((v2 / vh - res * res) / (vh))
2292
+ res2 = self.backend.bk_sqrt((res2 - res * res) / (vh))
2293
+
2258
2294
  res = self.backend.bk_reshape(res, oshape)
2259
2295
  res2 = self.backend.bk_reshape(res2, oshape)
2260
2296
  return res, res2
2261
2297
  else:
2262
2298
  res = self.backend.bk_reshape(res, oshape)
2263
2299
  return res
2264
-
2265
2300
  else:
2266
- v1 = self.backend.bk_reduce_sum(l_mask * l_x, axis=-1)
2267
- v2 = self.backend.bk_reduce_sum(l_mask * l_x * l_x, axis=-1)
2268
- vh = self.backend.bk_reduce_sum(l_mask, axis=-1)
2301
+ if self.use_median:
2302
+ res,res2 = self.backend.bk_masked_median(l_x, l_mask)
2303
+ else:
2304
+ v1 = self.backend.bk_reduce_sum(l_mask * l_x, axis=-1)
2305
+ v2 = self.backend.bk_reduce_sum(l_mask * l_x * l_x, axis=-1)
2306
+ vh = self.backend.bk_reduce_sum(l_mask, axis=-1)
2269
2307
 
2270
- res = v1 / vh
2308
+ res = v1 / vh
2309
+ res2= v2 / vh
2271
2310
 
2272
2311
  oshape = []
2273
2312
  if len(shape) > 1:
@@ -2276,24 +2315,27 @@ class FoCUS:
2276
2315
  oshape = [1]
2277
2316
 
2278
2317
  oshape = oshape + [mask.shape[0]]
2318
+
2279
2319
  if len(shape) > 2:
2280
2320
  oshape = oshape + shape[1:-1]
2281
2321
  else:
2282
2322
  oshape = oshape + [1]
2283
2323
 
2284
2324
  if calc_var:
2325
+ if self.use_median:
2326
+ vh = self.backend.bk_reduce_sum(l_mask, axis=-1)
2285
2327
  if self.backend.bk_is_complex(l_x):
2286
2328
  res2 = self.backend.bk_sqrt(
2287
2329
  (
2288
- self.backend.bk_real(v2) / self.backend.bk_real(vh)
2330
+ self.backend.bk_real(res2)
2289
2331
  - self.backend.bk_real(res) * self.backend.bk_real(res)
2290
- + self.backend.bk_imag(v2) / self.backend.bk_real(vh)
2332
+ + self.backend.bk_imag(res2)
2291
2333
  - self.backend.bk_imag(res) * self.backend.bk_imag(res)
2292
2334
  )
2293
2335
  / self.backend.bk_real(vh)
2294
2336
  )
2295
2337
  else:
2296
- res2 = self.backend.bk_sqrt((v2 / vh - res * res) / (vh))
2338
+ res2 = self.backend.bk_sqrt((res2 - res * res) / (vh))
2297
2339
 
2298
2340
  res = self.backend.bk_reshape(res, oshape)
2299
2341
  res2 = self.backend.bk_reshape(res2, oshape)
foscat/Plot.py CHANGED
@@ -959,8 +959,9 @@ def conjugate_gradient_normal_equation(data, x0, www, all_idx,
959
959
 
960
960
  rs_new = np.dot(r, r)
961
961
 
962
- if verbose and i % 50 == 0:
963
- print(f"Iter {i:03d}: residual = {np.sqrt(rs_new):.3e}")
962
+ if verbose and i % 10 == 0:
963
+ v=np.mean((LP(p, www, all_idx)-data)**2)
964
+ print(f"Iter {i:03d}: residual = {np.sqrt(rs_new):.3e},{np.sqrt(v):.3e}")
964
965
 
965
966
  if np.sqrt(rs_new) < tol:
966
967
  if verbose:
@@ -1155,7 +1156,7 @@ def plot_wave(wave,title="spectrum",unit="Amplitude",cmap="viridis"):
1155
1156
  plt.xlabel(r"$k_x$ [cycles / km]")
1156
1157
  plt.ylabel(r"$k_y$ [cycles / km]")
1157
1158
  plt.title(title)
1158
-
1159
+
1159
1160
  def lonlat_edges_from_ref(shape, ref_lon, ref_lat, dlon, dlat, anchor="center"):
1160
1161
  """
1161
1162
  Build lon/lat *edges* (H+1, W+1) for a regular, axis-aligned grid.
@@ -982,6 +982,9 @@ def fit(
982
982
  n_epoch: int = 10,
983
983
  view_epoch: int = 10,
984
984
  batch_size: int = 16,
985
+ x_valid: Union[torch.Tensor, np.ndarray, List[Union[torch.Tensor, np.ndarray]]]= None,
986
+ y_valid: Union[torch.Tensor, np.ndarray, List[Union[torch.Tensor, np.ndarray]]]= None,
987
+ save_model: bool = False,
985
988
  lr: float = 1e-3,
986
989
  weight_decay: float = 0.0,
987
990
  clip_grad_norm: Optional[float] = None,
@@ -1005,6 +1008,11 @@ def fit(
1005
1008
  device = model.runtime_device if hasattr(model, "runtime_device") else (torch.device("cuda") if torch.cuda.is_available() else torch.device("cpu"))
1006
1009
  model.to(device)
1007
1010
 
1011
+ if save_model:
1012
+ assert x_valid is None, "If save_mode=True x_valid should not be None"
1013
+ assert y_valid is None, "If save_mode=True y_valid should not be None"
1014
+ best_valid=1E30
1015
+
1008
1016
  # Detect variable-length mode
1009
1017
  varlen_mode = isinstance(x_train, (list, tuple))
1010
1018
 
@@ -1197,6 +1205,14 @@ def fit(
1197
1205
  history.append(epoch_loss)
1198
1206
  # print every view_epoch logical step
1199
1207
  if verbose and ((len(history) % view_epoch == 0) or (len(history) == 1)):
1200
- print(f"[epoch {len(history)}] loss={epoch_loss:.6f}")
1208
+ if x_valid is not None:
1209
+ preds=model.predict(model.to_tensor(x_valid)).cpu().numpy()
1210
+ valid_loss=np.mean((preds-y_valid)**2)
1211
+ if save_model:
1212
+ if best_valid>valid_loss:
1213
+ torch.save({"model": self.state_dict(), "cfg": CFG}, os.path.join(CFG["save_dir"], "best.pt"))
1214
+ print(f"[epoch {len(history)}] loss={epoch_loss:.4f} loss_valid={valid_loss:.4f}")
1215
+ else:
1216
+ print(f"[epoch {len(history)}] loss={epoch_loss:.4f}")
1201
1217
 
1202
1218
  return {"loss": history}