foscat 2025.5.2__py3-none-any.whl → 2025.6.1__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- foscat/BkTorch.py +10 -12
- foscat/CNN.py +31 -30
- foscat/FoCUS.py +248 -203
- foscat/GCNN.py +48 -150
- foscat/Softmax.py +1 -0
- foscat/alm.py +2 -2
- foscat/heal_NN.py +432 -0
- foscat/scat_cov.py +32 -1
- {foscat-2025.5.2.dist-info → foscat-2025.6.1.dist-info}/METADATA +1 -1
- {foscat-2025.5.2.dist-info → foscat-2025.6.1.dist-info}/RECORD +13 -12
- {foscat-2025.5.2.dist-info → foscat-2025.6.1.dist-info}/WHEEL +0 -0
- {foscat-2025.5.2.dist-info → foscat-2025.6.1.dist-info}/licenses/LICENSE +0 -0
- {foscat-2025.5.2.dist-info → foscat-2025.6.1.dist-info}/top_level.txt +0 -0
foscat/BkTorch.py
CHANGED
|
@@ -149,11 +149,7 @@ class BkTorch(BackendBase.BackendBase):
|
|
|
149
149
|
# -- BACKEND DEFINITION --
|
|
150
150
|
# ---------------------------------------------−---------
|
|
151
151
|
def bk_SparseTensor(self, indice, w, dense_shape=[]):
|
|
152
|
-
return (
|
|
153
|
-
self.backend.sparse_coo_tensor(indice.T, w, dense_shape)
|
|
154
|
-
.to_sparse_csr()
|
|
155
|
-
.to(self.torch_device)
|
|
156
|
-
)
|
|
152
|
+
return self.backend.sparse_coo_tensor(indice.T, w, dense_shape).to_sparse_csr().to(self.torch_device)
|
|
157
153
|
|
|
158
154
|
def bk_stack(self, list, axis=0):
|
|
159
155
|
return self.backend.stack(list, axis=axis).to(self.torch_device)
|
|
@@ -246,13 +242,13 @@ class BkTorch(BackendBase.BackendBase):
|
|
|
246
242
|
xr = self.bk_real(x)
|
|
247
243
|
# xi = self.bk_imag(x)
|
|
248
244
|
|
|
249
|
-
r = self.backend.sign(xr) * self.backend.sqrt(self.backend.sign(xr) * xr)
|
|
245
|
+
r = self.backend.sign(xr) * self.backend.sqrt(self.backend.sign(xr) * xr + 1E-16)
|
|
250
246
|
# return r
|
|
251
247
|
# i = self.backend.sign(xi) * self.backend.sqrt(self.backend.sign(xi) * xi)
|
|
252
248
|
|
|
253
249
|
return r
|
|
254
250
|
else:
|
|
255
|
-
return self.backend.sign(x) * self.backend.sqrt(self.backend.sign(x) * x)
|
|
251
|
+
return self.backend.sign(x) * self.backend.sqrt(self.backend.sign(x) * x + 1E-16)
|
|
256
252
|
|
|
257
253
|
def bk_square_comp(self, x):
|
|
258
254
|
if x.dtype == self.all_cbk_type:
|
|
@@ -391,9 +387,9 @@ class BkTorch(BackendBase.BackendBase):
|
|
|
391
387
|
return self.backend.argmax(data)
|
|
392
388
|
|
|
393
389
|
def bk_reshape(self, data, shape):
|
|
394
|
-
if isinstance(data, np.ndarray):
|
|
395
|
-
|
|
396
|
-
return data.
|
|
390
|
+
#if isinstance(data, np.ndarray):
|
|
391
|
+
# return data.reshape(shape)
|
|
392
|
+
return data.reshape(shape)
|
|
397
393
|
|
|
398
394
|
def bk_repeat(self, data, nn, axis=0):
|
|
399
395
|
return self.backend.repeat_interleave(data, repeats=nn, dim=axis)
|
|
@@ -440,7 +436,9 @@ class BkTorch(BackendBase.BackendBase):
|
|
|
440
436
|
return self.backend.zeros(shape, dtype=dtype).to(self.torch_device)
|
|
441
437
|
|
|
442
438
|
def bk_gather(self, data, idx, axis=0):
|
|
443
|
-
if axis ==
|
|
439
|
+
if axis == -1:
|
|
440
|
+
return data[...,idx]
|
|
441
|
+
elif axis == 0:
|
|
444
442
|
return data[idx]
|
|
445
443
|
elif axis == 1:
|
|
446
444
|
return data[:, idx]
|
|
@@ -448,7 +446,7 @@ class BkTorch(BackendBase.BackendBase):
|
|
|
448
446
|
return data[:, :, idx]
|
|
449
447
|
elif axis == 3:
|
|
450
448
|
return data[:, :, :, idx]
|
|
451
|
-
return data[
|
|
449
|
+
return data[idx,...]
|
|
452
450
|
|
|
453
451
|
def bk_reverse(self, data, axis=0):
|
|
454
452
|
return self.backend.flip(data, dims=[axis])
|
foscat/CNN.py
CHANGED
|
@@ -9,13 +9,12 @@ class CNN:
|
|
|
9
9
|
|
|
10
10
|
def __init__(
|
|
11
11
|
self,
|
|
12
|
-
scat_operator=None,
|
|
13
12
|
nparam=1,
|
|
14
|
-
|
|
13
|
+
KERNELSZ=3,
|
|
14
|
+
NORIENT=4,
|
|
15
15
|
chanlist=[],
|
|
16
16
|
in_nside=1,
|
|
17
17
|
n_chan_in=1,
|
|
18
|
-
nbatch=1,
|
|
19
18
|
SEED=1234,
|
|
20
19
|
filename=None,
|
|
21
20
|
):
|
|
@@ -31,31 +30,30 @@ class CNN:
|
|
|
31
30
|
self.in_nside = outlist[4]
|
|
32
31
|
self.nbatch = outlist[1]
|
|
33
32
|
self.n_chan_in = outlist[8]
|
|
33
|
+
self.NORIENT = outlist[9]
|
|
34
34
|
self.x = self.scat_operator.backend.bk_cast(outlist[6])
|
|
35
35
|
self.out_nside = self.in_nside // (2**self.nscale)
|
|
36
36
|
else:
|
|
37
|
-
self.nscale =
|
|
38
|
-
self.nbatch = nbatch
|
|
37
|
+
self.nscale = len(chanlist)-1
|
|
39
38
|
self.npar = nparam
|
|
40
39
|
self.n_chan_in = n_chan_in
|
|
41
40
|
self.scat_operator = scat_operator
|
|
42
|
-
if
|
|
43
|
-
|
|
44
|
-
|
|
45
|
-
|
|
46
|
-
)
|
|
47
|
-
return None
|
|
41
|
+
if self.scat_operator is None:
|
|
42
|
+
self.scat_operator = sc.funct(
|
|
43
|
+
KERNELSZ=KERNELSZ,
|
|
44
|
+
NORIENT=NORIENT)
|
|
48
45
|
|
|
49
46
|
self.chanlist = chanlist
|
|
50
|
-
self.KERNELSZ = scat_operator.KERNELSZ
|
|
51
|
-
self.
|
|
47
|
+
self.KERNELSZ = self.scat_operator.KERNELSZ
|
|
48
|
+
self.NORIENT = self.scat_operator.NORIENT
|
|
49
|
+
self.all_type = self.scat_operator.all_type
|
|
52
50
|
self.in_nside = in_nside
|
|
53
51
|
self.out_nside = self.in_nside // (2**self.nscale)
|
|
54
|
-
|
|
52
|
+
self.backend = self.scat_operator.backend
|
|
55
53
|
np.random.seed(SEED)
|
|
56
|
-
self.x = scat_operator.backend.bk_cast(
|
|
57
|
-
np.random.
|
|
58
|
-
/ (self.KERNELSZ * self.KERNELSZ)
|
|
54
|
+
self.x = self.scat_operator.backend.bk_cast(
|
|
55
|
+
np.random.rand(self.get_number_of_weights())
|
|
56
|
+
/ (self.KERNELSZ * (self.KERNELSZ//2+1)*self.NORIENT)
|
|
59
57
|
)
|
|
60
58
|
|
|
61
59
|
def save(self, filename):
|
|
@@ -70,6 +68,7 @@ class CNN:
|
|
|
70
68
|
self.get_weights().numpy(),
|
|
71
69
|
self.all_type,
|
|
72
70
|
self.n_chan_in,
|
|
71
|
+
self.NORIENT,
|
|
73
72
|
]
|
|
74
73
|
|
|
75
74
|
myout = open("%s.pkl" % (filename), "wb")
|
|
@@ -82,8 +81,8 @@ class CNN:
|
|
|
82
81
|
totnchan = totnchan + self.chanlist[i] * self.chanlist[i + 1]
|
|
83
82
|
return (
|
|
84
83
|
self.npar * 12 * self.out_nside**2 * self.chanlist[self.nscale]
|
|
85
|
-
+ totnchan * self.KERNELSZ * self.KERNELSZ
|
|
86
|
-
+ self.KERNELSZ * self.KERNELSZ * self.n_chan_in * self.chanlist[0]
|
|
84
|
+
+ totnchan * self.KERNELSZ * (self.KERNELSZ//2+1)
|
|
85
|
+
+ self.KERNELSZ * (self.KERNELSZ//2+1) * self.n_chan_in * self.chanlist[0]
|
|
87
86
|
)
|
|
88
87
|
|
|
89
88
|
def set_weights(self, x):
|
|
@@ -95,30 +94,32 @@ class CNN:
|
|
|
95
94
|
def eval(self, im, indices=None, weights=None):
|
|
96
95
|
|
|
97
96
|
x = self.x
|
|
98
|
-
ww = self.
|
|
99
|
-
x[0 : self.KERNELSZ * self.KERNELSZ * self.n_chan_in * self.chanlist[0]],
|
|
100
|
-
[self.KERNELSZ * self.KERNELSZ,
|
|
97
|
+
ww = self.backend.bk_reshape(
|
|
98
|
+
x[0 : self.KERNELSZ * (self.KERNELSZ//2+1) * self.n_chan_in * self.chanlist[0]],
|
|
99
|
+
[self.n_chan_in, self.KERNELSZ * (self.KERNELSZ//2+1), self.chanlist[0]],
|
|
101
100
|
)
|
|
102
|
-
nn = self.KERNELSZ * self.KERNELSZ * self.n_chan_in * self.chanlist[0]
|
|
101
|
+
nn = self.KERNELSZ * (self.KERNELSZ//2+1) * self.n_chan_in * self.chanlist[0]
|
|
103
102
|
|
|
104
103
|
im = self.scat_operator.healpix_layer(im, ww)
|
|
105
|
-
im = self.
|
|
104
|
+
im = self.backend.bk_relu(im)
|
|
105
|
+
|
|
106
|
+
im = self.backend.bk_reduce_mean(self.backend.bk_reshape(im,[im.shape[0],im.shape[1],im.shape[2]//4,4]),3)
|
|
106
107
|
|
|
107
108
|
for k in range(self.nscale):
|
|
108
109
|
ww = self.scat_operator.backend.bk_reshape(
|
|
109
110
|
x[
|
|
110
111
|
nn : nn
|
|
111
112
|
+ self.KERNELSZ
|
|
112
|
-
* self.KERNELSZ
|
|
113
|
+
* (self.KERNELSZ//2+1)
|
|
113
114
|
* self.chanlist[k]
|
|
114
115
|
* self.chanlist[k + 1]
|
|
115
116
|
],
|
|
116
|
-
[self.KERNELSZ * self.KERNELSZ,
|
|
117
|
+
[self.chanlist[k], self.KERNELSZ * (self.KERNELSZ//2+1), self.chanlist[k + 1]],
|
|
117
118
|
)
|
|
118
119
|
nn = (
|
|
119
120
|
nn
|
|
120
121
|
+ self.KERNELSZ
|
|
121
|
-
* self.KERNELSZ
|
|
122
|
+
* (self.KERNELSZ//2)
|
|
122
123
|
* self.chanlist[k]
|
|
123
124
|
* self.chanlist[k + 1]
|
|
124
125
|
)
|
|
@@ -129,7 +130,7 @@ class CNN:
|
|
|
129
130
|
im, ww, indices=indices[k], weights=weights[k]
|
|
130
131
|
)
|
|
131
132
|
im = self.scat_operator.backend.bk_relu(im)
|
|
132
|
-
im = self.
|
|
133
|
+
im = self.backend.bk_reduce_mean(self.backend.bk_reshape(im,[im.shape[0],im.shape[1],im.shape[2]//4,4]),3)
|
|
133
134
|
|
|
134
135
|
ww = self.scat_operator.backend.bk_reshape(
|
|
135
136
|
x[
|
|
@@ -141,11 +142,11 @@ class CNN:
|
|
|
141
142
|
|
|
142
143
|
im = self.scat_operator.backend.bk_matmul(
|
|
143
144
|
self.scat_operator.backend.bk_reshape(
|
|
144
|
-
im, [
|
|
145
|
+
im, [im.shape[0], im.shape[1] * im.shape[2]]
|
|
145
146
|
),
|
|
146
147
|
ww,
|
|
147
148
|
)
|
|
148
|
-
im = self.scat_operator.backend.bk_reshape(im, [self.npar])
|
|
149
|
+
#im = self.scat_operator.backend.bk_reshape(im, [self.npar])
|
|
149
150
|
im = self.scat_operator.backend.bk_relu(im)
|
|
150
151
|
|
|
151
152
|
return im
|
foscat/FoCUS.py
CHANGED
|
@@ -35,7 +35,7 @@ class FoCUS:
|
|
|
35
35
|
mpi_rank=0,
|
|
36
36
|
):
|
|
37
37
|
|
|
38
|
-
self.__version__ = "2025.
|
|
38
|
+
self.__version__ = "2025.06.1"
|
|
39
39
|
# P00 coeff for normalization for scat_cov
|
|
40
40
|
self.TMPFILE_VERSION = TMPFILE_VERSION
|
|
41
41
|
self.P1_dic = None
|
|
@@ -176,6 +176,9 @@ class FoCUS:
|
|
|
176
176
|
self.Y_CNN = {}
|
|
177
177
|
self.Z_CNN = {}
|
|
178
178
|
|
|
179
|
+
self.Idx_CNN = {}
|
|
180
|
+
self.Idx_WCNN = {}
|
|
181
|
+
|
|
179
182
|
self.filters_set = {}
|
|
180
183
|
self.edge_masks = {}
|
|
181
184
|
|
|
@@ -500,210 +503,26 @@ class FoCUS:
|
|
|
500
503
|
return indices, weights, xc, yc, zc
|
|
501
504
|
|
|
502
505
|
# ---------------------------------------------−---------
|
|
503
|
-
def calc_orientation(self, im): # im is [Ndata,12*Nside**2]
|
|
504
|
-
nside = int(np.sqrt(im.shape[1] // 12))
|
|
505
|
-
l_kernel = self.KERNELSZ * self.KERNELSZ
|
|
506
|
-
norient = 32
|
|
507
|
-
w = np.zeros([l_kernel, 1, 2 * norient])
|
|
508
|
-
ca = np.cos(np.arange(norient) / norient * np.pi)
|
|
509
|
-
sa = np.sin(np.arange(norient) / norient * np.pi)
|
|
510
|
-
stat = np.zeros([12 * nside**2, norient])
|
|
511
|
-
|
|
512
|
-
if self.ww_CNN[nside] is None:
|
|
513
|
-
self.init_CNN_index(nside, transpose=False)
|
|
514
|
-
|
|
515
|
-
y = self.Y_CNN[nside]
|
|
516
|
-
z = self.Z_CNN[nside]
|
|
517
|
-
|
|
518
|
-
for k in range(norient):
|
|
519
|
-
w[:, 0, k] = np.exp(-0.5 * nside**2 * ((y) ** 2 + (z) ** 2)) * np.cos(
|
|
520
|
-
nside * (y * ca[k] + z * sa[k]) * np.pi / 2
|
|
521
|
-
)
|
|
522
|
-
w[:, 0, k + norient] = np.exp(
|
|
523
|
-
-0.5 * nside**2 * ((y) ** 2 + (z) ** 2)
|
|
524
|
-
) * np.sin(nside * (y * ca[k] + z * sa[k]) * np.pi / 2)
|
|
525
|
-
w[:, 0, k] = w[:, 0, k] - np.mean(w[:, 0, k])
|
|
526
|
-
w[:, 0, k + norient] = w[:, 0, k] - np.mean(w[:, 0, k + norient])
|
|
527
|
-
|
|
528
|
-
for k in range(im.shape[0]):
|
|
529
|
-
tmp = im[k].reshape(12 * nside**2, 1)
|
|
530
|
-
im2 = self.healpix_layer(tmp, w)
|
|
531
|
-
stat = stat + im2[:, 0:norient] ** 2 + im2[:, norient:] ** 2
|
|
532
|
-
|
|
533
|
-
rotation = (np.argmax(stat, 1)).astype("float") / 32.0 * 180.0
|
|
534
|
-
|
|
535
|
-
indices, weights, x, y, z = self.calc_indices_convol(
|
|
536
|
-
nside, 9, rotation=rotation
|
|
537
|
-
)
|
|
538
|
-
|
|
539
|
-
return indices, weights
|
|
540
|
-
|
|
541
|
-
def init_CNN_index(self, nside, transpose=False):
|
|
542
|
-
l_kernel = int(self.KERNELSZ * self.KERNELSZ)
|
|
543
|
-
try:
|
|
544
|
-
indices = np.load(
|
|
545
|
-
"%s/FOSCAT_%s_I%d_%d_%d_CNNV3.npy"
|
|
546
|
-
% (self.TEMPLATE_PATH, TMPFILE_VERSION, l_kernel, self.NORIENT, nside)
|
|
547
|
-
)
|
|
548
|
-
weights = np.load(
|
|
549
|
-
"%s/FOSCAT_%s_W%d_%d_%d_CNNV3.npy"
|
|
550
|
-
% (self.TEMPLATE_PATH, TMPFILE_VERSION, l_kernel, self.NORIENT, nside)
|
|
551
|
-
)
|
|
552
|
-
xc = np.load(
|
|
553
|
-
"%s/FOSCAT_%s_X%d_%d_%d_CNNV3.npy"
|
|
554
|
-
% (self.TEMPLATE_PATH, TMPFILE_VERSION, l_kernel, self.NORIENT, nside)
|
|
555
|
-
)
|
|
556
|
-
yc = np.load(
|
|
557
|
-
"%s/FOSCAT_%s_Y%d_%d_%d_CNNV3.npy"
|
|
558
|
-
% (self.TEMPLATE_PATH, TMPFILE_VERSION, l_kernel, self.NORIENT, nside)
|
|
559
|
-
)
|
|
560
|
-
zc = np.load(
|
|
561
|
-
"%s/FOSCAT_%s_Z%d_%d_%d_CNNV3.npy"
|
|
562
|
-
% (self.TEMPLATE_PATH, TMPFILE_VERSION, l_kernel, self.NORIENT, nside)
|
|
563
|
-
)
|
|
564
|
-
except:
|
|
565
|
-
indices, weights, xc, yc, zc = self.calc_indices_convol(nside, l_kernel)
|
|
566
|
-
np.save(
|
|
567
|
-
"%s/FOSCAT_%s_I%d_%d_%d_CNNV3.npy"
|
|
568
|
-
% (self.TEMPLATE_PATH, TMPFILE_VERSION, l_kernel, self.NORIENT, nside),
|
|
569
|
-
indices,
|
|
570
|
-
)
|
|
571
|
-
np.save(
|
|
572
|
-
"%s/FOSCAT_%s_W%d_%d_%d_CNNV3.npy"
|
|
573
|
-
% (self.TEMPLATE_PATH, TMPFILE_VERSION, l_kernel, self.NORIENT, nside),
|
|
574
|
-
weights,
|
|
575
|
-
)
|
|
576
|
-
np.save(
|
|
577
|
-
"%s/FOSCAT_%s_X%d_%d_%d_CNNV3.npy"
|
|
578
|
-
% (self.TEMPLATE_PATH, TMPFILE_VERSION, l_kernel, self.NORIENT, nside),
|
|
579
|
-
xc,
|
|
580
|
-
)
|
|
581
|
-
np.save(
|
|
582
|
-
"%s/FOSCAT_%s_Y%d_%d_%d_CNNV3.npy"
|
|
583
|
-
% (self.TEMPLATE_PATH, TMPFILE_VERSION, l_kernel, self.NORIENT, nside),
|
|
584
|
-
yc,
|
|
585
|
-
)
|
|
586
|
-
np.save(
|
|
587
|
-
"%s/FOSCAT_%s_Z%d_%d_%d_CNNV3.npy"
|
|
588
|
-
% (self.TEMPLATE_PATH, TMPFILE_VERSION, l_kernel, self.NORIENT, nside),
|
|
589
|
-
zc,
|
|
590
|
-
)
|
|
591
|
-
if not self.silent:
|
|
592
|
-
print(
|
|
593
|
-
"Write %s/FOSCAT_%s_W%d_%d_%d_CNNV2.npy"
|
|
594
|
-
% (
|
|
595
|
-
self.TEMPLATE_PATH,
|
|
596
|
-
TMPFILE_VERSION,
|
|
597
|
-
l_kernel,
|
|
598
|
-
self.NORIENT,
|
|
599
|
-
nside,
|
|
600
|
-
)
|
|
601
|
-
)
|
|
602
|
-
|
|
603
|
-
self.X_CNN[nside] = xc
|
|
604
|
-
self.Y_CNN[nside] = yc
|
|
605
|
-
self.Z_CNN[nside] = zc
|
|
606
|
-
self.ww_CNN[nside] = self.backend.bk_SparseTensor(
|
|
607
|
-
indices, weights, [12 * nside * nside * l_kernel, 12 * nside * nside]
|
|
608
|
-
)
|
|
609
|
-
|
|
610
|
-
# ---------------------------------------------−---------
|
|
611
|
-
def healpix_layer_coord(self, im, axis=0):
|
|
612
|
-
nside = int(np.sqrt(im.shape[axis] // 12))
|
|
613
|
-
if self.ww_CNN[nside] is None:
|
|
614
|
-
self.init_CNN_index(nside)
|
|
615
|
-
return self.X_CNN[nside], self.Y_CNN[nside], self.Z_CNN[nside]
|
|
616
|
-
|
|
617
|
-
# ---------------------------------------------−---------
|
|
618
|
-
def healpix_layer_transpose(self, im, ww, indices=None, weights=None, axis=0):
|
|
619
|
-
nside = int(np.sqrt(im.shape[axis] // 12))
|
|
620
|
-
|
|
621
|
-
if im.shape[1 + axis] != ww.shape[1]:
|
|
622
|
-
if not self.silent:
|
|
623
|
-
print("Weights channels should be equal to the input image channels")
|
|
624
|
-
return -1
|
|
625
|
-
if axis == 1:
|
|
626
|
-
results = []
|
|
627
|
-
|
|
628
|
-
for k in range(im.shape[0]):
|
|
629
|
-
|
|
630
|
-
tmp = self.healpix_layer(
|
|
631
|
-
im[k], ww, indices=indices, weights=weights, axis=0
|
|
632
|
-
)
|
|
633
|
-
tmp = self.backend.bk_reshape(
|
|
634
|
-
self.up_grade(tmp, 2 * nside), [12 * 4 * nside**2, ww.shape[2]]
|
|
635
|
-
)
|
|
636
|
-
|
|
637
|
-
results.append(tmp)
|
|
638
|
-
|
|
639
|
-
return self.backend.bk_stack(results, axis=0)
|
|
640
|
-
else:
|
|
641
|
-
tmp = self.healpix_layer(
|
|
642
|
-
im, ww, indices=indices, weights=weights, axis=axis
|
|
643
|
-
)
|
|
644
|
-
|
|
645
|
-
return self.up_grade(tmp, 2 * nside)
|
|
646
|
-
|
|
647
506
|
# ---------------------------------------------−---------
|
|
648
|
-
|
|
649
|
-
|
|
650
|
-
|
|
651
|
-
|
|
652
|
-
|
|
653
|
-
if im.shape[1 + axis] != ww.shape[1]:
|
|
654
|
-
if not self.silent:
|
|
655
|
-
print("Weights channels should be equal to the input image channels")
|
|
656
|
-
return -1
|
|
657
|
-
|
|
507
|
+
def healpix_layer(self, im, ww, indices=None, weights=None):
|
|
508
|
+
#ww [N_i,NORIENT,KERNELSZ*KERNELSZ//2,N_o,NORIENT]
|
|
509
|
+
#im [N_batch,N_i, NORIENT,N]
|
|
510
|
+
nside=int(np.sqrt(im.shape[-1]//12))
|
|
658
511
|
if indices is None:
|
|
659
|
-
if self.
|
|
660
|
-
self.
|
|
661
|
-
|
|
662
|
-
|
|
663
|
-
|
|
664
|
-
|
|
665
|
-
|
|
666
|
-
|
|
667
|
-
|
|
668
|
-
|
|
669
|
-
|
|
670
|
-
|
|
671
|
-
|
|
672
|
-
|
|
673
|
-
if axis == 1:
|
|
674
|
-
results = []
|
|
675
|
-
|
|
676
|
-
for k in range(im.shape[0]):
|
|
677
|
-
|
|
678
|
-
tmp = self.backend.bk_sparse_dense_matmul(mat, im[k])
|
|
679
|
-
|
|
680
|
-
density = self.backend.bk_reshape(
|
|
681
|
-
tmp, [12 * nside * nside, l_kernel * im.shape[1 + axis]]
|
|
682
|
-
)
|
|
683
|
-
|
|
684
|
-
density = self.backend.bk_matmul(
|
|
685
|
-
density,
|
|
686
|
-
self.backend.bk_reshape(
|
|
687
|
-
ww, [l_kernel * im.shape[1 + axis], ww.shape[2]]
|
|
688
|
-
),
|
|
689
|
-
)
|
|
690
|
-
|
|
691
|
-
results.append(
|
|
692
|
-
self.backend.bk_reshape(density, [12 * nside**2, ww.shape[2]])
|
|
693
|
-
)
|
|
694
|
-
|
|
695
|
-
return self.backend.bk_stack(results, axis=0)
|
|
696
|
-
else:
|
|
697
|
-
tmp = self.backend.bk_sparse_dense_matmul(mat, im)
|
|
698
|
-
|
|
699
|
-
density = self.backend.bk_reshape(
|
|
700
|
-
tmp, [12 * nside * nside, l_kernel * im.shape[1]]
|
|
701
|
-
)
|
|
702
|
-
|
|
703
|
-
return self.backend.bk_matmul(
|
|
704
|
-
density,
|
|
705
|
-
self.backend.bk_reshape(ww, [l_kernel * im.shape[1], ww.shape[2]]),
|
|
706
|
-
)
|
|
512
|
+
if (nside,self.NORIENT,self.KERNELSZ) not in self.ww_CNN:
|
|
513
|
+
self.init_index_cnn(nside,self.NORIENT)
|
|
514
|
+
indices = self.Idx_CNN[(nside,self.NORIENT,self.KERNELSZ)]
|
|
515
|
+
mat = self.Idx_WCNN[(nside,self.NORIENT,self.KERNELSZ)]
|
|
516
|
+
|
|
517
|
+
wim = self.backend.bk_gather(im,indices.flatten(),axis=3) #[N_batch,N_i,NORIENT,K*(K+1),N_o,NORIENT,N,N_w]
|
|
518
|
+
|
|
519
|
+
wim = self.backend.bk_reshape(wim,[im.shape[0],im.shape[1],im.shape[2]]+list(indices.shape))*mat[None,...]
|
|
520
|
+
#win is [N_batch,N_i, NORIENT,K*(K+1),1, NORIENT,N,N_w]
|
|
521
|
+
#ww is [1, N_i, NORIENT,K*(K+1),N_o,NORIENT]
|
|
522
|
+
wim = self.backend.bk_reduce_sum(wim[:,:,:,:,None]*ww[None,:,:,:,:,:,None,None],[1,2,3])
|
|
523
|
+
|
|
524
|
+
wim = self.backend.bk_reduce_sum(wim,-1)
|
|
525
|
+
return self.backend.bk_reshape(wim,[im.shape[0],ww.shape[3],ww.shape[4],im.shape[-1]])
|
|
707
526
|
|
|
708
527
|
# ---------------------------------------------−---------
|
|
709
528
|
|
|
@@ -1775,6 +1594,232 @@ class FoCUS:
|
|
|
1775
1594
|
|
|
1776
1595
|
return wr, wi, ws, tmp
|
|
1777
1596
|
|
|
1597
|
+
|
|
1598
|
+
# ---------------------------------------------−---------
|
|
1599
|
+
def init_index_cnn(self, nside, NORIENT=4,kernel=-1, cell_ids=None):
|
|
1600
|
+
|
|
1601
|
+
if kernel == -1:
|
|
1602
|
+
l_kernel = self.KERNELSZ
|
|
1603
|
+
else:
|
|
1604
|
+
l_kernel = kernel
|
|
1605
|
+
|
|
1606
|
+
if cell_ids is not None:
|
|
1607
|
+
ncell = cell_ids.shape[0]
|
|
1608
|
+
else:
|
|
1609
|
+
ncell = 12 * nside * nside
|
|
1610
|
+
|
|
1611
|
+
try:
|
|
1612
|
+
|
|
1613
|
+
if cell_ids is not None:
|
|
1614
|
+
tmp = np.load(
|
|
1615
|
+
"%s/XXXX_%s_W%d_%d_%d_PIDX.npy" # can not work
|
|
1616
|
+
% (
|
|
1617
|
+
self.TEMPLATE_PATH,
|
|
1618
|
+
TMPFILE_VERSION,
|
|
1619
|
+
l_kernel**2,
|
|
1620
|
+
NORIENT,
|
|
1621
|
+
nside, # if cell_ids computes the index
|
|
1622
|
+
)
|
|
1623
|
+
)
|
|
1624
|
+
|
|
1625
|
+
else:
|
|
1626
|
+
tmp = np.load(
|
|
1627
|
+
"%s/CNN_FOSCAT_%s_W%d_%d_%d_PIDX.npy"
|
|
1628
|
+
% (
|
|
1629
|
+
self.TEMPLATE_PATH,
|
|
1630
|
+
TMPFILE_VERSION,
|
|
1631
|
+
l_kernel**2,
|
|
1632
|
+
NORIENT,
|
|
1633
|
+
nside, # if cell_ids computes the index
|
|
1634
|
+
)
|
|
1635
|
+
)
|
|
1636
|
+
except:
|
|
1637
|
+
|
|
1638
|
+
pw = 8.0
|
|
1639
|
+
pw2 = 1.0
|
|
1640
|
+
threshold = 1e-3
|
|
1641
|
+
|
|
1642
|
+
if l_kernel == 5:
|
|
1643
|
+
pw = 8.0
|
|
1644
|
+
pw2 = 0.5
|
|
1645
|
+
threshold = 2e-4
|
|
1646
|
+
|
|
1647
|
+
elif l_kernel == 3:
|
|
1648
|
+
pw = 8.0
|
|
1649
|
+
pw2 = 1.0
|
|
1650
|
+
threshold = 1e-3
|
|
1651
|
+
|
|
1652
|
+
elif l_kernel == 7:
|
|
1653
|
+
pw = 8.0
|
|
1654
|
+
pw2 = 0.25
|
|
1655
|
+
threshold = 4e-5
|
|
1656
|
+
|
|
1657
|
+
n_weights = self.KERNELSZ*(self.KERNELSZ//2+1)
|
|
1658
|
+
|
|
1659
|
+
if cell_ids is not None:
|
|
1660
|
+
if not isinstance(cell_ids, np.ndarray):
|
|
1661
|
+
cell_ids = self.backend.to_numpy(cell_ids)
|
|
1662
|
+
th, ph = hp.pix2ang(nside, cell_ids, nest=True)
|
|
1663
|
+
x, y, z = hp.pix2vec(nside, cell_ids, nest=True)
|
|
1664
|
+
|
|
1665
|
+
t, p = hp.pix2ang(nside, cell_ids, nest=True)
|
|
1666
|
+
phi = [p[k] / np.pi * 180 for k in range(ncell)]
|
|
1667
|
+
thi = [t[k] / np.pi * 180 for k in range(ncell)]
|
|
1668
|
+
|
|
1669
|
+
indice = np.zeros([n_weights, NORIENT, ncell,4], dtype="int")
|
|
1670
|
+
|
|
1671
|
+
wav = np.zeros([n_weights, NORIENT, ncell,4], dtype="float")
|
|
1672
|
+
|
|
1673
|
+
else:
|
|
1674
|
+
|
|
1675
|
+
th, ph = hp.pix2ang(nside, np.arange(12 * nside**2), nest=True)
|
|
1676
|
+
x, y, z = hp.pix2vec(nside, np.arange(12 * nside**2), nest=True)
|
|
1677
|
+
|
|
1678
|
+
t, p = hp.pix2ang(nside, np.arange(12 * nside * nside), nest=True)
|
|
1679
|
+
phi = [p[k] / np.pi * 180 for k in range(12 * nside * nside)]
|
|
1680
|
+
thi = [t[k] / np.pi * 180 for k in range(12 * nside * nside)]
|
|
1681
|
+
|
|
1682
|
+
indice = np.zeros(
|
|
1683
|
+
[n_weights, NORIENT, 12 * nside * nside,4], dtype="int"
|
|
1684
|
+
)
|
|
1685
|
+
wav = np.zeros(
|
|
1686
|
+
[n_weights, NORIENT, 12 * nside * nside,4], dtype="float"
|
|
1687
|
+
)
|
|
1688
|
+
iv = 0
|
|
1689
|
+
iv2 = 0
|
|
1690
|
+
|
|
1691
|
+
for iii in range(ncell):
|
|
1692
|
+
if cell_ids is None:
|
|
1693
|
+
if iii % (nside * nside) == nside * nside - 1:
|
|
1694
|
+
if not self.silent:
|
|
1695
|
+
print(
|
|
1696
|
+
"Pre-compute nside=%6d %.2f%%"
|
|
1697
|
+
% (nside, 100 * iii / (12 * nside * nside))
|
|
1698
|
+
)
|
|
1699
|
+
|
|
1700
|
+
if cell_ids is not None:
|
|
1701
|
+
hidx = np.where(
|
|
1702
|
+
(x - x[iii]) ** 2 + (y - y[iii]) ** 2 + (z - z[iii]) ** 2
|
|
1703
|
+
< (2 * np.pi / nside) ** 2
|
|
1704
|
+
)[0]
|
|
1705
|
+
else:
|
|
1706
|
+
hidx = hp.query_disc(
|
|
1707
|
+
nside,
|
|
1708
|
+
[x[iii], y[iii], z[iii]],
|
|
1709
|
+
2 * np.pi / nside,
|
|
1710
|
+
nest=True,
|
|
1711
|
+
)
|
|
1712
|
+
|
|
1713
|
+
R = hp.Rotator(rot=[phi[iii], -thi[iii]], eulertype="ZYZ")
|
|
1714
|
+
|
|
1715
|
+
t2, p2 = R(th[hidx], ph[hidx])
|
|
1716
|
+
|
|
1717
|
+
vec2 = hp.ang2vec(t2, p2)
|
|
1718
|
+
|
|
1719
|
+
x2 = vec2[:, 0]
|
|
1720
|
+
y2 = vec2[:, 1]
|
|
1721
|
+
z2 = vec2[:, 2]
|
|
1722
|
+
|
|
1723
|
+
for l_rotation in range(NORIENT):
|
|
1724
|
+
|
|
1725
|
+
angle = (
|
|
1726
|
+
l_rotation / 4.0 * np.pi
|
|
1727
|
+
- phi[iii] / 180.0 * np.pi * (z[hidx] > 0)
|
|
1728
|
+
- (180.0 - phi[iii]) / 180.0 * np.pi * (z[hidx] < 0)
|
|
1729
|
+
)
|
|
1730
|
+
|
|
1731
|
+
|
|
1732
|
+
axes = y2 * np.cos(angle) - x2 * np.sin(angle)
|
|
1733
|
+
axes2 = -y2 * np.sin(angle) - x2 * np.cos(angle)
|
|
1734
|
+
|
|
1735
|
+
for k_weights in range(self.KERNELSZ//2+1):
|
|
1736
|
+
for l_weights in range(self.KERNELSZ):
|
|
1737
|
+
|
|
1738
|
+
val=np.exp(-(pw*(axes2*(nside)-(k_weights-self.KERNELSZ//2))**2+pw*(axes*(nside)-(l_weights-self.KERNELSZ//2))**2))+ \
|
|
1739
|
+
np.exp(-(pw*(axes2*(nside)+(k_weights-self.KERNELSZ//2))**2+pw*(axes*(nside)-(l_weights-self.KERNELSZ//2))**2))
|
|
1740
|
+
|
|
1741
|
+
idx = np.argsort(-val)
|
|
1742
|
+
idx = idx[0:4]
|
|
1743
|
+
|
|
1744
|
+
nval = len(idx)
|
|
1745
|
+
val=val[idx]
|
|
1746
|
+
|
|
1747
|
+
r = abs(val).sum()
|
|
1748
|
+
|
|
1749
|
+
if r > 0:
|
|
1750
|
+
val = val / r
|
|
1751
|
+
|
|
1752
|
+
indice[k_weights*self.KERNELSZ+l_weights,l_rotation,iii,:] = hidx[idx]
|
|
1753
|
+
wav[k_weights*self.KERNELSZ+l_weights,l_rotation,iii,:] = val
|
|
1754
|
+
|
|
1755
|
+
if not self.silent:
|
|
1756
|
+
print("Kernel Size ", iv / (NORIENT * 12 * nside * nside))
|
|
1757
|
+
|
|
1758
|
+
if cell_ids is None:
|
|
1759
|
+
if not self.silent:
|
|
1760
|
+
print(
|
|
1761
|
+
"Write FOSCAT_%s_W%d_%d_%d_PIDX.npy"
|
|
1762
|
+
% (TMPFILE_VERSION, self.KERNELSZ**2, NORIENT, nside)
|
|
1763
|
+
)
|
|
1764
|
+
np.save(
|
|
1765
|
+
"%s/CNN_FOSCAT_%s_W%d_%d_%d_PIDX.npy"
|
|
1766
|
+
% (
|
|
1767
|
+
self.TEMPLATE_PATH,
|
|
1768
|
+
TMPFILE_VERSION,
|
|
1769
|
+
self.KERNELSZ**2,
|
|
1770
|
+
NORIENT,
|
|
1771
|
+
nside,
|
|
1772
|
+
),
|
|
1773
|
+
indice,
|
|
1774
|
+
)
|
|
1775
|
+
np.save(
|
|
1776
|
+
"%s/CNN_FOSCAT_%s_W%d_%d_%d_WAVE.npy"
|
|
1777
|
+
% (
|
|
1778
|
+
self.TEMPLATE_PATH,
|
|
1779
|
+
TMPFILE_VERSION,
|
|
1780
|
+
self.KERNELSZ**2,
|
|
1781
|
+
NORIENT,
|
|
1782
|
+
nside,
|
|
1783
|
+
),
|
|
1784
|
+
wav,
|
|
1785
|
+
)
|
|
1786
|
+
|
|
1787
|
+
if cell_ids is None:
|
|
1788
|
+
self.barrier()
|
|
1789
|
+
if self.use_2D:
|
|
1790
|
+
tmp = np.load(
|
|
1791
|
+
"%s/W%d_%s_%d_IDX.npy"
|
|
1792
|
+
% (self.TEMPLATE_PATH, l_kernel**2, TMPFILE_VERSION, nside)
|
|
1793
|
+
)
|
|
1794
|
+
else:
|
|
1795
|
+
tmp = np.load(
|
|
1796
|
+
"%s/CNN_FOSCAT_%s_W%d_%d_%d_PIDX.npy"
|
|
1797
|
+
% (
|
|
1798
|
+
self.TEMPLATE_PATH,
|
|
1799
|
+
TMPFILE_VERSION,
|
|
1800
|
+
self.KERNELSZ**2,
|
|
1801
|
+
NORIENT,
|
|
1802
|
+
nside,
|
|
1803
|
+
)
|
|
1804
|
+
)
|
|
1805
|
+
wav = np.load(
|
|
1806
|
+
"%s/CNN_FOSCAT_%s_W%d_%d_%d_WAVE.npy"
|
|
1807
|
+
% (
|
|
1808
|
+
self.TEMPLATE_PATH,
|
|
1809
|
+
TMPFILE_VERSION,
|
|
1810
|
+
self.KERNELSZ**2,
|
|
1811
|
+
NORIENT,
|
|
1812
|
+
nside,
|
|
1813
|
+
)
|
|
1814
|
+
)
|
|
1815
|
+
else:
|
|
1816
|
+
tmp = indice
|
|
1817
|
+
|
|
1818
|
+
self.Idx_CNN[(nside,NORIENT,self.KERNELSZ)] = tmp
|
|
1819
|
+
self.Idx_WCNN[(nside,NORIENT,self.KERNELSZ)] = self.backend.bk_cast(wav)
|
|
1820
|
+
|
|
1821
|
+
return wav, tmp
|
|
1822
|
+
|
|
1778
1823
|
# ---------------------------------------------−---------
|
|
1779
1824
|
# convert swap axes tensor x [....,a,....,b,....] to [....,b,....,a,....]
|
|
1780
1825
|
def swapaxes(self, x, axis1, axis2):
|