foscat 2025.6.1__py3-none-any.whl → 2025.7.1__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
foscat/scat_cov.py CHANGED
@@ -44,6 +44,12 @@ class scat_cov:
44
44
  self.idx1 = None
45
45
  self.idx2 = None
46
46
  self.use_1D = use_1D
47
+ self.numel = self.backend.bk_len(s0)+ \
48
+ self.backend.bk_len(s1)+ \
49
+ self.backend.bk_len(s2)+ \
50
+ self.backend.bk_len(s3)+ \
51
+ self.backend.bk_len(s4)+ \
52
+ self.backend.bk_len(s3p)
47
53
 
48
54
  def numpy(self):
49
55
  if self.BACKEND == "numpy":
@@ -92,12 +98,7 @@ class scat_cov:
92
98
  )
93
99
 
94
100
  def conv2complex(self, val):
95
- if (
96
- val.dtype == "complex64"
97
- or val.dtype == "complex128"
98
- or val.dtype == "torch.complex64"
99
- or val.dtype == "torch.complex128"
100
- ):
101
+ if self.backend.bk_is_complex(val):
101
102
  return val
102
103
  else:
103
104
  return self.backend.bk_complex(val, 0 * val)
@@ -107,7 +108,7 @@ class scat_cov:
107
108
  def flatten(self):
108
109
  tmp = [
109
110
  self.conv2complex(
110
- self.backend.bk_reshape(self.S0, [self.S1.shape[0], self.S0.shape[1]])
111
+ self.backend.bk_reshape(self.S0, [self.S1.shape[0], self.S0.shape[1]*self.S0.shape[2]])
111
112
  )
112
113
  ]
113
114
  if self.use_1D:
@@ -177,21 +178,9 @@ class scat_cov:
177
178
  ],
178
179
  )
179
180
  ),
180
- self.backend.bk_reshape(
181
- self.S3,
182
- [
183
- self.S3.shape[0],
184
- self.S3.shape[1]
185
- * self.S3.shape[2]
186
- * self.S3.shape[3]
187
- * self.S3.shape[4],
188
- ],
189
- ),
190
- ]
191
- if self.S3P is not None:
192
- tmp = tmp + [
181
+ self.conv2complex(
193
182
  self.backend.bk_reshape(
194
- self.S3P,
183
+ self.S3,
195
184
  [
196
185
  self.S3.shape[0],
197
186
  self.S3.shape[1]
@@ -200,22 +189,39 @@ class scat_cov:
200
189
  * self.S3.shape[4],
201
190
  ],
202
191
  )
192
+ ),
193
+ ]
194
+ if self.S3P is not None:
195
+ tmp = tmp + [
196
+ self.conv2complex(
197
+ self.backend.bk_reshape(
198
+ self.S3P,
199
+ [
200
+ self.S3.shape[0],
201
+ self.S3.shape[1]
202
+ * self.S3.shape[2]
203
+ * self.S3.shape[3]
204
+ * self.S3.shape[4],
205
+ ],
206
+ )
207
+ )
203
208
  ]
204
209
 
205
210
  tmp = tmp + [
206
- self.backend.bk_reshape(
207
- self.S4,
208
- [
209
- self.S3.shape[0],
210
- self.S4.shape[1]
211
- * self.S4.shape[2]
212
- * self.S4.shape[3]
213
- * self.S4.shape[4]
214
- * self.S4.shape[5],
215
- ],
211
+ self.conv2complex(
212
+ self.backend.bk_reshape(
213
+ self.S4,
214
+ [
215
+ self.S4.shape[0],
216
+ self.S4.shape[1]
217
+ * self.S4.shape[2]
218
+ * self.S4.shape[3]
219
+ * self.S4.shape[4]
220
+ * self.S4.shape[5],
221
+ ],
222
+ )
216
223
  )
217
224
  ]
218
-
219
225
  return self.backend.bk_concat(tmp, 1)
220
226
 
221
227
  # ---------------------------------------------−---------
@@ -1589,20 +1595,20 @@ class scat_cov:
1589
1595
  def mean(self):
1590
1596
  if self.S1 is not None: # Auto
1591
1597
  return (
1592
- abs(self.get_np(self.S0)).mean()
1593
- + abs(self.get_np(self.S1)).mean()
1594
- + abs(self.get_np(self.S3)).mean()
1595
- + abs(self.get_np(self.S4)).mean()
1596
- + abs(self.get_np(self.S2)).mean()
1597
- ) / 4
1598
+ abs(self.get_np(self.S0)).sum()
1599
+ + abs(self.get_np(self.S1)).sum()
1600
+ + abs(self.get_np(self.S3)).sum()
1601
+ + abs(self.get_np(self.S4)).sum()
1602
+ + abs(self.get_np(self.S2)).sum()
1603
+ ) / self.numel
1598
1604
  else: # Cross
1599
1605
  return (
1600
- abs(self.get_np(self.S0)).mean()
1601
- + abs(self.get_np(self.S3)).mean()
1602
- + abs(self.get_np(self.S3P)).mean()
1603
- + abs(self.get_np(self.S4)).mean()
1604
- + abs(self.get_np(self.S2)).mean()
1605
- ) / 4
1606
+ abs(self.get_np(self.S0)).sum()
1607
+ + abs(self.get_np(self.S3)).sum()
1608
+ + abs(self.get_np(self.S3P)).sum()
1609
+ + abs(self.get_np(self.S4)).sum()
1610
+ + abs(self.get_np(self.S2)).sum()
1611
+ ) / self.numel
1606
1612
 
1607
1613
  def initdx(self, norient):
1608
1614
  idx1 = np.zeros([norient * norient], dtype="int")
@@ -2348,16 +2354,16 @@ class funct(FOC.FoCUS):
2348
2354
  )
2349
2355
 
2350
2356
  # compute local direction to make the statistical analysis more efficient
2351
- def stat_cfft(self, im, image2=None, upscale=False, smooth_scale=0):
2357
+ def stat_cfft(self, im, image2=None, upscale=False, smooth_scale=0,spin=0):
2352
2358
  tmp = im
2353
2359
  if image2 is not None:
2354
2360
  tmpi2 = image2
2355
2361
  if upscale:
2356
- l_nside = int(np.sqrt(tmp.shape[1] // 12))
2357
- tmp = self.up_grade(tmp, l_nside * 2, axis=1)
2362
+ l_nside = int(np.sqrt(tmp.shape[-1] // 12))
2363
+ tmp = self.up_grade(tmp, l_nside * 2)
2358
2364
  if image2 is not None:
2359
- tmpi2 = self.up_grade(tmpi2, l_nside * 2, axis=1)
2360
- l_nside = int(np.sqrt(tmp.shape[1] // 12))
2365
+ tmpi2 = self.up_grade(tmpi2, l_nside * 2)
2366
+ l_nside = int(np.sqrt(tmp.shape[-1] // 12))
2361
2367
  nscale = int(np.log(l_nside) / np.log(2))
2362
2368
  cmat = {}
2363
2369
  cmat2 = {}
@@ -2367,20 +2373,23 @@ class funct(FOC.FoCUS):
2367
2373
  if image2 is not None:
2368
2374
  sim = self.backend.bk_real(
2369
2375
  self.backend.bk_L1(
2370
- self.convol(tmp, axis=1)
2371
- * self.backend.bk_conjugate(self.convol(tmpi2, axis=1))
2376
+ self.convol(tmp,spin=spin)
2377
+ * self.backend.bk_conjugate(self.convol(tmpi2,spin=spin))
2372
2378
  )
2373
2379
  )
2374
2380
  else:
2375
- sim = self.backend.bk_abs(self.convol(tmp, axis=1))
2381
+ sim = self.backend.bk_abs(self.convol(tmp,spin=spin))
2376
2382
 
2377
2383
  # instead of difference between "opposite" channels use weighted average
2378
2384
  # of cosine and sine contributions using all channels
2379
- angles = self.backend.bk_cast(
2380
- (2 * np.pi * np.arange(self.NORIENT) / self.NORIENT).reshape(
2381
- 1, self.NORIENT, 1
2382
- )
2383
- ) # shape: (NORIENT,)
2385
+ if spin==0:
2386
+ angles = self.backend.bk_cast(
2387
+ (2 * np.pi * np.arange(self.NORIENT)
2388
+ / self.NORIENT).reshape(1,self.NORIENT,1)) # shape: (NORIENT,)
2389
+ else:
2390
+ angles = self.backend.bk_cast(
2391
+ (2 * np.pi * np.arange(self.NORIENT)
2392
+ / self.NORIENT).reshape(1,1,self.NORIENT,1)) # shape: (NORIENT,)
2384
2393
 
2385
2394
  # we use cosines and sines as weights for sim
2386
2395
  weighted_cos = self.backend.bk_reduce_mean(
@@ -2399,8 +2408,8 @@ class funct(FOC.FoCUS):
2399
2408
  cc, _ = self.ud_grade_2(self.smooth(cc))
2400
2409
  ss, _ = self.ud_grade_2(self.smooth(ss))
2401
2410
 
2402
- if cc.shape[0] != tmp.shape[0]:
2403
- ll_nside = int(np.sqrt(tmp.shape[1] // 12))
2411
+ if cc.shape[-1] != tmp.shape[-1]:
2412
+ ll_nside = int(np.sqrt(tmp.shape[-1] // 12))
2404
2413
  cc = self.up_grade(cc, ll_nside)
2405
2414
  ss = self.up_grade(ss, ll_nside)
2406
2415
 
@@ -2423,46 +2432,70 @@ class funct(FOC.FoCUS):
2423
2432
  w1 = np.sin(delta * np.pi / 2) ** 2
2424
2433
 
2425
2434
  # build rotation matrix
2426
- mat = np.zeros([self.NORIENT * self.NORIENT, sim.shape[2]])
2427
- lidx = np.arange(sim.shape[2])
2435
+ if spin==0:
2436
+ mat = np.zeros([self.NORIENT * self.NORIENT, sim.shape[-1]])
2437
+ else:
2438
+ mat = np.zeros([2,self.NORIENT * self.NORIENT, sim.shape[-1]])
2439
+ lidx = np.arange(sim.shape[-1])
2428
2440
  for ell in range(self.NORIENT):
2429
2441
  # Instead of simple linear weights, we use the cosine weights w0 and w1.
2430
2442
  col0 = self.NORIENT * ((ell + iph) % self.NORIENT) + ell
2431
2443
  col1 = self.NORIENT * ((ell + iph + 1) % self.NORIENT) + ell
2432
2444
 
2433
- mat[col0, lidx] = w0
2434
- mat[col1, lidx] = w1
2445
+ if spin==0:
2446
+ mat[col0, lidx] = w0
2447
+ mat[col1, lidx] = w1
2448
+ else:
2449
+ mat[0,col0, lidx] = w0[0]
2450
+ mat[0,col1, lidx] = w1[0]
2451
+ mat[1,col0, lidx] = w0[1]
2452
+ mat[1,col1, lidx] = w1[1]
2435
2453
 
2436
2454
  cmat[k] = self.backend.bk_cast(mat[None, ...].astype("complex64"))
2437
2455
 
2438
2456
  # do same modifications for mat2
2439
- mat2 = np.zeros(
2440
- [k + 1, self.NORIENT * self.NORIENT, self.NORIENT, sim.shape[2]]
2441
- )
2457
+ if spin==0:
2458
+ mat2 = np.zeros(
2459
+ [k + 1, self.NORIENT * self.NORIENT, self.NORIENT, sim.shape[-1]]
2460
+ )
2461
+ else:
2462
+ mat2 = np.zeros(
2463
+ [k + 1, 2, self.NORIENT * self.NORIENT, self.NORIENT, sim.shape[-1]]
2464
+ )
2442
2465
 
2443
2466
  for k2 in range(k + 1):
2444
2467
 
2445
- tmp2 = self.backend.bk_repeat(sim, self.NORIENT, axis=1)
2446
-
2447
- sim2 = self.backend.bk_reduce_sum(
2448
- self.backend.bk_reshape(
2449
- self.backend.bk_cast(
2450
- mat.reshape(1, self.NORIENT * self.NORIENT, mat.shape[1])
2451
- )
2452
- * tmp2,
2453
- [sim.shape[0], self.NORIENT, self.NORIENT, mat.shape[1]],
2454
- ),
2455
- 1,
2456
- )
2457
-
2458
- sim2 = self.backend.bk_abs(self.convol(sim2, axis=-1))
2468
+ tmp2 = self.backend.bk_expand_dims(sim,-2)
2469
+ if spin==0:
2470
+ sim2 = self.backend.bk_reduce_sum(
2471
+ self.backend.bk_reshape(
2472
+ self.backend.bk_cast(
2473
+ mat.reshape(1, self.NORIENT, self.NORIENT, mat.shape[-1])
2474
+ )
2475
+ * tmp2,
2476
+ [sim.shape[0], self.NORIENT, self.NORIENT, mat.shape[-1]],
2477
+ ),
2478
+ 1,
2479
+ )
2480
+ else:
2481
+ sim2 = self.backend.bk_reduce_sum(
2482
+ self.backend.bk_reshape(
2483
+ self.backend.bk_cast(
2484
+ mat.reshape(1, 2, self.NORIENT, self.NORIENT, mat.shape[-1])
2485
+ )
2486
+ * tmp2,
2487
+ [sim.shape[0], 2, self.NORIENT, self.NORIENT, mat.shape[-1]],
2488
+ ),
2489
+ 2,
2490
+ )
2459
2491
 
2492
+ sim2 = self.backend.bk_abs(self.convol(sim2))
2460
2493
  angles = self.backend.bk_reshape(angles, [1, self.NORIENT, 1, 1])
2461
2494
  weighted_cos2 = self.backend.bk_reduce_mean(
2462
- sim2 * self.backend.bk_cos(angles), axis=1
2495
+ sim2 * self.backend.bk_cos(angles), axis=-3
2463
2496
  )
2464
2497
  weighted_sin2 = self.backend.bk_reduce_mean(
2465
- sim2 * self.backend.bk_sin(angles), axis=1
2498
+ sim2 * self.backend.bk_sin(angles), axis=-3
2466
2499
  )
2467
2500
 
2468
2501
  cc2 = weighted_cos2[0]
@@ -2471,13 +2504,13 @@ class funct(FOC.FoCUS):
2471
2504
  if smooth_scale > 0:
2472
2505
  for m in range(smooth_scale):
2473
2506
  if cc2.shape[1] > 12:
2474
- cc2, _ = self.ud_grade_2(self.smooth(cc2, axis=1), axis=1)
2475
- ss2, _ = self.ud_grade_2(self.smooth(ss2, axis=1), axis=1)
2507
+ cc2, _ = self.ud_grade_2(self.smooth(cc2))
2508
+ ss2, _ = self.ud_grade_2(self.smooth(ss2))
2476
2509
 
2477
- if cc2.shape[1] != sim.shape[2]:
2478
- ll_nside = int(np.sqrt(sim.shape[2] // 12))
2479
- cc2 = self.up_grade(cc2, ll_nside, axis=1)
2480
- ss2 = self.up_grade(ss2, ll_nside, axis=1)
2510
+ if cc2.shape[-1] != sim.shape[-1]:
2511
+ ll_nside = int(np.sqrt(sim.shape[-1] // 12))
2512
+ cc2 = self.up_grade(cc2, ll_nside)
2513
+ ss2 = self.up_grade(ss2, ll_nside)
2481
2514
 
2482
2515
  if self.BACKEND == "numpy":
2483
2516
  phase2 = np.fmod(np.arctan2(ss2, cc2) + 2 * np.pi, 2 * np.pi)
@@ -2495,23 +2528,32 @@ class funct(FOC.FoCUS):
2495
2528
  delta2 = phase2_scaled - iph2
2496
2529
  w0_2 = np.cos(delta2 * np.pi / 2) ** 2
2497
2530
  w1_2 = np.sin(delta2 * np.pi / 2) ** 2
2498
- lidx = np.arange(sim.shape[2])
2499
-
2500
- for m in range(self.NORIENT):
2501
- for ell in range(self.NORIENT):
2502
- col0 = self.NORIENT * ((ell + iph2[m]) % self.NORIENT) + ell
2503
- col1 = self.NORIENT * ((ell + iph2[m] + 1) % self.NORIENT) + ell
2504
- mat2[k2, col0, m, lidx] = w0_2[m, lidx]
2505
- mat2[k2, col1, m, lidx] = w1_2[m, lidx]
2531
+ lidx = np.arange(sim.shape[-1])
2532
+
2533
+ if spin==0:
2534
+ for m in range(self.NORIENT):
2535
+ for ell in range(self.NORIENT):
2536
+ col0 = self.NORIENT * ((ell + iph2[m]) % self.NORIENT) + ell
2537
+ col1 = self.NORIENT * ((ell + iph2[m] + 1) % self.NORIENT) + ell
2538
+ mat2[k2, col0, m, lidx] = w0_2[m, lidx]
2539
+ mat2[k2, col1, m, lidx] = w1_2[m, lidx]
2540
+ else:
2541
+ for sidx in range(2):
2542
+ for m in range(self.NORIENT):
2543
+ for ell in range(self.NORIENT):
2544
+ col0 = self.NORIENT * ((ell + iph2[sidx,m]) % self.NORIENT) + ell
2545
+ col1 = self.NORIENT * ((ell + iph2[sidx,m] + 1) % self.NORIENT) + ell
2546
+ mat2[k2, sidx, col0, m, lidx] = w0_2[sidx,m, lidx]
2547
+ mat2[k2, sidx, col1, m, lidx] = w1_2[sidx,m, lidx]
2506
2548
 
2507
2549
  cmat2[k] = self.backend.bk_cast(
2508
2550
  mat2[0 : k + 1, None, ...].astype("complex64")
2509
2551
  )
2510
2552
 
2511
2553
  if k < l_nside - 1:
2512
- tmp, _ = self.ud_grade_2(tmp, axis=1)
2554
+ tmp, _ = self.ud_grade_2(tmp)
2513
2555
  if image2 is not None:
2514
- tmpi2, _ = self.ud_grade_2(tmpi2, axis=1)
2556
+ tmpi2, _ = self.ud_grade_2(tmpi)
2515
2557
  return cmat, cmat2
2516
2558
 
2517
2559
  def div_norm(self, complex_value, float_value):
@@ -2534,6 +2576,7 @@ class funct(FOC.FoCUS):
2534
2576
  edge=True,
2535
2577
  nside=None,
2536
2578
  cell_ids=None,
2579
+ spin=0
2537
2580
  ):
2538
2581
  """
2539
2582
  Calculates the scattering correlations for a batch of images. Mean are done over pixels.
@@ -2559,11 +2602,14 @@ class funct(FOC.FoCUS):
2559
2602
  norm: None or str
2560
2603
  If None no normalization is applied, if 'auto' normalize by the reference S2,
2561
2604
  if 'self' normalize by the current S2.
2605
+ spin : Integer
2606
+ If different from 0 compute spinned data (U,V to Divergence/Rotational spin==1) or (Q,U to E,B spin=2).
2607
+ This implies that the input data is 2*12*nside^2.
2562
2608
  Returns
2563
2609
  -------
2564
2610
  S1, S2, S3, S4 normalized
2565
2611
  """
2566
-
2612
+
2567
2613
  return_data = self.return_data
2568
2614
 
2569
2615
  # Check input consistency
@@ -2576,8 +2622,8 @@ class funct(FOC.FoCUS):
2576
2622
  if mask is not None:
2577
2623
  if self.use_2D:
2578
2624
  if (
2579
- image1.shape[-2] != mask.shape[1]
2580
- or image1.shape[-1] != mask.shape[2]
2625
+ image1.shape[-2] != mask.shape[-1]
2626
+ or image1.shape[-1] != mask.shape[-1]
2581
2627
  ):
2582
2628
  print(
2583
2629
  "The LAST 2 COLUMNs of the mask should have the same size ",
@@ -2588,7 +2634,7 @@ class funct(FOC.FoCUS):
2588
2634
  )
2589
2635
  return None
2590
2636
  else:
2591
- if image1.shape[-1] != mask.shape[1]:
2637
+ if image1.shape[-1] != mask.shape[-1]:
2592
2638
  print(
2593
2639
  "The LAST COLUMN of the mask should have the same size ",
2594
2640
  mask.shape,
@@ -2609,7 +2655,6 @@ class funct(FOC.FoCUS):
2609
2655
  cross = True
2610
2656
  l_nside = 2**32 # not initialize if 1D or 2D
2611
2657
  ### PARAMETERS
2612
- axis = 1
2613
2658
  # determine jmax and nside corresponding to the input map
2614
2659
  im_shape = image1.shape
2615
2660
  if self.use_2D:
@@ -2637,10 +2682,7 @@ class funct(FOC.FoCUS):
2637
2682
 
2638
2683
  J = int(np.log(nside) / np.log(2)) # Number of j scales
2639
2684
  else:
2640
- if len(image1.shape) == 2:
2641
- npix = int(im_shape[1]) # Number of pixels
2642
- else:
2643
- npix = int(im_shape[0]) # Number of pixels
2685
+ npix=int(im_shape[-1])
2644
2686
 
2645
2687
  if nside is None:
2646
2688
  nside = int(np.sqrt(npix // 12))
@@ -2659,7 +2701,7 @@ class funct(FOC.FoCUS):
2659
2701
  print("\n\n==========")
2660
2702
 
2661
2703
  ### LOCAL VARIABLES (IMAGES and MASK)
2662
- if len(image1.shape) == 1 or (len(image1.shape) == 2 and self.use_2D):
2704
+ if len(image1.shape) == 1 or (len(image1.shape) == 2 and self.use_2D) or (len(image1.shape) == 2 and spin>0):
2663
2705
  I1 = self.backend.bk_cast(
2664
2706
  self.backend.bk_expand_dims(image1, 0)
2665
2707
  ) # Local image1 [Nbatch, Npix]
@@ -2684,26 +2726,26 @@ class funct(FOC.FoCUS):
2684
2726
  # if the kernel size is bigger than 3 increase the binning before smoothing
2685
2727
  if self.use_2D:
2686
2728
  vmask = self.up_grade(
2687
- vmask, I1.shape[axis] * 2, axis=1, nouty=I1.shape[axis + 1] * 2
2729
+ vmask, I1.shape[-2] * 2, nouty=I1.shape[-1] * 2,axis=-2
2688
2730
  )
2689
2731
  I1 = self.up_grade(
2690
- I1, I1.shape[axis] * 2, axis=axis, nouty=I1.shape[axis + 1] * 2
2732
+ I1, I1.shape[-2] * 2, nouty=I1.shape[-1] * 2,axis=-2
2691
2733
  )
2692
2734
  if cross:
2693
2735
  I2 = self.up_grade(
2694
- I2, I2.shape[axis] * 2, axis=axis, nouty=I2.shape[axis + 1] * 2
2736
+ I2, I2.shape[-2] * 2, nouty=I2.shape[-1] * 2,axis=-2
2695
2737
  )
2696
2738
  elif self.use_1D:
2697
- vmask = self.up_grade(vmask, I1.shape[axis] * 2, axis=1)
2698
- I1 = self.up_grade(I1, I1.shape[axis] * 2, axis=axis)
2739
+ vmask = self.up_grade(vmask, I1.shape[-1] * 2)
2740
+ I1 = self.up_grade(I1, I1.shape[-1] * 2)
2699
2741
  if cross:
2700
- I2 = self.up_grade(I2, I2.shape[axis] * 2, axis=axis)
2742
+ I2 = self.up_grade(I2, I2.shape[-1] * 2)
2701
2743
  nside = nside * 2
2702
2744
  else:
2703
- I1 = self.up_grade(I1, nside * 2, axis=axis)
2704
- vmask = self.up_grade(vmask, nside * 2, axis=1)
2745
+ I1 = self.up_grade(I1, nside * 2)
2746
+ vmask = self.up_grade(vmask, nside * 2)
2705
2747
  if cross:
2706
- I2 = self.up_grade(I2, nside * 2, axis=axis)
2748
+ I2 = self.up_grade(I2, nside * 2)
2707
2749
 
2708
2750
  nside = nside * 2
2709
2751
 
@@ -2711,29 +2753,28 @@ class funct(FOC.FoCUS):
2711
2753
  # if the kernel size is bigger than 3 increase the binning before smoothing
2712
2754
  if self.use_2D:
2713
2755
  vmask = self.up_grade(
2714
- vmask, I1.shape[axis] * 2, axis=1, nouty=I1.shape[axis + 1] * 2
2756
+ vmask, I1.shape[-2] * 2, nouty=I1.shape[-1] * 2,axis=-2
2715
2757
  )
2716
2758
  I1 = self.up_grade(
2717
- I1, I1.shape[axis] * 2, axis=axis, nouty=I1.shape[axis + 1] * 2
2759
+ I1, I1.shape[-2] * 2, nouty=I1.shape[-1] * 2,axis=-2
2718
2760
  )
2719
2761
  if cross:
2720
2762
  I2 = self.up_grade(
2721
2763
  I2,
2722
- I2.shape[axis] * 2,
2723
- axis=axis,
2724
- nouty=I2.shape[axis + 1] * 2,
2764
+ I2.shape[-2] * 2,
2765
+ nouty=I2.shape[-1] * 2,axis=-2
2725
2766
  )
2726
2767
  elif self.use_1D:
2727
- vmask = self.up_grade(vmask, I1.shape[axis] * 2, axis=1)
2728
- I1 = self.up_grade(I1, I1.shape[axis] * 2, axis=axis)
2768
+ vmask = self.up_grade(vmask, I1.shape[-1] * 2)
2769
+ I1 = self.up_grade(I1, I1.shape[-1] * 2)
2729
2770
  if cross:
2730
- I2 = self.up_grade(I2, I2.shape[axis] * 2, axis=axis)
2771
+ I2 = self.up_grade(I2, I2.shape[-1] * 2)
2731
2772
  nside = nside * 2
2732
2773
  else:
2733
- I1 = self.up_grade(I1, nside * 2, axis=axis)
2734
- vmask = self.up_grade(vmask, nside * 2, axis=1)
2774
+ I1 = self.up_grade(I1, nside * 2)
2775
+ vmask = self.up_grade(vmask, nside * 2)
2735
2776
  if cross:
2736
- I2 = self.up_grade(I2, nside * 2, axis=axis)
2777
+ I2 = self.up_grade(I2, nside * 2)
2737
2778
  nside = nside * 2
2738
2779
 
2739
2780
  # Normalize the masks because they have different pixel numbers
@@ -2762,6 +2803,7 @@ class funct(FOC.FoCUS):
2762
2803
  off_S2 = -2
2763
2804
  off_S3 = -3
2764
2805
  off_S4 = -4
2806
+
2765
2807
  if self.use_1D:
2766
2808
  off_S2 = -1
2767
2809
  off_S3 = -1
@@ -2793,13 +2835,20 @@ class funct(FOC.FoCUS):
2793
2835
  )
2794
2836
  else:
2795
2837
  if not cross:
2796
- s0, l_vs0 = self.masked_mean(I1, vmask, axis=1, calc_var=True)
2838
+ s0, l_vs0 = self.masked_mean(I1,
2839
+ vmask,
2840
+ calc_var=True)
2797
2841
  else:
2798
2842
  s0, l_vs0 = self.masked_mean(
2799
- self.backend.bk_L1(I1 * I2), vmask, axis=1, calc_var=True
2800
- )
2801
- vs0 = self.backend.bk_concat([l_vs0, l_vs0], 1)
2802
- s0 = self.backend.bk_concat([s0, l_vs0], 1)
2843
+ self.backend.bk_L1(I1 * I2),
2844
+ vmask,
2845
+ calc_var=True)
2846
+
2847
+ vs0 = self.backend.bk_concat([l_vs0, l_vs0], -1)
2848
+ s0 = self.backend.bk_concat([s0, l_vs0], -1)
2849
+ if spin>0:
2850
+ vs0=self.backend.bk_reshape(vs0,[vs0.shape[0],vs0.shape[1],2,vs0.shape[2]//2])
2851
+ s0=self.backend.bk_reshape(s0,[s0.shape[0],s0.shape[1],2,s0.shape[2]//2])
2803
2852
  #### COMPUTE S1, S2, S3 and S4
2804
2853
  nside_j3 = nside # NSIDE start (nside_j3 = nside / 2^j3)
2805
2854
 
@@ -2848,20 +2897,29 @@ class funct(FOC.FoCUS):
2848
2897
  ####### S1 and S2
2849
2898
  ### Make the convolution I1 * Psi_j3
2850
2899
  conv1 = self.convol(
2851
- I1, axis=1, cell_ids=cell_ids_j3, nside=nside_j3
2900
+ I1, cell_ids=cell_ids_j3, nside=nside_j3,
2901
+ spin=spin
2852
2902
  ) # [Nbatch, Norient3 , Npix_j3]
2853
2903
 
2854
2904
  if cmat is not None:
2855
-
2856
2905
  tmp2 = self.backend.bk_repeat(conv1, self.NORIENT, axis=-2)
2857
-
2858
- conv1 = self.backend.bk_reduce_sum(
2859
- self.backend.bk_reshape(
2860
- cmat[j3] * tmp2,
2861
- [tmp2.shape[0], self.NORIENT, self.NORIENT, cmat[j3].shape[2]],
2862
- ),
2863
- 1,
2864
- )
2906
+
2907
+ if spin==0:
2908
+ conv1 = self.backend.bk_reduce_sum(
2909
+ self.backend.bk_reshape(
2910
+ cmat[j3] * tmp2,
2911
+ [tmp2.shape[0], self.NORIENT, self.NORIENT, cmat[j3].shape[2]],
2912
+ ),
2913
+ 1,
2914
+ )
2915
+ else:
2916
+ conv1 = self.backend.bk_reduce_sum(
2917
+ self.backend.bk_reshape(
2918
+ cmat[j3] * tmp2,
2919
+ [tmp2.shape[0], 2,self.NORIENT, self.NORIENT, cmat[j3].shape[3]],
2920
+ ),
2921
+ 2,
2922
+ )
2865
2923
 
2866
2924
  ### Take the module M1 = |I1 * Psi_j3|
2867
2925
  M1_square = conv1 * self.backend.bk_conjugate(
@@ -2871,7 +2929,6 @@ class funct(FOC.FoCUS):
2871
2929
  M1 = self.backend.bk_L1(M1_square) # [Nbatch, Npix_j3, Norient3]
2872
2930
  # Store M1_j3 in a dictionary
2873
2931
  M1_dic[j3] = M1
2874
-
2875
2932
  if not cross: # Auto
2876
2933
  M1_square = self.backend.bk_real(M1_square)
2877
2934
 
@@ -2882,11 +2939,11 @@ class funct(FOC.FoCUS):
2882
2939
  else:
2883
2940
  if calc_var:
2884
2941
  s2, vs2 = self.masked_mean(
2885
- M1_square, vmask, axis=2, rank=j3, calc_var=True
2942
+ M1_square, vmask, rank=j3, calc_var=True
2886
2943
  )
2887
2944
  else:
2888
- s2 = self.masked_mean(M1_square, vmask, axis=2, rank=j3)
2889
-
2945
+ s2 = self.masked_mean(M1_square, vmask, rank=j3)
2946
+
2890
2947
  if cond_init_P1_dic:
2891
2948
  # We fill P1_dic with S2 for normalisation of S3 and S4
2892
2949
  P1_dic[j3] = self.backend.bk_real(s2) # [Nbatch, Nmask, Norient3]
@@ -2929,11 +2986,11 @@ class funct(FOC.FoCUS):
2929
2986
  else:
2930
2987
  if calc_var:
2931
2988
  s1, vs1 = self.masked_mean(
2932
- M1, vmask, axis=2, rank=j3, calc_var=True
2989
+ M1, vmask, rank=j3, calc_var=True
2933
2990
  ) # [Nbatch, Nmask, Norient3]
2934
2991
  else:
2935
2992
  s1 = self.masked_mean(
2936
- M1, vmask, axis=2, rank=j3
2993
+ M1, vmask, rank=j3
2937
2994
  ) # [Nbatch, Nmask, Norient3]
2938
2995
 
2939
2996
  if return_data:
@@ -2967,22 +3024,38 @@ class funct(FOC.FoCUS):
2967
3024
  else: # Cross
2968
3025
  ### Make the convolution I2 * Psi_j3
2969
3026
  conv2 = self.convol(
2970
- I2, axis=2, cell_ids=cell_ids_j3, nside=nside_j3
3027
+ I2, cell_ids=cell_ids_j3, nside=nside_j3,
3028
+ spin=spin
2971
3029
  ) # [Nbatch, Npix_j3, Norient3]
2972
3030
  if cmat is not None:
2973
3031
  tmp2 = self.backend.bk_repeat(conv2, self.NORIENT, axis=-2)
2974
- conv2 = self.backend.bk_reduce_sum(
2975
- self.backend.bk_reshape(
2976
- cmat[j3] * tmp2,
2977
- [
2978
- tmp2.shape[0],
2979
- self.NORIENT,
2980
- self.NORIENT,
2981
- cmat[j3].shape[2],
2982
- ],
2983
- ),
2984
- 1,
2985
- )
3032
+ if spin==0:
3033
+ conv2 = self.backend.bk_reduce_sum(
3034
+ self.backend.bk_reshape(
3035
+ cmat[j3] * tmp2,
3036
+ [
3037
+ tmp2.shape[0],
3038
+ self.NORIENT,
3039
+ self.NORIENT,
3040
+ cmat[j3].shape[2],
3041
+ ],
3042
+ ),
3043
+ 1,
3044
+ )
3045
+ else:
3046
+ conv2 = self.backend.bk_reduce_sum(
3047
+ self.backend.bk_reshape(
3048
+ cmat[j3] * tmp2,
3049
+ [
3050
+ tmp2.shape[0],
3051
+ 2,
3052
+ self.NORIENT,
3053
+ self.NORIENT,
3054
+ cmat[j3].shape[3],
3055
+ ],
3056
+ ),
3057
+ 2,
3058
+ )
2986
3059
  ### Take the module M2 = |I2 * Psi_j3|
2987
3060
  M2_square = conv2 * self.backend.bk_conjugate(
2988
3061
  conv2
@@ -3001,17 +3074,17 @@ class funct(FOC.FoCUS):
3001
3074
  else:
3002
3075
  if calc_var:
3003
3076
  p1, vp1 = self.masked_mean(
3004
- M1_square, vmask, axis=2, rank=j3, calc_var=True
3077
+ M1_square, vmask, rank=j3, calc_var=True
3005
3078
  ) # [Nbatch, Nmask, Norient3]
3006
3079
  p2, vp2 = self.masked_mean(
3007
- M2_square, vmask, axis=2, rank=j3, calc_var=True
3080
+ M2_square, vmask, rank=j3, calc_var=True
3008
3081
  ) # [Nbatch, Nmask, Norient3]
3009
3082
  else:
3010
3083
  p1 = self.masked_mean(
3011
- M1_square, vmask, axis=2, rank=j3
3084
+ M1_square, vmask, rank=j3
3012
3085
  ) # [Nbatch, Nmask, Norient3]
3013
3086
  p2 = self.masked_mean(
3014
- M2_square, vmask, axis=2, rank=j3
3087
+ M2_square, vmask, rank=j3
3015
3088
  ) # [Nbatch, Nmask, Norient3]
3016
3089
  # We fill P1_dic with S2 for normalisation of S3 and S4
3017
3090
  P1_dic[j3] = self.backend.bk_real(p1) # [Nbatch, Nmask, Norient3]
@@ -3027,10 +3100,10 @@ class funct(FOC.FoCUS):
3027
3100
  else:
3028
3101
  if calc_var:
3029
3102
  s2, vs2 = self.masked_mean(
3030
- s2, vmask, axis=2, rank=j3, calc_var=True
3103
+ s2, vmask, rank=j3, calc_var=True
3031
3104
  )
3032
3105
  else:
3033
- s2 = self.masked_mean(s2, vmask, axis=2, rank=j3)
3106
+ s2 = self.masked_mean(s2, vmask, rank=j3)
3034
3107
 
3035
3108
  if return_data:
3036
3109
  if out_nside is not None and out_nside < nside_j3:
@@ -3072,11 +3145,11 @@ class funct(FOC.FoCUS):
3072
3145
  else:
3073
3146
  if calc_var:
3074
3147
  s1, vs1 = self.masked_mean(
3075
- MX, vmask, axis=2, rank=j3, calc_var=True
3148
+ MX, vmask, rank=j3, calc_var=True
3076
3149
  ) # [Nbatch, Nmask, Norient3]
3077
3150
  else:
3078
3151
  s1 = self.masked_mean(
3079
- MX, vmask, axis=2, rank=j3
3152
+ MX, vmask, rank=j3
3080
3153
  ) # [Nbatch, Nmask, Norient3]
3081
3154
  if return_data:
3082
3155
  if out_nside is not None and out_nside < nside_j3:
@@ -3133,6 +3206,7 @@ class funct(FOC.FoCUS):
3133
3206
  cmat2=cmat2,
3134
3207
  cell_ids=cell_ids_j3,
3135
3208
  nside_j2=nside_j3,
3209
+ spin=spin,
3136
3210
  ) # [Nbatch, Nmask, Norient3, Norient2]
3137
3211
  else:
3138
3212
  s3 = self._compute_S3(
@@ -3146,6 +3220,7 @@ class funct(FOC.FoCUS):
3146
3220
  cmat2=cmat2,
3147
3221
  cell_ids=cell_ids_j3,
3148
3222
  nside_j2=nside_j3,
3223
+ spin=spin,
3149
3224
  ) # [Nbatch, Nmask, Norient3, Norient2]
3150
3225
 
3151
3226
  if return_data:
@@ -3207,6 +3282,7 @@ class funct(FOC.FoCUS):
3207
3282
  cmat2=cmat2,
3208
3283
  cell_ids=cell_ids_j3,
3209
3284
  nside_j2=nside_j3,
3285
+ spin=spin,
3210
3286
  )
3211
3287
  s3p, vs3p = self._compute_S3(
3212
3288
  j2,
@@ -3219,6 +3295,7 @@ class funct(FOC.FoCUS):
3219
3295
  cmat2=cmat2,
3220
3296
  cell_ids=cell_ids_j3,
3221
3297
  nside_j2=nside_j3,
3298
+ spin=spin,
3222
3299
  )
3223
3300
  else:
3224
3301
  s3p = self._compute_S3(
@@ -3232,6 +3309,7 @@ class funct(FOC.FoCUS):
3232
3309
  cmat2=cmat2,
3233
3310
  cell_ids=cell_ids_j3,
3234
3311
  nside_j2=nside_j3,
3312
+ spin=spin,
3235
3313
  )
3236
3314
  s3 = self._compute_S3(
3237
3315
  j2,
@@ -3244,6 +3322,7 @@ class funct(FOC.FoCUS):
3244
3322
  cmat2=cmat2,
3245
3323
  cell_ids=cell_ids_j3,
3246
3324
  nside_j2=nside_j3,
3325
+ spin=spin,
3247
3326
  )
3248
3327
 
3249
3328
  if return_data:
@@ -3482,39 +3561,39 @@ class funct(FOC.FoCUS):
3482
3561
  ### Image I1,
3483
3562
  # downscale the I1 [Nbatch, Npix_j3]
3484
3563
  if j3 != Jmax - 1:
3485
- I1 = self.smooth(I1, axis=1, cell_ids=cell_ids_j3, nside=nside_j3)
3564
+ I1 = self.smooth(I1, cell_ids=cell_ids_j3, nside=nside_j3)
3486
3565
  I1, new_cell_ids_j3 = self.ud_grade_2(
3487
- I1, axis=1, cell_ids=cell_ids_j3, nside=nside_j3
3566
+ I1, cell_ids=cell_ids_j3, nside=nside_j3
3488
3567
  )
3489
3568
 
3490
3569
  ### Image I2
3491
3570
  if cross:
3492
- I2 = self.smooth(I2, axis=1, cell_ids=cell_ids_j3, nside=nside_j3)
3571
+ I2 = self.smooth(I2, cell_ids=cell_ids_j3, nside=nside_j3)
3493
3572
  I2, new_cell_ids_j3 = self.ud_grade_2(
3494
- I2, axis=1, cell_ids=cell_ids_j3, nside=nside_j3
3573
+ I2, cell_ids=cell_ids_j3, nside=nside_j3
3495
3574
  )
3496
3575
 
3497
3576
  ### Modules
3498
3577
  for j2 in range(0, j3 + 1): # j2 =< j3
3499
3578
  ### Dictionary M1_dic[j2]
3500
3579
  M1_smooth = self.smooth(
3501
- M1_dic[j2], axis=2, cell_ids=cell_ids_j3, nside=nside_j3
3580
+ M1_dic[j2], cell_ids=cell_ids_j3, nside=nside_j3
3502
3581
  ) # [Nbatch, Npix_j3, Norient3]
3503
3582
  M1_dic[j2], new_cell_ids_j2 = self.ud_grade_2(
3504
- M1_smooth, axis=2, cell_ids=cell_ids_j3, nside=nside_j3
3583
+ M1_smooth, cell_ids=cell_ids_j3, nside=nside_j3
3505
3584
  ) # [Nbatch, Npix_j3, Norient3]
3506
3585
 
3507
3586
  ### Dictionary M2_dic[j2]
3508
3587
  if cross:
3509
3588
  M2_smooth = self.smooth(
3510
- M2_dic[j2], axis=2, cell_ids=cell_ids_j3, nside=nside_j3
3589
+ M2_dic[j2], cell_ids=cell_ids_j3, nside=nside_j3
3511
3590
  ) # [Nbatch, Npix_j3, Norient3]
3512
3591
  M2_dic[j2], new_cell_ids_j2 = self.ud_grade_2(
3513
- M2_smooth, axis=2, cell_ids=cell_ids_j3, nside=nside_j3
3592
+ M2_smooth, cell_ids=cell_ids_j3, nside=nside_j3
3514
3593
  ) # [Nbatch, Npix_j3, Norient3]
3515
3594
  ### Mask
3516
3595
  vmask, new_cell_ids_j3 = self.ud_grade_2(
3517
- vmask, axis=1, cell_ids=cell_ids_j3, nside=nside_j3
3596
+ vmask, cell_ids=cell_ids_j3, nside=nside_j3
3518
3597
  )
3519
3598
 
3520
3599
  if self.mask_thres is not None:
@@ -3529,33 +3608,21 @@ class funct(FOC.FoCUS):
3529
3608
  self.P1_dic = P1_dic
3530
3609
  if cross:
3531
3610
  self.P2_dic = P2_dic
3532
- """
3533
- Sout=[s0]+S1+S2+S3+S4
3534
-
3535
- if cross:
3536
- Sout=Sout+S3P
3537
- if calc_var:
3538
- SVout=[vs0]+VS1+VS2+VS3+VS4
3539
- if cross:
3540
- VSout=VSout+VS3P
3541
- return self.backend.bk_concat(Sout, 2),self.backend.bk_concat(VSout, 2)
3542
-
3543
- return self.backend.bk_concat(Sout, 2)
3544
- """
3611
+
3545
3612
  if not return_data:
3546
- S1 = self.backend.bk_concat(S1, 2)
3547
- S2 = self.backend.bk_concat(S2, 2)
3548
- S3 = self.backend.bk_concat(S3, 2)
3549
- S4 = self.backend.bk_concat(S4, 2)
3613
+ S1 = self.backend.bk_concat(S1, -2)
3614
+ S2 = self.backend.bk_concat(S2, -2)
3615
+ S3 = self.backend.bk_concat(S3, -3)
3616
+ S4 = self.backend.bk_concat(S4, -4)
3550
3617
  if cross:
3551
- S3P = self.backend.bk_concat(S3P, 2)
3618
+ S3P = self.backend.bk_concat(S3P, -3)
3552
3619
  if calc_var:
3553
- VS1 = self.backend.bk_concat(VS1, 2)
3554
- VS2 = self.backend.bk_concat(VS2, 2)
3555
- VS3 = self.backend.bk_concat(VS3, 2)
3556
- VS4 = self.backend.bk_concat(VS4, 2)
3620
+ VS1 = self.backend.bk_concat(VS1, -2)
3621
+ VS2 = self.backend.bk_concat(VS2, -2)
3622
+ VS3 = self.backend.bk_concat(VS3, -3)
3623
+ VS4 = self.backend.bk_concat(VS4, -4)
3557
3624
  if cross:
3558
- VS3P = self.backend.bk_concat(VS3P, 2)
3625
+ VS3P = self.backend.bk_concat(VS3P, -3)
3559
3626
  if calc_var:
3560
3627
  if not cross:
3561
3628
  return scat_cov(
@@ -3612,18 +3679,19 @@ class funct(FOC.FoCUS):
3612
3679
  return
3613
3680
 
3614
3681
  def _compute_S3(
3615
- self,
3616
- j2,
3617
- j3,
3618
- conv,
3619
- vmask,
3620
- M_dic,
3621
- MconvPsi_dic,
3622
- calc_var=False,
3623
- return_data=False,
3624
- cmat2=None,
3625
- cell_ids=None,
3626
- nside_j2=None,
3682
+ self,
3683
+ j2,
3684
+ j3,
3685
+ conv,
3686
+ vmask,
3687
+ M_dic,
3688
+ MconvPsi_dic,
3689
+ calc_var=False,
3690
+ return_data=False,
3691
+ cmat2=None,
3692
+ cell_ids=None,
3693
+ nside_j2=None,
3694
+ spin=0,
3627
3695
  ):
3628
3696
  """
3629
3697
  Compute the S3 coefficients (auto or cross)
@@ -3637,24 +3705,40 @@ class funct(FOC.FoCUS):
3637
3705
  ### Compute |I1 * Psi_j2| * Psi_j3 = M1_j2 * Psi_j3
3638
3706
  # Warning: M1_dic[j2] is already at j3 resolution [Nbatch, Norient3, Npix_j3]
3639
3707
  MconvPsi = self.convol(
3640
- M_dic[j2], axis=2, cell_ids=cell_ids, nside=nside_j2
3708
+ M_dic[j2], cell_ids=cell_ids, nside=nside_j2
3641
3709
  ) # [Nbatch, Norient3, Norient2, Npix_j3]
3642
3710
 
3643
3711
  if cmat2 is not None:
3644
3712
  tmp2 = self.backend.bk_repeat(MconvPsi, self.NORIENT, axis=-3)
3645
- MconvPsi = self.backend.bk_reduce_sum(
3646
- self.backend.bk_reshape(
3647
- cmat2[j3][j2] * tmp2,
3648
- [
3649
- tmp2.shape[0],
3650
- self.NORIENT,
3651
- self.NORIENT,
3652
- self.NORIENT,
3653
- cmat2[j3][j2].shape[3],
3654
- ],
3655
- ),
3656
- 1,
3657
- )
3713
+ if spin==0:
3714
+ MconvPsi = self.backend.bk_reduce_sum(
3715
+ self.backend.bk_reshape(
3716
+ cmat2[j3][j2] * tmp2,
3717
+ [
3718
+ tmp2.shape[0],
3719
+ self.NORIENT,
3720
+ self.NORIENT,
3721
+ self.NORIENT,
3722
+ cmat2[j3][j2].shape[3],
3723
+ ],
3724
+ ),
3725
+ 1,
3726
+ )
3727
+ else:
3728
+ MconvPsi = self.backend.bk_reduce_sum(
3729
+ self.backend.bk_reshape(
3730
+ cmat2[j3][j2] * tmp2,
3731
+ [
3732
+ tmp2.shape[0],
3733
+ 2,
3734
+ self.NORIENT,
3735
+ self.NORIENT,
3736
+ self.NORIENT,
3737
+ cmat2[j3][j2].shape[4],
3738
+ ],
3739
+ ),
3740
+ 2,
3741
+ )
3658
3742
 
3659
3743
  # Store it so we can use it in S4 computation
3660
3744
  MconvPsi_dic[j2] = MconvPsi # [Nbatch, Norient3, Norient2, Npix_j3]
@@ -3674,12 +3758,12 @@ class funct(FOC.FoCUS):
3674
3758
  else:
3675
3759
  if calc_var:
3676
3760
  s3, vs3 = self.masked_mean(
3677
- s3, vmask, axis=3, rank=j2, calc_var=True
3761
+ s3, vmask, rank=j2, calc_var=True
3678
3762
  ) # [Nbatch, Nmask, Norient3, Norient2]
3679
3763
  return s3, vs3
3680
3764
  else:
3681
3765
  s3 = self.masked_mean(
3682
- s3, vmask, axis=3, rank=j2
3766
+ s3, vmask, rank=j2
3683
3767
  ) # [Nbatch, Nmask, Norient3, Norient2]
3684
3768
  return s3
3685
3769
 
@@ -3717,12 +3801,12 @@ class funct(FOC.FoCUS):
3717
3801
  else:
3718
3802
  if calc_var:
3719
3803
  s4, vs4 = self.masked_mean(
3720
- s4, vmask, axis=4, rank=j2, calc_var=True
3804
+ s4, vmask, rank=j2, calc_var=True
3721
3805
  ) # [Nbatch, Nmask, Norient3, Norient2, Norient1]
3722
3806
  return s4, vs4
3723
3807
  else:
3724
3808
  s4 = self.masked_mean(
3725
- s4, vmask, axis=4, rank=j2
3809
+ s4, vmask, rank=j2
3726
3810
  ) # [Nbatch, Nmask, Norient3, Norient2, Norient1]
3727
3811
  return s4
3728
3812
 
@@ -3922,20 +4006,20 @@ class funct(FOC.FoCUS):
3922
4006
  #
3923
4007
  # ---------------------------------------------------------------------------
3924
4008
  def scattering_cov(
3925
- self,
3926
- data,
3927
- data2=None,
3928
- Jmax=None,
3929
- if_large_batch=False,
3930
- S4_criteria=None,
3931
- use_ref=False,
3932
- normalization="S2",
3933
- edge=False,
3934
- in_mask=None,
3935
- pseudo_coef=1,
3936
- get_variance=False,
3937
- ref_sigma=None,
3938
- iso_ang=False,
4009
+ self,
4010
+ data,
4011
+ data2=None,
4012
+ Jmax=None,
4013
+ if_large_batch=False,
4014
+ S4_criteria=None,
4015
+ use_ref=False,
4016
+ normalization="S2",
4017
+ edge=False,
4018
+ in_mask=None,
4019
+ pseudo_coef=1,
4020
+ get_variance=False,
4021
+ ref_sigma=None,
4022
+ iso_ang=False,
3939
4023
  ):
3940
4024
  """
3941
4025
  Calculates the scattering correlations for a batch of images, including:
@@ -4048,7 +4132,10 @@ class funct(FOC.FoCUS):
4048
4132
  if data2 is not None:
4049
4133
  N_image2 = data2.shape[0]
4050
4134
 
4051
- nside = int(np.sqrt(npix // 12))
4135
+ if spin==0:
4136
+ nside = int(np.sqrt(npix // 12))
4137
+ else:
4138
+ nside = int(np.sqrt(npix // 24))
4052
4139
 
4053
4140
  J = int(np.log(nside) / np.log(2)) # Number of j scales
4054
4141
 
@@ -5781,7 +5868,9 @@ class funct(FOC.FoCUS):
5781
5868
 
5782
5869
  def from_gaussian(self, x):
5783
5870
 
5784
- x = self.backend.bk_clip_by_value(x, self.val_min, self.val_max)
5871
+ x = self.backend.bk_clip_by_value(x,
5872
+ self.val_min+1E-4*abs(self.val_min),
5873
+ self.val_max-1E-4*abs(self.val_max))
5785
5874
  return self.f_gaussian(self.backend.to_numpy(x))
5786
5875
 
5787
5876
  def square(self, x):
@@ -6181,6 +6270,28 @@ class funct(FOC.FoCUS):
6181
6270
  )
6182
6271
  return loss
6183
6272
 
6273
+ def The_lossH(u, scat_operator, args):
6274
+ ref = args[0]
6275
+ sref = args[1]
6276
+ use_v = args[2]
6277
+ ljmax = args[3]
6278
+
6279
+ learn = scat_operator.reduce_mean_batch(
6280
+ scat_operator.eval(
6281
+ u,
6282
+ Jmax=ljmax,
6283
+ norm='self'
6284
+ )
6285
+ )
6286
+
6287
+ # compute scattering covariance of the current synthetised map called u
6288
+ if use_v:
6289
+ loss = scat_operator.reduce_distance(learn,ref,sigma=sref)
6290
+ else:
6291
+ loss = scat_operator.reduce_distance(learn,ref)
6292
+
6293
+ return loss
6294
+
6184
6295
  def The_lossX(u, scat_operator, args):
6185
6296
  ref = args[0]
6186
6297
  sref = args[1]
@@ -6340,37 +6451,74 @@ class funct(FOC.FoCUS):
6340
6451
  # Increase the resolution between each step
6341
6452
  if self.use_2D:
6342
6453
  imap = self.up_grade(
6343
- omap, imap.shape[1] * 2, axis=1, nouty=imap.shape[2] * 2
6454
+ omap,
6455
+ imap.shape[1] * 2,
6456
+ axis=-2,
6457
+ nouty=imap.shape[2] * 2
6344
6458
  )
6345
6459
  elif self.use_1D:
6346
- imap = self.up_grade(omap, imap.shape[1] * 2, axis=1)
6460
+ imap = self.up_grade(omap, imap.shape[1] * 2)
6347
6461
  else:
6348
- imap = self.up_grade(omap, l_nside, axis=1)
6349
-
6462
+ imap = self.up_grade(omap, l_nside)
6463
+
6350
6464
  if grd_mask is not None:
6351
6465
  imap = imap * l_grd_mask[k] + tmp[k] * (1 - l_grd_mask[k])
6352
6466
 
6353
- # compute the coefficients for the target image
6354
- if use_variance:
6355
- ref, sref = self.scattering_cov(
6356
- tmp[k],
6357
- data2=l_ref[k],
6358
- get_variance=True,
6359
- edge=l_edge,
6360
- Jmax=l_jmax[k],
6361
- in_mask=l_in_mask[k],
6362
- iso_ang=iso_ang,
6363
- )
6467
+
6468
+ if self.use_2D:
6469
+ # compute the coefficients for the target image
6470
+ if use_variance:
6471
+ ref, sref = self.scattering_cov(
6472
+ tmp[k],
6473
+ data2=l_ref[k],
6474
+ get_variance=True,
6475
+ edge=l_edge,
6476
+ Jmax=l_jmax[k],
6477
+ in_mask=l_in_mask[k],
6478
+ iso_ang=iso_ang,
6479
+ )
6480
+ else:
6481
+ ref = self.scattering_cov(
6482
+ tmp[k],
6483
+ data2=l_ref[k],
6484
+ in_mask=l_in_mask[k],
6485
+ edge=l_edge,
6486
+ Jmax=l_jmax[k],
6487
+ iso_ang=iso_ang,
6488
+ )
6489
+ sref = ref
6364
6490
  else:
6365
- ref = self.scattering_cov(
6366
- tmp[k],
6367
- data2=l_ref[k],
6368
- in_mask=l_in_mask[k],
6369
- edge=l_edge,
6370
- Jmax=l_jmax[k],
6371
- iso_ang=iso_ang,
6372
- )
6373
- sref = ref
6491
+ ref = self.eval(
6492
+ tmp[k],
6493
+ image2=l_ref[k],
6494
+ mask=l_in_mask[k],
6495
+ Jmax=l_jmax[k],
6496
+ norm='auto'
6497
+ )
6498
+
6499
+ # compute the coefficients for the target image
6500
+ if use_variance:
6501
+ ref, sref = self.eval(
6502
+ tmp[k],
6503
+ image2=l_ref[k],
6504
+ mask=l_in_mask[k],
6505
+ Jmax=l_jmax[k],
6506
+ calc_var=True,
6507
+ norm='self'
6508
+ )
6509
+ else:
6510
+ ref = self.eval(
6511
+ tmp[k],
6512
+ image2=l_ref[k],
6513
+ mask=l_in_mask[k],
6514
+ Jmax=l_jmax[k],
6515
+ norm='self'
6516
+ )
6517
+ sref = ref
6518
+
6519
+ if iso_ang:
6520
+ ref=ref.iso_mean()
6521
+ sref=sref.iso_mean()
6374
6522
 
6375
6523
  # compute the mean of the population does nothing if only one map is given
6376
6524
  ref = self.reduce_mean_batch(ref)
@@ -6379,13 +6527,21 @@ class funct(FOC.FoCUS):
6379
6527
  self.purge_edge_mask()
6380
6528
 
6381
6529
  if l_ref[k] is None:
6382
- # define a loss to minimize
6383
- loss = synthe.Loss(The_loss, self, ref, sref, use_variance, l_jmax[k])
6530
+ if self.use_2D:
6531
+ # define a loss to minimize
6532
+ loss = synthe.Loss(The_loss, self, ref, sref, use_variance, l_jmax[k])
6533
+ else:
6534
+ loss = synthe.Loss(The_lossH, self, ref, sref, use_variance, l_jmax[k])
6384
6535
  else:
6385
6536
  # define a loss to minimize
6386
- loss = synthe.Loss(
6387
- The_lossX, self, ref, sref, use_variance, l_ref[k], l_jmax[k]
6388
- )
6537
+ if self.use_2D:
6538
+ loss = synthe.Loss(
6539
+ The_lossX, self, ref, sref, use_variance, l_ref[k], l_jmax[k]
6540
+ )
6541
+ else:
6542
+ loss = synthe.Loss(
6543
+ The_lossXH, self, ref, sref, use_variance, l_ref[k], l_jmax[k]
6544
+ )
6389
6545
 
6390
6546
  if input_image is not None:
6391
6547
  # define a loss to minimize