foscat 3.7.3__py3-none-any.whl → 3.8.2__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
foscat/scat_cov.py CHANGED
@@ -2277,14 +2277,14 @@ class funct(FOC.FoCUS):
2277
2277
  if tmp.S1 is not None:
2278
2278
  nS1 = np.expand_dims(tmp.S1, 0)
2279
2279
  else:
2280
- nS0 = np.expand_dims(tmp.S0.numpy(), 0)
2281
- nS2 = np.expand_dims(tmp.S2.numpy(), 0)
2282
- nS3 = np.expand_dims(tmp.S3.numpy(), 0)
2283
- nS4 = np.expand_dims(tmp.S4.numpy(), 0)
2280
+ nS0 = np.expand_dims(self.backend.to_numpy(tmp.S0), 0)
2281
+ nS2 = np.expand_dims(self.backend.to_numpy(tmp.S2), 0)
2282
+ nS3 = np.expand_dims(self.backend.to_numpy(tmp.S3), 0)
2283
+ nS4 = np.expand_dims(self.backend.to_numpy(tmp.S4), 0)
2284
2284
  if tmp.S3P is not None:
2285
- nS3P = np.expand_dims(tmp.S3P.numpy(), 0)
2285
+ nS3P = np.expand_dims(self.backend.to_numpy(tmp.S3P), 0)
2286
2286
  if tmp.S1 is not None:
2287
- nS1 = np.expand_dims(tmp.S1.numpy(), 0)
2287
+ nS1 = np.expand_dims(self.backend.to_numpy(tmp.S1), 0)
2288
2288
 
2289
2289
  if S0 is None:
2290
2290
  S0 = nS0
@@ -2304,24 +2304,24 @@ class funct(FOC.FoCUS):
2304
2304
  S3P = np.concatenate([S3P, nS3P], 0)
2305
2305
  if tmp.S1 is not None:
2306
2306
  S1 = np.concatenate([S1, nS1], 0)
2307
- sS0 = np.std(S0, 0)
2308
- sS2 = np.std(S2, 0)
2309
- sS3 = np.std(S3, 0)
2310
- sS4 = np.std(S4, 0)
2311
- mS0 = np.mean(S0, 0)
2312
- mS2 = np.mean(S2, 0)
2313
- mS3 = np.mean(S3, 0)
2314
- mS4 = np.mean(S4, 0)
2307
+ sS0 = self.backend.bk_cast(np.std(S0, 0))
2308
+ sS2 = self.backend.bk_cast(np.std(S2, 0))
2309
+ sS3 = self.backend.bk_cast(np.std(S3, 0))
2310
+ sS4 = self.backend.bk_cast(np.std(S4, 0))
2311
+ mS0 = self.backend.bk_cast(np.mean(S0, 0))
2312
+ mS2 = self.backend.bk_cast(np.mean(S2, 0))
2313
+ mS3 = self.backend.bk_cast(np.mean(S3, 0))
2314
+ mS4 = self.backend.bk_cast(np.mean(S4, 0))
2315
2315
  if tmp.S3P is not None:
2316
- sS3P = np.std(S3P, 0)
2317
- mS3P = np.mean(S3P, 0)
2316
+ sS3P = self.backend.bk_cast(np.std(S3P, 0))
2317
+ mS3P = self.backend.bk_cast(np.mean(S3P, 0))
2318
2318
  else:
2319
2319
  sS3P = None
2320
2320
  mS3P = None
2321
2321
 
2322
2322
  if tmp.S1 is not None:
2323
- sS1 = np.std(S1, 0)
2324
- mS1 = np.mean(S1, 0)
2323
+ sS1 = self.backend.bk_cast(np.std(S1, 0))
2324
+ mS1 = self.backend.bk_cast(np.mean(S1, 0))
2325
2325
  else:
2326
2326
  sS1 = None
2327
2327
  mS1 = None
@@ -2375,14 +2375,13 @@ class funct(FOC.FoCUS):
2375
2375
 
2376
2376
  # instead of difference between "opposite" channels use weighted average
2377
2377
  # of cosine and sine contributions using all channels
2378
- angles = (
2378
+ angles = self.backend.bk_cast((
2379
2379
  2 * np.pi * np.arange(self.NORIENT) / self.NORIENT
2380
- ) # shape: (NORIENT,)
2381
- angles = angles.reshape(1, 1, self.NORIENT)
2380
+ ).reshape(1, 1, self.NORIENT)) # shape: (NORIENT,)
2382
2381
 
2383
2382
  # we use cosines and sines as weights for sim
2384
- weighted_cos = self.backend.bk_reduce_mean(sim * np.cos(angles), axis=-1)
2385
- weighted_sin = self.backend.bk_reduce_mean(sim * np.sin(angles), axis=-1)
2383
+ weighted_cos = self.backend.bk_reduce_mean(sim * self.backend.bk_cos(angles), axis=-1)
2384
+ weighted_sin = self.backend.bk_reduce_mean(sim * self.backend.bk_sin(angles), axis=-1)
2386
2385
  # For simplicity, take first element of the batch
2387
2386
  cc = weighted_cos[0]
2388
2387
  ss = weighted_sin[0]
@@ -2402,7 +2401,8 @@ class funct(FOC.FoCUS):
2402
2401
  phase = np.fmod(np.arctan2(ss, cc) + 2 * np.pi, 2 * np.pi)
2403
2402
  else:
2404
2403
  phase = np.fmod(
2405
- np.arctan2(ss.numpy(), cc.numpy()) + 2 * np.pi, 2 * np.pi
2404
+ np.arctan2(self.backend.to_numpy(ss),
2405
+ self.backend.to_numpy(cc)) + 2 * np.pi, 2 * np.pi
2406
2406
  )
2407
2407
 
2408
2408
  # instead of linear interpolation cosine‐based interpolation
@@ -2435,7 +2435,8 @@ class funct(FOC.FoCUS):
2435
2435
 
