foscat 2025.6.1__py3-none-any.whl → 2025.7.1__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- foscat/BkBase.py +3 -0
- foscat/BkNumpy.py +5 -0
- foscat/BkTensorflow.py +5 -0
- foscat/BkTorch.py +6 -0
- foscat/FoCUS.py +740 -639
- foscat/HealSpline.py +211 -0
- foscat/heal_NN.py +43 -24
- foscat/scat_cov.py +433 -277
- {foscat-2025.6.1.dist-info → foscat-2025.7.1.dist-info}/METADATA +1 -1
- {foscat-2025.6.1.dist-info → foscat-2025.7.1.dist-info}/RECORD +13 -12
- {foscat-2025.6.1.dist-info → foscat-2025.7.1.dist-info}/WHEEL +0 -0
- {foscat-2025.6.1.dist-info → foscat-2025.7.1.dist-info}/licenses/LICENSE +0 -0
- {foscat-2025.6.1.dist-info → foscat-2025.7.1.dist-info}/top_level.txt +0 -0
foscat/scat_cov.py
CHANGED
|
@@ -44,6 +44,12 @@ class scat_cov:
|
|
|
44
44
|
self.idx1 = None
|
|
45
45
|
self.idx2 = None
|
|
46
46
|
self.use_1D = use_1D
|
|
47
|
+
self.numel = self.backend.bk_len(s0)+ \
|
|
48
|
+
self.backend.bk_len(s1)+ \
|
|
49
|
+
self.backend.bk_len(s2)+ \
|
|
50
|
+
self.backend.bk_len(s3)+ \
|
|
51
|
+
self.backend.bk_len(s4)+ \
|
|
52
|
+
self.backend.bk_len(s3p)
|
|
47
53
|
|
|
48
54
|
def numpy(self):
|
|
49
55
|
if self.BACKEND == "numpy":
|
|
@@ -92,12 +98,7 @@ class scat_cov:
|
|
|
92
98
|
)
|
|
93
99
|
|
|
94
100
|
def conv2complex(self, val):
|
|
95
|
-
if (
|
|
96
|
-
val.dtype == "complex64"
|
|
97
|
-
or val.dtype == "complex128"
|
|
98
|
-
or val.dtype == "torch.complex64"
|
|
99
|
-
or val.dtype == "torch.complex128"
|
|
100
|
-
):
|
|
101
|
+
if self.backend.bk_is_complex(val):
|
|
101
102
|
return val
|
|
102
103
|
else:
|
|
103
104
|
return self.backend.bk_complex(val, 0 * val)
|
|
@@ -107,7 +108,7 @@ class scat_cov:
|
|
|
107
108
|
def flatten(self):
|
|
108
109
|
tmp = [
|
|
109
110
|
self.conv2complex(
|
|
110
|
-
self.backend.bk_reshape(self.S0, [self.S1.shape[0], self.S0.shape[1]])
|
|
111
|
+
self.backend.bk_reshape(self.S0, [self.S1.shape[0], self.S0.shape[1]*self.S0.shape[2]])
|
|
111
112
|
)
|
|
112
113
|
]
|
|
113
114
|
if self.use_1D:
|
|
@@ -177,21 +178,9 @@ class scat_cov:
|
|
|
177
178
|
],
|
|
178
179
|
)
|
|
179
180
|
),
|
|
180
|
-
self.
|
|
181
|
-
self.S3,
|
|
182
|
-
[
|
|
183
|
-
self.S3.shape[0],
|
|
184
|
-
self.S3.shape[1]
|
|
185
|
-
* self.S3.shape[2]
|
|
186
|
-
* self.S3.shape[3]
|
|
187
|
-
* self.S3.shape[4],
|
|
188
|
-
],
|
|
189
|
-
),
|
|
190
|
-
]
|
|
191
|
-
if self.S3P is not None:
|
|
192
|
-
tmp = tmp + [
|
|
181
|
+
self.conv2complex(
|
|
193
182
|
self.backend.bk_reshape(
|
|
194
|
-
self.
|
|
183
|
+
self.S3,
|
|
195
184
|
[
|
|
196
185
|
self.S3.shape[0],
|
|
197
186
|
self.S3.shape[1]
|
|
@@ -200,22 +189,39 @@ class scat_cov:
|
|
|
200
189
|
* self.S3.shape[4],
|
|
201
190
|
],
|
|
202
191
|
)
|
|
192
|
+
),
|
|
193
|
+
]
|
|
194
|
+
if self.S3P is not None:
|
|
195
|
+
tmp = tmp + [
|
|
196
|
+
self.conv2complex(
|
|
197
|
+
self.backend.bk_reshape(
|
|
198
|
+
self.S3P,
|
|
199
|
+
[
|
|
200
|
+
self.S3.shape[0],
|
|
201
|
+
self.S3.shape[1]
|
|
202
|
+
* self.S3.shape[2]
|
|
203
|
+
* self.S3.shape[3]
|
|
204
|
+
* self.S3.shape[4],
|
|
205
|
+
],
|
|
206
|
+
)
|
|
207
|
+
)
|
|
203
208
|
]
|
|
204
209
|
|
|
205
210
|
tmp = tmp + [
|
|
206
|
-
self.
|
|
207
|
-
self.
|
|
208
|
-
|
|
209
|
-
|
|
210
|
-
|
|
211
|
-
|
|
212
|
-
|
|
213
|
-
|
|
214
|
-
|
|
215
|
-
|
|
211
|
+
self.conv2complex(
|
|
212
|
+
self.backend.bk_reshape(
|
|
213
|
+
self.S4,
|
|
214
|
+
[
|
|
215
|
+
self.S4.shape[0],
|
|
216
|
+
self.S4.shape[1]
|
|
217
|
+
* self.S4.shape[2]
|
|
218
|
+
* self.S4.shape[3]
|
|
219
|
+
* self.S4.shape[4]
|
|
220
|
+
* self.S4.shape[5],
|
|
221
|
+
],
|
|
222
|
+
)
|
|
216
223
|
)
|
|
217
224
|
]
|
|
218
|
-
|
|
219
225
|
return self.backend.bk_concat(tmp, 1)
|
|
220
226
|
|
|
221
227
|
# ---------------------------------------------−---------
|
|
@@ -1589,20 +1595,20 @@ class scat_cov:
|
|
|
1589
1595
|
def mean(self):
|
|
1590
1596
|
if self.S1 is not None: # Auto
|
|
1591
1597
|
return (
|
|
1592
|
-
abs(self.get_np(self.S0)).
|
|
1593
|
-
+ abs(self.get_np(self.S1)).
|
|
1594
|
-
+ abs(self.get_np(self.S3)).
|
|
1595
|
-
+ abs(self.get_np(self.S4)).
|
|
1596
|
-
+ abs(self.get_np(self.S2)).
|
|
1597
|
-
) /
|
|
1598
|
+
abs(self.get_np(self.S0)).sum()
|
|
1599
|
+
+ abs(self.get_np(self.S1)).sum()
|
|
1600
|
+
+ abs(self.get_np(self.S3)).sum()
|
|
1601
|
+
+ abs(self.get_np(self.S4)).sum()
|
|
1602
|
+
+ abs(self.get_np(self.S2)).sum()
|
|
1603
|
+
) / self.numel
|
|
1598
1604
|
else: # Cross
|
|
1599
1605
|
return (
|
|
1600
|
-
abs(self.get_np(self.S0)).
|
|
1601
|
-
+ abs(self.get_np(self.S3)).
|
|
1602
|
-
+ abs(self.get_np(self.S3P)).
|
|
1603
|
-
+ abs(self.get_np(self.S4)).
|
|
1604
|
-
+ abs(self.get_np(self.S2)).
|
|
1605
|
-
) /
|
|
1606
|
+
abs(self.get_np(self.S0)).sum()
|
|
1607
|
+
+ abs(self.get_np(self.S3)).sum()
|
|
1608
|
+
+ abs(self.get_np(self.S3P)).sum()
|
|
1609
|
+
+ abs(self.get_np(self.S4)).sum()
|
|
1610
|
+
+ abs(self.get_np(self.S2)).sum()
|
|
1611
|
+
) / self.numel
|
|
1606
1612
|
|
|
1607
1613
|
def initdx(self, norient):
|
|
1608
1614
|
idx1 = np.zeros([norient * norient], dtype="int")
|
|
@@ -2348,16 +2354,16 @@ class funct(FOC.FoCUS):
|
|
|
2348
2354
|
)
|
|
2349
2355
|
|
|
2350
2356
|
# compute local direction to make the statistical analysis more efficient
|
|
2351
|
-
def stat_cfft(self, im, image2=None, upscale=False, smooth_scale=0):
|
|
2357
|
+
def stat_cfft(self, im, image2=None, upscale=False, smooth_scale=0,spin=0):
|
|
2352
2358
|
tmp = im
|
|
2353
2359
|
if image2 is not None:
|
|
2354
2360
|
tmpi2 = image2
|
|
2355
2361
|
if upscale:
|
|
2356
|
-
l_nside = int(np.sqrt(tmp.shape[1] // 12))
|
|
2357
|
-
tmp = self.up_grade(tmp, l_nside * 2
|
|
2362
|
+
l_nside = int(np.sqrt(tmp.shape[-1] // 12))
|
|
2363
|
+
tmp = self.up_grade(tmp, l_nside * 2)
|
|
2358
2364
|
if image2 is not None:
|
|
2359
|
-
tmpi2 = self.up_grade(tmpi2, l_nside * 2
|
|
2360
|
-
l_nside = int(np.sqrt(tmp.shape[1] // 12))
|
|
2365
|
+
tmpi2 = self.up_grade(tmpi2, l_nside * 2)
|
|
2366
|
+
l_nside = int(np.sqrt(tmp.shape[-1] // 12))
|
|
2361
2367
|
nscale = int(np.log(l_nside) / np.log(2))
|
|
2362
2368
|
cmat = {}
|
|
2363
2369
|
cmat2 = {}
|
|
@@ -2367,20 +2373,23 @@ class funct(FOC.FoCUS):
|
|
|
2367
2373
|
if image2 is not None:
|
|
2368
2374
|
sim = self.backend.bk_real(
|
|
2369
2375
|
self.backend.bk_L1(
|
|
2370
|
-
self.convol(tmp,
|
|
2371
|
-
* self.backend.bk_conjugate(self.convol(tmpi2,
|
|
2376
|
+
self.convol(tmp,spin=spin)
|
|
2377
|
+
* self.backend.bk_conjugate(self.convol(tmpi2,spin=spin))
|
|
2372
2378
|
)
|
|
2373
2379
|
)
|
|
2374
2380
|
else:
|
|
2375
|
-
sim = self.backend.bk_abs(self.convol(tmp,
|
|
2381
|
+
sim = self.backend.bk_abs(self.convol(tmp,spin=spin))
|
|
2376
2382
|
|
|
2377
2383
|
# instead of difference between "opposite" channels use weighted average
|
|
2378
2384
|
# of cosine and sine contributions using all channels
|
|
2379
|
-
|
|
2380
|
-
|
|
2381
|
-
|
|
2382
|
-
|
|
2383
|
-
|
|
2385
|
+
if spin==0:
|
|
2386
|
+
angles = self.backend.bk_cast(
|
|
2387
|
+
(2 * np.pi * np.arange(self.NORIENT)
|
|
2388
|
+
/ self.NORIENT).reshape(1,self.NORIENT,1)) # shape: (NORIENT,)
|
|
2389
|
+
else:
|
|
2390
|
+
angles = self.backend.bk_cast(
|
|
2391
|
+
(2 * np.pi * np.arange(self.NORIENT)
|
|
2392
|
+
/ self.NORIENT).reshape(1,1,self.NORIENT,1)) # shape: (NORIENT,)
|
|
2384
2393
|
|
|
2385
2394
|
# we use cosines and sines as weights for sim
|
|
2386
2395
|
weighted_cos = self.backend.bk_reduce_mean(
|
|
@@ -2399,8 +2408,8 @@ class funct(FOC.FoCUS):
|
|
|
2399
2408
|
cc, _ = self.ud_grade_2(self.smooth(cc))
|
|
2400
2409
|
ss, _ = self.ud_grade_2(self.smooth(ss))
|
|
2401
2410
|
|
|
2402
|
-
if cc.shape[
|
|
2403
|
-
ll_nside = int(np.sqrt(tmp.shape[1] // 12))
|
|
2411
|
+
if cc.shape[-1] != tmp.shape[-1]:
|
|
2412
|
+
ll_nside = int(np.sqrt(tmp.shape[-1] // 12))
|
|
2404
2413
|
cc = self.up_grade(cc, ll_nside)
|
|
2405
2414
|
ss = self.up_grade(ss, ll_nside)
|
|
2406
2415
|
|
|
@@ -2423,46 +2432,70 @@ class funct(FOC.FoCUS):
|
|
|
2423
2432
|
w1 = np.sin(delta * np.pi / 2) ** 2
|
|
2424
2433
|
|
|
2425
2434
|
# build rotation matrix
|
|
2426
|
-
|
|
2427
|
-
|
|
2435
|
+
if spin==0:
|
|
2436
|
+
mat = np.zeros([self.NORIENT * self.NORIENT, sim.shape[-1]])
|
|
2437
|
+
else:
|
|
2438
|
+
mat = np.zeros([2,self.NORIENT * self.NORIENT, sim.shape[-1]])
|
|
2439
|
+
lidx = np.arange(sim.shape[-1])
|
|
2428
2440
|
for ell in range(self.NORIENT):
|
|
2429
2441
|
# Instead of simple linear weights, we use the cosine weights w0 and w1.
|
|
2430
2442
|
col0 = self.NORIENT * ((ell + iph) % self.NORIENT) + ell
|
|
2431
2443
|
col1 = self.NORIENT * ((ell + iph + 1) % self.NORIENT) + ell
|
|
2432
2444
|
|
|
2433
|
-
|
|
2434
|
-
|
|
2445
|
+
if spin==0:
|
|
2446
|
+
mat[col0, lidx] = w0
|
|
2447
|
+
mat[col1, lidx] = w1
|
|
2448
|
+
else:
|
|
2449
|
+
mat[0,col0, lidx] = w0[0]
|
|
2450
|
+
mat[0,col1, lidx] = w1[0]
|
|
2451
|
+
mat[1,col0, lidx] = w0[1]
|
|
2452
|
+
mat[1,col1, lidx] = w1[1]
|
|
2435
2453
|
|
|
2436
2454
|
cmat[k] = self.backend.bk_cast(mat[None, ...].astype("complex64"))
|
|
2437
2455
|
|
|
2438
2456
|
# do same modifications for mat2
|
|
2439
|
-
|
|
2440
|
-
|
|
2441
|
-
|
|
2457
|
+
if spin==0:
|
|
2458
|
+
mat2 = np.zeros(
|
|
2459
|
+
[k + 1, self.NORIENT * self.NORIENT, self.NORIENT, sim.shape[-1]]
|
|
2460
|
+
)
|
|
2461
|
+
else:
|
|
2462
|
+
mat2 = np.zeros(
|
|
2463
|
+
[k + 1, 2, self.NORIENT * self.NORIENT, self.NORIENT, sim.shape[-1]]
|
|
2464
|
+
)
|
|
2442
2465
|
|
|
2443
2466
|
for k2 in range(k + 1):
|
|
2444
2467
|
|
|
2445
|
-
tmp2 = self.backend.
|
|
2446
|
-
|
|
2447
|
-
|
|
2448
|
-
|
|
2449
|
-
|
|
2450
|
-
|
|
2451
|
-
|
|
2452
|
-
|
|
2453
|
-
|
|
2454
|
-
|
|
2455
|
-
|
|
2456
|
-
|
|
2457
|
-
|
|
2458
|
-
|
|
2468
|
+
tmp2 = self.backend.bk_expand_dims(sim,-2)
|
|
2469
|
+
if spin==0:
|
|
2470
|
+
sim2 = self.backend.bk_reduce_sum(
|
|
2471
|
+
self.backend.bk_reshape(
|
|
2472
|
+
self.backend.bk_cast(
|
|
2473
|
+
mat.reshape(1, self.NORIENT, self.NORIENT, mat.shape[-1])
|
|
2474
|
+
)
|
|
2475
|
+
* tmp2,
|
|
2476
|
+
[sim.shape[0], self.NORIENT, self.NORIENT, mat.shape[-1]],
|
|
2477
|
+
),
|
|
2478
|
+
1,
|
|
2479
|
+
)
|
|
2480
|
+
else:
|
|
2481
|
+
sim2 = self.backend.bk_reduce_sum(
|
|
2482
|
+
self.backend.bk_reshape(
|
|
2483
|
+
self.backend.bk_cast(
|
|
2484
|
+
mat.reshape(1, 2, self.NORIENT, self.NORIENT, mat.shape[-1])
|
|
2485
|
+
)
|
|
2486
|
+
* tmp2,
|
|
2487
|
+
[sim.shape[0], 2, self.NORIENT, self.NORIENT, mat.shape[-1]],
|
|
2488
|
+
),
|
|
2489
|
+
2,
|
|
2490
|
+
)
|
|
2459
2491
|
|
|
2492
|
+
sim2 = self.backend.bk_abs(self.convol(sim2))
|
|
2460
2493
|
angles = self.backend.bk_reshape(angles, [1, self.NORIENT, 1, 1])
|
|
2461
2494
|
weighted_cos2 = self.backend.bk_reduce_mean(
|
|
2462
|
-
sim2 * self.backend.bk_cos(angles), axis
|
|
2495
|
+
sim2 * self.backend.bk_cos(angles), axis=-3
|
|
2463
2496
|
)
|
|
2464
2497
|
weighted_sin2 = self.backend.bk_reduce_mean(
|
|
2465
|
-
sim2 * self.backend.bk_sin(angles), axis
|
|
2498
|
+
sim2 * self.backend.bk_sin(angles), axis=-3
|
|
2466
2499
|
)
|
|
2467
2500
|
|
|
2468
2501
|
cc2 = weighted_cos2[0]
|
|
@@ -2471,13 +2504,13 @@ class funct(FOC.FoCUS):
|
|
|
2471
2504
|
if smooth_scale > 0:
|
|
2472
2505
|
for m in range(smooth_scale):
|
|
2473
2506
|
if cc2.shape[1] > 12:
|
|
2474
|
-
cc2, _ = self.ud_grade_2(self.smooth(cc2
|
|
2475
|
-
ss2, _ = self.ud_grade_2(self.smooth(ss2
|
|
2507
|
+
cc2, _ = self.ud_grade_2(self.smooth(cc2))
|
|
2508
|
+
ss2, _ = self.ud_grade_2(self.smooth(ss2))
|
|
2476
2509
|
|
|
2477
|
-
if cc2.shape[1] != sim.shape[
|
|
2478
|
-
ll_nside = int(np.sqrt(sim.shape[
|
|
2479
|
-
cc2 = self.up_grade(cc2, ll_nside
|
|
2480
|
-
ss2 = self.up_grade(ss2, ll_nside
|
|
2510
|
+
if cc2.shape[-1] != sim.shape[-1]:
|
|
2511
|
+
ll_nside = int(np.sqrt(sim.shape[-1] // 12))
|
|
2512
|
+
cc2 = self.up_grade(cc2, ll_nside)
|
|
2513
|
+
ss2 = self.up_grade(ss2, ll_nside)
|
|
2481
2514
|
|
|
2482
2515
|
if self.BACKEND == "numpy":
|
|
2483
2516
|
phase2 = np.fmod(np.arctan2(ss2, cc2) + 2 * np.pi, 2 * np.pi)
|
|
@@ -2495,23 +2528,32 @@ class funct(FOC.FoCUS):
|
|
|
2495
2528
|
delta2 = phase2_scaled - iph2
|
|
2496
2529
|
w0_2 = np.cos(delta2 * np.pi / 2) ** 2
|
|
2497
2530
|
w1_2 = np.sin(delta2 * np.pi / 2) ** 2
|
|
2498
|
-
lidx = np.arange(sim.shape[
|
|
2499
|
-
|
|
2500
|
-
|
|
2501
|
-
for
|
|
2502
|
-
|
|
2503
|
-
|
|
2504
|
-
|
|
2505
|
-
|
|
2531
|
+
lidx = np.arange(sim.shape[-1])
|
|
2532
|
+
|
|
2533
|
+
if spin==0:
|
|
2534
|
+
for m in range(self.NORIENT):
|
|
2535
|
+
for ell in range(self.NORIENT):
|
|
2536
|
+
col0 = self.NORIENT * ((ell + iph2[m]) % self.NORIENT) + ell
|
|
2537
|
+
col1 = self.NORIENT * ((ell + iph2[m] + 1) % self.NORIENT) + ell
|
|
2538
|
+
mat2[k2, col0, m, lidx] = w0_2[m, lidx]
|
|
2539
|
+
mat2[k2, col1, m, lidx] = w1_2[m, lidx]
|
|
2540
|
+
else:
|
|
2541
|
+
for sidx in range(2):
|
|
2542
|
+
for m in range(self.NORIENT):
|
|
2543
|
+
for ell in range(self.NORIENT):
|
|
2544
|
+
col0 = self.NORIENT * ((ell + iph2[sidx,m]) % self.NORIENT) + ell
|
|
2545
|
+
col1 = self.NORIENT * ((ell + iph2[sidx,m] + 1) % self.NORIENT) + ell
|
|
2546
|
+
mat2[k2, sidx, col0, m, lidx] = w0_2[sidx,m, lidx]
|
|
2547
|
+
mat2[k2, sidx, col1, m, lidx] = w1_2[sidx,m, lidx]
|
|
2506
2548
|
|
|
2507
2549
|
cmat2[k] = self.backend.bk_cast(
|
|
2508
2550
|
mat2[0 : k + 1, None, ...].astype("complex64")
|
|
2509
2551
|
)
|
|
2510
2552
|
|
|
2511
2553
|
if k < l_nside - 1:
|
|
2512
|
-
tmp, _ = self.ud_grade_2(tmp
|
|
2554
|
+
tmp, _ = self.ud_grade_2(tmp)
|
|
2513
2555
|
if image2 is not None:
|
|
2514
|
-
tmpi2, _ = self.ud_grade_2(
|
|
2556
|
+
tmpi2, _ = self.ud_grade_2(tmpi)
|
|
2515
2557
|
return cmat, cmat2
|
|
2516
2558
|
|
|
2517
2559
|
def div_norm(self, complex_value, float_value):
|
|
@@ -2534,6 +2576,7 @@ class funct(FOC.FoCUS):
|
|
|
2534
2576
|
edge=True,
|
|
2535
2577
|
nside=None,
|
|
2536
2578
|
cell_ids=None,
|
|
2579
|
+
spin=0
|
|
2537
2580
|
):
|
|
2538
2581
|
"""
|
|
2539
2582
|
Calculates the scattering correlations for a batch of images. Mean are done over pixels.
|
|
@@ -2559,11 +2602,14 @@ class funct(FOC.FoCUS):
|
|
|
2559
2602
|
norm: None or str
|
|
2560
2603
|
If None no normalization is applied, if 'auto' normalize by the reference S2,
|
|
2561
2604
|
if 'self' normalize by the current S2.
|
|
2605
|
+
spin : Integer
|
|
2606
|
+
If different from 0 compute spinned data (U,V to Divergence/Rotational spin==1) or (Q,U to E,B spin=2).
|
|
2607
|
+
This implies that the input data is 2*12*nside^2.
|
|
2562
2608
|
Returns
|
|
2563
2609
|
-------
|
|
2564
2610
|
S1, S2, S3, S4 normalized
|
|
2565
2611
|
"""
|
|
2566
|
-
|
|
2612
|
+
|
|
2567
2613
|
return_data = self.return_data
|
|
2568
2614
|
|
|
2569
2615
|
# Check input consistency
|
|
@@ -2576,8 +2622,8 @@ class funct(FOC.FoCUS):
|
|
|
2576
2622
|
if mask is not None:
|
|
2577
2623
|
if self.use_2D:
|
|
2578
2624
|
if (
|
|
2579
|
-
image1.shape[-2] != mask.shape[1]
|
|
2580
|
-
or image1.shape[-1] != mask.shape[
|
|
2625
|
+
image1.shape[-2] != mask.shape[-1]
|
|
2626
|
+
or image1.shape[-1] != mask.shape[-1]
|
|
2581
2627
|
):
|
|
2582
2628
|
print(
|
|
2583
2629
|
"The LAST 2 COLUMNs of the mask should have the same size ",
|
|
@@ -2588,7 +2634,7 @@ class funct(FOC.FoCUS):
|
|
|
2588
2634
|
)
|
|
2589
2635
|
return None
|
|
2590
2636
|
else:
|
|
2591
|
-
if image1.shape[-1] != mask.shape[1]:
|
|
2637
|
+
if image1.shape[-1] != mask.shape[-1]:
|
|
2592
2638
|
print(
|
|
2593
2639
|
"The LAST COLUMN of the mask should have the same size ",
|
|
2594
2640
|
mask.shape,
|
|
@@ -2609,7 +2655,6 @@ class funct(FOC.FoCUS):
|
|
|
2609
2655
|
cross = True
|
|
2610
2656
|
l_nside = 2**32 # not initialize if 1D or 2D
|
|
2611
2657
|
### PARAMETERS
|
|
2612
|
-
axis = 1
|
|
2613
2658
|
# determine jmax and nside corresponding to the input map
|
|
2614
2659
|
im_shape = image1.shape
|
|
2615
2660
|
if self.use_2D:
|
|
@@ -2637,10 +2682,7 @@ class funct(FOC.FoCUS):
|
|
|
2637
2682
|
|
|
2638
2683
|
J = int(np.log(nside) / np.log(2)) # Number of j scales
|
|
2639
2684
|
else:
|
|
2640
|
-
|
|
2641
|
-
npix = int(im_shape[1]) # Number of pixels
|
|
2642
|
-
else:
|
|
2643
|
-
npix = int(im_shape[0]) # Number of pixels
|
|
2685
|
+
npix=int(im_shape[-1])
|
|
2644
2686
|
|
|
2645
2687
|
if nside is None:
|
|
2646
2688
|
nside = int(np.sqrt(npix // 12))
|
|
@@ -2659,7 +2701,7 @@ class funct(FOC.FoCUS):
|
|
|
2659
2701
|
print("\n\n==========")
|
|
2660
2702
|
|
|
2661
2703
|
### LOCAL VARIABLES (IMAGES and MASK)
|
|
2662
|
-
if len(image1.shape) == 1 or (len(image1.shape) == 2 and self.use_2D):
|
|
2704
|
+
if len(image1.shape) == 1 or (len(image1.shape) == 2 and self.use_2D) or (len(image1.shape) == 2 and spin>0):
|
|
2663
2705
|
I1 = self.backend.bk_cast(
|
|
2664
2706
|
self.backend.bk_expand_dims(image1, 0)
|
|
2665
2707
|
) # Local image1 [Nbatch, Npix]
|
|
@@ -2684,26 +2726,26 @@ class funct(FOC.FoCUS):
|
|
|
2684
2726
|
# if the kernel size is bigger than 3 increase the binning before smoothing
|
|
2685
2727
|
if self.use_2D:
|
|
2686
2728
|
vmask = self.up_grade(
|
|
2687
|
-
vmask, I1.shape[
|
|
2729
|
+
vmask, I1.shape[-2] * 2, nouty=I1.shape[-1] * 2,axis=-2
|
|
2688
2730
|
)
|
|
2689
2731
|
I1 = self.up_grade(
|
|
2690
|
-
I1, I1.shape[
|
|
2732
|
+
I1, I1.shape[-2] * 2, nouty=I1.shape[-1] * 2,axis=-2
|
|
2691
2733
|
)
|
|
2692
2734
|
if cross:
|
|
2693
2735
|
I2 = self.up_grade(
|
|
2694
|
-
I2, I2.shape[
|
|
2736
|
+
I2, I2.shape[-2] * 2, nouty=I2.shape[-1] * 2,axis=-2
|
|
2695
2737
|
)
|
|
2696
2738
|
elif self.use_1D:
|
|
2697
|
-
vmask = self.up_grade(vmask, I1.shape[
|
|
2698
|
-
I1 = self.up_grade(I1, I1.shape[
|
|
2739
|
+
vmask = self.up_grade(vmask, I1.shape[-1] * 2)
|
|
2740
|
+
I1 = self.up_grade(I1, I1.shape[-1] * 2)
|
|
2699
2741
|
if cross:
|
|
2700
|
-
I2 = self.up_grade(I2, I2.shape[
|
|
2742
|
+
I2 = self.up_grade(I2, I2.shape[-1] * 2)
|
|
2701
2743
|
nside = nside * 2
|
|
2702
2744
|
else:
|
|
2703
|
-
I1 = self.up_grade(I1, nside * 2
|
|
2704
|
-
vmask = self.up_grade(vmask, nside * 2
|
|
2745
|
+
I1 = self.up_grade(I1, nside * 2)
|
|
2746
|
+
vmask = self.up_grade(vmask, nside * 2)
|
|
2705
2747
|
if cross:
|
|
2706
|
-
I2 = self.up_grade(I2, nside * 2
|
|
2748
|
+
I2 = self.up_grade(I2, nside * 2)
|
|
2707
2749
|
|
|
2708
2750
|
nside = nside * 2
|
|
2709
2751
|
|
|
@@ -2711,29 +2753,28 @@ class funct(FOC.FoCUS):
|
|
|
2711
2753
|
# if the kernel size is bigger than 3 increase the binning before smoothing
|
|
2712
2754
|
if self.use_2D:
|
|
2713
2755
|
vmask = self.up_grade(
|
|
2714
|
-
vmask, I1.shape[
|
|
2756
|
+
vmask, I1.shape[-2] * 2, nouty=I1.shape[-1] * 2,axis=-2
|
|
2715
2757
|
)
|
|
2716
2758
|
I1 = self.up_grade(
|
|
2717
|
-
I1, I1.shape[
|
|
2759
|
+
I1, I1.shape[-2] * 2, nouty=I1.shape[-1] * 2,axis=-2
|
|
2718
2760
|
)
|
|
2719
2761
|
if cross:
|
|
2720
2762
|
I2 = self.up_grade(
|
|
2721
2763
|
I2,
|
|
2722
|
-
I2.shape[
|
|
2723
|
-
|
|
2724
|
-
nouty=I2.shape[axis + 1] * 2,
|
|
2764
|
+
I2.shape[-2] * 2,
|
|
2765
|
+
nouty=I2.shape[-1] * 2,axis=-2
|
|
2725
2766
|
)
|
|
2726
2767
|
elif self.use_1D:
|
|
2727
|
-
vmask = self.up_grade(vmask, I1.shape[
|
|
2728
|
-
I1 = self.up_grade(I1, I1.shape[
|
|
2768
|
+
vmask = self.up_grade(vmask, I1.shape[-1] * 2)
|
|
2769
|
+
I1 = self.up_grade(I1, I1.shape[-1] * 2)
|
|
2729
2770
|
if cross:
|
|
2730
|
-
I2 = self.up_grade(I2, I2.shape[
|
|
2771
|
+
I2 = self.up_grade(I2, I2.shape[-1] * 2)
|
|
2731
2772
|
nside = nside * 2
|
|
2732
2773
|
else:
|
|
2733
|
-
I1 = self.up_grade(I1, nside * 2
|
|
2734
|
-
vmask = self.up_grade(vmask, nside * 2
|
|
2774
|
+
I1 = self.up_grade(I1, nside * 2)
|
|
2775
|
+
vmask = self.up_grade(vmask, nside * 2)
|
|
2735
2776
|
if cross:
|
|
2736
|
-
I2 = self.up_grade(I2, nside * 2
|
|
2777
|
+
I2 = self.up_grade(I2, nside * 2)
|
|
2737
2778
|
nside = nside * 2
|
|
2738
2779
|
|
|
2739
2780
|
# Normalize the masks because they have different pixel numbers
|
|
@@ -2762,6 +2803,7 @@ class funct(FOC.FoCUS):
|
|
|
2762
2803
|
off_S2 = -2
|
|
2763
2804
|
off_S3 = -3
|
|
2764
2805
|
off_S4 = -4
|
|
2806
|
+
|
|
2765
2807
|
if self.use_1D:
|
|
2766
2808
|
off_S2 = -1
|
|
2767
2809
|
off_S3 = -1
|
|
@@ -2793,13 +2835,20 @@ class funct(FOC.FoCUS):
|
|
|
2793
2835
|
)
|
|
2794
2836
|
else:
|
|
2795
2837
|
if not cross:
|
|
2796
|
-
s0, l_vs0 = self.masked_mean(I1,
|
|
2838
|
+
s0, l_vs0 = self.masked_mean(I1,
|
|
2839
|
+
vmask,
|
|
2840
|
+
calc_var=True)
|
|
2797
2841
|
else:
|
|
2798
2842
|
s0, l_vs0 = self.masked_mean(
|
|
2799
|
-
self.backend.bk_L1(I1 * I2),
|
|
2800
|
-
|
|
2801
|
-
|
|
2802
|
-
|
|
2843
|
+
self.backend.bk_L1(I1 * I2),
|
|
2844
|
+
vmask,
|
|
2845
|
+
calc_var=True)
|
|
2846
|
+
|
|
2847
|
+
vs0 = self.backend.bk_concat([l_vs0, l_vs0], -1)
|
|
2848
|
+
s0 = self.backend.bk_concat([s0, l_vs0], -1)
|
|
2849
|
+
if spin>0:
|
|
2850
|
+
vs0=self.backend.bk_reshape(vs0,[vs0.shape[0],vs0.shape[1],2,vs0.shape[2]//2])
|
|
2851
|
+
s0=self.backend.bk_reshape(s0,[s0.shape[0],s0.shape[1],2,s0.shape[2]//2])
|
|
2803
2852
|
#### COMPUTE S1, S2, S3 and S4
|
|
2804
2853
|
nside_j3 = nside # NSIDE start (nside_j3 = nside / 2^j3)
|
|
2805
2854
|
|
|
@@ -2848,20 +2897,29 @@ class funct(FOC.FoCUS):
|
|
|
2848
2897
|
####### S1 and S2
|
|
2849
2898
|
### Make the convolution I1 * Psi_j3
|
|
2850
2899
|
conv1 = self.convol(
|
|
2851
|
-
I1,
|
|
2900
|
+
I1, cell_ids=cell_ids_j3, nside=nside_j3,
|
|
2901
|
+
spin=spin
|
|
2852
2902
|
) # [Nbatch, Norient3 , Npix_j3]
|
|
2853
2903
|
|
|
2854
2904
|
if cmat is not None:
|
|
2855
|
-
|
|
2856
2905
|
tmp2 = self.backend.bk_repeat(conv1, self.NORIENT, axis=-2)
|
|
2857
|
-
|
|
2858
|
-
|
|
2859
|
-
self.backend.
|
|
2860
|
-
|
|
2861
|
-
|
|
2862
|
-
|
|
2863
|
-
|
|
2864
|
-
|
|
2906
|
+
|
|
2907
|
+
if spin==0:
|
|
2908
|
+
conv1 = self.backend.bk_reduce_sum(
|
|
2909
|
+
self.backend.bk_reshape(
|
|
2910
|
+
cmat[j3] * tmp2,
|
|
2911
|
+
[tmp2.shape[0], self.NORIENT, self.NORIENT, cmat[j3].shape[2]],
|
|
2912
|
+
),
|
|
2913
|
+
1,
|
|
2914
|
+
)
|
|
2915
|
+
else:
|
|
2916
|
+
conv1 = self.backend.bk_reduce_sum(
|
|
2917
|
+
self.backend.bk_reshape(
|
|
2918
|
+
cmat[j3] * tmp2,
|
|
2919
|
+
[tmp2.shape[0], 2,self.NORIENT, self.NORIENT, cmat[j3].shape[3]],
|
|
2920
|
+
),
|
|
2921
|
+
2,
|
|
2922
|
+
)
|
|
2865
2923
|
|
|
2866
2924
|
### Take the module M1 = |I1 * Psi_j3|
|
|
2867
2925
|
M1_square = conv1 * self.backend.bk_conjugate(
|
|
@@ -2871,7 +2929,6 @@ class funct(FOC.FoCUS):
|
|
|
2871
2929
|
M1 = self.backend.bk_L1(M1_square) # [Nbatch, Npix_j3, Norient3]
|
|
2872
2930
|
# Store M1_j3 in a dictionary
|
|
2873
2931
|
M1_dic[j3] = M1
|
|
2874
|
-
|
|
2875
2932
|
if not cross: # Auto
|
|
2876
2933
|
M1_square = self.backend.bk_real(M1_square)
|
|
2877
2934
|
|
|
@@ -2882,11 +2939,11 @@ class funct(FOC.FoCUS):
|
|
|
2882
2939
|
else:
|
|
2883
2940
|
if calc_var:
|
|
2884
2941
|
s2, vs2 = self.masked_mean(
|
|
2885
|
-
M1_square, vmask,
|
|
2942
|
+
M1_square, vmask, rank=j3, calc_var=True
|
|
2886
2943
|
)
|
|
2887
2944
|
else:
|
|
2888
|
-
s2 = self.masked_mean(M1_square, vmask,
|
|
2889
|
-
|
|
2945
|
+
s2 = self.masked_mean(M1_square, vmask, rank=j3)
|
|
2946
|
+
|
|
2890
2947
|
if cond_init_P1_dic:
|
|
2891
2948
|
# We fill P1_dic with S2 for normalisation of S3 and S4
|
|
2892
2949
|
P1_dic[j3] = self.backend.bk_real(s2) # [Nbatch, Nmask, Norient3]
|
|
@@ -2929,11 +2986,11 @@ class funct(FOC.FoCUS):
|
|
|
2929
2986
|
else:
|
|
2930
2987
|
if calc_var:
|
|
2931
2988
|
s1, vs1 = self.masked_mean(
|
|
2932
|
-
M1, vmask,
|
|
2989
|
+
M1, vmask, rank=j3, calc_var=True
|
|
2933
2990
|
) # [Nbatch, Nmask, Norient3]
|
|
2934
2991
|
else:
|
|
2935
2992
|
s1 = self.masked_mean(
|
|
2936
|
-
M1, vmask,
|
|
2993
|
+
M1, vmask, rank=j3
|
|
2937
2994
|
) # [Nbatch, Nmask, Norient3]
|
|
2938
2995
|
|
|
2939
2996
|
if return_data:
|
|
@@ -2967,22 +3024,38 @@ class funct(FOC.FoCUS):
|
|
|
2967
3024
|
else: # Cross
|
|
2968
3025
|
### Make the convolution I2 * Psi_j3
|
|
2969
3026
|
conv2 = self.convol(
|
|
2970
|
-
I2,
|
|
3027
|
+
I2, cell_ids=cell_ids_j3, nside=nside_j3,
|
|
3028
|
+
spin=spin
|
|
2971
3029
|
) # [Nbatch, Npix_j3, Norient3]
|
|
2972
3030
|
if cmat is not None:
|
|
2973
3031
|
tmp2 = self.backend.bk_repeat(conv2, self.NORIENT, axis=-2)
|
|
2974
|
-
|
|
2975
|
-
self.backend.
|
|
2976
|
-
|
|
2977
|
-
|
|
2978
|
-
|
|
2979
|
-
|
|
2980
|
-
|
|
2981
|
-
|
|
2982
|
-
|
|
2983
|
-
|
|
2984
|
-
|
|
2985
|
-
|
|
3032
|
+
if spin==0:
|
|
3033
|
+
conv2 = self.backend.bk_reduce_sum(
|
|
3034
|
+
self.backend.bk_reshape(
|
|
3035
|
+
cmat[j3] * tmp2,
|
|
3036
|
+
[
|
|
3037
|
+
tmp2.shape[0],
|
|
3038
|
+
self.NORIENT,
|
|
3039
|
+
self.NORIENT,
|
|
3040
|
+
cmat[j3].shape[2],
|
|
3041
|
+
],
|
|
3042
|
+
),
|
|
3043
|
+
1,
|
|
3044
|
+
)
|
|
3045
|
+
else:
|
|
3046
|
+
conv2 = self.backend.bk_reduce_sum(
|
|
3047
|
+
self.backend.bk_reshape(
|
|
3048
|
+
cmat[j3] * tmp2,
|
|
3049
|
+
[
|
|
3050
|
+
tmp2.shape[0],
|
|
3051
|
+
2,
|
|
3052
|
+
self.NORIENT,
|
|
3053
|
+
self.NORIENT,
|
|
3054
|
+
cmat[j3].shape[3],
|
|
3055
|
+
],
|
|
3056
|
+
),
|
|
3057
|
+
2,
|
|
3058
|
+
)
|
|
2986
3059
|
### Take the module M2 = |I2 * Psi_j3|
|
|
2987
3060
|
M2_square = conv2 * self.backend.bk_conjugate(
|
|
2988
3061
|
conv2
|
|
@@ -3001,17 +3074,17 @@ class funct(FOC.FoCUS):
|
|
|
3001
3074
|
else:
|
|
3002
3075
|
if calc_var:
|
|
3003
3076
|
p1, vp1 = self.masked_mean(
|
|
3004
|
-
M1_square, vmask,
|
|
3077
|
+
M1_square, vmask, rank=j3, calc_var=True
|
|
3005
3078
|
) # [Nbatch, Nmask, Norient3]
|
|
3006
3079
|
p2, vp2 = self.masked_mean(
|
|
3007
|
-
M2_square, vmask,
|
|
3080
|
+
M2_square, vmask, rank=j3, calc_var=True
|
|
3008
3081
|
) # [Nbatch, Nmask, Norient3]
|
|
3009
3082
|
else:
|
|
3010
3083
|
p1 = self.masked_mean(
|
|
3011
|
-
M1_square, vmask,
|
|
3084
|
+
M1_square, vmask, rank=j3
|
|
3012
3085
|
) # [Nbatch, Nmask, Norient3]
|
|
3013
3086
|
p2 = self.masked_mean(
|
|
3014
|
-
M2_square, vmask,
|
|
3087
|
+
M2_square, vmask, rank=j3
|
|
3015
3088
|
) # [Nbatch, Nmask, Norient3]
|
|
3016
3089
|
# We fill P1_dic with S2 for normalisation of S3 and S4
|
|
3017
3090
|
P1_dic[j3] = self.backend.bk_real(p1) # [Nbatch, Nmask, Norient3]
|
|
@@ -3027,10 +3100,10 @@ class funct(FOC.FoCUS):
|
|
|
3027
3100
|
else:
|
|
3028
3101
|
if calc_var:
|
|
3029
3102
|
s2, vs2 = self.masked_mean(
|
|
3030
|
-
s2, vmask,
|
|
3103
|
+
s2, vmask, rank=j3, calc_var=True
|
|
3031
3104
|
)
|
|
3032
3105
|
else:
|
|
3033
|
-
s2 = self.masked_mean(s2, vmask,
|
|
3106
|
+
s2 = self.masked_mean(s2, vmask, rank=j3)
|
|
3034
3107
|
|
|
3035
3108
|
if return_data:
|
|
3036
3109
|
if out_nside is not None and out_nside < nside_j3:
|
|
@@ -3072,11 +3145,11 @@ class funct(FOC.FoCUS):
|
|
|
3072
3145
|
else:
|
|
3073
3146
|
if calc_var:
|
|
3074
3147
|
s1, vs1 = self.masked_mean(
|
|
3075
|
-
MX, vmask,
|
|
3148
|
+
MX, vmask, rank=j3, calc_var=True
|
|
3076
3149
|
) # [Nbatch, Nmask, Norient3]
|
|
3077
3150
|
else:
|
|
3078
3151
|
s1 = self.masked_mean(
|
|
3079
|
-
MX, vmask,
|
|
3152
|
+
MX, vmask, rank=j3
|
|
3080
3153
|
) # [Nbatch, Nmask, Norient3]
|
|
3081
3154
|
if return_data:
|
|
3082
3155
|
if out_nside is not None and out_nside < nside_j3:
|
|
@@ -3133,6 +3206,7 @@ class funct(FOC.FoCUS):
|
|
|
3133
3206
|
cmat2=cmat2,
|
|
3134
3207
|
cell_ids=cell_ids_j3,
|
|
3135
3208
|
nside_j2=nside_j3,
|
|
3209
|
+
spin=spin,
|
|
3136
3210
|
) # [Nbatch, Nmask, Norient3, Norient2]
|
|
3137
3211
|
else:
|
|
3138
3212
|
s3 = self._compute_S3(
|
|
@@ -3146,6 +3220,7 @@ class funct(FOC.FoCUS):
|
|
|
3146
3220
|
cmat2=cmat2,
|
|
3147
3221
|
cell_ids=cell_ids_j3,
|
|
3148
3222
|
nside_j2=nside_j3,
|
|
3223
|
+
spin=spin,
|
|
3149
3224
|
) # [Nbatch, Nmask, Norient3, Norient2]
|
|
3150
3225
|
|
|
3151
3226
|
if return_data:
|
|
@@ -3207,6 +3282,7 @@ class funct(FOC.FoCUS):
|
|
|
3207
3282
|
cmat2=cmat2,
|
|
3208
3283
|
cell_ids=cell_ids_j3,
|
|
3209
3284
|
nside_j2=nside_j3,
|
|
3285
|
+
spin=spin,
|
|
3210
3286
|
)
|
|
3211
3287
|
s3p, vs3p = self._compute_S3(
|
|
3212
3288
|
j2,
|
|
@@ -3219,6 +3295,7 @@ class funct(FOC.FoCUS):
|
|
|
3219
3295
|
cmat2=cmat2,
|
|
3220
3296
|
cell_ids=cell_ids_j3,
|
|
3221
3297
|
nside_j2=nside_j3,
|
|
3298
|
+
spin=spin,
|
|
3222
3299
|
)
|
|
3223
3300
|
else:
|
|
3224
3301
|
s3p = self._compute_S3(
|
|
@@ -3232,6 +3309,7 @@ class funct(FOC.FoCUS):
|
|
|
3232
3309
|
cmat2=cmat2,
|
|
3233
3310
|
cell_ids=cell_ids_j3,
|
|
3234
3311
|
nside_j2=nside_j3,
|
|
3312
|
+
spin=spin,
|
|
3235
3313
|
)
|
|
3236
3314
|
s3 = self._compute_S3(
|
|
3237
3315
|
j2,
|
|
@@ -3244,6 +3322,7 @@ class funct(FOC.FoCUS):
|
|
|
3244
3322
|
cmat2=cmat2,
|
|
3245
3323
|
cell_ids=cell_ids_j3,
|
|
3246
3324
|
nside_j2=nside_j3,
|
|
3325
|
+
spin=spin,
|
|
3247
3326
|
)
|
|
3248
3327
|
|
|
3249
3328
|
if return_data:
|
|
@@ -3482,39 +3561,39 @@ class funct(FOC.FoCUS):
|
|
|
3482
3561
|
### Image I1,
|
|
3483
3562
|
# downscale the I1 [Nbatch, Npix_j3]
|
|
3484
3563
|
if j3 != Jmax - 1:
|
|
3485
|
-
I1 = self.smooth(I1,
|
|
3564
|
+
I1 = self.smooth(I1, cell_ids=cell_ids_j3, nside=nside_j3)
|
|
3486
3565
|
I1, new_cell_ids_j3 = self.ud_grade_2(
|
|
3487
|
-
I1,
|
|
3566
|
+
I1, cell_ids=cell_ids_j3, nside=nside_j3
|
|
3488
3567
|
)
|
|
3489
3568
|
|
|
3490
3569
|
### Image I2
|
|
3491
3570
|
if cross:
|
|
3492
|
-
I2 = self.smooth(I2,
|
|
3571
|
+
I2 = self.smooth(I2, cell_ids=cell_ids_j3, nside=nside_j3)
|
|
3493
3572
|
I2, new_cell_ids_j3 = self.ud_grade_2(
|
|
3494
|
-
I2,
|
|
3573
|
+
I2, cell_ids=cell_ids_j3, nside=nside_j3
|
|
3495
3574
|
)
|
|
3496
3575
|
|
|
3497
3576
|
### Modules
|
|
3498
3577
|
for j2 in range(0, j3 + 1): # j2 =< j3
|
|
3499
3578
|
### Dictionary M1_dic[j2]
|
|
3500
3579
|
M1_smooth = self.smooth(
|
|
3501
|
-
M1_dic[j2],
|
|
3580
|
+
M1_dic[j2], cell_ids=cell_ids_j3, nside=nside_j3
|
|
3502
3581
|
) # [Nbatch, Npix_j3, Norient3]
|
|
3503
3582
|
M1_dic[j2], new_cell_ids_j2 = self.ud_grade_2(
|
|
3504
|
-
M1_smooth,
|
|
3583
|
+
M1_smooth, cell_ids=cell_ids_j3, nside=nside_j3
|
|
3505
3584
|
) # [Nbatch, Npix_j3, Norient3]
|
|
3506
3585
|
|
|
3507
3586
|
### Dictionary M2_dic[j2]
|
|
3508
3587
|
if cross:
|
|
3509
3588
|
M2_smooth = self.smooth(
|
|
3510
|
-
M2_dic[j2],
|
|
3589
|
+
M2_dic[j2], cell_ids=cell_ids_j3, nside=nside_j3
|
|
3511
3590
|
) # [Nbatch, Npix_j3, Norient3]
|
|
3512
3591
|
M2_dic[j2], new_cell_ids_j2 = self.ud_grade_2(
|
|
3513
|
-
M2_smooth,
|
|
3592
|
+
M2_smooth, cell_ids=cell_ids_j3, nside=nside_j3
|
|
3514
3593
|
) # [Nbatch, Npix_j3, Norient3]
|
|
3515
3594
|
### Mask
|
|
3516
3595
|
vmask, new_cell_ids_j3 = self.ud_grade_2(
|
|
3517
|
-
vmask,
|
|
3596
|
+
vmask, cell_ids=cell_ids_j3, nside=nside_j3
|
|
3518
3597
|
)
|
|
3519
3598
|
|
|
3520
3599
|
if self.mask_thres is not None:
|
|
@@ -3529,33 +3608,21 @@ class funct(FOC.FoCUS):
|
|
|
3529
3608
|
self.P1_dic = P1_dic
|
|
3530
3609
|
if cross:
|
|
3531
3610
|
self.P2_dic = P2_dic
|
|
3532
|
-
|
|
3533
|
-
Sout=[s0]+S1+S2+S3+S4
|
|
3534
|
-
|
|
3535
|
-
if cross:
|
|
3536
|
-
Sout=Sout+S3P
|
|
3537
|
-
if calc_var:
|
|
3538
|
-
SVout=[vs0]+VS1+VS2+VS3+VS4
|
|
3539
|
-
if cross:
|
|
3540
|
-
VSout=VSout+VS3P
|
|
3541
|
-
return self.backend.bk_concat(Sout, 2),self.backend.bk_concat(VSout, 2)
|
|
3542
|
-
|
|
3543
|
-
return self.backend.bk_concat(Sout, 2)
|
|
3544
|
-
"""
|
|
3611
|
+
|
|
3545
3612
|
if not return_data:
|
|
3546
|
-
S1 = self.backend.bk_concat(S1, 2)
|
|
3547
|
-
S2 = self.backend.bk_concat(S2, 2)
|
|
3548
|
-
S3 = self.backend.bk_concat(S3,
|
|
3549
|
-
S4 = self.backend.bk_concat(S4,
|
|
3613
|
+
S1 = self.backend.bk_concat(S1, -2)
|
|
3614
|
+
S2 = self.backend.bk_concat(S2, -2)
|
|
3615
|
+
S3 = self.backend.bk_concat(S3, -3)
|
|
3616
|
+
S4 = self.backend.bk_concat(S4, -4)
|
|
3550
3617
|
if cross:
|
|
3551
|
-
S3P = self.backend.bk_concat(S3P,
|
|
3618
|
+
S3P = self.backend.bk_concat(S3P, -3)
|
|
3552
3619
|
if calc_var:
|
|
3553
|
-
VS1 = self.backend.bk_concat(VS1, 2)
|
|
3554
|
-
VS2 = self.backend.bk_concat(VS2, 2)
|
|
3555
|
-
VS3 = self.backend.bk_concat(VS3,
|
|
3556
|
-
VS4 = self.backend.bk_concat(VS4,
|
|
3620
|
+
VS1 = self.backend.bk_concat(VS1, -2)
|
|
3621
|
+
VS2 = self.backend.bk_concat(VS2, -2)
|
|
3622
|
+
VS3 = self.backend.bk_concat(VS3, -3)
|
|
3623
|
+
VS4 = self.backend.bk_concat(VS4, -4)
|
|
3557
3624
|
if cross:
|
|
3558
|
-
VS3P = self.backend.bk_concat(VS3P,
|
|
3625
|
+
VS3P = self.backend.bk_concat(VS3P, -3)
|
|
3559
3626
|
if calc_var:
|
|
3560
3627
|
if not cross:
|
|
3561
3628
|
return scat_cov(
|
|
@@ -3612,18 +3679,19 @@ class funct(FOC.FoCUS):
|
|
|
3612
3679
|
return
|
|
3613
3680
|
|
|
3614
3681
|
def _compute_S3(
|
|
3615
|
-
|
|
3616
|
-
|
|
3617
|
-
|
|
3618
|
-
|
|
3619
|
-
|
|
3620
|
-
|
|
3621
|
-
|
|
3622
|
-
|
|
3623
|
-
|
|
3624
|
-
|
|
3625
|
-
|
|
3626
|
-
|
|
3682
|
+
self,
|
|
3683
|
+
j2,
|
|
3684
|
+
j3,
|
|
3685
|
+
conv,
|
|
3686
|
+
vmask,
|
|
3687
|
+
M_dic,
|
|
3688
|
+
MconvPsi_dic,
|
|
3689
|
+
calc_var=False,
|
|
3690
|
+
return_data=False,
|
|
3691
|
+
cmat2=None,
|
|
3692
|
+
cell_ids=None,
|
|
3693
|
+
nside_j2=None,
|
|
3694
|
+
spin=0,
|
|
3627
3695
|
):
|
|
3628
3696
|
"""
|
|
3629
3697
|
Compute the S3 coefficients (auto or cross)
|
|
@@ -3637,24 +3705,40 @@ class funct(FOC.FoCUS):
|
|
|
3637
3705
|
### Compute |I1 * Psi_j2| * Psi_j3 = M1_j2 * Psi_j3
|
|
3638
3706
|
# Warning: M1_dic[j2] is already at j3 resolution [Nbatch, Norient3, Npix_j3]
|
|
3639
3707
|
MconvPsi = self.convol(
|
|
3640
|
-
M_dic[j2],
|
|
3708
|
+
M_dic[j2], cell_ids=cell_ids, nside=nside_j2
|
|
3641
3709
|
) # [Nbatch, Norient3, Norient2, Npix_j3]
|
|
3642
3710
|
|
|
3643
3711
|
if cmat2 is not None:
|
|
3644
3712
|
tmp2 = self.backend.bk_repeat(MconvPsi, self.NORIENT, axis=-3)
|
|
3645
|
-
|
|
3646
|
-
self.backend.
|
|
3647
|
-
|
|
3648
|
-
|
|
3649
|
-
|
|
3650
|
-
|
|
3651
|
-
|
|
3652
|
-
|
|
3653
|
-
|
|
3654
|
-
|
|
3655
|
-
|
|
3656
|
-
|
|
3657
|
-
|
|
3713
|
+
if spin==0:
|
|
3714
|
+
MconvPsi = self.backend.bk_reduce_sum(
|
|
3715
|
+
self.backend.bk_reshape(
|
|
3716
|
+
cmat2[j3][j2] * tmp2,
|
|
3717
|
+
[
|
|
3718
|
+
tmp2.shape[0],
|
|
3719
|
+
self.NORIENT,
|
|
3720
|
+
self.NORIENT,
|
|
3721
|
+
self.NORIENT,
|
|
3722
|
+
cmat2[j3][j2].shape[3],
|
|
3723
|
+
],
|
|
3724
|
+
),
|
|
3725
|
+
1,
|
|
3726
|
+
)
|
|
3727
|
+
else:
|
|
3728
|
+
MconvPsi = self.backend.bk_reduce_sum(
|
|
3729
|
+
self.backend.bk_reshape(
|
|
3730
|
+
cmat2[j3][j2] * tmp2,
|
|
3731
|
+
[
|
|
3732
|
+
tmp2.shape[0],
|
|
3733
|
+
2,
|
|
3734
|
+
self.NORIENT,
|
|
3735
|
+
self.NORIENT,
|
|
3736
|
+
self.NORIENT,
|
|
3737
|
+
cmat2[j3][j2].shape[4],
|
|
3738
|
+
],
|
|
3739
|
+
),
|
|
3740
|
+
2,
|
|
3741
|
+
)
|
|
3658
3742
|
|
|
3659
3743
|
# Store it so we can use it in S4 computation
|
|
3660
3744
|
MconvPsi_dic[j2] = MconvPsi # [Nbatch, Norient3, Norient2, Npix_j3]
|
|
@@ -3674,12 +3758,12 @@ class funct(FOC.FoCUS):
|
|
|
3674
3758
|
else:
|
|
3675
3759
|
if calc_var:
|
|
3676
3760
|
s3, vs3 = self.masked_mean(
|
|
3677
|
-
s3, vmask,
|
|
3761
|
+
s3, vmask, rank=j2, calc_var=True
|
|
3678
3762
|
) # [Nbatch, Nmask, Norient3, Norient2]
|
|
3679
3763
|
return s3, vs3
|
|
3680
3764
|
else:
|
|
3681
3765
|
s3 = self.masked_mean(
|
|
3682
|
-
s3, vmask,
|
|
3766
|
+
s3, vmask, rank=j2
|
|
3683
3767
|
) # [Nbatch, Nmask, Norient3, Norient2]
|
|
3684
3768
|
return s3
|
|
3685
3769
|
|
|
@@ -3717,12 +3801,12 @@ class funct(FOC.FoCUS):
|
|
|
3717
3801
|
else:
|
|
3718
3802
|
if calc_var:
|
|
3719
3803
|
s4, vs4 = self.masked_mean(
|
|
3720
|
-
s4, vmask,
|
|
3804
|
+
s4, vmask, rank=j2, calc_var=True
|
|
3721
3805
|
) # [Nbatch, Nmask, Norient3, Norient2, Norient1]
|
|
3722
3806
|
return s4, vs4
|
|
3723
3807
|
else:
|
|
3724
3808
|
s4 = self.masked_mean(
|
|
3725
|
-
s4, vmask,
|
|
3809
|
+
s4, vmask, rank=j2
|
|
3726
3810
|
) # [Nbatch, Nmask, Norient3, Norient2, Norient1]
|
|
3727
3811
|
return s4
|
|
3728
3812
|
|
|
@@ -3922,20 +4006,20 @@ class funct(FOC.FoCUS):
|
|
|
3922
4006
|
#
|
|
3923
4007
|
# ---------------------------------------------------------------------------
|
|
3924
4008
|
def scattering_cov(
|
|
3925
|
-
|
|
3926
|
-
|
|
3927
|
-
|
|
3928
|
-
|
|
3929
|
-
|
|
3930
|
-
|
|
3931
|
-
|
|
3932
|
-
|
|
3933
|
-
|
|
3934
|
-
|
|
3935
|
-
|
|
3936
|
-
|
|
3937
|
-
|
|
3938
|
-
|
|
4009
|
+
self,
|
|
4010
|
+
data,
|
|
4011
|
+
data2=None,
|
|
4012
|
+
Jmax=None,
|
|
4013
|
+
if_large_batch=False,
|
|
4014
|
+
S4_criteria=None,
|
|
4015
|
+
use_ref=False,
|
|
4016
|
+
normalization="S2",
|
|
4017
|
+
edge=False,
|
|
4018
|
+
in_mask=None,
|
|
4019
|
+
pseudo_coef=1,
|
|
4020
|
+
get_variance=False,
|
|
4021
|
+
ref_sigma=None,
|
|
4022
|
+
iso_ang=False,
|
|
3939
4023
|
):
|
|
3940
4024
|
"""
|
|
3941
4025
|
Calculates the scattering correlations for a batch of images, including:
|
|
@@ -4048,7 +4132,10 @@ class funct(FOC.FoCUS):
|
|
|
4048
4132
|
if data2 is not None:
|
|
4049
4133
|
N_image2 = data2.shape[0]
|
|
4050
4134
|
|
|
4051
|
-
|
|
4135
|
+
if spin==0:
|
|
4136
|
+
nside = int(np.sqrt(npix // 12))
|
|
4137
|
+
else:
|
|
4138
|
+
nside = int(np.sqrt(npix // 24))
|
|
4052
4139
|
|
|
4053
4140
|
J = int(np.log(nside) / np.log(2)) # Number of j scales
|
|
4054
4141
|
|
|
@@ -5781,7 +5868,9 @@ class funct(FOC.FoCUS):
|
|
|
5781
5868
|
|
|
5782
5869
|
def from_gaussian(self, x):
|
|
5783
5870
|
|
|
5784
|
-
x = self.backend.bk_clip_by_value(x,
|
|
5871
|
+
x = self.backend.bk_clip_by_value(x,
|
|
5872
|
+
self.val_min+1E-4*abs(self.val_min),
|
|
5873
|
+
self.val_max-1E-4*abs(self.val_max))
|
|
5785
5874
|
return self.f_gaussian(self.backend.to_numpy(x))
|
|
5786
5875
|
|
|
5787
5876
|
def square(self, x):
|
|
@@ -6181,6 +6270,28 @@ class funct(FOC.FoCUS):
|
|
|
6181
6270
|
)
|
|
6182
6271
|
return loss
|
|
6183
6272
|
|
|
6273
|
+
def The_lossH(u, scat_operator, args):
|
|
6274
|
+
ref = args[0]
|
|
6275
|
+
sref = args[1]
|
|
6276
|
+
use_v = args[2]
|
|
6277
|
+
ljmax = args[3]
|
|
6278
|
+
|
|
6279
|
+
learn = scat_operator.reduce_mean_batch(
|
|
6280
|
+
scat_operator.eval(
|
|
6281
|
+
u,
|
|
6282
|
+
Jmax=ljmax,
|
|
6283
|
+
norm='self'
|
|
6284
|
+
)
|
|
6285
|
+
)
|
|
6286
|
+
|
|
6287
|
+
# compute scattering covariance of the current synthetised map called u
|
|
6288
|
+
if use_v:
|
|
6289
|
+
loss = scat_operator.reduce_distance(learn,ref,sigma=sref)
|
|
6290
|
+
else:
|
|
6291
|
+
loss = scat_operator.reduce_distance(learn,ref)
|
|
6292
|
+
|
|
6293
|
+
return loss
|
|
6294
|
+
|
|
6184
6295
|
def The_lossX(u, scat_operator, args):
|
|
6185
6296
|
ref = args[0]
|
|
6186
6297
|
sref = args[1]
|
|
@@ -6340,37 +6451,74 @@ class funct(FOC.FoCUS):
|
|
|
6340
6451
|
# Increase the resolution between each step
|
|
6341
6452
|
if self.use_2D:
|
|
6342
6453
|
imap = self.up_grade(
|
|
6343
|
-
omap,
|
|
6454
|
+
omap,
|
|
6455
|
+
imap.shape[1] * 2,
|
|
6456
|
+
axis=-2,
|
|
6457
|
+
nouty=imap.shape[2] * 2
|
|
6344
6458
|
)
|
|
6345
6459
|
elif self.use_1D:
|
|
6346
|
-
imap = self.up_grade(omap, imap.shape[1] * 2
|
|
6460
|
+
imap = self.up_grade(omap, imap.shape[1] * 2)
|
|
6347
6461
|
else:
|
|
6348
|
-
imap = self.up_grade(omap, l_nside
|
|
6349
|
-
|
|
6462
|
+
imap = self.up_grade(omap, l_nside)
|
|
6463
|
+
|
|
6350
6464
|
if grd_mask is not None:
|
|
6351
6465
|
imap = imap * l_grd_mask[k] + tmp[k] * (1 - l_grd_mask[k])
|
|
6352
6466
|
|
|
6353
|
-
|
|
6354
|
-
if
|
|
6355
|
-
|
|
6356
|
-
|
|
6357
|
-
|
|
6358
|
-
|
|
6359
|
-
|
|
6360
|
-
|
|
6361
|
-
|
|
6362
|
-
|
|
6363
|
-
|
|
6467
|
+
|
|
6468
|
+
if self.use_2D:
|
|
6469
|
+
# compute the coefficients for the target image
|
|
6470
|
+
if use_variance:
|
|
6471
|
+
ref, sref = self.scattering_cov(
|
|
6472
|
+
tmp[k],
|
|
6473
|
+
data2=l_ref[k],
|
|
6474
|
+
get_variance=True,
|
|
6475
|
+
edge=l_edge,
|
|
6476
|
+
Jmax=l_jmax[k],
|
|
6477
|
+
in_mask=l_in_mask[k],
|
|
6478
|
+
iso_ang=iso_ang,
|
|
6479
|
+
)
|
|
6480
|
+
else:
|
|
6481
|
+
ref = self.scattering_cov(
|
|
6482
|
+
tmp[k],
|
|
6483
|
+
data2=l_ref[k],
|
|
6484
|
+
in_mask=l_in_mask[k],
|
|
6485
|
+
edge=l_edge,
|
|
6486
|
+
Jmax=l_jmax[k],
|
|
6487
|
+
iso_ang=iso_ang,
|
|
6488
|
+
)
|
|
6489
|
+
sref = ref
|
|
6364
6490
|
else:
|
|
6365
|
-
ref = self.
|
|
6366
|
-
|
|
6367
|
-
|
|
6368
|
-
|
|
6369
|
-
|
|
6370
|
-
|
|
6371
|
-
|
|
6372
|
-
|
|
6373
|
-
|
|
6491
|
+
ref = self.eval(
|
|
6492
|
+
tmp[k],
|
|
6493
|
+
image2=l_ref[k],
|
|
6494
|
+
mask=l_in_mask[k],
|
|
6495
|
+
Jmax=l_jmax[k],
|
|
6496
|
+
norm='auto'
|
|
6497
|
+
)
|
|
6498
|
+
|
|
6499
|
+
# compute the coefficients for the target image
|
|
6500
|
+
if use_variance:
|
|
6501
|
+
ref, sref = self.eval(
|
|
6502
|
+
tmp[k],
|
|
6503
|
+
image2=l_ref[k],
|
|
6504
|
+
mask=l_in_mask[k],
|
|
6505
|
+
Jmax=l_jmax[k],
|
|
6506
|
+
calc_var=True,
|
|
6507
|
+
norm='self'
|
|
6508
|
+
)
|
|
6509
|
+
else:
|
|
6510
|
+
ref = self.eval(
|
|
6511
|
+
tmp[k],
|
|
6512
|
+
image2=l_ref[k],
|
|
6513
|
+
mask=l_in_mask[k],
|
|
6514
|
+
Jmax=l_jmax[k],
|
|
6515
|
+
norm='self'
|
|
6516
|
+
)
|
|
6517
|
+
sref = ref
|
|
6518
|
+
|
|
6519
|
+
if iso_ang:
|
|
6520
|
+
ref=ref.iso_mean()
|
|
6521
|
+
sref=sref.iso_mean()
|
|
6374
6522
|
|
|
6375
6523
|
# compute the mean of the population does nothing if only one map is given
|
|
6376
6524
|
ref = self.reduce_mean_batch(ref)
|
|
@@ -6379,13 +6527,21 @@ class funct(FOC.FoCUS):
|
|
|
6379
6527
|
self.purge_edge_mask()
|
|
6380
6528
|
|
|
6381
6529
|
if l_ref[k] is None:
|
|
6382
|
-
|
|
6383
|
-
|
|
6530
|
+
if self.use_2D:
|
|
6531
|
+
# define a loss to minimize
|
|
6532
|
+
loss = synthe.Loss(The_loss, self, ref, sref, use_variance, l_jmax[k])
|
|
6533
|
+
else:
|
|
6534
|
+
loss = synthe.Loss(The_lossH, self, ref, sref, use_variance, l_jmax[k])
|
|
6384
6535
|
else:
|
|
6385
6536
|
# define a loss to minimize
|
|
6386
|
-
|
|
6387
|
-
|
|
6388
|
-
|
|
6537
|
+
if self.use_2D:
|
|
6538
|
+
loss = synthe.Loss(
|
|
6539
|
+
The_lossX, self, ref, sref, use_variance, l_ref[k], l_jmax[k]
|
|
6540
|
+
)
|
|
6541
|
+
else:
|
|
6542
|
+
loss = synthe.Loss(
|
|
6543
|
+
The_lossXH, self, ref, sref, use_variance, l_ref[k], l_jmax[k]
|
|
6544
|
+
)
|
|
6389
6545
|
|
|
6390
6546
|
if input_image is not None:
|
|
6391
6547
|
# define a loss to minimize
|