foscat 2025.10.2__py3-none-any.whl → 2026.1.1__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
foscat/FoCUS.py CHANGED
@@ -5,8 +5,11 @@ import healpy as hp
5
5
  import numpy as np
6
6
  import foscat.HealSpline as HS
7
7
  from scipy.interpolate import griddata
8
+ from foscat.SphereDownGeo import SphereDownGeo
9
+ from foscat.SphereUpGeo import SphereUpGeo
10
+ import torch
8
11
 
9
- TMPFILE_VERSION = "V10_0"
12
+ TMPFILE_VERSION = "V12_0"
10
13
 
11
14
 
12
15
  class FoCUS:
@@ -28,15 +31,15 @@ class FoCUS:
28
31
  use_2D=False,
29
32
  use_1D=False,
30
33
  return_data=False,
31
- JmaxDelta=0,
32
34
  DODIV=False,
35
+ use_median=False,
33
36
  InitWave=None,
34
37
  silent=True,
35
38
  mpi_size=1,
36
39
  mpi_rank=0
37
40
  ):
38
41
 
39
- self.__version__ = "2025.10.2"
42
+ self.__version__ = "2026.01.1"
40
43
  # P00 coeff for normalization for scat_cov
41
44
  self.TMPFILE_VERSION = TMPFILE_VERSION
42
45
  self.P1_dic = None
@@ -50,13 +53,15 @@ class FoCUS:
50
53
  self.mpi_rank = mpi_rank
51
54
  self.return_data = return_data
52
55
  self.silent = silent
56
+ self.use_median = use_median
53
57
 
54
58
  self.kernel_smooth = {}
55
59
  self.padding_smooth = {}
56
60
  self.kernelR_conv = {}
57
61
  self.kernelI_conv = {}
58
62
  self.padding_conv = {}
59
-
63
+ self.down = {}
64
+ self.up = {}
60
65
  if not self.silent:
61
66
  print("================================================")
62
67
  print(" START FOSCAT CONFIGURATION")
@@ -89,13 +94,6 @@ class FoCUS:
89
94
  self.nlog = 0
90
95
  self.padding = padding
91
96
 
92
- if JmaxDelta != 0:
93
- print(
94
- "OPTION JmaxDelta is not avialable anymore after version 3.6.2. Please use Jmax option in eval function"
95
- )
96
- return None
97
-
98
- self.OSTEP = JmaxDelta
99
97
  self.use_2D = use_2D
100
98
  self.use_1D = use_1D
101
99
 
@@ -654,6 +652,7 @@ class FoCUS:
654
652
  return rim
655
653
 
656
654
  # --------------------------------------------------------
655
+
657
656
  def ud_grade_2(self, im, axis=0, cell_ids=None, nside=None,max_poll=False):
658
657
 
659
658
  if self.use_2D:
@@ -673,13 +672,20 @@ class FoCUS:
673
672
  tim = self.backend.bk_reshape(
674
673
  self.backend.bk_cast(im), [ndata, npix, npiy, 1]
675
674
  )
675
+ '''
676
676
  tim = self.backend.bk_reshape(
677
677
  tim[:, 0 : 2 * (npix // 2), 0 : 2 * (npiy // 2), :],
678
678
  [ndata, npix // 2, 2, npiy // 2, 2, 1],
679
679
  )
680
-
681
- res = self.backend.bk_reduce_sum(self.backend.bk_reduce_sum(tim, 4), 2) / 4
682
-
680
+
681
+ #res = self.backend.bk_reduce_sum(self.backend.bk_reduce_sum(tim, 4), 2) / 4
682
+ '''
683
+
684
+ if self.use_median:
685
+ res = self.backend.downsample_median_2x2(tim)
686
+ else:
687
+ res = self.backend.downsample_mean_2x2(tim)
688
+
683
689
  if len(ishape) == 2:
