foscat 2025.7.2__py3-none-any.whl → 2025.8.3__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
foscat/FoCUS.py CHANGED
@@ -3,9 +3,10 @@ import sys
3
3
 
4
4
  import healpy as hp
5
5
  import numpy as np
6
+ import foscat.HealSpline as HS
6
7
  from scipy.interpolate import griddata
7
8
 
8
- TMPFILE_VERSION = "V6_0"
9
+ TMPFILE_VERSION = "V7_0"
9
10
 
10
11
 
11
12
  class FoCUS:
@@ -35,7 +36,7 @@ class FoCUS:
35
36
  mpi_rank=0
36
37
  ):
37
38
 
38
- self.__version__ = "2025.07.2"
39
+ self.__version__ = "2025.08.3"
39
40
  # P00 coeff for normalization for scat_cov
40
41
  self.TMPFILE_VERSION = TMPFILE_VERSION
41
42
  self.P1_dic = None
@@ -637,11 +638,13 @@ class FoCUS:
637
638
  return image.numpy()[self.ring2nest[lout]]
638
639
 
639
640
  # --------------------------------------------------------
640
- def ud_grade(self, im, j, axis=0):
641
+ def ud_grade(self, im, j, axis=0, cell_ids=None, nside=None):
641
642
  rim = im
642
643
  for k in range(j):
643
644
  # rim = self.smooth(rim, axis=axis)
644
- rim = self.ud_grade_2(rim, axis=axis)
645
+ rim = self.ud_grade_2(rim, axis=axis,
646
+ cell_ids=cell_ids,
647
+ nside=nside)
645
648
  return rim
646
649
 
647
650
  # --------------------------------------------------------
@@ -654,55 +657,39 @@ class FoCUS:
654
657
  print("Use of 2D scat with data that has less than 2D")
655
658
  return None, None
656
659
 
657
- npix = im.shape[axis]
658
- npiy = im.shape[axis + 1]
659
- odata = 1
660
- if len(ishape) > axis + 2:
661
- for k in range(axis + 2, len(ishape)):
662
- odata = odata * ishape[k]
660
+ npix = im.shape[-2]
661
+ npiy = im.shape[-1]
663
662
 
664
663
  ndata = 1
665
- for k in range(axis):
664
+ for k in range(len(im.shape)-2):
666
665
  ndata = ndata * ishape[k]
667
666
 
668
667
  tim = self.backend.bk_reshape(
669
- self.backend.bk_cast(im), [ndata, npix, npiy, odata]
668
+ self.backend.bk_cast(im), [ndata, npix, npiy, 1]
670
669
  )
