foscat 2025.7.3__py3-none-any.whl → 2025.8.3__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
foscat/FoCUS.py CHANGED
@@ -3,9 +3,10 @@ import sys
3
3
 
4
4
  import healpy as hp
5
5
  import numpy as np
6
+ import foscat.HealSpline as HS
6
7
  from scipy.interpolate import griddata
7
8
 
8
- TMPFILE_VERSION = "V6_0"
9
+ TMPFILE_VERSION = "V7_0"
9
10
 
10
11
 
11
12
  class FoCUS:
@@ -35,7 +36,7 @@ class FoCUS:
35
36
  mpi_rank=0
36
37
  ):
37
38
 
38
- self.__version__ = "2025.07.3"
39
+ self.__version__ = "2025.08.3"
39
40
  # P00 coeff for normalization for scat_cov
40
41
  self.TMPFILE_VERSION = TMPFILE_VERSION
41
42
  self.P1_dic = None
@@ -637,11 +638,13 @@ class FoCUS:
637
638
  return image.numpy()[self.ring2nest[lout]]
638
639
 
639
640
  # --------------------------------------------------------
640
- def ud_grade(self, im, j, axis=0):
641
+ def ud_grade(self, im, j, axis=0, cell_ids=None, nside=None):
641
642
  rim = im
642
643
  for k in range(j):
643
644
  # rim = self.smooth(rim, axis=axis)
644
- rim = self.ud_grade_2(rim, axis=axis)
645
+ rim = self.ud_grade_2(rim, axis=axis,
646
+ cell_ids=cell_ids,
647
+ nside=nside)
645
648
  return rim
646
649
 
647
650
  # --------------------------------------------------------
@@ -654,55 +657,39 @@ class FoCUS:
654
657
  print("Use of 2D scat with data that has less than 2D")
655
658
  return None, None
656
659
 
657
- npix = im.shape[axis]
658
- npiy = im.shape[axis + 1]
659
- odata = 1
660
- if len(ishape) > axis + 2:
661
- for k in range(axis + 2, len(ishape)):
662
- odata = odata * ishape[k]
660
+ npix = im.shape[-2]
661
+ npiy = im.shape[-1]
663
662
 
664
663
  ndata = 1
665
- for k in range(axis):
664
+ for k in range(len(im.shape)-2):
666
665
  ndata = ndata * ishape[k]
667
666
 
668
667
  tim = self.backend.bk_reshape(
669
- self.backend.bk_cast(im), [ndata, npix, npiy, odata]
668
+ self.backend.bk_cast(im), [ndata, npix, npiy, 1]
670
669
  )
