foscat 3.8.2__py3-none-any.whl → 3.9.0__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- foscat/BkBase.py +36 -35
- foscat/BkNumpy.py +53 -62
- foscat/BkTensorflow.py +87 -88
- foscat/BkTorch.py +91 -72
- foscat/FoCUS.py +72 -56
- foscat/Synthesis.py +3 -3
- foscat/alm.py +188 -170
- foscat/backend.py +84 -70
- foscat/scat_cov.py +1849 -2086
- foscat/scat_cov2D.py +146 -53
- {foscat-3.8.2.dist-info → foscat-3.9.0.dist-info}/METADATA +1 -1
- {foscat-3.8.2.dist-info → foscat-3.9.0.dist-info}/RECORD +15 -15
- {foscat-3.8.2.dist-info → foscat-3.9.0.dist-info}/WHEEL +1 -1
- {foscat-3.8.2.dist-info → foscat-3.9.0.dist-info}/LICENSE +0 -0
- {foscat-3.8.2.dist-info → foscat-3.9.0.dist-info}/top_level.txt +0 -0
foscat/scat_cov.py
CHANGED
|
@@ -92,7 +92,12 @@ class scat_cov:
|
|
|
92
92
|
)
|
|
93
93
|
|
|
94
94
|
def conv2complex(self, val):
|
|
95
|
-
if
|
|
95
|
+
if (
|
|
96
|
+
val.dtype == "complex64"
|
|
97
|
+
or val.dtype == "complex128"
|
|
98
|
+
or val.dtype == "torch.complex64"
|
|
99
|
+
or val.dtype == "torch.complex128"
|
|
100
|
+
):
|
|
96
101
|
return val
|
|
97
102
|
else:
|
|
98
103
|
return self.backend.bk_complex(val, 0 * val)
|
|
@@ -2043,9 +2048,11 @@ class scat_cov:
|
|
|
2043
2048
|
)
|
|
2044
2049
|
)
|
|
2045
2050
|
else:
|
|
2046
|
-
s3[i, j, idx[noff:], k, l_orient] =
|
|
2047
|
-
|
|
2048
|
-
|
|
2051
|
+
s3[i, j, idx[noff:], k, l_orient] = (
|
|
2052
|
+
self.backend.to_numpy(self.S3)[
|
|
2053
|
+
i, j, j2 == ij - noff, k, l_orient
|
|
2054
|
+
]
|
|
2055
|
+
)
|
|
2049
2056
|
s3[i, j, idx[:noff], k, l_orient] = (
|
|
2050
2057
|
self.add_data_from_slope(
|
|
2051
2058
|
self.backend.to_numpy(self.S3)[
|
|
@@ -2208,7 +2215,7 @@ class scat_cov:
|
|
|
2208
2215
|
|
|
2209
2216
|
|
|
2210
2217
|
class funct(FOC.FoCUS):
|
|
2211
|
-
|
|
2218
|
+
|
|
2212
2219
|
def fill(self, im, nullval=hp.UNSEEN):
|
|
2213
2220
|
if self.use_2D:
|
|
2214
2221
|
return self.fill_2d(im, nullval=nullval)
|
|
@@ -2375,13 +2382,19 @@ class funct(FOC.FoCUS):
|
|
|
2375
2382
|
|
|
2376
2383
|
# instead of difference between "opposite" channels use weighted average
|
|
2377
2384
|
# of cosine and sine contributions using all channels
|
|
2378
|
-
angles = self.backend.bk_cast(
|
|
2379
|
-
2 * np.pi * np.arange(self.NORIENT) / self.NORIENT
|
|
2380
|
-
|
|
2385
|
+
angles = self.backend.bk_cast(
|
|
2386
|
+
(2 * np.pi * np.arange(self.NORIENT) / self.NORIENT).reshape(
|
|
2387
|
+
1, 1, self.NORIENT
|
|
2388
|
+
)
|
|
2389
|
+
) # shape: (NORIENT,)
|
|
2381
2390
|
|
|
2382
2391
|
# we use cosines and sines as weights for sim
|
|
2383
|
-
weighted_cos = self.backend.bk_reduce_mean(
|
|
2384
|
-
|
|
2392
|
+
weighted_cos = self.backend.bk_reduce_mean(
|
|
2393
|
+
sim * self.backend.bk_cos(angles), axis=-1
|
|
2394
|
+
)
|
|
2395
|
+
weighted_sin = self.backend.bk_reduce_mean(
|
|
2396
|
+
sim * self.backend.bk_sin(angles), axis=-1
|
|
2397
|
+
)
|
|
2385
2398
|
# For simplicity, take first element of the batch
|
|
2386
2399
|
cc = weighted_cos[0]
|
|
2387
2400
|
ss = weighted_sin[0]
|
|
@@ -2401,8 +2414,9 @@ class funct(FOC.FoCUS):
|
|
|
2401
2414
|
phase = np.fmod(np.arctan2(ss, cc) + 2 * np.pi, 2 * np.pi)
|
|
2402
2415
|
else:
|
|
2403
2416
|
phase = np.fmod(
|
|
2404
|
-
np.arctan2(self.backend.to_numpy(ss),
|
|
2405
|
-
|
|
2417
|
+
np.arctan2(self.backend.to_numpy(ss), self.backend.to_numpy(cc))
|
|
2418
|
+
+ 2 * np.pi,
|
|
2419
|
+
2 * np.pi,
|
|
2406
2420
|
)
|
|
2407
2421
|
|
|
2408
2422
|
# instead of linear interpolation cosine‐based interpolation
|
|
@@ -2416,10 +2430,10 @@ class funct(FOC.FoCUS):
|
|
|
2416
2430
|
# build rotation matrix
|
|
2417
2431
|
mat = np.zeros([sim.shape[1], self.NORIENT * self.NORIENT])
|
|
2418
2432
|
lidx = np.arange(sim.shape[1])
|
|
2419
|
-
for
|
|
2433
|
+
for ell in range(self.NORIENT):
|
|
2420
2434
|
# Instead of simple linear weights, we use the cosine weights w0 and w1.
|
|
2421
|
-
col0 = self.NORIENT * ((
|
|
2422
|
-
col1 = self.NORIENT * ((
|
|
2435
|
+
col0 = self.NORIENT * ((ell + iph) % self.NORIENT) + ell
|
|
2436
|
+
col1 = self.NORIENT * ((ell + iph + 1) % self.NORIENT) + ell
|
|
2423
2437
|
mat[lidx, col0] = w0
|
|
2424
2438
|
mat[lidx, col1] = w1
|
|
2425
2439
|
|
|
@@ -2436,7 +2450,8 @@ class funct(FOC.FoCUS):
|
|
|
2436
2450
|
sim2 = self.backend.bk_reduce_sum(
|
|
2437
2451
|
self.backend.bk_reshape(
|
|
2438
2452
|
self.backend.bk_cast(
|
|
2439
|
-
|
|
2453
|
+
mat.reshape(1, mat.shape[0], self.NORIENT * self.NORIENT)
|
|
2454
|
+
)
|
|
2440
2455
|
* tmp2,
|
|
2441
2456
|
[sim.shape[0], cmat[k].shape[0], self.NORIENT, self.NORIENT],
|
|
2442
2457
|
),
|
|
@@ -2470,8 +2485,11 @@ class funct(FOC.FoCUS):
|
|
|
2470
2485
|
phase2 = np.fmod(np.arctan2(ss2, cc2) + 2 * np.pi, 2 * np.pi)
|
|
2471
2486
|
else:
|
|
2472
2487
|
phase2 = np.fmod(
|
|
2473
|
-
np.arctan2(
|
|
2474
|
-
|
|
2488
|
+
np.arctan2(
|
|
2489
|
+
self.backend.to_numpy(ss2), self.backend.to_numpy(cc2)
|
|
2490
|
+
)
|
|
2491
|
+
+ 2 * np.pi,
|
|
2492
|
+
2 * np.pi,
|
|
2475
2493
|
)
|
|
2476
2494
|
|
|
2477
2495
|
phase2_scaled = self.NORIENT * phase2 / (2 * np.pi)
|
|
@@ -2482,9 +2500,11 @@ class funct(FOC.FoCUS):
|
|
|
2482
2500
|
lidx = np.arange(sim.shape[1])
|
|
2483
2501
|
|
|
2484
2502
|
for m in range(self.NORIENT):
|
|
2485
|
-
for
|
|
2486
|
-
col0 = self.NORIENT * ((
|
|
2487
|
-
col1 =
|
|
2503
|
+
for ell in range(self.NORIENT):
|
|
2504
|
+
col0 = self.NORIENT * ((ell + iph2[:, m]) % self.NORIENT) + ell
|
|
2505
|
+
col1 = (
|
|
2506
|
+
self.NORIENT * ((ell + iph2[:, m] + 1) % self.NORIENT) + ell
|
|
2507
|
+
)
|
|
2488
2508
|
mat2[k2, lidx, m, col0] = w0_2[:, m]
|
|
2489
2509
|
mat2[k2, lidx, m, col1] = w1_2[:, m]
|
|
2490
2510
|
cmat2[k] = self.backend.bk_cast(mat2.astype("complex64"))
|
|
@@ -2502,17 +2522,17 @@ class funct(FOC.FoCUS):
|
|
|
2502
2522
|
)
|
|
2503
2523
|
|
|
2504
2524
|
def eval(
|
|
2505
|
-
|
|
2506
|
-
|
|
2507
|
-
|
|
2508
|
-
|
|
2509
|
-
|
|
2510
|
-
|
|
2511
|
-
|
|
2512
|
-
|
|
2513
|
-
|
|
2514
|
-
|
|
2515
|
-
|
|
2525
|
+
self,
|
|
2526
|
+
image1,
|
|
2527
|
+
image2=None,
|
|
2528
|
+
mask=None,
|
|
2529
|
+
norm=None,
|
|
2530
|
+
calc_var=False,
|
|
2531
|
+
cmat=None,
|
|
2532
|
+
cmat2=None,
|
|
2533
|
+
Jmax=None,
|
|
2534
|
+
out_nside=None,
|
|
2535
|
+
edge=True,
|
|
2516
2536
|
):
|
|
2517
2537
|
"""
|
|
2518
2538
|
Calculates the scattering correlations for a batch of images. Mean are done over pixels.
|
|
@@ -2542,9 +2562,9 @@ class funct(FOC.FoCUS):
|
|
|
2542
2562
|
-------
|
|
2543
2563
|
S1, S2, S3, S4 normalized
|
|
2544
2564
|
"""
|
|
2545
|
-
|
|
2565
|
+
|
|
2546
2566
|
return_data = self.return_data
|
|
2547
|
-
|
|
2567
|
+
|
|
2548
2568
|
# Check input consistency
|
|
2549
2569
|
if image2 is not None:
|
|
2550
2570
|
if list(image1.shape) != list(image2.shape):
|
|
@@ -2554,7 +2574,10 @@ class funct(FOC.FoCUS):
|
|
|
2554
2574
|
return None
|
|
2555
2575
|
if mask is not None:
|
|
2556
2576
|
if self.use_2D:
|
|
2557
|
-
if
|
|
2577
|
+
if (
|
|
2578
|
+
image1.shape[-2] != mask.shape[1]
|
|
2579
|
+
or image1.shape[-1] != mask.shape[2]
|
|
2580
|
+
):
|
|
2558
2581
|
print(
|
|
2559
2582
|
"The LAST 2 COLUMNs of the mask should have the same size ",
|
|
2560
2583
|
mask.shape,
|
|
@@ -2564,7 +2587,7 @@ class funct(FOC.FoCUS):
|
|
|
2564
2587
|
)
|
|
2565
2588
|
return None
|
|
2566
2589
|
else:
|
|
2567
|
-
if image1.shape[-1] != mask.shape[1]:
|
|
2590
|
+
if image1.shape[-1] != mask.shape[1]:
|
|
2568
2591
|
print(
|
|
2569
2592
|
"The LAST COLUMN of the mask should have the same size ",
|
|
2570
2593
|
mask.shape,
|
|
@@ -2618,16 +2641,17 @@ class funct(FOC.FoCUS):
|
|
|
2618
2641
|
nside = int(np.sqrt(npix // 12))
|
|
2619
2642
|
|
|
2620
2643
|
J = int(np.log(nside) / np.log(2)) # Number of j scales
|
|
2621
|
-
|
|
2622
|
-
if (self.use_2D or self.use_1D) and self.KERNELSZ>3:
|
|
2623
|
-
J-=1
|
|
2644
|
+
|
|
2645
|
+
if (self.use_2D or self.use_1D) and self.KERNELSZ > 3:
|
|
2646
|
+
J -= 1
|
|
2624
2647
|
if Jmax is None:
|
|
2625
2648
|
Jmax = J # Number of steps for the loop on scales
|
|
2626
|
-
if Jmax>J:
|
|
2627
|
-
print(
|
|
2628
|
-
print(
|
|
2629
|
-
|
|
2630
|
-
|
|
2649
|
+
if Jmax > J:
|
|
2650
|
+
print("==========\n\n")
|
|
2651
|
+
print(
|
|
2652
|
+
"The Jmax you requested is larger than the data size, which may cause problems while computing the scattering transform."
|
|
2653
|
+
)
|
|
2654
|
+
print("\n\n==========")
|
|
2631
2655
|
|
|
2632
2656
|
### LOCAL VARIABLES (IMAGES and MASK)
|
|
2633
2657
|
if len(image1.shape) == 1 or (len(image1.shape) == 2 and self.use_2D):
|
|
@@ -2770,32 +2794,37 @@ class funct(FOC.FoCUS):
|
|
|
2770
2794
|
nside_j3 = nside # NSIDE start (nside_j3 = nside / 2^j3)
|
|
2771
2795
|
|
|
2772
2796
|
# a remettre comme avant
|
|
2773
|
-
M1_dic={}
|
|
2774
|
-
M2_dic={}
|
|
2775
|
-
|
|
2797
|
+
M1_dic = {}
|
|
2798
|
+
M2_dic = {}
|
|
2799
|
+
|
|
2776
2800
|
for j3 in range(Jmax):
|
|
2777
|
-
|
|
2801
|
+
|
|
2778
2802
|
if edge:
|
|
2779
2803
|
if self.mask_mask is None:
|
|
2780
|
-
self.mask_mask={}
|
|
2804
|
+
self.mask_mask = {}
|
|
2781
2805
|
if self.use_2D:
|
|
2782
|
-
if (vmask.shape[1],vmask.shape[2]) not in self.mask_mask:
|
|
2783
|
-
mask_mask=np.zeros([1,vmask.shape[1],vmask.shape[2]])
|
|
2784
|
-
mask_mask[
|
|
2785
|
-
|
|
2786
|
-
|
|
2787
|
-
|
|
2788
|
-
|
|
2789
|
-
|
|
2790
|
-
|
|
2806
|
+
if (vmask.shape[1], vmask.shape[2]) not in self.mask_mask:
|
|
2807
|
+
mask_mask = np.zeros([1, vmask.shape[1], vmask.shape[2]])
|
|
2808
|
+
mask_mask[
|
|
2809
|
+
0,
|
|
2810
|
+
self.KERNELSZ // 2 : -self.KERNELSZ // 2 + 1,
|
|
2811
|
+
self.KERNELSZ // 2 : -self.KERNELSZ // 2 + 1,
|
|
2812
|
+
] = 1.0
|
|
2813
|
+
self.mask_mask[(vmask.shape[1], vmask.shape[2])] = (
|
|
2814
|
+
self.backend.bk_cast(mask_mask)
|
|
2815
|
+
)
|
|
2816
|
+
vmask = vmask * self.mask_mask[(vmask.shape[1], vmask.shape[2])]
|
|
2817
|
+
# print(self.KERNELSZ//2,vmask,mask_mask)
|
|
2818
|
+
|
|
2791
2819
|
if self.use_1D:
|
|
2792
2820
|
if (vmask.shape[1]) not in self.mask_mask:
|
|
2793
|
-
mask_mask=np.zeros([1,vmask.shape[1]])
|
|
2794
|
-
mask_mask[0,
|
|
2795
|
-
|
|
2796
|
-
|
|
2797
|
-
|
|
2798
|
-
|
|
2821
|
+
mask_mask = np.zeros([1, vmask.shape[1]])
|
|
2822
|
+
mask_mask[0, self.KERNELSZ // 2 : -self.KERNELSZ // 2 + 1] = 1.0
|
|
2823
|
+
self.mask_mask[(vmask.shape[1])] = self.backend.bk_cast(
|
|
2824
|
+
mask_mask
|
|
2825
|
+
)
|
|
2826
|
+
vmask = vmask * self.mask_mask[(vmask.shape[1])]
|
|
2827
|
+
|
|
2799
2828
|
if return_data:
|
|
2800
2829
|
S3[j3] = None
|
|
2801
2830
|
S3P[j3] = None
|
|
@@ -3449,9 +3478,9 @@ class funct(FOC.FoCUS):
|
|
|
3449
3478
|
M2_dic[j2], axis=1
|
|
3450
3479
|
) # [Nbatch, Npix_j3, Norient3]
|
|
3451
3480
|
M2_dic[j2] = self.ud_grade_2(
|
|
3452
|
-
|
|
3481
|
+
M2_smooth, axis=1
|
|
3453
3482
|
) # [Nbatch, Npix_j3, Norient3]
|
|
3454
|
-
|
|
3483
|
+
|
|
3455
3484
|
### Mask
|
|
3456
3485
|
vmask = self.ud_grade_2(vmask, axis=1)
|
|
3457
3486
|
|
|
@@ -3543,1278 +3572,232 @@ class funct(FOC.FoCUS):
|
|
|
3543
3572
|
use_1D=self.use_1D,
|
|
3544
3573
|
)
|
|
3545
3574
|
|
|
3546
|
-
def
|
|
3547
|
-
|
|
3548
|
-
|
|
3549
|
-
|
|
3550
|
-
|
|
3551
|
-
|
|
3552
|
-
|
|
3553
|
-
|
|
3554
|
-
|
|
3555
|
-
|
|
3556
|
-
|
|
3557
|
-
|
|
3575
|
+
def clean_norm(self):
|
|
3576
|
+
self.P1_dic = None
|
|
3577
|
+
self.P2_dic = None
|
|
3578
|
+
return
|
|
3579
|
+
|
|
3580
|
+
def _compute_S3(
|
|
3581
|
+
self,
|
|
3582
|
+
j2,
|
|
3583
|
+
j3,
|
|
3584
|
+
conv,
|
|
3585
|
+
vmask,
|
|
3586
|
+
M_dic,
|
|
3587
|
+
MconvPsi_dic,
|
|
3588
|
+
calc_var=False,
|
|
3589
|
+
return_data=False,
|
|
3590
|
+
cmat2=None,
|
|
3558
3591
|
):
|
|
3559
3592
|
"""
|
|
3560
|
-
|
|
3561
|
-
|
|
3562
|
-
S1 = <|I * Psi_j3|>
|
|
3563
|
-
Normalization : take the log
|
|
3564
|
-
power spectrum:
|
|
3565
|
-
S2 = <|I * Psi_j3|^2>
|
|
3566
|
-
Normalization : take the log
|
|
3567
|
-
orig. x modulus:
|
|
3568
|
-
S3 = < (I * Psi)_j3 x (|I * Psi_j2| * Psi_j3)^* >
|
|
3569
|
-
Normalization : divide by (S2_j2 * S2_j3)^0.5
|
|
3570
|
-
modulus x modulus:
|
|
3571
|
-
S4 = <(|I * psi1| * psi3)(|I * psi2| * psi3)^*>
|
|
3572
|
-
Normalization : divide by (S2_j1 * S2_j2)^0.5
|
|
3593
|
+
Compute the S3 coefficients (auto or cross)
|
|
3594
|
+
S3 = < (Ia * Psi)_j3 x (|Ib * Psi_j2| * Psi_j3)^* >_pix
|
|
3573
3595
|
Parameters
|
|
3574
3596
|
----------
|
|
3575
|
-
image1: tensor
|
|
3576
|
-
Image on which we compute the scattering coefficients [Nbatch, Npix, 1, 1]
|
|
3577
|
-
image2: tensor
|
|
3578
|
-
Second image. If not None, we compute cross-scattering covariance coefficients.
|
|
3579
|
-
mask:
|
|
3580
|
-
norm: None or str
|
|
3581
|
-
If None no normalization is applied, if 'auto' normalize by the reference S2,
|
|
3582
|
-
if 'self' normalize by the current S2.
|
|
3583
3597
|
Returns
|
|
3584
3598
|
-------
|
|
3585
|
-
|
|
3599
|
+
cs3, ss3: real and imag parts of S3 coeff
|
|
3586
3600
|
"""
|
|
3587
|
-
|
|
3588
|
-
|
|
3589
|
-
|
|
3590
|
-
|
|
3591
|
-
|
|
3592
|
-
|
|
3593
|
-
|
|
3594
|
-
|
|
3595
|
-
|
|
3596
|
-
|
|
3597
|
-
|
|
3598
|
-
|
|
3599
|
-
|
|
3600
|
-
|
|
3601
|
-
|
|
3602
|
-
|
|
3603
|
-
|
|
3604
|
-
)
|
|
3605
|
-
|
|
3606
|
-
if self.use_2D and len(image1.shape) < 2:
|
|
3607
|
-
print(
|
|
3608
|
-
"To work with 2D scattering transform, two dimension is needed, input map has only on dimension"
|
|
3601
|
+
### Compute |I1 * Psi_j2| * Psi_j3 = M1_j2 * Psi_j3
|
|
3602
|
+
# Warning: M1_dic[j2] is already at j3 resolution [Nbatch, Npix_j3, Norient3]
|
|
3603
|
+
MconvPsi = self.convol(
|
|
3604
|
+
M_dic[j2], axis=1
|
|
3605
|
+
) # [Nbatch, Npix_j3, Norient3, Norient2]
|
|
3606
|
+
if cmat2 is not None:
|
|
3607
|
+
tmp2 = self.backend.bk_repeat(MconvPsi, self.NORIENT, axis=-1)
|
|
3608
|
+
MconvPsi = self.backend.bk_reduce_sum(
|
|
3609
|
+
self.backend.bk_reshape(
|
|
3610
|
+
cmat2[j3][j2] * tmp2,
|
|
3611
|
+
[
|
|
3612
|
+
tmp2.shape[0],
|
|
3613
|
+
cmat2[j3].shape[1],
|
|
3614
|
+
self.NORIENT,
|
|
3615
|
+
self.NORIENT,
|
|
3616
|
+
self.NORIENT,
|
|
3617
|
+
],
|
|
3618
|
+
),
|
|
3619
|
+
3,
|
|
3609
3620
|
)
|
|
3610
|
-
return None
|
|
3611
|
-
|
|
3612
|
-
### AUTO OR CROSS
|
|
3613
|
-
cross = False
|
|
3614
|
-
if image2 is not None:
|
|
3615
|
-
cross = True
|
|
3616
3621
|
|
|
3617
|
-
|
|
3618
|
-
|
|
3619
|
-
# determine jmax and nside corresponding to the input map
|
|
3620
|
-
im_shape = image1.shape
|
|
3621
|
-
if self.use_2D:
|
|
3622
|
-
if len(image1.shape) == 2:
|
|
3623
|
-
nside = np.min([im_shape[0], im_shape[1]])
|
|
3624
|
-
npix = im_shape[0] * im_shape[1] # Number of pixels
|
|
3625
|
-
x1 = im_shape[0]
|
|
3626
|
-
x2 = im_shape[1]
|
|
3627
|
-
else:
|
|
3628
|
-
nside = np.min([im_shape[1], im_shape[2]])
|
|
3629
|
-
npix = im_shape[1] * im_shape[2] # Number of pixels
|
|
3630
|
-
x1 = im_shape[1]
|
|
3631
|
-
x2 = im_shape[2]
|
|
3632
|
-
J = int(np.log(nside - self.KERNELSZ) / np.log(2)) # Number of j scales
|
|
3633
|
-
elif self.use_1D:
|
|
3634
|
-
if len(image1.shape) == 2:
|
|
3635
|
-
npix = int(im_shape[1]) # Number of pixels
|
|
3636
|
-
else:
|
|
3637
|
-
npix = int(im_shape[0]) # Number of pixels
|
|
3622
|
+
# Store it so we can use it in S4 computation
|
|
3623
|
+
MconvPsi_dic[j2] = MconvPsi # [Nbatch, Npix_j3, Norient3, Norient2]
|
|
3638
3624
|
|
|
3639
|
-
|
|
3625
|
+
### Compute the product (I2 * Psi)_j3 x (M1_j2 * Psi_j3)^*
|
|
3626
|
+
# z_1 x z_2^* = (a1a2 + b1b2) + i(b1a2 - a1b2)
|
|
3627
|
+
# cconv, sconv are [Nbatch, Npix_j3, Norient3]
|
|
3628
|
+
if self.use_1D:
|
|
3629
|
+
s3 = conv * self.backend.bk_conjugate(MconvPsi)
|
|
3630
|
+
else:
|
|
3631
|
+
s3 = self.backend.bk_expand_dims(conv, -1) * self.backend.bk_conjugate(
|
|
3632
|
+
MconvPsi
|
|
3633
|
+
) # [Nbatch, Npix_j3, Norient3, Norient2]
|
|
3640
3634
|
|
|
3641
|
-
|
|
3635
|
+
### Apply the mask [Nmask, Npix_j3] and sum over pixels
|
|
3636
|
+
if return_data:
|
|
3637
|
+
return s3
|
|
3642
3638
|
else:
|
|
3643
|
-
if
|
|
3644
|
-
|
|
3639
|
+
if calc_var:
|
|
3640
|
+
s3, vs3 = self.masked_mean(
|
|
3641
|
+
s3, vmask, axis=1, rank=j2, calc_var=True
|
|
3642
|
+
) # [Nbatch, Nmask, Norient3, Norient2]
|
|
3643
|
+
return s3, vs3
|
|
3645
3644
|
else:
|
|
3646
|
-
|
|
3645
|
+
s3 = self.masked_mean(
|
|
3646
|
+
s3, vmask, axis=1, rank=j2
|
|
3647
|
+
) # [Nbatch, Nmask, Norient3, Norient2]
|
|
3648
|
+
return s3
|
|
3647
3649
|
|
|
3648
|
-
|
|
3650
|
+
def _compute_S4(
|
|
3651
|
+
self,
|
|
3652
|
+
j1,
|
|
3653
|
+
j2,
|
|
3654
|
+
vmask,
|
|
3655
|
+
M1convPsi_dic,
|
|
3656
|
+
M2convPsi_dic=None,
|
|
3657
|
+
calc_var=False,
|
|
3658
|
+
return_data=False,
|
|
3659
|
+
):
|
|
3660
|
+
#### Simplify notations
|
|
3661
|
+
M1 = M1convPsi_dic[j1] # [Nbatch, Npix_j3, Norient3, Norient1]
|
|
3649
3662
|
|
|
3650
|
-
|
|
3651
|
-
|
|
3652
|
-
|
|
3653
|
-
|
|
3654
|
-
|
|
3655
|
-
Jmax = J # Number of steps for the loop on scales
|
|
3656
|
-
if Jmax>J:
|
|
3657
|
-
print('==========\n\n')
|
|
3658
|
-
print('The Jmax you requested is larger than the data size, which may cause problems while computing the scattering transform.')
|
|
3659
|
-
print('\n\n==========')
|
|
3660
|
-
|
|
3663
|
+
# Auto or Cross coefficients
|
|
3664
|
+
if M2convPsi_dic is None: # Auto
|
|
3665
|
+
M2 = M1convPsi_dic[j2] # [Nbatch, Npix_j3, Norient3, Norient2]
|
|
3666
|
+
else: # Cross
|
|
3667
|
+
M2 = M2convPsi_dic[j2]
|
|
3661
3668
|
|
|
3662
|
-
###
|
|
3663
|
-
|
|
3664
|
-
|
|
3665
|
-
|
|
3666
|
-
) # Local image1 [Nbatch, Npix]
|
|
3667
|
-
if cross:
|
|
3668
|
-
I2 = self.backend.bk_cast(
|
|
3669
|
-
self.backend.bk_expand_dims(image2, 0)
|
|
3670
|
-
) # Local image2 [Nbatch, Npix]
|
|
3669
|
+
### Compute the product (|I1 * Psi_j1| * Psi_j3)(|I2 * Psi_j2| * Psi_j3)
|
|
3670
|
+
# z_1 x z_2^* = (a1a2 + b1b2) + i(b1a2 - a1b2)
|
|
3671
|
+
if self.use_1D:
|
|
3672
|
+
s4 = M1 * self.backend.bk_conjugate(M2)
|
|
3671
3673
|
else:
|
|
3672
|
-
|
|
3673
|
-
|
|
3674
|
-
|
|
3674
|
+
s4 = self.backend.bk_expand_dims(M1, -2) * self.backend.bk_conjugate(
|
|
3675
|
+
self.backend.bk_expand_dims(M2, -1)
|
|
3676
|
+
) # [Nbatch, Npix_j3, Norient3, Norient2, Norient1]
|
|
3675
3677
|
|
|
3676
|
-
|
|
3677
|
-
|
|
3678
|
-
|
|
3679
|
-
else:
|
|
3680
|
-
vmask = self.backend.bk_ones([1, npix], dtype=self.all_type)
|
|
3678
|
+
### Apply the mask and sum over pixels
|
|
3679
|
+
if return_data:
|
|
3680
|
+
return s4
|
|
3681
3681
|
else:
|
|
3682
|
-
|
|
3682
|
+
if calc_var:
|
|
3683
|
+
s4, vs4 = self.masked_mean(
|
|
3684
|
+
s4, vmask, axis=1, rank=j2, calc_var=True
|
|
3685
|
+
) # [Nbatch, Nmask, Norient3, Norient2, Norient1]
|
|
3686
|
+
return s4, vs4
|
|
3687
|
+
else:
|
|
3688
|
+
s4 = self.masked_mean(
|
|
3689
|
+
s4, vmask, axis=1, rank=j2
|
|
3690
|
+
) # [Nbatch, Nmask, Norient3, Norient2, Norient1]
|
|
3691
|
+
return s4
|
|
3683
3692
|
|
|
3684
|
-
|
|
3685
|
-
|
|
3686
|
-
|
|
3687
|
-
|
|
3688
|
-
|
|
3693
|
+
def computer_filter(self, M, N, J, L):
|
|
3694
|
+
"""
|
|
3695
|
+
This function is strongly inspire by the package https://github.com/SihaoCheng/scattering_transform
|
|
3696
|
+
Done by Sihao Cheng and Rudy Morel.
|
|
3697
|
+
"""
|
|
3698
|
+
|
|
3699
|
+
filter = np.zeros([J, L, M, N], dtype="complex64")
|
|
3700
|
+
|
|
3701
|
+
slant = 4.0 / L
|
|
3702
|
+
|
|
3703
|
+
for j in range(J):
|
|
3704
|
+
|
|
3705
|
+
for ell in range(L):
|
|
3706
|
+
|
|
3707
|
+
theta = (int(L - L / 2 - 1) - ell) * np.pi / L
|
|
3708
|
+
sigma = 0.8 * 2**j
|
|
3709
|
+
xi = 3.0 / 4.0 * np.pi / 2**j
|
|
3710
|
+
|
|
3711
|
+
R = np.array(
|
|
3712
|
+
[[np.cos(theta), -np.sin(theta)], [np.sin(theta), np.cos(theta)]],
|
|
3713
|
+
np.float64,
|
|
3689
3714
|
)
|
|
3690
|
-
|
|
3691
|
-
|
|
3715
|
+
R_inv = np.array(
|
|
3716
|
+
[[np.cos(theta), np.sin(theta)], [-np.sin(theta), np.cos(theta)]],
|
|
3717
|
+
np.float64,
|
|
3692
3718
|
)
|
|
3693
|
-
if cross:
|
|
3694
|
-
I2 = self.up_grade(
|
|
3695
|
-
I2, I2.shape[axis] * 2, axis=axis, nouty=I2.shape[axis + 1] * 2
|
|
3696
|
-
)
|
|
3697
|
-
elif self.use_1D:
|
|
3698
|
-
vmask = self.up_grade(vmask, I1.shape[axis] * 2, axis=1)
|
|
3699
|
-
I1 = self.up_grade(I1, I1.shape[axis] * 2, axis=axis)
|
|
3700
|
-
if cross:
|
|
3701
|
-
I2 = self.up_grade(I2, I2.shape[axis] * 2, axis=axis)
|
|
3702
|
-
else:
|
|
3703
|
-
I1 = self.up_grade(I1, nside * 2, axis=axis)
|
|
3704
|
-
vmask = self.up_grade(vmask, nside * 2, axis=1)
|
|
3705
|
-
if cross:
|
|
3706
|
-
I2 = self.up_grade(I2, nside * 2, axis=axis)
|
|
3707
|
-
|
|
3708
|
-
if self.KERNELSZ > 5 and not self.use_2D:
|
|
3709
|
-
# if the kernel size is bigger than 3 increase the binning before smoothing
|
|
3710
|
-
if self.use_2D:
|
|
3711
|
-
vmask = self.up_grade(
|
|
3712
|
-
vmask, I1.shape[axis] * 2, axis=1, nouty=I1.shape[axis + 1] * 2
|
|
3713
|
-
)
|
|
3714
|
-
I1 = self.up_grade(
|
|
3715
|
-
I1, I1.shape[axis] * 2, axis=axis, nouty=I1.shape[axis + 1] * 2
|
|
3716
|
-
)
|
|
3717
|
-
if cross:
|
|
3718
|
-
I2 = self.up_grade(
|
|
3719
|
-
I2,
|
|
3720
|
-
I2.shape[axis] * 2,
|
|
3721
|
-
axis=axis,
|
|
3722
|
-
nouty=I2.shape[axis + 1] * 2,
|
|
3723
|
-
)
|
|
3724
|
-
elif self.use_1D:
|
|
3725
|
-
vmask = self.up_grade(vmask, I1.shape[axis] * 4, axis=1)
|
|
3726
|
-
I1 = self.up_grade(I1, I1.shape[axis] * 4, axis=axis)
|
|
3727
|
-
if cross:
|
|
3728
|
-
I2 = self.up_grade(I2, I2.shape[axis] * 4, axis=axis)
|
|
3729
|
-
else:
|
|
3730
|
-
I1 = self.up_grade(I1, nside * 4, axis=axis)
|
|
3731
|
-
vmask = self.up_grade(vmask, nside * 4, axis=1)
|
|
3732
|
-
if cross:
|
|
3733
|
-
I2 = self.up_grade(I2, nside * 4, axis=axis)
|
|
3734
|
-
|
|
3735
|
-
# Normalize the masks because they have different pixel numbers
|
|
3736
|
-
# vmask /= self.backend.bk_reduce_sum(vmask, axis=1)[:, None] # [Nmask, Npix]
|
|
3737
|
-
|
|
3738
|
-
### INITIALIZATION
|
|
3739
|
-
# Coefficients
|
|
3740
|
-
if return_data:
|
|
3741
|
-
S1 = {}
|
|
3742
|
-
S2 = {}
|
|
3743
|
-
S3 = {}
|
|
3744
|
-
S3P = {}
|
|
3745
|
-
S4 = {}
|
|
3746
|
-
else:
|
|
3747
|
-
result=self.backend.backend.zeros([I1.shape[0],vmask.shape[0],2+2*Jmax*self.NORIENT],
|
|
3748
|
-
dtype=self.backend.backend.float32,
|
|
3749
|
-
device=self.backend.torch_device)
|
|
3750
|
-
vresult=self.backend.backend.zeros([I1.shape[0],vmask.shape[0],2+2*Jmax*self.NORIENT],
|
|
3751
|
-
dtype=self.backend.backend.float32,
|
|
3752
|
-
device=self.backend.torch_device)
|
|
3753
|
-
S1 = self.backend.backend.zeros([1,Jmax*self.NORIENT],dtype=self.backend.backend.float32,device=self.backend.torch_device)
|
|
3754
|
-
S2 = self.backend.backend.zeros([1,Jmax*self.NORIENT],dtype=self.backend.backend.float32,device=self.backend.torch_device)
|
|
3755
|
-
S3 = []
|
|
3756
|
-
S4 = []
|
|
3757
|
-
S3P = []
|
|
3758
|
-
VS1 = self.backend.backend.zeros([1,Jmax*self.NORIENT],dtype=self.backend.backend.float32,device=self.backend.torch_device)
|
|
3759
|
-
VS2 = self.backend.backend.zeros([1,Jmax*self.NORIENT],dtype=self.backend.backend.float32,device=self.backend.torch_device)
|
|
3760
|
-
VS3 = []
|
|
3761
|
-
VS3P = []
|
|
3762
|
-
VS4 = []
|
|
3763
|
-
|
|
3764
|
-
off_S2 = -2
|
|
3765
|
-
off_S3 = -3
|
|
3766
|
-
off_S4 = -4
|
|
3767
|
-
if self.use_1D:
|
|
3768
|
-
off_S2 = -1
|
|
3769
|
-
off_S3 = -1
|
|
3770
|
-
off_S4 = -1
|
|
3771
|
-
|
|
3772
|
-
# S2 for normalization
|
|
3773
|
-
cond_init_P1_dic = (norm == "self") or (
|
|
3774
|
-
(norm == "auto") and (self.P1_dic is None)
|
|
3775
|
-
)
|
|
3776
|
-
if norm is None:
|
|
3777
|
-
pass
|
|
3778
|
-
elif cond_init_P1_dic:
|
|
3779
|
-
P1_dic = {}
|
|
3780
|
-
if cross:
|
|
3781
|
-
P2_dic = {}
|
|
3782
|
-
elif (norm == "auto") and (self.P1_dic is not None):
|
|
3783
|
-
P1_dic = self.P1_dic
|
|
3784
|
-
if cross:
|
|
3785
|
-
P2_dic = self.P2_dic
|
|
3786
|
-
|
|
3787
|
-
if return_data:
|
|
3788
|
-
s0 = I1
|
|
3789
|
-
if out_nside is not None:
|
|
3790
|
-
s0 = self.backend.bk_reduce_mean(
|
|
3791
|
-
self.backend.bk_reshape(
|
|
3792
|
-
s0, [s0.shape[0], 12 * out_nside**2, (nside // out_nside) ** 2]
|
|
3793
|
-
),
|
|
3794
|
-
2,
|
|
3795
|
-
)
|
|
3796
|
-
else:
|
|
3797
|
-
if not cross:
|
|
3798
|
-
s0, l_vs0 = self.masked_mean(I1, vmask, axis=1, calc_var=True)
|
|
3799
|
-
else:
|
|
3800
|
-
s0, l_vs0 = self.masked_mean(
|
|
3801
|
-
self.backend.bk_L1(I1 * I2), vmask, axis=1, calc_var=True
|
|
3802
|
-
)
|
|
3803
|
-
#vs0 = self.backend.bk_concat([l_vs0, l_vs0], 1)
|
|
3804
|
-
#s0 = self.backend.bk_concat([s0, l_vs0], 1)
|
|
3805
|
-
result[:,:,0]=s0
|
|
3806
|
-
result[:,:,1]=l_vs0
|
|
3807
|
-
vresult[:,:,0]=l_vs0
|
|
3808
|
-
vresult[:,:,1]=l_vs0
|
|
3809
|
-
#### COMPUTE S1, S2, S3 and S4
|
|
3810
|
-
nside_j3 = nside # NSIDE start (nside_j3 = nside / 2^j3)
|
|
3811
|
-
|
|
3812
|
-
# a remettre comme avant
|
|
3813
|
-
M1_dic={}
|
|
3814
|
-
M2_dic={}
|
|
3815
|
-
|
|
3816
|
-
for j3 in range(Jmax):
|
|
3817
|
-
|
|
3818
|
-
if edge:
|
|
3819
|
-
if self.mask_mask is None:
|
|
3820
|
-
self.mask_mask={}
|
|
3821
|
-
if self.use_2D:
|
|
3822
|
-
if (vmask.shape[1],vmask.shape[2]) not in self.mask_mask:
|
|
3823
|
-
mask_mask=np.zeros([1,vmask.shape[1],vmask.shape[2]])
|
|
3824
|
-
mask_mask[0,
|
|
3825
|
-
self.KERNELSZ//2:-self.KERNELSZ//2+1,
|
|
3826
|
-
self.KERNELSZ//2:-self.KERNELSZ//2+1]=1.0
|
|
3827
|
-
self.mask_mask[(vmask.shape[1],vmask.shape[2])]=self.backend.bk_cast(mask_mask)
|
|
3828
|
-
vmask=vmask*self.mask_mask[(vmask.shape[1],vmask.shape[2])]
|
|
3829
|
-
#print(self.KERNELSZ//2,vmask,mask_mask)
|
|
3830
|
-
|
|
3831
|
-
if self.use_1D:
|
|
3832
|
-
if (vmask.shape[1]) not in self.mask_mask:
|
|
3833
|
-
mask_mask=np.zeros([1,vmask.shape[1]])
|
|
3834
|
-
mask_mask[0,
|
|
3835
|
-
self.KERNELSZ//2:-self.KERNELSZ//2+1]=1.0
|
|
3836
|
-
self.mask_mask[(vmask.shape[1])]=self.backend.bk_cast(mask_mask)
|
|
3837
|
-
vmask=vmask*self.mask_mask[(vmask.shape[1])]
|
|
3838
|
-
|
|
3839
|
-
if return_data:
|
|
3840
|
-
S3[j3] = None
|
|
3841
|
-
S3P[j3] = None
|
|
3842
|
-
|
|
3843
|
-
if S4 is None:
|
|
3844
|
-
S4 = {}
|
|
3845
|
-
S4[j3] = None
|
|
3846
|
-
|
|
3847
|
-
####### S1 and S2
|
|
3848
|
-
### Make the convolution I1 * Psi_j3
|
|
3849
|
-
conv1 = self.convol(I1, axis=1) # [Nbatch, Npix_j3, Norient3]
|
|
3850
|
-
if cmat is not None:
|
|
3851
|
-
tmp2 = self.backend.bk_repeat(conv1, self.NORIENT, axis=-1)
|
|
3852
|
-
conv1 = self.backend.bk_reduce_sum(
|
|
3853
|
-
self.backend.bk_reshape(
|
|
3854
|
-
cmat[j3] * tmp2,
|
|
3855
|
-
[tmp2.shape[0], cmat[j3].shape[0], self.NORIENT, self.NORIENT],
|
|
3856
|
-
),
|
|
3857
|
-
2,
|
|
3858
|
-
)
|
|
3859
|
-
|
|
3860
|
-
### Take the module M1 = |I1 * Psi_j3|
|
|
3861
|
-
M1_square = conv1 * self.backend.bk_conjugate(
|
|
3862
|
-
conv1
|
|
3863
|
-
) # [Nbatch, Npix_j3, Norient3]
|
|
3864
|
-
|
|
3865
|
-
M1 = self.backend.bk_L1(M1_square) # [Nbatch, Npix_j3, Norient3]
|
|
3866
|
-
|
|
3867
|
-
# Store M1_j3 in a dictionary
|
|
3868
|
-
M1_dic[j3] = M1
|
|
3869
|
-
|
|
3870
|
-
if not cross: # Auto
|
|
3871
|
-
M1_square = self.backend.bk_real(M1_square)
|
|
3872
|
-
|
|
3873
|
-
### S2_auto = < M1^2 >_pix
|
|
3874
|
-
# Apply the mask [Nmask, Npix_j3] and average over pixels
|
|
3875
|
-
if return_data:
|
|
3876
|
-
s2 = M1_square
|
|
3877
|
-
else:
|
|
3878
|
-
if calc_var:
|
|
3879
|
-
s2, vs2 = self.masked_mean(
|
|
3880
|
-
M1_square, vmask, axis=1, rank=j3, calc_var=True
|
|
3881
|
-
)
|
|
3882
|
-
#s2=self.backend.bk_flatten(self.backend.bk_real(s2))
|
|
3883
|
-
#vs2=self.backend.bk_flatten(vs2)
|
|
3884
|
-
else:
|
|
3885
|
-
s2 = self.masked_mean(M1_square, vmask, axis=1, rank=j3)
|
|
3886
|
-
|
|
3887
|
-
if cond_init_P1_dic:
|
|
3888
|
-
# We fill P1_dic with S2 for normalisation of S3 and S4
|
|
3889
|
-
P1_dic[j3] = self.backend.bk_real(self.backend.bk_real(s2)) # [Nbatch, Nmask, Norient3]
|
|
3890
|
-
|
|
3891
|
-
# We store S2_auto to return it [Nbatch, Nmask, NS2, Norient3]
|
|
3892
|
-
if return_data:
|
|
3893
|
-
if S2 is None:
|
|
3894
|
-
S2 = {}
|
|
3895
|
-
if out_nside is not None and out_nside < nside_j3:
|
|
3896
|
-
s2 = self.backend.bk_reduce_mean(
|
|
3897
|
-
self.backend.bk_reshape(
|
|
3898
|
-
s2,
|
|
3899
|
-
[
|
|
3900
|
-
s2.shape[0],
|
|
3901
|
-
12 * out_nside**2,
|
|
3902
|
-
(nside_j3 // out_nside) ** 2,
|
|
3903
|
-
s2.shape[2],
|
|
3904
|
-
],
|
|
3905
|
-
),
|
|
3906
|
-
2,
|
|
3907
|
-
)
|
|
3908
|
-
S2[j3] = s2
|
|
3909
|
-
else:
|
|
3910
|
-
if norm == "auto": # Normalize S2
|
|
3911
|
-
s2 /= P1_dic[j3]
|
|
3912
|
-
"""
|
|
3913
|
-
S2.append(
|
|
3914
|
-
self.backend.bk_expand_dims(s2, off_S2)
|
|
3915
|
-
) # Add a dimension for NS2
|
|
3916
|
-
if calc_var:
|
|
3917
|
-
VS2.append(
|
|
3918
|
-
self.backend.bk_expand_dims(vs2, off_S2)
|
|
3919
|
-
) # Add a dimension for NS2
|
|
3920
|
-
"""
|
|
3921
|
-
#print(s2.shape,result[:,:,2+j3*NORIENT*2:2+j3*NORIENT*2+NORIENT].shape,result.shape,2+j3*NORIENT*2)
|
|
3922
|
-
result[:,:,2+j3*NORIENT*2:2+j3*NORIENT*2+NORIENT]=s2
|
|
3923
|
-
if calc_var:
|
|
3924
|
-
vresult[:,:,2+j3*NORIENT*2:2+j3*NORIENT*2+NORIENT]=vs2
|
|
3925
|
-
#### S1_auto computation
|
|
3926
|
-
### Image 1 : S1 = < M1 >_pix
|
|
3927
|
-
# Apply the mask [Nmask, Npix_j3] and average over pixels
|
|
3928
|
-
if return_data:
|
|
3929
|
-
s1 = M1
|
|
3930
|
-
else:
|
|
3931
|
-
if calc_var:
|
|
3932
|
-
s1, vs1 = self.masked_mean(
|
|
3933
|
-
M1, vmask, axis=1, rank=j3, calc_var=True
|
|
3934
|
-
) # [Nbatch, Nmask, Norient3]
|
|
3935
|
-
#s1=self.backend.bk_flatten(self.backend.bk_real(s1))
|
|
3936
|
-
#vs1=self.backend.bk_flatten(vs1)
|
|
3937
|
-
else:
|
|
3938
|
-
s1 = self.masked_mean(
|
|
3939
|
-
M1, vmask, axis=1, rank=j3
|
|
3940
|
-
) # [Nbatch, Nmask, Norient3]
|
|
3941
|
-
#s1=self.backend.bk_flatten(self.backend.bk_real(s1))
|
|
3942
|
-
|
|
3943
|
-
if return_data:
|
|
3944
|
-
if out_nside is not None and out_nside < nside_j3:
|
|
3945
|
-
s1 = self.backend.bk_reduce_mean(
|
|
3946
|
-
self.backend.bk_reshape(
|
|
3947
|
-
s1,
|
|
3948
|
-
[
|
|
3949
|
-
s1.shape[0],
|
|
3950
|
-
12 * out_nside**2,
|
|
3951
|
-
(nside_j3 // out_nside) ** 2,
|
|
3952
|
-
s1.shape[2],
|
|
3953
|
-
],
|
|
3954
|
-
),
|
|
3955
|
-
2,
|
|
3956
|
-
)
|
|
3957
|
-
S1[j3] = s1
|
|
3958
|
-
else:
|
|
3959
|
-
### Normalize S1
|
|
3960
|
-
if norm is not None:
|
|
3961
|
-
self.div_norm(s1, (P1_dic[j3]) ** 0.5)
|
|
3962
|
-
result[:,:,2+j3*NORIENT*2+NORIENT:2+j3*NORIENT*2+2*NORIENT]=s1
|
|
3963
|
-
if calc_var:
|
|
3964
|
-
vresult[:,:,2+j3*NORIENT*2+NORIENT:2+j3*NORIENT*2+2*NORIENT]=vs1
|
|
3965
|
-
"""
|
|
3966
|
-
### We store S1 for image1 [Nbatch, Nmask, NS1, Norient3]
|
|
3967
|
-
S1.append(
|
|
3968
|
-
self.backend.bk_expand_dims(s1, off_S2)
|
|
3969
|
-
) # Add a dimension for NS1
|
|
3970
|
-
if calc_var:
|
|
3971
|
-
VS1.append(
|
|
3972
|
-
self.backend.bk_expand_dims(vs1, off_S2)
|
|
3973
|
-
) # Add a dimension for NS1
|
|
3974
|
-
"""
|
|
3975
|
-
|
|
3976
|
-
else: # Cross
|
|
3977
|
-
### Make the convolution I2 * Psi_j3
|
|
3978
|
-
conv2 = self.convol(I2, axis=1) # [Nbatch, Npix_j3, Norient3]
|
|
3979
|
-
if cmat is not None:
|
|
3980
|
-
tmp2 = self.backend.bk_repeat(conv2, self.NORIENT, axis=-1)
|
|
3981
|
-
conv2 = self.backend.bk_reduce_sum(
|
|
3982
|
-
self.backend.bk_reshape(
|
|
3983
|
-
cmat[j3] * tmp2,
|
|
3984
|
-
[
|
|
3985
|
-
tmp2.shape[0],
|
|
3986
|
-
cmat[j3].shape[0],
|
|
3987
|
-
self.NORIENT,
|
|
3988
|
-
self.NORIENT,
|
|
3989
|
-
],
|
|
3990
|
-
),
|
|
3991
|
-
2,
|
|
3992
|
-
)
|
|
3993
|
-
### Take the module M2 = |I2 * Psi_j3|
|
|
3994
|
-
M2_square = conv2 * self.backend.bk_conjugate(
|
|
3995
|
-
conv2
|
|
3996
|
-
) # [Nbatch, Npix_j3, Norient3]
|
|
3997
|
-
M2 = self.backend.bk_L1(M2_square) # [Nbatch, Npix_j3, Norient3]
|
|
3998
|
-
# Store M2_j3 in a dictionary
|
|
3999
|
-
M2_dic[j3] = M2
|
|
4000
|
-
|
|
4001
|
-
### S2_auto = < M2^2 >_pix
|
|
4002
|
-
# Not returned, only for normalization
|
|
4003
|
-
if cond_init_P1_dic:
|
|
4004
|
-
# Apply the mask [Nmask, Npix_j3] and average over pixels
|
|
4005
|
-
if return_data:
|
|
4006
|
-
p1 = M1_square
|
|
4007
|
-
p2 = M2_square
|
|
4008
|
-
else:
|
|
4009
|
-
if calc_var:
|
|
4010
|
-
p1, vp1 = self.masked_mean(
|
|
4011
|
-
M1_square, vmask, axis=1, rank=j3, calc_var=True
|
|
4012
|
-
) # [Nbatch, Nmask, Norient3]
|
|
4013
|
-
p2, vp2 = self.masked_mean(
|
|
4014
|
-
M2_square, vmask, axis=1, rank=j3, calc_var=True
|
|
4015
|
-
) # [Nbatch, Nmask, Norient3]
|
|
4016
|
-
else:
|
|
4017
|
-
p1 = self.masked_mean(
|
|
4018
|
-
M1_square, vmask, axis=1, rank=j3
|
|
4019
|
-
) # [Nbatch, Nmask, Norient3]
|
|
4020
|
-
p2 = self.masked_mean(
|
|
4021
|
-
M2_square, vmask, axis=1, rank=j3
|
|
4022
|
-
) # [Nbatch, Nmask, Norient3]
|
|
4023
|
-
# We fill P1_dic with S2 for normalisation of S3 and S4
|
|
4024
|
-
P1_dic[j3] = self.backend.bk_real(p1) # [Nbatch, Nmask, Norient3]
|
|
4025
|
-
P2_dic[j3] = self.backend.bk_real(p2) # [Nbatch, Nmask, Norient3]
|
|
4026
|
-
|
|
4027
|
-
### S2_cross = < (I1 * Psi_j3) (I2 * Psi_j3)^* >_pix
|
|
4028
|
-
# z_1 x z_2^* = (a1a2 + b1b2) + i(b1a2 - a1b2)
|
|
4029
|
-
s2 = conv1 * self.backend.bk_conjugate(conv2)
|
|
4030
|
-
MX = self.backend.bk_L1(s2)
|
|
4031
|
-
# Apply the mask [Nmask, Npix_j3] and average over pixels
|
|
4032
|
-
if return_data:
|
|
4033
|
-
s2 = s2
|
|
4034
|
-
else:
|
|
4035
|
-
if calc_var:
|
|
4036
|
-
s2, vs2 = self.masked_mean(
|
|
4037
|
-
s2, vmask, axis=1, rank=j3, calc_var=True
|
|
4038
|
-
)
|
|
4039
|
-
else:
|
|
4040
|
-
s2 = self.masked_mean(s2, vmask, axis=1, rank=j3)
|
|
4041
|
-
|
|
4042
|
-
if return_data:
|
|
4043
|
-
if out_nside is not None and out_nside < nside_j3:
|
|
4044
|
-
s2 = self.backend.bk_reduce_mean(
|
|
4045
|
-
self.backend.bk_reshape(
|
|
4046
|
-
s2,
|
|
4047
|
-
[
|
|
4048
|
-
s2.shape[0],
|
|
4049
|
-
12 * out_nside**2,
|
|
4050
|
-
(nside_j3 // out_nside) ** 2,
|
|
4051
|
-
s2.shape[2],
|
|
4052
|
-
],
|
|
4053
|
-
),
|
|
4054
|
-
2,
|
|
4055
|
-
)
|
|
4056
|
-
S2[j3] = s2
|
|
4057
|
-
else:
|
|
4058
|
-
### Normalize S2_cross
|
|
4059
|
-
if norm == "auto":
|
|
4060
|
-
s2 /= (P1_dic[j3] * P2_dic[j3]) ** 0.5
|
|
4061
|
-
|
|
4062
|
-
### Store S2_cross as complex [Nbatch, Nmask, NS2, Norient3]
|
|
4063
|
-
s2 = self.backend.bk_real(s2)
|
|
4064
|
-
|
|
4065
|
-
S2.append(
|
|
4066
|
-
self.backend.bk_expand_dims(s2, off_S2)
|
|
4067
|
-
) # Add a dimension for NS2
|
|
4068
|
-
if calc_var:
|
|
4069
|
-
VS2.append(
|
|
4070
|
-
self.backend.bk_expand_dims(vs2, off_S2)
|
|
4071
|
-
) # Add a dimension for NS2
|
|
4072
|
-
|
|
4073
|
-
#### S1_auto computation
|
|
4074
|
-
### Image 1 : S1 = < M1 >_pix
|
|
4075
|
-
# Apply the mask [Nmask, Npix_j3] and average over pixels
|
|
4076
|
-
if return_data:
|
|
4077
|
-
s1 = MX
|
|
4078
|
-
else:
|
|
4079
|
-
if calc_var:
|
|
4080
|
-
s1, vs1 = self.masked_mean(
|
|
4081
|
-
MX, vmask, axis=1, rank=j3, calc_var=True
|
|
4082
|
-
) # [Nbatch, Nmask, Norient3]
|
|
4083
|
-
else:
|
|
4084
|
-
s1 = self.masked_mean(
|
|
4085
|
-
MX, vmask, axis=1, rank=j3
|
|
4086
|
-
) # [Nbatch, Nmask, Norient3]
|
|
4087
|
-
if return_data:
|
|
4088
|
-
if out_nside is not None and out_nside < nside_j3:
|
|
4089
|
-
s1 = self.backend.bk_reduce_mean(
|
|
4090
|
-
self.backend.bk_reshape(
|
|
4091
|
-
s1,
|
|
4092
|
-
[
|
|
4093
|
-
s1.shape[0],
|
|
4094
|
-
12 * out_nside**2,
|
|
4095
|
-
(nside_j3 // out_nside) ** 2,
|
|
4096
|
-
s1.shape[2],
|
|
4097
|
-
],
|
|
4098
|
-
),
|
|
4099
|
-
2,
|
|
4100
|
-
)
|
|
4101
|
-
S1[j3] = s1
|
|
4102
|
-
else:
|
|
4103
|
-
### Normalize S1
|
|
4104
|
-
if norm is not None:
|
|
4105
|
-
self.div_norm(s1, (P1_dic[j3]) ** 0.5)
|
|
4106
|
-
### We store S1 for image1 [Nbatch, Nmask, NS1, Norient3]
|
|
4107
|
-
S1.append(
|
|
4108
|
-
self.backend.bk_expand_dims(s1, off_S2)
|
|
4109
|
-
) # Add a dimension for NS1
|
|
4110
|
-
if calc_var:
|
|
4111
|
-
VS1.append(
|
|
4112
|
-
self.backend.bk_expand_dims(vs1, off_S2)
|
|
4113
|
-
) # Add a dimension for NS1
|
|
4114
|
-
|
|
4115
|
-
# Initialize dictionaries for |I1*Psi_j| * Psi_j3
|
|
4116
|
-
M1convPsi_dic = {}
|
|
4117
|
-
if cross:
|
|
4118
|
-
# Initialize dictionaries for |I2*Psi_j| * Psi_j3
|
|
4119
|
-
M2convPsi_dic = {}
|
|
4120
|
-
|
|
4121
|
-
###### S3
|
|
4122
|
-
nside_j2 = nside_j3
|
|
4123
|
-
for j2 in range(0,-1): # j3 + 1): # j2 <= j3
|
|
4124
|
-
if return_data:
|
|
4125
|
-
if S4[j3] is None:
|
|
4126
|
-
S4[j3] = {}
|
|
4127
|
-
S4[j3][j2] = None
|
|
4128
|
-
|
|
4129
|
-
### S3_auto = < (I1 * Psi)_j3 x (|I1 * Psi_j2| * Psi_j3)^* >_pix
|
|
4130
|
-
if not cross:
|
|
4131
|
-
if calc_var:
|
|
4132
|
-
s3, vs3 = self._compute_S3(
|
|
4133
|
-
j2,
|
|
4134
|
-
j3,
|
|
4135
|
-
conv1,
|
|
4136
|
-
vmask,
|
|
4137
|
-
M1_dic,
|
|
4138
|
-
M1convPsi_dic,
|
|
4139
|
-
calc_var=True,
|
|
4140
|
-
cmat2=cmat2,
|
|
4141
|
-
) # [Nbatch, Nmask, Norient3, Norient2]
|
|
4142
|
-
else:
|
|
4143
|
-
s3 = self._compute_S3(
|
|
4144
|
-
j2,
|
|
4145
|
-
j3,
|
|
4146
|
-
conv1,
|
|
4147
|
-
vmask,
|
|
4148
|
-
M1_dic,
|
|
4149
|
-
M1convPsi_dic,
|
|
4150
|
-
return_data=return_data,
|
|
4151
|
-
cmat2=cmat2,
|
|
4152
|
-
) # [Nbatch, Nmask, Norient3, Norient2]
|
|
4153
|
-
|
|
4154
|
-
if return_data:
|
|
4155
|
-
if S3[j3] is None:
|
|
4156
|
-
S3[j3] = {}
|
|
4157
|
-
if out_nside is not None and out_nside < nside_j2:
|
|
4158
|
-
s3 = self.backend.bk_reduce_mean(
|
|
4159
|
-
self.backend.bk_reshape(
|
|
4160
|
-
s3,
|
|
4161
|
-
[
|
|
4162
|
-
s3.shape[0],
|
|
4163
|
-
12 * out_nside**2,
|
|
4164
|
-
(nside_j2 // out_nside) ** 2,
|
|
4165
|
-
s3.shape[2],
|
|
4166
|
-
s3.shape[3],
|
|
4167
|
-
],
|
|
4168
|
-
),
|
|
4169
|
-
2,
|
|
4170
|
-
)
|
|
4171
|
-
S3[j3][j2] = s3
|
|
4172
|
-
else:
|
|
4173
|
-
### Normalize S3 with S2_j [Nbatch, Nmask, Norient_j]
|
|
4174
|
-
if norm is not None:
|
|
4175
|
-
self.div_norm(
|
|
4176
|
-
s3,
|
|
4177
|
-
(
|
|
4178
|
-
self.backend.bk_expand_dims(P1_dic[j2], off_S2)
|
|
4179
|
-
* self.backend.bk_expand_dims(P1_dic[j3], -1)
|
|
4180
|
-
)
|
|
4181
|
-
** 0.5,
|
|
4182
|
-
) # [Nbatch, Nmask, Norient3, Norient2]
|
|
4183
|
-
|
|
4184
|
-
### Store S3 as a complex [Nbatch, Nmask, NS3, Norient3, Norient2]
|
|
4185
|
-
|
|
4186
|
-
# S3.append(self.backend.bk_reshape(s3,[s3.shape[0],s3.shape[1],
|
|
4187
|
-
# s3.shape[2]*s3.shape[3]]))
|
|
4188
|
-
S3.append(
|
|
4189
|
-
self.backend.bk_expand_dims(s3, off_S3)
|
|
4190
|
-
) # Add a dimension for NS3
|
|
4191
|
-
if calc_var:
|
|
4192
|
-
VS3.append(
|
|
4193
|
-
self.backend.bk_expand_dims(vs3, off_S3)
|
|
4194
|
-
) # Add a dimension for NS3
|
|
4195
|
-
# VS3.append(self.backend.bk_reshape(vs3,[s3.shape[0],s3.shape[1],
|
|
4196
|
-
# s3.shape[2]*s3.shape[3]]))
|
|
4197
|
-
|
|
4198
|
-
### S3_cross = < (I1 * Psi)_j3 x (|I2 * Psi_j2| * Psi_j3)^* >_pix
|
|
4199
|
-
### S3P_cross = < (I2 * Psi)_j3 x (|I1 * Psi_j2| * Psi_j3)^* >_pix
|
|
4200
|
-
else:
|
|
4201
|
-
if calc_var:
|
|
4202
|
-
s3, vs3 = self._compute_S3(
|
|
4203
|
-
j2,
|
|
4204
|
-
j3,
|
|
4205
|
-
conv1,
|
|
4206
|
-
vmask,
|
|
4207
|
-
M2_dic,
|
|
4208
|
-
M2convPsi_dic,
|
|
4209
|
-
calc_var=True,
|
|
4210
|
-
cmat2=cmat2,
|
|
4211
|
-
)
|
|
4212
|
-
s3p, vs3p = self._compute_S3(
|
|
4213
|
-
j2,
|
|
4214
|
-
j3,
|
|
4215
|
-
conv2,
|
|
4216
|
-
vmask,
|
|
4217
|
-
M1_dic,
|
|
4218
|
-
M1convPsi_dic,
|
|
4219
|
-
calc_var=True,
|
|
4220
|
-
cmat2=cmat2,
|
|
4221
|
-
)
|
|
4222
|
-
else:
|
|
4223
|
-
s3 = self._compute_S3(
|
|
4224
|
-
j2,
|
|
4225
|
-
j3,
|
|
4226
|
-
conv1,
|
|
4227
|
-
vmask,
|
|
4228
|
-
M2_dic,
|
|
4229
|
-
M2convPsi_dic,
|
|
4230
|
-
return_data=return_data,
|
|
4231
|
-
cmat2=cmat2,
|
|
4232
|
-
)
|
|
4233
|
-
s3p = self._compute_S3(
|
|
4234
|
-
j2,
|
|
4235
|
-
j3,
|
|
4236
|
-
conv2,
|
|
4237
|
-
vmask,
|
|
4238
|
-
M1_dic,
|
|
4239
|
-
M1convPsi_dic,
|
|
4240
|
-
return_data=return_data,
|
|
4241
|
-
cmat2=cmat2,
|
|
4242
|
-
)
|
|
4243
|
-
|
|
4244
|
-
if return_data:
|
|
4245
|
-
if S3[j3] is None:
|
|
4246
|
-
S3[j3] = {}
|
|
4247
|
-
S3P[j3] = {}
|
|
4248
|
-
if out_nside is not None and out_nside < nside_j2:
|
|
4249
|
-
s3 = self.backend.bk_reduce_mean(
|
|
4250
|
-
self.backend.bk_reshape(
|
|
4251
|
-
s3,
|
|
4252
|
-
[
|
|
4253
|
-
s3.shape[0],
|
|
4254
|
-
12 * out_nside**2,
|
|
4255
|
-
(nside_j2 // out_nside) ** 2,
|
|
4256
|
-
s3.shape[2],
|
|
4257
|
-
s3.shape[3],
|
|
4258
|
-
],
|
|
4259
|
-
),
|
|
4260
|
-
2,
|
|
4261
|
-
)
|
|
4262
|
-
s3p = self.backend.bk_reduce_mean(
|
|
4263
|
-
self.backend.bk_reshape(
|
|
4264
|
-
s3p,
|
|
4265
|
-
[
|
|
4266
|
-
s3.shape[0],
|
|
4267
|
-
12 * out_nside**2,
|
|
4268
|
-
(nside_j2 // out_nside) ** 2,
|
|
4269
|
-
s3.shape[2],
|
|
4270
|
-
s3.shape[3],
|
|
4271
|
-
],
|
|
4272
|
-
),
|
|
4273
|
-
2,
|
|
4274
|
-
)
|
|
4275
|
-
S3[j3][j2] = s3
|
|
4276
|
-
S3P[j3][j2] = s3p
|
|
4277
|
-
else:
|
|
4278
|
-
### Normalize S3 and S3P with S2_j [Nbatch, Nmask, Norient_j]
|
|
4279
|
-
if norm is not None:
|
|
4280
|
-
self.div_norm(
|
|
4281
|
-
s3,
|
|
4282
|
-
(
|
|
4283
|
-
self.backend.bk_expand_dims(P2_dic[j2], off_S2)
|
|
4284
|
-
* self.backend.bk_expand_dims(P1_dic[j3], -1)
|
|
4285
|
-
)
|
|
4286
|
-
** 0.5,
|
|
4287
|
-
) # [Nbatch, Nmask, Norient3, Norient2]
|
|
4288
|
-
self.div_norm(
|
|
4289
|
-
s3p,
|
|
4290
|
-
(
|
|
4291
|
-
self.backend.bk_expand_dims(P1_dic[j2], off_S2)
|
|
4292
|
-
* self.backend.bk_expand_dims(P2_dic[j3], -1)
|
|
4293
|
-
)
|
|
4294
|
-
** 0.5,
|
|
4295
|
-
) # [Nbatch, Nmask, Norient3, Norient2]
|
|
4296
|
-
|
|
4297
|
-
### Store S3 and S3P as a complex [Nbatch, Nmask, NS3, Norient3, Norient2]
|
|
4298
|
-
|
|
4299
|
-
# S3.append(self.backend.bk_reshape(s3,[s3.shape[0],s3.shape[1],
|
|
4300
|
-
# s3.shape[2]*s3.shape[3]]))
|
|
4301
|
-
S3.append(
|
|
4302
|
-
self.backend.bk_expand_dims(s3, off_S3)
|
|
4303
|
-
) # Add a dimension for NS3
|
|
4304
|
-
if calc_var:
|
|
4305
|
-
VS3.append(
|
|
4306
|
-
self.backend.bk_expand_dims(vs3, off_S3)
|
|
4307
|
-
) # Add a dimension for NS3
|
|
4308
|
-
|
|
4309
|
-
# VS3.append(self.backend.bk_reshape(vs3,[s3.shape[0],s3.shape[1],
|
|
4310
|
-
# s3.shape[2]*s3.shape[3]]))
|
|
4311
|
-
|
|
4312
|
-
# S3P.append(self.backend.bk_reshape(s3p,[s3.shape[0],s3.shape[1],
|
|
4313
|
-
# s3.shape[2]*s3.shape[3]]))
|
|
4314
|
-
S3P.append(
|
|
4315
|
-
self.backend.bk_expand_dims(s3p, off_S3)
|
|
4316
|
-
) # Add a dimension for NS3
|
|
4317
|
-
if calc_var:
|
|
4318
|
-
VS3P.append(
|
|
4319
|
-
self.backend.bk_expand_dims(vs3p, off_S3)
|
|
4320
|
-
) # Add a dimension for NS3
|
|
4321
|
-
# VS3P.append(self.backend.bk_reshape(vs3p,[s3.shape[0],s3.shape[1],
|
|
4322
|
-
# s3.shape[2]*s3.shape[3]]))
|
|
4323
|
-
|
|
4324
|
-
##### S4
|
|
4325
|
-
nside_j1 = nside_j2
|
|
4326
|
-
for j1 in range(0, j2 + 1): # j1 <= j2
|
|
4327
|
-
### S4_auto = <(|I1 * psi1| * psi3)(|I1 * psi2| * psi3)^*>
|
|
4328
|
-
if not cross:
|
|
4329
|
-
if calc_var:
|
|
4330
|
-
s4, vs4 = self._compute_S4(
|
|
4331
|
-
j1,
|
|
4332
|
-
j2,
|
|
4333
|
-
vmask,
|
|
4334
|
-
M1convPsi_dic,
|
|
4335
|
-
M2convPsi_dic=None,
|
|
4336
|
-
calc_var=True,
|
|
4337
|
-
) # [Nbatch, Nmask, Norient3, Norient2, Norient1]
|
|
4338
|
-
else:
|
|
4339
|
-
s4 = self._compute_S4(
|
|
4340
|
-
j1,
|
|
4341
|
-
j2,
|
|
4342
|
-
vmask,
|
|
4343
|
-
M1convPsi_dic,
|
|
4344
|
-
M2convPsi_dic=None,
|
|
4345
|
-
return_data=return_data,
|
|
4346
|
-
) # [Nbatch, Nmask, Norient3, Norient2, Norient1]
|
|
4347
|
-
|
|
4348
|
-
if return_data:
|
|
4349
|
-
if S4[j3][j2] is None:
|
|
4350
|
-
S4[j3][j2] = {}
|
|
4351
|
-
if out_nside is not None and out_nside < nside_j1:
|
|
4352
|
-
s4 = self.backend.bk_reduce_mean(
|
|
4353
|
-
self.backend.bk_reshape(
|
|
4354
|
-
s4,
|
|
4355
|
-
[
|
|
4356
|
-
s4.shape[0],
|
|
4357
|
-
12 * out_nside**2,
|
|
4358
|
-
(nside_j1 // out_nside) ** 2,
|
|
4359
|
-
s4.shape[2],
|
|
4360
|
-
s4.shape[3],
|
|
4361
|
-
s4.shape[4],
|
|
4362
|
-
],
|
|
4363
|
-
),
|
|
4364
|
-
2,
|
|
4365
|
-
)
|
|
4366
|
-
S4[j3][j2][j1] = s4
|
|
4367
|
-
else:
|
|
4368
|
-
### Normalize S4 with S2_j [Nbatch, Nmask, Norient_j]
|
|
4369
|
-
if norm is not None:
|
|
4370
|
-
self.div_norm(
|
|
4371
|
-
s4,
|
|
4372
|
-
(
|
|
4373
|
-
self.backend.bk_expand_dims(
|
|
4374
|
-
self.backend.bk_expand_dims(
|
|
4375
|
-
P1_dic[j1], off_S2
|
|
4376
|
-
),
|
|
4377
|
-
off_S2,
|
|
4378
|
-
)
|
|
4379
|
-
* self.backend.bk_expand_dims(
|
|
4380
|
-
self.backend.bk_expand_dims(
|
|
4381
|
-
P1_dic[j2], off_S2
|
|
4382
|
-
),
|
|
4383
|
-
-1,
|
|
4384
|
-
)
|
|
4385
|
-
)
|
|
4386
|
-
** 0.5,
|
|
4387
|
-
) # [Nbatch, Nmask, Norient3, Norient2, Norient1]
|
|
4388
|
-
### Store S4 as a complex [Nbatch, Nmask, NS4, Norient3, Norient2, Norient1]
|
|
4389
|
-
|
|
4390
|
-
# S4.append(self.backend.bk_reshape(s4,[s4.shape[0],s4.shape[1],
|
|
4391
|
-
# s4.shape[2]*s4.shape[3]*s4.shape[4]]))
|
|
4392
|
-
S4.append(
|
|
4393
|
-
self.backend.bk_expand_dims(s4, off_S4)
|
|
4394
|
-
) # Add a dimension for NS4
|
|
4395
|
-
if calc_var:
|
|
4396
|
-
# VS4.append(self.backend.bk_reshape(vs4,[s4.shape[0],s4.shape[1],
|
|
4397
|
-
# s4.shape[2]*s4.shape[3]*s4.shape[4]]))
|
|
4398
|
-
VS4.append(
|
|
4399
|
-
self.backend.bk_expand_dims(vs4, off_S4)
|
|
4400
|
-
) # Add a dimension for NS4
|
|
4401
|
-
|
|
4402
|
-
### S4_cross = <(|I1 * psi1| * psi3)(|I2 * psi2| * psi3)^*>
|
|
4403
|
-
else:
|
|
4404
|
-
if calc_var:
|
|
4405
|
-
s4, vs4 = self._compute_S4(
|
|
4406
|
-
j1,
|
|
4407
|
-
j2,
|
|
4408
|
-
vmask,
|
|
4409
|
-
M1convPsi_dic,
|
|
4410
|
-
M2convPsi_dic=M2convPsi_dic,
|
|
4411
|
-
calc_var=True,
|
|
4412
|
-
) # [Nbatch, Nmask, Norient3, Norient2, Norient1]
|
|
4413
|
-
else:
|
|
4414
|
-
s4 = self._compute_S4(
|
|
4415
|
-
j1,
|
|
4416
|
-
j2,
|
|
4417
|
-
vmask,
|
|
4418
|
-
M1convPsi_dic,
|
|
4419
|
-
M2convPsi_dic=M2convPsi_dic,
|
|
4420
|
-
return_data=return_data,
|
|
4421
|
-
) # [Nbatch, Nmask, Norient3, Norient2, Norient1]
|
|
4422
|
-
|
|
4423
|
-
if return_data:
|
|
4424
|
-
if S4[j3][j2] is None:
|
|
4425
|
-
S4[j3][j2] = {}
|
|
4426
|
-
if out_nside is not None and out_nside < nside_j1:
|
|
4427
|
-
s4 = self.backend.bk_reduce_mean(
|
|
4428
|
-
self.backend.bk_reshape(
|
|
4429
|
-
s4,
|
|
4430
|
-
[
|
|
4431
|
-
s4.shape[0],
|
|
4432
|
-
12 * out_nside**2,
|
|
4433
|
-
(nside_j1 // out_nside) ** 2,
|
|
4434
|
-
s4.shape[2],
|
|
4435
|
-
s4.shape[3],
|
|
4436
|
-
s4.shape[4],
|
|
4437
|
-
],
|
|
4438
|
-
),
|
|
4439
|
-
2,
|
|
4440
|
-
)
|
|
4441
|
-
S4[j3][j2][j1] = s4
|
|
4442
|
-
else:
|
|
4443
|
-
### Normalize S4 with S2_j [Nbatch, Nmask, Norient_j]
|
|
4444
|
-
if norm is not None:
|
|
4445
|
-
self.div_norm(
|
|
4446
|
-
s4,
|
|
4447
|
-
(
|
|
4448
|
-
self.backend.bk_expand_dims(
|
|
4449
|
-
self.backend.bk_expand_dims(
|
|
4450
|
-
P1_dic[j1], off_S2
|
|
4451
|
-
),
|
|
4452
|
-
off_S2,
|
|
4453
|
-
)
|
|
4454
|
-
* self.backend.bk_expand_dims(
|
|
4455
|
-
self.backend.bk_expand_dims(
|
|
4456
|
-
P2_dic[j2], off_S2
|
|
4457
|
-
),
|
|
4458
|
-
-1,
|
|
4459
|
-
)
|
|
4460
|
-
)
|
|
4461
|
-
** 0.5,
|
|
4462
|
-
) # [Nbatch, Nmask, Norient3, Norient2, Norient1]
|
|
4463
|
-
### Store S4 as a complex [Nbatch, Nmask, NS4, Norient3, Norient2, Norient1]
|
|
4464
|
-
# S4.append(self.backend.bk_reshape(s4,[s4.shape[0],s4.shape[1],
|
|
4465
|
-
# s4.shape[2]*s4.shape[3]*s4.shape[4]]))
|
|
4466
|
-
S4.append(
|
|
4467
|
-
self.backend.bk_expand_dims(s4, off_S4)
|
|
4468
|
-
) # Add a dimension for NS4
|
|
4469
|
-
if calc_var:
|
|
4470
|
-
|
|
4471
|
-
# VS4.append(self.backend.bk_reshape(vs4,[s4.shape[0],s4.shape[1],
|
|
4472
|
-
# s4.shape[2]*s4.shape[3]*s4.shape[4]]))
|
|
4473
|
-
VS4.append(
|
|
4474
|
-
self.backend.bk_expand_dims(vs4, off_S4)
|
|
4475
|
-
) # Add a dimension for NS4
|
|
4476
|
-
|
|
4477
|
-
nside_j1 = nside_j1 // 2
|
|
4478
|
-
nside_j2 = nside_j2 // 2
|
|
4479
|
-
|
|
4480
|
-
###### Reshape for next iteration on j3
|
|
4481
|
-
### Image I1,
|
|
4482
|
-
# downscale the I1 [Nbatch, Npix_j3]
|
|
4483
|
-
if j3 != Jmax - 1:
|
|
4484
|
-
I1 = self.smooth(I1, axis=1)
|
|
4485
|
-
I1 = self.ud_grade_2(I1, axis=1)
|
|
4486
|
-
|
|
4487
|
-
### Image I2
|
|
4488
|
-
if cross:
|
|
4489
|
-
I2 = self.smooth(I2, axis=1)
|
|
4490
|
-
I2 = self.ud_grade_2(I2, axis=1)
|
|
4491
|
-
|
|
4492
|
-
### Modules
|
|
4493
|
-
for j2 in range(0, j3 + 1): # j2 =< j3
|
|
4494
|
-
### Dictionary M1_dic[j2]
|
|
4495
|
-
M1_smooth = self.smooth(
|
|
4496
|
-
M1_dic[j2], axis=1
|
|
4497
|
-
) # [Nbatch, Npix_j3, Norient3]
|
|
4498
|
-
M1_dic[j2] = self.ud_grade_2(
|
|
4499
|
-
M1_smooth, axis=1
|
|
4500
|
-
) # [Nbatch, Npix_j3, Norient3]
|
|
4501
|
-
|
|
4502
|
-
### Dictionary M2_dic[j2]
|
|
4503
|
-
if cross:
|
|
4504
|
-
M2_smooth = self.smooth(
|
|
4505
|
-
M2_dic[j2], axis=1
|
|
4506
|
-
) # [Nbatch, Npix_j3, Norient3]
|
|
4507
|
-
M2_dic[j2] = self.ud_grade_2(
|
|
4508
|
-
M2, axis=1
|
|
4509
|
-
) # [Nbatch, Npix_j3, Norient3]
|
|
4510
|
-
|
|
4511
|
-
### Mask
|
|
4512
|
-
vmask = self.ud_grade_2(vmask, axis=1)
|
|
4513
|
-
|
|
4514
|
-
if self.mask_thres is not None:
|
|
4515
|
-
vmask = self.backend.bk_threshold(vmask, self.mask_thres)
|
|
4516
|
-
|
|
4517
|
-
### NSIDE_j3
|
|
4518
|
-
nside_j3 = nside_j3 // 2
|
|
4519
|
-
|
|
4520
|
-
### Store P1_dic and P2_dic in self
|
|
4521
|
-
if (norm == "auto") and (self.P1_dic is None):
|
|
4522
|
-
self.P1_dic = P1_dic
|
|
4523
|
-
if cross:
|
|
4524
|
-
self.P2_dic = P2_dic
|
|
4525
|
-
"""
|
|
4526
|
-
Sout=[s0]+S1+S2+S3+S4
|
|
4527
|
-
|
|
4528
|
-
if cross:
|
|
4529
|
-
Sout=Sout+S3P
|
|
4530
|
-
if calc_var:
|
|
4531
|
-
SVout=[vs0]+VS1+VS2+VS3+VS4
|
|
4532
|
-
if cross:
|
|
4533
|
-
VSout=VSout+VS3P
|
|
4534
|
-
return self.backend.bk_concat(Sout, 2),self.backend.bk_concat(VSout, 2)
|
|
4535
|
-
|
|
4536
|
-
return self.backend.bk_concat(Sout, 2)
|
|
4537
|
-
"""
|
|
4538
|
-
if calc_var:
|
|
4539
|
-
return result,vresult
|
|
4540
|
-
else:
|
|
4541
|
-
return result
|
|
4542
|
-
if calc_var:
|
|
4543
|
-
for k in S1:
|
|
4544
|
-
print(k.shape,k.dtype)
|
|
4545
|
-
for k in S2:
|
|
4546
|
-
print(k.shape,k.dtype)
|
|
4547
|
-
print(s0.shape,s0.dtype)
|
|
4548
|
-
return self.backend.bk_concat([s0]+S1+S2,axis=1),self.backend.bk_concat([vs0]+VS1+VS2,axis=1)
|
|
4549
|
-
else:
|
|
4550
|
-
return self.backend.bk_concat([s0]+S1+S2,axis=1)
|
|
4551
|
-
|
|
4552
|
-
if not return_data:
|
|
4553
|
-
S1 = self.backend.bk_concat(S1, 2)
|
|
4554
|
-
S2 = self.backend.bk_concat(S2, 2)
|
|
4555
|
-
S3 = self.backend.bk_concat(S3, 2)
|
|
4556
|
-
S4 = self.backend.bk_concat(S4, 2)
|
|
4557
|
-
if cross:
|
|
4558
|
-
S3P = self.backend.bk_concat(S3P, 2)
|
|
4559
|
-
if calc_var:
|
|
4560
|
-
VS1 = self.backend.bk_concat(VS1, 2)
|
|
4561
|
-
VS2 = self.backend.bk_concat(VS2, 2)
|
|
4562
|
-
VS3 = self.backend.bk_concat(VS3, 2)
|
|
4563
|
-
VS4 = self.backend.bk_concat(VS4, 2)
|
|
4564
|
-
if cross:
|
|
4565
|
-
VS3P = self.backend.bk_concat(VS3P, 2)
|
|
4566
|
-
if calc_var:
|
|
4567
|
-
if not cross:
|
|
4568
|
-
return scat_cov(
|
|
4569
|
-
s0, S2, S3, S4, s1=S1, backend=self.backend, use_1D=self.use_1D
|
|
4570
|
-
), scat_cov(
|
|
4571
|
-
vs0,
|
|
4572
|
-
VS2,
|
|
4573
|
-
VS3,
|
|
4574
|
-
VS4,
|
|
4575
|
-
s1=VS1,
|
|
4576
|
-
backend=self.backend,
|
|
4577
|
-
use_1D=self.use_1D,
|
|
4578
|
-
)
|
|
4579
|
-
else:
|
|
4580
|
-
return scat_cov(
|
|
4581
|
-
s0,
|
|
4582
|
-
S2,
|
|
4583
|
-
S3,
|
|
4584
|
-
S4,
|
|
4585
|
-
s1=S1,
|
|
4586
|
-
s3p=S3P,
|
|
4587
|
-
backend=self.backend,
|
|
4588
|
-
use_1D=self.use_1D,
|
|
4589
|
-
), scat_cov(
|
|
4590
|
-
vs0,
|
|
4591
|
-
VS2,
|
|
4592
|
-
VS3,
|
|
4593
|
-
VS4,
|
|
4594
|
-
s1=VS1,
|
|
4595
|
-
s3p=VS3P,
|
|
4596
|
-
backend=self.backend,
|
|
4597
|
-
use_1D=self.use_1D,
|
|
4598
|
-
)
|
|
4599
|
-
else:
|
|
4600
|
-
if not cross:
|
|
4601
|
-
return scat_cov(
|
|
4602
|
-
s0, S2, S3, S4, s1=S1, backend=self.backend, use_1D=self.use_1D
|
|
4603
|
-
)
|
|
4604
|
-
else:
|
|
4605
|
-
return scat_cov(
|
|
4606
|
-
s0,
|
|
4607
|
-
S2,
|
|
4608
|
-
S3,
|
|
4609
|
-
S4,
|
|
4610
|
-
s1=S1,
|
|
4611
|
-
s3p=S3P,
|
|
4612
|
-
backend=self.backend,
|
|
4613
|
-
use_1D=self.use_1D,
|
|
4614
|
-
)
|
|
4615
|
-
def clean_norm(self):
|
|
4616
|
-
self.P1_dic = None
|
|
4617
|
-
self.P2_dic = None
|
|
4618
|
-
return
|
|
4619
|
-
|
|
4620
|
-
def _compute_S3(
|
|
4621
|
-
self,
|
|
4622
|
-
j2,
|
|
4623
|
-
j3,
|
|
4624
|
-
conv,
|
|
4625
|
-
vmask,
|
|
4626
|
-
M_dic,
|
|
4627
|
-
MconvPsi_dic,
|
|
4628
|
-
calc_var=False,
|
|
4629
|
-
return_data=False,
|
|
4630
|
-
cmat2=None,
|
|
4631
|
-
):
|
|
4632
|
-
"""
|
|
4633
|
-
Compute the S3 coefficients (auto or cross)
|
|
4634
|
-
S3 = < (Ia * Psi)_j3 x (|Ib * Psi_j2| * Psi_j3)^* >_pix
|
|
4635
|
-
Parameters
|
|
4636
|
-
----------
|
|
4637
|
-
Returns
|
|
4638
|
-
-------
|
|
4639
|
-
cs3, ss3: real and imag parts of S3 coeff
|
|
4640
|
-
"""
|
|
4641
|
-
### Compute |I1 * Psi_j2| * Psi_j3 = M1_j2 * Psi_j3
|
|
4642
|
-
# Warning: M1_dic[j2] is already at j3 resolution [Nbatch, Npix_j3, Norient3]
|
|
4643
|
-
MconvPsi = self.convol(
|
|
4644
|
-
M_dic[j2], axis=1
|
|
4645
|
-
) # [Nbatch, Npix_j3, Norient3, Norient2]
|
|
4646
|
-
if cmat2 is not None:
|
|
4647
|
-
tmp2 = self.backend.bk_repeat(MconvPsi, self.NORIENT, axis=-1)
|
|
4648
|
-
MconvPsi = self.backend.bk_reduce_sum(
|
|
4649
|
-
self.backend.bk_reshape(
|
|
4650
|
-
cmat2[j3][j2] * tmp2,
|
|
4651
|
-
[
|
|
4652
|
-
tmp2.shape[0],
|
|
4653
|
-
cmat2[j3].shape[1],
|
|
4654
|
-
self.NORIENT,
|
|
4655
|
-
self.NORIENT,
|
|
4656
|
-
self.NORIENT,
|
|
4657
|
-
],
|
|
4658
|
-
),
|
|
4659
|
-
3,
|
|
4660
|
-
)
|
|
4661
|
-
|
|
4662
|
-
# Store it so we can use it in S4 computation
|
|
4663
|
-
MconvPsi_dic[j2] = MconvPsi # [Nbatch, Npix_j3, Norient3, Norient2]
|
|
4664
|
-
|
|
4665
|
-
### Compute the product (I2 * Psi)_j3 x (M1_j2 * Psi_j3)^*
|
|
4666
|
-
# z_1 x z_2^* = (a1a2 + b1b2) + i(b1a2 - a1b2)
|
|
4667
|
-
# cconv, sconv are [Nbatch, Npix_j3, Norient3]
|
|
4668
|
-
if self.use_1D:
|
|
4669
|
-
s3 = conv * self.backend.bk_conjugate(MconvPsi)
|
|
4670
|
-
else:
|
|
4671
|
-
s3 = self.backend.bk_expand_dims(conv, -1) * self.backend.bk_conjugate(
|
|
4672
|
-
MconvPsi
|
|
4673
|
-
) # [Nbatch, Npix_j3, Norient3, Norient2]
|
|
4674
|
-
|
|
4675
|
-
### Apply the mask [Nmask, Npix_j3] and sum over pixels
|
|
4676
|
-
if return_data:
|
|
4677
|
-
return s3
|
|
4678
|
-
else:
|
|
4679
|
-
if calc_var:
|
|
4680
|
-
s3, vs3 = self.masked_mean(
|
|
4681
|
-
s3, vmask, axis=1, rank=j2, calc_var=True
|
|
4682
|
-
) # [Nbatch, Nmask, Norient3, Norient2]
|
|
4683
|
-
return s3, vs3
|
|
4684
|
-
else:
|
|
4685
|
-
s3 = self.masked_mean(
|
|
4686
|
-
s3, vmask, axis=1, rank=j2
|
|
4687
|
-
) # [Nbatch, Nmask, Norient3, Norient2]
|
|
4688
|
-
return s3
|
|
4689
|
-
|
|
4690
|
-
def _compute_S4(
|
|
4691
|
-
self,
|
|
4692
|
-
j1,
|
|
4693
|
-
j2,
|
|
4694
|
-
vmask,
|
|
4695
|
-
M1convPsi_dic,
|
|
4696
|
-
M2convPsi_dic=None,
|
|
4697
|
-
calc_var=False,
|
|
4698
|
-
return_data=False,
|
|
4699
|
-
):
|
|
4700
|
-
#### Simplify notations
|
|
4701
|
-
M1 = M1convPsi_dic[j1] # [Nbatch, Npix_j3, Norient3, Norient1]
|
|
4702
|
-
|
|
4703
|
-
# Auto or Cross coefficients
|
|
4704
|
-
if M2convPsi_dic is None: # Auto
|
|
4705
|
-
M2 = M1convPsi_dic[j2] # [Nbatch, Npix_j3, Norient3, Norient2]
|
|
4706
|
-
else: # Cross
|
|
4707
|
-
M2 = M2convPsi_dic[j2]
|
|
4708
|
-
|
|
4709
|
-
### Compute the product (|I1 * Psi_j1| * Psi_j3)(|I2 * Psi_j2| * Psi_j3)
|
|
4710
|
-
# z_1 x z_2^* = (a1a2 + b1b2) + i(b1a2 - a1b2)
|
|
4711
|
-
if self.use_1D:
|
|
4712
|
-
s4 = M1 * self.backend.bk_conjugate(M2)
|
|
4713
|
-
else:
|
|
4714
|
-
s4 = self.backend.bk_expand_dims(M1, -2) * self.backend.bk_conjugate(
|
|
4715
|
-
self.backend.bk_expand_dims(M2, -1)
|
|
4716
|
-
) # [Nbatch, Npix_j3, Norient3, Norient2, Norient1]
|
|
4717
|
-
|
|
4718
|
-
### Apply the mask and sum over pixels
|
|
4719
|
-
if return_data:
|
|
4720
|
-
return s4
|
|
4721
|
-
else:
|
|
4722
|
-
if calc_var:
|
|
4723
|
-
s4, vs4 = self.masked_mean(
|
|
4724
|
-
s4, vmask, axis=1, rank=j2, calc_var=True
|
|
4725
|
-
) # [Nbatch, Nmask, Norient3, Norient2, Norient1]
|
|
4726
|
-
return s4, vs4
|
|
4727
|
-
else:
|
|
4728
|
-
s4 = self.masked_mean(
|
|
4729
|
-
s4, vmask, axis=1, rank=j2
|
|
4730
|
-
) # [Nbatch, Nmask, Norient3, Norient2, Norient1]
|
|
4731
|
-
return s4
|
|
4732
|
-
|
|
4733
|
-
def computer_filter(self,M,N,J,L):
|
|
4734
|
-
'''
|
|
4735
|
-
This function is strongly inspire by the package https://github.com/SihaoCheng/scattering_transform
|
|
4736
|
-
Done by Sihao Cheng and Rudy Morel.
|
|
4737
|
-
'''
|
|
4738
|
-
|
|
4739
|
-
filter = np.zeros([J, L, M, N],dtype='complex64')
|
|
4740
|
-
|
|
4741
|
-
slant=4.0 / L
|
|
4742
|
-
|
|
4743
|
-
for j in range(J):
|
|
4744
|
-
|
|
4745
|
-
for l in range(L):
|
|
4746
|
-
|
|
4747
|
-
theta = (int(L-L/2-1)-l) * np.pi / L
|
|
4748
|
-
sigma = 0.8 * 2**j
|
|
4749
|
-
xi = 3.0 / 4.0 * np.pi /2**j
|
|
4750
|
-
|
|
4751
|
-
R = np.array([[np.cos(theta), -np.sin(theta)], [np.sin(theta), np.cos(theta)]], np.float64)
|
|
4752
|
-
R_inv = np.array([[np.cos(theta), np.sin(theta)], [-np.sin(theta), np.cos(theta)]], np.float64)
|
|
4753
3719
|
D = np.array([[1, 0], [0, slant * slant]])
|
|
4754
|
-
curv = np.matmul(R, np.matmul(D, R_inv)) / (
|
|
4755
|
-
|
|
3720
|
+
curv = np.matmul(R, np.matmul(D, R_inv)) / (2 * sigma * sigma)
|
|
3721
|
+
|
|
4756
3722
|
gab = np.zeros((M, N), np.complex128)
|
|
4757
|
-
xx = np.empty((2,2, M, N))
|
|
4758
|
-
yy = np.empty((2,2, M, N))
|
|
4759
|
-
|
|
3723
|
+
xx = np.empty((2, 2, M, N))
|
|
3724
|
+
yy = np.empty((2, 2, M, N))
|
|
3725
|
+
|
|
4760
3726
|
for ii, ex in enumerate([-1, 0]):
|
|
4761
3727
|
for jj, ey in enumerate([-1, 0]):
|
|
4762
|
-
xx[ii,jj], yy[ii,jj] = np.mgrid[
|
|
4763
|
-
ex * M : M + ex * M,
|
|
4764
|
-
|
|
4765
|
-
|
|
4766
|
-
arg = -(
|
|
4767
|
-
|
|
4768
|
-
|
|
4769
|
-
|
|
4770
|
-
|
|
4771
|
-
|
|
3728
|
+
xx[ii, jj], yy[ii, jj] = np.mgrid[
|
|
3729
|
+
ex * M : M + ex * M, ey * N : N + ey * N
|
|
3730
|
+
]
|
|
3731
|
+
|
|
3732
|
+
arg = -(
|
|
3733
|
+
curv[0, 0] * xx * xx
|
|
3734
|
+
+ (curv[0, 1] + curv[1, 0]) * xx * yy
|
|
3735
|
+
+ curv[1, 1] * yy * yy
|
|
3736
|
+
)
|
|
3737
|
+
argi = arg + 1.0j * (xx * xi * np.cos(theta) + yy * xi * np.sin(theta))
|
|
3738
|
+
|
|
3739
|
+
gabi = np.exp(argi).sum((0, 1))
|
|
3740
|
+
gab = np.exp(arg).sum((0, 1))
|
|
3741
|
+
|
|
4772
3742
|
norm_factor = 2 * np.pi * sigma * sigma / slant
|
|
4773
|
-
|
|
3743
|
+
|
|
4774
3744
|
gab = gab / norm_factor
|
|
4775
|
-
|
|
3745
|
+
|
|
4776
3746
|
gabi = gabi / norm_factor
|
|
4777
3747
|
|
|
4778
3748
|
K = gabi.sum() / gab.sum()
|
|
4779
3749
|
|
|
4780
3750
|
# Apply the Gaussian
|
|
4781
|
-
filter[j,
|
|
4782
|
-
filter[j,
|
|
4783
|
-
|
|
3751
|
+
filter[j, ell] = np.fft.fft2(gabi - K * gab)
|
|
3752
|
+
filter[j, ell, 0, 0] = 0.0
|
|
3753
|
+
|
|
4784
3754
|
return self.backend.bk_cast(filter)
|
|
4785
|
-
|
|
3755
|
+
|
|
4786
3756
|
# ------------------------------------------------------------------------------------------
|
|
4787
3757
|
#
|
|
4788
|
-
# utility functions
|
|
3758
|
+
# utility functions
|
|
4789
3759
|
#
|
|
4790
3760
|
# ------------------------------------------------------------------------------------------
|
|
4791
|
-
def cut_high_k_off(self,data_f, dx, dy):
|
|
4792
|
-
|
|
3761
|
+
def cut_high_k_off(self, data_f, dx, dy):
|
|
3762
|
+
"""
|
|
4793
3763
|
This function is strongly inspire by the package https://github.com/SihaoCheng/scattering_transform
|
|
4794
3764
|
Done by Sihao Cheng and Rudy Morel.
|
|
4795
|
-
|
|
4796
|
-
|
|
4797
|
-
if self.backend.BACKEND==
|
|
4798
|
-
if_xodd =
|
|
4799
|
-
if_yodd =
|
|
3765
|
+
"""
|
|
3766
|
+
|
|
3767
|
+
if self.backend.BACKEND == "torch":
|
|
3768
|
+
if_xodd = data_f.shape[-2] % 2 == 1
|
|
3769
|
+
if_yodd = data_f.shape[-1] % 2 == 1
|
|
4800
3770
|
result = self.backend.backend.cat(
|
|
4801
|
-
(
|
|
4802
|
-
(
|
|
4803
|
-
|
|
4804
|
-
|
|
4805
|
-
|
|
4806
|
-
|
|
4807
|
-
|
|
3771
|
+
(
|
|
3772
|
+
self.backend.backend.cat(
|
|
3773
|
+
(
|
|
3774
|
+
data_f[..., : dx + if_xodd, : dy + if_yodd],
|
|
3775
|
+
data_f[..., -dx:, : dy + if_yodd],
|
|
3776
|
+
),
|
|
3777
|
+
-2,
|
|
3778
|
+
),
|
|
3779
|
+
self.backend.backend.cat(
|
|
3780
|
+
(data_f[..., : dx + if_xodd, -dy:], data_f[..., -dx:, -dy:]), -2
|
|
3781
|
+
),
|
|
3782
|
+
),
|
|
3783
|
+
-1,
|
|
3784
|
+
)
|
|
4808
3785
|
return result
|
|
4809
3786
|
else:
|
|
4810
3787
|
# Check if the last two dimensions are odd
|
|
4811
|
-
if_xodd = self.backend.backend.cast(
|
|
4812
|
-
|
|
3788
|
+
if_xodd = self.backend.backend.cast(
|
|
3789
|
+
self.backend.backend.shape(data_f)[-2] % 2 == 1,
|
|
3790
|
+
self.backend.backend.int32,
|
|
3791
|
+
)
|
|
3792
|
+
if_yodd = self.backend.backend.cast(
|
|
3793
|
+
self.backend.backend.shape(data_f)[-1] % 2 == 1,
|
|
3794
|
+
self.backend.backend.int32,
|
|
3795
|
+
)
|
|
4813
3796
|
|
|
4814
3797
|
# Extract four regions
|
|
4815
|
-
top_left = data_f[..., :dx+if_xodd, :dy+if_yodd]
|
|
4816
|
-
top_right = data_f[..., -dx:, :dy+if_yodd]
|
|
4817
|
-
bottom_left = data_f[..., :dx+if_xodd, -dy:]
|
|
3798
|
+
top_left = data_f[..., : dx + if_xodd, : dy + if_yodd]
|
|
3799
|
+
top_right = data_f[..., -dx:, : dy + if_yodd]
|
|
3800
|
+
bottom_left = data_f[..., : dx + if_xodd, -dy:]
|
|
4818
3801
|
bottom_right = data_f[..., -dx:, -dy:]
|
|
4819
3802
|
|
|
4820
3803
|
# Concatenate along the last two dimensions
|
|
@@ -4829,70 +3812,74 @@ class funct(FOC.FoCUS):
|
|
|
4829
3812
|
# utility functions for computing scattering coef and covariance
|
|
4830
3813
|
#
|
|
4831
3814
|
# ---------------------------------------------------------------------------
|
|
4832
|
-
|
|
4833
|
-
def get_dxdy(self, j,M,N):
|
|
4834
|
-
|
|
3815
|
+
|
|
3816
|
+
def get_dxdy(self, j, M, N):
|
|
3817
|
+
"""
|
|
4835
3818
|
This function is strongly inspire by the package https://github.com/SihaoCheng/scattering_transform
|
|
4836
3819
|
Done by Sihao Cheng and Rudy Morel.
|
|
4837
|
-
|
|
4838
|
-
dx = int(max(
|
|
4839
|
-
dy = int(max(
|
|
3820
|
+
"""
|
|
3821
|
+
dx = int(max(8, min(np.ceil(M / 2**j), M // 2)))
|
|
3822
|
+
dy = int(max(8, min(np.ceil(N / 2**j), N // 2)))
|
|
4840
3823
|
return dx, dy
|
|
4841
|
-
|
|
4842
|
-
|
|
4843
3824
|
|
|
4844
|
-
def get_edge_masks(self,M, N, J, d0=1):
|
|
4845
|
-
|
|
3825
|
+
def get_edge_masks(self, M, N, J, d0=1):
|
|
3826
|
+
"""
|
|
4846
3827
|
This function is strongly inspire by the package https://github.com/SihaoCheng/scattering_transform
|
|
4847
3828
|
Done by Sihao Cheng and Rudy Morel.
|
|
4848
|
-
|
|
3829
|
+
"""
|
|
4849
3830
|
edge_masks = np.empty((J, M, N))
|
|
4850
|
-
X, Y = np.meshgrid(np.arange(M), np.arange(N), indexing=
|
|
3831
|
+
X, Y = np.meshgrid(np.arange(M), np.arange(N), indexing="ij")
|
|
4851
3832
|
for j in range(J):
|
|
4852
|
-
edge_dx = min(M//4, 2**j*d0)
|
|
4853
|
-
edge_dy = min(N//4, 2**j*d0)
|
|
4854
|
-
edge_masks[j] = (
|
|
4855
|
-
|
|
4856
|
-
|
|
3833
|
+
edge_dx = min(M // 4, 2**j * d0)
|
|
3834
|
+
edge_dy = min(N // 4, 2**j * d0)
|
|
3835
|
+
edge_masks[j] = (
|
|
3836
|
+
(X >= edge_dx)
|
|
3837
|
+
* (X <= M - edge_dx)
|
|
3838
|
+
* (Y >= edge_dy)
|
|
3839
|
+
* (Y <= N - edge_dy)
|
|
3840
|
+
)
|
|
3841
|
+
edge_masks = edge_masks[:, None, :, :]
|
|
3842
|
+
edge_masks = edge_masks / edge_masks.mean((-2, -1))[:, :, None, None]
|
|
4857
3843
|
return self.backend.bk_cast(edge_masks)
|
|
4858
|
-
|
|
3844
|
+
|
|
4859
3845
|
# ---------------------------------------------------------------------------
|
|
4860
3846
|
#
|
|
4861
3847
|
# scattering cov
|
|
4862
3848
|
#
|
|
4863
3849
|
# ---------------------------------------------------------------------------
|
|
4864
3850
|
def scattering_cov(
|
|
4865
|
-
self,
|
|
3851
|
+
self,
|
|
3852
|
+
data,
|
|
4866
3853
|
data2=None,
|
|
4867
3854
|
Jmax=None,
|
|
4868
|
-
if_large_batch=False,
|
|
4869
|
-
S4_criteria=None,
|
|
4870
|
-
use_ref=False,
|
|
4871
|
-
normalization=
|
|
3855
|
+
if_large_batch=False,
|
|
3856
|
+
S4_criteria=None,
|
|
3857
|
+
use_ref=False,
|
|
3858
|
+
normalization="S2",
|
|
4872
3859
|
edge=False,
|
|
4873
|
-
pseudo_coef=1,
|
|
4874
|
-
get_variance=False,
|
|
3860
|
+
pseudo_coef=1,
|
|
3861
|
+
get_variance=False,
|
|
4875
3862
|
ref_sigma=None,
|
|
4876
|
-
iso_ang=False
|
|
3863
|
+
iso_ang=False,
|
|
4877
3864
|
):
|
|
4878
|
-
|
|
3865
|
+
"""
|
|
4879
3866
|
Calculates the scattering correlations for a batch of images, including:
|
|
4880
|
-
|
|
3867
|
+
|
|
4881
3868
|
This function is strongly inspire by the package https://github.com/SihaoCheng/scattering_transform
|
|
4882
3869
|
Done by Sihao Cheng and Rudy Morel.
|
|
4883
|
-
|
|
4884
|
-
orig. x orig.:
|
|
3870
|
+
|
|
3871
|
+
orig. x orig.:
|
|
4885
3872
|
P00 = <(I * psi)(I * psi)*> = L2(I * psi)^2
|
|
4886
|
-
orig. x modulus:
|
|
3873
|
+
orig. x modulus:
|
|
4887
3874
|
C01 = <(I * psi2)(|I * psi1| * psi2)*> / factor
|
|
4888
3875
|
when normalization == 'P00', factor = L2(I * psi2) * L2(I * psi1)
|
|
4889
3876
|
when normalization == 'P11', factor = L2(I * psi2) * L2(|I * psi1| * psi2)
|
|
4890
|
-
modulus x modulus:
|
|
3877
|
+
modulus x modulus:
|
|
4891
3878
|
C11_pre_norm = <(|I * psi1| * psi3)(|I * psi2| * psi3)>
|
|
4892
3879
|
C11 = C11_pre_norm / factor
|
|
4893
3880
|
when normalization == 'P00', factor = L2(I * psi1) * L2(I * psi2)
|
|
4894
3881
|
when normalization == 'P11', factor = L2(|I * psi1| * psi3) * L2(|I * psi2| * psi3)
|
|
4895
|
-
modulus x modulus (auto):
|
|
3882
|
+
modulus x modulus (auto):
|
|
4896
3883
|
P11 = <(|I * psi1| * psi2)(|I * psi1| * psi2)*>
|
|
4897
3884
|
Parameters
|
|
4898
3885
|
----------
|
|
@@ -4902,7 +3889,7 @@ class funct(FOC.FoCUS):
|
|
|
4902
3889
|
It is recommended to use "False" unless one meets a memory issue
|
|
4903
3890
|
C11_criteria : str or None (=None)
|
|
4904
3891
|
Only C11 coefficients that satisfy this criteria will be computed.
|
|
4905
|
-
Any expressions of j1, j2, and j3 that can be evaluated as a Bool
|
|
3892
|
+
Any expressions of j1, j2, and j3 that can be evaluated as a Bool
|
|
4906
3893
|
is accepted.The default "None" corresponds to "j1 <= j2 <= j3".
|
|
4907
3894
|
use_ref : Bool (=False)
|
|
4908
3895
|
When normalizing, whether or not to use the normalization factor
|
|
@@ -4916,7 +3903,7 @@ class funct(FOC.FoCUS):
|
|
|
4916
3903
|
If true, the edge region with a width of rougly the size of the largest
|
|
4917
3904
|
wavelet involved is excluded when taking the global average to obtain
|
|
4918
3905
|
the scattering coefficients.
|
|
4919
|
-
|
|
3906
|
+
|
|
4920
3907
|
Returns
|
|
4921
3908
|
-------
|
|
4922
3909
|
'P00' : torch tensor with size [N_image, J, L] (# image, j1, l1)
|
|
@@ -4934,34 +3921,34 @@ class funct(FOC.FoCUS):
|
|
|
4934
3921
|
j1 <= j3 are set to np.nan and not computed.
|
|
4935
3922
|
'P11_iso' : torch tensor with size [N_image, J, J, L] (# image, j1, j2, l2-l1)
|
|
4936
3923
|
'P11' averaged over l1 while keeping l2-l1 constant.
|
|
4937
|
-
|
|
3924
|
+
"""
|
|
4938
3925
|
if S4_criteria is None:
|
|
4939
|
-
S4_criteria =
|
|
4940
|
-
|
|
3926
|
+
S4_criteria = "j2>=j1"
|
|
3927
|
+
|
|
4941
3928
|
if self.all_bk_type == "float32":
|
|
4942
|
-
C_ONE=np.complex64(1.0)
|
|
3929
|
+
C_ONE = np.complex64(1.0)
|
|
4943
3930
|
else:
|
|
4944
|
-
C_ONE=np.complex128(1.0)
|
|
4945
|
-
|
|
3931
|
+
C_ONE = np.complex128(1.0)
|
|
3932
|
+
|
|
4946
3933
|
# determine jmax and nside corresponding to the input map
|
|
4947
3934
|
im_shape = data.shape
|
|
4948
3935
|
if self.use_2D:
|
|
4949
3936
|
if len(data.shape) == 2:
|
|
4950
3937
|
nside = np.min([im_shape[0], im_shape[1]])
|
|
4951
|
-
M,N = im_shape[0],im_shape[1]
|
|
4952
|
-
N_image =
|
|
3938
|
+
M, N = im_shape[0], im_shape[1]
|
|
3939
|
+
N_image = 1
|
|
4953
3940
|
N_image2 = 1
|
|
4954
3941
|
else:
|
|
4955
3942
|
nside = np.min([im_shape[1], im_shape[2]])
|
|
4956
|
-
M,N = im_shape[1],im_shape[2]
|
|
3943
|
+
M, N = im_shape[1], im_shape[2]
|
|
4957
3944
|
N_image = data.shape[0]
|
|
4958
3945
|
if data2 is not None:
|
|
4959
3946
|
N_image2 = data2.shape[0]
|
|
4960
|
-
J = int(np.log(nside) / np.log(2))-1 # Number of j scales
|
|
3947
|
+
J = int(np.log(nside) / np.log(2)) - 1 # Number of j scales
|
|
4961
3948
|
elif self.use_1D:
|
|
4962
3949
|
if len(data.shape) == 2:
|
|
4963
3950
|
npix = int(im_shape[1]) # Number of pixels
|
|
4964
|
-
N_image =
|
|
3951
|
+
N_image = 1
|
|
4965
3952
|
N_image2 = 1
|
|
4966
3953
|
else:
|
|
4967
3954
|
npix = int(im_shape[0]) # Number of pixels
|
|
@@ -4971,12 +3958,12 @@ class funct(FOC.FoCUS):
|
|
|
4971
3958
|
|
|
4972
3959
|
nside = int(npix)
|
|
4973
3960
|
|
|
4974
|
-
J = int(np.log(nside) / np.log(2))-1 # Number of j scales
|
|
3961
|
+
J = int(np.log(nside) / np.log(2)) - 1 # Number of j scales
|
|
4975
3962
|
else:
|
|
4976
3963
|
if len(data.shape) == 2:
|
|
4977
3964
|
npix = int(im_shape[1]) # Number of pixels
|
|
4978
|
-
N_image =
|
|
4979
|
-
N_image2 =
|
|
3965
|
+
N_image = 1
|
|
3966
|
+
N_image2 = 1
|
|
4980
3967
|
else:
|
|
4981
3968
|
npix = int(im_shape[0]) # Number of pixels
|
|
4982
3969
|
N_image = data.shape[0]
|
|
@@ -4986,984 +3973,1715 @@ class funct(FOC.FoCUS):
|
|
|
4986
3973
|
nside = int(np.sqrt(npix // 12))
|
|
4987
3974
|
|
|
4988
3975
|
J = int(np.log(nside) / np.log(2)) # Number of j scales
|
|
4989
|
-
|
|
3976
|
+
|
|
4990
3977
|
if Jmax is None:
|
|
4991
3978
|
Jmax = J # Number of steps for the loop on scales
|
|
4992
|
-
if Jmax>J:
|
|
4993
|
-
print(
|
|
4994
|
-
print(
|
|
4995
|
-
|
|
4996
|
-
|
|
4997
|
-
|
|
4998
|
-
|
|
4999
|
-
|
|
5000
|
-
|
|
5001
|
-
|
|
5002
|
-
|
|
5003
|
-
|
|
5004
|
-
|
|
5005
|
-
|
|
5006
|
-
|
|
3979
|
+
if Jmax > J:
|
|
3980
|
+
print("==========\n\n")
|
|
3981
|
+
print(
|
|
3982
|
+
"The Jmax you requested is larger than the data size, which may cause problems while computing the scattering transform."
|
|
3983
|
+
)
|
|
3984
|
+
print("\n\n==========")
|
|
3985
|
+
|
|
3986
|
+
L = self.NORIENT
|
|
3987
|
+
norm_factor_S3 = 1.0
|
|
3988
|
+
|
|
3989
|
+
if self.backend.BACKEND == "torch":
|
|
3990
|
+
if (M, N, J, L) not in self.filters_set:
|
|
3991
|
+
self.filters_set[(M, N, J, L)] = self.computer_filter(
|
|
3992
|
+
M, N, J, L
|
|
3993
|
+
) # self.computer_filter(M,N,J,L)
|
|
3994
|
+
|
|
3995
|
+
filters_set = self.filters_set[(M, N, J, L)]
|
|
3996
|
+
|
|
3997
|
+
# weight = self.weight
|
|
5007
3998
|
if use_ref:
|
|
5008
|
-
if normalization==
|
|
3999
|
+
if normalization == "S2":
|
|
5009
4000
|
ref_S2 = self.ref_scattering_cov_S2
|
|
5010
|
-
else:
|
|
5011
|
-
ref_P11 = self.ref_scattering_cov[
|
|
4001
|
+
else:
|
|
4002
|
+
ref_P11 = self.ref_scattering_cov["P11"]
|
|
5012
4003
|
|
|
5013
4004
|
# convert numpy array input into self.backend.bk_ tensors
|
|
5014
4005
|
data = self.backend.bk_cast(data)
|
|
5015
|
-
data_f = self.backend.bk_fftn(data, dim=(-2
|
|
4006
|
+
data_f = self.backend.bk_fftn(data, dim=(-2, -1))
|
|
5016
4007
|
if data2 is not None:
|
|
5017
4008
|
data2 = self.backend.bk_cast(data2)
|
|
5018
|
-
data2_f = self.backend.bk_fftn(data2, dim=(-2
|
|
5019
|
-
|
|
4009
|
+
data2_f = self.backend.bk_fftn(data2, dim=(-2, -1))
|
|
4010
|
+
|
|
5020
4011
|
# initialize tensors for scattering coefficients
|
|
5021
|
-
S2 = self.backend.bk_zeros((N_image,J,L), dtype=data.dtype)
|
|
5022
|
-
S1 = self.backend.bk_zeros((N_image,J,L), dtype=data.dtype)
|
|
5023
|
-
|
|
5024
|
-
Ndata_S3 = J*(J+1)//2
|
|
5025
|
-
Ndata_S4 = J*(J+1)*(J+2)//6
|
|
5026
|
-
J_S4={}
|
|
5027
|
-
|
|
5028
|
-
S3 = self.backend.bk_zeros((N_image,Ndata_S3,L,L), dtype=data_f.dtype)
|
|
4012
|
+
S2 = self.backend.bk_zeros((N_image, J, L), dtype=data.dtype)
|
|
4013
|
+
S1 = self.backend.bk_zeros((N_image, J, L), dtype=data.dtype)
|
|
4014
|
+
|
|
4015
|
+
Ndata_S3 = J * (J + 1) // 2
|
|
4016
|
+
Ndata_S4 = J * (J + 1) * (J + 2) // 6
|
|
4017
|
+
J_S4 = {}
|
|
4018
|
+
|
|
4019
|
+
S3 = self.backend.bk_zeros((N_image, Ndata_S3, L, L), dtype=data_f.dtype)
|
|
5029
4020
|
if data2 is not None:
|
|
5030
|
-
S3p = self.backend.bk_zeros(
|
|
5031
|
-
|
|
5032
|
-
|
|
5033
|
-
|
|
4021
|
+
S3p = self.backend.bk_zeros(
|
|
4022
|
+
(N_image, Ndata_S3, L, L), dtype=data_f.dtype
|
|
4023
|
+
)
|
|
4024
|
+
S4_pre_norm = self.backend.bk_zeros(
|
|
4025
|
+
(N_image, Ndata_S4, L, L, L), dtype=data_f.dtype
|
|
4026
|
+
)
|
|
4027
|
+
S4 = self.backend.bk_zeros((N_image, Ndata_S4, L, L, L), dtype=data_f.dtype)
|
|
4028
|
+
|
|
5034
4029
|
# variance
|
|
5035
4030
|
if get_variance:
|
|
5036
|
-
S2_sigma = self.backend.bk_zeros((N_image,J,L), dtype=data.dtype)
|
|
5037
|
-
S1_sigma = self.backend.bk_zeros((N_image,J,L), dtype=data.dtype)
|
|
5038
|
-
S3_sigma = self.backend.bk_zeros(
|
|
4031
|
+
S2_sigma = self.backend.bk_zeros((N_image, J, L), dtype=data.dtype)
|
|
4032
|
+
S1_sigma = self.backend.bk_zeros((N_image, J, L), dtype=data.dtype)
|
|
4033
|
+
S3_sigma = self.backend.bk_zeros(
|
|
4034
|
+
(N_image, Ndata_S3, L, L), dtype=data_f.dtype
|
|
4035
|
+
)
|
|
5039
4036
|
if data2 is not None:
|
|
5040
|
-
S3p_sigma = self.backend.bk_zeros(
|
|
5041
|
-
|
|
5042
|
-
|
|
4037
|
+
S3p_sigma = self.backend.bk_zeros(
|
|
4038
|
+
(N_image, Ndata_S3, L, L), dtype=data_f.dtype
|
|
4039
|
+
)
|
|
4040
|
+
S4_sigma = self.backend.bk_zeros(
|
|
4041
|
+
(N_image, Ndata_S4, L, L, L), dtype=data_f.dtype
|
|
4042
|
+
)
|
|
4043
|
+
|
|
5043
4044
|
if iso_ang:
|
|
5044
|
-
S3_iso = self.backend.bk_zeros(
|
|
5045
|
-
|
|
4045
|
+
S3_iso = self.backend.bk_zeros(
|
|
4046
|
+
(N_image, Ndata_S3, L), dtype=data_f.dtype
|
|
4047
|
+
)
|
|
4048
|
+
S4_iso = self.backend.bk_zeros(
|
|
4049
|
+
(N_image, Ndata_S4, L, L), dtype=data_f.dtype
|
|
4050
|
+
)
|
|
5046
4051
|
if get_variance:
|
|
5047
|
-
S3_sigma_iso = self.backend.bk_zeros(
|
|
5048
|
-
|
|
4052
|
+
S3_sigma_iso = self.backend.bk_zeros(
|
|
4053
|
+
(N_image, Ndata_S3, L), dtype=data_f.dtype
|
|
4054
|
+
)
|
|
4055
|
+
S4_sigma_iso = self.backend.bk_zeros(
|
|
4056
|
+
(N_image, Ndata_S4, L, L), dtype=data_f.dtype
|
|
4057
|
+
)
|
|
5049
4058
|
if data2 is not None:
|
|
5050
|
-
S3p_iso = self.backend.bk_zeros(
|
|
4059
|
+
S3p_iso = self.backend.bk_zeros(
|
|
4060
|
+
(N_image, Ndata_S3, L), dtype=data_f.dtype
|
|
4061
|
+
)
|
|
5051
4062
|
if get_variance:
|
|
5052
|
-
S3p_sigma_iso = self.backend.bk_zeros(
|
|
5053
|
-
|
|
4063
|
+
S3p_sigma_iso = self.backend.bk_zeros(
|
|
4064
|
+
(N_image, Ndata_S3, L), dtype=data_f.dtype
|
|
4065
|
+
)
|
|
4066
|
+
|
|
5054
4067
|
#
|
|
5055
|
-
if edge:
|
|
5056
|
-
if (M,N,J) not in self.edge_masks:
|
|
5057
|
-
self.edge_masks[(M,N,J)] = self.get_edge_masks(M,N,J)
|
|
5058
|
-
edge_mask=self.edge_masks[(M,N,J)]
|
|
5059
|
-
else:
|
|
4068
|
+
if edge:
|
|
4069
|
+
if (M, N, J) not in self.edge_masks:
|
|
4070
|
+
self.edge_masks[(M, N, J)] = self.get_edge_masks(M, N, J)
|
|
4071
|
+
edge_mask = self.edge_masks[(M, N, J)]
|
|
4072
|
+
else:
|
|
5060
4073
|
edge_mask = 1
|
|
5061
|
-
|
|
4074
|
+
|
|
5062
4075
|
# calculate scattering fields
|
|
5063
4076
|
if data2 is None:
|
|
5064
4077
|
if self.use_2D:
|
|
5065
4078
|
if len(data.shape) == 2:
|
|
5066
4079
|
I1 = self.backend.bk_ifftn(
|
|
5067
|
-
data_f[None,None,None
|
|
4080
|
+
data_f[None, None, None, :, :]
|
|
4081
|
+
* filters_set[None, :J, :, :, :],
|
|
4082
|
+
dim=(-2, -1),
|
|
5068
4083
|
).abs()
|
|
5069
4084
|
else:
|
|
5070
4085
|
I1 = self.backend.bk_ifftn(
|
|
5071
|
-
data_f[:,None,None
|
|
4086
|
+
data_f[:, None, None, :, :]
|
|
4087
|
+
* filters_set[None, :J, :, :, :],
|
|
4088
|
+
dim=(-2, -1),
|
|
5072
4089
|
).abs()
|
|
5073
4090
|
elif self.use_1D:
|
|
5074
4091
|
if len(data.shape) == 1:
|
|
5075
4092
|
I1 = self.backend.bk_ifftn(
|
|
5076
|
-
data_f[None,None,None
|
|
4093
|
+
data_f[None, None, None, :] * filters_set[None, :J, :, :],
|
|
4094
|
+
dim=(-1),
|
|
5077
4095
|
).abs()
|
|
5078
4096
|
else:
|
|
5079
4097
|
I1 = self.backend.bk_ifftn(
|
|
5080
|
-
data_f[:,None,None
|
|
4098
|
+
data_f[:, None, None, :] * filters_set[None, :J, :, :],
|
|
4099
|
+
dim=(-1),
|
|
5081
4100
|
).abs()
|
|
5082
4101
|
else:
|
|
5083
|
-
print(
|
|
5084
|
-
|
|
5085
|
-
S2 = (I1**2 * edge_mask).mean((-2
|
|
5086
|
-
S1
|
|
4102
|
+
print("todo")
|
|
4103
|
+
|
|
4104
|
+
S2 = (I1**2 * edge_mask).mean((-2, -1))
|
|
4105
|
+
S1 = (I1 * edge_mask).mean((-2, -1))
|
|
5087
4106
|
|
|
5088
4107
|
if get_variance:
|
|
5089
|
-
S2_sigma = (I1**2 * edge_mask).std((-2
|
|
5090
|
-
S1_sigma
|
|
5091
|
-
|
|
4108
|
+
S2_sigma = (I1**2 * edge_mask).std((-2, -1))
|
|
4109
|
+
S1_sigma = (I1 * edge_mask).std((-2, -1))
|
|
4110
|
+
|
|
5092
4111
|
else:
|
|
5093
4112
|
if self.use_2D:
|
|
5094
4113
|
if len(data.shape) == 2:
|
|
5095
4114
|
I1 = self.backend.bk_ifftn(
|
|
5096
|
-
data_f[None,None,None
|
|
4115
|
+
data_f[None, None, None, :, :]
|
|
4116
|
+
* filters_set[None, :J, :, :, :],
|
|
4117
|
+
dim=(-2, -1),
|
|
5097
4118
|
)
|
|
5098
4119
|
I2 = self.backend.bk_ifftn(
|
|
5099
|
-
data2_f[None,None,None
|
|
4120
|
+
data2_f[None, None, None, :, :]
|
|
4121
|
+
* filters_set[None, :J, :, :, :],
|
|
4122
|
+
dim=(-2, -1),
|
|
5100
4123
|
)
|
|
5101
4124
|
else:
|
|
5102
4125
|
I1 = self.backend.bk_ifftn(
|
|
5103
|
-
data_f[:,None,None
|
|
4126
|
+
data_f[:, None, None, :, :]
|
|
4127
|
+
* filters_set[None, :J, :, :, :],
|
|
4128
|
+
dim=(-2, -1),
|
|
5104
4129
|
)
|
|
5105
4130
|
I2 = self.backend.bk_ifftn(
|
|
5106
|
-
data2_f[:,None,None
|
|
4131
|
+
data2_f[:, None, None, :, :]
|
|
4132
|
+
* filters_set[None, :J, :, :, :],
|
|
4133
|
+
dim=(-2, -1),
|
|
5107
4134
|
)
|
|
5108
4135
|
elif self.use_1D:
|
|
5109
4136
|
if len(data.shape) == 1:
|
|
5110
4137
|
I1 = self.backend.bk_ifftn(
|
|
5111
|
-
data_f[None,None,None
|
|
4138
|
+
data_f[None, None, None, :] * filters_set[None, :J, :, :],
|
|
4139
|
+
dim=(-1),
|
|
5112
4140
|
)
|
|
5113
4141
|
I2 = self.backend.bk_ifftn(
|
|
5114
|
-
data2_f[None,None,None
|
|
4142
|
+
data2_f[None, None, None, :] * filters_set[None, :J, :, :],
|
|
4143
|
+
dim=(-1),
|
|
5115
4144
|
)
|
|
5116
4145
|
else:
|
|
5117
4146
|
I1 = self.backend.bk_ifftn(
|
|
5118
|
-
data_f[:,None,None
|
|
4147
|
+
data_f[:, None, None, :] * filters_set[None, :J, :, :],
|
|
4148
|
+
dim=(-1),
|
|
5119
4149
|
)
|
|
5120
4150
|
I2 = self.backend.bk_ifftn(
|
|
5121
|
-
data2_f[:,None,None
|
|
4151
|
+
data2_f[:, None, None, :] * filters_set[None, :J, :, :],
|
|
4152
|
+
dim=(-1),
|
|
5122
4153
|
)
|
|
5123
4154
|
else:
|
|
5124
|
-
print(
|
|
5125
|
-
|
|
5126
|
-
I1=self.backend.bk_real(I1*self.backend.bk_conjugate(I2))
|
|
5127
|
-
|
|
5128
|
-
S2 = self.backend.bk_reduce_mean((I1 * edge_mask),axis=(-2
|
|
4155
|
+
print("todo")
|
|
4156
|
+
|
|
4157
|
+
I1 = self.backend.bk_real(I1 * self.backend.bk_conjugate(I2))
|
|
4158
|
+
|
|
4159
|
+
S2 = self.backend.bk_reduce_mean((I1 * edge_mask), axis=(-2, -1))
|
|
5129
4160
|
if get_variance:
|
|
5130
|
-
S2_sigma = self.backend.bk_reduce_std(
|
|
5131
|
-
|
|
5132
|
-
|
|
5133
|
-
|
|
5134
|
-
|
|
4161
|
+
S2_sigma = self.backend.bk_reduce_std(
|
|
4162
|
+
(I1 * edge_mask), axis=(-2, -1)
|
|
4163
|
+
)
|
|
4164
|
+
|
|
4165
|
+
I1 = self.backend.bk_L1(I1)
|
|
4166
|
+
|
|
4167
|
+
S1 = self.backend.bk_reduce_mean((I1 * edge_mask), axis=(-2, -1))
|
|
5135
4168
|
|
|
5136
4169
|
if get_variance:
|
|
5137
|
-
S1_sigma
|
|
5138
|
-
|
|
5139
|
-
|
|
5140
|
-
|
|
4170
|
+
S1_sigma = self.backend.bk_reduce_std(
|
|
4171
|
+
(I1 * edge_mask), axis=(-2, -1)
|
|
4172
|
+
)
|
|
4173
|
+
|
|
4174
|
+
I1_f = self.backend.bk_fftn(I1, dim=(-2, -1))
|
|
4175
|
+
|
|
5141
4176
|
if pseudo_coef != 1:
|
|
5142
4177
|
I1 = I1**pseudo_coef
|
|
5143
|
-
|
|
5144
|
-
Ndata_S3=0
|
|
5145
|
-
Ndata_S4=0
|
|
5146
|
-
|
|
4178
|
+
|
|
4179
|
+
Ndata_S3 = 0
|
|
4180
|
+
Ndata_S4 = 0
|
|
4181
|
+
|
|
5147
4182
|
# calculate the covariance and correlations of the scattering fields
|
|
5148
4183
|
# only use the low-k Fourier coefs when calculating large-j scattering coefs.
|
|
5149
|
-
for j3 in range(0,J):
|
|
5150
|
-
J_S4[j3]=Ndata_S4
|
|
5151
|
-
|
|
5152
|
-
dx3, dy3 = self.get_dxdy(j3,M,N)
|
|
5153
|
-
I1_f_small = self.cut_high_k_off(
|
|
4184
|
+
for j3 in range(0, J):
|
|
4185
|
+
J_S4[j3] = Ndata_S4
|
|
4186
|
+
|
|
4187
|
+
dx3, dy3 = self.get_dxdy(j3, M, N)
|
|
4188
|
+
I1_f_small = self.cut_high_k_off(
|
|
4189
|
+
I1_f[:, : j3 + 1], dx3, dy3
|
|
4190
|
+
) # Nimage, J, L, x, y
|
|
5154
4191
|
data_f_small = self.cut_high_k_off(data_f, dx3, dy3)
|
|
5155
4192
|
if data2 is not None:
|
|
5156
4193
|
data2_f_small = self.cut_high_k_off(data2_f, dx3, dy3)
|
|
5157
4194
|
if edge:
|
|
5158
|
-
I1_small = self.backend.bk_ifftn(
|
|
5159
|
-
|
|
4195
|
+
I1_small = self.backend.bk_ifftn(
|
|
4196
|
+
I1_f_small, dim=(-2, -1), norm="ortho"
|
|
4197
|
+
)
|
|
4198
|
+
data_small = self.backend.bk_ifftn(
|
|
4199
|
+
data_f_small, dim=(-2, -1), norm="ortho"
|
|
4200
|
+
)
|
|
5160
4201
|
if data2 is not None:
|
|
5161
|
-
data2_small = self.backend.bk_ifftn(
|
|
5162
|
-
|
|
5163
|
-
|
|
4202
|
+
data2_small = self.backend.bk_ifftn(
|
|
4203
|
+
data2_f_small, dim=(-2, -1), norm="ortho"
|
|
4204
|
+
)
|
|
4205
|
+
|
|
4206
|
+
wavelet_f3 = self.cut_high_k_off(filters_set[j3], dx3, dy3) # L,x,y
|
|
5164
4207
|
_, M3, N3 = wavelet_f3.shape
|
|
5165
4208
|
wavelet_f3_squared = wavelet_f3**2
|
|
5166
|
-
edge_dx = min(4, int(2**j3*dx3*2/M))
|
|
5167
|
-
edge_dy = min(4, int(2**j3*dy3*2/N))
|
|
5168
|
-
|
|
4209
|
+
edge_dx = min(4, int(2**j3 * dx3 * 2 / M))
|
|
4210
|
+
edge_dy = min(4, int(2**j3 * dy3 * 2 / N))
|
|
4211
|
+
|
|
5169
4212
|
# a normalization change due to the cutoff of frequency space
|
|
5170
|
-
fft_factor = 1 /(M3*N3) * (M3*N3/M/N)**2
|
|
5171
|
-
for j2 in range(0,j3+1):
|
|
5172
|
-
I1_f2_wf3_small = I1_f_small[:,j2].view(
|
|
5173
|
-
|
|
4213
|
+
fft_factor = 1 / (M3 * N3) * (M3 * N3 / M / N) ** 2
|
|
4214
|
+
for j2 in range(0, j3 + 1):
|
|
4215
|
+
I1_f2_wf3_small = I1_f_small[:, j2].view(
|
|
4216
|
+
N_image, L, 1, M3, N3
|
|
4217
|
+
) * wavelet_f3.view(1, 1, L, M3, N3)
|
|
4218
|
+
I1_f2_wf3_2_small = I1_f_small[:, j2].view(
|
|
4219
|
+
N_image, L, 1, M3, N3
|
|
4220
|
+
) * wavelet_f3_squared.view(1, 1, L, M3, N3)
|
|
5174
4221
|
if edge:
|
|
5175
|
-
I12_w3_small = self.backend.bk_ifftn(
|
|
5176
|
-
|
|
4222
|
+
I12_w3_small = self.backend.bk_ifftn(
|
|
4223
|
+
I1_f2_wf3_small, dim=(-2, -1), norm="ortho"
|
|
4224
|
+
)
|
|
4225
|
+
I12_w3_2_small = self.backend.bk_ifftn(
|
|
4226
|
+
I1_f2_wf3_2_small, dim=(-2, -1), norm="ortho"
|
|
4227
|
+
)
|
|
5177
4228
|
if use_ref:
|
|
5178
|
-
if normalization==
|
|
5179
|
-
norm_factor_S3 = (
|
|
5180
|
-
|
|
5181
|
-
|
|
4229
|
+
if normalization == "P11":
|
|
4230
|
+
norm_factor_S3 = (
|
|
4231
|
+
ref_S2[:, None, j3, :]
|
|
4232
|
+
* ref_P11[:, j2, j3, :, :] ** pseudo_coef
|
|
4233
|
+
) ** 0.5
|
|
4234
|
+
if normalization == "S2":
|
|
4235
|
+
norm_factor_S3 = (
|
|
4236
|
+
ref_S2[:, None, j3, :]
|
|
4237
|
+
* ref_S2[:, j2, :, None] ** pseudo_coef
|
|
4238
|
+
) ** 0.5
|
|
5182
4239
|
else:
|
|
5183
|
-
if normalization==
|
|
4240
|
+
if normalization == "P11":
|
|
5184
4241
|
# [N_image,l2,l3,x,y]
|
|
5185
|
-
P11_temp = (I1_f2_wf3_small.abs()**2).mean(
|
|
5186
|
-
|
|
5187
|
-
|
|
5188
|
-
norm_factor_S3 = (
|
|
4242
|
+
P11_temp = (I1_f2_wf3_small.abs() ** 2).mean(
|
|
4243
|
+
(-2, -1)
|
|
4244
|
+
) * fft_factor
|
|
4245
|
+
norm_factor_S3 = (
|
|
4246
|
+
S2[:, None, j3, :] * P11_temp**pseudo_coef
|
|
4247
|
+
) ** 0.5
|
|
4248
|
+
if normalization == "S2":
|
|
4249
|
+
norm_factor_S3 = (
|
|
4250
|
+
S2[:, None, j3, :] * S2[:, j2, :, None] ** pseudo_coef
|
|
4251
|
+
) ** 0.5
|
|
5189
4252
|
|
|
5190
4253
|
if not edge:
|
|
5191
|
-
S3[:,Ndata_S3
|
|
5192
|
-
|
|
5193
|
-
|
|
5194
|
-
|
|
4254
|
+
S3[:, Ndata_S3, :, :] = (
|
|
4255
|
+
(
|
|
4256
|
+
data_f_small.view(N_image, 1, 1, M3, N3)
|
|
4257
|
+
* self.backend.bk_conjugate(I1_f2_wf3_small)
|
|
4258
|
+
).mean((-2, -1))
|
|
4259
|
+
* fft_factor
|
|
4260
|
+
/ norm_factor_S3
|
|
4261
|
+
)
|
|
4262
|
+
|
|
5195
4263
|
if get_variance:
|
|
5196
|
-
S3_sigma[:,Ndata_S3
|
|
5197
|
-
|
|
5198
|
-
|
|
4264
|
+
S3_sigma[:, Ndata_S3, :, :] = (
|
|
4265
|
+
(
|
|
4266
|
+
data_f_small.view(N_image, 1, 1, M3, N3)
|
|
4267
|
+
* self.backend.bk_conjugate(I1_f2_wf3_small)
|
|
4268
|
+
).std((-2, -1))
|
|
4269
|
+
* fft_factor
|
|
4270
|
+
/ norm_factor_S3
|
|
4271
|
+
)
|
|
5199
4272
|
else:
|
|
5200
|
-
S3[:,Ndata_S3
|
|
5201
|
-
|
|
5202
|
-
|
|
4273
|
+
S3[:, Ndata_S3, :, :] = (
|
|
4274
|
+
(
|
|
4275
|
+
data_small.view(N_image, 1, 1, M3, N3)
|
|
4276
|
+
* self.backend.bk_conjugate(I12_w3_small)
|
|
4277
|
+
)[..., edge_dx : M3 - edge_dx, edge_dy : N3 - edge_dy].mean(
|
|
4278
|
+
(-2, -1)
|
|
4279
|
+
)
|
|
4280
|
+
* fft_factor
|
|
4281
|
+
/ norm_factor_S3
|
|
4282
|
+
)
|
|
5203
4283
|
if get_variance:
|
|
5204
|
-
S3_sigma[:,Ndata_S3
|
|
5205
|
-
|
|
5206
|
-
|
|
4284
|
+
S3_sigma[:, Ndata_S3, :, :] = (
|
|
4285
|
+
(
|
|
4286
|
+
data_small.view(N_image, 1, 1, M3, N3)
|
|
4287
|
+
* self.backend.bk_conjugate(I12_w3_small)
|
|
4288
|
+
)[
|
|
4289
|
+
..., edge_dx : M3 - edge_dx, edge_dy : N3 - edge_dy
|
|
4290
|
+
].std(
|
|
4291
|
+
(-2, -1)
|
|
4292
|
+
)
|
|
4293
|
+
* fft_factor
|
|
4294
|
+
/ norm_factor_S3
|
|
4295
|
+
)
|
|
5207
4296
|
if data2 is not None:
|
|
5208
4297
|
if not edge:
|
|
5209
|
-
S3p[:,Ndata_S3
|
|
5210
|
-
|
|
5211
|
-
|
|
5212
|
-
|
|
4298
|
+
S3p[:, Ndata_S3, :, :] = (
|
|
4299
|
+
(
|
|
4300
|
+
data2_f_small.view(N_image2, 1, 1, M3, N3)
|
|
4301
|
+
* self.backend.bk_conjugate(I1_f2_wf3_small)
|
|
4302
|
+
).mean((-2, -1))
|
|
4303
|
+
* fft_factor
|
|
4304
|
+
/ norm_factor_S3
|
|
4305
|
+
)
|
|
4306
|
+
|
|
5213
4307
|
if get_variance:
|
|
5214
|
-
S3p_sigma[:,Ndata_S3
|
|
5215
|
-
|
|
5216
|
-
|
|
4308
|
+
S3p_sigma[:, Ndata_S3, :, :] = (
|
|
4309
|
+
(
|
|
4310
|
+
data2_f_small.view(N_image2, 1, 1, M3, N3)
|
|
4311
|
+
* self.backend.bk_conjugate(I1_f2_wf3_small)
|
|
4312
|
+
).std((-2, -1))
|
|
4313
|
+
* fft_factor
|
|
4314
|
+
/ norm_factor_S3
|
|
4315
|
+
)
|
|
5217
4316
|
else:
|
|
5218
|
-
S3p[:,Ndata_S3
|
|
5219
|
-
|
|
5220
|
-
|
|
4317
|
+
S3p[:, Ndata_S3, :, :] = (
|
|
4318
|
+
(
|
|
4319
|
+
data2_small.view(N_image2, 1, 1, M3, N3)
|
|
4320
|
+
* self.backend.bk_conjugate(I12_w3_small)
|
|
4321
|
+
)[
|
|
4322
|
+
..., edge_dx : M3 - edge_dx, edge_dy : N3 - edge_dy
|
|
4323
|
+
].mean(
|
|
4324
|
+
(-2, -1)
|
|
4325
|
+
)
|
|
4326
|
+
* fft_factor
|
|
4327
|
+
/ norm_factor_S3
|
|
4328
|
+
)
|
|
5221
4329
|
if get_variance:
|
|
5222
|
-
S3p_sigma[:,Ndata_S3
|
|
5223
|
-
|
|
5224
|
-
|
|
5225
|
-
|
|
4330
|
+
S3p_sigma[:, Ndata_S3, :, :] = (
|
|
4331
|
+
(
|
|
4332
|
+
data2_small.view(N_image2, 1, 1, M3, N3)
|
|
4333
|
+
* self.backend.bk_conjugate(I12_w3_small)
|
|
4334
|
+
)[
|
|
4335
|
+
...,
|
|
4336
|
+
edge_dx : M3 - edge_dx,
|
|
4337
|
+
edge_dy : N3 - edge_dy,
|
|
4338
|
+
].std(
|
|
4339
|
+
(-2, -1)
|
|
4340
|
+
)
|
|
4341
|
+
* fft_factor
|
|
4342
|
+
/ norm_factor_S3
|
|
4343
|
+
)
|
|
4344
|
+
Ndata_S3 += 1
|
|
5226
4345
|
if j2 <= j3:
|
|
5227
|
-
beg_n=Ndata_S4
|
|
5228
|
-
for j1 in range(0, j2+1):
|
|
4346
|
+
beg_n = Ndata_S4
|
|
4347
|
+
for j1 in range(0, j2 + 1):
|
|
5229
4348
|
if eval(S4_criteria):
|
|
5230
4349
|
if not edge:
|
|
5231
4350
|
if not if_large_batch:
|
|
5232
4351
|
# [N_image,l1,l2,l3,x,y]
|
|
5233
|
-
S4_pre_norm[:,Ndata_S4
|
|
5234
|
-
I1_f_small[:,j1].view(
|
|
5235
|
-
|
|
5236
|
-
|
|
4352
|
+
S4_pre_norm[:, Ndata_S4, :, :, :] = (
|
|
4353
|
+
I1_f_small[:, j1].view(
|
|
4354
|
+
N_image, L, 1, 1, M3, N3
|
|
4355
|
+
)
|
|
4356
|
+
* self.backend.bk_conjugate(
|
|
4357
|
+
I1_f2_wf3_2_small.view(
|
|
4358
|
+
N_image, 1, L, L, M3, N3
|
|
4359
|
+
)
|
|
4360
|
+
)
|
|
4361
|
+
).mean((-2, -1)) * fft_factor
|
|
5237
4362
|
if get_variance:
|
|
5238
|
-
S4_sigma[:,Ndata_S4
|
|
5239
|
-
I1_f_small[:,j1].view(
|
|
5240
|
-
|
|
5241
|
-
|
|
4363
|
+
S4_sigma[:, Ndata_S4, :, :, :] = (
|
|
4364
|
+
I1_f_small[:, j1].view(
|
|
4365
|
+
N_image, L, 1, 1, M3, N3
|
|
4366
|
+
)
|
|
4367
|
+
* self.backend.bk_conjugate(
|
|
4368
|
+
I1_f2_wf3_2_small.view(
|
|
4369
|
+
N_image, 1, L, L, M3, N3
|
|
4370
|
+
)
|
|
4371
|
+
)
|
|
4372
|
+
).std((-2, -1)) * fft_factor
|
|
5242
4373
|
else:
|
|
5243
4374
|
for l1 in range(L):
|
|
5244
4375
|
# [N_image,l2,l3,x,y]
|
|
5245
|
-
S4_pre_norm[:,Ndata_S4,l1
|
|
5246
|
-
I1_f_small[:,j1,l1].view(
|
|
5247
|
-
|
|
5248
|
-
|
|
4376
|
+
S4_pre_norm[:, Ndata_S4, l1, :, :] = (
|
|
4377
|
+
I1_f_small[:, j1, l1].view(
|
|
4378
|
+
N_image, 1, 1, M3, N3
|
|
4379
|
+
)
|
|
4380
|
+
* self.backend.bk_conjugate(
|
|
4381
|
+
I1_f2_wf3_2_small.view(
|
|
4382
|
+
N_image, L, L, M3, N3
|
|
4383
|
+
)
|
|
4384
|
+
)
|
|
4385
|
+
).mean((-2, -1)) * fft_factor
|
|
5249
4386
|
if get_variance:
|
|
5250
|
-
S4_sigma[:,Ndata_S4,l1
|
|
5251
|
-
I1_f_small[:,j1,l1].view(
|
|
5252
|
-
|
|
5253
|
-
|
|
4387
|
+
S4_sigma[:, Ndata_S4, l1, :, :] = (
|
|
4388
|
+
I1_f_small[:, j1, l1].view(
|
|
4389
|
+
N_image, 1, 1, M3, N3
|
|
4390
|
+
)
|
|
4391
|
+
* self.backend.bk_conjugate(
|
|
4392
|
+
I1_f2_wf3_2_small.view(
|
|
4393
|
+
N_image, L, L, M3, N3
|
|
4394
|
+
)
|
|
4395
|
+
)
|
|
4396
|
+
).std((-2, -1)) * fft_factor
|
|
5254
4397
|
else:
|
|
5255
4398
|
if not if_large_batch:
|
|
5256
4399
|
# [N_image,l1,l2,l3,x,y]
|
|
5257
|
-
S4_pre_norm[:,Ndata_S4
|
|
5258
|
-
I1_small[:,j1].view(
|
|
5259
|
-
|
|
4400
|
+
S4_pre_norm[:, Ndata_S4, :, :, :] = (
|
|
4401
|
+
I1_small[:, j1].view(
|
|
4402
|
+
N_image, L, 1, 1, M3, N3
|
|
4403
|
+
)
|
|
4404
|
+
* self.backend.bk_conjugate(
|
|
4405
|
+
I12_w3_2_small.view(
|
|
4406
|
+
N_image, 1, L, L, M3, N3
|
|
4407
|
+
)
|
|
5260
4408
|
)
|
|
5261
|
-
)[...,edge_dx:-edge_dx, edge_dy:-edge_dy].mean(
|
|
4409
|
+
)[..., edge_dx:-edge_dx, edge_dy:-edge_dy].mean(
|
|
4410
|
+
(-2, -1)
|
|
4411
|
+
) * fft_factor
|
|
5262
4412
|
if get_variance:
|
|
5263
|
-
S4_sigma[:,Ndata_S4
|
|
5264
|
-
I1_small[:,j1].view(
|
|
5265
|
-
|
|
4413
|
+
S4_sigma[:, Ndata_S4, :, :, :] = (
|
|
4414
|
+
I1_small[:, j1].view(
|
|
4415
|
+
N_image, L, 1, 1, M3, N3
|
|
4416
|
+
)
|
|
4417
|
+
* self.backend.bk_conjugate(
|
|
4418
|
+
I12_w3_2_small.view(
|
|
4419
|
+
N_image, 1, L, L, M3, N3
|
|
4420
|
+
)
|
|
5266
4421
|
)
|
|
5267
|
-
)[
|
|
4422
|
+
)[
|
|
4423
|
+
..., edge_dx:-edge_dx, edge_dy:-edge_dy
|
|
4424
|
+
].std(
|
|
4425
|
+
(-2, -1)
|
|
4426
|
+
) * fft_factor
|
|
5268
4427
|
else:
|
|
5269
4428
|
for l1 in range(L):
|
|
5270
|
-
|
|
5271
|
-
S4_pre_norm[:,Ndata_S4,l1
|
|
5272
|
-
I1_small[:,j1].view(
|
|
5273
|
-
|
|
4429
|
+
# [N_image,l2,l3,x,y]
|
|
4430
|
+
S4_pre_norm[:, Ndata_S4, l1, :, :] = (
|
|
4431
|
+
I1_small[:, j1].view(
|
|
4432
|
+
N_image, 1, 1, M3, N3
|
|
5274
4433
|
)
|
|
5275
|
-
|
|
4434
|
+
* self.backend.bk_conjugate(
|
|
4435
|
+
I12_w3_2_small.view(
|
|
4436
|
+
N_image, L, L, M3, N3
|
|
4437
|
+
)
|
|
4438
|
+
)
|
|
4439
|
+
)[
|
|
4440
|
+
..., edge_dx:-edge_dx, edge_dy:-edge_dy
|
|
4441
|
+
].mean(
|
|
4442
|
+
(-2, -1)
|
|
4443
|
+
) * fft_factor
|
|
5276
4444
|
if get_variance:
|
|
5277
|
-
S4_sigma[:,Ndata_S4,l1
|
|
5278
|
-
I1_small[:,j1].view(
|
|
5279
|
-
|
|
4445
|
+
S4_sigma[:, Ndata_S4, l1, :, :] = (
|
|
4446
|
+
I1_small[:, j1].view(
|
|
4447
|
+
N_image, 1, 1, M3, N3
|
|
5280
4448
|
)
|
|
5281
|
-
|
|
5282
|
-
|
|
5283
|
-
|
|
5284
|
-
|
|
5285
|
-
|
|
5286
|
-
|
|
5287
|
-
|
|
5288
|
-
|
|
5289
|
-
|
|
5290
|
-
|
|
5291
|
-
|
|
5292
|
-
|
|
4449
|
+
* self.backend.bk_conjugate(
|
|
4450
|
+
I12_w3_2_small.view(
|
|
4451
|
+
N_image, L, L, M3, N3
|
|
4452
|
+
)
|
|
4453
|
+
)
|
|
4454
|
+
)[
|
|
4455
|
+
...,
|
|
4456
|
+
edge_dx:-edge_dx,
|
|
4457
|
+
edge_dy:-edge_dy,
|
|
4458
|
+
].mean(
|
|
4459
|
+
(-2, -1)
|
|
4460
|
+
) * fft_factor
|
|
4461
|
+
|
|
4462
|
+
Ndata_S4 += 1
|
|
4463
|
+
|
|
4464
|
+
if normalization == "S2":
|
|
4465
|
+
if use_ref:
|
|
4466
|
+
P = (
|
|
4467
|
+
ref_S2[:, j3 : j3 + 1, :, None, None]
|
|
4468
|
+
* ref_S2[:, j2 : j2 + 1, None, :, None]
|
|
4469
|
+
) ** (0.5 * pseudo_coef)
|
|
4470
|
+
else:
|
|
4471
|
+
P = (
|
|
4472
|
+
S2[:, j3 : j3 + 1, :, None, None]
|
|
4473
|
+
* S2[:, j2 : j2 + 1, None, :, None]
|
|
4474
|
+
) ** (0.5 * pseudo_coef)
|
|
4475
|
+
|
|
4476
|
+
S4[:, beg_n:Ndata_S4, :, :, :] = (
|
|
4477
|
+
S4_pre_norm[:, beg_n:Ndata_S4, :, :, :].clone() / P
|
|
4478
|
+
)
|
|
4479
|
+
|
|
5293
4480
|
if get_variance:
|
|
5294
|
-
S4_sigma[:,beg_n:Ndata_S4
|
|
4481
|
+
S4_sigma[:, beg_n:Ndata_S4, :, :, :] = (
|
|
4482
|
+
S4_sigma[:, beg_n:Ndata_S4, :, :, :] / P
|
|
4483
|
+
)
|
|
5295
4484
|
else:
|
|
5296
|
-
S4=S4_pre_norm
|
|
5297
|
-
|
|
4485
|
+
S4 = S4_pre_norm
|
|
4486
|
+
|
|
5298
4487
|
# average over l1 to obtain simple isotropic statistics
|
|
5299
4488
|
if iso_ang:
|
|
5300
4489
|
S2_iso = S2.mean(-1)
|
|
5301
4490
|
S1_iso = S1.mean(-1)
|
|
5302
4491
|
for l1 in range(L):
|
|
5303
4492
|
for l2 in range(L):
|
|
5304
|
-
S3_iso[...,(l2-l1)%L] += S3[...,l1,l2]
|
|
4493
|
+
S3_iso[..., (l2 - l1) % L] += S3[..., l1, l2]
|
|
5305
4494
|
if data2 is not None:
|
|
5306
|
-
S3p_iso[...,(l2-l1)%L] += S3p[...,l1,l2]
|
|
4495
|
+
S3p_iso[..., (l2 - l1) % L] += S3p[..., l1, l2]
|
|
5307
4496
|
for l3 in range(L):
|
|
5308
|
-
S4_iso[...,(l2-l1)%L,(l3-l1)%L] += S4[
|
|
5309
|
-
|
|
4497
|
+
S4_iso[..., (l2 - l1) % L, (l3 - l1) % L] += S4[
|
|
4498
|
+
..., l1, l2, l3
|
|
4499
|
+
]
|
|
4500
|
+
S3_iso /= L
|
|
4501
|
+
S4_iso /= L
|
|
5310
4502
|
if data2 is not None:
|
|
5311
4503
|
S3p_iso /= L
|
|
5312
|
-
|
|
4504
|
+
|
|
5313
4505
|
if get_variance:
|
|
5314
4506
|
S2_sigma_iso = S2_sigma.mean(-1)
|
|
5315
4507
|
S1_sigma_iso = S1_sigma.mean(-1)
|
|
5316
4508
|
for l1 in range(L):
|
|
5317
4509
|
for l2 in range(L):
|
|
5318
|
-
S3_sigma_iso[...,(l2-l1)%L] += S3_sigma[...,l1,l2]
|
|
4510
|
+
S3_sigma_iso[..., (l2 - l1) % L] += S3_sigma[..., l1, l2]
|
|
5319
4511
|
if data2 is not None:
|
|
5320
|
-
S3p_sigma_iso[...,(l2-l1)%L] += S3p_sigma[
|
|
4512
|
+
S3p_sigma_iso[..., (l2 - l1) % L] += S3p_sigma[
|
|
4513
|
+
..., l1, l2
|
|
4514
|
+
]
|
|
5321
4515
|
for l3 in range(L):
|
|
5322
|
-
S4_sigma_iso[
|
|
5323
|
-
|
|
4516
|
+
S4_sigma_iso[
|
|
4517
|
+
..., (l2 - l1) % L, (l3 - l1) % L
|
|
4518
|
+
] += S4_sigma[..., l1, l2, l3]
|
|
4519
|
+
S3_sigma_iso /= L
|
|
4520
|
+
S4_sigma_iso /= L
|
|
5324
4521
|
if data2 is not None:
|
|
5325
4522
|
S3p_sigma_iso /= L
|
|
5326
|
-
|
|
5327
|
-
mean_data=self.backend.bk_zeros((N_image,1), dtype=data.dtype)
|
|
5328
|
-
std_data=self.backend.bk_zeros((N_image,1), dtype=data.dtype)
|
|
5329
|
-
|
|
4523
|
+
|
|
4524
|
+
mean_data = self.backend.bk_zeros((N_image, 1), dtype=data.dtype)
|
|
4525
|
+
std_data = self.backend.bk_zeros((N_image, 1), dtype=data.dtype)
|
|
4526
|
+
|
|
5330
4527
|
if data2 is None:
|
|
5331
|
-
mean_data[:,0]=data.mean((-2
|
|
5332
|
-
std_data[:,0]=data.std((-2
|
|
4528
|
+
mean_data[:, 0] = data.mean((-2, -1))
|
|
4529
|
+
std_data[:, 0] = data.std((-2, -1))
|
|
5333
4530
|
else:
|
|
5334
|
-
mean_data[:,0]=(data2*data).mean((-2
|
|
5335
|
-
std_data[:,0]=(data2*data).std((-2
|
|
5336
|
-
|
|
4531
|
+
mean_data[:, 0] = (data2 * data).mean((-2, -1))
|
|
4532
|
+
std_data[:, 0] = (data2 * data).std((-2, -1))
|
|
4533
|
+
|
|
5337
4534
|
if get_variance:
|
|
5338
|
-
ref_sigma={}
|
|
4535
|
+
ref_sigma = {}
|
|
5339
4536
|
if iso_ang:
|
|
5340
|
-
ref_sigma[
|
|
5341
|
-
ref_sigma[
|
|
5342
|
-
ref_sigma[
|
|
5343
|
-
ref_sigma[
|
|
4537
|
+
ref_sigma["std_data"] = std_data
|
|
4538
|
+
ref_sigma["S1_sigma"] = S1_sigma_iso
|
|
4539
|
+
ref_sigma["S2_sigma"] = S2_sigma_iso
|
|
4540
|
+
ref_sigma["S3_sigma"] = S3_sigma_iso
|
|
5344
4541
|
if data2 is not None:
|
|
5345
|
-
ref_sigma[
|
|
5346
|
-
ref_sigma[
|
|
4542
|
+
ref_sigma["S3p_sigma"] = S3p_sigma_iso
|
|
4543
|
+
ref_sigma["S4_sigma"] = S4_sigma_iso
|
|
5347
4544
|
else:
|
|
5348
|
-
ref_sigma[
|
|
5349
|
-
ref_sigma[
|
|
5350
|
-
ref_sigma[
|
|
5351
|
-
ref_sigma[
|
|
4545
|
+
ref_sigma["std_data"] = std_data
|
|
4546
|
+
ref_sigma["S1_sigma"] = S1_sigma
|
|
4547
|
+
ref_sigma["S2_sigma"] = S2_sigma
|
|
4548
|
+
ref_sigma["S3_sigma"] = S3_sigma
|
|
5352
4549
|
if data2 is not None:
|
|
5353
|
-
ref_sigma[
|
|
5354
|
-
ref_sigma[
|
|
5355
|
-
|
|
4550
|
+
ref_sigma["S3p_sigma"] = S3p_sigma
|
|
4551
|
+
ref_sigma["S4_sigma"] = S4_sigma
|
|
4552
|
+
|
|
5356
4553
|
if data2 is None:
|
|
5357
4554
|
if iso_ang:
|
|
5358
4555
|
if ref_sigma is not None:
|
|
5359
|
-
for_synthesis = self.backend.backend.cat(
|
|
5360
|
-
|
|
5361
|
-
|
|
5362
|
-
|
|
5363
|
-
|
|
5364
|
-
|
|
5365
|
-
|
|
5366
|
-
|
|
5367
|
-
|
|
5368
|
-
|
|
4556
|
+
for_synthesis = self.backend.backend.cat(
|
|
4557
|
+
(
|
|
4558
|
+
mean_data / ref_sigma["std_data"],
|
|
4559
|
+
std_data / ref_sigma["std_data"],
|
|
4560
|
+
(S2_iso / ref_sigma["S2_sigma"])
|
|
4561
|
+
.reshape((N_image, -1))
|
|
4562
|
+
.log(),
|
|
4563
|
+
(S1_iso / ref_sigma["S1_sigma"])
|
|
4564
|
+
.reshape((N_image, -1))
|
|
4565
|
+
.log(),
|
|
4566
|
+
(S3_iso / ref_sigma["S3_sigma"])
|
|
4567
|
+
.reshape((N_image, -1))
|
|
4568
|
+
.real,
|
|
4569
|
+
(S3_iso / ref_sigma["S3_sigma"])
|
|
4570
|
+
.reshape((N_image, -1))
|
|
4571
|
+
.imag,
|
|
4572
|
+
(S4_iso / ref_sigma["S4_sigma"])
|
|
4573
|
+
.reshape((N_image, -1))
|
|
4574
|
+
.real,
|
|
4575
|
+
(S4_iso / ref_sigma["S4_sigma"])
|
|
4576
|
+
.reshape((N_image, -1))
|
|
4577
|
+
.imag,
|
|
4578
|
+
),
|
|
4579
|
+
dim=-1,
|
|
4580
|
+
)
|
|
5369
4581
|
else:
|
|
5370
|
-
for_synthesis = self.backend.backend.cat(
|
|
5371
|
-
|
|
5372
|
-
|
|
5373
|
-
|
|
5374
|
-
|
|
5375
|
-
|
|
5376
|
-
|
|
5377
|
-
|
|
5378
|
-
|
|
5379
|
-
|
|
4582
|
+
for_synthesis = self.backend.backend.cat(
|
|
4583
|
+
(
|
|
4584
|
+
mean_data / std_data,
|
|
4585
|
+
std_data,
|
|
4586
|
+
S2_iso.reshape((N_image, -1)).log(),
|
|
4587
|
+
S1_iso.reshape((N_image, -1)).log(),
|
|
4588
|
+
S3_iso.reshape((N_image, -1)).real,
|
|
4589
|
+
S3_iso.reshape((N_image, -1)).imag,
|
|
4590
|
+
S4_iso.reshape((N_image, -1)).real,
|
|
4591
|
+
S4_iso.reshape((N_image, -1)).imag,
|
|
4592
|
+
),
|
|
4593
|
+
dim=-1,
|
|
4594
|
+
)
|
|
5380
4595
|
else:
|
|
5381
4596
|
if ref_sigma is not None:
|
|
5382
|
-
for_synthesis = self.backend.backend.cat(
|
|
5383
|
-
|
|
5384
|
-
|
|
5385
|
-
|
|
5386
|
-
|
|
5387
|
-
|
|
5388
|
-
|
|
5389
|
-
|
|
5390
|
-
|
|
5391
|
-
|
|
4597
|
+
for_synthesis = self.backend.backend.cat(
|
|
4598
|
+
(
|
|
4599
|
+
mean_data / ref_sigma["std_data"],
|
|
4600
|
+
std_data / ref_sigma["std_data"],
|
|
4601
|
+
(S2 / ref_sigma["S2_sigma"])
|
|
4602
|
+
.reshape((N_image, -1))
|
|
4603
|
+
.log(),
|
|
4604
|
+
(S1 / ref_sigma["S1_sigma"])
|
|
4605
|
+
.reshape((N_image, -1))
|
|
4606
|
+
.log(),
|
|
4607
|
+
(S3 / ref_sigma["S3_sigma"])
|
|
4608
|
+
.reshape((N_image, -1))
|
|
4609
|
+
.real,
|
|
4610
|
+
(S3 / ref_sigma["S3_sigma"])
|
|
4611
|
+
.reshape((N_image, -1))
|
|
4612
|
+
.imag,
|
|
4613
|
+
(S4 / ref_sigma["S4_sigma"])
|
|
4614
|
+
.reshape((N_image, -1))
|
|
4615
|
+
.real,
|
|
4616
|
+
(S4 / ref_sigma["S4_sigma"])
|
|
4617
|
+
.reshape((N_image, -1))
|
|
4618
|
+
.imag,
|
|
4619
|
+
),
|
|
4620
|
+
dim=-1,
|
|
4621
|
+
)
|
|
5392
4622
|
else:
|
|
5393
|
-
for_synthesis = self.backend.backend.cat(
|
|
5394
|
-
|
|
5395
|
-
|
|
5396
|
-
|
|
5397
|
-
|
|
5398
|
-
|
|
5399
|
-
|
|
5400
|
-
|
|
5401
|
-
|
|
5402
|
-
|
|
4623
|
+
for_synthesis = self.backend.backend.cat(
|
|
4624
|
+
(
|
|
4625
|
+
mean_data / std_data,
|
|
4626
|
+
std_data,
|
|
4627
|
+
S2.reshape((N_image, -1)).log(),
|
|
4628
|
+
S1.reshape((N_image, -1)).log(),
|
|
4629
|
+
S3.reshape((N_image, -1)).real,
|
|
4630
|
+
S3.reshape((N_image, -1)).imag,
|
|
4631
|
+
S4.reshape((N_image, -1)).real,
|
|
4632
|
+
S4.reshape((N_image, -1)).imag,
|
|
4633
|
+
),
|
|
4634
|
+
dim=-1,
|
|
4635
|
+
)
|
|
5403
4636
|
else:
|
|
5404
4637
|
if iso_ang:
|
|
5405
4638
|
if ref_sigma is not None:
|
|
5406
|
-
for_synthesis = self.backend.backend.cat(
|
|
5407
|
-
|
|
5408
|
-
|
|
5409
|
-
|
|
5410
|
-
|
|
5411
|
-
|
|
5412
|
-
|
|
5413
|
-
|
|
5414
|
-
|
|
5415
|
-
|
|
5416
|
-
|
|
5417
|
-
|
|
4639
|
+
for_synthesis = self.backend.backend.cat(
|
|
4640
|
+
(
|
|
4641
|
+
mean_data / ref_sigma["std_data"],
|
|
4642
|
+
std_data / ref_sigma["std_data"],
|
|
4643
|
+
(S2_iso / ref_sigma["S2_sigma"]).reshape((N_image, -1)),
|
|
4644
|
+
(S1_iso / ref_sigma["S1_sigma"]).reshape((N_image, -1)),
|
|
4645
|
+
(S3_iso / ref_sigma["S3_sigma"])
|
|
4646
|
+
.reshape((N_image, -1))
|
|
4647
|
+
.real,
|
|
4648
|
+
(S3_iso / ref_sigma["S3_sigma"])
|
|
4649
|
+
.reshape((N_image, -1))
|
|
4650
|
+
.imag,
|
|
4651
|
+
(S3p_iso / ref_sigma["S3p_sigma"])
|
|
4652
|
+
.reshape((N_image, -1))
|
|
4653
|
+
.real,
|
|
4654
|
+
(S3p_iso / ref_sigma["S3p_sigma"])
|
|
4655
|
+
.reshape((N_image, -1))
|
|
4656
|
+
.imag,
|
|
4657
|
+
(S4_iso / ref_sigma["S4_sigma"])
|
|
4658
|
+
.reshape((N_image, -1))
|
|
4659
|
+
.real,
|
|
4660
|
+
(S4_iso / ref_sigma["S4_sigma"])
|
|
4661
|
+
.reshape((N_image, -1))
|
|
4662
|
+
.imag,
|
|
4663
|
+
),
|
|
4664
|
+
dim=-1,
|
|
4665
|
+
)
|
|
5418
4666
|
else:
|
|
5419
|
-
for_synthesis = self.backend.backend.cat(
|
|
5420
|
-
|
|
5421
|
-
|
|
5422
|
-
|
|
5423
|
-
|
|
5424
|
-
|
|
5425
|
-
|
|
5426
|
-
|
|
5427
|
-
|
|
5428
|
-
|
|
5429
|
-
|
|
5430
|
-
|
|
4667
|
+
for_synthesis = self.backend.backend.cat(
|
|
4668
|
+
(
|
|
4669
|
+
mean_data / std_data,
|
|
4670
|
+
std_data,
|
|
4671
|
+
S2_iso.reshape((N_image, -1)),
|
|
4672
|
+
S1_iso.reshape((N_image, -1)),
|
|
4673
|
+
S3_iso.reshape((N_image, -1)).real,
|
|
4674
|
+
S3_iso.reshape((N_image, -1)).imag,
|
|
4675
|
+
S3p_iso.reshape((N_image, -1)).real,
|
|
4676
|
+
S3p_iso.reshape((N_image, -1)).imag,
|
|
4677
|
+
S4_iso.reshape((N_image, -1)).real,
|
|
4678
|
+
S4_iso.reshape((N_image, -1)).imag,
|
|
4679
|
+
),
|
|
4680
|
+
dim=-1,
|
|
4681
|
+
)
|
|
5431
4682
|
else:
|
|
5432
4683
|
if ref_sigma is not None:
|
|
5433
|
-
for_synthesis = self.backend.backend.cat(
|
|
5434
|
-
|
|
5435
|
-
|
|
5436
|
-
|
|
5437
|
-
|
|
5438
|
-
|
|
5439
|
-
|
|
5440
|
-
|
|
5441
|
-
|
|
5442
|
-
|
|
5443
|
-
|
|
5444
|
-
|
|
4684
|
+
for_synthesis = self.backend.backend.cat(
|
|
4685
|
+
(
|
|
4686
|
+
mean_data / ref_sigma["std_data"],
|
|
4687
|
+
std_data / ref_sigma["std_data"],
|
|
4688
|
+
(S2 / ref_sigma["S2_sigma"]).reshape((N_image, -1)),
|
|
4689
|
+
(S1 / ref_sigma["S1_sigma"]).reshape((N_image, -1)),
|
|
4690
|
+
(S3 / ref_sigma["S3_sigma"])
|
|
4691
|
+
.reshape((N_image, -1))
|
|
4692
|
+
.real,
|
|
4693
|
+
(S3 / ref_sigma["S3_sigma"])
|
|
4694
|
+
.reshape((N_image, -1))
|
|
4695
|
+
.imag,
|
|
4696
|
+
(S3p / ref_sigma["S3p_sigma"])
|
|
4697
|
+
.reshape((N_image, -1))
|
|
4698
|
+
.real,
|
|
4699
|
+
(S3p / ref_sigma["S3p_sigma"])
|
|
4700
|
+
.reshape((N_image, -1))
|
|
4701
|
+
.imag,
|
|
4702
|
+
(S4 / ref_sigma["S4_sigma"])
|
|
4703
|
+
.reshape((N_image, -1))
|
|
4704
|
+
.real,
|
|
4705
|
+
(S4 / ref_sigma["S4_sigma"])
|
|
4706
|
+
.reshape((N_image, -1))
|
|
4707
|
+
.imag,
|
|
4708
|
+
),
|
|
4709
|
+
dim=-1,
|
|
4710
|
+
)
|
|
5445
4711
|
else:
|
|
5446
|
-
for_synthesis = self.backend.backend.cat(
|
|
5447
|
-
|
|
5448
|
-
|
|
5449
|
-
|
|
5450
|
-
|
|
5451
|
-
|
|
5452
|
-
|
|
5453
|
-
|
|
5454
|
-
|
|
5455
|
-
|
|
5456
|
-
|
|
5457
|
-
|
|
5458
|
-
|
|
5459
|
-
|
|
5460
|
-
|
|
5461
|
-
|
|
4712
|
+
for_synthesis = self.backend.backend.cat(
|
|
4713
|
+
(
|
|
4714
|
+
mean_data / std_data,
|
|
4715
|
+
std_data,
|
|
4716
|
+
S2.reshape((N_image, -1)),
|
|
4717
|
+
S1.reshape((N_image, -1)),
|
|
4718
|
+
S3.reshape((N_image, -1)).real,
|
|
4719
|
+
S3.reshape((N_image, -1)).imag,
|
|
4720
|
+
S3p.reshape((N_image, -1)).real,
|
|
4721
|
+
S3p.reshape((N_image, -1)).imag,
|
|
4722
|
+
S4.reshape((N_image, -1)).real,
|
|
4723
|
+
S4.reshape((N_image, -1)).imag,
|
|
4724
|
+
),
|
|
4725
|
+
dim=-1,
|
|
4726
|
+
)
|
|
4727
|
+
|
|
4728
|
+
if not use_ref:
|
|
4729
|
+
self.ref_scattering_cov_S2 = S2
|
|
4730
|
+
|
|
5462
4731
|
if get_variance:
|
|
5463
|
-
return for_synthesis,ref_sigma
|
|
5464
|
-
|
|
4732
|
+
return for_synthesis, ref_sigma
|
|
4733
|
+
|
|
5465
4734
|
return for_synthesis
|
|
5466
|
-
|
|
5467
|
-
if (M,N,J,L) not in self.filters_set:
|
|
5468
|
-
self.filters_set[(M,N,J,L)] = self.computer_filter(
|
|
5469
|
-
|
|
5470
|
-
|
|
5471
|
-
|
|
5472
|
-
|
|
4735
|
+
|
|
4736
|
+
if (M, N, J, L) not in self.filters_set:
|
|
4737
|
+
self.filters_set[(M, N, J, L)] = self.computer_filter(
|
|
4738
|
+
M, N, J, L
|
|
4739
|
+
) # self.computer_filter(M,N,J,L)
|
|
4740
|
+
|
|
4741
|
+
filters_set = self.filters_set[(M, N, J, L)]
|
|
4742
|
+
|
|
4743
|
+
# weight = self.weight
|
|
5473
4744
|
if use_ref:
|
|
5474
|
-
if normalization==
|
|
4745
|
+
if normalization == "S2":
|
|
5475
4746
|
ref_S2 = self.ref_scattering_cov_S2
|
|
5476
|
-
else:
|
|
5477
|
-
ref_P11 = self.ref_scattering_cov[
|
|
4747
|
+
else:
|
|
4748
|
+
ref_P11 = self.ref_scattering_cov["P11"]
|
|
5478
4749
|
|
|
5479
4750
|
# convert numpy array input into self.backend.bk_ tensors
|
|
5480
4751
|
data = self.backend.bk_cast(data)
|
|
5481
|
-
data_f = self.backend.bk_fftn(data, dim=(-2
|
|
4752
|
+
data_f = self.backend.bk_fftn(data, dim=(-2, -1))
|
|
5482
4753
|
if data2 is not None:
|
|
5483
4754
|
data2 = self.backend.bk_cast(data2)
|
|
5484
|
-
data2_f = self.backend.bk_fftn(data2, dim=(-2
|
|
5485
|
-
|
|
4755
|
+
data2_f = self.backend.bk_fftn(data2, dim=(-2, -1))
|
|
4756
|
+
|
|
5486
4757
|
# initialize tensors for scattering coefficients
|
|
5487
|
-
|
|
5488
|
-
Ndata_S3 = J*(J+1)//2
|
|
5489
|
-
Ndata_S4 = J*(J+1)*(J+2)//6
|
|
5490
|
-
J_S4={}
|
|
5491
|
-
|
|
4758
|
+
|
|
4759
|
+
Ndata_S3 = J * (J + 1) // 2
|
|
4760
|
+
Ndata_S4 = J * (J + 1) * (J + 2) // 6
|
|
4761
|
+
J_S4 = {}
|
|
4762
|
+
|
|
5492
4763
|
S3 = []
|
|
5493
4764
|
if data2 is not None:
|
|
5494
4765
|
S3p = []
|
|
5495
|
-
S4_pre_norm = []
|
|
5496
|
-
S4 = []
|
|
5497
|
-
|
|
4766
|
+
S4_pre_norm = []
|
|
4767
|
+
S4 = []
|
|
4768
|
+
|
|
5498
4769
|
# variance
|
|
5499
4770
|
if get_variance:
|
|
5500
|
-
S3_sigma = []
|
|
4771
|
+
S3_sigma = []
|
|
5501
4772
|
if data2 is not None:
|
|
5502
4773
|
S3p_sigma = []
|
|
5503
|
-
S4_sigma = []
|
|
5504
|
-
|
|
4774
|
+
S4_sigma = []
|
|
4775
|
+
|
|
5505
4776
|
if iso_ang:
|
|
5506
4777
|
S3_iso = []
|
|
5507
4778
|
if data2 is not None:
|
|
5508
4779
|
S3p_iso = []
|
|
5509
|
-
|
|
4780
|
+
|
|
5510
4781
|
S4_iso = []
|
|
5511
4782
|
if get_variance:
|
|
5512
4783
|
S3_sigma_iso = []
|
|
5513
4784
|
if data2 is not None:
|
|
5514
|
-
S3p_sigma_iso = []
|
|
5515
|
-
S4_sigma_iso = []
|
|
5516
|
-
|
|
4785
|
+
S3p_sigma_iso = []
|
|
4786
|
+
S4_sigma_iso = []
|
|
4787
|
+
|
|
5517
4788
|
#
|
|
5518
|
-
if edge:
|
|
5519
|
-
if (M,N,J) not in self.edge_masks:
|
|
5520
|
-
self.edge_masks[(M,N,J)] = self.get_edge_masks(M,N,J)
|
|
5521
|
-
edge_mask = self.edge_masks[(M,N,J)]
|
|
5522
|
-
else:
|
|
4789
|
+
if edge:
|
|
4790
|
+
if (M, N, J) not in self.edge_masks:
|
|
4791
|
+
self.edge_masks[(M, N, J)] = self.get_edge_masks(M, N, J)
|
|
4792
|
+
edge_mask = self.edge_masks[(M, N, J)]
|
|
4793
|
+
else:
|
|
5523
4794
|
edge_mask = 1
|
|
5524
|
-
|
|
4795
|
+
|
|
5525
4796
|
# calculate scattering fields
|
|
5526
4797
|
if data2 is None:
|
|
5527
4798
|
if self.use_2D:
|
|
5528
4799
|
if len(data.shape) == 2:
|
|
5529
|
-
I1 = self.backend.bk_abs(
|
|
5530
|
-
|
|
5531
|
-
|
|
4800
|
+
I1 = self.backend.bk_abs(
|
|
4801
|
+
self.backend.bk_ifftn(
|
|
4802
|
+
data_f[None, None, None, :, :]
|
|
4803
|
+
* filters_set[None, :J, :, :, :],
|
|
4804
|
+
dim=(-2, -1),
|
|
4805
|
+
)
|
|
4806
|
+
)
|
|
5532
4807
|
else:
|
|
5533
|
-
I1 = self.backend.bk_abs(
|
|
5534
|
-
|
|
5535
|
-
|
|
4808
|
+
I1 = self.backend.bk_abs(
|
|
4809
|
+
self.backend.bk_ifftn(
|
|
4810
|
+
data_f[:, None, None, :, :]
|
|
4811
|
+
* filters_set[None, :J, :, :, :],
|
|
4812
|
+
dim=(-2, -1),
|
|
4813
|
+
)
|
|
4814
|
+
)
|
|
5536
4815
|
elif self.use_1D:
|
|
5537
4816
|
if len(data.shape) == 1:
|
|
5538
|
-
I1 = self.backend.bk_abs(
|
|
5539
|
-
|
|
5540
|
-
|
|
4817
|
+
I1 = self.backend.bk_abs(
|
|
4818
|
+
self.backend.bk_ifftn(
|
|
4819
|
+
data_f[None, None, None, :] * filters_set[None, :J, :, :],
|
|
4820
|
+
dim=(-1),
|
|
4821
|
+
)
|
|
4822
|
+
)
|
|
5541
4823
|
else:
|
|
5542
|
-
I1 = self.backend.bk_abs(
|
|
5543
|
-
|
|
5544
|
-
|
|
4824
|
+
I1 = self.backend.bk_abs(
|
|
4825
|
+
self.backend.bk_ifftn(
|
|
4826
|
+
data_f[:, None, None, :] * filters_set[None, :J, :, :],
|
|
4827
|
+
dim=(-1),
|
|
4828
|
+
)
|
|
4829
|
+
)
|
|
5545
4830
|
else:
|
|
5546
|
-
print(
|
|
5547
|
-
|
|
5548
|
-
S2 = self.backend.bk_reduce_mean((I1**2 * edge_mask),axis=(-2
|
|
5549
|
-
S1 = self.backend.bk_reduce_mean(I1 * edge_mask,axis=(-2
|
|
4831
|
+
print("todo")
|
|
4832
|
+
|
|
4833
|
+
S2 = self.backend.bk_reduce_mean((I1**2 * edge_mask), axis=(-2, -1))
|
|
4834
|
+
S1 = self.backend.bk_reduce_mean(I1 * edge_mask, axis=(-2, -1))
|
|
5550
4835
|
|
|
5551
4836
|
if get_variance:
|
|
5552
|
-
S2_sigma = self.backend.bk_reduce_std(
|
|
5553
|
-
|
|
5554
|
-
|
|
5555
|
-
|
|
5556
|
-
|
|
5557
|
-
|
|
4837
|
+
S2_sigma = self.backend.bk_reduce_std(
|
|
4838
|
+
(I1**2 * edge_mask), axis=(-2, -1)
|
|
4839
|
+
)
|
|
4840
|
+
S1_sigma = self.backend.bk_reduce_std((I1 * edge_mask), axis=(-2, -1))
|
|
4841
|
+
|
|
4842
|
+
I1_f = self.backend.bk_fftn(I1, dim=(-2, -1))
|
|
4843
|
+
|
|
4844
|
+
else:
|
|
5558
4845
|
if self.use_2D:
|
|
5559
4846
|
if len(data.shape) == 2:
|
|
5560
4847
|
I1 = self.backend.bk_ifftn(
|
|
5561
|
-
data_f[None,None,None
|
|
4848
|
+
data_f[None, None, None, :, :] * filters_set[None, :J, :, :, :],
|
|
4849
|
+
dim=(-2, -1),
|
|
5562
4850
|
)
|
|
5563
4851
|
I2 = self.backend.bk_ifftn(
|
|
5564
|
-
data2_f[None,None,None
|
|
4852
|
+
data2_f[None, None, None, :, :]
|
|
4853
|
+
* filters_set[None, :J, :, :, :],
|
|
4854
|
+
dim=(-2, -1),
|
|
5565
4855
|
)
|
|
5566
4856
|
else:
|
|
5567
4857
|
I1 = self.backend.bk_ifftn(
|
|
5568
|
-
data_f[:,None,None
|
|
4858
|
+
data_f[:, None, None, :, :] * filters_set[None, :J, :, :, :],
|
|
4859
|
+
dim=(-2, -1),
|
|
5569
4860
|
)
|
|
5570
4861
|
I2 = self.backend.bk_ifftn(
|
|
5571
|
-
data2_f[:,None,None
|
|
4862
|
+
data2_f[:, None, None, :, :] * filters_set[None, :J, :, :, :],
|
|
4863
|
+
dim=(-2, -1),
|
|
5572
4864
|
)
|
|
5573
4865
|
elif self.use_1D:
|
|
5574
4866
|
if len(data.shape) == 1:
|
|
5575
4867
|
I1 = self.backend.bk_ifftn(
|
|
5576
|
-
data_f[None,None,None
|
|
4868
|
+
data_f[None, None, None, :] * filters_set[None, :J, :, :],
|
|
4869
|
+
dim=(-1),
|
|
5577
4870
|
)
|
|
5578
4871
|
I2 = self.backend.bk_ifftn(
|
|
5579
|
-
data2_f[None,None,None
|
|
4872
|
+
data2_f[None, None, None, :] * filters_set[None, :J, :, :],
|
|
4873
|
+
dim=(-1),
|
|
5580
4874
|
)
|
|
5581
4875
|
else:
|
|
5582
4876
|
I1 = self.backend.bk_ifftn(
|
|
5583
|
-
data_f[:,None,None
|
|
4877
|
+
data_f[:, None, None, :] * filters_set[None, :J, :, :], dim=(-1)
|
|
5584
4878
|
)
|
|
5585
4879
|
I2 = self.backend.bk_ifftn(
|
|
5586
|
-
data2_f[:,None,None
|
|
4880
|
+
data2_f[:, None, None, :] * filters_set[None, :J, :, :],
|
|
4881
|
+
dim=(-1),
|
|
5587
4882
|
)
|
|
5588
4883
|
else:
|
|
5589
|
-
print(
|
|
5590
|
-
|
|
5591
|
-
I1=self.backend.bk_real(I1*self.backend.bk_conjugate(I2))
|
|
5592
|
-
|
|
5593
|
-
S2 = self.backend.bk_reduce_mean((I1 * edge_mask),axis=(-2
|
|
4884
|
+
print("todo")
|
|
4885
|
+
|
|
4886
|
+
I1 = self.backend.bk_real(I1 * self.backend.bk_conjugate(I2))
|
|
4887
|
+
|
|
4888
|
+
S2 = self.backend.bk_reduce_mean((I1 * edge_mask), axis=(-2, -1))
|
|
5594
4889
|
if get_variance:
|
|
5595
|
-
S2_sigma = self.backend.bk_reduce_std((I1 * edge_mask),axis=(-2
|
|
5596
|
-
|
|
5597
|
-
I1=self.backend.bk_L1(I1)
|
|
5598
|
-
|
|
5599
|
-
S1
|
|
4890
|
+
S2_sigma = self.backend.bk_reduce_std((I1 * edge_mask), axis=(-2, -1))
|
|
4891
|
+
|
|
4892
|
+
I1 = self.backend.bk_L1(I1)
|
|
4893
|
+
|
|
4894
|
+
S1 = self.backend.bk_reduce_mean((I1 * edge_mask), axis=(-2, -1))
|
|
5600
4895
|
|
|
5601
4896
|
if get_variance:
|
|
5602
|
-
S1_sigma
|
|
5603
|
-
|
|
5604
|
-
I1_f= self.backend.bk_fftn(I1, dim=(-2
|
|
5605
|
-
|
|
4897
|
+
S1_sigma = self.backend.bk_reduce_std((I1 * edge_mask), axis=(-2, -1))
|
|
4898
|
+
|
|
4899
|
+
I1_f = self.backend.bk_fftn(I1, dim=(-2, -1))
|
|
4900
|
+
|
|
5606
4901
|
if pseudo_coef != 1:
|
|
5607
4902
|
I1 = I1**pseudo_coef
|
|
5608
|
-
|
|
5609
|
-
Ndata_S3=0
|
|
5610
|
-
Ndata_S4=0
|
|
5611
|
-
|
|
4903
|
+
|
|
4904
|
+
Ndata_S3 = 0
|
|
4905
|
+
Ndata_S4 = 0
|
|
4906
|
+
|
|
5612
4907
|
# calculate the covariance and correlations of the scattering fields
|
|
5613
4908
|
# only use the low-k Fourier coefs when calculating large-j scattering coefs.
|
|
5614
|
-
for j3 in range(0,J):
|
|
5615
|
-
J_S4[j3]=Ndata_S4
|
|
5616
|
-
|
|
5617
|
-
dx3, dy3 = self.get_dxdy(j3,M,N)
|
|
5618
|
-
I1_f_small = self.cut_high_k_off(
|
|
4909
|
+
for j3 in range(0, J):
|
|
4910
|
+
J_S4[j3] = Ndata_S4
|
|
4911
|
+
|
|
4912
|
+
dx3, dy3 = self.get_dxdy(j3, M, N)
|
|
4913
|
+
I1_f_small = self.cut_high_k_off(
|
|
4914
|
+
I1_f[:, : j3 + 1], dx3, dy3
|
|
4915
|
+
) # Nimage, J, L, x, y
|
|
5619
4916
|
data_f_small = self.cut_high_k_off(data_f, dx3, dy3)
|
|
5620
4917
|
if data2 is not None:
|
|
5621
4918
|
data2_f_small = self.cut_high_k_off(data2_f, dx3, dy3)
|
|
5622
4919
|
if edge:
|
|
5623
|
-
I1_small = self.backend.bk_ifftn(I1_f_small, dim=(-2
|
|
5624
|
-
data_small = self.backend.bk_ifftn(
|
|
4920
|
+
I1_small = self.backend.bk_ifftn(I1_f_small, dim=(-2, -1), norm="ortho")
|
|
4921
|
+
data_small = self.backend.bk_ifftn(
|
|
4922
|
+
data_f_small, dim=(-2, -1), norm="ortho"
|
|
4923
|
+
)
|
|
5625
4924
|
if data2 is not None:
|
|
5626
|
-
data2_small = self.backend.bk_ifftn(
|
|
5627
|
-
|
|
4925
|
+
data2_small = self.backend.bk_ifftn(
|
|
4926
|
+
data2_f_small, dim=(-2, -1), norm="ortho"
|
|
4927
|
+
)
|
|
4928
|
+
wavelet_f3 = self.cut_high_k_off(filters_set[j3], dx3, dy3) # L,x,y
|
|
5628
4929
|
_, M3, N3 = wavelet_f3.shape
|
|
5629
4930
|
wavelet_f3_squared = wavelet_f3**2
|
|
5630
|
-
edge_dx = min(4, int(2**j3*dx3*2/M))
|
|
5631
|
-
edge_dy = min(4, int(2**j3*dy3*2/N))
|
|
4931
|
+
edge_dx = min(4, int(2**j3 * dx3 * 2 / M))
|
|
4932
|
+
edge_dy = min(4, int(2**j3 * dy3 * 2 / N))
|
|
5632
4933
|
# a normalization change due to the cutoff of frequency space
|
|
5633
4934
|
if self.all_bk_type == "float32":
|
|
5634
|
-
fft_factor = np.complex64(1 /(M3*N3) * (M3*N3/M/N)**2)
|
|
4935
|
+
fft_factor = np.complex64(1 / (M3 * N3) * (M3 * N3 / M / N) ** 2)
|
|
5635
4936
|
else:
|
|
5636
|
-
fft_factor = np.complex128(1 /(M3*N3) * (M3*N3/M/N)**2)
|
|
5637
|
-
for j2 in range(0,j3+1):
|
|
5638
|
-
#I1_f2_wf3_small = I1_f_small[:,j2].view(N_image,L,1,M3,N3) * wavelet_f3.view(1,1,L,M3,N3)
|
|
5639
|
-
#I1_f2_wf3_2_small = I1_f_small[:,j2].view(N_image,L,1,M3,N3) * wavelet_f3_squared.view(1,1,L,M3,N3)
|
|
5640
|
-
I1_f2_wf3_small = self.backend.bk_reshape(
|
|
5641
|
-
|
|
4937
|
+
fft_factor = np.complex128(1 / (M3 * N3) * (M3 * N3 / M / N) ** 2)
|
|
4938
|
+
for j2 in range(0, j3 + 1):
|
|
4939
|
+
# I1_f2_wf3_small = I1_f_small[:,j2].view(N_image,L,1,M3,N3) * wavelet_f3.view(1,1,L,M3,N3)
|
|
4940
|
+
# I1_f2_wf3_2_small = I1_f_small[:,j2].view(N_image,L,1,M3,N3) * wavelet_f3_squared.view(1,1,L,M3,N3)
|
|
4941
|
+
I1_f2_wf3_small = self.backend.bk_reshape(
|
|
4942
|
+
I1_f_small[:, j2], [N_image, 1, L, 1, M3, N3]
|
|
4943
|
+
) * self.backend.bk_reshape(wavelet_f3, [1, 1, 1, L, M3, N3])
|
|
4944
|
+
I1_f2_wf3_2_small = self.backend.bk_reshape(
|
|
4945
|
+
I1_f_small[:, j2], [N_image, 1, L, 1, M3, N3]
|
|
4946
|
+
) * self.backend.bk_reshape(wavelet_f3_squared, [1, 1, 1, L, M3, N3])
|
|
5642
4947
|
if edge:
|
|
5643
|
-
I12_w3_small = self.backend.bk_ifftn(
|
|
5644
|
-
|
|
4948
|
+
I12_w3_small = self.backend.bk_ifftn(
|
|
4949
|
+
I1_f2_wf3_small, dim=(-2, -1), norm="ortho"
|
|
4950
|
+
)
|
|
4951
|
+
I12_w3_2_small = self.backend.bk_ifftn(
|
|
4952
|
+
I1_f2_wf3_2_small, dim=(-2, -1), norm="ortho"
|
|
4953
|
+
)
|
|
5645
4954
|
if use_ref:
|
|
5646
|
-
if normalization==
|
|
5647
|
-
norm_factor_S3 = (
|
|
5648
|
-
|
|
5649
|
-
|
|
5650
|
-
|
|
5651
|
-
norm_factor_S3 = self.backend.bk_complex(
|
|
4955
|
+
if normalization == "P11":
|
|
4956
|
+
norm_factor_S3 = (
|
|
4957
|
+
ref_S2[:, None, j3, :]
|
|
4958
|
+
* ref_P11[:, j2, j3, :, :] ** pseudo_coef
|
|
4959
|
+
) ** 0.5
|
|
4960
|
+
norm_factor_S3 = self.backend.bk_complex(
|
|
4961
|
+
norm_factor_S3, 0 * norm_factor_S3
|
|
4962
|
+
)
|
|
4963
|
+
elif normalization == "S2":
|
|
4964
|
+
norm_factor_S3 = (
|
|
4965
|
+
ref_S2[:, None, j3, :]
|
|
4966
|
+
* ref_S2[:, j2, :, None] ** pseudo_coef
|
|
4967
|
+
) ** 0.5
|
|
4968
|
+
norm_factor_S3 = self.backend.bk_complex(
|
|
4969
|
+
norm_factor_S3, 0 * norm_factor_S3
|
|
4970
|
+
)
|
|
5652
4971
|
else:
|
|
5653
4972
|
norm_factor_S3 = C_ONE
|
|
5654
4973
|
else:
|
|
5655
|
-
if normalization==
|
|
4974
|
+
if normalization == "P11":
|
|
5656
4975
|
# [N_image,l2,l3,x,y]
|
|
5657
|
-
P11_temp =
|
|
5658
|
-
|
|
5659
|
-
|
|
5660
|
-
|
|
5661
|
-
|
|
5662
|
-
|
|
4976
|
+
P11_temp = (
|
|
4977
|
+
self.backend.bk_reduce_mean(
|
|
4978
|
+
(I1_f2_wf3_small.abs() ** 2), axis=(-2, -1)
|
|
4979
|
+
)
|
|
4980
|
+
* fft_factor
|
|
4981
|
+
)
|
|
4982
|
+
norm_factor_S3 = (
|
|
4983
|
+
S2[:, None, j3, :] * P11_temp**pseudo_coef
|
|
4984
|
+
) ** 0.5
|
|
4985
|
+
norm_factor_S3 = self.backend.bk_complex(
|
|
4986
|
+
norm_factor_S3, 0 * norm_factor_S3
|
|
4987
|
+
)
|
|
4988
|
+
elif normalization == "S2":
|
|
4989
|
+
norm_factor_S3 = (
|
|
4990
|
+
S2[:, None, j3, None, :]
|
|
4991
|
+
* S2[:, None, j2, :, None] ** pseudo_coef
|
|
4992
|
+
) ** 0.5
|
|
4993
|
+
norm_factor_S3 = self.backend.bk_complex(
|
|
4994
|
+
norm_factor_S3, 0 * norm_factor_S3
|
|
4995
|
+
)
|
|
5663
4996
|
else:
|
|
5664
4997
|
norm_factor_S3 = C_ONE
|
|
5665
|
-
|
|
5666
4998
|
|
|
5667
4999
|
if not edge:
|
|
5668
|
-
S3.append(
|
|
5669
|
-
self.backend.
|
|
5670
|
-
|
|
5000
|
+
S3.append(
|
|
5001
|
+
self.backend.bk_reduce_mean(
|
|
5002
|
+
self.backend.bk_reshape(
|
|
5003
|
+
data_f_small, [N_image, 1, 1, 1, M3, N3]
|
|
5004
|
+
)
|
|
5005
|
+
* self.backend.bk_conjugate(I1_f2_wf3_small),
|
|
5006
|
+
axis=(-2, -1),
|
|
5007
|
+
)
|
|
5008
|
+
* fft_factor
|
|
5009
|
+
/ norm_factor_S3
|
|
5010
|
+
)
|
|
5671
5011
|
if get_variance:
|
|
5672
|
-
S3_sigma.append(
|
|
5673
|
-
self.backend.
|
|
5674
|
-
|
|
5012
|
+
S3_sigma.append(
|
|
5013
|
+
self.backend.bk_reduce_std(
|
|
5014
|
+
self.backend.bk_reshape(
|
|
5015
|
+
data_f_small, [N_image, 1, 1, 1, M3, N3]
|
|
5016
|
+
)
|
|
5017
|
+
* self.backend.bk_conjugate(I1_f2_wf3_small),
|
|
5018
|
+
axis=(-2, -1),
|
|
5019
|
+
)
|
|
5020
|
+
* fft_factor
|
|
5021
|
+
/ norm_factor_S3
|
|
5022
|
+
)
|
|
5675
5023
|
else:
|
|
5676
|
-
S3.append(
|
|
5677
|
-
|
|
5678
|
-
|
|
5024
|
+
S3.append(
|
|
5025
|
+
self.backend.bk_reduce_mean(
|
|
5026
|
+
(
|
|
5027
|
+
self.backend.bk_reshape(
|
|
5028
|
+
data_small, [N_image, 1, 1, 1, M3, N3]
|
|
5029
|
+
)
|
|
5030
|
+
* self.backend.bk_conjugate(I12_w3_small)
|
|
5031
|
+
)[..., edge_dx : M3 - edge_dx, edge_dy : N3 - edge_dy],
|
|
5032
|
+
axis=(-2, -1),
|
|
5033
|
+
)
|
|
5034
|
+
* fft_factor
|
|
5035
|
+
/ norm_factor_S3
|
|
5036
|
+
)
|
|
5679
5037
|
if get_variance:
|
|
5680
|
-
S3_sigma.apend(
|
|
5681
|
-
|
|
5682
|
-
|
|
5038
|
+
S3_sigma.apend(
|
|
5039
|
+
self.backend.bk_reduce_std(
|
|
5040
|
+
(
|
|
5041
|
+
self.backend.bk_reshape(
|
|
5042
|
+
data_small, [N_image, 1, 1, 1, M3, N3]
|
|
5043
|
+
)
|
|
5044
|
+
* self.backend.bk_conjugate(I12_w3_small)
|
|
5045
|
+
)[..., edge_dx : M3 - edge_dx, edge_dy : N3 - edge_dy],
|
|
5046
|
+
axis=(-2, -1),
|
|
5047
|
+
)
|
|
5048
|
+
* fft_factor
|
|
5049
|
+
/ norm_factor_S3
|
|
5050
|
+
)
|
|
5683
5051
|
if data2 is not None:
|
|
5684
5052
|
if not edge:
|
|
5685
|
-
S3p.append(
|
|
5686
|
-
|
|
5687
|
-
|
|
5688
|
-
|
|
5053
|
+
S3p.append(
|
|
5054
|
+
self.backend.bk_reduce_mean(
|
|
5055
|
+
(
|
|
5056
|
+
self.backend.bk_reshape(
|
|
5057
|
+
data2_f_small, [N_image2, 1, 1, 1, M3, N3]
|
|
5058
|
+
)
|
|
5059
|
+
* self.backend.bk_conjugate(I1_f2_wf3_small)
|
|
5060
|
+
),
|
|
5061
|
+
axis=(-2, -1),
|
|
5062
|
+
)
|
|
5063
|
+
* fft_factor
|
|
5064
|
+
/ norm_factor_S3
|
|
5065
|
+
)
|
|
5066
|
+
|
|
5689
5067
|
if get_variance:
|
|
5690
|
-
S3p_sigma.append(
|
|
5691
|
-
|
|
5692
|
-
|
|
5068
|
+
S3p_sigma.append(
|
|
5069
|
+
self.backend.bk_reduce_std(
|
|
5070
|
+
(
|
|
5071
|
+
self.backend.bk_reshape(
|
|
5072
|
+
data2_f_small, [N_image2, 1, 1, 1, M3, N3]
|
|
5073
|
+
)
|
|
5074
|
+
* self.backend.bk_conjugate(I1_f2_wf3_small)
|
|
5075
|
+
),
|
|
5076
|
+
axis=(-2, -1),
|
|
5077
|
+
)
|
|
5078
|
+
* fft_factor
|
|
5079
|
+
/ norm_factor_S3
|
|
5080
|
+
)
|
|
5693
5081
|
else:
|
|
5694
|
-
|
|
5695
|
-
S3p.append(
|
|
5696
|
-
|
|
5697
|
-
|
|
5082
|
+
|
|
5083
|
+
S3p.append(
|
|
5084
|
+
self.backend.bk_reduce_mean(
|
|
5085
|
+
(
|
|
5086
|
+
self.backend.bk_reshape(
|
|
5087
|
+
data2_small, [N_image2, 1, 1, 1, M3, N3]
|
|
5088
|
+
)
|
|
5089
|
+
* self.backend.bk_conjugate(I12_w3_small)
|
|
5090
|
+
)[..., edge_dx : M3 - edge_dx, edge_dy : N3 - edge_dy],
|
|
5091
|
+
axis=(-2, -1),
|
|
5092
|
+
)
|
|
5093
|
+
* fft_factor
|
|
5094
|
+
/ norm_factor_S3
|
|
5095
|
+
)
|
|
5698
5096
|
if get_variance:
|
|
5699
|
-
S3p_sigma.append(
|
|
5700
|
-
|
|
5701
|
-
|
|
5702
|
-
|
|
5097
|
+
S3p_sigma.append(
|
|
5098
|
+
self.backend.bk_reduce_std(
|
|
5099
|
+
(
|
|
5100
|
+
self.backend.bk_reshape(
|
|
5101
|
+
data2_small, [N_image2, 1, 1, 1, M3, N3]
|
|
5102
|
+
)
|
|
5103
|
+
* self.backend.bk_conjugate(I12_w3_small)
|
|
5104
|
+
)[
|
|
5105
|
+
...,
|
|
5106
|
+
edge_dx : M3 - edge_dx,
|
|
5107
|
+
edge_dy : N3 - edge_dy,
|
|
5108
|
+
],
|
|
5109
|
+
axis=(-2, -1),
|
|
5110
|
+
)
|
|
5111
|
+
* fft_factor
|
|
5112
|
+
/ norm_factor_S3
|
|
5113
|
+
)
|
|
5114
|
+
|
|
5703
5115
|
if j2 <= j3:
|
|
5704
|
-
if normalization==
|
|
5705
|
-
if use_ref:
|
|
5706
|
-
P = 1/
|
|
5707
|
-
|
|
5708
|
-
|
|
5709
|
-
|
|
5116
|
+
if normalization == "S2":
|
|
5117
|
+
if use_ref:
|
|
5118
|
+
P = 1 / (
|
|
5119
|
+
(
|
|
5120
|
+
ref_S2[:, j3 : j3 + 1, :, None, None]
|
|
5121
|
+
* ref_S2[:, j2 : j2 + 1, None, :, None]
|
|
5122
|
+
)
|
|
5123
|
+
** (0.5 * pseudo_coef)
|
|
5124
|
+
)
|
|
5125
|
+
else:
|
|
5126
|
+
P = 1 / (
|
|
5127
|
+
(
|
|
5128
|
+
S2[:, j3 : j3 + 1, :, None, None]
|
|
5129
|
+
* S2[:, j2 : j2 + 1, None, :, None]
|
|
5130
|
+
)
|
|
5131
|
+
** (0.5 * pseudo_coef)
|
|
5132
|
+
)
|
|
5133
|
+
P = self.backend.bk_complex(P, 0.0 * P)
|
|
5710
5134
|
else:
|
|
5711
|
-
P=C_ONE
|
|
5712
|
-
|
|
5713
|
-
for j1 in range(0, j2+1):
|
|
5714
|
-
|
|
5715
|
-
|
|
5716
|
-
|
|
5717
|
-
|
|
5718
|
-
|
|
5719
|
-
|
|
5720
|
-
|
|
5721
|
-
|
|
5722
|
-
|
|
5723
|
-
|
|
5724
|
-
self.backend.bk_conjugate(
|
|
5725
|
-
|
|
5726
|
-
|
|
5727
|
-
|
|
5728
|
-
|
|
5729
|
-
|
|
5730
|
-
|
|
5731
|
-
|
|
5732
|
-
|
|
5733
|
-
|
|
5734
|
-
|
|
5735
|
-
|
|
5736
|
-
|
|
5737
|
-
|
|
5135
|
+
P = C_ONE
|
|
5136
|
+
|
|
5137
|
+
for j1 in range(0, j2 + 1):
|
|
5138
|
+
if not edge:
|
|
5139
|
+
if not if_large_batch:
|
|
5140
|
+
# [N_image,l1,l2,l3,x,y]
|
|
5141
|
+
S4.append(
|
|
5142
|
+
self.backend.bk_reduce_mean(
|
|
5143
|
+
(
|
|
5144
|
+
self.backend.bk_reshape(
|
|
5145
|
+
I1_f_small[:, j1],
|
|
5146
|
+
[N_image, 1, L, 1, 1, M3, N3],
|
|
5147
|
+
)
|
|
5148
|
+
* self.backend.bk_conjugate(
|
|
5149
|
+
self.backend.bk_reshape(
|
|
5150
|
+
I1_f2_wf3_2_small,
|
|
5151
|
+
[N_image, 1, 1, L, L, M3, N3],
|
|
5152
|
+
)
|
|
5153
|
+
)
|
|
5154
|
+
),
|
|
5155
|
+
axis=(-2, -1),
|
|
5156
|
+
)
|
|
5157
|
+
* fft_factor
|
|
5158
|
+
* P
|
|
5159
|
+
)
|
|
5160
|
+
if get_variance:
|
|
5161
|
+
S4_sigma.append(
|
|
5162
|
+
self.backend.bk_reduce_std(
|
|
5163
|
+
(
|
|
5164
|
+
self.backend.bk_reshape(
|
|
5165
|
+
I1_f_small[:, j1],
|
|
5166
|
+
[N_image, 1, L, 1, 1, M3, N3],
|
|
5167
|
+
)
|
|
5168
|
+
* self.backend.bk_conjugate(
|
|
5169
|
+
self.backend.bk_reshape(
|
|
5170
|
+
I1_f2_wf3_2_small,
|
|
5171
|
+
[N_image, 1, 1, L, L, M3, N3],
|
|
5172
|
+
)
|
|
5173
|
+
)
|
|
5174
|
+
),
|
|
5175
|
+
axis=(-2, -1),
|
|
5176
|
+
)
|
|
5177
|
+
* fft_factor
|
|
5178
|
+
* P
|
|
5179
|
+
)
|
|
5738
5180
|
else:
|
|
5739
|
-
|
|
5740
|
-
# [N_image,
|
|
5741
|
-
S4.append(
|
|
5742
|
-
|
|
5743
|
-
|
|
5181
|
+
for l1 in range(L):
|
|
5182
|
+
# [N_image,l2,l3,x,y]
|
|
5183
|
+
S4.append(
|
|
5184
|
+
self.backend.bk_reduce_mean(
|
|
5185
|
+
(
|
|
5186
|
+
self.backend.bk_reshape(
|
|
5187
|
+
I1_f_small[:, j1, l1],
|
|
5188
|
+
[N_image, 1, 1, 1, M3, N3],
|
|
5189
|
+
)
|
|
5190
|
+
* self.backend.bk_conjugate(
|
|
5191
|
+
self.backend.bk_reshape(
|
|
5192
|
+
I1_f2_wf3_2_small,
|
|
5193
|
+
[N_image, 1, L, L, M3, N3],
|
|
5194
|
+
)
|
|
5195
|
+
)
|
|
5196
|
+
),
|
|
5197
|
+
axis=(-2, -1),
|
|
5744
5198
|
)
|
|
5745
|
-
|
|
5199
|
+
* fft_factor
|
|
5200
|
+
* P
|
|
5201
|
+
)
|
|
5746
5202
|
if get_variance:
|
|
5747
|
-
S4_sigma.append(
|
|
5748
|
-
|
|
5749
|
-
|
|
5203
|
+
S4_sigma.append(
|
|
5204
|
+
self.backend.bk_reduce_std(
|
|
5205
|
+
(
|
|
5206
|
+
self.backend.bk_reshape(
|
|
5207
|
+
I1_f_small[:, j1, l1],
|
|
5208
|
+
[N_image, 1, 1, 1, M3, N3],
|
|
5209
|
+
)
|
|
5210
|
+
* self.backend.bk_conjugate(
|
|
5211
|
+
self.backend.bk_reshape(
|
|
5212
|
+
I1_f2_wf3_2_small,
|
|
5213
|
+
[N_image, 1, L, L, M3, N3],
|
|
5214
|
+
)
|
|
5215
|
+
)
|
|
5216
|
+
),
|
|
5217
|
+
axis=(-2, -1),
|
|
5750
5218
|
)
|
|
5751
|
-
|
|
5752
|
-
|
|
5753
|
-
|
|
5754
|
-
|
|
5755
|
-
|
|
5756
|
-
|
|
5757
|
-
|
|
5219
|
+
* fft_factor
|
|
5220
|
+
* P
|
|
5221
|
+
)
|
|
5222
|
+
else:
|
|
5223
|
+
if not if_large_batch:
|
|
5224
|
+
# [N_image,l1,l2,l3,x,y]
|
|
5225
|
+
S4.append(
|
|
5226
|
+
self.backend.bk_reduce_mean(
|
|
5227
|
+
(
|
|
5228
|
+
self.backend.bk_reshape(
|
|
5229
|
+
I1_small[:, j1],
|
|
5230
|
+
[N_image, 1, L, 1, 1, M3, N3],
|
|
5758
5231
|
)
|
|
5759
|
-
|
|
5760
|
-
|
|
5761
|
-
|
|
5762
|
-
|
|
5763
|
-
self.backend.bk_reshape(I12_w3_2_small,[N_image,1,L,L,M3,N3])
|
|
5232
|
+
* self.backend.bk_conjugate(
|
|
5233
|
+
self.backend.bk_reshape(
|
|
5234
|
+
I12_w3_2_small,
|
|
5235
|
+
[N_image, 1, 1, L, L, M3, N3],
|
|
5764
5236
|
)
|
|
5765
|
-
)
|
|
5766
|
-
|
|
5767
|
-
|
|
5768
|
-
|
|
5769
|
-
|
|
5237
|
+
)
|
|
5238
|
+
)[..., edge_dx:-edge_dx, edge_dy:-edge_dy],
|
|
5239
|
+
axis=(-2, -1),
|
|
5240
|
+
)
|
|
5241
|
+
* fft_factor
|
|
5242
|
+
* P
|
|
5243
|
+
)
|
|
5244
|
+
if get_variance:
|
|
5245
|
+
S4_sigma.append(
|
|
5246
|
+
self.backend.bk_reduce_std(
|
|
5247
|
+
(
|
|
5248
|
+
self.backend.bk_reshape(
|
|
5249
|
+
I1_small[:, j1],
|
|
5250
|
+
[N_image, 1, L, 1, 1, M3, N3],
|
|
5251
|
+
)
|
|
5252
|
+
* self.backend.bk_conjugate(
|
|
5253
|
+
self.backend.bk_reshape(
|
|
5254
|
+
I12_w3_2_small,
|
|
5255
|
+
[N_image, 1, 1, L, L, M3, N3],
|
|
5256
|
+
)
|
|
5257
|
+
)
|
|
5258
|
+
)[..., edge_dx:-edge_dx, edge_dy:-edge_dy],
|
|
5259
|
+
axis=(-2, -1),
|
|
5260
|
+
)
|
|
5261
|
+
* fft_factor
|
|
5262
|
+
* P
|
|
5263
|
+
)
|
|
5264
|
+
else:
|
|
5265
|
+
for l1 in range(L):
|
|
5266
|
+
# [N_image,l2,l3,x,y]
|
|
5267
|
+
S4.append(
|
|
5268
|
+
self.backend.bk_reduce_mean(
|
|
5269
|
+
(
|
|
5270
|
+
self.backend.bk_reshape(
|
|
5271
|
+
I1_small[:, j1],
|
|
5272
|
+
[N_image, 1, 1, 1, M3, N3],
|
|
5273
|
+
)
|
|
5274
|
+
* self.backend.bk_conjugate(
|
|
5275
|
+
self.backend.bk_reshape(
|
|
5276
|
+
I12_w3_2_small,
|
|
5277
|
+
[N_image, 1, L, L, M3, N3],
|
|
5278
|
+
)
|
|
5279
|
+
)
|
|
5280
|
+
)[..., edge_dx:-edge_dx, edge_dy:-edge_dy],
|
|
5281
|
+
axis=(-2, -1),
|
|
5282
|
+
)
|
|
5283
|
+
* fft_factor
|
|
5284
|
+
* P
|
|
5285
|
+
)
|
|
5286
|
+
if get_variance:
|
|
5287
|
+
S4_sigma.append(
|
|
5288
|
+
self.backend.bk_reduce_std(
|
|
5289
|
+
(
|
|
5290
|
+
self.backend.bk_reshape(
|
|
5291
|
+
I1_small[:, j1],
|
|
5292
|
+
[N_image, 1, 1, 1, M3, N3],
|
|
5293
|
+
)
|
|
5294
|
+
* self.backend.bk_conjugate(
|
|
5295
|
+
self.backend.bk_reshape(
|
|
5296
|
+
I12_w3_2_small,
|
|
5297
|
+
[N_image, 1, L, L, M3, N3],
|
|
5298
|
+
)
|
|
5299
|
+
)
|
|
5300
|
+
)[
|
|
5301
|
+
...,
|
|
5302
|
+
edge_dx:-edge_dx,
|
|
5303
|
+
edge_dy:-edge_dy,
|
|
5304
|
+
],
|
|
5305
|
+
axis=(-2, -1),
|
|
5306
|
+
)
|
|
5307
|
+
* fft_factor
|
|
5308
|
+
* P
|
|
5309
|
+
)
|
|
5310
|
+
|
|
5311
|
+
S3 = self.backend.bk_concat(S3, axis=1)
|
|
5312
|
+
S4 = self.backend.bk_concat(S4, axis=1)
|
|
5313
|
+
|
|
5770
5314
|
if get_variance:
|
|
5771
|
-
S3_sigma=self.backend.bk_concat(S3_sigma,axis=1)
|
|
5772
|
-
S4_sigma=self.backend.bk_concat(S4_sigma,axis=1)
|
|
5773
|
-
|
|
5315
|
+
S3_sigma = self.backend.bk_concat(S3_sigma, axis=1)
|
|
5316
|
+
S4_sigma = self.backend.bk_concat(S4_sigma, axis=1)
|
|
5317
|
+
|
|
5774
5318
|
if data2 is not None:
|
|
5775
|
-
S3p=self.backend.bk_concat(S3p,axis=1)
|
|
5319
|
+
S3p = self.backend.bk_concat(S3p, axis=1)
|
|
5776
5320
|
if get_variance:
|
|
5777
|
-
S3p_sigma=self.backend.bk_concat(S3p_sigma,axis=1)
|
|
5778
|
-
|
|
5321
|
+
S3p_sigma = self.backend.bk_concat(S3p_sigma, axis=1)
|
|
5322
|
+
|
|
5779
5323
|
# average over l1 to obtain simple isotropic statistics
|
|
5780
5324
|
if iso_ang:
|
|
5781
|
-
S2_iso = self.backend.bk_reduce_mean(S2,axis=(-1))
|
|
5782
|
-
S1_iso = self.backend.bk_reduce_mean(S1,axis=(-1))
|
|
5325
|
+
S2_iso = self.backend.bk_reduce_mean(S2, axis=(-1))
|
|
5326
|
+
S1_iso = self.backend.bk_reduce_mean(S1, axis=(-1))
|
|
5783
5327
|
for l1 in range(L):
|
|
5784
5328
|
for l2 in range(L):
|
|
5785
|
-
S3_iso[...,(l2-l1)%L] += S3[...,l1,l2]
|
|
5329
|
+
S3_iso[..., (l2 - l1) % L] += S3[..., l1, l2]
|
|
5786
5330
|
if data2 is not None:
|
|
5787
|
-
S3p_iso[...,(l2-l1)%L] += S3p[...,l1,l2]
|
|
5331
|
+
S3p_iso[..., (l2 - l1) % L] += S3p[..., l1, l2]
|
|
5788
5332
|
for l3 in range(L):
|
|
5789
|
-
S4_iso[...,(l2-l1)%L,(l3-l1)%L] += S4[...,l1,l2,l3]
|
|
5790
|
-
S3_iso /= L
|
|
5333
|
+
S4_iso[..., (l2 - l1) % L, (l3 - l1) % L] += S4[..., l1, l2, l3]
|
|
5334
|
+
S3_iso /= L
|
|
5335
|
+
S4_iso /= L
|
|
5791
5336
|
if data2 is not None:
|
|
5792
5337
|
S3p_iso /= L
|
|
5793
|
-
|
|
5338
|
+
|
|
5794
5339
|
if get_variance:
|
|
5795
|
-
S2_sigma_iso = self.backend.bk_reduce_mean(S2_sigma,axis=(-1))
|
|
5796
|
-
S1_sigma_iso = self.backend.bk_reduce_mean(S1_sigma,axis=(-1))
|
|
5340
|
+
S2_sigma_iso = self.backend.bk_reduce_mean(S2_sigma, axis=(-1))
|
|
5341
|
+
S1_sigma_iso = self.backend.bk_reduce_mean(S1_sigma, axis=(-1))
|
|
5797
5342
|
for l1 in range(L):
|
|
5798
5343
|
for l2 in range(L):
|
|
5799
|
-
S3_sigma_iso[...,(l2-l1)%L] += S3_sigma[...,l1,l2]
|
|
5344
|
+
S3_sigma_iso[..., (l2 - l1) % L] += S3_sigma[..., l1, l2]
|
|
5800
5345
|
if data2 is not None:
|
|
5801
|
-
S3p_sigma_iso[...,(l2-l1)%L] += S3p_sigma[...,l1,l2]
|
|
5346
|
+
S3p_sigma_iso[..., (l2 - l1) % L] += S3p_sigma[..., l1, l2]
|
|
5802
5347
|
for l3 in range(L):
|
|
5803
|
-
S4_sigma_iso[...,(l2-l1)%L,(l3-l1)%L] += S4_sigma[
|
|
5804
|
-
|
|
5348
|
+
S4_sigma_iso[..., (l2 - l1) % L, (l3 - l1) % L] += S4_sigma[
|
|
5349
|
+
..., l1, l2, l3
|
|
5350
|
+
]
|
|
5351
|
+
S3_sigma_iso /= L
|
|
5352
|
+
S4_sigma_iso /= L
|
|
5805
5353
|
if data2 is not None:
|
|
5806
5354
|
S3p_sigma_iso /= L
|
|
5807
|
-
|
|
5355
|
+
|
|
5808
5356
|
if data2 is None:
|
|
5809
|
-
mean_data=self.backend.bk_reshape(
|
|
5810
|
-
|
|
5357
|
+
mean_data = self.backend.bk_reshape(
|
|
5358
|
+
self.backend.bk_reduce_mean(data, axis=(-2, -1)), [N_image, 1]
|
|
5359
|
+
)
|
|
5360
|
+
std_data = self.backend.bk_reshape(
|
|
5361
|
+
self.backend.bk_reduce_std(data, axis=(-2, -1)), [N_image, 1]
|
|
5362
|
+
)
|
|
5811
5363
|
else:
|
|
5812
|
-
mean_data=self.backend.bk_reshape(
|
|
5813
|
-
|
|
5814
|
-
|
|
5364
|
+
mean_data = self.backend.bk_reshape(
|
|
5365
|
+
self.backend.bk_reduce_mean(data * data2, axis=(-2, -1)), [N_image, 1]
|
|
5366
|
+
)
|
|
5367
|
+
std_data = self.backend.bk_reshape(
|
|
5368
|
+
self.backend.bk_reduce_std(data * data2, axis=(-2, -1)), [N_image, 1]
|
|
5369
|
+
)
|
|
5370
|
+
|
|
5815
5371
|
if get_variance:
|
|
5816
|
-
ref_sigma={}
|
|
5372
|
+
ref_sigma = {}
|
|
5817
5373
|
if iso_ang:
|
|
5818
|
-
ref_sigma[
|
|
5819
|
-
ref_sigma[
|
|
5820
|
-
ref_sigma[
|
|
5821
|
-
ref_sigma[
|
|
5822
|
-
ref_sigma[
|
|
5374
|
+
ref_sigma["std_data"] = std_data
|
|
5375
|
+
ref_sigma["S1_sigma"] = S1_sigma_iso
|
|
5376
|
+
ref_sigma["S2_sigma"] = S2_sigma_iso
|
|
5377
|
+
ref_sigma["S3_sigma"] = S3_sigma_iso
|
|
5378
|
+
ref_sigma["S4_sigma"] = S4_sigma_iso
|
|
5823
5379
|
if data2 is not None:
|
|
5824
|
-
ref_sigma[
|
|
5380
|
+
ref_sigma["S3p_sigma"] = S3p_sigma_iso
|
|
5825
5381
|
else:
|
|
5826
|
-
ref_sigma[
|
|
5827
|
-
ref_sigma[
|
|
5828
|
-
ref_sigma[
|
|
5829
|
-
ref_sigma[
|
|
5830
|
-
ref_sigma[
|
|
5382
|
+
ref_sigma["std_data"] = std_data
|
|
5383
|
+
ref_sigma["S1_sigma"] = S1_sigma
|
|
5384
|
+
ref_sigma["S2_sigma"] = S2_sigma
|
|
5385
|
+
ref_sigma["S3_sigma"] = S3_sigma
|
|
5386
|
+
ref_sigma["S4_sigma"] = S4_sigma
|
|
5831
5387
|
if data2 is not None:
|
|
5832
|
-
ref_sigma[
|
|
5833
|
-
|
|
5388
|
+
ref_sigma["S3p_sigma"] = S3_sigma
|
|
5389
|
+
|
|
5834
5390
|
if data2 is None:
|
|
5835
5391
|
if iso_ang:
|
|
5836
5392
|
if ref_sigma is not None:
|
|
5837
|
-
for_synthesis = self.backend.bk_concat(
|
|
5838
|
-
|
|
5839
|
-
|
|
5840
|
-
|
|
5841
|
-
|
|
5842
|
-
|
|
5843
|
-
|
|
5844
|
-
|
|
5845
|
-
|
|
5846
|
-
|
|
5393
|
+
for_synthesis = self.backend.bk_concat(
|
|
5394
|
+
(
|
|
5395
|
+
mean_data / ref_sigma["std_data"],
|
|
5396
|
+
std_data / ref_sigma["std_data"],
|
|
5397
|
+
self.backend.bk_reshape(
|
|
5398
|
+
self.backend.bk_log(S2_iso / ref_sigma["S2_sigma"]),
|
|
5399
|
+
[N_image, -1],
|
|
5400
|
+
),
|
|
5401
|
+
self.backend.bk_reshape(
|
|
5402
|
+
self.backend.bk_log(S1_iso / ref_sigma["S1_sigma"]),
|
|
5403
|
+
[N_image, -1],
|
|
5404
|
+
),
|
|
5405
|
+
self.backend.bk_reshape(
|
|
5406
|
+
self.backend.bk_real(S3_iso / ref_sigma["S3_sigma"]),
|
|
5407
|
+
[N_image, -1],
|
|
5408
|
+
),
|
|
5409
|
+
self.backend.bk_reshape(
|
|
5410
|
+
self.backend.bk_imag(S3_iso / ref_sigma["S3_sigma"]),
|
|
5411
|
+
[N_image, -1],
|
|
5412
|
+
),
|
|
5413
|
+
self.backend.bk_reshape(
|
|
5414
|
+
self.backend.bk_real(S4_iso / ref_sigma["S4_sigma"]),
|
|
5415
|
+
[N_image, -1],
|
|
5416
|
+
),
|
|
5417
|
+
self.backend.bk_reshape(
|
|
5418
|
+
self.backend.bk_imag(S4_iso / ref_sigma["S4_sigma"]),
|
|
5419
|
+
[N_image, -1],
|
|
5420
|
+
),
|
|
5421
|
+
),
|
|
5422
|
+
axis=-1,
|
|
5423
|
+
)
|
|
5847
5424
|
else:
|
|
5848
|
-
for_synthesis = self.backend.bk_concat(
|
|
5849
|
-
|
|
5850
|
-
|
|
5851
|
-
|
|
5852
|
-
|
|
5853
|
-
|
|
5854
|
-
|
|
5855
|
-
|
|
5856
|
-
|
|
5857
|
-
|
|
5425
|
+
for_synthesis = self.backend.bk_concat(
|
|
5426
|
+
(
|
|
5427
|
+
mean_data / std_data,
|
|
5428
|
+
std_data,
|
|
5429
|
+
self.backend.bk_reshape(
|
|
5430
|
+
self.backend.bk_log(S2_iso), [N_image, -1]
|
|
5431
|
+
),
|
|
5432
|
+
self.backend.bk_reshape(
|
|
5433
|
+
self.backend.bk_log(S1_iso), [N_image, -1]
|
|
5434
|
+
),
|
|
5435
|
+
self.backend.bk_reshape(
|
|
5436
|
+
self.backend.bk_real(S3_iso), [N_image, -1]
|
|
5437
|
+
),
|
|
5438
|
+
self.backend.bk_reshape(
|
|
5439
|
+
self.backend.bk_imag(S3_iso), [N_image, -1]
|
|
5440
|
+
),
|
|
5441
|
+
self.backend.bk_reshape(
|
|
5442
|
+
self.backend.bk_real(S4_iso), [N_image, -1]
|
|
5443
|
+
),
|
|
5444
|
+
self.backend.bk_reshape(
|
|
5445
|
+
self.backend.bk_imag(S4_iso), [N_image, -1]
|
|
5446
|
+
),
|
|
5447
|
+
),
|
|
5448
|
+
axis=-1,
|
|
5449
|
+
)
|
|
5858
5450
|
else:
|
|
5859
5451
|
if ref_sigma is not None:
|
|
5860
|
-
for_synthesis = self.backend.bk_concat(
|
|
5861
|
-
|
|
5862
|
-
|
|
5863
|
-
|
|
5864
|
-
|
|
5865
|
-
|
|
5866
|
-
|
|
5867
|
-
|
|
5868
|
-
|
|
5869
|
-
|
|
5452
|
+
for_synthesis = self.backend.bk_concat(
|
|
5453
|
+
(
|
|
5454
|
+
mean_data / ref_sigma["std_data"],
|
|
5455
|
+
std_data / ref_sigma["std_data"],
|
|
5456
|
+
self.backend.bk_reshape(
|
|
5457
|
+
self.backend.bk_log(S2 / ref_sigma["S2_sigma"]),
|
|
5458
|
+
[N_image, -1],
|
|
5459
|
+
),
|
|
5460
|
+
self.backend.bk_reshape(
|
|
5461
|
+
self.backend.bk_log(S1 / ref_sigma["S1_sigma"]),
|
|
5462
|
+
[N_image, -1],
|
|
5463
|
+
),
|
|
5464
|
+
self.backend.bk_reshape(
|
|
5465
|
+
self.backend.bk_real(S3 / ref_sigma["S3_sigma"]),
|
|
5466
|
+
[N_image, -1],
|
|
5467
|
+
),
|
|
5468
|
+
self.backend.bk_reshape(
|
|
5469
|
+
self.backend.bk_imag(S3 / ref_sigma["S3_sigma"]),
|
|
5470
|
+
[N_image, -1],
|
|
5471
|
+
),
|
|
5472
|
+
self.backend.bk_reshape(
|
|
5473
|
+
self.backend.bk_real(S4 / ref_sigma["S4_sigma"]),
|
|
5474
|
+
[N_image, -1],
|
|
5475
|
+
),
|
|
5476
|
+
self.backend.bk_reshape(
|
|
5477
|
+
self.backend.bk_imag(S4 / ref_sigma["S4_sigma"]),
|
|
5478
|
+
[N_image, -1],
|
|
5479
|
+
),
|
|
5480
|
+
),
|
|
5481
|
+
axis=-1,
|
|
5482
|
+
)
|
|
5870
5483
|
else:
|
|
5871
|
-
for_synthesis = self.backend.bk_concat(
|
|
5872
|
-
|
|
5484
|
+
for_synthesis = self.backend.bk_concat(
|
|
5485
|
+
(
|
|
5486
|
+
mean_data / std_data,
|
|
5873
5487
|
std_data,
|
|
5874
|
-
self.backend.bk_reshape(
|
|
5875
|
-
|
|
5876
|
-
|
|
5877
|
-
self.backend.bk_reshape(
|
|
5878
|
-
|
|
5879
|
-
|
|
5880
|
-
|
|
5488
|
+
self.backend.bk_reshape(
|
|
5489
|
+
self.backend.bk_log(S2), [N_image, -1]
|
|
5490
|
+
),
|
|
5491
|
+
self.backend.bk_reshape(
|
|
5492
|
+
self.backend.bk_log(S1), [N_image, -1]
|
|
5493
|
+
),
|
|
5494
|
+
self.backend.bk_reshape(
|
|
5495
|
+
self.backend.bk_real(S3), [N_image, -1]
|
|
5496
|
+
),
|
|
5497
|
+
self.backend.bk_reshape(
|
|
5498
|
+
self.backend.bk_imag(S3), [N_image, -1]
|
|
5499
|
+
),
|
|
5500
|
+
self.backend.bk_reshape(
|
|
5501
|
+
self.backend.bk_real(S4), [N_image, -1]
|
|
5502
|
+
),
|
|
5503
|
+
self.backend.bk_reshape(
|
|
5504
|
+
self.backend.bk_imag(S4), [N_image, -1]
|
|
5505
|
+
),
|
|
5506
|
+
),
|
|
5507
|
+
axis=-1,
|
|
5508
|
+
)
|
|
5881
5509
|
else:
|
|
5882
5510
|
if iso_ang:
|
|
5883
5511
|
if ref_sigma is not None:
|
|
5884
|
-
for_synthesis = self.backend.backend.cat(
|
|
5885
|
-
|
|
5886
|
-
|
|
5887
|
-
|
|
5888
|
-
|
|
5889
|
-
|
|
5890
|
-
|
|
5891
|
-
|
|
5892
|
-
|
|
5893
|
-
|
|
5894
|
-
|
|
5895
|
-
|
|
5512
|
+
for_synthesis = self.backend.backend.cat(
|
|
5513
|
+
(
|
|
5514
|
+
mean_data / ref_sigma["std_data"],
|
|
5515
|
+
std_data / ref_sigma["std_data"],
|
|
5516
|
+
self.backend.bk_reshape(
|
|
5517
|
+
self.backend.bk_real(S2_iso / ref_sigma["S2_sigma"]),
|
|
5518
|
+
[N_image, -1],
|
|
5519
|
+
),
|
|
5520
|
+
self.backend.bk_reshape(
|
|
5521
|
+
self.backend.bk_real(S1_iso / ref_sigma["S1_sigma"]),
|
|
5522
|
+
[N_image, -1],
|
|
5523
|
+
),
|
|
5524
|
+
self.backend.bk_reshape(
|
|
5525
|
+
self.backend.bk_real(S3_iso / ref_sigma["S3_sigma"]),
|
|
5526
|
+
[N_image, -1],
|
|
5527
|
+
),
|
|
5528
|
+
self.backend.bk_reshape(
|
|
5529
|
+
self.backend.bk_imag(S3_iso / ref_sigma["S3_sigma"]),
|
|
5530
|
+
[N_image, -1],
|
|
5531
|
+
),
|
|
5532
|
+
self.backend.bk_reshape(
|
|
5533
|
+
self.backend.bk_real(S3p_iso / ref_sigma["S3p_sigma"]),
|
|
5534
|
+
[N_image, -1],
|
|
5535
|
+
),
|
|
5536
|
+
self.backend.bk_reshape(
|
|
5537
|
+
self.backend.bk_imag(S3p_iso / ref_sigma["S3p_sigma"]),
|
|
5538
|
+
[N_image, -1],
|
|
5539
|
+
),
|
|
5540
|
+
self.backend.bk_reshape(
|
|
5541
|
+
self.backend.bk_real(S4_iso / ref_sigma["S4_sigma"]),
|
|
5542
|
+
[N_image, -1],
|
|
5543
|
+
),
|
|
5544
|
+
self.backend.bk_reshape(
|
|
5545
|
+
self.backend.bk_imag(S4_iso / ref_sigma["S4_sigma"]),
|
|
5546
|
+
[N_image, -1],
|
|
5547
|
+
),
|
|
5548
|
+
),
|
|
5549
|
+
axis=-1,
|
|
5550
|
+
)
|
|
5896
5551
|
else:
|
|
5897
|
-
for_synthesis = self.backend.backend.cat(
|
|
5898
|
-
|
|
5899
|
-
|
|
5900
|
-
|
|
5901
|
-
|
|
5902
|
-
|
|
5903
|
-
|
|
5904
|
-
|
|
5905
|
-
|
|
5906
|
-
|
|
5907
|
-
|
|
5908
|
-
|
|
5552
|
+
for_synthesis = self.backend.backend.cat(
|
|
5553
|
+
(
|
|
5554
|
+
mean_data / std_data,
|
|
5555
|
+
std_data,
|
|
5556
|
+
self.backend.bk_reshape(
|
|
5557
|
+
self.backend.bk_real(S2_iso), [N_image, -1]
|
|
5558
|
+
),
|
|
5559
|
+
self.backend.bk_reshape(
|
|
5560
|
+
self.backend.bk_real(S1_iso), [N_image, -1]
|
|
5561
|
+
),
|
|
5562
|
+
self.backend.bk_reshape(
|
|
5563
|
+
self.backend.bk_real(S3_iso), [N_image, -1]
|
|
5564
|
+
),
|
|
5565
|
+
self.backend.bk_reshape(
|
|
5566
|
+
self.backend.bk_imag(S3_iso), [N_image, -1]
|
|
5567
|
+
),
|
|
5568
|
+
self.backend.bk_reshape(
|
|
5569
|
+
self.backend.bk_real(S3p_iso), [N_image, -1]
|
|
5570
|
+
),
|
|
5571
|
+
self.backend.bk_reshape(
|
|
5572
|
+
self.backend.bk_imag(S3p_iso), [N_image, -1]
|
|
5573
|
+
),
|
|
5574
|
+
self.backend.bk_reshape(
|
|
5575
|
+
self.backend.bk_real(S4_iso), [N_image, -1]
|
|
5576
|
+
),
|
|
5577
|
+
self.backend.bk_reshape(
|
|
5578
|
+
self.backend.bk_imag(S4_iso), [N_image, -1]
|
|
5579
|
+
),
|
|
5580
|
+
),
|
|
5581
|
+
axis=-1,
|
|
5582
|
+
)
|
|
5909
5583
|
else:
|
|
5910
5584
|
if ref_sigma is not None:
|
|
5911
|
-
for_synthesis = self.backend.backend.cat(
|
|
5912
|
-
|
|
5913
|
-
|
|
5914
|
-
|
|
5915
|
-
|
|
5916
|
-
|
|
5917
|
-
|
|
5918
|
-
|
|
5919
|
-
|
|
5920
|
-
|
|
5921
|
-
|
|
5922
|
-
|
|
5585
|
+
for_synthesis = self.backend.backend.cat(
|
|
5586
|
+
(
|
|
5587
|
+
mean_data / ref_sigma["std_data"],
|
|
5588
|
+
std_data / ref_sigma["std_data"],
|
|
5589
|
+
self.backend.bk_reshape(
|
|
5590
|
+
self.backend.bk_real(S2 / ref_sigma["S2_sigma"]),
|
|
5591
|
+
[N_image, -1],
|
|
5592
|
+
),
|
|
5593
|
+
self.backend.bk_reshape(
|
|
5594
|
+
self.backend.bk_real(S1 / ref_sigma["S1_sigma"]),
|
|
5595
|
+
[N_image, -1],
|
|
5596
|
+
),
|
|
5597
|
+
self.backend.bk_reshape(
|
|
5598
|
+
self.backend.bk_real(S3 / ref_sigma["S3_sigma"]),
|
|
5599
|
+
[N_image, -1],
|
|
5600
|
+
),
|
|
5601
|
+
self.backend.bk_reshape(
|
|
5602
|
+
self.backend.bk_imag(S3 / ref_sigma["S3_sigma"]),
|
|
5603
|
+
[N_image, -1],
|
|
5604
|
+
),
|
|
5605
|
+
self.backend.bk_reshape(
|
|
5606
|
+
self.backend.bk_real(S3p / ref_sigma["S3p_sigma"]),
|
|
5607
|
+
[N_image, -1],
|
|
5608
|
+
),
|
|
5609
|
+
self.backend.bk_reshape(
|
|
5610
|
+
self.backend.bk_imag(S3p / ref_sigma["S3p_sigma"]),
|
|
5611
|
+
[N_image, -1],
|
|
5612
|
+
),
|
|
5613
|
+
self.backend.bk_reshape(
|
|
5614
|
+
self.backend.bk_real(S4 / ref_sigma["S4_sigma"]),
|
|
5615
|
+
[N_image, -1],
|
|
5616
|
+
),
|
|
5617
|
+
self.backend.bk_reshape(
|
|
5618
|
+
self.backend.bk_imag(S4 / ref_sigma["S4_sigma"]),
|
|
5619
|
+
[N_image, -1],
|
|
5620
|
+
),
|
|
5621
|
+
),
|
|
5622
|
+
axis=-1,
|
|
5623
|
+
)
|
|
5923
5624
|
else:
|
|
5924
|
-
for_synthesis = self.backend.bk_concat(
|
|
5925
|
-
|
|
5625
|
+
for_synthesis = self.backend.bk_concat(
|
|
5626
|
+
(
|
|
5627
|
+
mean_data / std_data,
|
|
5926
5628
|
std_data,
|
|
5927
|
-
self.backend.bk_reshape(
|
|
5928
|
-
|
|
5929
|
-
|
|
5930
|
-
self.backend.bk_reshape(
|
|
5931
|
-
|
|
5932
|
-
|
|
5933
|
-
self.backend.bk_reshape(
|
|
5934
|
-
|
|
5935
|
-
|
|
5936
|
-
|
|
5937
|
-
|
|
5938
|
-
|
|
5939
|
-
|
|
5629
|
+
self.backend.bk_reshape(
|
|
5630
|
+
self.backend.bk_real(S2), [N_image, -1]
|
|
5631
|
+
),
|
|
5632
|
+
self.backend.bk_reshape(
|
|
5633
|
+
self.backend.bk_real(S1), [N_image, -1]
|
|
5634
|
+
),
|
|
5635
|
+
self.backend.bk_reshape(
|
|
5636
|
+
self.backend.bk_real(S3), [N_image, -1]
|
|
5637
|
+
),
|
|
5638
|
+
self.backend.bk_reshape(
|
|
5639
|
+
self.backend.bk_imag(S3), [N_image, -1]
|
|
5640
|
+
),
|
|
5641
|
+
self.backend.bk_reshape(
|
|
5642
|
+
self.backend.bk_real(S3p), [N_image, -1]
|
|
5643
|
+
),
|
|
5644
|
+
self.backend.bk_reshape(
|
|
5645
|
+
self.backend.bk_imag(S3p), [N_image, -1]
|
|
5646
|
+
),
|
|
5647
|
+
self.backend.bk_reshape(
|
|
5648
|
+
self.backend.bk_real(S4), [N_image, -1]
|
|
5649
|
+
),
|
|
5650
|
+
self.backend.bk_reshape(
|
|
5651
|
+
self.backend.bk_imag(S4), [N_image, -1]
|
|
5652
|
+
),
|
|
5653
|
+
),
|
|
5654
|
+
axis=-1,
|
|
5655
|
+
)
|
|
5656
|
+
|
|
5657
|
+
if not use_ref:
|
|
5658
|
+
self.ref_scattering_cov_S2 = S2
|
|
5659
|
+
|
|
5940
5660
|
if get_variance:
|
|
5941
|
-
return for_synthesis,ref_sigma
|
|
5942
|
-
|
|
5661
|
+
return for_synthesis, ref_sigma
|
|
5662
|
+
|
|
5943
5663
|
return for_synthesis
|
|
5944
|
-
|
|
5945
|
-
|
|
5946
|
-
def to_gaussian(self,x):
|
|
5947
|
-
from scipy.stats import norm
|
|
5664
|
+
|
|
5665
|
+
def to_gaussian(self, x):
|
|
5948
5666
|
from scipy.interpolate import interp1d
|
|
5667
|
+
from scipy.stats import norm
|
|
5949
5668
|
|
|
5950
|
-
idx=np.argsort(x.flatten())
|
|
5669
|
+
idx = np.argsort(x.flatten())
|
|
5951
5670
|
p = (np.arange(1, idx.shape[0] + 1) - 0.5) / idx.shape[0]
|
|
5952
|
-
im_target=x.flatten()
|
|
5671
|
+
im_target = x.flatten()
|
|
5953
5672
|
im_target[idx] = norm.ppf(p)
|
|
5954
|
-
|
|
5673
|
+
|
|
5955
5674
|
# Interpolation cubique
|
|
5956
|
-
self.f_gaussian = interp1d(im_target[idx], x.flatten()[idx], kind=
|
|
5957
|
-
self.val_min=im_target[idx[0]]
|
|
5958
|
-
self.val_max=im_target[idx[-1]]
|
|
5675
|
+
self.f_gaussian = interp1d(im_target[idx], x.flatten()[idx], kind="cubic")
|
|
5676
|
+
self.val_min = im_target[idx[0]]
|
|
5677
|
+
self.val_max = im_target[idx[-1]]
|
|
5959
5678
|
return im_target.reshape(x.shape)
|
|
5960
5679
|
|
|
5680
|
+
def from_gaussian(self, x):
|
|
5961
5681
|
|
|
5962
|
-
|
|
5963
|
-
|
|
5964
|
-
x=self.backend.bk_clip_by_value(x,self.val_min,self.val_max)
|
|
5682
|
+
x = self.backend.bk_clip_by_value(x, self.val_min, self.val_max)
|
|
5965
5683
|
return self.f_gaussian(self.backend.to_numpy(x))
|
|
5966
|
-
|
|
5684
|
+
|
|
5967
5685
|
def square(self, x):
|
|
5968
5686
|
if isinstance(x, scat_cov):
|
|
5969
5687
|
if x.S1 is None:
|
|
@@ -6315,89 +6033,133 @@ class funct(FOC.FoCUS):
|
|
|
6315
6033
|
s0, s2, s3, s4, s1=s1, s3p=s3p, backend=self.backend, use_1D=self.use_1D
|
|
6316
6034
|
)
|
|
6317
6035
|
|
|
6318
|
-
def synthesis(
|
|
6319
|
-
|
|
6320
|
-
|
|
6321
|
-
|
|
6322
|
-
|
|
6323
|
-
|
|
6324
|
-
|
|
6325
|
-
|
|
6326
|
-
|
|
6327
|
-
|
|
6328
|
-
|
|
6329
|
-
|
|
6330
|
-
|
|
6331
|
-
|
|
6332
|
-
|
|
6036
|
+
def synthesis(
|
|
6037
|
+
self,
|
|
6038
|
+
image_target,
|
|
6039
|
+
nstep=4,
|
|
6040
|
+
seed=1234,
|
|
6041
|
+
Jmax=None,
|
|
6042
|
+
edge=False,
|
|
6043
|
+
to_gaussian=True,
|
|
6044
|
+
use_variance=False,
|
|
6045
|
+
synthesised_N=1,
|
|
6046
|
+
input_image=None,
|
|
6047
|
+
grd_mask=None,
|
|
6048
|
+
iso_ang=False,
|
|
6049
|
+
EVAL_FREQUENCY=100,
|
|
6050
|
+
NUM_EPOCHS=300,
|
|
6051
|
+
):
|
|
6052
|
+
|
|
6333
6053
|
import time
|
|
6334
6054
|
|
|
6335
|
-
|
|
6336
|
-
|
|
6055
|
+
import foscat.Synthesis as synthe
|
|
6056
|
+
|
|
6057
|
+
def The_loss(u, scat_operator, args):
|
|
6058
|
+
ref = args[0]
|
|
6337
6059
|
sref = args[1]
|
|
6338
|
-
use_v= args[2]
|
|
6339
|
-
|
|
6060
|
+
use_v = args[2]
|
|
6061
|
+
|
|
6340
6062
|
# compute scattering covariance of the current synthetised map called u
|
|
6341
6063
|
if use_v:
|
|
6342
|
-
learn=scat_operator.reduce_mean_batch(
|
|
6064
|
+
learn = scat_operator.reduce_mean_batch(
|
|
6065
|
+
scat_operator.scattering_cov(
|
|
6066
|
+
u,
|
|
6067
|
+
edge=edge,
|
|
6068
|
+
Jmax=Jmax,
|
|
6069
|
+
ref_sigma=sref,
|
|
6070
|
+
use_ref=True,
|
|
6071
|
+
iso_ang=iso_ang,
|
|
6072
|
+
)
|
|
6073
|
+
)
|
|
6343
6074
|
else:
|
|
6344
|
-
learn=scat_operator.reduce_mean_batch(
|
|
6345
|
-
|
|
6075
|
+
learn = scat_operator.reduce_mean_batch(
|
|
6076
|
+
scat_operator.scattering_cov(
|
|
6077
|
+
u, edge=edge, Jmax=Jmax, use_ref=True, iso_ang=iso_ang
|
|
6078
|
+
)
|
|
6079
|
+
)
|
|
6080
|
+
|
|
6346
6081
|
# make the difference withe the reference coordinates
|
|
6347
|
-
loss=scat_operator.backend.bk_reduce_mean(
|
|
6082
|
+
loss = scat_operator.backend.bk_reduce_mean(
|
|
6083
|
+
scat_operator.backend.bk_square(learn - ref)
|
|
6084
|
+
)
|
|
6348
6085
|
return loss
|
|
6349
6086
|
|
|
6350
6087
|
if to_gaussian:
|
|
6351
6088
|
# Change the data histogram to gaussian distribution
|
|
6352
|
-
im_target=self.to_gaussian(image_target)
|
|
6089
|
+
im_target = self.to_gaussian(image_target)
|
|
6353
6090
|
else:
|
|
6354
|
-
im_target=image_target
|
|
6355
|
-
|
|
6356
|
-
axis=len(im_target.shape)-1
|
|
6091
|
+
im_target = image_target
|
|
6092
|
+
|
|
6093
|
+
axis = len(im_target.shape) - 1
|
|
6357
6094
|
if self.use_2D:
|
|
6358
|
-
axis-=1
|
|
6359
|
-
if axis==0:
|
|
6360
|
-
im_target=self.backend.bk_expand_dims(im_target,0)
|
|
6095
|
+
axis -= 1
|
|
6096
|
+
if axis == 0:
|
|
6097
|
+
im_target = self.backend.bk_expand_dims(im_target, 0)
|
|
6361
6098
|
|
|
6362
6099
|
# compute the number of possible steps
|
|
6363
6100
|
if self.use_2D:
|
|
6364
|
-
jmax=int(
|
|
6101
|
+
jmax = int(
|
|
6102
|
+
np.min([np.log(im_target.shape[1]), np.log(im_target.shape[2])])
|
|
6103
|
+
/ np.log(2)
|
|
6104
|
+
)
|
|
6365
6105
|
elif self.use_1D:
|
|
6366
|
-
jmax=int(np.log(im_target.shape[1])/np.log(2))
|
|
6106
|
+
jmax = int(np.log(im_target.shape[1]) / np.log(2))
|
|
6367
6107
|
else:
|
|
6368
|
-
jmax=int((np.log(im_target.shape[1]//12)/np.log(2))/2)
|
|
6369
|
-
nside=2**jmax
|
|
6108
|
+
jmax = int((np.log(im_target.shape[1] // 12) / np.log(2)) / 2)
|
|
6109
|
+
nside = 2**jmax
|
|
6370
6110
|
|
|
6371
|
-
if nstep>jmax-1:
|
|
6372
|
-
nstep=jmax-1
|
|
6111
|
+
if nstep > jmax - 1:
|
|
6112
|
+
nstep = jmax - 1
|
|
6373
6113
|
|
|
6374
|
-
t1=time.time()
|
|
6375
|
-
tmp={}
|
|
6376
|
-
|
|
6377
|
-
|
|
6378
|
-
|
|
6114
|
+
t1 = time.time()
|
|
6115
|
+
tmp = {}
|
|
6116
|
+
|
|
6117
|
+
l_grd_mask={}
|
|
6118
|
+
|
|
6119
|
+
tmp[nstep - 1] = self.backend.bk_cast(im_target)
|
|
6120
|
+
if grd_mask is not None:
|
|
6121
|
+
l_grd_mask[nstep - 1] = self.backend.bk_cast(grd_mask)
|
|
6122
|
+
else:
|
|
6123
|
+
l_grd_mask[nstep - 1] = None
|
|
6379
6124
|
|
|
6380
|
-
|
|
6381
|
-
|
|
6125
|
+
for ell in range(nstep - 2, -1, -1):
|
|
6126
|
+
tmp[ell] = self.ud_grade_2(tmp[ell + 1], axis=1)
|
|
6127
|
+
if grd_mask is not None:
|
|
6128
|
+
l_grd_mask[ell] = self.ud_grade_2(l_grd_mask[ell + 1], axis=1)
|
|
6129
|
+
else:
|
|
6130
|
+
l_grd_mask[ell] = None
|
|
6382
6131
|
|
|
6132
|
+
|
|
6133
|
+
if not self.use_2D and not self.use_1D:
|
|
6134
|
+
l_nside = nside // (2 ** (nstep - 1))
|
|
6135
|
+
|
|
6383
6136
|
for k in range(nstep):
|
|
6384
|
-
if k==0:
|
|
6137
|
+
if k == 0:
|
|
6385
6138
|
if input_image is None:
|
|
6386
6139
|
np.random.seed(seed)
|
|
6387
6140
|
if self.use_2D:
|
|
6388
|
-
imap=np.random.randn(
|
|
6389
|
-
|
|
6390
|
-
|
|
6141
|
+
imap = self.backend.bk_cast(np.random.randn(
|
|
6142
|
+
synthesised_N, tmp[k].shape[1], tmp[k].shape[2]
|
|
6143
|
+
))
|
|
6391
6144
|
else:
|
|
6392
|
-
imap=np.random.randn(synthesised_N,
|
|
6393
|
-
tmp[k].shape[1])
|
|
6145
|
+
imap = self.backend.bk_cast(np.random.randn(synthesised_N, tmp[k].shape[1]))
|
|
6394
6146
|
else:
|
|
6395
6147
|
if self.use_2D:
|
|
6396
|
-
imap=self.backend.bk_reshape(
|
|
6397
|
-
|
|
6148
|
+
imap = self.backend.bk_reshape(
|
|
6149
|
+
self.backend.bk_tile(
|
|
6150
|
+
self.backend.bk_cast(input_image.flatten()),
|
|
6151
|
+
synthesised_N,
|
|
6152
|
+
),
|
|
6153
|
+
[synthesised_N, tmp[k].shape[1], tmp[k].shape[2]],
|
|
6154
|
+
)
|
|
6398
6155
|
else:
|
|
6399
|
-
imap=self.backend.bk_reshape(
|
|
6400
|
-
|
|
6156
|
+
imap = self.backend.bk_reshape(
|
|
6157
|
+
self.backend.bk_tile(
|
|
6158
|
+
self.backend.bk_cast(input_image.flatten()),
|
|
6159
|
+
synthesised_N,
|
|
6160
|
+
),
|
|
6161
|
+
[synthesised_N, tmp[k].shape[1]],
|
|
6162
|
+
)
|
|
6401
6163
|
else:
|
|
6402
6164
|
# Increase the resolution between each step
|
|
6403
6165
|
if self.use_2D:
|
|
@@ -6408,48 +6170,49 @@ class funct(FOC.FoCUS):
|
|
|
6408
6170
|
imap = self.up_grade(omap, imap.shape[1] * 2, axis=1)
|
|
6409
6171
|
else:
|
|
6410
6172
|
imap = self.up_grade(omap, l_nside, axis=1)
|
|
6411
|
-
|
|
6173
|
+
|
|
6174
|
+
if grd_mask is not None:
|
|
6175
|
+
imap=imap*l_grd_mask[k]+tmp[k]*(1-l_grd_mask[k])
|
|
6176
|
+
|
|
6412
6177
|
# compute the coefficients for the target image
|
|
6413
6178
|
if use_variance:
|
|
6414
|
-
ref,sref=self.scattering_cov(
|
|
6179
|
+
ref, sref = self.scattering_cov(
|
|
6180
|
+
tmp[k], get_variance=True, edge=edge, Jmax=Jmax, iso_ang=iso_ang
|
|
6181
|
+
)
|
|
6415
6182
|
else:
|
|
6416
|
-
ref=self.scattering_cov(tmp[k],edge=edge,Jmax=Jmax,iso_ang=iso_ang)
|
|
6417
|
-
sref=ref
|
|
6418
|
-
|
|
6183
|
+
ref = self.scattering_cov(tmp[k], edge=edge, Jmax=Jmax, iso_ang=iso_ang)
|
|
6184
|
+
sref = ref
|
|
6185
|
+
|
|
6419
6186
|
# compute the mean of the population does nothing if only one map is given
|
|
6420
|
-
ref=self.reduce_mean_batch(ref)
|
|
6421
|
-
|
|
6187
|
+
ref = self.reduce_mean_batch(ref)
|
|
6188
|
+
|
|
6422
6189
|
# define a loss to minimize
|
|
6423
|
-
loss=synthe.Loss(The_loss,self,ref,sref,use_variance)
|
|
6424
|
-
|
|
6190
|
+
loss = synthe.Loss(The_loss, self, ref, sref, use_variance)
|
|
6191
|
+
|
|
6425
6192
|
sy = synthe.Synthesis([loss])
|
|
6426
6193
|
|
|
6427
6194
|
# initialize the synthesised map
|
|
6428
6195
|
if self.use_2D:
|
|
6429
|
-
print(
|
|
6196
|
+
print("Synthesis scale [ %d x %d ]" % (imap.shape[1], imap.shape[2]))
|
|
6430
6197
|
elif self.use_1D:
|
|
6431
|
-
print(
|
|
6198
|
+
print("Synthesis scale [ %d ]" % (imap.shape[1]))
|
|
6432
6199
|
else:
|
|
6433
|
-
print(
|
|
6434
|
-
l_nside*=2
|
|
6435
|
-
|
|
6200
|
+
print("Synthesis scale nside=%d" % (l_nside))
|
|
6201
|
+
l_nside *= 2
|
|
6202
|
+
|
|
6436
6203
|
# do the minimization
|
|
6437
|
-
omap=sy.run(imap,
|
|
6438
|
-
EVAL_FREQUENCY=EVAL_FREQUENCY,
|
|
6439
|
-
NUM_EPOCHS = NUM_EPOCHS)
|
|
6440
|
-
|
|
6441
|
-
|
|
6204
|
+
omap = sy.run(imap, EVAL_FREQUENCY=EVAL_FREQUENCY, NUM_EPOCHS=NUM_EPOCHS,grd_mask=l_grd_mask[k])
|
|
6442
6205
|
|
|
6443
|
-
t2=time.time()
|
|
6444
|
-
print(
|
|
6206
|
+
t2 = time.time()
|
|
6207
|
+
print("Total computation %.2fs" % (t2 - t1))
|
|
6445
6208
|
|
|
6446
6209
|
if to_gaussian:
|
|
6447
|
-
omap=self.from_gaussian(omap)
|
|
6210
|
+
omap = self.from_gaussian(omap)
|
|
6448
6211
|
|
|
6449
|
-
if axis==0 and synthesised_N==1:
|
|
6212
|
+
if axis == 0 and synthesised_N == 1:
|
|
6450
6213
|
return omap[0]
|
|
6451
6214
|
else:
|
|
6452
6215
|
return omap
|
|
6453
|
-
|
|
6454
|
-
def to_numpy(self,x):
|
|
6216
|
+
|
|
6217
|
+
def to_numpy(self, x):
|
|
6455
6218
|
return self.backend.to_numpy(x)
|