foscat 2025.5.0__py3-none-any.whl → 2025.6.1__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- foscat/BkTensorflow.py +138 -14
- foscat/BkTorch.py +90 -57
- foscat/CNN.py +31 -30
- foscat/FoCUS.py +640 -917
- foscat/GCNN.py +48 -150
- foscat/Softmax.py +1 -0
- foscat/alm.py +2 -2
- foscat/heal_NN.py +432 -0
- foscat/scat_cov.py +139 -96
- foscat/scat_cov_map2D.py +2 -2
- {foscat-2025.5.0.dist-info → foscat-2025.6.1.dist-info}/METADATA +1 -1
- {foscat-2025.5.0.dist-info → foscat-2025.6.1.dist-info}/RECORD +15 -14
- {foscat-2025.5.0.dist-info → foscat-2025.6.1.dist-info}/WHEEL +1 -1
- {foscat-2025.5.0.dist-info → foscat-2025.6.1.dist-info}/licenses/LICENSE +0 -0
- {foscat-2025.5.0.dist-info → foscat-2025.6.1.dist-info}/top_level.txt +0 -0
foscat/FoCUS.py
CHANGED
|
@@ -5,7 +5,7 @@ import healpy as hp
|
|
|
5
5
|
import numpy as np
|
|
6
6
|
from scipy.interpolate import griddata
|
|
7
7
|
|
|
8
|
-
TMPFILE_VERSION = "
|
|
8
|
+
TMPFILE_VERSION = "V5_0"
|
|
9
9
|
|
|
10
10
|
|
|
11
11
|
class FoCUS:
|
|
@@ -35,7 +35,7 @@ class FoCUS:
|
|
|
35
35
|
mpi_rank=0,
|
|
36
36
|
):
|
|
37
37
|
|
|
38
|
-
self.__version__ = "2025.
|
|
38
|
+
self.__version__ = "2025.06.1"
|
|
39
39
|
# P00 coeff for normalization for scat_cov
|
|
40
40
|
self.TMPFILE_VERSION = TMPFILE_VERSION
|
|
41
41
|
self.P1_dic = None
|
|
@@ -176,11 +176,14 @@ class FoCUS:
|
|
|
176
176
|
self.Y_CNN = {}
|
|
177
177
|
self.Z_CNN = {}
|
|
178
178
|
|
|
179
|
+
self.Idx_CNN = {}
|
|
180
|
+
self.Idx_WCNN = {}
|
|
181
|
+
|
|
179
182
|
self.filters_set = {}
|
|
180
183
|
self.edge_masks = {}
|
|
181
184
|
|
|
182
|
-
wwc = np.zeros([KERNELSZ**2
|
|
183
|
-
wws = np.zeros([KERNELSZ**2
|
|
185
|
+
wwc = np.zeros([l_NORIENT, KERNELSZ**2]).astype(all_type)
|
|
186
|
+
wws = np.zeros([l_NORIENT, KERNELSZ**2]).astype(all_type)
|
|
184
187
|
|
|
185
188
|
x = np.repeat(np.arange(KERNELSZ) - KERNELSZ // 2, KERNELSZ).reshape(
|
|
186
189
|
KERNELSZ, KERNELSZ
|
|
@@ -203,12 +206,12 @@ class FoCUS:
|
|
|
203
206
|
-0.5 * (xx**2 + yy**2)
|
|
204
207
|
)
|
|
205
208
|
|
|
206
|
-
wwc[
|
|
209
|
+
wwc[0] = tmp.flatten() - tmp.mean()
|
|
207
210
|
tmp = 0 * w_smooth
|
|
208
|
-
wws[
|
|
211
|
+
wws[0] = tmp.flatten()
|
|
209
212
|
sigma = np.sqrt((wwc[:, 0] ** 2).mean())
|
|
210
|
-
wwc[
|
|
211
|
-
wws[
|
|
213
|
+
wwc[0] /= sigma
|
|
214
|
+
wws[0] /= sigma
|
|
212
215
|
|
|
213
216
|
w_smooth = w_smooth.flatten()
|
|
214
217
|
else:
|
|
@@ -239,12 +242,12 @@ class FoCUS:
|
|
|
239
242
|
tmp1 = np.cos(yy * np.pi) * w_smooth
|
|
240
243
|
tmp2 = np.sin(yy * np.pi) * w_smooth
|
|
241
244
|
|
|
242
|
-
wwc[
|
|
243
|
-
wws[
|
|
245
|
+
wwc[i] = tmp1.flatten() - tmp1.mean()
|
|
246
|
+
wws[i] = tmp2.flatten() - tmp2.mean()
|
|
244
247
|
# sigma = np.sqrt((wwc[:, i] ** 2).mean())
|
|
245
248
|
sigma = np.mean(w_smooth)
|
|
246
|
-
wwc[
|
|
247
|
-
wws[
|
|
249
|
+
wwc[i] /= sigma
|
|
250
|
+
wws[i] /= sigma
|
|
248
251
|
|
|
249
252
|
if DODIV and i == 0:
|
|
250
253
|
r = xx**2 + yy**2
|
|
@@ -253,22 +256,22 @@ class FoCUS:
|
|
|
253
256
|
tmp1 = r * np.cos(2 * theta) * w_smooth
|
|
254
257
|
tmp2 = r * np.sin(2 * theta) * w_smooth
|
|
255
258
|
|
|
256
|
-
wwc[
|
|
257
|
-
wws[
|
|
259
|
+
wwc[NORIENT] = tmp1.flatten() - tmp1.mean()
|
|
260
|
+
wws[NORIENT] = tmp2.flatten() - tmp2.mean()
|
|
258
261
|
# sigma = np.sqrt((wwc[:, NORIENT] ** 2).mean())
|
|
259
262
|
sigma = np.mean(w_smooth)
|
|
260
263
|
|
|
261
|
-
wwc[
|
|
262
|
-
wws[
|
|
264
|
+
wwc[NORIENT] /= sigma
|
|
265
|
+
wws[NORIENT] /= sigma
|
|
263
266
|
tmp1 = r * np.cos(2 * theta + np.pi)
|
|
264
267
|
tmp2 = r * np.sin(2 * theta + np.pi)
|
|
265
268
|
|
|
266
|
-
wwc[
|
|
267
|
-
wws[
|
|
269
|
+
wwc[NORIENT + 1] = tmp1.flatten() - tmp1.mean()
|
|
270
|
+
wws[NORIENT + 1] = tmp2.flatten() - tmp2.mean()
|
|
268
271
|
# sigma = np.sqrt((wwc[:, NORIENT + 1] ** 2).mean())
|
|
269
272
|
sigma = np.mean(w_smooth)
|
|
270
|
-
wwc[
|
|
271
|
-
wws[
|
|
273
|
+
wwc[NORIENT + 1] /= sigma
|
|
274
|
+
wws[NORIENT + 1] /= sigma
|
|
272
275
|
|
|
273
276
|
w_smooth = w_smooth.flatten()
|
|
274
277
|
|
|
@@ -316,14 +319,14 @@ class FoCUS:
|
|
|
316
319
|
r = np.sum(np.sqrt(c * c + s * s))
|
|
317
320
|
c = c / r
|
|
318
321
|
s = s / r
|
|
319
|
-
self.ww_RealT[1] = self.backend.
|
|
320
|
-
np.array(c).reshape(xx.shape[0]
|
|
322
|
+
self.ww_RealT[1] = self.backend.bk_cast(
|
|
323
|
+
self.backend.bk_constant(np.array(c).reshape(xx.shape[0]))
|
|
321
324
|
)
|
|
322
|
-
self.ww_ImagT[1] = self.backend.
|
|
323
|
-
np.array(s).reshape(xx.shape[0]
|
|
325
|
+
self.ww_ImagT[1] = self.backend.bk_cast(
|
|
326
|
+
self.backend.bk_constant(np.array(s).reshape(xx.shape[0]))
|
|
324
327
|
)
|
|
325
|
-
self.ww_SmoothT[1] = self.backend.
|
|
326
|
-
np.array(w).reshape(xx.shape[0]
|
|
328
|
+
self.ww_SmoothT[1] = self.backend.bk_cast(
|
|
329
|
+
self.backend.bk_constant(np.array(w).reshape(xx.shape[0]))
|
|
327
330
|
)
|
|
328
331
|
|
|
329
332
|
else:
|
|
@@ -333,22 +336,16 @@ class FoCUS:
|
|
|
333
336
|
self.ww_SmoothT = {}
|
|
334
337
|
|
|
335
338
|
self.ww_SmoothT[1] = self.backend.bk_constant(
|
|
336
|
-
self.w_smooth.reshape(
|
|
337
|
-
)
|
|
338
|
-
www = np.zeros([KERNELSZ, KERNELSZ, NORIENT, NORIENT], dtype=self.all_type)
|
|
339
|
-
for k in range(NORIENT):
|
|
340
|
-
www[:, :, k, k] = self.w_smooth.reshape(KERNELSZ, KERNELSZ)
|
|
341
|
-
self.ww_SmoothT[NORIENT] = self.backend.bk_constant(
|
|
342
|
-
www.reshape(KERNELSZ, KERNELSZ, NORIENT, NORIENT)
|
|
339
|
+
self.w_smooth.reshape(1, KERNELSZ, KERNELSZ)
|
|
343
340
|
)
|
|
344
341
|
self.ww_RealT[1] = self.backend.bk_constant(
|
|
345
342
|
self.backend.bk_reshape(
|
|
346
|
-
wwc.astype(self.all_type), [
|
|
343
|
+
wwc.astype(self.all_type), [NORIENT, KERNELSZ, KERNELSZ]
|
|
347
344
|
)
|
|
348
345
|
)
|
|
349
346
|
self.ww_ImagT[1] = self.backend.bk_constant(
|
|
350
347
|
self.backend.bk_reshape(
|
|
351
|
-
wws.astype(self.all_type), [
|
|
348
|
+
wws.astype(self.all_type), [NORIENT, KERNELSZ, KERNELSZ]
|
|
352
349
|
)
|
|
353
350
|
)
|
|
354
351
|
|
|
@@ -505,211 +502,27 @@ class FoCUS:
|
|
|
505
502
|
)
|
|
506
503
|
return indices, weights, xc, yc, zc
|
|
507
504
|
|
|
508
|
-
# ---------------------------------------------−---------
|
|
509
|
-
def calc_orientation(self, im): # im is [Ndata,12*Nside**2]
|
|
510
|
-
nside = int(np.sqrt(im.shape[1] // 12))
|
|
511
|
-
l_kernel = self.KERNELSZ * self.KERNELSZ
|
|
512
|
-
norient = 32
|
|
513
|
-
w = np.zeros([l_kernel, 1, 2 * norient])
|
|
514
|
-
ca = np.cos(np.arange(norient) / norient * np.pi)
|
|
515
|
-
sa = np.sin(np.arange(norient) / norient * np.pi)
|
|
516
|
-
stat = np.zeros([12 * nside**2, norient])
|
|
517
|
-
|
|
518
|
-
if self.ww_CNN[nside] is None:
|
|
519
|
-
self.init_CNN_index(nside, transpose=False)
|
|
520
|
-
|
|
521
|
-
y = self.Y_CNN[nside]
|
|
522
|
-
z = self.Z_CNN[nside]
|
|
523
|
-
|
|
524
|
-
for k in range(norient):
|
|
525
|
-
w[:, 0, k] = np.exp(-0.5 * nside**2 * ((y) ** 2 + (z) ** 2)) * np.cos(
|
|
526
|
-
nside * (y * ca[k] + z * sa[k]) * np.pi / 2
|
|
527
|
-
)
|
|
528
|
-
w[:, 0, k + norient] = np.exp(
|
|
529
|
-
-0.5 * nside**2 * ((y) ** 2 + (z) ** 2)
|
|
530
|
-
) * np.sin(nside * (y * ca[k] + z * sa[k]) * np.pi / 2)
|
|
531
|
-
w[:, 0, k] = w[:, 0, k] - np.mean(w[:, 0, k])
|
|
532
|
-
w[:, 0, k + norient] = w[:, 0, k] - np.mean(w[:, 0, k + norient])
|
|
533
|
-
|
|
534
|
-
for k in range(im.shape[0]):
|
|
535
|
-
tmp = im[k].reshape(12 * nside**2, 1)
|
|
536
|
-
im2 = self.healpix_layer(tmp, w)
|
|
537
|
-
stat = stat + im2[:, 0:norient] ** 2 + im2[:, norient:] ** 2
|
|
538
|
-
|
|
539
|
-
rotation = (np.argmax(stat, 1)).astype("float") / 32.0 * 180.0
|
|
540
|
-
|
|
541
|
-
indices, weights, x, y, z = self.calc_indices_convol(
|
|
542
|
-
nside, 9, rotation=rotation
|
|
543
|
-
)
|
|
544
|
-
|
|
545
|
-
return indices, weights
|
|
546
|
-
|
|
547
|
-
def init_CNN_index(self, nside, transpose=False):
|
|
548
|
-
l_kernel = int(self.KERNELSZ * self.KERNELSZ)
|
|
549
|
-
try:
|
|
550
|
-
indices = np.load(
|
|
551
|
-
"%s/FOSCAT_%s_I%d_%d_%d_CNNV3.npy"
|
|
552
|
-
% (self.TEMPLATE_PATH, TMPFILE_VERSION, l_kernel, self.NORIENT, nside)
|
|
553
|
-
)
|
|
554
|
-
weights = np.load(
|
|
555
|
-
"%s/FOSCAT_%s_W%d_%d_%d_CNNV3.npy"
|
|
556
|
-
% (self.TEMPLATE_PATH, TMPFILE_VERSION, l_kernel, self.NORIENT, nside)
|
|
557
|
-
)
|
|
558
|
-
xc = np.load(
|
|
559
|
-
"%s/FOSCAT_%s_X%d_%d_%d_CNNV3.npy"
|
|
560
|
-
% (self.TEMPLATE_PATH, TMPFILE_VERSION, l_kernel, self.NORIENT, nside)
|
|
561
|
-
)
|
|
562
|
-
yc = np.load(
|
|
563
|
-
"%s/FOSCAT_%s_Y%d_%d_%d_CNNV3.npy"
|
|
564
|
-
% (self.TEMPLATE_PATH, TMPFILE_VERSION, l_kernel, self.NORIENT, nside)
|
|
565
|
-
)
|
|
566
|
-
zc = np.load(
|
|
567
|
-
"%s/FOSCAT_%s_Z%d_%d_%d_CNNV3.npy"
|
|
568
|
-
% (self.TEMPLATE_PATH, TMPFILE_VERSION, l_kernel, self.NORIENT, nside)
|
|
569
|
-
)
|
|
570
|
-
except:
|
|
571
|
-
indices, weights, xc, yc, zc = self.calc_indices_convol(nside, l_kernel)
|
|
572
|
-
np.save(
|
|
573
|
-
"%s/FOSCAT_%s_I%d_%d_%d_CNNV3.npy"
|
|
574
|
-
% (self.TEMPLATE_PATH, TMPFILE_VERSION, l_kernel, self.NORIENT, nside),
|
|
575
|
-
indices,
|
|
576
|
-
)
|
|
577
|
-
np.save(
|
|
578
|
-
"%s/FOSCAT_%s_W%d_%d_%d_CNNV3.npy"
|
|
579
|
-
% (self.TEMPLATE_PATH, TMPFILE_VERSION, l_kernel, self.NORIENT, nside),
|
|
580
|
-
weights,
|
|
581
|
-
)
|
|
582
|
-
np.save(
|
|
583
|
-
"%s/FOSCAT_%s_X%d_%d_%d_CNNV3.npy"
|
|
584
|
-
% (self.TEMPLATE_PATH, TMPFILE_VERSION, l_kernel, self.NORIENT, nside),
|
|
585
|
-
xc,
|
|
586
|
-
)
|
|
587
|
-
np.save(
|
|
588
|
-
"%s/FOSCAT_%s_Y%d_%d_%d_CNNV3.npy"
|
|
589
|
-
% (self.TEMPLATE_PATH, TMPFILE_VERSION, l_kernel, self.NORIENT, nside),
|
|
590
|
-
yc,
|
|
591
|
-
)
|
|
592
|
-
np.save(
|
|
593
|
-
"%s/FOSCAT_%s_Z%d_%d_%d_CNNV3.npy"
|
|
594
|
-
% (self.TEMPLATE_PATH, TMPFILE_VERSION, l_kernel, self.NORIENT, nside),
|
|
595
|
-
zc,
|
|
596
|
-
)
|
|
597
|
-
if not self.silent:
|
|
598
|
-
print(
|
|
599
|
-
"Write %s/FOSCAT_%s_W%d_%d_%d_CNNV2.npy"
|
|
600
|
-
% (
|
|
601
|
-
self.TEMPLATE_PATH,
|
|
602
|
-
TMPFILE_VERSION,
|
|
603
|
-
l_kernel,
|
|
604
|
-
self.NORIENT,
|
|
605
|
-
nside,
|
|
606
|
-
)
|
|
607
|
-
)
|
|
608
|
-
|
|
609
|
-
self.X_CNN[nside] = xc
|
|
610
|
-
self.Y_CNN[nside] = yc
|
|
611
|
-
self.Z_CNN[nside] = zc
|
|
612
|
-
self.ww_CNN[nside] = self.backend.bk_SparseTensor(
|
|
613
|
-
indices, weights, [12 * nside * nside * l_kernel, 12 * nside * nside]
|
|
614
|
-
)
|
|
615
|
-
|
|
616
|
-
# ---------------------------------------------−---------
|
|
617
|
-
def healpix_layer_coord(self, im, axis=0):
|
|
618
|
-
nside = int(np.sqrt(im.shape[axis] // 12))
|
|
619
|
-
if self.ww_CNN[nside] is None:
|
|
620
|
-
self.init_CNN_index(nside)
|
|
621
|
-
return self.X_CNN[nside], self.Y_CNN[nside], self.Z_CNN[nside]
|
|
622
|
-
|
|
623
|
-
# ---------------------------------------------−---------
|
|
624
|
-
def healpix_layer_transpose(self, im, ww, indices=None, weights=None, axis=0):
|
|
625
|
-
nside = int(np.sqrt(im.shape[axis] // 12))
|
|
626
|
-
|
|
627
|
-
if im.shape[1 + axis] != ww.shape[1]:
|
|
628
|
-
if not self.silent:
|
|
629
|
-
print("Weights channels should be equal to the input image channels")
|
|
630
|
-
return -1
|
|
631
|
-
if axis == 1:
|
|
632
|
-
results = []
|
|
633
|
-
|
|
634
|
-
for k in range(im.shape[0]):
|
|
635
|
-
|
|
636
|
-
tmp = self.healpix_layer(
|
|
637
|
-
im[k], ww, indices=indices, weights=weights, axis=0
|
|
638
|
-
)
|
|
639
|
-
tmp = self.backend.bk_reshape(
|
|
640
|
-
self.up_grade(tmp, 2 * nside), [12 * 4 * nside**2, ww.shape[2]]
|
|
641
|
-
)
|
|
642
|
-
|
|
643
|
-
results.append(tmp)
|
|
644
|
-
|
|
645
|
-
return self.backend.bk_stack(results, axis=0)
|
|
646
|
-
else:
|
|
647
|
-
tmp = self.healpix_layer(
|
|
648
|
-
im, ww, indices=indices, weights=weights, axis=axis
|
|
649
|
-
)
|
|
650
|
-
|
|
651
|
-
return self.up_grade(tmp, 2 * nside)
|
|
652
|
-
|
|
653
505
|
# ---------------------------------------------−---------
|
|
654
506
|
# ---------------------------------------------−---------
|
|
655
|
-
def healpix_layer(self, im, ww, indices=None, weights=None
|
|
656
|
-
|
|
657
|
-
|
|
658
|
-
|
|
659
|
-
if im.shape[1 + axis] != ww.shape[1]:
|
|
660
|
-
if not self.silent:
|
|
661
|
-
print("Weights channels should be equal to the input image channels")
|
|
662
|
-
return -1
|
|
663
|
-
|
|
507
|
+
def healpix_layer(self, im, ww, indices=None, weights=None):
|
|
508
|
+
#ww [N_i,NORIENT,KERNELSZ*KERNELSZ//2,N_o,NORIENT]
|
|
509
|
+
#im [N_batch,N_i, NORIENT,N]
|
|
510
|
+
nside=int(np.sqrt(im.shape[-1]//12))
|
|
664
511
|
if indices is None:
|
|
665
|
-
if self.
|
|
666
|
-
self.
|
|
667
|
-
|
|
668
|
-
|
|
669
|
-
|
|
670
|
-
|
|
671
|
-
|
|
672
|
-
|
|
673
|
-
|
|
674
|
-
|
|
675
|
-
|
|
676
|
-
|
|
677
|
-
|
|
678
|
-
|
|
679
|
-
if axis == 1:
|
|
680
|
-
results = []
|
|
681
|
-
|
|
682
|
-
for k in range(im.shape[0]):
|
|
683
|
-
|
|
684
|
-
tmp = self.backend.bk_sparse_dense_matmul(mat, im[k])
|
|
685
|
-
|
|
686
|
-
density = self.backend.bk_reshape(
|
|
687
|
-
tmp, [12 * nside * nside, l_kernel * im.shape[1 + axis]]
|
|
688
|
-
)
|
|
689
|
-
|
|
690
|
-
density = self.backend.bk_matmul(
|
|
691
|
-
density,
|
|
692
|
-
self.backend.bk_reshape(
|
|
693
|
-
ww, [l_kernel * im.shape[1 + axis], ww.shape[2]]
|
|
694
|
-
),
|
|
695
|
-
)
|
|
696
|
-
|
|
697
|
-
results.append(
|
|
698
|
-
self.backend.bk_reshape(density, [12 * nside**2, ww.shape[2]])
|
|
699
|
-
)
|
|
700
|
-
|
|
701
|
-
return self.backend.bk_stack(results, axis=0)
|
|
702
|
-
else:
|
|
703
|
-
tmp = self.backend.bk_sparse_dense_matmul(mat, im)
|
|
704
|
-
|
|
705
|
-
density = self.backend.bk_reshape(
|
|
706
|
-
tmp, [12 * nside * nside, l_kernel * im.shape[1]]
|
|
707
|
-
)
|
|
708
|
-
|
|
709
|
-
return self.backend.bk_matmul(
|
|
710
|
-
density,
|
|
711
|
-
self.backend.bk_reshape(ww, [l_kernel * im.shape[1], ww.shape[2]]),
|
|
712
|
-
)
|
|
512
|
+
if (nside,self.NORIENT,self.KERNELSZ) not in self.ww_CNN:
|
|
513
|
+
self.init_index_cnn(nside,self.NORIENT)
|
|
514
|
+
indices = self.Idx_CNN[(nside,self.NORIENT,self.KERNELSZ)]
|
|
515
|
+
mat = self.Idx_WCNN[(nside,self.NORIENT,self.KERNELSZ)]
|
|
516
|
+
|
|
517
|
+
wim = self.backend.bk_gather(im,indices.flatten(),axis=3) #[N_batch,N_i,NORIENT,K*(K+1),N_o,NORIENT,N,N_w]
|
|
518
|
+
|
|
519
|
+
wim = self.backend.bk_reshape(wim,[im.shape[0],im.shape[1],im.shape[2]]+list(indices.shape))*mat[None,...]
|
|
520
|
+
#win is [N_batch,N_i, NORIENT,K*(K+1),1, NORIENT,N,N_w]
|
|
521
|
+
#ww is [1, N_i, NORIENT,K*(K+1),N_o,NORIENT]
|
|
522
|
+
wim = self.backend.bk_reduce_sum(wim[:,:,:,:,None]*ww[None,:,:,:,:,:,None,None],[1,2,3])
|
|
523
|
+
|
|
524
|
+
wim = self.backend.bk_reduce_sum(wim,-1)
|
|
525
|
+
return self.backend.bk_reshape(wim,[im.shape[0],ww.shape[3],ww.shape[4],im.shape[-1]])
|
|
713
526
|
|
|
714
527
|
# ---------------------------------------------−---------
|
|
715
528
|
|
|
@@ -806,53 +619,20 @@ class FoCUS:
|
|
|
806
619
|
return self.backend.bk_reshape(res, [npix // 2, npiy // 2]), None
|
|
807
620
|
elif self.use_1D:
|
|
808
621
|
ishape = list(im.shape)
|
|
809
|
-
if len(ishape) < axis + 1:
|
|
810
|
-
if not self.silent:
|
|
811
|
-
print("Use of 1D scat with data that has less than 1D")
|
|
812
|
-
return None, None
|
|
813
622
|
|
|
814
|
-
npix =
|
|
815
|
-
odata = 1
|
|
816
|
-
if len(ishape) > axis + 1:
|
|
817
|
-
for k in range(axis + 1, len(ishape)):
|
|
818
|
-
odata = odata * ishape[k]
|
|
623
|
+
npix = ishape[-1]
|
|
819
624
|
|
|
820
625
|
ndata = 1
|
|
821
|
-
for k in range(
|
|
626
|
+
for k in range(len(ishape) - 1):
|
|
822
627
|
ndata = ndata * ishape[k]
|
|
823
628
|
|
|
824
629
|
tim = self.backend.bk_reshape(
|
|
825
|
-
self.backend.bk_cast(im), [ndata, npix,
|
|
826
|
-
)
|
|
827
|
-
tim = self.backend.bk_reshape(
|
|
828
|
-
tim[:, 0 : 2 * (npix // 2), :], [ndata, npix // 2, 2, odata]
|
|
630
|
+
self.backend.bk_cast(im), [ndata, npix // 2, 2]
|
|
829
631
|
)
|
|
830
632
|
|
|
831
|
-
res = self.backend.bk_reduce_mean(tim,
|
|
832
|
-
|
|
833
|
-
if axis == 0:
|
|
834
|
-
if len(ishape) == 1:
|
|
835
|
-
return self.backend.bk_reshape(res, [npix // 2]), None
|
|
836
|
-
else:
|
|
837
|
-
return (
|
|
838
|
-
self.backend.bk_reshape(res, [npix // 2] + ishape[axis + 1 :]),
|
|
839
|
-
None,
|
|
840
|
-
)
|
|
841
|
-
else:
|
|
842
|
-
if len(ishape) == axis + 1:
|
|
843
|
-
return (
|
|
844
|
-
self.backend.bk_reshape(res, ishape[0:axis] + [npix // 2]),
|
|
845
|
-
None,
|
|
846
|
-
)
|
|
847
|
-
else:
|
|
848
|
-
return (
|
|
849
|
-
self.backend.bk_reshape(
|
|
850
|
-
res, ishape[0:axis] + [npix // 2] + ishape[axis + 1 :]
|
|
851
|
-
),
|
|
852
|
-
None,
|
|
853
|
-
)
|
|
633
|
+
res = self.backend.bk_reduce_mean(tim, -1)
|
|
854
634
|
|
|
855
|
-
return self.backend.bk_reshape(res, [npix // 2]), None
|
|
635
|
+
return self.backend.bk_reshape(res, ishape[0:-1] + [npix // 2]), None
|
|
856
636
|
|
|
857
637
|
else:
|
|
858
638
|
shape = list(im.shape)
|
|
@@ -1384,13 +1164,18 @@ class FoCUS:
|
|
|
1384
1164
|
return res
|
|
1385
1165
|
|
|
1386
1166
|
# ---------------------------------------------−---------
|
|
1387
|
-
def init_index(self, nside, kernel=-1):
|
|
1167
|
+
def init_index(self, nside, kernel=-1, cell_ids=None):
|
|
1388
1168
|
|
|
1389
1169
|
if kernel == -1:
|
|
1390
1170
|
l_kernel = self.KERNELSZ
|
|
1391
1171
|
else:
|
|
1392
1172
|
l_kernel = kernel
|
|
1393
1173
|
|
|
1174
|
+
if cell_ids is not None:
|
|
1175
|
+
ncell = cell_ids.shape[0]
|
|
1176
|
+
else:
|
|
1177
|
+
ncell = 12 * nside * nside
|
|
1178
|
+
|
|
1394
1179
|
try:
|
|
1395
1180
|
if self.use_2D:
|
|
1396
1181
|
tmp = np.load(
|
|
@@ -1398,16 +1183,29 @@ class FoCUS:
|
|
|
1398
1183
|
% (self.TEMPLATE_PATH, l_kernel**2, TMPFILE_VERSION, nside)
|
|
1399
1184
|
)
|
|
1400
1185
|
else:
|
|
1401
|
-
|
|
1402
|
-
|
|
1403
|
-
|
|
1404
|
-
|
|
1405
|
-
|
|
1406
|
-
|
|
1407
|
-
|
|
1408
|
-
|
|
1186
|
+
if cell_ids is not None:
|
|
1187
|
+
tmp = np.load(
|
|
1188
|
+
"%s/XXXX_%s_W%d_%d_%d_PIDX.npy" # can not work
|
|
1189
|
+
% (
|
|
1190
|
+
self.TEMPLATE_PATH,
|
|
1191
|
+
TMPFILE_VERSION,
|
|
1192
|
+
l_kernel**2,
|
|
1193
|
+
self.NORIENT,
|
|
1194
|
+
nside, # if cell_ids computes the index
|
|
1195
|
+
)
|
|
1196
|
+
)
|
|
1197
|
+
|
|
1198
|
+
else:
|
|
1199
|
+
tmp = np.load(
|
|
1200
|
+
"%s/FOSCAT_%s_W%d_%d_%d_PIDX.npy"
|
|
1201
|
+
% (
|
|
1202
|
+
self.TEMPLATE_PATH,
|
|
1203
|
+
TMPFILE_VERSION,
|
|
1204
|
+
l_kernel**2,
|
|
1205
|
+
self.NORIENT,
|
|
1206
|
+
nside, # if cell_ids computes the index
|
|
1207
|
+
)
|
|
1409
1208
|
)
|
|
1410
|
-
)
|
|
1411
1209
|
except:
|
|
1412
1210
|
if not self.use_2D:
|
|
1413
1211
|
|
|
@@ -1426,36 +1224,64 @@ class FoCUS:
|
|
|
1426
1224
|
pw2 = 0.25
|
|
1427
1225
|
threshold = 4e-5
|
|
1428
1226
|
|
|
1429
|
-
|
|
1430
|
-
|
|
1227
|
+
if cell_ids is not None:
|
|
1228
|
+
if not isinstance(cell_ids, np.ndarray):
|
|
1229
|
+
cell_ids = self.backend.to_numpy(cell_ids)
|
|
1230
|
+
th, ph = hp.pix2ang(nside, cell_ids, nest=True)
|
|
1231
|
+
x, y, z = hp.pix2vec(nside, cell_ids, nest=True)
|
|
1431
1232
|
|
|
1432
|
-
|
|
1433
|
-
|
|
1434
|
-
|
|
1233
|
+
t, p = hp.pix2ang(nside, cell_ids, nest=True)
|
|
1234
|
+
phi = [p[k] / np.pi * 180 for k in range(ncell)]
|
|
1235
|
+
thi = [t[k] / np.pi * 180 for k in range(ncell)]
|
|
1435
1236
|
|
|
1436
|
-
|
|
1437
|
-
|
|
1438
|
-
|
|
1439
|
-
|
|
1440
|
-
wav = np.zeros(
|
|
1441
|
-
[12 * nside * nside * 64 * self.NORIENT], dtype="complex"
|
|
1442
|
-
)
|
|
1443
|
-
wwav = np.zeros([12 * nside * nside * 64 * self.NORIENT], dtype="float")
|
|
1237
|
+
indice2 = np.zeros([ncell * 64, 2], dtype="int")
|
|
1238
|
+
indice = np.zeros([ncell * 64 * self.NORIENT, 2], dtype="int")
|
|
1239
|
+
wav = np.zeros([ncell * 64 * self.NORIENT], dtype="complex")
|
|
1240
|
+
wwav = np.zeros([ncell * 64 * self.NORIENT], dtype="float")
|
|
1444
1241
|
|
|
1242
|
+
else:
|
|
1243
|
+
|
|
1244
|
+
th, ph = hp.pix2ang(nside, np.arange(12 * nside**2), nest=True)
|
|
1245
|
+
x, y, z = hp.pix2vec(nside, np.arange(12 * nside**2), nest=True)
|
|
1246
|
+
|
|
1247
|
+
t, p = hp.pix2ang(nside, np.arange(12 * nside * nside), nest=True)
|
|
1248
|
+
phi = [p[k] / np.pi * 180 for k in range(12 * nside * nside)]
|
|
1249
|
+
thi = [t[k] / np.pi * 180 for k in range(12 * nside * nside)]
|
|
1250
|
+
|
|
1251
|
+
indice2 = np.zeros([12 * nside * nside * 64, 2], dtype="int")
|
|
1252
|
+
indice = np.zeros(
|
|
1253
|
+
[12 * nside * nside * 64 * self.NORIENT, 2], dtype="int"
|
|
1254
|
+
)
|
|
1255
|
+
wav = np.zeros(
|
|
1256
|
+
[12 * nside * nside * 64 * self.NORIENT], dtype="complex"
|
|
1257
|
+
)
|
|
1258
|
+
wwav = np.zeros(
|
|
1259
|
+
[12 * nside * nside * 64 * self.NORIENT], dtype="float"
|
|
1260
|
+
)
|
|
1445
1261
|
iv = 0
|
|
1446
1262
|
iv2 = 0
|
|
1447
|
-
for iii in range(12 * nside * nside):
|
|
1448
1263
|
|
|
1449
|
-
|
|
1450
|
-
|
|
1451
|
-
|
|
1452
|
-
|
|
1453
|
-
|
|
1454
|
-
|
|
1264
|
+
for iii in range(ncell):
|
|
1265
|
+
if cell_ids is None:
|
|
1266
|
+
if iii % (nside * nside) == nside * nside - 1:
|
|
1267
|
+
if not self.silent:
|
|
1268
|
+
print(
|
|
1269
|
+
"Pre-compute nside=%6d %.2f%%"
|
|
1270
|
+
% (nside, 100 * iii / (12 * nside * nside))
|
|
1271
|
+
)
|
|
1455
1272
|
|
|
1456
|
-
|
|
1457
|
-
|
|
1458
|
-
|
|
1273
|
+
if cell_ids is not None:
|
|
1274
|
+
hidx = np.where(
|
|
1275
|
+
(x - x[iii]) ** 2 + (y - y[iii]) ** 2 + (z - z[iii]) ** 2
|
|
1276
|
+
< (2 * np.pi / nside) ** 2
|
|
1277
|
+
)[0]
|
|
1278
|
+
else:
|
|
1279
|
+
hidx = hp.query_disc(
|
|
1280
|
+
nside,
|
|
1281
|
+
[x[iii], y[iii], z[iii]],
|
|
1282
|
+
2 * np.pi / nside,
|
|
1283
|
+
nest=True,
|
|
1284
|
+
)
|
|
1459
1285
|
|
|
1460
1286
|
R = hp.Rotator(rot=[phi[iii], -thi[iii]], eulertype="ZYZ")
|
|
1461
1287
|
|
|
@@ -1474,8 +1300,8 @@ class FoCUS:
|
|
|
1474
1300
|
)
|
|
1475
1301
|
idx = np.where((ww**2) > threshold)[0]
|
|
1476
1302
|
nval2 = len(idx)
|
|
1477
|
-
indice2[iv2 : iv2 + nval2,
|
|
1478
|
-
indice2[iv2 : iv2 + nval2,
|
|
1303
|
+
indice2[iv2 : iv2 + nval2, 1] = iii
|
|
1304
|
+
indice2[iv2 : iv2 + nval2, 0] = hidx[idx]
|
|
1479
1305
|
wwav[iv2 : iv2 + nval2] = ww[idx] / np.sum(ww[idx])
|
|
1480
1306
|
iv2 += nval2
|
|
1481
1307
|
|
|
@@ -1497,15 +1323,18 @@ class FoCUS:
|
|
|
1497
1323
|
idx = np.where(vnorm > threshold)[0]
|
|
1498
1324
|
|
|
1499
1325
|
nval = len(idx)
|
|
1500
|
-
indice[iv : iv + nval,
|
|
1501
|
-
indice[iv : iv + nval,
|
|
1326
|
+
indice[iv : iv + nval, 1] = iii + l_rotation * ncell
|
|
1327
|
+
indice[iv : iv + nval, 0] = hidx[idx]
|
|
1502
1328
|
# print([hidx[k] for k in idx])
|
|
1503
1329
|
# print(hp.query_disc(nside, [x[iii],y[iii],z[iii]], np.pi/nside,nest=True))
|
|
1504
1330
|
normr = np.mean(wresr[idx])
|
|
1505
1331
|
normi = np.mean(wresi[idx])
|
|
1506
1332
|
|
|
1507
1333
|
val = wresr[idx] - normr + 1j * (wresi[idx] - normi)
|
|
1508
|
-
|
|
1334
|
+
r = abs(val).sum()
|
|
1335
|
+
|
|
1336
|
+
if r > 0:
|
|
1337
|
+
val = val / r
|
|
1509
1338
|
|
|
1510
1339
|
wav[iv : iv + nval] = val
|
|
1511
1340
|
iv += nval
|
|
@@ -1609,56 +1438,57 @@ class FoCUS:
|
|
|
1609
1438
|
wav=w.flatten()
|
|
1610
1439
|
wwav=wwav.flatten()
|
|
1611
1440
|
"""
|
|
1612
|
-
if
|
|
1613
|
-
|
|
1614
|
-
|
|
1615
|
-
|
|
1441
|
+
if cell_ids is None:
|
|
1442
|
+
if not self.silent:
|
|
1443
|
+
print(
|
|
1444
|
+
"Write FOSCAT_%s_W%d_%d_%d_PIDX.npy"
|
|
1445
|
+
% (TMPFILE_VERSION, self.KERNELSZ**2, self.NORIENT, nside)
|
|
1446
|
+
)
|
|
1447
|
+
np.save(
|
|
1448
|
+
"%s/FOSCAT_%s_W%d_%d_%d_PIDX.npy"
|
|
1449
|
+
% (
|
|
1450
|
+
self.TEMPLATE_PATH,
|
|
1451
|
+
TMPFILE_VERSION,
|
|
1452
|
+
self.KERNELSZ**2,
|
|
1453
|
+
self.NORIENT,
|
|
1454
|
+
nside,
|
|
1455
|
+
),
|
|
1456
|
+
indice,
|
|
1457
|
+
)
|
|
1458
|
+
np.save(
|
|
1459
|
+
"%s/FOSCAT_%s_W%d_%d_%d_WAVE.npy"
|
|
1460
|
+
% (
|
|
1461
|
+
self.TEMPLATE_PATH,
|
|
1462
|
+
TMPFILE_VERSION,
|
|
1463
|
+
self.KERNELSZ**2,
|
|
1464
|
+
self.NORIENT,
|
|
1465
|
+
nside,
|
|
1466
|
+
),
|
|
1467
|
+
wav,
|
|
1468
|
+
)
|
|
1469
|
+
np.save(
|
|
1470
|
+
"%s/FOSCAT_%s_W%d_%d_%d_PIDX2.npy"
|
|
1471
|
+
% (
|
|
1472
|
+
self.TEMPLATE_PATH,
|
|
1473
|
+
TMPFILE_VERSION,
|
|
1474
|
+
self.KERNELSZ**2,
|
|
1475
|
+
self.NORIENT,
|
|
1476
|
+
nside,
|
|
1477
|
+
),
|
|
1478
|
+
indice2,
|
|
1479
|
+
)
|
|
1480
|
+
np.save(
|
|
1481
|
+
"%s/FOSCAT_%s_W%d_%d_%d_SMOO.npy"
|
|
1482
|
+
% (
|
|
1483
|
+
self.TEMPLATE_PATH,
|
|
1484
|
+
TMPFILE_VERSION,
|
|
1485
|
+
self.KERNELSZ**2,
|
|
1486
|
+
self.NORIENT,
|
|
1487
|
+
nside,
|
|
1488
|
+
),
|
|
1489
|
+
wwav,
|
|
1616
1490
|
)
|
|
1617
|
-
|
|
1618
|
-
"%s/FOSCAT_%s_W%d_%d_%d_PIDX.npy"
|
|
1619
|
-
% (
|
|
1620
|
-
self.TEMPLATE_PATH,
|
|
1621
|
-
TMPFILE_VERSION,
|
|
1622
|
-
self.KERNELSZ**2,
|
|
1623
|
-
self.NORIENT,
|
|
1624
|
-
nside,
|
|
1625
|
-
),
|
|
1626
|
-
indice,
|
|
1627
|
-
)
|
|
1628
|
-
np.save(
|
|
1629
|
-
"%s/FOSCAT_%s_W%d_%d_%d_WAVE.npy"
|
|
1630
|
-
% (
|
|
1631
|
-
self.TEMPLATE_PATH,
|
|
1632
|
-
TMPFILE_VERSION,
|
|
1633
|
-
self.KERNELSZ**2,
|
|
1634
|
-
self.NORIENT,
|
|
1635
|
-
nside,
|
|
1636
|
-
),
|
|
1637
|
-
wav,
|
|
1638
|
-
)
|
|
1639
|
-
np.save(
|
|
1640
|
-
"%s/FOSCAT_%s_W%d_%d_%d_PIDX2.npy"
|
|
1641
|
-
% (
|
|
1642
|
-
self.TEMPLATE_PATH,
|
|
1643
|
-
TMPFILE_VERSION,
|
|
1644
|
-
self.KERNELSZ**2,
|
|
1645
|
-
self.NORIENT,
|
|
1646
|
-
nside,
|
|
1647
|
-
),
|
|
1648
|
-
indice2,
|
|
1649
|
-
)
|
|
1650
|
-
np.save(
|
|
1651
|
-
"%s/FOSCAT_%s_W%d_%d_%d_SMOO.npy"
|
|
1652
|
-
% (
|
|
1653
|
-
self.TEMPLATE_PATH,
|
|
1654
|
-
TMPFILE_VERSION,
|
|
1655
|
-
self.KERNELSZ**2,
|
|
1656
|
-
self.NORIENT,
|
|
1657
|
-
nside,
|
|
1658
|
-
),
|
|
1659
|
-
wwav,
|
|
1660
|
-
)
|
|
1661
|
-
else:
|
|
1491
|
+
if self.use_2D:
|
|
1662
1492
|
if l_kernel**2 == 9:
|
|
1663
1493
|
if self.rank == 0:
|
|
1664
1494
|
self.comp_idx_w9(nside)
|
|
@@ -1674,23 +1504,24 @@ class FoCUS:
|
|
|
1674
1504
|
)
|
|
1675
1505
|
return None
|
|
1676
1506
|
|
|
1677
|
-
|
|
1678
|
-
|
|
1679
|
-
|
|
1680
|
-
|
|
1681
|
-
|
|
1682
|
-
|
|
1683
|
-
|
|
1684
|
-
|
|
1685
|
-
|
|
1686
|
-
|
|
1687
|
-
|
|
1688
|
-
|
|
1689
|
-
|
|
1690
|
-
|
|
1691
|
-
|
|
1507
|
+
if cell_ids is None:
|
|
1508
|
+
self.barrier()
|
|
1509
|
+
if self.use_2D:
|
|
1510
|
+
tmp = np.load(
|
|
1511
|
+
"%s/W%d_%s_%d_IDX.npy"
|
|
1512
|
+
% (self.TEMPLATE_PATH, l_kernel**2, TMPFILE_VERSION, nside)
|
|
1513
|
+
)
|
|
1514
|
+
else:
|
|
1515
|
+
tmp = np.load(
|
|
1516
|
+
"%s/FOSCAT_%s_W%d_%d_%d_PIDX.npy"
|
|
1517
|
+
% (
|
|
1518
|
+
self.TEMPLATE_PATH,
|
|
1519
|
+
TMPFILE_VERSION,
|
|
1520
|
+
self.KERNELSZ**2,
|
|
1521
|
+
self.NORIENT,
|
|
1522
|
+
nside,
|
|
1523
|
+
)
|
|
1692
1524
|
)
|
|
1693
|
-
)
|
|
1694
1525
|
tmp2 = np.load(
|
|
1695
1526
|
"%s/FOSCAT_%s_W%d_%d_%d_PIDX2.npy"
|
|
1696
1527
|
% (
|
|
@@ -1731,22 +1562,28 @@ class FoCUS:
|
|
|
1731
1562
|
nside,
|
|
1732
1563
|
)
|
|
1733
1564
|
)
|
|
1734
|
-
|
|
1735
|
-
|
|
1736
|
-
|
|
1737
|
-
|
|
1738
|
-
|
|
1739
|
-
|
|
1740
|
-
|
|
1741
|
-
|
|
1742
|
-
|
|
1743
|
-
|
|
1744
|
-
|
|
1745
|
-
|
|
1746
|
-
|
|
1747
|
-
|
|
1748
|
-
|
|
1749
|
-
|
|
1565
|
+
else:
|
|
1566
|
+
tmp = indice
|
|
1567
|
+
tmp2 = indice2
|
|
1568
|
+
wr = wav.real
|
|
1569
|
+
wi = wav.imag
|
|
1570
|
+
ws = self.slope * wwav
|
|
1571
|
+
|
|
1572
|
+
wr = self.backend.bk_SparseTensor(
|
|
1573
|
+
self.backend.bk_constant(tmp),
|
|
1574
|
+
self.backend.bk_constant(self.backend.bk_cast(wr)),
|
|
1575
|
+
dense_shape=[ncell, self.NORIENT * ncell],
|
|
1576
|
+
)
|
|
1577
|
+
wi = self.backend.bk_SparseTensor(
|
|
1578
|
+
self.backend.bk_constant(tmp),
|
|
1579
|
+
self.backend.bk_constant(self.backend.bk_cast(wi)),
|
|
1580
|
+
dense_shape=[ncell, self.NORIENT * ncell],
|
|
1581
|
+
)
|
|
1582
|
+
ws = self.backend.bk_SparseTensor(
|
|
1583
|
+
self.backend.bk_constant(tmp2),
|
|
1584
|
+
self.backend.bk_constant(self.backend.bk_cast(ws)),
|
|
1585
|
+
dense_shape=[ncell, ncell],
|
|
1586
|
+
)
|
|
1750
1587
|
|
|
1751
1588
|
if kernel == -1:
|
|
1752
1589
|
self.Idx_Neighbours[nside] = tmp
|
|
@@ -1757,42 +1594,268 @@ class FoCUS:
|
|
|
1757
1594
|
|
|
1758
1595
|
return wr, wi, ws, tmp
|
|
1759
1596
|
|
|
1597
|
+
|
|
1760
1598
|
# ---------------------------------------------−---------
|
|
1761
|
-
|
|
1762
|
-
|
|
1763
|
-
|
|
1764
|
-
|
|
1765
|
-
laxis1 = len(shape) + axis1
|
|
1766
|
-
else:
|
|
1767
|
-
laxis1 = axis1
|
|
1768
|
-
if axis2 < 0:
|
|
1769
|
-
laxis2 = len(shape) + axis2
|
|
1599
|
+
def init_index_cnn(self, nside, NORIENT=4,kernel=-1, cell_ids=None):
|
|
1600
|
+
|
|
1601
|
+
if kernel == -1:
|
|
1602
|
+
l_kernel = self.KERNELSZ
|
|
1770
1603
|
else:
|
|
1771
|
-
|
|
1604
|
+
l_kernel = kernel
|
|
1772
1605
|
|
|
1773
|
-
|
|
1774
|
-
|
|
1775
|
-
|
|
1776
|
-
|
|
1777
|
-
return self.backend.bk_transpose(x, thelist)
|
|
1606
|
+
if cell_ids is not None:
|
|
1607
|
+
ncell = cell_ids.shape[0]
|
|
1608
|
+
else:
|
|
1609
|
+
ncell = 12 * nside * nside
|
|
1778
1610
|
|
|
1779
|
-
|
|
1780
|
-
|
|
1781
|
-
|
|
1782
|
-
|
|
1783
|
-
|
|
1611
|
+
try:
|
|
1612
|
+
|
|
1613
|
+
if cell_ids is not None:
|
|
1614
|
+
tmp = np.load(
|
|
1615
|
+
"%s/XXXX_%s_W%d_%d_%d_PIDX.npy" # can not work
|
|
1616
|
+
% (
|
|
1617
|
+
self.TEMPLATE_PATH,
|
|
1618
|
+
TMPFILE_VERSION,
|
|
1619
|
+
l_kernel**2,
|
|
1620
|
+
NORIENT,
|
|
1621
|
+
nside, # if cell_ids computes the index
|
|
1622
|
+
)
|
|
1623
|
+
)
|
|
1784
1624
|
|
|
1785
|
-
|
|
1786
|
-
|
|
1787
|
-
|
|
1788
|
-
|
|
1789
|
-
|
|
1790
|
-
|
|
1625
|
+
else:
|
|
1626
|
+
tmp = np.load(
|
|
1627
|
+
"%s/CNN_FOSCAT_%s_W%d_%d_%d_PIDX.npy"
|
|
1628
|
+
% (
|
|
1629
|
+
self.TEMPLATE_PATH,
|
|
1630
|
+
TMPFILE_VERSION,
|
|
1631
|
+
l_kernel**2,
|
|
1632
|
+
NORIENT,
|
|
1633
|
+
nside, # if cell_ids computes the index
|
|
1634
|
+
)
|
|
1635
|
+
)
|
|
1636
|
+
except:
|
|
1791
1637
|
|
|
1792
|
-
|
|
1638
|
+
pw = 8.0
|
|
1639
|
+
pw2 = 1.0
|
|
1640
|
+
threshold = 1e-3
|
|
1641
|
+
|
|
1642
|
+
if l_kernel == 5:
|
|
1643
|
+
pw = 8.0
|
|
1644
|
+
pw2 = 0.5
|
|
1645
|
+
threshold = 2e-4
|
|
1793
1646
|
|
|
1794
|
-
|
|
1795
|
-
|
|
1647
|
+
elif l_kernel == 3:
|
|
1648
|
+
pw = 8.0
|
|
1649
|
+
pw2 = 1.0
|
|
1650
|
+
threshold = 1e-3
|
|
1651
|
+
|
|
1652
|
+
elif l_kernel == 7:
|
|
1653
|
+
pw = 8.0
|
|
1654
|
+
pw2 = 0.25
|
|
1655
|
+
threshold = 4e-5
|
|
1656
|
+
|
|
1657
|
+
n_weights = self.KERNELSZ*(self.KERNELSZ//2+1)
|
|
1658
|
+
|
|
1659
|
+
if cell_ids is not None:
|
|
1660
|
+
if not isinstance(cell_ids, np.ndarray):
|
|
1661
|
+
cell_ids = self.backend.to_numpy(cell_ids)
|
|
1662
|
+
th, ph = hp.pix2ang(nside, cell_ids, nest=True)
|
|
1663
|
+
x, y, z = hp.pix2vec(nside, cell_ids, nest=True)
|
|
1664
|
+
|
|
1665
|
+
t, p = hp.pix2ang(nside, cell_ids, nest=True)
|
|
1666
|
+
phi = [p[k] / np.pi * 180 for k in range(ncell)]
|
|
1667
|
+
thi = [t[k] / np.pi * 180 for k in range(ncell)]
|
|
1668
|
+
|
|
1669
|
+
indice = np.zeros([n_weights, NORIENT, ncell,4], dtype="int")
|
|
1670
|
+
|
|
1671
|
+
wav = np.zeros([n_weights, NORIENT, ncell,4], dtype="float")
|
|
1672
|
+
|
|
1673
|
+
else:
|
|
1674
|
+
|
|
1675
|
+
th, ph = hp.pix2ang(nside, np.arange(12 * nside**2), nest=True)
|
|
1676
|
+
x, y, z = hp.pix2vec(nside, np.arange(12 * nside**2), nest=True)
|
|
1677
|
+
|
|
1678
|
+
t, p = hp.pix2ang(nside, np.arange(12 * nside * nside), nest=True)
|
|
1679
|
+
phi = [p[k] / np.pi * 180 for k in range(12 * nside * nside)]
|
|
1680
|
+
thi = [t[k] / np.pi * 180 for k in range(12 * nside * nside)]
|
|
1681
|
+
|
|
1682
|
+
indice = np.zeros(
|
|
1683
|
+
[n_weights, NORIENT, 12 * nside * nside,4], dtype="int"
|
|
1684
|
+
)
|
|
1685
|
+
wav = np.zeros(
|
|
1686
|
+
[n_weights, NORIENT, 12 * nside * nside,4], dtype="float"
|
|
1687
|
+
)
|
|
1688
|
+
iv = 0
|
|
1689
|
+
iv2 = 0
|
|
1690
|
+
|
|
1691
|
+
for iii in range(ncell):
|
|
1692
|
+
if cell_ids is None:
|
|
1693
|
+
if iii % (nside * nside) == nside * nside - 1:
|
|
1694
|
+
if not self.silent:
|
|
1695
|
+
print(
|
|
1696
|
+
"Pre-compute nside=%6d %.2f%%"
|
|
1697
|
+
% (nside, 100 * iii / (12 * nside * nside))
|
|
1698
|
+
)
|
|
1699
|
+
|
|
1700
|
+
if cell_ids is not None:
|
|
1701
|
+
hidx = np.where(
|
|
1702
|
+
(x - x[iii]) ** 2 + (y - y[iii]) ** 2 + (z - z[iii]) ** 2
|
|
1703
|
+
< (2 * np.pi / nside) ** 2
|
|
1704
|
+
)[0]
|
|
1705
|
+
else:
|
|
1706
|
+
hidx = hp.query_disc(
|
|
1707
|
+
nside,
|
|
1708
|
+
[x[iii], y[iii], z[iii]],
|
|
1709
|
+
2 * np.pi / nside,
|
|
1710
|
+
nest=True,
|
|
1711
|
+
)
|
|
1712
|
+
|
|
1713
|
+
R = hp.Rotator(rot=[phi[iii], -thi[iii]], eulertype="ZYZ")
|
|
1714
|
+
|
|
1715
|
+
t2, p2 = R(th[hidx], ph[hidx])
|
|
1716
|
+
|
|
1717
|
+
vec2 = hp.ang2vec(t2, p2)
|
|
1718
|
+
|
|
1719
|
+
x2 = vec2[:, 0]
|
|
1720
|
+
y2 = vec2[:, 1]
|
|
1721
|
+
z2 = vec2[:, 2]
|
|
1722
|
+
|
|
1723
|
+
for l_rotation in range(NORIENT):
|
|
1724
|
+
|
|
1725
|
+
angle = (
|
|
1726
|
+
l_rotation / 4.0 * np.pi
|
|
1727
|
+
- phi[iii] / 180.0 * np.pi * (z[hidx] > 0)
|
|
1728
|
+
- (180.0 - phi[iii]) / 180.0 * np.pi * (z[hidx] < 0)
|
|
1729
|
+
)
|
|
1730
|
+
|
|
1731
|
+
|
|
1732
|
+
axes = y2 * np.cos(angle) - x2 * np.sin(angle)
|
|
1733
|
+
axes2 = -y2 * np.sin(angle) - x2 * np.cos(angle)
|
|
1734
|
+
|
|
1735
|
+
for k_weights in range(self.KERNELSZ//2+1):
|
|
1736
|
+
for l_weights in range(self.KERNELSZ):
|
|
1737
|
+
|
|
1738
|
+
val=np.exp(-(pw*(axes2*(nside)-(k_weights-self.KERNELSZ//2))**2+pw*(axes*(nside)-(l_weights-self.KERNELSZ//2))**2))+ \
|
|
1739
|
+
np.exp(-(pw*(axes2*(nside)+(k_weights-self.KERNELSZ//2))**2+pw*(axes*(nside)-(l_weights-self.KERNELSZ//2))**2))
|
|
1740
|
+
|
|
1741
|
+
idx = np.argsort(-val)
|
|
1742
|
+
idx = idx[0:4]
|
|
1743
|
+
|
|
1744
|
+
nval = len(idx)
|
|
1745
|
+
val=val[idx]
|
|
1746
|
+
|
|
1747
|
+
r = abs(val).sum()
|
|
1748
|
+
|
|
1749
|
+
if r > 0:
|
|
1750
|
+
val = val / r
|
|
1751
|
+
|
|
1752
|
+
indice[k_weights*self.KERNELSZ+l_weights,l_rotation,iii,:] = hidx[idx]
|
|
1753
|
+
wav[k_weights*self.KERNELSZ+l_weights,l_rotation,iii,:] = val
|
|
1754
|
+
|
|
1755
|
+
if not self.silent:
|
|
1756
|
+
print("Kernel Size ", iv / (NORIENT * 12 * nside * nside))
|
|
1757
|
+
|
|
1758
|
+
if cell_ids is None:
|
|
1759
|
+
if not self.silent:
|
|
1760
|
+
print(
|
|
1761
|
+
"Write FOSCAT_%s_W%d_%d_%d_PIDX.npy"
|
|
1762
|
+
% (TMPFILE_VERSION, self.KERNELSZ**2, NORIENT, nside)
|
|
1763
|
+
)
|
|
1764
|
+
np.save(
|
|
1765
|
+
"%s/CNN_FOSCAT_%s_W%d_%d_%d_PIDX.npy"
|
|
1766
|
+
% (
|
|
1767
|
+
self.TEMPLATE_PATH,
|
|
1768
|
+
TMPFILE_VERSION,
|
|
1769
|
+
self.KERNELSZ**2,
|
|
1770
|
+
NORIENT,
|
|
1771
|
+
nside,
|
|
1772
|
+
),
|
|
1773
|
+
indice,
|
|
1774
|
+
)
|
|
1775
|
+
np.save(
|
|
1776
|
+
"%s/CNN_FOSCAT_%s_W%d_%d_%d_WAVE.npy"
|
|
1777
|
+
% (
|
|
1778
|
+
self.TEMPLATE_PATH,
|
|
1779
|
+
TMPFILE_VERSION,
|
|
1780
|
+
self.KERNELSZ**2,
|
|
1781
|
+
NORIENT,
|
|
1782
|
+
nside,
|
|
1783
|
+
),
|
|
1784
|
+
wav,
|
|
1785
|
+
)
|
|
1786
|
+
|
|
1787
|
+
if cell_ids is None:
|
|
1788
|
+
self.barrier()
|
|
1789
|
+
if self.use_2D:
|
|
1790
|
+
tmp = np.load(
|
|
1791
|
+
"%s/W%d_%s_%d_IDX.npy"
|
|
1792
|
+
% (self.TEMPLATE_PATH, l_kernel**2, TMPFILE_VERSION, nside)
|
|
1793
|
+
)
|
|
1794
|
+
else:
|
|
1795
|
+
tmp = np.load(
|
|
1796
|
+
"%s/CNN_FOSCAT_%s_W%d_%d_%d_PIDX.npy"
|
|
1797
|
+
% (
|
|
1798
|
+
self.TEMPLATE_PATH,
|
|
1799
|
+
TMPFILE_VERSION,
|
|
1800
|
+
self.KERNELSZ**2,
|
|
1801
|
+
NORIENT,
|
|
1802
|
+
nside,
|
|
1803
|
+
)
|
|
1804
|
+
)
|
|
1805
|
+
wav = np.load(
|
|
1806
|
+
"%s/CNN_FOSCAT_%s_W%d_%d_%d_WAVE.npy"
|
|
1807
|
+
% (
|
|
1808
|
+
self.TEMPLATE_PATH,
|
|
1809
|
+
TMPFILE_VERSION,
|
|
1810
|
+
self.KERNELSZ**2,
|
|
1811
|
+
NORIENT,
|
|
1812
|
+
nside,
|
|
1813
|
+
)
|
|
1814
|
+
)
|
|
1815
|
+
else:
|
|
1816
|
+
tmp = indice
|
|
1817
|
+
|
|
1818
|
+
self.Idx_CNN[(nside,NORIENT,self.KERNELSZ)] = tmp
|
|
1819
|
+
self.Idx_WCNN[(nside,NORIENT,self.KERNELSZ)] = self.backend.bk_cast(wav)
|
|
1820
|
+
|
|
1821
|
+
return wav, tmp
|
|
1822
|
+
|
|
1823
|
+
# ---------------------------------------------−---------
|
|
1824
|
+
# convert swap axes tensor x [....,a,....,b,....] to [....,b,....,a,....]
|
|
1825
|
+
def swapaxes(self, x, axis1, axis2):
|
|
1826
|
+
shape = list(x.shape)
|
|
1827
|
+
if axis1 < 0:
|
|
1828
|
+
laxis1 = len(shape) + axis1
|
|
1829
|
+
else:
|
|
1830
|
+
laxis1 = axis1
|
|
1831
|
+
if axis2 < 0:
|
|
1832
|
+
laxis2 = len(shape) + axis2
|
|
1833
|
+
else:
|
|
1834
|
+
laxis2 = axis2
|
|
1835
|
+
|
|
1836
|
+
naxes = len(shape)
|
|
1837
|
+
thelist = [i for i in range(naxes)]
|
|
1838
|
+
thelist[laxis1] = laxis2
|
|
1839
|
+
thelist[laxis2] = laxis1
|
|
1840
|
+
return self.backend.bk_transpose(x, thelist)
|
|
1841
|
+
|
|
1842
|
+
# ---------------------------------------------−---------
|
|
1843
|
+
# Mean using mask x [....,Npix,....], mask[Nmask,Npix] to [....,Nmask,....]
|
|
1844
|
+
# if use_2D
|
|
1845
|
+
# Mean using mask x [....,12,Nside+2*off,Nside+2*off,....], mask[Nmask,12,Nside+2*off,Nside+2*off] to [....,Nmask,....]
|
|
1846
|
+
def masked_mean(self, x, mask, axis=0, rank=0, calc_var=False):
|
|
1847
|
+
|
|
1848
|
+
# ==========================================================================
|
|
1849
|
+
# in input data=[Nbatch,...,NORIENT[,NORIENT],X[,Y]]
|
|
1850
|
+
# in input mask=[Nmask,X[,Y]]
|
|
1851
|
+
# if self.use_2D : X[,Y]] = [X,Y]
|
|
1852
|
+
# if second level: NORIENT[,NORIENT]= NORIENT,NORIENT
|
|
1853
|
+
# ==========================================================================
|
|
1854
|
+
|
|
1855
|
+
shape = list(x.shape)
|
|
1856
|
+
|
|
1857
|
+
if not self.use_2D and not self.use_1D:
|
|
1858
|
+
nside = int(np.sqrt(x.shape[axis] // 12))
|
|
1796
1859
|
|
|
1797
1860
|
l_mask = mask
|
|
1798
1861
|
if self.mask_norm:
|
|
@@ -1802,6 +1865,7 @@ class FoCUS:
|
|
|
1802
1865
|
),
|
|
1803
1866
|
1,
|
|
1804
1867
|
)
|
|
1868
|
+
|
|
1805
1869
|
if not self.use_2D:
|
|
1806
1870
|
l_mask = (
|
|
1807
1871
|
12
|
|
@@ -1845,13 +1909,11 @@ class FoCUS:
|
|
|
1845
1909
|
]
|
|
1846
1910
|
|
|
1847
1911
|
ichannel = 1
|
|
1848
|
-
for i in range(
|
|
1912
|
+
for i in range(1, len(shape) - 2):
|
|
1849
1913
|
ichannel *= shape[i]
|
|
1850
|
-
|
|
1851
|
-
for i in range(axis + 2, len(shape)):
|
|
1852
|
-
ochannel *= shape[i]
|
|
1914
|
+
|
|
1853
1915
|
l_x = self.backend.bk_reshape(
|
|
1854
|
-
x, [
|
|
1916
|
+
x, [shape[0], 1, ichannel, shape[-2], shape[-1]]
|
|
1855
1917
|
)
|
|
1856
1918
|
|
|
1857
1919
|
if self.padding == "VALID":
|
|
@@ -1876,12 +1938,10 @@ class FoCUS:
|
|
|
1876
1938
|
l_mask = l_mask[:, self.KERNELSZ // 2 : -self.KERNELSZ // 2 + 1]
|
|
1877
1939
|
|
|
1878
1940
|
ichannel = 1
|
|
1879
|
-
for i in range(
|
|
1941
|
+
for i in range(1, len(shape) - 1):
|
|
1880
1942
|
ichannel *= shape[i]
|
|
1881
|
-
|
|
1882
|
-
|
|
1883
|
-
ochannel *= shape[i]
|
|
1884
|
-
l_x = self.backend.bk_reshape(x, [ichannel, 1, shape[axis], ochannel])
|
|
1943
|
+
|
|
1944
|
+
l_x = self.backend.bk_reshape(x, [shape[0], 1, ichannel, shape[-1]])
|
|
1885
1945
|
|
|
1886
1946
|
if self.padding == "VALID":
|
|
1887
1947
|
oshape = [k for k in shape]
|
|
@@ -1891,18 +1951,14 @@ class FoCUS:
|
|
|
1891
1951
|
)
|
|
1892
1952
|
else:
|
|
1893
1953
|
ichannel = 1
|
|
1894
|
-
for i in range(
|
|
1954
|
+
for i in range(len(shape) - 1):
|
|
1895
1955
|
ichannel *= shape[i]
|
|
1896
|
-
ochannel = 1
|
|
1897
|
-
for i in range(axis + 1, len(shape)):
|
|
1898
|
-
ochannel *= shape[i]
|
|
1899
|
-
l_x = self.backend.bk_reshape(x, [ichannel, 1, shape[axis], ochannel])
|
|
1900
1956
|
|
|
1901
|
-
|
|
1902
|
-
|
|
1903
|
-
|
|
1904
|
-
# mask=[
|
|
1905
|
-
l_mask = self.backend.bk_expand_dims(l_mask,
|
|
1957
|
+
l_x = self.backend.bk_reshape(x, [ichannel, 1, shape[-1]])
|
|
1958
|
+
|
|
1959
|
+
# data=[Nbatch,...,NORIENT[,NORIENT],X[,Y]] => data=[Nbatch,1,...,NORIENT[,NORIENT],X[,Y]]
|
|
1960
|
+
# mask=[Nmask,X[,Y]] => mask=[1,Nmask,....,X[,Y]]
|
|
1961
|
+
l_mask = self.backend.bk_expand_dims(self.backend.bk_expand_dims(l_mask, 0), 0)
|
|
1906
1962
|
|
|
1907
1963
|
if l_x.dtype == self.all_cbk_type:
|
|
1908
1964
|
l_mask = self.backend.bk_complex(l_mask, self.backend.bk_cast(0.0 * l_mask))
|
|
@@ -1916,21 +1972,23 @@ class FoCUS:
|
|
|
1916
1972
|
# vtmp = l_x[:,self.KERNELSZ // 2 : -self.KERNELSZ // 2,self.KERNELSZ // 2 : -self.KERNELSZ // 2,:]
|
|
1917
1973
|
|
|
1918
1974
|
v1 = self.backend.bk_reduce_sum(
|
|
1919
|
-
self.backend.bk_reduce_sum(mtmp * vtmp, axis
|
|
1975
|
+
self.backend.bk_reduce_sum(mtmp * vtmp, axis=-1), -1
|
|
1920
1976
|
)
|
|
1921
1977
|
v2 = self.backend.bk_reduce_sum(
|
|
1922
|
-
self.backend.bk_reduce_sum(mtmp * vtmp * vtmp, axis
|
|
1978
|
+
self.backend.bk_reduce_sum(mtmp * vtmp * vtmp, axis=-1), -1
|
|
1979
|
+
)
|
|
1980
|
+
vh = self.backend.bk_reduce_sum(
|
|
1981
|
+
self.backend.bk_reduce_sum(mtmp, axis=-1), -1
|
|
1923
1982
|
)
|
|
1924
|
-
vh = self.backend.bk_reduce_sum(self.backend.bk_reduce_sum(mtmp, axis=2), 2)
|
|
1925
1983
|
|
|
1926
1984
|
res = v1 / vh
|
|
1927
1985
|
|
|
1928
|
-
oshape = []
|
|
1986
|
+
oshape = [x.shape[0]] + [mask.shape[0]]
|
|
1929
1987
|
if axis > 0:
|
|
1930
|
-
oshape = oshape + list(x.shape[
|
|
1931
|
-
|
|
1932
|
-
if
|
|
1933
|
-
oshape = oshape + list(x.shape[axis
|
|
1988
|
+
oshape = oshape + list(x.shape[1:axis])
|
|
1989
|
+
|
|
1990
|
+
if len(x.shape[axis:-2]) > 0:
|
|
1991
|
+
oshape = oshape + list(x.shape[axis:-2])
|
|
1934
1992
|
|
|
1935
1993
|
if calc_var:
|
|
1936
1994
|
if self.backend.bk_is_complex(vtmp):
|
|
@@ -1960,19 +2018,15 @@ class FoCUS:
|
|
|
1960
2018
|
elif self.use_1D:
|
|
1961
2019
|
mtmp = l_mask
|
|
1962
2020
|
vtmp = l_x
|
|
1963
|
-
|
|
1964
|
-
|
|
1965
|
-
|
|
1966
|
-
vh = self.backend.bk_reduce_sum(mtmp, axis=2)
|
|
2021
|
+
v1 = self.backend.bk_reduce_sum(mtmp * vtmp, axis=-1)
|
|
2022
|
+
v2 = self.backend.bk_reduce_sum(mtmp * vtmp * vtmp, axis=-1)
|
|
2023
|
+
vh = self.backend.bk_reduce_sum(mtmp, axis=-1)
|
|
1967
2024
|
|
|
1968
2025
|
res = v1 / vh
|
|
1969
2026
|
|
|
1970
|
-
oshape = []
|
|
1971
|
-
if
|
|
1972
|
-
oshape = oshape + list(x.shape[
|
|
1973
|
-
oshape = oshape + [mask.shape[0]]
|
|
1974
|
-
if axis + 1 < len(x.shape):
|
|
1975
|
-
oshape = oshape + list(x.shape[axis + 1 :])
|
|
2027
|
+
oshape = [x.shape[0]] + [mask.shape[0]]
|
|
2028
|
+
if len(x.shape) > 1:
|
|
2029
|
+
oshape = oshape + list(x.shape[1:-1])
|
|
1976
2030
|
|
|
1977
2031
|
if calc_var:
|
|
1978
2032
|
if self.backend.bk_is_complex(vtmp):
|
|
@@ -1991,7 +2045,6 @@ class FoCUS:
|
|
|
1991
2045
|
)
|
|
1992
2046
|
else:
|
|
1993
2047
|
res2 = self.backend.bk_sqrt((v2 / vh - res * res) / (vh))
|
|
1994
|
-
|
|
1995
2048
|
res = self.backend.bk_reshape(res, oshape)
|
|
1996
2049
|
res2 = self.backend.bk_reshape(res2, oshape)
|
|
1997
2050
|
return res, res2
|
|
@@ -2000,18 +2053,20 @@ class FoCUS:
|
|
|
2000
2053
|
return res
|
|
2001
2054
|
|
|
2002
2055
|
else:
|
|
2003
|
-
v1 = self.backend.bk_reduce_sum(l_mask * l_x, axis
|
|
2004
|
-
v2 = self.backend.bk_reduce_sum(l_mask * l_x * l_x, axis
|
|
2005
|
-
vh = self.backend.bk_reduce_sum(l_mask, axis
|
|
2056
|
+
v1 = self.backend.bk_reduce_sum(l_mask * l_x, axis=-1)
|
|
2057
|
+
v2 = self.backend.bk_reduce_sum(l_mask * l_x * l_x, axis=-1)
|
|
2058
|
+
vh = self.backend.bk_reduce_sum(l_mask, axis=-1)
|
|
2006
2059
|
|
|
2007
2060
|
res = v1 / vh
|
|
2008
2061
|
|
|
2009
2062
|
oshape = []
|
|
2010
2063
|
if axis > 0:
|
|
2011
|
-
oshape =
|
|
2064
|
+
oshape = [x.shape[0]]
|
|
2065
|
+
else:
|
|
2066
|
+
oshape = [1]
|
|
2012
2067
|
oshape = oshape + [mask.shape[0]]
|
|
2013
|
-
if axis
|
|
2014
|
-
oshape = oshape + list(x.shape[
|
|
2068
|
+
if axis > 1:
|
|
2069
|
+
oshape = oshape + list(x.shape[1:-1])
|
|
2015
2070
|
|
|
2016
2071
|
if calc_var:
|
|
2017
2072
|
if self.backend.bk_is_complex(l_x):
|
|
@@ -2176,169 +2231,67 @@ class FoCUS:
|
|
|
2176
2231
|
print("Use of 2D scat with data that has less than 2D")
|
|
2177
2232
|
return None
|
|
2178
2233
|
|
|
2179
|
-
npix = ishape[
|
|
2180
|
-
npiy = ishape[
|
|
2181
|
-
odata = 1
|
|
2182
|
-
if len(ishape) > axis + 2:
|
|
2183
|
-
for k in range(axis + 2, len(ishape)):
|
|
2184
|
-
odata = odata * ishape[k]
|
|
2234
|
+
npix = ishape[-2]
|
|
2235
|
+
npiy = ishape[-1]
|
|
2185
2236
|
|
|
2186
2237
|
ndata = 1
|
|
2187
|
-
for k in range(
|
|
2238
|
+
for k in range(len(ishape) - 2):
|
|
2188
2239
|
ndata = ndata * ishape[k]
|
|
2189
2240
|
|
|
2190
2241
|
tim = self.backend.bk_reshape(
|
|
2191
|
-
self.backend.bk_cast(in_image), [ndata, npix, npiy
|
|
2242
|
+
self.backend.bk_cast(in_image), [ndata, npix, npiy]
|
|
2192
2243
|
)
|
|
2193
2244
|
|
|
2194
2245
|
if self.backend.bk_is_complex(tim):
|
|
2195
|
-
rr1 = self.backend.conv2d(
|
|
2196
|
-
|
|
2197
|
-
|
|
2198
|
-
|
|
2199
|
-
padding=self.padding,
|
|
2200
|
-
)
|
|
2201
|
-
ii1 = self.backend.conv2d(
|
|
2202
|
-
self.backend.bk_real(tim),
|
|
2203
|
-
self.ww_ImagT[odata],
|
|
2204
|
-
strides=[1, 1, 1, 1],
|
|
2205
|
-
padding=self.padding,
|
|
2206
|
-
)
|
|
2207
|
-
rr2 = self.backend.conv2d(
|
|
2208
|
-
self.backend.bk_imag(tim),
|
|
2209
|
-
self.ww_RealT[odata],
|
|
2210
|
-
strides=[1, 1, 1, 1],
|
|
2211
|
-
padding=self.padding,
|
|
2212
|
-
)
|
|
2213
|
-
ii2 = self.backend.conv2d(
|
|
2214
|
-
self.backend.bk_imag(tim),
|
|
2215
|
-
self.ww_ImagT[odata],
|
|
2216
|
-
strides=[1, 1, 1, 1],
|
|
2217
|
-
padding=self.padding,
|
|
2218
|
-
)
|
|
2246
|
+
rr1 = self.backend.conv2d(self.backend.bk_real(tim), self.ww_RealT[1])
|
|
2247
|
+
ii1 = self.backend.conv2d(self.backend.bk_real(tim), self.ww_ImagT[1])
|
|
2248
|
+
rr2 = self.backend.conv2d(self.backend.bk_imag(tim), self.ww_RealT[1])
|
|
2249
|
+
ii2 = self.backend.conv2d(self.backend.bk_imag(tim), self.ww_ImagT[1])
|
|
2219
2250
|
res = self.backend.bk_complex(rr1 - ii2, ii1 + rr2)
|
|
2220
2251
|
else:
|
|
2221
|
-
rr = self.backend.conv2d(
|
|
2222
|
-
|
|
2223
|
-
self.ww_RealT[odata],
|
|
2224
|
-
strides=[1, 1, 1, 1],
|
|
2225
|
-
padding=self.padding,
|
|
2226
|
-
)
|
|
2227
|
-
ii = self.backend.conv2d(
|
|
2228
|
-
tim,
|
|
2229
|
-
self.ww_ImagT[odata],
|
|
2230
|
-
strides=[1, 1, 1, 1],
|
|
2231
|
-
padding=self.padding,
|
|
2232
|
-
)
|
|
2252
|
+
rr = self.backend.conv2d(tim, self.ww_RealT[1])
|
|
2253
|
+
ii = self.backend.conv2d(tim, self.ww_ImagT[1])
|
|
2233
2254
|
res = self.backend.bk_complex(rr, ii)
|
|
2234
2255
|
|
|
2235
|
-
|
|
2236
|
-
|
|
2237
|
-
|
|
2238
|
-
res, [res.shape[1], res.shape[2], self.NORIENT]
|
|
2239
|
-
)
|
|
2240
|
-
else:
|
|
2241
|
-
return self.backend.bk_reshape(
|
|
2242
|
-
res,
|
|
2243
|
-
[res.shape[1], res.shape[2], self.NORIENT] + ishape[axis + 2 :],
|
|
2244
|
-
)
|
|
2245
|
-
else:
|
|
2246
|
-
if len(ishape) == axis + 2:
|
|
2247
|
-
return self.backend.bk_reshape(
|
|
2248
|
-
res, ishape[0:axis] + [res.shape[1], res.shape[2], self.NORIENT]
|
|
2249
|
-
)
|
|
2250
|
-
else:
|
|
2251
|
-
return self.backend.bk_reshape(
|
|
2252
|
-
res,
|
|
2253
|
-
ishape[0:axis]
|
|
2254
|
-
+ [res.shape[1], res.shape[2], self.NORIENT]
|
|
2255
|
-
+ ishape[axis + 2 :],
|
|
2256
|
-
)
|
|
2256
|
+
return self.backend.bk_reshape(
|
|
2257
|
+
res, ishape[0:-2] + [self.NORIENT, npix, npiy]
|
|
2258
|
+
)
|
|
2257
2259
|
|
|
2258
|
-
return self.backend.bk_reshape(res, in_image.shape + [self.NORIENT])
|
|
2259
2260
|
elif self.use_1D:
|
|
2260
2261
|
ishape = list(in_image.shape)
|
|
2261
|
-
if len(ishape) < axis + 1:
|
|
2262
|
-
if not self.silent:
|
|
2263
|
-
print("Use of 1D scat with data that has less than 1D")
|
|
2264
|
-
return None
|
|
2265
2262
|
|
|
2266
|
-
npix = ishape[
|
|
2267
|
-
odata = 1
|
|
2268
|
-
if len(ishape) > axis + 1:
|
|
2269
|
-
for k in range(axis + 1, len(ishape)):
|
|
2270
|
-
odata = odata * ishape[k]
|
|
2263
|
+
npix = ishape[-1]
|
|
2271
2264
|
|
|
2272
2265
|
ndata = 1
|
|
2273
|
-
for k in range(
|
|
2266
|
+
for k in range(len(ishape) - 1):
|
|
2274
2267
|
ndata = ndata * ishape[k]
|
|
2275
2268
|
|
|
2276
|
-
tim = self.backend.bk_reshape(
|
|
2277
|
-
self.backend.bk_cast(in_image), [ndata, npix, odata]
|
|
2278
|
-
)
|
|
2269
|
+
tim = self.backend.bk_reshape(self.backend.bk_cast(in_image), [ndata, npix])
|
|
2279
2270
|
|
|
2280
2271
|
if self.backend.bk_is_complex(tim):
|
|
2281
|
-
rr1 = self.backend.conv1d(
|
|
2282
|
-
|
|
2283
|
-
|
|
2284
|
-
|
|
2285
|
-
padding=self.padding,
|
|
2286
|
-
)
|
|
2287
|
-
ii1 = self.backend.conv1d(
|
|
2288
|
-
self.backend.bk_real(tim),
|
|
2289
|
-
self.ww_ImagT[odata],
|
|
2290
|
-
strides=[1, 1, 1],
|
|
2291
|
-
padding=self.padding,
|
|
2292
|
-
)
|
|
2293
|
-
rr2 = self.backend.conv1d(
|
|
2294
|
-
self.backend.bk_imag(tim),
|
|
2295
|
-
self.ww_RealT[odata],
|
|
2296
|
-
strides=[1, 1, 1],
|
|
2297
|
-
padding=self.padding,
|
|
2298
|
-
)
|
|
2299
|
-
ii2 = self.backend.conv1d(
|
|
2300
|
-
self.backend.bk_imag(tim),
|
|
2301
|
-
self.ww_ImagT[odata],
|
|
2302
|
-
strides=[1, 1, 1],
|
|
2303
|
-
padding=self.padding,
|
|
2304
|
-
)
|
|
2272
|
+
rr1 = self.backend.conv1d(self.backend.bk_real(tim), self.ww_RealT[1])
|
|
2273
|
+
ii1 = self.backend.conv1d(self.backend.bk_real(tim), self.ww_ImagT[1])
|
|
2274
|
+
rr2 = self.backend.conv1d(self.backend.bk_imag(tim), self.ww_RealT[1])
|
|
2275
|
+
ii2 = self.backend.conv1d(self.backend.bk_imag(tim), self.ww_ImagT[1])
|
|
2305
2276
|
res = self.backend.bk_complex(rr1 - ii2, ii1 + rr2)
|
|
2306
2277
|
else:
|
|
2307
|
-
rr = self.backend.conv1d(
|
|
2308
|
-
|
|
2309
|
-
)
|
|
2310
|
-
ii = self.backend.conv1d(
|
|
2311
|
-
tim, self.ww_ImagT[odata], strides=[1, 1, 1], padding=self.padding
|
|
2312
|
-
)
|
|
2278
|
+
rr = self.backend.conv1d(tim, self.ww_RealT[1])
|
|
2279
|
+
ii = self.backend.conv1d(tim, self.ww_ImagT[1])
|
|
2313
2280
|
res = self.backend.bk_complex(rr, ii)
|
|
2314
2281
|
|
|
2315
|
-
|
|
2316
|
-
if len(ishape) == 1:
|
|
2317
|
-
return self.backend.bk_reshape(res, [res.shape[1]])
|
|
2318
|
-
else:
|
|
2319
|
-
return self.backend.bk_reshape(
|
|
2320
|
-
res, [res.shape[1]] + ishape[axis + 2 :]
|
|
2321
|
-
)
|
|
2322
|
-
else:
|
|
2323
|
-
if len(ishape) == axis + 1:
|
|
2324
|
-
return self.backend.bk_reshape(res, ishape[0:axis] + [res.shape[1]])
|
|
2325
|
-
else:
|
|
2326
|
-
return self.backend.bk_reshape(
|
|
2327
|
-
res, ishape[0:axis] + [res.shape[1]] + ishape[axis + 1 :]
|
|
2328
|
-
)
|
|
2329
|
-
|
|
2330
|
-
return self.backend.bk_reshape(res, in_image.shape + [self.NORIENT])
|
|
2282
|
+
return self.backend.bk_reshape(res, ishape)
|
|
2331
2283
|
|
|
2332
2284
|
else:
|
|
2333
2285
|
ishape = list(image.shape)
|
|
2334
|
-
|
|
2286
|
+
"""
|
|
2335
2287
|
if cell_ids is not None:
|
|
2336
2288
|
if cell_ids.shape[0] not in self.padding_conv:
|
|
2289
|
+
print(image.shape,cell_ids.shape)
|
|
2337
2290
|
import healpix_convolution as hc
|
|
2338
2291
|
from xdggs.healpix import HealpixInfo
|
|
2339
2292
|
|
|
2340
2293
|
res = self.backend.bk_zeros(
|
|
2341
|
-
ishape + [self.NORIENT], dtype=self.backend.all_cbk_type
|
|
2294
|
+
ishape[0:-1] + [self.NORIENT]+ishape[-1], dtype=self.backend.all_cbk_type
|
|
2342
2295
|
)
|
|
2343
2296
|
|
|
2344
2297
|
grid_info = HealpixInfo(
|
|
@@ -2384,14 +2337,15 @@ class FoCUS:
|
|
|
2384
2337
|
padded_data
|
|
2385
2338
|
) + 1j * kernelI.matmul(padded_data)
|
|
2386
2339
|
return res
|
|
2387
|
-
|
|
2388
|
-
nside
|
|
2340
|
+
"""
|
|
2341
|
+
if nside is None:
|
|
2342
|
+
nside = int(np.sqrt(image.shape[-1] // 12))
|
|
2389
2343
|
|
|
2390
2344
|
if self.Idx_Neighbours[nside] is None:
|
|
2391
2345
|
if self.InitWave is None:
|
|
2392
|
-
wr, wi, ws, widx = self.init_index(nside)
|
|
2346
|
+
wr, wi, ws, widx = self.init_index(nside, cell_ids=cell_ids)
|
|
2393
2347
|
else:
|
|
2394
|
-
wr, wi, ws, widx = self.InitWave(
|
|
2348
|
+
wr, wi, ws, widx = self.InitWave(nside, cell_ids=cell_ids)
|
|
2395
2349
|
|
|
2396
2350
|
self.Idx_Neighbours[nside] = 1 # self.backend.bk_constant(tmp)
|
|
2397
2351
|
self.ww_Real[nside] = wr
|
|
@@ -2401,156 +2355,63 @@ class FoCUS:
|
|
|
2401
2355
|
l_ww_real = self.ww_Real[nside]
|
|
2402
2356
|
l_ww_imag = self.ww_Imag[nside]
|
|
2403
2357
|
|
|
2404
|
-
|
|
2405
|
-
for k in range(axis + 1, len(ishape)):
|
|
2406
|
-
odata = odata * ishape[k]
|
|
2358
|
+
# always convolve the last dimension
|
|
2407
2359
|
|
|
2408
|
-
|
|
2409
|
-
|
|
2410
|
-
for k in range(
|
|
2360
|
+
ndata = 1
|
|
2361
|
+
if len(ishape) > 1:
|
|
2362
|
+
for k in range(len(ishape) - 1):
|
|
2411
2363
|
ndata = ndata * ishape[k]
|
|
2412
|
-
|
|
2413
|
-
|
|
2414
|
-
|
|
2415
|
-
if tim.dtype == self.all_cbk_type:
|
|
2416
|
-
rr1 = self.backend.bk_reshape(
|
|
2417
|
-
self.backend.bk_sparse_dense_matmul(
|
|
2418
|
-
l_ww_real, self.backend.bk_real(tim[0])
|
|
2419
|
-
),
|
|
2420
|
-
[1, 12 * nside**2, self.NORIENT, odata],
|
|
2421
|
-
)
|
|
2422
|
-
ii1 = self.backend.bk_reshape(
|
|
2423
|
-
self.backend.bk_sparse_dense_matmul(
|
|
2424
|
-
l_ww_imag, self.backend.bk_real(tim[0])
|
|
2425
|
-
),
|
|
2426
|
-
[1, 12 * nside**2, self.NORIENT, odata],
|
|
2427
|
-
)
|
|
2428
|
-
rr2 = self.backend.bk_reshape(
|
|
2429
|
-
self.backend.bk_sparse_dense_matmul(
|
|
2430
|
-
l_ww_real, self.backend.bk_imag(tim[0])
|
|
2431
|
-
),
|
|
2432
|
-
[1, 12 * nside**2, self.NORIENT, odata],
|
|
2433
|
-
)
|
|
2434
|
-
ii2 = self.backend.bk_reshape(
|
|
2435
|
-
self.backend.bk_sparse_dense_matmul(
|
|
2436
|
-
l_ww_imag, self.backend.bk_imag(tim[0])
|
|
2437
|
-
),
|
|
2438
|
-
[1, 12 * nside**2, self.NORIENT, odata],
|
|
2439
|
-
)
|
|
2440
|
-
res = self.backend.bk_complex(rr1 - ii2, ii1 + rr2)
|
|
2441
|
-
else:
|
|
2442
|
-
rr = self.backend.bk_reshape(
|
|
2443
|
-
self.backend.bk_sparse_dense_matmul(l_ww_real, tim[0]),
|
|
2444
|
-
[1, 12 * nside**2, self.NORIENT, odata],
|
|
2445
|
-
)
|
|
2446
|
-
ii = self.backend.bk_reshape(
|
|
2447
|
-
self.backend.bk_sparse_dense_matmul(l_ww_imag, tim[0]),
|
|
2448
|
-
[1, 12 * nside**2, self.NORIENT, odata],
|
|
2449
|
-
)
|
|
2450
|
-
res = self.backend.bk_complex(rr, ii)
|
|
2451
|
-
|
|
2452
|
-
for k in range(1, ndata):
|
|
2453
|
-
if tim.dtype == self.all_cbk_type:
|
|
2454
|
-
rr1 = self.backend.bk_reshape(
|
|
2455
|
-
self.backend.bk_sparse_dense_matmul(
|
|
2456
|
-
l_ww_real, self.backend.bk_real(tim[k])
|
|
2457
|
-
),
|
|
2458
|
-
[1, 12 * nside**2, self.NORIENT, odata],
|
|
2459
|
-
)
|
|
2460
|
-
ii1 = self.backend.bk_reshape(
|
|
2461
|
-
self.backend.bk_sparse_dense_matmul(
|
|
2462
|
-
l_ww_imag, self.backend.bk_real(tim[k])
|
|
2463
|
-
),
|
|
2464
|
-
[1, 12 * nside**2, self.NORIENT, odata],
|
|
2465
|
-
)
|
|
2466
|
-
rr2 = self.backend.bk_reshape(
|
|
2467
|
-
self.backend.bk_sparse_dense_matmul(
|
|
2468
|
-
l_ww_real, self.backend.bk_imag(tim[k])
|
|
2469
|
-
),
|
|
2470
|
-
[1, 12 * nside**2, self.NORIENT, odata],
|
|
2471
|
-
)
|
|
2472
|
-
ii2 = self.backend.bk_reshape(
|
|
2473
|
-
self.backend.bk_sparse_dense_matmul(
|
|
2474
|
-
l_ww_imag, self.backend.bk_imag(tim[k])
|
|
2475
|
-
),
|
|
2476
|
-
[1, 12 * nside**2, self.NORIENT, odata],
|
|
2477
|
-
)
|
|
2478
|
-
res = self.backend.bk_concat(
|
|
2479
|
-
[res, self.backend.bk_complex(rr1 - ii2, ii1 + rr2)], 0
|
|
2480
|
-
)
|
|
2481
|
-
else:
|
|
2482
|
-
rr = self.backend.bk_reshape(
|
|
2483
|
-
self.backend.bk_sparse_dense_matmul(l_ww_real, tim[k]),
|
|
2484
|
-
[1, 12 * nside**2, self.NORIENT, odata],
|
|
2485
|
-
)
|
|
2486
|
-
ii = self.backend.bk_reshape(
|
|
2487
|
-
self.backend.bk_sparse_dense_matmul(l_ww_imag, tim[k]),
|
|
2488
|
-
[1, 12 * nside**2, self.NORIENT, odata],
|
|
2489
|
-
)
|
|
2490
|
-
res = self.backend.bk_concat(
|
|
2491
|
-
[res, self.backend.bk_complex(rr, ii)], 0
|
|
2492
|
-
)
|
|
2493
|
-
|
|
2494
|
-
if len(ishape) == axis + 1:
|
|
2495
|
-
return self.backend.bk_reshape(
|
|
2496
|
-
res, ishape[0:axis] + [12 * nside**2, self.NORIENT]
|
|
2497
|
-
)
|
|
2498
|
-
else:
|
|
2499
|
-
return self.backend.bk_reshape(
|
|
2500
|
-
res,
|
|
2501
|
-
ishape[0:axis]
|
|
2502
|
-
+ [12 * nside**2]
|
|
2503
|
-
+ ishape[axis + 1 :]
|
|
2504
|
-
+ [self.NORIENT],
|
|
2505
|
-
)
|
|
2364
|
+
tim = self.backend.bk_reshape(
|
|
2365
|
+
self.backend.bk_cast(image), [ndata, ishape[-1]]
|
|
2366
|
+
)
|
|
2506
2367
|
|
|
2507
|
-
if
|
|
2508
|
-
|
|
2509
|
-
self.backend.
|
|
2368
|
+
if tim.dtype == self.all_cbk_type:
|
|
2369
|
+
rr1 = self.backend.bk_reshape(
|
|
2370
|
+
self.backend.bk_sparse_dense_matmul(
|
|
2371
|
+
self.backend.bk_real(tim),
|
|
2372
|
+
l_ww_real,
|
|
2373
|
+
),
|
|
2374
|
+
[ndata, self.NORIENT, ishape[-1]],
|
|
2510
2375
|
)
|
|
2511
|
-
|
|
2512
|
-
|
|
2513
|
-
self.backend.
|
|
2514
|
-
|
|
2515
|
-
|
|
2516
|
-
|
|
2517
|
-
|
|
2518
|
-
|
|
2519
|
-
|
|
2520
|
-
|
|
2521
|
-
|
|
2522
|
-
|
|
2523
|
-
|
|
2524
|
-
|
|
2525
|
-
|
|
2526
|
-
|
|
2527
|
-
),
|
|
2528
|
-
|
|
2529
|
-
)
|
|
2530
|
-
|
|
2531
|
-
|
|
2532
|
-
|
|
2533
|
-
|
|
2534
|
-
|
|
2535
|
-
)
|
|
2536
|
-
|
|
2537
|
-
|
|
2538
|
-
|
|
2539
|
-
|
|
2540
|
-
|
|
2541
|
-
|
|
2542
|
-
|
|
2543
|
-
|
|
2544
|
-
|
|
2545
|
-
|
|
2546
|
-
|
|
2376
|
+
ii1 = self.backend.bk_reshape(
|
|
2377
|
+
self.backend.bk_sparse_dense_matmul(
|
|
2378
|
+
self.backend.bk_real(tim),
|
|
2379
|
+
l_ww_imag,
|
|
2380
|
+
),
|
|
2381
|
+
[ndata, self.NORIENT, ishape[-1]],
|
|
2382
|
+
)
|
|
2383
|
+
rr2 = self.backend.bk_reshape(
|
|
2384
|
+
self.backend.bk_sparse_dense_matmul(
|
|
2385
|
+
self.backend.bk_imag(tim),
|
|
2386
|
+
l_ww_real,
|
|
2387
|
+
),
|
|
2388
|
+
[ndata, self.NORIENT, ishape[-1]],
|
|
2389
|
+
)
|
|
2390
|
+
ii2 = self.backend.bk_reshape(
|
|
2391
|
+
self.backend.bk_sparse_dense_matmul(
|
|
2392
|
+
self.backend.bk_imag(tim),
|
|
2393
|
+
l_ww_imag,
|
|
2394
|
+
),
|
|
2395
|
+
[ndata, self.NORIENT, ishape[-1]],
|
|
2396
|
+
)
|
|
2397
|
+
res = self.backend.bk_complex(rr1 - ii2, ii1 + rr2)
|
|
2398
|
+
else:
|
|
2399
|
+
rr = self.backend.bk_reshape(
|
|
2400
|
+
self.backend.bk_sparse_dense_matmul(tim, l_ww_real),
|
|
2401
|
+
[ndata, self.NORIENT, ishape[-1]],
|
|
2402
|
+
)
|
|
2403
|
+
ii = self.backend.bk_reshape(
|
|
2404
|
+
self.backend.bk_sparse_dense_matmul(tim, l_ww_imag),
|
|
2405
|
+
[ndata, self.NORIENT, ishape[-1]],
|
|
2406
|
+
)
|
|
2407
|
+
res = self.backend.bk_complex(rr, ii)
|
|
2408
|
+
if len(ishape) > 1:
|
|
2409
|
+
return self.backend.bk_reshape(
|
|
2410
|
+
res, ishape[0:-1] + [self.NORIENT, ishape[-1]]
|
|
2411
|
+
)
|
|
2412
|
+
else:
|
|
2413
|
+
return self.backend.bk_reshape(res, [self.NORIENT, ishape[-1]])
|
|
2547
2414
|
|
|
2548
|
-
if len(ishape) == 1:
|
|
2549
|
-
return self.backend.bk_reshape(res, [12 * nside**2, self.NORIENT])
|
|
2550
|
-
else:
|
|
2551
|
-
return self.backend.bk_reshape(
|
|
2552
|
-
res, [12 * nside**2] + ishape[axis + 1 :] + [self.NORIENT]
|
|
2553
|
-
)
|
|
2554
2415
|
return res
|
|
2555
2416
|
|
|
2556
2417
|
# ---------------------------------------------−---------
|
|
@@ -2578,114 +2439,43 @@ class FoCUS:
|
|
|
2578
2439
|
ndata = ndata * ishape[k]
|
|
2579
2440
|
|
|
2580
2441
|
tim = self.backend.bk_reshape(
|
|
2581
|
-
self.backend.bk_cast(in_image), [ndata, npix, npiy
|
|
2442
|
+
self.backend.bk_cast(in_image), [ndata, npix, npiy]
|
|
2582
2443
|
)
|
|
2583
2444
|
|
|
2584
2445
|
if self.backend.bk_is_complex(tim):
|
|
2585
|
-
rr = self.backend.conv2d(
|
|
2586
|
-
|
|
2587
|
-
self.ww_SmoothT[odata],
|
|
2588
|
-
strides=[1, 1, 1, 1],
|
|
2589
|
-
padding=self.padding,
|
|
2590
|
-
)
|
|
2591
|
-
ii = self.backend.conv2d(
|
|
2592
|
-
self.backend.bk_imag(tim),
|
|
2593
|
-
self.ww_SmoothT[odata],
|
|
2594
|
-
strides=[1, 1, 1, 1],
|
|
2595
|
-
padding=self.padding,
|
|
2596
|
-
)
|
|
2446
|
+
rr = self.backend.conv2d(self.backend.bk_real(tim), self.ww_SmoothT[1])
|
|
2447
|
+
ii = self.backend.conv2d(self.backend.bk_imag(tim), self.ww_SmoothT[1])
|
|
2597
2448
|
res = self.backend.bk_complex(rr, ii)
|
|
2598
2449
|
else:
|
|
2599
|
-
res = self.backend.conv2d(
|
|
2600
|
-
tim,
|
|
2601
|
-
self.ww_SmoothT[odata],
|
|
2602
|
-
strides=[1, 1, 1, 1],
|
|
2603
|
-
padding=self.padding,
|
|
2604
|
-
)
|
|
2450
|
+
res = self.backend.conv2d(tim, self.ww_SmoothT[1])
|
|
2605
2451
|
|
|
2606
|
-
|
|
2607
|
-
if len(ishape) == 2:
|
|
2608
|
-
return self.backend.bk_reshape(res, [res.shape[1], res.shape[2]])
|
|
2609
|
-
else:
|
|
2610
|
-
return self.backend.bk_reshape(
|
|
2611
|
-
res, [res.shape[1], res.shape[2]] + ishape[axis + 2 :]
|
|
2612
|
-
)
|
|
2613
|
-
else:
|
|
2614
|
-
if len(ishape) == axis + 2:
|
|
2615
|
-
return self.backend.bk_reshape(
|
|
2616
|
-
res, ishape[0:axis] + [res.shape[1], res.shape[2]]
|
|
2617
|
-
)
|
|
2618
|
-
else:
|
|
2619
|
-
return self.backend.bk_reshape(
|
|
2620
|
-
res,
|
|
2621
|
-
ishape[0:axis]
|
|
2622
|
-
+ [res.shape[1], res.shape[2]]
|
|
2623
|
-
+ ishape[axis + 2 :],
|
|
2624
|
-
)
|
|
2452
|
+
return self.backend.bk_reshape(res, ishape)
|
|
2625
2453
|
|
|
2626
|
-
return self.backend.bk_reshape(res, in_image.shape)
|
|
2627
2454
|
elif self.use_1D:
|
|
2628
2455
|
|
|
2629
2456
|
ishape = list(in_image.shape)
|
|
2630
|
-
if len(ishape) < axis + 1:
|
|
2631
|
-
if not self.silent:
|
|
2632
|
-
print("Use of 1D scat with data that has less than 1D")
|
|
2633
|
-
return None
|
|
2634
2457
|
|
|
2635
|
-
npix = ishape[
|
|
2636
|
-
odata = 1
|
|
2637
|
-
if len(ishape) > axis + 1:
|
|
2638
|
-
for k in range(axis + 1, len(ishape)):
|
|
2639
|
-
odata = odata * ishape[k]
|
|
2458
|
+
npix = ishape[-1]
|
|
2640
2459
|
|
|
2641
2460
|
ndata = 1
|
|
2642
|
-
for k in range(
|
|
2461
|
+
for k in range(len(ishape) - 1):
|
|
2643
2462
|
ndata = ndata * ishape[k]
|
|
2644
2463
|
|
|
2645
|
-
tim = self.backend.bk_reshape(
|
|
2646
|
-
self.backend.bk_cast(in_image), [ndata, npix, odata]
|
|
2647
|
-
)
|
|
2464
|
+
tim = self.backend.bk_reshape(self.backend.bk_cast(in_image), [ndata, npix])
|
|
2648
2465
|
|
|
2649
2466
|
if self.backend.bk_is_complex(tim):
|
|
2650
|
-
rr = self.backend.conv1d(
|
|
2651
|
-
|
|
2652
|
-
self.ww_SmoothT[odata],
|
|
2653
|
-
strides=[1, 1, 1],
|
|
2654
|
-
padding=self.padding,
|
|
2655
|
-
)
|
|
2656
|
-
ii = self.backend.conv1d(
|
|
2657
|
-
self.backend.bk_imag(tim),
|
|
2658
|
-
self.ww_SmoothT[odata],
|
|
2659
|
-
strides=[1, 1, 1],
|
|
2660
|
-
padding=self.padding,
|
|
2661
|
-
)
|
|
2467
|
+
rr = self.backend.conv1d(self.backend.bk_real(tim), self.ww_SmoothT[1])
|
|
2468
|
+
ii = self.backend.conv1d(self.backend.bk_imag(tim), self.ww_SmoothT[1])
|
|
2662
2469
|
res = self.backend.bk_complex(rr, ii)
|
|
2663
2470
|
else:
|
|
2664
|
-
res = self.backend.conv1d(
|
|
2665
|
-
tim, self.ww_SmoothT[odata], strides=[1, 1, 1], padding=self.padding
|
|
2666
|
-
)
|
|
2667
|
-
|
|
2668
|
-
if axis == 0:
|
|
2669
|
-
if len(ishape) == 1:
|
|
2670
|
-
return self.backend.bk_reshape(res, [res.shape[1]])
|
|
2671
|
-
else:
|
|
2672
|
-
return self.backend.bk_reshape(
|
|
2673
|
-
res, [res.shape[1]] + ishape[axis + 1 :]
|
|
2674
|
-
)
|
|
2675
|
-
else:
|
|
2676
|
-
if len(ishape) == axis + 1:
|
|
2677
|
-
return self.backend.bk_reshape(res, ishape[0:axis] + [res.shape[1]])
|
|
2678
|
-
else:
|
|
2679
|
-
return self.backend.bk_reshape(
|
|
2680
|
-
res, ishape[0:axis] + [res.shape[1]] + ishape[axis + 1 :]
|
|
2681
|
-
)
|
|
2471
|
+
res = self.backend.conv1d(tim, self.ww_SmoothT[1])
|
|
2682
2472
|
|
|
2683
|
-
return self.backend.bk_reshape(res,
|
|
2473
|
+
return self.backend.bk_reshape(res, ishape)
|
|
2684
2474
|
|
|
2685
2475
|
else:
|
|
2686
2476
|
|
|
2687
2477
|
ishape = list(image.shape)
|
|
2688
|
-
|
|
2478
|
+
"""
|
|
2689
2479
|
if cell_ids is not None:
|
|
2690
2480
|
if cell_ids.shape[0] not in self.padding_smooth:
|
|
2691
2481
|
import healpix_convolution as hc
|
|
@@ -2726,15 +2516,16 @@ class FoCUS:
|
|
|
2726
2516
|
padded_data = padding.apply(image[l, :, k2], is_torch=True)
|
|
2727
2517
|
res[l, :, k2] = kernel.matmul(padded_data)
|
|
2728
2518
|
return res
|
|
2729
|
-
|
|
2730
|
-
nside
|
|
2519
|
+
"""
|
|
2520
|
+
if nside is None:
|
|
2521
|
+
nside = int(np.sqrt(image.shape[-1] // 12))
|
|
2731
2522
|
|
|
2732
2523
|
if self.Idx_Neighbours[nside] is None:
|
|
2733
2524
|
|
|
2734
2525
|
if self.InitWave is None:
|
|
2735
|
-
wr, wi, ws, widx = self.init_index(nside)
|
|
2526
|
+
wr, wi, ws, widx = self.init_index(nside, cell_ids=cell_ids)
|
|
2736
2527
|
else:
|
|
2737
|
-
wr, wi, ws, widx = self.InitWave(self, nside)
|
|
2528
|
+
wr, wi, ws, widx = self.InitWave(self, nside, cell_ids=cell_ids)
|
|
2738
2529
|
|
|
2739
2530
|
self.Idx_Neighbours[nside] = 1
|
|
2740
2531
|
self.ww_Real[nside] = wr
|
|
@@ -2744,92 +2535,24 @@ class FoCUS:
|
|
|
2744
2535
|
l_w_smooth = self.w_smooth[nside]
|
|
2745
2536
|
|
|
2746
2537
|
odata = 1
|
|
2747
|
-
for k in range(
|
|
2538
|
+
for k in range(0, len(ishape) - 1):
|
|
2748
2539
|
odata = odata * ishape[k]
|
|
2749
2540
|
|
|
2750
|
-
|
|
2751
|
-
|
|
2752
|
-
|
|
2753
|
-
|
|
2754
|
-
|
|
2755
|
-
|
|
2756
|
-
|
|
2757
|
-
|
|
2758
|
-
|
|
2759
|
-
|
|
2760
|
-
|
|
2761
|
-
|
|
2762
|
-
|
|
2763
|
-
|
|
2764
|
-
|
|
2765
|
-
return self.backend.bk_reshape(
|
|
2766
|
-
res, [12 * nside**2] + ishape[axis + 1 :]
|
|
2767
|
-
)
|
|
2768
|
-
|
|
2769
|
-
if axis > 0:
|
|
2770
|
-
ndata = ishape[0]
|
|
2771
|
-
for k in range(1, axis):
|
|
2772
|
-
ndata = ndata * ishape[k]
|
|
2773
|
-
tim = self.backend.bk_reshape(image, [ndata, 12 * nside**2, odata])
|
|
2774
|
-
if tim.dtype == self.all_cbk_type:
|
|
2775
|
-
rr = self.backend.bk_reshape(
|
|
2776
|
-
self.backend.bk_sparse_dense_matmul(
|
|
2777
|
-
l_w_smooth, self.backend.bk_real(tim[0])
|
|
2778
|
-
),
|
|
2779
|
-
[1, 12 * nside**2, odata],
|
|
2780
|
-
)
|
|
2781
|
-
ri = self.backend.bk_reshape(
|
|
2782
|
-
self.backend.bk_sparse_dense_matmul(
|
|
2783
|
-
l_w_smooth, self.backend.bk_imag(tim[0])
|
|
2784
|
-
),
|
|
2785
|
-
[1, 12 * nside**2, odata],
|
|
2786
|
-
)
|
|
2787
|
-
res = self.backend.bk_complex(rr, ri)
|
|
2788
|
-
else:
|
|
2789
|
-
res = self.backend.bk_reshape(
|
|
2790
|
-
self.backend.bk_sparse_dense_matmul(l_w_smooth, tim[0]),
|
|
2791
|
-
[1, 12 * nside**2, odata],
|
|
2792
|
-
)
|
|
2793
|
-
|
|
2794
|
-
for k in range(1, ndata):
|
|
2795
|
-
if tim.dtype == self.all_cbk_type:
|
|
2796
|
-
rr = self.backend.bk_reshape(
|
|
2797
|
-
self.backend.bk_sparse_dense_matmul(
|
|
2798
|
-
l_w_smooth, self.backend.bk_real(tim[k])
|
|
2799
|
-
),
|
|
2800
|
-
[1, 12 * nside**2, odata],
|
|
2801
|
-
)
|
|
2802
|
-
ri = self.backend.bk_reshape(
|
|
2803
|
-
self.backend.bk_sparse_dense_matmul(
|
|
2804
|
-
l_w_smooth, self.backend.bk_imag(tim[k])
|
|
2805
|
-
),
|
|
2806
|
-
[1, 12 * nside**2, odata],
|
|
2807
|
-
)
|
|
2808
|
-
res = self.backend.bk_concat(
|
|
2809
|
-
[res, self.backend.bk_complex(rr, ri)], 0
|
|
2810
|
-
)
|
|
2811
|
-
else:
|
|
2812
|
-
res = self.backend.bk_concat(
|
|
2813
|
-
[
|
|
2814
|
-
res,
|
|
2815
|
-
self.backend.bk_reshape(
|
|
2816
|
-
self.backend.bk_sparse_dense_matmul(
|
|
2817
|
-
l_w_smooth, tim[k]
|
|
2818
|
-
),
|
|
2819
|
-
[1, 12 * nside**2, odata],
|
|
2820
|
-
),
|
|
2821
|
-
],
|
|
2822
|
-
0,
|
|
2823
|
-
)
|
|
2824
|
-
|
|
2825
|
-
if len(ishape) == axis + 1:
|
|
2826
|
-
return self.backend.bk_reshape(
|
|
2827
|
-
res, ishape[0:axis] + [12 * nside**2]
|
|
2828
|
-
)
|
|
2829
|
-
else:
|
|
2830
|
-
return self.backend.bk_reshape(
|
|
2831
|
-
res, ishape[0:axis] + [12 * nside**2] + ishape[axis + 1 :]
|
|
2832
|
-
)
|
|
2541
|
+
tim = self.backend.bk_reshape(image, [odata, ishape[-1]])
|
|
2542
|
+
if tim.dtype == self.all_cbk_type:
|
|
2543
|
+
rr = self.backend.bk_sparse_dense_matmul(
|
|
2544
|
+
self.backend.bk_real(tim), l_w_smooth
|
|
2545
|
+
)
|
|
2546
|
+
ri = self.backend.bk_sparse_dense_matmul(
|
|
2547
|
+
self.backend.bk_imag(tim), l_w_smooth
|
|
2548
|
+
)
|
|
2549
|
+
res = self.backend.bk_complex(rr, ri)
|
|
2550
|
+
else:
|
|
2551
|
+
res = self.backend.bk_sparse_dense_matmul(tim, l_w_smooth)
|
|
2552
|
+
if len(ishape) == 1:
|
|
2553
|
+
return self.backend.bk_reshape(res, [ishape[-1]])
|
|
2554
|
+
else:
|
|
2555
|
+
return self.backend.bk_reshape(res, ishape[0:-1] + [ishape[-1]])
|
|
2833
2556
|
|
|
2834
2557
|
return res
|
|
2835
2558
|
|