foscat 2025.7.1__py3-none-any.whl → 2025.7.3__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
foscat/BkTorch.py CHANGED
@@ -196,10 +196,41 @@ class BkTorch(BackendBase.BackendBase):
196
196
  y = y.reshape(*leading_dims, O_c, Nx, Ny)
197
197
 
198
198
  return y
199
-
199
+
200
200
  def conv1d(self, x, w, strides=[1, 1, 1], padding="SAME"):
201
- # to be written!!!
202
- return x
201
+ """
202
+ Performs 1D convolution along the last axis of a 2D tensor x[n, m] with kernel w[K].
203
+
204
+ Parameters:
205
+ - x: torch.Tensor of shape [n, m]
206
+ - w: torch.Tensor of shape [K]
207
+ - strides: list of 3 ints; only strides[1] (along axis -1) is used
208
+ - padding: "SAME" or "VALID"
209
+
210
+ Returns:
211
+ - torch.Tensor of shape [n, m] (if SAME) or smaller (if VALID)
212
+ """
213
+ assert x.ndim == 2, "Input x must be a 2D tensor [n, m]"
214
+ assert w.ndim == 1, "Kernel w must be a 1D tensor [K]"
215
+ stride = strides[1]
216
+
217
+ # Reshape for PyTorch conv1d: [batch, channels, width]
218
+ x_reshaped = x.unsqueeze(1) # [n, 1, m]
219
+ w_flipped = w.flip(0).view(1, 1, -1) # [out_channels=1, in_channels=1, kernel_size]
220
+
221
+ if padding.upper() == "SAME":
222
+ pad_total = w.shape[0] - 1
223
+ pad_left = pad_total // 2
224
+ pad_right = pad_total - pad_left
225
+ x_reshaped = F.pad(x_reshaped, (pad_left, pad_right), mode='constant', value=0)
226
+ padding_mode = 'valid'
227
+ elif padding.upper() == "VALID":
228
+ padding_mode = 'valid'
229
+ else:
230
+ raise ValueError("padding must be either 'SAME' or 'VALID'")
231
+
232
+ out = F.conv1d(x_reshaped, w_flipped, stride=stride, padding=0) # manual padding applied above
233
+ return out.squeeze(1) # [n, m_out]
203
234
 
204
235
  def bk_threshold(self, x, threshold, greater=True):
205
236
 
foscat/FoCUS.py CHANGED
@@ -35,7 +35,7 @@ class FoCUS:
35
35
  mpi_rank=0
36
36
  ):
37
37
 
38
- self.__version__ = "2025.07.1"
38
+ self.__version__ = "2025.07.3"
39
39
  # P00 coeff for normalization for scat_cov
40
40
  self.TMPFILE_VERSION = TMPFILE_VERSION
41
41
  self.P1_dic = None
@@ -790,13 +790,11 @@ class FoCUS:
790
790
 
791
791
  npix = im.shape[axis]
792
792
  odata = 1
793
- if len(ishape) > axis + 1:
794
- for k in range(axis + 1, len(ishape)):
795
- odata = odata * ishape[k]
796
-
793
+
797
794
  ndata = 1
798
- for k in range(axis):
799
- ndata = ndata * ishape[k]
795
+ if len(ishape)>1:
796
+ for k in range(len(ishape)-1):
797
+ ndata = ndata * ishape[k]
800
798
 
801
799
  tim = self.backend.bk_reshape(
802
800
  self.backend.bk_cast(im), [ndata, npix, odata]
@@ -819,21 +817,7 @@ class FoCUS:
819
817
  self.backend.bk_concat([res1, res2], -2),
820
818
  [ndata, tim.shape[1] * 2, odata],
821
819
  )
822
-
823
- if axis == 0:
824
- if len(ishape) == 1:
825
- return self.backend.bk_reshape(tim, [nout])
826
- else:
827
- return self.backend.bk_reshape(tim, [nout] + ishape[axis + 1 :])
828
- else:
829
- if len(ishape) == axis + 1:
830
- return self.backend.bk_reshape(tim, ishape[0:axis] + [nout])
831
- else:
832
- return self.backend.bk_reshape(
833
- tim, ishape[0:axis] + [nout] + ishape[axis + 1 :]
834
- )
835
-
836
- return self.backend.bk_reshape(tim, [nout])
820
+ return self.backend.bk_reshape(tim, ishape[0:-1] + [nout])
837
821
 
838
822
  else:
839
823
 
@@ -1691,9 +1675,10 @@ class FoCUS:
1691
1675
  except:
1692
1676
  lcell_ids=self.to_numpy(cell_ids)
1693
1677
  idx_map[lcell_ids]=np.arange(lcell_ids.shape[0],dtype='int32')
1694
-
1678
+
1695
1679
  lidx=np.where(idx_map[tmp[:,1]%(12*nside**2)]!=-1)[0]
1696
1680
  orientation=tmp[lidx,1]//(12*nside**2)
1681
+ orientation2=tmp[lidx,0]//(12*nside**2)
1697
1682
  tmp=tmp[lidx]
1698
1683
  wr=wr[lidx]
1699
1684
  wi=wi[lidx]
@@ -1703,16 +1688,21 @@ class FoCUS:
1703
1688
  wi[lidx]=0.0
1704
1689
  tmp[lidx,0]=0
1705
1690
  tmp[:,1]+=orientation*lcell_ids.shape[0]
1691
+ tmp[:,0]+=orientation2*lcell_ids.shape[0]
1706
1692
 
1707
1693
  idx_map=-np.ones([12*nside**2],dtype='int32')
1708
1694
  idx_map[lcell_ids]=np.arange(cell_ids.shape[0],dtype='int32')
1709
- lidx=np.where(idx_map[tmp2[:,1]]!=-1)[0]
1695
+ lidx=np.where(idx_map[tmp2[:,1]%(12*nside**2)]!=-1)[0]
1696
+ i_id=tmp2[lidx,1]//(12*nside**2)
1697
+ i_id2=tmp2[lidx,0]//(12*nside**2)
1710
1698
  tmp2=tmp2[lidx]
1711
1699
  ws=ws[lidx]
1712
- tmp2=idx_map[tmp2]
1700
+ tmp2=idx_map[tmp2%(12*nside**2)]
1713
1701
  lidx=np.where(tmp2[:,0]==-1)[0]
1714
1702
  ws[lidx]=0.0
1715
1703
  tmp2[lidx,0]=0
1704
+ tmp2[:,1]+=i_id*lcell_ids.shape[0]
1705
+ tmp2[:,0]+=i_id2*lcell_ids.shape[0]
1716
1706
 
1717
1707
  else:
1718
1708
  tmp = indice
@@ -2110,8 +2100,8 @@ class FoCUS:
2110
2100
  ichannel = 1
2111
2101
  for i in range(1, len(shape) - 1):
2112
2102
  ichannel *= shape[i]
2113
-
2114
- l_x = self.backend.bk_reshape(x, [shape[0], 1, ichannel, shape[-1]])
2103
+
2104
+ l_x = self.backend.bk_reshape(x, [shape[0], 1, ichannel,shape[-1]])
2115
2105
 
2116
2106
  if self.padding == "VALID":
2117
2107
  oshape = [k for k in shape]
@@ -2198,9 +2188,9 @@ class FoCUS:
2198
2188
  elif self.use_1D:
2199
2189
  mtmp = l_mask
2200
2190
  vtmp = l_x
2201
- v1 = self.backend.bk_reduce_sum(l_mask[1,:,...,:] * vtmp, axis=-1)
2202
- v2 = self.backend.bk_reduce_sum(mtmp * vtmp * vtmp, axis=-1)
2203
- vh = self.backend.bk_reduce_sum(mtmp, axis=-1)
2191
+ v1 = self.backend.bk_reduce_sum(l_mask * vtmp, axis=-1)
2192
+ v2 = self.backend.bk_reduce_sum(l_mask * vtmp * vtmp, axis=-1)
2193
+ vh = self.backend.bk_reduce_sum(l_mask , axis=-1)
2204
2194
 
2205
2195
  res = v1 / vh
2206
2196
 
@@ -2599,7 +2589,7 @@ class FoCUS:
2599
2589
  ishape = list(in_image.shape)
2600
2590
 
2601
2591
  npix = ishape[-1]
2602
-
2592
+
2603
2593
  ndata = 1
2604
2594
  for k in range(len(ishape) - 1):
2605
2595
  ndata = ndata * ishape[k]
@@ -2612,7 +2602,7 @@ class FoCUS:
2612
2602
  res = self.backend.bk_complex(rr, ii)
2613
2603
  else:
2614
2604
  res = self.backend.conv1d(tim, self.ww_SmoothT[1])
2615
-
2605
+
2616
2606
  return self.backend.bk_reshape(res, ishape)
2617
2607
 
2618
2608
  else:
foscat/HealSpline.py CHANGED
@@ -6,11 +6,14 @@ import healpy as hp
6
6
 
7
7
  class heal_spline:
8
8
  def __init__(
9
- self,
10
- level):
9
+ self,
10
+ level,
11
+ gamma=1,
12
+ ):
11
13
  nside=2**level
12
14
  self.nside_store=2**(level//2)
13
15
  self.spline_tree={}
16
+ self.gamma=gamma
14
17
 
15
18
  self.nside=nside
16
19
  #compute colatitude
@@ -79,6 +82,8 @@ class heal_spline:
79
82
 
80
83
  hit=np.bincount(all_idx.flatten(),weights=www.flatten())
81
84
  www[hit[all_idx]<threshold]=0.0
85
+ if self.gamma!=1:
86
+ www=www**self.gamma
82
87
  www=www/np.sum(www,0)[None,:]
83
88
  return www,all_idx,heal_idx
84
89
 
foscat/scat_cov.py CHANGED
@@ -106,12 +106,13 @@ class scat_cov:
106
106
 
107
107
  # ---------------------------------------------−---------
108
108
  def flatten(self):
109
- tmp = [
110
- self.conv2complex(
111
- self.backend.bk_reshape(self.S0, [self.S1.shape[0], self.S0.shape[1]*self.S0.shape[2]])
112
- )
113
- ]
114
109
  if self.use_1D:
110
+ tmp = [
111
+ self.conv2complex(
112
+ self.backend.bk_reshape(self.S0, [self.S1.shape[0], self.S0.shape[1]])
113
+ )
114
+ ]
115
+
115
116
  if self.S1 is not None:
116
117
  tmp = tmp + [
117
118
  self.conv2complex(
@@ -156,6 +157,11 @@ class scat_cov:
156
157
 
157
158
  return self.backend.bk_concat(tmp, 1)
158
159
 
160
+ tmp = [
161
+ self.conv2complex(
162
+ self.backend.bk_reshape(self.S0, [self.S1.shape[0], self.S0.shape[1]*self.S0.shape[2]])
163
+ )
164
+ ]
159
165
  if self.S1 is not None:
160
166
  tmp = tmp + [
161
167
  self.conv2complex(
@@ -3610,19 +3616,34 @@ class funct(FOC.FoCUS):
3610
3616
  self.P2_dic = P2_dic
3611
3617
 
3612
3618
  if not return_data:
3613
- S1 = self.backend.bk_concat(S1, -2)
3614
- S2 = self.backend.bk_concat(S2, -2)
3615
- S3 = self.backend.bk_concat(S3, -3)
3616
- S4 = self.backend.bk_concat(S4, -4)
3617
- if cross:
3618
- S3P = self.backend.bk_concat(S3P, -3)
3619
- if calc_var:
3620
- VS1 = self.backend.bk_concat(VS1, -2)
3621
- VS2 = self.backend.bk_concat(VS2, -2)
3622
- VS3 = self.backend.bk_concat(VS3, -3)
3623
- VS4 = self.backend.bk_concat(VS4, -4)
3619
+ if not self.use_1D:
3620
+ S1 = self.backend.bk_concat(S1, -2)
3621
+ S2 = self.backend.bk_concat(S2, -2)
3622
+ S3 = self.backend.bk_concat(S3, -3)
3623
+ S4 = self.backend.bk_concat(S4, -4)
3624
3624
  if cross:
3625
- VS3P = self.backend.bk_concat(VS3P, -3)
3625
+ S3P = self.backend.bk_concat(S3P, -3)
3626
+ if calc_var:
3627
+ VS1 = self.backend.bk_concat(VS1, -2)
3628
+ VS2 = self.backend.bk_concat(VS2, -2)
3629
+ VS3 = self.backend.bk_concat(VS3, -3)
3630
+ VS4 = self.backend.bk_concat(VS4, -4)
3631
+ if cross:
3632
+ VS3P = self.backend.bk_concat(VS3P, -3)
3633
+ else:
3634
+ S1 = self.backend.bk_concat(S1, -1)
3635
+ S2 = self.backend.bk_concat(S2, -1)
3636
+ S3 = self.backend.bk_concat(S3, -1)
3637
+ S4 = self.backend.bk_concat(S4, -1)
3638
+ if cross:
3639
+ S3P = self.backend.bk_concat(S3P, -1)
3640
+ if calc_var:
3641
+ VS1 = self.backend.bk_concat(VS1, -1)
3642
+ VS2 = self.backend.bk_concat(VS2, -1)
3643
+ VS3 = self.backend.bk_concat(VS3, -1)
3644
+ VS4 = self.backend.bk_concat(VS4, -1)
3645
+ if cross:
3646
+ VS3P = self.backend.bk_concat(VS3P, -1)
3626
3647
  if calc_var:
3627
3648
  if not cross:
3628
3649
  return scat_cov(
@@ -3816,62 +3837,121 @@ class funct(FOC.FoCUS):
3816
3837
  Done by Sihao Cheng and Rudy Morel.
3817
3838
  """
3818
3839
 
3819
- filter = np.zeros([J, L, M, N], dtype="complex64")
3840
+ if N!=0:
3841
+ filter = np.zeros([J, L, M, N], dtype="complex64")
3820
3842
 
3821
- slant = 4.0 / L
3843
+ slant = 4.0 / L
3822
3844
 
3823
- for j in range(J):
3845
+ for j in range(J):
3824
3846
 
3825
- for ell in range(L):
3847
+ for ell in range(L):
3826
3848
 
3827
- theta = (int(L - L / 2 - 1) - ell) * np.pi / L
3828
- sigma = 0.8 * 2**j
3829
- xi = 3.0 / 4.0 * np.pi / 2**j
3849
+ theta = (int(L - L / 2 - 1) - ell) * np.pi / L
3850
+ sigma = 0.8 * 2**j
3851
+ xi = 3.0 / 4.0 * np.pi / 2**j
3830
3852
 
3831
- R = np.array(
3832
- [[np.cos(theta), -np.sin(theta)], [np.sin(theta), np.cos(theta)]],
3833
- np.float64,
3834
- )
3835
- R_inv = np.array(
3836
- [[np.cos(theta), np.sin(theta)], [-np.sin(theta), np.cos(theta)]],
3837
- np.float64,
3838
- )
3839
- D = np.array([[1, 0], [0, slant * slant]])
3840
- curv = np.matmul(R, np.matmul(D, R_inv)) / (2 * sigma * sigma)
3841
-
3842
- gab = np.zeros((M, N), np.complex128)
3843
- xx = np.empty((2, 2, M, N))
3844
- yy = np.empty((2, 2, M, N))
3845
-
3846
- for ii, ex in enumerate([-1, 0]):
3847
- for jj, ey in enumerate([-1, 0]):
3848
- xx[ii, jj], yy[ii, jj] = np.mgrid[
3849
- ex * M : M + ex * M, ey * N : N + ey * N
3850
- ]
3851
-
3852
- arg = -(
3853
- curv[0, 0] * xx * xx
3854
- + (curv[0, 1] + curv[1, 0]) * xx * yy
3855
- + curv[1, 1] * yy * yy
3856
- )
3857
- argi = arg + 1.0j * (xx * xi * np.cos(theta) + yy * xi * np.sin(theta))
3853
+ R = np.array(
3854
+ [[np.cos(theta), -np.sin(theta)], [np.sin(theta), np.cos(theta)]],
3855
+ np.float64,
3856
+ )
3857
+ R_inv = np.array(
3858
+ [[np.cos(theta), np.sin(theta)], [-np.sin(theta), np.cos(theta)]],
3859
+ np.float64,
3860
+ )
3861
+ D = np.array([[1, 0], [0, slant * slant]])
3862
+ curv = np.matmul(R, np.matmul(D, R_inv)) / (2 * sigma * sigma)
3858
3863
 
3859
- gabi = np.exp(argi).sum((0, 1))
3860
- gab = np.exp(arg).sum((0, 1))
3864
+ gab = np.zeros((M, N), np.complex128)
3865
+ xx = np.empty((2, 2, M, N))
3866
+ yy = np.empty((2, 2, M, N))
3861
3867
 
3862
- norm_factor = 2 * np.pi * sigma * sigma / slant
3868
+ for ii, ex in enumerate([-1, 0]):
3869
+ for jj, ey in enumerate([-1, 0]):
3870
+ xx[ii, jj], yy[ii, jj] = np.mgrid[
3871
+ ex * M : M + ex * M, ey * N : N + ey * N
3872
+ ]
3863
3873
 
3864
- gab = gab / norm_factor
3874
+ arg = -(
3875
+ curv[0, 0] * xx * xx
3876
+ + (curv[0, 1] + curv[1, 0]) * xx * yy
3877
+ + curv[1, 1] * yy * yy
3878
+ )
3879
+ argi = arg + 1.0j * (xx * xi * np.cos(theta) + yy * xi * np.sin(theta))
3865
3880
 
3866
- gabi = gabi / norm_factor
3881
+ gabi = np.exp(argi).sum((0, 1))
3882
+ gab = np.exp(arg).sum((0, 1))
3867
3883
 
3868
- K = gabi.sum() / gab.sum()
3884
+ norm_factor = 2 * np.pi * sigma * sigma / slant
3869
3885
 
3870
- # Apply the Gaussian
3871
- filter[j, ell] = np.fft.fft2(gabi - K * gab)
3872
- filter[j, ell, 0, 0] = 0.0
3886
+ gab = gab / norm_factor
3873
3887
 
3874
- return self.backend.bk_cast(filter)
3888
+ gabi = gabi / norm_factor
3889
+
3890
+ K = gabi.sum() / gab.sum()
3891
+
3892
+ # Apply the Gaussian
3893
+ filter[j, ell] = np.fft.fft2(gabi - K * gab)
3894
+ filter[j, ell, 0, 0] = 0.0
3895
+
3896
+ return self.backend.bk_cast(filter)
3897
+ else:
3898
+ filter = np.zeros([J, L, M], dtype="complex64")
3899
+ #TODO
3900
+ print('filter for 1D not yet available')
3901
+ exit(0)
3902
+ slant = 4.0 / L
3903
+
3904
+ for j in range(J):
3905
+
3906
+ for ell in range(L):
3907
+
3908
+ theta = (int(L - L / 2 - 1) - ell) * np.pi / L
3909
+ sigma = 0.8 * 2**j
3910
+ xi = 3.0 / 4.0 * np.pi / 2**j
3911
+
3912
+ R = np.array(
3913
+ [[np.cos(theta), -np.sin(theta)], [np.sin(theta), np.cos(theta)]],
3914
+ np.float64,
3915
+ )
3916
+ R_inv = np.array(
3917
+ [[np.cos(theta), np.sin(theta)], [-np.sin(theta), np.cos(theta)]],
3918
+ np.float64,
3919
+ )
3920
+ D = np.array([[1, 0], [0, slant * slant]])
3921
+ curv = np.matmul(R, np.matmul(D, R_inv)) / (2 * sigma * sigma)
3922
+
3923
+ gab = np.zeros((M), np.complex128)
3924
+ xx = np.empty((M))
3925
+
3926
+ for ii, ex in enumerate([-1, 0]):
3927
+ for jj, ey in enumerate([-1, 0]):
3928
+ xx[ii, jj], yy[ii, jj] = np.mgrid[
3929
+ ex * M : M + ex * M, ey * N : N + ey * N
3930
+ ]
3931
+
3932
+ arg = -(
3933
+ curv[0, 0] * xx * xx
3934
+ + (curv[0, 1] + curv[1, 0]) * xx * yy
3935
+ + curv[1, 1] * yy * yy
3936
+ )
3937
+ argi = arg + 1.0j * (xx * xi * np.cos(theta) + yy * xi * np.sin(theta))
3938
+
3939
+ gabi = np.exp(argi).sum((0, 1))
3940
+ gab = np.exp(arg).sum((0, 1))
3941
+
3942
+ norm_factor = 2 * np.pi * sigma * sigma / slant
3943
+
3944
+ gab = gab / norm_factor
3945
+
3946
+ gabi = gabi / norm_factor
3947
+
3948
+ K = gabi.sum() / gab.sum()
3949
+
3950
+ # Apply the Gaussian
3951
+ filter[j, ell] = np.fft.fft2(gabi - K * gab)
3952
+ filter[j, ell, 0, 0] = 0.0
3953
+
3954
+ return self.backend.bk_cast(filter)
3875
3955
 
3876
3956
  # ------------------------------------------------------------------------------------------
3877
3957
  #
@@ -4107,18 +4187,24 @@ class funct(FOC.FoCUS):
4107
4187
  if data2 is not None:
4108
4188
  N_image2 = data2.shape[0]
4109
4189
  J = int(np.log(nside) / np.log(2)) - 1 # Number of j scales
4190
+ dim=(-2,-1)
4110
4191
  elif self.use_1D:
4111
4192
  if len(data.shape) == 2:
4112
4193
  npix = int(im_shape[1]) # Number of pixels
4194
+ M = im_shape[1]
4195
+ N=0
4113
4196
  N_image = 1
4114
4197
  N_image2 = 1
4115
4198
  else:
4116
4199
  npix = int(im_shape[0]) # Number of pixels
4117
4200
  N_image = data.shape[0]
4201
+ M = im_shape[0]
4202
+ N=0
4118
4203
  if data2 is not None:
4119
4204
  N_image2 = data2.shape[0]
4120
4205
 
4121
4206
  nside = int(npix)
4207
+ dim=(-1)
4122
4208
 
4123
4209
  J = int(np.log(nside) / np.log(2)) - 1 # Number of j scales
4124
4210
  else:
@@ -4169,10 +4255,10 @@ class funct(FOC.FoCUS):
4169
4255
 
4170
4256
  # convert numpy array input into self.backend.bk_ tensors
4171
4257
  data = self.backend.bk_cast(data)
4172
- data_f = self.backend.bk_fftn(data, dim=(-2, -1))
4258
+ data_f = self.backend.bk_fftn(data, dim=dim)
4173
4259
  if data2 is not None:
4174
4260
  data2 = self.backend.bk_cast(data2)
4175
- data2_f = self.backend.bk_fftn(data2, dim=(-2, -1))
4261
+ data2_f = self.backend.bk_fftn(data2, dim=dim)
4176
4262
 
4177
4263
  # initialize tensors for scattering coefficients
4178
4264
  S2 = self.backend.bk_zeros((N_image, J, L), dtype=data.dtype)
@@ -4248,13 +4334,13 @@ class funct(FOC.FoCUS):
4248
4334
  I1 = self.backend.bk_ifftn(
4249
4335
  data_f[None, None, None, :, :]
4250
4336
  * filters_set[None, :J, :, :, :],
4251
- dim=(-2, -1),
4337
+ dim=dim,
4252
4338
  ).abs()
4253
4339
  else:
4254
4340
  I1 = self.backend.bk_ifftn(
4255
4341
  data_f[:, None, None, :, :]
4256
4342
  * filters_set[None, :J, :, :, :],
4257
- dim=(-2, -1),
4343
+ dim=dim,
4258
4344
  ).abs()
4259
4345
  elif self.use_1D:
4260
4346
  if len(data.shape) == 1:
@@ -4270,12 +4356,12 @@ class funct(FOC.FoCUS):
4270
4356
  else:
4271
4357
  print("todo")
4272
4358
 
4273
- S2 = (I1**2 * edge_mask).mean((-2, -1))
4274
- S1 = (I1 * edge_mask).mean((-2, -1))
4359
+ S2 = (I1**2 * edge_mask).mean(dim)
4360
+ S1 = (I1 * edge_mask).mean(dim)
4275
4361
 
4276
4362
  if get_variance:
4277
- S2_sigma = (I1**2 * edge_mask).std((-2, -1))
4278
- S1_sigma = (I1 * edge_mask).std((-2, -1))
4363
+ S2_sigma = (I1**2 * edge_mask).std(dim)
4364
+ S1_sigma = (I1 * edge_mask).std(dim)
4279
4365
 
4280
4366
  else:
4281
4367
  if self.use_2D:
@@ -4283,64 +4369,64 @@ class funct(FOC.FoCUS):
4283
4369
  I1 = self.backend.bk_ifftn(
4284
4370
  data_f[None, None, None, :, :]
4285
4371
  * filters_set[None, :J, :, :, :],
4286
- dim=(-2, -1),
4372
+ dim=dim,
4287
4373
  )
4288
4374
  I2 = self.backend.bk_ifftn(
4289
4375
  data2_f[None, None, None, :, :]
4290
4376
  * filters_set[None, :J, :, :, :],
4291
- dim=(-2, -1),
4377
+ dim=dim,
4292
4378
  )
4293
4379
  else:
4294
4380
  I1 = self.backend.bk_ifftn(
4295
4381
  data_f[:, None, None, :, :]
4296
4382
  * filters_set[None, :J, :, :, :],
4297
- dim=(-2, -1),
4383
+ dim=dim,
4298
4384
  )
4299
4385
  I2 = self.backend.bk_ifftn(
4300
4386
  data2_f[:, None, None, :, :]
4301
4387
  * filters_set[None, :J, :, :, :],
4302
- dim=(-2, -1),
4388
+ dim=dim,
4303
4389
  )
4304
4390
  elif self.use_1D:
4305
4391
  if len(data.shape) == 1:
4306
4392
  I1 = self.backend.bk_ifftn(
4307
4393
  data_f[None, None, None, :] * filters_set[None, :J, :, :],
4308
- dim=(-1),
4394
+ dim=dim,
4309
4395
  )
4310
4396
  I2 = self.backend.bk_ifftn(
4311
4397
  data2_f[None, None, None, :] * filters_set[None, :J, :, :],
4312
- dim=(-1),
4398
+ dim=dim,
4313
4399
  )
4314
4400
  else:
4315
4401
  I1 = self.backend.bk_ifftn(
4316
4402
  data_f[:, None, None, :] * filters_set[None, :J, :, :],
4317
- dim=(-1),
4403
+ dim=dim,
4318
4404
  )
4319
4405
  I2 = self.backend.bk_ifftn(
4320
4406
  data2_f[:, None, None, :] * filters_set[None, :J, :, :],
4321
- dim=(-1),
4407
+ dim=dim,
4322
4408
  )
4323
4409
  else:
4324
4410
  print("todo")
4325
4411
 
4326
4412
  I1 = self.backend.bk_real(I1 * self.backend.bk_conjugate(I2))
4327
4413
 
4328
- S2 = self.backend.bk_reduce_mean((I1 * edge_mask), axis=(-2, -1))
4414
+ S2 = self.backend.bk_reduce_mean((I1 * edge_mask), axis=dim)
4329
4415
  if get_variance:
4330
4416
  S2_sigma = self.backend.bk_reduce_std(
4331
- (I1 * edge_mask), axis=(-2, -1)
4417
+ (I1 * edge_mask), axis=dim
4332
4418
  )
4333
4419
 
4334
4420
  I1 = self.backend.bk_L1(I1)
4335
4421
 
4336
- S1 = self.backend.bk_reduce_mean((I1 * edge_mask), axis=(-2, -1))
4422
+ S1 = self.backend.bk_reduce_mean((I1 * edge_mask), axis=dim)
4337
4423
 
4338
4424
  if get_variance:
4339
4425
  S1_sigma = self.backend.bk_reduce_std(
4340
- (I1 * edge_mask), axis=(-2, -1)
4426
+ (I1 * edge_mask), axis=dim
4341
4427
  )
4342
4428
 
4343
- I1_f = self.backend.bk_fftn(I1, dim=(-2, -1))
4429
+ I1_f = self.backend.bk_fftn(I1, dim=dim)
4344
4430
 
4345
4431
  if pseudo_coef != 1:
4346
4432
  I1 = I1**pseudo_coef
@@ -4362,14 +4448,14 @@ class funct(FOC.FoCUS):
4362
4448
  data2_f_small = self.cut_high_k_off(data2_f, dx3, dy3)
4363
4449
  if edge:
4364
4450
  I1_small = self.backend.bk_ifftn(
4365
- I1_f_small, dim=(-2, -1), norm="ortho"
4451
+ I1_f_small, dim=dim, norm="ortho"
4366
4452
  )
4367
4453
  data_small = self.backend.bk_ifftn(
4368
- data_f_small, dim=(-2, -1), norm="ortho"
4454
+ data_f_small, dim=dim, norm="ortho"
4369
4455
  )
4370
4456
  if data2 is not None:
4371
4457
  data2_small = self.backend.bk_ifftn(
4372
- data2_f_small, dim=(-2, -1), norm="ortho"
4458
+ data2_f_small, dim=dim, norm="ortho"
4373
4459
  )
4374
4460
 
4375
4461
  wavelet_f3 = self.cut_high_k_off(filters_set[j3], dx3, dy3) # L,x,y
@@ -4400,10 +4486,10 @@ class funct(FOC.FoCUS):
4400
4486
  ) * wavelet_f3_squared.view(1, 1, L, M3, N3)
4401
4487
  if edge:
4402
4488
  I12_w3_small = self.backend.bk_ifftn(
4403
- I1_f2_wf3_small, dim=(-2, -1), norm="ortho"
4489
+ I1_f2_wf3_small, dim=dim, norm="ortho"
4404
4490
  )
4405
4491
  I12_w3_2_small = self.backend.bk_ifftn(
4406
- I1_f2_wf3_2_small, dim=(-2, -1), norm="ortho"
4492
+ I1_f2_wf3_2_small, dim=dim, norm="ortho"
4407
4493
  )
4408
4494
  if use_ref:
4409
4495
  if normalization == "P11":
@@ -4420,7 +4506,7 @@ class funct(FOC.FoCUS):
4420
4506
  if normalization == "P11":
4421
4507
  # [N_image,l2,l3,x,y]
4422
4508
  P11_temp = (I1_f2_wf3_small.abs() ** 2).mean(
4423
- (-2, -1)
4509
+ dim
4424
4510
  ) * fft_factor
4425
4511
  norm_factor_S3 = (
4426
4512
  S2[:, None, j3, :] * P11_temp**pseudo_coef
@@ -4435,7 +4521,7 @@ class funct(FOC.FoCUS):
4435
4521
  (
4436
4522
  data_f_small.view(N_image, 1, 1, M3, N3)
4437
4523
  * self.backend.bk_conjugate(I1_f2_wf3_small)
4438
- ).mean((-2, -1))
4524
+ ).mean(dim)
4439
4525
  * fft_factor
4440
4526
  / norm_factor_S3
4441
4527
  )
@@ -4445,7 +4531,7 @@ class funct(FOC.FoCUS):
4445
4531
  (
4446
4532
  data_f_small.view(N_image, 1, 1, M3, N3)
4447
4533
  * self.backend.bk_conjugate(I1_f2_wf3_small)
4448
- ).std((-2, -1))
4534
+ ).std(dim)
4449
4535
  * fft_factor
4450
4536
  / norm_factor_S3
4451
4537
  )
@@ -4456,7 +4542,7 @@ class funct(FOC.FoCUS):
4456
4542
  * self.backend.bk_conjugate(I12_w3_small)
4457
4543
  * edge_mask[None, None, None, :, :]
4458
4544
  ).mean( # [..., edge_dx : M3 - edge_dx, edge_dy : N3 - edge_dy]
4459
- (-2, -1)
4545
+ dim
4460
4546
  )
4461
4547
  * fft_factor
4462
4548
  / norm_factor_S3
@@ -4467,7 +4553,7 @@ class funct(FOC.FoCUS):
4467
4553
  data_small.view(N_image, 1, 1, M3, N3)
4468
4554
  * self.backend.bk_conjugate(I12_w3_small)
4469
4555
  * edge_mask[None, None, None, :, :]
4470
- ).std((-2, -1))
4556
+ ).std(dim)
4471
4557
  * fft_factor
4472
4558
  / norm_factor_S3
4473
4559
  )
@@ -4477,7 +4563,7 @@ class funct(FOC.FoCUS):
4477
4563
  (
4478
4564
  data2_f_small.view(N_image2, 1, 1, M3, N3)
4479
4565
  * self.backend.bk_conjugate(I1_f2_wf3_small)
4480
- ).mean((-2, -1))
4566
+ ).mean(dim)
4481
4567
  * fft_factor
4482
4568
  / norm_factor_S3
4483
4569
  )
@@ -4487,7 +4573,7 @@ class funct(FOC.FoCUS):
4487
4573
  (
4488
4574
  data2_f_small.view(N_image2, 1, 1, M3, N3)
4489
4575
  * self.backend.bk_conjugate(I1_f2_wf3_small)
4490
- ).std((-2, -1))
4576
+ ).std(dim)
4491
4577
  * fft_factor
4492
4578
  / norm_factor_S3
4493
4579
  )
@@ -4497,7 +4583,7 @@ class funct(FOC.FoCUS):
4497
4583
  data2_small.view(N_image2, 1, 1, M3, N3)
4498
4584
  * self.backend.bk_conjugate(I12_w3_small)
4499
4585
  * edge_mask[None, None, None, :, :]
4500
- ).mean((-2, -1))
4586
+ ).mean(dim)
4501
4587
  * fft_factor
4502
4588
  / norm_factor_S3
4503
4589
  )
@@ -4507,7 +4593,7 @@ class funct(FOC.FoCUS):
4507
4593
  data2_small.view(N_image2, 1, 1, M3, N3)
4508
4594
  * self.backend.bk_conjugate(I12_w3_small)
4509
4595
  * edge_mask[None, None, None, :, :]
4510
- ).std((-2, -1))
4596
+ ).std(dim)
4511
4597
  * fft_factor
4512
4598
  / norm_factor_S3
4513
4599
  )
@@ -4528,7 +4614,7 @@ class funct(FOC.FoCUS):
4528
4614
  N_image, 1, L, L, M3, N3
4529
4615
  )
4530
4616
  )
4531
- ).mean((-2, -1)) * fft_factor
4617
+ ).mean(dim) * fft_factor
4532
4618
  if get_variance:
4533
4619
  S4_sigma[:, Ndata_S4, :, :, :] = (
4534
4620
  I1_f_small[:, j1].view(
@@ -4539,7 +4625,7 @@ class funct(FOC.FoCUS):
4539
4625
  N_image, 1, L, L, M3, N3
4540
4626
  )
4541
4627
  )
4542
- ).std((-2, -1)) * fft_factor
4628
+ ).std(dim) * fft_factor
4543
4629
  else:
4544
4630
  for l1 in range(L):
4545
4631
  # [N_image,l2,l3,x,y]
@@ -4552,7 +4638,7 @@ class funct(FOC.FoCUS):
4552
4638
  N_image, L, L, M3, N3
4553
4639
  )
4554
4640
  )
4555
- ).mean((-2, -1)) * fft_factor
4641
+ ).mean(dim) * fft_factor
4556
4642
  if get_variance:
4557
4643
  S4_sigma[:, Ndata_S4, l1, :, :] = (
4558
4644
  I1_f_small[:, j1, l1].view(
@@ -4563,7 +4649,7 @@ class funct(FOC.FoCUS):
4563
4649
  N_image, L, L, M3, N3
4564
4650
  )
4565
4651
  )
4566
- ).std((-2, -1)) * fft_factor
4652
+ ).std(dim) * fft_factor
4567
4653
  else:
4568
4654
  if not if_large_batch:
4569
4655
  # [N_image,l1,l2,l3,x,y]
@@ -4577,7 +4663,7 @@ class funct(FOC.FoCUS):
4577
4663
  )
4578
4664
  )
4579
4665
  * edge_mask[None, None, None, None, :, :]
4580
- ).mean((-2, -1)) * fft_factor
4666
+ ).mean(dim) * fft_factor
4581
4667
  if get_variance:
4582
4668
  S4_sigma[:, Ndata_S4, :, :, :] = (
4583
4669
  I1_small[:, j1].view(
@@ -4591,7 +4677,7 @@ class funct(FOC.FoCUS):
4591
4677
  * edge_mask[
4592
4678
  None, None, None, None, :, :
4593
4679
  ]
4594
- ).std((-2, -1)) * fft_factor
4680
+ ).std(dim) * fft_factor
4595
4681
  else:
4596
4682
  for l1 in range(L):
4597
4683
  # [N_image,l2,l3,x,y]
@@ -4607,7 +4693,7 @@ class funct(FOC.FoCUS):
4607
4693
  * edge_mask[
4608
4694
  None, None, None, None, :, :
4609
4695
  ]
4610
- ).mean((-2, -1)) * fft_factor
4696
+ ).mean(dim) * fft_factor
4611
4697
  if get_variance:
4612
4698
  S4_sigma[:, Ndata_S4, l1, :, :] = (
4613
4699
  I1_small[:, j1].view(
@@ -4621,7 +4707,7 @@ class funct(FOC.FoCUS):
4621
4707
  * edge_mask[
4622
4708
  None, None, None, None, :, :
4623
4709
  ]
4624
- ).std((-2, -1)) * fft_factor
4710
+ ).std(dim) * fft_factor
4625
4711
 
4626
4712
  Ndata_S4 += 1
4627
4713
 
@@ -4689,11 +4775,11 @@ class funct(FOC.FoCUS):
4689
4775
  std_data = self.backend.bk_zeros((N_image, 1), dtype=data.dtype)
4690
4776
 
4691
4777
  if data2 is None:
4692
- mean_data[:, 0] = data.mean((-2, -1))
4693
- std_data[:, 0] = data.std((-2, -1))
4778
+ mean_data[:, 0] = data.mean(dim)
4779
+ std_data[:, 0] = data.std(dim)
4694
4780
  else:
4695
- mean_data[:, 0] = (data2 * data).mean((-2, -1))
4696
- std_data[:, 0] = (data2 * data).std((-2, -1))
4781
+ mean_data[:, 0] = (data2 * data).mean(dim)
4782
+ std_data[:, 0] = (data2 * data).std(dim)
4697
4783
 
4698
4784
  if get_variance:
4699
4785
  ref_sigma = {}
@@ -4913,10 +4999,10 @@ class funct(FOC.FoCUS):
4913
4999
 
4914
5000
  # convert numpy array input into self.backend.bk_ tensors
4915
5001
  data = self.backend.bk_cast(data)
4916
- data_f = self.backend.bk_fftn(data, dim=(-2, -1))
5002
+ data_f = self.backend.bk_fftn(data, dim=dim)
4917
5003
  if data2 is not None:
4918
5004
  data2 = self.backend.bk_cast(data2)
4919
- data2_f = self.backend.bk_fftn(data2, dim=(-2, -1))
5005
+ data2_f = self.backend.bk_fftn(data2, dim=dim)
4920
5006
 
4921
5007
  # initialize tensors for scattering coefficients
4922
5008
 
@@ -4967,7 +5053,7 @@ class funct(FOC.FoCUS):
4967
5053
  self.backend.bk_ifftn(
4968
5054
  data_f[None, None, None, :, :]
4969
5055
  * filters_set[None, :J, :, :, :],
4970
- dim=(-2, -1),
5056
+ dim=dim,
4971
5057
  )
4972
5058
  )
4973
5059
  else:
@@ -4975,7 +5061,7 @@ class funct(FOC.FoCUS):
4975
5061
  self.backend.bk_ifftn(
4976
5062
  data_f[:, None, None, :, :]
4977
5063
  * filters_set[None, :J, :, :, :],
4978
- dim=(-2, -1),
5064
+ dim=dim,
4979
5065
  )
4980
5066
  )
4981
5067
  elif self.use_1D:
@@ -4996,37 +5082,37 @@ class funct(FOC.FoCUS):
4996
5082
  else:
4997
5083
  print("todo")
4998
5084
 
4999
- S2 = self.backend.bk_reduce_mean((I1**2 * edge_mask), axis=(-2, -1))
5000
- S1 = self.backend.bk_reduce_mean(I1 * edge_mask, axis=(-2, -1))
5085
+ S2 = self.backend.bk_reduce_mean((I1**2 * edge_mask), axis=dim)
5086
+ S1 = self.backend.bk_reduce_mean(I1 * edge_mask, axis=dim)
5001
5087
 
5002
5088
  if get_variance:
5003
5089
  S2_sigma = self.backend.bk_reduce_std(
5004
- (I1**2 * edge_mask), axis=(-2, -1)
5090
+ (I1**2 * edge_mask), axis=dim
5005
5091
  )
5006
- S1_sigma = self.backend.bk_reduce_std((I1 * edge_mask), axis=(-2, -1))
5092
+ S1_sigma = self.backend.bk_reduce_std((I1 * edge_mask), axis=dim)
5007
5093
 
5008
- I1_f = self.backend.bk_fftn(I1, dim=(-2, -1))
5094
+ I1_f = self.backend.bk_fftn(I1, dim=dim)
5009
5095
 
5010
5096
  else:
5011
5097
  if self.use_2D:
5012
5098
  if len(data.shape) == 2:
5013
5099
  I1 = self.backend.bk_ifftn(
5014
5100
  data_f[None, None, None, :, :] * filters_set[None, :J, :, :, :],
5015
- dim=(-2, -1),
5101
+ dim=dim,
5016
5102
  )
5017
5103
  I2 = self.backend.bk_ifftn(
5018
5104
  data2_f[None, None, None, :, :]
5019
5105
  * filters_set[None, :J, :, :, :],
5020
- dim=(-2, -1),
5106
+ dim=dim,
5021
5107
  )
5022
5108
  else:
5023
5109
  I1 = self.backend.bk_ifftn(
5024
5110
  data_f[:, None, None, :, :] * filters_set[None, :J, :, :, :],
5025
- dim=(-2, -1),
5111
+ dim=dim,
5026
5112
  )
5027
5113
  I2 = self.backend.bk_ifftn(
5028
5114
  data2_f[:, None, None, :, :] * filters_set[None, :J, :, :, :],
5029
- dim=(-2, -1),
5115
+ dim=dim,
5030
5116
  )
5031
5117
  elif self.use_1D:
5032
5118
  if len(data.shape) == 1:
@@ -5051,18 +5137,18 @@ class funct(FOC.FoCUS):
5051
5137
 
5052
5138
  I1 = self.backend.bk_real(I1 * self.backend.bk_conjugate(I2))
5053
5139
 
5054
- S2 = self.backend.bk_reduce_mean((I1 * edge_mask), axis=(-2, -1))
5140
+ S2 = self.backend.bk_reduce_mean((I1 * edge_mask), axis=dim)
5055
5141
  if get_variance:
5056
- S2_sigma = self.backend.bk_reduce_std((I1 * edge_mask), axis=(-2, -1))
5142
+ S2_sigma = self.backend.bk_reduce_std((I1 * edge_mask), axis=dim)
5057
5143
 
5058
5144
  I1 = self.backend.bk_L1(I1)
5059
5145
 
5060
- S1 = self.backend.bk_reduce_mean((I1 * edge_mask), axis=(-2, -1))
5146
+ S1 = self.backend.bk_reduce_mean((I1 * edge_mask), axis=dim)
5061
5147
 
5062
5148
  if get_variance:
5063
- S1_sigma = self.backend.bk_reduce_std((I1 * edge_mask), axis=(-2, -1))
5149
+ S1_sigma = self.backend.bk_reduce_std((I1 * edge_mask), axis=dim)
5064
5150
 
5065
- I1_f = self.backend.bk_fftn(I1, dim=(-2, -1))
5151
+ I1_f = self.backend.bk_fftn(I1, dim=dim)
5066
5152
 
5067
5153
  if pseudo_coef != 1:
5068
5154
  I1 = I1**pseudo_coef
@@ -5083,13 +5169,13 @@ class funct(FOC.FoCUS):
5083
5169
  if data2 is not None:
5084
5170
  data2_f_small = self.cut_high_k_off(data2_f, dx3, dy3)
5085
5171
  if edge:
5086
- I1_small = self.backend.bk_ifftn(I1_f_small, dim=(-2, -1), norm="ortho")
5172
+ I1_small = self.backend.bk_ifftn(I1_f_small, dim=dim, norm="ortho")
5087
5173
  data_small = self.backend.bk_ifftn(
5088
- data_f_small, dim=(-2, -1), norm="ortho"
5174
+ data_f_small, dim=dim, norm="ortho"
5089
5175
  )
5090
5176
  if data2 is not None:
5091
5177
  data2_small = self.backend.bk_ifftn(
5092
- data2_f_small, dim=(-2, -1), norm="ortho"
5178
+ data2_f_small, dim=dim, norm="ortho"
5093
5179
  )
5094
5180
  wavelet_f3 = self.cut_high_k_off(filters_set[j3], dx3, dy3) # L,x,y
5095
5181
  _, M3, N3 = wavelet_f3.shape
@@ -5112,10 +5198,10 @@ class funct(FOC.FoCUS):
5112
5198
  ) * self.backend.bk_reshape(wavelet_f3_squared, [1, 1, 1, L, M3, N3])
5113
5199
  if edge:
5114
5200
  I12_w3_small = self.backend.bk_ifftn(
5115
- I1_f2_wf3_small, dim=(-2, -1), norm="ortho"
5201
+ I1_f2_wf3_small, dim=dim, norm="ortho"
5116
5202
  )
5117
5203
  I12_w3_2_small = self.backend.bk_ifftn(
5118
- I1_f2_wf3_2_small, dim=(-2, -1), norm="ortho"
5204
+ I1_f2_wf3_2_small, dim=dim, norm="ortho"
5119
5205
  )
5120
5206
  if use_ref:
5121
5207
  if normalization == "P11":
@@ -5141,7 +5227,7 @@ class funct(FOC.FoCUS):
5141
5227
  # [N_image,l2,l3,x,y]
5142
5228
  P11_temp = (
5143
5229
  self.backend.bk_reduce_mean(
5144
- (I1_f2_wf3_small.abs() ** 2), axis=(-2, -1)
5230
+ (I1_f2_wf3_small.abs() ** 2), axis=dim
5145
5231
  )
5146
5232
  * fft_factor
5147
5233
  )
@@ -5169,7 +5255,7 @@ class funct(FOC.FoCUS):
5169
5255
  data_f_small, [N_image, 1, 1, 1, M3, N3]
5170
5256
  )
5171
5257
  * self.backend.bk_conjugate(I1_f2_wf3_small),
5172
- axis=(-2, -1),
5258
+ axis=dim,
5173
5259
  )
5174
5260
  * fft_factor
5175
5261
  / norm_factor_S3
@@ -5181,7 +5267,7 @@ class funct(FOC.FoCUS):
5181
5267
  data_f_small, [N_image, 1, 1, 1, M3, N3]
5182
5268
  )
5183
5269
  * self.backend.bk_conjugate(I1_f2_wf3_small),
5184
- axis=(-2, -1),
5270
+ axis=dim,
5185
5271
  )
5186
5272
  * fft_factor
5187
5273
  / norm_factor_S3
@@ -5195,7 +5281,7 @@ class funct(FOC.FoCUS):
5195
5281
  )
5196
5282
  * self.backend.bk_conjugate(I12_w3_small)
5197
5283
  )[..., edge_dx : M3 - edge_dx, edge_dy : N3 - edge_dy],
5198
- axis=(-2, -1),
5284
+ axis=dim,
5199
5285
  )
5200
5286
  * fft_factor
5201
5287
  / norm_factor_S3
@@ -5209,7 +5295,7 @@ class funct(FOC.FoCUS):
5209
5295
  )
5210
5296
  * self.backend.bk_conjugate(I12_w3_small)
5211
5297
  )[..., edge_dx : M3 - edge_dx, edge_dy : N3 - edge_dy],
5212
- axis=(-2, -1),
5298
+ axis=dim,
5213
5299
  )
5214
5300
  * fft_factor
5215
5301
  / norm_factor_S3
@@ -5224,7 +5310,7 @@ class funct(FOC.FoCUS):
5224
5310
  )
5225
5311
  * self.backend.bk_conjugate(I1_f2_wf3_small)
5226
5312
  ),
5227
- axis=(-2, -1),
5313
+ axis=dim,
5228
5314
  )
5229
5315
  * fft_factor
5230
5316
  / norm_factor_S3
@@ -5239,7 +5325,7 @@ class funct(FOC.FoCUS):
5239
5325
  )
5240
5326
  * self.backend.bk_conjugate(I1_f2_wf3_small)
5241
5327
  ),
5242
- axis=(-2, -1),
5328
+ axis=dim,
5243
5329
  )
5244
5330
  * fft_factor
5245
5331
  / norm_factor_S3
@@ -5254,7 +5340,7 @@ class funct(FOC.FoCUS):
5254
5340
  )
5255
5341
  * self.backend.bk_conjugate(I12_w3_small)
5256
5342
  )[..., edge_dx : M3 - edge_dx, edge_dy : N3 - edge_dy],
5257
- axis=(-2, -1),
5343
+ axis=dim,
5258
5344
  )
5259
5345
  * fft_factor
5260
5346
  / norm_factor_S3
@@ -5272,7 +5358,7 @@ class funct(FOC.FoCUS):
5272
5358
  edge_dx : M3 - edge_dx,
5273
5359
  edge_dy : N3 - edge_dy,
5274
5360
  ],
5275
- axis=(-2, -1),
5361
+ axis=dim,
5276
5362
  )
5277
5363
  * fft_factor
5278
5364
  / norm_factor_S3
@@ -5318,7 +5404,7 @@ class funct(FOC.FoCUS):
5318
5404
  )
5319
5405
  )
5320
5406
  ),
5321
- axis=(-2, -1),
5407
+ axis=dim,
5322
5408
  )
5323
5409
  * fft_factor
5324
5410
  * P
@@ -5338,7 +5424,7 @@ class funct(FOC.FoCUS):
5338
5424
  )
5339
5425
  )
5340
5426
  ),
5341
- axis=(-2, -1),
5427
+ axis=dim,
5342
5428
  )
5343
5429
  * fft_factor
5344
5430
  * P
@@ -5360,7 +5446,7 @@ class funct(FOC.FoCUS):
5360
5446
  )
5361
5447
  )
5362
5448
  ),
5363
- axis=(-2, -1),
5449
+ axis=dim,
5364
5450
  )
5365
5451
  * fft_factor
5366
5452
  * P
@@ -5380,7 +5466,7 @@ class funct(FOC.FoCUS):
5380
5466
  )
5381
5467
  )
5382
5468
  ),
5383
- axis=(-2, -1),
5469
+ axis=dim,
5384
5470
  )
5385
5471
  * fft_factor
5386
5472
  * P
@@ -5402,7 +5488,7 @@ class funct(FOC.FoCUS):
5402
5488
  )
5403
5489
  )
5404
5490
  )[..., edge_dx:-edge_dx, edge_dy:-edge_dy],
5405
- axis=(-2, -1),
5491
+ axis=dim,
5406
5492
  )
5407
5493
  * fft_factor
5408
5494
  * P
@@ -5422,7 +5508,7 @@ class funct(FOC.FoCUS):
5422
5508
  )
5423
5509
  )
5424
5510
  )[..., edge_dx:-edge_dx, edge_dy:-edge_dy],
5425
- axis=(-2, -1),
5511
+ axis=dim,
5426
5512
  )
5427
5513
  * fft_factor
5428
5514
  * P
@@ -5444,7 +5530,7 @@ class funct(FOC.FoCUS):
5444
5530
  )
5445
5531
  )
5446
5532
  )[..., edge_dx:-edge_dx, edge_dy:-edge_dy],
5447
- axis=(-2, -1),
5533
+ axis=dim,
5448
5534
  )
5449
5535
  * fft_factor
5450
5536
  * P
@@ -5468,7 +5554,7 @@ class funct(FOC.FoCUS):
5468
5554
  edge_dx:-edge_dx,
5469
5555
  edge_dy:-edge_dy,
5470
5556
  ],
5471
- axis=(-2, -1),
5557
+ axis=dim,
5472
5558
  )
5473
5559
  * fft_factor
5474
5560
  * P
@@ -5521,17 +5607,17 @@ class funct(FOC.FoCUS):
5521
5607
 
5522
5608
  if data2 is None:
5523
5609
  mean_data = self.backend.bk_reshape(
5524
- self.backend.bk_reduce_mean(data, axis=(-2, -1)), [N_image, 1]
5610
+ self.backend.bk_reduce_mean(data, axis=dim), [N_image, 1]
5525
5611
  )
5526
5612
  std_data = self.backend.bk_reshape(
5527
- self.backend.bk_reduce_std(data, axis=(-2, -1)), [N_image, 1]
5613
+ self.backend.bk_reduce_std(data, axis=dim), [N_image, 1]
5528
5614
  )
5529
5615
  else:
5530
5616
  mean_data = self.backend.bk_reshape(
5531
- self.backend.bk_reduce_mean(data * data2, axis=(-2, -1)), [N_image, 1]
5617
+ self.backend.bk_reduce_mean(data * data2, axis=dim), [N_image, 1]
5532
5618
  )
5533
5619
  std_data = self.backend.bk_reshape(
5534
- self.backend.bk_reduce_std(data * data2, axis=(-2, -1)), [N_image, 1]
5620
+ self.backend.bk_reduce_std(data * data2, axis=dim), [N_image, 1]
5535
5621
  )
5536
5622
 
5537
5623
  if get_variance:
@@ -1,6 +1,6 @@
1
1
  Metadata-Version: 2.4
2
2
  Name: foscat
3
- Version: 2025.7.1
3
+ Version: 2025.7.3
4
4
  Summary: Generate synthetic Healpix or 2D data using Cross Scattering Transform
5
5
  Author-email: Jean-Marc DELOUIS <jean.marc.delouis@ifremer.fr>
6
6
  Maintainer-email: Theo Foulquier <theo.foulquier@ifremer.fr>
@@ -1,12 +1,12 @@
1
1
  foscat/BkBase.py,sha256=2buIR9RK6g7HLoHJbzVCYhi1PkjDW6SXlu7IlF7SfA4,21611
2
2
  foscat/BkNumpy.py,sha256=qvKxDoAPQD52Ui9qv_D_GZvWpXX2n9S9dOGlXz5uNdQ,10683
3
3
  foscat/BkTensorflow.py,sha256=iIdLx6VTOfOEocfZBOGyizQn5geDLTfdWWAwDeQr9YA,20056
4
- foscat/BkTorch.py,sha256=bAF1zpuQ03y5YR6mEgB4XS9F0TCm9TmFA4ZFfhXikHM,18244
4
+ foscat/BkTorch.py,sha256=fWkNTrgK1MkpkS-bNVmC0ihJY_WlPs98ndperSh63i8,19593
5
5
  foscat/CNN.py,sha256=gQ9V76wmcowo2BaNp5sJYcSDCVOjc18TS9cE6-qEUso,5153
6
6
  foscat/CircSpline.py,sha256=CXi49FxF8ZoeZ17Ua8c1AZXe2B5ICEC9aCXb97atB3s,4028
7
- foscat/FoCUS.py,sha256=pcSHO3vlpC5zi6wzFNxFcytW6_mMBuATIzdkPNRsMQI,103903
7
+ foscat/FoCUS.py,sha256=81LmzcWitLGS0CrqddoRmER95JlGxX3N0dQAi1M7i-g,103645
8
8
  foscat/GCNN.py,sha256=q7yWHCMJpP7-m3WvR3OQnp5taeYWaMxIY2hQ6SIb9gs,4487
9
- foscat/HealSpline.py,sha256=DuVR_n0sAIUATmXEkPMqvFySdRMEFZhM8tIcHg2iKRY,7269
9
+ foscat/HealSpline.py,sha256=Y05LLtsAVdkzf_u6UZtQx8u1DwfBnmso9OzHAT35ZJM,7387
10
10
  foscat/Softmax.py,sha256=UDZGrTroYtmGEyokGUVpwNO_cgbICi9QVuRr8Yx52_k,2917
11
11
  foscat/Spline1D.py,sha256=rKzzenduaZZ-yBDJd35it6Gyrj1spqb7hoIaUgISPzY,2983
12
12
  foscat/Synthesis.py,sha256=tC5hvpam19QwDdvghVax7dA7gMgKA6ZtxQEcV9HjdC0,13824
@@ -20,13 +20,13 @@ foscat/loss_backend_torch.py,sha256=k3z18Dj3SaLKK6ZIKcm7GO4U_YKYVP6LtHG1aIbxkYk,
20
20
  foscat/scat.py,sha256=qGYiBIysPt65MdmF07WWA4piVlTfA9-lFDTaicnqC2w,72822
21
21
  foscat/scat1D.py,sha256=W5Uu6wdQ4ZsFKXpof0f1OBl-1wjJmW7ruvddRWxe7uM,53726
22
22
  foscat/scat2D.py,sha256=boKj0ASqMMSy7uQLK6hPniG87m3hZGJBYBiq5v8F9IQ,532
23
- foscat/scat_cov.py,sha256=fIUwQsq7v3l5AKz9nJZoXDWWxWm18dnlbVTLneeCVG8,267077
23
+ foscat/scat_cov.py,sha256=TgYa78os4vbFI4DRK5hccnH-6bpRpXFEbN52Z2L2Xbs,270032
24
24
  foscat/scat_cov1D.py,sha256=XOxsZZ5TYq8f34i2tUgIfzyaqaTDlICB3HzD2l_puro,531
25
25
  foscat/scat_cov2D.py,sha256=pAm0fKw8wyXram0TFbtw8tGcc8QPKuPXpQk0kh10r4U,7078
26
26
  foscat/scat_cov_map.py,sha256=9MzpwT2g9S3dmnjHEMK7PPLQ27oGQg2VFVsP_TDUU5E,2869
27
27
  foscat/scat_cov_map2D.py,sha256=1dS4P1KHqZYkYCLA1sYpPSZulJrCTd_2eL8HFOjlcz4,2841
28
- foscat-2025.7.1.dist-info/licenses/LICENSE,sha256=i0ukIr8ZUpkSY2sZaE9XZK-6vuSU5iG6IgX_3pjatP8,1505
29
- foscat-2025.7.1.dist-info/METADATA,sha256=43ThyYojmtu36vazeDUoHJQk4nylTDF9xF9AO7FC_bQ,7215
30
- foscat-2025.7.1.dist-info/WHEEL,sha256=_zCd3N1l69ArxyTb8rzEoP9TpbYXkqRFSNOD5OuxnTs,91
31
- foscat-2025.7.1.dist-info/top_level.txt,sha256=AGySXBBAlJgb8Tj8af6m_F-aiNg2zNTcybCUPVOKjAg,7
32
- foscat-2025.7.1.dist-info/RECORD,,
28
+ foscat-2025.7.3.dist-info/licenses/LICENSE,sha256=i0ukIr8ZUpkSY2sZaE9XZK-6vuSU5iG6IgX_3pjatP8,1505
29
+ foscat-2025.7.3.dist-info/METADATA,sha256=uIMQ9XWdbP1sN8ftu_8flj1OgiNhRXHcn6Rw-tbLoWs,7215
30
+ foscat-2025.7.3.dist-info/WHEEL,sha256=_zCd3N1l69ArxyTb8rzEoP9TpbYXkqRFSNOD5OuxnTs,91
31
+ foscat-2025.7.3.dist-info/top_level.txt,sha256=AGySXBBAlJgb8Tj8af6m_F-aiNg2zNTcybCUPVOKjAg,7
32
+ foscat-2025.7.3.dist-info/RECORD,,