foscat 3.6.0__py3-none-any.whl → 3.7.0__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
foscat/scat_cov.py CHANGED
@@ -33,9 +33,7 @@ testwarn = 0
33
33
 
34
34
 
35
35
  class scat_cov:
36
- def __init__(
37
- self, s0, s2, s3, s4, s1=None, s3p=None, backend=None, use_1D=False
38
- ):
36
+ def __init__(self, s0, s2, s3, s4, s1=None, s3p=None, backend=None, use_1D=False):
39
37
  self.S0 = s0
40
38
  self.S2 = s2
41
39
  self.S3 = s3
@@ -54,17 +52,17 @@ class scat_cov:
54
52
  if self.S1 is None:
55
53
  s1 = None
56
54
  else:
57
- s1 = self.S1.numpy()
55
+ s1 = self.backend.to_numpy(self.S1)
58
56
  if self.S3P is None:
59
57
  s3p = None
60
58
  else:
61
- s3p = self.S3P.numpy()
59
+ s3p = self.backend.to_numpy(self.S3P)
62
60
 
63
61
  return scat_cov(
64
- (self.S0.numpy()),
65
- (self.S2.numpy()),
66
- (self.S3.numpy()),
67
- (self.S4.numpy()),
62
+ self.backend.to_numpy(self.S0),
63
+ self.backend.to_numpy(self.S2),
64
+ self.backend.to_numpy(self.S3),
65
+ self.backend.to_numpy(self.S4),
68
66
  s1=s1,
69
67
  s3p=s3p,
70
68
  backend=self.backend,
@@ -94,7 +92,7 @@ class scat_cov:
94
92
  )
95
93
 
96
94
  def conv2complex(self, val):
97
- if val.dtype == "complex64" or val.dtype == "complex128":
95
+ if val.dtype == "complex64" in val.dtype or "complex128" or val.dtype == "torch.complex64" or val.dtype == "torch.complex128" :
98
96
  return val
99
97
  else:
100
98
  return self.backend.bk_complex(val, 0 * val)
@@ -1531,7 +1529,7 @@ class scat_cov:
1531
1529
  if isinstance(x, np.ndarray):
1532
1530
  return x
1533
1531
  else:
1534
- return x.numpy()
1532
+ return self.backend.to_numpy(x)
1535
1533
  else:
1536
1534
  return None
1537
1535
 
@@ -1974,7 +1972,7 @@ class scat_cov:
1974
1972
  if self.BACKEND == "numpy":
1975
1973
  s2[:, :, noff:, :] = self.S2
1976
1974
  else:
1977
- s2[:, :, noff:, :] = self.S2.numpy()
1975
+ s2[:, :, noff:, :] = self.backend.to_numpy(self.S2)
1978
1976
  for i in range(self.S2.shape[0]):
1979
1977
  for j in range(self.S2.shape[1]):
1980
1978
  for k in range(self.S2.shape[3]):
@@ -1986,7 +1984,7 @@ class scat_cov:
1986
1984
  if self.BACKEND == "numpy":
1987
1985
  s1[:, :, noff:, :] = self.S1
1988
1986
  else:
1989
- s1[:, :, noff:, :] = self.S1.numpy()
1987
+ s1[:, :, noff:, :] = self.backend.to_numpy(self.S1)
1990
1988
  for i in range(self.S1.shape[0]):
1991
1989
  for j in range(self.S1.shape[1]):
1992
1990
  for k in range(self.S1.shape[3]):
@@ -2045,12 +2043,12 @@ class scat_cov:
2045
2043
  )
2046
2044
  )
2047
2045
  else:
2048
- s3[i, j, idx[noff:], k, l_orient] = self.S3.numpy()[
2046
+ s3[i, j, idx[noff:], k, l_orient] = self.backend.to_numpy(self.S3)[
2049
2047
  i, j, j2 == ij - noff, k, l_orient
2050
2048
  ]
2051
2049
  s3[i, j, idx[:noff], k, l_orient] = (
2052
2050
  self.add_data_from_slope(
2053
- self.S3.numpy()[
2051
+ self.backend.to_numpy(self.S3)[
2054
2052
  i, j, j2 == ij - noff, k, l_orient
2055
2053
  ],
2056
2054
  noff,
@@ -2358,13 +2356,13 @@ class funct(FOC.FoCUS):
2358
2356
  tmp = self.up_grade(tmp, l_nside * 2, axis=1)
2359
2357
  if image2 is not None:
2360
2358
  tmpi2 = self.up_grade(tmpi2, l_nside * 2, axis=1)
2361
-
2362
2359
  l_nside = int(np.sqrt(tmp.shape[1] // 12))
2363
2360
  nscale = int(np.log(l_nside) / np.log(2))
2364
2361
  cmat = {}
2365
2362
  cmat2 = {}
2363
+
2364
+ # Loop over scales
2366
2365
  for k in range(nscale):
2367
- sim = self.backend.bk_abs(self.convol(tmp, axis=1))
2368
2366
  if image2 is not None:
2369
2367
  sim = self.backend.bk_real(
2370
2368
  self.backend.bk_L1(
@@ -2375,17 +2373,31 @@ class funct(FOC.FoCUS):
2375
2373
  else:
2376
2374
  sim = self.backend.bk_abs(self.convol(tmp, axis=1))
2377
2375
 
2378
- cc = self.backend.bk_reduce_mean(sim[:, :, 0] - sim[:, :, 2], 0)
2379
- ss = self.backend.bk_reduce_mean(sim[:, :, 1] - sim[:, :, 3], 0)
2380
- for m in range(smooth_scale):
2381
- if cc.shape[0] > 12:
2382
- cc = self.ud_grade_2(self.smooth(cc))
2383
- ss = self.ud_grade_2(self.smooth(ss))
2376
+ # instead of difference between "opposite" channels use weighted average
2377
+ # of cosine and sine contributions using all channels
2378
+ angles = (
2379
+ 2 * np.pi * np.arange(self.NORIENT) / self.NORIENT
2380
+ ) # shape: (NORIENT,)
2381
+ angles = angles.reshape(1, 1, self.NORIENT)
2382
+
2383
+ # we use cosines and sines as weights for sim
2384
+ weighted_cos = self.backend.bk_reduce_mean(sim * np.cos(angles), axis=-1)
2385
+ weighted_sin = self.backend.bk_reduce_mean(sim * np.sin(angles), axis=-1)
2386
+ # For simplicity, take first element of the batch
2387
+ cc = weighted_cos[0]
2388
+ ss = weighted_sin[0]
2389
+
2390
+ if smooth_scale > 0:
2391
+ for m in range(smooth_scale):
2392
+ if cc.shape[0] > 12:
2393
+ cc = self.ud_grade_2(self.smooth(cc))
2394
+ ss = self.ud_grade_2(self.smooth(ss))
2384
2395
  if cc.shape[0] != tmp.shape[0]:
2385
2396
  ll_nside = int(np.sqrt(tmp.shape[1] // 12))
2386
2397
  cc = self.up_grade(cc, ll_nside)
2387
2398
  ss = self.up_grade(ss, ll_nside)
2388
2399
 
2400
+ # compute local phase from weighted cos and sin (same as before)
2389
2401
  if self.BACKEND == "numpy":
2390
2402
  phase = np.fmod(np.arctan2(ss, cc) + 2 * np.pi, 2 * np.pi)
2391
2403
  else:
@@ -2393,87 +2405,87 @@ class funct(FOC.FoCUS):
2393
2405
  np.arctan2(ss.numpy(), cc.numpy()) + 2 * np.pi, 2 * np.pi
2394
2406
  )
2395
2407
 
2396
- iph = (4 * phase / (2 * np.pi)).astype("int")
2397
- alpha = 4 * phase / (2 * np.pi) - iph
2398
- mat = np.zeros([sim.shape[1], 4 * 4])
2408
+ # instead of linear interpolation cosine‐based interpolation
2409
+ phase_scaled = self.NORIENT * phase / (2 * np.pi)
2410
+ iph = np.floor(phase_scaled).astype("int") # lower bin index
2411
+ delta = phase_scaled - iph # fractional part in [0,1)
2412
+ # interpolation weights
2413
+ w0 = np.cos(delta * np.pi / 2) ** 2
2414
+ w1 = np.sin(delta * np.pi / 2) ** 2
2415
+
2416
+ # build rotation matrix
2417
+ mat = np.zeros([sim.shape[1], self.NORIENT * self.NORIENT])
2399
2418
  lidx = np.arange(sim.shape[1])
2400
- for l_orient in range(4):
2401
- mat[lidx, 4 * ((l_orient + iph) % 4) + l_orient] = 1.0 - alpha
2402
- mat[lidx, 4 * ((l_orient + iph + 1) % 4) + l_orient] = alpha
2419
+ for l in range(self.NORIENT):
2420
+ # Instead of simple linear weights, we use the cosine weights w0 and w1.
2421
+ col0 = self.NORIENT * ((l + iph) % self.NORIENT) + l
2422
+ col1 = self.NORIENT * ((l + iph + 1) % self.NORIENT) + l
2423
+ mat[lidx, col0] = w0
2424
+ mat[lidx, col1] = w1
2403
2425
 
2404
2426
  cmat[k] = self.backend.bk_cast(mat.astype("complex64"))
2405
2427
 
2406
- mat2 = np.zeros([k + 1, sim.shape[1], 4, 4 * 4])
2428
+ # do same modifications for mat2
2429
+ mat2 = np.zeros(
2430
+ [k + 1, sim.shape[1], self.NORIENT, self.NORIENT * self.NORIENT]
2431
+ )
2407
2432
 
2408
2433
  for k2 in range(k + 1):
2409
- tmp2 = self.backend.bk_repeat(sim, 4, axis=-1)
2434
+ tmp2 = self.backend.bk_repeat(sim, self.NORIENT, axis=-1)
2435
+
2410
2436
  sim2 = self.backend.bk_reduce_sum(
2411
2437
  self.backend.bk_reshape(
2412
- mat.reshape(1, mat.shape[0], 16) * tmp2,
2413
- [sim.shape[0], cmat[k].shape[0], 4, 4],
2438
+ mat.reshape(1, mat.shape[0], self.NORIENT * self.NORIENT)
2439
+ * tmp2,
2440
+ [sim.shape[0], cmat[k].shape[0], self.NORIENT, self.NORIENT],
2414
2441
  ),
2415
2442
  2,
2416
2443
  )
2444
+
2417
2445
  sim2 = self.backend.bk_abs(self.convol(sim2, axis=1))
2418
2446
 
2419
- cc = self.smooth(
2420
- self.backend.bk_reduce_mean(sim2[:, :, 0] - sim2[:, :, 2], 0)
2447
+ weighted_cos2 = self.backend.bk_reduce_mean(
2448
+ sim2 * np.cos(angles), axis=-1
2421
2449
  )
2422
- ss = self.smooth(
2423
- self.backend.bk_reduce_mean(sim2[:, :, 1] - sim2[:, :, 3], 0)
2450
+ weighted_sin2 = self.backend.bk_reduce_mean(
2451
+ sim2 * np.sin(angles), axis=-1
2424
2452
  )
2425
- for m in range(smooth_scale):
2426
- if cc.shape[0] > 12:
2427
- cc = self.ud_grade_2(self.smooth(cc))
2428
- ss = self.ud_grade_2(self.smooth(ss))
2429
- if cc.shape[0] != sim.shape[1]:
2453
+
2454
+ cc2 = weighted_cos2[0]
2455
+ ss2 = weighted_sin2[0]
2456
+
2457
+ if smooth_scale > 0:
2458
+ for m in range(smooth_scale):
2459
+ if cc2.shape[0] > 12:
2460
+ cc2 = self.ud_grade_2(self.smooth(cc2))
2461
+ ss2 = self.ud_grade_2(self.smooth(ss2))
2462
+
2463
+ if cc2.shape[0] != sim.shape[1]:
2430
2464
  ll_nside = int(np.sqrt(sim.shape[1] // 12))
2431
- cc = self.up_grade(cc, ll_nside)
2432
- ss = self.up_grade(ss, ll_nside)
2465
+ cc2 = self.up_grade(cc2, ll_nside)
2466
+ ss2 = self.up_grade(ss2, ll_nside)
2433
2467
 
2434
2468
  if self.BACKEND == "numpy":
2435
- phase = np.fmod(np.arctan2(ss, cc) + 2 * np.pi, 2 * np.pi)
2469
+ phase2 = np.fmod(np.arctan2(ss2, cc2) + 2 * np.pi, 2 * np.pi)
2436
2470
  else:
2437
- phase = np.fmod(
2438
- np.arctan2(ss.numpy(), cc.numpy()) + 2 * np.pi, 2 * np.pi
2471
+ phase2 = np.fmod(
2472
+ np.arctan2(ss2.numpy(), cc2.numpy()) + 2 * np.pi, 2 * np.pi
2439
2473
  )
2440
- """
2441
- for k in range(4):
2442
- hp.mollview(np.fmod(phase+np.pi,2*np.pi),cmap='jet',nest=True,hold=False,sub=(2,2,1+k))
2443
- plt.show()
2444
- return None
2445
- """
2446
- iph = (4 * phase / (2 * np.pi)).astype("int")
2447
- alpha = 4 * phase / (2 * np.pi) - iph
2474
+
2475
+ phase2_scaled = self.NORIENT * phase2 / (2 * np.pi)
2476
+ iph2 = np.floor(phase2_scaled).astype("int")
2477
+ delta2 = phase2_scaled - iph2
2478
+ w0_2 = np.cos(delta2 * np.pi / 2) ** 2
2479
+ w1_2 = np.sin(delta2 * np.pi / 2) ** 2
2448
2480
  lidx = np.arange(sim.shape[1])
2449
- for m in range(4):
2450
- for l_orient in range(4):
2451
- mat2[
2452
- k2, lidx, m, 4 * ((l_orient + iph[:, m]) % 4) + l_orient
2453
- ] = (1.0 - alpha[:, m])
2454
- mat2[
2455
- k2, lidx, m, 4 * ((l_orient + iph[:, m] + 1) % 4) + l_orient
2456
- ] = alpha[:, m]
2457
-
2458
- cmat2[k] = self.backend.bk_cast(mat2.astype("complex64"))
2459
- """
2460
- tmp=self.backend.bk_repeat(sim[0],4,axis=1)
2461
- sim2=self.backend.bk_reduce_sum(self.backend.bk_reshape(mat*tmp,[12*nside**2,4,4]),1)
2462
-
2463
- cc2=(sim2[:,0]-sim2[:,2])
2464
- ss2=(sim2[:,1]-sim2[:,3])
2465
- phase2=np.fmod(np.arctan2(ss2.numpy(),cc2.numpy())+2*np.pi,2*np.pi)
2466
-
2467
- plt.figure()
2468
- hp.mollview(phase,cmap='jet',nest=True,hold=False,sub=(2,2,1))
2469
- hp.mollview(np.fmod(phase2+np.pi,2*np.pi),cmap='jet',nest=True,hold=False,sub=(2,2,2))
2470
- plt.figure()
2471
- for k in range(4):
2472
- hp.mollview((sim[0,:,k]).numpy().real,cmap='jet',nest=True,hold=False,sub=(2,4,1+k),min=-10,max=10)
2473
- hp.mollview((sim2[:,k]).numpy().real,cmap='jet',nest=True,hold=False,sub=(2,4,5+k),min=-10,max=10)
2474
-
2475
- plt.show()
2476
- """
2481
+
2482
+ for m in range(self.NORIENT):
2483
+ for l in range(self.NORIENT):
2484
+ col0 = self.NORIENT * ((l + iph2[:, m]) % self.NORIENT) + l
2485
+ col1 = self.NORIENT * ((l + iph2[:, m] + 1) % self.NORIENT) + l
2486
+ mat2[k2, lidx, m, col0] = w0_2[:, m]
2487
+ mat2[k2, lidx, m, col1] = w1_2[:, m]
2488
+ cmat2[k] = self.backend.bk_cast(mat2.astype("complex64"))
2477
2489
 
2478
2490
  if k < l_nside - 1:
2479
2491
  tmp = self.ud_grade_2(tmp, axis=1)
@@ -2488,15 +2500,17 @@ class funct(FOC.FoCUS):
2488
2500
  )
2489
2501
 
2490
2502
  def eval(
2491
- self,
2492
- image1,
2493
- image2=None,
2494
- mask=None,
2495
- norm=None,
2496
- Auto=True,
2497
- calc_var=False,
2498
- cmat=None,
2499
- cmat2=None,
2503
+ self,
2504
+ image1,
2505
+ image2=None,
2506
+ mask=None,
2507
+ norm=None,
2508
+ calc_var=False,
2509
+ cmat=None,
2510
+ cmat2=None,
2511
+ Jmax=None,
2512
+ out_nside=None,
2513
+ edge=True
2500
2514
  ):
2501
2515
  """
2502
2516
  Calculates the scattering correlations for a batch of images. Mean are done over pixels.
@@ -2522,9 +2536,6 @@ class funct(FOC.FoCUS):
2522
2536
  norm: None or str
2523
2537
  If None no normalization is applied, if 'auto' normalize by the reference S2,
2524
2538
  if 'self' normalize by the current S2.
2525
- all_cross: False or True
2526
- If False compute all the coefficient even the Imaginary part,
2527
- If True return only the terms computable in the auto case.
2528
2539
  Returns
2529
2540
  -------
2530
2541
  S1, S2, S3, S4 normalized
@@ -2538,7 +2549,7 @@ class funct(FOC.FoCUS):
2538
2549
  )
2539
2550
  return None
2540
2551
  if mask is not None:
2541
- if list(image1.shape) != list(mask.shape)[1:]:
2552
+ if image1.shape[-2] != mask.shape[1] or image1.shape[-1] != mask.shape[2]:
2542
2553
  print(
2543
2554
  "The LAST COLUMN of the mask should have the same size ",
2544
2555
  mask.shape,
@@ -2557,9 +2568,6 @@ class funct(FOC.FoCUS):
2557
2568
  cross = False
2558
2569
  if image2 is not None:
2559
2570
  cross = True
2560
- all_cross = Auto
2561
- else:
2562
- all_cross = False
2563
2571
 
2564
2572
  ### PARAMETERS
2565
2573
  axis = 1
@@ -2595,8 +2603,16 @@ class funct(FOC.FoCUS):
2595
2603
  nside = int(np.sqrt(npix // 12))
2596
2604
 
2597
2605
  J = int(np.log(nside) / np.log(2)) # Number of j scales
2598
-
2599
- Jmax = J - self.OSTEP # Number of steps for the loop on scales
2606
+
2607
+ if (self.use_2D or self.use_1D) and self.KERNELSZ>3:
2608
+ J-=1
2609
+ if Jmax is None:
2610
+ Jmax = J # Number of steps for the loop on scales
2611
+ if Jmax>J:
2612
+ print('==========\n\n')
2613
+ print('The Jmax you requested is larger than the data size, which may cause problems while computing the scattering transform.')
2614
+ print('\n\n==========')
2615
+
2600
2616
 
2601
2617
  ### LOCAL VARIABLES (IMAGES and MASK)
2602
2618
  if len(image1.shape) == 1 or (len(image1.shape) == 2 and self.use_2D):
@@ -2620,7 +2636,7 @@ class funct(FOC.FoCUS):
2620
2636
  else:
2621
2637
  vmask = self.backend.bk_cast(mask) # [Nmask, Npix]
2622
2638
 
2623
- if self.KERNELSZ > 3:
2639
+ if self.KERNELSZ > 3 and not self.use_2D:
2624
2640
  # if the kernel size is bigger than 3 increase the binning before smoothing
2625
2641
  if self.use_2D:
2626
2642
  vmask = self.up_grade(
@@ -2644,7 +2660,7 @@ class funct(FOC.FoCUS):
2644
2660
  if cross:
2645
2661
  I2 = self.up_grade(I2, nside * 2, axis=axis)
2646
2662
 
2647
- if self.KERNELSZ > 5:
2663
+ if self.KERNELSZ > 5 and not self.use_2D:
2648
2664
  # if the kernel size is bigger than 3 increase the binning before smoothing
2649
2665
  if self.use_2D:
2650
2666
  vmask = self.up_grade(
@@ -2676,7 +2692,23 @@ class funct(FOC.FoCUS):
2676
2692
 
2677
2693
  ### INITIALIZATION
2678
2694
  # Coefficients
2679
- S1, S2, S3, S4, S3P = None, None, None, None, None
2695
+ if return_data:
2696
+ S1 = {}
2697
+ S2 = {}
2698
+ S3 = {}
2699
+ S3P = {}
2700
+ S4 = {}
2701
+ else:
2702
+ S1 = []
2703
+ S2 = []
2704
+ S3 = []
2705
+ S4 = []
2706
+ S3P = []
2707
+ VS1 = []
2708
+ VS2 = []
2709
+ VS3 = []
2710
+ VS3P = []
2711
+ VS4 = []
2680
2712
 
2681
2713
  off_S2 = -2
2682
2714
  off_S3 = -3
@@ -2686,11 +2718,6 @@ class funct(FOC.FoCUS):
2686
2718
  off_S3 = -1
2687
2719
  off_S4 = -1
2688
2720
 
2689
- # Dictionaries for S3 computation
2690
- M1_dic = {} # M stands for Module M1 = |I1 * Psi|
2691
- if cross:
2692
- M2_dic = {}
2693
-
2694
2721
  # S2 for normalization
2695
2722
  cond_init_P1_dic = (norm == "self") or (
2696
2723
  (norm == "auto") and (self.P1_dic is None)
@@ -2708,6 +2735,13 @@ class funct(FOC.FoCUS):
2708
2735
 
2709
2736
  if return_data:
2710
2737
  s0 = I1
2738
+ if out_nside is not None:
2739
+ s0 = self.backend.bk_reduce_mean(
2740
+ self.backend.bk_reshape(
2741
+ s0, [s0.shape[0], 12 * out_nside**2, (nside // out_nside) ** 2]
2742
+ ),
2743
+ 2,
2744
+ )
2711
2745
  else:
2712
2746
  if not cross:
2713
2747
  s0, l_vs0 = self.masked_mean(I1, vmask, axis=1, calc_var=True)
@@ -2717,17 +2751,37 @@ class funct(FOC.FoCUS):
2717
2751
  )
2718
2752
  vs0 = self.backend.bk_concat([l_vs0, l_vs0], 1)
2719
2753
  s0 = self.backend.bk_concat([s0, l_vs0], 1)
2720
-
2721
2754
  #### COMPUTE S1, S2, S3 and S4
2722
2755
  nside_j3 = nside # NSIDE start (nside_j3 = nside / 2^j3)
2756
+
2757
+ # a remettre comme avant
2758
+ M1_dic={}
2759
+
2723
2760
  for j3 in range(Jmax):
2761
+
2762
+ if edge:
2763
+ if self.mask_mask is None:
2764
+ self.mask_mask={}
2765
+ if self.use_2D:
2766
+ if (vmask.shape[1],vmask.shape[2]) not in self.mask_mask:
2767
+ mask_mask=np.zeros([1,vmask.shape[1],vmask.shape[2]])
2768
+ mask_mask[0,
2769
+ self.KERNELSZ//2:-self.KERNELSZ//2+1,
2770
+ self.KERNELSZ//2:-self.KERNELSZ//2+1]=1.0
2771
+ self.mask_mask[(vmask.shape[1],vmask.shape[2])]=self.backend.bk_cast(mask_mask)
2772
+ vmask=vmask*self.mask_mask[(vmask.shape[1],vmask.shape[2])]
2773
+ #print(self.KERNELSZ//2,vmask,mask_mask)
2774
+
2775
+ if self.use_1D:
2776
+ if (vmask.shape[1]) not in self.mask_mask:
2777
+ mask_mask=np.zeros([1,vmask.shape[1]])
2778
+ mask_mask[0,
2779
+ self.KERNELSZ//2:-self.KERNELSZ//2+1]=1.0
2780
+ self.mask_mask[(vmask.shape[1])]=self.backend.bk_cast(mask_mask)
2781
+ vmask=vmask*self.mask_mask[(vmask.shape[1])]
2782
+
2724
2783
  if return_data:
2725
- if S3 is None:
2726
- S3 = {}
2727
2784
  S3[j3] = None
2728
-
2729
- if S3P is None:
2730
- S3P = {}
2731
2785
  S3P[j3] = None
2732
2786
 
2733
2787
  if S4 is None:
@@ -2737,12 +2791,12 @@ class funct(FOC.FoCUS):
2737
2791
  ####### S1 and S2
2738
2792
  ### Make the convolution I1 * Psi_j3
2739
2793
  conv1 = self.convol(I1, axis=1) # [Nbatch, Npix_j3, Norient3]
2740
-
2741
2794
  if cmat is not None:
2742
- tmp2 = self.backend.bk_repeat(conv1, 4, axis=-1)
2795
+ tmp2 = self.backend.bk_repeat(conv1, self.NORIENT, axis=-1)
2743
2796
  conv1 = self.backend.bk_reduce_sum(
2744
2797
  self.backend.bk_reshape(
2745
- cmat[j3] * tmp2, [tmp2.shape[0], cmat[j3].shape[0], 4, 4]
2798
+ cmat[j3] * tmp2,
2799
+ [tmp2.shape[0], cmat[j3].shape[0], self.NORIENT, self.NORIENT],
2746
2800
  ),
2747
2801
  2,
2748
2802
  )
@@ -2778,27 +2832,31 @@ class funct(FOC.FoCUS):
2778
2832
  if return_data:
2779
2833
  if S2 is None:
2780
2834
  S2 = {}
2835
+ if out_nside is not None and out_nside < nside_j3:
2836
+ s2 = self.backend.bk_reduce_mean(
2837
+ self.backend.bk_reshape(
2838
+ s2,
2839
+ [
2840
+ s2.shape[0],
2841
+ 12 * out_nside**2,
2842
+ (nside_j3 // out_nside) ** 2,
2843
+ s2.shape[2],
2844
+ ],
2845
+ ),
2846
+ 2,
2847
+ )
2781
2848
  S2[j3] = s2
2782
2849
  else:
2783
2850
  if norm == "auto": # Normalize S2
2784
2851
  s2 /= P1_dic[j3]
2785
- if S2 is None:
2786
- S2 = self.backend.bk_expand_dims(
2787
- s2, off_S2
2852
+
2853
+ S2.append(
2854
+ self.backend.bk_expand_dims(s2, off_S2)
2855
+ ) # Add a dimension for NS2
2856
+ if calc_var:
2857
+ VS2.append(
2858
+ self.backend.bk_expand_dims(vs2, off_S2)
2788
2859
  ) # Add a dimension for NS2
2789
- if calc_var:
2790
- VS2 = self.backend.bk_expand_dims(
2791
- vs2, off_S2
2792
- ) # Add a dimension for NS2
2793
- else:
2794
- S2 = self.backend.bk_concat(
2795
- [S2, self.backend.bk_expand_dims(s2, off_S2)], axis=2
2796
- )
2797
- if calc_var:
2798
- VS2 = self.backend.bk_concat(
2799
- [VS2, self.backend.bk_expand_dims(vs2, off_S2)],
2800
- axis=2,
2801
- )
2802
2860
 
2803
2861
  #### S1_auto computation
2804
2862
  ### Image 1 : S1 = < M1 >_pix
@@ -2816,39 +2874,47 @@ class funct(FOC.FoCUS):
2816
2874
  ) # [Nbatch, Nmask, Norient3]
2817
2875
 
2818
2876
  if return_data:
2819
- if S1 is None:
2820
- S1 = {}
2877
+ if out_nside is not None and out_nside < nside_j3:
2878
+ s1 = self.backend.bk_reduce_mean(
2879
+ self.backend.bk_reshape(
2880
+ s1,
2881
+ [
2882
+ s1.shape[0],
2883
+ 12 * out_nside**2,
2884
+ (nside_j3 // out_nside) ** 2,
2885
+ s1.shape[2],
2886
+ ],
2887
+ ),
2888
+ 2,
2889
+ )
2821
2890
  S1[j3] = s1
2822
2891
  else:
2823
2892
  ### Normalize S1
2824
2893
  if norm is not None:
2825
2894
  self.div_norm(s1, (P1_dic[j3]) ** 0.5)
2826
2895
  ### We store S1 for image1 [Nbatch, Nmask, NS1, Norient3]
2827
- if S1 is None:
2828
- S1 = self.backend.bk_expand_dims(
2829
- s1, off_S2
2896
+ S1.append(
2897
+ self.backend.bk_expand_dims(s1, off_S2)
2898
+ ) # Add a dimension for NS1
2899
+ if calc_var:
2900
+ VS1.append(
2901
+ self.backend.bk_expand_dims(vs1, off_S2)
2830
2902
  ) # Add a dimension for NS1
2831
- if calc_var:
2832
- VS1 = self.backend.bk_expand_dims(
2833
- vs1, off_S2
2834
- ) # Add a dimension for NS1
2835
- else:
2836
- S1 = self.backend.bk_concat(
2837
- [S1, self.backend.bk_expand_dims(s1, off_S2)], axis=2
2838
- )
2839
- if calc_var:
2840
- VS1 = self.backend.bk_concat(
2841
- [VS1, self.backend.bk_expand_dims(vs1, off_S2)], axis=2
2842
- )
2843
2903
 
2844
2904
  else: # Cross
2845
2905
  ### Make the convolution I2 * Psi_j3
2846
2906
  conv2 = self.convol(I2, axis=1) # [Nbatch, Npix_j3, Norient3]
2847
2907
  if cmat is not None:
2848
- tmp2 = self.backend.bk_repeat(conv2, 4, axis=-1)
2908
+ tmp2 = self.backend.bk_repeat(conv2, self.NORIENT, axis=-1)
2849
2909
  conv2 = self.backend.bk_reduce_sum(
2850
2910
  self.backend.bk_reshape(
2851
- cmat[j3] * tmp2, [tmp2.shape[0], cmat[j3].shape[0], 4, 4]
2911
+ cmat[j3] * tmp2,
2912
+ [
2913
+ tmp2.shape[0],
2914
+ cmat[j3].shape[0],
2915
+ self.NORIENT,
2916
+ self.NORIENT,
2917
+ ],
2852
2918
  ),
2853
2919
  2,
2854
2920
  )
@@ -2902,8 +2968,19 @@ class funct(FOC.FoCUS):
2902
2968
  s2 = self.masked_mean(s2, vmask, axis=1, rank=j3)
2903
2969
 
2904
2970
  if return_data:
2905
- if S2 is None:
2906
- S2 = {}
2971
+ if out_nside is not None and out_nside < nside_j3:
2972
+ s2 = self.backend.bk_reduce_mean(
2973
+ self.backend.bk_reshape(
2974
+ s2,
2975
+ [
2976
+ s2.shape[0],
2977
+ 12 * out_nside**2,
2978
+ (nside_j3 // out_nside) ** 2,
2979
+ s2.shape[2],
2980
+ ],
2981
+ ),
2982
+ 2,
2983
+ )
2907
2984
  S2[j3] = s2
2908
2985
  else:
2909
2986
  ### Normalize S2_cross
@@ -2911,26 +2988,15 @@ class funct(FOC.FoCUS):
2911
2988
  s2 /= (P1_dic[j3] * P2_dic[j3]) ** 0.5
2912
2989
 
2913
2990
  ### Store S2_cross as complex [Nbatch, Nmask, NS2, Norient3]
2914
- if not all_cross:
2915
- s2 = self.backend.bk_real(s2)
2991
+ s2 = self.backend.bk_real(s2)
2916
2992
 
2917
- if S2 is None:
2918
- S2 = self.backend.bk_expand_dims(
2919
- s2, off_S2
2993
+ S2.append(
2994
+ self.backend.bk_expand_dims(s2, off_S2)
2995
+ ) # Add a dimension for NS2
2996
+ if calc_var:
2997
+ VS2.append(
2998
+ self.backend.bk_expand_dims(vs2, off_S2)
2920
2999
  ) # Add a dimension for NS2
2921
- if calc_var:
2922
- VS2 = self.backend.bk_expand_dims(
2923
- vs2, off_S2
2924
- ) # Add a dimension for NS2
2925
- else:
2926
- S2 = self.backend.bk_concat(
2927
- [S2, self.backend.bk_expand_dims(s2, off_S2)], axis=2
2928
- )
2929
- if calc_var:
2930
- VS2 = self.backend.bk_concat(
2931
- [VS2, self.backend.bk_expand_dims(vs2, off_S2)],
2932
- axis=2,
2933
- )
2934
3000
 
2935
3001
  #### S1_auto computation
2936
3002
  ### Image 1 : S1 = < M1 >_pix
@@ -2947,30 +3013,32 @@ class funct(FOC.FoCUS):
2947
3013
  MX, vmask, axis=1, rank=j3
2948
3014
  ) # [Nbatch, Nmask, Norient3]
2949
3015
  if return_data:
2950
- if S1 is None:
2951
- S1 = {}
3016
+ if out_nside is not None and out_nside < nside_j3:
3017
+ s1 = self.backend.bk_reduce_mean(
3018
+ self.backend.bk_reshape(
3019
+ s1,
3020
+ [
3021
+ s1.shape[0],
3022
+ 12 * out_nside**2,
3023
+ (nside_j3 // out_nside) ** 2,
3024
+ s1.shape[2],
3025
+ ],
3026
+ ),
3027
+ 2,
3028
+ )
2952
3029
  S1[j3] = s1
2953
3030
  else:
2954
3031
  ### Normalize S1
2955
3032
  if norm is not None:
2956
3033
  self.div_norm(s1, (P1_dic[j3]) ** 0.5)
2957
3034
  ### We store S1 for image1 [Nbatch, Nmask, NS1, Norient3]
2958
- if S1 is None:
2959
- S1 = self.backend.bk_expand_dims(
2960
- s1, off_S2
3035
+ S1.append(
3036
+ self.backend.bk_expand_dims(s1, off_S2)
3037
+ ) # Add a dimension for NS1
3038
+ if calc_var:
3039
+ VS1.append(
3040
+ self.backend.bk_expand_dims(vs1, off_S2)
2961
3041
  ) # Add a dimension for NS1
2962
- if calc_var:
2963
- VS1 = self.backend.bk_expand_dims(
2964
- vs1, off_S2
2965
- ) # Add a dimension for NS1
2966
- else:
2967
- S1 = self.backend.bk_concat(
2968
- [S1, self.backend.bk_expand_dims(s1, off_S2)], axis=2
2969
- )
2970
- if calc_var:
2971
- VS1 = self.backend.bk_concat(
2972
- [VS1, self.backend.bk_expand_dims(vs1, off_S2)], axis=2
2973
- )
2974
3042
 
2975
3043
  # Initialize dictionaries for |I1*Psi_j| * Psi_j3
2976
3044
  M1convPsi_dic = {}
@@ -2979,6 +3047,7 @@ class funct(FOC.FoCUS):
2979
3047
  M2convPsi_dic = {}
2980
3048
 
2981
3049
  ###### S3
3050
+ nside_j2 = nside_j3
2982
3051
  for j2 in range(0, j3 + 1): # j2 <= j3
2983
3052
  if return_data:
2984
3053
  if S4[j3] is None:
@@ -3013,6 +3082,20 @@ class funct(FOC.FoCUS):
3013
3082
  if return_data:
3014
3083
  if S3[j3] is None:
3015
3084
  S3[j3] = {}
3085
+ if out_nside is not None and out_nside < nside_j2:
3086
+ s3 = self.backend.bk_reduce_mean(
3087
+ self.backend.bk_reshape(
3088
+ s3,
3089
+ [
3090
+ s3.shape[0],
3091
+ 12 * out_nside**2,
3092
+ (nside_j2 // out_nside) ** 2,
3093
+ s3.shape[2],
3094
+ s3.shape[3],
3095
+ ],
3096
+ ),
3097
+ 2,
3098
+ )
3016
3099
  S3[j3][j2] = s3
3017
3100
  else:
3018
3101
  ### Normalize S3 with S2_j [Nbatch, Nmask, Norient_j]
@@ -3027,23 +3110,18 @@ class funct(FOC.FoCUS):
3027
3110
  ) # [Nbatch, Nmask, Norient3, Norient2]
3028
3111
 
3029
3112
  ### Store S3 as a complex [Nbatch, Nmask, NS3, Norient3, Norient2]
3030
- if S3 is None:
3031
- S3 = self.backend.bk_expand_dims(
3032
- s3, off_S3
3033
- ) # Add a dimension for NS3
3034
- if calc_var:
3035
- VS3 = self.backend.bk_expand_dims(
3036
- vs3, off_S3
3037
- ) # Add a dimension for NS3
3038
- else:
3039
- S3 = self.backend.bk_concat(
3040
- [S3, self.backend.bk_expand_dims(s3, off_S3)], axis=2
3113
+
3114
+ # S3.append(self.backend.bk_reshape(s3,[s3.shape[0],s3.shape[1],
3115
+ # s3.shape[2]*s3.shape[3]]))
3116
+ S3.append(
3117
+ self.backend.bk_expand_dims(s3, off_S3)
3118
+ ) # Add a dimension for NS3
3119
+ if calc_var:
3120
+ VS3.append(
3121
+ self.backend.bk_expand_dims(vs3, off_S3)
3041
3122
  ) # Add a dimension for NS3
3042
- if calc_var:
3043
- VS3 = self.backend.bk_concat(
3044
- [VS3, self.backend.bk_expand_dims(vs3, off_S3)],
3045
- axis=2,
3046
- ) # Add a dimension for NS3
3123
+ # VS3.append(self.backend.bk_reshape(vs3,[s3.shape[0],s3.shape[1],
3124
+ # s3.shape[2]*s3.shape[3]]))
3047
3125
 
3048
3126
  ### S3_cross = < (I1 * Psi)_j3 x (|I2 * Psi_j2| * Psi_j3)^* >_pix
3049
3127
  ### S3P_cross = < (I2 * Psi)_j3 x (|I1 * Psi_j2| * Psi_j3)^* >_pix
@@ -3095,6 +3173,33 @@ class funct(FOC.FoCUS):
3095
3173
  if S3[j3] is None:
3096
3174
  S3[j3] = {}
3097
3175
  S3P[j3] = {}
3176
+ if out_nside is not None and out_nside < nside_j2:
3177
+ s3 = self.backend.bk_reduce_mean(
3178
+ self.backend.bk_reshape(
3179
+ s3,
3180
+ [
3181
+ s3.shape[0],
3182
+ 12 * out_nside**2,
3183
+ (nside_j2 // out_nside) ** 2,
3184
+ s3.shape[2],
3185
+ s3.shape[3],
3186
+ ],
3187
+ ),
3188
+ 2,
3189
+ )
3190
+ s3p = self.backend.bk_reduce_mean(
3191
+ self.backend.bk_reshape(
3192
+ s3p,
3193
+ [
3194
+ s3.shape[0],
3195
+ 12 * out_nside**2,
3196
+ (nside_j2 // out_nside) ** 2,
3197
+ s3.shape[2],
3198
+ s3.shape[3],
3199
+ ],
3200
+ ),
3201
+ 2,
3202
+ )
3098
3203
  S3[j3][j2] = s3
3099
3204
  S3P[j3][j2] = s3p
3100
3205
  else:
@@ -3118,42 +3223,34 @@ class funct(FOC.FoCUS):
3118
3223
  ) # [Nbatch, Nmask, Norient3, Norient2]
3119
3224
 
3120
3225
  ### Store S3 and S3P as a complex [Nbatch, Nmask, NS3, Norient3, Norient2]
3121
- if S3 is None:
3122
- S3 = self.backend.bk_expand_dims(
3123
- s3, off_S3
3124
- ) # Add a dimension for NS3
3125
- if calc_var:
3126
- VS3 = self.backend.bk_expand_dims(
3127
- vs3, off_S3
3128
- ) # Add a dimension for NS3
3129
- else:
3130
- S3 = self.backend.bk_concat(
3131
- [S3, self.backend.bk_expand_dims(s3, off_S3)], axis=2
3132
- ) # Add a dimension for NS3
3133
- if calc_var:
3134
- VS3 = self.backend.bk_concat(
3135
- [VS3, self.backend.bk_expand_dims(vs3, off_S3)],
3136
- axis=2,
3137
- ) # Add a dimension for NS3
3138
- if S3P is None:
3139
- S3P = self.backend.bk_expand_dims(
3140
- s3p, off_S3
3226
+
3227
+ # S3.append(self.backend.bk_reshape(s3,[s3.shape[0],s3.shape[1],
3228
+ # s3.shape[2]*s3.shape[3]]))
3229
+ S3.append(
3230
+ self.backend.bk_expand_dims(s3, off_S3)
3231
+ ) # Add a dimension for NS3
3232
+ if calc_var:
3233
+ VS3.append(
3234
+ self.backend.bk_expand_dims(vs3, off_S3)
3141
3235
  ) # Add a dimension for NS3
3142
- if calc_var:
3143
- VS3P = self.backend.bk_expand_dims(
3144
- vs3p, off_S3
3145
- ) # Add a dimension for NS3
3146
- else:
3147
- S3P = self.backend.bk_concat(
3148
- [S3P, self.backend.bk_expand_dims(s3p, off_S3)], axis=2
3236
+
3237
+ # VS3.append(self.backend.bk_reshape(vs3,[s3.shape[0],s3.shape[1],
3238
+ # s3.shape[2]*s3.shape[3]]))
3239
+
3240
+ # S3P.append(self.backend.bk_reshape(s3p,[s3.shape[0],s3.shape[1],
3241
+ # s3.shape[2]*s3.shape[3]]))
3242
+ S3P.append(
3243
+ self.backend.bk_expand_dims(s3p, off_S3)
3244
+ ) # Add a dimension for NS3
3245
+ if calc_var:
3246
+ VS3P.append(
3247
+ self.backend.bk_expand_dims(vs3p, off_S3)
3149
3248
  ) # Add a dimension for NS3
3150
- if calc_var:
3151
- VS3P = self.backend.bk_concat(
3152
- [VS3P, self.backend.bk_expand_dims(vs3p, off_S3)],
3153
- axis=2,
3154
- ) # Add a dimension for NS3
3249
+ # VS3P.append(self.backend.bk_reshape(vs3p,[s3.shape[0],s3.shape[1],
3250
+ # s3.shape[2]*s3.shape[3]]))
3155
3251
 
3156
3252
  ##### S4
3253
+ nside_j1 = nside_j2
3157
3254
  for j1 in range(0, j2 + 1): # j1 <= j2
3158
3255
  ### S4_auto = <(|I1 * psi1| * psi3)(|I1 * psi2| * psi3)^*>
3159
3256
  if not cross:
@@ -3179,6 +3276,21 @@ class funct(FOC.FoCUS):
3179
3276
  if return_data:
3180
3277
  if S4[j3][j2] is None:
3181
3278
  S4[j3][j2] = {}
3279
+ if out_nside is not None and out_nside < nside_j1:
3280
+ s4 = self.backend.bk_reduce_mean(
3281
+ self.backend.bk_reshape(
3282
+ s4,
3283
+ [
3284
+ s4.shape[0],
3285
+ 12 * out_nside**2,
3286
+ (nside_j1 // out_nside) ** 2,
3287
+ s4.shape[2],
3288
+ s4.shape[3],
3289
+ s4.shape[4],
3290
+ ],
3291
+ ),
3292
+ 2,
3293
+ )
3182
3294
  S4[j3][j2][j1] = s4
3183
3295
  else:
3184
3296
  ### Normalize S4 with S2_j [Nbatch, Nmask, Norient_j]
@@ -3202,27 +3314,18 @@ class funct(FOC.FoCUS):
3202
3314
  ** 0.5,
3203
3315
  ) # [Nbatch, Nmask, Norient3, Norient2, Norient1]
3204
3316
  ### Store S4 as a complex [Nbatch, Nmask, NS4, Norient3, Norient2, Norient1]
3205
- if S4 is None:
3206
- S4 = self.backend.bk_expand_dims(
3207
- s4, off_S4
3208
- ) # Add a dimension for NS4
3209
- if calc_var:
3210
- VS4 = self.backend.bk_expand_dims(
3211
- vs4, off_S4
3212
- ) # Add a dimension for NS4
3213
- else:
3214
- S4 = self.backend.bk_concat(
3215
- [S4, self.backend.bk_expand_dims(s4, off_S4)],
3216
- axis=2,
3317
+
3318
+ # S4.append(self.backend.bk_reshape(s4,[s4.shape[0],s4.shape[1],
3319
+ # s4.shape[2]*s4.shape[3]*s4.shape[4]]))
3320
+ S4.append(
3321
+ self.backend.bk_expand_dims(s4, off_S4)
3322
+ ) # Add a dimension for NS4
3323
+ if calc_var:
3324
+ # VS4.append(self.backend.bk_reshape(vs4,[s4.shape[0],s4.shape[1],
3325
+ # s4.shape[2]*s4.shape[3]*s4.shape[4]]))
3326
+ VS4.append(
3327
+ self.backend.bk_expand_dims(vs4, off_S4)
3217
3328
  ) # Add a dimension for NS4
3218
- if calc_var:
3219
- VS4 = self.backend.bk_concat(
3220
- [
3221
- VS4,
3222
- self.backend.bk_expand_dims(vs4, off_S4),
3223
- ],
3224
- axis=2,
3225
- ) # Add a dimension for NS4
3226
3329
 
3227
3330
  ### S4_cross = <(|I1 * psi1| * psi3)(|I2 * psi2| * psi3)^*>
3228
3331
  else:
@@ -3248,6 +3351,21 @@ class funct(FOC.FoCUS):
3248
3351
  if return_data:
3249
3352
  if S4[j3][j2] is None:
3250
3353
  S4[j3][j2] = {}
3354
+ if out_nside is not None and out_nside < nside_j1:
3355
+ s4 = self.backend.bk_reduce_mean(
3356
+ self.backend.bk_reshape(
3357
+ s4,
3358
+ [
3359
+ s4.shape[0],
3360
+ 12 * out_nside**2,
3361
+ (nside_j1 // out_nside) ** 2,
3362
+ s4.shape[2],
3363
+ s4.shape[3],
3364
+ s4.shape[4],
3365
+ ],
3366
+ ),
3367
+ 2,
3368
+ )
3251
3369
  S4[j3][j2][j1] = s4
3252
3370
  else:
3253
3371
  ### Normalize S4 with S2_j [Nbatch, Nmask, Norient_j]
@@ -3271,39 +3389,33 @@ class funct(FOC.FoCUS):
3271
3389
  ** 0.5,
3272
3390
  ) # [Nbatch, Nmask, Norient3, Norient2, Norient1]
3273
3391
  ### Store S4 as a complex [Nbatch, Nmask, NS4, Norient3, Norient2, Norient1]
3274
- if S4 is None:
3275
- S4 = self.backend.bk_expand_dims(
3276
- s4, off_S4
3277
- ) # Add a dimension for NS4
3278
- if calc_var:
3279
- VS4 = self.backend.bk_expand_dims(
3280
- vs4, off_S4
3281
- ) # Add a dimension for NS4
3282
- else:
3283
- S4 = self.backend.bk_concat(
3284
- [S4, self.backend.bk_expand_dims(s4, off_S4)],
3285
- axis=2,
3392
+ # S4.append(self.backend.bk_reshape(s4,[s4.shape[0],s4.shape[1],
3393
+ # s4.shape[2]*s4.shape[3]*s4.shape[4]]))
3394
+ S4.append(
3395
+ self.backend.bk_expand_dims(s4, off_S4)
3396
+ ) # Add a dimension for NS4
3397
+ if calc_var:
3398
+
3399
+ # VS4.append(self.backend.bk_reshape(vs4,[s4.shape[0],s4.shape[1],
3400
+ # s4.shape[2]*s4.shape[3]*s4.shape[4]]))
3401
+ VS4.append(
3402
+ self.backend.bk_expand_dims(vs4, off_S4)
3286
3403
  ) # Add a dimension for NS4
3287
- if calc_var:
3288
- VS4 = self.backend.bk_concat(
3289
- [
3290
- VS4,
3291
- self.backend.bk_expand_dims(vs4, off_S4),
3292
- ],
3293
- axis=2,
3294
- ) # Add a dimension for NS4
3404
+
3405
+ nside_j1 = nside_j1 // 2
3406
+ nside_j2 = nside_j2 // 2
3295
3407
 
3296
3408
  ###### Reshape for next iteration on j3
3297
3409
  ### Image I1,
3298
3410
  # downscale the I1 [Nbatch, Npix_j3]
3299
3411
  if j3 != Jmax - 1:
3300
- I1_smooth = self.smooth(I1, axis=1)
3301
- I1 = self.ud_grade_2(I1_smooth, axis=1)
3412
+ I1 = self.smooth(I1, axis=1)
3413
+ I1 = self.ud_grade_2(I1, axis=1)
3302
3414
 
3303
3415
  ### Image I2
3304
3416
  if cross:
3305
- I2_smooth = self.smooth(I2, axis=1)
3306
- I2 = self.ud_grade_2(I2_smooth, axis=1)
3417
+ I2 = self.smooth(I2, axis=1)
3418
+ I2 = self.ud_grade_2(I2, axis=1)
3307
3419
 
3308
3420
  ### Modules
3309
3421
  for j2 in range(0, j3 + 1): # j2 =< j3
@@ -3321,8 +3433,9 @@ class funct(FOC.FoCUS):
3321
3433
  M2_dic[j2], axis=1
3322
3434
  ) # [Nbatch, Npix_j3, Norient3]
3323
3435
  M2_dic[j2] = self.ud_grade_2(
3324
- M2_smooth, axis=1
3436
+ M2, axis=1
3325
3437
  ) # [Nbatch, Npix_j3, Norient3]
3438
+
3326
3439
  ### Mask
3327
3440
  vmask = self.ud_grade_2(vmask, axis=1)
3328
3441
 
@@ -3337,7 +3450,33 @@ class funct(FOC.FoCUS):
3337
3450
  self.P1_dic = P1_dic
3338
3451
  if cross:
3339
3452
  self.P2_dic = P2_dic
3453
+ """
3454
+ Sout=[s0]+S1+S2+S3+S4
3340
3455
 
3456
+ if cross:
3457
+ Sout=Sout+S3P
3458
+ if calc_var:
3459
+ SVout=[vs0]+VS1+VS2+VS3+VS4
3460
+ if cross:
3461
+ VSout=VSout+VS3P
3462
+ return self.backend.bk_concat(Sout, 2),self.backend.bk_concat(VSout, 2)
3463
+
3464
+ return self.backend.bk_concat(Sout, 2)
3465
+ """
3466
+ if not return_data:
3467
+ S1 = self.backend.bk_concat(S1, 2)
3468
+ S2 = self.backend.bk_concat(S2, 2)
3469
+ S3 = self.backend.bk_concat(S3, 2)
3470
+ S4 = self.backend.bk_concat(S4, 2)
3471
+ if cross:
3472
+ S3P = self.backend.bk_concat(S3P, 2)
3473
+ if calc_var:
3474
+ VS1 = self.backend.bk_concat(VS1, 2)
3475
+ VS2 = self.backend.bk_concat(VS2, 2)
3476
+ VS3 = self.backend.bk_concat(VS3, 2)
3477
+ VS4 = self.backend.bk_concat(VS4, 2)
3478
+ if cross:
3479
+ VS3P = self.backend.bk_concat(VS3P, 2)
3341
3480
  if calc_var:
3342
3481
  if not cross:
3343
3482
  return scat_cov(
@@ -3420,10 +3559,17 @@ class funct(FOC.FoCUS):
3420
3559
  M_dic[j2], axis=1
3421
3560
  ) # [Nbatch, Npix_j3, Norient3, Norient2]
3422
3561
  if cmat2 is not None:
3423
- tmp2 = self.backend.bk_repeat(MconvPsi, 4, axis=-1)
3562
+ tmp2 = self.backend.bk_repeat(MconvPsi, self.NORIENT, axis=-1)
3424
3563
  MconvPsi = self.backend.bk_reduce_sum(
3425
3564
  self.backend.bk_reshape(
3426
- cmat2[j3][j2] * tmp2, [tmp2.shape[0], cmat2[j3].shape[1], 4, 4, 4]
3565
+ cmat2[j3][j2] * tmp2,
3566
+ [
3567
+ tmp2.shape[0],
3568
+ cmat2[j3].shape[1],
3569
+ self.NORIENT,
3570
+ self.NORIENT,
3571
+ self.NORIENT,
3572
+ ],
3427
3573
  ),
3428
3574
  3,
3429
3575
  )
@@ -3499,6 +3645,27 @@ class funct(FOC.FoCUS):
3499
3645
  ) # [Nbatch, Nmask, Norient3, Norient2, Norient1]
3500
3646
  return s4
3501
3647
 
3648
+ def to_gaussian(self,x):
3649
+ from scipy.stats import norm
3650
+ from scipy.interpolate import interp1d
3651
+
3652
+ idx=np.argsort(x.flatten())
3653
+ p = (np.arange(1, idx.shape[0] + 1) - 0.5) / idx.shape[0]
3654
+ im_target=x.flatten()
3655
+ im_target[idx] = norm.ppf(p)
3656
+
3657
+ # Interpolation cubique
3658
+ self.f_gaussian = interp1d(im_target[idx], x.flatten()[idx], kind='cubic')
3659
+ self.val_min=im_target[idx[0]]
3660
+ self.val_max=im_target[idx[-1]]
3661
+ return im_target.reshape(x.shape)
3662
+
3663
+
3664
+ def from_gaussian(self,x):
3665
+
3666
+ x=self.backend.bk_clip_by_value(x,self.val_min,self.val_max)
3667
+ return self.f_gaussian(self.backend.to_numpy(x))
3668
+
3502
3669
  def square(self, x):
3503
3670
  if isinstance(x, scat_cov):
3504
3671
  if x.S1 is None:
@@ -3548,42 +3715,47 @@ class funct(FOC.FoCUS):
3548
3715
  return self.backend.bk_abs(self.backend.bk_sqrt(x))
3549
3716
 
3550
3717
  def reduce_mean(self, x):
3551
-
3718
+
3552
3719
  if isinstance(x, scat_cov):
3553
- result = self.backend.bk_reduce_sum(self.backend.bk_abs(x.S0)) + \
3554
- self.backend.bk_reduce_sum(self.backend.bk_abs(x.S2)) + \
3555
- self.backend.bk_reduce_sum(self.backend.bk_abs(x.S3)) + \
3556
- self.backend.bk_reduce_sum(self.backend.bk_abs(x.S4))
3557
-
3558
- N = self.backend.bk_size(x.S0)+self.backend.bk_size(x.S2)+ \
3559
- self.backend.bk_size(x.S3)+self.backend.bk_size(x.S4)
3560
-
3720
+ result = (
3721
+ self.backend.bk_reduce_sum(self.backend.bk_abs(x.S0))
3722
+ + self.backend.bk_reduce_sum(self.backend.bk_abs(x.S2))
3723
+ + self.backend.bk_reduce_sum(self.backend.bk_abs(x.S3))
3724
+ + self.backend.bk_reduce_sum(self.backend.bk_abs(x.S4))
3725
+ )
3726
+
3727
+ N = (
3728
+ self.backend.bk_size(x.S0)
3729
+ + self.backend.bk_size(x.S2)
3730
+ + self.backend.bk_size(x.S3)
3731
+ + self.backend.bk_size(x.S4)
3732
+ )
3733
+
3561
3734
  if x.S1 is not None:
3562
- result = result+self.backend.bk_reduce_sum(self.backend.bk_abs(x.S1))
3735
+ result = result + self.backend.bk_reduce_sum(self.backend.bk_abs(x.S1))
3563
3736
  N = N + self.backend.bk_size(x.S1)
3564
3737
  if x.S3P is not None:
3565
- result = result+self.backend.bk_reduce_sum(self.backend.bk_abs(x.S3P))
3738
+ result = result + self.backend.bk_reduce_sum(self.backend.bk_abs(x.S3P))
3566
3739
  N = N + self.backend.bk_size(x.S3P)
3567
- return result/self.backend.bk_cast(N)
3740
+ return result / self.backend.bk_cast(N)
3568
3741
  else:
3569
3742
  return self.backend.bk_reduce_mean(x, axis=0)
3570
-
3571
3743
 
3572
3744
  def reduce_mean_batch(self, x):
3573
-
3745
+
3574
3746
  if isinstance(x, scat_cov):
3575
-
3576
- sS0=self.backend.bk_reduce_mean(x.S0, axis=0)
3577
- sS2=self.backend.bk_reduce_mean(x.S2, axis=0)
3578
- sS3=self.backend.bk_reduce_mean(x.S3, axis=0)
3579
- sS4=self.backend.bk_reduce_mean(x.S4, axis=0)
3580
- sS1=None
3581
- sS3P=None
3747
+
3748
+ sS0 = self.backend.bk_reduce_mean(x.S0, axis=0)
3749
+ sS2 = self.backend.bk_reduce_mean(x.S2, axis=0)
3750
+ sS3 = self.backend.bk_reduce_mean(x.S3, axis=0)
3751
+ sS4 = self.backend.bk_reduce_mean(x.S4, axis=0)
3752
+ sS1 = None
3753
+ sS3P = None
3582
3754
  if x.S1 is not None:
3583
3755
  sS1 = self.backend.bk_reduce_mean(x.S1, axis=0)
3584
3756
  if x.S3P is not None:
3585
3757
  sS3P = self.backend.bk_reduce_mean(x.S3P, axis=0)
3586
-
3758
+
3587
3759
  result = scat_cov(
3588
3760
  sS0,
3589
3761
  sS2,
@@ -3597,22 +3769,22 @@ class funct(FOC.FoCUS):
3597
3769
  return result
3598
3770
  else:
3599
3771
  return self.backend.bk_reduce_mean(x, axis=0)
3600
-
3772
+
3601
3773
  def reduce_sum_batch(self, x):
3602
-
3774
+
3603
3775
  if isinstance(x, scat_cov):
3604
-
3605
- sS0=self.backend.bk_reduce_sum(x.S0, axis=0)
3606
- sS2=self.backend.bk_reduce_sum(x.S2, axis=0)
3607
- sS3=self.backend.bk_reduce_sum(x.S3, axis=0)
3608
- sS4=self.backend.bk_reduce_sum(x.S4, axis=0)
3609
- sS1=None
3610
- sS3P=None
3776
+
3777
+ sS0 = self.backend.bk_reduce_sum(x.S0, axis=0)
3778
+ sS2 = self.backend.bk_reduce_sum(x.S2, axis=0)
3779
+ sS3 = self.backend.bk_reduce_sum(x.S3, axis=0)
3780
+ sS4 = self.backend.bk_reduce_sum(x.S4, axis=0)
3781
+ sS1 = None
3782
+ sS3P = None
3611
3783
  if x.S1 is not None:
3612
3784
  sS1 = self.backend.bk_reduce_sum(x.S1, axis=0)
3613
3785
  if x.S3P is not None:
3614
3786
  sS3P = self.backend.bk_reduce_sum(x.S3P, axis=0)
3615
-
3787
+
3616
3788
  result = scat_cov(
3617
3789
  sS0,
3618
3790
  sS2,
@@ -3626,7 +3798,7 @@ class funct(FOC.FoCUS):
3626
3798
  return result
3627
3799
  else:
3628
3800
  return self.backend.bk_reduce_mean(x, axis=0)
3629
-
3801
+
3630
3802
  def reduce_distance(self, x, y, sigma=None):
3631
3803
 
3632
3804
  if isinstance(x, scat_cov):
@@ -3662,11 +3834,13 @@ class funct(FOC.FoCUS):
3662
3834
  return result
3663
3835
  else:
3664
3836
  if sigma is None:
3665
- tmp=x-y
3837
+ tmp = x - y
3666
3838
  else:
3667
- tmp=(x-y)/sigma
3839
+ tmp = (x - y) / sigma
3668
3840
  # do abs in case of complex values
3669
- return self.backend.bk_abs(self.backend.bk_reduce_mean(self.backend.bk_square(tmp)))
3841
+ return self.backend.bk_abs(
3842
+ self.backend.bk_reduce_mean(self.backend.bk_square(tmp))
3843
+ )
3670
3844
 
3671
3845
  def reduce_sum(self, x):
3672
3846
 
@@ -3842,3 +4016,122 @@ class funct(FOC.FoCUS):
3842
4016
  return scat_cov(
3843
4017
  s0, s2, s3, s4, s1=s1, s3p=s3p, backend=self.backend, use_1D=self.use_1D
3844
4018
  )
4019
+
4020
+ def synthesis(self,
4021
+ image_target,
4022
+ nstep=4,
4023
+ seed=1234,
4024
+ edge=True,
4025
+ to_gaussian=True,
4026
+ EVAL_FREQUENCY=100,
4027
+ NUM_EPOCHS = 300):
4028
+
4029
+ import foscat.Synthesis as synthe
4030
+ import time
4031
+
4032
+ def The_loss(u,scat_operator,args):
4033
+ ref = args[0]
4034
+ sref = args[1]
4035
+
4036
+ # compute scattering covariance of the current synthetised map called u
4037
+ learn=scat_operator.reduce_mean_batch(scat_operator.eval(u,edge=edge))
4038
+
4039
+ # make the difference withe the reference coordinates
4040
+ loss=scat_operator.reduce_distance(learn,ref,sigma=sref)
4041
+
4042
+ return loss
4043
+
4044
+ if to_gaussian:
4045
+ # Change the data histogram to gaussian distribution
4046
+ im_target=self.to_gaussian(image_target)
4047
+ else:
4048
+ im_target=image_target
4049
+
4050
+ axis=len(im_target.shape)-1
4051
+ if self.use_2D:
4052
+ axis-=1
4053
+ if axis==0:
4054
+ im_target=self.backend.bk_expand_dims(im_target,0)
4055
+
4056
+ # compute the number of possible steps
4057
+ if self.use_2D:
4058
+ jmax=int(np.min([np.log(im_target.shape[1]),np.log(im_target.shape[2])])/np.log(2))
4059
+ elif self.use_1D:
4060
+ jmax=int(np.log(im_target.shape[1])/np.log(2))
4061
+ else:
4062
+ jmax=int((np.log(im_target.shape[1]//12)/np.log(2))/2)
4063
+ nside=2**jmax
4064
+
4065
+ if nstep>jmax-1:
4066
+ nstep=jmax-1
4067
+
4068
+ t1=time.time()
4069
+ tmp={}
4070
+ tmp[nstep-1]=im_target
4071
+ for l in range(nstep-2,-1,-1):
4072
+ tmp[l]=self.ud_grade_2(tmp[l+1],axis=1)
4073
+
4074
+ if not self.use_2D and not self.use_1D:
4075
+ l_nside=nside//(2**(nstep-1))
4076
+
4077
+ for k in range(nstep):
4078
+ if k==0:
4079
+ np.random.seed(seed)
4080
+ if self.use_2D:
4081
+ imap=np.random.randn(tmp[k].shape[0],
4082
+ tmp[k].shape[1],
4083
+ tmp[k].shape[2])
4084
+ else:
4085
+ imap=np.random.randn(tmp[k].shape[0],
4086
+ tmp[k].shape[1])
4087
+ else:
4088
+ axis=1
4089
+ # if the kernel size is bigger than 3 increase the binning before smoothing
4090
+ if self.use_2D:
4091
+ imap = self.up_grade(
4092
+ omap, imap.shape[axis] * 2, axis=1, nouty=imap.shape[axis + 1] * 2
4093
+ )
4094
+ elif self.use_1D:
4095
+ imap = self.up_grade(omap, imap.shape[axis] * 2, axis=1)
4096
+ else:
4097
+ imap = self.up_grade(omap, l_nside, axis=axis)
4098
+
4099
+ # compute the coefficients for the target image
4100
+ ref,sref=self.eval(tmp[k],calc_var=True,edge=edge)
4101
+
4102
+ # compute the mean of the population does nothing if only one map is given
4103
+ ref=self.reduce_mean_batch(ref)
4104
+ sref=self.reduce_mean_batch(sref)
4105
+
4106
+ # define a loss to minimize
4107
+ loss=synthe.Loss(The_loss,self,ref,sref)
4108
+
4109
+ sy = synthe.Synthesis([loss])
4110
+
4111
+ # initialize the synthesised map
4112
+ if self.use_2D:
4113
+ print('Synthesis scale [ %d x %d ]'%(imap.shape[1],imap.shape[2]))
4114
+ elif self.use_1D:
4115
+ print('Synthesis scale [ %d ]'%(imap.shape[1]))
4116
+ else:
4117
+ print('Synthesis scale nside=%d'%(l_nside))
4118
+ l_nside*=2
4119
+
4120
+ # do the minimization
4121
+ omap=sy.run(imap,
4122
+ EVAL_FREQUENCY=EVAL_FREQUENCY,
4123
+ NUM_EPOCHS = NUM_EPOCHS)
4124
+
4125
+ t2=time.time()
4126
+ print('Total computation %.2fs'%(t2-t1))
4127
+
4128
+ if to_gaussian:
4129
+ omap=self.from_gaussian(omap)
4130
+
4131
+ if axis==0:
4132
+ return omap[0]
4133
+ else:
4134
+ return omap
4135
+
4136
+ def to_numpy(self,x):
4137
+ return self.backend.to_numpy(x)