foscat 2025.5.2__py3-none-any.whl → 2025.6.3__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- foscat/BkTorch.py +11 -12
- foscat/CNN.py +31 -30
- foscat/FoCUS.py +784 -780
- foscat/GCNN.py +48 -150
- foscat/Softmax.py +1 -0
- foscat/alm.py +2 -2
- foscat/heal_NN.py +451 -0
- foscat/scat_cov.py +186 -155
- {foscat-2025.5.2.dist-info → foscat-2025.6.3.dist-info}/METADATA +1 -1
- {foscat-2025.5.2.dist-info → foscat-2025.6.3.dist-info}/RECORD +13 -12
- {foscat-2025.5.2.dist-info → foscat-2025.6.3.dist-info}/WHEEL +0 -0
- {foscat-2025.5.2.dist-info → foscat-2025.6.3.dist-info}/licenses/LICENSE +0 -0
- {foscat-2025.5.2.dist-info → foscat-2025.6.3.dist-info}/top_level.txt +0 -0
foscat/scat_cov.py
CHANGED
|
@@ -92,12 +92,7 @@ class scat_cov:
|
|
|
92
92
|
)
|
|
93
93
|
|
|
94
94
|
def conv2complex(self, val):
|
|
95
|
-
if (
|
|
96
|
-
val.dtype == "complex64"
|
|
97
|
-
or val.dtype == "complex128"
|
|
98
|
-
or val.dtype == "torch.complex64"
|
|
99
|
-
or val.dtype == "torch.complex128"
|
|
100
|
-
):
|
|
95
|
+
if self.backend.bk_is_complex(val):
|
|
101
96
|
return val
|
|
102
97
|
else:
|
|
103
98
|
return self.backend.bk_complex(val, 0 * val)
|
|
@@ -107,7 +102,7 @@ class scat_cov:
|
|
|
107
102
|
def flatten(self):
|
|
108
103
|
tmp = [
|
|
109
104
|
self.conv2complex(
|
|
110
|
-
self.backend.bk_reshape(self.S0, [self.S1.shape[0], self.S0.shape[1]])
|
|
105
|
+
self.backend.bk_reshape(self.S0, [self.S1.shape[0], self.S0.shape[1]*self.S0.shape[2]])
|
|
111
106
|
)
|
|
112
107
|
]
|
|
113
108
|
if self.use_1D:
|
|
@@ -177,21 +172,9 @@ class scat_cov:
|
|
|
177
172
|
],
|
|
178
173
|
)
|
|
179
174
|
),
|
|
180
|
-
self.
|
|
181
|
-
self.S3,
|
|
182
|
-
[
|
|
183
|
-
self.S3.shape[0],
|
|
184
|
-
self.S3.shape[1]
|
|
185
|
-
* self.S3.shape[2]
|
|
186
|
-
* self.S3.shape[3]
|
|
187
|
-
* self.S3.shape[4],
|
|
188
|
-
],
|
|
189
|
-
),
|
|
190
|
-
]
|
|
191
|
-
if self.S3P is not None:
|
|
192
|
-
tmp = tmp + [
|
|
175
|
+
self.conv2complex(
|
|
193
176
|
self.backend.bk_reshape(
|
|
194
|
-
self.
|
|
177
|
+
self.S3,
|
|
195
178
|
[
|
|
196
179
|
self.S3.shape[0],
|
|
197
180
|
self.S3.shape[1]
|
|
@@ -200,22 +183,39 @@ class scat_cov:
|
|
|
200
183
|
* self.S3.shape[4],
|
|
201
184
|
],
|
|
202
185
|
)
|
|
186
|
+
),
|
|
187
|
+
]
|
|
188
|
+
if self.S3P is not None:
|
|
189
|
+
tmp = tmp + [
|
|
190
|
+
self.conv2complex(
|
|
191
|
+
self.backend.bk_reshape(
|
|
192
|
+
self.S3P,
|
|
193
|
+
[
|
|
194
|
+
self.S3.shape[0],
|
|
195
|
+
self.S3.shape[1]
|
|
196
|
+
* self.S3.shape[2]
|
|
197
|
+
* self.S3.shape[3]
|
|
198
|
+
* self.S3.shape[4],
|
|
199
|
+
],
|
|
200
|
+
)
|
|
201
|
+
)
|
|
203
202
|
]
|
|
204
203
|
|
|
205
204
|
tmp = tmp + [
|
|
206
|
-
self.
|
|
207
|
-
self.
|
|
208
|
-
|
|
209
|
-
|
|
210
|
-
|
|
211
|
-
|
|
212
|
-
|
|
213
|
-
|
|
214
|
-
|
|
215
|
-
|
|
205
|
+
self.conv2complex(
|
|
206
|
+
self.backend.bk_reshape(
|
|
207
|
+
self.S4,
|
|
208
|
+
[
|
|
209
|
+
self.S4.shape[0],
|
|
210
|
+
self.S4.shape[1]
|
|
211
|
+
* self.S4.shape[2]
|
|
212
|
+
* self.S4.shape[3]
|
|
213
|
+
* self.S4.shape[4]
|
|
214
|
+
* self.S4.shape[5],
|
|
215
|
+
],
|
|
216
|
+
)
|
|
216
217
|
)
|
|
217
218
|
]
|
|
218
|
-
|
|
219
219
|
return self.backend.bk_concat(tmp, 1)
|
|
220
220
|
|
|
221
221
|
# ---------------------------------------------−---------
|
|
@@ -2354,9 +2354,9 @@ class funct(FOC.FoCUS):
|
|
|
2354
2354
|
tmpi2 = image2
|
|
2355
2355
|
if upscale:
|
|
2356
2356
|
l_nside = int(np.sqrt(tmp.shape[1] // 12))
|
|
2357
|
-
tmp = self.up_grade(tmp, l_nside * 2
|
|
2357
|
+
tmp = self.up_grade(tmp, l_nside * 2)
|
|
2358
2358
|
if image2 is not None:
|
|
2359
|
-
tmpi2 = self.up_grade(tmpi2, l_nside * 2
|
|
2359
|
+
tmpi2 = self.up_grade(tmpi2, l_nside * 2)
|
|
2360
2360
|
l_nside = int(np.sqrt(tmp.shape[1] // 12))
|
|
2361
2361
|
nscale = int(np.log(l_nside) / np.log(2))
|
|
2362
2362
|
cmat = {}
|
|
@@ -2367,12 +2367,12 @@ class funct(FOC.FoCUS):
|
|
|
2367
2367
|
if image2 is not None:
|
|
2368
2368
|
sim = self.backend.bk_real(
|
|
2369
2369
|
self.backend.bk_L1(
|
|
2370
|
-
self.convol(tmp
|
|
2371
|
-
* self.backend.bk_conjugate(self.convol(tmpi2
|
|
2370
|
+
self.convol(tmp)
|
|
2371
|
+
* self.backend.bk_conjugate(self.convol(tmpi2))
|
|
2372
2372
|
)
|
|
2373
2373
|
)
|
|
2374
2374
|
else:
|
|
2375
|
-
sim = self.backend.bk_abs(self.convol(tmp
|
|
2375
|
+
sim = self.backend.bk_abs(self.convol(tmp))
|
|
2376
2376
|
|
|
2377
2377
|
# instead of difference between "opposite" channels use weighted average
|
|
2378
2378
|
# of cosine and sine contributions using all channels
|
|
@@ -2442,12 +2442,12 @@ class funct(FOC.FoCUS):
|
|
|
2442
2442
|
|
|
2443
2443
|
for k2 in range(k + 1):
|
|
2444
2444
|
|
|
2445
|
-
tmp2 = self.backend.
|
|
2445
|
+
tmp2 = self.backend.bk_expand_dims(sim,-2)
|
|
2446
2446
|
|
|
2447
2447
|
sim2 = self.backend.bk_reduce_sum(
|
|
2448
2448
|
self.backend.bk_reshape(
|
|
2449
2449
|
self.backend.bk_cast(
|
|
2450
|
-
mat.reshape(1, self.NORIENT
|
|
2450
|
+
mat.reshape(1, self.NORIENT, self.NORIENT, mat.shape[1])
|
|
2451
2451
|
)
|
|
2452
2452
|
* tmp2,
|
|
2453
2453
|
[sim.shape[0], self.NORIENT, self.NORIENT, mat.shape[1]],
|
|
@@ -2455,14 +2455,14 @@ class funct(FOC.FoCUS):
|
|
|
2455
2455
|
1,
|
|
2456
2456
|
)
|
|
2457
2457
|
|
|
2458
|
-
sim2 = self.backend.bk_abs(self.convol(sim2
|
|
2458
|
+
sim2 = self.backend.bk_abs(self.convol(sim2))
|
|
2459
2459
|
|
|
2460
2460
|
angles = self.backend.bk_reshape(angles, [1, self.NORIENT, 1, 1])
|
|
2461
2461
|
weighted_cos2 = self.backend.bk_reduce_mean(
|
|
2462
|
-
sim2 * self.backend.bk_cos(angles), axis
|
|
2462
|
+
sim2 * self.backend.bk_cos(angles), axis=-3
|
|
2463
2463
|
)
|
|
2464
2464
|
weighted_sin2 = self.backend.bk_reduce_mean(
|
|
2465
|
-
sim2 * self.backend.bk_sin(angles), axis
|
|
2465
|
+
sim2 * self.backend.bk_sin(angles), axis=-3
|
|
2466
2466
|
)
|
|
2467
2467
|
|
|
2468
2468
|
cc2 = weighted_cos2[0]
|
|
@@ -2471,13 +2471,13 @@ class funct(FOC.FoCUS):
|
|
|
2471
2471
|
if smooth_scale > 0:
|
|
2472
2472
|
for m in range(smooth_scale):
|
|
2473
2473
|
if cc2.shape[1] > 12:
|
|
2474
|
-
cc2, _ = self.ud_grade_2(self.smooth(cc2
|
|
2475
|
-
ss2, _ = self.ud_grade_2(self.smooth(ss2
|
|
2474
|
+
cc2, _ = self.ud_grade_2(self.smooth(cc2))
|
|
2475
|
+
ss2, _ = self.ud_grade_2(self.smooth(ss2))
|
|
2476
2476
|
|
|
2477
2477
|
if cc2.shape[1] != sim.shape[2]:
|
|
2478
2478
|
ll_nside = int(np.sqrt(sim.shape[2] // 12))
|
|
2479
|
-
cc2 = self.up_grade(cc2, ll_nside
|
|
2480
|
-
ss2 = self.up_grade(ss2, ll_nside
|
|
2479
|
+
cc2 = self.up_grade(cc2, ll_nside)
|
|
2480
|
+
ss2 = self.up_grade(ss2, ll_nside)
|
|
2481
2481
|
|
|
2482
2482
|
if self.BACKEND == "numpy":
|
|
2483
2483
|
phase2 = np.fmod(np.arctan2(ss2, cc2) + 2 * np.pi, 2 * np.pi)
|
|
@@ -2509,9 +2509,9 @@ class funct(FOC.FoCUS):
|
|
|
2509
2509
|
)
|
|
2510
2510
|
|
|
2511
2511
|
if k < l_nside - 1:
|
|
2512
|
-
tmp, _ = self.ud_grade_2(tmp
|
|
2512
|
+
tmp, _ = self.ud_grade_2(tmp)
|
|
2513
2513
|
if image2 is not None:
|
|
2514
|
-
tmpi2, _ = self.ud_grade_2(
|
|
2514
|
+
tmpi2, _ = self.ud_grade_2(tmpi)
|
|
2515
2515
|
return cmat, cmat2
|
|
2516
2516
|
|
|
2517
2517
|
def div_norm(self, complex_value, float_value):
|
|
@@ -2534,6 +2534,7 @@ class funct(FOC.FoCUS):
|
|
|
2534
2534
|
edge=True,
|
|
2535
2535
|
nside=None,
|
|
2536
2536
|
cell_ids=None,
|
|
2537
|
+
spin=0
|
|
2537
2538
|
):
|
|
2538
2539
|
"""
|
|
2539
2540
|
Calculates the scattering correlations for a batch of images. Mean are done over pixels.
|
|
@@ -2559,11 +2560,14 @@ class funct(FOC.FoCUS):
|
|
|
2559
2560
|
norm: None or str
|
|
2560
2561
|
If None no normalization is applied, if 'auto' normalize by the reference S2,
|
|
2561
2562
|
if 'self' normalize by the current S2.
|
|
2563
|
+
spin : Integer
|
|
2564
|
+
If different from 0 compute spinned data (U,V to Divergence/Rotational spin==1) or (Q,U to E,B spin=2).
|
|
2565
|
+
This implies that the input data is 2*12*nside^2.
|
|
2562
2566
|
Returns
|
|
2563
2567
|
-------
|
|
2564
2568
|
S1, S2, S3, S4 normalized
|
|
2565
2569
|
"""
|
|
2566
|
-
|
|
2570
|
+
|
|
2567
2571
|
return_data = self.return_data
|
|
2568
2572
|
|
|
2569
2573
|
# Check input consistency
|
|
@@ -2576,8 +2580,8 @@ class funct(FOC.FoCUS):
|
|
|
2576
2580
|
if mask is not None:
|
|
2577
2581
|
if self.use_2D:
|
|
2578
2582
|
if (
|
|
2579
|
-
image1.shape[-2] != mask.shape[1]
|
|
2580
|
-
or image1.shape[-1] != mask.shape[
|
|
2583
|
+
image1.shape[-2] != mask.shape[-1]
|
|
2584
|
+
or image1.shape[-1] != mask.shape[-1]
|
|
2581
2585
|
):
|
|
2582
2586
|
print(
|
|
2583
2587
|
"The LAST 2 COLUMNs of the mask should have the same size ",
|
|
@@ -2588,7 +2592,7 @@ class funct(FOC.FoCUS):
|
|
|
2588
2592
|
)
|
|
2589
2593
|
return None
|
|
2590
2594
|
else:
|
|
2591
|
-
if image1.shape[-1] != mask.shape[1]:
|
|
2595
|
+
if image1.shape[-1] != mask.shape[-1]:
|
|
2592
2596
|
print(
|
|
2593
2597
|
"The LAST COLUMN of the mask should have the same size ",
|
|
2594
2598
|
mask.shape,
|
|
@@ -2609,7 +2613,6 @@ class funct(FOC.FoCUS):
|
|
|
2609
2613
|
cross = True
|
|
2610
2614
|
l_nside = 2**32 # not initialize if 1D or 2D
|
|
2611
2615
|
### PARAMETERS
|
|
2612
|
-
axis = 1
|
|
2613
2616
|
# determine jmax and nside corresponding to the input map
|
|
2614
2617
|
im_shape = image1.shape
|
|
2615
2618
|
if self.use_2D:
|
|
@@ -2637,10 +2640,7 @@ class funct(FOC.FoCUS):
|
|
|
2637
2640
|
|
|
2638
2641
|
J = int(np.log(nside) / np.log(2)) # Number of j scales
|
|
2639
2642
|
else:
|
|
2640
|
-
|
|
2641
|
-
npix = int(im_shape[1]) # Number of pixels
|
|
2642
|
-
else:
|
|
2643
|
-
npix = int(im_shape[0]) # Number of pixels
|
|
2643
|
+
npix=int(im_shape[-1])
|
|
2644
2644
|
|
|
2645
2645
|
if nside is None:
|
|
2646
2646
|
nside = int(np.sqrt(npix // 12))
|
|
@@ -2659,7 +2659,7 @@ class funct(FOC.FoCUS):
|
|
|
2659
2659
|
print("\n\n==========")
|
|
2660
2660
|
|
|
2661
2661
|
### LOCAL VARIABLES (IMAGES and MASK)
|
|
2662
|
-
if len(image1.shape) == 1 or (len(image1.shape) == 2 and self.use_2D):
|
|
2662
|
+
if len(image1.shape) == 1 or (len(image1.shape) == 2 and self.use_2D) or (len(image1.shape) == 2 and spin>0):
|
|
2663
2663
|
I1 = self.backend.bk_cast(
|
|
2664
2664
|
self.backend.bk_expand_dims(image1, 0)
|
|
2665
2665
|
) # Local image1 [Nbatch, Npix]
|
|
@@ -2684,26 +2684,26 @@ class funct(FOC.FoCUS):
|
|
|
2684
2684
|
# if the kernel size is bigger than 3 increase the binning before smoothing
|
|
2685
2685
|
if self.use_2D:
|
|
2686
2686
|
vmask = self.up_grade(
|
|
2687
|
-
vmask, I1.shape[
|
|
2687
|
+
vmask, I1.shape[-2] * 2, nouty=I1.shape[-1] * 2
|
|
2688
2688
|
)
|
|
2689
2689
|
I1 = self.up_grade(
|
|
2690
|
-
I1, I1.shape[
|
|
2690
|
+
I1, I1.shape[-2] * 2, nouty=I1.shape[-1] * 2
|
|
2691
2691
|
)
|
|
2692
2692
|
if cross:
|
|
2693
2693
|
I2 = self.up_grade(
|
|
2694
|
-
I2, I2.shape[
|
|
2694
|
+
I2, I2.shape[-2] * 2, nouty=I2.shape[-1] * 2
|
|
2695
2695
|
)
|
|
2696
2696
|
elif self.use_1D:
|
|
2697
|
-
vmask = self.up_grade(vmask, I1.shape[
|
|
2698
|
-
I1 = self.up_grade(I1, I1.shape[
|
|
2697
|
+
vmask = self.up_grade(vmask, I1.shape[-1] * 2)
|
|
2698
|
+
I1 = self.up_grade(I1, I1.shape[-1] * 2)
|
|
2699
2699
|
if cross:
|
|
2700
|
-
I2 = self.up_grade(I2, I2.shape[
|
|
2700
|
+
I2 = self.up_grade(I2, I2.shape[-1] * 2)
|
|
2701
2701
|
nside = nside * 2
|
|
2702
2702
|
else:
|
|
2703
|
-
I1 = self.up_grade(I1, nside * 2
|
|
2704
|
-
vmask = self.up_grade(vmask, nside * 2
|
|
2703
|
+
I1 = self.up_grade(I1, nside * 2)
|
|
2704
|
+
vmask = self.up_grade(vmask, nside * 2)
|
|
2705
2705
|
if cross:
|
|
2706
|
-
I2 = self.up_grade(I2, nside * 2
|
|
2706
|
+
I2 = self.up_grade(I2, nside * 2)
|
|
2707
2707
|
|
|
2708
2708
|
nside = nside * 2
|
|
2709
2709
|
|
|
@@ -2711,29 +2711,28 @@ class funct(FOC.FoCUS):
|
|
|
2711
2711
|
# if the kernel size is bigger than 3 increase the binning before smoothing
|
|
2712
2712
|
if self.use_2D:
|
|
2713
2713
|
vmask = self.up_grade(
|
|
2714
|
-
vmask, I1.shape[
|
|
2714
|
+
vmask, I1.shape[-2] * 2, nouty=I1.shape[-1] * 2
|
|
2715
2715
|
)
|
|
2716
2716
|
I1 = self.up_grade(
|
|
2717
|
-
I1, I1.shape[
|
|
2717
|
+
I1, I1.shape[-2] * 2, nouty=I1.shape[-1] * 2
|
|
2718
2718
|
)
|
|
2719
2719
|
if cross:
|
|
2720
2720
|
I2 = self.up_grade(
|
|
2721
2721
|
I2,
|
|
2722
|
-
I2.shape[
|
|
2723
|
-
|
|
2724
|
-
nouty=I2.shape[axis + 1] * 2,
|
|
2722
|
+
I2.shape[-2] * 2,
|
|
2723
|
+
nouty=I2.shape[-1] * 2,
|
|
2725
2724
|
)
|
|
2726
2725
|
elif self.use_1D:
|
|
2727
|
-
vmask = self.up_grade(vmask, I1.shape[
|
|
2728
|
-
I1 = self.up_grade(I1, I1.shape[
|
|
2726
|
+
vmask = self.up_grade(vmask, I1.shape[-1] * 2)
|
|
2727
|
+
I1 = self.up_grade(I1, I1.shape[-1] * 2)
|
|
2729
2728
|
if cross:
|
|
2730
|
-
I2 = self.up_grade(I2, I2.shape[
|
|
2729
|
+
I2 = self.up_grade(I2, I2.shape[-1] * 2)
|
|
2731
2730
|
nside = nside * 2
|
|
2732
2731
|
else:
|
|
2733
|
-
I1 = self.up_grade(I1, nside * 2
|
|
2734
|
-
vmask = self.up_grade(vmask, nside * 2
|
|
2732
|
+
I1 = self.up_grade(I1, nside * 2)
|
|
2733
|
+
vmask = self.up_grade(vmask, nside * 2)
|
|
2735
2734
|
if cross:
|
|
2736
|
-
I2 = self.up_grade(I2, nside * 2
|
|
2735
|
+
I2 = self.up_grade(I2, nside * 2)
|
|
2737
2736
|
nside = nside * 2
|
|
2738
2737
|
|
|
2739
2738
|
# Normalize the masks because they have different pixel numbers
|
|
@@ -2762,6 +2761,7 @@ class funct(FOC.FoCUS):
|
|
|
2762
2761
|
off_S2 = -2
|
|
2763
2762
|
off_S3 = -3
|
|
2764
2763
|
off_S4 = -4
|
|
2764
|
+
|
|
2765
2765
|
if self.use_1D:
|
|
2766
2766
|
off_S2 = -1
|
|
2767
2767
|
off_S3 = -1
|
|
@@ -2793,13 +2793,20 @@ class funct(FOC.FoCUS):
|
|
|
2793
2793
|
)
|
|
2794
2794
|
else:
|
|
2795
2795
|
if not cross:
|
|
2796
|
-
s0, l_vs0 = self.masked_mean(I1,
|
|
2796
|
+
s0, l_vs0 = self.masked_mean(I1,
|
|
2797
|
+
vmask,
|
|
2798
|
+
calc_var=True)
|
|
2797
2799
|
else:
|
|
2798
2800
|
s0, l_vs0 = self.masked_mean(
|
|
2799
|
-
self.backend.bk_L1(I1 * I2),
|
|
2800
|
-
|
|
2801
|
-
|
|
2802
|
-
|
|
2801
|
+
self.backend.bk_L1(I1 * I2),
|
|
2802
|
+
vmask,
|
|
2803
|
+
calc_var=True)
|
|
2804
|
+
|
|
2805
|
+
vs0 = self.backend.bk_concat([l_vs0, l_vs0], -1)
|
|
2806
|
+
s0 = self.backend.bk_concat([s0, l_vs0], -1)
|
|
2807
|
+
if spin>0:
|
|
2808
|
+
vs0=self.backend.bk_reshape(vs0,[vs0.shape[0],vs0.shape[1],2,vs0.shape[2]//2])
|
|
2809
|
+
s0=self.backend.bk_reshape(s0,[s0.shape[0],s0.shape[1],2,s0.shape[2]//2])
|
|
2803
2810
|
#### COMPUTE S1, S2, S3 and S4
|
|
2804
2811
|
nside_j3 = nside # NSIDE start (nside_j3 = nside / 2^j3)
|
|
2805
2812
|
|
|
@@ -2848,13 +2855,13 @@ class funct(FOC.FoCUS):
|
|
|
2848
2855
|
####### S1 and S2
|
|
2849
2856
|
### Make the convolution I1 * Psi_j3
|
|
2850
2857
|
conv1 = self.convol(
|
|
2851
|
-
I1,
|
|
2858
|
+
I1, cell_ids=cell_ids_j3, nside=nside_j3,
|
|
2859
|
+
spin=spin
|
|
2852
2860
|
) # [Nbatch, Norient3 , Npix_j3]
|
|
2853
2861
|
|
|
2854
2862
|
if cmat is not None:
|
|
2855
|
-
|
|
2856
2863
|
tmp2 = self.backend.bk_repeat(conv1, self.NORIENT, axis=-2)
|
|
2857
|
-
|
|
2864
|
+
|
|
2858
2865
|
conv1 = self.backend.bk_reduce_sum(
|
|
2859
2866
|
self.backend.bk_reshape(
|
|
2860
2867
|
cmat[j3] * tmp2,
|
|
@@ -2871,7 +2878,6 @@ class funct(FOC.FoCUS):
|
|
|
2871
2878
|
M1 = self.backend.bk_L1(M1_square) # [Nbatch, Npix_j3, Norient3]
|
|
2872
2879
|
# Store M1_j3 in a dictionary
|
|
2873
2880
|
M1_dic[j3] = M1
|
|
2874
|
-
|
|
2875
2881
|
if not cross: # Auto
|
|
2876
2882
|
M1_square = self.backend.bk_real(M1_square)
|
|
2877
2883
|
|
|
@@ -2882,11 +2888,11 @@ class funct(FOC.FoCUS):
|
|
|
2882
2888
|
else:
|
|
2883
2889
|
if calc_var:
|
|
2884
2890
|
s2, vs2 = self.masked_mean(
|
|
2885
|
-
M1_square, vmask,
|
|
2891
|
+
M1_square, vmask, rank=j3, calc_var=True
|
|
2886
2892
|
)
|
|
2887
2893
|
else:
|
|
2888
|
-
s2 = self.masked_mean(M1_square, vmask,
|
|
2889
|
-
|
|
2894
|
+
s2 = self.masked_mean(M1_square, vmask, rank=j3)
|
|
2895
|
+
|
|
2890
2896
|
if cond_init_P1_dic:
|
|
2891
2897
|
# We fill P1_dic with S2 for normalisation of S3 and S4
|
|
2892
2898
|
P1_dic[j3] = self.backend.bk_real(s2) # [Nbatch, Nmask, Norient3]
|
|
@@ -2929,11 +2935,11 @@ class funct(FOC.FoCUS):
|
|
|
2929
2935
|
else:
|
|
2930
2936
|
if calc_var:
|
|
2931
2937
|
s1, vs1 = self.masked_mean(
|
|
2932
|
-
M1, vmask,
|
|
2938
|
+
M1, vmask, rank=j3, calc_var=True
|
|
2933
2939
|
) # [Nbatch, Nmask, Norient3]
|
|
2934
2940
|
else:
|
|
2935
2941
|
s1 = self.masked_mean(
|
|
2936
|
-
M1, vmask,
|
|
2942
|
+
M1, vmask, rank=j3
|
|
2937
2943
|
) # [Nbatch, Nmask, Norient3]
|
|
2938
2944
|
|
|
2939
2945
|
if return_data:
|
|
@@ -2967,7 +2973,8 @@ class funct(FOC.FoCUS):
|
|
|
2967
2973
|
else: # Cross
|
|
2968
2974
|
### Make the convolution I2 * Psi_j3
|
|
2969
2975
|
conv2 = self.convol(
|
|
2970
|
-
I2,
|
|
2976
|
+
I2, cell_ids=cell_ids_j3, nside=nside_j3,
|
|
2977
|
+
spin=spin
|
|
2971
2978
|
) # [Nbatch, Npix_j3, Norient3]
|
|
2972
2979
|
if cmat is not None:
|
|
2973
2980
|
tmp2 = self.backend.bk_repeat(conv2, self.NORIENT, axis=-2)
|
|
@@ -3001,17 +3008,17 @@ class funct(FOC.FoCUS):
|
|
|
3001
3008
|
else:
|
|
3002
3009
|
if calc_var:
|
|
3003
3010
|
p1, vp1 = self.masked_mean(
|
|
3004
|
-
M1_square, vmask,
|
|
3011
|
+
M1_square, vmask, rank=j3, calc_var=True
|
|
3005
3012
|
) # [Nbatch, Nmask, Norient3]
|
|
3006
3013
|
p2, vp2 = self.masked_mean(
|
|
3007
|
-
M2_square, vmask,
|
|
3014
|
+
M2_square, vmask, rank=j3, calc_var=True
|
|
3008
3015
|
) # [Nbatch, Nmask, Norient3]
|
|
3009
3016
|
else:
|
|
3010
3017
|
p1 = self.masked_mean(
|
|
3011
|
-
M1_square, vmask,
|
|
3018
|
+
M1_square, vmask, rank=j3
|
|
3012
3019
|
) # [Nbatch, Nmask, Norient3]
|
|
3013
3020
|
p2 = self.masked_mean(
|
|
3014
|
-
M2_square, vmask,
|
|
3021
|
+
M2_square, vmask, rank=j3
|
|
3015
3022
|
) # [Nbatch, Nmask, Norient3]
|
|
3016
3023
|
# We fill P1_dic with S2 for normalisation of S3 and S4
|
|
3017
3024
|
P1_dic[j3] = self.backend.bk_real(p1) # [Nbatch, Nmask, Norient3]
|
|
@@ -3027,10 +3034,10 @@ class funct(FOC.FoCUS):
|
|
|
3027
3034
|
else:
|
|
3028
3035
|
if calc_var:
|
|
3029
3036
|
s2, vs2 = self.masked_mean(
|
|
3030
|
-
s2, vmask,
|
|
3037
|
+
s2, vmask, rank=j3, calc_var=True
|
|
3031
3038
|
)
|
|
3032
3039
|
else:
|
|
3033
|
-
s2 = self.masked_mean(s2, vmask,
|
|
3040
|
+
s2 = self.masked_mean(s2, vmask, rank=j3)
|
|
3034
3041
|
|
|
3035
3042
|
if return_data:
|
|
3036
3043
|
if out_nside is not None and out_nside < nside_j3:
|
|
@@ -3072,11 +3079,11 @@ class funct(FOC.FoCUS):
|
|
|
3072
3079
|
else:
|
|
3073
3080
|
if calc_var:
|
|
3074
3081
|
s1, vs1 = self.masked_mean(
|
|
3075
|
-
MX, vmask,
|
|
3082
|
+
MX, vmask, rank=j3, calc_var=True
|
|
3076
3083
|
) # [Nbatch, Nmask, Norient3]
|
|
3077
3084
|
else:
|
|
3078
3085
|
s1 = self.masked_mean(
|
|
3079
|
-
MX, vmask,
|
|
3086
|
+
MX, vmask, rank=j3
|
|
3080
3087
|
) # [Nbatch, Nmask, Norient3]
|
|
3081
3088
|
if return_data:
|
|
3082
3089
|
if out_nside is not None and out_nside < nside_j3:
|
|
@@ -3482,39 +3489,39 @@ class funct(FOC.FoCUS):
|
|
|
3482
3489
|
### Image I1,
|
|
3483
3490
|
# downscale the I1 [Nbatch, Npix_j3]
|
|
3484
3491
|
if j3 != Jmax - 1:
|
|
3485
|
-
I1 = self.smooth(I1,
|
|
3492
|
+
I1 = self.smooth(I1, cell_ids=cell_ids_j3, nside=nside_j3)
|
|
3486
3493
|
I1, new_cell_ids_j3 = self.ud_grade_2(
|
|
3487
|
-
I1,
|
|
3494
|
+
I1, cell_ids=cell_ids_j3, nside=nside_j3
|
|
3488
3495
|
)
|
|
3489
3496
|
|
|
3490
3497
|
### Image I2
|
|
3491
3498
|
if cross:
|
|
3492
|
-
I2 = self.smooth(I2,
|
|
3499
|
+
I2 = self.smooth(I2, cell_ids=cell_ids_j3, nside=nside_j3)
|
|
3493
3500
|
I2, new_cell_ids_j3 = self.ud_grade_2(
|
|
3494
|
-
I2,
|
|
3501
|
+
I2, cell_ids=cell_ids_j3, nside=nside_j3
|
|
3495
3502
|
)
|
|
3496
3503
|
|
|
3497
3504
|
### Modules
|
|
3498
3505
|
for j2 in range(0, j3 + 1): # j2 =< j3
|
|
3499
3506
|
### Dictionary M1_dic[j2]
|
|
3500
3507
|
M1_smooth = self.smooth(
|
|
3501
|
-
M1_dic[j2],
|
|
3508
|
+
M1_dic[j2], cell_ids=cell_ids_j3, nside=nside_j3
|
|
3502
3509
|
) # [Nbatch, Npix_j3, Norient3]
|
|
3503
3510
|
M1_dic[j2], new_cell_ids_j2 = self.ud_grade_2(
|
|
3504
|
-
M1_smooth,
|
|
3511
|
+
M1_smooth, cell_ids=cell_ids_j3, nside=nside_j3
|
|
3505
3512
|
) # [Nbatch, Npix_j3, Norient3]
|
|
3506
3513
|
|
|
3507
3514
|
### Dictionary M2_dic[j2]
|
|
3508
3515
|
if cross:
|
|
3509
3516
|
M2_smooth = self.smooth(
|
|
3510
|
-
M2_dic[j2],
|
|
3517
|
+
M2_dic[j2], cell_ids=cell_ids_j3, nside=nside_j3
|
|
3511
3518
|
) # [Nbatch, Npix_j3, Norient3]
|
|
3512
3519
|
M2_dic[j2], new_cell_ids_j2 = self.ud_grade_2(
|
|
3513
|
-
M2_smooth,
|
|
3520
|
+
M2_smooth, cell_ids=cell_ids_j3, nside=nside_j3
|
|
3514
3521
|
) # [Nbatch, Npix_j3, Norient3]
|
|
3515
3522
|
### Mask
|
|
3516
3523
|
vmask, new_cell_ids_j3 = self.ud_grade_2(
|
|
3517
|
-
vmask,
|
|
3524
|
+
vmask, cell_ids=cell_ids_j3, nside=nside_j3
|
|
3518
3525
|
)
|
|
3519
3526
|
|
|
3520
3527
|
if self.mask_thres is not None:
|
|
@@ -3529,33 +3536,21 @@ class funct(FOC.FoCUS):
|
|
|
3529
3536
|
self.P1_dic = P1_dic
|
|
3530
3537
|
if cross:
|
|
3531
3538
|
self.P2_dic = P2_dic
|
|
3532
|
-
|
|
3533
|
-
Sout=[s0]+S1+S2+S3+S4
|
|
3534
|
-
|
|
3535
|
-
if cross:
|
|
3536
|
-
Sout=Sout+S3P
|
|
3537
|
-
if calc_var:
|
|
3538
|
-
SVout=[vs0]+VS1+VS2+VS3+VS4
|
|
3539
|
-
if cross:
|
|
3540
|
-
VSout=VSout+VS3P
|
|
3541
|
-
return self.backend.bk_concat(Sout, 2),self.backend.bk_concat(VSout, 2)
|
|
3542
|
-
|
|
3543
|
-
return self.backend.bk_concat(Sout, 2)
|
|
3544
|
-
"""
|
|
3539
|
+
|
|
3545
3540
|
if not return_data:
|
|
3546
|
-
S1 = self.backend.bk_concat(S1, 2)
|
|
3547
|
-
S2 = self.backend.bk_concat(S2, 2)
|
|
3548
|
-
S3 = self.backend.bk_concat(S3,
|
|
3549
|
-
S4 = self.backend.bk_concat(S4,
|
|
3541
|
+
S1 = self.backend.bk_concat(S1, -2)
|
|
3542
|
+
S2 = self.backend.bk_concat(S2, -2)
|
|
3543
|
+
S3 = self.backend.bk_concat(S3, -3)
|
|
3544
|
+
S4 = self.backend.bk_concat(S4, -4)
|
|
3550
3545
|
if cross:
|
|
3551
|
-
S3P = self.backend.bk_concat(S3P,
|
|
3546
|
+
S3P = self.backend.bk_concat(S3P, -3)
|
|
3552
3547
|
if calc_var:
|
|
3553
|
-
VS1 = self.backend.bk_concat(VS1, 2)
|
|
3554
|
-
VS2 = self.backend.bk_concat(VS2, 2)
|
|
3555
|
-
VS3 = self.backend.bk_concat(VS3,
|
|
3556
|
-
VS4 = self.backend.bk_concat(VS4,
|
|
3548
|
+
VS1 = self.backend.bk_concat(VS1, -2)
|
|
3549
|
+
VS2 = self.backend.bk_concat(VS2, -2)
|
|
3550
|
+
VS3 = self.backend.bk_concat(VS3, -3)
|
|
3551
|
+
VS4 = self.backend.bk_concat(VS4, -4)
|
|
3557
3552
|
if cross:
|
|
3558
|
-
VS3P = self.backend.bk_concat(VS3P,
|
|
3553
|
+
VS3P = self.backend.bk_concat(VS3P, -3)
|
|
3559
3554
|
if calc_var:
|
|
3560
3555
|
if not cross:
|
|
3561
3556
|
return scat_cov(
|
|
@@ -3637,7 +3632,7 @@ class funct(FOC.FoCUS):
|
|
|
3637
3632
|
### Compute |I1 * Psi_j2| * Psi_j3 = M1_j2 * Psi_j3
|
|
3638
3633
|
# Warning: M1_dic[j2] is already at j3 resolution [Nbatch, Norient3, Npix_j3]
|
|
3639
3634
|
MconvPsi = self.convol(
|
|
3640
|
-
M_dic[j2],
|
|
3635
|
+
M_dic[j2], cell_ids=cell_ids, nside=nside_j2
|
|
3641
3636
|
) # [Nbatch, Norient3, Norient2, Npix_j3]
|
|
3642
3637
|
|
|
3643
3638
|
if cmat2 is not None:
|
|
@@ -3674,12 +3669,12 @@ class funct(FOC.FoCUS):
|
|
|
3674
3669
|
else:
|
|
3675
3670
|
if calc_var:
|
|
3676
3671
|
s3, vs3 = self.masked_mean(
|
|
3677
|
-
s3, vmask,
|
|
3672
|
+
s3, vmask, rank=j2, calc_var=True
|
|
3678
3673
|
) # [Nbatch, Nmask, Norient3, Norient2]
|
|
3679
3674
|
return s3, vs3
|
|
3680
3675
|
else:
|
|
3681
3676
|
s3 = self.masked_mean(
|
|
3682
|
-
s3, vmask,
|
|
3677
|
+
s3, vmask, rank=j2
|
|
3683
3678
|
) # [Nbatch, Nmask, Norient3, Norient2]
|
|
3684
3679
|
return s3
|
|
3685
3680
|
|
|
@@ -3717,12 +3712,12 @@ class funct(FOC.FoCUS):
|
|
|
3717
3712
|
else:
|
|
3718
3713
|
if calc_var:
|
|
3719
3714
|
s4, vs4 = self.masked_mean(
|
|
3720
|
-
s4, vmask,
|
|
3715
|
+
s4, vmask, rank=j2, calc_var=True
|
|
3721
3716
|
) # [Nbatch, Nmask, Norient3, Norient2, Norient1]
|
|
3722
3717
|
return s4, vs4
|
|
3723
3718
|
else:
|
|
3724
3719
|
s4 = self.masked_mean(
|
|
3725
|
-
s4, vmask,
|
|
3720
|
+
s4, vmask, rank=j2
|
|
3726
3721
|
) # [Nbatch, Nmask, Norient3, Norient2, Norient1]
|
|
3727
3722
|
return s4
|
|
3728
3723
|
|
|
@@ -3922,20 +3917,20 @@ class funct(FOC.FoCUS):
|
|
|
3922
3917
|
#
|
|
3923
3918
|
# ---------------------------------------------------------------------------
|
|
3924
3919
|
def scattering_cov(
|
|
3925
|
-
|
|
3926
|
-
|
|
3927
|
-
|
|
3928
|
-
|
|
3929
|
-
|
|
3930
|
-
|
|
3931
|
-
|
|
3932
|
-
|
|
3933
|
-
|
|
3934
|
-
|
|
3935
|
-
|
|
3936
|
-
|
|
3937
|
-
|
|
3938
|
-
|
|
3920
|
+
self,
|
|
3921
|
+
data,
|
|
3922
|
+
data2=None,
|
|
3923
|
+
Jmax=None,
|
|
3924
|
+
if_large_batch=False,
|
|
3925
|
+
S4_criteria=None,
|
|
3926
|
+
use_ref=False,
|
|
3927
|
+
normalization="S2",
|
|
3928
|
+
edge=False,
|
|
3929
|
+
in_mask=None,
|
|
3930
|
+
pseudo_coef=1,
|
|
3931
|
+
get_variance=False,
|
|
3932
|
+
ref_sigma=None,
|
|
3933
|
+
iso_ang=False,
|
|
3939
3934
|
):
|
|
3940
3935
|
"""
|
|
3941
3936
|
Calculates the scattering correlations for a batch of images, including:
|
|
@@ -4048,7 +4043,10 @@ class funct(FOC.FoCUS):
|
|
|
4048
4043
|
if data2 is not None:
|
|
4049
4044
|
N_image2 = data2.shape[0]
|
|
4050
4045
|
|
|
4051
|
-
|
|
4046
|
+
if spin==0:
|
|
4047
|
+
nside = int(np.sqrt(npix // 12))
|
|
4048
|
+
else:
|
|
4049
|
+
nside = int(np.sqrt(npix // 24))
|
|
4052
4050
|
|
|
4053
4051
|
J = int(np.log(nside) / np.log(2)) # Number of j scales
|
|
4054
4052
|
|
|
@@ -5781,7 +5779,9 @@ class funct(FOC.FoCUS):
|
|
|
5781
5779
|
|
|
5782
5780
|
def from_gaussian(self, x):
|
|
5783
5781
|
|
|
5784
|
-
x = self.backend.bk_clip_by_value(x,
|
|
5782
|
+
x = self.backend.bk_clip_by_value(x,
|
|
5783
|
+
self.val_min+1E-4*abs(self.val_min),
|
|
5784
|
+
self.val_max-1E-4*abs(self.val_max))
|
|
5785
5785
|
return self.f_gaussian(self.backend.to_numpy(x))
|
|
5786
5786
|
|
|
5787
5787
|
def square(self, x):
|
|
@@ -6079,7 +6079,38 @@ class funct(FOC.FoCUS):
|
|
|
6079
6079
|
return scat_cov(
|
|
6080
6080
|
s0, s2, s3, s4, s1=s1, s3p=s3p, backend=self.backend, use_1D=self.use_1D
|
|
6081
6081
|
)
|
|
6082
|
-
|
|
6082
|
+
def calc_matrix_orientation(self,noise_map,image2=None):
|
|
6083
|
+
# Décalage circulaire par matrice de permutation
|
|
6084
|
+
def circ_shift_matrix(N,k):
|
|
6085
|
+
return np.roll(np.eye(N), shift=-k, axis=1)
|
|
6086
|
+
Norient = self.NORIENT
|
|
6087
|
+
im=self.convol(noise_map)
|
|
6088
|
+
if image2 is None:
|
|
6089
|
+
mm=np.mean(abs(self.backend.to_numpy(im)),0)
|
|
6090
|
+
else:
|
|
6091
|
+
im2=self.convol(self.backend.bk_cast(image2))
|
|
6092
|
+
mm=np.mean(self.backend.to_numpy(
|
|
6093
|
+
self.backend.bk_L1(im*self.backend.bk_conjugate(im2))).real,0)
|
|
6094
|
+
|
|
6095
|
+
Norient=mm.shape[0]
|
|
6096
|
+
xx=np.cos(np.arange(Norient)/Norient*2*np.pi)
|
|
6097
|
+
yy=np.sin(np.arange(Norient)/Norient*2*np.pi)
|
|
6098
|
+
|
|
6099
|
+
a=np.sum(mm*xx[:,None],0)
|
|
6100
|
+
b=np.sum(mm*yy[:,None],0)
|
|
6101
|
+
|
|
6102
|
+
o=np.fmod(Norient*np.arctan2(-b,a)/(2*np.pi)+Norient,Norient)
|
|
6103
|
+
xx=np.arange(Norient)
|
|
6104
|
+
alpha = o[None,:]-xx[:,None]
|
|
6105
|
+
beta = np.fmod(1+o[None,:]-xx[:,None],Norient)
|
|
6106
|
+
alpha=(1-alpha)*(alpha<1)*(alpha>0)+beta*(beta<1)*(beta>0)
|
|
6107
|
+
|
|
6108
|
+
m=np.zeros([Norient,Norient,mm.shape[1]])
|
|
6109
|
+
for k in range(Norient):
|
|
6110
|
+
m[k,:,:]=np.roll(alpha,k,0)
|
|
6111
|
+
#m=np.mean(m,0)
|
|
6112
|
+
return self.backend.bk_cast(m)
|
|
6113
|
+
|
|
6083
6114
|
def synthesis(
|
|
6084
6115
|
self,
|
|
6085
6116
|
image_target,
|