foscat 2025.6.1__py3-none-any.whl → 2025.6.3__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
foscat/scat_cov.py CHANGED
@@ -92,12 +92,7 @@ class scat_cov:
92
92
  )
93
93
 
94
94
  def conv2complex(self, val):
95
- if (
96
- val.dtype == "complex64"
97
- or val.dtype == "complex128"
98
- or val.dtype == "torch.complex64"
99
- or val.dtype == "torch.complex128"
100
- ):
95
+ if self.backend.bk_is_complex(val):
101
96
  return val
102
97
  else:
103
98
  return self.backend.bk_complex(val, 0 * val)
@@ -107,7 +102,7 @@ class scat_cov:
107
102
  def flatten(self):
108
103
  tmp = [
109
104
  self.conv2complex(
110
- self.backend.bk_reshape(self.S0, [self.S1.shape[0], self.S0.shape[1]])
105
+ self.backend.bk_reshape(self.S0, [self.S1.shape[0], self.S0.shape[1]*self.S0.shape[2]])
111
106
  )
112
107
  ]
113
108
  if self.use_1D:
@@ -177,21 +172,9 @@ class scat_cov:
177
172
  ],
178
173
  )
179
174
  ),
180
- self.backend.bk_reshape(
181
- self.S3,
182
- [
183
- self.S3.shape[0],
184
- self.S3.shape[1]
185
- * self.S3.shape[2]
186
- * self.S3.shape[3]
187
- * self.S3.shape[4],
188
- ],
189
- ),
190
- ]
191
- if self.S3P is not None:
192
- tmp = tmp + [
175
+ self.conv2complex(
193
176
  self.backend.bk_reshape(
194
- self.S3P,
177
+ self.S3,
195
178
  [
196
179
  self.S3.shape[0],
197
180
  self.S3.shape[1]
@@ -200,22 +183,39 @@ class scat_cov:
200
183
  * self.S3.shape[4],
201
184
  ],
202
185
  )
186
+ ),
187
+ ]
188
+ if self.S3P is not None:
189
+ tmp = tmp + [
190
+ self.conv2complex(
191
+ self.backend.bk_reshape(
192
+ self.S3P,
193
+ [
194
+ self.S3.shape[0],
195
+ self.S3.shape[1]
196
+ * self.S3.shape[2]
197
+ * self.S3.shape[3]
198
+ * self.S3.shape[4],
199
+ ],
200
+ )
201
+ )
203
202
  ]
204
203
 
205
204
  tmp = tmp + [
206
- self.backend.bk_reshape(
207
- self.S4,
208
- [
209
- self.S3.shape[0],
210
- self.S4.shape[1]
211
- * self.S4.shape[2]
212
- * self.S4.shape[3]
213
- * self.S4.shape[4]
214
- * self.S4.shape[5],
215
- ],
205
+ self.conv2complex(
206
+ self.backend.bk_reshape(
207
+ self.S4,
208
+ [
209
+ self.S4.shape[0],
210
+ self.S4.shape[1]
211
+ * self.S4.shape[2]
212
+ * self.S4.shape[3]
213
+ * self.S4.shape[4]
214
+ * self.S4.shape[5],
215
+ ],
216
+ )
216
217
  )
217
218
  ]
218
-
219
219
  return self.backend.bk_concat(tmp, 1)
220
220
 
221
221
  # ---------------------------------------------−---------
@@ -2354,9 +2354,9 @@ class funct(FOC.FoCUS):
2354
2354
  tmpi2 = image2
2355
2355
  if upscale:
2356
2356
  l_nside = int(np.sqrt(tmp.shape[1] // 12))
2357
- tmp = self.up_grade(tmp, l_nside * 2, axis=1)
2357
+ tmp = self.up_grade(tmp, l_nside * 2)
2358
2358
  if image2 is not None:
2359
- tmpi2 = self.up_grade(tmpi2, l_nside * 2, axis=1)
2359
+ tmpi2 = self.up_grade(tmpi2, l_nside * 2)
2360
2360
  l_nside = int(np.sqrt(tmp.shape[1] // 12))
2361
2361
  nscale = int(np.log(l_nside) / np.log(2))
2362
2362
  cmat = {}
@@ -2367,12 +2367,12 @@ class funct(FOC.FoCUS):
2367
2367
  if image2 is not None:
2368
2368
  sim = self.backend.bk_real(
2369
2369
  self.backend.bk_L1(
2370
- self.convol(tmp, axis=1)
2371
- * self.backend.bk_conjugate(self.convol(tmpi2, axis=1))
2370
+ self.convol(tmp)
2371
+ * self.backend.bk_conjugate(self.convol(tmpi2))
2372
2372
  )
2373
2373
  )
2374
2374
  else:
2375
- sim = self.backend.bk_abs(self.convol(tmp, axis=1))
2375
+ sim = self.backend.bk_abs(self.convol(tmp))
2376
2376
 
2377
2377
  # instead of difference between "opposite" channels use weighted average
2378
2378
  # of cosine and sine contributions using all channels
@@ -2442,12 +2442,12 @@ class funct(FOC.FoCUS):
2442
2442
 
2443
2443
  for k2 in range(k + 1):
2444
2444
 
2445
- tmp2 = self.backend.bk_repeat(sim, self.NORIENT, axis=1)
2445
+ tmp2 = self.backend.bk_expand_dims(sim,-2)
2446
2446
 
2447
2447
  sim2 = self.backend.bk_reduce_sum(
2448
2448
  self.backend.bk_reshape(
2449
2449
  self.backend.bk_cast(
2450
- mat.reshape(1, self.NORIENT * self.NORIENT, mat.shape[1])
2450
+ mat.reshape(1, self.NORIENT, self.NORIENT, mat.shape[1])
2451
2451
  )
2452
2452
  * tmp2,
2453
2453
  [sim.shape[0], self.NORIENT, self.NORIENT, mat.shape[1]],
@@ -2455,14 +2455,14 @@ class funct(FOC.FoCUS):
2455
2455
  1,
2456
2456
  )
2457
2457
 
2458
- sim2 = self.backend.bk_abs(self.convol(sim2, axis=-1))
2458
+ sim2 = self.backend.bk_abs(self.convol(sim2))
2459
2459
 
2460
2460
  angles = self.backend.bk_reshape(angles, [1, self.NORIENT, 1, 1])
2461
2461
  weighted_cos2 = self.backend.bk_reduce_mean(
2462
- sim2 * self.backend.bk_cos(angles), axis=1
2462
+ sim2 * self.backend.bk_cos(angles), axis=-3
2463
2463
  )
2464
2464
  weighted_sin2 = self.backend.bk_reduce_mean(
2465
- sim2 * self.backend.bk_sin(angles), axis=1
2465
+ sim2 * self.backend.bk_sin(angles), axis=-3
2466
2466
  )
2467
2467
 
2468
2468
  cc2 = weighted_cos2[0]
@@ -2471,13 +2471,13 @@ class funct(FOC.FoCUS):
2471
2471
  if smooth_scale > 0:
2472
2472
  for m in range(smooth_scale):
2473
2473
  if cc2.shape[1] > 12:
2474
- cc2, _ = self.ud_grade_2(self.smooth(cc2, axis=1), axis=1)
2475
- ss2, _ = self.ud_grade_2(self.smooth(ss2, axis=1), axis=1)
2474
+ cc2, _ = self.ud_grade_2(self.smooth(cc2))
2475
+ ss2, _ = self.ud_grade_2(self.smooth(ss2))
2476
2476
 
2477
2477
  if cc2.shape[1] != sim.shape[2]:
2478
2478
  ll_nside = int(np.sqrt(sim.shape[2] // 12))
2479
- cc2 = self.up_grade(cc2, ll_nside, axis=1)
2480
- ss2 = self.up_grade(ss2, ll_nside, axis=1)
2479
+ cc2 = self.up_grade(cc2, ll_nside)
2480
+ ss2 = self.up_grade(ss2, ll_nside)
2481
2481
 
2482
2482
  if self.BACKEND == "numpy":
2483
2483
  phase2 = np.fmod(np.arctan2(ss2, cc2) + 2 * np.pi, 2 * np.pi)
@@ -2509,9 +2509,9 @@ class funct(FOC.FoCUS):
2509
2509
  )
2510
2510
 
2511
2511
  if k < l_nside - 1:
2512
- tmp, _ = self.ud_grade_2(tmp, axis=1)
2512
+ tmp, _ = self.ud_grade_2(tmp)
2513
2513
  if image2 is not None:
2514
- tmpi2, _ = self.ud_grade_2(tmpi2, axis=1)
2514
+ tmpi2, _ = self.ud_grade_2(tmpi)
2515
2515
  return cmat, cmat2
2516
2516
 
2517
2517
  def div_norm(self, complex_value, float_value):
@@ -2534,6 +2534,7 @@ class funct(FOC.FoCUS):
2534
2534
  edge=True,
2535
2535
  nside=None,
2536
2536
  cell_ids=None,
2537
+ spin=0
2537
2538
  ):
2538
2539
  """
2539
2540
  Calculates the scattering correlations for a batch of images. Mean are done over pixels.
@@ -2559,11 +2560,14 @@ class funct(FOC.FoCUS):
2559
2560
  norm: None or str
2560
2561
  If None no normalization is applied, if 'auto' normalize by the reference S2,
2561
2562
  if 'self' normalize by the current S2.
2563
+ spin : Integer
2564
+ If different from 0 compute spinned data (U,V to Divergence/Rotational spin==1) or (Q,U to E,B spin=2).
2565
+ This implies that the input data is 2*12*nside^2.
2562
2566
  Returns
2563
2567
  -------
2564
2568
  S1, S2, S3, S4 normalized
2565
2569
  """
2566
-
2570
+
2567
2571
  return_data = self.return_data
2568
2572
 
2569
2573
  # Check input consistency
@@ -2576,8 +2580,8 @@ class funct(FOC.FoCUS):
2576
2580
  if mask is not None:
2577
2581
  if self.use_2D:
2578
2582
  if (
2579
- image1.shape[-2] != mask.shape[1]
2580
- or image1.shape[-1] != mask.shape[2]
2583
+ image1.shape[-2] != mask.shape[-1]
2584
+ or image1.shape[-1] != mask.shape[-1]
2581
2585
  ):
2582
2586
  print(
2583
2587
  "The LAST 2 COLUMNs of the mask should have the same size ",
@@ -2588,7 +2592,7 @@ class funct(FOC.FoCUS):
2588
2592
  )
2589
2593
  return None
2590
2594
  else:
2591
- if image1.shape[-1] != mask.shape[1]:
2595
+ if image1.shape[-1] != mask.shape[-1]:
2592
2596
  print(
2593
2597
  "The LAST COLUMN of the mask should have the same size ",
2594
2598
  mask.shape,
@@ -2609,7 +2613,6 @@ class funct(FOC.FoCUS):
2609
2613
  cross = True
2610
2614
  l_nside = 2**32 # not initialize if 1D or 2D
2611
2615
  ### PARAMETERS
2612
- axis = 1
2613
2616
  # determine jmax and nside corresponding to the input map
2614
2617
  im_shape = image1.shape
2615
2618
  if self.use_2D:
@@ -2637,10 +2640,7 @@ class funct(FOC.FoCUS):
2637
2640
 
2638
2641
  J = int(np.log(nside) / np.log(2)) # Number of j scales
2639
2642
  else:
2640
- if len(image1.shape) == 2:
2641
- npix = int(im_shape[1]) # Number of pixels
2642
- else:
2643
- npix = int(im_shape[0]) # Number of pixels
2643
+ npix=int(im_shape[-1])
2644
2644
 
2645
2645
  if nside is None:
2646
2646
  nside = int(np.sqrt(npix // 12))
@@ -2659,7 +2659,7 @@ class funct(FOC.FoCUS):
2659
2659
  print("\n\n==========")
2660
2660
 
2661
2661
  ### LOCAL VARIABLES (IMAGES and MASK)
2662
- if len(image1.shape) == 1 or (len(image1.shape) == 2 and self.use_2D):
2662
+ if len(image1.shape) == 1 or (len(image1.shape) == 2 and self.use_2D) or (len(image1.shape) == 2 and spin>0):
2663
2663
  I1 = self.backend.bk_cast(
2664
2664
  self.backend.bk_expand_dims(image1, 0)
2665
2665
  ) # Local image1 [Nbatch, Npix]
@@ -2684,26 +2684,26 @@ class funct(FOC.FoCUS):
2684
2684
  # if the kernel size is bigger than 3 increase the binning before smoothing
2685
2685
  if self.use_2D:
2686
2686
  vmask = self.up_grade(
2687
- vmask, I1.shape[axis] * 2, axis=1, nouty=I1.shape[axis + 1] * 2
2687
+ vmask, I1.shape[-2] * 2, nouty=I1.shape[-1] * 2
2688
2688
  )
2689
2689
  I1 = self.up_grade(
2690
- I1, I1.shape[axis] * 2, axis=axis, nouty=I1.shape[axis + 1] * 2
2690
+ I1, I1.shape[-2] * 2, nouty=I1.shape[-1] * 2
2691
2691
  )
2692
2692
  if cross:
2693
2693
  I2 = self.up_grade(
2694
- I2, I2.shape[axis] * 2, axis=axis, nouty=I2.shape[axis + 1] * 2
2694
+ I2, I2.shape[-2] * 2, nouty=I2.shape[-1] * 2
2695
2695
  )
2696
2696
  elif self.use_1D:
2697
- vmask = self.up_grade(vmask, I1.shape[axis] * 2, axis=1)
2698
- I1 = self.up_grade(I1, I1.shape[axis] * 2, axis=axis)
2697
+ vmask = self.up_grade(vmask, I1.shape[-1] * 2)
2698
+ I1 = self.up_grade(I1, I1.shape[-1] * 2)
2699
2699
  if cross:
2700
- I2 = self.up_grade(I2, I2.shape[axis] * 2, axis=axis)
2700
+ I2 = self.up_grade(I2, I2.shape[-1] * 2)
2701
2701
  nside = nside * 2
2702
2702
  else:
2703
- I1 = self.up_grade(I1, nside * 2, axis=axis)
2704
- vmask = self.up_grade(vmask, nside * 2, axis=1)
2703
+ I1 = self.up_grade(I1, nside * 2)
2704
+ vmask = self.up_grade(vmask, nside * 2)
2705
2705
  if cross:
2706
- I2 = self.up_grade(I2, nside * 2, axis=axis)
2706
+ I2 = self.up_grade(I2, nside * 2)
2707
2707
 
2708
2708
  nside = nside * 2
2709
2709
 
@@ -2711,29 +2711,28 @@ class funct(FOC.FoCUS):
2711
2711
  # if the kernel size is bigger than 3 increase the binning before smoothing
2712
2712
  if self.use_2D:
2713
2713
  vmask = self.up_grade(
2714
- vmask, I1.shape[axis] * 2, axis=1, nouty=I1.shape[axis + 1] * 2
2714
+ vmask, I1.shape[-2] * 2, nouty=I1.shape[-1] * 2
2715
2715
  )
2716
2716
  I1 = self.up_grade(
2717
- I1, I1.shape[axis] * 2, axis=axis, nouty=I1.shape[axis + 1] * 2
2717
+ I1, I1.shape[-2] * 2, nouty=I1.shape[-1] * 2
2718
2718
  )
2719
2719
  if cross:
2720
2720
  I2 = self.up_grade(
2721
2721
  I2,
2722
- I2.shape[axis] * 2,
2723
- axis=axis,
2724
- nouty=I2.shape[axis + 1] * 2,
2722
+ I2.shape[-2] * 2,
2723
+ nouty=I2.shape[-1] * 2,
2725
2724
  )
2726
2725
  elif self.use_1D:
2727
- vmask = self.up_grade(vmask, I1.shape[axis] * 2, axis=1)
2728
- I1 = self.up_grade(I1, I1.shape[axis] * 2, axis=axis)
2726
+ vmask = self.up_grade(vmask, I1.shape[-1] * 2)
2727
+ I1 = self.up_grade(I1, I1.shape[-1] * 2)
2729
2728
  if cross:
2730
- I2 = self.up_grade(I2, I2.shape[axis] * 2, axis=axis)
2729
+ I2 = self.up_grade(I2, I2.shape[-1] * 2)
2731
2730
  nside = nside * 2
2732
2731
  else:
2733
- I1 = self.up_grade(I1, nside * 2, axis=axis)
2734
- vmask = self.up_grade(vmask, nside * 2, axis=1)
2732
+ I1 = self.up_grade(I1, nside * 2)
2733
+ vmask = self.up_grade(vmask, nside * 2)
2735
2734
  if cross:
2736
- I2 = self.up_grade(I2, nside * 2, axis=axis)
2735
+ I2 = self.up_grade(I2, nside * 2)
2737
2736
  nside = nside * 2
2738
2737
 
2739
2738
  # Normalize the masks because they have different pixel numbers
@@ -2762,6 +2761,7 @@ class funct(FOC.FoCUS):
2762
2761
  off_S2 = -2
2763
2762
  off_S3 = -3
2764
2763
  off_S4 = -4
2764
+
2765
2765
  if self.use_1D:
2766
2766
  off_S2 = -1
2767
2767
  off_S3 = -1
@@ -2793,13 +2793,20 @@ class funct(FOC.FoCUS):
2793
2793
  )
2794
2794
  else:
2795
2795
  if not cross:
2796
- s0, l_vs0 = self.masked_mean(I1, vmask, axis=1, calc_var=True)
2796
+ s0, l_vs0 = self.masked_mean(I1,
2797
+ vmask,
2798
+ calc_var=True)
2797
2799
  else:
2798
2800
  s0, l_vs0 = self.masked_mean(
2799
- self.backend.bk_L1(I1 * I2), vmask, axis=1, calc_var=True
2800
- )
2801
- vs0 = self.backend.bk_concat([l_vs0, l_vs0], 1)
2802
- s0 = self.backend.bk_concat([s0, l_vs0], 1)
2801
+ self.backend.bk_L1(I1 * I2),
2802
+ vmask,
2803
+ calc_var=True)
2804
+
2805
+ vs0 = self.backend.bk_concat([l_vs0, l_vs0], -1)
2806
+ s0 = self.backend.bk_concat([s0, l_vs0], -1)
2807
+ if spin>0:
2808
+ vs0=self.backend.bk_reshape(vs0,[vs0.shape[0],vs0.shape[1],2,vs0.shape[2]//2])
2809
+ s0=self.backend.bk_reshape(s0,[s0.shape[0],s0.shape[1],2,s0.shape[2]//2])
2803
2810
  #### COMPUTE S1, S2, S3 and S4
2804
2811
  nside_j3 = nside # NSIDE start (nside_j3 = nside / 2^j3)
2805
2812
 
@@ -2848,13 +2855,13 @@ class funct(FOC.FoCUS):
2848
2855
  ####### S1 and S2
2849
2856
  ### Make the convolution I1 * Psi_j3
2850
2857
  conv1 = self.convol(
2851
- I1, axis=1, cell_ids=cell_ids_j3, nside=nside_j3
2858
+ I1, cell_ids=cell_ids_j3, nside=nside_j3,
2859
+ spin=spin
2852
2860
  ) # [Nbatch, Norient3 , Npix_j3]
2853
2861
 
2854
2862
  if cmat is not None:
2855
-
2856
2863
  tmp2 = self.backend.bk_repeat(conv1, self.NORIENT, axis=-2)
2857
-
2864
+
2858
2865
  conv1 = self.backend.bk_reduce_sum(
2859
2866
  self.backend.bk_reshape(
2860
2867
  cmat[j3] * tmp2,
@@ -2871,7 +2878,6 @@ class funct(FOC.FoCUS):
2871
2878
  M1 = self.backend.bk_L1(M1_square) # [Nbatch, Npix_j3, Norient3]
2872
2879
  # Store M1_j3 in a dictionary
2873
2880
  M1_dic[j3] = M1
2874
-
2875
2881
  if not cross: # Auto
2876
2882
  M1_square = self.backend.bk_real(M1_square)
2877
2883
 
@@ -2882,11 +2888,11 @@ class funct(FOC.FoCUS):
2882
2888
  else:
2883
2889
  if calc_var:
2884
2890
  s2, vs2 = self.masked_mean(
2885
- M1_square, vmask, axis=2, rank=j3, calc_var=True
2891
+ M1_square, vmask, rank=j3, calc_var=True
2886
2892
  )
2887
2893
  else:
2888
- s2 = self.masked_mean(M1_square, vmask, axis=2, rank=j3)
2889
-
2894
+ s2 = self.masked_mean(M1_square, vmask, rank=j3)
2895
+
2890
2896
  if cond_init_P1_dic:
2891
2897
  # We fill P1_dic with S2 for normalisation of S3 and S4
2892
2898
  P1_dic[j3] = self.backend.bk_real(s2) # [Nbatch, Nmask, Norient3]
@@ -2929,11 +2935,11 @@ class funct(FOC.FoCUS):
2929
2935
  else:
2930
2936
  if calc_var:
2931
2937
  s1, vs1 = self.masked_mean(
2932
- M1, vmask, axis=2, rank=j3, calc_var=True
2938
+ M1, vmask, rank=j3, calc_var=True
2933
2939
  ) # [Nbatch, Nmask, Norient3]
2934
2940
  else:
2935
2941
  s1 = self.masked_mean(
2936
- M1, vmask, axis=2, rank=j3
2942
+ M1, vmask, rank=j3
2937
2943
  ) # [Nbatch, Nmask, Norient3]
2938
2944
 
2939
2945
  if return_data:
@@ -2967,7 +2973,8 @@ class funct(FOC.FoCUS):
2967
2973
  else: # Cross
2968
2974
  ### Make the convolution I2 * Psi_j3
2969
2975
  conv2 = self.convol(
2970
- I2, axis=2, cell_ids=cell_ids_j3, nside=nside_j3
2976
+ I2, cell_ids=cell_ids_j3, nside=nside_j3,
2977
+ spin=spin
2971
2978
  ) # [Nbatch, Npix_j3, Norient3]
2972
2979
  if cmat is not None:
2973
2980
  tmp2 = self.backend.bk_repeat(conv2, self.NORIENT, axis=-2)
@@ -3001,17 +3008,17 @@ class funct(FOC.FoCUS):
3001
3008
  else:
3002
3009
  if calc_var:
3003
3010
  p1, vp1 = self.masked_mean(
3004
- M1_square, vmask, axis=2, rank=j3, calc_var=True
3011
+ M1_square, vmask, rank=j3, calc_var=True
3005
3012
  ) # [Nbatch, Nmask, Norient3]
3006
3013
  p2, vp2 = self.masked_mean(
3007
- M2_square, vmask, axis=2, rank=j3, calc_var=True
3014
+ M2_square, vmask, rank=j3, calc_var=True
3008
3015
  ) # [Nbatch, Nmask, Norient3]
3009
3016
  else:
3010
3017
  p1 = self.masked_mean(
3011
- M1_square, vmask, axis=2, rank=j3
3018
+ M1_square, vmask, rank=j3
3012
3019
  ) # [Nbatch, Nmask, Norient3]
3013
3020
  p2 = self.masked_mean(
3014
- M2_square, vmask, axis=2, rank=j3
3021
+ M2_square, vmask, rank=j3
3015
3022
  ) # [Nbatch, Nmask, Norient3]
3016
3023
  # We fill P1_dic with S2 for normalisation of S3 and S4
3017
3024
  P1_dic[j3] = self.backend.bk_real(p1) # [Nbatch, Nmask, Norient3]
@@ -3027,10 +3034,10 @@ class funct(FOC.FoCUS):
3027
3034
  else:
3028
3035
  if calc_var:
3029
3036
  s2, vs2 = self.masked_mean(
3030
- s2, vmask, axis=2, rank=j3, calc_var=True
3037
+ s2, vmask, rank=j3, calc_var=True
3031
3038
  )
3032
3039
  else:
3033
- s2 = self.masked_mean(s2, vmask, axis=2, rank=j3)
3040
+ s2 = self.masked_mean(s2, vmask, rank=j3)
3034
3041
 
3035
3042
  if return_data:
3036
3043
  if out_nside is not None and out_nside < nside_j3:
@@ -3072,11 +3079,11 @@ class funct(FOC.FoCUS):
3072
3079
  else:
3073
3080
  if calc_var:
3074
3081
  s1, vs1 = self.masked_mean(
3075
- MX, vmask, axis=2, rank=j3, calc_var=True
3082
+ MX, vmask, rank=j3, calc_var=True
3076
3083
  ) # [Nbatch, Nmask, Norient3]
3077
3084
  else:
3078
3085
  s1 = self.masked_mean(
3079
- MX, vmask, axis=2, rank=j3
3086
+ MX, vmask, rank=j3
3080
3087
  ) # [Nbatch, Nmask, Norient3]
3081
3088
  if return_data:
3082
3089
  if out_nside is not None and out_nside < nside_j3:
@@ -3482,39 +3489,39 @@ class funct(FOC.FoCUS):
3482
3489
  ### Image I1,
3483
3490
  # downscale the I1 [Nbatch, Npix_j3]
3484
3491
  if j3 != Jmax - 1:
3485
- I1 = self.smooth(I1, axis=1, cell_ids=cell_ids_j3, nside=nside_j3)
3492
+ I1 = self.smooth(I1, cell_ids=cell_ids_j3, nside=nside_j3)
3486
3493
  I1, new_cell_ids_j3 = self.ud_grade_2(
3487
- I1, axis=1, cell_ids=cell_ids_j3, nside=nside_j3
3494
+ I1, cell_ids=cell_ids_j3, nside=nside_j3
3488
3495
  )
3489
3496
 
3490
3497
  ### Image I2
3491
3498
  if cross:
3492
- I2 = self.smooth(I2, axis=1, cell_ids=cell_ids_j3, nside=nside_j3)
3499
+ I2 = self.smooth(I2, cell_ids=cell_ids_j3, nside=nside_j3)
3493
3500
  I2, new_cell_ids_j3 = self.ud_grade_2(
3494
- I2, axis=1, cell_ids=cell_ids_j3, nside=nside_j3
3501
+ I2, cell_ids=cell_ids_j3, nside=nside_j3
3495
3502
  )
3496
3503
 
3497
3504
  ### Modules
3498
3505
  for j2 in range(0, j3 + 1): # j2 =< j3
3499
3506
  ### Dictionary M1_dic[j2]
3500
3507
  M1_smooth = self.smooth(
3501
- M1_dic[j2], axis=2, cell_ids=cell_ids_j3, nside=nside_j3
3508
+ M1_dic[j2], cell_ids=cell_ids_j3, nside=nside_j3
3502
3509
  ) # [Nbatch, Npix_j3, Norient3]
3503
3510
  M1_dic[j2], new_cell_ids_j2 = self.ud_grade_2(
3504
- M1_smooth, axis=2, cell_ids=cell_ids_j3, nside=nside_j3
3511
+ M1_smooth, cell_ids=cell_ids_j3, nside=nside_j3
3505
3512
  ) # [Nbatch, Npix_j3, Norient3]
3506
3513
 
3507
3514
  ### Dictionary M2_dic[j2]
3508
3515
  if cross:
3509
3516
  M2_smooth = self.smooth(
3510
- M2_dic[j2], axis=2, cell_ids=cell_ids_j3, nside=nside_j3
3517
+ M2_dic[j2], cell_ids=cell_ids_j3, nside=nside_j3
3511
3518
  ) # [Nbatch, Npix_j3, Norient3]
3512
3519
  M2_dic[j2], new_cell_ids_j2 = self.ud_grade_2(
3513
- M2_smooth, axis=2, cell_ids=cell_ids_j3, nside=nside_j3
3520
+ M2_smooth, cell_ids=cell_ids_j3, nside=nside_j3
3514
3521
  ) # [Nbatch, Npix_j3, Norient3]
3515
3522
  ### Mask
3516
3523
  vmask, new_cell_ids_j3 = self.ud_grade_2(
3517
- vmask, axis=1, cell_ids=cell_ids_j3, nside=nside_j3
3524
+ vmask, cell_ids=cell_ids_j3, nside=nside_j3
3518
3525
  )
3519
3526
 
3520
3527
  if self.mask_thres is not None:
@@ -3529,33 +3536,21 @@ class funct(FOC.FoCUS):
3529
3536
  self.P1_dic = P1_dic
3530
3537
  if cross:
3531
3538
  self.P2_dic = P2_dic
3532
- """
3533
- Sout=[s0]+S1+S2+S3+S4
3534
-
3535
- if cross:
3536
- Sout=Sout+S3P
3537
- if calc_var:
3538
- SVout=[vs0]+VS1+VS2+VS3+VS4
3539
- if cross:
3540
- VSout=VSout+VS3P
3541
- return self.backend.bk_concat(Sout, 2),self.backend.bk_concat(VSout, 2)
3542
-
3543
- return self.backend.bk_concat(Sout, 2)
3544
- """
3539
+
3545
3540
  if not return_data:
3546
- S1 = self.backend.bk_concat(S1, 2)
3547
- S2 = self.backend.bk_concat(S2, 2)
3548
- S3 = self.backend.bk_concat(S3, 2)
3549
- S4 = self.backend.bk_concat(S4, 2)
3541
+ S1 = self.backend.bk_concat(S1, -2)
3542
+ S2 = self.backend.bk_concat(S2, -2)
3543
+ S3 = self.backend.bk_concat(S3, -3)
3544
+ S4 = self.backend.bk_concat(S4, -4)
3550
3545
  if cross:
3551
- S3P = self.backend.bk_concat(S3P, 2)
3546
+ S3P = self.backend.bk_concat(S3P, -3)
3552
3547
  if calc_var:
3553
- VS1 = self.backend.bk_concat(VS1, 2)
3554
- VS2 = self.backend.bk_concat(VS2, 2)
3555
- VS3 = self.backend.bk_concat(VS3, 2)
3556
- VS4 = self.backend.bk_concat(VS4, 2)
3548
+ VS1 = self.backend.bk_concat(VS1, -2)
3549
+ VS2 = self.backend.bk_concat(VS2, -2)
3550
+ VS3 = self.backend.bk_concat(VS3, -3)
3551
+ VS4 = self.backend.bk_concat(VS4, -4)
3557
3552
  if cross:
3558
- VS3P = self.backend.bk_concat(VS3P, 2)
3553
+ VS3P = self.backend.bk_concat(VS3P, -3)
3559
3554
  if calc_var:
3560
3555
  if not cross:
3561
3556
  return scat_cov(
@@ -3637,7 +3632,7 @@ class funct(FOC.FoCUS):
3637
3632
  ### Compute |I1 * Psi_j2| * Psi_j3 = M1_j2 * Psi_j3
3638
3633
  # Warning: M1_dic[j2] is already at j3 resolution [Nbatch, Norient3, Npix_j3]
3639
3634
  MconvPsi = self.convol(
3640
- M_dic[j2], axis=2, cell_ids=cell_ids, nside=nside_j2
3635
+ M_dic[j2], cell_ids=cell_ids, nside=nside_j2
3641
3636
  ) # [Nbatch, Norient3, Norient2, Npix_j3]
3642
3637
 
3643
3638
  if cmat2 is not None:
@@ -3674,12 +3669,12 @@ class funct(FOC.FoCUS):
3674
3669
  else:
3675
3670
  if calc_var:
3676
3671
  s3, vs3 = self.masked_mean(
3677
- s3, vmask, axis=3, rank=j2, calc_var=True
3672
+ s3, vmask, rank=j2, calc_var=True
3678
3673
  ) # [Nbatch, Nmask, Norient3, Norient2]
3679
3674
  return s3, vs3
3680
3675
  else:
3681
3676
  s3 = self.masked_mean(
3682
- s3, vmask, axis=3, rank=j2
3677
+ s3, vmask, rank=j2
3683
3678
  ) # [Nbatch, Nmask, Norient3, Norient2]
3684
3679
  return s3
3685
3680
 
@@ -3717,12 +3712,12 @@ class funct(FOC.FoCUS):
3717
3712
  else:
3718
3713
  if calc_var:
3719
3714
  s4, vs4 = self.masked_mean(
3720
- s4, vmask, axis=4, rank=j2, calc_var=True
3715
+ s4, vmask, rank=j2, calc_var=True
3721
3716
  ) # [Nbatch, Nmask, Norient3, Norient2, Norient1]
3722
3717
  return s4, vs4
3723
3718
  else:
3724
3719
  s4 = self.masked_mean(
3725
- s4, vmask, axis=4, rank=j2
3720
+ s4, vmask, rank=j2
3726
3721
  ) # [Nbatch, Nmask, Norient3, Norient2, Norient1]
3727
3722
  return s4
3728
3723
 
@@ -3922,20 +3917,20 @@ class funct(FOC.FoCUS):
3922
3917
  #
3923
3918
  # ---------------------------------------------------------------------------
3924
3919
  def scattering_cov(
3925
- self,
3926
- data,
3927
- data2=None,
3928
- Jmax=None,
3929
- if_large_batch=False,
3930
- S4_criteria=None,
3931
- use_ref=False,
3932
- normalization="S2",
3933
- edge=False,
3934
- in_mask=None,
3935
- pseudo_coef=1,
3936
- get_variance=False,
3937
- ref_sigma=None,
3938
- iso_ang=False,
3920
+ self,
3921
+ data,
3922
+ data2=None,
3923
+ Jmax=None,
3924
+ if_large_batch=False,
3925
+ S4_criteria=None,
3926
+ use_ref=False,
3927
+ normalization="S2",
3928
+ edge=False,
3929
+ in_mask=None,
3930
+ pseudo_coef=1,
3931
+ get_variance=False,
3932
+ ref_sigma=None,
3933
+ iso_ang=False,
3939
3934
  ):
3940
3935
  """
3941
3936
  Calculates the scattering correlations for a batch of images, including:
@@ -4048,7 +4043,10 @@ class funct(FOC.FoCUS):
4048
4043
  if data2 is not None:
4049
4044
  N_image2 = data2.shape[0]
4050
4045
 
4051
- nside = int(np.sqrt(npix // 12))
4046
+ if spin==0:
4047
+ nside = int(np.sqrt(npix // 12))
4048
+ else:
4049
+ nside = int(np.sqrt(npix // 24))
4052
4050
 
4053
4051
  J = int(np.log(nside) / np.log(2)) # Number of j scales
4054
4052
 
@@ -5781,7 +5779,9 @@ class funct(FOC.FoCUS):
5781
5779
 
5782
5780
  def from_gaussian(self, x):
5783
5781
 
5784
- x = self.backend.bk_clip_by_value(x, self.val_min, self.val_max)
5782
+ x = self.backend.bk_clip_by_value(x,
5783
+ self.val_min+1E-4*abs(self.val_min),
5784
+ self.val_max-1E-4*abs(self.val_max))
5785
5785
  return self.f_gaussian(self.backend.to_numpy(x))
5786
5786
 
5787
5787
  def square(self, x):
@@ -1,6 +1,6 @@
1
1
  Metadata-Version: 2.4
2
2
  Name: foscat
3
- Version: 2025.6.1
3
+ Version: 2025.6.3
4
4
  Summary: Generate synthetic Healpix or 2D data using Cross Scattering Transform
5
5
  Author-email: Jean-Marc DELOUIS <jean.marc.delouis@ifremer.fr>
6
6
  Maintainer-email: Theo Foulquier <theo.foulquier@ifremer.fr>