foscat 3.8.0__py3-none-any.whl → 3.9.0__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- foscat/BkBase.py +39 -29
- foscat/BkNumpy.py +56 -47
- foscat/BkTensorflow.py +89 -81
- foscat/BkTorch.py +95 -59
- foscat/FoCUS.py +72 -56
- foscat/Synthesis.py +3 -3
- foscat/alm.py +194 -177
- foscat/backend.py +84 -70
- foscat/scat_cov.py +1876 -2100
- foscat/scat_cov2D.py +146 -53
- {foscat-3.8.0.dist-info → foscat-3.9.0.dist-info}/METADATA +1 -1
- {foscat-3.8.0.dist-info → foscat-3.9.0.dist-info}/RECORD +15 -15
- {foscat-3.8.0.dist-info → foscat-3.9.0.dist-info}/WHEEL +1 -1
- {foscat-3.8.0.dist-info → foscat-3.9.0.dist-info}/LICENSE +0 -0
- {foscat-3.8.0.dist-info → foscat-3.9.0.dist-info}/top_level.txt +0 -0
foscat/scat_cov.py
CHANGED
|
@@ -92,7 +92,12 @@ class scat_cov:
|
|
|
92
92
|
)
|
|
93
93
|
|
|
94
94
|
def conv2complex(self, val):
|
|
95
|
-
if
|
|
95
|
+
if (
|
|
96
|
+
val.dtype == "complex64"
|
|
97
|
+
or val.dtype == "complex128"
|
|
98
|
+
or val.dtype == "torch.complex64"
|
|
99
|
+
or val.dtype == "torch.complex128"
|
|
100
|
+
):
|
|
96
101
|
return val
|
|
97
102
|
else:
|
|
98
103
|
return self.backend.bk_complex(val, 0 * val)
|
|
@@ -2043,9 +2048,11 @@ class scat_cov:
|
|
|
2043
2048
|
)
|
|
2044
2049
|
)
|
|
2045
2050
|
else:
|
|
2046
|
-
s3[i, j, idx[noff:], k, l_orient] =
|
|
2047
|
-
|
|
2048
|
-
|
|
2051
|
+
s3[i, j, idx[noff:], k, l_orient] = (
|
|
2052
|
+
self.backend.to_numpy(self.S3)[
|
|
2053
|
+
i, j, j2 == ij - noff, k, l_orient
|
|
2054
|
+
]
|
|
2055
|
+
)
|
|
2049
2056
|
s3[i, j, idx[:noff], k, l_orient] = (
|
|
2050
2057
|
self.add_data_from_slope(
|
|
2051
2058
|
self.backend.to_numpy(self.S3)[
|
|
@@ -2208,7 +2215,7 @@ class scat_cov:
|
|
|
2208
2215
|
|
|
2209
2216
|
|
|
2210
2217
|
class funct(FOC.FoCUS):
|
|
2211
|
-
|
|
2218
|
+
|
|
2212
2219
|
def fill(self, im, nullval=hp.UNSEEN):
|
|
2213
2220
|
if self.use_2D:
|
|
2214
2221
|
return self.fill_2d(im, nullval=nullval)
|
|
@@ -2277,14 +2284,14 @@ class funct(FOC.FoCUS):
|
|
|
2277
2284
|
if tmp.S1 is not None:
|
|
2278
2285
|
nS1 = np.expand_dims(tmp.S1, 0)
|
|
2279
2286
|
else:
|
|
2280
|
-
nS0 = np.expand_dims(tmp.S0
|
|
2281
|
-
nS2 = np.expand_dims(tmp.S2
|
|
2282
|
-
nS3 = np.expand_dims(tmp.S3
|
|
2283
|
-
nS4 = np.expand_dims(tmp.S4
|
|
2287
|
+
nS0 = np.expand_dims(self.backend.to_numpy(tmp.S0), 0)
|
|
2288
|
+
nS2 = np.expand_dims(self.backend.to_numpy(tmp.S2), 0)
|
|
2289
|
+
nS3 = np.expand_dims(self.backend.to_numpy(tmp.S3), 0)
|
|
2290
|
+
nS4 = np.expand_dims(self.backend.to_numpy(tmp.S4), 0)
|
|
2284
2291
|
if tmp.S3P is not None:
|
|
2285
|
-
nS3P = np.expand_dims(tmp.S3P
|
|
2292
|
+
nS3P = np.expand_dims(self.backend.to_numpy(tmp.S3P), 0)
|
|
2286
2293
|
if tmp.S1 is not None:
|
|
2287
|
-
nS1 = np.expand_dims(tmp.S1
|
|
2294
|
+
nS1 = np.expand_dims(self.backend.to_numpy(tmp.S1), 0)
|
|
2288
2295
|
|
|
2289
2296
|
if S0 is None:
|
|
2290
2297
|
S0 = nS0
|
|
@@ -2304,24 +2311,24 @@ class funct(FOC.FoCUS):
|
|
|
2304
2311
|
S3P = np.concatenate([S3P, nS3P], 0)
|
|
2305
2312
|
if tmp.S1 is not None:
|
|
2306
2313
|
S1 = np.concatenate([S1, nS1], 0)
|
|
2307
|
-
sS0 = np.std(S0, 0)
|
|
2308
|
-
sS2 = np.std(S2, 0)
|
|
2309
|
-
sS3 = np.std(S3, 0)
|
|
2310
|
-
sS4 = np.std(S4, 0)
|
|
2311
|
-
mS0 = np.mean(S0, 0)
|
|
2312
|
-
mS2 = np.mean(S2, 0)
|
|
2313
|
-
mS3 = np.mean(S3, 0)
|
|
2314
|
-
mS4 = np.mean(S4, 0)
|
|
2314
|
+
sS0 = self.backend.bk_cast(np.std(S0, 0))
|
|
2315
|
+
sS2 = self.backend.bk_cast(np.std(S2, 0))
|
|
2316
|
+
sS3 = self.backend.bk_cast(np.std(S3, 0))
|
|
2317
|
+
sS4 = self.backend.bk_cast(np.std(S4, 0))
|
|
2318
|
+
mS0 = self.backend.bk_cast(np.mean(S0, 0))
|
|
2319
|
+
mS2 = self.backend.bk_cast(np.mean(S2, 0))
|
|
2320
|
+
mS3 = self.backend.bk_cast(np.mean(S3, 0))
|
|
2321
|
+
mS4 = self.backend.bk_cast(np.mean(S4, 0))
|
|
2315
2322
|
if tmp.S3P is not None:
|
|
2316
|
-
sS3P = np.std(S3P, 0)
|
|
2317
|
-
mS3P = np.mean(S3P, 0)
|
|
2323
|
+
sS3P = self.backend.bk_cast(np.std(S3P, 0))
|
|
2324
|
+
mS3P = self.backend.bk_cast(np.mean(S3P, 0))
|
|
2318
2325
|
else:
|
|
2319
2326
|
sS3P = None
|
|
2320
2327
|
mS3P = None
|
|
2321
2328
|
|
|
2322
2329
|
if tmp.S1 is not None:
|
|
2323
|
-
sS1 = np.std(S1, 0)
|
|
2324
|
-
mS1 = np.mean(S1, 0)
|
|
2330
|
+
sS1 = self.backend.bk_cast(np.std(S1, 0))
|
|
2331
|
+
mS1 = self.backend.bk_cast(np.mean(S1, 0))
|
|
2325
2332
|
else:
|
|
2326
2333
|
sS1 = None
|
|
2327
2334
|
mS1 = None
|
|
@@ -2375,14 +2382,19 @@ class funct(FOC.FoCUS):
|
|
|
2375
2382
|
|
|
2376
2383
|
# instead of difference between "opposite" channels use weighted average
|
|
2377
2384
|
# of cosine and sine contributions using all channels
|
|
2378
|
-
angles = (
|
|
2379
|
-
2 * np.pi * np.arange(self.NORIENT) / self.NORIENT
|
|
2385
|
+
angles = self.backend.bk_cast(
|
|
2386
|
+
(2 * np.pi * np.arange(self.NORIENT) / self.NORIENT).reshape(
|
|
2387
|
+
1, 1, self.NORIENT
|
|
2388
|
+
)
|
|
2380
2389
|
) # shape: (NORIENT,)
|
|
2381
|
-
angles = angles.reshape(1, 1, self.NORIENT)
|
|
2382
2390
|
|
|
2383
2391
|
# we use cosines and sines as weights for sim
|
|
2384
|
-
weighted_cos = self.backend.bk_reduce_mean(
|
|
2385
|
-
|
|
2392
|
+
weighted_cos = self.backend.bk_reduce_mean(
|
|
2393
|
+
sim * self.backend.bk_cos(angles), axis=-1
|
|
2394
|
+
)
|
|
2395
|
+
weighted_sin = self.backend.bk_reduce_mean(
|
|
2396
|
+
sim * self.backend.bk_sin(angles), axis=-1
|
|
2397
|
+
)
|
|
2386
2398
|
# For simplicity, take first element of the batch
|
|
2387
2399
|
cc = weighted_cos[0]
|
|
2388
2400
|
ss = weighted_sin[0]
|
|
@@ -2402,7 +2414,9 @@ class funct(FOC.FoCUS):
|
|
|
2402
2414
|
phase = np.fmod(np.arctan2(ss, cc) + 2 * np.pi, 2 * np.pi)
|
|
2403
2415
|
else:
|
|
2404
2416
|
phase = np.fmod(
|
|
2405
|
-
np.arctan2(
|
|
2417
|
+
np.arctan2(self.backend.to_numpy(ss), self.backend.to_numpy(cc))
|
|
2418
|
+
+ 2 * np.pi,
|
|
2419
|
+
2 * np.pi,
|
|
2406
2420
|
)
|
|
2407
2421
|
|
|
2408
2422
|
# instead of linear interpolation cosine‐based interpolation
|
|
@@ -2416,10 +2430,10 @@ class funct(FOC.FoCUS):
|
|
|
2416
2430
|
# build rotation matrix
|
|
2417
2431
|
mat = np.zeros([sim.shape[1], self.NORIENT * self.NORIENT])
|
|
2418
2432
|
lidx = np.arange(sim.shape[1])
|
|
2419
|
-
for
|
|
2433
|
+
for ell in range(self.NORIENT):
|
|
2420
2434
|
# Instead of simple linear weights, we use the cosine weights w0 and w1.
|
|
2421
|
-
col0 = self.NORIENT * ((
|
|
2422
|
-
col1 = self.NORIENT * ((
|
|
2435
|
+
col0 = self.NORIENT * ((ell + iph) % self.NORIENT) + ell
|
|
2436
|
+
col1 = self.NORIENT * ((ell + iph + 1) % self.NORIENT) + ell
|
|
2423
2437
|
mat[lidx, col0] = w0
|
|
2424
2438
|
mat[lidx, col1] = w1
|
|
2425
2439
|
|
|
@@ -2435,7 +2449,9 @@ class funct(FOC.FoCUS):
|
|
|
2435
2449
|
|
|
2436
2450
|
sim2 = self.backend.bk_reduce_sum(
|
|
2437
2451
|
self.backend.bk_reshape(
|
|
2438
|
-
|
|
2452
|
+
self.backend.bk_cast(
|
|
2453
|
+
mat.reshape(1, mat.shape[0], self.NORIENT * self.NORIENT)
|
|
2454
|
+
)
|
|
2439
2455
|
* tmp2,
|
|
2440
2456
|
[sim.shape[0], cmat[k].shape[0], self.NORIENT, self.NORIENT],
|
|
2441
2457
|
),
|
|
@@ -2445,10 +2461,10 @@ class funct(FOC.FoCUS):
|
|
|
2445
2461
|
sim2 = self.backend.bk_abs(self.convol(sim2, axis=1))
|
|
2446
2462
|
|
|
2447
2463
|
weighted_cos2 = self.backend.bk_reduce_mean(
|
|
2448
|
-
sim2 *
|
|
2464
|
+
sim2 * self.backend.bk_cos(angles), axis=-1
|
|
2449
2465
|
)
|
|
2450
2466
|
weighted_sin2 = self.backend.bk_reduce_mean(
|
|
2451
|
-
sim2 *
|
|
2467
|
+
sim2 * self.backend.bk_sin(angles), axis=-1
|
|
2452
2468
|
)
|
|
2453
2469
|
|
|
2454
2470
|
cc2 = weighted_cos2[0]
|
|
@@ -2469,7 +2485,11 @@ class funct(FOC.FoCUS):
|
|
|
2469
2485
|
phase2 = np.fmod(np.arctan2(ss2, cc2) + 2 * np.pi, 2 * np.pi)
|
|
2470
2486
|
else:
|
|
2471
2487
|
phase2 = np.fmod(
|
|
2472
|
-
np.arctan2(
|
|
2488
|
+
np.arctan2(
|
|
2489
|
+
self.backend.to_numpy(ss2), self.backend.to_numpy(cc2)
|
|
2490
|
+
)
|
|
2491
|
+
+ 2 * np.pi,
|
|
2492
|
+
2 * np.pi,
|
|
2473
2493
|
)
|
|
2474
2494
|
|
|
2475
2495
|
phase2_scaled = self.NORIENT * phase2 / (2 * np.pi)
|
|
@@ -2480,9 +2500,11 @@ class funct(FOC.FoCUS):
|
|
|
2480
2500
|
lidx = np.arange(sim.shape[1])
|
|
2481
2501
|
|
|
2482
2502
|
for m in range(self.NORIENT):
|
|
2483
|
-
for
|
|
2484
|
-
col0 = self.NORIENT * ((
|
|
2485
|
-
col1 =
|
|
2503
|
+
for ell in range(self.NORIENT):
|
|
2504
|
+
col0 = self.NORIENT * ((ell + iph2[:, m]) % self.NORIENT) + ell
|
|
2505
|
+
col1 = (
|
|
2506
|
+
self.NORIENT * ((ell + iph2[:, m] + 1) % self.NORIENT) + ell
|
|
2507
|
+
)
|
|
2486
2508
|
mat2[k2, lidx, m, col0] = w0_2[:, m]
|
|
2487
2509
|
mat2[k2, lidx, m, col1] = w1_2[:, m]
|
|
2488
2510
|
cmat2[k] = self.backend.bk_cast(mat2.astype("complex64"))
|
|
@@ -2500,17 +2522,17 @@ class funct(FOC.FoCUS):
|
|
|
2500
2522
|
)
|
|
2501
2523
|
|
|
2502
2524
|
def eval(
|
|
2503
|
-
|
|
2504
|
-
|
|
2505
|
-
|
|
2506
|
-
|
|
2507
|
-
|
|
2508
|
-
|
|
2509
|
-
|
|
2510
|
-
|
|
2511
|
-
|
|
2512
|
-
|
|
2513
|
-
|
|
2525
|
+
self,
|
|
2526
|
+
image1,
|
|
2527
|
+
image2=None,
|
|
2528
|
+
mask=None,
|
|
2529
|
+
norm=None,
|
|
2530
|
+
calc_var=False,
|
|
2531
|
+
cmat=None,
|
|
2532
|
+
cmat2=None,
|
|
2533
|
+
Jmax=None,
|
|
2534
|
+
out_nside=None,
|
|
2535
|
+
edge=True,
|
|
2514
2536
|
):
|
|
2515
2537
|
"""
|
|
2516
2538
|
Calculates the scattering correlations for a batch of images. Mean are done over pixels.
|
|
@@ -2540,9 +2562,9 @@ class funct(FOC.FoCUS):
|
|
|
2540
2562
|
-------
|
|
2541
2563
|
S1, S2, S3, S4 normalized
|
|
2542
2564
|
"""
|
|
2543
|
-
|
|
2565
|
+
|
|
2544
2566
|
return_data = self.return_data
|
|
2545
|
-
|
|
2567
|
+
|
|
2546
2568
|
# Check input consistency
|
|
2547
2569
|
if image2 is not None:
|
|
2548
2570
|
if list(image1.shape) != list(image2.shape):
|
|
@@ -2552,7 +2574,10 @@ class funct(FOC.FoCUS):
|
|
|
2552
2574
|
return None
|
|
2553
2575
|
if mask is not None:
|
|
2554
2576
|
if self.use_2D:
|
|
2555
|
-
if
|
|
2577
|
+
if (
|
|
2578
|
+
image1.shape[-2] != mask.shape[1]
|
|
2579
|
+
or image1.shape[-1] != mask.shape[2]
|
|
2580
|
+
):
|
|
2556
2581
|
print(
|
|
2557
2582
|
"The LAST 2 COLUMNs of the mask should have the same size ",
|
|
2558
2583
|
mask.shape,
|
|
@@ -2562,7 +2587,7 @@ class funct(FOC.FoCUS):
|
|
|
2562
2587
|
)
|
|
2563
2588
|
return None
|
|
2564
2589
|
else:
|
|
2565
|
-
if image1.shape[-1] != mask.shape[1]:
|
|
2590
|
+
if image1.shape[-1] != mask.shape[1]:
|
|
2566
2591
|
print(
|
|
2567
2592
|
"The LAST COLUMN of the mask should have the same size ",
|
|
2568
2593
|
mask.shape,
|
|
@@ -2616,16 +2641,17 @@ class funct(FOC.FoCUS):
|
|
|
2616
2641
|
nside = int(np.sqrt(npix // 12))
|
|
2617
2642
|
|
|
2618
2643
|
J = int(np.log(nside) / np.log(2)) # Number of j scales
|
|
2619
|
-
|
|
2620
|
-
if (self.use_2D or self.use_1D) and self.KERNELSZ>3:
|
|
2621
|
-
J-=1
|
|
2644
|
+
|
|
2645
|
+
if (self.use_2D or self.use_1D) and self.KERNELSZ > 3:
|
|
2646
|
+
J -= 1
|
|
2622
2647
|
if Jmax is None:
|
|
2623
2648
|
Jmax = J # Number of steps for the loop on scales
|
|
2624
|
-
if Jmax>J:
|
|
2625
|
-
print(
|
|
2626
|
-
print(
|
|
2627
|
-
|
|
2628
|
-
|
|
2649
|
+
if Jmax > J:
|
|
2650
|
+
print("==========\n\n")
|
|
2651
|
+
print(
|
|
2652
|
+
"The Jmax you requested is larger than the data size, which may cause problems while computing the scattering transform."
|
|
2653
|
+
)
|
|
2654
|
+
print("\n\n==========")
|
|
2629
2655
|
|
|
2630
2656
|
### LOCAL VARIABLES (IMAGES and MASK)
|
|
2631
2657
|
if len(image1.shape) == 1 or (len(image1.shape) == 2 and self.use_2D):
|
|
@@ -2768,32 +2794,37 @@ class funct(FOC.FoCUS):
|
|
|
2768
2794
|
nside_j3 = nside # NSIDE start (nside_j3 = nside / 2^j3)
|
|
2769
2795
|
|
|
2770
2796
|
# a remettre comme avant
|
|
2771
|
-
M1_dic={}
|
|
2772
|
-
M2_dic={}
|
|
2773
|
-
|
|
2797
|
+
M1_dic = {}
|
|
2798
|
+
M2_dic = {}
|
|
2799
|
+
|
|
2774
2800
|
for j3 in range(Jmax):
|
|
2775
|
-
|
|
2801
|
+
|
|
2776
2802
|
if edge:
|
|
2777
2803
|
if self.mask_mask is None:
|
|
2778
|
-
self.mask_mask={}
|
|
2804
|
+
self.mask_mask = {}
|
|
2779
2805
|
if self.use_2D:
|
|
2780
|
-
if (vmask.shape[1],vmask.shape[2]) not in self.mask_mask:
|
|
2781
|
-
mask_mask=np.zeros([1,vmask.shape[1],vmask.shape[2]])
|
|
2782
|
-
mask_mask[
|
|
2783
|
-
|
|
2784
|
-
|
|
2785
|
-
|
|
2786
|
-
|
|
2787
|
-
|
|
2788
|
-
|
|
2806
|
+
if (vmask.shape[1], vmask.shape[2]) not in self.mask_mask:
|
|
2807
|
+
mask_mask = np.zeros([1, vmask.shape[1], vmask.shape[2]])
|
|
2808
|
+
mask_mask[
|
|
2809
|
+
0,
|
|
2810
|
+
self.KERNELSZ // 2 : -self.KERNELSZ // 2 + 1,
|
|
2811
|
+
self.KERNELSZ // 2 : -self.KERNELSZ // 2 + 1,
|
|
2812
|
+
] = 1.0
|
|
2813
|
+
self.mask_mask[(vmask.shape[1], vmask.shape[2])] = (
|
|
2814
|
+
self.backend.bk_cast(mask_mask)
|
|
2815
|
+
)
|
|
2816
|
+
vmask = vmask * self.mask_mask[(vmask.shape[1], vmask.shape[2])]
|
|
2817
|
+
# print(self.KERNELSZ//2,vmask,mask_mask)
|
|
2818
|
+
|
|
2789
2819
|
if self.use_1D:
|
|
2790
2820
|
if (vmask.shape[1]) not in self.mask_mask:
|
|
2791
|
-
mask_mask=np.zeros([1,vmask.shape[1]])
|
|
2792
|
-
mask_mask[0,
|
|
2793
|
-
|
|
2794
|
-
|
|
2795
|
-
|
|
2796
|
-
|
|
2821
|
+
mask_mask = np.zeros([1, vmask.shape[1]])
|
|
2822
|
+
mask_mask[0, self.KERNELSZ // 2 : -self.KERNELSZ // 2 + 1] = 1.0
|
|
2823
|
+
self.mask_mask[(vmask.shape[1])] = self.backend.bk_cast(
|
|
2824
|
+
mask_mask
|
|
2825
|
+
)
|
|
2826
|
+
vmask = vmask * self.mask_mask[(vmask.shape[1])]
|
|
2827
|
+
|
|
2797
2828
|
if return_data:
|
|
2798
2829
|
S3[j3] = None
|
|
2799
2830
|
S3P[j3] = None
|
|
@@ -3447,9 +3478,9 @@ class funct(FOC.FoCUS):
|
|
|
3447
3478
|
M2_dic[j2], axis=1
|
|
3448
3479
|
) # [Nbatch, Npix_j3, Norient3]
|
|
3449
3480
|
M2_dic[j2] = self.ud_grade_2(
|
|
3450
|
-
|
|
3481
|
+
M2_smooth, axis=1
|
|
3451
3482
|
) # [Nbatch, Npix_j3, Norient3]
|
|
3452
|
-
|
|
3483
|
+
|
|
3453
3484
|
### Mask
|
|
3454
3485
|
vmask = self.ud_grade_2(vmask, axis=1)
|
|
3455
3486
|
|
|
@@ -3541,1278 +3572,232 @@ class funct(FOC.FoCUS):
|
|
|
3541
3572
|
use_1D=self.use_1D,
|
|
3542
3573
|
)
|
|
3543
3574
|
|
|
3544
|
-
def
|
|
3545
|
-
|
|
3546
|
-
|
|
3547
|
-
|
|
3548
|
-
|
|
3549
|
-
|
|
3550
|
-
|
|
3551
|
-
|
|
3552
|
-
|
|
3553
|
-
|
|
3554
|
-
|
|
3555
|
-
|
|
3575
|
+
def clean_norm(self):
|
|
3576
|
+
self.P1_dic = None
|
|
3577
|
+
self.P2_dic = None
|
|
3578
|
+
return
|
|
3579
|
+
|
|
3580
|
+
def _compute_S3(
|
|
3581
|
+
self,
|
|
3582
|
+
j2,
|
|
3583
|
+
j3,
|
|
3584
|
+
conv,
|
|
3585
|
+
vmask,
|
|
3586
|
+
M_dic,
|
|
3587
|
+
MconvPsi_dic,
|
|
3588
|
+
calc_var=False,
|
|
3589
|
+
return_data=False,
|
|
3590
|
+
cmat2=None,
|
|
3556
3591
|
):
|
|
3557
3592
|
"""
|
|
3558
|
-
|
|
3559
|
-
|
|
3560
|
-
S1 = <|I * Psi_j3|>
|
|
3561
|
-
Normalization : take the log
|
|
3562
|
-
power spectrum:
|
|
3563
|
-
S2 = <|I * Psi_j3|^2>
|
|
3564
|
-
Normalization : take the log
|
|
3565
|
-
orig. x modulus:
|
|
3566
|
-
S3 = < (I * Psi)_j3 x (|I * Psi_j2| * Psi_j3)^* >
|
|
3567
|
-
Normalization : divide by (S2_j2 * S2_j3)^0.5
|
|
3568
|
-
modulus x modulus:
|
|
3569
|
-
S4 = <(|I * psi1| * psi3)(|I * psi2| * psi3)^*>
|
|
3570
|
-
Normalization : divide by (S2_j1 * S2_j2)^0.5
|
|
3593
|
+
Compute the S3 coefficients (auto or cross)
|
|
3594
|
+
S3 = < (Ia * Psi)_j3 x (|Ib * Psi_j2| * Psi_j3)^* >_pix
|
|
3571
3595
|
Parameters
|
|
3572
3596
|
----------
|
|
3573
|
-
image1: tensor
|
|
3574
|
-
Image on which we compute the scattering coefficients [Nbatch, Npix, 1, 1]
|
|
3575
|
-
image2: tensor
|
|
3576
|
-
Second image. If not None, we compute cross-scattering covariance coefficients.
|
|
3577
|
-
mask:
|
|
3578
|
-
norm: None or str
|
|
3579
|
-
If None no normalization is applied, if 'auto' normalize by the reference S2,
|
|
3580
|
-
if 'self' normalize by the current S2.
|
|
3581
3597
|
Returns
|
|
3582
3598
|
-------
|
|
3583
|
-
|
|
3599
|
+
cs3, ss3: real and imag parts of S3 coeff
|
|
3584
3600
|
"""
|
|
3585
|
-
|
|
3586
|
-
|
|
3587
|
-
|
|
3588
|
-
|
|
3589
|
-
|
|
3590
|
-
|
|
3591
|
-
|
|
3592
|
-
|
|
3593
|
-
|
|
3594
|
-
|
|
3595
|
-
|
|
3596
|
-
|
|
3597
|
-
|
|
3598
|
-
|
|
3599
|
-
|
|
3600
|
-
|
|
3601
|
-
|
|
3602
|
-
)
|
|
3603
|
-
|
|
3604
|
-
if self.use_2D and len(image1.shape) < 2:
|
|
3605
|
-
print(
|
|
3606
|
-
"To work with 2D scattering transform, two dimension is needed, input map has only on dimension"
|
|
3601
|
+
### Compute |I1 * Psi_j2| * Psi_j3 = M1_j2 * Psi_j3
|
|
3602
|
+
# Warning: M1_dic[j2] is already at j3 resolution [Nbatch, Npix_j3, Norient3]
|
|
3603
|
+
MconvPsi = self.convol(
|
|
3604
|
+
M_dic[j2], axis=1
|
|
3605
|
+
) # [Nbatch, Npix_j3, Norient3, Norient2]
|
|
3606
|
+
if cmat2 is not None:
|
|
3607
|
+
tmp2 = self.backend.bk_repeat(MconvPsi, self.NORIENT, axis=-1)
|
|
3608
|
+
MconvPsi = self.backend.bk_reduce_sum(
|
|
3609
|
+
self.backend.bk_reshape(
|
|
3610
|
+
cmat2[j3][j2] * tmp2,
|
|
3611
|
+
[
|
|
3612
|
+
tmp2.shape[0],
|
|
3613
|
+
cmat2[j3].shape[1],
|
|
3614
|
+
self.NORIENT,
|
|
3615
|
+
self.NORIENT,
|
|
3616
|
+
self.NORIENT,
|
|
3617
|
+
],
|
|
3618
|
+
),
|
|
3619
|
+
3,
|
|
3607
3620
|
)
|
|
3608
|
-
return None
|
|
3609
|
-
|
|
3610
|
-
### AUTO OR CROSS
|
|
3611
|
-
cross = False
|
|
3612
|
-
if image2 is not None:
|
|
3613
|
-
cross = True
|
|
3614
3621
|
|
|
3615
|
-
|
|
3616
|
-
|
|
3617
|
-
# determine jmax and nside corresponding to the input map
|
|
3618
|
-
im_shape = image1.shape
|
|
3619
|
-
if self.use_2D:
|
|
3620
|
-
if len(image1.shape) == 2:
|
|
3621
|
-
nside = np.min([im_shape[0], im_shape[1]])
|
|
3622
|
-
npix = im_shape[0] * im_shape[1] # Number of pixels
|
|
3623
|
-
x1 = im_shape[0]
|
|
3624
|
-
x2 = im_shape[1]
|
|
3625
|
-
else:
|
|
3626
|
-
nside = np.min([im_shape[1], im_shape[2]])
|
|
3627
|
-
npix = im_shape[1] * im_shape[2] # Number of pixels
|
|
3628
|
-
x1 = im_shape[1]
|
|
3629
|
-
x2 = im_shape[2]
|
|
3630
|
-
J = int(np.log(nside - self.KERNELSZ) / np.log(2)) # Number of j scales
|
|
3631
|
-
elif self.use_1D:
|
|
3632
|
-
if len(image1.shape) == 2:
|
|
3633
|
-
npix = int(im_shape[1]) # Number of pixels
|
|
3634
|
-
else:
|
|
3635
|
-
npix = int(im_shape[0]) # Number of pixels
|
|
3622
|
+
# Store it so we can use it in S4 computation
|
|
3623
|
+
MconvPsi_dic[j2] = MconvPsi # [Nbatch, Npix_j3, Norient3, Norient2]
|
|
3636
3624
|
|
|
3637
|
-
|
|
3625
|
+
### Compute the product (I2 * Psi)_j3 x (M1_j2 * Psi_j3)^*
|
|
3626
|
+
# z_1 x z_2^* = (a1a2 + b1b2) + i(b1a2 - a1b2)
|
|
3627
|
+
# cconv, sconv are [Nbatch, Npix_j3, Norient3]
|
|
3628
|
+
if self.use_1D:
|
|
3629
|
+
s3 = conv * self.backend.bk_conjugate(MconvPsi)
|
|
3630
|
+
else:
|
|
3631
|
+
s3 = self.backend.bk_expand_dims(conv, -1) * self.backend.bk_conjugate(
|
|
3632
|
+
MconvPsi
|
|
3633
|
+
) # [Nbatch, Npix_j3, Norient3, Norient2]
|
|
3638
3634
|
|
|
3639
|
-
|
|
3635
|
+
### Apply the mask [Nmask, Npix_j3] and sum over pixels
|
|
3636
|
+
if return_data:
|
|
3637
|
+
return s3
|
|
3640
3638
|
else:
|
|
3641
|
-
if
|
|
3642
|
-
|
|
3639
|
+
if calc_var:
|
|
3640
|
+
s3, vs3 = self.masked_mean(
|
|
3641
|
+
s3, vmask, axis=1, rank=j2, calc_var=True
|
|
3642
|
+
) # [Nbatch, Nmask, Norient3, Norient2]
|
|
3643
|
+
return s3, vs3
|
|
3643
3644
|
else:
|
|
3644
|
-
|
|
3645
|
+
s3 = self.masked_mean(
|
|
3646
|
+
s3, vmask, axis=1, rank=j2
|
|
3647
|
+
) # [Nbatch, Nmask, Norient3, Norient2]
|
|
3648
|
+
return s3
|
|
3645
3649
|
|
|
3646
|
-
|
|
3650
|
+
def _compute_S4(
|
|
3651
|
+
self,
|
|
3652
|
+
j1,
|
|
3653
|
+
j2,
|
|
3654
|
+
vmask,
|
|
3655
|
+
M1convPsi_dic,
|
|
3656
|
+
M2convPsi_dic=None,
|
|
3657
|
+
calc_var=False,
|
|
3658
|
+
return_data=False,
|
|
3659
|
+
):
|
|
3660
|
+
#### Simplify notations
|
|
3661
|
+
M1 = M1convPsi_dic[j1] # [Nbatch, Npix_j3, Norient3, Norient1]
|
|
3647
3662
|
|
|
3648
|
-
|
|
3649
|
-
|
|
3650
|
-
|
|
3651
|
-
|
|
3652
|
-
|
|
3653
|
-
Jmax = J # Number of steps for the loop on scales
|
|
3654
|
-
if Jmax>J:
|
|
3655
|
-
print('==========\n\n')
|
|
3656
|
-
print('The Jmax you requested is larger than the data size, which may cause problems while computing the scattering transform.')
|
|
3657
|
-
print('\n\n==========')
|
|
3658
|
-
|
|
3663
|
+
# Auto or Cross coefficients
|
|
3664
|
+
if M2convPsi_dic is None: # Auto
|
|
3665
|
+
M2 = M1convPsi_dic[j2] # [Nbatch, Npix_j3, Norient3, Norient2]
|
|
3666
|
+
else: # Cross
|
|
3667
|
+
M2 = M2convPsi_dic[j2]
|
|
3659
3668
|
|
|
3660
|
-
###
|
|
3661
|
-
|
|
3662
|
-
|
|
3663
|
-
|
|
3664
|
-
) # Local image1 [Nbatch, Npix]
|
|
3665
|
-
if cross:
|
|
3666
|
-
I2 = self.backend.bk_cast(
|
|
3667
|
-
self.backend.bk_expand_dims(image2, 0)
|
|
3668
|
-
) # Local image2 [Nbatch, Npix]
|
|
3669
|
+
### Compute the product (|I1 * Psi_j1| * Psi_j3)(|I2 * Psi_j2| * Psi_j3)
|
|
3670
|
+
# z_1 x z_2^* = (a1a2 + b1b2) + i(b1a2 - a1b2)
|
|
3671
|
+
if self.use_1D:
|
|
3672
|
+
s4 = M1 * self.backend.bk_conjugate(M2)
|
|
3669
3673
|
else:
|
|
3670
|
-
|
|
3671
|
-
|
|
3672
|
-
|
|
3674
|
+
s4 = self.backend.bk_expand_dims(M1, -2) * self.backend.bk_conjugate(
|
|
3675
|
+
self.backend.bk_expand_dims(M2, -1)
|
|
3676
|
+
) # [Nbatch, Npix_j3, Norient3, Norient2, Norient1]
|
|
3673
3677
|
|
|
3674
|
-
|
|
3675
|
-
|
|
3676
|
-
|
|
3677
|
-
else:
|
|
3678
|
-
vmask = self.backend.bk_ones([1, npix], dtype=self.all_type)
|
|
3678
|
+
### Apply the mask and sum over pixels
|
|
3679
|
+
if return_data:
|
|
3680
|
+
return s4
|
|
3679
3681
|
else:
|
|
3680
|
-
|
|
3682
|
+
if calc_var:
|
|
3683
|
+
s4, vs4 = self.masked_mean(
|
|
3684
|
+
s4, vmask, axis=1, rank=j2, calc_var=True
|
|
3685
|
+
) # [Nbatch, Nmask, Norient3, Norient2, Norient1]
|
|
3686
|
+
return s4, vs4
|
|
3687
|
+
else:
|
|
3688
|
+
s4 = self.masked_mean(
|
|
3689
|
+
s4, vmask, axis=1, rank=j2
|
|
3690
|
+
) # [Nbatch, Nmask, Norient3, Norient2, Norient1]
|
|
3691
|
+
return s4
|
|
3681
3692
|
|
|
3682
|
-
|
|
3683
|
-
|
|
3684
|
-
|
|
3685
|
-
|
|
3686
|
-
|
|
3693
|
+
def computer_filter(self, M, N, J, L):
|
|
3694
|
+
"""
|
|
3695
|
+
This function is strongly inspire by the package https://github.com/SihaoCheng/scattering_transform
|
|
3696
|
+
Done by Sihao Cheng and Rudy Morel.
|
|
3697
|
+
"""
|
|
3698
|
+
|
|
3699
|
+
filter = np.zeros([J, L, M, N], dtype="complex64")
|
|
3700
|
+
|
|
3701
|
+
slant = 4.0 / L
|
|
3702
|
+
|
|
3703
|
+
for j in range(J):
|
|
3704
|
+
|
|
3705
|
+
for ell in range(L):
|
|
3706
|
+
|
|
3707
|
+
theta = (int(L - L / 2 - 1) - ell) * np.pi / L
|
|
3708
|
+
sigma = 0.8 * 2**j
|
|
3709
|
+
xi = 3.0 / 4.0 * np.pi / 2**j
|
|
3710
|
+
|
|
3711
|
+
R = np.array(
|
|
3712
|
+
[[np.cos(theta), -np.sin(theta)], [np.sin(theta), np.cos(theta)]],
|
|
3713
|
+
np.float64,
|
|
3687
3714
|
)
|
|
3688
|
-
|
|
3689
|
-
|
|
3715
|
+
R_inv = np.array(
|
|
3716
|
+
[[np.cos(theta), np.sin(theta)], [-np.sin(theta), np.cos(theta)]],
|
|
3717
|
+
np.float64,
|
|
3690
3718
|
)
|
|
3691
|
-
if cross:
|
|
3692
|
-
I2 = self.up_grade(
|
|
3693
|
-
I2, I2.shape[axis] * 2, axis=axis, nouty=I2.shape[axis + 1] * 2
|
|
3694
|
-
)
|
|
3695
|
-
elif self.use_1D:
|
|
3696
|
-
vmask = self.up_grade(vmask, I1.shape[axis] * 2, axis=1)
|
|
3697
|
-
I1 = self.up_grade(I1, I1.shape[axis] * 2, axis=axis)
|
|
3698
|
-
if cross:
|
|
3699
|
-
I2 = self.up_grade(I2, I2.shape[axis] * 2, axis=axis)
|
|
3700
|
-
else:
|
|
3701
|
-
I1 = self.up_grade(I1, nside * 2, axis=axis)
|
|
3702
|
-
vmask = self.up_grade(vmask, nside * 2, axis=1)
|
|
3703
|
-
if cross:
|
|
3704
|
-
I2 = self.up_grade(I2, nside * 2, axis=axis)
|
|
3705
|
-
|
|
3706
|
-
if self.KERNELSZ > 5 and not self.use_2D:
|
|
3707
|
-
# if the kernel size is bigger than 3 increase the binning before smoothing
|
|
3708
|
-
if self.use_2D:
|
|
3709
|
-
vmask = self.up_grade(
|
|
3710
|
-
vmask, I1.shape[axis] * 2, axis=1, nouty=I1.shape[axis + 1] * 2
|
|
3711
|
-
)
|
|
3712
|
-
I1 = self.up_grade(
|
|
3713
|
-
I1, I1.shape[axis] * 2, axis=axis, nouty=I1.shape[axis + 1] * 2
|
|
3714
|
-
)
|
|
3715
|
-
if cross:
|
|
3716
|
-
I2 = self.up_grade(
|
|
3717
|
-
I2,
|
|
3718
|
-
I2.shape[axis] * 2,
|
|
3719
|
-
axis=axis,
|
|
3720
|
-
nouty=I2.shape[axis + 1] * 2,
|
|
3721
|
-
)
|
|
3722
|
-
elif self.use_1D:
|
|
3723
|
-
vmask = self.up_grade(vmask, I1.shape[axis] * 4, axis=1)
|
|
3724
|
-
I1 = self.up_grade(I1, I1.shape[axis] * 4, axis=axis)
|
|
3725
|
-
if cross:
|
|
3726
|
-
I2 = self.up_grade(I2, I2.shape[axis] * 4, axis=axis)
|
|
3727
|
-
else:
|
|
3728
|
-
I1 = self.up_grade(I1, nside * 4, axis=axis)
|
|
3729
|
-
vmask = self.up_grade(vmask, nside * 4, axis=1)
|
|
3730
|
-
if cross:
|
|
3731
|
-
I2 = self.up_grade(I2, nside * 4, axis=axis)
|
|
3732
|
-
|
|
3733
|
-
# Normalize the masks because they have different pixel numbers
|
|
3734
|
-
# vmask /= self.backend.bk_reduce_sum(vmask, axis=1)[:, None] # [Nmask, Npix]
|
|
3735
|
-
|
|
3736
|
-
### INITIALIZATION
|
|
3737
|
-
# Coefficients
|
|
3738
|
-
if return_data:
|
|
3739
|
-
S1 = {}
|
|
3740
|
-
S2 = {}
|
|
3741
|
-
S3 = {}
|
|
3742
|
-
S3P = {}
|
|
3743
|
-
S4 = {}
|
|
3744
|
-
else:
|
|
3745
|
-
result=self.backend.backend.zeros([I1.shape[0],vmask.shape[0],2+2*Jmax*self.NORIENT],
|
|
3746
|
-
dtype=self.backend.backend.float32,
|
|
3747
|
-
device=self.backend.torch_device)
|
|
3748
|
-
vresult=self.backend.backend.zeros([I1.shape[0],vmask.shape[0],2+2*Jmax*self.NORIENT],
|
|
3749
|
-
dtype=self.backend.backend.float32,
|
|
3750
|
-
device=self.backend.torch_device)
|
|
3751
|
-
S1 = self.backend.backend.zeros([1,Jmax*self.NORIENT],dtype=self.backend.backend.float32,device=self.backend.torch_device)
|
|
3752
|
-
S2 = self.backend.backend.zeros([1,Jmax*self.NORIENT],dtype=self.backend.backend.float32,device=self.backend.torch_device)
|
|
3753
|
-
S3 = []
|
|
3754
|
-
S4 = []
|
|
3755
|
-
S3P = []
|
|
3756
|
-
VS1 = self.backend.backend.zeros([1,Jmax*self.NORIENT],dtype=self.backend.backend.float32,device=self.backend.torch_device)
|
|
3757
|
-
VS2 = self.backend.backend.zeros([1,Jmax*self.NORIENT],dtype=self.backend.backend.float32,device=self.backend.torch_device)
|
|
3758
|
-
VS3 = []
|
|
3759
|
-
VS3P = []
|
|
3760
|
-
VS4 = []
|
|
3761
|
-
|
|
3762
|
-
off_S2 = -2
|
|
3763
|
-
off_S3 = -3
|
|
3764
|
-
off_S4 = -4
|
|
3765
|
-
if self.use_1D:
|
|
3766
|
-
off_S2 = -1
|
|
3767
|
-
off_S3 = -1
|
|
3768
|
-
off_S4 = -1
|
|
3769
|
-
|
|
3770
|
-
# S2 for normalization
|
|
3771
|
-
cond_init_P1_dic = (norm == "self") or (
|
|
3772
|
-
(norm == "auto") and (self.P1_dic is None)
|
|
3773
|
-
)
|
|
3774
|
-
if norm is None:
|
|
3775
|
-
pass
|
|
3776
|
-
elif cond_init_P1_dic:
|
|
3777
|
-
P1_dic = {}
|
|
3778
|
-
if cross:
|
|
3779
|
-
P2_dic = {}
|
|
3780
|
-
elif (norm == "auto") and (self.P1_dic is not None):
|
|
3781
|
-
P1_dic = self.P1_dic
|
|
3782
|
-
if cross:
|
|
3783
|
-
P2_dic = self.P2_dic
|
|
3784
|
-
|
|
3785
|
-
if return_data:
|
|
3786
|
-
s0 = I1
|
|
3787
|
-
if out_nside is not None:
|
|
3788
|
-
s0 = self.backend.bk_reduce_mean(
|
|
3789
|
-
self.backend.bk_reshape(
|
|
3790
|
-
s0, [s0.shape[0], 12 * out_nside**2, (nside // out_nside) ** 2]
|
|
3791
|
-
),
|
|
3792
|
-
2,
|
|
3793
|
-
)
|
|
3794
|
-
else:
|
|
3795
|
-
if not cross:
|
|
3796
|
-
s0, l_vs0 = self.masked_mean(I1, vmask, axis=1, calc_var=True)
|
|
3797
|
-
else:
|
|
3798
|
-
s0, l_vs0 = self.masked_mean(
|
|
3799
|
-
self.backend.bk_L1(I1 * I2), vmask, axis=1, calc_var=True
|
|
3800
|
-
)
|
|
3801
|
-
#vs0 = self.backend.bk_concat([l_vs0, l_vs0], 1)
|
|
3802
|
-
#s0 = self.backend.bk_concat([s0, l_vs0], 1)
|
|
3803
|
-
result[:,:,0]=s0
|
|
3804
|
-
result[:,:,1]=l_vs0
|
|
3805
|
-
vresult[:,:,0]=l_vs0
|
|
3806
|
-
vresult[:,:,1]=l_vs0
|
|
3807
|
-
#### COMPUTE S1, S2, S3 and S4
|
|
3808
|
-
nside_j3 = nside # NSIDE start (nside_j3 = nside / 2^j3)
|
|
3809
|
-
|
|
3810
|
-
# a remettre comme avant
|
|
3811
|
-
M1_dic={}
|
|
3812
|
-
M2_dic={}
|
|
3813
|
-
|
|
3814
|
-
for j3 in range(Jmax):
|
|
3815
|
-
|
|
3816
|
-
if edge:
|
|
3817
|
-
if self.mask_mask is None:
|
|
3818
|
-
self.mask_mask={}
|
|
3819
|
-
if self.use_2D:
|
|
3820
|
-
if (vmask.shape[1],vmask.shape[2]) not in self.mask_mask:
|
|
3821
|
-
mask_mask=np.zeros([1,vmask.shape[1],vmask.shape[2]])
|
|
3822
|
-
mask_mask[0,
|
|
3823
|
-
self.KERNELSZ//2:-self.KERNELSZ//2+1,
|
|
3824
|
-
self.KERNELSZ//2:-self.KERNELSZ//2+1]=1.0
|
|
3825
|
-
self.mask_mask[(vmask.shape[1],vmask.shape[2])]=self.backend.bk_cast(mask_mask)
|
|
3826
|
-
vmask=vmask*self.mask_mask[(vmask.shape[1],vmask.shape[2])]
|
|
3827
|
-
#print(self.KERNELSZ//2,vmask,mask_mask)
|
|
3828
|
-
|
|
3829
|
-
if self.use_1D:
|
|
3830
|
-
if (vmask.shape[1]) not in self.mask_mask:
|
|
3831
|
-
mask_mask=np.zeros([1,vmask.shape[1]])
|
|
3832
|
-
mask_mask[0,
|
|
3833
|
-
self.KERNELSZ//2:-self.KERNELSZ//2+1]=1.0
|
|
3834
|
-
self.mask_mask[(vmask.shape[1])]=self.backend.bk_cast(mask_mask)
|
|
3835
|
-
vmask=vmask*self.mask_mask[(vmask.shape[1])]
|
|
3836
|
-
|
|
3837
|
-
if return_data:
|
|
3838
|
-
S3[j3] = None
|
|
3839
|
-
S3P[j3] = None
|
|
3840
|
-
|
|
3841
|
-
if S4 is None:
|
|
3842
|
-
S4 = {}
|
|
3843
|
-
S4[j3] = None
|
|
3844
|
-
|
|
3845
|
-
####### S1 and S2
|
|
3846
|
-
### Make the convolution I1 * Psi_j3
|
|
3847
|
-
conv1 = self.convol(I1, axis=1) # [Nbatch, Npix_j3, Norient3]
|
|
3848
|
-
if cmat is not None:
|
|
3849
|
-
tmp2 = self.backend.bk_repeat(conv1, self.NORIENT, axis=-1)
|
|
3850
|
-
conv1 = self.backend.bk_reduce_sum(
|
|
3851
|
-
self.backend.bk_reshape(
|
|
3852
|
-
cmat[j3] * tmp2,
|
|
3853
|
-
[tmp2.shape[0], cmat[j3].shape[0], self.NORIENT, self.NORIENT],
|
|
3854
|
-
),
|
|
3855
|
-
2,
|
|
3856
|
-
)
|
|
3857
|
-
|
|
3858
|
-
### Take the module M1 = |I1 * Psi_j3|
|
|
3859
|
-
M1_square = conv1 * self.backend.bk_conjugate(
|
|
3860
|
-
conv1
|
|
3861
|
-
) # [Nbatch, Npix_j3, Norient3]
|
|
3862
|
-
|
|
3863
|
-
M1 = self.backend.bk_L1(M1_square) # [Nbatch, Npix_j3, Norient3]
|
|
3864
|
-
|
|
3865
|
-
# Store M1_j3 in a dictionary
|
|
3866
|
-
M1_dic[j3] = M1
|
|
3867
|
-
|
|
3868
|
-
if not cross: # Auto
|
|
3869
|
-
M1_square = self.backend.bk_real(M1_square)
|
|
3870
|
-
|
|
3871
|
-
### S2_auto = < M1^2 >_pix
|
|
3872
|
-
# Apply the mask [Nmask, Npix_j3] and average over pixels
|
|
3873
|
-
if return_data:
|
|
3874
|
-
s2 = M1_square
|
|
3875
|
-
else:
|
|
3876
|
-
if calc_var:
|
|
3877
|
-
s2, vs2 = self.masked_mean(
|
|
3878
|
-
M1_square, vmask, axis=1, rank=j3, calc_var=True
|
|
3879
|
-
)
|
|
3880
|
-
#s2=self.backend.bk_flatten(self.backend.bk_real(s2))
|
|
3881
|
-
#vs2=self.backend.bk_flatten(vs2)
|
|
3882
|
-
else:
|
|
3883
|
-
s2 = self.masked_mean(M1_square, vmask, axis=1, rank=j3)
|
|
3884
|
-
|
|
3885
|
-
if cond_init_P1_dic:
|
|
3886
|
-
# We fill P1_dic with S2 for normalisation of S3 and S4
|
|
3887
|
-
P1_dic[j3] = self.backend.bk_real(self.backend.bk_real(s2)) # [Nbatch, Nmask, Norient3]
|
|
3888
|
-
|
|
3889
|
-
# We store S2_auto to return it [Nbatch, Nmask, NS2, Norient3]
|
|
3890
|
-
if return_data:
|
|
3891
|
-
if S2 is None:
|
|
3892
|
-
S2 = {}
|
|
3893
|
-
if out_nside is not None and out_nside < nside_j3:
|
|
3894
|
-
s2 = self.backend.bk_reduce_mean(
|
|
3895
|
-
self.backend.bk_reshape(
|
|
3896
|
-
s2,
|
|
3897
|
-
[
|
|
3898
|
-
s2.shape[0],
|
|
3899
|
-
12 * out_nside**2,
|
|
3900
|
-
(nside_j3 // out_nside) ** 2,
|
|
3901
|
-
s2.shape[2],
|
|
3902
|
-
],
|
|
3903
|
-
),
|
|
3904
|
-
2,
|
|
3905
|
-
)
|
|
3906
|
-
S2[j3] = s2
|
|
3907
|
-
else:
|
|
3908
|
-
if norm == "auto": # Normalize S2
|
|
3909
|
-
s2 /= P1_dic[j3]
|
|
3910
|
-
"""
|
|
3911
|
-
S2.append(
|
|
3912
|
-
self.backend.bk_expand_dims(s2, off_S2)
|
|
3913
|
-
) # Add a dimension for NS2
|
|
3914
|
-
if calc_var:
|
|
3915
|
-
VS2.append(
|
|
3916
|
-
self.backend.bk_expand_dims(vs2, off_S2)
|
|
3917
|
-
) # Add a dimension for NS2
|
|
3918
|
-
"""
|
|
3919
|
-
#print(s2.shape,result[:,:,2+j3*NORIENT*2:2+j3*NORIENT*2+NORIENT].shape,result.shape,2+j3*NORIENT*2)
|
|
3920
|
-
result[:,:,2+j3*NORIENT*2:2+j3*NORIENT*2+NORIENT]=s2
|
|
3921
|
-
if calc_var:
|
|
3922
|
-
vresult[:,:,2+j3*NORIENT*2:2+j3*NORIENT*2+NORIENT]=vs2
|
|
3923
|
-
#### S1_auto computation
|
|
3924
|
-
### Image 1 : S1 = < M1 >_pix
|
|
3925
|
-
# Apply the mask [Nmask, Npix_j3] and average over pixels
|
|
3926
|
-
if return_data:
|
|
3927
|
-
s1 = M1
|
|
3928
|
-
else:
|
|
3929
|
-
if calc_var:
|
|
3930
|
-
s1, vs1 = self.masked_mean(
|
|
3931
|
-
M1, vmask, axis=1, rank=j3, calc_var=True
|
|
3932
|
-
) # [Nbatch, Nmask, Norient3]
|
|
3933
|
-
#s1=self.backend.bk_flatten(self.backend.bk_real(s1))
|
|
3934
|
-
#vs1=self.backend.bk_flatten(vs1)
|
|
3935
|
-
else:
|
|
3936
|
-
s1 = self.masked_mean(
|
|
3937
|
-
M1, vmask, axis=1, rank=j3
|
|
3938
|
-
) # [Nbatch, Nmask, Norient3]
|
|
3939
|
-
#s1=self.backend.bk_flatten(self.backend.bk_real(s1))
|
|
3940
|
-
|
|
3941
|
-
if return_data:
|
|
3942
|
-
if out_nside is not None and out_nside < nside_j3:
|
|
3943
|
-
s1 = self.backend.bk_reduce_mean(
|
|
3944
|
-
self.backend.bk_reshape(
|
|
3945
|
-
s1,
|
|
3946
|
-
[
|
|
3947
|
-
s1.shape[0],
|
|
3948
|
-
12 * out_nside**2,
|
|
3949
|
-
(nside_j3 // out_nside) ** 2,
|
|
3950
|
-
s1.shape[2],
|
|
3951
|
-
],
|
|
3952
|
-
),
|
|
3953
|
-
2,
|
|
3954
|
-
)
|
|
3955
|
-
S1[j3] = s1
|
|
3956
|
-
else:
|
|
3957
|
-
### Normalize S1
|
|
3958
|
-
if norm is not None:
|
|
3959
|
-
self.div_norm(s1, (P1_dic[j3]) ** 0.5)
|
|
3960
|
-
result[:,:,2+j3*NORIENT*2+NORIENT:2+j3*NORIENT*2+2*NORIENT]=s1
|
|
3961
|
-
if calc_var:
|
|
3962
|
-
vresult[:,:,2+j3*NORIENT*2+NORIENT:2+j3*NORIENT*2+2*NORIENT]=vs1
|
|
3963
|
-
"""
|
|
3964
|
-
### We store S1 for image1 [Nbatch, Nmask, NS1, Norient3]
|
|
3965
|
-
S1.append(
|
|
3966
|
-
self.backend.bk_expand_dims(s1, off_S2)
|
|
3967
|
-
) # Add a dimension for NS1
|
|
3968
|
-
if calc_var:
|
|
3969
|
-
VS1.append(
|
|
3970
|
-
self.backend.bk_expand_dims(vs1, off_S2)
|
|
3971
|
-
) # Add a dimension for NS1
|
|
3972
|
-
"""
|
|
3973
|
-
|
|
3974
|
-
else: # Cross
|
|
3975
|
-
### Make the convolution I2 * Psi_j3
|
|
3976
|
-
conv2 = self.convol(I2, axis=1) # [Nbatch, Npix_j3, Norient3]
|
|
3977
|
-
if cmat is not None:
|
|
3978
|
-
tmp2 = self.backend.bk_repeat(conv2, self.NORIENT, axis=-1)
|
|
3979
|
-
conv2 = self.backend.bk_reduce_sum(
|
|
3980
|
-
self.backend.bk_reshape(
|
|
3981
|
-
cmat[j3] * tmp2,
|
|
3982
|
-
[
|
|
3983
|
-
tmp2.shape[0],
|
|
3984
|
-
cmat[j3].shape[0],
|
|
3985
|
-
self.NORIENT,
|
|
3986
|
-
self.NORIENT,
|
|
3987
|
-
],
|
|
3988
|
-
),
|
|
3989
|
-
2,
|
|
3990
|
-
)
|
|
3991
|
-
### Take the module M2 = |I2 * Psi_j3|
|
|
3992
|
-
M2_square = conv2 * self.backend.bk_conjugate(
|
|
3993
|
-
conv2
|
|
3994
|
-
) # [Nbatch, Npix_j3, Norient3]
|
|
3995
|
-
M2 = self.backend.bk_L1(M2_square) # [Nbatch, Npix_j3, Norient3]
|
|
3996
|
-
# Store M2_j3 in a dictionary
|
|
3997
|
-
M2_dic[j3] = M2
|
|
3998
|
-
|
|
3999
|
-
### S2_auto = < M2^2 >_pix
|
|
4000
|
-
# Not returned, only for normalization
|
|
4001
|
-
if cond_init_P1_dic:
|
|
4002
|
-
# Apply the mask [Nmask, Npix_j3] and average over pixels
|
|
4003
|
-
if return_data:
|
|
4004
|
-
p1 = M1_square
|
|
4005
|
-
p2 = M2_square
|
|
4006
|
-
else:
|
|
4007
|
-
if calc_var:
|
|
4008
|
-
p1, vp1 = self.masked_mean(
|
|
4009
|
-
M1_square, vmask, axis=1, rank=j3, calc_var=True
|
|
4010
|
-
) # [Nbatch, Nmask, Norient3]
|
|
4011
|
-
p2, vp2 = self.masked_mean(
|
|
4012
|
-
M2_square, vmask, axis=1, rank=j3, calc_var=True
|
|
4013
|
-
) # [Nbatch, Nmask, Norient3]
|
|
4014
|
-
else:
|
|
4015
|
-
p1 = self.masked_mean(
|
|
4016
|
-
M1_square, vmask, axis=1, rank=j3
|
|
4017
|
-
) # [Nbatch, Nmask, Norient3]
|
|
4018
|
-
p2 = self.masked_mean(
|
|
4019
|
-
M2_square, vmask, axis=1, rank=j3
|
|
4020
|
-
) # [Nbatch, Nmask, Norient3]
|
|
4021
|
-
# We fill P1_dic with S2 for normalisation of S3 and S4
|
|
4022
|
-
P1_dic[j3] = self.backend.bk_real(p1) # [Nbatch, Nmask, Norient3]
|
|
4023
|
-
P2_dic[j3] = self.backend.bk_real(p2) # [Nbatch, Nmask, Norient3]
|
|
4024
|
-
|
|
4025
|
-
### S2_cross = < (I1 * Psi_j3) (I2 * Psi_j3)^* >_pix
|
|
4026
|
-
# z_1 x z_2^* = (a1a2 + b1b2) + i(b1a2 - a1b2)
|
|
4027
|
-
s2 = conv1 * self.backend.bk_conjugate(conv2)
|
|
4028
|
-
MX = self.backend.bk_L1(s2)
|
|
4029
|
-
# Apply the mask [Nmask, Npix_j3] and average over pixels
|
|
4030
|
-
if return_data:
|
|
4031
|
-
s2 = s2
|
|
4032
|
-
else:
|
|
4033
|
-
if calc_var:
|
|
4034
|
-
s2, vs2 = self.masked_mean(
|
|
4035
|
-
s2, vmask, axis=1, rank=j3, calc_var=True
|
|
4036
|
-
)
|
|
4037
|
-
else:
|
|
4038
|
-
s2 = self.masked_mean(s2, vmask, axis=1, rank=j3)
|
|
4039
|
-
|
|
4040
|
-
if return_data:
|
|
4041
|
-
if out_nside is not None and out_nside < nside_j3:
|
|
4042
|
-
s2 = self.backend.bk_reduce_mean(
|
|
4043
|
-
self.backend.bk_reshape(
|
|
4044
|
-
s2,
|
|
4045
|
-
[
|
|
4046
|
-
s2.shape[0],
|
|
4047
|
-
12 * out_nside**2,
|
|
4048
|
-
(nside_j3 // out_nside) ** 2,
|
|
4049
|
-
s2.shape[2],
|
|
4050
|
-
],
|
|
4051
|
-
),
|
|
4052
|
-
2,
|
|
4053
|
-
)
|
|
4054
|
-
S2[j3] = s2
|
|
4055
|
-
else:
|
|
4056
|
-
### Normalize S2_cross
|
|
4057
|
-
if norm == "auto":
|
|
4058
|
-
s2 /= (P1_dic[j3] * P2_dic[j3]) ** 0.5
|
|
4059
|
-
|
|
4060
|
-
### Store S2_cross as complex [Nbatch, Nmask, NS2, Norient3]
|
|
4061
|
-
s2 = self.backend.bk_real(s2)
|
|
4062
|
-
|
|
4063
|
-
S2.append(
|
|
4064
|
-
self.backend.bk_expand_dims(s2, off_S2)
|
|
4065
|
-
) # Add a dimension for NS2
|
|
4066
|
-
if calc_var:
|
|
4067
|
-
VS2.append(
|
|
4068
|
-
self.backend.bk_expand_dims(vs2, off_S2)
|
|
4069
|
-
) # Add a dimension for NS2
|
|
4070
|
-
|
|
4071
|
-
#### S1_auto computation
|
|
4072
|
-
### Image 1 : S1 = < M1 >_pix
|
|
4073
|
-
# Apply the mask [Nmask, Npix_j3] and average over pixels
|
|
4074
|
-
if return_data:
|
|
4075
|
-
s1 = MX
|
|
4076
|
-
else:
|
|
4077
|
-
if calc_var:
|
|
4078
|
-
s1, vs1 = self.masked_mean(
|
|
4079
|
-
MX, vmask, axis=1, rank=j3, calc_var=True
|
|
4080
|
-
) # [Nbatch, Nmask, Norient3]
|
|
4081
|
-
else:
|
|
4082
|
-
s1 = self.masked_mean(
|
|
4083
|
-
MX, vmask, axis=1, rank=j3
|
|
4084
|
-
) # [Nbatch, Nmask, Norient3]
|
|
4085
|
-
if return_data:
|
|
4086
|
-
if out_nside is not None and out_nside < nside_j3:
|
|
4087
|
-
s1 = self.backend.bk_reduce_mean(
|
|
4088
|
-
self.backend.bk_reshape(
|
|
4089
|
-
s1,
|
|
4090
|
-
[
|
|
4091
|
-
s1.shape[0],
|
|
4092
|
-
12 * out_nside**2,
|
|
4093
|
-
(nside_j3 // out_nside) ** 2,
|
|
4094
|
-
s1.shape[2],
|
|
4095
|
-
],
|
|
4096
|
-
),
|
|
4097
|
-
2,
|
|
4098
|
-
)
|
|
4099
|
-
S1[j3] = s1
|
|
4100
|
-
else:
|
|
4101
|
-
### Normalize S1
|
|
4102
|
-
if norm is not None:
|
|
4103
|
-
self.div_norm(s1, (P1_dic[j3]) ** 0.5)
|
|
4104
|
-
### We store S1 for image1 [Nbatch, Nmask, NS1, Norient3]
|
|
4105
|
-
S1.append(
|
|
4106
|
-
self.backend.bk_expand_dims(s1, off_S2)
|
|
4107
|
-
) # Add a dimension for NS1
|
|
4108
|
-
if calc_var:
|
|
4109
|
-
VS1.append(
|
|
4110
|
-
self.backend.bk_expand_dims(vs1, off_S2)
|
|
4111
|
-
) # Add a dimension for NS1
|
|
4112
|
-
|
|
4113
|
-
# Initialize dictionaries for |I1*Psi_j| * Psi_j3
|
|
4114
|
-
M1convPsi_dic = {}
|
|
4115
|
-
if cross:
|
|
4116
|
-
# Initialize dictionaries for |I2*Psi_j| * Psi_j3
|
|
4117
|
-
M2convPsi_dic = {}
|
|
4118
|
-
|
|
4119
|
-
###### S3
|
|
4120
|
-
nside_j2 = nside_j3
|
|
4121
|
-
for j2 in range(0,-1): # j3 + 1): # j2 <= j3
|
|
4122
|
-
if return_data:
|
|
4123
|
-
if S4[j3] is None:
|
|
4124
|
-
S4[j3] = {}
|
|
4125
|
-
S4[j3][j2] = None
|
|
4126
|
-
|
|
4127
|
-
### S3_auto = < (I1 * Psi)_j3 x (|I1 * Psi_j2| * Psi_j3)^* >_pix
|
|
4128
|
-
if not cross:
|
|
4129
|
-
if calc_var:
|
|
4130
|
-
s3, vs3 = self._compute_S3(
|
|
4131
|
-
j2,
|
|
4132
|
-
j3,
|
|
4133
|
-
conv1,
|
|
4134
|
-
vmask,
|
|
4135
|
-
M1_dic,
|
|
4136
|
-
M1convPsi_dic,
|
|
4137
|
-
calc_var=True,
|
|
4138
|
-
cmat2=cmat2,
|
|
4139
|
-
) # [Nbatch, Nmask, Norient3, Norient2]
|
|
4140
|
-
else:
|
|
4141
|
-
s3 = self._compute_S3(
|
|
4142
|
-
j2,
|
|
4143
|
-
j3,
|
|
4144
|
-
conv1,
|
|
4145
|
-
vmask,
|
|
4146
|
-
M1_dic,
|
|
4147
|
-
M1convPsi_dic,
|
|
4148
|
-
return_data=return_data,
|
|
4149
|
-
cmat2=cmat2,
|
|
4150
|
-
) # [Nbatch, Nmask, Norient3, Norient2]
|
|
4151
|
-
|
|
4152
|
-
if return_data:
|
|
4153
|
-
if S3[j3] is None:
|
|
4154
|
-
S3[j3] = {}
|
|
4155
|
-
if out_nside is not None and out_nside < nside_j2:
|
|
4156
|
-
s3 = self.backend.bk_reduce_mean(
|
|
4157
|
-
self.backend.bk_reshape(
|
|
4158
|
-
s3,
|
|
4159
|
-
[
|
|
4160
|
-
s3.shape[0],
|
|
4161
|
-
12 * out_nside**2,
|
|
4162
|
-
(nside_j2 // out_nside) ** 2,
|
|
4163
|
-
s3.shape[2],
|
|
4164
|
-
s3.shape[3],
|
|
4165
|
-
],
|
|
4166
|
-
),
|
|
4167
|
-
2,
|
|
4168
|
-
)
|
|
4169
|
-
S3[j3][j2] = s3
|
|
4170
|
-
else:
|
|
4171
|
-
### Normalize S3 with S2_j [Nbatch, Nmask, Norient_j]
|
|
4172
|
-
if norm is not None:
|
|
4173
|
-
self.div_norm(
|
|
4174
|
-
s3,
|
|
4175
|
-
(
|
|
4176
|
-
self.backend.bk_expand_dims(P1_dic[j2], off_S2)
|
|
4177
|
-
* self.backend.bk_expand_dims(P1_dic[j3], -1)
|
|
4178
|
-
)
|
|
4179
|
-
** 0.5,
|
|
4180
|
-
) # [Nbatch, Nmask, Norient3, Norient2]
|
|
4181
|
-
|
|
4182
|
-
### Store S3 as a complex [Nbatch, Nmask, NS3, Norient3, Norient2]
|
|
4183
|
-
|
|
4184
|
-
# S3.append(self.backend.bk_reshape(s3,[s3.shape[0],s3.shape[1],
|
|
4185
|
-
# s3.shape[2]*s3.shape[3]]))
|
|
4186
|
-
S3.append(
|
|
4187
|
-
self.backend.bk_expand_dims(s3, off_S3)
|
|
4188
|
-
) # Add a dimension for NS3
|
|
4189
|
-
if calc_var:
|
|
4190
|
-
VS3.append(
|
|
4191
|
-
self.backend.bk_expand_dims(vs3, off_S3)
|
|
4192
|
-
) # Add a dimension for NS3
|
|
4193
|
-
# VS3.append(self.backend.bk_reshape(vs3,[s3.shape[0],s3.shape[1],
|
|
4194
|
-
# s3.shape[2]*s3.shape[3]]))
|
|
4195
|
-
|
|
4196
|
-
### S3_cross = < (I1 * Psi)_j3 x (|I2 * Psi_j2| * Psi_j3)^* >_pix
|
|
4197
|
-
### S3P_cross = < (I2 * Psi)_j3 x (|I1 * Psi_j2| * Psi_j3)^* >_pix
|
|
4198
|
-
else:
|
|
4199
|
-
if calc_var:
|
|
4200
|
-
s3, vs3 = self._compute_S3(
|
|
4201
|
-
j2,
|
|
4202
|
-
j3,
|
|
4203
|
-
conv1,
|
|
4204
|
-
vmask,
|
|
4205
|
-
M2_dic,
|
|
4206
|
-
M2convPsi_dic,
|
|
4207
|
-
calc_var=True,
|
|
4208
|
-
cmat2=cmat2,
|
|
4209
|
-
)
|
|
4210
|
-
s3p, vs3p = self._compute_S3(
|
|
4211
|
-
j2,
|
|
4212
|
-
j3,
|
|
4213
|
-
conv2,
|
|
4214
|
-
vmask,
|
|
4215
|
-
M1_dic,
|
|
4216
|
-
M1convPsi_dic,
|
|
4217
|
-
calc_var=True,
|
|
4218
|
-
cmat2=cmat2,
|
|
4219
|
-
)
|
|
4220
|
-
else:
|
|
4221
|
-
s3 = self._compute_S3(
|
|
4222
|
-
j2,
|
|
4223
|
-
j3,
|
|
4224
|
-
conv1,
|
|
4225
|
-
vmask,
|
|
4226
|
-
M2_dic,
|
|
4227
|
-
M2convPsi_dic,
|
|
4228
|
-
return_data=return_data,
|
|
4229
|
-
cmat2=cmat2,
|
|
4230
|
-
)
|
|
4231
|
-
s3p = self._compute_S3(
|
|
4232
|
-
j2,
|
|
4233
|
-
j3,
|
|
4234
|
-
conv2,
|
|
4235
|
-
vmask,
|
|
4236
|
-
M1_dic,
|
|
4237
|
-
M1convPsi_dic,
|
|
4238
|
-
return_data=return_data,
|
|
4239
|
-
cmat2=cmat2,
|
|
4240
|
-
)
|
|
4241
|
-
|
|
4242
|
-
if return_data:
|
|
4243
|
-
if S3[j3] is None:
|
|
4244
|
-
S3[j3] = {}
|
|
4245
|
-
S3P[j3] = {}
|
|
4246
|
-
if out_nside is not None and out_nside < nside_j2:
|
|
4247
|
-
s3 = self.backend.bk_reduce_mean(
|
|
4248
|
-
self.backend.bk_reshape(
|
|
4249
|
-
s3,
|
|
4250
|
-
[
|
|
4251
|
-
s3.shape[0],
|
|
4252
|
-
12 * out_nside**2,
|
|
4253
|
-
(nside_j2 // out_nside) ** 2,
|
|
4254
|
-
s3.shape[2],
|
|
4255
|
-
s3.shape[3],
|
|
4256
|
-
],
|
|
4257
|
-
),
|
|
4258
|
-
2,
|
|
4259
|
-
)
|
|
4260
|
-
s3p = self.backend.bk_reduce_mean(
|
|
4261
|
-
self.backend.bk_reshape(
|
|
4262
|
-
s3p,
|
|
4263
|
-
[
|
|
4264
|
-
s3.shape[0],
|
|
4265
|
-
12 * out_nside**2,
|
|
4266
|
-
(nside_j2 // out_nside) ** 2,
|
|
4267
|
-
s3.shape[2],
|
|
4268
|
-
s3.shape[3],
|
|
4269
|
-
],
|
|
4270
|
-
),
|
|
4271
|
-
2,
|
|
4272
|
-
)
|
|
4273
|
-
S3[j3][j2] = s3
|
|
4274
|
-
S3P[j3][j2] = s3p
|
|
4275
|
-
else:
|
|
4276
|
-
### Normalize S3 and S3P with S2_j [Nbatch, Nmask, Norient_j]
|
|
4277
|
-
if norm is not None:
|
|
4278
|
-
self.div_norm(
|
|
4279
|
-
s3,
|
|
4280
|
-
(
|
|
4281
|
-
self.backend.bk_expand_dims(P2_dic[j2], off_S2)
|
|
4282
|
-
* self.backend.bk_expand_dims(P1_dic[j3], -1)
|
|
4283
|
-
)
|
|
4284
|
-
** 0.5,
|
|
4285
|
-
) # [Nbatch, Nmask, Norient3, Norient2]
|
|
4286
|
-
self.div_norm(
|
|
4287
|
-
s3p,
|
|
4288
|
-
(
|
|
4289
|
-
self.backend.bk_expand_dims(P1_dic[j2], off_S2)
|
|
4290
|
-
* self.backend.bk_expand_dims(P2_dic[j3], -1)
|
|
4291
|
-
)
|
|
4292
|
-
** 0.5,
|
|
4293
|
-
) # [Nbatch, Nmask, Norient3, Norient2]
|
|
4294
|
-
|
|
4295
|
-
### Store S3 and S3P as a complex [Nbatch, Nmask, NS3, Norient3, Norient2]
|
|
4296
|
-
|
|
4297
|
-
# S3.append(self.backend.bk_reshape(s3,[s3.shape[0],s3.shape[1],
|
|
4298
|
-
# s3.shape[2]*s3.shape[3]]))
|
|
4299
|
-
S3.append(
|
|
4300
|
-
self.backend.bk_expand_dims(s3, off_S3)
|
|
4301
|
-
) # Add a dimension for NS3
|
|
4302
|
-
if calc_var:
|
|
4303
|
-
VS3.append(
|
|
4304
|
-
self.backend.bk_expand_dims(vs3, off_S3)
|
|
4305
|
-
) # Add a dimension for NS3
|
|
4306
|
-
|
|
4307
|
-
# VS3.append(self.backend.bk_reshape(vs3,[s3.shape[0],s3.shape[1],
|
|
4308
|
-
# s3.shape[2]*s3.shape[3]]))
|
|
4309
|
-
|
|
4310
|
-
# S3P.append(self.backend.bk_reshape(s3p,[s3.shape[0],s3.shape[1],
|
|
4311
|
-
# s3.shape[2]*s3.shape[3]]))
|
|
4312
|
-
S3P.append(
|
|
4313
|
-
self.backend.bk_expand_dims(s3p, off_S3)
|
|
4314
|
-
) # Add a dimension for NS3
|
|
4315
|
-
if calc_var:
|
|
4316
|
-
VS3P.append(
|
|
4317
|
-
self.backend.bk_expand_dims(vs3p, off_S3)
|
|
4318
|
-
) # Add a dimension for NS3
|
|
4319
|
-
# VS3P.append(self.backend.bk_reshape(vs3p,[s3.shape[0],s3.shape[1],
|
|
4320
|
-
# s3.shape[2]*s3.shape[3]]))
|
|
4321
|
-
|
|
4322
|
-
##### S4
|
|
4323
|
-
nside_j1 = nside_j2
|
|
4324
|
-
for j1 in range(0, j2 + 1): # j1 <= j2
|
|
4325
|
-
### S4_auto = <(|I1 * psi1| * psi3)(|I1 * psi2| * psi3)^*>
|
|
4326
|
-
if not cross:
|
|
4327
|
-
if calc_var:
|
|
4328
|
-
s4, vs4 = self._compute_S4(
|
|
4329
|
-
j1,
|
|
4330
|
-
j2,
|
|
4331
|
-
vmask,
|
|
4332
|
-
M1convPsi_dic,
|
|
4333
|
-
M2convPsi_dic=None,
|
|
4334
|
-
calc_var=True,
|
|
4335
|
-
) # [Nbatch, Nmask, Norient3, Norient2, Norient1]
|
|
4336
|
-
else:
|
|
4337
|
-
s4 = self._compute_S4(
|
|
4338
|
-
j1,
|
|
4339
|
-
j2,
|
|
4340
|
-
vmask,
|
|
4341
|
-
M1convPsi_dic,
|
|
4342
|
-
M2convPsi_dic=None,
|
|
4343
|
-
return_data=return_data,
|
|
4344
|
-
) # [Nbatch, Nmask, Norient3, Norient2, Norient1]
|
|
4345
|
-
|
|
4346
|
-
if return_data:
|
|
4347
|
-
if S4[j3][j2] is None:
|
|
4348
|
-
S4[j3][j2] = {}
|
|
4349
|
-
if out_nside is not None and out_nside < nside_j1:
|
|
4350
|
-
s4 = self.backend.bk_reduce_mean(
|
|
4351
|
-
self.backend.bk_reshape(
|
|
4352
|
-
s4,
|
|
4353
|
-
[
|
|
4354
|
-
s4.shape[0],
|
|
4355
|
-
12 * out_nside**2,
|
|
4356
|
-
(nside_j1 // out_nside) ** 2,
|
|
4357
|
-
s4.shape[2],
|
|
4358
|
-
s4.shape[3],
|
|
4359
|
-
s4.shape[4],
|
|
4360
|
-
],
|
|
4361
|
-
),
|
|
4362
|
-
2,
|
|
4363
|
-
)
|
|
4364
|
-
S4[j3][j2][j1] = s4
|
|
4365
|
-
else:
|
|
4366
|
-
### Normalize S4 with S2_j [Nbatch, Nmask, Norient_j]
|
|
4367
|
-
if norm is not None:
|
|
4368
|
-
self.div_norm(
|
|
4369
|
-
s4,
|
|
4370
|
-
(
|
|
4371
|
-
self.backend.bk_expand_dims(
|
|
4372
|
-
self.backend.bk_expand_dims(
|
|
4373
|
-
P1_dic[j1], off_S2
|
|
4374
|
-
),
|
|
4375
|
-
off_S2,
|
|
4376
|
-
)
|
|
4377
|
-
* self.backend.bk_expand_dims(
|
|
4378
|
-
self.backend.bk_expand_dims(
|
|
4379
|
-
P1_dic[j2], off_S2
|
|
4380
|
-
),
|
|
4381
|
-
-1,
|
|
4382
|
-
)
|
|
4383
|
-
)
|
|
4384
|
-
** 0.5,
|
|
4385
|
-
) # [Nbatch, Nmask, Norient3, Norient2, Norient1]
|
|
4386
|
-
### Store S4 as a complex [Nbatch, Nmask, NS4, Norient3, Norient2, Norient1]
|
|
4387
|
-
|
|
4388
|
-
# S4.append(self.backend.bk_reshape(s4,[s4.shape[0],s4.shape[1],
|
|
4389
|
-
# s4.shape[2]*s4.shape[3]*s4.shape[4]]))
|
|
4390
|
-
S4.append(
|
|
4391
|
-
self.backend.bk_expand_dims(s4, off_S4)
|
|
4392
|
-
) # Add a dimension for NS4
|
|
4393
|
-
if calc_var:
|
|
4394
|
-
# VS4.append(self.backend.bk_reshape(vs4,[s4.shape[0],s4.shape[1],
|
|
4395
|
-
# s4.shape[2]*s4.shape[3]*s4.shape[4]]))
|
|
4396
|
-
VS4.append(
|
|
4397
|
-
self.backend.bk_expand_dims(vs4, off_S4)
|
|
4398
|
-
) # Add a dimension for NS4
|
|
4399
|
-
|
|
4400
|
-
### S4_cross = <(|I1 * psi1| * psi3)(|I2 * psi2| * psi3)^*>
|
|
4401
|
-
else:
|
|
4402
|
-
if calc_var:
|
|
4403
|
-
s4, vs4 = self._compute_S4(
|
|
4404
|
-
j1,
|
|
4405
|
-
j2,
|
|
4406
|
-
vmask,
|
|
4407
|
-
M1convPsi_dic,
|
|
4408
|
-
M2convPsi_dic=M2convPsi_dic,
|
|
4409
|
-
calc_var=True,
|
|
4410
|
-
) # [Nbatch, Nmask, Norient3, Norient2, Norient1]
|
|
4411
|
-
else:
|
|
4412
|
-
s4 = self._compute_S4(
|
|
4413
|
-
j1,
|
|
4414
|
-
j2,
|
|
4415
|
-
vmask,
|
|
4416
|
-
M1convPsi_dic,
|
|
4417
|
-
M2convPsi_dic=M2convPsi_dic,
|
|
4418
|
-
return_data=return_data,
|
|
4419
|
-
) # [Nbatch, Nmask, Norient3, Norient2, Norient1]
|
|
4420
|
-
|
|
4421
|
-
if return_data:
|
|
4422
|
-
if S4[j3][j2] is None:
|
|
4423
|
-
S4[j3][j2] = {}
|
|
4424
|
-
if out_nside is not None and out_nside < nside_j1:
|
|
4425
|
-
s4 = self.backend.bk_reduce_mean(
|
|
4426
|
-
self.backend.bk_reshape(
|
|
4427
|
-
s4,
|
|
4428
|
-
[
|
|
4429
|
-
s4.shape[0],
|
|
4430
|
-
12 * out_nside**2,
|
|
4431
|
-
(nside_j1 // out_nside) ** 2,
|
|
4432
|
-
s4.shape[2],
|
|
4433
|
-
s4.shape[3],
|
|
4434
|
-
s4.shape[4],
|
|
4435
|
-
],
|
|
4436
|
-
),
|
|
4437
|
-
2,
|
|
4438
|
-
)
|
|
4439
|
-
S4[j3][j2][j1] = s4
|
|
4440
|
-
else:
|
|
4441
|
-
### Normalize S4 with S2_j [Nbatch, Nmask, Norient_j]
|
|
4442
|
-
if norm is not None:
|
|
4443
|
-
self.div_norm(
|
|
4444
|
-
s4,
|
|
4445
|
-
(
|
|
4446
|
-
self.backend.bk_expand_dims(
|
|
4447
|
-
self.backend.bk_expand_dims(
|
|
4448
|
-
P1_dic[j1], off_S2
|
|
4449
|
-
),
|
|
4450
|
-
off_S2,
|
|
4451
|
-
)
|
|
4452
|
-
* self.backend.bk_expand_dims(
|
|
4453
|
-
self.backend.bk_expand_dims(
|
|
4454
|
-
P2_dic[j2], off_S2
|
|
4455
|
-
),
|
|
4456
|
-
-1,
|
|
4457
|
-
)
|
|
4458
|
-
)
|
|
4459
|
-
** 0.5,
|
|
4460
|
-
) # [Nbatch, Nmask, Norient3, Norient2, Norient1]
|
|
4461
|
-
### Store S4 as a complex [Nbatch, Nmask, NS4, Norient3, Norient2, Norient1]
|
|
4462
|
-
# S4.append(self.backend.bk_reshape(s4,[s4.shape[0],s4.shape[1],
|
|
4463
|
-
# s4.shape[2]*s4.shape[3]*s4.shape[4]]))
|
|
4464
|
-
S4.append(
|
|
4465
|
-
self.backend.bk_expand_dims(s4, off_S4)
|
|
4466
|
-
) # Add a dimension for NS4
|
|
4467
|
-
if calc_var:
|
|
4468
|
-
|
|
4469
|
-
# VS4.append(self.backend.bk_reshape(vs4,[s4.shape[0],s4.shape[1],
|
|
4470
|
-
# s4.shape[2]*s4.shape[3]*s4.shape[4]]))
|
|
4471
|
-
VS4.append(
|
|
4472
|
-
self.backend.bk_expand_dims(vs4, off_S4)
|
|
4473
|
-
) # Add a dimension for NS4
|
|
4474
|
-
|
|
4475
|
-
nside_j1 = nside_j1 // 2
|
|
4476
|
-
nside_j2 = nside_j2 // 2
|
|
4477
|
-
|
|
4478
|
-
###### Reshape for next iteration on j3
|
|
4479
|
-
### Image I1,
|
|
4480
|
-
# downscale the I1 [Nbatch, Npix_j3]
|
|
4481
|
-
if j3 != Jmax - 1:
|
|
4482
|
-
I1 = self.smooth(I1, axis=1)
|
|
4483
|
-
I1 = self.ud_grade_2(I1, axis=1)
|
|
4484
|
-
|
|
4485
|
-
### Image I2
|
|
4486
|
-
if cross:
|
|
4487
|
-
I2 = self.smooth(I2, axis=1)
|
|
4488
|
-
I2 = self.ud_grade_2(I2, axis=1)
|
|
4489
|
-
|
|
4490
|
-
### Modules
|
|
4491
|
-
for j2 in range(0, j3 + 1): # j2 =< j3
|
|
4492
|
-
### Dictionary M1_dic[j2]
|
|
4493
|
-
M1_smooth = self.smooth(
|
|
4494
|
-
M1_dic[j2], axis=1
|
|
4495
|
-
) # [Nbatch, Npix_j3, Norient3]
|
|
4496
|
-
M1_dic[j2] = self.ud_grade_2(
|
|
4497
|
-
M1_smooth, axis=1
|
|
4498
|
-
) # [Nbatch, Npix_j3, Norient3]
|
|
4499
|
-
|
|
4500
|
-
### Dictionary M2_dic[j2]
|
|
4501
|
-
if cross:
|
|
4502
|
-
M2_smooth = self.smooth(
|
|
4503
|
-
M2_dic[j2], axis=1
|
|
4504
|
-
) # [Nbatch, Npix_j3, Norient3]
|
|
4505
|
-
M2_dic[j2] = self.ud_grade_2(
|
|
4506
|
-
M2, axis=1
|
|
4507
|
-
) # [Nbatch, Npix_j3, Norient3]
|
|
4508
|
-
|
|
4509
|
-
### Mask
|
|
4510
|
-
vmask = self.ud_grade_2(vmask, axis=1)
|
|
4511
|
-
|
|
4512
|
-
if self.mask_thres is not None:
|
|
4513
|
-
vmask = self.backend.bk_threshold(vmask, self.mask_thres)
|
|
4514
|
-
|
|
4515
|
-
### NSIDE_j3
|
|
4516
|
-
nside_j3 = nside_j3 // 2
|
|
4517
|
-
|
|
4518
|
-
### Store P1_dic and P2_dic in self
|
|
4519
|
-
if (norm == "auto") and (self.P1_dic is None):
|
|
4520
|
-
self.P1_dic = P1_dic
|
|
4521
|
-
if cross:
|
|
4522
|
-
self.P2_dic = P2_dic
|
|
4523
|
-
"""
|
|
4524
|
-
Sout=[s0]+S1+S2+S3+S4
|
|
4525
|
-
|
|
4526
|
-
if cross:
|
|
4527
|
-
Sout=Sout+S3P
|
|
4528
|
-
if calc_var:
|
|
4529
|
-
SVout=[vs0]+VS1+VS2+VS3+VS4
|
|
4530
|
-
if cross:
|
|
4531
|
-
VSout=VSout+VS3P
|
|
4532
|
-
return self.backend.bk_concat(Sout, 2),self.backend.bk_concat(VSout, 2)
|
|
4533
|
-
|
|
4534
|
-
return self.backend.bk_concat(Sout, 2)
|
|
4535
|
-
"""
|
|
4536
|
-
if calc_var:
|
|
4537
|
-
return result,vresult
|
|
4538
|
-
else:
|
|
4539
|
-
return result
|
|
4540
|
-
if calc_var:
|
|
4541
|
-
for k in S1:
|
|
4542
|
-
print(k.shape,k.dtype)
|
|
4543
|
-
for k in S2:
|
|
4544
|
-
print(k.shape,k.dtype)
|
|
4545
|
-
print(s0.shape,s0.dtype)
|
|
4546
|
-
return self.backend.bk_concat([s0]+S1+S2,axis=1),self.backend.bk_concat([vs0]+VS1+VS2,axis=1)
|
|
4547
|
-
else:
|
|
4548
|
-
return self.backend.bk_concat([s0]+S1+S2,axis=1)
|
|
4549
|
-
|
|
4550
|
-
if not return_data:
|
|
4551
|
-
S1 = self.backend.bk_concat(S1, 2)
|
|
4552
|
-
S2 = self.backend.bk_concat(S2, 2)
|
|
4553
|
-
S3 = self.backend.bk_concat(S3, 2)
|
|
4554
|
-
S4 = self.backend.bk_concat(S4, 2)
|
|
4555
|
-
if cross:
|
|
4556
|
-
S3P = self.backend.bk_concat(S3P, 2)
|
|
4557
|
-
if calc_var:
|
|
4558
|
-
VS1 = self.backend.bk_concat(VS1, 2)
|
|
4559
|
-
VS2 = self.backend.bk_concat(VS2, 2)
|
|
4560
|
-
VS3 = self.backend.bk_concat(VS3, 2)
|
|
4561
|
-
VS4 = self.backend.bk_concat(VS4, 2)
|
|
4562
|
-
if cross:
|
|
4563
|
-
VS3P = self.backend.bk_concat(VS3P, 2)
|
|
4564
|
-
if calc_var:
|
|
4565
|
-
if not cross:
|
|
4566
|
-
return scat_cov(
|
|
4567
|
-
s0, S2, S3, S4, s1=S1, backend=self.backend, use_1D=self.use_1D
|
|
4568
|
-
), scat_cov(
|
|
4569
|
-
vs0,
|
|
4570
|
-
VS2,
|
|
4571
|
-
VS3,
|
|
4572
|
-
VS4,
|
|
4573
|
-
s1=VS1,
|
|
4574
|
-
backend=self.backend,
|
|
4575
|
-
use_1D=self.use_1D,
|
|
4576
|
-
)
|
|
4577
|
-
else:
|
|
4578
|
-
return scat_cov(
|
|
4579
|
-
s0,
|
|
4580
|
-
S2,
|
|
4581
|
-
S3,
|
|
4582
|
-
S4,
|
|
4583
|
-
s1=S1,
|
|
4584
|
-
s3p=S3P,
|
|
4585
|
-
backend=self.backend,
|
|
4586
|
-
use_1D=self.use_1D,
|
|
4587
|
-
), scat_cov(
|
|
4588
|
-
vs0,
|
|
4589
|
-
VS2,
|
|
4590
|
-
VS3,
|
|
4591
|
-
VS4,
|
|
4592
|
-
s1=VS1,
|
|
4593
|
-
s3p=VS3P,
|
|
4594
|
-
backend=self.backend,
|
|
4595
|
-
use_1D=self.use_1D,
|
|
4596
|
-
)
|
|
4597
|
-
else:
|
|
4598
|
-
if not cross:
|
|
4599
|
-
return scat_cov(
|
|
4600
|
-
s0, S2, S3, S4, s1=S1, backend=self.backend, use_1D=self.use_1D
|
|
4601
|
-
)
|
|
4602
|
-
else:
|
|
4603
|
-
return scat_cov(
|
|
4604
|
-
s0,
|
|
4605
|
-
S2,
|
|
4606
|
-
S3,
|
|
4607
|
-
S4,
|
|
4608
|
-
s1=S1,
|
|
4609
|
-
s3p=S3P,
|
|
4610
|
-
backend=self.backend,
|
|
4611
|
-
use_1D=self.use_1D,
|
|
4612
|
-
)
|
|
4613
|
-
def clean_norm(self):
|
|
4614
|
-
self.P1_dic = None
|
|
4615
|
-
self.P2_dic = None
|
|
4616
|
-
return
|
|
4617
|
-
|
|
4618
|
-
def _compute_S3(
|
|
4619
|
-
self,
|
|
4620
|
-
j2,
|
|
4621
|
-
j3,
|
|
4622
|
-
conv,
|
|
4623
|
-
vmask,
|
|
4624
|
-
M_dic,
|
|
4625
|
-
MconvPsi_dic,
|
|
4626
|
-
calc_var=False,
|
|
4627
|
-
return_data=False,
|
|
4628
|
-
cmat2=None,
|
|
4629
|
-
):
|
|
4630
|
-
"""
|
|
4631
|
-
Compute the S3 coefficients (auto or cross)
|
|
4632
|
-
S3 = < (Ia * Psi)_j3 x (|Ib * Psi_j2| * Psi_j3)^* >_pix
|
|
4633
|
-
Parameters
|
|
4634
|
-
----------
|
|
4635
|
-
Returns
|
|
4636
|
-
-------
|
|
4637
|
-
cs3, ss3: real and imag parts of S3 coeff
|
|
4638
|
-
"""
|
|
4639
|
-
### Compute |I1 * Psi_j2| * Psi_j3 = M1_j2 * Psi_j3
|
|
4640
|
-
# Warning: M1_dic[j2] is already at j3 resolution [Nbatch, Npix_j3, Norient3]
|
|
4641
|
-
MconvPsi = self.convol(
|
|
4642
|
-
M_dic[j2], axis=1
|
|
4643
|
-
) # [Nbatch, Npix_j3, Norient3, Norient2]
|
|
4644
|
-
if cmat2 is not None:
|
|
4645
|
-
tmp2 = self.backend.bk_repeat(MconvPsi, self.NORIENT, axis=-1)
|
|
4646
|
-
MconvPsi = self.backend.bk_reduce_sum(
|
|
4647
|
-
self.backend.bk_reshape(
|
|
4648
|
-
cmat2[j3][j2] * tmp2,
|
|
4649
|
-
[
|
|
4650
|
-
tmp2.shape[0],
|
|
4651
|
-
cmat2[j3].shape[1],
|
|
4652
|
-
self.NORIENT,
|
|
4653
|
-
self.NORIENT,
|
|
4654
|
-
self.NORIENT,
|
|
4655
|
-
],
|
|
4656
|
-
),
|
|
4657
|
-
3,
|
|
4658
|
-
)
|
|
4659
|
-
|
|
4660
|
-
# Store it so we can use it in S4 computation
|
|
4661
|
-
MconvPsi_dic[j2] = MconvPsi # [Nbatch, Npix_j3, Norient3, Norient2]
|
|
4662
|
-
|
|
4663
|
-
### Compute the product (I2 * Psi)_j3 x (M1_j2 * Psi_j3)^*
|
|
4664
|
-
# z_1 x z_2^* = (a1a2 + b1b2) + i(b1a2 - a1b2)
|
|
4665
|
-
# cconv, sconv are [Nbatch, Npix_j3, Norient3]
|
|
4666
|
-
if self.use_1D:
|
|
4667
|
-
s3 = conv * self.backend.bk_conjugate(MconvPsi)
|
|
4668
|
-
else:
|
|
4669
|
-
s3 = self.backend.bk_expand_dims(conv, -1) * self.backend.bk_conjugate(
|
|
4670
|
-
MconvPsi
|
|
4671
|
-
) # [Nbatch, Npix_j3, Norient3, Norient2]
|
|
4672
|
-
|
|
4673
|
-
### Apply the mask [Nmask, Npix_j3] and sum over pixels
|
|
4674
|
-
if return_data:
|
|
4675
|
-
return s3
|
|
4676
|
-
else:
|
|
4677
|
-
if calc_var:
|
|
4678
|
-
s3, vs3 = self.masked_mean(
|
|
4679
|
-
s3, vmask, axis=1, rank=j2, calc_var=True
|
|
4680
|
-
) # [Nbatch, Nmask, Norient3, Norient2]
|
|
4681
|
-
return s3, vs3
|
|
4682
|
-
else:
|
|
4683
|
-
s3 = self.masked_mean(
|
|
4684
|
-
s3, vmask, axis=1, rank=j2
|
|
4685
|
-
) # [Nbatch, Nmask, Norient3, Norient2]
|
|
4686
|
-
return s3
|
|
4687
|
-
|
|
4688
|
-
def _compute_S4(
|
|
4689
|
-
self,
|
|
4690
|
-
j1,
|
|
4691
|
-
j2,
|
|
4692
|
-
vmask,
|
|
4693
|
-
M1convPsi_dic,
|
|
4694
|
-
M2convPsi_dic=None,
|
|
4695
|
-
calc_var=False,
|
|
4696
|
-
return_data=False,
|
|
4697
|
-
):
|
|
4698
|
-
#### Simplify notations
|
|
4699
|
-
M1 = M1convPsi_dic[j1] # [Nbatch, Npix_j3, Norient3, Norient1]
|
|
4700
|
-
|
|
4701
|
-
# Auto or Cross coefficients
|
|
4702
|
-
if M2convPsi_dic is None: # Auto
|
|
4703
|
-
M2 = M1convPsi_dic[j2] # [Nbatch, Npix_j3, Norient3, Norient2]
|
|
4704
|
-
else: # Cross
|
|
4705
|
-
M2 = M2convPsi_dic[j2]
|
|
4706
|
-
|
|
4707
|
-
### Compute the product (|I1 * Psi_j1| * Psi_j3)(|I2 * Psi_j2| * Psi_j3)
|
|
4708
|
-
# z_1 x z_2^* = (a1a2 + b1b2) + i(b1a2 - a1b2)
|
|
4709
|
-
if self.use_1D:
|
|
4710
|
-
s4 = M1 * self.backend.bk_conjugate(M2)
|
|
4711
|
-
else:
|
|
4712
|
-
s4 = self.backend.bk_expand_dims(M1, -2) * self.backend.bk_conjugate(
|
|
4713
|
-
self.backend.bk_expand_dims(M2, -1)
|
|
4714
|
-
) # [Nbatch, Npix_j3, Norient3, Norient2, Norient1]
|
|
4715
|
-
|
|
4716
|
-
### Apply the mask and sum over pixels
|
|
4717
|
-
if return_data:
|
|
4718
|
-
return s4
|
|
4719
|
-
else:
|
|
4720
|
-
if calc_var:
|
|
4721
|
-
s4, vs4 = self.masked_mean(
|
|
4722
|
-
s4, vmask, axis=1, rank=j2, calc_var=True
|
|
4723
|
-
) # [Nbatch, Nmask, Norient3, Norient2, Norient1]
|
|
4724
|
-
return s4, vs4
|
|
4725
|
-
else:
|
|
4726
|
-
s4 = self.masked_mean(
|
|
4727
|
-
s4, vmask, axis=1, rank=j2
|
|
4728
|
-
) # [Nbatch, Nmask, Norient3, Norient2, Norient1]
|
|
4729
|
-
return s4
|
|
4730
|
-
|
|
4731
|
-
def computer_filter(self,M,N,J,L):
|
|
4732
|
-
'''
|
|
4733
|
-
This function is strongly inspire by the package https://github.com/SihaoCheng/scattering_transform
|
|
4734
|
-
Done by Sihao Cheng and Rudy Morel.
|
|
4735
|
-
'''
|
|
4736
|
-
|
|
4737
|
-
filter = np.zeros([J, L, M, N],dtype='complex64')
|
|
4738
|
-
|
|
4739
|
-
slant=4.0 / L
|
|
4740
|
-
|
|
4741
|
-
for j in range(J):
|
|
4742
|
-
|
|
4743
|
-
for l in range(L):
|
|
4744
|
-
|
|
4745
|
-
theta = (int(L-L/2-1)-l) * np.pi / L
|
|
4746
|
-
sigma = 0.8 * 2**j
|
|
4747
|
-
xi = 3.0 / 4.0 * np.pi /2**j
|
|
4748
|
-
|
|
4749
|
-
R = np.array([[np.cos(theta), -np.sin(theta)], [np.sin(theta), np.cos(theta)]], np.float64)
|
|
4750
|
-
R_inv = np.array([[np.cos(theta), np.sin(theta)], [-np.sin(theta), np.cos(theta)]], np.float64)
|
|
4751
3719
|
D = np.array([[1, 0], [0, slant * slant]])
|
|
4752
|
-
curv = np.matmul(R, np.matmul(D, R_inv)) / (
|
|
4753
|
-
|
|
3720
|
+
curv = np.matmul(R, np.matmul(D, R_inv)) / (2 * sigma * sigma)
|
|
3721
|
+
|
|
4754
3722
|
gab = np.zeros((M, N), np.complex128)
|
|
4755
|
-
xx = np.empty((2,2, M, N))
|
|
4756
|
-
yy = np.empty((2,2, M, N))
|
|
4757
|
-
|
|
3723
|
+
xx = np.empty((2, 2, M, N))
|
|
3724
|
+
yy = np.empty((2, 2, M, N))
|
|
3725
|
+
|
|
4758
3726
|
for ii, ex in enumerate([-1, 0]):
|
|
4759
3727
|
for jj, ey in enumerate([-1, 0]):
|
|
4760
|
-
xx[ii,jj], yy[ii,jj] = np.mgrid[
|
|
4761
|
-
ex * M : M + ex * M,
|
|
4762
|
-
|
|
4763
|
-
|
|
4764
|
-
arg = -(
|
|
4765
|
-
|
|
4766
|
-
|
|
4767
|
-
|
|
4768
|
-
|
|
4769
|
-
|
|
3728
|
+
xx[ii, jj], yy[ii, jj] = np.mgrid[
|
|
3729
|
+
ex * M : M + ex * M, ey * N : N + ey * N
|
|
3730
|
+
]
|
|
3731
|
+
|
|
3732
|
+
arg = -(
|
|
3733
|
+
curv[0, 0] * xx * xx
|
|
3734
|
+
+ (curv[0, 1] + curv[1, 0]) * xx * yy
|
|
3735
|
+
+ curv[1, 1] * yy * yy
|
|
3736
|
+
)
|
|
3737
|
+
argi = arg + 1.0j * (xx * xi * np.cos(theta) + yy * xi * np.sin(theta))
|
|
3738
|
+
|
|
3739
|
+
gabi = np.exp(argi).sum((0, 1))
|
|
3740
|
+
gab = np.exp(arg).sum((0, 1))
|
|
3741
|
+
|
|
4770
3742
|
norm_factor = 2 * np.pi * sigma * sigma / slant
|
|
4771
|
-
|
|
3743
|
+
|
|
4772
3744
|
gab = gab / norm_factor
|
|
4773
|
-
|
|
3745
|
+
|
|
4774
3746
|
gabi = gabi / norm_factor
|
|
4775
3747
|
|
|
4776
3748
|
K = gabi.sum() / gab.sum()
|
|
4777
3749
|
|
|
4778
3750
|
# Apply the Gaussian
|
|
4779
|
-
filter[j,
|
|
4780
|
-
filter[j,
|
|
4781
|
-
|
|
3751
|
+
filter[j, ell] = np.fft.fft2(gabi - K * gab)
|
|
3752
|
+
filter[j, ell, 0, 0] = 0.0
|
|
3753
|
+
|
|
4782
3754
|
return self.backend.bk_cast(filter)
|
|
4783
|
-
|
|
3755
|
+
|
|
4784
3756
|
# ------------------------------------------------------------------------------------------
|
|
4785
3757
|
#
|
|
4786
|
-
# utility functions
|
|
3758
|
+
# utility functions
|
|
4787
3759
|
#
|
|
4788
3760
|
# ------------------------------------------------------------------------------------------
|
|
4789
|
-
def cut_high_k_off(self,data_f, dx, dy):
|
|
4790
|
-
|
|
3761
|
+
def cut_high_k_off(self, data_f, dx, dy):
|
|
3762
|
+
"""
|
|
4791
3763
|
This function is strongly inspire by the package https://github.com/SihaoCheng/scattering_transform
|
|
4792
3764
|
Done by Sihao Cheng and Rudy Morel.
|
|
4793
|
-
|
|
4794
|
-
|
|
4795
|
-
if self.backend.BACKEND==
|
|
4796
|
-
if_xodd =
|
|
4797
|
-
if_yodd =
|
|
3765
|
+
"""
|
|
3766
|
+
|
|
3767
|
+
if self.backend.BACKEND == "torch":
|
|
3768
|
+
if_xodd = data_f.shape[-2] % 2 == 1
|
|
3769
|
+
if_yodd = data_f.shape[-1] % 2 == 1
|
|
4798
3770
|
result = self.backend.backend.cat(
|
|
4799
|
-
(
|
|
4800
|
-
(
|
|
4801
|
-
|
|
4802
|
-
|
|
4803
|
-
|
|
4804
|
-
|
|
4805
|
-
|
|
3771
|
+
(
|
|
3772
|
+
self.backend.backend.cat(
|
|
3773
|
+
(
|
|
3774
|
+
data_f[..., : dx + if_xodd, : dy + if_yodd],
|
|
3775
|
+
data_f[..., -dx:, : dy + if_yodd],
|
|
3776
|
+
),
|
|
3777
|
+
-2,
|
|
3778
|
+
),
|
|
3779
|
+
self.backend.backend.cat(
|
|
3780
|
+
(data_f[..., : dx + if_xodd, -dy:], data_f[..., -dx:, -dy:]), -2
|
|
3781
|
+
),
|
|
3782
|
+
),
|
|
3783
|
+
-1,
|
|
3784
|
+
)
|
|
4806
3785
|
return result
|
|
4807
3786
|
else:
|
|
4808
3787
|
# Check if the last two dimensions are odd
|
|
4809
|
-
if_xodd = self.backend.backend.cast(
|
|
4810
|
-
|
|
3788
|
+
if_xodd = self.backend.backend.cast(
|
|
3789
|
+
self.backend.backend.shape(data_f)[-2] % 2 == 1,
|
|
3790
|
+
self.backend.backend.int32,
|
|
3791
|
+
)
|
|
3792
|
+
if_yodd = self.backend.backend.cast(
|
|
3793
|
+
self.backend.backend.shape(data_f)[-1] % 2 == 1,
|
|
3794
|
+
self.backend.backend.int32,
|
|
3795
|
+
)
|
|
4811
3796
|
|
|
4812
3797
|
# Extract four regions
|
|
4813
|
-
top_left = data_f[..., :dx+if_xodd, :dy+if_yodd]
|
|
4814
|
-
top_right = data_f[..., -dx:, :dy+if_yodd]
|
|
4815
|
-
bottom_left = data_f[..., :dx+if_xodd, -dy:]
|
|
3798
|
+
top_left = data_f[..., : dx + if_xodd, : dy + if_yodd]
|
|
3799
|
+
top_right = data_f[..., -dx:, : dy + if_yodd]
|
|
3800
|
+
bottom_left = data_f[..., : dx + if_xodd, -dy:]
|
|
4816
3801
|
bottom_right = data_f[..., -dx:, -dy:]
|
|
4817
3802
|
|
|
4818
3803
|
# Concatenate along the last two dimensions
|
|
@@ -4827,70 +3812,74 @@ class funct(FOC.FoCUS):
|
|
|
4827
3812
|
# utility functions for computing scattering coef and covariance
|
|
4828
3813
|
#
|
|
4829
3814
|
# ---------------------------------------------------------------------------
|
|
4830
|
-
|
|
4831
|
-
def get_dxdy(self, j,M,N):
|
|
4832
|
-
|
|
3815
|
+
|
|
3816
|
+
def get_dxdy(self, j, M, N):
|
|
3817
|
+
"""
|
|
4833
3818
|
This function is strongly inspire by the package https://github.com/SihaoCheng/scattering_transform
|
|
4834
3819
|
Done by Sihao Cheng and Rudy Morel.
|
|
4835
|
-
|
|
4836
|
-
dx = int(max(
|
|
4837
|
-
dy = int(max(
|
|
3820
|
+
"""
|
|
3821
|
+
dx = int(max(8, min(np.ceil(M / 2**j), M // 2)))
|
|
3822
|
+
dy = int(max(8, min(np.ceil(N / 2**j), N // 2)))
|
|
4838
3823
|
return dx, dy
|
|
4839
|
-
|
|
4840
|
-
|
|
4841
3824
|
|
|
4842
|
-
def get_edge_masks(self,M, N, J, d0=1):
|
|
4843
|
-
|
|
3825
|
+
def get_edge_masks(self, M, N, J, d0=1):
|
|
3826
|
+
"""
|
|
4844
3827
|
This function is strongly inspire by the package https://github.com/SihaoCheng/scattering_transform
|
|
4845
3828
|
Done by Sihao Cheng and Rudy Morel.
|
|
4846
|
-
|
|
3829
|
+
"""
|
|
4847
3830
|
edge_masks = np.empty((J, M, N))
|
|
4848
|
-
X, Y = np.meshgrid(np.arange(M), np.arange(N), indexing=
|
|
3831
|
+
X, Y = np.meshgrid(np.arange(M), np.arange(N), indexing="ij")
|
|
4849
3832
|
for j in range(J):
|
|
4850
|
-
edge_dx = min(M//4, 2**j*d0)
|
|
4851
|
-
edge_dy = min(N//4, 2**j*d0)
|
|
4852
|
-
edge_masks[j] = (
|
|
4853
|
-
|
|
4854
|
-
|
|
3833
|
+
edge_dx = min(M // 4, 2**j * d0)
|
|
3834
|
+
edge_dy = min(N // 4, 2**j * d0)
|
|
3835
|
+
edge_masks[j] = (
|
|
3836
|
+
(X >= edge_dx)
|
|
3837
|
+
* (X <= M - edge_dx)
|
|
3838
|
+
* (Y >= edge_dy)
|
|
3839
|
+
* (Y <= N - edge_dy)
|
|
3840
|
+
)
|
|
3841
|
+
edge_masks = edge_masks[:, None, :, :]
|
|
3842
|
+
edge_masks = edge_masks / edge_masks.mean((-2, -1))[:, :, None, None]
|
|
4855
3843
|
return self.backend.bk_cast(edge_masks)
|
|
4856
|
-
|
|
3844
|
+
|
|
4857
3845
|
# ---------------------------------------------------------------------------
|
|
4858
3846
|
#
|
|
4859
3847
|
# scattering cov
|
|
4860
3848
|
#
|
|
4861
3849
|
# ---------------------------------------------------------------------------
|
|
4862
3850
|
def scattering_cov(
|
|
4863
|
-
self,
|
|
3851
|
+
self,
|
|
3852
|
+
data,
|
|
4864
3853
|
data2=None,
|
|
4865
3854
|
Jmax=None,
|
|
4866
|
-
if_large_batch=False,
|
|
4867
|
-
S4_criteria=None,
|
|
4868
|
-
use_ref=False,
|
|
4869
|
-
normalization=
|
|
3855
|
+
if_large_batch=False,
|
|
3856
|
+
S4_criteria=None,
|
|
3857
|
+
use_ref=False,
|
|
3858
|
+
normalization="S2",
|
|
4870
3859
|
edge=False,
|
|
4871
|
-
pseudo_coef=1,
|
|
4872
|
-
get_variance=False,
|
|
3860
|
+
pseudo_coef=1,
|
|
3861
|
+
get_variance=False,
|
|
4873
3862
|
ref_sigma=None,
|
|
4874
|
-
iso_ang=False
|
|
3863
|
+
iso_ang=False,
|
|
4875
3864
|
):
|
|
4876
|
-
|
|
3865
|
+
"""
|
|
4877
3866
|
Calculates the scattering correlations for a batch of images, including:
|
|
4878
|
-
|
|
3867
|
+
|
|
4879
3868
|
This function is strongly inspire by the package https://github.com/SihaoCheng/scattering_transform
|
|
4880
3869
|
Done by Sihao Cheng and Rudy Morel.
|
|
4881
|
-
|
|
4882
|
-
orig. x orig.:
|
|
3870
|
+
|
|
3871
|
+
orig. x orig.:
|
|
4883
3872
|
P00 = <(I * psi)(I * psi)*> = L2(I * psi)^2
|
|
4884
|
-
orig. x modulus:
|
|
3873
|
+
orig. x modulus:
|
|
4885
3874
|
C01 = <(I * psi2)(|I * psi1| * psi2)*> / factor
|
|
4886
3875
|
when normalization == 'P00', factor = L2(I * psi2) * L2(I * psi1)
|
|
4887
3876
|
when normalization == 'P11', factor = L2(I * psi2) * L2(|I * psi1| * psi2)
|
|
4888
|
-
modulus x modulus:
|
|
3877
|
+
modulus x modulus:
|
|
4889
3878
|
C11_pre_norm = <(|I * psi1| * psi3)(|I * psi2| * psi3)>
|
|
4890
3879
|
C11 = C11_pre_norm / factor
|
|
4891
3880
|
when normalization == 'P00', factor = L2(I * psi1) * L2(I * psi2)
|
|
4892
3881
|
when normalization == 'P11', factor = L2(|I * psi1| * psi3) * L2(|I * psi2| * psi3)
|
|
4893
|
-
modulus x modulus (auto):
|
|
3882
|
+
modulus x modulus (auto):
|
|
4894
3883
|
P11 = <(|I * psi1| * psi2)(|I * psi1| * psi2)*>
|
|
4895
3884
|
Parameters
|
|
4896
3885
|
----------
|
|
@@ -4900,7 +3889,7 @@ class funct(FOC.FoCUS):
|
|
|
4900
3889
|
It is recommended to use "False" unless one meets a memory issue
|
|
4901
3890
|
C11_criteria : str or None (=None)
|
|
4902
3891
|
Only C11 coefficients that satisfy this criteria will be computed.
|
|
4903
|
-
Any expressions of j1, j2, and j3 that can be evaluated as a Bool
|
|
3892
|
+
Any expressions of j1, j2, and j3 that can be evaluated as a Bool
|
|
4904
3893
|
is accepted.The default "None" corresponds to "j1 <= j2 <= j3".
|
|
4905
3894
|
use_ref : Bool (=False)
|
|
4906
3895
|
When normalizing, whether or not to use the normalization factor
|
|
@@ -4914,7 +3903,7 @@ class funct(FOC.FoCUS):
|
|
|
4914
3903
|
If true, the edge region with a width of rougly the size of the largest
|
|
4915
3904
|
wavelet involved is excluded when taking the global average to obtain
|
|
4916
3905
|
the scattering coefficients.
|
|
4917
|
-
|
|
3906
|
+
|
|
4918
3907
|
Returns
|
|
4919
3908
|
-------
|
|
4920
3909
|
'P00' : torch tensor with size [N_image, J, L] (# image, j1, l1)
|
|
@@ -4932,29 +3921,34 @@ class funct(FOC.FoCUS):
|
|
|
4932
3921
|
j1 <= j3 are set to np.nan and not computed.
|
|
4933
3922
|
'P11_iso' : torch tensor with size [N_image, J, J, L] (# image, j1, j2, l2-l1)
|
|
4934
3923
|
'P11' averaged over l1 while keeping l2-l1 constant.
|
|
4935
|
-
|
|
3924
|
+
"""
|
|
4936
3925
|
if S4_criteria is None:
|
|
4937
|
-
S4_criteria =
|
|
4938
|
-
|
|
3926
|
+
S4_criteria = "j2>=j1"
|
|
3927
|
+
|
|
3928
|
+
if self.all_bk_type == "float32":
|
|
3929
|
+
C_ONE = np.complex64(1.0)
|
|
3930
|
+
else:
|
|
3931
|
+
C_ONE = np.complex128(1.0)
|
|
3932
|
+
|
|
4939
3933
|
# determine jmax and nside corresponding to the input map
|
|
4940
3934
|
im_shape = data.shape
|
|
4941
3935
|
if self.use_2D:
|
|
4942
3936
|
if len(data.shape) == 2:
|
|
4943
3937
|
nside = np.min([im_shape[0], im_shape[1]])
|
|
4944
|
-
M,N = im_shape[0],im_shape[1]
|
|
4945
|
-
N_image =
|
|
3938
|
+
M, N = im_shape[0], im_shape[1]
|
|
3939
|
+
N_image = 1
|
|
4946
3940
|
N_image2 = 1
|
|
4947
3941
|
else:
|
|
4948
3942
|
nside = np.min([im_shape[1], im_shape[2]])
|
|
4949
|
-
M,N = im_shape[1],im_shape[2]
|
|
3943
|
+
M, N = im_shape[1], im_shape[2]
|
|
4950
3944
|
N_image = data.shape[0]
|
|
4951
3945
|
if data2 is not None:
|
|
4952
3946
|
N_image2 = data2.shape[0]
|
|
4953
|
-
J = int(np.log(nside) / np.log(2))-1 # Number of j scales
|
|
3947
|
+
J = int(np.log(nside) / np.log(2)) - 1 # Number of j scales
|
|
4954
3948
|
elif self.use_1D:
|
|
4955
3949
|
if len(data.shape) == 2:
|
|
4956
3950
|
npix = int(im_shape[1]) # Number of pixels
|
|
4957
|
-
N_image =
|
|
3951
|
+
N_image = 1
|
|
4958
3952
|
N_image2 = 1
|
|
4959
3953
|
else:
|
|
4960
3954
|
npix = int(im_shape[0]) # Number of pixels
|
|
@@ -4964,12 +3958,12 @@ class funct(FOC.FoCUS):
|
|
|
4964
3958
|
|
|
4965
3959
|
nside = int(npix)
|
|
4966
3960
|
|
|
4967
|
-
J = int(np.log(nside) / np.log(2))-1 # Number of j scales
|
|
3961
|
+
J = int(np.log(nside) / np.log(2)) - 1 # Number of j scales
|
|
4968
3962
|
else:
|
|
4969
3963
|
if len(data.shape) == 2:
|
|
4970
3964
|
npix = int(im_shape[1]) # Number of pixels
|
|
4971
|
-
N_image =
|
|
4972
|
-
N_image2 =
|
|
3965
|
+
N_image = 1
|
|
3966
|
+
N_image2 = 1
|
|
4973
3967
|
else:
|
|
4974
3968
|
npix = int(im_shape[0]) # Number of pixels
|
|
4975
3969
|
N_image = data.shape[0]
|
|
@@ -4979,978 +3973,1715 @@ class funct(FOC.FoCUS):
|
|
|
4979
3973
|
nside = int(np.sqrt(npix // 12))
|
|
4980
3974
|
|
|
4981
3975
|
J = int(np.log(nside) / np.log(2)) # Number of j scales
|
|
4982
|
-
|
|
3976
|
+
|
|
4983
3977
|
if Jmax is None:
|
|
4984
3978
|
Jmax = J # Number of steps for the loop on scales
|
|
4985
|
-
if Jmax>J:
|
|
4986
|
-
print(
|
|
4987
|
-
print(
|
|
4988
|
-
|
|
4989
|
-
|
|
4990
|
-
|
|
4991
|
-
|
|
4992
|
-
|
|
4993
|
-
|
|
4994
|
-
|
|
4995
|
-
|
|
4996
|
-
|
|
4997
|
-
|
|
4998
|
-
|
|
4999
|
-
|
|
3979
|
+
if Jmax > J:
|
|
3980
|
+
print("==========\n\n")
|
|
3981
|
+
print(
|
|
3982
|
+
"The Jmax you requested is larger than the data size, which may cause problems while computing the scattering transform."
|
|
3983
|
+
)
|
|
3984
|
+
print("\n\n==========")
|
|
3985
|
+
|
|
3986
|
+
L = self.NORIENT
|
|
3987
|
+
norm_factor_S3 = 1.0
|
|
3988
|
+
|
|
3989
|
+
if self.backend.BACKEND == "torch":
|
|
3990
|
+
if (M, N, J, L) not in self.filters_set:
|
|
3991
|
+
self.filters_set[(M, N, J, L)] = self.computer_filter(
|
|
3992
|
+
M, N, J, L
|
|
3993
|
+
) # self.computer_filter(M,N,J,L)
|
|
3994
|
+
|
|
3995
|
+
filters_set = self.filters_set[(M, N, J, L)]
|
|
3996
|
+
|
|
3997
|
+
# weight = self.weight
|
|
5000
3998
|
if use_ref:
|
|
5001
|
-
if normalization==
|
|
3999
|
+
if normalization == "S2":
|
|
5002
4000
|
ref_S2 = self.ref_scattering_cov_S2
|
|
5003
|
-
else:
|
|
5004
|
-
ref_P11 = self.ref_scattering_cov[
|
|
4001
|
+
else:
|
|
4002
|
+
ref_P11 = self.ref_scattering_cov["P11"]
|
|
5005
4003
|
|
|
5006
4004
|
# convert numpy array input into self.backend.bk_ tensors
|
|
5007
4005
|
data = self.backend.bk_cast(data)
|
|
5008
|
-
data_f = self.backend.bk_fftn(data, dim=(-2
|
|
4006
|
+
data_f = self.backend.bk_fftn(data, dim=(-2, -1))
|
|
5009
4007
|
if data2 is not None:
|
|
5010
4008
|
data2 = self.backend.bk_cast(data2)
|
|
5011
|
-
data2_f = self.backend.bk_fftn(data2, dim=(-2
|
|
5012
|
-
|
|
5013
|
-
# initialize tensors for scattering coefficients
|
|
5014
|
-
S2 = self.backend.bk_zeros((N_image,J,L), dtype=data.dtype)
|
|
5015
|
-
S1 = self.backend.bk_zeros((N_image,J,L), dtype=data.dtype)
|
|
5016
|
-
|
|
5017
|
-
Ndata_S3 = J*(J+1)//2
|
|
5018
|
-
Ndata_S4 = J*(J+1)*(J+2)//6
|
|
5019
|
-
J_S4={}
|
|
5020
|
-
|
|
5021
|
-
S3 = self.backend.bk_zeros((N_image,Ndata_S3,L,L), dtype=data_f.dtype)
|
|
4009
|
+
data2_f = self.backend.bk_fftn(data2, dim=(-2, -1))
|
|
4010
|
+
|
|
4011
|
+
# initialize tensors for scattering coefficients
|
|
4012
|
+
S2 = self.backend.bk_zeros((N_image, J, L), dtype=data.dtype)
|
|
4013
|
+
S1 = self.backend.bk_zeros((N_image, J, L), dtype=data.dtype)
|
|
4014
|
+
|
|
4015
|
+
Ndata_S3 = J * (J + 1) // 2
|
|
4016
|
+
Ndata_S4 = J * (J + 1) * (J + 2) // 6
|
|
4017
|
+
J_S4 = {}
|
|
4018
|
+
|
|
4019
|
+
S3 = self.backend.bk_zeros((N_image, Ndata_S3, L, L), dtype=data_f.dtype)
|
|
5022
4020
|
if data2 is not None:
|
|
5023
|
-
S3p = self.backend.bk_zeros(
|
|
5024
|
-
|
|
5025
|
-
|
|
5026
|
-
|
|
4021
|
+
S3p = self.backend.bk_zeros(
|
|
4022
|
+
(N_image, Ndata_S3, L, L), dtype=data_f.dtype
|
|
4023
|
+
)
|
|
4024
|
+
S4_pre_norm = self.backend.bk_zeros(
|
|
4025
|
+
(N_image, Ndata_S4, L, L, L), dtype=data_f.dtype
|
|
4026
|
+
)
|
|
4027
|
+
S4 = self.backend.bk_zeros((N_image, Ndata_S4, L, L, L), dtype=data_f.dtype)
|
|
4028
|
+
|
|
5027
4029
|
# variance
|
|
5028
4030
|
if get_variance:
|
|
5029
|
-
S2_sigma = self.backend.bk_zeros((N_image,J,L), dtype=data.dtype)
|
|
5030
|
-
S1_sigma = self.backend.bk_zeros((N_image,J,L), dtype=data.dtype)
|
|
5031
|
-
S3_sigma = self.backend.bk_zeros(
|
|
4031
|
+
S2_sigma = self.backend.bk_zeros((N_image, J, L), dtype=data.dtype)
|
|
4032
|
+
S1_sigma = self.backend.bk_zeros((N_image, J, L), dtype=data.dtype)
|
|
4033
|
+
S3_sigma = self.backend.bk_zeros(
|
|
4034
|
+
(N_image, Ndata_S3, L, L), dtype=data_f.dtype
|
|
4035
|
+
)
|
|
5032
4036
|
if data2 is not None:
|
|
5033
|
-
S3p_sigma = self.backend.bk_zeros(
|
|
5034
|
-
|
|
5035
|
-
|
|
4037
|
+
S3p_sigma = self.backend.bk_zeros(
|
|
4038
|
+
(N_image, Ndata_S3, L, L), dtype=data_f.dtype
|
|
4039
|
+
)
|
|
4040
|
+
S4_sigma = self.backend.bk_zeros(
|
|
4041
|
+
(N_image, Ndata_S4, L, L, L), dtype=data_f.dtype
|
|
4042
|
+
)
|
|
4043
|
+
|
|
5036
4044
|
if iso_ang:
|
|
5037
|
-
S3_iso = self.backend.bk_zeros(
|
|
5038
|
-
|
|
4045
|
+
S3_iso = self.backend.bk_zeros(
|
|
4046
|
+
(N_image, Ndata_S3, L), dtype=data_f.dtype
|
|
4047
|
+
)
|
|
4048
|
+
S4_iso = self.backend.bk_zeros(
|
|
4049
|
+
(N_image, Ndata_S4, L, L), dtype=data_f.dtype
|
|
4050
|
+
)
|
|
5039
4051
|
if get_variance:
|
|
5040
|
-
S3_sigma_iso = self.backend.bk_zeros(
|
|
5041
|
-
|
|
4052
|
+
S3_sigma_iso = self.backend.bk_zeros(
|
|
4053
|
+
(N_image, Ndata_S3, L), dtype=data_f.dtype
|
|
4054
|
+
)
|
|
4055
|
+
S4_sigma_iso = self.backend.bk_zeros(
|
|
4056
|
+
(N_image, Ndata_S4, L, L), dtype=data_f.dtype
|
|
4057
|
+
)
|
|
5042
4058
|
if data2 is not None:
|
|
5043
|
-
S3p_iso = self.backend.bk_zeros(
|
|
4059
|
+
S3p_iso = self.backend.bk_zeros(
|
|
4060
|
+
(N_image, Ndata_S3, L), dtype=data_f.dtype
|
|
4061
|
+
)
|
|
5044
4062
|
if get_variance:
|
|
5045
|
-
S3p_sigma_iso = self.backend.bk_zeros(
|
|
5046
|
-
|
|
4063
|
+
S3p_sigma_iso = self.backend.bk_zeros(
|
|
4064
|
+
(N_image, Ndata_S3, L), dtype=data_f.dtype
|
|
4065
|
+
)
|
|
4066
|
+
|
|
5047
4067
|
#
|
|
5048
|
-
if edge:
|
|
5049
|
-
if (M,N,J) not in self.edge_masks:
|
|
5050
|
-
self.edge_masks[(M,N,J)] = self.get_edge_masks(M,N,J)
|
|
5051
|
-
edge_mask=self.edge_masks[(M,N,J)]
|
|
5052
|
-
else:
|
|
4068
|
+
if edge:
|
|
4069
|
+
if (M, N, J) not in self.edge_masks:
|
|
4070
|
+
self.edge_masks[(M, N, J)] = self.get_edge_masks(M, N, J)
|
|
4071
|
+
edge_mask = self.edge_masks[(M, N, J)]
|
|
4072
|
+
else:
|
|
5053
4073
|
edge_mask = 1
|
|
5054
|
-
|
|
4074
|
+
|
|
5055
4075
|
# calculate scattering fields
|
|
5056
4076
|
if data2 is None:
|
|
5057
4077
|
if self.use_2D:
|
|
5058
4078
|
if len(data.shape) == 2:
|
|
5059
4079
|
I1 = self.backend.bk_ifftn(
|
|
5060
|
-
data_f[None,None,None
|
|
4080
|
+
data_f[None, None, None, :, :]
|
|
4081
|
+
* filters_set[None, :J, :, :, :],
|
|
4082
|
+
dim=(-2, -1),
|
|
5061
4083
|
).abs()
|
|
5062
4084
|
else:
|
|
5063
4085
|
I1 = self.backend.bk_ifftn(
|
|
5064
|
-
data_f[:,None,None
|
|
4086
|
+
data_f[:, None, None, :, :]
|
|
4087
|
+
* filters_set[None, :J, :, :, :],
|
|
4088
|
+
dim=(-2, -1),
|
|
5065
4089
|
).abs()
|
|
5066
4090
|
elif self.use_1D:
|
|
5067
4091
|
if len(data.shape) == 1:
|
|
5068
4092
|
I1 = self.backend.bk_ifftn(
|
|
5069
|
-
data_f[None,None,None
|
|
4093
|
+
data_f[None, None, None, :] * filters_set[None, :J, :, :],
|
|
4094
|
+
dim=(-1),
|
|
5070
4095
|
).abs()
|
|
5071
4096
|
else:
|
|
5072
4097
|
I1 = self.backend.bk_ifftn(
|
|
5073
|
-
data_f[:,None,None
|
|
4098
|
+
data_f[:, None, None, :] * filters_set[None, :J, :, :],
|
|
4099
|
+
dim=(-1),
|
|
5074
4100
|
).abs()
|
|
5075
4101
|
else:
|
|
5076
|
-
print(
|
|
5077
|
-
|
|
5078
|
-
S2 = (I1**2 * edge_mask).mean((-2
|
|
5079
|
-
S1
|
|
4102
|
+
print("todo")
|
|
4103
|
+
|
|
4104
|
+
S2 = (I1**2 * edge_mask).mean((-2, -1))
|
|
4105
|
+
S1 = (I1 * edge_mask).mean((-2, -1))
|
|
5080
4106
|
|
|
5081
4107
|
if get_variance:
|
|
5082
|
-
S2_sigma = (I1**2 * edge_mask).std((-2
|
|
5083
|
-
S1_sigma
|
|
5084
|
-
|
|
4108
|
+
S2_sigma = (I1**2 * edge_mask).std((-2, -1))
|
|
4109
|
+
S1_sigma = (I1 * edge_mask).std((-2, -1))
|
|
4110
|
+
|
|
5085
4111
|
else:
|
|
5086
4112
|
if self.use_2D:
|
|
5087
4113
|
if len(data.shape) == 2:
|
|
5088
4114
|
I1 = self.backend.bk_ifftn(
|
|
5089
|
-
data_f[None,None,None
|
|
4115
|
+
data_f[None, None, None, :, :]
|
|
4116
|
+
* filters_set[None, :J, :, :, :],
|
|
4117
|
+
dim=(-2, -1),
|
|
5090
4118
|
)
|
|
5091
4119
|
I2 = self.backend.bk_ifftn(
|
|
5092
|
-
data2_f[None,None,None
|
|
4120
|
+
data2_f[None, None, None, :, :]
|
|
4121
|
+
* filters_set[None, :J, :, :, :],
|
|
4122
|
+
dim=(-2, -1),
|
|
5093
4123
|
)
|
|
5094
4124
|
else:
|
|
5095
4125
|
I1 = self.backend.bk_ifftn(
|
|
5096
|
-
data_f[:,None,None
|
|
4126
|
+
data_f[:, None, None, :, :]
|
|
4127
|
+
* filters_set[None, :J, :, :, :],
|
|
4128
|
+
dim=(-2, -1),
|
|
5097
4129
|
)
|
|
5098
4130
|
I2 = self.backend.bk_ifftn(
|
|
5099
|
-
data2_f[:,None,None
|
|
4131
|
+
data2_f[:, None, None, :, :]
|
|
4132
|
+
* filters_set[None, :J, :, :, :],
|
|
4133
|
+
dim=(-2, -1),
|
|
5100
4134
|
)
|
|
5101
4135
|
elif self.use_1D:
|
|
5102
4136
|
if len(data.shape) == 1:
|
|
5103
4137
|
I1 = self.backend.bk_ifftn(
|
|
5104
|
-
data_f[None,None,None
|
|
4138
|
+
data_f[None, None, None, :] * filters_set[None, :J, :, :],
|
|
4139
|
+
dim=(-1),
|
|
5105
4140
|
)
|
|
5106
4141
|
I2 = self.backend.bk_ifftn(
|
|
5107
|
-
data2_f[None,None,None
|
|
4142
|
+
data2_f[None, None, None, :] * filters_set[None, :J, :, :],
|
|
4143
|
+
dim=(-1),
|
|
5108
4144
|
)
|
|
5109
4145
|
else:
|
|
5110
4146
|
I1 = self.backend.bk_ifftn(
|
|
5111
|
-
data_f[:,None,None
|
|
4147
|
+
data_f[:, None, None, :] * filters_set[None, :J, :, :],
|
|
4148
|
+
dim=(-1),
|
|
5112
4149
|
)
|
|
5113
4150
|
I2 = self.backend.bk_ifftn(
|
|
5114
|
-
data2_f[:,None,None
|
|
4151
|
+
data2_f[:, None, None, :] * filters_set[None, :J, :, :],
|
|
4152
|
+
dim=(-1),
|
|
5115
4153
|
)
|
|
5116
4154
|
else:
|
|
5117
|
-
print(
|
|
5118
|
-
|
|
5119
|
-
I1=self.backend.bk_real(I1*self.backend.bk_conjugate(I2))
|
|
5120
|
-
|
|
5121
|
-
S2 = self.backend.bk_reduce_mean((I1 * edge_mask),axis=(-2
|
|
4155
|
+
print("todo")
|
|
4156
|
+
|
|
4157
|
+
I1 = self.backend.bk_real(I1 * self.backend.bk_conjugate(I2))
|
|
4158
|
+
|
|
4159
|
+
S2 = self.backend.bk_reduce_mean((I1 * edge_mask), axis=(-2, -1))
|
|
5122
4160
|
if get_variance:
|
|
5123
|
-
S2_sigma = self.backend.bk_reduce_std(
|
|
5124
|
-
|
|
5125
|
-
|
|
5126
|
-
|
|
5127
|
-
|
|
4161
|
+
S2_sigma = self.backend.bk_reduce_std(
|
|
4162
|
+
(I1 * edge_mask), axis=(-2, -1)
|
|
4163
|
+
)
|
|
4164
|
+
|
|
4165
|
+
I1 = self.backend.bk_L1(I1)
|
|
4166
|
+
|
|
4167
|
+
S1 = self.backend.bk_reduce_mean((I1 * edge_mask), axis=(-2, -1))
|
|
5128
4168
|
|
|
5129
4169
|
if get_variance:
|
|
5130
|
-
S1_sigma
|
|
5131
|
-
|
|
5132
|
-
|
|
5133
|
-
|
|
4170
|
+
S1_sigma = self.backend.bk_reduce_std(
|
|
4171
|
+
(I1 * edge_mask), axis=(-2, -1)
|
|
4172
|
+
)
|
|
4173
|
+
|
|
4174
|
+
I1_f = self.backend.bk_fftn(I1, dim=(-2, -1))
|
|
4175
|
+
|
|
5134
4176
|
if pseudo_coef != 1:
|
|
5135
4177
|
I1 = I1**pseudo_coef
|
|
5136
|
-
|
|
5137
|
-
Ndata_S3=0
|
|
5138
|
-
Ndata_S4=0
|
|
5139
|
-
|
|
4178
|
+
|
|
4179
|
+
Ndata_S3 = 0
|
|
4180
|
+
Ndata_S4 = 0
|
|
4181
|
+
|
|
5140
4182
|
# calculate the covariance and correlations of the scattering fields
|
|
5141
4183
|
# only use the low-k Fourier coefs when calculating large-j scattering coefs.
|
|
5142
|
-
for j3 in range(0,J):
|
|
5143
|
-
J_S4[j3]=Ndata_S4
|
|
5144
|
-
|
|
5145
|
-
dx3, dy3 = self.get_dxdy(j3,M,N)
|
|
5146
|
-
I1_f_small = self.cut_high_k_off(
|
|
4184
|
+
for j3 in range(0, J):
|
|
4185
|
+
J_S4[j3] = Ndata_S4
|
|
4186
|
+
|
|
4187
|
+
dx3, dy3 = self.get_dxdy(j3, M, N)
|
|
4188
|
+
I1_f_small = self.cut_high_k_off(
|
|
4189
|
+
I1_f[:, : j3 + 1], dx3, dy3
|
|
4190
|
+
) # Nimage, J, L, x, y
|
|
5147
4191
|
data_f_small = self.cut_high_k_off(data_f, dx3, dy3)
|
|
5148
4192
|
if data2 is not None:
|
|
5149
4193
|
data2_f_small = self.cut_high_k_off(data2_f, dx3, dy3)
|
|
5150
4194
|
if edge:
|
|
5151
|
-
I1_small = self.backend.bk_ifftn(
|
|
5152
|
-
|
|
4195
|
+
I1_small = self.backend.bk_ifftn(
|
|
4196
|
+
I1_f_small, dim=(-2, -1), norm="ortho"
|
|
4197
|
+
)
|
|
4198
|
+
data_small = self.backend.bk_ifftn(
|
|
4199
|
+
data_f_small, dim=(-2, -1), norm="ortho"
|
|
4200
|
+
)
|
|
5153
4201
|
if data2 is not None:
|
|
5154
|
-
data2_small = self.backend.bk_ifftn(
|
|
5155
|
-
|
|
5156
|
-
|
|
4202
|
+
data2_small = self.backend.bk_ifftn(
|
|
4203
|
+
data2_f_small, dim=(-2, -1), norm="ortho"
|
|
4204
|
+
)
|
|
4205
|
+
|
|
4206
|
+
wavelet_f3 = self.cut_high_k_off(filters_set[j3], dx3, dy3) # L,x,y
|
|
5157
4207
|
_, M3, N3 = wavelet_f3.shape
|
|
5158
4208
|
wavelet_f3_squared = wavelet_f3**2
|
|
5159
|
-
edge_dx = min(4, int(2**j3*dx3*2/M))
|
|
5160
|
-
edge_dy = min(4, int(2**j3*dy3*2/N))
|
|
5161
|
-
|
|
4209
|
+
edge_dx = min(4, int(2**j3 * dx3 * 2 / M))
|
|
4210
|
+
edge_dy = min(4, int(2**j3 * dy3 * 2 / N))
|
|
4211
|
+
|
|
5162
4212
|
# a normalization change due to the cutoff of frequency space
|
|
5163
|
-
fft_factor = 1 /(M3*N3) * (M3*N3/M/N)**2
|
|
5164
|
-
for j2 in range(0,j3+1):
|
|
5165
|
-
I1_f2_wf3_small = I1_f_small[:,j2].view(
|
|
5166
|
-
|
|
4213
|
+
fft_factor = 1 / (M3 * N3) * (M3 * N3 / M / N) ** 2
|
|
4214
|
+
for j2 in range(0, j3 + 1):
|
|
4215
|
+
I1_f2_wf3_small = I1_f_small[:, j2].view(
|
|
4216
|
+
N_image, L, 1, M3, N3
|
|
4217
|
+
) * wavelet_f3.view(1, 1, L, M3, N3)
|
|
4218
|
+
I1_f2_wf3_2_small = I1_f_small[:, j2].view(
|
|
4219
|
+
N_image, L, 1, M3, N3
|
|
4220
|
+
) * wavelet_f3_squared.view(1, 1, L, M3, N3)
|
|
5167
4221
|
if edge:
|
|
5168
|
-
I12_w3_small = self.backend.bk_ifftn(
|
|
5169
|
-
|
|
4222
|
+
I12_w3_small = self.backend.bk_ifftn(
|
|
4223
|
+
I1_f2_wf3_small, dim=(-2, -1), norm="ortho"
|
|
4224
|
+
)
|
|
4225
|
+
I12_w3_2_small = self.backend.bk_ifftn(
|
|
4226
|
+
I1_f2_wf3_2_small, dim=(-2, -1), norm="ortho"
|
|
4227
|
+
)
|
|
5170
4228
|
if use_ref:
|
|
5171
|
-
if normalization==
|
|
5172
|
-
norm_factor_S3 = (
|
|
5173
|
-
|
|
5174
|
-
|
|
4229
|
+
if normalization == "P11":
|
|
4230
|
+
norm_factor_S3 = (
|
|
4231
|
+
ref_S2[:, None, j3, :]
|
|
4232
|
+
* ref_P11[:, j2, j3, :, :] ** pseudo_coef
|
|
4233
|
+
) ** 0.5
|
|
4234
|
+
if normalization == "S2":
|
|
4235
|
+
norm_factor_S3 = (
|
|
4236
|
+
ref_S2[:, None, j3, :]
|
|
4237
|
+
* ref_S2[:, j2, :, None] ** pseudo_coef
|
|
4238
|
+
) ** 0.5
|
|
5175
4239
|
else:
|
|
5176
|
-
if normalization==
|
|
4240
|
+
if normalization == "P11":
|
|
5177
4241
|
# [N_image,l2,l3,x,y]
|
|
5178
|
-
P11_temp = (I1_f2_wf3_small.abs()**2).mean(
|
|
5179
|
-
|
|
5180
|
-
|
|
5181
|
-
norm_factor_S3 = (
|
|
4242
|
+
P11_temp = (I1_f2_wf3_small.abs() ** 2).mean(
|
|
4243
|
+
(-2, -1)
|
|
4244
|
+
) * fft_factor
|
|
4245
|
+
norm_factor_S3 = (
|
|
4246
|
+
S2[:, None, j3, :] * P11_temp**pseudo_coef
|
|
4247
|
+
) ** 0.5
|
|
4248
|
+
if normalization == "S2":
|
|
4249
|
+
norm_factor_S3 = (
|
|
4250
|
+
S2[:, None, j3, :] * S2[:, j2, :, None] ** pseudo_coef
|
|
4251
|
+
) ** 0.5
|
|
5182
4252
|
|
|
5183
4253
|
if not edge:
|
|
5184
|
-
S3[:,Ndata_S3
|
|
5185
|
-
|
|
5186
|
-
|
|
5187
|
-
|
|
4254
|
+
S3[:, Ndata_S3, :, :] = (
|
|
4255
|
+
(
|
|
4256
|
+
data_f_small.view(N_image, 1, 1, M3, N3)
|
|
4257
|
+
* self.backend.bk_conjugate(I1_f2_wf3_small)
|
|
4258
|
+
).mean((-2, -1))
|
|
4259
|
+
* fft_factor
|
|
4260
|
+
/ norm_factor_S3
|
|
4261
|
+
)
|
|
4262
|
+
|
|
5188
4263
|
if get_variance:
|
|
5189
|
-
S3_sigma[:,Ndata_S3
|
|
5190
|
-
|
|
5191
|
-
|
|
4264
|
+
S3_sigma[:, Ndata_S3, :, :] = (
|
|
4265
|
+
(
|
|
4266
|
+
data_f_small.view(N_image, 1, 1, M3, N3)
|
|
4267
|
+
* self.backend.bk_conjugate(I1_f2_wf3_small)
|
|
4268
|
+
).std((-2, -1))
|
|
4269
|
+
* fft_factor
|
|
4270
|
+
/ norm_factor_S3
|
|
4271
|
+
)
|
|
5192
4272
|
else:
|
|
5193
|
-
S3[:,Ndata_S3
|
|
5194
|
-
|
|
5195
|
-
|
|
4273
|
+
S3[:, Ndata_S3, :, :] = (
|
|
4274
|
+
(
|
|
4275
|
+
data_small.view(N_image, 1, 1, M3, N3)
|
|
4276
|
+
* self.backend.bk_conjugate(I12_w3_small)
|
|
4277
|
+
)[..., edge_dx : M3 - edge_dx, edge_dy : N3 - edge_dy].mean(
|
|
4278
|
+
(-2, -1)
|
|
4279
|
+
)
|
|
4280
|
+
* fft_factor
|
|
4281
|
+
/ norm_factor_S3
|
|
4282
|
+
)
|
|
5196
4283
|
if get_variance:
|
|
5197
|
-
S3_sigma[:,Ndata_S3
|
|
5198
|
-
|
|
5199
|
-
|
|
4284
|
+
S3_sigma[:, Ndata_S3, :, :] = (
|
|
4285
|
+
(
|
|
4286
|
+
data_small.view(N_image, 1, 1, M3, N3)
|
|
4287
|
+
* self.backend.bk_conjugate(I12_w3_small)
|
|
4288
|
+
)[
|
|
4289
|
+
..., edge_dx : M3 - edge_dx, edge_dy : N3 - edge_dy
|
|
4290
|
+
].std(
|
|
4291
|
+
(-2, -1)
|
|
4292
|
+
)
|
|
4293
|
+
* fft_factor
|
|
4294
|
+
/ norm_factor_S3
|
|
4295
|
+
)
|
|
5200
4296
|
if data2 is not None:
|
|
5201
4297
|
if not edge:
|
|
5202
|
-
S3p[:,Ndata_S3
|
|
5203
|
-
|
|
5204
|
-
|
|
5205
|
-
|
|
4298
|
+
S3p[:, Ndata_S3, :, :] = (
|
|
4299
|
+
(
|
|
4300
|
+
data2_f_small.view(N_image2, 1, 1, M3, N3)
|
|
4301
|
+
* self.backend.bk_conjugate(I1_f2_wf3_small)
|
|
4302
|
+
).mean((-2, -1))
|
|
4303
|
+
* fft_factor
|
|
4304
|
+
/ norm_factor_S3
|
|
4305
|
+
)
|
|
4306
|
+
|
|
5206
4307
|
if get_variance:
|
|
5207
|
-
S3p_sigma[:,Ndata_S3
|
|
5208
|
-
|
|
5209
|
-
|
|
4308
|
+
S3p_sigma[:, Ndata_S3, :, :] = (
|
|
4309
|
+
(
|
|
4310
|
+
data2_f_small.view(N_image2, 1, 1, M3, N3)
|
|
4311
|
+
* self.backend.bk_conjugate(I1_f2_wf3_small)
|
|
4312
|
+
).std((-2, -1))
|
|
4313
|
+
* fft_factor
|
|
4314
|
+
/ norm_factor_S3
|
|
4315
|
+
)
|
|
5210
4316
|
else:
|
|
5211
|
-
S3p[:,Ndata_S3
|
|
5212
|
-
|
|
5213
|
-
|
|
4317
|
+
S3p[:, Ndata_S3, :, :] = (
|
|
4318
|
+
(
|
|
4319
|
+
data2_small.view(N_image2, 1, 1, M3, N3)
|
|
4320
|
+
* self.backend.bk_conjugate(I12_w3_small)
|
|
4321
|
+
)[
|
|
4322
|
+
..., edge_dx : M3 - edge_dx, edge_dy : N3 - edge_dy
|
|
4323
|
+
].mean(
|
|
4324
|
+
(-2, -1)
|
|
4325
|
+
)
|
|
4326
|
+
* fft_factor
|
|
4327
|
+
/ norm_factor_S3
|
|
4328
|
+
)
|
|
5214
4329
|
if get_variance:
|
|
5215
|
-
S3p_sigma[:,Ndata_S3
|
|
5216
|
-
|
|
5217
|
-
|
|
5218
|
-
|
|
4330
|
+
S3p_sigma[:, Ndata_S3, :, :] = (
|
|
4331
|
+
(
|
|
4332
|
+
data2_small.view(N_image2, 1, 1, M3, N3)
|
|
4333
|
+
* self.backend.bk_conjugate(I12_w3_small)
|
|
4334
|
+
)[
|
|
4335
|
+
...,
|
|
4336
|
+
edge_dx : M3 - edge_dx,
|
|
4337
|
+
edge_dy : N3 - edge_dy,
|
|
4338
|
+
].std(
|
|
4339
|
+
(-2, -1)
|
|
4340
|
+
)
|
|
4341
|
+
* fft_factor
|
|
4342
|
+
/ norm_factor_S3
|
|
4343
|
+
)
|
|
4344
|
+
Ndata_S3 += 1
|
|
5219
4345
|
if j2 <= j3:
|
|
5220
|
-
beg_n=Ndata_S4
|
|
5221
|
-
for j1 in range(0, j2+1):
|
|
4346
|
+
beg_n = Ndata_S4
|
|
4347
|
+
for j1 in range(0, j2 + 1):
|
|
5222
4348
|
if eval(S4_criteria):
|
|
5223
4349
|
if not edge:
|
|
5224
4350
|
if not if_large_batch:
|
|
5225
4351
|
# [N_image,l1,l2,l3,x,y]
|
|
5226
|
-
S4_pre_norm[:,Ndata_S4
|
|
5227
|
-
I1_f_small[:,j1].view(
|
|
5228
|
-
|
|
5229
|
-
|
|
4352
|
+
S4_pre_norm[:, Ndata_S4, :, :, :] = (
|
|
4353
|
+
I1_f_small[:, j1].view(
|
|
4354
|
+
N_image, L, 1, 1, M3, N3
|
|
4355
|
+
)
|
|
4356
|
+
* self.backend.bk_conjugate(
|
|
4357
|
+
I1_f2_wf3_2_small.view(
|
|
4358
|
+
N_image, 1, L, L, M3, N3
|
|
4359
|
+
)
|
|
4360
|
+
)
|
|
4361
|
+
).mean((-2, -1)) * fft_factor
|
|
5230
4362
|
if get_variance:
|
|
5231
|
-
S4_sigma[:,Ndata_S4
|
|
5232
|
-
I1_f_small[:,j1].view(
|
|
5233
|
-
|
|
5234
|
-
|
|
4363
|
+
S4_sigma[:, Ndata_S4, :, :, :] = (
|
|
4364
|
+
I1_f_small[:, j1].view(
|
|
4365
|
+
N_image, L, 1, 1, M3, N3
|
|
4366
|
+
)
|
|
4367
|
+
* self.backend.bk_conjugate(
|
|
4368
|
+
I1_f2_wf3_2_small.view(
|
|
4369
|
+
N_image, 1, L, L, M3, N3
|
|
4370
|
+
)
|
|
4371
|
+
)
|
|
4372
|
+
).std((-2, -1)) * fft_factor
|
|
5235
4373
|
else:
|
|
5236
4374
|
for l1 in range(L):
|
|
5237
4375
|
# [N_image,l2,l3,x,y]
|
|
5238
|
-
S4_pre_norm[:,Ndata_S4,l1
|
|
5239
|
-
I1_f_small[:,j1,l1].view(
|
|
5240
|
-
|
|
5241
|
-
|
|
4376
|
+
S4_pre_norm[:, Ndata_S4, l1, :, :] = (
|
|
4377
|
+
I1_f_small[:, j1, l1].view(
|
|
4378
|
+
N_image, 1, 1, M3, N3
|
|
4379
|
+
)
|
|
4380
|
+
* self.backend.bk_conjugate(
|
|
4381
|
+
I1_f2_wf3_2_small.view(
|
|
4382
|
+
N_image, L, L, M3, N3
|
|
4383
|
+
)
|
|
4384
|
+
)
|
|
4385
|
+
).mean((-2, -1)) * fft_factor
|
|
5242
4386
|
if get_variance:
|
|
5243
|
-
S4_sigma[:,Ndata_S4,l1
|
|
5244
|
-
I1_f_small[:,j1,l1].view(
|
|
5245
|
-
|
|
5246
|
-
|
|
4387
|
+
S4_sigma[:, Ndata_S4, l1, :, :] = (
|
|
4388
|
+
I1_f_small[:, j1, l1].view(
|
|
4389
|
+
N_image, 1, 1, M3, N3
|
|
4390
|
+
)
|
|
4391
|
+
* self.backend.bk_conjugate(
|
|
4392
|
+
I1_f2_wf3_2_small.view(
|
|
4393
|
+
N_image, L, L, M3, N3
|
|
4394
|
+
)
|
|
4395
|
+
)
|
|
4396
|
+
).std((-2, -1)) * fft_factor
|
|
5247
4397
|
else:
|
|
5248
4398
|
if not if_large_batch:
|
|
5249
4399
|
# [N_image,l1,l2,l3,x,y]
|
|
5250
|
-
S4_pre_norm[:,Ndata_S4
|
|
5251
|
-
I1_small[:,j1].view(
|
|
5252
|
-
|
|
4400
|
+
S4_pre_norm[:, Ndata_S4, :, :, :] = (
|
|
4401
|
+
I1_small[:, j1].view(
|
|
4402
|
+
N_image, L, 1, 1, M3, N3
|
|
4403
|
+
)
|
|
4404
|
+
* self.backend.bk_conjugate(
|
|
4405
|
+
I12_w3_2_small.view(
|
|
4406
|
+
N_image, 1, L, L, M3, N3
|
|
4407
|
+
)
|
|
5253
4408
|
)
|
|
5254
|
-
)[...,edge_dx:-edge_dx, edge_dy:-edge_dy].mean(
|
|
4409
|
+
)[..., edge_dx:-edge_dx, edge_dy:-edge_dy].mean(
|
|
4410
|
+
(-2, -1)
|
|
4411
|
+
) * fft_factor
|
|
5255
4412
|
if get_variance:
|
|
5256
|
-
S4_sigma[:,Ndata_S4
|
|
5257
|
-
I1_small[:,j1].view(
|
|
5258
|
-
|
|
4413
|
+
S4_sigma[:, Ndata_S4, :, :, :] = (
|
|
4414
|
+
I1_small[:, j1].view(
|
|
4415
|
+
N_image, L, 1, 1, M3, N3
|
|
5259
4416
|
)
|
|
5260
|
-
|
|
4417
|
+
* self.backend.bk_conjugate(
|
|
4418
|
+
I12_w3_2_small.view(
|
|
4419
|
+
N_image, 1, L, L, M3, N3
|
|
4420
|
+
)
|
|
4421
|
+
)
|
|
4422
|
+
)[
|
|
4423
|
+
..., edge_dx:-edge_dx, edge_dy:-edge_dy
|
|
4424
|
+
].std(
|
|
4425
|
+
(-2, -1)
|
|
4426
|
+
) * fft_factor
|
|
5261
4427
|
else:
|
|
5262
4428
|
for l1 in range(L):
|
|
5263
|
-
|
|
5264
|
-
S4_pre_norm[:,Ndata_S4,l1
|
|
5265
|
-
I1_small[:,j1].view(
|
|
5266
|
-
|
|
4429
|
+
# [N_image,l2,l3,x,y]
|
|
4430
|
+
S4_pre_norm[:, Ndata_S4, l1, :, :] = (
|
|
4431
|
+
I1_small[:, j1].view(
|
|
4432
|
+
N_image, 1, 1, M3, N3
|
|
4433
|
+
)
|
|
4434
|
+
* self.backend.bk_conjugate(
|
|
4435
|
+
I12_w3_2_small.view(
|
|
4436
|
+
N_image, L, L, M3, N3
|
|
4437
|
+
)
|
|
5267
4438
|
)
|
|
5268
|
-
)[
|
|
4439
|
+
)[
|
|
4440
|
+
..., edge_dx:-edge_dx, edge_dy:-edge_dy
|
|
4441
|
+
].mean(
|
|
4442
|
+
(-2, -1)
|
|
4443
|
+
) * fft_factor
|
|
5269
4444
|
if get_variance:
|
|
5270
|
-
S4_sigma[:,Ndata_S4,l1
|
|
5271
|
-
I1_small[:,j1].view(
|
|
5272
|
-
|
|
4445
|
+
S4_sigma[:, Ndata_S4, l1, :, :] = (
|
|
4446
|
+
I1_small[:, j1].view(
|
|
4447
|
+
N_image, 1, 1, M3, N3
|
|
4448
|
+
)
|
|
4449
|
+
* self.backend.bk_conjugate(
|
|
4450
|
+
I12_w3_2_small.view(
|
|
4451
|
+
N_image, L, L, M3, N3
|
|
4452
|
+
)
|
|
5273
4453
|
)
|
|
5274
|
-
)[
|
|
5275
|
-
|
|
5276
|
-
|
|
5277
|
-
|
|
5278
|
-
|
|
5279
|
-
|
|
5280
|
-
|
|
5281
|
-
|
|
5282
|
-
|
|
5283
|
-
|
|
5284
|
-
|
|
5285
|
-
|
|
4454
|
+
)[
|
|
4455
|
+
...,
|
|
4456
|
+
edge_dx:-edge_dx,
|
|
4457
|
+
edge_dy:-edge_dy,
|
|
4458
|
+
].mean(
|
|
4459
|
+
(-2, -1)
|
|
4460
|
+
) * fft_factor
|
|
4461
|
+
|
|
4462
|
+
Ndata_S4 += 1
|
|
4463
|
+
|
|
4464
|
+
if normalization == "S2":
|
|
4465
|
+
if use_ref:
|
|
4466
|
+
P = (
|
|
4467
|
+
ref_S2[:, j3 : j3 + 1, :, None, None]
|
|
4468
|
+
* ref_S2[:, j2 : j2 + 1, None, :, None]
|
|
4469
|
+
) ** (0.5 * pseudo_coef)
|
|
4470
|
+
else:
|
|
4471
|
+
P = (
|
|
4472
|
+
S2[:, j3 : j3 + 1, :, None, None]
|
|
4473
|
+
* S2[:, j2 : j2 + 1, None, :, None]
|
|
4474
|
+
) ** (0.5 * pseudo_coef)
|
|
4475
|
+
|
|
4476
|
+
S4[:, beg_n:Ndata_S4, :, :, :] = (
|
|
4477
|
+
S4_pre_norm[:, beg_n:Ndata_S4, :, :, :].clone() / P
|
|
4478
|
+
)
|
|
4479
|
+
|
|
5286
4480
|
if get_variance:
|
|
5287
|
-
S4_sigma[:,beg_n:Ndata_S4
|
|
4481
|
+
S4_sigma[:, beg_n:Ndata_S4, :, :, :] = (
|
|
4482
|
+
S4_sigma[:, beg_n:Ndata_S4, :, :, :] / P
|
|
4483
|
+
)
|
|
5288
4484
|
else:
|
|
5289
|
-
S4=S4_pre_norm
|
|
5290
|
-
|
|
4485
|
+
S4 = S4_pre_norm
|
|
4486
|
+
|
|
5291
4487
|
# average over l1 to obtain simple isotropic statistics
|
|
5292
4488
|
if iso_ang:
|
|
5293
4489
|
S2_iso = S2.mean(-1)
|
|
5294
4490
|
S1_iso = S1.mean(-1)
|
|
5295
4491
|
for l1 in range(L):
|
|
5296
4492
|
for l2 in range(L):
|
|
5297
|
-
S3_iso[...,(l2-l1)%L] += S3[...,l1,l2]
|
|
4493
|
+
S3_iso[..., (l2 - l1) % L] += S3[..., l1, l2]
|
|
5298
4494
|
if data2 is not None:
|
|
5299
|
-
S3p_iso[...,(l2-l1)%L] += S3p[...,l1,l2]
|
|
4495
|
+
S3p_iso[..., (l2 - l1) % L] += S3p[..., l1, l2]
|
|
5300
4496
|
for l3 in range(L):
|
|
5301
|
-
S4_iso[...,(l2-l1)%L,(l3-l1)%L] += S4[
|
|
5302
|
-
|
|
4497
|
+
S4_iso[..., (l2 - l1) % L, (l3 - l1) % L] += S4[
|
|
4498
|
+
..., l1, l2, l3
|
|
4499
|
+
]
|
|
4500
|
+
S3_iso /= L
|
|
4501
|
+
S4_iso /= L
|
|
5303
4502
|
if data2 is not None:
|
|
5304
4503
|
S3p_iso /= L
|
|
5305
|
-
|
|
4504
|
+
|
|
5306
4505
|
if get_variance:
|
|
5307
4506
|
S2_sigma_iso = S2_sigma.mean(-1)
|
|
5308
4507
|
S1_sigma_iso = S1_sigma.mean(-1)
|
|
5309
4508
|
for l1 in range(L):
|
|
5310
4509
|
for l2 in range(L):
|
|
5311
|
-
S3_sigma_iso[...,(l2-l1)%L] += S3_sigma[...,l1,l2]
|
|
4510
|
+
S3_sigma_iso[..., (l2 - l1) % L] += S3_sigma[..., l1, l2]
|
|
5312
4511
|
if data2 is not None:
|
|
5313
|
-
S3p_sigma_iso[...,(l2-l1)%L] += S3p_sigma[
|
|
4512
|
+
S3p_sigma_iso[..., (l2 - l1) % L] += S3p_sigma[
|
|
4513
|
+
..., l1, l2
|
|
4514
|
+
]
|
|
5314
4515
|
for l3 in range(L):
|
|
5315
|
-
S4_sigma_iso[
|
|
5316
|
-
|
|
4516
|
+
S4_sigma_iso[
|
|
4517
|
+
..., (l2 - l1) % L, (l3 - l1) % L
|
|
4518
|
+
] += S4_sigma[..., l1, l2, l3]
|
|
4519
|
+
S3_sigma_iso /= L
|
|
4520
|
+
S4_sigma_iso /= L
|
|
5317
4521
|
if data2 is not None:
|
|
5318
4522
|
S3p_sigma_iso /= L
|
|
5319
|
-
|
|
5320
|
-
mean_data=self.backend.bk_zeros((N_image,1), dtype=data.dtype)
|
|
5321
|
-
std_data=self.backend.bk_zeros((N_image,1), dtype=data.dtype)
|
|
5322
|
-
|
|
4523
|
+
|
|
4524
|
+
mean_data = self.backend.bk_zeros((N_image, 1), dtype=data.dtype)
|
|
4525
|
+
std_data = self.backend.bk_zeros((N_image, 1), dtype=data.dtype)
|
|
4526
|
+
|
|
5323
4527
|
if data2 is None:
|
|
5324
|
-
mean_data[:,0]=data.mean((-2
|
|
5325
|
-
std_data[:,0]=data.std((-2
|
|
4528
|
+
mean_data[:, 0] = data.mean((-2, -1))
|
|
4529
|
+
std_data[:, 0] = data.std((-2, -1))
|
|
5326
4530
|
else:
|
|
5327
|
-
mean_data[:,0]=(data2*data).mean((-2
|
|
5328
|
-
std_data[:,0]=(data2*data).std((-2
|
|
5329
|
-
|
|
4531
|
+
mean_data[:, 0] = (data2 * data).mean((-2, -1))
|
|
4532
|
+
std_data[:, 0] = (data2 * data).std((-2, -1))
|
|
4533
|
+
|
|
5330
4534
|
if get_variance:
|
|
5331
|
-
ref_sigma={}
|
|
4535
|
+
ref_sigma = {}
|
|
5332
4536
|
if iso_ang:
|
|
5333
|
-
ref_sigma[
|
|
5334
|
-
ref_sigma[
|
|
5335
|
-
ref_sigma[
|
|
5336
|
-
ref_sigma[
|
|
4537
|
+
ref_sigma["std_data"] = std_data
|
|
4538
|
+
ref_sigma["S1_sigma"] = S1_sigma_iso
|
|
4539
|
+
ref_sigma["S2_sigma"] = S2_sigma_iso
|
|
4540
|
+
ref_sigma["S3_sigma"] = S3_sigma_iso
|
|
5337
4541
|
if data2 is not None:
|
|
5338
|
-
ref_sigma[
|
|
5339
|
-
ref_sigma[
|
|
4542
|
+
ref_sigma["S3p_sigma"] = S3p_sigma_iso
|
|
4543
|
+
ref_sigma["S4_sigma"] = S4_sigma_iso
|
|
5340
4544
|
else:
|
|
5341
|
-
ref_sigma[
|
|
5342
|
-
ref_sigma[
|
|
5343
|
-
ref_sigma[
|
|
5344
|
-
ref_sigma[
|
|
4545
|
+
ref_sigma["std_data"] = std_data
|
|
4546
|
+
ref_sigma["S1_sigma"] = S1_sigma
|
|
4547
|
+
ref_sigma["S2_sigma"] = S2_sigma
|
|
4548
|
+
ref_sigma["S3_sigma"] = S3_sigma
|
|
5345
4549
|
if data2 is not None:
|
|
5346
|
-
ref_sigma[
|
|
5347
|
-
ref_sigma[
|
|
5348
|
-
|
|
4550
|
+
ref_sigma["S3p_sigma"] = S3p_sigma
|
|
4551
|
+
ref_sigma["S4_sigma"] = S4_sigma
|
|
4552
|
+
|
|
5349
4553
|
if data2 is None:
|
|
5350
4554
|
if iso_ang:
|
|
5351
4555
|
if ref_sigma is not None:
|
|
5352
|
-
for_synthesis = self.backend.backend.cat(
|
|
5353
|
-
|
|
5354
|
-
|
|
5355
|
-
|
|
5356
|
-
|
|
5357
|
-
|
|
5358
|
-
|
|
5359
|
-
|
|
5360
|
-
|
|
5361
|
-
|
|
4556
|
+
for_synthesis = self.backend.backend.cat(
|
|
4557
|
+
(
|
|
4558
|
+
mean_data / ref_sigma["std_data"],
|
|
4559
|
+
std_data / ref_sigma["std_data"],
|
|
4560
|
+
(S2_iso / ref_sigma["S2_sigma"])
|
|
4561
|
+
.reshape((N_image, -1))
|
|
4562
|
+
.log(),
|
|
4563
|
+
(S1_iso / ref_sigma["S1_sigma"])
|
|
4564
|
+
.reshape((N_image, -1))
|
|
4565
|
+
.log(),
|
|
4566
|
+
(S3_iso / ref_sigma["S3_sigma"])
|
|
4567
|
+
.reshape((N_image, -1))
|
|
4568
|
+
.real,
|
|
4569
|
+
(S3_iso / ref_sigma["S3_sigma"])
|
|
4570
|
+
.reshape((N_image, -1))
|
|
4571
|
+
.imag,
|
|
4572
|
+
(S4_iso / ref_sigma["S4_sigma"])
|
|
4573
|
+
.reshape((N_image, -1))
|
|
4574
|
+
.real,
|
|
4575
|
+
(S4_iso / ref_sigma["S4_sigma"])
|
|
4576
|
+
.reshape((N_image, -1))
|
|
4577
|
+
.imag,
|
|
4578
|
+
),
|
|
4579
|
+
dim=-1,
|
|
4580
|
+
)
|
|
5362
4581
|
else:
|
|
5363
|
-
for_synthesis = self.backend.backend.cat(
|
|
5364
|
-
|
|
5365
|
-
|
|
5366
|
-
|
|
5367
|
-
|
|
5368
|
-
|
|
5369
|
-
|
|
5370
|
-
|
|
5371
|
-
|
|
5372
|
-
|
|
4582
|
+
for_synthesis = self.backend.backend.cat(
|
|
4583
|
+
(
|
|
4584
|
+
mean_data / std_data,
|
|
4585
|
+
std_data,
|
|
4586
|
+
S2_iso.reshape((N_image, -1)).log(),
|
|
4587
|
+
S1_iso.reshape((N_image, -1)).log(),
|
|
4588
|
+
S3_iso.reshape((N_image, -1)).real,
|
|
4589
|
+
S3_iso.reshape((N_image, -1)).imag,
|
|
4590
|
+
S4_iso.reshape((N_image, -1)).real,
|
|
4591
|
+
S4_iso.reshape((N_image, -1)).imag,
|
|
4592
|
+
),
|
|
4593
|
+
dim=-1,
|
|
4594
|
+
)
|
|
5373
4595
|
else:
|
|
5374
4596
|
if ref_sigma is not None:
|
|
5375
|
-
for_synthesis = self.backend.backend.cat(
|
|
5376
|
-
|
|
5377
|
-
|
|
5378
|
-
|
|
5379
|
-
|
|
5380
|
-
|
|
5381
|
-
|
|
5382
|
-
|
|
5383
|
-
|
|
5384
|
-
|
|
4597
|
+
for_synthesis = self.backend.backend.cat(
|
|
4598
|
+
(
|
|
4599
|
+
mean_data / ref_sigma["std_data"],
|
|
4600
|
+
std_data / ref_sigma["std_data"],
|
|
4601
|
+
(S2 / ref_sigma["S2_sigma"])
|
|
4602
|
+
.reshape((N_image, -1))
|
|
4603
|
+
.log(),
|
|
4604
|
+
(S1 / ref_sigma["S1_sigma"])
|
|
4605
|
+
.reshape((N_image, -1))
|
|
4606
|
+
.log(),
|
|
4607
|
+
(S3 / ref_sigma["S3_sigma"])
|
|
4608
|
+
.reshape((N_image, -1))
|
|
4609
|
+
.real,
|
|
4610
|
+
(S3 / ref_sigma["S3_sigma"])
|
|
4611
|
+
.reshape((N_image, -1))
|
|
4612
|
+
.imag,
|
|
4613
|
+
(S4 / ref_sigma["S4_sigma"])
|
|
4614
|
+
.reshape((N_image, -1))
|
|
4615
|
+
.real,
|
|
4616
|
+
(S4 / ref_sigma["S4_sigma"])
|
|
4617
|
+
.reshape((N_image, -1))
|
|
4618
|
+
.imag,
|
|
4619
|
+
),
|
|
4620
|
+
dim=-1,
|
|
4621
|
+
)
|
|
5385
4622
|
else:
|
|
5386
|
-
for_synthesis = self.backend.backend.cat(
|
|
5387
|
-
|
|
5388
|
-
|
|
5389
|
-
|
|
5390
|
-
|
|
5391
|
-
|
|
5392
|
-
|
|
5393
|
-
|
|
5394
|
-
|
|
5395
|
-
|
|
4623
|
+
for_synthesis = self.backend.backend.cat(
|
|
4624
|
+
(
|
|
4625
|
+
mean_data / std_data,
|
|
4626
|
+
std_data,
|
|
4627
|
+
S2.reshape((N_image, -1)).log(),
|
|
4628
|
+
S1.reshape((N_image, -1)).log(),
|
|
4629
|
+
S3.reshape((N_image, -1)).real,
|
|
4630
|
+
S3.reshape((N_image, -1)).imag,
|
|
4631
|
+
S4.reshape((N_image, -1)).real,
|
|
4632
|
+
S4.reshape((N_image, -1)).imag,
|
|
4633
|
+
),
|
|
4634
|
+
dim=-1,
|
|
4635
|
+
)
|
|
5396
4636
|
else:
|
|
5397
4637
|
if iso_ang:
|
|
5398
4638
|
if ref_sigma is not None:
|
|
5399
|
-
for_synthesis = self.backend.backend.cat(
|
|
5400
|
-
|
|
5401
|
-
|
|
5402
|
-
|
|
5403
|
-
|
|
5404
|
-
|
|
5405
|
-
|
|
5406
|
-
|
|
5407
|
-
|
|
5408
|
-
|
|
5409
|
-
|
|
5410
|
-
|
|
4639
|
+
for_synthesis = self.backend.backend.cat(
|
|
4640
|
+
(
|
|
4641
|
+
mean_data / ref_sigma["std_data"],
|
|
4642
|
+
std_data / ref_sigma["std_data"],
|
|
4643
|
+
(S2_iso / ref_sigma["S2_sigma"]).reshape((N_image, -1)),
|
|
4644
|
+
(S1_iso / ref_sigma["S1_sigma"]).reshape((N_image, -1)),
|
|
4645
|
+
(S3_iso / ref_sigma["S3_sigma"])
|
|
4646
|
+
.reshape((N_image, -1))
|
|
4647
|
+
.real,
|
|
4648
|
+
(S3_iso / ref_sigma["S3_sigma"])
|
|
4649
|
+
.reshape((N_image, -1))
|
|
4650
|
+
.imag,
|
|
4651
|
+
(S3p_iso / ref_sigma["S3p_sigma"])
|
|
4652
|
+
.reshape((N_image, -1))
|
|
4653
|
+
.real,
|
|
4654
|
+
(S3p_iso / ref_sigma["S3p_sigma"])
|
|
4655
|
+
.reshape((N_image, -1))
|
|
4656
|
+
.imag,
|
|
4657
|
+
(S4_iso / ref_sigma["S4_sigma"])
|
|
4658
|
+
.reshape((N_image, -1))
|
|
4659
|
+
.real,
|
|
4660
|
+
(S4_iso / ref_sigma["S4_sigma"])
|
|
4661
|
+
.reshape((N_image, -1))
|
|
4662
|
+
.imag,
|
|
4663
|
+
),
|
|
4664
|
+
dim=-1,
|
|
4665
|
+
)
|
|
5411
4666
|
else:
|
|
5412
|
-
for_synthesis = self.backend.backend.cat(
|
|
5413
|
-
|
|
5414
|
-
|
|
5415
|
-
|
|
5416
|
-
|
|
5417
|
-
|
|
5418
|
-
|
|
5419
|
-
|
|
5420
|
-
|
|
5421
|
-
|
|
5422
|
-
|
|
5423
|
-
|
|
4667
|
+
for_synthesis = self.backend.backend.cat(
|
|
4668
|
+
(
|
|
4669
|
+
mean_data / std_data,
|
|
4670
|
+
std_data,
|
|
4671
|
+
S2_iso.reshape((N_image, -1)),
|
|
4672
|
+
S1_iso.reshape((N_image, -1)),
|
|
4673
|
+
S3_iso.reshape((N_image, -1)).real,
|
|
4674
|
+
S3_iso.reshape((N_image, -1)).imag,
|
|
4675
|
+
S3p_iso.reshape((N_image, -1)).real,
|
|
4676
|
+
S3p_iso.reshape((N_image, -1)).imag,
|
|
4677
|
+
S4_iso.reshape((N_image, -1)).real,
|
|
4678
|
+
S4_iso.reshape((N_image, -1)).imag,
|
|
4679
|
+
),
|
|
4680
|
+
dim=-1,
|
|
4681
|
+
)
|
|
5424
4682
|
else:
|
|
5425
4683
|
if ref_sigma is not None:
|
|
5426
|
-
for_synthesis = self.backend.backend.cat(
|
|
5427
|
-
|
|
5428
|
-
|
|
5429
|
-
|
|
5430
|
-
|
|
5431
|
-
|
|
5432
|
-
|
|
5433
|
-
|
|
5434
|
-
|
|
5435
|
-
|
|
5436
|
-
|
|
5437
|
-
|
|
4684
|
+
for_synthesis = self.backend.backend.cat(
|
|
4685
|
+
(
|
|
4686
|
+
mean_data / ref_sigma["std_data"],
|
|
4687
|
+
std_data / ref_sigma["std_data"],
|
|
4688
|
+
(S2 / ref_sigma["S2_sigma"]).reshape((N_image, -1)),
|
|
4689
|
+
(S1 / ref_sigma["S1_sigma"]).reshape((N_image, -1)),
|
|
4690
|
+
(S3 / ref_sigma["S3_sigma"])
|
|
4691
|
+
.reshape((N_image, -1))
|
|
4692
|
+
.real,
|
|
4693
|
+
(S3 / ref_sigma["S3_sigma"])
|
|
4694
|
+
.reshape((N_image, -1))
|
|
4695
|
+
.imag,
|
|
4696
|
+
(S3p / ref_sigma["S3p_sigma"])
|
|
4697
|
+
.reshape((N_image, -1))
|
|
4698
|
+
.real,
|
|
4699
|
+
(S3p / ref_sigma["S3p_sigma"])
|
|
4700
|
+
.reshape((N_image, -1))
|
|
4701
|
+
.imag,
|
|
4702
|
+
(S4 / ref_sigma["S4_sigma"])
|
|
4703
|
+
.reshape((N_image, -1))
|
|
4704
|
+
.real,
|
|
4705
|
+
(S4 / ref_sigma["S4_sigma"])
|
|
4706
|
+
.reshape((N_image, -1))
|
|
4707
|
+
.imag,
|
|
4708
|
+
),
|
|
4709
|
+
dim=-1,
|
|
4710
|
+
)
|
|
5438
4711
|
else:
|
|
5439
|
-
for_synthesis = self.backend.backend.cat(
|
|
5440
|
-
|
|
5441
|
-
|
|
5442
|
-
|
|
5443
|
-
|
|
5444
|
-
|
|
5445
|
-
|
|
5446
|
-
|
|
5447
|
-
|
|
5448
|
-
|
|
5449
|
-
|
|
5450
|
-
|
|
5451
|
-
|
|
5452
|
-
|
|
5453
|
-
|
|
5454
|
-
|
|
4712
|
+
for_synthesis = self.backend.backend.cat(
|
|
4713
|
+
(
|
|
4714
|
+
mean_data / std_data,
|
|
4715
|
+
std_data,
|
|
4716
|
+
S2.reshape((N_image, -1)),
|
|
4717
|
+
S1.reshape((N_image, -1)),
|
|
4718
|
+
S3.reshape((N_image, -1)).real,
|
|
4719
|
+
S3.reshape((N_image, -1)).imag,
|
|
4720
|
+
S3p.reshape((N_image, -1)).real,
|
|
4721
|
+
S3p.reshape((N_image, -1)).imag,
|
|
4722
|
+
S4.reshape((N_image, -1)).real,
|
|
4723
|
+
S4.reshape((N_image, -1)).imag,
|
|
4724
|
+
),
|
|
4725
|
+
dim=-1,
|
|
4726
|
+
)
|
|
4727
|
+
|
|
4728
|
+
if not use_ref:
|
|
4729
|
+
self.ref_scattering_cov_S2 = S2
|
|
4730
|
+
|
|
5455
4731
|
if get_variance:
|
|
5456
|
-
return for_synthesis,ref_sigma
|
|
5457
|
-
|
|
4732
|
+
return for_synthesis, ref_sigma
|
|
4733
|
+
|
|
5458
4734
|
return for_synthesis
|
|
5459
|
-
|
|
5460
|
-
if (M,N,J,L) not in self.filters_set:
|
|
5461
|
-
self.filters_set[(M,N,J,L)] = self.computer_filter(
|
|
5462
|
-
|
|
5463
|
-
|
|
5464
|
-
|
|
5465
|
-
|
|
4735
|
+
|
|
4736
|
+
if (M, N, J, L) not in self.filters_set:
|
|
4737
|
+
self.filters_set[(M, N, J, L)] = self.computer_filter(
|
|
4738
|
+
M, N, J, L
|
|
4739
|
+
) # self.computer_filter(M,N,J,L)
|
|
4740
|
+
|
|
4741
|
+
filters_set = self.filters_set[(M, N, J, L)]
|
|
4742
|
+
|
|
4743
|
+
# weight = self.weight
|
|
5466
4744
|
if use_ref:
|
|
5467
|
-
if normalization==
|
|
4745
|
+
if normalization == "S2":
|
|
5468
4746
|
ref_S2 = self.ref_scattering_cov_S2
|
|
5469
|
-
else:
|
|
5470
|
-
ref_P11 = self.ref_scattering_cov[
|
|
4747
|
+
else:
|
|
4748
|
+
ref_P11 = self.ref_scattering_cov["P11"]
|
|
5471
4749
|
|
|
5472
4750
|
# convert numpy array input into self.backend.bk_ tensors
|
|
5473
4751
|
data = self.backend.bk_cast(data)
|
|
5474
|
-
data_f = self.backend.bk_fftn(data, dim=(-2
|
|
4752
|
+
data_f = self.backend.bk_fftn(data, dim=(-2, -1))
|
|
5475
4753
|
if data2 is not None:
|
|
5476
4754
|
data2 = self.backend.bk_cast(data2)
|
|
5477
|
-
data2_f = self.backend.bk_fftn(data2, dim=(-2
|
|
5478
|
-
|
|
4755
|
+
data2_f = self.backend.bk_fftn(data2, dim=(-2, -1))
|
|
4756
|
+
|
|
5479
4757
|
# initialize tensors for scattering coefficients
|
|
5480
|
-
|
|
5481
|
-
Ndata_S3 = J*(J+1)//2
|
|
5482
|
-
Ndata_S4 = J*(J+1)*(J+2)//6
|
|
5483
|
-
J_S4={}
|
|
5484
|
-
|
|
4758
|
+
|
|
4759
|
+
Ndata_S3 = J * (J + 1) // 2
|
|
4760
|
+
Ndata_S4 = J * (J + 1) * (J + 2) // 6
|
|
4761
|
+
J_S4 = {}
|
|
4762
|
+
|
|
5485
4763
|
S3 = []
|
|
5486
4764
|
if data2 is not None:
|
|
5487
4765
|
S3p = []
|
|
5488
|
-
S4_pre_norm = []
|
|
5489
|
-
S4 = []
|
|
5490
|
-
|
|
4766
|
+
S4_pre_norm = []
|
|
4767
|
+
S4 = []
|
|
4768
|
+
|
|
5491
4769
|
# variance
|
|
5492
4770
|
if get_variance:
|
|
5493
|
-
S3_sigma = []
|
|
4771
|
+
S3_sigma = []
|
|
5494
4772
|
if data2 is not None:
|
|
5495
4773
|
S3p_sigma = []
|
|
5496
|
-
S4_sigma = []
|
|
5497
|
-
|
|
4774
|
+
S4_sigma = []
|
|
4775
|
+
|
|
5498
4776
|
if iso_ang:
|
|
5499
4777
|
S3_iso = []
|
|
5500
4778
|
if data2 is not None:
|
|
5501
4779
|
S3p_iso = []
|
|
5502
|
-
|
|
4780
|
+
|
|
5503
4781
|
S4_iso = []
|
|
5504
4782
|
if get_variance:
|
|
5505
4783
|
S3_sigma_iso = []
|
|
5506
4784
|
if data2 is not None:
|
|
5507
|
-
S3p_sigma_iso = []
|
|
5508
|
-
S4_sigma_iso = []
|
|
5509
|
-
|
|
4785
|
+
S3p_sigma_iso = []
|
|
4786
|
+
S4_sigma_iso = []
|
|
4787
|
+
|
|
5510
4788
|
#
|
|
5511
|
-
if edge:
|
|
5512
|
-
if (M,N,J) not in self.edge_masks:
|
|
5513
|
-
self.edge_masks[(M,N,J)] = self.get_edge_masks(M,N,J)
|
|
5514
|
-
edge_mask = self.edge_masks[(M,N,J)]
|
|
5515
|
-
else:
|
|
4789
|
+
if edge:
|
|
4790
|
+
if (M, N, J) not in self.edge_masks:
|
|
4791
|
+
self.edge_masks[(M, N, J)] = self.get_edge_masks(M, N, J)
|
|
4792
|
+
edge_mask = self.edge_masks[(M, N, J)]
|
|
4793
|
+
else:
|
|
5516
4794
|
edge_mask = 1
|
|
5517
|
-
|
|
4795
|
+
|
|
5518
4796
|
# calculate scattering fields
|
|
5519
4797
|
if data2 is None:
|
|
5520
4798
|
if self.use_2D:
|
|
5521
4799
|
if len(data.shape) == 2:
|
|
5522
|
-
I1 = self.backend.bk_abs(
|
|
5523
|
-
|
|
5524
|
-
|
|
4800
|
+
I1 = self.backend.bk_abs(
|
|
4801
|
+
self.backend.bk_ifftn(
|
|
4802
|
+
data_f[None, None, None, :, :]
|
|
4803
|
+
* filters_set[None, :J, :, :, :],
|
|
4804
|
+
dim=(-2, -1),
|
|
4805
|
+
)
|
|
4806
|
+
)
|
|
5525
4807
|
else:
|
|
5526
|
-
I1 = self.backend.bk_abs(
|
|
5527
|
-
|
|
5528
|
-
|
|
4808
|
+
I1 = self.backend.bk_abs(
|
|
4809
|
+
self.backend.bk_ifftn(
|
|
4810
|
+
data_f[:, None, None, :, :]
|
|
4811
|
+
* filters_set[None, :J, :, :, :],
|
|
4812
|
+
dim=(-2, -1),
|
|
4813
|
+
)
|
|
4814
|
+
)
|
|
5529
4815
|
elif self.use_1D:
|
|
5530
4816
|
if len(data.shape) == 1:
|
|
5531
|
-
I1 = self.backend.bk_abs(
|
|
5532
|
-
|
|
5533
|
-
|
|
4817
|
+
I1 = self.backend.bk_abs(
|
|
4818
|
+
self.backend.bk_ifftn(
|
|
4819
|
+
data_f[None, None, None, :] * filters_set[None, :J, :, :],
|
|
4820
|
+
dim=(-1),
|
|
4821
|
+
)
|
|
4822
|
+
)
|
|
5534
4823
|
else:
|
|
5535
|
-
I1 = self.backend.bk_abs(
|
|
5536
|
-
|
|
5537
|
-
|
|
4824
|
+
I1 = self.backend.bk_abs(
|
|
4825
|
+
self.backend.bk_ifftn(
|
|
4826
|
+
data_f[:, None, None, :] * filters_set[None, :J, :, :],
|
|
4827
|
+
dim=(-1),
|
|
4828
|
+
)
|
|
4829
|
+
)
|
|
5538
4830
|
else:
|
|
5539
|
-
print(
|
|
5540
|
-
|
|
5541
|
-
S2 = self.backend.bk_reduce_mean((I1**2 * edge_mask),axis=(-2
|
|
5542
|
-
S1 = self.backend.bk_reduce_mean(I1 * edge_mask,axis=(-2
|
|
4831
|
+
print("todo")
|
|
4832
|
+
|
|
4833
|
+
S2 = self.backend.bk_reduce_mean((I1**2 * edge_mask), axis=(-2, -1))
|
|
4834
|
+
S1 = self.backend.bk_reduce_mean(I1 * edge_mask, axis=(-2, -1))
|
|
5543
4835
|
|
|
5544
4836
|
if get_variance:
|
|
5545
|
-
S2_sigma = self.backend.bk_reduce_std(
|
|
5546
|
-
|
|
5547
|
-
|
|
5548
|
-
|
|
5549
|
-
|
|
5550
|
-
|
|
4837
|
+
S2_sigma = self.backend.bk_reduce_std(
|
|
4838
|
+
(I1**2 * edge_mask), axis=(-2, -1)
|
|
4839
|
+
)
|
|
4840
|
+
S1_sigma = self.backend.bk_reduce_std((I1 * edge_mask), axis=(-2, -1))
|
|
4841
|
+
|
|
4842
|
+
I1_f = self.backend.bk_fftn(I1, dim=(-2, -1))
|
|
4843
|
+
|
|
4844
|
+
else:
|
|
5551
4845
|
if self.use_2D:
|
|
5552
4846
|
if len(data.shape) == 2:
|
|
5553
4847
|
I1 = self.backend.bk_ifftn(
|
|
5554
|
-
data_f[None,None,None
|
|
4848
|
+
data_f[None, None, None, :, :] * filters_set[None, :J, :, :, :],
|
|
4849
|
+
dim=(-2, -1),
|
|
5555
4850
|
)
|
|
5556
4851
|
I2 = self.backend.bk_ifftn(
|
|
5557
|
-
data2_f[None,None,None
|
|
4852
|
+
data2_f[None, None, None, :, :]
|
|
4853
|
+
* filters_set[None, :J, :, :, :],
|
|
4854
|
+
dim=(-2, -1),
|
|
5558
4855
|
)
|
|
5559
4856
|
else:
|
|
5560
4857
|
I1 = self.backend.bk_ifftn(
|
|
5561
|
-
data_f[:,None,None
|
|
4858
|
+
data_f[:, None, None, :, :] * filters_set[None, :J, :, :, :],
|
|
4859
|
+
dim=(-2, -1),
|
|
5562
4860
|
)
|
|
5563
4861
|
I2 = self.backend.bk_ifftn(
|
|
5564
|
-
data2_f[:,None,None
|
|
4862
|
+
data2_f[:, None, None, :, :] * filters_set[None, :J, :, :, :],
|
|
4863
|
+
dim=(-2, -1),
|
|
5565
4864
|
)
|
|
5566
4865
|
elif self.use_1D:
|
|
5567
4866
|
if len(data.shape) == 1:
|
|
5568
4867
|
I1 = self.backend.bk_ifftn(
|
|
5569
|
-
data_f[None,None,None
|
|
4868
|
+
data_f[None, None, None, :] * filters_set[None, :J, :, :],
|
|
4869
|
+
dim=(-1),
|
|
5570
4870
|
)
|
|
5571
4871
|
I2 = self.backend.bk_ifftn(
|
|
5572
|
-
data2_f[None,None,None
|
|
4872
|
+
data2_f[None, None, None, :] * filters_set[None, :J, :, :],
|
|
4873
|
+
dim=(-1),
|
|
5573
4874
|
)
|
|
5574
4875
|
else:
|
|
5575
4876
|
I1 = self.backend.bk_ifftn(
|
|
5576
|
-
data_f[:,None,None
|
|
4877
|
+
data_f[:, None, None, :] * filters_set[None, :J, :, :], dim=(-1)
|
|
5577
4878
|
)
|
|
5578
4879
|
I2 = self.backend.bk_ifftn(
|
|
5579
|
-
data2_f[:,None,None
|
|
4880
|
+
data2_f[:, None, None, :] * filters_set[None, :J, :, :],
|
|
4881
|
+
dim=(-1),
|
|
5580
4882
|
)
|
|
5581
4883
|
else:
|
|
5582
|
-
print(
|
|
5583
|
-
|
|
5584
|
-
I1=self.backend.bk_real(I1*self.backend.bk_conjugate(I2))
|
|
5585
|
-
|
|
5586
|
-
S2 = self.backend.bk_reduce_mean((I1 * edge_mask),axis=(-2
|
|
4884
|
+
print("todo")
|
|
4885
|
+
|
|
4886
|
+
I1 = self.backend.bk_real(I1 * self.backend.bk_conjugate(I2))
|
|
4887
|
+
|
|
4888
|
+
S2 = self.backend.bk_reduce_mean((I1 * edge_mask), axis=(-2, -1))
|
|
5587
4889
|
if get_variance:
|
|
5588
|
-
S2_sigma = self.backend.bk_reduce_std((I1 * edge_mask),axis=(-2
|
|
5589
|
-
|
|
5590
|
-
I1=self.backend.bk_L1(I1)
|
|
5591
|
-
|
|
5592
|
-
S1
|
|
4890
|
+
S2_sigma = self.backend.bk_reduce_std((I1 * edge_mask), axis=(-2, -1))
|
|
4891
|
+
|
|
4892
|
+
I1 = self.backend.bk_L1(I1)
|
|
4893
|
+
|
|
4894
|
+
S1 = self.backend.bk_reduce_mean((I1 * edge_mask), axis=(-2, -1))
|
|
5593
4895
|
|
|
5594
4896
|
if get_variance:
|
|
5595
|
-
S1_sigma
|
|
5596
|
-
|
|
5597
|
-
I1_f= self.backend.bk_fftn(I1, dim=(-2
|
|
5598
|
-
|
|
4897
|
+
S1_sigma = self.backend.bk_reduce_std((I1 * edge_mask), axis=(-2, -1))
|
|
4898
|
+
|
|
4899
|
+
I1_f = self.backend.bk_fftn(I1, dim=(-2, -1))
|
|
4900
|
+
|
|
5599
4901
|
if pseudo_coef != 1:
|
|
5600
4902
|
I1 = I1**pseudo_coef
|
|
5601
|
-
|
|
5602
|
-
Ndata_S3=0
|
|
5603
|
-
Ndata_S4=0
|
|
5604
|
-
|
|
4903
|
+
|
|
4904
|
+
Ndata_S3 = 0
|
|
4905
|
+
Ndata_S4 = 0
|
|
4906
|
+
|
|
5605
4907
|
# calculate the covariance and correlations of the scattering fields
|
|
5606
4908
|
# only use the low-k Fourier coefs when calculating large-j scattering coefs.
|
|
5607
|
-
for j3 in range(0,J):
|
|
5608
|
-
J_S4[j3]=Ndata_S4
|
|
5609
|
-
|
|
5610
|
-
dx3, dy3 = self.get_dxdy(j3,M,N)
|
|
5611
|
-
I1_f_small = self.cut_high_k_off(
|
|
4909
|
+
for j3 in range(0, J):
|
|
4910
|
+
J_S4[j3] = Ndata_S4
|
|
4911
|
+
|
|
4912
|
+
dx3, dy3 = self.get_dxdy(j3, M, N)
|
|
4913
|
+
I1_f_small = self.cut_high_k_off(
|
|
4914
|
+
I1_f[:, : j3 + 1], dx3, dy3
|
|
4915
|
+
) # Nimage, J, L, x, y
|
|
5612
4916
|
data_f_small = self.cut_high_k_off(data_f, dx3, dy3)
|
|
5613
4917
|
if data2 is not None:
|
|
5614
4918
|
data2_f_small = self.cut_high_k_off(data2_f, dx3, dy3)
|
|
5615
4919
|
if edge:
|
|
5616
|
-
I1_small = self.backend.bk_ifftn(I1_f_small, dim=(-2
|
|
5617
|
-
data_small = self.backend.bk_ifftn(
|
|
4920
|
+
I1_small = self.backend.bk_ifftn(I1_f_small, dim=(-2, -1), norm="ortho")
|
|
4921
|
+
data_small = self.backend.bk_ifftn(
|
|
4922
|
+
data_f_small, dim=(-2, -1), norm="ortho"
|
|
4923
|
+
)
|
|
5618
4924
|
if data2 is not None:
|
|
5619
|
-
data2_small = self.backend.bk_ifftn(
|
|
5620
|
-
|
|
4925
|
+
data2_small = self.backend.bk_ifftn(
|
|
4926
|
+
data2_f_small, dim=(-2, -1), norm="ortho"
|
|
4927
|
+
)
|
|
4928
|
+
wavelet_f3 = self.cut_high_k_off(filters_set[j3], dx3, dy3) # L,x,y
|
|
5621
4929
|
_, M3, N3 = wavelet_f3.shape
|
|
5622
4930
|
wavelet_f3_squared = wavelet_f3**2
|
|
5623
|
-
edge_dx = min(4, int(2**j3*dx3*2/M))
|
|
5624
|
-
edge_dy = min(4, int(2**j3*dy3*2/N))
|
|
4931
|
+
edge_dx = min(4, int(2**j3 * dx3 * 2 / M))
|
|
4932
|
+
edge_dy = min(4, int(2**j3 * dy3 * 2 / N))
|
|
5625
4933
|
# a normalization change due to the cutoff of frequency space
|
|
5626
|
-
|
|
5627
|
-
|
|
5628
|
-
|
|
5629
|
-
|
|
5630
|
-
|
|
5631
|
-
|
|
4934
|
+
if self.all_bk_type == "float32":
|
|
4935
|
+
fft_factor = np.complex64(1 / (M3 * N3) * (M3 * N3 / M / N) ** 2)
|
|
4936
|
+
else:
|
|
4937
|
+
fft_factor = np.complex128(1 / (M3 * N3) * (M3 * N3 / M / N) ** 2)
|
|
4938
|
+
for j2 in range(0, j3 + 1):
|
|
4939
|
+
# I1_f2_wf3_small = I1_f_small[:,j2].view(N_image,L,1,M3,N3) * wavelet_f3.view(1,1,L,M3,N3)
|
|
4940
|
+
# I1_f2_wf3_2_small = I1_f_small[:,j2].view(N_image,L,1,M3,N3) * wavelet_f3_squared.view(1,1,L,M3,N3)
|
|
4941
|
+
I1_f2_wf3_small = self.backend.bk_reshape(
|
|
4942
|
+
I1_f_small[:, j2], [N_image, 1, L, 1, M3, N3]
|
|
4943
|
+
) * self.backend.bk_reshape(wavelet_f3, [1, 1, 1, L, M3, N3])
|
|
4944
|
+
I1_f2_wf3_2_small = self.backend.bk_reshape(
|
|
4945
|
+
I1_f_small[:, j2], [N_image, 1, L, 1, M3, N3]
|
|
4946
|
+
) * self.backend.bk_reshape(wavelet_f3_squared, [1, 1, 1, L, M3, N3])
|
|
5632
4947
|
if edge:
|
|
5633
|
-
I12_w3_small = self.backend.bk_ifftn(
|
|
5634
|
-
|
|
4948
|
+
I12_w3_small = self.backend.bk_ifftn(
|
|
4949
|
+
I1_f2_wf3_small, dim=(-2, -1), norm="ortho"
|
|
4950
|
+
)
|
|
4951
|
+
I12_w3_2_small = self.backend.bk_ifftn(
|
|
4952
|
+
I1_f2_wf3_2_small, dim=(-2, -1), norm="ortho"
|
|
4953
|
+
)
|
|
5635
4954
|
if use_ref:
|
|
5636
|
-
if normalization==
|
|
5637
|
-
norm_factor_S3 = (
|
|
5638
|
-
|
|
5639
|
-
|
|
4955
|
+
if normalization == "P11":
|
|
4956
|
+
norm_factor_S3 = (
|
|
4957
|
+
ref_S2[:, None, j3, :]
|
|
4958
|
+
* ref_P11[:, j2, j3, :, :] ** pseudo_coef
|
|
4959
|
+
) ** 0.5
|
|
4960
|
+
norm_factor_S3 = self.backend.bk_complex(
|
|
4961
|
+
norm_factor_S3, 0 * norm_factor_S3
|
|
4962
|
+
)
|
|
4963
|
+
elif normalization == "S2":
|
|
4964
|
+
norm_factor_S3 = (
|
|
4965
|
+
ref_S2[:, None, j3, :]
|
|
4966
|
+
* ref_S2[:, j2, :, None] ** pseudo_coef
|
|
4967
|
+
) ** 0.5
|
|
4968
|
+
norm_factor_S3 = self.backend.bk_complex(
|
|
4969
|
+
norm_factor_S3, 0 * norm_factor_S3
|
|
4970
|
+
)
|
|
5640
4971
|
else:
|
|
5641
|
-
norm_factor_S3 =
|
|
4972
|
+
norm_factor_S3 = C_ONE
|
|
5642
4973
|
else:
|
|
5643
|
-
if normalization==
|
|
4974
|
+
if normalization == "P11":
|
|
5644
4975
|
# [N_image,l2,l3,x,y]
|
|
5645
|
-
P11_temp =
|
|
5646
|
-
|
|
5647
|
-
|
|
5648
|
-
|
|
4976
|
+
P11_temp = (
|
|
4977
|
+
self.backend.bk_reduce_mean(
|
|
4978
|
+
(I1_f2_wf3_small.abs() ** 2), axis=(-2, -1)
|
|
4979
|
+
)
|
|
4980
|
+
* fft_factor
|
|
4981
|
+
)
|
|
4982
|
+
norm_factor_S3 = (
|
|
4983
|
+
S2[:, None, j3, :] * P11_temp**pseudo_coef
|
|
4984
|
+
) ** 0.5
|
|
4985
|
+
norm_factor_S3 = self.backend.bk_complex(
|
|
4986
|
+
norm_factor_S3, 0 * norm_factor_S3
|
|
4987
|
+
)
|
|
4988
|
+
elif normalization == "S2":
|
|
4989
|
+
norm_factor_S3 = (
|
|
4990
|
+
S2[:, None, j3, None, :]
|
|
4991
|
+
* S2[:, None, j2, :, None] ** pseudo_coef
|
|
4992
|
+
) ** 0.5
|
|
4993
|
+
norm_factor_S3 = self.backend.bk_complex(
|
|
4994
|
+
norm_factor_S3, 0 * norm_factor_S3
|
|
4995
|
+
)
|
|
5649
4996
|
else:
|
|
5650
|
-
norm_factor_S3 =
|
|
5651
|
-
|
|
5652
|
-
norm_factor_S3 = self.backend.bk_complex(norm_factor_S3,0*norm_factor_S3)
|
|
4997
|
+
norm_factor_S3 = C_ONE
|
|
5653
4998
|
|
|
5654
4999
|
if not edge:
|
|
5655
|
-
S3.append(
|
|
5656
|
-
self.backend.
|
|
5657
|
-
|
|
5000
|
+
S3.append(
|
|
5001
|
+
self.backend.bk_reduce_mean(
|
|
5002
|
+
self.backend.bk_reshape(
|
|
5003
|
+
data_f_small, [N_image, 1, 1, 1, M3, N3]
|
|
5004
|
+
)
|
|
5005
|
+
* self.backend.bk_conjugate(I1_f2_wf3_small),
|
|
5006
|
+
axis=(-2, -1),
|
|
5007
|
+
)
|
|
5008
|
+
* fft_factor
|
|
5009
|
+
/ norm_factor_S3
|
|
5010
|
+
)
|
|
5658
5011
|
if get_variance:
|
|
5659
|
-
S3_sigma.append(
|
|
5660
|
-
self.backend.
|
|
5661
|
-
|
|
5012
|
+
S3_sigma.append(
|
|
5013
|
+
self.backend.bk_reduce_std(
|
|
5014
|
+
self.backend.bk_reshape(
|
|
5015
|
+
data_f_small, [N_image, 1, 1, 1, M3, N3]
|
|
5016
|
+
)
|
|
5017
|
+
* self.backend.bk_conjugate(I1_f2_wf3_small),
|
|
5018
|
+
axis=(-2, -1),
|
|
5019
|
+
)
|
|
5020
|
+
* fft_factor
|
|
5021
|
+
/ norm_factor_S3
|
|
5022
|
+
)
|
|
5662
5023
|
else:
|
|
5663
|
-
S3.append(
|
|
5664
|
-
|
|
5665
|
-
|
|
5024
|
+
S3.append(
|
|
5025
|
+
self.backend.bk_reduce_mean(
|
|
5026
|
+
(
|
|
5027
|
+
self.backend.bk_reshape(
|
|
5028
|
+
data_small, [N_image, 1, 1, 1, M3, N3]
|
|
5029
|
+
)
|
|
5030
|
+
* self.backend.bk_conjugate(I12_w3_small)
|
|
5031
|
+
)[..., edge_dx : M3 - edge_dx, edge_dy : N3 - edge_dy],
|
|
5032
|
+
axis=(-2, -1),
|
|
5033
|
+
)
|
|
5034
|
+
* fft_factor
|
|
5035
|
+
/ norm_factor_S3
|
|
5036
|
+
)
|
|
5666
5037
|
if get_variance:
|
|
5667
|
-
S3_sigma.apend(
|
|
5668
|
-
|
|
5669
|
-
|
|
5038
|
+
S3_sigma.apend(
|
|
5039
|
+
self.backend.bk_reduce_std(
|
|
5040
|
+
(
|
|
5041
|
+
self.backend.bk_reshape(
|
|
5042
|
+
data_small, [N_image, 1, 1, 1, M3, N3]
|
|
5043
|
+
)
|
|
5044
|
+
* self.backend.bk_conjugate(I12_w3_small)
|
|
5045
|
+
)[..., edge_dx : M3 - edge_dx, edge_dy : N3 - edge_dy],
|
|
5046
|
+
axis=(-2, -1),
|
|
5047
|
+
)
|
|
5048
|
+
* fft_factor
|
|
5049
|
+
/ norm_factor_S3
|
|
5050
|
+
)
|
|
5670
5051
|
if data2 is not None:
|
|
5671
5052
|
if not edge:
|
|
5672
|
-
S3p.append(
|
|
5673
|
-
|
|
5674
|
-
|
|
5675
|
-
|
|
5053
|
+
S3p.append(
|
|
5054
|
+
self.backend.bk_reduce_mean(
|
|
5055
|
+
(
|
|
5056
|
+
self.backend.bk_reshape(
|
|
5057
|
+
data2_f_small, [N_image2, 1, 1, 1, M3, N3]
|
|
5058
|
+
)
|
|
5059
|
+
* self.backend.bk_conjugate(I1_f2_wf3_small)
|
|
5060
|
+
),
|
|
5061
|
+
axis=(-2, -1),
|
|
5062
|
+
)
|
|
5063
|
+
* fft_factor
|
|
5064
|
+
/ norm_factor_S3
|
|
5065
|
+
)
|
|
5066
|
+
|
|
5676
5067
|
if get_variance:
|
|
5677
|
-
S3p_sigma.append(
|
|
5678
|
-
|
|
5679
|
-
|
|
5068
|
+
S3p_sigma.append(
|
|
5069
|
+
self.backend.bk_reduce_std(
|
|
5070
|
+
(
|
|
5071
|
+
self.backend.bk_reshape(
|
|
5072
|
+
data2_f_small, [N_image2, 1, 1, 1, M3, N3]
|
|
5073
|
+
)
|
|
5074
|
+
* self.backend.bk_conjugate(I1_f2_wf3_small)
|
|
5075
|
+
),
|
|
5076
|
+
axis=(-2, -1),
|
|
5077
|
+
)
|
|
5078
|
+
* fft_factor
|
|
5079
|
+
/ norm_factor_S3
|
|
5080
|
+
)
|
|
5680
5081
|
else:
|
|
5681
|
-
|
|
5682
|
-
S3p.append(
|
|
5683
|
-
|
|
5684
|
-
|
|
5082
|
+
|
|
5083
|
+
S3p.append(
|
|
5084
|
+
self.backend.bk_reduce_mean(
|
|
5085
|
+
(
|
|
5086
|
+
self.backend.bk_reshape(
|
|
5087
|
+
data2_small, [N_image2, 1, 1, 1, M3, N3]
|
|
5088
|
+
)
|
|
5089
|
+
* self.backend.bk_conjugate(I12_w3_small)
|
|
5090
|
+
)[..., edge_dx : M3 - edge_dx, edge_dy : N3 - edge_dy],
|
|
5091
|
+
axis=(-2, -1),
|
|
5092
|
+
)
|
|
5093
|
+
* fft_factor
|
|
5094
|
+
/ norm_factor_S3
|
|
5095
|
+
)
|
|
5685
5096
|
if get_variance:
|
|
5686
|
-
S3p_sigma.append(
|
|
5687
|
-
|
|
5688
|
-
|
|
5689
|
-
|
|
5097
|
+
S3p_sigma.append(
|
|
5098
|
+
self.backend.bk_reduce_std(
|
|
5099
|
+
(
|
|
5100
|
+
self.backend.bk_reshape(
|
|
5101
|
+
data2_small, [N_image2, 1, 1, 1, M3, N3]
|
|
5102
|
+
)
|
|
5103
|
+
* self.backend.bk_conjugate(I12_w3_small)
|
|
5104
|
+
)[
|
|
5105
|
+
...,
|
|
5106
|
+
edge_dx : M3 - edge_dx,
|
|
5107
|
+
edge_dy : N3 - edge_dy,
|
|
5108
|
+
],
|
|
5109
|
+
axis=(-2, -1),
|
|
5110
|
+
)
|
|
5111
|
+
* fft_factor
|
|
5112
|
+
/ norm_factor_S3
|
|
5113
|
+
)
|
|
5114
|
+
|
|
5690
5115
|
if j2 <= j3:
|
|
5691
|
-
if normalization==
|
|
5692
|
-
if use_ref:
|
|
5693
|
-
P = 1/
|
|
5694
|
-
|
|
5695
|
-
|
|
5696
|
-
|
|
5116
|
+
if normalization == "S2":
|
|
5117
|
+
if use_ref:
|
|
5118
|
+
P = 1 / (
|
|
5119
|
+
(
|
|
5120
|
+
ref_S2[:, j3 : j3 + 1, :, None, None]
|
|
5121
|
+
* ref_S2[:, j2 : j2 + 1, None, :, None]
|
|
5122
|
+
)
|
|
5123
|
+
** (0.5 * pseudo_coef)
|
|
5124
|
+
)
|
|
5125
|
+
else:
|
|
5126
|
+
P = 1 / (
|
|
5127
|
+
(
|
|
5128
|
+
S2[:, j3 : j3 + 1, :, None, None]
|
|
5129
|
+
* S2[:, j2 : j2 + 1, None, :, None]
|
|
5130
|
+
)
|
|
5131
|
+
** (0.5 * pseudo_coef)
|
|
5132
|
+
)
|
|
5133
|
+
P = self.backend.bk_complex(P, 0.0 * P)
|
|
5697
5134
|
else:
|
|
5698
|
-
P=
|
|
5699
|
-
|
|
5700
|
-
for j1 in range(0, j2+1):
|
|
5701
|
-
|
|
5702
|
-
|
|
5703
|
-
|
|
5704
|
-
|
|
5705
|
-
|
|
5706
|
-
|
|
5707
|
-
|
|
5708
|
-
|
|
5709
|
-
|
|
5710
|
-
|
|
5711
|
-
self.backend.bk_conjugate(
|
|
5712
|
-
|
|
5713
|
-
|
|
5714
|
-
|
|
5715
|
-
|
|
5716
|
-
|
|
5717
|
-
|
|
5718
|
-
|
|
5719
|
-
|
|
5720
|
-
|
|
5721
|
-
|
|
5722
|
-
|
|
5723
|
-
|
|
5724
|
-
|
|
5135
|
+
P = C_ONE
|
|
5136
|
+
|
|
5137
|
+
for j1 in range(0, j2 + 1):
|
|
5138
|
+
if not edge:
|
|
5139
|
+
if not if_large_batch:
|
|
5140
|
+
# [N_image,l1,l2,l3,x,y]
|
|
5141
|
+
S4.append(
|
|
5142
|
+
self.backend.bk_reduce_mean(
|
|
5143
|
+
(
|
|
5144
|
+
self.backend.bk_reshape(
|
|
5145
|
+
I1_f_small[:, j1],
|
|
5146
|
+
[N_image, 1, L, 1, 1, M3, N3],
|
|
5147
|
+
)
|
|
5148
|
+
* self.backend.bk_conjugate(
|
|
5149
|
+
self.backend.bk_reshape(
|
|
5150
|
+
I1_f2_wf3_2_small,
|
|
5151
|
+
[N_image, 1, 1, L, L, M3, N3],
|
|
5152
|
+
)
|
|
5153
|
+
)
|
|
5154
|
+
),
|
|
5155
|
+
axis=(-2, -1),
|
|
5156
|
+
)
|
|
5157
|
+
* fft_factor
|
|
5158
|
+
* P
|
|
5159
|
+
)
|
|
5160
|
+
if get_variance:
|
|
5161
|
+
S4_sigma.append(
|
|
5162
|
+
self.backend.bk_reduce_std(
|
|
5163
|
+
(
|
|
5164
|
+
self.backend.bk_reshape(
|
|
5165
|
+
I1_f_small[:, j1],
|
|
5166
|
+
[N_image, 1, L, 1, 1, M3, N3],
|
|
5167
|
+
)
|
|
5168
|
+
* self.backend.bk_conjugate(
|
|
5169
|
+
self.backend.bk_reshape(
|
|
5170
|
+
I1_f2_wf3_2_small,
|
|
5171
|
+
[N_image, 1, 1, L, L, M3, N3],
|
|
5172
|
+
)
|
|
5173
|
+
)
|
|
5174
|
+
),
|
|
5175
|
+
axis=(-2, -1),
|
|
5176
|
+
)
|
|
5177
|
+
* fft_factor
|
|
5178
|
+
* P
|
|
5179
|
+
)
|
|
5725
5180
|
else:
|
|
5726
|
-
|
|
5727
|
-
# [N_image,
|
|
5728
|
-
S4.append(
|
|
5729
|
-
|
|
5730
|
-
|
|
5181
|
+
for l1 in range(L):
|
|
5182
|
+
# [N_image,l2,l3,x,y]
|
|
5183
|
+
S4.append(
|
|
5184
|
+
self.backend.bk_reduce_mean(
|
|
5185
|
+
(
|
|
5186
|
+
self.backend.bk_reshape(
|
|
5187
|
+
I1_f_small[:, j1, l1],
|
|
5188
|
+
[N_image, 1, 1, 1, M3, N3],
|
|
5189
|
+
)
|
|
5190
|
+
* self.backend.bk_conjugate(
|
|
5191
|
+
self.backend.bk_reshape(
|
|
5192
|
+
I1_f2_wf3_2_small,
|
|
5193
|
+
[N_image, 1, L, L, M3, N3],
|
|
5194
|
+
)
|
|
5195
|
+
)
|
|
5196
|
+
),
|
|
5197
|
+
axis=(-2, -1),
|
|
5731
5198
|
)
|
|
5732
|
-
|
|
5199
|
+
* fft_factor
|
|
5200
|
+
* P
|
|
5201
|
+
)
|
|
5733
5202
|
if get_variance:
|
|
5734
|
-
S4_sigma.append(
|
|
5735
|
-
|
|
5736
|
-
|
|
5203
|
+
S4_sigma.append(
|
|
5204
|
+
self.backend.bk_reduce_std(
|
|
5205
|
+
(
|
|
5206
|
+
self.backend.bk_reshape(
|
|
5207
|
+
I1_f_small[:, j1, l1],
|
|
5208
|
+
[N_image, 1, 1, 1, M3, N3],
|
|
5209
|
+
)
|
|
5210
|
+
* self.backend.bk_conjugate(
|
|
5211
|
+
self.backend.bk_reshape(
|
|
5212
|
+
I1_f2_wf3_2_small,
|
|
5213
|
+
[N_image, 1, L, L, M3, N3],
|
|
5214
|
+
)
|
|
5215
|
+
)
|
|
5216
|
+
),
|
|
5217
|
+
axis=(-2, -1),
|
|
5737
5218
|
)
|
|
5738
|
-
|
|
5739
|
-
|
|
5740
|
-
|
|
5741
|
-
|
|
5742
|
-
|
|
5743
|
-
|
|
5744
|
-
|
|
5219
|
+
* fft_factor
|
|
5220
|
+
* P
|
|
5221
|
+
)
|
|
5222
|
+
else:
|
|
5223
|
+
if not if_large_batch:
|
|
5224
|
+
# [N_image,l1,l2,l3,x,y]
|
|
5225
|
+
S4.append(
|
|
5226
|
+
self.backend.bk_reduce_mean(
|
|
5227
|
+
(
|
|
5228
|
+
self.backend.bk_reshape(
|
|
5229
|
+
I1_small[:, j1],
|
|
5230
|
+
[N_image, 1, L, 1, 1, M3, N3],
|
|
5745
5231
|
)
|
|
5746
|
-
|
|
5747
|
-
|
|
5748
|
-
|
|
5749
|
-
|
|
5750
|
-
self.backend.bk_reshape(I12_w3_2_small,[N_image,1,L,L,M3,N3])
|
|
5232
|
+
* self.backend.bk_conjugate(
|
|
5233
|
+
self.backend.bk_reshape(
|
|
5234
|
+
I12_w3_2_small,
|
|
5235
|
+
[N_image, 1, 1, L, L, M3, N3],
|
|
5751
5236
|
)
|
|
5752
|
-
)
|
|
5753
|
-
|
|
5754
|
-
|
|
5755
|
-
|
|
5756
|
-
|
|
5237
|
+
)
|
|
5238
|
+
)[..., edge_dx:-edge_dx, edge_dy:-edge_dy],
|
|
5239
|
+
axis=(-2, -1),
|
|
5240
|
+
)
|
|
5241
|
+
* fft_factor
|
|
5242
|
+
* P
|
|
5243
|
+
)
|
|
5244
|
+
if get_variance:
|
|
5245
|
+
S4_sigma.append(
|
|
5246
|
+
self.backend.bk_reduce_std(
|
|
5247
|
+
(
|
|
5248
|
+
self.backend.bk_reshape(
|
|
5249
|
+
I1_small[:, j1],
|
|
5250
|
+
[N_image, 1, L, 1, 1, M3, N3],
|
|
5251
|
+
)
|
|
5252
|
+
* self.backend.bk_conjugate(
|
|
5253
|
+
self.backend.bk_reshape(
|
|
5254
|
+
I12_w3_2_small,
|
|
5255
|
+
[N_image, 1, 1, L, L, M3, N3],
|
|
5256
|
+
)
|
|
5257
|
+
)
|
|
5258
|
+
)[..., edge_dx:-edge_dx, edge_dy:-edge_dy],
|
|
5259
|
+
axis=(-2, -1),
|
|
5260
|
+
)
|
|
5261
|
+
* fft_factor
|
|
5262
|
+
* P
|
|
5263
|
+
)
|
|
5264
|
+
else:
|
|
5265
|
+
for l1 in range(L):
|
|
5266
|
+
# [N_image,l2,l3,x,y]
|
|
5267
|
+
S4.append(
|
|
5268
|
+
self.backend.bk_reduce_mean(
|
|
5269
|
+
(
|
|
5270
|
+
self.backend.bk_reshape(
|
|
5271
|
+
I1_small[:, j1],
|
|
5272
|
+
[N_image, 1, 1, 1, M3, N3],
|
|
5273
|
+
)
|
|
5274
|
+
* self.backend.bk_conjugate(
|
|
5275
|
+
self.backend.bk_reshape(
|
|
5276
|
+
I12_w3_2_small,
|
|
5277
|
+
[N_image, 1, L, L, M3, N3],
|
|
5278
|
+
)
|
|
5279
|
+
)
|
|
5280
|
+
)[..., edge_dx:-edge_dx, edge_dy:-edge_dy],
|
|
5281
|
+
axis=(-2, -1),
|
|
5282
|
+
)
|
|
5283
|
+
* fft_factor
|
|
5284
|
+
* P
|
|
5285
|
+
)
|
|
5286
|
+
if get_variance:
|
|
5287
|
+
S4_sigma.append(
|
|
5288
|
+
self.backend.bk_reduce_std(
|
|
5289
|
+
(
|
|
5290
|
+
self.backend.bk_reshape(
|
|
5291
|
+
I1_small[:, j1],
|
|
5292
|
+
[N_image, 1, 1, 1, M3, N3],
|
|
5293
|
+
)
|
|
5294
|
+
* self.backend.bk_conjugate(
|
|
5295
|
+
self.backend.bk_reshape(
|
|
5296
|
+
I12_w3_2_small,
|
|
5297
|
+
[N_image, 1, L, L, M3, N3],
|
|
5298
|
+
)
|
|
5299
|
+
)
|
|
5300
|
+
)[
|
|
5301
|
+
...,
|
|
5302
|
+
edge_dx:-edge_dx,
|
|
5303
|
+
edge_dy:-edge_dy,
|
|
5304
|
+
],
|
|
5305
|
+
axis=(-2, -1),
|
|
5306
|
+
)
|
|
5307
|
+
* fft_factor
|
|
5308
|
+
* P
|
|
5309
|
+
)
|
|
5310
|
+
|
|
5311
|
+
S3 = self.backend.bk_concat(S3, axis=1)
|
|
5312
|
+
S4 = self.backend.bk_concat(S4, axis=1)
|
|
5313
|
+
|
|
5757
5314
|
if get_variance:
|
|
5758
|
-
S3_sigma=self.backend.bk_concat(S3_sigma,axis=1)
|
|
5759
|
-
S4_sigma=self.backend.bk_concat(S4_sigma,axis=1)
|
|
5760
|
-
|
|
5315
|
+
S3_sigma = self.backend.bk_concat(S3_sigma, axis=1)
|
|
5316
|
+
S4_sigma = self.backend.bk_concat(S4_sigma, axis=1)
|
|
5317
|
+
|
|
5761
5318
|
if data2 is not None:
|
|
5762
|
-
S3p=self.backend.bk_concat(S3p,axis=1)
|
|
5319
|
+
S3p = self.backend.bk_concat(S3p, axis=1)
|
|
5763
5320
|
if get_variance:
|
|
5764
|
-
S3p_sigma=self.backend.bk_concat(S3p_sigma,axis=1)
|
|
5765
|
-
|
|
5321
|
+
S3p_sigma = self.backend.bk_concat(S3p_sigma, axis=1)
|
|
5322
|
+
|
|
5766
5323
|
# average over l1 to obtain simple isotropic statistics
|
|
5767
5324
|
if iso_ang:
|
|
5768
|
-
S2_iso = self.backend.bk_reduce_mean(S2,axis=(-1))
|
|
5769
|
-
S1_iso = self.backend.bk_reduce_mean(S1,axis=(-1))
|
|
5325
|
+
S2_iso = self.backend.bk_reduce_mean(S2, axis=(-1))
|
|
5326
|
+
S1_iso = self.backend.bk_reduce_mean(S1, axis=(-1))
|
|
5770
5327
|
for l1 in range(L):
|
|
5771
5328
|
for l2 in range(L):
|
|
5772
|
-
S3_iso[...,(l2-l1)%L] += S3[...,l1,l2]
|
|
5329
|
+
S3_iso[..., (l2 - l1) % L] += S3[..., l1, l2]
|
|
5773
5330
|
if data2 is not None:
|
|
5774
|
-
S3p_iso[...,(l2-l1)%L] += S3p[...,l1,l2]
|
|
5331
|
+
S3p_iso[..., (l2 - l1) % L] += S3p[..., l1, l2]
|
|
5775
5332
|
for l3 in range(L):
|
|
5776
|
-
S4_iso[...,(l2-l1)%L,(l3-l1)%L] += S4[...,l1,l2,l3]
|
|
5777
|
-
S3_iso /= L
|
|
5333
|
+
S4_iso[..., (l2 - l1) % L, (l3 - l1) % L] += S4[..., l1, l2, l3]
|
|
5334
|
+
S3_iso /= L
|
|
5335
|
+
S4_iso /= L
|
|
5778
5336
|
if data2 is not None:
|
|
5779
5337
|
S3p_iso /= L
|
|
5780
|
-
|
|
5338
|
+
|
|
5781
5339
|
if get_variance:
|
|
5782
|
-
S2_sigma_iso = self.backend.bk_reduce_mean(S2_sigma,axis=(-1))
|
|
5783
|
-
S1_sigma_iso = self.backend.bk_reduce_mean(S1_sigma,axis=(-1))
|
|
5340
|
+
S2_sigma_iso = self.backend.bk_reduce_mean(S2_sigma, axis=(-1))
|
|
5341
|
+
S1_sigma_iso = self.backend.bk_reduce_mean(S1_sigma, axis=(-1))
|
|
5784
5342
|
for l1 in range(L):
|
|
5785
5343
|
for l2 in range(L):
|
|
5786
|
-
S3_sigma_iso[...,(l2-l1)%L] += S3_sigma[...,l1,l2]
|
|
5344
|
+
S3_sigma_iso[..., (l2 - l1) % L] += S3_sigma[..., l1, l2]
|
|
5787
5345
|
if data2 is not None:
|
|
5788
|
-
S3p_sigma_iso[...,(l2-l1)%L] += S3p_sigma[...,l1,l2]
|
|
5346
|
+
S3p_sigma_iso[..., (l2 - l1) % L] += S3p_sigma[..., l1, l2]
|
|
5789
5347
|
for l3 in range(L):
|
|
5790
|
-
S4_sigma_iso[...,(l2-l1)%L,(l3-l1)%L] += S4_sigma[
|
|
5791
|
-
|
|
5348
|
+
S4_sigma_iso[..., (l2 - l1) % L, (l3 - l1) % L] += S4_sigma[
|
|
5349
|
+
..., l1, l2, l3
|
|
5350
|
+
]
|
|
5351
|
+
S3_sigma_iso /= L
|
|
5352
|
+
S4_sigma_iso /= L
|
|
5792
5353
|
if data2 is not None:
|
|
5793
5354
|
S3p_sigma_iso /= L
|
|
5794
|
-
|
|
5355
|
+
|
|
5795
5356
|
if data2 is None:
|
|
5796
|
-
mean_data=self.backend.bk_reshape(
|
|
5797
|
-
|
|
5357
|
+
mean_data = self.backend.bk_reshape(
|
|
5358
|
+
self.backend.bk_reduce_mean(data, axis=(-2, -1)), [N_image, 1]
|
|
5359
|
+
)
|
|
5360
|
+
std_data = self.backend.bk_reshape(
|
|
5361
|
+
self.backend.bk_reduce_std(data, axis=(-2, -1)), [N_image, 1]
|
|
5362
|
+
)
|
|
5798
5363
|
else:
|
|
5799
|
-
mean_data=self.backend.bk_reshape(
|
|
5800
|
-
|
|
5801
|
-
|
|
5364
|
+
mean_data = self.backend.bk_reshape(
|
|
5365
|
+
self.backend.bk_reduce_mean(data * data2, axis=(-2, -1)), [N_image, 1]
|
|
5366
|
+
)
|
|
5367
|
+
std_data = self.backend.bk_reshape(
|
|
5368
|
+
self.backend.bk_reduce_std(data * data2, axis=(-2, -1)), [N_image, 1]
|
|
5369
|
+
)
|
|
5370
|
+
|
|
5802
5371
|
if get_variance:
|
|
5803
|
-
ref_sigma={}
|
|
5372
|
+
ref_sigma = {}
|
|
5804
5373
|
if iso_ang:
|
|
5805
|
-
ref_sigma[
|
|
5806
|
-
ref_sigma[
|
|
5807
|
-
ref_sigma[
|
|
5808
|
-
ref_sigma[
|
|
5809
|
-
ref_sigma[
|
|
5374
|
+
ref_sigma["std_data"] = std_data
|
|
5375
|
+
ref_sigma["S1_sigma"] = S1_sigma_iso
|
|
5376
|
+
ref_sigma["S2_sigma"] = S2_sigma_iso
|
|
5377
|
+
ref_sigma["S3_sigma"] = S3_sigma_iso
|
|
5378
|
+
ref_sigma["S4_sigma"] = S4_sigma_iso
|
|
5810
5379
|
if data2 is not None:
|
|
5811
|
-
ref_sigma[
|
|
5380
|
+
ref_sigma["S3p_sigma"] = S3p_sigma_iso
|
|
5812
5381
|
else:
|
|
5813
|
-
ref_sigma[
|
|
5814
|
-
ref_sigma[
|
|
5815
|
-
ref_sigma[
|
|
5816
|
-
ref_sigma[
|
|
5817
|
-
ref_sigma[
|
|
5382
|
+
ref_sigma["std_data"] = std_data
|
|
5383
|
+
ref_sigma["S1_sigma"] = S1_sigma
|
|
5384
|
+
ref_sigma["S2_sigma"] = S2_sigma
|
|
5385
|
+
ref_sigma["S3_sigma"] = S3_sigma
|
|
5386
|
+
ref_sigma["S4_sigma"] = S4_sigma
|
|
5818
5387
|
if data2 is not None:
|
|
5819
|
-
ref_sigma[
|
|
5820
|
-
|
|
5388
|
+
ref_sigma["S3p_sigma"] = S3_sigma
|
|
5389
|
+
|
|
5821
5390
|
if data2 is None:
|
|
5822
5391
|
if iso_ang:
|
|
5823
5392
|
if ref_sigma is not None:
|
|
5824
|
-
for_synthesis = self.backend.bk_concat(
|
|
5825
|
-
|
|
5826
|
-
|
|
5827
|
-
|
|
5828
|
-
|
|
5829
|
-
|
|
5830
|
-
|
|
5831
|
-
|
|
5832
|
-
|
|
5833
|
-
|
|
5393
|
+
for_synthesis = self.backend.bk_concat(
|
|
5394
|
+
(
|
|
5395
|
+
mean_data / ref_sigma["std_data"],
|
|
5396
|
+
std_data / ref_sigma["std_data"],
|
|
5397
|
+
self.backend.bk_reshape(
|
|
5398
|
+
self.backend.bk_log(S2_iso / ref_sigma["S2_sigma"]),
|
|
5399
|
+
[N_image, -1],
|
|
5400
|
+
),
|
|
5401
|
+
self.backend.bk_reshape(
|
|
5402
|
+
self.backend.bk_log(S1_iso / ref_sigma["S1_sigma"]),
|
|
5403
|
+
[N_image, -1],
|
|
5404
|
+
),
|
|
5405
|
+
self.backend.bk_reshape(
|
|
5406
|
+
self.backend.bk_real(S3_iso / ref_sigma["S3_sigma"]),
|
|
5407
|
+
[N_image, -1],
|
|
5408
|
+
),
|
|
5409
|
+
self.backend.bk_reshape(
|
|
5410
|
+
self.backend.bk_imag(S3_iso / ref_sigma["S3_sigma"]),
|
|
5411
|
+
[N_image, -1],
|
|
5412
|
+
),
|
|
5413
|
+
self.backend.bk_reshape(
|
|
5414
|
+
self.backend.bk_real(S4_iso / ref_sigma["S4_sigma"]),
|
|
5415
|
+
[N_image, -1],
|
|
5416
|
+
),
|
|
5417
|
+
self.backend.bk_reshape(
|
|
5418
|
+
self.backend.bk_imag(S4_iso / ref_sigma["S4_sigma"]),
|
|
5419
|
+
[N_image, -1],
|
|
5420
|
+
),
|
|
5421
|
+
),
|
|
5422
|
+
axis=-1,
|
|
5423
|
+
)
|
|
5834
5424
|
else:
|
|
5835
|
-
for_synthesis = self.backend.bk_concat(
|
|
5836
|
-
|
|
5837
|
-
|
|
5838
|
-
|
|
5839
|
-
|
|
5840
|
-
|
|
5841
|
-
|
|
5842
|
-
|
|
5843
|
-
|
|
5844
|
-
|
|
5425
|
+
for_synthesis = self.backend.bk_concat(
|
|
5426
|
+
(
|
|
5427
|
+
mean_data / std_data,
|
|
5428
|
+
std_data,
|
|
5429
|
+
self.backend.bk_reshape(
|
|
5430
|
+
self.backend.bk_log(S2_iso), [N_image, -1]
|
|
5431
|
+
),
|
|
5432
|
+
self.backend.bk_reshape(
|
|
5433
|
+
self.backend.bk_log(S1_iso), [N_image, -1]
|
|
5434
|
+
),
|
|
5435
|
+
self.backend.bk_reshape(
|
|
5436
|
+
self.backend.bk_real(S3_iso), [N_image, -1]
|
|
5437
|
+
),
|
|
5438
|
+
self.backend.bk_reshape(
|
|
5439
|
+
self.backend.bk_imag(S3_iso), [N_image, -1]
|
|
5440
|
+
),
|
|
5441
|
+
self.backend.bk_reshape(
|
|
5442
|
+
self.backend.bk_real(S4_iso), [N_image, -1]
|
|
5443
|
+
),
|
|
5444
|
+
self.backend.bk_reshape(
|
|
5445
|
+
self.backend.bk_imag(S4_iso), [N_image, -1]
|
|
5446
|
+
),
|
|
5447
|
+
),
|
|
5448
|
+
axis=-1,
|
|
5449
|
+
)
|
|
5845
5450
|
else:
|
|
5846
5451
|
if ref_sigma is not None:
|
|
5847
|
-
for_synthesis = self.backend.bk_concat(
|
|
5848
|
-
|
|
5849
|
-
|
|
5850
|
-
|
|
5851
|
-
|
|
5852
|
-
|
|
5853
|
-
|
|
5854
|
-
|
|
5855
|
-
|
|
5856
|
-
|
|
5452
|
+
for_synthesis = self.backend.bk_concat(
|
|
5453
|
+
(
|
|
5454
|
+
mean_data / ref_sigma["std_data"],
|
|
5455
|
+
std_data / ref_sigma["std_data"],
|
|
5456
|
+
self.backend.bk_reshape(
|
|
5457
|
+
self.backend.bk_log(S2 / ref_sigma["S2_sigma"]),
|
|
5458
|
+
[N_image, -1],
|
|
5459
|
+
),
|
|
5460
|
+
self.backend.bk_reshape(
|
|
5461
|
+
self.backend.bk_log(S1 / ref_sigma["S1_sigma"]),
|
|
5462
|
+
[N_image, -1],
|
|
5463
|
+
),
|
|
5464
|
+
self.backend.bk_reshape(
|
|
5465
|
+
self.backend.bk_real(S3 / ref_sigma["S3_sigma"]),
|
|
5466
|
+
[N_image, -1],
|
|
5467
|
+
),
|
|
5468
|
+
self.backend.bk_reshape(
|
|
5469
|
+
self.backend.bk_imag(S3 / ref_sigma["S3_sigma"]),
|
|
5470
|
+
[N_image, -1],
|
|
5471
|
+
),
|
|
5472
|
+
self.backend.bk_reshape(
|
|
5473
|
+
self.backend.bk_real(S4 / ref_sigma["S4_sigma"]),
|
|
5474
|
+
[N_image, -1],
|
|
5475
|
+
),
|
|
5476
|
+
self.backend.bk_reshape(
|
|
5477
|
+
self.backend.bk_imag(S4 / ref_sigma["S4_sigma"]),
|
|
5478
|
+
[N_image, -1],
|
|
5479
|
+
),
|
|
5480
|
+
),
|
|
5481
|
+
axis=-1,
|
|
5482
|
+
)
|
|
5857
5483
|
else:
|
|
5858
|
-
for_synthesis = self.backend.bk_concat(
|
|
5859
|
-
|
|
5484
|
+
for_synthesis = self.backend.bk_concat(
|
|
5485
|
+
(
|
|
5486
|
+
mean_data / std_data,
|
|
5860
5487
|
std_data,
|
|
5861
|
-
self.backend.bk_reshape(
|
|
5862
|
-
|
|
5863
|
-
|
|
5864
|
-
self.backend.bk_reshape(
|
|
5865
|
-
|
|
5866
|
-
|
|
5867
|
-
|
|
5488
|
+
self.backend.bk_reshape(
|
|
5489
|
+
self.backend.bk_log(S2), [N_image, -1]
|
|
5490
|
+
),
|
|
5491
|
+
self.backend.bk_reshape(
|
|
5492
|
+
self.backend.bk_log(S1), [N_image, -1]
|
|
5493
|
+
),
|
|
5494
|
+
self.backend.bk_reshape(
|
|
5495
|
+
self.backend.bk_real(S3), [N_image, -1]
|
|
5496
|
+
),
|
|
5497
|
+
self.backend.bk_reshape(
|
|
5498
|
+
self.backend.bk_imag(S3), [N_image, -1]
|
|
5499
|
+
),
|
|
5500
|
+
self.backend.bk_reshape(
|
|
5501
|
+
self.backend.bk_real(S4), [N_image, -1]
|
|
5502
|
+
),
|
|
5503
|
+
self.backend.bk_reshape(
|
|
5504
|
+
self.backend.bk_imag(S4), [N_image, -1]
|
|
5505
|
+
),
|
|
5506
|
+
),
|
|
5507
|
+
axis=-1,
|
|
5508
|
+
)
|
|
5868
5509
|
else:
|
|
5869
5510
|
if iso_ang:
|
|
5870
5511
|
if ref_sigma is not None:
|
|
5871
|
-
for_synthesis = self.backend.backend.cat(
|
|
5872
|
-
|
|
5873
|
-
|
|
5874
|
-
|
|
5875
|
-
|
|
5876
|
-
|
|
5877
|
-
|
|
5878
|
-
|
|
5879
|
-
|
|
5880
|
-
|
|
5881
|
-
|
|
5882
|
-
|
|
5512
|
+
for_synthesis = self.backend.backend.cat(
|
|
5513
|
+
(
|
|
5514
|
+
mean_data / ref_sigma["std_data"],
|
|
5515
|
+
std_data / ref_sigma["std_data"],
|
|
5516
|
+
self.backend.bk_reshape(
|
|
5517
|
+
self.backend.bk_real(S2_iso / ref_sigma["S2_sigma"]),
|
|
5518
|
+
[N_image, -1],
|
|
5519
|
+
),
|
|
5520
|
+
self.backend.bk_reshape(
|
|
5521
|
+
self.backend.bk_real(S1_iso / ref_sigma["S1_sigma"]),
|
|
5522
|
+
[N_image, -1],
|
|
5523
|
+
),
|
|
5524
|
+
self.backend.bk_reshape(
|
|
5525
|
+
self.backend.bk_real(S3_iso / ref_sigma["S3_sigma"]),
|
|
5526
|
+
[N_image, -1],
|
|
5527
|
+
),
|
|
5528
|
+
self.backend.bk_reshape(
|
|
5529
|
+
self.backend.bk_imag(S3_iso / ref_sigma["S3_sigma"]),
|
|
5530
|
+
[N_image, -1],
|
|
5531
|
+
),
|
|
5532
|
+
self.backend.bk_reshape(
|
|
5533
|
+
self.backend.bk_real(S3p_iso / ref_sigma["S3p_sigma"]),
|
|
5534
|
+
[N_image, -1],
|
|
5535
|
+
),
|
|
5536
|
+
self.backend.bk_reshape(
|
|
5537
|
+
self.backend.bk_imag(S3p_iso / ref_sigma["S3p_sigma"]),
|
|
5538
|
+
[N_image, -1],
|
|
5539
|
+
),
|
|
5540
|
+
self.backend.bk_reshape(
|
|
5541
|
+
self.backend.bk_real(S4_iso / ref_sigma["S4_sigma"]),
|
|
5542
|
+
[N_image, -1],
|
|
5543
|
+
),
|
|
5544
|
+
self.backend.bk_reshape(
|
|
5545
|
+
self.backend.bk_imag(S4_iso / ref_sigma["S4_sigma"]),
|
|
5546
|
+
[N_image, -1],
|
|
5547
|
+
),
|
|
5548
|
+
),
|
|
5549
|
+
axis=-1,
|
|
5550
|
+
)
|
|
5883
5551
|
else:
|
|
5884
|
-
for_synthesis = self.backend.backend.cat(
|
|
5885
|
-
|
|
5886
|
-
|
|
5887
|
-
|
|
5888
|
-
|
|
5889
|
-
|
|
5890
|
-
|
|
5891
|
-
|
|
5892
|
-
|
|
5893
|
-
|
|
5894
|
-
|
|
5895
|
-
|
|
5552
|
+
for_synthesis = self.backend.backend.cat(
|
|
5553
|
+
(
|
|
5554
|
+
mean_data / std_data,
|
|
5555
|
+
std_data,
|
|
5556
|
+
self.backend.bk_reshape(
|
|
5557
|
+
self.backend.bk_real(S2_iso), [N_image, -1]
|
|
5558
|
+
),
|
|
5559
|
+
self.backend.bk_reshape(
|
|
5560
|
+
self.backend.bk_real(S1_iso), [N_image, -1]
|
|
5561
|
+
),
|
|
5562
|
+
self.backend.bk_reshape(
|
|
5563
|
+
self.backend.bk_real(S3_iso), [N_image, -1]
|
|
5564
|
+
),
|
|
5565
|
+
self.backend.bk_reshape(
|
|
5566
|
+
self.backend.bk_imag(S3_iso), [N_image, -1]
|
|
5567
|
+
),
|
|
5568
|
+
self.backend.bk_reshape(
|
|
5569
|
+
self.backend.bk_real(S3p_iso), [N_image, -1]
|
|
5570
|
+
),
|
|
5571
|
+
self.backend.bk_reshape(
|
|
5572
|
+
self.backend.bk_imag(S3p_iso), [N_image, -1]
|
|
5573
|
+
),
|
|
5574
|
+
self.backend.bk_reshape(
|
|
5575
|
+
self.backend.bk_real(S4_iso), [N_image, -1]
|
|
5576
|
+
),
|
|
5577
|
+
self.backend.bk_reshape(
|
|
5578
|
+
self.backend.bk_imag(S4_iso), [N_image, -1]
|
|
5579
|
+
),
|
|
5580
|
+
),
|
|
5581
|
+
axis=-1,
|
|
5582
|
+
)
|
|
5896
5583
|
else:
|
|
5897
5584
|
if ref_sigma is not None:
|
|
5898
|
-
for_synthesis = self.backend.backend.cat(
|
|
5899
|
-
|
|
5900
|
-
|
|
5901
|
-
|
|
5902
|
-
|
|
5903
|
-
|
|
5904
|
-
|
|
5905
|
-
|
|
5906
|
-
|
|
5907
|
-
|
|
5908
|
-
|
|
5909
|
-
|
|
5585
|
+
for_synthesis = self.backend.backend.cat(
|
|
5586
|
+
(
|
|
5587
|
+
mean_data / ref_sigma["std_data"],
|
|
5588
|
+
std_data / ref_sigma["std_data"],
|
|
5589
|
+
self.backend.bk_reshape(
|
|
5590
|
+
self.backend.bk_real(S2 / ref_sigma["S2_sigma"]),
|
|
5591
|
+
[N_image, -1],
|
|
5592
|
+
),
|
|
5593
|
+
self.backend.bk_reshape(
|
|
5594
|
+
self.backend.bk_real(S1 / ref_sigma["S1_sigma"]),
|
|
5595
|
+
[N_image, -1],
|
|
5596
|
+
),
|
|
5597
|
+
self.backend.bk_reshape(
|
|
5598
|
+
self.backend.bk_real(S3 / ref_sigma["S3_sigma"]),
|
|
5599
|
+
[N_image, -1],
|
|
5600
|
+
),
|
|
5601
|
+
self.backend.bk_reshape(
|
|
5602
|
+
self.backend.bk_imag(S3 / ref_sigma["S3_sigma"]),
|
|
5603
|
+
[N_image, -1],
|
|
5604
|
+
),
|
|
5605
|
+
self.backend.bk_reshape(
|
|
5606
|
+
self.backend.bk_real(S3p / ref_sigma["S3p_sigma"]),
|
|
5607
|
+
[N_image, -1],
|
|
5608
|
+
),
|
|
5609
|
+
self.backend.bk_reshape(
|
|
5610
|
+
self.backend.bk_imag(S3p / ref_sigma["S3p_sigma"]),
|
|
5611
|
+
[N_image, -1],
|
|
5612
|
+
),
|
|
5613
|
+
self.backend.bk_reshape(
|
|
5614
|
+
self.backend.bk_real(S4 / ref_sigma["S4_sigma"]),
|
|
5615
|
+
[N_image, -1],
|
|
5616
|
+
),
|
|
5617
|
+
self.backend.bk_reshape(
|
|
5618
|
+
self.backend.bk_imag(S4 / ref_sigma["S4_sigma"]),
|
|
5619
|
+
[N_image, -1],
|
|
5620
|
+
),
|
|
5621
|
+
),
|
|
5622
|
+
axis=-1,
|
|
5623
|
+
)
|
|
5910
5624
|
else:
|
|
5911
|
-
for_synthesis = self.backend.bk_concat(
|
|
5912
|
-
|
|
5625
|
+
for_synthesis = self.backend.bk_concat(
|
|
5626
|
+
(
|
|
5627
|
+
mean_data / std_data,
|
|
5913
5628
|
std_data,
|
|
5914
|
-
self.backend.bk_reshape(
|
|
5915
|
-
|
|
5916
|
-
|
|
5917
|
-
self.backend.bk_reshape(
|
|
5918
|
-
|
|
5919
|
-
|
|
5920
|
-
self.backend.bk_reshape(
|
|
5921
|
-
|
|
5922
|
-
|
|
5923
|
-
|
|
5924
|
-
|
|
5925
|
-
|
|
5926
|
-
|
|
5629
|
+
self.backend.bk_reshape(
|
|
5630
|
+
self.backend.bk_real(S2), [N_image, -1]
|
|
5631
|
+
),
|
|
5632
|
+
self.backend.bk_reshape(
|
|
5633
|
+
self.backend.bk_real(S1), [N_image, -1]
|
|
5634
|
+
),
|
|
5635
|
+
self.backend.bk_reshape(
|
|
5636
|
+
self.backend.bk_real(S3), [N_image, -1]
|
|
5637
|
+
),
|
|
5638
|
+
self.backend.bk_reshape(
|
|
5639
|
+
self.backend.bk_imag(S3), [N_image, -1]
|
|
5640
|
+
),
|
|
5641
|
+
self.backend.bk_reshape(
|
|
5642
|
+
self.backend.bk_real(S3p), [N_image, -1]
|
|
5643
|
+
),
|
|
5644
|
+
self.backend.bk_reshape(
|
|
5645
|
+
self.backend.bk_imag(S3p), [N_image, -1]
|
|
5646
|
+
),
|
|
5647
|
+
self.backend.bk_reshape(
|
|
5648
|
+
self.backend.bk_real(S4), [N_image, -1]
|
|
5649
|
+
),
|
|
5650
|
+
self.backend.bk_reshape(
|
|
5651
|
+
self.backend.bk_imag(S4), [N_image, -1]
|
|
5652
|
+
),
|
|
5653
|
+
),
|
|
5654
|
+
axis=-1,
|
|
5655
|
+
)
|
|
5656
|
+
|
|
5657
|
+
if not use_ref:
|
|
5658
|
+
self.ref_scattering_cov_S2 = S2
|
|
5659
|
+
|
|
5927
5660
|
if get_variance:
|
|
5928
|
-
return for_synthesis,ref_sigma
|
|
5929
|
-
|
|
5661
|
+
return for_synthesis, ref_sigma
|
|
5662
|
+
|
|
5930
5663
|
return for_synthesis
|
|
5931
|
-
|
|
5932
|
-
|
|
5933
|
-
def to_gaussian(self,x):
|
|
5934
|
-
from scipy.stats import norm
|
|
5664
|
+
|
|
5665
|
+
def to_gaussian(self, x):
|
|
5935
5666
|
from scipy.interpolate import interp1d
|
|
5667
|
+
from scipy.stats import norm
|
|
5936
5668
|
|
|
5937
|
-
idx=np.argsort(x.flatten())
|
|
5669
|
+
idx = np.argsort(x.flatten())
|
|
5938
5670
|
p = (np.arange(1, idx.shape[0] + 1) - 0.5) / idx.shape[0]
|
|
5939
|
-
im_target=x.flatten()
|
|
5671
|
+
im_target = x.flatten()
|
|
5940
5672
|
im_target[idx] = norm.ppf(p)
|
|
5941
|
-
|
|
5673
|
+
|
|
5942
5674
|
# Interpolation cubique
|
|
5943
|
-
self.f_gaussian = interp1d(im_target[idx], x.flatten()[idx], kind=
|
|
5944
|
-
self.val_min=im_target[idx[0]]
|
|
5945
|
-
self.val_max=im_target[idx[-1]]
|
|
5675
|
+
self.f_gaussian = interp1d(im_target[idx], x.flatten()[idx], kind="cubic")
|
|
5676
|
+
self.val_min = im_target[idx[0]]
|
|
5677
|
+
self.val_max = im_target[idx[-1]]
|
|
5946
5678
|
return im_target.reshape(x.shape)
|
|
5947
5679
|
|
|
5680
|
+
def from_gaussian(self, x):
|
|
5948
5681
|
|
|
5949
|
-
|
|
5950
|
-
|
|
5951
|
-
x=self.backend.bk_clip_by_value(x,self.val_min,self.val_max)
|
|
5682
|
+
x = self.backend.bk_clip_by_value(x, self.val_min, self.val_max)
|
|
5952
5683
|
return self.f_gaussian(self.backend.to_numpy(x))
|
|
5953
|
-
|
|
5684
|
+
|
|
5954
5685
|
def square(self, x):
|
|
5955
5686
|
if isinstance(x, scat_cov):
|
|
5956
5687
|
if x.S1 is None:
|
|
@@ -6302,89 +6033,133 @@ class funct(FOC.FoCUS):
|
|
|
6302
6033
|
s0, s2, s3, s4, s1=s1, s3p=s3p, backend=self.backend, use_1D=self.use_1D
|
|
6303
6034
|
)
|
|
6304
6035
|
|
|
6305
|
-
def synthesis(
|
|
6306
|
-
|
|
6307
|
-
|
|
6308
|
-
|
|
6309
|
-
|
|
6310
|
-
|
|
6311
|
-
|
|
6312
|
-
|
|
6313
|
-
|
|
6314
|
-
|
|
6315
|
-
|
|
6316
|
-
|
|
6317
|
-
|
|
6318
|
-
|
|
6319
|
-
|
|
6036
|
+
def synthesis(
|
|
6037
|
+
self,
|
|
6038
|
+
image_target,
|
|
6039
|
+
nstep=4,
|
|
6040
|
+
seed=1234,
|
|
6041
|
+
Jmax=None,
|
|
6042
|
+
edge=False,
|
|
6043
|
+
to_gaussian=True,
|
|
6044
|
+
use_variance=False,
|
|
6045
|
+
synthesised_N=1,
|
|
6046
|
+
input_image=None,
|
|
6047
|
+
grd_mask=None,
|
|
6048
|
+
iso_ang=False,
|
|
6049
|
+
EVAL_FREQUENCY=100,
|
|
6050
|
+
NUM_EPOCHS=300,
|
|
6051
|
+
):
|
|
6052
|
+
|
|
6320
6053
|
import time
|
|
6321
6054
|
|
|
6322
|
-
|
|
6323
|
-
|
|
6055
|
+
import foscat.Synthesis as synthe
|
|
6056
|
+
|
|
6057
|
+
def The_loss(u, scat_operator, args):
|
|
6058
|
+
ref = args[0]
|
|
6324
6059
|
sref = args[1]
|
|
6325
|
-
use_v= args[2]
|
|
6326
|
-
|
|
6060
|
+
use_v = args[2]
|
|
6061
|
+
|
|
6327
6062
|
# compute scattering covariance of the current synthetised map called u
|
|
6328
6063
|
if use_v:
|
|
6329
|
-
learn=scat_operator.reduce_mean_batch(
|
|
6064
|
+
learn = scat_operator.reduce_mean_batch(
|
|
6065
|
+
scat_operator.scattering_cov(
|
|
6066
|
+
u,
|
|
6067
|
+
edge=edge,
|
|
6068
|
+
Jmax=Jmax,
|
|
6069
|
+
ref_sigma=sref,
|
|
6070
|
+
use_ref=True,
|
|
6071
|
+
iso_ang=iso_ang,
|
|
6072
|
+
)
|
|
6073
|
+
)
|
|
6330
6074
|
else:
|
|
6331
|
-
learn=scat_operator.reduce_mean_batch(
|
|
6332
|
-
|
|
6075
|
+
learn = scat_operator.reduce_mean_batch(
|
|
6076
|
+
scat_operator.scattering_cov(
|
|
6077
|
+
u, edge=edge, Jmax=Jmax, use_ref=True, iso_ang=iso_ang
|
|
6078
|
+
)
|
|
6079
|
+
)
|
|
6080
|
+
|
|
6333
6081
|
# make the difference withe the reference coordinates
|
|
6334
|
-
loss=scat_operator.backend.bk_reduce_mean(
|
|
6082
|
+
loss = scat_operator.backend.bk_reduce_mean(
|
|
6083
|
+
scat_operator.backend.bk_square(learn - ref)
|
|
6084
|
+
)
|
|
6335
6085
|
return loss
|
|
6336
6086
|
|
|
6337
6087
|
if to_gaussian:
|
|
6338
6088
|
# Change the data histogram to gaussian distribution
|
|
6339
|
-
im_target=self.to_gaussian(image_target)
|
|
6089
|
+
im_target = self.to_gaussian(image_target)
|
|
6340
6090
|
else:
|
|
6341
|
-
im_target=image_target
|
|
6342
|
-
|
|
6343
|
-
axis=len(im_target.shape)-1
|
|
6091
|
+
im_target = image_target
|
|
6092
|
+
|
|
6093
|
+
axis = len(im_target.shape) - 1
|
|
6344
6094
|
if self.use_2D:
|
|
6345
|
-
axis-=1
|
|
6346
|
-
if axis==0:
|
|
6347
|
-
im_target=self.backend.bk_expand_dims(im_target,0)
|
|
6095
|
+
axis -= 1
|
|
6096
|
+
if axis == 0:
|
|
6097
|
+
im_target = self.backend.bk_expand_dims(im_target, 0)
|
|
6348
6098
|
|
|
6349
6099
|
# compute the number of possible steps
|
|
6350
6100
|
if self.use_2D:
|
|
6351
|
-
jmax=int(
|
|
6101
|
+
jmax = int(
|
|
6102
|
+
np.min([np.log(im_target.shape[1]), np.log(im_target.shape[2])])
|
|
6103
|
+
/ np.log(2)
|
|
6104
|
+
)
|
|
6352
6105
|
elif self.use_1D:
|
|
6353
|
-
jmax=int(np.log(im_target.shape[1])/np.log(2))
|
|
6106
|
+
jmax = int(np.log(im_target.shape[1]) / np.log(2))
|
|
6354
6107
|
else:
|
|
6355
|
-
jmax=int((np.log(im_target.shape[1]//12)/np.log(2))/2)
|
|
6356
|
-
nside=2**jmax
|
|
6108
|
+
jmax = int((np.log(im_target.shape[1] // 12) / np.log(2)) / 2)
|
|
6109
|
+
nside = 2**jmax
|
|
6357
6110
|
|
|
6358
|
-
if nstep>jmax-1:
|
|
6359
|
-
nstep=jmax-1
|
|
6111
|
+
if nstep > jmax - 1:
|
|
6112
|
+
nstep = jmax - 1
|
|
6360
6113
|
|
|
6361
|
-
t1=time.time()
|
|
6362
|
-
tmp={}
|
|
6363
|
-
|
|
6364
|
-
|
|
6365
|
-
|
|
6114
|
+
t1 = time.time()
|
|
6115
|
+
tmp = {}
|
|
6116
|
+
|
|
6117
|
+
l_grd_mask={}
|
|
6118
|
+
|
|
6119
|
+
tmp[nstep - 1] = self.backend.bk_cast(im_target)
|
|
6120
|
+
if grd_mask is not None:
|
|
6121
|
+
l_grd_mask[nstep - 1] = self.backend.bk_cast(grd_mask)
|
|
6122
|
+
else:
|
|
6123
|
+
l_grd_mask[nstep - 1] = None
|
|
6366
6124
|
|
|
6367
|
-
|
|
6368
|
-
|
|
6125
|
+
for ell in range(nstep - 2, -1, -1):
|
|
6126
|
+
tmp[ell] = self.ud_grade_2(tmp[ell + 1], axis=1)
|
|
6127
|
+
if grd_mask is not None:
|
|
6128
|
+
l_grd_mask[ell] = self.ud_grade_2(l_grd_mask[ell + 1], axis=1)
|
|
6129
|
+
else:
|
|
6130
|
+
l_grd_mask[ell] = None
|
|
6369
6131
|
|
|
6132
|
+
|
|
6133
|
+
if not self.use_2D and not self.use_1D:
|
|
6134
|
+
l_nside = nside // (2 ** (nstep - 1))
|
|
6135
|
+
|
|
6370
6136
|
for k in range(nstep):
|
|
6371
|
-
if k==0:
|
|
6137
|
+
if k == 0:
|
|
6372
6138
|
if input_image is None:
|
|
6373
6139
|
np.random.seed(seed)
|
|
6374
6140
|
if self.use_2D:
|
|
6375
|
-
imap=np.random.randn(
|
|
6376
|
-
|
|
6377
|
-
|
|
6141
|
+
imap = self.backend.bk_cast(np.random.randn(
|
|
6142
|
+
synthesised_N, tmp[k].shape[1], tmp[k].shape[2]
|
|
6143
|
+
))
|
|
6378
6144
|
else:
|
|
6379
|
-
imap=np.random.randn(synthesised_N,
|
|
6380
|
-
tmp[k].shape[1])
|
|
6145
|
+
imap = self.backend.bk_cast(np.random.randn(synthesised_N, tmp[k].shape[1]))
|
|
6381
6146
|
else:
|
|
6382
6147
|
if self.use_2D:
|
|
6383
|
-
imap=self.backend.bk_reshape(
|
|
6384
|
-
|
|
6148
|
+
imap = self.backend.bk_reshape(
|
|
6149
|
+
self.backend.bk_tile(
|
|
6150
|
+
self.backend.bk_cast(input_image.flatten()),
|
|
6151
|
+
synthesised_N,
|
|
6152
|
+
),
|
|
6153
|
+
[synthesised_N, tmp[k].shape[1], tmp[k].shape[2]],
|
|
6154
|
+
)
|
|
6385
6155
|
else:
|
|
6386
|
-
imap=self.backend.bk_reshape(
|
|
6387
|
-
|
|
6156
|
+
imap = self.backend.bk_reshape(
|
|
6157
|
+
self.backend.bk_tile(
|
|
6158
|
+
self.backend.bk_cast(input_image.flatten()),
|
|
6159
|
+
synthesised_N,
|
|
6160
|
+
),
|
|
6161
|
+
[synthesised_N, tmp[k].shape[1]],
|
|
6162
|
+
)
|
|
6388
6163
|
else:
|
|
6389
6164
|
# Increase the resolution between each step
|
|
6390
6165
|
if self.use_2D:
|
|
@@ -6395,48 +6170,49 @@ class funct(FOC.FoCUS):
|
|
|
6395
6170
|
imap = self.up_grade(omap, imap.shape[1] * 2, axis=1)
|
|
6396
6171
|
else:
|
|
6397
6172
|
imap = self.up_grade(omap, l_nside, axis=1)
|
|
6398
|
-
|
|
6173
|
+
|
|
6174
|
+
if grd_mask is not None:
|
|
6175
|
+
imap=imap*l_grd_mask[k]+tmp[k]*(1-l_grd_mask[k])
|
|
6176
|
+
|
|
6399
6177
|
# compute the coefficients for the target image
|
|
6400
6178
|
if use_variance:
|
|
6401
|
-
ref,sref=self.scattering_cov(
|
|
6179
|
+
ref, sref = self.scattering_cov(
|
|
6180
|
+
tmp[k], get_variance=True, edge=edge, Jmax=Jmax, iso_ang=iso_ang
|
|
6181
|
+
)
|
|
6402
6182
|
else:
|
|
6403
|
-
ref=self.scattering_cov(tmp[k],edge=edge,Jmax=Jmax,iso_ang=iso_ang)
|
|
6404
|
-
sref=ref
|
|
6405
|
-
|
|
6183
|
+
ref = self.scattering_cov(tmp[k], edge=edge, Jmax=Jmax, iso_ang=iso_ang)
|
|
6184
|
+
sref = ref
|
|
6185
|
+
|
|
6406
6186
|
# compute the mean of the population does nothing if only one map is given
|
|
6407
|
-
ref=self.reduce_mean_batch(ref)
|
|
6408
|
-
|
|
6187
|
+
ref = self.reduce_mean_batch(ref)
|
|
6188
|
+
|
|
6409
6189
|
# define a loss to minimize
|
|
6410
|
-
loss=synthe.Loss(The_loss,self,ref,sref,use_variance)
|
|
6411
|
-
|
|
6190
|
+
loss = synthe.Loss(The_loss, self, ref, sref, use_variance)
|
|
6191
|
+
|
|
6412
6192
|
sy = synthe.Synthesis([loss])
|
|
6413
6193
|
|
|
6414
6194
|
# initialize the synthesised map
|
|
6415
6195
|
if self.use_2D:
|
|
6416
|
-
print(
|
|
6196
|
+
print("Synthesis scale [ %d x %d ]" % (imap.shape[1], imap.shape[2]))
|
|
6417
6197
|
elif self.use_1D:
|
|
6418
|
-
print(
|
|
6198
|
+
print("Synthesis scale [ %d ]" % (imap.shape[1]))
|
|
6419
6199
|
else:
|
|
6420
|
-
print(
|
|
6421
|
-
l_nside*=2
|
|
6422
|
-
|
|
6200
|
+
print("Synthesis scale nside=%d" % (l_nside))
|
|
6201
|
+
l_nside *= 2
|
|
6202
|
+
|
|
6423
6203
|
# do the minimization
|
|
6424
|
-
omap=sy.run(imap,
|
|
6425
|
-
EVAL_FREQUENCY=EVAL_FREQUENCY,
|
|
6426
|
-
NUM_EPOCHS = NUM_EPOCHS)
|
|
6427
|
-
|
|
6428
|
-
|
|
6204
|
+
omap = sy.run(imap, EVAL_FREQUENCY=EVAL_FREQUENCY, NUM_EPOCHS=NUM_EPOCHS,grd_mask=l_grd_mask[k])
|
|
6429
6205
|
|
|
6430
|
-
t2=time.time()
|
|
6431
|
-
print(
|
|
6206
|
+
t2 = time.time()
|
|
6207
|
+
print("Total computation %.2fs" % (t2 - t1))
|
|
6432
6208
|
|
|
6433
6209
|
if to_gaussian:
|
|
6434
|
-
omap=self.from_gaussian(omap)
|
|
6210
|
+
omap = self.from_gaussian(omap)
|
|
6435
6211
|
|
|
6436
|
-
if axis==0 and synthesised_N==1:
|
|
6212
|
+
if axis == 0 and synthesised_N == 1:
|
|
6437
6213
|
return omap[0]
|
|
6438
6214
|
else:
|
|
6439
6215
|
return omap
|
|
6440
|
-
|
|
6441
|
-
def to_numpy(self,x):
|
|
6216
|
+
|
|
6217
|
+
def to_numpy(self, x):
|
|
6442
6218
|
return self.backend.to_numpy(x)
|