foscat 2025.5.0__py3-none-any.whl → 2025.5.2__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
foscat/FoCUS.py CHANGED
@@ -5,7 +5,7 @@ import healpy as hp
5
5
  import numpy as np
6
6
  from scipy.interpolate import griddata
7
7
 
8
- TMPFILE_VERSION = "V4_0"
8
+ TMPFILE_VERSION = "V5_0"
9
9
 
10
10
 
11
11
  class FoCUS:
@@ -35,7 +35,7 @@ class FoCUS:
35
35
  mpi_rank=0,
36
36
  ):
37
37
 
38
- self.__version__ = "2025.05.0"
38
+ self.__version__ = "2025.05.2"
39
39
  # P00 coeff for normalization for scat_cov
40
40
  self.TMPFILE_VERSION = TMPFILE_VERSION
41
41
  self.P1_dic = None
@@ -179,8 +179,8 @@ class FoCUS:
179
179
  self.filters_set = {}
180
180
  self.edge_masks = {}
181
181
 
182
- wwc = np.zeros([KERNELSZ**2, l_NORIENT]).astype(all_type)
183
- wws = np.zeros([KERNELSZ**2, l_NORIENT]).astype(all_type)
182
+ wwc = np.zeros([l_NORIENT, KERNELSZ**2]).astype(all_type)
183
+ wws = np.zeros([l_NORIENT, KERNELSZ**2]).astype(all_type)
184
184
 
185
185
  x = np.repeat(np.arange(KERNELSZ) - KERNELSZ // 2, KERNELSZ).reshape(
186
186
  KERNELSZ, KERNELSZ
@@ -203,12 +203,12 @@ class FoCUS:
203
203
  -0.5 * (xx**2 + yy**2)
204
204
  )
205
205
 
206
- wwc[:, 0] = tmp.flatten() - tmp.mean()
206
+ wwc[0] = tmp.flatten() - tmp.mean()
207
207
  tmp = 0 * w_smooth
208
- wws[:, 0] = tmp.flatten()
208
+ wws[0] = tmp.flatten()
209
209
  sigma = np.sqrt((wwc[:, 0] ** 2).mean())
210
- wwc[:, 0] /= sigma
211
- wws[:, 0] /= sigma
210
+ wwc[0] /= sigma
211
+ wws[0] /= sigma
212
212
 
213
213
  w_smooth = w_smooth.flatten()
214
214
  else:
@@ -239,12 +239,12 @@ class FoCUS:
239
239
  tmp1 = np.cos(yy * np.pi) * w_smooth
240
240
  tmp2 = np.sin(yy * np.pi) * w_smooth
241
241
 
242
- wwc[:, i] = tmp1.flatten() - tmp1.mean()
243
- wws[:, i] = tmp2.flatten() - tmp2.mean()
242
+ wwc[i] = tmp1.flatten() - tmp1.mean()
243
+ wws[i] = tmp2.flatten() - tmp2.mean()
244
244
  # sigma = np.sqrt((wwc[:, i] ** 2).mean())
245
245
  sigma = np.mean(w_smooth)
246
- wwc[:, i] /= sigma
247
- wws[:, i] /= sigma
246
+ wwc[i] /= sigma
247
+ wws[i] /= sigma
248
248
 
249
249
  if DODIV and i == 0:
250
250
  r = xx**2 + yy**2
@@ -253,22 +253,22 @@ class FoCUS:
253
253
  tmp1 = r * np.cos(2 * theta) * w_smooth
254
254
  tmp2 = r * np.sin(2 * theta) * w_smooth
255
255
 
256
- wwc[:, NORIENT] = tmp1.flatten() - tmp1.mean()
257
- wws[:, NORIENT] = tmp2.flatten() - tmp2.mean()
256
+ wwc[NORIENT] = tmp1.flatten() - tmp1.mean()
257
+ wws[NORIENT] = tmp2.flatten() - tmp2.mean()
258
258
  # sigma = np.sqrt((wwc[:, NORIENT] ** 2).mean())
259
259
  sigma = np.mean(w_smooth)
260
260
 
261
- wwc[:, NORIENT] /= sigma
262
- wws[:, NORIENT] /= sigma
261
+ wwc[NORIENT] /= sigma
262
+ wws[NORIENT] /= sigma
263
263
  tmp1 = r * np.cos(2 * theta + np.pi)
264
264
  tmp2 = r * np.sin(2 * theta + np.pi)
265
265
 
266
- wwc[:, NORIENT + 1] = tmp1.flatten() - tmp1.mean()
267
- wws[:, NORIENT + 1] = tmp2.flatten() - tmp2.mean()
266
+ wwc[NORIENT + 1] = tmp1.flatten() - tmp1.mean()
267
+ wws[NORIENT + 1] = tmp2.flatten() - tmp2.mean()
268
268
  # sigma = np.sqrt((wwc[:, NORIENT + 1] ** 2).mean())
269
269
  sigma = np.mean(w_smooth)
270
- wwc[:, NORIENT + 1] /= sigma
271
- wws[:, NORIENT + 1] /= sigma
270
+ wwc[NORIENT + 1] /= sigma
271
+ wws[NORIENT + 1] /= sigma
272
272
 
273
273
  w_smooth = w_smooth.flatten()
274
274
 
@@ -316,14 +316,14 @@ class FoCUS:
316
316
  r = np.sum(np.sqrt(c * c + s * s))
317
317
  c = c / r
318
318
  s = s / r
