foscat 3.8.2__py3-none-any.whl → 3.9.0__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
foscat/scat_cov.py CHANGED
@@ -92,7 +92,12 @@ class scat_cov:
92
92
  )
93
93
 
94
94
  def conv2complex(self, val):
95
- if val.dtype == "complex64" or val.dtype=="complex128" or val.dtype == "torch.complex64" or val.dtype == "torch.complex128" :
95
+ if (
96
+ val.dtype == "complex64"
97
+ or val.dtype == "complex128"
98
+ or val.dtype == "torch.complex64"
99
+ or val.dtype == "torch.complex128"
100
+ ):
96
101
  return val
97
102
  else:
98
103
  return self.backend.bk_complex(val, 0 * val)
@@ -2043,9 +2048,11 @@ class scat_cov:
2043
2048
  )
2044
2049
  )
2045
2050
  else:
2046
- s3[i, j, idx[noff:], k, l_orient] = self.backend.to_numpy(self.S3)[
2047
- i, j, j2 == ij - noff, k, l_orient
2048
- ]
2051
+ s3[i, j, idx[noff:], k, l_orient] = (
2052
+ self.backend.to_numpy(self.S3)[
2053
+ i, j, j2 == ij - noff, k, l_orient
2054
+ ]
2055
+ )
2049
2056
  s3[i, j, idx[:noff], k, l_orient] = (
2050
2057
  self.add_data_from_slope(
2051
2058
  self.backend.to_numpy(self.S3)[
@@ -2208,7 +2215,7 @@ class scat_cov:
2208
2215
 
2209
2216
 
2210
2217
  class funct(FOC.FoCUS):
2211
-
2218
+
2212
2219
  def fill(self, im, nullval=hp.UNSEEN):
2213
2220
  if self.use_2D:
2214
2221
  return self.fill_2d(im, nullval=nullval)
@@ -2375,13 +2382,19 @@ class funct(FOC.FoCUS):
2375
2382
 
2376
2383
  # instead of difference between "opposite" channels use weighted average
2377
2384
  # of cosine and sine contributions using all channels
2378
- angles = self.backend.bk_cast((
2379
- 2 * np.pi * np.arange(self.NORIENT) / self.NORIENT
2380
- ).reshape(1, 1, self.NORIENT)) # shape: (NORIENT,)
2385
+ angles = self.backend.bk_cast(
2386
+ (2 * np.pi * np.arange(self.NORIENT) / self.NORIENT).reshape(
2387
+ 1, 1, self.NORIENT
2388
+ )
2389
+ ) # shape: (NORIENT,)
2381
2390
 
2382
2391
  # we use cosines and sines as weights for sim
2383
- weighted_cos = self.backend.bk_reduce_mean(sim * self.backend.bk_cos(angles), axis=-1)
2384
- weighted_sin = self.backend.bk_reduce_mean(sim * self.backend.bk_sin(angles), axis=-1)
2392
+ weighted_cos = self.backend.bk_reduce_mean(
2393
+ sim * self.backend.bk_cos(angles), axis=-1
2394
+ )
2395
+ weighted_sin = self.backend.bk_reduce_mean(
2396
+ sim * self.backend.bk_sin(angles), axis=-1
2397
+ )
2385
2398
  # For simplicity, take first element of the batch
2386
2399
  cc = weighted_cos[0]
2387
2400
  ss = weighted_sin[0]
@@ -2401,8 +2414,9 @@ class funct(FOC.FoCUS):
2401
2414
  phase = np.fmod(np.arctan2(ss, cc) + 2 * np.pi, 2 * np.pi)
2402
2415
  else:
2403
2416
  phase = np.fmod(
2404
- np.arctan2(self.backend.to_numpy(ss),
2405
- self.backend.to_numpy(cc)) + 2 * np.pi, 2 * np.pi
2417
+ np.arctan2(self.backend.to_numpy(ss), self.backend.to_numpy(cc))
2418
+ + 2 * np.pi,
2419
+ 2 * np.pi,
2406
2420
  )
2407
2421
 
2408
2422
  # instead of linear interpolation cosine‐based interpolation
@@ -2416,10 +2430,10 @@ class funct(FOC.FoCUS):
2416
2430
  # build rotation matrix
2417
2431
  mat = np.zeros([sim.shape[1], self.NORIENT * self.NORIENT])
2418
2432
  lidx = np.arange(sim.shape[1])
2419
- for l in range(self.NORIENT):
2433
+ for ell in range(self.NORIENT):
2420
2434
  # Instead of simple linear weights, we use the cosine weights w0 and w1.
2421
- col0 = self.NORIENT * ((l + iph) % self.NORIENT) + l
2422
- col1 = self.NORIENT * ((l + iph + 1) % self.NORIENT) + l
2435
+ col0 = self.NORIENT * ((ell + iph) % self.NORIENT) + ell
2436
+ col1 = self.NORIENT * ((ell + iph + 1) % self.NORIENT) + ell
2423
2437
  mat[lidx, col0] = w0
2424
2438
  mat[lidx, col1] = w1
2425
2439
 
@@ -2436,7 +2450,8 @@ class funct(FOC.FoCUS):
2436
2450
  sim2 = self.backend.bk_reduce_sum(
2437
2451
  self.backend.bk_reshape(
2438
2452
  self.backend.bk_cast(
2439
- mat.reshape(1, mat.shape[0], self.NORIENT * self.NORIENT))
2453
+ mat.reshape(1, mat.shape[0], self.NORIENT * self.NORIENT)
2454
+ )
2440
2455
  * tmp2,
2441
2456
  [sim.shape[0], cmat[k].shape[0], self.NORIENT, self.NORIENT],
2442
2457
  ),
@@ -2470,8 +2485,11 @@ class funct(FOC.FoCUS):
2470
2485
  phase2 = np.fmod(np.arctan2(ss2, cc2) + 2 * np.pi, 2 * np.pi)
2471
2486
  else:
2472
2487
  phase2 = np.fmod(
2473
- np.arctan2(self.backend.to_numpy(ss2),
2474
- self.backend.to_numpy(cc2)) + 2 * np.pi, 2 * np.pi
2488
+ np.arctan2(
2489
+ self.backend.to_numpy(ss2), self.backend.to_numpy(cc2)
2490
+ )
2491
+ + 2 * np.pi,
2492
+ 2 * np.pi,
2475
2493
  )
2476
2494
 
2477
2495
  phase2_scaled = self.NORIENT * phase2 / (2 * np.pi)
@@ -2482,9 +2500,11 @@ class funct(FOC.FoCUS):
2482
2500
  lidx = np.arange(sim.shape[1])
2483
2501
 
2484
2502
  for m in range(self.NORIENT):
2485
- for l in range(self.NORIENT):
2486
- col0 = self.NORIENT * ((l + iph2[:, m]) % self.NORIENT) + l
2487
- col1 = self.NORIENT * ((l + iph2[:, m] + 1) % self.NORIENT) + l
2503
+ for ell in range(self.NORIENT):
2504
+ col0 = self.NORIENT * ((ell + iph2[:, m]) % self.NORIENT) + ell
2505
+ col1 = (
2506
+ self.NORIENT * ((ell + iph2[:, m] + 1) % self.NORIENT) + ell
2507
+ )
2488
2508
  mat2[k2, lidx, m, col0] = w0_2[:, m]
2489
2509
  mat2[k2, lidx, m, col1] = w1_2[:, m]
2490
2510
  cmat2[k] = self.backend.bk_cast(mat2.astype("complex64"))
@@ -2502,17 +2522,17 @@ class funct(FOC.FoCUS):
2502
2522
  )
2503
2523
 
2504
2524
  def eval(
2505
- self,
2506
- image1,
2507
- image2=None,
2508
- mask=None,
2509
- norm=None,
2510
- calc_var=False,
2511
- cmat=None,
2512
- cmat2=None,
2513
- Jmax=None,
2514
- out_nside=None,
2515
- edge=True
2525
+ self,
2526
+ image1,
2527
+ image2=None,
2528
+ mask=None,
2529
+ norm=None,
2530
+ calc_var=False,
2531
+ cmat=None,
2532
+ cmat2=None,
2533
+ Jmax=None,
2534
+ out_nside=None,
2535
+ edge=True,
2516
2536
  ):
2517
2537
  """
2518
2538
  Calculates the scattering correlations for a batch of images. Mean are done over pixels.
@@ -2542,9 +2562,9 @@ class funct(FOC.FoCUS):
2542
2562
  -------
2543
2563
  S1, S2, S3, S4 normalized
2544
2564
  """
2545
-
2565
+
2546
2566
  return_data = self.return_data
2547
-
2567
+
2548
2568
  # Check input consistency
2549
2569
  if image2 is not None:
2550
2570
  if list(image1.shape) != list(image2.shape):
@@ -2554,7 +2574,10 @@ class funct(FOC.FoCUS):
2554
2574
  return None
2555
2575
  if mask is not None:
2556
2576
  if self.use_2D:
2557
- if image1.shape[-2] != mask.shape[1] or image1.shape[-1] != mask.shape[2]:
2577
+ if (
2578
+ image1.shape[-2] != mask.shape[1]
2579
+ or image1.shape[-1] != mask.shape[2]
2580
+ ):
2558
2581
  print(
2559
2582
  "The LAST 2 COLUMNs of the mask should have the same size ",
2560
2583
  mask.shape,
@@ -2564,7 +2587,7 @@ class funct(FOC.FoCUS):
2564
2587
  )
2565
2588
  return None
2566
2589
  else:
2567
- if image1.shape[-1] != mask.shape[1]:
2590
+ if image1.shape[-1] != mask.shape[1]:
2568
2591
  print(
2569
2592
  "The LAST COLUMN of the mask should have the same size ",
2570
2593
  mask.shape,
@@ -2618,16 +2641,17 @@ class funct(FOC.FoCUS):
2618
2641
  nside = int(np.sqrt(npix // 12))
2619
2642
 
2620
2643
  J = int(np.log(nside) / np.log(2)) # Number of j scales
2621
-
2622
- if (self.use_2D or self.use_1D) and self.KERNELSZ>3:
2623
- J-=1
2644
+
2645
+ if (self.use_2D or self.use_1D) and self.KERNELSZ > 3:
2646
+ J -= 1
2624
2647
  if Jmax is None:
2625
2648
  Jmax = J # Number of steps for the loop on scales
2626
- if Jmax>J:
2627
- print('==========\n\n')
2628
- print('The Jmax you requested is larger than the data size, which may cause problems while computing the scattering transform.')
2629
- print('\n\n==========')
2630
-
2649
+ if Jmax > J:
2650
+ print("==========\n\n")
2651
+ print(
2652
+ "The Jmax you requested is larger than the data size, which may cause problems while computing the scattering transform."
2653
+ )
2654
+ print("\n\n==========")
2631
2655
 
2632
2656
  ### LOCAL VARIABLES (IMAGES and MASK)
2633
2657
  if len(image1.shape) == 1 or (len(image1.shape) == 2 and self.use_2D):
@@ -2770,32 +2794,37 @@ class funct(FOC.FoCUS):
2770
2794
  nside_j3 = nside # NSIDE start (nside_j3 = nside / 2^j3)
2771
2795
 
2772
2796
  # a remettre comme avant
2773
- M1_dic={}
2774
- M2_dic={}
2775
-
2797
+ M1_dic = {}
2798
+ M2_dic = {}
2799
+
2776
2800
  for j3 in range(Jmax):
2777
-
2801
+
2778
2802
  if edge:
2779
2803
  if self.mask_mask is None:
2780
- self.mask_mask={}
2804
+ self.mask_mask = {}
2781
2805
  if self.use_2D:
2782
- if (vmask.shape[1],vmask.shape[2]) not in self.mask_mask:
2783
- mask_mask=np.zeros([1,vmask.shape[1],vmask.shape[2]])
2784
- mask_mask[0,
2785
- self.KERNELSZ//2:-self.KERNELSZ//2+1,
2786
- self.KERNELSZ//2:-self.KERNELSZ//2+1]=1.0
2787
- self.mask_mask[(vmask.shape[1],vmask.shape[2])]=self.backend.bk_cast(mask_mask)
2788
- vmask=vmask*self.mask_mask[(vmask.shape[1],vmask.shape[2])]
2789
- #print(self.KERNELSZ//2,vmask,mask_mask)
2790
-
2806
+ if (vmask.shape[1], vmask.shape[2]) not in self.mask_mask:
2807
+ mask_mask = np.zeros([1, vmask.shape[1], vmask.shape[2]])
2808
+ mask_mask[
2809
+ 0,
2810
+ self.KERNELSZ // 2 : -self.KERNELSZ // 2 + 1,
2811
+ self.KERNELSZ // 2 : -self.KERNELSZ // 2 + 1,
2812
+ ] = 1.0
2813
+ self.mask_mask[(vmask.shape[1], vmask.shape[2])] = (
2814
+ self.backend.bk_cast(mask_mask)
2815
+ )
2816
+ vmask = vmask * self.mask_mask[(vmask.shape[1], vmask.shape[2])]
2817
+ # print(self.KERNELSZ//2,vmask,mask_mask)
2818
+
2791
2819
  if self.use_1D:
2792
2820
  if (vmask.shape[1]) not in self.mask_mask:
2793
- mask_mask=np.zeros([1,vmask.shape[1]])
2794
- mask_mask[0,
2795
- self.KERNELSZ//2:-self.KERNELSZ//2+1]=1.0
2796
- self.mask_mask[(vmask.shape[1])]=self.backend.bk_cast(mask_mask)
2797
- vmask=vmask*self.mask_mask[(vmask.shape[1])]
2798
-
2821
+ mask_mask = np.zeros([1, vmask.shape[1]])
2822
+ mask_mask[0, self.KERNELSZ // 2 : -self.KERNELSZ // 2 + 1] = 1.0
2823
+ self.mask_mask[(vmask.shape[1])] = self.backend.bk_cast(
2824
+ mask_mask
2825
+ )
2826
+ vmask = vmask * self.mask_mask[(vmask.shape[1])]
2827
+
2799
2828
  if return_data:
2800
2829
  S3[j3] = None
2801
2830
  S3P[j3] = None
@@ -3449,9 +3478,9 @@ class funct(FOC.FoCUS):
3449
3478
  M2_dic[j2], axis=1
3450
3479
  ) # [Nbatch, Npix_j3, Norient3]
3451
3480
  M2_dic[j2] = self.ud_grade_2(
3452
- M2, axis=1
3481
+ M2_smooth, axis=1
3453
3482
  ) # [Nbatch, Npix_j3, Norient3]
3454
-
3483
+
3455
3484
  ### Mask
3456
3485
  vmask = self.ud_grade_2(vmask, axis=1)
3457
3486
 
@@ -3543,1278 +3572,232 @@ class funct(FOC.FoCUS):
3543
3572
  use_1D=self.use_1D,
3544
3573
  )
3545
3574
 
3546
- def eval_new(
3547
- self,
3548
- image1,
3549
- image2=None,
3550
- mask=None,
3551
- norm=None,
3552
- calc_var=False,
3553
- cmat=None,
3554
- cmat2=None,
3555
- Jmax=None,
3556
- out_nside=None,
3557
- edge=True
3575
+ def clean_norm(self):
3576
+ self.P1_dic = None
3577
+ self.P2_dic = None
3578
+ return
3579
+
3580
+ def _compute_S3(
3581
+ self,
3582
+ j2,
3583
+ j3,
3584
+ conv,
3585
+ vmask,
3586
+ M_dic,
3587
+ MconvPsi_dic,
3588
+ calc_var=False,
3589
+ return_data=False,
3590
+ cmat2=None,
3558
3591
  ):
3559
3592
  """
3560
- Calculates the scattering correlations for a batch of images. Mean are done over pixels.
3561
- mean of modulus:
3562
- S1 = <|I * Psi_j3|>
3563
- Normalization : take the log
3564
- power spectrum:
3565
- S2 = <|I * Psi_j3|^2>
3566
- Normalization : take the log
3567
- orig. x modulus:
3568
- S3 = < (I * Psi)_j3 x (|I * Psi_j2| * Psi_j3)^* >
3569
- Normalization : divide by (S2_j2 * S2_j3)^0.5
3570
- modulus x modulus:
3571
- S4 = <(|I * psi1| * psi3)(|I * psi2| * psi3)^*>
3572
- Normalization : divide by (S2_j1 * S2_j2)^0.5
3593
+ Compute the S3 coefficients (auto or cross)
3594
+ S3 = < (Ia * Psi)_j3 x (|Ib * Psi_j2| * Psi_j3)^* >_pix
3573
3595
  Parameters
3574
3596
  ----------
3575
- image1: tensor
3576
- Image on which we compute the scattering coefficients [Nbatch, Npix, 1, 1]
3577
- image2: tensor
3578
- Second image. If not None, we compute cross-scattering covariance coefficients.
3579
- mask:
3580
- norm: None or str
3581
- If None no normalization is applied, if 'auto' normalize by the reference S2,
3582
- if 'self' normalize by the current S2.
3583
3597
  Returns
3584
3598
  -------
3585
- S1, S2, S3, S4 normalized
3599
+ cs3, ss3: real and imag parts of S3 coeff
3586
3600
  """
3587
- return_data = self.return_data
3588
- NORIENT=self.NORIENT
3589
- # Check input consistency
3590
- if image2 is not None:
3591
- if list(image1.shape) != list(image2.shape):
3592
- print(
3593
- "The two input image should have the same size to eval Scattering Covariance"
3594
- )
3595
- return None
3596
- if mask is not None:
3597
- if image1.shape[-2] != mask.shape[1] or image1.shape[-1] != mask.shape[2]:
3598
- print(
3599
- "The LAST COLUMN of the mask should have the same size ",
3600
- mask.shape,
3601
- "than the input image ",
3602
- image1.shape,
3603
- "to eval Scattering Covariance",
3604
- )
3605
- return None
3606
- if self.use_2D and len(image1.shape) < 2:
3607
- print(
3608
- "To work with 2D scattering transform, two dimension is needed, input map has only on dimension"
3601
+ ### Compute |I1 * Psi_j2| * Psi_j3 = M1_j2 * Psi_j3
3602
+ # Warning: M1_dic[j2] is already at j3 resolution [Nbatch, Npix_j3, Norient3]
3603
+ MconvPsi = self.convol(
3604
+ M_dic[j2], axis=1
3605
+ ) # [Nbatch, Npix_j3, Norient3, Norient2]
3606
+ if cmat2 is not None:
3607
+ tmp2 = self.backend.bk_repeat(MconvPsi, self.NORIENT, axis=-1)
3608
+ MconvPsi = self.backend.bk_reduce_sum(
3609
+ self.backend.bk_reshape(
3610
+ cmat2[j3][j2] * tmp2,
3611
+ [
3612
+ tmp2.shape[0],
3613
+ cmat2[j3].shape[1],
3614
+ self.NORIENT,
3615
+ self.NORIENT,
3616
+ self.NORIENT,
3617
+ ],
3618
+ ),
3619
+ 3,
3609
3620
  )
3610
- return None
3611
-
3612
- ### AUTO OR CROSS
3613
- cross = False
3614
- if image2 is not None:
3615
- cross = True
3616
3621
 
3617
- ### PARAMETERS
3618
- axis = 1
3619
- # determine jmax and nside corresponding to the input map
3620
- im_shape = image1.shape
3621
- if self.use_2D:
3622
- if len(image1.shape) == 2:
3623
- nside = np.min([im_shape[0], im_shape[1]])
3624
- npix = im_shape[0] * im_shape[1] # Number of pixels
3625
- x1 = im_shape[0]
3626
- x2 = im_shape[1]
3627
- else:
3628
- nside = np.min([im_shape[1], im_shape[2]])
3629
- npix = im_shape[1] * im_shape[2] # Number of pixels
3630
- x1 = im_shape[1]
3631
- x2 = im_shape[2]
3632
- J = int(np.log(nside - self.KERNELSZ) / np.log(2)) # Number of j scales
3633
- elif self.use_1D:
3634
- if len(image1.shape) == 2:
3635
- npix = int(im_shape[1]) # Number of pixels
3636
- else:
3637
- npix = int(im_shape[0]) # Number of pixels
3622
+ # Store it so we can use it in S4 computation
3623
+ MconvPsi_dic[j2] = MconvPsi # [Nbatch, Npix_j3, Norient3, Norient2]
3638
3624
 
3639
- nside = int(npix)
3625
+ ### Compute the product (I2 * Psi)_j3 x (M1_j2 * Psi_j3)^*
3626
+ # z_1 x z_2^* = (a1a2 + b1b2) + i(b1a2 - a1b2)
3627
+ # cconv, sconv are [Nbatch, Npix_j3, Norient3]
3628
+ if self.use_1D:
3629
+ s3 = conv * self.backend.bk_conjugate(MconvPsi)
3630
+ else:
3631
+ s3 = self.backend.bk_expand_dims(conv, -1) * self.backend.bk_conjugate(
3632
+ MconvPsi
3633
+ ) # [Nbatch, Npix_j3, Norient3, Norient2]
3640
3634
 
3641
- J = int(np.log(nside) / np.log(2)) # Number of j scales
3635
+ ### Apply the mask [Nmask, Npix_j3] and sum over pixels
3636
+ if return_data:
3637
+ return s3
3642
3638
  else:
3643
- if len(image1.shape) == 2:
3644
- npix = int(im_shape[1]) # Number of pixels
3639
+ if calc_var:
3640
+ s3, vs3 = self.masked_mean(
3641
+ s3, vmask, axis=1, rank=j2, calc_var=True
3642
+ ) # [Nbatch, Nmask, Norient3, Norient2]
3643
+ return s3, vs3
3645
3644
  else:
3646
- npix = int(im_shape[0]) # Number of pixels
3645
+ s3 = self.masked_mean(
3646
+ s3, vmask, axis=1, rank=j2
3647
+ ) # [Nbatch, Nmask, Norient3, Norient2]
3648
+ return s3
3647
3649
 
3648
- nside = int(np.sqrt(npix // 12))
3650
+ def _compute_S4(
3651
+ self,
3652
+ j1,
3653
+ j2,
3654
+ vmask,
3655
+ M1convPsi_dic,
3656
+ M2convPsi_dic=None,
3657
+ calc_var=False,
3658
+ return_data=False,
3659
+ ):
3660
+ #### Simplify notations
3661
+ M1 = M1convPsi_dic[j1] # [Nbatch, Npix_j3, Norient3, Norient1]
3649
3662
 
3650
- J = int(np.log(nside) / np.log(2)) # Number of j scales
3651
-
3652
- if (self.use_2D or self.use_1D) and self.KERNELSZ>3:
3653
- J-=1
3654
- if Jmax is None:
3655
- Jmax = J # Number of steps for the loop on scales
3656
- if Jmax>J:
3657
- print('==========\n\n')
3658
- print('The Jmax you requested is larger than the data size, which may cause problems while computing the scattering transform.')
3659
- print('\n\n==========')
3660
-
3663
+ # Auto or Cross coefficients
3664
+ if M2convPsi_dic is None: # Auto
3665
+ M2 = M1convPsi_dic[j2] # [Nbatch, Npix_j3, Norient3, Norient2]
3666
+ else: # Cross
3667
+ M2 = M2convPsi_dic[j2]
3661
3668
 
3662
- ### LOCAL VARIABLES (IMAGES and MASK)
3663
- if len(image1.shape) == 1 or (len(image1.shape) == 2 and self.use_2D):
3664
- I1 = self.backend.bk_cast(
3665
- self.backend.bk_expand_dims(image1, 0)
3666
- ) # Local image1 [Nbatch, Npix]
3667
- if cross:
3668
- I2 = self.backend.bk_cast(
3669
- self.backend.bk_expand_dims(image2, 0)
3670
- ) # Local image2 [Nbatch, Npix]
3669
+ ### Compute the product (|I1 * Psi_j1| * Psi_j3)(|I2 * Psi_j2| * Psi_j3)
3670
+ # z_1 x z_2^* = (a1a2 + b1b2) + i(b1a2 - a1b2)
3671
+ if self.use_1D:
3672
+ s4 = M1 * self.backend.bk_conjugate(M2)
3671
3673
  else:
3672
- I1 = self.backend.bk_cast(image1) # Local image1 [Nbatch, Npix]
3673
- if cross:
3674
- I2 = self.backend.bk_cast(image2) # Local image2 [Nbatch, Npix]
3674
+ s4 = self.backend.bk_expand_dims(M1, -2) * self.backend.bk_conjugate(
3675
+ self.backend.bk_expand_dims(M2, -1)
3676
+ ) # [Nbatch, Npix_j3, Norient3, Norient2, Norient1]
3675
3677
 
3676
- if mask is None:
3677
- if self.use_2D:
3678
- vmask = self.backend.bk_ones([1, x1, x2], dtype=self.all_type)
3679
- else:
3680
- vmask = self.backend.bk_ones([1, npix], dtype=self.all_type)
3678
+ ### Apply the mask and sum over pixels
3679
+ if return_data:
3680
+ return s4
3681
3681
  else:
3682
- vmask = self.backend.bk_cast(mask) # [Nmask, Npix]
3682
+ if calc_var:
3683
+ s4, vs4 = self.masked_mean(
3684
+ s4, vmask, axis=1, rank=j2, calc_var=True
3685
+ ) # [Nbatch, Nmask, Norient3, Norient2, Norient1]
3686
+ return s4, vs4
3687
+ else:
3688
+ s4 = self.masked_mean(
3689
+ s4, vmask, axis=1, rank=j2
3690
+ ) # [Nbatch, Nmask, Norient3, Norient2, Norient1]
3691
+ return s4
3683
3692
 
3684
- if self.KERNELSZ > 3 and not self.use_2D:
3685
- # if the kernel size is bigger than 3 increase the binning before smoothing
3686
- if self.use_2D:
3687
- vmask = self.up_grade(
3688
- vmask, I1.shape[axis] * 2, axis=1, nouty=I1.shape[axis + 1] * 2
3693
+ def computer_filter(self, M, N, J, L):
3694
+ """
3695
+ This function is strongly inspire by the package https://github.com/SihaoCheng/scattering_transform
3696
+ Done by Sihao Cheng and Rudy Morel.
3697
+ """
3698
+
3699
+ filter = np.zeros([J, L, M, N], dtype="complex64")
3700
+
3701
+ slant = 4.0 / L
3702
+
3703
+ for j in range(J):
3704
+
3705
+ for ell in range(L):
3706
+
3707
+ theta = (int(L - L / 2 - 1) - ell) * np.pi / L
3708
+ sigma = 0.8 * 2**j
3709
+ xi = 3.0 / 4.0 * np.pi / 2**j
3710
+
3711
+ R = np.array(
3712
+ [[np.cos(theta), -np.sin(theta)], [np.sin(theta), np.cos(theta)]],
3713
+ np.float64,
3689
3714
  )
3690
- I1 = self.up_grade(
3691
- I1, I1.shape[axis] * 2, axis=axis, nouty=I1.shape[axis + 1] * 2
3715
+ R_inv = np.array(
3716
+ [[np.cos(theta), np.sin(theta)], [-np.sin(theta), np.cos(theta)]],
3717
+ np.float64,
3692
3718
  )
3693
- if cross:
3694
- I2 = self.up_grade(
3695
- I2, I2.shape[axis] * 2, axis=axis, nouty=I2.shape[axis + 1] * 2
3696
- )
3697
- elif self.use_1D:
3698
- vmask = self.up_grade(vmask, I1.shape[axis] * 2, axis=1)
3699
- I1 = self.up_grade(I1, I1.shape[axis] * 2, axis=axis)
3700
- if cross:
3701
- I2 = self.up_grade(I2, I2.shape[axis] * 2, axis=axis)
3702
- else:
3703
- I1 = self.up_grade(I1, nside * 2, axis=axis)
3704
- vmask = self.up_grade(vmask, nside * 2, axis=1)
3705
- if cross:
3706
- I2 = self.up_grade(I2, nside * 2, axis=axis)
3707
-
3708
- if self.KERNELSZ > 5 and not self.use_2D:
3709
- # if the kernel size is bigger than 3 increase the binning before smoothing
3710
- if self.use_2D:
3711
- vmask = self.up_grade(
3712
- vmask, I1.shape[axis] * 2, axis=1, nouty=I1.shape[axis + 1] * 2
3713
- )
3714
- I1 = self.up_grade(
3715
- I1, I1.shape[axis] * 2, axis=axis, nouty=I1.shape[axis + 1] * 2
3716
- )
3717
- if cross:
3718
- I2 = self.up_grade(
3719
- I2,
3720
- I2.shape[axis] * 2,
3721
- axis=axis,
3722
- nouty=I2.shape[axis + 1] * 2,
3723
- )
3724
- elif self.use_1D:
3725
- vmask = self.up_grade(vmask, I1.shape[axis] * 4, axis=1)
3726
- I1 = self.up_grade(I1, I1.shape[axis] * 4, axis=axis)
3727
- if cross:
3728
- I2 = self.up_grade(I2, I2.shape[axis] * 4, axis=axis)
3729
- else:
3730
- I1 = self.up_grade(I1, nside * 4, axis=axis)
3731
- vmask = self.up_grade(vmask, nside * 4, axis=1)
3732
- if cross:
3733
- I2 = self.up_grade(I2, nside * 4, axis=axis)
3734
-
3735
- # Normalize the masks because they have different pixel numbers
3736
- # vmask /= self.backend.bk_reduce_sum(vmask, axis=1)[:, None] # [Nmask, Npix]
3737
-
3738
- ### INITIALIZATION
3739
- # Coefficients
3740
- if return_data:
3741
- S1 = {}
3742
- S2 = {}
3743
- S3 = {}
3744
- S3P = {}
3745
- S4 = {}
3746
- else:
3747
- result=self.backend.backend.zeros([I1.shape[0],vmask.shape[0],2+2*Jmax*self.NORIENT],
3748
- dtype=self.backend.backend.float32,
3749
- device=self.backend.torch_device)
3750
- vresult=self.backend.backend.zeros([I1.shape[0],vmask.shape[0],2+2*Jmax*self.NORIENT],
3751
- dtype=self.backend.backend.float32,
3752
- device=self.backend.torch_device)
3753
- S1 = self.backend.backend.zeros([1,Jmax*self.NORIENT],dtype=self.backend.backend.float32,device=self.backend.torch_device)
3754
- S2 = self.backend.backend.zeros([1,Jmax*self.NORIENT],dtype=self.backend.backend.float32,device=self.backend.torch_device)
3755
- S3 = []
3756
- S4 = []
3757
- S3P = []
3758
- VS1 = self.backend.backend.zeros([1,Jmax*self.NORIENT],dtype=self.backend.backend.float32,device=self.backend.torch_device)
3759
- VS2 = self.backend.backend.zeros([1,Jmax*self.NORIENT],dtype=self.backend.backend.float32,device=self.backend.torch_device)
3760
- VS3 = []
3761
- VS3P = []
3762
- VS4 = []
3763
-
3764
- off_S2 = -2
3765
- off_S3 = -3
3766
- off_S4 = -4
3767
- if self.use_1D:
3768
- off_S2 = -1
3769
- off_S3 = -1
3770
- off_S4 = -1
3771
-
3772
- # S2 for normalization
3773
- cond_init_P1_dic = (norm == "self") or (
3774
- (norm == "auto") and (self.P1_dic is None)
3775
- )
3776
- if norm is None:
3777
- pass
3778
- elif cond_init_P1_dic:
3779
- P1_dic = {}
3780
- if cross:
3781
- P2_dic = {}
3782
- elif (norm == "auto") and (self.P1_dic is not None):
3783
- P1_dic = self.P1_dic
3784
- if cross:
3785
- P2_dic = self.P2_dic
3786
-
3787
- if return_data:
3788
- s0 = I1
3789
- if out_nside is not None:
3790
- s0 = self.backend.bk_reduce_mean(
3791
- self.backend.bk_reshape(
3792
- s0, [s0.shape[0], 12 * out_nside**2, (nside // out_nside) ** 2]
3793
- ),
3794
- 2,
3795
- )
3796
- else:
3797
- if not cross:
3798
- s0, l_vs0 = self.masked_mean(I1, vmask, axis=1, calc_var=True)
3799
- else:
3800
- s0, l_vs0 = self.masked_mean(
3801
- self.backend.bk_L1(I1 * I2), vmask, axis=1, calc_var=True
3802
- )
3803
- #vs0 = self.backend.bk_concat([l_vs0, l_vs0], 1)
3804
- #s0 = self.backend.bk_concat([s0, l_vs0], 1)
3805
- result[:,:,0]=s0
3806
- result[:,:,1]=l_vs0
3807
- vresult[:,:,0]=l_vs0
3808
- vresult[:,:,1]=l_vs0
3809
- #### COMPUTE S1, S2, S3 and S4
3810
- nside_j3 = nside # NSIDE start (nside_j3 = nside / 2^j3)
3811
-
3812
- # a remettre comme avant
3813
- M1_dic={}
3814
- M2_dic={}
3815
-
3816
- for j3 in range(Jmax):
3817
-
3818
- if edge:
3819
- if self.mask_mask is None:
3820
- self.mask_mask={}
3821
- if self.use_2D:
3822
- if (vmask.shape[1],vmask.shape[2]) not in self.mask_mask:
3823
- mask_mask=np.zeros([1,vmask.shape[1],vmask.shape[2]])
3824
- mask_mask[0,
3825
- self.KERNELSZ//2:-self.KERNELSZ//2+1,
3826
- self.KERNELSZ//2:-self.KERNELSZ//2+1]=1.0
3827
- self.mask_mask[(vmask.shape[1],vmask.shape[2])]=self.backend.bk_cast(mask_mask)
3828
- vmask=vmask*self.mask_mask[(vmask.shape[1],vmask.shape[2])]
3829
- #print(self.KERNELSZ//2,vmask,mask_mask)
3830
-
3831
- if self.use_1D:
3832
- if (vmask.shape[1]) not in self.mask_mask:
3833
- mask_mask=np.zeros([1,vmask.shape[1]])
3834
- mask_mask[0,
3835
- self.KERNELSZ//2:-self.KERNELSZ//2+1]=1.0
3836
- self.mask_mask[(vmask.shape[1])]=self.backend.bk_cast(mask_mask)
3837
- vmask=vmask*self.mask_mask[(vmask.shape[1])]
3838
-
3839
- if return_data:
3840
- S3[j3] = None
3841
- S3P[j3] = None
3842
-
3843
- if S4 is None:
3844
- S4 = {}
3845
- S4[j3] = None
3846
-
3847
- ####### S1 and S2
3848
- ### Make the convolution I1 * Psi_j3
3849
- conv1 = self.convol(I1, axis=1) # [Nbatch, Npix_j3, Norient3]
3850
- if cmat is not None:
3851
- tmp2 = self.backend.bk_repeat(conv1, self.NORIENT, axis=-1)
3852
- conv1 = self.backend.bk_reduce_sum(
3853
- self.backend.bk_reshape(
3854
- cmat[j3] * tmp2,
3855
- [tmp2.shape[0], cmat[j3].shape[0], self.NORIENT, self.NORIENT],
3856
- ),
3857
- 2,
3858
- )
3859
-
3860
- ### Take the module M1 = |I1 * Psi_j3|
3861
- M1_square = conv1 * self.backend.bk_conjugate(
3862
- conv1
3863
- ) # [Nbatch, Npix_j3, Norient3]
3864
-
3865
- M1 = self.backend.bk_L1(M1_square) # [Nbatch, Npix_j3, Norient3]
3866
-
3867
- # Store M1_j3 in a dictionary
3868
- M1_dic[j3] = M1
3869
-
3870
- if not cross: # Auto
3871
- M1_square = self.backend.bk_real(M1_square)
3872
-
3873
- ### S2_auto = < M1^2 >_pix
3874
- # Apply the mask [Nmask, Npix_j3] and average over pixels
3875
- if return_data:
3876
- s2 = M1_square
3877
- else:
3878
- if calc_var:
3879
- s2, vs2 = self.masked_mean(
3880
- M1_square, vmask, axis=1, rank=j3, calc_var=True
3881
- )
3882
- #s2=self.backend.bk_flatten(self.backend.bk_real(s2))
3883
- #vs2=self.backend.bk_flatten(vs2)
3884
- else:
3885
- s2 = self.masked_mean(M1_square, vmask, axis=1, rank=j3)
3886
-
3887
- if cond_init_P1_dic:
3888
- # We fill P1_dic with S2 for normalisation of S3 and S4
3889
- P1_dic[j3] = self.backend.bk_real(self.backend.bk_real(s2)) # [Nbatch, Nmask, Norient3]
3890
-
3891
- # We store S2_auto to return it [Nbatch, Nmask, NS2, Norient3]
3892
- if return_data:
3893
- if S2 is None:
3894
- S2 = {}
3895
- if out_nside is not None and out_nside < nside_j3:
3896
- s2 = self.backend.bk_reduce_mean(
3897
- self.backend.bk_reshape(
3898
- s2,
3899
- [
3900
- s2.shape[0],
3901
- 12 * out_nside**2,
3902
- (nside_j3 // out_nside) ** 2,
3903
- s2.shape[2],
3904
- ],
3905
- ),
3906
- 2,
3907
- )
3908
- S2[j3] = s2
3909
- else:
3910
- if norm == "auto": # Normalize S2
3911
- s2 /= P1_dic[j3]
3912
- """
3913
- S2.append(
3914
- self.backend.bk_expand_dims(s2, off_S2)
3915
- ) # Add a dimension for NS2
3916
- if calc_var:
3917
- VS2.append(
3918
- self.backend.bk_expand_dims(vs2, off_S2)
3919
- ) # Add a dimension for NS2
3920
- """
3921
- #print(s2.shape,result[:,:,2+j3*NORIENT*2:2+j3*NORIENT*2+NORIENT].shape,result.shape,2+j3*NORIENT*2)
3922
- result[:,:,2+j3*NORIENT*2:2+j3*NORIENT*2+NORIENT]=s2
3923
- if calc_var:
3924
- vresult[:,:,2+j3*NORIENT*2:2+j3*NORIENT*2+NORIENT]=vs2
3925
- #### S1_auto computation
3926
- ### Image 1 : S1 = < M1 >_pix
3927
- # Apply the mask [Nmask, Npix_j3] and average over pixels
3928
- if return_data:
3929
- s1 = M1
3930
- else:
3931
- if calc_var:
3932
- s1, vs1 = self.masked_mean(
3933
- M1, vmask, axis=1, rank=j3, calc_var=True
3934
- ) # [Nbatch, Nmask, Norient3]
3935
- #s1=self.backend.bk_flatten(self.backend.bk_real(s1))
3936
- #vs1=self.backend.bk_flatten(vs1)
3937
- else:
3938
- s1 = self.masked_mean(
3939
- M1, vmask, axis=1, rank=j3
3940
- ) # [Nbatch, Nmask, Norient3]
3941
- #s1=self.backend.bk_flatten(self.backend.bk_real(s1))
3942
-
3943
- if return_data:
3944
- if out_nside is not None and out_nside < nside_j3:
3945
- s1 = self.backend.bk_reduce_mean(
3946
- self.backend.bk_reshape(
3947
- s1,
3948
- [
3949
- s1.shape[0],
3950
- 12 * out_nside**2,
3951
- (nside_j3 // out_nside) ** 2,
3952
- s1.shape[2],
3953
- ],
3954
- ),
3955
- 2,
3956
- )
3957
- S1[j3] = s1
3958
- else:
3959
- ### Normalize S1
3960
- if norm is not None:
3961
- self.div_norm(s1, (P1_dic[j3]) ** 0.5)
3962
- result[:,:,2+j3*NORIENT*2+NORIENT:2+j3*NORIENT*2+2*NORIENT]=s1
3963
- if calc_var:
3964
- vresult[:,:,2+j3*NORIENT*2+NORIENT:2+j3*NORIENT*2+2*NORIENT]=vs1
3965
- """
3966
- ### We store S1 for image1 [Nbatch, Nmask, NS1, Norient3]
3967
- S1.append(
3968
- self.backend.bk_expand_dims(s1, off_S2)
3969
- ) # Add a dimension for NS1
3970
- if calc_var:
3971
- VS1.append(
3972
- self.backend.bk_expand_dims(vs1, off_S2)
3973
- ) # Add a dimension for NS1
3974
- """
3975
-
3976
- else: # Cross
3977
- ### Make the convolution I2 * Psi_j3
3978
- conv2 = self.convol(I2, axis=1) # [Nbatch, Npix_j3, Norient3]
3979
- if cmat is not None:
3980
- tmp2 = self.backend.bk_repeat(conv2, self.NORIENT, axis=-1)
3981
- conv2 = self.backend.bk_reduce_sum(
3982
- self.backend.bk_reshape(
3983
- cmat[j3] * tmp2,
3984
- [
3985
- tmp2.shape[0],
3986
- cmat[j3].shape[0],
3987
- self.NORIENT,
3988
- self.NORIENT,
3989
- ],
3990
- ),
3991
- 2,
3992
- )
3993
- ### Take the module M2 = |I2 * Psi_j3|
3994
- M2_square = conv2 * self.backend.bk_conjugate(
3995
- conv2
3996
- ) # [Nbatch, Npix_j3, Norient3]
3997
- M2 = self.backend.bk_L1(M2_square) # [Nbatch, Npix_j3, Norient3]
3998
- # Store M2_j3 in a dictionary
3999
- M2_dic[j3] = M2
4000
-
4001
- ### S2_auto = < M2^2 >_pix
4002
- # Not returned, only for normalization
4003
- if cond_init_P1_dic:
4004
- # Apply the mask [Nmask, Npix_j3] and average over pixels
4005
- if return_data:
4006
- p1 = M1_square
4007
- p2 = M2_square
4008
- else:
4009
- if calc_var:
4010
- p1, vp1 = self.masked_mean(
4011
- M1_square, vmask, axis=1, rank=j3, calc_var=True
4012
- ) # [Nbatch, Nmask, Norient3]
4013
- p2, vp2 = self.masked_mean(
4014
- M2_square, vmask, axis=1, rank=j3, calc_var=True
4015
- ) # [Nbatch, Nmask, Norient3]
4016
- else:
4017
- p1 = self.masked_mean(
4018
- M1_square, vmask, axis=1, rank=j3
4019
- ) # [Nbatch, Nmask, Norient3]
4020
- p2 = self.masked_mean(
4021
- M2_square, vmask, axis=1, rank=j3
4022
- ) # [Nbatch, Nmask, Norient3]
4023
- # We fill P1_dic with S2 for normalisation of S3 and S4
4024
- P1_dic[j3] = self.backend.bk_real(p1) # [Nbatch, Nmask, Norient3]
4025
- P2_dic[j3] = self.backend.bk_real(p2) # [Nbatch, Nmask, Norient3]
4026
-
4027
- ### S2_cross = < (I1 * Psi_j3) (I2 * Psi_j3)^* >_pix
4028
- # z_1 x z_2^* = (a1a2 + b1b2) + i(b1a2 - a1b2)
4029
- s2 = conv1 * self.backend.bk_conjugate(conv2)
4030
- MX = self.backend.bk_L1(s2)
4031
- # Apply the mask [Nmask, Npix_j3] and average over pixels
4032
- if return_data:
4033
- s2 = s2
4034
- else:
4035
- if calc_var:
4036
- s2, vs2 = self.masked_mean(
4037
- s2, vmask, axis=1, rank=j3, calc_var=True
4038
- )
4039
- else:
4040
- s2 = self.masked_mean(s2, vmask, axis=1, rank=j3)
4041
-
4042
- if return_data:
4043
- if out_nside is not None and out_nside < nside_j3:
4044
- s2 = self.backend.bk_reduce_mean(
4045
- self.backend.bk_reshape(
4046
- s2,
4047
- [
4048
- s2.shape[0],
4049
- 12 * out_nside**2,
4050
- (nside_j3 // out_nside) ** 2,
4051
- s2.shape[2],
4052
- ],
4053
- ),
4054
- 2,
4055
- )
4056
- S2[j3] = s2
4057
- else:
4058
- ### Normalize S2_cross
4059
- if norm == "auto":
4060
- s2 /= (P1_dic[j3] * P2_dic[j3]) ** 0.5
4061
-
4062
- ### Store S2_cross as complex [Nbatch, Nmask, NS2, Norient3]
4063
- s2 = self.backend.bk_real(s2)
4064
-
4065
- S2.append(
4066
- self.backend.bk_expand_dims(s2, off_S2)
4067
- ) # Add a dimension for NS2
4068
- if calc_var:
4069
- VS2.append(
4070
- self.backend.bk_expand_dims(vs2, off_S2)
4071
- ) # Add a dimension for NS2
4072
-
4073
- #### S1_auto computation
4074
- ### Image 1 : S1 = < M1 >_pix
4075
- # Apply the mask [Nmask, Npix_j3] and average over pixels
4076
- if return_data:
4077
- s1 = MX
4078
- else:
4079
- if calc_var:
4080
- s1, vs1 = self.masked_mean(
4081
- MX, vmask, axis=1, rank=j3, calc_var=True
4082
- ) # [Nbatch, Nmask, Norient3]
4083
- else:
4084
- s1 = self.masked_mean(
4085
- MX, vmask, axis=1, rank=j3
4086
- ) # [Nbatch, Nmask, Norient3]
4087
- if return_data:
4088
- if out_nside is not None and out_nside < nside_j3:
4089
- s1 = self.backend.bk_reduce_mean(
4090
- self.backend.bk_reshape(
4091
- s1,
4092
- [
4093
- s1.shape[0],
4094
- 12 * out_nside**2,
4095
- (nside_j3 // out_nside) ** 2,
4096
- s1.shape[2],
4097
- ],
4098
- ),
4099
- 2,
4100
- )
4101
- S1[j3] = s1
4102
- else:
4103
- ### Normalize S1
4104
- if norm is not None:
4105
- self.div_norm(s1, (P1_dic[j3]) ** 0.5)
4106
- ### We store S1 for image1 [Nbatch, Nmask, NS1, Norient3]
4107
- S1.append(
4108
- self.backend.bk_expand_dims(s1, off_S2)
4109
- ) # Add a dimension for NS1
4110
- if calc_var:
4111
- VS1.append(
4112
- self.backend.bk_expand_dims(vs1, off_S2)
4113
- ) # Add a dimension for NS1
4114
-
4115
- # Initialize dictionaries for |I1*Psi_j| * Psi_j3
4116
- M1convPsi_dic = {}
4117
- if cross:
4118
- # Initialize dictionaries for |I2*Psi_j| * Psi_j3
4119
- M2convPsi_dic = {}
4120
-
4121
- ###### S3
4122
- nside_j2 = nside_j3
4123
- for j2 in range(0,-1): # j3 + 1): # j2 <= j3
4124
- if return_data:
4125
- if S4[j3] is None:
4126
- S4[j3] = {}
4127
- S4[j3][j2] = None
4128
-
4129
- ### S3_auto = < (I1 * Psi)_j3 x (|I1 * Psi_j2| * Psi_j3)^* >_pix
4130
- if not cross:
4131
- if calc_var:
4132
- s3, vs3 = self._compute_S3(
4133
- j2,
4134
- j3,
4135
- conv1,
4136
- vmask,
4137
- M1_dic,
4138
- M1convPsi_dic,
4139
- calc_var=True,
4140
- cmat2=cmat2,
4141
- ) # [Nbatch, Nmask, Norient3, Norient2]
4142
- else:
4143
- s3 = self._compute_S3(
4144
- j2,
4145
- j3,
4146
- conv1,
4147
- vmask,
4148
- M1_dic,
4149
- M1convPsi_dic,
4150
- return_data=return_data,
4151
- cmat2=cmat2,
4152
- ) # [Nbatch, Nmask, Norient3, Norient2]
4153
-
4154
- if return_data:
4155
- if S3[j3] is None:
4156
- S3[j3] = {}
4157
- if out_nside is not None and out_nside < nside_j2:
4158
- s3 = self.backend.bk_reduce_mean(
4159
- self.backend.bk_reshape(
4160
- s3,
4161
- [
4162
- s3.shape[0],
4163
- 12 * out_nside**2,
4164
- (nside_j2 // out_nside) ** 2,
4165
- s3.shape[2],
4166
- s3.shape[3],
4167
- ],
4168
- ),
4169
- 2,
4170
- )
4171
- S3[j3][j2] = s3
4172
- else:
4173
- ### Normalize S3 with S2_j [Nbatch, Nmask, Norient_j]
4174
- if norm is not None:
4175
- self.div_norm(
4176
- s3,
4177
- (
4178
- self.backend.bk_expand_dims(P1_dic[j2], off_S2)
4179
- * self.backend.bk_expand_dims(P1_dic[j3], -1)
4180
- )
4181
- ** 0.5,
4182
- ) # [Nbatch, Nmask, Norient3, Norient2]
4183
-
4184
- ### Store S3 as a complex [Nbatch, Nmask, NS3, Norient3, Norient2]
4185
-
4186
- # S3.append(self.backend.bk_reshape(s3,[s3.shape[0],s3.shape[1],
4187
- # s3.shape[2]*s3.shape[3]]))
4188
- S3.append(
4189
- self.backend.bk_expand_dims(s3, off_S3)
4190
- ) # Add a dimension for NS3
4191
- if calc_var:
4192
- VS3.append(
4193
- self.backend.bk_expand_dims(vs3, off_S3)
4194
- ) # Add a dimension for NS3
4195
- # VS3.append(self.backend.bk_reshape(vs3,[s3.shape[0],s3.shape[1],
4196
- # s3.shape[2]*s3.shape[3]]))
4197
-
4198
- ### S3_cross = < (I1 * Psi)_j3 x (|I2 * Psi_j2| * Psi_j3)^* >_pix
4199
- ### S3P_cross = < (I2 * Psi)_j3 x (|I1 * Psi_j2| * Psi_j3)^* >_pix
4200
- else:
4201
- if calc_var:
4202
- s3, vs3 = self._compute_S3(
4203
- j2,
4204
- j3,
4205
- conv1,
4206
- vmask,
4207
- M2_dic,
4208
- M2convPsi_dic,
4209
- calc_var=True,
4210
- cmat2=cmat2,
4211
- )
4212
- s3p, vs3p = self._compute_S3(
4213
- j2,
4214
- j3,
4215
- conv2,
4216
- vmask,
4217
- M1_dic,
4218
- M1convPsi_dic,
4219
- calc_var=True,
4220
- cmat2=cmat2,
4221
- )
4222
- else:
4223
- s3 = self._compute_S3(
4224
- j2,
4225
- j3,
4226
- conv1,
4227
- vmask,
4228
- M2_dic,
4229
- M2convPsi_dic,
4230
- return_data=return_data,
4231
- cmat2=cmat2,
4232
- )
4233
- s3p = self._compute_S3(
4234
- j2,
4235
- j3,
4236
- conv2,
4237
- vmask,
4238
- M1_dic,
4239
- M1convPsi_dic,
4240
- return_data=return_data,
4241
- cmat2=cmat2,
4242
- )
4243
-
4244
- if return_data:
4245
- if S3[j3] is None:
4246
- S3[j3] = {}
4247
- S3P[j3] = {}
4248
- if out_nside is not None and out_nside < nside_j2:
4249
- s3 = self.backend.bk_reduce_mean(
4250
- self.backend.bk_reshape(
4251
- s3,
4252
- [
4253
- s3.shape[0],
4254
- 12 * out_nside**2,
4255
- (nside_j2 // out_nside) ** 2,
4256
- s3.shape[2],
4257
- s3.shape[3],
4258
- ],
4259
- ),
4260
- 2,
4261
- )
4262
- s3p = self.backend.bk_reduce_mean(
4263
- self.backend.bk_reshape(
4264
- s3p,
4265
- [
4266
- s3.shape[0],
4267
- 12 * out_nside**2,
4268
- (nside_j2 // out_nside) ** 2,
4269
- s3.shape[2],
4270
- s3.shape[3],
4271
- ],
4272
- ),
4273
- 2,
4274
- )
4275
- S3[j3][j2] = s3
4276
- S3P[j3][j2] = s3p
4277
- else:
4278
- ### Normalize S3 and S3P with S2_j [Nbatch, Nmask, Norient_j]
4279
- if norm is not None:
4280
- self.div_norm(
4281
- s3,
4282
- (
4283
- self.backend.bk_expand_dims(P2_dic[j2], off_S2)
4284
- * self.backend.bk_expand_dims(P1_dic[j3], -1)
4285
- )
4286
- ** 0.5,
4287
- ) # [Nbatch, Nmask, Norient3, Norient2]
4288
- self.div_norm(
4289
- s3p,
4290
- (
4291
- self.backend.bk_expand_dims(P1_dic[j2], off_S2)
4292
- * self.backend.bk_expand_dims(P2_dic[j3], -1)
4293
- )
4294
- ** 0.5,
4295
- ) # [Nbatch, Nmask, Norient3, Norient2]
4296
-
4297
- ### Store S3 and S3P as a complex [Nbatch, Nmask, NS3, Norient3, Norient2]
4298
-
4299
- # S3.append(self.backend.bk_reshape(s3,[s3.shape[0],s3.shape[1],
4300
- # s3.shape[2]*s3.shape[3]]))
4301
- S3.append(
4302
- self.backend.bk_expand_dims(s3, off_S3)
4303
- ) # Add a dimension for NS3
4304
- if calc_var:
4305
- VS3.append(
4306
- self.backend.bk_expand_dims(vs3, off_S3)
4307
- ) # Add a dimension for NS3
4308
-
4309
- # VS3.append(self.backend.bk_reshape(vs3,[s3.shape[0],s3.shape[1],
4310
- # s3.shape[2]*s3.shape[3]]))
4311
-
4312
- # S3P.append(self.backend.bk_reshape(s3p,[s3.shape[0],s3.shape[1],
4313
- # s3.shape[2]*s3.shape[3]]))
4314
- S3P.append(
4315
- self.backend.bk_expand_dims(s3p, off_S3)
4316
- ) # Add a dimension for NS3
4317
- if calc_var:
4318
- VS3P.append(
4319
- self.backend.bk_expand_dims(vs3p, off_S3)
4320
- ) # Add a dimension for NS3
4321
- # VS3P.append(self.backend.bk_reshape(vs3p,[s3.shape[0],s3.shape[1],
4322
- # s3.shape[2]*s3.shape[3]]))
4323
-
4324
- ##### S4
4325
- nside_j1 = nside_j2
4326
- for j1 in range(0, j2 + 1): # j1 <= j2
4327
- ### S4_auto = <(|I1 * psi1| * psi3)(|I1 * psi2| * psi3)^*>
4328
- if not cross:
4329
- if calc_var:
4330
- s4, vs4 = self._compute_S4(
4331
- j1,
4332
- j2,
4333
- vmask,
4334
- M1convPsi_dic,
4335
- M2convPsi_dic=None,
4336
- calc_var=True,
4337
- ) # [Nbatch, Nmask, Norient3, Norient2, Norient1]
4338
- else:
4339
- s4 = self._compute_S4(
4340
- j1,
4341
- j2,
4342
- vmask,
4343
- M1convPsi_dic,
4344
- M2convPsi_dic=None,
4345
- return_data=return_data,
4346
- ) # [Nbatch, Nmask, Norient3, Norient2, Norient1]
4347
-
4348
- if return_data:
4349
- if S4[j3][j2] is None:
4350
- S4[j3][j2] = {}
4351
- if out_nside is not None and out_nside < nside_j1:
4352
- s4 = self.backend.bk_reduce_mean(
4353
- self.backend.bk_reshape(
4354
- s4,
4355
- [
4356
- s4.shape[0],
4357
- 12 * out_nside**2,
4358
- (nside_j1 // out_nside) ** 2,
4359
- s4.shape[2],
4360
- s4.shape[3],
4361
- s4.shape[4],
4362
- ],
4363
- ),
4364
- 2,
4365
- )
4366
- S4[j3][j2][j1] = s4
4367
- else:
4368
- ### Normalize S4 with S2_j [Nbatch, Nmask, Norient_j]
4369
- if norm is not None:
4370
- self.div_norm(
4371
- s4,
4372
- (
4373
- self.backend.bk_expand_dims(
4374
- self.backend.bk_expand_dims(
4375
- P1_dic[j1], off_S2
4376
- ),
4377
- off_S2,
4378
- )
4379
- * self.backend.bk_expand_dims(
4380
- self.backend.bk_expand_dims(
4381
- P1_dic[j2], off_S2
4382
- ),
4383
- -1,
4384
- )
4385
- )
4386
- ** 0.5,
4387
- ) # [Nbatch, Nmask, Norient3, Norient2, Norient1]
4388
- ### Store S4 as a complex [Nbatch, Nmask, NS4, Norient3, Norient2, Norient1]
4389
-
4390
- # S4.append(self.backend.bk_reshape(s4,[s4.shape[0],s4.shape[1],
4391
- # s4.shape[2]*s4.shape[3]*s4.shape[4]]))
4392
- S4.append(
4393
- self.backend.bk_expand_dims(s4, off_S4)
4394
- ) # Add a dimension for NS4
4395
- if calc_var:
4396
- # VS4.append(self.backend.bk_reshape(vs4,[s4.shape[0],s4.shape[1],
4397
- # s4.shape[2]*s4.shape[3]*s4.shape[4]]))
4398
- VS4.append(
4399
- self.backend.bk_expand_dims(vs4, off_S4)
4400
- ) # Add a dimension for NS4
4401
-
4402
- ### S4_cross = <(|I1 * psi1| * psi3)(|I2 * psi2| * psi3)^*>
4403
- else:
4404
- if calc_var:
4405
- s4, vs4 = self._compute_S4(
4406
- j1,
4407
- j2,
4408
- vmask,
4409
- M1convPsi_dic,
4410
- M2convPsi_dic=M2convPsi_dic,
4411
- calc_var=True,
4412
- ) # [Nbatch, Nmask, Norient3, Norient2, Norient1]
4413
- else:
4414
- s4 = self._compute_S4(
4415
- j1,
4416
- j2,
4417
- vmask,
4418
- M1convPsi_dic,
4419
- M2convPsi_dic=M2convPsi_dic,
4420
- return_data=return_data,
4421
- ) # [Nbatch, Nmask, Norient3, Norient2, Norient1]
4422
-
4423
- if return_data:
4424
- if S4[j3][j2] is None:
4425
- S4[j3][j2] = {}
4426
- if out_nside is not None and out_nside < nside_j1:
4427
- s4 = self.backend.bk_reduce_mean(
4428
- self.backend.bk_reshape(
4429
- s4,
4430
- [
4431
- s4.shape[0],
4432
- 12 * out_nside**2,
4433
- (nside_j1 // out_nside) ** 2,
4434
- s4.shape[2],
4435
- s4.shape[3],
4436
- s4.shape[4],
4437
- ],
4438
- ),
4439
- 2,
4440
- )
4441
- S4[j3][j2][j1] = s4
4442
- else:
4443
- ### Normalize S4 with S2_j [Nbatch, Nmask, Norient_j]
4444
- if norm is not None:
4445
- self.div_norm(
4446
- s4,
4447
- (
4448
- self.backend.bk_expand_dims(
4449
- self.backend.bk_expand_dims(
4450
- P1_dic[j1], off_S2
4451
- ),
4452
- off_S2,
4453
- )
4454
- * self.backend.bk_expand_dims(
4455
- self.backend.bk_expand_dims(
4456
- P2_dic[j2], off_S2
4457
- ),
4458
- -1,
4459
- )
4460
- )
4461
- ** 0.5,
4462
- ) # [Nbatch, Nmask, Norient3, Norient2, Norient1]
4463
- ### Store S4 as a complex [Nbatch, Nmask, NS4, Norient3, Norient2, Norient1]
4464
- # S4.append(self.backend.bk_reshape(s4,[s4.shape[0],s4.shape[1],
4465
- # s4.shape[2]*s4.shape[3]*s4.shape[4]]))
4466
- S4.append(
4467
- self.backend.bk_expand_dims(s4, off_S4)
4468
- ) # Add a dimension for NS4
4469
- if calc_var:
4470
-
4471
- # VS4.append(self.backend.bk_reshape(vs4,[s4.shape[0],s4.shape[1],
4472
- # s4.shape[2]*s4.shape[3]*s4.shape[4]]))
4473
- VS4.append(
4474
- self.backend.bk_expand_dims(vs4, off_S4)
4475
- ) # Add a dimension for NS4
4476
-
4477
- nside_j1 = nside_j1 // 2
4478
- nside_j2 = nside_j2 // 2
4479
-
4480
- ###### Reshape for next iteration on j3
4481
- ### Image I1,
4482
- # downscale the I1 [Nbatch, Npix_j3]
4483
- if j3 != Jmax - 1:
4484
- I1 = self.smooth(I1, axis=1)
4485
- I1 = self.ud_grade_2(I1, axis=1)
4486
-
4487
- ### Image I2
4488
- if cross:
4489
- I2 = self.smooth(I2, axis=1)
4490
- I2 = self.ud_grade_2(I2, axis=1)
4491
-
4492
- ### Modules
4493
- for j2 in range(0, j3 + 1): # j2 =< j3
4494
- ### Dictionary M1_dic[j2]
4495
- M1_smooth = self.smooth(
4496
- M1_dic[j2], axis=1
4497
- ) # [Nbatch, Npix_j3, Norient3]
4498
- M1_dic[j2] = self.ud_grade_2(
4499
- M1_smooth, axis=1
4500
- ) # [Nbatch, Npix_j3, Norient3]
4501
-
4502
- ### Dictionary M2_dic[j2]
4503
- if cross:
4504
- M2_smooth = self.smooth(
4505
- M2_dic[j2], axis=1
4506
- ) # [Nbatch, Npix_j3, Norient3]
4507
- M2_dic[j2] = self.ud_grade_2(
4508
- M2, axis=1
4509
- ) # [Nbatch, Npix_j3, Norient3]
4510
-
4511
- ### Mask
4512
- vmask = self.ud_grade_2(vmask, axis=1)
4513
-
4514
- if self.mask_thres is not None:
4515
- vmask = self.backend.bk_threshold(vmask, self.mask_thres)
4516
-
4517
- ### NSIDE_j3
4518
- nside_j3 = nside_j3 // 2
4519
-
4520
- ### Store P1_dic and P2_dic in self
4521
- if (norm == "auto") and (self.P1_dic is None):
4522
- self.P1_dic = P1_dic
4523
- if cross:
4524
- self.P2_dic = P2_dic
4525
- """
4526
- Sout=[s0]+S1+S2+S3+S4
4527
-
4528
- if cross:
4529
- Sout=Sout+S3P
4530
- if calc_var:
4531
- SVout=[vs0]+VS1+VS2+VS3+VS4
4532
- if cross:
4533
- VSout=VSout+VS3P
4534
- return self.backend.bk_concat(Sout, 2),self.backend.bk_concat(VSout, 2)
4535
-
4536
- return self.backend.bk_concat(Sout, 2)
4537
- """
4538
- if calc_var:
4539
- return result,vresult
4540
- else:
4541
- return result
4542
- if calc_var:
4543
- for k in S1:
4544
- print(k.shape,k.dtype)
4545
- for k in S2:
4546
- print(k.shape,k.dtype)
4547
- print(s0.shape,s0.dtype)
4548
- return self.backend.bk_concat([s0]+S1+S2,axis=1),self.backend.bk_concat([vs0]+VS1+VS2,axis=1)
4549
- else:
4550
- return self.backend.bk_concat([s0]+S1+S2,axis=1)
4551
-
4552
- if not return_data:
4553
- S1 = self.backend.bk_concat(S1, 2)
4554
- S2 = self.backend.bk_concat(S2, 2)
4555
- S3 = self.backend.bk_concat(S3, 2)
4556
- S4 = self.backend.bk_concat(S4, 2)
4557
- if cross:
4558
- S3P = self.backend.bk_concat(S3P, 2)
4559
- if calc_var:
4560
- VS1 = self.backend.bk_concat(VS1, 2)
4561
- VS2 = self.backend.bk_concat(VS2, 2)
4562
- VS3 = self.backend.bk_concat(VS3, 2)
4563
- VS4 = self.backend.bk_concat(VS4, 2)
4564
- if cross:
4565
- VS3P = self.backend.bk_concat(VS3P, 2)
4566
- if calc_var:
4567
- if not cross:
4568
- return scat_cov(
4569
- s0, S2, S3, S4, s1=S1, backend=self.backend, use_1D=self.use_1D
4570
- ), scat_cov(
4571
- vs0,
4572
- VS2,
4573
- VS3,
4574
- VS4,
4575
- s1=VS1,
4576
- backend=self.backend,
4577
- use_1D=self.use_1D,
4578
- )
4579
- else:
4580
- return scat_cov(
4581
- s0,
4582
- S2,
4583
- S3,
4584
- S4,
4585
- s1=S1,
4586
- s3p=S3P,
4587
- backend=self.backend,
4588
- use_1D=self.use_1D,
4589
- ), scat_cov(
4590
- vs0,
4591
- VS2,
4592
- VS3,
4593
- VS4,
4594
- s1=VS1,
4595
- s3p=VS3P,
4596
- backend=self.backend,
4597
- use_1D=self.use_1D,
4598
- )
4599
- else:
4600
- if not cross:
4601
- return scat_cov(
4602
- s0, S2, S3, S4, s1=S1, backend=self.backend, use_1D=self.use_1D
4603
- )
4604
- else:
4605
- return scat_cov(
4606
- s0,
4607
- S2,
4608
- S3,
4609
- S4,
4610
- s1=S1,
4611
- s3p=S3P,
4612
- backend=self.backend,
4613
- use_1D=self.use_1D,
4614
- )
4615
- def clean_norm(self):
4616
- self.P1_dic = None
4617
- self.P2_dic = None
4618
- return
4619
-
4620
- def _compute_S3(
4621
- self,
4622
- j2,
4623
- j3,
4624
- conv,
4625
- vmask,
4626
- M_dic,
4627
- MconvPsi_dic,
4628
- calc_var=False,
4629
- return_data=False,
4630
- cmat2=None,
4631
- ):
4632
- """
4633
- Compute the S3 coefficients (auto or cross)
4634
- S3 = < (Ia * Psi)_j3 x (|Ib * Psi_j2| * Psi_j3)^* >_pix
4635
- Parameters
4636
- ----------
4637
- Returns
4638
- -------
4639
- cs3, ss3: real and imag parts of S3 coeff
4640
- """
4641
- ### Compute |I1 * Psi_j2| * Psi_j3 = M1_j2 * Psi_j3
4642
- # Warning: M1_dic[j2] is already at j3 resolution [Nbatch, Npix_j3, Norient3]
4643
- MconvPsi = self.convol(
4644
- M_dic[j2], axis=1
4645
- ) # [Nbatch, Npix_j3, Norient3, Norient2]
4646
- if cmat2 is not None:
4647
- tmp2 = self.backend.bk_repeat(MconvPsi, self.NORIENT, axis=-1)
4648
- MconvPsi = self.backend.bk_reduce_sum(
4649
- self.backend.bk_reshape(
4650
- cmat2[j3][j2] * tmp2,
4651
- [
4652
- tmp2.shape[0],
4653
- cmat2[j3].shape[1],
4654
- self.NORIENT,
4655
- self.NORIENT,
4656
- self.NORIENT,
4657
- ],
4658
- ),
4659
- 3,
4660
- )
4661
-
4662
- # Store it so we can use it in S4 computation
4663
- MconvPsi_dic[j2] = MconvPsi # [Nbatch, Npix_j3, Norient3, Norient2]
4664
-
4665
- ### Compute the product (I2 * Psi)_j3 x (M1_j2 * Psi_j3)^*
4666
- # z_1 x z_2^* = (a1a2 + b1b2) + i(b1a2 - a1b2)
4667
- # cconv, sconv are [Nbatch, Npix_j3, Norient3]
4668
- if self.use_1D:
4669
- s3 = conv * self.backend.bk_conjugate(MconvPsi)
4670
- else:
4671
- s3 = self.backend.bk_expand_dims(conv, -1) * self.backend.bk_conjugate(
4672
- MconvPsi
4673
- ) # [Nbatch, Npix_j3, Norient3, Norient2]
4674
-
4675
- ### Apply the mask [Nmask, Npix_j3] and sum over pixels
4676
- if return_data:
4677
- return s3
4678
- else:
4679
- if calc_var:
4680
- s3, vs3 = self.masked_mean(
4681
- s3, vmask, axis=1, rank=j2, calc_var=True
4682
- ) # [Nbatch, Nmask, Norient3, Norient2]
4683
- return s3, vs3
4684
- else:
4685
- s3 = self.masked_mean(
4686
- s3, vmask, axis=1, rank=j2
4687
- ) # [Nbatch, Nmask, Norient3, Norient2]
4688
- return s3
4689
-
4690
- def _compute_S4(
4691
- self,
4692
- j1,
4693
- j2,
4694
- vmask,
4695
- M1convPsi_dic,
4696
- M2convPsi_dic=None,
4697
- calc_var=False,
4698
- return_data=False,
4699
- ):
4700
- #### Simplify notations
4701
- M1 = M1convPsi_dic[j1] # [Nbatch, Npix_j3, Norient3, Norient1]
4702
-
4703
- # Auto or Cross coefficients
4704
- if M2convPsi_dic is None: # Auto
4705
- M2 = M1convPsi_dic[j2] # [Nbatch, Npix_j3, Norient3, Norient2]
4706
- else: # Cross
4707
- M2 = M2convPsi_dic[j2]
4708
-
4709
- ### Compute the product (|I1 * Psi_j1| * Psi_j3)(|I2 * Psi_j2| * Psi_j3)
4710
- # z_1 x z_2^* = (a1a2 + b1b2) + i(b1a2 - a1b2)
4711
- if self.use_1D:
4712
- s4 = M1 * self.backend.bk_conjugate(M2)
4713
- else:
4714
- s4 = self.backend.bk_expand_dims(M1, -2) * self.backend.bk_conjugate(
4715
- self.backend.bk_expand_dims(M2, -1)
4716
- ) # [Nbatch, Npix_j3, Norient3, Norient2, Norient1]
4717
-
4718
- ### Apply the mask and sum over pixels
4719
- if return_data:
4720
- return s4
4721
- else:
4722
- if calc_var:
4723
- s4, vs4 = self.masked_mean(
4724
- s4, vmask, axis=1, rank=j2, calc_var=True
4725
- ) # [Nbatch, Nmask, Norient3, Norient2, Norient1]
4726
- return s4, vs4
4727
- else:
4728
- s4 = self.masked_mean(
4729
- s4, vmask, axis=1, rank=j2
4730
- ) # [Nbatch, Nmask, Norient3, Norient2, Norient1]
4731
- return s4
4732
-
4733
- def computer_filter(self,M,N,J,L):
4734
- '''
4735
- This function is strongly inspire by the package https://github.com/SihaoCheng/scattering_transform
4736
- Done by Sihao Cheng and Rudy Morel.
4737
- '''
4738
-
4739
- filter = np.zeros([J, L, M, N],dtype='complex64')
4740
-
4741
- slant=4.0 / L
4742
-
4743
- for j in range(J):
4744
-
4745
- for l in range(L):
4746
-
4747
- theta = (int(L-L/2-1)-l) * np.pi / L
4748
- sigma = 0.8 * 2**j
4749
- xi = 3.0 / 4.0 * np.pi /2**j
4750
-
4751
- R = np.array([[np.cos(theta), -np.sin(theta)], [np.sin(theta), np.cos(theta)]], np.float64)
4752
- R_inv = np.array([[np.cos(theta), np.sin(theta)], [-np.sin(theta), np.cos(theta)]], np.float64)
4753
3719
  D = np.array([[1, 0], [0, slant * slant]])
4754
- curv = np.matmul(R, np.matmul(D, R_inv)) / ( 2 * sigma * sigma)
4755
-
3720
+ curv = np.matmul(R, np.matmul(D, R_inv)) / (2 * sigma * sigma)
3721
+
4756
3722
  gab = np.zeros((M, N), np.complex128)
4757
- xx = np.empty((2,2, M, N))
4758
- yy = np.empty((2,2, M, N))
4759
-
3723
+ xx = np.empty((2, 2, M, N))
3724
+ yy = np.empty((2, 2, M, N))
3725
+
4760
3726
  for ii, ex in enumerate([-1, 0]):
4761
3727
  for jj, ey in enumerate([-1, 0]):
4762
- xx[ii,jj], yy[ii,jj] = np.mgrid[
4763
- ex * M : M + ex * M,
4764
- ey * N : N + ey * N]
4765
-
4766
- arg = -(curv[0, 0] * xx * xx + (curv[0, 1] + curv[1, 0]) * xx * yy + curv[1, 1] * yy * yy)
4767
- argi = arg + 1.j * (xx * xi * np.cos(theta) + yy * xi * np.sin(theta))
4768
-
4769
- gabi = np.exp(argi).sum((0,1))
4770
- gab = np.exp(arg).sum((0,1))
4771
-
3728
+ xx[ii, jj], yy[ii, jj] = np.mgrid[
3729
+ ex * M : M + ex * M, ey * N : N + ey * N
3730
+ ]
3731
+
3732
+ arg = -(
3733
+ curv[0, 0] * xx * xx
3734
+ + (curv[0, 1] + curv[1, 0]) * xx * yy
3735
+ + curv[1, 1] * yy * yy
3736
+ )
3737
+ argi = arg + 1.0j * (xx * xi * np.cos(theta) + yy * xi * np.sin(theta))
3738
+
3739
+ gabi = np.exp(argi).sum((0, 1))
3740
+ gab = np.exp(arg).sum((0, 1))
3741
+
4772
3742
  norm_factor = 2 * np.pi * sigma * sigma / slant
4773
-
3743
+
4774
3744
  gab = gab / norm_factor
4775
-
3745
+
4776
3746
  gabi = gabi / norm_factor
4777
3747
 
4778
3748
  K = gabi.sum() / gab.sum()
4779
3749
 
4780
3750
  # Apply the Gaussian
4781
- filter[j, l] = np.fft.fft2(gabi-K*gab)
4782
- filter[j,l,0,0]=0.0
4783
-
3751
+ filter[j, ell] = np.fft.fft2(gabi - K * gab)
3752
+ filter[j, ell, 0, 0] = 0.0
3753
+
4784
3754
  return self.backend.bk_cast(filter)
4785
-
3755
+
4786
3756
  # ------------------------------------------------------------------------------------------
4787
3757
  #
4788
- # utility functions
3758
+ # utility functions
4789
3759
  #
4790
3760
  # ------------------------------------------------------------------------------------------
4791
- def cut_high_k_off(self,data_f, dx, dy):
4792
- '''
3761
+ def cut_high_k_off(self, data_f, dx, dy):
3762
+ """
4793
3763
  This function is strongly inspire by the package https://github.com/SihaoCheng/scattering_transform
4794
3764
  Done by Sihao Cheng and Rudy Morel.
4795
- '''
4796
-
4797
- if self.backend.BACKEND=='torch':
4798
- if_xodd = (data_f.shape[-2]%2==1)
4799
- if_yodd = (data_f.shape[-1]%2==1)
3765
+ """
3766
+
3767
+ if self.backend.BACKEND == "torch":
3768
+ if_xodd = data_f.shape[-2] % 2 == 1
3769
+ if_yodd = data_f.shape[-1] % 2 == 1
4800
3770
  result = self.backend.backend.cat(
4801
- (self.backend.backend.cat(
4802
- ( data_f[...,:dx+if_xodd, :dy+if_yodd] , data_f[...,-dx:, :dy+if_yodd]
4803
- ), -2),
4804
- self.backend.backend.cat(
4805
- ( data_f[...,:dx+if_xodd, -dy:] , data_f[...,-dx:, -dy:]
4806
- ), -2)
4807
- ),-1)
3771
+ (
3772
+ self.backend.backend.cat(
3773
+ (
3774
+ data_f[..., : dx + if_xodd, : dy + if_yodd],
3775
+ data_f[..., -dx:, : dy + if_yodd],
3776
+ ),
3777
+ -2,
3778
+ ),
3779
+ self.backend.backend.cat(
3780
+ (data_f[..., : dx + if_xodd, -dy:], data_f[..., -dx:, -dy:]), -2
3781
+ ),
3782
+ ),
3783
+ -1,
3784
+ )
4808
3785
  return result
4809
3786
  else:
4810
3787
  # Check if the last two dimensions are odd
4811
- if_xodd = self.backend.backend.cast(self.backend.backend.shape(data_f)[-2] % 2 == 1, self.backend.backend.int32)
4812
- if_yodd = self.backend.backend.cast(self.backend.backend.shape(data_f)[-1] % 2 == 1, self.backend.backend.int32)
3788
+ if_xodd = self.backend.backend.cast(
3789
+ self.backend.backend.shape(data_f)[-2] % 2 == 1,
3790
+ self.backend.backend.int32,
3791
+ )
3792
+ if_yodd = self.backend.backend.cast(
3793
+ self.backend.backend.shape(data_f)[-1] % 2 == 1,
3794
+ self.backend.backend.int32,
3795
+ )
4813
3796
 
4814
3797
  # Extract four regions
4815
- top_left = data_f[..., :dx+if_xodd, :dy+if_yodd]
4816
- top_right = data_f[..., -dx:, :dy+if_yodd]
4817
- bottom_left = data_f[..., :dx+if_xodd, -dy:]
3798
+ top_left = data_f[..., : dx + if_xodd, : dy + if_yodd]
3799
+ top_right = data_f[..., -dx:, : dy + if_yodd]
3800
+ bottom_left = data_f[..., : dx + if_xodd, -dy:]
4818
3801
  bottom_right = data_f[..., -dx:, -dy:]
4819
3802
 
4820
3803
  # Concatenate along the last two dimensions
@@ -4829,70 +3812,74 @@ class funct(FOC.FoCUS):
4829
3812
  # utility functions for computing scattering coef and covariance
4830
3813
  #
4831
3814
  # ---------------------------------------------------------------------------
4832
-
4833
- def get_dxdy(self, j,M,N):
4834
- '''
3815
+
3816
+ def get_dxdy(self, j, M, N):
3817
+ """
4835
3818
  This function is strongly inspire by the package https://github.com/SihaoCheng/scattering_transform
4836
3819
  Done by Sihao Cheng and Rudy Morel.
4837
- '''
4838
- dx = int(max( 8, min( np.ceil(M/2**j), M//2 ) ))
4839
- dy = int(max( 8, min( np.ceil(N/2**j), N//2 ) ))
3820
+ """
3821
+ dx = int(max(8, min(np.ceil(M / 2**j), M // 2)))
3822
+ dy = int(max(8, min(np.ceil(N / 2**j), N // 2)))
4840
3823
  return dx, dy
4841
-
4842
-
4843
3824
 
4844
- def get_edge_masks(self,M, N, J, d0=1):
4845
- '''
3825
+ def get_edge_masks(self, M, N, J, d0=1):
3826
+ """
4846
3827
  This function is strongly inspire by the package https://github.com/SihaoCheng/scattering_transform
4847
3828
  Done by Sihao Cheng and Rudy Morel.
4848
- '''
3829
+ """
4849
3830
  edge_masks = np.empty((J, M, N))
4850
- X, Y = np.meshgrid(np.arange(M), np.arange(N), indexing='ij')
3831
+ X, Y = np.meshgrid(np.arange(M), np.arange(N), indexing="ij")
4851
3832
  for j in range(J):
4852
- edge_dx = min(M//4, 2**j*d0)
4853
- edge_dy = min(N//4, 2**j*d0)
4854
- edge_masks[j] = (X>=edge_dx) * (X<=M-edge_dx) * (Y>=edge_dy) * (Y<=N-edge_dy)
4855
- edge_masks = edge_masks[:,None,:,:]
4856
- edge_masks = edge_masks / edge_masks.mean((-2,-1))[:,:,None,None]
3833
+ edge_dx = min(M // 4, 2**j * d0)
3834
+ edge_dy = min(N // 4, 2**j * d0)
3835
+ edge_masks[j] = (
3836
+ (X >= edge_dx)
3837
+ * (X <= M - edge_dx)
3838
+ * (Y >= edge_dy)
3839
+ * (Y <= N - edge_dy)
3840
+ )
3841
+ edge_masks = edge_masks[:, None, :, :]
3842
+ edge_masks = edge_masks / edge_masks.mean((-2, -1))[:, :, None, None]
4857
3843
  return self.backend.bk_cast(edge_masks)
4858
-
3844
+
4859
3845
  # ---------------------------------------------------------------------------
4860
3846
  #
4861
3847
  # scattering cov
4862
3848
  #
4863
3849
  # ---------------------------------------------------------------------------
4864
3850
  def scattering_cov(
4865
- self, data,
3851
+ self,
3852
+ data,
4866
3853
  data2=None,
4867
3854
  Jmax=None,
4868
- if_large_batch=False,
4869
- S4_criteria=None,
4870
- use_ref=False,
4871
- normalization='S2',
3855
+ if_large_batch=False,
3856
+ S4_criteria=None,
3857
+ use_ref=False,
3858
+ normalization="S2",
4872
3859
  edge=False,
4873
- pseudo_coef=1,
4874
- get_variance=False,
3860
+ pseudo_coef=1,
3861
+ get_variance=False,
4875
3862
  ref_sigma=None,
4876
- iso_ang=False
3863
+ iso_ang=False,
4877
3864
  ):
4878
- '''
3865
+ """
4879
3866
  Calculates the scattering correlations for a batch of images, including:
4880
-
3867
+
4881
3868
  This function is strongly inspire by the package https://github.com/SihaoCheng/scattering_transform
4882
3869
  Done by Sihao Cheng and Rudy Morel.
4883
-
4884
- orig. x orig.:
3870
+
3871
+ orig. x orig.:
4885
3872
  P00 = <(I * psi)(I * psi)*> = L2(I * psi)^2
4886
- orig. x modulus:
3873
+ orig. x modulus:
4887
3874
  C01 = <(I * psi2)(|I * psi1| * psi2)*> / factor
4888
3875
  when normalization == 'P00', factor = L2(I * psi2) * L2(I * psi1)
4889
3876
  when normalization == 'P11', factor = L2(I * psi2) * L2(|I * psi1| * psi2)
4890
- modulus x modulus:
3877
+ modulus x modulus:
4891
3878
  C11_pre_norm = <(|I * psi1| * psi3)(|I * psi2| * psi3)>
4892
3879
  C11 = C11_pre_norm / factor
4893
3880
  when normalization == 'P00', factor = L2(I * psi1) * L2(I * psi2)
4894
3881
  when normalization == 'P11', factor = L2(|I * psi1| * psi3) * L2(|I * psi2| * psi3)
4895
- modulus x modulus (auto):
3882
+ modulus x modulus (auto):
4896
3883
  P11 = <(|I * psi1| * psi2)(|I * psi1| * psi2)*>
4897
3884
  Parameters
4898
3885
  ----------
@@ -4902,7 +3889,7 @@ class funct(FOC.FoCUS):
4902
3889
  It is recommended to use "False" unless one meets a memory issue
4903
3890
  C11_criteria : str or None (=None)
4904
3891
  Only C11 coefficients that satisfy this criteria will be computed.
4905
- Any expressions of j1, j2, and j3 that can be evaluated as a Bool
3892
+ Any expressions of j1, j2, and j3 that can be evaluated as a Bool
4906
3893
  is accepted.The default "None" corresponds to "j1 <= j2 <= j3".
4907
3894
  use_ref : Bool (=False)
4908
3895
  When normalizing, whether or not to use the normalization factor
@@ -4916,7 +3903,7 @@ class funct(FOC.FoCUS):
4916
3903
  If true, the edge region with a width of rougly the size of the largest
4917
3904
  wavelet involved is excluded when taking the global average to obtain
4918
3905
  the scattering coefficients.
4919
-
3906
+
4920
3907
  Returns
4921
3908
  -------
4922
3909
  'P00' : torch tensor with size [N_image, J, L] (# image, j1, l1)
@@ -4934,34 +3921,34 @@ class funct(FOC.FoCUS):
4934
3921
  j1 <= j3 are set to np.nan and not computed.
4935
3922
  'P11_iso' : torch tensor with size [N_image, J, J, L] (# image, j1, j2, l2-l1)
4936
3923
  'P11' averaged over l1 while keeping l2-l1 constant.
4937
- '''
3924
+ """
4938
3925
  if S4_criteria is None:
4939
- S4_criteria = 'j2>=j1'
4940
-
3926
+ S4_criteria = "j2>=j1"
3927
+
4941
3928
  if self.all_bk_type == "float32":
4942
- C_ONE=np.complex64(1.0)
3929
+ C_ONE = np.complex64(1.0)
4943
3930
  else:
4944
- C_ONE=np.complex128(1.0)
4945
-
3931
+ C_ONE = np.complex128(1.0)
3932
+
4946
3933
  # determine jmax and nside corresponding to the input map
4947
3934
  im_shape = data.shape
4948
3935
  if self.use_2D:
4949
3936
  if len(data.shape) == 2:
4950
3937
  nside = np.min([im_shape[0], im_shape[1]])
4951
- M,N = im_shape[0],im_shape[1]
4952
- N_image = 1
3938
+ M, N = im_shape[0], im_shape[1]
3939
+ N_image = 1
4953
3940
  N_image2 = 1
4954
3941
  else:
4955
3942
  nside = np.min([im_shape[1], im_shape[2]])
4956
- M,N = im_shape[1],im_shape[2]
3943
+ M, N = im_shape[1], im_shape[2]
4957
3944
  N_image = data.shape[0]
4958
3945
  if data2 is not None:
4959
3946
  N_image2 = data2.shape[0]
4960
- J = int(np.log(nside) / np.log(2))-1 # Number of j scales
3947
+ J = int(np.log(nside) / np.log(2)) - 1 # Number of j scales
4961
3948
  elif self.use_1D:
4962
3949
  if len(data.shape) == 2:
4963
3950
  npix = int(im_shape[1]) # Number of pixels
4964
- N_image = 1
3951
+ N_image = 1
4965
3952
  N_image2 = 1
4966
3953
  else:
4967
3954
  npix = int(im_shape[0]) # Number of pixels
@@ -4971,12 +3958,12 @@ class funct(FOC.FoCUS):
4971
3958
 
4972
3959
  nside = int(npix)
4973
3960
 
4974
- J = int(np.log(nside) / np.log(2))-1 # Number of j scales
3961
+ J = int(np.log(nside) / np.log(2)) - 1 # Number of j scales
4975
3962
  else:
4976
3963
  if len(data.shape) == 2:
4977
3964
  npix = int(im_shape[1]) # Number of pixels
4978
- N_image = 1
4979
- N_image2 = 1
3965
+ N_image = 1
3966
+ N_image2 = 1
4980
3967
  else:
4981
3968
  npix = int(im_shape[0]) # Number of pixels
4982
3969
  N_image = data.shape[0]
@@ -4986,984 +3973,1715 @@ class funct(FOC.FoCUS):
4986
3973
  nside = int(np.sqrt(npix // 12))
4987
3974
 
4988
3975
  J = int(np.log(nside) / np.log(2)) # Number of j scales
4989
-
3976
+
4990
3977
  if Jmax is None:
4991
3978
  Jmax = J # Number of steps for the loop on scales
4992
- if Jmax>J:
4993
- print('==========\n\n')
4994
- print('The Jmax you requested is larger than the data size, which may cause problems while computing the scattering transform.')
4995
- print('\n\n==========')
4996
-
4997
- L=self.NORIENT
4998
- norm_factor_S3=1.0
4999
-
5000
- if self.backend.BACKEND=='torch':
5001
- if (M,N,J,L) not in self.filters_set:
5002
- self.filters_set[(M,N,J,L)] = self.computer_filter(M,N,J,L) #self.computer_filter(M,N,J,L)
5003
-
5004
- filters_set = self.filters_set[(M,N,J,L)]
5005
-
5006
- #weight = self.weight
3979
+ if Jmax > J:
3980
+ print("==========\n\n")
3981
+ print(
3982
+ "The Jmax you requested is larger than the data size, which may cause problems while computing the scattering transform."
3983
+ )
3984
+ print("\n\n==========")
3985
+
3986
+ L = self.NORIENT
3987
+ norm_factor_S3 = 1.0
3988
+
3989
+ if self.backend.BACKEND == "torch":
3990
+ if (M, N, J, L) not in self.filters_set:
3991
+ self.filters_set[(M, N, J, L)] = self.computer_filter(
3992
+ M, N, J, L
3993
+ ) # self.computer_filter(M,N,J,L)
3994
+
3995
+ filters_set = self.filters_set[(M, N, J, L)]
3996
+
3997
+ # weight = self.weight
5007
3998
  if use_ref:
5008
- if normalization=='S2':
3999
+ if normalization == "S2":
5009
4000
  ref_S2 = self.ref_scattering_cov_S2
5010
- else:
5011
- ref_P11 = self.ref_scattering_cov['P11']
4001
+ else:
4002
+ ref_P11 = self.ref_scattering_cov["P11"]
5012
4003
 
5013
4004
  # convert numpy array input into self.backend.bk_ tensors
5014
4005
  data = self.backend.bk_cast(data)
5015
- data_f = self.backend.bk_fftn(data, dim=(-2,-1))
4006
+ data_f = self.backend.bk_fftn(data, dim=(-2, -1))
5016
4007
  if data2 is not None:
5017
4008
  data2 = self.backend.bk_cast(data2)
5018
- data2_f = self.backend.bk_fftn(data2, dim=(-2,-1))
5019
-
4009
+ data2_f = self.backend.bk_fftn(data2, dim=(-2, -1))
4010
+
5020
4011
  # initialize tensors for scattering coefficients
5021
- S2 = self.backend.bk_zeros((N_image,J,L), dtype=data.dtype)
5022
- S1 = self.backend.bk_zeros((N_image,J,L), dtype=data.dtype)
5023
-
5024
- Ndata_S3 = J*(J+1)//2
5025
- Ndata_S4 = J*(J+1)*(J+2)//6
5026
- J_S4={}
5027
-
5028
- S3 = self.backend.bk_zeros((N_image,Ndata_S3,L,L), dtype=data_f.dtype)
4012
+ S2 = self.backend.bk_zeros((N_image, J, L), dtype=data.dtype)
4013
+ S1 = self.backend.bk_zeros((N_image, J, L), dtype=data.dtype)
4014
+
4015
+ Ndata_S3 = J * (J + 1) // 2
4016
+ Ndata_S4 = J * (J + 1) * (J + 2) // 6
4017
+ J_S4 = {}
4018
+
4019
+ S3 = self.backend.bk_zeros((N_image, Ndata_S3, L, L), dtype=data_f.dtype)
5029
4020
  if data2 is not None:
5030
- S3p = self.backend.bk_zeros((N_image,Ndata_S3,L,L), dtype=data_f.dtype)
5031
- S4_pre_norm = self.backend.bk_zeros((N_image,Ndata_S4,L,L,L), dtype=data_f.dtype)
5032
- S4 = self.backend.bk_zeros((N_image,Ndata_S4,L,L,L), dtype=data_f.dtype)
5033
-
4021
+ S3p = self.backend.bk_zeros(
4022
+ (N_image, Ndata_S3, L, L), dtype=data_f.dtype
4023
+ )
4024
+ S4_pre_norm = self.backend.bk_zeros(
4025
+ (N_image, Ndata_S4, L, L, L), dtype=data_f.dtype
4026
+ )
4027
+ S4 = self.backend.bk_zeros((N_image, Ndata_S4, L, L, L), dtype=data_f.dtype)
4028
+
5034
4029
  # variance
5035
4030
  if get_variance:
5036
- S2_sigma = self.backend.bk_zeros((N_image,J,L), dtype=data.dtype)
5037
- S1_sigma = self.backend.bk_zeros((N_image,J,L), dtype=data.dtype)
5038
- S3_sigma = self.backend.bk_zeros((N_image,Ndata_S3,L,L), dtype=data_f.dtype)
4031
+ S2_sigma = self.backend.bk_zeros((N_image, J, L), dtype=data.dtype)
4032
+ S1_sigma = self.backend.bk_zeros((N_image, J, L), dtype=data.dtype)
4033
+ S3_sigma = self.backend.bk_zeros(
4034
+ (N_image, Ndata_S3, L, L), dtype=data_f.dtype
4035
+ )
5039
4036
  if data2 is not None:
5040
- S3p_sigma = self.backend.bk_zeros((N_image,Ndata_S3,L,L), dtype=data_f.dtype)
5041
- S4_sigma = self.backend.bk_zeros((N_image,Ndata_S4,L,L,L), dtype=data_f.dtype)
5042
-
4037
+ S3p_sigma = self.backend.bk_zeros(
4038
+ (N_image, Ndata_S3, L, L), dtype=data_f.dtype
4039
+ )
4040
+ S4_sigma = self.backend.bk_zeros(
4041
+ (N_image, Ndata_S4, L, L, L), dtype=data_f.dtype
4042
+ )
4043
+
5043
4044
  if iso_ang:
5044
- S3_iso = self.backend.bk_zeros((N_image,Ndata_S3,L), dtype=data_f.dtype)
5045
- S4_iso = self.backend.bk_zeros((N_image,Ndata_S4,L,L), dtype=data_f.dtype)
4045
+ S3_iso = self.backend.bk_zeros(
4046
+ (N_image, Ndata_S3, L), dtype=data_f.dtype
4047
+ )
4048
+ S4_iso = self.backend.bk_zeros(
4049
+ (N_image, Ndata_S4, L, L), dtype=data_f.dtype
4050
+ )
5046
4051
  if get_variance:
5047
- S3_sigma_iso = self.backend.bk_zeros((N_image,Ndata_S3,L), dtype=data_f.dtype)
5048
- S4_sigma_iso = self.backend.bk_zeros((N_image,Ndata_S4,L,L), dtype=data_f.dtype)
4052
+ S3_sigma_iso = self.backend.bk_zeros(
4053
+ (N_image, Ndata_S3, L), dtype=data_f.dtype
4054
+ )
4055
+ S4_sigma_iso = self.backend.bk_zeros(
4056
+ (N_image, Ndata_S4, L, L), dtype=data_f.dtype
4057
+ )
5049
4058
  if data2 is not None:
5050
- S3p_iso = self.backend.bk_zeros((N_image,Ndata_S3,L), dtype=data_f.dtype)
4059
+ S3p_iso = self.backend.bk_zeros(
4060
+ (N_image, Ndata_S3, L), dtype=data_f.dtype
4061
+ )
5051
4062
  if get_variance:
5052
- S3p_sigma_iso = self.backend.bk_zeros((N_image,Ndata_S3,L), dtype=data_f.dtype)
5053
-
4063
+ S3p_sigma_iso = self.backend.bk_zeros(
4064
+ (N_image, Ndata_S3, L), dtype=data_f.dtype
4065
+ )
4066
+
5054
4067
  #
5055
- if edge:
5056
- if (M,N,J) not in self.edge_masks:
5057
- self.edge_masks[(M,N,J)] = self.get_edge_masks(M,N,J)
5058
- edge_mask=self.edge_masks[(M,N,J)]
5059
- else:
4068
+ if edge:
4069
+ if (M, N, J) not in self.edge_masks:
4070
+ self.edge_masks[(M, N, J)] = self.get_edge_masks(M, N, J)
4071
+ edge_mask = self.edge_masks[(M, N, J)]
4072
+ else:
5060
4073
  edge_mask = 1
5061
-
4074
+
5062
4075
  # calculate scattering fields
5063
4076
  if data2 is None:
5064
4077
  if self.use_2D:
5065
4078
  if len(data.shape) == 2:
5066
4079
  I1 = self.backend.bk_ifftn(
5067
- data_f[None,None,None,:,:] * filters_set[None,:J,:,:,:], dim=(-2,-1)
4080
+ data_f[None, None, None, :, :]
4081
+ * filters_set[None, :J, :, :, :],
4082
+ dim=(-2, -1),
5068
4083
  ).abs()
5069
4084
  else:
5070
4085
  I1 = self.backend.bk_ifftn(
5071
- data_f[:,None,None,:,:] * filters_set[None,:J,:,:,:], dim=(-2,-1)
4086
+ data_f[:, None, None, :, :]
4087
+ * filters_set[None, :J, :, :, :],
4088
+ dim=(-2, -1),
5072
4089
  ).abs()
5073
4090
  elif self.use_1D:
5074
4091
  if len(data.shape) == 1:
5075
4092
  I1 = self.backend.bk_ifftn(
5076
- data_f[None,None,None,:] * filters_set[None,:J,:,:], dim=(-1)
4093
+ data_f[None, None, None, :] * filters_set[None, :J, :, :],
4094
+ dim=(-1),
5077
4095
  ).abs()
5078
4096
  else:
5079
4097
  I1 = self.backend.bk_ifftn(
5080
- data_f[:,None,None,:] * filters_set[None,:J,:,:], dim=(-1)
4098
+ data_f[:, None, None, :] * filters_set[None, :J, :, :],
4099
+ dim=(-1),
5081
4100
  ).abs()
5082
4101
  else:
5083
- print('todo')
5084
-
5085
- S2 = (I1**2 * edge_mask).mean((-2,-1))
5086
- S1 = (I1 * edge_mask).mean((-2,-1))
4102
+ print("todo")
4103
+
4104
+ S2 = (I1**2 * edge_mask).mean((-2, -1))
4105
+ S1 = (I1 * edge_mask).mean((-2, -1))
5087
4106
 
5088
4107
  if get_variance:
5089
- S2_sigma = (I1**2 * edge_mask).std((-2,-1))
5090
- S1_sigma = (I1 * edge_mask).std((-2,-1))
5091
-
4108
+ S2_sigma = (I1**2 * edge_mask).std((-2, -1))
4109
+ S1_sigma = (I1 * edge_mask).std((-2, -1))
4110
+
5092
4111
  else:
5093
4112
  if self.use_2D:
5094
4113
  if len(data.shape) == 2:
5095
4114
  I1 = self.backend.bk_ifftn(
5096
- data_f[None,None,None,:,:] * filters_set[None,:J,:,:,:], dim=(-2,-1)
4115
+ data_f[None, None, None, :, :]
4116
+ * filters_set[None, :J, :, :, :],
4117
+ dim=(-2, -1),
5097
4118
  )
5098
4119
  I2 = self.backend.bk_ifftn(
5099
- data2_f[None,None,None,:,:] * filters_set[None,:J,:,:,:], dim=(-2,-1)
4120
+ data2_f[None, None, None, :, :]
4121
+ * filters_set[None, :J, :, :, :],
4122
+ dim=(-2, -1),
5100
4123
  )
5101
4124
  else:
5102
4125
  I1 = self.backend.bk_ifftn(
5103
- data_f[:,None,None,:,:] * filters_set[None,:J,:,:,:], dim=(-2,-1)
4126
+ data_f[:, None, None, :, :]
4127
+ * filters_set[None, :J, :, :, :],
4128
+ dim=(-2, -1),
5104
4129
  )
5105
4130
  I2 = self.backend.bk_ifftn(
5106
- data2_f[:,None,None,:,:] * filters_set[None,:J,:,:,:], dim=(-2,-1)
4131
+ data2_f[:, None, None, :, :]
4132
+ * filters_set[None, :J, :, :, :],
4133
+ dim=(-2, -1),
5107
4134
  )
5108
4135
  elif self.use_1D:
5109
4136
  if len(data.shape) == 1:
5110
4137
  I1 = self.backend.bk_ifftn(
5111
- data_f[None,None,None,:] * filters_set[None,:J,:,:], dim=(-1)
4138
+ data_f[None, None, None, :] * filters_set[None, :J, :, :],
4139
+ dim=(-1),
5112
4140
  )
5113
4141
  I2 = self.backend.bk_ifftn(
5114
- data2_f[None,None,None,:] * filters_set[None,:J,:,:], dim=(-1)
4142
+ data2_f[None, None, None, :] * filters_set[None, :J, :, :],
4143
+ dim=(-1),
5115
4144
  )
5116
4145
  else:
5117
4146
  I1 = self.backend.bk_ifftn(
5118
- data_f[:,None,None,:] * filters_set[None,:J,:,:], dim=(-1)
4147
+ data_f[:, None, None, :] * filters_set[None, :J, :, :],
4148
+ dim=(-1),
5119
4149
  )
5120
4150
  I2 = self.backend.bk_ifftn(
5121
- data2_f[:,None,None,:] * filters_set[None,:J,:,:], dim=(-1)
4151
+ data2_f[:, None, None, :] * filters_set[None, :J, :, :],
4152
+ dim=(-1),
5122
4153
  )
5123
4154
  else:
5124
- print('todo')
5125
-
5126
- I1=self.backend.bk_real(I1*self.backend.bk_conjugate(I2))
5127
-
5128
- S2 = self.backend.bk_reduce_mean((I1 * edge_mask),axis=(-2,-1))
4155
+ print("todo")
4156
+
4157
+ I1 = self.backend.bk_real(I1 * self.backend.bk_conjugate(I2))
4158
+
4159
+ S2 = self.backend.bk_reduce_mean((I1 * edge_mask), axis=(-2, -1))
5129
4160
  if get_variance:
5130
- S2_sigma = self.backend.bk_reduce_std((I1 * edge_mask),axis=(-2,-1))
5131
-
5132
- I1=self.backend.bk_L1(I1)
5133
-
5134
- S1 = self.backend.bk_reduce_mean((I1 * edge_mask),axis=(-2,-1))
4161
+ S2_sigma = self.backend.bk_reduce_std(
4162
+ (I1 * edge_mask), axis=(-2, -1)
4163
+ )
4164
+
4165
+ I1 = self.backend.bk_L1(I1)
4166
+
4167
+ S1 = self.backend.bk_reduce_mean((I1 * edge_mask), axis=(-2, -1))
5135
4168
 
5136
4169
  if get_variance:
5137
- S1_sigma = self.backend.bk_reduce_std((I1 * edge_mask),axis=(-2,-1))
5138
-
5139
- I1_f= self.backend.bk_fftn(I1, dim=(-2,-1))
5140
-
4170
+ S1_sigma = self.backend.bk_reduce_std(
4171
+ (I1 * edge_mask), axis=(-2, -1)
4172
+ )
4173
+
4174
+ I1_f = self.backend.bk_fftn(I1, dim=(-2, -1))
4175
+
5141
4176
  if pseudo_coef != 1:
5142
4177
  I1 = I1**pseudo_coef
5143
-
5144
- Ndata_S3=0
5145
- Ndata_S4=0
5146
-
4178
+
4179
+ Ndata_S3 = 0
4180
+ Ndata_S4 = 0
4181
+
5147
4182
  # calculate the covariance and correlations of the scattering fields
5148
4183
  # only use the low-k Fourier coefs when calculating large-j scattering coefs.
5149
- for j3 in range(0,J):
5150
- J_S4[j3]=Ndata_S4
5151
-
5152
- dx3, dy3 = self.get_dxdy(j3,M,N)
5153
- I1_f_small = self.cut_high_k_off(I1_f[:,:j3+1], dx3, dy3) # Nimage, J, L, x, y
4184
+ for j3 in range(0, J):
4185
+ J_S4[j3] = Ndata_S4
4186
+
4187
+ dx3, dy3 = self.get_dxdy(j3, M, N)
4188
+ I1_f_small = self.cut_high_k_off(
4189
+ I1_f[:, : j3 + 1], dx3, dy3
4190
+ ) # Nimage, J, L, x, y
5154
4191
  data_f_small = self.cut_high_k_off(data_f, dx3, dy3)
5155
4192
  if data2 is not None:
5156
4193
  data2_f_small = self.cut_high_k_off(data2_f, dx3, dy3)
5157
4194
  if edge:
5158
- I1_small = self.backend.bk_ifftn(I1_f_small, dim=(-2,-1), norm='ortho')
5159
- data_small = self.backend.bk_ifftn(data_f_small, dim=(-2,-1), norm='ortho')
4195
+ I1_small = self.backend.bk_ifftn(
4196
+ I1_f_small, dim=(-2, -1), norm="ortho"
4197
+ )
4198
+ data_small = self.backend.bk_ifftn(
4199
+ data_f_small, dim=(-2, -1), norm="ortho"
4200
+ )
5160
4201
  if data2 is not None:
5161
- data2_small = self.backend.bk_ifftn(data2_f_small, dim=(-2,-1), norm='ortho')
5162
-
5163
- wavelet_f3 = self.cut_high_k_off(filters_set[j3], dx3, dy3) # L,x,y
4202
+ data2_small = self.backend.bk_ifftn(
4203
+ data2_f_small, dim=(-2, -1), norm="ortho"
4204
+ )
4205
+
4206
+ wavelet_f3 = self.cut_high_k_off(filters_set[j3], dx3, dy3) # L,x,y
5164
4207
  _, M3, N3 = wavelet_f3.shape
5165
4208
  wavelet_f3_squared = wavelet_f3**2
5166
- edge_dx = min(4, int(2**j3*dx3*2/M))
5167
- edge_dy = min(4, int(2**j3*dy3*2/N))
5168
-
4209
+ edge_dx = min(4, int(2**j3 * dx3 * 2 / M))
4210
+ edge_dy = min(4, int(2**j3 * dy3 * 2 / N))
4211
+
5169
4212
  # a normalization change due to the cutoff of frequency space
5170
- fft_factor = 1 /(M3*N3) * (M3*N3/M/N)**2
5171
- for j2 in range(0,j3+1):
5172
- I1_f2_wf3_small = I1_f_small[:,j2].view(N_image,L,1,M3,N3) * wavelet_f3.view(1,1,L,M3,N3)
5173
- I1_f2_wf3_2_small = I1_f_small[:,j2].view(N_image,L,1,M3,N3) * wavelet_f3_squared.view(1,1,L,M3,N3)
4213
+ fft_factor = 1 / (M3 * N3) * (M3 * N3 / M / N) ** 2
4214
+ for j2 in range(0, j3 + 1):
4215
+ I1_f2_wf3_small = I1_f_small[:, j2].view(
4216
+ N_image, L, 1, M3, N3
4217
+ ) * wavelet_f3.view(1, 1, L, M3, N3)
4218
+ I1_f2_wf3_2_small = I1_f_small[:, j2].view(
4219
+ N_image, L, 1, M3, N3
4220
+ ) * wavelet_f3_squared.view(1, 1, L, M3, N3)
5174
4221
  if edge:
5175
- I12_w3_small = self.backend.bk_ifftn(I1_f2_wf3_small, dim=(-2,-1), norm='ortho')
5176
- I12_w3_2_small = self.backend.bk_ifftn(I1_f2_wf3_2_small, dim=(-2,-1), norm='ortho')
4222
+ I12_w3_small = self.backend.bk_ifftn(
4223
+ I1_f2_wf3_small, dim=(-2, -1), norm="ortho"
4224
+ )
4225
+ I12_w3_2_small = self.backend.bk_ifftn(
4226
+ I1_f2_wf3_2_small, dim=(-2, -1), norm="ortho"
4227
+ )
5177
4228
  if use_ref:
5178
- if normalization=='P11':
5179
- norm_factor_S3 = (ref_S2[:,None,j3,:] * ref_P11[:,j2,j3,:,:]**pseudo_coef)**0.5
5180
- if normalization=='S2':
5181
- norm_factor_S3 = (ref_S2[:,None,j3,:] * ref_S2[:,j2,:,None]**pseudo_coef)**0.5
4229
+ if normalization == "P11":
4230
+ norm_factor_S3 = (
4231
+ ref_S2[:, None, j3, :]
4232
+ * ref_P11[:, j2, j3, :, :] ** pseudo_coef
4233
+ ) ** 0.5
4234
+ if normalization == "S2":
4235
+ norm_factor_S3 = (
4236
+ ref_S2[:, None, j3, :]
4237
+ * ref_S2[:, j2, :, None] ** pseudo_coef
4238
+ ) ** 0.5
5182
4239
  else:
5183
- if normalization=='P11':
4240
+ if normalization == "P11":
5184
4241
  # [N_image,l2,l3,x,y]
5185
- P11_temp = (I1_f2_wf3_small.abs()**2).mean((-2,-1)) * fft_factor
5186
- norm_factor_S3 = (S2[:,None,j3,:] * P11_temp**pseudo_coef)**0.5
5187
- if normalization=='S2':
5188
- norm_factor_S3 = (S2[:,None,j3,:] * S2[:,j2,:,None]**pseudo_coef)**0.5
4242
+ P11_temp = (I1_f2_wf3_small.abs() ** 2).mean(
4243
+ (-2, -1)
4244
+ ) * fft_factor
4245
+ norm_factor_S3 = (
4246
+ S2[:, None, j3, :] * P11_temp**pseudo_coef
4247
+ ) ** 0.5
4248
+ if normalization == "S2":
4249
+ norm_factor_S3 = (
4250
+ S2[:, None, j3, :] * S2[:, j2, :, None] ** pseudo_coef
4251
+ ) ** 0.5
5189
4252
 
5190
4253
  if not edge:
5191
- S3[:,Ndata_S3,:,:] = (
5192
- data_f_small.view(N_image,1,1,M3,N3) * self.backend.bk_conjugate(I1_f2_wf3_small)
5193
- ).mean((-2,-1)) * fft_factor / norm_factor_S3
5194
-
4254
+ S3[:, Ndata_S3, :, :] = (
4255
+ (
4256
+ data_f_small.view(N_image, 1, 1, M3, N3)
4257
+ * self.backend.bk_conjugate(I1_f2_wf3_small)
4258
+ ).mean((-2, -1))
4259
+ * fft_factor
4260
+ / norm_factor_S3
4261
+ )
4262
+
5195
4263
  if get_variance:
5196
- S3_sigma[:,Ndata_S3,:,:] = (
5197
- data_f_small.view(N_image,1,1,M3,N3) * self.backend.bk_conjugate(I1_f2_wf3_small)
5198
- ).std((-2,-1)) * fft_factor / norm_factor_S3
4264
+ S3_sigma[:, Ndata_S3, :, :] = (
4265
+ (
4266
+ data_f_small.view(N_image, 1, 1, M3, N3)
4267
+ * self.backend.bk_conjugate(I1_f2_wf3_small)
4268
+ ).std((-2, -1))
4269
+ * fft_factor
4270
+ / norm_factor_S3
4271
+ )
5199
4272
  else:
5200
- S3[:,Ndata_S3,:,:] = (
5201
- data_small.view(N_image,1,1,M3,N3) * self.backend.bk_conjugate(I12_w3_small)
5202
- )[...,edge_dx:M3-edge_dx, edge_dy:N3-edge_dy].mean((-2,-1)) * fft_factor / norm_factor_S3
4273
+ S3[:, Ndata_S3, :, :] = (
4274
+ (
4275
+ data_small.view(N_image, 1, 1, M3, N3)
4276
+ * self.backend.bk_conjugate(I12_w3_small)
4277
+ )[..., edge_dx : M3 - edge_dx, edge_dy : N3 - edge_dy].mean(
4278
+ (-2, -1)
4279
+ )
4280
+ * fft_factor
4281
+ / norm_factor_S3
4282
+ )
5203
4283
  if get_variance:
5204
- S3_sigma[:,Ndata_S3,:,:] = (
5205
- data_small.view(N_image,1,1,M3,N3) * self.backend.bk_conjugate(I12_w3_small)
5206
- )[...,edge_dx:M3-edge_dx, edge_dy:N3-edge_dy].std((-2,-1)) * fft_factor / norm_factor_S3
4284
+ S3_sigma[:, Ndata_S3, :, :] = (
4285
+ (
4286
+ data_small.view(N_image, 1, 1, M3, N3)
4287
+ * self.backend.bk_conjugate(I12_w3_small)
4288
+ )[
4289
+ ..., edge_dx : M3 - edge_dx, edge_dy : N3 - edge_dy
4290
+ ].std(
4291
+ (-2, -1)
4292
+ )
4293
+ * fft_factor
4294
+ / norm_factor_S3
4295
+ )
5207
4296
  if data2 is not None:
5208
4297
  if not edge:
5209
- S3p[:,Ndata_S3,:,:] = (
5210
- data2_f_small.view(N_image2,1,1,M3,N3) * self.backend.bk_conjugate(I1_f2_wf3_small)
5211
- ).mean((-2,-1)) * fft_factor / norm_factor_S3
5212
-
4298
+ S3p[:, Ndata_S3, :, :] = (
4299
+ (
4300
+ data2_f_small.view(N_image2, 1, 1, M3, N3)
4301
+ * self.backend.bk_conjugate(I1_f2_wf3_small)
4302
+ ).mean((-2, -1))
4303
+ * fft_factor
4304
+ / norm_factor_S3
4305
+ )
4306
+
5213
4307
  if get_variance:
5214
- S3p_sigma[:,Ndata_S3,:,:] = (
5215
- data2_f_small.view(N_image2,1,1,M3,N3) * self.backend.bk_conjugate(I1_f2_wf3_small)
5216
- ).std((-2,-1)) * fft_factor / norm_factor_S3
4308
+ S3p_sigma[:, Ndata_S3, :, :] = (
4309
+ (
4310
+ data2_f_small.view(N_image2, 1, 1, M3, N3)
4311
+ * self.backend.bk_conjugate(I1_f2_wf3_small)
4312
+ ).std((-2, -1))
4313
+ * fft_factor
4314
+ / norm_factor_S3
4315
+ )
5217
4316
  else:
5218
- S3p[:,Ndata_S3,:,:] = (
5219
- data2_small.view(N_image2,1,1,M3,N3) * self.backend.bk_conjugate(I12_w3_small)
5220
- )[...,edge_dx:M3-edge_dx, edge_dy:N3-edge_dy].mean((-2,-1)) * fft_factor / norm_factor_S3
4317
+ S3p[:, Ndata_S3, :, :] = (
4318
+ (
4319
+ data2_small.view(N_image2, 1, 1, M3, N3)
4320
+ * self.backend.bk_conjugate(I12_w3_small)
4321
+ )[
4322
+ ..., edge_dx : M3 - edge_dx, edge_dy : N3 - edge_dy
4323
+ ].mean(
4324
+ (-2, -1)
4325
+ )
4326
+ * fft_factor
4327
+ / norm_factor_S3
4328
+ )
5221
4329
  if get_variance:
5222
- S3p_sigma[:,Ndata_S3,:,:] = (
5223
- data2_small.view(N_image2,1,1,M3,N3) * self.backend.bk_conjugate(I12_w3_small)
5224
- )[...,edge_dx:M3-edge_dx, edge_dy:N3-edge_dy].std((-2,-1)) * fft_factor / norm_factor_S3
5225
- Ndata_S3+=1
4330
+ S3p_sigma[:, Ndata_S3, :, :] = (
4331
+ (
4332
+ data2_small.view(N_image2, 1, 1, M3, N3)
4333
+ * self.backend.bk_conjugate(I12_w3_small)
4334
+ )[
4335
+ ...,
4336
+ edge_dx : M3 - edge_dx,
4337
+ edge_dy : N3 - edge_dy,
4338
+ ].std(
4339
+ (-2, -1)
4340
+ )
4341
+ * fft_factor
4342
+ / norm_factor_S3
4343
+ )
4344
+ Ndata_S3 += 1
5226
4345
  if j2 <= j3:
5227
- beg_n=Ndata_S4
5228
- for j1 in range(0, j2+1):
4346
+ beg_n = Ndata_S4
4347
+ for j1 in range(0, j2 + 1):
5229
4348
  if eval(S4_criteria):
5230
4349
  if not edge:
5231
4350
  if not if_large_batch:
5232
4351
  # [N_image,l1,l2,l3,x,y]
5233
- S4_pre_norm[:,Ndata_S4,:,:,:] = (
5234
- I1_f_small[:,j1].view(N_image,L,1,1,M3,N3) *
5235
- self.backend.bk_conjugate(I1_f2_wf3_2_small.view(N_image,1,L,L,M3,N3))
5236
- ).mean((-2,-1)) * fft_factor
4352
+ S4_pre_norm[:, Ndata_S4, :, :, :] = (
4353
+ I1_f_small[:, j1].view(
4354
+ N_image, L, 1, 1, M3, N3
4355
+ )
4356
+ * self.backend.bk_conjugate(
4357
+ I1_f2_wf3_2_small.view(
4358
+ N_image, 1, L, L, M3, N3
4359
+ )
4360
+ )
4361
+ ).mean((-2, -1)) * fft_factor
5237
4362
  if get_variance:
5238
- S4_sigma[:,Ndata_S4,:,:,:] = (
5239
- I1_f_small[:,j1].view(N_image,L,1,1,M3,N3) *
5240
- self.backend.bk_conjugate(I1_f2_wf3_2_small.view(N_image,1,L,L,M3,N3))
5241
- ).std((-2,-1)) * fft_factor
4363
+ S4_sigma[:, Ndata_S4, :, :, :] = (
4364
+ I1_f_small[:, j1].view(
4365
+ N_image, L, 1, 1, M3, N3
4366
+ )
4367
+ * self.backend.bk_conjugate(
4368
+ I1_f2_wf3_2_small.view(
4369
+ N_image, 1, L, L, M3, N3
4370
+ )
4371
+ )
4372
+ ).std((-2, -1)) * fft_factor
5242
4373
  else:
5243
4374
  for l1 in range(L):
5244
4375
  # [N_image,l2,l3,x,y]
5245
- S4_pre_norm[:,Ndata_S4,l1,:,:] = (
5246
- I1_f_small[:,j1,l1].view(N_image,1,1,M3,N3) *
5247
- self.backend.bk_conjugate(I1_f2_wf3_2_small.view(N_image,L,L,M3,N3))
5248
- ).mean((-2,-1)) * fft_factor
4376
+ S4_pre_norm[:, Ndata_S4, l1, :, :] = (
4377
+ I1_f_small[:, j1, l1].view(
4378
+ N_image, 1, 1, M3, N3
4379
+ )
4380
+ * self.backend.bk_conjugate(
4381
+ I1_f2_wf3_2_small.view(
4382
+ N_image, L, L, M3, N3
4383
+ )
4384
+ )
4385
+ ).mean((-2, -1)) * fft_factor
5249
4386
  if get_variance:
5250
- S4_sigma[:,Ndata_S4,l1,:,:] = (
5251
- I1_f_small[:,j1,l1].view(N_image,1,1,M3,N3) *
5252
- self.backend.bk_conjugate(I1_f2_wf3_2_small.view(N_image,L,L,M3,N3))
5253
- ).std((-2,-1)) * fft_factor
4387
+ S4_sigma[:, Ndata_S4, l1, :, :] = (
4388
+ I1_f_small[:, j1, l1].view(
4389
+ N_image, 1, 1, M3, N3
4390
+ )
4391
+ * self.backend.bk_conjugate(
4392
+ I1_f2_wf3_2_small.view(
4393
+ N_image, L, L, M3, N3
4394
+ )
4395
+ )
4396
+ ).std((-2, -1)) * fft_factor
5254
4397
  else:
5255
4398
  if not if_large_batch:
5256
4399
  # [N_image,l1,l2,l3,x,y]
5257
- S4_pre_norm[:,Ndata_S4,:,:,:] = (
5258
- I1_small[:,j1].view(N_image,L,1,1,M3,N3) * self.backend.bk_conjugate(
5259
- I12_w3_2_small.view(N_image,1,L,L,M3,N3)
4400
+ S4_pre_norm[:, Ndata_S4, :, :, :] = (
4401
+ I1_small[:, j1].view(
4402
+ N_image, L, 1, 1, M3, N3
4403
+ )
4404
+ * self.backend.bk_conjugate(
4405
+ I12_w3_2_small.view(
4406
+ N_image, 1, L, L, M3, N3
4407
+ )
5260
4408
  )
5261
- )[...,edge_dx:-edge_dx, edge_dy:-edge_dy].mean((-2,-1)) * fft_factor
4409
+ )[..., edge_dx:-edge_dx, edge_dy:-edge_dy].mean(
4410
+ (-2, -1)
4411
+ ) * fft_factor
5262
4412
  if get_variance:
5263
- S4_sigma[:,Ndata_S4,:,:,:] = (
5264
- I1_small[:,j1].view(N_image,L,1,1,M3,N3) * self.backend.bk_conjugate(
5265
- I12_w3_2_small.view(N_image,1,L,L,M3,N3)
4413
+ S4_sigma[:, Ndata_S4, :, :, :] = (
4414
+ I1_small[:, j1].view(
4415
+ N_image, L, 1, 1, M3, N3
4416
+ )
4417
+ * self.backend.bk_conjugate(
4418
+ I12_w3_2_small.view(
4419
+ N_image, 1, L, L, M3, N3
4420
+ )
5266
4421
  )
5267
- )[...,edge_dx:-edge_dx, edge_dy:-edge_dy].std((-2,-1)) * fft_factor
4422
+ )[
4423
+ ..., edge_dx:-edge_dx, edge_dy:-edge_dy
4424
+ ].std(
4425
+ (-2, -1)
4426
+ ) * fft_factor
5268
4427
  else:
5269
4428
  for l1 in range(L):
5270
- # [N_image,l2,l3,x,y]
5271
- S4_pre_norm[:,Ndata_S4,l1,:,:] = (
5272
- I1_small[:,j1].view(N_image,1,1,M3,N3) * self.backend.bk_conjugate(
5273
- I12_w3_2_small.view(N_image,L,L,M3,N3)
4429
+ # [N_image,l2,l3,x,y]
4430
+ S4_pre_norm[:, Ndata_S4, l1, :, :] = (
4431
+ I1_small[:, j1].view(
4432
+ N_image, 1, 1, M3, N3
5274
4433
  )
5275
- )[...,edge_dx:-edge_dx, edge_dy:-edge_dy].mean((-2,-1)) * fft_factor
4434
+ * self.backend.bk_conjugate(
4435
+ I12_w3_2_small.view(
4436
+ N_image, L, L, M3, N3
4437
+ )
4438
+ )
4439
+ )[
4440
+ ..., edge_dx:-edge_dx, edge_dy:-edge_dy
4441
+ ].mean(
4442
+ (-2, -1)
4443
+ ) * fft_factor
5276
4444
  if get_variance:
5277
- S4_sigma[:,Ndata_S4,l1,:,:] = (
5278
- I1_small[:,j1].view(N_image,1,1,M3,N3) * self.backend.bk_conjugate(
5279
- I12_w3_2_small.view(N_image,L,L,M3,N3)
4445
+ S4_sigma[:, Ndata_S4, l1, :, :] = (
4446
+ I1_small[:, j1].view(
4447
+ N_image, 1, 1, M3, N3
5280
4448
  )
5281
- )[...,edge_dx:-edge_dx, edge_dy:-edge_dy].mean((-2,-1)) * fft_factor
5282
-
5283
- Ndata_S4+=1
5284
-
5285
- if normalization=='S2':
5286
- if use_ref:
5287
- P = ((ref_S2[:,j3:j3+1,:,None,None] * ref_S2[:,j2:j2+1,None,:,None] )**(0.5*pseudo_coef))
5288
- else:
5289
- P = ((S2[:,j3:j3+1,:,None,None] * S2[:,j2:j2+1,None,:,None] )**(0.5*pseudo_coef))
5290
-
5291
- S4[:,beg_n:Ndata_S4,:,:,:]=S4_pre_norm[:,beg_n:Ndata_S4,:,:,:].clone()/P
5292
-
4449
+ * self.backend.bk_conjugate(
4450
+ I12_w3_2_small.view(
4451
+ N_image, L, L, M3, N3
4452
+ )
4453
+ )
4454
+ )[
4455
+ ...,
4456
+ edge_dx:-edge_dx,
4457
+ edge_dy:-edge_dy,
4458
+ ].mean(
4459
+ (-2, -1)
4460
+ ) * fft_factor
4461
+
4462
+ Ndata_S4 += 1
4463
+
4464
+ if normalization == "S2":
4465
+ if use_ref:
4466
+ P = (
4467
+ ref_S2[:, j3 : j3 + 1, :, None, None]
4468
+ * ref_S2[:, j2 : j2 + 1, None, :, None]
4469
+ ) ** (0.5 * pseudo_coef)
4470
+ else:
4471
+ P = (
4472
+ S2[:, j3 : j3 + 1, :, None, None]
4473
+ * S2[:, j2 : j2 + 1, None, :, None]
4474
+ ) ** (0.5 * pseudo_coef)
4475
+
4476
+ S4[:, beg_n:Ndata_S4, :, :, :] = (
4477
+ S4_pre_norm[:, beg_n:Ndata_S4, :, :, :].clone() / P
4478
+ )
4479
+
5293
4480
  if get_variance:
5294
- S4_sigma[:,beg_n:Ndata_S4,:,:,:] = S4_sigma[:,beg_n:Ndata_S4,:,:,:]/P
4481
+ S4_sigma[:, beg_n:Ndata_S4, :, :, :] = (
4482
+ S4_sigma[:, beg_n:Ndata_S4, :, :, :] / P
4483
+ )
5295
4484
  else:
5296
- S4=S4_pre_norm
5297
-
4485
+ S4 = S4_pre_norm
4486
+
5298
4487
  # average over l1 to obtain simple isotropic statistics
5299
4488
  if iso_ang:
5300
4489
  S2_iso = S2.mean(-1)
5301
4490
  S1_iso = S1.mean(-1)
5302
4491
  for l1 in range(L):
5303
4492
  for l2 in range(L):
5304
- S3_iso[...,(l2-l1)%L] += S3[...,l1,l2]
4493
+ S3_iso[..., (l2 - l1) % L] += S3[..., l1, l2]
5305
4494
  if data2 is not None:
5306
- S3p_iso[...,(l2-l1)%L] += S3p[...,l1,l2]
4495
+ S3p_iso[..., (l2 - l1) % L] += S3p[..., l1, l2]
5307
4496
  for l3 in range(L):
5308
- S4_iso[...,(l2-l1)%L,(l3-l1)%L] += S4[...,l1,l2,l3]
5309
- S3_iso /= L; S4_iso /= L
4497
+ S4_iso[..., (l2 - l1) % L, (l3 - l1) % L] += S4[
4498
+ ..., l1, l2, l3
4499
+ ]
4500
+ S3_iso /= L
4501
+ S4_iso /= L
5310
4502
  if data2 is not None:
5311
4503
  S3p_iso /= L
5312
-
4504
+
5313
4505
  if get_variance:
5314
4506
  S2_sigma_iso = S2_sigma.mean(-1)
5315
4507
  S1_sigma_iso = S1_sigma.mean(-1)
5316
4508
  for l1 in range(L):
5317
4509
  for l2 in range(L):
5318
- S3_sigma_iso[...,(l2-l1)%L] += S3_sigma[...,l1,l2]
4510
+ S3_sigma_iso[..., (l2 - l1) % L] += S3_sigma[..., l1, l2]
5319
4511
  if data2 is not None:
5320
- S3p_sigma_iso[...,(l2-l1)%L] += S3p_sigma[...,l1,l2]
4512
+ S3p_sigma_iso[..., (l2 - l1) % L] += S3p_sigma[
4513
+ ..., l1, l2
4514
+ ]
5321
4515
  for l3 in range(L):
5322
- S4_sigma_iso[...,(l2-l1)%L,(l3-l1)%L] += S4_sigma[...,l1,l2,l3]
5323
- S3_sigma_iso /= L; S4_sigma_iso /= L
4516
+ S4_sigma_iso[
4517
+ ..., (l2 - l1) % L, (l3 - l1) % L
4518
+ ] += S4_sigma[..., l1, l2, l3]
4519
+ S3_sigma_iso /= L
4520
+ S4_sigma_iso /= L
5324
4521
  if data2 is not None:
5325
4522
  S3p_sigma_iso /= L
5326
-
5327
- mean_data=self.backend.bk_zeros((N_image,1), dtype=data.dtype)
5328
- std_data=self.backend.bk_zeros((N_image,1), dtype=data.dtype)
5329
-
4523
+
4524
+ mean_data = self.backend.bk_zeros((N_image, 1), dtype=data.dtype)
4525
+ std_data = self.backend.bk_zeros((N_image, 1), dtype=data.dtype)
4526
+
5330
4527
  if data2 is None:
5331
- mean_data[:,0]=data.mean((-2,-1))
5332
- std_data[:,0]=data.std((-2,-1))
4528
+ mean_data[:, 0] = data.mean((-2, -1))
4529
+ std_data[:, 0] = data.std((-2, -1))
5333
4530
  else:
5334
- mean_data[:,0]=(data2*data).mean((-2,-1))
5335
- std_data[:,0]=(data2*data).std((-2,-1))
5336
-
4531
+ mean_data[:, 0] = (data2 * data).mean((-2, -1))
4532
+ std_data[:, 0] = (data2 * data).std((-2, -1))
4533
+
5337
4534
  if get_variance:
5338
- ref_sigma={}
4535
+ ref_sigma = {}
5339
4536
  if iso_ang:
5340
- ref_sigma['std_data']=std_data
5341
- ref_sigma['S1_sigma']=S1_sigma_iso
5342
- ref_sigma['S2_sigma']=S2_sigma_iso
5343
- ref_sigma['S3_sigma']=S3_sigma_iso
4537
+ ref_sigma["std_data"] = std_data
4538
+ ref_sigma["S1_sigma"] = S1_sigma_iso
4539
+ ref_sigma["S2_sigma"] = S2_sigma_iso
4540
+ ref_sigma["S3_sigma"] = S3_sigma_iso
5344
4541
  if data2 is not None:
5345
- ref_sigma['S3p_sigma']=S3p_sigma_iso
5346
- ref_sigma['S4_sigma']=S4_sigma_iso
4542
+ ref_sigma["S3p_sigma"] = S3p_sigma_iso
4543
+ ref_sigma["S4_sigma"] = S4_sigma_iso
5347
4544
  else:
5348
- ref_sigma['std_data']=std_data
5349
- ref_sigma['S1_sigma']=S1_sigma
5350
- ref_sigma['S2_sigma']=S2_sigma
5351
- ref_sigma['S3_sigma']=S3_sigma
4545
+ ref_sigma["std_data"] = std_data
4546
+ ref_sigma["S1_sigma"] = S1_sigma
4547
+ ref_sigma["S2_sigma"] = S2_sigma
4548
+ ref_sigma["S3_sigma"] = S3_sigma
5352
4549
  if data2 is not None:
5353
- ref_sigma['S3p_sigma']=S3p_sigma
5354
- ref_sigma['S4_sigma']=S4_sigma
5355
-
4550
+ ref_sigma["S3p_sigma"] = S3p_sigma
4551
+ ref_sigma["S4_sigma"] = S4_sigma
4552
+
5356
4553
  if data2 is None:
5357
4554
  if iso_ang:
5358
4555
  if ref_sigma is not None:
5359
- for_synthesis = self.backend.backend.cat((
5360
- mean_data/ref_sigma['std_data'],
5361
- std_data/ref_sigma['std_data'],
5362
- (S2_iso/ref_sigma['S2_sigma']).reshape((N_image, -1)).log(),
5363
- (S1_iso/ref_sigma['S1_sigma']).reshape((N_image, -1)).log(),
5364
- (S3_iso/ref_sigma['S3_sigma']).reshape((N_image, -1)).real,
5365
- (S3_iso/ref_sigma['S3_sigma']).reshape((N_image, -1)).imag,
5366
- (S4_iso/ref_sigma['S4_sigma']).reshape((N_image, -1)).real,
5367
- (S4_iso/ref_sigma['S4_sigma']).reshape((N_image, -1)).imag,
5368
- ),dim=-1)
4556
+ for_synthesis = self.backend.backend.cat(
4557
+ (
4558
+ mean_data / ref_sigma["std_data"],
4559
+ std_data / ref_sigma["std_data"],
4560
+ (S2_iso / ref_sigma["S2_sigma"])
4561
+ .reshape((N_image, -1))
4562
+ .log(),
4563
+ (S1_iso / ref_sigma["S1_sigma"])
4564
+ .reshape((N_image, -1))
4565
+ .log(),
4566
+ (S3_iso / ref_sigma["S3_sigma"])
4567
+ .reshape((N_image, -1))
4568
+ .real,
4569
+ (S3_iso / ref_sigma["S3_sigma"])
4570
+ .reshape((N_image, -1))
4571
+ .imag,
4572
+ (S4_iso / ref_sigma["S4_sigma"])
4573
+ .reshape((N_image, -1))
4574
+ .real,
4575
+ (S4_iso / ref_sigma["S4_sigma"])
4576
+ .reshape((N_image, -1))
4577
+ .imag,
4578
+ ),
4579
+ dim=-1,
4580
+ )
5369
4581
  else:
5370
- for_synthesis = self.backend.backend.cat((
5371
- mean_data/std_data,
5372
- std_data,
5373
- S2_iso.reshape((N_image, -1)).log(),
5374
- S1_iso.reshape((N_image, -1)).log(),
5375
- S3_iso.reshape((N_image, -1)).real,
5376
- S3_iso.reshape((N_image, -1)).imag,
5377
- S4_iso.reshape((N_image, -1)).real,
5378
- S4_iso.reshape((N_image, -1)).imag,
5379
- ),dim=-1)
4582
+ for_synthesis = self.backend.backend.cat(
4583
+ (
4584
+ mean_data / std_data,
4585
+ std_data,
4586
+ S2_iso.reshape((N_image, -1)).log(),
4587
+ S1_iso.reshape((N_image, -1)).log(),
4588
+ S3_iso.reshape((N_image, -1)).real,
4589
+ S3_iso.reshape((N_image, -1)).imag,
4590
+ S4_iso.reshape((N_image, -1)).real,
4591
+ S4_iso.reshape((N_image, -1)).imag,
4592
+ ),
4593
+ dim=-1,
4594
+ )
5380
4595
  else:
5381
4596
  if ref_sigma is not None:
5382
- for_synthesis = self.backend.backend.cat((
5383
- mean_data/ref_sigma['std_data'],
5384
- std_data/ref_sigma['std_data'],
5385
- (S2/ref_sigma['S2_sigma']).reshape((N_image, -1)).log(),
5386
- (S1/ref_sigma['S1_sigma']).reshape((N_image, -1)).log(),
5387
- (S3/ref_sigma['S3_sigma']).reshape((N_image, -1)).real,
5388
- (S3/ref_sigma['S3_sigma']).reshape((N_image, -1)).imag,
5389
- (S4/ref_sigma['S4_sigma']).reshape((N_image, -1)).real,
5390
- (S4/ref_sigma['S4_sigma']).reshape((N_image, -1)).imag,
5391
- ),dim=-1)
4597
+ for_synthesis = self.backend.backend.cat(
4598
+ (
4599
+ mean_data / ref_sigma["std_data"],
4600
+ std_data / ref_sigma["std_data"],
4601
+ (S2 / ref_sigma["S2_sigma"])
4602
+ .reshape((N_image, -1))
4603
+ .log(),
4604
+ (S1 / ref_sigma["S1_sigma"])
4605
+ .reshape((N_image, -1))
4606
+ .log(),
4607
+ (S3 / ref_sigma["S3_sigma"])
4608
+ .reshape((N_image, -1))
4609
+ .real,
4610
+ (S3 / ref_sigma["S3_sigma"])
4611
+ .reshape((N_image, -1))
4612
+ .imag,
4613
+ (S4 / ref_sigma["S4_sigma"])
4614
+ .reshape((N_image, -1))
4615
+ .real,
4616
+ (S4 / ref_sigma["S4_sigma"])
4617
+ .reshape((N_image, -1))
4618
+ .imag,
4619
+ ),
4620
+ dim=-1,
4621
+ )
5392
4622
  else:
5393
- for_synthesis = self.backend.backend.cat((
5394
- mean_data/std_data,
5395
- std_data,
5396
- S2.reshape((N_image, -1)).log(),
5397
- S1.reshape((N_image, -1)).log(),
5398
- S3.reshape((N_image, -1)).real,
5399
- S3.reshape((N_image, -1)).imag,
5400
- S4.reshape((N_image, -1)).real,
5401
- S4.reshape((N_image, -1)).imag,
5402
- ),dim=-1)
4623
+ for_synthesis = self.backend.backend.cat(
4624
+ (
4625
+ mean_data / std_data,
4626
+ std_data,
4627
+ S2.reshape((N_image, -1)).log(),
4628
+ S1.reshape((N_image, -1)).log(),
4629
+ S3.reshape((N_image, -1)).real,
4630
+ S3.reshape((N_image, -1)).imag,
4631
+ S4.reshape((N_image, -1)).real,
4632
+ S4.reshape((N_image, -1)).imag,
4633
+ ),
4634
+ dim=-1,
4635
+ )
5403
4636
  else:
5404
4637
  if iso_ang:
5405
4638
  if ref_sigma is not None:
5406
- for_synthesis = self.backend.backend.cat((
5407
- mean_data/ref_sigma['std_data'],
5408
- std_data/ref_sigma['std_data'],
5409
- (S2_iso/ref_sigma['S2_sigma']).reshape((N_image, -1)),
5410
- (S1_iso/ref_sigma['S1_sigma']).reshape((N_image, -1)),
5411
- (S3_iso/ref_sigma['S3_sigma']).reshape((N_image, -1)).real,
5412
- (S3_iso/ref_sigma['S3_sigma']).reshape((N_image, -1)).imag,
5413
- (S3p_iso/ref_sigma['S3p_sigma']).reshape((N_image, -1)).real,
5414
- (S3p_iso/ref_sigma['S3p_sigma']).reshape((N_image, -1)).imag,
5415
- (S4_iso/ref_sigma['S4_sigma']).reshape((N_image, -1)).real,
5416
- (S4_iso/ref_sigma['S4_sigma']).reshape((N_image, -1)).imag,
5417
- ),dim=-1)
4639
+ for_synthesis = self.backend.backend.cat(
4640
+ (
4641
+ mean_data / ref_sigma["std_data"],
4642
+ std_data / ref_sigma["std_data"],
4643
+ (S2_iso / ref_sigma["S2_sigma"]).reshape((N_image, -1)),
4644
+ (S1_iso / ref_sigma["S1_sigma"]).reshape((N_image, -1)),
4645
+ (S3_iso / ref_sigma["S3_sigma"])
4646
+ .reshape((N_image, -1))
4647
+ .real,
4648
+ (S3_iso / ref_sigma["S3_sigma"])
4649
+ .reshape((N_image, -1))
4650
+ .imag,
4651
+ (S3p_iso / ref_sigma["S3p_sigma"])
4652
+ .reshape((N_image, -1))
4653
+ .real,
4654
+ (S3p_iso / ref_sigma["S3p_sigma"])
4655
+ .reshape((N_image, -1))
4656
+ .imag,
4657
+ (S4_iso / ref_sigma["S4_sigma"])
4658
+ .reshape((N_image, -1))
4659
+ .real,
4660
+ (S4_iso / ref_sigma["S4_sigma"])
4661
+ .reshape((N_image, -1))
4662
+ .imag,
4663
+ ),
4664
+ dim=-1,
4665
+ )
5418
4666
  else:
5419
- for_synthesis = self.backend.backend.cat((
5420
- mean_data/std_data,
5421
- std_data,
5422
- S2_iso.reshape((N_image, -1)),
5423
- S1_iso.reshape((N_image, -1)),
5424
- S3_iso.reshape((N_image, -1)).real,
5425
- S3_iso.reshape((N_image, -1)).imag,
5426
- S3p_iso.reshape((N_image, -1)).real,
5427
- S3p_iso.reshape((N_image, -1)).imag,
5428
- S4_iso.reshape((N_image, -1)).real,
5429
- S4_iso.reshape((N_image, -1)).imag,
5430
- ),dim=-1)
4667
+ for_synthesis = self.backend.backend.cat(
4668
+ (
4669
+ mean_data / std_data,
4670
+ std_data,
4671
+ S2_iso.reshape((N_image, -1)),
4672
+ S1_iso.reshape((N_image, -1)),
4673
+ S3_iso.reshape((N_image, -1)).real,
4674
+ S3_iso.reshape((N_image, -1)).imag,
4675
+ S3p_iso.reshape((N_image, -1)).real,
4676
+ S3p_iso.reshape((N_image, -1)).imag,
4677
+ S4_iso.reshape((N_image, -1)).real,
4678
+ S4_iso.reshape((N_image, -1)).imag,
4679
+ ),
4680
+ dim=-1,
4681
+ )
5431
4682
  else:
5432
4683
  if ref_sigma is not None:
5433
- for_synthesis = self.backend.backend.cat((
5434
- mean_data/ref_sigma['std_data'],
5435
- std_data/ref_sigma['std_data'],
5436
- (S2/ref_sigma['S2_sigma']).reshape((N_image, -1)),
5437
- (S1/ref_sigma['S1_sigma']).reshape((N_image, -1)),
5438
- (S3/ref_sigma['S3_sigma']).reshape((N_image, -1)).real,
5439
- (S3/ref_sigma['S3_sigma']).reshape((N_image, -1)).imag,
5440
- (S3p/ref_sigma['S3p_sigma']).reshape((N_image, -1)).real,
5441
- (S3p/ref_sigma['S3p_sigma']).reshape((N_image, -1)).imag,
5442
- (S4/ref_sigma['S4_sigma']).reshape((N_image, -1)).real,
5443
- (S4/ref_sigma['S4_sigma']).reshape((N_image, -1)).imag,
5444
- ),dim=-1)
4684
+ for_synthesis = self.backend.backend.cat(
4685
+ (
4686
+ mean_data / ref_sigma["std_data"],
4687
+ std_data / ref_sigma["std_data"],
4688
+ (S2 / ref_sigma["S2_sigma"]).reshape((N_image, -1)),
4689
+ (S1 / ref_sigma["S1_sigma"]).reshape((N_image, -1)),
4690
+ (S3 / ref_sigma["S3_sigma"])
4691
+ .reshape((N_image, -1))
4692
+ .real,
4693
+ (S3 / ref_sigma["S3_sigma"])
4694
+ .reshape((N_image, -1))
4695
+ .imag,
4696
+ (S3p / ref_sigma["S3p_sigma"])
4697
+ .reshape((N_image, -1))
4698
+ .real,
4699
+ (S3p / ref_sigma["S3p_sigma"])
4700
+ .reshape((N_image, -1))
4701
+ .imag,
4702
+ (S4 / ref_sigma["S4_sigma"])
4703
+ .reshape((N_image, -1))
4704
+ .real,
4705
+ (S4 / ref_sigma["S4_sigma"])
4706
+ .reshape((N_image, -1))
4707
+ .imag,
4708
+ ),
4709
+ dim=-1,
4710
+ )
5445
4711
  else:
5446
- for_synthesis = self.backend.backend.cat((
5447
- mean_data/std_data,
5448
- std_data,
5449
- S2.reshape((N_image, -1)),
5450
- S1.reshape((N_image, -1)),
5451
- S3.reshape((N_image, -1)).real,
5452
- S3.reshape((N_image, -1)).imag,
5453
- S3p.reshape((N_image, -1)).real,
5454
- S3p.reshape((N_image, -1)).imag,
5455
- S4.reshape((N_image, -1)).real,
5456
- S4.reshape((N_image, -1)).imag,
5457
- ),dim=-1)
5458
-
5459
- if not use_ref:
5460
- self.ref_scattering_cov_S2=S2
5461
-
4712
+ for_synthesis = self.backend.backend.cat(
4713
+ (
4714
+ mean_data / std_data,
4715
+ std_data,
4716
+ S2.reshape((N_image, -1)),
4717
+ S1.reshape((N_image, -1)),
4718
+ S3.reshape((N_image, -1)).real,
4719
+ S3.reshape((N_image, -1)).imag,
4720
+ S3p.reshape((N_image, -1)).real,
4721
+ S3p.reshape((N_image, -1)).imag,
4722
+ S4.reshape((N_image, -1)).real,
4723
+ S4.reshape((N_image, -1)).imag,
4724
+ ),
4725
+ dim=-1,
4726
+ )
4727
+
4728
+ if not use_ref:
4729
+ self.ref_scattering_cov_S2 = S2
4730
+
5462
4731
  if get_variance:
5463
- return for_synthesis,ref_sigma
5464
-
4732
+ return for_synthesis, ref_sigma
4733
+
5465
4734
  return for_synthesis
5466
-
5467
- if (M,N,J,L) not in self.filters_set:
5468
- self.filters_set[(M,N,J,L)] = self.computer_filter(M,N,J,L) #self.computer_filter(M,N,J,L)
5469
-
5470
- filters_set = self.filters_set[(M,N,J,L)]
5471
-
5472
- #weight = self.weight
4735
+
4736
+ if (M, N, J, L) not in self.filters_set:
4737
+ self.filters_set[(M, N, J, L)] = self.computer_filter(
4738
+ M, N, J, L
4739
+ ) # self.computer_filter(M,N,J,L)
4740
+
4741
+ filters_set = self.filters_set[(M, N, J, L)]
4742
+
4743
+ # weight = self.weight
5473
4744
  if use_ref:
5474
- if normalization=='S2':
4745
+ if normalization == "S2":
5475
4746
  ref_S2 = self.ref_scattering_cov_S2
5476
- else:
5477
- ref_P11 = self.ref_scattering_cov['P11']
4747
+ else:
4748
+ ref_P11 = self.ref_scattering_cov["P11"]
5478
4749
 
5479
4750
  # convert numpy array input into self.backend.bk_ tensors
5480
4751
  data = self.backend.bk_cast(data)
5481
- data_f = self.backend.bk_fftn(data, dim=(-2,-1))
4752
+ data_f = self.backend.bk_fftn(data, dim=(-2, -1))
5482
4753
  if data2 is not None:
5483
4754
  data2 = self.backend.bk_cast(data2)
5484
- data2_f = self.backend.bk_fftn(data2, dim=(-2,-1))
5485
-
4755
+ data2_f = self.backend.bk_fftn(data2, dim=(-2, -1))
4756
+
5486
4757
  # initialize tensors for scattering coefficients
5487
-
5488
- Ndata_S3 = J*(J+1)//2
5489
- Ndata_S4 = J*(J+1)*(J+2)//6
5490
- J_S4={}
5491
-
4758
+
4759
+ Ndata_S3 = J * (J + 1) // 2
4760
+ Ndata_S4 = J * (J + 1) * (J + 2) // 6
4761
+ J_S4 = {}
4762
+
5492
4763
  S3 = []
5493
4764
  if data2 is not None:
5494
4765
  S3p = []
5495
- S4_pre_norm = []
5496
- S4 = []
5497
-
4766
+ S4_pre_norm = []
4767
+ S4 = []
4768
+
5498
4769
  # variance
5499
4770
  if get_variance:
5500
- S3_sigma = []
4771
+ S3_sigma = []
5501
4772
  if data2 is not None:
5502
4773
  S3p_sigma = []
5503
- S4_sigma = []
5504
-
4774
+ S4_sigma = []
4775
+
5505
4776
  if iso_ang:
5506
4777
  S3_iso = []
5507
4778
  if data2 is not None:
5508
4779
  S3p_iso = []
5509
-
4780
+
5510
4781
  S4_iso = []
5511
4782
  if get_variance:
5512
4783
  S3_sigma_iso = []
5513
4784
  if data2 is not None:
5514
- S3p_sigma_iso = []
5515
- S4_sigma_iso = []
5516
-
4785
+ S3p_sigma_iso = []
4786
+ S4_sigma_iso = []
4787
+
5517
4788
  #
5518
- if edge:
5519
- if (M,N,J) not in self.edge_masks:
5520
- self.edge_masks[(M,N,J)] = self.get_edge_masks(M,N,J)
5521
- edge_mask = self.edge_masks[(M,N,J)]
5522
- else:
4789
+ if edge:
4790
+ if (M, N, J) not in self.edge_masks:
4791
+ self.edge_masks[(M, N, J)] = self.get_edge_masks(M, N, J)
4792
+ edge_mask = self.edge_masks[(M, N, J)]
4793
+ else:
5523
4794
  edge_mask = 1
5524
-
4795
+
5525
4796
  # calculate scattering fields
5526
4797
  if data2 is None:
5527
4798
  if self.use_2D:
5528
4799
  if len(data.shape) == 2:
5529
- I1 = self.backend.bk_abs(self.backend.bk_ifftn(
5530
- data_f[None,None,None,:,:] * filters_set[None,:J,:,:,:], dim=(-2,-1)
5531
- ))
4800
+ I1 = self.backend.bk_abs(
4801
+ self.backend.bk_ifftn(
4802
+ data_f[None, None, None, :, :]
4803
+ * filters_set[None, :J, :, :, :],
4804
+ dim=(-2, -1),
4805
+ )
4806
+ )
5532
4807
  else:
5533
- I1 = self.backend.bk_abs(self.backend.bk_ifftn(
5534
- data_f[:,None,None,:,:] * filters_set[None,:J,:,:,:], dim=(-2,-1)
5535
- ))
4808
+ I1 = self.backend.bk_abs(
4809
+ self.backend.bk_ifftn(
4810
+ data_f[:, None, None, :, :]
4811
+ * filters_set[None, :J, :, :, :],
4812
+ dim=(-2, -1),
4813
+ )
4814
+ )
5536
4815
  elif self.use_1D:
5537
4816
  if len(data.shape) == 1:
5538
- I1 = self.backend.bk_abs(self.backend.bk_ifftn(
5539
- data_f[None,None,None,:] * filters_set[None,:J,:,:], dim=(-1)
5540
- ))
4817
+ I1 = self.backend.bk_abs(
4818
+ self.backend.bk_ifftn(
4819
+ data_f[None, None, None, :] * filters_set[None, :J, :, :],
4820
+ dim=(-1),
4821
+ )
4822
+ )
5541
4823
  else:
5542
- I1 = self.backend.bk_abs(self.backend.bk_ifftn(
5543
- data_f[:,None,None,:] * filters_set[None,:J,:,:], dim=(-1)
5544
- ))
4824
+ I1 = self.backend.bk_abs(
4825
+ self.backend.bk_ifftn(
4826
+ data_f[:, None, None, :] * filters_set[None, :J, :, :],
4827
+ dim=(-1),
4828
+ )
4829
+ )
5545
4830
  else:
5546
- print('todo')
5547
-
5548
- S2 = self.backend.bk_reduce_mean((I1**2 * edge_mask),axis=(-2,-1))
5549
- S1 = self.backend.bk_reduce_mean(I1 * edge_mask,axis=(-2,-1))
4831
+ print("todo")
4832
+
4833
+ S2 = self.backend.bk_reduce_mean((I1**2 * edge_mask), axis=(-2, -1))
4834
+ S1 = self.backend.bk_reduce_mean(I1 * edge_mask, axis=(-2, -1))
5550
4835
 
5551
4836
  if get_variance:
5552
- S2_sigma = self.backend.bk_reduce_std((I1**2 * edge_mask),axis=(-2,-1))
5553
- S1_sigma = self.backend.bk_reduce_std((I1 * edge_mask),axis=(-2,-1))
5554
-
5555
- I1_f= self.backend.bk_fftn(I1, dim=(-2,-1))
5556
-
5557
- else:
4837
+ S2_sigma = self.backend.bk_reduce_std(
4838
+ (I1**2 * edge_mask), axis=(-2, -1)
4839
+ )
4840
+ S1_sigma = self.backend.bk_reduce_std((I1 * edge_mask), axis=(-2, -1))
4841
+
4842
+ I1_f = self.backend.bk_fftn(I1, dim=(-2, -1))
4843
+
4844
+ else:
5558
4845
  if self.use_2D:
5559
4846
  if len(data.shape) == 2:
5560
4847
  I1 = self.backend.bk_ifftn(
5561
- data_f[None,None,None,:,:] * filters_set[None,:J,:,:,:], dim=(-2,-1)
4848
+ data_f[None, None, None, :, :] * filters_set[None, :J, :, :, :],
4849
+ dim=(-2, -1),
5562
4850
  )
5563
4851
  I2 = self.backend.bk_ifftn(
5564
- data2_f[None,None,None,:,:] * filters_set[None,:J,:,:,:], dim=(-2,-1)
4852
+ data2_f[None, None, None, :, :]
4853
+ * filters_set[None, :J, :, :, :],
4854
+ dim=(-2, -1),
5565
4855
  )
5566
4856
  else:
5567
4857
  I1 = self.backend.bk_ifftn(
5568
- data_f[:,None,None,:,:] * filters_set[None,:J,:,:,:], dim=(-2,-1)
4858
+ data_f[:, None, None, :, :] * filters_set[None, :J, :, :, :],
4859
+ dim=(-2, -1),
5569
4860
  )
5570
4861
  I2 = self.backend.bk_ifftn(
5571
- data2_f[:,None,None,:,:] * filters_set[None,:J,:,:,:], dim=(-2,-1)
4862
+ data2_f[:, None, None, :, :] * filters_set[None, :J, :, :, :],
4863
+ dim=(-2, -1),
5572
4864
  )
5573
4865
  elif self.use_1D:
5574
4866
  if len(data.shape) == 1:
5575
4867
  I1 = self.backend.bk_ifftn(
5576
- data_f[None,None,None,:] * filters_set[None,:J,:,:], dim=(-1)
4868
+ data_f[None, None, None, :] * filters_set[None, :J, :, :],
4869
+ dim=(-1),
5577
4870
  )
5578
4871
  I2 = self.backend.bk_ifftn(
5579
- data2_f[None,None,None,:] * filters_set[None,:J,:,:], dim=(-1)
4872
+ data2_f[None, None, None, :] * filters_set[None, :J, :, :],
4873
+ dim=(-1),
5580
4874
  )
5581
4875
  else:
5582
4876
  I1 = self.backend.bk_ifftn(
5583
- data_f[:,None,None,:] * filters_set[None,:J,:,:], dim=(-1)
4877
+ data_f[:, None, None, :] * filters_set[None, :J, :, :], dim=(-1)
5584
4878
  )
5585
4879
  I2 = self.backend.bk_ifftn(
5586
- data2_f[:,None,None,:] * filters_set[None,:J,:,:], dim=(-1)
4880
+ data2_f[:, None, None, :] * filters_set[None, :J, :, :],
4881
+ dim=(-1),
5587
4882
  )
5588
4883
  else:
5589
- print('todo')
5590
-
5591
- I1=self.backend.bk_real(I1*self.backend.bk_conjugate(I2))
5592
-
5593
- S2 = self.backend.bk_reduce_mean((I1 * edge_mask),axis=(-2,-1))
4884
+ print("todo")
4885
+
4886
+ I1 = self.backend.bk_real(I1 * self.backend.bk_conjugate(I2))
4887
+
4888
+ S2 = self.backend.bk_reduce_mean((I1 * edge_mask), axis=(-2, -1))
5594
4889
  if get_variance:
5595
- S2_sigma = self.backend.bk_reduce_std((I1 * edge_mask),axis=(-2,-1))
5596
-
5597
- I1=self.backend.bk_L1(I1)
5598
-
5599
- S1 = self.backend.bk_reduce_mean((I1 * edge_mask),axis=(-2,-1))
4890
+ S2_sigma = self.backend.bk_reduce_std((I1 * edge_mask), axis=(-2, -1))
4891
+
4892
+ I1 = self.backend.bk_L1(I1)
4893
+
4894
+ S1 = self.backend.bk_reduce_mean((I1 * edge_mask), axis=(-2, -1))
5600
4895
 
5601
4896
  if get_variance:
5602
- S1_sigma = self.backend.bk_reduce_std((I1 * edge_mask),axis=(-2,-1))
5603
-
5604
- I1_f= self.backend.bk_fftn(I1, dim=(-2,-1))
5605
-
4897
+ S1_sigma = self.backend.bk_reduce_std((I1 * edge_mask), axis=(-2, -1))
4898
+
4899
+ I1_f = self.backend.bk_fftn(I1, dim=(-2, -1))
4900
+
5606
4901
  if pseudo_coef != 1:
5607
4902
  I1 = I1**pseudo_coef
5608
-
5609
- Ndata_S3=0
5610
- Ndata_S4=0
5611
-
4903
+
4904
+ Ndata_S3 = 0
4905
+ Ndata_S4 = 0
4906
+
5612
4907
  # calculate the covariance and correlations of the scattering fields
5613
4908
  # only use the low-k Fourier coefs when calculating large-j scattering coefs.
5614
- for j3 in range(0,J):
5615
- J_S4[j3]=Ndata_S4
5616
-
5617
- dx3, dy3 = self.get_dxdy(j3,M,N)
5618
- I1_f_small = self.cut_high_k_off(I1_f[:,:j3+1], dx3, dy3) # Nimage, J, L, x, y
4909
+ for j3 in range(0, J):
4910
+ J_S4[j3] = Ndata_S4
4911
+
4912
+ dx3, dy3 = self.get_dxdy(j3, M, N)
4913
+ I1_f_small = self.cut_high_k_off(
4914
+ I1_f[:, : j3 + 1], dx3, dy3
4915
+ ) # Nimage, J, L, x, y
5619
4916
  data_f_small = self.cut_high_k_off(data_f, dx3, dy3)
5620
4917
  if data2 is not None:
5621
4918
  data2_f_small = self.cut_high_k_off(data2_f, dx3, dy3)
5622
4919
  if edge:
5623
- I1_small = self.backend.bk_ifftn(I1_f_small, dim=(-2,-1), norm='ortho')
5624
- data_small = self.backend.bk_ifftn(data_f_small, dim=(-2,-1), norm='ortho')
4920
+ I1_small = self.backend.bk_ifftn(I1_f_small, dim=(-2, -1), norm="ortho")
4921
+ data_small = self.backend.bk_ifftn(
4922
+ data_f_small, dim=(-2, -1), norm="ortho"
4923
+ )
5625
4924
  if data2 is not None:
5626
- data2_small = self.backend.bk_ifftn(data2_f_small, dim=(-2,-1), norm='ortho')
5627
- wavelet_f3 = self.cut_high_k_off(filters_set[j3], dx3, dy3) # L,x,y
4925
+ data2_small = self.backend.bk_ifftn(
4926
+ data2_f_small, dim=(-2, -1), norm="ortho"
4927
+ )
4928
+ wavelet_f3 = self.cut_high_k_off(filters_set[j3], dx3, dy3) # L,x,y
5628
4929
  _, M3, N3 = wavelet_f3.shape
5629
4930
  wavelet_f3_squared = wavelet_f3**2
5630
- edge_dx = min(4, int(2**j3*dx3*2/M))
5631
- edge_dy = min(4, int(2**j3*dy3*2/N))
4931
+ edge_dx = min(4, int(2**j3 * dx3 * 2 / M))
4932
+ edge_dy = min(4, int(2**j3 * dy3 * 2 / N))
5632
4933
  # a normalization change due to the cutoff of frequency space
5633
4934
  if self.all_bk_type == "float32":
5634
- fft_factor = np.complex64(1 /(M3*N3) * (M3*N3/M/N)**2)
4935
+ fft_factor = np.complex64(1 / (M3 * N3) * (M3 * N3 / M / N) ** 2)
5635
4936
  else:
5636
- fft_factor = np.complex128(1 /(M3*N3) * (M3*N3/M/N)**2)
5637
- for j2 in range(0,j3+1):
5638
- #I1_f2_wf3_small = I1_f_small[:,j2].view(N_image,L,1,M3,N3) * wavelet_f3.view(1,1,L,M3,N3)
5639
- #I1_f2_wf3_2_small = I1_f_small[:,j2].view(N_image,L,1,M3,N3) * wavelet_f3_squared.view(1,1,L,M3,N3)
5640
- I1_f2_wf3_small = self.backend.bk_reshape(I1_f_small[:,j2],[N_image,1,L,1,M3,N3]) * self.backend.bk_reshape(wavelet_f3,[1,1,1,L,M3,N3])
5641
- I1_f2_wf3_2_small = self.backend.bk_reshape(I1_f_small[:,j2],[N_image,1,L,1,M3,N3]) * self.backend.bk_reshape(wavelet_f3_squared,[1,1,1,L,M3,N3])
4937
+ fft_factor = np.complex128(1 / (M3 * N3) * (M3 * N3 / M / N) ** 2)
4938
+ for j2 in range(0, j3 + 1):
4939
+ # I1_f2_wf3_small = I1_f_small[:,j2].view(N_image,L,1,M3,N3) * wavelet_f3.view(1,1,L,M3,N3)
4940
+ # I1_f2_wf3_2_small = I1_f_small[:,j2].view(N_image,L,1,M3,N3) * wavelet_f3_squared.view(1,1,L,M3,N3)
4941
+ I1_f2_wf3_small = self.backend.bk_reshape(
4942
+ I1_f_small[:, j2], [N_image, 1, L, 1, M3, N3]
4943
+ ) * self.backend.bk_reshape(wavelet_f3, [1, 1, 1, L, M3, N3])
4944
+ I1_f2_wf3_2_small = self.backend.bk_reshape(
4945
+ I1_f_small[:, j2], [N_image, 1, L, 1, M3, N3]
4946
+ ) * self.backend.bk_reshape(wavelet_f3_squared, [1, 1, 1, L, M3, N3])
5642
4947
  if edge:
5643
- I12_w3_small = self.backend.bk_ifftn(I1_f2_wf3_small, dim=(-2,-1), norm='ortho')
5644
- I12_w3_2_small = self.backend.bk_ifftn(I1_f2_wf3_2_small, dim=(-2,-1), norm='ortho')
4948
+ I12_w3_small = self.backend.bk_ifftn(
4949
+ I1_f2_wf3_small, dim=(-2, -1), norm="ortho"
4950
+ )
4951
+ I12_w3_2_small = self.backend.bk_ifftn(
4952
+ I1_f2_wf3_2_small, dim=(-2, -1), norm="ortho"
4953
+ )
5645
4954
  if use_ref:
5646
- if normalization=='P11':
5647
- norm_factor_S3 = (ref_S2[:,None,j3,:] * ref_P11[:,j2,j3,:,:]**pseudo_coef)**0.5
5648
- norm_factor_S3 = self.backend.bk_complex(norm_factor_S3,0*norm_factor_S3)
5649
- elif normalization=='S2':
5650
- norm_factor_S3 = (ref_S2[:,None,j3,:] * ref_S2[:,j2,:,None]**pseudo_coef)**0.5
5651
- norm_factor_S3 = self.backend.bk_complex(norm_factor_S3,0*norm_factor_S3)
4955
+ if normalization == "P11":
4956
+ norm_factor_S3 = (
4957
+ ref_S2[:, None, j3, :]
4958
+ * ref_P11[:, j2, j3, :, :] ** pseudo_coef
4959
+ ) ** 0.5
4960
+ norm_factor_S3 = self.backend.bk_complex(
4961
+ norm_factor_S3, 0 * norm_factor_S3
4962
+ )
4963
+ elif normalization == "S2":
4964
+ norm_factor_S3 = (
4965
+ ref_S2[:, None, j3, :]
4966
+ * ref_S2[:, j2, :, None] ** pseudo_coef
4967
+ ) ** 0.5
4968
+ norm_factor_S3 = self.backend.bk_complex(
4969
+ norm_factor_S3, 0 * norm_factor_S3
4970
+ )
5652
4971
  else:
5653
4972
  norm_factor_S3 = C_ONE
5654
4973
  else:
5655
- if normalization=='P11':
4974
+ if normalization == "P11":
5656
4975
  # [N_image,l2,l3,x,y]
5657
- P11_temp = self.backend.bk_reduce_mean((I1_f2_wf3_small.abs()**2),axis=(-2,-1)) * fft_factor
5658
- norm_factor_S3 = (S2[:,None,j3,:] * P11_temp**pseudo_coef)**0.5
5659
- norm_factor_S3 = self.backend.bk_complex(norm_factor_S3,0*norm_factor_S3)
5660
- elif normalization=='S2':
5661
- norm_factor_S3 = (S2[:,None,j3,None,:] * S2[:,None,j2,:,None]**pseudo_coef)**0.5
5662
- norm_factor_S3 = self.backend.bk_complex(norm_factor_S3,0*norm_factor_S3)
4976
+ P11_temp = (
4977
+ self.backend.bk_reduce_mean(
4978
+ (I1_f2_wf3_small.abs() ** 2), axis=(-2, -1)
4979
+ )
4980
+ * fft_factor
4981
+ )
4982
+ norm_factor_S3 = (
4983
+ S2[:, None, j3, :] * P11_temp**pseudo_coef
4984
+ ) ** 0.5
4985
+ norm_factor_S3 = self.backend.bk_complex(
4986
+ norm_factor_S3, 0 * norm_factor_S3
4987
+ )
4988
+ elif normalization == "S2":
4989
+ norm_factor_S3 = (
4990
+ S2[:, None, j3, None, :]
4991
+ * S2[:, None, j2, :, None] ** pseudo_coef
4992
+ ) ** 0.5
4993
+ norm_factor_S3 = self.backend.bk_complex(
4994
+ norm_factor_S3, 0 * norm_factor_S3
4995
+ )
5663
4996
  else:
5664
4997
  norm_factor_S3 = C_ONE
5665
-
5666
4998
 
5667
4999
  if not edge:
5668
- S3.append(self.backend.bk_reduce_mean(
5669
- self.backend.bk_reshape(data_f_small,[N_image,1,1,1,M3,N3]) * self.backend.bk_conjugate(I1_f2_wf3_small)
5670
- ,axis=(-2,-1)) * fft_factor / norm_factor_S3)
5000
+ S3.append(
5001
+ self.backend.bk_reduce_mean(
5002
+ self.backend.bk_reshape(
5003
+ data_f_small, [N_image, 1, 1, 1, M3, N3]
5004
+ )
5005
+ * self.backend.bk_conjugate(I1_f2_wf3_small),
5006
+ axis=(-2, -1),
5007
+ )
5008
+ * fft_factor
5009
+ / norm_factor_S3
5010
+ )
5671
5011
  if get_variance:
5672
- S3_sigma.append(self.backend.bk_reduce_std(
5673
- self.backend.bk_reshape(data_f_small,[N_image,1,1,1,M3,N3]) * self.backend.bk_conjugate(I1_f2_wf3_small)
5674
- ,axis=(-2,-1)) * fft_factor / norm_factor_S3)
5012
+ S3_sigma.append(
5013
+ self.backend.bk_reduce_std(
5014
+ self.backend.bk_reshape(
5015
+ data_f_small, [N_image, 1, 1, 1, M3, N3]
5016
+ )
5017
+ * self.backend.bk_conjugate(I1_f2_wf3_small),
5018
+ axis=(-2, -1),
5019
+ )
5020
+ * fft_factor
5021
+ / norm_factor_S3
5022
+ )
5675
5023
  else:
5676
- S3.append(self.backend.bk_reduce_mean(
5677
- (self.backend.bk_reshape(data_small,[N_image,1,1,1,M3,N3]) * self.backend.bk_conjugate(I12_w3_small)
5678
- )[...,edge_dx:M3-edge_dx, edge_dy:N3-edge_dy],axis=(-2,-1)) * fft_factor / norm_factor_S3)
5024
+ S3.append(
5025
+ self.backend.bk_reduce_mean(
5026
+ (
5027
+ self.backend.bk_reshape(
5028
+ data_small, [N_image, 1, 1, 1, M3, N3]
5029
+ )
5030
+ * self.backend.bk_conjugate(I12_w3_small)
5031
+ )[..., edge_dx : M3 - edge_dx, edge_dy : N3 - edge_dy],
5032
+ axis=(-2, -1),
5033
+ )
5034
+ * fft_factor
5035
+ / norm_factor_S3
5036
+ )
5679
5037
  if get_variance:
5680
- S3_sigma.apend(self.backend.bk_reduce_std(
5681
- (self.backend.bk_reshape(data_small,[N_image,1,1,1,M3,N3]) * self.backend.bk_conjugate(I12_w3_small)
5682
- )[...,edge_dx:M3-edge_dx, edge_dy:N3-edge_dy],axis=(-2,-1)) * fft_factor / norm_factor_S3)
5038
+ S3_sigma.apend(
5039
+ self.backend.bk_reduce_std(
5040
+ (
5041
+ self.backend.bk_reshape(
5042
+ data_small, [N_image, 1, 1, 1, M3, N3]
5043
+ )
5044
+ * self.backend.bk_conjugate(I12_w3_small)
5045
+ )[..., edge_dx : M3 - edge_dx, edge_dy : N3 - edge_dy],
5046
+ axis=(-2, -1),
5047
+ )
5048
+ * fft_factor
5049
+ / norm_factor_S3
5050
+ )
5683
5051
  if data2 is not None:
5684
5052
  if not edge:
5685
- S3p.append(self.backend.bk_reduce_mean(
5686
- (self.backend.bk_reshape(data2_f_small,[N_image2,1,1,1,M3,N3]) * self.backend.bk_conjugate(I1_f2_wf3_small)
5687
- ),axis=(-2,-1)) * fft_factor / norm_factor_S3)
5688
-
5053
+ S3p.append(
5054
+ self.backend.bk_reduce_mean(
5055
+ (
5056
+ self.backend.bk_reshape(
5057
+ data2_f_small, [N_image2, 1, 1, 1, M3, N3]
5058
+ )
5059
+ * self.backend.bk_conjugate(I1_f2_wf3_small)
5060
+ ),
5061
+ axis=(-2, -1),
5062
+ )
5063
+ * fft_factor
5064
+ / norm_factor_S3
5065
+ )
5066
+
5689
5067
  if get_variance:
5690
- S3p_sigma.append(self.backend.bk_reduce_std(
5691
- (self.backend.bk_reshape(data2_f_small,[N_image2,1,1,1,M3,N3]) * self.backend.bk_conjugate(I1_f2_wf3_small)
5692
- ),axis=(-2,-1)) * fft_factor / norm_factor_S3)
5068
+ S3p_sigma.append(
5069
+ self.backend.bk_reduce_std(
5070
+ (
5071
+ self.backend.bk_reshape(
5072
+ data2_f_small, [N_image2, 1, 1, 1, M3, N3]
5073
+ )
5074
+ * self.backend.bk_conjugate(I1_f2_wf3_small)
5075
+ ),
5076
+ axis=(-2, -1),
5077
+ )
5078
+ * fft_factor
5079
+ / norm_factor_S3
5080
+ )
5693
5081
  else:
5694
-
5695
- S3p.append(self.backend.bk_reduce_mean(
5696
- (self.backend.bk_reshape(data2_small,[N_image2,1,1,1,M3,N3]) * self.backend.bk_conjugate(I12_w3_small)
5697
- )[...,edge_dx:M3-edge_dx, edge_dy:N3-edge_dy],axis=(-2,-1)) * fft_factor / norm_factor_S3)
5082
+
5083
+ S3p.append(
5084
+ self.backend.bk_reduce_mean(
5085
+ (
5086
+ self.backend.bk_reshape(
5087
+ data2_small, [N_image2, 1, 1, 1, M3, N3]
5088
+ )
5089
+ * self.backend.bk_conjugate(I12_w3_small)
5090
+ )[..., edge_dx : M3 - edge_dx, edge_dy : N3 - edge_dy],
5091
+ axis=(-2, -1),
5092
+ )
5093
+ * fft_factor
5094
+ / norm_factor_S3
5095
+ )
5698
5096
  if get_variance:
5699
- S3p_sigma.append(self.backend.bk_reduce_std(
5700
- (self.backend.bk_reshape(data2_small,[N_image2,1,1,1,M3,N3]) * self.backend.bk_conjugate(I12_w3_small)
5701
- )[...,edge_dx:M3-edge_dx, edge_dy:N3-edge_dy],axis=(-2,-1)) * fft_factor / norm_factor_S3)
5702
-
5097
+ S3p_sigma.append(
5098
+ self.backend.bk_reduce_std(
5099
+ (
5100
+ self.backend.bk_reshape(
5101
+ data2_small, [N_image2, 1, 1, 1, M3, N3]
5102
+ )
5103
+ * self.backend.bk_conjugate(I12_w3_small)
5104
+ )[
5105
+ ...,
5106
+ edge_dx : M3 - edge_dx,
5107
+ edge_dy : N3 - edge_dy,
5108
+ ],
5109
+ axis=(-2, -1),
5110
+ )
5111
+ * fft_factor
5112
+ / norm_factor_S3
5113
+ )
5114
+
5703
5115
  if j2 <= j3:
5704
- if normalization=='S2':
5705
- if use_ref:
5706
- P = 1/((ref_S2[:,j3:j3+1,:,None,None] * ref_S2[:,j2:j2+1,None,:,None] )**(0.5*pseudo_coef))
5707
- else:
5708
- P = 1/(((S2[:,j3:j3+1,:,None,None] * S2[:,j2:j2+1,None,:,None] )**(0.5*pseudo_coef)))
5709
- P=self.backend.bk_complex(P,0.0*P)
5116
+ if normalization == "S2":
5117
+ if use_ref:
5118
+ P = 1 / (
5119
+ (
5120
+ ref_S2[:, j3 : j3 + 1, :, None, None]
5121
+ * ref_S2[:, j2 : j2 + 1, None, :, None]
5122
+ )
5123
+ ** (0.5 * pseudo_coef)
5124
+ )
5125
+ else:
5126
+ P = 1 / (
5127
+ (
5128
+ S2[:, j3 : j3 + 1, :, None, None]
5129
+ * S2[:, j2 : j2 + 1, None, :, None]
5130
+ )
5131
+ ** (0.5 * pseudo_coef)
5132
+ )
5133
+ P = self.backend.bk_complex(P, 0.0 * P)
5710
5134
  else:
5711
- P=C_ONE
5712
-
5713
- for j1 in range(0, j2+1):
5714
- if not edge:
5715
- if not if_large_batch:
5716
- # [N_image,l1,l2,l3,x,y]
5717
- S4.append(self.backend.bk_reduce_mean(
5718
- (self.backend.bk_reshape(I1_f_small[:,j1],[N_image,1,L,1,1,M3,N3]) *
5719
- self.backend.bk_conjugate(self.backend.bk_reshape(I1_f2_wf3_2_small,[N_image,1,1,L,L,M3,N3]))
5720
- ),axis=(-2,-1)) * fft_factor*P)
5721
- if get_variance:
5722
- S4_sigma.append(self.backend.bk_reduce_std(
5723
- (self.backend.bk_reshape(I1_f_small[:,j1],[N_image,1,L,1,1,M3,N3]) *
5724
- self.backend.bk_conjugate(self.backend.bk_reshape(I1_f2_wf3_2_small,[N_image,1,1,L,L,M3,N3]))
5725
- ),axis=(-2,-1)) * fft_factor*P)
5726
- else:
5727
- for l1 in range(L):
5728
- # [N_image,l2,l3,x,y]
5729
- S4.append(self.backend.bk_reduce_mean(
5730
- (self.backend.bk_reshape(I1_f_small[:,j1,l1],[N_image,1,1,1,M3,N3]) *
5731
- self.backend.bk_conjugate(self.backend.bk_reshape(I1_f2_wf3_2_small,[N_image,1,L,L,M3,N3]))
5732
- ),axis=(-2,-1)) * fft_factor*P)
5733
- if get_variance:
5734
- S4_sigma.append(self.backend.bk_reduce_std(
5735
- (self.backend.bk_reshape(I1_f_small[:,j1,l1],[N_image,1,1,1,M3,N3]) *
5736
- self.backend.bk_conjugate(self.backend.bk_reshape(I1_f2_wf3_2_small,[N_image,1,L,L,M3,N3]))
5737
- ),axis=(-2,-1)) * fft_factor*P)
5135
+ P = C_ONE
5136
+
5137
+ for j1 in range(0, j2 + 1):
5138
+ if not edge:
5139
+ if not if_large_batch:
5140
+ # [N_image,l1,l2,l3,x,y]
5141
+ S4.append(
5142
+ self.backend.bk_reduce_mean(
5143
+ (
5144
+ self.backend.bk_reshape(
5145
+ I1_f_small[:, j1],
5146
+ [N_image, 1, L, 1, 1, M3, N3],
5147
+ )
5148
+ * self.backend.bk_conjugate(
5149
+ self.backend.bk_reshape(
5150
+ I1_f2_wf3_2_small,
5151
+ [N_image, 1, 1, L, L, M3, N3],
5152
+ )
5153
+ )
5154
+ ),
5155
+ axis=(-2, -1),
5156
+ )
5157
+ * fft_factor
5158
+ * P
5159
+ )
5160
+ if get_variance:
5161
+ S4_sigma.append(
5162
+ self.backend.bk_reduce_std(
5163
+ (
5164
+ self.backend.bk_reshape(
5165
+ I1_f_small[:, j1],
5166
+ [N_image, 1, L, 1, 1, M3, N3],
5167
+ )
5168
+ * self.backend.bk_conjugate(
5169
+ self.backend.bk_reshape(
5170
+ I1_f2_wf3_2_small,
5171
+ [N_image, 1, 1, L, L, M3, N3],
5172
+ )
5173
+ )
5174
+ ),
5175
+ axis=(-2, -1),
5176
+ )
5177
+ * fft_factor
5178
+ * P
5179
+ )
5738
5180
  else:
5739
- if not if_large_batch:
5740
- # [N_image,l1,l2,l3,x,y]
5741
- S4.append(self.backend.bk_reduce_mean(
5742
- (self.backend.bk_reshape(I1_small[:,j1],[N_image,1,L,1,1,M3,N3]) * self.backend.bk_conjugate(
5743
- self.backend.bk_reshape(I12_w3_2_small,[N_image,1,1,L,L,M3,N3])
5181
+ for l1 in range(L):
5182
+ # [N_image,l2,l3,x,y]
5183
+ S4.append(
5184
+ self.backend.bk_reduce_mean(
5185
+ (
5186
+ self.backend.bk_reshape(
5187
+ I1_f_small[:, j1, l1],
5188
+ [N_image, 1, 1, 1, M3, N3],
5189
+ )
5190
+ * self.backend.bk_conjugate(
5191
+ self.backend.bk_reshape(
5192
+ I1_f2_wf3_2_small,
5193
+ [N_image, 1, L, L, M3, N3],
5194
+ )
5195
+ )
5196
+ ),
5197
+ axis=(-2, -1),
5744
5198
  )
5745
- )[...,edge_dx:-edge_dx, edge_dy:-edge_dy],axis=(-2,-1)) * fft_factor*P)
5199
+ * fft_factor
5200
+ * P
5201
+ )
5746
5202
  if get_variance:
5747
- S4_sigma.append(self.backend.bk_reduce_std(
5748
- (self.backend.bk_reshape(I1_small[:,j1],[N_image,1,L,1,1,M3,N3]) * self.backend.bk_conjugate(
5749
- self.backend.bk_reshape(I12_w3_2_small,[N_image,1,1,L,L,M3,N3])
5203
+ S4_sigma.append(
5204
+ self.backend.bk_reduce_std(
5205
+ (
5206
+ self.backend.bk_reshape(
5207
+ I1_f_small[:, j1, l1],
5208
+ [N_image, 1, 1, 1, M3, N3],
5209
+ )
5210
+ * self.backend.bk_conjugate(
5211
+ self.backend.bk_reshape(
5212
+ I1_f2_wf3_2_small,
5213
+ [N_image, 1, L, L, M3, N3],
5214
+ )
5215
+ )
5216
+ ),
5217
+ axis=(-2, -1),
5750
5218
  )
5751
- )[...,edge_dx:-edge_dx, edge_dy:-edge_dy],axis=(-2,-1)) * fft_factor*P)
5752
- else:
5753
- for l1 in range(L):
5754
- # [N_image,l2,l3,x,y]
5755
- S4.append(self.backend.bk_reduce_mean(
5756
- (self.backend.bk_reshape(I1_small[:,j1],[N_image,1,1,1,M3,N3]) * self.backend.bk_conjugate(
5757
- self.backend.bk_reshape(I12_w3_2_small,[N_image,1,L,L,M3,N3])
5219
+ * fft_factor
5220
+ * P
5221
+ )
5222
+ else:
5223
+ if not if_large_batch:
5224
+ # [N_image,l1,l2,l3,x,y]
5225
+ S4.append(
5226
+ self.backend.bk_reduce_mean(
5227
+ (
5228
+ self.backend.bk_reshape(
5229
+ I1_small[:, j1],
5230
+ [N_image, 1, L, 1, 1, M3, N3],
5758
5231
  )
5759
- )[...,edge_dx:-edge_dx, edge_dy:-edge_dy],axis=(-2,-1)) * fft_factor*P)
5760
- if get_variance:
5761
- S4_sigma.append(self.backend.bk_reduce_std(
5762
- (self.backend.bk_reshape(I1_small[:,j1],[N_image,1,1,1,M3,N3]) * self.backend.bk_conjugate(
5763
- self.backend.bk_reshape(I12_w3_2_small,[N_image,1,L,L,M3,N3])
5232
+ * self.backend.bk_conjugate(
5233
+ self.backend.bk_reshape(
5234
+ I12_w3_2_small,
5235
+ [N_image, 1, 1, L, L, M3, N3],
5764
5236
  )
5765
- )[...,edge_dx:-edge_dx, edge_dy:-edge_dy],axis=(-2,-1)) * fft_factor*P)
5766
-
5767
- S3=self.backend.bk_concat(S3,axis=1)
5768
- S4=self.backend.bk_concat(S4,axis=1)
5769
-
5237
+ )
5238
+ )[..., edge_dx:-edge_dx, edge_dy:-edge_dy],
5239
+ axis=(-2, -1),
5240
+ )
5241
+ * fft_factor
5242
+ * P
5243
+ )
5244
+ if get_variance:
5245
+ S4_sigma.append(
5246
+ self.backend.bk_reduce_std(
5247
+ (
5248
+ self.backend.bk_reshape(
5249
+ I1_small[:, j1],
5250
+ [N_image, 1, L, 1, 1, M3, N3],
5251
+ )
5252
+ * self.backend.bk_conjugate(
5253
+ self.backend.bk_reshape(
5254
+ I12_w3_2_small,
5255
+ [N_image, 1, 1, L, L, M3, N3],
5256
+ )
5257
+ )
5258
+ )[..., edge_dx:-edge_dx, edge_dy:-edge_dy],
5259
+ axis=(-2, -1),
5260
+ )
5261
+ * fft_factor
5262
+ * P
5263
+ )
5264
+ else:
5265
+ for l1 in range(L):
5266
+ # [N_image,l2,l3,x,y]
5267
+ S4.append(
5268
+ self.backend.bk_reduce_mean(
5269
+ (
5270
+ self.backend.bk_reshape(
5271
+ I1_small[:, j1],
5272
+ [N_image, 1, 1, 1, M3, N3],
5273
+ )
5274
+ * self.backend.bk_conjugate(
5275
+ self.backend.bk_reshape(
5276
+ I12_w3_2_small,
5277
+ [N_image, 1, L, L, M3, N3],
5278
+ )
5279
+ )
5280
+ )[..., edge_dx:-edge_dx, edge_dy:-edge_dy],
5281
+ axis=(-2, -1),
5282
+ )
5283
+ * fft_factor
5284
+ * P
5285
+ )
5286
+ if get_variance:
5287
+ S4_sigma.append(
5288
+ self.backend.bk_reduce_std(
5289
+ (
5290
+ self.backend.bk_reshape(
5291
+ I1_small[:, j1],
5292
+ [N_image, 1, 1, 1, M3, N3],
5293
+ )
5294
+ * self.backend.bk_conjugate(
5295
+ self.backend.bk_reshape(
5296
+ I12_w3_2_small,
5297
+ [N_image, 1, L, L, M3, N3],
5298
+ )
5299
+ )
5300
+ )[
5301
+ ...,
5302
+ edge_dx:-edge_dx,
5303
+ edge_dy:-edge_dy,
5304
+ ],
5305
+ axis=(-2, -1),
5306
+ )
5307
+ * fft_factor
5308
+ * P
5309
+ )
5310
+
5311
+ S3 = self.backend.bk_concat(S3, axis=1)
5312
+ S4 = self.backend.bk_concat(S4, axis=1)
5313
+
5770
5314
  if get_variance:
5771
- S3_sigma=self.backend.bk_concat(S3_sigma,axis=1)
5772
- S4_sigma=self.backend.bk_concat(S4_sigma,axis=1)
5773
-
5315
+ S3_sigma = self.backend.bk_concat(S3_sigma, axis=1)
5316
+ S4_sigma = self.backend.bk_concat(S4_sigma, axis=1)
5317
+
5774
5318
  if data2 is not None:
5775
- S3p=self.backend.bk_concat(S3p,axis=1)
5319
+ S3p = self.backend.bk_concat(S3p, axis=1)
5776
5320
  if get_variance:
5777
- S3p_sigma=self.backend.bk_concat(S3p_sigma,axis=1)
5778
-
5321
+ S3p_sigma = self.backend.bk_concat(S3p_sigma, axis=1)
5322
+
5779
5323
  # average over l1 to obtain simple isotropic statistics
5780
5324
  if iso_ang:
5781
- S2_iso = self.backend.bk_reduce_mean(S2,axis=(-1))
5782
- S1_iso = self.backend.bk_reduce_mean(S1,axis=(-1))
5325
+ S2_iso = self.backend.bk_reduce_mean(S2, axis=(-1))
5326
+ S1_iso = self.backend.bk_reduce_mean(S1, axis=(-1))
5783
5327
  for l1 in range(L):
5784
5328
  for l2 in range(L):
5785
- S3_iso[...,(l2-l1)%L] += S3[...,l1,l2]
5329
+ S3_iso[..., (l2 - l1) % L] += S3[..., l1, l2]
5786
5330
  if data2 is not None:
5787
- S3p_iso[...,(l2-l1)%L] += S3p[...,l1,l2]
5331
+ S3p_iso[..., (l2 - l1) % L] += S3p[..., l1, l2]
5788
5332
  for l3 in range(L):
5789
- S4_iso[...,(l2-l1)%L,(l3-l1)%L] += S4[...,l1,l2,l3]
5790
- S3_iso /= L; S4_iso /= L
5333
+ S4_iso[..., (l2 - l1) % L, (l3 - l1) % L] += S4[..., l1, l2, l3]
5334
+ S3_iso /= L
5335
+ S4_iso /= L
5791
5336
  if data2 is not None:
5792
5337
  S3p_iso /= L
5793
-
5338
+
5794
5339
  if get_variance:
5795
- S2_sigma_iso = self.backend.bk_reduce_mean(S2_sigma,axis=(-1))
5796
- S1_sigma_iso = self.backend.bk_reduce_mean(S1_sigma,axis=(-1))
5340
+ S2_sigma_iso = self.backend.bk_reduce_mean(S2_sigma, axis=(-1))
5341
+ S1_sigma_iso = self.backend.bk_reduce_mean(S1_sigma, axis=(-1))
5797
5342
  for l1 in range(L):
5798
5343
  for l2 in range(L):
5799
- S3_sigma_iso[...,(l2-l1)%L] += S3_sigma[...,l1,l2]
5344
+ S3_sigma_iso[..., (l2 - l1) % L] += S3_sigma[..., l1, l2]
5800
5345
  if data2 is not None:
5801
- S3p_sigma_iso[...,(l2-l1)%L] += S3p_sigma[...,l1,l2]
5346
+ S3p_sigma_iso[..., (l2 - l1) % L] += S3p_sigma[..., l1, l2]
5802
5347
  for l3 in range(L):
5803
- S4_sigma_iso[...,(l2-l1)%L,(l3-l1)%L] += S4_sigma[...,l1,l2,l3]
5804
- S3_sigma_iso /= L; S4_sigma_iso /= L
5348
+ S4_sigma_iso[..., (l2 - l1) % L, (l3 - l1) % L] += S4_sigma[
5349
+ ..., l1, l2, l3
5350
+ ]
5351
+ S3_sigma_iso /= L
5352
+ S4_sigma_iso /= L
5805
5353
  if data2 is not None:
5806
5354
  S3p_sigma_iso /= L
5807
-
5355
+
5808
5356
  if data2 is None:
5809
- mean_data=self.backend.bk_reshape(self.backend.bk_reduce_mean(data,axis=(-2,-1)),[N_image,1])
5810
- std_data=self.backend.bk_reshape(self.backend.bk_reduce_std(data,axis=(-2,-1)),[N_image,1])
5357
+ mean_data = self.backend.bk_reshape(
5358
+ self.backend.bk_reduce_mean(data, axis=(-2, -1)), [N_image, 1]
5359
+ )
5360
+ std_data = self.backend.bk_reshape(
5361
+ self.backend.bk_reduce_std(data, axis=(-2, -1)), [N_image, 1]
5362
+ )
5811
5363
  else:
5812
- mean_data=self.backend.bk_reshape(self.backend.bk_reduce_mean(data*data2,axis=(-2,-1)),[N_image,1])
5813
- std_data=self.backend.bk_reshape(self.backend.bk_reduce_std(data*data2,axis=(-2,-1)),[N_image,1])
5814
-
5364
+ mean_data = self.backend.bk_reshape(
5365
+ self.backend.bk_reduce_mean(data * data2, axis=(-2, -1)), [N_image, 1]
5366
+ )
5367
+ std_data = self.backend.bk_reshape(
5368
+ self.backend.bk_reduce_std(data * data2, axis=(-2, -1)), [N_image, 1]
5369
+ )
5370
+
5815
5371
  if get_variance:
5816
- ref_sigma={}
5372
+ ref_sigma = {}
5817
5373
  if iso_ang:
5818
- ref_sigma['std_data']=std_data
5819
- ref_sigma['S1_sigma']=S1_sigma_iso
5820
- ref_sigma['S2_sigma']=S2_sigma_iso
5821
- ref_sigma['S3_sigma']=S3_sigma_iso
5822
- ref_sigma['S4_sigma']=S4_sigma_iso
5374
+ ref_sigma["std_data"] = std_data
5375
+ ref_sigma["S1_sigma"] = S1_sigma_iso
5376
+ ref_sigma["S2_sigma"] = S2_sigma_iso
5377
+ ref_sigma["S3_sigma"] = S3_sigma_iso
5378
+ ref_sigma["S4_sigma"] = S4_sigma_iso
5823
5379
  if data2 is not None:
5824
- ref_sigma['S3p_sigma']=S3p_sigma_iso
5380
+ ref_sigma["S3p_sigma"] = S3p_sigma_iso
5825
5381
  else:
5826
- ref_sigma['std_data']=std_data
5827
- ref_sigma['S1_sigma']=S1_sigma
5828
- ref_sigma['S2_sigma']=S2_sigma
5829
- ref_sigma['S3_sigma']=S3_sigma
5830
- ref_sigma['S4_sigma']=S4_sigma
5382
+ ref_sigma["std_data"] = std_data
5383
+ ref_sigma["S1_sigma"] = S1_sigma
5384
+ ref_sigma["S2_sigma"] = S2_sigma
5385
+ ref_sigma["S3_sigma"] = S3_sigma
5386
+ ref_sigma["S4_sigma"] = S4_sigma
5831
5387
  if data2 is not None:
5832
- ref_sigma['S3p_sigma']=S3_sigma
5833
-
5388
+ ref_sigma["S3p_sigma"] = S3_sigma
5389
+
5834
5390
  if data2 is None:
5835
5391
  if iso_ang:
5836
5392
  if ref_sigma is not None:
5837
- for_synthesis = self.backend.bk_concat((
5838
- mean_data/ref_sigma['std_data'],
5839
- std_data/ref_sigma['std_data'],
5840
- self.backend.bk_reshape(self.backend.bk_log(S2_iso/ref_sigma['S2_sigma']),[N_image, -1]),
5841
- self.backend.bk_reshape(self.backend.bk_log(S1_iso/ref_sigma['S1_sigma']),[N_image, -1]),
5842
- self.backend.bk_reshape(self.backend.bk_real(S3_iso/ref_sigma['S3_sigma']),[N_image, -1]),
5843
- self.backend.bk_reshape(self.backend.bk_imag(S3_iso/ref_sigma['S3_sigma']),[N_image, -1]),
5844
- self.backend.bk_reshape(self.backend.bk_real(S4_iso/ref_sigma['S4_sigma']),[N_image, -1]),
5845
- self.backend.bk_reshape(self.backend.bk_imag(S4_iso/ref_sigma['S4_sigma']),[N_image, -1]),
5846
- ),axis=-1)
5393
+ for_synthesis = self.backend.bk_concat(
5394
+ (
5395
+ mean_data / ref_sigma["std_data"],
5396
+ std_data / ref_sigma["std_data"],
5397
+ self.backend.bk_reshape(
5398
+ self.backend.bk_log(S2_iso / ref_sigma["S2_sigma"]),
5399
+ [N_image, -1],
5400
+ ),
5401
+ self.backend.bk_reshape(
5402
+ self.backend.bk_log(S1_iso / ref_sigma["S1_sigma"]),
5403
+ [N_image, -1],
5404
+ ),
5405
+ self.backend.bk_reshape(
5406
+ self.backend.bk_real(S3_iso / ref_sigma["S3_sigma"]),
5407
+ [N_image, -1],
5408
+ ),
5409
+ self.backend.bk_reshape(
5410
+ self.backend.bk_imag(S3_iso / ref_sigma["S3_sigma"]),
5411
+ [N_image, -1],
5412
+ ),
5413
+ self.backend.bk_reshape(
5414
+ self.backend.bk_real(S4_iso / ref_sigma["S4_sigma"]),
5415
+ [N_image, -1],
5416
+ ),
5417
+ self.backend.bk_reshape(
5418
+ self.backend.bk_imag(S4_iso / ref_sigma["S4_sigma"]),
5419
+ [N_image, -1],
5420
+ ),
5421
+ ),
5422
+ axis=-1,
5423
+ )
5847
5424
  else:
5848
- for_synthesis = self.backend.bk_concat((
5849
- mean_data/std_data,
5850
- std_data,
5851
- self.backend.bk_reshape(self.backend.bk_log(S2_iso),[N_image, -1]),
5852
- self.backend.bk_reshape(self.backend.bk_log(S1_iso),[N_image, -1]),
5853
- self.backend.bk_reshape(self.backend.bk_real(S3_iso),[N_image, -1]),
5854
- self.backend.bk_reshape(self.backend.bk_imag(S3_iso),[N_image, -1]),
5855
- self.backend.bk_reshape(self.backend.bk_real(S4_iso),[N_image, -1]),
5856
- self.backend.bk_reshape(self.backend.bk_imag(S4_iso),[N_image, -1]),
5857
- ),axis=-1)
5425
+ for_synthesis = self.backend.bk_concat(
5426
+ (
5427
+ mean_data / std_data,
5428
+ std_data,
5429
+ self.backend.bk_reshape(
5430
+ self.backend.bk_log(S2_iso), [N_image, -1]
5431
+ ),
5432
+ self.backend.bk_reshape(
5433
+ self.backend.bk_log(S1_iso), [N_image, -1]
5434
+ ),
5435
+ self.backend.bk_reshape(
5436
+ self.backend.bk_real(S3_iso), [N_image, -1]
5437
+ ),
5438
+ self.backend.bk_reshape(
5439
+ self.backend.bk_imag(S3_iso), [N_image, -1]
5440
+ ),
5441
+ self.backend.bk_reshape(
5442
+ self.backend.bk_real(S4_iso), [N_image, -1]
5443
+ ),
5444
+ self.backend.bk_reshape(
5445
+ self.backend.bk_imag(S4_iso), [N_image, -1]
5446
+ ),
5447
+ ),
5448
+ axis=-1,
5449
+ )
5858
5450
  else:
5859
5451
  if ref_sigma is not None:
5860
- for_synthesis = self.backend.bk_concat((
5861
- mean_data/ref_sigma['std_data'],
5862
- std_data/ref_sigma['std_data'],
5863
- self.backend.bk_reshape(self.backend.bk_log(S2/ref_sigma['S2_sigma']),[N_image, -1]),
5864
- self.backend.bk_reshape(self.backend.bk_log(S1/ref_sigma['S1_sigma']),[N_image, -1]),
5865
- self.backend.bk_reshape(self.backend.bk_real(S3/ref_sigma['S3_sigma']),[N_image, -1]),
5866
- self.backend.bk_reshape(self.backend.bk_imag(S3/ref_sigma['S3_sigma']),[N_image, -1]),
5867
- self.backend.bk_reshape(self.backend.bk_real(S4/ref_sigma['S4_sigma']),[N_image, -1]),
5868
- self.backend.bk_reshape(self.backend.bk_imag(S4/ref_sigma['S4_sigma']),[N_image, -1]),
5869
- ),axis=-1)
5452
+ for_synthesis = self.backend.bk_concat(
5453
+ (
5454
+ mean_data / ref_sigma["std_data"],
5455
+ std_data / ref_sigma["std_data"],
5456
+ self.backend.bk_reshape(
5457
+ self.backend.bk_log(S2 / ref_sigma["S2_sigma"]),
5458
+ [N_image, -1],
5459
+ ),
5460
+ self.backend.bk_reshape(
5461
+ self.backend.bk_log(S1 / ref_sigma["S1_sigma"]),
5462
+ [N_image, -1],
5463
+ ),
5464
+ self.backend.bk_reshape(
5465
+ self.backend.bk_real(S3 / ref_sigma["S3_sigma"]),
5466
+ [N_image, -1],
5467
+ ),
5468
+ self.backend.bk_reshape(
5469
+ self.backend.bk_imag(S3 / ref_sigma["S3_sigma"]),
5470
+ [N_image, -1],
5471
+ ),
5472
+ self.backend.bk_reshape(
5473
+ self.backend.bk_real(S4 / ref_sigma["S4_sigma"]),
5474
+ [N_image, -1],
5475
+ ),
5476
+ self.backend.bk_reshape(
5477
+ self.backend.bk_imag(S4 / ref_sigma["S4_sigma"]),
5478
+ [N_image, -1],
5479
+ ),
5480
+ ),
5481
+ axis=-1,
5482
+ )
5870
5483
  else:
5871
- for_synthesis = self.backend.bk_concat((
5872
- mean_data/std_data,
5484
+ for_synthesis = self.backend.bk_concat(
5485
+ (
5486
+ mean_data / std_data,
5873
5487
  std_data,
5874
- self.backend.bk_reshape(self.backend.bk_log(S2),[N_image, -1]),
5875
- self.backend.bk_reshape(self.backend.bk_log(S1),[N_image, -1]),
5876
- self.backend.bk_reshape(self.backend.bk_real(S3),[N_image, -1]),
5877
- self.backend.bk_reshape(self.backend.bk_imag(S3),[N_image, -1]),
5878
- self.backend.bk_reshape(self.backend.bk_real(S4),[N_image, -1]),
5879
- self.backend.bk_reshape(self.backend.bk_imag(S4),[N_image, -1])
5880
- ),axis=-1)
5488
+ self.backend.bk_reshape(
5489
+ self.backend.bk_log(S2), [N_image, -1]
5490
+ ),
5491
+ self.backend.bk_reshape(
5492
+ self.backend.bk_log(S1), [N_image, -1]
5493
+ ),
5494
+ self.backend.bk_reshape(
5495
+ self.backend.bk_real(S3), [N_image, -1]
5496
+ ),
5497
+ self.backend.bk_reshape(
5498
+ self.backend.bk_imag(S3), [N_image, -1]
5499
+ ),
5500
+ self.backend.bk_reshape(
5501
+ self.backend.bk_real(S4), [N_image, -1]
5502
+ ),
5503
+ self.backend.bk_reshape(
5504
+ self.backend.bk_imag(S4), [N_image, -1]
5505
+ ),
5506
+ ),
5507
+ axis=-1,
5508
+ )
5881
5509
  else:
5882
5510
  if iso_ang:
5883
5511
  if ref_sigma is not None:
5884
- for_synthesis = self.backend.backend.cat((
5885
- mean_data/ref_sigma['std_data'],
5886
- std_data/ref_sigma['std_data'],
5887
- self.backend.bk_reshape(self.backend.bk_real(S2_iso/ref_sigma['S2_sigma']),[N_image, -1]),
5888
- self.backend.bk_reshape(self.backend.bk_real(S1_iso/ref_sigma['S1_sigma']),[N_image, -1]),
5889
- self.backend.bk_reshape(self.backend.bk_real(S3_iso/ref_sigma['S3_sigma']),[N_image, -1]),
5890
- self.backend.bk_reshape(self.backend.bk_imag(S3_iso/ref_sigma['S3_sigma']),[N_image, -1]),
5891
- self.backend.bk_reshape(self.backend.bk_real(S3p_iso/ref_sigma['S3p_sigma']),[N_image, -1]),
5892
- self.backend.bk_reshape(self.backend.bk_imag(S3p_iso/ref_sigma['S3p_sigma']),[N_image, -1]),
5893
- self.backend.bk_reshape(self.backend.bk_real(S4_iso/ref_sigma['S4_sigma']),[N_image, -1]),
5894
- self.backend.bk_reshape(self.backend.bk_imag(S4_iso/ref_sigma['S4_sigma']),[N_image, -1]),
5895
- ),axis=-1)
5512
+ for_synthesis = self.backend.backend.cat(
5513
+ (
5514
+ mean_data / ref_sigma["std_data"],
5515
+ std_data / ref_sigma["std_data"],
5516
+ self.backend.bk_reshape(
5517
+ self.backend.bk_real(S2_iso / ref_sigma["S2_sigma"]),
5518
+ [N_image, -1],
5519
+ ),
5520
+ self.backend.bk_reshape(
5521
+ self.backend.bk_real(S1_iso / ref_sigma["S1_sigma"]),
5522
+ [N_image, -1],
5523
+ ),
5524
+ self.backend.bk_reshape(
5525
+ self.backend.bk_real(S3_iso / ref_sigma["S3_sigma"]),
5526
+ [N_image, -1],
5527
+ ),
5528
+ self.backend.bk_reshape(
5529
+ self.backend.bk_imag(S3_iso / ref_sigma["S3_sigma"]),
5530
+ [N_image, -1],
5531
+ ),
5532
+ self.backend.bk_reshape(
5533
+ self.backend.bk_real(S3p_iso / ref_sigma["S3p_sigma"]),
5534
+ [N_image, -1],
5535
+ ),
5536
+ self.backend.bk_reshape(
5537
+ self.backend.bk_imag(S3p_iso / ref_sigma["S3p_sigma"]),
5538
+ [N_image, -1],
5539
+ ),
5540
+ self.backend.bk_reshape(
5541
+ self.backend.bk_real(S4_iso / ref_sigma["S4_sigma"]),
5542
+ [N_image, -1],
5543
+ ),
5544
+ self.backend.bk_reshape(
5545
+ self.backend.bk_imag(S4_iso / ref_sigma["S4_sigma"]),
5546
+ [N_image, -1],
5547
+ ),
5548
+ ),
5549
+ axis=-1,
5550
+ )
5896
5551
  else:
5897
- for_synthesis = self.backend.backend.cat((
5898
- mean_data/std_data,
5899
- std_data,
5900
- self.backend.bk_reshape(self.backend.bk_real(S2_iso),[N_image, -1]),
5901
- self.backend.bk_reshape(self.backend.bk_real(S1_iso),[N_image, -1]),
5902
- self.backend.bk_reshape(self.backend.bk_real(S3_iso),[N_image, -1]),
5903
- self.backend.bk_reshape(self.backend.bk_imag(S3_iso),[N_image, -1]),
5904
- self.backend.bk_reshape(self.backend.bk_real(S3p_iso),[N_image, -1]),
5905
- self.backend.bk_reshape(self.backend.bk_imag(S3p_iso),[N_image, -1]),
5906
- self.backend.bk_reshape(self.backend.bk_real(S4_iso),[N_image, -1]),
5907
- self.backend.bk_reshape(self.backend.bk_imag(S4_iso),[N_image, -1]),
5908
- ),axis=-1)
5552
+ for_synthesis = self.backend.backend.cat(
5553
+ (
5554
+ mean_data / std_data,
5555
+ std_data,
5556
+ self.backend.bk_reshape(
5557
+ self.backend.bk_real(S2_iso), [N_image, -1]
5558
+ ),
5559
+ self.backend.bk_reshape(
5560
+ self.backend.bk_real(S1_iso), [N_image, -1]
5561
+ ),
5562
+ self.backend.bk_reshape(
5563
+ self.backend.bk_real(S3_iso), [N_image, -1]
5564
+ ),
5565
+ self.backend.bk_reshape(
5566
+ self.backend.bk_imag(S3_iso), [N_image, -1]
5567
+ ),
5568
+ self.backend.bk_reshape(
5569
+ self.backend.bk_real(S3p_iso), [N_image, -1]
5570
+ ),
5571
+ self.backend.bk_reshape(
5572
+ self.backend.bk_imag(S3p_iso), [N_image, -1]
5573
+ ),
5574
+ self.backend.bk_reshape(
5575
+ self.backend.bk_real(S4_iso), [N_image, -1]
5576
+ ),
5577
+ self.backend.bk_reshape(
5578
+ self.backend.bk_imag(S4_iso), [N_image, -1]
5579
+ ),
5580
+ ),
5581
+ axis=-1,
5582
+ )
5909
5583
  else:
5910
5584
  if ref_sigma is not None:
5911
- for_synthesis = self.backend.backend.cat((
5912
- mean_data/ref_sigma['std_data'],
5913
- std_data/ref_sigma['std_data'],
5914
- self.backend.bk_reshape(self.backend.bk_real(S2/ref_sigma['S2_sigma']),[N_image, -1]),
5915
- self.backend.bk_reshape(self.backend.bk_real(S1/ref_sigma['S1_sigma']),[N_image, -1]),
5916
- self.backend.bk_reshape(self.backend.bk_real(S3/ref_sigma['S3_sigma']),[N_image, -1]),
5917
- self.backend.bk_reshape(self.backend.bk_imag(S3/ref_sigma['S3_sigma']),[N_image, -1]),
5918
- self.backend.bk_reshape(self.backend.bk_real(S3p/ref_sigma['S3p_sigma']),[N_image, -1]),
5919
- self.backend.bk_reshape(self.backend.bk_imag(S3p/ref_sigma['S3p_sigma']),[N_image, -1]),
5920
- self.backend.bk_reshape(self.backend.bk_real(S4/ref_sigma['S4_sigma']),[N_image, -1]),
5921
- self.backend.bk_reshape(self.backend.bk_imag(S4/ref_sigma['S4_sigma']),[N_image, -1]),
5922
- ),axis=-1)
5585
+ for_synthesis = self.backend.backend.cat(
5586
+ (
5587
+ mean_data / ref_sigma["std_data"],
5588
+ std_data / ref_sigma["std_data"],
5589
+ self.backend.bk_reshape(
5590
+ self.backend.bk_real(S2 / ref_sigma["S2_sigma"]),
5591
+ [N_image, -1],
5592
+ ),
5593
+ self.backend.bk_reshape(
5594
+ self.backend.bk_real(S1 / ref_sigma["S1_sigma"]),
5595
+ [N_image, -1],
5596
+ ),
5597
+ self.backend.bk_reshape(
5598
+ self.backend.bk_real(S3 / ref_sigma["S3_sigma"]),
5599
+ [N_image, -1],
5600
+ ),
5601
+ self.backend.bk_reshape(
5602
+ self.backend.bk_imag(S3 / ref_sigma["S3_sigma"]),
5603
+ [N_image, -1],
5604
+ ),
5605
+ self.backend.bk_reshape(
5606
+ self.backend.bk_real(S3p / ref_sigma["S3p_sigma"]),
5607
+ [N_image, -1],
5608
+ ),
5609
+ self.backend.bk_reshape(
5610
+ self.backend.bk_imag(S3p / ref_sigma["S3p_sigma"]),
5611
+ [N_image, -1],
5612
+ ),
5613
+ self.backend.bk_reshape(
5614
+ self.backend.bk_real(S4 / ref_sigma["S4_sigma"]),
5615
+ [N_image, -1],
5616
+ ),
5617
+ self.backend.bk_reshape(
5618
+ self.backend.bk_imag(S4 / ref_sigma["S4_sigma"]),
5619
+ [N_image, -1],
5620
+ ),
5621
+ ),
5622
+ axis=-1,
5623
+ )
5923
5624
  else:
5924
- for_synthesis = self.backend.bk_concat((
5925
- mean_data/std_data,
5625
+ for_synthesis = self.backend.bk_concat(
5626
+ (
5627
+ mean_data / std_data,
5926
5628
  std_data,
5927
- self.backend.bk_reshape(self.backend.bk_real(S2),[N_image, -1]),
5928
- self.backend.bk_reshape(self.backend.bk_real(S1),[N_image, -1]),
5929
- self.backend.bk_reshape(self.backend.bk_real(S3),[N_image, -1]),
5930
- self.backend.bk_reshape(self.backend.bk_imag(S3),[N_image, -1]),
5931
- self.backend.bk_reshape(self.backend.bk_real(S3p),[N_image, -1]),
5932
- self.backend.bk_reshape(self.backend.bk_imag(S3p),[N_image, -1]),
5933
- self.backend.bk_reshape(self.backend.bk_real(S4),[N_image, -1]),
5934
- self.backend.bk_reshape(self.backend.bk_imag(S4),[N_image, -1])
5935
- ),axis=-1)
5936
-
5937
- if not use_ref:
5938
- self.ref_scattering_cov_S2=S2
5939
-
5629
+ self.backend.bk_reshape(
5630
+ self.backend.bk_real(S2), [N_image, -1]
5631
+ ),
5632
+ self.backend.bk_reshape(
5633
+ self.backend.bk_real(S1), [N_image, -1]
5634
+ ),
5635
+ self.backend.bk_reshape(
5636
+ self.backend.bk_real(S3), [N_image, -1]
5637
+ ),
5638
+ self.backend.bk_reshape(
5639
+ self.backend.bk_imag(S3), [N_image, -1]
5640
+ ),
5641
+ self.backend.bk_reshape(
5642
+ self.backend.bk_real(S3p), [N_image, -1]
5643
+ ),
5644
+ self.backend.bk_reshape(
5645
+ self.backend.bk_imag(S3p), [N_image, -1]
5646
+ ),
5647
+ self.backend.bk_reshape(
5648
+ self.backend.bk_real(S4), [N_image, -1]
5649
+ ),
5650
+ self.backend.bk_reshape(
5651
+ self.backend.bk_imag(S4), [N_image, -1]
5652
+ ),
5653
+ ),
5654
+ axis=-1,
5655
+ )
5656
+
5657
+ if not use_ref:
5658
+ self.ref_scattering_cov_S2 = S2
5659
+
5940
5660
  if get_variance:
5941
- return for_synthesis,ref_sigma
5942
-
5661
+ return for_synthesis, ref_sigma
5662
+
5943
5663
  return for_synthesis
5944
-
5945
-
5946
- def to_gaussian(self,x):
5947
- from scipy.stats import norm
5664
+
5665
+ def to_gaussian(self, x):
5948
5666
  from scipy.interpolate import interp1d
5667
+ from scipy.stats import norm
5949
5668
 
5950
- idx=np.argsort(x.flatten())
5669
+ idx = np.argsort(x.flatten())
5951
5670
  p = (np.arange(1, idx.shape[0] + 1) - 0.5) / idx.shape[0]
5952
- im_target=x.flatten()
5671
+ im_target = x.flatten()
5953
5672
  im_target[idx] = norm.ppf(p)
5954
-
5673
+
5955
5674
  # Interpolation cubique
5956
- self.f_gaussian = interp1d(im_target[idx], x.flatten()[idx], kind='cubic')
5957
- self.val_min=im_target[idx[0]]
5958
- self.val_max=im_target[idx[-1]]
5675
+ self.f_gaussian = interp1d(im_target[idx], x.flatten()[idx], kind="cubic")
5676
+ self.val_min = im_target[idx[0]]
5677
+ self.val_max = im_target[idx[-1]]
5959
5678
  return im_target.reshape(x.shape)
5960
5679
 
5680
+ def from_gaussian(self, x):
5961
5681
 
5962
- def from_gaussian(self,x):
5963
-
5964
- x=self.backend.bk_clip_by_value(x,self.val_min,self.val_max)
5682
+ x = self.backend.bk_clip_by_value(x, self.val_min, self.val_max)
5965
5683
  return self.f_gaussian(self.backend.to_numpy(x))
5966
-
5684
+
5967
5685
  def square(self, x):
5968
5686
  if isinstance(x, scat_cov):
5969
5687
  if x.S1 is None:
@@ -6315,89 +6033,133 @@ class funct(FOC.FoCUS):
6315
6033
  s0, s2, s3, s4, s1=s1, s3p=s3p, backend=self.backend, use_1D=self.use_1D
6316
6034
  )
6317
6035
 
6318
- def synthesis(self,
6319
- image_target,
6320
- nstep=4,
6321
- seed=1234,
6322
- Jmax=None,
6323
- edge=False,
6324
- to_gaussian=True,
6325
- use_variance=False,
6326
- synthesised_N=1,
6327
- input_image=None,
6328
- iso_ang=False,
6329
- EVAL_FREQUENCY=100,
6330
- NUM_EPOCHS = 300):
6331
-
6332
- import foscat.Synthesis as synthe
6036
+ def synthesis(
6037
+ self,
6038
+ image_target,
6039
+ nstep=4,
6040
+ seed=1234,
6041
+ Jmax=None,
6042
+ edge=False,
6043
+ to_gaussian=True,
6044
+ use_variance=False,
6045
+ synthesised_N=1,
6046
+ input_image=None,
6047
+ grd_mask=None,
6048
+ iso_ang=False,
6049
+ EVAL_FREQUENCY=100,
6050
+ NUM_EPOCHS=300,
6051
+ ):
6052
+
6333
6053
  import time
6334
6054
 
6335
- def The_loss(u,scat_operator,args):
6336
- ref = args[0]
6055
+ import foscat.Synthesis as synthe
6056
+
6057
+ def The_loss(u, scat_operator, args):
6058
+ ref = args[0]
6337
6059
  sref = args[1]
6338
- use_v= args[2]
6339
-
6060
+ use_v = args[2]
6061
+
6340
6062
  # compute scattering covariance of the current synthetised map called u
6341
6063
  if use_v:
6342
- learn=scat_operator.reduce_mean_batch(scat_operator.scattering_cov(u,edge=edge,Jmax=Jmax,ref_sigma=sref,use_ref=True,iso_ang=iso_ang))
6064
+ learn = scat_operator.reduce_mean_batch(
6065
+ scat_operator.scattering_cov(
6066
+ u,
6067
+ edge=edge,
6068
+ Jmax=Jmax,
6069
+ ref_sigma=sref,
6070
+ use_ref=True,
6071
+ iso_ang=iso_ang,
6072
+ )
6073
+ )
6343
6074
  else:
6344
- learn=scat_operator.reduce_mean_batch(scat_operator.scattering_cov(u,edge=edge,Jmax=Jmax,use_ref=True,iso_ang=iso_ang))
6345
-
6075
+ learn = scat_operator.reduce_mean_batch(
6076
+ scat_operator.scattering_cov(
6077
+ u, edge=edge, Jmax=Jmax, use_ref=True, iso_ang=iso_ang
6078
+ )
6079
+ )
6080
+
6346
6081
  # make the difference withe the reference coordinates
6347
- loss=scat_operator.backend.bk_reduce_mean(scat_operator.backend.bk_square((learn-ref)))
6082
+ loss = scat_operator.backend.bk_reduce_mean(
6083
+ scat_operator.backend.bk_square(learn - ref)
6084
+ )
6348
6085
  return loss
6349
6086
 
6350
6087
  if to_gaussian:
6351
6088
  # Change the data histogram to gaussian distribution
6352
- im_target=self.to_gaussian(image_target)
6089
+ im_target = self.to_gaussian(image_target)
6353
6090
  else:
6354
- im_target=image_target
6355
-
6356
- axis=len(im_target.shape)-1
6091
+ im_target = image_target
6092
+
6093
+ axis = len(im_target.shape) - 1
6357
6094
  if self.use_2D:
6358
- axis-=1
6359
- if axis==0:
6360
- im_target=self.backend.bk_expand_dims(im_target,0)
6095
+ axis -= 1
6096
+ if axis == 0:
6097
+ im_target = self.backend.bk_expand_dims(im_target, 0)
6361
6098
 
6362
6099
  # compute the number of possible steps
6363
6100
  if self.use_2D:
6364
- jmax=int(np.min([np.log(im_target.shape[1]),np.log(im_target.shape[2])])/np.log(2))
6101
+ jmax = int(
6102
+ np.min([np.log(im_target.shape[1]), np.log(im_target.shape[2])])
6103
+ / np.log(2)
6104
+ )
6365
6105
  elif self.use_1D:
6366
- jmax=int(np.log(im_target.shape[1])/np.log(2))
6106
+ jmax = int(np.log(im_target.shape[1]) / np.log(2))
6367
6107
  else:
6368
- jmax=int((np.log(im_target.shape[1]//12)/np.log(2))/2)
6369
- nside=2**jmax
6108
+ jmax = int((np.log(im_target.shape[1] // 12) / np.log(2)) / 2)
6109
+ nside = 2**jmax
6370
6110
 
6371
- if nstep>jmax-1:
6372
- nstep=jmax-1
6111
+ if nstep > jmax - 1:
6112
+ nstep = jmax - 1
6373
6113
 
6374
- t1=time.time()
6375
- tmp={}
6376
- tmp[nstep-1]=im_target
6377
- for l in range(nstep-2,-1,-1):
6378
- tmp[l]=self.ud_grade_2(tmp[l+1],axis=1)
6114
+ t1 = time.time()
6115
+ tmp = {}
6116
+
6117
+ l_grd_mask={}
6118
+
6119
+ tmp[nstep - 1] = self.backend.bk_cast(im_target)
6120
+ if grd_mask is not None:
6121
+ l_grd_mask[nstep - 1] = self.backend.bk_cast(grd_mask)
6122
+ else:
6123
+ l_grd_mask[nstep - 1] = None
6379
6124
 
6380
- if not self.use_2D and not self.use_1D:
6381
- l_nside=nside//(2**(nstep-1))
6125
+ for ell in range(nstep - 2, -1, -1):
6126
+ tmp[ell] = self.ud_grade_2(tmp[ell + 1], axis=1)
6127
+ if grd_mask is not None:
6128
+ l_grd_mask[ell] = self.ud_grade_2(l_grd_mask[ell + 1], axis=1)
6129
+ else:
6130
+ l_grd_mask[ell] = None
6382
6131
 
6132
+
6133
+ if not self.use_2D and not self.use_1D:
6134
+ l_nside = nside // (2 ** (nstep - 1))
6135
+
6383
6136
  for k in range(nstep):
6384
- if k==0:
6137
+ if k == 0:
6385
6138
  if input_image is None:
6386
6139
  np.random.seed(seed)
6387
6140
  if self.use_2D:
6388
- imap=np.random.randn(synthesised_N,
6389
- tmp[k].shape[1],
6390
- tmp[k].shape[2])
6141
+ imap = self.backend.bk_cast(np.random.randn(
6142
+ synthesised_N, tmp[k].shape[1], tmp[k].shape[2]
6143
+ ))
6391
6144
  else:
6392
- imap=np.random.randn(synthesised_N,
6393
- tmp[k].shape[1])
6145
+ imap = self.backend.bk_cast(np.random.randn(synthesised_N, tmp[k].shape[1]))
6394
6146
  else:
6395
6147
  if self.use_2D:
6396
- imap=self.backend.bk_reshape(self.backend.bk_tile(self.backend.bk_cast(input_image.flatten()),synthesised_N),
6397
- [synthesised_N,tmp[k].shape[1],tmp[k].shape[2]])
6148
+ imap = self.backend.bk_reshape(
6149
+ self.backend.bk_tile(
6150
+ self.backend.bk_cast(input_image.flatten()),
6151
+ synthesised_N,
6152
+ ),
6153
+ [synthesised_N, tmp[k].shape[1], tmp[k].shape[2]],
6154
+ )
6398
6155
  else:
6399
- imap=self.backend.bk_reshape(self.backend.bk_tile(self.backend.bk_cast(input_image.flatten()),synthesised_N),
6400
- [synthesised_N,tmp[k].shape[1]])
6156
+ imap = self.backend.bk_reshape(
6157
+ self.backend.bk_tile(
6158
+ self.backend.bk_cast(input_image.flatten()),
6159
+ synthesised_N,
6160
+ ),
6161
+ [synthesised_N, tmp[k].shape[1]],
6162
+ )
6401
6163
  else:
6402
6164
  # Increase the resolution between each step
6403
6165
  if self.use_2D:
@@ -6408,48 +6170,49 @@ class funct(FOC.FoCUS):
6408
6170
  imap = self.up_grade(omap, imap.shape[1] * 2, axis=1)
6409
6171
  else:
6410
6172
  imap = self.up_grade(omap, l_nside, axis=1)
6411
-
6173
+
6174
+ if grd_mask is not None:
6175
+ imap=imap*l_grd_mask[k]+tmp[k]*(1-l_grd_mask[k])
6176
+
6412
6177
  # compute the coefficients for the target image
6413
6178
  if use_variance:
6414
- ref,sref=self.scattering_cov(tmp[k],get_variance=True,edge=edge,Jmax=Jmax,iso_ang=iso_ang)
6179
+ ref, sref = self.scattering_cov(
6180
+ tmp[k], get_variance=True, edge=edge, Jmax=Jmax, iso_ang=iso_ang
6181
+ )
6415
6182
  else:
6416
- ref=self.scattering_cov(tmp[k],edge=edge,Jmax=Jmax,iso_ang=iso_ang)
6417
- sref=ref
6418
-
6183
+ ref = self.scattering_cov(tmp[k], edge=edge, Jmax=Jmax, iso_ang=iso_ang)
6184
+ sref = ref
6185
+
6419
6186
  # compute the mean of the population does nothing if only one map is given
6420
- ref=self.reduce_mean_batch(ref)
6421
-
6187
+ ref = self.reduce_mean_batch(ref)
6188
+
6422
6189
  # define a loss to minimize
6423
- loss=synthe.Loss(The_loss,self,ref,sref,use_variance)
6424
-
6190
+ loss = synthe.Loss(The_loss, self, ref, sref, use_variance)
6191
+
6425
6192
  sy = synthe.Synthesis([loss])
6426
6193
 
6427
6194
  # initialize the synthesised map
6428
6195
  if self.use_2D:
6429
- print('Synthesis scale [ %d x %d ]'%(imap.shape[1],imap.shape[2]))
6196
+ print("Synthesis scale [ %d x %d ]" % (imap.shape[1], imap.shape[2]))
6430
6197
  elif self.use_1D:
6431
- print('Synthesis scale [ %d ]'%(imap.shape[1]))
6198
+ print("Synthesis scale [ %d ]" % (imap.shape[1]))
6432
6199
  else:
6433
- print('Synthesis scale nside=%d'%(l_nside))
6434
- l_nside*=2
6435
-
6200
+ print("Synthesis scale nside=%d" % (l_nside))
6201
+ l_nside *= 2
6202
+
6436
6203
  # do the minimization
6437
- omap=sy.run(imap,
6438
- EVAL_FREQUENCY=EVAL_FREQUENCY,
6439
- NUM_EPOCHS = NUM_EPOCHS)
6440
-
6441
-
6204
+ omap = sy.run(imap, EVAL_FREQUENCY=EVAL_FREQUENCY, NUM_EPOCHS=NUM_EPOCHS,grd_mask=l_grd_mask[k])
6442
6205
 
6443
- t2=time.time()
6444
- print('Total computation %.2fs'%(t2-t1))
6206
+ t2 = time.time()
6207
+ print("Total computation %.2fs" % (t2 - t1))
6445
6208
 
6446
6209
  if to_gaussian:
6447
- omap=self.from_gaussian(omap)
6210
+ omap = self.from_gaussian(omap)
6448
6211
 
6449
- if axis==0 and synthesised_N==1:
6212
+ if axis == 0 and synthesised_N == 1:
6450
6213
  return omap[0]
6451
6214
  else:
6452
6215
  return omap
6453
-
6454
- def to_numpy(self,x):
6216
+
6217
+ def to_numpy(self, x):
6455
6218
  return self.backend.to_numpy(x)