671
670
  tim = self.backend.bk_reshape(
672
671
  tim[:, 0 : 2 * (npix // 2), 0 : 2 * (npiy // 2), :],
673
- [ndata, npix // 2, 2, npiy // 2, 2, odata],
672
+ [ndata, npix // 2, 2, npiy // 2, 2, 1],
674
673
  )
675
674
 
676
675
  res = self.backend.bk_reduce_sum(self.backend.bk_reduce_sum(tim, 4), 2) / 4
677
676
 
678
- if axis == 0:
679
- if len(ishape) == 2:
680
- return self.backend.bk_reshape(res, [npix // 2, npiy // 2]), None
681
- else:
682
- return (
683
- self.backend.bk_reshape(
684
- res, [npix // 2, npiy // 2] + ishape[axis + 2 :]
685
- ),
686
- None,
687
- )
677
+ if len(ishape) == 2:
678
+ return (
679
+ self.backend.bk_reshape(
680
+ res, [npix // 2, npiy // 2]
681
+ ),
682
+ None,
683
+ )
688
684
  else:
689
- if len(ishape) == axis + 2:
690
- return (
691
- self.backend.bk_reshape(
692
- res, ishape[0:axis] + [npix // 2, npiy // 2]
693
- ),
694
- None,
695
- )
696
- else:
697
- return (
698
- self.backend.bk_reshape(
699
- res,
700
- ishape[0:axis]
701
- + [npix // 2, npiy // 2]
702
- + ishape[axis + 2 :],
703
- ),
704
- None,
705
- )
685
+ return (
686
+ self.backend.bk_reshape(
687
+ res,
688
+ ishape[0:-2]
689
+ + [npix // 2, npiy // 2],
690
+ ),
691
+ None,
692
+ )
706
693
 
707
694
  return self.backend.bk_reshape(res, [npix // 2, npiy // 2]), None
708
695
  elif self.use_1D:
@@ -733,11 +720,16 @@ class FoCUS:
733
720
  ),None
734
721
 
735
722
  # --------------------------------------------------------
736
- def up_grade(self, im, nout, axis=-1, nouty=None):
723
+ def up_grade(self, im, nout,
724
+ axis=-1,
725
+ nouty=None,
726
+ cell_ids=None,
727
+ o_cell_ids=None,
728
+ nside=None):
737
729
 
738
730
  ishape = list(im.shape)
739
731
  if self.use_2D:
740
- if len(ishape) < axis + 2:
732
+ if len(ishape) < 2:
741
733
  if not self.silent:
742
734
  print("Use of 2D scat with data that has less than 2D")
743
735
  return None
@@ -745,39 +737,28 @@ class FoCUS:
745
737
  if nouty is None:
746
738
  nouty = nout
747
739
 
748
- if ishape[axis] == nout and ishape[axis + 1] == nouty:
740
+ if ishape[-2] == nout and ishape[-1] == nouty:
749
741
  return im
750
742
 
751
- npix = im.shape[axis]
752
- npiy = im.shape[axis + 1]
753
- odata = 1
743
+ npix = im.shape[-2]
744
+ npiy = im.shape[-1]
754
745
 
755
746
  ndata = 1
756
- for k in range(axis):
747
+ for k in range(len(im.shape)-2):
757
748
  ndata = ndata * ishape[k]
758
749
 
759
750
  tim = self.backend.bk_reshape(
760
- self.backend.bk_cast(im), [ndata, npix, npiy, odata]
751
+ self.backend.bk_cast(im), [ndata, npix, npiy,1]
761
752
  )
762
753
 
763
754
  res = self.backend.bk_resize_image(tim, [nout, nouty])
764
755
 
765
- if axis == 0:
766
- if len(ishape) == 2:
767
- return self.backend.bk_reshape(res, [nout, nouty])
768
- else:
769
- return self.backend.bk_reshape(
770
- res, [nout, nouty] + ishape[axis + 2 :]
771
- )
756
+ if len(ishape) == 2:
757
+ return self.backend.bk_reshape(res, [nout, nouty])
772
758
  else:
773
- if len(ishape) == axis + 2:
774
- return self.backend.bk_reshape(res, ishape[0:axis] + [nout, nouty])
775
- else:
776
- return self.backend.bk_reshape(
777
- res, ishape[0:axis] + [nout, nouty]
778
- )
779
-
780
- return self.backend.bk_reshape(res, [nout, nouty])
759
+ return self.backend.bk_reshape(
760
+ res, ishape[0:-2] + [nout, nouty]
761
+ )
781
762
 
782
763
  elif self.use_1D:
783
764
  if len(ishape) < axis + 1:
@@ -820,37 +801,104 @@ class FoCUS:
820
801
  return self.backend.bk_reshape(tim, ishape[0:-1] + [nout])
821
802
 
822
803
  else:
823
-
824
- lout = int(np.sqrt(im.shape[-1] // 12))
825
-
804
+ if nside is None:
805
+ lout = int(np.sqrt(im.shape[-1] // 12))
806
+ else:
807
+ lout = nside
808
+
826
809
  if (lout,nout) not in self.pix_interp_val:
827
810
  if not self.silent:
828
811
  print("compute lout nout", lout, nout)
829
- th, ph = hp.pix2ang(
830
- nout, np.arange(12 * nout**2, dtype="int"), nest=True
831
- )
832
- p, w = hp.get_interp_weights(lout, th, ph, nest=True)
833
- del th
834
- del ph
835
-
836
- indice = np.zeros([12 * nout * nout * 4, 2], dtype="int")
837
- p = p.T
838
- w = w.T
839
- t = np.argsort(
840
- p, 1
841
- ).flatten() # to make oder indices for sparsematrix computation
842
- t = t + np.repeat(np.arange(12 * nout * nout) * 4, 4)
843
- p = p.flatten()[t]
844
- w = w.flatten()[t]
845
- indice[:, 1] = np.repeat(np.arange(12 * nout**2), 4)
846
- indice[:, 0] = p
847
-
848
- self.pix_interp_val[(lout,nout)] = 1
849
- self.weight_interp_val[(lout,nout)] = self.backend.bk_SparseTensor(
850
- self.backend.bk_constant(indice),
851
- self.backend.bk_constant(self.backend.bk_cast(w.flatten())),
852
- dense_shape=[12 * lout**2,12 * nout**2],
853
- )
812
+ if cell_ids is None:
813
+ o_cell_ids=np.arange(12 * nout**2, dtype="int")
814
+ i_npix=12*lout**2
815
+
816
+ #level=int(np.log2(lout)) # nside=128
817
+
818
+ #sp = HS.heal_spline(level,gamma=2.0)
819
+
820
+ th, ph = hp.pix2ang(
821
+ nout, o_cell_ids, nest=True
822
+ )
823
+
824
+ all_idx,www=hp.get_interp_weights(lout,th,ph,nest=True)
825
+
826
+ #www,all_idx,hidx=sp.ang2weigths(th,ph,nest=True)
827
+
828
+ w=www.T
829
+ p=all_idx.T
830
+
831
+ w=w.flatten()
832
+ p=p.flatten()
833
+
834
+ indice = np.zeros([o_cell_ids.shape[0] * 4, 2], dtype="int")
835
+ indice[:, 1] = np.repeat(np.arange(o_cell_ids.shape[0]), 4)
836
+ indice[:, 0] = p
837
+
838
+ self.pix_interp_val[(lout,nout)] = 1
839
+ self.weight_interp_val[(lout,nout)] = self.backend.bk_SparseTensor(
840
+ self.backend.bk_constant(indice),
841
+ self.backend.bk_constant(self.backend.bk_cast(w)),
842
+ dense_shape=[i_npix,o_cell_ids.shape[0]],
843
+ )
844
+
845
+ else:
846
+ ratio=(nout//lout)**2
847
+ if o_cell_ids is None:
848
+ o_cell_ids=np.tile(cell_ids,ratio)*ratio+np.repeat(np.arange(ratio),cell_ids.shape[0])
849
+ i_npix=cell_ids.shape[0]
850
+
851
+ #level=int(np.log2(lout)) # nside=128
852
+
853
+ #sp = HS.heal_spline(level,gamma=2.0)
854
+
855
+ th, ph = hp.pix2ang(
856
+ nout, o_cell_ids, nest=True
857
+ )
858
+
859
+ all_idx,www=hp.get_interp_weights(lout,th,ph,nest=True)
860
+ #www,all_idx,hidx=sp.ang2weigths(th,ph,nest=True)
861
+
862
+ hidx,inv_idx = np.unique(all_idx,
863
+ return_inverse=True)
864
+ all_idx = inv_idx
865
+
866
+ sorter = np.argsort(hidx)
867
+
868
+ index=sorter[np.searchsorted(hidx,
869
+ cell_ids,
870
+ sorter=sorter)]
871
+
872
+ mask = -np.ones([hidx.shape[0]])
873
+
874
+ mask[index] = np.arange(index.shape[0],dtype='int')
875
+
876
+ all_idx=mask[all_idx]
877
+
878
+ www[all_idx==-1]=0.0
879
+ www/=np.sum(www,0)[None,:]
880
+
881
+ all_idx[all_idx==-1]=0
882
+
883
+ w=www.T
884
+ p=all_idx.T
885
+
886
+ w=w.flatten()
887
+ p=p.flatten()
888
+
889
+ indice = np.zeros([o_cell_ids.shape[0] * 4, 2], dtype="int")
890
+ indice[:, 1] = np.repeat(np.arange(o_cell_ids.shape[0]), 4)
891
+ indice[:, 0] = p
892
+
893
+ self.pix_interp_val[(lout,nout)] = 1
894
+ self.weight_interp_val[(lout,nout)] = self.backend.bk_SparseTensor(
895
+ self.backend.bk_constant(indice),
896
+ self.backend.bk_constant(self.backend.bk_cast(w)),
897
+ dense_shape=[i_npix,o_cell_ids.shape[0]],
898
+ )
899
+
900
+ del w
901
+ del p
854
902
 
855
903
  if lout == nout:
856
904
  imout = im
@@ -861,7 +909,7 @@ class FoCUS:
861
909
  for k in range(len(ishape)-1):
862
910
  ndata = ndata * ishape[k]
863
911
  tim = self.backend.bk_reshape(
864
- self.backend.bk_cast(im), [ndata, 12 * lout**2]
912
+ self.backend.bk_cast(im), [ndata, ishape[-1]]
865
913
  )
866
914
  if tim.dtype == self.all_cbk_type:
867
915
  rr = self.backend.bk_sparse_dense_matmul(
@@ -878,12 +926,12 @@ class FoCUS:
878
926
  tim,
879
927
  self.weight_interp_val[(lout,nout)],
880
928
  )
881
-
929
+
882
930
  if len(ishape) == 1:
883
- return self.backend.bk_reshape(imout, [12 * nout**2])
931
+ return self.backend.bk_reshape(imout, [imout.shape[-1]])
884
932
  else:
885
933
  return self.backend.bk_reshape(
886
- imout, ishape[0:axis]+[12 * nout**2]
934
+ imout, ishape[0:-1]+[imout.shape[-1]]
887
935
  )
888
936
  return imout
889
937
 
@@ -1175,7 +1223,7 @@ class FoCUS:
1175
1223
  % (self.TEMPLATE_PATH, l_kernel**2,TMPFILE_VERSION, nside)
1176
1224
  )
1177
1225
  else:
1178
- if cell_ids is not None and nside>512:
1226
+ if cell_ids is not None and spin==0:
1179
1227
  tmp = self.read_index(
1180
1228
  "%s/XXXX_%s_W%d_%d_%d_PIDX.fst" # can not work
1181
1229
  % (
@@ -1188,6 +1236,7 @@ class FoCUS:
1188
1236
  )
1189
1237
 
1190
1238
  else:
1239
+ '''
1191
1240
  print('LOAD ',"%s/FOSCAT_%s_W%d_%d_%d_PIDX-SPIN%d.fst"
1192
1241
  % (
1193
1242
  self.TEMPLATE_PATH,
@@ -1196,6 +1245,7 @@ class FoCUS:
1196
1245
  self.NORIENT,
1197
1246
  nside,spin # if cell_ids computes the index
1198
1247
  ))
1248
+ '''
1199
1249
  tmp = self.read_index(
1200
1250
  "%s/FOSCAT_%s_W%d_%d_%d_PIDX-SPIN%d.fst"
1201
1251
  % (
@@ -1208,20 +1258,21 @@ class FoCUS:
1208
1258
  )
1209
1259
 
1210
1260
  except:
1211
- if cell_ids is not None and nside<=512:
1261
+ if cell_ids is not None and spin!=0:
1212
1262
  self.init_index(nside, kernel=kernel, spin=spin)
1213
1263
 
1214
1264
  if not self.use_2D:
1215
- print('NOT FOUND THEN COMPUTE %s/FOSCAT_%s_W%d_%d_%d_PIDX-SPIN%d.fst'
1216
- % (
1217
- self.TEMPLATE_PATH,
1218
- TMPFILE_VERSION,
1219
- l_kernel**2,
1220
- self.NORIENT,
1221
- nside,spin # if cell_ids computes the index
1222
- )
1223
- )
1224
1265
  if spin!=0:
1266
+ # keep the print here as spin!=0 can be long
1267
+ print('NOT FOUND THEN COMPUTE %s/FOSCAT_%s_W%d_%d_%d_PIDX-SPIN%d.fst'
1268
+ % (
1269
+ self.TEMPLATE_PATH,
1270
+ TMPFILE_VERSION,
1271
+ l_kernel**2,
1272
+ self.NORIENT,
1273
+ nside,spin # if cell_ids computes the index
1274
+ )
1275
+ )
1225
1276
  try:
1226
1277
  tmp = self.read_index(
1227
1278
  "%s/FOSCAT_%s_W%d_%d_%d_PIDX-SPIN0.fst"
@@ -1234,6 +1285,7 @@ class FoCUS:
1234
1285
  )
1235
1286
  )
1236
1287
  except:
1288
+ '''
1237
1289
  print('NOT FOUND THEN COMPUTE %s/FOSCAT_%s_W%d_%d_%d_PIDX-SPIN0.fst'
1238
1290
  % (
1239
1291
  self.TEMPLATE_PATH,
@@ -1243,7 +1295,7 @@ class FoCUS:
1243
1295
  nside
1244
1296
  )
1245
1297
  )
1246
-
1298
+ '''
1247
1299
  self.init_index(nside, kernel=kernel, spin=0)
1248
1300
 
1249
1301
  tmp = self.read_index(
@@ -1266,6 +1318,7 @@ class FoCUS:
1266
1318
  )
1267
1319
  )
1268
1320
 
1321
+ '''
1269
1322
  nn=self.NORIENT*12*nside**2
1270
1323
  idxEB=np.concatenate([tmp,tmp,tmp,tmp],0)
1271
1324
  idxEB[tmp.shape[0]:2*tmp.shape[0],0]+=12*nside**2
@@ -1273,7 +1326,156 @@ class FoCUS:
1273
1326
  idxEB[2*tmp.shape[0]:,1]+=nn
1274
1327
 
1275
1328
  tmpEB=np.zeros([tmpw.shape[0]*4],dtype='complex')
1329
+ '''
1330
+ import foscat.HOrientedConvol as hs
1276
1331
 
1332
+ hconvol=hs.HOrientedConvol(nside,3*self.KERNELSZ,cell_ids=cell_ids)
1333
+
1334
+ if cell_ids is None:
1335
+ l_cell_ids=np.arange(12*nside**2)
1336
+ else:
1337
+ l_cell_ids=cell_ids
1338
+
1339
+ nvalid=self.KERNELSZ**2
1340
+ idxEB=hconvol.idx_nn[:,0:nvalid]
1341
+ tmpEB=np.zeros([self.NORIENT,4,l_cell_ids.shape[0],nvalid],dtype='complex')
1342
+ tmpS=np.zeros([4,l_cell_ids.shape[0],nvalid],dtype='float')
1343
+
1344
+ idx={}
1345
+ nn=0
1346
+ nn2=1
1347
+ if nside<64:
1348
+ pp=10
1349
+ else:
1350
+ pp=1
1351
+ while nn2>0:
1352
+ idx2={}
1353
+ nn2=0
1354
+ im=np.zeros([12*nside**2])
1355
+ for n in range(l_cell_ids.shape[0]):
1356
+ if im[hconvol.idx_nn[n,0]]==0 and n not in idx:
1357
+ im[hconvol.idx_nn[n,:]]=1.0
1358
+ idx[hconvol.idx_nn[n,0]]=1.0
1359
+ idx2[hconvol.idx_nn[n,0]]=1.0
1360
+ nn+=1
1361
+ nn2+=1
1362
+ im=np.zeros([12*nside**2])
1363
+ for k in idx2:
1364
+ im[k]=1.0
1365
+ r=self.convol(im)
1366
+ for k in range(self.NORIENT):
1367
+ ralm=hp.map2alm(hp.reorder(r[k].cpu().numpy().real,n2r=True))[None,:]
1368
+ ialm=hp.map2alm(hp.reorder(r[k].cpu().numpy().imag,n2r=True))[None,:]
1369
+
1370
+ alm=np.concatenate([ralm,0*ralm,0*ralm],0)
1371
+ rqe,rue,rie=hp.alm2map_spin(alm,nside,spin,3*nside-1)
1372
+ alm=np.concatenate([ialm,0*ialm,0*ialm],0)
1373
+ iqe,iue,iie=hp.alm2map_spin(alm,nside,spin,3*nside-1)
1374
+
1375
+ alm=np.concatenate([0*ralm,ralm,0*ralm],0)
1376
+ rqb,rub,rib=hp.alm2map_spin(alm,nside,spin,3*nside-1)
1377
+ alm=np.concatenate([0*ialm,ialm,0*ialm],0)
1378
+ iqb,iub,iib=hp.alm2map_spin(alm,nside,spin,3*nside-1)
1379
+
1380
+ rqe=hp.reorder(rqe,r2n=True)
1381
+ rue=hp.reorder(rue,r2n=True)
1382
+ rqb=hp.reorder(rqb,r2n=True)
1383
+ rub=hp.reorder(rub,r2n=True)
1384
+
1385
+ iqe=hp.reorder(iqe,r2n=True)
1386
+ iue=hp.reorder(iue,r2n=True)
1387
+ iqb=hp.reorder(iqb,r2n=True)
1388
+ iub=hp.reorder(iub,r2n=True)
1389
+
1390
+ for l in idx2:
1391
+ tmpEB[k,0,l]=rqe[idxEB[l,:]]+1J*iqe[idxEB[l,:]]
1392
+ tmpEB[k,1,l]=rue[idxEB[l,:]]+1J*iue[idxEB[l,:]]
1393
+ tmpEB[k,2,l]=rqb[idxEB[l,:]]+1J*iqb[idxEB[l,:]]
1394
+ tmpEB[k,3,l]=rub[idxEB[l,:]]+1J*iub[idxEB[l,:]]
1395
+
1396
+ r=self.smooth(im)
1397
+
1398
+ ralm=hp.map2alm(hp.reorder(r.cpu().numpy(),n2r=True))[None,:]
1399
+
1400
+ alm=np.concatenate([ralm,0*ralm,0*ralm],0)
1401
+ rqe,rue,rie=hp.alm2map_spin(alm,nside,spin,3*nside-1)
1402
+
1403
+ alm=np.concatenate([0*ralm,ralm,0*ralm],0)
1404
+ rqb,rub,rib=hp.alm2map_spin(alm,nside,spin,3*nside-1)
1405
+
1406
+ rqe=hp.reorder(rqe,r2n=True)
1407
+ rue=hp.reorder(rue,r2n=True)
1408
+ rqb=hp.reorder(rqb,r2n=True)
1409
+ rub=hp.reorder(rub,r2n=True)
1410
+
1411
+ for l in idx2:
1412
+ tmpS[0,l,:]=rqe[idxEB[l,:]]
1413
+ tmpS[1,l,:]=rue[idxEB[l,:]]
1414
+ tmpS[2,l,:]=rqb[idxEB[l,:]]
1415
+ tmpS[3,l,:]=rub[idxEB[l,:]]
1416
+ if 100*nn/(l_cell_ids.shape[0])>pp:
1417
+ if nside<64:
1418
+ pp+=10
1419
+ else:
1420
+ pp+=1
1421
+ print('%.2f%% Done'%(100*nn/(l_cell_ids.shape[0])))
1422
+
1423
+ wav=tmpEB.flatten()
1424
+ wwav=tmpS.flatten()
1425
+ ndata=l_cell_ids.shape[0]*nvalid
1426
+ indice_1_1=np.tile(idxEB.flatten(),4*self.NORIENT)
1427
+ for k in range(self.NORIENT):
1428
+ indice_1_1[(4*k+1)*ndata:(4*k+2)*ndata]+=l_cell_ids.shape[0]
1429
+ indice_1_1[(4*k+3)*ndata:(4*k+4)*ndata]+=l_cell_ids.shape[0]
1430
+
1431
+ indice_1_0=np.tile(np.tile(np.repeat(np.arange(l_cell_ids.shape[0]),nvalid),4),self.NORIENT)
1432
+ for k in range(self.NORIENT):
1433
+ indice_1_0[(4*k+2)*ndata:(4*k+4)*ndata]+=self.NORIENT*l_cell_ids.shape[0]
1434
+ indice_1_0[(4*k)*ndata:(4*k+4)*ndata]+=k*l_cell_ids.shape[0]
1435
+ '''
1436
+ import matplotlib.pyplot as plt
1437
+ plt.figure()
1438
+ plt.subplot(2,2,1)
1439
+ plt.plot(indice_1_0)
1440
+ plt.subplot(2,2,2)
1441
+ plt.plot(indice_1_1)
1442
+ plt.subplot(2,2,3)
1443
+ plt.plot(wav.real)
1444
+ plt.subplot(2,2,4)
1445
+ plt.plot(abs(wav))
1446
+
1447
+ iarg=np.argsort(indice_1_0)
1448
+ indice_1_1=indice_1_1[iarg]
1449
+ indice_1_0=indice_1_0[iarg]
1450
+ wav=wav[iarg]
1451
+ '''
1452
+
1453
+ indice=np.concatenate([indice_1_1[:,None],indice_1_0[:,None]],1)
1454
+
1455
+ indice_2_1=np.tile(idxEB.flatten(),4)
1456
+ indice_2_1[ndata:2*ndata]+=l_cell_ids.shape[0]
1457
+ indice_2_1[3*ndata:4*ndata]+=l_cell_ids.shape[0]
1458
+ indice_2_0=np.tile(np.repeat(np.arange(l_cell_ids.shape[0]),nvalid),4)
1459
+ indice_2_0[2*ndata:]+=l_cell_ids.shape[0]
1460
+ '''
1461
+ plt.figure()
1462
+ plt.subplot(2,2,1)
1463
+ plt.plot(indice_2_0)
1464
+ plt.subplot(2,2,2)
1465
+ plt.plot(indice_2_1)
1466
+ plt.subplot(2,2,3)
1467
+ plt.plot(wav.real)
1468
+ plt.subplot(2,2,4)
1469
+ plt.plot(wwav)
1470
+
1471
+ iarg=np.argsort(indice_2_0)
1472
+ indice_2_1=indice_2_1[iarg]
1473
+ indice_2_0=indice_2_0[iarg]
1474
+ wwav=wwav[iarg]
1475
+ '''
1476
+ indice2=np.concatenate([indice_2_1[:,None],indice_2_0[:,None]],1)
1477
+
1478
+ '''
1277
1479
  for k in range(self.NORIENT*12*nside**2):
1278
1480
  if k%(nside**2)==0:
1279
1481
  print('Init index 1/2 spin=%d Please wait %d done against %d nside=%d kernel=%d'%(spin,k//(nside**2),
@@ -1300,7 +1502,7 @@ class FoCUS:
1300
1502
  tmpEB[idx+2*tmp.shape[0]]=hp.reorder(i,r2n=True)[tmp[idx,0]]+1J*hp.reorder(i2,r2n=True)[tmp[idx,0]]
1301
1503
  tmpEB[idx+3*tmp.shape[0]]=hp.reorder(q,r2n=True)[tmp[idx,0]]+1J*hp.reorder(q2,r2n=True)[tmp[idx,0]]
1302
1504
 
1303
-
1505
+ '''
1304
1506
  self.save_index("%s/FOSCAT_%s_W%d_%d_%d_PIDX-SPIN%d.fst"% (self.TEMPLATE_PATH,
1305
1507
  self.TMPFILE_VERSION,
1306
1508
  self.KERNELSZ**2,
@@ -1308,7 +1510,7 @@ class FoCUS:
1308
1510
  nside,
1309
1511
  spin
1310
1512
  ),
1311
- idxEB
1513
+ indice
1312
1514
  )
1313
1515
  self.save_index("%s/FOSCAT_%s_W%d_%d_%d_WAVE-SPIN%d.fst"% (self.TEMPLATE_PATH,
1314
1516
  self.TMPFILE_VERSION,
@@ -1317,9 +1519,9 @@ class FoCUS:
1317
1519
  nside,
1318
1520
  spin,
1319
1521
  ),
1320
- tmpEB
1522
+ wav
1321
1523
  )
1322
-
1524
+ '''
1323
1525
  tmp = self.read_index(
1324
1526
  "%s/FOSCAT_%s_W%d_%d_%d_PIDX2-SPIN0.fst"
1325
1527
  % (
@@ -1361,7 +1563,7 @@ class FoCUS:
1361
1563
  tmpEB[idx+2*tmp.shape[0]]=hp.reorder(i,r2n=True)[tmp[idx,0]]
1362
1564
  tmpEB[idx+3*tmp.shape[0]]=hp.reorder(q,r2n=True)[tmp[idx,0]]
1363
1565
 
1364
-
1566
+ '''
1365
1567
  self.save_index("%s/FOSCAT_%s_W%d_%d_%d_PIDX2-SPIN%d.fst"% (self.TEMPLATE_PATH,
1366
1568
  self.TMPFILE_VERSION,
1367
1569
  self.KERNELSZ**2,
@@ -1369,7 +1571,7 @@ class FoCUS:
1369
1571
  nside,
1370
1572
  spin
1371
1573
  ),
1372
- idxEB
1574
+ indice2
1373
1575
  )
1374
1576
  self.save_index("%s/FOSCAT_%s_W%d_%d_%d_SMOO-SPIN%d.fst"% (self.TEMPLATE_PATH,
1375
1577
  self.TMPFILE_VERSION,
@@ -1378,10 +1580,11 @@ class FoCUS:
1378
1580
  nside,
1379
1581
  spin,
1380
1582
  ),
1381
- tmpEB
1583
+ wwav
1382
1584
  )
1585
+
1383
1586
  else:
1384
-
1587
+ '''
1385
1588
  if l_kernel == 5:
1386
1589
  pw = 0.5
1387
1590
  pw2 = 0.5
@@ -1396,8 +1599,20 @@ class FoCUS:
1396
1599
  pw = 0.5
1397
1600
  pw2 = 0.25
1398
1601
  threshold = 4e-5
1602
+ '''
1603
+ import foscat.HOrientedConvol as hs
1399
1604
 
1400
- if cell_ids is not None and nside>512:
1605
+ hconvol=hs.HOrientedConvol(nside,l_kernel,cell_ids=cell_ids)
1606
+
1607
+ orientations=np.pi*np.arange(self.NORIENT)/self.NORIENT
1608
+
1609
+ wav,indice,wwav,indice2=hconvol.make_wavelet_matrix(orientations,
1610
+ polar=True,
1611
+ return_index=True,
1612
+ return_smooth=True)
1613
+
1614
+ '''
1615
+ if cell_ids is not None and nside>256:
1401
1616
  if not isinstance(cell_ids, np.ndarray):
1402
1617
  cell_ids = self.backend.to_numpy(cell_ids)
1403
1618
  th, ph = hp.pix2ang(nside, cell_ids, nest=True)
@@ -1521,8 +1736,7 @@ class FoCUS:
1521
1736
  wav = wav[:iv]
1522
1737
  indice2 = indice2[:iv2, :]
1523
1738
  wwav = wwav[:iv2]
1524
- if not self.silent:
1525
- print("Kernel Size ", iv / (self.NORIENT * 12 * nside * nside))
1739
+ '''
1526
1740
 
1527
1741
  if cell_ids is None:
1528
1742
  if not self.silent:
@@ -1597,7 +1811,7 @@ class FoCUS:
1597
1811
  )
1598
1812
  return None
1599
1813
 
1600
- if cell_ids is None or nside<=512:
1814
+ if cell_ids is None or spin!=0:
1601
1815
  self.barrier()
1602
1816
  if self.use_2D:
1603
1817
  tmp = self.read_index(
@@ -1703,6 +1917,20 @@ class FoCUS:
1703
1917
  tmp2[lidx,0]=0
1704
1918
  tmp2[:,1]+=i_id*lcell_ids.shape[0]
1705
1919
  tmp2[:,0]+=i_id2*lcell_ids.shape[0]
1920
+
1921
+ #add normalisation
1922
+ ww=np.bincount(tmp2[:,1],weights=ws)
1923
+ ws/=ww[tmp2[:,1]]
1924
+
1925
+ wh=np.bincount(tmp[:,1])
1926
+ ww=np.bincount(tmp[:,1],weights=wr)
1927
+ wr-=(ww/wh)[tmp[:,1]]
1928
+ ww=np.bincount(tmp[:,1],weights=wi)
1929
+ wi-=(ww/wh)[tmp[:,1]]
1930
+
1931
+ ww=np.bincount(tmp[:,1],weights=np.sqrt(wr*wr+wi*wi))
1932
+ wr/=ww[tmp[:,1]]
1933
+ wi/=ww[tmp[:,1]]
1706
1934
 
1707
1935
  else:
1708
1936
  tmp = indice
@@ -2152,13 +2380,8 @@ class FoCUS:
2152
2380
  res = v1 / vh
2153
2381
 
2154
2382
  oshape = [x.shape[0]] + [mask.shape[0]]
2155
- if axis > 0:
2156
- oshape = oshape + list(x.shape[1:axis])
2157
-
2158
- if len(x.shape[axis:-2]) > 0:
2159
- oshape = oshape + list(x.shape[axis:-2])
2160
- else:
2161
- oshape = oshape + [1]
2383
+ if len(x.shape)>1:
2384
+ oshape = oshape + list(x.shape[1:-2])
2162
2385
 
2163
2386
  if calc_var:
2164
2387
  if self.backend.bk_is_complex(vtmp):
@@ -2560,15 +2783,16 @@ class FoCUS:
2560
2783
  print("Use of 2D scat with data that has less than 2D")
2561
2784
  return None
2562
2785
 
2563
- npix = ishape[axis]
2564
- npiy = ishape[axis + 1]
2786
+ npix = ishape[-2]
2787
+ npiy = ishape[-1]
2788
+
2565
2789
  odata = 1
2566
- if len(ishape) > axis + 2:
2567
- for k in range(axis + 2, len(ishape)):
2790
+ if len(ishape) > 1:
2791
+ for k in range(len(ishape)-2):
2568
2792
  odata = odata * ishape[k]
2569
2793
 
2570
2794
  ndata = 1
2571
- for k in range(axis):
2795
+ for k in range(len(ishape)-2):
2572
2796
  ndata = ndata * ishape[k]
2573
2797
 
2574
2798
  tim = self.backend.bk_reshape(
@@ -2630,16 +2854,30 @@ class FoCUS:
2630
2854
  odata = odata * ishape[k]
2631
2855
 
2632
2856
  tim = self.backend.bk_reshape(image, [odata, ishape[-1]])
2633
- if tim.dtype == self.all_cbk_type:
2634
- rr = self.backend.bk_sparse_dense_matmul(
2635
- self.backend.bk_real(tim), l_w_smooth
2636
- )
2637
- ri = self.backend.bk_sparse_dense_matmul(
2638
- self.backend.bk_imag(tim), l_w_smooth
2639
- )
2640
- res = self.backend.bk_complex(rr, ri)
2857
+ if spin==0:
2858
+ if tim.dtype == self.all_cbk_type:
2859
+ rr = self.backend.bk_sparse_dense_matmul(
2860
+ self.backend.bk_real(tim), l_w_smooth
2861
+ )
2862
+ ri = self.backend.bk_sparse_dense_matmul(
2863
+ self.backend.bk_imag(tim), l_w_smooth
2864
+ )
2865
+ res = self.backend.bk_complex(rr, ri)
2866
+ else:
2867
+ res = self.backend.bk_sparse_dense_matmul(tim, l_w_smooth)
2641
2868
  else:
2642
- res = self.backend.bk_sparse_dense_matmul(tim, l_w_smooth)
2869
+ tim=self.backend.bk_reshape(tim,[odata//2,2*tim.shape[-1]])
2870
+ if tim.dtype == self.all_cbk_type:
2871
+ rr = self.backend.bk_sparse_dense_matmul(
2872
+ self.backend.bk_real(tim), l_w_smooth
2873
+ )
2874
+ ri = self.backend.bk_sparse_dense_matmul(
2875
+ self.backend.bk_imag(tim), l_w_smooth
2876
+ )
2877
+ res = self.backend.bk_complex(rr, ri)
2878
+ else:
2879
+ res = self.backend.bk_sparse_dense_matmul(tim, l_w_smooth)
2880
+
2643
2881
  if len(ishape) == 1:
2644
2882
  return self.backend.bk_reshape(res, [ishape[-1]])
2645
2883
  else: