foscat 2025.5.0__py3-none-any.whl → 2025.6.1__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
foscat/scat_cov.py CHANGED
@@ -2378,16 +2378,16 @@ class funct(FOC.FoCUS):
2378
2378
  # of cosine and sine contributions using all channels
2379
2379
  angles = self.backend.bk_cast(
2380
2380
  (2 * np.pi * np.arange(self.NORIENT) / self.NORIENT).reshape(
2381
- 1, 1, self.NORIENT
2381
+ 1, self.NORIENT, 1
2382
2382
  )
2383
2383
  ) # shape: (NORIENT,)
2384
2384
 
2385
2385
  # we use cosines and sines as weights for sim
2386
2386
  weighted_cos = self.backend.bk_reduce_mean(
2387
- sim * self.backend.bk_cos(angles), axis=-1
2387
+ sim * self.backend.bk_cos(angles), axis=-2
2388
2388
  )
2389
2389
  weighted_sin = self.backend.bk_reduce_mean(
2390
- sim * self.backend.bk_sin(angles), axis=-1
2390
+ sim * self.backend.bk_sin(angles), axis=-2
2391
2391
  )
2392
2392
  # For simplicity, take first element of the batch
2393
2393
  cc = weighted_cos[0]
@@ -2423,43 +2423,46 @@ class funct(FOC.FoCUS):
2423
2423
  w1 = np.sin(delta * np.pi / 2) ** 2
2424
2424
 
2425
2425
  # build rotation matrix
2426
- mat = np.zeros([sim.shape[1], self.NORIENT * self.NORIENT])
2427
- lidx = np.arange(sim.shape[1])
2426
+ mat = np.zeros([self.NORIENT * self.NORIENT, sim.shape[2]])
2427
+ lidx = np.arange(sim.shape[2])
2428
2428
  for ell in range(self.NORIENT):
2429
2429
  # Instead of simple linear weights, we use the cosine weights w0 and w1.
2430
2430
  col0 = self.NORIENT * ((ell + iph) % self.NORIENT) + ell
2431
2431
  col1 = self.NORIENT * ((ell + iph + 1) % self.NORIENT) + ell
2432
- mat[lidx, col0] = w0
2433
- mat[lidx, col1] = w1
2434
2432
 
2435
- cmat[k] = self.backend.bk_cast(mat.astype("complex64"))
2433
+ mat[col0, lidx] = w0
2434
+ mat[col1, lidx] = w1
2435
+
2436
+ cmat[k] = self.backend.bk_cast(mat[None, ...].astype("complex64"))
2436
2437
 
2437
2438
  # do same modifications for mat2
2438
2439
  mat2 = np.zeros(
2439
- [k + 1, sim.shape[1], self.NORIENT, self.NORIENT * self.NORIENT]
2440
+ [k + 1, self.NORIENT * self.NORIENT, self.NORIENT, sim.shape[2]]
2440
2441
  )
2441
2442
 
2442
2443
  for k2 in range(k + 1):
2443
- tmp2 = self.backend.bk_repeat(sim, self.NORIENT, axis=-1)
2444
+
2445
+ tmp2 = self.backend.bk_repeat(sim, self.NORIENT, axis=1)
2444
2446
 
2445
2447
  sim2 = self.backend.bk_reduce_sum(
2446
2448
  self.backend.bk_reshape(
2447
2449
  self.backend.bk_cast(
2448
- mat.reshape(1, mat.shape[0], self.NORIENT * self.NORIENT)
2450
+ mat.reshape(1, self.NORIENT * self.NORIENT, mat.shape[1])
2449
2451
  )
2450
2452
  * tmp2,
2451
- [sim.shape[0], cmat[k].shape[0], self.NORIENT, self.NORIENT],
2453
+ [sim.shape[0], self.NORIENT, self.NORIENT, mat.shape[1]],
2452
2454
  ),
2453
- 2,
2455
+ 1,
2454
2456
  )
2455
2457
 
2456
- sim2 = self.backend.bk_abs(self.convol(sim2, axis=1))
2458
+ sim2 = self.backend.bk_abs(self.convol(sim2, axis=-1))
2457
2459
 
2460
+ angles = self.backend.bk_reshape(angles, [1, self.NORIENT, 1, 1])
2458
2461
  weighted_cos2 = self.backend.bk_reduce_mean(
2459
- sim2 * self.backend.bk_cos(angles), axis=-1
2462
+ sim2 * self.backend.bk_cos(angles), axis=1
2460
2463
  )
2461
2464
  weighted_sin2 = self.backend.bk_reduce_mean(
2462
- sim2 * self.backend.bk_sin(angles), axis=-1
2465
+ sim2 * self.backend.bk_sin(angles), axis=1
2463
2466
  )
2464
2467
 
2465
2468
  cc2 = weighted_cos2[0]
@@ -2467,14 +2470,14 @@ class funct(FOC.FoCUS):
2467
2470
 
2468
2471
  if smooth_scale > 0:
2469
2472
  for m in range(smooth_scale):
2470
- if cc2.shape[0] > 12:
2471
- cc2, _ = self.ud_grade_2(self.smooth(cc2))
2472
- ss2, _ = self.ud_grade_2(self.smooth(ss2))
2473
+ if cc2.shape[1] > 12:
2474
+ cc2, _ = self.ud_grade_2(self.smooth(cc2, axis=1), axis=1)
2475
+ ss2, _ = self.ud_grade_2(self.smooth(ss2, axis=1), axis=1)
2473
2476
 
2474
- if cc2.shape[0] != sim.shape[1]:
2475
- ll_nside = int(np.sqrt(sim.shape[1] // 12))
2476
- cc2 = self.up_grade(cc2, ll_nside)
2477
- ss2 = self.up_grade(ss2, ll_nside)
2477
+ if cc2.shape[1] != sim.shape[2]:
2478
+ ll_nside = int(np.sqrt(sim.shape[2] // 12))
2479
+ cc2 = self.up_grade(cc2, ll_nside, axis=1)
2480
+ ss2 = self.up_grade(ss2, ll_nside, axis=1)
2478
2481
 
2479
2482
  if self.BACKEND == "numpy":
2480
2483
  phase2 = np.fmod(np.arctan2(ss2, cc2) + 2 * np.pi, 2 * np.pi)
@@ -2492,17 +2495,18 @@ class funct(FOC.FoCUS):
2492
2495
  delta2 = phase2_scaled - iph2
2493
2496
  w0_2 = np.cos(delta2 * np.pi / 2) ** 2
2494
2497
  w1_2 = np.sin(delta2 * np.pi / 2) ** 2
2495
- lidx = np.arange(sim.shape[1])
2498
+ lidx = np.arange(sim.shape[2])
2496
2499
 
2497
2500
  for m in range(self.NORIENT):
2498
2501
  for ell in range(self.NORIENT):
2499
- col0 = self.NORIENT * ((ell + iph2[:, m]) % self.NORIENT) + ell
2500
- col1 = (
2501
- self.NORIENT * ((ell + iph2[:, m] + 1) % self.NORIENT) + ell
2502
- )
2503
- mat2[k2, lidx, m, col0] = w0_2[:, m]
2504
- mat2[k2, lidx, m, col1] = w1_2[:, m]
2505
- cmat2[k] = self.backend.bk_cast(mat2.astype("complex64"))
2502
+ col0 = self.NORIENT * ((ell + iph2[m]) % self.NORIENT) + ell
2503
+ col1 = self.NORIENT * ((ell + iph2[m] + 1) % self.NORIENT) + ell
2504
+ mat2[k2, col0, m, lidx] = w0_2[m, lidx]
2505
+ mat2[k2, col1, m, lidx] = w1_2[m, lidx]
2506
+
2507
+ cmat2[k] = self.backend.bk_cast(
2508
+ mat2[0 : k + 1, None, ...].astype("complex64")
2509
+ )
2506
2510
 
2507
2511
  if k < l_nside - 1:
2508
2512
  tmp, _ = self.ud_grade_2(tmp, axis=1)
@@ -2620,6 +2624,9 @@ class funct(FOC.FoCUS):
2620
2624
  x1 = im_shape[1]
2621
2625
  x2 = im_shape[2]
2622
2626
  J = int(np.log(nside - self.KERNELSZ) / np.log(2)) # Number of j scales
2627
+ if J == 0:
2628
+ print("Use of too small 2D domain does not work J_max=", J)
2629
+ return None
2623
2630
  elif self.use_1D:
2624
2631
  if len(image1.shape) == 2:
2625
2632
  npix = int(im_shape[1]) # Number of pixels
@@ -2842,21 +2849,25 @@ class funct(FOC.FoCUS):
2842
2849
  ### Make the convolution I1 * Psi_j3
2843
2850
  conv1 = self.convol(
2844
2851
  I1, axis=1, cell_ids=cell_ids_j3, nside=nside_j3
2845
- ) # [Nbatch, Npix_j3, Norient3]
2852
+ ) # [Nbatch, Norient3 , Npix_j3]
2853
+
2846
2854
  if cmat is not None:
2847
- tmp2 = self.backend.bk_repeat(conv1, self.NORIENT, axis=-1)
2855
+
2856
+ tmp2 = self.backend.bk_repeat(conv1, self.NORIENT, axis=-2)
2857
+
2848
2858
  conv1 = self.backend.bk_reduce_sum(
2849
2859
  self.backend.bk_reshape(
2850
2860
  cmat[j3] * tmp2,
2851
- [tmp2.shape[0], cmat[j3].shape[0], self.NORIENT, self.NORIENT],
2861
+ [tmp2.shape[0], self.NORIENT, self.NORIENT, cmat[j3].shape[2]],
2852
2862
  ),
2853
- 2,
2863
+ 1,
2854
2864
  )
2855
2865
 
2856
2866
  ### Take the module M1 = |I1 * Psi_j3|
2857
2867
  M1_square = conv1 * self.backend.bk_conjugate(
2858
2868
  conv1
2859
- ) # [Nbatch, Npix_j3, Norient3]
2869
+ ) # [Nbatch, Norient3, Npix_j3]
2870
+
2860
2871
  M1 = self.backend.bk_L1(M1_square) # [Nbatch, Npix_j3, Norient3]
2861
2872
  # Store M1_j3 in a dictionary
2862
2873
  M1_dic[j3] = M1
@@ -2871,10 +2882,10 @@ class funct(FOC.FoCUS):
2871
2882
  else:
2872
2883
  if calc_var:
2873
2884
  s2, vs2 = self.masked_mean(
2874
- M1_square, vmask, axis=1, rank=j3, calc_var=True
2885
+ M1_square, vmask, axis=2, rank=j3, calc_var=True
2875
2886
  )
2876
2887
  else:
2877
- s2 = self.masked_mean(M1_square, vmask, axis=1, rank=j3)
2888
+ s2 = self.masked_mean(M1_square, vmask, axis=2, rank=j3)
2878
2889
 
2879
2890
  if cond_init_P1_dic:
2880
2891
  # We fill P1_dic with S2 for normalisation of S3 and S4
@@ -2890,9 +2901,9 @@ class funct(FOC.FoCUS):
2890
2901
  s2,
2891
2902
  [
2892
2903
  s2.shape[0],
2904
+ s2.shape[2],
2893
2905
  12 * out_nside**2,
2894
2906
  (nside_j3 // out_nside) ** 2,
2895
- s2.shape[2],
2896
2907
  ],
2897
2908
  ),
2898
2909
  2,
@@ -2918,11 +2929,11 @@ class funct(FOC.FoCUS):
2918
2929
  else:
2919
2930
  if calc_var:
2920
2931
  s1, vs1 = self.masked_mean(
2921
- M1, vmask, axis=1, rank=j3, calc_var=True
2932
+ M1, vmask, axis=2, rank=j3, calc_var=True
2922
2933
  ) # [Nbatch, Nmask, Norient3]
2923
2934
  else:
2924
2935
  s1 = self.masked_mean(
2925
- M1, vmask, axis=1, rank=j3
2936
+ M1, vmask, axis=2, rank=j3
2926
2937
  ) # [Nbatch, Nmask, Norient3]
2927
2938
 
2928
2939
  if return_data:
@@ -2932,9 +2943,9 @@ class funct(FOC.FoCUS):
2932
2943
  s1,
2933
2944
  [
2934
2945
  s1.shape[0],
2946
+ s1.shape[2],
2935
2947
  12 * out_nside**2,
2936
2948
  (nside_j3 // out_nside) ** 2,
2937
- s1.shape[2],
2938
2949
  ],
2939
2950
  ),
2940
2951
  2,
@@ -2956,21 +2967,21 @@ class funct(FOC.FoCUS):
2956
2967
  else: # Cross
2957
2968
  ### Make the convolution I2 * Psi_j3
2958
2969
  conv2 = self.convol(
2959
- I2, axis=1, cell_ids=cell_ids_j3, nside=nside_j3
2970
+ I2, axis=2, cell_ids=cell_ids_j3, nside=nside_j3
2960
2971
  ) # [Nbatch, Npix_j3, Norient3]
2961
2972
  if cmat is not None:
2962
- tmp2 = self.backend.bk_repeat(conv2, self.NORIENT, axis=-1)
2973
+ tmp2 = self.backend.bk_repeat(conv2, self.NORIENT, axis=-2)
2963
2974
  conv2 = self.backend.bk_reduce_sum(
2964
2975
  self.backend.bk_reshape(
2965
2976
  cmat[j3] * tmp2,
2966
2977
  [
2967
2978
  tmp2.shape[0],
2968
- cmat[j3].shape[0],
2969
2979
  self.NORIENT,
2970
2980
  self.NORIENT,
2981
+ cmat[j3].shape[2],
2971
2982
  ],
2972
2983
  ),
2973
- 2,
2984
+ 1,
2974
2985
  )
2975
2986
  ### Take the module M2 = |I2 * Psi_j3|
2976
2987
  M2_square = conv2 * self.backend.bk_conjugate(
@@ -2990,17 +3001,17 @@ class funct(FOC.FoCUS):
2990
3001
  else:
2991
3002
  if calc_var:
2992
3003
  p1, vp1 = self.masked_mean(
2993
- M1_square, vmask, axis=1, rank=j3, calc_var=True
3004
+ M1_square, vmask, axis=2, rank=j3, calc_var=True
2994
3005
  ) # [Nbatch, Nmask, Norient3]
2995
3006
  p2, vp2 = self.masked_mean(
2996
- M2_square, vmask, axis=1, rank=j3, calc_var=True
3007
+ M2_square, vmask, axis=2, rank=j3, calc_var=True
2997
3008
  ) # [Nbatch, Nmask, Norient3]
2998
3009
  else:
2999
3010
  p1 = self.masked_mean(
3000
- M1_square, vmask, axis=1, rank=j3
3011
+ M1_square, vmask, axis=2, rank=j3
3001
3012
  ) # [Nbatch, Nmask, Norient3]
3002
3013
  p2 = self.masked_mean(
3003
- M2_square, vmask, axis=1, rank=j3
3014
+ M2_square, vmask, axis=2, rank=j3
3004
3015
  ) # [Nbatch, Nmask, Norient3]
3005
3016
  # We fill P1_dic with S2 for normalisation of S3 and S4
3006
3017
  P1_dic[j3] = self.backend.bk_real(p1) # [Nbatch, Nmask, Norient3]
@@ -3016,10 +3027,10 @@ class funct(FOC.FoCUS):
3016
3027
  else:
3017
3028
  if calc_var:
3018
3029
  s2, vs2 = self.masked_mean(
3019
- s2, vmask, axis=1, rank=j3, calc_var=True
3030
+ s2, vmask, axis=2, rank=j3, calc_var=True
3020
3031
  )
3021
3032
  else:
3022
- s2 = self.masked_mean(s2, vmask, axis=1, rank=j3)
3033
+ s2 = self.masked_mean(s2, vmask, axis=2, rank=j3)
3023
3034
 
3024
3035
  if return_data:
3025
3036
  if out_nside is not None and out_nside < nside_j3:
@@ -3028,9 +3039,9 @@ class funct(FOC.FoCUS):
3028
3039
  s2,
3029
3040
  [
3030
3041
  s2.shape[0],
3042
+ s2.shape[2],
3031
3043
  12 * out_nside**2,
3032
3044
  (nside_j3 // out_nside) ** 2,
3033
- s2.shape[2],
3034
3045
  ],
3035
3046
  ),
3036
3047
  2,
@@ -3040,7 +3051,7 @@ class funct(FOC.FoCUS):
3040
3051
 
3041
3052
  ### Store S2_cross as complex [Nbatch, Nmask, NS2, Norient3]
3042
3053
  s2 = self.backend.bk_real(s2)
3043
-
3054
+
3044
3055
  ### Normalize S2_cross
3045
3056
  if norm == "auto":
3046
3057
  s2 /= (P1_dic[j3] * P2_dic[j3]) ** 0.5
@@ -3061,11 +3072,11 @@ class funct(FOC.FoCUS):
3061
3072
  else:
3062
3073
  if calc_var:
3063
3074
  s1, vs1 = self.masked_mean(
3064
- MX, vmask, axis=1, rank=j3, calc_var=True
3075
+ MX, vmask, axis=2, rank=j3, calc_var=True
3065
3076
  ) # [Nbatch, Nmask, Norient3]
3066
3077
  else:
3067
3078
  s1 = self.masked_mean(
3068
- MX, vmask, axis=1, rank=j3
3079
+ MX, vmask, axis=2, rank=j3
3069
3080
  ) # [Nbatch, Nmask, Norient3]
3070
3081
  if return_data:
3071
3082
  if out_nside is not None and out_nside < nside_j3:
@@ -3074,9 +3085,9 @@ class funct(FOC.FoCUS):
3074
3085
  s1,
3075
3086
  [
3076
3087
  s1.shape[0],
3088
+ s1.shape[2],
3077
3089
  12 * out_nside**2,
3078
3090
  (nside_j3 // out_nside) ** 2,
3079
- s1.shape[2],
3080
3091
  ],
3081
3092
  ),
3082
3093
  2,
@@ -3487,19 +3498,19 @@ class funct(FOC.FoCUS):
3487
3498
  for j2 in range(0, j3 + 1): # j2 =< j3
3488
3499
  ### Dictionary M1_dic[j2]
3489
3500
  M1_smooth = self.smooth(
3490
- M1_dic[j2], axis=1, cell_ids=cell_ids_j3, nside=nside_j3
3501
+ M1_dic[j2], axis=2, cell_ids=cell_ids_j3, nside=nside_j3
3491
3502
  ) # [Nbatch, Npix_j3, Norient3]
3492
3503
  M1_dic[j2], new_cell_ids_j2 = self.ud_grade_2(
3493
- M1_smooth, axis=1, cell_ids=cell_ids_j3, nside=nside_j3
3504
+ M1_smooth, axis=2, cell_ids=cell_ids_j3, nside=nside_j3
3494
3505
  ) # [Nbatch, Npix_j3, Norient3]
3495
3506
 
3496
3507
  ### Dictionary M2_dic[j2]
3497
3508
  if cross:
3498
3509
  M2_smooth = self.smooth(
3499
- M2_dic[j2], axis=1, cell_ids=cell_ids_j3, nside=nside_j3
3510
+ M2_dic[j2], axis=2, cell_ids=cell_ids_j3, nside=nside_j3
3500
3511
  ) # [Nbatch, Npix_j3, Norient3]
3501
3512
  M2_dic[j2], new_cell_ids_j2 = self.ud_grade_2(
3502
- M2_smooth, axis=1, cell_ids=cell_ids_j3, nside=nside_j3
3513
+ M2_smooth, axis=2, cell_ids=cell_ids_j3, nside=nside_j3
3503
3514
  ) # [Nbatch, Npix_j3, Norient3]
3504
3515
  ### Mask
3505
3516
  vmask, new_cell_ids_j3 = self.ud_grade_2(
@@ -3624,51 +3635,51 @@ class funct(FOC.FoCUS):
3624
3635
  cs3, ss3: real and imag parts of S3 coeff
3625
3636
  """
3626
3637
  ### Compute |I1 * Psi_j2| * Psi_j3 = M1_j2 * Psi_j3
3627
- # Warning: M1_dic[j2] is already at j3 resolution [Nbatch, Npix_j3, Norient3]
3638
+ # Warning: M1_dic[j2] is already at j3 resolution [Nbatch, Norient3, Npix_j3]
3628
3639
  MconvPsi = self.convol(
3629
- M_dic[j2], axis=1, cell_ids=cell_ids, nside=nside_j2
3630
- ) # [Nbatch, Npix_j3, Norient3, Norient2]
3640
+ M_dic[j2], axis=2, cell_ids=cell_ids, nside=nside_j2
3641
+ ) # [Nbatch, Norient3, Norient2, Npix_j3]
3642
+
3631
3643
  if cmat2 is not None:
3632
- tmp2 = self.backend.bk_repeat(MconvPsi, self.NORIENT, axis=-1)
3644
+ tmp2 = self.backend.bk_repeat(MconvPsi, self.NORIENT, axis=-3)
3633
3645
  MconvPsi = self.backend.bk_reduce_sum(
3634
3646
  self.backend.bk_reshape(
3635
3647
  cmat2[j3][j2] * tmp2,
3636
3648
  [
3637
3649
  tmp2.shape[0],
3638
- cmat2[j3].shape[1],
3639
3650
  self.NORIENT,
3640
3651
  self.NORIENT,
3641
3652
  self.NORIENT,
3653
+ cmat2[j3][j2].shape[3],
3642
3654
  ],
3643
3655
  ),
3644
- 3,
3656
+ 1,
3645
3657
  )
3646
3658
 
3647
3659
  # Store it so we can use it in S4 computation
3648
- MconvPsi_dic[j2] = MconvPsi # [Nbatch, Npix_j3, Norient3, Norient2]
3660
+ MconvPsi_dic[j2] = MconvPsi # [Nbatch, Norient3, Norient2, Npix_j3]
3649
3661
 
3650
3662
  ### Compute the product (I2 * Psi)_j3 x (M1_j2 * Psi_j3)^*
3651
3663
  # z_1 x z_2^* = (a1a2 + b1b2) + i(b1a2 - a1b2)
3652
- # cconv, sconv are [Nbatch, Npix_j3, Norient3]
3664
+ # cconv, sconv are [Nbatch, Norient3, Npix_j3]
3653
3665
  if self.use_1D:
3654
3666
  s3 = conv * self.backend.bk_conjugate(MconvPsi)
3655
3667
  else:
3656
- s3 = self.backend.bk_expand_dims(conv, -1) * self.backend.bk_conjugate(
3668
+ s3 = self.backend.bk_expand_dims(conv, -3) * self.backend.bk_conjugate(
3657
3669
  MconvPsi
3658
- ) # [Nbatch, Npix_j3, Norient3, Norient2]
3659
-
3670
+ ) # [Nbatch, Norient3, Norient2, Npix_j3]
3660
3671
  ### Apply the mask [Nmask, Npix_j3] and sum over pixels
3661
3672
  if return_data:
3662
3673
  return s3
3663
3674
  else:
3664
3675
  if calc_var:
3665
3676
  s3, vs3 = self.masked_mean(
3666
- s3, vmask, axis=1, rank=j2, calc_var=True
3677
+ s3, vmask, axis=3, rank=j2, calc_var=True
3667
3678
  ) # [Nbatch, Nmask, Norient3, Norient2]
3668
3679
  return s3, vs3
3669
3680
  else:
3670
3681
  s3 = self.masked_mean(
3671
- s3, vmask, axis=1, rank=j2
3682
+ s3, vmask, axis=3, rank=j2
3672
3683
  ) # [Nbatch, Nmask, Norient3, Norient2]
3673
3684
  return s3
3674
3685
 
@@ -3683,11 +3694,11 @@ class funct(FOC.FoCUS):
3683
3694
  return_data=False,
3684
3695
  ):
3685
3696
  #### Simplify notations
3686
- M1 = M1convPsi_dic[j1] # [Nbatch, Npix_j3, Norient3, Norient1]
3697
+ M1 = M1convPsi_dic[j1] # [Nbatch, Norient3, Norient1, Npix_j3]
3687
3698
 
3688
3699
  # Auto or Cross coefficients
3689
3700
  if M2convPsi_dic is None: # Auto
3690
- M2 = M1convPsi_dic[j2] # [Nbatch, Npix_j3, Norient3, Norient2]
3701
+ M2 = M1convPsi_dic[j2] # [Nbatch, Norient3, Norient2, Npix_j3]
3691
3702
  else: # Cross
3692
3703
  M2 = M2convPsi_dic[j2]
3693
3704
 
@@ -3696,9 +3707,9 @@ class funct(FOC.FoCUS):
3696
3707
  if self.use_1D:
3697
3708
  s4 = M1 * self.backend.bk_conjugate(M2)
3698
3709
  else:
3699
- s4 = self.backend.bk_expand_dims(M1, -2) * self.backend.bk_conjugate(
3700
- self.backend.bk_expand_dims(M2, -1)
3701
- ) # [Nbatch, Npix_j3, Norient3, Norient2, Norient1]
3710
+ s4 = self.backend.bk_expand_dims(M1, -4) * self.backend.bk_conjugate(
3711
+ self.backend.bk_expand_dims(M2, -3)
3712
+ ) # [Nbatch, Norient3, Norient2, Norient1,Npix_j3]
3702
3713
 
3703
3714
  ### Apply the mask and sum over pixels
3704
3715
  if return_data:
@@ -3706,12 +3717,12 @@ class funct(FOC.FoCUS):
3706
3717
  else:
3707
3718
  if calc_var:
3708
3719
  s4, vs4 = self.masked_mean(
3709
- s4, vmask, axis=1, rank=j2, calc_var=True
3720
+ s4, vmask, axis=4, rank=j2, calc_var=True
3710
3721
  ) # [Nbatch, Nmask, Norient3, Norient2, Norient1]
3711
3722
  return s4, vs4
3712
3723
  else:
3713
3724
  s4 = self.masked_mean(
3714
- s4, vmask, axis=1, rank=j2
3725
+ s4, vmask, axis=4, rank=j2
3715
3726
  ) # [Nbatch, Nmask, Norient3, Norient2, Norient1]
3716
3727
  return s4
3717
3728
 
@@ -6050,9 +6061,7 @@ class funct(FOC.FoCUS):
6050
6061
  cmat2=None,
6051
6062
  ):
6052
6063
 
6053
- res = self.eval(
6054
- image1, image2=image2, mask=mask, cmat=cmat, cmat2=cmat2
6055
- )
6064
+ res = self.eval(image1, image2=image2, mask=mask, cmat=cmat, cmat2=cmat2)
6056
6065
  return res.S0, res.S2, res.S1, res.S3, res.S4, res.S3P
6057
6066
 
6058
6067
  def eval_fast(
@@ -6065,12 +6074,43 @@ class funct(FOC.FoCUS):
6065
6074
  cmat2=None,
6066
6075
  ):
6067
6076
  s0, s2, s1, s3, s4, s3p = self.eval_comp_fast(
6068
- image1, image2=image2, mask=mask, cmat=cmat, cmat2=cmat2
6077
+ image1, image2=image2, mask=mask, cmat=cmat, cmat2=cmat2
6069
6078
  )
6070
6079
  return scat_cov(
6071
6080
  s0, s2, s3, s4, s1=s1, s3p=s3p, backend=self.backend, use_1D=self.use_1D
6072
6081
  )
6073
-
6082
+ def calc_matrix_orientation(self,noise_map,image2=None):
6083
+ # Décalage circulaire par matrice de permutation
6084
+ def circ_shift_matrix(N,k):
6085
+ return np.roll(np.eye(N), shift=-k, axis=1)
6086
+ Norient = self.NORIENT
6087
+ im=self.convol(noise_map)
6088
+ if image2 is None:
6089
+ mm=np.mean(abs(self.backend.to_numpy(im)),0)
6090
+ else:
6091
+ im2=self.convol(self.backend.bk_cast(image2))
6092
+ mm=np.mean(self.backend.to_numpy(
6093
+ self.backend.bk_L1(im*self.backend.bk_conjugate(im2))).real,0)
6094
+
6095
+ Norient=mm.shape[0]
6096
+ xx=np.cos(np.arange(Norient)/Norient*2*np.pi)
6097
+ yy=np.sin(np.arange(Norient)/Norient*2*np.pi)
6098
+
6099
+ a=np.sum(mm*xx[:,None],0)
6100
+ b=np.sum(mm*yy[:,None],0)
6101
+
6102
+ o=np.fmod(Norient*np.arctan2(-b,a)/(2*np.pi)+Norient,Norient)
6103
+ xx=np.arange(Norient)
6104
+ alpha = o[None,:]-xx[:,None]
6105
+ beta = np.fmod(1+o[None,:]-xx[:,None],Norient)
6106
+ alpha=(1-alpha)*(alpha<1)*(alpha>0)+beta*(beta<1)*(beta>0)
6107
+
6108
+ m=np.zeros([Norient,Norient,mm.shape[1]])
6109
+ for k in range(Norient):
6110
+ m[k,:,:]=np.roll(alpha,k,0)
6111
+ #m=np.mean(m,0)
6112
+ return self.backend.bk_cast(m)
6113
+
6074
6114
  def synthesis(
6075
6115
  self,
6076
6116
  image_target,
@@ -6104,12 +6144,12 @@ class funct(FOC.FoCUS):
6104
6144
  def The_loss_ref_image(u, scat_operator, args):
6105
6145
  input_image = args[0]
6106
6146
  mask = args[1]
6107
-
6108
- loss = 1E-3*scat_operator.backend.bk_reduce_mean(
6109
- scat_operator.backend.bk_square(mask*(input_image - u))
6147
+
6148
+ loss = 1e-3 * scat_operator.backend.bk_reduce_mean(
6149
+ scat_operator.backend.bk_square(mask * (input_image - u))
6110
6150
  )
6111
6151
  return loss
6112
-
6152
+
6113
6153
  def The_loss(u, scat_operator, args):
6114
6154
  ref = args[0]
6115
6155
  sref = args[1]
@@ -6346,14 +6386,17 @@ class funct(FOC.FoCUS):
6346
6386
  loss = synthe.Loss(
6347
6387
  The_lossX, self, ref, sref, use_variance, l_ref[k], l_jmax[k]
6348
6388
  )
6349
-
6389
+
6350
6390
  if input_image is not None:
6351
6391
  # define a loss to minimize
6352
- loss_input = synthe.Loss(The_loss_ref_image, self,
6353
- self.backend.bk_cast(l_input_image[k]),
6354
- self.backend.bk_cast(l_in_mask[k]))
6355
-
6356
- sy = synthe.Synthesis([loss]) #,loss_input])
6392
+ loss_input = synthe.Loss(
6393
+ The_loss_ref_image,
6394
+ self,
6395
+ self.backend.bk_cast(l_input_image[k]),
6396
+ self.backend.bk_cast(l_in_mask[k]),
6397
+ )
6398
+
6399
+ sy = synthe.Synthesis([loss]) # ,loss_input])
6357
6400
  else:
6358
6401
  sy = synthe.Synthesis([loss])
6359
6402
 
foscat/scat_cov_map2D.py CHANGED
@@ -23,10 +23,10 @@ class funct(scat.funct):
23
23
  super().__init__(use_2D=True, return_data=True, *args, **kwargs)
24
24
 
25
25
  def eval(
26
- self, image1, image2=None, mask=None, norm=None, Auto=True, calc_var=False
26
+ self, image1, image2=None, mask=None, norm=None, calc_var=False, Jmax=None
27
27
  ):
28
28
  r = super().eval(
29
- image1, image2=image2, mask=mask, norm=norm, Auto=Auto, calc_var=calc_var
29
+ image1, image2=image2, mask=mask, norm=norm, calc_var=calc_var, Jmax=Jmax
30
30
  )
31
31
  return scat_cov_map(
32
32
  r.S2, r.S0, r.S3, r.S4, S1=r.S1, S3P=r.S3P, backend=r.backend
@@ -1,6 +1,6 @@
1
1
  Metadata-Version: 2.4
2
2
  Name: foscat
3
- Version: 2025.5.0
3
+ Version: 2025.6.1
4
4
  Summary: Generate synthetic Healpix or 2D data using Cross Scattering Transform
5
5
  Author-email: Jean-Marc DELOUIS <jean.marc.delouis@ifremer.fr>
6
6
  Maintainer-email: Theo Foulquier <theo.foulquier@ifremer.fr>
@@ -1,30 +1,31 @@
1
1
  foscat/BkBase.py,sha256=TEhfqUpIOh_bGBCyQfRCs0yjKdhgELjFpvq_QiouX5A,21514
2
2
  foscat/BkNumpy.py,sha256=zRldS_-L6A7y1zDzEPZXQntuw3Paw2zHZowhD43FHRs,10589
3
- foscat/BkTensorflow.py,sha256=N5TBacuyFB1-qGTi2kOc8zbgWzj5lVRRN47uZJpJJ10,15713
4
- foscat/BkTorch.py,sha256=011L9WCBtrRzV1jfGWCYOMSkt1IJ0PfEO82NnrIYbAc,16648
5
- foscat/CNN.py,sha256=j0F2a4Xf3LijhyD_WVZ6Eg_IjGuXw3ddH6Iudj1xVaw,4874
3
+ foscat/BkTensorflow.py,sha256=K2s3xYVMHqLlTyApQpeKf9dc3hbRv8EqtCA_gE1bRQA,19958
4
+ foscat/BkTorch.py,sha256=OUpa1ajRFxes9v_T6jK0gLlJXzn_DhQ3rlrf-yTSFcY,18126
5
+ foscat/CNN.py,sha256=gQ9V76wmcowo2BaNp5sJYcSDCVOjc18TS9cE6-qEUso,5153
6
6
  foscat/CircSpline.py,sha256=CXi49FxF8ZoeZ17Ua8c1AZXe2B5ICEC9aCXb97atB3s,4028
7
- foscat/FoCUS.py,sha256=FeSkBmjBTELZQn529SONSOuzKlYxWqRLpyTM8j-Ql_Y,108768
8
- foscat/GCNN.py,sha256=5RV-FKuvqbD-k99TwiM4CttM2LMZE21WD0IK0j5Mkko,7599
9
- foscat/Softmax.py,sha256=aBLQauoG0q2SJYPotV6U-cxAhsJcspWHNRWdnA_nAiQ,2854
7
+ foscat/FoCUS.py,sha256=Ke_h6g4Fqn62OpCkVxfr5zIDnT5hd-b1vvkm9aB3LhY,97364
8
+ foscat/GCNN.py,sha256=q7yWHCMJpP7-m3WvR3OQnp5taeYWaMxIY2hQ6SIb9gs,4487
9
+ foscat/Softmax.py,sha256=UDZGrTroYtmGEyokGUVpwNO_cgbICi9QVuRr8Yx52_k,2917
10
10
  foscat/Spline1D.py,sha256=rKzzenduaZZ-yBDJd35it6Gyrj1spqb7hoIaUgISPzY,2983
11
11
  foscat/Synthesis.py,sha256=tC5hvpam19QwDdvghVax7dA7gMgKA6ZtxQEcV9HjdC0,13824
12
12
  foscat/__init__.py,sha256=47DEQpj8HBSa-_TImW-5JCeuQeRkm5NMpJWZG3hSuFU,0
13
- foscat/alm.py,sha256=qZlsYj5HzV1EY9Fdzt0U8bemrZHZziaMOKZ55FU8foM,33806
13
+ foscat/alm.py,sha256=XkK4rFVRoO-oJpr74iBffKt7hdS_iJkR016IlYm10gQ,33832
14
14
  foscat/backend.py,sha256=l3aMwDyXP6jURMIvratFMGWCTcQpaR68KnUuuGDezqE,45418
15
15
  foscat/backend_tens.py,sha256=9Dp136m9frkclkwifJQLLbIpl3ETI3_txdPUZcKfuMw,1618
16
+ foscat/heal_NN.py,sha256=BXAqBEftvxVNOEtwo6xid6gFLrve8I-9jPQ1xI0HTl0,15648
16
17
  foscat/loss_backend_tens.py,sha256=dCOVN6faDtIpN3VO78HTmYP2i5fnFAf-Ddy5qVBlGrM,1783
17
18
  foscat/loss_backend_torch.py,sha256=k3z18Dj3SaLKK6ZIKcm7GO4U_YKYVP6LtHG1aIbxkYk,1627
18
19
  foscat/scat.py,sha256=qGYiBIysPt65MdmF07WWA4piVlTfA9-lFDTaicnqC2w,72822
19
20
  foscat/scat1D.py,sha256=W5Uu6wdQ4ZsFKXpof0f1OBl-1wjJmW7ruvddRWxe7uM,53726
20
21
  foscat/scat2D.py,sha256=boKj0ASqMMSy7uQLK6hPniG87m3hZGJBYBiq5v8F9IQ,532
21
- foscat/scat_cov.py,sha256=SMIhqspoe4vn6n_suNd1Npbs6eoXH7C45KupQta_h6M,258963
22
+ foscat/scat_cov.py,sha256=bdlaDoMmIDHMAQ7ZdkbtrIKr4Y-3w2xWO9e2_ZTeelk,260445
22
23
  foscat/scat_cov1D.py,sha256=XOxsZZ5TYq8f34i2tUgIfzyaqaTDlICB3HzD2l_puro,531
23
24
  foscat/scat_cov2D.py,sha256=pAm0fKw8wyXram0TFbtw8tGcc8QPKuPXpQk0kh10r4U,7078
24
25
  foscat/scat_cov_map.py,sha256=9MzpwT2g9S3dmnjHEMK7PPLQ27oGQg2VFVsP_TDUU5E,2869
25
- foscat/scat_cov_map2D.py,sha256=FqF45FBcoiQbvuVsrLWUIPRUc95GsKsrnH6fKzB3GlE,2841
26
- foscat-2025.5.0.dist-info/licenses/LICENSE,sha256=i0ukIr8ZUpkSY2sZaE9XZK-6vuSU5iG6IgX_3pjatP8,1505
27
- foscat-2025.5.0.dist-info/METADATA,sha256=lcHXnUQB5cZQUTHp8vgEUP0ZeG278dN5GIae_UqnQPs,7215
28
- foscat-2025.5.0.dist-info/WHEEL,sha256=0CuiUZ_p9E4cD6NyLD6UG80LBXYyiSYZOKDm5lp32xk,91
29
- foscat-2025.5.0.dist-info/top_level.txt,sha256=AGySXBBAlJgb8Tj8af6m_F-aiNg2zNTcybCUPVOKjAg,7
30
- foscat-2025.5.0.dist-info/RECORD,,
26
+ foscat/scat_cov_map2D.py,sha256=1dS4P1KHqZYkYCLA1sYpPSZulJrCTd_2eL8HFOjlcz4,2841
27
+ foscat-2025.6.1.dist-info/licenses/LICENSE,sha256=i0ukIr8ZUpkSY2sZaE9XZK-6vuSU5iG6IgX_3pjatP8,1505
28
+ foscat-2025.6.1.dist-info/METADATA,sha256=0l5yu57MiVAHs7Q_Ar1I0CsksYBclfPs4Uq5me-4BnI,7215
29
+ foscat-2025.6.1.dist-info/WHEEL,sha256=_zCd3N1l69ArxyTb8rzEoP9TpbYXkqRFSNOD5OuxnTs,91
30
+ foscat-2025.6.1.dist-info/top_level.txt,sha256=AGySXBBAlJgb8Tj8af6m_F-aiNg2zNTcybCUPVOKjAg,7
31
+ foscat-2025.6.1.dist-info/RECORD,,
@@ -1,5 +1,5 @@
1
1
  Wheel-Version: 1.0
2
- Generator: setuptools (80.3.1)
2
+ Generator: setuptools (80.9.0)
3
3
  Root-Is-Purelib: true
4
4
  Tag: py3-none-any
5
5