671
670
  tim = self.backend.bk_reshape(
672
671
  tim[:, 0 : 2 * (npix // 2), 0 : 2 * (npiy // 2), :],
673
- [ndata, npix // 2, 2, npiy // 2, 2, odata],
672
+ [ndata, npix // 2, 2, npiy // 2, 2, 1],
674
673
  )
675
674
 
676
675
  res = self.backend.bk_reduce_sum(self.backend.bk_reduce_sum(tim, 4), 2) / 4
677
676
 
678
- if axis == 0:
679
- if len(ishape) == 2:
680
- return self.backend.bk_reshape(res, [npix // 2, npiy // 2]), None
681
- else:
682
- return (
683
- self.backend.bk_reshape(
684
- res, [npix // 2, npiy // 2] + ishape[axis + 2 :]
685
- ),
686
- None,
687
- )
677
+ if len(ishape) == 2:
678
+ return (
679
+ self.backend.bk_reshape(
680
+ res, [npix // 2, npiy // 2]
681
+ ),
682
+ None,
683
+ )
688
684
  else:
689
- if len(ishape) == axis + 2:
690
- return (
691
- self.backend.bk_reshape(
692
- res, ishape[0:axis] + [npix // 2, npiy // 2]
693
- ),
694
- None,
695
- )
696
- else:
697
- return (
698
- self.backend.bk_reshape(
699
- res,
700
- ishape[0:axis]
701
- + [npix // 2, npiy // 2]
702
- + ishape[axis + 2 :],
703
- ),
704
- None,
705
- )
685
+ return (
686
+ self.backend.bk_reshape(
687
+ res,
688
+ ishape[0:-2]
689
+ + [npix // 2, npiy // 2],
690
+ ),
691
+ None,
692
+ )
706
693
 
707
694
  return self.backend.bk_reshape(res, [npix // 2, npiy // 2]), None
708
695
  elif self.use_1D:
@@ -733,11 +720,16 @@ class FoCUS:
733
720
  ),None
734
721
 
735
722
  # --------------------------------------------------------
736
- def up_grade(self, im, nout, axis=-1, nouty=None):
723
+ def up_grade(self, im, nout,
724
+ axis=-1,
725
+ nouty=None,
726
+ cell_ids=None,
727
+ o_cell_ids=None,
728
+ nside=None):
737
729
 
738
730
  ishape = list(im.shape)
739
731
  if self.use_2D:
740
- if len(ishape) < axis + 2:
732
+ if len(ishape) < 2:
741
733
  if not self.silent:
742
734
  print("Use of 2D scat with data that has less than 2D")
743
735
  return None
@@ -745,39 +737,28 @@ class FoCUS:
745
737
  if nouty is None:
746
738
  nouty = nout
747
739
 
748
- if ishape[axis] == nout and ishape[axis + 1] == nouty:
740
+ if ishape[-2] == nout and ishape[-1] == nouty:
749
741
  return im
750
742
 
751
- npix = im.shape[axis]
752
- npiy = im.shape[axis + 1]
753
- odata = 1
743
+ npix = im.shape[-2]
744
+ npiy = im.shape[-1]
754
745
 
755
746
  ndata = 1
756
- for k in range(axis):
747
+ for k in range(len(im.shape)-2):
757
748
  ndata = ndata * ishape[k]
758
749
 
759
750
  tim = self.backend.bk_reshape(
760
- self.backend.bk_cast(im), [ndata, npix, npiy, odata]
751
+ self.backend.bk_cast(im), [ndata, npix, npiy,1]
761
752
  )
762
753
 
763
754
  res = self.backend.bk_resize_image(tim, [nout, nouty])
764
755
 
765
- if axis == 0:
766
- if len(ishape) == 2:
767
- return self.backend.bk_reshape(res, [nout, nouty])
768
- else:
769
- return self.backend.bk_reshape(
770
- res, [nout, nouty] + ishape[axis + 2 :]
771
- )
756
+ if len(ishape) == 2:
757
+ return self.backend.bk_reshape(res, [nout, nouty])
772
758
  else:
773
- if len(ishape) == axis + 2:
774
- return self.backend.bk_reshape(res, ishape[0:axis] + [nout, nouty])
775
- else:
776
- return self.backend.bk_reshape(
777
- res, ishape[0:axis] + [nout, nouty]
778
- )
779
-
780
- return self.backend.bk_reshape(res, [nout, nouty])
759
+ return self.backend.bk_reshape(
760
+ res, ishape[0:-2] + [nout, nouty]
761
+ )
781
762
 
782
763
  elif self.use_1D:
783
764
  if len(ishape) < axis + 1:
@@ -790,13 +771,11 @@ class FoCUS:
790
771
 
791
772
  npix = im.shape[axis]
792
773
  odata = 1
793
- if len(ishape) > axis + 1:
794
- for k in range(axis + 1, len(ishape)):
795
- odata = odata * ishape[k]
796
-
774
+
797
775
  ndata = 1
798
- for k in range(axis):
799
- ndata = ndata * ishape[k]
776
+ if len(ishape)>1:
777
+ for k in range(len(ishape)-1):
778
+ ndata = ndata * ishape[k]
800
779
 
801
780
  tim = self.backend.bk_reshape(
802
781
  self.backend.bk_cast(im), [ndata, npix, odata]
@@ -819,54 +798,107 @@ class FoCUS:
819
798
  self.backend.bk_concat([res1, res2], -2),
820
799
  [ndata, tim.shape[1] * 2, odata],
821
800
  )
801
+ return self.backend.bk_reshape(tim, ishape[0:-1] + [nout])
822
802
 
823
- if axis == 0:
824
- if len(ishape) == 1:
825
- return self.backend.bk_reshape(tim, [nout])
826
- else:
827
- return self.backend.bk_reshape(tim, [nout] + ishape[axis + 1 :])
803
+ else:
804
+ if nside is None:
805
+ lout = int(np.sqrt(im.shape[-1] // 12))
828
806
  else:
829
- if len(ishape) == axis + 1:
830
- return self.backend.bk_reshape(tim, ishape[0:axis] + [nout])
807
+ lout = nside
808
+
809
+ if (lout,nout) not in self.pix_interp_val:
810
+ if not self.silent:
811
+ print("compute lout nout", lout, nout)
812
+ if cell_ids is None:
813
+ o_cell_ids=np.arange(12 * nout**2, dtype="int")
814
+ i_npix=12*lout**2
815
+
816
+ #level=int(np.log2(lout)) # nside=128
817
+
818
+ #sp = HS.heal_spline(level,gamma=2.0)
819
+
820
+ th, ph = hp.pix2ang(
821
+ nout, o_cell_ids, nest=True
822
+ )
823
+
824
+ all_idx,www=hp.get_interp_weights(lout,th,ph,nest=True)
825
+
826
+ #www,all_idx,hidx=sp.ang2weigths(th,ph,nest=True)
827
+
828
+ w=www.T
829
+ p=all_idx.T
830
+
831
+ w=w.flatten()
832
+ p=p.flatten()
833
+
834
+ indice = np.zeros([o_cell_ids.shape[0] * 4, 2], dtype="int")
835
+ indice[:, 1] = np.repeat(np.arange(o_cell_ids.shape[0]), 4)
836
+ indice[:, 0] = p
837
+
838
+ self.pix_interp_val[(lout,nout)] = 1
839
+ self.weight_interp_val[(lout,nout)] = self.backend.bk_SparseTensor(
840
+ self.backend.bk_constant(indice),
841
+ self.backend.bk_constant(self.backend.bk_cast(w)),
842
+ dense_shape=[i_npix,o_cell_ids.shape[0]],
843
+ )
844
+
831
845
  else:
832
- return self.backend.bk_reshape(
833
- tim, ishape[0:axis] + [nout] + ishape[axis + 1 :]
846
+ ratio=(nout//lout)**2
847
+ if o_cell_ids is None:
848
+ o_cell_ids=np.tile(cell_ids,ratio)*ratio+np.repeat(np.arange(ratio),cell_ids.shape[0])
849
+ i_npix=cell_ids.shape[0]
850
+
851
+ #level=int(np.log2(lout)) # nside=128
852
+
853
+ #sp = HS.heal_spline(level,gamma=2.0)
854
+
855
+ th, ph = hp.pix2ang(
856
+ nout, o_cell_ids, nest=True
834
857
  )
835
858
 
836
- return self.backend.bk_reshape(tim, [nout])
859
+ all_idx,www=hp.get_interp_weights(lout,th,ph,nest=True)
860
+ #www,all_idx,hidx=sp.ang2weigths(th,ph,nest=True)
837
861
 
838
- else:
862
+ hidx,inv_idx = np.unique(all_idx,
863
+ return_inverse=True)
864
+ all_idx = inv_idx
865
+
866
+ sorter = np.argsort(hidx)
867
+
868
+ index=sorter[np.searchsorted(hidx,
869
+ cell_ids,
870
+ sorter=sorter)]
871
+
872
+ mask = -np.ones([hidx.shape[0]])
873
+
874
+ mask[index] = np.arange(index.shape[0],dtype='int')
839
875
 
840
- lout = int(np.sqrt(im.shape[-1] // 12))
876
+ all_idx=mask[all_idx]
841
877
 
842
- if (lout,nout) not in self.pix_interp_val:
843
- if not self.silent:
844
- print("compute lout nout", lout, nout)
845
- th, ph = hp.pix2ang(
846
- nout, np.arange(12 * nout**2, dtype="int"), nest=True
847
- )
848
- p, w = hp.get_interp_weights(lout, th, ph, nest=True)
849
- del th
850
- del ph
851
-
852
- indice = np.zeros([12 * nout * nout * 4, 2], dtype="int")
853
- p = p.T
854
- w = w.T
855
- t = np.argsort(
856
- p, 1
857
- ).flatten() # to make oder indices for sparsematrix computation
858
- t = t + np.repeat(np.arange(12 * nout * nout) * 4, 4)
859
- p = p.flatten()[t]
860
- w = w.flatten()[t]
861
- indice[:, 1] = np.repeat(np.arange(12 * nout**2), 4)
862
- indice[:, 0] = p
863
-
864
- self.pix_interp_val[(lout,nout)] = 1
865
- self.weight_interp_val[(lout,nout)] = self.backend.bk_SparseTensor(
866
- self.backend.bk_constant(indice),
867
- self.backend.bk_constant(self.backend.bk_cast(w.flatten())),
868
- dense_shape=[12 * lout**2,12 * nout**2],
869
- )
878
+ www[all_idx==-1]=0.0
879
+ www/=np.sum(www,0)[None,:]
880
+
881
+ all_idx[all_idx==-1]=0
882
+
883
+ w=www.T
884
+ p=all_idx.T
885
+
886
+ w=w.flatten()
887
+ p=p.flatten()
888
+
889
+ indice = np.zeros([o_cell_ids.shape[0] * 4, 2], dtype="int")
890
+ indice[:, 1] = np.repeat(np.arange(o_cell_ids.shape[0]), 4)
891
+ indice[:, 0] = p
892
+
893
+ self.pix_interp_val[(lout,nout)] = 1
894
+ self.weight_interp_val[(lout,nout)] = self.backend.bk_SparseTensor(
895
+ self.backend.bk_constant(indice),
896
+ self.backend.bk_constant(self.backend.bk_cast(w)),
897
+ dense_shape=[i_npix,o_cell_ids.shape[0]],
898
+ )
899
+
900
+ del w
901
+ del p
870
902
 
871
903
  if lout == nout:
872
904
  imout = im
@@ -877,7 +909,7 @@ class FoCUS:
877
909
  for k in range(len(ishape)-1):
878
910
  ndata = ndata * ishape[k]
879
911
  tim = self.backend.bk_reshape(
880
- self.backend.bk_cast(im), [ndata, 12 * lout**2]
912
+ self.backend.bk_cast(im), [ndata, ishape[-1]]
881
913
  )
882
914
  if tim.dtype == self.all_cbk_type:
883
915
  rr = self.backend.bk_sparse_dense_matmul(
@@ -894,12 +926,12 @@ class FoCUS:
894
926
  tim,
895
927
  self.weight_interp_val[(lout,nout)],
896
928
  )
897
-
929
+
898
930
  if len(ishape) == 1:
899
- return self.backend.bk_reshape(imout, [12 * nout**2])
931
+ return self.backend.bk_reshape(imout, [imout.shape[-1]])
900
932
  else:
901
933
  return self.backend.bk_reshape(
902
- imout, ishape[0:axis]+[12 * nout**2]
934
+ imout, ishape[0:-1]+[imout.shape[-1]]
903
935
  )
904
936
  return imout
905
937
 
@@ -1191,7 +1223,7 @@ class FoCUS:
1191
1223
  % (self.TEMPLATE_PATH, l_kernel**2,TMPFILE_VERSION, nside)
1192
1224
  )
1193
1225
  else:
1194
- if cell_ids is not None and nside>512:
1226
+ if cell_ids is not None and spin==0:
1195
1227
  tmp = self.read_index(
1196
1228
  "%s/XXXX_%s_W%d_%d_%d_PIDX.fst" # can not work
1197
1229
  % (
@@ -1204,6 +1236,7 @@ class FoCUS:
1204
1236
  )
1205
1237
 
1206
1238
  else:
1239
+ '''
1207
1240
  print('LOAD ',"%s/FOSCAT_%s_W%d_%d_%d_PIDX-SPIN%d.fst"
1208
1241
  % (
1209
1242
  self.TEMPLATE_PATH,
@@ -1212,6 +1245,7 @@ class FoCUS:
1212
1245
  self.NORIENT,
1213
1246
  nside,spin # if cell_ids computes the index
1214
1247
  ))
1248
+ '''
1215
1249
  tmp = self.read_index(
1216
1250
  "%s/FOSCAT_%s_W%d_%d_%d_PIDX-SPIN%d.fst"
1217
1251
  % (
@@ -1224,20 +1258,21 @@ class FoCUS:
1224
1258
  )
1225
1259
 
1226
1260
  except:
1227
- if cell_ids is not None and nside<=512:
1261
+ if cell_ids is not None and spin!=0:
1228
1262
  self.init_index(nside, kernel=kernel, spin=spin)
1229
1263
 
1230
1264
  if not self.use_2D:
1231
- print('NOT FOUND THEN COMPUTE %s/FOSCAT_%s_W%d_%d_%d_PIDX-SPIN%d.fst'
1232
- % (
1233
- self.TEMPLATE_PATH,
1234
- TMPFILE_VERSION,
1235
- l_kernel**2,
1236
- self.NORIENT,
1237
- nside,spin # if cell_ids computes the index
1238
- )
1239
- )
1240
1265
  if spin!=0:
1266
+ # keep the print here as spin!=0 can be long
1267
+ print('NOT FOUND THEN COMPUTE %s/FOSCAT_%s_W%d_%d_%d_PIDX-SPIN%d.fst'
1268
+ % (
1269
+ self.TEMPLATE_PATH,
1270
+ TMPFILE_VERSION,
1271
+ l_kernel**2,
1272
+ self.NORIENT,
1273
+ nside,spin # if cell_ids computes the index
1274
+ )
1275
+ )
1241
1276
  try:
1242
1277
  tmp = self.read_index(
1243
1278
  "%s/FOSCAT_%s_W%d_%d_%d_PIDX-SPIN0.fst"
@@ -1250,6 +1285,7 @@ class FoCUS:
1250
1285
  )
1251
1286
  )
1252
1287
  except:
1288
+ '''
1253
1289
  print('NOT FOUND THEN COMPUTE %s/FOSCAT_%s_W%d_%d_%d_PIDX-SPIN0.fst'
1254
1290
  % (
1255
1291
  self.TEMPLATE_PATH,
@@ -1259,7 +1295,7 @@ class FoCUS:
1259
1295
  nside
1260
1296
  )
1261
1297
  )
1262
-
1298
+ '''
1263
1299
  self.init_index(nside, kernel=kernel, spin=0)
1264
1300
 
1265
1301
  tmp = self.read_index(
@@ -1282,6 +1318,7 @@ class FoCUS:
1282
1318
  )
1283
1319
  )
1284
1320
 
1321
+ '''
1285
1322
  nn=self.NORIENT*12*nside**2
1286
1323
  idxEB=np.concatenate([tmp,tmp,tmp,tmp],0)
1287
1324
  idxEB[tmp.shape[0]:2*tmp.shape[0],0]+=12*nside**2
@@ -1289,7 +1326,156 @@ class FoCUS:
1289
1326
  idxEB[2*tmp.shape[0]:,1]+=nn
1290
1327
 
1291
1328
  tmpEB=np.zeros([tmpw.shape[0]*4],dtype='complex')
1329
+ '''
1330
+ import foscat.HOrientedConvol as hs
1292
1331
 
1332
+ hconvol=hs.HOrientedConvol(nside,3*self.KERNELSZ,cell_ids=cell_ids)
1333
+
1334
+ if cell_ids is None:
1335
+ l_cell_ids=np.arange(12*nside**2)
1336
+ else:
1337
+ l_cell_ids=cell_ids
1338
+
1339
+ nvalid=self.KERNELSZ**2
1340
+ idxEB=hconvol.idx_nn[:,0:nvalid]
1341
+ tmpEB=np.zeros([self.NORIENT,4,l_cell_ids.shape[0],nvalid],dtype='complex')
1342
+ tmpS=np.zeros([4,l_cell_ids.shape[0],nvalid],dtype='float')
1343
+
1344
+ idx={}
1345
+ nn=0
1346
+ nn2=1
1347
+ if nside<64:
1348
+ pp=10
1349
+ else:
1350
+ pp=1
1351
+ while nn2>0:
1352
+ idx2={}
1353
+ nn2=0
1354
+ im=np.zeros([12*nside**2])
1355
+ for n in range(l_cell_ids.shape[0]):
1356
+ if im[hconvol.idx_nn[n,0]]==0 and n not in idx:
1357
+ im[hconvol.idx_nn[n,:]]=1.0
1358
+ idx[hconvol.idx_nn[n,0]]=1.0
1359
+ idx2[hconvol.idx_nn[n,0]]=1.0
1360
+ nn+=1
1361
+ nn2+=1
1362
+ im=np.zeros([12*nside**2])
1363
+ for k in idx2:
1364
+ im[k]=1.0
1365
+ r=self.convol(im)
1366
+ for k in range(self.NORIENT):
1367
+ ralm=hp.map2alm(hp.reorder(r[k].cpu().numpy().real,n2r=True))[None,:]
1368
+ ialm=hp.map2alm(hp.reorder(r[k].cpu().numpy().imag,n2r=True))[None,:]
1369
+
1370
+ alm=np.concatenate([ralm,0*ralm,0*ralm],0)
1371
+ rqe,rue,rie=hp.alm2map_spin(alm,nside,spin,3*nside-1)
1372
+ alm=np.concatenate([ialm,0*ialm,0*ialm],0)
1373
+ iqe,iue,iie=hp.alm2map_spin(alm,nside,spin,3*nside-1)
1374
+
1375
+ alm=np.concatenate([0*ralm,ralm,0*ralm],0)
1376
+ rqb,rub,rib=hp.alm2map_spin(alm,nside,spin,3*nside-1)
1377
+ alm=np.concatenate([0*ialm,ialm,0*ialm],0)
1378
+ iqb,iub,iib=hp.alm2map_spin(alm,nside,spin,3*nside-1)
1379
+
1380
+ rqe=hp.reorder(rqe,r2n=True)
1381
+ rue=hp.reorder(rue,r2n=True)
1382
+ rqb=hp.reorder(rqb,r2n=True)
1383
+ rub=hp.reorder(rub,r2n=True)
1384
+
1385
+ iqe=hp.reorder(iqe,r2n=True)
1386
+ iue=hp.reorder(iue,r2n=True)
1387
+ iqb=hp.reorder(iqb,r2n=True)
1388
+ iub=hp.reorder(iub,r2n=True)
1389
+
1390
+ for l in idx2:
1391
+ tmpEB[k,0,l]=rqe[idxEB[l,:]]+1J*iqe[idxEB[l,:]]
1392
+ tmpEB[k,1,l]=rue[idxEB[l,:]]+1J*iue[idxEB[l,:]]
1393
+ tmpEB[k,2,l]=rqb[idxEB[l,:]]+1J*iqb[idxEB[l,:]]
1394
+ tmpEB[k,3,l]=rub[idxEB[l,:]]+1J*iub[idxEB[l,:]]
1395
+
1396
+ r=self.smooth(im)
1397
+
1398
+ ralm=hp.map2alm(hp.reorder(r.cpu().numpy(),n2r=True))[None,:]
1399
+
1400
+ alm=np.concatenate([ralm,0*ralm,0*ralm],0)
1401
+ rqe,rue,rie=hp.alm2map_spin(alm,nside,spin,3*nside-1)
1402
+
1403
+ alm=np.concatenate([0*ralm,ralm,0*ralm],0)
1404
+ rqb,rub,rib=hp.alm2map_spin(alm,nside,spin,3*nside-1)
1405
+
1406
+ rqe=hp.reorder(rqe,r2n=True)
1407
+ rue=hp.reorder(rue,r2n=True)
1408
+ rqb=hp.reorder(rqb,r2n=True)
1409
+ rub=hp.reorder(rub,r2n=True)
1410
+
1411
+ for l in idx2:
1412
+ tmpS[0,l,:]=rqe[idxEB[l,:]]
1413
+ tmpS[1,l,:]=rue[idxEB[l,:]]
1414
+ tmpS[2,l,:]=rqb[idxEB[l,:]]
1415
+ tmpS[3,l,:]=rub[idxEB[l,:]]
1416
+ if 100*nn/(l_cell_ids.shape[0])>pp:
1417
+ if nside<64:
1418
+ pp+=10
1419
+ else:
1420
+ pp+=1
1421
+ print('%.2f%% Done'%(100*nn/(l_cell_ids.shape[0])))
1422
+
1423
+ wav=tmpEB.flatten()
1424
+ wwav=tmpS.flatten()
1425
+ ndata=l_cell_ids.shape[0]*nvalid
1426
+ indice_1_1=np.tile(idxEB.flatten(),4*self.NORIENT)
1427
+ for k in range(self.NORIENT):
1428
+ indice_1_1[(4*k+1)*ndata:(4*k+2)*ndata]+=l_cell_ids.shape[0]
1429
+ indice_1_1[(4*k+3)*ndata:(4*k+4)*ndata]+=l_cell_ids.shape[0]
1430
+
1431
+ indice_1_0=np.tile(np.tile(np.repeat(np.arange(l_cell_ids.shape[0]),nvalid),4),self.NORIENT)
1432
+ for k in range(self.NORIENT):
1433
+ indice_1_0[(4*k+2)*ndata:(4*k+4)*ndata]+=self.NORIENT*l_cell_ids.shape[0]
1434
+ indice_1_0[(4*k)*ndata:(4*k+4)*ndata]+=k*l_cell_ids.shape[0]
1435
+ '''
1436
+ import matplotlib.pyplot as plt
1437
+ plt.figure()
1438
+ plt.subplot(2,2,1)
1439
+ plt.plot(indice_1_0)
1440
+ plt.subplot(2,2,2)
1441
+ plt.plot(indice_1_1)
1442
+ plt.subplot(2,2,3)
1443
+ plt.plot(wav.real)
1444
+ plt.subplot(2,2,4)
1445
+ plt.plot(abs(wav))
1446
+
1447
+ iarg=np.argsort(indice_1_0)
1448
+ indice_1_1=indice_1_1[iarg]
1449
+ indice_1_0=indice_1_0[iarg]
1450
+ wav=wav[iarg]
1451
+ '''
1452
+
1453
+ indice=np.concatenate([indice_1_1[:,None],indice_1_0[:,None]],1)
1454
+
1455
+ indice_2_1=np.tile(idxEB.flatten(),4)
1456
+ indice_2_1[ndata:2*ndata]+=l_cell_ids.shape[0]
1457
+ indice_2_1[3*ndata:4*ndata]+=l_cell_ids.shape[0]
1458
+ indice_2_0=np.tile(np.repeat(np.arange(l_cell_ids.shape[0]),nvalid),4)
1459
+ indice_2_0[2*ndata:]+=l_cell_ids.shape[0]
1460
+ '''
1461
+ plt.figure()
1462
+ plt.subplot(2,2,1)
1463
+ plt.plot(indice_2_0)
1464
+ plt.subplot(2,2,2)
1465
+ plt.plot(indice_2_1)
1466
+ plt.subplot(2,2,3)
1467
+ plt.plot(wav.real)
1468
+ plt.subplot(2,2,4)
1469
+ plt.plot(wwav)
1470
+
1471
+ iarg=np.argsort(indice_2_0)
1472
+ indice_2_1=indice_2_1[iarg]
1473
+ indice_2_0=indice_2_0[iarg]
1474
+ wwav=wwav[iarg]
1475
+ '''
1476
+ indice2=np.concatenate([indice_2_1[:,None],indice_2_0[:,None]],1)
1477
+
1478
+ '''
1293
1479
  for k in range(self.NORIENT*12*nside**2):
1294
1480
  if k%(nside**2)==0:
1295
1481
  print('Init index 1/2 spin=%d Please wait %d done against %d nside=%d kernel=%d'%(spin,k//(nside**2),
@@ -1316,7 +1502,7 @@ class FoCUS:
1316
1502
  tmpEB[idx+2*tmp.shape[0]]=hp.reorder(i,r2n=True)[tmp[idx,0]]+1J*hp.reorder(i2,r2n=True)[tmp[idx,0]]
1317
1503
  tmpEB[idx+3*tmp.shape[0]]=hp.reorder(q,r2n=True)[tmp[idx,0]]+1J*hp.reorder(q2,r2n=True)[tmp[idx,0]]
1318
1504
 
1319
-
1505
+ '''
1320
1506
  self.save_index("%s/FOSCAT_%s_W%d_%d_%d_PIDX-SPIN%d.fst"% (self.TEMPLATE_PATH,
1321
1507
  self.TMPFILE_VERSION,
1322
1508
  self.KERNELSZ**2,
@@ -1324,7 +1510,7 @@ class FoCUS:
1324
1510
  nside,
1325
1511
  spin
1326
1512
  ),
1327
- idxEB
1513
+ indice
1328
1514
  )
1329
1515
  self.save_index("%s/FOSCAT_%s_W%d_%d_%d_WAVE-SPIN%d.fst"% (self.TEMPLATE_PATH,
1330
1516
  self.TMPFILE_VERSION,
@@ -1333,9 +1519,9 @@ class FoCUS:
1333
1519
  nside,
1334
1520
  spin,
1335
1521
  ),
1336
- tmpEB
1522
+ wav
1337
1523
  )
1338
-
1524
+ '''
1339
1525
  tmp = self.read_index(
1340
1526
  "%s/FOSCAT_%s_W%d_%d_%d_PIDX2-SPIN0.fst"
1341
1527
  % (
@@ -1377,7 +1563,7 @@ class FoCUS:
1377
1563
  tmpEB[idx+2*tmp.shape[0]]=hp.reorder(i,r2n=True)[tmp[idx,0]]
1378
1564
  tmpEB[idx+3*tmp.shape[0]]=hp.reorder(q,r2n=True)[tmp[idx,0]]
1379
1565
 
1380
-
1566
+ '''
1381
1567
  self.save_index("%s/FOSCAT_%s_W%d_%d_%d_PIDX2-SPIN%d.fst"% (self.TEMPLATE_PATH,
1382
1568
  self.TMPFILE_VERSION,
1383
1569
  self.KERNELSZ**2,
@@ -1385,7 +1571,7 @@ class FoCUS:
1385
1571
  nside,
1386
1572
  spin
1387
1573
  ),
1388
- idxEB
1574
+ indice2
1389
1575
  )
1390
1576
  self.save_index("%s/FOSCAT_%s_W%d_%d_%d_SMOO-SPIN%d.fst"% (self.TEMPLATE_PATH,
1391
1577
  self.TMPFILE_VERSION,
@@ -1394,10 +1580,11 @@ class FoCUS:
1394
1580
  nside,
1395
1581
  spin,
1396
1582
  ),
1397
- tmpEB
1583
+ wwav
1398
1584
  )
1585
+
1399
1586
  else:
1400
-
1587
+ '''
1401
1588
  if l_kernel == 5:
1402
1589
  pw = 0.5
1403
1590
  pw2 = 0.5
@@ -1412,8 +1599,20 @@ class FoCUS:
1412
1599
  pw = 0.5
1413
1600
  pw2 = 0.25
1414
1601
  threshold = 4e-5
1602
+ '''
1603
+ import foscat.HOrientedConvol as hs
1415
1604
 
1416
- if cell_ids is not None and nside>512:
1605
+ hconvol=hs.HOrientedConvol(nside,l_kernel,cell_ids=cell_ids)
1606
+
1607
+ orientations=np.pi*np.arange(self.NORIENT)/self.NORIENT
1608
+
1609
+ wav,indice,wwav,indice2=hconvol.make_wavelet_matrix(orientations,
1610
+ polar=True,
1611
+ return_index=True,
1612
+ return_smooth=True)
1613
+
1614
+ '''
1615
+ if cell_ids is not None and nside>256:
1417
1616
  if not isinstance(cell_ids, np.ndarray):
1418
1617
  cell_ids = self.backend.to_numpy(cell_ids)
1419
1618
  th, ph = hp.pix2ang(nside, cell_ids, nest=True)
@@ -1537,8 +1736,7 @@ class FoCUS:
1537
1736
  wav = wav[:iv]
1538
1737
  indice2 = indice2[:iv2, :]
1539
1738
  wwav = wwav[:iv2]
1540
- if not self.silent:
1541
- print("Kernel Size ", iv / (self.NORIENT * 12 * nside * nside))
1739
+ '''
1542
1740
 
1543
1741
  if cell_ids is None:
1544
1742
  if not self.silent:
@@ -1613,7 +1811,7 @@ class FoCUS:
1613
1811
  )
1614
1812
  return None
1615
1813
 
1616
- if cell_ids is None or nside<=512:
1814
+ if cell_ids is None or spin!=0:
1617
1815
  self.barrier()
1618
1816
  if self.use_2D:
1619
1817
  tmp = self.read_index(
@@ -1719,6 +1917,20 @@ class FoCUS:
1719
1917
  tmp2[lidx,0]=0
1720
1918
  tmp2[:,1]+=i_id*lcell_ids.shape[0]
1721
1919
  tmp2[:,0]+=i_id2*lcell_ids.shape[0]
1920
+
1921
+ #add normalisation
1922
+ ww=np.bincount(tmp2[:,1],weights=ws)
1923
+ ws/=ww[tmp2[:,1]]
1924
+
1925
+ wh=np.bincount(tmp[:,1])
1926
+ ww=np.bincount(tmp[:,1],weights=wr)
1927
+ wr-=(ww/wh)[tmp[:,1]]
1928
+ ww=np.bincount(tmp[:,1],weights=wi)
1929
+ wi-=(ww/wh)[tmp[:,1]]
1930
+
1931
+ ww=np.bincount(tmp[:,1],weights=np.sqrt(wr*wr+wi*wi))
1932
+ wr/=ww[tmp[:,1]]
1933
+ wi/=ww[tmp[:,1]]
1722
1934
 
1723
1935
  else:
1724
1936
  tmp = indice
@@ -2116,8 +2328,8 @@ class FoCUS:
2116
2328
  ichannel = 1
2117
2329
  for i in range(1, len(shape) - 1):
2118
2330
  ichannel *= shape[i]
2119
-
2120
- l_x = self.backend.bk_reshape(x, [shape[0], 1, ichannel, shape[-1]])
2331
+
2332
+ l_x = self.backend.bk_reshape(x, [shape[0], 1, ichannel,shape[-1]])
2121
2333
 
2122
2334
  if self.padding == "VALID":
2123
2335
  oshape = [k for k in shape]
@@ -2168,13 +2380,8 @@ class FoCUS:
2168
2380
  res = v1 / vh
2169
2381
 
2170
2382
  oshape = [x.shape[0]] + [mask.shape[0]]
2171
- if axis > 0:
2172
- oshape = oshape + list(x.shape[1:axis])
2173
-
2174
- if len(x.shape[axis:-2]) > 0:
2175
- oshape = oshape + list(x.shape[axis:-2])
2176
- else:
2177
- oshape = oshape + [1]
2383
+ if len(x.shape)>1:
2384
+ oshape = oshape + list(x.shape[1:-2])
2178
2385
 
2179
2386
  if calc_var:
2180
2387
  if self.backend.bk_is_complex(vtmp):
@@ -2204,9 +2411,9 @@ class FoCUS:
2204
2411
  elif self.use_1D:
2205
2412
  mtmp = l_mask
2206
2413
  vtmp = l_x
2207
- v1 = self.backend.bk_reduce_sum(l_mask[1,:,...,:] * vtmp, axis=-1)
2208
- v2 = self.backend.bk_reduce_sum(mtmp * vtmp * vtmp, axis=-1)
2209
- vh = self.backend.bk_reduce_sum(mtmp, axis=-1)
2414
+ v1 = self.backend.bk_reduce_sum(l_mask * vtmp, axis=-1)
2415
+ v2 = self.backend.bk_reduce_sum(l_mask * vtmp * vtmp, axis=-1)
2416
+ vh = self.backend.bk_reduce_sum(l_mask , axis=-1)
2210
2417
 
2211
2418
  res = v1 / vh
2212
2419
 
@@ -2576,15 +2783,16 @@ class FoCUS:
2576
2783
  print("Use of 2D scat with data that has less than 2D")
2577
2784
  return None
2578
2785
 
2579
- npix = ishape[axis]
2580
- npiy = ishape[axis + 1]
2786
+ npix = ishape[-2]
2787
+ npiy = ishape[-1]
2788
+
2581
2789
  odata = 1
2582
- if len(ishape) > axis + 2:
2583
- for k in range(axis + 2, len(ishape)):
2790
+ if len(ishape) > 1:
2791
+ for k in range(len(ishape)-2):
2584
2792
  odata = odata * ishape[k]
2585
2793
 
2586
2794
  ndata = 1
2587
- for k in range(axis):
2795
+ for k in range(len(ishape)-2):
2588
2796
  ndata = ndata * ishape[k]
2589
2797
 
2590
2798
  tim = self.backend.bk_reshape(
@@ -2605,7 +2813,7 @@ class FoCUS:
2605
2813
  ishape = list(in_image.shape)
2606
2814
 
2607
2815
  npix = ishape[-1]
2608
-
2816
+
2609
2817
  ndata = 1
2610
2818
  for k in range(len(ishape) - 1):
2611
2819
  ndata = ndata * ishape[k]
@@ -2618,7 +2826,7 @@ class FoCUS:
2618
2826
  res = self.backend.bk_complex(rr, ii)
2619
2827
  else:
2620
2828
  res = self.backend.conv1d(tim, self.ww_SmoothT[1])
2621
-
2829
+
2622
2830
  return self.backend.bk_reshape(res, ishape)
2623
2831
 
2624
2832
  else:
@@ -2646,16 +2854,30 @@ class FoCUS:
2646
2854
  odata = odata * ishape[k]
2647
2855
 
2648
2856
  tim = self.backend.bk_reshape(image, [odata, ishape[-1]])
2649
- if tim.dtype == self.all_cbk_type:
2650
- rr = self.backend.bk_sparse_dense_matmul(
2651
- self.backend.bk_real(tim), l_w_smooth
2652
- )
2653
- ri = self.backend.bk_sparse_dense_matmul(
2654
- self.backend.bk_imag(tim), l_w_smooth
2655
- )
2656
- res = self.backend.bk_complex(rr, ri)
2857
+ if spin==0:
2858
+ if tim.dtype == self.all_cbk_type:
2859
+ rr = self.backend.bk_sparse_dense_matmul(
2860
+ self.backend.bk_real(tim), l_w_smooth
2861
+ )
2862
+ ri = self.backend.bk_sparse_dense_matmul(
2863
+ self.backend.bk_imag(tim), l_w_smooth
2864
+ )
2865
+ res = self.backend.bk_complex(rr, ri)
2866
+ else:
2867
+ res = self.backend.bk_sparse_dense_matmul(tim, l_w_smooth)
2657
2868
  else:
2658
- res = self.backend.bk_sparse_dense_matmul(tim, l_w_smooth)
2869
+ tim=self.backend.bk_reshape(tim,[odata//2,2*tim.shape[-1]])
2870
+ if tim.dtype == self.all_cbk_type:
2871
+ rr = self.backend.bk_sparse_dense_matmul(
2872
+ self.backend.bk_real(tim), l_w_smooth
2873
+ )
2874
+ ri = self.backend.bk_sparse_dense_matmul(
2875
+ self.backend.bk_imag(tim), l_w_smooth
2876
+ )
2877
+ res = self.backend.bk_complex(rr, ri)
2878
+ else:
2879
+ res = self.backend.bk_sparse_dense_matmul(tim, l_w_smooth)
2880
+
2659
2881
  if len(ishape) == 1:
2660
2882
  return self.backend.bk_reshape(res, [ishape[-1]])
2661
2883
  else: