foscat 2025.6.3__py3-none-any.whl → 2025.7.2__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
foscat/scat_cov.py CHANGED
@@ -44,6 +44,12 @@ class scat_cov:
44
44
  self.idx1 = None
45
45
  self.idx2 = None
46
46
  self.use_1D = use_1D
47
+ self.numel = self.backend.bk_len(s0)+ \
48
+ self.backend.bk_len(s1)+ \
49
+ self.backend.bk_len(s2)+ \
50
+ self.backend.bk_len(s3)+ \
51
+ self.backend.bk_len(s4)+ \
52
+ self.backend.bk_len(s3p)
47
53
 
48
54
  def numpy(self):
49
55
  if self.BACKEND == "numpy":
@@ -1589,20 +1595,20 @@ class scat_cov:
1589
1595
  def mean(self):
1590
1596
  if self.S1 is not None: # Auto
1591
1597
  return (
1592
- abs(self.get_np(self.S0)).mean()
1593
- + abs(self.get_np(self.S1)).mean()
1594
- + abs(self.get_np(self.S3)).mean()
1595
- + abs(self.get_np(self.S4)).mean()
1596
- + abs(self.get_np(self.S2)).mean()
1597
- ) / 4
1598
+ abs(self.get_np(self.S0)).sum()
1599
+ + abs(self.get_np(self.S1)).sum()
1600
+ + abs(self.get_np(self.S3)).sum()
1601
+ + abs(self.get_np(self.S4)).sum()
1602
+ + abs(self.get_np(self.S2)).sum()
1603
+ ) / self.numel
1598
1604
  else: # Cross
1599
1605
  return (
1600
- abs(self.get_np(self.S0)).mean()
1601
- + abs(self.get_np(self.S3)).mean()
1602
- + abs(self.get_np(self.S3P)).mean()
1603
- + abs(self.get_np(self.S4)).mean()
1604
- + abs(self.get_np(self.S2)).mean()
1605
- ) / 4
1606
+ abs(self.get_np(self.S0)).sum()
1607
+ + abs(self.get_np(self.S3)).sum()
1608
+ + abs(self.get_np(self.S3P)).sum()
1609
+ + abs(self.get_np(self.S4)).sum()
1610
+ + abs(self.get_np(self.S2)).sum()
1611
+ ) / self.numel
1606
1612
 
1607
1613
  def initdx(self, norient):
1608
1614
  idx1 = np.zeros([norient * norient], dtype="int")
@@ -2348,16 +2354,16 @@ class funct(FOC.FoCUS):
2348
2354
  )
2349
2355
 
2350
2356
  # compute local direction to make the statistical analysis more efficient
2351
- def stat_cfft(self, im, image2=None, upscale=False, smooth_scale=0):
2357
+ def stat_cfft(self, im, image2=None, upscale=False, smooth_scale=0,spin=0):
2352
2358
  tmp = im
2353
2359
  if image2 is not None:
2354
2360
  tmpi2 = image2
2355
2361
  if upscale:
