foscat 3.8.2__py3-none-any.whl → 2025.3.0__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- foscat/BkBase.py +36 -35
- foscat/BkNumpy.py +53 -62
- foscat/BkTensorflow.py +87 -88
- foscat/BkTorch.py +159 -72
- foscat/FoCUS.py +228 -89
- foscat/Synthesis.py +3 -3
- foscat/alm.py +188 -170
- foscat/backend.py +84 -70
- foscat/scat_cov.py +2138 -2220
- foscat/scat_cov2D.py +146 -53
- {foscat-3.8.2.dist-info → foscat-2025.3.0.dist-info}/METADATA +3 -2
- foscat-2025.3.0.dist-info/RECORD +30 -0
- {foscat-3.8.2.dist-info → foscat-2025.3.0.dist-info}/WHEEL +1 -1
- foscat-3.8.2.dist-info/RECORD +0 -30
- {foscat-3.8.2.dist-info → foscat-2025.3.0.dist-info/licenses}/LICENSE +0 -0
- {foscat-3.8.2.dist-info → foscat-2025.3.0.dist-info}/top_level.txt +0 -0
foscat/FoCUS.py
CHANGED
|
@@ -1,5 +1,4 @@
|
|
|
1
1
|
import os
|
|
2
|
-
import os
|
|
3
2
|
import sys
|
|
4
3
|
|
|
5
4
|
import healpy as hp
|
|
@@ -11,32 +10,32 @@ TMPFILE_VERSION = "V4_0"
|
|
|
11
10
|
|
|
12
11
|
class FoCUS:
|
|
13
12
|
def __init__(
|
|
14
|
-
|
|
15
|
-
|
|
16
|
-
|
|
17
|
-
|
|
18
|
-
|
|
19
|
-
|
|
20
|
-
|
|
21
|
-
|
|
22
|
-
|
|
23
|
-
|
|
24
|
-
|
|
25
|
-
|
|
26
|
-
|
|
27
|
-
|
|
28
|
-
|
|
29
|
-
|
|
30
|
-
|
|
31
|
-
|
|
32
|
-
|
|
33
|
-
|
|
34
|
-
|
|
35
|
-
|
|
36
|
-
|
|
13
|
+
self,
|
|
14
|
+
NORIENT=4,
|
|
15
|
+
LAMBDA=1.2,
|
|
16
|
+
KERNELSZ=3,
|
|
17
|
+
slope=1.0,
|
|
18
|
+
all_type="float32",
|
|
19
|
+
nstep_max=16,
|
|
20
|
+
padding="SAME",
|
|
21
|
+
gpupos=0,
|
|
22
|
+
mask_thres=None,
|
|
23
|
+
mask_norm=False,
|
|
24
|
+
isMPI=False,
|
|
25
|
+
TEMPLATE_PATH="data",
|
|
26
|
+
BACKEND="tensorflow",
|
|
27
|
+
use_2D=False,
|
|
28
|
+
use_1D=False,
|
|
29
|
+
return_data=False,
|
|
30
|
+
JmaxDelta=0,
|
|
31
|
+
DODIV=False,
|
|
32
|
+
InitWave=None,
|
|
33
|
+
silent=True,
|
|
34
|
+
mpi_size=1,
|
|
35
|
+
mpi_rank=0,
|
|
37
36
|
):
|
|
38
37
|
|
|
39
|
-
self.__version__ = "
|
|
38
|
+
self.__version__ = "2025.03.0"
|
|
40
39
|
# P00 coeff for normalization for scat_cov
|
|
41
40
|
self.TMPFILE_VERSION = TMPFILE_VERSION
|
|
42
41
|
self.P1_dic = None
|
|
@@ -45,12 +44,18 @@ class FoCUS:
|
|
|
45
44
|
self.mask_thres = mask_thres
|
|
46
45
|
self.mask_norm = mask_norm
|
|
47
46
|
self.InitWave = InitWave
|
|
48
|
-
self.mask_mask=None
|
|
47
|
+
self.mask_mask = None
|
|
49
48
|
self.mpi_size = mpi_size
|
|
50
49
|
self.mpi_rank = mpi_rank
|
|
51
50
|
self.return_data = return_data
|
|
52
51
|
self.silent = silent
|
|
53
52
|
|
|
53
|
+
self.kernel_smooth = {}
|
|
54
|
+
self.padding_smooth = {}
|
|
55
|
+
self.kernelR_conv = {}
|
|
56
|
+
self.kernelI_conv = {}
|
|
57
|
+
self.padding_conv = {}
|
|
58
|
+
|
|
54
59
|
if not self.silent:
|
|
55
60
|
print("================================================")
|
|
56
61
|
print(" START FOSCAT CONFIGURATION")
|
|
@@ -69,10 +74,7 @@ class FoCUS:
|
|
|
69
74
|
if not self.silent:
|
|
70
75
|
print("The directory %s is created")
|
|
71
76
|
except:
|
|
72
|
-
|
|
73
|
-
print(
|
|
74
|
-
"Impossible to create the directory %s" % (self.TEMPLATE_PATH)
|
|
75
|
-
)
|
|
77
|
+
print("Impossible to create the directory %s" % (self.TEMPLATE_PATH))
|
|
76
78
|
return None
|
|
77
79
|
|
|
78
80
|
self.number_of_loss = 0
|
|
@@ -82,10 +84,9 @@ class FoCUS:
|
|
|
82
84
|
self.padding = padding
|
|
83
85
|
|
|
84
86
|
if JmaxDelta != 0:
|
|
85
|
-
|
|
86
|
-
|
|
87
|
-
|
|
88
|
-
)
|
|
87
|
+
print(
|
|
88
|
+
"OPTION JmaxDelta is not avialable anymore after version 3.6.2. Please use Jmax option in eval function"
|
|
89
|
+
)
|
|
89
90
|
return None
|
|
90
91
|
|
|
91
92
|
self.OSTEP = JmaxDelta
|
|
@@ -105,31 +106,34 @@ class FoCUS:
|
|
|
105
106
|
|
|
106
107
|
self.all_type = all_type
|
|
107
108
|
self.BACKEND = BACKEND
|
|
108
|
-
|
|
109
|
-
if BACKEND==
|
|
109
|
+
|
|
110
|
+
if BACKEND == "torch":
|
|
110
111
|
from foscat.BkTorch import BkTorch
|
|
112
|
+
|
|
111
113
|
self.backend = BkTorch(
|
|
112
114
|
all_type=all_type,
|
|
113
115
|
mpi_rank=mpi_rank,
|
|
114
116
|
gpupos=gpupos,
|
|
115
117
|
silent=self.silent,
|
|
116
|
-
|
|
117
|
-
elif BACKEND==
|
|
118
|
+
)
|
|
119
|
+
elif BACKEND == "tensorflow":
|
|
118
120
|
from foscat.BkTensorflow import BkTensorflow
|
|
121
|
+
|
|
119
122
|
self.backend = BkTensorflow(
|
|
120
123
|
all_type=all_type,
|
|
121
124
|
mpi_rank=mpi_rank,
|
|
122
125
|
gpupos=gpupos,
|
|
123
126
|
silent=self.silent,
|
|
124
|
-
|
|
127
|
+
)
|
|
125
128
|
else:
|
|
126
129
|
from foscat.BkNumpy import BkNumpy
|
|
130
|
+
|
|
127
131
|
self.backend = BkNumpy(
|
|
128
132
|
all_type=all_type,
|
|
129
133
|
mpi_rank=mpi_rank,
|
|
130
134
|
gpupos=gpupos,
|
|
131
135
|
silent=self.silent,
|
|
132
|
-
|
|
136
|
+
)
|
|
133
137
|
|
|
134
138
|
self.all_bk_type = self.backend.all_bk_type
|
|
135
139
|
self.all_cbk_type = self.backend.all_cbk_type
|
|
@@ -172,9 +176,9 @@ class FoCUS:
|
|
|
172
176
|
self.Y_CNN = {}
|
|
173
177
|
self.Z_CNN = {}
|
|
174
178
|
|
|
175
|
-
self.filters_set={}
|
|
176
|
-
self.edge_masks={}
|
|
177
|
-
|
|
179
|
+
self.filters_set = {}
|
|
180
|
+
self.edge_masks = {}
|
|
181
|
+
|
|
178
182
|
wwc = np.zeros([KERNELSZ**2, l_NORIENT]).astype(all_type)
|
|
179
183
|
wws = np.zeros([KERNELSZ**2, l_NORIENT]).astype(all_type)
|
|
180
184
|
|
|
@@ -209,15 +213,27 @@ class FoCUS:
|
|
|
209
213
|
w_smooth = w_smooth.flatten()
|
|
210
214
|
else:
|
|
211
215
|
for i in range(NORIENT):
|
|
212
|
-
a = (
|
|
213
|
-
|
|
214
|
-
|
|
215
|
-
|
|
216
|
+
a = (
|
|
217
|
+
(NORIENT - 1 - i) / float(NORIENT) * np.pi
|
|
218
|
+
) # get the same angle number than scattering lib
|
|
219
|
+
if KERNELSZ < 5:
|
|
220
|
+
xx = (
|
|
221
|
+
(3 / float(KERNELSZ)) * LAMBDA * (x * np.cos(a) + y * np.sin(a))
|
|
222
|
+
)
|
|
223
|
+
yy = (
|
|
224
|
+
(3 / float(KERNELSZ)) * LAMBDA * (x * np.sin(a) - y * np.cos(a))
|
|
225
|
+
)
|
|
216
226
|
else:
|
|
217
|
-
xx = (3 /5) * LAMBDA * (x * np.cos(a) + y * np.sin(a))
|
|
218
|
-
yy = (3 /5) * LAMBDA * (x * np.sin(a) - y * np.cos(a))
|
|
227
|
+
xx = (3 / 5) * LAMBDA * (x * np.cos(a) + y * np.sin(a))
|
|
228
|
+
yy = (3 / 5) * LAMBDA * (x * np.sin(a) - y * np.cos(a))
|
|
219
229
|
if KERNELSZ == 5:
|
|
220
|
-
w_smooth=np.exp(
|
|
230
|
+
w_smooth = np.exp(
|
|
231
|
+
-2
|
|
232
|
+
* (
|
|
233
|
+
(3.0 / float(KERNELSZ) * xx) ** 2
|
|
234
|
+
+ (3.0 / float(KERNELSZ) * yy) ** 2
|
|
235
|
+
)
|
|
236
|
+
)
|
|
221
237
|
else:
|
|
222
238
|
w_smooth = np.exp(-0.5 * (xx**2 + yy**2))
|
|
223
239
|
tmp1 = np.cos(yy * np.pi) * w_smooth
|
|
@@ -225,7 +241,7 @@ class FoCUS:
|
|
|
225
241
|
|
|
226
242
|
wwc[:, i] = tmp1.flatten() - tmp1.mean()
|
|
227
243
|
wws[:, i] = tmp2.flatten() - tmp2.mean()
|
|
228
|
-
#sigma = np.sqrt((wwc[:, i] ** 2).mean())
|
|
244
|
+
# sigma = np.sqrt((wwc[:, i] ** 2).mean())
|
|
229
245
|
sigma = np.mean(w_smooth)
|
|
230
246
|
wwc[:, i] /= sigma
|
|
231
247
|
wws[:, i] /= sigma
|
|
@@ -239,7 +255,7 @@ class FoCUS:
|
|
|
239
255
|
|
|
240
256
|
wwc[:, NORIENT] = tmp1.flatten() - tmp1.mean()
|
|
241
257
|
wws[:, NORIENT] = tmp2.flatten() - tmp2.mean()
|
|
242
|
-
#sigma = np.sqrt((wwc[:, NORIENT] ** 2).mean())
|
|
258
|
+
# sigma = np.sqrt((wwc[:, NORIENT] ** 2).mean())
|
|
243
259
|
sigma = np.mean(w_smooth)
|
|
244
260
|
|
|
245
261
|
wwc[:, NORIENT] /= sigma
|
|
@@ -249,13 +265,13 @@ class FoCUS:
|
|
|
249
265
|
|
|
250
266
|
wwc[:, NORIENT + 1] = tmp1.flatten() - tmp1.mean()
|
|
251
267
|
wws[:, NORIENT + 1] = tmp2.flatten() - tmp2.mean()
|
|
252
|
-
#sigma = np.sqrt((wwc[:, NORIENT + 1] ** 2).mean())
|
|
268
|
+
# sigma = np.sqrt((wwc[:, NORIENT + 1] ** 2).mean())
|
|
253
269
|
sigma = np.mean(w_smooth)
|
|
254
270
|
wwc[:, NORIENT + 1] /= sigma
|
|
255
271
|
wws[:, NORIENT + 1] /= sigma
|
|
256
272
|
|
|
257
273
|
w_smooth = w_smooth.flatten()
|
|
258
|
-
|
|
274
|
+
|
|
259
275
|
if self.use_1D:
|
|
260
276
|
KERNELSZ = 5
|
|
261
277
|
|
|
@@ -723,19 +739,19 @@ class FoCUS:
|
|
|
723
739
|
def ud_grade(self, im, j, axis=0):
|
|
724
740
|
rim = im
|
|
725
741
|
for k in range(j):
|
|
726
|
-
#rim = self.smooth(rim, axis=axis)
|
|
742
|
+
# rim = self.smooth(rim, axis=axis)
|
|
727
743
|
rim = self.ud_grade_2(rim, axis=axis)
|
|
728
744
|
return rim
|
|
729
745
|
|
|
730
746
|
# --------------------------------------------------------
|
|
731
|
-
def ud_grade_2(self, im, axis=0):
|
|
747
|
+
def ud_grade_2(self, im, axis=0, cell_ids=None, nside=None):
|
|
732
748
|
|
|
733
749
|
if self.use_2D:
|
|
734
750
|
ishape = list(im.shape)
|
|
735
751
|
if len(ishape) < axis + 2:
|
|
736
752
|
if not self.silent:
|
|
737
753
|
print("Use of 2D scat with data that has less than 2D")
|
|
738
|
-
return None
|
|
754
|
+
return None, None
|
|
739
755
|
|
|
740
756
|
npix = im.shape[axis]
|
|
741
757
|
npiy = im.shape[axis + 1]
|
|
@@ -760,29 +776,40 @@ class FoCUS:
|
|
|
760
776
|
|
|
761
777
|
if axis == 0:
|
|
762
778
|
if len(ishape) == 2:
|
|
763
|
-
return self.backend.bk_reshape(res, [npix // 2, npiy // 2])
|
|
779
|
+
return self.backend.bk_reshape(res, [npix // 2, npiy // 2]), None
|
|
764
780
|
else:
|
|
765
|
-
return
|
|
766
|
-
|
|
781
|
+
return (
|
|
782
|
+
self.backend.bk_reshape(
|
|
783
|
+
res, [npix // 2, npiy // 2] + ishape[axis + 2 :]
|
|
784
|
+
),
|
|
785
|
+
None,
|
|
767
786
|
)
|
|
768
787
|
else:
|
|
769
788
|
if len(ishape) == axis + 2:
|
|
770
|
-
return
|
|
771
|
-
|
|
789
|
+
return (
|
|
790
|
+
self.backend.bk_reshape(
|
|
791
|
+
res, ishape[0:axis] + [npix // 2, npiy // 2]
|
|
792
|
+
),
|
|
793
|
+
None,
|
|
772
794
|
)
|
|
773
795
|
else:
|
|
774
|
-
return
|
|
775
|
-
|
|
776
|
-
|
|
796
|
+
return (
|
|
797
|
+
self.backend.bk_reshape(
|
|
798
|
+
res,
|
|
799
|
+
ishape[0:axis]
|
|
800
|
+
+ [npix // 2, npiy // 2]
|
|
801
|
+
+ ishape[axis + 2 :],
|
|
802
|
+
),
|
|
803
|
+
None,
|
|
777
804
|
)
|
|
778
805
|
|
|
779
|
-
return self.backend.bk_reshape(res, [npix // 2, npiy // 2])
|
|
806
|
+
return self.backend.bk_reshape(res, [npix // 2, npiy // 2]), None
|
|
780
807
|
elif self.use_1D:
|
|
781
808
|
ishape = list(im.shape)
|
|
782
809
|
if len(ishape) < axis + 1:
|
|
783
810
|
if not self.silent:
|
|
784
811
|
print("Use of 1D scat with data that has less than 1D")
|
|
785
|
-
return None
|
|
812
|
+
return None, None
|
|
786
813
|
|
|
787
814
|
npix = im.shape[axis]
|
|
788
815
|
odata = 1
|
|
@@ -805,23 +832,33 @@ class FoCUS:
|
|
|
805
832
|
|
|
806
833
|
if axis == 0:
|
|
807
834
|
if len(ishape) == 1:
|
|
808
|
-
return self.backend.bk_reshape(res, [npix // 2])
|
|
835
|
+
return self.backend.bk_reshape(res, [npix // 2]), None
|
|
809
836
|
else:
|
|
810
|
-
return
|
|
811
|
-
res, [npix // 2] + ishape[axis + 1 :]
|
|
837
|
+
return (
|
|
838
|
+
self.backend.bk_reshape(res, [npix // 2] + ishape[axis + 1 :]),
|
|
839
|
+
None,
|
|
812
840
|
)
|
|
813
841
|
else:
|
|
814
842
|
if len(ishape) == axis + 1:
|
|
815
|
-
return
|
|
843
|
+
return (
|
|
844
|
+
self.backend.bk_reshape(res, ishape[0:axis] + [npix // 2]),
|
|
845
|
+
None,
|
|
846
|
+
)
|
|
816
847
|
else:
|
|
817
|
-
return
|
|
818
|
-
|
|
848
|
+
return (
|
|
849
|
+
self.backend.bk_reshape(
|
|
850
|
+
res, ishape[0:axis] + [npix // 2] + ishape[axis + 1 :]
|
|
851
|
+
),
|
|
852
|
+
None,
|
|
819
853
|
)
|
|
820
854
|
|
|
821
|
-
return self.backend.bk_reshape(res, [npix // 2])
|
|
855
|
+
return self.backend.bk_reshape(res, [npix // 2]), None
|
|
822
856
|
|
|
823
857
|
else:
|
|
824
858
|
shape = list(im.shape)
|
|
859
|
+
if cell_ids is not None:
|
|
860
|
+
sim, new_cell_ids = self.backend.binned_mean(im, cell_ids)
|
|
861
|
+
return sim, new_cell_ids
|
|
825
862
|
|
|
826
863
|
lout = int(np.sqrt(shape[axis] // 12))
|
|
827
864
|
if im.__class__ == np.zeros([0]).__class__:
|
|
@@ -840,8 +877,11 @@ class FoCUS:
|
|
|
840
877
|
if len(shape) > axis:
|
|
841
878
|
oshape = oshape + shape[axis + 1 :]
|
|
842
879
|
|
|
843
|
-
return
|
|
844
|
-
self.backend.
|
|
880
|
+
return (
|
|
881
|
+
self.backend.bk_reduce_mean(
|
|
882
|
+
self.backend.bk_reshape(im, oshape), axis=axis + 1
|
|
883
|
+
),
|
|
884
|
+
None,
|
|
845
885
|
)
|
|
846
886
|
|
|
847
887
|
# --------------------------------------------------------
|
|
@@ -1794,14 +1834,14 @@ class FoCUS:
|
|
|
1794
1834
|
if self.padding == "VALID":
|
|
1795
1835
|
l_mask = l_mask[
|
|
1796
1836
|
:,
|
|
1797
|
-
self.KERNELSZ // 2 : -self.KERNELSZ // 2+1,
|
|
1798
|
-
self.KERNELSZ // 2 : -self.KERNELSZ // 2+1,
|
|
1837
|
+
self.KERNELSZ // 2 : -self.KERNELSZ // 2 + 1,
|
|
1838
|
+
self.KERNELSZ // 2 : -self.KERNELSZ // 2 + 1,
|
|
1799
1839
|
]
|
|
1800
1840
|
if shape[axis] != l_mask.shape[1]:
|
|
1801
1841
|
l_mask = l_mask[
|
|
1802
1842
|
:,
|
|
1803
|
-
self.KERNELSZ // 2 : -self.KERNELSZ // 2+1,
|
|
1804
|
-
self.KERNELSZ // 2 : -self.KERNELSZ // 2+1,
|
|
1843
|
+
self.KERNELSZ // 2 : -self.KERNELSZ // 2 + 1,
|
|
1844
|
+
self.KERNELSZ // 2 : -self.KERNELSZ // 2 + 1,
|
|
1805
1845
|
]
|
|
1806
1846
|
|
|
1807
1847
|
ichannel = 1
|
|
@@ -1868,10 +1908,10 @@ class FoCUS:
|
|
|
1868
1908
|
l_mask = self.backend.bk_complex(l_mask, self.backend.bk_cast(0.0 * l_mask))
|
|
1869
1909
|
|
|
1870
1910
|
if self.use_2D:
|
|
1871
|
-
#if self.padding == "VALID":
|
|
1911
|
+
# if self.padding == "VALID":
|
|
1872
1912
|
mtmp = l_mask
|
|
1873
1913
|
vtmp = l_x
|
|
1874
|
-
#else:
|
|
1914
|
+
# else:
|
|
1875
1915
|
# mtmp = l_mask[:,self.KERNELSZ // 2 : -self.KERNELSZ // 2,self.KERNELSZ // 2 : -self.KERNELSZ // 2,:]
|
|
1876
1916
|
# vtmp = l_x[:,self.KERNELSZ // 2 : -self.KERNELSZ // 2,self.KERNELSZ // 2 : -self.KERNELSZ // 2,:]
|
|
1877
1917
|
|
|
@@ -2125,7 +2165,7 @@ class FoCUS:
|
|
|
2125
2165
|
return self.backend.bk_reduce_sum(r)
|
|
2126
2166
|
|
|
2127
2167
|
# ---------------------------------------------−---------
|
|
2128
|
-
def convol(self, in_image, axis=0):
|
|
2168
|
+
def convol(self, in_image, axis=0, cell_ids=None, nside=None):
|
|
2129
2169
|
|
|
2130
2170
|
image = self.backend.bk_cast(in_image)
|
|
2131
2171
|
|
|
@@ -2290,6 +2330,61 @@ class FoCUS:
|
|
|
2290
2330
|
return self.backend.bk_reshape(res, in_image.shape + [self.NORIENT])
|
|
2291
2331
|
|
|
2292
2332
|
else:
|
|
2333
|
+
ishape = list(image.shape)
|
|
2334
|
+
|
|
2335
|
+
if cell_ids is not None:
|
|
2336
|
+
if cell_ids.shape[0] not in self.padding_conv:
|
|
2337
|
+
import healpix_convolution as hc
|
|
2338
|
+
from xdggs.healpix import HealpixInfo
|
|
2339
|
+
|
|
2340
|
+
res = self.backend.bk_zeros(
|
|
2341
|
+
ishape + [self.NORIENT], dtype=self.backend.all_cbk_type
|
|
2342
|
+
)
|
|
2343
|
+
|
|
2344
|
+
grid_info = HealpixInfo(
|
|
2345
|
+
level=int(np.log(nside) / np.log(2)), indexing_scheme="nested"
|
|
2346
|
+
)
|
|
2347
|
+
|
|
2348
|
+
for k in range(self.NORIENT):
|
|
2349
|
+
kernelR, kernelI = hc.kernels.wavelet_kernel(
|
|
2350
|
+
cell_ids, grid_info=grid_info, orientation=k, is_torch=True
|
|
2351
|
+
)
|
|
2352
|
+
self.kernelR_conv[(cell_ids.shape[0], k)] = kernelR.to(
|
|
2353
|
+
self.backend.all_bk_type
|
|
2354
|
+
).to(image.device)
|
|
2355
|
+
self.kernelI_conv[(cell_ids.shape[0], k)] = kernelI.to(
|
|
2356
|
+
self.backend.all_bk_type
|
|
2357
|
+
).to(image.device)
|
|
2358
|
+
self.padding_conv[(cell_ids.shape[0], k)] = hc.pad(
|
|
2359
|
+
cell_ids,
|
|
2360
|
+
grid_info=grid_info,
|
|
2361
|
+
ring=5 // 2, # wavelet kernel_size=5 is hard coded
|
|
2362
|
+
mode="mean",
|
|
2363
|
+
constant_value=0,
|
|
2364
|
+
)
|
|
2365
|
+
|
|
2366
|
+
for k in range(self.NORIENT):
|
|
2367
|
+
|
|
2368
|
+
kernelR = self.kernelR_conv[(cell_ids.shape[0], k)]
|
|
2369
|
+
kernelI = self.kernelI_conv[(cell_ids.shape[0], k)]
|
|
2370
|
+
padding = self.padding_conv[(cell_ids.shape[0], k)]
|
|
2371
|
+
if len(ishape) == 2:
|
|
2372
|
+
for l in range(ishape[0]):
|
|
2373
|
+
padded_data = padding.apply(image[l], is_torch=True)
|
|
2374
|
+
res[l, :, k] = kernelR.matmul(
|
|
2375
|
+
padded_data
|
|
2376
|
+
) + 1j * kernelI.matmul(padded_data)
|
|
2377
|
+
else:
|
|
2378
|
+
for l in range(ishape[0]):
|
|
2379
|
+
for k2 in range(ishape[2]):
|
|
2380
|
+
padded_data = padding.apply(
|
|
2381
|
+
image[l, :, k2], is_torch=True
|
|
2382
|
+
)
|
|
2383
|
+
res[l, :, k2, k] = kernelR.matmul(
|
|
2384
|
+
padded_data
|
|
2385
|
+
) + 1j * kernelI.matmul(padded_data)
|
|
2386
|
+
return res
|
|
2387
|
+
|
|
2293
2388
|
nside = int(np.sqrt(image.shape[axis] // 12))
|
|
2294
2389
|
|
|
2295
2390
|
if self.Idx_Neighbours[nside] is None:
|
|
@@ -2306,7 +2401,6 @@ class FoCUS:
|
|
|
2306
2401
|
l_ww_real = self.ww_Real[nside]
|
|
2307
2402
|
l_ww_imag = self.ww_Imag[nside]
|
|
2308
2403
|
|
|
2309
|
-
ishape = list(image.shape)
|
|
2310
2404
|
odata = 1
|
|
2311
2405
|
for k in range(axis + 1, len(ishape)):
|
|
2312
2406
|
odata = odata * ishape[k]
|
|
@@ -2460,7 +2554,7 @@ class FoCUS:
|
|
|
2460
2554
|
return res
|
|
2461
2555
|
|
|
2462
2556
|
# ---------------------------------------------−---------
|
|
2463
|
-
def smooth(self, in_image, axis=0):
|
|
2557
|
+
def smooth(self, in_image, axis=0, cell_ids=None, nside=None):
|
|
2464
2558
|
|
|
2465
2559
|
image = self.backend.bk_cast(in_image)
|
|
2466
2560
|
|
|
@@ -2589,6 +2683,50 @@ class FoCUS:
|
|
|
2589
2683
|
return self.backend.bk_reshape(res, in_image.shape)
|
|
2590
2684
|
|
|
2591
2685
|
else:
|
|
2686
|
+
|
|
2687
|
+
ishape = list(image.shape)
|
|
2688
|
+
|
|
2689
|
+
if cell_ids is not None:
|
|
2690
|
+
if cell_ids.shape[0] not in self.padding_smooth:
|
|
2691
|
+
import healpix_convolution as hc
|
|
2692
|
+
from xdggs.healpix import HealpixInfo
|
|
2693
|
+
|
|
2694
|
+
grid_info = HealpixInfo(
|
|
2695
|
+
level=int(np.log(nside) / np.log(2)), indexing_scheme="nested"
|
|
2696
|
+
)
|
|
2697
|
+
|
|
2698
|
+
kernel = hc.kernels.wavelet_smooth_kernel(
|
|
2699
|
+
cell_ids, grid_info=grid_info, is_torch=True
|
|
2700
|
+
)
|
|
2701
|
+
|
|
2702
|
+
self.kernel_smooth[cell_ids.shape[0]] = kernel.to(
|
|
2703
|
+
self.backend.all_bk_type
|
|
2704
|
+
).to(image.device)
|
|
2705
|
+
|
|
2706
|
+
self.padding_smooth[cell_ids.shape[0]] = hc.pad(
|
|
2707
|
+
cell_ids,
|
|
2708
|
+
grid_info=grid_info,
|
|
2709
|
+
ring=5 // 2, # wavelet kernel_size=5 is hard coded
|
|
2710
|
+
mode="mean",
|
|
2711
|
+
constant_value=0,
|
|
2712
|
+
)
|
|
2713
|
+
|
|
2714
|
+
kernel = self.kernel_smooth[cell_ids.shape[0]]
|
|
2715
|
+
padding = self.padding_smooth[cell_ids.shape[0]]
|
|
2716
|
+
|
|
2717
|
+
res = self.backend.bk_zeros(ishape, dtype=self.backend.all_cbk_type)
|
|
2718
|
+
|
|
2719
|
+
if len(ishape) == 2:
|
|
2720
|
+
for l in range(ishape[0]):
|
|
2721
|
+
padded_data = padding.apply(image[l], is_torch=True)
|
|
2722
|
+
res[l] = kernel.matmul(padded_data)
|
|
2723
|
+
else:
|
|
2724
|
+
for l in range(ishape[0]):
|
|
2725
|
+
for k2 in range(ishape[2]):
|
|
2726
|
+
padded_data = padding.apply(image[l, :, k2], is_torch=True)
|
|
2727
|
+
res[l, :, k2] = kernel.matmul(padded_data)
|
|
2728
|
+
return res
|
|
2729
|
+
|
|
2592
2730
|
nside = int(np.sqrt(image.shape[axis] // 12))
|
|
2593
2731
|
|
|
2594
2732
|
if self.Idx_Neighbours[nside] is None:
|
|
@@ -2604,7 +2742,6 @@ class FoCUS:
|
|
|
2604
2742
|
self.w_smooth[nside] = ws
|
|
2605
2743
|
|
|
2606
2744
|
l_w_smooth = self.w_smooth[nside]
|
|
2607
|
-
ishape = list(image.shape)
|
|
2608
2745
|
|
|
2609
2746
|
odata = 1
|
|
2610
2747
|
for k in range(axis + 1, len(ishape)):
|
|
@@ -2707,9 +2844,11 @@ class FoCUS:
|
|
|
2707
2844
|
# ---------------------------------------------−---------
|
|
2708
2845
|
def get_ww(self, nside=1):
|
|
2709
2846
|
if self.use_2D:
|
|
2710
|
-
|
|
2711
|
-
return (
|
|
2712
|
-
|
|
2847
|
+
|
|
2848
|
+
return (
|
|
2849
|
+
self.ww_RealT[1].reshape(self.KERNELSZ * self.KERNELSZ, self.NORIENT),
|
|
2850
|
+
self.ww_ImagT[1].reshape(self.KERNELSZ * self.KERNELSZ, self.NORIENT),
|
|
2851
|
+
)
|
|
2713
2852
|
else:
|
|
2714
2853
|
return (self.ww_Real[nside], self.ww_Imag[nside])
|
|
2715
2854
|
|
foscat/Synthesis.py
CHANGED
|
@@ -240,9 +240,9 @@ class Synthesis:
|
|
|
240
240
|
grd_mask = self.grd_mask
|
|
241
241
|
|
|
242
242
|
if grd_mask is not None:
|
|
243
|
-
g_tot =
|
|
243
|
+
g_tot = self.operation.backend.to_numpy(g_tot * grd_mask)
|
|
244
244
|
else:
|
|
245
|
-
g_tot = self.to_numpy(g_tot)
|
|
245
|
+
g_tot = self.operation.backend.to_numpy(g_tot)
|
|
246
246
|
|
|
247
247
|
g_tot[np.isnan(g_tot)] = 0.0
|
|
248
248
|
|
|
@@ -426,7 +426,7 @@ class Synthesis:
|
|
|
426
426
|
factr=factr,
|
|
427
427
|
maxiter=maxitt,
|
|
428
428
|
)
|
|
429
|
-
print(
|
|
429
|
+
print("Final Loss ", loss)
|
|
430
430
|
# update bias input data
|
|
431
431
|
if iteration < NUM_STEP_BIAS - 1:
|
|
432
432
|
# if self.mpi_rank==0:
|