foscat 3.8.0__py3-none-any.whl → 3.9.0__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
foscat/scat_cov.py CHANGED
@@ -92,7 +92,12 @@ class scat_cov:
92
92
  )
93
93
 
94
94
  def conv2complex(self, val):
95
- if val.dtype == "complex64" or val.dtype=="complex128" or val.dtype == "torch.complex64" or val.dtype == "torch.complex128" :
95
+ if (
96
+ val.dtype == "complex64"
97
+ or val.dtype == "complex128"
98
+ or val.dtype == "torch.complex64"
99
+ or val.dtype == "torch.complex128"
100
+ ):
96
101
  return val
97
102
  else:
98
103
  return self.backend.bk_complex(val, 0 * val)
@@ -2043,9 +2048,11 @@ class scat_cov:
2043
2048
  )
2044
2049
  )
2045
2050
  else:
2046
- s3[i, j, idx[noff:], k, l_orient] = self.backend.to_numpy(self.S3)[
2047
- i, j, j2 == ij - noff, k, l_orient
2048
- ]
2051
+ s3[i, j, idx[noff:], k, l_orient] = (
2052
+ self.backend.to_numpy(self.S3)[
2053
+ i, j, j2 == ij - noff, k, l_orient
2054
+ ]
2055
+ )
2049
2056
  s3[i, j, idx[:noff], k, l_orient] = (
2050
2057
  self.add_data_from_slope(
2051
2058
  self.backend.to_numpy(self.S3)[
@@ -2208,7 +2215,7 @@ class scat_cov:
2208
2215
 
2209
2216
 
2210
2217
  class funct(FOC.FoCUS):
2211
-
2218
+
2212
2219
  def fill(self, im, nullval=hp.UNSEEN):
2213
2220
  if self.use_2D:
2214
2221
  return self.fill_2d(im, nullval=nullval)
@@ -2277,14 +2284,14 @@ class funct(FOC.FoCUS):
2277
2284
  if tmp.S1 is not None:
2278
2285
  nS1 = np.expand_dims(tmp.S1, 0)
2279
2286
  else:
2280
- nS0 = np.expand_dims(tmp.S0.numpy(), 0)
2281
- nS2 = np.expand_dims(tmp.S2.numpy(), 0)
2282
- nS3 = np.expand_dims(tmp.S3.numpy(), 0)
2283
- nS4 = np.expand_dims(tmp.S4.numpy(), 0)
2287
+ nS0 = np.expand_dims(self.backend.to_numpy(tmp.S0), 0)
2288
+ nS2 = np.expand_dims(self.backend.to_numpy(tmp.S2), 0)
2289
+ nS3 = np.expand_dims(self.backend.to_numpy(tmp.S3), 0)
2290
+ nS4 = np.expand_dims(self.backend.to_numpy(tmp.S4), 0)
2284
2291
  if tmp.S3P is not None:
2285
- nS3P = np.expand_dims(tmp.S3P.numpy(), 0)
2292
+ nS3P = np.expand_dims(self.backend.to_numpy(tmp.S3P), 0)
2286
2293
  if tmp.S1 is not None:
2287
- nS1 = np.expand_dims(tmp.S1.numpy(), 0)
2294
+ nS1 = np.expand_dims(self.backend.to_numpy(tmp.S1), 0)
2288
2295
 
2289
2296
  if S0 is None:
2290
2297
  S0 = nS0
@@ -2304,24 +2311,24 @@ class funct(FOC.FoCUS):
2304
2311
  S3P = np.concatenate([S3P, nS3P], 0)
2305
2312
  if tmp.S1 is not None:
2306
2313
  S1 = np.concatenate([S1, nS1], 0)
2307
- sS0 = np.std(S0, 0)
2308
- sS2 = np.std(S2, 0)
2309
- sS3 = np.std(S3, 0)
2310
- sS4 = np.std(S4, 0)
2311
- mS0 = np.mean(S0, 0)
2312
- mS2 = np.mean(S2, 0)
2313
- mS3 = np.mean(S3, 0)
2314
- mS4 = np.mean(S4, 0)
2314
+ sS0 = self.backend.bk_cast(np.std(S0, 0))
2315
+ sS2 = self.backend.bk_cast(np.std(S2, 0))
2316
+ sS3 = self.backend.bk_cast(np.std(S3, 0))
2317
+ sS4 = self.backend.bk_cast(np.std(S4, 0))
2318
+ mS0 = self.backend.bk_cast(np.mean(S0, 0))
2319
+ mS2 = self.backend.bk_cast(np.mean(S2, 0))
2320
+ mS3 = self.backend.bk_cast(np.mean(S3, 0))
2321
+ mS4 = self.backend.bk_cast(np.mean(S4, 0))
2315
2322
  if tmp.S3P is not None:
2316
- sS3P = np.std(S3P, 0)
2317
- mS3P = np.mean(S3P, 0)
2323
+ sS3P = self.backend.bk_cast(np.std(S3P, 0))
2324
+ mS3P = self.backend.bk_cast(np.mean(S3P, 0))
2318
2325
  else:
2319
2326
  sS3P = None
2320
2327
  mS3P = None
2321
2328
 
2322
2329
  if tmp.S1 is not None:
2323
- sS1 = np.std(S1, 0)
2324
- mS1 = np.mean(S1, 0)
2330
+ sS1 = self.backend.bk_cast(np.std(S1, 0))
2331
+ mS1 = self.backend.bk_cast(np.mean(S1, 0))
2325
2332
  else:
2326
2333
  sS1 = None
2327
2334
  mS1 = None
@@ -2375,14 +2382,19 @@ class funct(FOC.FoCUS):
2375
2382
 
2376
2383
  # instead of difference between "opposite" channels use weighted average
2377
2384
  # of cosine and sine contributions using all channels
2378
- angles = (
2379
- 2 * np.pi * np.arange(self.NORIENT) / self.NORIENT
2385
+ angles = self.backend.bk_cast(
2386
+ (2 * np.pi * np.arange(self.NORIENT) / self.NORIENT).reshape(
2387
+ 1, 1, self.NORIENT
2388
+ )
2380
2389
  ) # shape: (NORIENT,)
2381
- angles = angles.reshape(1, 1, self.NORIENT)
2382
2390
 
2383
2391
  # we use cosines and sines as weights for sim
2384
- weighted_cos = self.backend.bk_reduce_mean(sim * np.cos(angles), axis=-1)
2385
- weighted_sin = self.backend.bk_reduce_mean(sim * np.sin(angles), axis=-1)
2392
+ weighted_cos = self.backend.bk_reduce_mean(
2393
+ sim * self.backend.bk_cos(angles), axis=-1
2394
+ )
2395
+ weighted_sin = self.backend.bk_reduce_mean(
2396
+ sim * self.backend.bk_sin(angles), axis=-1
2397
+ )
2386
2398
  # For simplicity, take first element of the batch
2387
2399
  cc = weighted_cos[0]
2388
2400
  ss = weighted_sin[0]
@@ -2402,7 +2414,9 @@ class funct(FOC.FoCUS):
2402
2414
  phase = np.fmod(np.arctan2(ss, cc) + 2 * np.pi, 2 * np.pi)
2403
2415
  else:
2404
2416
  phase = np.fmod(
2405
- np.arctan2(ss.numpy(), cc.numpy()) + 2 * np.pi, 2 * np.pi
2417
+ np.arctan2(self.backend.to_numpy(ss), self.backend.to_numpy(cc))
2418
+ + 2 * np.pi,
2419
+ 2 * np.pi,
2406
2420
  )
2407
2421
 
2408
2422
  # instead of linear interpolation cosine‐based interpolation
@@ -2416,10 +2430,10 @@ class funct(FOC.FoCUS):
2416
2430
  # build rotation matrix
2417
2431
  mat = np.zeros([sim.shape[1], self.NORIENT * self.NORIENT])
2418
2432
  lidx = np.arange(sim.shape[1])
2419
- for l in range(self.NORIENT):
2433
+ for ell in range(self.NORIENT):
2420
2434
  # Instead of simple linear weights, we use the cosine weights w0 and w1.
2421
- col0 = self.NORIENT * ((l + iph) % self.NORIENT) + l
2422
- col1 = self.NORIENT * ((l + iph + 1) % self.NORIENT) + l
2435
+ col0 = self.NORIENT * ((ell + iph) % self.NORIENT) + ell
2436
+ col1 = self.NORIENT * ((ell + iph + 1) % self.NORIENT) + ell
2423
2437
  mat[lidx, col0] = w0
2424
2438
  mat[lidx, col1] = w1
2425
2439
 
@@ -2435,7 +2449,9 @@ class funct(FOC.FoCUS):
2435
2449
 
2436
2450
  sim2 = self.backend.bk_reduce_sum(
2437
2451
  self.backend.bk_reshape(
2438
- mat.reshape(1, mat.shape[0], self.NORIENT * self.NORIENT)
2452
+ self.backend.bk_cast(
2453
+ mat.reshape(1, mat.shape[0], self.NORIENT * self.NORIENT)
2454
+ )
2439
2455
  * tmp2,
2440
2456
  [sim.shape[0], cmat[k].shape[0], self.NORIENT, self.NORIENT],
2441
2457
  ),
@@ -2445,10 +2461,10 @@ class funct(FOC.FoCUS):
2445
2461
  sim2 = self.backend.bk_abs(self.convol(sim2, axis=1))
2446
2462
 
2447
2463
  weighted_cos2 = self.backend.bk_reduce_mean(
2448
- sim2 * np.cos(angles), axis=-1
2464
+ sim2 * self.backend.bk_cos(angles), axis=-1
2449
2465
  )
2450
2466
  weighted_sin2 = self.backend.bk_reduce_mean(
2451
- sim2 * np.sin(angles), axis=-1
2467
+ sim2 * self.backend.bk_sin(angles), axis=-1
2452
2468
  )
2453
2469
 
2454
2470
  cc2 = weighted_cos2[0]
@@ -2469,7 +2485,11 @@ class funct(FOC.FoCUS):
2469
2485
  phase2 = np.fmod(np.arctan2(ss2, cc2) + 2 * np.pi, 2 * np.pi)
2470
2486
  else:
2471
2487
  phase2 = np.fmod(
2472
- np.arctan2(ss2.numpy(), cc2.numpy()) + 2 * np.pi, 2 * np.pi
2488
+ np.arctan2(
2489
+ self.backend.to_numpy(ss2), self.backend.to_numpy(cc2)
2490
+ )
2491
+ + 2 * np.pi,
2492
+ 2 * np.pi,
2473
2493
  )
2474
2494
 
2475
2495
  phase2_scaled = self.NORIENT * phase2 / (2 * np.pi)
@@ -2480,9 +2500,11 @@ class funct(FOC.FoCUS):
2480
2500
  lidx = np.arange(sim.shape[1])
2481
2501
 
2482
2502
  for m in range(self.NORIENT):
2483
- for l in range(self.NORIENT):
2484
- col0 = self.NORIENT * ((l + iph2[:, m]) % self.NORIENT) + l
2485
- col1 = self.NORIENT * ((l + iph2[:, m] + 1) % self.NORIENT) + l
2503
+ for ell in range(self.NORIENT):
2504
+ col0 = self.NORIENT * ((ell + iph2[:, m]) % self.NORIENT) + ell
2505
+ col1 = (
2506
+ self.NORIENT * ((ell + iph2[:, m] + 1) % self.NORIENT) + ell
2507
+ )
2486
2508
  mat2[k2, lidx, m, col0] = w0_2[:, m]
2487
2509
  mat2[k2, lidx, m, col1] = w1_2[:, m]
2488
2510
  cmat2[k] = self.backend.bk_cast(mat2.astype("complex64"))
@@ -2500,17 +2522,17 @@ class funct(FOC.FoCUS):
2500
2522
  )
2501
2523
 
2502
2524
  def eval(
2503
- self,
2504
- image1,
2505
- image2=None,
2506
- mask=None,
2507
- norm=None,
2508
- calc_var=False,
2509
- cmat=None,
2510
- cmat2=None,
2511
- Jmax=None,
2512
- out_nside=None,
2513
- edge=True
2525
+ self,
2526
+ image1,
2527
+ image2=None,
2528
+ mask=None,
2529
+ norm=None,
2530
+ calc_var=False,
2531
+ cmat=None,
2532
+ cmat2=None,
2533
+ Jmax=None,
2534
+ out_nside=None,
2535
+ edge=True,
2514
2536
  ):
2515
2537
  """
2516
2538
  Calculates the scattering correlations for a batch of images. Mean are done over pixels.
@@ -2540,9 +2562,9 @@ class funct(FOC.FoCUS):
2540
2562
  -------
2541
2563
  S1, S2, S3, S4 normalized
2542
2564
  """
2543
-
2565
+
2544
2566
  return_data = self.return_data
2545
-
2567
+
2546
2568
  # Check input consistency
2547
2569
  if image2 is not None:
2548
2570
  if list(image1.shape) != list(image2.shape):
@@ -2552,7 +2574,10 @@ class funct(FOC.FoCUS):
2552
2574
  return None
2553
2575
  if mask is not None:
2554
2576
  if self.use_2D:
2555
- if image1.shape[-2] != mask.shape[1] or image1.shape[-1] != mask.shape[2]:
2577
+ if (
2578
+ image1.shape[-2] != mask.shape[1]
2579
+ or image1.shape[-1] != mask.shape[2]
2580
+ ):
2556
2581
  print(
2557
2582
  "The LAST 2 COLUMNs of the mask should have the same size ",
2558
2583
  mask.shape,
@@ -2562,7 +2587,7 @@ class funct(FOC.FoCUS):
2562
2587
  )
2563
2588
  return None
2564
2589
  else:
2565
- if image1.shape[-1] != mask.shape[1]:
2590
+ if image1.shape[-1] != mask.shape[1]:
2566
2591
  print(
2567
2592
  "The LAST COLUMN of the mask should have the same size ",
2568
2593
  mask.shape,
@@ -2616,16 +2641,17 @@ class funct(FOC.FoCUS):
2616
2641
  nside = int(np.sqrt(npix // 12))
2617
2642
 
2618
2643
  J = int(np.log(nside) / np.log(2)) # Number of j scales
2619
-
2620
- if (self.use_2D or self.use_1D) and self.KERNELSZ>3:
2621
- J-=1
2644
+
2645
+ if (self.use_2D or self.use_1D) and self.KERNELSZ > 3:
2646
+ J -= 1
2622
2647
  if Jmax is None:
2623
2648
  Jmax = J # Number of steps for the loop on scales
2624
- if Jmax>J:
2625
- print('==========\n\n')
2626
- print('The Jmax you requested is larger than the data size, which may cause problems while computing the scattering transform.')
2627
- print('\n\n==========')
2628
-
2649
+ if Jmax > J:
2650
+ print("==========\n\n")
2651
+ print(
2652
+ "The Jmax you requested is larger than the data size, which may cause problems while computing the scattering transform."
2653
+ )
2654
+ print("\n\n==========")
2629
2655
 
2630
2656
  ### LOCAL VARIABLES (IMAGES and MASK)
2631
2657
  if len(image1.shape) == 1 or (len(image1.shape) == 2 and self.use_2D):
@@ -2768,32 +2794,37 @@ class funct(FOC.FoCUS):
2768
2794
  nside_j3 = nside # NSIDE start (nside_j3 = nside / 2^j3)
2769
2795
 
2770
2796
  # a remettre comme avant
2771
- M1_dic={}
2772
- M2_dic={}
2773
-
2797
+ M1_dic = {}
2798
+ M2_dic = {}
2799
+
2774
2800
  for j3 in range(Jmax):
2775
-
2801
+
2776
2802
  if edge:
2777
2803
  if self.mask_mask is None:
2778
- self.mask_mask={}
2804
+ self.mask_mask = {}
2779
2805
  if self.use_2D:
2780
- if (vmask.shape[1],vmask.shape[2]) not in self.mask_mask:
2781
- mask_mask=np.zeros([1,vmask.shape[1],vmask.shape[2]])
2782
- mask_mask[0,
2783
- self.KERNELSZ//2:-self.KERNELSZ//2+1,
2784
- self.KERNELSZ//2:-self.KERNELSZ//2+1]=1.0
2785
- self.mask_mask[(vmask.shape[1],vmask.shape[2])]=self.backend.bk_cast(mask_mask)
2786
- vmask=vmask*self.mask_mask[(vmask.shape[1],vmask.shape[2])]
2787
- #print(self.KERNELSZ//2,vmask,mask_mask)
2788
-
2806
+ if (vmask.shape[1], vmask.shape[2]) not in self.mask_mask:
2807
+ mask_mask = np.zeros([1, vmask.shape[1], vmask.shape[2]])
2808
+ mask_mask[
2809
+ 0,
2810
+ self.KERNELSZ // 2 : -self.KERNELSZ // 2 + 1,
2811
+ self.KERNELSZ // 2 : -self.KERNELSZ // 2 + 1,
2812
+ ] = 1.0
2813
+ self.mask_mask[(vmask.shape[1], vmask.shape[2])] = (
2814
+ self.backend.bk_cast(mask_mask)
2815
+ )
2816
+ vmask = vmask * self.mask_mask[(vmask.shape[1], vmask.shape[2])]
2817
+ # print(self.KERNELSZ//2,vmask,mask_mask)
2818
+
2789
2819
  if self.use_1D:
2790
2820
  if (vmask.shape[1]) not in self.mask_mask:
2791
- mask_mask=np.zeros([1,vmask.shape[1]])
2792
- mask_mask[0,
2793
- self.KERNELSZ//2:-self.KERNELSZ//2+1]=1.0
2794
- self.mask_mask[(vmask.shape[1])]=self.backend.bk_cast(mask_mask)
2795
- vmask=vmask*self.mask_mask[(vmask.shape[1])]
2796
-
2821
+ mask_mask = np.zeros([1, vmask.shape[1]])
2822
+ mask_mask[0, self.KERNELSZ // 2 : -self.KERNELSZ // 2 + 1] = 1.0
2823
+ self.mask_mask[(vmask.shape[1])] = self.backend.bk_cast(
2824
+ mask_mask
2825
+ )
2826
+ vmask = vmask * self.mask_mask[(vmask.shape[1])]
2827
+
2797
2828
  if return_data:
2798
2829
  S3[j3] = None
2799
2830
  S3P[j3] = None
@@ -3447,9 +3478,9 @@ class funct(FOC.FoCUS):
3447
3478
  M2_dic[j2], axis=1
3448
3479
  ) # [Nbatch, Npix_j3, Norient3]
3449
3480
  M2_dic[j2] = self.ud_grade_2(
3450
- M2, axis=1
3481
+ M2_smooth, axis=1
3451
3482
  ) # [Nbatch, Npix_j3, Norient3]
3452
-
3483
+
3453
3484
  ### Mask
3454
3485
  vmask = self.ud_grade_2(vmask, axis=1)
3455
3486
 
@@ -3541,1278 +3572,232 @@ class funct(FOC.FoCUS):
3541
3572
  use_1D=self.use_1D,
3542
3573
  )
3543
3574
 
3544
- def eval_new(
3545
- self,
3546
- image1,
3547
- image2=None,
3548
- mask=None,
3549
- norm=None,
3550
- calc_var=False,
3551
- cmat=None,
3552
- cmat2=None,
3553
- Jmax=None,
3554
- out_nside=None,
3555
- edge=True
3575
+ def clean_norm(self):
3576
+ self.P1_dic = None
3577
+ self.P2_dic = None
3578
+ return
3579
+
3580
+ def _compute_S3(
3581
+ self,
3582
+ j2,
3583
+ j3,
3584
+ conv,
3585
+ vmask,
3586
+ M_dic,
3587
+ MconvPsi_dic,
3588
+ calc_var=False,
3589
+ return_data=False,
3590
+ cmat2=None,
3556
3591
  ):
3557
3592
  """
3558
- Calculates the scattering correlations for a batch of images. Mean are done over pixels.
3559
- mean of modulus:
3560
- S1 = <|I * Psi_j3|>
3561
- Normalization : take the log
3562
- power spectrum:
3563
- S2 = <|I * Psi_j3|^2>
3564
- Normalization : take the log
3565
- orig. x modulus:
3566
- S3 = < (I * Psi)_j3 x (|I * Psi_j2| * Psi_j3)^* >
3567
- Normalization : divide by (S2_j2 * S2_j3)^0.5
3568
- modulus x modulus:
3569
- S4 = <(|I * psi1| * psi3)(|I * psi2| * psi3)^*>
3570
- Normalization : divide by (S2_j1 * S2_j2)^0.5
3593
+ Compute the S3 coefficients (auto or cross)
3594
+ S3 = < (Ia * Psi)_j3 x (|Ib * Psi_j2| * Psi_j3)^* >_pix
3571
3595
  Parameters
3572
3596
  ----------
3573
- image1: tensor
3574
- Image on which we compute the scattering coefficients [Nbatch, Npix, 1, 1]
3575
- image2: tensor
3576
- Second image. If not None, we compute cross-scattering covariance coefficients.
3577
- mask:
3578
- norm: None or str
3579
- If None no normalization is applied, if 'auto' normalize by the reference S2,
3580
- if 'self' normalize by the current S2.
3581
3597
  Returns
3582
3598
  -------
3583
- S1, S2, S3, S4 normalized
3599
+ cs3, ss3: real and imag parts of S3 coeff
3584
3600
  """
3585
- return_data = self.return_data
3586
- NORIENT=self.NORIENT
3587
- # Check input consistency
3588
- if image2 is not None:
3589
- if list(image1.shape) != list(image2.shape):
3590
- print(
3591
- "The two input image should have the same size to eval Scattering Covariance"
3592
- )
3593
- return None
3594
- if mask is not None:
3595
- if image1.shape[-2] != mask.shape[1] or image1.shape[-1] != mask.shape[2]:
3596
- print(
3597
- "The LAST COLUMN of the mask should have the same size ",
3598
- mask.shape,
3599
- "than the input image ",
3600
- image1.shape,
3601
- "to eval Scattering Covariance",
3602
- )
3603
- return None
3604
- if self.use_2D and len(image1.shape) < 2:
3605
- print(
3606
- "To work with 2D scattering transform, two dimension is needed, input map has only on dimension"
3601
+ ### Compute |I1 * Psi_j2| * Psi_j3 = M1_j2 * Psi_j3
3602
+ # Warning: M1_dic[j2] is already at j3 resolution [Nbatch, Npix_j3, Norient3]
3603
+ MconvPsi = self.convol(
3604
+ M_dic[j2], axis=1
3605
+ ) # [Nbatch, Npix_j3, Norient3, Norient2]
3606
+ if cmat2 is not None:
3607
+ tmp2 = self.backend.bk_repeat(MconvPsi, self.NORIENT, axis=-1)
3608
+ MconvPsi = self.backend.bk_reduce_sum(
3609
+ self.backend.bk_reshape(
3610
+ cmat2[j3][j2] * tmp2,
3611
+ [
3612
+ tmp2.shape[0],
3613
+ cmat2[j3].shape[1],
3614
+ self.NORIENT,
3615
+ self.NORIENT,
3616
+ self.NORIENT,
3617
+ ],
3618
+ ),
3619
+ 3,
3607
3620
  )
3608
- return None
3609
-
3610
- ### AUTO OR CROSS
3611
- cross = False
3612
- if image2 is not None:
3613
- cross = True
3614
3621
 
3615
- ### PARAMETERS
3616
- axis = 1
3617
- # determine jmax and nside corresponding to the input map
3618
- im_shape = image1.shape
3619
- if self.use_2D:
3620
- if len(image1.shape) == 2:
3621
- nside = np.min([im_shape[0], im_shape[1]])
3622
- npix = im_shape[0] * im_shape[1] # Number of pixels
3623
- x1 = im_shape[0]
3624
- x2 = im_shape[1]
3625
- else:
3626
- nside = np.min([im_shape[1], im_shape[2]])
3627
- npix = im_shape[1] * im_shape[2] # Number of pixels
3628
- x1 = im_shape[1]
3629
- x2 = im_shape[2]
3630
- J = int(np.log(nside - self.KERNELSZ) / np.log(2)) # Number of j scales
3631
- elif self.use_1D:
3632
- if len(image1.shape) == 2:
3633
- npix = int(im_shape[1]) # Number of pixels
3634
- else:
3635
- npix = int(im_shape[0]) # Number of pixels
3622
+ # Store it so we can use it in S4 computation
3623
+ MconvPsi_dic[j2] = MconvPsi # [Nbatch, Npix_j3, Norient3, Norient2]
3636
3624
 
3637
- nside = int(npix)
3625
+ ### Compute the product (I2 * Psi)_j3 x (M1_j2 * Psi_j3)^*
3626
+ # z_1 x z_2^* = (a1a2 + b1b2) + i(b1a2 - a1b2)
3627
+ # cconv, sconv are [Nbatch, Npix_j3, Norient3]
3628
+ if self.use_1D:
3629
+ s3 = conv * self.backend.bk_conjugate(MconvPsi)
3630
+ else:
3631
+ s3 = self.backend.bk_expand_dims(conv, -1) * self.backend.bk_conjugate(
3632
+ MconvPsi
3633
+ ) # [Nbatch, Npix_j3, Norient3, Norient2]
3638
3634
 
3639
- J = int(np.log(nside) / np.log(2)) # Number of j scales
3635
+ ### Apply the mask [Nmask, Npix_j3] and sum over pixels
3636
+ if return_data:
3637
+ return s3
3640
3638
  else:
3641
- if len(image1.shape) == 2:
3642
- npix = int(im_shape[1]) # Number of pixels
3639
+ if calc_var:
3640
+ s3, vs3 = self.masked_mean(
3641
+ s3, vmask, axis=1, rank=j2, calc_var=True
3642
+ ) # [Nbatch, Nmask, Norient3, Norient2]
3643
+ return s3, vs3
3643
3644
  else:
3644
- npix = int(im_shape[0]) # Number of pixels
3645
+ s3 = self.masked_mean(
3646
+ s3, vmask, axis=1, rank=j2
3647
+ ) # [Nbatch, Nmask, Norient3, Norient2]
3648
+ return s3
3645
3649
 
3646
- nside = int(np.sqrt(npix // 12))
3650
+ def _compute_S4(
3651
+ self,
3652
+ j1,
3653
+ j2,
3654
+ vmask,
3655
+ M1convPsi_dic,
3656
+ M2convPsi_dic=None,
3657
+ calc_var=False,
3658
+ return_data=False,
3659
+ ):
3660
+ #### Simplify notations
3661
+ M1 = M1convPsi_dic[j1] # [Nbatch, Npix_j3, Norient3, Norient1]
3647
3662
 
3648
- J = int(np.log(nside) / np.log(2)) # Number of j scales
3649
-
3650
- if (self.use_2D or self.use_1D) and self.KERNELSZ>3:
3651
- J-=1
3652
- if Jmax is None:
3653
- Jmax = J # Number of steps for the loop on scales
3654
- if Jmax>J:
3655
- print('==========\n\n')
3656
- print('The Jmax you requested is larger than the data size, which may cause problems while computing the scattering transform.')
3657
- print('\n\n==========')
3658
-
3663
+ # Auto or Cross coefficients
3664
+ if M2convPsi_dic is None: # Auto
3665
+ M2 = M1convPsi_dic[j2] # [Nbatch, Npix_j3, Norient3, Norient2]
3666
+ else: # Cross
3667
+ M2 = M2convPsi_dic[j2]
3659
3668
 
3660
- ### LOCAL VARIABLES (IMAGES and MASK)
3661
- if len(image1.shape) == 1 or (len(image1.shape) == 2 and self.use_2D):
3662
- I1 = self.backend.bk_cast(
3663
- self.backend.bk_expand_dims(image1, 0)
3664
- ) # Local image1 [Nbatch, Npix]
3665
- if cross:
3666
- I2 = self.backend.bk_cast(
3667
- self.backend.bk_expand_dims(image2, 0)
3668
- ) # Local image2 [Nbatch, Npix]
3669
+ ### Compute the product (|I1 * Psi_j1| * Psi_j3)(|I2 * Psi_j2| * Psi_j3)
3670
+ # z_1 x z_2^* = (a1a2 + b1b2) + i(b1a2 - a1b2)
3671
+ if self.use_1D:
3672
+ s4 = M1 * self.backend.bk_conjugate(M2)
3669
3673
  else:
3670
- I1 = self.backend.bk_cast(image1) # Local image1 [Nbatch, Npix]
3671
- if cross:
3672
- I2 = self.backend.bk_cast(image2) # Local image2 [Nbatch, Npix]
3674
+ s4 = self.backend.bk_expand_dims(M1, -2) * self.backend.bk_conjugate(
3675
+ self.backend.bk_expand_dims(M2, -1)
3676
+ ) # [Nbatch, Npix_j3, Norient3, Norient2, Norient1]
3673
3677
 
3674
- if mask is None:
3675
- if self.use_2D:
3676
- vmask = self.backend.bk_ones([1, x1, x2], dtype=self.all_type)
3677
- else:
3678
- vmask = self.backend.bk_ones([1, npix], dtype=self.all_type)
3678
+ ### Apply the mask and sum over pixels
3679
+ if return_data:
3680
+ return s4
3679
3681
  else:
3680
- vmask = self.backend.bk_cast(mask) # [Nmask, Npix]
3682
+ if calc_var:
3683
+ s4, vs4 = self.masked_mean(
3684
+ s4, vmask, axis=1, rank=j2, calc_var=True
3685
+ ) # [Nbatch, Nmask, Norient3, Norient2, Norient1]
3686
+ return s4, vs4
3687
+ else:
3688
+ s4 = self.masked_mean(
3689
+ s4, vmask, axis=1, rank=j2
3690
+ ) # [Nbatch, Nmask, Norient3, Norient2, Norient1]
3691
+ return s4
3681
3692
 
3682
- if self.KERNELSZ > 3 and not self.use_2D:
3683
- # if the kernel size is bigger than 3 increase the binning before smoothing
3684
- if self.use_2D:
3685
- vmask = self.up_grade(
3686
- vmask, I1.shape[axis] * 2, axis=1, nouty=I1.shape[axis + 1] * 2
3693
+ def computer_filter(self, M, N, J, L):
3694
+ """
3695
+ This function is strongly inspire by the package https://github.com/SihaoCheng/scattering_transform
3696
+ Done by Sihao Cheng and Rudy Morel.
3697
+ """
3698
+
3699
+ filter = np.zeros([J, L, M, N], dtype="complex64")
3700
+
3701
+ slant = 4.0 / L
3702
+
3703
+ for j in range(J):
3704
+
3705
+ for ell in range(L):
3706
+
3707
+ theta = (int(L - L / 2 - 1) - ell) * np.pi / L
3708
+ sigma = 0.8 * 2**j
3709
+ xi = 3.0 / 4.0 * np.pi / 2**j
3710
+
3711
+ R = np.array(
3712
+ [[np.cos(theta), -np.sin(theta)], [np.sin(theta), np.cos(theta)]],
3713
+ np.float64,
3687
3714
  )
3688
- I1 = self.up_grade(
3689
- I1, I1.shape[axis] * 2, axis=axis, nouty=I1.shape[axis + 1] * 2
3715
+ R_inv = np.array(
3716
+ [[np.cos(theta), np.sin(theta)], [-np.sin(theta), np.cos(theta)]],
3717
+ np.float64,
3690
3718
  )
3691
- if cross:
3692
- I2 = self.up_grade(
3693
- I2, I2.shape[axis] * 2, axis=axis, nouty=I2.shape[axis + 1] * 2
3694
- )
3695
- elif self.use_1D:
3696
- vmask = self.up_grade(vmask, I1.shape[axis] * 2, axis=1)
3697
- I1 = self.up_grade(I1, I1.shape[axis] * 2, axis=axis)
3698
- if cross:
3699
- I2 = self.up_grade(I2, I2.shape[axis] * 2, axis=axis)
3700
- else:
3701
- I1 = self.up_grade(I1, nside * 2, axis=axis)
3702
- vmask = self.up_grade(vmask, nside * 2, axis=1)
3703
- if cross:
3704
- I2 = self.up_grade(I2, nside * 2, axis=axis)
3705
-
3706
- if self.KERNELSZ > 5 and not self.use_2D:
3707
- # if the kernel size is bigger than 3 increase the binning before smoothing
3708
- if self.use_2D:
3709
- vmask = self.up_grade(
3710
- vmask, I1.shape[axis] * 2, axis=1, nouty=I1.shape[axis + 1] * 2
3711
- )
3712
- I1 = self.up_grade(
3713
- I1, I1.shape[axis] * 2, axis=axis, nouty=I1.shape[axis + 1] * 2
3714
- )
3715
- if cross:
3716
- I2 = self.up_grade(
3717
- I2,
3718
- I2.shape[axis] * 2,
3719
- axis=axis,
3720
- nouty=I2.shape[axis + 1] * 2,
3721
- )
3722
- elif self.use_1D:
3723
- vmask = self.up_grade(vmask, I1.shape[axis] * 4, axis=1)
3724
- I1 = self.up_grade(I1, I1.shape[axis] * 4, axis=axis)
3725
- if cross:
3726
- I2 = self.up_grade(I2, I2.shape[axis] * 4, axis=axis)
3727
- else:
3728
- I1 = self.up_grade(I1, nside * 4, axis=axis)
3729
- vmask = self.up_grade(vmask, nside * 4, axis=1)
3730
- if cross:
3731
- I2 = self.up_grade(I2, nside * 4, axis=axis)
3732
-
3733
- # Normalize the masks because they have different pixel numbers
3734
- # vmask /= self.backend.bk_reduce_sum(vmask, axis=1)[:, None] # [Nmask, Npix]
3735
-
3736
- ### INITIALIZATION
3737
- # Coefficients
3738
- if return_data:
3739
- S1 = {}
3740
- S2 = {}
3741
- S3 = {}
3742
- S3P = {}
3743
- S4 = {}
3744
- else:
3745
- result=self.backend.backend.zeros([I1.shape[0],vmask.shape[0],2+2*Jmax*self.NORIENT],
3746
- dtype=self.backend.backend.float32,
3747
- device=self.backend.torch_device)
3748
- vresult=self.backend.backend.zeros([I1.shape[0],vmask.shape[0],2+2*Jmax*self.NORIENT],
3749
- dtype=self.backend.backend.float32,
3750
- device=self.backend.torch_device)
3751
- S1 = self.backend.backend.zeros([1,Jmax*self.NORIENT],dtype=self.backend.backend.float32,device=self.backend.torch_device)
3752
- S2 = self.backend.backend.zeros([1,Jmax*self.NORIENT],dtype=self.backend.backend.float32,device=self.backend.torch_device)
3753
- S3 = []
3754
- S4 = []
3755
- S3P = []
3756
- VS1 = self.backend.backend.zeros([1,Jmax*self.NORIENT],dtype=self.backend.backend.float32,device=self.backend.torch_device)
3757
- VS2 = self.backend.backend.zeros([1,Jmax*self.NORIENT],dtype=self.backend.backend.float32,device=self.backend.torch_device)
3758
- VS3 = []
3759
- VS3P = []
3760
- VS4 = []
3761
-
3762
- off_S2 = -2
3763
- off_S3 = -3
3764
- off_S4 = -4
3765
- if self.use_1D:
3766
- off_S2 = -1
3767
- off_S3 = -1
3768
- off_S4 = -1
3769
-
3770
- # S2 for normalization
3771
- cond_init_P1_dic = (norm == "self") or (
3772
- (norm == "auto") and (self.P1_dic is None)
3773
- )
3774
- if norm is None:
3775
- pass
3776
- elif cond_init_P1_dic:
3777
- P1_dic = {}
3778
- if cross:
3779
- P2_dic = {}
3780
- elif (norm == "auto") and (self.P1_dic is not None):
3781
- P1_dic = self.P1_dic
3782
- if cross:
3783
- P2_dic = self.P2_dic
3784
-
3785
- if return_data:
3786
- s0 = I1
3787
- if out_nside is not None:
3788
- s0 = self.backend.bk_reduce_mean(
3789
- self.backend.bk_reshape(
3790
- s0, [s0.shape[0], 12 * out_nside**2, (nside // out_nside) ** 2]
3791
- ),
3792
- 2,
3793
- )
3794
- else:
3795
- if not cross:
3796
- s0, l_vs0 = self.masked_mean(I1, vmask, axis=1, calc_var=True)
3797
- else:
3798
- s0, l_vs0 = self.masked_mean(
3799
- self.backend.bk_L1(I1 * I2), vmask, axis=1, calc_var=True
3800
- )
3801
- #vs0 = self.backend.bk_concat([l_vs0, l_vs0], 1)
3802
- #s0 = self.backend.bk_concat([s0, l_vs0], 1)
3803
- result[:,:,0]=s0
3804
- result[:,:,1]=l_vs0
3805
- vresult[:,:,0]=l_vs0
3806
- vresult[:,:,1]=l_vs0
3807
- #### COMPUTE S1, S2, S3 and S4
3808
- nside_j3 = nside # NSIDE start (nside_j3 = nside / 2^j3)
3809
-
3810
- # a remettre comme avant
3811
- M1_dic={}
3812
- M2_dic={}
3813
-
3814
- for j3 in range(Jmax):
3815
-
3816
- if edge:
3817
- if self.mask_mask is None:
3818
- self.mask_mask={}
3819
- if self.use_2D:
3820
- if (vmask.shape[1],vmask.shape[2]) not in self.mask_mask:
3821
- mask_mask=np.zeros([1,vmask.shape[1],vmask.shape[2]])
3822
- mask_mask[0,
3823
- self.KERNELSZ//2:-self.KERNELSZ//2+1,
3824
- self.KERNELSZ//2:-self.KERNELSZ//2+1]=1.0
3825
- self.mask_mask[(vmask.shape[1],vmask.shape[2])]=self.backend.bk_cast(mask_mask)
3826
- vmask=vmask*self.mask_mask[(vmask.shape[1],vmask.shape[2])]
3827
- #print(self.KERNELSZ//2,vmask,mask_mask)
3828
-
3829
- if self.use_1D:
3830
- if (vmask.shape[1]) not in self.mask_mask:
3831
- mask_mask=np.zeros([1,vmask.shape[1]])
3832
- mask_mask[0,
3833
- self.KERNELSZ//2:-self.KERNELSZ//2+1]=1.0
3834
- self.mask_mask[(vmask.shape[1])]=self.backend.bk_cast(mask_mask)
3835
- vmask=vmask*self.mask_mask[(vmask.shape[1])]
3836
-
3837
- if return_data:
3838
- S3[j3] = None
3839
- S3P[j3] = None
3840
-
3841
- if S4 is None:
3842
- S4 = {}
3843
- S4[j3] = None
3844
-
3845
- ####### S1 and S2
3846
- ### Make the convolution I1 * Psi_j3
3847
- conv1 = self.convol(I1, axis=1) # [Nbatch, Npix_j3, Norient3]
3848
- if cmat is not None:
3849
- tmp2 = self.backend.bk_repeat(conv1, self.NORIENT, axis=-1)
3850
- conv1 = self.backend.bk_reduce_sum(
3851
- self.backend.bk_reshape(
3852
- cmat[j3] * tmp2,
3853
- [tmp2.shape[0], cmat[j3].shape[0], self.NORIENT, self.NORIENT],
3854
- ),
3855
- 2,
3856
- )
3857
-
3858
- ### Take the module M1 = |I1 * Psi_j3|
3859
- M1_square = conv1 * self.backend.bk_conjugate(
3860
- conv1
3861
- ) # [Nbatch, Npix_j3, Norient3]
3862
-
3863
- M1 = self.backend.bk_L1(M1_square) # [Nbatch, Npix_j3, Norient3]
3864
-
3865
- # Store M1_j3 in a dictionary
3866
- M1_dic[j3] = M1
3867
-
3868
- if not cross: # Auto
3869
- M1_square = self.backend.bk_real(M1_square)
3870
-
3871
- ### S2_auto = < M1^2 >_pix
3872
- # Apply the mask [Nmask, Npix_j3] and average over pixels
3873
- if return_data:
3874
- s2 = M1_square
3875
- else:
3876
- if calc_var:
3877
- s2, vs2 = self.masked_mean(
3878
- M1_square, vmask, axis=1, rank=j3, calc_var=True
3879
- )
3880
- #s2=self.backend.bk_flatten(self.backend.bk_real(s2))
3881
- #vs2=self.backend.bk_flatten(vs2)
3882
- else:
3883
- s2 = self.masked_mean(M1_square, vmask, axis=1, rank=j3)
3884
-
3885
- if cond_init_P1_dic:
3886
- # We fill P1_dic with S2 for normalisation of S3 and S4
3887
- P1_dic[j3] = self.backend.bk_real(self.backend.bk_real(s2)) # [Nbatch, Nmask, Norient3]
3888
-
3889
- # We store S2_auto to return it [Nbatch, Nmask, NS2, Norient3]
3890
- if return_data:
3891
- if S2 is None:
3892
- S2 = {}
3893
- if out_nside is not None and out_nside < nside_j3:
3894
- s2 = self.backend.bk_reduce_mean(
3895
- self.backend.bk_reshape(
3896
- s2,
3897
- [
3898
- s2.shape[0],
3899
- 12 * out_nside**2,
3900
- (nside_j3 // out_nside) ** 2,
3901
- s2.shape[2],
3902
- ],
3903
- ),
3904
- 2,
3905
- )
3906
- S2[j3] = s2
3907
- else:
3908
- if norm == "auto": # Normalize S2
3909
- s2 /= P1_dic[j3]
3910
- """
3911
- S2.append(
3912
- self.backend.bk_expand_dims(s2, off_S2)
3913
- ) # Add a dimension for NS2
3914
- if calc_var:
3915
- VS2.append(
3916
- self.backend.bk_expand_dims(vs2, off_S2)
3917
- ) # Add a dimension for NS2
3918
- """
3919
- #print(s2.shape,result[:,:,2+j3*NORIENT*2:2+j3*NORIENT*2+NORIENT].shape,result.shape,2+j3*NORIENT*2)
3920
- result[:,:,2+j3*NORIENT*2:2+j3*NORIENT*2+NORIENT]=s2
3921
- if calc_var:
3922
- vresult[:,:,2+j3*NORIENT*2:2+j3*NORIENT*2+NORIENT]=vs2
3923
- #### S1_auto computation
3924
- ### Image 1 : S1 = < M1 >_pix
3925
- # Apply the mask [Nmask, Npix_j3] and average over pixels
3926
- if return_data:
3927
- s1 = M1
3928
- else:
3929
- if calc_var:
3930
- s1, vs1 = self.masked_mean(
3931
- M1, vmask, axis=1, rank=j3, calc_var=True
3932
- ) # [Nbatch, Nmask, Norient3]
3933
- #s1=self.backend.bk_flatten(self.backend.bk_real(s1))
3934
- #vs1=self.backend.bk_flatten(vs1)
3935
- else:
3936
- s1 = self.masked_mean(
3937
- M1, vmask, axis=1, rank=j3
3938
- ) # [Nbatch, Nmask, Norient3]
3939
- #s1=self.backend.bk_flatten(self.backend.bk_real(s1))
3940
-
3941
- if return_data:
3942
- if out_nside is not None and out_nside < nside_j3:
3943
- s1 = self.backend.bk_reduce_mean(
3944
- self.backend.bk_reshape(
3945
- s1,
3946
- [
3947
- s1.shape[0],
3948
- 12 * out_nside**2,
3949
- (nside_j3 // out_nside) ** 2,
3950
- s1.shape[2],
3951
- ],
3952
- ),
3953
- 2,
3954
- )
3955
- S1[j3] = s1
3956
- else:
3957
- ### Normalize S1
3958
- if norm is not None:
3959
- self.div_norm(s1, (P1_dic[j3]) ** 0.5)
3960
- result[:,:,2+j3*NORIENT*2+NORIENT:2+j3*NORIENT*2+2*NORIENT]=s1
3961
- if calc_var:
3962
- vresult[:,:,2+j3*NORIENT*2+NORIENT:2+j3*NORIENT*2+2*NORIENT]=vs1
3963
- """
3964
- ### We store S1 for image1 [Nbatch, Nmask, NS1, Norient3]
3965
- S1.append(
3966
- self.backend.bk_expand_dims(s1, off_S2)
3967
- ) # Add a dimension for NS1
3968
- if calc_var:
3969
- VS1.append(
3970
- self.backend.bk_expand_dims(vs1, off_S2)
3971
- ) # Add a dimension for NS1
3972
- """
3973
-
3974
- else: # Cross
3975
- ### Make the convolution I2 * Psi_j3
3976
- conv2 = self.convol(I2, axis=1) # [Nbatch, Npix_j3, Norient3]
3977
- if cmat is not None:
3978
- tmp2 = self.backend.bk_repeat(conv2, self.NORIENT, axis=-1)
3979
- conv2 = self.backend.bk_reduce_sum(
3980
- self.backend.bk_reshape(
3981
- cmat[j3] * tmp2,
3982
- [
3983
- tmp2.shape[0],
3984
- cmat[j3].shape[0],
3985
- self.NORIENT,
3986
- self.NORIENT,
3987
- ],
3988
- ),
3989
- 2,
3990
- )
3991
- ### Take the module M2 = |I2 * Psi_j3|
3992
- M2_square = conv2 * self.backend.bk_conjugate(
3993
- conv2
3994
- ) # [Nbatch, Npix_j3, Norient3]
3995
- M2 = self.backend.bk_L1(M2_square) # [Nbatch, Npix_j3, Norient3]
3996
- # Store M2_j3 in a dictionary
3997
- M2_dic[j3] = M2
3998
-
3999
- ### S2_auto = < M2^2 >_pix
4000
- # Not returned, only for normalization
4001
- if cond_init_P1_dic:
4002
- # Apply the mask [Nmask, Npix_j3] and average over pixels
4003
- if return_data:
4004
- p1 = M1_square
4005
- p2 = M2_square
4006
- else:
4007
- if calc_var:
4008
- p1, vp1 = self.masked_mean(
4009
- M1_square, vmask, axis=1, rank=j3, calc_var=True
4010
- ) # [Nbatch, Nmask, Norient3]
4011
- p2, vp2 = self.masked_mean(
4012
- M2_square, vmask, axis=1, rank=j3, calc_var=True
4013
- ) # [Nbatch, Nmask, Norient3]
4014
- else:
4015
- p1 = self.masked_mean(
4016
- M1_square, vmask, axis=1, rank=j3
4017
- ) # [Nbatch, Nmask, Norient3]
4018
- p2 = self.masked_mean(
4019
- M2_square, vmask, axis=1, rank=j3
4020
- ) # [Nbatch, Nmask, Norient3]
4021
- # We fill P1_dic with S2 for normalisation of S3 and S4
4022
- P1_dic[j3] = self.backend.bk_real(p1) # [Nbatch, Nmask, Norient3]
4023
- P2_dic[j3] = self.backend.bk_real(p2) # [Nbatch, Nmask, Norient3]
4024
-
4025
- ### S2_cross = < (I1 * Psi_j3) (I2 * Psi_j3)^* >_pix
4026
- # z_1 x z_2^* = (a1a2 + b1b2) + i(b1a2 - a1b2)
4027
- s2 = conv1 * self.backend.bk_conjugate(conv2)
4028
- MX = self.backend.bk_L1(s2)
4029
- # Apply the mask [Nmask, Npix_j3] and average over pixels
4030
- if return_data:
4031
- s2 = s2
4032
- else:
4033
- if calc_var:
4034
- s2, vs2 = self.masked_mean(
4035
- s2, vmask, axis=1, rank=j3, calc_var=True
4036
- )
4037
- else:
4038
- s2 = self.masked_mean(s2, vmask, axis=1, rank=j3)
4039
-
4040
- if return_data:
4041
- if out_nside is not None and out_nside < nside_j3:
4042
- s2 = self.backend.bk_reduce_mean(
4043
- self.backend.bk_reshape(
4044
- s2,
4045
- [
4046
- s2.shape[0],
4047
- 12 * out_nside**2,
4048
- (nside_j3 // out_nside) ** 2,
4049
- s2.shape[2],
4050
- ],
4051
- ),
4052
- 2,
4053
- )
4054
- S2[j3] = s2
4055
- else:
4056
- ### Normalize S2_cross
4057
- if norm == "auto":
4058
- s2 /= (P1_dic[j3] * P2_dic[j3]) ** 0.5
4059
-
4060
- ### Store S2_cross as complex [Nbatch, Nmask, NS2, Norient3]
4061
- s2 = self.backend.bk_real(s2)
4062
-
4063
- S2.append(
4064
- self.backend.bk_expand_dims(s2, off_S2)
4065
- ) # Add a dimension for NS2
4066
- if calc_var:
4067
- VS2.append(
4068
- self.backend.bk_expand_dims(vs2, off_S2)
4069
- ) # Add a dimension for NS2
4070
-
4071
- #### S1_auto computation
4072
- ### Image 1 : S1 = < M1 >_pix
4073
- # Apply the mask [Nmask, Npix_j3] and average over pixels
4074
- if return_data:
4075
- s1 = MX
4076
- else:
4077
- if calc_var:
4078
- s1, vs1 = self.masked_mean(
4079
- MX, vmask, axis=1, rank=j3, calc_var=True
4080
- ) # [Nbatch, Nmask, Norient3]
4081
- else:
4082
- s1 = self.masked_mean(
4083
- MX, vmask, axis=1, rank=j3
4084
- ) # [Nbatch, Nmask, Norient3]
4085
- if return_data:
4086
- if out_nside is not None and out_nside < nside_j3:
4087
- s1 = self.backend.bk_reduce_mean(
4088
- self.backend.bk_reshape(
4089
- s1,
4090
- [
4091
- s1.shape[0],
4092
- 12 * out_nside**2,
4093
- (nside_j3 // out_nside) ** 2,
4094
- s1.shape[2],
4095
- ],
4096
- ),
4097
- 2,
4098
- )
4099
- S1[j3] = s1
4100
- else:
4101
- ### Normalize S1
4102
- if norm is not None:
4103
- self.div_norm(s1, (P1_dic[j3]) ** 0.5)
4104
- ### We store S1 for image1 [Nbatch, Nmask, NS1, Norient3]
4105
- S1.append(
4106
- self.backend.bk_expand_dims(s1, off_S2)
4107
- ) # Add a dimension for NS1
4108
- if calc_var:
4109
- VS1.append(
4110
- self.backend.bk_expand_dims(vs1, off_S2)
4111
- ) # Add a dimension for NS1
4112
-
4113
- # Initialize dictionaries for |I1*Psi_j| * Psi_j3
4114
- M1convPsi_dic = {}
4115
- if cross:
4116
- # Initialize dictionaries for |I2*Psi_j| * Psi_j3
4117
- M2convPsi_dic = {}
4118
-
4119
- ###### S3
4120
- nside_j2 = nside_j3
4121
- for j2 in range(0,-1): # j3 + 1): # j2 <= j3
4122
- if return_data:
4123
- if S4[j3] is None:
4124
- S4[j3] = {}
4125
- S4[j3][j2] = None
4126
-
4127
- ### S3_auto = < (I1 * Psi)_j3 x (|I1 * Psi_j2| * Psi_j3)^* >_pix
4128
- if not cross:
4129
- if calc_var:
4130
- s3, vs3 = self._compute_S3(
4131
- j2,
4132
- j3,
4133
- conv1,
4134
- vmask,
4135
- M1_dic,
4136
- M1convPsi_dic,
4137
- calc_var=True,
4138
- cmat2=cmat2,
4139
- ) # [Nbatch, Nmask, Norient3, Norient2]
4140
- else:
4141
- s3 = self._compute_S3(
4142
- j2,
4143
- j3,
4144
- conv1,
4145
- vmask,
4146
- M1_dic,
4147
- M1convPsi_dic,
4148
- return_data=return_data,
4149
- cmat2=cmat2,
4150
- ) # [Nbatch, Nmask, Norient3, Norient2]
4151
-
4152
- if return_data:
4153
- if S3[j3] is None:
4154
- S3[j3] = {}
4155
- if out_nside is not None and out_nside < nside_j2:
4156
- s3 = self.backend.bk_reduce_mean(
4157
- self.backend.bk_reshape(
4158
- s3,
4159
- [
4160
- s3.shape[0],
4161
- 12 * out_nside**2,
4162
- (nside_j2 // out_nside) ** 2,
4163
- s3.shape[2],
4164
- s3.shape[3],
4165
- ],
4166
- ),
4167
- 2,
4168
- )
4169
- S3[j3][j2] = s3
4170
- else:
4171
- ### Normalize S3 with S2_j [Nbatch, Nmask, Norient_j]
4172
- if norm is not None:
4173
- self.div_norm(
4174
- s3,
4175
- (
4176
- self.backend.bk_expand_dims(P1_dic[j2], off_S2)
4177
- * self.backend.bk_expand_dims(P1_dic[j3], -1)
4178
- )
4179
- ** 0.5,
4180
- ) # [Nbatch, Nmask, Norient3, Norient2]
4181
-
4182
- ### Store S3 as a complex [Nbatch, Nmask, NS3, Norient3, Norient2]
4183
-
4184
- # S3.append(self.backend.bk_reshape(s3,[s3.shape[0],s3.shape[1],
4185
- # s3.shape[2]*s3.shape[3]]))
4186
- S3.append(
4187
- self.backend.bk_expand_dims(s3, off_S3)
4188
- ) # Add a dimension for NS3
4189
- if calc_var:
4190
- VS3.append(
4191
- self.backend.bk_expand_dims(vs3, off_S3)
4192
- ) # Add a dimension for NS3
4193
- # VS3.append(self.backend.bk_reshape(vs3,[s3.shape[0],s3.shape[1],
4194
- # s3.shape[2]*s3.shape[3]]))
4195
-
4196
- ### S3_cross = < (I1 * Psi)_j3 x (|I2 * Psi_j2| * Psi_j3)^* >_pix
4197
- ### S3P_cross = < (I2 * Psi)_j3 x (|I1 * Psi_j2| * Psi_j3)^* >_pix
4198
- else:
4199
- if calc_var:
4200
- s3, vs3 = self._compute_S3(
4201
- j2,
4202
- j3,
4203
- conv1,
4204
- vmask,
4205
- M2_dic,
4206
- M2convPsi_dic,
4207
- calc_var=True,
4208
- cmat2=cmat2,
4209
- )
4210
- s3p, vs3p = self._compute_S3(
4211
- j2,
4212
- j3,
4213
- conv2,
4214
- vmask,
4215
- M1_dic,
4216
- M1convPsi_dic,
4217
- calc_var=True,
4218
- cmat2=cmat2,
4219
- )
4220
- else:
4221
- s3 = self._compute_S3(
4222
- j2,
4223
- j3,
4224
- conv1,
4225
- vmask,
4226
- M2_dic,
4227
- M2convPsi_dic,
4228
- return_data=return_data,
4229
- cmat2=cmat2,
4230
- )
4231
- s3p = self._compute_S3(
4232
- j2,
4233
- j3,
4234
- conv2,
4235
- vmask,
4236
- M1_dic,
4237
- M1convPsi_dic,
4238
- return_data=return_data,
4239
- cmat2=cmat2,
4240
- )
4241
-
4242
- if return_data:
4243
- if S3[j3] is None:
4244
- S3[j3] = {}
4245
- S3P[j3] = {}
4246
- if out_nside is not None and out_nside < nside_j2:
4247
- s3 = self.backend.bk_reduce_mean(
4248
- self.backend.bk_reshape(
4249
- s3,
4250
- [
4251
- s3.shape[0],
4252
- 12 * out_nside**2,
4253
- (nside_j2 // out_nside) ** 2,
4254
- s3.shape[2],
4255
- s3.shape[3],
4256
- ],
4257
- ),
4258
- 2,
4259
- )
4260
- s3p = self.backend.bk_reduce_mean(
4261
- self.backend.bk_reshape(
4262
- s3p,
4263
- [
4264
- s3.shape[0],
4265
- 12 * out_nside**2,
4266
- (nside_j2 // out_nside) ** 2,
4267
- s3.shape[2],
4268
- s3.shape[3],
4269
- ],
4270
- ),
4271
- 2,
4272
- )
4273
- S3[j3][j2] = s3
4274
- S3P[j3][j2] = s3p
4275
- else:
4276
- ### Normalize S3 and S3P with S2_j [Nbatch, Nmask, Norient_j]
4277
- if norm is not None:
4278
- self.div_norm(
4279
- s3,
4280
- (
4281
- self.backend.bk_expand_dims(P2_dic[j2], off_S2)
4282
- * self.backend.bk_expand_dims(P1_dic[j3], -1)
4283
- )
4284
- ** 0.5,
4285
- ) # [Nbatch, Nmask, Norient3, Norient2]
4286
- self.div_norm(
4287
- s3p,
4288
- (
4289
- self.backend.bk_expand_dims(P1_dic[j2], off_S2)
4290
- * self.backend.bk_expand_dims(P2_dic[j3], -1)
4291
- )
4292
- ** 0.5,
4293
- ) # [Nbatch, Nmask, Norient3, Norient2]
4294
-
4295
- ### Store S3 and S3P as a complex [Nbatch, Nmask, NS3, Norient3, Norient2]
4296
-
4297
- # S3.append(self.backend.bk_reshape(s3,[s3.shape[0],s3.shape[1],
4298
- # s3.shape[2]*s3.shape[3]]))
4299
- S3.append(
4300
- self.backend.bk_expand_dims(s3, off_S3)
4301
- ) # Add a dimension for NS3
4302
- if calc_var:
4303
- VS3.append(
4304
- self.backend.bk_expand_dims(vs3, off_S3)
4305
- ) # Add a dimension for NS3
4306
-
4307
- # VS3.append(self.backend.bk_reshape(vs3,[s3.shape[0],s3.shape[1],
4308
- # s3.shape[2]*s3.shape[3]]))
4309
-
4310
- # S3P.append(self.backend.bk_reshape(s3p,[s3.shape[0],s3.shape[1],
4311
- # s3.shape[2]*s3.shape[3]]))
4312
- S3P.append(
4313
- self.backend.bk_expand_dims(s3p, off_S3)
4314
- ) # Add a dimension for NS3
4315
- if calc_var:
4316
- VS3P.append(
4317
- self.backend.bk_expand_dims(vs3p, off_S3)
4318
- ) # Add a dimension for NS3
4319
- # VS3P.append(self.backend.bk_reshape(vs3p,[s3.shape[0],s3.shape[1],
4320
- # s3.shape[2]*s3.shape[3]]))
4321
-
4322
- ##### S4
4323
- nside_j1 = nside_j2
4324
- for j1 in range(0, j2 + 1): # j1 <= j2
4325
- ### S4_auto = <(|I1 * psi1| * psi3)(|I1 * psi2| * psi3)^*>
4326
- if not cross:
4327
- if calc_var:
4328
- s4, vs4 = self._compute_S4(
4329
- j1,
4330
- j2,
4331
- vmask,
4332
- M1convPsi_dic,
4333
- M2convPsi_dic=None,
4334
- calc_var=True,
4335
- ) # [Nbatch, Nmask, Norient3, Norient2, Norient1]
4336
- else:
4337
- s4 = self._compute_S4(
4338
- j1,
4339
- j2,
4340
- vmask,
4341
- M1convPsi_dic,
4342
- M2convPsi_dic=None,
4343
- return_data=return_data,
4344
- ) # [Nbatch, Nmask, Norient3, Norient2, Norient1]
4345
-
4346
- if return_data:
4347
- if S4[j3][j2] is None:
4348
- S4[j3][j2] = {}
4349
- if out_nside is not None and out_nside < nside_j1:
4350
- s4 = self.backend.bk_reduce_mean(
4351
- self.backend.bk_reshape(
4352
- s4,
4353
- [
4354
- s4.shape[0],
4355
- 12 * out_nside**2,
4356
- (nside_j1 // out_nside) ** 2,
4357
- s4.shape[2],
4358
- s4.shape[3],
4359
- s4.shape[4],
4360
- ],
4361
- ),
4362
- 2,
4363
- )
4364
- S4[j3][j2][j1] = s4
4365
- else:
4366
- ### Normalize S4 with S2_j [Nbatch, Nmask, Norient_j]
4367
- if norm is not None:
4368
- self.div_norm(
4369
- s4,
4370
- (
4371
- self.backend.bk_expand_dims(
4372
- self.backend.bk_expand_dims(
4373
- P1_dic[j1], off_S2
4374
- ),
4375
- off_S2,
4376
- )
4377
- * self.backend.bk_expand_dims(
4378
- self.backend.bk_expand_dims(
4379
- P1_dic[j2], off_S2
4380
- ),
4381
- -1,
4382
- )
4383
- )
4384
- ** 0.5,
4385
- ) # [Nbatch, Nmask, Norient3, Norient2, Norient1]
4386
- ### Store S4 as a complex [Nbatch, Nmask, NS4, Norient3, Norient2, Norient1]
4387
-
4388
- # S4.append(self.backend.bk_reshape(s4,[s4.shape[0],s4.shape[1],
4389
- # s4.shape[2]*s4.shape[3]*s4.shape[4]]))
4390
- S4.append(
4391
- self.backend.bk_expand_dims(s4, off_S4)
4392
- ) # Add a dimension for NS4
4393
- if calc_var:
4394
- # VS4.append(self.backend.bk_reshape(vs4,[s4.shape[0],s4.shape[1],
4395
- # s4.shape[2]*s4.shape[3]*s4.shape[4]]))
4396
- VS4.append(
4397
- self.backend.bk_expand_dims(vs4, off_S4)
4398
- ) # Add a dimension for NS4
4399
-
4400
- ### S4_cross = <(|I1 * psi1| * psi3)(|I2 * psi2| * psi3)^*>
4401
- else:
4402
- if calc_var:
4403
- s4, vs4 = self._compute_S4(
4404
- j1,
4405
- j2,
4406
- vmask,
4407
- M1convPsi_dic,
4408
- M2convPsi_dic=M2convPsi_dic,
4409
- calc_var=True,
4410
- ) # [Nbatch, Nmask, Norient3, Norient2, Norient1]
4411
- else:
4412
- s4 = self._compute_S4(
4413
- j1,
4414
- j2,
4415
- vmask,
4416
- M1convPsi_dic,
4417
- M2convPsi_dic=M2convPsi_dic,
4418
- return_data=return_data,
4419
- ) # [Nbatch, Nmask, Norient3, Norient2, Norient1]
4420
-
4421
- if return_data:
4422
- if S4[j3][j2] is None:
4423
- S4[j3][j2] = {}
4424
- if out_nside is not None and out_nside < nside_j1:
4425
- s4 = self.backend.bk_reduce_mean(
4426
- self.backend.bk_reshape(
4427
- s4,
4428
- [
4429
- s4.shape[0],
4430
- 12 * out_nside**2,
4431
- (nside_j1 // out_nside) ** 2,
4432
- s4.shape[2],
4433
- s4.shape[3],
4434
- s4.shape[4],
4435
- ],
4436
- ),
4437
- 2,
4438
- )
4439
- S4[j3][j2][j1] = s4
4440
- else:
4441
- ### Normalize S4 with S2_j [Nbatch, Nmask, Norient_j]
4442
- if norm is not None:
4443
- self.div_norm(
4444
- s4,
4445
- (
4446
- self.backend.bk_expand_dims(
4447
- self.backend.bk_expand_dims(
4448
- P1_dic[j1], off_S2
4449
- ),
4450
- off_S2,
4451
- )
4452
- * self.backend.bk_expand_dims(
4453
- self.backend.bk_expand_dims(
4454
- P2_dic[j2], off_S2
4455
- ),
4456
- -1,
4457
- )
4458
- )
4459
- ** 0.5,
4460
- ) # [Nbatch, Nmask, Norient3, Norient2, Norient1]
4461
- ### Store S4 as a complex [Nbatch, Nmask, NS4, Norient3, Norient2, Norient1]
4462
- # S4.append(self.backend.bk_reshape(s4,[s4.shape[0],s4.shape[1],
4463
- # s4.shape[2]*s4.shape[3]*s4.shape[4]]))
4464
- S4.append(
4465
- self.backend.bk_expand_dims(s4, off_S4)
4466
- ) # Add a dimension for NS4
4467
- if calc_var:
4468
-
4469
- # VS4.append(self.backend.bk_reshape(vs4,[s4.shape[0],s4.shape[1],
4470
- # s4.shape[2]*s4.shape[3]*s4.shape[4]]))
4471
- VS4.append(
4472
- self.backend.bk_expand_dims(vs4, off_S4)
4473
- ) # Add a dimension for NS4
4474
-
4475
- nside_j1 = nside_j1 // 2
4476
- nside_j2 = nside_j2 // 2
4477
-
4478
- ###### Reshape for next iteration on j3
4479
- ### Image I1,
4480
- # downscale the I1 [Nbatch, Npix_j3]
4481
- if j3 != Jmax - 1:
4482
- I1 = self.smooth(I1, axis=1)
4483
- I1 = self.ud_grade_2(I1, axis=1)
4484
-
4485
- ### Image I2
4486
- if cross:
4487
- I2 = self.smooth(I2, axis=1)
4488
- I2 = self.ud_grade_2(I2, axis=1)
4489
-
4490
- ### Modules
4491
- for j2 in range(0, j3 + 1): # j2 =< j3
4492
- ### Dictionary M1_dic[j2]
4493
- M1_smooth = self.smooth(
4494
- M1_dic[j2], axis=1
4495
- ) # [Nbatch, Npix_j3, Norient3]
4496
- M1_dic[j2] = self.ud_grade_2(
4497
- M1_smooth, axis=1
4498
- ) # [Nbatch, Npix_j3, Norient3]
4499
-
4500
- ### Dictionary M2_dic[j2]
4501
- if cross:
4502
- M2_smooth = self.smooth(
4503
- M2_dic[j2], axis=1
4504
- ) # [Nbatch, Npix_j3, Norient3]
4505
- M2_dic[j2] = self.ud_grade_2(
4506
- M2, axis=1
4507
- ) # [Nbatch, Npix_j3, Norient3]
4508
-
4509
- ### Mask
4510
- vmask = self.ud_grade_2(vmask, axis=1)
4511
-
4512
- if self.mask_thres is not None:
4513
- vmask = self.backend.bk_threshold(vmask, self.mask_thres)
4514
-
4515
- ### NSIDE_j3
4516
- nside_j3 = nside_j3 // 2
4517
-
4518
- ### Store P1_dic and P2_dic in self
4519
- if (norm == "auto") and (self.P1_dic is None):
4520
- self.P1_dic = P1_dic
4521
- if cross:
4522
- self.P2_dic = P2_dic
4523
- """
4524
- Sout=[s0]+S1+S2+S3+S4
4525
-
4526
- if cross:
4527
- Sout=Sout+S3P
4528
- if calc_var:
4529
- SVout=[vs0]+VS1+VS2+VS3+VS4
4530
- if cross:
4531
- VSout=VSout+VS3P
4532
- return self.backend.bk_concat(Sout, 2),self.backend.bk_concat(VSout, 2)
4533
-
4534
- return self.backend.bk_concat(Sout, 2)
4535
- """
4536
- if calc_var:
4537
- return result,vresult
4538
- else:
4539
- return result
4540
- if calc_var:
4541
- for k in S1:
4542
- print(k.shape,k.dtype)
4543
- for k in S2:
4544
- print(k.shape,k.dtype)
4545
- print(s0.shape,s0.dtype)
4546
- return self.backend.bk_concat([s0]+S1+S2,axis=1),self.backend.bk_concat([vs0]+VS1+VS2,axis=1)
4547
- else:
4548
- return self.backend.bk_concat([s0]+S1+S2,axis=1)
4549
-
4550
- if not return_data:
4551
- S1 = self.backend.bk_concat(S1, 2)
4552
- S2 = self.backend.bk_concat(S2, 2)
4553
- S3 = self.backend.bk_concat(S3, 2)
4554
- S4 = self.backend.bk_concat(S4, 2)
4555
- if cross:
4556
- S3P = self.backend.bk_concat(S3P, 2)
4557
- if calc_var:
4558
- VS1 = self.backend.bk_concat(VS1, 2)
4559
- VS2 = self.backend.bk_concat(VS2, 2)
4560
- VS3 = self.backend.bk_concat(VS3, 2)
4561
- VS4 = self.backend.bk_concat(VS4, 2)
4562
- if cross:
4563
- VS3P = self.backend.bk_concat(VS3P, 2)
4564
- if calc_var:
4565
- if not cross:
4566
- return scat_cov(
4567
- s0, S2, S3, S4, s1=S1, backend=self.backend, use_1D=self.use_1D
4568
- ), scat_cov(
4569
- vs0,
4570
- VS2,
4571
- VS3,
4572
- VS4,
4573
- s1=VS1,
4574
- backend=self.backend,
4575
- use_1D=self.use_1D,
4576
- )
4577
- else:
4578
- return scat_cov(
4579
- s0,
4580
- S2,
4581
- S3,
4582
- S4,
4583
- s1=S1,
4584
- s3p=S3P,
4585
- backend=self.backend,
4586
- use_1D=self.use_1D,
4587
- ), scat_cov(
4588
- vs0,
4589
- VS2,
4590
- VS3,
4591
- VS4,
4592
- s1=VS1,
4593
- s3p=VS3P,
4594
- backend=self.backend,
4595
- use_1D=self.use_1D,
4596
- )
4597
- else:
4598
- if not cross:
4599
- return scat_cov(
4600
- s0, S2, S3, S4, s1=S1, backend=self.backend, use_1D=self.use_1D
4601
- )
4602
- else:
4603
- return scat_cov(
4604
- s0,
4605
- S2,
4606
- S3,
4607
- S4,
4608
- s1=S1,
4609
- s3p=S3P,
4610
- backend=self.backend,
4611
- use_1D=self.use_1D,
4612
- )
4613
- def clean_norm(self):
4614
- self.P1_dic = None
4615
- self.P2_dic = None
4616
- return
4617
-
4618
- def _compute_S3(
4619
- self,
4620
- j2,
4621
- j3,
4622
- conv,
4623
- vmask,
4624
- M_dic,
4625
- MconvPsi_dic,
4626
- calc_var=False,
4627
- return_data=False,
4628
- cmat2=None,
4629
- ):
4630
- """
4631
- Compute the S3 coefficients (auto or cross)
4632
- S3 = < (Ia * Psi)_j3 x (|Ib * Psi_j2| * Psi_j3)^* >_pix
4633
- Parameters
4634
- ----------
4635
- Returns
4636
- -------
4637
- cs3, ss3: real and imag parts of S3 coeff
4638
- """
4639
- ### Compute |I1 * Psi_j2| * Psi_j3 = M1_j2 * Psi_j3
4640
- # Warning: M1_dic[j2] is already at j3 resolution [Nbatch, Npix_j3, Norient3]
4641
- MconvPsi = self.convol(
4642
- M_dic[j2], axis=1
4643
- ) # [Nbatch, Npix_j3, Norient3, Norient2]
4644
- if cmat2 is not None:
4645
- tmp2 = self.backend.bk_repeat(MconvPsi, self.NORIENT, axis=-1)
4646
- MconvPsi = self.backend.bk_reduce_sum(
4647
- self.backend.bk_reshape(
4648
- cmat2[j3][j2] * tmp2,
4649
- [
4650
- tmp2.shape[0],
4651
- cmat2[j3].shape[1],
4652
- self.NORIENT,
4653
- self.NORIENT,
4654
- self.NORIENT,
4655
- ],
4656
- ),
4657
- 3,
4658
- )
4659
-
4660
- # Store it so we can use it in S4 computation
4661
- MconvPsi_dic[j2] = MconvPsi # [Nbatch, Npix_j3, Norient3, Norient2]
4662
-
4663
- ### Compute the product (I2 * Psi)_j3 x (M1_j2 * Psi_j3)^*
4664
- # z_1 x z_2^* = (a1a2 + b1b2) + i(b1a2 - a1b2)
4665
- # cconv, sconv are [Nbatch, Npix_j3, Norient3]
4666
- if self.use_1D:
4667
- s3 = conv * self.backend.bk_conjugate(MconvPsi)
4668
- else:
4669
- s3 = self.backend.bk_expand_dims(conv, -1) * self.backend.bk_conjugate(
4670
- MconvPsi
4671
- ) # [Nbatch, Npix_j3, Norient3, Norient2]
4672
-
4673
- ### Apply the mask [Nmask, Npix_j3] and sum over pixels
4674
- if return_data:
4675
- return s3
4676
- else:
4677
- if calc_var:
4678
- s3, vs3 = self.masked_mean(
4679
- s3, vmask, axis=1, rank=j2, calc_var=True
4680
- ) # [Nbatch, Nmask, Norient3, Norient2]
4681
- return s3, vs3
4682
- else:
4683
- s3 = self.masked_mean(
4684
- s3, vmask, axis=1, rank=j2
4685
- ) # [Nbatch, Nmask, Norient3, Norient2]
4686
- return s3
4687
-
4688
- def _compute_S4(
4689
- self,
4690
- j1,
4691
- j2,
4692
- vmask,
4693
- M1convPsi_dic,
4694
- M2convPsi_dic=None,
4695
- calc_var=False,
4696
- return_data=False,
4697
- ):
4698
- #### Simplify notations
4699
- M1 = M1convPsi_dic[j1] # [Nbatch, Npix_j3, Norient3, Norient1]
4700
-
4701
- # Auto or Cross coefficients
4702
- if M2convPsi_dic is None: # Auto
4703
- M2 = M1convPsi_dic[j2] # [Nbatch, Npix_j3, Norient3, Norient2]
4704
- else: # Cross
4705
- M2 = M2convPsi_dic[j2]
4706
-
4707
- ### Compute the product (|I1 * Psi_j1| * Psi_j3)(|I2 * Psi_j2| * Psi_j3)
4708
- # z_1 x z_2^* = (a1a2 + b1b2) + i(b1a2 - a1b2)
4709
- if self.use_1D:
4710
- s4 = M1 * self.backend.bk_conjugate(M2)
4711
- else:
4712
- s4 = self.backend.bk_expand_dims(M1, -2) * self.backend.bk_conjugate(
4713
- self.backend.bk_expand_dims(M2, -1)
4714
- ) # [Nbatch, Npix_j3, Norient3, Norient2, Norient1]
4715
-
4716
- ### Apply the mask and sum over pixels
4717
- if return_data:
4718
- return s4
4719
- else:
4720
- if calc_var:
4721
- s4, vs4 = self.masked_mean(
4722
- s4, vmask, axis=1, rank=j2, calc_var=True
4723
- ) # [Nbatch, Nmask, Norient3, Norient2, Norient1]
4724
- return s4, vs4
4725
- else:
4726
- s4 = self.masked_mean(
4727
- s4, vmask, axis=1, rank=j2
4728
- ) # [Nbatch, Nmask, Norient3, Norient2, Norient1]
4729
- return s4
4730
-
4731
- def computer_filter(self,M,N,J,L):
4732
- '''
4733
- This function is strongly inspire by the package https://github.com/SihaoCheng/scattering_transform
4734
- Done by Sihao Cheng and Rudy Morel.
4735
- '''
4736
-
4737
- filter = np.zeros([J, L, M, N],dtype='complex64')
4738
-
4739
- slant=4.0 / L
4740
-
4741
- for j in range(J):
4742
-
4743
- for l in range(L):
4744
-
4745
- theta = (int(L-L/2-1)-l) * np.pi / L
4746
- sigma = 0.8 * 2**j
4747
- xi = 3.0 / 4.0 * np.pi /2**j
4748
-
4749
- R = np.array([[np.cos(theta), -np.sin(theta)], [np.sin(theta), np.cos(theta)]], np.float64)
4750
- R_inv = np.array([[np.cos(theta), np.sin(theta)], [-np.sin(theta), np.cos(theta)]], np.float64)
4751
3719
  D = np.array([[1, 0], [0, slant * slant]])
4752
- curv = np.matmul(R, np.matmul(D, R_inv)) / ( 2 * sigma * sigma)
4753
-
3720
+ curv = np.matmul(R, np.matmul(D, R_inv)) / (2 * sigma * sigma)
3721
+
4754
3722
  gab = np.zeros((M, N), np.complex128)
4755
- xx = np.empty((2,2, M, N))
4756
- yy = np.empty((2,2, M, N))
4757
-
3723
+ xx = np.empty((2, 2, M, N))
3724
+ yy = np.empty((2, 2, M, N))
3725
+
4758
3726
  for ii, ex in enumerate([-1, 0]):
4759
3727
  for jj, ey in enumerate([-1, 0]):
4760
- xx[ii,jj], yy[ii,jj] = np.mgrid[
4761
- ex * M : M + ex * M,
4762
- ey * N : N + ey * N]
4763
-
4764
- arg = -(curv[0, 0] * xx * xx + (curv[0, 1] + curv[1, 0]) * xx * yy + curv[1, 1] * yy * yy)
4765
- argi = arg + 1.j * (xx * xi * np.cos(theta) + yy * xi * np.sin(theta))
4766
-
4767
- gabi = np.exp(argi).sum((0,1))
4768
- gab = np.exp(arg).sum((0,1))
4769
-
3728
+ xx[ii, jj], yy[ii, jj] = np.mgrid[
3729
+ ex * M : M + ex * M, ey * N : N + ey * N
3730
+ ]
3731
+
3732
+ arg = -(
3733
+ curv[0, 0] * xx * xx
3734
+ + (curv[0, 1] + curv[1, 0]) * xx * yy
3735
+ + curv[1, 1] * yy * yy
3736
+ )
3737
+ argi = arg + 1.0j * (xx * xi * np.cos(theta) + yy * xi * np.sin(theta))
3738
+
3739
+ gabi = np.exp(argi).sum((0, 1))
3740
+ gab = np.exp(arg).sum((0, 1))
3741
+
4770
3742
  norm_factor = 2 * np.pi * sigma * sigma / slant
4771
-
3743
+
4772
3744
  gab = gab / norm_factor
4773
-
3745
+
4774
3746
  gabi = gabi / norm_factor
4775
3747
 
4776
3748
  K = gabi.sum() / gab.sum()
4777
3749
 
4778
3750
  # Apply the Gaussian
4779
- filter[j, l] = np.fft.fft2(gabi-K*gab)
4780
- filter[j,l,0,0]=0.0
4781
-
3751
+ filter[j, ell] = np.fft.fft2(gabi - K * gab)
3752
+ filter[j, ell, 0, 0] = 0.0
3753
+
4782
3754
  return self.backend.bk_cast(filter)
4783
-
3755
+
4784
3756
  # ------------------------------------------------------------------------------------------
4785
3757
  #
4786
- # utility functions
3758
+ # utility functions
4787
3759
  #
4788
3760
  # ------------------------------------------------------------------------------------------
4789
- def cut_high_k_off(self,data_f, dx, dy):
4790
- '''
3761
+ def cut_high_k_off(self, data_f, dx, dy):
3762
+ """
4791
3763
  This function is strongly inspire by the package https://github.com/SihaoCheng/scattering_transform
4792
3764
  Done by Sihao Cheng and Rudy Morel.
4793
- '''
4794
-
4795
- if self.backend.BACKEND=='torch':
4796
- if_xodd = (data_f.shape[-2]%2==1)
4797
- if_yodd = (data_f.shape[-1]%2==1)
3765
+ """
3766
+
3767
+ if self.backend.BACKEND == "torch":
3768
+ if_xodd = data_f.shape[-2] % 2 == 1
3769
+ if_yodd = data_f.shape[-1] % 2 == 1
4798
3770
  result = self.backend.backend.cat(
4799
- (self.backend.backend.cat(
4800
- ( data_f[...,:dx+if_xodd, :dy+if_yodd] , data_f[...,-dx:, :dy+if_yodd]
4801
- ), -2),
4802
- self.backend.backend.cat(
4803
- ( data_f[...,:dx+if_xodd, -dy:] , data_f[...,-dx:, -dy:]
4804
- ), -2)
4805
- ),-1)
3771
+ (
3772
+ self.backend.backend.cat(
3773
+ (
3774
+ data_f[..., : dx + if_xodd, : dy + if_yodd],
3775
+ data_f[..., -dx:, : dy + if_yodd],
3776
+ ),
3777
+ -2,
3778
+ ),
3779
+ self.backend.backend.cat(
3780
+ (data_f[..., : dx + if_xodd, -dy:], data_f[..., -dx:, -dy:]), -2
3781
+ ),
3782
+ ),
3783
+ -1,
3784
+ )
4806
3785
  return result
4807
3786
  else:
4808
3787
  # Check if the last two dimensions are odd
4809
- if_xodd = self.backend.backend.cast(self.backend.backend.shape(data_f)[-2] % 2 == 1, self.backend.backend.int32)
4810
- if_yodd = self.backend.backend.cast(self.backend.backend.shape(data_f)[-1] % 2 == 1, self.backend.backend.int32)
3788
+ if_xodd = self.backend.backend.cast(
3789
+ self.backend.backend.shape(data_f)[-2] % 2 == 1,
3790
+ self.backend.backend.int32,
3791
+ )
3792
+ if_yodd = self.backend.backend.cast(
3793
+ self.backend.backend.shape(data_f)[-1] % 2 == 1,
3794
+ self.backend.backend.int32,
3795
+ )
4811
3796
 
4812
3797
  # Extract four regions
4813
- top_left = data_f[..., :dx+if_xodd, :dy+if_yodd]
4814
- top_right = data_f[..., -dx:, :dy+if_yodd]
4815
- bottom_left = data_f[..., :dx+if_xodd, -dy:]
3798
+ top_left = data_f[..., : dx + if_xodd, : dy + if_yodd]
3799
+ top_right = data_f[..., -dx:, : dy + if_yodd]
3800
+ bottom_left = data_f[..., : dx + if_xodd, -dy:]
4816
3801
  bottom_right = data_f[..., -dx:, -dy:]
4817
3802
 
4818
3803
  # Concatenate along the last two dimensions
@@ -4827,70 +3812,74 @@ class funct(FOC.FoCUS):
4827
3812
  # utility functions for computing scattering coef and covariance
4828
3813
  #
4829
3814
  # ---------------------------------------------------------------------------
4830
-
4831
- def get_dxdy(self, j,M,N):
4832
- '''
3815
+
3816
+ def get_dxdy(self, j, M, N):
3817
+ """
4833
3818
  This function is strongly inspire by the package https://github.com/SihaoCheng/scattering_transform
4834
3819
  Done by Sihao Cheng and Rudy Morel.
4835
- '''
4836
- dx = int(max( 8, min( np.ceil(M/2**j), M//2 ) ))
4837
- dy = int(max( 8, min( np.ceil(N/2**j), N//2 ) ))
3820
+ """
3821
+ dx = int(max(8, min(np.ceil(M / 2**j), M // 2)))
3822
+ dy = int(max(8, min(np.ceil(N / 2**j), N // 2)))
4838
3823
  return dx, dy
4839
-
4840
-
4841
3824
 
4842
- def get_edge_masks(self,M, N, J, d0=1):
4843
- '''
3825
+ def get_edge_masks(self, M, N, J, d0=1):
3826
+ """
4844
3827
  This function is strongly inspire by the package https://github.com/SihaoCheng/scattering_transform
4845
3828
  Done by Sihao Cheng and Rudy Morel.
4846
- '''
3829
+ """
4847
3830
  edge_masks = np.empty((J, M, N))
4848
- X, Y = np.meshgrid(np.arange(M), np.arange(N), indexing='ij')
3831
+ X, Y = np.meshgrid(np.arange(M), np.arange(N), indexing="ij")
4849
3832
  for j in range(J):
4850
- edge_dx = min(M//4, 2**j*d0)
4851
- edge_dy = min(N//4, 2**j*d0)
4852
- edge_masks[j] = (X>=edge_dx) * (X<=M-edge_dx) * (Y>=edge_dy) * (Y<=N-edge_dy)
4853
- edge_masks = edge_masks[:,None,:,:]
4854
- edge_masks = edge_masks / edge_masks.mean((-2,-1))[:,:,None,None]
3833
+ edge_dx = min(M // 4, 2**j * d0)
3834
+ edge_dy = min(N // 4, 2**j * d0)
3835
+ edge_masks[j] = (
3836
+ (X >= edge_dx)
3837
+ * (X <= M - edge_dx)
3838
+ * (Y >= edge_dy)
3839
+ * (Y <= N - edge_dy)
3840
+ )
3841
+ edge_masks = edge_masks[:, None, :, :]
3842
+ edge_masks = edge_masks / edge_masks.mean((-2, -1))[:, :, None, None]
4855
3843
  return self.backend.bk_cast(edge_masks)
4856
-
3844
+
4857
3845
  # ---------------------------------------------------------------------------
4858
3846
  #
4859
3847
  # scattering cov
4860
3848
  #
4861
3849
  # ---------------------------------------------------------------------------
4862
3850
  def scattering_cov(
4863
- self, data,
3851
+ self,
3852
+ data,
4864
3853
  data2=None,
4865
3854
  Jmax=None,
4866
- if_large_batch=False,
4867
- S4_criteria=None,
4868
- use_ref=False,
4869
- normalization='S2',
3855
+ if_large_batch=False,
3856
+ S4_criteria=None,
3857
+ use_ref=False,
3858
+ normalization="S2",
4870
3859
  edge=False,
4871
- pseudo_coef=1,
4872
- get_variance=False,
3860
+ pseudo_coef=1,
3861
+ get_variance=False,
4873
3862
  ref_sigma=None,
4874
- iso_ang=False
3863
+ iso_ang=False,
4875
3864
  ):
4876
- '''
3865
+ """
4877
3866
  Calculates the scattering correlations for a batch of images, including:
4878
-
3867
+
4879
3868
  This function is strongly inspire by the package https://github.com/SihaoCheng/scattering_transform
4880
3869
  Done by Sihao Cheng and Rudy Morel.
4881
-
4882
- orig. x orig.:
3870
+
3871
+ orig. x orig.:
4883
3872
  P00 = <(I * psi)(I * psi)*> = L2(I * psi)^2
4884
- orig. x modulus:
3873
+ orig. x modulus:
4885
3874
  C01 = <(I * psi2)(|I * psi1| * psi2)*> / factor
4886
3875
  when normalization == 'P00', factor = L2(I * psi2) * L2(I * psi1)
4887
3876
  when normalization == 'P11', factor = L2(I * psi2) * L2(|I * psi1| * psi2)
4888
- modulus x modulus:
3877
+ modulus x modulus:
4889
3878
  C11_pre_norm = <(|I * psi1| * psi3)(|I * psi2| * psi3)>
4890
3879
  C11 = C11_pre_norm / factor
4891
3880
  when normalization == 'P00', factor = L2(I * psi1) * L2(I * psi2)
4892
3881
  when normalization == 'P11', factor = L2(|I * psi1| * psi3) * L2(|I * psi2| * psi3)
4893
- modulus x modulus (auto):
3882
+ modulus x modulus (auto):
4894
3883
  P11 = <(|I * psi1| * psi2)(|I * psi1| * psi2)*>
4895
3884
  Parameters
4896
3885
  ----------
@@ -4900,7 +3889,7 @@ class funct(FOC.FoCUS):
4900
3889
  It is recommended to use "False" unless one meets a memory issue
4901
3890
  C11_criteria : str or None (=None)
4902
3891
  Only C11 coefficients that satisfy this criteria will be computed.
4903
- Any expressions of j1, j2, and j3 that can be evaluated as a Bool
3892
+ Any expressions of j1, j2, and j3 that can be evaluated as a Bool
4904
3893
  is accepted.The default "None" corresponds to "j1 <= j2 <= j3".
4905
3894
  use_ref : Bool (=False)
4906
3895
  When normalizing, whether or not to use the normalization factor
@@ -4914,7 +3903,7 @@ class funct(FOC.FoCUS):
4914
3903
  If true, the edge region with a width of rougly the size of the largest
4915
3904
  wavelet involved is excluded when taking the global average to obtain
4916
3905
  the scattering coefficients.
4917
-
3906
+
4918
3907
  Returns
4919
3908
  -------
4920
3909
  'P00' : torch tensor with size [N_image, J, L] (# image, j1, l1)
@@ -4932,29 +3921,34 @@ class funct(FOC.FoCUS):
4932
3921
  j1 <= j3 are set to np.nan and not computed.
4933
3922
  'P11_iso' : torch tensor with size [N_image, J, J, L] (# image, j1, j2, l2-l1)
4934
3923
  'P11' averaged over l1 while keeping l2-l1 constant.
4935
- '''
3924
+ """
4936
3925
  if S4_criteria is None:
4937
- S4_criteria = 'j2>=j1'
4938
-
3926
+ S4_criteria = "j2>=j1"
3927
+
3928
+ if self.all_bk_type == "float32":
3929
+ C_ONE = np.complex64(1.0)
3930
+ else:
3931
+ C_ONE = np.complex128(1.0)
3932
+
4939
3933
  # determine jmax and nside corresponding to the input map
4940
3934
  im_shape = data.shape
4941
3935
  if self.use_2D:
4942
3936
  if len(data.shape) == 2:
4943
3937
  nside = np.min([im_shape[0], im_shape[1]])
4944
- M,N = im_shape[0],im_shape[1]
4945
- N_image = 1
3938
+ M, N = im_shape[0], im_shape[1]
3939
+ N_image = 1
4946
3940
  N_image2 = 1
4947
3941
  else:
4948
3942
  nside = np.min([im_shape[1], im_shape[2]])
4949
- M,N = im_shape[1],im_shape[2]
3943
+ M, N = im_shape[1], im_shape[2]
4950
3944
  N_image = data.shape[0]
4951
3945
  if data2 is not None:
4952
3946
  N_image2 = data2.shape[0]
4953
- J = int(np.log(nside) / np.log(2))-1 # Number of j scales
3947
+ J = int(np.log(nside) / np.log(2)) - 1 # Number of j scales
4954
3948
  elif self.use_1D:
4955
3949
  if len(data.shape) == 2:
4956
3950
  npix = int(im_shape[1]) # Number of pixels
4957
- N_image = 1
3951
+ N_image = 1
4958
3952
  N_image2 = 1
4959
3953
  else:
4960
3954
  npix = int(im_shape[0]) # Number of pixels
@@ -4964,12 +3958,12 @@ class funct(FOC.FoCUS):
4964
3958
 
4965
3959
  nside = int(npix)
4966
3960
 
4967
- J = int(np.log(nside) / np.log(2))-1 # Number of j scales
3961
+ J = int(np.log(nside) / np.log(2)) - 1 # Number of j scales
4968
3962
  else:
4969
3963
  if len(data.shape) == 2:
4970
3964
  npix = int(im_shape[1]) # Number of pixels
4971
- N_image = 1
4972
- N_image2 = 1
3965
+ N_image = 1
3966
+ N_image2 = 1
4973
3967
  else:
4974
3968
  npix = int(im_shape[0]) # Number of pixels
4975
3969
  N_image = data.shape[0]
@@ -4979,978 +3973,1715 @@ class funct(FOC.FoCUS):
4979
3973
  nside = int(np.sqrt(npix // 12))
4980
3974
 
4981
3975
  J = int(np.log(nside) / np.log(2)) # Number of j scales
4982
-
3976
+
4983
3977
  if Jmax is None:
4984
3978
  Jmax = J # Number of steps for the loop on scales
4985
- if Jmax>J:
4986
- print('==========\n\n')
4987
- print('The Jmax you requested is larger than the data size, which may cause problems while computing the scattering transform.')
4988
- print('\n\n==========')
4989
-
4990
- L=self.NORIENT
4991
- norm_factor_S3=1.0
4992
-
4993
- if self.backend.BACKEND=='torch':
4994
- if (M,N,J,L) not in self.filters_set:
4995
- self.filters_set[(M,N,J,L)] = self.computer_filter(M,N,J,L) #self.computer_filter(M,N,J,L)
4996
-
4997
- filters_set = self.filters_set[(M,N,J,L)]
4998
-
4999
- #weight = self.weight
3979
+ if Jmax > J:
3980
+ print("==========\n\n")
3981
+ print(
3982
+ "The Jmax you requested is larger than the data size, which may cause problems while computing the scattering transform."
3983
+ )
3984
+ print("\n\n==========")
3985
+
3986
+ L = self.NORIENT
3987
+ norm_factor_S3 = 1.0
3988
+
3989
+ if self.backend.BACKEND == "torch":
3990
+ if (M, N, J, L) not in self.filters_set:
3991
+ self.filters_set[(M, N, J, L)] = self.computer_filter(
3992
+ M, N, J, L
3993
+ ) # self.computer_filter(M,N,J,L)
3994
+
3995
+ filters_set = self.filters_set[(M, N, J, L)]
3996
+
3997
+ # weight = self.weight
5000
3998
  if use_ref:
5001
- if normalization=='S2':
3999
+ if normalization == "S2":
5002
4000
  ref_S2 = self.ref_scattering_cov_S2
5003
- else:
5004
- ref_P11 = self.ref_scattering_cov['P11']
4001
+ else:
4002
+ ref_P11 = self.ref_scattering_cov["P11"]
5005
4003
 
5006
4004
  # convert numpy array input into self.backend.bk_ tensors
5007
4005
  data = self.backend.bk_cast(data)
5008
- data_f = self.backend.bk_fftn(data, dim=(-2,-1))
4006
+ data_f = self.backend.bk_fftn(data, dim=(-2, -1))
5009
4007
  if data2 is not None:
5010
4008
  data2 = self.backend.bk_cast(data2)
5011
- data2_f = self.backend.bk_fftn(data2, dim=(-2,-1))
5012
-
5013
- # initialize tensors for scattering coefficients
5014
- S2 = self.backend.bk_zeros((N_image,J,L), dtype=data.dtype)
5015
- S1 = self.backend.bk_zeros((N_image,J,L), dtype=data.dtype)
5016
-
5017
- Ndata_S3 = J*(J+1)//2
5018
- Ndata_S4 = J*(J+1)*(J+2)//6
5019
- J_S4={}
5020
-
5021
- S3 = self.backend.bk_zeros((N_image,Ndata_S3,L,L), dtype=data_f.dtype)
4009
+ data2_f = self.backend.bk_fftn(data2, dim=(-2, -1))
4010
+
4011
+ # initialize tensors for scattering coefficients
4012
+ S2 = self.backend.bk_zeros((N_image, J, L), dtype=data.dtype)
4013
+ S1 = self.backend.bk_zeros((N_image, J, L), dtype=data.dtype)
4014
+
4015
+ Ndata_S3 = J * (J + 1) // 2
4016
+ Ndata_S4 = J * (J + 1) * (J + 2) // 6
4017
+ J_S4 = {}
4018
+
4019
+ S3 = self.backend.bk_zeros((N_image, Ndata_S3, L, L), dtype=data_f.dtype)
5022
4020
  if data2 is not None:
5023
- S3p = self.backend.bk_zeros((N_image,Ndata_S3,L,L), dtype=data_f.dtype)
5024
- S4_pre_norm = self.backend.bk_zeros((N_image,Ndata_S4,L,L,L), dtype=data_f.dtype)
5025
- S4 = self.backend.bk_zeros((N_image,Ndata_S4,L,L,L), dtype=data_f.dtype)
5026
-
4021
+ S3p = self.backend.bk_zeros(
4022
+ (N_image, Ndata_S3, L, L), dtype=data_f.dtype
4023
+ )
4024
+ S4_pre_norm = self.backend.bk_zeros(
4025
+ (N_image, Ndata_S4, L, L, L), dtype=data_f.dtype
4026
+ )
4027
+ S4 = self.backend.bk_zeros((N_image, Ndata_S4, L, L, L), dtype=data_f.dtype)
4028
+
5027
4029
  # variance
5028
4030
  if get_variance:
5029
- S2_sigma = self.backend.bk_zeros((N_image,J,L), dtype=data.dtype)
5030
- S1_sigma = self.backend.bk_zeros((N_image,J,L), dtype=data.dtype)
5031
- S3_sigma = self.backend.bk_zeros((N_image,Ndata_S3,L,L), dtype=data_f.dtype)
4031
+ S2_sigma = self.backend.bk_zeros((N_image, J, L), dtype=data.dtype)
4032
+ S1_sigma = self.backend.bk_zeros((N_image, J, L), dtype=data.dtype)
4033
+ S3_sigma = self.backend.bk_zeros(
4034
+ (N_image, Ndata_S3, L, L), dtype=data_f.dtype
4035
+ )
5032
4036
  if data2 is not None:
5033
- S3p_sigma = self.backend.bk_zeros((N_image,Ndata_S3,L,L), dtype=data_f.dtype)
5034
- S4_sigma = self.backend.bk_zeros((N_image,Ndata_S4,L,L,L), dtype=data_f.dtype)
5035
-
4037
+ S3p_sigma = self.backend.bk_zeros(
4038
+ (N_image, Ndata_S3, L, L), dtype=data_f.dtype
4039
+ )
4040
+ S4_sigma = self.backend.bk_zeros(
4041
+ (N_image, Ndata_S4, L, L, L), dtype=data_f.dtype
4042
+ )
4043
+
5036
4044
  if iso_ang:
5037
- S3_iso = self.backend.bk_zeros((N_image,Ndata_S3,L), dtype=data_f.dtype)
5038
- S4_iso = self.backend.bk_zeros((N_image,Ndata_S4,L,L), dtype=data_f.dtype)
4045
+ S3_iso = self.backend.bk_zeros(
4046
+ (N_image, Ndata_S3, L), dtype=data_f.dtype
4047
+ )
4048
+ S4_iso = self.backend.bk_zeros(
4049
+ (N_image, Ndata_S4, L, L), dtype=data_f.dtype
4050
+ )
5039
4051
  if get_variance:
5040
- S3_sigma_iso = self.backend.bk_zeros((N_image,Ndata_S3,L), dtype=data_f.dtype)
5041
- S4_sigma_iso = self.backend.bk_zeros((N_image,Ndata_S4,L,L), dtype=data_f.dtype)
4052
+ S3_sigma_iso = self.backend.bk_zeros(
4053
+ (N_image, Ndata_S3, L), dtype=data_f.dtype
4054
+ )
4055
+ S4_sigma_iso = self.backend.bk_zeros(
4056
+ (N_image, Ndata_S4, L, L), dtype=data_f.dtype
4057
+ )
5042
4058
  if data2 is not None:
5043
- S3p_iso = self.backend.bk_zeros((N_image,Ndata_S3,L), dtype=data_f.dtype)
4059
+ S3p_iso = self.backend.bk_zeros(
4060
+ (N_image, Ndata_S3, L), dtype=data_f.dtype
4061
+ )
5044
4062
  if get_variance:
5045
- S3p_sigma_iso = self.backend.bk_zeros((N_image,Ndata_S3,L), dtype=data_f.dtype)
5046
-
4063
+ S3p_sigma_iso = self.backend.bk_zeros(
4064
+ (N_image, Ndata_S3, L), dtype=data_f.dtype
4065
+ )
4066
+
5047
4067
  #
5048
- if edge:
5049
- if (M,N,J) not in self.edge_masks:
5050
- self.edge_masks[(M,N,J)] = self.get_edge_masks(M,N,J)
5051
- edge_mask=self.edge_masks[(M,N,J)]
5052
- else:
4068
+ if edge:
4069
+ if (M, N, J) not in self.edge_masks:
4070
+ self.edge_masks[(M, N, J)] = self.get_edge_masks(M, N, J)
4071
+ edge_mask = self.edge_masks[(M, N, J)]
4072
+ else:
5053
4073
  edge_mask = 1
5054
-
4074
+
5055
4075
  # calculate scattering fields
5056
4076
  if data2 is None:
5057
4077
  if self.use_2D:
5058
4078
  if len(data.shape) == 2:
5059
4079
  I1 = self.backend.bk_ifftn(
5060
- data_f[None,None,None,:,:] * filters_set[None,:J,:,:,:], dim=(-2,-1)
4080
+ data_f[None, None, None, :, :]
4081
+ * filters_set[None, :J, :, :, :],
4082
+ dim=(-2, -1),
5061
4083
  ).abs()
5062
4084
  else:
5063
4085
  I1 = self.backend.bk_ifftn(
5064
- data_f[:,None,None,:,:] * filters_set[None,:J,:,:,:], dim=(-2,-1)
4086
+ data_f[:, None, None, :, :]
4087
+ * filters_set[None, :J, :, :, :],
4088
+ dim=(-2, -1),
5065
4089
  ).abs()
5066
4090
  elif self.use_1D:
5067
4091
  if len(data.shape) == 1:
5068
4092
  I1 = self.backend.bk_ifftn(
5069
- data_f[None,None,None,:] * filters_set[None,:J,:,:], dim=(-1)
4093
+ data_f[None, None, None, :] * filters_set[None, :J, :, :],
4094
+ dim=(-1),
5070
4095
  ).abs()
5071
4096
  else:
5072
4097
  I1 = self.backend.bk_ifftn(
5073
- data_f[:,None,None,:] * filters_set[None,:J,:,:], dim=(-1)
4098
+ data_f[:, None, None, :] * filters_set[None, :J, :, :],
4099
+ dim=(-1),
5074
4100
  ).abs()
5075
4101
  else:
5076
- print('todo')
5077
-
5078
- S2 = (I1**2 * edge_mask).mean((-2,-1))
5079
- S1 = (I1 * edge_mask).mean((-2,-1))
4102
+ print("todo")
4103
+
4104
+ S2 = (I1**2 * edge_mask).mean((-2, -1))
4105
+ S1 = (I1 * edge_mask).mean((-2, -1))
5080
4106
 
5081
4107
  if get_variance:
5082
- S2_sigma = (I1**2 * edge_mask).std((-2,-1))
5083
- S1_sigma = (I1 * edge_mask).std((-2,-1))
5084
-
4108
+ S2_sigma = (I1**2 * edge_mask).std((-2, -1))
4109
+ S1_sigma = (I1 * edge_mask).std((-2, -1))
4110
+
5085
4111
  else:
5086
4112
  if self.use_2D:
5087
4113
  if len(data.shape) == 2:
5088
4114
  I1 = self.backend.bk_ifftn(
5089
- data_f[None,None,None,:,:] * filters_set[None,:J,:,:,:], dim=(-2,-1)
4115
+ data_f[None, None, None, :, :]
4116
+ * filters_set[None, :J, :, :, :],
4117
+ dim=(-2, -1),
5090
4118
  )
5091
4119
  I2 = self.backend.bk_ifftn(
5092
- data2_f[None,None,None,:,:] * filters_set[None,:J,:,:,:], dim=(-2,-1)
4120
+ data2_f[None, None, None, :, :]
4121
+ * filters_set[None, :J, :, :, :],
4122
+ dim=(-2, -1),
5093
4123
  )
5094
4124
  else:
5095
4125
  I1 = self.backend.bk_ifftn(
5096
- data_f[:,None,None,:,:] * filters_set[None,:J,:,:,:], dim=(-2,-1)
4126
+ data_f[:, None, None, :, :]
4127
+ * filters_set[None, :J, :, :, :],
4128
+ dim=(-2, -1),
5097
4129
  )
5098
4130
  I2 = self.backend.bk_ifftn(
5099
- data2_f[:,None,None,:,:] * filters_set[None,:J,:,:,:], dim=(-2,-1)
4131
+ data2_f[:, None, None, :, :]
4132
+ * filters_set[None, :J, :, :, :],
4133
+ dim=(-2, -1),
5100
4134
  )
5101
4135
  elif self.use_1D:
5102
4136
  if len(data.shape) == 1:
5103
4137
  I1 = self.backend.bk_ifftn(
5104
- data_f[None,None,None,:] * filters_set[None,:J,:,:], dim=(-1)
4138
+ data_f[None, None, None, :] * filters_set[None, :J, :, :],
4139
+ dim=(-1),
5105
4140
  )
5106
4141
  I2 = self.backend.bk_ifftn(
5107
- data2_f[None,None,None,:] * filters_set[None,:J,:,:], dim=(-1)
4142
+ data2_f[None, None, None, :] * filters_set[None, :J, :, :],
4143
+ dim=(-1),
5108
4144
  )
5109
4145
  else:
5110
4146
  I1 = self.backend.bk_ifftn(
5111
- data_f[:,None,None,:] * filters_set[None,:J,:,:], dim=(-1)
4147
+ data_f[:, None, None, :] * filters_set[None, :J, :, :],
4148
+ dim=(-1),
5112
4149
  )
5113
4150
  I2 = self.backend.bk_ifftn(
5114
- data2_f[:,None,None,:] * filters_set[None,:J,:,:], dim=(-1)
4151
+ data2_f[:, None, None, :] * filters_set[None, :J, :, :],
4152
+ dim=(-1),
5115
4153
  )
5116
4154
  else:
5117
- print('todo')
5118
-
5119
- I1=self.backend.bk_real(I1*self.backend.bk_conjugate(I2))
5120
-
5121
- S2 = self.backend.bk_reduce_mean((I1 * edge_mask),axis=(-2,-1))
4155
+ print("todo")
4156
+
4157
+ I1 = self.backend.bk_real(I1 * self.backend.bk_conjugate(I2))
4158
+
4159
+ S2 = self.backend.bk_reduce_mean((I1 * edge_mask), axis=(-2, -1))
5122
4160
  if get_variance:
5123
- S2_sigma = self.backend.bk_reduce_std((I1 * edge_mask),axis=(-2,-1))
5124
-
5125
- I1=self.backend.bk_L1(I1)
5126
-
5127
- S1 = self.backend.bk_reduce_mean((I1 * edge_mask),axis=(-2,-1))
4161
+ S2_sigma = self.backend.bk_reduce_std(
4162
+ (I1 * edge_mask), axis=(-2, -1)
4163
+ )
4164
+
4165
+ I1 = self.backend.bk_L1(I1)
4166
+
4167
+ S1 = self.backend.bk_reduce_mean((I1 * edge_mask), axis=(-2, -1))
5128
4168
 
5129
4169
  if get_variance:
5130
- S1_sigma = self.backend.bk_reduce_std((I1 * edge_mask),axis=(-2,-1))
5131
-
5132
- I1_f= self.backend.bk_fftn(I1, dim=(-2,-1))
5133
-
4170
+ S1_sigma = self.backend.bk_reduce_std(
4171
+ (I1 * edge_mask), axis=(-2, -1)
4172
+ )
4173
+
4174
+ I1_f = self.backend.bk_fftn(I1, dim=(-2, -1))
4175
+
5134
4176
  if pseudo_coef != 1:
5135
4177
  I1 = I1**pseudo_coef
5136
-
5137
- Ndata_S3=0
5138
- Ndata_S4=0
5139
-
4178
+
4179
+ Ndata_S3 = 0
4180
+ Ndata_S4 = 0
4181
+
5140
4182
  # calculate the covariance and correlations of the scattering fields
5141
4183
  # only use the low-k Fourier coefs when calculating large-j scattering coefs.
5142
- for j3 in range(0,J):
5143
- J_S4[j3]=Ndata_S4
5144
-
5145
- dx3, dy3 = self.get_dxdy(j3,M,N)
5146
- I1_f_small = self.cut_high_k_off(I1_f[:,:j3+1], dx3, dy3) # Nimage, J, L, x, y
4184
+ for j3 in range(0, J):
4185
+ J_S4[j3] = Ndata_S4
4186
+
4187
+ dx3, dy3 = self.get_dxdy(j3, M, N)
4188
+ I1_f_small = self.cut_high_k_off(
4189
+ I1_f[:, : j3 + 1], dx3, dy3
4190
+ ) # Nimage, J, L, x, y
5147
4191
  data_f_small = self.cut_high_k_off(data_f, dx3, dy3)
5148
4192
  if data2 is not None:
5149
4193
  data2_f_small = self.cut_high_k_off(data2_f, dx3, dy3)
5150
4194
  if edge:
5151
- I1_small = self.backend.bk_ifftn(I1_f_small, dim=(-2,-1), norm='ortho')
5152
- data_small = self.backend.bk_ifftn(data_f_small, dim=(-2,-1), norm='ortho')
4195
+ I1_small = self.backend.bk_ifftn(
4196
+ I1_f_small, dim=(-2, -1), norm="ortho"
4197
+ )
4198
+ data_small = self.backend.bk_ifftn(
4199
+ data_f_small, dim=(-2, -1), norm="ortho"
4200
+ )
5153
4201
  if data2 is not None:
5154
- data2_small = self.backend.bk_ifftn(data2_f_small, dim=(-2,-1), norm='ortho')
5155
-
5156
- wavelet_f3 = self.cut_high_k_off(filters_set[j3], dx3, dy3) # L,x,y
4202
+ data2_small = self.backend.bk_ifftn(
4203
+ data2_f_small, dim=(-2, -1), norm="ortho"
4204
+ )
4205
+
4206
+ wavelet_f3 = self.cut_high_k_off(filters_set[j3], dx3, dy3) # L,x,y
5157
4207
  _, M3, N3 = wavelet_f3.shape
5158
4208
  wavelet_f3_squared = wavelet_f3**2
5159
- edge_dx = min(4, int(2**j3*dx3*2/M))
5160
- edge_dy = min(4, int(2**j3*dy3*2/N))
5161
-
4209
+ edge_dx = min(4, int(2**j3 * dx3 * 2 / M))
4210
+ edge_dy = min(4, int(2**j3 * dy3 * 2 / N))
4211
+
5162
4212
  # a normalization change due to the cutoff of frequency space
5163
- fft_factor = 1 /(M3*N3) * (M3*N3/M/N)**2
5164
- for j2 in range(0,j3+1):
5165
- I1_f2_wf3_small = I1_f_small[:,j2].view(N_image,L,1,M3,N3) * wavelet_f3.view(1,1,L,M3,N3)
5166
- I1_f2_wf3_2_small = I1_f_small[:,j2].view(N_image,L,1,M3,N3) * wavelet_f3_squared.view(1,1,L,M3,N3)
4213
+ fft_factor = 1 / (M3 * N3) * (M3 * N3 / M / N) ** 2
4214
+ for j2 in range(0, j3 + 1):
4215
+ I1_f2_wf3_small = I1_f_small[:, j2].view(
4216
+ N_image, L, 1, M3, N3
4217
+ ) * wavelet_f3.view(1, 1, L, M3, N3)
4218
+ I1_f2_wf3_2_small = I1_f_small[:, j2].view(
4219
+ N_image, L, 1, M3, N3
4220
+ ) * wavelet_f3_squared.view(1, 1, L, M3, N3)
5167
4221
  if edge:
5168
- I12_w3_small = self.backend.bk_ifftn(I1_f2_wf3_small, dim=(-2,-1), norm='ortho')
5169
- I12_w3_2_small = self.backend.bk_ifftn(I1_f2_wf3_2_small, dim=(-2,-1), norm='ortho')
4222
+ I12_w3_small = self.backend.bk_ifftn(
4223
+ I1_f2_wf3_small, dim=(-2, -1), norm="ortho"
4224
+ )
4225
+ I12_w3_2_small = self.backend.bk_ifftn(
4226
+ I1_f2_wf3_2_small, dim=(-2, -1), norm="ortho"
4227
+ )
5170
4228
  if use_ref:
5171
- if normalization=='P11':
5172
- norm_factor_S3 = (ref_S2[:,None,j3,:] * ref_P11[:,j2,j3,:,:]**pseudo_coef)**0.5
5173
- if normalization=='S2':
5174
- norm_factor_S3 = (ref_S2[:,None,j3,:] * ref_S2[:,j2,:,None]**pseudo_coef)**0.5
4229
+ if normalization == "P11":
4230
+ norm_factor_S3 = (
4231
+ ref_S2[:, None, j3, :]
4232
+ * ref_P11[:, j2, j3, :, :] ** pseudo_coef
4233
+ ) ** 0.5
4234
+ if normalization == "S2":
4235
+ norm_factor_S3 = (
4236
+ ref_S2[:, None, j3, :]
4237
+ * ref_S2[:, j2, :, None] ** pseudo_coef
4238
+ ) ** 0.5
5175
4239
  else:
5176
- if normalization=='P11':
4240
+ if normalization == "P11":
5177
4241
  # [N_image,l2,l3,x,y]
5178
- P11_temp = (I1_f2_wf3_small.abs()**2).mean((-2,-1)) * fft_factor
5179
- norm_factor_S3 = (S2[:,None,j3,:] * P11_temp**pseudo_coef)**0.5
5180
- if normalization=='S2':
5181
- norm_factor_S3 = (S2[:,None,j3,:] * S2[:,j2,:,None]**pseudo_coef)**0.5
4242
+ P11_temp = (I1_f2_wf3_small.abs() ** 2).mean(
4243
+ (-2, -1)
4244
+ ) * fft_factor
4245
+ norm_factor_S3 = (
4246
+ S2[:, None, j3, :] * P11_temp**pseudo_coef
4247
+ ) ** 0.5
4248
+ if normalization == "S2":
4249
+ norm_factor_S3 = (
4250
+ S2[:, None, j3, :] * S2[:, j2, :, None] ** pseudo_coef
4251
+ ) ** 0.5
5182
4252
 
5183
4253
  if not edge:
5184
- S3[:,Ndata_S3,:,:] = (
5185
- data_f_small.view(N_image,1,1,M3,N3) * self.backend.bk_conjugate(I1_f2_wf3_small)
5186
- ).mean((-2,-1)) * fft_factor / norm_factor_S3
5187
-
4254
+ S3[:, Ndata_S3, :, :] = (
4255
+ (
4256
+ data_f_small.view(N_image, 1, 1, M3, N3)
4257
+ * self.backend.bk_conjugate(I1_f2_wf3_small)
4258
+ ).mean((-2, -1))
4259
+ * fft_factor
4260
+ / norm_factor_S3
4261
+ )
4262
+
5188
4263
  if get_variance:
5189
- S3_sigma[:,Ndata_S3,:,:] = (
5190
- data_f_small.view(N_image,1,1,M3,N3) * self.backend.bk_conjugate(I1_f2_wf3_small)
5191
- ).std((-2,-1)) * fft_factor / norm_factor_S3
4264
+ S3_sigma[:, Ndata_S3, :, :] = (
4265
+ (
4266
+ data_f_small.view(N_image, 1, 1, M3, N3)
4267
+ * self.backend.bk_conjugate(I1_f2_wf3_small)
4268
+ ).std((-2, -1))
4269
+ * fft_factor
4270
+ / norm_factor_S3
4271
+ )
5192
4272
  else:
5193
- S3[:,Ndata_S3,:,:] = (
5194
- data_small.view(N_image,1,1,M3,N3) * self.backend.bk_conjugate(I12_w3_small)
5195
- )[...,edge_dx:M3-edge_dx, edge_dy:N3-edge_dy].mean((-2,-1)) * fft_factor / norm_factor_S3
4273
+ S3[:, Ndata_S3, :, :] = (
4274
+ (
4275
+ data_small.view(N_image, 1, 1, M3, N3)
4276
+ * self.backend.bk_conjugate(I12_w3_small)
4277
+ )[..., edge_dx : M3 - edge_dx, edge_dy : N3 - edge_dy].mean(
4278
+ (-2, -1)
4279
+ )
4280
+ * fft_factor
4281
+ / norm_factor_S3
4282
+ )
5196
4283
  if get_variance:
5197
- S3_sigma[:,Ndata_S3,:,:] = (
5198
- data_small.view(N_image,1,1,M3,N3) * self.backend.bk_conjugate(I12_w3_small)
5199
- )[...,edge_dx:M3-edge_dx, edge_dy:N3-edge_dy].std((-2,-1)) * fft_factor / norm_factor_S3
4284
+ S3_sigma[:, Ndata_S3, :, :] = (
4285
+ (
4286
+ data_small.view(N_image, 1, 1, M3, N3)
4287
+ * self.backend.bk_conjugate(I12_w3_small)
4288
+ )[
4289
+ ..., edge_dx : M3 - edge_dx, edge_dy : N3 - edge_dy
4290
+ ].std(
4291
+ (-2, -1)
4292
+ )
4293
+ * fft_factor
4294
+ / norm_factor_S3
4295
+ )
5200
4296
  if data2 is not None:
5201
4297
  if not edge:
5202
- S3p[:,Ndata_S3,:,:] = (
5203
- data2_f_small.view(N_image2,1,1,M3,N3) * self.backend.bk_conjugate(I1_f2_wf3_small)
5204
- ).mean((-2,-1)) * fft_factor / norm_factor_S3
5205
-
4298
+ S3p[:, Ndata_S3, :, :] = (
4299
+ (
4300
+ data2_f_small.view(N_image2, 1, 1, M3, N3)
4301
+ * self.backend.bk_conjugate(I1_f2_wf3_small)
4302
+ ).mean((-2, -1))
4303
+ * fft_factor
4304
+ / norm_factor_S3
4305
+ )
4306
+
5206
4307
  if get_variance:
5207
- S3p_sigma[:,Ndata_S3,:,:] = (
5208
- data2_f_small.view(N_image2,1,1,M3,N3) * self.backend.bk_conjugate(I1_f2_wf3_small)
5209
- ).std((-2,-1)) * fft_factor / norm_factor_S3
4308
+ S3p_sigma[:, Ndata_S3, :, :] = (
4309
+ (
4310
+ data2_f_small.view(N_image2, 1, 1, M3, N3)
4311
+ * self.backend.bk_conjugate(I1_f2_wf3_small)
4312
+ ).std((-2, -1))
4313
+ * fft_factor
4314
+ / norm_factor_S3
4315
+ )
5210
4316
  else:
5211
- S3p[:,Ndata_S3,:,:] = (
5212
- data2_small.view(N_image2,1,1,M3,N3) * self.backend.bk_conjugate(I12_w3_small)
5213
- )[...,edge_dx:M3-edge_dx, edge_dy:N3-edge_dy].mean((-2,-1)) * fft_factor / norm_factor_S3
4317
+ S3p[:, Ndata_S3, :, :] = (
4318
+ (
4319
+ data2_small.view(N_image2, 1, 1, M3, N3)
4320
+ * self.backend.bk_conjugate(I12_w3_small)
4321
+ )[
4322
+ ..., edge_dx : M3 - edge_dx, edge_dy : N3 - edge_dy
4323
+ ].mean(
4324
+ (-2, -1)
4325
+ )
4326
+ * fft_factor
4327
+ / norm_factor_S3
4328
+ )
5214
4329
  if get_variance:
5215
- S3p_sigma[:,Ndata_S3,:,:] = (
5216
- data2_small.view(N_image2,1,1,M3,N3) * self.backend.bk_conjugate(I12_w3_small)
5217
- )[...,edge_dx:M3-edge_dx, edge_dy:N3-edge_dy].std((-2,-1)) * fft_factor / norm_factor_S3
5218
- Ndata_S3+=1
4330
+ S3p_sigma[:, Ndata_S3, :, :] = (
4331
+ (
4332
+ data2_small.view(N_image2, 1, 1, M3, N3)
4333
+ * self.backend.bk_conjugate(I12_w3_small)
4334
+ )[
4335
+ ...,
4336
+ edge_dx : M3 - edge_dx,
4337
+ edge_dy : N3 - edge_dy,
4338
+ ].std(
4339
+ (-2, -1)
4340
+ )
4341
+ * fft_factor
4342
+ / norm_factor_S3
4343
+ )
4344
+ Ndata_S3 += 1
5219
4345
  if j2 <= j3:
5220
- beg_n=Ndata_S4
5221
- for j1 in range(0, j2+1):
4346
+ beg_n = Ndata_S4
4347
+ for j1 in range(0, j2 + 1):
5222
4348
  if eval(S4_criteria):
5223
4349
  if not edge:
5224
4350
  if not if_large_batch:
5225
4351
  # [N_image,l1,l2,l3,x,y]
5226
- S4_pre_norm[:,Ndata_S4,:,:,:] = (
5227
- I1_f_small[:,j1].view(N_image,L,1,1,M3,N3) *
5228
- self.backend.bk_conjugate(I1_f2_wf3_2_small.view(N_image,1,L,L,M3,N3))
5229
- ).mean((-2,-1)) * fft_factor
4352
+ S4_pre_norm[:, Ndata_S4, :, :, :] = (
4353
+ I1_f_small[:, j1].view(
4354
+ N_image, L, 1, 1, M3, N3
4355
+ )
4356
+ * self.backend.bk_conjugate(
4357
+ I1_f2_wf3_2_small.view(
4358
+ N_image, 1, L, L, M3, N3
4359
+ )
4360
+ )
4361
+ ).mean((-2, -1)) * fft_factor
5230
4362
  if get_variance:
5231
- S4_sigma[:,Ndata_S4,:,:,:] = (
5232
- I1_f_small[:,j1].view(N_image,L,1,1,M3,N3) *
5233
- self.backend.bk_conjugate(I1_f2_wf3_2_small.view(N_image,1,L,L,M3,N3))
5234
- ).std((-2,-1)) * fft_factor
4363
+ S4_sigma[:, Ndata_S4, :, :, :] = (
4364
+ I1_f_small[:, j1].view(
4365
+ N_image, L, 1, 1, M3, N3
4366
+ )
4367
+ * self.backend.bk_conjugate(
4368
+ I1_f2_wf3_2_small.view(
4369
+ N_image, 1, L, L, M3, N3
4370
+ )
4371
+ )
4372
+ ).std((-2, -1)) * fft_factor
5235
4373
  else:
5236
4374
  for l1 in range(L):
5237
4375
  # [N_image,l2,l3,x,y]
5238
- S4_pre_norm[:,Ndata_S4,l1,:,:] = (
5239
- I1_f_small[:,j1,l1].view(N_image,1,1,M3,N3) *
5240
- self.backend.bk_conjugate(I1_f2_wf3_2_small.view(N_image,L,L,M3,N3))
5241
- ).mean((-2,-1)) * fft_factor
4376
+ S4_pre_norm[:, Ndata_S4, l1, :, :] = (
4377
+ I1_f_small[:, j1, l1].view(
4378
+ N_image, 1, 1, M3, N3
4379
+ )
4380
+ * self.backend.bk_conjugate(
4381
+ I1_f2_wf3_2_small.view(
4382
+ N_image, L, L, M3, N3
4383
+ )
4384
+ )
4385
+ ).mean((-2, -1)) * fft_factor
5242
4386
  if get_variance:
5243
- S4_sigma[:,Ndata_S4,l1,:,:] = (
5244
- I1_f_small[:,j1,l1].view(N_image,1,1,M3,N3) *
5245
- self.backend.bk_conjugate(I1_f2_wf3_2_small.view(N_image,L,L,M3,N3))
5246
- ).std((-2,-1)) * fft_factor
4387
+ S4_sigma[:, Ndata_S4, l1, :, :] = (
4388
+ I1_f_small[:, j1, l1].view(
4389
+ N_image, 1, 1, M3, N3
4390
+ )
4391
+ * self.backend.bk_conjugate(
4392
+ I1_f2_wf3_2_small.view(
4393
+ N_image, L, L, M3, N3
4394
+ )
4395
+ )
4396
+ ).std((-2, -1)) * fft_factor
5247
4397
  else:
5248
4398
  if not if_large_batch:
5249
4399
  # [N_image,l1,l2,l3,x,y]
5250
- S4_pre_norm[:,Ndata_S4,:,:,:] = (
5251
- I1_small[:,j1].view(N_image,L,1,1,M3,N3) * self.backend.bk_conjugate(
5252
- I12_w3_2_small.view(N_image,1,L,L,M3,N3)
4400
+ S4_pre_norm[:, Ndata_S4, :, :, :] = (
4401
+ I1_small[:, j1].view(
4402
+ N_image, L, 1, 1, M3, N3
4403
+ )
4404
+ * self.backend.bk_conjugate(
4405
+ I12_w3_2_small.view(
4406
+ N_image, 1, L, L, M3, N3
4407
+ )
5253
4408
  )
5254
- )[...,edge_dx:-edge_dx, edge_dy:-edge_dy].mean((-2,-1)) * fft_factor
4409
+ )[..., edge_dx:-edge_dx, edge_dy:-edge_dy].mean(
4410
+ (-2, -1)
4411
+ ) * fft_factor
5255
4412
  if get_variance:
5256
- S4_sigma[:,Ndata_S4,:,:,:] = (
5257
- I1_small[:,j1].view(N_image,L,1,1,M3,N3) * self.backend.bk_conjugate(
5258
- I12_w3_2_small.view(N_image,1,L,L,M3,N3)
4413
+ S4_sigma[:, Ndata_S4, :, :, :] = (
4414
+ I1_small[:, j1].view(
4415
+ N_image, L, 1, 1, M3, N3
5259
4416
  )
5260
- )[...,edge_dx:-edge_dx, edge_dy:-edge_dy].std((-2,-1)) * fft_factor
4417
+ * self.backend.bk_conjugate(
4418
+ I12_w3_2_small.view(
4419
+ N_image, 1, L, L, M3, N3
4420
+ )
4421
+ )
4422
+ )[
4423
+ ..., edge_dx:-edge_dx, edge_dy:-edge_dy
4424
+ ].std(
4425
+ (-2, -1)
4426
+ ) * fft_factor
5261
4427
  else:
5262
4428
  for l1 in range(L):
5263
- # [N_image,l2,l3,x,y]
5264
- S4_pre_norm[:,Ndata_S4,l1,:,:] = (
5265
- I1_small[:,j1].view(N_image,1,1,M3,N3) * self.backend.bk_conjugate(
5266
- I12_w3_2_small.view(N_image,L,L,M3,N3)
4429
+ # [N_image,l2,l3,x,y]
4430
+ S4_pre_norm[:, Ndata_S4, l1, :, :] = (
4431
+ I1_small[:, j1].view(
4432
+ N_image, 1, 1, M3, N3
4433
+ )
4434
+ * self.backend.bk_conjugate(
4435
+ I12_w3_2_small.view(
4436
+ N_image, L, L, M3, N3
4437
+ )
5267
4438
  )
5268
- )[...,edge_dx:-edge_dx, edge_dy:-edge_dy].mean((-2,-1)) * fft_factor
4439
+ )[
4440
+ ..., edge_dx:-edge_dx, edge_dy:-edge_dy
4441
+ ].mean(
4442
+ (-2, -1)
4443
+ ) * fft_factor
5269
4444
  if get_variance:
5270
- S4_sigma[:,Ndata_S4,l1,:,:] = (
5271
- I1_small[:,j1].view(N_image,1,1,M3,N3) * self.backend.bk_conjugate(
5272
- I12_w3_2_small.view(N_image,L,L,M3,N3)
4445
+ S4_sigma[:, Ndata_S4, l1, :, :] = (
4446
+ I1_small[:, j1].view(
4447
+ N_image, 1, 1, M3, N3
4448
+ )
4449
+ * self.backend.bk_conjugate(
4450
+ I12_w3_2_small.view(
4451
+ N_image, L, L, M3, N3
4452
+ )
5273
4453
  )
5274
- )[...,edge_dx:-edge_dx, edge_dy:-edge_dy].mean((-2,-1)) * fft_factor
5275
-
5276
- Ndata_S4+=1
5277
-
5278
- if normalization=='S2':
5279
- if use_ref:
5280
- P = ((ref_S2[:,j3:j3+1,:,None,None] * ref_S2[:,j2:j2+1,None,:,None] )**(0.5*pseudo_coef))
5281
- else:
5282
- P = ((S2[:,j3:j3+1,:,None,None] * S2[:,j2:j2+1,None,:,None] )**(0.5*pseudo_coef))
5283
-
5284
- S4[:,beg_n:Ndata_S4,:,:,:]=S4_pre_norm[:,beg_n:Ndata_S4,:,:,:].clone()/P
5285
-
4454
+ )[
4455
+ ...,
4456
+ edge_dx:-edge_dx,
4457
+ edge_dy:-edge_dy,
4458
+ ].mean(
4459
+ (-2, -1)
4460
+ ) * fft_factor
4461
+
4462
+ Ndata_S4 += 1
4463
+
4464
+ if normalization == "S2":
4465
+ if use_ref:
4466
+ P = (
4467
+ ref_S2[:, j3 : j3 + 1, :, None, None]
4468
+ * ref_S2[:, j2 : j2 + 1, None, :, None]
4469
+ ) ** (0.5 * pseudo_coef)
4470
+ else:
4471
+ P = (
4472
+ S2[:, j3 : j3 + 1, :, None, None]
4473
+ * S2[:, j2 : j2 + 1, None, :, None]
4474
+ ) ** (0.5 * pseudo_coef)
4475
+
4476
+ S4[:, beg_n:Ndata_S4, :, :, :] = (
4477
+ S4_pre_norm[:, beg_n:Ndata_S4, :, :, :].clone() / P
4478
+ )
4479
+
5286
4480
  if get_variance:
5287
- S4_sigma[:,beg_n:Ndata_S4,:,:,:] = S4_sigma[:,beg_n:Ndata_S4,:,:,:]/P
4481
+ S4_sigma[:, beg_n:Ndata_S4, :, :, :] = (
4482
+ S4_sigma[:, beg_n:Ndata_S4, :, :, :] / P
4483
+ )
5288
4484
  else:
5289
- S4=S4_pre_norm
5290
-
4485
+ S4 = S4_pre_norm
4486
+
5291
4487
  # average over l1 to obtain simple isotropic statistics
5292
4488
  if iso_ang:
5293
4489
  S2_iso = S2.mean(-1)
5294
4490
  S1_iso = S1.mean(-1)
5295
4491
  for l1 in range(L):
5296
4492
  for l2 in range(L):
5297
- S3_iso[...,(l2-l1)%L] += S3[...,l1,l2]
4493
+ S3_iso[..., (l2 - l1) % L] += S3[..., l1, l2]
5298
4494
  if data2 is not None:
5299
- S3p_iso[...,(l2-l1)%L] += S3p[...,l1,l2]
4495
+ S3p_iso[..., (l2 - l1) % L] += S3p[..., l1, l2]
5300
4496
  for l3 in range(L):
5301
- S4_iso[...,(l2-l1)%L,(l3-l1)%L] += S4[...,l1,l2,l3]
5302
- S3_iso /= L; S4_iso /= L
4497
+ S4_iso[..., (l2 - l1) % L, (l3 - l1) % L] += S4[
4498
+ ..., l1, l2, l3
4499
+ ]
4500
+ S3_iso /= L
4501
+ S4_iso /= L
5303
4502
  if data2 is not None:
5304
4503
  S3p_iso /= L
5305
-
4504
+
5306
4505
  if get_variance:
5307
4506
  S2_sigma_iso = S2_sigma.mean(-1)
5308
4507
  S1_sigma_iso = S1_sigma.mean(-1)
5309
4508
  for l1 in range(L):
5310
4509
  for l2 in range(L):
5311
- S3_sigma_iso[...,(l2-l1)%L] += S3_sigma[...,l1,l2]
4510
+ S3_sigma_iso[..., (l2 - l1) % L] += S3_sigma[..., l1, l2]
5312
4511
  if data2 is not None:
5313
- S3p_sigma_iso[...,(l2-l1)%L] += S3p_sigma[...,l1,l2]
4512
+ S3p_sigma_iso[..., (l2 - l1) % L] += S3p_sigma[
4513
+ ..., l1, l2
4514
+ ]
5314
4515
  for l3 in range(L):
5315
- S4_sigma_iso[...,(l2-l1)%L,(l3-l1)%L] += S4_sigma[...,l1,l2,l3]
5316
- S3_sigma_iso /= L; S4_sigma_iso /= L
4516
+ S4_sigma_iso[
4517
+ ..., (l2 - l1) % L, (l3 - l1) % L
4518
+ ] += S4_sigma[..., l1, l2, l3]
4519
+ S3_sigma_iso /= L
4520
+ S4_sigma_iso /= L
5317
4521
  if data2 is not None:
5318
4522
  S3p_sigma_iso /= L
5319
-
5320
- mean_data=self.backend.bk_zeros((N_image,1), dtype=data.dtype)
5321
- std_data=self.backend.bk_zeros((N_image,1), dtype=data.dtype)
5322
-
4523
+
4524
+ mean_data = self.backend.bk_zeros((N_image, 1), dtype=data.dtype)
4525
+ std_data = self.backend.bk_zeros((N_image, 1), dtype=data.dtype)
4526
+
5323
4527
  if data2 is None:
5324
- mean_data[:,0]=data.mean((-2,-1))
5325
- std_data[:,0]=data.std((-2,-1))
4528
+ mean_data[:, 0] = data.mean((-2, -1))
4529
+ std_data[:, 0] = data.std((-2, -1))
5326
4530
  else:
5327
- mean_data[:,0]=(data2*data).mean((-2,-1))
5328
- std_data[:,0]=(data2*data).std((-2,-1))
5329
-
4531
+ mean_data[:, 0] = (data2 * data).mean((-2, -1))
4532
+ std_data[:, 0] = (data2 * data).std((-2, -1))
4533
+
5330
4534
  if get_variance:
5331
- ref_sigma={}
4535
+ ref_sigma = {}
5332
4536
  if iso_ang:
5333
- ref_sigma['std_data']=std_data
5334
- ref_sigma['S1_sigma']=S1_sigma_iso
5335
- ref_sigma['S2_sigma']=S2_sigma_iso
5336
- ref_sigma['S3_sigma']=S3_sigma_iso
4537
+ ref_sigma["std_data"] = std_data
4538
+ ref_sigma["S1_sigma"] = S1_sigma_iso
4539
+ ref_sigma["S2_sigma"] = S2_sigma_iso
4540
+ ref_sigma["S3_sigma"] = S3_sigma_iso
5337
4541
  if data2 is not None:
5338
- ref_sigma['S3p_sigma']=S3p_sigma_iso
5339
- ref_sigma['S4_sigma']=S4_sigma_iso
4542
+ ref_sigma["S3p_sigma"] = S3p_sigma_iso
4543
+ ref_sigma["S4_sigma"] = S4_sigma_iso
5340
4544
  else:
5341
- ref_sigma['std_data']=std_data
5342
- ref_sigma['S1_sigma']=S1_sigma
5343
- ref_sigma['S2_sigma']=S2_sigma
5344
- ref_sigma['S3_sigma']=S3_sigma
4545
+ ref_sigma["std_data"] = std_data
4546
+ ref_sigma["S1_sigma"] = S1_sigma
4547
+ ref_sigma["S2_sigma"] = S2_sigma
4548
+ ref_sigma["S3_sigma"] = S3_sigma
5345
4549
  if data2 is not None:
5346
- ref_sigma['S3p_sigma']=S3p_sigma
5347
- ref_sigma['S4_sigma']=S4_sigma
5348
-
4550
+ ref_sigma["S3p_sigma"] = S3p_sigma
4551
+ ref_sigma["S4_sigma"] = S4_sigma
4552
+
5349
4553
  if data2 is None:
5350
4554
  if iso_ang:
5351
4555
  if ref_sigma is not None:
5352
- for_synthesis = self.backend.backend.cat((
5353
- mean_data/ref_sigma['std_data'],
5354
- std_data/ref_sigma['std_data'],
5355
- (S2_iso/ref_sigma['S2_sigma']).reshape((N_image, -1)).log(),
5356
- (S1_iso/ref_sigma['S1_sigma']).reshape((N_image, -1)).log(),
5357
- (S3_iso/ref_sigma['S3_sigma']).reshape((N_image, -1)).real,
5358
- (S3_iso/ref_sigma['S3_sigma']).reshape((N_image, -1)).imag,
5359
- (S4_iso/ref_sigma['S4_sigma']).reshape((N_image, -1)).real,
5360
- (S4_iso/ref_sigma['S4_sigma']).reshape((N_image, -1)).imag,
5361
- ),dim=-1)
4556
+ for_synthesis = self.backend.backend.cat(
4557
+ (
4558
+ mean_data / ref_sigma["std_data"],
4559
+ std_data / ref_sigma["std_data"],
4560
+ (S2_iso / ref_sigma["S2_sigma"])
4561
+ .reshape((N_image, -1))
4562
+ .log(),
4563
+ (S1_iso / ref_sigma["S1_sigma"])
4564
+ .reshape((N_image, -1))
4565
+ .log(),
4566
+ (S3_iso / ref_sigma["S3_sigma"])
4567
+ .reshape((N_image, -1))
4568
+ .real,
4569
+ (S3_iso / ref_sigma["S3_sigma"])
4570
+ .reshape((N_image, -1))
4571
+ .imag,
4572
+ (S4_iso / ref_sigma["S4_sigma"])
4573
+ .reshape((N_image, -1))
4574
+ .real,
4575
+ (S4_iso / ref_sigma["S4_sigma"])
4576
+ .reshape((N_image, -1))
4577
+ .imag,
4578
+ ),
4579
+ dim=-1,
4580
+ )
5362
4581
  else:
5363
- for_synthesis = self.backend.backend.cat((
5364
- mean_data/std_data,
5365
- std_data,
5366
- S2_iso.reshape((N_image, -1)).log(),
5367
- S1_iso.reshape((N_image, -1)).log(),
5368
- S3_iso.reshape((N_image, -1)).real,
5369
- S3_iso.reshape((N_image, -1)).imag,
5370
- S4_iso.reshape((N_image, -1)).real,
5371
- S4_iso.reshape((N_image, -1)).imag,
5372
- ),dim=-1)
4582
+ for_synthesis = self.backend.backend.cat(
4583
+ (
4584
+ mean_data / std_data,
4585
+ std_data,
4586
+ S2_iso.reshape((N_image, -1)).log(),
4587
+ S1_iso.reshape((N_image, -1)).log(),
4588
+ S3_iso.reshape((N_image, -1)).real,
4589
+ S3_iso.reshape((N_image, -1)).imag,
4590
+ S4_iso.reshape((N_image, -1)).real,
4591
+ S4_iso.reshape((N_image, -1)).imag,
4592
+ ),
4593
+ dim=-1,
4594
+ )
5373
4595
  else:
5374
4596
  if ref_sigma is not None:
5375
- for_synthesis = self.backend.backend.cat((
5376
- mean_data/ref_sigma['std_data'],
5377
- std_data/ref_sigma['std_data'],
5378
- (S2/ref_sigma['S2_sigma']).reshape((N_image, -1)).log(),
5379
- (S1/ref_sigma['S1_sigma']).reshape((N_image, -1)).log(),
5380
- (S3/ref_sigma['S3_sigma']).reshape((N_image, -1)).real,
5381
- (S3/ref_sigma['S3_sigma']).reshape((N_image, -1)).imag,
5382
- (S4/ref_sigma['S4_sigma']).reshape((N_image, -1)).real,
5383
- (S4/ref_sigma['S4_sigma']).reshape((N_image, -1)).imag,
5384
- ),dim=-1)
4597
+ for_synthesis = self.backend.backend.cat(
4598
+ (
4599
+ mean_data / ref_sigma["std_data"],
4600
+ std_data / ref_sigma["std_data"],
4601
+ (S2 / ref_sigma["S2_sigma"])
4602
+ .reshape((N_image, -1))
4603
+ .log(),
4604
+ (S1 / ref_sigma["S1_sigma"])
4605
+ .reshape((N_image, -1))
4606
+ .log(),
4607
+ (S3 / ref_sigma["S3_sigma"])
4608
+ .reshape((N_image, -1))
4609
+ .real,
4610
+ (S3 / ref_sigma["S3_sigma"])
4611
+ .reshape((N_image, -1))
4612
+ .imag,
4613
+ (S4 / ref_sigma["S4_sigma"])
4614
+ .reshape((N_image, -1))
4615
+ .real,
4616
+ (S4 / ref_sigma["S4_sigma"])
4617
+ .reshape((N_image, -1))
4618
+ .imag,
4619
+ ),
4620
+ dim=-1,
4621
+ )
5385
4622
  else:
5386
- for_synthesis = self.backend.backend.cat((
5387
- mean_data/std_data,
5388
- std_data,
5389
- S2.reshape((N_image, -1)).log(),
5390
- S1.reshape((N_image, -1)).log(),
5391
- S3.reshape((N_image, -1)).real,
5392
- S3.reshape((N_image, -1)).imag,
5393
- S4.reshape((N_image, -1)).real,
5394
- S4.reshape((N_image, -1)).imag,
5395
- ),dim=-1)
4623
+ for_synthesis = self.backend.backend.cat(
4624
+ (
4625
+ mean_data / std_data,
4626
+ std_data,
4627
+ S2.reshape((N_image, -1)).log(),
4628
+ S1.reshape((N_image, -1)).log(),
4629
+ S3.reshape((N_image, -1)).real,
4630
+ S3.reshape((N_image, -1)).imag,
4631
+ S4.reshape((N_image, -1)).real,
4632
+ S4.reshape((N_image, -1)).imag,
4633
+ ),
4634
+ dim=-1,
4635
+ )
5396
4636
  else:
5397
4637
  if iso_ang:
5398
4638
  if ref_sigma is not None:
5399
- for_synthesis = self.backend.backend.cat((
5400
- mean_data/ref_sigma['std_data'],
5401
- std_data/ref_sigma['std_data'],
5402
- (S2_iso/ref_sigma['S2_sigma']).reshape((N_image, -1)),
5403
- (S1_iso/ref_sigma['S1_sigma']).reshape((N_image, -1)),
5404
- (S3_iso/ref_sigma['S3_sigma']).reshape((N_image, -1)).real,
5405
- (S3_iso/ref_sigma['S3_sigma']).reshape((N_image, -1)).imag,
5406
- (S3p_iso/ref_sigma['S3p_sigma']).reshape((N_image, -1)).real,
5407
- (S3p_iso/ref_sigma['S3p_sigma']).reshape((N_image, -1)).imag,
5408
- (S4_iso/ref_sigma['S4_sigma']).reshape((N_image, -1)).real,
5409
- (S4_iso/ref_sigma['S4_sigma']).reshape((N_image, -1)).imag,
5410
- ),dim=-1)
4639
+ for_synthesis = self.backend.backend.cat(
4640
+ (
4641
+ mean_data / ref_sigma["std_data"],
4642
+ std_data / ref_sigma["std_data"],
4643
+ (S2_iso / ref_sigma["S2_sigma"]).reshape((N_image, -1)),
4644
+ (S1_iso / ref_sigma["S1_sigma"]).reshape((N_image, -1)),
4645
+ (S3_iso / ref_sigma["S3_sigma"])
4646
+ .reshape((N_image, -1))
4647
+ .real,
4648
+ (S3_iso / ref_sigma["S3_sigma"])
4649
+ .reshape((N_image, -1))
4650
+ .imag,
4651
+ (S3p_iso / ref_sigma["S3p_sigma"])
4652
+ .reshape((N_image, -1))
4653
+ .real,
4654
+ (S3p_iso / ref_sigma["S3p_sigma"])
4655
+ .reshape((N_image, -1))
4656
+ .imag,
4657
+ (S4_iso / ref_sigma["S4_sigma"])
4658
+ .reshape((N_image, -1))
4659
+ .real,
4660
+ (S4_iso / ref_sigma["S4_sigma"])
4661
+ .reshape((N_image, -1))
4662
+ .imag,
4663
+ ),
4664
+ dim=-1,
4665
+ )
5411
4666
  else:
5412
- for_synthesis = self.backend.backend.cat((
5413
- mean_data/std_data,
5414
- std_data,
5415
- S2_iso.reshape((N_image, -1)),
5416
- S1_iso.reshape((N_image, -1)),
5417
- S3_iso.reshape((N_image, -1)).real,
5418
- S3_iso.reshape((N_image, -1)).imag,
5419
- S3p_iso.reshape((N_image, -1)).real,
5420
- S3p_iso.reshape((N_image, -1)).imag,
5421
- S4_iso.reshape((N_image, -1)).real,
5422
- S4_iso.reshape((N_image, -1)).imag,
5423
- ),dim=-1)
4667
+ for_synthesis = self.backend.backend.cat(
4668
+ (
4669
+ mean_data / std_data,
4670
+ std_data,
4671
+ S2_iso.reshape((N_image, -1)),
4672
+ S1_iso.reshape((N_image, -1)),
4673
+ S3_iso.reshape((N_image, -1)).real,
4674
+ S3_iso.reshape((N_image, -1)).imag,
4675
+ S3p_iso.reshape((N_image, -1)).real,
4676
+ S3p_iso.reshape((N_image, -1)).imag,
4677
+ S4_iso.reshape((N_image, -1)).real,
4678
+ S4_iso.reshape((N_image, -1)).imag,
4679
+ ),
4680
+ dim=-1,
4681
+ )
5424
4682
  else:
5425
4683
  if ref_sigma is not None:
5426
- for_synthesis = self.backend.backend.cat((
5427
- mean_data/ref_sigma['std_data'],
5428
- std_data/ref_sigma['std_data'],
5429
- (S2/ref_sigma['S2_sigma']).reshape((N_image, -1)),
5430
- (S1/ref_sigma['S1_sigma']).reshape((N_image, -1)),
5431
- (S3/ref_sigma['S3_sigma']).reshape((N_image, -1)).real,
5432
- (S3/ref_sigma['S3_sigma']).reshape((N_image, -1)).imag,
5433
- (S3p/ref_sigma['S3p_sigma']).reshape((N_image, -1)).real,
5434
- (S3p/ref_sigma['S3p_sigma']).reshape((N_image, -1)).imag,
5435
- (S4/ref_sigma['S4_sigma']).reshape((N_image, -1)).real,
5436
- (S4/ref_sigma['S4_sigma']).reshape((N_image, -1)).imag,
5437
- ),dim=-1)
4684
+ for_synthesis = self.backend.backend.cat(
4685
+ (
4686
+ mean_data / ref_sigma["std_data"],
4687
+ std_data / ref_sigma["std_data"],
4688
+ (S2 / ref_sigma["S2_sigma"]).reshape((N_image, -1)),
4689
+ (S1 / ref_sigma["S1_sigma"]).reshape((N_image, -1)),
4690
+ (S3 / ref_sigma["S3_sigma"])
4691
+ .reshape((N_image, -1))
4692
+ .real,
4693
+ (S3 / ref_sigma["S3_sigma"])
4694
+ .reshape((N_image, -1))
4695
+ .imag,
4696
+ (S3p / ref_sigma["S3p_sigma"])
4697
+ .reshape((N_image, -1))
4698
+ .real,
4699
+ (S3p / ref_sigma["S3p_sigma"])
4700
+ .reshape((N_image, -1))
4701
+ .imag,
4702
+ (S4 / ref_sigma["S4_sigma"])
4703
+ .reshape((N_image, -1))
4704
+ .real,
4705
+ (S4 / ref_sigma["S4_sigma"])
4706
+ .reshape((N_image, -1))
4707
+ .imag,
4708
+ ),
4709
+ dim=-1,
4710
+ )
5438
4711
  else:
5439
- for_synthesis = self.backend.backend.cat((
5440
- mean_data/std_data,
5441
- std_data,
5442
- S2.reshape((N_image, -1)),
5443
- S1.reshape((N_image, -1)),
5444
- S3.reshape((N_image, -1)).real,
5445
- S3.reshape((N_image, -1)).imag,
5446
- S3p.reshape((N_image, -1)).real,
5447
- S3p.reshape((N_image, -1)).imag,
5448
- S4.reshape((N_image, -1)).real,
5449
- S4.reshape((N_image, -1)).imag,
5450
- ),dim=-1)
5451
-
5452
- if not use_ref:
5453
- self.ref_scattering_cov_S2=S2
5454
-
4712
+ for_synthesis = self.backend.backend.cat(
4713
+ (
4714
+ mean_data / std_data,
4715
+ std_data,
4716
+ S2.reshape((N_image, -1)),
4717
+ S1.reshape((N_image, -1)),
4718
+ S3.reshape((N_image, -1)).real,
4719
+ S3.reshape((N_image, -1)).imag,
4720
+ S3p.reshape((N_image, -1)).real,
4721
+ S3p.reshape((N_image, -1)).imag,
4722
+ S4.reshape((N_image, -1)).real,
4723
+ S4.reshape((N_image, -1)).imag,
4724
+ ),
4725
+ dim=-1,
4726
+ )
4727
+
4728
+ if not use_ref:
4729
+ self.ref_scattering_cov_S2 = S2
4730
+
5455
4731
  if get_variance:
5456
- return for_synthesis,ref_sigma
5457
-
4732
+ return for_synthesis, ref_sigma
4733
+
5458
4734
  return for_synthesis
5459
-
5460
- if (M,N,J,L) not in self.filters_set:
5461
- self.filters_set[(M,N,J,L)] = self.computer_filter(M,N,J,L) #self.computer_filter(M,N,J,L)
5462
-
5463
- filters_set = self.filters_set[(M,N,J,L)]
5464
-
5465
- #weight = self.weight
4735
+
4736
+ if (M, N, J, L) not in self.filters_set:
4737
+ self.filters_set[(M, N, J, L)] = self.computer_filter(
4738
+ M, N, J, L
4739
+ ) # self.computer_filter(M,N,J,L)
4740
+
4741
+ filters_set = self.filters_set[(M, N, J, L)]
4742
+
4743
+ # weight = self.weight
5466
4744
  if use_ref:
5467
- if normalization=='S2':
4745
+ if normalization == "S2":
5468
4746
  ref_S2 = self.ref_scattering_cov_S2
5469
- else:
5470
- ref_P11 = self.ref_scattering_cov['P11']
4747
+ else:
4748
+ ref_P11 = self.ref_scattering_cov["P11"]
5471
4749
 
5472
4750
  # convert numpy array input into self.backend.bk_ tensors
5473
4751
  data = self.backend.bk_cast(data)
5474
- data_f = self.backend.bk_fftn(data, dim=(-2,-1))
4752
+ data_f = self.backend.bk_fftn(data, dim=(-2, -1))
5475
4753
  if data2 is not None:
5476
4754
  data2 = self.backend.bk_cast(data2)
5477
- data2_f = self.backend.bk_fftn(data2, dim=(-2,-1))
5478
-
4755
+ data2_f = self.backend.bk_fftn(data2, dim=(-2, -1))
4756
+
5479
4757
  # initialize tensors for scattering coefficients
5480
-
5481
- Ndata_S3 = J*(J+1)//2
5482
- Ndata_S4 = J*(J+1)*(J+2)//6
5483
- J_S4={}
5484
-
4758
+
4759
+ Ndata_S3 = J * (J + 1) // 2
4760
+ Ndata_S4 = J * (J + 1) * (J + 2) // 6
4761
+ J_S4 = {}
4762
+
5485
4763
  S3 = []
5486
4764
  if data2 is not None:
5487
4765
  S3p = []
5488
- S4_pre_norm = []
5489
- S4 = []
5490
-
4766
+ S4_pre_norm = []
4767
+ S4 = []
4768
+
5491
4769
  # variance
5492
4770
  if get_variance:
5493
- S3_sigma = []
4771
+ S3_sigma = []
5494
4772
  if data2 is not None:
5495
4773
  S3p_sigma = []
5496
- S4_sigma = []
5497
-
4774
+ S4_sigma = []
4775
+
5498
4776
  if iso_ang:
5499
4777
  S3_iso = []
5500
4778
  if data2 is not None:
5501
4779
  S3p_iso = []
5502
-
4780
+
5503
4781
  S4_iso = []
5504
4782
  if get_variance:
5505
4783
  S3_sigma_iso = []
5506
4784
  if data2 is not None:
5507
- S3p_sigma_iso = []
5508
- S4_sigma_iso = []
5509
-
4785
+ S3p_sigma_iso = []
4786
+ S4_sigma_iso = []
4787
+
5510
4788
  #
5511
- if edge:
5512
- if (M,N,J) not in self.edge_masks:
5513
- self.edge_masks[(M,N,J)] = self.get_edge_masks(M,N,J)
5514
- edge_mask = self.edge_masks[(M,N,J)]
5515
- else:
4789
+ if edge:
4790
+ if (M, N, J) not in self.edge_masks:
4791
+ self.edge_masks[(M, N, J)] = self.get_edge_masks(M, N, J)
4792
+ edge_mask = self.edge_masks[(M, N, J)]
4793
+ else:
5516
4794
  edge_mask = 1
5517
-
4795
+
5518
4796
  # calculate scattering fields
5519
4797
  if data2 is None:
5520
4798
  if self.use_2D:
5521
4799
  if len(data.shape) == 2:
5522
- I1 = self.backend.bk_abs(self.backend.bk_ifftn(
5523
- data_f[None,None,None,:,:] * filters_set[None,:J,:,:,:], dim=(-2,-1)
5524
- ))
4800
+ I1 = self.backend.bk_abs(
4801
+ self.backend.bk_ifftn(
4802
+ data_f[None, None, None, :, :]
4803
+ * filters_set[None, :J, :, :, :],
4804
+ dim=(-2, -1),
4805
+ )
4806
+ )
5525
4807
  else:
5526
- I1 = self.backend.bk_abs(self.backend.bk_ifftn(
5527
- data_f[:,None,None,:,:] * filters_set[None,:J,:,:,:], dim=(-2,-1)
5528
- ))
4808
+ I1 = self.backend.bk_abs(
4809
+ self.backend.bk_ifftn(
4810
+ data_f[:, None, None, :, :]
4811
+ * filters_set[None, :J, :, :, :],
4812
+ dim=(-2, -1),
4813
+ )
4814
+ )
5529
4815
  elif self.use_1D:
5530
4816
  if len(data.shape) == 1:
5531
- I1 = self.backend.bk_abs(self.backend.bk_ifftn(
5532
- data_f[None,None,None,:] * filters_set[None,:J,:,:], dim=(-1)
5533
- ))
4817
+ I1 = self.backend.bk_abs(
4818
+ self.backend.bk_ifftn(
4819
+ data_f[None, None, None, :] * filters_set[None, :J, :, :],
4820
+ dim=(-1),
4821
+ )
4822
+ )
5534
4823
  else:
5535
- I1 = self.backend.bk_abs(self.backend.bk_ifftn(
5536
- data_f[:,None,None,:] * filters_set[None,:J,:,:], dim=(-1)
5537
- ))
4824
+ I1 = self.backend.bk_abs(
4825
+ self.backend.bk_ifftn(
4826
+ data_f[:, None, None, :] * filters_set[None, :J, :, :],
4827
+ dim=(-1),
4828
+ )
4829
+ )
5538
4830
  else:
5539
- print('todo')
5540
-
5541
- S2 = self.backend.bk_reduce_mean((I1**2 * edge_mask),axis=(-2,-1))
5542
- S1 = self.backend.bk_reduce_mean(I1 * edge_mask,axis=(-2,-1))
4831
+ print("todo")
4832
+
4833
+ S2 = self.backend.bk_reduce_mean((I1**2 * edge_mask), axis=(-2, -1))
4834
+ S1 = self.backend.bk_reduce_mean(I1 * edge_mask, axis=(-2, -1))
5543
4835
 
5544
4836
  if get_variance:
5545
- S2_sigma = self.backend.bk_reduce_std((I1**2 * edge_mask),axis=(-2,-1))
5546
- S1_sigma = self.backend.bk_reduce_std((I1 * edge_mask),axis=(-2,-1))
5547
-
5548
- I1_f= self.backend.bk_fftn(I1, dim=(-2,-1))
5549
-
5550
- else:
4837
+ S2_sigma = self.backend.bk_reduce_std(
4838
+ (I1**2 * edge_mask), axis=(-2, -1)
4839
+ )
4840
+ S1_sigma = self.backend.bk_reduce_std((I1 * edge_mask), axis=(-2, -1))
4841
+
4842
+ I1_f = self.backend.bk_fftn(I1, dim=(-2, -1))
4843
+
4844
+ else:
5551
4845
  if self.use_2D:
5552
4846
  if len(data.shape) == 2:
5553
4847
  I1 = self.backend.bk_ifftn(
5554
- data_f[None,None,None,:,:] * filters_set[None,:J,:,:,:], dim=(-2,-1)
4848
+ data_f[None, None, None, :, :] * filters_set[None, :J, :, :, :],
4849
+ dim=(-2, -1),
5555
4850
  )
5556
4851
  I2 = self.backend.bk_ifftn(
5557
- data2_f[None,None,None,:,:] * filters_set[None,:J,:,:,:], dim=(-2,-1)
4852
+ data2_f[None, None, None, :, :]
4853
+ * filters_set[None, :J, :, :, :],
4854
+ dim=(-2, -1),
5558
4855
  )
5559
4856
  else:
5560
4857
  I1 = self.backend.bk_ifftn(
5561
- data_f[:,None,None,:,:] * filters_set[None,:J,:,:,:], dim=(-2,-1)
4858
+ data_f[:, None, None, :, :] * filters_set[None, :J, :, :, :],
4859
+ dim=(-2, -1),
5562
4860
  )
5563
4861
  I2 = self.backend.bk_ifftn(
5564
- data2_f[:,None,None,:,:] * filters_set[None,:J,:,:,:], dim=(-2,-1)
4862
+ data2_f[:, None, None, :, :] * filters_set[None, :J, :, :, :],
4863
+ dim=(-2, -1),
5565
4864
  )
5566
4865
  elif self.use_1D:
5567
4866
  if len(data.shape) == 1:
5568
4867
  I1 = self.backend.bk_ifftn(
5569
- data_f[None,None,None,:] * filters_set[None,:J,:,:], dim=(-1)
4868
+ data_f[None, None, None, :] * filters_set[None, :J, :, :],
4869
+ dim=(-1),
5570
4870
  )
5571
4871
  I2 = self.backend.bk_ifftn(
5572
- data2_f[None,None,None,:] * filters_set[None,:J,:,:], dim=(-1)
4872
+ data2_f[None, None, None, :] * filters_set[None, :J, :, :],
4873
+ dim=(-1),
5573
4874
  )
5574
4875
  else:
5575
4876
  I1 = self.backend.bk_ifftn(
5576
- data_f[:,None,None,:] * filters_set[None,:J,:,:], dim=(-1)
4877
+ data_f[:, None, None, :] * filters_set[None, :J, :, :], dim=(-1)
5577
4878
  )
5578
4879
  I2 = self.backend.bk_ifftn(
5579
- data2_f[:,None,None,:] * filters_set[None,:J,:,:], dim=(-1)
4880
+ data2_f[:, None, None, :] * filters_set[None, :J, :, :],
4881
+ dim=(-1),
5580
4882
  )
5581
4883
  else:
5582
- print('todo')
5583
-
5584
- I1=self.backend.bk_real(I1*self.backend.bk_conjugate(I2))
5585
-
5586
- S2 = self.backend.bk_reduce_mean((I1 * edge_mask),axis=(-2,-1))
4884
+ print("todo")
4885
+
4886
+ I1 = self.backend.bk_real(I1 * self.backend.bk_conjugate(I2))
4887
+
4888
+ S2 = self.backend.bk_reduce_mean((I1 * edge_mask), axis=(-2, -1))
5587
4889
  if get_variance:
5588
- S2_sigma = self.backend.bk_reduce_std((I1 * edge_mask),axis=(-2,-1))
5589
-
5590
- I1=self.backend.bk_L1(I1)
5591
-
5592
- S1 = self.backend.bk_reduce_mean((I1 * edge_mask),axis=(-2,-1))
4890
+ S2_sigma = self.backend.bk_reduce_std((I1 * edge_mask), axis=(-2, -1))
4891
+
4892
+ I1 = self.backend.bk_L1(I1)
4893
+
4894
+ S1 = self.backend.bk_reduce_mean((I1 * edge_mask), axis=(-2, -1))
5593
4895
 
5594
4896
  if get_variance:
5595
- S1_sigma = self.backend.bk_reduce_std((I1 * edge_mask),axis=(-2,-1))
5596
-
5597
- I1_f= self.backend.bk_fftn(I1, dim=(-2,-1))
5598
-
4897
+ S1_sigma = self.backend.bk_reduce_std((I1 * edge_mask), axis=(-2, -1))
4898
+
4899
+ I1_f = self.backend.bk_fftn(I1, dim=(-2, -1))
4900
+
5599
4901
  if pseudo_coef != 1:
5600
4902
  I1 = I1**pseudo_coef
5601
-
5602
- Ndata_S3=0
5603
- Ndata_S4=0
5604
-
4903
+
4904
+ Ndata_S3 = 0
4905
+ Ndata_S4 = 0
4906
+
5605
4907
  # calculate the covariance and correlations of the scattering fields
5606
4908
  # only use the low-k Fourier coefs when calculating large-j scattering coefs.
5607
- for j3 in range(0,J):
5608
- J_S4[j3]=Ndata_S4
5609
-
5610
- dx3, dy3 = self.get_dxdy(j3,M,N)
5611
- I1_f_small = self.cut_high_k_off(I1_f[:,:j3+1], dx3, dy3) # Nimage, J, L, x, y
4909
+ for j3 in range(0, J):
4910
+ J_S4[j3] = Ndata_S4
4911
+
4912
+ dx3, dy3 = self.get_dxdy(j3, M, N)
4913
+ I1_f_small = self.cut_high_k_off(
4914
+ I1_f[:, : j3 + 1], dx3, dy3
4915
+ ) # Nimage, J, L, x, y
5612
4916
  data_f_small = self.cut_high_k_off(data_f, dx3, dy3)
5613
4917
  if data2 is not None:
5614
4918
  data2_f_small = self.cut_high_k_off(data2_f, dx3, dy3)
5615
4919
  if edge:
5616
- I1_small = self.backend.bk_ifftn(I1_f_small, dim=(-2,-1), norm='ortho')
5617
- data_small = self.backend.bk_ifftn(data_f_small, dim=(-2,-1), norm='ortho')
4920
+ I1_small = self.backend.bk_ifftn(I1_f_small, dim=(-2, -1), norm="ortho")
4921
+ data_small = self.backend.bk_ifftn(
4922
+ data_f_small, dim=(-2, -1), norm="ortho"
4923
+ )
5618
4924
  if data2 is not None:
5619
- data2_small = self.backend.bk_ifftn(data2_f_small, dim=(-2,-1), norm='ortho')
5620
- wavelet_f3 = self.cut_high_k_off(filters_set[j3], dx3, dy3) # L,x,y
4925
+ data2_small = self.backend.bk_ifftn(
4926
+ data2_f_small, dim=(-2, -1), norm="ortho"
4927
+ )
4928
+ wavelet_f3 = self.cut_high_k_off(filters_set[j3], dx3, dy3) # L,x,y
5621
4929
  _, M3, N3 = wavelet_f3.shape
5622
4930
  wavelet_f3_squared = wavelet_f3**2
5623
- edge_dx = min(4, int(2**j3*dx3*2/M))
5624
- edge_dy = min(4, int(2**j3*dy3*2/N))
4931
+ edge_dx = min(4, int(2**j3 * dx3 * 2 / M))
4932
+ edge_dy = min(4, int(2**j3 * dy3 * 2 / N))
5625
4933
  # a normalization change due to the cutoff of frequency space
5626
- fft_factor = 1 /(M3*N3) * (M3*N3/M/N)**2
5627
- for j2 in range(0,j3+1):
5628
- #I1_f2_wf3_small = I1_f_small[:,j2].view(N_image,L,1,M3,N3) * wavelet_f3.view(1,1,L,M3,N3)
5629
- #I1_f2_wf3_2_small = I1_f_small[:,j2].view(N_image,L,1,M3,N3) * wavelet_f3_squared.view(1,1,L,M3,N3)
5630
- I1_f2_wf3_small = self.backend.bk_reshape(I1_f_small[:,j2],[N_image,1,L,1,M3,N3]) * self.backend.bk_reshape(wavelet_f3,[1,1,1,L,M3,N3])
5631
- I1_f2_wf3_2_small = self.backend.bk_reshape(I1_f_small[:,j2],[N_image,1,L,1,M3,N3]) * self.backend.bk_reshape(wavelet_f3_squared,[1,1,1,L,M3,N3])
4934
+ if self.all_bk_type == "float32":
4935
+ fft_factor = np.complex64(1 / (M3 * N3) * (M3 * N3 / M / N) ** 2)
4936
+ else:
4937
+ fft_factor = np.complex128(1 / (M3 * N3) * (M3 * N3 / M / N) ** 2)
4938
+ for j2 in range(0, j3 + 1):
4939
+ # I1_f2_wf3_small = I1_f_small[:,j2].view(N_image,L,1,M3,N3) * wavelet_f3.view(1,1,L,M3,N3)
4940
+ # I1_f2_wf3_2_small = I1_f_small[:,j2].view(N_image,L,1,M3,N3) * wavelet_f3_squared.view(1,1,L,M3,N3)
4941
+ I1_f2_wf3_small = self.backend.bk_reshape(
4942
+ I1_f_small[:, j2], [N_image, 1, L, 1, M3, N3]
4943
+ ) * self.backend.bk_reshape(wavelet_f3, [1, 1, 1, L, M3, N3])
4944
+ I1_f2_wf3_2_small = self.backend.bk_reshape(
4945
+ I1_f_small[:, j2], [N_image, 1, L, 1, M3, N3]
4946
+ ) * self.backend.bk_reshape(wavelet_f3_squared, [1, 1, 1, L, M3, N3])
5632
4947
  if edge:
5633
- I12_w3_small = self.backend.bk_ifftn(I1_f2_wf3_small, dim=(-2,-1), norm='ortho')
5634
- I12_w3_2_small = self.backend.bk_ifftn(I1_f2_wf3_2_small, dim=(-2,-1), norm='ortho')
4948
+ I12_w3_small = self.backend.bk_ifftn(
4949
+ I1_f2_wf3_small, dim=(-2, -1), norm="ortho"
4950
+ )
4951
+ I12_w3_2_small = self.backend.bk_ifftn(
4952
+ I1_f2_wf3_2_small, dim=(-2, -1), norm="ortho"
4953
+ )
5635
4954
  if use_ref:
5636
- if normalization=='P11':
5637
- norm_factor_S3 = (ref_S2[:,None,j3,:] * ref_P11[:,j2,j3,:,:]**pseudo_coef)**0.5
5638
- elif normalization=='S2':
5639
- norm_factor_S3 = (ref_S2[:,None,j3,:] * ref_S2[:,j2,:,None]**pseudo_coef)**0.5
4955
+ if normalization == "P11":
4956
+ norm_factor_S3 = (
4957
+ ref_S2[:, None, j3, :]
4958
+ * ref_P11[:, j2, j3, :, :] ** pseudo_coef
4959
+ ) ** 0.5
4960
+ norm_factor_S3 = self.backend.bk_complex(
4961
+ norm_factor_S3, 0 * norm_factor_S3
4962
+ )
4963
+ elif normalization == "S2":
4964
+ norm_factor_S3 = (
4965
+ ref_S2[:, None, j3, :]
4966
+ * ref_S2[:, j2, :, None] ** pseudo_coef
4967
+ ) ** 0.5
4968
+ norm_factor_S3 = self.backend.bk_complex(
4969
+ norm_factor_S3, 0 * norm_factor_S3
4970
+ )
5640
4971
  else:
5641
- norm_factor_S3 = 1.0
4972
+ norm_factor_S3 = C_ONE
5642
4973
  else:
5643
- if normalization=='P11':
4974
+ if normalization == "P11":
5644
4975
  # [N_image,l2,l3,x,y]
5645
- P11_temp = self.backend.bk_reduce_mean((I1_f2_wf3_small.abs()**2),axis=(-2,-1)) * fft_factor
5646
- norm_factor_S3 = (S2[:,None,j3,:] * P11_temp**pseudo_coef)**0.5
5647
- elif normalization=='S2':
5648
- norm_factor_S3 = (S2[:,None,j3,None,:] * S2[:,None,j2,:,None]**pseudo_coef)**0.5
4976
+ P11_temp = (
4977
+ self.backend.bk_reduce_mean(
4978
+ (I1_f2_wf3_small.abs() ** 2), axis=(-2, -1)
4979
+ )
4980
+ * fft_factor
4981
+ )
4982
+ norm_factor_S3 = (
4983
+ S2[:, None, j3, :] * P11_temp**pseudo_coef
4984
+ ) ** 0.5
4985
+ norm_factor_S3 = self.backend.bk_complex(
4986
+ norm_factor_S3, 0 * norm_factor_S3
4987
+ )
4988
+ elif normalization == "S2":
4989
+ norm_factor_S3 = (
4990
+ S2[:, None, j3, None, :]
4991
+ * S2[:, None, j2, :, None] ** pseudo_coef
4992
+ ) ** 0.5
4993
+ norm_factor_S3 = self.backend.bk_complex(
4994
+ norm_factor_S3, 0 * norm_factor_S3
4995
+ )
5649
4996
  else:
5650
- norm_factor_S3 = 1.0
5651
-
5652
- norm_factor_S3 = self.backend.bk_complex(norm_factor_S3,0*norm_factor_S3)
4997
+ norm_factor_S3 = C_ONE
5653
4998
 
5654
4999
  if not edge:
5655
- S3.append(self.backend.bk_reduce_mean(
5656
- self.backend.bk_reshape(data_f_small,[N_image,1,1,1,M3,N3]) * self.backend.bk_conjugate(I1_f2_wf3_small)
5657
- ,axis=(-2,-1)) * fft_factor / norm_factor_S3)
5000
+ S3.append(
5001
+ self.backend.bk_reduce_mean(
5002
+ self.backend.bk_reshape(
5003
+ data_f_small, [N_image, 1, 1, 1, M3, N3]
5004
+ )
5005
+ * self.backend.bk_conjugate(I1_f2_wf3_small),
5006
+ axis=(-2, -1),
5007
+ )
5008
+ * fft_factor
5009
+ / norm_factor_S3
5010
+ )
5658
5011
  if get_variance:
5659
- S3_sigma.append(self.backend.bk_reduce_std(
5660
- self.backend.bk_reshape(data_f_small,[N_image,1,1,1,M3,N3]) * self.backend.bk_conjugate(I1_f2_wf3_small)
5661
- ,axis=(-2,-1)) * fft_factor / norm_factor_S3)
5012
+ S3_sigma.append(
5013
+ self.backend.bk_reduce_std(
5014
+ self.backend.bk_reshape(
5015
+ data_f_small, [N_image, 1, 1, 1, M3, N3]
5016
+ )
5017
+ * self.backend.bk_conjugate(I1_f2_wf3_small),
5018
+ axis=(-2, -1),
5019
+ )
5020
+ * fft_factor
5021
+ / norm_factor_S3
5022
+ )
5662
5023
  else:
5663
- S3.append(self.backend.bk_reduce_mean(
5664
- (self.backend.bk_reshape(data_small,[N_image,1,1,1,M3,N3]) * self.backend.bk_conjugate(I12_w3_small)
5665
- )[...,edge_dx:M3-edge_dx, edge_dy:N3-edge_dy],axis=(-2,-1)) * fft_factor / norm_factor_S3)
5024
+ S3.append(
5025
+ self.backend.bk_reduce_mean(
5026
+ (
5027
+ self.backend.bk_reshape(
5028
+ data_small, [N_image, 1, 1, 1, M3, N3]
5029
+ )
5030
+ * self.backend.bk_conjugate(I12_w3_small)
5031
+ )[..., edge_dx : M3 - edge_dx, edge_dy : N3 - edge_dy],
5032
+ axis=(-2, -1),
5033
+ )
5034
+ * fft_factor
5035
+ / norm_factor_S3
5036
+ )
5666
5037
  if get_variance:
5667
- S3_sigma.apend(self.backend.bk_reduce_std(
5668
- (self.backend.bk_reshape(data_small,[N_image,1,1,1,M3,N3]) * self.backend.bk_conjugate(I12_w3_small)
5669
- )[...,edge_dx:M3-edge_dx, edge_dy:N3-edge_dy],axis=(-2,-1)) * fft_factor / norm_factor_S3)
5038
+ S3_sigma.apend(
5039
+ self.backend.bk_reduce_std(
5040
+ (
5041
+ self.backend.bk_reshape(
5042
+ data_small, [N_image, 1, 1, 1, M3, N3]
5043
+ )
5044
+ * self.backend.bk_conjugate(I12_w3_small)
5045
+ )[..., edge_dx : M3 - edge_dx, edge_dy : N3 - edge_dy],
5046
+ axis=(-2, -1),
5047
+ )
5048
+ * fft_factor
5049
+ / norm_factor_S3
5050
+ )
5670
5051
  if data2 is not None:
5671
5052
  if not edge:
5672
- S3p.append(self.backend.bk_reduce_mean(
5673
- (self.backend.bk_reshape(data2_f_small,[N_image2,1,1,1,M3,N3]) * self.backend.bk_conjugate(I1_f2_wf3_small)
5674
- ),axis=(-2,-1)) * fft_factor / norm_factor_S3)
5675
-
5053
+ S3p.append(
5054
+ self.backend.bk_reduce_mean(
5055
+ (
5056
+ self.backend.bk_reshape(
5057
+ data2_f_small, [N_image2, 1, 1, 1, M3, N3]
5058
+ )
5059
+ * self.backend.bk_conjugate(I1_f2_wf3_small)
5060
+ ),
5061
+ axis=(-2, -1),
5062
+ )
5063
+ * fft_factor
5064
+ / norm_factor_S3
5065
+ )
5066
+
5676
5067
  if get_variance:
5677
- S3p_sigma.append(self.backend.bk_reduce_std(
5678
- (self.backend.bk_reshape(data2_f_small,[N_image2,1,1,1,M3,N3]) * self.backend.bk_conjugate(I1_f2_wf3_small)
5679
- ),axis=(-2,-1)) * fft_factor / norm_factor_S3)
5068
+ S3p_sigma.append(
5069
+ self.backend.bk_reduce_std(
5070
+ (
5071
+ self.backend.bk_reshape(
5072
+ data2_f_small, [N_image2, 1, 1, 1, M3, N3]
5073
+ )
5074
+ * self.backend.bk_conjugate(I1_f2_wf3_small)
5075
+ ),
5076
+ axis=(-2, -1),
5077
+ )
5078
+ * fft_factor
5079
+ / norm_factor_S3
5080
+ )
5680
5081
  else:
5681
-
5682
- S3p.append(self.backend.bk_reduce_mean(
5683
- (self.backend.bk_reshape(data2_small,[N_image2,1,1,1,M3,N3]) * self.backend.bk_conjugate(I12_w3_small)
5684
- )[...,edge_dx:M3-edge_dx, edge_dy:N3-edge_dy],axis=(-2,-1)) * fft_factor / norm_factor_S3)
5082
+
5083
+ S3p.append(
5084
+ self.backend.bk_reduce_mean(
5085
+ (
5086
+ self.backend.bk_reshape(
5087
+ data2_small, [N_image2, 1, 1, 1, M3, N3]
5088
+ )
5089
+ * self.backend.bk_conjugate(I12_w3_small)
5090
+ )[..., edge_dx : M3 - edge_dx, edge_dy : N3 - edge_dy],
5091
+ axis=(-2, -1),
5092
+ )
5093
+ * fft_factor
5094
+ / norm_factor_S3
5095
+ )
5685
5096
  if get_variance:
5686
- S3p_sigma.append(self.backend.bk_reduce_std(
5687
- (self.backend.bk_reshape(data2_small,[N_image2,1,1,1,M3,N3]) * self.backend.bk_conjugate(I12_w3_small)
5688
- )[...,edge_dx:M3-edge_dx, edge_dy:N3-edge_dy],axis=(-2,-1)) * fft_factor / norm_factor_S3)
5689
-
5097
+ S3p_sigma.append(
5098
+ self.backend.bk_reduce_std(
5099
+ (
5100
+ self.backend.bk_reshape(
5101
+ data2_small, [N_image2, 1, 1, 1, M3, N3]
5102
+ )
5103
+ * self.backend.bk_conjugate(I12_w3_small)
5104
+ )[
5105
+ ...,
5106
+ edge_dx : M3 - edge_dx,
5107
+ edge_dy : N3 - edge_dy,
5108
+ ],
5109
+ axis=(-2, -1),
5110
+ )
5111
+ * fft_factor
5112
+ / norm_factor_S3
5113
+ )
5114
+
5690
5115
  if j2 <= j3:
5691
- if normalization=='S2':
5692
- if use_ref:
5693
- P = 1/((ref_S2[:,j3:j3+1,:,None,None] * ref_S2[:,j2:j2+1,None,:,None] )**(0.5*pseudo_coef))
5694
- else:
5695
- P = 1/(((S2[:,j3:j3+1,:,None,None] * S2[:,j2:j2+1,None,:,None] )**(0.5*pseudo_coef)))
5696
- P=self.backend.bk_complex(P,0.0*P)
5116
+ if normalization == "S2":
5117
+ if use_ref:
5118
+ P = 1 / (
5119
+ (
5120
+ ref_S2[:, j3 : j3 + 1, :, None, None]
5121
+ * ref_S2[:, j2 : j2 + 1, None, :, None]
5122
+ )
5123
+ ** (0.5 * pseudo_coef)
5124
+ )
5125
+ else:
5126
+ P = 1 / (
5127
+ (
5128
+ S2[:, j3 : j3 + 1, :, None, None]
5129
+ * S2[:, j2 : j2 + 1, None, :, None]
5130
+ )
5131
+ ** (0.5 * pseudo_coef)
5132
+ )
5133
+ P = self.backend.bk_complex(P, 0.0 * P)
5697
5134
  else:
5698
- P=self.backend.bk_complex(1.0,0.0)
5699
-
5700
- for j1 in range(0, j2+1):
5701
- if not edge:
5702
- if not if_large_batch:
5703
- # [N_image,l1,l2,l3,x,y]
5704
- S4.append(self.backend.bk_reduce_mean(
5705
- (self.backend.bk_reshape(I1_f_small[:,j1],[N_image,1,L,1,1,M3,N3]) *
5706
- self.backend.bk_conjugate(self.backend.bk_reshape(I1_f2_wf3_2_small,[N_image,1,1,L,L,M3,N3]))
5707
- ),axis=(-2,-1)) * fft_factor*P)
5708
- if get_variance:
5709
- S4_sigma.append(self.backend.bk_reduce_std(
5710
- (self.backend.bk_reshape(I1_f_small[:,j1],[N_image,1,L,1,1,M3,N3]) *
5711
- self.backend.bk_conjugate(self.backend.bk_reshape(I1_f2_wf3_2_small,[N_image,1,1,L,L,M3,N3]))
5712
- ),axis=(-2,-1)) * fft_factor*P)
5713
- else:
5714
- for l1 in range(L):
5715
- # [N_image,l2,l3,x,y]
5716
- S4.append(self.backend.bk_reduce_mean(
5717
- (self.backend.bk_reshape(I1_f_small[:,j1,l1],[N_image,1,1,1,M3,N3]) *
5718
- self.backend.bk_conjugate(self.backend.bk_reshape(I1_f2_wf3_2_small,[N_image,1,L,L,M3,N3]))
5719
- ),axis=(-2,-1)) * fft_factor*P)
5720
- if get_variance:
5721
- S4_sigma.append(self.backend.bk_reduce_std(
5722
- (self.backend.bk_reshape(I1_f_small[:,j1,l1],[N_image,1,1,1,M3,N3]) *
5723
- self.backend.bk_conjugate(self.backend.bk_reshape(I1_f2_wf3_2_small,[N_image,1,L,L,M3,N3]))
5724
- ),axis=(-2,-1)) * fft_factor*P)
5135
+ P = C_ONE
5136
+
5137
+ for j1 in range(0, j2 + 1):
5138
+ if not edge:
5139
+ if not if_large_batch:
5140
+ # [N_image,l1,l2,l3,x,y]
5141
+ S4.append(
5142
+ self.backend.bk_reduce_mean(
5143
+ (
5144
+ self.backend.bk_reshape(
5145
+ I1_f_small[:, j1],
5146
+ [N_image, 1, L, 1, 1, M3, N3],
5147
+ )
5148
+ * self.backend.bk_conjugate(
5149
+ self.backend.bk_reshape(
5150
+ I1_f2_wf3_2_small,
5151
+ [N_image, 1, 1, L, L, M3, N3],
5152
+ )
5153
+ )
5154
+ ),
5155
+ axis=(-2, -1),
5156
+ )
5157
+ * fft_factor
5158
+ * P
5159
+ )
5160
+ if get_variance:
5161
+ S4_sigma.append(
5162
+ self.backend.bk_reduce_std(
5163
+ (
5164
+ self.backend.bk_reshape(
5165
+ I1_f_small[:, j1],
5166
+ [N_image, 1, L, 1, 1, M3, N3],
5167
+ )
5168
+ * self.backend.bk_conjugate(
5169
+ self.backend.bk_reshape(
5170
+ I1_f2_wf3_2_small,
5171
+ [N_image, 1, 1, L, L, M3, N3],
5172
+ )
5173
+ )
5174
+ ),
5175
+ axis=(-2, -1),
5176
+ )
5177
+ * fft_factor
5178
+ * P
5179
+ )
5725
5180
  else:
5726
- if not if_large_batch:
5727
- # [N_image,l1,l2,l3,x,y]
5728
- S4.append(self.backend.bk_reduce_mean(
5729
- (self.backend.bk_reshape(I1_small[:,j1],[N_image,1,L,1,1,M3,N3]) * self.backend.bk_conjugate(
5730
- self.backend.bk_reshape(I12_w3_2_small,[N_image,1,1,L,L,M3,N3])
5181
+ for l1 in range(L):
5182
+ # [N_image,l2,l3,x,y]
5183
+ S4.append(
5184
+ self.backend.bk_reduce_mean(
5185
+ (
5186
+ self.backend.bk_reshape(
5187
+ I1_f_small[:, j1, l1],
5188
+ [N_image, 1, 1, 1, M3, N3],
5189
+ )
5190
+ * self.backend.bk_conjugate(
5191
+ self.backend.bk_reshape(
5192
+ I1_f2_wf3_2_small,
5193
+ [N_image, 1, L, L, M3, N3],
5194
+ )
5195
+ )
5196
+ ),
5197
+ axis=(-2, -1),
5731
5198
  )
5732
- )[...,edge_dx:-edge_dx, edge_dy:-edge_dy],axis=(-2,-1)) * fft_factor*P)
5199
+ * fft_factor
5200
+ * P
5201
+ )
5733
5202
  if get_variance:
5734
- S4_sigma.append(self.backend.bk_reduce_std(
5735
- (self.backend.bk_reshape(I1_small[:,j1],[N_image,1,L,1,1,M3,N3]) * self.backend.bk_conjugate(
5736
- self.backend.bk_reshape(I12_w3_2_small,[N_image,1,1,L,L,M3,N3])
5203
+ S4_sigma.append(
5204
+ self.backend.bk_reduce_std(
5205
+ (
5206
+ self.backend.bk_reshape(
5207
+ I1_f_small[:, j1, l1],
5208
+ [N_image, 1, 1, 1, M3, N3],
5209
+ )
5210
+ * self.backend.bk_conjugate(
5211
+ self.backend.bk_reshape(
5212
+ I1_f2_wf3_2_small,
5213
+ [N_image, 1, L, L, M3, N3],
5214
+ )
5215
+ )
5216
+ ),
5217
+ axis=(-2, -1),
5737
5218
  )
5738
- )[...,edge_dx:-edge_dx, edge_dy:-edge_dy],axis=(-2,-1)) * fft_factor*P)
5739
- else:
5740
- for l1 in range(L):
5741
- # [N_image,l2,l3,x,y]
5742
- S4.append(self.backend.bk_reduce_mean(
5743
- (self.backend.bk_reshape(I1_small[:,j1],[N_image,1,1,1,M3,N3]) * self.backend.bk_conjugate(
5744
- self.backend.bk_reshape(I12_w3_2_small,[N_image,1,L,L,M3,N3])
5219
+ * fft_factor
5220
+ * P
5221
+ )
5222
+ else:
5223
+ if not if_large_batch:
5224
+ # [N_image,l1,l2,l3,x,y]
5225
+ S4.append(
5226
+ self.backend.bk_reduce_mean(
5227
+ (
5228
+ self.backend.bk_reshape(
5229
+ I1_small[:, j1],
5230
+ [N_image, 1, L, 1, 1, M3, N3],
5745
5231
  )
5746
- )[...,edge_dx:-edge_dx, edge_dy:-edge_dy],axis=(-2,-1)) * fft_factor*P)
5747
- if get_variance:
5748
- S4_sigma.append(self.backend.bk_reduce_std(
5749
- (self.backend.bk_reshape(I1_small[:,j1],[N_image,1,1,1,M3,N3]) * self.backend.bk_conjugate(
5750
- self.backend.bk_reshape(I12_w3_2_small,[N_image,1,L,L,M3,N3])
5232
+ * self.backend.bk_conjugate(
5233
+ self.backend.bk_reshape(
5234
+ I12_w3_2_small,
5235
+ [N_image, 1, 1, L, L, M3, N3],
5751
5236
  )
5752
- )[...,edge_dx:-edge_dx, edge_dy:-edge_dy],axis=(-2,-1)) * fft_factor*P)
5753
-
5754
- S3=self.backend.bk_concat(S3,axis=1)
5755
- S4=self.backend.bk_concat(S4,axis=1)
5756
-
5237
+ )
5238
+ )[..., edge_dx:-edge_dx, edge_dy:-edge_dy],
5239
+ axis=(-2, -1),
5240
+ )
5241
+ * fft_factor
5242
+ * P
5243
+ )
5244
+ if get_variance:
5245
+ S4_sigma.append(
5246
+ self.backend.bk_reduce_std(
5247
+ (
5248
+ self.backend.bk_reshape(
5249
+ I1_small[:, j1],
5250
+ [N_image, 1, L, 1, 1, M3, N3],
5251
+ )
5252
+ * self.backend.bk_conjugate(
5253
+ self.backend.bk_reshape(
5254
+ I12_w3_2_small,
5255
+ [N_image, 1, 1, L, L, M3, N3],
5256
+ )
5257
+ )
5258
+ )[..., edge_dx:-edge_dx, edge_dy:-edge_dy],
5259
+ axis=(-2, -1),
5260
+ )
5261
+ * fft_factor
5262
+ * P
5263
+ )
5264
+ else:
5265
+ for l1 in range(L):
5266
+ # [N_image,l2,l3,x,y]
5267
+ S4.append(
5268
+ self.backend.bk_reduce_mean(
5269
+ (
5270
+ self.backend.bk_reshape(
5271
+ I1_small[:, j1],
5272
+ [N_image, 1, 1, 1, M3, N3],
5273
+ )
5274
+ * self.backend.bk_conjugate(
5275
+ self.backend.bk_reshape(
5276
+ I12_w3_2_small,
5277
+ [N_image, 1, L, L, M3, N3],
5278
+ )
5279
+ )
5280
+ )[..., edge_dx:-edge_dx, edge_dy:-edge_dy],
5281
+ axis=(-2, -1),
5282
+ )
5283
+ * fft_factor
5284
+ * P
5285
+ )
5286
+ if get_variance:
5287
+ S4_sigma.append(
5288
+ self.backend.bk_reduce_std(
5289
+ (
5290
+ self.backend.bk_reshape(
5291
+ I1_small[:, j1],
5292
+ [N_image, 1, 1, 1, M3, N3],
5293
+ )
5294
+ * self.backend.bk_conjugate(
5295
+ self.backend.bk_reshape(
5296
+ I12_w3_2_small,
5297
+ [N_image, 1, L, L, M3, N3],
5298
+ )
5299
+ )
5300
+ )[
5301
+ ...,
5302
+ edge_dx:-edge_dx,
5303
+ edge_dy:-edge_dy,
5304
+ ],
5305
+ axis=(-2, -1),
5306
+ )
5307
+ * fft_factor
5308
+ * P
5309
+ )
5310
+
5311
+ S3 = self.backend.bk_concat(S3, axis=1)
5312
+ S4 = self.backend.bk_concat(S4, axis=1)
5313
+
5757
5314
  if get_variance:
5758
- S3_sigma=self.backend.bk_concat(S3_sigma,axis=1)
5759
- S4_sigma=self.backend.bk_concat(S4_sigma,axis=1)
5760
-
5315
+ S3_sigma = self.backend.bk_concat(S3_sigma, axis=1)
5316
+ S4_sigma = self.backend.bk_concat(S4_sigma, axis=1)
5317
+
5761
5318
  if data2 is not None:
5762
- S3p=self.backend.bk_concat(S3p,axis=1)
5319
+ S3p = self.backend.bk_concat(S3p, axis=1)
5763
5320
  if get_variance:
5764
- S3p_sigma=self.backend.bk_concat(S3p_sigma,axis=1)
5765
-
5321
+ S3p_sigma = self.backend.bk_concat(S3p_sigma, axis=1)
5322
+
5766
5323
  # average over l1 to obtain simple isotropic statistics
5767
5324
  if iso_ang:
5768
- S2_iso = self.backend.bk_reduce_mean(S2,axis=(-1))
5769
- S1_iso = self.backend.bk_reduce_mean(S1,axis=(-1))
5325
+ S2_iso = self.backend.bk_reduce_mean(S2, axis=(-1))
5326
+ S1_iso = self.backend.bk_reduce_mean(S1, axis=(-1))
5770
5327
  for l1 in range(L):
5771
5328
  for l2 in range(L):
5772
- S3_iso[...,(l2-l1)%L] += S3[...,l1,l2]
5329
+ S3_iso[..., (l2 - l1) % L] += S3[..., l1, l2]
5773
5330
  if data2 is not None:
5774
- S3p_iso[...,(l2-l1)%L] += S3p[...,l1,l2]
5331
+ S3p_iso[..., (l2 - l1) % L] += S3p[..., l1, l2]
5775
5332
  for l3 in range(L):
5776
- S4_iso[...,(l2-l1)%L,(l3-l1)%L] += S4[...,l1,l2,l3]
5777
- S3_iso /= L; S4_iso /= L
5333
+ S4_iso[..., (l2 - l1) % L, (l3 - l1) % L] += S4[..., l1, l2, l3]
5334
+ S3_iso /= L
5335
+ S4_iso /= L
5778
5336
  if data2 is not None:
5779
5337
  S3p_iso /= L
5780
-
5338
+
5781
5339
  if get_variance:
5782
- S2_sigma_iso = self.backend.bk_reduce_mean(S2_sigma,axis=(-1))
5783
- S1_sigma_iso = self.backend.bk_reduce_mean(S1_sigma,axis=(-1))
5340
+ S2_sigma_iso = self.backend.bk_reduce_mean(S2_sigma, axis=(-1))
5341
+ S1_sigma_iso = self.backend.bk_reduce_mean(S1_sigma, axis=(-1))
5784
5342
  for l1 in range(L):
5785
5343
  for l2 in range(L):
5786
- S3_sigma_iso[...,(l2-l1)%L] += S3_sigma[...,l1,l2]
5344
+ S3_sigma_iso[..., (l2 - l1) % L] += S3_sigma[..., l1, l2]
5787
5345
  if data2 is not None:
5788
- S3p_sigma_iso[...,(l2-l1)%L] += S3p_sigma[...,l1,l2]
5346
+ S3p_sigma_iso[..., (l2 - l1) % L] += S3p_sigma[..., l1, l2]
5789
5347
  for l3 in range(L):
5790
- S4_sigma_iso[...,(l2-l1)%L,(l3-l1)%L] += S4_sigma[...,l1,l2,l3]
5791
- S3_sigma_iso /= L; S4_sigma_iso /= L
5348
+ S4_sigma_iso[..., (l2 - l1) % L, (l3 - l1) % L] += S4_sigma[
5349
+ ..., l1, l2, l3
5350
+ ]
5351
+ S3_sigma_iso /= L
5352
+ S4_sigma_iso /= L
5792
5353
  if data2 is not None:
5793
5354
  S3p_sigma_iso /= L
5794
-
5355
+
5795
5356
  if data2 is None:
5796
- mean_data=self.backend.bk_reshape(self.backend.bk_reduce_mean(data,axis=(-2,-1)),[N_image,1])
5797
- std_data=self.backend.bk_reshape(self.backend.bk_reduce_std(data,axis=(-2,-1)),[N_image,1])
5357
+ mean_data = self.backend.bk_reshape(
5358
+ self.backend.bk_reduce_mean(data, axis=(-2, -1)), [N_image, 1]
5359
+ )
5360
+ std_data = self.backend.bk_reshape(
5361
+ self.backend.bk_reduce_std(data, axis=(-2, -1)), [N_image, 1]
5362
+ )
5798
5363
  else:
5799
- mean_data=self.backend.bk_reshape(self.backend.bk_reduce_mean(data*data2,axis=(-2,-1)),[N_image,1])
5800
- std_data=self.backend.bk_reshape(self.backend.bk_reduce_std(data*data2,axis=(-2,-1)),[N_image,1])
5801
-
5364
+ mean_data = self.backend.bk_reshape(
5365
+ self.backend.bk_reduce_mean(data * data2, axis=(-2, -1)), [N_image, 1]
5366
+ )
5367
+ std_data = self.backend.bk_reshape(
5368
+ self.backend.bk_reduce_std(data * data2, axis=(-2, -1)), [N_image, 1]
5369
+ )
5370
+
5802
5371
  if get_variance:
5803
- ref_sigma={}
5372
+ ref_sigma = {}
5804
5373
  if iso_ang:
5805
- ref_sigma['std_data']=std_data
5806
- ref_sigma['S1_sigma']=S1_sigma_iso
5807
- ref_sigma['S2_sigma']=S2_sigma_iso
5808
- ref_sigma['S3_sigma']=S3_sigma_iso
5809
- ref_sigma['S4_sigma']=S4_sigma_iso
5374
+ ref_sigma["std_data"] = std_data
5375
+ ref_sigma["S1_sigma"] = S1_sigma_iso
5376
+ ref_sigma["S2_sigma"] = S2_sigma_iso
5377
+ ref_sigma["S3_sigma"] = S3_sigma_iso
5378
+ ref_sigma["S4_sigma"] = S4_sigma_iso
5810
5379
  if data2 is not None:
5811
- ref_sigma['S3p_sigma']=S3p_sigma_iso
5380
+ ref_sigma["S3p_sigma"] = S3p_sigma_iso
5812
5381
  else:
5813
- ref_sigma['std_data']=std_data
5814
- ref_sigma['S1_sigma']=S1_sigma
5815
- ref_sigma['S2_sigma']=S2_sigma
5816
- ref_sigma['S3_sigma']=S3_sigma
5817
- ref_sigma['S4_sigma']=S4_sigma
5382
+ ref_sigma["std_data"] = std_data
5383
+ ref_sigma["S1_sigma"] = S1_sigma
5384
+ ref_sigma["S2_sigma"] = S2_sigma
5385
+ ref_sigma["S3_sigma"] = S3_sigma
5386
+ ref_sigma["S4_sigma"] = S4_sigma
5818
5387
  if data2 is not None:
5819
- ref_sigma['S3p_sigma']=S3_sigma
5820
-
5388
+ ref_sigma["S3p_sigma"] = S3_sigma
5389
+
5821
5390
  if data2 is None:
5822
5391
  if iso_ang:
5823
5392
  if ref_sigma is not None:
5824
- for_synthesis = self.backend.bk_concat((
5825
- mean_data/ref_sigma['std_data'],
5826
- std_data/ref_sigma['std_data'],
5827
- self.backend.bk_reshape(self.backend.bk_log(S2_iso/ref_sigma['S2_sigma']),[N_image, -1]),
5828
- self.backend.bk_reshape(self.backend.bk_log(S1_iso/ref_sigma['S1_sigma']),[N_image, -1]),
5829
- self.backend.bk_reshape(self.backend.bk_real(S3_iso/ref_sigma['S3_sigma']),[N_image, -1]),
5830
- self.backend.bk_reshape(self.backend.bk_imag(S3_iso/ref_sigma['S3_sigma']),[N_image, -1]),
5831
- self.backend.bk_reshape(self.backend.bk_real(S4_iso/ref_sigma['S4_sigma']),[N_image, -1]),
5832
- self.backend.bk_reshape(self.backend.bk_imag(S4_iso/ref_sigma['S4_sigma']),[N_image, -1]),
5833
- ),axis=-1)
5393
+ for_synthesis = self.backend.bk_concat(
5394
+ (
5395
+ mean_data / ref_sigma["std_data"],
5396
+ std_data / ref_sigma["std_data"],
5397
+ self.backend.bk_reshape(
5398
+ self.backend.bk_log(S2_iso / ref_sigma["S2_sigma"]),
5399
+ [N_image, -1],
5400
+ ),
5401
+ self.backend.bk_reshape(
5402
+ self.backend.bk_log(S1_iso / ref_sigma["S1_sigma"]),
5403
+ [N_image, -1],
5404
+ ),
5405
+ self.backend.bk_reshape(
5406
+ self.backend.bk_real(S3_iso / ref_sigma["S3_sigma"]),
5407
+ [N_image, -1],
5408
+ ),
5409
+ self.backend.bk_reshape(
5410
+ self.backend.bk_imag(S3_iso / ref_sigma["S3_sigma"]),
5411
+ [N_image, -1],
5412
+ ),
5413
+ self.backend.bk_reshape(
5414
+ self.backend.bk_real(S4_iso / ref_sigma["S4_sigma"]),
5415
+ [N_image, -1],
5416
+ ),
5417
+ self.backend.bk_reshape(
5418
+ self.backend.bk_imag(S4_iso / ref_sigma["S4_sigma"]),
5419
+ [N_image, -1],
5420
+ ),
5421
+ ),
5422
+ axis=-1,
5423
+ )
5834
5424
  else:
5835
- for_synthesis = self.backend.bk_concat((
5836
- mean_data/std_data,
5837
- std_data,
5838
- self.backend.bk_reshape(self.backend.bk_log(S2_iso),[N_image, -1]),
5839
- self.backend.bk_reshape(self.backend.bk_log(S1_iso),[N_image, -1]),
5840
- self.backend.bk_reshape(self.backend.bk_real(S3_iso),[N_image, -1]),
5841
- self.backend.bk_reshape(self.backend.bk_imag(S3_iso),[N_image, -1]),
5842
- self.backend.bk_reshape(self.backend.bk_real(S4_iso),[N_image, -1]),
5843
- self.backend.bk_reshape(self.backend.bk_imag(S4_iso),[N_image, -1]),
5844
- ),axis=-1)
5425
+ for_synthesis = self.backend.bk_concat(
5426
+ (
5427
+ mean_data / std_data,
5428
+ std_data,
5429
+ self.backend.bk_reshape(
5430
+ self.backend.bk_log(S2_iso), [N_image, -1]
5431
+ ),
5432
+ self.backend.bk_reshape(
5433
+ self.backend.bk_log(S1_iso), [N_image, -1]
5434
+ ),
5435
+ self.backend.bk_reshape(
5436
+ self.backend.bk_real(S3_iso), [N_image, -1]
5437
+ ),
5438
+ self.backend.bk_reshape(
5439
+ self.backend.bk_imag(S3_iso), [N_image, -1]
5440
+ ),
5441
+ self.backend.bk_reshape(
5442
+ self.backend.bk_real(S4_iso), [N_image, -1]
5443
+ ),
5444
+ self.backend.bk_reshape(
5445
+ self.backend.bk_imag(S4_iso), [N_image, -1]
5446
+ ),
5447
+ ),
5448
+ axis=-1,
5449
+ )
5845
5450
  else:
5846
5451
  if ref_sigma is not None:
5847
- for_synthesis = self.backend.bk_concat((
5848
- mean_data/ref_sigma['std_data'],
5849
- std_data/ref_sigma['std_data'],
5850
- self.backend.bk_reshape(self.backend.bk_log(S2/ref_sigma['S2_sigma']),[N_image, -1]),
5851
- self.backend.bk_reshape(self.backend.bk_log(S1/ref_sigma['S1_sigma']),[N_image, -1]),
5852
- self.backend.bk_reshape(self.backend.bk_real(S3/ref_sigma['S3_sigma']),[N_image, -1]),
5853
- self.backend.bk_reshape(self.backend.bk_imag(S3/ref_sigma['S3_sigma']),[N_image, -1]),
5854
- self.backend.bk_reshape(self.backend.bk_real(S4/ref_sigma['S4_sigma']),[N_image, -1]),
5855
- self.backend.bk_reshape(self.backend.bk_imag(S4/ref_sigma['S4_sigma']),[N_image, -1]),
5856
- ),axis=-1)
5452
+ for_synthesis = self.backend.bk_concat(
5453
+ (
5454
+ mean_data / ref_sigma["std_data"],
5455
+ std_data / ref_sigma["std_data"],
5456
+ self.backend.bk_reshape(
5457
+ self.backend.bk_log(S2 / ref_sigma["S2_sigma"]),
5458
+ [N_image, -1],
5459
+ ),
5460
+ self.backend.bk_reshape(
5461
+ self.backend.bk_log(S1 / ref_sigma["S1_sigma"]),
5462
+ [N_image, -1],
5463
+ ),
5464
+ self.backend.bk_reshape(
5465
+ self.backend.bk_real(S3 / ref_sigma["S3_sigma"]),
5466
+ [N_image, -1],
5467
+ ),
5468
+ self.backend.bk_reshape(
5469
+ self.backend.bk_imag(S3 / ref_sigma["S3_sigma"]),
5470
+ [N_image, -1],
5471
+ ),
5472
+ self.backend.bk_reshape(
5473
+ self.backend.bk_real(S4 / ref_sigma["S4_sigma"]),
5474
+ [N_image, -1],
5475
+ ),
5476
+ self.backend.bk_reshape(
5477
+ self.backend.bk_imag(S4 / ref_sigma["S4_sigma"]),
5478
+ [N_image, -1],
5479
+ ),
5480
+ ),
5481
+ axis=-1,
5482
+ )
5857
5483
  else:
5858
- for_synthesis = self.backend.bk_concat((
5859
- mean_data/std_data,
5484
+ for_synthesis = self.backend.bk_concat(
5485
+ (
5486
+ mean_data / std_data,
5860
5487
  std_data,
5861
- self.backend.bk_reshape(self.backend.bk_log(S2),[N_image, -1]),
5862
- self.backend.bk_reshape(self.backend.bk_log(S1),[N_image, -1]),
5863
- self.backend.bk_reshape(self.backend.bk_real(S3),[N_image, -1]),
5864
- self.backend.bk_reshape(self.backend.bk_imag(S3),[N_image, -1]),
5865
- self.backend.bk_reshape(self.backend.bk_real(S4),[N_image, -1]),
5866
- self.backend.bk_reshape(self.backend.bk_imag(S4),[N_image, -1])
5867
- ),axis=-1)
5488
+ self.backend.bk_reshape(
5489
+ self.backend.bk_log(S2), [N_image, -1]
5490
+ ),
5491
+ self.backend.bk_reshape(
5492
+ self.backend.bk_log(S1), [N_image, -1]
5493
+ ),
5494
+ self.backend.bk_reshape(
5495
+ self.backend.bk_real(S3), [N_image, -1]
5496
+ ),
5497
+ self.backend.bk_reshape(
5498
+ self.backend.bk_imag(S3), [N_image, -1]
5499
+ ),
5500
+ self.backend.bk_reshape(
5501
+ self.backend.bk_real(S4), [N_image, -1]
5502
+ ),
5503
+ self.backend.bk_reshape(
5504
+ self.backend.bk_imag(S4), [N_image, -1]
5505
+ ),
5506
+ ),
5507
+ axis=-1,
5508
+ )
5868
5509
  else:
5869
5510
  if iso_ang:
5870
5511
  if ref_sigma is not None:
5871
- for_synthesis = self.backend.backend.cat((
5872
- mean_data/ref_sigma['std_data'],
5873
- std_data/ref_sigma['std_data'],
5874
- self.backend.bk_reshape(self.backend.bk_real(S2_iso/ref_sigma['S2_sigma']),[N_image, -1]),
5875
- self.backend.bk_reshape(self.backend.bk_real(S1_iso/ref_sigma['S1_sigma']),[N_image, -1]),
5876
- self.backend.bk_reshape(self.backend.bk_real(S3_iso/ref_sigma['S3_sigma']),[N_image, -1]),
5877
- self.backend.bk_reshape(self.backend.bk_imag(S3_iso/ref_sigma['S3_sigma']),[N_image, -1]),
5878
- self.backend.bk_reshape(self.backend.bk_real(S3p_iso/ref_sigma['S3p_sigma']),[N_image, -1]),
5879
- self.backend.bk_reshape(self.backend.bk_imag(S3p_iso/ref_sigma['S3p_sigma']),[N_image, -1]),
5880
- self.backend.bk_reshape(self.backend.bk_real(S4_iso/ref_sigma['S4_sigma']),[N_image, -1]),
5881
- self.backend.bk_reshape(self.backend.bk_imag(S4_iso/ref_sigma['S4_sigma']),[N_image, -1]),
5882
- ),axis=-1)
5512
+ for_synthesis = self.backend.backend.cat(
5513
+ (
5514
+ mean_data / ref_sigma["std_data"],
5515
+ std_data / ref_sigma["std_data"],
5516
+ self.backend.bk_reshape(
5517
+ self.backend.bk_real(S2_iso / ref_sigma["S2_sigma"]),
5518
+ [N_image, -1],
5519
+ ),
5520
+ self.backend.bk_reshape(
5521
+ self.backend.bk_real(S1_iso / ref_sigma["S1_sigma"]),
5522
+ [N_image, -1],
5523
+ ),
5524
+ self.backend.bk_reshape(
5525
+ self.backend.bk_real(S3_iso / ref_sigma["S3_sigma"]),
5526
+ [N_image, -1],
5527
+ ),
5528
+ self.backend.bk_reshape(
5529
+ self.backend.bk_imag(S3_iso / ref_sigma["S3_sigma"]),
5530
+ [N_image, -1],
5531
+ ),
5532
+ self.backend.bk_reshape(
5533
+ self.backend.bk_real(S3p_iso / ref_sigma["S3p_sigma"]),
5534
+ [N_image, -1],
5535
+ ),
5536
+ self.backend.bk_reshape(
5537
+ self.backend.bk_imag(S3p_iso / ref_sigma["S3p_sigma"]),
5538
+ [N_image, -1],
5539
+ ),
5540
+ self.backend.bk_reshape(
5541
+ self.backend.bk_real(S4_iso / ref_sigma["S4_sigma"]),
5542
+ [N_image, -1],
5543
+ ),
5544
+ self.backend.bk_reshape(
5545
+ self.backend.bk_imag(S4_iso / ref_sigma["S4_sigma"]),
5546
+ [N_image, -1],
5547
+ ),
5548
+ ),
5549
+ axis=-1,
5550
+ )
5883
5551
  else:
5884
- for_synthesis = self.backend.backend.cat((
5885
- mean_data/std_data,
5886
- std_data,
5887
- self.backend.bk_reshape(self.backend.bk_real(S2_iso),[N_image, -1]),
5888
- self.backend.bk_reshape(self.backend.bk_real(S1_iso),[N_image, -1]),
5889
- self.backend.bk_reshape(self.backend.bk_real(S3_iso),[N_image, -1]),
5890
- self.backend.bk_reshape(self.backend.bk_imag(S3_iso),[N_image, -1]),
5891
- self.backend.bk_reshape(self.backend.bk_real(S3p_iso),[N_image, -1]),
5892
- self.backend.bk_reshape(self.backend.bk_imag(S3p_iso),[N_image, -1]),
5893
- self.backend.bk_reshape(self.backend.bk_real(S4_iso),[N_image, -1]),
5894
- self.backend.bk_reshape(self.backend.bk_imag(S4_iso),[N_image, -1]),
5895
- ),axis=-1)
5552
+ for_synthesis = self.backend.backend.cat(
5553
+ (
5554
+ mean_data / std_data,
5555
+ std_data,
5556
+ self.backend.bk_reshape(
5557
+ self.backend.bk_real(S2_iso), [N_image, -1]
5558
+ ),
5559
+ self.backend.bk_reshape(
5560
+ self.backend.bk_real(S1_iso), [N_image, -1]
5561
+ ),
5562
+ self.backend.bk_reshape(
5563
+ self.backend.bk_real(S3_iso), [N_image, -1]
5564
+ ),
5565
+ self.backend.bk_reshape(
5566
+ self.backend.bk_imag(S3_iso), [N_image, -1]
5567
+ ),
5568
+ self.backend.bk_reshape(
5569
+ self.backend.bk_real(S3p_iso), [N_image, -1]
5570
+ ),
5571
+ self.backend.bk_reshape(
5572
+ self.backend.bk_imag(S3p_iso), [N_image, -1]
5573
+ ),
5574
+ self.backend.bk_reshape(
5575
+ self.backend.bk_real(S4_iso), [N_image, -1]
5576
+ ),
5577
+ self.backend.bk_reshape(
5578
+ self.backend.bk_imag(S4_iso), [N_image, -1]
5579
+ ),
5580
+ ),
5581
+ axis=-1,
5582
+ )
5896
5583
  else:
5897
5584
  if ref_sigma is not None:
5898
- for_synthesis = self.backend.backend.cat((
5899
- mean_data/ref_sigma['std_data'],
5900
- std_data/ref_sigma['std_data'],
5901
- self.backend.bk_reshape(self.backend.bk_real(S2/ref_sigma['S2_sigma']),[N_image, -1]),
5902
- self.backend.bk_reshape(self.backend.bk_real(S1/ref_sigma['S1_sigma']),[N_image, -1]),
5903
- self.backend.bk_reshape(self.backend.bk_real(S3/ref_sigma['S3_sigma']),[N_image, -1]),
5904
- self.backend.bk_reshape(self.backend.bk_imag(S3/ref_sigma['S3_sigma']),[N_image, -1]),
5905
- self.backend.bk_reshape(self.backend.bk_real(S3p/ref_sigma['S3p_sigma']),[N_image, -1]),
5906
- self.backend.bk_reshape(self.backend.bk_imag(S3p/ref_sigma['S3p_sigma']),[N_image, -1]),
5907
- self.backend.bk_reshape(self.backend.bk_real(S4/ref_sigma['S4_sigma']),[N_image, -1]),
5908
- self.backend.bk_reshape(self.backend.bk_imag(S4/ref_sigma['S4_sigma']),[N_image, -1]),
5909
- ),axis=-1)
5585
+ for_synthesis = self.backend.backend.cat(
5586
+ (
5587
+ mean_data / ref_sigma["std_data"],
5588
+ std_data / ref_sigma["std_data"],
5589
+ self.backend.bk_reshape(
5590
+ self.backend.bk_real(S2 / ref_sigma["S2_sigma"]),
5591
+ [N_image, -1],
5592
+ ),
5593
+ self.backend.bk_reshape(
5594
+ self.backend.bk_real(S1 / ref_sigma["S1_sigma"]),
5595
+ [N_image, -1],
5596
+ ),
5597
+ self.backend.bk_reshape(
5598
+ self.backend.bk_real(S3 / ref_sigma["S3_sigma"]),
5599
+ [N_image, -1],
5600
+ ),
5601
+ self.backend.bk_reshape(
5602
+ self.backend.bk_imag(S3 / ref_sigma["S3_sigma"]),
5603
+ [N_image, -1],
5604
+ ),
5605
+ self.backend.bk_reshape(
5606
+ self.backend.bk_real(S3p / ref_sigma["S3p_sigma"]),
5607
+ [N_image, -1],
5608
+ ),
5609
+ self.backend.bk_reshape(
5610
+ self.backend.bk_imag(S3p / ref_sigma["S3p_sigma"]),
5611
+ [N_image, -1],
5612
+ ),
5613
+ self.backend.bk_reshape(
5614
+ self.backend.bk_real(S4 / ref_sigma["S4_sigma"]),
5615
+ [N_image, -1],
5616
+ ),
5617
+ self.backend.bk_reshape(
5618
+ self.backend.bk_imag(S4 / ref_sigma["S4_sigma"]),
5619
+ [N_image, -1],
5620
+ ),
5621
+ ),
5622
+ axis=-1,
5623
+ )
5910
5624
  else:
5911
- for_synthesis = self.backend.bk_concat((
5912
- mean_data/std_data,
5625
+ for_synthesis = self.backend.bk_concat(
5626
+ (
5627
+ mean_data / std_data,
5913
5628
  std_data,
5914
- self.backend.bk_reshape(self.backend.bk_real(S2),[N_image, -1]),
5915
- self.backend.bk_reshape(self.backend.bk_real(S1),[N_image, -1]),
5916
- self.backend.bk_reshape(self.backend.bk_real(S3),[N_image, -1]),
5917
- self.backend.bk_reshape(self.backend.bk_imag(S3),[N_image, -1]),
5918
- self.backend.bk_reshape(self.backend.bk_real(S3p),[N_image, -1]),
5919
- self.backend.bk_reshape(self.backend.bk_imag(S3p),[N_image, -1]),
5920
- self.backend.bk_reshape(self.backend.bk_real(S4),[N_image, -1]),
5921
- self.backend.bk_reshape(self.backend.bk_imag(S4),[N_image, -1])
5922
- ),axis=-1)
5923
-
5924
- if not use_ref:
5925
- self.ref_scattering_cov_S2=S2
5926
-
5629
+ self.backend.bk_reshape(
5630
+ self.backend.bk_real(S2), [N_image, -1]
5631
+ ),
5632
+ self.backend.bk_reshape(
5633
+ self.backend.bk_real(S1), [N_image, -1]
5634
+ ),
5635
+ self.backend.bk_reshape(
5636
+ self.backend.bk_real(S3), [N_image, -1]
5637
+ ),
5638
+ self.backend.bk_reshape(
5639
+ self.backend.bk_imag(S3), [N_image, -1]
5640
+ ),
5641
+ self.backend.bk_reshape(
5642
+ self.backend.bk_real(S3p), [N_image, -1]
5643
+ ),
5644
+ self.backend.bk_reshape(
5645
+ self.backend.bk_imag(S3p), [N_image, -1]
5646
+ ),
5647
+ self.backend.bk_reshape(
5648
+ self.backend.bk_real(S4), [N_image, -1]
5649
+ ),
5650
+ self.backend.bk_reshape(
5651
+ self.backend.bk_imag(S4), [N_image, -1]
5652
+ ),
5653
+ ),
5654
+ axis=-1,
5655
+ )
5656
+
5657
+ if not use_ref:
5658
+ self.ref_scattering_cov_S2 = S2
5659
+
5927
5660
  if get_variance:
5928
- return for_synthesis,ref_sigma
5929
-
5661
+ return for_synthesis, ref_sigma
5662
+
5930
5663
  return for_synthesis
5931
-
5932
-
5933
- def to_gaussian(self,x):
5934
- from scipy.stats import norm
5664
+
5665
+ def to_gaussian(self, x):
5935
5666
  from scipy.interpolate import interp1d
5667
+ from scipy.stats import norm
5936
5668
 
5937
- idx=np.argsort(x.flatten())
5669
+ idx = np.argsort(x.flatten())
5938
5670
  p = (np.arange(1, idx.shape[0] + 1) - 0.5) / idx.shape[0]
5939
- im_target=x.flatten()
5671
+ im_target = x.flatten()
5940
5672
  im_target[idx] = norm.ppf(p)
5941
-
5673
+
5942
5674
  # Interpolation cubique
5943
- self.f_gaussian = interp1d(im_target[idx], x.flatten()[idx], kind='cubic')
5944
- self.val_min=im_target[idx[0]]
5945
- self.val_max=im_target[idx[-1]]
5675
+ self.f_gaussian = interp1d(im_target[idx], x.flatten()[idx], kind="cubic")
5676
+ self.val_min = im_target[idx[0]]
5677
+ self.val_max = im_target[idx[-1]]
5946
5678
  return im_target.reshape(x.shape)
5947
5679
 
5680
+ def from_gaussian(self, x):
5948
5681
 
5949
- def from_gaussian(self,x):
5950
-
5951
- x=self.backend.bk_clip_by_value(x,self.val_min,self.val_max)
5682
+ x = self.backend.bk_clip_by_value(x, self.val_min, self.val_max)
5952
5683
  return self.f_gaussian(self.backend.to_numpy(x))
5953
-
5684
+
5954
5685
  def square(self, x):
5955
5686
  if isinstance(x, scat_cov):
5956
5687
  if x.S1 is None:
@@ -6302,89 +6033,133 @@ class funct(FOC.FoCUS):
6302
6033
  s0, s2, s3, s4, s1=s1, s3p=s3p, backend=self.backend, use_1D=self.use_1D
6303
6034
  )
6304
6035
 
6305
- def synthesis(self,
6306
- image_target,
6307
- nstep=4,
6308
- seed=1234,
6309
- Jmax=None,
6310
- edge=False,
6311
- to_gaussian=True,
6312
- use_variance=False,
6313
- synthesised_N=1,
6314
- input_image=None,
6315
- iso_ang=False,
6316
- EVAL_FREQUENCY=100,
6317
- NUM_EPOCHS = 300):
6318
-
6319
- import foscat.Synthesis as synthe
6036
+ def synthesis(
6037
+ self,
6038
+ image_target,
6039
+ nstep=4,
6040
+ seed=1234,
6041
+ Jmax=None,
6042
+ edge=False,
6043
+ to_gaussian=True,
6044
+ use_variance=False,
6045
+ synthesised_N=1,
6046
+ input_image=None,
6047
+ grd_mask=None,
6048
+ iso_ang=False,
6049
+ EVAL_FREQUENCY=100,
6050
+ NUM_EPOCHS=300,
6051
+ ):
6052
+
6320
6053
  import time
6321
6054
 
6322
- def The_loss(u,scat_operator,args):
6323
- ref = args[0]
6055
+ import foscat.Synthesis as synthe
6056
+
6057
+ def The_loss(u, scat_operator, args):
6058
+ ref = args[0]
6324
6059
  sref = args[1]
6325
- use_v= args[2]
6326
-
6060
+ use_v = args[2]
6061
+
6327
6062
  # compute scattering covariance of the current synthetised map called u
6328
6063
  if use_v:
6329
- learn=scat_operator.reduce_mean_batch(scat_operator.scattering_cov(u,edge=edge,Jmax=Jmax,ref_sigma=sref,use_ref=True,iso_ang=iso_ang))
6064
+ learn = scat_operator.reduce_mean_batch(
6065
+ scat_operator.scattering_cov(
6066
+ u,
6067
+ edge=edge,
6068
+ Jmax=Jmax,
6069
+ ref_sigma=sref,
6070
+ use_ref=True,
6071
+ iso_ang=iso_ang,
6072
+ )
6073
+ )
6330
6074
  else:
6331
- learn=scat_operator.reduce_mean_batch(scat_operator.scattering_cov(u,edge=edge,Jmax=Jmax,use_ref=True,iso_ang=iso_ang))
6332
-
6075
+ learn = scat_operator.reduce_mean_batch(
6076
+ scat_operator.scattering_cov(
6077
+ u, edge=edge, Jmax=Jmax, use_ref=True, iso_ang=iso_ang
6078
+ )
6079
+ )
6080
+
6333
6081
  # make the difference withe the reference coordinates
6334
- loss=scat_operator.backend.bk_reduce_mean(scat_operator.backend.bk_square((learn-ref)))
6082
+ loss = scat_operator.backend.bk_reduce_mean(
6083
+ scat_operator.backend.bk_square(learn - ref)
6084
+ )
6335
6085
  return loss
6336
6086
 
6337
6087
  if to_gaussian:
6338
6088
  # Change the data histogram to gaussian distribution
6339
- im_target=self.to_gaussian(image_target)
6089
+ im_target = self.to_gaussian(image_target)
6340
6090
  else:
6341
- im_target=image_target
6342
-
6343
- axis=len(im_target.shape)-1
6091
+ im_target = image_target
6092
+
6093
+ axis = len(im_target.shape) - 1
6344
6094
  if self.use_2D:
6345
- axis-=1
6346
- if axis==0:
6347
- im_target=self.backend.bk_expand_dims(im_target,0)
6095
+ axis -= 1
6096
+ if axis == 0:
6097
+ im_target = self.backend.bk_expand_dims(im_target, 0)
6348
6098
 
6349
6099
  # compute the number of possible steps
6350
6100
  if self.use_2D:
6351
- jmax=int(np.min([np.log(im_target.shape[1]),np.log(im_target.shape[2])])/np.log(2))
6101
+ jmax = int(
6102
+ np.min([np.log(im_target.shape[1]), np.log(im_target.shape[2])])
6103
+ / np.log(2)
6104
+ )
6352
6105
  elif self.use_1D:
6353
- jmax=int(np.log(im_target.shape[1])/np.log(2))
6106
+ jmax = int(np.log(im_target.shape[1]) / np.log(2))
6354
6107
  else:
6355
- jmax=int((np.log(im_target.shape[1]//12)/np.log(2))/2)
6356
- nside=2**jmax
6108
+ jmax = int((np.log(im_target.shape[1] // 12) / np.log(2)) / 2)
6109
+ nside = 2**jmax
6357
6110
 
6358
- if nstep>jmax-1:
6359
- nstep=jmax-1
6111
+ if nstep > jmax - 1:
6112
+ nstep = jmax - 1
6360
6113
 
6361
- t1=time.time()
6362
- tmp={}
6363
- tmp[nstep-1]=im_target
6364
- for l in range(nstep-2,-1,-1):
6365
- tmp[l]=self.ud_grade_2(tmp[l+1],axis=1)
6114
+ t1 = time.time()
6115
+ tmp = {}
6116
+
6117
+ l_grd_mask={}
6118
+
6119
+ tmp[nstep - 1] = self.backend.bk_cast(im_target)
6120
+ if grd_mask is not None:
6121
+ l_grd_mask[nstep - 1] = self.backend.bk_cast(grd_mask)
6122
+ else:
6123
+ l_grd_mask[nstep - 1] = None
6366
6124
 
6367
- if not self.use_2D and not self.use_1D:
6368
- l_nside=nside//(2**(nstep-1))
6125
+ for ell in range(nstep - 2, -1, -1):
6126
+ tmp[ell] = self.ud_grade_2(tmp[ell + 1], axis=1)
6127
+ if grd_mask is not None:
6128
+ l_grd_mask[ell] = self.ud_grade_2(l_grd_mask[ell + 1], axis=1)
6129
+ else:
6130
+ l_grd_mask[ell] = None
6369
6131
 
6132
+
6133
+ if not self.use_2D and not self.use_1D:
6134
+ l_nside = nside // (2 ** (nstep - 1))
6135
+
6370
6136
  for k in range(nstep):
6371
- if k==0:
6137
+ if k == 0:
6372
6138
  if input_image is None:
6373
6139
  np.random.seed(seed)
6374
6140
  if self.use_2D:
6375
- imap=np.random.randn(synthesised_N,
6376
- tmp[k].shape[1],
6377
- tmp[k].shape[2])
6141
+ imap = self.backend.bk_cast(np.random.randn(
6142
+ synthesised_N, tmp[k].shape[1], tmp[k].shape[2]
6143
+ ))
6378
6144
  else:
6379
- imap=np.random.randn(synthesised_N,
6380
- tmp[k].shape[1])
6145
+ imap = self.backend.bk_cast(np.random.randn(synthesised_N, tmp[k].shape[1]))
6381
6146
  else:
6382
6147
  if self.use_2D:
6383
- imap=self.backend.bk_reshape(self.backend.bk_tile(self.backend.bk_cast(input_image.flatten()),synthesised_N),
6384
- [synthesised_N,tmp[k].shape[1],tmp[k].shape[2]])
6148
+ imap = self.backend.bk_reshape(
6149
+ self.backend.bk_tile(
6150
+ self.backend.bk_cast(input_image.flatten()),
6151
+ synthesised_N,
6152
+ ),
6153
+ [synthesised_N, tmp[k].shape[1], tmp[k].shape[2]],
6154
+ )
6385
6155
  else:
6386
- imap=self.backend.bk_reshape(self.backend.bk_tile(self.backend.bk_cast(input_image.flatten()),synthesised_N),
6387
- [synthesised_N,tmp[k].shape[1]])
6156
+ imap = self.backend.bk_reshape(
6157
+ self.backend.bk_tile(
6158
+ self.backend.bk_cast(input_image.flatten()),
6159
+ synthesised_N,
6160
+ ),
6161
+ [synthesised_N, tmp[k].shape[1]],
6162
+ )
6388
6163
  else:
6389
6164
  # Increase the resolution between each step
6390
6165
  if self.use_2D:
@@ -6395,48 +6170,49 @@ class funct(FOC.FoCUS):
6395
6170
  imap = self.up_grade(omap, imap.shape[1] * 2, axis=1)
6396
6171
  else:
6397
6172
  imap = self.up_grade(omap, l_nside, axis=1)
6398
-
6173
+
6174
+ if grd_mask is not None:
6175
+ imap=imap*l_grd_mask[k]+tmp[k]*(1-l_grd_mask[k])
6176
+
6399
6177
  # compute the coefficients for the target image
6400
6178
  if use_variance:
6401
- ref,sref=self.scattering_cov(tmp[k],get_variance=True,edge=edge,Jmax=Jmax,iso_ang=iso_ang)
6179
+ ref, sref = self.scattering_cov(
6180
+ tmp[k], get_variance=True, edge=edge, Jmax=Jmax, iso_ang=iso_ang
6181
+ )
6402
6182
  else:
6403
- ref=self.scattering_cov(tmp[k],edge=edge,Jmax=Jmax,iso_ang=iso_ang)
6404
- sref=ref
6405
-
6183
+ ref = self.scattering_cov(tmp[k], edge=edge, Jmax=Jmax, iso_ang=iso_ang)
6184
+ sref = ref
6185
+
6406
6186
  # compute the mean of the population does nothing if only one map is given
6407
- ref=self.reduce_mean_batch(ref)
6408
-
6187
+ ref = self.reduce_mean_batch(ref)
6188
+
6409
6189
  # define a loss to minimize
6410
- loss=synthe.Loss(The_loss,self,ref,sref,use_variance)
6411
-
6190
+ loss = synthe.Loss(The_loss, self, ref, sref, use_variance)
6191
+
6412
6192
  sy = synthe.Synthesis([loss])
6413
6193
 
6414
6194
  # initialize the synthesised map
6415
6195
  if self.use_2D:
6416
- print('Synthesis scale [ %d x %d ]'%(imap.shape[1],imap.shape[2]))
6196
+ print("Synthesis scale [ %d x %d ]" % (imap.shape[1], imap.shape[2]))
6417
6197
  elif self.use_1D:
6418
- print('Synthesis scale [ %d ]'%(imap.shape[1]))
6198
+ print("Synthesis scale [ %d ]" % (imap.shape[1]))
6419
6199
  else:
6420
- print('Synthesis scale nside=%d'%(l_nside))
6421
- l_nside*=2
6422
-
6200
+ print("Synthesis scale nside=%d" % (l_nside))
6201
+ l_nside *= 2
6202
+
6423
6203
  # do the minimization
6424
- omap=sy.run(imap,
6425
- EVAL_FREQUENCY=EVAL_FREQUENCY,
6426
- NUM_EPOCHS = NUM_EPOCHS)
6427
-
6428
-
6204
+ omap = sy.run(imap, EVAL_FREQUENCY=EVAL_FREQUENCY, NUM_EPOCHS=NUM_EPOCHS,grd_mask=l_grd_mask[k])
6429
6205
 
6430
- t2=time.time()
6431
- print('Total computation %.2fs'%(t2-t1))
6206
+ t2 = time.time()
6207
+ print("Total computation %.2fs" % (t2 - t1))
6432
6208
 
6433
6209
  if to_gaussian:
6434
- omap=self.from_gaussian(omap)
6210
+ omap = self.from_gaussian(omap)
6435
6211
 
6436
- if axis==0 and synthesised_N==1:
6212
+ if axis == 0 and synthesised_N == 1:
6437
6213
  return omap[0]
6438
6214
  else:
6439
6215
  return omap
6440
-
6441
- def to_numpy(self,x):
6216
+
6217
+ def to_numpy(self, x):
6442
6218
  return self.backend.to_numpy(x)