2436
2436
  sim2 = self.backend.bk_reduce_sum(
2437
2437
  self.backend.bk_reshape(
2438
- mat.reshape(1, mat.shape[0], self.NORIENT * self.NORIENT)
2438
+ self.backend.bk_cast(
2439
+ mat.reshape(1, mat.shape[0], self.NORIENT * self.NORIENT))
2439
2440
  * tmp2,
2440
2441
  [sim.shape[0], cmat[k].shape[0], self.NORIENT, self.NORIENT],
2441
2442
  ),
@@ -2445,10 +2446,10 @@ class funct(FOC.FoCUS):
2445
2446
  sim2 = self.backend.bk_abs(self.convol(sim2, axis=1))
2446
2447
 
2447
2448
  weighted_cos2 = self.backend.bk_reduce_mean(
2448
- sim2 * np.cos(angles), axis=-1
2449
+ sim2 * self.backend.bk_cos(angles), axis=-1
2449
2450
  )
2450
2451
  weighted_sin2 = self.backend.bk_reduce_mean(
2451
- sim2 * np.sin(angles), axis=-1
2452
+ sim2 * self.backend.bk_sin(angles), axis=-1
2452
2453
  )
2453
2454
 
2454
2455
  cc2 = weighted_cos2[0]
@@ -2469,7 +2470,8 @@ class funct(FOC.FoCUS):
2469
2470
  phase2 = np.fmod(np.arctan2(ss2, cc2) + 2 * np.pi, 2 * np.pi)
2470
2471
  else:
2471
2472
  phase2 = np.fmod(
2472
- np.arctan2(ss2.numpy(), cc2.numpy()) + 2 * np.pi, 2 * np.pi
2473
+ np.arctan2(self.backend.to_numpy(ss2),
2474
+ self.backend.to_numpy(cc2)) + 2 * np.pi, 2 * np.pi
2473
2475
  )
2474
2476
 
2475
2477
  phase2_scaled = self.NORIENT * phase2 / (2 * np.pi)
@@ -4791,17 +4793,37 @@ class funct(FOC.FoCUS):
4791
4793
  This function is strongly inspire by the package https://github.com/SihaoCheng/scattering_transform
4792
4794
  Done by Sihao Cheng and Rudy Morel.
4793
4795
  '''
4794
- if_xodd = (data_f.shape[-2]%2==1)
4795
- if_yodd = (data_f.shape[-1]%2==1)
4796
- result = self.backend.backend.cat(
4797
- (self.backend.backend.cat(
4798
- ( data_f[...,:dx+if_xodd, :dy+if_yodd] , data_f[...,-dx:, :dy+if_yodd]
4799
- ), -2),
4800
- self.backend.backend.cat(
4801
- ( data_f[...,:dx+if_xodd, -dy:] , data_f[...,-dx:, -dy:]
4802
- ), -2)
4803
- ),-1)
4804
- return result
4796
+
4797
+ if self.backend.BACKEND=='torch':
4798
+ if_xodd = (data_f.shape[-2]%2==1)
4799
+ if_yodd = (data_f.shape[-1]%2==1)
4800
+ result = self.backend.backend.cat(
4801
+ (self.backend.backend.cat(
4802
+ ( data_f[...,:dx+if_xodd, :dy+if_yodd] , data_f[...,-dx:, :dy+if_yodd]
4803
+ ), -2),
4804
+ self.backend.backend.cat(
4805
+ ( data_f[...,:dx+if_xodd, -dy:] , data_f[...,-dx:, -dy:]
4806
+ ), -2)
4807
+ ),-1)
4808
+ return result
4809
+ else:
4810
+ # Check if the last two dimensions are odd
4811
+ if_xodd = self.backend.backend.cast(self.backend.backend.shape(data_f)[-2] % 2 == 1, self.backend.backend.int32)
4812
+ if_yodd = self.backend.backend.cast(self.backend.backend.shape(data_f)[-1] % 2 == 1, self.backend.backend.int32)
4813
+
4814
+ # Extract four regions
4815
+ top_left = data_f[..., :dx+if_xodd, :dy+if_yodd]
4816
+ top_right = data_f[..., -dx:, :dy+if_yodd]
4817
+ bottom_left = data_f[..., :dx+if_xodd, -dy:]
4818
+ bottom_right = data_f[..., -dx:, -dy:]
4819
+
4820
+ # Concatenate along the last two dimensions
4821
+ top = self.backend.backend.concat([top_left, top_right], axis=-2)
4822
+ bottom = self.backend.backend.concat([bottom_left, bottom_right], axis=-2)
4823
+ result = self.backend.backend.concat([top, bottom], axis=-1)
4824
+
4825
+ return result
4826
+
4805
4827
  # ---------------------------------------------------------------------------
4806
4828
  #
4807
4829
  # utility functions for computing scattering coef and covariance
@@ -4824,13 +4846,15 @@ class funct(FOC.FoCUS):
4824
4846
  This function is strongly inspire by the package https://github.com/SihaoCheng/scattering_transform
4825
4847
  Done by Sihao Cheng and Rudy Morel.
4826
4848
  '''
4827
- edge_masks = self.backend.backend.empty((J, M, N))
4828
- X, Y = self.backend.backend.meshgrid(self.backend.backend.arange(M), self.backend.backend.arange(N), indexing='ij')
4849
+ edge_masks = np.empty((J, M, N))
4850
+ X, Y = np.meshgrid(np.arange(M), np.arange(N), indexing='ij')
4829
4851
  for j in range(J):
4830
4852
  edge_dx = min(M//4, 2**j*d0)
4831
4853
  edge_dy = min(N//4, 2**j*d0)
4832
4854
  edge_masks[j] = (X>=edge_dx) * (X<=M-edge_dx) * (Y>=edge_dy) * (Y<=N-edge_dy)
4833
- return edge_masks.to(self.backend.torch_device)
4855
+ edge_masks = edge_masks[:,None,:,:]
4856
+ edge_masks = edge_masks / edge_masks.mean((-2,-1))[:,:,None,None]
4857
+ return self.backend.bk_cast(edge_masks)
4834
4858
 
4835
4859
  # ---------------------------------------------------------------------------
4836
4860
  #
@@ -4914,6 +4938,11 @@ class funct(FOC.FoCUS):
4914
4938
  if S4_criteria is None:
4915
4939
  S4_criteria = 'j2>=j1'
4916
4940
 
4941
+ if self.all_bk_type == "float32":
4942
+ C_ONE=np.complex64(1.0)
4943
+ else:
4944
+ C_ONE=np.complex128(1.0)
4945
+
4917
4946
  # determine jmax and nside corresponding to the input map
4918
4947
  im_shape = data.shape
4919
4948
  if self.use_2D:
@@ -4921,18 +4950,24 @@ class funct(FOC.FoCUS):
4921
4950
  nside = np.min([im_shape[0], im_shape[1]])
4922
4951
  M,N = im_shape[0],im_shape[1]
4923
4952
  N_image = 1
4953
+ N_image2 = 1
4924
4954
  else:
4925
4955
  nside = np.min([im_shape[1], im_shape[2]])
4926
4956
  M,N = im_shape[1],im_shape[2]
4927
4957
  N_image = data.shape[0]
4958
+ if data2 is not None:
4959
+ N_image2 = data2.shape[0]
4928
4960
  J = int(np.log(nside) / np.log(2))-1 # Number of j scales
4929
4961
  elif self.use_1D:
4930
4962
  if len(data.shape) == 2:
4931
4963
  npix = int(im_shape[1]) # Number of pixels
4932
4964
  N_image = 1
4965
+ N_image2 = 1
4933
4966
  else:
4934
4967
  npix = int(im_shape[0]) # Number of pixels
4935
4968
  N_image = data.shape[0]
4969
+ if data2 is not None:
4970
+ N_image2 = data2.shape[0]
4936
4971
 
4937
4972
  nside = int(npix)
4938
4973
 
@@ -4941,9 +4976,12 @@ class funct(FOC.FoCUS):
4941
4976
  if len(data.shape) == 2:
4942
4977
  npix = int(im_shape[1]) # Number of pixels
4943
4978
  N_image = 1
4979
+ N_image2 = 1
4944
4980
  else:
4945
4981
  npix = int(im_shape[0]) # Number of pixels
4946
4982
  N_image = data.shape[0]
4983
+ if data2 is not None:
4984
+ N_image2 = data2.shape[0]
4947
4985
 
4948
4986
  nside = int(np.sqrt(npix // 12))
4949
4987
 
@@ -4957,7 +4995,475 @@ class funct(FOC.FoCUS):
4957
4995
  print('\n\n==========')
4958
4996
 
4959
4997
  L=self.NORIENT
4998
+ norm_factor_S3=1.0
4960
4999
 
5000
+ if self.backend.BACKEND=='torch':
5001
+ if (M,N,J,L) not in self.filters_set:
5002
+ self.filters_set[(M,N,J,L)] = self.computer_filter(M,N,J,L) #self.computer_filter(M,N,J,L)
5003
+
5004
+ filters_set = self.filters_set[(M,N,J,L)]
5005
+
5006
+ #weight = self.weight
5007
+ if use_ref:
5008
+ if normalization=='S2':
5009
+ ref_S2 = self.ref_scattering_cov_S2
5010
+ else:
5011
+ ref_P11 = self.ref_scattering_cov['P11']
5012
+
5013
+ # convert numpy array input into self.backend.bk_ tensors
5014
+ data = self.backend.bk_cast(data)
5015
+ data_f = self.backend.bk_fftn(data, dim=(-2,-1))
5016
+ if data2 is not None:
5017
+ data2 = self.backend.bk_cast(data2)
5018
+ data2_f = self.backend.bk_fftn(data2, dim=(-2,-1))
5019
+
5020
+ # initialize tensors for scattering coefficients
5021
+ S2 = self.backend.bk_zeros((N_image,J,L), dtype=data.dtype)
5022
+ S1 = self.backend.bk_zeros((N_image,J,L), dtype=data.dtype)
5023
+
5024
+ Ndata_S3 = J*(J+1)//2
5025
+ Ndata_S4 = J*(J+1)*(J+2)//6
5026
+ J_S4={}
5027
+
5028
+ S3 = self.backend.bk_zeros((N_image,Ndata_S3,L,L), dtype=data_f.dtype)
5029
+ if data2 is not None:
5030
+ S3p = self.backend.bk_zeros((N_image,Ndata_S3,L,L), dtype=data_f.dtype)
5031
+ S4_pre_norm = self.backend.bk_zeros((N_image,Ndata_S4,L,L,L), dtype=data_f.dtype)
5032
+ S4 = self.backend.bk_zeros((N_image,Ndata_S4,L,L,L), dtype=data_f.dtype)
5033
+
5034
+ # variance
5035
+ if get_variance:
5036
+ S2_sigma = self.backend.bk_zeros((N_image,J,L), dtype=data.dtype)
5037
+ S1_sigma = self.backend.bk_zeros((N_image,J,L), dtype=data.dtype)
5038
+ S3_sigma = self.backend.bk_zeros((N_image,Ndata_S3,L,L), dtype=data_f.dtype)
5039
+ if data2 is not None:
5040
+ S3p_sigma = self.backend.bk_zeros((N_image,Ndata_S3,L,L), dtype=data_f.dtype)
5041
+ S4_sigma = self.backend.bk_zeros((N_image,Ndata_S4,L,L,L), dtype=data_f.dtype)
5042
+
5043
+ if iso_ang:
5044
+ S3_iso = self.backend.bk_zeros((N_image,Ndata_S3,L), dtype=data_f.dtype)
5045
+ S4_iso = self.backend.bk_zeros((N_image,Ndata_S4,L,L), dtype=data_f.dtype)
5046
+ if get_variance:
5047
+ S3_sigma_iso = self.backend.bk_zeros((N_image,Ndata_S3,L), dtype=data_f.dtype)
5048
+ S4_sigma_iso = self.backend.bk_zeros((N_image,Ndata_S4,L,L), dtype=data_f.dtype)
5049
+ if data2 is not None:
5050
+ S3p_iso = self.backend.bk_zeros((N_image,Ndata_S3,L), dtype=data_f.dtype)
5051
+ if get_variance:
5052
+ S3p_sigma_iso = self.backend.bk_zeros((N_image,Ndata_S3,L), dtype=data_f.dtype)
5053
+
5054
+ #
5055
+ if edge:
5056
+ if (M,N,J) not in self.edge_masks:
5057
+ self.edge_masks[(M,N,J)] = self.get_edge_masks(M,N,J)
5058
+ edge_mask=self.edge_masks[(M,N,J)]
5059
+ else:
5060
+ edge_mask = 1
5061
+
5062
+ # calculate scattering fields
5063
+ if data2 is None:
5064
+ if self.use_2D:
5065
+ if len(data.shape) == 2:
5066
+ I1 = self.backend.bk_ifftn(
5067
+ data_f[None,None,None,:,:] * filters_set[None,:J,:,:,:], dim=(-2,-1)
5068
+ ).abs()
5069
+ else:
5070
+ I1 = self.backend.bk_ifftn(
5071
+ data_f[:,None,None,:,:] * filters_set[None,:J,:,:,:], dim=(-2,-1)
5072
+ ).abs()
5073
+ elif self.use_1D:
5074
+ if len(data.shape) == 1:
5075
+ I1 = self.backend.bk_ifftn(
5076
+ data_f[None,None,None,:] * filters_set[None,:J,:,:], dim=(-1)
5077
+ ).abs()
5078
+ else:
5079
+ I1 = self.backend.bk_ifftn(
5080
+ data_f[:,None,None,:] * filters_set[None,:J,:,:], dim=(-1)
5081
+ ).abs()
5082
+ else:
5083
+ print('todo')
5084
+
5085
+ S2 = (I1**2 * edge_mask).mean((-2,-1))
5086
+ S1 = (I1 * edge_mask).mean((-2,-1))
5087
+
5088
+ if get_variance:
5089
+ S2_sigma = (I1**2 * edge_mask).std((-2,-1))
5090
+ S1_sigma = (I1 * edge_mask).std((-2,-1))
5091
+
5092
+ else:
5093
+ if self.use_2D:
5094
+ if len(data.shape) == 2:
5095
+ I1 = self.backend.bk_ifftn(
5096
+ data_f[None,None,None,:,:] * filters_set[None,:J,:,:,:], dim=(-2,-1)
5097
+ )
5098
+ I2 = self.backend.bk_ifftn(
5099
+ data2_f[None,None,None,:,:] * filters_set[None,:J,:,:,:], dim=(-2,-1)
5100
+ )
5101
+ else:
5102
+ I1 = self.backend.bk_ifftn(
5103
+ data_f[:,None,None,:,:] * filters_set[None,:J,:,:,:], dim=(-2,-1)
5104
+ )
5105
+ I2 = self.backend.bk_ifftn(
5106
+ data2_f[:,None,None,:,:] * filters_set[None,:J,:,:,:], dim=(-2,-1)
5107
+ )
5108
+ elif self.use_1D:
5109
+ if len(data.shape) == 1:
5110
+ I1 = self.backend.bk_ifftn(
5111
+ data_f[None,None,None,:] * filters_set[None,:J,:,:], dim=(-1)
5112
+ )
5113
+ I2 = self.backend.bk_ifftn(
5114
+ data2_f[None,None,None,:] * filters_set[None,:J,:,:], dim=(-1)
5115
+ )
5116
+ else:
5117
+ I1 = self.backend.bk_ifftn(
5118
+ data_f[:,None,None,:] * filters_set[None,:J,:,:], dim=(-1)
5119
+ )
5120
+ I2 = self.backend.bk_ifftn(
5121
+ data2_f[:,None,None,:] * filters_set[None,:J,:,:], dim=(-1)
5122
+ )
5123
+ else:
5124
+ print('todo')
5125
+
5126
+ I1=self.backend.bk_real(I1*self.backend.bk_conjugate(I2))
5127
+
5128
+ S2 = self.backend.bk_reduce_mean((I1 * edge_mask),axis=(-2,-1))
5129
+ if get_variance:
5130
+ S2_sigma = self.backend.bk_reduce_std((I1 * edge_mask),axis=(-2,-1))
5131
+
5132
+ I1=self.backend.bk_L1(I1)
5133
+
5134
+ S1 = self.backend.bk_reduce_mean((I1 * edge_mask),axis=(-2,-1))
5135
+
5136
+ if get_variance:
5137
+ S1_sigma = self.backend.bk_reduce_std((I1 * edge_mask),axis=(-2,-1))
5138
+
5139
+ I1_f= self.backend.bk_fftn(I1, dim=(-2,-1))
5140
+
5141
+ if pseudo_coef != 1:
5142
+ I1 = I1**pseudo_coef
5143
+
5144
+ Ndata_S3=0
5145
+ Ndata_S4=0
5146
+
5147
+ # calculate the covariance and correlations of the scattering fields
5148
+ # only use the low-k Fourier coefs when calculating large-j scattering coefs.
5149
+ for j3 in range(0,J):
5150
+ J_S4[j3]=Ndata_S4
5151
+
5152
+ dx3, dy3 = self.get_dxdy(j3,M,N)
5153
+ I1_f_small = self.cut_high_k_off(I1_f[:,:j3+1], dx3, dy3) # Nimage, J, L, x, y
5154
+ data_f_small = self.cut_high_k_off(data_f, dx3, dy3)
5155
+ if data2 is not None:
5156
+ data2_f_small = self.cut_high_k_off(data2_f, dx3, dy3)
5157
+ if edge:
5158
+ I1_small = self.backend.bk_ifftn(I1_f_small, dim=(-2,-1), norm='ortho')
5159
+ data_small = self.backend.bk_ifftn(data_f_small, dim=(-2,-1), norm='ortho')
5160
+ if data2 is not None:
5161
+ data2_small = self.backend.bk_ifftn(data2_f_small, dim=(-2,-1), norm='ortho')
5162
+
5163
+ wavelet_f3 = self.cut_high_k_off(filters_set[j3], dx3, dy3) # L,x,y
5164
+ _, M3, N3 = wavelet_f3.shape
5165
+ wavelet_f3_squared = wavelet_f3**2
5166
+ edge_dx = min(4, int(2**j3*dx3*2/M))
5167
+ edge_dy = min(4, int(2**j3*dy3*2/N))
5168
+
5169
+ # a normalization change due to the cutoff of frequency space
5170
+ fft_factor = 1 /(M3*N3) * (M3*N3/M/N)**2
5171
+ for j2 in range(0,j3+1):
5172
+ I1_f2_wf3_small = I1_f_small[:,j2].view(N_image,L,1,M3,N3) * wavelet_f3.view(1,1,L,M3,N3)
5173
+ I1_f2_wf3_2_small = I1_f_small[:,j2].view(N_image,L,1,M3,N3) * wavelet_f3_squared.view(1,1,L,M3,N3)
5174
+ if edge:
5175
+ I12_w3_small = self.backend.bk_ifftn(I1_f2_wf3_small, dim=(-2,-1), norm='ortho')
5176
+ I12_w3_2_small = self.backend.bk_ifftn(I1_f2_wf3_2_small, dim=(-2,-1), norm='ortho')
5177
+ if use_ref:
5178
+ if normalization=='P11':
5179
+ norm_factor_S3 = (ref_S2[:,None,j3,:] * ref_P11[:,j2,j3,:,:]**pseudo_coef)**0.5
5180
+ if normalization=='S2':
5181
+ norm_factor_S3 = (ref_S2[:,None,j3,:] * ref_S2[:,j2,:,None]**pseudo_coef)**0.5
5182
+ else:
5183
+ if normalization=='P11':
5184
+ # [N_image,l2,l3,x,y]
5185
+ P11_temp = (I1_f2_wf3_small.abs()**2).mean((-2,-1)) * fft_factor
5186
+ norm_factor_S3 = (S2[:,None,j3,:] * P11_temp**pseudo_coef)**0.5
5187
+ if normalization=='S2':
5188
+ norm_factor_S3 = (S2[:,None,j3,:] * S2[:,j2,:,None]**pseudo_coef)**0.5
5189
+
5190
+ if not edge:
5191
+ S3[:,Ndata_S3,:,:] = (
5192
+ data_f_small.view(N_image,1,1,M3,N3) * self.backend.bk_conjugate(I1_f2_wf3_small)
5193
+ ).mean((-2,-1)) * fft_factor / norm_factor_S3
5194
+
5195
+ if get_variance:
5196
+ S3_sigma[:,Ndata_S3,:,:] = (
5197
+ data_f_small.view(N_image,1,1,M3,N3) * self.backend.bk_conjugate(I1_f2_wf3_small)
5198
+ ).std((-2,-1)) * fft_factor / norm_factor_S3
5199
+ else:
5200
+ S3[:,Ndata_S3,:,:] = (
5201
+ data_small.view(N_image,1,1,M3,N3) * self.backend.bk_conjugate(I12_w3_small)
5202
+ )[...,edge_dx:M3-edge_dx, edge_dy:N3-edge_dy].mean((-2,-1)) * fft_factor / norm_factor_S3
5203
+ if get_variance:
5204
+ S3_sigma[:,Ndata_S3,:,:] = (
5205
+ data_small.view(N_image,1,1,M3,N3) * self.backend.bk_conjugate(I12_w3_small)
5206
+ )[...,edge_dx:M3-edge_dx, edge_dy:N3-edge_dy].std((-2,-1)) * fft_factor / norm_factor_S3
5207
+ if data2 is not None:
5208
+ if not edge:
5209
+ S3p[:,Ndata_S3,:,:] = (
5210
+ data2_f_small.view(N_image2,1,1,M3,N3) * self.backend.bk_conjugate(I1_f2_wf3_small)
5211
+ ).mean((-2,-1)) * fft_factor / norm_factor_S3
5212
+
5213
+ if get_variance:
5214
+ S3p_sigma[:,Ndata_S3,:,:] = (
5215
+ data2_f_small.view(N_image2,1,1,M3,N3) * self.backend.bk_conjugate(I1_f2_wf3_small)
5216
+ ).std((-2,-1)) * fft_factor / norm_factor_S3
5217
+ else:
5218
+ S3p[:,Ndata_S3,:,:] = (
5219
+ data2_small.view(N_image2,1,1,M3,N3) * self.backend.bk_conjugate(I12_w3_small)
5220
+ )[...,edge_dx:M3-edge_dx, edge_dy:N3-edge_dy].mean((-2,-1)) * fft_factor / norm_factor_S3
5221
+ if get_variance:
5222
+ S3p_sigma[:,Ndata_S3,:,:] = (
5223
+ data2_small.view(N_image2,1,1,M3,N3) * self.backend.bk_conjugate(I12_w3_small)
5224
+ )[...,edge_dx:M3-edge_dx, edge_dy:N3-edge_dy].std((-2,-1)) * fft_factor / norm_factor_S3
5225
+ Ndata_S3+=1
5226
+ if j2 <= j3:
5227
+ beg_n=Ndata_S4
5228
+ for j1 in range(0, j2+1):
5229
+ if eval(S4_criteria):
5230
+ if not edge:
5231
+ if not if_large_batch:
5232
+ # [N_image,l1,l2,l3,x,y]
5233
+ S4_pre_norm[:,Ndata_S4,:,:,:] = (
5234
+ I1_f_small[:,j1].view(N_image,L,1,1,M3,N3) *
5235
+ self.backend.bk_conjugate(I1_f2_wf3_2_small.view(N_image,1,L,L,M3,N3))
5236
+ ).mean((-2,-1)) * fft_factor
5237
+ if get_variance:
5238
+ S4_sigma[:,Ndata_S4,:,:,:] = (
5239
+ I1_f_small[:,j1].view(N_image,L,1,1,M3,N3) *
5240
+ self.backend.bk_conjugate(I1_f2_wf3_2_small.view(N_image,1,L,L,M3,N3))
5241
+ ).std((-2,-1)) * fft_factor
5242
+ else:
5243
+ for l1 in range(L):
5244
+ # [N_image,l2,l3,x,y]
5245
+ S4_pre_norm[:,Ndata_S4,l1,:,:] = (
5246
+ I1_f_small[:,j1,l1].view(N_image,1,1,M3,N3) *
5247
+ self.backend.bk_conjugate(I1_f2_wf3_2_small.view(N_image,L,L,M3,N3))
5248
+ ).mean((-2,-1)) * fft_factor
5249
+ if get_variance:
5250
+ S4_sigma[:,Ndata_S4,l1,:,:] = (
5251
+ I1_f_small[:,j1,l1].view(N_image,1,1,M3,N3) *
5252
+ self.backend.bk_conjugate(I1_f2_wf3_2_small.view(N_image,L,L,M3,N3))
5253
+ ).std((-2,-1)) * fft_factor
5254
+ else:
5255
+ if not if_large_batch:
5256
+ # [N_image,l1,l2,l3,x,y]
5257
+ S4_pre_norm[:,Ndata_S4,:,:,:] = (
5258
+ I1_small[:,j1].view(N_image,L,1,1,M3,N3) * self.backend.bk_conjugate(
5259
+ I12_w3_2_small.view(N_image,1,L,L,M3,N3)
5260
+ )
5261
+ )[...,edge_dx:-edge_dx, edge_dy:-edge_dy].mean((-2,-1)) * fft_factor
5262
+ if get_variance:
5263
+ S4_sigma[:,Ndata_S4,:,:,:] = (
5264
+ I1_small[:,j1].view(N_image,L,1,1,M3,N3) * self.backend.bk_conjugate(
5265
+ I12_w3_2_small.view(N_image,1,L,L,M3,N3)
5266
+ )
5267
+ )[...,edge_dx:-edge_dx, edge_dy:-edge_dy].std((-2,-1)) * fft_factor
5268
+ else:
5269
+ for l1 in range(L):
5270
+ # [N_image,l2,l3,x,y]
5271
+ S4_pre_norm[:,Ndata_S4,l1,:,:] = (
5272
+ I1_small[:,j1].view(N_image,1,1,M3,N3) * self.backend.bk_conjugate(
5273
+ I12_w3_2_small.view(N_image,L,L,M3,N3)
5274
+ )
5275
+ )[...,edge_dx:-edge_dx, edge_dy:-edge_dy].mean((-2,-1)) * fft_factor
5276
+ if get_variance:
5277
+ S4_sigma[:,Ndata_S4,l1,:,:] = (
5278
+ I1_small[:,j1].view(N_image,1,1,M3,N3) * self.backend.bk_conjugate(
5279
+ I12_w3_2_small.view(N_image,L,L,M3,N3)
5280
+ )
5281
+ )[...,edge_dx:-edge_dx, edge_dy:-edge_dy].mean((-2,-1)) * fft_factor
5282
+
5283
+ Ndata_S4+=1
5284
+
5285
+ if normalization=='S2':
5286
+ if use_ref:
5287
+ P = ((ref_S2[:,j3:j3+1,:,None,None] * ref_S2[:,j2:j2+1,None,:,None] )**(0.5*pseudo_coef))
5288
+ else:
5289
+ P = ((S2[:,j3:j3+1,:,None,None] * S2[:,j2:j2+1,None,:,None] )**(0.5*pseudo_coef))
5290
+
5291
+ S4[:,beg_n:Ndata_S4,:,:,:]=S4_pre_norm[:,beg_n:Ndata_S4,:,:,:].clone()/P
5292
+
5293
+ if get_variance:
5294
+ S4_sigma[:,beg_n:Ndata_S4,:,:,:] = S4_sigma[:,beg_n:Ndata_S4,:,:,:]/P
5295
+ else:
5296
+ S4=S4_pre_norm
5297
+
5298
+ # average over l1 to obtain simple isotropic statistics
5299
+ if iso_ang:
5300
+ S2_iso = S2.mean(-1)
5301
+ S1_iso = S1.mean(-1)
5302
+ for l1 in range(L):
5303
+ for l2 in range(L):
5304
+ S3_iso[...,(l2-l1)%L] += S3[...,l1,l2]
5305
+ if data2 is not None:
5306
+ S3p_iso[...,(l2-l1)%L] += S3p[...,l1,l2]
5307
+ for l3 in range(L):
5308
+ S4_iso[...,(l2-l1)%L,(l3-l1)%L] += S4[...,l1,l2,l3]
5309
+ S3_iso /= L; S4_iso /= L
5310
+ if data2 is not None:
5311
+ S3p_iso /= L
5312
+
5313
+ if get_variance:
5314
+ S2_sigma_iso = S2_sigma.mean(-1)
5315
+ S1_sigma_iso = S1_sigma.mean(-1)
5316
+ for l1 in range(L):
5317
+ for l2 in range(L):
5318
+ S3_sigma_iso[...,(l2-l1)%L] += S3_sigma[...,l1,l2]
5319
+ if data2 is not None:
5320
+ S3p_sigma_iso[...,(l2-l1)%L] += S3p_sigma[...,l1,l2]
5321
+ for l3 in range(L):
5322
+ S4_sigma_iso[...,(l2-l1)%L,(l3-l1)%L] += S4_sigma[...,l1,l2,l3]
5323
+ S3_sigma_iso /= L; S4_sigma_iso /= L
5324
+ if data2 is not None:
5325
+ S3p_sigma_iso /= L
5326
+
5327
+ mean_data=self.backend.bk_zeros((N_image,1), dtype=data.dtype)
5328
+ std_data=self.backend.bk_zeros((N_image,1), dtype=data.dtype)
5329
+
5330
+ if data2 is None:
5331
+ mean_data[:,0]=data.mean((-2,-1))
5332
+ std_data[:,0]=data.std((-2,-1))
5333
+ else:
5334
+ mean_data[:,0]=(data2*data).mean((-2,-1))
5335
+ std_data[:,0]=(data2*data).std((-2,-1))
5336
+
5337
+ if get_variance:
5338
+ ref_sigma={}
5339
+ if iso_ang:
5340
+ ref_sigma['std_data']=std_data
5341
+ ref_sigma['S1_sigma']=S1_sigma_iso
5342
+ ref_sigma['S2_sigma']=S2_sigma_iso
5343
+ ref_sigma['S3_sigma']=S3_sigma_iso
5344
+ if data2 is not None:
5345
+ ref_sigma['S3p_sigma']=S3p_sigma_iso
5346
+ ref_sigma['S4_sigma']=S4_sigma_iso
5347
+ else:
5348
+ ref_sigma['std_data']=std_data
5349
+ ref_sigma['S1_sigma']=S1_sigma
5350
+ ref_sigma['S2_sigma']=S2_sigma
5351
+ ref_sigma['S3_sigma']=S3_sigma
5352
+ if data2 is not None:
5353
+ ref_sigma['S3p_sigma']=S3p_sigma
5354
+ ref_sigma['S4_sigma']=S4_sigma
5355
+
5356
+ if data2 is None:
5357
+ if iso_ang:
5358
+ if ref_sigma is not None:
5359
+ for_synthesis = self.backend.backend.cat((
5360
+ mean_data/ref_sigma['std_data'],
5361
+ std_data/ref_sigma['std_data'],
5362
+ (S2_iso/ref_sigma['S2_sigma']).reshape((N_image, -1)).log(),
5363
+ (S1_iso/ref_sigma['S1_sigma']).reshape((N_image, -1)).log(),
5364
+ (S3_iso/ref_sigma['S3_sigma']).reshape((N_image, -1)).real,
5365
+ (S3_iso/ref_sigma['S3_sigma']).reshape((N_image, -1)).imag,
5366
+ (S4_iso/ref_sigma['S4_sigma']).reshape((N_image, -1)).real,
5367
+ (S4_iso/ref_sigma['S4_sigma']).reshape((N_image, -1)).imag,
5368
+ ),dim=-1)
5369
+ else:
5370
+ for_synthesis = self.backend.backend.cat((
5371
+ mean_data/std_data,
5372
+ std_data,
5373
+ S2_iso.reshape((N_image, -1)).log(),
5374
+ S1_iso.reshape((N_image, -1)).log(),
5375
+ S3_iso.reshape((N_image, -1)).real,
5376
+ S3_iso.reshape((N_image, -1)).imag,
5377
+ S4_iso.reshape((N_image, -1)).real,
5378
+ S4_iso.reshape((N_image, -1)).imag,
5379
+ ),dim=-1)
5380
+ else:
5381
+ if ref_sigma is not None:
5382
+ for_synthesis = self.backend.backend.cat((
5383
+ mean_data/ref_sigma['std_data'],
5384
+ std_data/ref_sigma['std_data'],
5385
+ (S2/ref_sigma['S2_sigma']).reshape((N_image, -1)).log(),
5386
+ (S1/ref_sigma['S1_sigma']).reshape((N_image, -1)).log(),
5387
+ (S3/ref_sigma['S3_sigma']).reshape((N_image, -1)).real,
5388
+ (S3/ref_sigma['S3_sigma']).reshape((N_image, -1)).imag,
5389
+ (S4/ref_sigma['S4_sigma']).reshape((N_image, -1)).real,
5390
+ (S4/ref_sigma['S4_sigma']).reshape((N_image, -1)).imag,
5391
+ ),dim=-1)
5392
+ else:
5393
+ for_synthesis = self.backend.backend.cat((
5394
+ mean_data/std_data,
5395
+ std_data,
5396
+ S2.reshape((N_image, -1)).log(),
5397
+ S1.reshape((N_image, -1)).log(),
5398
+ S3.reshape((N_image, -1)).real,
5399
+ S3.reshape((N_image, -1)).imag,
5400
+ S4.reshape((N_image, -1)).real,
5401
+ S4.reshape((N_image, -1)).imag,
5402
+ ),dim=-1)
5403
+ else:
5404
+ if iso_ang:
5405
+ if ref_sigma is not None:
5406
+ for_synthesis = self.backend.backend.cat((
5407
+ mean_data/ref_sigma['std_data'],
5408
+ std_data/ref_sigma['std_data'],
5409
+ (S2_iso/ref_sigma['S2_sigma']).reshape((N_image, -1)),
5410
+ (S1_iso/ref_sigma['S1_sigma']).reshape((N_image, -1)),
5411
+ (S3_iso/ref_sigma['S3_sigma']).reshape((N_image, -1)).real,
5412
+ (S3_iso/ref_sigma['S3_sigma']).reshape((N_image, -1)).imag,
5413
+ (S3p_iso/ref_sigma['S3p_sigma']).reshape((N_image, -1)).real,
5414
+ (S3p_iso/ref_sigma['S3p_sigma']).reshape((N_image, -1)).imag,
5415
+ (S4_iso/ref_sigma['S4_sigma']).reshape((N_image, -1)).real,
5416
+ (S4_iso/ref_sigma['S4_sigma']).reshape((N_image, -1)).imag,
5417
+ ),dim=-1)
5418
+ else:
5419
+ for_synthesis = self.backend.backend.cat((
5420
+ mean_data/std_data,
5421
+ std_data,
5422
+ S2_iso.reshape((N_image, -1)),
5423
+ S1_iso.reshape((N_image, -1)),
5424
+ S3_iso.reshape((N_image, -1)).real,
5425
+ S3_iso.reshape((N_image, -1)).imag,
5426
+ S3p_iso.reshape((N_image, -1)).real,
5427
+ S3p_iso.reshape((N_image, -1)).imag,
5428
+ S4_iso.reshape((N_image, -1)).real,
5429
+ S4_iso.reshape((N_image, -1)).imag,
5430
+ ),dim=-1)
5431
+ else:
5432
+ if ref_sigma is not None:
5433
+ for_synthesis = self.backend.backend.cat((
5434
+ mean_data/ref_sigma['std_data'],
5435
+ std_data/ref_sigma['std_data'],
5436
+ (S2/ref_sigma['S2_sigma']).reshape((N_image, -1)),
5437
+ (S1/ref_sigma['S1_sigma']).reshape((N_image, -1)),
5438
+ (S3/ref_sigma['S3_sigma']).reshape((N_image, -1)).real,
5439
+ (S3/ref_sigma['S3_sigma']).reshape((N_image, -1)).imag,
5440
+ (S3p/ref_sigma['S3p_sigma']).reshape((N_image, -1)).real,
5441
+ (S3p/ref_sigma['S3p_sigma']).reshape((N_image, -1)).imag,
5442
+ (S4/ref_sigma['S4_sigma']).reshape((N_image, -1)).real,
5443
+ (S4/ref_sigma['S4_sigma']).reshape((N_image, -1)).imag,
5444
+ ),dim=-1)
5445
+ else:
5446
+ for_synthesis = self.backend.backend.cat((
5447
+ mean_data/std_data,
5448
+ std_data,
5449
+ S2.reshape((N_image, -1)),
5450
+ S1.reshape((N_image, -1)),
5451
+ S3.reshape((N_image, -1)).real,
5452
+ S3.reshape((N_image, -1)).imag,
5453
+ S3p.reshape((N_image, -1)).real,
5454
+ S3p.reshape((N_image, -1)).imag,
5455
+ S4.reshape((N_image, -1)).real,
5456
+ S4.reshape((N_image, -1)).imag,
5457
+ ),dim=-1)
5458
+
5459
+ if not use_ref:
5460
+ self.ref_scattering_cov_S2=S2
5461
+
5462
+ if get_variance:
5463
+ return for_synthesis,ref_sigma
5464
+
5465
+ return for_synthesis
5466
+
4961
5467
  if (M,N,J,L) not in self.filters_set:
4962
5468
  self.filters_set[(M,N,J,L)] = self.computer_filter(M,N,J,L) #self.computer_filter(M,N,J,L)
4963
5469
 
@@ -4978,46 +5484,41 @@ class funct(FOC.FoCUS):
4978
5484
  data2_f = self.backend.bk_fftn(data2, dim=(-2,-1))
4979
5485
 
4980
5486
  # initialize tensors for scattering coefficients
4981
- S2 = self.backend.bk_zeros((N_image,J,L), dtype=data.dtype)
4982
- S1 = self.backend.bk_zeros((N_image,J,L), dtype=data.dtype)
4983
5487
 
4984
5488
  Ndata_S3 = J*(J+1)//2
4985
5489
  Ndata_S4 = J*(J+1)*(J+2)//6
4986
5490
  J_S4={}
4987
5491
 
4988
- S3 = self.backend.bk_zeros((N_image,Ndata_S3,L,L), dtype=data_f.dtype)
5492
+ S3 = []
4989
5493
  if data2 is not None:
4990
- S3p = self.backend.bk_zeros((N_image,Ndata_S3,L,L), dtype=data_f.dtype)
4991
- S4_pre_norm = self.backend.bk_zeros((N_image,Ndata_S4,L,L,L), dtype=data_f.dtype)
4992
- S4 = self.backend.bk_zeros((N_image,Ndata_S4,L,L,L), dtype=data_f.dtype)
5494
+ S3p = []
5495
+ S4_pre_norm = []
5496
+ S4 = []
4993
5497
 
4994
5498
  # variance
4995
5499
  if get_variance:
4996
- S2_sigma = self.backend.bk_zeros((N_image,J,L), dtype=data.dtype)
4997
- S1_sigma = self.backend.bk_zeros((N_image,J,L), dtype=data.dtype)
4998
- S3_sigma = self.backend.bk_zeros((N_image,Ndata_S3,L,L), dtype=data_f.dtype)
5500
+ S3_sigma = []
4999
5501
  if data2 is not None:
5000
- S3p_sigma = self.backend.bk_zeros((N_image,Ndata_S3,L,L), dtype=data_f.dtype)
5001
- S4_sigma = self.backend.bk_zeros((N_image,Ndata_S4,L,L,L), dtype=data_f.dtype)
5502
+ S3p_sigma = []
5503
+ S4_sigma = []
5002
5504
 
5003
5505
  if iso_ang:
5004
- S3_iso = self.backend.bk_zeros((N_image,Ndata_S3,L), dtype=data_f.dtype)
5506
+ S3_iso = []
5005
5507
  if data2 is not None:
5006
- S3p_iso = self.backend.bk_zeros((N_image,Ndata_S3,L,L), dtype=data_f.dtype)
5508
+ S3p_iso = []
5007
5509
 
5008
- S4_iso = self.backend.bk_zeros((N_image,Ndata_S4,L,L), dtype=data_f.dtype)
5510
+ S4_iso = []
5009
5511
  if get_variance:
5010
- S3_sigma_iso = self.backend.bk_zeros((N_image,Ndata_S3,L), dtype=data_f.dtype)
5512
+ S3_sigma_iso = []
5011
5513
  if data2 is not None:
5012
- S3p_sigma_iso = self.backend.bk_zeros((N_image,Ndata_S3,L), dtype=data_f.dtype)
5013
- S4_sigma_iso = self.backend.bk_zeros((N_image,Ndata_S4,L,L), dtype=data_f.dtype)
5514
+ S3p_sigma_iso = []
5515
+ S4_sigma_iso = []
5014
5516
 
5015
5517
  #
5016
5518
  if edge:
5017
5519
  if (M,N,J) not in self.edge_masks:
5018
5520
  self.edge_masks[(M,N,J)] = self.get_edge_masks(M,N,J)
5019
- edge_mask = self.edge_masks[(M,N,J)][:,None,:,:]
5020
- edge_mask = edge_mask / edge_mask.mean((-2,-1))[:,:,None,None]
5521
+ edge_mask = self.edge_masks[(M,N,J)]
5021
5522
  else:
5022
5523
  edge_mask = 1
5023
5524
 
@@ -5025,31 +5526,31 @@ class funct(FOC.FoCUS):
5025
5526
  if data2 is None:
5026
5527
  if self.use_2D:
5027
5528
  if len(data.shape) == 2:
5028
- I1 = self.backend.bk_ifftn(
5529
+ I1 = self.backend.bk_abs(self.backend.bk_ifftn(
5029
5530
  data_f[None,None,None,:,:] * filters_set[None,:J,:,:,:], dim=(-2,-1)
5030
- ).abs()
5531
+ ))
5031
5532
  else:
5032
- I1 = self.backend.bk_ifftn(
5533
+ I1 = self.backend.bk_abs(self.backend.bk_ifftn(
5033
5534
  data_f[:,None,None,:,:] * filters_set[None,:J,:,:,:], dim=(-2,-1)
5034
- ).abs()
5535
+ ))
5035
5536
  elif self.use_1D:
5036
5537
  if len(data.shape) == 1:
5037
- I1 = self.backend.bk_ifftn(
5538
+ I1 = self.backend.bk_abs(self.backend.bk_ifftn(
5038
5539
  data_f[None,None,None,:] * filters_set[None,:J,:,:], dim=(-1)
5039
- ).abs()
5540
+ ))
5040
5541
  else:
5041
- I1 = self.backend.bk_ifftn(
5542
+ I1 = self.backend.bk_abs(self.backend.bk_ifftn(
5042
5543
  data_f[:,None,None,:] * filters_set[None,:J,:,:], dim=(-1)
5043
- ).abs()
5544
+ ))
5044
5545
  else:
5045
5546
  print('todo')
5046
5547
 
5047
- S2 = (I1**2 * edge_mask).mean((-2,-1))
5048
- S1 = (I1 * edge_mask).mean((-2,-1))
5548
+ S2 = self.backend.bk_reduce_mean((I1**2 * edge_mask),axis=(-2,-1))
5549
+ S1 = self.backend.bk_reduce_mean(I1 * edge_mask,axis=(-2,-1))
5049
5550
 
5050
5551
  if get_variance:
5051
- S2_sigma = (I1**2 * edge_mask).std((-2,-1))
5052
- S1_sigma = (I1 * edge_mask).std((-2,-1))
5552
+ S2_sigma = self.backend.bk_reduce_std((I1**2 * edge_mask),axis=(-2,-1))
5553
+ S1_sigma = self.backend.bk_reduce_std((I1 * edge_mask),axis=(-2,-1))
5053
5554
 
5054
5555
  I1_f= self.backend.bk_fftn(I1, dim=(-2,-1))
5055
5556
 
@@ -5089,16 +5590,16 @@ class funct(FOC.FoCUS):
5089
5590
 
5090
5591
  I1=self.backend.bk_real(I1*self.backend.bk_conjugate(I2))
5091
5592
 
5092
- S2 = (I1 * edge_mask).mean((-2,-1))
5593
+ S2 = self.backend.bk_reduce_mean((I1 * edge_mask),axis=(-2,-1))
5093
5594
  if get_variance:
5094
- S2_sigma = (I1 * edge_mask).std((-2,-1))
5595
+ S2_sigma = self.backend.bk_reduce_std((I1 * edge_mask),axis=(-2,-1))
5095
5596
 
5096
5597
  I1=self.backend.bk_L1(I1)
5097
5598
 
5098
- S1 = (I1 * edge_mask).mean((-2,-1))
5599
+ S1 = self.backend.bk_reduce_mean((I1 * edge_mask),axis=(-2,-1))
5099
5600
 
5100
5601
  if get_variance:
5101
- S1_sigma = (I1 * edge_mask).std((-2,-1))
5602
+ S1_sigma = self.backend.bk_reduce_std((I1 * edge_mask),axis=(-2,-1))
5102
5603
 
5103
5604
  I1_f= self.backend.bk_fftn(I1, dim=(-2,-1))
5104
5605
 
@@ -5129,170 +5630,156 @@ class funct(FOC.FoCUS):
5129
5630
  edge_dx = min(4, int(2**j3*dx3*2/M))
5130
5631
  edge_dy = min(4, int(2**j3*dy3*2/N))
5131
5632
  # a normalization change due to the cutoff of frequency space
5132
- fft_factor = 1 /(M3*N3) * (M3*N3/M/N)**2
5633
+ if self.all_bk_type == "float32":
5634
+ fft_factor = np.complex64(1 /(M3*N3) * (M3*N3/M/N)**2)
5635
+ else:
5636
+ fft_factor = np.complex128(1 /(M3*N3) * (M3*N3/M/N)**2)
5133
5637
  for j2 in range(0,j3+1):
5134
- I1_f2_wf3_small = I1_f_small[:,j2].view(N_image,L,1,M3,N3) * wavelet_f3.view(1,1,L,M3,N3)
5135
- I1_f2_wf3_2_small = I1_f_small[:,j2].view(N_image,L,1,M3,N3) * wavelet_f3_squared.view(1,1,L,M3,N3)
5638
+ #I1_f2_wf3_small = I1_f_small[:,j2].view(N_image,L,1,M3,N3) * wavelet_f3.view(1,1,L,M3,N3)
5639
+ #I1_f2_wf3_2_small = I1_f_small[:,j2].view(N_image,L,1,M3,N3) * wavelet_f3_squared.view(1,1,L,M3,N3)
5640
+ I1_f2_wf3_small = self.backend.bk_reshape(I1_f_small[:,j2],[N_image,1,L,1,M3,N3]) * self.backend.bk_reshape(wavelet_f3,[1,1,1,L,M3,N3])
5641
+ I1_f2_wf3_2_small = self.backend.bk_reshape(I1_f_small[:,j2],[N_image,1,L,1,M3,N3]) * self.backend.bk_reshape(wavelet_f3_squared,[1,1,1,L,M3,N3])
5136
5642
  if edge:
5137
5643
  I12_w3_small = self.backend.bk_ifftn(I1_f2_wf3_small, dim=(-2,-1), norm='ortho')
5138
5644
  I12_w3_2_small = self.backend.bk_ifftn(I1_f2_wf3_2_small, dim=(-2,-1), norm='ortho')
5139
5645
  if use_ref:
5140
5646
  if normalization=='P11':
5141
5647
  norm_factor_S3 = (ref_S2[:,None,j3,:] * ref_P11[:,j2,j3,:,:]**pseudo_coef)**0.5
5648
+ norm_factor_S3 = self.backend.bk_complex(norm_factor_S3,0*norm_factor_S3)
5142
5649
  elif normalization=='S2':
5143
5650
  norm_factor_S3 = (ref_S2[:,None,j3,:] * ref_S2[:,j2,:,None]**pseudo_coef)**0.5
5144
- norm_factor_S3 = 1.0
5651
+ norm_factor_S3 = self.backend.bk_complex(norm_factor_S3,0*norm_factor_S3)
5652
+ else:
5653
+ norm_factor_S3 = C_ONE
5145
5654
  else:
5146
5655
  if normalization=='P11':
5147
5656
  # [N_image,l2,l3,x,y]
5148
- P11_temp = (I1_f2_wf3_small.abs()**2).mean((-2,-1)) * fft_factor
5657
+ P11_temp = self.backend.bk_reduce_mean((I1_f2_wf3_small.abs()**2),axis=(-2,-1)) * fft_factor
5149
5658
  norm_factor_S3 = (S2[:,None,j3,:] * P11_temp**pseudo_coef)**0.5
5659
+ norm_factor_S3 = self.backend.bk_complex(norm_factor_S3,0*norm_factor_S3)
5150
5660
  elif normalization=='S2':
5151
- norm_factor_S3 = (S2[:,None,j3,:] * S2[:,j2,:,None]**pseudo_coef)**0.5
5152
- norm_factor_S3 = 1.0
5661
+ norm_factor_S3 = (S2[:,None,j3,None,:] * S2[:,None,j2,:,None]**pseudo_coef)**0.5
5662
+ norm_factor_S3 = self.backend.bk_complex(norm_factor_S3,0*norm_factor_S3)
5663
+ else:
5664
+ norm_factor_S3 = C_ONE
5665
+
5153
5666
 
5154
5667
  if not edge:
5155
- S3[:,Ndata_S3,:,:] = (
5156
- data_f_small.view(N_image,1,1,M3,N3) * self.backend.bk_conjugate(I1_f2_wf3_small)
5157
- ).mean((-2,-1)) * fft_factor / norm_factor_S3
5158
-
5668
+ S3.append(self.backend.bk_reduce_mean(
5669
+ self.backend.bk_reshape(data_f_small,[N_image,1,1,1,M3,N3]) * self.backend.bk_conjugate(I1_f2_wf3_small)
5670
+ ,axis=(-2,-1)) * fft_factor / norm_factor_S3)
5159
5671
  if get_variance:
5160
- S3_sigma[:,Ndata_S3,:,:] = (
5161
- data_f_small.view(N_image,1,1,M3,N3) * self.backend.bk_conjugate(I1_f2_wf3_small)
5162
- ).std((-2,-1)) * fft_factor / norm_factor_S3
5672
+ S3_sigma.append(self.backend.bk_reduce_std(
5673
+ self.backend.bk_reshape(data_f_small,[N_image,1,1,1,M3,N3]) * self.backend.bk_conjugate(I1_f2_wf3_small)
5674
+ ,axis=(-2,-1)) * fft_factor / norm_factor_S3)
5163
5675
  else:
5164
-
5165
- S3[:,Ndata_S3,:,:] = (
5166
- data_small.view(N_image,1,1,M3,N3) * self.backend.bk_conjugate(I12_w3_small)
5167
- )[...,edge_dx:M3-edge_dx, edge_dy:N3-edge_dy].mean((-2,-1)) * fft_factor / norm_factor_S3
5676
+ S3.append(self.backend.bk_reduce_mean(
5677
+ (self.backend.bk_reshape(data_small,[N_image,1,1,1,M3,N3]) * self.backend.bk_conjugate(I12_w3_small)
5678
+ )[...,edge_dx:M3-edge_dx, edge_dy:N3-edge_dy],axis=(-2,-1)) * fft_factor / norm_factor_S3)
5168
5679
  if get_variance:
5169
- S3_sigma[:,Ndata_S3,:,:] = (
5170
- data_small.view(N_image,1,1,M3,N3) * self.backend.bk_conjugate(I12_w3_small)
5171
- )[...,edge_dx:M3-edge_dx, edge_dy:N3-edge_dy].std((-2,-1)) * fft_factor / norm_factor_S3
5680
+ S3_sigma.apend(self.backend.bk_reduce_std(
5681
+ (self.backend.bk_reshape(data_small,[N_image,1,1,1,M3,N3]) * self.backend.bk_conjugate(I12_w3_small)
5682
+ )[...,edge_dx:M3-edge_dx, edge_dy:N3-edge_dy],axis=(-2,-1)) * fft_factor / norm_factor_S3)
5172
5683
  if data2 is not None:
5173
5684
  if not edge:
5174
- S3p[:,Ndata_S3,:,:] = (
5175
- data2_f_small.view(N_image,1,1,M3,N3) * self.backend.bk_conjugate(I1_f2_wf3_small)
5176
- ).mean((-2,-1)) * fft_factor / norm_factor_S3
5685
+ S3p.append(self.backend.bk_reduce_mean(
5686
+ (self.backend.bk_reshape(data2_f_small,[N_image2,1,1,1,M3,N3]) * self.backend.bk_conjugate(I1_f2_wf3_small)
5687
+ ),axis=(-2,-1)) * fft_factor / norm_factor_S3)
5177
5688
 
5178
5689
  if get_variance:
5179
- S3p_sigma[:,Ndata_S3,:,:] = (
5180
- data2_f_small.view(N_image,1,1,M3,N3) * self.backend.bk_conjugate(I1_f2_wf3_small)
5181
- ).std((-2,-1)) * fft_factor / norm_factor_S3
5690
+ S3p_sigma.append(self.backend.bk_reduce_std(
5691
+ (self.backend.bk_reshape(data2_f_small,[N_image2,1,1,1,M3,N3]) * self.backend.bk_conjugate(I1_f2_wf3_small)
5692
+ ),axis=(-2,-1)) * fft_factor / norm_factor_S3)
5182
5693
  else:
5183
5694
 
5184
- S3p[:,Ndata_S3,:,:] = (
5185
- data2_small.view(N_image,1,1,M3,N3) * self.backend.bk_conjugate(I12_w3_small)
5186
- )[...,edge_dx:M3-edge_dx, edge_dy:N3-edge_dy].mean((-2,-1)) * fft_factor / norm_factor_S3
5695
+ S3p.append(self.backend.bk_reduce_mean(
5696
+ (self.backend.bk_reshape(data2_small,[N_image2,1,1,1,M3,N3]) * self.backend.bk_conjugate(I12_w3_small)
5697
+ )[...,edge_dx:M3-edge_dx, edge_dy:N3-edge_dy],axis=(-2,-1)) * fft_factor / norm_factor_S3)
5187
5698
  if get_variance:
5188
- S3p_sigma[:,Ndata_S3,:,:] = (
5189
- data2_small.view(N_image,1,1,M3,N3) * self.backend.bk_conjugate(I12_w3_small)
5190
- )[...,edge_dx:M3-edge_dx, edge_dy:N3-edge_dy].std((-2,-1)) * fft_factor / norm_factor_S3
5191
-
5192
- Ndata_S3+=1
5699
+ S3p_sigma.append(self.backend.bk_reduce_std(
5700
+ (self.backend.bk_reshape(data2_small,[N_image2,1,1,1,M3,N3]) * self.backend.bk_conjugate(I12_w3_small)
5701
+ )[...,edge_dx:M3-edge_dx, edge_dy:N3-edge_dy],axis=(-2,-1)) * fft_factor / norm_factor_S3)
5702
+
5193
5703
  if j2 <= j3:
5194
- beg_n=Ndata_S4
5704
+ if normalization=='S2':
5705
+ if use_ref:
5706
+ P = 1/((ref_S2[:,j3:j3+1,:,None,None] * ref_S2[:,j2:j2+1,None,:,None] )**(0.5*pseudo_coef))
5707
+ else:
5708
+ P = 1/(((S2[:,j3:j3+1,:,None,None] * S2[:,j2:j2+1,None,:,None] )**(0.5*pseudo_coef)))
5709
+ P=self.backend.bk_complex(P,0.0*P)
5710
+ else:
5711
+ P=C_ONE
5712
+
5195
5713
  for j1 in range(0, j2+1):
5196
- if eval(S4_criteria):
5197
5714
  if not edge:
5198
5715
  if not if_large_batch:
5199
5716
  # [N_image,l1,l2,l3,x,y]
5200
- S4_pre_norm[:,Ndata_S4,:,:,:] = (
5201
- I1_f_small[:,j1].view(N_image,L,1,1,M3,N3) *
5202
- self.backend.bk_conjugate(I1_f2_wf3_2_small.view(N_image,1,L,L,M3,N3))
5203
- ).mean((-2,-1)) * fft_factor
5717
+ S4.append(self.backend.bk_reduce_mean(
5718
+ (self.backend.bk_reshape(I1_f_small[:,j1],[N_image,1,L,1,1,M3,N3]) *
5719
+ self.backend.bk_conjugate(self.backend.bk_reshape(I1_f2_wf3_2_small,[N_image,1,1,L,L,M3,N3]))
5720
+ ),axis=(-2,-1)) * fft_factor*P)
5204
5721
  if get_variance:
5205
- S4_sigma[:,Ndata_S4,:,:,:] = (
5206
- I1_f_small[:,j1].view(N_image,L,1,1,M3,N3) *
5207
- self.backend.bk_conjugate(I1_f2_wf3_2_small.view(N_image,1,L,L,M3,N3))
5208
- ).std((-2,-1)) * fft_factor
5722
+ S4_sigma.append(self.backend.bk_reduce_std(
5723
+ (self.backend.bk_reshape(I1_f_small[:,j1],[N_image,1,L,1,1,M3,N3]) *
5724
+ self.backend.bk_conjugate(self.backend.bk_reshape(I1_f2_wf3_2_small,[N_image,1,1,L,L,M3,N3]))
5725
+ ),axis=(-2,-1)) * fft_factor*P)
5209
5726
  else:
5210
5727
  for l1 in range(L):
5211
5728
  # [N_image,l2,l3,x,y]
5212
- S4_pre_norm[:,Ndata_S4,l1,:,:] = (
5213
- I1_f_small[:,j1,l1].view(N_image,1,1,M3,N3) *
5214
- self.backend.bk_conjugate(I1_f2_wf3_2_small.view(N_image,L,L,M3,N3))
5215
- ).mean((-2,-1)) * fft_factor
5729
+ S4.append(self.backend.bk_reduce_mean(
5730
+ (self.backend.bk_reshape(I1_f_small[:,j1,l1],[N_image,1,1,1,M3,N3]) *
5731
+ self.backend.bk_conjugate(self.backend.bk_reshape(I1_f2_wf3_2_small,[N_image,1,L,L,M3,N3]))
5732
+ ),axis=(-2,-1)) * fft_factor*P)
5216
5733
  if get_variance:
5217
- S4_sigma[:,Ndata_S4,l1,:,:] = (
5218
- I1_f_small[:,j1,l1].view(N_image,1,1,M3,N3) *
5219
- self.backend.bk_conjugate(I1_f2_wf3_2_small.view(N_image,L,L,M3,N3))
5220
- ).std((-2,-1)) * fft_factor
5734
+ S4_sigma.append(self.backend.bk_reduce_std(
5735
+ (self.backend.bk_reshape(I1_f_small[:,j1,l1],[N_image,1,1,1,M3,N3]) *
5736
+ self.backend.bk_conjugate(self.backend.bk_reshape(I1_f2_wf3_2_small,[N_image,1,L,L,M3,N3]))
5737
+ ),axis=(-2,-1)) * fft_factor*P)
5221
5738
  else:
5222
5739
  if not if_large_batch:
5223
5740
  # [N_image,l1,l2,l3,x,y]
5224
- S4_pre_norm[:,Ndata_S4,:,:,:] = (
5225
- I1_small[:,j1].view(N_image,L,1,1,M3,N3) * self.backend.bk_conjugate(
5226
- I12_w3_2_small.view(N_image,1,L,L,M3,N3)
5741
+ S4.append(self.backend.bk_reduce_mean(
5742
+ (self.backend.bk_reshape(I1_small[:,j1],[N_image,1,L,1,1,M3,N3]) * self.backend.bk_conjugate(
5743
+ self.backend.bk_reshape(I12_w3_2_small,[N_image,1,1,L,L,M3,N3])
5227
5744
  )
5228
- )[...,edge_dx:-edge_dx, edge_dy:-edge_dy].mean((-2,-1)) * fft_factor
5745
+ )[...,edge_dx:-edge_dx, edge_dy:-edge_dy],axis=(-2,-1)) * fft_factor*P)
5229
5746
  if get_variance:
5230
- S4_sigma[:,Ndata_S4,:,:,:] = (
5231
- I1_small[:,j1].view(N_image,L,1,1,M3,N3) * self.backend.bk_conjugate(
5232
- I12_w3_2_small.view(N_image,1,L,L,M3,N3)
5747
+ S4_sigma.append(self.backend.bk_reduce_std(
5748
+ (self.backend.bk_reshape(I1_small[:,j1],[N_image,1,L,1,1,M3,N3]) * self.backend.bk_conjugate(
5749
+ self.backend.bk_reshape(I12_w3_2_small,[N_image,1,1,L,L,M3,N3])
5233
5750
  )
5234
- )[...,edge_dx:-edge_dx, edge_dy:-edge_dy].std((-2,-1)) * fft_factor
5751
+ )[...,edge_dx:-edge_dx, edge_dy:-edge_dy],axis=(-2,-1)) * fft_factor*P)
5235
5752
  else:
5236
5753
  for l1 in range(L):
5237
5754
  # [N_image,l2,l3,x,y]
5238
- S4_pre_norm[:,Ndata_S4,l1,:,:] = (
5239
- I1_small[:,j1].view(N_image,1,1,M3,N3) * self.backend.bk_conjugate(
5240
- I12_w3_2_small.view(N_image,L,L,M3,N3)
5755
+ S4.append(self.backend.bk_reduce_mean(
5756
+ (self.backend.bk_reshape(I1_small[:,j1],[N_image,1,1,1,M3,N3]) * self.backend.bk_conjugate(
5757
+ self.backend.bk_reshape(I12_w3_2_small,[N_image,1,L,L,M3,N3])
5241
5758
  )
5242
- )[...,edge_dx:-edge_dx, edge_dy:-edge_dy].mean((-2,-1)) * fft_factor
5759
+ )[...,edge_dx:-edge_dx, edge_dy:-edge_dy],axis=(-2,-1)) * fft_factor*P)
5243
5760
  if get_variance:
5244
- S4_sigma[:,Ndata_S4,l1,:,:] = (
5245
- I1_small[:,j1].view(N_image,1,1,M3,N3) * self.backend.bk_conjugate(
5246
- I12_w3_2_small.view(N_image,L,L,M3,N3)
5761
+ S4_sigma.append(self.backend.bk_reduce_std(
5762
+ (self.backend.bk_reshape(I1_small[:,j1],[N_image,1,1,1,M3,N3]) * self.backend.bk_conjugate(
5763
+ self.backend.bk_reshape(I12_w3_2_small,[N_image,1,L,L,M3,N3])
5247
5764
  )
5248
- )[...,edge_dx:-edge_dx, edge_dy:-edge_dy].mean((-2,-1)) * fft_factor
5765
+ )[...,edge_dx:-edge_dx, edge_dy:-edge_dy],axis=(-2,-1)) * fft_factor*P)
5249
5766
 
5250
- Ndata_S4+=1
5251
-
5252
- if normalization=='S2':
5253
- if use_ref:
5254
- P = (ref_S2[:,j3,:,None,None] * ref_S2[:,j2,None,:,None] )**(0.5*pseudo_coef)
5255
- else:
5256
- P = ((S2[:,j3,:,None,None] * S2[:,j2,None,:,None] )**(0.5*pseudo_coef))
5257
- S4[:,beg_n:Ndata_S4,:,:,:]=S4_pre_norm[:,beg_n:Ndata_S4,:,:,:].clone()/(P.clone())
5258
-
5259
- if get_variance:
5260
- S4_sigma[:,beg_n:Ndata_S4,:,:,:] = S4_sigma[:,beg_n:Ndata_S4,:,:,:] / (P)
5261
- else:
5262
- S4[:,beg_n:Ndata_S4,:,:,:]=S4_pre_norm[:,beg_n:Ndata_S4,:,:,:].clone()
5263
-
5264
- if get_variance:
5265
- S4_sigma[:,beg_n:Ndata_S4,:,:,:] = S4_sigma[:,beg_n:Ndata_S4,:,:,:]
5266
-
5267
- """
5268
- # define P11 from diagonals of S4
5269
- for j1 in range(J):
5270
- for l1 in range(L):
5271
- P11[:,j1,:,l1,:] = S4_pre_norm[:,j1,j1,:,l1,l1,:].real
5272
-
5273
-
5274
- if normalization=='S4':
5275
- if use_ref:
5276
- P = ref_P11
5277
- else:
5278
- P = P11
5279
- #.view(N_image,J,1,J,L,1,L) * .view(N_image,1,J,J,1,L,L)
5280
- S4 = S4_pre_norm / (
5281
- P[:,:,None,:,:,None,:] * P[:,None,:,:,None,:,:]
5282
- )**(0.5*pseudo_coef)
5283
-
5284
-
5767
+ S3=self.backend.bk_concat(S3,axis=1)
5768
+ S4=self.backend.bk_concat(S4,axis=1)
5285
5769
 
5770
+ if get_variance:
5771
+ S3_sigma=self.backend.bk_concat(S3_sigma,axis=1)
5772
+ S4_sigma=self.backend.bk_concat(S4_sigma,axis=1)
5773
+
5774
+ if data2 is not None:
5775
+ S3p=self.backend.bk_concat(S3p,axis=1)
5776
+ if get_variance:
5777
+ S3p_sigma=self.backend.bk_concat(S3p_sigma,axis=1)
5286
5778
 
5287
- # get a single, flattened data vector for_synthesis
5288
- select_and_index = self.get_scattering_index(J, L, normalization, S4_criteria)
5289
- index_for_synthesis = select_and_index['index_for_synthesis']
5290
- index_for_synthesis_iso = select_and_index['index_for_synthesis_iso']
5291
- """
5292
5779
  # average over l1 to obtain simple isotropic statistics
5293
5780
  if iso_ang:
5294
- S2_iso = S2.mean(-1)
5295
- S1_iso = S1.mean(-1)
5781
+ S2_iso = self.backend.bk_reduce_mean(S2,axis=(-1))
5782
+ S1_iso = self.backend.bk_reduce_mean(S1,axis=(-1))
5296
5783
  for l1 in range(L):
5297
5784
  for l2 in range(L):
5298
5785
  S3_iso[...,(l2-l1)%L] += S3[...,l1,l2]
@@ -5305,8 +5792,8 @@ class funct(FOC.FoCUS):
5305
5792
  S3p_iso /= L
5306
5793
 
5307
5794
  if get_variance:
5308
- S2_sigma_iso = S2_sigma.mean(-1)
5309
- S1_sigma_iso = S1_sigma.mean(-1)
5795
+ S2_sigma_iso = self.backend.bk_reduce_mean(S2_sigma,axis=(-1))
5796
+ S1_sigma_iso = self.backend.bk_reduce_mean(S1_sigma,axis=(-1))
5310
5797
  for l1 in range(L):
5311
5798
  for l2 in range(L):
5312
5799
  S3_sigma_iso[...,(l2-l1)%L] += S3_sigma[...,l1,l2]
@@ -5318,10 +5805,12 @@ class funct(FOC.FoCUS):
5318
5805
  if data2 is not None:
5319
5806
  S3p_sigma_iso /= L
5320
5807
 
5321
- mean_data=self.backend.bk_zeros((N_image,1), dtype=data.dtype)
5322
- std_data=self.backend.bk_zeros((N_image,1), dtype=data.dtype)
5323
- mean_data[:,0]=data.mean((-2,-1))
5324
- std_data[:,0]=data.std((-2,-1))
5808
+ if data2 is None:
5809
+ mean_data=self.backend.bk_reshape(self.backend.bk_reduce_mean(data,axis=(-2,-1)),[N_image,1])
5810
+ std_data=self.backend.bk_reshape(self.backend.bk_reduce_std(data,axis=(-2,-1)),[N_image,1])
5811
+ else:
5812
+ mean_data=self.backend.bk_reshape(self.backend.bk_reduce_mean(data*data2,axis=(-2,-1)),[N_image,1])
5813
+ std_data=self.backend.bk_reshape(self.backend.bk_reduce_std(data*data2,axis=(-2,-1)),[N_image,1])
5325
5814
 
5326
5815
  if get_variance:
5327
5816
  ref_sigma={}
@@ -5345,105 +5834,105 @@ class funct(FOC.FoCUS):
5345
5834
  if data2 is None:
5346
5835
  if iso_ang:
5347
5836
  if ref_sigma is not None:
5348
- for_synthesis = self.backend.backend.cat((
5837
+ for_synthesis = self.backend.bk_concat((
5349
5838
  mean_data/ref_sigma['std_data'],
5350
5839
  std_data/ref_sigma['std_data'],
5351
- (S2_iso/ref_sigma['S2_sigma']).reshape((N_image, -1)).log(),
5352
- (S1_iso/ref_sigma['S1_sigma']).reshape((N_image, -1)).log(),
5353
- (S3_iso/ref_sigma['S3_sigma']).reshape((N_image, -1)).real,
5354
- (S3_iso/ref_sigma['S3_sigma']).reshape((N_image, -1)).imag,
5355
- (S4_iso/ref_sigma['S4_sigma']).reshape((N_image, -1)).real,
5356
- (S4_iso/ref_sigma['S4_sigma']).reshape((N_image, -1)).imag,
5357
- ),dim=-1)
5840
+ self.backend.bk_reshape(self.backend.bk_log(S2_iso/ref_sigma['S2_sigma']),[N_image, -1]),
5841
+ self.backend.bk_reshape(self.backend.bk_log(S1_iso/ref_sigma['S1_sigma']),[N_image, -1]),
5842
+ self.backend.bk_reshape(self.backend.bk_real(S3_iso/ref_sigma['S3_sigma']),[N_image, -1]),
5843
+ self.backend.bk_reshape(self.backend.bk_imag(S3_iso/ref_sigma['S3_sigma']),[N_image, -1]),
5844
+ self.backend.bk_reshape(self.backend.bk_real(S4_iso/ref_sigma['S4_sigma']),[N_image, -1]),
5845
+ self.backend.bk_reshape(self.backend.bk_imag(S4_iso/ref_sigma['S4_sigma']),[N_image, -1]),
5846
+ ),axis=-1)
5358
5847
  else:
5359
- for_synthesis = self.backend.backend.cat((
5848
+ for_synthesis = self.backend.bk_concat((
5360
5849
  mean_data/std_data,
5361
5850
  std_data,
5362
- S2_iso.reshape((N_image, -1)).log(),
5363
- S1_iso.reshape((N_image, -1)).log(),
5364
- S3_iso.reshape((N_image, -1)).real,
5365
- S3_iso.reshape((N_image, -1)).imag,
5366
- S4_iso.reshape((N_image, -1)).real,
5367
- S4_iso.reshape((N_image, -1)).imag,
5368
- ),dim=-1)
5851
+ self.backend.bk_reshape(self.backend.bk_log(S2_iso),[N_image, -1]),
5852
+ self.backend.bk_reshape(self.backend.bk_log(S1_iso),[N_image, -1]),
5853
+ self.backend.bk_reshape(self.backend.bk_real(S3_iso),[N_image, -1]),
5854
+ self.backend.bk_reshape(self.backend.bk_imag(S3_iso),[N_image, -1]),
5855
+ self.backend.bk_reshape(self.backend.bk_real(S4_iso),[N_image, -1]),
5856
+ self.backend.bk_reshape(self.backend.bk_imag(S4_iso),[N_image, -1]),
5857
+ ),axis=-1)
5369
5858
  else:
5370
5859
  if ref_sigma is not None:
5371
- for_synthesis = self.backend.backend.cat((
5860
+ for_synthesis = self.backend.bk_concat((
5372
5861
  mean_data/ref_sigma['std_data'],
5373
5862
  std_data/ref_sigma['std_data'],
5374
- (S2/ref_sigma['S2_sigma']).reshape((N_image, -1)).log(),
5375
- (S1/ref_sigma['S1_sigma']).reshape((N_image, -1)).log(),
5376
- (S3/ref_sigma['S3_sigma']).reshape((N_image, -1)).real,
5377
- (S3/ref_sigma['S3_sigma']).reshape((N_image, -1)).imag,
5378
- (S4/ref_sigma['S4_sigma']).reshape((N_image, -1)).real,
5379
- (S4/ref_sigma['S4_sigma']).reshape((N_image, -1)).imag,
5380
- ),dim=-1)
5863
+ self.backend.bk_reshape(self.backend.bk_log(S2/ref_sigma['S2_sigma']),[N_image, -1]),
5864
+ self.backend.bk_reshape(self.backend.bk_log(S1/ref_sigma['S1_sigma']),[N_image, -1]),
5865
+ self.backend.bk_reshape(self.backend.bk_real(S3/ref_sigma['S3_sigma']),[N_image, -1]),
5866
+ self.backend.bk_reshape(self.backend.bk_imag(S3/ref_sigma['S3_sigma']),[N_image, -1]),
5867
+ self.backend.bk_reshape(self.backend.bk_real(S4/ref_sigma['S4_sigma']),[N_image, -1]),
5868
+ self.backend.bk_reshape(self.backend.bk_imag(S4/ref_sigma['S4_sigma']),[N_image, -1]),
5869
+ ),axis=-1)
5381
5870
  else:
5382
- for_synthesis = self.backend.backend.cat((
5383
- mean_data/std_data,
5384
- std_data,
5385
- S2.reshape((N_image, -1)).log(),
5386
- S1.reshape((N_image, -1)).log(),
5387
- S3.reshape((N_image, -1)).real,
5388
- S3.reshape((N_image, -1)).imag,
5389
- S4.reshape((N_image, -1)).real,
5390
- S4.reshape((N_image, -1)).imag,
5391
- ),dim=-1)
5871
+ for_synthesis = self.backend.bk_concat((
5872
+ mean_data/std_data,
5873
+ std_data,
5874
+ self.backend.bk_reshape(self.backend.bk_log(S2),[N_image, -1]),
5875
+ self.backend.bk_reshape(self.backend.bk_log(S1),[N_image, -1]),
5876
+ self.backend.bk_reshape(self.backend.bk_real(S3),[N_image, -1]),
5877
+ self.backend.bk_reshape(self.backend.bk_imag(S3),[N_image, -1]),
5878
+ self.backend.bk_reshape(self.backend.bk_real(S4),[N_image, -1]),
5879
+ self.backend.bk_reshape(self.backend.bk_imag(S4),[N_image, -1])
5880
+ ),axis=-1)
5392
5881
  else:
5393
5882
  if iso_ang:
5394
5883
  if ref_sigma is not None:
5395
5884
  for_synthesis = self.backend.backend.cat((
5396
5885
  mean_data/ref_sigma['std_data'],
5397
5886
  std_data/ref_sigma['std_data'],
5398
- (S2_iso/ref_sigma['S2_sigma']).reshape((N_image, -1)),
5399
- (S1_iso/ref_sigma['S1_sigma']).reshape((N_image, -1)),
5400
- (S3_iso/ref_sigma['S3_sigma']).reshape((N_image, -1)).real,
5401
- (S3_iso/ref_sigma['S3_sigma']).reshape((N_image, -1)).imag,
5402
- (S3p_iso/ref_sigma['S3p_sigma']).reshape((N_image, -1)).real,
5403
- (S3p_iso/ref_sigma['S3p_sigma']).reshape((N_image, -1)).imag,
5404
- (S4_iso/ref_sigma['S4_sigma']).reshape((N_image, -1)).real,
5405
- (S4_iso/ref_sigma['S4_sigma']).reshape((N_image, -1)).imag,
5406
- ),dim=-1)
5887
+ self.backend.bk_reshape(self.backend.bk_real(S2_iso/ref_sigma['S2_sigma']),[N_image, -1]),
5888
+ self.backend.bk_reshape(self.backend.bk_real(S1_iso/ref_sigma['S1_sigma']),[N_image, -1]),
5889
+ self.backend.bk_reshape(self.backend.bk_real(S3_iso/ref_sigma['S3_sigma']),[N_image, -1]),
5890
+ self.backend.bk_reshape(self.backend.bk_imag(S3_iso/ref_sigma['S3_sigma']),[N_image, -1]),
5891
+ self.backend.bk_reshape(self.backend.bk_real(S3p_iso/ref_sigma['S3p_sigma']),[N_image, -1]),
5892
+ self.backend.bk_reshape(self.backend.bk_imag(S3p_iso/ref_sigma['S3p_sigma']),[N_image, -1]),
5893
+ self.backend.bk_reshape(self.backend.bk_real(S4_iso/ref_sigma['S4_sigma']),[N_image, -1]),
5894
+ self.backend.bk_reshape(self.backend.bk_imag(S4_iso/ref_sigma['S4_sigma']),[N_image, -1]),
5895
+ ),axis=-1)
5407
5896
  else:
5408
5897
  for_synthesis = self.backend.backend.cat((
5409
5898
  mean_data/std_data,
5410
5899
  std_data,
5411
- S2_iso.reshape((N_image, -1)),
5412
- S1_iso.reshape((N_image, -1)),
5413
- S3_iso.reshape((N_image, -1)).real,
5414
- S3_iso.reshape((N_image, -1)).imag,
5415
- S3p_iso.reshape((N_image, -1)).real,
5416
- S3p_iso.reshape((N_image, -1)).imag,
5417
- S4_iso.reshape((N_image, -1)).real,
5418
- S4_iso.reshape((N_image, -1)).imag,
5419
- ),dim=-1)
5900
+ self.backend.bk_reshape(self.backend.bk_real(S2_iso),[N_image, -1]),
5901
+ self.backend.bk_reshape(self.backend.bk_real(S1_iso),[N_image, -1]),
5902
+ self.backend.bk_reshape(self.backend.bk_real(S3_iso),[N_image, -1]),
5903
+ self.backend.bk_reshape(self.backend.bk_imag(S3_iso),[N_image, -1]),
5904
+ self.backend.bk_reshape(self.backend.bk_real(S3p_iso),[N_image, -1]),
5905
+ self.backend.bk_reshape(self.backend.bk_imag(S3p_iso),[N_image, -1]),
5906
+ self.backend.bk_reshape(self.backend.bk_real(S4_iso),[N_image, -1]),
5907
+ self.backend.bk_reshape(self.backend.bk_imag(S4_iso),[N_image, -1]),
5908
+ ),axis=-1)
5420
5909
  else:
5421
5910
  if ref_sigma is not None:
5422
5911
  for_synthesis = self.backend.backend.cat((
5423
5912
  mean_data/ref_sigma['std_data'],
5424
5913
  std_data/ref_sigma['std_data'],
5425
- (S2/ref_sigma['S2_sigma']).reshape((N_image, -1)),
5426
- (S1/ref_sigma['S1_sigma']).reshape((N_image, -1)),
5427
- (S3/ref_sigma['S3_sigma']).reshape((N_image, -1)).real,
5428
- (S3/ref_sigma['S3_sigma']).reshape((N_image, -1)).imag,
5429
- (S3p/ref_sigma['S3p_sigma']).reshape((N_image, -1)).real,
5430
- (S3p/ref_sigma['S3p_sigma']).reshape((N_image, -1)).imag,
5431
- (S4/ref_sigma['S4_sigma']).reshape((N_image, -1)).real,
5432
- (S4/ref_sigma['S4_sigma']).reshape((N_image, -1)).imag,
5433
- ),dim=-1)
5914
+ self.backend.bk_reshape(self.backend.bk_real(S2/ref_sigma['S2_sigma']),[N_image, -1]),
5915
+ self.backend.bk_reshape(self.backend.bk_real(S1/ref_sigma['S1_sigma']),[N_image, -1]),
5916
+ self.backend.bk_reshape(self.backend.bk_real(S3/ref_sigma['S3_sigma']),[N_image, -1]),
5917
+ self.backend.bk_reshape(self.backend.bk_imag(S3/ref_sigma['S3_sigma']),[N_image, -1]),
5918
+ self.backend.bk_reshape(self.backend.bk_real(S3p/ref_sigma['S3p_sigma']),[N_image, -1]),
5919
+ self.backend.bk_reshape(self.backend.bk_imag(S3p/ref_sigma['S3p_sigma']),[N_image, -1]),
5920
+ self.backend.bk_reshape(self.backend.bk_real(S4/ref_sigma['S4_sigma']),[N_image, -1]),
5921
+ self.backend.bk_reshape(self.backend.bk_imag(S4/ref_sigma['S4_sigma']),[N_image, -1]),
5922
+ ),axis=-1)
5434
5923
  else:
5435
- for_synthesis = self.backend.backend.cat((
5436
- mean_data/std_data,
5437
- std_data,
5438
- S2.reshape((N_image, -1)),
5439
- S1.reshape((N_image, -1)),
5440
- S3.reshape((N_image, -1)).real,
5441
- S3.reshape((N_image, -1)).imag,
5442
- S3p.reshape((N_image, -1)).real,
5443
- S3p.reshape((N_image, -1)).imag,
5444
- S4.reshape((N_image, -1)).real,
5445
- S4.reshape((N_image, -1)).imag,
5446
- ),dim=-1)
5924
+ for_synthesis = self.backend.bk_concat((
5925
+ mean_data/std_data,
5926
+ std_data,
5927
+ self.backend.bk_reshape(self.backend.bk_real(S2),[N_image, -1]),
5928
+ self.backend.bk_reshape(self.backend.bk_real(S1),[N_image, -1]),
5929
+ self.backend.bk_reshape(self.backend.bk_real(S3),[N_image, -1]),
5930
+ self.backend.bk_reshape(self.backend.bk_imag(S3),[N_image, -1]),
5931
+ self.backend.bk_reshape(self.backend.bk_real(S3p),[N_image, -1]),
5932
+ self.backend.bk_reshape(self.backend.bk_imag(S3p),[N_image, -1]),
5933
+ self.backend.bk_reshape(self.backend.bk_real(S4),[N_image, -1]),
5934
+ self.backend.bk_reshape(self.backend.bk_imag(S4),[N_image, -1])
5935
+ ),axis=-1)
5447
5936
 
5448
5937
  if not use_ref:
5449
5938
  self.ref_scattering_cov_S2=S2
@@ -5835,6 +6324,7 @@ class funct(FOC.FoCUS):
5835
6324
  to_gaussian=True,
5836
6325
  use_variance=False,
5837
6326
  synthesised_N=1,
6327
+ input_image=None,
5838
6328
  iso_ang=False,
5839
6329
  EVAL_FREQUENCY=100,
5840
6330
  NUM_EPOCHS = 300):
@@ -5892,14 +6382,22 @@ class funct(FOC.FoCUS):
5892
6382
 
5893
6383
  for k in range(nstep):
5894
6384
  if k==0:
5895
- np.random.seed(seed)
5896
- if self.use_2D:
5897
- imap=np.random.randn(synthesised_N,
5898
- tmp[k].shape[1],
5899
- tmp[k].shape[2])
6385
+ if input_image is None:
6386
+ np.random.seed(seed)
6387
+ if self.use_2D:
6388
+ imap=np.random.randn(synthesised_N,
6389
+ tmp[k].shape[1],
6390
+ tmp[k].shape[2])
6391
+ else:
6392
+ imap=np.random.randn(synthesised_N,
6393
+ tmp[k].shape[1])
5900
6394
  else:
5901
- imap=np.random.randn(synthesised_N,
5902
- tmp[k].shape[1])
6395
+ if self.use_2D:
6396
+ imap=self.backend.bk_reshape(self.backend.bk_tile(self.backend.bk_cast(input_image.flatten()),synthesised_N),
6397
+ [synthesised_N,tmp[k].shape[1],tmp[k].shape[2]])
6398
+ else:
6399
+ imap=self.backend.bk_reshape(self.backend.bk_tile(self.backend.bk_cast(input_image.flatten()),synthesised_N),
6400
+ [synthesised_N,tmp[k].shape[1]])
5903
6401
  else:
5904
6402
  # Increase the resolution between each step
5905
6403
  if self.use_2D: