foscat 2025.5.2__py3-none-any.whl → 2025.6.3__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
foscat/FoCUS.py CHANGED
@@ -10,32 +10,32 @@ TMPFILE_VERSION = "V5_0"
10
10
 
11
11
  class FoCUS:
12
12
  def __init__(
13
- self,
14
- NORIENT=4,
15
- LAMBDA=1.2,
16
- KERNELSZ=3,
17
- slope=1.0,
18
- all_type="float32",
19
- nstep_max=16,
20
- padding="SAME",
21
- gpupos=0,
22
- mask_thres=None,
23
- mask_norm=False,
24
- isMPI=False,
25
- TEMPLATE_PATH="data",
26
- BACKEND="tensorflow",
27
- use_2D=False,
28
- use_1D=False,
29
- return_data=False,
30
- JmaxDelta=0,
31
- DODIV=False,
32
- InitWave=None,
33
- silent=True,
34
- mpi_size=1,
35
- mpi_rank=0,
13
+ self,
14
+ NORIENT=4,
15
+ LAMBDA=1.2,
16
+ KERNELSZ=3,
17
+ slope=1.0,
18
+ all_type="float32",
19
+ nstep_max=20,
20
+ padding="SAME",
21
+ gpupos=0,
22
+ mask_thres=None,
23
+ mask_norm=False,
24
+ isMPI=False,
25
+ TEMPLATE_PATH="data",
26
+ BACKEND="torch",
27
+ use_2D=False,
28
+ use_1D=False,
29
+ return_data=False,
30
+ JmaxDelta=0,
31
+ DODIV=False,
32
+ InitWave=None,
33
+ silent=True,
34
+ mpi_size=1,
35
+ mpi_rank=0
36
36
  ):
37
37
 
38
- self.__version__ = "2025.05.2"
38
+ self.__version__ = "2025.06.3"
39
39
  # P00 coeff for normalization for scat_cov
40
40
  self.TMPFILE_VERSION = TMPFILE_VERSION
41
41
  self.P1_dic = None
@@ -176,6 +176,9 @@ class FoCUS:
176
176
  self.Y_CNN = {}
177
177
  self.Z_CNN = {}
178
178
 
179
+ self.Idx_CNN = {}
180
+ self.Idx_WCNN = {}
181
+
179
182
  self.filters_set = {}
180
183
  self.edge_masks = {}
181
184
 
@@ -366,39 +369,8 @@ class FoCUS:
366
369
  self.pix_interp_val = {}
367
370
  self.weight_interp_val = {}
368
371
  self.ring2nest = {}
369
- self.nest2R = {}
370
- self.nest2R1 = {}
371
- self.nest2R2 = {}
372
- self.nest2R3 = {}
373
- self.nest2R4 = {}
374
- self.inv_nest2R = {}
375
- self.remove_border = {}
376
-
377
372
  self.ampnorm = {}
378
373
 
379
- for i in range(nstep_max):
380
- lout = 2**i
381
- self.pix_interp_val[lout] = {}
382
- self.weight_interp_val[lout] = {}
383
- for j in range(nstep_max):
384
- lout2 = 2**j
385
- self.pix_interp_val[lout][lout2] = None
386
- self.weight_interp_val[lout][lout2] = None
387
- self.ring2nest[lout] = None
388
- self.Idx_Neighbours[lout] = None
389
- self.nest2R[lout] = None
390
- self.nest2R1[lout] = None
391
- self.nest2R2[lout] = None
392
- self.nest2R3[lout] = None
393
- self.nest2R4[lout] = None
394
- self.inv_nest2R[lout] = None
395
- self.remove_border[lout] = None
396
- self.ww_CNN_Transpose[lout] = None
397
- self.ww_CNN[lout] = None
398
- self.X_CNN[lout] = None
399
- self.Y_CNN[lout] = None
400
- self.Z_CNN[lout] = None
401
-
402
374
  self.loss = {}
403
375
 
404
376
  def get_type(self):
@@ -500,210 +472,26 @@ class FoCUS:
500
472
  return indices, weights, xc, yc, zc
501
473
 
502
474
  # ---------------------------------------------−---------
