foscat 3.9.0__py3-none-any.whl → 2025.5.0__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- foscat/BkBase.py +9 -0
- foscat/BkTorch.py +68 -0
- foscat/FoCUS.py +157 -34
- foscat/Synthesis.py +1 -1
- foscat/scat_cov.py +417 -246
- foscat/scat_cov_map.py +0 -2
- {foscat-3.9.0.dist-info → foscat-2025.5.0.dist-info}/METADATA +3 -2
- {foscat-3.9.0.dist-info → foscat-2025.5.0.dist-info}/RECORD +11 -11
- {foscat-3.9.0.dist-info → foscat-2025.5.0.dist-info}/WHEEL +1 -1
- {foscat-3.9.0.dist-info → foscat-2025.5.0.dist-info/licenses}/LICENSE +0 -0
- {foscat-3.9.0.dist-info → foscat-2025.5.0.dist-info}/top_level.txt +0 -0
foscat/scat_cov.py
CHANGED
|
@@ -177,9 +177,21 @@ class scat_cov:
|
|
|
177
177
|
],
|
|
178
178
|
)
|
|
179
179
|
),
|
|
180
|
-
self.
|
|
180
|
+
self.backend.bk_reshape(
|
|
181
|
+
self.S3,
|
|
182
|
+
[
|
|
183
|
+
self.S3.shape[0],
|
|
184
|
+
self.S3.shape[1]
|
|
185
|
+
* self.S3.shape[2]
|
|
186
|
+
* self.S3.shape[3]
|
|
187
|
+
* self.S3.shape[4],
|
|
188
|
+
],
|
|
189
|
+
),
|
|
190
|
+
]
|
|
191
|
+
if self.S3P is not None:
|
|
192
|
+
tmp = tmp + [
|
|
181
193
|
self.backend.bk_reshape(
|
|
182
|
-
self.
|
|
194
|
+
self.S3P,
|
|
183
195
|
[
|
|
184
196
|
self.S3.shape[0],
|
|
185
197
|
self.S3.shape[1]
|
|
@@ -188,37 +200,19 @@ class scat_cov:
|
|
|
188
200
|
* self.S3.shape[4],
|
|
189
201
|
],
|
|
190
202
|
)
|
|
191
|
-
),
|
|
192
|
-
]
|
|
193
|
-
if self.S3P is not None:
|
|
194
|
-
tmp = tmp + [
|
|
195
|
-
self.conv2complex(
|
|
196
|
-
self.backend.bk_reshape(
|
|
197
|
-
self.S3P,
|
|
198
|
-
[
|
|
199
|
-
self.S3.shape[0],
|
|
200
|
-
self.S3.shape[1]
|
|
201
|
-
* self.S3.shape[2]
|
|
202
|
-
* self.S3.shape[3]
|
|
203
|
-
* self.S3.shape[4],
|
|
204
|
-
],
|
|
205
|
-
)
|
|
206
|
-
)
|
|
207
203
|
]
|
|
208
204
|
|
|
209
205
|
tmp = tmp + [
|
|
210
|
-
self.
|
|
211
|
-
self.
|
|
212
|
-
|
|
213
|
-
[
|
|
214
|
-
|
|
215
|
-
|
|
216
|
-
|
|
217
|
-
|
|
218
|
-
|
|
219
|
-
|
|
220
|
-
],
|
|
221
|
-
)
|
|
206
|
+
self.backend.bk_reshape(
|
|
207
|
+
self.S4,
|
|
208
|
+
[
|
|
209
|
+
self.S3.shape[0],
|
|
210
|
+
self.S4.shape[1]
|
|
211
|
+
* self.S4.shape[2]
|
|
212
|
+
* self.S4.shape[3]
|
|
213
|
+
* self.S4.shape[4]
|
|
214
|
+
* self.S4.shape[5],
|
|
215
|
+
],
|
|
222
216
|
)
|
|
223
217
|
]
|
|
224
218
|
|
|
@@ -504,7 +498,7 @@ class scat_cov:
|
|
|
504
498
|
if other.S1 is None:
|
|
505
499
|
s1 = None
|
|
506
500
|
else:
|
|
507
|
-
s1 = self.S1
|
|
501
|
+
s1 = self.doadd(self.S1, other.S1)
|
|
508
502
|
else:
|
|
509
503
|
s1 = self.S1 + other
|
|
510
504
|
|
|
@@ -534,7 +528,7 @@ class scat_cov:
|
|
|
534
528
|
return scat_cov(
|
|
535
529
|
self.doadd(self.S0, other.S0),
|
|
536
530
|
self.doadd(self.S2, other.S2),
|
|
537
|
-
(self.S3
|
|
531
|
+
self.doadd(self.S3, other.S3),
|
|
538
532
|
s4,
|
|
539
533
|
s1=s1,
|
|
540
534
|
s3p=s3p,
|
|
@@ -662,7 +656,7 @@ class scat_cov:
|
|
|
662
656
|
s1 = None
|
|
663
657
|
else:
|
|
664
658
|
if isinstance(other, scat_cov):
|
|
665
|
-
s1 = other.S1
|
|
659
|
+
s1 = self.dodiv(other.S1, self.S1)
|
|
666
660
|
else:
|
|
667
661
|
s1 = other / self.S1
|
|
668
662
|
|
|
@@ -689,7 +683,7 @@ class scat_cov:
|
|
|
689
683
|
return scat_cov(
|
|
690
684
|
self.dodiv(other.S0, self.S0),
|
|
691
685
|
self.dodiv(other.S2, self.S2),
|
|
692
|
-
(other.S3
|
|
686
|
+
self.dodiv(other.S3, self.S3),
|
|
693
687
|
s4,
|
|
694
688
|
s1=s1,
|
|
695
689
|
s3p=s3p,
|
|
@@ -725,7 +719,7 @@ class scat_cov:
|
|
|
725
719
|
if other.S1 is None:
|
|
726
720
|
s1 = None
|
|
727
721
|
else:
|
|
728
|
-
s1 = other.S1
|
|
722
|
+
s1 = self.domin(other.S1, self.S1)
|
|
729
723
|
else:
|
|
730
724
|
s1 = other - self.S1
|
|
731
725
|
|
|
@@ -755,7 +749,7 @@ class scat_cov:
|
|
|
755
749
|
return scat_cov(
|
|
756
750
|
self.domin(other.S0, self.S0),
|
|
757
751
|
self.domin(other.S2, self.S2),
|
|
758
|
-
(other.S3
|
|
752
|
+
self.domin(other.S3, self.S3),
|
|
759
753
|
s4,
|
|
760
754
|
s1=s1,
|
|
761
755
|
s3p=s3p,
|
|
@@ -790,7 +784,7 @@ class scat_cov:
|
|
|
790
784
|
if other.S1 is None:
|
|
791
785
|
s1 = None
|
|
792
786
|
else:
|
|
793
|
-
s1 = self.S1
|
|
787
|
+
s1 = self.domin(self.S1, other.S1)
|
|
794
788
|
else:
|
|
795
789
|
s1 = self.S1 - other
|
|
796
790
|
|
|
@@ -820,7 +814,7 @@ class scat_cov:
|
|
|
820
814
|
return scat_cov(
|
|
821
815
|
self.domin(self.S0, other.S0),
|
|
822
816
|
self.domin(self.S2, other.S2),
|
|
823
|
-
(self.S3
|
|
817
|
+
self.domin(self.S3, other.S3),
|
|
824
818
|
s4,
|
|
825
819
|
s1=s1,
|
|
826
820
|
s3p=s3p,
|
|
@@ -920,7 +914,7 @@ class scat_cov:
|
|
|
920
914
|
if other.S1 is None:
|
|
921
915
|
s1 = None
|
|
922
916
|
else:
|
|
923
|
-
s1 = self.S1
|
|
917
|
+
s1 = self.domult(self.S1, other.S1)
|
|
924
918
|
else:
|
|
925
919
|
s1 = self.S1 * other
|
|
926
920
|
|
|
@@ -2215,7 +2209,7 @@ class scat_cov:
|
|
|
2215
2209
|
|
|
2216
2210
|
|
|
2217
2211
|
class funct(FOC.FoCUS):
|
|
2218
|
-
|
|
2212
|
+
|
|
2219
2213
|
def fill(self, im, nullval=hp.UNSEEN):
|
|
2220
2214
|
if self.use_2D:
|
|
2221
2215
|
return self.fill_2d(im, nullval=nullval)
|
|
@@ -2402,8 +2396,9 @@ class funct(FOC.FoCUS):
|
|
|
2402
2396
|
if smooth_scale > 0:
|
|
2403
2397
|
for m in range(smooth_scale):
|
|
2404
2398
|
if cc.shape[0] > 12:
|
|
2405
|
-
cc = self.ud_grade_2(self.smooth(cc))
|
|
2406
|
-
ss = self.ud_grade_2(self.smooth(ss))
|
|
2399
|
+
cc, _ = self.ud_grade_2(self.smooth(cc))
|
|
2400
|
+
ss, _ = self.ud_grade_2(self.smooth(ss))
|
|
2401
|
+
|
|
2407
2402
|
if cc.shape[0] != tmp.shape[0]:
|
|
2408
2403
|
ll_nside = int(np.sqrt(tmp.shape[1] // 12))
|
|
2409
2404
|
cc = self.up_grade(cc, ll_nside)
|
|
@@ -2473,8 +2468,8 @@ class funct(FOC.FoCUS):
|
|
|
2473
2468
|
if smooth_scale > 0:
|
|
2474
2469
|
for m in range(smooth_scale):
|
|
2475
2470
|
if cc2.shape[0] > 12:
|
|
2476
|
-
cc2 = self.ud_grade_2(self.smooth(cc2))
|
|
2477
|
-
ss2 = self.ud_grade_2(self.smooth(ss2))
|
|
2471
|
+
cc2, _ = self.ud_grade_2(self.smooth(cc2))
|
|
2472
|
+
ss2, _ = self.ud_grade_2(self.smooth(ss2))
|
|
2478
2473
|
|
|
2479
2474
|
if cc2.shape[0] != sim.shape[1]:
|
|
2480
2475
|
ll_nside = int(np.sqrt(sim.shape[1] // 12))
|
|
@@ -2510,9 +2505,9 @@ class funct(FOC.FoCUS):
|
|
|
2510
2505
|
cmat2[k] = self.backend.bk_cast(mat2.astype("complex64"))
|
|
2511
2506
|
|
|
2512
2507
|
if k < l_nside - 1:
|
|
2513
|
-
tmp = self.ud_grade_2(tmp, axis=1)
|
|
2508
|
+
tmp, _ = self.ud_grade_2(tmp, axis=1)
|
|
2514
2509
|
if image2 is not None:
|
|
2515
|
-
tmpi2 = self.ud_grade_2(tmpi2, axis=1)
|
|
2510
|
+
tmpi2, _ = self.ud_grade_2(tmpi2, axis=1)
|
|
2516
2511
|
return cmat, cmat2
|
|
2517
2512
|
|
|
2518
2513
|
def div_norm(self, complex_value, float_value):
|
|
@@ -2533,6 +2528,8 @@ class funct(FOC.FoCUS):
|
|
|
2533
2528
|
Jmax=None,
|
|
2534
2529
|
out_nside=None,
|
|
2535
2530
|
edge=True,
|
|
2531
|
+
nside=None,
|
|
2532
|
+
cell_ids=None,
|
|
2536
2533
|
):
|
|
2537
2534
|
"""
|
|
2538
2535
|
Calculates the scattering correlations for a batch of images. Mean are done over pixels.
|
|
@@ -2606,7 +2603,7 @@ class funct(FOC.FoCUS):
|
|
|
2606
2603
|
cross = False
|
|
2607
2604
|
if image2 is not None:
|
|
2608
2605
|
cross = True
|
|
2609
|
-
|
|
2606
|
+
l_nside = 2**32 # not initialize if 1D or 2D
|
|
2610
2607
|
### PARAMETERS
|
|
2611
2608
|
axis = 1
|
|
2612
2609
|
# determine jmax and nside corresponding to the input map
|
|
@@ -2638,7 +2635,8 @@ class funct(FOC.FoCUS):
|
|
|
2638
2635
|
else:
|
|
2639
2636
|
npix = int(im_shape[0]) # Number of pixels
|
|
2640
2637
|
|
|
2641
|
-
nside
|
|
2638
|
+
if nside is None:
|
|
2639
|
+
nside = int(np.sqrt(npix // 12))
|
|
2642
2640
|
|
|
2643
2641
|
J = int(np.log(nside) / np.log(2)) # Number of j scales
|
|
2644
2642
|
|
|
@@ -2675,7 +2673,7 @@ class funct(FOC.FoCUS):
|
|
|
2675
2673
|
else:
|
|
2676
2674
|
vmask = self.backend.bk_cast(mask) # [Nmask, Npix]
|
|
2677
2675
|
|
|
2678
|
-
if self.KERNELSZ > 3 and not self.use_2D:
|
|
2676
|
+
if self.KERNELSZ > 3 and not self.use_2D and cell_ids is None:
|
|
2679
2677
|
# if the kernel size is bigger than 3 increase the binning before smoothing
|
|
2680
2678
|
if self.use_2D:
|
|
2681
2679
|
vmask = self.up_grade(
|
|
@@ -2693,12 +2691,15 @@ class funct(FOC.FoCUS):
|
|
|
2693
2691
|
I1 = self.up_grade(I1, I1.shape[axis] * 2, axis=axis)
|
|
2694
2692
|
if cross:
|
|
2695
2693
|
I2 = self.up_grade(I2, I2.shape[axis] * 2, axis=axis)
|
|
2694
|
+
nside = nside * 2
|
|
2696
2695
|
else:
|
|
2697
2696
|
I1 = self.up_grade(I1, nside * 2, axis=axis)
|
|
2698
2697
|
vmask = self.up_grade(vmask, nside * 2, axis=1)
|
|
2699
2698
|
if cross:
|
|
2700
2699
|
I2 = self.up_grade(I2, nside * 2, axis=axis)
|
|
2701
2700
|
|
|
2701
|
+
nside = nside * 2
|
|
2702
|
+
|
|
2702
2703
|
if self.KERNELSZ > 5 and not self.use_2D:
|
|
2703
2704
|
# if the kernel size is bigger than 3 increase the binning before smoothing
|
|
2704
2705
|
if self.use_2D:
|
|
@@ -2716,15 +2717,17 @@ class funct(FOC.FoCUS):
|
|
|
2716
2717
|
nouty=I2.shape[axis + 1] * 2,
|
|
2717
2718
|
)
|
|
2718
2719
|
elif self.use_1D:
|
|
2719
|
-
vmask = self.up_grade(vmask, I1.shape[axis] *
|
|
2720
|
-
I1 = self.up_grade(I1, I1.shape[axis] *
|
|
2720
|
+
vmask = self.up_grade(vmask, I1.shape[axis] * 2, axis=1)
|
|
2721
|
+
I1 = self.up_grade(I1, I1.shape[axis] * 2, axis=axis)
|
|
2721
2722
|
if cross:
|
|
2722
|
-
I2 = self.up_grade(I2, I2.shape[axis] *
|
|
2723
|
+
I2 = self.up_grade(I2, I2.shape[axis] * 2, axis=axis)
|
|
2724
|
+
nside = nside * 2
|
|
2723
2725
|
else:
|
|
2724
|
-
I1 = self.up_grade(I1, nside *
|
|
2725
|
-
vmask = self.up_grade(vmask, nside *
|
|
2726
|
+
I1 = self.up_grade(I1, nside * 2, axis=axis)
|
|
2727
|
+
vmask = self.up_grade(vmask, nside * 2, axis=1)
|
|
2726
2728
|
if cross:
|
|
2727
|
-
I2 = self.up_grade(I2, nside *
|
|
2729
|
+
I2 = self.up_grade(I2, nside * 2, axis=axis)
|
|
2730
|
+
nside = nside * 2
|
|
2728
2731
|
|
|
2729
2732
|
# Normalize the masks because they have different pixel numbers
|
|
2730
2733
|
# vmask /= self.backend.bk_reduce_sum(vmask, axis=1)[:, None] # [Nmask, Npix]
|
|
@@ -2797,6 +2800,8 @@ class funct(FOC.FoCUS):
|
|
|
2797
2800
|
M1_dic = {}
|
|
2798
2801
|
M2_dic = {}
|
|
2799
2802
|
|
|
2803
|
+
cell_ids_j3 = cell_ids
|
|
2804
|
+
|
|
2800
2805
|
for j3 in range(Jmax):
|
|
2801
2806
|
|
|
2802
2807
|
if edge:
|
|
@@ -2835,7 +2840,9 @@ class funct(FOC.FoCUS):
|
|
|
2835
2840
|
|
|
2836
2841
|
####### S1 and S2
|
|
2837
2842
|
### Make the convolution I1 * Psi_j3
|
|
2838
|
-
conv1 = self.convol(
|
|
2843
|
+
conv1 = self.convol(
|
|
2844
|
+
I1, axis=1, cell_ids=cell_ids_j3, nside=nside_j3
|
|
2845
|
+
) # [Nbatch, Npix_j3, Norient3]
|
|
2839
2846
|
if cmat is not None:
|
|
2840
2847
|
tmp2 = self.backend.bk_repeat(conv1, self.NORIENT, axis=-1)
|
|
2841
2848
|
conv1 = self.backend.bk_reduce_sum(
|
|
@@ -2948,7 +2955,9 @@ class funct(FOC.FoCUS):
|
|
|
2948
2955
|
|
|
2949
2956
|
else: # Cross
|
|
2950
2957
|
### Make the convolution I2 * Psi_j3
|
|
2951
|
-
conv2 = self.convol(
|
|
2958
|
+
conv2 = self.convol(
|
|
2959
|
+
I2, axis=1, cell_ids=cell_ids_j3, nside=nside_j3
|
|
2960
|
+
) # [Nbatch, Npix_j3, Norient3]
|
|
2952
2961
|
if cmat is not None:
|
|
2953
2962
|
tmp2 = self.backend.bk_repeat(conv2, self.NORIENT, axis=-1)
|
|
2954
2963
|
conv2 = self.backend.bk_reduce_sum(
|
|
@@ -3028,12 +3037,13 @@ class funct(FOC.FoCUS):
|
|
|
3028
3037
|
)
|
|
3029
3038
|
S2[j3] = s2
|
|
3030
3039
|
else:
|
|
3031
|
-
### Normalize S2_cross
|
|
3032
|
-
if norm == "auto":
|
|
3033
|
-
s2 /= (P1_dic[j3] * P2_dic[j3]) ** 0.5
|
|
3034
3040
|
|
|
3035
3041
|
### Store S2_cross as complex [Nbatch, Nmask, NS2, Norient3]
|
|
3036
3042
|
s2 = self.backend.bk_real(s2)
|
|
3043
|
+
|
|
3044
|
+
### Normalize S2_cross
|
|
3045
|
+
if norm == "auto":
|
|
3046
|
+
s2 /= (P1_dic[j3] * P2_dic[j3]) ** 0.5
|
|
3037
3047
|
|
|
3038
3048
|
S2.append(
|
|
3039
3049
|
self.backend.bk_expand_dims(s2, off_S2)
|
|
@@ -3092,7 +3102,6 @@ class funct(FOC.FoCUS):
|
|
|
3092
3102
|
M2convPsi_dic = {}
|
|
3093
3103
|
|
|
3094
3104
|
###### S3
|
|
3095
|
-
nside_j2 = nside_j3
|
|
3096
3105
|
for j2 in range(0, j3 + 1): # j2 <= j3
|
|
3097
3106
|
if return_data:
|
|
3098
3107
|
if S4[j3] is None:
|
|
@@ -3111,6 +3120,8 @@ class funct(FOC.FoCUS):
|
|
|
3111
3120
|
M1convPsi_dic,
|
|
3112
3121
|
calc_var=True,
|
|
3113
3122
|
cmat2=cmat2,
|
|
3123
|
+
cell_ids=cell_ids_j3,
|
|
3124
|
+
nside_j2=nside_j3,
|
|
3114
3125
|
) # [Nbatch, Nmask, Norient3, Norient2]
|
|
3115
3126
|
else:
|
|
3116
3127
|
s3 = self._compute_S3(
|
|
@@ -3122,19 +3133,21 @@ class funct(FOC.FoCUS):
|
|
|
3122
3133
|
M1convPsi_dic,
|
|
3123
3134
|
return_data=return_data,
|
|
3124
3135
|
cmat2=cmat2,
|
|
3136
|
+
cell_ids=cell_ids_j3,
|
|
3137
|
+
nside_j2=nside_j3,
|
|
3125
3138
|
) # [Nbatch, Nmask, Norient3, Norient2]
|
|
3126
3139
|
|
|
3127
3140
|
if return_data:
|
|
3128
3141
|
if S3[j3] is None:
|
|
3129
3142
|
S3[j3] = {}
|
|
3130
|
-
if out_nside is not None and out_nside <
|
|
3143
|
+
if out_nside is not None and out_nside < nside_j3:
|
|
3131
3144
|
s3 = self.backend.bk_reduce_mean(
|
|
3132
3145
|
self.backend.bk_reshape(
|
|
3133
3146
|
s3,
|
|
3134
3147
|
[
|
|
3135
3148
|
s3.shape[0],
|
|
3136
3149
|
12 * out_nside**2,
|
|
3137
|
-
(
|
|
3150
|
+
(nside_j3 // out_nside) ** 2,
|
|
3138
3151
|
s3.shape[2],
|
|
3139
3152
|
s3.shape[3],
|
|
3140
3153
|
],
|
|
@@ -3181,6 +3194,8 @@ class funct(FOC.FoCUS):
|
|
|
3181
3194
|
M2convPsi_dic,
|
|
3182
3195
|
calc_var=True,
|
|
3183
3196
|
cmat2=cmat2,
|
|
3197
|
+
cell_ids=cell_ids_j3,
|
|
3198
|
+
nside_j2=nside_j3,
|
|
3184
3199
|
)
|
|
3185
3200
|
s3p, vs3p = self._compute_S3(
|
|
3186
3201
|
j2,
|
|
@@ -3191,41 +3206,47 @@ class funct(FOC.FoCUS):
|
|
|
3191
3206
|
M1convPsi_dic,
|
|
3192
3207
|
calc_var=True,
|
|
3193
3208
|
cmat2=cmat2,
|
|
3209
|
+
cell_ids=cell_ids_j3,
|
|
3210
|
+
nside_j2=nside_j3,
|
|
3194
3211
|
)
|
|
3195
3212
|
else:
|
|
3196
|
-
|
|
3213
|
+
s3p = self._compute_S3(
|
|
3197
3214
|
j2,
|
|
3198
3215
|
j3,
|
|
3199
|
-
|
|
3216
|
+
conv2,
|
|
3200
3217
|
vmask,
|
|
3201
|
-
|
|
3202
|
-
|
|
3218
|
+
M1_dic,
|
|
3219
|
+
M1convPsi_dic,
|
|
3203
3220
|
return_data=return_data,
|
|
3204
3221
|
cmat2=cmat2,
|
|
3222
|
+
cell_ids=cell_ids_j3,
|
|
3223
|
+
nside_j2=nside_j3,
|
|
3205
3224
|
)
|
|
3206
|
-
|
|
3225
|
+
s3 = self._compute_S3(
|
|
3207
3226
|
j2,
|
|
3208
3227
|
j3,
|
|
3209
|
-
|
|
3228
|
+
conv1,
|
|
3210
3229
|
vmask,
|
|
3211
|
-
|
|
3212
|
-
|
|
3230
|
+
M2_dic,
|
|
3231
|
+
M2convPsi_dic,
|
|
3213
3232
|
return_data=return_data,
|
|
3214
3233
|
cmat2=cmat2,
|
|
3234
|
+
cell_ids=cell_ids_j3,
|
|
3235
|
+
nside_j2=nside_j3,
|
|
3215
3236
|
)
|
|
3216
3237
|
|
|
3217
3238
|
if return_data:
|
|
3218
3239
|
if S3[j3] is None:
|
|
3219
3240
|
S3[j3] = {}
|
|
3220
3241
|
S3P[j3] = {}
|
|
3221
|
-
if out_nside is not None and out_nside <
|
|
3242
|
+
if out_nside is not None and out_nside < nside_j3:
|
|
3222
3243
|
s3 = self.backend.bk_reduce_mean(
|
|
3223
3244
|
self.backend.bk_reshape(
|
|
3224
3245
|
s3,
|
|
3225
3246
|
[
|
|
3226
3247
|
s3.shape[0],
|
|
3227
3248
|
12 * out_nside**2,
|
|
3228
|
-
(
|
|
3249
|
+
(nside_j3 // out_nside) ** 2,
|
|
3229
3250
|
s3.shape[2],
|
|
3230
3251
|
s3.shape[3],
|
|
3231
3252
|
],
|
|
@@ -3238,7 +3259,7 @@ class funct(FOC.FoCUS):
|
|
|
3238
3259
|
[
|
|
3239
3260
|
s3.shape[0],
|
|
3240
3261
|
12 * out_nside**2,
|
|
3241
|
-
(
|
|
3262
|
+
(nside_j3 // out_nside) ** 2,
|
|
3242
3263
|
s3.shape[2],
|
|
3243
3264
|
s3.shape[3],
|
|
3244
3265
|
],
|
|
@@ -3295,7 +3316,6 @@ class funct(FOC.FoCUS):
|
|
|
3295
3316
|
# s3.shape[2]*s3.shape[3]]))
|
|
3296
3317
|
|
|
3297
3318
|
##### S4
|
|
3298
|
-
nside_j1 = nside_j2
|
|
3299
3319
|
for j1 in range(0, j2 + 1): # j1 <= j2
|
|
3300
3320
|
### S4_auto = <(|I1 * psi1| * psi3)(|I1 * psi2| * psi3)^*>
|
|
3301
3321
|
if not cross:
|
|
@@ -3321,14 +3341,14 @@ class funct(FOC.FoCUS):
|
|
|
3321
3341
|
if return_data:
|
|
3322
3342
|
if S4[j3][j2] is None:
|
|
3323
3343
|
S4[j3][j2] = {}
|
|
3324
|
-
if out_nside is not None and out_nside <
|
|
3344
|
+
if out_nside is not None and out_nside < nside_j3:
|
|
3325
3345
|
s4 = self.backend.bk_reduce_mean(
|
|
3326
3346
|
self.backend.bk_reshape(
|
|
3327
3347
|
s4,
|
|
3328
3348
|
[
|
|
3329
3349
|
s4.shape[0],
|
|
3330
3350
|
12 * out_nside**2,
|
|
3331
|
-
(
|
|
3351
|
+
(nside_j3 // out_nside) ** 2,
|
|
3332
3352
|
s4.shape[2],
|
|
3333
3353
|
s4.shape[3],
|
|
3334
3354
|
s4.shape[4],
|
|
@@ -3396,14 +3416,14 @@ class funct(FOC.FoCUS):
|
|
|
3396
3416
|
if return_data:
|
|
3397
3417
|
if S4[j3][j2] is None:
|
|
3398
3418
|
S4[j3][j2] = {}
|
|
3399
|
-
if out_nside is not None and out_nside <
|
|
3419
|
+
if out_nside is not None and out_nside < nside_j3:
|
|
3400
3420
|
s4 = self.backend.bk_reduce_mean(
|
|
3401
3421
|
self.backend.bk_reshape(
|
|
3402
3422
|
s4,
|
|
3403
3423
|
[
|
|
3404
3424
|
s4.shape[0],
|
|
3405
3425
|
12 * out_nside**2,
|
|
3406
|
-
(
|
|
3426
|
+
(nside_j3 // out_nside) ** 2,
|
|
3407
3427
|
s4.shape[2],
|
|
3408
3428
|
s4.shape[3],
|
|
3409
3429
|
s4.shape[4],
|
|
@@ -3447,48 +3467,51 @@ class funct(FOC.FoCUS):
|
|
|
3447
3467
|
self.backend.bk_expand_dims(vs4, off_S4)
|
|
3448
3468
|
) # Add a dimension for NS4
|
|
3449
3469
|
|
|
3450
|
-
nside_j1 = nside_j1 // 2
|
|
3451
|
-
nside_j2 = nside_j2 // 2
|
|
3452
|
-
|
|
3453
3470
|
###### Reshape for next iteration on j3
|
|
3454
3471
|
### Image I1,
|
|
3455
3472
|
# downscale the I1 [Nbatch, Npix_j3]
|
|
3456
3473
|
if j3 != Jmax - 1:
|
|
3457
|
-
I1 = self.smooth(I1, axis=1)
|
|
3458
|
-
I1 = self.ud_grade_2(
|
|
3474
|
+
I1 = self.smooth(I1, axis=1, cell_ids=cell_ids_j3, nside=nside_j3)
|
|
3475
|
+
I1, new_cell_ids_j3 = self.ud_grade_2(
|
|
3476
|
+
I1, axis=1, cell_ids=cell_ids_j3, nside=nside_j3
|
|
3477
|
+
)
|
|
3459
3478
|
|
|
3460
3479
|
### Image I2
|
|
3461
3480
|
if cross:
|
|
3462
|
-
I2 = self.smooth(I2, axis=1)
|
|
3463
|
-
I2 = self.ud_grade_2(
|
|
3481
|
+
I2 = self.smooth(I2, axis=1, cell_ids=cell_ids_j3, nside=nside_j3)
|
|
3482
|
+
I2, new_cell_ids_j3 = self.ud_grade_2(
|
|
3483
|
+
I2, axis=1, cell_ids=cell_ids_j3, nside=nside_j3
|
|
3484
|
+
)
|
|
3464
3485
|
|
|
3465
3486
|
### Modules
|
|
3466
3487
|
for j2 in range(0, j3 + 1): # j2 =< j3
|
|
3467
3488
|
### Dictionary M1_dic[j2]
|
|
3468
3489
|
M1_smooth = self.smooth(
|
|
3469
|
-
M1_dic[j2], axis=1
|
|
3490
|
+
M1_dic[j2], axis=1, cell_ids=cell_ids_j3, nside=nside_j3
|
|
3470
3491
|
) # [Nbatch, Npix_j3, Norient3]
|
|
3471
|
-
M1_dic[j2] = self.ud_grade_2(
|
|
3472
|
-
M1_smooth, axis=1
|
|
3492
|
+
M1_dic[j2], new_cell_ids_j2 = self.ud_grade_2(
|
|
3493
|
+
M1_smooth, axis=1, cell_ids=cell_ids_j3, nside=nside_j3
|
|
3473
3494
|
) # [Nbatch, Npix_j3, Norient3]
|
|
3474
3495
|
|
|
3475
3496
|
### Dictionary M2_dic[j2]
|
|
3476
3497
|
if cross:
|
|
3477
3498
|
M2_smooth = self.smooth(
|
|
3478
|
-
M2_dic[j2], axis=1
|
|
3499
|
+
M2_dic[j2], axis=1, cell_ids=cell_ids_j3, nside=nside_j3
|
|
3479
3500
|
) # [Nbatch, Npix_j3, Norient3]
|
|
3480
|
-
M2_dic[j2] = self.ud_grade_2(
|
|
3481
|
-
M2_smooth, axis=1
|
|
3501
|
+
M2_dic[j2], new_cell_ids_j2 = self.ud_grade_2(
|
|
3502
|
+
M2_smooth, axis=1, cell_ids=cell_ids_j3, nside=nside_j3
|
|
3482
3503
|
) # [Nbatch, Npix_j3, Norient3]
|
|
3483
|
-
|
|
3484
3504
|
### Mask
|
|
3485
|
-
vmask = self.ud_grade_2(
|
|
3505
|
+
vmask, new_cell_ids_j3 = self.ud_grade_2(
|
|
3506
|
+
vmask, axis=1, cell_ids=cell_ids_j3, nside=nside_j3
|
|
3507
|
+
)
|
|
3486
3508
|
|
|
3487
3509
|
if self.mask_thres is not None:
|
|
3488
3510
|
vmask = self.backend.bk_threshold(vmask, self.mask_thres)
|
|
3489
3511
|
|
|
3490
3512
|
### NSIDE_j3
|
|
3491
3513
|
nside_j3 = nside_j3 // 2
|
|
3514
|
+
cell_ids_j3 = new_cell_ids_j3
|
|
3492
3515
|
|
|
3493
3516
|
### Store P1_dic and P2_dic in self
|
|
3494
3517
|
if (norm == "auto") and (self.P1_dic is None):
|
|
@@ -3588,6 +3611,8 @@ class funct(FOC.FoCUS):
|
|
|
3588
3611
|
calc_var=False,
|
|
3589
3612
|
return_data=False,
|
|
3590
3613
|
cmat2=None,
|
|
3614
|
+
cell_ids=None,
|
|
3615
|
+
nside_j2=None,
|
|
3591
3616
|
):
|
|
3592
3617
|
"""
|
|
3593
3618
|
Compute the S3 coefficients (auto or cross)
|
|
@@ -3601,7 +3626,7 @@ class funct(FOC.FoCUS):
|
|
|
3601
3626
|
### Compute |I1 * Psi_j2| * Psi_j3 = M1_j2 * Psi_j3
|
|
3602
3627
|
# Warning: M1_dic[j2] is already at j3 resolution [Nbatch, Npix_j3, Norient3]
|
|
3603
3628
|
MconvPsi = self.convol(
|
|
3604
|
-
M_dic[j2], axis=1
|
|
3629
|
+
M_dic[j2], axis=1, cell_ids=cell_ids, nside=nside_j2
|
|
3605
3630
|
) # [Nbatch, Npix_j3, Norient3, Norient2]
|
|
3606
3631
|
if cmat2 is not None:
|
|
3607
3632
|
tmp2 = self.backend.bk_repeat(MconvPsi, self.NORIENT, axis=-1)
|
|
@@ -3822,24 +3847,62 @@ class funct(FOC.FoCUS):
|
|
|
3822
3847
|
dy = int(max(8, min(np.ceil(N / 2**j), N // 2)))
|
|
3823
3848
|
return dx, dy
|
|
3824
3849
|
|
|
3825
|
-
def get_edge_masks(self, M, N, J, d0=1):
|
|
3850
|
+
def get_edge_masks(self, M, N, J, d0=1, in_mask=None, edge_dx=None, edge_dy=None):
|
|
3826
3851
|
"""
|
|
3827
3852
|
This function is strongly inspire by the package https://github.com/SihaoCheng/scattering_transform
|
|
3828
3853
|
Done by Sihao Cheng and Rudy Morel.
|
|
3829
3854
|
"""
|
|
3830
3855
|
edge_masks = np.empty((J, M, N))
|
|
3856
|
+
|
|
3831
3857
|
X, Y = np.meshgrid(np.arange(M), np.arange(N), indexing="ij")
|
|
3832
|
-
|
|
3833
|
-
|
|
3834
|
-
|
|
3835
|
-
|
|
3836
|
-
|
|
3837
|
-
|
|
3838
|
-
|
|
3839
|
-
|
|
3858
|
+
if in_mask is not None:
|
|
3859
|
+
from scipy.ndimage import binary_erosion
|
|
3860
|
+
|
|
3861
|
+
if in_mask is not None:
|
|
3862
|
+
if in_mask.shape[0] != M or in_mask.shape[0] != N:
|
|
3863
|
+
l_mask = in_mask.reshape(
|
|
3864
|
+
M, in_mask.shape[0] // M, N, in_mask.shape[1] // N
|
|
3865
|
+
)
|
|
3866
|
+
l_mask = (
|
|
3867
|
+
np.sum(np.sum(l_mask, 1), 2)
|
|
3868
|
+
* (M * N)
|
|
3869
|
+
/ (in_mask.shape[0] * in_mask.shape[1])
|
|
3870
|
+
)
|
|
3871
|
+
else:
|
|
3872
|
+
l_mask = in_mask
|
|
3873
|
+
|
|
3874
|
+
if edge_dx is None:
|
|
3875
|
+
for j in range(J):
|
|
3876
|
+
edge_dx = min(M // 4, 2**j * d0)
|
|
3877
|
+
edge_dy = min(N // 4, 2**j * d0)
|
|
3878
|
+
|
|
3879
|
+
edge_masks[j] = (
|
|
3880
|
+
(X >= edge_dx)
|
|
3881
|
+
* (X < M - edge_dx)
|
|
3882
|
+
* (Y >= edge_dy)
|
|
3883
|
+
* (Y < N - edge_dy)
|
|
3884
|
+
)
|
|
3885
|
+
if in_mask is not None:
|
|
3886
|
+
l_mask = binary_erosion(
|
|
3887
|
+
l_mask, iterations=1 + np.max([edge_dx, edge_dy])
|
|
3888
|
+
)
|
|
3889
|
+
edge_masks[j] *= l_mask
|
|
3890
|
+
|
|
3891
|
+
edge_masks = edge_masks[:, None, :, :]
|
|
3892
|
+
|
|
3893
|
+
edge_masks = edge_masks / edge_masks.mean((-2, -1))[:, :, None, None]
|
|
3894
|
+
else:
|
|
3895
|
+
edge_masks = (
|
|
3896
|
+
(X >= edge_dx) * (X < M - edge_dx) * (Y >= edge_dy) * (Y < N - edge_dy)
|
|
3840
3897
|
)
|
|
3841
|
-
|
|
3842
|
-
|
|
3898
|
+
if in_mask is not None:
|
|
3899
|
+
l_mask = binary_erosion(
|
|
3900
|
+
l_mask, iterations=1 + np.max([edge_dx, edge_dy])
|
|
3901
|
+
)
|
|
3902
|
+
edge_masks *= l_mask
|
|
3903
|
+
|
|
3904
|
+
edge_masks = edge_masks / edge_masks.mean((-2, -1))
|
|
3905
|
+
|
|
3843
3906
|
return self.backend.bk_cast(edge_masks)
|
|
3844
3907
|
|
|
3845
3908
|
# ---------------------------------------------------------------------------
|
|
@@ -3857,6 +3920,7 @@ class funct(FOC.FoCUS):
|
|
|
3857
3920
|
use_ref=False,
|
|
3858
3921
|
normalization="S2",
|
|
3859
3922
|
edge=False,
|
|
3923
|
+
in_mask=None,
|
|
3860
3924
|
pseudo_coef=1,
|
|
3861
3925
|
get_variance=False,
|
|
3862
3926
|
ref_sigma=None,
|
|
@@ -3925,6 +3989,9 @@ class funct(FOC.FoCUS):
|
|
|
3925
3989
|
if S4_criteria is None:
|
|
3926
3990
|
S4_criteria = "j2>=j1"
|
|
3927
3991
|
|
|
3992
|
+
if not edge and in_mask is not None:
|
|
3993
|
+
edge = True
|
|
3994
|
+
|
|
3928
3995
|
if self.all_bk_type == "float32":
|
|
3929
3996
|
C_ONE = np.complex64(1.0)
|
|
3930
3997
|
else:
|
|
@@ -3974,14 +4041,15 @@ class funct(FOC.FoCUS):
|
|
|
3974
4041
|
|
|
3975
4042
|
J = int(np.log(nside) / np.log(2)) # Number of j scales
|
|
3976
4043
|
|
|
3977
|
-
if Jmax is None:
|
|
3978
|
-
|
|
3979
|
-
|
|
3980
|
-
|
|
3981
|
-
|
|
3982
|
-
|
|
3983
|
-
|
|
3984
|
-
|
|
4044
|
+
if Jmax is not None:
|
|
4045
|
+
|
|
4046
|
+
if Jmax > J:
|
|
4047
|
+
print("==========\n\n")
|
|
4048
|
+
print(
|
|
4049
|
+
"The Jmax you requested is larger than the data size, which may cause problems while computing the scattering transform."
|
|
4050
|
+
)
|
|
4051
|
+
print("\n\n==========")
|
|
4052
|
+
J = Jmax # Number of steps for the loop on scales
|
|
3985
4053
|
|
|
3986
4054
|
L = self.NORIENT
|
|
3987
4055
|
norm_factor_S3 = 1.0
|
|
@@ -4067,7 +4135,10 @@ class funct(FOC.FoCUS):
|
|
|
4067
4135
|
#
|
|
4068
4136
|
if edge:
|
|
4069
4137
|
if (M, N, J) not in self.edge_masks:
|
|
4070
|
-
self.edge_masks[(M, N, J)] = self.get_edge_masks(
|
|
4138
|
+
self.edge_masks[(M, N, J)] = self.get_edge_masks(
|
|
4139
|
+
M, N, J, in_mask=in_mask
|
|
4140
|
+
)
|
|
4141
|
+
|
|
4071
4142
|
edge_mask = self.edge_masks[(M, N, J)]
|
|
4072
4143
|
else:
|
|
4073
4144
|
edge_mask = 1
|
|
@@ -4206,8 +4277,19 @@ class funct(FOC.FoCUS):
|
|
|
4206
4277
|
wavelet_f3 = self.cut_high_k_off(filters_set[j3], dx3, dy3) # L,x,y
|
|
4207
4278
|
_, M3, N3 = wavelet_f3.shape
|
|
4208
4279
|
wavelet_f3_squared = wavelet_f3**2
|
|
4209
|
-
|
|
4210
|
-
|
|
4280
|
+
if edge is True:
|
|
4281
|
+
if (M3, N3, J, j3) not in self.edge_masks:
|
|
4282
|
+
|
|
4283
|
+
edge_dx = min(4, int(2**j3 * dx3 * 2 / M))
|
|
4284
|
+
edge_dy = min(4, int(2**j3 * dy3 * 2 / N))
|
|
4285
|
+
|
|
4286
|
+
self.edge_masks[(M3, N3, J, j3)] = self.get_edge_masks(
|
|
4287
|
+
M3, N3, J, in_mask=in_mask, edge_dx=edge_dx, edge_dy=edge_dy
|
|
4288
|
+
)
|
|
4289
|
+
|
|
4290
|
+
edge_mask = self.edge_masks[(M3, N3, J, j3)]
|
|
4291
|
+
else:
|
|
4292
|
+
edge_mask = 1
|
|
4211
4293
|
|
|
4212
4294
|
# a normalization change due to the cutoff of frequency space
|
|
4213
4295
|
fft_factor = 1 / (M3 * N3) * (M3 * N3 / M / N) ** 2
|
|
@@ -4274,7 +4356,8 @@ class funct(FOC.FoCUS):
|
|
|
4274
4356
|
(
|
|
4275
4357
|
data_small.view(N_image, 1, 1, M3, N3)
|
|
4276
4358
|
* self.backend.bk_conjugate(I12_w3_small)
|
|
4277
|
-
|
|
4359
|
+
* edge_mask[None, None, None, :, :]
|
|
4360
|
+
).mean( # [..., edge_dx : M3 - edge_dx, edge_dy : N3 - edge_dy]
|
|
4278
4361
|
(-2, -1)
|
|
4279
4362
|
)
|
|
4280
4363
|
* fft_factor
|
|
@@ -4285,11 +4368,8 @@ class funct(FOC.FoCUS):
|
|
|
4285
4368
|
(
|
|
4286
4369
|
data_small.view(N_image, 1, 1, M3, N3)
|
|
4287
4370
|
* self.backend.bk_conjugate(I12_w3_small)
|
|
4288
|
-
|
|
4289
|
-
|
|
4290
|
-
].std(
|
|
4291
|
-
(-2, -1)
|
|
4292
|
-
)
|
|
4371
|
+
* edge_mask[None, None, None, :, :]
|
|
4372
|
+
).std((-2, -1))
|
|
4293
4373
|
* fft_factor
|
|
4294
4374
|
/ norm_factor_S3
|
|
4295
4375
|
)
|
|
@@ -4318,11 +4398,8 @@ class funct(FOC.FoCUS):
|
|
|
4318
4398
|
(
|
|
4319
4399
|
data2_small.view(N_image2, 1, 1, M3, N3)
|
|
4320
4400
|
* self.backend.bk_conjugate(I12_w3_small)
|
|
4321
|
-
|
|
4322
|
-
|
|
4323
|
-
].mean(
|
|
4324
|
-
(-2, -1)
|
|
4325
|
-
)
|
|
4401
|
+
* edge_mask[None, None, None, :, :]
|
|
4402
|
+
).mean((-2, -1))
|
|
4326
4403
|
* fft_factor
|
|
4327
4404
|
/ norm_factor_S3
|
|
4328
4405
|
)
|
|
@@ -4331,13 +4408,8 @@ class funct(FOC.FoCUS):
|
|
|
4331
4408
|
(
|
|
4332
4409
|
data2_small.view(N_image2, 1, 1, M3, N3)
|
|
4333
4410
|
* self.backend.bk_conjugate(I12_w3_small)
|
|
4334
|
-
|
|
4335
|
-
|
|
4336
|
-
edge_dx : M3 - edge_dx,
|
|
4337
|
-
edge_dy : N3 - edge_dy,
|
|
4338
|
-
].std(
|
|
4339
|
-
(-2, -1)
|
|
4340
|
-
)
|
|
4411
|
+
* edge_mask[None, None, None, :, :]
|
|
4412
|
+
).std((-2, -1))
|
|
4341
4413
|
* fft_factor
|
|
4342
4414
|
/ norm_factor_S3
|
|
4343
4415
|
)
|
|
@@ -4406,9 +4478,8 @@ class funct(FOC.FoCUS):
|
|
|
4406
4478
|
N_image, 1, L, L, M3, N3
|
|
4407
4479
|
)
|
|
4408
4480
|
)
|
|
4409
|
-
|
|
4410
|
-
|
|
4411
|
-
) * fft_factor
|
|
4481
|
+
* edge_mask[None, None, None, None, :, :]
|
|
4482
|
+
).mean((-2, -1)) * fft_factor
|
|
4412
4483
|
if get_variance:
|
|
4413
4484
|
S4_sigma[:, Ndata_S4, :, :, :] = (
|
|
4414
4485
|
I1_small[:, j1].view(
|
|
@@ -4419,11 +4490,10 @@ class funct(FOC.FoCUS):
|
|
|
4419
4490
|
N_image, 1, L, L, M3, N3
|
|
4420
4491
|
)
|
|
4421
4492
|
)
|
|
4422
|
-
|
|
4423
|
-
|
|
4424
|
-
|
|
4425
|
-
|
|
4426
|
-
) * fft_factor
|
|
4493
|
+
* edge_mask[
|
|
4494
|
+
None, None, None, None, :, :
|
|
4495
|
+
]
|
|
4496
|
+
).std((-2, -1)) * fft_factor
|
|
4427
4497
|
else:
|
|
4428
4498
|
for l1 in range(L):
|
|
4429
4499
|
# [N_image,l2,l3,x,y]
|
|
@@ -4436,11 +4506,10 @@ class funct(FOC.FoCUS):
|
|
|
4436
4506
|
N_image, L, L, M3, N3
|
|
4437
4507
|
)
|
|
4438
4508
|
)
|
|
4439
|
-
|
|
4440
|
-
|
|
4441
|
-
|
|
4442
|
-
|
|
4443
|
-
) * fft_factor
|
|
4509
|
+
* edge_mask[
|
|
4510
|
+
None, None, None, None, :, :
|
|
4511
|
+
]
|
|
4512
|
+
).mean((-2, -1)) * fft_factor
|
|
4444
4513
|
if get_variance:
|
|
4445
4514
|
S4_sigma[:, Ndata_S4, l1, :, :] = (
|
|
4446
4515
|
I1_small[:, j1].view(
|
|
@@ -4451,13 +4520,10 @@ class funct(FOC.FoCUS):
|
|
|
4451
4520
|
N_image, L, L, M3, N3
|
|
4452
4521
|
)
|
|
4453
4522
|
)
|
|
4454
|
-
|
|
4455
|
-
|
|
4456
|
-
|
|
4457
|
-
|
|
4458
|
-
].mean(
|
|
4459
|
-
(-2, -1)
|
|
4460
|
-
) * fft_factor
|
|
4523
|
+
* edge_mask[
|
|
4524
|
+
None, None, None, None, :, :
|
|
4525
|
+
]
|
|
4526
|
+
).std((-2, -1)) * fft_factor
|
|
4461
4527
|
|
|
4462
4528
|
Ndata_S4 += 1
|
|
4463
4529
|
|
|
@@ -4788,7 +4854,9 @@ class funct(FOC.FoCUS):
|
|
|
4788
4854
|
#
|
|
4789
4855
|
if edge:
|
|
4790
4856
|
if (M, N, J) not in self.edge_masks:
|
|
4791
|
-
self.edge_masks[(M, N, J)] = self.get_edge_masks(
|
|
4857
|
+
self.edge_masks[(M, N, J)] = self.get_edge_masks(
|
|
4858
|
+
M, N, J, in_mask=in_mask
|
|
4859
|
+
)
|
|
4792
4860
|
edge_mask = self.edge_masks[(M, N, J)]
|
|
4793
4861
|
else:
|
|
4794
4862
|
edge_mask = 1
|
|
@@ -5662,19 +5730,42 @@ class funct(FOC.FoCUS):
|
|
|
5662
5730
|
|
|
5663
5731
|
return for_synthesis
|
|
5664
5732
|
|
|
5665
|
-
def
|
|
5733
|
+
def purge_edge_mask(self):
|
|
5734
|
+
|
|
5735
|
+
list_edge = []
|
|
5736
|
+
for k in self.edge_masks:
|
|
5737
|
+
list_edge.append(k)
|
|
5738
|
+
for k in list_edge:
|
|
5739
|
+
del self.edge_masks[k]
|
|
5740
|
+
|
|
5741
|
+
self.edge_masks = {}
|
|
5742
|
+
|
|
5743
|
+
def to_gaussian(self, x, in_mask=None):
|
|
5666
5744
|
from scipy.interpolate import interp1d
|
|
5667
5745
|
from scipy.stats import norm
|
|
5668
5746
|
|
|
5669
|
-
|
|
5670
|
-
|
|
5671
|
-
|
|
5672
|
-
|
|
5747
|
+
if in_mask is not None:
|
|
5748
|
+
m_idx = np.where(in_mask.flatten() > 0)[0]
|
|
5749
|
+
idx = np.argsort(x.flatten()[m_idx])
|
|
5750
|
+
p = norm.ppf((np.arange(1, idx.shape[0] + 1) - 0.5) / idx.shape[0])
|
|
5751
|
+
im_target = x.flatten()
|
|
5752
|
+
im_target[m_idx[idx]] = p
|
|
5673
5753
|
|
|
5674
|
-
|
|
5675
|
-
|
|
5676
|
-
|
|
5677
|
-
|
|
5754
|
+
self.f_gaussian = interp1d(
|
|
5755
|
+
im_target[m_idx[idx]], x.flatten()[m_idx[idx]], kind="cubic"
|
|
5756
|
+
)
|
|
5757
|
+
self.val_min = im_target[m_idx][idx[0]]
|
|
5758
|
+
self.val_max = im_target[m_idx][idx[-1]]
|
|
5759
|
+
else:
|
|
5760
|
+
idx = np.argsort(x.flatten())
|
|
5761
|
+
p = (np.arange(1, idx.shape[0] + 1) - 0.5) / idx.shape[0]
|
|
5762
|
+
im_target = x.flatten()
|
|
5763
|
+
im_target[idx] = norm.ppf(p)
|
|
5764
|
+
|
|
5765
|
+
# Interpolation cubique
|
|
5766
|
+
self.f_gaussian = interp1d(im_target[idx], x.flatten()[idx], kind="cubic")
|
|
5767
|
+
self.val_min = im_target[idx[0]]
|
|
5768
|
+
self.val_max = im_target[idx[-1]]
|
|
5678
5769
|
return im_target.reshape(x.shape)
|
|
5679
5770
|
|
|
5680
5771
|
def from_gaussian(self, x):
|
|
@@ -5948,57 +6039,6 @@ class funct(FOC.FoCUS):
|
|
|
5948
6039
|
|
|
5949
6040
|
return result
|
|
5950
6041
|
|
|
5951
|
-
# # ---------------------------------------------−---------
|
|
5952
|
-
# def std(self, list_of_sc):
|
|
5953
|
-
# n = len(list_of_sc)
|
|
5954
|
-
# res = list_of_sc[0]
|
|
5955
|
-
# res2 = list_of_sc[0] * list_of_sc[0]
|
|
5956
|
-
# for k in range(1, n):
|
|
5957
|
-
# res = res + list_of_sc[k]
|
|
5958
|
-
# res2 = res2 + list_of_sc[k] * list_of_sc[k]
|
|
5959
|
-
#
|
|
5960
|
-
# if res.S1 is None:
|
|
5961
|
-
# if res.S3P is not None:
|
|
5962
|
-
# return scat_cov(
|
|
5963
|
-
# res.domult(sig.S0, res.S0) * res.domult(sig.S0, res.S0),
|
|
5964
|
-
# res.domult(sig.S2, res.S2) * res.domult(sig.S2, res.S2),
|
|
5965
|
-
# res.domult(sig.S3, res.S3) * res.domult(sig.S3, res.S3),
|
|
5966
|
-
# res.domult(sig.S4, res.S4) * res.domult(sig.S4, res.S4),
|
|
5967
|
-
# S3P=res.domult(sig.S3P, res.S3P) * res.domult(sig.S3P, res.S3P),
|
|
5968
|
-
# backend=self.backend,
|
|
5969
|
-
# use_1D=self.use_1D,
|
|
5970
|
-
# )
|
|
5971
|
-
# else:
|
|
5972
|
-
# return scat_cov(
|
|
5973
|
-
# res.domult(sig.S0, res.S0) * res.domult(sig.S0, res.S0),
|
|
5974
|
-
# res.domult(sig.S2, res.S2) * res.domult(sig.S2, res.S2),
|
|
5975
|
-
# res.domult(sig.S3, res.S3) * res.domult(sig.S3, res.S3),
|
|
5976
|
-
# res.domult(sig.S4, res.S4) * res.domult(sig.S4, res.S4),
|
|
5977
|
-
# backend=self.backend,
|
|
5978
|
-
# use_1D=self.use_1D,
|
|
5979
|
-
# )
|
|
5980
|
-
# else:
|
|
5981
|
-
# if res.S3P is None:
|
|
5982
|
-
# return scat_cov(
|
|
5983
|
-
# res.domult(sig.S0, res.S0) * res.domult(sig.S0, res.S0),
|
|
5984
|
-
# res.domult(sig.S2, res.S2) * res.domult(sig.S2, res.S2),
|
|
5985
|
-
# res.domult(sig.S3, res.S3) * res.domult(sig.S3, res.S3),
|
|
5986
|
-
# res.domult(sig.S4, res.S4) * res.domult(sig.S4, res.S4),
|
|
5987
|
-
# S1=res.domult(sig.S1, res.S1) * res.domult(sig.S1, res.S1),
|
|
5988
|
-
# S3P=res.domult(sig.S3P, res.S3P) * res.domult(sig.S3P, res.S3P),
|
|
5989
|
-
# backend=self.backend,
|
|
5990
|
-
# )
|
|
5991
|
-
# else:
|
|
5992
|
-
# return scat_cov(
|
|
5993
|
-
# res.domult(sig.S2, res.S2) * res.domult(sig.S2, res.S2),
|
|
5994
|
-
# res.domult(sig.S1, res.S1) * res.domult(sig.S1, res.S1),
|
|
5995
|
-
# res.domult(sig.S3, res.S3) * res.domult(sig.S3, res.S3),
|
|
5996
|
-
# res.domult(sig.S4, res.S4) * res.domult(sig.S4, res.S4),
|
|
5997
|
-
# backend=self.backend,
|
|
5998
|
-
# use_1D=self.use_1D,
|
|
5999
|
-
# )
|
|
6000
|
-
# return self.NORIENT
|
|
6001
|
-
|
|
6002
6042
|
@tf_function
|
|
6003
6043
|
def eval_comp_fast(
|
|
6004
6044
|
self,
|
|
@@ -6006,13 +6046,12 @@ class funct(FOC.FoCUS):
|
|
|
6006
6046
|
image2=None,
|
|
6007
6047
|
mask=None,
|
|
6008
6048
|
norm=None,
|
|
6009
|
-
Auto=True,
|
|
6010
6049
|
cmat=None,
|
|
6011
6050
|
cmat2=None,
|
|
6012
6051
|
):
|
|
6013
6052
|
|
|
6014
6053
|
res = self.eval(
|
|
6015
|
-
image1, image2=image2, mask=mask,
|
|
6054
|
+
image1, image2=image2, mask=mask, cmat=cmat, cmat2=cmat2
|
|
6016
6055
|
)
|
|
6017
6056
|
return res.S0, res.S2, res.S1, res.S3, res.S4, res.S3P
|
|
6018
6057
|
|
|
@@ -6022,12 +6061,11 @@ class funct(FOC.FoCUS):
|
|
|
6022
6061
|
image2=None,
|
|
6023
6062
|
mask=None,
|
|
6024
6063
|
norm=None,
|
|
6025
|
-
Auto=True,
|
|
6026
6064
|
cmat=None,
|
|
6027
6065
|
cmat2=None,
|
|
6028
6066
|
):
|
|
6029
6067
|
s0, s2, s1, s3, s4, s3p = self.eval_comp_fast(
|
|
6030
|
-
image1, image2=image2, mask=mask,
|
|
6068
|
+
image1, image2=image2, mask=mask, cmat=cmat, cmat2=cmat2
|
|
6031
6069
|
)
|
|
6032
6070
|
return scat_cov(
|
|
6033
6071
|
s0, s2, s3, s4, s1=s1, s3p=s3p, backend=self.backend, use_1D=self.use_1D
|
|
@@ -6036,6 +6074,7 @@ class funct(FOC.FoCUS):
|
|
|
6036
6074
|
def synthesis(
|
|
6037
6075
|
self,
|
|
6038
6076
|
image_target,
|
|
6077
|
+
reference=None,
|
|
6039
6078
|
nstep=4,
|
|
6040
6079
|
seed=1234,
|
|
6041
6080
|
Jmax=None,
|
|
@@ -6045,6 +6084,7 @@ class funct(FOC.FoCUS):
|
|
|
6045
6084
|
synthesised_N=1,
|
|
6046
6085
|
input_image=None,
|
|
6047
6086
|
grd_mask=None,
|
|
6087
|
+
in_mask=None,
|
|
6048
6088
|
iso_ang=False,
|
|
6049
6089
|
EVAL_FREQUENCY=100,
|
|
6050
6090
|
NUM_EPOCHS=300,
|
|
@@ -6054,18 +6094,68 @@ class funct(FOC.FoCUS):
|
|
|
6054
6094
|
|
|
6055
6095
|
import foscat.Synthesis as synthe
|
|
6056
6096
|
|
|
6097
|
+
l_edge = edge
|
|
6098
|
+
if in_mask is not None:
|
|
6099
|
+
l_edge = True
|
|
6100
|
+
|
|
6101
|
+
if edge:
|
|
6102
|
+
self.purge_edge_mask()
|
|
6103
|
+
|
|
6104
|
+
def The_loss_ref_image(u, scat_operator, args):
|
|
6105
|
+
input_image = args[0]
|
|
6106
|
+
mask = args[1]
|
|
6107
|
+
|
|
6108
|
+
loss = 1E-3*scat_operator.backend.bk_reduce_mean(
|
|
6109
|
+
scat_operator.backend.bk_square(mask*(input_image - u))
|
|
6110
|
+
)
|
|
6111
|
+
return loss
|
|
6112
|
+
|
|
6057
6113
|
def The_loss(u, scat_operator, args):
|
|
6058
6114
|
ref = args[0]
|
|
6059
6115
|
sref = args[1]
|
|
6060
6116
|
use_v = args[2]
|
|
6117
|
+
ljmax = args[3]
|
|
6118
|
+
|
|
6119
|
+
# compute scattering covariance of the current synthetised map called u
|
|
6120
|
+
if use_v:
|
|
6121
|
+
learn = scat_operator.reduce_mean_batch(
|
|
6122
|
+
scat_operator.scattering_cov(
|
|
6123
|
+
u,
|
|
6124
|
+
edge=l_edge,
|
|
6125
|
+
Jmax=ljmax,
|
|
6126
|
+
ref_sigma=sref,
|
|
6127
|
+
use_ref=True,
|
|
6128
|
+
iso_ang=iso_ang,
|
|
6129
|
+
)
|
|
6130
|
+
)
|
|
6131
|
+
else:
|
|
6132
|
+
learn = scat_operator.reduce_mean_batch(
|
|
6133
|
+
scat_operator.scattering_cov(
|
|
6134
|
+
u, edge=l_edge, Jmax=ljmax, use_ref=True, iso_ang=iso_ang
|
|
6135
|
+
)
|
|
6136
|
+
)
|
|
6137
|
+
|
|
6138
|
+
# make the difference withe the reference coordinates
|
|
6139
|
+
loss = scat_operator.backend.bk_reduce_mean(
|
|
6140
|
+
scat_operator.backend.bk_square(learn - ref)
|
|
6141
|
+
)
|
|
6142
|
+
return loss
|
|
6143
|
+
|
|
6144
|
+
def The_lossX(u, scat_operator, args):
|
|
6145
|
+
ref = args[0]
|
|
6146
|
+
sref = args[1]
|
|
6147
|
+
use_v = args[2]
|
|
6148
|
+
im2 = args[3]
|
|
6149
|
+
ljmax = args[4]
|
|
6061
6150
|
|
|
6062
6151
|
# compute scattering covariance of the current synthetised map called u
|
|
6063
6152
|
if use_v:
|
|
6064
6153
|
learn = scat_operator.reduce_mean_batch(
|
|
6065
6154
|
scat_operator.scattering_cov(
|
|
6066
6155
|
u,
|
|
6067
|
-
|
|
6068
|
-
|
|
6156
|
+
data2=im2,
|
|
6157
|
+
edge=l_edge,
|
|
6158
|
+
Jmax=ljmax,
|
|
6069
6159
|
ref_sigma=sref,
|
|
6070
6160
|
use_ref=True,
|
|
6071
6161
|
iso_ang=iso_ang,
|
|
@@ -6074,7 +6164,12 @@ class funct(FOC.FoCUS):
|
|
|
6074
6164
|
else:
|
|
6075
6165
|
learn = scat_operator.reduce_mean_batch(
|
|
6076
6166
|
scat_operator.scattering_cov(
|
|
6077
|
-
u,
|
|
6167
|
+
u,
|
|
6168
|
+
data2=im2,
|
|
6169
|
+
edge=l_edge,
|
|
6170
|
+
Jmax=ljmax,
|
|
6171
|
+
use_ref=True,
|
|
6172
|
+
iso_ang=iso_ang,
|
|
6078
6173
|
)
|
|
6079
6174
|
)
|
|
6080
6175
|
|
|
@@ -6086,7 +6181,7 @@ class funct(FOC.FoCUS):
|
|
|
6086
6181
|
|
|
6087
6182
|
if to_gaussian:
|
|
6088
6183
|
# Change the data histogram to gaussian distribution
|
|
6089
|
-
im_target = self.to_gaussian(image_target)
|
|
6184
|
+
im_target = self.to_gaussian(image_target, in_mask=in_mask)
|
|
6090
6185
|
else:
|
|
6091
6186
|
im_target = image_target
|
|
6092
6187
|
|
|
@@ -6113,22 +6208,59 @@ class funct(FOC.FoCUS):
|
|
|
6113
6208
|
|
|
6114
6209
|
t1 = time.time()
|
|
6115
6210
|
tmp = {}
|
|
6116
|
-
|
|
6117
|
-
l_grd_mask={}
|
|
6118
|
-
|
|
6211
|
+
|
|
6212
|
+
l_grd_mask = {}
|
|
6213
|
+
l_in_mask = {}
|
|
6214
|
+
l_input_image = {}
|
|
6215
|
+
l_ref = {}
|
|
6216
|
+
l_jmax = {}
|
|
6217
|
+
|
|
6119
6218
|
tmp[nstep - 1] = self.backend.bk_cast(im_target)
|
|
6219
|
+
l_jmax[nstep - 1] = Jmax
|
|
6220
|
+
|
|
6221
|
+
if reference is not None:
|
|
6222
|
+
l_ref[nstep - 1] = self.backend.bk_cast(reference)
|
|
6223
|
+
else:
|
|
6224
|
+
l_ref[nstep - 1] = None
|
|
6225
|
+
|
|
6120
6226
|
if grd_mask is not None:
|
|
6121
6227
|
l_grd_mask[nstep - 1] = self.backend.bk_cast(grd_mask)
|
|
6122
6228
|
else:
|
|
6123
6229
|
l_grd_mask[nstep - 1] = None
|
|
6124
|
-
|
|
6230
|
+
if in_mask is not None:
|
|
6231
|
+
l_in_mask[nstep - 1] = in_mask
|
|
6232
|
+
else:
|
|
6233
|
+
l_in_mask[nstep - 1] = None
|
|
6234
|
+
|
|
6235
|
+
if input_image is not None:
|
|
6236
|
+
l_input_image[nstep - 1] = input_image
|
|
6237
|
+
|
|
6125
6238
|
for ell in range(nstep - 2, -1, -1):
|
|
6126
|
-
tmp[ell] = self.ud_grade_2(tmp[ell + 1], axis=1)
|
|
6239
|
+
tmp[ell], _ = self.ud_grade_2(tmp[ell + 1], axis=1)
|
|
6240
|
+
|
|
6127
6241
|
if grd_mask is not None:
|
|
6128
|
-
l_grd_mask[ell] = self.ud_grade_2(l_grd_mask[ell + 1], axis=1)
|
|
6242
|
+
l_grd_mask[ell], _ = self.ud_grade_2(l_grd_mask[ell + 1], axis=1)
|
|
6129
6243
|
else:
|
|
6130
6244
|
l_grd_mask[ell] = None
|
|
6131
|
-
|
|
6245
|
+
|
|
6246
|
+
if in_mask is not None:
|
|
6247
|
+
l_in_mask[ell], _ = self.ud_grade_2(l_in_mask[ell + 1])
|
|
6248
|
+
l_in_mask[ell] = self.backend.to_numpy(l_in_mask[ell])
|
|
6249
|
+
else:
|
|
6250
|
+
l_in_mask[ell] = None
|
|
6251
|
+
|
|
6252
|
+
if input_image is not None:
|
|
6253
|
+
l_input_image[ell], _ = self.ud_grade_2(l_input_image[ell + 1], axis=1)
|
|
6254
|
+
|
|
6255
|
+
if reference is not None:
|
|
6256
|
+
l_ref[ell], _ = self.ud_grade_2(l_ref[ell + 1], axis=1)
|
|
6257
|
+
else:
|
|
6258
|
+
l_ref[ell] = None
|
|
6259
|
+
|
|
6260
|
+
if l_jmax[ell + 1] is None:
|
|
6261
|
+
l_jmax[ell] = None
|
|
6262
|
+
else:
|
|
6263
|
+
l_jmax[ell] = l_jmax[ell + 1] - 1
|
|
6132
6264
|
|
|
6133
6265
|
if not self.use_2D and not self.use_1D:
|
|
6134
6266
|
l_nside = nside // (2 ** (nstep - 1))
|
|
@@ -6138,16 +6270,20 @@ class funct(FOC.FoCUS):
|
|
|
6138
6270
|
if input_image is None:
|
|
6139
6271
|
np.random.seed(seed)
|
|
6140
6272
|
if self.use_2D:
|
|
6141
|
-
imap = self.backend.bk_cast(
|
|
6142
|
-
|
|
6143
|
-
|
|
6273
|
+
imap = self.backend.bk_cast(
|
|
6274
|
+
np.random.randn(
|
|
6275
|
+
synthesised_N, tmp[k].shape[1], tmp[k].shape[2]
|
|
6276
|
+
)
|
|
6277
|
+
)
|
|
6144
6278
|
else:
|
|
6145
|
-
imap = self.backend.bk_cast(
|
|
6279
|
+
imap = self.backend.bk_cast(
|
|
6280
|
+
np.random.randn(synthesised_N, tmp[k].shape[1])
|
|
6281
|
+
)
|
|
6146
6282
|
else:
|
|
6147
6283
|
if self.use_2D:
|
|
6148
6284
|
imap = self.backend.bk_reshape(
|
|
6149
6285
|
self.backend.bk_tile(
|
|
6150
|
-
self.backend.bk_cast(
|
|
6286
|
+
self.backend.bk_cast(l_input_image[k].flatten()),
|
|
6151
6287
|
synthesised_N,
|
|
6152
6288
|
),
|
|
6153
6289
|
[synthesised_N, tmp[k].shape[1], tmp[k].shape[2]],
|
|
@@ -6155,7 +6291,7 @@ class funct(FOC.FoCUS):
|
|
|
6155
6291
|
else:
|
|
6156
6292
|
imap = self.backend.bk_reshape(
|
|
6157
6293
|
self.backend.bk_tile(
|
|
6158
|
-
self.backend.bk_cast(
|
|
6294
|
+
self.backend.bk_cast(l_input_image[k].flatten()),
|
|
6159
6295
|
synthesised_N,
|
|
6160
6296
|
),
|
|
6161
6297
|
[synthesised_N, tmp[k].shape[1]],
|
|
@@ -6170,26 +6306,56 @@ class funct(FOC.FoCUS):
|
|
|
6170
6306
|
imap = self.up_grade(omap, imap.shape[1] * 2, axis=1)
|
|
6171
6307
|
else:
|
|
6172
6308
|
imap = self.up_grade(omap, l_nside, axis=1)
|
|
6173
|
-
|
|
6309
|
+
|
|
6174
6310
|
if grd_mask is not None:
|
|
6175
|
-
imap=imap*l_grd_mask[k]+tmp[k]*(1-l_grd_mask[k])
|
|
6176
|
-
|
|
6311
|
+
imap = imap * l_grd_mask[k] + tmp[k] * (1 - l_grd_mask[k])
|
|
6312
|
+
|
|
6177
6313
|
# compute the coefficients for the target image
|
|
6178
6314
|
if use_variance:
|
|
6179
6315
|
ref, sref = self.scattering_cov(
|
|
6180
|
-
tmp[k],
|
|
6316
|
+
tmp[k],
|
|
6317
|
+
data2=l_ref[k],
|
|
6318
|
+
get_variance=True,
|
|
6319
|
+
edge=l_edge,
|
|
6320
|
+
Jmax=l_jmax[k],
|
|
6321
|
+
in_mask=l_in_mask[k],
|
|
6322
|
+
iso_ang=iso_ang,
|
|
6181
6323
|
)
|
|
6182
6324
|
else:
|
|
6183
|
-
ref = self.scattering_cov(
|
|
6325
|
+
ref = self.scattering_cov(
|
|
6326
|
+
tmp[k],
|
|
6327
|
+
data2=l_ref[k],
|
|
6328
|
+
in_mask=l_in_mask[k],
|
|
6329
|
+
edge=l_edge,
|
|
6330
|
+
Jmax=l_jmax[k],
|
|
6331
|
+
iso_ang=iso_ang,
|
|
6332
|
+
)
|
|
6184
6333
|
sref = ref
|
|
6185
6334
|
|
|
6186
6335
|
# compute the mean of the population does nothing if only one map is given
|
|
6187
6336
|
ref = self.reduce_mean_batch(ref)
|
|
6188
6337
|
|
|
6189
|
-
|
|
6190
|
-
|
|
6338
|
+
if l_in_mask[k] is not None:
|
|
6339
|
+
self.purge_edge_mask()
|
|
6191
6340
|
|
|
6192
|
-
|
|
6341
|
+
if l_ref[k] is None:
|
|
6342
|
+
# define a loss to minimize
|
|
6343
|
+
loss = synthe.Loss(The_loss, self, ref, sref, use_variance, l_jmax[k])
|
|
6344
|
+
else:
|
|
6345
|
+
# define a loss to minimize
|
|
6346
|
+
loss = synthe.Loss(
|
|
6347
|
+
The_lossX, self, ref, sref, use_variance, l_ref[k], l_jmax[k]
|
|
6348
|
+
)
|
|
6349
|
+
|
|
6350
|
+
if input_image is not None:
|
|
6351
|
+
# define a loss to minimize
|
|
6352
|
+
loss_input = synthe.Loss(The_loss_ref_image, self,
|
|
6353
|
+
self.backend.bk_cast(l_input_image[k]),
|
|
6354
|
+
self.backend.bk_cast(l_in_mask[k]))
|
|
6355
|
+
|
|
6356
|
+
sy = synthe.Synthesis([loss]) #,loss_input])
|
|
6357
|
+
else:
|
|
6358
|
+
sy = synthe.Synthesis([loss])
|
|
6193
6359
|
|
|
6194
6360
|
# initialize the synthesised map
|
|
6195
6361
|
if self.use_2D:
|
|
@@ -6201,7 +6367,12 @@ class funct(FOC.FoCUS):
|
|
|
6201
6367
|
l_nside *= 2
|
|
6202
6368
|
|
|
6203
6369
|
# do the minimization
|
|
6204
|
-
omap = sy.run(
|
|
6370
|
+
omap = sy.run(
|
|
6371
|
+
imap,
|
|
6372
|
+
EVAL_FREQUENCY=EVAL_FREQUENCY,
|
|
6373
|
+
NUM_EPOCHS=NUM_EPOCHS,
|
|
6374
|
+
grd_mask=l_grd_mask[k],
|
|
6375
|
+
)
|
|
6205
6376
|
|
|
6206
6377
|
t2 = time.time()
|
|
6207
6378
|
print("Total computation %.2fs" % (t2 - t1))
|