foscat 2025.5.0__py3-none-any.whl → 2025.6.1__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
foscat/FoCUS.py CHANGED
@@ -5,7 +5,7 @@ import healpy as hp
5
5
  import numpy as np
6
6
  from scipy.interpolate import griddata
7
7
 
8
- TMPFILE_VERSION = "V4_0"
8
+ TMPFILE_VERSION = "V5_0"
9
9
 
10
10
 
11
11
  class FoCUS:
@@ -35,7 +35,7 @@ class FoCUS:
35
35
  mpi_rank=0,
36
36
  ):
37
37
 
38
- self.__version__ = "2025.05.0"
38
+ self.__version__ = "2025.06.1"
39
39
  # P00 coeff for normalization for scat_cov
40
40
  self.TMPFILE_VERSION = TMPFILE_VERSION
41
41
  self.P1_dic = None
@@ -176,11 +176,14 @@ class FoCUS:
176
176
  self.Y_CNN = {}
177
177
  self.Z_CNN = {}
178
178
 
179
+ self.Idx_CNN = {}
180
+ self.Idx_WCNN = {}
181
+
179
182
  self.filters_set = {}
180
183
  self.edge_masks = {}
181
184
 
182
- wwc = np.zeros([KERNELSZ**2, l_NORIENT]).astype(all_type)
183
- wws = np.zeros([KERNELSZ**2, l_NORIENT]).astype(all_type)
185
+ wwc = np.zeros([l_NORIENT, KERNELSZ**2]).astype(all_type)
186
+ wws = np.zeros([l_NORIENT, KERNELSZ**2]).astype(all_type)
184
187
 
185
188
  x = np.repeat(np.arange(KERNELSZ) - KERNELSZ // 2, KERNELSZ).reshape(
186
189
  KERNELSZ, KERNELSZ
@@ -203,12 +206,12 @@ class FoCUS:
203
206
  -0.5 * (xx**2 + yy**2)
204
207
  )
205
208
 
206
- wwc[:, 0] = tmp.flatten() - tmp.mean()
209
+ wwc[0] = tmp.flatten() - tmp.mean()
207
210
  tmp = 0 * w_smooth
208
- wws[:, 0] = tmp.flatten()
211
+ wws[0] = tmp.flatten()
209
212
  sigma = np.sqrt((wwc[:, 0] ** 2).mean())
210
- wwc[:, 0] /= sigma
211
- wws[:, 0] /= sigma
213
+ wwc[0] /= sigma
214
+ wws[0] /= sigma
212
215
 
213
216
  w_smooth = w_smooth.flatten()
214
217
  else:
@@ -239,12 +242,12 @@ class FoCUS:
239
242
  tmp1 = np.cos(yy * np.pi) * w_smooth
240
243
  tmp2 = np.sin(yy * np.pi) * w_smooth
241
244
 
242
- wwc[:, i] = tmp1.flatten() - tmp1.mean()
243
- wws[:, i] = tmp2.flatten() - tmp2.mean()
245
+ wwc[i] = tmp1.flatten() - tmp1.mean()
246
+ wws[i] = tmp2.flatten() - tmp2.mean()
244
247
  # sigma = np.sqrt((wwc[:, i] ** 2).mean())
245
248
  sigma = np.mean(w_smooth)
246
- wwc[:, i] /= sigma
247
- wws[:, i] /= sigma
249
+ wwc[i] /= sigma
250
+ wws[i] /= sigma
248
251
 
249
252
  if DODIV and i == 0:
250
253
  r = xx**2 + yy**2
@@ -253,22 +256,22 @@ class FoCUS:
253
256
  tmp1 = r * np.cos(2 * theta) * w_smooth
254
257
  tmp2 = r * np.sin(2 * theta) * w_smooth
255
258
 
256
- wwc[:, NORIENT] = tmp1.flatten() - tmp1.mean()
257
- wws[:, NORIENT] = tmp2.flatten() - tmp2.mean()
259
+ wwc[NORIENT] = tmp1.flatten() - tmp1.mean()
260
+ wws[NORIENT] = tmp2.flatten() - tmp2.mean()
258
261
  # sigma = np.sqrt((wwc[:, NORIENT] ** 2).mean())
259
262
  sigma = np.mean(w_smooth)
260
263
 
261
- wwc[:, NORIENT] /= sigma
262
- wws[:, NORIENT] /= sigma
264
+ wwc[NORIENT] /= sigma
265
+ wws[NORIENT] /= sigma
263
266
  tmp1 = r * np.cos(2 * theta + np.pi)
264
267
  tmp2 = r * np.sin(2 * theta + np.pi)
265
268
 
266
- wwc[:, NORIENT + 1] = tmp1.flatten() - tmp1.mean()
267
- wws[:, NORIENT + 1] = tmp2.flatten() - tmp2.mean()
269
+ wwc[NORIENT + 1] = tmp1.flatten() - tmp1.mean()
270
+ wws[NORIENT + 1] = tmp2.flatten() - tmp2.mean()
268
271
  # sigma = np.sqrt((wwc[:, NORIENT + 1] ** 2).mean())
269
272
  sigma = np.mean(w_smooth)
270
- wwc[:, NORIENT + 1] /= sigma
271
- wws[:, NORIENT + 1] /= sigma
273
+ wwc[NORIENT + 1] /= sigma
274
+ wws[NORIENT + 1] /= sigma
272
275
 
273
276
  w_smooth = w_smooth.flatten()
274
277
 
@@ -316,14 +319,14 @@ class FoCUS:
316
319
  r = np.sum(np.sqrt(c * c + s * s))
317
320
  c = c / r
318
321
  s = s / r
319
- self.ww_RealT[1] = self.backend.bk_constant(
320
- np.array(c).reshape(xx.shape[0], 1, 1)
322
+ self.ww_RealT[1] = self.backend.bk_cast(
323
+ self.backend.bk_constant(np.array(c).reshape(xx.shape[0]))
321
324
  )
322
- self.ww_ImagT[1] = self.backend.bk_constant(
323
- np.array(s).reshape(xx.shape[0], 1, 1)
325
+ self.ww_ImagT[1] = self.backend.bk_cast(
326
+ self.backend.bk_constant(np.array(s).reshape(xx.shape[0]))
324
327
  )
325
- self.ww_SmoothT[1] = self.backend.bk_constant(
326
- np.array(w).reshape(xx.shape[0], 1, 1)
328
+ self.ww_SmoothT[1] = self.backend.bk_cast(
329
+ self.backend.bk_constant(np.array(w).reshape(xx.shape[0]))
327
330
  )
328
331
 
329
332
  else:
@@ -333,22 +336,16 @@ class FoCUS:
333
336
  self.ww_SmoothT = {}
334
337
 
335
338
  self.ww_SmoothT[1] = self.backend.bk_constant(
336
- self.w_smooth.reshape(KERNELSZ, KERNELSZ, 1, 1)
337
- )
338
- www = np.zeros([KERNELSZ, KERNELSZ, NORIENT, NORIENT], dtype=self.all_type)
339
- for k in range(NORIENT):
340
- www[:, :, k, k] = self.w_smooth.reshape(KERNELSZ, KERNELSZ)
341
- self.ww_SmoothT[NORIENT] = self.backend.bk_constant(
342
- www.reshape(KERNELSZ, KERNELSZ, NORIENT, NORIENT)
339
+ self.w_smooth.reshape(1, KERNELSZ, KERNELSZ)
343
340
  )
344
341
  self.ww_RealT[1] = self.backend.bk_constant(
345
342
  self.backend.bk_reshape(
346
- wwc.astype(self.all_type), [KERNELSZ, KERNELSZ, 1, NORIENT]
343
+ wwc.astype(self.all_type), [NORIENT, KERNELSZ, KERNELSZ]
347
344
  )
348
345
  )
349
346
  self.ww_ImagT[1] = self.backend.bk_constant(
350
347
  self.backend.bk_reshape(
351
- wws.astype(self.all_type), [KERNELSZ, KERNELSZ, 1, NORIENT]
348
+ wws.astype(self.all_type), [NORIENT, KERNELSZ, KERNELSZ]
352
349
  )
353
350
  )
354
351
 
@@ -505,211 +502,27 @@ class FoCUS:
505
502
  )
506
503
  return indices, weights, xc, yc, zc
507
504
 
508
- # ---------------------------------------------−---------
509
- def calc_orientation(self, im): # im is [Ndata,12*Nside**2]
510
- nside = int(np.sqrt(im.shape[1] // 12))
511
- l_kernel = self.KERNELSZ * self.KERNELSZ
512
- norient = 32
513
- w = np.zeros([l_kernel, 1, 2 * norient])
514
- ca = np.cos(np.arange(norient) / norient * np.pi)
515
- sa = np.sin(np.arange(norient) / norient * np.pi)
516
- stat = np.zeros([12 * nside**2, norient])
517
-
518
- if self.ww_CNN[nside] is None:
519
- self.init_CNN_index(nside, transpose=False)
520
-
521
- y = self.Y_CNN[nside]
522
- z = self.Z_CNN[nside]
523
-
524
- for k in range(norient):
525
- w[:, 0, k] = np.exp(-0.5 * nside**2 * ((y) ** 2 + (z) ** 2)) * np.cos(
526
- nside * (y * ca[k] + z * sa[k]) * np.pi / 2
527
- )
528
- w[:, 0, k + norient] = np.exp(
529
- -0.5 * nside**2 * ((y) ** 2 + (z) ** 2)
530
- ) * np.sin(nside * (y * ca[k] + z * sa[k]) * np.pi / 2)
531
- w[:, 0, k] = w[:, 0, k] - np.mean(w[:, 0, k])
532
- w[:, 0, k + norient] = w[:, 0, k] - np.mean(w[:, 0, k + norient])
533
-
534
- for k in range(im.shape[0]):
535
- tmp = im[k].reshape(12 * nside**2, 1)
536
- im2 = self.healpix_layer(tmp, w)
537
- stat = stat + im2[:, 0:norient] ** 2 + im2[:, norient:] ** 2
538
-
539
- rotation = (np.argmax(stat, 1)).astype("float") / 32.0 * 180.0
540
-
541
- indices, weights, x, y, z = self.calc_indices_convol(
542
- nside, 9, rotation=rotation
543
- )
544
-
545
- return indices, weights
546
-
547
- def init_CNN_index(self, nside, transpose=False):
548
- l_kernel = int(self.KERNELSZ * self.KERNELSZ)
549
- try:
550
- indices = np.load(
551
- "%s/FOSCAT_%s_I%d_%d_%d_CNNV3.npy"
552
- % (self.TEMPLATE_PATH, TMPFILE_VERSION, l_kernel, self.NORIENT, nside)
553
- )
554
- weights = np.load(
555
- "%s/FOSCAT_%s_W%d_%d_%d_CNNV3.npy"
556
- % (self.TEMPLATE_PATH, TMPFILE_VERSION, l_kernel, self.NORIENT, nside)
557
- )
558
- xc = np.load(
559
- "%s/FOSCAT_%s_X%d_%d_%d_CNNV3.npy"
560
- % (self.TEMPLATE_PATH, TMPFILE_VERSION, l_kernel, self.NORIENT, nside)
561
- )
562
- yc = np.load(
563
- "%s/FOSCAT_%s_Y%d_%d_%d_CNNV3.npy"
564
- % (self.TEMPLATE_PATH, TMPFILE_VERSION, l_kernel, self.NORIENT, nside)
565
- )
566
- zc = np.load(
567
- "%s/FOSCAT_%s_Z%d_%d_%d_CNNV3.npy"
568
- % (self.TEMPLATE_PATH, TMPFILE_VERSION, l_kernel, self.NORIENT, nside)
569
- )
570
- except:
571
- indices, weights, xc, yc, zc = self.calc_indices_convol(nside, l_kernel)
572
- np.save(
573
- "%s/FOSCAT_%s_I%d_%d_%d_CNNV3.npy"
574
- % (self.TEMPLATE_PATH, TMPFILE_VERSION, l_kernel, self.NORIENT, nside),
575
- indices,
576
- )
577
- np.save(
578
- "%s/FOSCAT_%s_W%d_%d_%d_CNNV3.npy"
579
- % (self.TEMPLATE_PATH, TMPFILE_VERSION, l_kernel, self.NORIENT, nside),
580
- weights,
581
- )
582
- np.save(
583
- "%s/FOSCAT_%s_X%d_%d_%d_CNNV3.npy"
584
- % (self.TEMPLATE_PATH, TMPFILE_VERSION, l_kernel, self.NORIENT, nside),
585
- xc,
586
- )
587
- np.save(
588
- "%s/FOSCAT_%s_Y%d_%d_%d_CNNV3.npy"
589
- % (self.TEMPLATE_PATH, TMPFILE_VERSION, l_kernel, self.NORIENT, nside),
590
- yc,
591
- )
592
- np.save(
593
- "%s/FOSCAT_%s_Z%d_%d_%d_CNNV3.npy"
594
- % (self.TEMPLATE_PATH, TMPFILE_VERSION, l_kernel, self.NORIENT, nside),
595
- zc,
596
- )
597
- if not self.silent:
598
- print(
599
- "Write %s/FOSCAT_%s_W%d_%d_%d_CNNV2.npy"
600
- % (
601
- self.TEMPLATE_PATH,
602
- TMPFILE_VERSION,
603
- l_kernel,
604
- self.NORIENT,
605
- nside,
606
- )
607
- )
608
-
609
- self.X_CNN[nside] = xc
610
- self.Y_CNN[nside] = yc
611
- self.Z_CNN[nside] = zc
612
- self.ww_CNN[nside] = self.backend.bk_SparseTensor(
613
- indices, weights, [12 * nside * nside * l_kernel, 12 * nside * nside]
614
- )
615
-
616
- # ---------------------------------------------−---------
617
- def healpix_layer_coord(self, im, axis=0):
618
- nside = int(np.sqrt(im.shape[axis] // 12))
619
- if self.ww_CNN[nside] is None:
620
- self.init_CNN_index(nside)
621
- return self.X_CNN[nside], self.Y_CNN[nside], self.Z_CNN[nside]
622
-
623
- # ---------------------------------------------−---------
624
- def healpix_layer_transpose(self, im, ww, indices=None, weights=None, axis=0):
625
- nside = int(np.sqrt(im.shape[axis] // 12))
626
-
627
- if im.shape[1 + axis] != ww.shape[1]:
628
- if not self.silent:
629
- print("Weights channels should be equal to the input image channels")
630
- return -1
631
- if axis == 1:
632
- results = []
633
-
634
- for k in range(im.shape[0]):
635
-
636
- tmp = self.healpix_layer(
637
- im[k], ww, indices=indices, weights=weights, axis=0
638
- )
639
- tmp = self.backend.bk_reshape(
640
- self.up_grade(tmp, 2 * nside), [12 * 4 * nside**2, ww.shape[2]]
641
- )
642
-
643
- results.append(tmp)
644
-
645
- return self.backend.bk_stack(results, axis=0)
646
- else:
647
- tmp = self.healpix_layer(
648
- im, ww, indices=indices, weights=weights, axis=axis
649
- )
650
-
651
- return self.up_grade(tmp, 2 * nside)
652
-
653
505
  # ---------------------------------------------−---------
654
506
  # ---------------------------------------------−---------
655
- def healpix_layer(self, im, ww, indices=None, weights=None, axis=0):
656
- nside = int(np.sqrt(im.shape[axis] // 12))
657
- l_kernel = self.KERNELSZ * self.KERNELSZ
658
-
659
- if im.shape[1 + axis] != ww.shape[1]:
660
- if not self.silent:
661
- print("Weights channels should be equal to the input image channels")
662
- return -1
663
-
507
+ def healpix_layer(self, im, ww, indices=None, weights=None):
508
+ #ww [N_i,NORIENT,KERNELSZ*KERNELSZ//2,N_o,NORIENT]
509
+ #im [N_batch,N_i, NORIENT,N]
510
+ nside=int(np.sqrt(im.shape[-1]//12))
664
511
  if indices is None:
665
- if self.ww_CNN[nside] is None:
666
- self.init_CNN_index(nside, transpose=False)
667
- mat = self.ww_CNN[nside]
668
- else:
669
- if weights is None:
670
- print(
671
- "healpix_layer : If indices is not none weights should be specify"
672
- )
673
- return 0
674
-
675
- mat = self.backend.bk_SparseTensor(
676
- indices, weights, [12 * nside * nside * l_kernel, 12 * nside * nside]
677
- )
678
-
679
- if axis == 1:
680
- results = []
681
-
682
- for k in range(im.shape[0]):
683
-
684
- tmp = self.backend.bk_sparse_dense_matmul(mat, im[k])
685
-
686
- density = self.backend.bk_reshape(
687
- tmp, [12 * nside * nside, l_kernel * im.shape[1 + axis]]
688
- )
689
-
690
- density = self.backend.bk_matmul(
691
- density,
692
- self.backend.bk_reshape(
693
- ww, [l_kernel * im.shape[1 + axis], ww.shape[2]]
694
- ),
695
- )
696
-
697
- results.append(
698
- self.backend.bk_reshape(density, [12 * nside**2, ww.shape[2]])
699
- )
700
-
701
- return self.backend.bk_stack(results, axis=0)
702
- else:
703
- tmp = self.backend.bk_sparse_dense_matmul(mat, im)
704
-
705
- density = self.backend.bk_reshape(
706
- tmp, [12 * nside * nside, l_kernel * im.shape[1]]
707
- )
708
-
709
- return self.backend.bk_matmul(
710
- density,
711
- self.backend.bk_reshape(ww, [l_kernel * im.shape[1], ww.shape[2]]),
712
- )
512
+ if (nside,self.NORIENT,self.KERNELSZ) not in self.ww_CNN:
513
+ self.init_index_cnn(nside,self.NORIENT)
514
+ indices = self.Idx_CNN[(nside,self.NORIENT,self.KERNELSZ)]
515
+ mat = self.Idx_WCNN[(nside,self.NORIENT,self.KERNELSZ)]
516
+
517
+ wim = self.backend.bk_gather(im,indices.flatten(),axis=3) #[N_batch,N_i,NORIENT,K*(K+1),N_o,NORIENT,N,N_w]
518
+
519
+ wim = self.backend.bk_reshape(wim,[im.shape[0],im.shape[1],im.shape[2]]+list(indices.shape))*mat[None,...]
520
+ #win is [N_batch,N_i, NORIENT,K*(K+1),1, NORIENT,N,N_w]
521
+ #ww is [1, N_i, NORIENT,K*(K+1),N_o,NORIENT]
522
+ wim = self.backend.bk_reduce_sum(wim[:,:,:,:,None]*ww[None,:,:,:,:,:,None,None],[1,2,3])
523
+
524
+ wim = self.backend.bk_reduce_sum(wim,-1)
525
+ return self.backend.bk_reshape(wim,[im.shape[0],ww.shape[3],ww.shape[4],im.shape[-1]])
713
526
 
714
527
  # ---------------------------------------------−---------
715
528
 
@@ -806,53 +619,20 @@ class FoCUS:
806
619
  return self.backend.bk_reshape(res, [npix // 2, npiy // 2]), None
807
620
  elif self.use_1D:
808
621
  ishape = list(im.shape)
809
- if len(ishape) < axis + 1:
810
- if not self.silent:
811
- print("Use of 1D scat with data that has less than 1D")
812
- return None, None
813
622
 
814
- npix = im.shape[axis]
815
- odata = 1
816
- if len(ishape) > axis + 1:
817
- for k in range(axis + 1, len(ishape)):
818
- odata = odata * ishape[k]
623
+ npix = ishape[-1]
819
624
 
820
625
  ndata = 1
821
- for k in range(axis):
626
+ for k in range(len(ishape) - 1):
822
627
  ndata = ndata * ishape[k]
823
628
 
824
629
  tim = self.backend.bk_reshape(
825
- self.backend.bk_cast(im), [ndata, npix, odata]
826
- )
827
- tim = self.backend.bk_reshape(
828
- tim[:, 0 : 2 * (npix // 2), :], [ndata, npix // 2, 2, odata]
630
+ self.backend.bk_cast(im), [ndata, npix // 2, 2]
829
631
  )
830
632
 
831
- res = self.backend.bk_reduce_mean(tim, 2)
832
-
833
- if axis == 0:
834
- if len(ishape) == 1:
835
- return self.backend.bk_reshape(res, [npix // 2]), None
836
- else:
837
- return (
838
- self.backend.bk_reshape(res, [npix // 2] + ishape[axis + 1 :]),
839
- None,
840
- )
841
- else:
842
- if len(ishape) == axis + 1:
843
- return (
844
- self.backend.bk_reshape(res, ishape[0:axis] + [npix // 2]),
845
- None,
846
- )
847
- else:
848
- return (
849
- self.backend.bk_reshape(
850
- res, ishape[0:axis] + [npix // 2] + ishape[axis + 1 :]
851
- ),
852
- None,
853
- )
633
+ res = self.backend.bk_reduce_mean(tim, -1)
854
634
 
855
- return self.backend.bk_reshape(res, [npix // 2]), None
635
+ return self.backend.bk_reshape(res, ishape[0:-1] + [npix // 2]), None
856
636
 
857
637
  else:
858
638
  shape = list(im.shape)
@@ -1384,13 +1164,18 @@ class FoCUS:
1384
1164
  return res
1385
1165
 
1386
1166
  # ---------------------------------------------−---------
1387
- def init_index(self, nside, kernel=-1):
1167
+ def init_index(self, nside, kernel=-1, cell_ids=None):
1388
1168
 
1389
1169
  if kernel == -1:
1390
1170
  l_kernel = self.KERNELSZ
1391
1171
  else:
1392
1172
  l_kernel = kernel
1393
1173
 
1174
+ if cell_ids is not None:
1175
+ ncell = cell_ids.shape[0]
1176
+ else:
1177
+ ncell = 12 * nside * nside
1178
+
1394
1179
  try:
1395
1180
  if self.use_2D:
1396
1181
  tmp = np.load(
@@ -1398,16 +1183,29 @@ class FoCUS:
1398
1183
  % (self.TEMPLATE_PATH, l_kernel**2, TMPFILE_VERSION, nside)
1399
1184
  )
1400
1185
  else:
1401
- tmp = np.load(
1402
- "%s/FOSCAT_%s_W%d_%d_%d_PIDX.npy"
1403
- % (
1404
- self.TEMPLATE_PATH,
1405
- TMPFILE_VERSION,
1406
- l_kernel**2,
1407
- self.NORIENT,
1408
- nside,
1186
+ if cell_ids is not None:
1187
+ tmp = np.load(
1188
+ "%s/XXXX_%s_W%d_%d_%d_PIDX.npy" # can not work
1189
+ % (
1190
+ self.TEMPLATE_PATH,
1191
+ TMPFILE_VERSION,
1192
+ l_kernel**2,
1193
+ self.NORIENT,
1194
+ nside, # if cell_ids computes the index
1195
+ )
1196
+ )
1197
+
1198
+ else:
1199
+ tmp = np.load(
1200
+ "%s/FOSCAT_%s_W%d_%d_%d_PIDX.npy"
1201
+ % (
1202
+ self.TEMPLATE_PATH,
1203
+ TMPFILE_VERSION,
1204
+ l_kernel**2,
1205
+ self.NORIENT,
1206
+ nside, # if cell_ids computes the index
1207
+ )
1409
1208
  )
1410
- )
1411
1209
  except:
1412
1210
  if not self.use_2D:
1413
1211
 
@@ -1426,36 +1224,64 @@ class FoCUS:
1426
1224
  pw2 = 0.25
1427
1225
  threshold = 4e-5
1428
1226
 
1429
- th, ph = hp.pix2ang(nside, np.arange(12 * nside**2), nest=True)
1430
- x, y, z = hp.pix2vec(nside, np.arange(12 * nside**2), nest=True)
1227
+ if cell_ids is not None:
1228
+ if not isinstance(cell_ids, np.ndarray):
1229
+ cell_ids = self.backend.to_numpy(cell_ids)
1230
+ th, ph = hp.pix2ang(nside, cell_ids, nest=True)
1231
+ x, y, z = hp.pix2vec(nside, cell_ids, nest=True)
1431
1232
 
1432
- t, p = hp.pix2ang(nside, np.arange(12 * nside * nside), nest=True)
1433
- phi = [p[k] / np.pi * 180 for k in range(12 * nside * nside)]
1434
- thi = [t[k] / np.pi * 180 for k in range(12 * nside * nside)]
1233
+ t, p = hp.pix2ang(nside, cell_ids, nest=True)
1234
+ phi = [p[k] / np.pi * 180 for k in range(ncell)]
1235
+ thi = [t[k] / np.pi * 180 for k in range(ncell)]
1435
1236
 
1436
- indice2 = np.zeros([12 * nside * nside * 64, 2], dtype="int")
1437
- indice = np.zeros(
1438
- [12 * nside * nside * 64 * self.NORIENT, 2], dtype="int"
1439
- )
1440
- wav = np.zeros(
1441
- [12 * nside * nside * 64 * self.NORIENT], dtype="complex"
1442
- )
1443
- wwav = np.zeros([12 * nside * nside * 64 * self.NORIENT], dtype="float")
1237
+ indice2 = np.zeros([ncell * 64, 2], dtype="int")
1238
+ indice = np.zeros([ncell * 64 * self.NORIENT, 2], dtype="int")
1239
+ wav = np.zeros([ncell * 64 * self.NORIENT], dtype="complex")
1240
+ wwav = np.zeros([ncell * 64 * self.NORIENT], dtype="float")
1444
1241
 
1242
+ else:
1243
+
1244
+ th, ph = hp.pix2ang(nside, np.arange(12 * nside**2), nest=True)
1245
+ x, y, z = hp.pix2vec(nside, np.arange(12 * nside**2), nest=True)
1246
+
1247
+ t, p = hp.pix2ang(nside, np.arange(12 * nside * nside), nest=True)
1248
+ phi = [p[k] / np.pi * 180 for k in range(12 * nside * nside)]
1249
+ thi = [t[k] / np.pi * 180 for k in range(12 * nside * nside)]
1250
+
1251
+ indice2 = np.zeros([12 * nside * nside * 64, 2], dtype="int")
1252
+ indice = np.zeros(
1253
+ [12 * nside * nside * 64 * self.NORIENT, 2], dtype="int"
1254
+ )
1255
+ wav = np.zeros(
1256
+ [12 * nside * nside * 64 * self.NORIENT], dtype="complex"
1257
+ )
1258
+ wwav = np.zeros(
1259
+ [12 * nside * nside * 64 * self.NORIENT], dtype="float"
1260
+ )
1445
1261
  iv = 0
1446
1262
  iv2 = 0
1447
- for iii in range(12 * nside * nside):
1448
1263
 
1449
- if iii % (nside * nside) == nside * nside - 1:
1450
- if not self.silent:
1451
- print(
1452
- "Pre-compute nside=%6d %.2f%%"
1453
- % (nside, 100 * iii / (12 * nside * nside))
1454
- )
1264
+ for iii in range(ncell):
1265
+ if cell_ids is None:
1266
+ if iii % (nside * nside) == nside * nside - 1:
1267
+ if not self.silent:
1268
+ print(
1269
+ "Pre-compute nside=%6d %.2f%%"
1270
+ % (nside, 100 * iii / (12 * nside * nside))
1271
+ )
1455
1272
 
1456
- hidx = hp.query_disc(
1457
- nside, [x[iii], y[iii], z[iii]], 2 * np.pi / nside, nest=True
1458
- )
1273
+ if cell_ids is not None:
1274
+ hidx = np.where(
1275
+ (x - x[iii]) ** 2 + (y - y[iii]) ** 2 + (z - z[iii]) ** 2
1276
+ < (2 * np.pi / nside) ** 2
1277
+ )[0]
1278
+ else:
1279
+ hidx = hp.query_disc(
1280
+ nside,
1281
+ [x[iii], y[iii], z[iii]],
1282
+ 2 * np.pi / nside,
1283
+ nest=True,
1284
+ )
1459
1285
 
1460
1286
  R = hp.Rotator(rot=[phi[iii], -thi[iii]], eulertype="ZYZ")
1461
1287
 
@@ -1474,8 +1300,8 @@ class FoCUS:
1474
1300
  )
1475
1301
  idx = np.where((ww**2) > threshold)[0]
1476
1302
  nval2 = len(idx)
1477
- indice2[iv2 : iv2 + nval2, 0] = iii
1478
- indice2[iv2 : iv2 + nval2, 1] = hidx[idx]
1303
+ indice2[iv2 : iv2 + nval2, 1] = iii
1304
+ indice2[iv2 : iv2 + nval2, 0] = hidx[idx]
1479
1305
  wwav[iv2 : iv2 + nval2] = ww[idx] / np.sum(ww[idx])
1480
1306
  iv2 += nval2
1481
1307
 
@@ -1497,15 +1323,18 @@ class FoCUS:
1497
1323
  idx = np.where(vnorm > threshold)[0]
1498
1324
 
1499
1325
  nval = len(idx)
1500
- indice[iv : iv + nval, 0] = iii * 4 + l_rotation
1501
- indice[iv : iv + nval, 1] = hidx[idx]
1326
+ indice[iv : iv + nval, 1] = iii + l_rotation * ncell
1327
+ indice[iv : iv + nval, 0] = hidx[idx]
1502
1328
  # print([hidx[k] for k in idx])
1503
1329
  # print(hp.query_disc(nside, [x[iii],y[iii],z[iii]], np.pi/nside,nest=True))
1504
1330
  normr = np.mean(wresr[idx])
1505
1331
  normi = np.mean(wresi[idx])
1506
1332
 
1507
1333
  val = wresr[idx] - normr + 1j * (wresi[idx] - normi)
1508
- val = val / abs(val).sum()
1334
+ r = abs(val).sum()
1335
+
1336
+ if r > 0:
1337
+ val = val / r
1509
1338
 
1510
1339
  wav[iv : iv + nval] = val
1511
1340
  iv += nval
@@ -1609,56 +1438,57 @@ class FoCUS:
1609
1438
  wav=w.flatten()
1610
1439
  wwav=wwav.flatten()
1611
1440
  """
1612
- if not self.silent:
1613
- print(
1614
- "Write FOSCAT_%s_W%d_%d_%d_PIDX.npy"
1615
- % (TMPFILE_VERSION, self.KERNELSZ**2, self.NORIENT, nside)
1441
+ if cell_ids is None:
1442
+ if not self.silent:
1443
+ print(
1444
+ "Write FOSCAT_%s_W%d_%d_%d_PIDX.npy"
1445
+ % (TMPFILE_VERSION, self.KERNELSZ**2, self.NORIENT, nside)
1446
+ )
1447
+ np.save(
1448
+ "%s/FOSCAT_%s_W%d_%d_%d_PIDX.npy"
1449
+ % (
1450
+ self.TEMPLATE_PATH,
1451
+ TMPFILE_VERSION,
1452
+ self.KERNELSZ**2,
1453
+ self.NORIENT,
1454
+ nside,
1455
+ ),
1456
+ indice,
1457
+ )
1458
+ np.save(
1459
+ "%s/FOSCAT_%s_W%d_%d_%d_WAVE.npy"
1460
+ % (
1461
+ self.TEMPLATE_PATH,
1462
+ TMPFILE_VERSION,
1463
+ self.KERNELSZ**2,
1464
+ self.NORIENT,
1465
+ nside,
1466
+ ),
1467
+ wav,
1468
+ )
1469
+ np.save(
1470
+ "%s/FOSCAT_%s_W%d_%d_%d_PIDX2.npy"
1471
+ % (
1472
+ self.TEMPLATE_PATH,
1473
+ TMPFILE_VERSION,
1474
+ self.KERNELSZ**2,
1475
+ self.NORIENT,
1476
+ nside,
1477
+ ),
1478
+ indice2,
1479
+ )
1480
+ np.save(
1481
+ "%s/FOSCAT_%s_W%d_%d_%d_SMOO.npy"
1482
+ % (
1483
+ self.TEMPLATE_PATH,
1484
+ TMPFILE_VERSION,
1485
+ self.KERNELSZ**2,
1486
+ self.NORIENT,
1487
+ nside,
1488
+ ),
1489
+ wwav,
1616
1490
  )
1617
- np.save(
1618
- "%s/FOSCAT_%s_W%d_%d_%d_PIDX.npy"
1619
- % (
1620
- self.TEMPLATE_PATH,
1621
- TMPFILE_VERSION,
1622
- self.KERNELSZ**2,
1623
- self.NORIENT,
1624
- nside,
1625
- ),
1626
- indice,
1627
- )
1628
- np.save(
1629
- "%s/FOSCAT_%s_W%d_%d_%d_WAVE.npy"
1630
- % (
1631
- self.TEMPLATE_PATH,
1632
- TMPFILE_VERSION,
1633
- self.KERNELSZ**2,
1634
- self.NORIENT,
1635
- nside,
1636
- ),
1637
- wav,
1638
- )
1639
- np.save(
1640
- "%s/FOSCAT_%s_W%d_%d_%d_PIDX2.npy"
1641
- % (
1642
- self.TEMPLATE_PATH,
1643
- TMPFILE_VERSION,
1644
- self.KERNELSZ**2,
1645
- self.NORIENT,
1646
- nside,
1647
- ),
1648
- indice2,
1649
- )
1650
- np.save(
1651
- "%s/FOSCAT_%s_W%d_%d_%d_SMOO.npy"
1652
- % (
1653
- self.TEMPLATE_PATH,
1654
- TMPFILE_VERSION,
1655
- self.KERNELSZ**2,
1656
- self.NORIENT,
1657
- nside,
1658
- ),
1659
- wwav,
1660
- )
1661
- else:
1491
+ if self.use_2D:
1662
1492
  if l_kernel**2 == 9:
1663
1493
  if self.rank == 0:
1664
1494
  self.comp_idx_w9(nside)
@@ -1674,23 +1504,24 @@ class FoCUS:
1674
1504
  )
1675
1505
  return None
1676
1506
 
1677
- self.barrier()
1678
- if self.use_2D:
1679
- tmp = np.load(
1680
- "%s/W%d_%s_%d_IDX.npy"
1681
- % (self.TEMPLATE_PATH, l_kernel**2, TMPFILE_VERSION, nside)
1682
- )
1683
- else:
1684
- tmp = np.load(
1685
- "%s/FOSCAT_%s_W%d_%d_%d_PIDX.npy"
1686
- % (
1687
- self.TEMPLATE_PATH,
1688
- TMPFILE_VERSION,
1689
- self.KERNELSZ**2,
1690
- self.NORIENT,
1691
- nside,
1507
+ if cell_ids is None:
1508
+ self.barrier()
1509
+ if self.use_2D:
1510
+ tmp = np.load(
1511
+ "%s/W%d_%s_%d_IDX.npy"
1512
+ % (self.TEMPLATE_PATH, l_kernel**2, TMPFILE_VERSION, nside)
1513
+ )
1514
+ else:
1515
+ tmp = np.load(
1516
+ "%s/FOSCAT_%s_W%d_%d_%d_PIDX.npy"
1517
+ % (
1518
+ self.TEMPLATE_PATH,
1519
+ TMPFILE_VERSION,
1520
+ self.KERNELSZ**2,
1521
+ self.NORIENT,
1522
+ nside,
1523
+ )
1692
1524
  )
1693
- )
1694
1525
  tmp2 = np.load(
1695
1526
  "%s/FOSCAT_%s_W%d_%d_%d_PIDX2.npy"
1696
1527
  % (
@@ -1731,22 +1562,28 @@ class FoCUS:
1731
1562
  nside,
1732
1563
  )
1733
1564
  )
1734
-
1735
- wr = self.backend.bk_SparseTensor(
1736
- self.backend.bk_constant(tmp),
1737
- self.backend.bk_constant(self.backend.bk_cast(wr)),
1738
- dense_shape=[12 * nside**2 * self.NORIENT, 12 * nside**2],
1739
- )
1740
- wi = self.backend.bk_SparseTensor(
1741
- self.backend.bk_constant(tmp),
1742
- self.backend.bk_constant(self.backend.bk_cast(wi)),
1743
- dense_shape=[12 * nside**2 * self.NORIENT, 12 * nside**2],
1744
- )
1745
- ws = self.backend.bk_SparseTensor(
1746
- self.backend.bk_constant(tmp2),
1747
- self.backend.bk_constant(self.backend.bk_cast(ws)),
1748
- dense_shape=[12 * nside**2, 12 * nside**2],
1749
- )
1565
+ else:
1566
+ tmp = indice
1567
+ tmp2 = indice2
1568
+ wr = wav.real
1569
+ wi = wav.imag
1570
+ ws = self.slope * wwav
1571
+
1572
+ wr = self.backend.bk_SparseTensor(
1573
+ self.backend.bk_constant(tmp),
1574
+ self.backend.bk_constant(self.backend.bk_cast(wr)),
1575
+ dense_shape=[ncell, self.NORIENT * ncell],
1576
+ )
1577
+ wi = self.backend.bk_SparseTensor(
1578
+ self.backend.bk_constant(tmp),
1579
+ self.backend.bk_constant(self.backend.bk_cast(wi)),
1580
+ dense_shape=[ncell, self.NORIENT * ncell],
1581
+ )
1582
+ ws = self.backend.bk_SparseTensor(
1583
+ self.backend.bk_constant(tmp2),
1584
+ self.backend.bk_constant(self.backend.bk_cast(ws)),
1585
+ dense_shape=[ncell, ncell],
1586
+ )
1750
1587
 
1751
1588
  if kernel == -1:
1752
1589
  self.Idx_Neighbours[nside] = tmp
@@ -1757,42 +1594,268 @@ class FoCUS:
1757
1594
 
1758
1595
  return wr, wi, ws, tmp
1759
1596
 
1597
+
1760
1598
  # ---------------------------------------------−---------
1761
- # convert swap axes tensor x [....,a,....,b,....] to [....,b,....,a,....]
1762
- def swapaxes(self, x, axis1, axis2):
1763
- shape = list(x.shape)
1764
- if axis1 < 0:
1765
- laxis1 = len(shape) + axis1
1766
- else:
1767
- laxis1 = axis1
1768
- if axis2 < 0:
1769
- laxis2 = len(shape) + axis2
1599
+ def init_index_cnn(self, nside, NORIENT=4,kernel=-1, cell_ids=None):
1600
+
1601
+ if kernel == -1:
1602
+ l_kernel = self.KERNELSZ
1770
1603
  else:
1771
- laxis2 = axis2
1604
+ l_kernel = kernel
1772
1605
 
1773
- naxes = len(shape)
1774
- thelist = [i for i in range(naxes)]
1775
- thelist[laxis1] = laxis2
1776
- thelist[laxis2] = laxis1
1777
- return self.backend.bk_transpose(x, thelist)
1606
+ if cell_ids is not None:
1607
+ ncell = cell_ids.shape[0]
1608
+ else:
1609
+ ncell = 12 * nside * nside
1778
1610
 
1779
- # ---------------------------------------------−---------
1780
- # Mean using mask x [....,Npix,....], mask[Nmask,Npix] to [....,Nmask,....]
1781
- # if use_2D
1782
- # Mean using mask x [....,12,Nside+2*off,Nside+2*off,....], mask[Nmask,12,Nside+2*off,Nside+2*off] to [....,Nmask,....]
1783
- def masked_mean(self, x, mask, axis=0, rank=0, calc_var=False):
1611
+ try:
1612
+
1613
+ if cell_ids is not None:
1614
+ tmp = np.load(
1615
+ "%s/XXXX_%s_W%d_%d_%d_PIDX.npy" # can not work
1616
+ % (
1617
+ self.TEMPLATE_PATH,
1618
+ TMPFILE_VERSION,
1619
+ l_kernel**2,
1620
+ NORIENT,
1621
+ nside, # if cell_ids computes the index
1622
+ )
1623
+ )
1784
1624
 
1785
- # ==========================================================================
1786
- # in input data=[Nbatch,...,X[,Y],NORIENT[,NORIENT]]
1787
- # in input mask=[Nmask,X[,Y]]
1788
- # if self.use_2D : X[,Y]] = [X,Y]
1789
- # if second level: NORIENT[,NORIENT]= NORIENT,NORIENT
1790
- # ==========================================================================
1625
+ else:
1626
+ tmp = np.load(
1627
+ "%s/CNN_FOSCAT_%s_W%d_%d_%d_PIDX.npy"
1628
+ % (
1629
+ self.TEMPLATE_PATH,
1630
+ TMPFILE_VERSION,
1631
+ l_kernel**2,
1632
+ NORIENT,
1633
+ nside, # if cell_ids computes the index
1634
+ )
1635
+ )
1636
+ except:
1791
1637
 
1792
- shape = list(x.shape)
1638
+ pw = 8.0
1639
+ pw2 = 1.0
1640
+ threshold = 1e-3
1641
+
1642
+ if l_kernel == 5:
1643
+ pw = 8.0
1644
+ pw2 = 0.5
1645
+ threshold = 2e-4
1793
1646
 
1794
- if not self.use_2D:
1795
- nside = int(np.sqrt(x.shape[axis] // 12))
1647
+ elif l_kernel == 3:
1648
+ pw = 8.0
1649
+ pw2 = 1.0
1650
+ threshold = 1e-3
1651
+
1652
+ elif l_kernel == 7:
1653
+ pw = 8.0
1654
+ pw2 = 0.25
1655
+ threshold = 4e-5
1656
+
1657
+ n_weights = self.KERNELSZ*(self.KERNELSZ//2+1)
1658
+
1659
+ if cell_ids is not None:
1660
+ if not isinstance(cell_ids, np.ndarray):
1661
+ cell_ids = self.backend.to_numpy(cell_ids)
1662
+ th, ph = hp.pix2ang(nside, cell_ids, nest=True)
1663
+ x, y, z = hp.pix2vec(nside, cell_ids, nest=True)
1664
+
1665
+ t, p = hp.pix2ang(nside, cell_ids, nest=True)
1666
+ phi = [p[k] / np.pi * 180 for k in range(ncell)]
1667
+ thi = [t[k] / np.pi * 180 for k in range(ncell)]
1668
+
1669
+ indice = np.zeros([n_weights, NORIENT, ncell,4], dtype="int")
1670
+
1671
+ wav = np.zeros([n_weights, NORIENT, ncell,4], dtype="float")
1672
+
1673
+ else:
1674
+
1675
+ th, ph = hp.pix2ang(nside, np.arange(12 * nside**2), nest=True)
1676
+ x, y, z = hp.pix2vec(nside, np.arange(12 * nside**2), nest=True)
1677
+
1678
+ t, p = hp.pix2ang(nside, np.arange(12 * nside * nside), nest=True)
1679
+ phi = [p[k] / np.pi * 180 for k in range(12 * nside * nside)]
1680
+ thi = [t[k] / np.pi * 180 for k in range(12 * nside * nside)]
1681
+
1682
+ indice = np.zeros(
1683
+ [n_weights, NORIENT, 12 * nside * nside,4], dtype="int"
1684
+ )
1685
+ wav = np.zeros(
1686
+ [n_weights, NORIENT, 12 * nside * nside,4], dtype="float"
1687
+ )
1688
+ iv = 0
1689
+ iv2 = 0
1690
+
1691
+ for iii in range(ncell):
1692
+ if cell_ids is None:
1693
+ if iii % (nside * nside) == nside * nside - 1:
1694
+ if not self.silent:
1695
+ print(
1696
+ "Pre-compute nside=%6d %.2f%%"
1697
+ % (nside, 100 * iii / (12 * nside * nside))
1698
+ )
1699
+
1700
+ if cell_ids is not None:
1701
+ hidx = np.where(
1702
+ (x - x[iii]) ** 2 + (y - y[iii]) ** 2 + (z - z[iii]) ** 2
1703
+ < (2 * np.pi / nside) ** 2
1704
+ )[0]
1705
+ else:
1706
+ hidx = hp.query_disc(
1707
+ nside,
1708
+ [x[iii], y[iii], z[iii]],
1709
+ 2 * np.pi / nside,
1710
+ nest=True,
1711
+ )
1712
+
1713
+ R = hp.Rotator(rot=[phi[iii], -thi[iii]], eulertype="ZYZ")
1714
+
1715
+ t2, p2 = R(th[hidx], ph[hidx])
1716
+
1717
+ vec2 = hp.ang2vec(t2, p2)
1718
+
1719
+ x2 = vec2[:, 0]
1720
+ y2 = vec2[:, 1]
1721
+ z2 = vec2[:, 2]
1722
+
1723
+ for l_rotation in range(NORIENT):
1724
+
1725
+ angle = (
1726
+ l_rotation / 4.0 * np.pi
1727
+ - phi[iii] / 180.0 * np.pi * (z[hidx] > 0)
1728
+ - (180.0 - phi[iii]) / 180.0 * np.pi * (z[hidx] < 0)
1729
+ )
1730
+
1731
+
1732
+ axes = y2 * np.cos(angle) - x2 * np.sin(angle)
1733
+ axes2 = -y2 * np.sin(angle) - x2 * np.cos(angle)
1734
+
1735
+ for k_weights in range(self.KERNELSZ//2+1):
1736
+ for l_weights in range(self.KERNELSZ):
1737
+
1738
+ val=np.exp(-(pw*(axes2*(nside)-(k_weights-self.KERNELSZ//2))**2+pw*(axes*(nside)-(l_weights-self.KERNELSZ//2))**2))+ \
1739
+ np.exp(-(pw*(axes2*(nside)+(k_weights-self.KERNELSZ//2))**2+pw*(axes*(nside)-(l_weights-self.KERNELSZ//2))**2))
1740
+
1741
+ idx = np.argsort(-val)
1742
+ idx = idx[0:4]
1743
+
1744
+ nval = len(idx)
1745
+ val=val[idx]
1746
+
1747
+ r = abs(val).sum()
1748
+
1749
+ if r > 0:
1750
+ val = val / r
1751
+
1752
+ indice[k_weights*self.KERNELSZ+l_weights,l_rotation,iii,:] = hidx[idx]
1753
+ wav[k_weights*self.KERNELSZ+l_weights,l_rotation,iii,:] = val
1754
+
1755
+ if not self.silent:
1756
+ print("Kernel Size ", iv / (NORIENT * 12 * nside * nside))
1757
+
1758
+ if cell_ids is None:
1759
+ if not self.silent:
1760
+ print(
1761
+ "Write FOSCAT_%s_W%d_%d_%d_PIDX.npy"
1762
+ % (TMPFILE_VERSION, self.KERNELSZ**2, NORIENT, nside)
1763
+ )
1764
+ np.save(
1765
+ "%s/CNN_FOSCAT_%s_W%d_%d_%d_PIDX.npy"
1766
+ % (
1767
+ self.TEMPLATE_PATH,
1768
+ TMPFILE_VERSION,
1769
+ self.KERNELSZ**2,
1770
+ NORIENT,
1771
+ nside,
1772
+ ),
1773
+ indice,
1774
+ )
1775
+ np.save(
1776
+ "%s/CNN_FOSCAT_%s_W%d_%d_%d_WAVE.npy"
1777
+ % (
1778
+ self.TEMPLATE_PATH,
1779
+ TMPFILE_VERSION,
1780
+ self.KERNELSZ**2,
1781
+ NORIENT,
1782
+ nside,
1783
+ ),
1784
+ wav,
1785
+ )
1786
+
1787
+ if cell_ids is None:
1788
+ self.barrier()
1789
+ if self.use_2D:
1790
+ tmp = np.load(
1791
+ "%s/W%d_%s_%d_IDX.npy"
1792
+ % (self.TEMPLATE_PATH, l_kernel**2, TMPFILE_VERSION, nside)
1793
+ )
1794
+ else:
1795
+ tmp = np.load(
1796
+ "%s/CNN_FOSCAT_%s_W%d_%d_%d_PIDX.npy"
1797
+ % (
1798
+ self.TEMPLATE_PATH,
1799
+ TMPFILE_VERSION,
1800
+ self.KERNELSZ**2,
1801
+ NORIENT,
1802
+ nside,
1803
+ )
1804
+ )
1805
+ wav = np.load(
1806
+ "%s/CNN_FOSCAT_%s_W%d_%d_%d_WAVE.npy"
1807
+ % (
1808
+ self.TEMPLATE_PATH,
1809
+ TMPFILE_VERSION,
1810
+ self.KERNELSZ**2,
1811
+ NORIENT,
1812
+ nside,
1813
+ )
1814
+ )
1815
+ else:
1816
+ tmp = indice
1817
+
1818
+ self.Idx_CNN[(nside,NORIENT,self.KERNELSZ)] = tmp
1819
+ self.Idx_WCNN[(nside,NORIENT,self.KERNELSZ)] = self.backend.bk_cast(wav)
1820
+
1821
+ return wav, tmp
1822
+
1823
+ # ---------------------------------------------−---------
1824
+ # convert swap axes tensor x [....,a,....,b,....] to [....,b,....,a,....]
1825
+ def swapaxes(self, x, axis1, axis2):
1826
+ shape = list(x.shape)
1827
+ if axis1 < 0:
1828
+ laxis1 = len(shape) + axis1
1829
+ else:
1830
+ laxis1 = axis1
1831
+ if axis2 < 0:
1832
+ laxis2 = len(shape) + axis2
1833
+ else:
1834
+ laxis2 = axis2
1835
+
1836
+ naxes = len(shape)
1837
+ thelist = [i for i in range(naxes)]
1838
+ thelist[laxis1] = laxis2
1839
+ thelist[laxis2] = laxis1
1840
+ return self.backend.bk_transpose(x, thelist)
1841
+
1842
+ # ---------------------------------------------−---------
1843
+ # Mean using mask x [....,Npix,....], mask[Nmask,Npix] to [....,Nmask,....]
1844
+ # if use_2D
1845
+ # Mean using mask x [....,12,Nside+2*off,Nside+2*off,....], mask[Nmask,12,Nside+2*off,Nside+2*off] to [....,Nmask,....]
1846
+ def masked_mean(self, x, mask, axis=0, rank=0, calc_var=False):
1847
+
1848
+ # ==========================================================================
1849
+ # in input data=[Nbatch,...,NORIENT[,NORIENT],X[,Y]]
1850
+ # in input mask=[Nmask,X[,Y]]
1851
+ # if self.use_2D : X[,Y]] = [X,Y]
1852
+ # if second level: NORIENT[,NORIENT]= NORIENT,NORIENT
1853
+ # ==========================================================================
1854
+
1855
+ shape = list(x.shape)
1856
+
1857
+ if not self.use_2D and not self.use_1D:
1858
+ nside = int(np.sqrt(x.shape[axis] // 12))
1796
1859
 
1797
1860
  l_mask = mask
1798
1861
  if self.mask_norm:
@@ -1802,6 +1865,7 @@ class FoCUS:
1802
1865
  ),
1803
1866
  1,
1804
1867
  )
1868
+
1805
1869
  if not self.use_2D:
1806
1870
  l_mask = (
1807
1871
  12
@@ -1845,13 +1909,11 @@ class FoCUS:
1845
1909
  ]
1846
1910
 
1847
1911
  ichannel = 1
1848
- for i in range(axis):
1912
+ for i in range(1, len(shape) - 2):
1849
1913
  ichannel *= shape[i]
1850
- ochannel = 1
1851
- for i in range(axis + 2, len(shape)):
1852
- ochannel *= shape[i]
1914
+
1853
1915
  l_x = self.backend.bk_reshape(
1854
- x, [ichannel, 1, shape[axis], shape[axis + 1], ochannel]
1916
+ x, [shape[0], 1, ichannel, shape[-2], shape[-1]]
1855
1917
  )
1856
1918
 
1857
1919
  if self.padding == "VALID":
@@ -1876,12 +1938,10 @@ class FoCUS:
1876
1938
  l_mask = l_mask[:, self.KERNELSZ // 2 : -self.KERNELSZ // 2 + 1]
1877
1939
 
1878
1940
  ichannel = 1
1879
- for i in range(axis):
1941
+ for i in range(1, len(shape) - 1):
1880
1942
  ichannel *= shape[i]
1881
- ochannel = 1
1882
- for i in range(axis + 1, len(shape)):
1883
- ochannel *= shape[i]
1884
- l_x = self.backend.bk_reshape(x, [ichannel, 1, shape[axis], ochannel])
1943
+
1944
+ l_x = self.backend.bk_reshape(x, [shape[0], 1, ichannel, shape[-1]])
1885
1945
 
1886
1946
  if self.padding == "VALID":
1887
1947
  oshape = [k for k in shape]
@@ -1891,18 +1951,14 @@ class FoCUS:
1891
1951
  )
1892
1952
  else:
1893
1953
  ichannel = 1
1894
- for i in range(axis):
1954
+ for i in range(len(shape) - 1):
1895
1955
  ichannel *= shape[i]
1896
- ochannel = 1
1897
- for i in range(axis + 1, len(shape)):
1898
- ochannel *= shape[i]
1899
- l_x = self.backend.bk_reshape(x, [ichannel, 1, shape[axis], ochannel])
1900
1956
 
1901
- # data=[Nbatch,...,X[,Y],NORIENT[,NORIENT]] => data=[Nbatch,1,...,X[,Y],NORIENT[,NORIENT]]
1902
- # mask=[Nmask,X[,Y]] => mask=[1,Nmask,X[,Y]]
1903
- l_mask = self.backend.bk_expand_dims(l_mask, 0)
1904
- # mask=[1,Nmask,X[,Y]] => mask=[1,Nmask,X[,Y],1]
1905
- l_mask = self.backend.bk_expand_dims(l_mask, -1)
1957
+ l_x = self.backend.bk_reshape(x, [ichannel, 1, shape[-1]])
1958
+
1959
+ # data=[Nbatch,...,NORIENT[,NORIENT],X[,Y]] => data=[Nbatch,1,...,NORIENT[,NORIENT],X[,Y]]
1960
+ # mask=[Nmask,X[,Y]] => mask=[1,Nmask,....,X[,Y]]
1961
+ l_mask = self.backend.bk_expand_dims(self.backend.bk_expand_dims(l_mask, 0), 0)
1906
1962
 
1907
1963
  if l_x.dtype == self.all_cbk_type:
1908
1964
  l_mask = self.backend.bk_complex(l_mask, self.backend.bk_cast(0.0 * l_mask))
@@ -1916,21 +1972,23 @@ class FoCUS:
1916
1972
  # vtmp = l_x[:,self.KERNELSZ // 2 : -self.KERNELSZ // 2,self.KERNELSZ // 2 : -self.KERNELSZ // 2,:]
1917
1973
 
1918
1974
  v1 = self.backend.bk_reduce_sum(
1919
- self.backend.bk_reduce_sum(mtmp * vtmp, axis=2), 2
1975
+ self.backend.bk_reduce_sum(mtmp * vtmp, axis=-1), -1
1920
1976
  )
1921
1977
  v2 = self.backend.bk_reduce_sum(
1922
- self.backend.bk_reduce_sum(mtmp * vtmp * vtmp, axis=2), 2
1978
+ self.backend.bk_reduce_sum(mtmp * vtmp * vtmp, axis=-1), -1
1979
+ )
1980
+ vh = self.backend.bk_reduce_sum(
1981
+ self.backend.bk_reduce_sum(mtmp, axis=-1), -1
1923
1982
  )
1924
- vh = self.backend.bk_reduce_sum(self.backend.bk_reduce_sum(mtmp, axis=2), 2)
1925
1983
 
1926
1984
  res = v1 / vh
1927
1985
 
1928
- oshape = []
1986
+ oshape = [x.shape[0]] + [mask.shape[0]]
1929
1987
  if axis > 0:
1930
- oshape = oshape + list(x.shape[0:axis])
1931
- oshape = oshape + [mask.shape[0]]
1932
- if axis + 1 < len(x.shape):
1933
- oshape = oshape + list(x.shape[axis + 2 :])
1988
+ oshape = oshape + list(x.shape[1:axis])
1989
+
1990
+ if len(x.shape[axis:-2]) > 0:
1991
+ oshape = oshape + list(x.shape[axis:-2])
1934
1992
 
1935
1993
  if calc_var:
1936
1994
  if self.backend.bk_is_complex(vtmp):
@@ -1960,19 +2018,15 @@ class FoCUS:
1960
2018
  elif self.use_1D:
1961
2019
  mtmp = l_mask
1962
2020
  vtmp = l_x
1963
-
1964
- v1 = self.backend.bk_reduce_sum(mtmp * vtmp, axis=2)
1965
- v2 = self.backend.bk_reduce_sum(mtmp * vtmp * vtmp, axis=2)
1966
- vh = self.backend.bk_reduce_sum(mtmp, axis=2)
2021
+ v1 = self.backend.bk_reduce_sum(mtmp * vtmp, axis=-1)
2022
+ v2 = self.backend.bk_reduce_sum(mtmp * vtmp * vtmp, axis=-1)
2023
+ vh = self.backend.bk_reduce_sum(mtmp, axis=-1)
1967
2024
 
1968
2025
  res = v1 / vh
1969
2026
 
1970
- oshape = []
1971
- if axis > 0:
1972
- oshape = oshape + list(x.shape[0:axis])
1973
- oshape = oshape + [mask.shape[0]]
1974
- if axis + 1 < len(x.shape):
1975
- oshape = oshape + list(x.shape[axis + 1 :])
2027
+ oshape = [x.shape[0]] + [mask.shape[0]]
2028
+ if len(x.shape) > 1:
2029
+ oshape = oshape + list(x.shape[1:-1])
1976
2030
 
1977
2031
  if calc_var:
1978
2032
  if self.backend.bk_is_complex(vtmp):
@@ -1991,7 +2045,6 @@ class FoCUS:
1991
2045
  )
1992
2046
  else:
1993
2047
  res2 = self.backend.bk_sqrt((v2 / vh - res * res) / (vh))
1994
-
1995
2048
  res = self.backend.bk_reshape(res, oshape)
1996
2049
  res2 = self.backend.bk_reshape(res2, oshape)
1997
2050
  return res, res2
@@ -2000,18 +2053,20 @@ class FoCUS:
2000
2053
  return res
2001
2054
 
2002
2055
  else:
2003
- v1 = self.backend.bk_reduce_sum(l_mask * l_x, axis=2)
2004
- v2 = self.backend.bk_reduce_sum(l_mask * l_x * l_x, axis=2)
2005
- vh = self.backend.bk_reduce_sum(l_mask, axis=2)
2056
+ v1 = self.backend.bk_reduce_sum(l_mask * l_x, axis=-1)
2057
+ v2 = self.backend.bk_reduce_sum(l_mask * l_x * l_x, axis=-1)
2058
+ vh = self.backend.bk_reduce_sum(l_mask, axis=-1)
2006
2059
 
2007
2060
  res = v1 / vh
2008
2061
 
2009
2062
  oshape = []
2010
2063
  if axis > 0:
2011
- oshape = oshape + list(x.shape[0:axis])
2064
+ oshape = [x.shape[0]]
2065
+ else:
2066
+ oshape = [1]
2012
2067
  oshape = oshape + [mask.shape[0]]
2013
- if axis + 1 < len(x.shape):
2014
- oshape = oshape + list(x.shape[axis + 1 :])
2068
+ if axis > 1:
2069
+ oshape = oshape + list(x.shape[1:-1])
2015
2070
 
2016
2071
  if calc_var:
2017
2072
  if self.backend.bk_is_complex(l_x):
@@ -2176,169 +2231,67 @@ class FoCUS:
2176
2231
  print("Use of 2D scat with data that has less than 2D")
2177
2232
  return None
2178
2233
 
2179
- npix = ishape[axis]
2180
- npiy = ishape[axis + 1]
2181
- odata = 1
2182
- if len(ishape) > axis + 2:
2183
- for k in range(axis + 2, len(ishape)):
2184
- odata = odata * ishape[k]
2234
+ npix = ishape[-2]
2235
+ npiy = ishape[-1]
2185
2236
 
2186
2237
  ndata = 1
2187
- for k in range(axis):
2238
+ for k in range(len(ishape) - 2):
2188
2239
  ndata = ndata * ishape[k]
2189
2240
 
2190
2241
  tim = self.backend.bk_reshape(
2191
- self.backend.bk_cast(in_image), [ndata, npix, npiy, odata]
2242
+ self.backend.bk_cast(in_image), [ndata, npix, npiy]
2192
2243
  )
2193
2244
 
2194
2245
  if self.backend.bk_is_complex(tim):
2195
- rr1 = self.backend.conv2d(
2196
- self.backend.bk_real(tim),
2197
- self.ww_RealT[odata],
2198
- strides=[1, 1, 1, 1],
2199
- padding=self.padding,
2200
- )
2201
- ii1 = self.backend.conv2d(
2202
- self.backend.bk_real(tim),
2203
- self.ww_ImagT[odata],
2204
- strides=[1, 1, 1, 1],
2205
- padding=self.padding,
2206
- )
2207
- rr2 = self.backend.conv2d(
2208
- self.backend.bk_imag(tim),
2209
- self.ww_RealT[odata],
2210
- strides=[1, 1, 1, 1],
2211
- padding=self.padding,
2212
- )
2213
- ii2 = self.backend.conv2d(
2214
- self.backend.bk_imag(tim),
2215
- self.ww_ImagT[odata],
2216
- strides=[1, 1, 1, 1],
2217
- padding=self.padding,
2218
- )
2246
+ rr1 = self.backend.conv2d(self.backend.bk_real(tim), self.ww_RealT[1])
2247
+ ii1 = self.backend.conv2d(self.backend.bk_real(tim), self.ww_ImagT[1])
2248
+ rr2 = self.backend.conv2d(self.backend.bk_imag(tim), self.ww_RealT[1])
2249
+ ii2 = self.backend.conv2d(self.backend.bk_imag(tim), self.ww_ImagT[1])
2219
2250
  res = self.backend.bk_complex(rr1 - ii2, ii1 + rr2)
2220
2251
  else:
2221
- rr = self.backend.conv2d(
2222
- tim,
2223
- self.ww_RealT[odata],
2224
- strides=[1, 1, 1, 1],
2225
- padding=self.padding,
2226
- )
2227
- ii = self.backend.conv2d(
2228
- tim,
2229
- self.ww_ImagT[odata],
2230
- strides=[1, 1, 1, 1],
2231
- padding=self.padding,
2232
- )
2252
+ rr = self.backend.conv2d(tim, self.ww_RealT[1])
2253
+ ii = self.backend.conv2d(tim, self.ww_ImagT[1])
2233
2254
  res = self.backend.bk_complex(rr, ii)
2234
2255
 
2235
- if axis == 0:
2236
- if len(ishape) == 2:
2237
- return self.backend.bk_reshape(
2238
- res, [res.shape[1], res.shape[2], self.NORIENT]
2239
- )
2240
- else:
2241
- return self.backend.bk_reshape(
2242
- res,
2243
- [res.shape[1], res.shape[2], self.NORIENT] + ishape[axis + 2 :],
2244
- )
2245
- else:
2246
- if len(ishape) == axis + 2:
2247
- return self.backend.bk_reshape(
2248
- res, ishape[0:axis] + [res.shape[1], res.shape[2], self.NORIENT]
2249
- )
2250
- else:
2251
- return self.backend.bk_reshape(
2252
- res,
2253
- ishape[0:axis]
2254
- + [res.shape[1], res.shape[2], self.NORIENT]
2255
- + ishape[axis + 2 :],
2256
- )
2256
+ return self.backend.bk_reshape(
2257
+ res, ishape[0:-2] + [self.NORIENT, npix, npiy]
2258
+ )
2257
2259
 
2258
- return self.backend.bk_reshape(res, in_image.shape + [self.NORIENT])
2259
2260
  elif self.use_1D:
2260
2261
  ishape = list(in_image.shape)
2261
- if len(ishape) < axis + 1:
2262
- if not self.silent:
2263
- print("Use of 1D scat with data that has less than 1D")
2264
- return None
2265
2262
 
2266
- npix = ishape[axis]
2267
- odata = 1
2268
- if len(ishape) > axis + 1:
2269
- for k in range(axis + 1, len(ishape)):
2270
- odata = odata * ishape[k]
2263
+ npix = ishape[-1]
2271
2264
 
2272
2265
  ndata = 1
2273
- for k in range(axis):
2266
+ for k in range(len(ishape) - 1):
2274
2267
  ndata = ndata * ishape[k]
2275
2268
 
2276
- tim = self.backend.bk_reshape(
2277
- self.backend.bk_cast(in_image), [ndata, npix, odata]
2278
- )
2269
+ tim = self.backend.bk_reshape(self.backend.bk_cast(in_image), [ndata, npix])
2279
2270
 
2280
2271
  if self.backend.bk_is_complex(tim):
2281
- rr1 = self.backend.conv1d(
2282
- self.backend.bk_real(tim),
2283
- self.ww_RealT[odata],
2284
- strides=[1, 1, 1],
2285
- padding=self.padding,
2286
- )
2287
- ii1 = self.backend.conv1d(
2288
- self.backend.bk_real(tim),
2289
- self.ww_ImagT[odata],
2290
- strides=[1, 1, 1],
2291
- padding=self.padding,
2292
- )
2293
- rr2 = self.backend.conv1d(
2294
- self.backend.bk_imag(tim),
2295
- self.ww_RealT[odata],
2296
- strides=[1, 1, 1],
2297
- padding=self.padding,
2298
- )
2299
- ii2 = self.backend.conv1d(
2300
- self.backend.bk_imag(tim),
2301
- self.ww_ImagT[odata],
2302
- strides=[1, 1, 1],
2303
- padding=self.padding,
2304
- )
2272
+ rr1 = self.backend.conv1d(self.backend.bk_real(tim), self.ww_RealT[1])
2273
+ ii1 = self.backend.conv1d(self.backend.bk_real(tim), self.ww_ImagT[1])
2274
+ rr2 = self.backend.conv1d(self.backend.bk_imag(tim), self.ww_RealT[1])
2275
+ ii2 = self.backend.conv1d(self.backend.bk_imag(tim), self.ww_ImagT[1])
2305
2276
  res = self.backend.bk_complex(rr1 - ii2, ii1 + rr2)
2306
2277
  else:
2307
- rr = self.backend.conv1d(
2308
- tim, self.ww_RealT[odata], strides=[1, 1, 1], padding=self.padding
2309
- )
2310
- ii = self.backend.conv1d(
2311
- tim, self.ww_ImagT[odata], strides=[1, 1, 1], padding=self.padding
2312
- )
2278
+ rr = self.backend.conv1d(tim, self.ww_RealT[1])
2279
+ ii = self.backend.conv1d(tim, self.ww_ImagT[1])
2313
2280
  res = self.backend.bk_complex(rr, ii)
2314
2281
 
2315
- if axis == 0:
2316
- if len(ishape) == 1:
2317
- return self.backend.bk_reshape(res, [res.shape[1]])
2318
- else:
2319
- return self.backend.bk_reshape(
2320
- res, [res.shape[1]] + ishape[axis + 2 :]
2321
- )
2322
- else:
2323
- if len(ishape) == axis + 1:
2324
- return self.backend.bk_reshape(res, ishape[0:axis] + [res.shape[1]])
2325
- else:
2326
- return self.backend.bk_reshape(
2327
- res, ishape[0:axis] + [res.shape[1]] + ishape[axis + 1 :]
2328
- )
2329
-
2330
- return self.backend.bk_reshape(res, in_image.shape + [self.NORIENT])
2282
+ return self.backend.bk_reshape(res, ishape)
2331
2283
 
2332
2284
  else:
2333
2285
  ishape = list(image.shape)
2334
-
2286
+ """
2335
2287
  if cell_ids is not None:
2336
2288
  if cell_ids.shape[0] not in self.padding_conv:
2289
+ print(image.shape,cell_ids.shape)
2337
2290
  import healpix_convolution as hc
2338
2291
  from xdggs.healpix import HealpixInfo
2339
2292
 
2340
2293
  res = self.backend.bk_zeros(
2341
- ishape + [self.NORIENT], dtype=self.backend.all_cbk_type
2294
+ ishape[0:-1] + [self.NORIENT]+ishape[-1], dtype=self.backend.all_cbk_type
2342
2295
  )
2343
2296
 
2344
2297
  grid_info = HealpixInfo(
@@ -2384,14 +2337,15 @@ class FoCUS:
2384
2337
  padded_data
2385
2338
  ) + 1j * kernelI.matmul(padded_data)
2386
2339
  return res
2387
-
2388
- nside = int(np.sqrt(image.shape[axis] // 12))
2340
+ """
2341
+ if nside is None:
2342
+ nside = int(np.sqrt(image.shape[-1] // 12))
2389
2343
 
2390
2344
  if self.Idx_Neighbours[nside] is None:
2391
2345
  if self.InitWave is None:
2392
- wr, wi, ws, widx = self.init_index(nside)
2346
+ wr, wi, ws, widx = self.init_index(nside, cell_ids=cell_ids)
2393
2347
  else:
2394
- wr, wi, ws, widx = self.InitWave(self, nside)
2348
+ wr, wi, ws, widx = self.InitWave(nside, cell_ids=cell_ids)
2395
2349
 
2396
2350
  self.Idx_Neighbours[nside] = 1 # self.backend.bk_constant(tmp)
2397
2351
  self.ww_Real[nside] = wr
@@ -2401,156 +2355,63 @@ class FoCUS:
2401
2355
  l_ww_real = self.ww_Real[nside]
2402
2356
  l_ww_imag = self.ww_Imag[nside]
2403
2357
 
2404
- odata = 1
2405
- for k in range(axis + 1, len(ishape)):
2406
- odata = odata * ishape[k]
2358
+ # always convolve the last dimension
2407
2359
 
2408
- if axis > 0:
2409
- ndata = 1
2410
- for k in range(axis):
2360
+ ndata = 1
2361
+ if len(ishape) > 1:
2362
+ for k in range(len(ishape) - 1):
2411
2363
  ndata = ndata * ishape[k]
2412
- tim = self.backend.bk_reshape(
2413
- self.backend.bk_cast(image), [ndata, 12 * nside**2, odata]
2414
- )
2415
- if tim.dtype == self.all_cbk_type:
2416
- rr1 = self.backend.bk_reshape(
2417
- self.backend.bk_sparse_dense_matmul(
2418
- l_ww_real, self.backend.bk_real(tim[0])
2419
- ),
2420
- [1, 12 * nside**2, self.NORIENT, odata],
2421
- )
2422
- ii1 = self.backend.bk_reshape(
2423
- self.backend.bk_sparse_dense_matmul(
2424
- l_ww_imag, self.backend.bk_real(tim[0])
2425
- ),
2426
- [1, 12 * nside**2, self.NORIENT, odata],
2427
- )
2428
- rr2 = self.backend.bk_reshape(
2429
- self.backend.bk_sparse_dense_matmul(
2430
- l_ww_real, self.backend.bk_imag(tim[0])
2431
- ),
2432
- [1, 12 * nside**2, self.NORIENT, odata],
2433
- )
2434
- ii2 = self.backend.bk_reshape(
2435
- self.backend.bk_sparse_dense_matmul(
2436
- l_ww_imag, self.backend.bk_imag(tim[0])
2437
- ),
2438
- [1, 12 * nside**2, self.NORIENT, odata],
2439
- )
2440
- res = self.backend.bk_complex(rr1 - ii2, ii1 + rr2)
2441
- else:
2442
- rr = self.backend.bk_reshape(
2443
- self.backend.bk_sparse_dense_matmul(l_ww_real, tim[0]),
2444
- [1, 12 * nside**2, self.NORIENT, odata],
2445
- )
2446
- ii = self.backend.bk_reshape(
2447
- self.backend.bk_sparse_dense_matmul(l_ww_imag, tim[0]),
2448
- [1, 12 * nside**2, self.NORIENT, odata],
2449
- )
2450
- res = self.backend.bk_complex(rr, ii)
2451
-
2452
- for k in range(1, ndata):
2453
- if tim.dtype == self.all_cbk_type:
2454
- rr1 = self.backend.bk_reshape(
2455
- self.backend.bk_sparse_dense_matmul(
2456
- l_ww_real, self.backend.bk_real(tim[k])
2457
- ),
2458
- [1, 12 * nside**2, self.NORIENT, odata],
2459
- )
2460
- ii1 = self.backend.bk_reshape(
2461
- self.backend.bk_sparse_dense_matmul(
2462
- l_ww_imag, self.backend.bk_real(tim[k])
2463
- ),
2464
- [1, 12 * nside**2, self.NORIENT, odata],
2465
- )
2466
- rr2 = self.backend.bk_reshape(
2467
- self.backend.bk_sparse_dense_matmul(
2468
- l_ww_real, self.backend.bk_imag(tim[k])
2469
- ),
2470
- [1, 12 * nside**2, self.NORIENT, odata],
2471
- )
2472
- ii2 = self.backend.bk_reshape(
2473
- self.backend.bk_sparse_dense_matmul(
2474
- l_ww_imag, self.backend.bk_imag(tim[k])
2475
- ),
2476
- [1, 12 * nside**2, self.NORIENT, odata],
2477
- )
2478
- res = self.backend.bk_concat(
2479
- [res, self.backend.bk_complex(rr1 - ii2, ii1 + rr2)], 0
2480
- )
2481
- else:
2482
- rr = self.backend.bk_reshape(
2483
- self.backend.bk_sparse_dense_matmul(l_ww_real, tim[k]),
2484
- [1, 12 * nside**2, self.NORIENT, odata],
2485
- )
2486
- ii = self.backend.bk_reshape(
2487
- self.backend.bk_sparse_dense_matmul(l_ww_imag, tim[k]),
2488
- [1, 12 * nside**2, self.NORIENT, odata],
2489
- )
2490
- res = self.backend.bk_concat(
2491
- [res, self.backend.bk_complex(rr, ii)], 0
2492
- )
2493
-
2494
- if len(ishape) == axis + 1:
2495
- return self.backend.bk_reshape(
2496
- res, ishape[0:axis] + [12 * nside**2, self.NORIENT]
2497
- )
2498
- else:
2499
- return self.backend.bk_reshape(
2500
- res,
2501
- ishape[0:axis]
2502
- + [12 * nside**2]
2503
- + ishape[axis + 1 :]
2504
- + [self.NORIENT],
2505
- )
2364
+ tim = self.backend.bk_reshape(
2365
+ self.backend.bk_cast(image), [ndata, ishape[-1]]
2366
+ )
2506
2367
 
2507
- if axis == 0:
2508
- tim = self.backend.bk_reshape(
2509
- self.backend.bk_cast(image), [12 * nside**2, odata]
2368
+ if tim.dtype == self.all_cbk_type:
2369
+ rr1 = self.backend.bk_reshape(
2370
+ self.backend.bk_sparse_dense_matmul(
2371
+ self.backend.bk_real(tim),
2372
+ l_ww_real,
2373
+ ),
2374
+ [ndata, self.NORIENT, ishape[-1]],
2510
2375
  )
2511
- if tim.dtype == self.all_cbk_type:
2512
- rr1 = self.backend.bk_reshape(
2513
- self.backend.bk_sparse_dense_matmul(
2514
- l_ww_real, self.backend.bk_real(tim)
2515
- ),
2516
- [12 * nside**2, self.NORIENT, odata],
2517
- )
2518
- ii1 = self.backend.bk_reshape(
2519
- self.backend.bk_sparse_dense_matmul(
2520
- l_ww_imag, self.backend.bk_real(tim)
2521
- ),
2522
- [12 * nside**2, self.NORIENT, odata],
2523
- )
2524
- rr2 = self.backend.bk_reshape(
2525
- self.backend.bk_sparse_dense_matmul(
2526
- l_ww_real, self.backend.bk_imag(tim)
2527
- ),
2528
- [12 * nside**2, self.NORIENT, odata],
2529
- )
2530
- ii2 = self.backend.bk_reshape(
2531
- self.backend.bk_sparse_dense_matmul(
2532
- l_ww_imag, self.backend.bk_imag(tim)
2533
- ),
2534
- [12 * nside**2, self.NORIENT, odata],
2535
- )
2536
- res = self.backend.bk_complex(rr1 - ii2, ii1 + rr2)
2537
- else:
2538
- rr = self.backend.bk_reshape(
2539
- self.backend.bk_sparse_dense_matmul(l_ww_real, tim),
2540
- [12 * nside**2, self.NORIENT, odata],
2541
- )
2542
- ii = self.backend.bk_reshape(
2543
- self.backend.bk_sparse_dense_matmul(l_ww_imag, tim),
2544
- [12 * nside**2, self.NORIENT, odata],
2545
- )
2546
- res = self.backend.bk_complex(rr, ii)
2376
+ ii1 = self.backend.bk_reshape(
2377
+ self.backend.bk_sparse_dense_matmul(
2378
+ self.backend.bk_real(tim),
2379
+ l_ww_imag,
2380
+ ),
2381
+ [ndata, self.NORIENT, ishape[-1]],
2382
+ )
2383
+ rr2 = self.backend.bk_reshape(
2384
+ self.backend.bk_sparse_dense_matmul(
2385
+ self.backend.bk_imag(tim),
2386
+ l_ww_real,
2387
+ ),
2388
+ [ndata, self.NORIENT, ishape[-1]],
2389
+ )
2390
+ ii2 = self.backend.bk_reshape(
2391
+ self.backend.bk_sparse_dense_matmul(
2392
+ self.backend.bk_imag(tim),
2393
+ l_ww_imag,
2394
+ ),
2395
+ [ndata, self.NORIENT, ishape[-1]],
2396
+ )
2397
+ res = self.backend.bk_complex(rr1 - ii2, ii1 + rr2)
2398
+ else:
2399
+ rr = self.backend.bk_reshape(
2400
+ self.backend.bk_sparse_dense_matmul(tim, l_ww_real),
2401
+ [ndata, self.NORIENT, ishape[-1]],
2402
+ )
2403
+ ii = self.backend.bk_reshape(
2404
+ self.backend.bk_sparse_dense_matmul(tim, l_ww_imag),
2405
+ [ndata, self.NORIENT, ishape[-1]],
2406
+ )
2407
+ res = self.backend.bk_complex(rr, ii)
2408
+ if len(ishape) > 1:
2409
+ return self.backend.bk_reshape(
2410
+ res, ishape[0:-1] + [self.NORIENT, ishape[-1]]
2411
+ )
2412
+ else:
2413
+ return self.backend.bk_reshape(res, [self.NORIENT, ishape[-1]])
2547
2414
 
2548
- if len(ishape) == 1:
2549
- return self.backend.bk_reshape(res, [12 * nside**2, self.NORIENT])
2550
- else:
2551
- return self.backend.bk_reshape(
2552
- res, [12 * nside**2] + ishape[axis + 1 :] + [self.NORIENT]
2553
- )
2554
2415
  return res
2555
2416
 
2556
2417
  # ---------------------------------------------−---------
@@ -2578,114 +2439,43 @@ class FoCUS:
2578
2439
  ndata = ndata * ishape[k]
2579
2440
 
2580
2441
  tim = self.backend.bk_reshape(
2581
- self.backend.bk_cast(in_image), [ndata, npix, npiy, odata]
2442
+ self.backend.bk_cast(in_image), [ndata, npix, npiy]
2582
2443
  )
2583
2444
 
2584
2445
  if self.backend.bk_is_complex(tim):
2585
- rr = self.backend.conv2d(
2586
- self.backend.bk_real(tim),
2587
- self.ww_SmoothT[odata],
2588
- strides=[1, 1, 1, 1],
2589
- padding=self.padding,
2590
- )
2591
- ii = self.backend.conv2d(
2592
- self.backend.bk_imag(tim),
2593
- self.ww_SmoothT[odata],
2594
- strides=[1, 1, 1, 1],
2595
- padding=self.padding,
2596
- )
2446
+ rr = self.backend.conv2d(self.backend.bk_real(tim), self.ww_SmoothT[1])
2447
+ ii = self.backend.conv2d(self.backend.bk_imag(tim), self.ww_SmoothT[1])
2597
2448
  res = self.backend.bk_complex(rr, ii)
2598
2449
  else:
2599
- res = self.backend.conv2d(
2600
- tim,
2601
- self.ww_SmoothT[odata],
2602
- strides=[1, 1, 1, 1],
2603
- padding=self.padding,
2604
- )
2450
+ res = self.backend.conv2d(tim, self.ww_SmoothT[1])
2605
2451
 
2606
- if axis == 0:
2607
- if len(ishape) == 2:
2608
- return self.backend.bk_reshape(res, [res.shape[1], res.shape[2]])
2609
- else:
2610
- return self.backend.bk_reshape(
2611
- res, [res.shape[1], res.shape[2]] + ishape[axis + 2 :]
2612
- )
2613
- else:
2614
- if len(ishape) == axis + 2:
2615
- return self.backend.bk_reshape(
2616
- res, ishape[0:axis] + [res.shape[1], res.shape[2]]
2617
- )
2618
- else:
2619
- return self.backend.bk_reshape(
2620
- res,
2621
- ishape[0:axis]
2622
- + [res.shape[1], res.shape[2]]
2623
- + ishape[axis + 2 :],
2624
- )
2452
+ return self.backend.bk_reshape(res, ishape)
2625
2453
 
2626
- return self.backend.bk_reshape(res, in_image.shape)
2627
2454
  elif self.use_1D:
2628
2455
 
2629
2456
  ishape = list(in_image.shape)
2630
- if len(ishape) < axis + 1:
2631
- if not self.silent:
2632
- print("Use of 1D scat with data that has less than 1D")
2633
- return None
2634
2457
 
2635
- npix = ishape[axis]
2636
- odata = 1
2637
- if len(ishape) > axis + 1:
2638
- for k in range(axis + 1, len(ishape)):
2639
- odata = odata * ishape[k]
2458
+ npix = ishape[-1]
2640
2459
 
2641
2460
  ndata = 1
2642
- for k in range(axis):
2461
+ for k in range(len(ishape) - 1):
2643
2462
  ndata = ndata * ishape[k]
2644
2463
 
2645
- tim = self.backend.bk_reshape(
2646
- self.backend.bk_cast(in_image), [ndata, npix, odata]
2647
- )
2464
+ tim = self.backend.bk_reshape(self.backend.bk_cast(in_image), [ndata, npix])
2648
2465
 
2649
2466
  if self.backend.bk_is_complex(tim):
2650
- rr = self.backend.conv1d(
2651
- self.backend.bk_real(tim),
2652
- self.ww_SmoothT[odata],
2653
- strides=[1, 1, 1],
2654
- padding=self.padding,
2655
- )
2656
- ii = self.backend.conv1d(
2657
- self.backend.bk_imag(tim),
2658
- self.ww_SmoothT[odata],
2659
- strides=[1, 1, 1],
2660
- padding=self.padding,
2661
- )
2467
+ rr = self.backend.conv1d(self.backend.bk_real(tim), self.ww_SmoothT[1])
2468
+ ii = self.backend.conv1d(self.backend.bk_imag(tim), self.ww_SmoothT[1])
2662
2469
  res = self.backend.bk_complex(rr, ii)
2663
2470
  else:
2664
- res = self.backend.conv1d(
2665
- tim, self.ww_SmoothT[odata], strides=[1, 1, 1], padding=self.padding
2666
- )
2667
-
2668
- if axis == 0:
2669
- if len(ishape) == 1:
2670
- return self.backend.bk_reshape(res, [res.shape[1]])
2671
- else:
2672
- return self.backend.bk_reshape(
2673
- res, [res.shape[1]] + ishape[axis + 1 :]
2674
- )
2675
- else:
2676
- if len(ishape) == axis + 1:
2677
- return self.backend.bk_reshape(res, ishape[0:axis] + [res.shape[1]])
2678
- else:
2679
- return self.backend.bk_reshape(
2680
- res, ishape[0:axis] + [res.shape[1]] + ishape[axis + 1 :]
2681
- )
2471
+ res = self.backend.conv1d(tim, self.ww_SmoothT[1])
2682
2472
 
2683
- return self.backend.bk_reshape(res, in_image.shape)
2473
+ return self.backend.bk_reshape(res, ishape)
2684
2474
 
2685
2475
  else:
2686
2476
 
2687
2477
  ishape = list(image.shape)
2688
-
2478
+ """
2689
2479
  if cell_ids is not None:
2690
2480
  if cell_ids.shape[0] not in self.padding_smooth:
2691
2481
  import healpix_convolution as hc
@@ -2726,15 +2516,16 @@ class FoCUS:
2726
2516
  padded_data = padding.apply(image[l, :, k2], is_torch=True)
2727
2517
  res[l, :, k2] = kernel.matmul(padded_data)
2728
2518
  return res
2729
-
2730
- nside = int(np.sqrt(image.shape[axis] // 12))
2519
+ """
2520
+ if nside is None:
2521
+ nside = int(np.sqrt(image.shape[-1] // 12))
2731
2522
 
2732
2523
  if self.Idx_Neighbours[nside] is None:
2733
2524
 
2734
2525
  if self.InitWave is None:
2735
- wr, wi, ws, widx = self.init_index(nside)
2526
+ wr, wi, ws, widx = self.init_index(nside, cell_ids=cell_ids)
2736
2527
  else:
2737
- wr, wi, ws, widx = self.InitWave(self, nside)
2528
+ wr, wi, ws, widx = self.InitWave(self, nside, cell_ids=cell_ids)
2738
2529
 
2739
2530
  self.Idx_Neighbours[nside] = 1
2740
2531
  self.ww_Real[nside] = wr
@@ -2744,92 +2535,24 @@ class FoCUS:
2744
2535
  l_w_smooth = self.w_smooth[nside]
2745
2536
 
2746
2537
  odata = 1
2747
- for k in range(axis + 1, len(ishape)):
2538
+ for k in range(0, len(ishape) - 1):
2748
2539
  odata = odata * ishape[k]
2749
2540
 
2750
- if axis == 0:
2751
- tim = self.backend.bk_reshape(image, [12 * nside**2, odata])
2752
- if tim.dtype == self.all_cbk_type:
2753
- rr = self.backend.bk_sparse_dense_matmul(
2754
- l_w_smooth, self.backend.bk_real(tim)
2755
- )
2756
- ri = self.backend.bk_sparse_dense_matmul(
2757
- l_w_smooth, self.backend.bk_imag(tim)
2758
- )
2759
- res = self.backend.bk_complex(rr, ri)
2760
- else:
2761
- res = self.backend.bk_sparse_dense_matmul(l_w_smooth, tim)
2762
- if len(ishape) == 1:
2763
- return self.backend.bk_reshape(res, [12 * nside**2])
2764
- else:
2765
- return self.backend.bk_reshape(
2766
- res, [12 * nside**2] + ishape[axis + 1 :]
2767
- )
2768
-
2769
- if axis > 0:
2770
- ndata = ishape[0]
2771
- for k in range(1, axis):
2772
- ndata = ndata * ishape[k]
2773
- tim = self.backend.bk_reshape(image, [ndata, 12 * nside**2, odata])
2774
- if tim.dtype == self.all_cbk_type:
2775
- rr = self.backend.bk_reshape(
2776
- self.backend.bk_sparse_dense_matmul(
2777
- l_w_smooth, self.backend.bk_real(tim[0])
2778
- ),
2779
- [1, 12 * nside**2, odata],
2780
- )
2781
- ri = self.backend.bk_reshape(
2782
- self.backend.bk_sparse_dense_matmul(
2783
- l_w_smooth, self.backend.bk_imag(tim[0])
2784
- ),
2785
- [1, 12 * nside**2, odata],
2786
- )
2787
- res = self.backend.bk_complex(rr, ri)
2788
- else:
2789
- res = self.backend.bk_reshape(
2790
- self.backend.bk_sparse_dense_matmul(l_w_smooth, tim[0]),
2791
- [1, 12 * nside**2, odata],
2792
- )
2793
-
2794
- for k in range(1, ndata):
2795
- if tim.dtype == self.all_cbk_type:
2796
- rr = self.backend.bk_reshape(
2797
- self.backend.bk_sparse_dense_matmul(
2798
- l_w_smooth, self.backend.bk_real(tim[k])
2799
- ),
2800
- [1, 12 * nside**2, odata],
2801
- )
2802
- ri = self.backend.bk_reshape(
2803
- self.backend.bk_sparse_dense_matmul(
2804
- l_w_smooth, self.backend.bk_imag(tim[k])
2805
- ),
2806
- [1, 12 * nside**2, odata],
2807
- )
2808
- res = self.backend.bk_concat(
2809
- [res, self.backend.bk_complex(rr, ri)], 0
2810
- )
2811
- else:
2812
- res = self.backend.bk_concat(
2813
- [
2814
- res,
2815
- self.backend.bk_reshape(
2816
- self.backend.bk_sparse_dense_matmul(
2817
- l_w_smooth, tim[k]
2818
- ),
2819
- [1, 12 * nside**2, odata],
2820
- ),
2821
- ],
2822
- 0,
2823
- )
2824
-
2825
- if len(ishape) == axis + 1:
2826
- return self.backend.bk_reshape(
2827
- res, ishape[0:axis] + [12 * nside**2]
2828
- )
2829
- else:
2830
- return self.backend.bk_reshape(
2831
- res, ishape[0:axis] + [12 * nside**2] + ishape[axis + 1 :]
2832
- )
2541
+ tim = self.backend.bk_reshape(image, [odata, ishape[-1]])
2542
+ if tim.dtype == self.all_cbk_type:
2543
+ rr = self.backend.bk_sparse_dense_matmul(
2544
+ self.backend.bk_real(tim), l_w_smooth
2545
+ )
2546
+ ri = self.backend.bk_sparse_dense_matmul(
2547
+ self.backend.bk_imag(tim), l_w_smooth
2548
+ )
2549
+ res = self.backend.bk_complex(rr, ri)
2550
+ else:
2551
+ res = self.backend.bk_sparse_dense_matmul(tim, l_w_smooth)
2552
+ if len(ishape) == 1:
2553
+ return self.backend.bk_reshape(res, [ishape[-1]])
2554
+ else:
2555
+ return self.backend.bk_reshape(res, ishape[0:-1] + [ishape[-1]])
2833
2556
 
2834
2557
  return res
2835
2558