foscat 3.6.1__py3-none-any.whl → 3.7.1__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- foscat/CircSpline.py +19 -22
- foscat/FoCUS.py +61 -51
- foscat/Spline1D.py +12 -13
- foscat/Synthesis.py +16 -13
- foscat/alm.py +718 -580
- foscat/backend.py +141 -37
- foscat/scat_cov.py +2311 -430
- foscat/scat_cov2D.py +102 -2
- foscat/scat_cov_map.py +15 -2
- {foscat-3.6.1.dist-info → foscat-3.7.1.dist-info}/METADATA +3 -3
- foscat-3.7.1.dist-info/RECORD +26 -0
- {foscat-3.6.1.dist-info → foscat-3.7.1.dist-info}/WHEEL +1 -1
- foscat/alm_tools.py +0 -11
- foscat-3.6.1.dist-info/RECORD +0 -27
- /foscat-3.6.1.dist-info/LICENCE → /foscat-3.7.1.dist-info/LICENSE +0 -0
- {foscat-3.6.1.dist-info → foscat-3.7.1.dist-info}/top_level.txt +0 -0
foscat/scat_cov.py
CHANGED
|
@@ -33,9 +33,7 @@ testwarn = 0
|
|
|
33
33
|
|
|
34
34
|
|
|
35
35
|
class scat_cov:
|
|
36
|
-
def __init__(
|
|
37
|
-
self, s0, s2, s3, s4, s1=None, s3p=None, backend=None, use_1D=False
|
|
38
|
-
):
|
|
36
|
+
def __init__(self, s0, s2, s3, s4, s1=None, s3p=None, backend=None, use_1D=False):
|
|
39
37
|
self.S0 = s0
|
|
40
38
|
self.S2 = s2
|
|
41
39
|
self.S3 = s3
|
|
@@ -54,17 +52,17 @@ class scat_cov:
|
|
|
54
52
|
if self.S1 is None:
|
|
55
53
|
s1 = None
|
|
56
54
|
else:
|
|
57
|
-
s1 = self.
|
|
55
|
+
s1 = self.backend.to_numpy(self.S1)
|
|
58
56
|
if self.S3P is None:
|
|
59
57
|
s3p = None
|
|
60
58
|
else:
|
|
61
|
-
s3p = self.
|
|
59
|
+
s3p = self.backend.to_numpy(self.S3P)
|
|
62
60
|
|
|
63
61
|
return scat_cov(
|
|
64
|
-
(self.S0
|
|
65
|
-
(self.S2
|
|
66
|
-
(self.S3
|
|
67
|
-
(self.S4
|
|
62
|
+
self.backend.to_numpy(self.S0),
|
|
63
|
+
self.backend.to_numpy(self.S2),
|
|
64
|
+
self.backend.to_numpy(self.S3),
|
|
65
|
+
self.backend.to_numpy(self.S4),
|
|
68
66
|
s1=s1,
|
|
69
67
|
s3p=s3p,
|
|
70
68
|
backend=self.backend,
|
|
@@ -94,7 +92,7 @@ class scat_cov:
|
|
|
94
92
|
)
|
|
95
93
|
|
|
96
94
|
def conv2complex(self, val):
|
|
97
|
-
if val.dtype == "complex64" or val.dtype == "complex128":
|
|
95
|
+
if val.dtype == "complex64" or val.dtype=="complex128" or val.dtype == "torch.complex64" or val.dtype == "torch.complex128" :
|
|
98
96
|
return val
|
|
99
97
|
else:
|
|
100
98
|
return self.backend.bk_complex(val, 0 * val)
|
|
@@ -1531,7 +1529,7 @@ class scat_cov:
|
|
|
1531
1529
|
if isinstance(x, np.ndarray):
|
|
1532
1530
|
return x
|
|
1533
1531
|
else:
|
|
1534
|
-
return
|
|
1532
|
+
return self.backend.to_numpy(x)
|
|
1535
1533
|
else:
|
|
1536
1534
|
return None
|
|
1537
1535
|
|
|
@@ -1974,7 +1972,7 @@ class scat_cov:
|
|
|
1974
1972
|
if self.BACKEND == "numpy":
|
|
1975
1973
|
s2[:, :, noff:, :] = self.S2
|
|
1976
1974
|
else:
|
|
1977
|
-
s2[:, :, noff:, :] = self.
|
|
1975
|
+
s2[:, :, noff:, :] = self.backend.to_numpy(self.S2)
|
|
1978
1976
|
for i in range(self.S2.shape[0]):
|
|
1979
1977
|
for j in range(self.S2.shape[1]):
|
|
1980
1978
|
for k in range(self.S2.shape[3]):
|
|
@@ -1986,7 +1984,7 @@ class scat_cov:
|
|
|
1986
1984
|
if self.BACKEND == "numpy":
|
|
1987
1985
|
s1[:, :, noff:, :] = self.S1
|
|
1988
1986
|
else:
|
|
1989
|
-
s1[:, :, noff:, :] = self.
|
|
1987
|
+
s1[:, :, noff:, :] = self.backend.to_numpy(self.S1)
|
|
1990
1988
|
for i in range(self.S1.shape[0]):
|
|
1991
1989
|
for j in range(self.S1.shape[1]):
|
|
1992
1990
|
for k in range(self.S1.shape[3]):
|
|
@@ -2045,12 +2043,12 @@ class scat_cov:
|
|
|
2045
2043
|
)
|
|
2046
2044
|
)
|
|
2047
2045
|
else:
|
|
2048
|
-
s3[i, j, idx[noff:], k, l_orient] = self.
|
|
2046
|
+
s3[i, j, idx[noff:], k, l_orient] = self.backend.to_numpy(self.S3)[
|
|
2049
2047
|
i, j, j2 == ij - noff, k, l_orient
|
|
2050
2048
|
]
|
|
2051
2049
|
s3[i, j, idx[:noff], k, l_orient] = (
|
|
2052
2050
|
self.add_data_from_slope(
|
|
2053
|
-
self.
|
|
2051
|
+
self.backend.to_numpy(self.S3)[
|
|
2054
2052
|
i, j, j2 == ij - noff, k, l_orient
|
|
2055
2053
|
],
|
|
2056
2054
|
noff,
|
|
@@ -2358,13 +2356,13 @@ class funct(FOC.FoCUS):
|
|
|
2358
2356
|
tmp = self.up_grade(tmp, l_nside * 2, axis=1)
|
|
2359
2357
|
if image2 is not None:
|
|
2360
2358
|
tmpi2 = self.up_grade(tmpi2, l_nside * 2, axis=1)
|
|
2361
|
-
|
|
2362
2359
|
l_nside = int(np.sqrt(tmp.shape[1] // 12))
|
|
2363
2360
|
nscale = int(np.log(l_nside) / np.log(2))
|
|
2364
2361
|
cmat = {}
|
|
2365
2362
|
cmat2 = {}
|
|
2363
|
+
|
|
2364
|
+
# Loop over scales
|
|
2366
2365
|
for k in range(nscale):
|
|
2367
|
-
sim = self.backend.bk_abs(self.convol(tmp, axis=1))
|
|
2368
2366
|
if image2 is not None:
|
|
2369
2367
|
sim = self.backend.bk_real(
|
|
2370
2368
|
self.backend.bk_L1(
|
|
@@ -2375,17 +2373,31 @@ class funct(FOC.FoCUS):
|
|
|
2375
2373
|
else:
|
|
2376
2374
|
sim = self.backend.bk_abs(self.convol(tmp, axis=1))
|
|
2377
2375
|
|
|
2378
|
-
|
|
2379
|
-
|
|
2380
|
-
|
|
2381
|
-
|
|
2382
|
-
|
|
2383
|
-
|
|
2376
|
+
# instead of difference between "opposite" channels use weighted average
|
|
2377
|
+
# of cosine and sine contributions using all channels
|
|
2378
|
+
angles = (
|
|
2379
|
+
2 * np.pi * np.arange(self.NORIENT) / self.NORIENT
|
|
2380
|
+
) # shape: (NORIENT,)
|
|
2381
|
+
angles = angles.reshape(1, 1, self.NORIENT)
|
|
2382
|
+
|
|
2383
|
+
# we use cosines and sines as weights for sim
|
|
2384
|
+
weighted_cos = self.backend.bk_reduce_mean(sim * np.cos(angles), axis=-1)
|
|
2385
|
+
weighted_sin = self.backend.bk_reduce_mean(sim * np.sin(angles), axis=-1)
|
|
2386
|
+
# For simplicity, take first element of the batch
|
|
2387
|
+
cc = weighted_cos[0]
|
|
2388
|
+
ss = weighted_sin[0]
|
|
2389
|
+
|
|
2390
|
+
if smooth_scale > 0:
|
|
2391
|
+
for m in range(smooth_scale):
|
|
2392
|
+
if cc.shape[0] > 12:
|
|
2393
|
+
cc = self.ud_grade_2(self.smooth(cc))
|
|
2394
|
+
ss = self.ud_grade_2(self.smooth(ss))
|
|
2384
2395
|
if cc.shape[0] != tmp.shape[0]:
|
|
2385
2396
|
ll_nside = int(np.sqrt(tmp.shape[1] // 12))
|
|
2386
2397
|
cc = self.up_grade(cc, ll_nside)
|
|
2387
2398
|
ss = self.up_grade(ss, ll_nside)
|
|
2388
2399
|
|
|
2400
|
+
# compute local phase from weighted cos and sin (same as before)
|
|
2389
2401
|
if self.BACKEND == "numpy":
|
|
2390
2402
|
phase = np.fmod(np.arctan2(ss, cc) + 2 * np.pi, 2 * np.pi)
|
|
2391
2403
|
else:
|
|
@@ -2393,87 +2405,87 @@ class funct(FOC.FoCUS):
|
|
|
2393
2405
|
np.arctan2(ss.numpy(), cc.numpy()) + 2 * np.pi, 2 * np.pi
|
|
2394
2406
|
)
|
|
2395
2407
|
|
|
2396
|
-
|
|
2397
|
-
|
|
2398
|
-
|
|
2408
|
+
# instead of linear interpolation cosine‐based interpolation
|
|
2409
|
+
phase_scaled = self.NORIENT * phase / (2 * np.pi)
|
|
2410
|
+
iph = np.floor(phase_scaled).astype("int") # lower bin index
|
|
2411
|
+
delta = phase_scaled - iph # fractional part in [0,1)
|
|
2412
|
+
# interpolation weights
|
|
2413
|
+
w0 = np.cos(delta * np.pi / 2) ** 2
|
|
2414
|
+
w1 = np.sin(delta * np.pi / 2) ** 2
|
|
2415
|
+
|
|
2416
|
+
# build rotation matrix
|
|
2417
|
+
mat = np.zeros([sim.shape[1], self.NORIENT * self.NORIENT])
|
|
2399
2418
|
lidx = np.arange(sim.shape[1])
|
|
2400
|
-
for
|
|
2401
|
-
|
|
2402
|
-
|
|
2419
|
+
for l in range(self.NORIENT):
|
|
2420
|
+
# Instead of simple linear weights, we use the cosine weights w0 and w1.
|
|
2421
|
+
col0 = self.NORIENT * ((l + iph) % self.NORIENT) + l
|
|
2422
|
+
col1 = self.NORIENT * ((l + iph + 1) % self.NORIENT) + l
|
|
2423
|
+
mat[lidx, col0] = w0
|
|
2424
|
+
mat[lidx, col1] = w1
|
|
2403
2425
|
|
|
2404
2426
|
cmat[k] = self.backend.bk_cast(mat.astype("complex64"))
|
|
2405
2427
|
|
|
2406
|
-
|
|
2428
|
+
# do same modifications for mat2
|
|
2429
|
+
mat2 = np.zeros(
|
|
2430
|
+
[k + 1, sim.shape[1], self.NORIENT, self.NORIENT * self.NORIENT]
|
|
2431
|
+
)
|
|
2407
2432
|
|
|
2408
2433
|
for k2 in range(k + 1):
|
|
2409
|
-
tmp2 = self.backend.bk_repeat(sim,
|
|
2434
|
+
tmp2 = self.backend.bk_repeat(sim, self.NORIENT, axis=-1)
|
|
2435
|
+
|
|
2410
2436
|
sim2 = self.backend.bk_reduce_sum(
|
|
2411
2437
|
self.backend.bk_reshape(
|
|
2412
|
-
mat.reshape(1, mat.shape[0],
|
|
2413
|
-
|
|
2438
|
+
mat.reshape(1, mat.shape[0], self.NORIENT * self.NORIENT)
|
|
2439
|
+
* tmp2,
|
|
2440
|
+
[sim.shape[0], cmat[k].shape[0], self.NORIENT, self.NORIENT],
|
|
2414
2441
|
),
|
|
2415
2442
|
2,
|
|
2416
2443
|
)
|
|
2444
|
+
|
|
2417
2445
|
sim2 = self.backend.bk_abs(self.convol(sim2, axis=1))
|
|
2418
2446
|
|
|
2419
|
-
|
|
2420
|
-
|
|
2447
|
+
weighted_cos2 = self.backend.bk_reduce_mean(
|
|
2448
|
+
sim2 * np.cos(angles), axis=-1
|
|
2421
2449
|
)
|
|
2422
|
-
|
|
2423
|
-
|
|
2450
|
+
weighted_sin2 = self.backend.bk_reduce_mean(
|
|
2451
|
+
sim2 * np.sin(angles), axis=-1
|
|
2424
2452
|
)
|
|
2425
|
-
|
|
2426
|
-
|
|
2427
|
-
|
|
2428
|
-
|
|
2429
|
-
if
|
|
2453
|
+
|
|
2454
|
+
cc2 = weighted_cos2[0]
|
|
2455
|
+
ss2 = weighted_sin2[0]
|
|
2456
|
+
|
|
2457
|
+
if smooth_scale > 0:
|
|
2458
|
+
for m in range(smooth_scale):
|
|
2459
|
+
if cc2.shape[0] > 12:
|
|
2460
|
+
cc2 = self.ud_grade_2(self.smooth(cc2))
|
|
2461
|
+
ss2 = self.ud_grade_2(self.smooth(ss2))
|
|
2462
|
+
|
|
2463
|
+
if cc2.shape[0] != sim.shape[1]:
|
|
2430
2464
|
ll_nside = int(np.sqrt(sim.shape[1] // 12))
|
|
2431
|
-
|
|
2432
|
-
|
|
2465
|
+
cc2 = self.up_grade(cc2, ll_nside)
|
|
2466
|
+
ss2 = self.up_grade(ss2, ll_nside)
|
|
2433
2467
|
|
|
2434
2468
|
if self.BACKEND == "numpy":
|
|
2435
|
-
|
|
2469
|
+
phase2 = np.fmod(np.arctan2(ss2, cc2) + 2 * np.pi, 2 * np.pi)
|
|
2436
2470
|
else:
|
|
2437
|
-
|
|
2438
|
-
np.arctan2(
|
|
2471
|
+
phase2 = np.fmod(
|
|
2472
|
+
np.arctan2(ss2.numpy(), cc2.numpy()) + 2 * np.pi, 2 * np.pi
|
|
2439
2473
|
)
|
|
2440
|
-
|
|
2441
|
-
|
|
2442
|
-
|
|
2443
|
-
|
|
2444
|
-
|
|
2445
|
-
|
|
2446
|
-
iph = (4 * phase / (2 * np.pi)).astype("int")
|
|
2447
|
-
alpha = 4 * phase / (2 * np.pi) - iph
|
|
2474
|
+
|
|
2475
|
+
phase2_scaled = self.NORIENT * phase2 / (2 * np.pi)
|
|
2476
|
+
iph2 = np.floor(phase2_scaled).astype("int")
|
|
2477
|
+
delta2 = phase2_scaled - iph2
|
|
2478
|
+
w0_2 = np.cos(delta2 * np.pi / 2) ** 2
|
|
2479
|
+
w1_2 = np.sin(delta2 * np.pi / 2) ** 2
|
|
2448
2480
|
lidx = np.arange(sim.shape[1])
|
|
2449
|
-
|
|
2450
|
-
|
|
2451
|
-
|
|
2452
|
-
|
|
2453
|
-
|
|
2454
|
-
mat2[
|
|
2455
|
-
|
|
2456
|
-
|
|
2457
|
-
|
|
2458
|
-
cmat2[k] = self.backend.bk_cast(mat2.astype("complex64"))
|
|
2459
|
-
"""
|
|
2460
|
-
tmp=self.backend.bk_repeat(sim[0],4,axis=1)
|
|
2461
|
-
sim2=self.backend.bk_reduce_sum(self.backend.bk_reshape(mat*tmp,[12*nside**2,4,4]),1)
|
|
2462
|
-
|
|
2463
|
-
cc2=(sim2[:,0]-sim2[:,2])
|
|
2464
|
-
ss2=(sim2[:,1]-sim2[:,3])
|
|
2465
|
-
phase2=np.fmod(np.arctan2(ss2.numpy(),cc2.numpy())+2*np.pi,2*np.pi)
|
|
2466
|
-
|
|
2467
|
-
plt.figure()
|
|
2468
|
-
hp.mollview(phase,cmap='jet',nest=True,hold=False,sub=(2,2,1))
|
|
2469
|
-
hp.mollview(np.fmod(phase2+np.pi,2*np.pi),cmap='jet',nest=True,hold=False,sub=(2,2,2))
|
|
2470
|
-
plt.figure()
|
|
2471
|
-
for k in range(4):
|
|
2472
|
-
hp.mollview((sim[0,:,k]).numpy().real,cmap='jet',nest=True,hold=False,sub=(2,4,1+k),min=-10,max=10)
|
|
2473
|
-
hp.mollview((sim2[:,k]).numpy().real,cmap='jet',nest=True,hold=False,sub=(2,4,5+k),min=-10,max=10)
|
|
2474
|
-
|
|
2475
|
-
plt.show()
|
|
2476
|
-
"""
|
|
2481
|
+
|
|
2482
|
+
for m in range(self.NORIENT):
|
|
2483
|
+
for l in range(self.NORIENT):
|
|
2484
|
+
col0 = self.NORIENT * ((l + iph2[:, m]) % self.NORIENT) + l
|
|
2485
|
+
col1 = self.NORIENT * ((l + iph2[:, m] + 1) % self.NORIENT) + l
|
|
2486
|
+
mat2[k2, lidx, m, col0] = w0_2[:, m]
|
|
2487
|
+
mat2[k2, lidx, m, col1] = w1_2[:, m]
|
|
2488
|
+
cmat2[k] = self.backend.bk_cast(mat2.astype("complex64"))
|
|
2477
2489
|
|
|
2478
2490
|
if k < l_nside - 1:
|
|
2479
2491
|
tmp = self.ud_grade_2(tmp, axis=1)
|
|
@@ -2493,11 +2505,12 @@ class funct(FOC.FoCUS):
|
|
|
2493
2505
|
image2=None,
|
|
2494
2506
|
mask=None,
|
|
2495
2507
|
norm=None,
|
|
2496
|
-
Auto=True,
|
|
2497
2508
|
calc_var=False,
|
|
2498
2509
|
cmat=None,
|
|
2499
2510
|
cmat2=None,
|
|
2500
|
-
|
|
2511
|
+
Jmax=None,
|
|
2512
|
+
out_nside=None,
|
|
2513
|
+
edge=True
|
|
2501
2514
|
):
|
|
2502
2515
|
"""
|
|
2503
2516
|
Calculates the scattering correlations for a batch of images. Mean are done over pixels.
|
|
@@ -2523,9 +2536,6 @@ class funct(FOC.FoCUS):
|
|
|
2523
2536
|
norm: None or str
|
|
2524
2537
|
If None no normalization is applied, if 'auto' normalize by the reference S2,
|
|
2525
2538
|
if 'self' normalize by the current S2.
|
|
2526
|
-
all_cross: False or True
|
|
2527
|
-
If False compute all the coefficient even the Imaginary part,
|
|
2528
|
-
If True return only the terms computable in the auto case.
|
|
2529
2539
|
Returns
|
|
2530
2540
|
-------
|
|
2531
2541
|
S1, S2, S3, S4 normalized
|
|
@@ -2539,15 +2549,26 @@ class funct(FOC.FoCUS):
|
|
|
2539
2549
|
)
|
|
2540
2550
|
return None
|
|
2541
2551
|
if mask is not None:
|
|
2542
|
-
if
|
|
2543
|
-
|
|
2544
|
-
|
|
2545
|
-
|
|
2546
|
-
|
|
2547
|
-
|
|
2548
|
-
|
|
2549
|
-
|
|
2550
|
-
|
|
2552
|
+
if self.use_2D:
|
|
2553
|
+
if image1.shape[-2] != mask.shape[1] or image1.shape[-1] != mask.shape[2]:
|
|
2554
|
+
print(
|
|
2555
|
+
"The LAST 2 COLUMNs of the mask should have the same size ",
|
|
2556
|
+
mask.shape,
|
|
2557
|
+
"than the input image ",
|
|
2558
|
+
image1.shape,
|
|
2559
|
+
"to eval Scattering Covariance",
|
|
2560
|
+
)
|
|
2561
|
+
return None
|
|
2562
|
+
else:
|
|
2563
|
+
if image1.shape[-1] != mask.shape[1]:
|
|
2564
|
+
print(
|
|
2565
|
+
"The LAST COLUMN of the mask should have the same size ",
|
|
2566
|
+
mask.shape,
|
|
2567
|
+
"than the input image ",
|
|
2568
|
+
image1.shape,
|
|
2569
|
+
"to eval Scattering Covariance",
|
|
2570
|
+
)
|
|
2571
|
+
return None
|
|
2551
2572
|
if self.use_2D and len(image1.shape) < 2:
|
|
2552
2573
|
print(
|
|
2553
2574
|
"To work with 2D scattering transform, two dimension is needed, input map has only on dimension"
|
|
@@ -2558,9 +2579,6 @@ class funct(FOC.FoCUS):
|
|
|
2558
2579
|
cross = False
|
|
2559
2580
|
if image2 is not None:
|
|
2560
2581
|
cross = True
|
|
2561
|
-
all_cross = Auto
|
|
2562
|
-
else:
|
|
2563
|
-
all_cross = False
|
|
2564
2582
|
|
|
2565
2583
|
### PARAMETERS
|
|
2566
2584
|
axis = 1
|
|
@@ -2596,8 +2614,16 @@ class funct(FOC.FoCUS):
|
|
|
2596
2614
|
nside = int(np.sqrt(npix // 12))
|
|
2597
2615
|
|
|
2598
2616
|
J = int(np.log(nside) / np.log(2)) # Number of j scales
|
|
2599
|
-
|
|
2600
|
-
|
|
2617
|
+
|
|
2618
|
+
if (self.use_2D or self.use_1D) and self.KERNELSZ>3:
|
|
2619
|
+
J-=1
|
|
2620
|
+
if Jmax is None:
|
|
2621
|
+
Jmax = J # Number of steps for the loop on scales
|
|
2622
|
+
if Jmax>J:
|
|
2623
|
+
print('==========\n\n')
|
|
2624
|
+
print('The Jmax you requested is larger than the data size, which may cause problems while computing the scattering transform.')
|
|
2625
|
+
print('\n\n==========')
|
|
2626
|
+
|
|
2601
2627
|
|
|
2602
2628
|
### LOCAL VARIABLES (IMAGES and MASK)
|
|
2603
2629
|
if len(image1.shape) == 1 or (len(image1.shape) == 2 and self.use_2D):
|
|
@@ -2621,7 +2647,7 @@ class funct(FOC.FoCUS):
|
|
|
2621
2647
|
else:
|
|
2622
2648
|
vmask = self.backend.bk_cast(mask) # [Nmask, Npix]
|
|
2623
2649
|
|
|
2624
|
-
if self.KERNELSZ > 3:
|
|
2650
|
+
if self.KERNELSZ > 3 and not self.use_2D:
|
|
2625
2651
|
# if the kernel size is bigger than 3 increase the binning before smoothing
|
|
2626
2652
|
if self.use_2D:
|
|
2627
2653
|
vmask = self.up_grade(
|
|
@@ -2645,7 +2671,7 @@ class funct(FOC.FoCUS):
|
|
|
2645
2671
|
if cross:
|
|
2646
2672
|
I2 = self.up_grade(I2, nside * 2, axis=axis)
|
|
2647
2673
|
|
|
2648
|
-
if self.KERNELSZ > 5:
|
|
2674
|
+
if self.KERNELSZ > 5 and not self.use_2D:
|
|
2649
2675
|
# if the kernel size is bigger than 3 increase the binning before smoothing
|
|
2650
2676
|
if self.use_2D:
|
|
2651
2677
|
vmask = self.up_grade(
|
|
@@ -2677,7 +2703,23 @@ class funct(FOC.FoCUS):
|
|
|
2677
2703
|
|
|
2678
2704
|
### INITIALIZATION
|
|
2679
2705
|
# Coefficients
|
|
2680
|
-
|
|
2706
|
+
if return_data:
|
|
2707
|
+
S1 = {}
|
|
2708
|
+
S2 = {}
|
|
2709
|
+
S3 = {}
|
|
2710
|
+
S3P = {}
|
|
2711
|
+
S4 = {}
|
|
2712
|
+
else:
|
|
2713
|
+
S1 = []
|
|
2714
|
+
S2 = []
|
|
2715
|
+
S3 = []
|
|
2716
|
+
S4 = []
|
|
2717
|
+
S3P = []
|
|
2718
|
+
VS1 = []
|
|
2719
|
+
VS2 = []
|
|
2720
|
+
VS3 = []
|
|
2721
|
+
VS3P = []
|
|
2722
|
+
VS4 = []
|
|
2681
2723
|
|
|
2682
2724
|
off_S2 = -2
|
|
2683
2725
|
off_S3 = -3
|
|
@@ -2687,11 +2729,6 @@ class funct(FOC.FoCUS):
|
|
|
2687
2729
|
off_S3 = -1
|
|
2688
2730
|
off_S4 = -1
|
|
2689
2731
|
|
|
2690
|
-
# Dictionaries for S3 computation
|
|
2691
|
-
M1_dic = {} # M stands for Module M1 = |I1 * Psi|
|
|
2692
|
-
if cross:
|
|
2693
|
-
M2_dic = {}
|
|
2694
|
-
|
|
2695
2732
|
# S2 for normalization
|
|
2696
2733
|
cond_init_P1_dic = (norm == "self") or (
|
|
2697
2734
|
(norm == "auto") and (self.P1_dic is None)
|
|
@@ -2710,7 +2747,12 @@ class funct(FOC.FoCUS):
|
|
|
2710
2747
|
if return_data:
|
|
2711
2748
|
s0 = I1
|
|
2712
2749
|
if out_nside is not None:
|
|
2713
|
-
s0 = self.backend.bk_reduce_mean(
|
|
2750
|
+
s0 = self.backend.bk_reduce_mean(
|
|
2751
|
+
self.backend.bk_reshape(
|
|
2752
|
+
s0, [s0.shape[0], 12 * out_nside**2, (nside // out_nside) ** 2]
|
|
2753
|
+
),
|
|
2754
|
+
2,
|
|
2755
|
+
)
|
|
2714
2756
|
else:
|
|
2715
2757
|
if not cross:
|
|
2716
2758
|
s0, l_vs0 = self.masked_mean(I1, vmask, axis=1, calc_var=True)
|
|
@@ -2720,17 +2762,38 @@ class funct(FOC.FoCUS):
|
|
|
2720
2762
|
)
|
|
2721
2763
|
vs0 = self.backend.bk_concat([l_vs0, l_vs0], 1)
|
|
2722
2764
|
s0 = self.backend.bk_concat([s0, l_vs0], 1)
|
|
2723
|
-
|
|
2724
2765
|
#### COMPUTE S1, S2, S3 and S4
|
|
2725
2766
|
nside_j3 = nside # NSIDE start (nside_j3 = nside / 2^j3)
|
|
2767
|
+
|
|
2768
|
+
# a remettre comme avant
|
|
2769
|
+
M1_dic={}
|
|
2770
|
+
M2_dic={}
|
|
2771
|
+
|
|
2726
2772
|
for j3 in range(Jmax):
|
|
2773
|
+
|
|
2774
|
+
if edge:
|
|
2775
|
+
if self.mask_mask is None:
|
|
2776
|
+
self.mask_mask={}
|
|
2777
|
+
if self.use_2D:
|
|
2778
|
+
if (vmask.shape[1],vmask.shape[2]) not in self.mask_mask:
|
|
2779
|
+
mask_mask=np.zeros([1,vmask.shape[1],vmask.shape[2]])
|
|
2780
|
+
mask_mask[0,
|
|
2781
|
+
self.KERNELSZ//2:-self.KERNELSZ//2+1,
|
|
2782
|
+
self.KERNELSZ//2:-self.KERNELSZ//2+1]=1.0
|
|
2783
|
+
self.mask_mask[(vmask.shape[1],vmask.shape[2])]=self.backend.bk_cast(mask_mask)
|
|
2784
|
+
vmask=vmask*self.mask_mask[(vmask.shape[1],vmask.shape[2])]
|
|
2785
|
+
#print(self.KERNELSZ//2,vmask,mask_mask)
|
|
2786
|
+
|
|
2787
|
+
if self.use_1D:
|
|
2788
|
+
if (vmask.shape[1]) not in self.mask_mask:
|
|
2789
|
+
mask_mask=np.zeros([1,vmask.shape[1]])
|
|
2790
|
+
mask_mask[0,
|
|
2791
|
+
self.KERNELSZ//2:-self.KERNELSZ//2+1]=1.0
|
|
2792
|
+
self.mask_mask[(vmask.shape[1])]=self.backend.bk_cast(mask_mask)
|
|
2793
|
+
vmask=vmask*self.mask_mask[(vmask.shape[1])]
|
|
2794
|
+
|
|
2727
2795
|
if return_data:
|
|
2728
|
-
if S3 is None:
|
|
2729
|
-
S3 = {}
|
|
2730
2796
|
S3[j3] = None
|
|
2731
|
-
|
|
2732
|
-
if S3P is None:
|
|
2733
|
-
S3P = {}
|
|
2734
2797
|
S3P[j3] = None
|
|
2735
2798
|
|
|
2736
2799
|
if S4 is None:
|
|
@@ -2740,12 +2803,12 @@ class funct(FOC.FoCUS):
|
|
|
2740
2803
|
####### S1 and S2
|
|
2741
2804
|
### Make the convolution I1 * Psi_j3
|
|
2742
2805
|
conv1 = self.convol(I1, axis=1) # [Nbatch, Npix_j3, Norient3]
|
|
2743
|
-
|
|
2744
2806
|
if cmat is not None:
|
|
2745
|
-
tmp2 = self.backend.bk_repeat(conv1,
|
|
2807
|
+
tmp2 = self.backend.bk_repeat(conv1, self.NORIENT, axis=-1)
|
|
2746
2808
|
conv1 = self.backend.bk_reduce_sum(
|
|
2747
2809
|
self.backend.bk_reshape(
|
|
2748
|
-
cmat[j3] * tmp2,
|
|
2810
|
+
cmat[j3] * tmp2,
|
|
2811
|
+
[tmp2.shape[0], cmat[j3].shape[0], self.NORIENT, self.NORIENT],
|
|
2749
2812
|
),
|
|
2750
2813
|
2,
|
|
2751
2814
|
)
|
|
@@ -2781,33 +2844,31 @@ class funct(FOC.FoCUS):
|
|
|
2781
2844
|
if return_data:
|
|
2782
2845
|
if S2 is None:
|
|
2783
2846
|
S2 = {}
|
|
2784
|
-
if out_nside is not None and out_nside<nside_j3:
|
|
2847
|
+
if out_nside is not None and out_nside < nside_j3:
|
|
2785
2848
|
s2 = self.backend.bk_reduce_mean(
|
|
2786
|
-
self.backend.bk_reshape(
|
|
2787
|
-
|
|
2788
|
-
|
|
2789
|
-
|
|
2849
|
+
self.backend.bk_reshape(
|
|
2850
|
+
s2,
|
|
2851
|
+
[
|
|
2852
|
+
s2.shape[0],
|
|
2853
|
+
12 * out_nside**2,
|
|
2854
|
+
(nside_j3 // out_nside) ** 2,
|
|
2855
|
+
s2.shape[2],
|
|
2856
|
+
],
|
|
2857
|
+
),
|
|
2858
|
+
2,
|
|
2859
|
+
)
|
|
2790
2860
|
S2[j3] = s2
|
|
2791
2861
|
else:
|
|
2792
2862
|
if norm == "auto": # Normalize S2
|
|
2793
2863
|
s2 /= P1_dic[j3]
|
|
2794
|
-
|
|
2795
|
-
|
|
2796
|
-
|
|
2864
|
+
|
|
2865
|
+
S2.append(
|
|
2866
|
+
self.backend.bk_expand_dims(s2, off_S2)
|
|
2867
|
+
) # Add a dimension for NS2
|
|
2868
|
+
if calc_var:
|
|
2869
|
+
VS2.append(
|
|
2870
|
+
self.backend.bk_expand_dims(vs2, off_S2)
|
|
2797
2871
|
) # Add a dimension for NS2
|
|
2798
|
-
if calc_var:
|
|
2799
|
-
VS2 = self.backend.bk_expand_dims(
|
|
2800
|
-
vs2, off_S2
|
|
2801
|
-
) # Add a dimension for NS2
|
|
2802
|
-
else:
|
|
2803
|
-
S2 = self.backend.bk_concat(
|
|
2804
|
-
[S2, self.backend.bk_expand_dims(s2, off_S2)], axis=2
|
|
2805
|
-
)
|
|
2806
|
-
if calc_var:
|
|
2807
|
-
VS2 = self.backend.bk_concat(
|
|
2808
|
-
[VS2, self.backend.bk_expand_dims(vs2, off_S2)],
|
|
2809
|
-
axis=2,
|
|
2810
|
-
)
|
|
2811
2872
|
|
|
2812
2873
|
#### S1_auto computation
|
|
2813
2874
|
### Image 1 : S1 = < M1 >_pix
|
|
@@ -2825,45 +2886,47 @@ class funct(FOC.FoCUS):
|
|
|
2825
2886
|
) # [Nbatch, Nmask, Norient3]
|
|
2826
2887
|
|
|
2827
2888
|
if return_data:
|
|
2828
|
-
if
|
|
2829
|
-
S1 = {}
|
|
2830
|
-
if out_nside is not None and out_nside<nside_j3:
|
|
2889
|
+
if out_nside is not None and out_nside < nside_j3:
|
|
2831
2890
|
s1 = self.backend.bk_reduce_mean(
|
|
2832
|
-
self.backend.bk_reshape(
|
|
2833
|
-
|
|
2834
|
-
|
|
2835
|
-
|
|
2891
|
+
self.backend.bk_reshape(
|
|
2892
|
+
s1,
|
|
2893
|
+
[
|
|
2894
|
+
s1.shape[0],
|
|
2895
|
+
12 * out_nside**2,
|
|
2896
|
+
(nside_j3 // out_nside) ** 2,
|
|
2897
|
+
s1.shape[2],
|
|
2898
|
+
],
|
|
2899
|
+
),
|
|
2900
|
+
2,
|
|
2901
|
+
)
|
|
2836
2902
|
S1[j3] = s1
|
|
2837
2903
|
else:
|
|
2838
2904
|
### Normalize S1
|
|
2839
2905
|
if norm is not None:
|
|
2840
2906
|
self.div_norm(s1, (P1_dic[j3]) ** 0.5)
|
|
2841
2907
|
### We store S1 for image1 [Nbatch, Nmask, NS1, Norient3]
|
|
2842
|
-
|
|
2843
|
-
|
|
2844
|
-
|
|
2908
|
+
S1.append(
|
|
2909
|
+
self.backend.bk_expand_dims(s1, off_S2)
|
|
2910
|
+
) # Add a dimension for NS1
|
|
2911
|
+
if calc_var:
|
|
2912
|
+
VS1.append(
|
|
2913
|
+
self.backend.bk_expand_dims(vs1, off_S2)
|
|
2845
2914
|
) # Add a dimension for NS1
|
|
2846
|
-
if calc_var:
|
|
2847
|
-
VS1 = self.backend.bk_expand_dims(
|
|
2848
|
-
vs1, off_S2
|
|
2849
|
-
) # Add a dimension for NS1
|
|
2850
|
-
else:
|
|
2851
|
-
S1 = self.backend.bk_concat(
|
|
2852
|
-
[S1, self.backend.bk_expand_dims(s1, off_S2)], axis=2
|
|
2853
|
-
)
|
|
2854
|
-
if calc_var:
|
|
2855
|
-
VS1 = self.backend.bk_concat(
|
|
2856
|
-
[VS1, self.backend.bk_expand_dims(vs1, off_S2)], axis=2
|
|
2857
|
-
)
|
|
2858
2915
|
|
|
2859
2916
|
else: # Cross
|
|
2860
2917
|
### Make the convolution I2 * Psi_j3
|
|
2861
2918
|
conv2 = self.convol(I2, axis=1) # [Nbatch, Npix_j3, Norient3]
|
|
2862
2919
|
if cmat is not None:
|
|
2863
|
-
tmp2 = self.backend.bk_repeat(conv2,
|
|
2920
|
+
tmp2 = self.backend.bk_repeat(conv2, self.NORIENT, axis=-1)
|
|
2864
2921
|
conv2 = self.backend.bk_reduce_sum(
|
|
2865
2922
|
self.backend.bk_reshape(
|
|
2866
|
-
cmat[j3] * tmp2,
|
|
2923
|
+
cmat[j3] * tmp2,
|
|
2924
|
+
[
|
|
2925
|
+
tmp2.shape[0],
|
|
2926
|
+
cmat[j3].shape[0],
|
|
2927
|
+
self.NORIENT,
|
|
2928
|
+
self.NORIENT,
|
|
2929
|
+
],
|
|
2867
2930
|
),
|
|
2868
2931
|
2,
|
|
2869
2932
|
)
|
|
@@ -2917,14 +2980,19 @@ class funct(FOC.FoCUS):
|
|
|
2917
2980
|
s2 = self.masked_mean(s2, vmask, axis=1, rank=j3)
|
|
2918
2981
|
|
|
2919
2982
|
if return_data:
|
|
2920
|
-
if
|
|
2921
|
-
S2 = {}
|
|
2922
|
-
if out_nside is not None and out_nside<nside_j3:
|
|
2983
|
+
if out_nside is not None and out_nside < nside_j3:
|
|
2923
2984
|
s2 = self.backend.bk_reduce_mean(
|
|
2924
|
-
self.backend.bk_reshape(
|
|
2925
|
-
|
|
2926
|
-
|
|
2927
|
-
|
|
2985
|
+
self.backend.bk_reshape(
|
|
2986
|
+
s2,
|
|
2987
|
+
[
|
|
2988
|
+
s2.shape[0],
|
|
2989
|
+
12 * out_nside**2,
|
|
2990
|
+
(nside_j3 // out_nside) ** 2,
|
|
2991
|
+
s2.shape[2],
|
|
2992
|
+
],
|
|
2993
|
+
),
|
|
2994
|
+
2,
|
|
2995
|
+
)
|
|
2928
2996
|
S2[j3] = s2
|
|
2929
2997
|
else:
|
|
2930
2998
|
### Normalize S2_cross
|
|
@@ -2932,26 +3000,15 @@ class funct(FOC.FoCUS):
|
|
|
2932
3000
|
s2 /= (P1_dic[j3] * P2_dic[j3]) ** 0.5
|
|
2933
3001
|
|
|
2934
3002
|
### Store S2_cross as complex [Nbatch, Nmask, NS2, Norient3]
|
|
2935
|
-
|
|
2936
|
-
s2 = self.backend.bk_real(s2)
|
|
3003
|
+
s2 = self.backend.bk_real(s2)
|
|
2937
3004
|
|
|
2938
|
-
|
|
2939
|
-
|
|
2940
|
-
|
|
3005
|
+
S2.append(
|
|
3006
|
+
self.backend.bk_expand_dims(s2, off_S2)
|
|
3007
|
+
) # Add a dimension for NS2
|
|
3008
|
+
if calc_var:
|
|
3009
|
+
VS2.append(
|
|
3010
|
+
self.backend.bk_expand_dims(vs2, off_S2)
|
|
2941
3011
|
) # Add a dimension for NS2
|
|
2942
|
-
if calc_var:
|
|
2943
|
-
VS2 = self.backend.bk_expand_dims(
|
|
2944
|
-
vs2, off_S2
|
|
2945
|
-
) # Add a dimension for NS2
|
|
2946
|
-
else:
|
|
2947
|
-
S2 = self.backend.bk_concat(
|
|
2948
|
-
[S2, self.backend.bk_expand_dims(s2, off_S2)], axis=2
|
|
2949
|
-
)
|
|
2950
|
-
if calc_var:
|
|
2951
|
-
VS2 = self.backend.bk_concat(
|
|
2952
|
-
[VS2, self.backend.bk_expand_dims(vs2, off_S2)],
|
|
2953
|
-
axis=2,
|
|
2954
|
-
)
|
|
2955
3012
|
|
|
2956
3013
|
#### S1_auto computation
|
|
2957
3014
|
### Image 1 : S1 = < M1 >_pix
|
|
@@ -2968,36 +3025,32 @@ class funct(FOC.FoCUS):
|
|
|
2968
3025
|
MX, vmask, axis=1, rank=j3
|
|
2969
3026
|
) # [Nbatch, Nmask, Norient3]
|
|
2970
3027
|
if return_data:
|
|
2971
|
-
if
|
|
2972
|
-
S1 = {}
|
|
2973
|
-
if out_nside is not None and out_nside<nside_j3:
|
|
3028
|
+
if out_nside is not None and out_nside < nside_j3:
|
|
2974
3029
|
s1 = self.backend.bk_reduce_mean(
|
|
2975
|
-
self.backend.bk_reshape(
|
|
2976
|
-
|
|
2977
|
-
|
|
2978
|
-
|
|
3030
|
+
self.backend.bk_reshape(
|
|
3031
|
+
s1,
|
|
3032
|
+
[
|
|
3033
|
+
s1.shape[0],
|
|
3034
|
+
12 * out_nside**2,
|
|
3035
|
+
(nside_j3 // out_nside) ** 2,
|
|
3036
|
+
s1.shape[2],
|
|
3037
|
+
],
|
|
3038
|
+
),
|
|
3039
|
+
2,
|
|
3040
|
+
)
|
|
2979
3041
|
S1[j3] = s1
|
|
2980
3042
|
else:
|
|
2981
3043
|
### Normalize S1
|
|
2982
3044
|
if norm is not None:
|
|
2983
3045
|
self.div_norm(s1, (P1_dic[j3]) ** 0.5)
|
|
2984
3046
|
### We store S1 for image1 [Nbatch, Nmask, NS1, Norient3]
|
|
2985
|
-
|
|
2986
|
-
|
|
2987
|
-
|
|
3047
|
+
S1.append(
|
|
3048
|
+
self.backend.bk_expand_dims(s1, off_S2)
|
|
3049
|
+
) # Add a dimension for NS1
|
|
3050
|
+
if calc_var:
|
|
3051
|
+
VS1.append(
|
|
3052
|
+
self.backend.bk_expand_dims(vs1, off_S2)
|
|
2988
3053
|
) # Add a dimension for NS1
|
|
2989
|
-
if calc_var:
|
|
2990
|
-
VS1 = self.backend.bk_expand_dims(
|
|
2991
|
-
vs1, off_S2
|
|
2992
|
-
) # Add a dimension for NS1
|
|
2993
|
-
else:
|
|
2994
|
-
S1 = self.backend.bk_concat(
|
|
2995
|
-
[S1, self.backend.bk_expand_dims(s1, off_S2)], axis=2
|
|
2996
|
-
)
|
|
2997
|
-
if calc_var:
|
|
2998
|
-
VS1 = self.backend.bk_concat(
|
|
2999
|
-
[VS1, self.backend.bk_expand_dims(vs1, off_S2)], axis=2
|
|
3000
|
-
)
|
|
3001
3054
|
|
|
3002
3055
|
# Initialize dictionaries for |I1*Psi_j| * Psi_j3
|
|
3003
3056
|
M1convPsi_dic = {}
|
|
@@ -3006,7 +3059,7 @@ class funct(FOC.FoCUS):
|
|
|
3006
3059
|
M2convPsi_dic = {}
|
|
3007
3060
|
|
|
3008
3061
|
###### S3
|
|
3009
|
-
nside_j2=nside_j3
|
|
3062
|
+
nside_j2 = nside_j3
|
|
3010
3063
|
for j2 in range(0, j3 + 1): # j2 <= j3
|
|
3011
3064
|
if return_data:
|
|
3012
3065
|
if S4[j3] is None:
|
|
@@ -3041,13 +3094,20 @@ class funct(FOC.FoCUS):
|
|
|
3041
3094
|
if return_data:
|
|
3042
3095
|
if S3[j3] is None:
|
|
3043
3096
|
S3[j3] = {}
|
|
3044
|
-
if out_nside is not None and out_nside<nside_j2:
|
|
3097
|
+
if out_nside is not None and out_nside < nside_j2:
|
|
3045
3098
|
s3 = self.backend.bk_reduce_mean(
|
|
3046
|
-
self.backend.bk_reshape(
|
|
3047
|
-
|
|
3048
|
-
|
|
3049
|
-
|
|
3050
|
-
|
|
3099
|
+
self.backend.bk_reshape(
|
|
3100
|
+
s3,
|
|
3101
|
+
[
|
|
3102
|
+
s3.shape[0],
|
|
3103
|
+
12 * out_nside**2,
|
|
3104
|
+
(nside_j2 // out_nside) ** 2,
|
|
3105
|
+
s3.shape[2],
|
|
3106
|
+
s3.shape[3],
|
|
3107
|
+
],
|
|
3108
|
+
),
|
|
3109
|
+
2,
|
|
3110
|
+
)
|
|
3051
3111
|
S3[j3][j2] = s3
|
|
3052
3112
|
else:
|
|
3053
3113
|
### Normalize S3 with S2_j [Nbatch, Nmask, Norient_j]
|
|
@@ -3062,23 +3122,18 @@ class funct(FOC.FoCUS):
|
|
|
3062
3122
|
) # [Nbatch, Nmask, Norient3, Norient2]
|
|
3063
3123
|
|
|
3064
3124
|
### Store S3 as a complex [Nbatch, Nmask, NS3, Norient3, Norient2]
|
|
3065
|
-
|
|
3066
|
-
|
|
3067
|
-
|
|
3068
|
-
|
|
3069
|
-
|
|
3070
|
-
|
|
3071
|
-
|
|
3072
|
-
|
|
3073
|
-
|
|
3074
|
-
S3 = self.backend.bk_concat(
|
|
3075
|
-
[S3, self.backend.bk_expand_dims(s3, off_S3)], axis=2
|
|
3125
|
+
|
|
3126
|
+
# S3.append(self.backend.bk_reshape(s3,[s3.shape[0],s3.shape[1],
|
|
3127
|
+
# s3.shape[2]*s3.shape[3]]))
|
|
3128
|
+
S3.append(
|
|
3129
|
+
self.backend.bk_expand_dims(s3, off_S3)
|
|
3130
|
+
) # Add a dimension for NS3
|
|
3131
|
+
if calc_var:
|
|
3132
|
+
VS3.append(
|
|
3133
|
+
self.backend.bk_expand_dims(vs3, off_S3)
|
|
3076
3134
|
) # Add a dimension for NS3
|
|
3077
|
-
|
|
3078
|
-
|
|
3079
|
-
[VS3, self.backend.bk_expand_dims(vs3, off_S3)],
|
|
3080
|
-
axis=2,
|
|
3081
|
-
) # Add a dimension for NS3
|
|
3135
|
+
# VS3.append(self.backend.bk_reshape(vs3,[s3.shape[0],s3.shape[1],
|
|
3136
|
+
# s3.shape[2]*s3.shape[3]]))
|
|
3082
3137
|
|
|
3083
3138
|
### S3_cross = < (I1 * Psi)_j3 x (|I2 * Psi_j2| * Psi_j3)^* >_pix
|
|
3084
3139
|
### S3P_cross = < (I2 * Psi)_j3 x (|I1 * Psi_j2| * Psi_j3)^* >_pix
|
|
@@ -3130,19 +3185,33 @@ class funct(FOC.FoCUS):
|
|
|
3130
3185
|
if S3[j3] is None:
|
|
3131
3186
|
S3[j3] = {}
|
|
3132
3187
|
S3P[j3] = {}
|
|
3133
|
-
if out_nside is not None and out_nside<nside_j2:
|
|
3188
|
+
if out_nside is not None and out_nside < nside_j2:
|
|
3134
3189
|
s3 = self.backend.bk_reduce_mean(
|
|
3135
|
-
self.backend.bk_reshape(
|
|
3136
|
-
|
|
3137
|
-
|
|
3138
|
-
|
|
3139
|
-
|
|
3190
|
+
self.backend.bk_reshape(
|
|
3191
|
+
s3,
|
|
3192
|
+
[
|
|
3193
|
+
s3.shape[0],
|
|
3194
|
+
12 * out_nside**2,
|
|
3195
|
+
(nside_j2 // out_nside) ** 2,
|
|
3196
|
+
s3.shape[2],
|
|
3197
|
+
s3.shape[3],
|
|
3198
|
+
],
|
|
3199
|
+
),
|
|
3200
|
+
2,
|
|
3201
|
+
)
|
|
3140
3202
|
s3p = self.backend.bk_reduce_mean(
|
|
3141
|
-
self.backend.bk_reshape(
|
|
3142
|
-
|
|
3143
|
-
|
|
3144
|
-
|
|
3145
|
-
|
|
3203
|
+
self.backend.bk_reshape(
|
|
3204
|
+
s3p,
|
|
3205
|
+
[
|
|
3206
|
+
s3.shape[0],
|
|
3207
|
+
12 * out_nside**2,
|
|
3208
|
+
(nside_j2 // out_nside) ** 2,
|
|
3209
|
+
s3.shape[2],
|
|
3210
|
+
s3.shape[3],
|
|
3211
|
+
],
|
|
3212
|
+
),
|
|
3213
|
+
2,
|
|
3214
|
+
)
|
|
3146
3215
|
S3[j3][j2] = s3
|
|
3147
3216
|
S3P[j3][j2] = s3p
|
|
3148
3217
|
else:
|
|
@@ -3166,43 +3235,34 @@ class funct(FOC.FoCUS):
|
|
|
3166
3235
|
) # [Nbatch, Nmask, Norient3, Norient2]
|
|
3167
3236
|
|
|
3168
3237
|
### Store S3 and S3P as a complex [Nbatch, Nmask, NS3, Norient3, Norient2]
|
|
3169
|
-
|
|
3170
|
-
|
|
3171
|
-
|
|
3172
|
-
|
|
3173
|
-
|
|
3174
|
-
|
|
3175
|
-
|
|
3176
|
-
|
|
3177
|
-
|
|
3178
|
-
S3 = self.backend.bk_concat(
|
|
3179
|
-
[S3, self.backend.bk_expand_dims(s3, off_S3)], axis=2
|
|
3180
|
-
) # Add a dimension for NS3
|
|
3181
|
-
if calc_var:
|
|
3182
|
-
VS3 = self.backend.bk_concat(
|
|
3183
|
-
[VS3, self.backend.bk_expand_dims(vs3, off_S3)],
|
|
3184
|
-
axis=2,
|
|
3185
|
-
) # Add a dimension for NS3
|
|
3186
|
-
if S3P is None:
|
|
3187
|
-
S3P = self.backend.bk_expand_dims(
|
|
3188
|
-
s3p, off_S3
|
|
3238
|
+
|
|
3239
|
+
# S3.append(self.backend.bk_reshape(s3,[s3.shape[0],s3.shape[1],
|
|
3240
|
+
# s3.shape[2]*s3.shape[3]]))
|
|
3241
|
+
S3.append(
|
|
3242
|
+
self.backend.bk_expand_dims(s3, off_S3)
|
|
3243
|
+
) # Add a dimension for NS3
|
|
3244
|
+
if calc_var:
|
|
3245
|
+
VS3.append(
|
|
3246
|
+
self.backend.bk_expand_dims(vs3, off_S3)
|
|
3189
3247
|
) # Add a dimension for NS3
|
|
3190
|
-
|
|
3191
|
-
|
|
3192
|
-
|
|
3193
|
-
|
|
3194
|
-
|
|
3195
|
-
|
|
3196
|
-
|
|
3248
|
+
|
|
3249
|
+
# VS3.append(self.backend.bk_reshape(vs3,[s3.shape[0],s3.shape[1],
|
|
3250
|
+
# s3.shape[2]*s3.shape[3]]))
|
|
3251
|
+
|
|
3252
|
+
# S3P.append(self.backend.bk_reshape(s3p,[s3.shape[0],s3.shape[1],
|
|
3253
|
+
# s3.shape[2]*s3.shape[3]]))
|
|
3254
|
+
S3P.append(
|
|
3255
|
+
self.backend.bk_expand_dims(s3p, off_S3)
|
|
3256
|
+
) # Add a dimension for NS3
|
|
3257
|
+
if calc_var:
|
|
3258
|
+
VS3P.append(
|
|
3259
|
+
self.backend.bk_expand_dims(vs3p, off_S3)
|
|
3197
3260
|
) # Add a dimension for NS3
|
|
3198
|
-
|
|
3199
|
-
|
|
3200
|
-
[VS3P, self.backend.bk_expand_dims(vs3p, off_S3)],
|
|
3201
|
-
axis=2,
|
|
3202
|
-
) # Add a dimension for NS3
|
|
3261
|
+
# VS3P.append(self.backend.bk_reshape(vs3p,[s3.shape[0],s3.shape[1],
|
|
3262
|
+
# s3.shape[2]*s3.shape[3]]))
|
|
3203
3263
|
|
|
3204
3264
|
##### S4
|
|
3205
|
-
nside_j1=nside_j2
|
|
3265
|
+
nside_j1 = nside_j2
|
|
3206
3266
|
for j1 in range(0, j2 + 1): # j1 <= j2
|
|
3207
3267
|
### S4_auto = <(|I1 * psi1| * psi3)(|I1 * psi2| * psi3)^*>
|
|
3208
3268
|
if not cross:
|
|
@@ -3228,14 +3288,21 @@ class funct(FOC.FoCUS):
|
|
|
3228
3288
|
if return_data:
|
|
3229
3289
|
if S4[j3][j2] is None:
|
|
3230
3290
|
S4[j3][j2] = {}
|
|
3231
|
-
if out_nside is not None and out_nside<nside_j1:
|
|
3291
|
+
if out_nside is not None and out_nside < nside_j1:
|
|
3232
3292
|
s4 = self.backend.bk_reduce_mean(
|
|
3233
|
-
self.backend.bk_reshape(
|
|
3234
|
-
|
|
3235
|
-
|
|
3236
|
-
|
|
3237
|
-
|
|
3238
|
-
|
|
3293
|
+
self.backend.bk_reshape(
|
|
3294
|
+
s4,
|
|
3295
|
+
[
|
|
3296
|
+
s4.shape[0],
|
|
3297
|
+
12 * out_nside**2,
|
|
3298
|
+
(nside_j1 // out_nside) ** 2,
|
|
3299
|
+
s4.shape[2],
|
|
3300
|
+
s4.shape[3],
|
|
3301
|
+
s4.shape[4],
|
|
3302
|
+
],
|
|
3303
|
+
),
|
|
3304
|
+
2,
|
|
3305
|
+
)
|
|
3239
3306
|
S4[j3][j2][j1] = s4
|
|
3240
3307
|
else:
|
|
3241
3308
|
### Normalize S4 with S2_j [Nbatch, Nmask, Norient_j]
|
|
@@ -3259,27 +3326,18 @@ class funct(FOC.FoCUS):
|
|
|
3259
3326
|
** 0.5,
|
|
3260
3327
|
) # [Nbatch, Nmask, Norient3, Norient2, Norient1]
|
|
3261
3328
|
### Store S4 as a complex [Nbatch, Nmask, NS4, Norient3, Norient2, Norient1]
|
|
3262
|
-
|
|
3263
|
-
|
|
3264
|
-
|
|
3265
|
-
|
|
3266
|
-
|
|
3267
|
-
|
|
3268
|
-
|
|
3269
|
-
|
|
3270
|
-
|
|
3271
|
-
|
|
3272
|
-
|
|
3273
|
-
axis=2,
|
|
3329
|
+
|
|
3330
|
+
# S4.append(self.backend.bk_reshape(s4,[s4.shape[0],s4.shape[1],
|
|
3331
|
+
# s4.shape[2]*s4.shape[3]*s4.shape[4]]))
|
|
3332
|
+
S4.append(
|
|
3333
|
+
self.backend.bk_expand_dims(s4, off_S4)
|
|
3334
|
+
) # Add a dimension for NS4
|
|
3335
|
+
if calc_var:
|
|
3336
|
+
# VS4.append(self.backend.bk_reshape(vs4,[s4.shape[0],s4.shape[1],
|
|
3337
|
+
# s4.shape[2]*s4.shape[3]*s4.shape[4]]))
|
|
3338
|
+
VS4.append(
|
|
3339
|
+
self.backend.bk_expand_dims(vs4, off_S4)
|
|
3274
3340
|
) # Add a dimension for NS4
|
|
3275
|
-
if calc_var:
|
|
3276
|
-
VS4 = self.backend.bk_concat(
|
|
3277
|
-
[
|
|
3278
|
-
VS4,
|
|
3279
|
-
self.backend.bk_expand_dims(vs4, off_S4),
|
|
3280
|
-
],
|
|
3281
|
-
axis=2,
|
|
3282
|
-
) # Add a dimension for NS4
|
|
3283
3341
|
|
|
3284
3342
|
### S4_cross = <(|I1 * psi1| * psi3)(|I2 * psi2| * psi3)^*>
|
|
3285
3343
|
else:
|
|
@@ -3305,14 +3363,21 @@ class funct(FOC.FoCUS):
|
|
|
3305
3363
|
if return_data:
|
|
3306
3364
|
if S4[j3][j2] is None:
|
|
3307
3365
|
S4[j3][j2] = {}
|
|
3308
|
-
if out_nside is not None and out_nside<nside_j1:
|
|
3366
|
+
if out_nside is not None and out_nside < nside_j1:
|
|
3309
3367
|
s4 = self.backend.bk_reduce_mean(
|
|
3310
|
-
self.backend.bk_reshape(
|
|
3311
|
-
|
|
3312
|
-
|
|
3313
|
-
|
|
3314
|
-
|
|
3315
|
-
|
|
3368
|
+
self.backend.bk_reshape(
|
|
3369
|
+
s4,
|
|
3370
|
+
[
|
|
3371
|
+
s4.shape[0],
|
|
3372
|
+
12 * out_nside**2,
|
|
3373
|
+
(nside_j1 // out_nside) ** 2,
|
|
3374
|
+
s4.shape[2],
|
|
3375
|
+
s4.shape[3],
|
|
3376
|
+
s4.shape[4],
|
|
3377
|
+
],
|
|
3378
|
+
),
|
|
3379
|
+
2,
|
|
3380
|
+
)
|
|
3316
3381
|
S4[j3][j2][j1] = s4
|
|
3317
3382
|
else:
|
|
3318
3383
|
### Normalize S4 with S2_j [Nbatch, Nmask, Norient_j]
|
|
@@ -3336,41 +3401,33 @@ class funct(FOC.FoCUS):
|
|
|
3336
3401
|
** 0.5,
|
|
3337
3402
|
) # [Nbatch, Nmask, Norient3, Norient2, Norient1]
|
|
3338
3403
|
### Store S4 as a complex [Nbatch, Nmask, NS4, Norient3, Norient2, Norient1]
|
|
3339
|
-
|
|
3340
|
-
|
|
3341
|
-
|
|
3342
|
-
)
|
|
3343
|
-
|
|
3344
|
-
|
|
3345
|
-
|
|
3346
|
-
|
|
3347
|
-
|
|
3348
|
-
|
|
3349
|
-
|
|
3350
|
-
axis=2,
|
|
3404
|
+
# S4.append(self.backend.bk_reshape(s4,[s4.shape[0],s4.shape[1],
|
|
3405
|
+
# s4.shape[2]*s4.shape[3]*s4.shape[4]]))
|
|
3406
|
+
S4.append(
|
|
3407
|
+
self.backend.bk_expand_dims(s4, off_S4)
|
|
3408
|
+
) # Add a dimension for NS4
|
|
3409
|
+
if calc_var:
|
|
3410
|
+
|
|
3411
|
+
# VS4.append(self.backend.bk_reshape(vs4,[s4.shape[0],s4.shape[1],
|
|
3412
|
+
# s4.shape[2]*s4.shape[3]*s4.shape[4]]))
|
|
3413
|
+
VS4.append(
|
|
3414
|
+
self.backend.bk_expand_dims(vs4, off_S4)
|
|
3351
3415
|
) # Add a dimension for NS4
|
|
3352
|
-
|
|
3353
|
-
|
|
3354
|
-
|
|
3355
|
-
|
|
3356
|
-
self.backend.bk_expand_dims(vs4, off_S4),
|
|
3357
|
-
],
|
|
3358
|
-
axis=2,
|
|
3359
|
-
) # Add a dimension for NS4
|
|
3360
|
-
nside_j1=nside_j1 // 2
|
|
3361
|
-
nside_j2=nside_j2 // 2
|
|
3362
|
-
|
|
3416
|
+
|
|
3417
|
+
nside_j1 = nside_j1 // 2
|
|
3418
|
+
nside_j2 = nside_j2 // 2
|
|
3419
|
+
|
|
3363
3420
|
###### Reshape for next iteration on j3
|
|
3364
3421
|
### Image I1,
|
|
3365
3422
|
# downscale the I1 [Nbatch, Npix_j3]
|
|
3366
3423
|
if j3 != Jmax - 1:
|
|
3367
|
-
|
|
3368
|
-
I1 = self.ud_grade_2(
|
|
3424
|
+
I1 = self.smooth(I1, axis=1)
|
|
3425
|
+
I1 = self.ud_grade_2(I1, axis=1)
|
|
3369
3426
|
|
|
3370
3427
|
### Image I2
|
|
3371
3428
|
if cross:
|
|
3372
|
-
|
|
3373
|
-
I2 = self.ud_grade_2(
|
|
3429
|
+
I2 = self.smooth(I2, axis=1)
|
|
3430
|
+
I2 = self.ud_grade_2(I2, axis=1)
|
|
3374
3431
|
|
|
3375
3432
|
### Modules
|
|
3376
3433
|
for j2 in range(0, j3 + 1): # j2 =< j3
|
|
@@ -3388,8 +3445,9 @@ class funct(FOC.FoCUS):
|
|
|
3388
3445
|
M2_dic[j2], axis=1
|
|
3389
3446
|
) # [Nbatch, Npix_j3, Norient3]
|
|
3390
3447
|
M2_dic[j2] = self.ud_grade_2(
|
|
3391
|
-
|
|
3448
|
+
M2, axis=1
|
|
3392
3449
|
) # [Nbatch, Npix_j3, Norient3]
|
|
3450
|
+
|
|
3393
3451
|
### Mask
|
|
3394
3452
|
vmask = self.ud_grade_2(vmask, axis=1)
|
|
3395
3453
|
|
|
@@ -3404,7 +3462,33 @@ class funct(FOC.FoCUS):
|
|
|
3404
3462
|
self.P1_dic = P1_dic
|
|
3405
3463
|
if cross:
|
|
3406
3464
|
self.P2_dic = P2_dic
|
|
3465
|
+
"""
|
|
3466
|
+
Sout=[s0]+S1+S2+S3+S4
|
|
3467
|
+
|
|
3468
|
+
if cross:
|
|
3469
|
+
Sout=Sout+S3P
|
|
3470
|
+
if calc_var:
|
|
3471
|
+
SVout=[vs0]+VS1+VS2+VS3+VS4
|
|
3472
|
+
if cross:
|
|
3473
|
+
VSout=VSout+VS3P
|
|
3474
|
+
return self.backend.bk_concat(Sout, 2),self.backend.bk_concat(VSout, 2)
|
|
3407
3475
|
|
|
3476
|
+
return self.backend.bk_concat(Sout, 2)
|
|
3477
|
+
"""
|
|
3478
|
+
if not return_data:
|
|
3479
|
+
S1 = self.backend.bk_concat(S1, 2)
|
|
3480
|
+
S2 = self.backend.bk_concat(S2, 2)
|
|
3481
|
+
S3 = self.backend.bk_concat(S3, 2)
|
|
3482
|
+
S4 = self.backend.bk_concat(S4, 2)
|
|
3483
|
+
if cross:
|
|
3484
|
+
S3P = self.backend.bk_concat(S3P, 2)
|
|
3485
|
+
if calc_var:
|
|
3486
|
+
VS1 = self.backend.bk_concat(VS1, 2)
|
|
3487
|
+
VS2 = self.backend.bk_concat(VS2, 2)
|
|
3488
|
+
VS3 = self.backend.bk_concat(VS3, 2)
|
|
3489
|
+
VS4 = self.backend.bk_concat(VS4, 2)
|
|
3490
|
+
if cross:
|
|
3491
|
+
VS3P = self.backend.bk_concat(VS3P, 2)
|
|
3408
3492
|
if calc_var:
|
|
3409
3493
|
if not cross:
|
|
3410
3494
|
return scat_cov(
|
|
@@ -3455,70 +3539,1146 @@ class funct(FOC.FoCUS):
|
|
|
3455
3539
|
use_1D=self.use_1D,
|
|
3456
3540
|
)
|
|
3457
3541
|
|
|
3458
|
-
def
|
|
3459
|
-
|
|
3460
|
-
|
|
3461
|
-
|
|
3462
|
-
|
|
3463
|
-
|
|
3464
|
-
|
|
3465
|
-
|
|
3466
|
-
|
|
3467
|
-
|
|
3468
|
-
|
|
3469
|
-
|
|
3470
|
-
MconvPsi_dic,
|
|
3471
|
-
calc_var=False,
|
|
3472
|
-
return_data=False,
|
|
3473
|
-
cmat2=None,
|
|
3542
|
+
def eval_new(
|
|
3543
|
+
self,
|
|
3544
|
+
image1,
|
|
3545
|
+
image2=None,
|
|
3546
|
+
mask=None,
|
|
3547
|
+
norm=None,
|
|
3548
|
+
calc_var=False,
|
|
3549
|
+
cmat=None,
|
|
3550
|
+
cmat2=None,
|
|
3551
|
+
Jmax=None,
|
|
3552
|
+
out_nside=None,
|
|
3553
|
+
edge=True
|
|
3474
3554
|
):
|
|
3475
3555
|
"""
|
|
3476
|
-
|
|
3477
|
-
|
|
3556
|
+
Calculates the scattering correlations for a batch of images. Mean are done over pixels.
|
|
3557
|
+
mean of modulus:
|
|
3558
|
+
S1 = <|I * Psi_j3|>
|
|
3559
|
+
Normalization : take the log
|
|
3560
|
+
power spectrum:
|
|
3561
|
+
S2 = <|I * Psi_j3|^2>
|
|
3562
|
+
Normalization : take the log
|
|
3563
|
+
orig. x modulus:
|
|
3564
|
+
S3 = < (I * Psi)_j3 x (|I * Psi_j2| * Psi_j3)^* >
|
|
3565
|
+
Normalization : divide by (S2_j2 * S2_j3)^0.5
|
|
3566
|
+
modulus x modulus:
|
|
3567
|
+
S4 = <(|I * psi1| * psi3)(|I * psi2| * psi3)^*>
|
|
3568
|
+
Normalization : divide by (S2_j1 * S2_j2)^0.5
|
|
3478
3569
|
Parameters
|
|
3479
3570
|
----------
|
|
3571
|
+
image1: tensor
|
|
3572
|
+
Image on which we compute the scattering coefficients [Nbatch, Npix, 1, 1]
|
|
3573
|
+
image2: tensor
|
|
3574
|
+
Second image. If not None, we compute cross-scattering covariance coefficients.
|
|
3575
|
+
mask:
|
|
3576
|
+
norm: None or str
|
|
3577
|
+
If None no normalization is applied, if 'auto' normalize by the reference S2,
|
|
3578
|
+
if 'self' normalize by the current S2.
|
|
3480
3579
|
Returns
|
|
3481
3580
|
-------
|
|
3482
|
-
|
|
3581
|
+
S1, S2, S3, S4 normalized
|
|
3483
3582
|
"""
|
|
3484
|
-
|
|
3485
|
-
|
|
3486
|
-
|
|
3487
|
-
|
|
3488
|
-
|
|
3489
|
-
|
|
3490
|
-
|
|
3491
|
-
|
|
3492
|
-
|
|
3493
|
-
|
|
3494
|
-
|
|
3495
|
-
|
|
3583
|
+
return_data = self.return_data
|
|
3584
|
+
NORIENT=self.NORIENT
|
|
3585
|
+
# Check input consistency
|
|
3586
|
+
if image2 is not None:
|
|
3587
|
+
if list(image1.shape) != list(image2.shape):
|
|
3588
|
+
print(
|
|
3589
|
+
"The two input image should have the same size to eval Scattering Covariance"
|
|
3590
|
+
)
|
|
3591
|
+
return None
|
|
3592
|
+
if mask is not None:
|
|
3593
|
+
if image1.shape[-2] != mask.shape[1] or image1.shape[-1] != mask.shape[2]:
|
|
3594
|
+
print(
|
|
3595
|
+
"The LAST COLUMN of the mask should have the same size ",
|
|
3596
|
+
mask.shape,
|
|
3597
|
+
"than the input image ",
|
|
3598
|
+
image1.shape,
|
|
3599
|
+
"to eval Scattering Covariance",
|
|
3600
|
+
)
|
|
3601
|
+
return None
|
|
3602
|
+
if self.use_2D and len(image1.shape) < 2:
|
|
3603
|
+
print(
|
|
3604
|
+
"To work with 2D scattering transform, two dimension is needed, input map has only on dimension"
|
|
3496
3605
|
)
|
|
3606
|
+
return None
|
|
3497
3607
|
|
|
3498
|
-
|
|
3499
|
-
|
|
3500
|
-
|
|
3501
|
-
|
|
3502
|
-
# z_1 x z_2^* = (a1a2 + b1b2) + i(b1a2 - a1b2)
|
|
3503
|
-
# cconv, sconv are [Nbatch, Npix_j3, Norient3]
|
|
3504
|
-
if self.use_1D:
|
|
3505
|
-
s3 = conv * self.backend.bk_conjugate(MconvPsi)
|
|
3506
|
-
else:
|
|
3507
|
-
s3 = self.backend.bk_expand_dims(conv, -1) * self.backend.bk_conjugate(
|
|
3508
|
-
MconvPsi
|
|
3509
|
-
) # [Nbatch, Npix_j3, Norient3, Norient2]
|
|
3608
|
+
### AUTO OR CROSS
|
|
3609
|
+
cross = False
|
|
3610
|
+
if image2 is not None:
|
|
3611
|
+
cross = True
|
|
3510
3612
|
|
|
3511
|
-
###
|
|
3512
|
-
|
|
3513
|
-
|
|
3514
|
-
|
|
3515
|
-
|
|
3516
|
-
|
|
3517
|
-
|
|
3518
|
-
|
|
3519
|
-
|
|
3613
|
+
### PARAMETERS
|
|
3614
|
+
axis = 1
|
|
3615
|
+
# determine jmax and nside corresponding to the input map
|
|
3616
|
+
im_shape = image1.shape
|
|
3617
|
+
if self.use_2D:
|
|
3618
|
+
if len(image1.shape) == 2:
|
|
3619
|
+
nside = np.min([im_shape[0], im_shape[1]])
|
|
3620
|
+
npix = im_shape[0] * im_shape[1] # Number of pixels
|
|
3621
|
+
x1 = im_shape[0]
|
|
3622
|
+
x2 = im_shape[1]
|
|
3520
3623
|
else:
|
|
3521
|
-
|
|
3624
|
+
nside = np.min([im_shape[1], im_shape[2]])
|
|
3625
|
+
npix = im_shape[1] * im_shape[2] # Number of pixels
|
|
3626
|
+
x1 = im_shape[1]
|
|
3627
|
+
x2 = im_shape[2]
|
|
3628
|
+
J = int(np.log(nside - self.KERNELSZ) / np.log(2)) # Number of j scales
|
|
3629
|
+
elif self.use_1D:
|
|
3630
|
+
if len(image1.shape) == 2:
|
|
3631
|
+
npix = int(im_shape[1]) # Number of pixels
|
|
3632
|
+
else:
|
|
3633
|
+
npix = int(im_shape[0]) # Number of pixels
|
|
3634
|
+
|
|
3635
|
+
nside = int(npix)
|
|
3636
|
+
|
|
3637
|
+
J = int(np.log(nside) / np.log(2)) # Number of j scales
|
|
3638
|
+
else:
|
|
3639
|
+
if len(image1.shape) == 2:
|
|
3640
|
+
npix = int(im_shape[1]) # Number of pixels
|
|
3641
|
+
else:
|
|
3642
|
+
npix = int(im_shape[0]) # Number of pixels
|
|
3643
|
+
|
|
3644
|
+
nside = int(np.sqrt(npix // 12))
|
|
3645
|
+
|
|
3646
|
+
J = int(np.log(nside) / np.log(2)) # Number of j scales
|
|
3647
|
+
|
|
3648
|
+
if (self.use_2D or self.use_1D) and self.KERNELSZ>3:
|
|
3649
|
+
J-=1
|
|
3650
|
+
if Jmax is None:
|
|
3651
|
+
Jmax = J # Number of steps for the loop on scales
|
|
3652
|
+
if Jmax>J:
|
|
3653
|
+
print('==========\n\n')
|
|
3654
|
+
print('The Jmax you requested is larger than the data size, which may cause problems while computing the scattering transform.')
|
|
3655
|
+
print('\n\n==========')
|
|
3656
|
+
|
|
3657
|
+
|
|
3658
|
+
### LOCAL VARIABLES (IMAGES and MASK)
|
|
3659
|
+
if len(image1.shape) == 1 or (len(image1.shape) == 2 and self.use_2D):
|
|
3660
|
+
I1 = self.backend.bk_cast(
|
|
3661
|
+
self.backend.bk_expand_dims(image1, 0)
|
|
3662
|
+
) # Local image1 [Nbatch, Npix]
|
|
3663
|
+
if cross:
|
|
3664
|
+
I2 = self.backend.bk_cast(
|
|
3665
|
+
self.backend.bk_expand_dims(image2, 0)
|
|
3666
|
+
) # Local image2 [Nbatch, Npix]
|
|
3667
|
+
else:
|
|
3668
|
+
I1 = self.backend.bk_cast(image1) # Local image1 [Nbatch, Npix]
|
|
3669
|
+
if cross:
|
|
3670
|
+
I2 = self.backend.bk_cast(image2) # Local image2 [Nbatch, Npix]
|
|
3671
|
+
|
|
3672
|
+
if mask is None:
|
|
3673
|
+
if self.use_2D:
|
|
3674
|
+
vmask = self.backend.bk_ones([1, x1, x2], dtype=self.all_type)
|
|
3675
|
+
else:
|
|
3676
|
+
vmask = self.backend.bk_ones([1, npix], dtype=self.all_type)
|
|
3677
|
+
else:
|
|
3678
|
+
vmask = self.backend.bk_cast(mask) # [Nmask, Npix]
|
|
3679
|
+
|
|
3680
|
+
if self.KERNELSZ > 3 and not self.use_2D:
|
|
3681
|
+
# if the kernel size is bigger than 3 increase the binning before smoothing
|
|
3682
|
+
if self.use_2D:
|
|
3683
|
+
vmask = self.up_grade(
|
|
3684
|
+
vmask, I1.shape[axis] * 2, axis=1, nouty=I1.shape[axis + 1] * 2
|
|
3685
|
+
)
|
|
3686
|
+
I1 = self.up_grade(
|
|
3687
|
+
I1, I1.shape[axis] * 2, axis=axis, nouty=I1.shape[axis + 1] * 2
|
|
3688
|
+
)
|
|
3689
|
+
if cross:
|
|
3690
|
+
I2 = self.up_grade(
|
|
3691
|
+
I2, I2.shape[axis] * 2, axis=axis, nouty=I2.shape[axis + 1] * 2
|
|
3692
|
+
)
|
|
3693
|
+
elif self.use_1D:
|
|
3694
|
+
vmask = self.up_grade(vmask, I1.shape[axis] * 2, axis=1)
|
|
3695
|
+
I1 = self.up_grade(I1, I1.shape[axis] * 2, axis=axis)
|
|
3696
|
+
if cross:
|
|
3697
|
+
I2 = self.up_grade(I2, I2.shape[axis] * 2, axis=axis)
|
|
3698
|
+
else:
|
|
3699
|
+
I1 = self.up_grade(I1, nside * 2, axis=axis)
|
|
3700
|
+
vmask = self.up_grade(vmask, nside * 2, axis=1)
|
|
3701
|
+
if cross:
|
|
3702
|
+
I2 = self.up_grade(I2, nside * 2, axis=axis)
|
|
3703
|
+
|
|
3704
|
+
if self.KERNELSZ > 5 and not self.use_2D:
|
|
3705
|
+
# if the kernel size is bigger than 3 increase the binning before smoothing
|
|
3706
|
+
if self.use_2D:
|
|
3707
|
+
vmask = self.up_grade(
|
|
3708
|
+
vmask, I1.shape[axis] * 2, axis=1, nouty=I1.shape[axis + 1] * 2
|
|
3709
|
+
)
|
|
3710
|
+
I1 = self.up_grade(
|
|
3711
|
+
I1, I1.shape[axis] * 2, axis=axis, nouty=I1.shape[axis + 1] * 2
|
|
3712
|
+
)
|
|
3713
|
+
if cross:
|
|
3714
|
+
I2 = self.up_grade(
|
|
3715
|
+
I2,
|
|
3716
|
+
I2.shape[axis] * 2,
|
|
3717
|
+
axis=axis,
|
|
3718
|
+
nouty=I2.shape[axis + 1] * 2,
|
|
3719
|
+
)
|
|
3720
|
+
elif self.use_1D:
|
|
3721
|
+
vmask = self.up_grade(vmask, I1.shape[axis] * 4, axis=1)
|
|
3722
|
+
I1 = self.up_grade(I1, I1.shape[axis] * 4, axis=axis)
|
|
3723
|
+
if cross:
|
|
3724
|
+
I2 = self.up_grade(I2, I2.shape[axis] * 4, axis=axis)
|
|
3725
|
+
else:
|
|
3726
|
+
I1 = self.up_grade(I1, nside * 4, axis=axis)
|
|
3727
|
+
vmask = self.up_grade(vmask, nside * 4, axis=1)
|
|
3728
|
+
if cross:
|
|
3729
|
+
I2 = self.up_grade(I2, nside * 4, axis=axis)
|
|
3730
|
+
|
|
3731
|
+
# Normalize the masks because they have different pixel numbers
|
|
3732
|
+
# vmask /= self.backend.bk_reduce_sum(vmask, axis=1)[:, None] # [Nmask, Npix]
|
|
3733
|
+
|
|
3734
|
+
### INITIALIZATION
|
|
3735
|
+
# Coefficients
|
|
3736
|
+
if return_data:
|
|
3737
|
+
S1 = {}
|
|
3738
|
+
S2 = {}
|
|
3739
|
+
S3 = {}
|
|
3740
|
+
S3P = {}
|
|
3741
|
+
S4 = {}
|
|
3742
|
+
else:
|
|
3743
|
+
result=self.backend.backend.zeros([I1.shape[0],vmask.shape[0],2+2*Jmax*self.NORIENT],
|
|
3744
|
+
dtype=self.backend.backend.float32,
|
|
3745
|
+
device=self.backend.torch_device)
|
|
3746
|
+
vresult=self.backend.backend.zeros([I1.shape[0],vmask.shape[0],2+2*Jmax*self.NORIENT],
|
|
3747
|
+
dtype=self.backend.backend.float32,
|
|
3748
|
+
device=self.backend.torch_device)
|
|
3749
|
+
S1 = self.backend.backend.zeros([1,Jmax*self.NORIENT],dtype=self.backend.backend.float32,device=self.backend.torch_device)
|
|
3750
|
+
S2 = self.backend.backend.zeros([1,Jmax*self.NORIENT],dtype=self.backend.backend.float32,device=self.backend.torch_device)
|
|
3751
|
+
S3 = []
|
|
3752
|
+
S4 = []
|
|
3753
|
+
S3P = []
|
|
3754
|
+
VS1 = self.backend.backend.zeros([1,Jmax*self.NORIENT],dtype=self.backend.backend.float32,device=self.backend.torch_device)
|
|
3755
|
+
VS2 = self.backend.backend.zeros([1,Jmax*self.NORIENT],dtype=self.backend.backend.float32,device=self.backend.torch_device)
|
|
3756
|
+
VS3 = []
|
|
3757
|
+
VS3P = []
|
|
3758
|
+
VS4 = []
|
|
3759
|
+
|
|
3760
|
+
off_S2 = -2
|
|
3761
|
+
off_S3 = -3
|
|
3762
|
+
off_S4 = -4
|
|
3763
|
+
if self.use_1D:
|
|
3764
|
+
off_S2 = -1
|
|
3765
|
+
off_S3 = -1
|
|
3766
|
+
off_S4 = -1
|
|
3767
|
+
|
|
3768
|
+
# S2 for normalization
|
|
3769
|
+
cond_init_P1_dic = (norm == "self") or (
|
|
3770
|
+
(norm == "auto") and (self.P1_dic is None)
|
|
3771
|
+
)
|
|
3772
|
+
if norm is None:
|
|
3773
|
+
pass
|
|
3774
|
+
elif cond_init_P1_dic:
|
|
3775
|
+
P1_dic = {}
|
|
3776
|
+
if cross:
|
|
3777
|
+
P2_dic = {}
|
|
3778
|
+
elif (norm == "auto") and (self.P1_dic is not None):
|
|
3779
|
+
P1_dic = self.P1_dic
|
|
3780
|
+
if cross:
|
|
3781
|
+
P2_dic = self.P2_dic
|
|
3782
|
+
|
|
3783
|
+
if return_data:
|
|
3784
|
+
s0 = I1
|
|
3785
|
+
if out_nside is not None:
|
|
3786
|
+
s0 = self.backend.bk_reduce_mean(
|
|
3787
|
+
self.backend.bk_reshape(
|
|
3788
|
+
s0, [s0.shape[0], 12 * out_nside**2, (nside // out_nside) ** 2]
|
|
3789
|
+
),
|
|
3790
|
+
2,
|
|
3791
|
+
)
|
|
3792
|
+
else:
|
|
3793
|
+
if not cross:
|
|
3794
|
+
s0, l_vs0 = self.masked_mean(I1, vmask, axis=1, calc_var=True)
|
|
3795
|
+
else:
|
|
3796
|
+
s0, l_vs0 = self.masked_mean(
|
|
3797
|
+
self.backend.bk_L1(I1 * I2), vmask, axis=1, calc_var=True
|
|
3798
|
+
)
|
|
3799
|
+
#vs0 = self.backend.bk_concat([l_vs0, l_vs0], 1)
|
|
3800
|
+
#s0 = self.backend.bk_concat([s0, l_vs0], 1)
|
|
3801
|
+
result[:,:,0]=s0
|
|
3802
|
+
result[:,:,1]=l_vs0
|
|
3803
|
+
vresult[:,:,0]=l_vs0
|
|
3804
|
+
vresult[:,:,1]=l_vs0
|
|
3805
|
+
#### COMPUTE S1, S2, S3 and S4
|
|
3806
|
+
nside_j3 = nside # NSIDE start (nside_j3 = nside / 2^j3)
|
|
3807
|
+
|
|
3808
|
+
# a remettre comme avant
|
|
3809
|
+
M1_dic={}
|
|
3810
|
+
|
|
3811
|
+
M2_dic={}
|
|
3812
|
+
|
|
3813
|
+
for j3 in range(Jmax):
|
|
3814
|
+
|
|
3815
|
+
if edge:
|
|
3816
|
+
if self.mask_mask is None:
|
|
3817
|
+
self.mask_mask={}
|
|
3818
|
+
if self.use_2D:
|
|
3819
|
+
if (vmask.shape[1],vmask.shape[2]) not in self.mask_mask:
|
|
3820
|
+
mask_mask=np.zeros([1,vmask.shape[1],vmask.shape[2]])
|
|
3821
|
+
mask_mask[0,
|
|
3822
|
+
self.KERNELSZ//2:-self.KERNELSZ//2+1,
|
|
3823
|
+
self.KERNELSZ//2:-self.KERNELSZ//2+1]=1.0
|
|
3824
|
+
self.mask_mask[(vmask.shape[1],vmask.shape[2])]=self.backend.bk_cast(mask_mask)
|
|
3825
|
+
vmask=vmask*self.mask_mask[(vmask.shape[1],vmask.shape[2])]
|
|
3826
|
+
#print(self.KERNELSZ//2,vmask,mask_mask)
|
|
3827
|
+
|
|
3828
|
+
if self.use_1D:
|
|
3829
|
+
if (vmask.shape[1]) not in self.mask_mask:
|
|
3830
|
+
mask_mask=np.zeros([1,vmask.shape[1]])
|
|
3831
|
+
mask_mask[0,
|
|
3832
|
+
self.KERNELSZ//2:-self.KERNELSZ//2+1]=1.0
|
|
3833
|
+
self.mask_mask[(vmask.shape[1])]=self.backend.bk_cast(mask_mask)
|
|
3834
|
+
vmask=vmask*self.mask_mask[(vmask.shape[1])]
|
|
3835
|
+
|
|
3836
|
+
if return_data:
|
|
3837
|
+
S3[j3] = None
|
|
3838
|
+
S3P[j3] = None
|
|
3839
|
+
|
|
3840
|
+
if S4 is None:
|
|
3841
|
+
S4 = {}
|
|
3842
|
+
S4[j3] = None
|
|
3843
|
+
|
|
3844
|
+
####### S1 and S2
|
|
3845
|
+
### Make the convolution I1 * Psi_j3
|
|
3846
|
+
conv1 = self.convol(I1, axis=1) # [Nbatch, Npix_j3, Norient3]
|
|
3847
|
+
if cmat is not None:
|
|
3848
|
+
tmp2 = self.backend.bk_repeat(conv1, self.NORIENT, axis=-1)
|
|
3849
|
+
conv1 = self.backend.bk_reduce_sum(
|
|
3850
|
+
self.backend.bk_reshape(
|
|
3851
|
+
cmat[j3] * tmp2,
|
|
3852
|
+
[tmp2.shape[0], cmat[j3].shape[0], self.NORIENT, self.NORIENT],
|
|
3853
|
+
),
|
|
3854
|
+
2,
|
|
3855
|
+
)
|
|
3856
|
+
|
|
3857
|
+
### Take the module M1 = |I1 * Psi_j3|
|
|
3858
|
+
M1_square = conv1 * self.backend.bk_conjugate(
|
|
3859
|
+
conv1
|
|
3860
|
+
) # [Nbatch, Npix_j3, Norient3]
|
|
3861
|
+
|
|
3862
|
+
M1 = self.backend.bk_L1(M1_square) # [Nbatch, Npix_j3, Norient3]
|
|
3863
|
+
|
|
3864
|
+
# Store M1_j3 in a dictionary
|
|
3865
|
+
M1_dic[j3] = M1
|
|
3866
|
+
|
|
3867
|
+
if not cross: # Auto
|
|
3868
|
+
M1_square = self.backend.bk_real(M1_square)
|
|
3869
|
+
|
|
3870
|
+
### S2_auto = < M1^2 >_pix
|
|
3871
|
+
# Apply the mask [Nmask, Npix_j3] and average over pixels
|
|
3872
|
+
if return_data:
|
|
3873
|
+
s2 = M1_square
|
|
3874
|
+
else:
|
|
3875
|
+
if calc_var:
|
|
3876
|
+
s2, vs2 = self.masked_mean(
|
|
3877
|
+
M1_square, vmask, axis=1, rank=j3, calc_var=True
|
|
3878
|
+
)
|
|
3879
|
+
#s2=self.backend.bk_flatten(self.backend.bk_real(s2))
|
|
3880
|
+
#vs2=self.backend.bk_flatten(vs2)
|
|
3881
|
+
else:
|
|
3882
|
+
s2 = self.masked_mean(M1_square, vmask, axis=1, rank=j3)
|
|
3883
|
+
|
|
3884
|
+
if cond_init_P1_dic:
|
|
3885
|
+
# We fill P1_dic with S2 for normalisation of S3 and S4
|
|
3886
|
+
P1_dic[j3] = self.backend.bk_real(self.backend.bk_real(s2)) # [Nbatch, Nmask, Norient3]
|
|
3887
|
+
|
|
3888
|
+
# We store S2_auto to return it [Nbatch, Nmask, NS2, Norient3]
|
|
3889
|
+
if return_data:
|
|
3890
|
+
if S2 is None:
|
|
3891
|
+
S2 = {}
|
|
3892
|
+
if out_nside is not None and out_nside < nside_j3:
|
|
3893
|
+
s2 = self.backend.bk_reduce_mean(
|
|
3894
|
+
self.backend.bk_reshape(
|
|
3895
|
+
s2,
|
|
3896
|
+
[
|
|
3897
|
+
s2.shape[0],
|
|
3898
|
+
12 * out_nside**2,
|
|
3899
|
+
(nside_j3 // out_nside) ** 2,
|
|
3900
|
+
s2.shape[2],
|
|
3901
|
+
],
|
|
3902
|
+
),
|
|
3903
|
+
2,
|
|
3904
|
+
)
|
|
3905
|
+
S2[j3] = s2
|
|
3906
|
+
else:
|
|
3907
|
+
if norm == "auto": # Normalize S2
|
|
3908
|
+
s2 /= P1_dic[j3]
|
|
3909
|
+
"""
|
|
3910
|
+
S2.append(
|
|
3911
|
+
self.backend.bk_expand_dims(s2, off_S2)
|
|
3912
|
+
) # Add a dimension for NS2
|
|
3913
|
+
if calc_var:
|
|
3914
|
+
VS2.append(
|
|
3915
|
+
self.backend.bk_expand_dims(vs2, off_S2)
|
|
3916
|
+
) # Add a dimension for NS2
|
|
3917
|
+
"""
|
|
3918
|
+
#print(s2.shape,result[:,:,2+j3*NORIENT*2:2+j3*NORIENT*2+NORIENT].shape,result.shape,2+j3*NORIENT*2)
|
|
3919
|
+
result[:,:,2+j3*NORIENT*2:2+j3*NORIENT*2+NORIENT]=s2
|
|
3920
|
+
if calc_var:
|
|
3921
|
+
vresult[:,:,2+j3*NORIENT*2:2+j3*NORIENT*2+NORIENT]=vs2
|
|
3922
|
+
#### S1_auto computation
|
|
3923
|
+
### Image 1 : S1 = < M1 >_pix
|
|
3924
|
+
# Apply the mask [Nmask, Npix_j3] and average over pixels
|
|
3925
|
+
if return_data:
|
|
3926
|
+
s1 = M1
|
|
3927
|
+
else:
|
|
3928
|
+
if calc_var:
|
|
3929
|
+
s1, vs1 = self.masked_mean(
|
|
3930
|
+
M1, vmask, axis=1, rank=j3, calc_var=True
|
|
3931
|
+
) # [Nbatch, Nmask, Norient3]
|
|
3932
|
+
#s1=self.backend.bk_flatten(self.backend.bk_real(s1))
|
|
3933
|
+
#vs1=self.backend.bk_flatten(vs1)
|
|
3934
|
+
else:
|
|
3935
|
+
s1 = self.masked_mean(
|
|
3936
|
+
M1, vmask, axis=1, rank=j3
|
|
3937
|
+
) # [Nbatch, Nmask, Norient3]
|
|
3938
|
+
#s1=self.backend.bk_flatten(self.backend.bk_real(s1))
|
|
3939
|
+
|
|
3940
|
+
if return_data:
|
|
3941
|
+
if out_nside is not None and out_nside < nside_j3:
|
|
3942
|
+
s1 = self.backend.bk_reduce_mean(
|
|
3943
|
+
self.backend.bk_reshape(
|
|
3944
|
+
s1,
|
|
3945
|
+
[
|
|
3946
|
+
s1.shape[0],
|
|
3947
|
+
12 * out_nside**2,
|
|
3948
|
+
(nside_j3 // out_nside) ** 2,
|
|
3949
|
+
s1.shape[2],
|
|
3950
|
+
],
|
|
3951
|
+
),
|
|
3952
|
+
2,
|
|
3953
|
+
)
|
|
3954
|
+
S1[j3] = s1
|
|
3955
|
+
else:
|
|
3956
|
+
### Normalize S1
|
|
3957
|
+
if norm is not None:
|
|
3958
|
+
self.div_norm(s1, (P1_dic[j3]) ** 0.5)
|
|
3959
|
+
result[:,:,2+j3*NORIENT*2+NORIENT:2+j3*NORIENT*2+2*NORIENT]=s1
|
|
3960
|
+
if calc_var:
|
|
3961
|
+
vresult[:,:,2+j3*NORIENT*2+NORIENT:2+j3*NORIENT*2+2*NORIENT]=vs1
|
|
3962
|
+
"""
|
|
3963
|
+
### We store S1 for image1 [Nbatch, Nmask, NS1, Norient3]
|
|
3964
|
+
S1.append(
|
|
3965
|
+
self.backend.bk_expand_dims(s1, off_S2)
|
|
3966
|
+
) # Add a dimension for NS1
|
|
3967
|
+
if calc_var:
|
|
3968
|
+
VS1.append(
|
|
3969
|
+
self.backend.bk_expand_dims(vs1, off_S2)
|
|
3970
|
+
) # Add a dimension for NS1
|
|
3971
|
+
"""
|
|
3972
|
+
|
|
3973
|
+
else: # Cross
|
|
3974
|
+
### Make the convolution I2 * Psi_j3
|
|
3975
|
+
conv2 = self.convol(I2, axis=1) # [Nbatch, Npix_j3, Norient3]
|
|
3976
|
+
if cmat is not None:
|
|
3977
|
+
tmp2 = self.backend.bk_repeat(conv2, self.NORIENT, axis=-1)
|
|
3978
|
+
conv2 = self.backend.bk_reduce_sum(
|
|
3979
|
+
self.backend.bk_reshape(
|
|
3980
|
+
cmat[j3] * tmp2,
|
|
3981
|
+
[
|
|
3982
|
+
tmp2.shape[0],
|
|
3983
|
+
cmat[j3].shape[0],
|
|
3984
|
+
self.NORIENT,
|
|
3985
|
+
self.NORIENT,
|
|
3986
|
+
],
|
|
3987
|
+
),
|
|
3988
|
+
2,
|
|
3989
|
+
)
|
|
3990
|
+
### Take the module M2 = |I2 * Psi_j3|
|
|
3991
|
+
M2_square = conv2 * self.backend.bk_conjugate(
|
|
3992
|
+
conv2
|
|
3993
|
+
) # [Nbatch, Npix_j3, Norient3]
|
|
3994
|
+
M2 = self.backend.bk_L1(M2_square) # [Nbatch, Npix_j3, Norient3]
|
|
3995
|
+
# Store M2_j3 in a dictionary
|
|
3996
|
+
M2_dic[j3] = M2
|
|
3997
|
+
|
|
3998
|
+
### S2_auto = < M2^2 >_pix
|
|
3999
|
+
# Not returned, only for normalization
|
|
4000
|
+
if cond_init_P1_dic:
|
|
4001
|
+
# Apply the mask [Nmask, Npix_j3] and average over pixels
|
|
4002
|
+
if return_data:
|
|
4003
|
+
p1 = M1_square
|
|
4004
|
+
p2 = M2_square
|
|
4005
|
+
else:
|
|
4006
|
+
if calc_var:
|
|
4007
|
+
p1, vp1 = self.masked_mean(
|
|
4008
|
+
M1_square, vmask, axis=1, rank=j3, calc_var=True
|
|
4009
|
+
) # [Nbatch, Nmask, Norient3]
|
|
4010
|
+
p2, vp2 = self.masked_mean(
|
|
4011
|
+
M2_square, vmask, axis=1, rank=j3, calc_var=True
|
|
4012
|
+
) # [Nbatch, Nmask, Norient3]
|
|
4013
|
+
else:
|
|
4014
|
+
p1 = self.masked_mean(
|
|
4015
|
+
M1_square, vmask, axis=1, rank=j3
|
|
4016
|
+
) # [Nbatch, Nmask, Norient3]
|
|
4017
|
+
p2 = self.masked_mean(
|
|
4018
|
+
M2_square, vmask, axis=1, rank=j3
|
|
4019
|
+
) # [Nbatch, Nmask, Norient3]
|
|
4020
|
+
# We fill P1_dic with S2 for normalisation of S3 and S4
|
|
4021
|
+
P1_dic[j3] = self.backend.bk_real(p1) # [Nbatch, Nmask, Norient3]
|
|
4022
|
+
P2_dic[j3] = self.backend.bk_real(p2) # [Nbatch, Nmask, Norient3]
|
|
4023
|
+
|
|
4024
|
+
### S2_cross = < (I1 * Psi_j3) (I2 * Psi_j3)^* >_pix
|
|
4025
|
+
# z_1 x z_2^* = (a1a2 + b1b2) + i(b1a2 - a1b2)
|
|
4026
|
+
s2 = conv1 * self.backend.bk_conjugate(conv2)
|
|
4027
|
+
MX = self.backend.bk_L1(s2)
|
|
4028
|
+
# Apply the mask [Nmask, Npix_j3] and average over pixels
|
|
4029
|
+
if return_data:
|
|
4030
|
+
s2 = s2
|
|
4031
|
+
else:
|
|
4032
|
+
if calc_var:
|
|
4033
|
+
s2, vs2 = self.masked_mean(
|
|
4034
|
+
s2, vmask, axis=1, rank=j3, calc_var=True
|
|
4035
|
+
)
|
|
4036
|
+
else:
|
|
4037
|
+
s2 = self.masked_mean(s2, vmask, axis=1, rank=j3)
|
|
4038
|
+
|
|
4039
|
+
if return_data:
|
|
4040
|
+
if out_nside is not None and out_nside < nside_j3:
|
|
4041
|
+
s2 = self.backend.bk_reduce_mean(
|
|
4042
|
+
self.backend.bk_reshape(
|
|
4043
|
+
s2,
|
|
4044
|
+
[
|
|
4045
|
+
s2.shape[0],
|
|
4046
|
+
12 * out_nside**2,
|
|
4047
|
+
(nside_j3 // out_nside) ** 2,
|
|
4048
|
+
s2.shape[2],
|
|
4049
|
+
],
|
|
4050
|
+
),
|
|
4051
|
+
2,
|
|
4052
|
+
)
|
|
4053
|
+
S2[j3] = s2
|
|
4054
|
+
else:
|
|
4055
|
+
### Normalize S2_cross
|
|
4056
|
+
if norm == "auto":
|
|
4057
|
+
s2 /= (P1_dic[j3] * P2_dic[j3]) ** 0.5
|
|
4058
|
+
|
|
4059
|
+
### Store S2_cross as complex [Nbatch, Nmask, NS2, Norient3]
|
|
4060
|
+
s2 = self.backend.bk_real(s2)
|
|
4061
|
+
|
|
4062
|
+
S2.append(
|
|
4063
|
+
self.backend.bk_expand_dims(s2, off_S2)
|
|
4064
|
+
) # Add a dimension for NS2
|
|
4065
|
+
if calc_var:
|
|
4066
|
+
VS2.append(
|
|
4067
|
+
self.backend.bk_expand_dims(vs2, off_S2)
|
|
4068
|
+
) # Add a dimension for NS2
|
|
4069
|
+
|
|
4070
|
+
#### S1_auto computation
|
|
4071
|
+
### Image 1 : S1 = < M1 >_pix
|
|
4072
|
+
# Apply the mask [Nmask, Npix_j3] and average over pixels
|
|
4073
|
+
if return_data:
|
|
4074
|
+
s1 = MX
|
|
4075
|
+
else:
|
|
4076
|
+
if calc_var:
|
|
4077
|
+
s1, vs1 = self.masked_mean(
|
|
4078
|
+
MX, vmask, axis=1, rank=j3, calc_var=True
|
|
4079
|
+
) # [Nbatch, Nmask, Norient3]
|
|
4080
|
+
else:
|
|
4081
|
+
s1 = self.masked_mean(
|
|
4082
|
+
MX, vmask, axis=1, rank=j3
|
|
4083
|
+
) # [Nbatch, Nmask, Norient3]
|
|
4084
|
+
if return_data:
|
|
4085
|
+
if out_nside is not None and out_nside < nside_j3:
|
|
4086
|
+
s1 = self.backend.bk_reduce_mean(
|
|
4087
|
+
self.backend.bk_reshape(
|
|
4088
|
+
s1,
|
|
4089
|
+
[
|
|
4090
|
+
s1.shape[0],
|
|
4091
|
+
12 * out_nside**2,
|
|
4092
|
+
(nside_j3 // out_nside) ** 2,
|
|
4093
|
+
s1.shape[2],
|
|
4094
|
+
],
|
|
4095
|
+
),
|
|
4096
|
+
2,
|
|
4097
|
+
)
|
|
4098
|
+
S1[j3] = s1
|
|
4099
|
+
else:
|
|
4100
|
+
### Normalize S1
|
|
4101
|
+
if norm is not None:
|
|
4102
|
+
self.div_norm(s1, (P1_dic[j3]) ** 0.5)
|
|
4103
|
+
### We store S1 for image1 [Nbatch, Nmask, NS1, Norient3]
|
|
4104
|
+
S1.append(
|
|
4105
|
+
self.backend.bk_expand_dims(s1, off_S2)
|
|
4106
|
+
) # Add a dimension for NS1
|
|
4107
|
+
if calc_var:
|
|
4108
|
+
VS1.append(
|
|
4109
|
+
self.backend.bk_expand_dims(vs1, off_S2)
|
|
4110
|
+
) # Add a dimension for NS1
|
|
4111
|
+
|
|
4112
|
+
# Initialize dictionaries for |I1*Psi_j| * Psi_j3
|
|
4113
|
+
M1convPsi_dic = {}
|
|
4114
|
+
if cross:
|
|
4115
|
+
# Initialize dictionaries for |I2*Psi_j| * Psi_j3
|
|
4116
|
+
M2convPsi_dic = {}
|
|
4117
|
+
|
|
4118
|
+
###### S3
|
|
4119
|
+
nside_j2 = nside_j3
|
|
4120
|
+
for j2 in range(0,-1): # j3 + 1): # j2 <= j3
|
|
4121
|
+
if return_data:
|
|
4122
|
+
if S4[j3] is None:
|
|
4123
|
+
S4[j3] = {}
|
|
4124
|
+
S4[j3][j2] = None
|
|
4125
|
+
|
|
4126
|
+
### S3_auto = < (I1 * Psi)_j3 x (|I1 * Psi_j2| * Psi_j3)^* >_pix
|
|
4127
|
+
if not cross:
|
|
4128
|
+
if calc_var:
|
|
4129
|
+
s3, vs3 = self._compute_S3(
|
|
4130
|
+
j2,
|
|
4131
|
+
j3,
|
|
4132
|
+
conv1,
|
|
4133
|
+
vmask,
|
|
4134
|
+
M1_dic,
|
|
4135
|
+
M1convPsi_dic,
|
|
4136
|
+
calc_var=True,
|
|
4137
|
+
cmat2=cmat2,
|
|
4138
|
+
) # [Nbatch, Nmask, Norient3, Norient2]
|
|
4139
|
+
else:
|
|
4140
|
+
s3 = self._compute_S3(
|
|
4141
|
+
j2,
|
|
4142
|
+
j3,
|
|
4143
|
+
conv1,
|
|
4144
|
+
vmask,
|
|
4145
|
+
M1_dic,
|
|
4146
|
+
M1convPsi_dic,
|
|
4147
|
+
return_data=return_data,
|
|
4148
|
+
cmat2=cmat2,
|
|
4149
|
+
) # [Nbatch, Nmask, Norient3, Norient2]
|
|
4150
|
+
|
|
4151
|
+
if return_data:
|
|
4152
|
+
if S3[j3] is None:
|
|
4153
|
+
S3[j3] = {}
|
|
4154
|
+
if out_nside is not None and out_nside < nside_j2:
|
|
4155
|
+
s3 = self.backend.bk_reduce_mean(
|
|
4156
|
+
self.backend.bk_reshape(
|
|
4157
|
+
s3,
|
|
4158
|
+
[
|
|
4159
|
+
s3.shape[0],
|
|
4160
|
+
12 * out_nside**2,
|
|
4161
|
+
(nside_j2 // out_nside) ** 2,
|
|
4162
|
+
s3.shape[2],
|
|
4163
|
+
s3.shape[3],
|
|
4164
|
+
],
|
|
4165
|
+
),
|
|
4166
|
+
2,
|
|
4167
|
+
)
|
|
4168
|
+
S3[j3][j2] = s3
|
|
4169
|
+
else:
|
|
4170
|
+
### Normalize S3 with S2_j [Nbatch, Nmask, Norient_j]
|
|
4171
|
+
if norm is not None:
|
|
4172
|
+
self.div_norm(
|
|
4173
|
+
s3,
|
|
4174
|
+
(
|
|
4175
|
+
self.backend.bk_expand_dims(P1_dic[j2], off_S2)
|
|
4176
|
+
* self.backend.bk_expand_dims(P1_dic[j3], -1)
|
|
4177
|
+
)
|
|
4178
|
+
** 0.5,
|
|
4179
|
+
) # [Nbatch, Nmask, Norient3, Norient2]
|
|
4180
|
+
|
|
4181
|
+
### Store S3 as a complex [Nbatch, Nmask, NS3, Norient3, Norient2]
|
|
4182
|
+
|
|
4183
|
+
# S3.append(self.backend.bk_reshape(s3,[s3.shape[0],s3.shape[1],
|
|
4184
|
+
# s3.shape[2]*s3.shape[3]]))
|
|
4185
|
+
S3.append(
|
|
4186
|
+
self.backend.bk_expand_dims(s3, off_S3)
|
|
4187
|
+
) # Add a dimension for NS3
|
|
4188
|
+
if calc_var:
|
|
4189
|
+
VS3.append(
|
|
4190
|
+
self.backend.bk_expand_dims(vs3, off_S3)
|
|
4191
|
+
) # Add a dimension for NS3
|
|
4192
|
+
# VS3.append(self.backend.bk_reshape(vs3,[s3.shape[0],s3.shape[1],
|
|
4193
|
+
# s3.shape[2]*s3.shape[3]]))
|
|
4194
|
+
|
|
4195
|
+
### S3_cross = < (I1 * Psi)_j3 x (|I2 * Psi_j2| * Psi_j3)^* >_pix
|
|
4196
|
+
### S3P_cross = < (I2 * Psi)_j3 x (|I1 * Psi_j2| * Psi_j3)^* >_pix
|
|
4197
|
+
else:
|
|
4198
|
+
if calc_var:
|
|
4199
|
+
s3, vs3 = self._compute_S3(
|
|
4200
|
+
j2,
|
|
4201
|
+
j3,
|
|
4202
|
+
conv1,
|
|
4203
|
+
vmask,
|
|
4204
|
+
M2_dic,
|
|
4205
|
+
M2convPsi_dic,
|
|
4206
|
+
calc_var=True,
|
|
4207
|
+
cmat2=cmat2,
|
|
4208
|
+
)
|
|
4209
|
+
s3p, vs3p = self._compute_S3(
|
|
4210
|
+
j2,
|
|
4211
|
+
j3,
|
|
4212
|
+
conv2,
|
|
4213
|
+
vmask,
|
|
4214
|
+
M1_dic,
|
|
4215
|
+
M1convPsi_dic,
|
|
4216
|
+
calc_var=True,
|
|
4217
|
+
cmat2=cmat2,
|
|
4218
|
+
)
|
|
4219
|
+
else:
|
|
4220
|
+
s3 = self._compute_S3(
|
|
4221
|
+
j2,
|
|
4222
|
+
j3,
|
|
4223
|
+
conv1,
|
|
4224
|
+
vmask,
|
|
4225
|
+
M2_dic,
|
|
4226
|
+
M2convPsi_dic,
|
|
4227
|
+
return_data=return_data,
|
|
4228
|
+
cmat2=cmat2,
|
|
4229
|
+
)
|
|
4230
|
+
s3p = self._compute_S3(
|
|
4231
|
+
j2,
|
|
4232
|
+
j3,
|
|
4233
|
+
conv2,
|
|
4234
|
+
vmask,
|
|
4235
|
+
M1_dic,
|
|
4236
|
+
M1convPsi_dic,
|
|
4237
|
+
return_data=return_data,
|
|
4238
|
+
cmat2=cmat2,
|
|
4239
|
+
)
|
|
4240
|
+
|
|
4241
|
+
if return_data:
|
|
4242
|
+
if S3[j3] is None:
|
|
4243
|
+
S3[j3] = {}
|
|
4244
|
+
S3P[j3] = {}
|
|
4245
|
+
if out_nside is not None and out_nside < nside_j2:
|
|
4246
|
+
s3 = self.backend.bk_reduce_mean(
|
|
4247
|
+
self.backend.bk_reshape(
|
|
4248
|
+
s3,
|
|
4249
|
+
[
|
|
4250
|
+
s3.shape[0],
|
|
4251
|
+
12 * out_nside**2,
|
|
4252
|
+
(nside_j2 // out_nside) ** 2,
|
|
4253
|
+
s3.shape[2],
|
|
4254
|
+
s3.shape[3],
|
|
4255
|
+
],
|
|
4256
|
+
),
|
|
4257
|
+
2,
|
|
4258
|
+
)
|
|
4259
|
+
s3p = self.backend.bk_reduce_mean(
|
|
4260
|
+
self.backend.bk_reshape(
|
|
4261
|
+
s3p,
|
|
4262
|
+
[
|
|
4263
|
+
s3.shape[0],
|
|
4264
|
+
12 * out_nside**2,
|
|
4265
|
+
(nside_j2 // out_nside) ** 2,
|
|
4266
|
+
s3.shape[2],
|
|
4267
|
+
s3.shape[3],
|
|
4268
|
+
],
|
|
4269
|
+
),
|
|
4270
|
+
2,
|
|
4271
|
+
)
|
|
4272
|
+
S3[j3][j2] = s3
|
|
4273
|
+
S3P[j3][j2] = s3p
|
|
4274
|
+
else:
|
|
4275
|
+
### Normalize S3 and S3P with S2_j [Nbatch, Nmask, Norient_j]
|
|
4276
|
+
if norm is not None:
|
|
4277
|
+
self.div_norm(
|
|
4278
|
+
s3,
|
|
4279
|
+
(
|
|
4280
|
+
self.backend.bk_expand_dims(P2_dic[j2], off_S2)
|
|
4281
|
+
* self.backend.bk_expand_dims(P1_dic[j3], -1)
|
|
4282
|
+
)
|
|
4283
|
+
** 0.5,
|
|
4284
|
+
) # [Nbatch, Nmask, Norient3, Norient2]
|
|
4285
|
+
self.div_norm(
|
|
4286
|
+
s3p,
|
|
4287
|
+
(
|
|
4288
|
+
self.backend.bk_expand_dims(P1_dic[j2], off_S2)
|
|
4289
|
+
* self.backend.bk_expand_dims(P2_dic[j3], -1)
|
|
4290
|
+
)
|
|
4291
|
+
** 0.5,
|
|
4292
|
+
) # [Nbatch, Nmask, Norient3, Norient2]
|
|
4293
|
+
|
|
4294
|
+
### Store S3 and S3P as a complex [Nbatch, Nmask, NS3, Norient3, Norient2]
|
|
4295
|
+
|
|
4296
|
+
# S3.append(self.backend.bk_reshape(s3,[s3.shape[0],s3.shape[1],
|
|
4297
|
+
# s3.shape[2]*s3.shape[3]]))
|
|
4298
|
+
S3.append(
|
|
4299
|
+
self.backend.bk_expand_dims(s3, off_S3)
|
|
4300
|
+
) # Add a dimension for NS3
|
|
4301
|
+
if calc_var:
|
|
4302
|
+
VS3.append(
|
|
4303
|
+
self.backend.bk_expand_dims(vs3, off_S3)
|
|
4304
|
+
) # Add a dimension for NS3
|
|
4305
|
+
|
|
4306
|
+
# VS3.append(self.backend.bk_reshape(vs3,[s3.shape[0],s3.shape[1],
|
|
4307
|
+
# s3.shape[2]*s3.shape[3]]))
|
|
4308
|
+
|
|
4309
|
+
# S3P.append(self.backend.bk_reshape(s3p,[s3.shape[0],s3.shape[1],
|
|
4310
|
+
# s3.shape[2]*s3.shape[3]]))
|
|
4311
|
+
S3P.append(
|
|
4312
|
+
self.backend.bk_expand_dims(s3p, off_S3)
|
|
4313
|
+
) # Add a dimension for NS3
|
|
4314
|
+
if calc_var:
|
|
4315
|
+
VS3P.append(
|
|
4316
|
+
self.backend.bk_expand_dims(vs3p, off_S3)
|
|
4317
|
+
) # Add a dimension for NS3
|
|
4318
|
+
# VS3P.append(self.backend.bk_reshape(vs3p,[s3.shape[0],s3.shape[1],
|
|
4319
|
+
# s3.shape[2]*s3.shape[3]]))
|
|
4320
|
+
|
|
4321
|
+
##### S4
|
|
4322
|
+
nside_j1 = nside_j2
|
|
4323
|
+
for j1 in range(0, j2 + 1): # j1 <= j2
|
|
4324
|
+
### S4_auto = <(|I1 * psi1| * psi3)(|I1 * psi2| * psi3)^*>
|
|
4325
|
+
if not cross:
|
|
4326
|
+
if calc_var:
|
|
4327
|
+
s4, vs4 = self._compute_S4(
|
|
4328
|
+
j1,
|
|
4329
|
+
j2,
|
|
4330
|
+
vmask,
|
|
4331
|
+
M1convPsi_dic,
|
|
4332
|
+
M2convPsi_dic=None,
|
|
4333
|
+
calc_var=True,
|
|
4334
|
+
) # [Nbatch, Nmask, Norient3, Norient2, Norient1]
|
|
4335
|
+
else:
|
|
4336
|
+
s4 = self._compute_S4(
|
|
4337
|
+
j1,
|
|
4338
|
+
j2,
|
|
4339
|
+
vmask,
|
|
4340
|
+
M1convPsi_dic,
|
|
4341
|
+
M2convPsi_dic=None,
|
|
4342
|
+
return_data=return_data,
|
|
4343
|
+
) # [Nbatch, Nmask, Norient3, Norient2, Norient1]
|
|
4344
|
+
|
|
4345
|
+
if return_data:
|
|
4346
|
+
if S4[j3][j2] is None:
|
|
4347
|
+
S4[j3][j2] = {}
|
|
4348
|
+
if out_nside is not None and out_nside < nside_j1:
|
|
4349
|
+
s4 = self.backend.bk_reduce_mean(
|
|
4350
|
+
self.backend.bk_reshape(
|
|
4351
|
+
s4,
|
|
4352
|
+
[
|
|
4353
|
+
s4.shape[0],
|
|
4354
|
+
12 * out_nside**2,
|
|
4355
|
+
(nside_j1 // out_nside) ** 2,
|
|
4356
|
+
s4.shape[2],
|
|
4357
|
+
s4.shape[3],
|
|
4358
|
+
s4.shape[4],
|
|
4359
|
+
],
|
|
4360
|
+
),
|
|
4361
|
+
2,
|
|
4362
|
+
)
|
|
4363
|
+
S4[j3][j2][j1] = s4
|
|
4364
|
+
else:
|
|
4365
|
+
### Normalize S4 with S2_j [Nbatch, Nmask, Norient_j]
|
|
4366
|
+
if norm is not None:
|
|
4367
|
+
self.div_norm(
|
|
4368
|
+
s4,
|
|
4369
|
+
(
|
|
4370
|
+
self.backend.bk_expand_dims(
|
|
4371
|
+
self.backend.bk_expand_dims(
|
|
4372
|
+
P1_dic[j1], off_S2
|
|
4373
|
+
),
|
|
4374
|
+
off_S2,
|
|
4375
|
+
)
|
|
4376
|
+
* self.backend.bk_expand_dims(
|
|
4377
|
+
self.backend.bk_expand_dims(
|
|
4378
|
+
P1_dic[j2], off_S2
|
|
4379
|
+
),
|
|
4380
|
+
-1,
|
|
4381
|
+
)
|
|
4382
|
+
)
|
|
4383
|
+
** 0.5,
|
|
4384
|
+
) # [Nbatch, Nmask, Norient3, Norient2, Norient1]
|
|
4385
|
+
### Store S4 as a complex [Nbatch, Nmask, NS4, Norient3, Norient2, Norient1]
|
|
4386
|
+
|
|
4387
|
+
# S4.append(self.backend.bk_reshape(s4,[s4.shape[0],s4.shape[1],
|
|
4388
|
+
# s4.shape[2]*s4.shape[3]*s4.shape[4]]))
|
|
4389
|
+
S4.append(
|
|
4390
|
+
self.backend.bk_expand_dims(s4, off_S4)
|
|
4391
|
+
) # Add a dimension for NS4
|
|
4392
|
+
if calc_var:
|
|
4393
|
+
# VS4.append(self.backend.bk_reshape(vs4,[s4.shape[0],s4.shape[1],
|
|
4394
|
+
# s4.shape[2]*s4.shape[3]*s4.shape[4]]))
|
|
4395
|
+
VS4.append(
|
|
4396
|
+
self.backend.bk_expand_dims(vs4, off_S4)
|
|
4397
|
+
) # Add a dimension for NS4
|
|
4398
|
+
|
|
4399
|
+
### S4_cross = <(|I1 * psi1| * psi3)(|I2 * psi2| * psi3)^*>
|
|
4400
|
+
else:
|
|
4401
|
+
if calc_var:
|
|
4402
|
+
s4, vs4 = self._compute_S4(
|
|
4403
|
+
j1,
|
|
4404
|
+
j2,
|
|
4405
|
+
vmask,
|
|
4406
|
+
M1convPsi_dic,
|
|
4407
|
+
M2convPsi_dic=M2convPsi_dic,
|
|
4408
|
+
calc_var=True,
|
|
4409
|
+
) # [Nbatch, Nmask, Norient3, Norient2, Norient1]
|
|
4410
|
+
else:
|
|
4411
|
+
s4 = self._compute_S4(
|
|
4412
|
+
j1,
|
|
4413
|
+
j2,
|
|
4414
|
+
vmask,
|
|
4415
|
+
M1convPsi_dic,
|
|
4416
|
+
M2convPsi_dic=M2convPsi_dic,
|
|
4417
|
+
return_data=return_data,
|
|
4418
|
+
) # [Nbatch, Nmask, Norient3, Norient2, Norient1]
|
|
4419
|
+
|
|
4420
|
+
if return_data:
|
|
4421
|
+
if S4[j3][j2] is None:
|
|
4422
|
+
S4[j3][j2] = {}
|
|
4423
|
+
if out_nside is not None and out_nside < nside_j1:
|
|
4424
|
+
s4 = self.backend.bk_reduce_mean(
|
|
4425
|
+
self.backend.bk_reshape(
|
|
4426
|
+
s4,
|
|
4427
|
+
[
|
|
4428
|
+
s4.shape[0],
|
|
4429
|
+
12 * out_nside**2,
|
|
4430
|
+
(nside_j1 // out_nside) ** 2,
|
|
4431
|
+
s4.shape[2],
|
|
4432
|
+
s4.shape[3],
|
|
4433
|
+
s4.shape[4],
|
|
4434
|
+
],
|
|
4435
|
+
),
|
|
4436
|
+
2,
|
|
4437
|
+
)
|
|
4438
|
+
S4[j3][j2][j1] = s4
|
|
4439
|
+
else:
|
|
4440
|
+
### Normalize S4 with S2_j [Nbatch, Nmask, Norient_j]
|
|
4441
|
+
if norm is not None:
|
|
4442
|
+
self.div_norm(
|
|
4443
|
+
s4,
|
|
4444
|
+
(
|
|
4445
|
+
self.backend.bk_expand_dims(
|
|
4446
|
+
self.backend.bk_expand_dims(
|
|
4447
|
+
P1_dic[j1], off_S2
|
|
4448
|
+
),
|
|
4449
|
+
off_S2,
|
|
4450
|
+
)
|
|
4451
|
+
* self.backend.bk_expand_dims(
|
|
4452
|
+
self.backend.bk_expand_dims(
|
|
4453
|
+
P2_dic[j2], off_S2
|
|
4454
|
+
),
|
|
4455
|
+
-1,
|
|
4456
|
+
)
|
|
4457
|
+
)
|
|
4458
|
+
** 0.5,
|
|
4459
|
+
) # [Nbatch, Nmask, Norient3, Norient2, Norient1]
|
|
4460
|
+
### Store S4 as a complex [Nbatch, Nmask, NS4, Norient3, Norient2, Norient1]
|
|
4461
|
+
# S4.append(self.backend.bk_reshape(s4,[s4.shape[0],s4.shape[1],
|
|
4462
|
+
# s4.shape[2]*s4.shape[3]*s4.shape[4]]))
|
|
4463
|
+
S4.append(
|
|
4464
|
+
self.backend.bk_expand_dims(s4, off_S4)
|
|
4465
|
+
) # Add a dimension for NS4
|
|
4466
|
+
if calc_var:
|
|
4467
|
+
|
|
4468
|
+
# VS4.append(self.backend.bk_reshape(vs4,[s4.shape[0],s4.shape[1],
|
|
4469
|
+
# s4.shape[2]*s4.shape[3]*s4.shape[4]]))
|
|
4470
|
+
VS4.append(
|
|
4471
|
+
self.backend.bk_expand_dims(vs4, off_S4)
|
|
4472
|
+
) # Add a dimension for NS4
|
|
4473
|
+
|
|
4474
|
+
nside_j1 = nside_j1 // 2
|
|
4475
|
+
nside_j2 = nside_j2 // 2
|
|
4476
|
+
|
|
4477
|
+
###### Reshape for next iteration on j3
|
|
4478
|
+
### Image I1,
|
|
4479
|
+
# downscale the I1 [Nbatch, Npix_j3]
|
|
4480
|
+
if j3 != Jmax - 1:
|
|
4481
|
+
I1 = self.smooth(I1, axis=1)
|
|
4482
|
+
I1 = self.ud_grade_2(I1, axis=1)
|
|
4483
|
+
|
|
4484
|
+
### Image I2
|
|
4485
|
+
if cross:
|
|
4486
|
+
I2 = self.smooth(I2, axis=1)
|
|
4487
|
+
I2 = self.ud_grade_2(I2, axis=1)
|
|
4488
|
+
|
|
4489
|
+
### Modules
|
|
4490
|
+
for j2 in range(0, j3 + 1): # j2 =< j3
|
|
4491
|
+
### Dictionary M1_dic[j2]
|
|
4492
|
+
M1_smooth = self.smooth(
|
|
4493
|
+
M1_dic[j2], axis=1
|
|
4494
|
+
) # [Nbatch, Npix_j3, Norient3]
|
|
4495
|
+
M1_dic[j2] = self.ud_grade_2(
|
|
4496
|
+
M1_smooth, axis=1
|
|
4497
|
+
) # [Nbatch, Npix_j3, Norient3]
|
|
4498
|
+
|
|
4499
|
+
### Dictionary M2_dic[j2]
|
|
4500
|
+
if cross:
|
|
4501
|
+
M2_smooth = self.smooth(
|
|
4502
|
+
M2_dic[j2], axis=1
|
|
4503
|
+
) # [Nbatch, Npix_j3, Norient3]
|
|
4504
|
+
M2_dic[j2] = self.ud_grade_2(
|
|
4505
|
+
M2, axis=1
|
|
4506
|
+
) # [Nbatch, Npix_j3, Norient3]
|
|
4507
|
+
|
|
4508
|
+
### Mask
|
|
4509
|
+
vmask = self.ud_grade_2(vmask, axis=1)
|
|
4510
|
+
|
|
4511
|
+
if self.mask_thres is not None:
|
|
4512
|
+
vmask = self.backend.bk_threshold(vmask, self.mask_thres)
|
|
4513
|
+
|
|
4514
|
+
### NSIDE_j3
|
|
4515
|
+
nside_j3 = nside_j3 // 2
|
|
4516
|
+
|
|
4517
|
+
### Store P1_dic and P2_dic in self
|
|
4518
|
+
if (norm == "auto") and (self.P1_dic is None):
|
|
4519
|
+
self.P1_dic = P1_dic
|
|
4520
|
+
if cross:
|
|
4521
|
+
self.P2_dic = P2_dic
|
|
4522
|
+
"""
|
|
4523
|
+
Sout=[s0]+S1+S2+S3+S4
|
|
4524
|
+
|
|
4525
|
+
if cross:
|
|
4526
|
+
Sout=Sout+S3P
|
|
4527
|
+
if calc_var:
|
|
4528
|
+
SVout=[vs0]+VS1+VS2+VS3+VS4
|
|
4529
|
+
if cross:
|
|
4530
|
+
VSout=VSout+VS3P
|
|
4531
|
+
return self.backend.bk_concat(Sout, 2),self.backend.bk_concat(VSout, 2)
|
|
4532
|
+
|
|
4533
|
+
return self.backend.bk_concat(Sout, 2)
|
|
4534
|
+
"""
|
|
4535
|
+
if calc_var:
|
|
4536
|
+
return result,vresult
|
|
4537
|
+
else:
|
|
4538
|
+
return result
|
|
4539
|
+
if calc_var:
|
|
4540
|
+
for k in S1:
|
|
4541
|
+
print(k.shape,k.dtype)
|
|
4542
|
+
for k in S2:
|
|
4543
|
+
print(k.shape,k.dtype)
|
|
4544
|
+
print(s0.shape,s0.dtype)
|
|
4545
|
+
return self.backend.bk_concat([s0]+S1+S2,axis=1),self.backend.bk_concat([vs0]+VS1+VS2,axis=1)
|
|
4546
|
+
else:
|
|
4547
|
+
return self.backend.bk_concat([s0]+S1+S2,axis=1)
|
|
4548
|
+
if not return_data:
|
|
4549
|
+
S1 = self.backend.bk_concat(S1, 2)
|
|
4550
|
+
S2 = self.backend.bk_concat(S2, 2)
|
|
4551
|
+
S3 = self.backend.bk_concat(S3, 2)
|
|
4552
|
+
S4 = self.backend.bk_concat(S4, 2)
|
|
4553
|
+
if cross:
|
|
4554
|
+
S3P = self.backend.bk_concat(S3P, 2)
|
|
4555
|
+
if calc_var:
|
|
4556
|
+
VS1 = self.backend.bk_concat(VS1, 2)
|
|
4557
|
+
VS2 = self.backend.bk_concat(VS2, 2)
|
|
4558
|
+
VS3 = self.backend.bk_concat(VS3, 2)
|
|
4559
|
+
VS4 = self.backend.bk_concat(VS4, 2)
|
|
4560
|
+
if cross:
|
|
4561
|
+
VS3P = self.backend.bk_concat(VS3P, 2)
|
|
4562
|
+
if calc_var:
|
|
4563
|
+
if not cross:
|
|
4564
|
+
return scat_cov(
|
|
4565
|
+
s0, S2, S3, S4, s1=S1, backend=self.backend, use_1D=self.use_1D
|
|
4566
|
+
), scat_cov(
|
|
4567
|
+
vs0,
|
|
4568
|
+
VS2,
|
|
4569
|
+
VS3,
|
|
4570
|
+
VS4,
|
|
4571
|
+
s1=VS1,
|
|
4572
|
+
backend=self.backend,
|
|
4573
|
+
use_1D=self.use_1D,
|
|
4574
|
+
)
|
|
4575
|
+
else:
|
|
4576
|
+
return scat_cov(
|
|
4577
|
+
s0,
|
|
4578
|
+
S2,
|
|
4579
|
+
S3,
|
|
4580
|
+
S4,
|
|
4581
|
+
s1=S1,
|
|
4582
|
+
s3p=S3P,
|
|
4583
|
+
backend=self.backend,
|
|
4584
|
+
use_1D=self.use_1D,
|
|
4585
|
+
), scat_cov(
|
|
4586
|
+
vs0,
|
|
4587
|
+
VS2,
|
|
4588
|
+
VS3,
|
|
4589
|
+
VS4,
|
|
4590
|
+
s1=VS1,
|
|
4591
|
+
s3p=VS3P,
|
|
4592
|
+
backend=self.backend,
|
|
4593
|
+
use_1D=self.use_1D,
|
|
4594
|
+
)
|
|
4595
|
+
else:
|
|
4596
|
+
if not cross:
|
|
4597
|
+
return scat_cov(
|
|
4598
|
+
s0, S2, S3, S4, s1=S1, backend=self.backend, use_1D=self.use_1D
|
|
4599
|
+
)
|
|
4600
|
+
else:
|
|
4601
|
+
return scat_cov(
|
|
4602
|
+
s0,
|
|
4603
|
+
S2,
|
|
4604
|
+
S3,
|
|
4605
|
+
S4,
|
|
4606
|
+
s1=S1,
|
|
4607
|
+
s3p=S3P,
|
|
4608
|
+
backend=self.backend,
|
|
4609
|
+
use_1D=self.use_1D,
|
|
4610
|
+
)
|
|
4611
|
+
def clean_norm(self):
|
|
4612
|
+
self.P1_dic = None
|
|
4613
|
+
self.P2_dic = None
|
|
4614
|
+
return
|
|
4615
|
+
|
|
4616
|
+
def _compute_S3(
|
|
4617
|
+
self,
|
|
4618
|
+
j2,
|
|
4619
|
+
j3,
|
|
4620
|
+
conv,
|
|
4621
|
+
vmask,
|
|
4622
|
+
M_dic,
|
|
4623
|
+
MconvPsi_dic,
|
|
4624
|
+
calc_var=False,
|
|
4625
|
+
return_data=False,
|
|
4626
|
+
cmat2=None,
|
|
4627
|
+
):
|
|
4628
|
+
"""
|
|
4629
|
+
Compute the S3 coefficients (auto or cross)
|
|
4630
|
+
S3 = < (Ia * Psi)_j3 x (|Ib * Psi_j2| * Psi_j3)^* >_pix
|
|
4631
|
+
Parameters
|
|
4632
|
+
----------
|
|
4633
|
+
Returns
|
|
4634
|
+
-------
|
|
4635
|
+
cs3, ss3: real and imag parts of S3 coeff
|
|
4636
|
+
"""
|
|
4637
|
+
### Compute |I1 * Psi_j2| * Psi_j3 = M1_j2 * Psi_j3
|
|
4638
|
+
# Warning: M1_dic[j2] is already at j3 resolution [Nbatch, Npix_j3, Norient3]
|
|
4639
|
+
MconvPsi = self.convol(
|
|
4640
|
+
M_dic[j2], axis=1
|
|
4641
|
+
) # [Nbatch, Npix_j3, Norient3, Norient2]
|
|
4642
|
+
if cmat2 is not None:
|
|
4643
|
+
tmp2 = self.backend.bk_repeat(MconvPsi, self.NORIENT, axis=-1)
|
|
4644
|
+
MconvPsi = self.backend.bk_reduce_sum(
|
|
4645
|
+
self.backend.bk_reshape(
|
|
4646
|
+
cmat2[j3][j2] * tmp2,
|
|
4647
|
+
[
|
|
4648
|
+
tmp2.shape[0],
|
|
4649
|
+
cmat2[j3].shape[1],
|
|
4650
|
+
self.NORIENT,
|
|
4651
|
+
self.NORIENT,
|
|
4652
|
+
self.NORIENT,
|
|
4653
|
+
],
|
|
4654
|
+
),
|
|
4655
|
+
3,
|
|
4656
|
+
)
|
|
4657
|
+
|
|
4658
|
+
# Store it so we can use it in S4 computation
|
|
4659
|
+
MconvPsi_dic[j2] = MconvPsi # [Nbatch, Npix_j3, Norient3, Norient2]
|
|
4660
|
+
|
|
4661
|
+
### Compute the product (I2 * Psi)_j3 x (M1_j2 * Psi_j3)^*
|
|
4662
|
+
# z_1 x z_2^* = (a1a2 + b1b2) + i(b1a2 - a1b2)
|
|
4663
|
+
# cconv, sconv are [Nbatch, Npix_j3, Norient3]
|
|
4664
|
+
if self.use_1D:
|
|
4665
|
+
s3 = conv * self.backend.bk_conjugate(MconvPsi)
|
|
4666
|
+
else:
|
|
4667
|
+
s3 = self.backend.bk_expand_dims(conv, -1) * self.backend.bk_conjugate(
|
|
4668
|
+
MconvPsi
|
|
4669
|
+
) # [Nbatch, Npix_j3, Norient3, Norient2]
|
|
4670
|
+
|
|
4671
|
+
### Apply the mask [Nmask, Npix_j3] and sum over pixels
|
|
4672
|
+
if return_data:
|
|
4673
|
+
return s3
|
|
4674
|
+
else:
|
|
4675
|
+
if calc_var:
|
|
4676
|
+
s3, vs3 = self.masked_mean(
|
|
4677
|
+
s3, vmask, axis=1, rank=j2, calc_var=True
|
|
4678
|
+
) # [Nbatch, Nmask, Norient3, Norient2]
|
|
4679
|
+
return s3, vs3
|
|
4680
|
+
else:
|
|
4681
|
+
s3 = self.masked_mean(
|
|
3522
4682
|
s3, vmask, axis=1, rank=j2
|
|
3523
4683
|
) # [Nbatch, Nmask, Norient3, Norient2]
|
|
3524
4684
|
return s3
|
|
@@ -3565,7 +4725,591 @@ class funct(FOC.FoCUS):
|
|
|
3565
4725
|
s4, vmask, axis=1, rank=j2
|
|
3566
4726
|
) # [Nbatch, Nmask, Norient3, Norient2, Norient1]
|
|
3567
4727
|
return s4
|
|
4728
|
+
|
|
4729
|
+
def computer_filter(self,M,N,J,L):
|
|
4730
|
+
'''
|
|
4731
|
+
This function is strongly inspire by the package https://github.com/SihaoCheng/scattering_transform
|
|
4732
|
+
Done by Sihao Cheng and Rudy Morel.
|
|
4733
|
+
'''
|
|
4734
|
+
|
|
4735
|
+
filter = np.zeros([J, L, M, N],dtype='complex64')
|
|
4736
|
+
|
|
4737
|
+
slant=4.0 / L
|
|
4738
|
+
|
|
4739
|
+
for j in range(J):
|
|
4740
|
+
|
|
4741
|
+
for l in range(L):
|
|
4742
|
+
|
|
4743
|
+
theta = (int(L-L/2-1)-l) * np.pi / L
|
|
4744
|
+
sigma = 0.8 * 2**j
|
|
4745
|
+
xi = 3.0 / 4.0 * np.pi /2**j
|
|
4746
|
+
|
|
4747
|
+
R = np.array([[np.cos(theta), -np.sin(theta)], [np.sin(theta), np.cos(theta)]], np.float64)
|
|
4748
|
+
R_inv = np.array([[np.cos(theta), np.sin(theta)], [-np.sin(theta), np.cos(theta)]], np.float64)
|
|
4749
|
+
D = np.array([[1, 0], [0, slant * slant]])
|
|
4750
|
+
curv = np.matmul(R, np.matmul(D, R_inv)) / ( 2 * sigma * sigma)
|
|
4751
|
+
|
|
4752
|
+
gab = np.zeros((M, N), np.complex128)
|
|
4753
|
+
xx = np.empty((2,2, M, N))
|
|
4754
|
+
yy = np.empty((2,2, M, N))
|
|
4755
|
+
|
|
4756
|
+
for ii, ex in enumerate([-1, 0]):
|
|
4757
|
+
for jj, ey in enumerate([-1, 0]):
|
|
4758
|
+
xx[ii,jj], yy[ii,jj] = np.mgrid[
|
|
4759
|
+
ex * M : M + ex * M,
|
|
4760
|
+
ey * N : N + ey * N]
|
|
4761
|
+
|
|
4762
|
+
arg = -(curv[0, 0] * xx * xx + (curv[0, 1] + curv[1, 0]) * xx * yy + curv[1, 1] * yy * yy)
|
|
4763
|
+
argi = arg + 1.j * (xx * xi * np.cos(theta) + yy * xi * np.sin(theta))
|
|
4764
|
+
|
|
4765
|
+
gabi = np.exp(argi).sum((0,1))
|
|
4766
|
+
gab = np.exp(arg).sum((0,1))
|
|
4767
|
+
|
|
4768
|
+
norm_factor = 2 * np.pi * sigma * sigma / slant
|
|
4769
|
+
|
|
4770
|
+
gab = gab / norm_factor
|
|
4771
|
+
|
|
4772
|
+
gabi = gabi / norm_factor
|
|
4773
|
+
|
|
4774
|
+
K = gabi.sum() / gab.sum()
|
|
4775
|
+
|
|
4776
|
+
# Apply the Gaussian
|
|
4777
|
+
filter[j, l] = np.fft.fft2(gabi-K*gab)
|
|
4778
|
+
filter[j,l,0,0]=0.0
|
|
4779
|
+
|
|
4780
|
+
return self.backend.bk_cast(filter)
|
|
4781
|
+
|
|
4782
|
+
# ------------------------------------------------------------------------------------------
|
|
4783
|
+
#
|
|
4784
|
+
# utility functions
|
|
4785
|
+
#
|
|
4786
|
+
# ------------------------------------------------------------------------------------------
|
|
4787
|
+
def cut_high_k_off(self,data_f, dx, dy):
|
|
4788
|
+
'''
|
|
4789
|
+
This function is strongly inspire by the package https://github.com/SihaoCheng/scattering_transform
|
|
4790
|
+
Done by Sihao Cheng and Rudy Morel.
|
|
4791
|
+
'''
|
|
4792
|
+
if_xodd = (data_f.shape[-2]%2==1)
|
|
4793
|
+
if_yodd = (data_f.shape[-1]%2==1)
|
|
4794
|
+
result = self.backend.backend.cat(
|
|
4795
|
+
(self.backend.backend.cat(
|
|
4796
|
+
( data_f[...,:dx+if_xodd, :dy+if_yodd] , data_f[...,-dx:, :dy+if_yodd]
|
|
4797
|
+
), -2),
|
|
4798
|
+
self.backend.backend.cat(
|
|
4799
|
+
( data_f[...,:dx+if_xodd, -dy:] , data_f[...,-dx:, -dy:]
|
|
4800
|
+
), -2)
|
|
4801
|
+
),-1)
|
|
4802
|
+
return result
|
|
4803
|
+
# ---------------------------------------------------------------------------
|
|
4804
|
+
#
|
|
4805
|
+
# utility functions for computing scattering coef and covariance
|
|
4806
|
+
#
|
|
4807
|
+
# ---------------------------------------------------------------------------
|
|
4808
|
+
|
|
4809
|
+
def get_dxdy(self, j,M,N):
|
|
4810
|
+
'''
|
|
4811
|
+
This function is strongly inspire by the package https://github.com/SihaoCheng/scattering_transform
|
|
4812
|
+
Done by Sihao Cheng and Rudy Morel.
|
|
4813
|
+
'''
|
|
4814
|
+
dx = int(max( 8, min( np.ceil(M/2**j), M//2 ) ))
|
|
4815
|
+
dy = int(max( 8, min( np.ceil(N/2**j), N//2 ) ))
|
|
4816
|
+
return dx, dy
|
|
4817
|
+
|
|
4818
|
+
|
|
4819
|
+
|
|
4820
|
+
def get_edge_masks(self,M, N, J, d0=1):
|
|
4821
|
+
'''
|
|
4822
|
+
This function is strongly inspire by the package https://github.com/SihaoCheng/scattering_transform
|
|
4823
|
+
Done by Sihao Cheng and Rudy Morel.
|
|
4824
|
+
'''
|
|
4825
|
+
edge_masks = self.backend.backend.empty((J, M, N))
|
|
4826
|
+
X, Y = self.backend.backend.meshgrid(self.backend.backend.arange(M), self.backend.backend.arange(N), indexing='ij')
|
|
4827
|
+
for j in range(J):
|
|
4828
|
+
edge_dx = min(M//4, 2**j*d0)
|
|
4829
|
+
edge_dy = min(N//4, 2**j*d0)
|
|
4830
|
+
edge_masks[j] = (X>=edge_dx) * (X<=M-edge_dx) * (Y>=edge_dy) * (Y<=N-edge_dy)
|
|
4831
|
+
return edge_masks.to(self.backend.torch_device)
|
|
4832
|
+
|
|
4833
|
+
# ---------------------------------------------------------------------------
|
|
4834
|
+
#
|
|
4835
|
+
# scattering cov
|
|
4836
|
+
#
|
|
4837
|
+
# ---------------------------------------------------------------------------
|
|
4838
|
+
def scattering_cov(
|
|
4839
|
+
self, data, Jmax=None,
|
|
4840
|
+
if_large_batch=False,
|
|
4841
|
+
S4_criteria=None,
|
|
4842
|
+
use_ref=False,
|
|
4843
|
+
normalization='S2',
|
|
4844
|
+
edge=False,
|
|
4845
|
+
pseudo_coef=1,
|
|
4846
|
+
get_variance=False,
|
|
4847
|
+
ref_sigma=None,
|
|
4848
|
+
iso_ang=False
|
|
4849
|
+
):
|
|
4850
|
+
'''
|
|
4851
|
+
Calculates the scattering correlations for a batch of images, including:
|
|
4852
|
+
|
|
4853
|
+
This function is strongly inspire by the package https://github.com/SihaoCheng/scattering_transform
|
|
4854
|
+
Done by Sihao Cheng and Rudy Morel.
|
|
4855
|
+
|
|
4856
|
+
orig. x orig.:
|
|
4857
|
+
P00 = <(I * psi)(I * psi)*> = L2(I * psi)^2
|
|
4858
|
+
orig. x modulus:
|
|
4859
|
+
C01 = <(I * psi2)(|I * psi1| * psi2)*> / factor
|
|
4860
|
+
when normalization == 'P00', factor = L2(I * psi2) * L2(I * psi1)
|
|
4861
|
+
when normalization == 'P11', factor = L2(I * psi2) * L2(|I * psi1| * psi2)
|
|
4862
|
+
modulus x modulus:
|
|
4863
|
+
C11_pre_norm = <(|I * psi1| * psi3)(|I * psi2| * psi3)>
|
|
4864
|
+
C11 = C11_pre_norm / factor
|
|
4865
|
+
when normalization == 'P00', factor = L2(I * psi1) * L2(I * psi2)
|
|
4866
|
+
when normalization == 'P11', factor = L2(|I * psi1| * psi3) * L2(|I * psi2| * psi3)
|
|
4867
|
+
modulus x modulus (auto):
|
|
4868
|
+
P11 = <(|I * psi1| * psi2)(|I * psi1| * psi2)*>
|
|
4869
|
+
Parameters
|
|
4870
|
+
----------
|
|
4871
|
+
data : numpy array or torch tensor
|
|
4872
|
+
image set, with size [N_image, x-sidelength, y-sidelength]
|
|
4873
|
+
if_large_batch : Bool (=False)
|
|
4874
|
+
It is recommended to use "False" unless one meets a memory issue
|
|
4875
|
+
C11_criteria : str or None (=None)
|
|
4876
|
+
Only C11 coefficients that satisfy this criteria will be computed.
|
|
4877
|
+
Any expressions of j1, j2, and j3 that can be evaluated as a Bool
|
|
4878
|
+
is accepted.The default "None" corresponds to "j1 <= j2 <= j3".
|
|
4879
|
+
use_ref : Bool (=False)
|
|
4880
|
+
When normalizing, whether or not to use the normalization factor
|
|
4881
|
+
computed from a reference field. For just computing the statistics,
|
|
4882
|
+
the default is False. However, for synthesis, set it to "True" will
|
|
4883
|
+
stablize the optimization process.
|
|
4884
|
+
normalization : str 'P00' or 'P11' (='P00')
|
|
4885
|
+
Whether 'P00' or 'P11' is used as the normalization factor for C01
|
|
4886
|
+
and C11.
|
|
4887
|
+
remove_edge : Bool (=False)
|
|
4888
|
+
If true, the edge region with a width of rougly the size of the largest
|
|
4889
|
+
wavelet involved is excluded when taking the global average to obtain
|
|
4890
|
+
the scattering coefficients.
|
|
4891
|
+
|
|
4892
|
+
Returns
|
|
4893
|
+
-------
|
|
4894
|
+
'P00' : torch tensor with size [N_image, J, L] (# image, j1, l1)
|
|
4895
|
+
the power in each wavelet bands (the orig. x orig. term)
|
|
4896
|
+
'S1' : torch tensor with size [N_image, J, L] (# image, j1, l1)
|
|
4897
|
+
the 1st-order scattering coefficients, i.e., the mean of wavelet modulus fields
|
|
4898
|
+
'C01' : torch tensor with size [N_image, J, J, L, L] (# image, j1, j2, l1, l2)
|
|
4899
|
+
the orig. x modulus terms. Elements with j1 < j2 are all set to np.nan and not computed.
|
|
4900
|
+
'C11' : torch tensor with size [N_image, J, J, J, L, L, L] (# image, j1, j2, j3, l1, l2, l3)
|
|
4901
|
+
the modulus x modulus terms. Elements not satisfying j1 <= j2 <= j3 and the conditions
|
|
4902
|
+
defined in 'C11_criteria' are all set to np.nan and not computed.
|
|
4903
|
+
'C11_pre_norm' and 'C11_pre_norm_iso': pre-normalized modulus x modulus terms.
|
|
4904
|
+
'P11' : torch tensor with size [N_image, J, J, L, L] (# image, j1, j2, l1, l2)
|
|
4905
|
+
the modulus x modulus terms with the two wavelets within modulus the same. Elements not following
|
|
4906
|
+
j1 <= j3 are set to np.nan and not computed.
|
|
4907
|
+
'P11_iso' : torch tensor with size [N_image, J, J, L] (# image, j1, j2, l2-l1)
|
|
4908
|
+
'P11' averaged over l1 while keeping l2-l1 constant.
|
|
4909
|
+
'''
|
|
4910
|
+
if S4_criteria is None:
|
|
4911
|
+
S4_criteria = 'j2>=j1'
|
|
4912
|
+
|
|
4913
|
+
# determine jmax and nside corresponding to the input map
|
|
4914
|
+
im_shape = data.shape
|
|
4915
|
+
if self.use_2D:
|
|
4916
|
+
if len(data.shape) == 2:
|
|
4917
|
+
nside = np.min([im_shape[0], im_shape[1]])
|
|
4918
|
+
M,N = im_shape[0],im_shape[1]
|
|
4919
|
+
N_image = 1
|
|
4920
|
+
else:
|
|
4921
|
+
nside = np.min([im_shape[1], im_shape[2]])
|
|
4922
|
+
M,N = im_shape[1],im_shape[2]
|
|
4923
|
+
N_image = data.shape[0]
|
|
4924
|
+
J = int(np.log(nside) / np.log(2))-1 # Number of j scales
|
|
4925
|
+
elif self.use_1D:
|
|
4926
|
+
if len(data.shape) == 2:
|
|
4927
|
+
npix = int(im_shape[1]) # Number of pixels
|
|
4928
|
+
N_image = 1
|
|
4929
|
+
else:
|
|
4930
|
+
npix = int(im_shape[0]) # Number of pixels
|
|
4931
|
+
N_image = data.shape[0]
|
|
4932
|
+
|
|
4933
|
+
nside = int(npix)
|
|
4934
|
+
|
|
4935
|
+
J = int(np.log(nside) / np.log(2))-1 # Number of j scales
|
|
4936
|
+
else:
|
|
4937
|
+
if len(data.shape) == 2:
|
|
4938
|
+
npix = int(im_shape[1]) # Number of pixels
|
|
4939
|
+
N_image = 1
|
|
4940
|
+
else:
|
|
4941
|
+
npix = int(im_shape[0]) # Number of pixels
|
|
4942
|
+
N_image = data.shape[0]
|
|
4943
|
+
|
|
4944
|
+
nside = int(np.sqrt(npix // 12))
|
|
4945
|
+
|
|
4946
|
+
J = int(np.log(nside) / np.log(2)) # Number of j scales
|
|
4947
|
+
|
|
4948
|
+
if Jmax is None:
|
|
4949
|
+
Jmax = J # Number of steps for the loop on scales
|
|
4950
|
+
if Jmax>J:
|
|
4951
|
+
print('==========\n\n')
|
|
4952
|
+
print('The Jmax you requested is larger than the data size, which may cause problems while computing the scattering transform.')
|
|
4953
|
+
print('\n\n==========')
|
|
4954
|
+
|
|
4955
|
+
L=self.NORIENT
|
|
4956
|
+
|
|
4957
|
+
if (M,N,J,L) not in self.filters_set:
|
|
4958
|
+
self.filters_set[(M,N,J,L)] = self.computer_filter(M,N,J,L) #self.computer_filter(M,N,J,L)
|
|
4959
|
+
|
|
4960
|
+
filters_set = self.filters_set[(M,N,J,L)]
|
|
4961
|
+
|
|
4962
|
+
#weight = self.weight
|
|
4963
|
+
if use_ref:
|
|
4964
|
+
if normalization=='S2':
|
|
4965
|
+
ref_S2 = self.ref_scattering_cov_S2
|
|
4966
|
+
else:
|
|
4967
|
+
ref_P11 = self.ref_scattering_cov['P11']
|
|
4968
|
+
|
|
4969
|
+
# convert numpy array input into self.backend.bk_ tensors
|
|
4970
|
+
data = self.backend.bk_cast(data)
|
|
4971
|
+
data_f = self.backend.bk_fftn(data, dim=(-2,-1))
|
|
4972
|
+
|
|
4973
|
+
# initialize tensors for scattering coefficients
|
|
4974
|
+
S2 = self.backend.bk_zeros((N_image,J,L), dtype=data.dtype)
|
|
4975
|
+
S1 = self.backend.bk_zeros((N_image,J,L), dtype=data.dtype)
|
|
4976
|
+
|
|
4977
|
+
Ndata_S3 = J*(J+1)//2
|
|
4978
|
+
Ndata_S4 = J*(J+1)*(J+2)//6
|
|
4979
|
+
J_S4={}
|
|
4980
|
+
|
|
4981
|
+
S3 = self.backend.bk_zeros((N_image,Ndata_S3,L,L), dtype=data_f.dtype)
|
|
4982
|
+
S4_pre_norm = self.backend.bk_zeros((N_image,Ndata_S4,L,L,L), dtype=data_f.dtype)
|
|
4983
|
+
S4 = self.backend.bk_zeros((N_image,Ndata_S4,L,L,L), dtype=data_f.dtype)
|
|
4984
|
+
|
|
4985
|
+
# variance
|
|
4986
|
+
if get_variance:
|
|
4987
|
+
S2_sigma = self.backend.bk_zeros((N_image,J,L), dtype=data.dtype)
|
|
4988
|
+
S1_sigma = self.backend.bk_zeros((N_image,J,L), dtype=data.dtype)
|
|
4989
|
+
S3_sigma = self.backend.bk_zeros((N_image,Ndata_S3,L,L), dtype=data_f.dtype)
|
|
4990
|
+
S4_sigma = self.backend.bk_zeros((N_image,Ndata_S4,L,L,L), dtype=data_f.dtype)
|
|
4991
|
+
|
|
4992
|
+
if iso_ang:
|
|
4993
|
+
S3_iso = self.backend.bk_zeros((N_image,Ndata_S3,L), dtype=data_f.dtype)
|
|
4994
|
+
S4_iso = self.backend.bk_zeros((N_image,Ndata_S4,L,L), dtype=data_f.dtype)
|
|
4995
|
+
if get_variance:
|
|
4996
|
+
S3_sigma_iso = self.backend.bk_zeros((N_image,Ndata_S3,L), dtype=data_f.dtype)
|
|
4997
|
+
S4_sigma_iso = self.backend.bk_zeros((N_image,Ndata_S4,L,L), dtype=data_f.dtype)
|
|
4998
|
+
|
|
4999
|
+
# calculate scattering fields
|
|
5000
|
+
if self.use_2D:
|
|
5001
|
+
if len(data.shape) == 2:
|
|
5002
|
+
I1 = self.backend.bk_ifftn(
|
|
5003
|
+
data_f[None,None,None,:,:] * filters_set[None,:J,:,:,:], dim=(-2,-1)
|
|
5004
|
+
).abs()
|
|
5005
|
+
else:
|
|
5006
|
+
I1 = self.backend.bk_ifftn(
|
|
5007
|
+
data_f[:,None,None,:,:] * filters_set[None,:J,:,:,:], dim=(-2,-1)
|
|
5008
|
+
).abs()
|
|
5009
|
+
elif self.use_1D:
|
|
5010
|
+
if len(data.shape) == 1:
|
|
5011
|
+
I1 = self.backend.bk_ifftn(
|
|
5012
|
+
data_f[None,None,None,:] * filters_set[None,:J,:,:], dim=(-1)
|
|
5013
|
+
).abs()
|
|
5014
|
+
else:
|
|
5015
|
+
I1 = self.backend.bk_ifftn(
|
|
5016
|
+
data_f[:,None,None,:] * filters_set[None,:J,:,:], dim=(-1)
|
|
5017
|
+
).abs()
|
|
5018
|
+
else:
|
|
5019
|
+
print('todo')
|
|
5020
|
+
|
|
5021
|
+
I1_f= self.backend.bk_fftn(I1, dim=(-2,-1))
|
|
5022
|
+
|
|
5023
|
+
#
|
|
5024
|
+
if edge:
|
|
5025
|
+
if (M,N,J) not in self.edge_masks:
|
|
5026
|
+
self.edge_masks[(M,N,J)] = self.get_edge_masks(M,N,J)
|
|
5027
|
+
edge_mask = self.edge_masks[(M,N,J)][:,None,:,:]
|
|
5028
|
+
edge_mask = edge_mask / edge_mask.mean((-2,-1))[:,:,None,None]
|
|
5029
|
+
else:
|
|
5030
|
+
edge_mask = 1
|
|
5031
|
+
S2 = (I1**2 * edge_mask).mean((-2,-1))
|
|
5032
|
+
S1 = (I1 * edge_mask).mean((-2,-1))
|
|
5033
|
+
|
|
5034
|
+
if get_variance:
|
|
5035
|
+
S2_sigma = (I1**2 * edge_mask).std((-2,-1))
|
|
5036
|
+
S1_sigma = (I1 * edge_mask).std((-2,-1))
|
|
5037
|
+
|
|
5038
|
+
if pseudo_coef != 1:
|
|
5039
|
+
I1 = I1**pseudo_coef
|
|
5040
|
+
|
|
5041
|
+
Ndata_S3=0
|
|
5042
|
+
Ndata_S4=0
|
|
5043
|
+
|
|
5044
|
+
# calculate the covariance and correlations of the scattering fields
|
|
5045
|
+
# only use the low-k Fourier coefs when calculating large-j scattering coefs.
|
|
5046
|
+
for j3 in range(0,J):
|
|
5047
|
+
J_S4[j3]=Ndata_S4
|
|
5048
|
+
|
|
5049
|
+
dx3, dy3 = self.get_dxdy(j3,M,N)
|
|
5050
|
+
I1_f_small = self.cut_high_k_off(I1_f[:,:j3+1], dx3, dy3) # Nimage, J, L, x, y
|
|
5051
|
+
data_f_small = self.cut_high_k_off(data_f, dx3, dy3)
|
|
5052
|
+
if edge:
|
|
5053
|
+
I1_small = self.backend.bk_ifftn(I1_f_small, dim=(-2,-1), norm='ortho')
|
|
5054
|
+
data_small = self.backend.bk_ifftn(data_f_small, dim=(-2,-1), norm='ortho')
|
|
5055
|
+
wavelet_f3 = self.cut_high_k_off(filters_set[j3], dx3, dy3) # L,x,y
|
|
5056
|
+
_, M3, N3 = wavelet_f3.shape
|
|
5057
|
+
wavelet_f3_squared = wavelet_f3**2
|
|
5058
|
+
edge_dx = min(4, int(2**j3*dx3*2/M))
|
|
5059
|
+
edge_dy = min(4, int(2**j3*dy3*2/N))
|
|
5060
|
+
# a normalization change due to the cutoff of frequency space
|
|
5061
|
+
fft_factor = 1 /(M3*N3) * (M3*N3/M/N)**2
|
|
5062
|
+
for j2 in range(0,j3+1):
|
|
5063
|
+
I1_f2_wf3_small = I1_f_small[:,j2].view(N_image,L,1,M3,N3) * wavelet_f3.view(1,1,L,M3,N3)
|
|
5064
|
+
I1_f2_wf3_2_small = I1_f_small[:,j2].view(N_image,L,1,M3,N3) * wavelet_f3_squared.view(1,1,L,M3,N3)
|
|
5065
|
+
if edge:
|
|
5066
|
+
I12_w3_small = self.backend.bk_ifftn(I1_f2_wf3_small, dim=(-2,-1), norm='ortho')
|
|
5067
|
+
I12_w3_2_small = self.backend.bk_ifftn(I1_f2_wf3_2_small, dim=(-2,-1), norm='ortho')
|
|
5068
|
+
if use_ref:
|
|
5069
|
+
if normalization=='P11':
|
|
5070
|
+
norm_factor_S3 = (ref_S2[:,None,j3,:] * ref_P11[:,j2,j3,:,:]**pseudo_coef)**0.5
|
|
5071
|
+
if normalization=='S2':
|
|
5072
|
+
norm_factor_S3 = (ref_S2[:,None,j3,:] * ref_S2[:,j2,:,None]**pseudo_coef)**0.5
|
|
5073
|
+
else:
|
|
5074
|
+
if normalization=='P11':
|
|
5075
|
+
# [N_image,l2,l3,x,y]
|
|
5076
|
+
P11_temp = (I1_f2_wf3_small.abs()**2).mean((-2,-1)) * fft_factor
|
|
5077
|
+
norm_factor_S3 = (S2[:,None,j3,:] * P11_temp**pseudo_coef)**0.5
|
|
5078
|
+
if normalization=='S2':
|
|
5079
|
+
norm_factor_S3 = (S2[:,None,j3,:] * S2[:,j2,:,None]**pseudo_coef)**0.5
|
|
5080
|
+
|
|
5081
|
+
if not edge:
|
|
5082
|
+
S3[:,Ndata_S3,:,:] = (
|
|
5083
|
+
data_f_small.view(N_image,1,1,M3,N3) * self.backend.bk_conjugate(I1_f2_wf3_small)
|
|
5084
|
+
).mean((-2,-1)) * fft_factor / norm_factor_S3
|
|
5085
|
+
|
|
5086
|
+
if get_variance:
|
|
5087
|
+
S3_sigma[:,Ndata_S3,:,:] = (
|
|
5088
|
+
data_f_small.view(N_image,1,1,M3,N3) * self.backend.bk_conjugate(I1_f2_wf3_small)
|
|
5089
|
+
).std((-2,-1)) * fft_factor / norm_factor_S3
|
|
5090
|
+
else:
|
|
5091
|
+
|
|
5092
|
+
S3[:,Ndata_S3,:,:] = (
|
|
5093
|
+
data_small.view(N_image,1,1,M3,N3) * self.backend.bk_conjugate(I12_w3_small)
|
|
5094
|
+
)[...,edge_dx:M3-edge_dx, edge_dy:N3-edge_dy].mean((-2,-1)) * fft_factor / norm_factor_S3
|
|
5095
|
+
if get_variance:
|
|
5096
|
+
S3_sigma[:,Ndata_S3,:,:] = (
|
|
5097
|
+
data_small.view(N_image,1,1,M3,N3) * self.backend.bk_conjugate(I12_w3_small)
|
|
5098
|
+
)[...,edge_dx:M3-edge_dx, edge_dy:N3-edge_dy].std((-2,-1)) * fft_factor / norm_factor_S3
|
|
5099
|
+
Ndata_S3+=1
|
|
5100
|
+
if j2 <= j3:
|
|
5101
|
+
beg_n=Ndata_S4
|
|
5102
|
+
for j1 in range(0, j2+1):
|
|
5103
|
+
if eval(S4_criteria):
|
|
5104
|
+
if not edge:
|
|
5105
|
+
if not if_large_batch:
|
|
5106
|
+
# [N_image,l1,l2,l3,x,y]
|
|
5107
|
+
S4_pre_norm[:,Ndata_S4,:,:,:] = (
|
|
5108
|
+
I1_f_small[:,j1].view(N_image,L,1,1,M3,N3) *
|
|
5109
|
+
self.backend.bk_conjugate(I1_f2_wf3_2_small.view(N_image,1,L,L,M3,N3))
|
|
5110
|
+
).mean((-2,-1)) * fft_factor
|
|
5111
|
+
if get_variance:
|
|
5112
|
+
S4_sigma[:,Ndata_S4,:,:,:] = (
|
|
5113
|
+
I1_f_small[:,j1].view(N_image,L,1,1,M3,N3) *
|
|
5114
|
+
self.backend.bk_conjugate(I1_f2_wf3_2_small.view(N_image,1,L,L,M3,N3))
|
|
5115
|
+
).std((-2,-1)) * fft_factor
|
|
5116
|
+
else:
|
|
5117
|
+
for l1 in range(L):
|
|
5118
|
+
# [N_image,l2,l3,x,y]
|
|
5119
|
+
S4_pre_norm[:,Ndata_S4,l1,:,:] = (
|
|
5120
|
+
I1_f_small[:,j1,l1].view(N_image,1,1,M3,N3) *
|
|
5121
|
+
self.backend.bk_conjugate(I1_f2_wf3_2_small.view(N_image,L,L,M3,N3))
|
|
5122
|
+
).mean((-2,-1)) * fft_factor
|
|
5123
|
+
if get_variance:
|
|
5124
|
+
S4_sigma[:,Ndata_S4,l1,:,:] = (
|
|
5125
|
+
I1_f_small[:,j1,l1].view(N_image,1,1,M3,N3) *
|
|
5126
|
+
self.backend.bk_conjugate(I1_f2_wf3_2_small.view(N_image,L,L,M3,N3))
|
|
5127
|
+
).std((-2,-1)) * fft_factor
|
|
5128
|
+
else:
|
|
5129
|
+
if not if_large_batch:
|
|
5130
|
+
# [N_image,l1,l2,l3,x,y]
|
|
5131
|
+
S4_pre_norm[:,Ndata_S4,:,:,:] = (
|
|
5132
|
+
I1_small[:,j1].view(N_image,L,1,1,M3,N3) * self.backend.bk_conjugate(
|
|
5133
|
+
I12_w3_2_small.view(N_image,1,L,L,M3,N3)
|
|
5134
|
+
)
|
|
5135
|
+
)[...,edge_dx:-edge_dx, edge_dy:-edge_dy].mean((-2,-1)) * fft_factor
|
|
5136
|
+
if get_variance:
|
|
5137
|
+
S4_sigma[:,Ndata_S4,:,:,:] = (
|
|
5138
|
+
I1_small[:,j1].view(N_image,L,1,1,M3,N3) * self.backend.bk_conjugate(
|
|
5139
|
+
I12_w3_2_small.view(N_image,1,L,L,M3,N3)
|
|
5140
|
+
)
|
|
5141
|
+
)[...,edge_dx:-edge_dx, edge_dy:-edge_dy].std((-2,-1)) * fft_factor
|
|
5142
|
+
else:
|
|
5143
|
+
for l1 in range(L):
|
|
5144
|
+
# [N_image,l2,l3,x,y]
|
|
5145
|
+
S4_pre_norm[:,Ndata_S4,l1,:,:] = (
|
|
5146
|
+
I1_small[:,j1].view(N_image,1,1,M3,N3) * self.backend.bk_conjugate(
|
|
5147
|
+
I12_w3_2_small.view(N_image,L,L,M3,N3)
|
|
5148
|
+
)
|
|
5149
|
+
)[...,edge_dx:-edge_dx, edge_dy:-edge_dy].mean((-2,-1)) * fft_factor
|
|
5150
|
+
if get_variance:
|
|
5151
|
+
S4_sigma[:,Ndata_S4,l1,:,:] = (
|
|
5152
|
+
I1_small[:,j1].view(N_image,1,1,M3,N3) * self.backend.bk_conjugate(
|
|
5153
|
+
I12_w3_2_small.view(N_image,L,L,M3,N3)
|
|
5154
|
+
)
|
|
5155
|
+
)[...,edge_dx:-edge_dx, edge_dy:-edge_dy].mean((-2,-1)) * fft_factor
|
|
5156
|
+
|
|
5157
|
+
Ndata_S4+=1
|
|
5158
|
+
|
|
5159
|
+
if normalization=='S2':
|
|
5160
|
+
if use_ref:
|
|
5161
|
+
P = (ref_S2[:,j3,:,None,None] * ref_S2[:,j2,None,:,None] )**(0.5*pseudo_coef)
|
|
5162
|
+
else:
|
|
5163
|
+
P = (S2[:,j3,:,None,None] * S2[:,j2,None,:,None] )**(0.5*pseudo_coef)
|
|
5164
|
+
|
|
5165
|
+
S4[:,beg_n:Ndata_S4,:,:,:]=S4_pre_norm[:,beg_n:Ndata_S4,:,:,:]/P
|
|
5166
|
+
|
|
5167
|
+
if get_variance:
|
|
5168
|
+
S4_sigma[:,beg_n:Ndata_S4,:,:,:] = S4_sigma[:,beg_n:Ndata_S4,:,:,:] / P
|
|
5169
|
+
|
|
5170
|
+
"""
|
|
5171
|
+
# define P11 from diagonals of S4
|
|
5172
|
+
for j1 in range(J):
|
|
5173
|
+
for l1 in range(L):
|
|
5174
|
+
P11[:,j1,:,l1,:] = S4_pre_norm[:,j1,j1,:,l1,l1,:].real
|
|
5175
|
+
|
|
5176
|
+
|
|
5177
|
+
if normalization=='S4':
|
|
5178
|
+
if use_ref:
|
|
5179
|
+
P = ref_P11
|
|
5180
|
+
else:
|
|
5181
|
+
P = P11
|
|
5182
|
+
#.view(N_image,J,1,J,L,1,L) * .view(N_image,1,J,J,1,L,L)
|
|
5183
|
+
S4 = S4_pre_norm / (
|
|
5184
|
+
P[:,:,None,:,:,None,:] * P[:,None,:,:,None,:,:]
|
|
5185
|
+
)**(0.5*pseudo_coef)
|
|
5186
|
+
|
|
5187
|
+
|
|
5188
|
+
|
|
5189
|
+
|
|
5190
|
+
# get a single, flattened data vector for_synthesis
|
|
5191
|
+
select_and_index = self.get_scattering_index(J, L, normalization, S4_criteria)
|
|
5192
|
+
index_for_synthesis = select_and_index['index_for_synthesis']
|
|
5193
|
+
index_for_synthesis_iso = select_and_index['index_for_synthesis_iso']
|
|
5194
|
+
"""
|
|
5195
|
+
# average over l1 to obtain simple isotropic statistics
|
|
5196
|
+
if iso_ang:
|
|
5197
|
+
S2_iso = S2.mean(-1)
|
|
5198
|
+
S1_iso = S1.mean(-1)
|
|
5199
|
+
for l1 in range(L):
|
|
5200
|
+
for l2 in range(L):
|
|
5201
|
+
S3_iso[...,(l2-l1)%L] += S3[...,l1,l2]
|
|
5202
|
+
for l3 in range(L):
|
|
5203
|
+
S4_iso[...,(l2-l1)%L,(l3-l1)%L] += S4[...,l1,l2,l3]
|
|
5204
|
+
S3_iso /= L; S4_iso /= L
|
|
5205
|
+
|
|
5206
|
+
if get_variance:
|
|
5207
|
+
S2_sigma_iso = S2_sigma.mean(-1)
|
|
5208
|
+
S1_sigma_iso = S1_sigma.mean(-1)
|
|
5209
|
+
for l1 in range(L):
|
|
5210
|
+
for l2 in range(L):
|
|
5211
|
+
S3_sigma_iso[...,(l2-l1)%L] += S3_sigma[...,l1,l2]
|
|
5212
|
+
for l3 in range(L):
|
|
5213
|
+
S4_sigma_iso[...,(l2-l1)%L,(l3-l1)%L] += S4_sigma[...,l1,l2,l3]
|
|
5214
|
+
S3_sigma_iso /= L; S4_sigma_iso /= L
|
|
5215
|
+
|
|
5216
|
+
mean_data=self.backend.bk_zeros((N_image,1), dtype=data.dtype)
|
|
5217
|
+
std_data=self.backend.bk_zeros((N_image,1), dtype=data.dtype)
|
|
5218
|
+
mean_data[:,0]=data.mean((-2,-1))
|
|
5219
|
+
std_data[:,0]=data.std((-2,-1))
|
|
5220
|
+
|
|
5221
|
+
if get_variance:
|
|
5222
|
+
ref_sigma={}
|
|
5223
|
+
if iso_ang:
|
|
5224
|
+
ref_sigma['std_data']=std_data
|
|
5225
|
+
ref_sigma['S1_sigma']=S1_sigma_iso
|
|
5226
|
+
ref_sigma['S2_sigma']=S2_sigma_iso
|
|
5227
|
+
ref_sigma['S3_sigma']=S3_sigma_iso
|
|
5228
|
+
ref_sigma['S4_sigma']=S4_sigma_iso
|
|
5229
|
+
else:
|
|
5230
|
+
ref_sigma['std_data']=std_data
|
|
5231
|
+
ref_sigma['S1_sigma']=S1_sigma
|
|
5232
|
+
ref_sigma['S2_sigma']=S2_sigma
|
|
5233
|
+
ref_sigma['S3_sigma']=S3_sigma
|
|
5234
|
+
ref_sigma['S4_sigma']=S4_sigma
|
|
5235
|
+
|
|
5236
|
+
if iso_ang:
|
|
5237
|
+
if ref_sigma is not None:
|
|
5238
|
+
for_synthesis = self.backend.backend.cat((
|
|
5239
|
+
mean_data/ref_sigma['std_data'],
|
|
5240
|
+
std_data/ref_sigma['std_data'],
|
|
5241
|
+
(S2_iso/ref_sigma['S2_sigma']).reshape((N_image, -1)).log(),
|
|
5242
|
+
(S1_iso/ref_sigma['S1_sigma']).reshape((N_image, -1)).log(),
|
|
5243
|
+
(S3_iso/ref_sigma['S3_sigma']).reshape((N_image, -1)).real,
|
|
5244
|
+
(S3_iso/ref_sigma['S3_sigma']).reshape((N_image, -1)).imag,
|
|
5245
|
+
(S4_iso/ref_sigma['S4_sigma']).reshape((N_image, -1)).real,
|
|
5246
|
+
(S4_iso/ref_sigma['S4_sigma']).reshape((N_image, -1)).imag,
|
|
5247
|
+
),dim=-1)
|
|
5248
|
+
else:
|
|
5249
|
+
for_synthesis = self.backend.backend.cat((
|
|
5250
|
+
mean_data/std_data,
|
|
5251
|
+
std_data,
|
|
5252
|
+
S2_iso.reshape((N_image, -1)).log(),
|
|
5253
|
+
S1_iso.reshape((N_image, -1)).log(),
|
|
5254
|
+
S3_iso.reshape((N_image, -1)).real,
|
|
5255
|
+
S3_iso.reshape((N_image, -1)).imag,
|
|
5256
|
+
S4_iso.reshape((N_image, -1)).real,
|
|
5257
|
+
S4_iso.reshape((N_image, -1)).imag,
|
|
5258
|
+
),dim=-1)
|
|
5259
|
+
else:
|
|
5260
|
+
if ref_sigma is not None:
|
|
5261
|
+
for_synthesis = self.backend.backend.cat((
|
|
5262
|
+
mean_data/ref_sigma['std_data'],
|
|
5263
|
+
std_data/ref_sigma['std_data'],
|
|
5264
|
+
(S2/ref_sigma['S2_sigma']).reshape((N_image, -1)).log(),
|
|
5265
|
+
(S1/ref_sigma['S1_sigma']).reshape((N_image, -1)).log(),
|
|
5266
|
+
(S3/ref_sigma['S3_sigma']).reshape((N_image, -1)).real,
|
|
5267
|
+
(S3/ref_sigma['S3_sigma']).reshape((N_image, -1)).imag,
|
|
5268
|
+
(S4/ref_sigma['S4_sigma']).reshape((N_image, -1)).real,
|
|
5269
|
+
(S4/ref_sigma['S4_sigma']).reshape((N_image, -1)).imag,
|
|
5270
|
+
),dim=-1)
|
|
5271
|
+
else:
|
|
5272
|
+
for_synthesis = self.backend.backend.cat((
|
|
5273
|
+
mean_data/std_data,
|
|
5274
|
+
std_data,
|
|
5275
|
+
S2.reshape((N_image, -1)).log(),
|
|
5276
|
+
S1.reshape((N_image, -1)).log(),
|
|
5277
|
+
S3.reshape((N_image, -1)).real,
|
|
5278
|
+
S3.reshape((N_image, -1)).imag,
|
|
5279
|
+
S4.reshape((N_image, -1)).real,
|
|
5280
|
+
S4.reshape((N_image, -1)).imag,
|
|
5281
|
+
),dim=-1)
|
|
5282
|
+
|
|
5283
|
+
if not use_ref:
|
|
5284
|
+
self.ref_scattering_cov_S2=S2
|
|
5285
|
+
|
|
5286
|
+
if get_variance:
|
|
5287
|
+
return for_synthesis,ref_sigma
|
|
5288
|
+
|
|
5289
|
+
return for_synthesis
|
|
5290
|
+
|
|
5291
|
+
|
|
5292
|
+
def to_gaussian(self,x):
|
|
5293
|
+
from scipy.stats import norm
|
|
5294
|
+
from scipy.interpolate import interp1d
|
|
5295
|
+
|
|
5296
|
+
idx=np.argsort(x.flatten())
|
|
5297
|
+
p = (np.arange(1, idx.shape[0] + 1) - 0.5) / idx.shape[0]
|
|
5298
|
+
im_target=x.flatten()
|
|
5299
|
+
im_target[idx] = norm.ppf(p)
|
|
5300
|
+
|
|
5301
|
+
# Interpolation cubique
|
|
5302
|
+
self.f_gaussian = interp1d(im_target[idx], x.flatten()[idx], kind='cubic')
|
|
5303
|
+
self.val_min=im_target[idx[0]]
|
|
5304
|
+
self.val_max=im_target[idx[-1]]
|
|
5305
|
+
return im_target.reshape(x.shape)
|
|
3568
5306
|
|
|
5307
|
+
|
|
5308
|
+
def from_gaussian(self,x):
|
|
5309
|
+
|
|
5310
|
+
x=self.backend.bk_clip_by_value(x,self.val_min,self.val_max)
|
|
5311
|
+
return self.f_gaussian(self.backend.to_numpy(x))
|
|
5312
|
+
|
|
3569
5313
|
def square(self, x):
|
|
3570
5314
|
if isinstance(x, scat_cov):
|
|
3571
5315
|
if x.S1 is None:
|
|
@@ -3615,42 +5359,47 @@ class funct(FOC.FoCUS):
|
|
|
3615
5359
|
return self.backend.bk_abs(self.backend.bk_sqrt(x))
|
|
3616
5360
|
|
|
3617
5361
|
def reduce_mean(self, x):
|
|
3618
|
-
|
|
5362
|
+
|
|
3619
5363
|
if isinstance(x, scat_cov):
|
|
3620
|
-
result =
|
|
3621
|
-
|
|
3622
|
-
|
|
3623
|
-
|
|
3624
|
-
|
|
3625
|
-
|
|
3626
|
-
|
|
3627
|
-
|
|
5364
|
+
result = (
|
|
5365
|
+
self.backend.bk_reduce_sum(self.backend.bk_abs(x.S0))
|
|
5366
|
+
+ self.backend.bk_reduce_sum(self.backend.bk_abs(x.S2))
|
|
5367
|
+
+ self.backend.bk_reduce_sum(self.backend.bk_abs(x.S3))
|
|
5368
|
+
+ self.backend.bk_reduce_sum(self.backend.bk_abs(x.S4))
|
|
5369
|
+
)
|
|
5370
|
+
|
|
5371
|
+
N = (
|
|
5372
|
+
self.backend.bk_size(x.S0)
|
|
5373
|
+
+ self.backend.bk_size(x.S2)
|
|
5374
|
+
+ self.backend.bk_size(x.S3)
|
|
5375
|
+
+ self.backend.bk_size(x.S4)
|
|
5376
|
+
)
|
|
5377
|
+
|
|
3628
5378
|
if x.S1 is not None:
|
|
3629
|
-
result = result+self.backend.bk_reduce_sum(self.backend.bk_abs(x.S1))
|
|
5379
|
+
result = result + self.backend.bk_reduce_sum(self.backend.bk_abs(x.S1))
|
|
3630
5380
|
N = N + self.backend.bk_size(x.S1)
|
|
3631
5381
|
if x.S3P is not None:
|
|
3632
|
-
result = result+self.backend.bk_reduce_sum(self.backend.bk_abs(x.S3P))
|
|
5382
|
+
result = result + self.backend.bk_reduce_sum(self.backend.bk_abs(x.S3P))
|
|
3633
5383
|
N = N + self.backend.bk_size(x.S3P)
|
|
3634
|
-
return result/self.backend.bk_cast(N)
|
|
5384
|
+
return result / self.backend.bk_cast(N)
|
|
3635
5385
|
else:
|
|
3636
5386
|
return self.backend.bk_reduce_mean(x, axis=0)
|
|
3637
|
-
|
|
3638
5387
|
|
|
3639
5388
|
def reduce_mean_batch(self, x):
|
|
3640
|
-
|
|
5389
|
+
|
|
3641
5390
|
if isinstance(x, scat_cov):
|
|
3642
|
-
|
|
3643
|
-
sS0=self.backend.bk_reduce_mean(x.S0, axis=0)
|
|
3644
|
-
sS2=self.backend.bk_reduce_mean(x.S2, axis=0)
|
|
3645
|
-
sS3=self.backend.bk_reduce_mean(x.S3, axis=0)
|
|
3646
|
-
sS4=self.backend.bk_reduce_mean(x.S4, axis=0)
|
|
3647
|
-
sS1=None
|
|
3648
|
-
sS3P=None
|
|
5391
|
+
|
|
5392
|
+
sS0 = self.backend.bk_reduce_mean(x.S0, axis=0)
|
|
5393
|
+
sS2 = self.backend.bk_reduce_mean(x.S2, axis=0)
|
|
5394
|
+
sS3 = self.backend.bk_reduce_mean(x.S3, axis=0)
|
|
5395
|
+
sS4 = self.backend.bk_reduce_mean(x.S4, axis=0)
|
|
5396
|
+
sS1 = None
|
|
5397
|
+
sS3P = None
|
|
3649
5398
|
if x.S1 is not None:
|
|
3650
5399
|
sS1 = self.backend.bk_reduce_mean(x.S1, axis=0)
|
|
3651
5400
|
if x.S3P is not None:
|
|
3652
5401
|
sS3P = self.backend.bk_reduce_mean(x.S3P, axis=0)
|
|
3653
|
-
|
|
5402
|
+
|
|
3654
5403
|
result = scat_cov(
|
|
3655
5404
|
sS0,
|
|
3656
5405
|
sS2,
|
|
@@ -3664,22 +5413,22 @@ class funct(FOC.FoCUS):
|
|
|
3664
5413
|
return result
|
|
3665
5414
|
else:
|
|
3666
5415
|
return self.backend.bk_reduce_mean(x, axis=0)
|
|
3667
|
-
|
|
5416
|
+
|
|
3668
5417
|
def reduce_sum_batch(self, x):
|
|
3669
|
-
|
|
5418
|
+
|
|
3670
5419
|
if isinstance(x, scat_cov):
|
|
3671
|
-
|
|
3672
|
-
sS0=self.backend.bk_reduce_sum(x.S0, axis=0)
|
|
3673
|
-
sS2=self.backend.bk_reduce_sum(x.S2, axis=0)
|
|
3674
|
-
sS3=self.backend.bk_reduce_sum(x.S3, axis=0)
|
|
3675
|
-
sS4=self.backend.bk_reduce_sum(x.S4, axis=0)
|
|
3676
|
-
sS1=None
|
|
3677
|
-
sS3P=None
|
|
5420
|
+
|
|
5421
|
+
sS0 = self.backend.bk_reduce_sum(x.S0, axis=0)
|
|
5422
|
+
sS2 = self.backend.bk_reduce_sum(x.S2, axis=0)
|
|
5423
|
+
sS3 = self.backend.bk_reduce_sum(x.S3, axis=0)
|
|
5424
|
+
sS4 = self.backend.bk_reduce_sum(x.S4, axis=0)
|
|
5425
|
+
sS1 = None
|
|
5426
|
+
sS3P = None
|
|
3678
5427
|
if x.S1 is not None:
|
|
3679
5428
|
sS1 = self.backend.bk_reduce_sum(x.S1, axis=0)
|
|
3680
5429
|
if x.S3P is not None:
|
|
3681
5430
|
sS3P = self.backend.bk_reduce_sum(x.S3P, axis=0)
|
|
3682
|
-
|
|
5431
|
+
|
|
3683
5432
|
result = scat_cov(
|
|
3684
5433
|
sS0,
|
|
3685
5434
|
sS2,
|
|
@@ -3693,7 +5442,7 @@ class funct(FOC.FoCUS):
|
|
|
3693
5442
|
return result
|
|
3694
5443
|
else:
|
|
3695
5444
|
return self.backend.bk_reduce_mean(x, axis=0)
|
|
3696
|
-
|
|
5445
|
+
|
|
3697
5446
|
def reduce_distance(self, x, y, sigma=None):
|
|
3698
5447
|
|
|
3699
5448
|
if isinstance(x, scat_cov):
|
|
@@ -3729,11 +5478,13 @@ class funct(FOC.FoCUS):
|
|
|
3729
5478
|
return result
|
|
3730
5479
|
else:
|
|
3731
5480
|
if sigma is None:
|
|
3732
|
-
tmp=x-y
|
|
5481
|
+
tmp = x - y
|
|
3733
5482
|
else:
|
|
3734
|
-
tmp=(x-y)/sigma
|
|
5483
|
+
tmp = (x - y) / sigma
|
|
3735
5484
|
# do abs in case of complex values
|
|
3736
|
-
return self.backend.bk_abs(
|
|
5485
|
+
return self.backend.bk_abs(
|
|
5486
|
+
self.backend.bk_reduce_mean(self.backend.bk_square(tmp))
|
|
5487
|
+
)
|
|
3737
5488
|
|
|
3738
5489
|
def reduce_sum(self, x):
|
|
3739
5490
|
|
|
@@ -3909,3 +5660,133 @@ class funct(FOC.FoCUS):
|
|
|
3909
5660
|
return scat_cov(
|
|
3910
5661
|
s0, s2, s3, s4, s1=s1, s3p=s3p, backend=self.backend, use_1D=self.use_1D
|
|
3911
5662
|
)
|
|
5663
|
+
|
|
5664
|
+
def synthesis(self,
|
|
5665
|
+
image_target,
|
|
5666
|
+
nstep=4,
|
|
5667
|
+
seed=1234,
|
|
5668
|
+
Jmax=None,
|
|
5669
|
+
edge=False,
|
|
5670
|
+
to_gaussian=True,
|
|
5671
|
+
use_variance=False,
|
|
5672
|
+
synthesised_N=1,
|
|
5673
|
+
iso_ang=False,
|
|
5674
|
+
EVAL_FREQUENCY=100,
|
|
5675
|
+
NUM_EPOCHS = 300):
|
|
5676
|
+
|
|
5677
|
+
import foscat.Synthesis as synthe
|
|
5678
|
+
import time
|
|
5679
|
+
|
|
5680
|
+
def The_loss(u,scat_operator,args):
|
|
5681
|
+
ref = args[0]
|
|
5682
|
+
sref = args[1]
|
|
5683
|
+
use_v= args[2]
|
|
5684
|
+
|
|
5685
|
+
# compute scattering covariance of the current synthetised map called u
|
|
5686
|
+
if use_v:
|
|
5687
|
+
learn=scat_operator.reduce_mean_batch(scat_operator.scattering_cov(u,edge=edge,Jmax=Jmax,ref_sigma=sref,use_ref=True,iso_ang=iso_ang))
|
|
5688
|
+
else:
|
|
5689
|
+
learn=scat_operator.reduce_mean_batch(scat_operator.scattering_cov(u,edge=edge,Jmax=Jmax,use_ref=True,iso_ang=iso_ang))
|
|
5690
|
+
|
|
5691
|
+
# make the difference withe the reference coordinates
|
|
5692
|
+
loss=scat_operator.backend.bk_reduce_mean(scat_operator.backend.bk_square((learn-ref)))
|
|
5693
|
+
return loss
|
|
5694
|
+
|
|
5695
|
+
if to_gaussian:
|
|
5696
|
+
# Change the data histogram to gaussian distribution
|
|
5697
|
+
im_target=self.to_gaussian(image_target)
|
|
5698
|
+
else:
|
|
5699
|
+
im_target=image_target
|
|
5700
|
+
|
|
5701
|
+
axis=len(im_target.shape)-1
|
|
5702
|
+
if self.use_2D:
|
|
5703
|
+
axis-=1
|
|
5704
|
+
if axis==0:
|
|
5705
|
+
im_target=self.backend.bk_expand_dims(im_target,0)
|
|
5706
|
+
|
|
5707
|
+
# compute the number of possible steps
|
|
5708
|
+
if self.use_2D:
|
|
5709
|
+
jmax=int(np.min([np.log(im_target.shape[1]),np.log(im_target.shape[2])])/np.log(2))
|
|
5710
|
+
elif self.use_1D:
|
|
5711
|
+
jmax=int(np.log(im_target.shape[1])/np.log(2))
|
|
5712
|
+
else:
|
|
5713
|
+
jmax=int((np.log(im_target.shape[1]//12)/np.log(2))/2)
|
|
5714
|
+
nside=2**jmax
|
|
5715
|
+
|
|
5716
|
+
if nstep>jmax-1:
|
|
5717
|
+
nstep=jmax-1
|
|
5718
|
+
|
|
5719
|
+
t1=time.time()
|
|
5720
|
+
tmp={}
|
|
5721
|
+
tmp[nstep-1]=im_target
|
|
5722
|
+
for l in range(nstep-2,-1,-1):
|
|
5723
|
+
tmp[l]=self.ud_grade_2(tmp[l+1],axis=1)
|
|
5724
|
+
|
|
5725
|
+
if not self.use_2D and not self.use_1D:
|
|
5726
|
+
l_nside=nside//(2**(nstep-1))
|
|
5727
|
+
|
|
5728
|
+
for k in range(nstep):
|
|
5729
|
+
if k==0:
|
|
5730
|
+
np.random.seed(seed)
|
|
5731
|
+
if self.use_2D:
|
|
5732
|
+
imap=np.random.randn(synthesised_N,
|
|
5733
|
+
tmp[k].shape[1],
|
|
5734
|
+
tmp[k].shape[2])
|
|
5735
|
+
else:
|
|
5736
|
+
imap=np.random.randn(synthesised_N,
|
|
5737
|
+
tmp[k].shape[1])
|
|
5738
|
+
else:
|
|
5739
|
+
# Increase the resolution between each step
|
|
5740
|
+
if self.use_2D:
|
|
5741
|
+
imap = self.up_grade(
|
|
5742
|
+
omap, imap.shape[1] * 2, axis=1, nouty=imap.shape[2] * 2
|
|
5743
|
+
)
|
|
5744
|
+
elif self.use_1D:
|
|
5745
|
+
imap = self.up_grade(omap, imap.shape[1] * 2, axis=1)
|
|
5746
|
+
else:
|
|
5747
|
+
imap = self.up_grade(omap, l_nside, axis=1)
|
|
5748
|
+
|
|
5749
|
+
# compute the coefficients for the target image
|
|
5750
|
+
if use_variance:
|
|
5751
|
+
ref,sref=self.scattering_cov(tmp[k],get_variance=True,edge=edge,Jmax=Jmax,iso_ang=iso_ang)
|
|
5752
|
+
else:
|
|
5753
|
+
ref=self.scattering_cov(tmp[k],edge=edge,Jmax=Jmax,iso_ang=iso_ang)
|
|
5754
|
+
sref=ref
|
|
5755
|
+
|
|
5756
|
+
# compute the mean of the population does nothing if only one map is given
|
|
5757
|
+
ref=self.reduce_mean_batch(ref)
|
|
5758
|
+
|
|
5759
|
+
# define a loss to minimize
|
|
5760
|
+
loss=synthe.Loss(The_loss,self,ref,sref,use_variance)
|
|
5761
|
+
|
|
5762
|
+
sy = synthe.Synthesis([loss])
|
|
5763
|
+
|
|
5764
|
+
# initialize the synthesised map
|
|
5765
|
+
if self.use_2D:
|
|
5766
|
+
print('Synthesis scale [ %d x %d ]'%(imap.shape[1],imap.shape[2]))
|
|
5767
|
+
elif self.use_1D:
|
|
5768
|
+
print('Synthesis scale [ %d ]'%(imap.shape[1]))
|
|
5769
|
+
else:
|
|
5770
|
+
print('Synthesis scale nside=%d'%(l_nside))
|
|
5771
|
+
l_nside*=2
|
|
5772
|
+
|
|
5773
|
+
# do the minimization
|
|
5774
|
+
omap=sy.run(imap,
|
|
5775
|
+
EVAL_FREQUENCY=EVAL_FREQUENCY,
|
|
5776
|
+
NUM_EPOCHS = NUM_EPOCHS)
|
|
5777
|
+
|
|
5778
|
+
|
|
5779
|
+
|
|
5780
|
+
t2=time.time()
|
|
5781
|
+
print('Total computation %.2fs'%(t2-t1))
|
|
5782
|
+
|
|
5783
|
+
if to_gaussian:
|
|
5784
|
+
omap=self.from_gaussian(omap)
|
|
5785
|
+
|
|
5786
|
+
if axis==0 and synthesised_N==1:
|
|
5787
|
+
return omap[0]
|
|
5788
|
+
else:
|
|
5789
|
+
return omap
|
|
5790
|
+
|
|
5791
|
+
def to_numpy(self,x):
|
|
5792
|
+
return self.backend.to_numpy(x)
|