foscat 2025.6.1__py3-none-any.whl → 2025.6.3__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
foscat/FoCUS.py CHANGED
@@ -10,32 +10,32 @@ TMPFILE_VERSION = "V5_0"
10
10
 
11
11
  class FoCUS:
12
12
  def __init__(
13
- self,
14
- NORIENT=4,
15
- LAMBDA=1.2,
16
- KERNELSZ=3,
17
- slope=1.0,
18
- all_type="float32",
19
- nstep_max=16,
20
- padding="SAME",
21
- gpupos=0,
22
- mask_thres=None,
23
- mask_norm=False,
24
- isMPI=False,
25
- TEMPLATE_PATH="data",
26
- BACKEND="tensorflow",
27
- use_2D=False,
28
- use_1D=False,
29
- return_data=False,
30
- JmaxDelta=0,
31
- DODIV=False,
32
- InitWave=None,
33
- silent=True,
34
- mpi_size=1,
35
- mpi_rank=0,
13
+ self,
14
+ NORIENT=4,
15
+ LAMBDA=1.2,
16
+ KERNELSZ=3,
17
+ slope=1.0,
18
+ all_type="float32",
19
+ nstep_max=20,
20
+ padding="SAME",
21
+ gpupos=0,
22
+ mask_thres=None,
23
+ mask_norm=False,
24
+ isMPI=False,
25
+ TEMPLATE_PATH="data",
26
+ BACKEND="torch",
27
+ use_2D=False,
28
+ use_1D=False,
29
+ return_data=False,
30
+ JmaxDelta=0,
31
+ DODIV=False,
32
+ InitWave=None,
33
+ silent=True,
34
+ mpi_size=1,
35
+ mpi_rank=0
36
36
  ):
37
37
 
38
- self.__version__ = "2025.06.1"
38
+ self.__version__ = "2025.06.3"
39
39
  # P00 coeff for normalization for scat_cov
40
40
  self.TMPFILE_VERSION = TMPFILE_VERSION
41
41
  self.P1_dic = None
@@ -369,39 +369,8 @@ class FoCUS:
369
369
  self.pix_interp_val = {}
370
370
  self.weight_interp_val = {}
371
371
  self.ring2nest = {}
372
- self.nest2R = {}
373
- self.nest2R1 = {}
374
- self.nest2R2 = {}
375
- self.nest2R3 = {}
376
- self.nest2R4 = {}
377
- self.inv_nest2R = {}
378
- self.remove_border = {}
379
-
380
372
  self.ampnorm = {}
381
373
 
382
- for i in range(nstep_max):
383
- lout = 2**i
384
- self.pix_interp_val[lout] = {}
385
- self.weight_interp_val[lout] = {}
386
- for j in range(nstep_max):
387
- lout2 = 2**j
388
- self.pix_interp_val[lout][lout2] = None
389
- self.weight_interp_val[lout][lout2] = None
390
- self.ring2nest[lout] = None
391
- self.Idx_Neighbours[lout] = None
392
- self.nest2R[lout] = None
393
- self.nest2R1[lout] = None
394
- self.nest2R2[lout] = None
395
- self.nest2R3[lout] = None
396
- self.nest2R4[lout] = None
397
- self.inv_nest2R[lout] = None
398
- self.remove_border[lout] = None
399
- self.ww_CNN_Transpose[lout] = None
400
- self.ww_CNN[lout] = None
401
- self.X_CNN[lout] = None
402
- self.Y_CNN[lout] = None
403
- self.Z_CNN[lout] = None
404
-
405
374
  self.loss = {}
406
375
 
407
376
  def get_type(self):
@@ -543,7 +512,7 @@ class FoCUS:
543
512
  def toring(self, image, axis=0):
544
513
  lout = int(np.sqrt(image.shape[axis] // 12))
545
514
 
546
- if self.ring2nest[lout] is None:
515
+ if lout not in self.ring2nest:
547
516
  self.ring2nest[lout] = hp.ring2nest(lout, np.arange(12 * lout**2))
548
517
 
549
518
  return image.numpy()[self.ring2nest[lout]]
@@ -639,30 +608,10 @@ class FoCUS:
639
608
  if cell_ids is not None:
640
609
  sim, new_cell_ids = self.backend.binned_mean(im, cell_ids)
641
610
  return sim, new_cell_ids
642
-
643
- lout = int(np.sqrt(shape[axis] // 12))
644
- if im.__class__ == np.zeros([0]).__class__:
645
- oshape = np.zeros([len(shape) + 1], dtype="int")
646
- if axis > 0:
647
- oshape[0:axis] = shape[0:axis]
648
- oshape[axis] = 12 * lout * lout // 4
649
- oshape[axis + 1] = 4
650
- if len(shape) > axis:
651
- oshape[axis + 2 :] = shape[axis + 1 :]
652
- else:
653
- if axis > 0:
654
- oshape = shape[0:axis] + [12 * lout * lout // 4, 4]
655
- else:
656
- oshape = [12 * lout * lout // 4, 4]
657
- if len(shape) > axis:
658
- oshape = oshape + shape[axis + 1 :]
659
-
660
- return (
661
- self.backend.bk_reduce_mean(
662
- self.backend.bk_reshape(im, oshape), axis=axis + 1
663
- ),
664
- None,
665
- )
611
+
612
+ return self.backend.bk_reduce_mean(
613
+ self.backend.bk_reshape(im, shape[0:-1]+[shape[-1]//4,4]), axis=-1
614
+ ),None
666
615
 
667
616
  # --------------------------------------------------------
668
617
  def up_grade(self, im, nout, axis=0, nouty=None):
@@ -773,9 +722,9 @@ class FoCUS:
773
722
 
774
723
  else:
775
724
 
776
- lout = int(np.sqrt(im.shape[axis] // 12))
725
+ lout = int(np.sqrt(im.shape[-1] // 12))
777
726
 
778
- if self.pix_interp_val[lout][nout] is None:
727
+ if (lout,nout) not in self.pix_interp_val:
779
728
  if not self.silent:
780
729
  print("compute lout nout", lout, nout)
781
730
  th, ph = hp.pix2ang(
@@ -794,104 +743,51 @@ class FoCUS:
794
743
  t = t + np.repeat(np.arange(12 * nout * nout) * 4, 4)
795
744
  p = p.flatten()[t]
796
745
  w = w.flatten()[t]
797
- indice[:, 0] = np.repeat(np.arange(12 * nout**2), 4)
798
- indice[:, 1] = p
746
+ indice[:, 1] = np.repeat(np.arange(12 * nout**2), 4)
747
+ indice[:, 0] = p
799
748
 
800
- self.pix_interp_val[lout][nout] = 1
801
- self.weight_interp_val[lout][nout] = self.backend.bk_SparseTensor(
749
+ self.pix_interp_val[(lout,nout)] = 1
750
+ self.weight_interp_val[(lout,nout)] = self.backend.bk_SparseTensor(
802
751
  self.backend.bk_constant(indice),
803
752
  self.backend.bk_constant(self.backend.bk_cast(w.flatten())),
804
- dense_shape=[12 * nout**2, 12 * lout**2],
753
+ dense_shape=[12 * lout**2,12 * nout**2],
805
754
  )
806
755
 
807
756
  if lout == nout:
808
757
  imout = im
809
758
  else:
810
-
759
+ # work only on the last column
760
+
811
761
  ishape = list(im.shape)
812
- odata = 1
813
- for k in range(axis + 1, len(ishape)):
814
- odata = odata * ishape[k]
815
762
 
816
763
  ndata = 1
817
- for k in range(axis):
764
+ for k in range(len(ishape)-1):
818
765
  ndata = ndata * ishape[k]
819
766
  tim = self.backend.bk_reshape(
820
- self.backend.bk_cast(im), [ndata, 12 * lout**2, odata]
767
+ self.backend.bk_cast(im), [ndata, 12 * lout**2]
821
768
  )
822
769
  if tim.dtype == self.all_cbk_type:
823
- rr = self.backend.bk_reshape(
824
- self.backend.bk_sparse_dense_matmul(
825
- self.weight_interp_val[lout][nout],
826
- self.backend.bk_real(tim[0]),
827
- ),
828
- [1, 12 * nout**2, odata],
829
- )
830
- ii = self.backend.bk_reshape(
831
- self.backend.bk_sparse_dense_matmul(
832
- self.weight_interp_val[lout][nout],
833
- self.backend.bk_imag(tim[0]),
834
- ),
835
- [1, 12 * nout**2, odata],
836
- )
770
+ rr = self.backend.bk_sparse_dense_matmul(
771
+ self.backend.bk_real(tim),
772
+ self.weight_interp_val[(lout,nout)],
773
+ )
774
+ ii = self.backend.bk_sparse_dense_matmul(
775
+ self.backend.bk_real(tim),
776
+ self.weight_interp_val[(lout,nout)],
777
+ )
837
778
  imout = self.backend.bk_complex(rr, ii)
838
779
  else:
839
- imout = self.backend.bk_reshape(
840
- self.backend.bk_sparse_dense_matmul(
841
- self.weight_interp_val[lout][nout], tim[0]
842
- ),
843
- [1, 12 * nout**2, odata],
780
+ imout = self.backend.bk_sparse_dense_matmul(
781
+ tim,
782
+ self.weight_interp_val[(lout,nout)],
844
783
  )
845
784
 
846
- for k in range(1, ndata):
847
- if tim.dtype == self.all_cbk_type:
848
- rr = self.backend.bk_reshape(
849
- self.backend.bk_sparse_dense_matmul(
850
- self.weight_interp_val[lout][nout],
851
- self.backend.bk_real(tim[k]),
852
- ),
853
- [1, 12 * nout**2, odata],
854
- )
855
- ii = self.backend.bk_reshape(
856
- self.backend.bk_sparse_dense_matmul(
857
- self.weight_interp_val[lout][nout],
858
- self.backend.bk_imag(tim[k]),
859
- ),
860
- [1, 12 * nout**2, odata],
861
- )
862
- imout = self.backend.bk_concat(
863
- [imout, self.backend.bk_complex(rr, ii)], 0
864
- )
865
- else:
866
- imout = self.backend.bk_concat(
867
- [
868
- imout,
869
- self.backend.bk_reshape(
870
- self.backend.bk_sparse_dense_matmul(
871
- self.weight_interp_val[lout][nout], tim[k]
872
- ),
873
- [1, 12 * nout**2, odata],
874
- ),
875
- ],
876
- 0,
877
- )
878
-
879
- if axis == 0:
880
- if len(ishape) == 1:
881
- return self.backend.bk_reshape(imout, [12 * nout**2])
882
- else:
883
- return self.backend.bk_reshape(
884
- imout, [12 * nout**2] + ishape[axis + 1 :]
885
- )
785
+ if len(ishape) == 1:
786
+ return self.backend.bk_reshape(imout, [12 * nout**2])
886
787
  else:
887
- if len(ishape) == axis + 1:
888
- return self.backend.bk_reshape(
889
- imout, ishape[0:axis] + [12 * nout**2]
890
- )
891
- else:
892
- return self.backend.bk_reshape(
893
- imout, ishape[0:axis] + [12 * nout**2] + ishape[axis + 1 :]
894
- )
788
+ return self.backend.bk_reshape(
789
+ imout, ishape[0:axis-1]+[12 * nout**2]
790
+ )
895
791
  return imout
896
792
 
897
793
  # --------------------------------------------------------
@@ -1164,7 +1060,7 @@ class FoCUS:
1164
1060
  return res
1165
1061
 
1166
1062
  # ---------------------------------------------−---------
1167
- def init_index(self, nside, kernel=-1, cell_ids=None):
1063
+ def init_index(self, nside, kernel=-1, cell_ids=None, spin=0):
1168
1064
 
1169
1065
  if kernel == -1:
1170
1066
  l_kernel = self.KERNELSZ
@@ -1197,297 +1093,372 @@ class FoCUS:
1197
1093
 
1198
1094
  else:
1199
1095
  tmp = np.load(
1200
- "%s/FOSCAT_%s_W%d_%d_%d_PIDX.npy"
1096
+ "%s/FOSCAT_%s_W%d_%d_%d_PIDX-SPIN%d.npy"
1201
1097
  % (
1202
1098
  self.TEMPLATE_PATH,
1203
1099
  TMPFILE_VERSION,
1204
1100
  l_kernel**2,
1205
1101
  self.NORIENT,
1206
- nside, # if cell_ids computes the index
1102
+ nside,spin # if cell_ids computes the index
1207
1103
  )
1208
1104
  )
1105
+
1209
1106
  except:
1210
1107
  if not self.use_2D:
1211
-
1212
- if l_kernel == 5:
1213
- pw = 0.5
1214
- pw2 = 0.5
1215
- threshold = 2e-4
1216
-
1217
- elif l_kernel == 3:
1218
- pw = 1.0 / np.sqrt(2)
1219
- pw2 = 1.0
1220
- threshold = 1e-3
1221
-
1222
- elif l_kernel == 7:
1223
- pw = 0.5
1224
- pw2 = 0.25
1225
- threshold = 4e-5
1226
-
1227
- if cell_ids is not None:
1228
- if not isinstance(cell_ids, np.ndarray):
1229
- cell_ids = self.backend.to_numpy(cell_ids)
1230
- th, ph = hp.pix2ang(nside, cell_ids, nest=True)
1231
- x, y, z = hp.pix2vec(nside, cell_ids, nest=True)
1232
-
1233
- t, p = hp.pix2ang(nside, cell_ids, nest=True)
1234
- phi = [p[k] / np.pi * 180 for k in range(ncell)]
1235
- thi = [t[k] / np.pi * 180 for k in range(ncell)]
1236
-
1237
- indice2 = np.zeros([ncell * 64, 2], dtype="int")
1238
- indice = np.zeros([ncell * 64 * self.NORIENT, 2], dtype="int")
1239
- wav = np.zeros([ncell * 64 * self.NORIENT], dtype="complex")
1240
- wwav = np.zeros([ncell * 64 * self.NORIENT], dtype="float")
1241
-
1108
+ if spin!=0:
1109
+ try:
1110
+ tmp = np.load("%s/FOSCAT_%s_W%d_%d_%d_PIDX-SPIN0.npy"% (
1111
+ self.TEMPLATE_PATH,
1112
+ self.TMPFILE_VERSION,
1113
+ self.KERNELSZ**2,
1114
+ self.NORIENT,
1115
+ nside)
1116
+ )
1117
+ except:
1118
+ self.init_index(nside, kernel=kernel, spin=0)
1119
+
1120
+ tmp = np.load("%s/FOSCAT_%s_W%d_%d_%d_PIDX-SPIN0.npy"% (
1121
+ self.TEMPLATE_PATH,
1122
+ self.TMPFILE_VERSION,
1123
+ self.KERNELSZ**2,
1124
+ self.NORIENT,
1125
+ nside)
1126
+ )
1127
+
1128
+ tmpw = np.load("%s/FOSCAT_%s_W%d_%d_%d_WAVE-SPIN0.npy"% (
1129
+ self.TEMPLATE_PATH,
1130
+ self.TMPFILE_VERSION,
1131
+ self.KERNELSZ**2,
1132
+ self.NORIENT,
1133
+ nside,
1134
+ )
1135
+ )
1136
+
1137
+ nn=self.NORIENT*12*nside**2
1138
+ idxEB=np.concatenate([tmp,tmp,tmp,tmp],0)
1139
+ idxEB[tmp.shape[0]:2*tmp.shape[0],0]+=12*nside**2
1140
+ idxEB[3*tmp.shape[0]:,0]+=12*nside**2
1141
+ idxEB[2*tmp.shape[0]:,1]+=nn
1142
+
1143
+ tmpEB=np.zeros([tmpw.shape[0]*4],dtype='complex')
1144
+
1145
+ for k in range(self.NORIENT*12*nside**2):
1146
+ if k%(nside**2)==0:
1147
+ print('Init index 1/2 spin=%d Please wait %d done against %d nside=%d kernel=%d'%(spin,k//(nside**2),
1148
+ self.NORIENT*12,
1149
+ nside,
1150
+ self.KERNELSZ))
1151
+ idx=np.where(tmp[:,1]==k)[0]
1152
+
1153
+ im=np.zeros([12*nside**2])
1154
+ im[tmp[idx,0]]=tmpw[idx].real
1155
+ almR=hp.map2alm(hp.reorder(im,n2r=True))
1156
+ im[tmp[idx,0]]=tmpw[idx].imag
1157
+ almI=hp.map2alm(hp.reorder(im,n2r=True))
1158
+
1159
+ i,q,u=hp.alm2map_spin([almR,almR*0,0*almR],nside,spin,3*nside-1)
1160
+ i2,q2,u2=hp.alm2map_spin([almI,0*almI,0*almI],nside,spin,3*nside-1)
1161
+
1162
+ tmpEB[idx]=hp.reorder(i,r2n=True)[tmp[idx,0]]+1J*hp.reorder(i2,r2n=True)[tmp[idx,0]]
1163
+ tmpEB[idx+tmp.shape[0]]=hp.reorder(q,r2n=True)[tmp[idx,0]]+1J*hp.reorder(q2,r2n=True)[tmp[idx,0]]
1164
+
1165
+ i,q,u=hp.alm2map_spin([0*almR,almR,0*almR],nside,spin,3*nside-1)
1166
+ i2,q2,u2=hp.alm2map_spin([0*almI,almI,0*almI],nside,spin,3*nside-1)
1167
+
1168
+ tmpEB[idx+2*tmp.shape[0]]=hp.reorder(i,r2n=True)[tmp[idx,0]]+1J*hp.reorder(i2,r2n=True)[tmp[idx,0]]
1169
+ tmpEB[idx+3*tmp.shape[0]]=hp.reorder(q,r2n=True)[tmp[idx,0]]+1J*hp.reorder(q2,r2n=True)[tmp[idx,0]]
1170
+
1171
+
1172
+ np.save("%s/FOSCAT_%s_W%d_%d_%d_PIDX-SPIN%d.npy"% (self.TEMPLATE_PATH,
1173
+ self.TMPFILE_VERSION,
1174
+ self.KERNELSZ**2,
1175
+ self.NORIENT,
1176
+ nside,
1177
+ spin
1178
+ ),
1179
+ idxEB
1180
+ )
1181
+ np.save("%s/FOSCAT_%s_W%d_%d_%d_WAVE-SPIN%d.npy"% (self.TEMPLATE_PATH,
1182
+ self.TMPFILE_VERSION,
1183
+ self.KERNELSZ**2,
1184
+ self.NORIENT,
1185
+ nside,
1186
+ spin,
1187
+ ),
1188
+ tmpEB
1189
+ )
1190
+ tmp = np.load("%s/FOSCAT_%s_W%d_%d_%d_PIDX2-SPIN0.npy"%
1191
+ (
1192
+ self.TEMPLATE_PATH,
1193
+ self.TMPFILE_VERSION,
1194
+ self.KERNELSZ**2,
1195
+ self.NORIENT,
1196
+ nside,
1197
+ )
1198
+ )
1199
+ tmpw = np.load("%s/FOSCAT_%s_W%d_%d_%d_SMOO-SPIN0.npy"%
1200
+ (
1201
+ self.TEMPLATE_PATH,
1202
+ self.TMPFILE_VERSION,
1203
+ self.KERNELSZ**2,
1204
+ self.NORIENT,
1205
+ nside,
1206
+ )
1207
+ )
1208
+
1209
+ nn=12*nside**2
1210
+ idxEB=np.concatenate([tmp,tmp,tmp,tmp],0)
1211
+ idxEB[tmp.shape[0]:2*tmp.shape[0],0]+=12*nside**2
1212
+ idxEB[3*tmp.shape[0]:,0]+=12*nside**2
1213
+ idxEB[2*tmp.shape[0]:,1]+=nn
1214
+
1215
+ tmpEB=np.zeros([tmpw.shape[0]*4],dtype='complex')
1216
+
1217
+ for k in range(12*nside**2):
1218
+ if k%(nside**2)==0:
1219
+ print('Init index 2/2 spin=%d Please wait %d done against %d nside=%d kernel=%d'%(spin,k//(nside**2),
1220
+ 12,
1221
+ nside,
1222
+ self.KERNELSZ))
1223
+ idx=np.where(tmp[:,1]==k)[0]
1224
+
1225
+ im=np.zeros([12*nside**2])
1226
+ im[tmp[idx,0]]=tmpw[idx].real
1227
+ almR=hp.map2alm(hp.reorder(im,n2r=True))
1228
+ im[tmp[idx,0]]=tmpw[idx].imag
1229
+ almI=hp.map2alm(hp.reorder(im,n2r=True))
1230
+
1231
+ i,q,u=hp.alm2map_spin([almR,almR*0,0*almR],nside,spin,3*nside-1)
1232
+ i2,q2,u2=hp.alm2map_spin([almI,0*almI,0*almI],nside,spin,3*nside-1)
1233
+
1234
+ tmpEB[idx]=hp.reorder(i,r2n=True)[tmp[idx,0]]+1J*hp.reorder(i2,r2n=True)[tmp[idx,0]]
1235
+ tmpEB[idx+tmp.shape[0]]=hp.reorder(q,r2n=True)[tmp[idx,0]]+1J*hp.reorder(q2,r2n=True)[tmp[idx,0]]
1236
+
1237
+ i,q,u=hp.alm2map_spin([0*almR,almR,0*almR],nside,spin,3*nside-1)
1238
+ i2,q2,u2=hp.alm2map_spin([0*almI,almI,0*almI],nside,spin,3*nside-1)
1239
+
1240
+ tmpEB[idx+2*tmp.shape[0]]=hp.reorder(i,r2n=True)[tmp[idx,0]]+1J*hp.reorder(i2,r2n=True)[tmp[idx,0]]
1241
+ tmpEB[idx+3*tmp.shape[0]]=hp.reorder(q,r2n=True)[tmp[idx,0]]+1J*hp.reorder(q2,r2n=True)[tmp[idx,0]]
1242
+
1243
+
1244
+ np.save("%s/FOSCAT_%s_W%d_%d_%d_PIDX2-SPIN%d.npy"%
1245
+ (
1246
+ self.TEMPLATE_PATH,
1247
+ self.TMPFILE_VERSION,
1248
+ self.KERNELSZ**2,
1249
+ self.NORIENT,
1250
+ nside,
1251
+ spin,
1252
+ ),
1253
+ idxEB
1254
+ )
1255
+ np.save("%s/FOSCAT_%s_W%d_%d_%d_SMOO-SPIN%d.npy"%
1256
+ (
1257
+ self.TEMPLATE_PATH,
1258
+ self.TMPFILE_VERSION,
1259
+ self.KERNELSZ**2,
1260
+ self.NORIENT,
1261
+ nside,
1262
+ spin,
1263
+ ),
1264
+ tmpEB
1265
+ )
1242
1266
  else:
1243
1267
 
1244
- th, ph = hp.pix2ang(nside, np.arange(12 * nside**2), nest=True)
1245
- x, y, z = hp.pix2vec(nside, np.arange(12 * nside**2), nest=True)
1246
-
1247
- t, p = hp.pix2ang(nside, np.arange(12 * nside * nside), nest=True)
1248
- phi = [p[k] / np.pi * 180 for k in range(12 * nside * nside)]
1249
- thi = [t[k] / np.pi * 180 for k in range(12 * nside * nside)]
1268
+ if l_kernel == 5:
1269
+ pw = 0.5
1270
+ pw2 = 0.5
1271
+ threshold = 2e-4
1250
1272
 
1251
- indice2 = np.zeros([12 * nside * nside * 64, 2], dtype="int")
1252
- indice = np.zeros(
1253
- [12 * nside * nside * 64 * self.NORIENT, 2], dtype="int"
1254
- )
1255
- wav = np.zeros(
1256
- [12 * nside * nside * 64 * self.NORIENT], dtype="complex"
1257
- )
1258
- wwav = np.zeros(
1259
- [12 * nside * nside * 64 * self.NORIENT], dtype="float"
1260
- )
1261
- iv = 0
1262
- iv2 = 0
1273
+ elif l_kernel == 3:
1274
+ pw = 1.0 / np.sqrt(2)
1275
+ pw2 = 1.0
1276
+ threshold = 1e-3
1263
1277
 
1264
- for iii in range(ncell):
1265
- if cell_ids is None:
1266
- if iii % (nside * nside) == nside * nside - 1:
1267
- if not self.silent:
1268
- print(
1269
- "Pre-compute nside=%6d %.2f%%"
1270
- % (nside, 100 * iii / (12 * nside * nside))
1271
- )
1278
+ elif l_kernel == 7:
1279
+ pw = 0.5
1280
+ pw2 = 0.25
1281
+ threshold = 4e-5
1272
1282
 
1273
1283
  if cell_ids is not None:
1274
- hidx = np.where(
1275
- (x - x[iii]) ** 2 + (y - y[iii]) ** 2 + (z - z[iii]) ** 2
1276
- < (2 * np.pi / nside) ** 2
1277
- )[0]
1278
- else:
1279
- hidx = hp.query_disc(
1280
- nside,
1281
- [x[iii], y[iii], z[iii]],
1282
- 2 * np.pi / nside,
1283
- nest=True,
1284
- )
1284
+ if not isinstance(cell_ids, np.ndarray):
1285
+ cell_ids = self.backend.to_numpy(cell_ids)
1286
+ th, ph = hp.pix2ang(nside, cell_ids, nest=True)
1287
+ x, y, z = hp.pix2vec(nside, cell_ids, nest=True)
1285
1288
 
1286
- R = hp.Rotator(rot=[phi[iii], -thi[iii]], eulertype="ZYZ")
1289
+ t, p = hp.pix2ang(nside, cell_ids, nest=True)
1290
+ phi = [p[k] / np.pi * 180 for k in range(ncell)]
1291
+ thi = [t[k] / np.pi * 180 for k in range(ncell)]
1287
1292
 
1288
- t2, p2 = R(th[hidx], ph[hidx])
1293
+ indice2 = np.zeros([ncell * 64, 2], dtype="int")
1294
+ indice = np.zeros([ncell * 64 * self.NORIENT, 2], dtype="int")
1295
+ wav = np.zeros([ncell * 64 * self.NORIENT], dtype="complex")
1296
+ wwav = np.zeros([ncell * 64 * self.NORIENT], dtype="float")
1289
1297
 
1290
- vec2 = hp.ang2vec(t2, p2)
1291
-
1292
- x2 = vec2[:, 0]
1293
- y2 = vec2[:, 1]
1294
- z2 = vec2[:, 2]
1298
+ else:
1295
1299
 
1296
- ww = np.exp(
1297
- -pw2
1298
- * ((nside) ** 2)
1299
- * ((x2) ** 2 + (y2) ** 2 + (z2 - 1.0) ** 2)
1300
- )
1301
- idx = np.where((ww**2) > threshold)[0]
1302
- nval2 = len(idx)
1303
- indice2[iv2 : iv2 + nval2, 1] = iii
1304
- indice2[iv2 : iv2 + nval2, 0] = hidx[idx]
1305
- wwav[iv2 : iv2 + nval2] = ww[idx] / np.sum(ww[idx])
1306
- iv2 += nval2
1300
+ th, ph = hp.pix2ang(nside, np.arange(12 * nside**2), nest=True)
1301
+ x, y, z = hp.pix2vec(nside, np.arange(12 * nside**2), nest=True)
1307
1302
 
1308
- for l_rotation in range(self.NORIENT):
1303
+ t, p = hp.pix2ang(nside, np.arange(12 * nside * nside), nest=True)
1304
+ phi = [p[k] / np.pi * 180 for k in range(12 * nside * nside)]
1305
+ thi = [t[k] / np.pi * 180 for k in range(12 * nside * nside)]
1309
1306
 
1310
- angle = (
1311
- l_rotation / 4.0 * np.pi
1312
- - phi[iii] / 180.0 * np.pi * (z[hidx] > 0)
1313
- - (180.0 - phi[iii]) / 180.0 * np.pi * (z[hidx] < 0)
1307
+ indice2 = np.zeros([12 * nside * nside * 64, 2], dtype="int")
1308
+ indice = np.zeros(
1309
+ [12 * nside * nside * 64 * self.NORIENT, 2], dtype="int"
1310
+ )
1311
+ wav = np.zeros(
1312
+ [12 * nside * nside * 64 * self.NORIENT], dtype="complex"
1314
1313
  )
1314
+ wwav = np.zeros(
1315
+ [12 * nside * nside * 64 * self.NORIENT], dtype="float"
1316
+ )
1317
+ iv = 0
1318
+ iv2 = 0
1319
+
1320
+ for iii in range(ncell):
1321
+ if cell_ids is None:
1322
+ if iii % (nside * nside) == nside * nside - 1:
1323
+ if not self.silent:
1324
+ print(
1325
+ "Pre-compute nside=%6d %.2f%%"
1326
+ % (nside, 100 * iii / (12 * nside * nside))
1327
+ )
1328
+
1329
+ if cell_ids is not None:
1330
+ hidx = np.where(
1331
+ (x - x[iii]) ** 2 + (y - y[iii]) ** 2 + (z - z[iii]) ** 2
1332
+ < (2 * np.pi / nside) ** 2
1333
+ )[0]
1334
+ else:
1335
+ hidx = hp.query_disc(
1336
+ nside,
1337
+ [x[iii], y[iii], z[iii]],
1338
+ 2 * np.pi / nside,
1339
+ nest=True,
1340
+ )
1315
1341
 
1316
- # posi=2*(0.5-(z[hidx]<0))
1342
+ R = hp.Rotator(rot=[phi[iii], -thi[iii]], eulertype="ZYZ")
1317
1343
 
1318
- axes = y2 * np.cos(angle) - x2 * np.sin(angle)
1319
- wresr = ww * np.cos(pw * axes * (nside) * np.pi)
1320
- wresi = ww * np.sin(pw * axes * (nside) * np.pi)
1344
+ t2, p2 = R(th[hidx], ph[hidx])
1321
1345
 
1322
- vnorm = wresr * wresr + wresi * wresi
1323
- idx = np.where(vnorm > threshold)[0]
1346
+ vec2 = hp.ang2vec(t2, p2)
1324
1347
 
1325
- nval = len(idx)
1326
- indice[iv : iv + nval, 1] = iii + l_rotation * ncell
1327
- indice[iv : iv + nval, 0] = hidx[idx]
1328
- # print([hidx[k] for k in idx])
1329
- # print(hp.query_disc(nside, [x[iii],y[iii],z[iii]], np.pi/nside,nest=True))
1330
- normr = np.mean(wresr[idx])
1331
- normi = np.mean(wresi[idx])
1348
+ x2 = vec2[:, 0]
1349
+ y2 = vec2[:, 1]
1350
+ z2 = vec2[:, 2]
1332
1351
 
1333
- val = wresr[idx] - normr + 1j * (wresi[idx] - normi)
1334
- r = abs(val).sum()
1352
+ ww = np.exp(
1353
+ -pw2
1354
+ * ((nside) ** 2)
1355
+ * ((x2) ** 2 + (y2) ** 2 + (z2 - 1.0) ** 2)
1356
+ )
1357
+ idx = np.where((ww**2) > threshold)[0]
1358
+ nval2 = len(idx)
1359
+ indice2[iv2 : iv2 + nval2, 1] = iii
1360
+ indice2[iv2 : iv2 + nval2, 0] = hidx[idx]
1361
+ wwav[iv2 : iv2 + nval2] = ww[idx] / np.sum(ww[idx])
1362
+ iv2 += nval2
1363
+
1364
+ for l_rotation in range(self.NORIENT):
1365
+
1366
+ angle = (
1367
+ l_rotation / 4.0 * np.pi
1368
+ - phi[iii] / 180.0 * np.pi * (z[hidx] > 0)
1369
+ - (180.0 - phi[iii]) / 180.0 * np.pi * (z[hidx] < 0)
1370
+ )
1335
1371
 
1336
- if r > 0:
1337
- val = val / r
1372
+ # posi=2*(0.5-(z[hidx]<0))
1338
1373
 
1339
- wav[iv : iv + nval] = val
1340
- iv += nval
1374
+ axes = y2 * np.cos(angle) - x2 * np.sin(angle)
1375
+ wresr = ww * np.cos(pw * axes * (nside) * np.pi)
1376
+ wresi = ww * np.sin(pw * axes * (nside) * np.pi)
1341
1377
 
1342
- indice = indice[:iv, :]
1343
- wav = wav[:iv]
1344
- indice2 = indice2[:iv2, :]
1345
- wwav = wwav[:iv2]
1346
- if not self.silent:
1347
- print("Kernel Size ", iv / (self.NORIENT * 12 * nside * nside))
1348
- """
1349
- # OLD VERSION OLD VERSION OLD VERSION (3.0)
1350
- if self.KERNELSZ*self.KERNELSZ>12*nside*nside:
1351
- l_kernel=3
1352
-
1353
- aa=np.cos(np.arange(self.NORIENT)/self.NORIENT*np.pi).reshape(1,self.NORIENT)
1354
- bb=np.sin(np.arange(self.NORIENT)/self.NORIENT*np.pi).reshape(1,self.NORIENT)
1355
- x,y,z=hp.pix2vec(nside,np.arange(12*nside*nside),nest=True)
1356
- to,po=hp.pix2ang(nside,np.arange(12*nside*nside),nest=True)
1357
-
1358
- wav=np.zeros([12*nside*nside,l_kernel**2,self.NORIENT],dtype='complex')
1359
- wwav=np.zeros([12*nside*nside,l_kernel**2])
1360
- iwav=np.zeros([12*nside*nside,l_kernel**2],dtype='int')
1361
-
1362
- scale=4
1363
- if nside>scale*2:
1364
- th,ph=hp.pix2ang(nside//scale,np.arange(12*(nside//scale)**2),nest=True)
1365
- else:
1366
- lidx=np.arange(12*nside*nside)
1378
+ vnorm = wresr * wresr + wresi * wresi
1379
+ idx = np.where(vnorm > threshold)[0]
1367
1380
 
1368
- pw=np.pi/4.0
1369
- pw2=1/2
1370
- amp=1.0
1381
+ nval = len(idx)
1382
+ indice[iv : iv + nval, 1] = iii + l_rotation * ncell
1383
+ indice[iv : iv + nval, 0] = hidx[idx]
1384
+ # print([hidx[k] for k in idx])
1385
+ # print(hp.query_disc(nside, [x[iii],y[iii],z[iii]], np.pi/nside,nest=True))
1386
+ normr = np.mean(wresr[idx])
1387
+ normi = np.mean(wresi[idx])
1371
1388
 
1372
- if l_kernel==5:
1373
- pw=np.pi/4.0
1374
- pw2=1/2.25
1375
- amp=1.0/9.2038
1389
+ val = wresr[idx] - normr + 1j * (wresi[idx] - normi)
1390
+ r = abs(val).sum()
1376
1391
 
1377
- elif l_kernel==3:
1378
- pw=1.0/np.sqrt(2)
1379
- pw2=1.0
1380
- amp=1/8.45
1392
+ if r > 0:
1393
+ val = val / r
1381
1394
 
1382
- elif l_kernel==7:
1383
- pw=np.pi/4.0
1384
- pw2=1.0/3.0
1395
+ wav[iv : iv + nval] = val
1396
+ iv += nval
1385
1397
 
1386
- for k in range(12*nside*nside):
1387
- if k%(nside*nside)==0:
1388
- if not self.silent:
1389
- print('Pre-compute nside=%6d %.2f%%'%(nside,100*k/(12*nside*nside)))
1390
- if nside>scale*2:
1391
- lidx=hp.get_all_neighbours(nside//scale,th[k//(scale*scale)],ph[k//(scale*scale)],nest=True)
1392
- lidx=np.concatenate([lidx,np.array([(k//(scale*scale))])],0)
1393
- lidx=np.repeat(lidx*(scale*scale),(scale*scale))+ \
1394
- np.tile(np.arange((scale*scale)),lidx.shape[0])
1395
-
1396
- delta=(x[lidx]-x[k])**2+(y[lidx]-y[k])**2+(z[lidx]-z[k])**2
1397
- pidx=np.where(delta<(10)/(nside**2))[0]
1398
- if len(pidx)<l_kernel**2:
1399
- pidx=np.arange(delta.shape[0])
1400
-
1401
- w=np.exp(-pw2*delta[pidx]*(nside**2))
1402
- pidx=pidx[np.argsort(-w)[0:l_kernel**2]]
1403
- pidx=pidx[np.argsort(lidx[pidx])]
1404
-
1405
- w=np.exp(-pw2*delta[pidx]*(nside**2))
1406
- iwav[k]=lidx[pidx]
1407
- wwav[k]=w
1408
- rot=[po[k]/np.pi*180.0,90+(-to[k])/np.pi*180.0]
1409
- r=hp.Rotator(rot=rot)
1410
- ty,tx=r(to[iwav[k]],po[iwav[k]])
1411
- ty=ty-np.pi/2
1412
-
1413
- xx=np.expand_dims(pw*nside*np.pi*tx/np.cos(ty),-1)
1414
- yy=np.expand_dims(pw*nside*np.pi*ty,-1)
1415
-
1416
- wav[k,:,:]=(np.cos(xx*aa+yy*bb)+complex(0.0,1.0)*np.sin(xx*aa+yy*bb))*np.expand_dims(w,-1)
1417
-
1418
- wav=wav-np.expand_dims(np.mean(wav,1),1)
1419
- wav=amp*wav/np.expand_dims(np.std(wav,1),1)
1420
- wwav=wwav/np.expand_dims(np.sum(wwav,1),1)
1421
-
1422
- nk=l_kernel*l_kernel
1423
- indice=np.zeros([12*nside*nside*nk*self.NORIENT,2],dtype='int')
1424
- lidx=np.arange(self.NORIENT)
1425
- for i in range(12*nside*nside):
1426
- indice[i*nk*self.NORIENT:i*nk*self.NORIENT+nk*self.NORIENT,0]=i*self.NORIENT+np.repeat(lidx,nk)
1427
- indice[i*nk*self.NORIENT:i*nk*self.NORIENT+nk*self.NORIENT,1]=np.tile(iwav[i],self.NORIENT)
1428
-
1429
- indice2=np.zeros([12*nside*nside*nk,2],dtype='int')
1430
- for i in range(12*nside*nside):
1431
- indice2[i*nk:i*nk+nk,0]=i
1432
- indice2[i*nk:i*nk+nk,1]=iwav[i]
1433
-
1434
- w=np.zeros([12*nside*nside,wav.shape[2],wav.shape[1]],dtype='complex')
1435
- for i in range(wav.shape[1]):
1436
- for j in range(wav.shape[2]):
1437
- w[:,j,i]=wav[:,i,j]
1438
- wav=w.flatten()
1439
- wwav=wwav.flatten()
1440
- """
1441
- if cell_ids is None:
1398
+ indice = indice[:iv, :]
1399
+ wav = wav[:iv]
1400
+ indice2 = indice2[:iv2, :]
1401
+ wwav = wwav[:iv2]
1442
1402
  if not self.silent:
1443
- print(
1444
- "Write FOSCAT_%s_W%d_%d_%d_PIDX.npy"
1445
- % (TMPFILE_VERSION, self.KERNELSZ**2, self.NORIENT, nside)
1403
+ print("Kernel Size ", iv / (self.NORIENT * 12 * nside * nside))
1404
+
1405
+ if cell_ids is None:
1406
+ if not self.silent:
1407
+ print(
1408
+ "Write FOSCAT_%s_W%d_%d_%d_PIDX-SPIN%d.npy"
1409
+ % (TMPFILE_VERSION, self.KERNELSZ**2,
1410
+ self.NORIENT,
1411
+ nside,
1412
+ spin,)
1413
+ )
1414
+ np.save(
1415
+ "%s/FOSCAT_%s_W%d_%d_%d_PIDX-SPIN%d.npy"
1416
+ % (
1417
+ self.TEMPLATE_PATH,
1418
+ TMPFILE_VERSION,
1419
+ self.KERNELSZ**2,
1420
+ self.NORIENT,
1421
+ nside,
1422
+ spin,
1423
+ ),
1424
+ indice,
1425
+ )
1426
+ np.save(
1427
+ "%s/FOSCAT_%s_W%d_%d_%d_WAVE-SPIN%d.npy"
1428
+ % (
1429
+ self.TEMPLATE_PATH,
1430
+ TMPFILE_VERSION,
1431
+ self.KERNELSZ**2,
1432
+ self.NORIENT,
1433
+ nside,
1434
+ spin,
1435
+ ),
1436
+ wav,
1437
+ )
1438
+ np.save(
1439
+ "%s/FOSCAT_%s_W%d_%d_%d_PIDX2-SPIN%d.npy"
1440
+ % (
1441
+ self.TEMPLATE_PATH,
1442
+ TMPFILE_VERSION,
1443
+ self.KERNELSZ**2,
1444
+ self.NORIENT,
1445
+ nside,
1446
+ spin,
1447
+ ),
1448
+ indice2,
1449
+ )
1450
+ np.save(
1451
+ "%s/FOSCAT_%s_W%d_%d_%d_SMOO-SPIN%d.npy"
1452
+ % (
1453
+ self.TEMPLATE_PATH,
1454
+ TMPFILE_VERSION,
1455
+ self.KERNELSZ**2,
1456
+ self.NORIENT,
1457
+ nside,
1458
+ spin,
1459
+ ),
1460
+ wwav,
1446
1461
  )
1447
- np.save(
1448
- "%s/FOSCAT_%s_W%d_%d_%d_PIDX.npy"
1449
- % (
1450
- self.TEMPLATE_PATH,
1451
- TMPFILE_VERSION,
1452
- self.KERNELSZ**2,
1453
- self.NORIENT,
1454
- nside,
1455
- ),
1456
- indice,
1457
- )
1458
- np.save(
1459
- "%s/FOSCAT_%s_W%d_%d_%d_WAVE.npy"
1460
- % (
1461
- self.TEMPLATE_PATH,
1462
- TMPFILE_VERSION,
1463
- self.KERNELSZ**2,
1464
- self.NORIENT,
1465
- nside,
1466
- ),
1467
- wav,
1468
- )
1469
- np.save(
1470
- "%s/FOSCAT_%s_W%d_%d_%d_PIDX2.npy"
1471
- % (
1472
- self.TEMPLATE_PATH,
1473
- TMPFILE_VERSION,
1474
- self.KERNELSZ**2,
1475
- self.NORIENT,
1476
- nside,
1477
- ),
1478
- indice2,
1479
- )
1480
- np.save(
1481
- "%s/FOSCAT_%s_W%d_%d_%d_SMOO.npy"
1482
- % (
1483
- self.TEMPLATE_PATH,
1484
- TMPFILE_VERSION,
1485
- self.KERNELSZ**2,
1486
- self.NORIENT,
1487
- nside,
1488
- ),
1489
- wwav,
1490
- )
1491
1462
  if self.use_2D:
1492
1463
  if l_kernel**2 == 9:
1493
1464
  if self.rank == 0:
@@ -1508,58 +1479,68 @@ class FoCUS:
1508
1479
  self.barrier()
1509
1480
  if self.use_2D:
1510
1481
  tmp = np.load(
1511
- "%s/W%d_%s_%d_IDX.npy"
1512
- % (self.TEMPLATE_PATH, l_kernel**2, TMPFILE_VERSION, nside)
1482
+ "%s/W%d_%s_%d_IDX-SPIN%d.npy"
1483
+ % (
1484
+ self.TEMPLATE_PATH,
1485
+ l_kernel**2,
1486
+ TMPFILE_VERSION,
1487
+ nside,
1488
+ spin)
1513
1489
  )
1514
1490
  else:
1515
1491
  tmp = np.load(
1516
- "%s/FOSCAT_%s_W%d_%d_%d_PIDX.npy"
1492
+ "%s/FOSCAT_%s_W%d_%d_%d_PIDX-SPIN%d.npy"
1517
1493
  % (
1518
1494
  self.TEMPLATE_PATH,
1519
1495
  TMPFILE_VERSION,
1520
1496
  self.KERNELSZ**2,
1521
1497
  self.NORIENT,
1522
1498
  nside,
1499
+ spin,
1523
1500
  )
1524
1501
  )
1525
1502
  tmp2 = np.load(
1526
- "%s/FOSCAT_%s_W%d_%d_%d_PIDX2.npy"
1503
+ "%s/FOSCAT_%s_W%d_%d_%d_PIDX2-SPIN%d.npy"
1527
1504
  % (
1528
1505
  self.TEMPLATE_PATH,
1529
1506
  TMPFILE_VERSION,
1530
1507
  self.KERNELSZ**2,
1531
1508
  self.NORIENT,
1532
1509
  nside,
1510
+ spin,
1533
1511
  )
1534
1512
  )
1535
1513
  wr = np.load(
1536
- "%s/FOSCAT_%s_W%d_%d_%d_WAVE.npy"
1514
+ "%s/FOSCAT_%s_W%d_%d_%d_WAVE-SPIN%d.npy"
1537
1515
  % (
1538
1516
  self.TEMPLATE_PATH,
1539
1517
  TMPFILE_VERSION,
1540
1518
  self.KERNELSZ**2,
1541
1519
  self.NORIENT,
1542
1520
  nside,
1521
+ spin,
1543
1522
  )
1544
1523
  ).real
1545
1524
  wi = np.load(
1546
- "%s/FOSCAT_%s_W%d_%d_%d_WAVE.npy"
1525
+ "%s/FOSCAT_%s_W%d_%d_%d_WAVE-SPIN%d.npy"
1547
1526
  % (
1548
1527
  self.TEMPLATE_PATH,
1549
1528
  TMPFILE_VERSION,
1550
1529
  self.KERNELSZ**2,
1551
1530
  self.NORIENT,
1552
1531
  nside,
1532
+ spin,
1553
1533
  )
1554
1534
  ).imag
1555
1535
  ws = self.slope * np.load(
1556
- "%s/FOSCAT_%s_W%d_%d_%d_SMOO.npy"
1536
+ "%s/FOSCAT_%s_W%d_%d_%d_SMOO-SPIN%d.npy"
1557
1537
  % (
1558
1538
  self.TEMPLATE_PATH,
1559
1539
  TMPFILE_VERSION,
1560
1540
  self.KERNELSZ**2,
1561
1541
  self.NORIENT,
1562
1542
  nside,
1543
+ spin,
1563
1544
  )
1564
1545
  )
1565
1546
  else:
@@ -1569,21 +1550,38 @@ class FoCUS:
1569
1550
  wi = wav.imag
1570
1551
  ws = self.slope * wwav
1571
1552
 
1572
- wr = self.backend.bk_SparseTensor(
1573
- self.backend.bk_constant(tmp),
1574
- self.backend.bk_constant(self.backend.bk_cast(wr)),
1575
- dense_shape=[ncell, self.NORIENT * ncell],
1576
- )
1577
- wi = self.backend.bk_SparseTensor(
1578
- self.backend.bk_constant(tmp),
1579
- self.backend.bk_constant(self.backend.bk_cast(wi)),
1580
- dense_shape=[ncell, self.NORIENT * ncell],
1581
- )
1582
- ws = self.backend.bk_SparseTensor(
1583
- self.backend.bk_constant(tmp2),
1584
- self.backend.bk_constant(self.backend.bk_cast(ws)),
1585
- dense_shape=[ncell, ncell],
1586
- )
1553
+ if spin==0:
1554
+ wr = self.backend.bk_SparseTensor(
1555
+ self.backend.bk_constant(tmp),
1556
+ self.backend.bk_constant(self.backend.bk_cast(wr)),
1557
+ dense_shape=[ncell, self.NORIENT * ncell],
1558
+ )
1559
+ wi = self.backend.bk_SparseTensor(
1560
+ self.backend.bk_constant(tmp),
1561
+ self.backend.bk_constant(self.backend.bk_cast(wi)),
1562
+ dense_shape=[ncell, self.NORIENT * ncell],
1563
+ )
1564
+ ws = self.backend.bk_SparseTensor(
1565
+ self.backend.bk_constant(tmp2),
1566
+ self.backend.bk_constant(self.backend.bk_cast(ws)),
1567
+ dense_shape=[ncell, ncell],
1568
+ )
1569
+ else:
1570
+ wr = self.backend.bk_SparseTensor(
1571
+ self.backend.bk_constant(tmp),
1572
+ self.backend.bk_constant(self.backend.bk_cast(wr)),
1573
+ dense_shape=[2*ncell, 2*self.NORIENT * ncell],
1574
+ )
1575
+ wi = self.backend.bk_SparseTensor(
1576
+ self.backend.bk_constant(tmp),
1577
+ self.backend.bk_constant(self.backend.bk_cast(wi)),
1578
+ dense_shape=[2*ncell, 2*self.NORIENT * ncell],
1579
+ )
1580
+ ws = self.backend.bk_SparseTensor(
1581
+ self.backend.bk_constant(tmp2),
1582
+ self.backend.bk_constant(self.backend.bk_cast(ws)),
1583
+ dense_shape=[2*ncell, 2*ncell],
1584
+ )
1587
1585
 
1588
1586
  if kernel == -1:
1589
1587
  self.Idx_Neighbours[nside] = tmp
@@ -1840,10 +1838,10 @@ class FoCUS:
1840
1838
  return self.backend.bk_transpose(x, thelist)
1841
1839
 
1842
1840
  # ---------------------------------------------−---------
1843
- # Mean using mask x [....,Npix,....], mask[Nmask,Npix] to [....,Nmask,....]
1841
+ # Mean using mask x [n_b,....,Npix], mask[Nmask,Npix] to [n_b,Nmask,....]
1844
1842
  # if use_2D
1845
- # Mean using mask x [....,12,Nside+2*off,Nside+2*off,....], mask[Nmask,12,Nside+2*off,Nside+2*off] to [....,Nmask,....]
1846
- def masked_mean(self, x, mask, axis=0, rank=0, calc_var=False):
1843
+ # Mean using mask x [n_b,....,N_1,N_2], mask[Nmask,N_1,N_2] to [n_b,Nmask,....]
1844
+ def masked_mean(self, x, mask, rank=0, calc_var=False):
1847
1845
 
1848
1846
  # ==========================================================================
1849
1847
  # in input data=[Nbatch,...,NORIENT[,NORIENT],X[,Y]]
@@ -1855,7 +1853,7 @@ class FoCUS:
1855
1853
  shape = list(x.shape)
1856
1854
 
1857
1855
  if not self.use_2D and not self.use_1D:
1858
- nside = int(np.sqrt(x.shape[axis] // 12))
1856
+ nside = int(np.sqrt(x.shape[-1] // 12))
1859
1857
 
1860
1858
  l_mask = mask
1861
1859
  if self.mask_norm:
@@ -1949,16 +1947,24 @@ class FoCUS:
1949
1947
  l_x = self.backend.bk_reshape(
1950
1948
  l_x[:, :, self.KERNELSZ // 2 : -self.KERNELSZ // 2 + 1, :], oshape
1951
1949
  )
1952
- else:
1950
+ else:
1953
1951
  ichannel = 1
1954
- for i in range(len(shape) - 1):
1955
- ichannel *= shape[i]
1952
+ if len(shape)>1:
1953
+ ichannel = shape[0]
1954
+
1955
+ ochannel = 1
1956
+ for i in range(1,len(shape)-1):
1957
+ ochannel *= shape[i]
1956
1958
 
1957
- l_x = self.backend.bk_reshape(x, [ichannel, 1, shape[-1]])
1959
+ l_x = self.backend.bk_reshape(x, [ichannel,1,ochannel,shape[-1]])
1958
1960
 
1959
- # data=[Nbatch,...,NORIENT[,NORIENT],X[,Y]] => data=[Nbatch,1,...,NORIENT[,NORIENT],X[,Y]]
1961
+ # data=[Nbatch,...,NORIENT[,NORIENT],X[,Y]] => data=[Nbatch,...,1,NORIENT[,NORIENT],X[,Y]]
1960
1962
  # mask=[Nmask,X[,Y]] => mask=[1,Nmask,....,X[,Y]]
1961
- l_mask = self.backend.bk_expand_dims(self.backend.bk_expand_dims(l_mask, 0), 0)
1963
+
1964
+ if self.use_2D:
1965
+ l_mask = self.backend.bk_expand_dims(self.backend.bk_expand_dims(l_mask,0),-3)
1966
+ else:
1967
+ l_mask = self.backend.bk_expand_dims(self.backend.bk_expand_dims(l_mask,0),-2)
1962
1968
 
1963
1969
  if l_x.dtype == self.all_cbk_type:
1964
1970
  l_mask = self.backend.bk_complex(l_mask, self.backend.bk_cast(0.0 * l_mask))
@@ -1989,6 +1995,8 @@ class FoCUS:
1989
1995
 
1990
1996
  if len(x.shape[axis:-2]) > 0:
1991
1997
  oshape = oshape + list(x.shape[axis:-2])
1998
+ else:
1999
+ oshape = oshape + [1]
1992
2000
 
1993
2001
  if calc_var:
1994
2002
  if self.backend.bk_is_complex(vtmp):
@@ -2018,7 +2026,7 @@ class FoCUS:
2018
2026
  elif self.use_1D:
2019
2027
  mtmp = l_mask
2020
2028
  vtmp = l_x
2021
- v1 = self.backend.bk_reduce_sum(mtmp * vtmp, axis=-1)
2029
+ v1 = self.backend.bk_reduce_sum(l_mask[1,:,...,:] * vtmp, axis=-1)
2022
2030
  v2 = self.backend.bk_reduce_sum(mtmp * vtmp * vtmp, axis=-1)
2023
2031
  vh = self.backend.bk_reduce_sum(mtmp, axis=-1)
2024
2032
 
@@ -2027,6 +2035,8 @@ class FoCUS:
2027
2035
  oshape = [x.shape[0]] + [mask.shape[0]]
2028
2036
  if len(x.shape) > 1:
2029
2037
  oshape = oshape + list(x.shape[1:-1])
2038
+ else:
2039
+ oshape = oshape + [1]
2030
2040
 
2031
2041
  if calc_var:
2032
2042
  if self.backend.bk_is_complex(vtmp):
@@ -2060,13 +2070,16 @@ class FoCUS:
2060
2070
  res = v1 / vh
2061
2071
 
2062
2072
  oshape = []
2063
- if axis > 0:
2073
+ if len(shape) > 1:
2064
2074
  oshape = [x.shape[0]]
2065
2075
  else:
2066
2076
  oshape = [1]
2077
+
2067
2078
  oshape = oshape + [mask.shape[0]]
2068
- if axis > 1:
2069
- oshape = oshape + list(x.shape[1:-1])
2079
+ if len(shape) > 2:
2080
+ oshape = oshape + shape[1:-1]
2081
+ else:
2082
+ oshape = oshape + [1]
2070
2083
 
2071
2084
  if calc_var:
2072
2085
  if self.backend.bk_is_complex(l_x):
@@ -2220,7 +2233,7 @@ class FoCUS:
2220
2233
  return self.backend.bk_reduce_sum(r)
2221
2234
 
2222
2235
  # ---------------------------------------------−---------
2223
- def convol(self, in_image, axis=0, cell_ids=None, nside=None):
2236
+ def convol(self, in_image, axis=0, cell_ids=None, nside=None, spin=0):
2224
2237
 
2225
2238
  image = self.backend.bk_cast(in_image)
2226
2239
 
@@ -2283,77 +2296,37 @@ class FoCUS:
2283
2296
 
2284
2297
  else:
2285
2298
  ishape = list(image.shape)
2286
- """
2287
- if cell_ids is not None:
2288
- if cell_ids.shape[0] not in self.padding_conv:
2289
- print(image.shape,cell_ids.shape)
2290
- import healpix_convolution as hc
2291
- from xdggs.healpix import HealpixInfo
2299
+ if nside is None:
2300
+ nside = int(np.sqrt(image.shape[-1] // 12))
2292
2301
 
2293
- res = self.backend.bk_zeros(
2294
- ishape[0:-1] + [self.NORIENT]+ishape[-1], dtype=self.backend.all_cbk_type
2295
- )
2302
+ if spin==0:
2303
+ if nside not in self.Idx_Neighbours:
2304
+ if self.InitWave is None:
2305
+ wr, wi, ws, widx = self.init_index(nside, cell_ids=cell_ids)
2306
+ else:
2307
+ wr, wi, ws, widx = self.InitWave(nside, cell_ids=cell_ids)
2296
2308
 
2297
- grid_info = HealpixInfo(
2298
- level=int(np.log(nside) / np.log(2)), indexing_scheme="nested"
2299
- )
2309
+ self.Idx_Neighbours[nside] = 1 # self.backend.bk_constant(tmp)
2310
+ self.ww_Real[nside] = wr
2311
+ self.ww_Imag[nside] = wi
2312
+ self.w_smooth[nside] = ws
2300
2313
 
2301
- for k in range(self.NORIENT):
2302
- kernelR, kernelI = hc.kernels.wavelet_kernel(
2303
- cell_ids, grid_info=grid_info, orientation=k, is_torch=True
2304
- )
2305
- self.kernelR_conv[(cell_ids.shape[0], k)] = kernelR.to(
2306
- self.backend.all_bk_type
2307
- ).to(image.device)
2308
- self.kernelI_conv[(cell_ids.shape[0], k)] = kernelI.to(
2309
- self.backend.all_bk_type
2310
- ).to(image.device)
2311
- self.padding_conv[(cell_ids.shape[0], k)] = hc.pad(
2312
- cell_ids,
2313
- grid_info=grid_info,
2314
- ring=5 // 2, # wavelet kernel_size=5 is hard coded
2315
- mode="mean",
2316
- constant_value=0,
2317
- )
2318
-
2319
- for k in range(self.NORIENT):
2320
-
2321
- kernelR = self.kernelR_conv[(cell_ids.shape[0], k)]
2322
- kernelI = self.kernelI_conv[(cell_ids.shape[0], k)]
2323
- padding = self.padding_conv[(cell_ids.shape[0], k)]
2324
- if len(ishape) == 2:
2325
- for l in range(ishape[0]):
2326
- padded_data = padding.apply(image[l], is_torch=True)
2327
- res[l, :, k] = kernelR.matmul(
2328
- padded_data
2329
- ) + 1j * kernelI.matmul(padded_data)
2314
+ l_ww_real = self.ww_Real[nside]
2315
+ l_ww_imag = self.ww_Imag[nside]
2316
+ else:
2317
+ if (spin,nside) not in self.Idx_Neighbours:
2318
+ if self.InitWave is None:
2319
+ wr, wi, ws, widx = self.init_index(nside, cell_ids=cell_ids,spin=spin)
2330
2320
  else:
2331
- for l in range(ishape[0]):
2332
- for k2 in range(ishape[2]):
2333
- padded_data = padding.apply(
2334
- image[l, :, k2], is_torch=True
2335
- )
2336
- res[l, :, k2, k] = kernelR.matmul(
2337
- padded_data
2338
- ) + 1j * kernelI.matmul(padded_data)
2339
- return res
2340
- """
2341
- if nside is None:
2342
- nside = int(np.sqrt(image.shape[-1] // 12))
2343
-
2344
- if self.Idx_Neighbours[nside] is None:
2345
- if self.InitWave is None:
2346
- wr, wi, ws, widx = self.init_index(nside, cell_ids=cell_ids)
2347
- else:
2348
- wr, wi, ws, widx = self.InitWave(nside, cell_ids=cell_ids)
2321
+ wr, wi, ws, widx = self.InitWave(nside, cell_ids=cell_ids,spin=spin)
2349
2322
 
2350
- self.Idx_Neighbours[nside] = 1 # self.backend.bk_constant(tmp)
2351
- self.ww_Real[nside] = wr
2352
- self.ww_Imag[nside] = wi
2353
- self.w_smooth[nside] = ws
2323
+ self.Idx_Neighbours[(spin,nside)] = 1 # self.backend.bk_constant(tmp)
2324
+ self.ww_Real[(spin,nside)] = wr
2325
+ self.ww_Imag[(spin,nside)] = wi
2326
+ self.w_smooth[(spin,nside)] = ws
2354
2327
 
2355
- l_ww_real = self.ww_Real[nside]
2356
- l_ww_imag = self.ww_Imag[nside]
2328
+ l_ww_real = self.ww_Real[(spin,nside)]
2329
+ l_ww_imag = self.ww_Imag[(spin,nside)]
2357
2330
 
2358
2331
  # always convolve the last dimension
2359
2332
 
@@ -2361,9 +2334,14 @@ class FoCUS:
2361
2334
  if len(ishape) > 1:
2362
2335
  for k in range(len(ishape) - 1):
2363
2336
  ndata = ndata * ishape[k]
2364
- tim = self.backend.bk_reshape(
2365
- self.backend.bk_cast(image), [ndata, ishape[-1]]
2366
- )
2337
+ if spin>0:
2338
+ tim = self.backend.bk_reshape(
2339
+ self.backend.bk_cast(image), [ndata//2,2*ishape[-1]]
2340
+ )
2341
+ else:
2342
+ tim = self.backend.bk_reshape(
2343
+ self.backend.bk_cast(image), [ndata, ishape[-1]]
2344
+ )
2367
2345
 
2368
2346
  if tim.dtype == self.all_cbk_type:
2369
2347
  rr1 = self.backend.bk_reshape(
@@ -2405,17 +2383,27 @@ class FoCUS:
2405
2383
  [ndata, self.NORIENT, ishape[-1]],
2406
2384
  )
2407
2385
  res = self.backend.bk_complex(rr, ii)
2408
- if len(ishape) > 1:
2409
- return self.backend.bk_reshape(
2410
- res, ishape[0:-1] + [self.NORIENT, ishape[-1]]
2411
- )
2386
+
2387
+ if spin==0:
2388
+ if len(ishape) > 1:
2389
+ return self.backend.bk_reshape(
2390
+ res, ishape[0:-1] + [self.NORIENT, ishape[-1]]
2391
+ )
2392
+ else:
2393
+ return self.backend.bk_reshape(res, [self.NORIENT, ishape[-1]])
2412
2394
  else:
2413
- return self.backend.bk_reshape(res, [self.NORIENT, ishape[-1]])
2395
+ if len(ishape) > 2:
2396
+ return self.backend.bk_reshape(
2397
+ res, ishape[0:-2] + [2,self.NORIENT, ishape[-1]]
2398
+ )
2399
+ else:
2400
+ return self.backend.bk_reshape(res, [2,self.NORIENT, ishape[-1]])
2401
+
2414
2402
 
2415
2403
  return res
2416
2404
 
2417
2405
  # ---------------------------------------------−---------
2418
- def smooth(self, in_image, axis=0, cell_ids=None, nside=None):
2406
+ def smooth(self, in_image, axis=0, cell_ids=None, nside=None, spin=0):
2419
2407
 
2420
2408
  image = self.backend.bk_cast(in_image)
2421
2409
 
@@ -2475,64 +2463,35 @@ class FoCUS:
2475
2463
  else:
2476
2464
 
2477
2465
  ishape = list(image.shape)
2478
- """
2479
- if cell_ids is not None:
2480
- if cell_ids.shape[0] not in self.padding_smooth:
2481
- import healpix_convolution as hc
2482
- from xdggs.healpix import HealpixInfo
2483
-
2484
- grid_info = HealpixInfo(
2485
- level=int(np.log(nside) / np.log(2)), indexing_scheme="nested"
2486
- )
2487
-
2488
- kernel = hc.kernels.wavelet_smooth_kernel(
2489
- cell_ids, grid_info=grid_info, is_torch=True
2490
- )
2491
-
2492
- self.kernel_smooth[cell_ids.shape[0]] = kernel.to(
2493
- self.backend.all_bk_type
2494
- ).to(image.device)
2495
-
2496
- self.padding_smooth[cell_ids.shape[0]] = hc.pad(
2497
- cell_ids,
2498
- grid_info=grid_info,
2499
- ring=5 // 2, # wavelet kernel_size=5 is hard coded
2500
- mode="mean",
2501
- constant_value=0,
2502
- )
2503
-
2504
- kernel = self.kernel_smooth[cell_ids.shape[0]]
2505
- padding = self.padding_smooth[cell_ids.shape[0]]
2506
-
2507
- res = self.backend.bk_zeros(ishape, dtype=self.backend.all_cbk_type)
2508
-
2509
- if len(ishape) == 2:
2510
- for l in range(ishape[0]):
2511
- padded_data = padding.apply(image[l], is_torch=True)
2512
- res[l] = kernel.matmul(padded_data)
2513
- else:
2514
- for l in range(ishape[0]):
2515
- for k2 in range(ishape[2]):
2516
- padded_data = padding.apply(image[l, :, k2], is_torch=True)
2517
- res[l, :, k2] = kernel.matmul(padded_data)
2518
- return res
2519
- """
2466
+
2520
2467
  if nside is None:
2521
2468
  nside = int(np.sqrt(image.shape[-1] // 12))
2522
2469
 
2523
- if self.Idx_Neighbours[nside] is None:
2470
+ if spin==0:
2471
+ if nside not in self.Idx_Neighbours:
2472
+ if self.InitWave is None:
2473
+ wr, wi, ws, widx = self.init_index(nside, cell_ids=cell_ids)
2474
+ else:
2475
+ wr, wi, ws, widx = self.InitWave(nside, cell_ids=cell_ids)
2524
2476
 
2525
- if self.InitWave is None:
2526
- wr, wi, ws, widx = self.init_index(nside, cell_ids=cell_ids)
2527
- else:
2528
- wr, wi, ws, widx = self.InitWave(self, nside, cell_ids=cell_ids)
2477
+ self.Idx_Neighbours[nside] = 1 # self.backend.bk_constant(tmp)
2478
+ self.ww_Real[nside] = wr
2479
+ self.ww_Imag[nside] = wi
2480
+ self.w_smooth[nside] = ws
2529
2481
 
2530
- self.Idx_Neighbours[nside] = 1
2531
- self.ww_Real[nside] = wr
2532
- self.ww_Imag[nside] = wi
2533
- self.w_smooth[nside] = ws
2482
+ l_w_smooth = self.w_smooth[nside]
2483
+ else:
2484
+ if (spin,nside) not in self.Idx_Neighbours:
2485
+ if self.InitWave is None:
2486
+ wr, wi, ws, widx = self.init_index(nside, cell_ids=cell_ids,spin=spin)
2487
+ else:
2488
+ wr, wi, ws, widx = self.InitWave(nside, cell_ids=cell_ids,spin=spin)
2534
2489
 
2535
- l_w_smooth = self.w_smooth[nside]
2490
+ self.Idx_Neighbours[(spin,nside)] = 1 # self.backend.bk_constant(tmp)
2491
+ self.ww_Real[(spin,nside)] = wr
2492
+ self.ww_Imag[(spin,nside)] = wi
2493
+ self.w_smooth[(spin,nside)] = ws
2494
+ l_w_smooth = self.w_smooth[(spin,nside)]
2536
2495
 
2537
2496
  odata = 1
2538
2497
  for k in range(0, len(ishape) - 1):