2356
- l_nside = int(np.sqrt(tmp.shape[1] // 12))
2362
+ l_nside = int(np.sqrt(tmp.shape[-1] // 12))
2357
2363
  tmp = self.up_grade(tmp, l_nside * 2)
2358
2364
  if image2 is not None:
2359
2365
  tmpi2 = self.up_grade(tmpi2, l_nside * 2)
2360
- l_nside = int(np.sqrt(tmp.shape[1] // 12))
2366
+ l_nside = int(np.sqrt(tmp.shape[-1] // 12))
2361
2367
  nscale = int(np.log(l_nside) / np.log(2))
2362
2368
  cmat = {}
2363
2369
  cmat2 = {}
@@ -2367,20 +2373,23 @@ class funct(FOC.FoCUS):
2367
2373
  if image2 is not None:
2368
2374
  sim = self.backend.bk_real(
2369
2375
  self.backend.bk_L1(
2370
- self.convol(tmp)
2371
- * self.backend.bk_conjugate(self.convol(tmpi2))
2376
+ self.convol(tmp,spin=spin)
2377
+ * self.backend.bk_conjugate(self.convol(tmpi2,spin=spin))
2372
2378
  )
2373
2379
  )
2374
2380
  else:
2375
- sim = self.backend.bk_abs(self.convol(tmp))
2381
+ sim = self.backend.bk_abs(self.convol(tmp,spin=spin))
2376
2382
 
2377
2383
  # instead of difference between "opposite" channels use weighted average
2378
2384
  # of cosine and sine contributions using all channels
2379
- angles = self.backend.bk_cast(
2380
- (2 * np.pi * np.arange(self.NORIENT) / self.NORIENT).reshape(
2381
- 1, self.NORIENT, 1
2382
- )
2383
- ) # shape: (NORIENT,)
2385
+ if spin==0:
2386
+ angles = self.backend.bk_cast(
2387
+ (2 * np.pi * np.arange(self.NORIENT)
2388
+ / self.NORIENT).reshape(1,self.NORIENT,1)) # shape: (NORIENT,)
2389
+ else:
2390
+ angles = self.backend.bk_cast(
2391
+ (2 * np.pi * np.arange(self.NORIENT)
2392
+ / self.NORIENT).reshape(1,1,self.NORIENT,1)) # shape: (NORIENT,)
2384
2393
 
2385
2394
  # we use cosines and sines as weights for sim
2386
2395
  weighted_cos = self.backend.bk_reduce_mean(
@@ -2399,8 +2408,8 @@ class funct(FOC.FoCUS):
2399
2408
  cc, _ = self.ud_grade_2(self.smooth(cc))
2400
2409
  ss, _ = self.ud_grade_2(self.smooth(ss))
2401
2410
 
2402
- if cc.shape[0] != tmp.shape[0]:
2403
- ll_nside = int(np.sqrt(tmp.shape[1] // 12))
2411
+ if cc.shape[-1] != tmp.shape[-1]:
2412
+ ll_nside = int(np.sqrt(tmp.shape[-1] // 12))
2404
2413
  cc = self.up_grade(cc, ll_nside)
2405
2414
  ss = self.up_grade(ss, ll_nside)
2406
2415
 
@@ -2423,40 +2432,64 @@ class funct(FOC.FoCUS):
2423
2432
  w1 = np.sin(delta * np.pi / 2) ** 2
2424
2433
 
2425
2434
  # build rotation matrix
2426
- mat = np.zeros([self.NORIENT * self.NORIENT, sim.shape[2]])
2427
- lidx = np.arange(sim.shape[2])
2435
+ if spin==0:
2436
+ mat = np.zeros([self.NORIENT * self.NORIENT, sim.shape[-1]])
2437
+ else:
2438
+ mat = np.zeros([2,self.NORIENT * self.NORIENT, sim.shape[-1]])
2439
+ lidx = np.arange(sim.shape[-1])
2428
2440
  for ell in range(self.NORIENT):
2429
2441
  # Instead of simple linear weights, we use the cosine weights w0 and w1.
2430
2442
  col0 = self.NORIENT * ((ell + iph) % self.NORIENT) + ell
2431
2443
  col1 = self.NORIENT * ((ell + iph + 1) % self.NORIENT) + ell
2432
2444
 
2433
- mat[col0, lidx] = w0
2434
- mat[col1, lidx] = w1
2445
+ if spin==0:
2446
+ mat[col0, lidx] = w0
2447
+ mat[col1, lidx] = w1
2448
+ else:
2449
+ mat[0,col0, lidx] = w0[0]
2450
+ mat[0,col1, lidx] = w1[0]
2451
+ mat[1,col0, lidx] = w0[1]
2452
+ mat[1,col1, lidx] = w1[1]
2435
2453
 
2436
2454
  cmat[k] = self.backend.bk_cast(mat[None, ...].astype("complex64"))
2437
2455
 
2438
2456
  # do same modifications for mat2
2439
- mat2 = np.zeros(
2440
- [k + 1, self.NORIENT * self.NORIENT, self.NORIENT, sim.shape[2]]
2441
- )
2457
+ if spin==0:
2458
+ mat2 = np.zeros(
2459
+ [k + 1, self.NORIENT * self.NORIENT, self.NORIENT, sim.shape[-1]]
2460
+ )
2461
+ else:
2462
+ mat2 = np.zeros(
2463
+ [k + 1, 2, self.NORIENT * self.NORIENT, self.NORIENT, sim.shape[-1]]
2464
+ )
2442
2465
 
2443
2466
  for k2 in range(k + 1):
2444
2467
 
2445
2468
  tmp2 = self.backend.bk_expand_dims(sim,-2)
2446
-
2447
- sim2 = self.backend.bk_reduce_sum(
2448
- self.backend.bk_reshape(
2449
- self.backend.bk_cast(
2450
- mat.reshape(1, self.NORIENT, self.NORIENT, mat.shape[1])
2451
- )
2452
- * tmp2,
2453
- [sim.shape[0], self.NORIENT, self.NORIENT, mat.shape[1]],
2454
- ),
2455
- 1,
2456
- )
2469
+ if spin==0:
2470
+ sim2 = self.backend.bk_reduce_sum(
2471
+ self.backend.bk_reshape(
2472
+ self.backend.bk_cast(
2473
+ mat.reshape(1, self.NORIENT, self.NORIENT, mat.shape[-1])
2474
+ )
2475
+ * tmp2,
2476
+ [sim.shape[0], self.NORIENT, self.NORIENT, mat.shape[-1]],
2477
+ ),
2478
+ 1,
2479
+ )
2480
+ else:
2481
+ sim2 = self.backend.bk_reduce_sum(
2482
+ self.backend.bk_reshape(
2483
+ self.backend.bk_cast(
2484
+ mat.reshape(1, 2, self.NORIENT, self.NORIENT, mat.shape[-1])
2485
+ )
2486
+ * tmp2,
2487
+ [sim.shape[0], 2, self.NORIENT, self.NORIENT, mat.shape[-1]],
2488
+ ),
2489
+ 2,
2490
+ )
2457
2491
 
2458
2492
  sim2 = self.backend.bk_abs(self.convol(sim2))
2459
-
2460
2493
  angles = self.backend.bk_reshape(angles, [1, self.NORIENT, 1, 1])
2461
2494
  weighted_cos2 = self.backend.bk_reduce_mean(
2462
2495
  sim2 * self.backend.bk_cos(angles), axis=-3
@@ -2474,8 +2507,8 @@ class funct(FOC.FoCUS):
2474
2507
  cc2, _ = self.ud_grade_2(self.smooth(cc2))
2475
2508
  ss2, _ = self.ud_grade_2(self.smooth(ss2))
2476
2509
 
2477
- if cc2.shape[1] != sim.shape[2]:
2478
- ll_nside = int(np.sqrt(sim.shape[2] // 12))
2510
+ if cc2.shape[-1] != sim.shape[-1]:
2511
+ ll_nside = int(np.sqrt(sim.shape[-1] // 12))
2479
2512
  cc2 = self.up_grade(cc2, ll_nside)
2480
2513
  ss2 = self.up_grade(ss2, ll_nside)
2481
2514
 
@@ -2495,14 +2528,23 @@ class funct(FOC.FoCUS):
2495
2528
  delta2 = phase2_scaled - iph2
2496
2529
  w0_2 = np.cos(delta2 * np.pi / 2) ** 2
2497
2530
  w1_2 = np.sin(delta2 * np.pi / 2) ** 2
2498
- lidx = np.arange(sim.shape[2])
2499
-
2500
- for m in range(self.NORIENT):
2501
- for ell in range(self.NORIENT):
2502
- col0 = self.NORIENT * ((ell + iph2[m]) % self.NORIENT) + ell
2503
- col1 = self.NORIENT * ((ell + iph2[m] + 1) % self.NORIENT) + ell
2504
- mat2[k2, col0, m, lidx] = w0_2[m, lidx]
2505
- mat2[k2, col1, m, lidx] = w1_2[m, lidx]
2531
+ lidx = np.arange(sim.shape[-1])
2532
+
2533
+ if spin==0:
2534
+ for m in range(self.NORIENT):
2535
+ for ell in range(self.NORIENT):
2536
+ col0 = self.NORIENT * ((ell + iph2[m]) % self.NORIENT) + ell
2537
+ col1 = self.NORIENT * ((ell + iph2[m] + 1) % self.NORIENT) + ell
2538
+ mat2[k2, col0, m, lidx] = w0_2[m, lidx]
2539
+ mat2[k2, col1, m, lidx] = w1_2[m, lidx]
2540
+ else:
2541
+ for sidx in range(2):
2542
+ for m in range(self.NORIENT):
2543
+ for ell in range(self.NORIENT):
2544
+ col0 = self.NORIENT * ((ell + iph2[sidx,m]) % self.NORIENT) + ell
2545
+ col1 = self.NORIENT * ((ell + iph2[sidx,m] + 1) % self.NORIENT) + ell
2546
+ mat2[k2, sidx, col0, m, lidx] = w0_2[sidx,m, lidx]
2547
+ mat2[k2, sidx, col1, m, lidx] = w1_2[sidx,m, lidx]
2506
2548
 
2507
2549
  cmat2[k] = self.backend.bk_cast(
2508
2550
  mat2[0 : k + 1, None, ...].astype("complex64")
@@ -2684,14 +2726,14 @@ class funct(FOC.FoCUS):
2684
2726
  # if the kernel size is bigger than 3 increase the binning before smoothing
2685
2727
  if self.use_2D:
2686
2728
  vmask = self.up_grade(
2687
- vmask, I1.shape[-2] * 2, nouty=I1.shape[-1] * 2
2729
+ vmask, I1.shape[-2] * 2, nouty=I1.shape[-1] * 2,axis=-2
2688
2730
  )
2689
2731
  I1 = self.up_grade(
2690
- I1, I1.shape[-2] * 2, nouty=I1.shape[-1] * 2
2732
+ I1, I1.shape[-2] * 2, nouty=I1.shape[-1] * 2,axis=-2
2691
2733
  )
2692
2734
  if cross:
2693
2735
  I2 = self.up_grade(
2694
- I2, I2.shape[-2] * 2, nouty=I2.shape[-1] * 2
2736
+ I2, I2.shape[-2] * 2, nouty=I2.shape[-1] * 2,axis=-2
2695
2737
  )
2696
2738
  elif self.use_1D:
2697
2739
  vmask = self.up_grade(vmask, I1.shape[-1] * 2)
@@ -2711,16 +2753,16 @@ class funct(FOC.FoCUS):
2711
2753
  # if the kernel size is bigger than 3 increase the binning before smoothing
2712
2754
  if self.use_2D:
2713
2755
  vmask = self.up_grade(
2714
- vmask, I1.shape[-2] * 2, nouty=I1.shape[-1] * 2
2756
+ vmask, I1.shape[-2] * 2, nouty=I1.shape[-1] * 2,axis=-2
2715
2757
  )
2716
2758
  I1 = self.up_grade(
2717
- I1, I1.shape[-2] * 2, nouty=I1.shape[-1] * 2
2759
+ I1, I1.shape[-2] * 2, nouty=I1.shape[-1] * 2,axis=-2
2718
2760
  )
2719
2761
  if cross:
2720
2762
  I2 = self.up_grade(
2721
2763
  I2,
2722
2764
  I2.shape[-2] * 2,
2723
- nouty=I2.shape[-1] * 2,
2765
+ nouty=I2.shape[-1] * 2,axis=-2
2724
2766
  )
2725
2767
  elif self.use_1D:
2726
2768
  vmask = self.up_grade(vmask, I1.shape[-1] * 2)
@@ -2862,13 +2904,22 @@ class funct(FOC.FoCUS):
2862
2904
  if cmat is not None:
2863
2905
  tmp2 = self.backend.bk_repeat(conv1, self.NORIENT, axis=-2)
2864
2906
 
2865
- conv1 = self.backend.bk_reduce_sum(
2866
- self.backend.bk_reshape(
2867
- cmat[j3] * tmp2,
2868
- [tmp2.shape[0], self.NORIENT, self.NORIENT, cmat[j3].shape[2]],
2869
- ),
2870
- 1,
2871
- )
2907
+ if spin==0:
2908
+ conv1 = self.backend.bk_reduce_sum(
2909
+ self.backend.bk_reshape(
2910
+ cmat[j3] * tmp2,
2911
+ [tmp2.shape[0], self.NORIENT, self.NORIENT, cmat[j3].shape[2]],
2912
+ ),
2913
+ 1,
2914
+ )
2915
+ else:
2916
+ conv1 = self.backend.bk_reduce_sum(
2917
+ self.backend.bk_reshape(
2918
+ cmat[j3] * tmp2,
2919
+ [tmp2.shape[0], 2,self.NORIENT, self.NORIENT, cmat[j3].shape[3]],
2920
+ ),
2921
+ 2,
2922
+ )
2872
2923
 
2873
2924
  ### Take the module M1 = |I1 * Psi_j3|
2874
2925
  M1_square = conv1 * self.backend.bk_conjugate(
@@ -2978,18 +3029,33 @@ class funct(FOC.FoCUS):
2978
3029
  ) # [Nbatch, Npix_j3, Norient3]
2979
3030
  if cmat is not None:
2980
3031
  tmp2 = self.backend.bk_repeat(conv2, self.NORIENT, axis=-2)
2981
- conv2 = self.backend.bk_reduce_sum(
2982
- self.backend.bk_reshape(
2983
- cmat[j3] * tmp2,
2984
- [
2985
- tmp2.shape[0],
2986
- self.NORIENT,
2987
- self.NORIENT,
2988
- cmat[j3].shape[2],
2989
- ],
2990
- ),
2991
- 1,
2992
- )
3032
+ if spin==0:
3033
+ conv2 = self.backend.bk_reduce_sum(
3034
+ self.backend.bk_reshape(
3035
+ cmat[j3] * tmp2,
3036
+ [
3037
+ tmp2.shape[0],
3038
+ self.NORIENT,
3039
+ self.NORIENT,
3040
+ cmat[j3].shape[2],
3041
+ ],
3042
+ ),
3043
+ 1,
3044
+ )
3045
+ else:
3046
+ conv2 = self.backend.bk_reduce_sum(
3047
+ self.backend.bk_reshape(
3048
+ cmat[j3] * tmp2,
3049
+ [
3050
+ tmp2.shape[0],
3051
+ 2,
3052
+ self.NORIENT,
3053
+ self.NORIENT,
3054
+ cmat[j3].shape[3],
3055
+ ],
3056
+ ),
3057
+ 2,
3058
+ )
2993
3059
  ### Take the module M2 = |I2 * Psi_j3|
2994
3060
  M2_square = conv2 * self.backend.bk_conjugate(
2995
3061
  conv2
@@ -3140,6 +3206,7 @@ class funct(FOC.FoCUS):
3140
3206
  cmat2=cmat2,
3141
3207
  cell_ids=cell_ids_j3,
3142
3208
  nside_j2=nside_j3,
3209
+ spin=spin,
3143
3210
  ) # [Nbatch, Nmask, Norient3, Norient2]
3144
3211
  else:
3145
3212
  s3 = self._compute_S3(
@@ -3153,6 +3220,7 @@ class funct(FOC.FoCUS):
3153
3220
  cmat2=cmat2,
3154
3221
  cell_ids=cell_ids_j3,
3155
3222
  nside_j2=nside_j3,
3223
+ spin=spin,
3156
3224
  ) # [Nbatch, Nmask, Norient3, Norient2]
3157
3225
 
3158
3226
  if return_data:
@@ -3214,6 +3282,7 @@ class funct(FOC.FoCUS):
3214
3282
  cmat2=cmat2,
3215
3283
  cell_ids=cell_ids_j3,
3216
3284
  nside_j2=nside_j3,
3285
+ spin=spin,
3217
3286
  )
3218
3287
  s3p, vs3p = self._compute_S3(
3219
3288
  j2,
@@ -3226,6 +3295,7 @@ class funct(FOC.FoCUS):
3226
3295
  cmat2=cmat2,
3227
3296
  cell_ids=cell_ids_j3,
3228
3297
  nside_j2=nside_j3,
3298
+ spin=spin,
3229
3299
  )
3230
3300
  else:
3231
3301
  s3p = self._compute_S3(
@@ -3239,6 +3309,7 @@ class funct(FOC.FoCUS):
3239
3309
  cmat2=cmat2,
3240
3310
  cell_ids=cell_ids_j3,
3241
3311
  nside_j2=nside_j3,
3312
+ spin=spin,
3242
3313
  )
3243
3314
  s3 = self._compute_S3(
3244
3315
  j2,
@@ -3251,6 +3322,7 @@ class funct(FOC.FoCUS):
3251
3322
  cmat2=cmat2,
3252
3323
  cell_ids=cell_ids_j3,
3253
3324
  nside_j2=nside_j3,
3325
+ spin=spin,
3254
3326
  )
3255
3327
 
3256
3328
  if return_data:
@@ -3607,18 +3679,19 @@ class funct(FOC.FoCUS):
3607
3679
  return
3608
3680
 
3609
3681
  def _compute_S3(
3610
- self,
3611
- j2,
3612
- j3,
3613
- conv,
3614
- vmask,
3615
- M_dic,
3616
- MconvPsi_dic,
3617
- calc_var=False,
3618
- return_data=False,
3619
- cmat2=None,
3620
- cell_ids=None,
3621
- nside_j2=None,
3682
+ self,
3683
+ j2,
3684
+ j3,
3685
+ conv,
3686
+ vmask,
3687
+ M_dic,
3688
+ MconvPsi_dic,
3689
+ calc_var=False,
3690
+ return_data=False,
3691
+ cmat2=None,
3692
+ cell_ids=None,
3693
+ nside_j2=None,
3694
+ spin=0,
3622
3695
  ):
3623
3696
  """
3624
3697
  Compute the S3 coefficients (auto or cross)
@@ -3637,19 +3710,35 @@ class funct(FOC.FoCUS):
3637
3710
 
3638
3711
  if cmat2 is not None:
3639
3712
  tmp2 = self.backend.bk_repeat(MconvPsi, self.NORIENT, axis=-3)
3640
- MconvPsi = self.backend.bk_reduce_sum(
3641
- self.backend.bk_reshape(
3642
- cmat2[j3][j2] * tmp2,
3643
- [
3644
- tmp2.shape[0],
3645
- self.NORIENT,
3646
- self.NORIENT,
3647
- self.NORIENT,
3648
- cmat2[j3][j2].shape[3],
3649
- ],
3650
- ),
3651
- 1,
3652
- )
3713
+ if spin==0:
3714
+ MconvPsi = self.backend.bk_reduce_sum(
3715
+ self.backend.bk_reshape(
3716
+ cmat2[j3][j2] * tmp2,
3717
+ [
3718
+ tmp2.shape[0],
3719
+ self.NORIENT,
3720
+ self.NORIENT,
3721
+ self.NORIENT,
3722
+ cmat2[j3][j2].shape[3],
3723
+ ],
3724
+ ),
3725
+ 1,
3726
+ )
3727
+ else:
3728
+ MconvPsi = self.backend.bk_reduce_sum(
3729
+ self.backend.bk_reshape(
3730
+ cmat2[j3][j2] * tmp2,
3731
+ [
3732
+ tmp2.shape[0],
3733
+ 2,
3734
+ self.NORIENT,
3735
+ self.NORIENT,
3736
+ self.NORIENT,
3737
+ cmat2[j3][j2].shape[4],
3738
+ ],
3739
+ ),
3740
+ 2,
3741
+ )
3653
3742
 
3654
3743
  # Store it so we can use it in S4 computation
3655
3744
  MconvPsi_dic[j2] = MconvPsi # [Nbatch, Norient3, Norient2, Npix_j3]
@@ -6181,6 +6270,28 @@ class funct(FOC.FoCUS):
6181
6270
  )
6182
6271
  return loss
6183
6272
 
6273
+ def The_lossH(u, scat_operator, args):
6274
+ ref = args[0]
6275
+ sref = args[1]
6276
+ use_v = args[2]
6277
+ ljmax = args[3]
6278
+
6279
+ learn = scat_operator.reduce_mean_batch(
6280
+ scat_operator.eval(
6281
+ u,
6282
+ Jmax=ljmax,
6283
+ norm='self'
6284
+ )
6285
+ )
6286
+
6287
+ # compute scattering covariance of the current synthetised map called u
6288
+ if use_v:
6289
+ loss = scat_operator.reduce_distance(learn,ref,sigma=sref)
6290
+ else:
6291
+ loss = scat_operator.reduce_distance(learn,ref)
6292
+
6293
+ return loss
6294
+
6184
6295
  def The_lossX(u, scat_operator, args):
6185
6296
  ref = args[0]
6186
6297
  sref = args[1]
@@ -6340,37 +6451,74 @@ class funct(FOC.FoCUS):
6340
6451
  # Increase the resolution between each step
6341
6452
  if self.use_2D:
6342
6453
  imap = self.up_grade(
6343
- omap, imap.shape[1] * 2, axis=1, nouty=imap.shape[2] * 2
6454
+ omap,
6455
+ imap.shape[1] * 2,
6456
+ axis=-2,
6457
+ nouty=imap.shape[2] * 2
6344
6458
  )
6345
6459
  elif self.use_1D:
6346
- imap = self.up_grade(omap, imap.shape[1] * 2, axis=1)
6460
+ imap = self.up_grade(omap, imap.shape[1] * 2)
6347
6461
  else:
6348
- imap = self.up_grade(omap, l_nside, axis=1)
6349
-
6462
+ imap = self.up_grade(omap, l_nside)
6463
+
6350
6464
  if grd_mask is not None:
6351
6465
  imap = imap * l_grd_mask[k] + tmp[k] * (1 - l_grd_mask[k])
6352
6466
 
6353
- # compute the coefficients for the target image
6354
- if use_variance:
6355
- ref, sref = self.scattering_cov(
6356
- tmp[k],
6357
- data2=l_ref[k],
6358
- get_variance=True,
6359
- edge=l_edge,
6360
- Jmax=l_jmax[k],
6361
- in_mask=l_in_mask[k],
6362
- iso_ang=iso_ang,
6363
- )
6467
+
6468
+ if self.use_2D:
6469
+ # compute the coefficients for the target image
6470
+ if use_variance:
6471
+ ref, sref = self.scattering_cov(
6472
+ tmp[k],
6473
+ data2=l_ref[k],
6474
+ get_variance=True,
6475
+ edge=l_edge,
6476
+ Jmax=l_jmax[k],
6477
+ in_mask=l_in_mask[k],
6478
+ iso_ang=iso_ang,
6479
+ )
6480
+ else:
6481
+ ref = self.scattering_cov(
6482
+ tmp[k],
6483
+ data2=l_ref[k],
6484
+ in_mask=l_in_mask[k],
6485
+ edge=l_edge,
6486
+ Jmax=l_jmax[k],
6487
+ iso_ang=iso_ang,
6488
+ )
6489
+ sref = ref
6364
6490
  else:
6365
- ref = self.scattering_cov(
6366
- tmp[k],
6367
- data2=l_ref[k],
6368
- in_mask=l_in_mask[k],
6369
- edge=l_edge,
6370
- Jmax=l_jmax[k],
6371
- iso_ang=iso_ang,
6372
- )
6373
- sref = ref
6491
+ ref = self.eval(
6492
+ tmp[k],
6493
+ image2=l_ref[k],
6494
+ mask=l_in_mask[k],
6495
+ Jmax=l_jmax[k],
6496
+ norm='auto'
6497
+ )
6498
+
6499
+ # compute the coefficients for the target image
6500
+ if use_variance:
6501
+ ref, sref = self.eval(
6502
+ tmp[k],
6503
+ image2=l_ref[k],
6504
+ mask=l_in_mask[k],
6505
+ Jmax=l_jmax[k],
6506
+ calc_var=True,
6507
+ norm='self'
6508
+ )
6509
+ else:
6510
+ ref = self.eval(
6511
+ tmp[k],
6512
+ image2=l_ref[k],
6513
+ mask=l_in_mask[k],
6514
+ Jmax=l_jmax[k],
6515
+ norm='self'
6516
+ )
6517
+ sref = ref
6518
+
6519
+ if iso_ang:
6520
+ ref=ref.iso_mean()
6521
+ sref=sref.iso_mean()
6374
6522
 
6375
6523
  # compute the mean of the population does nothing if only one map is given
6376
6524
  ref = self.reduce_mean_batch(ref)
@@ -6379,13 +6527,21 @@ class funct(FOC.FoCUS):
6379
6527
  self.purge_edge_mask()
6380
6528
 
6381
6529
  if l_ref[k] is None:
6382
- # define a loss to minimize
6383
- loss = synthe.Loss(The_loss, self, ref, sref, use_variance, l_jmax[k])
6530
+ if self.use_2D:
6531
+ # define a loss to minimize
6532
+ loss = synthe.Loss(The_loss, self, ref, sref, use_variance, l_jmax[k])
6533
+ else:
6534
+ loss = synthe.Loss(The_lossH, self, ref, sref, use_variance, l_jmax[k])
6384
6535
  else:
6385
6536
  # define a loss to minimize
6386
- loss = synthe.Loss(
6387
- The_lossX, self, ref, sref, use_variance, l_ref[k], l_jmax[k]
6388
- )
6537
+ if self.use_2D:
6538
+ loss = synthe.Loss(
6539
+ The_lossX, self, ref, sref, use_variance, l_ref[k], l_jmax[k]
6540
+ )
6541
+ else:
6542
+ loss = synthe.Loss(
6543
+ The_lossXH, self, ref, sref, use_variance, l_ref[k], l_jmax[k]
6544
+ )
6389
6545
 
6390
6546
  if input_image is not None:
6391
6547
  # define a loss to minimize
@@ -1,6 +1,6 @@
1
1
  Metadata-Version: 2.4
2
2
  Name: foscat
3
- Version: 2025.6.3
3
+ Version: 2025.7.2
4
4
  Summary: Generate synthetic Healpix or 2D data using Cross Scattering Transform
5
5
  Author-email: Jean-Marc DELOUIS <jean.marc.delouis@ifremer.fr>
6
6
  Maintainer-email: Theo Foulquier <theo.foulquier@ifremer.fr>