foscat 3.6.1__py3-none-any.whl → 3.7.0__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
foscat/scat_cov.py CHANGED
@@ -33,9 +33,7 @@ testwarn = 0
33
33
 
34
34
 
35
35
  class scat_cov:
36
- def __init__(
37
- self, s0, s2, s3, s4, s1=None, s3p=None, backend=None, use_1D=False
38
- ):
36
+ def __init__(self, s0, s2, s3, s4, s1=None, s3p=None, backend=None, use_1D=False):
39
37
  self.S0 = s0
40
38
  self.S2 = s2
41
39
  self.S3 = s3
@@ -54,17 +52,17 @@ class scat_cov:
54
52
  if self.S1 is None:
55
53
  s1 = None
56
54
  else:
57
- s1 = self.S1.numpy()
55
+ s1 = self.backend.to_numpy(self.S1)
58
56
  if self.S3P is None:
59
57
  s3p = None
60
58
  else:
61
- s3p = self.S3P.numpy()
59
+ s3p = self.backend.to_numpy(self.S3P)
62
60
 
63
61
  return scat_cov(
64
- (self.S0.numpy()),
65
- (self.S2.numpy()),
66
- (self.S3.numpy()),
67
- (self.S4.numpy()),
62
+ self.backend.to_numpy(self.S0),
63
+ self.backend.to_numpy(self.S2),
64
+ self.backend.to_numpy(self.S3),
65
+ self.backend.to_numpy(self.S4),
68
66
  s1=s1,
69
67
  s3p=s3p,
70
68
  backend=self.backend,
@@ -94,7 +92,7 @@ class scat_cov:
94
92
  )
95
93
 
96
94
  def conv2complex(self, val):
97
- if val.dtype == "complex64" or val.dtype == "complex128":
95
+ if val.dtype == "complex64" in val.dtype or "complex128" or val.dtype == "torch.complex64" or val.dtype == "torch.complex128" :
98
96
  return val
99
97
  else:
100
98
  return self.backend.bk_complex(val, 0 * val)
@@ -1531,7 +1529,7 @@ class scat_cov:
1531
1529
  if isinstance(x, np.ndarray):
1532
1530
  return x
1533
1531
  else:
1534
- return x.numpy()
1532
+ return self.backend.to_numpy(x)
1535
1533
  else:
1536
1534
  return None
1537
1535
 
@@ -1974,7 +1972,7 @@ class scat_cov:
1974
1972
  if self.BACKEND == "numpy":
1975
1973
  s2[:, :, noff:, :] = self.S2
1976
1974
  else:
1977
- s2[:, :, noff:, :] = self.S2.numpy()
1975
+ s2[:, :, noff:, :] = self.backend.to_numpy(self.S2)
1978
1976
  for i in range(self.S2.shape[0]):
1979
1977
  for j in range(self.S2.shape[1]):
1980
1978
  for k in range(self.S2.shape[3]):
@@ -1986,7 +1984,7 @@ class scat_cov:
1986
1984
  if self.BACKEND == "numpy":
1987
1985
  s1[:, :, noff:, :] = self.S1
1988
1986
  else:
1989
- s1[:, :, noff:, :] = self.S1.numpy()
1987
+ s1[:, :, noff:, :] = self.backend.to_numpy(self.S1)
1990
1988
  for i in range(self.S1.shape[0]):
1991
1989
  for j in range(self.S1.shape[1]):
1992
1990
  for k in range(self.S1.shape[3]):
@@ -2045,12 +2043,12 @@ class scat_cov:
2045
2043
  )
2046
2044
  )
2047
2045
  else:
2048
- s3[i, j, idx[noff:], k, l_orient] = self.S3.numpy()[
2046
+ s3[i, j, idx[noff:], k, l_orient] = self.backend.to_numpy(self.S3)[
2049
2047
  i, j, j2 == ij - noff, k, l_orient
2050
2048
  ]
2051
2049
  s3[i, j, idx[:noff], k, l_orient] = (
2052
2050
  self.add_data_from_slope(
2053
- self.S3.numpy()[
2051
+ self.backend.to_numpy(self.S3)[
2054
2052
  i, j, j2 == ij - noff, k, l_orient
2055
2053
  ],
2056
2054
  noff,
@@ -2358,13 +2356,13 @@ class funct(FOC.FoCUS):
2358
2356
  tmp = self.up_grade(tmp, l_nside * 2, axis=1)
2359
2357
  if image2 is not None:
2360
2358
  tmpi2 = self.up_grade(tmpi2, l_nside * 2, axis=1)
2361
-
2362
2359
  l_nside = int(np.sqrt(tmp.shape[1] // 12))
2363
2360
  nscale = int(np.log(l_nside) / np.log(2))
2364
2361
  cmat = {}
2365
2362
  cmat2 = {}
2363
+
2364
+ # Loop over scales
2366
2365
  for k in range(nscale):
2367
- sim = self.backend.bk_abs(self.convol(tmp, axis=1))
2368
2366
  if image2 is not None:
2369
2367
  sim = self.backend.bk_real(
2370
2368
  self.backend.bk_L1(
@@ -2375,17 +2373,31 @@ class funct(FOC.FoCUS):
2375
2373
  else:
2376
2374
  sim = self.backend.bk_abs(self.convol(tmp, axis=1))
2377
2375
 
2378
- cc = self.backend.bk_reduce_mean(sim[:, :, 0] - sim[:, :, 2], 0)
2379
- ss = self.backend.bk_reduce_mean(sim[:, :, 1] - sim[:, :, 3], 0)
2380
- for m in range(smooth_scale):
2381
- if cc.shape[0] > 12:
2382
- cc = self.ud_grade_2(self.smooth(cc))
2383
- ss = self.ud_grade_2(self.smooth(ss))
2376
+ # instead of difference between "opposite" channels use weighted average
2377
+ # of cosine and sine contributions using all channels
2378
+ angles = (
2379
+ 2 * np.pi * np.arange(self.NORIENT) / self.NORIENT
2380
+ ) # shape: (NORIENT,)
2381
+ angles = angles.reshape(1, 1, self.NORIENT)
2382
+
2383
+ # we use cosines and sines as weights for sim
2384
+ weighted_cos = self.backend.bk_reduce_mean(sim * np.cos(angles), axis=-1)
2385
+ weighted_sin = self.backend.bk_reduce_mean(sim * np.sin(angles), axis=-1)
2386
+ # For simplicity, take first element of the batch
2387
+ cc = weighted_cos[0]
2388
+ ss = weighted_sin[0]
2389
+
2390
+ if smooth_scale > 0:
2391
+ for m in range(smooth_scale):
2392
+ if cc.shape[0] > 12:
2393
+ cc = self.ud_grade_2(self.smooth(cc))
2394
+ ss = self.ud_grade_2(self.smooth(ss))
2384
2395
  if cc.shape[0] != tmp.shape[0]:
2385
2396
  ll_nside = int(np.sqrt(tmp.shape[1] // 12))
2386
2397
  cc = self.up_grade(cc, ll_nside)
2387
2398
  ss = self.up_grade(ss, ll_nside)
2388
2399
 
2400
+ # compute local phase from weighted cos and sin (same as before)
2389
2401
  if self.BACKEND == "numpy":
2390
2402
  phase = np.fmod(np.arctan2(ss, cc) + 2 * np.pi, 2 * np.pi)
2391
2403
  else:
@@ -2393,87 +2405,87 @@ class funct(FOC.FoCUS):
2393
2405
  np.arctan2(ss.numpy(), cc.numpy()) + 2 * np.pi, 2 * np.pi
2394
2406
  )
2395
2407
 
2396
- iph = (4 * phase / (2 * np.pi)).astype("int")
2397
- alpha = 4 * phase / (2 * np.pi) - iph
2398
- mat = np.zeros([sim.shape[1], 4 * 4])
2408
+ # instead of linear interpolation cosine‐based interpolation
2409
+ phase_scaled = self.NORIENT * phase / (2 * np.pi)
2410
+ iph = np.floor(phase_scaled).astype("int") # lower bin index
2411
+ delta = phase_scaled - iph # fractional part in [0,1)
2412
+ # interpolation weights
2413
+ w0 = np.cos(delta * np.pi / 2) ** 2
2414
+ w1 = np.sin(delta * np.pi / 2) ** 2
2415
+
2416
+ # build rotation matrix
2417
+ mat = np.zeros([sim.shape[1], self.NORIENT * self.NORIENT])
2399
2418
  lidx = np.arange(sim.shape[1])
2400
- for l_orient in range(4):
2401
- mat[lidx, 4 * ((l_orient + iph) % 4) + l_orient] = 1.0 - alpha
2402
- mat[lidx, 4 * ((l_orient + iph + 1) % 4) + l_orient] = alpha
2419
+ for l in range(self.NORIENT):
2420
+ # Instead of simple linear weights, we use the cosine weights w0 and w1.
2421
+ col0 = self.NORIENT * ((l + iph) % self.NORIENT) + l
2422
+ col1 = self.NORIENT * ((l + iph + 1) % self.NORIENT) + l
2423
+ mat[lidx, col0] = w0
2424
+ mat[lidx, col1] = w1
2403
2425
 
2404
2426
  cmat[k] = self.backend.bk_cast(mat.astype("complex64"))
2405
2427
 
2406
- mat2 = np.zeros([k + 1, sim.shape[1], 4, 4 * 4])
2428
+ # do same modifications for mat2
2429
+ mat2 = np.zeros(
2430
+ [k + 1, sim.shape[1], self.NORIENT, self.NORIENT * self.NORIENT]
2431
+ )
2407
2432
 
2408
2433
  for k2 in range(k + 1):
2409
- tmp2 = self.backend.bk_repeat(sim, 4, axis=-1)
2434
+ tmp2 = self.backend.bk_repeat(sim, self.NORIENT, axis=-1)
2435
+
2410
2436
  sim2 = self.backend.bk_reduce_sum(
2411
2437
  self.backend.bk_reshape(
2412
- mat.reshape(1, mat.shape[0], 16) * tmp2,
2413
- [sim.shape[0], cmat[k].shape[0], 4, 4],
2438
+ mat.reshape(1, mat.shape[0], self.NORIENT * self.NORIENT)
2439
+ * tmp2,
2440
+ [sim.shape[0], cmat[k].shape[0], self.NORIENT, self.NORIENT],
2414
2441
  ),
2415
2442
  2,
2416
2443
  )
2444
+
2417
2445
  sim2 = self.backend.bk_abs(self.convol(sim2, axis=1))
2418
2446
 
2419
- cc = self.smooth(
2420
- self.backend.bk_reduce_mean(sim2[:, :, 0] - sim2[:, :, 2], 0)
2447
+ weighted_cos2 = self.backend.bk_reduce_mean(
2448
+ sim2 * np.cos(angles), axis=-1
2421
2449
  )
2422
- ss = self.smooth(
2423
- self.backend.bk_reduce_mean(sim2[:, :, 1] - sim2[:, :, 3], 0)
2450
+ weighted_sin2 = self.backend.bk_reduce_mean(
2451
+ sim2 * np.sin(angles), axis=-1
2424
2452
  )
2425
- for m in range(smooth_scale):
2426
- if cc.shape[0] > 12:
2427
- cc = self.ud_grade_2(self.smooth(cc))
2428
- ss = self.ud_grade_2(self.smooth(ss))
2429
- if cc.shape[0] != sim.shape[1]:
2453
+
2454
+ cc2 = weighted_cos2[0]
2455
+ ss2 = weighted_sin2[0]
2456
+
2457
+ if smooth_scale > 0:
2458
+ for m in range(smooth_scale):
2459
+ if cc2.shape[0] > 12:
2460
+ cc2 = self.ud_grade_2(self.smooth(cc2))
2461
+ ss2 = self.ud_grade_2(self.smooth(ss2))
2462
+
2463
+ if cc2.shape[0] != sim.shape[1]:
2430
2464
  ll_nside = int(np.sqrt(sim.shape[1] // 12))
2431
- cc = self.up_grade(cc, ll_nside)
2432
- ss = self.up_grade(ss, ll_nside)
2465
+ cc2 = self.up_grade(cc2, ll_nside)
2466
+ ss2 = self.up_grade(ss2, ll_nside)
2433
2467
 
2434
2468
  if self.BACKEND == "numpy":
2435
- phase = np.fmod(np.arctan2(ss, cc) + 2 * np.pi, 2 * np.pi)
2469
+ phase2 = np.fmod(np.arctan2(ss2, cc2) + 2 * np.pi, 2 * np.pi)
2436
2470
  else:
2437
- phase = np.fmod(
2438
- np.arctan2(ss.numpy(), cc.numpy()) + 2 * np.pi, 2 * np.pi
2471
+ phase2 = np.fmod(
2472
+ np.arctan2(ss2.numpy(), cc2.numpy()) + 2 * np.pi, 2 * np.pi
2439
2473
  )
2440
- """
2441
- for k in range(4):
2442
- hp.mollview(np.fmod(phase+np.pi,2*np.pi),cmap='jet',nest=True,hold=False,sub=(2,2,1+k))
2443
- plt.show()
2444
- return None
2445
- """
2446
- iph = (4 * phase / (2 * np.pi)).astype("int")
2447
- alpha = 4 * phase / (2 * np.pi) - iph
2474
+
2475
+ phase2_scaled = self.NORIENT * phase2 / (2 * np.pi)
2476
+ iph2 = np.floor(phase2_scaled).astype("int")
2477
+ delta2 = phase2_scaled - iph2
2478
+ w0_2 = np.cos(delta2 * np.pi / 2) ** 2
2479
+ w1_2 = np.sin(delta2 * np.pi / 2) ** 2
2448
2480
  lidx = np.arange(sim.shape[1])
2449
- for m in range(4):
2450
- for l_orient in range(4):
2451
- mat2[
2452
- k2, lidx, m, 4 * ((l_orient + iph[:, m]) % 4) + l_orient
2453
- ] = (1.0 - alpha[:, m])
2454
- mat2[
2455
- k2, lidx, m, 4 * ((l_orient + iph[:, m] + 1) % 4) + l_orient
2456
- ] = alpha[:, m]
2457
-
2458
- cmat2[k] = self.backend.bk_cast(mat2.astype("complex64"))
2459
- """
2460
- tmp=self.backend.bk_repeat(sim[0],4,axis=1)
2461
- sim2=self.backend.bk_reduce_sum(self.backend.bk_reshape(mat*tmp,[12*nside**2,4,4]),1)
2462
-
2463
- cc2=(sim2[:,0]-sim2[:,2])
2464
- ss2=(sim2[:,1]-sim2[:,3])
2465
- phase2=np.fmod(np.arctan2(ss2.numpy(),cc2.numpy())+2*np.pi,2*np.pi)
2466
-
2467
- plt.figure()
2468
- hp.mollview(phase,cmap='jet',nest=True,hold=False,sub=(2,2,1))
2469
- hp.mollview(np.fmod(phase2+np.pi,2*np.pi),cmap='jet',nest=True,hold=False,sub=(2,2,2))
2470
- plt.figure()
2471
- for k in range(4):
2472
- hp.mollview((sim[0,:,k]).numpy().real,cmap='jet',nest=True,hold=False,sub=(2,4,1+k),min=-10,max=10)
2473
- hp.mollview((sim2[:,k]).numpy().real,cmap='jet',nest=True,hold=False,sub=(2,4,5+k),min=-10,max=10)
2474
-
2475
- plt.show()
2476
- """
2481
+
2482
+ for m in range(self.NORIENT):
2483
+ for l in range(self.NORIENT):
2484
+ col0 = self.NORIENT * ((l + iph2[:, m]) % self.NORIENT) + l
2485
+ col1 = self.NORIENT * ((l + iph2[:, m] + 1) % self.NORIENT) + l
2486
+ mat2[k2, lidx, m, col0] = w0_2[:, m]
2487
+ mat2[k2, lidx, m, col1] = w1_2[:, m]
2488
+ cmat2[k] = self.backend.bk_cast(mat2.astype("complex64"))
2477
2489
 
2478
2490
  if k < l_nside - 1:
2479
2491
  tmp = self.ud_grade_2(tmp, axis=1)
@@ -2493,11 +2505,12 @@ class funct(FOC.FoCUS):
2493
2505
  image2=None,
2494
2506
  mask=None,
2495
2507
  norm=None,
2496
- Auto=True,
2497
2508
  calc_var=False,
2498
2509
  cmat=None,
2499
2510
  cmat2=None,
2500
- out_nside=None
2511
+ Jmax=None,
2512
+ out_nside=None,
2513
+ edge=True
2501
2514
  ):
2502
2515
  """
2503
2516
  Calculates the scattering correlations for a batch of images. Mean are done over pixels.
@@ -2523,9 +2536,6 @@ class funct(FOC.FoCUS):
2523
2536
  norm: None or str
2524
2537
  If None no normalization is applied, if 'auto' normalize by the reference S2,
2525
2538
  if 'self' normalize by the current S2.
2526
- all_cross: False or True
2527
- If False compute all the coefficient even the Imaginary part,
2528
- If True return only the terms computable in the auto case.
2529
2539
  Returns
2530
2540
  -------
2531
2541
  S1, S2, S3, S4 normalized
@@ -2539,7 +2549,7 @@ class funct(FOC.FoCUS):
2539
2549
  )
2540
2550
  return None
2541
2551
  if mask is not None:
2542
- if list(image1.shape) != list(mask.shape)[1:]:
2552
+ if image1.shape[-2] != mask.shape[1] or image1.shape[-1] != mask.shape[2]:
2543
2553
  print(
2544
2554
  "The LAST COLUMN of the mask should have the same size ",
2545
2555
  mask.shape,
@@ -2558,9 +2568,6 @@ class funct(FOC.FoCUS):
2558
2568
  cross = False
2559
2569
  if image2 is not None:
2560
2570
  cross = True
2561
- all_cross = Auto
2562
- else:
2563
- all_cross = False
2564
2571
 
2565
2572
  ### PARAMETERS
2566
2573
  axis = 1
@@ -2596,8 +2603,16 @@ class funct(FOC.FoCUS):
2596
2603
  nside = int(np.sqrt(npix // 12))
2597
2604
 
2598
2605
  J = int(np.log(nside) / np.log(2)) # Number of j scales
2599
-
2600
- Jmax = J - self.OSTEP # Number of steps for the loop on scales
2606
+
2607
+ if (self.use_2D or self.use_1D) and self.KERNELSZ>3:
2608
+ J-=1
2609
+ if Jmax is None:
2610
+ Jmax = J # Number of steps for the loop on scales
2611
+ if Jmax>J:
2612
+ print('==========\n\n')
2613
+ print('The Jmax you requested is larger than the data size, which may cause problems while computing the scattering transform.')
2614
+ print('\n\n==========')
2615
+
2601
2616
 
2602
2617
  ### LOCAL VARIABLES (IMAGES and MASK)
2603
2618
  if len(image1.shape) == 1 or (len(image1.shape) == 2 and self.use_2D):
@@ -2621,7 +2636,7 @@ class funct(FOC.FoCUS):
2621
2636
  else:
2622
2637
  vmask = self.backend.bk_cast(mask) # [Nmask, Npix]
2623
2638
 
2624
- if self.KERNELSZ > 3:
2639
+ if self.KERNELSZ > 3 and not self.use_2D:
2625
2640
  # if the kernel size is bigger than 3 increase the binning before smoothing
2626
2641
  if self.use_2D:
2627
2642
  vmask = self.up_grade(
@@ -2645,7 +2660,7 @@ class funct(FOC.FoCUS):
2645
2660
  if cross:
2646
2661
  I2 = self.up_grade(I2, nside * 2, axis=axis)
2647
2662
 
2648
- if self.KERNELSZ > 5:
2663
+ if self.KERNELSZ > 5 and not self.use_2D:
2649
2664
  # if the kernel size is bigger than 3 increase the binning before smoothing
2650
2665
  if self.use_2D:
2651
2666
  vmask = self.up_grade(
@@ -2677,7 +2692,23 @@ class funct(FOC.FoCUS):
2677
2692
 
2678
2693
  ### INITIALIZATION
2679
2694
  # Coefficients
2680
- S1, S2, S3, S4, S3P = None, None, None, None, None
2695
+ if return_data:
2696
+ S1 = {}
2697
+ S2 = {}
2698
+ S3 = {}
2699
+ S3P = {}
2700
+ S4 = {}
2701
+ else:
2702
+ S1 = []
2703
+ S2 = []
2704
+ S3 = []
2705
+ S4 = []
2706
+ S3P = []
2707
+ VS1 = []
2708
+ VS2 = []
2709
+ VS3 = []
2710
+ VS3P = []
2711
+ VS4 = []
2681
2712
 
2682
2713
  off_S2 = -2
2683
2714
  off_S3 = -3
@@ -2687,11 +2718,6 @@ class funct(FOC.FoCUS):
2687
2718
  off_S3 = -1
2688
2719
  off_S4 = -1
2689
2720
 
2690
- # Dictionaries for S3 computation
2691
- M1_dic = {} # M stands for Module M1 = |I1 * Psi|
2692
- if cross:
2693
- M2_dic = {}
2694
-
2695
2721
  # S2 for normalization
2696
2722
  cond_init_P1_dic = (norm == "self") or (
2697
2723
  (norm == "auto") and (self.P1_dic is None)
@@ -2710,7 +2736,12 @@ class funct(FOC.FoCUS):
2710
2736
  if return_data:
2711
2737
  s0 = I1
2712
2738
  if out_nside is not None:
2713
- s0 = self.backend.bk_reduce_mean(self.backend.bk_reshape(s0,[s0.shape[0],12*out_nside**2,(nside//out_nside)**2]),2)
2739
+ s0 = self.backend.bk_reduce_mean(
2740
+ self.backend.bk_reshape(
2741
+ s0, [s0.shape[0], 12 * out_nside**2, (nside // out_nside) ** 2]
2742
+ ),
2743
+ 2,
2744
+ )
2714
2745
  else:
2715
2746
  if not cross:
2716
2747
  s0, l_vs0 = self.masked_mean(I1, vmask, axis=1, calc_var=True)
@@ -2720,17 +2751,37 @@ class funct(FOC.FoCUS):
2720
2751
  )
2721
2752
  vs0 = self.backend.bk_concat([l_vs0, l_vs0], 1)
2722
2753
  s0 = self.backend.bk_concat([s0, l_vs0], 1)
2723
-
2724
2754
  #### COMPUTE S1, S2, S3 and S4
2725
2755
  nside_j3 = nside # NSIDE start (nside_j3 = nside / 2^j3)
2756
+
2757
+ # a remettre comme avant
2758
+ M1_dic={}
2759
+
2726
2760
  for j3 in range(Jmax):
2761
+
2762
+ if edge:
2763
+ if self.mask_mask is None:
2764
+ self.mask_mask={}
2765
+ if self.use_2D:
2766
+ if (vmask.shape[1],vmask.shape[2]) not in self.mask_mask:
2767
+ mask_mask=np.zeros([1,vmask.shape[1],vmask.shape[2]])
2768
+ mask_mask[0,
2769
+ self.KERNELSZ//2:-self.KERNELSZ//2+1,
2770
+ self.KERNELSZ//2:-self.KERNELSZ//2+1]=1.0
2771
+ self.mask_mask[(vmask.shape[1],vmask.shape[2])]=self.backend.bk_cast(mask_mask)
2772
+ vmask=vmask*self.mask_mask[(vmask.shape[1],vmask.shape[2])]
2773
+ #print(self.KERNELSZ//2,vmask,mask_mask)
2774
+
2775
+ if self.use_1D:
2776
+ if (vmask.shape[1]) not in self.mask_mask:
2777
+ mask_mask=np.zeros([1,vmask.shape[1]])
2778
+ mask_mask[0,
2779
+ self.KERNELSZ//2:-self.KERNELSZ//2+1]=1.0
2780
+ self.mask_mask[(vmask.shape[1])]=self.backend.bk_cast(mask_mask)
2781
+ vmask=vmask*self.mask_mask[(vmask.shape[1])]
2782
+
2727
2783
  if return_data:
2728
- if S3 is None:
2729
- S3 = {}
2730
2784
  S3[j3] = None
2731
-
2732
- if S3P is None:
2733
- S3P = {}
2734
2785
  S3P[j3] = None
2735
2786
 
2736
2787
  if S4 is None:
@@ -2740,12 +2791,12 @@ class funct(FOC.FoCUS):
2740
2791
  ####### S1 and S2
2741
2792
  ### Make the convolution I1 * Psi_j3
2742
2793
  conv1 = self.convol(I1, axis=1) # [Nbatch, Npix_j3, Norient3]
2743
-
2744
2794
  if cmat is not None:
2745
- tmp2 = self.backend.bk_repeat(conv1, 4, axis=-1)
2795
+ tmp2 = self.backend.bk_repeat(conv1, self.NORIENT, axis=-1)
2746
2796
  conv1 = self.backend.bk_reduce_sum(
2747
2797
  self.backend.bk_reshape(
2748
- cmat[j3] * tmp2, [tmp2.shape[0], cmat[j3].shape[0], 4, 4]
2798
+ cmat[j3] * tmp2,
2799
+ [tmp2.shape[0], cmat[j3].shape[0], self.NORIENT, self.NORIENT],
2749
2800
  ),
2750
2801
  2,
2751
2802
  )
@@ -2781,33 +2832,31 @@ class funct(FOC.FoCUS):
2781
2832
  if return_data:
2782
2833
  if S2 is None:
2783
2834
  S2 = {}
2784
- if out_nside is not None and out_nside<nside_j3:
2835
+ if out_nside is not None and out_nside < nside_j3:
2785
2836
  s2 = self.backend.bk_reduce_mean(
2786
- self.backend.bk_reshape(s2,[s2.shape[0],
2787
- 12*out_nside**2,
2788
- (nside_j3//out_nside)**2,
2789
- s2.shape[2]]),2)
2837
+ self.backend.bk_reshape(
2838
+ s2,
2839
+ [
2840
+ s2.shape[0],
2841
+ 12 * out_nside**2,
2842
+ (nside_j3 // out_nside) ** 2,
2843
+ s2.shape[2],
2844
+ ],
2845
+ ),
2846
+ 2,
2847
+ )
2790
2848
  S2[j3] = s2
2791
2849
  else:
2792
2850
  if norm == "auto": # Normalize S2
2793
2851
  s2 /= P1_dic[j3]
2794
- if S2 is None:
2795
- S2 = self.backend.bk_expand_dims(
2796
- s2, off_S2
2852
+
2853
+ S2.append(
2854
+ self.backend.bk_expand_dims(s2, off_S2)
2855
+ ) # Add a dimension for NS2
2856
+ if calc_var:
2857
+ VS2.append(
2858
+ self.backend.bk_expand_dims(vs2, off_S2)
2797
2859
  ) # Add a dimension for NS2
2798
- if calc_var:
2799
- VS2 = self.backend.bk_expand_dims(
2800
- vs2, off_S2
2801
- ) # Add a dimension for NS2
2802
- else:
2803
- S2 = self.backend.bk_concat(
2804
- [S2, self.backend.bk_expand_dims(s2, off_S2)], axis=2
2805
- )
2806
- if calc_var:
2807
- VS2 = self.backend.bk_concat(
2808
- [VS2, self.backend.bk_expand_dims(vs2, off_S2)],
2809
- axis=2,
2810
- )
2811
2860
 
2812
2861
  #### S1_auto computation
2813
2862
  ### Image 1 : S1 = < M1 >_pix
@@ -2825,45 +2874,47 @@ class funct(FOC.FoCUS):
2825
2874
  ) # [Nbatch, Nmask, Norient3]
2826
2875
 
2827
2876
  if return_data:
2828
- if S1 is None:
2829
- S1 = {}
2830
- if out_nside is not None and out_nside<nside_j3:
2877
+ if out_nside is not None and out_nside < nside_j3:
2831
2878
  s1 = self.backend.bk_reduce_mean(
2832
- self.backend.bk_reshape(s1,[s1.shape[0],
2833
- 12*out_nside**2,
2834
- (nside_j3//out_nside)**2,
2835
- s1.shape[2]]),2)
2879
+ self.backend.bk_reshape(
2880
+ s1,
2881
+ [
2882
+ s1.shape[0],
2883
+ 12 * out_nside**2,
2884
+ (nside_j3 // out_nside) ** 2,
2885
+ s1.shape[2],
2886
+ ],
2887
+ ),
2888
+ 2,
2889
+ )
2836
2890
  S1[j3] = s1
2837
2891
  else:
2838
2892
  ### Normalize S1
2839
2893
  if norm is not None:
2840
2894
  self.div_norm(s1, (P1_dic[j3]) ** 0.5)
2841
2895
  ### We store S1 for image1 [Nbatch, Nmask, NS1, Norient3]
2842
- if S1 is None:
2843
- S1 = self.backend.bk_expand_dims(
2844
- s1, off_S2
2896
+ S1.append(
2897
+ self.backend.bk_expand_dims(s1, off_S2)
2898
+ ) # Add a dimension for NS1
2899
+ if calc_var:
2900
+ VS1.append(
2901
+ self.backend.bk_expand_dims(vs1, off_S2)
2845
2902
  ) # Add a dimension for NS1
2846
- if calc_var:
2847
- VS1 = self.backend.bk_expand_dims(
2848
- vs1, off_S2
2849
- ) # Add a dimension for NS1
2850
- else:
2851
- S1 = self.backend.bk_concat(
2852
- [S1, self.backend.bk_expand_dims(s1, off_S2)], axis=2
2853
- )
2854
- if calc_var:
2855
- VS1 = self.backend.bk_concat(
2856
- [VS1, self.backend.bk_expand_dims(vs1, off_S2)], axis=2
2857
- )
2858
2903
 
2859
2904
  else: # Cross
2860
2905
  ### Make the convolution I2 * Psi_j3
2861
2906
  conv2 = self.convol(I2, axis=1) # [Nbatch, Npix_j3, Norient3]
2862
2907
  if cmat is not None:
2863
- tmp2 = self.backend.bk_repeat(conv2, 4, axis=-1)
2908
+ tmp2 = self.backend.bk_repeat(conv2, self.NORIENT, axis=-1)
2864
2909
  conv2 = self.backend.bk_reduce_sum(
2865
2910
  self.backend.bk_reshape(
2866
- cmat[j3] * tmp2, [tmp2.shape[0], cmat[j3].shape[0], 4, 4]
2911
+ cmat[j3] * tmp2,
2912
+ [
2913
+ tmp2.shape[0],
2914
+ cmat[j3].shape[0],
2915
+ self.NORIENT,
2916
+ self.NORIENT,
2917
+ ],
2867
2918
  ),
2868
2919
  2,
2869
2920
  )
@@ -2917,14 +2968,19 @@ class funct(FOC.FoCUS):
2917
2968
  s2 = self.masked_mean(s2, vmask, axis=1, rank=j3)
2918
2969
 
2919
2970
  if return_data:
2920
- if S2 is None:
2921
- S2 = {}
2922
- if out_nside is not None and out_nside<nside_j3:
2971
+ if out_nside is not None and out_nside < nside_j3:
2923
2972
  s2 = self.backend.bk_reduce_mean(
2924
- self.backend.bk_reshape(s2,[s2.shape[0],
2925
- 12*out_nside**2,
2926
- (nside_j3//out_nside)**2,
2927
- s2.shape[2]]),2)
2973
+ self.backend.bk_reshape(
2974
+ s2,
2975
+ [
2976
+ s2.shape[0],
2977
+ 12 * out_nside**2,
2978
+ (nside_j3 // out_nside) ** 2,
2979
+ s2.shape[2],
2980
+ ],
2981
+ ),
2982
+ 2,
2983
+ )
2928
2984
  S2[j3] = s2
2929
2985
  else:
2930
2986
  ### Normalize S2_cross
@@ -2932,26 +2988,15 @@ class funct(FOC.FoCUS):
2932
2988
  s2 /= (P1_dic[j3] * P2_dic[j3]) ** 0.5
2933
2989
 
2934
2990
  ### Store S2_cross as complex [Nbatch, Nmask, NS2, Norient3]
2935
- if not all_cross:
2936
- s2 = self.backend.bk_real(s2)
2991
+ s2 = self.backend.bk_real(s2)
2937
2992
 
2938
- if S2 is None:
2939
- S2 = self.backend.bk_expand_dims(
2940
- s2, off_S2
2993
+ S2.append(
2994
+ self.backend.bk_expand_dims(s2, off_S2)
2995
+ ) # Add a dimension for NS2
2996
+ if calc_var:
2997
+ VS2.append(
2998
+ self.backend.bk_expand_dims(vs2, off_S2)
2941
2999
  ) # Add a dimension for NS2
2942
- if calc_var:
2943
- VS2 = self.backend.bk_expand_dims(
2944
- vs2, off_S2
2945
- ) # Add a dimension for NS2
2946
- else:
2947
- S2 = self.backend.bk_concat(
2948
- [S2, self.backend.bk_expand_dims(s2, off_S2)], axis=2
2949
- )
2950
- if calc_var:
2951
- VS2 = self.backend.bk_concat(
2952
- [VS2, self.backend.bk_expand_dims(vs2, off_S2)],
2953
- axis=2,
2954
- )
2955
3000
 
2956
3001
  #### S1_auto computation
2957
3002
  ### Image 1 : S1 = < M1 >_pix
@@ -2968,36 +3013,32 @@ class funct(FOC.FoCUS):
2968
3013
  MX, vmask, axis=1, rank=j3
2969
3014
  ) # [Nbatch, Nmask, Norient3]
2970
3015
  if return_data:
2971
- if S1 is None:
2972
- S1 = {}
2973
- if out_nside is not None and out_nside<nside_j3:
3016
+ if out_nside is not None and out_nside < nside_j3:
2974
3017
  s1 = self.backend.bk_reduce_mean(
2975
- self.backend.bk_reshape(s1,[s1.shape[0],
2976
- 12*out_nside**2,
2977
- (nside_j3//out_nside)**2,
2978
- s1.shape[2]]),2)
3018
+ self.backend.bk_reshape(
3019
+ s1,
3020
+ [
3021
+ s1.shape[0],
3022
+ 12 * out_nside**2,
3023
+ (nside_j3 // out_nside) ** 2,
3024
+ s1.shape[2],
3025
+ ],
3026
+ ),
3027
+ 2,
3028
+ )
2979
3029
  S1[j3] = s1
2980
3030
  else:
2981
3031
  ### Normalize S1
2982
3032
  if norm is not None:
2983
3033
  self.div_norm(s1, (P1_dic[j3]) ** 0.5)
2984
3034
  ### We store S1 for image1 [Nbatch, Nmask, NS1, Norient3]
2985
- if S1 is None:
2986
- S1 = self.backend.bk_expand_dims(
2987
- s1, off_S2
3035
+ S1.append(
3036
+ self.backend.bk_expand_dims(s1, off_S2)
3037
+ ) # Add a dimension for NS1
3038
+ if calc_var:
3039
+ VS1.append(
3040
+ self.backend.bk_expand_dims(vs1, off_S2)
2988
3041
  ) # Add a dimension for NS1
2989
- if calc_var:
2990
- VS1 = self.backend.bk_expand_dims(
2991
- vs1, off_S2
2992
- ) # Add a dimension for NS1
2993
- else:
2994
- S1 = self.backend.bk_concat(
2995
- [S1, self.backend.bk_expand_dims(s1, off_S2)], axis=2
2996
- )
2997
- if calc_var:
2998
- VS1 = self.backend.bk_concat(
2999
- [VS1, self.backend.bk_expand_dims(vs1, off_S2)], axis=2
3000
- )
3001
3042
 
3002
3043
  # Initialize dictionaries for |I1*Psi_j| * Psi_j3
3003
3044
  M1convPsi_dic = {}
@@ -3006,7 +3047,7 @@ class funct(FOC.FoCUS):
3006
3047
  M2convPsi_dic = {}
3007
3048
 
3008
3049
  ###### S3
3009
- nside_j2=nside_j3
3050
+ nside_j2 = nside_j3
3010
3051
  for j2 in range(0, j3 + 1): # j2 <= j3
3011
3052
  if return_data:
3012
3053
  if S4[j3] is None:
@@ -3041,13 +3082,20 @@ class funct(FOC.FoCUS):
3041
3082
  if return_data:
3042
3083
  if S3[j3] is None:
3043
3084
  S3[j3] = {}
3044
- if out_nside is not None and out_nside<nside_j2:
3085
+ if out_nside is not None and out_nside < nside_j2:
3045
3086
  s3 = self.backend.bk_reduce_mean(
3046
- self.backend.bk_reshape(s3,[s3.shape[0],
3047
- 12*out_nside**2,
3048
- (nside_j2//out_nside)**2,
3049
- s3.shape[2],
3050
- s3.shape[3]]),2)
3087
+ self.backend.bk_reshape(
3088
+ s3,
3089
+ [
3090
+ s3.shape[0],
3091
+ 12 * out_nside**2,
3092
+ (nside_j2 // out_nside) ** 2,
3093
+ s3.shape[2],
3094
+ s3.shape[3],
3095
+ ],
3096
+ ),
3097
+ 2,
3098
+ )
3051
3099
  S3[j3][j2] = s3
3052
3100
  else:
3053
3101
  ### Normalize S3 with S2_j [Nbatch, Nmask, Norient_j]
@@ -3062,23 +3110,18 @@ class funct(FOC.FoCUS):
3062
3110
  ) # [Nbatch, Nmask, Norient3, Norient2]
3063
3111
 
3064
3112
  ### Store S3 as a complex [Nbatch, Nmask, NS3, Norient3, Norient2]
3065
- if S3 is None:
3066
- S3 = self.backend.bk_expand_dims(
3067
- s3, off_S3
3068
- ) # Add a dimension for NS3
3069
- if calc_var:
3070
- VS3 = self.backend.bk_expand_dims(
3071
- vs3, off_S3
3072
- ) # Add a dimension for NS3
3073
- else:
3074
- S3 = self.backend.bk_concat(
3075
- [S3, self.backend.bk_expand_dims(s3, off_S3)], axis=2
3113
+
3114
+ # S3.append(self.backend.bk_reshape(s3,[s3.shape[0],s3.shape[1],
3115
+ # s3.shape[2]*s3.shape[3]]))
3116
+ S3.append(
3117
+ self.backend.bk_expand_dims(s3, off_S3)
3118
+ ) # Add a dimension for NS3
3119
+ if calc_var:
3120
+ VS3.append(
3121
+ self.backend.bk_expand_dims(vs3, off_S3)
3076
3122
  ) # Add a dimension for NS3
3077
- if calc_var:
3078
- VS3 = self.backend.bk_concat(
3079
- [VS3, self.backend.bk_expand_dims(vs3, off_S3)],
3080
- axis=2,
3081
- ) # Add a dimension for NS3
3123
+ # VS3.append(self.backend.bk_reshape(vs3,[s3.shape[0],s3.shape[1],
3124
+ # s3.shape[2]*s3.shape[3]]))
3082
3125
 
3083
3126
  ### S3_cross = < (I1 * Psi)_j3 x (|I2 * Psi_j2| * Psi_j3)^* >_pix
3084
3127
  ### S3P_cross = < (I2 * Psi)_j3 x (|I1 * Psi_j2| * Psi_j3)^* >_pix
@@ -3130,19 +3173,33 @@ class funct(FOC.FoCUS):
3130
3173
  if S3[j3] is None:
3131
3174
  S3[j3] = {}
3132
3175
  S3P[j3] = {}
3133
- if out_nside is not None and out_nside<nside_j2:
3176
+ if out_nside is not None and out_nside < nside_j2:
3134
3177
  s3 = self.backend.bk_reduce_mean(
3135
- self.backend.bk_reshape(s3,[s3.shape[0],
3136
- 12*out_nside**2,
3137
- (nside_j2//out_nside)**2,
3138
- s3.shape[2],
3139
- s3.shape[3]]),2)
3178
+ self.backend.bk_reshape(
3179
+ s3,
3180
+ [
3181
+ s3.shape[0],
3182
+ 12 * out_nside**2,
3183
+ (nside_j2 // out_nside) ** 2,
3184
+ s3.shape[2],
3185
+ s3.shape[3],
3186
+ ],
3187
+ ),
3188
+ 2,
3189
+ )
3140
3190
  s3p = self.backend.bk_reduce_mean(
3141
- self.backend.bk_reshape(s3p,[s3.shape[0],
3142
- 12*out_nside**2,
3143
- (nside_j2//out_nside)**2,
3144
- s3.shape[2],
3145
- s3.shape[3]]),2)
3191
+ self.backend.bk_reshape(
3192
+ s3p,
3193
+ [
3194
+ s3.shape[0],
3195
+ 12 * out_nside**2,
3196
+ (nside_j2 // out_nside) ** 2,
3197
+ s3.shape[2],
3198
+ s3.shape[3],
3199
+ ],
3200
+ ),
3201
+ 2,
3202
+ )
3146
3203
  S3[j3][j2] = s3
3147
3204
  S3P[j3][j2] = s3p
3148
3205
  else:
@@ -3166,43 +3223,34 @@ class funct(FOC.FoCUS):
3166
3223
  ) # [Nbatch, Nmask, Norient3, Norient2]
3167
3224
 
3168
3225
  ### Store S3 and S3P as a complex [Nbatch, Nmask, NS3, Norient3, Norient2]
3169
- if S3 is None:
3170
- S3 = self.backend.bk_expand_dims(
3171
- s3, off_S3
3172
- ) # Add a dimension for NS3
3173
- if calc_var:
3174
- VS3 = self.backend.bk_expand_dims(
3175
- vs3, off_S3
3176
- ) # Add a dimension for NS3
3177
- else:
3178
- S3 = self.backend.bk_concat(
3179
- [S3, self.backend.bk_expand_dims(s3, off_S3)], axis=2
3180
- ) # Add a dimension for NS3
3181
- if calc_var:
3182
- VS3 = self.backend.bk_concat(
3183
- [VS3, self.backend.bk_expand_dims(vs3, off_S3)],
3184
- axis=2,
3185
- ) # Add a dimension for NS3
3186
- if S3P is None:
3187
- S3P = self.backend.bk_expand_dims(
3188
- s3p, off_S3
3226
+
3227
+ # S3.append(self.backend.bk_reshape(s3,[s3.shape[0],s3.shape[1],
3228
+ # s3.shape[2]*s3.shape[3]]))
3229
+ S3.append(
3230
+ self.backend.bk_expand_dims(s3, off_S3)
3231
+ ) # Add a dimension for NS3
3232
+ if calc_var:
3233
+ VS3.append(
3234
+ self.backend.bk_expand_dims(vs3, off_S3)
3189
3235
  ) # Add a dimension for NS3
3190
- if calc_var:
3191
- VS3P = self.backend.bk_expand_dims(
3192
- vs3p, off_S3
3193
- ) # Add a dimension for NS3
3194
- else:
3195
- S3P = self.backend.bk_concat(
3196
- [S3P, self.backend.bk_expand_dims(s3p, off_S3)], axis=2
3236
+
3237
+ # VS3.append(self.backend.bk_reshape(vs3,[s3.shape[0],s3.shape[1],
3238
+ # s3.shape[2]*s3.shape[3]]))
3239
+
3240
+ # S3P.append(self.backend.bk_reshape(s3p,[s3.shape[0],s3.shape[1],
3241
+ # s3.shape[2]*s3.shape[3]]))
3242
+ S3P.append(
3243
+ self.backend.bk_expand_dims(s3p, off_S3)
3244
+ ) # Add a dimension for NS3
3245
+ if calc_var:
3246
+ VS3P.append(
3247
+ self.backend.bk_expand_dims(vs3p, off_S3)
3197
3248
  ) # Add a dimension for NS3
3198
- if calc_var:
3199
- VS3P = self.backend.bk_concat(
3200
- [VS3P, self.backend.bk_expand_dims(vs3p, off_S3)],
3201
- axis=2,
3202
- ) # Add a dimension for NS3
3249
+ # VS3P.append(self.backend.bk_reshape(vs3p,[s3.shape[0],s3.shape[1],
3250
+ # s3.shape[2]*s3.shape[3]]))
3203
3251
 
3204
3252
  ##### S4
3205
- nside_j1=nside_j2
3253
+ nside_j1 = nside_j2
3206
3254
  for j1 in range(0, j2 + 1): # j1 <= j2
3207
3255
  ### S4_auto = <(|I1 * psi1| * psi3)(|I1 * psi2| * psi3)^*>
3208
3256
  if not cross:
@@ -3228,14 +3276,21 @@ class funct(FOC.FoCUS):
3228
3276
  if return_data:
3229
3277
  if S4[j3][j2] is None:
3230
3278
  S4[j3][j2] = {}
3231
- if out_nside is not None and out_nside<nside_j1:
3279
+ if out_nside is not None and out_nside < nside_j1:
3232
3280
  s4 = self.backend.bk_reduce_mean(
3233
- self.backend.bk_reshape(s4,[s4.shape[0],
3234
- 12*out_nside**2,
3235
- (nside_j1//out_nside)**2,
3236
- s4.shape[2],
3237
- s4.shape[3],
3238
- s4.shape[4]]),2)
3281
+ self.backend.bk_reshape(
3282
+ s4,
3283
+ [
3284
+ s4.shape[0],
3285
+ 12 * out_nside**2,
3286
+ (nside_j1 // out_nside) ** 2,
3287
+ s4.shape[2],
3288
+ s4.shape[3],
3289
+ s4.shape[4],
3290
+ ],
3291
+ ),
3292
+ 2,
3293
+ )
3239
3294
  S4[j3][j2][j1] = s4
3240
3295
  else:
3241
3296
  ### Normalize S4 with S2_j [Nbatch, Nmask, Norient_j]
@@ -3259,27 +3314,18 @@ class funct(FOC.FoCUS):
3259
3314
  ** 0.5,
3260
3315
  ) # [Nbatch, Nmask, Norient3, Norient2, Norient1]
3261
3316
  ### Store S4 as a complex [Nbatch, Nmask, NS4, Norient3, Norient2, Norient1]
3262
- if S4 is None:
3263
- S4 = self.backend.bk_expand_dims(
3264
- s4, off_S4
3265
- ) # Add a dimension for NS4
3266
- if calc_var:
3267
- VS4 = self.backend.bk_expand_dims(
3268
- vs4, off_S4
3269
- ) # Add a dimension for NS4
3270
- else:
3271
- S4 = self.backend.bk_concat(
3272
- [S4, self.backend.bk_expand_dims(s4, off_S4)],
3273
- axis=2,
3317
+
3318
+ # S4.append(self.backend.bk_reshape(s4,[s4.shape[0],s4.shape[1],
3319
+ # s4.shape[2]*s4.shape[3]*s4.shape[4]]))
3320
+ S4.append(
3321
+ self.backend.bk_expand_dims(s4, off_S4)
3322
+ ) # Add a dimension for NS4
3323
+ if calc_var:
3324
+ # VS4.append(self.backend.bk_reshape(vs4,[s4.shape[0],s4.shape[1],
3325
+ # s4.shape[2]*s4.shape[3]*s4.shape[4]]))
3326
+ VS4.append(
3327
+ self.backend.bk_expand_dims(vs4, off_S4)
3274
3328
  ) # Add a dimension for NS4
3275
- if calc_var:
3276
- VS4 = self.backend.bk_concat(
3277
- [
3278
- VS4,
3279
- self.backend.bk_expand_dims(vs4, off_S4),
3280
- ],
3281
- axis=2,
3282
- ) # Add a dimension for NS4
3283
3329
 
3284
3330
  ### S4_cross = <(|I1 * psi1| * psi3)(|I2 * psi2| * psi3)^*>
3285
3331
  else:
@@ -3305,14 +3351,21 @@ class funct(FOC.FoCUS):
3305
3351
  if return_data:
3306
3352
  if S4[j3][j2] is None:
3307
3353
  S4[j3][j2] = {}
3308
- if out_nside is not None and out_nside<nside_j1:
3354
+ if out_nside is not None and out_nside < nside_j1:
3309
3355
  s4 = self.backend.bk_reduce_mean(
3310
- self.backend.bk_reshape(s4,[s4.shape[0],
3311
- 12*out_nside**2,
3312
- (nside_j1//out_nside)**2,
3313
- s4.shape[2],
3314
- s4.shape[3],
3315
- s4.shape[4]]),2)
3356
+ self.backend.bk_reshape(
3357
+ s4,
3358
+ [
3359
+ s4.shape[0],
3360
+ 12 * out_nside**2,
3361
+ (nside_j1 // out_nside) ** 2,
3362
+ s4.shape[2],
3363
+ s4.shape[3],
3364
+ s4.shape[4],
3365
+ ],
3366
+ ),
3367
+ 2,
3368
+ )
3316
3369
  S4[j3][j2][j1] = s4
3317
3370
  else:
3318
3371
  ### Normalize S4 with S2_j [Nbatch, Nmask, Norient_j]
@@ -3336,41 +3389,33 @@ class funct(FOC.FoCUS):
3336
3389
  ** 0.5,
3337
3390
  ) # [Nbatch, Nmask, Norient3, Norient2, Norient1]
3338
3391
  ### Store S4 as a complex [Nbatch, Nmask, NS4, Norient3, Norient2, Norient1]
3339
- if S4 is None:
3340
- S4 = self.backend.bk_expand_dims(
3341
- s4, off_S4
3342
- ) # Add a dimension for NS4
3343
- if calc_var:
3344
- VS4 = self.backend.bk_expand_dims(
3345
- vs4, off_S4
3346
- ) # Add a dimension for NS4
3347
- else:
3348
- S4 = self.backend.bk_concat(
3349
- [S4, self.backend.bk_expand_dims(s4, off_S4)],
3350
- axis=2,
3392
+ # S4.append(self.backend.bk_reshape(s4,[s4.shape[0],s4.shape[1],
3393
+ # s4.shape[2]*s4.shape[3]*s4.shape[4]]))
3394
+ S4.append(
3395
+ self.backend.bk_expand_dims(s4, off_S4)
3396
+ ) # Add a dimension for NS4
3397
+ if calc_var:
3398
+
3399
+ # VS4.append(self.backend.bk_reshape(vs4,[s4.shape[0],s4.shape[1],
3400
+ # s4.shape[2]*s4.shape[3]*s4.shape[4]]))
3401
+ VS4.append(
3402
+ self.backend.bk_expand_dims(vs4, off_S4)
3351
3403
  ) # Add a dimension for NS4
3352
- if calc_var:
3353
- VS4 = self.backend.bk_concat(
3354
- [
3355
- VS4,
3356
- self.backend.bk_expand_dims(vs4, off_S4),
3357
- ],
3358
- axis=2,
3359
- ) # Add a dimension for NS4
3360
- nside_j1=nside_j1 // 2
3361
- nside_j2=nside_j2 // 2
3362
-
3404
+
3405
+ nside_j1 = nside_j1 // 2
3406
+ nside_j2 = nside_j2 // 2
3407
+
3363
3408
  ###### Reshape for next iteration on j3
3364
3409
  ### Image I1,
3365
3410
  # downscale the I1 [Nbatch, Npix_j3]
3366
3411
  if j3 != Jmax - 1:
3367
- I1_smooth = self.smooth(I1, axis=1)
3368
- I1 = self.ud_grade_2(I1_smooth, axis=1)
3412
+ I1 = self.smooth(I1, axis=1)
3413
+ I1 = self.ud_grade_2(I1, axis=1)
3369
3414
 
3370
3415
  ### Image I2
3371
3416
  if cross:
3372
- I2_smooth = self.smooth(I2, axis=1)
3373
- I2 = self.ud_grade_2(I2_smooth, axis=1)
3417
+ I2 = self.smooth(I2, axis=1)
3418
+ I2 = self.ud_grade_2(I2, axis=1)
3374
3419
 
3375
3420
  ### Modules
3376
3421
  for j2 in range(0, j3 + 1): # j2 =< j3
@@ -3388,8 +3433,9 @@ class funct(FOC.FoCUS):
3388
3433
  M2_dic[j2], axis=1
3389
3434
  ) # [Nbatch, Npix_j3, Norient3]
3390
3435
  M2_dic[j2] = self.ud_grade_2(
3391
- M2_smooth, axis=1
3436
+ M2, axis=1
3392
3437
  ) # [Nbatch, Npix_j3, Norient3]
3438
+
3393
3439
  ### Mask
3394
3440
  vmask = self.ud_grade_2(vmask, axis=1)
3395
3441
 
@@ -3404,7 +3450,33 @@ class funct(FOC.FoCUS):
3404
3450
  self.P1_dic = P1_dic
3405
3451
  if cross:
3406
3452
  self.P2_dic = P2_dic
3453
+ """
3454
+ Sout=[s0]+S1+S2+S3+S4
3407
3455
 
3456
+ if cross:
3457
+ Sout=Sout+S3P
3458
+ if calc_var:
3459
+ SVout=[vs0]+VS1+VS2+VS3+VS4
3460
+ if cross:
3461
+ VSout=VSout+VS3P
3462
+ return self.backend.bk_concat(Sout, 2),self.backend.bk_concat(VSout, 2)
3463
+
3464
+ return self.backend.bk_concat(Sout, 2)
3465
+ """
3466
+ if not return_data:
3467
+ S1 = self.backend.bk_concat(S1, 2)
3468
+ S2 = self.backend.bk_concat(S2, 2)
3469
+ S3 = self.backend.bk_concat(S3, 2)
3470
+ S4 = self.backend.bk_concat(S4, 2)
3471
+ if cross:
3472
+ S3P = self.backend.bk_concat(S3P, 2)
3473
+ if calc_var:
3474
+ VS1 = self.backend.bk_concat(VS1, 2)
3475
+ VS2 = self.backend.bk_concat(VS2, 2)
3476
+ VS3 = self.backend.bk_concat(VS3, 2)
3477
+ VS4 = self.backend.bk_concat(VS4, 2)
3478
+ if cross:
3479
+ VS3P = self.backend.bk_concat(VS3P, 2)
3408
3480
  if calc_var:
3409
3481
  if not cross:
3410
3482
  return scat_cov(
@@ -3487,10 +3559,17 @@ class funct(FOC.FoCUS):
3487
3559
  M_dic[j2], axis=1
3488
3560
  ) # [Nbatch, Npix_j3, Norient3, Norient2]
3489
3561
  if cmat2 is not None:
3490
- tmp2 = self.backend.bk_repeat(MconvPsi, 4, axis=-1)
3562
+ tmp2 = self.backend.bk_repeat(MconvPsi, self.NORIENT, axis=-1)
3491
3563
  MconvPsi = self.backend.bk_reduce_sum(
3492
3564
  self.backend.bk_reshape(
3493
- cmat2[j3][j2] * tmp2, [tmp2.shape[0], cmat2[j3].shape[1], 4, 4, 4]
3565
+ cmat2[j3][j2] * tmp2,
3566
+ [
3567
+ tmp2.shape[0],
3568
+ cmat2[j3].shape[1],
3569
+ self.NORIENT,
3570
+ self.NORIENT,
3571
+ self.NORIENT,
3572
+ ],
3494
3573
  ),
3495
3574
  3,
3496
3575
  )
@@ -3566,6 +3645,27 @@ class funct(FOC.FoCUS):
3566
3645
  ) # [Nbatch, Nmask, Norient3, Norient2, Norient1]
3567
3646
  return s4
3568
3647
 
3648
+ def to_gaussian(self,x):
3649
+ from scipy.stats import norm
3650
+ from scipy.interpolate import interp1d
3651
+
3652
+ idx=np.argsort(x.flatten())
3653
+ p = (np.arange(1, idx.shape[0] + 1) - 0.5) / idx.shape[0]
3654
+ im_target=x.flatten()
3655
+ im_target[idx] = norm.ppf(p)
3656
+
3657
+ # Interpolation cubique
3658
+ self.f_gaussian = interp1d(im_target[idx], x.flatten()[idx], kind='cubic')
3659
+ self.val_min=im_target[idx[0]]
3660
+ self.val_max=im_target[idx[-1]]
3661
+ return im_target.reshape(x.shape)
3662
+
3663
+
3664
+ def from_gaussian(self,x):
3665
+
3666
+ x=self.backend.bk_clip_by_value(x,self.val_min,self.val_max)
3667
+ return self.f_gaussian(self.backend.to_numpy(x))
3668
+
3569
3669
  def square(self, x):
3570
3670
  if isinstance(x, scat_cov):
3571
3671
  if x.S1 is None:
@@ -3615,42 +3715,47 @@ class funct(FOC.FoCUS):
3615
3715
  return self.backend.bk_abs(self.backend.bk_sqrt(x))
3616
3716
 
3617
3717
  def reduce_mean(self, x):
3618
-
3718
+
3619
3719
  if isinstance(x, scat_cov):
3620
- result = self.backend.bk_reduce_sum(self.backend.bk_abs(x.S0)) + \
3621
- self.backend.bk_reduce_sum(self.backend.bk_abs(x.S2)) + \
3622
- self.backend.bk_reduce_sum(self.backend.bk_abs(x.S3)) + \
3623
- self.backend.bk_reduce_sum(self.backend.bk_abs(x.S4))
3624
-
3625
- N = self.backend.bk_size(x.S0)+self.backend.bk_size(x.S2)+ \
3626
- self.backend.bk_size(x.S3)+self.backend.bk_size(x.S4)
3627
-
3720
+ result = (
3721
+ self.backend.bk_reduce_sum(self.backend.bk_abs(x.S0))
3722
+ + self.backend.bk_reduce_sum(self.backend.bk_abs(x.S2))
3723
+ + self.backend.bk_reduce_sum(self.backend.bk_abs(x.S3))
3724
+ + self.backend.bk_reduce_sum(self.backend.bk_abs(x.S4))
3725
+ )
3726
+
3727
+ N = (
3728
+ self.backend.bk_size(x.S0)
3729
+ + self.backend.bk_size(x.S2)
3730
+ + self.backend.bk_size(x.S3)
3731
+ + self.backend.bk_size(x.S4)
3732
+ )
3733
+
3628
3734
  if x.S1 is not None:
3629
- result = result+self.backend.bk_reduce_sum(self.backend.bk_abs(x.S1))
3735
+ result = result + self.backend.bk_reduce_sum(self.backend.bk_abs(x.S1))
3630
3736
  N = N + self.backend.bk_size(x.S1)
3631
3737
  if x.S3P is not None:
3632
- result = result+self.backend.bk_reduce_sum(self.backend.bk_abs(x.S3P))
3738
+ result = result + self.backend.bk_reduce_sum(self.backend.bk_abs(x.S3P))
3633
3739
  N = N + self.backend.bk_size(x.S3P)
3634
- return result/self.backend.bk_cast(N)
3740
+ return result / self.backend.bk_cast(N)
3635
3741
  else:
3636
3742
  return self.backend.bk_reduce_mean(x, axis=0)
3637
-
3638
3743
 
3639
3744
  def reduce_mean_batch(self, x):
3640
-
3745
+
3641
3746
  if isinstance(x, scat_cov):
3642
-
3643
- sS0=self.backend.bk_reduce_mean(x.S0, axis=0)
3644
- sS2=self.backend.bk_reduce_mean(x.S2, axis=0)
3645
- sS3=self.backend.bk_reduce_mean(x.S3, axis=0)
3646
- sS4=self.backend.bk_reduce_mean(x.S4, axis=0)
3647
- sS1=None
3648
- sS3P=None
3747
+
3748
+ sS0 = self.backend.bk_reduce_mean(x.S0, axis=0)
3749
+ sS2 = self.backend.bk_reduce_mean(x.S2, axis=0)
3750
+ sS3 = self.backend.bk_reduce_mean(x.S3, axis=0)
3751
+ sS4 = self.backend.bk_reduce_mean(x.S4, axis=0)
3752
+ sS1 = None
3753
+ sS3P = None
3649
3754
  if x.S1 is not None:
3650
3755
  sS1 = self.backend.bk_reduce_mean(x.S1, axis=0)
3651
3756
  if x.S3P is not None:
3652
3757
  sS3P = self.backend.bk_reduce_mean(x.S3P, axis=0)
3653
-
3758
+
3654
3759
  result = scat_cov(
3655
3760
  sS0,
3656
3761
  sS2,
@@ -3664,22 +3769,22 @@ class funct(FOC.FoCUS):
3664
3769
  return result
3665
3770
  else:
3666
3771
  return self.backend.bk_reduce_mean(x, axis=0)
3667
-
3772
+
3668
3773
  def reduce_sum_batch(self, x):
3669
-
3774
+
3670
3775
  if isinstance(x, scat_cov):
3671
-
3672
- sS0=self.backend.bk_reduce_sum(x.S0, axis=0)
3673
- sS2=self.backend.bk_reduce_sum(x.S2, axis=0)
3674
- sS3=self.backend.bk_reduce_sum(x.S3, axis=0)
3675
- sS4=self.backend.bk_reduce_sum(x.S4, axis=0)
3676
- sS1=None
3677
- sS3P=None
3776
+
3777
+ sS0 = self.backend.bk_reduce_sum(x.S0, axis=0)
3778
+ sS2 = self.backend.bk_reduce_sum(x.S2, axis=0)
3779
+ sS3 = self.backend.bk_reduce_sum(x.S3, axis=0)
3780
+ sS4 = self.backend.bk_reduce_sum(x.S4, axis=0)
3781
+ sS1 = None
3782
+ sS3P = None
3678
3783
  if x.S1 is not None:
3679
3784
  sS1 = self.backend.bk_reduce_sum(x.S1, axis=0)
3680
3785
  if x.S3P is not None:
3681
3786
  sS3P = self.backend.bk_reduce_sum(x.S3P, axis=0)
3682
-
3787
+
3683
3788
  result = scat_cov(
3684
3789
  sS0,
3685
3790
  sS2,
@@ -3693,7 +3798,7 @@ class funct(FOC.FoCUS):
3693
3798
  return result
3694
3799
  else:
3695
3800
  return self.backend.bk_reduce_mean(x, axis=0)
3696
-
3801
+
3697
3802
  def reduce_distance(self, x, y, sigma=None):
3698
3803
 
3699
3804
  if isinstance(x, scat_cov):
@@ -3729,11 +3834,13 @@ class funct(FOC.FoCUS):
3729
3834
  return result
3730
3835
  else:
3731
3836
  if sigma is None:
3732
- tmp=x-y
3837
+ tmp = x - y
3733
3838
  else:
3734
- tmp=(x-y)/sigma
3839
+ tmp = (x - y) / sigma
3735
3840
  # do abs in case of complex values
3736
- return self.backend.bk_abs(self.backend.bk_reduce_mean(self.backend.bk_square(tmp)))
3841
+ return self.backend.bk_abs(
3842
+ self.backend.bk_reduce_mean(self.backend.bk_square(tmp))
3843
+ )
3737
3844
 
3738
3845
  def reduce_sum(self, x):
3739
3846
 
@@ -3909,3 +4016,122 @@ class funct(FOC.FoCUS):
3909
4016
  return scat_cov(
3910
4017
  s0, s2, s3, s4, s1=s1, s3p=s3p, backend=self.backend, use_1D=self.use_1D
3911
4018
  )
4019
+
4020
+ def synthesis(self,
4021
+ image_target,
4022
+ nstep=4,
4023
+ seed=1234,
4024
+ edge=True,
4025
+ to_gaussian=True,
4026
+ EVAL_FREQUENCY=100,
4027
+ NUM_EPOCHS = 300):
4028
+
4029
+ import foscat.Synthesis as synthe
4030
+ import time
4031
+
4032
+ def The_loss(u,scat_operator,args):
4033
+ ref = args[0]
4034
+ sref = args[1]
4035
+
4036
+ # compute scattering covariance of the current synthetised map called u
4037
+ learn=scat_operator.reduce_mean_batch(scat_operator.eval(u,edge=edge))
4038
+
4039
+ # make the difference withe the reference coordinates
4040
+ loss=scat_operator.reduce_distance(learn,ref,sigma=sref)
4041
+
4042
+ return loss
4043
+
4044
+ if to_gaussian:
4045
+ # Change the data histogram to gaussian distribution
4046
+ im_target=self.to_gaussian(image_target)
4047
+ else:
4048
+ im_target=image_target
4049
+
4050
+ axis=len(im_target.shape)-1
4051
+ if self.use_2D:
4052
+ axis-=1
4053
+ if axis==0:
4054
+ im_target=self.backend.bk_expand_dims(im_target,0)
4055
+
4056
+ # compute the number of possible steps
4057
+ if self.use_2D:
4058
+ jmax=int(np.min([np.log(im_target.shape[1]),np.log(im_target.shape[2])])/np.log(2))
4059
+ elif self.use_1D:
4060
+ jmax=int(np.log(im_target.shape[1])/np.log(2))
4061
+ else:
4062
+ jmax=int((np.log(im_target.shape[1]//12)/np.log(2))/2)
4063
+ nside=2**jmax
4064
+
4065
+ if nstep>jmax-1:
4066
+ nstep=jmax-1
4067
+
4068
+ t1=time.time()
4069
+ tmp={}
4070
+ tmp[nstep-1]=im_target
4071
+ for l in range(nstep-2,-1,-1):
4072
+ tmp[l]=self.ud_grade_2(tmp[l+1],axis=1)
4073
+
4074
+ if not self.use_2D and not self.use_1D:
4075
+ l_nside=nside//(2**(nstep-1))
4076
+
4077
+ for k in range(nstep):
4078
+ if k==0:
4079
+ np.random.seed(seed)
4080
+ if self.use_2D:
4081
+ imap=np.random.randn(tmp[k].shape[0],
4082
+ tmp[k].shape[1],
4083
+ tmp[k].shape[2])
4084
+ else:
4085
+ imap=np.random.randn(tmp[k].shape[0],
4086
+ tmp[k].shape[1])
4087
+ else:
4088
+ axis=1
4089
+ # if the kernel size is bigger than 3 increase the binning before smoothing
4090
+ if self.use_2D:
4091
+ imap = self.up_grade(
4092
+ omap, imap.shape[axis] * 2, axis=1, nouty=imap.shape[axis + 1] * 2
4093
+ )
4094
+ elif self.use_1D:
4095
+ imap = self.up_grade(omap, imap.shape[axis] * 2, axis=1)
4096
+ else:
4097
+ imap = self.up_grade(omap, l_nside, axis=axis)
4098
+
4099
+ # compute the coefficients for the target image
4100
+ ref,sref=self.eval(tmp[k],calc_var=True,edge=edge)
4101
+
4102
+ # compute the mean of the population does nothing if only one map is given
4103
+ ref=self.reduce_mean_batch(ref)
4104
+ sref=self.reduce_mean_batch(sref)
4105
+
4106
+ # define a loss to minimize
4107
+ loss=synthe.Loss(The_loss,self,ref,sref)
4108
+
4109
+ sy = synthe.Synthesis([loss])
4110
+
4111
+ # initialize the synthesised map
4112
+ if self.use_2D:
4113
+ print('Synthesis scale [ %d x %d ]'%(imap.shape[1],imap.shape[2]))
4114
+ elif self.use_1D:
4115
+ print('Synthesis scale [ %d ]'%(imap.shape[1]))
4116
+ else:
4117
+ print('Synthesis scale nside=%d'%(l_nside))
4118
+ l_nside*=2
4119
+
4120
+ # do the minimization
4121
+ omap=sy.run(imap,
4122
+ EVAL_FREQUENCY=EVAL_FREQUENCY,
4123
+ NUM_EPOCHS = NUM_EPOCHS)
4124
+
4125
+ t2=time.time()
4126
+ print('Total computation %.2fs'%(t2-t1))
4127
+
4128
+ if to_gaussian:
4129
+ omap=self.from_gaussian(omap)
4130
+
4131
+ if axis==0:
4132
+ return omap[0]
4133
+ else:
4134
+ return omap
4135
+
4136
+ def to_numpy(self,x):
4137
+ return self.backend.to_numpy(x)