foscat 3.6.1__py3-none-any.whl → 3.7.1__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
foscat/scat_cov.py CHANGED
@@ -33,9 +33,7 @@ testwarn = 0
33
33
 
34
34
 
35
35
  class scat_cov:
36
- def __init__(
37
- self, s0, s2, s3, s4, s1=None, s3p=None, backend=None, use_1D=False
38
- ):
36
+ def __init__(self, s0, s2, s3, s4, s1=None, s3p=None, backend=None, use_1D=False):
39
37
  self.S0 = s0
40
38
  self.S2 = s2
41
39
  self.S3 = s3
@@ -54,17 +52,17 @@ class scat_cov:
54
52
  if self.S1 is None:
55
53
  s1 = None
56
54
  else:
57
- s1 = self.S1.numpy()
55
+ s1 = self.backend.to_numpy(self.S1)
58
56
  if self.S3P is None:
59
57
  s3p = None
60
58
  else:
61
- s3p = self.S3P.numpy()
59
+ s3p = self.backend.to_numpy(self.S3P)
62
60
 
63
61
  return scat_cov(
64
- (self.S0.numpy()),
65
- (self.S2.numpy()),
66
- (self.S3.numpy()),
67
- (self.S4.numpy()),
62
+ self.backend.to_numpy(self.S0),
63
+ self.backend.to_numpy(self.S2),
64
+ self.backend.to_numpy(self.S3),
65
+ self.backend.to_numpy(self.S4),
68
66
  s1=s1,
69
67
  s3p=s3p,
70
68
  backend=self.backend,
@@ -94,7 +92,7 @@ class scat_cov:
94
92
  )
95
93
 
96
94
  def conv2complex(self, val):
97
- if val.dtype == "complex64" or val.dtype == "complex128":
95
+ if val.dtype == "complex64" or val.dtype=="complex128" or val.dtype == "torch.complex64" or val.dtype == "torch.complex128" :
98
96
  return val
99
97
  else:
100
98
  return self.backend.bk_complex(val, 0 * val)
@@ -1531,7 +1529,7 @@ class scat_cov:
1531
1529
  if isinstance(x, np.ndarray):
1532
1530
  return x
1533
1531
  else:
1534
- return x.numpy()
1532
+ return self.backend.to_numpy(x)
1535
1533
  else:
1536
1534
  return None
1537
1535
 
@@ -1974,7 +1972,7 @@ class scat_cov:
1974
1972
  if self.BACKEND == "numpy":
1975
1973
  s2[:, :, noff:, :] = self.S2
1976
1974
  else:
1977
- s2[:, :, noff:, :] = self.S2.numpy()
1975
+ s2[:, :, noff:, :] = self.backend.to_numpy(self.S2)
1978
1976
  for i in range(self.S2.shape[0]):
1979
1977
  for j in range(self.S2.shape[1]):
1980
1978
  for k in range(self.S2.shape[3]):
@@ -1986,7 +1984,7 @@ class scat_cov:
1986
1984
  if self.BACKEND == "numpy":
1987
1985
  s1[:, :, noff:, :] = self.S1
1988
1986
  else:
1989
- s1[:, :, noff:, :] = self.S1.numpy()
1987
+ s1[:, :, noff:, :] = self.backend.to_numpy(self.S1)
1990
1988
  for i in range(self.S1.shape[0]):
1991
1989
  for j in range(self.S1.shape[1]):
1992
1990
  for k in range(self.S1.shape[3]):
@@ -2045,12 +2043,12 @@ class scat_cov:
2045
2043
  )
2046
2044
  )
2047
2045
  else:
