foscat 2025.7.2__py3-none-any.whl → 2025.8.3__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
foscat/scat_cov.py CHANGED
@@ -33,7 +33,14 @@ testwarn = 0
33
33
 
34
34
 
35
35
  class scat_cov:
36
- def __init__(self, s0, s2, s3, s4, s1=None, s3p=None, backend=None, use_1D=False):
36
+ def __init__(self,
37
+ s0, s2, s3, s4,
38
+ s1=None,
39
+ s3p=None,
40
+ backend=None,
41
+ use_1D=False,
42
+ return_data=False
43
+ ):
37
44
  self.S0 = s0
38
45
  self.S2 = s2
39
46
  self.S3 = s3
@@ -44,12 +51,13 @@ class scat_cov:
44
51
  self.idx1 = None
45
52
  self.idx2 = None
46
53
  self.use_1D = use_1D
47
- self.numel = self.backend.bk_len(s0)+ \
48
- self.backend.bk_len(s1)+ \
49
- self.backend.bk_len(s2)+ \
50
- self.backend.bk_len(s3)+ \
51
- self.backend.bk_len(s4)+ \
52
- self.backend.bk_len(s3p)
54
+ if not return_data:
55
+ self.numel = self.backend.bk_len(s0)+ \
56
+ self.backend.bk_len(s1)+ \
57
+ self.backend.bk_len(s2)+ \
58
+ self.backend.bk_len(s3)+ \
59
+ self.backend.bk_len(s4)+ \
60
+ self.backend.bk_len(s3p)
53
61
 
54
62
  def numpy(self):
55
63
  if self.BACKEND == "numpy":
@@ -106,12 +114,13 @@ class scat_cov:
106
114
 
107
115
  # ---------------------------------------------−---------
108
116
  def flatten(self):
109
- tmp = [
110
- self.conv2complex(
111
- self.backend.bk_reshape(self.S0, [self.S1.shape[0], self.S0.shape[1]*self.S0.shape[2]])
112
- )
113
- ]
114
117
  if self.use_1D:
118
+ tmp = [
119
+ self.conv2complex(
120
+ self.backend.bk_reshape(self.S0, [self.S1.shape[0], self.S0.shape[1]])
121
+ )
122
+ ]
123
+
115
124
  if self.S1 is not None:
116
125
  tmp = tmp + [
117
126
  self.conv2complex(
@@ -156,6 +165,11 @@ class scat_cov:
156
165
 
157
166
  return self.backend.bk_concat(tmp, 1)
158
167
 
168
+ tmp = [
169
+ self.conv2complex(
170
+ self.backend.bk_reshape(self.S0, [self.S1.shape[0], self.S0.shape[1]*self.S0.shape[2]])
171
+ )
172
+ ]
159
173
  if self.S1 is not None:
160
174
  tmp = tmp + [
161
175
  self.conv2complex(
@@ -2819,6 +2833,7 @@ class funct(FOC.FoCUS):
2819
2833
  P1_dic = {}
2820
2834
  if cross:
2821
2835
  P2_dic = {}
2836
+
2822
2837
  elif (norm == "auto") and (self.P1_dic is not None):
2823
2838
  P1_dic = self.P1_dic
2824
2839
  if cross:
@@ -3610,23 +3625,40 @@ class funct(FOC.FoCUS):
3610
3625
  self.P2_dic = P2_dic
3611
3626
 
3612
3627
  if not return_data:
3613
- S1 = self.backend.bk_concat(S1, -2)
3614
- S2 = self.backend.bk_concat(S2, -2)
3615
- S3 = self.backend.bk_concat(S3, -3)
3616
- S4 = self.backend.bk_concat(S4, -4)
3617
- if cross:
3618
- S3P = self.backend.bk_concat(S3P, -3)
3619
- if calc_var:
3620
- VS1 = self.backend.bk_concat(VS1, -2)
3621
- VS2 = self.backend.bk_concat(VS2, -2)
3622
- VS3 = self.backend.bk_concat(VS3, -3)
3623
- VS4 = self.backend.bk_concat(VS4, -4)
3628
+ if not self.use_1D:
3629
+ S1 = self.backend.bk_concat(S1, -2)
3630
+ S2 = self.backend.bk_concat(S2, -2)
3631
+ S3 = self.backend.bk_concat(S3, -3)
3632
+ S4 = self.backend.bk_concat(S4, -4)
3624
3633
  if cross:
3625
- VS3P = self.backend.bk_concat(VS3P, -3)
3634
+ S3P = self.backend.bk_concat(S3P, -3)
3635
+ if calc_var:
3636
+ VS1 = self.backend.bk_concat(VS1, -2)
3637
+ VS2 = self.backend.bk_concat(VS2, -2)
3638
+ VS3 = self.backend.bk_concat(VS3, -3)
3639
+ VS4 = self.backend.bk_concat(VS4, -4)
3640
+ if cross:
3641
+ VS3P = self.backend.bk_concat(VS3P, -3)
3642
+ else:
3643
+ S1 = self.backend.bk_concat(S1, -1)
3644
+ S2 = self.backend.bk_concat(S2, -1)
3645
+ S3 = self.backend.bk_concat(S3, -1)
3646
+ S4 = self.backend.bk_concat(S4, -1)
3647
+ if cross:
3648
+ S3P = self.backend.bk_concat(S3P, -1)
3649
+ if calc_var:
3650
+ VS1 = self.backend.bk_concat(VS1, -1)
3651
+ VS2 = self.backend.bk_concat(VS2, -1)
3652
+ VS3 = self.backend.bk_concat(VS3, -1)
3653
+ VS4 = self.backend.bk_concat(VS4, -1)
3654
+ if cross:
3655
+ VS3P = self.backend.bk_concat(VS3P, -1)
3626
3656
  if calc_var:
3627
3657
  if not cross:
3628
3658
  return scat_cov(
3629
- s0, S2, S3, S4, s1=S1, backend=self.backend, use_1D=self.use_1D
3659
+ s0, S2, S3, S4, s1=S1, backend=self.backend,
3660
+ use_1D=self.use_1D,
3661
+ return_data=self.return_data
3630
3662
  ), scat_cov(
3631
3663
  vs0,
3632
3664
  VS2,
@@ -3635,6 +3667,7 @@ class funct(FOC.FoCUS):
3635
3667
  s1=VS1,
3636
3668
  backend=self.backend,
3637
3669
  use_1D=self.use_1D,
3670
+ return_data=self.return_data
3638
3671
  )
3639
3672
  else:
3640
3673
  return scat_cov(
@@ -3646,6 +3679,7 @@ class funct(FOC.FoCUS):
3646
3679
  s3p=S3P,
3647
3680
  backend=self.backend,
3648
3681
  use_1D=self.use_1D,
3682
+ return_data=self.return_data
3649
3683
  ), scat_cov(
3650
3684
  vs0,
3651
3685
  VS2,
@@ -3655,11 +3689,16 @@ class funct(FOC.FoCUS):
3655
3689
  s3p=VS3P,
3656
3690
  backend=self.backend,
3657
3691
  use_1D=self.use_1D,
3692
+ return_data=self.return_data
3658
3693
  )
3659
3694
  else:
3660
3695
  if not cross:
3661
3696
  return scat_cov(
3662
- s0, S2, S3, S4, s1=S1, backend=self.backend, use_1D=self.use_1D
3697
+ s0, S2, S3, S4,
3698
+ s1=S1,
3699
+ backend=self.backend,
3700
+ use_1D=self.use_1D,
3701
+ return_data=self.return_data
3663
3702
  )
3664
3703
  else:
3665
3704
  return scat_cov(
@@ -3671,6 +3710,7 @@ class funct(FOC.FoCUS):
3671
3710
  s3p=S3P,
3672
3711
  backend=self.backend,
3673
3712
  use_1D=self.use_1D,
3713
+ return_data=self.return_data
3674
3714
  )
3675
3715
 
3676
3716
  def clean_norm(self):
@@ -3748,8 +3788,12 @@ class funct(FOC.FoCUS):
3748
3788
  # cconv, sconv are [Nbatch, Norient3, Npix_j3]
3749
3789
  if self.use_1D:
3750
3790
  s3 = conv * self.backend.bk_conjugate(MconvPsi)
3791
+ elif self.use_2D:
3792
+ s3 = self.backend.bk_expand_dims(conv, -4)* self.backend.bk_conjugate(
3793
+ MconvPsi
3794
+ ) # [Nbatch, Norient3, Norient2, Npix_j3]
3751
3795
  else:
3752
- s3 = self.backend.bk_expand_dims(conv, -3) * self.backend.bk_conjugate(
3796
+ s3 = self.backend.bk_expand_dims(conv, -3)* self.backend.bk_conjugate(
3753
3797
  MconvPsi
3754
3798
  ) # [Nbatch, Norient3, Norient2, Npix_j3]
3755
3799
  ### Apply the mask [Nmask, Npix_j3] and sum over pixels
@@ -3816,62 +3860,121 @@ class funct(FOC.FoCUS):
3816
3860
  Done by Sihao Cheng and Rudy Morel.
3817
3861
  """
3818
3862
 
3819
- filter = np.zeros([J, L, M, N], dtype="complex64")
3863
+ if N!=0:
3864
+ filter = np.zeros([J, L, M, N], dtype="complex64")
3820
3865
 
3821
- slant = 4.0 / L
3866
+ slant = 4.0 / L
3822
3867
 
3823
- for j in range(J):
3868
+ for j in range(J):
3824
3869
 
3825
- for ell in range(L):
3870
+ for ell in range(L):
3826
3871
 
3827
- theta = (int(L - L / 2 - 1) - ell) * np.pi / L
3828
- sigma = 0.8 * 2**j
3829
- xi = 3.0 / 4.0 * np.pi / 2**j
3872
+ theta = (int(L - L / 2 - 1) - ell) * np.pi / L
3873
+ sigma = 0.8 * 2**j
3874
+ xi = 3.0 / 4.0 * np.pi / 2**j
3830
3875
 
3831
- R = np.array(
3832
- [[np.cos(theta), -np.sin(theta)], [np.sin(theta), np.cos(theta)]],
3833
- np.float64,
3834
- )
3835
- R_inv = np.array(
3836
- [[np.cos(theta), np.sin(theta)], [-np.sin(theta), np.cos(theta)]],
3837
- np.float64,
3838
- )
3839
- D = np.array([[1, 0], [0, slant * slant]])
3840
- curv = np.matmul(R, np.matmul(D, R_inv)) / (2 * sigma * sigma)
3841
-
3842
- gab = np.zeros((M, N), np.complex128)
3843
- xx = np.empty((2, 2, M, N))
3844
- yy = np.empty((2, 2, M, N))
3845
-
3846
- for ii, ex in enumerate([-1, 0]):
3847
- for jj, ey in enumerate([-1, 0]):
3848
- xx[ii, jj], yy[ii, jj] = np.mgrid[
3849
- ex * M : M + ex * M, ey * N : N + ey * N
3850
- ]
3851
-
3852
- arg = -(
3853
- curv[0, 0] * xx * xx
3854
- + (curv[0, 1] + curv[1, 0]) * xx * yy
3855
- + curv[1, 1] * yy * yy
3856
- )
3857
- argi = arg + 1.0j * (xx * xi * np.cos(theta) + yy * xi * np.sin(theta))
3876
+ R = np.array(
3877
+ [[np.cos(theta), -np.sin(theta)], [np.sin(theta), np.cos(theta)]],
3878
+ np.float64,
3879
+ )
3880
+ R_inv = np.array(
3881
+ [[np.cos(theta), np.sin(theta)], [-np.sin(theta), np.cos(theta)]],
3882
+ np.float64,
3883
+ )
3884
+ D = np.array([[1, 0], [0, slant * slant]])
3885
+ curv = np.matmul(R, np.matmul(D, R_inv)) / (2 * sigma * sigma)
3886
+
3887
+ gab = np.zeros((M, N), np.complex128)
3888
+ xx = np.empty((2, 2, M, N))
3889
+ yy = np.empty((2, 2, M, N))
3858
3890
 
3859
- gabi = np.exp(argi).sum((0, 1))
3860
- gab = np.exp(arg).sum((0, 1))
3891
+ for ii, ex in enumerate([-1, 0]):
3892
+ for jj, ey in enumerate([-1, 0]):
3893
+ xx[ii, jj], yy[ii, jj] = np.mgrid[
3894
+ ex * M : M + ex * M, ey * N : N + ey * N
3895
+ ]
3896
+
3897
+ arg = -(
3898
+ curv[0, 0] * xx * xx
3899
+ + (curv[0, 1] + curv[1, 0]) * xx * yy
3900
+ + curv[1, 1] * yy * yy
3901
+ )
3902
+ argi = arg + 1.0j * (xx * xi * np.cos(theta) + yy * xi * np.sin(theta))
3903
+
3904
+ gabi = np.exp(argi).sum((0, 1))
3905
+ gab = np.exp(arg).sum((0, 1))
3906
+
3907
+ norm_factor = 2 * np.pi * sigma * sigma / slant
3908
+
3909
+ gab = gab / norm_factor
3910
+
3911
+ gabi = gabi / norm_factor
3912
+
3913
+ K = gabi.sum() / gab.sum()
3914
+
3915
+ # Apply the Gaussian
3916
+ filter[j, ell] = np.fft.fft2(gabi - K * gab)
3917
+ filter[j, ell, 0, 0] = 0.0
3861
3918
 
3862
- norm_factor = 2 * np.pi * sigma * sigma / slant
3919
+ return self.backend.bk_cast(filter)
3920
+ else:
3921
+ filter = np.zeros([J, L, M], dtype="complex64")
3922
+ #TODO
3923
+ print('filter for 1D not yet available')
3924
+ exit(0)
3925
+ slant = 4.0 / L
3926
+
3927
+ for j in range(J):
3863
3928
 
3864
- gab = gab / norm_factor
3929
+ for ell in range(L):
3865
3930
 
3866
- gabi = gabi / norm_factor
3931
+ theta = (int(L - L / 2 - 1) - ell) * np.pi / L
3932
+ sigma = 0.8 * 2**j
3933
+ xi = 3.0 / 4.0 * np.pi / 2**j
3867
3934
 
3868
- K = gabi.sum() / gab.sum()
3935
+ R = np.array(
3936
+ [[np.cos(theta), -np.sin(theta)], [np.sin(theta), np.cos(theta)]],
3937
+ np.float64,
3938
+ )
3939
+ R_inv = np.array(
3940
+ [[np.cos(theta), np.sin(theta)], [-np.sin(theta), np.cos(theta)]],
3941
+ np.float64,
3942
+ )
3943
+ D = np.array([[1, 0], [0, slant * slant]])
3944
+ curv = np.matmul(R, np.matmul(D, R_inv)) / (2 * sigma * sigma)
3869
3945
 
3870
- # Apply the Gaussian
3871
- filter[j, ell] = np.fft.fft2(gabi - K * gab)
3872
- filter[j, ell, 0, 0] = 0.0
3946
+ gab = np.zeros((M), np.complex128)
3947
+ xx = np.empty((M))
3948
+
3949
+ for ii, ex in enumerate([-1, 0]):
3950
+ for jj, ey in enumerate([-1, 0]):
3951
+ xx[ii, jj], yy[ii, jj] = np.mgrid[
3952
+ ex * M : M + ex * M, ey * N : N + ey * N
3953
+ ]
3954
+
3955
+ arg = -(
3956
+ curv[0, 0] * xx * xx
3957
+ + (curv[0, 1] + curv[1, 0]) * xx * yy
3958
+ + curv[1, 1] * yy * yy
3959
+ )
3960
+ argi = arg + 1.0j * (xx * xi * np.cos(theta) + yy * xi * np.sin(theta))
3873
3961
 
3874
- return self.backend.bk_cast(filter)
3962
+ gabi = np.exp(argi).sum((0, 1))
3963
+ gab = np.exp(arg).sum((0, 1))
3964
+
3965
+ norm_factor = 2 * np.pi * sigma * sigma / slant
3966
+
3967
+ gab = gab / norm_factor
3968
+
3969
+ gabi = gabi / norm_factor
3970
+
3971
+ K = gabi.sum() / gab.sum()
3972
+
3973
+ # Apply the Gaussian
3974
+ filter[j, ell] = np.fft.fft2(gabi - K * gab)
3975
+ filter[j, ell, 0, 0] = 0.0
3976
+
3977
+ return self.backend.bk_cast(filter)
3875
3978
 
3876
3979
  # ------------------------------------------------------------------------------------------
3877
3980
  #
@@ -4107,18 +4210,24 @@ class funct(FOC.FoCUS):
4107
4210
  if data2 is not None:
4108
4211
  N_image2 = data2.shape[0]
4109
4212
  J = int(np.log(nside) / np.log(2)) - 1 # Number of j scales
4213
+ dim=(-2,-1)
4110
4214
  elif self.use_1D:
4111
4215
  if len(data.shape) == 2:
4112
4216
  npix = int(im_shape[1]) # Number of pixels
4217
+ M = im_shape[1]
4218
+ N=0
4113
4219
  N_image = 1
4114
4220
  N_image2 = 1
4115
4221
  else:
4116
4222
  npix = int(im_shape[0]) # Number of pixels
4117
4223
  N_image = data.shape[0]
4224
+ M = im_shape[0]
4225
+ N=0
4118
4226
  if data2 is not None:
4119
4227
  N_image2 = data2.shape[0]
4120
4228
 
4121
4229
  nside = int(npix)
4230
+ dim=(-1)
4122
4231
 
4123
4232
  J = int(np.log(nside) / np.log(2)) - 1 # Number of j scales
4124
4233
  else:
@@ -4169,10 +4278,10 @@ class funct(FOC.FoCUS):
4169
4278
 
4170
4279
  # convert numpy array input into self.backend.bk_ tensors
4171
4280
  data = self.backend.bk_cast(data)
4172
- data_f = self.backend.bk_fftn(data, dim=(-2, -1))
4281
+ data_f = self.backend.bk_fftn(data, dim=dim)
4173
4282
  if data2 is not None:
4174
4283
  data2 = self.backend.bk_cast(data2)
4175
- data2_f = self.backend.bk_fftn(data2, dim=(-2, -1))
4284
+ data2_f = self.backend.bk_fftn(data2, dim=dim)
4176
4285
 
4177
4286
  # initialize tensors for scattering coefficients
4178
4287
  S2 = self.backend.bk_zeros((N_image, J, L), dtype=data.dtype)
@@ -4248,13 +4357,13 @@ class funct(FOC.FoCUS):
4248
4357
  I1 = self.backend.bk_ifftn(
4249
4358
  data_f[None, None, None, :, :]
4250
4359
  * filters_set[None, :J, :, :, :],
4251
- dim=(-2, -1),
4360
+ dim=dim,
4252
4361
  ).abs()
4253
4362
  else:
4254
4363
  I1 = self.backend.bk_ifftn(
4255
4364
  data_f[:, None, None, :, :]
4256
4365
  * filters_set[None, :J, :, :, :],
4257
- dim=(-2, -1),
4366
+ dim=dim,
4258
4367
  ).abs()
4259
4368
  elif self.use_1D:
4260
4369
  if len(data.shape) == 1:
@@ -4270,12 +4379,12 @@ class funct(FOC.FoCUS):
4270
4379
  else:
4271
4380
  print("todo")
4272
4381
 
4273
- S2 = (I1**2 * edge_mask).mean((-2, -1))
4274
- S1 = (I1 * edge_mask).mean((-2, -1))
4382
+ S2 = (I1**2 * edge_mask).mean(dim)
4383
+ S1 = (I1 * edge_mask).mean(dim)
4275
4384
 
4276
4385
  if get_variance:
4277
- S2_sigma = (I1**2 * edge_mask).std((-2, -1))
4278
- S1_sigma = (I1 * edge_mask).std((-2, -1))
4386
+ S2_sigma = (I1**2 * edge_mask).std(dim)
4387
+ S1_sigma = (I1 * edge_mask).std(dim)
4279
4388
 
4280
4389
  else:
4281
4390
  if self.use_2D:
@@ -4283,64 +4392,64 @@ class funct(FOC.FoCUS):
4283
4392
  I1 = self.backend.bk_ifftn(
4284
4393
  data_f[None, None, None, :, :]
4285
4394
  * filters_set[None, :J, :, :, :],
4286
- dim=(-2, -1),
4395
+ dim=dim,
4287
4396
  )
4288
4397
  I2 = self.backend.bk_ifftn(
4289
4398
  data2_f[None, None, None, :, :]
4290
4399
  * filters_set[None, :J, :, :, :],
4291
- dim=(-2, -1),
4400
+ dim=dim,
4292
4401
  )
4293
4402
  else:
4294
4403
  I1 = self.backend.bk_ifftn(
4295
4404
  data_f[:, None, None, :, :]
4296
4405
  * filters_set[None, :J, :, :, :],
4297
- dim=(-2, -1),
4406
+ dim=dim,
4298
4407
  )
4299
4408
  I2 = self.backend.bk_ifftn(
4300
4409
  data2_f[:, None, None, :, :]
4301
4410
  * filters_set[None, :J, :, :, :],
4302
- dim=(-2, -1),
4411
+ dim=dim,
4303
4412
  )
4304
4413
  elif self.use_1D:
4305
4414
  if len(data.shape) == 1:
4306
4415
  I1 = self.backend.bk_ifftn(
4307
4416
  data_f[None, None, None, :] * filters_set[None, :J, :, :],
4308
- dim=(-1),
4417
+ dim=dim,
4309
4418
  )
4310
4419
  I2 = self.backend.bk_ifftn(
4311
4420
  data2_f[None, None, None, :] * filters_set[None, :J, :, :],
4312
- dim=(-1),
4421
+ dim=dim,
4313
4422
  )
4314
4423
  else:
4315
4424
  I1 = self.backend.bk_ifftn(
4316
4425
  data_f[:, None, None, :] * filters_set[None, :J, :, :],
4317
- dim=(-1),
4426
+ dim=dim,
4318
4427
  )
4319
4428
  I2 = self.backend.bk_ifftn(
4320
4429
  data2_f[:, None, None, :] * filters_set[None, :J, :, :],
4321
- dim=(-1),
4430
+ dim=dim,
4322
4431
  )
4323
4432
  else:
4324
4433
  print("todo")
4325
4434
 
4326
4435
  I1 = self.backend.bk_real(I1 * self.backend.bk_conjugate(I2))
4327
4436
 
4328
- S2 = self.backend.bk_reduce_mean((I1 * edge_mask), axis=(-2, -1))
4437
+ S2 = self.backend.bk_reduce_mean((I1 * edge_mask), axis=dim)
4329
4438
  if get_variance:
4330
4439
  S2_sigma = self.backend.bk_reduce_std(
4331
- (I1 * edge_mask), axis=(-2, -1)
4440
+ (I1 * edge_mask), axis=dim
4332
4441
  )
4333
4442
 
4334
4443
  I1 = self.backend.bk_L1(I1)
4335
4444
 
4336
- S1 = self.backend.bk_reduce_mean((I1 * edge_mask), axis=(-2, -1))
4445
+ S1 = self.backend.bk_reduce_mean((I1 * edge_mask), axis=dim)
4337
4446
 
4338
4447
  if get_variance:
4339
4448
  S1_sigma = self.backend.bk_reduce_std(
4340
- (I1 * edge_mask), axis=(-2, -1)
4449
+ (I1 * edge_mask), axis=dim
4341
4450
  )
4342
4451
 
4343
- I1_f = self.backend.bk_fftn(I1, dim=(-2, -1))
4452
+ I1_f = self.backend.bk_fftn(I1, dim=dim)
4344
4453
 
4345
4454
  if pseudo_coef != 1:
4346
4455
  I1 = I1**pseudo_coef
@@ -4362,14 +4471,14 @@ class funct(FOC.FoCUS):
4362
4471
  data2_f_small = self.cut_high_k_off(data2_f, dx3, dy3)
4363
4472
  if edge:
4364
4473
  I1_small = self.backend.bk_ifftn(
4365
- I1_f_small, dim=(-2, -1), norm="ortho"
4474
+ I1_f_small, dim=dim, norm="ortho"
4366
4475
  )
4367
4476
  data_small = self.backend.bk_ifftn(
4368
- data_f_small, dim=(-2, -1), norm="ortho"
4477
+ data_f_small, dim=dim, norm="ortho"
4369
4478
  )
4370
4479
  if data2 is not None:
4371
4480
  data2_small = self.backend.bk_ifftn(
4372
- data2_f_small, dim=(-2, -1), norm="ortho"
4481
+ data2_f_small, dim=dim, norm="ortho"
4373
4482
  )
4374
4483
 
4375
4484
  wavelet_f3 = self.cut_high_k_off(filters_set[j3], dx3, dy3) # L,x,y
@@ -4400,10 +4509,10 @@ class funct(FOC.FoCUS):
4400
4509
  ) * wavelet_f3_squared.view(1, 1, L, M3, N3)
4401
4510
  if edge:
4402
4511
  I12_w3_small = self.backend.bk_ifftn(
4403
- I1_f2_wf3_small, dim=(-2, -1), norm="ortho"
4512
+ I1_f2_wf3_small, dim=dim, norm="ortho"
4404
4513
  )
4405
4514
  I12_w3_2_small = self.backend.bk_ifftn(
4406
- I1_f2_wf3_2_small, dim=(-2, -1), norm="ortho"
4515
+ I1_f2_wf3_2_small, dim=dim, norm="ortho"
4407
4516
  )
4408
4517
  if use_ref:
4409
4518
  if normalization == "P11":
@@ -4420,7 +4529,7 @@ class funct(FOC.FoCUS):
4420
4529
  if normalization == "P11":
4421
4530
  # [N_image,l2,l3,x,y]
4422
4531
  P11_temp = (I1_f2_wf3_small.abs() ** 2).mean(
4423
- (-2, -1)
4532
+ dim
4424
4533
  ) * fft_factor
4425
4534
  norm_factor_S3 = (
4426
4535
  S2[:, None, j3, :] * P11_temp**pseudo_coef
@@ -4435,7 +4544,7 @@ class funct(FOC.FoCUS):
4435
4544
  (
4436
4545
  data_f_small.view(N_image, 1, 1, M3, N3)
4437
4546
  * self.backend.bk_conjugate(I1_f2_wf3_small)
4438
- ).mean((-2, -1))
4547
+ ).mean(dim)
4439
4548
  * fft_factor
4440
4549
  / norm_factor_S3
4441
4550
  )
@@ -4445,7 +4554,7 @@ class funct(FOC.FoCUS):
4445
4554
  (
4446
4555
  data_f_small.view(N_image, 1, 1, M3, N3)
4447
4556
  * self.backend.bk_conjugate(I1_f2_wf3_small)
4448
- ).std((-2, -1))
4557
+ ).std(dim)
4449
4558
  * fft_factor
4450
4559
  / norm_factor_S3
4451
4560
  )
@@ -4456,7 +4565,7 @@ class funct(FOC.FoCUS):
4456
4565
  * self.backend.bk_conjugate(I12_w3_small)
4457
4566
  * edge_mask[None, None, None, :, :]
4458
4567
  ).mean( # [..., edge_dx : M3 - edge_dx, edge_dy : N3 - edge_dy]
4459
- (-2, -1)
4568
+ dim
4460
4569
  )
4461
4570
  * fft_factor
4462
4571
  / norm_factor_S3
@@ -4467,7 +4576,7 @@ class funct(FOC.FoCUS):
4467
4576
  data_small.view(N_image, 1, 1, M3, N3)
4468
4577
  * self.backend.bk_conjugate(I12_w3_small)
4469
4578
  * edge_mask[None, None, None, :, :]
4470
- ).std((-2, -1))
4579
+ ).std(dim)
4471
4580
  * fft_factor
4472
4581
  / norm_factor_S3
4473
4582
  )
@@ -4477,7 +4586,7 @@ class funct(FOC.FoCUS):
4477
4586
  (
4478
4587
  data2_f_small.view(N_image2, 1, 1, M3, N3)
4479
4588
  * self.backend.bk_conjugate(I1_f2_wf3_small)
4480
- ).mean((-2, -1))
4589
+ ).mean(dim)
4481
4590
  * fft_factor
4482
4591
  / norm_factor_S3
4483
4592
  )
@@ -4487,7 +4596,7 @@ class funct(FOC.FoCUS):
4487
4596
  (
4488
4597
  data2_f_small.view(N_image2, 1, 1, M3, N3)
4489
4598
  * self.backend.bk_conjugate(I1_f2_wf3_small)
4490
- ).std((-2, -1))
4599
+ ).std(dim)
4491
4600
  * fft_factor
4492
4601
  / norm_factor_S3
4493
4602
  )
@@ -4497,7 +4606,7 @@ class funct(FOC.FoCUS):
4497
4606
  data2_small.view(N_image2, 1, 1, M3, N3)
4498
4607
  * self.backend.bk_conjugate(I12_w3_small)
4499
4608
  * edge_mask[None, None, None, :, :]
4500
- ).mean((-2, -1))
4609
+ ).mean(dim)
4501
4610
  * fft_factor
4502
4611
  / norm_factor_S3
4503
4612
  )
@@ -4507,7 +4616,7 @@ class funct(FOC.FoCUS):
4507
4616
  data2_small.view(N_image2, 1, 1, M3, N3)
4508
4617
  * self.backend.bk_conjugate(I12_w3_small)
4509
4618
  * edge_mask[None, None, None, :, :]
4510
- ).std((-2, -1))
4619
+ ).std(dim)
4511
4620
  * fft_factor
4512
4621
  / norm_factor_S3
4513
4622
  )
@@ -4528,7 +4637,7 @@ class funct(FOC.FoCUS):
4528
4637
  N_image, 1, L, L, M3, N3
4529
4638
  )
4530
4639
  )
4531
- ).mean((-2, -1)) * fft_factor
4640
+ ).mean(dim) * fft_factor
4532
4641
  if get_variance:
4533
4642
  S4_sigma[:, Ndata_S4, :, :, :] = (
4534
4643
  I1_f_small[:, j1].view(
@@ -4539,7 +4648,7 @@ class funct(FOC.FoCUS):
4539
4648
  N_image, 1, L, L, M3, N3
4540
4649
  )
4541
4650
  )
4542
- ).std((-2, -1)) * fft_factor
4651
+ ).std(dim) * fft_factor
4543
4652
  else:
4544
4653
  for l1 in range(L):
4545
4654
  # [N_image,l2,l3,x,y]
@@ -4552,7 +4661,7 @@ class funct(FOC.FoCUS):
4552
4661
  N_image, L, L, M3, N3
4553
4662
  )
4554
4663
  )
4555
- ).mean((-2, -1)) * fft_factor
4664
+ ).mean(dim) * fft_factor
4556
4665
  if get_variance:
4557
4666
  S4_sigma[:, Ndata_S4, l1, :, :] = (
4558
4667
  I1_f_small[:, j1, l1].view(
@@ -4563,7 +4672,7 @@ class funct(FOC.FoCUS):
4563
4672
  N_image, L, L, M3, N3
4564
4673
  )
4565
4674
  )
4566
- ).std((-2, -1)) * fft_factor
4675
+ ).std(dim) * fft_factor
4567
4676
  else:
4568
4677
  if not if_large_batch:
4569
4678
  # [N_image,l1,l2,l3,x,y]
@@ -4577,7 +4686,7 @@ class funct(FOC.FoCUS):
4577
4686
  )
4578
4687
  )
4579
4688
  * edge_mask[None, None, None, None, :, :]
4580
- ).mean((-2, -1)) * fft_factor
4689
+ ).mean(dim) * fft_factor
4581
4690
  if get_variance:
4582
4691
  S4_sigma[:, Ndata_S4, :, :, :] = (
4583
4692
  I1_small[:, j1].view(
@@ -4591,7 +4700,7 @@ class funct(FOC.FoCUS):
4591
4700
  * edge_mask[
4592
4701
  None, None, None, None, :, :
4593
4702
  ]
4594
- ).std((-2, -1)) * fft_factor
4703
+ ).std(dim) * fft_factor
4595
4704
  else:
4596
4705
  for l1 in range(L):
4597
4706
  # [N_image,l2,l3,x,y]
@@ -4607,7 +4716,7 @@ class funct(FOC.FoCUS):
4607
4716
  * edge_mask[
4608
4717
  None, None, None, None, :, :
4609
4718
  ]
4610
- ).mean((-2, -1)) * fft_factor
4719
+ ).mean(dim) * fft_factor
4611
4720
  if get_variance:
4612
4721
  S4_sigma[:, Ndata_S4, l1, :, :] = (
4613
4722
  I1_small[:, j1].view(
@@ -4621,7 +4730,7 @@ class funct(FOC.FoCUS):
4621
4730
  * edge_mask[
4622
4731
  None, None, None, None, :, :
4623
4732
  ]
4624
- ).std((-2, -1)) * fft_factor
4733
+ ).std(dim) * fft_factor
4625
4734
 
4626
4735
  Ndata_S4 += 1
4627
4736
 
@@ -4689,11 +4798,11 @@ class funct(FOC.FoCUS):
4689
4798
  std_data = self.backend.bk_zeros((N_image, 1), dtype=data.dtype)
4690
4799
 
4691
4800
  if data2 is None:
4692
- mean_data[:, 0] = data.mean((-2, -1))
4693
- std_data[:, 0] = data.std((-2, -1))
4801
+ mean_data[:, 0] = data.mean(dim)
4802
+ std_data[:, 0] = data.std(dim)
4694
4803
  else:
4695
- mean_data[:, 0] = (data2 * data).mean((-2, -1))
4696
- std_data[:, 0] = (data2 * data).std((-2, -1))
4804
+ mean_data[:, 0] = (data2 * data).mean(dim)
4805
+ std_data[:, 0] = (data2 * data).std(dim)
4697
4806
 
4698
4807
  if get_variance:
4699
4808
  ref_sigma = {}
@@ -4913,10 +5022,10 @@ class funct(FOC.FoCUS):
4913
5022
 
4914
5023
  # convert numpy array input into self.backend.bk_ tensors
4915
5024
  data = self.backend.bk_cast(data)
4916
- data_f = self.backend.bk_fftn(data, dim=(-2, -1))
5025
+ data_f = self.backend.bk_fftn(data, dim=dim)
4917
5026
  if data2 is not None:
4918
5027
  data2 = self.backend.bk_cast(data2)
4919
- data2_f = self.backend.bk_fftn(data2, dim=(-2, -1))
5028
+ data2_f = self.backend.bk_fftn(data2, dim=dim)
4920
5029
 
4921
5030
  # initialize tensors for scattering coefficients
4922
5031
 
@@ -4967,7 +5076,7 @@ class funct(FOC.FoCUS):
4967
5076
  self.backend.bk_ifftn(
4968
5077
  data_f[None, None, None, :, :]
4969
5078
  * filters_set[None, :J, :, :, :],
4970
- dim=(-2, -1),
5079
+ dim=dim,
4971
5080
  )
4972
5081
  )
4973
5082
  else:
@@ -4975,7 +5084,7 @@ class funct(FOC.FoCUS):
4975
5084
  self.backend.bk_ifftn(
4976
5085
  data_f[:, None, None, :, :]
4977
5086
  * filters_set[None, :J, :, :, :],
4978
- dim=(-2, -1),
5087
+ dim=dim,
4979
5088
  )
4980
5089
  )
4981
5090
  elif self.use_1D:
@@ -4996,37 +5105,37 @@ class funct(FOC.FoCUS):
4996
5105
  else:
4997
5106
  print("todo")
4998
5107
 
4999
- S2 = self.backend.bk_reduce_mean((I1**2 * edge_mask), axis=(-2, -1))
5000
- S1 = self.backend.bk_reduce_mean(I1 * edge_mask, axis=(-2, -1))
5108
+ S2 = self.backend.bk_reduce_mean((I1**2 * edge_mask), axis=dim)
5109
+ S1 = self.backend.bk_reduce_mean(I1 * edge_mask, axis=dim)
5001
5110
 
5002
5111
  if get_variance:
5003
5112
  S2_sigma = self.backend.bk_reduce_std(
5004
- (I1**2 * edge_mask), axis=(-2, -1)
5113
+ (I1**2 * edge_mask), axis=dim
5005
5114
  )
5006
- S1_sigma = self.backend.bk_reduce_std((I1 * edge_mask), axis=(-2, -1))
5115
+ S1_sigma = self.backend.bk_reduce_std((I1 * edge_mask), axis=dim)
5007
5116
 
5008
- I1_f = self.backend.bk_fftn(I1, dim=(-2, -1))
5117
+ I1_f = self.backend.bk_fftn(I1, dim=dim)
5009
5118
 
5010
5119
  else:
5011
5120
  if self.use_2D:
5012
5121
  if len(data.shape) == 2:
5013
5122
  I1 = self.backend.bk_ifftn(
5014
5123
  data_f[None, None, None, :, :] * filters_set[None, :J, :, :, :],
5015
- dim=(-2, -1),
5124
+ dim=dim,
5016
5125
  )
5017
5126
  I2 = self.backend.bk_ifftn(
5018
5127
  data2_f[None, None, None, :, :]
5019
5128
  * filters_set[None, :J, :, :, :],
5020
- dim=(-2, -1),
5129
+ dim=dim,
5021
5130
  )
5022
5131
  else:
5023
5132
  I1 = self.backend.bk_ifftn(
5024
5133
  data_f[:, None, None, :, :] * filters_set[None, :J, :, :, :],
5025
- dim=(-2, -1),
5134
+ dim=dim,
5026
5135
  )
5027
5136
  I2 = self.backend.bk_ifftn(
5028
5137
  data2_f[:, None, None, :, :] * filters_set[None, :J, :, :, :],
5029
- dim=(-2, -1),
5138
+ dim=dim,
5030
5139
  )
5031
5140
  elif self.use_1D:
5032
5141
  if len(data.shape) == 1:
@@ -5051,18 +5160,18 @@ class funct(FOC.FoCUS):
5051
5160
 
5052
5161
  I1 = self.backend.bk_real(I1 * self.backend.bk_conjugate(I2))
5053
5162
 
5054
- S2 = self.backend.bk_reduce_mean((I1 * edge_mask), axis=(-2, -1))
5163
+ S2 = self.backend.bk_reduce_mean((I1 * edge_mask), axis=dim)
5055
5164
  if get_variance:
5056
- S2_sigma = self.backend.bk_reduce_std((I1 * edge_mask), axis=(-2, -1))
5165
+ S2_sigma = self.backend.bk_reduce_std((I1 * edge_mask), axis=dim)
5057
5166
 
5058
5167
  I1 = self.backend.bk_L1(I1)
5059
5168
 
5060
- S1 = self.backend.bk_reduce_mean((I1 * edge_mask), axis=(-2, -1))
5169
+ S1 = self.backend.bk_reduce_mean((I1 * edge_mask), axis=dim)
5061
5170
 
5062
5171
  if get_variance:
5063
- S1_sigma = self.backend.bk_reduce_std((I1 * edge_mask), axis=(-2, -1))
5172
+ S1_sigma = self.backend.bk_reduce_std((I1 * edge_mask), axis=dim)
5064
5173
 
5065
- I1_f = self.backend.bk_fftn(I1, dim=(-2, -1))
5174
+ I1_f = self.backend.bk_fftn(I1, dim=dim)
5066
5175
 
5067
5176
  if pseudo_coef != 1:
5068
5177
  I1 = I1**pseudo_coef
@@ -5083,13 +5192,13 @@ class funct(FOC.FoCUS):
5083
5192
  if data2 is not None:
5084
5193
  data2_f_small = self.cut_high_k_off(data2_f, dx3, dy3)
5085
5194
  if edge:
5086
- I1_small = self.backend.bk_ifftn(I1_f_small, dim=(-2, -1), norm="ortho")
5195
+ I1_small = self.backend.bk_ifftn(I1_f_small, dim=dim, norm="ortho")
5087
5196
  data_small = self.backend.bk_ifftn(
5088
- data_f_small, dim=(-2, -1), norm="ortho"
5197
+ data_f_small, dim=dim, norm="ortho"
5089
5198
  )
5090
5199
  if data2 is not None:
5091
5200
  data2_small = self.backend.bk_ifftn(
5092
- data2_f_small, dim=(-2, -1), norm="ortho"
5201
+ data2_f_small, dim=dim, norm="ortho"
5093
5202
  )
5094
5203
  wavelet_f3 = self.cut_high_k_off(filters_set[j3], dx3, dy3) # L,x,y
5095
5204
  _, M3, N3 = wavelet_f3.shape
@@ -5112,10 +5221,10 @@ class funct(FOC.FoCUS):
5112
5221
  ) * self.backend.bk_reshape(wavelet_f3_squared, [1, 1, 1, L, M3, N3])
5113
5222
  if edge:
5114
5223
  I12_w3_small = self.backend.bk_ifftn(
5115
- I1_f2_wf3_small, dim=(-2, -1), norm="ortho"
5224
+ I1_f2_wf3_small, dim=dim, norm="ortho"
5116
5225
  )
5117
5226
  I12_w3_2_small = self.backend.bk_ifftn(
5118
- I1_f2_wf3_2_small, dim=(-2, -1), norm="ortho"
5227
+ I1_f2_wf3_2_small, dim=dim, norm="ortho"
5119
5228
  )
5120
5229
  if use_ref:
5121
5230
  if normalization == "P11":
@@ -5141,7 +5250,7 @@ class funct(FOC.FoCUS):
5141
5250
  # [N_image,l2,l3,x,y]
5142
5251
  P11_temp = (
5143
5252
  self.backend.bk_reduce_mean(
5144
- (I1_f2_wf3_small.abs() ** 2), axis=(-2, -1)
5253
+ (I1_f2_wf3_small.abs() ** 2), axis=dim
5145
5254
  )
5146
5255
  * fft_factor
5147
5256
  )
@@ -5169,7 +5278,7 @@ class funct(FOC.FoCUS):
5169
5278
  data_f_small, [N_image, 1, 1, 1, M3, N3]
5170
5279
  )
5171
5280
  * self.backend.bk_conjugate(I1_f2_wf3_small),
5172
- axis=(-2, -1),
5281
+ axis=dim,
5173
5282
  )
5174
5283
  * fft_factor
5175
5284
  / norm_factor_S3
@@ -5181,7 +5290,7 @@ class funct(FOC.FoCUS):
5181
5290
  data_f_small, [N_image, 1, 1, 1, M3, N3]
5182
5291
  )
5183
5292
  * self.backend.bk_conjugate(I1_f2_wf3_small),
5184
- axis=(-2, -1),
5293
+ axis=dim,
5185
5294
  )
5186
5295
  * fft_factor
5187
5296
  / norm_factor_S3
@@ -5195,7 +5304,7 @@ class funct(FOC.FoCUS):
5195
5304
  )
5196
5305
  * self.backend.bk_conjugate(I12_w3_small)
5197
5306
  )[..., edge_dx : M3 - edge_dx, edge_dy : N3 - edge_dy],
5198
- axis=(-2, -1),
5307
+ axis=dim,
5199
5308
  )
5200
5309
  * fft_factor
5201
5310
  / norm_factor_S3
@@ -5209,7 +5318,7 @@ class funct(FOC.FoCUS):
5209
5318
  )
5210
5319
  * self.backend.bk_conjugate(I12_w3_small)
5211
5320
  )[..., edge_dx : M3 - edge_dx, edge_dy : N3 - edge_dy],
5212
- axis=(-2, -1),
5321
+ axis=dim,
5213
5322
  )
5214
5323
  * fft_factor
5215
5324
  / norm_factor_S3
@@ -5224,7 +5333,7 @@ class funct(FOC.FoCUS):
5224
5333
  )
5225
5334
  * self.backend.bk_conjugate(I1_f2_wf3_small)
5226
5335
  ),
5227
- axis=(-2, -1),
5336
+ axis=dim,
5228
5337
  )
5229
5338
  * fft_factor
5230
5339
  / norm_factor_S3
@@ -5239,7 +5348,7 @@ class funct(FOC.FoCUS):
5239
5348
  )
5240
5349
  * self.backend.bk_conjugate(I1_f2_wf3_small)
5241
5350
  ),
5242
- axis=(-2, -1),
5351
+ axis=dim,
5243
5352
  )
5244
5353
  * fft_factor
5245
5354
  / norm_factor_S3
@@ -5254,7 +5363,7 @@ class funct(FOC.FoCUS):
5254
5363
  )
5255
5364
  * self.backend.bk_conjugate(I12_w3_small)
5256
5365
  )[..., edge_dx : M3 - edge_dx, edge_dy : N3 - edge_dy],
5257
- axis=(-2, -1),
5366
+ axis=dim,
5258
5367
  )
5259
5368
  * fft_factor
5260
5369
  / norm_factor_S3
@@ -5272,7 +5381,7 @@ class funct(FOC.FoCUS):
5272
5381
  edge_dx : M3 - edge_dx,
5273
5382
  edge_dy : N3 - edge_dy,
5274
5383
  ],
5275
- axis=(-2, -1),
5384
+ axis=dim,
5276
5385
  )
5277
5386
  * fft_factor
5278
5387
  / norm_factor_S3
@@ -5318,7 +5427,7 @@ class funct(FOC.FoCUS):
5318
5427
  )
5319
5428
  )
5320
5429
  ),
5321
- axis=(-2, -1),
5430
+ axis=dim,
5322
5431
  )
5323
5432
  * fft_factor
5324
5433
  * P
@@ -5338,7 +5447,7 @@ class funct(FOC.FoCUS):
5338
5447
  )
5339
5448
  )
5340
5449
  ),
5341
- axis=(-2, -1),
5450
+ axis=dim,
5342
5451
  )
5343
5452
  * fft_factor
5344
5453
  * P
@@ -5360,7 +5469,7 @@ class funct(FOC.FoCUS):
5360
5469
  )
5361
5470
  )
5362
5471
  ),
5363
- axis=(-2, -1),
5472
+ axis=dim,
5364
5473
  )
5365
5474
  * fft_factor
5366
5475
  * P
@@ -5380,7 +5489,7 @@ class funct(FOC.FoCUS):
5380
5489
  )
5381
5490
  )
5382
5491
  ),
5383
- axis=(-2, -1),
5492
+ axis=dim,
5384
5493
  )
5385
5494
  * fft_factor
5386
5495
  * P
@@ -5402,7 +5511,7 @@ class funct(FOC.FoCUS):
5402
5511
  )
5403
5512
  )
5404
5513
  )[..., edge_dx:-edge_dx, edge_dy:-edge_dy],
5405
- axis=(-2, -1),
5514
+ axis=dim,
5406
5515
  )
5407
5516
  * fft_factor
5408
5517
  * P
@@ -5422,7 +5531,7 @@ class funct(FOC.FoCUS):
5422
5531
  )
5423
5532
  )
5424
5533
  )[..., edge_dx:-edge_dx, edge_dy:-edge_dy],
5425
- axis=(-2, -1),
5534
+ axis=dim,
5426
5535
  )
5427
5536
  * fft_factor
5428
5537
  * P
@@ -5444,7 +5553,7 @@ class funct(FOC.FoCUS):
5444
5553
  )
5445
5554
  )
5446
5555
  )[..., edge_dx:-edge_dx, edge_dy:-edge_dy],
5447
- axis=(-2, -1),
5556
+ axis=dim,
5448
5557
  )
5449
5558
  * fft_factor
5450
5559
  * P
@@ -5468,7 +5577,7 @@ class funct(FOC.FoCUS):
5468
5577
  edge_dx:-edge_dx,
5469
5578
  edge_dy:-edge_dy,
5470
5579
  ],
5471
- axis=(-2, -1),
5580
+ axis=dim,
5472
5581
  )
5473
5582
  * fft_factor
5474
5583
  * P
@@ -5521,17 +5630,17 @@ class funct(FOC.FoCUS):
5521
5630
 
5522
5631
  if data2 is None:
5523
5632
  mean_data = self.backend.bk_reshape(
5524
- self.backend.bk_reduce_mean(data, axis=(-2, -1)), [N_image, 1]
5633
+ self.backend.bk_reduce_mean(data, axis=dim), [N_image, 1]
5525
5634
  )
5526
5635
  std_data = self.backend.bk_reshape(
5527
- self.backend.bk_reduce_std(data, axis=(-2, -1)), [N_image, 1]
5636
+ self.backend.bk_reduce_std(data, axis=dim), [N_image, 1]
5528
5637
  )
5529
5638
  else:
5530
5639
  mean_data = self.backend.bk_reshape(
5531
- self.backend.bk_reduce_mean(data * data2, axis=(-2, -1)), [N_image, 1]
5640
+ self.backend.bk_reduce_mean(data * data2, axis=dim), [N_image, 1]
5532
5641
  )
5533
5642
  std_data = self.backend.bk_reshape(
5534
- self.backend.bk_reduce_std(data * data2, axis=(-2, -1)), [N_image, 1]
5643
+ self.backend.bk_reduce_std(data * data2, axis=dim), [N_image, 1]
5535
5644
  )
5536
5645
 
5537
5646
  if get_variance:
@@ -5869,8 +5978,8 @@ class funct(FOC.FoCUS):
5869
5978
  def from_gaussian(self, x):
5870
5979
 
5871
5980
  x = self.backend.bk_clip_by_value(x,
5872
- self.val_min+1E-4*abs(self.val_min),
5873
- self.val_max-1E-4*abs(self.val_max))
5981
+ self.val_min+1E-7*(self.val_max-self.val_min),
5982
+ self.val_max-1E-7*(self.val_max-self.val_min))
5874
5983
  return self.f_gaussian(self.backend.to_numpy(x))
5875
5984
 
5876
5985
  def square(self, x):
@@ -6041,13 +6150,12 @@ class funct(FOC.FoCUS):
6041
6150
  return result
6042
6151
  else:
6043
6152
  if sigma is None:
6044
- tmp = x - y
6153
+ tmp = self.diff_data(x,y)
6045
6154
  else:
6046
- tmp = (x - y) / sigma
6155
+ tmp = self.diff_data(x,y,sigma=sigma)
6156
+
6047
6157
  # do abs in case of complex values
6048
- return self.backend.bk_abs(
6049
- self.backend.bk_reduce_mean(self.backend.bk_square(tmp))
6050
- )
6158
+ return tmp/x.shape[0]
6051
6159
 
6052
6160
  def reduce_sum(self, x):
6053
6161
 
@@ -6209,7 +6317,7 @@ class funct(FOC.FoCUS):
6209
6317
  Jmax=None,
6210
6318
  edge=False,
6211
6319
  to_gaussian=True,
6212
- use_variance=False,
6320
+ use_variance=True,
6213
6321
  synthesised_N=1,
6214
6322
  input_image=None,
6215
6323
  grd_mask=None,
@@ -6276,13 +6384,14 @@ class funct(FOC.FoCUS):
6276
6384
  use_v = args[2]
6277
6385
  ljmax = args[3]
6278
6386
 
6279
- learn = scat_operator.reduce_mean_batch(
6280
- scat_operator.eval(
6281
- u,
6282
- Jmax=ljmax,
6283
- norm='self'
6387
+ learn = scat_operator.eval(
6388
+ u,
6389
+ Jmax=ljmax,
6390
+ norm='auto'
6284
6391
  )
6285
- )
6392
+
6393
+ if synthesised_N>1:
6394
+ learn = scat_operator.reduce_mean_batch(learn)
6286
6395
 
6287
6396
  # compute scattering covariance of the current synthetised map called u
6288
6397
  if use_v:
@@ -6488,6 +6597,8 @@ class funct(FOC.FoCUS):
6488
6597
  )
6489
6598
  sref = ref
6490
6599
  else:
6600
+ self.clean_norm()
6601
+
6491
6602
  ref = self.eval(
6492
6603
  tmp[k],
6493
6604
  image2=l_ref[k],
@@ -6504,7 +6615,7 @@ class funct(FOC.FoCUS):
6504
6615
  mask=l_in_mask[k],
6505
6616
  Jmax=l_jmax[k],
6506
6617
  calc_var=True,
6507
- norm='self'
6618
+ norm='auto'
6508
6619
  )
6509
6620
  else:
6510
6621
  ref = self.eval(
@@ -6512,7 +6623,7 @@ class funct(FOC.FoCUS):
6512
6623
  image2=l_ref[k],
6513
6624
  mask=l_in_mask[k],
6514
6625
  Jmax=l_jmax[k],
6515
- norm='self'
6626
+ norm='auto'
6516
6627
  )
6517
6628
  sref = ref
6518
6629