foscat 2025.3.0__py3-none-any.whl → 2025.5.2__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
foscat/scat_cov.py CHANGED
@@ -2378,16 +2378,16 @@ class funct(FOC.FoCUS):
2378
2378
  # of cosine and sine contributions using all channels
2379
2379
  angles = self.backend.bk_cast(
2380
2380
  (2 * np.pi * np.arange(self.NORIENT) / self.NORIENT).reshape(
2381
- 1, 1, self.NORIENT
2381
+ 1, self.NORIENT, 1
2382
2382
  )
2383
2383
  ) # shape: (NORIENT,)
2384
2384
 
2385
2385
  # we use cosines and sines as weights for sim
2386
2386
  weighted_cos = self.backend.bk_reduce_mean(
2387
- sim * self.backend.bk_cos(angles), axis=-1
2387
+ sim * self.backend.bk_cos(angles), axis=-2
2388
2388
  )
2389
2389
  weighted_sin = self.backend.bk_reduce_mean(
2390
- sim * self.backend.bk_sin(angles), axis=-1
2390
+ sim * self.backend.bk_sin(angles), axis=-2
2391
2391
  )
2392
2392
  # For simplicity, take first element of the batch
2393
2393
  cc = weighted_cos[0]
@@ -2423,43 +2423,46 @@ class funct(FOC.FoCUS):
2423
2423
  w1 = np.sin(delta * np.pi / 2) ** 2
2424
2424
 
2425
2425
  # build rotation matrix
2426
- mat = np.zeros([sim.shape[1], self.NORIENT * self.NORIENT])
2427
- lidx = np.arange(sim.shape[1])
2426
+ mat = np.zeros([self.NORIENT * self.NORIENT, sim.shape[2]])
2427
+ lidx = np.arange(sim.shape[2])
2428
2428
  for ell in range(self.NORIENT):
2429
2429
  # Instead of simple linear weights, we use the cosine weights w0 and w1.
2430
2430
  col0 = self.NORIENT * ((ell + iph) % self.NORIENT) + ell
2431
2431
  col1 = self.NORIENT * ((ell + iph + 1) % self.NORIENT) + ell
2432
- mat[lidx, col0] = w0
2433
- mat[lidx, col1] = w1
2434
2432
 
2435
- cmat[k] = self.backend.bk_cast(mat.astype("complex64"))
2433
+ mat[col0, lidx] = w0
2434
+ mat[col1, lidx] = w1
2435
+
2436
+ cmat[k] = self.backend.bk_cast(mat[None, ...].astype("complex64"))
2436
2437
 
2437
2438
  # do same modifications for mat2
2438
2439
  mat2 = np.zeros(
2439
- [k + 1, sim.shape[1], self.NORIENT, self.NORIENT * self.NORIENT]
2440
+ [k + 1, self.NORIENT * self.NORIENT, self.NORIENT, sim.shape[2]]
2440
2441
  )
2441
2442
 
2442
2443
  for k2 in range(k + 1):
2443
- tmp2 = self.backend.bk_repeat(sim, self.NORIENT, axis=-1)
2444
+
2445
+ tmp2 = self.backend.bk_repeat(sim, self.NORIENT, axis=1)
2444
2446
 
2445
2447
  sim2 = self.backend.bk_reduce_sum(
2446
2448
  self.backend.bk_reshape(
2447
2449
  self.backend.bk_cast(
2448
- mat.reshape(1, mat.shape[0], self.NORIENT * self.NORIENT)
2450
+ mat.reshape(1, self.NORIENT * self.NORIENT, mat.shape[1])
2449
2451
  )
2450
2452
  * tmp2,
2451
- [sim.shape[0], cmat[k].shape[0], self.NORIENT, self.NORIENT],
2453
+ [sim.shape[0], self.NORIENT, self.NORIENT, mat.shape[1]],
2452
2454
  ),
2453
- 2,
2455
+ 1,
2454
2456
  )
2455
2457
 
2456
- sim2 = self.backend.bk_abs(self.convol(sim2, axis=1))
2458
+ sim2 = self.backend.bk_abs(self.convol(sim2, axis=-1))
2457
2459
 
2460
+ angles = self.backend.bk_reshape(angles, [1, self.NORIENT, 1, 1])
2458
2461
  weighted_cos2 = self.backend.bk_reduce_mean(
2459
- sim2 * self.backend.bk_cos(angles), axis=-1
2462
+ sim2 * self.backend.bk_cos(angles), axis=1
2460
2463
  )
2461
2464
  weighted_sin2 = self.backend.bk_reduce_mean(
2462
- sim2 * self.backend.bk_sin(angles), axis=-1
2465
+ sim2 * self.backend.bk_sin(angles), axis=1
2463
2466
  )
2464
2467
 
2465
2468
  cc2 = weighted_cos2[0]