2048
- s3[i, j, idx[noff:], k, l_orient] = self.S3.numpy()[
2046
+ s3[i, j, idx[noff:], k, l_orient] = self.backend.to_numpy(self.S3)[
2049
2047
  i, j, j2 == ij - noff, k, l_orient
2050
2048
  ]
2051
2049
  s3[i, j, idx[:noff], k, l_orient] = (
2052
2050
  self.add_data_from_slope(
2053
- self.S3.numpy()[
2051
+ self.backend.to_numpy(self.S3)[
2054
2052
  i, j, j2 == ij - noff, k, l_orient
2055
2053
  ],
2056
2054
  noff,
@@ -2358,13 +2356,13 @@ class funct(FOC.FoCUS):
2358
2356
  tmp = self.up_grade(tmp, l_nside * 2, axis=1)
2359
2357
  if image2 is not None:
2360
2358
  tmpi2 = self.up_grade(tmpi2, l_nside * 2, axis=1)
2361
-
2362
2359
  l_nside = int(np.sqrt(tmp.shape[1] // 12))
2363
2360
  nscale = int(np.log(l_nside) / np.log(2))
2364
2361
  cmat = {}
2365
2362
  cmat2 = {}
2363
+
2364
+ # Loop over scales
2366
2365
  for k in range(nscale):
2367
- sim = self.backend.bk_abs(self.convol(tmp, axis=1))
2368
2366
  if image2 is not None:
2369
2367
  sim = self.backend.bk_real(
2370
2368
  self.backend.bk_L1(
@@ -2375,17 +2373,31 @@ class funct(FOC.FoCUS):
2375
2373
  else:
2376
2374
  sim = self.backend.bk_abs(self.convol(tmp, axis=1))
2377
2375
 
2378
- cc = self.backend.bk_reduce_mean(sim[:, :, 0] - sim[:, :, 2], 0)
2379
- ss = self.backend.bk_reduce_mean(sim[:, :, 1] - sim[:, :, 3], 0)
2380
- for m in range(smooth_scale):
2381
- if cc.shape[0] > 12:
2382
- cc = self.ud_grade_2(self.smooth(cc))
2383
- ss = self.ud_grade_2(self.smooth(ss))
2376
+ # instead of difference between "opposite" channels use weighted average
2377
+ # of cosine and sine contributions using all channels
2378
+ angles = (
2379
+ 2 * np.pi * np.arange(self.NORIENT) / self.NORIENT
2380
+ ) # shape: (NORIENT,)
2381
+ angles = angles.reshape(1, 1, self.NORIENT)
2382
+
2383
+ # we use cosines and sines as weights for sim
2384
+ weighted_cos = self.backend.bk_reduce_mean(sim * np.cos(angles), axis=-1)
2385
+ weighted_sin = self.backend.bk_reduce_mean(sim * np.sin(angles), axis=-1)
2386
+ # For simplicity, take first element of the batch
2387
+ cc = weighted_cos[0]
2388
+ ss = weighted_sin[0]
2389
+
2390
+ if smooth_scale > 0:
2391
+ for m in range(smooth_scale):
2392
+ if cc.shape[0] > 12:
2393
+ cc = self.ud_grade_2(self.smooth(cc))
2394
+ ss = self.ud_grade_2(self.smooth(ss))
2384
2395
  if cc.shape[0] != tmp.shape[0]:
2385
2396
  ll_nside = int(np.sqrt(tmp.shape[1] // 12))
2386
2397
  cc = self.up_grade(cc, ll_nside)
2387
2398
  ss = self.up_grade(ss, ll_nside)
2388
2399
 
2400
+ # compute local phase from weighted cos and sin (same as before)
2389
2401
  if self.BACKEND == "numpy":
2390
2402
  phase = np.fmod(np.arctan2(ss, cc) + 2 * np.pi, 2 * np.pi)
2391
2403
  else:
@@ -2393,87 +2405,87 @@ class funct(FOC.FoCUS):
2393
2405
  np.arctan2(ss.numpy(), cc.numpy()) + 2 * np.pi, 2 * np.pi
2394
2406
  )
2395
2407
 
2396
- iph = (4 * phase / (2 * np.pi)).astype("int")
2397
- alpha = 4 * phase / (2 * np.pi) - iph
2398
- mat = np.zeros([sim.shape[1], 4 * 4])
2408
+ # instead of linear interpolation cosine‐based interpolation
2409
+ phase_scaled = self.NORIENT * phase / (2 * np.pi)
2410
+ iph = np.floor(phase_scaled).astype("int") # lower bin index
2411
+ delta = phase_scaled - iph # fractional part in [0,1)
2412
+ # interpolation weights
2413
+ w0 = np.cos(delta * np.pi / 2) ** 2
2414
+ w1 = np.sin(delta * np.pi / 2) ** 2
2415
+
2416
+ # build rotation matrix
2417
+ mat = np.zeros([sim.shape[1], self.NORIENT * self.NORIENT])
2399
2418
  lidx = np.arange(sim.shape[1])
2400
- for l_orient in range(4):
2401
- mat[lidx, 4 * ((l_orient + iph) % 4) + l_orient] = 1.0 - alpha
2402
- mat[lidx, 4 * ((l_orient + iph + 1) % 4) + l_orient] = alpha
2419
+ for l in range(self.NORIENT):
2420
+ # Instead of simple linear weights, we use the cosine weights w0 and w1.
2421
+ col0 = self.NORIENT * ((l + iph) % self.NORIENT) + l
2422
+ col1 = self.NORIENT * ((l + iph + 1) % self.NORIENT) + l
2423
+ mat[lidx, col0] = w0
2424
+ mat[lidx, col1] = w1
2403
2425
 
2404
2426
  cmat[k] = self.backend.bk_cast(mat.astype("complex64"))
2405
2427
 
2406
- mat2 = np.zeros([k + 1, sim.shape[1], 4, 4 * 4])
2428
+ # do same modifications for mat2
2429
+ mat2 = np.zeros(
2430
+ [k + 1, sim.shape[1], self.NORIENT, self.NORIENT * self.NORIENT]
2431
+ )
2407
2432
 
2408
2433
  for k2 in range(k + 1):
2409
- tmp2 = self.backend.bk_repeat(sim, 4, axis=-1)
2434
+ tmp2 = self.backend.bk_repeat(sim, self.NORIENT, axis=-1)
2435
+
2410
2436
  sim2 = self.backend.bk_reduce_sum(
2411
2437
  self.backend.bk_reshape(
2412
- mat.reshape(1, mat.shape[0], 16) * tmp2,
2413
- [sim.shape[0], cmat[k].shape[0], 4, 4],
2438
+ mat.reshape(1, mat.shape[0], self.NORIENT * self.NORIENT)
2439
+ * tmp2,
2440
+ [sim.shape[0], cmat[k].shape[0], self.NORIENT, self.NORIENT],
2414
2441
  ),
2415
2442
  2,
2416
2443
  )
2444
+
2417
2445
  sim2 = self.backend.bk_abs(self.convol(sim2, axis=1))
2418
2446
 
2419
- cc = self.smooth(
2420
- self.backend.bk_reduce_mean(sim2[:, :, 0] - sim2[:, :, 2], 0)
2447
+ weighted_cos2 = self.backend.bk_reduce_mean(
2448
+ sim2 * np.cos(angles), axis=-1
2421
2449
  )
2422
- ss = self.smooth(
2423
- self.backend.bk_reduce_mean(sim2[:, :, 1] - sim2[:, :, 3], 0)
2450
+ weighted_sin2 = self.backend.bk_reduce_mean(
2451
+ sim2 * np.sin(angles), axis=-1
2424
2452
  )
2425
- for m in range(smooth_scale):
2426
- if cc.shape[0] > 12:
2427
- cc = self.ud_grade_2(self.smooth(cc))
2428
- ss = self.ud_grade_2(self.smooth(ss))
2429
- if cc.shape[0] != sim.shape[1]:
2453
+
2454
+ cc2 = weighted_cos2[0]
2455
+ ss2 = weighted_sin2[0]
2456
+
2457
+ if smooth_scale > 0:
2458
+ for m in range(smooth_scale):
2459
+ if cc2.shape[0] > 12:
2460
+ cc2 = self.ud_grade_2(self.smooth(cc2))
2461
+ ss2 = self.ud_grade_2(self.smooth(ss2))
2462
+
2463
+ if cc2.shape[0] != sim.shape[1]:
2430
2464
  ll_nside = int(np.sqrt(sim.shape[1] // 12))
2431
- cc = self.up_grade(cc, ll_nside)
2432
- ss = self.up_grade(ss, ll_nside)
2465
+ cc2 = self.up_grade(cc2, ll_nside)
2466
+ ss2 = self.up_grade(ss2, ll_nside)
2433
2467
 
2434
2468
  if self.BACKEND == "numpy":
2435
- phase = np.fmod(np.arctan2(ss, cc) + 2 * np.pi, 2 * np.pi)
2469
+ phase2 = np.fmod(np.arctan2(ss2, cc2) + 2 * np.pi, 2 * np.pi)
2436
2470
  else:
2437
- phase = np.fmod(
2438
- np.arctan2(ss.numpy(), cc.numpy()) + 2 * np.pi, 2 * np.pi
2471
+ phase2 = np.fmod(
2472
+ np.arctan2(ss2.numpy(), cc2.numpy()) + 2 * np.pi, 2 * np.pi
2439
2473
  )
2440
- """
2441
- for k in range(4):
2442
- hp.mollview(np.fmod(phase+np.pi,2*np.pi),cmap='jet',nest=True,hold=False,sub=(2,2,1+k))
2443
- plt.show()
2444
- return None
2445
- """
2446
- iph = (4 * phase / (2 * np.pi)).astype("int")
2447
- alpha = 4 * phase / (2 * np.pi) - iph
2474
+
2475
+ phase2_scaled = self.NORIENT * phase2 / (2 * np.pi)
2476
+ iph2 = np.floor(phase2_scaled).astype("int")
2477
+ delta2 = phase2_scaled - iph2
2478
+ w0_2 = np.cos(delta2 * np.pi / 2) ** 2
2479
+ w1_2 = np.sin(delta2 * np.pi / 2) ** 2
2448
2480
  lidx = np.arange(sim.shape[1])
2449
- for m in range(4):
2450
- for l_orient in range(4):
2451
- mat2[
2452
- k2, lidx, m, 4 * ((l_orient + iph[:, m]) % 4) + l_orient
2453
- ] = (1.0 - alpha[:, m])
2454
- mat2[
2455
- k2, lidx, m, 4 * ((l_orient + iph[:, m] + 1) % 4) + l_orient
2456
- ] = alpha[:, m]
2457
-
2458
- cmat2[k] = self.backend.bk_cast(mat2.astype("complex64"))
2459
- """
2460
- tmp=self.backend.bk_repeat(sim[0],4,axis=1)
2461
- sim2=self.backend.bk_reduce_sum(self.backend.bk_reshape(mat*tmp,[12*nside**2,4,4]),1)
2462
-
2463
- cc2=(sim2[:,0]-sim2[:,2])
2464
- ss2=(sim2[:,1]-sim2[:,3])
2465
- phase2=np.fmod(np.arctan2(ss2.numpy(),cc2.numpy())+2*np.pi,2*np.pi)
2466
-
2467
- plt.figure()
2468
- hp.mollview(phase,cmap='jet',nest=True,hold=False,sub=(2,2,1))
2469
- hp.mollview(np.fmod(phase2+np.pi,2*np.pi),cmap='jet',nest=True,hold=False,sub=(2,2,2))
2470
- plt.figure()
2471
- for k in range(4):
2472
- hp.mollview((sim[0,:,k]).numpy().real,cmap='jet',nest=True,hold=False,sub=(2,4,1+k),min=-10,max=10)
2473
- hp.mollview((sim2[:,k]).numpy().real,cmap='jet',nest=True,hold=False,sub=(2,4,5+k),min=-10,max=10)
2474
-
2475
- plt.show()
2476
- """
2481
+
2482
+ for m in range(self.NORIENT):
2483
+ for l in range(self.NORIENT):
2484
+ col0 = self.NORIENT * ((l + iph2[:, m]) % self.NORIENT) + l
2485
+ col1 = self.NORIENT * ((l + iph2[:, m] + 1) % self.NORIENT) + l
2486
+ mat2[k2, lidx, m, col0] = w0_2[:, m]
2487
+ mat2[k2, lidx, m, col1] = w1_2[:, m]
2488
+ cmat2[k] = self.backend.bk_cast(mat2.astype("complex64"))
2477
2489
 
2478
2490
  if k < l_nside - 1:
2479
2491
  tmp = self.ud_grade_2(tmp, axis=1)
@@ -2493,11 +2505,12 @@ class funct(FOC.FoCUS):
2493
2505
  image2=None,
2494
2506
  mask=None,
2495
2507
  norm=None,
2496
- Auto=True,
2497
2508
  calc_var=False,
2498
2509
  cmat=None,
2499
2510
  cmat2=None,
2500
- out_nside=None
2511
+ Jmax=None,
2512
+ out_nside=None,
2513
+ edge=True
2501
2514
  ):
2502
2515
  """
2503
2516
  Calculates the scattering correlations for a batch of images. Mean are done over pixels.
@@ -2523,9 +2536,6 @@ class funct(FOC.FoCUS):
2523
2536
  norm: None or str
2524
2537
  If None no normalization is applied, if 'auto' normalize by the reference S2,
2525
2538
  if 'self' normalize by the current S2.
2526
- all_cross: False or True
2527
- If False compute all the coefficient even the Imaginary part,
2528
- If True return only the terms computable in the auto case.
2529
2539
  Returns
2530
2540
  -------
2531
2541
  S1, S2, S3, S4 normalized
@@ -2539,15 +2549,26 @@ class funct(FOC.FoCUS):
2539
2549
  )
2540
2550
  return None
2541
2551
  if mask is not None:
2542
- if list(image1.shape) != list(mask.shape)[1:]:
2543
- print(
2544
- "The LAST COLUMN of the mask should have the same size ",
2545
- mask.shape,
2546
- "than the input image ",
2547
- image1.shape,
2548
- "to eval Scattering Covariance",
2549
- )
2550
- return None
2552
+ if self.use_2D:
2553
+ if image1.shape[-2] != mask.shape[1] or image1.shape[-1] != mask.shape[2]:
2554
+ print(
2555
+ "The LAST 2 COLUMNs of the mask should have the same size ",
2556
+ mask.shape,
2557
+ "than the input image ",
2558
+ image1.shape,
2559
+ "to eval Scattering Covariance",
2560
+ )
2561
+ return None
2562
+ else:
2563
+ if image1.shape[-1] != mask.shape[1]:
2564
+ print(
2565
+ "The LAST COLUMN of the mask should have the same size ",
2566
+ mask.shape,
2567
+ "than the input image ",
2568
+ image1.shape,
2569
+ "to eval Scattering Covariance",
2570
+ )
2571
+ return None
2551
2572
  if self.use_2D and len(image1.shape) < 2:
2552
2573
  print(
2553
2574
  "To work with 2D scattering transform, two dimension is needed, input map has only on dimension"
@@ -2558,9 +2579,6 @@ class funct(FOC.FoCUS):
2558
2579
  cross = False
2559
2580
  if image2 is not None:
2560
2581
  cross = True
2561
- all_cross = Auto
2562
- else:
2563
- all_cross = False
2564
2582
 
2565
2583
  ### PARAMETERS
2566
2584
  axis = 1
@@ -2596,8 +2614,16 @@ class funct(FOC.FoCUS):
2596
2614
  nside = int(np.sqrt(npix // 12))
2597
2615
 
2598
2616
  J = int(np.log(nside) / np.log(2)) # Number of j scales
2599
-
2600
- Jmax = J - self.OSTEP # Number of steps for the loop on scales
2617
+
2618
+ if (self.use_2D or self.use_1D) and self.KERNELSZ>3:
2619
+ J-=1
2620
+ if Jmax is None:
2621
+ Jmax = J # Number of steps for the loop on scales
2622
+ if Jmax>J:
2623
+ print('==========\n\n')
2624
+ print('The Jmax you requested is larger than the data size, which may cause problems while computing the scattering transform.')
2625
+ print('\n\n==========')
2626
+
2601
2627
 
2602
2628
  ### LOCAL VARIABLES (IMAGES and MASK)
2603
2629
  if len(image1.shape) == 1 or (len(image1.shape) == 2 and self.use_2D):
@@ -2621,7 +2647,7 @@ class funct(FOC.FoCUS):
2621
2647
  else:
2622
2648
  vmask = self.backend.bk_cast(mask) # [Nmask, Npix]
2623
2649
 
2624
- if self.KERNELSZ > 3:
2650
+ if self.KERNELSZ > 3 and not self.use_2D:
2625
2651
  # if the kernel size is bigger than 3 increase the binning before smoothing
2626
2652
  if self.use_2D:
2627
2653
  vmask = self.up_grade(
@@ -2645,7 +2671,7 @@ class funct(FOC.FoCUS):
2645
2671
  if cross:
2646
2672
  I2 = self.up_grade(I2, nside * 2, axis=axis)
2647
2673
 
2648
- if self.KERNELSZ > 5:
2674
+ if self.KERNELSZ > 5 and not self.use_2D:
2649
2675
  # if the kernel size is bigger than 3 increase the binning before smoothing
2650
2676
  if self.use_2D:
2651
2677
  vmask = self.up_grade(
@@ -2677,7 +2703,23 @@ class funct(FOC.FoCUS):
2677
2703
 
2678
2704
  ### INITIALIZATION
2679
2705
  # Coefficients
2680
- S1, S2, S3, S4, S3P = None, None, None, None, None
2706
+ if return_data:
2707
+ S1 = {}
2708
+ S2 = {}
2709
+ S3 = {}
2710
+ S3P = {}
2711
+ S4 = {}
2712
+ else:
2713
+ S1 = []
2714
+ S2 = []
2715
+ S3 = []
2716
+ S4 = []
2717
+ S3P = []
2718
+ VS1 = []
2719
+ VS2 = []
2720
+ VS3 = []
2721
+ VS3P = []
2722
+ VS4 = []
2681
2723
 
2682
2724
  off_S2 = -2
2683
2725
  off_S3 = -3
@@ -2687,11 +2729,6 @@ class funct(FOC.FoCUS):
2687
2729
  off_S3 = -1
2688
2730
  off_S4 = -1
2689
2731
 
2690
- # Dictionaries for S3 computation
2691
- M1_dic = {} # M stands for Module M1 = |I1 * Psi|
2692
- if cross:
2693
- M2_dic = {}
2694
-
2695
2732
  # S2 for normalization
2696
2733
  cond_init_P1_dic = (norm == "self") or (
2697
2734
  (norm == "auto") and (self.P1_dic is None)
@@ -2710,7 +2747,12 @@ class funct(FOC.FoCUS):
2710
2747
  if return_data:
2711
2748
  s0 = I1
2712
2749
  if out_nside is not None:
2713
- s0 = self.backend.bk_reduce_mean(self.backend.bk_reshape(s0,[s0.shape[0],12*out_nside**2,(nside//out_nside)**2]),2)
2750
+ s0 = self.backend.bk_reduce_mean(
2751
+ self.backend.bk_reshape(
2752
+ s0, [s0.shape[0], 12 * out_nside**2, (nside // out_nside) ** 2]
2753
+ ),
2754
+ 2,
2755
+ )
2714
2756
  else:
2715
2757
  if not cross:
2716
2758
  s0, l_vs0 = self.masked_mean(I1, vmask, axis=1, calc_var=True)
@@ -2720,17 +2762,38 @@ class funct(FOC.FoCUS):
2720
2762
  )
2721
2763
  vs0 = self.backend.bk_concat([l_vs0, l_vs0], 1)
2722
2764
  s0 = self.backend.bk_concat([s0, l_vs0], 1)
2723
-
2724
2765
  #### COMPUTE S1, S2, S3 and S4
2725
2766
  nside_j3 = nside # NSIDE start (nside_j3 = nside / 2^j3)
2767
+
2768
+ # a remettre comme avant
2769
+ M1_dic={}
2770
+ M2_dic={}
2771
+
2726
2772
  for j3 in range(Jmax):
2773
+
2774
+ if edge:
2775
+ if self.mask_mask is None:
2776
+ self.mask_mask={}
2777
+ if self.use_2D:
2778
+ if (vmask.shape[1],vmask.shape[2]) not in self.mask_mask:
2779
+ mask_mask=np.zeros([1,vmask.shape[1],vmask.shape[2]])
2780
+ mask_mask[0,
2781
+ self.KERNELSZ//2:-self.KERNELSZ//2+1,
2782
+ self.KERNELSZ//2:-self.KERNELSZ//2+1]=1.0
2783
+ self.mask_mask[(vmask.shape[1],vmask.shape[2])]=self.backend.bk_cast(mask_mask)
2784
+ vmask=vmask*self.mask_mask[(vmask.shape[1],vmask.shape[2])]
2785
+ #print(self.KERNELSZ//2,vmask,mask_mask)
2786
+
2787
+ if self.use_1D:
2788
+ if (vmask.shape[1]) not in self.mask_mask:
2789
+ mask_mask=np.zeros([1,vmask.shape[1]])
2790
+ mask_mask[0,
2791
+ self.KERNELSZ//2:-self.KERNELSZ//2+1]=1.0
2792
+ self.mask_mask[(vmask.shape[1])]=self.backend.bk_cast(mask_mask)
2793
+ vmask=vmask*self.mask_mask[(vmask.shape[1])]
2794
+
2727
2795
  if return_data:
2728
- if S3 is None:
2729
- S3 = {}
2730
2796
  S3[j3] = None
2731
-
2732
- if S3P is None:
2733
- S3P = {}
2734
2797
  S3P[j3] = None
2735
2798
 
2736
2799
  if S4 is None:
@@ -2740,12 +2803,12 @@ class funct(FOC.FoCUS):
2740
2803
  ####### S1 and S2
2741
2804
  ### Make the convolution I1 * Psi_j3
2742
2805
  conv1 = self.convol(I1, axis=1) # [Nbatch, Npix_j3, Norient3]
2743
-
2744
2806
  if cmat is not None:
2745
- tmp2 = self.backend.bk_repeat(conv1, 4, axis=-1)
2807
+ tmp2 = self.backend.bk_repeat(conv1, self.NORIENT, axis=-1)
2746
2808
  conv1 = self.backend.bk_reduce_sum(
2747
2809
  self.backend.bk_reshape(
2748
- cmat[j3] * tmp2, [tmp2.shape[0], cmat[j3].shape[0], 4, 4]
2810
+ cmat[j3] * tmp2,
2811
+ [tmp2.shape[0], cmat[j3].shape[0], self.NORIENT, self.NORIENT],
2749
2812
  ),
2750
2813
  2,
2751
2814
  )
@@ -2781,33 +2844,31 @@ class funct(FOC.FoCUS):
2781
2844
  if return_data:
2782
2845
  if S2 is None:
2783
2846
  S2 = {}
2784
- if out_nside is not None and out_nside<nside_j3:
2847
+ if out_nside is not None and out_nside < nside_j3:
2785
2848
  s2 = self.backend.bk_reduce_mean(
2786
- self.backend.bk_reshape(s2,[s2.shape[0],
2787
- 12*out_nside**2,
2788
- (nside_j3//out_nside)**2,
2789
- s2.shape[2]]),2)
2849
+ self.backend.bk_reshape(
2850
+ s2,
2851
+ [
2852
+ s2.shape[0],
2853
+ 12 * out_nside**2,
2854
+ (nside_j3 // out_nside) ** 2,
2855
+ s2.shape[2],
2856
+ ],
2857
+ ),
2858
+ 2,
2859
+ )
2790
2860
  S2[j3] = s2
2791
2861
  else:
2792
2862
  if norm == "auto": # Normalize S2
2793
2863
  s2 /= P1_dic[j3]
2794
- if S2 is None:
2795
- S2 = self.backend.bk_expand_dims(
2796
- s2, off_S2
2864
+
2865
+ S2.append(
2866
+ self.backend.bk_expand_dims(s2, off_S2)
2867
+ ) # Add a dimension for NS2
2868
+ if calc_var:
2869
+ VS2.append(
2870
+ self.backend.bk_expand_dims(vs2, off_S2)
2797
2871
  ) # Add a dimension for NS2
2798
- if calc_var:
2799
- VS2 = self.backend.bk_expand_dims(
2800
- vs2, off_S2
2801
- ) # Add a dimension for NS2
2802
- else:
2803
- S2 = self.backend.bk_concat(
2804
- [S2, self.backend.bk_expand_dims(s2, off_S2)], axis=2
2805
- )
2806
- if calc_var:
2807
- VS2 = self.backend.bk_concat(
2808
- [VS2, self.backend.bk_expand_dims(vs2, off_S2)],
2809
- axis=2,
2810
- )
2811
2872
 
2812
2873
  #### S1_auto computation
2813
2874
  ### Image 1 : S1 = < M1 >_pix
@@ -2825,45 +2886,47 @@ class funct(FOC.FoCUS):
2825
2886
  ) # [Nbatch, Nmask, Norient3]
2826
2887
 
2827
2888
  if return_data:
2828
- if S1 is None:
2829
- S1 = {}
2830
- if out_nside is not None and out_nside<nside_j3:
2889
+ if out_nside is not None and out_nside < nside_j3:
2831
2890
  s1 = self.backend.bk_reduce_mean(
2832
- self.backend.bk_reshape(s1,[s1.shape[0],
2833
- 12*out_nside**2,
2834
- (nside_j3//out_nside)**2,
2835
- s1.shape[2]]),2)
2891
+ self.backend.bk_reshape(
2892
+ s1,
2893
+ [
2894
+ s1.shape[0],
2895
+ 12 * out_nside**2,
2896
+ (nside_j3 // out_nside) ** 2,
2897
+ s1.shape[2],
2898
+ ],
2899
+ ),
2900
+ 2,
2901
+ )
2836
2902
  S1[j3] = s1
2837
2903
  else:
2838
2904
  ### Normalize S1
2839
2905
  if norm is not None:
2840
2906
  self.div_norm(s1, (P1_dic[j3]) ** 0.5)
2841
2907
  ### We store S1 for image1 [Nbatch, Nmask, NS1, Norient3]
2842
- if S1 is None:
2843
- S1 = self.backend.bk_expand_dims(
2844
- s1, off_S2
2908
+ S1.append(
2909
+ self.backend.bk_expand_dims(s1, off_S2)
2910
+ ) # Add a dimension for NS1
2911
+ if calc_var:
2912
+ VS1.append(
2913
+ self.backend.bk_expand_dims(vs1, off_S2)
2845
2914
  ) # Add a dimension for NS1
2846
- if calc_var:
2847
- VS1 = self.backend.bk_expand_dims(
2848
- vs1, off_S2
2849
- ) # Add a dimension for NS1
2850
- else:
2851
- S1 = self.backend.bk_concat(
2852
- [S1, self.backend.bk_expand_dims(s1, off_S2)], axis=2
2853
- )
2854
- if calc_var:
2855
- VS1 = self.backend.bk_concat(
2856
- [VS1, self.backend.bk_expand_dims(vs1, off_S2)], axis=2
2857
- )
2858
2915
 
2859
2916
  else: # Cross
2860
2917
  ### Make the convolution I2 * Psi_j3
2861
2918
  conv2 = self.convol(I2, axis=1) # [Nbatch, Npix_j3, Norient3]
2862
2919
  if cmat is not None:
2863
- tmp2 = self.backend.bk_repeat(conv2, 4, axis=-1)
2920
+ tmp2 = self.backend.bk_repeat(conv2, self.NORIENT, axis=-1)
2864
2921
  conv2 = self.backend.bk_reduce_sum(
2865
2922
  self.backend.bk_reshape(
2866
- cmat[j3] * tmp2, [tmp2.shape[0], cmat[j3].shape[0], 4, 4]
2923
+ cmat[j3] * tmp2,
2924
+ [
2925
+ tmp2.shape[0],
2926
+ cmat[j3].shape[0],
2927
+ self.NORIENT,
2928
+ self.NORIENT,
2929
+ ],
2867
2930
  ),
2868
2931
  2,
2869
2932
  )
@@ -2917,14 +2980,19 @@ class funct(FOC.FoCUS):
2917
2980
  s2 = self.masked_mean(s2, vmask, axis=1, rank=j3)
2918
2981
 
2919
2982
  if return_data:
2920
- if S2 is None:
2921
- S2 = {}
2922
- if out_nside is not None and out_nside<nside_j3:
2983
+ if out_nside is not None and out_nside < nside_j3:
2923
2984
  s2 = self.backend.bk_reduce_mean(
2924
- self.backend.bk_reshape(s2,[s2.shape[0],
2925
- 12*out_nside**2,
2926
- (nside_j3//out_nside)**2,
2927
- s2.shape[2]]),2)
2985
+ self.backend.bk_reshape(
2986
+ s2,
2987
+ [
2988
+ s2.shape[0],
2989
+ 12 * out_nside**2,
2990
+ (nside_j3 // out_nside) ** 2,
2991
+ s2.shape[2],
2992
+ ],
2993
+ ),
2994
+ 2,
2995
+ )
2928
2996
  S2[j3] = s2
2929
2997
  else:
2930
2998
  ### Normalize S2_cross
@@ -2932,26 +3000,15 @@ class funct(FOC.FoCUS):
2932
3000
  s2 /= (P1_dic[j3] * P2_dic[j3]) ** 0.5
2933
3001
 
2934
3002
  ### Store S2_cross as complex [Nbatch, Nmask, NS2, Norient3]
2935
- if not all_cross:
2936
- s2 = self.backend.bk_real(s2)
3003
+ s2 = self.backend.bk_real(s2)
2937
3004
 
2938
- if S2 is None:
2939
- S2 = self.backend.bk_expand_dims(
2940
- s2, off_S2
3005
+ S2.append(
3006
+ self.backend.bk_expand_dims(s2, off_S2)
3007
+ ) # Add a dimension for NS2
3008
+ if calc_var:
3009
+ VS2.append(
3010
+ self.backend.bk_expand_dims(vs2, off_S2)
2941
3011
  ) # Add a dimension for NS2
2942
- if calc_var:
2943
- VS2 = self.backend.bk_expand_dims(
2944
- vs2, off_S2
2945
- ) # Add a dimension for NS2
2946
- else:
2947
- S2 = self.backend.bk_concat(
2948
- [S2, self.backend.bk_expand_dims(s2, off_S2)], axis=2
2949
- )
2950
- if calc_var:
2951
- VS2 = self.backend.bk_concat(
2952
- [VS2, self.backend.bk_expand_dims(vs2, off_S2)],
2953
- axis=2,
2954
- )
2955
3012
 
2956
3013
  #### S1_auto computation
2957
3014
  ### Image 1 : S1 = < M1 >_pix
@@ -2968,36 +3025,32 @@ class funct(FOC.FoCUS):
2968
3025
  MX, vmask, axis=1, rank=j3
2969
3026
  ) # [Nbatch, Nmask, Norient3]
2970
3027
  if return_data:
2971
- if S1 is None:
2972
- S1 = {}
2973
- if out_nside is not None and out_nside<nside_j3:
3028
+ if out_nside is not None and out_nside < nside_j3:
2974
3029
  s1 = self.backend.bk_reduce_mean(
2975
- self.backend.bk_reshape(s1,[s1.shape[0],
2976
- 12*out_nside**2,
2977
- (nside_j3//out_nside)**2,
2978
- s1.shape[2]]),2)
3030
+ self.backend.bk_reshape(
3031
+ s1,
3032
+ [
3033
+ s1.shape[0],
3034
+ 12 * out_nside**2,
3035
+ (nside_j3 // out_nside) ** 2,
3036
+ s1.shape[2],
3037
+ ],
3038
+ ),
3039
+ 2,
3040
+ )
2979
3041
  S1[j3] = s1
2980
3042
  else:
2981
3043
  ### Normalize S1
2982
3044
  if norm is not None:
2983
3045
  self.div_norm(s1, (P1_dic[j3]) ** 0.5)
2984
3046
  ### We store S1 for image1 [Nbatch, Nmask, NS1, Norient3]
2985
- if S1 is None:
2986
- S1 = self.backend.bk_expand_dims(
2987
- s1, off_S2
3047
+ S1.append(
3048
+ self.backend.bk_expand_dims(s1, off_S2)
3049
+ ) # Add a dimension for NS1
3050
+ if calc_var:
3051
+ VS1.append(
3052
+ self.backend.bk_expand_dims(vs1, off_S2)
2988
3053
  ) # Add a dimension for NS1
2989
- if calc_var:
2990
- VS1 = self.backend.bk_expand_dims(
2991
- vs1, off_S2
2992
- ) # Add a dimension for NS1
2993
- else:
2994
- S1 = self.backend.bk_concat(
2995
- [S1, self.backend.bk_expand_dims(s1, off_S2)], axis=2
2996
- )
2997
- if calc_var:
2998
- VS1 = self.backend.bk_concat(
2999
- [VS1, self.backend.bk_expand_dims(vs1, off_S2)], axis=2
3000
- )
3001
3054
 
3002
3055
  # Initialize dictionaries for |I1*Psi_j| * Psi_j3
3003
3056
  M1convPsi_dic = {}
@@ -3006,7 +3059,7 @@ class funct(FOC.FoCUS):
3006
3059
  M2convPsi_dic = {}
3007
3060
 
3008
3061
  ###### S3
3009
- nside_j2=nside_j3
3062
+ nside_j2 = nside_j3
3010
3063
  for j2 in range(0, j3 + 1): # j2 <= j3
3011
3064
  if return_data:
3012
3065
  if S4[j3] is None:
@@ -3041,13 +3094,20 @@ class funct(FOC.FoCUS):
3041
3094
  if return_data:
3042
3095
  if S3[j3] is None:
3043
3096
  S3[j3] = {}
3044
- if out_nside is not None and out_nside<nside_j2:
3097
+ if out_nside is not None and out_nside < nside_j2:
3045
3098
  s3 = self.backend.bk_reduce_mean(
3046
- self.backend.bk_reshape(s3,[s3.shape[0],
3047
- 12*out_nside**2,
3048
- (nside_j2//out_nside)**2,
3049
- s3.shape[2],
3050
- s3.shape[3]]),2)
3099
+ self.backend.bk_reshape(
3100
+ s3,
3101
+ [
3102
+ s3.shape[0],
3103
+ 12 * out_nside**2,
3104
+ (nside_j2 // out_nside) ** 2,
3105
+ s3.shape[2],
3106
+ s3.shape[3],
3107
+ ],
3108
+ ),
3109
+ 2,
3110
+ )
3051
3111
  S3[j3][j2] = s3
3052
3112
  else:
3053
3113
  ### Normalize S3 with S2_j [Nbatch, Nmask, Norient_j]
@@ -3062,23 +3122,18 @@ class funct(FOC.FoCUS):
3062
3122
  ) # [Nbatch, Nmask, Norient3, Norient2]
3063
3123
 
3064
3124
  ### Store S3 as a complex [Nbatch, Nmask, NS3, Norient3, Norient2]
3065
- if S3 is None:
3066
- S3 = self.backend.bk_expand_dims(
3067
- s3, off_S3
3068
- ) # Add a dimension for NS3
3069
- if calc_var:
3070
- VS3 = self.backend.bk_expand_dims(
3071
- vs3, off_S3
3072
- ) # Add a dimension for NS3
3073
- else:
3074
- S3 = self.backend.bk_concat(
3075
- [S3, self.backend.bk_expand_dims(s3, off_S3)], axis=2
3125
+
3126
+ # S3.append(self.backend.bk_reshape(s3,[s3.shape[0],s3.shape[1],
3127
+ # s3.shape[2]*s3.shape[3]]))
3128
+ S3.append(
3129
+ self.backend.bk_expand_dims(s3, off_S3)
3130
+ ) # Add a dimension for NS3
3131
+ if calc_var:
3132
+ VS3.append(
3133
+ self.backend.bk_expand_dims(vs3, off_S3)
3076
3134
  ) # Add a dimension for NS3
3077
- if calc_var:
3078
- VS3 = self.backend.bk_concat(
3079
- [VS3, self.backend.bk_expand_dims(vs3, off_S3)],
3080
- axis=2,
3081
- ) # Add a dimension for NS3
3135
+ # VS3.append(self.backend.bk_reshape(vs3,[s3.shape[0],s3.shape[1],
3136
+ # s3.shape[2]*s3.shape[3]]))
3082
3137
 
3083
3138
  ### S3_cross = < (I1 * Psi)_j3 x (|I2 * Psi_j2| * Psi_j3)^* >_pix
3084
3139
  ### S3P_cross = < (I2 * Psi)_j3 x (|I1 * Psi_j2| * Psi_j3)^* >_pix
@@ -3130,19 +3185,33 @@ class funct(FOC.FoCUS):
3130
3185
  if S3[j3] is None:
3131
3186
  S3[j3] = {}
3132
3187
  S3P[j3] = {}
3133
- if out_nside is not None and out_nside<nside_j2:
3188
+ if out_nside is not None and out_nside < nside_j2:
3134
3189
  s3 = self.backend.bk_reduce_mean(
3135
- self.backend.bk_reshape(s3,[s3.shape[0],
3136
- 12*out_nside**2,
3137
- (nside_j2//out_nside)**2,
3138
- s3.shape[2],
3139
- s3.shape[3]]),2)
3190
+ self.backend.bk_reshape(
3191
+ s3,
3192
+ [
3193
+ s3.shape[0],
3194
+ 12 * out_nside**2,
3195
+ (nside_j2 // out_nside) ** 2,
3196
+ s3.shape[2],
3197
+ s3.shape[3],
3198
+ ],
3199
+ ),
3200
+ 2,
3201
+ )
3140
3202
  s3p = self.backend.bk_reduce_mean(
3141
- self.backend.bk_reshape(s3p,[s3.shape[0],
3142
- 12*out_nside**2,
3143
- (nside_j2//out_nside)**2,
3144
- s3.shape[2],
3145
- s3.shape[3]]),2)
3203
+ self.backend.bk_reshape(
3204
+ s3p,
3205
+ [
3206
+ s3.shape[0],
3207
+ 12 * out_nside**2,
3208
+ (nside_j2 // out_nside) ** 2,
3209
+ s3.shape[2],
3210
+ s3.shape[3],
3211
+ ],
3212
+ ),
3213
+ 2,
3214
+ )
3146
3215
  S3[j3][j2] = s3
3147
3216
  S3P[j3][j2] = s3p
3148
3217
  else:
@@ -3166,43 +3235,34 @@ class funct(FOC.FoCUS):
3166
3235
  ) # [Nbatch, Nmask, Norient3, Norient2]
3167
3236
 
3168
3237
  ### Store S3 and S3P as a complex [Nbatch, Nmask, NS3, Norient3, Norient2]
3169
- if S3 is None:
3170
- S3 = self.backend.bk_expand_dims(
3171
- s3, off_S3
3172
- ) # Add a dimension for NS3
3173
- if calc_var:
3174
- VS3 = self.backend.bk_expand_dims(
3175
- vs3, off_S3
3176
- ) # Add a dimension for NS3
3177
- else:
3178
- S3 = self.backend.bk_concat(
3179
- [S3, self.backend.bk_expand_dims(s3, off_S3)], axis=2
3180
- ) # Add a dimension for NS3
3181
- if calc_var:
3182
- VS3 = self.backend.bk_concat(
3183
- [VS3, self.backend.bk_expand_dims(vs3, off_S3)],
3184
- axis=2,
3185
- ) # Add a dimension for NS3
3186
- if S3P is None:
3187
- S3P = self.backend.bk_expand_dims(
3188
- s3p, off_S3
3238
+
3239
+ # S3.append(self.backend.bk_reshape(s3,[s3.shape[0],s3.shape[1],
3240
+ # s3.shape[2]*s3.shape[3]]))
3241
+ S3.append(
3242
+ self.backend.bk_expand_dims(s3, off_S3)
3243
+ ) # Add a dimension for NS3
3244
+ if calc_var:
3245
+ VS3.append(
3246
+ self.backend.bk_expand_dims(vs3, off_S3)
3189
3247
  ) # Add a dimension for NS3
3190
- if calc_var:
3191
- VS3P = self.backend.bk_expand_dims(
3192
- vs3p, off_S3
3193
- ) # Add a dimension for NS3
3194
- else:
3195
- S3P = self.backend.bk_concat(
3196
- [S3P, self.backend.bk_expand_dims(s3p, off_S3)], axis=2
3248
+
3249
+ # VS3.append(self.backend.bk_reshape(vs3,[s3.shape[0],s3.shape[1],
3250
+ # s3.shape[2]*s3.shape[3]]))
3251
+
3252
+ # S3P.append(self.backend.bk_reshape(s3p,[s3.shape[0],s3.shape[1],
3253
+ # s3.shape[2]*s3.shape[3]]))
3254
+ S3P.append(
3255
+ self.backend.bk_expand_dims(s3p, off_S3)
3256
+ ) # Add a dimension for NS3
3257
+ if calc_var:
3258
+ VS3P.append(
3259
+ self.backend.bk_expand_dims(vs3p, off_S3)
3197
3260
  ) # Add a dimension for NS3
3198
- if calc_var:
3199
- VS3P = self.backend.bk_concat(
3200
- [VS3P, self.backend.bk_expand_dims(vs3p, off_S3)],
3201
- axis=2,
3202
- ) # Add a dimension for NS3
3261
+ # VS3P.append(self.backend.bk_reshape(vs3p,[s3.shape[0],s3.shape[1],
3262
+ # s3.shape[2]*s3.shape[3]]))
3203
3263
 
3204
3264
  ##### S4
3205
- nside_j1=nside_j2
3265
+ nside_j1 = nside_j2
3206
3266
  for j1 in range(0, j2 + 1): # j1 <= j2
3207
3267
  ### S4_auto = <(|I1 * psi1| * psi3)(|I1 * psi2| * psi3)^*>
3208
3268
  if not cross:
@@ -3228,14 +3288,21 @@ class funct(FOC.FoCUS):
3228
3288
  if return_data:
3229
3289
  if S4[j3][j2] is None:
3230
3290
  S4[j3][j2] = {}
3231
- if out_nside is not None and out_nside<nside_j1:
3291
+ if out_nside is not None and out_nside < nside_j1:
3232
3292
  s4 = self.backend.bk_reduce_mean(
3233
- self.backend.bk_reshape(s4,[s4.shape[0],
3234
- 12*out_nside**2,
3235
- (nside_j1//out_nside)**2,
3236
- s4.shape[2],
3237
- s4.shape[3],
3238
- s4.shape[4]]),2)
3293
+ self.backend.bk_reshape(
3294
+ s4,
3295
+ [
3296
+ s4.shape[0],
3297
+ 12 * out_nside**2,
3298
+ (nside_j1 // out_nside) ** 2,
3299
+ s4.shape[2],
3300
+ s4.shape[3],
3301
+ s4.shape[4],
3302
+ ],
3303
+ ),
3304
+ 2,
3305
+ )
3239
3306
  S4[j3][j2][j1] = s4
3240
3307
  else:
3241
3308
  ### Normalize S4 with S2_j [Nbatch, Nmask, Norient_j]
@@ -3259,27 +3326,18 @@ class funct(FOC.FoCUS):
3259
3326
  ** 0.5,
3260
3327
  ) # [Nbatch, Nmask, Norient3, Norient2, Norient1]
3261
3328
  ### Store S4 as a complex [Nbatch, Nmask, NS4, Norient3, Norient2, Norient1]
3262
- if S4 is None:
3263
- S4 = self.backend.bk_expand_dims(
3264
- s4, off_S4
3265
- ) # Add a dimension for NS4
3266
- if calc_var:
3267
- VS4 = self.backend.bk_expand_dims(
3268
- vs4, off_S4
3269
- ) # Add a dimension for NS4
3270
- else:
3271
- S4 = self.backend.bk_concat(
3272
- [S4, self.backend.bk_expand_dims(s4, off_S4)],
3273
- axis=2,
3329
+
3330
+ # S4.append(self.backend.bk_reshape(s4,[s4.shape[0],s4.shape[1],
3331
+ # s4.shape[2]*s4.shape[3]*s4.shape[4]]))
3332
+ S4.append(
3333
+ self.backend.bk_expand_dims(s4, off_S4)
3334
+ ) # Add a dimension for NS4
3335
+ if calc_var:
3336
+ # VS4.append(self.backend.bk_reshape(vs4,[s4.shape[0],s4.shape[1],
3337
+ # s4.shape[2]*s4.shape[3]*s4.shape[4]]))
3338
+ VS4.append(
3339
+ self.backend.bk_expand_dims(vs4, off_S4)
3274
3340
  ) # Add a dimension for NS4
3275
- if calc_var:
3276
- VS4 = self.backend.bk_concat(
3277
- [
3278
- VS4,
3279
- self.backend.bk_expand_dims(vs4, off_S4),
3280
- ],
3281
- axis=2,
3282
- ) # Add a dimension for NS4
3283
3341
 
3284
3342
  ### S4_cross = <(|I1 * psi1| * psi3)(|I2 * psi2| * psi3)^*>
3285
3343
  else:
@@ -3305,14 +3363,21 @@ class funct(FOC.FoCUS):
3305
3363
  if return_data:
3306
3364
  if S4[j3][j2] is None:
3307
3365
  S4[j3][j2] = {}
3308
- if out_nside is not None and out_nside<nside_j1:
3366
+ if out_nside is not None and out_nside < nside_j1:
3309
3367
  s4 = self.backend.bk_reduce_mean(
3310
- self.backend.bk_reshape(s4,[s4.shape[0],
3311
- 12*out_nside**2,
3312
- (nside_j1//out_nside)**2,
3313
- s4.shape[2],
3314
- s4.shape[3],
3315
- s4.shape[4]]),2)
3368
+ self.backend.bk_reshape(
3369
+ s4,
3370
+ [
3371
+ s4.shape[0],
3372
+ 12 * out_nside**2,
3373
+ (nside_j1 // out_nside) ** 2,
3374
+ s4.shape[2],
3375
+ s4.shape[3],
3376
+ s4.shape[4],
3377
+ ],
3378
+ ),
3379
+ 2,
3380
+ )
3316
3381
  S4[j3][j2][j1] = s4
3317
3382
  else:
3318
3383
  ### Normalize S4 with S2_j [Nbatch, Nmask, Norient_j]
@@ -3336,41 +3401,33 @@ class funct(FOC.FoCUS):
3336
3401
  ** 0.5,
3337
3402
  ) # [Nbatch, Nmask, Norient3, Norient2, Norient1]
3338
3403
  ### Store S4 as a complex [Nbatch, Nmask, NS4, Norient3, Norient2, Norient1]
3339
- if S4 is None:
3340
- S4 = self.backend.bk_expand_dims(
3341
- s4, off_S4
3342
- ) # Add a dimension for NS4
3343
- if calc_var:
3344
- VS4 = self.backend.bk_expand_dims(
3345
- vs4, off_S4
3346
- ) # Add a dimension for NS4
3347
- else:
3348
- S4 = self.backend.bk_concat(
3349
- [S4, self.backend.bk_expand_dims(s4, off_S4)],
3350
- axis=2,
3404
+ # S4.append(self.backend.bk_reshape(s4,[s4.shape[0],s4.shape[1],
3405
+ # s4.shape[2]*s4.shape[3]*s4.shape[4]]))
3406
+ S4.append(
3407
+ self.backend.bk_expand_dims(s4, off_S4)
3408
+ ) # Add a dimension for NS4
3409
+ if calc_var:
3410
+
3411
+ # VS4.append(self.backend.bk_reshape(vs4,[s4.shape[0],s4.shape[1],
3412
+ # s4.shape[2]*s4.shape[3]*s4.shape[4]]))
3413
+ VS4.append(
3414
+ self.backend.bk_expand_dims(vs4, off_S4)
3351
3415
  ) # Add a dimension for NS4
3352
- if calc_var:
3353
- VS4 = self.backend.bk_concat(
3354
- [
3355
- VS4,
3356
- self.backend.bk_expand_dims(vs4, off_S4),
3357
- ],
3358
- axis=2,
3359
- ) # Add a dimension for NS4
3360
- nside_j1=nside_j1 // 2
3361
- nside_j2=nside_j2 // 2
3362
-
3416
+
3417
+ nside_j1 = nside_j1 // 2
3418
+ nside_j2 = nside_j2 // 2
3419
+
3363
3420
  ###### Reshape for next iteration on j3
3364
3421
  ### Image I1,
3365
3422
  # downscale the I1 [Nbatch, Npix_j3]
3366
3423
  if j3 != Jmax - 1:
3367
- I1_smooth = self.smooth(I1, axis=1)
3368
- I1 = self.ud_grade_2(I1_smooth, axis=1)
3424
+ I1 = self.smooth(I1, axis=1)
3425
+ I1 = self.ud_grade_2(I1, axis=1)
3369
3426
 
3370
3427
  ### Image I2
3371
3428
  if cross:
3372
- I2_smooth = self.smooth(I2, axis=1)
3373
- I2 = self.ud_grade_2(I2_smooth, axis=1)
3429
+ I2 = self.smooth(I2, axis=1)
3430
+ I2 = self.ud_grade_2(I2, axis=1)
3374
3431
 
3375
3432
  ### Modules
3376
3433
  for j2 in range(0, j3 + 1): # j2 =< j3
@@ -3388,8 +3445,9 @@ class funct(FOC.FoCUS):
3388
3445
  M2_dic[j2], axis=1
3389
3446
  ) # [Nbatch, Npix_j3, Norient3]
3390
3447
  M2_dic[j2] = self.ud_grade_2(
3391
- M2_smooth, axis=1
3448
+ M2, axis=1
3392
3449
  ) # [Nbatch, Npix_j3, Norient3]
3450
+
3393
3451
  ### Mask
3394
3452
  vmask = self.ud_grade_2(vmask, axis=1)
3395
3453
 
@@ -3404,7 +3462,33 @@ class funct(FOC.FoCUS):
3404
3462
  self.P1_dic = P1_dic
3405
3463
  if cross:
3406
3464
  self.P2_dic = P2_dic
3465
+ """
3466
+ Sout=[s0]+S1+S2+S3+S4
3467
+
3468
+ if cross:
3469
+ Sout=Sout+S3P
3470
+ if calc_var:
3471
+ SVout=[vs0]+VS1+VS2+VS3+VS4
3472
+ if cross:
3473
+ VSout=VSout+VS3P
3474
+ return self.backend.bk_concat(Sout, 2),self.backend.bk_concat(VSout, 2)
3407
3475
 
3476
+ return self.backend.bk_concat(Sout, 2)
3477
+ """
3478
+ if not return_data:
3479
+ S1 = self.backend.bk_concat(S1, 2)
3480
+ S2 = self.backend.bk_concat(S2, 2)
3481
+ S3 = self.backend.bk_concat(S3, 2)
3482
+ S4 = self.backend.bk_concat(S4, 2)
3483
+ if cross:
3484
+ S3P = self.backend.bk_concat(S3P, 2)
3485
+ if calc_var:
3486
+ VS1 = self.backend.bk_concat(VS1, 2)
3487
+ VS2 = self.backend.bk_concat(VS2, 2)
3488
+ VS3 = self.backend.bk_concat(VS3, 2)
3489
+ VS4 = self.backend.bk_concat(VS4, 2)
3490
+ if cross:
3491
+ VS3P = self.backend.bk_concat(VS3P, 2)
3408
3492
  if calc_var:
3409
3493
  if not cross:
3410
3494
  return scat_cov(
@@ -3455,70 +3539,1146 @@ class funct(FOC.FoCUS):
3455
3539
  use_1D=self.use_1D,
3456
3540
  )
3457
3541
 
3458
- def clean_norm(self):
3459
- self.P1_dic = None
3460
- self.P2_dic = None
3461
- return
3462
-
3463
- def _compute_S3(
3464
- self,
3465
- j2,
3466
- j3,
3467
- conv,
3468
- vmask,
3469
- M_dic,
3470
- MconvPsi_dic,
3471
- calc_var=False,
3472
- return_data=False,
3473
- cmat2=None,
3542
+ def eval_new(
3543
+ self,
3544
+ image1,
3545
+ image2=None,
3546
+ mask=None,
3547
+ norm=None,
3548
+ calc_var=False,
3549
+ cmat=None,
3550
+ cmat2=None,
3551
+ Jmax=None,
3552
+ out_nside=None,
3553
+ edge=True
3474
3554
  ):
3475
3555
  """
3476
- Compute the S3 coefficients (auto or cross)
3477
- S3 = < (Ia * Psi)_j3 x (|Ib * Psi_j2| * Psi_j3)^* >_pix
3556
+ Calculates the scattering correlations for a batch of images. Mean are done over pixels.
3557
+ mean of modulus:
3558
+ S1 = <|I * Psi_j3|>
3559
+ Normalization : take the log
3560
+ power spectrum:
3561
+ S2 = <|I * Psi_j3|^2>
3562
+ Normalization : take the log
3563
+ orig. x modulus:
3564
+ S3 = < (I * Psi)_j3 x (|I * Psi_j2| * Psi_j3)^* >
3565
+ Normalization : divide by (S2_j2 * S2_j3)^0.5
3566
+ modulus x modulus:
3567
+ S4 = <(|I * psi1| * psi3)(|I * psi2| * psi3)^*>
3568
+ Normalization : divide by (S2_j1 * S2_j2)^0.5
3478
3569
  Parameters
3479
3570
  ----------
3571
+ image1: tensor
3572
+ Image on which we compute the scattering coefficients [Nbatch, Npix, 1, 1]
3573
+ image2: tensor
3574
+ Second image. If not None, we compute cross-scattering covariance coefficients.
3575
+ mask:
3576
+ norm: None or str
3577
+ If None no normalization is applied, if 'auto' normalize by the reference S2,
3578
+ if 'self' normalize by the current S2.
3480
3579
  Returns
3481
3580
  -------
3482
- cs3, ss3: real and imag parts of S3 coeff
3581
+ S1, S2, S3, S4 normalized
3483
3582
  """
3484
- ### Compute |I1 * Psi_j2| * Psi_j3 = M1_j2 * Psi_j3
3485
- # Warning: M1_dic[j2] is already at j3 resolution [Nbatch, Npix_j3, Norient3]
3486
- MconvPsi = self.convol(
3487
- M_dic[j2], axis=1
3488
- ) # [Nbatch, Npix_j3, Norient3, Norient2]
3489
- if cmat2 is not None:
3490
- tmp2 = self.backend.bk_repeat(MconvPsi, 4, axis=-1)
3491
- MconvPsi = self.backend.bk_reduce_sum(
3492
- self.backend.bk_reshape(
3493
- cmat2[j3][j2] * tmp2, [tmp2.shape[0], cmat2[j3].shape[1], 4, 4, 4]
3494
- ),
3495
- 3,
3583
+ return_data = self.return_data
3584
+ NORIENT=self.NORIENT
3585
+ # Check input consistency
3586
+ if image2 is not None:
3587
+ if list(image1.shape) != list(image2.shape):
3588
+ print(
3589
+ "The two input image should have the same size to eval Scattering Covariance"
3590
+ )
3591
+ return None
3592
+ if mask is not None:
3593
+ if image1.shape[-2] != mask.shape[1] or image1.shape[-1] != mask.shape[2]:
3594
+ print(
3595
+ "The LAST COLUMN of the mask should have the same size ",
3596
+ mask.shape,
3597
+ "than the input image ",
3598
+ image1.shape,
3599
+ "to eval Scattering Covariance",
3600
+ )
3601
+ return None
3602
+ if self.use_2D and len(image1.shape) < 2:
3603
+ print(
3604
+ "To work with 2D scattering transform, two dimension is needed, input map has only on dimension"
3496
3605
  )
3606
+ return None
3497
3607
 
3498
- # Store it so we can use it in S4 computation
3499
- MconvPsi_dic[j2] = MconvPsi # [Nbatch, Npix_j3, Norient3, Norient2]
3500
-
3501
- ### Compute the product (I2 * Psi)_j3 x (M1_j2 * Psi_j3)^*
3502
- # z_1 x z_2^* = (a1a2 + b1b2) + i(b1a2 - a1b2)
3503
- # cconv, sconv are [Nbatch, Npix_j3, Norient3]
3504
- if self.use_1D:
3505
- s3 = conv * self.backend.bk_conjugate(MconvPsi)
3506
- else:
3507
- s3 = self.backend.bk_expand_dims(conv, -1) * self.backend.bk_conjugate(
3508
- MconvPsi
3509
- ) # [Nbatch, Npix_j3, Norient3, Norient2]
3608
+ ### AUTO OR CROSS
3609
+ cross = False
3610
+ if image2 is not None:
3611
+ cross = True
3510
3612
 
3511
- ### Apply the mask [Nmask, Npix_j3] and sum over pixels
3512
- if return_data:
3513
- return s3
3514
- else:
3515
- if calc_var:
3516
- s3, vs3 = self.masked_mean(
3517
- s3, vmask, axis=1, rank=j2, calc_var=True
3518
- ) # [Nbatch, Nmask, Norient3, Norient2]
3519
- return s3, vs3
3613
+ ### PARAMETERS
3614
+ axis = 1
3615
+ # determine jmax and nside corresponding to the input map
3616
+ im_shape = image1.shape
3617
+ if self.use_2D:
3618
+ if len(image1.shape) == 2:
3619
+ nside = np.min([im_shape[0], im_shape[1]])
3620
+ npix = im_shape[0] * im_shape[1] # Number of pixels
3621
+ x1 = im_shape[0]
3622
+ x2 = im_shape[1]
3520
3623
  else:
3521
- s3 = self.masked_mean(
3624
+ nside = np.min([im_shape[1], im_shape[2]])
3625
+ npix = im_shape[1] * im_shape[2] # Number of pixels
3626
+ x1 = im_shape[1]
3627
+ x2 = im_shape[2]
3628
+ J = int(np.log(nside - self.KERNELSZ) / np.log(2)) # Number of j scales
3629
+ elif self.use_1D:
3630
+ if len(image1.shape) == 2:
3631
+ npix = int(im_shape[1]) # Number of pixels
3632
+ else:
3633
+ npix = int(im_shape[0]) # Number of pixels
3634
+
3635
+ nside = int(npix)
3636
+
3637
+ J = int(np.log(nside) / np.log(2)) # Number of j scales
3638
+ else:
3639
+ if len(image1.shape) == 2:
3640
+ npix = int(im_shape[1]) # Number of pixels
3641
+ else:
3642
+ npix = int(im_shape[0]) # Number of pixels
3643
+
3644
+ nside = int(np.sqrt(npix // 12))
3645
+
3646
+ J = int(np.log(nside) / np.log(2)) # Number of j scales
3647
+
3648
+ if (self.use_2D or self.use_1D) and self.KERNELSZ>3:
3649
+ J-=1
3650
+ if Jmax is None:
3651
+ Jmax = J # Number of steps for the loop on scales
3652
+ if Jmax>J:
3653
+ print('==========\n\n')
3654
+ print('The Jmax you requested is larger than the data size, which may cause problems while computing the scattering transform.')
3655
+ print('\n\n==========')
3656
+
3657
+
3658
+ ### LOCAL VARIABLES (IMAGES and MASK)
3659
+ if len(image1.shape) == 1 or (len(image1.shape) == 2 and self.use_2D):
3660
+ I1 = self.backend.bk_cast(
3661
+ self.backend.bk_expand_dims(image1, 0)
3662
+ ) # Local image1 [Nbatch, Npix]
3663
+ if cross:
3664
+ I2 = self.backend.bk_cast(
3665
+ self.backend.bk_expand_dims(image2, 0)
3666
+ ) # Local image2 [Nbatch, Npix]
3667
+ else:
3668
+ I1 = self.backend.bk_cast(image1) # Local image1 [Nbatch, Npix]
3669
+ if cross:
3670
+ I2 = self.backend.bk_cast(image2) # Local image2 [Nbatch, Npix]
3671
+
3672
+ if mask is None:
3673
+ if self.use_2D:
3674
+ vmask = self.backend.bk_ones([1, x1, x2], dtype=self.all_type)
3675
+ else:
3676
+ vmask = self.backend.bk_ones([1, npix], dtype=self.all_type)
3677
+ else:
3678
+ vmask = self.backend.bk_cast(mask) # [Nmask, Npix]
3679
+
3680
+ if self.KERNELSZ > 3 and not self.use_2D:
3681
+ # if the kernel size is bigger than 3 increase the binning before smoothing
3682
+ if self.use_2D:
3683
+ vmask = self.up_grade(
3684
+ vmask, I1.shape[axis] * 2, axis=1, nouty=I1.shape[axis + 1] * 2
3685
+ )
3686
+ I1 = self.up_grade(
3687
+ I1, I1.shape[axis] * 2, axis=axis, nouty=I1.shape[axis + 1] * 2
3688
+ )
3689
+ if cross:
3690
+ I2 = self.up_grade(
3691
+ I2, I2.shape[axis] * 2, axis=axis, nouty=I2.shape[axis + 1] * 2
3692
+ )
3693
+ elif self.use_1D:
3694
+ vmask = self.up_grade(vmask, I1.shape[axis] * 2, axis=1)
3695
+ I1 = self.up_grade(I1, I1.shape[axis] * 2, axis=axis)
3696
+ if cross:
3697
+ I2 = self.up_grade(I2, I2.shape[axis] * 2, axis=axis)
3698
+ else:
3699
+ I1 = self.up_grade(I1, nside * 2, axis=axis)
3700
+ vmask = self.up_grade(vmask, nside * 2, axis=1)
3701
+ if cross:
3702
+ I2 = self.up_grade(I2, nside * 2, axis=axis)
3703
+
3704
+ if self.KERNELSZ > 5 and not self.use_2D:
3705
+ # if the kernel size is bigger than 3 increase the binning before smoothing
3706
+ if self.use_2D:
3707
+ vmask = self.up_grade(
3708
+ vmask, I1.shape[axis] * 2, axis=1, nouty=I1.shape[axis + 1] * 2
3709
+ )
3710
+ I1 = self.up_grade(
3711
+ I1, I1.shape[axis] * 2, axis=axis, nouty=I1.shape[axis + 1] * 2
3712
+ )
3713
+ if cross:
3714
+ I2 = self.up_grade(
3715
+ I2,
3716
+ I2.shape[axis] * 2,
3717
+ axis=axis,
3718
+ nouty=I2.shape[axis + 1] * 2,
3719
+ )
3720
+ elif self.use_1D:
3721
+ vmask = self.up_grade(vmask, I1.shape[axis] * 4, axis=1)
3722
+ I1 = self.up_grade(I1, I1.shape[axis] * 4, axis=axis)
3723
+ if cross:
3724
+ I2 = self.up_grade(I2, I2.shape[axis] * 4, axis=axis)
3725
+ else:
3726
+ I1 = self.up_grade(I1, nside * 4, axis=axis)
3727
+ vmask = self.up_grade(vmask, nside * 4, axis=1)
3728
+ if cross:
3729
+ I2 = self.up_grade(I2, nside * 4, axis=axis)
3730
+
3731
+ # Normalize the masks because they have different pixel numbers
3732
+ # vmask /= self.backend.bk_reduce_sum(vmask, axis=1)[:, None] # [Nmask, Npix]
3733
+
3734
+ ### INITIALIZATION
3735
+ # Coefficients
3736
+ if return_data:
3737
+ S1 = {}
3738
+ S2 = {}
3739
+ S3 = {}
3740
+ S3P = {}
3741
+ S4 = {}
3742
+ else:
3743
+ result=self.backend.backend.zeros([I1.shape[0],vmask.shape[0],2+2*Jmax*self.NORIENT],
3744
+ dtype=self.backend.backend.float32,
3745
+ device=self.backend.torch_device)
3746
+ vresult=self.backend.backend.zeros([I1.shape[0],vmask.shape[0],2+2*Jmax*self.NORIENT],
3747
+ dtype=self.backend.backend.float32,
3748
+ device=self.backend.torch_device)
3749
+ S1 = self.backend.backend.zeros([1,Jmax*self.NORIENT],dtype=self.backend.backend.float32,device=self.backend.torch_device)
3750
+ S2 = self.backend.backend.zeros([1,Jmax*self.NORIENT],dtype=self.backend.backend.float32,device=self.backend.torch_device)
3751
+ S3 = []
3752
+ S4 = []
3753
+ S3P = []
3754
+ VS1 = self.backend.backend.zeros([1,Jmax*self.NORIENT],dtype=self.backend.backend.float32,device=self.backend.torch_device)
3755
+ VS2 = self.backend.backend.zeros([1,Jmax*self.NORIENT],dtype=self.backend.backend.float32,device=self.backend.torch_device)
3756
+ VS3 = []
3757
+ VS3P = []
3758
+ VS4 = []
3759
+
3760
+ off_S2 = -2
3761
+ off_S3 = -3
3762
+ off_S4 = -4
3763
+ if self.use_1D:
3764
+ off_S2 = -1
3765
+ off_S3 = -1
3766
+ off_S4 = -1
3767
+
3768
+ # S2 for normalization
3769
+ cond_init_P1_dic = (norm == "self") or (
3770
+ (norm == "auto") and (self.P1_dic is None)
3771
+ )
3772
+ if norm is None:
3773
+ pass
3774
+ elif cond_init_P1_dic:
3775
+ P1_dic = {}
3776
+ if cross:
3777
+ P2_dic = {}
3778
+ elif (norm == "auto") and (self.P1_dic is not None):
3779
+ P1_dic = self.P1_dic
3780
+ if cross:
3781
+ P2_dic = self.P2_dic
3782
+
3783
+ if return_data:
3784
+ s0 = I1
3785
+ if out_nside is not None:
3786
+ s0 = self.backend.bk_reduce_mean(
3787
+ self.backend.bk_reshape(
3788
+ s0, [s0.shape[0], 12 * out_nside**2, (nside // out_nside) ** 2]
3789
+ ),
3790
+ 2,
3791
+ )
3792
+ else:
3793
+ if not cross:
3794
+ s0, l_vs0 = self.masked_mean(I1, vmask, axis=1, calc_var=True)
3795
+ else:
3796
+ s0, l_vs0 = self.masked_mean(
3797
+ self.backend.bk_L1(I1 * I2), vmask, axis=1, calc_var=True
3798
+ )
3799
+ #vs0 = self.backend.bk_concat([l_vs0, l_vs0], 1)
3800
+ #s0 = self.backend.bk_concat([s0, l_vs0], 1)
3801
+ result[:,:,0]=s0
3802
+ result[:,:,1]=l_vs0
3803
+ vresult[:,:,0]=l_vs0
3804
+ vresult[:,:,1]=l_vs0
3805
+ #### COMPUTE S1, S2, S3 and S4
3806
+ nside_j3 = nside # NSIDE start (nside_j3 = nside / 2^j3)
3807
+
3808
+ # a remettre comme avant
3809
+ M1_dic={}
3810
+
3811
+ M2_dic={}
3812
+
3813
+ for j3 in range(Jmax):
3814
+
3815
+ if edge:
3816
+ if self.mask_mask is None:
3817
+ self.mask_mask={}
3818
+ if self.use_2D:
3819
+ if (vmask.shape[1],vmask.shape[2]) not in self.mask_mask:
3820
+ mask_mask=np.zeros([1,vmask.shape[1],vmask.shape[2]])
3821
+ mask_mask[0,
3822
+ self.KERNELSZ//2:-self.KERNELSZ//2+1,
3823
+ self.KERNELSZ//2:-self.KERNELSZ//2+1]=1.0
3824
+ self.mask_mask[(vmask.shape[1],vmask.shape[2])]=self.backend.bk_cast(mask_mask)
3825
+ vmask=vmask*self.mask_mask[(vmask.shape[1],vmask.shape[2])]
3826
+ #print(self.KERNELSZ//2,vmask,mask_mask)
3827
+
3828
+ if self.use_1D:
3829
+ if (vmask.shape[1]) not in self.mask_mask:
3830
+ mask_mask=np.zeros([1,vmask.shape[1]])
3831
+ mask_mask[0,
3832
+ self.KERNELSZ//2:-self.KERNELSZ//2+1]=1.0
3833
+ self.mask_mask[(vmask.shape[1])]=self.backend.bk_cast(mask_mask)
3834
+ vmask=vmask*self.mask_mask[(vmask.shape[1])]
3835
+
3836
+ if return_data:
3837
+ S3[j3] = None
3838
+ S3P[j3] = None
3839
+
3840
+ if S4 is None:
3841
+ S4 = {}
3842
+ S4[j3] = None
3843
+
3844
+ ####### S1 and S2
3845
+ ### Make the convolution I1 * Psi_j3
3846
+ conv1 = self.convol(I1, axis=1) # [Nbatch, Npix_j3, Norient3]
3847
+ if cmat is not None:
3848
+ tmp2 = self.backend.bk_repeat(conv1, self.NORIENT, axis=-1)
3849
+ conv1 = self.backend.bk_reduce_sum(
3850
+ self.backend.bk_reshape(
3851
+ cmat[j3] * tmp2,
3852
+ [tmp2.shape[0], cmat[j3].shape[0], self.NORIENT, self.NORIENT],
3853
+ ),
3854
+ 2,
3855
+ )
3856
+
3857
+ ### Take the module M1 = |I1 * Psi_j3|
3858
+ M1_square = conv1 * self.backend.bk_conjugate(
3859
+ conv1
3860
+ ) # [Nbatch, Npix_j3, Norient3]
3861
+
3862
+ M1 = self.backend.bk_L1(M1_square) # [Nbatch, Npix_j3, Norient3]
3863
+
3864
+ # Store M1_j3 in a dictionary
3865
+ M1_dic[j3] = M1
3866
+
3867
+ if not cross: # Auto
3868
+ M1_square = self.backend.bk_real(M1_square)
3869
+
3870
+ ### S2_auto = < M1^2 >_pix
3871
+ # Apply the mask [Nmask, Npix_j3] and average over pixels
3872
+ if return_data:
3873
+ s2 = M1_square
3874
+ else:
3875
+ if calc_var:
3876
+ s2, vs2 = self.masked_mean(
3877
+ M1_square, vmask, axis=1, rank=j3, calc_var=True
3878
+ )
3879
+ #s2=self.backend.bk_flatten(self.backend.bk_real(s2))
3880
+ #vs2=self.backend.bk_flatten(vs2)
3881
+ else:
3882
+ s2 = self.masked_mean(M1_square, vmask, axis=1, rank=j3)
3883
+
3884
+ if cond_init_P1_dic:
3885
+ # We fill P1_dic with S2 for normalisation of S3 and S4
3886
+ P1_dic[j3] = self.backend.bk_real(self.backend.bk_real(s2)) # [Nbatch, Nmask, Norient3]
3887
+
3888
+ # We store S2_auto to return it [Nbatch, Nmask, NS2, Norient3]
3889
+ if return_data:
3890
+ if S2 is None:
3891
+ S2 = {}
3892
+ if out_nside is not None and out_nside < nside_j3:
3893
+ s2 = self.backend.bk_reduce_mean(
3894
+ self.backend.bk_reshape(
3895
+ s2,
3896
+ [
3897
+ s2.shape[0],
3898
+ 12 * out_nside**2,
3899
+ (nside_j3 // out_nside) ** 2,
3900
+ s2.shape[2],
3901
+ ],
3902
+ ),
3903
+ 2,
3904
+ )
3905
+ S2[j3] = s2
3906
+ else:
3907
+ if norm == "auto": # Normalize S2
3908
+ s2 /= P1_dic[j3]
3909
+ """
3910
+ S2.append(
3911
+ self.backend.bk_expand_dims(s2, off_S2)
3912
+ ) # Add a dimension for NS2
3913
+ if calc_var:
3914
+ VS2.append(
3915
+ self.backend.bk_expand_dims(vs2, off_S2)
3916
+ ) # Add a dimension for NS2
3917
+ """
3918
+ #print(s2.shape,result[:,:,2+j3*NORIENT*2:2+j3*NORIENT*2+NORIENT].shape,result.shape,2+j3*NORIENT*2)
3919
+ result[:,:,2+j3*NORIENT*2:2+j3*NORIENT*2+NORIENT]=s2
3920
+ if calc_var:
3921
+ vresult[:,:,2+j3*NORIENT*2:2+j3*NORIENT*2+NORIENT]=vs2
3922
+ #### S1_auto computation
3923
+ ### Image 1 : S1 = < M1 >_pix
3924
+ # Apply the mask [Nmask, Npix_j3] and average over pixels
3925
+ if return_data:
3926
+ s1 = M1
3927
+ else:
3928
+ if calc_var:
3929
+ s1, vs1 = self.masked_mean(
3930
+ M1, vmask, axis=1, rank=j3, calc_var=True
3931
+ ) # [Nbatch, Nmask, Norient3]
3932
+ #s1=self.backend.bk_flatten(self.backend.bk_real(s1))
3933
+ #vs1=self.backend.bk_flatten(vs1)
3934
+ else:
3935
+ s1 = self.masked_mean(
3936
+ M1, vmask, axis=1, rank=j3
3937
+ ) # [Nbatch, Nmask, Norient3]
3938
+ #s1=self.backend.bk_flatten(self.backend.bk_real(s1))
3939
+
3940
+ if return_data:
3941
+ if out_nside is not None and out_nside < nside_j3:
3942
+ s1 = self.backend.bk_reduce_mean(
3943
+ self.backend.bk_reshape(
3944
+ s1,
3945
+ [
3946
+ s1.shape[0],
3947
+ 12 * out_nside**2,
3948
+ (nside_j3 // out_nside) ** 2,
3949
+ s1.shape[2],
3950
+ ],
3951
+ ),
3952
+ 2,
3953
+ )
3954
+ S1[j3] = s1
3955
+ else:
3956
+ ### Normalize S1
3957
+ if norm is not None:
3958
+ self.div_norm(s1, (P1_dic[j3]) ** 0.5)
3959
+ result[:,:,2+j3*NORIENT*2+NORIENT:2+j3*NORIENT*2+2*NORIENT]=s1
3960
+ if calc_var:
3961
+ vresult[:,:,2+j3*NORIENT*2+NORIENT:2+j3*NORIENT*2+2*NORIENT]=vs1
3962
+ """
3963
+ ### We store S1 for image1 [Nbatch, Nmask, NS1, Norient3]
3964
+ S1.append(
3965
+ self.backend.bk_expand_dims(s1, off_S2)
3966
+ ) # Add a dimension for NS1
3967
+ if calc_var:
3968
+ VS1.append(
3969
+ self.backend.bk_expand_dims(vs1, off_S2)
3970
+ ) # Add a dimension for NS1
3971
+ """
3972
+
3973
+ else: # Cross
3974
+ ### Make the convolution I2 * Psi_j3
3975
+ conv2 = self.convol(I2, axis=1) # [Nbatch, Npix_j3, Norient3]
3976
+ if cmat is not None:
3977
+ tmp2 = self.backend.bk_repeat(conv2, self.NORIENT, axis=-1)
3978
+ conv2 = self.backend.bk_reduce_sum(
3979
+ self.backend.bk_reshape(
3980
+ cmat[j3] * tmp2,
3981
+ [
3982
+ tmp2.shape[0],
3983
+ cmat[j3].shape[0],
3984
+ self.NORIENT,
3985
+ self.NORIENT,
3986
+ ],
3987
+ ),
3988
+ 2,
3989
+ )
3990
+ ### Take the module M2 = |I2 * Psi_j3|
3991
+ M2_square = conv2 * self.backend.bk_conjugate(
3992
+ conv2
3993
+ ) # [Nbatch, Npix_j3, Norient3]
3994
+ M2 = self.backend.bk_L1(M2_square) # [Nbatch, Npix_j3, Norient3]
3995
+ # Store M2_j3 in a dictionary
3996
+ M2_dic[j3] = M2
3997
+
3998
+ ### S2_auto = < M2^2 >_pix
3999
+ # Not returned, only for normalization
4000
+ if cond_init_P1_dic:
4001
+ # Apply the mask [Nmask, Npix_j3] and average over pixels
4002
+ if return_data:
4003
+ p1 = M1_square
4004
+ p2 = M2_square
4005
+ else:
4006
+ if calc_var:
4007
+ p1, vp1 = self.masked_mean(
4008
+ M1_square, vmask, axis=1, rank=j3, calc_var=True
4009
+ ) # [Nbatch, Nmask, Norient3]
4010
+ p2, vp2 = self.masked_mean(
4011
+ M2_square, vmask, axis=1, rank=j3, calc_var=True
4012
+ ) # [Nbatch, Nmask, Norient3]
4013
+ else:
4014
+ p1 = self.masked_mean(
4015
+ M1_square, vmask, axis=1, rank=j3
4016
+ ) # [Nbatch, Nmask, Norient3]
4017
+ p2 = self.masked_mean(
4018
+ M2_square, vmask, axis=1, rank=j3
4019
+ ) # [Nbatch, Nmask, Norient3]
4020
+ # We fill P1_dic with S2 for normalisation of S3 and S4
4021
+ P1_dic[j3] = self.backend.bk_real(p1) # [Nbatch, Nmask, Norient3]
4022
+ P2_dic[j3] = self.backend.bk_real(p2) # [Nbatch, Nmask, Norient3]
4023
+
4024
+ ### S2_cross = < (I1 * Psi_j3) (I2 * Psi_j3)^* >_pix
4025
+ # z_1 x z_2^* = (a1a2 + b1b2) + i(b1a2 - a1b2)
4026
+ s2 = conv1 * self.backend.bk_conjugate(conv2)
4027
+ MX = self.backend.bk_L1(s2)
4028
+ # Apply the mask [Nmask, Npix_j3] and average over pixels
4029
+ if return_data:
4030
+ s2 = s2
4031
+ else:
4032
+ if calc_var:
4033
+ s2, vs2 = self.masked_mean(
4034
+ s2, vmask, axis=1, rank=j3, calc_var=True
4035
+ )
4036
+ else:
4037
+ s2 = self.masked_mean(s2, vmask, axis=1, rank=j3)
4038
+
4039
+ if return_data:
4040
+ if out_nside is not None and out_nside < nside_j3:
4041
+ s2 = self.backend.bk_reduce_mean(
4042
+ self.backend.bk_reshape(
4043
+ s2,
4044
+ [
4045
+ s2.shape[0],
4046
+ 12 * out_nside**2,
4047
+ (nside_j3 // out_nside) ** 2,
4048
+ s2.shape[2],
4049
+ ],
4050
+ ),
4051
+ 2,
4052
+ )
4053
+ S2[j3] = s2
4054
+ else:
4055
+ ### Normalize S2_cross
4056
+ if norm == "auto":
4057
+ s2 /= (P1_dic[j3] * P2_dic[j3]) ** 0.5
4058
+
4059
+ ### Store S2_cross as complex [Nbatch, Nmask, NS2, Norient3]
4060
+ s2 = self.backend.bk_real(s2)
4061
+
4062
+ S2.append(
4063
+ self.backend.bk_expand_dims(s2, off_S2)
4064
+ ) # Add a dimension for NS2
4065
+ if calc_var:
4066
+ VS2.append(
4067
+ self.backend.bk_expand_dims(vs2, off_S2)
4068
+ ) # Add a dimension for NS2
4069
+
4070
+ #### S1_auto computation
4071
+ ### Image 1 : S1 = < M1 >_pix
4072
+ # Apply the mask [Nmask, Npix_j3] and average over pixels
4073
+ if return_data:
4074
+ s1 = MX
4075
+ else:
4076
+ if calc_var:
4077
+ s1, vs1 = self.masked_mean(
4078
+ MX, vmask, axis=1, rank=j3, calc_var=True
4079
+ ) # [Nbatch, Nmask, Norient3]
4080
+ else:
4081
+ s1 = self.masked_mean(
4082
+ MX, vmask, axis=1, rank=j3
4083
+ ) # [Nbatch, Nmask, Norient3]
4084
+ if return_data:
4085
+ if out_nside is not None and out_nside < nside_j3:
4086
+ s1 = self.backend.bk_reduce_mean(
4087
+ self.backend.bk_reshape(
4088
+ s1,
4089
+ [
4090
+ s1.shape[0],
4091
+ 12 * out_nside**2,
4092
+ (nside_j3 // out_nside) ** 2,
4093
+ s1.shape[2],
4094
+ ],
4095
+ ),
4096
+ 2,
4097
+ )
4098
+ S1[j3] = s1
4099
+ else:
4100
+ ### Normalize S1
4101
+ if norm is not None:
4102
+ self.div_norm(s1, (P1_dic[j3]) ** 0.5)
4103
+ ### We store S1 for image1 [Nbatch, Nmask, NS1, Norient3]
4104
+ S1.append(
4105
+ self.backend.bk_expand_dims(s1, off_S2)
4106
+ ) # Add a dimension for NS1
4107
+ if calc_var:
4108
+ VS1.append(
4109
+ self.backend.bk_expand_dims(vs1, off_S2)
4110
+ ) # Add a dimension for NS1
4111
+
4112
+ # Initialize dictionaries for |I1*Psi_j| * Psi_j3
4113
+ M1convPsi_dic = {}
4114
+ if cross:
4115
+ # Initialize dictionaries for |I2*Psi_j| * Psi_j3
4116
+ M2convPsi_dic = {}
4117
+
4118
+ ###### S3
4119
+ nside_j2 = nside_j3
4120
+ for j2 in range(0,-1): # j3 + 1): # j2 <= j3
4121
+ if return_data:
4122
+ if S4[j3] is None:
4123
+ S4[j3] = {}
4124
+ S4[j3][j2] = None
4125
+
4126
+ ### S3_auto = < (I1 * Psi)_j3 x (|I1 * Psi_j2| * Psi_j3)^* >_pix
4127
+ if not cross:
4128
+ if calc_var:
4129
+ s3, vs3 = self._compute_S3(
4130
+ j2,
4131
+ j3,
4132
+ conv1,
4133
+ vmask,
4134
+ M1_dic,
4135
+ M1convPsi_dic,
4136
+ calc_var=True,
4137
+ cmat2=cmat2,
4138
+ ) # [Nbatch, Nmask, Norient3, Norient2]
4139
+ else:
4140
+ s3 = self._compute_S3(
4141
+ j2,
4142
+ j3,
4143
+ conv1,
4144
+ vmask,
4145
+ M1_dic,
4146
+ M1convPsi_dic,
4147
+ return_data=return_data,
4148
+ cmat2=cmat2,
4149
+ ) # [Nbatch, Nmask, Norient3, Norient2]
4150
+
4151
+ if return_data:
4152
+ if S3[j3] is None:
4153
+ S3[j3] = {}
4154
+ if out_nside is not None and out_nside < nside_j2:
4155
+ s3 = self.backend.bk_reduce_mean(
4156
+ self.backend.bk_reshape(
4157
+ s3,
4158
+ [
4159
+ s3.shape[0],
4160
+ 12 * out_nside**2,
4161
+ (nside_j2 // out_nside) ** 2,
4162
+ s3.shape[2],
4163
+ s3.shape[3],
4164
+ ],
4165
+ ),
4166
+ 2,
4167
+ )
4168
+ S3[j3][j2] = s3
4169
+ else:
4170
+ ### Normalize S3 with S2_j [Nbatch, Nmask, Norient_j]
4171
+ if norm is not None:
4172
+ self.div_norm(
4173
+ s3,
4174
+ (
4175
+ self.backend.bk_expand_dims(P1_dic[j2], off_S2)
4176
+ * self.backend.bk_expand_dims(P1_dic[j3], -1)
4177
+ )
4178
+ ** 0.5,
4179
+ ) # [Nbatch, Nmask, Norient3, Norient2]
4180
+
4181
+ ### Store S3 as a complex [Nbatch, Nmask, NS3, Norient3, Norient2]
4182
+
4183
+ # S3.append(self.backend.bk_reshape(s3,[s3.shape[0],s3.shape[1],
4184
+ # s3.shape[2]*s3.shape[3]]))
4185
+ S3.append(
4186
+ self.backend.bk_expand_dims(s3, off_S3)
4187
+ ) # Add a dimension for NS3
4188
+ if calc_var:
4189
+ VS3.append(
4190
+ self.backend.bk_expand_dims(vs3, off_S3)
4191
+ ) # Add a dimension for NS3
4192
+ # VS3.append(self.backend.bk_reshape(vs3,[s3.shape[0],s3.shape[1],
4193
+ # s3.shape[2]*s3.shape[3]]))
4194
+
4195
+ ### S3_cross = < (I1 * Psi)_j3 x (|I2 * Psi_j2| * Psi_j3)^* >_pix
4196
+ ### S3P_cross = < (I2 * Psi)_j3 x (|I1 * Psi_j2| * Psi_j3)^* >_pix
4197
+ else:
4198
+ if calc_var:
4199
+ s3, vs3 = self._compute_S3(
4200
+ j2,
4201
+ j3,
4202
+ conv1,
4203
+ vmask,
4204
+ M2_dic,
4205
+ M2convPsi_dic,
4206
+ calc_var=True,
4207
+ cmat2=cmat2,
4208
+ )
4209
+ s3p, vs3p = self._compute_S3(
4210
+ j2,
4211
+ j3,
4212
+ conv2,
4213
+ vmask,
4214
+ M1_dic,
4215
+ M1convPsi_dic,
4216
+ calc_var=True,
4217
+ cmat2=cmat2,
4218
+ )
4219
+ else:
4220
+ s3 = self._compute_S3(
4221
+ j2,
4222
+ j3,
4223
+ conv1,
4224
+ vmask,
4225
+ M2_dic,
4226
+ M2convPsi_dic,
4227
+ return_data=return_data,
4228
+ cmat2=cmat2,
4229
+ )
4230
+ s3p = self._compute_S3(
4231
+ j2,
4232
+ j3,
4233
+ conv2,
4234
+ vmask,
4235
+ M1_dic,
4236
+ M1convPsi_dic,
4237
+ return_data=return_data,
4238
+ cmat2=cmat2,
4239
+ )
4240
+
4241
+ if return_data:
4242
+ if S3[j3] is None:
4243
+ S3[j3] = {}
4244
+ S3P[j3] = {}
4245
+ if out_nside is not None and out_nside < nside_j2:
4246
+ s3 = self.backend.bk_reduce_mean(
4247
+ self.backend.bk_reshape(
4248
+ s3,
4249
+ [
4250
+ s3.shape[0],
4251
+ 12 * out_nside**2,
4252
+ (nside_j2 // out_nside) ** 2,
4253
+ s3.shape[2],
4254
+ s3.shape[3],
4255
+ ],
4256
+ ),
4257
+ 2,
4258
+ )
4259
+ s3p = self.backend.bk_reduce_mean(
4260
+ self.backend.bk_reshape(
4261
+ s3p,
4262
+ [
4263
+ s3.shape[0],
4264
+ 12 * out_nside**2,
4265
+ (nside_j2 // out_nside) ** 2,
4266
+ s3.shape[2],
4267
+ s3.shape[3],
4268
+ ],
4269
+ ),
4270
+ 2,
4271
+ )
4272
+ S3[j3][j2] = s3
4273
+ S3P[j3][j2] = s3p
4274
+ else:
4275
+ ### Normalize S3 and S3P with S2_j [Nbatch, Nmask, Norient_j]
4276
+ if norm is not None:
4277
+ self.div_norm(
4278
+ s3,
4279
+ (
4280
+ self.backend.bk_expand_dims(P2_dic[j2], off_S2)
4281
+ * self.backend.bk_expand_dims(P1_dic[j3], -1)
4282
+ )
4283
+ ** 0.5,
4284
+ ) # [Nbatch, Nmask, Norient3, Norient2]
4285
+ self.div_norm(
4286
+ s3p,
4287
+ (
4288
+ self.backend.bk_expand_dims(P1_dic[j2], off_S2)
4289
+ * self.backend.bk_expand_dims(P2_dic[j3], -1)
4290
+ )
4291
+ ** 0.5,
4292
+ ) # [Nbatch, Nmask, Norient3, Norient2]
4293
+
4294
+ ### Store S3 and S3P as a complex [Nbatch, Nmask, NS3, Norient3, Norient2]
4295
+
4296
+ # S3.append(self.backend.bk_reshape(s3,[s3.shape[0],s3.shape[1],
4297
+ # s3.shape[2]*s3.shape[3]]))
4298
+ S3.append(
4299
+ self.backend.bk_expand_dims(s3, off_S3)
4300
+ ) # Add a dimension for NS3
4301
+ if calc_var:
4302
+ VS3.append(
4303
+ self.backend.bk_expand_dims(vs3, off_S3)
4304
+ ) # Add a dimension for NS3
4305
+
4306
+ # VS3.append(self.backend.bk_reshape(vs3,[s3.shape[0],s3.shape[1],
4307
+ # s3.shape[2]*s3.shape[3]]))
4308
+
4309
+ # S3P.append(self.backend.bk_reshape(s3p,[s3.shape[0],s3.shape[1],
4310
+ # s3.shape[2]*s3.shape[3]]))
4311
+ S3P.append(
4312
+ self.backend.bk_expand_dims(s3p, off_S3)
4313
+ ) # Add a dimension for NS3
4314
+ if calc_var:
4315
+ VS3P.append(
4316
+ self.backend.bk_expand_dims(vs3p, off_S3)
4317
+ ) # Add a dimension for NS3
4318
+ # VS3P.append(self.backend.bk_reshape(vs3p,[s3.shape[0],s3.shape[1],
4319
+ # s3.shape[2]*s3.shape[3]]))
4320
+
4321
+ ##### S4
4322
+ nside_j1 = nside_j2
4323
+ for j1 in range(0, j2 + 1): # j1 <= j2
4324
+ ### S4_auto = <(|I1 * psi1| * psi3)(|I1 * psi2| * psi3)^*>
4325
+ if not cross:
4326
+ if calc_var:
4327
+ s4, vs4 = self._compute_S4(
4328
+ j1,
4329
+ j2,
4330
+ vmask,
4331
+ M1convPsi_dic,
4332
+ M2convPsi_dic=None,
4333
+ calc_var=True,
4334
+ ) # [Nbatch, Nmask, Norient3, Norient2, Norient1]
4335
+ else:
4336
+ s4 = self._compute_S4(
4337
+ j1,
4338
+ j2,
4339
+ vmask,
4340
+ M1convPsi_dic,
4341
+ M2convPsi_dic=None,
4342
+ return_data=return_data,
4343
+ ) # [Nbatch, Nmask, Norient3, Norient2, Norient1]
4344
+
4345
+ if return_data:
4346
+ if S4[j3][j2] is None:
4347
+ S4[j3][j2] = {}
4348
+ if out_nside is not None and out_nside < nside_j1:
4349
+ s4 = self.backend.bk_reduce_mean(
4350
+ self.backend.bk_reshape(
4351
+ s4,
4352
+ [
4353
+ s4.shape[0],
4354
+ 12 * out_nside**2,
4355
+ (nside_j1 // out_nside) ** 2,
4356
+ s4.shape[2],
4357
+ s4.shape[3],
4358
+ s4.shape[4],
4359
+ ],
4360
+ ),
4361
+ 2,
4362
+ )
4363
+ S4[j3][j2][j1] = s4
4364
+ else:
4365
+ ### Normalize S4 with S2_j [Nbatch, Nmask, Norient_j]
4366
+ if norm is not None:
4367
+ self.div_norm(
4368
+ s4,
4369
+ (
4370
+ self.backend.bk_expand_dims(
4371
+ self.backend.bk_expand_dims(
4372
+ P1_dic[j1], off_S2
4373
+ ),
4374
+ off_S2,
4375
+ )
4376
+ * self.backend.bk_expand_dims(
4377
+ self.backend.bk_expand_dims(
4378
+ P1_dic[j2], off_S2
4379
+ ),
4380
+ -1,
4381
+ )
4382
+ )
4383
+ ** 0.5,
4384
+ ) # [Nbatch, Nmask, Norient3, Norient2, Norient1]
4385
+ ### Store S4 as a complex [Nbatch, Nmask, NS4, Norient3, Norient2, Norient1]
4386
+
4387
+ # S4.append(self.backend.bk_reshape(s4,[s4.shape[0],s4.shape[1],
4388
+ # s4.shape[2]*s4.shape[3]*s4.shape[4]]))
4389
+ S4.append(
4390
+ self.backend.bk_expand_dims(s4, off_S4)
4391
+ ) # Add a dimension for NS4
4392
+ if calc_var:
4393
+ # VS4.append(self.backend.bk_reshape(vs4,[s4.shape[0],s4.shape[1],
4394
+ # s4.shape[2]*s4.shape[3]*s4.shape[4]]))
4395
+ VS4.append(
4396
+ self.backend.bk_expand_dims(vs4, off_S4)
4397
+ ) # Add a dimension for NS4
4398
+
4399
+ ### S4_cross = <(|I1 * psi1| * psi3)(|I2 * psi2| * psi3)^*>
4400
+ else:
4401
+ if calc_var:
4402
+ s4, vs4 = self._compute_S4(
4403
+ j1,
4404
+ j2,
4405
+ vmask,
4406
+ M1convPsi_dic,
4407
+ M2convPsi_dic=M2convPsi_dic,
4408
+ calc_var=True,
4409
+ ) # [Nbatch, Nmask, Norient3, Norient2, Norient1]
4410
+ else:
4411
+ s4 = self._compute_S4(
4412
+ j1,
4413
+ j2,
4414
+ vmask,
4415
+ M1convPsi_dic,
4416
+ M2convPsi_dic=M2convPsi_dic,
4417
+ return_data=return_data,
4418
+ ) # [Nbatch, Nmask, Norient3, Norient2, Norient1]
4419
+
4420
+ if return_data:
4421
+ if S4[j3][j2] is None:
4422
+ S4[j3][j2] = {}
4423
+ if out_nside is not None and out_nside < nside_j1:
4424
+ s4 = self.backend.bk_reduce_mean(
4425
+ self.backend.bk_reshape(
4426
+ s4,
4427
+ [
4428
+ s4.shape[0],
4429
+ 12 * out_nside**2,
4430
+ (nside_j1 // out_nside) ** 2,
4431
+ s4.shape[2],
4432
+ s4.shape[3],
4433
+ s4.shape[4],
4434
+ ],
4435
+ ),
4436
+ 2,
4437
+ )
4438
+ S4[j3][j2][j1] = s4
4439
+ else:
4440
+ ### Normalize S4 with S2_j [Nbatch, Nmask, Norient_j]
4441
+ if norm is not None:
4442
+ self.div_norm(
4443
+ s4,
4444
+ (
4445
+ self.backend.bk_expand_dims(
4446
+ self.backend.bk_expand_dims(
4447
+ P1_dic[j1], off_S2
4448
+ ),
4449
+ off_S2,
4450
+ )
4451
+ * self.backend.bk_expand_dims(
4452
+ self.backend.bk_expand_dims(
4453
+ P2_dic[j2], off_S2
4454
+ ),
4455
+ -1,
4456
+ )
4457
+ )
4458
+ ** 0.5,
4459
+ ) # [Nbatch, Nmask, Norient3, Norient2, Norient1]
4460
+ ### Store S4 as a complex [Nbatch, Nmask, NS4, Norient3, Norient2, Norient1]
4461
+ # S4.append(self.backend.bk_reshape(s4,[s4.shape[0],s4.shape[1],
4462
+ # s4.shape[2]*s4.shape[3]*s4.shape[4]]))
4463
+ S4.append(
4464
+ self.backend.bk_expand_dims(s4, off_S4)
4465
+ ) # Add a dimension for NS4
4466
+ if calc_var:
4467
+
4468
+ # VS4.append(self.backend.bk_reshape(vs4,[s4.shape[0],s4.shape[1],
4469
+ # s4.shape[2]*s4.shape[3]*s4.shape[4]]))
4470
+ VS4.append(
4471
+ self.backend.bk_expand_dims(vs4, off_S4)
4472
+ ) # Add a dimension for NS4
4473
+
4474
+ nside_j1 = nside_j1 // 2
4475
+ nside_j2 = nside_j2 // 2
4476
+
4477
+ ###### Reshape for next iteration on j3
4478
+ ### Image I1,
4479
+ # downscale the I1 [Nbatch, Npix_j3]
4480
+ if j3 != Jmax - 1:
4481
+ I1 = self.smooth(I1, axis=1)
4482
+ I1 = self.ud_grade_2(I1, axis=1)
4483
+
4484
+ ### Image I2
4485
+ if cross:
4486
+ I2 = self.smooth(I2, axis=1)
4487
+ I2 = self.ud_grade_2(I2, axis=1)
4488
+
4489
+ ### Modules
4490
+ for j2 in range(0, j3 + 1): # j2 =< j3
4491
+ ### Dictionary M1_dic[j2]
4492
+ M1_smooth = self.smooth(
4493
+ M1_dic[j2], axis=1
4494
+ ) # [Nbatch, Npix_j3, Norient3]
4495
+ M1_dic[j2] = self.ud_grade_2(
4496
+ M1_smooth, axis=1
4497
+ ) # [Nbatch, Npix_j3, Norient3]
4498
+
4499
+ ### Dictionary M2_dic[j2]
4500
+ if cross:
4501
+ M2_smooth = self.smooth(
4502
+ M2_dic[j2], axis=1
4503
+ ) # [Nbatch, Npix_j3, Norient3]
4504
+ M2_dic[j2] = self.ud_grade_2(
4505
+ M2, axis=1
4506
+ ) # [Nbatch, Npix_j3, Norient3]
4507
+
4508
+ ### Mask
4509
+ vmask = self.ud_grade_2(vmask, axis=1)
4510
+
4511
+ if self.mask_thres is not None:
4512
+ vmask = self.backend.bk_threshold(vmask, self.mask_thres)
4513
+
4514
+ ### NSIDE_j3
4515
+ nside_j3 = nside_j3 // 2
4516
+
4517
+ ### Store P1_dic and P2_dic in self
4518
+ if (norm == "auto") and (self.P1_dic is None):
4519
+ self.P1_dic = P1_dic
4520
+ if cross:
4521
+ self.P2_dic = P2_dic
4522
+ """
4523
+ Sout=[s0]+S1+S2+S3+S4
4524
+
4525
+ if cross:
4526
+ Sout=Sout+S3P
4527
+ if calc_var:
4528
+ SVout=[vs0]+VS1+VS2+VS3+VS4
4529
+ if cross:
4530
+ VSout=VSout+VS3P
4531
+ return self.backend.bk_concat(Sout, 2),self.backend.bk_concat(VSout, 2)
4532
+
4533
+ return self.backend.bk_concat(Sout, 2)
4534
+ """
4535
+ if calc_var:
4536
+ return result,vresult
4537
+ else:
4538
+ return result
4539
+ if calc_var:
4540
+ for k in S1:
4541
+ print(k.shape,k.dtype)
4542
+ for k in S2:
4543
+ print(k.shape,k.dtype)
4544
+ print(s0.shape,s0.dtype)
4545
+ return self.backend.bk_concat([s0]+S1+S2,axis=1),self.backend.bk_concat([vs0]+VS1+VS2,axis=1)
4546
+ else:
4547
+ return self.backend.bk_concat([s0]+S1+S2,axis=1)
4548
+ if not return_data:
4549
+ S1 = self.backend.bk_concat(S1, 2)
4550
+ S2 = self.backend.bk_concat(S2, 2)
4551
+ S3 = self.backend.bk_concat(S3, 2)
4552
+ S4 = self.backend.bk_concat(S4, 2)
4553
+ if cross:
4554
+ S3P = self.backend.bk_concat(S3P, 2)
4555
+ if calc_var:
4556
+ VS1 = self.backend.bk_concat(VS1, 2)
4557
+ VS2 = self.backend.bk_concat(VS2, 2)
4558
+ VS3 = self.backend.bk_concat(VS3, 2)
4559
+ VS4 = self.backend.bk_concat(VS4, 2)
4560
+ if cross:
4561
+ VS3P = self.backend.bk_concat(VS3P, 2)
4562
+ if calc_var:
4563
+ if not cross:
4564
+ return scat_cov(
4565
+ s0, S2, S3, S4, s1=S1, backend=self.backend, use_1D=self.use_1D
4566
+ ), scat_cov(
4567
+ vs0,
4568
+ VS2,
4569
+ VS3,
4570
+ VS4,
4571
+ s1=VS1,
4572
+ backend=self.backend,
4573
+ use_1D=self.use_1D,
4574
+ )
4575
+ else:
4576
+ return scat_cov(
4577
+ s0,
4578
+ S2,
4579
+ S3,
4580
+ S4,
4581
+ s1=S1,
4582
+ s3p=S3P,
4583
+ backend=self.backend,
4584
+ use_1D=self.use_1D,
4585
+ ), scat_cov(
4586
+ vs0,
4587
+ VS2,
4588
+ VS3,
4589
+ VS4,
4590
+ s1=VS1,
4591
+ s3p=VS3P,
4592
+ backend=self.backend,
4593
+ use_1D=self.use_1D,
4594
+ )
4595
+ else:
4596
+ if not cross:
4597
+ return scat_cov(
4598
+ s0, S2, S3, S4, s1=S1, backend=self.backend, use_1D=self.use_1D
4599
+ )
4600
+ else:
4601
+ return scat_cov(
4602
+ s0,
4603
+ S2,
4604
+ S3,
4605
+ S4,
4606
+ s1=S1,
4607
+ s3p=S3P,
4608
+ backend=self.backend,
4609
+ use_1D=self.use_1D,
4610
+ )
4611
+ def clean_norm(self):
4612
+ self.P1_dic = None
4613
+ self.P2_dic = None
4614
+ return
4615
+
4616
+ def _compute_S3(
4617
+ self,
4618
+ j2,
4619
+ j3,
4620
+ conv,
4621
+ vmask,
4622
+ M_dic,
4623
+ MconvPsi_dic,
4624
+ calc_var=False,
4625
+ return_data=False,
4626
+ cmat2=None,
4627
+ ):
4628
+ """
4629
+ Compute the S3 coefficients (auto or cross)
4630
+ S3 = < (Ia * Psi)_j3 x (|Ib * Psi_j2| * Psi_j3)^* >_pix
4631
+ Parameters
4632
+ ----------
4633
+ Returns
4634
+ -------
4635
+ cs3, ss3: real and imag parts of S3 coeff
4636
+ """
4637
+ ### Compute |I1 * Psi_j2| * Psi_j3 = M1_j2 * Psi_j3
4638
+ # Warning: M1_dic[j2] is already at j3 resolution [Nbatch, Npix_j3, Norient3]
4639
+ MconvPsi = self.convol(
4640
+ M_dic[j2], axis=1
4641
+ ) # [Nbatch, Npix_j3, Norient3, Norient2]
4642
+ if cmat2 is not None:
4643
+ tmp2 = self.backend.bk_repeat(MconvPsi, self.NORIENT, axis=-1)
4644
+ MconvPsi = self.backend.bk_reduce_sum(
4645
+ self.backend.bk_reshape(
4646
+ cmat2[j3][j2] * tmp2,
4647
+ [
4648
+ tmp2.shape[0],
4649
+ cmat2[j3].shape[1],
4650
+ self.NORIENT,
4651
+ self.NORIENT,
4652
+ self.NORIENT,
4653
+ ],
4654
+ ),
4655
+ 3,
4656
+ )
4657
+
4658
+ # Store it so we can use it in S4 computation
4659
+ MconvPsi_dic[j2] = MconvPsi # [Nbatch, Npix_j3, Norient3, Norient2]
4660
+
4661
+ ### Compute the product (I2 * Psi)_j3 x (M1_j2 * Psi_j3)^*
4662
+ # z_1 x z_2^* = (a1a2 + b1b2) + i(b1a2 - a1b2)
4663
+ # cconv, sconv are [Nbatch, Npix_j3, Norient3]
4664
+ if self.use_1D:
4665
+ s3 = conv * self.backend.bk_conjugate(MconvPsi)
4666
+ else:
4667
+ s3 = self.backend.bk_expand_dims(conv, -1) * self.backend.bk_conjugate(
4668
+ MconvPsi
4669
+ ) # [Nbatch, Npix_j3, Norient3, Norient2]
4670
+
4671
+ ### Apply the mask [Nmask, Npix_j3] and sum over pixels
4672
+ if return_data:
4673
+ return s3
4674
+ else:
4675
+ if calc_var:
4676
+ s3, vs3 = self.masked_mean(
4677
+ s3, vmask, axis=1, rank=j2, calc_var=True
4678
+ ) # [Nbatch, Nmask, Norient3, Norient2]
4679
+ return s3, vs3
4680
+ else:
4681
+ s3 = self.masked_mean(
3522
4682
  s3, vmask, axis=1, rank=j2
3523
4683
  ) # [Nbatch, Nmask, Norient3, Norient2]
3524
4684
  return s3
@@ -3565,7 +4725,591 @@ class funct(FOC.FoCUS):
3565
4725
  s4, vmask, axis=1, rank=j2
3566
4726
  ) # [Nbatch, Nmask, Norient3, Norient2, Norient1]
3567
4727
  return s4
4728
+
4729
+ def computer_filter(self,M,N,J,L):
4730
+ '''
4731
+ This function is strongly inspire by the package https://github.com/SihaoCheng/scattering_transform
4732
+ Done by Sihao Cheng and Rudy Morel.
4733
+ '''
4734
+
4735
+ filter = np.zeros([J, L, M, N],dtype='complex64')
4736
+
4737
+ slant=4.0 / L
4738
+
4739
+ for j in range(J):
4740
+
4741
+ for l in range(L):
4742
+
4743
+ theta = (int(L-L/2-1)-l) * np.pi / L
4744
+ sigma = 0.8 * 2**j
4745
+ xi = 3.0 / 4.0 * np.pi /2**j
4746
+
4747
+ R = np.array([[np.cos(theta), -np.sin(theta)], [np.sin(theta), np.cos(theta)]], np.float64)
4748
+ R_inv = np.array([[np.cos(theta), np.sin(theta)], [-np.sin(theta), np.cos(theta)]], np.float64)
4749
+ D = np.array([[1, 0], [0, slant * slant]])
4750
+ curv = np.matmul(R, np.matmul(D, R_inv)) / ( 2 * sigma * sigma)
4751
+
4752
+ gab = np.zeros((M, N), np.complex128)
4753
+ xx = np.empty((2,2, M, N))
4754
+ yy = np.empty((2,2, M, N))
4755
+
4756
+ for ii, ex in enumerate([-1, 0]):
4757
+ for jj, ey in enumerate([-1, 0]):
4758
+ xx[ii,jj], yy[ii,jj] = np.mgrid[
4759
+ ex * M : M + ex * M,
4760
+ ey * N : N + ey * N]
4761
+
4762
+ arg = -(curv[0, 0] * xx * xx + (curv[0, 1] + curv[1, 0]) * xx * yy + curv[1, 1] * yy * yy)
4763
+ argi = arg + 1.j * (xx * xi * np.cos(theta) + yy * xi * np.sin(theta))
4764
+
4765
+ gabi = np.exp(argi).sum((0,1))
4766
+ gab = np.exp(arg).sum((0,1))
4767
+
4768
+ norm_factor = 2 * np.pi * sigma * sigma / slant
4769
+
4770
+ gab = gab / norm_factor
4771
+
4772
+ gabi = gabi / norm_factor
4773
+
4774
+ K = gabi.sum() / gab.sum()
4775
+
4776
+ # Apply the Gaussian
4777
+ filter[j, l] = np.fft.fft2(gabi-K*gab)
4778
+ filter[j,l,0,0]=0.0
4779
+
4780
+ return self.backend.bk_cast(filter)
4781
+
4782
+ # ------------------------------------------------------------------------------------------
4783
+ #
4784
+ # utility functions
4785
+ #
4786
+ # ------------------------------------------------------------------------------------------
4787
+ def cut_high_k_off(self,data_f, dx, dy):
4788
+ '''
4789
+ This function is strongly inspire by the package https://github.com/SihaoCheng/scattering_transform
4790
+ Done by Sihao Cheng and Rudy Morel.
4791
+ '''
4792
+ if_xodd = (data_f.shape[-2]%2==1)
4793
+ if_yodd = (data_f.shape[-1]%2==1)
4794
+ result = self.backend.backend.cat(
4795
+ (self.backend.backend.cat(
4796
+ ( data_f[...,:dx+if_xodd, :dy+if_yodd] , data_f[...,-dx:, :dy+if_yodd]
4797
+ ), -2),
4798
+ self.backend.backend.cat(
4799
+ ( data_f[...,:dx+if_xodd, -dy:] , data_f[...,-dx:, -dy:]
4800
+ ), -2)
4801
+ ),-1)
4802
+ return result
4803
+ # ---------------------------------------------------------------------------
4804
+ #
4805
+ # utility functions for computing scattering coef and covariance
4806
+ #
4807
+ # ---------------------------------------------------------------------------
4808
+
4809
+ def get_dxdy(self, j,M,N):
4810
+ '''
4811
+ This function is strongly inspire by the package https://github.com/SihaoCheng/scattering_transform
4812
+ Done by Sihao Cheng and Rudy Morel.
4813
+ '''
4814
+ dx = int(max( 8, min( np.ceil(M/2**j), M//2 ) ))
4815
+ dy = int(max( 8, min( np.ceil(N/2**j), N//2 ) ))
4816
+ return dx, dy
4817
+
4818
+
4819
+
4820
+ def get_edge_masks(self,M, N, J, d0=1):
4821
+ '''
4822
+ This function is strongly inspire by the package https://github.com/SihaoCheng/scattering_transform
4823
+ Done by Sihao Cheng and Rudy Morel.
4824
+ '''
4825
+ edge_masks = self.backend.backend.empty((J, M, N))
4826
+ X, Y = self.backend.backend.meshgrid(self.backend.backend.arange(M), self.backend.backend.arange(N), indexing='ij')
4827
+ for j in range(J):
4828
+ edge_dx = min(M//4, 2**j*d0)
4829
+ edge_dy = min(N//4, 2**j*d0)
4830
+ edge_masks[j] = (X>=edge_dx) * (X<=M-edge_dx) * (Y>=edge_dy) * (Y<=N-edge_dy)
4831
+ return edge_masks.to(self.backend.torch_device)
4832
+
4833
+ # ---------------------------------------------------------------------------
4834
+ #
4835
+ # scattering cov
4836
+ #
4837
+ # ---------------------------------------------------------------------------
4838
+ def scattering_cov(
4839
+ self, data, Jmax=None,
4840
+ if_large_batch=False,
4841
+ S4_criteria=None,
4842
+ use_ref=False,
4843
+ normalization='S2',
4844
+ edge=False,
4845
+ pseudo_coef=1,
4846
+ get_variance=False,
4847
+ ref_sigma=None,
4848
+ iso_ang=False
4849
+ ):
4850
+ '''
4851
+ Calculates the scattering correlations for a batch of images, including:
4852
+
4853
+ This function is strongly inspire by the package https://github.com/SihaoCheng/scattering_transform
4854
+ Done by Sihao Cheng and Rudy Morel.
4855
+
4856
+ orig. x orig.:
4857
+ P00 = <(I * psi)(I * psi)*> = L2(I * psi)^2
4858
+ orig. x modulus:
4859
+ C01 = <(I * psi2)(|I * psi1| * psi2)*> / factor
4860
+ when normalization == 'P00', factor = L2(I * psi2) * L2(I * psi1)
4861
+ when normalization == 'P11', factor = L2(I * psi2) * L2(|I * psi1| * psi2)
4862
+ modulus x modulus:
4863
+ C11_pre_norm = <(|I * psi1| * psi3)(|I * psi2| * psi3)>
4864
+ C11 = C11_pre_norm / factor
4865
+ when normalization == 'P00', factor = L2(I * psi1) * L2(I * psi2)
4866
+ when normalization == 'P11', factor = L2(|I * psi1| * psi3) * L2(|I * psi2| * psi3)
4867
+ modulus x modulus (auto):
4868
+ P11 = <(|I * psi1| * psi2)(|I * psi1| * psi2)*>
4869
+ Parameters
4870
+ ----------
4871
+ data : numpy array or torch tensor
4872
+ image set, with size [N_image, x-sidelength, y-sidelength]
4873
+ if_large_batch : Bool (=False)
4874
+ It is recommended to use "False" unless one meets a memory issue
4875
+ C11_criteria : str or None (=None)
4876
+ Only C11 coefficients that satisfy this criteria will be computed.
4877
+ Any expressions of j1, j2, and j3 that can be evaluated as a Bool
4878
+ is accepted.The default "None" corresponds to "j1 <= j2 <= j3".
4879
+ use_ref : Bool (=False)
4880
+ When normalizing, whether or not to use the normalization factor
4881
+ computed from a reference field. For just computing the statistics,
4882
+ the default is False. However, for synthesis, set it to "True" will
4883
+ stablize the optimization process.
4884
+ normalization : str 'P00' or 'P11' (='P00')
4885
+ Whether 'P00' or 'P11' is used as the normalization factor for C01
4886
+ and C11.
4887
+ remove_edge : Bool (=False)
4888
+ If true, the edge region with a width of rougly the size of the largest
4889
+ wavelet involved is excluded when taking the global average to obtain
4890
+ the scattering coefficients.
4891
+
4892
+ Returns
4893
+ -------
4894
+ 'P00' : torch tensor with size [N_image, J, L] (# image, j1, l1)
4895
+ the power in each wavelet bands (the orig. x orig. term)
4896
+ 'S1' : torch tensor with size [N_image, J, L] (# image, j1, l1)
4897
+ the 1st-order scattering coefficients, i.e., the mean of wavelet modulus fields
4898
+ 'C01' : torch tensor with size [N_image, J, J, L, L] (# image, j1, j2, l1, l2)
4899
+ the orig. x modulus terms. Elements with j1 < j2 are all set to np.nan and not computed.
4900
+ 'C11' : torch tensor with size [N_image, J, J, J, L, L, L] (# image, j1, j2, j3, l1, l2, l3)
4901
+ the modulus x modulus terms. Elements not satisfying j1 <= j2 <= j3 and the conditions
4902
+ defined in 'C11_criteria' are all set to np.nan and not computed.
4903
+ 'C11_pre_norm' and 'C11_pre_norm_iso': pre-normalized modulus x modulus terms.
4904
+ 'P11' : torch tensor with size [N_image, J, J, L, L] (# image, j1, j2, l1, l2)
4905
+ the modulus x modulus terms with the two wavelets within modulus the same. Elements not following
4906
+ j1 <= j3 are set to np.nan and not computed.
4907
+ 'P11_iso' : torch tensor with size [N_image, J, J, L] (# image, j1, j2, l2-l1)
4908
+ 'P11' averaged over l1 while keeping l2-l1 constant.
4909
+ '''
4910
+ if S4_criteria is None:
4911
+ S4_criteria = 'j2>=j1'
4912
+
4913
+ # determine jmax and nside corresponding to the input map
4914
+ im_shape = data.shape
4915
+ if self.use_2D:
4916
+ if len(data.shape) == 2:
4917
+ nside = np.min([im_shape[0], im_shape[1]])
4918
+ M,N = im_shape[0],im_shape[1]
4919
+ N_image = 1
4920
+ else:
4921
+ nside = np.min([im_shape[1], im_shape[2]])
4922
+ M,N = im_shape[1],im_shape[2]
4923
+ N_image = data.shape[0]
4924
+ J = int(np.log(nside) / np.log(2))-1 # Number of j scales
4925
+ elif self.use_1D:
4926
+ if len(data.shape) == 2:
4927
+ npix = int(im_shape[1]) # Number of pixels
4928
+ N_image = 1
4929
+ else:
4930
+ npix = int(im_shape[0]) # Number of pixels
4931
+ N_image = data.shape[0]
4932
+
4933
+ nside = int(npix)
4934
+
4935
+ J = int(np.log(nside) / np.log(2))-1 # Number of j scales
4936
+ else:
4937
+ if len(data.shape) == 2:
4938
+ npix = int(im_shape[1]) # Number of pixels
4939
+ N_image = 1
4940
+ else:
4941
+ npix = int(im_shape[0]) # Number of pixels
4942
+ N_image = data.shape[0]
4943
+
4944
+ nside = int(np.sqrt(npix // 12))
4945
+
4946
+ J = int(np.log(nside) / np.log(2)) # Number of j scales
4947
+
4948
+ if Jmax is None:
4949
+ Jmax = J # Number of steps for the loop on scales
4950
+ if Jmax>J:
4951
+ print('==========\n\n')
4952
+ print('The Jmax you requested is larger than the data size, which may cause problems while computing the scattering transform.')
4953
+ print('\n\n==========')
4954
+
4955
+ L=self.NORIENT
4956
+
4957
+ if (M,N,J,L) not in self.filters_set:
4958
+ self.filters_set[(M,N,J,L)] = self.computer_filter(M,N,J,L) #self.computer_filter(M,N,J,L)
4959
+
4960
+ filters_set = self.filters_set[(M,N,J,L)]
4961
+
4962
+ #weight = self.weight
4963
+ if use_ref:
4964
+ if normalization=='S2':
4965
+ ref_S2 = self.ref_scattering_cov_S2
4966
+ else:
4967
+ ref_P11 = self.ref_scattering_cov['P11']
4968
+
4969
+ # convert numpy array input into self.backend.bk_ tensors
4970
+ data = self.backend.bk_cast(data)
4971
+ data_f = self.backend.bk_fftn(data, dim=(-2,-1))
4972
+
4973
+ # initialize tensors for scattering coefficients
4974
+ S2 = self.backend.bk_zeros((N_image,J,L), dtype=data.dtype)
4975
+ S1 = self.backend.bk_zeros((N_image,J,L), dtype=data.dtype)
4976
+
4977
+ Ndata_S3 = J*(J+1)//2
4978
+ Ndata_S4 = J*(J+1)*(J+2)//6
4979
+ J_S4={}
4980
+
4981
+ S3 = self.backend.bk_zeros((N_image,Ndata_S3,L,L), dtype=data_f.dtype)
4982
+ S4_pre_norm = self.backend.bk_zeros((N_image,Ndata_S4,L,L,L), dtype=data_f.dtype)
4983
+ S4 = self.backend.bk_zeros((N_image,Ndata_S4,L,L,L), dtype=data_f.dtype)
4984
+
4985
+ # variance
4986
+ if get_variance:
4987
+ S2_sigma = self.backend.bk_zeros((N_image,J,L), dtype=data.dtype)
4988
+ S1_sigma = self.backend.bk_zeros((N_image,J,L), dtype=data.dtype)
4989
+ S3_sigma = self.backend.bk_zeros((N_image,Ndata_S3,L,L), dtype=data_f.dtype)
4990
+ S4_sigma = self.backend.bk_zeros((N_image,Ndata_S4,L,L,L), dtype=data_f.dtype)
4991
+
4992
+ if iso_ang:
4993
+ S3_iso = self.backend.bk_zeros((N_image,Ndata_S3,L), dtype=data_f.dtype)
4994
+ S4_iso = self.backend.bk_zeros((N_image,Ndata_S4,L,L), dtype=data_f.dtype)
4995
+ if get_variance:
4996
+ S3_sigma_iso = self.backend.bk_zeros((N_image,Ndata_S3,L), dtype=data_f.dtype)
4997
+ S4_sigma_iso = self.backend.bk_zeros((N_image,Ndata_S4,L,L), dtype=data_f.dtype)
4998
+
4999
+ # calculate scattering fields
5000
+ if self.use_2D:
5001
+ if len(data.shape) == 2:
5002
+ I1 = self.backend.bk_ifftn(
5003
+ data_f[None,None,None,:,:] * filters_set[None,:J,:,:,:], dim=(-2,-1)
5004
+ ).abs()
5005
+ else:
5006
+ I1 = self.backend.bk_ifftn(
5007
+ data_f[:,None,None,:,:] * filters_set[None,:J,:,:,:], dim=(-2,-1)
5008
+ ).abs()
5009
+ elif self.use_1D:
5010
+ if len(data.shape) == 1:
5011
+ I1 = self.backend.bk_ifftn(
5012
+ data_f[None,None,None,:] * filters_set[None,:J,:,:], dim=(-1)
5013
+ ).abs()
5014
+ else:
5015
+ I1 = self.backend.bk_ifftn(
5016
+ data_f[:,None,None,:] * filters_set[None,:J,:,:], dim=(-1)
5017
+ ).abs()
5018
+ else:
5019
+ print('todo')
5020
+
5021
+ I1_f= self.backend.bk_fftn(I1, dim=(-2,-1))
5022
+
5023
+ #
5024
+ if edge:
5025
+ if (M,N,J) not in self.edge_masks:
5026
+ self.edge_masks[(M,N,J)] = self.get_edge_masks(M,N,J)
5027
+ edge_mask = self.edge_masks[(M,N,J)][:,None,:,:]
5028
+ edge_mask = edge_mask / edge_mask.mean((-2,-1))[:,:,None,None]
5029
+ else:
5030
+ edge_mask = 1
5031
+ S2 = (I1**2 * edge_mask).mean((-2,-1))
5032
+ S1 = (I1 * edge_mask).mean((-2,-1))
5033
+
5034
+ if get_variance:
5035
+ S2_sigma = (I1**2 * edge_mask).std((-2,-1))
5036
+ S1_sigma = (I1 * edge_mask).std((-2,-1))
5037
+
5038
+ if pseudo_coef != 1:
5039
+ I1 = I1**pseudo_coef
5040
+
5041
+ Ndata_S3=0
5042
+ Ndata_S4=0
5043
+
5044
+ # calculate the covariance and correlations of the scattering fields
5045
+ # only use the low-k Fourier coefs when calculating large-j scattering coefs.
5046
+ for j3 in range(0,J):
5047
+ J_S4[j3]=Ndata_S4
5048
+
5049
+ dx3, dy3 = self.get_dxdy(j3,M,N)
5050
+ I1_f_small = self.cut_high_k_off(I1_f[:,:j3+1], dx3, dy3) # Nimage, J, L, x, y
5051
+ data_f_small = self.cut_high_k_off(data_f, dx3, dy3)
5052
+ if edge:
5053
+ I1_small = self.backend.bk_ifftn(I1_f_small, dim=(-2,-1), norm='ortho')
5054
+ data_small = self.backend.bk_ifftn(data_f_small, dim=(-2,-1), norm='ortho')
5055
+ wavelet_f3 = self.cut_high_k_off(filters_set[j3], dx3, dy3) # L,x,y
5056
+ _, M3, N3 = wavelet_f3.shape
5057
+ wavelet_f3_squared = wavelet_f3**2
5058
+ edge_dx = min(4, int(2**j3*dx3*2/M))
5059
+ edge_dy = min(4, int(2**j3*dy3*2/N))
5060
+ # a normalization change due to the cutoff of frequency space
5061
+ fft_factor = 1 /(M3*N3) * (M3*N3/M/N)**2
5062
+ for j2 in range(0,j3+1):
5063
+ I1_f2_wf3_small = I1_f_small[:,j2].view(N_image,L,1,M3,N3) * wavelet_f3.view(1,1,L,M3,N3)
5064
+ I1_f2_wf3_2_small = I1_f_small[:,j2].view(N_image,L,1,M3,N3) * wavelet_f3_squared.view(1,1,L,M3,N3)
5065
+ if edge:
5066
+ I12_w3_small = self.backend.bk_ifftn(I1_f2_wf3_small, dim=(-2,-1), norm='ortho')
5067
+ I12_w3_2_small = self.backend.bk_ifftn(I1_f2_wf3_2_small, dim=(-2,-1), norm='ortho')
5068
+ if use_ref:
5069
+ if normalization=='P11':
5070
+ norm_factor_S3 = (ref_S2[:,None,j3,:] * ref_P11[:,j2,j3,:,:]**pseudo_coef)**0.5
5071
+ if normalization=='S2':
5072
+ norm_factor_S3 = (ref_S2[:,None,j3,:] * ref_S2[:,j2,:,None]**pseudo_coef)**0.5
5073
+ else:
5074
+ if normalization=='P11':
5075
+ # [N_image,l2,l3,x,y]
5076
+ P11_temp = (I1_f2_wf3_small.abs()**2).mean((-2,-1)) * fft_factor
5077
+ norm_factor_S3 = (S2[:,None,j3,:] * P11_temp**pseudo_coef)**0.5
5078
+ if normalization=='S2':
5079
+ norm_factor_S3 = (S2[:,None,j3,:] * S2[:,j2,:,None]**pseudo_coef)**0.5
5080
+
5081
+ if not edge:
5082
+ S3[:,Ndata_S3,:,:] = (
5083
+ data_f_small.view(N_image,1,1,M3,N3) * self.backend.bk_conjugate(I1_f2_wf3_small)
5084
+ ).mean((-2,-1)) * fft_factor / norm_factor_S3
5085
+
5086
+ if get_variance:
5087
+ S3_sigma[:,Ndata_S3,:,:] = (
5088
+ data_f_small.view(N_image,1,1,M3,N3) * self.backend.bk_conjugate(I1_f2_wf3_small)
5089
+ ).std((-2,-1)) * fft_factor / norm_factor_S3
5090
+ else:
5091
+
5092
+ S3[:,Ndata_S3,:,:] = (
5093
+ data_small.view(N_image,1,1,M3,N3) * self.backend.bk_conjugate(I12_w3_small)
5094
+ )[...,edge_dx:M3-edge_dx, edge_dy:N3-edge_dy].mean((-2,-1)) * fft_factor / norm_factor_S3
5095
+ if get_variance:
5096
+ S3_sigma[:,Ndata_S3,:,:] = (
5097
+ data_small.view(N_image,1,1,M3,N3) * self.backend.bk_conjugate(I12_w3_small)
5098
+ )[...,edge_dx:M3-edge_dx, edge_dy:N3-edge_dy].std((-2,-1)) * fft_factor / norm_factor_S3
5099
+ Ndata_S3+=1
5100
+ if j2 <= j3:
5101
+ beg_n=Ndata_S4
5102
+ for j1 in range(0, j2+1):
5103
+ if eval(S4_criteria):
5104
+ if not edge:
5105
+ if not if_large_batch:
5106
+ # [N_image,l1,l2,l3,x,y]
5107
+ S4_pre_norm[:,Ndata_S4,:,:,:] = (
5108
+ I1_f_small[:,j1].view(N_image,L,1,1,M3,N3) *
5109
+ self.backend.bk_conjugate(I1_f2_wf3_2_small.view(N_image,1,L,L,M3,N3))
5110
+ ).mean((-2,-1)) * fft_factor
5111
+ if get_variance:
5112
+ S4_sigma[:,Ndata_S4,:,:,:] = (
5113
+ I1_f_small[:,j1].view(N_image,L,1,1,M3,N3) *
5114
+ self.backend.bk_conjugate(I1_f2_wf3_2_small.view(N_image,1,L,L,M3,N3))
5115
+ ).std((-2,-1)) * fft_factor
5116
+ else:
5117
+ for l1 in range(L):
5118
+ # [N_image,l2,l3,x,y]
5119
+ S4_pre_norm[:,Ndata_S4,l1,:,:] = (
5120
+ I1_f_small[:,j1,l1].view(N_image,1,1,M3,N3) *
5121
+ self.backend.bk_conjugate(I1_f2_wf3_2_small.view(N_image,L,L,M3,N3))
5122
+ ).mean((-2,-1)) * fft_factor
5123
+ if get_variance:
5124
+ S4_sigma[:,Ndata_S4,l1,:,:] = (
5125
+ I1_f_small[:,j1,l1].view(N_image,1,1,M3,N3) *
5126
+ self.backend.bk_conjugate(I1_f2_wf3_2_small.view(N_image,L,L,M3,N3))
5127
+ ).std((-2,-1)) * fft_factor
5128
+ else:
5129
+ if not if_large_batch:
5130
+ # [N_image,l1,l2,l3,x,y]
5131
+ S4_pre_norm[:,Ndata_S4,:,:,:] = (
5132
+ I1_small[:,j1].view(N_image,L,1,1,M3,N3) * self.backend.bk_conjugate(
5133
+ I12_w3_2_small.view(N_image,1,L,L,M3,N3)
5134
+ )
5135
+ )[...,edge_dx:-edge_dx, edge_dy:-edge_dy].mean((-2,-1)) * fft_factor
5136
+ if get_variance:
5137
+ S4_sigma[:,Ndata_S4,:,:,:] = (
5138
+ I1_small[:,j1].view(N_image,L,1,1,M3,N3) * self.backend.bk_conjugate(
5139
+ I12_w3_2_small.view(N_image,1,L,L,M3,N3)
5140
+ )
5141
+ )[...,edge_dx:-edge_dx, edge_dy:-edge_dy].std((-2,-1)) * fft_factor
5142
+ else:
5143
+ for l1 in range(L):
5144
+ # [N_image,l2,l3,x,y]
5145
+ S4_pre_norm[:,Ndata_S4,l1,:,:] = (
5146
+ I1_small[:,j1].view(N_image,1,1,M3,N3) * self.backend.bk_conjugate(
5147
+ I12_w3_2_small.view(N_image,L,L,M3,N3)
5148
+ )
5149
+ )[...,edge_dx:-edge_dx, edge_dy:-edge_dy].mean((-2,-1)) * fft_factor
5150
+ if get_variance:
5151
+ S4_sigma[:,Ndata_S4,l1,:,:] = (
5152
+ I1_small[:,j1].view(N_image,1,1,M3,N3) * self.backend.bk_conjugate(
5153
+ I12_w3_2_small.view(N_image,L,L,M3,N3)
5154
+ )
5155
+ )[...,edge_dx:-edge_dx, edge_dy:-edge_dy].mean((-2,-1)) * fft_factor
5156
+
5157
+ Ndata_S4+=1
5158
+
5159
+ if normalization=='S2':
5160
+ if use_ref:
5161
+ P = (ref_S2[:,j3,:,None,None] * ref_S2[:,j2,None,:,None] )**(0.5*pseudo_coef)
5162
+ else:
5163
+ P = (S2[:,j3,:,None,None] * S2[:,j2,None,:,None] )**(0.5*pseudo_coef)
5164
+
5165
+ S4[:,beg_n:Ndata_S4,:,:,:]=S4_pre_norm[:,beg_n:Ndata_S4,:,:,:]/P
5166
+
5167
+ if get_variance:
5168
+ S4_sigma[:,beg_n:Ndata_S4,:,:,:] = S4_sigma[:,beg_n:Ndata_S4,:,:,:] / P
5169
+
5170
+ """
5171
+ # define P11 from diagonals of S4
5172
+ for j1 in range(J):
5173
+ for l1 in range(L):
5174
+ P11[:,j1,:,l1,:] = S4_pre_norm[:,j1,j1,:,l1,l1,:].real
5175
+
5176
+
5177
+ if normalization=='S4':
5178
+ if use_ref:
5179
+ P = ref_P11
5180
+ else:
5181
+ P = P11
5182
+ #.view(N_image,J,1,J,L,1,L) * .view(N_image,1,J,J,1,L,L)
5183
+ S4 = S4_pre_norm / (
5184
+ P[:,:,None,:,:,None,:] * P[:,None,:,:,None,:,:]
5185
+ )**(0.5*pseudo_coef)
5186
+
5187
+
5188
+
5189
+
5190
+ # get a single, flattened data vector for_synthesis
5191
+ select_and_index = self.get_scattering_index(J, L, normalization, S4_criteria)
5192
+ index_for_synthesis = select_and_index['index_for_synthesis']
5193
+ index_for_synthesis_iso = select_and_index['index_for_synthesis_iso']
5194
+ """
5195
+ # average over l1 to obtain simple isotropic statistics
5196
+ if iso_ang:
5197
+ S2_iso = S2.mean(-1)
5198
+ S1_iso = S1.mean(-1)
5199
+ for l1 in range(L):
5200
+ for l2 in range(L):
5201
+ S3_iso[...,(l2-l1)%L] += S3[...,l1,l2]
5202
+ for l3 in range(L):
5203
+ S4_iso[...,(l2-l1)%L,(l3-l1)%L] += S4[...,l1,l2,l3]
5204
+ S3_iso /= L; S4_iso /= L
5205
+
5206
+ if get_variance:
5207
+ S2_sigma_iso = S2_sigma.mean(-1)
5208
+ S1_sigma_iso = S1_sigma.mean(-1)
5209
+ for l1 in range(L):
5210
+ for l2 in range(L):
5211
+ S3_sigma_iso[...,(l2-l1)%L] += S3_sigma[...,l1,l2]
5212
+ for l3 in range(L):
5213
+ S4_sigma_iso[...,(l2-l1)%L,(l3-l1)%L] += S4_sigma[...,l1,l2,l3]
5214
+ S3_sigma_iso /= L; S4_sigma_iso /= L
5215
+
5216
+ mean_data=self.backend.bk_zeros((N_image,1), dtype=data.dtype)
5217
+ std_data=self.backend.bk_zeros((N_image,1), dtype=data.dtype)
5218
+ mean_data[:,0]=data.mean((-2,-1))
5219
+ std_data[:,0]=data.std((-2,-1))
5220
+
5221
+ if get_variance:
5222
+ ref_sigma={}
5223
+ if iso_ang:
5224
+ ref_sigma['std_data']=std_data
5225
+ ref_sigma['S1_sigma']=S1_sigma_iso
5226
+ ref_sigma['S2_sigma']=S2_sigma_iso
5227
+ ref_sigma['S3_sigma']=S3_sigma_iso
5228
+ ref_sigma['S4_sigma']=S4_sigma_iso
5229
+ else:
5230
+ ref_sigma['std_data']=std_data
5231
+ ref_sigma['S1_sigma']=S1_sigma
5232
+ ref_sigma['S2_sigma']=S2_sigma
5233
+ ref_sigma['S3_sigma']=S3_sigma
5234
+ ref_sigma['S4_sigma']=S4_sigma
5235
+
5236
+ if iso_ang:
5237
+ if ref_sigma is not None:
5238
+ for_synthesis = self.backend.backend.cat((
5239
+ mean_data/ref_sigma['std_data'],
5240
+ std_data/ref_sigma['std_data'],
5241
+ (S2_iso/ref_sigma['S2_sigma']).reshape((N_image, -1)).log(),
5242
+ (S1_iso/ref_sigma['S1_sigma']).reshape((N_image, -1)).log(),
5243
+ (S3_iso/ref_sigma['S3_sigma']).reshape((N_image, -1)).real,
5244
+ (S3_iso/ref_sigma['S3_sigma']).reshape((N_image, -1)).imag,
5245
+ (S4_iso/ref_sigma['S4_sigma']).reshape((N_image, -1)).real,
5246
+ (S4_iso/ref_sigma['S4_sigma']).reshape((N_image, -1)).imag,
5247
+ ),dim=-1)
5248
+ else:
5249
+ for_synthesis = self.backend.backend.cat((
5250
+ mean_data/std_data,
5251
+ std_data,
5252
+ S2_iso.reshape((N_image, -1)).log(),
5253
+ S1_iso.reshape((N_image, -1)).log(),
5254
+ S3_iso.reshape((N_image, -1)).real,
5255
+ S3_iso.reshape((N_image, -1)).imag,
5256
+ S4_iso.reshape((N_image, -1)).real,
5257
+ S4_iso.reshape((N_image, -1)).imag,
5258
+ ),dim=-1)
5259
+ else:
5260
+ if ref_sigma is not None:
5261
+ for_synthesis = self.backend.backend.cat((
5262
+ mean_data/ref_sigma['std_data'],
5263
+ std_data/ref_sigma['std_data'],
5264
+ (S2/ref_sigma['S2_sigma']).reshape((N_image, -1)).log(),
5265
+ (S1/ref_sigma['S1_sigma']).reshape((N_image, -1)).log(),
5266
+ (S3/ref_sigma['S3_sigma']).reshape((N_image, -1)).real,
5267
+ (S3/ref_sigma['S3_sigma']).reshape((N_image, -1)).imag,
5268
+ (S4/ref_sigma['S4_sigma']).reshape((N_image, -1)).real,
5269
+ (S4/ref_sigma['S4_sigma']).reshape((N_image, -1)).imag,
5270
+ ),dim=-1)
5271
+ else:
5272
+ for_synthesis = self.backend.backend.cat((
5273
+ mean_data/std_data,
5274
+ std_data,
5275
+ S2.reshape((N_image, -1)).log(),
5276
+ S1.reshape((N_image, -1)).log(),
5277
+ S3.reshape((N_image, -1)).real,
5278
+ S3.reshape((N_image, -1)).imag,
5279
+ S4.reshape((N_image, -1)).real,
5280
+ S4.reshape((N_image, -1)).imag,
5281
+ ),dim=-1)
5282
+
5283
+ if not use_ref:
5284
+ self.ref_scattering_cov_S2=S2
5285
+
5286
+ if get_variance:
5287
+ return for_synthesis,ref_sigma
5288
+
5289
+ return for_synthesis
5290
+
5291
+
5292
+ def to_gaussian(self,x):
5293
+ from scipy.stats import norm
5294
+ from scipy.interpolate import interp1d
5295
+
5296
+ idx=np.argsort(x.flatten())
5297
+ p = (np.arange(1, idx.shape[0] + 1) - 0.5) / idx.shape[0]
5298
+ im_target=x.flatten()
5299
+ im_target[idx] = norm.ppf(p)
5300
+
5301
+ # Interpolation cubique
5302
+ self.f_gaussian = interp1d(im_target[idx], x.flatten()[idx], kind='cubic')
5303
+ self.val_min=im_target[idx[0]]
5304
+ self.val_max=im_target[idx[-1]]
5305
+ return im_target.reshape(x.shape)
3568
5306
 
5307
+
5308
+ def from_gaussian(self,x):
5309
+
5310
+ x=self.backend.bk_clip_by_value(x,self.val_min,self.val_max)
5311
+ return self.f_gaussian(self.backend.to_numpy(x))
5312
+
3569
5313
  def square(self, x):
3570
5314
  if isinstance(x, scat_cov):
3571
5315
  if x.S1 is None:
@@ -3615,42 +5359,47 @@ class funct(FOC.FoCUS):
3615
5359
  return self.backend.bk_abs(self.backend.bk_sqrt(x))
3616
5360
 
3617
5361
  def reduce_mean(self, x):
3618
-
5362
+
3619
5363
  if isinstance(x, scat_cov):
3620
- result = self.backend.bk_reduce_sum(self.backend.bk_abs(x.S0)) + \
3621
- self.backend.bk_reduce_sum(self.backend.bk_abs(x.S2)) + \
3622
- self.backend.bk_reduce_sum(self.backend.bk_abs(x.S3)) + \
3623
- self.backend.bk_reduce_sum(self.backend.bk_abs(x.S4))
3624
-
3625
- N = self.backend.bk_size(x.S0)+self.backend.bk_size(x.S2)+ \
3626
- self.backend.bk_size(x.S3)+self.backend.bk_size(x.S4)
3627
-
5364
+ result = (
5365
+ self.backend.bk_reduce_sum(self.backend.bk_abs(x.S0))
5366
+ + self.backend.bk_reduce_sum(self.backend.bk_abs(x.S2))
5367
+ + self.backend.bk_reduce_sum(self.backend.bk_abs(x.S3))
5368
+ + self.backend.bk_reduce_sum(self.backend.bk_abs(x.S4))
5369
+ )
5370
+
5371
+ N = (
5372
+ self.backend.bk_size(x.S0)
5373
+ + self.backend.bk_size(x.S2)
5374
+ + self.backend.bk_size(x.S3)
5375
+ + self.backend.bk_size(x.S4)
5376
+ )
5377
+
3628
5378
  if x.S1 is not None:
3629
- result = result+self.backend.bk_reduce_sum(self.backend.bk_abs(x.S1))
5379
+ result = result + self.backend.bk_reduce_sum(self.backend.bk_abs(x.S1))
3630
5380
  N = N + self.backend.bk_size(x.S1)
3631
5381
  if x.S3P is not None:
3632
- result = result+self.backend.bk_reduce_sum(self.backend.bk_abs(x.S3P))
5382
+ result = result + self.backend.bk_reduce_sum(self.backend.bk_abs(x.S3P))
3633
5383
  N = N + self.backend.bk_size(x.S3P)
3634
- return result/self.backend.bk_cast(N)
5384
+ return result / self.backend.bk_cast(N)
3635
5385
  else:
3636
5386
  return self.backend.bk_reduce_mean(x, axis=0)
3637
-
3638
5387
 
3639
5388
  def reduce_mean_batch(self, x):
3640
-
5389
+
3641
5390
  if isinstance(x, scat_cov):
3642
-
3643
- sS0=self.backend.bk_reduce_mean(x.S0, axis=0)
3644
- sS2=self.backend.bk_reduce_mean(x.S2, axis=0)
3645
- sS3=self.backend.bk_reduce_mean(x.S3, axis=0)
3646
- sS4=self.backend.bk_reduce_mean(x.S4, axis=0)
3647
- sS1=None
3648
- sS3P=None
5391
+
5392
+ sS0 = self.backend.bk_reduce_mean(x.S0, axis=0)
5393
+ sS2 = self.backend.bk_reduce_mean(x.S2, axis=0)
5394
+ sS3 = self.backend.bk_reduce_mean(x.S3, axis=0)
5395
+ sS4 = self.backend.bk_reduce_mean(x.S4, axis=0)
5396
+ sS1 = None
5397
+ sS3P = None
3649
5398
  if x.S1 is not None:
3650
5399
  sS1 = self.backend.bk_reduce_mean(x.S1, axis=0)
3651
5400
  if x.S3P is not None:
3652
5401
  sS3P = self.backend.bk_reduce_mean(x.S3P, axis=0)
3653
-
5402
+
3654
5403
  result = scat_cov(
3655
5404
  sS0,
3656
5405
  sS2,
@@ -3664,22 +5413,22 @@ class funct(FOC.FoCUS):
3664
5413
  return result
3665
5414
  else:
3666
5415
  return self.backend.bk_reduce_mean(x, axis=0)
3667
-
5416
+
3668
5417
  def reduce_sum_batch(self, x):
3669
-
5418
+
3670
5419
  if isinstance(x, scat_cov):
3671
-
3672
- sS0=self.backend.bk_reduce_sum(x.S0, axis=0)
3673
- sS2=self.backend.bk_reduce_sum(x.S2, axis=0)
3674
- sS3=self.backend.bk_reduce_sum(x.S3, axis=0)
3675
- sS4=self.backend.bk_reduce_sum(x.S4, axis=0)
3676
- sS1=None
3677
- sS3P=None
5420
+
5421
+ sS0 = self.backend.bk_reduce_sum(x.S0, axis=0)
5422
+ sS2 = self.backend.bk_reduce_sum(x.S2, axis=0)
5423
+ sS3 = self.backend.bk_reduce_sum(x.S3, axis=0)
5424
+ sS4 = self.backend.bk_reduce_sum(x.S4, axis=0)
5425
+ sS1 = None
5426
+ sS3P = None
3678
5427
  if x.S1 is not None:
3679
5428
  sS1 = self.backend.bk_reduce_sum(x.S1, axis=0)
3680
5429
  if x.S3P is not None:
3681
5430
  sS3P = self.backend.bk_reduce_sum(x.S3P, axis=0)
3682
-
5431
+
3683
5432
  result = scat_cov(
3684
5433
  sS0,
3685
5434
  sS2,
@@ -3693,7 +5442,7 @@ class funct(FOC.FoCUS):
3693
5442
  return result
3694
5443
  else:
3695
5444
  return self.backend.bk_reduce_mean(x, axis=0)
3696
-
5445
+
3697
5446
  def reduce_distance(self, x, y, sigma=None):
3698
5447
 
3699
5448
  if isinstance(x, scat_cov):
@@ -3729,11 +5478,13 @@ class funct(FOC.FoCUS):
3729
5478
  return result
3730
5479
  else:
3731
5480
  if sigma is None:
3732
- tmp=x-y
5481
+ tmp = x - y
3733
5482
  else:
3734
- tmp=(x-y)/sigma
5483
+ tmp = (x - y) / sigma
3735
5484
  # do abs in case of complex values
3736
- return self.backend.bk_abs(self.backend.bk_reduce_mean(self.backend.bk_square(tmp)))
5485
+ return self.backend.bk_abs(
5486
+ self.backend.bk_reduce_mean(self.backend.bk_square(tmp))
5487
+ )
3737
5488
 
3738
5489
  def reduce_sum(self, x):
3739
5490
 
@@ -3909,3 +5660,133 @@ class funct(FOC.FoCUS):
3909
5660
  return scat_cov(
3910
5661
  s0, s2, s3, s4, s1=s1, s3p=s3p, backend=self.backend, use_1D=self.use_1D
3911
5662
  )
5663
+
5664
+ def synthesis(self,
5665
+ image_target,
5666
+ nstep=4,
5667
+ seed=1234,
5668
+ Jmax=None,
5669
+ edge=False,
5670
+ to_gaussian=True,
5671
+ use_variance=False,
5672
+ synthesised_N=1,
5673
+ iso_ang=False,
5674
+ EVAL_FREQUENCY=100,
5675
+ NUM_EPOCHS = 300):
5676
+
5677
+ import foscat.Synthesis as synthe
5678
+ import time
5679
+
5680
+ def The_loss(u,scat_operator,args):
5681
+ ref = args[0]
5682
+ sref = args[1]
5683
+ use_v= args[2]
5684
+
5685
+ # compute scattering covariance of the current synthetised map called u
5686
+ if use_v:
5687
+ learn=scat_operator.reduce_mean_batch(scat_operator.scattering_cov(u,edge=edge,Jmax=Jmax,ref_sigma=sref,use_ref=True,iso_ang=iso_ang))
5688
+ else:
5689
+ learn=scat_operator.reduce_mean_batch(scat_operator.scattering_cov(u,edge=edge,Jmax=Jmax,use_ref=True,iso_ang=iso_ang))
5690
+
5691
+ # make the difference withe the reference coordinates
5692
+ loss=scat_operator.backend.bk_reduce_mean(scat_operator.backend.bk_square((learn-ref)))
5693
+ return loss
5694
+
5695
+ if to_gaussian:
5696
+ # Change the data histogram to gaussian distribution
5697
+ im_target=self.to_gaussian(image_target)
5698
+ else:
5699
+ im_target=image_target
5700
+
5701
+ axis=len(im_target.shape)-1
5702
+ if self.use_2D:
5703
+ axis-=1
5704
+ if axis==0:
5705
+ im_target=self.backend.bk_expand_dims(im_target,0)
5706
+
5707
+ # compute the number of possible steps
5708
+ if self.use_2D:
5709
+ jmax=int(np.min([np.log(im_target.shape[1]),np.log(im_target.shape[2])])/np.log(2))
5710
+ elif self.use_1D:
5711
+ jmax=int(np.log(im_target.shape[1])/np.log(2))
5712
+ else:
5713
+ jmax=int((np.log(im_target.shape[1]//12)/np.log(2))/2)
5714
+ nside=2**jmax
5715
+
5716
+ if nstep>jmax-1:
5717
+ nstep=jmax-1
5718
+
5719
+ t1=time.time()
5720
+ tmp={}
5721
+ tmp[nstep-1]=im_target
5722
+ for l in range(nstep-2,-1,-1):
5723
+ tmp[l]=self.ud_grade_2(tmp[l+1],axis=1)
5724
+
5725
+ if not self.use_2D and not self.use_1D:
5726
+ l_nside=nside//(2**(nstep-1))
5727
+
5728
+ for k in range(nstep):
5729
+ if k==0:
5730
+ np.random.seed(seed)
5731
+ if self.use_2D:
5732
+ imap=np.random.randn(synthesised_N,
5733
+ tmp[k].shape[1],
5734
+ tmp[k].shape[2])
5735
+ else:
5736
+ imap=np.random.randn(synthesised_N,
5737
+ tmp[k].shape[1])
5738
+ else:
5739
+ # Increase the resolution between each step
5740
+ if self.use_2D:
5741
+ imap = self.up_grade(
5742
+ omap, imap.shape[1] * 2, axis=1, nouty=imap.shape[2] * 2
5743
+ )
5744
+ elif self.use_1D:
5745
+ imap = self.up_grade(omap, imap.shape[1] * 2, axis=1)
5746
+ else:
5747
+ imap = self.up_grade(omap, l_nside, axis=1)
5748
+
5749
+ # compute the coefficients for the target image
5750
+ if use_variance:
5751
+ ref,sref=self.scattering_cov(tmp[k],get_variance=True,edge=edge,Jmax=Jmax,iso_ang=iso_ang)
5752
+ else:
5753
+ ref=self.scattering_cov(tmp[k],edge=edge,Jmax=Jmax,iso_ang=iso_ang)
5754
+ sref=ref
5755
+
5756
+ # compute the mean of the population does nothing if only one map is given
5757
+ ref=self.reduce_mean_batch(ref)
5758
+
5759
+ # define a loss to minimize
5760
+ loss=synthe.Loss(The_loss,self,ref,sref,use_variance)
5761
+
5762
+ sy = synthe.Synthesis([loss])
5763
+
5764
+ # initialize the synthesised map
5765
+ if self.use_2D:
5766
+ print('Synthesis scale [ %d x %d ]'%(imap.shape[1],imap.shape[2]))
5767
+ elif self.use_1D:
5768
+ print('Synthesis scale [ %d ]'%(imap.shape[1]))
5769
+ else:
5770
+ print('Synthesis scale nside=%d'%(l_nside))
5771
+ l_nside*=2
5772
+
5773
+ # do the minimization
5774
+ omap=sy.run(imap,
5775
+ EVAL_FREQUENCY=EVAL_FREQUENCY,
5776
+ NUM_EPOCHS = NUM_EPOCHS)
5777
+
5778
+
5779
+
5780
+ t2=time.time()
5781
+ print('Total computation %.2fs'%(t2-t1))
5782
+
5783
+ if to_gaussian:
5784
+ omap=self.from_gaussian(omap)
5785
+
5786
+ if axis==0 and synthesised_N==1:
5787
+ return omap[0]
5788
+ else:
5789
+ return omap
5790
+
5791
+ def to_numpy(self,x):
5792
+ return self.backend.to_numpy(x)