319
- self.ww_RealT[1] = self.backend.bk_constant(
320
- np.array(c).reshape(xx.shape[0], 1, 1)
319
+ self.ww_RealT[1] = self.backend.bk_cast(
320
+ self.backend.bk_constant(np.array(c).reshape(xx.shape[0]))
321
321
  )
322
- self.ww_ImagT[1] = self.backend.bk_constant(
323
- np.array(s).reshape(xx.shape[0], 1, 1)
322
+ self.ww_ImagT[1] = self.backend.bk_cast(
323
+ self.backend.bk_constant(np.array(s).reshape(xx.shape[0]))
324
324
  )
325
- self.ww_SmoothT[1] = self.backend.bk_constant(
326
- np.array(w).reshape(xx.shape[0], 1, 1)
325
+ self.ww_SmoothT[1] = self.backend.bk_cast(
326
+ self.backend.bk_constant(np.array(w).reshape(xx.shape[0]))
327
327
  )
328
328
 
329
329
  else:
@@ -333,22 +333,16 @@ class FoCUS:
333
333
  self.ww_SmoothT = {}
334
334
 
335
335
  self.ww_SmoothT[1] = self.backend.bk_constant(
336
- self.w_smooth.reshape(KERNELSZ, KERNELSZ, 1, 1)
337
- )
338
- www = np.zeros([KERNELSZ, KERNELSZ, NORIENT, NORIENT], dtype=self.all_type)
339
- for k in range(NORIENT):
340
- www[:, :, k, k] = self.w_smooth.reshape(KERNELSZ, KERNELSZ)
341
- self.ww_SmoothT[NORIENT] = self.backend.bk_constant(
342
- www.reshape(KERNELSZ, KERNELSZ, NORIENT, NORIENT)
336
+ self.w_smooth.reshape(1, KERNELSZ, KERNELSZ)
343
337
  )
344
338
  self.ww_RealT[1] = self.backend.bk_constant(
345
339
  self.backend.bk_reshape(
346
- wwc.astype(self.all_type), [KERNELSZ, KERNELSZ, 1, NORIENT]
340
+ wwc.astype(self.all_type), [NORIENT, KERNELSZ, KERNELSZ]
347
341
  )
348
342
  )
349
343
  self.ww_ImagT[1] = self.backend.bk_constant(
350
344
  self.backend.bk_reshape(
351
- wws.astype(self.all_type), [KERNELSZ, KERNELSZ, 1, NORIENT]
345
+ wws.astype(self.all_type), [NORIENT, KERNELSZ, KERNELSZ]
352
346
  )
353
347
  )
354
348
 
@@ -806,53 +800,20 @@ class FoCUS:
806
800
  return self.backend.bk_reshape(res, [npix // 2, npiy // 2]), None
807
801
  elif self.use_1D:
808
802
  ishape = list(im.shape)
809
- if len(ishape) < axis + 1:
810
- if not self.silent:
811
- print("Use of 1D scat with data that has less than 1D")
812
- return None, None
813
803
 
814
- npix = im.shape[axis]
815
- odata = 1
816
- if len(ishape) > axis + 1:
817
- for k in range(axis + 1, len(ishape)):
818
- odata = odata * ishape[k]
804
+ npix = ishape[-1]
819
805
 
820
806
  ndata = 1
821
- for k in range(axis):
807
+ for k in range(len(ishape) - 1):
822
808
  ndata = ndata * ishape[k]
823
809
 
824
810
  tim = self.backend.bk_reshape(
825
- self.backend.bk_cast(im), [ndata, npix, odata]
826
- )
827
- tim = self.backend.bk_reshape(
828
- tim[:, 0 : 2 * (npix // 2), :], [ndata, npix // 2, 2, odata]
811
+ self.backend.bk_cast(im), [ndata, npix // 2, 2]
829
812
  )
830
813
 
831
- res = self.backend.bk_reduce_mean(tim, 2)
832
-
833
- if axis == 0:
834
- if len(ishape) == 1:
835
- return self.backend.bk_reshape(res, [npix // 2]), None
836
- else:
837
- return (
838
- self.backend.bk_reshape(res, [npix // 2] + ishape[axis + 1 :]),
839
- None,
840
- )
841
- else:
842
- if len(ishape) == axis + 1:
843
- return (
844
- self.backend.bk_reshape(res, ishape[0:axis] + [npix // 2]),
845
- None,
846
- )
847
- else:
848
- return (
849
- self.backend.bk_reshape(
850
- res, ishape[0:axis] + [npix // 2] + ishape[axis + 1 :]
851
- ),
852
- None,
853
- )
814
+ res = self.backend.bk_reduce_mean(tim, -1)
854
815
 
855
- return self.backend.bk_reshape(res, [npix // 2]), None
816
+ return self.backend.bk_reshape(res, ishape[0:-1] + [npix // 2]), None
856
817
 
857
818
  else:
858
819
  shape = list(im.shape)
@@ -1384,13 +1345,18 @@ class FoCUS:
1384
1345
  return res
1385
1346
 
1386
1347
  # ---------------------------------------------−---------
1387
- def init_index(self, nside, kernel=-1):
1348
+ def init_index(self, nside, kernel=-1, cell_ids=None):
1388
1349
 
1389
1350
  if kernel == -1:
1390
1351
  l_kernel = self.KERNELSZ
1391
1352
  else:
1392
1353
  l_kernel = kernel
1393
1354
 
1355
+ if cell_ids is not None:
1356
+ ncell = cell_ids.shape[0]
1357
+ else:
1358
+ ncell = 12 * nside * nside
1359
+
1394
1360
  try:
1395
1361
  if self.use_2D:
1396
1362
  tmp = np.load(
@@ -1398,16 +1364,29 @@ class FoCUS:
1398
1364
  % (self.TEMPLATE_PATH, l_kernel**2, TMPFILE_VERSION, nside)
1399
1365
  )
1400
1366
  else:
1401
- tmp = np.load(
1402
- "%s/FOSCAT_%s_W%d_%d_%d_PIDX.npy"
1403
- % (
1404
- self.TEMPLATE_PATH,
1405
- TMPFILE_VERSION,
1406
- l_kernel**2,
1407
- self.NORIENT,
1408
- nside,
1367
+ if cell_ids is not None:
1368
+ tmp = np.load(
1369
+ "%s/XXXX_%s_W%d_%d_%d_PIDX.npy" # can not work
1370
+ % (
1371
+ self.TEMPLATE_PATH,
1372
+ TMPFILE_VERSION,
1373
+ l_kernel**2,
1374
+ self.NORIENT,
1375
+ nside, # if cell_ids computes the index
1376
+ )
1377
+ )
1378
+
1379
+ else:
1380
+ tmp = np.load(
1381
+ "%s/FOSCAT_%s_W%d_%d_%d_PIDX.npy"
1382
+ % (
1383
+ self.TEMPLATE_PATH,
1384
+ TMPFILE_VERSION,
1385
+ l_kernel**2,
1386
+ self.NORIENT,
1387
+ nside, # if cell_ids computes the index
1388
+ )
1409
1389
  )
1410
- )
1411
1390
  except:
1412
1391
  if not self.use_2D:
1413
1392
 
@@ -1426,36 +1405,64 @@ class FoCUS:
1426
1405
  pw2 = 0.25
1427
1406
  threshold = 4e-5
1428
1407
 
1429
- th, ph = hp.pix2ang(nside, np.arange(12 * nside**2), nest=True)
1430
- x, y, z = hp.pix2vec(nside, np.arange(12 * nside**2), nest=True)
1408
+ if cell_ids is not None:
1409
+ if not isinstance(cell_ids, np.ndarray):
1410
+ cell_ids = self.backend.to_numpy(cell_ids)
1411
+ th, ph = hp.pix2ang(nside, cell_ids, nest=True)
1412
+ x, y, z = hp.pix2vec(nside, cell_ids, nest=True)
1431
1413
 
1432
- t, p = hp.pix2ang(nside, np.arange(12 * nside * nside), nest=True)
1433
- phi = [p[k] / np.pi * 180 for k in range(12 * nside * nside)]
1434
- thi = [t[k] / np.pi * 180 for k in range(12 * nside * nside)]
1414
+ t, p = hp.pix2ang(nside, cell_ids, nest=True)
1415
+ phi = [p[k] / np.pi * 180 for k in range(ncell)]
1416
+ thi = [t[k] / np.pi * 180 for k in range(ncell)]
1435
1417
 
1436
- indice2 = np.zeros([12 * nside * nside * 64, 2], dtype="int")
1437
- indice = np.zeros(
1438
- [12 * nside * nside * 64 * self.NORIENT, 2], dtype="int"
1439
- )
1440
- wav = np.zeros(
1441
- [12 * nside * nside * 64 * self.NORIENT], dtype="complex"
1442
- )
1443
- wwav = np.zeros([12 * nside * nside * 64 * self.NORIENT], dtype="float")
1418
+ indice2 = np.zeros([ncell * 64, 2], dtype="int")
1419
+ indice = np.zeros([ncell * 64 * self.NORIENT, 2], dtype="int")
1420
+ wav = np.zeros([ncell * 64 * self.NORIENT], dtype="complex")
1421
+ wwav = np.zeros([ncell * 64 * self.NORIENT], dtype="float")
1422
+
1423
+ else:
1444
1424
 
1425
+ th, ph = hp.pix2ang(nside, np.arange(12 * nside**2), nest=True)
1426
+ x, y, z = hp.pix2vec(nside, np.arange(12 * nside**2), nest=True)
1427
+
1428
+ t, p = hp.pix2ang(nside, np.arange(12 * nside * nside), nest=True)
1429
+ phi = [p[k] / np.pi * 180 for k in range(12 * nside * nside)]
1430
+ thi = [t[k] / np.pi * 180 for k in range(12 * nside * nside)]
1431
+
1432
+ indice2 = np.zeros([12 * nside * nside * 64, 2], dtype="int")
1433
+ indice = np.zeros(
1434
+ [12 * nside * nside * 64 * self.NORIENT, 2], dtype="int"
1435
+ )
1436
+ wav = np.zeros(
1437
+ [12 * nside * nside * 64 * self.NORIENT], dtype="complex"
1438
+ )
1439
+ wwav = np.zeros(
1440
+ [12 * nside * nside * 64 * self.NORIENT], dtype="float"
1441
+ )
1445
1442
  iv = 0
1446
1443
  iv2 = 0
1447
- for iii in range(12 * nside * nside):
1448
1444
 
1449
- if iii % (nside * nside) == nside * nside - 1:
1450
- if not self.silent:
1451
- print(
1452
- "Pre-compute nside=%6d %.2f%%"
1453
- % (nside, 100 * iii / (12 * nside * nside))
1454
- )
1445
+ for iii in range(ncell):
1446
+ if cell_ids is None:
1447
+ if iii % (nside * nside) == nside * nside - 1:
1448
+ if not self.silent:
1449
+ print(
1450
+ "Pre-compute nside=%6d %.2f%%"
1451
+ % (nside, 100 * iii / (12 * nside * nside))
1452
+ )
1455
1453
 
1456
- hidx = hp.query_disc(
1457
- nside, [x[iii], y[iii], z[iii]], 2 * np.pi / nside, nest=True
1458
- )
1454
+ if cell_ids is not None:
1455
+ hidx = np.where(
1456
+ (x - x[iii]) ** 2 + (y - y[iii]) ** 2 + (z - z[iii]) ** 2
1457
+ < (2 * np.pi / nside) ** 2
1458
+ )[0]
1459
+ else:
1460
+ hidx = hp.query_disc(
1461
+ nside,
1462
+ [x[iii], y[iii], z[iii]],
1463
+ 2 * np.pi / nside,
1464
+ nest=True,
1465
+ )
1459
1466
 
1460
1467
  R = hp.Rotator(rot=[phi[iii], -thi[iii]], eulertype="ZYZ")
1461
1468
 
@@ -1474,8 +1481,8 @@ class FoCUS:
1474
1481
  )
1475
1482
  idx = np.where((ww**2) > threshold)[0]
1476
1483
  nval2 = len(idx)
1477
- indice2[iv2 : iv2 + nval2, 0] = iii
1478
- indice2[iv2 : iv2 + nval2, 1] = hidx[idx]
1484
+ indice2[iv2 : iv2 + nval2, 1] = iii
1485
+ indice2[iv2 : iv2 + nval2, 0] = hidx[idx]
1479
1486
  wwav[iv2 : iv2 + nval2] = ww[idx] / np.sum(ww[idx])
1480
1487
  iv2 += nval2
1481
1488
 
@@ -1497,15 +1504,18 @@ class FoCUS:
1497
1504
  idx = np.where(vnorm > threshold)[0]
1498
1505
 
1499
1506
  nval = len(idx)
1500
- indice[iv : iv + nval, 0] = iii * 4 + l_rotation
1501
- indice[iv : iv + nval, 1] = hidx[idx]
1507
+ indice[iv : iv + nval, 1] = iii + l_rotation * ncell
1508
+ indice[iv : iv + nval, 0] = hidx[idx]
1502
1509
  # print([hidx[k] for k in idx])
1503
1510
  # print(hp.query_disc(nside, [x[iii],y[iii],z[iii]], np.pi/nside,nest=True))
1504
1511
  normr = np.mean(wresr[idx])
1505
1512
  normi = np.mean(wresi[idx])
1506
1513
 
1507
1514
  val = wresr[idx] - normr + 1j * (wresi[idx] - normi)
1508
- val = val / abs(val).sum()
1515
+ r = abs(val).sum()
1516
+
1517
+ if r > 0:
1518
+ val = val / r
1509
1519
 
1510
1520
  wav[iv : iv + nval] = val
1511
1521
  iv += nval
@@ -1609,56 +1619,57 @@ class FoCUS:
1609
1619
  wav=w.flatten()
1610
1620
  wwav=wwav.flatten()
1611
1621
  """
1612
- if not self.silent:
1613
- print(
1614
- "Write FOSCAT_%s_W%d_%d_%d_PIDX.npy"
1615
- % (TMPFILE_VERSION, self.KERNELSZ**2, self.NORIENT, nside)
1622
+ if cell_ids is None:
1623
+ if not self.silent:
1624
+ print(
1625
+ "Write FOSCAT_%s_W%d_%d_%d_PIDX.npy"
1626
+ % (TMPFILE_VERSION, self.KERNELSZ**2, self.NORIENT, nside)
1627
+ )
1628
+ np.save(
1629
+ "%s/FOSCAT_%s_W%d_%d_%d_PIDX.npy"
1630
+ % (
1631
+ self.TEMPLATE_PATH,
1632
+ TMPFILE_VERSION,
1633
+ self.KERNELSZ**2,
1634
+ self.NORIENT,
1635
+ nside,
1636
+ ),
1637
+ indice,
1638
+ )
1639
+ np.save(
1640
+ "%s/FOSCAT_%s_W%d_%d_%d_WAVE.npy"
1641
+ % (
1642
+ self.TEMPLATE_PATH,
1643
+ TMPFILE_VERSION,
1644
+ self.KERNELSZ**2,
1645
+ self.NORIENT,
1646
+ nside,
1647
+ ),
1648
+ wav,
1649
+ )
1650
+ np.save(
1651
+ "%s/FOSCAT_%s_W%d_%d_%d_PIDX2.npy"
1652
+ % (
1653
+ self.TEMPLATE_PATH,
1654
+ TMPFILE_VERSION,
1655
+ self.KERNELSZ**2,
1656
+ self.NORIENT,
1657
+ nside,
1658
+ ),
1659
+ indice2,
1660
+ )
1661
+ np.save(
1662
+ "%s/FOSCAT_%s_W%d_%d_%d_SMOO.npy"
1663
+ % (
1664
+ self.TEMPLATE_PATH,
1665
+ TMPFILE_VERSION,
1666
+ self.KERNELSZ**2,
1667
+ self.NORIENT,
1668
+ nside,
1669
+ ),
1670
+ wwav,
1616
1671
  )
1617
- np.save(
1618
- "%s/FOSCAT_%s_W%d_%d_%d_PIDX.npy"
1619
- % (
1620
- self.TEMPLATE_PATH,
1621
- TMPFILE_VERSION,
1622
- self.KERNELSZ**2,
1623
- self.NORIENT,
1624
- nside,
1625
- ),
1626
- indice,
1627
- )
1628
- np.save(
1629
- "%s/FOSCAT_%s_W%d_%d_%d_WAVE.npy"
1630
- % (
1631
- self.TEMPLATE_PATH,
1632
- TMPFILE_VERSION,
1633
- self.KERNELSZ**2,
1634
- self.NORIENT,
1635
- nside,
1636
- ),
1637
- wav,
1638
- )
1639
- np.save(
1640
- "%s/FOSCAT_%s_W%d_%d_%d_PIDX2.npy"
1641
- % (
1642
- self.TEMPLATE_PATH,
1643
- TMPFILE_VERSION,
1644
- self.KERNELSZ**2,
1645
- self.NORIENT,
1646
- nside,
1647
- ),
1648
- indice2,
1649
- )
1650
- np.save(
1651
- "%s/FOSCAT_%s_W%d_%d_%d_SMOO.npy"
1652
- % (
1653
- self.TEMPLATE_PATH,
1654
- TMPFILE_VERSION,
1655
- self.KERNELSZ**2,
1656
- self.NORIENT,
1657
- nside,
1658
- ),
1659
- wwav,
1660
- )
1661
- else:
1672
+ if self.use_2D:
1662
1673
  if l_kernel**2 == 9:
1663
1674
  if self.rank == 0:
1664
1675
  self.comp_idx_w9(nside)
@@ -1674,23 +1685,24 @@ class FoCUS:
1674
1685
  )
1675
1686
  return None
1676
1687
 
1677
- self.barrier()
1678
- if self.use_2D:
1679
- tmp = np.load(
1680
- "%s/W%d_%s_%d_IDX.npy"
1681
- % (self.TEMPLATE_PATH, l_kernel**2, TMPFILE_VERSION, nside)
1682
- )
1683
- else:
1684
- tmp = np.load(
1685
- "%s/FOSCAT_%s_W%d_%d_%d_PIDX.npy"
1686
- % (
1687
- self.TEMPLATE_PATH,
1688
- TMPFILE_VERSION,
1689
- self.KERNELSZ**2,
1690
- self.NORIENT,
1691
- nside,
1688
+ if cell_ids is None:
1689
+ self.barrier()
1690
+ if self.use_2D:
1691
+ tmp = np.load(
1692
+ "%s/W%d_%s_%d_IDX.npy"
1693
+ % (self.TEMPLATE_PATH, l_kernel**2, TMPFILE_VERSION, nside)
1694
+ )
1695
+ else:
1696
+ tmp = np.load(
1697
+ "%s/FOSCAT_%s_W%d_%d_%d_PIDX.npy"
1698
+ % (
1699
+ self.TEMPLATE_PATH,
1700
+ TMPFILE_VERSION,
1701
+ self.KERNELSZ**2,
1702
+ self.NORIENT,
1703
+ nside,
1704
+ )
1692
1705
  )
1693
- )
1694
1706
  tmp2 = np.load(
1695
1707
  "%s/FOSCAT_%s_W%d_%d_%d_PIDX2.npy"
1696
1708
  % (
@@ -1731,22 +1743,28 @@ class FoCUS:
1731
1743
  nside,
1732
1744
  )
1733
1745
  )
1734
-
1735
- wr = self.backend.bk_SparseTensor(
1736
- self.backend.bk_constant(tmp),
1737
- self.backend.bk_constant(self.backend.bk_cast(wr)),
1738
- dense_shape=[12 * nside**2 * self.NORIENT, 12 * nside**2],
1739
- )
1740
- wi = self.backend.bk_SparseTensor(
1741
- self.backend.bk_constant(tmp),
1742
- self.backend.bk_constant(self.backend.bk_cast(wi)),
1743
- dense_shape=[12 * nside**2 * self.NORIENT, 12 * nside**2],
1744
- )
1745
- ws = self.backend.bk_SparseTensor(
1746
- self.backend.bk_constant(tmp2),
1747
- self.backend.bk_constant(self.backend.bk_cast(ws)),
1748
- dense_shape=[12 * nside**2, 12 * nside**2],
1749
- )
1746
+ else:
1747
+ tmp = indice
1748
+ tmp2 = indice2
1749
+ wr = wav.real
1750
+ wi = wav.imag
1751
+ ws = self.slope * wwav
1752
+
1753
+ wr = self.backend.bk_SparseTensor(
1754
+ self.backend.bk_constant(tmp),
1755
+ self.backend.bk_constant(self.backend.bk_cast(wr)),
1756
+ dense_shape=[ncell, self.NORIENT * ncell],
1757
+ )
1758
+ wi = self.backend.bk_SparseTensor(
1759
+ self.backend.bk_constant(tmp),
1760
+ self.backend.bk_constant(self.backend.bk_cast(wi)),
1761
+ dense_shape=[ncell, self.NORIENT * ncell],
1762
+ )
1763
+ ws = self.backend.bk_SparseTensor(
1764
+ self.backend.bk_constant(tmp2),
1765
+ self.backend.bk_constant(self.backend.bk_cast(ws)),
1766
+ dense_shape=[ncell, ncell],
1767
+ )
1750
1768
 
1751
1769
  if kernel == -1:
1752
1770
  self.Idx_Neighbours[nside] = tmp
@@ -1783,7 +1801,7 @@ class FoCUS:
1783
1801
  def masked_mean(self, x, mask, axis=0, rank=0, calc_var=False):
1784
1802
 
1785
1803
  # ==========================================================================
1786
- # in input data=[Nbatch,...,X[,Y],NORIENT[,NORIENT]]
1804
+ # in input data=[Nbatch,...,NORIENT[,NORIENT],X[,Y]]
1787
1805
  # in input mask=[Nmask,X[,Y]]
1788
1806
  # if self.use_2D : X[,Y]] = [X,Y]
1789
1807
  # if second level: NORIENT[,NORIENT]= NORIENT,NORIENT
@@ -1791,7 +1809,7 @@ class FoCUS:
1791
1809
 
1792
1810
  shape = list(x.shape)
1793
1811
 
1794
- if not self.use_2D:
1812
+ if not self.use_2D and not self.use_1D:
1795
1813
  nside = int(np.sqrt(x.shape[axis] // 12))
1796
1814
 
1797
1815
  l_mask = mask
@@ -1802,6 +1820,7 @@ class FoCUS:
1802
1820
  ),
1803
1821
  1,
1804
1822
  )
1823
+
1805
1824
  if not self.use_2D:
1806
1825
  l_mask = (
1807
1826
  12
@@ -1845,13 +1864,11 @@ class FoCUS:
1845
1864
  ]
1846
1865
 
1847
1866
  ichannel = 1
1848
- for i in range(axis):
1867
+ for i in range(1, len(shape) - 2):
1849
1868
  ichannel *= shape[i]
1850
- ochannel = 1
1851
- for i in range(axis + 2, len(shape)):
1852
- ochannel *= shape[i]
1869
+
1853
1870
  l_x = self.backend.bk_reshape(
1854
- x, [ichannel, 1, shape[axis], shape[axis + 1], ochannel]
1871
+ x, [shape[0], 1, ichannel, shape[-2], shape[-1]]
1855
1872
  )
1856
1873
 
1857
1874
  if self.padding == "VALID":
@@ -1876,12 +1893,10 @@ class FoCUS:
1876
1893
  l_mask = l_mask[:, self.KERNELSZ // 2 : -self.KERNELSZ // 2 + 1]
1877
1894
 
1878
1895
  ichannel = 1
1879
- for i in range(axis):
1896
+ for i in range(1, len(shape) - 1):
1880
1897
  ichannel *= shape[i]
1881
- ochannel = 1
1882
- for i in range(axis + 1, len(shape)):
1883
- ochannel *= shape[i]
1884
- l_x = self.backend.bk_reshape(x, [ichannel, 1, shape[axis], ochannel])
1898
+
1899
+ l_x = self.backend.bk_reshape(x, [shape[0], 1, ichannel, shape[-1]])
1885
1900
 
1886
1901
  if self.padding == "VALID":
1887
1902
  oshape = [k for k in shape]
@@ -1891,18 +1906,14 @@ class FoCUS:
1891
1906
  )
1892
1907
  else:
1893
1908
  ichannel = 1
1894
- for i in range(axis):
1909
+ for i in range(len(shape) - 1):
1895
1910
  ichannel *= shape[i]
1896
- ochannel = 1
1897
- for i in range(axis + 1, len(shape)):
1898
- ochannel *= shape[i]
1899
- l_x = self.backend.bk_reshape(x, [ichannel, 1, shape[axis], ochannel])
1900
1911
 
1901
- # data=[Nbatch,...,X[,Y],NORIENT[,NORIENT]] => data=[Nbatch,1,...,X[,Y],NORIENT[,NORIENT]]
1902
- # mask=[Nmask,X[,Y]] => mask=[1,Nmask,X[,Y]]
1903
- l_mask = self.backend.bk_expand_dims(l_mask, 0)
1904
- # mask=[1,Nmask,X[,Y]] => mask=[1,Nmask,X[,Y],1]
1905
- l_mask = self.backend.bk_expand_dims(l_mask, -1)
1912
+ l_x = self.backend.bk_reshape(x, [ichannel, 1, shape[-1]])
1913
+
1914
+ # data=[Nbatch,...,NORIENT[,NORIENT],X[,Y]] => data=[Nbatch,1,...,NORIENT[,NORIENT],X[,Y]]
1915
+ # mask=[Nmask,X[,Y]] => mask=[1,Nmask,....,X[,Y]]
1916
+ l_mask = self.backend.bk_expand_dims(self.backend.bk_expand_dims(l_mask, 0), 0)
1906
1917
 
1907
1918
  if l_x.dtype == self.all_cbk_type:
1908
1919
  l_mask = self.backend.bk_complex(l_mask, self.backend.bk_cast(0.0 * l_mask))
@@ -1916,21 +1927,23 @@ class FoCUS:
1916
1927
  # vtmp = l_x[:,self.KERNELSZ // 2 : -self.KERNELSZ // 2,self.KERNELSZ // 2 : -self.KERNELSZ // 2,:]
1917
1928
 
1918
1929
  v1 = self.backend.bk_reduce_sum(
1919
- self.backend.bk_reduce_sum(mtmp * vtmp, axis=2), 2
1930
+ self.backend.bk_reduce_sum(mtmp * vtmp, axis=-1), -1
1920
1931
  )
1921
1932
  v2 = self.backend.bk_reduce_sum(
1922
- self.backend.bk_reduce_sum(mtmp * vtmp * vtmp, axis=2), 2
1933
+ self.backend.bk_reduce_sum(mtmp * vtmp * vtmp, axis=-1), -1
1934
+ )
1935
+ vh = self.backend.bk_reduce_sum(
1936
+ self.backend.bk_reduce_sum(mtmp, axis=-1), -1
1923
1937
  )
1924
- vh = self.backend.bk_reduce_sum(self.backend.bk_reduce_sum(mtmp, axis=2), 2)
1925
1938
 
1926
1939
  res = v1 / vh
1927
1940
 
1928
- oshape = []
1941
+ oshape = [x.shape[0]] + [mask.shape[0]]
1929
1942
  if axis > 0:
1930
- oshape = oshape + list(x.shape[0:axis])
1931
- oshape = oshape + [mask.shape[0]]
1932
- if axis + 1 < len(x.shape):
1933
- oshape = oshape + list(x.shape[axis + 2 :])
1943
+ oshape = oshape + list(x.shape[1:axis])
1944
+
1945
+ if len(x.shape[axis:-2]) > 0:
1946
+ oshape = oshape + list(x.shape[axis:-2])
1934
1947
 
1935
1948
  if calc_var:
1936
1949
  if self.backend.bk_is_complex(vtmp):
@@ -1960,19 +1973,15 @@ class FoCUS:
1960
1973
  elif self.use_1D:
1961
1974
  mtmp = l_mask
1962
1975
  vtmp = l_x
1963
-
1964
- v1 = self.backend.bk_reduce_sum(mtmp * vtmp, axis=2)
1965
- v2 = self.backend.bk_reduce_sum(mtmp * vtmp * vtmp, axis=2)
1966
- vh = self.backend.bk_reduce_sum(mtmp, axis=2)
1976
+ v1 = self.backend.bk_reduce_sum(mtmp * vtmp, axis=-1)
1977
+ v2 = self.backend.bk_reduce_sum(mtmp * vtmp * vtmp, axis=-1)
1978
+ vh = self.backend.bk_reduce_sum(mtmp, axis=-1)
1967
1979
 
1968
1980
  res = v1 / vh
1969
1981
 
1970
- oshape = []
1971
- if axis > 0:
1972
- oshape = oshape + list(x.shape[0:axis])
1973
- oshape = oshape + [mask.shape[0]]
1974
- if axis + 1 < len(x.shape):
1975
- oshape = oshape + list(x.shape[axis + 1 :])
1982
+ oshape = [x.shape[0]] + [mask.shape[0]]
1983
+ if len(x.shape) > 1:
1984
+ oshape = oshape + list(x.shape[1:-1])
1976
1985
 
1977
1986
  if calc_var:
1978
1987
  if self.backend.bk_is_complex(vtmp):
@@ -1991,7 +2000,6 @@ class FoCUS:
1991
2000
  )
1992
2001
  else:
1993
2002
  res2 = self.backend.bk_sqrt((v2 / vh - res * res) / (vh))
1994
-
1995
2003
  res = self.backend.bk_reshape(res, oshape)
1996
2004
  res2 = self.backend.bk_reshape(res2, oshape)
1997
2005
  return res, res2
@@ -2000,18 +2008,20 @@ class FoCUS:
2000
2008
  return res
2001
2009
 
2002
2010
  else:
2003
- v1 = self.backend.bk_reduce_sum(l_mask * l_x, axis=2)
2004
- v2 = self.backend.bk_reduce_sum(l_mask * l_x * l_x, axis=2)
2005
- vh = self.backend.bk_reduce_sum(l_mask, axis=2)
2011
+ v1 = self.backend.bk_reduce_sum(l_mask * l_x, axis=-1)
2012
+ v2 = self.backend.bk_reduce_sum(l_mask * l_x * l_x, axis=-1)
2013
+ vh = self.backend.bk_reduce_sum(l_mask, axis=-1)
2006
2014
 
2007
2015
  res = v1 / vh
2008
2016
 
2009
2017
  oshape = []
2010
2018
  if axis > 0:
2011
- oshape = oshape + list(x.shape[0:axis])
2019
+ oshape = [x.shape[0]]
2020
+ else:
2021
+ oshape = [1]
2012
2022
  oshape = oshape + [mask.shape[0]]
2013
- if axis + 1 < len(x.shape):
2014
- oshape = oshape + list(x.shape[axis + 1 :])
2023
+ if axis > 1:
2024
+ oshape = oshape + list(x.shape[1:-1])
2015
2025
 
2016
2026
  if calc_var:
2017
2027
  if self.backend.bk_is_complex(l_x):
@@ -2176,169 +2186,67 @@ class FoCUS:
2176
2186
  print("Use of 2D scat with data that has less than 2D")
2177
2187
  return None
2178
2188
 
2179
- npix = ishape[axis]
2180
- npiy = ishape[axis + 1]
2181
- odata = 1
2182
- if len(ishape) > axis + 2:
2183
- for k in range(axis + 2, len(ishape)):
2184
- odata = odata * ishape[k]
2189
+ npix = ishape[-2]
2190
+ npiy = ishape[-1]
2185
2191
 
2186
2192
  ndata = 1
2187
- for k in range(axis):
2193
+ for k in range(len(ishape) - 2):
2188
2194
  ndata = ndata * ishape[k]
2189
2195
 
2190
2196
  tim = self.backend.bk_reshape(
2191
- self.backend.bk_cast(in_image), [ndata, npix, npiy, odata]
2197
+ self.backend.bk_cast(in_image), [ndata, npix, npiy]
2192
2198
  )
2193
2199
 
2194
2200
  if self.backend.bk_is_complex(tim):
2195
- rr1 = self.backend.conv2d(
2196
- self.backend.bk_real(tim),
2197
- self.ww_RealT[odata],
2198
- strides=[1, 1, 1, 1],
2199
- padding=self.padding,
2200
- )
2201
- ii1 = self.backend.conv2d(
2202
- self.backend.bk_real(tim),
2203
- self.ww_ImagT[odata],
2204
- strides=[1, 1, 1, 1],
2205
- padding=self.padding,
2206
- )
2207
- rr2 = self.backend.conv2d(
2208
- self.backend.bk_imag(tim),
2209
- self.ww_RealT[odata],
2210
- strides=[1, 1, 1, 1],
2211
- padding=self.padding,
2212
- )
2213
- ii2 = self.backend.conv2d(
2214
- self.backend.bk_imag(tim),
2215
- self.ww_ImagT[odata],
2216
- strides=[1, 1, 1, 1],
2217
- padding=self.padding,
2218
- )
2201
+ rr1 = self.backend.conv2d(self.backend.bk_real(tim), self.ww_RealT[1])
2202
+ ii1 = self.backend.conv2d(self.backend.bk_real(tim), self.ww_ImagT[1])
2203
+ rr2 = self.backend.conv2d(self.backend.bk_imag(tim), self.ww_RealT[1])
2204
+ ii2 = self.backend.conv2d(self.backend.bk_imag(tim), self.ww_ImagT[1])
2219
2205
  res = self.backend.bk_complex(rr1 - ii2, ii1 + rr2)
2220
2206
  else:
2221
- rr = self.backend.conv2d(
2222
- tim,
2223
- self.ww_RealT[odata],
2224
- strides=[1, 1, 1, 1],
2225
- padding=self.padding,
2226
- )
2227
- ii = self.backend.conv2d(
2228
- tim,
2229
- self.ww_ImagT[odata],
2230
- strides=[1, 1, 1, 1],
2231
- padding=self.padding,
2232
- )
2207
+ rr = self.backend.conv2d(tim, self.ww_RealT[1])
2208
+ ii = self.backend.conv2d(tim, self.ww_ImagT[1])
2233
2209
  res = self.backend.bk_complex(rr, ii)
2234
2210
 
2235
- if axis == 0:
2236
- if len(ishape) == 2:
2237
- return self.backend.bk_reshape(
2238
- res, [res.shape[1], res.shape[2], self.NORIENT]
2239
- )
2240
- else:
2241
- return self.backend.bk_reshape(
2242
- res,
2243
- [res.shape[1], res.shape[2], self.NORIENT] + ishape[axis + 2 :],
2244
- )
2245
- else:
2246
- if len(ishape) == axis + 2:
2247
- return self.backend.bk_reshape(
2248
- res, ishape[0:axis] + [res.shape[1], res.shape[2], self.NORIENT]
2249
- )
2250
- else:
2251
- return self.backend.bk_reshape(
2252
- res,
2253
- ishape[0:axis]
2254
- + [res.shape[1], res.shape[2], self.NORIENT]
2255
- + ishape[axis + 2 :],
2256
- )
2211
+ return self.backend.bk_reshape(
2212
+ res, ishape[0:-2] + [self.NORIENT, npix, npiy]
2213
+ )
2257
2214
 
2258
- return self.backend.bk_reshape(res, in_image.shape + [self.NORIENT])
2259
2215
  elif self.use_1D:
2260
2216
  ishape = list(in_image.shape)
2261
- if len(ishape) < axis + 1:
2262
- if not self.silent:
2263
- print("Use of 1D scat with data that has less than 1D")
2264
- return None
2265
2217
 
2266
- npix = ishape[axis]
2267
- odata = 1
2268
- if len(ishape) > axis + 1:
2269
- for k in range(axis + 1, len(ishape)):
2270
- odata = odata * ishape[k]
2218
+ npix = ishape[-1]
2271
2219
 
2272
2220
  ndata = 1
2273
- for k in range(axis):
2221
+ for k in range(len(ishape) - 1):
2274
2222
  ndata = ndata * ishape[k]
2275
2223
 
2276
- tim = self.backend.bk_reshape(
2277
- self.backend.bk_cast(in_image), [ndata, npix, odata]
2278
- )
2224
+ tim = self.backend.bk_reshape(self.backend.bk_cast(in_image), [ndata, npix])
2279
2225
 
2280
2226
  if self.backend.bk_is_complex(tim):
2281
- rr1 = self.backend.conv1d(
2282
- self.backend.bk_real(tim),
2283
- self.ww_RealT[odata],
2284
- strides=[1, 1, 1],
2285
- padding=self.padding,
2286
- )
2287
- ii1 = self.backend.conv1d(
2288
- self.backend.bk_real(tim),
2289
- self.ww_ImagT[odata],
2290
- strides=[1, 1, 1],
2291
- padding=self.padding,
2292
- )
2293
- rr2 = self.backend.conv1d(
2294
- self.backend.bk_imag(tim),
2295
- self.ww_RealT[odata],
2296
- strides=[1, 1, 1],
2297
- padding=self.padding,
2298
- )
2299
- ii2 = self.backend.conv1d(
2300
- self.backend.bk_imag(tim),
2301
- self.ww_ImagT[odata],
2302
- strides=[1, 1, 1],
2303
- padding=self.padding,
2304
- )
2227
+ rr1 = self.backend.conv1d(self.backend.bk_real(tim), self.ww_RealT[1])
2228
+ ii1 = self.backend.conv1d(self.backend.bk_real(tim), self.ww_ImagT[1])
2229
+ rr2 = self.backend.conv1d(self.backend.bk_imag(tim), self.ww_RealT[1])
2230
+ ii2 = self.backend.conv1d(self.backend.bk_imag(tim), self.ww_ImagT[1])
2305
2231
  res = self.backend.bk_complex(rr1 - ii2, ii1 + rr2)
2306
2232
  else:
2307
- rr = self.backend.conv1d(
2308
- tim, self.ww_RealT[odata], strides=[1, 1, 1], padding=self.padding
2309
- )
2310
- ii = self.backend.conv1d(
2311
- tim, self.ww_ImagT[odata], strides=[1, 1, 1], padding=self.padding
2312
- )
2233
+ rr = self.backend.conv1d(tim, self.ww_RealT[1])
2234
+ ii = self.backend.conv1d(tim, self.ww_ImagT[1])
2313
2235
  res = self.backend.bk_complex(rr, ii)
2314
2236
 
2315
- if axis == 0:
2316
- if len(ishape) == 1:
2317
- return self.backend.bk_reshape(res, [res.shape[1]])
2318
- else:
2319
- return self.backend.bk_reshape(
2320
- res, [res.shape[1]] + ishape[axis + 2 :]
2321
- )
2322
- else:
2323
- if len(ishape) == axis + 1:
2324
- return self.backend.bk_reshape(res, ishape[0:axis] + [res.shape[1]])
2325
- else:
2326
- return self.backend.bk_reshape(
2327
- res, ishape[0:axis] + [res.shape[1]] + ishape[axis + 1 :]
2328
- )
2329
-
2330
- return self.backend.bk_reshape(res, in_image.shape + [self.NORIENT])
2237
+ return self.backend.bk_reshape(res, ishape)
2331
2238
 
2332
2239
  else:
2333
2240
  ishape = list(image.shape)
2334
-
2241
+ """
2335
2242
  if cell_ids is not None:
2336
2243
  if cell_ids.shape[0] not in self.padding_conv:
2244
+ print(image.shape,cell_ids.shape)
2337
2245
  import healpix_convolution as hc
2338
2246
  from xdggs.healpix import HealpixInfo
2339
2247
 
2340
2248
  res = self.backend.bk_zeros(
2341
- ishape + [self.NORIENT], dtype=self.backend.all_cbk_type
2249
+ ishape[0:-1] + [self.NORIENT]+ishape[-1], dtype=self.backend.all_cbk_type
2342
2250
  )
2343
2251
 
2344
2252
  grid_info = HealpixInfo(
@@ -2384,14 +2292,15 @@ class FoCUS:
2384
2292
  padded_data
2385
2293
  ) + 1j * kernelI.matmul(padded_data)
2386
2294
  return res
2387
-
2388
- nside = int(np.sqrt(image.shape[axis] // 12))
2295
+ """
2296
+ if nside is None:
2297
+ nside = int(np.sqrt(image.shape[-1] // 12))
2389
2298
 
2390
2299
  if self.Idx_Neighbours[nside] is None:
2391
2300
  if self.InitWave is None:
2392
- wr, wi, ws, widx = self.init_index(nside)
2301
+ wr, wi, ws, widx = self.init_index(nside, cell_ids=cell_ids)
2393
2302
  else:
2394
- wr, wi, ws, widx = self.InitWave(self, nside)
2303
+ wr, wi, ws, widx = self.InitWave(nside, cell_ids=cell_ids)
2395
2304
 
2396
2305
  self.Idx_Neighbours[nside] = 1 # self.backend.bk_constant(tmp)
2397
2306
  self.ww_Real[nside] = wr
@@ -2401,156 +2310,63 @@ class FoCUS:
2401
2310
  l_ww_real = self.ww_Real[nside]
2402
2311
  l_ww_imag = self.ww_Imag[nside]
2403
2312
 
2404
- odata = 1
2405
- for k in range(axis + 1, len(ishape)):
2406
- odata = odata * ishape[k]
2313
+ # always convolve the last dimension
2407
2314
 
2408
- if axis > 0:
2409
- ndata = 1
2410
- for k in range(axis):
2315
+ ndata = 1
2316
+ if len(ishape) > 1:
2317
+ for k in range(len(ishape) - 1):
2411
2318
  ndata = ndata * ishape[k]
2412
- tim = self.backend.bk_reshape(
2413
- self.backend.bk_cast(image), [ndata, 12 * nside**2, odata]
2414
- )
2415
- if tim.dtype == self.all_cbk_type:
2416
- rr1 = self.backend.bk_reshape(
2417
- self.backend.bk_sparse_dense_matmul(
2418
- l_ww_real, self.backend.bk_real(tim[0])
2419
- ),
2420
- [1, 12 * nside**2, self.NORIENT, odata],
2421
- )
2422
- ii1 = self.backend.bk_reshape(
2423
- self.backend.bk_sparse_dense_matmul(
2424
- l_ww_imag, self.backend.bk_real(tim[0])
2425
- ),
2426
- [1, 12 * nside**2, self.NORIENT, odata],
2427
- )
2428
- rr2 = self.backend.bk_reshape(
2429
- self.backend.bk_sparse_dense_matmul(
2430
- l_ww_real, self.backend.bk_imag(tim[0])
2431
- ),
2432
- [1, 12 * nside**2, self.NORIENT, odata],
2433
- )
2434
- ii2 = self.backend.bk_reshape(
2435
- self.backend.bk_sparse_dense_matmul(
2436
- l_ww_imag, self.backend.bk_imag(tim[0])
2437
- ),
2438
- [1, 12 * nside**2, self.NORIENT, odata],
2439
- )
2440
- res = self.backend.bk_complex(rr1 - ii2, ii1 + rr2)
2441
- else:
2442
- rr = self.backend.bk_reshape(
2443
- self.backend.bk_sparse_dense_matmul(l_ww_real, tim[0]),
2444
- [1, 12 * nside**2, self.NORIENT, odata],
2445
- )
2446
- ii = self.backend.bk_reshape(
2447
- self.backend.bk_sparse_dense_matmul(l_ww_imag, tim[0]),
2448
- [1, 12 * nside**2, self.NORIENT, odata],
2449
- )
2450
- res = self.backend.bk_complex(rr, ii)
2451
-
2452
- for k in range(1, ndata):
2453
- if tim.dtype == self.all_cbk_type:
2454
- rr1 = self.backend.bk_reshape(
2455
- self.backend.bk_sparse_dense_matmul(
2456
- l_ww_real, self.backend.bk_real(tim[k])
2457
- ),
2458
- [1, 12 * nside**2, self.NORIENT, odata],
2459
- )
2460
- ii1 = self.backend.bk_reshape(
2461
- self.backend.bk_sparse_dense_matmul(
2462
- l_ww_imag, self.backend.bk_real(tim[k])
2463
- ),
2464
- [1, 12 * nside**2, self.NORIENT, odata],
2465
- )
2466
- rr2 = self.backend.bk_reshape(
2467
- self.backend.bk_sparse_dense_matmul(
2468
- l_ww_real, self.backend.bk_imag(tim[k])
2469
- ),
2470
- [1, 12 * nside**2, self.NORIENT, odata],
2471
- )
2472
- ii2 = self.backend.bk_reshape(
2473
- self.backend.bk_sparse_dense_matmul(
2474
- l_ww_imag, self.backend.bk_imag(tim[k])
2475
- ),
2476
- [1, 12 * nside**2, self.NORIENT, odata],
2477
- )
2478
- res = self.backend.bk_concat(
2479
- [res, self.backend.bk_complex(rr1 - ii2, ii1 + rr2)], 0
2480
- )
2481
- else:
2482
- rr = self.backend.bk_reshape(
2483
- self.backend.bk_sparse_dense_matmul(l_ww_real, tim[k]),
2484
- [1, 12 * nside**2, self.NORIENT, odata],
2485
- )
2486
- ii = self.backend.bk_reshape(
2487
- self.backend.bk_sparse_dense_matmul(l_ww_imag, tim[k]),
2488
- [1, 12 * nside**2, self.NORIENT, odata],
2489
- )
2490
- res = self.backend.bk_concat(
2491
- [res, self.backend.bk_complex(rr, ii)], 0
2492
- )
2493
-
2494
- if len(ishape) == axis + 1:
2495
- return self.backend.bk_reshape(
2496
- res, ishape[0:axis] + [12 * nside**2, self.NORIENT]
2497
- )
2498
- else:
2499
- return self.backend.bk_reshape(
2500
- res,
2501
- ishape[0:axis]
2502
- + [12 * nside**2]
2503
- + ishape[axis + 1 :]
2504
- + [self.NORIENT],
2505
- )
2319
+ tim = self.backend.bk_reshape(
2320
+ self.backend.bk_cast(image), [ndata, ishape[-1]]
2321
+ )
2506
2322
 
2507
- if axis == 0:
2508
- tim = self.backend.bk_reshape(
2509
- self.backend.bk_cast(image), [12 * nside**2, odata]
2323
+ if tim.dtype == self.all_cbk_type:
2324
+ rr1 = self.backend.bk_reshape(
2325
+ self.backend.bk_sparse_dense_matmul(
2326
+ self.backend.bk_real(tim),
2327
+ l_ww_real,
2328
+ ),
2329
+ [ndata, self.NORIENT, ishape[-1]],
2510
2330
  )
2511
- if tim.dtype == self.all_cbk_type:
2512
- rr1 = self.backend.bk_reshape(
2513
- self.backend.bk_sparse_dense_matmul(
2514
- l_ww_real, self.backend.bk_real(tim)
2515
- ),
2516
- [12 * nside**2, self.NORIENT, odata],
2517
- )
2518
- ii1 = self.backend.bk_reshape(
2519
- self.backend.bk_sparse_dense_matmul(
2520
- l_ww_imag, self.backend.bk_real(tim)
2521
- ),
2522
- [12 * nside**2, self.NORIENT, odata],
2523
- )
2524
- rr2 = self.backend.bk_reshape(
2525
- self.backend.bk_sparse_dense_matmul(
2526
- l_ww_real, self.backend.bk_imag(tim)
2527
- ),
2528
- [12 * nside**2, self.NORIENT, odata],
2529
- )
2530
- ii2 = self.backend.bk_reshape(
2531
- self.backend.bk_sparse_dense_matmul(
2532
- l_ww_imag, self.backend.bk_imag(tim)
2533
- ),
2534
- [12 * nside**2, self.NORIENT, odata],
2535
- )
2536
- res = self.backend.bk_complex(rr1 - ii2, ii1 + rr2)
2537
- else:
2538
- rr = self.backend.bk_reshape(
2539
- self.backend.bk_sparse_dense_matmul(l_ww_real, tim),
2540
- [12 * nside**2, self.NORIENT, odata],
2541
- )
2542
- ii = self.backend.bk_reshape(
2543
- self.backend.bk_sparse_dense_matmul(l_ww_imag, tim),
2544
- [12 * nside**2, self.NORIENT, odata],
2545
- )
2546
- res = self.backend.bk_complex(rr, ii)
2331
+ ii1 = self.backend.bk_reshape(
2332
+ self.backend.bk_sparse_dense_matmul(
2333
+ self.backend.bk_real(tim),
2334
+ l_ww_imag,
2335
+ ),
2336
+ [ndata, self.NORIENT, ishape[-1]],
2337
+ )
2338
+ rr2 = self.backend.bk_reshape(
2339
+ self.backend.bk_sparse_dense_matmul(
2340
+ self.backend.bk_imag(tim),
2341
+ l_ww_real,
2342
+ ),
2343
+ [ndata, self.NORIENT, ishape[-1]],
2344
+ )
2345
+ ii2 = self.backend.bk_reshape(
2346
+ self.backend.bk_sparse_dense_matmul(
2347
+ self.backend.bk_imag(tim),
2348
+ l_ww_imag,
2349
+ ),
2350
+ [ndata, self.NORIENT, ishape[-1]],
2351
+ )
2352
+ res = self.backend.bk_complex(rr1 - ii2, ii1 + rr2)
2353
+ else:
2354
+ rr = self.backend.bk_reshape(
2355
+ self.backend.bk_sparse_dense_matmul(tim, l_ww_real),
2356
+ [ndata, self.NORIENT, ishape[-1]],
2357
+ )
2358
+ ii = self.backend.bk_reshape(
2359
+ self.backend.bk_sparse_dense_matmul(tim, l_ww_imag),
2360
+ [ndata, self.NORIENT, ishape[-1]],
2361
+ )
2362
+ res = self.backend.bk_complex(rr, ii)
2363
+ if len(ishape) > 1:
2364
+ return self.backend.bk_reshape(
2365
+ res, ishape[0:-1] + [self.NORIENT, ishape[-1]]
2366
+ )
2367
+ else:
2368
+ return self.backend.bk_reshape(res, [self.NORIENT, ishape[-1]])
2547
2369
 
2548
- if len(ishape) == 1:
2549
- return self.backend.bk_reshape(res, [12 * nside**2, self.NORIENT])
2550
- else:
2551
- return self.backend.bk_reshape(
2552
- res, [12 * nside**2] + ishape[axis + 1 :] + [self.NORIENT]
2553
- )
2554
2370
  return res
2555
2371
 
2556
2372
  # ---------------------------------------------−---------
@@ -2578,114 +2394,43 @@ class FoCUS:
2578
2394
  ndata = ndata * ishape[k]
2579
2395
 
2580
2396
  tim = self.backend.bk_reshape(
2581
- self.backend.bk_cast(in_image), [ndata, npix, npiy, odata]
2397
+ self.backend.bk_cast(in_image), [ndata, npix, npiy]
2582
2398
  )
2583
2399
 
2584
2400
  if self.backend.bk_is_complex(tim):
2585
- rr = self.backend.conv2d(
2586
- self.backend.bk_real(tim),
2587
- self.ww_SmoothT[odata],
2588
- strides=[1, 1, 1, 1],
2589
- padding=self.padding,
2590
- )
2591
- ii = self.backend.conv2d(
2592
- self.backend.bk_imag(tim),
2593
- self.ww_SmoothT[odata],
2594
- strides=[1, 1, 1, 1],
2595
- padding=self.padding,
2596
- )
2401
+ rr = self.backend.conv2d(self.backend.bk_real(tim), self.ww_SmoothT[1])
2402
+ ii = self.backend.conv2d(self.backend.bk_imag(tim), self.ww_SmoothT[1])
2597
2403
  res = self.backend.bk_complex(rr, ii)
2598
2404
  else:
2599
- res = self.backend.conv2d(
2600
- tim,
2601
- self.ww_SmoothT[odata],
2602
- strides=[1, 1, 1, 1],
2603
- padding=self.padding,
2604
- )
2405
+ res = self.backend.conv2d(tim, self.ww_SmoothT[1])
2605
2406
 
2606
- if axis == 0:
2607
- if len(ishape) == 2:
2608
- return self.backend.bk_reshape(res, [res.shape[1], res.shape[2]])
2609
- else:
2610
- return self.backend.bk_reshape(
2611
- res, [res.shape[1], res.shape[2]] + ishape[axis + 2 :]
2612
- )
2613
- else:
2614
- if len(ishape) == axis + 2:
2615
- return self.backend.bk_reshape(
2616
- res, ishape[0:axis] + [res.shape[1], res.shape[2]]
2617
- )
2618
- else:
2619
- return self.backend.bk_reshape(
2620
- res,
2621
- ishape[0:axis]
2622
- + [res.shape[1], res.shape[2]]
2623
- + ishape[axis + 2 :],
2624
- )
2407
+ return self.backend.bk_reshape(res, ishape)
2625
2408
 
2626
- return self.backend.bk_reshape(res, in_image.shape)
2627
2409
  elif self.use_1D:
2628
2410
 
2629
2411
  ishape = list(in_image.shape)
2630
- if len(ishape) < axis + 1:
2631
- if not self.silent:
2632
- print("Use of 1D scat with data that has less than 1D")
2633
- return None
2634
2412
 
2635
- npix = ishape[axis]
2636
- odata = 1
2637
- if len(ishape) > axis + 1:
2638
- for k in range(axis + 1, len(ishape)):
2639
- odata = odata * ishape[k]
2413
+ npix = ishape[-1]
2640
2414
 
2641
2415
  ndata = 1
2642
- for k in range(axis):
2416
+ for k in range(len(ishape) - 1):
2643
2417
  ndata = ndata * ishape[k]
2644
2418
 
2645
- tim = self.backend.bk_reshape(
2646
- self.backend.bk_cast(in_image), [ndata, npix, odata]
2647
- )
2419
+ tim = self.backend.bk_reshape(self.backend.bk_cast(in_image), [ndata, npix])
2648
2420
 
2649
2421
  if self.backend.bk_is_complex(tim):
2650
- rr = self.backend.conv1d(
2651
- self.backend.bk_real(tim),
2652
- self.ww_SmoothT[odata],
2653
- strides=[1, 1, 1],
2654
- padding=self.padding,
2655
- )
2656
- ii = self.backend.conv1d(
2657
- self.backend.bk_imag(tim),
2658
- self.ww_SmoothT[odata],
2659
- strides=[1, 1, 1],
2660
- padding=self.padding,
2661
- )
2422
+ rr = self.backend.conv1d(self.backend.bk_real(tim), self.ww_SmoothT[1])
2423
+ ii = self.backend.conv1d(self.backend.bk_imag(tim), self.ww_SmoothT[1])
2662
2424
  res = self.backend.bk_complex(rr, ii)
2663
2425
  else:
2664
- res = self.backend.conv1d(
2665
- tim, self.ww_SmoothT[odata], strides=[1, 1, 1], padding=self.padding
2666
- )
2426
+ res = self.backend.conv1d(tim, self.ww_SmoothT[1])
2667
2427
 
2668
- if axis == 0:
2669
- if len(ishape) == 1:
2670
- return self.backend.bk_reshape(res, [res.shape[1]])
2671
- else:
2672
- return self.backend.bk_reshape(
2673
- res, [res.shape[1]] + ishape[axis + 1 :]
2674
- )
2675
- else:
2676
- if len(ishape) == axis + 1:
2677
- return self.backend.bk_reshape(res, ishape[0:axis] + [res.shape[1]])
2678
- else:
2679
- return self.backend.bk_reshape(
2680
- res, ishape[0:axis] + [res.shape[1]] + ishape[axis + 1 :]
2681
- )
2682
-
2683
- return self.backend.bk_reshape(res, in_image.shape)
2428
+ return self.backend.bk_reshape(res, ishape)
2684
2429
 
2685
2430
  else:
2686
2431
 
2687
2432
  ishape = list(image.shape)
2688
-
2433
+ """
2689
2434
  if cell_ids is not None:
2690
2435
  if cell_ids.shape[0] not in self.padding_smooth:
2691
2436
  import healpix_convolution as hc
@@ -2726,15 +2471,16 @@ class FoCUS:
2726
2471
  padded_data = padding.apply(image[l, :, k2], is_torch=True)
2727
2472
  res[l, :, k2] = kernel.matmul(padded_data)
2728
2473
  return res
2729
-
2730
- nside = int(np.sqrt(image.shape[axis] // 12))
2474
+ """
2475
+ if nside is None:
2476
+ nside = int(np.sqrt(image.shape[-1] // 12))
2731
2477
 
2732
2478
  if self.Idx_Neighbours[nside] is None:
2733
2479
 
2734
2480
  if self.InitWave is None:
2735
- wr, wi, ws, widx = self.init_index(nside)
2481
+ wr, wi, ws, widx = self.init_index(nside, cell_ids=cell_ids)
2736
2482
  else:
2737
- wr, wi, ws, widx = self.InitWave(self, nside)
2483
+ wr, wi, ws, widx = self.InitWave(self, nside, cell_ids=cell_ids)
2738
2484
 
2739
2485
  self.Idx_Neighbours[nside] = 1
2740
2486
  self.ww_Real[nside] = wr
@@ -2744,92 +2490,24 @@ class FoCUS:
2744
2490
  l_w_smooth = self.w_smooth[nside]
2745
2491
 
2746
2492
  odata = 1
2747
- for k in range(axis + 1, len(ishape)):
2493
+ for k in range(0, len(ishape) - 1):
2748
2494
  odata = odata * ishape[k]
2749
2495
 
2750
- if axis == 0:
2751
- tim = self.backend.bk_reshape(image, [12 * nside**2, odata])
2752
- if tim.dtype == self.all_cbk_type:
2753
- rr = self.backend.bk_sparse_dense_matmul(
2754
- l_w_smooth, self.backend.bk_real(tim)
2755
- )
2756
- ri = self.backend.bk_sparse_dense_matmul(
2757
- l_w_smooth, self.backend.bk_imag(tim)
2758
- )
2759
- res = self.backend.bk_complex(rr, ri)
2760
- else:
2761
- res = self.backend.bk_sparse_dense_matmul(l_w_smooth, tim)
2762
- if len(ishape) == 1:
2763
- return self.backend.bk_reshape(res, [12 * nside**2])
2764
- else:
2765
- return self.backend.bk_reshape(
2766
- res, [12 * nside**2] + ishape[axis + 1 :]
2767
- )
2768
-
2769
- if axis > 0:
2770
- ndata = ishape[0]
2771
- for k in range(1, axis):
2772
- ndata = ndata * ishape[k]
2773
- tim = self.backend.bk_reshape(image, [ndata, 12 * nside**2, odata])
2774
- if tim.dtype == self.all_cbk_type:
2775
- rr = self.backend.bk_reshape(
2776
- self.backend.bk_sparse_dense_matmul(
2777
- l_w_smooth, self.backend.bk_real(tim[0])
2778
- ),
2779
- [1, 12 * nside**2, odata],
2780
- )
2781
- ri = self.backend.bk_reshape(
2782
- self.backend.bk_sparse_dense_matmul(
2783
- l_w_smooth, self.backend.bk_imag(tim[0])
2784
- ),
2785
- [1, 12 * nside**2, odata],
2786
- )
2787
- res = self.backend.bk_complex(rr, ri)
2788
- else:
2789
- res = self.backend.bk_reshape(
2790
- self.backend.bk_sparse_dense_matmul(l_w_smooth, tim[0]),
2791
- [1, 12 * nside**2, odata],
2792
- )
2793
-
2794
- for k in range(1, ndata):
2795
- if tim.dtype == self.all_cbk_type:
2796
- rr = self.backend.bk_reshape(
2797
- self.backend.bk_sparse_dense_matmul(
2798
- l_w_smooth, self.backend.bk_real(tim[k])
2799
- ),
2800
- [1, 12 * nside**2, odata],
2801
- )
2802
- ri = self.backend.bk_reshape(
2803
- self.backend.bk_sparse_dense_matmul(
2804
- l_w_smooth, self.backend.bk_imag(tim[k])
2805
- ),
2806
- [1, 12 * nside**2, odata],
2807
- )
2808
- res = self.backend.bk_concat(
2809
- [res, self.backend.bk_complex(rr, ri)], 0
2810
- )
2811
- else:
2812
- res = self.backend.bk_concat(
2813
- [
2814
- res,
2815
- self.backend.bk_reshape(
2816
- self.backend.bk_sparse_dense_matmul(
2817
- l_w_smooth, tim[k]
2818
- ),
2819
- [1, 12 * nside**2, odata],
2820
- ),
2821
- ],
2822
- 0,
2823
- )
2824
-
2825
- if len(ishape) == axis + 1:
2826
- return self.backend.bk_reshape(
2827
- res, ishape[0:axis] + [12 * nside**2]
2828
- )
2829
- else:
2830
- return self.backend.bk_reshape(
2831
- res, ishape[0:axis] + [12 * nside**2] + ishape[axis + 1 :]
2832
- )
2496
+ tim = self.backend.bk_reshape(image, [odata, ishape[-1]])
2497
+ if tim.dtype == self.all_cbk_type:
2498
+ rr = self.backend.bk_sparse_dense_matmul(
2499
+ self.backend.bk_real(tim), l_w_smooth
2500
+ )
2501
+ ri = self.backend.bk_sparse_dense_matmul(
2502
+ self.backend.bk_imag(tim), l_w_smooth
2503
+ )
2504
+ res = self.backend.bk_complex(rr, ri)
2505
+ else:
2506
+ res = self.backend.bk_sparse_dense_matmul(tim, l_w_smooth)
2507
+ if len(ishape) == 1:
2508
+ return self.backend.bk_reshape(res, [ishape[-1]])
2509
+ else:
2510
+ return self.backend.bk_reshape(res, ishape[0:-1] + [ishape[-1]])
2833
2511
 
2834
2512
  return res
2835
2513