@@ -2467,14 +2470,14 @@ class funct(FOC.FoCUS):
2467
2470
 
2468
2471
  if smooth_scale > 0:
2469
2472
  for m in range(smooth_scale):
2470
- if cc2.shape[0] > 12:
2471
- cc2, _ = self.ud_grade_2(self.smooth(cc2))
2472
- ss2, _ = self.ud_grade_2(self.smooth(ss2))
2473
+ if cc2.shape[1] > 12:
2474
+ cc2, _ = self.ud_grade_2(self.smooth(cc2, axis=1), axis=1)
2475
+ ss2, _ = self.ud_grade_2(self.smooth(ss2, axis=1), axis=1)
2473
2476
 
2474
- if cc2.shape[0] != sim.shape[1]:
2475
- ll_nside = int(np.sqrt(sim.shape[1] // 12))
2476
- cc2 = self.up_grade(cc2, ll_nside)
2477
- ss2 = self.up_grade(ss2, ll_nside)
2477
+ if cc2.shape[1] != sim.shape[2]:
2478
+ ll_nside = int(np.sqrt(sim.shape[2] // 12))
2479
+ cc2 = self.up_grade(cc2, ll_nside, axis=1)
2480
+ ss2 = self.up_grade(ss2, ll_nside, axis=1)
2478
2481
 
2479
2482
  if self.BACKEND == "numpy":
2480
2483
  phase2 = np.fmod(np.arctan2(ss2, cc2) + 2 * np.pi, 2 * np.pi)
@@ -2492,17 +2495,18 @@ class funct(FOC.FoCUS):
2492
2495
  delta2 = phase2_scaled - iph2
2493
2496
  w0_2 = np.cos(delta2 * np.pi / 2) ** 2
2494
2497
  w1_2 = np.sin(delta2 * np.pi / 2) ** 2
2495
- lidx = np.arange(sim.shape[1])
2498
+ lidx = np.arange(sim.shape[2])
2496
2499
 
2497
2500
  for m in range(self.NORIENT):
2498
2501
  for ell in range(self.NORIENT):
2499
- col0 = self.NORIENT * ((ell + iph2[:, m]) % self.NORIENT) + ell
2500
- col1 = (
2501
- self.NORIENT * ((ell + iph2[:, m] + 1) % self.NORIENT) + ell
2502
- )
2503
- mat2[k2, lidx, m, col0] = w0_2[:, m]
2504
- mat2[k2, lidx, m, col1] = w1_2[:, m]
2505
- cmat2[k] = self.backend.bk_cast(mat2.astype("complex64"))
2502
+ col0 = self.NORIENT * ((ell + iph2[m]) % self.NORIENT) + ell
2503
+ col1 = self.NORIENT * ((ell + iph2[m] + 1) % self.NORIENT) + ell
2504
+ mat2[k2, col0, m, lidx] = w0_2[m, lidx]
2505
+ mat2[k2, col1, m, lidx] = w1_2[m, lidx]
2506
+
2507
+ cmat2[k] = self.backend.bk_cast(
2508
+ mat2[0 : k + 1, None, ...].astype("complex64")
2509
+ )
2506
2510
 
2507
2511
  if k < l_nside - 1:
2508
2512
  tmp, _ = self.ud_grade_2(tmp, axis=1)
@@ -2620,6 +2624,9 @@ class funct(FOC.FoCUS):
2620
2624
  x1 = im_shape[1]
2621
2625
  x2 = im_shape[2]
2622
2626
  J = int(np.log(nside - self.KERNELSZ) / np.log(2)) # Number of j scales
2627
+ if J == 0:
2628
+ print("Use of too small 2D domain does not work J_max=", J)
2629
+ return None
2623
2630
  elif self.use_1D:
2624
2631
  if len(image1.shape) == 2:
2625
2632
  npix = int(im_shape[1]) # Number of pixels
@@ -2842,21 +2849,25 @@ class funct(FOC.FoCUS):
2842
2849
  ### Make the convolution I1 * Psi_j3
2843
2850
  conv1 = self.convol(
2844
2851
  I1, axis=1, cell_ids=cell_ids_j3, nside=nside_j3
2845
- ) # [Nbatch, Npix_j3, Norient3]
2852
+ ) # [Nbatch, Norient3 , Npix_j3]
2853
+
2846
2854
  if cmat is not None:
2847
- tmp2 = self.backend.bk_repeat(conv1, self.NORIENT, axis=-1)
2855
+
2856
+ tmp2 = self.backend.bk_repeat(conv1, self.NORIENT, axis=-2)
2857
+
2848
2858
  conv1 = self.backend.bk_reduce_sum(
2849
2859
  self.backend.bk_reshape(
2850
2860
  cmat[j3] * tmp2,
2851
- [tmp2.shape[0], cmat[j3].shape[0], self.NORIENT, self.NORIENT],
2861
+ [tmp2.shape[0], self.NORIENT, self.NORIENT, cmat[j3].shape[2]],
2852
2862
  ),
2853
- 2,
2863
+ 1,
2854
2864
  )
2855
2865
 
2856
2866
  ### Take the module M1 = |I1 * Psi_j3|
2857
2867
  M1_square = conv1 * self.backend.bk_conjugate(
2858
2868
  conv1
2859
- ) # [Nbatch, Npix_j3, Norient3]
2869
+ ) # [Nbatch, Norient3, Npix_j3]
2870
+
2860
2871
  M1 = self.backend.bk_L1(M1_square) # [Nbatch, Npix_j3, Norient3]
2861
2872
  # Store M1_j3 in a dictionary
2862
2873
  M1_dic[j3] = M1
@@ -2871,10 +2882,10 @@ class funct(FOC.FoCUS):
2871
2882
  else:
2872
2883
  if calc_var:
2873
2884
  s2, vs2 = self.masked_mean(
2874
- M1_square, vmask, axis=1, rank=j3, calc_var=True
2885
+ M1_square, vmask, axis=2, rank=j3, calc_var=True
2875
2886
  )
2876
2887
  else:
2877
- s2 = self.masked_mean(M1_square, vmask, axis=1, rank=j3)
2888
+ s2 = self.masked_mean(M1_square, vmask, axis=2, rank=j3)
2878
2889
 
2879
2890
  if cond_init_P1_dic:
2880
2891
  # We fill P1_dic with S2 for normalisation of S3 and S4
@@ -2890,9 +2901,9 @@ class funct(FOC.FoCUS):
2890
2901
  s2,
2891
2902
  [
2892
2903
  s2.shape[0],
2904
+ s2.shape[2],
2893
2905
  12 * out_nside**2,
2894
2906
  (nside_j3 // out_nside) ** 2,
2895
- s2.shape[2],
2896
2907
  ],
2897
2908
  ),
2898
2909
  2,
@@ -2918,11 +2929,11 @@ class funct(FOC.FoCUS):
2918
2929
  else:
2919
2930
  if calc_var:
2920
2931
  s1, vs1 = self.masked_mean(
2921
- M1, vmask, axis=1, rank=j3, calc_var=True
2932
+ M1, vmask, axis=2, rank=j3, calc_var=True
2922
2933
  ) # [Nbatch, Nmask, Norient3]
2923
2934
  else:
2924
2935
  s1 = self.masked_mean(
2925
- M1, vmask, axis=1, rank=j3
2936
+ M1, vmask, axis=2, rank=j3
2926
2937
  ) # [Nbatch, Nmask, Norient3]
2927
2938
 
2928
2939
  if return_data:
@@ -2932,9 +2943,9 @@ class funct(FOC.FoCUS):
2932
2943
  s1,
2933
2944
  [
2934
2945
  s1.shape[0],
2946
+ s1.shape[2],
2935
2947
  12 * out_nside**2,
2936
2948
  (nside_j3 // out_nside) ** 2,
2937
- s1.shape[2],
2938
2949
  ],
2939
2950
  ),
2940
2951
  2,
@@ -2956,21 +2967,21 @@ class funct(FOC.FoCUS):
2956
2967
  else: # Cross
2957
2968
  ### Make the convolution I2 * Psi_j3
2958
2969
  conv2 = self.convol(
2959
- I2, axis=1, cell_ids=cell_ids_j3, nside=nside_j3
2970
+ I2, axis=2, cell_ids=cell_ids_j3, nside=nside_j3
2960
2971
  ) # [Nbatch, Npix_j3, Norient3]
2961
2972
  if cmat is not None:
2962
- tmp2 = self.backend.bk_repeat(conv2, self.NORIENT, axis=-1)
2973
+ tmp2 = self.backend.bk_repeat(conv2, self.NORIENT, axis=-2)
2963
2974
  conv2 = self.backend.bk_reduce_sum(
2964
2975
  self.backend.bk_reshape(
2965
2976
  cmat[j3] * tmp2,
2966
2977
  [
2967
2978
  tmp2.shape[0],
2968
- cmat[j3].shape[0],
2969
2979
  self.NORIENT,
2970
2980
  self.NORIENT,
2981
+ cmat[j3].shape[2],
2971
2982
  ],
2972
2983
  ),
2973
- 2,
2984
+ 1,
2974
2985
  )
2975
2986
  ### Take the module M2 = |I2 * Psi_j3|
2976
2987
  M2_square = conv2 * self.backend.bk_conjugate(
@@ -2990,17 +3001,17 @@ class funct(FOC.FoCUS):
2990
3001
  else:
2991
3002
  if calc_var:
2992
3003
  p1, vp1 = self.masked_mean(
2993
- M1_square, vmask, axis=1, rank=j3, calc_var=True
3004
+ M1_square, vmask, axis=2, rank=j3, calc_var=True
2994
3005
  ) # [Nbatch, Nmask, Norient3]
2995
3006
  p2, vp2 = self.masked_mean(
2996
- M2_square, vmask, axis=1, rank=j3, calc_var=True
3007
+ M2_square, vmask, axis=2, rank=j3, calc_var=True
2997
3008
  ) # [Nbatch, Nmask, Norient3]
2998
3009
  else:
2999
3010
  p1 = self.masked_mean(
3000
- M1_square, vmask, axis=1, rank=j3
3011
+ M1_square, vmask, axis=2, rank=j3
3001
3012
  ) # [Nbatch, Nmask, Norient3]
3002
3013
  p2 = self.masked_mean(
3003
- M2_square, vmask, axis=1, rank=j3
3014
+ M2_square, vmask, axis=2, rank=j3
3004
3015
  ) # [Nbatch, Nmask, Norient3]
3005
3016
  # We fill P1_dic with S2 for normalisation of S3 and S4
3006
3017
  P1_dic[j3] = self.backend.bk_real(p1) # [Nbatch, Nmask, Norient3]
@@ -3016,10 +3027,10 @@ class funct(FOC.FoCUS):
3016
3027
  else:
3017
3028
  if calc_var:
3018
3029
  s2, vs2 = self.masked_mean(
3019
- s2, vmask, axis=1, rank=j3, calc_var=True
3030
+ s2, vmask, axis=2, rank=j3, calc_var=True
3020
3031
  )
3021
3032
  else:
3022
- s2 = self.masked_mean(s2, vmask, axis=1, rank=j3)
3033
+ s2 = self.masked_mean(s2, vmask, axis=2, rank=j3)
3023
3034
 
3024
3035
  if return_data:
3025
3036
  if out_nside is not None and out_nside < nside_j3:
@@ -3028,22 +3039,23 @@ class funct(FOC.FoCUS):
3028
3039
  s2,
3029
3040
  [
3030
3041
  s2.shape[0],
3042
+ s2.shape[2],
3031
3043
  12 * out_nside**2,
3032
3044
  (nside_j3 // out_nside) ** 2,
3033
- s2.shape[2],
3034
3045
  ],
3035
3046
  ),
3036
3047
  2,
3037
3048
  )
3038
3049
  S2[j3] = s2
3039
3050
  else:
3040
- ### Normalize S2_cross
3041
- if norm == "auto":
3042
- s2 /= (P1_dic[j3] * P2_dic[j3]) ** 0.5
3043
3051
 
3044
3052
  ### Store S2_cross as complex [Nbatch, Nmask, NS2, Norient3]
3045
3053
  s2 = self.backend.bk_real(s2)
3046
3054
 
3055
+ ### Normalize S2_cross
3056
+ if norm == "auto":
3057
+ s2 /= (P1_dic[j3] * P2_dic[j3]) ** 0.5
3058
+
3047
3059
  S2.append(
3048
3060
  self.backend.bk_expand_dims(s2, off_S2)
3049
3061
  ) # Add a dimension for NS2
@@ -3060,11 +3072,11 @@ class funct(FOC.FoCUS):
3060
3072
  else:
3061
3073
  if calc_var:
3062
3074
  s1, vs1 = self.masked_mean(
3063
- MX, vmask, axis=1, rank=j3, calc_var=True
3075
+ MX, vmask, axis=2, rank=j3, calc_var=True
3064
3076
  ) # [Nbatch, Nmask, Norient3]
3065
3077
  else:
3066
3078
  s1 = self.masked_mean(
3067
- MX, vmask, axis=1, rank=j3
3079
+ MX, vmask, axis=2, rank=j3
3068
3080
  ) # [Nbatch, Nmask, Norient3]
3069
3081
  if return_data:
3070
3082
  if out_nside is not None and out_nside < nside_j3:
@@ -3073,9 +3085,9 @@ class funct(FOC.FoCUS):
3073
3085
  s1,
3074
3086
  [
3075
3087
  s1.shape[0],
3088
+ s1.shape[2],
3076
3089
  12 * out_nside**2,
3077
3090
  (nside_j3 // out_nside) ** 2,
3078
- s1.shape[2],
3079
3091
  ],
3080
3092
  ),
3081
3093
  2,
@@ -3486,19 +3498,19 @@ class funct(FOC.FoCUS):
3486
3498
  for j2 in range(0, j3 + 1): # j2 =< j3
3487
3499
  ### Dictionary M1_dic[j2]
3488
3500
  M1_smooth = self.smooth(
3489
- M1_dic[j2], axis=1, cell_ids=cell_ids_j3, nside=nside_j3
3501
+ M1_dic[j2], axis=2, cell_ids=cell_ids_j3, nside=nside_j3
3490
3502
  ) # [Nbatch, Npix_j3, Norient3]
3491
3503
  M1_dic[j2], new_cell_ids_j2 = self.ud_grade_2(
3492
- M1_smooth, axis=1, cell_ids=cell_ids_j3, nside=nside_j3
3504
+ M1_smooth, axis=2, cell_ids=cell_ids_j3, nside=nside_j3
3493
3505
  ) # [Nbatch, Npix_j3, Norient3]
3494
3506
 
3495
3507
  ### Dictionary M2_dic[j2]
3496
3508
  if cross:
3497
3509
  M2_smooth = self.smooth(
3498
- M2_dic[j2], axis=1, cell_ids=cell_ids_j3, nside=nside_j3
3510
+ M2_dic[j2], axis=2, cell_ids=cell_ids_j3, nside=nside_j3
3499
3511
  ) # [Nbatch, Npix_j3, Norient3]
3500
3512
  M2_dic[j2], new_cell_ids_j2 = self.ud_grade_2(
3501
- M2_smooth, axis=1, cell_ids=cell_ids_j3, nside=nside_j3
3513
+ M2_smooth, axis=2, cell_ids=cell_ids_j3, nside=nside_j3
3502
3514
  ) # [Nbatch, Npix_j3, Norient3]
3503
3515
  ### Mask
3504
3516
  vmask, new_cell_ids_j3 = self.ud_grade_2(
@@ -3623,51 +3635,51 @@ class funct(FOC.FoCUS):
3623
3635
  cs3, ss3: real and imag parts of S3 coeff
3624
3636
  """
3625
3637
  ### Compute |I1 * Psi_j2| * Psi_j3 = M1_j2 * Psi_j3
3626
- # Warning: M1_dic[j2] is already at j3 resolution [Nbatch, Npix_j3, Norient3]
3638
+ # Warning: M1_dic[j2] is already at j3 resolution [Nbatch, Norient3, Npix_j3]
3627
3639
  MconvPsi = self.convol(
3628
- M_dic[j2], axis=1, cell_ids=cell_ids, nside=nside_j2
3629
- ) # [Nbatch, Npix_j3, Norient3, Norient2]
3640
+ M_dic[j2], axis=2, cell_ids=cell_ids, nside=nside_j2
3641
+ ) # [Nbatch, Norient3, Norient2, Npix_j3]
3642
+
3630
3643
  if cmat2 is not None:
3631
- tmp2 = self.backend.bk_repeat(MconvPsi, self.NORIENT, axis=-1)
3644
+ tmp2 = self.backend.bk_repeat(MconvPsi, self.NORIENT, axis=-3)
3632
3645
  MconvPsi = self.backend.bk_reduce_sum(
3633
3646
  self.backend.bk_reshape(
3634
3647
  cmat2[j3][j2] * tmp2,
3635
3648
  [
3636
3649
  tmp2.shape[0],
3637
- cmat2[j3].shape[1],
3638
3650
  self.NORIENT,
3639
3651
  self.NORIENT,
3640
3652
  self.NORIENT,
3653
+ cmat2[j3][j2].shape[3],
3641
3654
  ],
3642
3655
  ),
3643
- 3,
3656
+ 1,
3644
3657
  )
3645
3658
 
3646
3659
  # Store it so we can use it in S4 computation
3647
- MconvPsi_dic[j2] = MconvPsi # [Nbatch, Npix_j3, Norient3, Norient2]
3660
+ MconvPsi_dic[j2] = MconvPsi # [Nbatch, Norient3, Norient2, Npix_j3]
3648
3661
 
3649
3662
  ### Compute the product (I2 * Psi)_j3 x (M1_j2 * Psi_j3)^*
3650
3663
  # z_1 x z_2^* = (a1a2 + b1b2) + i(b1a2 - a1b2)
3651
- # cconv, sconv are [Nbatch, Npix_j3, Norient3]
3664
+ # cconv, sconv are [Nbatch, Norient3, Npix_j3]
3652
3665
  if self.use_1D:
3653
3666
  s3 = conv * self.backend.bk_conjugate(MconvPsi)
3654
3667
  else:
3655
- s3 = self.backend.bk_expand_dims(conv, -1) * self.backend.bk_conjugate(
3668
+ s3 = self.backend.bk_expand_dims(conv, -3) * self.backend.bk_conjugate(
3656
3669
  MconvPsi
3657
- ) # [Nbatch, Npix_j3, Norient3, Norient2]
3658
-
3670
+ ) # [Nbatch, Norient3, Norient2, Npix_j3]
3659
3671
  ### Apply the mask [Nmask, Npix_j3] and sum over pixels
3660
3672
  if return_data:
3661
3673
  return s3
3662
3674
  else:
3663
3675
  if calc_var:
3664
3676
  s3, vs3 = self.masked_mean(
3665
- s3, vmask, axis=1, rank=j2, calc_var=True
3677
+ s3, vmask, axis=3, rank=j2, calc_var=True
3666
3678
  ) # [Nbatch, Nmask, Norient3, Norient2]
3667
3679
  return s3, vs3
3668
3680
  else:
3669
3681
  s3 = self.masked_mean(
3670
- s3, vmask, axis=1, rank=j2
3682
+ s3, vmask, axis=3, rank=j2
3671
3683
  ) # [Nbatch, Nmask, Norient3, Norient2]
3672
3684
  return s3
3673
3685
 
@@ -3682,11 +3694,11 @@ class funct(FOC.FoCUS):
3682
3694
  return_data=False,
3683
3695
  ):
3684
3696
  #### Simplify notations
3685
- M1 = M1convPsi_dic[j1] # [Nbatch, Npix_j3, Norient3, Norient1]
3697
+ M1 = M1convPsi_dic[j1] # [Nbatch, Norient3, Norient1, Npix_j3]
3686
3698
 
3687
3699
  # Auto or Cross coefficients
3688
3700
  if M2convPsi_dic is None: # Auto
3689
- M2 = M1convPsi_dic[j2] # [Nbatch, Npix_j3, Norient3, Norient2]
3701
+ M2 = M1convPsi_dic[j2] # [Nbatch, Norient3, Norient2, Npix_j3]
3690
3702
  else: # Cross
3691
3703
  M2 = M2convPsi_dic[j2]
3692
3704
 
@@ -3695,9 +3707,9 @@ class funct(FOC.FoCUS):
3695
3707
  if self.use_1D:
3696
3708
  s4 = M1 * self.backend.bk_conjugate(M2)
3697
3709
  else:
3698
- s4 = self.backend.bk_expand_dims(M1, -2) * self.backend.bk_conjugate(
3699
- self.backend.bk_expand_dims(M2, -1)
3700
- ) # [Nbatch, Npix_j3, Norient3, Norient2, Norient1]
3710
+ s4 = self.backend.bk_expand_dims(M1, -4) * self.backend.bk_conjugate(
3711
+ self.backend.bk_expand_dims(M2, -3)
3712
+ ) # [Nbatch, Norient3, Norient2, Norient1,Npix_j3]
3701
3713
 
3702
3714
  ### Apply the mask and sum over pixels
3703
3715
  if return_data:
@@ -3705,12 +3717,12 @@ class funct(FOC.FoCUS):
3705
3717
  else:
3706
3718
  if calc_var:
3707
3719
  s4, vs4 = self.masked_mean(
3708
- s4, vmask, axis=1, rank=j2, calc_var=True
3720
+ s4, vmask, axis=4, rank=j2, calc_var=True
3709
3721
  ) # [Nbatch, Nmask, Norient3, Norient2, Norient1]
3710
3722
  return s4, vs4
3711
3723
  else:
3712
3724
  s4 = self.masked_mean(
3713
- s4, vmask, axis=1, rank=j2
3725
+ s4, vmask, axis=4, rank=j2
3714
3726
  ) # [Nbatch, Nmask, Norient3, Norient2, Norient1]
3715
3727
  return s4
3716
3728
 
@@ -6045,14 +6057,11 @@ class funct(FOC.FoCUS):
6045
6057
  image2=None,
6046
6058
  mask=None,
6047
6059
  norm=None,
6048
- Auto=True,
6049
6060
  cmat=None,
6050
6061
  cmat2=None,
6051
6062
  ):
6052
6063
 
6053
- res = self.eval(
6054
- image1, image2=image2, mask=mask, Auto=Auto, cmat=cmat, cmat2=cmat2
6055
- )
6064
+ res = self.eval(image1, image2=image2, mask=mask, cmat=cmat, cmat2=cmat2)
6056
6065
  return res.S0, res.S2, res.S1, res.S3, res.S4, res.S3P
6057
6066
 
6058
6067
  def eval_fast(
@@ -6061,12 +6070,11 @@ class funct(FOC.FoCUS):
6061
6070
  image2=None,
6062
6071
  mask=None,
6063
6072
  norm=None,
6064
- Auto=True,
6065
6073
  cmat=None,
6066
6074
  cmat2=None,
6067
6075
  ):
6068
6076
  s0, s2, s1, s3, s4, s3p = self.eval_comp_fast(
6069
- image1, image2=image2, mask=mask, Auto=Auto, cmat=cmat, cmat2=cmat2
6077
+ image1, image2=image2, mask=mask, cmat=cmat, cmat2=cmat2
6070
6078
  )
6071
6079
  return scat_cov(
6072
6080
  s0, s2, s3, s4, s1=s1, s3p=s3p, backend=self.backend, use_1D=self.use_1D
@@ -6102,6 +6110,15 @@ class funct(FOC.FoCUS):
6102
6110
  if edge:
6103
6111
  self.purge_edge_mask()
6104
6112
 
6113
+ def The_loss_ref_image(u, scat_operator, args):
6114
+ input_image = args[0]
6115
+ mask = args[1]
6116
+
6117
+ loss = 1e-3 * scat_operator.backend.bk_reduce_mean(
6118
+ scat_operator.backend.bk_square(mask * (input_image - u))
6119
+ )
6120
+ return loss
6121
+
6105
6122
  def The_loss(u, scat_operator, args):
6106
6123
  ref = args[0]
6107
6124
  sref = args[1]
@@ -6339,7 +6356,18 @@ class funct(FOC.FoCUS):
6339
6356
  The_lossX, self, ref, sref, use_variance, l_ref[k], l_jmax[k]
6340
6357
  )
6341
6358
 
6342
- sy = synthe.Synthesis([loss])
6359
+ if input_image is not None:
6360
+ # define a loss to minimize
6361
+ loss_input = synthe.Loss(
6362
+ The_loss_ref_image,
6363
+ self,
6364
+ self.backend.bk_cast(l_input_image[k]),
6365
+ self.backend.bk_cast(l_in_mask[k]),
6366
+ )
6367
+
6368
+ sy = synthe.Synthesis([loss]) # ,loss_input])
6369
+ else:
6370
+ sy = synthe.Synthesis([loss])
6343
6371
 
6344
6372
  # initialize the synthesised map
6345
6373
  if self.use_2D:
foscat/scat_cov_map.py CHANGED
@@ -29,7 +29,6 @@ class funct(scat.funct):
29
29
  image2=None,
30
30
  mask=None,
31
31
  norm=None,
32
- Auto=True,
33
32
  calc_var=False,
34
33
  out_nside=None,
35
34
  ):
@@ -38,7 +37,6 @@ class funct(scat.funct):
38
37
  image2=image2,
39
38
  mask=mask,
40
39
  norm=norm,
41
- Auto=Auto,
42
40
  calc_var=calc_var,
43
41
  out_nside=out_nside,
44
42
  )
foscat/scat_cov_map2D.py CHANGED
@@ -23,10 +23,10 @@ class funct(scat.funct):
23
23
  super().__init__(use_2D=True, return_data=True, *args, **kwargs)
24
24
 
25
25
  def eval(
26
- self, image1, image2=None, mask=None, norm=None, Auto=True, calc_var=False
26
+ self, image1, image2=None, mask=None, norm=None, calc_var=False, Jmax=None
27
27
  ):
28
28
  r = super().eval(
29
- image1, image2=image2, mask=mask, norm=norm, Auto=Auto, calc_var=calc_var
29
+ image1, image2=image2, mask=mask, norm=norm, calc_var=calc_var, Jmax=Jmax
30
30
  )
31
31
  return scat_cov_map(
32
32
  r.S2, r.S0, r.S3, r.S4, S1=r.S1, S3P=r.S3P, backend=r.backend
@@ -1,6 +1,6 @@
1
1
  Metadata-Version: 2.4
2
2
  Name: foscat
3
- Version: 2025.3.0
3
+ Version: 2025.5.2
4
4
  Summary: Generate synthetic Healpix or 2D data using Cross Scattering Transform
5
5
  Author-email: Jean-Marc DELOUIS <jean.marc.delouis@ifremer.fr>
6
6
  Maintainer-email: Theo Foulquier <theo.foulquier@ifremer.fr>
@@ -1,10 +1,10 @@
1
- foscat/BkBase.py,sha256=_iszgMdVIVEB47EBxNt5xemsdaKzsNFPStDF00M_-Ng,21281
1
+ foscat/BkBase.py,sha256=TEhfqUpIOh_bGBCyQfRCs0yjKdhgELjFpvq_QiouX5A,21514
2
2
  foscat/BkNumpy.py,sha256=zRldS_-L6A7y1zDzEPZXQntuw3Paw2zHZowhD43FHRs,10589
3
- foscat/BkTensorflow.py,sha256=N5TBacuyFB1-qGTi2kOc8zbgWzj5lVRRN47uZJpJJ10,15713
4
- foscat/BkTorch.py,sha256=011L9WCBtrRzV1jfGWCYOMSkt1IJ0PfEO82NnrIYbAc,16648
3
+ foscat/BkTensorflow.py,sha256=K2s3xYVMHqLlTyApQpeKf9dc3hbRv8EqtCA_gE1bRQA,19958
4
+ foscat/BkTorch.py,sha256=7TwJxSw_XGyJyvombHrfsT7LFm4RNHt1sDQD53Fbll0,18105
5
5
  foscat/CNN.py,sha256=j0F2a4Xf3LijhyD_WVZ6Eg_IjGuXw3ddH6Iudj1xVaw,4874
6
6
  foscat/CircSpline.py,sha256=CXi49FxF8ZoeZ17Ua8c1AZXe2B5ICEC9aCXb97atB3s,4028
7
- foscat/FoCUS.py,sha256=iCWuhQqYQ1ub3F0flO2iVuMoN7gCDd1oZ79SIH9-oww,108768
7
+ foscat/FoCUS.py,sha256=kKjuLvdCrR8VHjnxzdgYSCpe8LnNlHaDOV6Riu8UEO4,95289
8
8
  foscat/GCNN.py,sha256=5RV-FKuvqbD-k99TwiM4CttM2LMZE21WD0IK0j5Mkko,7599
9
9
  foscat/Softmax.py,sha256=aBLQauoG0q2SJYPotV6U-cxAhsJcspWHNRWdnA_nAiQ,2854
10
10
  foscat/Spline1D.py,sha256=rKzzenduaZZ-yBDJd35it6Gyrj1spqb7hoIaUgISPzY,2983
@@ -18,13 +18,13 @@ foscat/loss_backend_torch.py,sha256=k3z18Dj3SaLKK6ZIKcm7GO4U_YKYVP6LtHG1aIbxkYk,
18
18
  foscat/scat.py,sha256=qGYiBIysPt65MdmF07WWA4piVlTfA9-lFDTaicnqC2w,72822
19
19
  foscat/scat1D.py,sha256=W5Uu6wdQ4ZsFKXpof0f1OBl-1wjJmW7ruvddRWxe7uM,53726
20
20
  foscat/scat2D.py,sha256=boKj0ASqMMSy7uQLK6hPniG87m3hZGJBYBiq5v8F9IQ,532
21
- foscat/scat_cov.py,sha256=ZOFDWNC8q04N6Tvpe7RxSWlRgJ8jgsIyPJ_EJ39CXOg,258297
21
+ foscat/scat_cov.py,sha256=tpuYqyPwYdG8vjm4uwYUjajeHzxYwCto-Bky475oPXA,259203
22
22
  foscat/scat_cov1D.py,sha256=XOxsZZ5TYq8f34i2tUgIfzyaqaTDlICB3HzD2l_puro,531
23
23
  foscat/scat_cov2D.py,sha256=pAm0fKw8wyXram0TFbtw8tGcc8QPKuPXpQk0kh10r4U,7078
24
- foscat/scat_cov_map.py,sha256=Swt39-nYEaQkBzyX4EOAQBvUuYQpERzJ-uVxSWS2b-Y,2911
25
- foscat/scat_cov_map2D.py,sha256=FqF45FBcoiQbvuVsrLWUIPRUc95GsKsrnH6fKzB3GlE,2841
26
- foscat-2025.3.0.dist-info/licenses/LICENSE,sha256=i0ukIr8ZUpkSY2sZaE9XZK-6vuSU5iG6IgX_3pjatP8,1505
27
- foscat-2025.3.0.dist-info/METADATA,sha256=hfPyoLACrvhbOjv70ve--tjL5CseBXEHMFvC3CNzSr0,7215
28
- foscat-2025.3.0.dist-info/WHEEL,sha256=CmyFI0kx5cdEMTLiONQRbGQwjIoR1aIYB7eCAQ4KPJ0,91
29
- foscat-2025.3.0.dist-info/top_level.txt,sha256=AGySXBBAlJgb8Tj8af6m_F-aiNg2zNTcybCUPVOKjAg,7
30
- foscat-2025.3.0.dist-info/RECORD,,
24
+ foscat/scat_cov_map.py,sha256=9MzpwT2g9S3dmnjHEMK7PPLQ27oGQg2VFVsP_TDUU5E,2869
25
+ foscat/scat_cov_map2D.py,sha256=1dS4P1KHqZYkYCLA1sYpPSZulJrCTd_2eL8HFOjlcz4,2841
26
+ foscat-2025.5.2.dist-info/licenses/LICENSE,sha256=i0ukIr8ZUpkSY2sZaE9XZK-6vuSU5iG6IgX_3pjatP8,1505
27
+ foscat-2025.5.2.dist-info/METADATA,sha256=sbR5VT1kq7d1qrLd74c-CdprtKhpkC6eHQ1nE24nqtI,7215
28
+ foscat-2025.5.2.dist-info/WHEEL,sha256=_zCd3N1l69ArxyTb8rzEoP9TpbYXkqRFSNOD5OuxnTs,91
29
+ foscat-2025.5.2.dist-info/top_level.txt,sha256=AGySXBBAlJgb8Tj8af6m_F-aiNg2zNTcybCUPVOKjAg,7
30
+ foscat-2025.5.2.dist-info/RECORD,,
@@ -1,5 +1,5 @@
1
1
  Wheel-Version: 1.0
2
- Generator: setuptools (78.1.0)
2
+ Generator: setuptools (80.9.0)
3
3
  Root-Is-Purelib: true
4
4
  Tag: py3-none-any
5
5