503
- def calc_orientation(self, im): # im is [Ndata,12*Nside**2]
504
- nside = int(np.sqrt(im.shape[1] // 12))
505
- l_kernel = self.KERNELSZ * self.KERNELSZ
506
- norient = 32
507
- w = np.zeros([l_kernel, 1, 2 * norient])
508
- ca = np.cos(np.arange(norient) / norient * np.pi)
509
- sa = np.sin(np.arange(norient) / norient * np.pi)
510
- stat = np.zeros([12 * nside**2, norient])
511
-
512
- if self.ww_CNN[nside] is None:
513
- self.init_CNN_index(nside, transpose=False)
514
-
515
- y = self.Y_CNN[nside]
516
- z = self.Z_CNN[nside]
517
-
518
- for k in range(norient):
519
- w[:, 0, k] = np.exp(-0.5 * nside**2 * ((y) ** 2 + (z) ** 2)) * np.cos(
520
- nside * (y * ca[k] + z * sa[k]) * np.pi / 2
521
- )
522
- w[:, 0, k + norient] = np.exp(
523
- -0.5 * nside**2 * ((y) ** 2 + (z) ** 2)
524
- ) * np.sin(nside * (y * ca[k] + z * sa[k]) * np.pi / 2)
525
- w[:, 0, k] = w[:, 0, k] - np.mean(w[:, 0, k])
526
- w[:, 0, k + norient] = w[:, 0, k] - np.mean(w[:, 0, k + norient])
527
-
528
- for k in range(im.shape[0]):
529
- tmp = im[k].reshape(12 * nside**2, 1)
530
- im2 = self.healpix_layer(tmp, w)
531
- stat = stat + im2[:, 0:norient] ** 2 + im2[:, norient:] ** 2
532
-
533
- rotation = (np.argmax(stat, 1)).astype("float") / 32.0 * 180.0
534
-
535
- indices, weights, x, y, z = self.calc_indices_convol(
536
- nside, 9, rotation=rotation
537
- )
538
-
539
- return indices, weights
540
-
541
- def init_CNN_index(self, nside, transpose=False):
542
- l_kernel = int(self.KERNELSZ * self.KERNELSZ)
543
- try:
544
- indices = np.load(
545
- "%s/FOSCAT_%s_I%d_%d_%d_CNNV3.npy"
546
- % (self.TEMPLATE_PATH, TMPFILE_VERSION, l_kernel, self.NORIENT, nside)
547
- )
548
- weights = np.load(
549
- "%s/FOSCAT_%s_W%d_%d_%d_CNNV3.npy"
550
- % (self.TEMPLATE_PATH, TMPFILE_VERSION, l_kernel, self.NORIENT, nside)
551
- )
552
- xc = np.load(
553
- "%s/FOSCAT_%s_X%d_%d_%d_CNNV3.npy"
554
- % (self.TEMPLATE_PATH, TMPFILE_VERSION, l_kernel, self.NORIENT, nside)
555
- )
556
- yc = np.load(
557
- "%s/FOSCAT_%s_Y%d_%d_%d_CNNV3.npy"
558
- % (self.TEMPLATE_PATH, TMPFILE_VERSION, l_kernel, self.NORIENT, nside)
559
- )
560
- zc = np.load(
561
- "%s/FOSCAT_%s_Z%d_%d_%d_CNNV3.npy"
562
- % (self.TEMPLATE_PATH, TMPFILE_VERSION, l_kernel, self.NORIENT, nside)
563
- )
564
- except:
565
- indices, weights, xc, yc, zc = self.calc_indices_convol(nside, l_kernel)
566
- np.save(
567
- "%s/FOSCAT_%s_I%d_%d_%d_CNNV3.npy"
568
- % (self.TEMPLATE_PATH, TMPFILE_VERSION, l_kernel, self.NORIENT, nside),
569
- indices,
570
- )
571
- np.save(
572
- "%s/FOSCAT_%s_W%d_%d_%d_CNNV3.npy"
573
- % (self.TEMPLATE_PATH, TMPFILE_VERSION, l_kernel, self.NORIENT, nside),
574
- weights,
575
- )
576
- np.save(
577
- "%s/FOSCAT_%s_X%d_%d_%d_CNNV3.npy"
578
- % (self.TEMPLATE_PATH, TMPFILE_VERSION, l_kernel, self.NORIENT, nside),
579
- xc,
580
- )
581
- np.save(
582
- "%s/FOSCAT_%s_Y%d_%d_%d_CNNV3.npy"
583
- % (self.TEMPLATE_PATH, TMPFILE_VERSION, l_kernel, self.NORIENT, nside),
584
- yc,
585
- )
586
- np.save(
587
- "%s/FOSCAT_%s_Z%d_%d_%d_CNNV3.npy"
588
- % (self.TEMPLATE_PATH, TMPFILE_VERSION, l_kernel, self.NORIENT, nside),
589
- zc,
590
- )
591
- if not self.silent:
592
- print(
593
- "Write %s/FOSCAT_%s_W%d_%d_%d_CNNV2.npy"
594
- % (
595
- self.TEMPLATE_PATH,
596
- TMPFILE_VERSION,
597
- l_kernel,
598
- self.NORIENT,
599
- nside,
600
- )
601
- )
602
-
603
- self.X_CNN[nside] = xc
604
- self.Y_CNN[nside] = yc
605
- self.Z_CNN[nside] = zc
606
- self.ww_CNN[nside] = self.backend.bk_SparseTensor(
607
- indices, weights, [12 * nside * nside * l_kernel, 12 * nside * nside]
608
- )
609
-
610
- # ---------------------------------------------−---------
611
- def healpix_layer_coord(self, im, axis=0):
612
- nside = int(np.sqrt(im.shape[axis] // 12))
613
- if self.ww_CNN[nside] is None:
614
- self.init_CNN_index(nside)
615
- return self.X_CNN[nside], self.Y_CNN[nside], self.Z_CNN[nside]
616
-
617
475
  # ---------------------------------------------−---------
618
- def healpix_layer_transpose(self, im, ww, indices=None, weights=None, axis=0):
619
- nside = int(np.sqrt(im.shape[axis] // 12))
620
-
621
- if im.shape[1 + axis] != ww.shape[1]:
622
- if not self.silent:
623
- print("Weights channels should be equal to the input image channels")
624
- return -1
625
- if axis == 1:
626
- results = []
627
-
628
- for k in range(im.shape[0]):
629
-
630
- tmp = self.healpix_layer(
631
- im[k], ww, indices=indices, weights=weights, axis=0
632
- )
633
- tmp = self.backend.bk_reshape(
634
- self.up_grade(tmp, 2 * nside), [12 * 4 * nside**2, ww.shape[2]]
635
- )
636
-
637
- results.append(tmp)
638
-
639
- return self.backend.bk_stack(results, axis=0)
640
- else:
641
- tmp = self.healpix_layer(
642
- im, ww, indices=indices, weights=weights, axis=axis
643
- )
644
-
645
- return self.up_grade(tmp, 2 * nside)
646
-
647
- # ---------------------------------------------−---------
648
- # ---------------------------------------------−---------
649
- def healpix_layer(self, im, ww, indices=None, weights=None, axis=0):
650
- nside = int(np.sqrt(im.shape[axis] // 12))
651
- l_kernel = self.KERNELSZ * self.KERNELSZ
652
-
653
- if im.shape[1 + axis] != ww.shape[1]:
654
- if not self.silent:
655
- print("Weights channels should be equal to the input image channels")
656
- return -1
657
-
476
+ def healpix_layer(self, im, ww, indices=None, weights=None):
477
+ #ww [N_i,NORIENT,KERNELSZ*KERNELSZ//2,N_o,NORIENT]
478
+ #im [N_batch,N_i, NORIENT,N]
479
+ nside=int(np.sqrt(im.shape[-1]//12))
658
480
  if indices is None:
659
- if self.ww_CNN[nside] is None:
660
- self.init_CNN_index(nside, transpose=False)
661
- mat = self.ww_CNN[nside]
662
- else:
663
- if weights is None:
664
- print(
665
- "healpix_layer : If indices is not none weights should be specify"
666
- )
667
- return 0
668
-
669
- mat = self.backend.bk_SparseTensor(
670
- indices, weights, [12 * nside * nside * l_kernel, 12 * nside * nside]
671
- )
672
-
673
- if axis == 1:
674
- results = []
675
-
676
- for k in range(im.shape[0]):
677
-
678
- tmp = self.backend.bk_sparse_dense_matmul(mat, im[k])
679
-
680
- density = self.backend.bk_reshape(
681
- tmp, [12 * nside * nside, l_kernel * im.shape[1 + axis]]
682
- )
683
-
684
- density = self.backend.bk_matmul(
685
- density,
686
- self.backend.bk_reshape(
687
- ww, [l_kernel * im.shape[1 + axis], ww.shape[2]]
688
- ),
689
- )
690
-
691
- results.append(
692
- self.backend.bk_reshape(density, [12 * nside**2, ww.shape[2]])
693
- )
694
-
695
- return self.backend.bk_stack(results, axis=0)
696
- else:
697
- tmp = self.backend.bk_sparse_dense_matmul(mat, im)
698
-
699
- density = self.backend.bk_reshape(
700
- tmp, [12 * nside * nside, l_kernel * im.shape[1]]
701
- )
702
-
703
- return self.backend.bk_matmul(
704
- density,
705
- self.backend.bk_reshape(ww, [l_kernel * im.shape[1], ww.shape[2]]),
706
- )
481
+ if (nside,self.NORIENT,self.KERNELSZ) not in self.ww_CNN:
482
+ self.init_index_cnn(nside,self.NORIENT)
483
+ indices = self.Idx_CNN[(nside,self.NORIENT,self.KERNELSZ)]
484
+ mat = self.Idx_WCNN[(nside,self.NORIENT,self.KERNELSZ)]
485
+
486
+ wim = self.backend.bk_gather(im,indices.flatten(),axis=3) #[N_batch,N_i,NORIENT,K*(K+1),N_o,NORIENT,N,N_w]
487
+
488
+ wim = self.backend.bk_reshape(wim,[im.shape[0],im.shape[1],im.shape[2]]+list(indices.shape))*mat[None,...]
489
+ #win is [N_batch,N_i, NORIENT,K*(K+1),1, NORIENT,N,N_w]
490
+ #ww is [1, N_i, NORIENT,K*(K+1),N_o,NORIENT]
491
+ wim = self.backend.bk_reduce_sum(wim[:,:,:,:,None]*ww[None,:,:,:,:,:,None,None],[1,2,3])
492
+
493
+ wim = self.backend.bk_reduce_sum(wim,-1)
494
+ return self.backend.bk_reshape(wim,[im.shape[0],ww.shape[3],ww.shape[4],im.shape[-1]])
707
495
 
708
496
  # ---------------------------------------------−---------
709
497
 
@@ -724,7 +512,7 @@ class FoCUS:
724
512
  def toring(self, image, axis=0):
725
513
  lout = int(np.sqrt(image.shape[axis] // 12))
726
514
 
727
- if self.ring2nest[lout] is None:
515
+ if lout not in self.ring2nest:
728
516
  self.ring2nest[lout] = hp.ring2nest(lout, np.arange(12 * lout**2))
729
517
 
730
518
  return image.numpy()[self.ring2nest[lout]]
@@ -820,30 +608,10 @@ class FoCUS:
820
608
  if cell_ids is not None:
821
609
  sim, new_cell_ids = self.backend.binned_mean(im, cell_ids)
822
610
  return sim, new_cell_ids
823
-
824
- lout = int(np.sqrt(shape[axis] // 12))
825
- if im.__class__ == np.zeros([0]).__class__:
826
- oshape = np.zeros([len(shape) + 1], dtype="int")
827
- if axis > 0:
828
- oshape[0:axis] = shape[0:axis]
829
- oshape[axis] = 12 * lout * lout // 4
830
- oshape[axis + 1] = 4
831
- if len(shape) > axis:
832
- oshape[axis + 2 :] = shape[axis + 1 :]
833
- else:
834
- if axis > 0:
835
- oshape = shape[0:axis] + [12 * lout * lout // 4, 4]
836
- else:
837
- oshape = [12 * lout * lout // 4, 4]
838
- if len(shape) > axis:
839
- oshape = oshape + shape[axis + 1 :]
840
-
841
- return (
842
- self.backend.bk_reduce_mean(
843
- self.backend.bk_reshape(im, oshape), axis=axis + 1
844
- ),
845
- None,
846
- )
611
+
612
+ return self.backend.bk_reduce_mean(
613
+ self.backend.bk_reshape(im, shape[0:-1]+[shape[-1]//4,4]), axis=-1
614
+ ),None
847
615
 
848
616
  # --------------------------------------------------------
849
617
  def up_grade(self, im, nout, axis=0, nouty=None):
@@ -954,9 +722,9 @@ class FoCUS:
954
722
 
955
723
  else:
956
724
 
957
- lout = int(np.sqrt(im.shape[axis] // 12))
725
+ lout = int(np.sqrt(im.shape[-1] // 12))
958
726
 
959
- if self.pix_interp_val[lout][nout] is None:
727
+ if (lout,nout) not in self.pix_interp_val:
960
728
  if not self.silent:
961
729
  print("compute lout nout", lout, nout)
962
730
  th, ph = hp.pix2ang(
@@ -975,104 +743,51 @@ class FoCUS:
975
743
  t = t + np.repeat(np.arange(12 * nout * nout) * 4, 4)
976
744
  p = p.flatten()[t]
977
745
  w = w.flatten()[t]
978
- indice[:, 0] = np.repeat(np.arange(12 * nout**2), 4)
979
- indice[:, 1] = p
746
+ indice[:, 1] = np.repeat(np.arange(12 * nout**2), 4)
747
+ indice[:, 0] = p
980
748
 
981
- self.pix_interp_val[lout][nout] = 1
982
- self.weight_interp_val[lout][nout] = self.backend.bk_SparseTensor(
749
+ self.pix_interp_val[(lout,nout)] = 1
750
+ self.weight_interp_val[(lout,nout)] = self.backend.bk_SparseTensor(
983
751
  self.backend.bk_constant(indice),
984
752
  self.backend.bk_constant(self.backend.bk_cast(w.flatten())),
985
- dense_shape=[12 * nout**2, 12 * lout**2],
753
+ dense_shape=[12 * lout**2,12 * nout**2],
986
754
  )
987
755
 
988
756
  if lout == nout:
989
757
  imout = im
990
758
  else:
991
-
759
+ # work only on the last column
760
+
992
761
  ishape = list(im.shape)
993
- odata = 1
994
- for k in range(axis + 1, len(ishape)):
995
- odata = odata * ishape[k]
996
762
 
997
763
  ndata = 1
998
- for k in range(axis):
764
+ for k in range(len(ishape)-1):
999
765
  ndata = ndata * ishape[k]
1000
766
  tim = self.backend.bk_reshape(
1001
- self.backend.bk_cast(im), [ndata, 12 * lout**2, odata]
767
+ self.backend.bk_cast(im), [ndata, 12 * lout**2]
1002
768
  )
1003
769
  if tim.dtype == self.all_cbk_type:
1004
- rr = self.backend.bk_reshape(
1005
- self.backend.bk_sparse_dense_matmul(
1006
- self.weight_interp_val[lout][nout],
1007
- self.backend.bk_real(tim[0]),
1008
- ),
1009
- [1, 12 * nout**2, odata],
1010
- )
1011
- ii = self.backend.bk_reshape(
1012
- self.backend.bk_sparse_dense_matmul(
1013
- self.weight_interp_val[lout][nout],
1014
- self.backend.bk_imag(tim[0]),
1015
- ),
1016
- [1, 12 * nout**2, odata],
1017
- )
770
+ rr = self.backend.bk_sparse_dense_matmul(
771
+ self.backend.bk_real(tim),
772
+ self.weight_interp_val[(lout,nout)],
773
+ )
774
+ ii = self.backend.bk_sparse_dense_matmul(
775
+ self.backend.bk_real(tim),
776
+ self.weight_interp_val[(lout,nout)],
777
+ )
1018
778
  imout = self.backend.bk_complex(rr, ii)
1019
779
  else:
1020
- imout = self.backend.bk_reshape(
1021
- self.backend.bk_sparse_dense_matmul(
1022
- self.weight_interp_val[lout][nout], tim[0]
1023
- ),
1024
- [1, 12 * nout**2, odata],
780
+ imout = self.backend.bk_sparse_dense_matmul(
781
+ tim,
782
+ self.weight_interp_val[(lout,nout)],
1025
783
  )
1026
784
 
1027
- for k in range(1, ndata):
1028
- if tim.dtype == self.all_cbk_type:
1029
- rr = self.backend.bk_reshape(
1030
- self.backend.bk_sparse_dense_matmul(
1031
- self.weight_interp_val[lout][nout],
1032
- self.backend.bk_real(tim[k]),
1033
- ),
1034
- [1, 12 * nout**2, odata],
1035
- )
1036
- ii = self.backend.bk_reshape(
1037
- self.backend.bk_sparse_dense_matmul(
1038
- self.weight_interp_val[lout][nout],
1039
- self.backend.bk_imag(tim[k]),
1040
- ),
1041
- [1, 12 * nout**2, odata],
1042
- )
1043
- imout = self.backend.bk_concat(
1044
- [imout, self.backend.bk_complex(rr, ii)], 0
1045
- )
1046
- else:
1047
- imout = self.backend.bk_concat(
1048
- [
1049
- imout,
1050
- self.backend.bk_reshape(
1051
- self.backend.bk_sparse_dense_matmul(
1052
- self.weight_interp_val[lout][nout], tim[k]
1053
- ),
1054
- [1, 12 * nout**2, odata],
1055
- ),
1056
- ],
1057
- 0,
1058
- )
1059
-
1060
- if axis == 0:
1061
- if len(ishape) == 1:
1062
- return self.backend.bk_reshape(imout, [12 * nout**2])
1063
- else:
1064
- return self.backend.bk_reshape(
1065
- imout, [12 * nout**2] + ishape[axis + 1 :]
1066
- )
785
+ if len(ishape) == 1:
786
+ return self.backend.bk_reshape(imout, [12 * nout**2])
1067
787
  else:
1068
- if len(ishape) == axis + 1:
1069
- return self.backend.bk_reshape(
1070
- imout, ishape[0:axis] + [12 * nout**2]
1071
- )
1072
- else:
1073
- return self.backend.bk_reshape(
1074
- imout, ishape[0:axis] + [12 * nout**2] + ishape[axis + 1 :]
1075
- )
788
+ return self.backend.bk_reshape(
789
+ imout, ishape[0:axis-1]+[12 * nout**2]
790
+ )
1076
791
  return imout
1077
792
 
1078
793
  # --------------------------------------------------------
@@ -1345,7 +1060,7 @@ class FoCUS:
1345
1060
  return res
1346
1061
 
1347
1062
  # ---------------------------------------------−---------
1348
- def init_index(self, nside, kernel=-1, cell_ids=None):
1063
+ def init_index(self, nside, kernel=-1, cell_ids=None, spin=0):
1349
1064
 
1350
1065
  if kernel == -1:
1351
1066
  l_kernel = self.KERNELSZ
@@ -1378,297 +1093,372 @@ class FoCUS:
1378
1093
 
1379
1094
  else:
1380
1095
  tmp = np.load(
1381
- "%s/FOSCAT_%s_W%d_%d_%d_PIDX.npy"
1096
+ "%s/FOSCAT_%s_W%d_%d_%d_PIDX-SPIN%d.npy"
1382
1097
  % (
1383
1098
  self.TEMPLATE_PATH,
1384
1099
  TMPFILE_VERSION,
1385
1100
  l_kernel**2,
1386
1101
  self.NORIENT,
1387
- nside, # if cell_ids computes the index
1102
+ nside,spin # if cell_ids computes the index
1388
1103
  )
1389
1104
  )
1105
+
1390
1106
  except:
1391
1107
  if not self.use_2D:
1392
-
1393
- if l_kernel == 5:
1394
- pw = 0.5
1395
- pw2 = 0.5
1396
- threshold = 2e-4
1397
-
1398
- elif l_kernel == 3:
1399
- pw = 1.0 / np.sqrt(2)
1400
- pw2 = 1.0
1401
- threshold = 1e-3
1402
-
1403
- elif l_kernel == 7:
1404
- pw = 0.5
1405
- pw2 = 0.25
1406
- threshold = 4e-5
1407
-
1408
- if cell_ids is not None:
1409
- if not isinstance(cell_ids, np.ndarray):
1410
- cell_ids = self.backend.to_numpy(cell_ids)
1411
- th, ph = hp.pix2ang(nside, cell_ids, nest=True)
1412
- x, y, z = hp.pix2vec(nside, cell_ids, nest=True)
1413
-
1414
- t, p = hp.pix2ang(nside, cell_ids, nest=True)
1415
- phi = [p[k] / np.pi * 180 for k in range(ncell)]
1416
- thi = [t[k] / np.pi * 180 for k in range(ncell)]
1417
-
1418
- indice2 = np.zeros([ncell * 64, 2], dtype="int")
1419
- indice = np.zeros([ncell * 64 * self.NORIENT, 2], dtype="int")
1420
- wav = np.zeros([ncell * 64 * self.NORIENT], dtype="complex")
1421
- wwav = np.zeros([ncell * 64 * self.NORIENT], dtype="float")
1422
-
1108
+ if spin!=0:
1109
+ try:
1110
+ tmp = np.load("%s/FOSCAT_%s_W%d_%d_%d_PIDX-SPIN0.npy"% (
1111
+ self.TEMPLATE_PATH,
1112
+ self.TMPFILE_VERSION,
1113
+ self.KERNELSZ**2,
1114
+ self.NORIENT,
1115
+ nside)
1116
+ )
1117
+ except:
1118
+ self.init_index(nside, kernel=kernel, spin=0)
1119
+
1120
+ tmp = np.load("%s/FOSCAT_%s_W%d_%d_%d_PIDX-SPIN0.npy"% (
1121
+ self.TEMPLATE_PATH,
1122
+ self.TMPFILE_VERSION,
1123
+ self.KERNELSZ**2,
1124
+ self.NORIENT,
1125
+ nside)
1126
+ )
1127
+
1128
+ tmpw = np.load("%s/FOSCAT_%s_W%d_%d_%d_WAVE-SPIN0.npy"% (
1129
+ self.TEMPLATE_PATH,
1130
+ self.TMPFILE_VERSION,
1131
+ self.KERNELSZ**2,
1132
+ self.NORIENT,
1133
+ nside,
1134
+ )
1135
+ )
1136
+
1137
+ nn=self.NORIENT*12*nside**2
1138
+ idxEB=np.concatenate([tmp,tmp,tmp,tmp],0)
1139
+ idxEB[tmp.shape[0]:2*tmp.shape[0],0]+=12*nside**2
1140
+ idxEB[3*tmp.shape[0]:,0]+=12*nside**2
1141
+ idxEB[2*tmp.shape[0]:,1]+=nn
1142
+
1143
+ tmpEB=np.zeros([tmpw.shape[0]*4],dtype='complex')
1144
+
1145
+ for k in range(self.NORIENT*12*nside**2):
1146
+ if k%(nside**2)==0:
1147
+ print('Init index 1/2 spin=%d Please wait %d done against %d nside=%d kernel=%d'%(spin,k//(nside**2),
1148
+ self.NORIENT*12,
1149
+ nside,
1150
+ self.KERNELSZ))
1151
+ idx=np.where(tmp[:,1]==k)[0]
1152
+
1153
+ im=np.zeros([12*nside**2])
1154
+ im[tmp[idx,0]]=tmpw[idx].real
1155
+ almR=hp.map2alm(hp.reorder(im,n2r=True))
1156
+ im[tmp[idx,0]]=tmpw[idx].imag
1157
+ almI=hp.map2alm(hp.reorder(im,n2r=True))
1158
+
1159
+ i,q,u=hp.alm2map_spin([almR,almR*0,0*almR],nside,spin,3*nside-1)
1160
+ i2,q2,u2=hp.alm2map_spin([almI,0*almI,0*almI],nside,spin,3*nside-1)
1161
+
1162
+ tmpEB[idx]=hp.reorder(i,r2n=True)[tmp[idx,0]]+1J*hp.reorder(i2,r2n=True)[tmp[idx,0]]
1163
+ tmpEB[idx+tmp.shape[0]]=hp.reorder(q,r2n=True)[tmp[idx,0]]+1J*hp.reorder(q2,r2n=True)[tmp[idx,0]]
1164
+
1165
+ i,q,u=hp.alm2map_spin([0*almR,almR,0*almR],nside,spin,3*nside-1)
1166
+ i2,q2,u2=hp.alm2map_spin([0*almI,almI,0*almI],nside,spin,3*nside-1)
1167
+
1168
+ tmpEB[idx+2*tmp.shape[0]]=hp.reorder(i,r2n=True)[tmp[idx,0]]+1J*hp.reorder(i2,r2n=True)[tmp[idx,0]]
1169
+ tmpEB[idx+3*tmp.shape[0]]=hp.reorder(q,r2n=True)[tmp[idx,0]]+1J*hp.reorder(q2,r2n=True)[tmp[idx,0]]
1170
+
1171
+
1172
+ np.save("%s/FOSCAT_%s_W%d_%d_%d_PIDX-SPIN%d.npy"% (self.TEMPLATE_PATH,
1173
+ self.TMPFILE_VERSION,
1174
+ self.KERNELSZ**2,
1175
+ self.NORIENT,
1176
+ nside,
1177
+ spin
1178
+ ),
1179
+ idxEB
1180
+ )
1181
+ np.save("%s/FOSCAT_%s_W%d_%d_%d_WAVE-SPIN%d.npy"% (self.TEMPLATE_PATH,
1182
+ self.TMPFILE_VERSION,
1183
+ self.KERNELSZ**2,
1184
+ self.NORIENT,
1185
+ nside,
1186
+ spin,
1187
+ ),
1188
+ tmpEB
1189
+ )
1190
+ tmp = np.load("%s/FOSCAT_%s_W%d_%d_%d_PIDX2-SPIN0.npy"%
1191
+ (
1192
+ self.TEMPLATE_PATH,
1193
+ self.TMPFILE_VERSION,
1194
+ self.KERNELSZ**2,
1195
+ self.NORIENT,
1196
+ nside,
1197
+ )
1198
+ )
1199
+ tmpw = np.load("%s/FOSCAT_%s_W%d_%d_%d_SMOO-SPIN0.npy"%
1200
+ (
1201
+ self.TEMPLATE_PATH,
1202
+ self.TMPFILE_VERSION,
1203
+ self.KERNELSZ**2,
1204
+ self.NORIENT,
1205
+ nside,
1206
+ )
1207
+ )
1208
+
1209
+ nn=12*nside**2
1210
+ idxEB=np.concatenate([tmp,tmp,tmp,tmp],0)
1211
+ idxEB[tmp.shape[0]:2*tmp.shape[0],0]+=12*nside**2
1212
+ idxEB[3*tmp.shape[0]:,0]+=12*nside**2
1213
+ idxEB[2*tmp.shape[0]:,1]+=nn
1214
+
1215
+ tmpEB=np.zeros([tmpw.shape[0]*4],dtype='complex')
1216
+
1217
+ for k in range(12*nside**2):
1218
+ if k%(nside**2)==0:
1219
+ print('Init index 2/2 spin=%d Please wait %d done against %d nside=%d kernel=%d'%(spin,k//(nside**2),
1220
+ 12,
1221
+ nside,
1222
+ self.KERNELSZ))
1223
+ idx=np.where(tmp[:,1]==k)[0]
1224
+
1225
+ im=np.zeros([12*nside**2])
1226
+ im[tmp[idx,0]]=tmpw[idx].real
1227
+ almR=hp.map2alm(hp.reorder(im,n2r=True))
1228
+ im[tmp[idx,0]]=tmpw[idx].imag
1229
+ almI=hp.map2alm(hp.reorder(im,n2r=True))
1230
+
1231
+ i,q,u=hp.alm2map_spin([almR,almR*0,0*almR],nside,spin,3*nside-1)
1232
+ i2,q2,u2=hp.alm2map_spin([almI,0*almI,0*almI],nside,spin,3*nside-1)
1233
+
1234
+ tmpEB[idx]=hp.reorder(i,r2n=True)[tmp[idx,0]]+1J*hp.reorder(i2,r2n=True)[tmp[idx,0]]
1235
+ tmpEB[idx+tmp.shape[0]]=hp.reorder(q,r2n=True)[tmp[idx,0]]+1J*hp.reorder(q2,r2n=True)[tmp[idx,0]]
1236
+
1237
+ i,q,u=hp.alm2map_spin([0*almR,almR,0*almR],nside,spin,3*nside-1)
1238
+ i2,q2,u2=hp.alm2map_spin([0*almI,almI,0*almI],nside,spin,3*nside-1)
1239
+
1240
+ tmpEB[idx+2*tmp.shape[0]]=hp.reorder(i,r2n=True)[tmp[idx,0]]+1J*hp.reorder(i2,r2n=True)[tmp[idx,0]]
1241
+ tmpEB[idx+3*tmp.shape[0]]=hp.reorder(q,r2n=True)[tmp[idx,0]]+1J*hp.reorder(q2,r2n=True)[tmp[idx,0]]
1242
+
1243
+
1244
+ np.save("%s/FOSCAT_%s_W%d_%d_%d_PIDX2-SPIN%d.npy"%
1245
+ (
1246
+ self.TEMPLATE_PATH,
1247
+ self.TMPFILE_VERSION,
1248
+ self.KERNELSZ**2,
1249
+ self.NORIENT,
1250
+ nside,
1251
+ spin,
1252
+ ),
1253
+ idxEB
1254
+ )
1255
+ np.save("%s/FOSCAT_%s_W%d_%d_%d_SMOO-SPIN%d.npy"%
1256
+ (
1257
+ self.TEMPLATE_PATH,
1258
+ self.TMPFILE_VERSION,
1259
+ self.KERNELSZ**2,
1260
+ self.NORIENT,
1261
+ nside,
1262
+ spin,
1263
+ ),
1264
+ tmpEB
1265
+ )
1423
1266
  else:
1424
1267
 
1425
- th, ph = hp.pix2ang(nside, np.arange(12 * nside**2), nest=True)
1426
- x, y, z = hp.pix2vec(nside, np.arange(12 * nside**2), nest=True)
1427
-
1428
- t, p = hp.pix2ang(nside, np.arange(12 * nside * nside), nest=True)
1429
- phi = [p[k] / np.pi * 180 for k in range(12 * nside * nside)]
1430
- thi = [t[k] / np.pi * 180 for k in range(12 * nside * nside)]
1268
+ if l_kernel == 5:
1269
+ pw = 0.5
1270
+ pw2 = 0.5
1271
+ threshold = 2e-4
1431
1272
 
1432
- indice2 = np.zeros([12 * nside * nside * 64, 2], dtype="int")
1433
- indice = np.zeros(
1434
- [12 * nside * nside * 64 * self.NORIENT, 2], dtype="int"
1435
- )
1436
- wav = np.zeros(
1437
- [12 * nside * nside * 64 * self.NORIENT], dtype="complex"
1438
- )
1439
- wwav = np.zeros(
1440
- [12 * nside * nside * 64 * self.NORIENT], dtype="float"
1441
- )
1442
- iv = 0
1443
- iv2 = 0
1273
+ elif l_kernel == 3:
1274
+ pw = 1.0 / np.sqrt(2)
1275
+ pw2 = 1.0
1276
+ threshold = 1e-3
1444
1277
 
1445
- for iii in range(ncell):
1446
- if cell_ids is None:
1447
- if iii % (nside * nside) == nside * nside - 1:
1448
- if not self.silent:
1449
- print(
1450
- "Pre-compute nside=%6d %.2f%%"
1451
- % (nside, 100 * iii / (12 * nside * nside))
1452
- )
1278
+ elif l_kernel == 7:
1279
+ pw = 0.5
1280
+ pw2 = 0.25
1281
+ threshold = 4e-5
1453
1282
 
1454
1283
  if cell_ids is not None:
1455
- hidx = np.where(
1456
- (x - x[iii]) ** 2 + (y - y[iii]) ** 2 + (z - z[iii]) ** 2
1457
- < (2 * np.pi / nside) ** 2
1458
- )[0]
1459
- else:
1460
- hidx = hp.query_disc(
1461
- nside,
1462
- [x[iii], y[iii], z[iii]],
1463
- 2 * np.pi / nside,
1464
- nest=True,
1465
- )
1284
+ if not isinstance(cell_ids, np.ndarray):
1285
+ cell_ids = self.backend.to_numpy(cell_ids)
1286
+ th, ph = hp.pix2ang(nside, cell_ids, nest=True)
1287
+ x, y, z = hp.pix2vec(nside, cell_ids, nest=True)
1466
1288
 
1467
- R = hp.Rotator(rot=[phi[iii], -thi[iii]], eulertype="ZYZ")
1289
+ t, p = hp.pix2ang(nside, cell_ids, nest=True)
1290
+ phi = [p[k] / np.pi * 180 for k in range(ncell)]
1291
+ thi = [t[k] / np.pi * 180 for k in range(ncell)]
1468
1292
 
1469
- t2, p2 = R(th[hidx], ph[hidx])
1293
+ indice2 = np.zeros([ncell * 64, 2], dtype="int")
1294
+ indice = np.zeros([ncell * 64 * self.NORIENT, 2], dtype="int")
1295
+ wav = np.zeros([ncell * 64 * self.NORIENT], dtype="complex")
1296
+ wwav = np.zeros([ncell * 64 * self.NORIENT], dtype="float")
1470
1297
 
1471
- vec2 = hp.ang2vec(t2, p2)
1298
+ else:
1472
1299
 
1473
- x2 = vec2[:, 0]
1474
- y2 = vec2[:, 1]
1475
- z2 = vec2[:, 2]
1300
+ th, ph = hp.pix2ang(nside, np.arange(12 * nside**2), nest=True)
1301
+ x, y, z = hp.pix2vec(nside, np.arange(12 * nside**2), nest=True)
1476
1302
 
1477
- ww = np.exp(
1478
- -pw2
1479
- * ((nside) ** 2)
1480
- * ((x2) ** 2 + (y2) ** 2 + (z2 - 1.0) ** 2)
1481
- )
1482
- idx = np.where((ww**2) > threshold)[0]
1483
- nval2 = len(idx)
1484
- indice2[iv2 : iv2 + nval2, 1] = iii
1485
- indice2[iv2 : iv2 + nval2, 0] = hidx[idx]
1486
- wwav[iv2 : iv2 + nval2] = ww[idx] / np.sum(ww[idx])
1487
- iv2 += nval2
1303
+ t, p = hp.pix2ang(nside, np.arange(12 * nside * nside), nest=True)
1304
+ phi = [p[k] / np.pi * 180 for k in range(12 * nside * nside)]
1305
+ thi = [t[k] / np.pi * 180 for k in range(12 * nside * nside)]
1488
1306
 
1489
- for l_rotation in range(self.NORIENT):
1490
-
1491
- angle = (
1492
- l_rotation / 4.0 * np.pi
1493
- - phi[iii] / 180.0 * np.pi * (z[hidx] > 0)
1494
- - (180.0 - phi[iii]) / 180.0 * np.pi * (z[hidx] < 0)
1307
+ indice2 = np.zeros([12 * nside * nside * 64, 2], dtype="int")
1308
+ indice = np.zeros(
1309
+ [12 * nside * nside * 64 * self.NORIENT, 2], dtype="int"
1310
+ )
1311
+ wav = np.zeros(
1312
+ [12 * nside * nside * 64 * self.NORIENT], dtype="complex"
1313
+ )
1314
+ wwav = np.zeros(
1315
+ [12 * nside * nside * 64 * self.NORIENT], dtype="float"
1495
1316
  )
1317
+ iv = 0
1318
+ iv2 = 0
1319
+
1320
+ for iii in range(ncell):
1321
+ if cell_ids is None:
1322
+ if iii % (nside * nside) == nside * nside - 1:
1323
+ if not self.silent:
1324
+ print(
1325
+ "Pre-compute nside=%6d %.2f%%"
1326
+ % (nside, 100 * iii / (12 * nside * nside))
1327
+ )
1328
+
1329
+ if cell_ids is not None:
1330
+ hidx = np.where(
1331
+ (x - x[iii]) ** 2 + (y - y[iii]) ** 2 + (z - z[iii]) ** 2
1332
+ < (2 * np.pi / nside) ** 2
1333
+ )[0]
1334
+ else:
1335
+ hidx = hp.query_disc(
1336
+ nside,
1337
+ [x[iii], y[iii], z[iii]],
1338
+ 2 * np.pi / nside,
1339
+ nest=True,
1340
+ )
1496
1341
 
1497
- # posi=2*(0.5-(z[hidx]<0))
1342
+ R = hp.Rotator(rot=[phi[iii], -thi[iii]], eulertype="ZYZ")
1498
1343
 
1499
- axes = y2 * np.cos(angle) - x2 * np.sin(angle)
1500
- wresr = ww * np.cos(pw * axes * (nside) * np.pi)
1501
- wresi = ww * np.sin(pw * axes * (nside) * np.pi)
1344
+ t2, p2 = R(th[hidx], ph[hidx])
1502
1345
 
1503
- vnorm = wresr * wresr + wresi * wresi
1504
- idx = np.where(vnorm > threshold)[0]
1346
+ vec2 = hp.ang2vec(t2, p2)
1505
1347
 
1506
- nval = len(idx)
1507
- indice[iv : iv + nval, 1] = iii + l_rotation * ncell
1508
- indice[iv : iv + nval, 0] = hidx[idx]
1509
- # print([hidx[k] for k in idx])
1510
- # print(hp.query_disc(nside, [x[iii],y[iii],z[iii]], np.pi/nside,nest=True))
1511
- normr = np.mean(wresr[idx])
1512
- normi = np.mean(wresi[idx])
1348
+ x2 = vec2[:, 0]
1349
+ y2 = vec2[:, 1]
1350
+ z2 = vec2[:, 2]
1513
1351
 
1514
- val = wresr[idx] - normr + 1j * (wresi[idx] - normi)
1515
- r = abs(val).sum()
1352
+ ww = np.exp(
1353
+ -pw2
1354
+ * ((nside) ** 2)
1355
+ * ((x2) ** 2 + (y2) ** 2 + (z2 - 1.0) ** 2)
1356
+ )
1357
+ idx = np.where((ww**2) > threshold)[0]
1358
+ nval2 = len(idx)
1359
+ indice2[iv2 : iv2 + nval2, 1] = iii
1360
+ indice2[iv2 : iv2 + nval2, 0] = hidx[idx]
1361
+ wwav[iv2 : iv2 + nval2] = ww[idx] / np.sum(ww[idx])
1362
+ iv2 += nval2
1363
+
1364
+ for l_rotation in range(self.NORIENT):
1365
+
1366
+ angle = (
1367
+ l_rotation / 4.0 * np.pi
1368
+ - phi[iii] / 180.0 * np.pi * (z[hidx] > 0)
1369
+ - (180.0 - phi[iii]) / 180.0 * np.pi * (z[hidx] < 0)
1370
+ )
1516
1371
 
1517
- if r > 0:
1518
- val = val / r
1372
+ # posi=2*(0.5-(z[hidx]<0))
1519
1373
 
1520
- wav[iv : iv + nval] = val
1521
- iv += nval
1374
+ axes = y2 * np.cos(angle) - x2 * np.sin(angle)
1375
+ wresr = ww * np.cos(pw * axes * (nside) * np.pi)
1376
+ wresi = ww * np.sin(pw * axes * (nside) * np.pi)
1522
1377
 
1523
- indice = indice[:iv, :]
1524
- wav = wav[:iv]
1525
- indice2 = indice2[:iv2, :]
1526
- wwav = wwav[:iv2]
1527
- if not self.silent:
1528
- print("Kernel Size ", iv / (self.NORIENT * 12 * nside * nside))
1529
- """
1530
- # OLD VERSION OLD VERSION OLD VERSION (3.0)
1531
- if self.KERNELSZ*self.KERNELSZ>12*nside*nside:
1532
- l_kernel=3
1533
-
1534
- aa=np.cos(np.arange(self.NORIENT)/self.NORIENT*np.pi).reshape(1,self.NORIENT)
1535
- bb=np.sin(np.arange(self.NORIENT)/self.NORIENT*np.pi).reshape(1,self.NORIENT)
1536
- x,y,z=hp.pix2vec(nside,np.arange(12*nside*nside),nest=True)
1537
- to,po=hp.pix2ang(nside,np.arange(12*nside*nside),nest=True)
1538
-
1539
- wav=np.zeros([12*nside*nside,l_kernel**2,self.NORIENT],dtype='complex')
1540
- wwav=np.zeros([12*nside*nside,l_kernel**2])
1541
- iwav=np.zeros([12*nside*nside,l_kernel**2],dtype='int')
1542
-
1543
- scale=4
1544
- if nside>scale*2:
1545
- th,ph=hp.pix2ang(nside//scale,np.arange(12*(nside//scale)**2),nest=True)
1546
- else:
1547
- lidx=np.arange(12*nside*nside)
1378
+ vnorm = wresr * wresr + wresi * wresi
1379
+ idx = np.where(vnorm > threshold)[0]
1548
1380
 
1549
- pw=np.pi/4.0
1550
- pw2=1/2
1551
- amp=1.0
1381
+ nval = len(idx)
1382
+ indice[iv : iv + nval, 1] = iii + l_rotation * ncell
1383
+ indice[iv : iv + nval, 0] = hidx[idx]
1384
+ # print([hidx[k] for k in idx])
1385
+ # print(hp.query_disc(nside, [x[iii],y[iii],z[iii]], np.pi/nside,nest=True))
1386
+ normr = np.mean(wresr[idx])
1387
+ normi = np.mean(wresi[idx])
1552
1388
 
1553
- if l_kernel==5:
1554
- pw=np.pi/4.0
1555
- pw2=1/2.25
1556
- amp=1.0/9.2038
1389
+ val = wresr[idx] - normr + 1j * (wresi[idx] - normi)
1390
+ r = abs(val).sum()
1557
1391
 
1558
- elif l_kernel==3:
1559
- pw=1.0/np.sqrt(2)
1560
- pw2=1.0
1561
- amp=1/8.45
1392
+ if r > 0:
1393
+ val = val / r
1562
1394
 
1563
- elif l_kernel==7:
1564
- pw=np.pi/4.0
1565
- pw2=1.0/3.0
1395
+ wav[iv : iv + nval] = val
1396
+ iv += nval
1566
1397
 
1567
- for k in range(12*nside*nside):
1568
- if k%(nside*nside)==0:
1569
- if not self.silent:
1570
- print('Pre-compute nside=%6d %.2f%%'%(nside,100*k/(12*nside*nside)))
1571
- if nside>scale*2:
1572
- lidx=hp.get_all_neighbours(nside//scale,th[k//(scale*scale)],ph[k//(scale*scale)],nest=True)
1573
- lidx=np.concatenate([lidx,np.array([(k//(scale*scale))])],0)
1574
- lidx=np.repeat(lidx*(scale*scale),(scale*scale))+ \
1575
- np.tile(np.arange((scale*scale)),lidx.shape[0])
1576
-
1577
- delta=(x[lidx]-x[k])**2+(y[lidx]-y[k])**2+(z[lidx]-z[k])**2
1578
- pidx=np.where(delta<(10)/(nside**2))[0]
1579
- if len(pidx)<l_kernel**2:
1580
- pidx=np.arange(delta.shape[0])
1581
-
1582
- w=np.exp(-pw2*delta[pidx]*(nside**2))
1583
- pidx=pidx[np.argsort(-w)[0:l_kernel**2]]
1584
- pidx=pidx[np.argsort(lidx[pidx])]
1585
-
1586
- w=np.exp(-pw2*delta[pidx]*(nside**2))
1587
- iwav[k]=lidx[pidx]
1588
- wwav[k]=w
1589
- rot=[po[k]/np.pi*180.0,90+(-to[k])/np.pi*180.0]
1590
- r=hp.Rotator(rot=rot)
1591
- ty,tx=r(to[iwav[k]],po[iwav[k]])
1592
- ty=ty-np.pi/2
1593
-
1594
- xx=np.expand_dims(pw*nside*np.pi*tx/np.cos(ty),-1)
1595
- yy=np.expand_dims(pw*nside*np.pi*ty,-1)
1596
-
1597
- wav[k,:,:]=(np.cos(xx*aa+yy*bb)+complex(0.0,1.0)*np.sin(xx*aa+yy*bb))*np.expand_dims(w,-1)
1598
-
1599
- wav=wav-np.expand_dims(np.mean(wav,1),1)
1600
- wav=amp*wav/np.expand_dims(np.std(wav,1),1)
1601
- wwav=wwav/np.expand_dims(np.sum(wwav,1),1)
1602
-
1603
- nk=l_kernel*l_kernel
1604
- indice=np.zeros([12*nside*nside*nk*self.NORIENT,2],dtype='int')
1605
- lidx=np.arange(self.NORIENT)
1606
- for i in range(12*nside*nside):
1607
- indice[i*nk*self.NORIENT:i*nk*self.NORIENT+nk*self.NORIENT,0]=i*self.NORIENT+np.repeat(lidx,nk)
1608
- indice[i*nk*self.NORIENT:i*nk*self.NORIENT+nk*self.NORIENT,1]=np.tile(iwav[i],self.NORIENT)
1609
-
1610
- indice2=np.zeros([12*nside*nside*nk,2],dtype='int')
1611
- for i in range(12*nside*nside):
1612
- indice2[i*nk:i*nk+nk,0]=i
1613
- indice2[i*nk:i*nk+nk,1]=iwav[i]
1614
-
1615
- w=np.zeros([12*nside*nside,wav.shape[2],wav.shape[1]],dtype='complex')
1616
- for i in range(wav.shape[1]):
1617
- for j in range(wav.shape[2]):
1618
- w[:,j,i]=wav[:,i,j]
1619
- wav=w.flatten()
1620
- wwav=wwav.flatten()
1621
- """
1622
- if cell_ids is None:
1398
+ indice = indice[:iv, :]
1399
+ wav = wav[:iv]
1400
+ indice2 = indice2[:iv2, :]
1401
+ wwav = wwav[:iv2]
1623
1402
  if not self.silent:
1624
- print(
1625
- "Write FOSCAT_%s_W%d_%d_%d_PIDX.npy"
1626
- % (TMPFILE_VERSION, self.KERNELSZ**2, self.NORIENT, nside)
1403
+ print("Kernel Size ", iv / (self.NORIENT * 12 * nside * nside))
1404
+
1405
+ if cell_ids is None:
1406
+ if not self.silent:
1407
+ print(
1408
+ "Write FOSCAT_%s_W%d_%d_%d_PIDX-SPIN%d.npy"
1409
+ % (TMPFILE_VERSION, self.KERNELSZ**2,
1410
+ self.NORIENT,
1411
+ nside,
1412
+ spin,)
1413
+ )
1414
+ np.save(
1415
+ "%s/FOSCAT_%s_W%d_%d_%d_PIDX-SPIN%d.npy"
1416
+ % (
1417
+ self.TEMPLATE_PATH,
1418
+ TMPFILE_VERSION,
1419
+ self.KERNELSZ**2,
1420
+ self.NORIENT,
1421
+ nside,
1422
+ spin,
1423
+ ),
1424
+ indice,
1425
+ )
1426
+ np.save(
1427
+ "%s/FOSCAT_%s_W%d_%d_%d_WAVE-SPIN%d.npy"
1428
+ % (
1429
+ self.TEMPLATE_PATH,
1430
+ TMPFILE_VERSION,
1431
+ self.KERNELSZ**2,
1432
+ self.NORIENT,
1433
+ nside,
1434
+ spin,
1435
+ ),
1436
+ wav,
1437
+ )
1438
+ np.save(
1439
+ "%s/FOSCAT_%s_W%d_%d_%d_PIDX2-SPIN%d.npy"
1440
+ % (
1441
+ self.TEMPLATE_PATH,
1442
+ TMPFILE_VERSION,
1443
+ self.KERNELSZ**2,
1444
+ self.NORIENT,
1445
+ nside,
1446
+ spin,
1447
+ ),
1448
+ indice2,
1449
+ )
1450
+ np.save(
1451
+ "%s/FOSCAT_%s_W%d_%d_%d_SMOO-SPIN%d.npy"
1452
+ % (
1453
+ self.TEMPLATE_PATH,
1454
+ TMPFILE_VERSION,
1455
+ self.KERNELSZ**2,
1456
+ self.NORIENT,
1457
+ nside,
1458
+ spin,
1459
+ ),
1460
+ wwav,
1627
1461
  )
1628
- np.save(
1629
- "%s/FOSCAT_%s_W%d_%d_%d_PIDX.npy"
1630
- % (
1631
- self.TEMPLATE_PATH,
1632
- TMPFILE_VERSION,
1633
- self.KERNELSZ**2,
1634
- self.NORIENT,
1635
- nside,
1636
- ),
1637
- indice,
1638
- )
1639
- np.save(
1640
- "%s/FOSCAT_%s_W%d_%d_%d_WAVE.npy"
1641
- % (
1642
- self.TEMPLATE_PATH,
1643
- TMPFILE_VERSION,
1644
- self.KERNELSZ**2,
1645
- self.NORIENT,
1646
- nside,
1647
- ),
1648
- wav,
1649
- )
1650
- np.save(
1651
- "%s/FOSCAT_%s_W%d_%d_%d_PIDX2.npy"
1652
- % (
1653
- self.TEMPLATE_PATH,
1654
- TMPFILE_VERSION,
1655
- self.KERNELSZ**2,
1656
- self.NORIENT,
1657
- nside,
1658
- ),
1659
- indice2,
1660
- )
1661
- np.save(
1662
- "%s/FOSCAT_%s_W%d_%d_%d_SMOO.npy"
1663
- % (
1664
- self.TEMPLATE_PATH,
1665
- TMPFILE_VERSION,
1666
- self.KERNELSZ**2,
1667
- self.NORIENT,
1668
- nside,
1669
- ),
1670
- wwav,
1671
- )
1672
1462
  if self.use_2D:
1673
1463
  if l_kernel**2 == 9:
1674
1464
  if self.rank == 0:
@@ -1689,58 +1479,68 @@ class FoCUS:
1689
1479
  self.barrier()
1690
1480
  if self.use_2D:
1691
1481
  tmp = np.load(
1692
- "%s/W%d_%s_%d_IDX.npy"
1693
- % (self.TEMPLATE_PATH, l_kernel**2, TMPFILE_VERSION, nside)
1482
+ "%s/W%d_%s_%d_IDX-SPIN%d.npy"
1483
+ % (
1484
+ self.TEMPLATE_PATH,
1485
+ l_kernel**2,
1486
+ TMPFILE_VERSION,
1487
+ nside,
1488
+ spin)
1694
1489
  )
1695
1490
  else:
1696
1491
  tmp = np.load(
1697
- "%s/FOSCAT_%s_W%d_%d_%d_PIDX.npy"
1492
+ "%s/FOSCAT_%s_W%d_%d_%d_PIDX-SPIN%d.npy"
1698
1493
  % (
1699
1494
  self.TEMPLATE_PATH,
1700
1495
  TMPFILE_VERSION,
1701
1496
  self.KERNELSZ**2,
1702
1497
  self.NORIENT,
1703
1498
  nside,
1499
+ spin,
1704
1500
  )
1705
1501
  )
1706
1502
  tmp2 = np.load(
1707
- "%s/FOSCAT_%s_W%d_%d_%d_PIDX2.npy"
1503
+ "%s/FOSCAT_%s_W%d_%d_%d_PIDX2-SPIN%d.npy"
1708
1504
  % (
1709
1505
  self.TEMPLATE_PATH,
1710
1506
  TMPFILE_VERSION,
1711
1507
  self.KERNELSZ**2,
1712
1508
  self.NORIENT,
1713
1509
  nside,
1510
+ spin,
1714
1511
  )
1715
1512
  )
1716
1513
  wr = np.load(
1717
- "%s/FOSCAT_%s_W%d_%d_%d_WAVE.npy"
1514
+ "%s/FOSCAT_%s_W%d_%d_%d_WAVE-SPIN%d.npy"
1718
1515
  % (
1719
1516
  self.TEMPLATE_PATH,
1720
1517
  TMPFILE_VERSION,
1721
1518
  self.KERNELSZ**2,
1722
1519
  self.NORIENT,
1723
1520
  nside,
1521
+ spin,
1724
1522
  )
1725
1523
  ).real
1726
1524
  wi = np.load(
1727
- "%s/FOSCAT_%s_W%d_%d_%d_WAVE.npy"
1525
+ "%s/FOSCAT_%s_W%d_%d_%d_WAVE-SPIN%d.npy"
1728
1526
  % (
1729
1527
  self.TEMPLATE_PATH,
1730
1528
  TMPFILE_VERSION,
1731
1529
  self.KERNELSZ**2,
1732
1530
  self.NORIENT,
1733
1531
  nside,
1532
+ spin,
1734
1533
  )
1735
1534
  ).imag
1736
1535
  ws = self.slope * np.load(
1737
- "%s/FOSCAT_%s_W%d_%d_%d_SMOO.npy"
1536
+ "%s/FOSCAT_%s_W%d_%d_%d_SMOO-SPIN%d.npy"
1738
1537
  % (
1739
1538
  self.TEMPLATE_PATH,
1740
1539
  TMPFILE_VERSION,
1741
1540
  self.KERNELSZ**2,
1742
1541
  self.NORIENT,
1743
1542
  nside,
1543
+ spin,
1744
1544
  )
1745
1545
  )
1746
1546
  else:
@@ -1750,21 +1550,38 @@ class FoCUS:
1750
1550
  wi = wav.imag
1751
1551
  ws = self.slope * wwav
1752
1552
 
1753
- wr = self.backend.bk_SparseTensor(
1754
- self.backend.bk_constant(tmp),
1755
- self.backend.bk_constant(self.backend.bk_cast(wr)),
1756
- dense_shape=[ncell, self.NORIENT * ncell],
1757
- )
1758
- wi = self.backend.bk_SparseTensor(
1759
- self.backend.bk_constant(tmp),
1760
- self.backend.bk_constant(self.backend.bk_cast(wi)),
1761
- dense_shape=[ncell, self.NORIENT * ncell],
1762
- )
1763
- ws = self.backend.bk_SparseTensor(
1764
- self.backend.bk_constant(tmp2),
1765
- self.backend.bk_constant(self.backend.bk_cast(ws)),
1766
- dense_shape=[ncell, ncell],
1767
- )
1553
+ if spin==0:
1554
+ wr = self.backend.bk_SparseTensor(
1555
+ self.backend.bk_constant(tmp),
1556
+ self.backend.bk_constant(self.backend.bk_cast(wr)),
1557
+ dense_shape=[ncell, self.NORIENT * ncell],
1558
+ )
1559
+ wi = self.backend.bk_SparseTensor(
1560
+ self.backend.bk_constant(tmp),
1561
+ self.backend.bk_constant(self.backend.bk_cast(wi)),
1562
+ dense_shape=[ncell, self.NORIENT * ncell],
1563
+ )
1564
+ ws = self.backend.bk_SparseTensor(
1565
+ self.backend.bk_constant(tmp2),
1566
+ self.backend.bk_constant(self.backend.bk_cast(ws)),
1567
+ dense_shape=[ncell, ncell],
1568
+ )
1569
+ else:
1570
+ wr = self.backend.bk_SparseTensor(
1571
+ self.backend.bk_constant(tmp),
1572
+ self.backend.bk_constant(self.backend.bk_cast(wr)),
1573
+ dense_shape=[2*ncell, 2*self.NORIENT * ncell],
1574
+ )
1575
+ wi = self.backend.bk_SparseTensor(
1576
+ self.backend.bk_constant(tmp),
1577
+ self.backend.bk_constant(self.backend.bk_cast(wi)),
1578
+ dense_shape=[2*ncell, 2*self.NORIENT * ncell],
1579
+ )
1580
+ ws = self.backend.bk_SparseTensor(
1581
+ self.backend.bk_constant(tmp2),
1582
+ self.backend.bk_constant(self.backend.bk_cast(ws)),
1583
+ dense_shape=[2*ncell, 2*ncell],
1584
+ )
1768
1585
 
1769
1586
  if kernel == -1:
1770
1587
  self.Idx_Neighbours[nside] = tmp
@@ -1775,6 +1592,232 @@ class FoCUS:
1775
1592
 
1776
1593
  return wr, wi, ws, tmp
1777
1594
 
1595
+
1596
+ # ---------------------------------------------−---------
1597
+ def init_index_cnn(self, nside, NORIENT=4,kernel=-1, cell_ids=None):
1598
+
1599
+ if kernel == -1:
1600
+ l_kernel = self.KERNELSZ
1601
+ else:
1602
+ l_kernel = kernel
1603
+
1604
+ if cell_ids is not None:
1605
+ ncell = cell_ids.shape[0]
1606
+ else:
1607
+ ncell = 12 * nside * nside
1608
+
1609
+ try:
1610
+
1611
+ if cell_ids is not None:
1612
+ tmp = np.load(
1613
+ "%s/XXXX_%s_W%d_%d_%d_PIDX.npy" # can not work
1614
+ % (
1615
+ self.TEMPLATE_PATH,
1616
+ TMPFILE_VERSION,
1617
+ l_kernel**2,
1618
+ NORIENT,
1619
+ nside, # if cell_ids computes the index
1620
+ )
1621
+ )
1622
+
1623
+ else:
1624
+ tmp = np.load(
1625
+ "%s/CNN_FOSCAT_%s_W%d_%d_%d_PIDX.npy"
1626
+ % (
1627
+ self.TEMPLATE_PATH,
1628
+ TMPFILE_VERSION,
1629
+ l_kernel**2,
1630
+ NORIENT,
1631
+ nside, # if cell_ids computes the index
1632
+ )
1633
+ )
1634
+ except:
1635
+
1636
+ pw = 8.0
1637
+ pw2 = 1.0
1638
+ threshold = 1e-3
1639
+
1640
+ if l_kernel == 5:
1641
+ pw = 8.0
1642
+ pw2 = 0.5
1643
+ threshold = 2e-4
1644
+
1645
+ elif l_kernel == 3:
1646
+ pw = 8.0
1647
+ pw2 = 1.0
1648
+ threshold = 1e-3
1649
+
1650
+ elif l_kernel == 7:
1651
+ pw = 8.0
1652
+ pw2 = 0.25
1653
+ threshold = 4e-5
1654
+
1655
+ n_weights = self.KERNELSZ*(self.KERNELSZ//2+1)
1656
+
1657
+ if cell_ids is not None:
1658
+ if not isinstance(cell_ids, np.ndarray):
1659
+ cell_ids = self.backend.to_numpy(cell_ids)
1660
+ th, ph = hp.pix2ang(nside, cell_ids, nest=True)
1661
+ x, y, z = hp.pix2vec(nside, cell_ids, nest=True)
1662
+
1663
+ t, p = hp.pix2ang(nside, cell_ids, nest=True)
1664
+ phi = [p[k] / np.pi * 180 for k in range(ncell)]
1665
+ thi = [t[k] / np.pi * 180 for k in range(ncell)]
1666
+
1667
+ indice = np.zeros([n_weights, NORIENT, ncell,4], dtype="int")
1668
+
1669
+ wav = np.zeros([n_weights, NORIENT, ncell,4], dtype="float")
1670
+
1671
+ else:
1672
+
1673
+ th, ph = hp.pix2ang(nside, np.arange(12 * nside**2), nest=True)
1674
+ x, y, z = hp.pix2vec(nside, np.arange(12 * nside**2), nest=True)
1675
+
1676
+ t, p = hp.pix2ang(nside, np.arange(12 * nside * nside), nest=True)
1677
+ phi = [p[k] / np.pi * 180 for k in range(12 * nside * nside)]
1678
+ thi = [t[k] / np.pi * 180 for k in range(12 * nside * nside)]
1679
+
1680
+ indice = np.zeros(
1681
+ [n_weights, NORIENT, 12 * nside * nside,4], dtype="int"
1682
+ )
1683
+ wav = np.zeros(
1684
+ [n_weights, NORIENT, 12 * nside * nside,4], dtype="float"
1685
+ )
1686
+ iv = 0
1687
+ iv2 = 0
1688
+
1689
+ for iii in range(ncell):
1690
+ if cell_ids is None:
1691
+ if iii % (nside * nside) == nside * nside - 1:
1692
+ if not self.silent:
1693
+ print(
1694
+ "Pre-compute nside=%6d %.2f%%"
1695
+ % (nside, 100 * iii / (12 * nside * nside))
1696
+ )
1697
+
1698
+ if cell_ids is not None:
1699
+ hidx = np.where(
1700
+ (x - x[iii]) ** 2 + (y - y[iii]) ** 2 + (z - z[iii]) ** 2
1701
+ < (2 * np.pi / nside) ** 2
1702
+ )[0]
1703
+ else:
1704
+ hidx = hp.query_disc(
1705
+ nside,
1706
+ [x[iii], y[iii], z[iii]],
1707
+ 2 * np.pi / nside,
1708
+ nest=True,
1709
+ )
1710
+
1711
+ R = hp.Rotator(rot=[phi[iii], -thi[iii]], eulertype="ZYZ")
1712
+
1713
+ t2, p2 = R(th[hidx], ph[hidx])
1714
+
1715
+ vec2 = hp.ang2vec(t2, p2)
1716
+
1717
+ x2 = vec2[:, 0]
1718
+ y2 = vec2[:, 1]
1719
+ z2 = vec2[:, 2]
1720
+
1721
+ for l_rotation in range(NORIENT):
1722
+
1723
+ angle = (
1724
+ l_rotation / 4.0 * np.pi
1725
+ - phi[iii] / 180.0 * np.pi * (z[hidx] > 0)
1726
+ - (180.0 - phi[iii]) / 180.0 * np.pi * (z[hidx] < 0)
1727
+ )
1728
+
1729
+
1730
+ axes = y2 * np.cos(angle) - x2 * np.sin(angle)
1731
+ axes2 = -y2 * np.sin(angle) - x2 * np.cos(angle)
1732
+
1733
+ for k_weights in range(self.KERNELSZ//2+1):
1734
+ for l_weights in range(self.KERNELSZ):
1735
+
1736
+ val=np.exp(-(pw*(axes2*(nside)-(k_weights-self.KERNELSZ//2))**2+pw*(axes*(nside)-(l_weights-self.KERNELSZ//2))**2))+ \
1737
+ np.exp(-(pw*(axes2*(nside)+(k_weights-self.KERNELSZ//2))**2+pw*(axes*(nside)-(l_weights-self.KERNELSZ//2))**2))
1738
+
1739
+ idx = np.argsort(-val)
1740
+ idx = idx[0:4]
1741
+
1742
+ nval = len(idx)
1743
+ val=val[idx]
1744
+
1745
+ r = abs(val).sum()
1746
+
1747
+ if r > 0:
1748
+ val = val / r
1749
+
1750
+ indice[k_weights*self.KERNELSZ+l_weights,l_rotation,iii,:] = hidx[idx]
1751
+ wav[k_weights*self.KERNELSZ+l_weights,l_rotation,iii,:] = val
1752
+
1753
+ if not self.silent:
1754
+ print("Kernel Size ", iv / (NORIENT * 12 * nside * nside))
1755
+
1756
+ if cell_ids is None:
1757
+ if not self.silent:
1758
+ print(
1759
+ "Write FOSCAT_%s_W%d_%d_%d_PIDX.npy"
1760
+ % (TMPFILE_VERSION, self.KERNELSZ**2, NORIENT, nside)
1761
+ )
1762
+ np.save(
1763
+ "%s/CNN_FOSCAT_%s_W%d_%d_%d_PIDX.npy"
1764
+ % (
1765
+ self.TEMPLATE_PATH,
1766
+ TMPFILE_VERSION,
1767
+ self.KERNELSZ**2,
1768
+ NORIENT,
1769
+ nside,
1770
+ ),
1771
+ indice,
1772
+ )
1773
+ np.save(
1774
+ "%s/CNN_FOSCAT_%s_W%d_%d_%d_WAVE.npy"
1775
+ % (
1776
+ self.TEMPLATE_PATH,
1777
+ TMPFILE_VERSION,
1778
+ self.KERNELSZ**2,
1779
+ NORIENT,
1780
+ nside,
1781
+ ),
1782
+ wav,
1783
+ )
1784
+
1785
+ if cell_ids is None:
1786
+ self.barrier()
1787
+ if self.use_2D:
1788
+ tmp = np.load(
1789
+ "%s/W%d_%s_%d_IDX.npy"
1790
+ % (self.TEMPLATE_PATH, l_kernel**2, TMPFILE_VERSION, nside)
1791
+ )
1792
+ else:
1793
+ tmp = np.load(
1794
+ "%s/CNN_FOSCAT_%s_W%d_%d_%d_PIDX.npy"
1795
+ % (
1796
+ self.TEMPLATE_PATH,
1797
+ TMPFILE_VERSION,
1798
+ self.KERNELSZ**2,
1799
+ NORIENT,
1800
+ nside,
1801
+ )
1802
+ )
1803
+ wav = np.load(
1804
+ "%s/CNN_FOSCAT_%s_W%d_%d_%d_WAVE.npy"
1805
+ % (
1806
+ self.TEMPLATE_PATH,
1807
+ TMPFILE_VERSION,
1808
+ self.KERNELSZ**2,
1809
+ NORIENT,
1810
+ nside,
1811
+ )
1812
+ )
1813
+ else:
1814
+ tmp = indice
1815
+
1816
+ self.Idx_CNN[(nside,NORIENT,self.KERNELSZ)] = tmp
1817
+ self.Idx_WCNN[(nside,NORIENT,self.KERNELSZ)] = self.backend.bk_cast(wav)
1818
+
1819
+ return wav, tmp
1820
+
1778
1821
  # ---------------------------------------------−---------
1779
1822
  # convert swap axes tensor x [....,a,....,b,....] to [....,b,....,a,....]
1780
1823
  def swapaxes(self, x, axis1, axis2):
@@ -1795,10 +1838,10 @@ class FoCUS:
1795
1838
  return self.backend.bk_transpose(x, thelist)
1796
1839
 
1797
1840
  # ---------------------------------------------−---------
1798
- # Mean using mask x [....,Npix,....], mask[Nmask,Npix] to [....,Nmask,....]
1841
+ # Mean using mask x [n_b,....,Npix], mask[Nmask,Npix] to [n_b,Nmask,....]
1799
1842
  # if use_2D
1800
- # Mean using mask x [....,12,Nside+2*off,Nside+2*off,....], mask[Nmask,12,Nside+2*off,Nside+2*off] to [....,Nmask,....]
1801
- def masked_mean(self, x, mask, axis=0, rank=0, calc_var=False):
1843
+ # Mean using mask x [n_b,....,N_1,N_2], mask[Nmask,N_1,N_2] to [n_b,Nmask,....]
1844
+ def masked_mean(self, x, mask, rank=0, calc_var=False):
1802
1845
 
1803
1846
  # ==========================================================================
1804
1847
  # in input data=[Nbatch,...,NORIENT[,NORIENT],X[,Y]]
@@ -1810,7 +1853,7 @@ class FoCUS:
1810
1853
  shape = list(x.shape)
1811
1854
 
1812
1855
  if not self.use_2D and not self.use_1D:
1813
- nside = int(np.sqrt(x.shape[axis] // 12))
1856
+ nside = int(np.sqrt(x.shape[-1] // 12))
1814
1857
 
1815
1858
  l_mask = mask
1816
1859
  if self.mask_norm:
@@ -1904,16 +1947,24 @@ class FoCUS:
1904
1947
  l_x = self.backend.bk_reshape(
1905
1948
  l_x[:, :, self.KERNELSZ // 2 : -self.KERNELSZ // 2 + 1, :], oshape
1906
1949
  )
1907
- else:
1950
+ else:
1908
1951
  ichannel = 1
1909
- for i in range(len(shape) - 1):
1910
- ichannel *= shape[i]
1952
+ if len(shape)>1:
1953
+ ichannel = shape[0]
1954
+
1955
+ ochannel = 1
1956
+ for i in range(1,len(shape)-1):
1957
+ ochannel *= shape[i]
1911
1958
 
1912
- l_x = self.backend.bk_reshape(x, [ichannel, 1, shape[-1]])
1959
+ l_x = self.backend.bk_reshape(x, [ichannel,1,ochannel,shape[-1]])
1913
1960
 
1914
- # data=[Nbatch,...,NORIENT[,NORIENT],X[,Y]] => data=[Nbatch,1,...,NORIENT[,NORIENT],X[,Y]]
1961
+ # data=[Nbatch,...,NORIENT[,NORIENT],X[,Y]] => data=[Nbatch,...,1,NORIENT[,NORIENT],X[,Y]]
1915
1962
  # mask=[Nmask,X[,Y]] => mask=[1,Nmask,....,X[,Y]]
1916
- l_mask = self.backend.bk_expand_dims(self.backend.bk_expand_dims(l_mask, 0), 0)
1963
+
1964
+ if self.use_2D:
1965
+ l_mask = self.backend.bk_expand_dims(self.backend.bk_expand_dims(l_mask,0),-3)
1966
+ else:
1967
+ l_mask = self.backend.bk_expand_dims(self.backend.bk_expand_dims(l_mask,0),-2)
1917
1968
 
1918
1969
  if l_x.dtype == self.all_cbk_type:
1919
1970
  l_mask = self.backend.bk_complex(l_mask, self.backend.bk_cast(0.0 * l_mask))
@@ -1944,6 +1995,8 @@ class FoCUS:
1944
1995
 
1945
1996
  if len(x.shape[axis:-2]) > 0:
1946
1997
  oshape = oshape + list(x.shape[axis:-2])
1998
+ else:
1999
+ oshape = oshape + [1]
1947
2000
 
1948
2001
  if calc_var:
1949
2002
  if self.backend.bk_is_complex(vtmp):
@@ -1973,7 +2026,7 @@ class FoCUS:
1973
2026
  elif self.use_1D:
1974
2027
  mtmp = l_mask
1975
2028
  vtmp = l_x
1976
- v1 = self.backend.bk_reduce_sum(mtmp * vtmp, axis=-1)
2029
+ v1 = self.backend.bk_reduce_sum(l_mask[1,:,...,:] * vtmp, axis=-1)
1977
2030
  v2 = self.backend.bk_reduce_sum(mtmp * vtmp * vtmp, axis=-1)
1978
2031
  vh = self.backend.bk_reduce_sum(mtmp, axis=-1)
1979
2032
 
@@ -1982,6 +2035,8 @@ class FoCUS:
1982
2035
  oshape = [x.shape[0]] + [mask.shape[0]]
1983
2036
  if len(x.shape) > 1:
1984
2037
  oshape = oshape + list(x.shape[1:-1])
2038
+ else:
2039
+ oshape = oshape + [1]
1985
2040
 
1986
2041
  if calc_var:
1987
2042
  if self.backend.bk_is_complex(vtmp):
@@ -2015,13 +2070,16 @@ class FoCUS:
2015
2070
  res = v1 / vh
2016
2071
 
2017
2072
  oshape = []
2018
- if axis > 0:
2073
+ if len(shape) > 1:
2019
2074
  oshape = [x.shape[0]]
2020
2075
  else:
2021
2076
  oshape = [1]
2077
+
2022
2078
  oshape = oshape + [mask.shape[0]]
2023
- if axis > 1:
2024
- oshape = oshape + list(x.shape[1:-1])
2079
+ if len(shape) > 2:
2080
+ oshape = oshape + shape[1:-1]
2081
+ else:
2082
+ oshape = oshape + [1]
2025
2083
 
2026
2084
  if calc_var:
2027
2085
  if self.backend.bk_is_complex(l_x):
@@ -2175,7 +2233,7 @@ class FoCUS:
2175
2233
  return self.backend.bk_reduce_sum(r)
2176
2234
 
2177
2235
  # ---------------------------------------------−---------
2178
- def convol(self, in_image, axis=0, cell_ids=None, nside=None):
2236
+ def convol(self, in_image, axis=0, cell_ids=None, nside=None, spin=0):
2179
2237
 
2180
2238
  image = self.backend.bk_cast(in_image)
2181
2239
 
@@ -2238,77 +2296,37 @@ class FoCUS:
2238
2296
 
2239
2297
  else:
2240
2298
  ishape = list(image.shape)
2241
- """
2242
- if cell_ids is not None:
2243
- if cell_ids.shape[0] not in self.padding_conv:
2244
- print(image.shape,cell_ids.shape)
2245
- import healpix_convolution as hc
2246
- from xdggs.healpix import HealpixInfo
2247
-
2248
- res = self.backend.bk_zeros(
2249
- ishape[0:-1] + [self.NORIENT]+ishape[-1], dtype=self.backend.all_cbk_type
2250
- )
2299
+ if nside is None:
2300
+ nside = int(np.sqrt(image.shape[-1] // 12))
2251
2301
 
2252
- grid_info = HealpixInfo(
2253
- level=int(np.log(nside) / np.log(2)), indexing_scheme="nested"
2254
- )
2302
+ if spin==0:
2303
+ if nside not in self.Idx_Neighbours:
2304
+ if self.InitWave is None:
2305
+ wr, wi, ws, widx = self.init_index(nside, cell_ids=cell_ids)
2306
+ else:
2307
+ wr, wi, ws, widx = self.InitWave(nside, cell_ids=cell_ids)
2255
2308
 
2256
- for k in range(self.NORIENT):
2257
- kernelR, kernelI = hc.kernels.wavelet_kernel(
2258
- cell_ids, grid_info=grid_info, orientation=k, is_torch=True
2259
- )
2260
- self.kernelR_conv[(cell_ids.shape[0], k)] = kernelR.to(
2261
- self.backend.all_bk_type
2262
- ).to(image.device)
2263
- self.kernelI_conv[(cell_ids.shape[0], k)] = kernelI.to(
2264
- self.backend.all_bk_type
2265
- ).to(image.device)
2266
- self.padding_conv[(cell_ids.shape[0], k)] = hc.pad(
2267
- cell_ids,
2268
- grid_info=grid_info,
2269
- ring=5 // 2, # wavelet kernel_size=5 is hard coded
2270
- mode="mean",
2271
- constant_value=0,
2272
- )
2309
+ self.Idx_Neighbours[nside] = 1 # self.backend.bk_constant(tmp)
2310
+ self.ww_Real[nside] = wr
2311
+ self.ww_Imag[nside] = wi
2312
+ self.w_smooth[nside] = ws
2273
2313
 
2274
- for k in range(self.NORIENT):
2275
-
2276
- kernelR = self.kernelR_conv[(cell_ids.shape[0], k)]
2277
- kernelI = self.kernelI_conv[(cell_ids.shape[0], k)]
2278
- padding = self.padding_conv[(cell_ids.shape[0], k)]
2279
- if len(ishape) == 2:
2280
- for l in range(ishape[0]):
2281
- padded_data = padding.apply(image[l], is_torch=True)
2282
- res[l, :, k] = kernelR.matmul(
2283
- padded_data
2284
- ) + 1j * kernelI.matmul(padded_data)
2314
+ l_ww_real = self.ww_Real[nside]
2315
+ l_ww_imag = self.ww_Imag[nside]
2316
+ else:
2317
+ if (spin,nside) not in self.Idx_Neighbours:
2318
+ if self.InitWave is None:
2319
+ wr, wi, ws, widx = self.init_index(nside, cell_ids=cell_ids,spin=spin)
2285
2320
  else:
2286
- for l in range(ishape[0]):
2287
- for k2 in range(ishape[2]):
2288
- padded_data = padding.apply(
2289
- image[l, :, k2], is_torch=True
2290
- )
2291
- res[l, :, k2, k] = kernelR.matmul(
2292
- padded_data
2293
- ) + 1j * kernelI.matmul(padded_data)
2294
- return res
2295
- """
2296
- if nside is None:
2297
- nside = int(np.sqrt(image.shape[-1] // 12))
2298
-
2299
- if self.Idx_Neighbours[nside] is None:
2300
- if self.InitWave is None:
2301
- wr, wi, ws, widx = self.init_index(nside, cell_ids=cell_ids)
2302
- else:
2303
- wr, wi, ws, widx = self.InitWave(nside, cell_ids=cell_ids)
2321
+ wr, wi, ws, widx = self.InitWave(nside, cell_ids=cell_ids,spin=spin)
2304
2322
 
2305
- self.Idx_Neighbours[nside] = 1 # self.backend.bk_constant(tmp)
2306
- self.ww_Real[nside] = wr
2307
- self.ww_Imag[nside] = wi
2308
- self.w_smooth[nside] = ws
2323
+ self.Idx_Neighbours[(spin,nside)] = 1 # self.backend.bk_constant(tmp)
2324
+ self.ww_Real[(spin,nside)] = wr
2325
+ self.ww_Imag[(spin,nside)] = wi
2326
+ self.w_smooth[(spin,nside)] = ws
2309
2327
 
2310
- l_ww_real = self.ww_Real[nside]
2311
- l_ww_imag = self.ww_Imag[nside]
2328
+ l_ww_real = self.ww_Real[(spin,nside)]
2329
+ l_ww_imag = self.ww_Imag[(spin,nside)]
2312
2330
 
2313
2331
  # always convolve the last dimension
2314
2332
 
@@ -2316,9 +2334,14 @@ class FoCUS:
2316
2334
  if len(ishape) > 1:
2317
2335
  for k in range(len(ishape) - 1):
2318
2336
  ndata = ndata * ishape[k]
2319
- tim = self.backend.bk_reshape(
2320
- self.backend.bk_cast(image), [ndata, ishape[-1]]
2321
- )
2337
+ if spin>0:
2338
+ tim = self.backend.bk_reshape(
2339
+ self.backend.bk_cast(image), [ndata//2,2*ishape[-1]]
2340
+ )
2341
+ else:
2342
+ tim = self.backend.bk_reshape(
2343
+ self.backend.bk_cast(image), [ndata, ishape[-1]]
2344
+ )
2322
2345
 
2323
2346
  if tim.dtype == self.all_cbk_type:
2324
2347
  rr1 = self.backend.bk_reshape(
@@ -2360,17 +2383,27 @@ class FoCUS:
2360
2383
  [ndata, self.NORIENT, ishape[-1]],
2361
2384
  )
2362
2385
  res = self.backend.bk_complex(rr, ii)
2363
- if len(ishape) > 1:
2364
- return self.backend.bk_reshape(
2365
- res, ishape[0:-1] + [self.NORIENT, ishape[-1]]
2366
- )
2386
+
2387
+ if spin==0:
2388
+ if len(ishape) > 1:
2389
+ return self.backend.bk_reshape(
2390
+ res, ishape[0:-1] + [self.NORIENT, ishape[-1]]
2391
+ )
2392
+ else:
2393
+ return self.backend.bk_reshape(res, [self.NORIENT, ishape[-1]])
2367
2394
  else:
2368
- return self.backend.bk_reshape(res, [self.NORIENT, ishape[-1]])
2395
+ if len(ishape) > 2:
2396
+ return self.backend.bk_reshape(
2397
+ res, ishape[0:-2] + [2,self.NORIENT, ishape[-1]]
2398
+ )
2399
+ else:
2400
+ return self.backend.bk_reshape(res, [2,self.NORIENT, ishape[-1]])
2401
+
2369
2402
 
2370
2403
  return res
2371
2404
 
2372
2405
  # ---------------------------------------------−---------
2373
- def smooth(self, in_image, axis=0, cell_ids=None, nside=None):
2406
+ def smooth(self, in_image, axis=0, cell_ids=None, nside=None, spin=0):
2374
2407
 
2375
2408
  image = self.backend.bk_cast(in_image)
2376
2409
 
@@ -2430,64 +2463,35 @@ class FoCUS:
2430
2463
  else:
2431
2464
 
2432
2465
  ishape = list(image.shape)
2433
- """
2434
- if cell_ids is not None:
2435
- if cell_ids.shape[0] not in self.padding_smooth:
2436
- import healpix_convolution as hc
2437
- from xdggs.healpix import HealpixInfo
2438
-
2439
- grid_info = HealpixInfo(
2440
- level=int(np.log(nside) / np.log(2)), indexing_scheme="nested"
2441
- )
2442
-
2443
- kernel = hc.kernels.wavelet_smooth_kernel(
2444
- cell_ids, grid_info=grid_info, is_torch=True
2445
- )
2446
-
2447
- self.kernel_smooth[cell_ids.shape[0]] = kernel.to(
2448
- self.backend.all_bk_type
2449
- ).to(image.device)
2450
-
2451
- self.padding_smooth[cell_ids.shape[0]] = hc.pad(
2452
- cell_ids,
2453
- grid_info=grid_info,
2454
- ring=5 // 2, # wavelet kernel_size=5 is hard coded
2455
- mode="mean",
2456
- constant_value=0,
2457
- )
2458
-
2459
- kernel = self.kernel_smooth[cell_ids.shape[0]]
2460
- padding = self.padding_smooth[cell_ids.shape[0]]
2461
-
2462
- res = self.backend.bk_zeros(ishape, dtype=self.backend.all_cbk_type)
2463
-
2464
- if len(ishape) == 2:
2465
- for l in range(ishape[0]):
2466
- padded_data = padding.apply(image[l], is_torch=True)
2467
- res[l] = kernel.matmul(padded_data)
2468
- else:
2469
- for l in range(ishape[0]):
2470
- for k2 in range(ishape[2]):
2471
- padded_data = padding.apply(image[l, :, k2], is_torch=True)
2472
- res[l, :, k2] = kernel.matmul(padded_data)
2473
- return res
2474
- """
2466
+
2475
2467
  if nside is None:
2476
2468
  nside = int(np.sqrt(image.shape[-1] // 12))
2477
2469
 
2478
- if self.Idx_Neighbours[nside] is None:
2470
+ if spin==0:
2471
+ if nside not in self.Idx_Neighbours:
2472
+ if self.InitWave is None:
2473
+ wr, wi, ws, widx = self.init_index(nside, cell_ids=cell_ids)
2474
+ else:
2475
+ wr, wi, ws, widx = self.InitWave(nside, cell_ids=cell_ids)
2479
2476
 
2480
- if self.InitWave is None:
2481
- wr, wi, ws, widx = self.init_index(nside, cell_ids=cell_ids)
2482
- else:
2483
- wr, wi, ws, widx = self.InitWave(self, nside, cell_ids=cell_ids)
2477
+ self.Idx_Neighbours[nside] = 1 # self.backend.bk_constant(tmp)
2478
+ self.ww_Real[nside] = wr
2479
+ self.ww_Imag[nside] = wi
2480
+ self.w_smooth[nside] = ws
2484
2481
 
2485
- self.Idx_Neighbours[nside] = 1
2486
- self.ww_Real[nside] = wr
2487
- self.ww_Imag[nside] = wi
2488
- self.w_smooth[nside] = ws
2482
+ l_w_smooth = self.w_smooth[nside]
2483
+ else:
2484
+ if (spin,nside) not in self.Idx_Neighbours:
2485
+ if self.InitWave is None:
2486
+ wr, wi, ws, widx = self.init_index(nside, cell_ids=cell_ids,spin=spin)
2487
+ else:
2488
+ wr, wi, ws, widx = self.InitWave(nside, cell_ids=cell_ids,spin=spin)
2489
2489
 
2490
- l_w_smooth = self.w_smooth[nside]
2490
+ self.Idx_Neighbours[(spin,nside)] = 1 # self.backend.bk_constant(tmp)
2491
+ self.ww_Real[(spin,nside)] = wr
2492
+ self.ww_Imag[(spin,nside)] = wi
2493
+ self.w_smooth[(spin,nside)] = ws
2494
+ l_w_smooth = self.w_smooth[(spin,nside)]
2491
2495
 
2492
2496
  odata = 1
2493
2497
  for k in range(0, len(ishape) - 1):