684
690
  return (
685
691
  self.backend.bk_reshape(
@@ -711,13 +717,42 @@ class FoCUS:
711
717
  self.backend.bk_cast(im), [ndata, npix // 2, 2]
712
718
  )
713
719
 
714
- res = self.backend.bk_reduce_mean(tim, -1)
720
+ if self.use_median:
721
+ res=self.backend.bk_reduce_median(tim,axis=-1)
722
+ else:
723
+ res=self.backend.bk_reduce_mean(tim,axis=-1)
715
724
 
716
725
  return self.backend.bk_reshape(res, ishape[0:-1] + [npix // 2]), None
717
726
 
718
727
  else:
719
728
  shape = list(im.shape)
720
- if max_poll:
729
+ if nside is None:
730
+ l_nside=int(np.sqrt(shape[-1]//12))
731
+ else:
732
+ l_nside=nside
733
+
734
+ nbatch=1
735
+ for k in range(len(shape)-1):
736
+ nbatch*=shape[k]
737
+ if l_nside not in self.down:
738
+ print('initialise down', l_nside)
739
+ self.down[l_nside] = SphereDownGeo(nside_in=l_nside, dtype=self.all_bk_type,mode="smooth", in_cell_ids=cell_ids)
740
+
741
+ res,out_cell=self.down[l_nside](self.backend.bk_reshape(im,[nbatch,1,shape[-1]]))
742
+
743
+ return self.backend.bk_reshape(res,shape[:-1]+[out_cell.shape[0]]),out_cell
744
+ '''
745
+ if self.use_median:
746
+ if cell_ids is not None:
747
+ sim, new_cell_ids = self.backend.binned_mean(im, cell_ids,reduce='median')
748
+ return sim, new_cell_ids
749
+
750
+ return self.backend.bk_reduce_median(
751
+ self.backend.bk_reshape(im, shape[0:-1]+[shape[-1]//4,4]), axis=-1
752
+ ),None
753
+
754
+ elif max_poll:
755
+
721
756
  if cell_ids is not None:
722
757
  sim, new_cell_ids = self.backend.binned_mean(im, cell_ids,reduce='max')
723
758
  return sim, new_cell_ids
@@ -733,6 +768,7 @@ class FoCUS:
733
768
  return self.backend.bk_reduce_mean(
734
769
  self.backend.bk_reshape(im, shape[0:-1]+[shape[-1]//4,4]), axis=-1
735
770
  ),None
771
+ '''
736
772
 
737
773
  # --------------------------------------------------------
738
774
  def up_grade(self, im, nout,
@@ -822,6 +858,7 @@ class FoCUS:
822
858
  else:
823
859
  lout = nside
824
860
 
861
+ '''
825
862
  if (lout,nout) not in self.pix_interp_val or force_init_index:
826
863
  if not self.silent:
827
864
  print("compute lout nout", lout, nout)
@@ -912,12 +949,32 @@ class FoCUS:
912
949
 
913
950
  del w
914
951
  del p
915
-
916
- if lout == nout:
917
- imout = im
918
- else:
919
- # work only on the last column
920
-
952
+ '''
953
+ shape=list(im.shape)
954
+ nbatch=1
955
+ for k in range(len(shape)-1):
956
+ nbatch*=shape[k]
957
+
958
+ im=self.backend.bk_reshape(im,[nbatch,1,shape[-1]])
959
+
960
+ while lout<nout:
961
+ if lout not in self.up:
962
+ if o_cell_ids is None:
963
+ l_o_cell_ids=torch.tensor(np.arange(12*(lout**2),dtype='int'),device=im.device)
964
+ else:
965
+ l_o_cell_ids=o_cell_ids
966
+ self.up[lout] = SphereUpGeo(nside_out=lout,
967
+ dtype=self.all_bk_type,
968
+ cell_ids_out=l_o_cell_ids,
969
+ up_norm="col_l1")
970
+ im, fine_ids = self.up[lout](self.backend.bk_cast(im))
971
+ lout*=2
972
+ if lout<nout and o_cell_ids is not None:
973
+ o_cell_ids=torch.repeat(fine_ids,4)*4+ \
974
+ torch.tile(torch.tensor([0,1,2,3],device=fine_ids.device,dtype=fine_ids.dtype),fine_ids.shape[0])
975
+
976
+ return self.backend.bk_reshape(im,shape[:-1]+[im.shape[-1]])
977
+ '''
921
978
  ndata = 1
922
979
  for k in range(len(ishape)-1):
923
980
  ndata = ndata * ishape[k]
@@ -946,6 +1003,7 @@ class FoCUS:
946
1003
  return self.backend.bk_reshape(
947
1004
  imout, ishape[0:-1]+[imout.shape[-1]]
948
1005
  )
1006
+ '''
949
1007
  return imout
950
1008
 
951
1009
  # --------------------------------------------------------
@@ -1340,7 +1398,9 @@ class FoCUS:
1340
1398
  else:
1341
1399
  l_cell_ids=cell_ids
1342
1400
 
1343
- nvalid=self.KERNELSZ**2
1401
+ nvalid=4*self.KERNELSZ**2
1402
+ if nvalid>12*nside**2:
1403
+ nvalid=12*nside**2
1344
1404
  idxEB=hconvol.idx_nn[:,0:nvalid]
1345
1405
  tmpEB=np.zeros([self.NORIENT,4,l_cell_ids.shape[0],nvalid],dtype='complex')
1346
1406
  tmpS=np.zeros([4,l_cell_ids.shape[0],nvalid],dtype='float')
@@ -1486,7 +1546,7 @@ class FoCUS:
1486
1546
 
1487
1547
  else:
1488
1548
  if l_kernel == 5:
1489
- pw = 0.5
1549
+ pw = 0.75
1490
1550
  pw2 = 0.5
1491
1551
  threshold = 2e-5
1492
1552
 
@@ -2185,17 +2245,21 @@ class FoCUS:
2185
2245
  # mtmp = l_mask[:,self.KERNELSZ // 2 : -self.KERNELSZ // 2,self.KERNELSZ // 2 : -self.KERNELSZ // 2,:]
2186
2246
  # vtmp = l_x[:,self.KERNELSZ // 2 : -self.KERNELSZ // 2,self.KERNELSZ // 2 : -self.KERNELSZ // 2,:]
2187
2247
 
2188
- v1 = self.backend.bk_reduce_sum(
2189
- self.backend.bk_reduce_sum(mtmp * vtmp, axis=-1), -1
2190
- )
2191
- v2 = self.backend.bk_reduce_sum(
2192
- self.backend.bk_reduce_sum(mtmp * vtmp * vtmp, axis=-1), -1
2193
- )
2194
- vh = self.backend.bk_reduce_sum(
2195
- self.backend.bk_reduce_sum(mtmp, axis=-1), -1
2196
- )
2248
+ if self.use_median:
2249
+ res,res2 = self.backend.bk_masked_median_2d_weiszfeld(vtmp, mtmp)
2250
+ else:
2251
+ v1 = self.backend.bk_reduce_sum(
2252
+ self.backend.bk_reduce_sum(mtmp * vtmp, axis=-1), -1
2253
+ )
2254
+ v2 = self.backend.bk_reduce_sum(
2255
+ self.backend.bk_reduce_sum(mtmp * vtmp * vtmp, axis=-1), -1
2256
+ )
2257
+ vh = self.backend.bk_reduce_sum(
2258
+ self.backend.bk_reduce_sum(mtmp, axis=-1), -1
2259
+ )
2197
2260
 
2198
- res = v1 / vh
2261
+ res = v1 / vh
2262
+ res2= v2 / vh
2199
2263
 
2200
2264
  oshape = [x.shape[0]] + [mask.shape[0]]
2201
2265
  if len(x.shape) > 3:
@@ -2204,22 +2268,26 @@ class FoCUS:
2204
2268
  oshape = oshape + [1]
2205
2269
 
2206
2270
  if calc_var:
2271
+ if self.use_median:
2272
+ vh = self.backend.bk_reduce_sum(
2273
+ self.backend.bk_reduce_sum(mtmp, axis=-1), -1
2274
+ )
2207
2275
  if self.backend.bk_is_complex(vtmp):
2208
2276
  res2 = self.backend.bk_sqrt(
2209
2277
  (
2210
2278
  (
2211
- self.backend.bk_real(v2) / self.backend.bk_real(vh)
2279
+ self.backend.bk_real(res2)
2212
2280
  - self.backend.bk_real(res) * self.backend.bk_real(res)
2213
2281
  )
2214
2282
  + (
2215
- self.backend.bk_imag(v2) / self.backend.bk_real(vh)
2283
+ self.backend.bk_imag(res2)
2216
2284
  - self.backend.bk_imag(res) * self.backend.bk_imag(res)
2217
2285
  )
2218
2286
  )
2219
2287
  / self.backend.bk_real(vh)
2220
2288
  )
2221
2289
  else:
2222
- res2 = self.backend.bk_sqrt((v2 / vh - res * res) / (vh))
2290
+ res2 = self.backend.bk_sqrt((res2 - res * res) / (vh))
2223
2291
 
2224
2292
  res = self.backend.bk_reshape(res, oshape)
2225
2293
  res2 = self.backend.bk_reshape(res2, oshape)
@@ -2231,11 +2299,16 @@ class FoCUS:
2231
2299
  elif self.use_1D:
2232
2300
  mtmp = l_mask
2233
2301
  vtmp = l_x
2234
- v1 = self.backend.bk_reduce_sum(l_mask * vtmp, axis=-1)
2235
- v2 = self.backend.bk_reduce_sum(l_mask * vtmp * vtmp, axis=-1)
2236
- vh = self.backend.bk_reduce_sum(l_mask , axis=-1)
2302
+
2303
+ if self.use_median:
2304
+ res,res2 = self.backend.bk_masked_median(l_x, l_mask)
2305
+ else:
2306
+ v1 = self.backend.bk_reduce_sum(l_mask * l_x, axis=-1)
2307
+ v2 = self.backend.bk_reduce_sum(l_mask * l_x * l_x, axis=-1)
2308
+ vh = self.backend.bk_reduce_sum(l_mask, axis=-1)
2237
2309
 
2238
- res = v1 / vh
2310
+ res = v1 / vh
2311
+ res2= v2 / vh
2239
2312
 
2240
2313
  oshape = [x.shape[0]] + [mask.shape[0]]
2241
2314
  if len(x.shape) > 1:
@@ -2244,35 +2317,42 @@ class FoCUS:
2244
2317
  oshape = oshape + [1]
2245
2318
 
2246
2319
  if calc_var:
2320
+ if self.use_median:
2321
+ vh = self.backend.bk_reduce_sum(l_mask, axis=-1)
2322
+
2247
2323
  if self.backend.bk_is_complex(vtmp):
2248
2324
  res2 = self.backend.bk_sqrt(
2249
2325
  (
2250
2326
  (
2251
- self.backend.bk_real(v2) / self.backend.bk_real(vh)
2327
+ self.backend.bk_real(res2)
2252
2328
  - self.backend.bk_real(res) * self.backend.bk_real(res)
2253
2329
  )
2254
2330
  + (
2255
- self.backend.bk_imag(v2) / self.backend.bk_real(vh)
2331
+ self.backend.bk_imag(res2)
2256
2332
  - self.backend.bk_imag(res) * self.backend.bk_imag(res)
2257
2333
  )
2258
2334
  )
2259
2335
  / self.backend.bk_real(vh)
2260
2336
  )
2261
2337
  else:
2262
- res2 = self.backend.bk_sqrt((v2 / vh - res * res) / (vh))
2338
+ res2 = self.backend.bk_sqrt((res2 - res * res) / (vh))
2339
+
2263
2340
  res = self.backend.bk_reshape(res, oshape)
2264
2341
  res2 = self.backend.bk_reshape(res2, oshape)
2265
2342
  return res, res2
2266
2343
  else:
2267
2344
  res = self.backend.bk_reshape(res, oshape)
2268
2345
  return res
2269
-
2270
2346
  else:
2271
- v1 = self.backend.bk_reduce_sum(l_mask * l_x, axis=-1)
2272
- v2 = self.backend.bk_reduce_sum(l_mask * l_x * l_x, axis=-1)
2273
- vh = self.backend.bk_reduce_sum(l_mask, axis=-1)
2347
+ if self.use_median:
2348
+ res,res2 = self.backend.bk_masked_median(l_x, l_mask)
2349
+ else:
2350
+ v1 = self.backend.bk_reduce_sum(l_mask * l_x, axis=-1)
2351
+ v2 = self.backend.bk_reduce_sum(l_mask * l_x * l_x, axis=-1)
2352
+ vh = self.backend.bk_reduce_sum(l_mask, axis=-1)
2274
2353
 
2275
- res = v1 / vh
2354
+ res = v1 / vh
2355
+ res2= v2 / vh
2276
2356
 
2277
2357
  oshape = []
2278
2358
  if len(shape) > 1:
@@ -2281,24 +2361,27 @@ class FoCUS:
2281
2361
  oshape = [1]
2282
2362
 
2283
2363
  oshape = oshape + [mask.shape[0]]
2364
+
2284
2365
  if len(shape) > 2:
2285
2366
  oshape = oshape + shape[1:-1]
2286
2367
  else:
2287
2368
  oshape = oshape + [1]
2288
2369
 
2289
2370
  if calc_var:
2371
+ if self.use_median:
2372
+ vh = self.backend.bk_reduce_sum(l_mask, axis=-1)
2290
2373
  if self.backend.bk_is_complex(l_x):
2291
2374
  res2 = self.backend.bk_sqrt(
2292
2375
  (
2293
- self.backend.bk_real(v2) / self.backend.bk_real(vh)
2376
+ self.backend.bk_real(res2)
2294
2377
  - self.backend.bk_real(res) * self.backend.bk_real(res)
2295
- + self.backend.bk_imag(v2) / self.backend.bk_real(vh)
2378
+ + self.backend.bk_imag(res2)
2296
2379
  - self.backend.bk_imag(res) * self.backend.bk_imag(res)
2297
2380
  )
2298
2381
  / self.backend.bk_real(vh)
2299
2382
  )
2300
2383
  else:
2301
- res2 = self.backend.bk_sqrt((v2 / vh - res * res) / (vh))
2384
+ res2 = self.backend.bk_sqrt((res2 - res * res) / (vh))
2302
2385
 
2303
2386
  res = self.backend.bk_reshape(res, oshape)
2304
2387
  res2 = self.backend.bk_reshape(res2, oshape)