foscat 2025.7.2__py3-none-any.whl → 2025.8.3__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- foscat/BkTorch.py +34 -3
- foscat/CNN.py +1 -0
- foscat/FoCUS.py +387 -165
- foscat/HOrientedConvol.py +546 -0
- foscat/HealSpline.py +8 -5
- foscat/Synthesis.py +27 -18
- foscat/UNET.py +200 -0
- foscat/scat_cov.py +289 -178
- foscat/scat_cov_map2D.py +1 -1
- {foscat-2025.7.2.dist-info → foscat-2025.8.3.dist-info}/METADATA +1 -1
- {foscat-2025.7.2.dist-info → foscat-2025.8.3.dist-info}/RECORD +14 -12
- {foscat-2025.7.2.dist-info → foscat-2025.8.3.dist-info}/WHEEL +0 -0
- {foscat-2025.7.2.dist-info → foscat-2025.8.3.dist-info}/licenses/LICENSE +0 -0
- {foscat-2025.7.2.dist-info → foscat-2025.8.3.dist-info}/top_level.txt +0 -0
foscat/scat_cov.py
CHANGED
|
@@ -33,7 +33,14 @@ testwarn = 0
|
|
|
33
33
|
|
|
34
34
|
|
|
35
35
|
class scat_cov:
|
|
36
|
-
def __init__(self,
|
|
36
|
+
def __init__(self,
|
|
37
|
+
s0, s2, s3, s4,
|
|
38
|
+
s1=None,
|
|
39
|
+
s3p=None,
|
|
40
|
+
backend=None,
|
|
41
|
+
use_1D=False,
|
|
42
|
+
return_data=False
|
|
43
|
+
):
|
|
37
44
|
self.S0 = s0
|
|
38
45
|
self.S2 = s2
|
|
39
46
|
self.S3 = s3
|
|
@@ -44,12 +51,13 @@ class scat_cov:
|
|
|
44
51
|
self.idx1 = None
|
|
45
52
|
self.idx2 = None
|
|
46
53
|
self.use_1D = use_1D
|
|
47
|
-
|
|
48
|
-
self.backend.bk_len(
|
|
49
|
-
|
|
50
|
-
|
|
51
|
-
|
|
52
|
-
|
|
54
|
+
if not return_data:
|
|
55
|
+
self.numel = self.backend.bk_len(s0)+ \
|
|
56
|
+
self.backend.bk_len(s1)+ \
|
|
57
|
+
self.backend.bk_len(s2)+ \
|
|
58
|
+
self.backend.bk_len(s3)+ \
|
|
59
|
+
self.backend.bk_len(s4)+ \
|
|
60
|
+
self.backend.bk_len(s3p)
|
|
53
61
|
|
|
54
62
|
def numpy(self):
|
|
55
63
|
if self.BACKEND == "numpy":
|
|
@@ -106,12 +114,13 @@ class scat_cov:
|
|
|
106
114
|
|
|
107
115
|
# ---------------------------------------------−---------
|
|
108
116
|
def flatten(self):
|
|
109
|
-
tmp = [
|
|
110
|
-
self.conv2complex(
|
|
111
|
-
self.backend.bk_reshape(self.S0, [self.S1.shape[0], self.S0.shape[1]*self.S0.shape[2]])
|
|
112
|
-
)
|
|
113
|
-
]
|
|
114
117
|
if self.use_1D:
|
|
118
|
+
tmp = [
|
|
119
|
+
self.conv2complex(
|
|
120
|
+
self.backend.bk_reshape(self.S0, [self.S1.shape[0], self.S0.shape[1]])
|
|
121
|
+
)
|
|
122
|
+
]
|
|
123
|
+
|
|
115
124
|
if self.S1 is not None:
|
|
116
125
|
tmp = tmp + [
|
|
117
126
|
self.conv2complex(
|
|
@@ -156,6 +165,11 @@ class scat_cov:
|
|
|
156
165
|
|
|
157
166
|
return self.backend.bk_concat(tmp, 1)
|
|
158
167
|
|
|
168
|
+
tmp = [
|
|
169
|
+
self.conv2complex(
|
|
170
|
+
self.backend.bk_reshape(self.S0, [self.S1.shape[0], self.S0.shape[1]*self.S0.shape[2]])
|
|
171
|
+
)
|
|
172
|
+
]
|
|
159
173
|
if self.S1 is not None:
|
|
160
174
|
tmp = tmp + [
|
|
161
175
|
self.conv2complex(
|
|
@@ -2819,6 +2833,7 @@ class funct(FOC.FoCUS):
|
|
|
2819
2833
|
P1_dic = {}
|
|
2820
2834
|
if cross:
|
|
2821
2835
|
P2_dic = {}
|
|
2836
|
+
|
|
2822
2837
|
elif (norm == "auto") and (self.P1_dic is not None):
|
|
2823
2838
|
P1_dic = self.P1_dic
|
|
2824
2839
|
if cross:
|
|
@@ -3610,23 +3625,40 @@ class funct(FOC.FoCUS):
|
|
|
3610
3625
|
self.P2_dic = P2_dic
|
|
3611
3626
|
|
|
3612
3627
|
if not return_data:
|
|
3613
|
-
|
|
3614
|
-
|
|
3615
|
-
|
|
3616
|
-
|
|
3617
|
-
|
|
3618
|
-
S3P = self.backend.bk_concat(S3P, -3)
|
|
3619
|
-
if calc_var:
|
|
3620
|
-
VS1 = self.backend.bk_concat(VS1, -2)
|
|
3621
|
-
VS2 = self.backend.bk_concat(VS2, -2)
|
|
3622
|
-
VS3 = self.backend.bk_concat(VS3, -3)
|
|
3623
|
-
VS4 = self.backend.bk_concat(VS4, -4)
|
|
3628
|
+
if not self.use_1D:
|
|
3629
|
+
S1 = self.backend.bk_concat(S1, -2)
|
|
3630
|
+
S2 = self.backend.bk_concat(S2, -2)
|
|
3631
|
+
S3 = self.backend.bk_concat(S3, -3)
|
|
3632
|
+
S4 = self.backend.bk_concat(S4, -4)
|
|
3624
3633
|
if cross:
|
|
3625
|
-
|
|
3634
|
+
S3P = self.backend.bk_concat(S3P, -3)
|
|
3635
|
+
if calc_var:
|
|
3636
|
+
VS1 = self.backend.bk_concat(VS1, -2)
|
|
3637
|
+
VS2 = self.backend.bk_concat(VS2, -2)
|
|
3638
|
+
VS3 = self.backend.bk_concat(VS3, -3)
|
|
3639
|
+
VS4 = self.backend.bk_concat(VS4, -4)
|
|
3640
|
+
if cross:
|
|
3641
|
+
VS3P = self.backend.bk_concat(VS3P, -3)
|
|
3642
|
+
else:
|
|
3643
|
+
S1 = self.backend.bk_concat(S1, -1)
|
|
3644
|
+
S2 = self.backend.bk_concat(S2, -1)
|
|
3645
|
+
S3 = self.backend.bk_concat(S3, -1)
|
|
3646
|
+
S4 = self.backend.bk_concat(S4, -1)
|
|
3647
|
+
if cross:
|
|
3648
|
+
S3P = self.backend.bk_concat(S3P, -1)
|
|
3649
|
+
if calc_var:
|
|
3650
|
+
VS1 = self.backend.bk_concat(VS1, -1)
|
|
3651
|
+
VS2 = self.backend.bk_concat(VS2, -1)
|
|
3652
|
+
VS3 = self.backend.bk_concat(VS3, -1)
|
|
3653
|
+
VS4 = self.backend.bk_concat(VS4, -1)
|
|
3654
|
+
if cross:
|
|
3655
|
+
VS3P = self.backend.bk_concat(VS3P, -1)
|
|
3626
3656
|
if calc_var:
|
|
3627
3657
|
if not cross:
|
|
3628
3658
|
return scat_cov(
|
|
3629
|
-
s0, S2, S3, S4, s1=S1, backend=self.backend,
|
|
3659
|
+
s0, S2, S3, S4, s1=S1, backend=self.backend,
|
|
3660
|
+
use_1D=self.use_1D,
|
|
3661
|
+
return_data=self.return_data
|
|
3630
3662
|
), scat_cov(
|
|
3631
3663
|
vs0,
|
|
3632
3664
|
VS2,
|
|
@@ -3635,6 +3667,7 @@ class funct(FOC.FoCUS):
|
|
|
3635
3667
|
s1=VS1,
|
|
3636
3668
|
backend=self.backend,
|
|
3637
3669
|
use_1D=self.use_1D,
|
|
3670
|
+
return_data=self.return_data
|
|
3638
3671
|
)
|
|
3639
3672
|
else:
|
|
3640
3673
|
return scat_cov(
|
|
@@ -3646,6 +3679,7 @@ class funct(FOC.FoCUS):
|
|
|
3646
3679
|
s3p=S3P,
|
|
3647
3680
|
backend=self.backend,
|
|
3648
3681
|
use_1D=self.use_1D,
|
|
3682
|
+
return_data=self.return_data
|
|
3649
3683
|
), scat_cov(
|
|
3650
3684
|
vs0,
|
|
3651
3685
|
VS2,
|
|
@@ -3655,11 +3689,16 @@ class funct(FOC.FoCUS):
|
|
|
3655
3689
|
s3p=VS3P,
|
|
3656
3690
|
backend=self.backend,
|
|
3657
3691
|
use_1D=self.use_1D,
|
|
3692
|
+
return_data=self.return_data
|
|
3658
3693
|
)
|
|
3659
3694
|
else:
|
|
3660
3695
|
if not cross:
|
|
3661
3696
|
return scat_cov(
|
|
3662
|
-
s0, S2, S3, S4,
|
|
3697
|
+
s0, S2, S3, S4,
|
|
3698
|
+
s1=S1,
|
|
3699
|
+
backend=self.backend,
|
|
3700
|
+
use_1D=self.use_1D,
|
|
3701
|
+
return_data=self.return_data
|
|
3663
3702
|
)
|
|
3664
3703
|
else:
|
|
3665
3704
|
return scat_cov(
|
|
@@ -3671,6 +3710,7 @@ class funct(FOC.FoCUS):
|
|
|
3671
3710
|
s3p=S3P,
|
|
3672
3711
|
backend=self.backend,
|
|
3673
3712
|
use_1D=self.use_1D,
|
|
3713
|
+
return_data=self.return_data
|
|
3674
3714
|
)
|
|
3675
3715
|
|
|
3676
3716
|
def clean_norm(self):
|
|
@@ -3748,8 +3788,12 @@ class funct(FOC.FoCUS):
|
|
|
3748
3788
|
# cconv, sconv are [Nbatch, Norient3, Npix_j3]
|
|
3749
3789
|
if self.use_1D:
|
|
3750
3790
|
s3 = conv * self.backend.bk_conjugate(MconvPsi)
|
|
3791
|
+
elif self.use_2D:
|
|
3792
|
+
s3 = self.backend.bk_expand_dims(conv, -4)* self.backend.bk_conjugate(
|
|
3793
|
+
MconvPsi
|
|
3794
|
+
) # [Nbatch, Norient3, Norient2, Npix_j3]
|
|
3751
3795
|
else:
|
|
3752
|
-
s3 = self.backend.bk_expand_dims(conv, -3)
|
|
3796
|
+
s3 = self.backend.bk_expand_dims(conv, -3)* self.backend.bk_conjugate(
|
|
3753
3797
|
MconvPsi
|
|
3754
3798
|
) # [Nbatch, Norient3, Norient2, Npix_j3]
|
|
3755
3799
|
### Apply the mask [Nmask, Npix_j3] and sum over pixels
|
|
@@ -3816,62 +3860,121 @@ class funct(FOC.FoCUS):
|
|
|
3816
3860
|
Done by Sihao Cheng and Rudy Morel.
|
|
3817
3861
|
"""
|
|
3818
3862
|
|
|
3819
|
-
|
|
3863
|
+
if N!=0:
|
|
3864
|
+
filter = np.zeros([J, L, M, N], dtype="complex64")
|
|
3820
3865
|
|
|
3821
|
-
|
|
3866
|
+
slant = 4.0 / L
|
|
3822
3867
|
|
|
3823
|
-
|
|
3868
|
+
for j in range(J):
|
|
3824
3869
|
|
|
3825
|
-
|
|
3870
|
+
for ell in range(L):
|
|
3826
3871
|
|
|
3827
|
-
|
|
3828
|
-
|
|
3829
|
-
|
|
3872
|
+
theta = (int(L - L / 2 - 1) - ell) * np.pi / L
|
|
3873
|
+
sigma = 0.8 * 2**j
|
|
3874
|
+
xi = 3.0 / 4.0 * np.pi / 2**j
|
|
3830
3875
|
|
|
3831
|
-
|
|
3832
|
-
|
|
3833
|
-
|
|
3834
|
-
|
|
3835
|
-
|
|
3836
|
-
|
|
3837
|
-
|
|
3838
|
-
|
|
3839
|
-
|
|
3840
|
-
|
|
3841
|
-
|
|
3842
|
-
|
|
3843
|
-
|
|
3844
|
-
|
|
3845
|
-
|
|
3846
|
-
for ii, ex in enumerate([-1, 0]):
|
|
3847
|
-
for jj, ey in enumerate([-1, 0]):
|
|
3848
|
-
xx[ii, jj], yy[ii, jj] = np.mgrid[
|
|
3849
|
-
ex * M : M + ex * M, ey * N : N + ey * N
|
|
3850
|
-
]
|
|
3851
|
-
|
|
3852
|
-
arg = -(
|
|
3853
|
-
curv[0, 0] * xx * xx
|
|
3854
|
-
+ (curv[0, 1] + curv[1, 0]) * xx * yy
|
|
3855
|
-
+ curv[1, 1] * yy * yy
|
|
3856
|
-
)
|
|
3857
|
-
argi = arg + 1.0j * (xx * xi * np.cos(theta) + yy * xi * np.sin(theta))
|
|
3876
|
+
R = np.array(
|
|
3877
|
+
[[np.cos(theta), -np.sin(theta)], [np.sin(theta), np.cos(theta)]],
|
|
3878
|
+
np.float64,
|
|
3879
|
+
)
|
|
3880
|
+
R_inv = np.array(
|
|
3881
|
+
[[np.cos(theta), np.sin(theta)], [-np.sin(theta), np.cos(theta)]],
|
|
3882
|
+
np.float64,
|
|
3883
|
+
)
|
|
3884
|
+
D = np.array([[1, 0], [0, slant * slant]])
|
|
3885
|
+
curv = np.matmul(R, np.matmul(D, R_inv)) / (2 * sigma * sigma)
|
|
3886
|
+
|
|
3887
|
+
gab = np.zeros((M, N), np.complex128)
|
|
3888
|
+
xx = np.empty((2, 2, M, N))
|
|
3889
|
+
yy = np.empty((2, 2, M, N))
|
|
3858
3890
|
|
|
3859
|
-
|
|
3860
|
-
|
|
3891
|
+
for ii, ex in enumerate([-1, 0]):
|
|
3892
|
+
for jj, ey in enumerate([-1, 0]):
|
|
3893
|
+
xx[ii, jj], yy[ii, jj] = np.mgrid[
|
|
3894
|
+
ex * M : M + ex * M, ey * N : N + ey * N
|
|
3895
|
+
]
|
|
3896
|
+
|
|
3897
|
+
arg = -(
|
|
3898
|
+
curv[0, 0] * xx * xx
|
|
3899
|
+
+ (curv[0, 1] + curv[1, 0]) * xx * yy
|
|
3900
|
+
+ curv[1, 1] * yy * yy
|
|
3901
|
+
)
|
|
3902
|
+
argi = arg + 1.0j * (xx * xi * np.cos(theta) + yy * xi * np.sin(theta))
|
|
3903
|
+
|
|
3904
|
+
gabi = np.exp(argi).sum((0, 1))
|
|
3905
|
+
gab = np.exp(arg).sum((0, 1))
|
|
3906
|
+
|
|
3907
|
+
norm_factor = 2 * np.pi * sigma * sigma / slant
|
|
3908
|
+
|
|
3909
|
+
gab = gab / norm_factor
|
|
3910
|
+
|
|
3911
|
+
gabi = gabi / norm_factor
|
|
3912
|
+
|
|
3913
|
+
K = gabi.sum() / gab.sum()
|
|
3914
|
+
|
|
3915
|
+
# Apply the Gaussian
|
|
3916
|
+
filter[j, ell] = np.fft.fft2(gabi - K * gab)
|
|
3917
|
+
filter[j, ell, 0, 0] = 0.0
|
|
3861
3918
|
|
|
3862
|
-
|
|
3919
|
+
return self.backend.bk_cast(filter)
|
|
3920
|
+
else:
|
|
3921
|
+
filter = np.zeros([J, L, M], dtype="complex64")
|
|
3922
|
+
#TODO
|
|
3923
|
+
print('filter for 1D not yet available')
|
|
3924
|
+
exit(0)
|
|
3925
|
+
slant = 4.0 / L
|
|
3926
|
+
|
|
3927
|
+
for j in range(J):
|
|
3863
3928
|
|
|
3864
|
-
|
|
3929
|
+
for ell in range(L):
|
|
3865
3930
|
|
|
3866
|
-
|
|
3931
|
+
theta = (int(L - L / 2 - 1) - ell) * np.pi / L
|
|
3932
|
+
sigma = 0.8 * 2**j
|
|
3933
|
+
xi = 3.0 / 4.0 * np.pi / 2**j
|
|
3867
3934
|
|
|
3868
|
-
|
|
3935
|
+
R = np.array(
|
|
3936
|
+
[[np.cos(theta), -np.sin(theta)], [np.sin(theta), np.cos(theta)]],
|
|
3937
|
+
np.float64,
|
|
3938
|
+
)
|
|
3939
|
+
R_inv = np.array(
|
|
3940
|
+
[[np.cos(theta), np.sin(theta)], [-np.sin(theta), np.cos(theta)]],
|
|
3941
|
+
np.float64,
|
|
3942
|
+
)
|
|
3943
|
+
D = np.array([[1, 0], [0, slant * slant]])
|
|
3944
|
+
curv = np.matmul(R, np.matmul(D, R_inv)) / (2 * sigma * sigma)
|
|
3869
3945
|
|
|
3870
|
-
|
|
3871
|
-
|
|
3872
|
-
|
|
3946
|
+
gab = np.zeros((M), np.complex128)
|
|
3947
|
+
xx = np.empty((M))
|
|
3948
|
+
|
|
3949
|
+
for ii, ex in enumerate([-1, 0]):
|
|
3950
|
+
for jj, ey in enumerate([-1, 0]):
|
|
3951
|
+
xx[ii, jj], yy[ii, jj] = np.mgrid[
|
|
3952
|
+
ex * M : M + ex * M, ey * N : N + ey * N
|
|
3953
|
+
]
|
|
3954
|
+
|
|
3955
|
+
arg = -(
|
|
3956
|
+
curv[0, 0] * xx * xx
|
|
3957
|
+
+ (curv[0, 1] + curv[1, 0]) * xx * yy
|
|
3958
|
+
+ curv[1, 1] * yy * yy
|
|
3959
|
+
)
|
|
3960
|
+
argi = arg + 1.0j * (xx * xi * np.cos(theta) + yy * xi * np.sin(theta))
|
|
3873
3961
|
|
|
3874
|
-
|
|
3962
|
+
gabi = np.exp(argi).sum((0, 1))
|
|
3963
|
+
gab = np.exp(arg).sum((0, 1))
|
|
3964
|
+
|
|
3965
|
+
norm_factor = 2 * np.pi * sigma * sigma / slant
|
|
3966
|
+
|
|
3967
|
+
gab = gab / norm_factor
|
|
3968
|
+
|
|
3969
|
+
gabi = gabi / norm_factor
|
|
3970
|
+
|
|
3971
|
+
K = gabi.sum() / gab.sum()
|
|
3972
|
+
|
|
3973
|
+
# Apply the Gaussian
|
|
3974
|
+
filter[j, ell] = np.fft.fft2(gabi - K * gab)
|
|
3975
|
+
filter[j, ell, 0, 0] = 0.0
|
|
3976
|
+
|
|
3977
|
+
return self.backend.bk_cast(filter)
|
|
3875
3978
|
|
|
3876
3979
|
# ------------------------------------------------------------------------------------------
|
|
3877
3980
|
#
|
|
@@ -4107,18 +4210,24 @@ class funct(FOC.FoCUS):
|
|
|
4107
4210
|
if data2 is not None:
|
|
4108
4211
|
N_image2 = data2.shape[0]
|
|
4109
4212
|
J = int(np.log(nside) / np.log(2)) - 1 # Number of j scales
|
|
4213
|
+
dim=(-2,-1)
|
|
4110
4214
|
elif self.use_1D:
|
|
4111
4215
|
if len(data.shape) == 2:
|
|
4112
4216
|
npix = int(im_shape[1]) # Number of pixels
|
|
4217
|
+
M = im_shape[1]
|
|
4218
|
+
N=0
|
|
4113
4219
|
N_image = 1
|
|
4114
4220
|
N_image2 = 1
|
|
4115
4221
|
else:
|
|
4116
4222
|
npix = int(im_shape[0]) # Number of pixels
|
|
4117
4223
|
N_image = data.shape[0]
|
|
4224
|
+
M = im_shape[0]
|
|
4225
|
+
N=0
|
|
4118
4226
|
if data2 is not None:
|
|
4119
4227
|
N_image2 = data2.shape[0]
|
|
4120
4228
|
|
|
4121
4229
|
nside = int(npix)
|
|
4230
|
+
dim=(-1)
|
|
4122
4231
|
|
|
4123
4232
|
J = int(np.log(nside) / np.log(2)) - 1 # Number of j scales
|
|
4124
4233
|
else:
|
|
@@ -4169,10 +4278,10 @@ class funct(FOC.FoCUS):
|
|
|
4169
4278
|
|
|
4170
4279
|
# convert numpy array input into self.backend.bk_ tensors
|
|
4171
4280
|
data = self.backend.bk_cast(data)
|
|
4172
|
-
data_f = self.backend.bk_fftn(data, dim=
|
|
4281
|
+
data_f = self.backend.bk_fftn(data, dim=dim)
|
|
4173
4282
|
if data2 is not None:
|
|
4174
4283
|
data2 = self.backend.bk_cast(data2)
|
|
4175
|
-
data2_f = self.backend.bk_fftn(data2, dim=
|
|
4284
|
+
data2_f = self.backend.bk_fftn(data2, dim=dim)
|
|
4176
4285
|
|
|
4177
4286
|
# initialize tensors for scattering coefficients
|
|
4178
4287
|
S2 = self.backend.bk_zeros((N_image, J, L), dtype=data.dtype)
|
|
@@ -4248,13 +4357,13 @@ class funct(FOC.FoCUS):
|
|
|
4248
4357
|
I1 = self.backend.bk_ifftn(
|
|
4249
4358
|
data_f[None, None, None, :, :]
|
|
4250
4359
|
* filters_set[None, :J, :, :, :],
|
|
4251
|
-
dim=
|
|
4360
|
+
dim=dim,
|
|
4252
4361
|
).abs()
|
|
4253
4362
|
else:
|
|
4254
4363
|
I1 = self.backend.bk_ifftn(
|
|
4255
4364
|
data_f[:, None, None, :, :]
|
|
4256
4365
|
* filters_set[None, :J, :, :, :],
|
|
4257
|
-
dim=
|
|
4366
|
+
dim=dim,
|
|
4258
4367
|
).abs()
|
|
4259
4368
|
elif self.use_1D:
|
|
4260
4369
|
if len(data.shape) == 1:
|
|
@@ -4270,12 +4379,12 @@ class funct(FOC.FoCUS):
|
|
|
4270
4379
|
else:
|
|
4271
4380
|
print("todo")
|
|
4272
4381
|
|
|
4273
|
-
S2 = (I1**2 * edge_mask).mean(
|
|
4274
|
-
S1 = (I1 * edge_mask).mean(
|
|
4382
|
+
S2 = (I1**2 * edge_mask).mean(dim)
|
|
4383
|
+
S1 = (I1 * edge_mask).mean(dim)
|
|
4275
4384
|
|
|
4276
4385
|
if get_variance:
|
|
4277
|
-
S2_sigma = (I1**2 * edge_mask).std(
|
|
4278
|
-
S1_sigma = (I1 * edge_mask).std(
|
|
4386
|
+
S2_sigma = (I1**2 * edge_mask).std(dim)
|
|
4387
|
+
S1_sigma = (I1 * edge_mask).std(dim)
|
|
4279
4388
|
|
|
4280
4389
|
else:
|
|
4281
4390
|
if self.use_2D:
|
|
@@ -4283,64 +4392,64 @@ class funct(FOC.FoCUS):
|
|
|
4283
4392
|
I1 = self.backend.bk_ifftn(
|
|
4284
4393
|
data_f[None, None, None, :, :]
|
|
4285
4394
|
* filters_set[None, :J, :, :, :],
|
|
4286
|
-
dim=
|
|
4395
|
+
dim=dim,
|
|
4287
4396
|
)
|
|
4288
4397
|
I2 = self.backend.bk_ifftn(
|
|
4289
4398
|
data2_f[None, None, None, :, :]
|
|
4290
4399
|
* filters_set[None, :J, :, :, :],
|
|
4291
|
-
dim=
|
|
4400
|
+
dim=dim,
|
|
4292
4401
|
)
|
|
4293
4402
|
else:
|
|
4294
4403
|
I1 = self.backend.bk_ifftn(
|
|
4295
4404
|
data_f[:, None, None, :, :]
|
|
4296
4405
|
* filters_set[None, :J, :, :, :],
|
|
4297
|
-
dim=
|
|
4406
|
+
dim=dim,
|
|
4298
4407
|
)
|
|
4299
4408
|
I2 = self.backend.bk_ifftn(
|
|
4300
4409
|
data2_f[:, None, None, :, :]
|
|
4301
4410
|
* filters_set[None, :J, :, :, :],
|
|
4302
|
-
dim=
|
|
4411
|
+
dim=dim,
|
|
4303
4412
|
)
|
|
4304
4413
|
elif self.use_1D:
|
|
4305
4414
|
if len(data.shape) == 1:
|
|
4306
4415
|
I1 = self.backend.bk_ifftn(
|
|
4307
4416
|
data_f[None, None, None, :] * filters_set[None, :J, :, :],
|
|
4308
|
-
dim=
|
|
4417
|
+
dim=dim,
|
|
4309
4418
|
)
|
|
4310
4419
|
I2 = self.backend.bk_ifftn(
|
|
4311
4420
|
data2_f[None, None, None, :] * filters_set[None, :J, :, :],
|
|
4312
|
-
dim=
|
|
4421
|
+
dim=dim,
|
|
4313
4422
|
)
|
|
4314
4423
|
else:
|
|
4315
4424
|
I1 = self.backend.bk_ifftn(
|
|
4316
4425
|
data_f[:, None, None, :] * filters_set[None, :J, :, :],
|
|
4317
|
-
dim=
|
|
4426
|
+
dim=dim,
|
|
4318
4427
|
)
|
|
4319
4428
|
I2 = self.backend.bk_ifftn(
|
|
4320
4429
|
data2_f[:, None, None, :] * filters_set[None, :J, :, :],
|
|
4321
|
-
dim=
|
|
4430
|
+
dim=dim,
|
|
4322
4431
|
)
|
|
4323
4432
|
else:
|
|
4324
4433
|
print("todo")
|
|
4325
4434
|
|
|
4326
4435
|
I1 = self.backend.bk_real(I1 * self.backend.bk_conjugate(I2))
|
|
4327
4436
|
|
|
4328
|
-
S2 = self.backend.bk_reduce_mean((I1 * edge_mask), axis=
|
|
4437
|
+
S2 = self.backend.bk_reduce_mean((I1 * edge_mask), axis=dim)
|
|
4329
4438
|
if get_variance:
|
|
4330
4439
|
S2_sigma = self.backend.bk_reduce_std(
|
|
4331
|
-
(I1 * edge_mask), axis=
|
|
4440
|
+
(I1 * edge_mask), axis=dim
|
|
4332
4441
|
)
|
|
4333
4442
|
|
|
4334
4443
|
I1 = self.backend.bk_L1(I1)
|
|
4335
4444
|
|
|
4336
|
-
S1 = self.backend.bk_reduce_mean((I1 * edge_mask), axis=
|
|
4445
|
+
S1 = self.backend.bk_reduce_mean((I1 * edge_mask), axis=dim)
|
|
4337
4446
|
|
|
4338
4447
|
if get_variance:
|
|
4339
4448
|
S1_sigma = self.backend.bk_reduce_std(
|
|
4340
|
-
(I1 * edge_mask), axis=
|
|
4449
|
+
(I1 * edge_mask), axis=dim
|
|
4341
4450
|
)
|
|
4342
4451
|
|
|
4343
|
-
I1_f = self.backend.bk_fftn(I1, dim=
|
|
4452
|
+
I1_f = self.backend.bk_fftn(I1, dim=dim)
|
|
4344
4453
|
|
|
4345
4454
|
if pseudo_coef != 1:
|
|
4346
4455
|
I1 = I1**pseudo_coef
|
|
@@ -4362,14 +4471,14 @@ class funct(FOC.FoCUS):
|
|
|
4362
4471
|
data2_f_small = self.cut_high_k_off(data2_f, dx3, dy3)
|
|
4363
4472
|
if edge:
|
|
4364
4473
|
I1_small = self.backend.bk_ifftn(
|
|
4365
|
-
I1_f_small, dim=
|
|
4474
|
+
I1_f_small, dim=dim, norm="ortho"
|
|
4366
4475
|
)
|
|
4367
4476
|
data_small = self.backend.bk_ifftn(
|
|
4368
|
-
data_f_small, dim=
|
|
4477
|
+
data_f_small, dim=dim, norm="ortho"
|
|
4369
4478
|
)
|
|
4370
4479
|
if data2 is not None:
|
|
4371
4480
|
data2_small = self.backend.bk_ifftn(
|
|
4372
|
-
data2_f_small, dim=
|
|
4481
|
+
data2_f_small, dim=dim, norm="ortho"
|
|
4373
4482
|
)
|
|
4374
4483
|
|
|
4375
4484
|
wavelet_f3 = self.cut_high_k_off(filters_set[j3], dx3, dy3) # L,x,y
|
|
@@ -4400,10 +4509,10 @@ class funct(FOC.FoCUS):
|
|
|
4400
4509
|
) * wavelet_f3_squared.view(1, 1, L, M3, N3)
|
|
4401
4510
|
if edge:
|
|
4402
4511
|
I12_w3_small = self.backend.bk_ifftn(
|
|
4403
|
-
I1_f2_wf3_small, dim=
|
|
4512
|
+
I1_f2_wf3_small, dim=dim, norm="ortho"
|
|
4404
4513
|
)
|
|
4405
4514
|
I12_w3_2_small = self.backend.bk_ifftn(
|
|
4406
|
-
I1_f2_wf3_2_small, dim=
|
|
4515
|
+
I1_f2_wf3_2_small, dim=dim, norm="ortho"
|
|
4407
4516
|
)
|
|
4408
4517
|
if use_ref:
|
|
4409
4518
|
if normalization == "P11":
|
|
@@ -4420,7 +4529,7 @@ class funct(FOC.FoCUS):
|
|
|
4420
4529
|
if normalization == "P11":
|
|
4421
4530
|
# [N_image,l2,l3,x,y]
|
|
4422
4531
|
P11_temp = (I1_f2_wf3_small.abs() ** 2).mean(
|
|
4423
|
-
|
|
4532
|
+
dim
|
|
4424
4533
|
) * fft_factor
|
|
4425
4534
|
norm_factor_S3 = (
|
|
4426
4535
|
S2[:, None, j3, :] * P11_temp**pseudo_coef
|
|
@@ -4435,7 +4544,7 @@ class funct(FOC.FoCUS):
|
|
|
4435
4544
|
(
|
|
4436
4545
|
data_f_small.view(N_image, 1, 1, M3, N3)
|
|
4437
4546
|
* self.backend.bk_conjugate(I1_f2_wf3_small)
|
|
4438
|
-
).mean(
|
|
4547
|
+
).mean(dim)
|
|
4439
4548
|
* fft_factor
|
|
4440
4549
|
/ norm_factor_S3
|
|
4441
4550
|
)
|
|
@@ -4445,7 +4554,7 @@ class funct(FOC.FoCUS):
|
|
|
4445
4554
|
(
|
|
4446
4555
|
data_f_small.view(N_image, 1, 1, M3, N3)
|
|
4447
4556
|
* self.backend.bk_conjugate(I1_f2_wf3_small)
|
|
4448
|
-
).std(
|
|
4557
|
+
).std(dim)
|
|
4449
4558
|
* fft_factor
|
|
4450
4559
|
/ norm_factor_S3
|
|
4451
4560
|
)
|
|
@@ -4456,7 +4565,7 @@ class funct(FOC.FoCUS):
|
|
|
4456
4565
|
* self.backend.bk_conjugate(I12_w3_small)
|
|
4457
4566
|
* edge_mask[None, None, None, :, :]
|
|
4458
4567
|
).mean( # [..., edge_dx : M3 - edge_dx, edge_dy : N3 - edge_dy]
|
|
4459
|
-
|
|
4568
|
+
dim
|
|
4460
4569
|
)
|
|
4461
4570
|
* fft_factor
|
|
4462
4571
|
/ norm_factor_S3
|
|
@@ -4467,7 +4576,7 @@ class funct(FOC.FoCUS):
|
|
|
4467
4576
|
data_small.view(N_image, 1, 1, M3, N3)
|
|
4468
4577
|
* self.backend.bk_conjugate(I12_w3_small)
|
|
4469
4578
|
* edge_mask[None, None, None, :, :]
|
|
4470
|
-
).std(
|
|
4579
|
+
).std(dim)
|
|
4471
4580
|
* fft_factor
|
|
4472
4581
|
/ norm_factor_S3
|
|
4473
4582
|
)
|
|
@@ -4477,7 +4586,7 @@ class funct(FOC.FoCUS):
|
|
|
4477
4586
|
(
|
|
4478
4587
|
data2_f_small.view(N_image2, 1, 1, M3, N3)
|
|
4479
4588
|
* self.backend.bk_conjugate(I1_f2_wf3_small)
|
|
4480
|
-
).mean(
|
|
4589
|
+
).mean(dim)
|
|
4481
4590
|
* fft_factor
|
|
4482
4591
|
/ norm_factor_S3
|
|
4483
4592
|
)
|
|
@@ -4487,7 +4596,7 @@ class funct(FOC.FoCUS):
|
|
|
4487
4596
|
(
|
|
4488
4597
|
data2_f_small.view(N_image2, 1, 1, M3, N3)
|
|
4489
4598
|
* self.backend.bk_conjugate(I1_f2_wf3_small)
|
|
4490
|
-
).std(
|
|
4599
|
+
).std(dim)
|
|
4491
4600
|
* fft_factor
|
|
4492
4601
|
/ norm_factor_S3
|
|
4493
4602
|
)
|
|
@@ -4497,7 +4606,7 @@ class funct(FOC.FoCUS):
|
|
|
4497
4606
|
data2_small.view(N_image2, 1, 1, M3, N3)
|
|
4498
4607
|
* self.backend.bk_conjugate(I12_w3_small)
|
|
4499
4608
|
* edge_mask[None, None, None, :, :]
|
|
4500
|
-
).mean(
|
|
4609
|
+
).mean(dim)
|
|
4501
4610
|
* fft_factor
|
|
4502
4611
|
/ norm_factor_S3
|
|
4503
4612
|
)
|
|
@@ -4507,7 +4616,7 @@ class funct(FOC.FoCUS):
|
|
|
4507
4616
|
data2_small.view(N_image2, 1, 1, M3, N3)
|
|
4508
4617
|
* self.backend.bk_conjugate(I12_w3_small)
|
|
4509
4618
|
* edge_mask[None, None, None, :, :]
|
|
4510
|
-
).std(
|
|
4619
|
+
).std(dim)
|
|
4511
4620
|
* fft_factor
|
|
4512
4621
|
/ norm_factor_S3
|
|
4513
4622
|
)
|
|
@@ -4528,7 +4637,7 @@ class funct(FOC.FoCUS):
|
|
|
4528
4637
|
N_image, 1, L, L, M3, N3
|
|
4529
4638
|
)
|
|
4530
4639
|
)
|
|
4531
|
-
).mean(
|
|
4640
|
+
).mean(dim) * fft_factor
|
|
4532
4641
|
if get_variance:
|
|
4533
4642
|
S4_sigma[:, Ndata_S4, :, :, :] = (
|
|
4534
4643
|
I1_f_small[:, j1].view(
|
|
@@ -4539,7 +4648,7 @@ class funct(FOC.FoCUS):
|
|
|
4539
4648
|
N_image, 1, L, L, M3, N3
|
|
4540
4649
|
)
|
|
4541
4650
|
)
|
|
4542
|
-
).std(
|
|
4651
|
+
).std(dim) * fft_factor
|
|
4543
4652
|
else:
|
|
4544
4653
|
for l1 in range(L):
|
|
4545
4654
|
# [N_image,l2,l3,x,y]
|
|
@@ -4552,7 +4661,7 @@ class funct(FOC.FoCUS):
|
|
|
4552
4661
|
N_image, L, L, M3, N3
|
|
4553
4662
|
)
|
|
4554
4663
|
)
|
|
4555
|
-
).mean(
|
|
4664
|
+
).mean(dim) * fft_factor
|
|
4556
4665
|
if get_variance:
|
|
4557
4666
|
S4_sigma[:, Ndata_S4, l1, :, :] = (
|
|
4558
4667
|
I1_f_small[:, j1, l1].view(
|
|
@@ -4563,7 +4672,7 @@ class funct(FOC.FoCUS):
|
|
|
4563
4672
|
N_image, L, L, M3, N3
|
|
4564
4673
|
)
|
|
4565
4674
|
)
|
|
4566
|
-
).std(
|
|
4675
|
+
).std(dim) * fft_factor
|
|
4567
4676
|
else:
|
|
4568
4677
|
if not if_large_batch:
|
|
4569
4678
|
# [N_image,l1,l2,l3,x,y]
|
|
@@ -4577,7 +4686,7 @@ class funct(FOC.FoCUS):
|
|
|
4577
4686
|
)
|
|
4578
4687
|
)
|
|
4579
4688
|
* edge_mask[None, None, None, None, :, :]
|
|
4580
|
-
).mean(
|
|
4689
|
+
).mean(dim) * fft_factor
|
|
4581
4690
|
if get_variance:
|
|
4582
4691
|
S4_sigma[:, Ndata_S4, :, :, :] = (
|
|
4583
4692
|
I1_small[:, j1].view(
|
|
@@ -4591,7 +4700,7 @@ class funct(FOC.FoCUS):
|
|
|
4591
4700
|
* edge_mask[
|
|
4592
4701
|
None, None, None, None, :, :
|
|
4593
4702
|
]
|
|
4594
|
-
).std(
|
|
4703
|
+
).std(dim) * fft_factor
|
|
4595
4704
|
else:
|
|
4596
4705
|
for l1 in range(L):
|
|
4597
4706
|
# [N_image,l2,l3,x,y]
|
|
@@ -4607,7 +4716,7 @@ class funct(FOC.FoCUS):
|
|
|
4607
4716
|
* edge_mask[
|
|
4608
4717
|
None, None, None, None, :, :
|
|
4609
4718
|
]
|
|
4610
|
-
).mean(
|
|
4719
|
+
).mean(dim) * fft_factor
|
|
4611
4720
|
if get_variance:
|
|
4612
4721
|
S4_sigma[:, Ndata_S4, l1, :, :] = (
|
|
4613
4722
|
I1_small[:, j1].view(
|
|
@@ -4621,7 +4730,7 @@ class funct(FOC.FoCUS):
|
|
|
4621
4730
|
* edge_mask[
|
|
4622
4731
|
None, None, None, None, :, :
|
|
4623
4732
|
]
|
|
4624
|
-
).std(
|
|
4733
|
+
).std(dim) * fft_factor
|
|
4625
4734
|
|
|
4626
4735
|
Ndata_S4 += 1
|
|
4627
4736
|
|
|
@@ -4689,11 +4798,11 @@ class funct(FOC.FoCUS):
|
|
|
4689
4798
|
std_data = self.backend.bk_zeros((N_image, 1), dtype=data.dtype)
|
|
4690
4799
|
|
|
4691
4800
|
if data2 is None:
|
|
4692
|
-
mean_data[:, 0] = data.mean(
|
|
4693
|
-
std_data[:, 0] = data.std(
|
|
4801
|
+
mean_data[:, 0] = data.mean(dim)
|
|
4802
|
+
std_data[:, 0] = data.std(dim)
|
|
4694
4803
|
else:
|
|
4695
|
-
mean_data[:, 0] = (data2 * data).mean(
|
|
4696
|
-
std_data[:, 0] = (data2 * data).std(
|
|
4804
|
+
mean_data[:, 0] = (data2 * data).mean(dim)
|
|
4805
|
+
std_data[:, 0] = (data2 * data).std(dim)
|
|
4697
4806
|
|
|
4698
4807
|
if get_variance:
|
|
4699
4808
|
ref_sigma = {}
|
|
@@ -4913,10 +5022,10 @@ class funct(FOC.FoCUS):
|
|
|
4913
5022
|
|
|
4914
5023
|
# convert numpy array input into self.backend.bk_ tensors
|
|
4915
5024
|
data = self.backend.bk_cast(data)
|
|
4916
|
-
data_f = self.backend.bk_fftn(data, dim=
|
|
5025
|
+
data_f = self.backend.bk_fftn(data, dim=dim)
|
|
4917
5026
|
if data2 is not None:
|
|
4918
5027
|
data2 = self.backend.bk_cast(data2)
|
|
4919
|
-
data2_f = self.backend.bk_fftn(data2, dim=
|
|
5028
|
+
data2_f = self.backend.bk_fftn(data2, dim=dim)
|
|
4920
5029
|
|
|
4921
5030
|
# initialize tensors for scattering coefficients
|
|
4922
5031
|
|
|
@@ -4967,7 +5076,7 @@ class funct(FOC.FoCUS):
|
|
|
4967
5076
|
self.backend.bk_ifftn(
|
|
4968
5077
|
data_f[None, None, None, :, :]
|
|
4969
5078
|
* filters_set[None, :J, :, :, :],
|
|
4970
|
-
dim=
|
|
5079
|
+
dim=dim,
|
|
4971
5080
|
)
|
|
4972
5081
|
)
|
|
4973
5082
|
else:
|
|
@@ -4975,7 +5084,7 @@ class funct(FOC.FoCUS):
|
|
|
4975
5084
|
self.backend.bk_ifftn(
|
|
4976
5085
|
data_f[:, None, None, :, :]
|
|
4977
5086
|
* filters_set[None, :J, :, :, :],
|
|
4978
|
-
dim=
|
|
5087
|
+
dim=dim,
|
|
4979
5088
|
)
|
|
4980
5089
|
)
|
|
4981
5090
|
elif self.use_1D:
|
|
@@ -4996,37 +5105,37 @@ class funct(FOC.FoCUS):
|
|
|
4996
5105
|
else:
|
|
4997
5106
|
print("todo")
|
|
4998
5107
|
|
|
4999
|
-
S2 = self.backend.bk_reduce_mean((I1**2 * edge_mask), axis=
|
|
5000
|
-
S1 = self.backend.bk_reduce_mean(I1 * edge_mask, axis=
|
|
5108
|
+
S2 = self.backend.bk_reduce_mean((I1**2 * edge_mask), axis=dim)
|
|
5109
|
+
S1 = self.backend.bk_reduce_mean(I1 * edge_mask, axis=dim)
|
|
5001
5110
|
|
|
5002
5111
|
if get_variance:
|
|
5003
5112
|
S2_sigma = self.backend.bk_reduce_std(
|
|
5004
|
-
(I1**2 * edge_mask), axis=
|
|
5113
|
+
(I1**2 * edge_mask), axis=dim
|
|
5005
5114
|
)
|
|
5006
|
-
S1_sigma = self.backend.bk_reduce_std((I1 * edge_mask), axis=
|
|
5115
|
+
S1_sigma = self.backend.bk_reduce_std((I1 * edge_mask), axis=dim)
|
|
5007
5116
|
|
|
5008
|
-
I1_f = self.backend.bk_fftn(I1, dim=
|
|
5117
|
+
I1_f = self.backend.bk_fftn(I1, dim=dim)
|
|
5009
5118
|
|
|
5010
5119
|
else:
|
|
5011
5120
|
if self.use_2D:
|
|
5012
5121
|
if len(data.shape) == 2:
|
|
5013
5122
|
I1 = self.backend.bk_ifftn(
|
|
5014
5123
|
data_f[None, None, None, :, :] * filters_set[None, :J, :, :, :],
|
|
5015
|
-
dim=
|
|
5124
|
+
dim=dim,
|
|
5016
5125
|
)
|
|
5017
5126
|
I2 = self.backend.bk_ifftn(
|
|
5018
5127
|
data2_f[None, None, None, :, :]
|
|
5019
5128
|
* filters_set[None, :J, :, :, :],
|
|
5020
|
-
dim=
|
|
5129
|
+
dim=dim,
|
|
5021
5130
|
)
|
|
5022
5131
|
else:
|
|
5023
5132
|
I1 = self.backend.bk_ifftn(
|
|
5024
5133
|
data_f[:, None, None, :, :] * filters_set[None, :J, :, :, :],
|
|
5025
|
-
dim=
|
|
5134
|
+
dim=dim,
|
|
5026
5135
|
)
|
|
5027
5136
|
I2 = self.backend.bk_ifftn(
|
|
5028
5137
|
data2_f[:, None, None, :, :] * filters_set[None, :J, :, :, :],
|
|
5029
|
-
dim=
|
|
5138
|
+
dim=dim,
|
|
5030
5139
|
)
|
|
5031
5140
|
elif self.use_1D:
|
|
5032
5141
|
if len(data.shape) == 1:
|
|
@@ -5051,18 +5160,18 @@ class funct(FOC.FoCUS):
|
|
|
5051
5160
|
|
|
5052
5161
|
I1 = self.backend.bk_real(I1 * self.backend.bk_conjugate(I2))
|
|
5053
5162
|
|
|
5054
|
-
S2 = self.backend.bk_reduce_mean((I1 * edge_mask), axis=
|
|
5163
|
+
S2 = self.backend.bk_reduce_mean((I1 * edge_mask), axis=dim)
|
|
5055
5164
|
if get_variance:
|
|
5056
|
-
S2_sigma = self.backend.bk_reduce_std((I1 * edge_mask), axis=
|
|
5165
|
+
S2_sigma = self.backend.bk_reduce_std((I1 * edge_mask), axis=dim)
|
|
5057
5166
|
|
|
5058
5167
|
I1 = self.backend.bk_L1(I1)
|
|
5059
5168
|
|
|
5060
|
-
S1 = self.backend.bk_reduce_mean((I1 * edge_mask), axis=
|
|
5169
|
+
S1 = self.backend.bk_reduce_mean((I1 * edge_mask), axis=dim)
|
|
5061
5170
|
|
|
5062
5171
|
if get_variance:
|
|
5063
|
-
S1_sigma = self.backend.bk_reduce_std((I1 * edge_mask), axis=
|
|
5172
|
+
S1_sigma = self.backend.bk_reduce_std((I1 * edge_mask), axis=dim)
|
|
5064
5173
|
|
|
5065
|
-
I1_f = self.backend.bk_fftn(I1, dim=
|
|
5174
|
+
I1_f = self.backend.bk_fftn(I1, dim=dim)
|
|
5066
5175
|
|
|
5067
5176
|
if pseudo_coef != 1:
|
|
5068
5177
|
I1 = I1**pseudo_coef
|
|
@@ -5083,13 +5192,13 @@ class funct(FOC.FoCUS):
|
|
|
5083
5192
|
if data2 is not None:
|
|
5084
5193
|
data2_f_small = self.cut_high_k_off(data2_f, dx3, dy3)
|
|
5085
5194
|
if edge:
|
|
5086
|
-
I1_small = self.backend.bk_ifftn(I1_f_small, dim=
|
|
5195
|
+
I1_small = self.backend.bk_ifftn(I1_f_small, dim=dim, norm="ortho")
|
|
5087
5196
|
data_small = self.backend.bk_ifftn(
|
|
5088
|
-
data_f_small, dim=
|
|
5197
|
+
data_f_small, dim=dim, norm="ortho"
|
|
5089
5198
|
)
|
|
5090
5199
|
if data2 is not None:
|
|
5091
5200
|
data2_small = self.backend.bk_ifftn(
|
|
5092
|
-
data2_f_small, dim=
|
|
5201
|
+
data2_f_small, dim=dim, norm="ortho"
|
|
5093
5202
|
)
|
|
5094
5203
|
wavelet_f3 = self.cut_high_k_off(filters_set[j3], dx3, dy3) # L,x,y
|
|
5095
5204
|
_, M3, N3 = wavelet_f3.shape
|
|
@@ -5112,10 +5221,10 @@ class funct(FOC.FoCUS):
|
|
|
5112
5221
|
) * self.backend.bk_reshape(wavelet_f3_squared, [1, 1, 1, L, M3, N3])
|
|
5113
5222
|
if edge:
|
|
5114
5223
|
I12_w3_small = self.backend.bk_ifftn(
|
|
5115
|
-
I1_f2_wf3_small, dim=
|
|
5224
|
+
I1_f2_wf3_small, dim=dim, norm="ortho"
|
|
5116
5225
|
)
|
|
5117
5226
|
I12_w3_2_small = self.backend.bk_ifftn(
|
|
5118
|
-
I1_f2_wf3_2_small, dim=
|
|
5227
|
+
I1_f2_wf3_2_small, dim=dim, norm="ortho"
|
|
5119
5228
|
)
|
|
5120
5229
|
if use_ref:
|
|
5121
5230
|
if normalization == "P11":
|
|
@@ -5141,7 +5250,7 @@ class funct(FOC.FoCUS):
|
|
|
5141
5250
|
# [N_image,l2,l3,x,y]
|
|
5142
5251
|
P11_temp = (
|
|
5143
5252
|
self.backend.bk_reduce_mean(
|
|
5144
|
-
(I1_f2_wf3_small.abs() ** 2), axis=
|
|
5253
|
+
(I1_f2_wf3_small.abs() ** 2), axis=dim
|
|
5145
5254
|
)
|
|
5146
5255
|
* fft_factor
|
|
5147
5256
|
)
|
|
@@ -5169,7 +5278,7 @@ class funct(FOC.FoCUS):
|
|
|
5169
5278
|
data_f_small, [N_image, 1, 1, 1, M3, N3]
|
|
5170
5279
|
)
|
|
5171
5280
|
* self.backend.bk_conjugate(I1_f2_wf3_small),
|
|
5172
|
-
axis=
|
|
5281
|
+
axis=dim,
|
|
5173
5282
|
)
|
|
5174
5283
|
* fft_factor
|
|
5175
5284
|
/ norm_factor_S3
|
|
@@ -5181,7 +5290,7 @@ class funct(FOC.FoCUS):
|
|
|
5181
5290
|
data_f_small, [N_image, 1, 1, 1, M3, N3]
|
|
5182
5291
|
)
|
|
5183
5292
|
* self.backend.bk_conjugate(I1_f2_wf3_small),
|
|
5184
|
-
axis=
|
|
5293
|
+
axis=dim,
|
|
5185
5294
|
)
|
|
5186
5295
|
* fft_factor
|
|
5187
5296
|
/ norm_factor_S3
|
|
@@ -5195,7 +5304,7 @@ class funct(FOC.FoCUS):
|
|
|
5195
5304
|
)
|
|
5196
5305
|
* self.backend.bk_conjugate(I12_w3_small)
|
|
5197
5306
|
)[..., edge_dx : M3 - edge_dx, edge_dy : N3 - edge_dy],
|
|
5198
|
-
axis=
|
|
5307
|
+
axis=dim,
|
|
5199
5308
|
)
|
|
5200
5309
|
* fft_factor
|
|
5201
5310
|
/ norm_factor_S3
|
|
@@ -5209,7 +5318,7 @@ class funct(FOC.FoCUS):
|
|
|
5209
5318
|
)
|
|
5210
5319
|
* self.backend.bk_conjugate(I12_w3_small)
|
|
5211
5320
|
)[..., edge_dx : M3 - edge_dx, edge_dy : N3 - edge_dy],
|
|
5212
|
-
axis=
|
|
5321
|
+
axis=dim,
|
|
5213
5322
|
)
|
|
5214
5323
|
* fft_factor
|
|
5215
5324
|
/ norm_factor_S3
|
|
@@ -5224,7 +5333,7 @@ class funct(FOC.FoCUS):
|
|
|
5224
5333
|
)
|
|
5225
5334
|
* self.backend.bk_conjugate(I1_f2_wf3_small)
|
|
5226
5335
|
),
|
|
5227
|
-
axis=
|
|
5336
|
+
axis=dim,
|
|
5228
5337
|
)
|
|
5229
5338
|
* fft_factor
|
|
5230
5339
|
/ norm_factor_S3
|
|
@@ -5239,7 +5348,7 @@ class funct(FOC.FoCUS):
|
|
|
5239
5348
|
)
|
|
5240
5349
|
* self.backend.bk_conjugate(I1_f2_wf3_small)
|
|
5241
5350
|
),
|
|
5242
|
-
axis=
|
|
5351
|
+
axis=dim,
|
|
5243
5352
|
)
|
|
5244
5353
|
* fft_factor
|
|
5245
5354
|
/ norm_factor_S3
|
|
@@ -5254,7 +5363,7 @@ class funct(FOC.FoCUS):
|
|
|
5254
5363
|
)
|
|
5255
5364
|
* self.backend.bk_conjugate(I12_w3_small)
|
|
5256
5365
|
)[..., edge_dx : M3 - edge_dx, edge_dy : N3 - edge_dy],
|
|
5257
|
-
axis=
|
|
5366
|
+
axis=dim,
|
|
5258
5367
|
)
|
|
5259
5368
|
* fft_factor
|
|
5260
5369
|
/ norm_factor_S3
|
|
@@ -5272,7 +5381,7 @@ class funct(FOC.FoCUS):
|
|
|
5272
5381
|
edge_dx : M3 - edge_dx,
|
|
5273
5382
|
edge_dy : N3 - edge_dy,
|
|
5274
5383
|
],
|
|
5275
|
-
axis=
|
|
5384
|
+
axis=dim,
|
|
5276
5385
|
)
|
|
5277
5386
|
* fft_factor
|
|
5278
5387
|
/ norm_factor_S3
|
|
@@ -5318,7 +5427,7 @@ class funct(FOC.FoCUS):
|
|
|
5318
5427
|
)
|
|
5319
5428
|
)
|
|
5320
5429
|
),
|
|
5321
|
-
axis=
|
|
5430
|
+
axis=dim,
|
|
5322
5431
|
)
|
|
5323
5432
|
* fft_factor
|
|
5324
5433
|
* P
|
|
@@ -5338,7 +5447,7 @@ class funct(FOC.FoCUS):
|
|
|
5338
5447
|
)
|
|
5339
5448
|
)
|
|
5340
5449
|
),
|
|
5341
|
-
axis=
|
|
5450
|
+
axis=dim,
|
|
5342
5451
|
)
|
|
5343
5452
|
* fft_factor
|
|
5344
5453
|
* P
|
|
@@ -5360,7 +5469,7 @@ class funct(FOC.FoCUS):
|
|
|
5360
5469
|
)
|
|
5361
5470
|
)
|
|
5362
5471
|
),
|
|
5363
|
-
axis=
|
|
5472
|
+
axis=dim,
|
|
5364
5473
|
)
|
|
5365
5474
|
* fft_factor
|
|
5366
5475
|
* P
|
|
@@ -5380,7 +5489,7 @@ class funct(FOC.FoCUS):
|
|
|
5380
5489
|
)
|
|
5381
5490
|
)
|
|
5382
5491
|
),
|
|
5383
|
-
axis=
|
|
5492
|
+
axis=dim,
|
|
5384
5493
|
)
|
|
5385
5494
|
* fft_factor
|
|
5386
5495
|
* P
|
|
@@ -5402,7 +5511,7 @@ class funct(FOC.FoCUS):
|
|
|
5402
5511
|
)
|
|
5403
5512
|
)
|
|
5404
5513
|
)[..., edge_dx:-edge_dx, edge_dy:-edge_dy],
|
|
5405
|
-
axis=
|
|
5514
|
+
axis=dim,
|
|
5406
5515
|
)
|
|
5407
5516
|
* fft_factor
|
|
5408
5517
|
* P
|
|
@@ -5422,7 +5531,7 @@ class funct(FOC.FoCUS):
|
|
|
5422
5531
|
)
|
|
5423
5532
|
)
|
|
5424
5533
|
)[..., edge_dx:-edge_dx, edge_dy:-edge_dy],
|
|
5425
|
-
axis=
|
|
5534
|
+
axis=dim,
|
|
5426
5535
|
)
|
|
5427
5536
|
* fft_factor
|
|
5428
5537
|
* P
|
|
@@ -5444,7 +5553,7 @@ class funct(FOC.FoCUS):
|
|
|
5444
5553
|
)
|
|
5445
5554
|
)
|
|
5446
5555
|
)[..., edge_dx:-edge_dx, edge_dy:-edge_dy],
|
|
5447
|
-
axis=
|
|
5556
|
+
axis=dim,
|
|
5448
5557
|
)
|
|
5449
5558
|
* fft_factor
|
|
5450
5559
|
* P
|
|
@@ -5468,7 +5577,7 @@ class funct(FOC.FoCUS):
|
|
|
5468
5577
|
edge_dx:-edge_dx,
|
|
5469
5578
|
edge_dy:-edge_dy,
|
|
5470
5579
|
],
|
|
5471
|
-
axis=
|
|
5580
|
+
axis=dim,
|
|
5472
5581
|
)
|
|
5473
5582
|
* fft_factor
|
|
5474
5583
|
* P
|
|
@@ -5521,17 +5630,17 @@ class funct(FOC.FoCUS):
|
|
|
5521
5630
|
|
|
5522
5631
|
if data2 is None:
|
|
5523
5632
|
mean_data = self.backend.bk_reshape(
|
|
5524
|
-
self.backend.bk_reduce_mean(data, axis=
|
|
5633
|
+
self.backend.bk_reduce_mean(data, axis=dim), [N_image, 1]
|
|
5525
5634
|
)
|
|
5526
5635
|
std_data = self.backend.bk_reshape(
|
|
5527
|
-
self.backend.bk_reduce_std(data, axis=
|
|
5636
|
+
self.backend.bk_reduce_std(data, axis=dim), [N_image, 1]
|
|
5528
5637
|
)
|
|
5529
5638
|
else:
|
|
5530
5639
|
mean_data = self.backend.bk_reshape(
|
|
5531
|
-
self.backend.bk_reduce_mean(data * data2, axis=
|
|
5640
|
+
self.backend.bk_reduce_mean(data * data2, axis=dim), [N_image, 1]
|
|
5532
5641
|
)
|
|
5533
5642
|
std_data = self.backend.bk_reshape(
|
|
5534
|
-
self.backend.bk_reduce_std(data * data2, axis=
|
|
5643
|
+
self.backend.bk_reduce_std(data * data2, axis=dim), [N_image, 1]
|
|
5535
5644
|
)
|
|
5536
5645
|
|
|
5537
5646
|
if get_variance:
|
|
@@ -5869,8 +5978,8 @@ class funct(FOC.FoCUS):
|
|
|
5869
5978
|
def from_gaussian(self, x):
|
|
5870
5979
|
|
|
5871
5980
|
x = self.backend.bk_clip_by_value(x,
|
|
5872
|
-
self.val_min+1E-
|
|
5873
|
-
self.val_max-1E-
|
|
5981
|
+
self.val_min+1E-7*(self.val_max-self.val_min),
|
|
5982
|
+
self.val_max-1E-7*(self.val_max-self.val_min))
|
|
5874
5983
|
return self.f_gaussian(self.backend.to_numpy(x))
|
|
5875
5984
|
|
|
5876
5985
|
def square(self, x):
|
|
@@ -6041,13 +6150,12 @@ class funct(FOC.FoCUS):
|
|
|
6041
6150
|
return result
|
|
6042
6151
|
else:
|
|
6043
6152
|
if sigma is None:
|
|
6044
|
-
tmp = x
|
|
6153
|
+
tmp = self.diff_data(x,y)
|
|
6045
6154
|
else:
|
|
6046
|
-
tmp = (x
|
|
6155
|
+
tmp = self.diff_data(x,y,sigma=sigma)
|
|
6156
|
+
|
|
6047
6157
|
# do abs in case of complex values
|
|
6048
|
-
return
|
|
6049
|
-
self.backend.bk_reduce_mean(self.backend.bk_square(tmp))
|
|
6050
|
-
)
|
|
6158
|
+
return tmp/x.shape[0]
|
|
6051
6159
|
|
|
6052
6160
|
def reduce_sum(self, x):
|
|
6053
6161
|
|
|
@@ -6209,7 +6317,7 @@ class funct(FOC.FoCUS):
|
|
|
6209
6317
|
Jmax=None,
|
|
6210
6318
|
edge=False,
|
|
6211
6319
|
to_gaussian=True,
|
|
6212
|
-
use_variance=
|
|
6320
|
+
use_variance=True,
|
|
6213
6321
|
synthesised_N=1,
|
|
6214
6322
|
input_image=None,
|
|
6215
6323
|
grd_mask=None,
|
|
@@ -6276,13 +6384,14 @@ class funct(FOC.FoCUS):
|
|
|
6276
6384
|
use_v = args[2]
|
|
6277
6385
|
ljmax = args[3]
|
|
6278
6386
|
|
|
6279
|
-
learn = scat_operator.
|
|
6280
|
-
|
|
6281
|
-
|
|
6282
|
-
|
|
6283
|
-
norm='self'
|
|
6387
|
+
learn = scat_operator.eval(
|
|
6388
|
+
u,
|
|
6389
|
+
Jmax=ljmax,
|
|
6390
|
+
norm='auto'
|
|
6284
6391
|
)
|
|
6285
|
-
|
|
6392
|
+
|
|
6393
|
+
if synthesised_N>1:
|
|
6394
|
+
learn = scat_operator.reduce_mean_batch(learn)
|
|
6286
6395
|
|
|
6287
6396
|
# compute scattering covariance of the current synthetised map called u
|
|
6288
6397
|
if use_v:
|
|
@@ -6488,6 +6597,8 @@ class funct(FOC.FoCUS):
|
|
|
6488
6597
|
)
|
|
6489
6598
|
sref = ref
|
|
6490
6599
|
else:
|
|
6600
|
+
self.clean_norm()
|
|
6601
|
+
|
|
6491
6602
|
ref = self.eval(
|
|
6492
6603
|
tmp[k],
|
|
6493
6604
|
image2=l_ref[k],
|
|
@@ -6504,7 +6615,7 @@ class funct(FOC.FoCUS):
|
|
|
6504
6615
|
mask=l_in_mask[k],
|
|
6505
6616
|
Jmax=l_jmax[k],
|
|
6506
6617
|
calc_var=True,
|
|
6507
|
-
norm='
|
|
6618
|
+
norm='auto'
|
|
6508
6619
|
)
|
|
6509
6620
|
else:
|
|
6510
6621
|
ref = self.eval(
|
|
@@ -6512,7 +6623,7 @@ class funct(FOC.FoCUS):
|
|
|
6512
6623
|
image2=l_ref[k],
|
|
6513
6624
|
mask=l_in_mask[k],
|
|
6514
6625
|
Jmax=l_jmax[k],
|
|
6515
|
-
norm='
|
|
6626
|
+
norm='auto'
|
|
6516
6627
|
)
|
|
6517
6628
|
sref = ref
|
|
6518
6629
|
|