foscat 2025.6.1__py3-none-any.whl → 2025.7.1__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
foscat/FoCUS.py CHANGED
@@ -5,37 +5,37 @@ import healpy as hp
5
5
  import numpy as np
6
6
  from scipy.interpolate import griddata
7
7
 
8
- TMPFILE_VERSION = "V5_0"
8
+ TMPFILE_VERSION = "V6_0"
9
9
 
10
10
 
11
11
  class FoCUS:
12
12
  def __init__(
13
- self,
14
- NORIENT=4,
15
- LAMBDA=1.2,
16
- KERNELSZ=3,
17
- slope=1.0,
18
- all_type="float32",
19
- nstep_max=16,
20
- padding="SAME",
21
- gpupos=0,
22
- mask_thres=None,
23
- mask_norm=False,
24
- isMPI=False,
25
- TEMPLATE_PATH="data",
26
- BACKEND="tensorflow",
27
- use_2D=False,
28
- use_1D=False,
29
- return_data=False,
30
- JmaxDelta=0,
31
- DODIV=False,
32
- InitWave=None,
33
- silent=True,
34
- mpi_size=1,
35
- mpi_rank=0,
13
+ self,
14
+ NORIENT=4,
15
+ LAMBDA=1.2,
16
+ KERNELSZ=3,
17
+ slope=1.0,
18
+ all_type="float32",
19
+ nstep_max=20,
20
+ padding="SAME",
21
+ gpupos=0,
22
+ mask_thres=None,
23
+ mask_norm=False,
24
+ isMPI=False,
25
+ TEMPLATE_PATH=None,
26
+ BACKEND="torch",
27
+ use_2D=False,
28
+ use_1D=False,
29
+ return_data=False,
30
+ JmaxDelta=0,
31
+ DODIV=False,
32
+ InitWave=None,
33
+ silent=True,
34
+ mpi_size=1,
35
+ mpi_rank=0
36
36
  ):
37
37
 
38
- self.__version__ = "2025.06.1"
38
+ self.__version__ = "2025.07.1"
39
39
  # P00 coeff for normalization for scat_cov
40
40
  self.TMPFILE_VERSION = TMPFILE_VERSION
41
41
  self.P1_dic = None
@@ -62,6 +62,11 @@ class FoCUS:
62
62
  print("================================================")
63
63
  sys.stdout.flush()
64
64
 
65
+ home_dir = os.environ["HOME"]
66
+
67
+ if TEMPLATE_PATH is None:
68
+ TEMPLATE_PATH=home_dir+"/.FOSCAT/data"
69
+
65
70
  self.TEMPLATE_PATH = TEMPLATE_PATH
66
71
  if not os.path.exists(self.TEMPLATE_PATH):
67
72
  if not self.silent:
@@ -281,28 +286,10 @@ class FoCUS:
281
286
  self.KERNELSZ = KERNELSZ
282
287
 
283
288
  self.Idx_Neighbours = {}
289
+ self.w_smooth = {}
284
290
 
285
- if not self.use_2D and not self.use_1D:
286
- self.w_smooth = {}
287
- for i in range(nstep_max):
288
- lout = 2**i
289
- self.ww_Real[lout] = None
290
-
291
- for i in range(1, 6):
292
- lout = 2**i
293
- if not self.silent:
294
- print("Init Wave ", lout)
295
-
296
- if self.InitWave is None:
297
- wr, wi, ws, widx = self.init_index(lout)
298
- else:
299
- wr, wi, ws, widx = self.InitWave(self, lout)
300
-
301
- self.Idx_Neighbours[lout] = 1 # self.backend.bk_constant(widx)
302
- self.ww_Real[lout] = wr
303
- self.ww_Imag[lout] = wi
304
- self.w_smooth[lout] = ws
305
- elif self.use_1D:
291
+
292
+ if self.use_1D:
306
293
  self.w_smooth = slope * (w_smooth / w_smooth.sum()).astype(self.all_type)
307
294
  self.ww_RealT = {}
308
295
  self.ww_ImagT = {}
@@ -329,7 +316,7 @@ class FoCUS:
329
316
  self.backend.bk_constant(np.array(w).reshape(xx.shape[0]))
330
317
  )
331
318
 
332
- else:
319
+ if self.use_2D:
333
320
  self.w_smooth = slope * (w_smooth / w_smooth.sum()).astype(self.all_type)
334
321
  self.ww_RealT = {}
335
322
  self.ww_ImagT = {}
@@ -369,41 +356,34 @@ class FoCUS:
369
356
  self.pix_interp_val = {}
370
357
  self.weight_interp_val = {}
371
358
  self.ring2nest = {}
372
- self.nest2R = {}
373
- self.nest2R1 = {}
374
- self.nest2R2 = {}
375
- self.nest2R3 = {}
376
- self.nest2R4 = {}
377
- self.inv_nest2R = {}
378
- self.remove_border = {}
379
-
380
359
  self.ampnorm = {}
381
360
 
382
- for i in range(nstep_max):
383
- lout = 2**i
384
- self.pix_interp_val[lout] = {}
385
- self.weight_interp_val[lout] = {}
386
- for j in range(nstep_max):
387
- lout2 = 2**j
388
- self.pix_interp_val[lout][lout2] = None
389
- self.weight_interp_val[lout][lout2] = None
390
- self.ring2nest[lout] = None
391
- self.Idx_Neighbours[lout] = None
392
- self.nest2R[lout] = None
393
- self.nest2R1[lout] = None
394
- self.nest2R2[lout] = None
395
- self.nest2R3[lout] = None
396
- self.nest2R4[lout] = None
397
- self.inv_nest2R[lout] = None
398
- self.remove_border[lout] = None
399
- self.ww_CNN_Transpose[lout] = None
400
- self.ww_CNN[lout] = None
401
- self.X_CNN[lout] = None
402
- self.Y_CNN[lout] = None
403
- self.Z_CNN[lout] = None
404
-
405
361
  self.loss = {}
406
362
 
363
+ self.dtype_dcode_map = {
364
+ 0: np.int64,
365
+ 1: np.int32,
366
+ 2: np.float32,
367
+ 3: np.float64,
368
+ 4: np.complex64,
369
+ 5: np.complex128
370
+ }
371
+ self.dtype_code_map = {
372
+ np.int64: 0,
373
+ np.int32: 1,
374
+ np.float32: 2,
375
+ np.float64: 3,
376
+ np.complex64: 4,
377
+ np.complex128: 5
378
+ }
379
+
380
+ # this is for the storage only
381
+ def get_dtype_code(self, dtype):
382
+ for key, code in self.dtype_code_map.items():
383
+ if np.dtype(dtype) == np.dtype(key):
384
+ return code
385
+ raise ValueError(f"Unsupported data type: {dtype}")
386
+
407
387
  def get_type(self):
408
388
  return self.all_type
409
389
 
@@ -502,6 +482,114 @@ class FoCUS:
502
482
  )
503
483
  return indices, weights, xc, yc, zc
504
484
 
485
+ #======================================================================================
486
+ # The next two functions prepare the ability of FOSCAT to work with large indexed file
487
+ #======================================================================================
488
+
489
+ def save_index(self, filepath, data, offset=0, count=None):
490
+ """
491
+ Save an N-dimensional NumPy array with shape (N, ...) to binary file.
492
+ A 12x int64 header is written, describing dtype and shape beyond axis 0.
493
+
494
+ Header layout (12 x int64):
495
+ [0] = dtype code (0=int64, 1=int32, 2=float32, 3=float64, 4=complex64, 5=complex128)
496
+ [1] = number of extra dimensions (i.e., data.ndim - 1)
497
+ [2:12] = shape[1:] padded with zeros
498
+
499
+ Parameters:
500
+ - filepath: target binary file path
501
+ - data: NumPy array with shape (N, ...)
502
+ - offset: number of items to skip on axis 0
503
+ - count: number of items to write on axis 0 (default: rest of the array)
504
+ """
505
+ if filepath is None:
506
+ raise ValueError("No filepath specified for writing.")
507
+
508
+ data = np.asarray(data)
509
+ if data.ndim < 1:
510
+ raise ValueError("Data must have at least one dimension.")
511
+
512
+ extra_dims = data.shape[1:]
513
+ if len(extra_dims) > 10:
514
+ raise ValueError(f"Too many dimensions: {data.ndim}. Max supported is 11 (1 + 10 extra).")
515
+
516
+ dtype_code = self.get_dtype_code(data.dtype)
517
+ itemsize = data.dtype.itemsize
518
+ item_shape = data.shape[1:]
519
+ item_count = np.prod(item_shape, dtype=np.int64) if item_shape else 1
520
+
521
+ if count is None:
522
+ count = data.shape[0]
523
+
524
+ header = np.zeros(12, dtype=np.int64)
525
+ header[0] = dtype_code
526
+ header[1] = len(extra_dims)
527
+ header[2:2 + len(extra_dims)] = extra_dims
528
+
529
+ mode = 'r+b' if os.path.exists(filepath) else 'w+b'
530
+ with open(filepath, mode) as f:
531
+ if os.path.getsize(filepath) == 0:
532
+ f.write(header.tobytes())
533
+
534
+ byte_offset = 12 * 8 + offset * itemsize * item_count # header is 96 bytes
535
+ f.seek(byte_offset)
536
+ f.write(data[offset:offset + count].tobytes())
537
+
538
+ def read_index(self, filepath, offset=0, count=None):
539
+ """
540
+ Load a NumPy array from a binary file with a 12x int64 header.
541
+
542
+ Header layout:
543
+ [0] = dtype code
544
+ [1] = number of extra dimensions (D)
545
+ [2:2+D] = shape[1:] of each sample (shape after axis 0)
546
+
547
+ Parameters:
548
+ - filepath: path to the binary file
549
+ - offset: number of samples to skip on axis 0
550
+ - count: number of samples to read (default: all remaining)
551
+
552
+ Returns:
553
+ - data: NumPy array with shape (count, ...) and correct dtype
554
+ """
555
+ if not os.path.exists(filepath):
556
+ raise FileNotFoundError(f"File not found: {filepath}")
557
+
558
+ with open(filepath, 'rb') as f:
559
+ header_bytes = f.read(12 * 8)
560
+ if len(header_bytes) != 96:
561
+ raise ValueError("Invalid or missing header (expected 96 bytes).")
562
+
563
+ header = np.frombuffer(header_bytes, dtype=np.int64)
564
+ dtype_code = header[0]
565
+ ndim_extra = header[1]
566
+ if dtype_code not in self.dtype_dcode_map:
567
+ raise ValueError(f"Unknown dtype code in header: {dtype_code}")
568
+
569
+ dtype = self.dtype_dcode_map[dtype_code]
570
+ shape1 = tuple(header[2:2 + ndim_extra])
571
+ itemsize = np.dtype(dtype).itemsize
572
+ item_count = np.prod(shape1, dtype=np.int64) if shape1 else 1
573
+ bytes_per_sample = itemsize * item_count
574
+
575
+ # Seek to data block
576
+ f.seek(12 * 8 + offset * bytes_per_sample)
577
+
578
+ # Determine number of items
579
+ if count is None:
580
+ remaining_bytes = os.path.getsize(filepath) - (12 * 8 + offset * bytes_per_sample)
581
+ count = remaining_bytes // bytes_per_sample
582
+
583
+ raw = f.read(count * bytes_per_sample)
584
+ data = np.frombuffer(raw, dtype=dtype)
585
+
586
+ if shape1:
587
+ data = data.reshape((count,) + shape1)
588
+ else:
589
+ data = data.reshape((count,))
590
+
591
+ return data
592
+
505
593
  # ---------------------------------------------−---------
506
594
  # ---------------------------------------------−---------
507
595
  def healpix_layer(self, im, ww, indices=None, weights=None):
@@ -543,7 +631,7 @@ class FoCUS:
543
631
  def toring(self, image, axis=0):
544
632
  lout = int(np.sqrt(image.shape[axis] // 12))
545
633
 
546
- if self.ring2nest[lout] is None:
634
+ if lout not in self.ring2nest:
547
635
  self.ring2nest[lout] = hp.ring2nest(lout, np.arange(12 * lout**2))
548
636
 
549
637
  return image.numpy()[self.ring2nest[lout]]
@@ -639,36 +727,16 @@ class FoCUS:
639
727
  if cell_ids is not None:
640
728
  sim, new_cell_ids = self.backend.binned_mean(im, cell_ids)
641
729
  return sim, new_cell_ids
642
-
643
- lout = int(np.sqrt(shape[axis] // 12))
644
- if im.__class__ == np.zeros([0]).__class__:
645
- oshape = np.zeros([len(shape) + 1], dtype="int")
646
- if axis > 0:
647
- oshape[0:axis] = shape[0:axis]
648
- oshape[axis] = 12 * lout * lout // 4
649
- oshape[axis + 1] = 4
650
- if len(shape) > axis:
651
- oshape[axis + 2 :] = shape[axis + 1 :]
652
- else:
653
- if axis > 0:
654
- oshape = shape[0:axis] + [12 * lout * lout // 4, 4]
655
- else:
656
- oshape = [12 * lout * lout // 4, 4]
657
- if len(shape) > axis:
658
- oshape = oshape + shape[axis + 1 :]
659
-
660
- return (
661
- self.backend.bk_reduce_mean(
662
- self.backend.bk_reshape(im, oshape), axis=axis + 1
663
- ),
664
- None,
665
- )
730
+
731
+ return self.backend.bk_reduce_mean(
732
+ self.backend.bk_reshape(im, shape[0:-1]+[shape[-1]//4,4]), axis=-1
733
+ ),None
666
734
 
667
735
  # --------------------------------------------------------
668
- def up_grade(self, im, nout, axis=0, nouty=None):
736
+ def up_grade(self, im, nout, axis=-1, nouty=None):
669
737
 
738
+ ishape = list(im.shape)
670
739
  if self.use_2D:
671
- ishape = list(im.shape)
672
740
  if len(ishape) < axis + 2:
673
741
  if not self.silent:
674
742
  print("Use of 2D scat with data that has less than 2D")
@@ -683,9 +751,6 @@ class FoCUS:
683
751
  npix = im.shape[axis]
684
752
  npiy = im.shape[axis + 1]
685
753
  odata = 1
686
- if len(ishape) > axis + 2:
687
- for k in range(axis + 2, len(ishape)):
688
- odata = odata * ishape[k]
689
754
 
690
755
  ndata = 1
691
756
  for k in range(axis):
@@ -709,13 +774,12 @@ class FoCUS:
709
774
  return self.backend.bk_reshape(res, ishape[0:axis] + [nout, nouty])
710
775
  else:
711
776
  return self.backend.bk_reshape(
712
- res, ishape[0:axis] + [nout, nouty] + ishape[axis + 2 :]
777
+ res, ishape[0:axis] + [nout, nouty]
713
778
  )
714
779
 
715
780
  return self.backend.bk_reshape(res, [nout, nouty])
716
781
 
717
782
  elif self.use_1D:
718
- ishape = list(im.shape)
719
783
  if len(ishape) < axis + 1:
720
784
  if not self.silent:
721
785
  print("Use of 1D scat with data that has less than 1D")
@@ -773,9 +837,9 @@ class FoCUS:
773
837
 
774
838
  else:
775
839
 
776
- lout = int(np.sqrt(im.shape[axis] // 12))
840
+ lout = int(np.sqrt(im.shape[-1] // 12))
777
841
 
778
- if self.pix_interp_val[lout][nout] is None:
842
+ if (lout,nout) not in self.pix_interp_val:
779
843
  if not self.silent:
780
844
  print("compute lout nout", lout, nout)
781
845
  th, ph = hp.pix2ang(
@@ -794,104 +858,49 @@ class FoCUS:
794
858
  t = t + np.repeat(np.arange(12 * nout * nout) * 4, 4)
795
859
  p = p.flatten()[t]
796
860
  w = w.flatten()[t]
797
- indice[:, 0] = np.repeat(np.arange(12 * nout**2), 4)
798
- indice[:, 1] = p
861
+ indice[:, 1] = np.repeat(np.arange(12 * nout**2), 4)
862
+ indice[:, 0] = p
799
863
 
800
- self.pix_interp_val[lout][nout] = 1
801
- self.weight_interp_val[lout][nout] = self.backend.bk_SparseTensor(
864
+ self.pix_interp_val[(lout,nout)] = 1
865
+ self.weight_interp_val[(lout,nout)] = self.backend.bk_SparseTensor(
802
866
  self.backend.bk_constant(indice),
803
867
  self.backend.bk_constant(self.backend.bk_cast(w.flatten())),
804
- dense_shape=[12 * nout**2, 12 * lout**2],
868
+ dense_shape=[12 * lout**2,12 * nout**2],
805
869
  )
806
870
 
807
871
  if lout == nout:
808
872
  imout = im
809
873
  else:
810
-
811
- ishape = list(im.shape)
812
- odata = 1
813
- for k in range(axis + 1, len(ishape)):
814
- odata = odata * ishape[k]
874
+ # work only on the last column
815
875
 
816
876
  ndata = 1
817
- for k in range(axis):
877
+ for k in range(len(ishape)-1):
818
878
  ndata = ndata * ishape[k]
819
879
  tim = self.backend.bk_reshape(
820
- self.backend.bk_cast(im), [ndata, 12 * lout**2, odata]
880
+ self.backend.bk_cast(im), [ndata, 12 * lout**2]
821
881
  )
822
882
  if tim.dtype == self.all_cbk_type:
823
- rr = self.backend.bk_reshape(
824
- self.backend.bk_sparse_dense_matmul(
825
- self.weight_interp_val[lout][nout],
826
- self.backend.bk_real(tim[0]),
827
- ),
828
- [1, 12 * nout**2, odata],
829
- )
830
- ii = self.backend.bk_reshape(
831
- self.backend.bk_sparse_dense_matmul(
832
- self.weight_interp_val[lout][nout],
833
- self.backend.bk_imag(tim[0]),
834
- ),
835
- [1, 12 * nout**2, odata],
836
- )
883
+ rr = self.backend.bk_sparse_dense_matmul(
884
+ self.backend.bk_real(tim),
885
+ self.weight_interp_val[(lout,nout)],
886
+ )
887
+ ii = self.backend.bk_sparse_dense_matmul(
888
+ self.backend.bk_real(tim),
889
+ self.weight_interp_val[(lout,nout)],
890
+ )
837
891
  imout = self.backend.bk_complex(rr, ii)
838
892
  else:
839
- imout = self.backend.bk_reshape(
840
- self.backend.bk_sparse_dense_matmul(
841
- self.weight_interp_val[lout][nout], tim[0]
842
- ),
843
- [1, 12 * nout**2, odata],
893
+ imout = self.backend.bk_sparse_dense_matmul(
894
+ tim,
895
+ self.weight_interp_val[(lout,nout)],
844
896
  )
845
-
846
- for k in range(1, ndata):
847
- if tim.dtype == self.all_cbk_type:
848
- rr = self.backend.bk_reshape(
849
- self.backend.bk_sparse_dense_matmul(
850
- self.weight_interp_val[lout][nout],
851
- self.backend.bk_real(tim[k]),
852
- ),
853
- [1, 12 * nout**2, odata],
854
- )
855
- ii = self.backend.bk_reshape(
856
- self.backend.bk_sparse_dense_matmul(
857
- self.weight_interp_val[lout][nout],
858
- self.backend.bk_imag(tim[k]),
859
- ),
860
- [1, 12 * nout**2, odata],
861
- )
862
- imout = self.backend.bk_concat(
863
- [imout, self.backend.bk_complex(rr, ii)], 0
864
- )
865
- else:
866
- imout = self.backend.bk_concat(
867
- [
868
- imout,
869
- self.backend.bk_reshape(
870
- self.backend.bk_sparse_dense_matmul(
871
- self.weight_interp_val[lout][nout], tim[k]
872
- ),
873
- [1, 12 * nout**2, odata],
874
- ),
875
- ],
876
- 0,
877
- )
878
-
879
- if axis == 0:
880
- if len(ishape) == 1:
881
- return self.backend.bk_reshape(imout, [12 * nout**2])
882
- else:
883
- return self.backend.bk_reshape(
884
- imout, [12 * nout**2] + ishape[axis + 1 :]
885
- )
897
+
898
+ if len(ishape) == 1:
899
+ return self.backend.bk_reshape(imout, [12 * nout**2])
886
900
  else:
887
- if len(ishape) == axis + 1:
888
- return self.backend.bk_reshape(
889
- imout, ishape[0:axis] + [12 * nout**2]
890
- )
891
- else:
892
- return self.backend.bk_reshape(
893
- imout, ishape[0:axis] + [12 * nout**2] + ishape[axis + 1 :]
894
- )
901
+ return self.backend.bk_reshape(
902
+ imout, ishape[0:axis]+[12 * nout**2]
903
+ )
895
904
  return imout
896
905
 
897
906
  # --------------------------------------------------------
@@ -1164,7 +1173,7 @@ class FoCUS:
1164
1173
  return res
1165
1174
 
1166
1175
  # ---------------------------------------------−---------
1167
- def init_index(self, nside, kernel=-1, cell_ids=None):
1176
+ def init_index(self, nside, kernel=-1, cell_ids=None, spin=0):
1168
1177
 
1169
1178
  if kernel == -1:
1170
1179
  l_kernel = self.KERNELSZ
@@ -1178,14 +1187,13 @@ class FoCUS:
1178
1187
 
1179
1188
  try:
1180
1189
  if self.use_2D:
1181
- tmp = np.load(
1182
- "%s/W%d_%s_%d_IDX.npy"
1183
- % (self.TEMPLATE_PATH, l_kernel**2, TMPFILE_VERSION, nside)
1190
+ tmp = self.read_index("%s/W%d_%s_%d_IDX.fst"
1191
+ % (self.TEMPLATE_PATH, l_kernel**2,TMPFILE_VERSION, nside)
1184
1192
  )
1185
1193
  else:
1186
- if cell_ids is not None:
1187
- tmp = np.load(
1188
- "%s/XXXX_%s_W%d_%d_%d_PIDX.npy" # can not work
1194
+ if cell_ids is not None and nside>512:
1195
+ tmp = self.read_index(
1196
+ "%s/XXXX_%s_W%d_%d_%d_PIDX.fst" # can not work
1189
1197
  % (
1190
1198
  self.TEMPLATE_PATH,
1191
1199
  TMPFILE_VERSION,
@@ -1196,298 +1204,399 @@ class FoCUS:
1196
1204
  )
1197
1205
 
1198
1206
  else:
1199
- tmp = np.load(
1200
- "%s/FOSCAT_%s_W%d_%d_%d_PIDX.npy"
1207
+ print('LOAD ',"%s/FOSCAT_%s_W%d_%d_%d_PIDX-SPIN%d.fst"
1201
1208
  % (
1202
1209
  self.TEMPLATE_PATH,
1203
1210
  TMPFILE_VERSION,
1204
1211
  l_kernel**2,
1205
1212
  self.NORIENT,
1206
- nside, # if cell_ids computes the index
1213
+ nside,spin # if cell_ids computes the index
1214
+ ))
1215
+ tmp = self.read_index(
1216
+ "%s/FOSCAT_%s_W%d_%d_%d_PIDX-SPIN%d.fst"
1217
+ % (
1218
+ self.TEMPLATE_PATH,
1219
+ TMPFILE_VERSION,
1220
+ l_kernel**2,
1221
+ self.NORIENT,
1222
+ nside,spin # if cell_ids computes the index
1207
1223
  )
1208
1224
  )
1225
+
1209
1226
  except:
1227
+ if cell_ids is not None and nside<=512:
1228
+ self.init_index(nside, kernel=kernel, spin=spin)
1229
+
1210
1230
  if not self.use_2D:
1231
+ print('NOT FOUND THEN COMPUTE %s/FOSCAT_%s_W%d_%d_%d_PIDX-SPIN%d.fst'
1232
+ % (
1233
+ self.TEMPLATE_PATH,
1234
+ TMPFILE_VERSION,
1235
+ l_kernel**2,
1236
+ self.NORIENT,
1237
+ nside,spin # if cell_ids computes the index
1238
+ )
1239
+ )
1240
+ if spin!=0:
1241
+ try:
1242
+ tmp = self.read_index(
1243
+ "%s/FOSCAT_%s_W%d_%d_%d_PIDX-SPIN0.fst"
1244
+ % (
1245
+ self.TEMPLATE_PATH,
1246
+ TMPFILE_VERSION,
1247
+ l_kernel**2,
1248
+ self.NORIENT,
1249
+ nside
1250
+ )
1251
+ )
1252
+ except:
1253
+ print('NOT FOUND THEN COMPUTE %s/FOSCAT_%s_W%d_%d_%d_PIDX-SPIN0.fst'
1254
+ % (
1255
+ self.TEMPLATE_PATH,
1256
+ TMPFILE_VERSION,
1257
+ l_kernel**2,
1258
+ self.NORIENT,
1259
+ nside
1260
+ )
1261
+ )
1262
+
1263
+ self.init_index(nside, kernel=kernel, spin=0)
1264
+
1265
+ tmp = self.read_index(
1266
+ "%s/FOSCAT_%s_W%d_%d_%d_PIDX-SPIN0.fst"
1267
+ % (
1268
+ self.TEMPLATE_PATH,
1269
+ TMPFILE_VERSION,
1270
+ l_kernel**2,
1271
+ self.NORIENT,
1272
+ nside
1273
+ )
1274
+ )
1275
+
1276
+ tmpw = self.read_index("%s/FOSCAT_%s_W%d_%d_%d_WAVE-SPIN0.fst"% (
1277
+ self.TEMPLATE_PATH,
1278
+ self.TMPFILE_VERSION,
1279
+ self.KERNELSZ**2,
1280
+ self.NORIENT,
1281
+ nside,
1282
+ )
1283
+ )
1284
+
1285
+ nn=self.NORIENT*12*nside**2
1286
+ idxEB=np.concatenate([tmp,tmp,tmp,tmp],0)
1287
+ idxEB[tmp.shape[0]:2*tmp.shape[0],0]+=12*nside**2
1288
+ idxEB[3*tmp.shape[0]:,0]+=12*nside**2
1289
+ idxEB[2*tmp.shape[0]:,1]+=nn
1290
+
1291
+ tmpEB=np.zeros([tmpw.shape[0]*4],dtype='complex')
1292
+
1293
+ for k in range(self.NORIENT*12*nside**2):
1294
+ if k%(nside**2)==0:
1295
+ print('Init index 1/2 spin=%d Please wait %d done against %d nside=%d kernel=%d'%(spin,k//(nside**2),
1296
+ self.NORIENT*12,
1297
+ nside,
1298
+ self.KERNELSZ))
1299
+ idx=np.where(tmp[:,1]==k)[0]
1300
+
1301
+ im=np.zeros([12*nside**2])
1302
+ im[tmp[idx,0]]=tmpw[idx].real
1303
+ almR=hp.map2alm(hp.reorder(im,n2r=True))
1304
+ im[tmp[idx,0]]=tmpw[idx].imag
1305
+ almI=hp.map2alm(hp.reorder(im,n2r=True))
1306
+
1307
+ i,q,u=hp.alm2map_spin([almR,almR*0,0*almR],nside,spin,3*nside-1)
1308
+ i2,q2,u2=hp.alm2map_spin([almI,0*almI,0*almI],nside,spin,3*nside-1)
1309
+
1310
+ tmpEB[idx]=hp.reorder(i,r2n=True)[tmp[idx,0]]+1J*hp.reorder(i2,r2n=True)[tmp[idx,0]]
1311
+ tmpEB[idx+tmp.shape[0]]=hp.reorder(q,r2n=True)[tmp[idx,0]]+1J*hp.reorder(q2,r2n=True)[tmp[idx,0]]
1312
+
1313
+ i,q,u=hp.alm2map_spin([0*almR,almR,0*almR],nside,spin,3*nside-1)
1314
+ i2,q2,u2=hp.alm2map_spin([0*almI,almI,0*almI],nside,spin,3*nside-1)
1315
+
1316
+ tmpEB[idx+2*tmp.shape[0]]=hp.reorder(i,r2n=True)[tmp[idx,0]]+1J*hp.reorder(i2,r2n=True)[tmp[idx,0]]
1317
+ tmpEB[idx+3*tmp.shape[0]]=hp.reorder(q,r2n=True)[tmp[idx,0]]+1J*hp.reorder(q2,r2n=True)[tmp[idx,0]]
1211
1318
 
1212
- if l_kernel == 5:
1213
- pw = 0.5
1214
- pw2 = 0.5
1215
- threshold = 2e-4
1319
+
1320
+ self.save_index("%s/FOSCAT_%s_W%d_%d_%d_PIDX-SPIN%d.fst"% (self.TEMPLATE_PATH,
1321
+ self.TMPFILE_VERSION,
1322
+ self.KERNELSZ**2,
1323
+ self.NORIENT,
1324
+ nside,
1325
+ spin
1326
+ ),
1327
+ idxEB
1328
+ )
1329
+ self.save_index("%s/FOSCAT_%s_W%d_%d_%d_WAVE-SPIN%d.fst"% (self.TEMPLATE_PATH,
1330
+ self.TMPFILE_VERSION,
1331
+ self.KERNELSZ**2,
1332
+ self.NORIENT,
1333
+ nside,
1334
+ spin,
1335
+ ),
1336
+ tmpEB
1337
+ )
1338
+
1339
+ tmp = self.read_index(
1340
+ "%s/FOSCAT_%s_W%d_%d_%d_PIDX2-SPIN0.fst"
1341
+ % (
1342
+ self.TEMPLATE_PATH,
1343
+ TMPFILE_VERSION,
1344
+ l_kernel**2,
1345
+ self.NORIENT,
1346
+ nside
1347
+ )
1348
+ )
1349
+
1350
+ tmpw = self.read_index("%s/FOSCAT_%s_W%d_%d_%d_SMOO-SPIN0.fst"% (
1351
+ self.TEMPLATE_PATH,
1352
+ self.TMPFILE_VERSION,
1353
+ self.KERNELSZ**2,
1354
+ self.NORIENT,
1355
+ nside,
1356
+ )
1357
+ )
1358
+ for k in range(12*nside**2):
1359
+ if k%(nside**2)==0:
1360
+ print('Init index 2/2 spin=%d Please wait %d done against %d nside=%d kernel=%d'%(spin,k//(nside**2),
1361
+ 12,
1362
+ nside,
1363
+ self.KERNELSZ))
1364
+ idx=np.where(tmp[:,1]==k)[0]
1365
+
1366
+ im=np.zeros([12*nside**2])
1367
+ im[tmp[idx,0]]=tmpw[idx]
1368
+ almR=hp.map2alm(hp.reorder(im,n2r=True))
1369
+
1370
+ i,q,u=hp.alm2map_spin([almR,almR*0,0*almR],nside,spin,3*nside-1)
1371
+
1372
+ tmpEB[idx]=hp.reorder(i,r2n=True)[tmp[idx,0]]
1373
+ tmpEB[idx+tmp.shape[0]]=hp.reorder(q,r2n=True)[tmp[idx,0]]
1374
+
1375
+ i,q,u=hp.alm2map_spin([0*almR,almR,0*almR],nside,spin,3*nside-1)
1376
+
1377
+ tmpEB[idx+2*tmp.shape[0]]=hp.reorder(i,r2n=True)[tmp[idx,0]]
1378
+ tmpEB[idx+3*tmp.shape[0]]=hp.reorder(q,r2n=True)[tmp[idx,0]]
1216
1379
 
1217
- elif l_kernel == 3:
1218
- pw = 1.0 / np.sqrt(2)
1219
- pw2 = 1.0
1220
- threshold = 1e-3
1380
+
1381
+ self.save_index("%s/FOSCAT_%s_W%d_%d_%d_PIDX2-SPIN%d.fst"% (self.TEMPLATE_PATH,
1382
+ self.TMPFILE_VERSION,
1383
+ self.KERNELSZ**2,
1384
+ self.NORIENT,
1385
+ nside,
1386
+ spin
1387
+ ),
1388
+ idxEB
1389
+ )
1390
+ self.save_index("%s/FOSCAT_%s_W%d_%d_%d_SMOO-SPIN%d.fst"% (self.TEMPLATE_PATH,
1391
+ self.TMPFILE_VERSION,
1392
+ self.KERNELSZ**2,
1393
+ self.NORIENT,
1394
+ nside,
1395
+ spin,
1396
+ ),
1397
+ tmpEB
1398
+ )
1399
+ else:
1221
1400
 
1222
- elif l_kernel == 7:
1223
- pw = 0.5
1224
- pw2 = 0.25
1225
- threshold = 4e-5
1401
+ if l_kernel == 5:
1402
+ pw = 0.5
1403
+ pw2 = 0.5
1404
+ threshold = 2e-4
1226
1405
 
1227
- if cell_ids is not None:
1228
- if not isinstance(cell_ids, np.ndarray):
1229
- cell_ids = self.backend.to_numpy(cell_ids)
1230
- th, ph = hp.pix2ang(nside, cell_ids, nest=True)
1231
- x, y, z = hp.pix2vec(nside, cell_ids, nest=True)
1406
+ elif l_kernel == 3:
1407
+ pw = 1.0 / np.sqrt(2)
1408
+ pw2 = 1.0
1409
+ threshold = 1e-3
1232
1410
 
1233
- t, p = hp.pix2ang(nside, cell_ids, nest=True)
1234
- phi = [p[k] / np.pi * 180 for k in range(ncell)]
1235
- thi = [t[k] / np.pi * 180 for k in range(ncell)]
1411
+ elif l_kernel == 7:
1412
+ pw = 0.5
1413
+ pw2 = 0.25
1414
+ threshold = 4e-5
1236
1415
 
1237
- indice2 = np.zeros([ncell * 64, 2], dtype="int")
1238
- indice = np.zeros([ncell * 64 * self.NORIENT, 2], dtype="int")
1239
- wav = np.zeros([ncell * 64 * self.NORIENT], dtype="complex")
1240
- wwav = np.zeros([ncell * 64 * self.NORIENT], dtype="float")
1416
+ if cell_ids is not None and nside>512:
1417
+ if not isinstance(cell_ids, np.ndarray):
1418
+ cell_ids = self.backend.to_numpy(cell_ids)
1419
+ th, ph = hp.pix2ang(nside, cell_ids, nest=True)
1420
+ x, y, z = hp.pix2vec(nside, cell_ids, nest=True)
1241
1421
 
1242
- else:
1422
+ t, p = hp.pix2ang(nside, cell_ids, nest=True)
1423
+ phi = [p[k] / np.pi * 180 for k in range(ncell)]
1424
+ thi = [t[k] / np.pi * 180 for k in range(ncell)]
1243
1425
 
1244
- th, ph = hp.pix2ang(nside, np.arange(12 * nside**2), nest=True)
1245
- x, y, z = hp.pix2vec(nside, np.arange(12 * nside**2), nest=True)
1426
+ indice2 = np.zeros([ncell * 64, 2], dtype="int")
1427
+ indice = np.zeros([ncell * 64 * self.NORIENT, 2], dtype="int")
1428
+ wav = np.zeros([ncell * 64 * self.NORIENT], dtype="complex")
1429
+ wwav = np.zeros([ncell * 64 * self.NORIENT], dtype="float")
1246
1430
 
1247
- t, p = hp.pix2ang(nside, np.arange(12 * nside * nside), nest=True)
1248
- phi = [p[k] / np.pi * 180 for k in range(12 * nside * nside)]
1249
- thi = [t[k] / np.pi * 180 for k in range(12 * nside * nside)]
1431
+ else:
1250
1432
 
1251
- indice2 = np.zeros([12 * nside * nside * 64, 2], dtype="int")
1252
- indice = np.zeros(
1253
- [12 * nside * nside * 64 * self.NORIENT, 2], dtype="int"
1254
- )
1255
- wav = np.zeros(
1256
- [12 * nside * nside * 64 * self.NORIENT], dtype="complex"
1257
- )
1258
- wwav = np.zeros(
1259
- [12 * nside * nside * 64 * self.NORIENT], dtype="float"
1260
- )
1261
- iv = 0
1262
- iv2 = 0
1433
+ th, ph = hp.pix2ang(nside, np.arange(12 * nside**2), nest=True)
1434
+ x, y, z = hp.pix2vec(nside, np.arange(12 * nside**2), nest=True)
1263
1435
 
1264
- for iii in range(ncell):
1265
- if cell_ids is None:
1266
- if iii % (nside * nside) == nside * nside - 1:
1267
- if not self.silent:
1268
- print(
1269
- "Pre-compute nside=%6d %.2f%%"
1270
- % (nside, 100 * iii / (12 * nside * nside))
1271
- )
1436
+ t, p = hp.pix2ang(nside, np.arange(12 * nside * nside), nest=True)
1437
+ phi = [p[k] / np.pi * 180 for k in range(12 * nside * nside)]
1438
+ thi = [t[k] / np.pi * 180 for k in range(12 * nside * nside)]
1272
1439
 
1273
- if cell_ids is not None:
1274
- hidx = np.where(
1275
- (x - x[iii]) ** 2 + (y - y[iii]) ** 2 + (z - z[iii]) ** 2
1276
- < (2 * np.pi / nside) ** 2
1277
- )[0]
1278
- else:
1279
- hidx = hp.query_disc(
1280
- nside,
1281
- [x[iii], y[iii], z[iii]],
1282
- 2 * np.pi / nside,
1283
- nest=True,
1440
+ indice2 = np.zeros([12 * nside * nside * 64, 2],
1441
+ dtype="int")
1442
+
1443
+ indice = np.zeros(
1444
+ [12 * nside * nside * 64 * self.NORIENT, 2],
1445
+ dtype="int"
1284
1446
  )
1447
+ wav = np.zeros(
1448
+ [12 * nside * nside * 64 * self.NORIENT],
1449
+ dtype="complex"
1450
+ )
1451
+ wwav = np.zeros(
1452
+ [12 * nside * nside * 64 * self.NORIENT],
1453
+ dtype="float"
1454
+ )
1455
+ iv = 0
1456
+ iv2 = 0
1457
+
1458
+ for iii in range(ncell):
1459
+ if cell_ids is None:
1460
+ if iii % (nside * nside) == nside * nside - 1:
1461
+ if not self.silent:
1462
+ print(
1463
+ "Pre-compute nside=%6d %.2f%%"
1464
+ % (nside, 100 * iii / (12 * nside * nside))
1465
+ )
1466
+
1467
+ if cell_ids is not None:
1468
+ hidx = np.where(
1469
+ (x - x[iii]) ** 2 + (y - y[iii]) ** 2 + (z - z[iii]) ** 2
1470
+ < (2 * np.pi / nside) ** 2
1471
+ )[0]
1472
+ else:
1473
+ hidx = hp.query_disc(
1474
+ nside,
1475
+ [x[iii], y[iii], z[iii]],
1476
+ 2 * np.pi / nside,
1477
+ nest=True,
1478
+ )
1285
1479
 
1286
- R = hp.Rotator(rot=[phi[iii], -thi[iii]], eulertype="ZYZ")
1287
-
1288
- t2, p2 = R(th[hidx], ph[hidx])
1289
-
1290
- vec2 = hp.ang2vec(t2, p2)
1480
+ R = hp.Rotator(rot=[phi[iii], -thi[iii]], eulertype="ZYZ")
1291
1481
 
1292
- x2 = vec2[:, 0]
1293
- y2 = vec2[:, 1]
1294
- z2 = vec2[:, 2]
1482
+ t2, p2 = R(th[hidx], ph[hidx])
1295
1483
 
1296
- ww = np.exp(
1297
- -pw2
1298
- * ((nside) ** 2)
1299
- * ((x2) ** 2 + (y2) ** 2 + (z2 - 1.0) ** 2)
1300
- )
1301
- idx = np.where((ww**2) > threshold)[0]
1302
- nval2 = len(idx)
1303
- indice2[iv2 : iv2 + nval2, 1] = iii
1304
- indice2[iv2 : iv2 + nval2, 0] = hidx[idx]
1305
- wwav[iv2 : iv2 + nval2] = ww[idx] / np.sum(ww[idx])
1306
- iv2 += nval2
1484
+ vec2 = hp.ang2vec(t2, p2)
1307
1485
 
1308
- for l_rotation in range(self.NORIENT):
1486
+ x2 = vec2[:, 0]
1487
+ y2 = vec2[:, 1]
1488
+ z2 = vec2[:, 2]
1309
1489
 
1310
- angle = (
1311
- l_rotation / 4.0 * np.pi
1312
- - phi[iii] / 180.0 * np.pi * (z[hidx] > 0)
1313
- - (180.0 - phi[iii]) / 180.0 * np.pi * (z[hidx] < 0)
1490
+ ww = np.exp(
1491
+ -pw2
1492
+ * ((nside) ** 2)
1493
+ * ((x2) ** 2 + (y2) ** 2 + (z2 - 1.0) ** 2)
1314
1494
  )
1495
+ idx = np.where((ww**2) > threshold)[0]
1496
+ nval2 = len(idx)
1497
+ indice2[iv2 : iv2 + nval2, 1] = iii
1498
+ indice2[iv2 : iv2 + nval2, 0] = hidx[idx]
1499
+ wwav[iv2 : iv2 + nval2] = ww[idx] / np.sum(ww[idx])
1500
+ iv2 += nval2
1501
+
1502
+ for l_rotation in range(self.NORIENT):
1503
+
1504
+ angle = (
1505
+ l_rotation / 4.0 * np.pi
1506
+ - phi[iii] / 180.0 * np.pi * (z[hidx] > 0)
1507
+ - (180.0 - phi[iii]) / 180.0 * np.pi * (z[hidx] < 0)
1508
+ )
1315
1509
 
1316
- # posi=2*(0.5-(z[hidx]<0))
1317
-
1318
- axes = y2 * np.cos(angle) - x2 * np.sin(angle)
1319
- wresr = ww * np.cos(pw * axes * (nside) * np.pi)
1320
- wresi = ww * np.sin(pw * axes * (nside) * np.pi)
1321
-
1322
- vnorm = wresr * wresr + wresi * wresi
1323
- idx = np.where(vnorm > threshold)[0]
1324
-
1325
- nval = len(idx)
1326
- indice[iv : iv + nval, 1] = iii + l_rotation * ncell
1327
- indice[iv : iv + nval, 0] = hidx[idx]
1328
- # print([hidx[k] for k in idx])
1329
- # print(hp.query_disc(nside, [x[iii],y[iii],z[iii]], np.pi/nside,nest=True))
1330
- normr = np.mean(wresr[idx])
1331
- normi = np.mean(wresi[idx])
1332
-
1333
- val = wresr[idx] - normr + 1j * (wresi[idx] - normi)
1334
- r = abs(val).sum()
1510
+ # posi=2*(0.5-(z[hidx]<0))
1335
1511
 
1336
- if r > 0:
1337
- val = val / r
1512
+ axes = y2 * np.cos(angle) - x2 * np.sin(angle)
1513
+ wresr = ww * np.cos(pw * axes * (nside) * np.pi)
1514
+ wresi = ww * np.sin(pw * axes * (nside) * np.pi)
1338
1515
 
1339
- wav[iv : iv + nval] = val
1340
- iv += nval
1516
+ vnorm = wresr * wresr + wresi * wresi
1517
+ idx = np.where(vnorm > threshold)[0]
1341
1518
 
1342
- indice = indice[:iv, :]
1343
- wav = wav[:iv]
1344
- indice2 = indice2[:iv2, :]
1345
- wwav = wwav[:iv2]
1346
- if not self.silent:
1347
- print("Kernel Size ", iv / (self.NORIENT * 12 * nside * nside))
1348
- """
1349
- # OLD VERSION OLD VERSION OLD VERSION (3.0)
1350
- if self.KERNELSZ*self.KERNELSZ>12*nside*nside:
1351
- l_kernel=3
1352
-
1353
- aa=np.cos(np.arange(self.NORIENT)/self.NORIENT*np.pi).reshape(1,self.NORIENT)
1354
- bb=np.sin(np.arange(self.NORIENT)/self.NORIENT*np.pi).reshape(1,self.NORIENT)
1355
- x,y,z=hp.pix2vec(nside,np.arange(12*nside*nside),nest=True)
1356
- to,po=hp.pix2ang(nside,np.arange(12*nside*nside),nest=True)
1357
-
1358
- wav=np.zeros([12*nside*nside,l_kernel**2,self.NORIENT],dtype='complex')
1359
- wwav=np.zeros([12*nside*nside,l_kernel**2])
1360
- iwav=np.zeros([12*nside*nside,l_kernel**2],dtype='int')
1361
-
1362
- scale=4
1363
- if nside>scale*2:
1364
- th,ph=hp.pix2ang(nside//scale,np.arange(12*(nside//scale)**2),nest=True)
1365
- else:
1366
- lidx=np.arange(12*nside*nside)
1519
+ nval = len(idx)
1520
+ indice[iv : iv + nval, 1] = iii + l_rotation * ncell
1521
+ indice[iv : iv + nval, 0] = hidx[idx]
1522
+ # print([hidx[k] for k in idx])
1523
+ # print(hp.query_disc(nside, [x[iii],y[iii],z[iii]], np.pi/nside,nest=True))
1524
+ normr = np.mean(wresr[idx])
1525
+ normi = np.mean(wresi[idx])
1367
1526
 
1368
- pw=np.pi/4.0
1369
- pw2=1/2
1370
- amp=1.0
1527
+ val = wresr[idx] - normr + 1j * (wresi[idx] - normi)
1528
+ r = abs(val).sum()
1371
1529
 
1372
- if l_kernel==5:
1373
- pw=np.pi/4.0
1374
- pw2=1/2.25
1375
- amp=1.0/9.2038
1530
+ if r > 0:
1531
+ val = val / r
1376
1532
 
1377
- elif l_kernel==3:
1378
- pw=1.0/np.sqrt(2)
1379
- pw2=1.0
1380
- amp=1/8.45
1533
+ wav[iv : iv + nval] = val
1534
+ iv += nval
1381
1535
 
1382
- elif l_kernel==7:
1383
- pw=np.pi/4.0
1384
- pw2=1.0/3.0
1536
+ indice = indice[:iv, :]
1537
+ wav = wav[:iv]
1538
+ indice2 = indice2[:iv2, :]
1539
+ wwav = wwav[:iv2]
1540
+ if not self.silent:
1541
+ print("Kernel Size ", iv / (self.NORIENT * 12 * nside * nside))
1385
1542
 
1386
- for k in range(12*nside*nside):
1387
- if k%(nside*nside)==0:
1543
+ if cell_ids is None:
1388
1544
  if not self.silent:
1389
- print('Pre-compute nside=%6d %.2f%%'%(nside,100*k/(12*nside*nside)))
1390
- if nside>scale*2:
1391
- lidx=hp.get_all_neighbours(nside//scale,th[k//(scale*scale)],ph[k//(scale*scale)],nest=True)
1392
- lidx=np.concatenate([lidx,np.array([(k//(scale*scale))])],0)
1393
- lidx=np.repeat(lidx*(scale*scale),(scale*scale))+ \
1394
- np.tile(np.arange((scale*scale)),lidx.shape[0])
1395
-
1396
- delta=(x[lidx]-x[k])**2+(y[lidx]-y[k])**2+(z[lidx]-z[k])**2
1397
- pidx=np.where(delta<(10)/(nside**2))[0]
1398
- if len(pidx)<l_kernel**2:
1399
- pidx=np.arange(delta.shape[0])
1400
-
1401
- w=np.exp(-pw2*delta[pidx]*(nside**2))
1402
- pidx=pidx[np.argsort(-w)[0:l_kernel**2]]
1403
- pidx=pidx[np.argsort(lidx[pidx])]
1404
-
1405
- w=np.exp(-pw2*delta[pidx]*(nside**2))
1406
- iwav[k]=lidx[pidx]
1407
- wwav[k]=w
1408
- rot=[po[k]/np.pi*180.0,90+(-to[k])/np.pi*180.0]
1409
- r=hp.Rotator(rot=rot)
1410
- ty,tx=r(to[iwav[k]],po[iwav[k]])
1411
- ty=ty-np.pi/2
1412
-
1413
- xx=np.expand_dims(pw*nside*np.pi*tx/np.cos(ty),-1)
1414
- yy=np.expand_dims(pw*nside*np.pi*ty,-1)
1415
-
1416
- wav[k,:,:]=(np.cos(xx*aa+yy*bb)+complex(0.0,1.0)*np.sin(xx*aa+yy*bb))*np.expand_dims(w,-1)
1417
-
1418
- wav=wav-np.expand_dims(np.mean(wav,1),1)
1419
- wav=amp*wav/np.expand_dims(np.std(wav,1),1)
1420
- wwav=wwav/np.expand_dims(np.sum(wwav,1),1)
1421
-
1422
- nk=l_kernel*l_kernel
1423
- indice=np.zeros([12*nside*nside*nk*self.NORIENT,2],dtype='int')
1424
- lidx=np.arange(self.NORIENT)
1425
- for i in range(12*nside*nside):
1426
- indice[i*nk*self.NORIENT:i*nk*self.NORIENT+nk*self.NORIENT,0]=i*self.NORIENT+np.repeat(lidx,nk)
1427
- indice[i*nk*self.NORIENT:i*nk*self.NORIENT+nk*self.NORIENT,1]=np.tile(iwav[i],self.NORIENT)
1428
-
1429
- indice2=np.zeros([12*nside*nside*nk,2],dtype='int')
1430
- for i in range(12*nside*nside):
1431
- indice2[i*nk:i*nk+nk,0]=i
1432
- indice2[i*nk:i*nk+nk,1]=iwav[i]
1433
-
1434
- w=np.zeros([12*nside*nside,wav.shape[2],wav.shape[1]],dtype='complex')
1435
- for i in range(wav.shape[1]):
1436
- for j in range(wav.shape[2]):
1437
- w[:,j,i]=wav[:,i,j]
1438
- wav=w.flatten()
1439
- wwav=wwav.flatten()
1440
- """
1441
- if cell_ids is None:
1442
- if not self.silent:
1443
- print(
1444
- "Write FOSCAT_%s_W%d_%d_%d_PIDX.npy"
1445
- % (TMPFILE_VERSION, self.KERNELSZ**2, self.NORIENT, nside)
1545
+ print(
1546
+ "Write %s/FOSCAT_%s_W%d_%d_%d_PIDX-SPIN%d.fst"
1547
+ % ( self.TEMPLATE_PATH,
1548
+ TMPFILE_VERSION, self.KERNELSZ**2,
1549
+ self.NORIENT,
1550
+ nside,
1551
+ spin)
1552
+ )
1553
+ self.save_index("%s/FOSCAT_%s_W%d_%d_%d_PIDX-SPIN%d.fst"
1554
+ % (
1555
+ self.TEMPLATE_PATH,
1556
+ TMPFILE_VERSION,
1557
+ self.KERNELSZ**2,
1558
+ self.NORIENT,
1559
+ nside,
1560
+ spin,
1561
+ ),
1562
+ indice
1563
+ )
1564
+ self.save_index(
1565
+ "%s/FOSCAT_%s_W%d_%d_%d_WAVE-SPIN%d.fst"
1566
+ % (
1567
+ self.TEMPLATE_PATH,
1568
+ TMPFILE_VERSION,
1569
+ self.KERNELSZ**2,
1570
+ self.NORIENT,
1571
+ nside,
1572
+ spin,
1573
+ ),
1574
+ wav,
1575
+ )
1576
+ self.save_index(
1577
+ "%s/FOSCAT_%s_W%d_%d_%d_PIDX2-SPIN%d.fst"
1578
+ % (
1579
+ self.TEMPLATE_PATH,
1580
+ TMPFILE_VERSION,
1581
+ self.KERNELSZ**2,
1582
+ self.NORIENT,
1583
+ nside,
1584
+ spin,
1585
+ ),
1586
+ indice2,
1587
+ )
1588
+ self.save_index(
1589
+ "%s/FOSCAT_%s_W%d_%d_%d_SMOO-SPIN%d.fst"
1590
+ % (
1591
+ self.TEMPLATE_PATH,
1592
+ TMPFILE_VERSION,
1593
+ self.KERNELSZ**2,
1594
+ self.NORIENT,
1595
+ nside,
1596
+ spin,
1597
+ ),
1598
+ wwav,
1446
1599
  )
1447
- np.save(
1448
- "%s/FOSCAT_%s_W%d_%d_%d_PIDX.npy"
1449
- % (
1450
- self.TEMPLATE_PATH,
1451
- TMPFILE_VERSION,
1452
- self.KERNELSZ**2,
1453
- self.NORIENT,
1454
- nside,
1455
- ),
1456
- indice,
1457
- )
1458
- np.save(
1459
- "%s/FOSCAT_%s_W%d_%d_%d_WAVE.npy"
1460
- % (
1461
- self.TEMPLATE_PATH,
1462
- TMPFILE_VERSION,
1463
- self.KERNELSZ**2,
1464
- self.NORIENT,
1465
- nside,
1466
- ),
1467
- wav,
1468
- )
1469
- np.save(
1470
- "%s/FOSCAT_%s_W%d_%d_%d_PIDX2.npy"
1471
- % (
1472
- self.TEMPLATE_PATH,
1473
- TMPFILE_VERSION,
1474
- self.KERNELSZ**2,
1475
- self.NORIENT,
1476
- nside,
1477
- ),
1478
- indice2,
1479
- )
1480
- np.save(
1481
- "%s/FOSCAT_%s_W%d_%d_%d_SMOO.npy"
1482
- % (
1483
- self.TEMPLATE_PATH,
1484
- TMPFILE_VERSION,
1485
- self.KERNELSZ**2,
1486
- self.NORIENT,
1487
- nside,
1488
- ),
1489
- wwav,
1490
- )
1491
1600
  if self.use_2D:
1492
1601
  if l_kernel**2 == 9:
1493
1602
  if self.rank == 0:
@@ -1504,64 +1613,107 @@ class FoCUS:
1504
1613
  )
1505
1614
  return None
1506
1615
 
1507
- if cell_ids is None:
1616
+ if cell_ids is None or nside<=512:
1508
1617
  self.barrier()
1509
1618
  if self.use_2D:
1510
- tmp = np.load(
1511
- "%s/W%d_%s_%d_IDX.npy"
1512
- % (self.TEMPLATE_PATH, l_kernel**2, TMPFILE_VERSION, nside)
1619
+ tmp = self.read_index(
1620
+ "%s/W%d_%s_%d_IDX-SPIN%d.fst"
1621
+ % (
1622
+ self.TEMPLATE_PATH,
1623
+ l_kernel**2,
1624
+ TMPFILE_VERSION,
1625
+ nside,
1626
+ spin)
1513
1627
  )
1514
1628
  else:
1515
- tmp = np.load(
1516
- "%s/FOSCAT_%s_W%d_%d_%d_PIDX.npy"
1629
+ tmp = self.read_index(
1630
+ "%s/FOSCAT_%s_W%d_%d_%d_PIDX-SPIN%d.fst"
1517
1631
  % (
1518
1632
  self.TEMPLATE_PATH,
1519
1633
  TMPFILE_VERSION,
1520
1634
  self.KERNELSZ**2,
1521
1635
  self.NORIENT,
1522
1636
  nside,
1637
+ spin,
1523
1638
  )
1524
1639
  )
1525
- tmp2 = np.load(
1526
- "%s/FOSCAT_%s_W%d_%d_%d_PIDX2.npy"
1640
+ tmp2 = self.read_index(
1641
+ "%s/FOSCAT_%s_W%d_%d_%d_PIDX2-SPIN%d.fst"
1527
1642
  % (
1528
1643
  self.TEMPLATE_PATH,
1529
1644
  TMPFILE_VERSION,
1530
1645
  self.KERNELSZ**2,
1531
1646
  self.NORIENT,
1532
1647
  nside,
1648
+ spin,
1533
1649
  )
1534
1650
  )
1535
- wr = np.load(
1536
- "%s/FOSCAT_%s_W%d_%d_%d_WAVE.npy"
1651
+ wr = self.read_index(
1652
+ "%s/FOSCAT_%s_W%d_%d_%d_WAVE-SPIN%d.fst"
1537
1653
  % (
1538
1654
  self.TEMPLATE_PATH,
1539
1655
  TMPFILE_VERSION,
1540
1656
  self.KERNELSZ**2,
1541
1657
  self.NORIENT,
1542
1658
  nside,
1659
+ spin,
1543
1660
  )
1544
1661
  ).real
1545
- wi = np.load(
1546
- "%s/FOSCAT_%s_W%d_%d_%d_WAVE.npy"
1662
+ wi = self.read_index(
1663
+ "%s/FOSCAT_%s_W%d_%d_%d_WAVE-SPIN%d.fst"
1547
1664
  % (
1548
1665
  self.TEMPLATE_PATH,
1549
1666
  TMPFILE_VERSION,
1550
1667
  self.KERNELSZ**2,
1551
1668
  self.NORIENT,
1552
1669
  nside,
1670
+ spin,
1553
1671
  )
1554
1672
  ).imag
1555
- ws = self.slope * np.load(
1556
- "%s/FOSCAT_%s_W%d_%d_%d_SMOO.npy"
1673
+ ws = self.slope * self.read_index(
1674
+ "%s/FOSCAT_%s_W%d_%d_%d_SMOO-SPIN%d.fst"
1557
1675
  % (
1558
1676
  self.TEMPLATE_PATH,
1559
1677
  TMPFILE_VERSION,
1560
1678
  self.KERNELSZ**2,
1561
1679
  self.NORIENT,
1562
1680
  nside,
1681
+ spin,
1563
1682
  )
1564
1683
  )
1684
+
1685
+ if cell_ids is not None:
1686
+ idx_map=-np.ones([12*nside**2],dtype='int32')
1687
+ lcell_ids=cell_ids
1688
+
1689
+ try:
1690
+ idx_map[lcell_ids]=np.arange(lcell_ids.shape[0],dtype='int32')
1691
+ except:
1692
+ lcell_ids=self.to_numpy(cell_ids)
1693
+ idx_map[lcell_ids]=np.arange(lcell_ids.shape[0],dtype='int32')
1694
+
1695
+ lidx=np.where(idx_map[tmp[:,1]%(12*nside**2)]!=-1)[0]
1696
+ orientation=tmp[lidx,1]//(12*nside**2)
1697
+ tmp=tmp[lidx]
1698
+ wr=wr[lidx]
1699
+ wi=wi[lidx]
1700
+ tmp=idx_map[tmp%(12*nside**2)]
1701
+ lidx=np.where(tmp[:,0]==-1)[0]
1702
+ wr[lidx]=0.0
1703
+ wi[lidx]=0.0
1704
+ tmp[lidx,0]=0
1705
+ tmp[:,1]+=orientation*lcell_ids.shape[0]
1706
+
1707
+ idx_map=-np.ones([12*nside**2],dtype='int32')
1708
+ idx_map[lcell_ids]=np.arange(cell_ids.shape[0],dtype='int32')
1709
+ lidx=np.where(idx_map[tmp2[:,1]]!=-1)[0]
1710
+ tmp2=tmp2[lidx]
1711
+ ws=ws[lidx]
1712
+ tmp2=idx_map[tmp2]
1713
+ lidx=np.where(tmp2[:,0]==-1)[0]
1714
+ ws[lidx]=0.0
1715
+ tmp2[lidx,0]=0
1716
+
1565
1717
  else:
1566
1718
  tmp = indice
1567
1719
  tmp2 = indice2
@@ -1569,21 +1721,39 @@ class FoCUS:
1569
1721
  wi = wav.imag
1570
1722
  ws = self.slope * wwav
1571
1723
 
1572
- wr = self.backend.bk_SparseTensor(
1573
- self.backend.bk_constant(tmp),
1574
- self.backend.bk_constant(self.backend.bk_cast(wr)),
1575
- dense_shape=[ncell, self.NORIENT * ncell],
1576
- )
1577
- wi = self.backend.bk_SparseTensor(
1578
- self.backend.bk_constant(tmp),
1579
- self.backend.bk_constant(self.backend.bk_cast(wi)),
1580
- dense_shape=[ncell, self.NORIENT * ncell],
1581
- )
1582
- ws = self.backend.bk_SparseTensor(
1583
- self.backend.bk_constant(tmp2),
1584
- self.backend.bk_constant(self.backend.bk_cast(ws)),
1585
- dense_shape=[ncell, ncell],
1586
- )
1724
+
1725
+ if spin==0:
1726
+ wr = self.backend.bk_SparseTensor(
1727
+ self.backend.bk_constant(tmp),
1728
+ self.backend.bk_constant(self.backend.bk_cast(wr)),
1729
+ dense_shape=[ncell, self.NORIENT * ncell],
1730
+ )
1731
+ wi = self.backend.bk_SparseTensor(
1732
+ self.backend.bk_constant(tmp),
1733
+ self.backend.bk_constant(self.backend.bk_cast(wi)),
1734
+ dense_shape=[ncell, self.NORIENT * ncell],
1735
+ )
1736
+ ws = self.backend.bk_SparseTensor(
1737
+ self.backend.bk_constant(tmp2),
1738
+ self.backend.bk_constant(self.backend.bk_cast(ws)),
1739
+ dense_shape=[ncell, ncell],
1740
+ )
1741
+ else:
1742
+ wr = self.backend.bk_SparseTensor(
1743
+ self.backend.bk_constant(tmp),
1744
+ self.backend.bk_constant(self.backend.bk_cast(wr)),
1745
+ dense_shape=[2*ncell, 2*self.NORIENT * ncell],
1746
+ )
1747
+ wi = self.backend.bk_SparseTensor(
1748
+ self.backend.bk_constant(tmp),
1749
+ self.backend.bk_constant(self.backend.bk_cast(wi)),
1750
+ dense_shape=[2*ncell, 2*self.NORIENT * ncell],
1751
+ )
1752
+ ws = self.backend.bk_SparseTensor(
1753
+ self.backend.bk_constant(tmp2),
1754
+ self.backend.bk_constant(self.backend.bk_cast(ws)),
1755
+ dense_shape=[2*ncell, 2*ncell],
1756
+ )
1587
1757
 
1588
1758
  if kernel == -1:
1589
1759
  self.Idx_Neighbours[nside] = tmp
@@ -1592,7 +1762,7 @@ class FoCUS:
1592
1762
  if kernel != -1:
1593
1763
  return tmp
1594
1764
 
1595
- return wr, wi, ws, tmp
1765
+ return wr, wi, ws,tmp
1596
1766
 
1597
1767
 
1598
1768
  # ---------------------------------------------−---------
@@ -1611,8 +1781,8 @@ class FoCUS:
1611
1781
  try:
1612
1782
 
1613
1783
  if cell_ids is not None:
1614
- tmp = np.load(
1615
- "%s/XXXX_%s_W%d_%d_%d_PIDX.npy" # can not work
1784
+ tmp = self.read_index(
1785
+ "%s/XXXX_%s_W%d_%d_%d_PIDX.fst" # can not work
1616
1786
  % (
1617
1787
  self.TEMPLATE_PATH,
1618
1788
  TMPFILE_VERSION,
@@ -1623,8 +1793,8 @@ class FoCUS:
1623
1793
  )
1624
1794
 
1625
1795
  else:
1626
- tmp = np.load(
1627
- "%s/CNN_FOSCAT_%s_W%d_%d_%d_PIDX.npy"
1796
+ tmp = self.read_index(
1797
+ "%s/CNN_FOSCAT_%s_W%d_%d_%d_PIDX.fst"
1628
1798
  % (
1629
1799
  self.TEMPLATE_PATH,
1630
1800
  TMPFILE_VERSION,
@@ -1758,11 +1928,11 @@ class FoCUS:
1758
1928
  if cell_ids is None:
1759
1929
  if not self.silent:
1760
1930
  print(
1761
- "Write FOSCAT_%s_W%d_%d_%d_PIDX.npy"
1931
+ "Write FOSCAT_%s_W%d_%d_%d_PIDX.fst"
1762
1932
  % (TMPFILE_VERSION, self.KERNELSZ**2, NORIENT, nside)
1763
1933
  )
1764
- np.save(
1765
- "%s/CNN_FOSCAT_%s_W%d_%d_%d_PIDX.npy"
1934
+ self.save_index(
1935
+ "%s/CNN_FOSCAT_%s_W%d_%d_%d_PIDX.fst"
1766
1936
  % (
1767
1937
  self.TEMPLATE_PATH,
1768
1938
  TMPFILE_VERSION,
@@ -1772,8 +1942,8 @@ class FoCUS:
1772
1942
  ),
1773
1943
  indice,
1774
1944
  )
1775
- np.save(
1776
- "%s/CNN_FOSCAT_%s_W%d_%d_%d_WAVE.npy"
1945
+ self.save_index(
1946
+ "%s/CNN_FOSCAT_%s_W%d_%d_%d_WAVE.fst"
1777
1947
  % (
1778
1948
  self.TEMPLATE_PATH,
1779
1949
  TMPFILE_VERSION,
@@ -1787,13 +1957,13 @@ class FoCUS:
1787
1957
  if cell_ids is None:
1788
1958
  self.barrier()
1789
1959
  if self.use_2D:
1790
- tmp = np.load(
1791
- "%s/W%d_%s_%d_IDX.npy"
1960
+ tmp = self.read_index(
1961
+ "%s/W%d_%s_%d_IDX.fst"
1792
1962
  % (self.TEMPLATE_PATH, l_kernel**2, TMPFILE_VERSION, nside)
1793
1963
  )
1794
1964
  else:
1795
- tmp = np.load(
1796
- "%s/CNN_FOSCAT_%s_W%d_%d_%d_PIDX.npy"
1965
+ tmp = self.read_index(
1966
+ "%s/CNN_FOSCAT_%s_W%d_%d_%d_PIDX.fst"
1797
1967
  % (
1798
1968
  self.TEMPLATE_PATH,
1799
1969
  TMPFILE_VERSION,
@@ -1802,8 +1972,8 @@ class FoCUS:
1802
1972
  nside,
1803
1973
  )
1804
1974
  )
1805
- wav = np.load(
1806
- "%s/CNN_FOSCAT_%s_W%d_%d_%d_WAVE.npy"
1975
+ wav = self.read_index(
1976
+ "%s/CNN_FOSCAT_%s_W%d_%d_%d_WAVE.fst"
1807
1977
  % (
1808
1978
  self.TEMPLATE_PATH,
1809
1979
  TMPFILE_VERSION,
@@ -1840,10 +2010,10 @@ class FoCUS:
1840
2010
  return self.backend.bk_transpose(x, thelist)
1841
2011
 
1842
2012
  # ---------------------------------------------−---------
1843
- # Mean using mask x [....,Npix,....], mask[Nmask,Npix] to [....,Nmask,....]
2013
+ # Mean using mask x [n_b,....,Npix], mask[Nmask,Npix] to [n_b,Nmask,....]
1844
2014
  # if use_2D
1845
- # Mean using mask x [....,12,Nside+2*off,Nside+2*off,....], mask[Nmask,12,Nside+2*off,Nside+2*off] to [....,Nmask,....]
1846
- def masked_mean(self, x, mask, axis=0, rank=0, calc_var=False):
2015
+ # Mean using mask x [n_b,....,N_1,N_2], mask[Nmask,N_1,N_2] to [n_b,Nmask,....]
2016
+ def masked_mean(self, x, mask, rank=0, calc_var=False):
1847
2017
 
1848
2018
  # ==========================================================================
1849
2019
  # in input data=[Nbatch,...,NORIENT[,NORIENT],X[,Y]]
@@ -1855,7 +2025,7 @@ class FoCUS:
1855
2025
  shape = list(x.shape)
1856
2026
 
1857
2027
  if not self.use_2D and not self.use_1D:
1858
- nside = int(np.sqrt(x.shape[axis] // 12))
2028
+ nside = int(np.sqrt(x.shape[-1] // 12))
1859
2029
 
1860
2030
  l_mask = mask
1861
2031
  if self.mask_norm:
@@ -1949,16 +2119,24 @@ class FoCUS:
1949
2119
  l_x = self.backend.bk_reshape(
1950
2120
  l_x[:, :, self.KERNELSZ // 2 : -self.KERNELSZ // 2 + 1, :], oshape
1951
2121
  )
1952
- else:
2122
+ else:
1953
2123
  ichannel = 1
1954
- for i in range(len(shape) - 1):
1955
- ichannel *= shape[i]
2124
+ if len(shape)>1:
2125
+ ichannel = shape[0]
2126
+
2127
+ ochannel = 1
2128
+ for i in range(1,len(shape)-1):
2129
+ ochannel *= shape[i]
1956
2130
 
1957
- l_x = self.backend.bk_reshape(x, [ichannel, 1, shape[-1]])
2131
+ l_x = self.backend.bk_reshape(x, [ichannel,1,ochannel,shape[-1]])
1958
2132
 
1959
- # data=[Nbatch,...,NORIENT[,NORIENT],X[,Y]] => data=[Nbatch,1,...,NORIENT[,NORIENT],X[,Y]]
2133
+ # data=[Nbatch,...,NORIENT[,NORIENT],X[,Y]] => data=[Nbatch,...,1,NORIENT[,NORIENT],X[,Y]]
1960
2134
  # mask=[Nmask,X[,Y]] => mask=[1,Nmask,....,X[,Y]]
1961
- l_mask = self.backend.bk_expand_dims(self.backend.bk_expand_dims(l_mask, 0), 0)
2135
+
2136
+ if self.use_2D:
2137
+ l_mask = self.backend.bk_expand_dims(self.backend.bk_expand_dims(l_mask,0),-3)
2138
+ else:
2139
+ l_mask = self.backend.bk_expand_dims(self.backend.bk_expand_dims(l_mask,0),-2)
1962
2140
 
1963
2141
  if l_x.dtype == self.all_cbk_type:
1964
2142
  l_mask = self.backend.bk_complex(l_mask, self.backend.bk_cast(0.0 * l_mask))
@@ -1989,6 +2167,8 @@ class FoCUS:
1989
2167
 
1990
2168
  if len(x.shape[axis:-2]) > 0:
1991
2169
  oshape = oshape + list(x.shape[axis:-2])
2170
+ else:
2171
+ oshape = oshape + [1]
1992
2172
 
1993
2173
  if calc_var:
1994
2174
  if self.backend.bk_is_complex(vtmp):
@@ -2018,7 +2198,7 @@ class FoCUS:
2018
2198
  elif self.use_1D:
2019
2199
  mtmp = l_mask
2020
2200
  vtmp = l_x
2021
- v1 = self.backend.bk_reduce_sum(mtmp * vtmp, axis=-1)
2201
+ v1 = self.backend.bk_reduce_sum(l_mask[1,:,...,:] * vtmp, axis=-1)
2022
2202
  v2 = self.backend.bk_reduce_sum(mtmp * vtmp * vtmp, axis=-1)
2023
2203
  vh = self.backend.bk_reduce_sum(mtmp, axis=-1)
2024
2204
 
@@ -2027,6 +2207,8 @@ class FoCUS:
2027
2207
  oshape = [x.shape[0]] + [mask.shape[0]]
2028
2208
  if len(x.shape) > 1:
2029
2209
  oshape = oshape + list(x.shape[1:-1])
2210
+ else:
2211
+ oshape = oshape + [1]
2030
2212
 
2031
2213
  if calc_var:
2032
2214
  if self.backend.bk_is_complex(vtmp):
@@ -2060,13 +2242,16 @@ class FoCUS:
2060
2242
  res = v1 / vh
2061
2243
 
2062
2244
  oshape = []
2063
- if axis > 0:
2245
+ if len(shape) > 1:
2064
2246
  oshape = [x.shape[0]]
2065
2247
  else:
2066
2248
  oshape = [1]
2249
+
2067
2250
  oshape = oshape + [mask.shape[0]]
2068
- if axis > 1:
2069
- oshape = oshape + list(x.shape[1:-1])
2251
+ if len(shape) > 2:
2252
+ oshape = oshape + shape[1:-1]
2253
+ else:
2254
+ oshape = oshape + [1]
2070
2255
 
2071
2256
  if calc_var:
2072
2257
  if self.backend.bk_is_complex(l_x):
@@ -2220,7 +2405,7 @@ class FoCUS:
2220
2405
  return self.backend.bk_reduce_sum(r)
2221
2406
 
2222
2407
  # ---------------------------------------------−---------
2223
- def convol(self, in_image, axis=0, cell_ids=None, nside=None):
2408
+ def convol(self, in_image, axis=0, cell_ids=None, nside=None, spin=0):
2224
2409
 
2225
2410
  image = self.backend.bk_cast(in_image)
2226
2411
 
@@ -2283,77 +2468,22 @@ class FoCUS:
2283
2468
 
2284
2469
  else:
2285
2470
  ishape = list(image.shape)
2286
- """
2287
- if cell_ids is not None:
2288
- if cell_ids.shape[0] not in self.padding_conv:
2289
- print(image.shape,cell_ids.shape)
2290
- import healpix_convolution as hc
2291
- from xdggs.healpix import HealpixInfo
2292
-
2293
- res = self.backend.bk_zeros(
2294
- ishape[0:-1] + [self.NORIENT]+ishape[-1], dtype=self.backend.all_cbk_type
2295
- )
2296
-
2297
- grid_info = HealpixInfo(
2298
- level=int(np.log(nside) / np.log(2)), indexing_scheme="nested"
2299
- )
2300
-
2301
- for k in range(self.NORIENT):
2302
- kernelR, kernelI = hc.kernels.wavelet_kernel(
2303
- cell_ids, grid_info=grid_info, orientation=k, is_torch=True
2304
- )
2305
- self.kernelR_conv[(cell_ids.shape[0], k)] = kernelR.to(
2306
- self.backend.all_bk_type
2307
- ).to(image.device)
2308
- self.kernelI_conv[(cell_ids.shape[0], k)] = kernelI.to(
2309
- self.backend.all_bk_type
2310
- ).to(image.device)
2311
- self.padding_conv[(cell_ids.shape[0], k)] = hc.pad(
2312
- cell_ids,
2313
- grid_info=grid_info,
2314
- ring=5 // 2, # wavelet kernel_size=5 is hard coded
2315
- mode="mean",
2316
- constant_value=0,
2317
- )
2318
-
2319
- for k in range(self.NORIENT):
2320
-
2321
- kernelR = self.kernelR_conv[(cell_ids.shape[0], k)]
2322
- kernelI = self.kernelI_conv[(cell_ids.shape[0], k)]
2323
- padding = self.padding_conv[(cell_ids.shape[0], k)]
2324
- if len(ishape) == 2:
2325
- for l in range(ishape[0]):
2326
- padded_data = padding.apply(image[l], is_torch=True)
2327
- res[l, :, k] = kernelR.matmul(
2328
- padded_data
2329
- ) + 1j * kernelI.matmul(padded_data)
2330
- else:
2331
- for l in range(ishape[0]):
2332
- for k2 in range(ishape[2]):
2333
- padded_data = padding.apply(
2334
- image[l, :, k2], is_torch=True
2335
- )
2336
- res[l, :, k2, k] = kernelR.matmul(
2337
- padded_data
2338
- ) + 1j * kernelI.matmul(padded_data)
2339
- return res
2340
- """
2341
2471
  if nside is None:
2342
2472
  nside = int(np.sqrt(image.shape[-1] // 12))
2343
2473
 
2344
- if self.Idx_Neighbours[nside] is None:
2474
+ if (spin,nside) not in self.Idx_Neighbours:
2345
2475
  if self.InitWave is None:
2346
- wr, wi, ws, widx = self.init_index(nside, cell_ids=cell_ids)
2476
+ wr, wi, ws, widx = self.init_index(nside, cell_ids=cell_ids,spin=spin)
2347
2477
  else:
2348
- wr, wi, ws, widx = self.InitWave(nside, cell_ids=cell_ids)
2478
+ wr, wi, ws, widx = self.InitWave(nside, cell_ids=cell_ids,spin=spin)
2349
2479
 
2350
- self.Idx_Neighbours[nside] = 1 # self.backend.bk_constant(tmp)
2351
- self.ww_Real[nside] = wr
2352
- self.ww_Imag[nside] = wi
2353
- self.w_smooth[nside] = ws
2480
+ self.Idx_Neighbours[(spin,nside)] = 1 # self.backend.bk_constant(tmp)
2481
+ self.ww_Real[(spin,nside)] = wr
2482
+ self.ww_Imag[(spin,nside)] = wi
2483
+ self.w_smooth[(spin,nside)] = ws
2354
2484
 
2355
- l_ww_real = self.ww_Real[nside]
2356
- l_ww_imag = self.ww_Imag[nside]
2485
+ l_ww_real = self.ww_Real[(spin,nside)]
2486
+ l_ww_imag = self.ww_Imag[(spin,nside)]
2357
2487
 
2358
2488
  # always convolve the last dimension
2359
2489
 
@@ -2361,10 +2491,14 @@ class FoCUS:
2361
2491
  if len(ishape) > 1:
2362
2492
  for k in range(len(ishape) - 1):
2363
2493
  ndata = ndata * ishape[k]
2364
- tim = self.backend.bk_reshape(
2365
- self.backend.bk_cast(image), [ndata, ishape[-1]]
2366
- )
2367
-
2494
+ if spin>0:
2495
+ tim = self.backend.bk_reshape(
2496
+ self.backend.bk_cast(image), [ndata//2,2*ishape[-1]]
2497
+ )
2498
+ else:
2499
+ tim = self.backend.bk_reshape(
2500
+ self.backend.bk_cast(image), [ndata, ishape[-1]]
2501
+ )
2368
2502
  if tim.dtype == self.all_cbk_type:
2369
2503
  rr1 = self.backend.bk_reshape(
2370
2504
  self.backend.bk_sparse_dense_matmul(
@@ -2405,17 +2539,26 @@ class FoCUS:
2405
2539
  [ndata, self.NORIENT, ishape[-1]],
2406
2540
  )
2407
2541
  res = self.backend.bk_complex(rr, ii)
2408
- if len(ishape) > 1:
2409
- return self.backend.bk_reshape(
2410
- res, ishape[0:-1] + [self.NORIENT, ishape[-1]]
2411
- )
2542
+
2543
+ if spin==0:
2544
+ if len(ishape) > 1:
2545
+ return self.backend.bk_reshape(
2546
+ res, ishape[0:-1] + [self.NORIENT, ishape[-1]]
2547
+ )
2548
+ else:
2549
+ return self.backend.bk_reshape(res, [self.NORIENT, ishape[-1]])
2412
2550
  else:
2413
- return self.backend.bk_reshape(res, [self.NORIENT, ishape[-1]])
2414
-
2551
+ if len(ishape) > 2:
2552
+ return self.backend.bk_reshape(
2553
+ res, ishape[0:-2] + [2,self.NORIENT, ishape[-1]]
2554
+ )
2555
+ else:
2556
+ return self.backend.bk_reshape(res, [2,self.NORIENT, ishape[-1]])
2557
+
2415
2558
  return res
2416
2559
 
2417
2560
  # ---------------------------------------------−---------
2418
- def smooth(self, in_image, axis=0, cell_ids=None, nside=None):
2561
+ def smooth(self, in_image, axis=0, cell_ids=None, nside=None, spin=0):
2419
2562
 
2420
2563
  image = self.backend.bk_cast(in_image)
2421
2564
 
@@ -2475,64 +2618,22 @@ class FoCUS:
2475
2618
  else:
2476
2619
 
2477
2620
  ishape = list(image.shape)
2478
- """
2479
- if cell_ids is not None:
2480
- if cell_ids.shape[0] not in self.padding_smooth:
2481
- import healpix_convolution as hc
2482
- from xdggs.healpix import HealpixInfo
2483
-
2484
- grid_info = HealpixInfo(
2485
- level=int(np.log(nside) / np.log(2)), indexing_scheme="nested"
2486
- )
2487
-
2488
- kernel = hc.kernels.wavelet_smooth_kernel(
2489
- cell_ids, grid_info=grid_info, is_torch=True
2490
- )
2491
-
2492
- self.kernel_smooth[cell_ids.shape[0]] = kernel.to(
2493
- self.backend.all_bk_type
2494
- ).to(image.device)
2495
-
2496
- self.padding_smooth[cell_ids.shape[0]] = hc.pad(
2497
- cell_ids,
2498
- grid_info=grid_info,
2499
- ring=5 // 2, # wavelet kernel_size=5 is hard coded
2500
- mode="mean",
2501
- constant_value=0,
2502
- )
2503
-
2504
- kernel = self.kernel_smooth[cell_ids.shape[0]]
2505
- padding = self.padding_smooth[cell_ids.shape[0]]
2506
-
2507
- res = self.backend.bk_zeros(ishape, dtype=self.backend.all_cbk_type)
2508
-
2509
- if len(ishape) == 2:
2510
- for l in range(ishape[0]):
2511
- padded_data = padding.apply(image[l], is_torch=True)
2512
- res[l] = kernel.matmul(padded_data)
2513
- else:
2514
- for l in range(ishape[0]):
2515
- for k2 in range(ishape[2]):
2516
- padded_data = padding.apply(image[l, :, k2], is_torch=True)
2517
- res[l, :, k2] = kernel.matmul(padded_data)
2518
- return res
2519
- """
2621
+
2520
2622
  if nside is None:
2521
2623
  nside = int(np.sqrt(image.shape[-1] // 12))
2522
2624
 
2523
- if self.Idx_Neighbours[nside] is None:
2524
-
2625
+ if (spin,nside) not in self.Idx_Neighbours:
2525
2626
  if self.InitWave is None:
2526
- wr, wi, ws, widx = self.init_index(nside, cell_ids=cell_ids)
2627
+ wr, wi, ws, widx = self.init_index(nside, cell_ids=cell_ids,spin=spin)
2527
2628
  else:
2528
- wr, wi, ws, widx = self.InitWave(self, nside, cell_ids=cell_ids)
2629
+ wr, wi, ws, widx = self.InitWave(nside, cell_ids=cell_ids,spin=spin)
2529
2630
 
2530
- self.Idx_Neighbours[nside] = 1
2531
- self.ww_Real[nside] = wr
2532
- self.ww_Imag[nside] = wi
2533
- self.w_smooth[nside] = ws
2534
-
2535
- l_w_smooth = self.w_smooth[nside]
2631
+ self.Idx_Neighbours[(spin,nside)] = 1 # self.backend.bk_constant(tmp)
2632
+ self.ww_Real[(spin,nside)] = wr
2633
+ self.ww_Imag[(spin,nside)] = wi
2634
+ self.w_smooth[(spin,nside)] = ws
2635
+
2636
+ l_w_smooth = self.w_smooth[(spin,nside)]
2536
2637
 
2537
2638
  odata = 1
2538
2639
  for k in range(0, len(ishape) - 1):