foscat 3.7.2__py3-none-any.whl → 3.8.0__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
foscat/scat_cov.py CHANGED
@@ -4791,17 +4791,37 @@ class funct(FOC.FoCUS):
4791
4791
  This function is strongly inspire by the package https://github.com/SihaoCheng/scattering_transform
4792
4792
  Done by Sihao Cheng and Rudy Morel.
4793
4793
  '''
4794
- if_xodd = (data_f.shape[-2]%2==1)
4795
- if_yodd = (data_f.shape[-1]%2==1)
4796
- result = self.backend.backend.cat(
4797
- (self.backend.backend.cat(
4798
- ( data_f[...,:dx+if_xodd, :dy+if_yodd] , data_f[...,-dx:, :dy+if_yodd]
4799
- ), -2),
4800
- self.backend.backend.cat(
4801
- ( data_f[...,:dx+if_xodd, -dy:] , data_f[...,-dx:, -dy:]
4802
- ), -2)
4803
- ),-1)
4804
- return result
4794
+
4795
+ if self.backend.BACKEND=='torch':
4796
+ if_xodd = (data_f.shape[-2]%2==1)
4797
+ if_yodd = (data_f.shape[-1]%2==1)
4798
+ result = self.backend.backend.cat(
4799
+ (self.backend.backend.cat(
4800
+ ( data_f[...,:dx+if_xodd, :dy+if_yodd] , data_f[...,-dx:, :dy+if_yodd]
4801
+ ), -2),
4802
+ self.backend.backend.cat(
4803
+ ( data_f[...,:dx+if_xodd, -dy:] , data_f[...,-dx:, -dy:]
4804
+ ), -2)
4805
+ ),-1)
4806
+ return result
4807
+ else:
4808
+ # Check if the last two dimensions are odd
4809
+ if_xodd = self.backend.backend.cast(self.backend.backend.shape(data_f)[-2] % 2 == 1, self.backend.backend.int32)
4810
+ if_yodd = self.backend.backend.cast(self.backend.backend.shape(data_f)[-1] % 2 == 1, self.backend.backend.int32)
4811
+
4812
+ # Extract four regions
4813
+ top_left = data_f[..., :dx+if_xodd, :dy+if_yodd]
4814
+ top_right = data_f[..., -dx:, :dy+if_yodd]
4815
+ bottom_left = data_f[..., :dx+if_xodd, -dy:]
4816
+ bottom_right = data_f[..., -dx:, -dy:]
4817
+
4818
+ # Concatenate along the last two dimensions
4819
+ top = self.backend.backend.concat([top_left, top_right], axis=-2)
4820
+ bottom = self.backend.backend.concat([bottom_left, bottom_right], axis=-2)
4821
+ result = self.backend.backend.concat([top, bottom], axis=-1)
4822
+
4823
+ return result
4824
+
4805
4825
  # ---------------------------------------------------------------------------
4806
4826
  #
4807
4827
  # utility functions for computing scattering coef and covariance
@@ -4824,13 +4844,15 @@ class funct(FOC.FoCUS):
4824
4844
  This function is strongly inspire by the package https://github.com/SihaoCheng/scattering_transform
4825
4845
  Done by Sihao Cheng and Rudy Morel.
4826
4846
  '''
4827
- edge_masks = self.backend.backend.empty((J, M, N))
4828
- X, Y = self.backend.backend.meshgrid(self.backend.backend.arange(M), self.backend.backend.arange(N), indexing='ij')
4847
+ edge_masks = np.empty((J, M, N))
4848
+ X, Y = np.meshgrid(np.arange(M), np.arange(N), indexing='ij')
4829
4849
  for j in range(J):
4830
4850
  edge_dx = min(M//4, 2**j*d0)
4831
4851
  edge_dy = min(N//4, 2**j*d0)
4832
4852
  edge_masks[j] = (X>=edge_dx) * (X<=M-edge_dx) * (Y>=edge_dy) * (Y<=N-edge_dy)
4833
- return edge_masks.to(self.backend.torch_device)
4853
+ edge_masks = edge_masks[:,None,:,:]
4854
+ edge_masks = edge_masks / edge_masks.mean((-2,-1))[:,:,None,None]
4855
+ return self.backend.bk_cast(edge_masks)
4834
4856
 
4835
4857
  # ---------------------------------------------------------------------------
4836
4858
  #
@@ -4921,18 +4943,24 @@ class funct(FOC.FoCUS):
4921
4943
  nside = np.min([im_shape[0], im_shape[1]])
4922
4944
  M,N = im_shape[0],im_shape[1]
4923
4945
  N_image = 1
4946
+ N_image2 = 1
4924
4947
  else:
4925
4948
  nside = np.min([im_shape[1], im_shape[2]])
4926
4949
  M,N = im_shape[1],im_shape[2]
4927
4950
  N_image = data.shape[0]
4951
+ if data2 is not None:
4952
+ N_image2 = data2.shape[0]
4928
4953
  J = int(np.log(nside) / np.log(2))-1 # Number of j scales
4929
4954
  elif self.use_1D:
4930
4955
  if len(data.shape) == 2:
4931
4956
  npix = int(im_shape[1]) # Number of pixels
4932
4957
  N_image = 1
4958
+ N_image2 = 1
4933
4959
  else:
4934
4960
  npix = int(im_shape[0]) # Number of pixels
4935
4961
  N_image = data.shape[0]
4962
+ if data2 is not None:
4963
+ N_image2 = data2.shape[0]
4936
4964
 
4937
4965
  nside = int(npix)
4938
4966
 
@@ -4941,9 +4969,12 @@ class funct(FOC.FoCUS):
4941
4969
  if len(data.shape) == 2:
4942
4970
  npix = int(im_shape[1]) # Number of pixels
4943
4971
  N_image = 1
4972
+ N_image2 = 1
4944
4973
  else:
4945
4974
  npix = int(im_shape[0]) # Number of pixels
4946
4975
  N_image = data.shape[0]
4976
+ if data2 is not None:
4977
+ N_image2 = data2.shape[0]
4947
4978
 
4948
4979
  nside = int(np.sqrt(npix // 12))
4949
4980
 
@@ -4957,7 +4988,475 @@ class funct(FOC.FoCUS):
4957
4988
  print('\n\n==========')
4958
4989
 
4959
4990
  L=self.NORIENT
4991
+ norm_factor_S3=1.0
4992
+
4993
+ if self.backend.BACKEND=='torch':
4994
+ if (M,N,J,L) not in self.filters_set:
4995
+ self.filters_set[(M,N,J,L)] = self.computer_filter(M,N,J,L) #self.computer_filter(M,N,J,L)
4996
+
4997
+ filters_set = self.filters_set[(M,N,J,L)]
4998
+
4999
+ #weight = self.weight
5000
+ if use_ref:
5001
+ if normalization=='S2':
5002
+ ref_S2 = self.ref_scattering_cov_S2
5003
+ else:
5004
+ ref_P11 = self.ref_scattering_cov['P11']
5005
+
5006
+ # convert numpy array input into self.backend.bk_ tensors
5007
+ data = self.backend.bk_cast(data)
5008
+ data_f = self.backend.bk_fftn(data, dim=(-2,-1))
5009
+ if data2 is not None:
5010
+ data2 = self.backend.bk_cast(data2)
5011
+ data2_f = self.backend.bk_fftn(data2, dim=(-2,-1))
5012
+
5013
+ # initialize tensors for scattering coefficients
5014
+ S2 = self.backend.bk_zeros((N_image,J,L), dtype=data.dtype)
5015
+ S1 = self.backend.bk_zeros((N_image,J,L), dtype=data.dtype)
5016
+
5017
+ Ndata_S3 = J*(J+1)//2
5018
+ Ndata_S4 = J*(J+1)*(J+2)//6
5019
+ J_S4={}
5020
+
5021
+ S3 = self.backend.bk_zeros((N_image,Ndata_S3,L,L), dtype=data_f.dtype)
5022
+ if data2 is not None:
5023
+ S3p = self.backend.bk_zeros((N_image,Ndata_S3,L,L), dtype=data_f.dtype)
5024
+ S4_pre_norm = self.backend.bk_zeros((N_image,Ndata_S4,L,L,L), dtype=data_f.dtype)
5025
+ S4 = self.backend.bk_zeros((N_image,Ndata_S4,L,L,L), dtype=data_f.dtype)
5026
+
5027
+ # variance
5028
+ if get_variance:
5029
+ S2_sigma = self.backend.bk_zeros((N_image,J,L), dtype=data.dtype)
5030
+ S1_sigma = self.backend.bk_zeros((N_image,J,L), dtype=data.dtype)
5031
+ S3_sigma = self.backend.bk_zeros((N_image,Ndata_S3,L,L), dtype=data_f.dtype)
5032
+ if data2 is not None:
5033
+ S3p_sigma = self.backend.bk_zeros((N_image,Ndata_S3,L,L), dtype=data_f.dtype)
5034
+ S4_sigma = self.backend.bk_zeros((N_image,Ndata_S4,L,L,L), dtype=data_f.dtype)
5035
+
5036
+ if iso_ang:
5037
+ S3_iso = self.backend.bk_zeros((N_image,Ndata_S3,L), dtype=data_f.dtype)
5038
+ S4_iso = self.backend.bk_zeros((N_image,Ndata_S4,L,L), dtype=data_f.dtype)
5039
+ if get_variance:
5040
+ S3_sigma_iso = self.backend.bk_zeros((N_image,Ndata_S3,L), dtype=data_f.dtype)
5041
+ S4_sigma_iso = self.backend.bk_zeros((N_image,Ndata_S4,L,L), dtype=data_f.dtype)
5042
+ if data2 is not None:
5043
+ S3p_iso = self.backend.bk_zeros((N_image,Ndata_S3,L), dtype=data_f.dtype)
5044
+ if get_variance:
5045
+ S3p_sigma_iso = self.backend.bk_zeros((N_image,Ndata_S3,L), dtype=data_f.dtype)
5046
+
5047
+ #
5048
+ if edge:
5049
+ if (M,N,J) not in self.edge_masks:
5050
+ self.edge_masks[(M,N,J)] = self.get_edge_masks(M,N,J)
5051
+ edge_mask=self.edge_masks[(M,N,J)]
5052
+ else:
5053
+ edge_mask = 1
5054
+
5055
+ # calculate scattering fields
5056
+ if data2 is None:
5057
+ if self.use_2D:
5058
+ if len(data.shape) == 2:
5059
+ I1 = self.backend.bk_ifftn(
5060
+ data_f[None,None,None,:,:] * filters_set[None,:J,:,:,:], dim=(-2,-1)
5061
+ ).abs()
5062
+ else:
5063
+ I1 = self.backend.bk_ifftn(
5064
+ data_f[:,None,None,:,:] * filters_set[None,:J,:,:,:], dim=(-2,-1)
5065
+ ).abs()
5066
+ elif self.use_1D:
5067
+ if len(data.shape) == 1:
5068
+ I1 = self.backend.bk_ifftn(
5069
+ data_f[None,None,None,:] * filters_set[None,:J,:,:], dim=(-1)
5070
+ ).abs()
5071
+ else:
5072
+ I1 = self.backend.bk_ifftn(
5073
+ data_f[:,None,None,:] * filters_set[None,:J,:,:], dim=(-1)
5074
+ ).abs()
5075
+ else:
5076
+ print('todo')
5077
+
5078
+ S2 = (I1**2 * edge_mask).mean((-2,-1))
5079
+ S1 = (I1 * edge_mask).mean((-2,-1))
5080
+
5081
+ if get_variance:
5082
+ S2_sigma = (I1**2 * edge_mask).std((-2,-1))
5083
+ S1_sigma = (I1 * edge_mask).std((-2,-1))
5084
+
5085
+ else:
5086
+ if self.use_2D:
5087
+ if len(data.shape) == 2:
5088
+ I1 = self.backend.bk_ifftn(
5089
+ data_f[None,None,None,:,:] * filters_set[None,:J,:,:,:], dim=(-2,-1)
5090
+ )
5091
+ I2 = self.backend.bk_ifftn(
5092
+ data2_f[None,None,None,:,:] * filters_set[None,:J,:,:,:], dim=(-2,-1)
5093
+ )
5094
+ else:
5095
+ I1 = self.backend.bk_ifftn(
5096
+ data_f[:,None,None,:,:] * filters_set[None,:J,:,:,:], dim=(-2,-1)
5097
+ )
5098
+ I2 = self.backend.bk_ifftn(
5099
+ data2_f[:,None,None,:,:] * filters_set[None,:J,:,:,:], dim=(-2,-1)
5100
+ )
5101
+ elif self.use_1D:
5102
+ if len(data.shape) == 1:
5103
+ I1 = self.backend.bk_ifftn(
5104
+ data_f[None,None,None,:] * filters_set[None,:J,:,:], dim=(-1)
5105
+ )
5106
+ I2 = self.backend.bk_ifftn(
5107
+ data2_f[None,None,None,:] * filters_set[None,:J,:,:], dim=(-1)
5108
+ )
5109
+ else:
5110
+ I1 = self.backend.bk_ifftn(
5111
+ data_f[:,None,None,:] * filters_set[None,:J,:,:], dim=(-1)
5112
+ )
5113
+ I2 = self.backend.bk_ifftn(
5114
+ data2_f[:,None,None,:] * filters_set[None,:J,:,:], dim=(-1)
5115
+ )
5116
+ else:
5117
+ print('todo')
4960
5118
 
5119
+ I1=self.backend.bk_real(I1*self.backend.bk_conjugate(I2))
5120
+
5121
+ S2 = self.backend.bk_reduce_mean((I1 * edge_mask),axis=(-2,-1))
5122
+ if get_variance:
5123
+ S2_sigma = self.backend.bk_reduce_std((I1 * edge_mask),axis=(-2,-1))
5124
+
5125
+ I1=self.backend.bk_L1(I1)
5126
+
5127
+ S1 = self.backend.bk_reduce_mean((I1 * edge_mask),axis=(-2,-1))
5128
+
5129
+ if get_variance:
5130
+ S1_sigma = self.backend.bk_reduce_std((I1 * edge_mask),axis=(-2,-1))
5131
+
5132
+ I1_f= self.backend.bk_fftn(I1, dim=(-2,-1))
5133
+
5134
+ if pseudo_coef != 1:
5135
+ I1 = I1**pseudo_coef
5136
+
5137
+ Ndata_S3=0
5138
+ Ndata_S4=0
5139
+
5140
+ # calculate the covariance and correlations of the scattering fields
5141
+ # only use the low-k Fourier coefs when calculating large-j scattering coefs.
5142
+ for j3 in range(0,J):
5143
+ J_S4[j3]=Ndata_S4
5144
+
5145
+ dx3, dy3 = self.get_dxdy(j3,M,N)
5146
+ I1_f_small = self.cut_high_k_off(I1_f[:,:j3+1], dx3, dy3) # Nimage, J, L, x, y
5147
+ data_f_small = self.cut_high_k_off(data_f, dx3, dy3)
5148
+ if data2 is not None:
5149
+ data2_f_small = self.cut_high_k_off(data2_f, dx3, dy3)
5150
+ if edge:
5151
+ I1_small = self.backend.bk_ifftn(I1_f_small, dim=(-2,-1), norm='ortho')
5152
+ data_small = self.backend.bk_ifftn(data_f_small, dim=(-2,-1), norm='ortho')
5153
+ if data2 is not None:
5154
+ data2_small = self.backend.bk_ifftn(data2_f_small, dim=(-2,-1), norm='ortho')
5155
+
5156
+ wavelet_f3 = self.cut_high_k_off(filters_set[j3], dx3, dy3) # L,x,y
5157
+ _, M3, N3 = wavelet_f3.shape
5158
+ wavelet_f3_squared = wavelet_f3**2
5159
+ edge_dx = min(4, int(2**j3*dx3*2/M))
5160
+ edge_dy = min(4, int(2**j3*dy3*2/N))
5161
+
5162
+ # a normalization change due to the cutoff of frequency space
5163
+ fft_factor = 1 /(M3*N3) * (M3*N3/M/N)**2
5164
+ for j2 in range(0,j3+1):
5165
+ I1_f2_wf3_small = I1_f_small[:,j2].view(N_image,L,1,M3,N3) * wavelet_f3.view(1,1,L,M3,N3)
5166
+ I1_f2_wf3_2_small = I1_f_small[:,j2].view(N_image,L,1,M3,N3) * wavelet_f3_squared.view(1,1,L,M3,N3)
5167
+ if edge:
5168
+ I12_w3_small = self.backend.bk_ifftn(I1_f2_wf3_small, dim=(-2,-1), norm='ortho')
5169
+ I12_w3_2_small = self.backend.bk_ifftn(I1_f2_wf3_2_small, dim=(-2,-1), norm='ortho')
5170
+ if use_ref:
5171
+ if normalization=='P11':
5172
+ norm_factor_S3 = (ref_S2[:,None,j3,:] * ref_P11[:,j2,j3,:,:]**pseudo_coef)**0.5
5173
+ if normalization=='S2':
5174
+ norm_factor_S3 = (ref_S2[:,None,j3,:] * ref_S2[:,j2,:,None]**pseudo_coef)**0.5
5175
+ else:
5176
+ if normalization=='P11':
5177
+ # [N_image,l2,l3,x,y]
5178
+ P11_temp = (I1_f2_wf3_small.abs()**2).mean((-2,-1)) * fft_factor
5179
+ norm_factor_S3 = (S2[:,None,j3,:] * P11_temp**pseudo_coef)**0.5
5180
+ if normalization=='S2':
5181
+ norm_factor_S3 = (S2[:,None,j3,:] * S2[:,j2,:,None]**pseudo_coef)**0.5
5182
+
5183
+ if not edge:
5184
+ S3[:,Ndata_S3,:,:] = (
5185
+ data_f_small.view(N_image,1,1,M3,N3) * self.backend.bk_conjugate(I1_f2_wf3_small)
5186
+ ).mean((-2,-1)) * fft_factor / norm_factor_S3
5187
+
5188
+ if get_variance:
5189
+ S3_sigma[:,Ndata_S3,:,:] = (
5190
+ data_f_small.view(N_image,1,1,M3,N3) * self.backend.bk_conjugate(I1_f2_wf3_small)
5191
+ ).std((-2,-1)) * fft_factor / norm_factor_S3
5192
+ else:
5193
+ S3[:,Ndata_S3,:,:] = (
5194
+ data_small.view(N_image,1,1,M3,N3) * self.backend.bk_conjugate(I12_w3_small)
5195
+ )[...,edge_dx:M3-edge_dx, edge_dy:N3-edge_dy].mean((-2,-1)) * fft_factor / norm_factor_S3
5196
+ if get_variance:
5197
+ S3_sigma[:,Ndata_S3,:,:] = (
5198
+ data_small.view(N_image,1,1,M3,N3) * self.backend.bk_conjugate(I12_w3_small)
5199
+ )[...,edge_dx:M3-edge_dx, edge_dy:N3-edge_dy].std((-2,-1)) * fft_factor / norm_factor_S3
5200
+ if data2 is not None:
5201
+ if not edge:
5202
+ S3p[:,Ndata_S3,:,:] = (
5203
+ data2_f_small.view(N_image2,1,1,M3,N3) * self.backend.bk_conjugate(I1_f2_wf3_small)
5204
+ ).mean((-2,-1)) * fft_factor / norm_factor_S3
5205
+
5206
+ if get_variance:
5207
+ S3p_sigma[:,Ndata_S3,:,:] = (
5208
+ data2_f_small.view(N_image2,1,1,M3,N3) * self.backend.bk_conjugate(I1_f2_wf3_small)
5209
+ ).std((-2,-1)) * fft_factor / norm_factor_S3
5210
+ else:
5211
+ S3p[:,Ndata_S3,:,:] = (
5212
+ data2_small.view(N_image2,1,1,M3,N3) * self.backend.bk_conjugate(I12_w3_small)
5213
+ )[...,edge_dx:M3-edge_dx, edge_dy:N3-edge_dy].mean((-2,-1)) * fft_factor / norm_factor_S3
5214
+ if get_variance:
5215
+ S3p_sigma[:,Ndata_S3,:,:] = (
5216
+ data2_small.view(N_image2,1,1,M3,N3) * self.backend.bk_conjugate(I12_w3_small)
5217
+ )[...,edge_dx:M3-edge_dx, edge_dy:N3-edge_dy].std((-2,-1)) * fft_factor / norm_factor_S3
5218
+ Ndata_S3+=1
5219
+ if j2 <= j3:
5220
+ beg_n=Ndata_S4
5221
+ for j1 in range(0, j2+1):
5222
+ if eval(S4_criteria):
5223
+ if not edge:
5224
+ if not if_large_batch:
5225
+ # [N_image,l1,l2,l3,x,y]
5226
+ S4_pre_norm[:,Ndata_S4,:,:,:] = (
5227
+ I1_f_small[:,j1].view(N_image,L,1,1,M3,N3) *
5228
+ self.backend.bk_conjugate(I1_f2_wf3_2_small.view(N_image,1,L,L,M3,N3))
5229
+ ).mean((-2,-1)) * fft_factor
5230
+ if get_variance:
5231
+ S4_sigma[:,Ndata_S4,:,:,:] = (
5232
+ I1_f_small[:,j1].view(N_image,L,1,1,M3,N3) *
5233
+ self.backend.bk_conjugate(I1_f2_wf3_2_small.view(N_image,1,L,L,M3,N3))
5234
+ ).std((-2,-1)) * fft_factor
5235
+ else:
5236
+ for l1 in range(L):
5237
+ # [N_image,l2,l3,x,y]
5238
+ S4_pre_norm[:,Ndata_S4,l1,:,:] = (
5239
+ I1_f_small[:,j1,l1].view(N_image,1,1,M3,N3) *
5240
+ self.backend.bk_conjugate(I1_f2_wf3_2_small.view(N_image,L,L,M3,N3))
5241
+ ).mean((-2,-1)) * fft_factor
5242
+ if get_variance:
5243
+ S4_sigma[:,Ndata_S4,l1,:,:] = (
5244
+ I1_f_small[:,j1,l1].view(N_image,1,1,M3,N3) *
5245
+ self.backend.bk_conjugate(I1_f2_wf3_2_small.view(N_image,L,L,M3,N3))
5246
+ ).std((-2,-1)) * fft_factor
5247
+ else:
5248
+ if not if_large_batch:
5249
+ # [N_image,l1,l2,l3,x,y]
5250
+ S4_pre_norm[:,Ndata_S4,:,:,:] = (
5251
+ I1_small[:,j1].view(N_image,L,1,1,M3,N3) * self.backend.bk_conjugate(
5252
+ I12_w3_2_small.view(N_image,1,L,L,M3,N3)
5253
+ )
5254
+ )[...,edge_dx:-edge_dx, edge_dy:-edge_dy].mean((-2,-1)) * fft_factor
5255
+ if get_variance:
5256
+ S4_sigma[:,Ndata_S4,:,:,:] = (
5257
+ I1_small[:,j1].view(N_image,L,1,1,M3,N3) * self.backend.bk_conjugate(
5258
+ I12_w3_2_small.view(N_image,1,L,L,M3,N3)
5259
+ )
5260
+ )[...,edge_dx:-edge_dx, edge_dy:-edge_dy].std((-2,-1)) * fft_factor
5261
+ else:
5262
+ for l1 in range(L):
5263
+ # [N_image,l2,l3,x,y]
5264
+ S4_pre_norm[:,Ndata_S4,l1,:,:] = (
5265
+ I1_small[:,j1].view(N_image,1,1,M3,N3) * self.backend.bk_conjugate(
5266
+ I12_w3_2_small.view(N_image,L,L,M3,N3)
5267
+ )
5268
+ )[...,edge_dx:-edge_dx, edge_dy:-edge_dy].mean((-2,-1)) * fft_factor
5269
+ if get_variance:
5270
+ S4_sigma[:,Ndata_S4,l1,:,:] = (
5271
+ I1_small[:,j1].view(N_image,1,1,M3,N3) * self.backend.bk_conjugate(
5272
+ I12_w3_2_small.view(N_image,L,L,M3,N3)
5273
+ )
5274
+ )[...,edge_dx:-edge_dx, edge_dy:-edge_dy].mean((-2,-1)) * fft_factor
5275
+
5276
+ Ndata_S4+=1
5277
+
5278
+ if normalization=='S2':
5279
+ if use_ref:
5280
+ P = ((ref_S2[:,j3:j3+1,:,None,None] * ref_S2[:,j2:j2+1,None,:,None] )**(0.5*pseudo_coef))
5281
+ else:
5282
+ P = ((S2[:,j3:j3+1,:,None,None] * S2[:,j2:j2+1,None,:,None] )**(0.5*pseudo_coef))
5283
+
5284
+ S4[:,beg_n:Ndata_S4,:,:,:]=S4_pre_norm[:,beg_n:Ndata_S4,:,:,:].clone()/P
5285
+
5286
+ if get_variance:
5287
+ S4_sigma[:,beg_n:Ndata_S4,:,:,:] = S4_sigma[:,beg_n:Ndata_S4,:,:,:]/P
5288
+ else:
5289
+ S4=S4_pre_norm
5290
+
5291
+ # average over l1 to obtain simple isotropic statistics
5292
+ if iso_ang:
5293
+ S2_iso = S2.mean(-1)
5294
+ S1_iso = S1.mean(-1)
5295
+ for l1 in range(L):
5296
+ for l2 in range(L):
5297
+ S3_iso[...,(l2-l1)%L] += S3[...,l1,l2]
5298
+ if data2 is not None:
5299
+ S3p_iso[...,(l2-l1)%L] += S3p[...,l1,l2]
5300
+ for l3 in range(L):
5301
+ S4_iso[...,(l2-l1)%L,(l3-l1)%L] += S4[...,l1,l2,l3]
5302
+ S3_iso /= L; S4_iso /= L
5303
+ if data2 is not None:
5304
+ S3p_iso /= L
5305
+
5306
+ if get_variance:
5307
+ S2_sigma_iso = S2_sigma.mean(-1)
5308
+ S1_sigma_iso = S1_sigma.mean(-1)
5309
+ for l1 in range(L):
5310
+ for l2 in range(L):
5311
+ S3_sigma_iso[...,(l2-l1)%L] += S3_sigma[...,l1,l2]
5312
+ if data2 is not None:
5313
+ S3p_sigma_iso[...,(l2-l1)%L] += S3p_sigma[...,l1,l2]
5314
+ for l3 in range(L):
5315
+ S4_sigma_iso[...,(l2-l1)%L,(l3-l1)%L] += S4_sigma[...,l1,l2,l3]
5316
+ S3_sigma_iso /= L; S4_sigma_iso /= L
5317
+ if data2 is not None:
5318
+ S3p_sigma_iso /= L
5319
+
5320
+ mean_data=self.backend.bk_zeros((N_image,1), dtype=data.dtype)
5321
+ std_data=self.backend.bk_zeros((N_image,1), dtype=data.dtype)
5322
+
5323
+ if data2 is None:
5324
+ mean_data[:,0]=data.mean((-2,-1))
5325
+ std_data[:,0]=data.std((-2,-1))
5326
+ else:
5327
+ mean_data[:,0]=(data2*data).mean((-2,-1))
5328
+ std_data[:,0]=(data2*data).std((-2,-1))
5329
+
5330
+ if get_variance:
5331
+ ref_sigma={}
5332
+ if iso_ang:
5333
+ ref_sigma['std_data']=std_data
5334
+ ref_sigma['S1_sigma']=S1_sigma_iso
5335
+ ref_sigma['S2_sigma']=S2_sigma_iso
5336
+ ref_sigma['S3_sigma']=S3_sigma_iso
5337
+ if data2 is not None:
5338
+ ref_sigma['S3p_sigma']=S3p_sigma_iso
5339
+ ref_sigma['S4_sigma']=S4_sigma_iso
5340
+ else:
5341
+ ref_sigma['std_data']=std_data
5342
+ ref_sigma['S1_sigma']=S1_sigma
5343
+ ref_sigma['S2_sigma']=S2_sigma
5344
+ ref_sigma['S3_sigma']=S3_sigma
5345
+ if data2 is not None:
5346
+ ref_sigma['S3p_sigma']=S3p_sigma
5347
+ ref_sigma['S4_sigma']=S4_sigma
5348
+
5349
+ if data2 is None:
5350
+ if iso_ang:
5351
+ if ref_sigma is not None:
5352
+ for_synthesis = self.backend.backend.cat((
5353
+ mean_data/ref_sigma['std_data'],
5354
+ std_data/ref_sigma['std_data'],
5355
+ (S2_iso/ref_sigma['S2_sigma']).reshape((N_image, -1)).log(),
5356
+ (S1_iso/ref_sigma['S1_sigma']).reshape((N_image, -1)).log(),
5357
+ (S3_iso/ref_sigma['S3_sigma']).reshape((N_image, -1)).real,
5358
+ (S3_iso/ref_sigma['S3_sigma']).reshape((N_image, -1)).imag,
5359
+ (S4_iso/ref_sigma['S4_sigma']).reshape((N_image, -1)).real,
5360
+ (S4_iso/ref_sigma['S4_sigma']).reshape((N_image, -1)).imag,
5361
+ ),dim=-1)
5362
+ else:
5363
+ for_synthesis = self.backend.backend.cat((
5364
+ mean_data/std_data,
5365
+ std_data,
5366
+ S2_iso.reshape((N_image, -1)).log(),
5367
+ S1_iso.reshape((N_image, -1)).log(),
5368
+ S3_iso.reshape((N_image, -1)).real,
5369
+ S3_iso.reshape((N_image, -1)).imag,
5370
+ S4_iso.reshape((N_image, -1)).real,
5371
+ S4_iso.reshape((N_image, -1)).imag,
5372
+ ),dim=-1)
5373
+ else:
5374
+ if ref_sigma is not None:
5375
+ for_synthesis = self.backend.backend.cat((
5376
+ mean_data/ref_sigma['std_data'],
5377
+ std_data/ref_sigma['std_data'],
5378
+ (S2/ref_sigma['S2_sigma']).reshape((N_image, -1)).log(),
5379
+ (S1/ref_sigma['S1_sigma']).reshape((N_image, -1)).log(),
5380
+ (S3/ref_sigma['S3_sigma']).reshape((N_image, -1)).real,
5381
+ (S3/ref_sigma['S3_sigma']).reshape((N_image, -1)).imag,
5382
+ (S4/ref_sigma['S4_sigma']).reshape((N_image, -1)).real,
5383
+ (S4/ref_sigma['S4_sigma']).reshape((N_image, -1)).imag,
5384
+ ),dim=-1)
5385
+ else:
5386
+ for_synthesis = self.backend.backend.cat((
5387
+ mean_data/std_data,
5388
+ std_data,
5389
+ S2.reshape((N_image, -1)).log(),
5390
+ S1.reshape((N_image, -1)).log(),
5391
+ S3.reshape((N_image, -1)).real,
5392
+ S3.reshape((N_image, -1)).imag,
5393
+ S4.reshape((N_image, -1)).real,
5394
+ S4.reshape((N_image, -1)).imag,
5395
+ ),dim=-1)
5396
+ else:
5397
+ if iso_ang:
5398
+ if ref_sigma is not None:
5399
+ for_synthesis = self.backend.backend.cat((
5400
+ mean_data/ref_sigma['std_data'],
5401
+ std_data/ref_sigma['std_data'],
5402
+ (S2_iso/ref_sigma['S2_sigma']).reshape((N_image, -1)),
5403
+ (S1_iso/ref_sigma['S1_sigma']).reshape((N_image, -1)),
5404
+ (S3_iso/ref_sigma['S3_sigma']).reshape((N_image, -1)).real,
5405
+ (S3_iso/ref_sigma['S3_sigma']).reshape((N_image, -1)).imag,
5406
+ (S3p_iso/ref_sigma['S3p_sigma']).reshape((N_image, -1)).real,
5407
+ (S3p_iso/ref_sigma['S3p_sigma']).reshape((N_image, -1)).imag,
5408
+ (S4_iso/ref_sigma['S4_sigma']).reshape((N_image, -1)).real,
5409
+ (S4_iso/ref_sigma['S4_sigma']).reshape((N_image, -1)).imag,
5410
+ ),dim=-1)
5411
+ else:
5412
+ for_synthesis = self.backend.backend.cat((
5413
+ mean_data/std_data,
5414
+ std_data,
5415
+ S2_iso.reshape((N_image, -1)),
5416
+ S1_iso.reshape((N_image, -1)),
5417
+ S3_iso.reshape((N_image, -1)).real,
5418
+ S3_iso.reshape((N_image, -1)).imag,
5419
+ S3p_iso.reshape((N_image, -1)).real,
5420
+ S3p_iso.reshape((N_image, -1)).imag,
5421
+ S4_iso.reshape((N_image, -1)).real,
5422
+ S4_iso.reshape((N_image, -1)).imag,
5423
+ ),dim=-1)
5424
+ else:
5425
+ if ref_sigma is not None:
5426
+ for_synthesis = self.backend.backend.cat((
5427
+ mean_data/ref_sigma['std_data'],
5428
+ std_data/ref_sigma['std_data'],
5429
+ (S2/ref_sigma['S2_sigma']).reshape((N_image, -1)),
5430
+ (S1/ref_sigma['S1_sigma']).reshape((N_image, -1)),
5431
+ (S3/ref_sigma['S3_sigma']).reshape((N_image, -1)).real,
5432
+ (S3/ref_sigma['S3_sigma']).reshape((N_image, -1)).imag,
5433
+ (S3p/ref_sigma['S3p_sigma']).reshape((N_image, -1)).real,
5434
+ (S3p/ref_sigma['S3p_sigma']).reshape((N_image, -1)).imag,
5435
+ (S4/ref_sigma['S4_sigma']).reshape((N_image, -1)).real,
5436
+ (S4/ref_sigma['S4_sigma']).reshape((N_image, -1)).imag,
5437
+ ),dim=-1)
5438
+ else:
5439
+ for_synthesis = self.backend.backend.cat((
5440
+ mean_data/std_data,
5441
+ std_data,
5442
+ S2.reshape((N_image, -1)),
5443
+ S1.reshape((N_image, -1)),
5444
+ S3.reshape((N_image, -1)).real,
5445
+ S3.reshape((N_image, -1)).imag,
5446
+ S3p.reshape((N_image, -1)).real,
5447
+ S3p.reshape((N_image, -1)).imag,
5448
+ S4.reshape((N_image, -1)).real,
5449
+ S4.reshape((N_image, -1)).imag,
5450
+ ),dim=-1)
5451
+
5452
+ if not use_ref:
5453
+ self.ref_scattering_cov_S2=S2
5454
+
5455
+ if get_variance:
5456
+ return for_synthesis,ref_sigma
5457
+
5458
+ return for_synthesis
5459
+
4961
5460
  if (M,N,J,L) not in self.filters_set:
4962
5461
  self.filters_set[(M,N,J,L)] = self.computer_filter(M,N,J,L) #self.computer_filter(M,N,J,L)
4963
5462
 
@@ -4978,46 +5477,41 @@ class funct(FOC.FoCUS):
4978
5477
  data2_f = self.backend.bk_fftn(data2, dim=(-2,-1))
4979
5478
 
4980
5479
  # initialize tensors for scattering coefficients
4981
- S2 = self.backend.bk_zeros((N_image,J,L), dtype=data.dtype)
4982
- S1 = self.backend.bk_zeros((N_image,J,L), dtype=data.dtype)
4983
5480
 
4984
5481
  Ndata_S3 = J*(J+1)//2
4985
5482
  Ndata_S4 = J*(J+1)*(J+2)//6
4986
5483
  J_S4={}
4987
5484
 
4988
- S3 = self.backend.bk_zeros((N_image,Ndata_S3,L,L), dtype=data_f.dtype)
5485
+ S3 = []
4989
5486
  if data2 is not None:
4990
- S3p = self.backend.bk_zeros((N_image,Ndata_S3,L,L), dtype=data_f.dtype)
4991
- S4_pre_norm = self.backend.bk_zeros((N_image,Ndata_S4,L,L,L), dtype=data_f.dtype)
4992
- S4 = self.backend.bk_zeros((N_image,Ndata_S4,L,L,L), dtype=data_f.dtype)
5487
+ S3p = []
5488
+ S4_pre_norm = []
5489
+ S4 = []
4993
5490
 
4994
5491
  # variance
4995
5492
  if get_variance:
4996
- S2_sigma = self.backend.bk_zeros((N_image,J,L), dtype=data.dtype)
4997
- S1_sigma = self.backend.bk_zeros((N_image,J,L), dtype=data.dtype)
4998
- S3_sigma = self.backend.bk_zeros((N_image,Ndata_S3,L,L), dtype=data_f.dtype)
5493
+ S3_sigma = []
4999
5494
  if data2 is not None:
5000
- S3p_sigma = self.backend.bk_zeros((N_image,Ndata_S3,L,L), dtype=data_f.dtype)
5001
- S4_sigma = self.backend.bk_zeros((N_image,Ndata_S4,L,L,L), dtype=data_f.dtype)
5495
+ S3p_sigma = []
5496
+ S4_sigma = []
5002
5497
 
5003
5498
  if iso_ang:
5004
- S3_iso = self.backend.bk_zeros((N_image,Ndata_S3,L), dtype=data_f.dtype)
5499
+ S3_iso = []
5005
5500
  if data2 is not None:
5006
- S3p_iso = self.backend.bk_zeros((N_image,Ndata_S3,L,L), dtype=data_f.dtype)
5501
+ S3p_iso = []
5007
5502
 
5008
- S4_iso = self.backend.bk_zeros((N_image,Ndata_S4,L,L), dtype=data_f.dtype)
5503
+ S4_iso = []
5009
5504
  if get_variance:
5010
- S3_sigma_iso = self.backend.bk_zeros((N_image,Ndata_S3,L), dtype=data_f.dtype)
5505
+ S3_sigma_iso = []
5011
5506
  if data2 is not None:
5012
- S3p_sigma_iso = self.backend.bk_zeros((N_image,Ndata_S3,L), dtype=data_f.dtype)
5013
- S4_sigma_iso = self.backend.bk_zeros((N_image,Ndata_S4,L,L), dtype=data_f.dtype)
5507
+ S3p_sigma_iso = []
5508
+ S4_sigma_iso = []
5014
5509
 
5015
5510
  #
5016
5511
  if edge:
5017
5512
  if (M,N,J) not in self.edge_masks:
5018
5513
  self.edge_masks[(M,N,J)] = self.get_edge_masks(M,N,J)
5019
- edge_mask = self.edge_masks[(M,N,J)][:,None,:,:]
5020
- edge_mask = edge_mask / edge_mask.mean((-2,-1))[:,:,None,None]
5514
+ edge_mask = self.edge_masks[(M,N,J)]
5021
5515
  else:
5022
5516
  edge_mask = 1
5023
5517
 
@@ -5025,31 +5519,31 @@ class funct(FOC.FoCUS):
5025
5519
  if data2 is None:
5026
5520
  if self.use_2D:
5027
5521
  if len(data.shape) == 2:
5028
- I1 = self.backend.bk_ifftn(
5522
+ I1 = self.backend.bk_abs(self.backend.bk_ifftn(
5029
5523
  data_f[None,None,None,:,:] * filters_set[None,:J,:,:,:], dim=(-2,-1)
5030
- ).abs()
5524
+ ))
5031
5525
  else:
5032
- I1 = self.backend.bk_ifftn(
5526
+ I1 = self.backend.bk_abs(self.backend.bk_ifftn(
5033
5527
  data_f[:,None,None,:,:] * filters_set[None,:J,:,:,:], dim=(-2,-1)
5034
- ).abs()
5528
+ ))
5035
5529
  elif self.use_1D:
5036
5530
  if len(data.shape) == 1:
5037
- I1 = self.backend.bk_ifftn(
5531
+ I1 = self.backend.bk_abs(self.backend.bk_ifftn(
5038
5532
  data_f[None,None,None,:] * filters_set[None,:J,:,:], dim=(-1)
5039
- ).abs()
5533
+ ))
5040
5534
  else:
5041
- I1 = self.backend.bk_ifftn(
5535
+ I1 = self.backend.bk_abs(self.backend.bk_ifftn(
5042
5536
  data_f[:,None,None,:] * filters_set[None,:J,:,:], dim=(-1)
5043
- ).abs()
5537
+ ))
5044
5538
  else:
5045
5539
  print('todo')
5046
5540
 
5047
- S2 = (I1**2 * edge_mask).mean((-2,-1))
5048
- S1 = (I1 * edge_mask).mean((-2,-1))
5541
+ S2 = self.backend.bk_reduce_mean((I1**2 * edge_mask),axis=(-2,-1))
5542
+ S1 = self.backend.bk_reduce_mean(I1 * edge_mask,axis=(-2,-1))
5049
5543
 
5050
5544
  if get_variance:
5051
- S2_sigma = (I1**2 * edge_mask).std((-2,-1))
5052
- S1_sigma = (I1 * edge_mask).std((-2,-1))
5545
+ S2_sigma = self.backend.bk_reduce_std((I1**2 * edge_mask),axis=(-2,-1))
5546
+ S1_sigma = self.backend.bk_reduce_std((I1 * edge_mask),axis=(-2,-1))
5053
5547
 
5054
5548
  I1_f= self.backend.bk_fftn(I1, dim=(-2,-1))
5055
5549
 
@@ -5089,16 +5583,16 @@ class funct(FOC.FoCUS):
5089
5583
 
5090
5584
  I1=self.backend.bk_real(I1*self.backend.bk_conjugate(I2))
5091
5585
 
5092
- S2 = (I1 * edge_mask).mean((-2,-1))
5586
+ S2 = self.backend.bk_reduce_mean((I1 * edge_mask),axis=(-2,-1))
5093
5587
  if get_variance:
5094
- S2_sigma = (I1 * edge_mask).std((-2,-1))
5588
+ S2_sigma = self.backend.bk_reduce_std((I1 * edge_mask),axis=(-2,-1))
5095
5589
 
5096
5590
  I1=self.backend.bk_L1(I1)
5097
5591
 
5098
- S1 = (I1 * edge_mask).mean((-2,-1))
5592
+ S1 = self.backend.bk_reduce_mean((I1 * edge_mask),axis=(-2,-1))
5099
5593
 
5100
5594
  if get_variance:
5101
- S1_sigma = (I1 * edge_mask).std((-2,-1))
5595
+ S1_sigma = self.backend.bk_reduce_std((I1 * edge_mask),axis=(-2,-1))
5102
5596
 
5103
5597
  I1_f= self.backend.bk_fftn(I1, dim=(-2,-1))
5104
5598
 
@@ -5131,8 +5625,10 @@ class funct(FOC.FoCUS):
5131
5625
  # a normalization change due to the cutoff of frequency space
5132
5626
  fft_factor = 1 /(M3*N3) * (M3*N3/M/N)**2
5133
5627
  for j2 in range(0,j3+1):
5134
- I1_f2_wf3_small = I1_f_small[:,j2].view(N_image,L,1,M3,N3) * wavelet_f3.view(1,1,L,M3,N3)
5135
- I1_f2_wf3_2_small = I1_f_small[:,j2].view(N_image,L,1,M3,N3) * wavelet_f3_squared.view(1,1,L,M3,N3)
5628
+ #I1_f2_wf3_small = I1_f_small[:,j2].view(N_image,L,1,M3,N3) * wavelet_f3.view(1,1,L,M3,N3)
5629
+ #I1_f2_wf3_2_small = I1_f_small[:,j2].view(N_image,L,1,M3,N3) * wavelet_f3_squared.view(1,1,L,M3,N3)
5630
+ I1_f2_wf3_small = self.backend.bk_reshape(I1_f_small[:,j2],[N_image,1,L,1,M3,N3]) * self.backend.bk_reshape(wavelet_f3,[1,1,1,L,M3,N3])
5631
+ I1_f2_wf3_2_small = self.backend.bk_reshape(I1_f_small[:,j2],[N_image,1,L,1,M3,N3]) * self.backend.bk_reshape(wavelet_f3_squared,[1,1,1,L,M3,N3])
5136
5632
  if edge:
5137
5633
  I12_w3_small = self.backend.bk_ifftn(I1_f2_wf3_small, dim=(-2,-1), norm='ortho')
5138
5634
  I12_w3_2_small = self.backend.bk_ifftn(I1_f2_wf3_2_small, dim=(-2,-1), norm='ortho')
@@ -5141,158 +5637,136 @@ class funct(FOC.FoCUS):
5141
5637
  norm_factor_S3 = (ref_S2[:,None,j3,:] * ref_P11[:,j2,j3,:,:]**pseudo_coef)**0.5
5142
5638
  elif normalization=='S2':
5143
5639
  norm_factor_S3 = (ref_S2[:,None,j3,:] * ref_S2[:,j2,:,None]**pseudo_coef)**0.5
5144
- norm_factor_S3 = 1.0
5640
+ else:
5641
+ norm_factor_S3 = 1.0
5145
5642
  else:
5146
5643
  if normalization=='P11':
5147
5644
  # [N_image,l2,l3,x,y]
5148
- P11_temp = (I1_f2_wf3_small.abs()**2).mean((-2,-1)) * fft_factor
5645
+ P11_temp = self.backend.bk_reduce_mean((I1_f2_wf3_small.abs()**2),axis=(-2,-1)) * fft_factor
5149
5646
  norm_factor_S3 = (S2[:,None,j3,:] * P11_temp**pseudo_coef)**0.5
5150
5647
  elif normalization=='S2':
5151
- norm_factor_S3 = (S2[:,None,j3,:] * S2[:,j2,:,None]**pseudo_coef)**0.5
5152
- norm_factor_S3 = 1.0
5648
+ norm_factor_S3 = (S2[:,None,j3,None,:] * S2[:,None,j2,:,None]**pseudo_coef)**0.5
5649
+ else:
5650
+ norm_factor_S3 = 1.0
5651
+
5652
+ norm_factor_S3 = self.backend.bk_complex(norm_factor_S3,0*norm_factor_S3)
5153
5653
 
5154
5654
  if not edge:
5155
- S3[:,Ndata_S3,:,:] = (
5156
- data_f_small.view(N_image,1,1,M3,N3) * self.backend.bk_conjugate(I1_f2_wf3_small)
5157
- ).mean((-2,-1)) * fft_factor / norm_factor_S3
5158
-
5655
+ S3.append(self.backend.bk_reduce_mean(
5656
+ self.backend.bk_reshape(data_f_small,[N_image,1,1,1,M3,N3]) * self.backend.bk_conjugate(I1_f2_wf3_small)
5657
+ ,axis=(-2,-1)) * fft_factor / norm_factor_S3)
5159
5658
  if get_variance:
5160
- S3_sigma[:,Ndata_S3,:,:] = (
5161
- data_f_small.view(N_image,1,1,M3,N3) * self.backend.bk_conjugate(I1_f2_wf3_small)
5162
- ).std((-2,-1)) * fft_factor / norm_factor_S3
5659
+ S3_sigma.append(self.backend.bk_reduce_std(
5660
+ self.backend.bk_reshape(data_f_small,[N_image,1,1,1,M3,N3]) * self.backend.bk_conjugate(I1_f2_wf3_small)
5661
+ ,axis=(-2,-1)) * fft_factor / norm_factor_S3)
5163
5662
  else:
5164
-
5165
- S3[:,Ndata_S3,:,:] = (
5166
- data_small.view(N_image,1,1,M3,N3) * self.backend.bk_conjugate(I12_w3_small)
5167
- )[...,edge_dx:M3-edge_dx, edge_dy:N3-edge_dy].mean((-2,-1)) * fft_factor / norm_factor_S3
5663
+ S3.append(self.backend.bk_reduce_mean(
5664
+ (self.backend.bk_reshape(data_small,[N_image,1,1,1,M3,N3]) * self.backend.bk_conjugate(I12_w3_small)
5665
+ )[...,edge_dx:M3-edge_dx, edge_dy:N3-edge_dy],axis=(-2,-1)) * fft_factor / norm_factor_S3)
5168
5666
  if get_variance:
5169
- S3_sigma[:,Ndata_S3,:,:] = (
5170
- data_small.view(N_image,1,1,M3,N3) * self.backend.bk_conjugate(I12_w3_small)
5171
- )[...,edge_dx:M3-edge_dx, edge_dy:N3-edge_dy].std((-2,-1)) * fft_factor / norm_factor_S3
5667
+ S3_sigma.apend(self.backend.bk_reduce_std(
5668
+ (self.backend.bk_reshape(data_small,[N_image,1,1,1,M3,N3]) * self.backend.bk_conjugate(I12_w3_small)
5669
+ )[...,edge_dx:M3-edge_dx, edge_dy:N3-edge_dy],axis=(-2,-1)) * fft_factor / norm_factor_S3)
5172
5670
  if data2 is not None:
5173
5671
  if not edge:
5174
- S3p[:,Ndata_S3,:,:] = (
5175
- data2_f_small.view(N_image,1,1,M3,N3) * self.backend.bk_conjugate(I1_f2_wf3_small)
5176
- ).mean((-2,-1)) * fft_factor / norm_factor_S3
5672
+ S3p.append(self.backend.bk_reduce_mean(
5673
+ (self.backend.bk_reshape(data2_f_small,[N_image2,1,1,1,M3,N3]) * self.backend.bk_conjugate(I1_f2_wf3_small)
5674
+ ),axis=(-2,-1)) * fft_factor / norm_factor_S3)
5177
5675
 
5178
5676
  if get_variance:
5179
- S3p_sigma[:,Ndata_S3,:,:] = (
5180
- data2_f_small.view(N_image,1,1,M3,N3) * self.backend.bk_conjugate(I1_f2_wf3_small)
5181
- ).std((-2,-1)) * fft_factor / norm_factor_S3
5677
+ S3p_sigma.append(self.backend.bk_reduce_std(
5678
+ (self.backend.bk_reshape(data2_f_small,[N_image2,1,1,1,M3,N3]) * self.backend.bk_conjugate(I1_f2_wf3_small)
5679
+ ),axis=(-2,-1)) * fft_factor / norm_factor_S3)
5182
5680
  else:
5183
5681
 
5184
- S3p[:,Ndata_S3,:,:] = (
5185
- data2_small.view(N_image,1,1,M3,N3) * self.backend.bk_conjugate(I12_w3_small)
5186
- )[...,edge_dx:M3-edge_dx, edge_dy:N3-edge_dy].mean((-2,-1)) * fft_factor / norm_factor_S3
5682
+ S3p.append(self.backend.bk_reduce_mean(
5683
+ (self.backend.bk_reshape(data2_small,[N_image2,1,1,1,M3,N3]) * self.backend.bk_conjugate(I12_w3_small)
5684
+ )[...,edge_dx:M3-edge_dx, edge_dy:N3-edge_dy],axis=(-2,-1)) * fft_factor / norm_factor_S3)
5187
5685
  if get_variance:
5188
- S3p_sigma[:,Ndata_S3,:,:] = (
5189
- data2_small.view(N_image,1,1,M3,N3) * self.backend.bk_conjugate(I12_w3_small)
5190
- )[...,edge_dx:M3-edge_dx, edge_dy:N3-edge_dy].std((-2,-1)) * fft_factor / norm_factor_S3
5191
-
5192
- Ndata_S3+=1
5686
+ S3p_sigma.append(self.backend.bk_reduce_std(
5687
+ (self.backend.bk_reshape(data2_small,[N_image2,1,1,1,M3,N3]) * self.backend.bk_conjugate(I12_w3_small)
5688
+ )[...,edge_dx:M3-edge_dx, edge_dy:N3-edge_dy],axis=(-2,-1)) * fft_factor / norm_factor_S3)
5689
+
5193
5690
  if j2 <= j3:
5194
- beg_n=Ndata_S4
5691
+ if normalization=='S2':
5692
+ if use_ref:
5693
+ P = 1/((ref_S2[:,j3:j3+1,:,None,None] * ref_S2[:,j2:j2+1,None,:,None] )**(0.5*pseudo_coef))
5694
+ else:
5695
+ P = 1/(((S2[:,j3:j3+1,:,None,None] * S2[:,j2:j2+1,None,:,None] )**(0.5*pseudo_coef)))
5696
+ P=self.backend.bk_complex(P,0.0*P)
5697
+ else:
5698
+ P=self.backend.bk_complex(1.0,0.0)
5699
+
5195
5700
  for j1 in range(0, j2+1):
5196
- if eval(S4_criteria):
5197
5701
  if not edge:
5198
5702
  if not if_large_batch:
5199
5703
  # [N_image,l1,l2,l3,x,y]
5200
- S4_pre_norm[:,Ndata_S4,:,:,:] = (
5201
- I1_f_small[:,j1].view(N_image,L,1,1,M3,N3) *
5202
- self.backend.bk_conjugate(I1_f2_wf3_2_small.view(N_image,1,L,L,M3,N3))
5203
- ).mean((-2,-1)) * fft_factor
5704
+ S4.append(self.backend.bk_reduce_mean(
5705
+ (self.backend.bk_reshape(I1_f_small[:,j1],[N_image,1,L,1,1,M3,N3]) *
5706
+ self.backend.bk_conjugate(self.backend.bk_reshape(I1_f2_wf3_2_small,[N_image,1,1,L,L,M3,N3]))
5707
+ ),axis=(-2,-1)) * fft_factor*P)
5204
5708
  if get_variance:
5205
- S4_sigma[:,Ndata_S4,:,:,:] = (
5206
- I1_f_small[:,j1].view(N_image,L,1,1,M3,N3) *
5207
- self.backend.bk_conjugate(I1_f2_wf3_2_small.view(N_image,1,L,L,M3,N3))
5208
- ).std((-2,-1)) * fft_factor
5709
+ S4_sigma.append(self.backend.bk_reduce_std(
5710
+ (self.backend.bk_reshape(I1_f_small[:,j1],[N_image,1,L,1,1,M3,N3]) *
5711
+ self.backend.bk_conjugate(self.backend.bk_reshape(I1_f2_wf3_2_small,[N_image,1,1,L,L,M3,N3]))
5712
+ ),axis=(-2,-1)) * fft_factor*P)
5209
5713
  else:
5210
5714
  for l1 in range(L):
5211
5715
  # [N_image,l2,l3,x,y]
5212
- S4_pre_norm[:,Ndata_S4,l1,:,:] = (
5213
- I1_f_small[:,j1,l1].view(N_image,1,1,M3,N3) *
5214
- self.backend.bk_conjugate(I1_f2_wf3_2_small.view(N_image,L,L,M3,N3))
5215
- ).mean((-2,-1)) * fft_factor
5716
+ S4.append(self.backend.bk_reduce_mean(
5717
+ (self.backend.bk_reshape(I1_f_small[:,j1,l1],[N_image,1,1,1,M3,N3]) *
5718
+ self.backend.bk_conjugate(self.backend.bk_reshape(I1_f2_wf3_2_small,[N_image,1,L,L,M3,N3]))
5719
+ ),axis=(-2,-1)) * fft_factor*P)
5216
5720
  if get_variance:
5217
- S4_sigma[:,Ndata_S4,l1,:,:] = (
5218
- I1_f_small[:,j1,l1].view(N_image,1,1,M3,N3) *
5219
- self.backend.bk_conjugate(I1_f2_wf3_2_small.view(N_image,L,L,M3,N3))
5220
- ).std((-2,-1)) * fft_factor
5721
+ S4_sigma.append(self.backend.bk_reduce_std(
5722
+ (self.backend.bk_reshape(I1_f_small[:,j1,l1],[N_image,1,1,1,M3,N3]) *
5723
+ self.backend.bk_conjugate(self.backend.bk_reshape(I1_f2_wf3_2_small,[N_image,1,L,L,M3,N3]))
5724
+ ),axis=(-2,-1)) * fft_factor*P)
5221
5725
  else:
5222
5726
  if not if_large_batch:
5223
5727
  # [N_image,l1,l2,l3,x,y]
5224
- S4_pre_norm[:,Ndata_S4,:,:,:] = (
5225
- I1_small[:,j1].view(N_image,L,1,1,M3,N3) * self.backend.bk_conjugate(
5226
- I12_w3_2_small.view(N_image,1,L,L,M3,N3)
5728
+ S4.append(self.backend.bk_reduce_mean(
5729
+ (self.backend.bk_reshape(I1_small[:,j1],[N_image,1,L,1,1,M3,N3]) * self.backend.bk_conjugate(
5730
+ self.backend.bk_reshape(I12_w3_2_small,[N_image,1,1,L,L,M3,N3])
5227
5731
  )
5228
- )[...,edge_dx:-edge_dx, edge_dy:-edge_dy].mean((-2,-1)) * fft_factor
5732
+ )[...,edge_dx:-edge_dx, edge_dy:-edge_dy],axis=(-2,-1)) * fft_factor*P)
5229
5733
  if get_variance:
5230
- S4_sigma[:,Ndata_S4,:,:,:] = (
5231
- I1_small[:,j1].view(N_image,L,1,1,M3,N3) * self.backend.bk_conjugate(
5232
- I12_w3_2_small.view(N_image,1,L,L,M3,N3)
5734
+ S4_sigma.append(self.backend.bk_reduce_std(
5735
+ (self.backend.bk_reshape(I1_small[:,j1],[N_image,1,L,1,1,M3,N3]) * self.backend.bk_conjugate(
5736
+ self.backend.bk_reshape(I12_w3_2_small,[N_image,1,1,L,L,M3,N3])
5233
5737
  )
5234
- )[...,edge_dx:-edge_dx, edge_dy:-edge_dy].std((-2,-1)) * fft_factor
5738
+ )[...,edge_dx:-edge_dx, edge_dy:-edge_dy],axis=(-2,-1)) * fft_factor*P)
5235
5739
  else:
5236
5740
  for l1 in range(L):
5237
5741
  # [N_image,l2,l3,x,y]
5238
- S4_pre_norm[:,Ndata_S4,l1,:,:] = (
5239
- I1_small[:,j1].view(N_image,1,1,M3,N3) * self.backend.bk_conjugate(
5240
- I12_w3_2_small.view(N_image,L,L,M3,N3)
5742
+ S4.append(self.backend.bk_reduce_mean(
5743
+ (self.backend.bk_reshape(I1_small[:,j1],[N_image,1,1,1,M3,N3]) * self.backend.bk_conjugate(
5744
+ self.backend.bk_reshape(I12_w3_2_small,[N_image,1,L,L,M3,N3])
5241
5745
  )
5242
- )[...,edge_dx:-edge_dx, edge_dy:-edge_dy].mean((-2,-1)) * fft_factor
5746
+ )[...,edge_dx:-edge_dx, edge_dy:-edge_dy],axis=(-2,-1)) * fft_factor*P)
5243
5747
  if get_variance:
5244
- S4_sigma[:,Ndata_S4,l1,:,:] = (
5245
- I1_small[:,j1].view(N_image,1,1,M3,N3) * self.backend.bk_conjugate(
5246
- I12_w3_2_small.view(N_image,L,L,M3,N3)
5748
+ S4_sigma.append(self.backend.bk_reduce_std(
5749
+ (self.backend.bk_reshape(I1_small[:,j1],[N_image,1,1,1,M3,N3]) * self.backend.bk_conjugate(
5750
+ self.backend.bk_reshape(I12_w3_2_small,[N_image,1,L,L,M3,N3])
5247
5751
  )
5248
- )[...,edge_dx:-edge_dx, edge_dy:-edge_dy].mean((-2,-1)) * fft_factor
5752
+ )[...,edge_dx:-edge_dx, edge_dy:-edge_dy],axis=(-2,-1)) * fft_factor*P)
5249
5753
 
5250
- Ndata_S4+=1
5251
-
5252
- if normalization=='S2':
5253
- if use_ref:
5254
- P = (ref_S2[:,j3,:,None,None] * ref_S2[:,j2,None,:,None] )**(0.5*pseudo_coef)
5255
- else:
5256
- P = ((S2[:,j3,:,None,None] * S2[:,j2,None,:,None] )**(0.5*pseudo_coef))
5257
- S4[:,beg_n:Ndata_S4,:,:,:]=S4_pre_norm[:,beg_n:Ndata_S4,:,:,:].clone()/(P.clone())
5258
-
5259
- if get_variance:
5260
- S4_sigma[:,beg_n:Ndata_S4,:,:,:] = S4_sigma[:,beg_n:Ndata_S4,:,:,:] / (P)
5261
- else:
5262
- S4[:,beg_n:Ndata_S4,:,:,:]=S4_pre_norm[:,beg_n:Ndata_S4,:,:,:].clone()
5263
-
5264
- if get_variance:
5265
- S4_sigma[:,beg_n:Ndata_S4,:,:,:] = S4_sigma[:,beg_n:Ndata_S4,:,:,:]
5266
-
5267
- """
5268
- # define P11 from diagonals of S4
5269
- for j1 in range(J):
5270
- for l1 in range(L):
5271
- P11[:,j1,:,l1,:] = S4_pre_norm[:,j1,j1,:,l1,l1,:].real
5272
-
5273
-
5274
- if normalization=='S4':
5275
- if use_ref:
5276
- P = ref_P11
5277
- else:
5278
- P = P11
5279
- #.view(N_image,J,1,J,L,1,L) * .view(N_image,1,J,J,1,L,L)
5280
- S4 = S4_pre_norm / (
5281
- P[:,:,None,:,:,None,:] * P[:,None,:,:,None,:,:]
5282
- )**(0.5*pseudo_coef)
5283
-
5284
-
5754
+ S3=self.backend.bk_concat(S3,axis=1)
5755
+ S4=self.backend.bk_concat(S4,axis=1)
5285
5756
 
5757
+ if get_variance:
5758
+ S3_sigma=self.backend.bk_concat(S3_sigma,axis=1)
5759
+ S4_sigma=self.backend.bk_concat(S4_sigma,axis=1)
5760
+
5761
+ if data2 is not None:
5762
+ S3p=self.backend.bk_concat(S3p,axis=1)
5763
+ if get_variance:
5764
+ S3p_sigma=self.backend.bk_concat(S3p_sigma,axis=1)
5286
5765
 
5287
- # get a single, flattened data vector for_synthesis
5288
- select_and_index = self.get_scattering_index(J, L, normalization, S4_criteria)
5289
- index_for_synthesis = select_and_index['index_for_synthesis']
5290
- index_for_synthesis_iso = select_and_index['index_for_synthesis_iso']
5291
- """
5292
5766
  # average over l1 to obtain simple isotropic statistics
5293
5767
  if iso_ang:
5294
- S2_iso = S2.mean(-1)
5295
- S1_iso = S1.mean(-1)
5768
+ S2_iso = self.backend.bk_reduce_mean(S2,axis=(-1))
5769
+ S1_iso = self.backend.bk_reduce_mean(S1,axis=(-1))
5296
5770
  for l1 in range(L):
5297
5771
  for l2 in range(L):
5298
5772
  S3_iso[...,(l2-l1)%L] += S3[...,l1,l2]
@@ -5305,8 +5779,8 @@ class funct(FOC.FoCUS):
5305
5779
  S3p_iso /= L
5306
5780
 
5307
5781
  if get_variance:
5308
- S2_sigma_iso = S2_sigma.mean(-1)
5309
- S1_sigma_iso = S1_sigma.mean(-1)
5782
+ S2_sigma_iso = self.backend.bk_reduce_mean(S2_sigma,axis=(-1))
5783
+ S1_sigma_iso = self.backend.bk_reduce_mean(S1_sigma,axis=(-1))
5310
5784
  for l1 in range(L):
5311
5785
  for l2 in range(L):
5312
5786
  S3_sigma_iso[...,(l2-l1)%L] += S3_sigma[...,l1,l2]
@@ -5318,10 +5792,12 @@ class funct(FOC.FoCUS):
5318
5792
  if data2 is not None:
5319
5793
  S3p_sigma_iso /= L
5320
5794
 
5321
- mean_data=self.backend.bk_zeros((N_image,1), dtype=data.dtype)
5322
- std_data=self.backend.bk_zeros((N_image,1), dtype=data.dtype)
5323
- mean_data[:,0]=data.mean((-2,-1))
5324
- std_data[:,0]=data.std((-2,-1))
5795
+ if data2 is None:
5796
+ mean_data=self.backend.bk_reshape(self.backend.bk_reduce_mean(data,axis=(-2,-1)),[N_image,1])
5797
+ std_data=self.backend.bk_reshape(self.backend.bk_reduce_std(data,axis=(-2,-1)),[N_image,1])
5798
+ else:
5799
+ mean_data=self.backend.bk_reshape(self.backend.bk_reduce_mean(data*data2,axis=(-2,-1)),[N_image,1])
5800
+ std_data=self.backend.bk_reshape(self.backend.bk_reduce_std(data*data2,axis=(-2,-1)),[N_image,1])
5325
5801
 
5326
5802
  if get_variance:
5327
5803
  ref_sigma={}
@@ -5345,105 +5821,105 @@ class funct(FOC.FoCUS):
5345
5821
  if data2 is None:
5346
5822
  if iso_ang:
5347
5823
  if ref_sigma is not None:
5348
- for_synthesis = self.backend.backend.cat((
5824
+ for_synthesis = self.backend.bk_concat((
5349
5825
  mean_data/ref_sigma['std_data'],
5350
5826
  std_data/ref_sigma['std_data'],
5351
- (S2_iso/ref_sigma['S2_sigma']).reshape((N_image, -1)).log(),
5352
- (S1_iso/ref_sigma['S1_sigma']).reshape((N_image, -1)).log(),
5353
- (S3_iso/ref_sigma['S3_sigma']).reshape((N_image, -1)).real,
5354
- (S3_iso/ref_sigma['S3_sigma']).reshape((N_image, -1)).imag,
5355
- (S4_iso/ref_sigma['S4_sigma']).reshape((N_image, -1)).real,
5356
- (S4_iso/ref_sigma['S4_sigma']).reshape((N_image, -1)).imag,
5357
- ),dim=-1)
5827
+ self.backend.bk_reshape(self.backend.bk_log(S2_iso/ref_sigma['S2_sigma']),[N_image, -1]),
5828
+ self.backend.bk_reshape(self.backend.bk_log(S1_iso/ref_sigma['S1_sigma']),[N_image, -1]),
5829
+ self.backend.bk_reshape(self.backend.bk_real(S3_iso/ref_sigma['S3_sigma']),[N_image, -1]),
5830
+ self.backend.bk_reshape(self.backend.bk_imag(S3_iso/ref_sigma['S3_sigma']),[N_image, -1]),
5831
+ self.backend.bk_reshape(self.backend.bk_real(S4_iso/ref_sigma['S4_sigma']),[N_image, -1]),
5832
+ self.backend.bk_reshape(self.backend.bk_imag(S4_iso/ref_sigma['S4_sigma']),[N_image, -1]),
5833
+ ),axis=-1)
5358
5834
  else:
5359
- for_synthesis = self.backend.backend.cat((
5835
+ for_synthesis = self.backend.bk_concat((
5360
5836
  mean_data/std_data,
5361
5837
  std_data,
5362
- S2_iso.reshape((N_image, -1)).log(),
5363
- S1_iso.reshape((N_image, -1)).log(),
5364
- S3_iso.reshape((N_image, -1)).real,
5365
- S3_iso.reshape((N_image, -1)).imag,
5366
- S4_iso.reshape((N_image, -1)).real,
5367
- S4_iso.reshape((N_image, -1)).imag,
5368
- ),dim=-1)
5838
+ self.backend.bk_reshape(self.backend.bk_log(S2_iso),[N_image, -1]),
5839
+ self.backend.bk_reshape(self.backend.bk_log(S1_iso),[N_image, -1]),
5840
+ self.backend.bk_reshape(self.backend.bk_real(S3_iso),[N_image, -1]),
5841
+ self.backend.bk_reshape(self.backend.bk_imag(S3_iso),[N_image, -1]),
5842
+ self.backend.bk_reshape(self.backend.bk_real(S4_iso),[N_image, -1]),
5843
+ self.backend.bk_reshape(self.backend.bk_imag(S4_iso),[N_image, -1]),
5844
+ ),axis=-1)
5369
5845
  else:
5370
5846
  if ref_sigma is not None:
5371
- for_synthesis = self.backend.backend.cat((
5847
+ for_synthesis = self.backend.bk_concat((
5372
5848
  mean_data/ref_sigma['std_data'],
5373
5849
  std_data/ref_sigma['std_data'],
5374
- (S2/ref_sigma['S2_sigma']).reshape((N_image, -1)).log(),
5375
- (S1/ref_sigma['S1_sigma']).reshape((N_image, -1)).log(),
5376
- (S3/ref_sigma['S3_sigma']).reshape((N_image, -1)).real,
5377
- (S3/ref_sigma['S3_sigma']).reshape((N_image, -1)).imag,
5378
- (S4/ref_sigma['S4_sigma']).reshape((N_image, -1)).real,
5379
- (S4/ref_sigma['S4_sigma']).reshape((N_image, -1)).imag,
5380
- ),dim=-1)
5850
+ self.backend.bk_reshape(self.backend.bk_log(S2/ref_sigma['S2_sigma']),[N_image, -1]),
5851
+ self.backend.bk_reshape(self.backend.bk_log(S1/ref_sigma['S1_sigma']),[N_image, -1]),
5852
+ self.backend.bk_reshape(self.backend.bk_real(S3/ref_sigma['S3_sigma']),[N_image, -1]),
5853
+ self.backend.bk_reshape(self.backend.bk_imag(S3/ref_sigma['S3_sigma']),[N_image, -1]),
5854
+ self.backend.bk_reshape(self.backend.bk_real(S4/ref_sigma['S4_sigma']),[N_image, -1]),
5855
+ self.backend.bk_reshape(self.backend.bk_imag(S4/ref_sigma['S4_sigma']),[N_image, -1]),
5856
+ ),axis=-1)
5381
5857
  else:
5382
- for_synthesis = self.backend.backend.cat((
5383
- mean_data/std_data,
5384
- std_data,
5385
- S2.reshape((N_image, -1)).log(),
5386
- S1.reshape((N_image, -1)).log(),
5387
- S3.reshape((N_image, -1)).real,
5388
- S3.reshape((N_image, -1)).imag,
5389
- S4.reshape((N_image, -1)).real,
5390
- S4.reshape((N_image, -1)).imag,
5391
- ),dim=-1)
5858
+ for_synthesis = self.backend.bk_concat((
5859
+ mean_data/std_data,
5860
+ std_data,
5861
+ self.backend.bk_reshape(self.backend.bk_log(S2),[N_image, -1]),
5862
+ self.backend.bk_reshape(self.backend.bk_log(S1),[N_image, -1]),
5863
+ self.backend.bk_reshape(self.backend.bk_real(S3),[N_image, -1]),
5864
+ self.backend.bk_reshape(self.backend.bk_imag(S3),[N_image, -1]),
5865
+ self.backend.bk_reshape(self.backend.bk_real(S4),[N_image, -1]),
5866
+ self.backend.bk_reshape(self.backend.bk_imag(S4),[N_image, -1])
5867
+ ),axis=-1)
5392
5868
  else:
5393
5869
  if iso_ang:
5394
5870
  if ref_sigma is not None:
5395
5871
  for_synthesis = self.backend.backend.cat((
5396
5872
  mean_data/ref_sigma['std_data'],
5397
5873
  std_data/ref_sigma['std_data'],
5398
- (S2_iso/ref_sigma['S2_sigma']).reshape((N_image, -1)),
5399
- (S1_iso/ref_sigma['S1_sigma']).reshape((N_image, -1)),
5400
- (S3_iso/ref_sigma['S3_sigma']).reshape((N_image, -1)).real,
5401
- (S3_iso/ref_sigma['S3_sigma']).reshape((N_image, -1)).imag,
5402
- (S3p_iso/ref_sigma['S3p_sigma']).reshape((N_image, -1)).real,
5403
- (S3p_iso/ref_sigma['S3p_sigma']).reshape((N_image, -1)).imag,
5404
- (S4_iso/ref_sigma['S4_sigma']).reshape((N_image, -1)).real,
5405
- (S4_iso/ref_sigma['S4_sigma']).reshape((N_image, -1)).imag,
5406
- ),dim=-1)
5874
+ self.backend.bk_reshape(self.backend.bk_real(S2_iso/ref_sigma['S2_sigma']),[N_image, -1]),
5875
+ self.backend.bk_reshape(self.backend.bk_real(S1_iso/ref_sigma['S1_sigma']),[N_image, -1]),
5876
+ self.backend.bk_reshape(self.backend.bk_real(S3_iso/ref_sigma['S3_sigma']),[N_image, -1]),
5877
+ self.backend.bk_reshape(self.backend.bk_imag(S3_iso/ref_sigma['S3_sigma']),[N_image, -1]),
5878
+ self.backend.bk_reshape(self.backend.bk_real(S3p_iso/ref_sigma['S3p_sigma']),[N_image, -1]),
5879
+ self.backend.bk_reshape(self.backend.bk_imag(S3p_iso/ref_sigma['S3p_sigma']),[N_image, -1]),
5880
+ self.backend.bk_reshape(self.backend.bk_real(S4_iso/ref_sigma['S4_sigma']),[N_image, -1]),
5881
+ self.backend.bk_reshape(self.backend.bk_imag(S4_iso/ref_sigma['S4_sigma']),[N_image, -1]),
5882
+ ),axis=-1)
5407
5883
  else:
5408
5884
  for_synthesis = self.backend.backend.cat((
5409
5885
  mean_data/std_data,
5410
5886
  std_data,
5411
- S2_iso.reshape((N_image, -1)),
5412
- S1_iso.reshape((N_image, -1)),
5413
- S3_iso.reshape((N_image, -1)).real,
5414
- S3_iso.reshape((N_image, -1)).imag,
5415
- S3p_iso.reshape((N_image, -1)).real,
5416
- S3p_iso.reshape((N_image, -1)).imag,
5417
- S4_iso.reshape((N_image, -1)).real,
5418
- S4_iso.reshape((N_image, -1)).imag,
5419
- ),dim=-1)
5887
+ self.backend.bk_reshape(self.backend.bk_real(S2_iso),[N_image, -1]),
5888
+ self.backend.bk_reshape(self.backend.bk_real(S1_iso),[N_image, -1]),
5889
+ self.backend.bk_reshape(self.backend.bk_real(S3_iso),[N_image, -1]),
5890
+ self.backend.bk_reshape(self.backend.bk_imag(S3_iso),[N_image, -1]),
5891
+ self.backend.bk_reshape(self.backend.bk_real(S3p_iso),[N_image, -1]),
5892
+ self.backend.bk_reshape(self.backend.bk_imag(S3p_iso),[N_image, -1]),
5893
+ self.backend.bk_reshape(self.backend.bk_real(S4_iso),[N_image, -1]),
5894
+ self.backend.bk_reshape(self.backend.bk_imag(S4_iso),[N_image, -1]),
5895
+ ),axis=-1)
5420
5896
  else:
5421
5897
  if ref_sigma is not None:
5422
5898
  for_synthesis = self.backend.backend.cat((
5423
5899
  mean_data/ref_sigma['std_data'],
5424
5900
  std_data/ref_sigma['std_data'],
5425
- (S2/ref_sigma['S2_sigma']).reshape((N_image, -1)),
5426
- (S1/ref_sigma['S1_sigma']).reshape((N_image, -1)),
5427
- (S3/ref_sigma['S3_sigma']).reshape((N_image, -1)).real,
5428
- (S3/ref_sigma['S3_sigma']).reshape((N_image, -1)).imag,
5429
- (S3p/ref_sigma['S3p_sigma']).reshape((N_image, -1)).real,
5430
- (S3p/ref_sigma['S3p_sigma']).reshape((N_image, -1)).imag,
5431
- (S4/ref_sigma['S4_sigma']).reshape((N_image, -1)).real,
5432
- (S4/ref_sigma['S4_sigma']).reshape((N_image, -1)).imag,
5433
- ),dim=-1)
5901
+ self.backend.bk_reshape(self.backend.bk_real(S2/ref_sigma['S2_sigma']),[N_image, -1]),
5902
+ self.backend.bk_reshape(self.backend.bk_real(S1/ref_sigma['S1_sigma']),[N_image, -1]),
5903
+ self.backend.bk_reshape(self.backend.bk_real(S3/ref_sigma['S3_sigma']),[N_image, -1]),
5904
+ self.backend.bk_reshape(self.backend.bk_imag(S3/ref_sigma['S3_sigma']),[N_image, -1]),
5905
+ self.backend.bk_reshape(self.backend.bk_real(S3p/ref_sigma['S3p_sigma']),[N_image, -1]),
5906
+ self.backend.bk_reshape(self.backend.bk_imag(S3p/ref_sigma['S3p_sigma']),[N_image, -1]),
5907
+ self.backend.bk_reshape(self.backend.bk_real(S4/ref_sigma['S4_sigma']),[N_image, -1]),
5908
+ self.backend.bk_reshape(self.backend.bk_imag(S4/ref_sigma['S4_sigma']),[N_image, -1]),
5909
+ ),axis=-1)
5434
5910
  else:
5435
- for_synthesis = self.backend.backend.cat((
5436
- mean_data/std_data,
5437
- std_data,
5438
- S2.reshape((N_image, -1)),
5439
- S1.reshape((N_image, -1)),
5440
- S3.reshape((N_image, -1)).real,
5441
- S3.reshape((N_image, -1)).imag,
5442
- S3p.reshape((N_image, -1)).real,
5443
- S3p.reshape((N_image, -1)).imag,
5444
- S4.reshape((N_image, -1)).real,
5445
- S4.reshape((N_image, -1)).imag,
5446
- ),dim=-1)
5911
+ for_synthesis = self.backend.bk_concat((
5912
+ mean_data/std_data,
5913
+ std_data,
5914
+ self.backend.bk_reshape(self.backend.bk_real(S2),[N_image, -1]),
5915
+ self.backend.bk_reshape(self.backend.bk_real(S1),[N_image, -1]),
5916
+ self.backend.bk_reshape(self.backend.bk_real(S3),[N_image, -1]),
5917
+ self.backend.bk_reshape(self.backend.bk_imag(S3),[N_image, -1]),
5918
+ self.backend.bk_reshape(self.backend.bk_real(S3p),[N_image, -1]),
5919
+ self.backend.bk_reshape(self.backend.bk_imag(S3p),[N_image, -1]),
5920
+ self.backend.bk_reshape(self.backend.bk_real(S4),[N_image, -1]),
5921
+ self.backend.bk_reshape(self.backend.bk_imag(S4),[N_image, -1])
5922
+ ),axis=-1)
5447
5923
 
5448
5924
  if not use_ref:
5449
5925
  self.ref_scattering_cov_S2=S2
@@ -5835,6 +6311,7 @@ class funct(FOC.FoCUS):
5835
6311
  to_gaussian=True,
5836
6312
  use_variance=False,
5837
6313
  synthesised_N=1,
6314
+ input_image=None,
5838
6315
  iso_ang=False,
5839
6316
  EVAL_FREQUENCY=100,
5840
6317
  NUM_EPOCHS = 300):
@@ -5892,14 +6369,22 @@ class funct(FOC.FoCUS):
5892
6369
 
5893
6370
  for k in range(nstep):
5894
6371
  if k==0:
5895
- np.random.seed(seed)
5896
- if self.use_2D:
5897
- imap=np.random.randn(synthesised_N,
5898
- tmp[k].shape[1],
5899
- tmp[k].shape[2])
6372
+ if input_image is None:
6373
+ np.random.seed(seed)
6374
+ if self.use_2D:
6375
+ imap=np.random.randn(synthesised_N,
6376
+ tmp[k].shape[1],
6377
+ tmp[k].shape[2])
6378
+ else:
6379
+ imap=np.random.randn(synthesised_N,
6380
+ tmp[k].shape[1])
5900
6381
  else:
5901
- imap=np.random.randn(synthesised_N,
5902
- tmp[k].shape[1])
6382
+ if self.use_2D:
6383
+ imap=self.backend.bk_reshape(self.backend.bk_tile(self.backend.bk_cast(input_image.flatten()),synthesised_N),
6384
+ [synthesised_N,tmp[k].shape[1],tmp[k].shape[2]])
6385
+ else:
6386
+ imap=self.backend.bk_reshape(self.backend.bk_tile(self.backend.bk_cast(input_image.flatten()),synthesised_N),
6387
+ [synthesised_N,tmp[k].shape[1]])
5903
6388
  else:
5904
6389
  # Increase the resolution between each step
5905
6390
  if self.use_2D: