foscat 3.9.0__py3-none-any.whl → 2025.3.0__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- foscat/BkTorch.py +68 -0
- foscat/FoCUS.py +157 -34
- foscat/Synthesis.py +1 -1
- foscat/scat_cov.py +393 -238
- {foscat-3.9.0.dist-info → foscat-2025.3.0.dist-info}/METADATA +3 -2
- {foscat-3.9.0.dist-info → foscat-2025.3.0.dist-info}/RECORD +9 -9
- {foscat-3.9.0.dist-info → foscat-2025.3.0.dist-info}/WHEEL +1 -1
- {foscat-3.9.0.dist-info → foscat-2025.3.0.dist-info/licenses}/LICENSE +0 -0
- {foscat-3.9.0.dist-info → foscat-2025.3.0.dist-info}/top_level.txt +0 -0
foscat/scat_cov.py
CHANGED
|
@@ -177,9 +177,21 @@ class scat_cov:
|
|
|
177
177
|
],
|
|
178
178
|
)
|
|
179
179
|
),
|
|
180
|
-
self.
|
|
180
|
+
self.backend.bk_reshape(
|
|
181
|
+
self.S3,
|
|
182
|
+
[
|
|
183
|
+
self.S3.shape[0],
|
|
184
|
+
self.S3.shape[1]
|
|
185
|
+
* self.S3.shape[2]
|
|
186
|
+
* self.S3.shape[3]
|
|
187
|
+
* self.S3.shape[4],
|
|
188
|
+
],
|
|
189
|
+
),
|
|
190
|
+
]
|
|
191
|
+
if self.S3P is not None:
|
|
192
|
+
tmp = tmp + [
|
|
181
193
|
self.backend.bk_reshape(
|
|
182
|
-
self.
|
|
194
|
+
self.S3P,
|
|
183
195
|
[
|
|
184
196
|
self.S3.shape[0],
|
|
185
197
|
self.S3.shape[1]
|
|
@@ -188,37 +200,19 @@ class scat_cov:
|
|
|
188
200
|
* self.S3.shape[4],
|
|
189
201
|
],
|
|
190
202
|
)
|
|
191
|
-
),
|
|
192
|
-
]
|
|
193
|
-
if self.S3P is not None:
|
|
194
|
-
tmp = tmp + [
|
|
195
|
-
self.conv2complex(
|
|
196
|
-
self.backend.bk_reshape(
|
|
197
|
-
self.S3P,
|
|
198
|
-
[
|
|
199
|
-
self.S3.shape[0],
|
|
200
|
-
self.S3.shape[1]
|
|
201
|
-
* self.S3.shape[2]
|
|
202
|
-
* self.S3.shape[3]
|
|
203
|
-
* self.S3.shape[4],
|
|
204
|
-
],
|
|
205
|
-
)
|
|
206
|
-
)
|
|
207
203
|
]
|
|
208
204
|
|
|
209
205
|
tmp = tmp + [
|
|
210
|
-
self.
|
|
211
|
-
self.
|
|
212
|
-
|
|
213
|
-
[
|
|
214
|
-
|
|
215
|
-
|
|
216
|
-
|
|
217
|
-
|
|
218
|
-
|
|
219
|
-
|
|
220
|
-
],
|
|
221
|
-
)
|
|
206
|
+
self.backend.bk_reshape(
|
|
207
|
+
self.S4,
|
|
208
|
+
[
|
|
209
|
+
self.S3.shape[0],
|
|
210
|
+
self.S4.shape[1]
|
|
211
|
+
* self.S4.shape[2]
|
|
212
|
+
* self.S4.shape[3]
|
|
213
|
+
* self.S4.shape[4]
|
|
214
|
+
* self.S4.shape[5],
|
|
215
|
+
],
|
|
222
216
|
)
|
|
223
217
|
]
|
|
224
218
|
|
|
@@ -504,7 +498,7 @@ class scat_cov:
|
|
|
504
498
|
if other.S1 is None:
|
|
505
499
|
s1 = None
|
|
506
500
|
else:
|
|
507
|
-
s1 = self.S1
|
|
501
|
+
s1 = self.doadd(self.S1, other.S1)
|
|
508
502
|
else:
|
|
509
503
|
s1 = self.S1 + other
|
|
510
504
|
|
|
@@ -534,7 +528,7 @@ class scat_cov:
|
|
|
534
528
|
return scat_cov(
|
|
535
529
|
self.doadd(self.S0, other.S0),
|
|
536
530
|
self.doadd(self.S2, other.S2),
|
|
537
|
-
(self.S3
|
|
531
|
+
self.doadd(self.S3, other.S3),
|
|
538
532
|
s4,
|
|
539
533
|
s1=s1,
|
|
540
534
|
s3p=s3p,
|
|
@@ -662,7 +656,7 @@ class scat_cov:
|
|
|
662
656
|
s1 = None
|
|
663
657
|
else:
|
|
664
658
|
if isinstance(other, scat_cov):
|
|
665
|
-
s1 = other.S1
|
|
659
|
+
s1 = self.dodiv(other.S1, self.S1)
|
|
666
660
|
else:
|
|
667
661
|
s1 = other / self.S1
|
|
668
662
|
|
|
@@ -689,7 +683,7 @@ class scat_cov:
|
|
|
689
683
|
return scat_cov(
|
|
690
684
|
self.dodiv(other.S0, self.S0),
|
|
691
685
|
self.dodiv(other.S2, self.S2),
|
|
692
|
-
(other.S3
|
|
686
|
+
self.dodiv(other.S3, self.S3),
|
|
693
687
|
s4,
|
|
694
688
|
s1=s1,
|
|
695
689
|
s3p=s3p,
|
|
@@ -725,7 +719,7 @@ class scat_cov:
|
|
|
725
719
|
if other.S1 is None:
|
|
726
720
|
s1 = None
|
|
727
721
|
else:
|
|
728
|
-
s1 = other.S1
|
|
722
|
+
s1 = self.domin(other.S1, self.S1)
|
|
729
723
|
else:
|
|
730
724
|
s1 = other - self.S1
|
|
731
725
|
|
|
@@ -755,7 +749,7 @@ class scat_cov:
|
|
|
755
749
|
return scat_cov(
|
|
756
750
|
self.domin(other.S0, self.S0),
|
|
757
751
|
self.domin(other.S2, self.S2),
|
|
758
|
-
(other.S3
|
|
752
|
+
self.domin(other.S3, self.S3),
|
|
759
753
|
s4,
|
|
760
754
|
s1=s1,
|
|
761
755
|
s3p=s3p,
|
|
@@ -790,7 +784,7 @@ class scat_cov:
|
|
|
790
784
|
if other.S1 is None:
|
|
791
785
|
s1 = None
|
|
792
786
|
else:
|
|
793
|
-
s1 = self.S1
|
|
787
|
+
s1 = self.domin(self.S1, other.S1)
|
|
794
788
|
else:
|
|
795
789
|
s1 = self.S1 - other
|
|
796
790
|
|
|
@@ -820,7 +814,7 @@ class scat_cov:
|
|
|
820
814
|
return scat_cov(
|
|
821
815
|
self.domin(self.S0, other.S0),
|
|
822
816
|
self.domin(self.S2, other.S2),
|
|
823
|
-
(self.S3
|
|
817
|
+
self.domin(self.S3, other.S3),
|
|
824
818
|
s4,
|
|
825
819
|
s1=s1,
|
|
826
820
|
s3p=s3p,
|
|
@@ -920,7 +914,7 @@ class scat_cov:
|
|
|
920
914
|
if other.S1 is None:
|
|
921
915
|
s1 = None
|
|
922
916
|
else:
|
|
923
|
-
s1 = self.S1
|
|
917
|
+
s1 = self.domult(self.S1, other.S1)
|
|
924
918
|
else:
|
|
925
919
|
s1 = self.S1 * other
|
|
926
920
|
|
|
@@ -2215,7 +2209,7 @@ class scat_cov:
|
|
|
2215
2209
|
|
|
2216
2210
|
|
|
2217
2211
|
class funct(FOC.FoCUS):
|
|
2218
|
-
|
|
2212
|
+
|
|
2219
2213
|
def fill(self, im, nullval=hp.UNSEEN):
|
|
2220
2214
|
if self.use_2D:
|
|
2221
2215
|
return self.fill_2d(im, nullval=nullval)
|
|
@@ -2402,8 +2396,9 @@ class funct(FOC.FoCUS):
|
|
|
2402
2396
|
if smooth_scale > 0:
|
|
2403
2397
|
for m in range(smooth_scale):
|
|
2404
2398
|
if cc.shape[0] > 12:
|
|
2405
|
-
cc = self.ud_grade_2(self.smooth(cc))
|
|
2406
|
-
ss = self.ud_grade_2(self.smooth(ss))
|
|
2399
|
+
cc, _ = self.ud_grade_2(self.smooth(cc))
|
|
2400
|
+
ss, _ = self.ud_grade_2(self.smooth(ss))
|
|
2401
|
+
|
|
2407
2402
|
if cc.shape[0] != tmp.shape[0]:
|
|
2408
2403
|
ll_nside = int(np.sqrt(tmp.shape[1] // 12))
|
|
2409
2404
|
cc = self.up_grade(cc, ll_nside)
|
|
@@ -2473,8 +2468,8 @@ class funct(FOC.FoCUS):
|
|
|
2473
2468
|
if smooth_scale > 0:
|
|
2474
2469
|
for m in range(smooth_scale):
|
|
2475
2470
|
if cc2.shape[0] > 12:
|
|
2476
|
-
cc2 = self.ud_grade_2(self.smooth(cc2))
|
|
2477
|
-
ss2 = self.ud_grade_2(self.smooth(ss2))
|
|
2471
|
+
cc2, _ = self.ud_grade_2(self.smooth(cc2))
|
|
2472
|
+
ss2, _ = self.ud_grade_2(self.smooth(ss2))
|
|
2478
2473
|
|
|
2479
2474
|
if cc2.shape[0] != sim.shape[1]:
|
|
2480
2475
|
ll_nside = int(np.sqrt(sim.shape[1] // 12))
|
|
@@ -2510,9 +2505,9 @@ class funct(FOC.FoCUS):
|
|
|
2510
2505
|
cmat2[k] = self.backend.bk_cast(mat2.astype("complex64"))
|
|
2511
2506
|
|
|
2512
2507
|
if k < l_nside - 1:
|
|
2513
|
-
tmp = self.ud_grade_2(tmp, axis=1)
|
|
2508
|
+
tmp, _ = self.ud_grade_2(tmp, axis=1)
|
|
2514
2509
|
if image2 is not None:
|
|
2515
|
-
tmpi2 = self.ud_grade_2(tmpi2, axis=1)
|
|
2510
|
+
tmpi2, _ = self.ud_grade_2(tmpi2, axis=1)
|
|
2516
2511
|
return cmat, cmat2
|
|
2517
2512
|
|
|
2518
2513
|
def div_norm(self, complex_value, float_value):
|
|
@@ -2533,6 +2528,8 @@ class funct(FOC.FoCUS):
|
|
|
2533
2528
|
Jmax=None,
|
|
2534
2529
|
out_nside=None,
|
|
2535
2530
|
edge=True,
|
|
2531
|
+
nside=None,
|
|
2532
|
+
cell_ids=None,
|
|
2536
2533
|
):
|
|
2537
2534
|
"""
|
|
2538
2535
|
Calculates the scattering correlations for a batch of images. Mean are done over pixels.
|
|
@@ -2606,7 +2603,7 @@ class funct(FOC.FoCUS):
|
|
|
2606
2603
|
cross = False
|
|
2607
2604
|
if image2 is not None:
|
|
2608
2605
|
cross = True
|
|
2609
|
-
|
|
2606
|
+
l_nside = 2**32 # not initialize if 1D or 2D
|
|
2610
2607
|
### PARAMETERS
|
|
2611
2608
|
axis = 1
|
|
2612
2609
|
# determine jmax and nside corresponding to the input map
|
|
@@ -2638,7 +2635,8 @@ class funct(FOC.FoCUS):
|
|
|
2638
2635
|
else:
|
|
2639
2636
|
npix = int(im_shape[0]) # Number of pixels
|
|
2640
2637
|
|
|
2641
|
-
nside
|
|
2638
|
+
if nside is None:
|
|
2639
|
+
nside = int(np.sqrt(npix // 12))
|
|
2642
2640
|
|
|
2643
2641
|
J = int(np.log(nside) / np.log(2)) # Number of j scales
|
|
2644
2642
|
|
|
@@ -2675,7 +2673,7 @@ class funct(FOC.FoCUS):
|
|
|
2675
2673
|
else:
|
|
2676
2674
|
vmask = self.backend.bk_cast(mask) # [Nmask, Npix]
|
|
2677
2675
|
|
|
2678
|
-
if self.KERNELSZ > 3 and not self.use_2D:
|
|
2676
|
+
if self.KERNELSZ > 3 and not self.use_2D and cell_ids is None:
|
|
2679
2677
|
# if the kernel size is bigger than 3 increase the binning before smoothing
|
|
2680
2678
|
if self.use_2D:
|
|
2681
2679
|
vmask = self.up_grade(
|
|
@@ -2693,12 +2691,15 @@ class funct(FOC.FoCUS):
|
|
|
2693
2691
|
I1 = self.up_grade(I1, I1.shape[axis] * 2, axis=axis)
|
|
2694
2692
|
if cross:
|
|
2695
2693
|
I2 = self.up_grade(I2, I2.shape[axis] * 2, axis=axis)
|
|
2694
|
+
nside = nside * 2
|
|
2696
2695
|
else:
|
|
2697
2696
|
I1 = self.up_grade(I1, nside * 2, axis=axis)
|
|
2698
2697
|
vmask = self.up_grade(vmask, nside * 2, axis=1)
|
|
2699
2698
|
if cross:
|
|
2700
2699
|
I2 = self.up_grade(I2, nside * 2, axis=axis)
|
|
2701
2700
|
|
|
2701
|
+
nside = nside * 2
|
|
2702
|
+
|
|
2702
2703
|
if self.KERNELSZ > 5 and not self.use_2D:
|
|
2703
2704
|
# if the kernel size is bigger than 3 increase the binning before smoothing
|
|
2704
2705
|
if self.use_2D:
|
|
@@ -2716,15 +2717,17 @@ class funct(FOC.FoCUS):
|
|
|
2716
2717
|
nouty=I2.shape[axis + 1] * 2,
|
|
2717
2718
|
)
|
|
2718
2719
|
elif self.use_1D:
|
|
2719
|
-
vmask = self.up_grade(vmask, I1.shape[axis] *
|
|
2720
|
-
I1 = self.up_grade(I1, I1.shape[axis] *
|
|
2720
|
+
vmask = self.up_grade(vmask, I1.shape[axis] * 2, axis=1)
|
|
2721
|
+
I1 = self.up_grade(I1, I1.shape[axis] * 2, axis=axis)
|
|
2721
2722
|
if cross:
|
|
2722
|
-
I2 = self.up_grade(I2, I2.shape[axis] *
|
|
2723
|
+
I2 = self.up_grade(I2, I2.shape[axis] * 2, axis=axis)
|
|
2724
|
+
nside = nside * 2
|
|
2723
2725
|
else:
|
|
2724
|
-
I1 = self.up_grade(I1, nside *
|
|
2725
|
-
vmask = self.up_grade(vmask, nside *
|
|
2726
|
+
I1 = self.up_grade(I1, nside * 2, axis=axis)
|
|
2727
|
+
vmask = self.up_grade(vmask, nside * 2, axis=1)
|
|
2726
2728
|
if cross:
|
|
2727
|
-
I2 = self.up_grade(I2, nside *
|
|
2729
|
+
I2 = self.up_grade(I2, nside * 2, axis=axis)
|
|
2730
|
+
nside = nside * 2
|
|
2728
2731
|
|
|
2729
2732
|
# Normalize the masks because they have different pixel numbers
|
|
2730
2733
|
# vmask /= self.backend.bk_reduce_sum(vmask, axis=1)[:, None] # [Nmask, Npix]
|
|
@@ -2797,6 +2800,8 @@ class funct(FOC.FoCUS):
|
|
|
2797
2800
|
M1_dic = {}
|
|
2798
2801
|
M2_dic = {}
|
|
2799
2802
|
|
|
2803
|
+
cell_ids_j3 = cell_ids
|
|
2804
|
+
|
|
2800
2805
|
for j3 in range(Jmax):
|
|
2801
2806
|
|
|
2802
2807
|
if edge:
|
|
@@ -2835,7 +2840,9 @@ class funct(FOC.FoCUS):
|
|
|
2835
2840
|
|
|
2836
2841
|
####### S1 and S2
|
|
2837
2842
|
### Make the convolution I1 * Psi_j3
|
|
2838
|
-
conv1 = self.convol(
|
|
2843
|
+
conv1 = self.convol(
|
|
2844
|
+
I1, axis=1, cell_ids=cell_ids_j3, nside=nside_j3
|
|
2845
|
+
) # [Nbatch, Npix_j3, Norient3]
|
|
2839
2846
|
if cmat is not None:
|
|
2840
2847
|
tmp2 = self.backend.bk_repeat(conv1, self.NORIENT, axis=-1)
|
|
2841
2848
|
conv1 = self.backend.bk_reduce_sum(
|
|
@@ -2948,7 +2955,9 @@ class funct(FOC.FoCUS):
|
|
|
2948
2955
|
|
|
2949
2956
|
else: # Cross
|
|
2950
2957
|
### Make the convolution I2 * Psi_j3
|
|
2951
|
-
conv2 = self.convol(
|
|
2958
|
+
conv2 = self.convol(
|
|
2959
|
+
I2, axis=1, cell_ids=cell_ids_j3, nside=nside_j3
|
|
2960
|
+
) # [Nbatch, Npix_j3, Norient3]
|
|
2952
2961
|
if cmat is not None:
|
|
2953
2962
|
tmp2 = self.backend.bk_repeat(conv2, self.NORIENT, axis=-1)
|
|
2954
2963
|
conv2 = self.backend.bk_reduce_sum(
|
|
@@ -3092,7 +3101,6 @@ class funct(FOC.FoCUS):
|
|
|
3092
3101
|
M2convPsi_dic = {}
|
|
3093
3102
|
|
|
3094
3103
|
###### S3
|
|
3095
|
-
nside_j2 = nside_j3
|
|
3096
3104
|
for j2 in range(0, j3 + 1): # j2 <= j3
|
|
3097
3105
|
if return_data:
|
|
3098
3106
|
if S4[j3] is None:
|
|
@@ -3111,6 +3119,8 @@ class funct(FOC.FoCUS):
|
|
|
3111
3119
|
M1convPsi_dic,
|
|
3112
3120
|
calc_var=True,
|
|
3113
3121
|
cmat2=cmat2,
|
|
3122
|
+
cell_ids=cell_ids_j3,
|
|
3123
|
+
nside_j2=nside_j3,
|
|
3114
3124
|
) # [Nbatch, Nmask, Norient3, Norient2]
|
|
3115
3125
|
else:
|
|
3116
3126
|
s3 = self._compute_S3(
|
|
@@ -3122,19 +3132,21 @@ class funct(FOC.FoCUS):
|
|
|
3122
3132
|
M1convPsi_dic,
|
|
3123
3133
|
return_data=return_data,
|
|
3124
3134
|
cmat2=cmat2,
|
|
3135
|
+
cell_ids=cell_ids_j3,
|
|
3136
|
+
nside_j2=nside_j3,
|
|
3125
3137
|
) # [Nbatch, Nmask, Norient3, Norient2]
|
|
3126
3138
|
|
|
3127
3139
|
if return_data:
|
|
3128
3140
|
if S3[j3] is None:
|
|
3129
3141
|
S3[j3] = {}
|
|
3130
|
-
if out_nside is not None and out_nside <
|
|
3142
|
+
if out_nside is not None and out_nside < nside_j3:
|
|
3131
3143
|
s3 = self.backend.bk_reduce_mean(
|
|
3132
3144
|
self.backend.bk_reshape(
|
|
3133
3145
|
s3,
|
|
3134
3146
|
[
|
|
3135
3147
|
s3.shape[0],
|
|
3136
3148
|
12 * out_nside**2,
|
|
3137
|
-
(
|
|
3149
|
+
(nside_j3 // out_nside) ** 2,
|
|
3138
3150
|
s3.shape[2],
|
|
3139
3151
|
s3.shape[3],
|
|
3140
3152
|
],
|
|
@@ -3181,6 +3193,8 @@ class funct(FOC.FoCUS):
|
|
|
3181
3193
|
M2convPsi_dic,
|
|
3182
3194
|
calc_var=True,
|
|
3183
3195
|
cmat2=cmat2,
|
|
3196
|
+
cell_ids=cell_ids_j3,
|
|
3197
|
+
nside_j2=nside_j3,
|
|
3184
3198
|
)
|
|
3185
3199
|
s3p, vs3p = self._compute_S3(
|
|
3186
3200
|
j2,
|
|
@@ -3191,41 +3205,47 @@ class funct(FOC.FoCUS):
|
|
|
3191
3205
|
M1convPsi_dic,
|
|
3192
3206
|
calc_var=True,
|
|
3193
3207
|
cmat2=cmat2,
|
|
3208
|
+
cell_ids=cell_ids_j3,
|
|
3209
|
+
nside_j2=nside_j3,
|
|
3194
3210
|
)
|
|
3195
3211
|
else:
|
|
3196
|
-
|
|
3212
|
+
s3p = self._compute_S3(
|
|
3197
3213
|
j2,
|
|
3198
3214
|
j3,
|
|
3199
|
-
|
|
3215
|
+
conv2,
|
|
3200
3216
|
vmask,
|
|
3201
|
-
|
|
3202
|
-
|
|
3217
|
+
M1_dic,
|
|
3218
|
+
M1convPsi_dic,
|
|
3203
3219
|
return_data=return_data,
|
|
3204
3220
|
cmat2=cmat2,
|
|
3221
|
+
cell_ids=cell_ids_j3,
|
|
3222
|
+
nside_j2=nside_j3,
|
|
3205
3223
|
)
|
|
3206
|
-
|
|
3224
|
+
s3 = self._compute_S3(
|
|
3207
3225
|
j2,
|
|
3208
3226
|
j3,
|
|
3209
|
-
|
|
3227
|
+
conv1,
|
|
3210
3228
|
vmask,
|
|
3211
|
-
|
|
3212
|
-
|
|
3229
|
+
M2_dic,
|
|
3230
|
+
M2convPsi_dic,
|
|
3213
3231
|
return_data=return_data,
|
|
3214
3232
|
cmat2=cmat2,
|
|
3233
|
+
cell_ids=cell_ids_j3,
|
|
3234
|
+
nside_j2=nside_j3,
|
|
3215
3235
|
)
|
|
3216
3236
|
|
|
3217
3237
|
if return_data:
|
|
3218
3238
|
if S3[j3] is None:
|
|
3219
3239
|
S3[j3] = {}
|
|
3220
3240
|
S3P[j3] = {}
|
|
3221
|
-
if out_nside is not None and out_nside <
|
|
3241
|
+
if out_nside is not None and out_nside < nside_j3:
|
|
3222
3242
|
s3 = self.backend.bk_reduce_mean(
|
|
3223
3243
|
self.backend.bk_reshape(
|
|
3224
3244
|
s3,
|
|
3225
3245
|
[
|
|
3226
3246
|
s3.shape[0],
|
|
3227
3247
|
12 * out_nside**2,
|
|
3228
|
-
(
|
|
3248
|
+
(nside_j3 // out_nside) ** 2,
|
|
3229
3249
|
s3.shape[2],
|
|
3230
3250
|
s3.shape[3],
|
|
3231
3251
|
],
|
|
@@ -3238,7 +3258,7 @@ class funct(FOC.FoCUS):
|
|
|
3238
3258
|
[
|
|
3239
3259
|
s3.shape[0],
|
|
3240
3260
|
12 * out_nside**2,
|
|
3241
|
-
(
|
|
3261
|
+
(nside_j3 // out_nside) ** 2,
|
|
3242
3262
|
s3.shape[2],
|
|
3243
3263
|
s3.shape[3],
|
|
3244
3264
|
],
|
|
@@ -3295,7 +3315,6 @@ class funct(FOC.FoCUS):
|
|
|
3295
3315
|
# s3.shape[2]*s3.shape[3]]))
|
|
3296
3316
|
|
|
3297
3317
|
##### S4
|
|
3298
|
-
nside_j1 = nside_j2
|
|
3299
3318
|
for j1 in range(0, j2 + 1): # j1 <= j2
|
|
3300
3319
|
### S4_auto = <(|I1 * psi1| * psi3)(|I1 * psi2| * psi3)^*>
|
|
3301
3320
|
if not cross:
|
|
@@ -3321,14 +3340,14 @@ class funct(FOC.FoCUS):
|
|
|
3321
3340
|
if return_data:
|
|
3322
3341
|
if S4[j3][j2] is None:
|
|
3323
3342
|
S4[j3][j2] = {}
|
|
3324
|
-
if out_nside is not None and out_nside <
|
|
3343
|
+
if out_nside is not None and out_nside < nside_j3:
|
|
3325
3344
|
s4 = self.backend.bk_reduce_mean(
|
|
3326
3345
|
self.backend.bk_reshape(
|
|
3327
3346
|
s4,
|
|
3328
3347
|
[
|
|
3329
3348
|
s4.shape[0],
|
|
3330
3349
|
12 * out_nside**2,
|
|
3331
|
-
(
|
|
3350
|
+
(nside_j3 // out_nside) ** 2,
|
|
3332
3351
|
s4.shape[2],
|
|
3333
3352
|
s4.shape[3],
|
|
3334
3353
|
s4.shape[4],
|
|
@@ -3396,14 +3415,14 @@ class funct(FOC.FoCUS):
|
|
|
3396
3415
|
if return_data:
|
|
3397
3416
|
if S4[j3][j2] is None:
|
|
3398
3417
|
S4[j3][j2] = {}
|
|
3399
|
-
if out_nside is not None and out_nside <
|
|
3418
|
+
if out_nside is not None and out_nside < nside_j3:
|
|
3400
3419
|
s4 = self.backend.bk_reduce_mean(
|
|
3401
3420
|
self.backend.bk_reshape(
|
|
3402
3421
|
s4,
|
|
3403
3422
|
[
|
|
3404
3423
|
s4.shape[0],
|
|
3405
3424
|
12 * out_nside**2,
|
|
3406
|
-
(
|
|
3425
|
+
(nside_j3 // out_nside) ** 2,
|
|
3407
3426
|
s4.shape[2],
|
|
3408
3427
|
s4.shape[3],
|
|
3409
3428
|
s4.shape[4],
|
|
@@ -3447,48 +3466,51 @@ class funct(FOC.FoCUS):
|
|
|
3447
3466
|
self.backend.bk_expand_dims(vs4, off_S4)
|
|
3448
3467
|
) # Add a dimension for NS4
|
|
3449
3468
|
|
|
3450
|
-
nside_j1 = nside_j1 // 2
|
|
3451
|
-
nside_j2 = nside_j2 // 2
|
|
3452
|
-
|
|
3453
3469
|
###### Reshape for next iteration on j3
|
|
3454
3470
|
### Image I1,
|
|
3455
3471
|
# downscale the I1 [Nbatch, Npix_j3]
|
|
3456
3472
|
if j3 != Jmax - 1:
|
|
3457
|
-
I1 = self.smooth(I1, axis=1)
|
|
3458
|
-
I1 = self.ud_grade_2(
|
|
3473
|
+
I1 = self.smooth(I1, axis=1, cell_ids=cell_ids_j3, nside=nside_j3)
|
|
3474
|
+
I1, new_cell_ids_j3 = self.ud_grade_2(
|
|
3475
|
+
I1, axis=1, cell_ids=cell_ids_j3, nside=nside_j3
|
|
3476
|
+
)
|
|
3459
3477
|
|
|
3460
3478
|
### Image I2
|
|
3461
3479
|
if cross:
|
|
3462
|
-
I2 = self.smooth(I2, axis=1)
|
|
3463
|
-
I2 = self.ud_grade_2(
|
|
3480
|
+
I2 = self.smooth(I2, axis=1, cell_ids=cell_ids_j3, nside=nside_j3)
|
|
3481
|
+
I2, new_cell_ids_j3 = self.ud_grade_2(
|
|
3482
|
+
I2, axis=1, cell_ids=cell_ids_j3, nside=nside_j3
|
|
3483
|
+
)
|
|
3464
3484
|
|
|
3465
3485
|
### Modules
|
|
3466
3486
|
for j2 in range(0, j3 + 1): # j2 =< j3
|
|
3467
3487
|
### Dictionary M1_dic[j2]
|
|
3468
3488
|
M1_smooth = self.smooth(
|
|
3469
|
-
M1_dic[j2], axis=1
|
|
3489
|
+
M1_dic[j2], axis=1, cell_ids=cell_ids_j3, nside=nside_j3
|
|
3470
3490
|
) # [Nbatch, Npix_j3, Norient3]
|
|
3471
|
-
M1_dic[j2] = self.ud_grade_2(
|
|
3472
|
-
M1_smooth, axis=1
|
|
3491
|
+
M1_dic[j2], new_cell_ids_j2 = self.ud_grade_2(
|
|
3492
|
+
M1_smooth, axis=1, cell_ids=cell_ids_j3, nside=nside_j3
|
|
3473
3493
|
) # [Nbatch, Npix_j3, Norient3]
|
|
3474
3494
|
|
|
3475
3495
|
### Dictionary M2_dic[j2]
|
|
3476
3496
|
if cross:
|
|
3477
3497
|
M2_smooth = self.smooth(
|
|
3478
|
-
M2_dic[j2], axis=1
|
|
3498
|
+
M2_dic[j2], axis=1, cell_ids=cell_ids_j3, nside=nside_j3
|
|
3479
3499
|
) # [Nbatch, Npix_j3, Norient3]
|
|
3480
|
-
M2_dic[j2] = self.ud_grade_2(
|
|
3481
|
-
M2_smooth, axis=1
|
|
3500
|
+
M2_dic[j2], new_cell_ids_j2 = self.ud_grade_2(
|
|
3501
|
+
M2_smooth, axis=1, cell_ids=cell_ids_j3, nside=nside_j3
|
|
3482
3502
|
) # [Nbatch, Npix_j3, Norient3]
|
|
3483
|
-
|
|
3484
3503
|
### Mask
|
|
3485
|
-
vmask = self.ud_grade_2(
|
|
3504
|
+
vmask, new_cell_ids_j3 = self.ud_grade_2(
|
|
3505
|
+
vmask, axis=1, cell_ids=cell_ids_j3, nside=nside_j3
|
|
3506
|
+
)
|
|
3486
3507
|
|
|
3487
3508
|
if self.mask_thres is not None:
|
|
3488
3509
|
vmask = self.backend.bk_threshold(vmask, self.mask_thres)
|
|
3489
3510
|
|
|
3490
3511
|
### NSIDE_j3
|
|
3491
3512
|
nside_j3 = nside_j3 // 2
|
|
3513
|
+
cell_ids_j3 = new_cell_ids_j3
|
|
3492
3514
|
|
|
3493
3515
|
### Store P1_dic and P2_dic in self
|
|
3494
3516
|
if (norm == "auto") and (self.P1_dic is None):
|
|
@@ -3588,6 +3610,8 @@ class funct(FOC.FoCUS):
|
|
|
3588
3610
|
calc_var=False,
|
|
3589
3611
|
return_data=False,
|
|
3590
3612
|
cmat2=None,
|
|
3613
|
+
cell_ids=None,
|
|
3614
|
+
nside_j2=None,
|
|
3591
3615
|
):
|
|
3592
3616
|
"""
|
|
3593
3617
|
Compute the S3 coefficients (auto or cross)
|
|
@@ -3601,7 +3625,7 @@ class funct(FOC.FoCUS):
|
|
|
3601
3625
|
### Compute |I1 * Psi_j2| * Psi_j3 = M1_j2 * Psi_j3
|
|
3602
3626
|
# Warning: M1_dic[j2] is already at j3 resolution [Nbatch, Npix_j3, Norient3]
|
|
3603
3627
|
MconvPsi = self.convol(
|
|
3604
|
-
M_dic[j2], axis=1
|
|
3628
|
+
M_dic[j2], axis=1, cell_ids=cell_ids, nside=nside_j2
|
|
3605
3629
|
) # [Nbatch, Npix_j3, Norient3, Norient2]
|
|
3606
3630
|
if cmat2 is not None:
|
|
3607
3631
|
tmp2 = self.backend.bk_repeat(MconvPsi, self.NORIENT, axis=-1)
|
|
@@ -3822,24 +3846,62 @@ class funct(FOC.FoCUS):
|
|
|
3822
3846
|
dy = int(max(8, min(np.ceil(N / 2**j), N // 2)))
|
|
3823
3847
|
return dx, dy
|
|
3824
3848
|
|
|
3825
|
-
def get_edge_masks(self, M, N, J, d0=1):
|
|
3849
|
+
def get_edge_masks(self, M, N, J, d0=1, in_mask=None, edge_dx=None, edge_dy=None):
|
|
3826
3850
|
"""
|
|
3827
3851
|
This function is strongly inspire by the package https://github.com/SihaoCheng/scattering_transform
|
|
3828
3852
|
Done by Sihao Cheng and Rudy Morel.
|
|
3829
3853
|
"""
|
|
3830
3854
|
edge_masks = np.empty((J, M, N))
|
|
3855
|
+
|
|
3831
3856
|
X, Y = np.meshgrid(np.arange(M), np.arange(N), indexing="ij")
|
|
3832
|
-
|
|
3833
|
-
|
|
3834
|
-
|
|
3835
|
-
|
|
3836
|
-
|
|
3837
|
-
|
|
3838
|
-
|
|
3839
|
-
|
|
3857
|
+
if in_mask is not None:
|
|
3858
|
+
from scipy.ndimage import binary_erosion
|
|
3859
|
+
|
|
3860
|
+
if in_mask is not None:
|
|
3861
|
+
if in_mask.shape[0] != M or in_mask.shape[0] != N:
|
|
3862
|
+
l_mask = in_mask.reshape(
|
|
3863
|
+
M, in_mask.shape[0] // M, N, in_mask.shape[1] // N
|
|
3864
|
+
)
|
|
3865
|
+
l_mask = (
|
|
3866
|
+
np.sum(np.sum(l_mask, 1), 2)
|
|
3867
|
+
* (M * N)
|
|
3868
|
+
/ (in_mask.shape[0] * in_mask.shape[1])
|
|
3869
|
+
)
|
|
3870
|
+
else:
|
|
3871
|
+
l_mask = in_mask
|
|
3872
|
+
|
|
3873
|
+
if edge_dx is None:
|
|
3874
|
+
for j in range(J):
|
|
3875
|
+
edge_dx = min(M // 4, 2**j * d0)
|
|
3876
|
+
edge_dy = min(N // 4, 2**j * d0)
|
|
3877
|
+
|
|
3878
|
+
edge_masks[j] = (
|
|
3879
|
+
(X >= edge_dx)
|
|
3880
|
+
* (X < M - edge_dx)
|
|
3881
|
+
* (Y >= edge_dy)
|
|
3882
|
+
* (Y < N - edge_dy)
|
|
3883
|
+
)
|
|
3884
|
+
if in_mask is not None:
|
|
3885
|
+
l_mask = binary_erosion(
|
|
3886
|
+
l_mask, iterations=1 + np.max([edge_dx, edge_dy])
|
|
3887
|
+
)
|
|
3888
|
+
edge_masks[j] *= l_mask
|
|
3889
|
+
|
|
3890
|
+
edge_masks = edge_masks[:, None, :, :]
|
|
3891
|
+
|
|
3892
|
+
edge_masks = edge_masks / edge_masks.mean((-2, -1))[:, :, None, None]
|
|
3893
|
+
else:
|
|
3894
|
+
edge_masks = (
|
|
3895
|
+
(X >= edge_dx) * (X < M - edge_dx) * (Y >= edge_dy) * (Y < N - edge_dy)
|
|
3840
3896
|
)
|
|
3841
|
-
|
|
3842
|
-
|
|
3897
|
+
if in_mask is not None:
|
|
3898
|
+
l_mask = binary_erosion(
|
|
3899
|
+
l_mask, iterations=1 + np.max([edge_dx, edge_dy])
|
|
3900
|
+
)
|
|
3901
|
+
edge_masks *= l_mask
|
|
3902
|
+
|
|
3903
|
+
edge_masks = edge_masks / edge_masks.mean((-2, -1))
|
|
3904
|
+
|
|
3843
3905
|
return self.backend.bk_cast(edge_masks)
|
|
3844
3906
|
|
|
3845
3907
|
# ---------------------------------------------------------------------------
|
|
@@ -3857,6 +3919,7 @@ class funct(FOC.FoCUS):
|
|
|
3857
3919
|
use_ref=False,
|
|
3858
3920
|
normalization="S2",
|
|
3859
3921
|
edge=False,
|
|
3922
|
+
in_mask=None,
|
|
3860
3923
|
pseudo_coef=1,
|
|
3861
3924
|
get_variance=False,
|
|
3862
3925
|
ref_sigma=None,
|
|
@@ -3925,6 +3988,9 @@ class funct(FOC.FoCUS):
|
|
|
3925
3988
|
if S4_criteria is None:
|
|
3926
3989
|
S4_criteria = "j2>=j1"
|
|
3927
3990
|
|
|
3991
|
+
if not edge and in_mask is not None:
|
|
3992
|
+
edge = True
|
|
3993
|
+
|
|
3928
3994
|
if self.all_bk_type == "float32":
|
|
3929
3995
|
C_ONE = np.complex64(1.0)
|
|
3930
3996
|
else:
|
|
@@ -3974,14 +4040,15 @@ class funct(FOC.FoCUS):
|
|
|
3974
4040
|
|
|
3975
4041
|
J = int(np.log(nside) / np.log(2)) # Number of j scales
|
|
3976
4042
|
|
|
3977
|
-
if Jmax is None:
|
|
3978
|
-
|
|
3979
|
-
|
|
3980
|
-
|
|
3981
|
-
|
|
3982
|
-
|
|
3983
|
-
|
|
3984
|
-
|
|
4043
|
+
if Jmax is not None:
|
|
4044
|
+
|
|
4045
|
+
if Jmax > J:
|
|
4046
|
+
print("==========\n\n")
|
|
4047
|
+
print(
|
|
4048
|
+
"The Jmax you requested is larger than the data size, which may cause problems while computing the scattering transform."
|
|
4049
|
+
)
|
|
4050
|
+
print("\n\n==========")
|
|
4051
|
+
J = Jmax # Number of steps for the loop on scales
|
|
3985
4052
|
|
|
3986
4053
|
L = self.NORIENT
|
|
3987
4054
|
norm_factor_S3 = 1.0
|
|
@@ -4067,7 +4134,10 @@ class funct(FOC.FoCUS):
|
|
|
4067
4134
|
#
|
|
4068
4135
|
if edge:
|
|
4069
4136
|
if (M, N, J) not in self.edge_masks:
|
|
4070
|
-
self.edge_masks[(M, N, J)] = self.get_edge_masks(
|
|
4137
|
+
self.edge_masks[(M, N, J)] = self.get_edge_masks(
|
|
4138
|
+
M, N, J, in_mask=in_mask
|
|
4139
|
+
)
|
|
4140
|
+
|
|
4071
4141
|
edge_mask = self.edge_masks[(M, N, J)]
|
|
4072
4142
|
else:
|
|
4073
4143
|
edge_mask = 1
|
|
@@ -4206,8 +4276,19 @@ class funct(FOC.FoCUS):
|
|
|
4206
4276
|
wavelet_f3 = self.cut_high_k_off(filters_set[j3], dx3, dy3) # L,x,y
|
|
4207
4277
|
_, M3, N3 = wavelet_f3.shape
|
|
4208
4278
|
wavelet_f3_squared = wavelet_f3**2
|
|
4209
|
-
|
|
4210
|
-
|
|
4279
|
+
if edge is True:
|
|
4280
|
+
if (M3, N3, J, j3) not in self.edge_masks:
|
|
4281
|
+
|
|
4282
|
+
edge_dx = min(4, int(2**j3 * dx3 * 2 / M))
|
|
4283
|
+
edge_dy = min(4, int(2**j3 * dy3 * 2 / N))
|
|
4284
|
+
|
|
4285
|
+
self.edge_masks[(M3, N3, J, j3)] = self.get_edge_masks(
|
|
4286
|
+
M3, N3, J, in_mask=in_mask, edge_dx=edge_dx, edge_dy=edge_dy
|
|
4287
|
+
)
|
|
4288
|
+
|
|
4289
|
+
edge_mask = self.edge_masks[(M3, N3, J, j3)]
|
|
4290
|
+
else:
|
|
4291
|
+
edge_mask = 1
|
|
4211
4292
|
|
|
4212
4293
|
# a normalization change due to the cutoff of frequency space
|
|
4213
4294
|
fft_factor = 1 / (M3 * N3) * (M3 * N3 / M / N) ** 2
|
|
@@ -4274,7 +4355,8 @@ class funct(FOC.FoCUS):
|
|
|
4274
4355
|
(
|
|
4275
4356
|
data_small.view(N_image, 1, 1, M3, N3)
|
|
4276
4357
|
* self.backend.bk_conjugate(I12_w3_small)
|
|
4277
|
-
|
|
4358
|
+
* edge_mask[None, None, None, :, :]
|
|
4359
|
+
).mean( # [..., edge_dx : M3 - edge_dx, edge_dy : N3 - edge_dy]
|
|
4278
4360
|
(-2, -1)
|
|
4279
4361
|
)
|
|
4280
4362
|
* fft_factor
|
|
@@ -4285,11 +4367,8 @@ class funct(FOC.FoCUS):
|
|
|
4285
4367
|
(
|
|
4286
4368
|
data_small.view(N_image, 1, 1, M3, N3)
|
|
4287
4369
|
* self.backend.bk_conjugate(I12_w3_small)
|
|
4288
|
-
|
|
4289
|
-
|
|
4290
|
-
].std(
|
|
4291
|
-
(-2, -1)
|
|
4292
|
-
)
|
|
4370
|
+
* edge_mask[None, None, None, :, :]
|
|
4371
|
+
).std((-2, -1))
|
|
4293
4372
|
* fft_factor
|
|
4294
4373
|
/ norm_factor_S3
|
|
4295
4374
|
)
|
|
@@ -4318,11 +4397,8 @@ class funct(FOC.FoCUS):
|
|
|
4318
4397
|
(
|
|
4319
4398
|
data2_small.view(N_image2, 1, 1, M3, N3)
|
|
4320
4399
|
* self.backend.bk_conjugate(I12_w3_small)
|
|
4321
|
-
|
|
4322
|
-
|
|
4323
|
-
].mean(
|
|
4324
|
-
(-2, -1)
|
|
4325
|
-
)
|
|
4400
|
+
* edge_mask[None, None, None, :, :]
|
|
4401
|
+
).mean((-2, -1))
|
|
4326
4402
|
* fft_factor
|
|
4327
4403
|
/ norm_factor_S3
|
|
4328
4404
|
)
|
|
@@ -4331,13 +4407,8 @@ class funct(FOC.FoCUS):
|
|
|
4331
4407
|
(
|
|
4332
4408
|
data2_small.view(N_image2, 1, 1, M3, N3)
|
|
4333
4409
|
* self.backend.bk_conjugate(I12_w3_small)
|
|
4334
|
-
|
|
4335
|
-
|
|
4336
|
-
edge_dx : M3 - edge_dx,
|
|
4337
|
-
edge_dy : N3 - edge_dy,
|
|
4338
|
-
].std(
|
|
4339
|
-
(-2, -1)
|
|
4340
|
-
)
|
|
4410
|
+
* edge_mask[None, None, None, :, :]
|
|
4411
|
+
).std((-2, -1))
|
|
4341
4412
|
* fft_factor
|
|
4342
4413
|
/ norm_factor_S3
|
|
4343
4414
|
)
|
|
@@ -4406,9 +4477,8 @@ class funct(FOC.FoCUS):
|
|
|
4406
4477
|
N_image, 1, L, L, M3, N3
|
|
4407
4478
|
)
|
|
4408
4479
|
)
|
|
4409
|
-
|
|
4410
|
-
|
|
4411
|
-
) * fft_factor
|
|
4480
|
+
* edge_mask[None, None, None, None, :, :]
|
|
4481
|
+
).mean((-2, -1)) * fft_factor
|
|
4412
4482
|
if get_variance:
|
|
4413
4483
|
S4_sigma[:, Ndata_S4, :, :, :] = (
|
|
4414
4484
|
I1_small[:, j1].view(
|
|
@@ -4419,11 +4489,10 @@ class funct(FOC.FoCUS):
|
|
|
4419
4489
|
N_image, 1, L, L, M3, N3
|
|
4420
4490
|
)
|
|
4421
4491
|
)
|
|
4422
|
-
|
|
4423
|
-
|
|
4424
|
-
|
|
4425
|
-
|
|
4426
|
-
) * fft_factor
|
|
4492
|
+
* edge_mask[
|
|
4493
|
+
None, None, None, None, :, :
|
|
4494
|
+
]
|
|
4495
|
+
).std((-2, -1)) * fft_factor
|
|
4427
4496
|
else:
|
|
4428
4497
|
for l1 in range(L):
|
|
4429
4498
|
# [N_image,l2,l3,x,y]
|
|
@@ -4436,11 +4505,10 @@ class funct(FOC.FoCUS):
|
|
|
4436
4505
|
N_image, L, L, M3, N3
|
|
4437
4506
|
)
|
|
4438
4507
|
)
|
|
4439
|
-
|
|
4440
|
-
|
|
4441
|
-
|
|
4442
|
-
|
|
4443
|
-
) * fft_factor
|
|
4508
|
+
* edge_mask[
|
|
4509
|
+
None, None, None, None, :, :
|
|
4510
|
+
]
|
|
4511
|
+
).mean((-2, -1)) * fft_factor
|
|
4444
4512
|
if get_variance:
|
|
4445
4513
|
S4_sigma[:, Ndata_S4, l1, :, :] = (
|
|
4446
4514
|
I1_small[:, j1].view(
|
|
@@ -4451,13 +4519,10 @@ class funct(FOC.FoCUS):
|
|
|
4451
4519
|
N_image, L, L, M3, N3
|
|
4452
4520
|
)
|
|
4453
4521
|
)
|
|
4454
|
-
|
|
4455
|
-
|
|
4456
|
-
|
|
4457
|
-
|
|
4458
|
-
].mean(
|
|
4459
|
-
(-2, -1)
|
|
4460
|
-
) * fft_factor
|
|
4522
|
+
* edge_mask[
|
|
4523
|
+
None, None, None, None, :, :
|
|
4524
|
+
]
|
|
4525
|
+
).std((-2, -1)) * fft_factor
|
|
4461
4526
|
|
|
4462
4527
|
Ndata_S4 += 1
|
|
4463
4528
|
|
|
@@ -4788,7 +4853,9 @@ class funct(FOC.FoCUS):
|
|
|
4788
4853
|
#
|
|
4789
4854
|
if edge:
|
|
4790
4855
|
if (M, N, J) not in self.edge_masks:
|
|
4791
|
-
self.edge_masks[(M, N, J)] = self.get_edge_masks(
|
|
4856
|
+
self.edge_masks[(M, N, J)] = self.get_edge_masks(
|
|
4857
|
+
M, N, J, in_mask=in_mask
|
|
4858
|
+
)
|
|
4792
4859
|
edge_mask = self.edge_masks[(M, N, J)]
|
|
4793
4860
|
else:
|
|
4794
4861
|
edge_mask = 1
|
|
@@ -5662,19 +5729,42 @@ class funct(FOC.FoCUS):
|
|
|
5662
5729
|
|
|
5663
5730
|
return for_synthesis
|
|
5664
5731
|
|
|
5665
|
-
def
|
|
5732
|
+
def purge_edge_mask(self):
|
|
5733
|
+
|
|
5734
|
+
list_edge = []
|
|
5735
|
+
for k in self.edge_masks:
|
|
5736
|
+
list_edge.append(k)
|
|
5737
|
+
for k in list_edge:
|
|
5738
|
+
del self.edge_masks[k]
|
|
5739
|
+
|
|
5740
|
+
self.edge_masks = {}
|
|
5741
|
+
|
|
5742
|
+
def to_gaussian(self, x, in_mask=None):
|
|
5666
5743
|
from scipy.interpolate import interp1d
|
|
5667
5744
|
from scipy.stats import norm
|
|
5668
5745
|
|
|
5669
|
-
|
|
5670
|
-
|
|
5671
|
-
|
|
5672
|
-
|
|
5746
|
+
if in_mask is not None:
|
|
5747
|
+
m_idx = np.where(in_mask.flatten() > 0)[0]
|
|
5748
|
+
idx = np.argsort(x.flatten()[m_idx])
|
|
5749
|
+
p = norm.ppf((np.arange(1, idx.shape[0] + 1) - 0.5) / idx.shape[0])
|
|
5750
|
+
im_target = x.flatten()
|
|
5751
|
+
im_target[m_idx[idx]] = p
|
|
5673
5752
|
|
|
5674
|
-
|
|
5675
|
-
|
|
5676
|
-
|
|
5677
|
-
|
|
5753
|
+
self.f_gaussian = interp1d(
|
|
5754
|
+
im_target[m_idx[idx]], x.flatten()[m_idx[idx]], kind="cubic"
|
|
5755
|
+
)
|
|
5756
|
+
self.val_min = im_target[m_idx][idx[0]]
|
|
5757
|
+
self.val_max = im_target[m_idx][idx[-1]]
|
|
5758
|
+
else:
|
|
5759
|
+
idx = np.argsort(x.flatten())
|
|
5760
|
+
p = (np.arange(1, idx.shape[0] + 1) - 0.5) / idx.shape[0]
|
|
5761
|
+
im_target = x.flatten()
|
|
5762
|
+
im_target[idx] = norm.ppf(p)
|
|
5763
|
+
|
|
5764
|
+
# Interpolation cubique
|
|
5765
|
+
self.f_gaussian = interp1d(im_target[idx], x.flatten()[idx], kind="cubic")
|
|
5766
|
+
self.val_min = im_target[idx[0]]
|
|
5767
|
+
self.val_max = im_target[idx[-1]]
|
|
5678
5768
|
return im_target.reshape(x.shape)
|
|
5679
5769
|
|
|
5680
5770
|
def from_gaussian(self, x):
|
|
@@ -5948,57 +6038,6 @@ class funct(FOC.FoCUS):
|
|
|
5948
6038
|
|
|
5949
6039
|
return result
|
|
5950
6040
|
|
|
5951
|
-
# # ---------------------------------------------−---------
|
|
5952
|
-
# def std(self, list_of_sc):
|
|
5953
|
-
# n = len(list_of_sc)
|
|
5954
|
-
# res = list_of_sc[0]
|
|
5955
|
-
# res2 = list_of_sc[0] * list_of_sc[0]
|
|
5956
|
-
# for k in range(1, n):
|
|
5957
|
-
# res = res + list_of_sc[k]
|
|
5958
|
-
# res2 = res2 + list_of_sc[k] * list_of_sc[k]
|
|
5959
|
-
#
|
|
5960
|
-
# if res.S1 is None:
|
|
5961
|
-
# if res.S3P is not None:
|
|
5962
|
-
# return scat_cov(
|
|
5963
|
-
# res.domult(sig.S0, res.S0) * res.domult(sig.S0, res.S0),
|
|
5964
|
-
# res.domult(sig.S2, res.S2) * res.domult(sig.S2, res.S2),
|
|
5965
|
-
# res.domult(sig.S3, res.S3) * res.domult(sig.S3, res.S3),
|
|
5966
|
-
# res.domult(sig.S4, res.S4) * res.domult(sig.S4, res.S4),
|
|
5967
|
-
# S3P=res.domult(sig.S3P, res.S3P) * res.domult(sig.S3P, res.S3P),
|
|
5968
|
-
# backend=self.backend,
|
|
5969
|
-
# use_1D=self.use_1D,
|
|
5970
|
-
# )
|
|
5971
|
-
# else:
|
|
5972
|
-
# return scat_cov(
|
|
5973
|
-
# res.domult(sig.S0, res.S0) * res.domult(sig.S0, res.S0),
|
|
5974
|
-
# res.domult(sig.S2, res.S2) * res.domult(sig.S2, res.S2),
|
|
5975
|
-
# res.domult(sig.S3, res.S3) * res.domult(sig.S3, res.S3),
|
|
5976
|
-
# res.domult(sig.S4, res.S4) * res.domult(sig.S4, res.S4),
|
|
5977
|
-
# backend=self.backend,
|
|
5978
|
-
# use_1D=self.use_1D,
|
|
5979
|
-
# )
|
|
5980
|
-
# else:
|
|
5981
|
-
# if res.S3P is None:
|
|
5982
|
-
# return scat_cov(
|
|
5983
|
-
# res.domult(sig.S0, res.S0) * res.domult(sig.S0, res.S0),
|
|
5984
|
-
# res.domult(sig.S2, res.S2) * res.domult(sig.S2, res.S2),
|
|
5985
|
-
# res.domult(sig.S3, res.S3) * res.domult(sig.S3, res.S3),
|
|
5986
|
-
# res.domult(sig.S4, res.S4) * res.domult(sig.S4, res.S4),
|
|
5987
|
-
# S1=res.domult(sig.S1, res.S1) * res.domult(sig.S1, res.S1),
|
|
5988
|
-
# S3P=res.domult(sig.S3P, res.S3P) * res.domult(sig.S3P, res.S3P),
|
|
5989
|
-
# backend=self.backend,
|
|
5990
|
-
# )
|
|
5991
|
-
# else:
|
|
5992
|
-
# return scat_cov(
|
|
5993
|
-
# res.domult(sig.S2, res.S2) * res.domult(sig.S2, res.S2),
|
|
5994
|
-
# res.domult(sig.S1, res.S1) * res.domult(sig.S1, res.S1),
|
|
5995
|
-
# res.domult(sig.S3, res.S3) * res.domult(sig.S3, res.S3),
|
|
5996
|
-
# res.domult(sig.S4, res.S4) * res.domult(sig.S4, res.S4),
|
|
5997
|
-
# backend=self.backend,
|
|
5998
|
-
# use_1D=self.use_1D,
|
|
5999
|
-
# )
|
|
6000
|
-
# return self.NORIENT
|
|
6001
|
-
|
|
6002
6041
|
@tf_function
|
|
6003
6042
|
def eval_comp_fast(
|
|
6004
6043
|
self,
|
|
@@ -6036,6 +6075,7 @@ class funct(FOC.FoCUS):
|
|
|
6036
6075
|
def synthesis(
|
|
6037
6076
|
self,
|
|
6038
6077
|
image_target,
|
|
6078
|
+
reference=None,
|
|
6039
6079
|
nstep=4,
|
|
6040
6080
|
seed=1234,
|
|
6041
6081
|
Jmax=None,
|
|
@@ -6045,6 +6085,7 @@ class funct(FOC.FoCUS):
|
|
|
6045
6085
|
synthesised_N=1,
|
|
6046
6086
|
input_image=None,
|
|
6047
6087
|
grd_mask=None,
|
|
6088
|
+
in_mask=None,
|
|
6048
6089
|
iso_ang=False,
|
|
6049
6090
|
EVAL_FREQUENCY=100,
|
|
6050
6091
|
NUM_EPOCHS=300,
|
|
@@ -6054,18 +6095,59 @@ class funct(FOC.FoCUS):
|
|
|
6054
6095
|
|
|
6055
6096
|
import foscat.Synthesis as synthe
|
|
6056
6097
|
|
|
6098
|
+
l_edge = edge
|
|
6099
|
+
if in_mask is not None:
|
|
6100
|
+
l_edge = True
|
|
6101
|
+
|
|
6102
|
+
if edge:
|
|
6103
|
+
self.purge_edge_mask()
|
|
6104
|
+
|
|
6057
6105
|
def The_loss(u, scat_operator, args):
|
|
6058
6106
|
ref = args[0]
|
|
6059
6107
|
sref = args[1]
|
|
6060
6108
|
use_v = args[2]
|
|
6109
|
+
ljmax = args[3]
|
|
6110
|
+
|
|
6111
|
+
# compute scattering covariance of the current synthetised map called u
|
|
6112
|
+
if use_v:
|
|
6113
|
+
learn = scat_operator.reduce_mean_batch(
|
|
6114
|
+
scat_operator.scattering_cov(
|
|
6115
|
+
u,
|
|
6116
|
+
edge=l_edge,
|
|
6117
|
+
Jmax=ljmax,
|
|
6118
|
+
ref_sigma=sref,
|
|
6119
|
+
use_ref=True,
|
|
6120
|
+
iso_ang=iso_ang,
|
|
6121
|
+
)
|
|
6122
|
+
)
|
|
6123
|
+
else:
|
|
6124
|
+
learn = scat_operator.reduce_mean_batch(
|
|
6125
|
+
scat_operator.scattering_cov(
|
|
6126
|
+
u, edge=l_edge, Jmax=ljmax, use_ref=True, iso_ang=iso_ang
|
|
6127
|
+
)
|
|
6128
|
+
)
|
|
6129
|
+
|
|
6130
|
+
# make the difference withe the reference coordinates
|
|
6131
|
+
loss = scat_operator.backend.bk_reduce_mean(
|
|
6132
|
+
scat_operator.backend.bk_square(learn - ref)
|
|
6133
|
+
)
|
|
6134
|
+
return loss
|
|
6135
|
+
|
|
6136
|
+
def The_lossX(u, scat_operator, args):
|
|
6137
|
+
ref = args[0]
|
|
6138
|
+
sref = args[1]
|
|
6139
|
+
use_v = args[2]
|
|
6140
|
+
im2 = args[3]
|
|
6141
|
+
ljmax = args[4]
|
|
6061
6142
|
|
|
6062
6143
|
# compute scattering covariance of the current synthetised map called u
|
|
6063
6144
|
if use_v:
|
|
6064
6145
|
learn = scat_operator.reduce_mean_batch(
|
|
6065
6146
|
scat_operator.scattering_cov(
|
|
6066
6147
|
u,
|
|
6067
|
-
|
|
6068
|
-
|
|
6148
|
+
data2=im2,
|
|
6149
|
+
edge=l_edge,
|
|
6150
|
+
Jmax=ljmax,
|
|
6069
6151
|
ref_sigma=sref,
|
|
6070
6152
|
use_ref=True,
|
|
6071
6153
|
iso_ang=iso_ang,
|
|
@@ -6074,7 +6156,12 @@ class funct(FOC.FoCUS):
|
|
|
6074
6156
|
else:
|
|
6075
6157
|
learn = scat_operator.reduce_mean_batch(
|
|
6076
6158
|
scat_operator.scattering_cov(
|
|
6077
|
-
u,
|
|
6159
|
+
u,
|
|
6160
|
+
data2=im2,
|
|
6161
|
+
edge=l_edge,
|
|
6162
|
+
Jmax=ljmax,
|
|
6163
|
+
use_ref=True,
|
|
6164
|
+
iso_ang=iso_ang,
|
|
6078
6165
|
)
|
|
6079
6166
|
)
|
|
6080
6167
|
|
|
@@ -6086,7 +6173,7 @@ class funct(FOC.FoCUS):
|
|
|
6086
6173
|
|
|
6087
6174
|
if to_gaussian:
|
|
6088
6175
|
# Change the data histogram to gaussian distribution
|
|
6089
|
-
im_target = self.to_gaussian(image_target)
|
|
6176
|
+
im_target = self.to_gaussian(image_target, in_mask=in_mask)
|
|
6090
6177
|
else:
|
|
6091
6178
|
im_target = image_target
|
|
6092
6179
|
|
|
@@ -6113,22 +6200,59 @@ class funct(FOC.FoCUS):
|
|
|
6113
6200
|
|
|
6114
6201
|
t1 = time.time()
|
|
6115
6202
|
tmp = {}
|
|
6116
|
-
|
|
6117
|
-
l_grd_mask={}
|
|
6118
|
-
|
|
6203
|
+
|
|
6204
|
+
l_grd_mask = {}
|
|
6205
|
+
l_in_mask = {}
|
|
6206
|
+
l_input_image = {}
|
|
6207
|
+
l_ref = {}
|
|
6208
|
+
l_jmax = {}
|
|
6209
|
+
|
|
6119
6210
|
tmp[nstep - 1] = self.backend.bk_cast(im_target)
|
|
6211
|
+
l_jmax[nstep - 1] = Jmax
|
|
6212
|
+
|
|
6213
|
+
if reference is not None:
|
|
6214
|
+
l_ref[nstep - 1] = self.backend.bk_cast(reference)
|
|
6215
|
+
else:
|
|
6216
|
+
l_ref[nstep - 1] = None
|
|
6217
|
+
|
|
6120
6218
|
if grd_mask is not None:
|
|
6121
6219
|
l_grd_mask[nstep - 1] = self.backend.bk_cast(grd_mask)
|
|
6122
6220
|
else:
|
|
6123
6221
|
l_grd_mask[nstep - 1] = None
|
|
6124
|
-
|
|
6222
|
+
if in_mask is not None:
|
|
6223
|
+
l_in_mask[nstep - 1] = in_mask
|
|
6224
|
+
else:
|
|
6225
|
+
l_in_mask[nstep - 1] = None
|
|
6226
|
+
|
|
6227
|
+
if input_image is not None:
|
|
6228
|
+
l_input_image[nstep - 1] = input_image
|
|
6229
|
+
|
|
6125
6230
|
for ell in range(nstep - 2, -1, -1):
|
|
6126
|
-
tmp[ell] = self.ud_grade_2(tmp[ell + 1], axis=1)
|
|
6231
|
+
tmp[ell], _ = self.ud_grade_2(tmp[ell + 1], axis=1)
|
|
6232
|
+
|
|
6127
6233
|
if grd_mask is not None:
|
|
6128
|
-
l_grd_mask[ell] = self.ud_grade_2(l_grd_mask[ell + 1], axis=1)
|
|
6234
|
+
l_grd_mask[ell], _ = self.ud_grade_2(l_grd_mask[ell + 1], axis=1)
|
|
6129
6235
|
else:
|
|
6130
6236
|
l_grd_mask[ell] = None
|
|
6131
|
-
|
|
6237
|
+
|
|
6238
|
+
if in_mask is not None:
|
|
6239
|
+
l_in_mask[ell], _ = self.ud_grade_2(l_in_mask[ell + 1])
|
|
6240
|
+
l_in_mask[ell] = self.backend.to_numpy(l_in_mask[ell])
|
|
6241
|
+
else:
|
|
6242
|
+
l_in_mask[ell] = None
|
|
6243
|
+
|
|
6244
|
+
if input_image is not None:
|
|
6245
|
+
l_input_image[ell], _ = self.ud_grade_2(l_input_image[ell + 1], axis=1)
|
|
6246
|
+
|
|
6247
|
+
if reference is not None:
|
|
6248
|
+
l_ref[ell], _ = self.ud_grade_2(l_ref[ell + 1], axis=1)
|
|
6249
|
+
else:
|
|
6250
|
+
l_ref[ell] = None
|
|
6251
|
+
|
|
6252
|
+
if l_jmax[ell + 1] is None:
|
|
6253
|
+
l_jmax[ell] = None
|
|
6254
|
+
else:
|
|
6255
|
+
l_jmax[ell] = l_jmax[ell + 1] - 1
|
|
6132
6256
|
|
|
6133
6257
|
if not self.use_2D and not self.use_1D:
|
|
6134
6258
|
l_nside = nside // (2 ** (nstep - 1))
|
|
@@ -6138,16 +6262,20 @@ class funct(FOC.FoCUS):
|
|
|
6138
6262
|
if input_image is None:
|
|
6139
6263
|
np.random.seed(seed)
|
|
6140
6264
|
if self.use_2D:
|
|
6141
|
-
imap = self.backend.bk_cast(
|
|
6142
|
-
|
|
6143
|
-
|
|
6265
|
+
imap = self.backend.bk_cast(
|
|
6266
|
+
np.random.randn(
|
|
6267
|
+
synthesised_N, tmp[k].shape[1], tmp[k].shape[2]
|
|
6268
|
+
)
|
|
6269
|
+
)
|
|
6144
6270
|
else:
|
|
6145
|
-
imap = self.backend.bk_cast(
|
|
6271
|
+
imap = self.backend.bk_cast(
|
|
6272
|
+
np.random.randn(synthesised_N, tmp[k].shape[1])
|
|
6273
|
+
)
|
|
6146
6274
|
else:
|
|
6147
6275
|
if self.use_2D:
|
|
6148
6276
|
imap = self.backend.bk_reshape(
|
|
6149
6277
|
self.backend.bk_tile(
|
|
6150
|
-
self.backend.bk_cast(
|
|
6278
|
+
self.backend.bk_cast(l_input_image[k].flatten()),
|
|
6151
6279
|
synthesised_N,
|
|
6152
6280
|
),
|
|
6153
6281
|
[synthesised_N, tmp[k].shape[1], tmp[k].shape[2]],
|
|
@@ -6155,7 +6283,7 @@ class funct(FOC.FoCUS):
|
|
|
6155
6283
|
else:
|
|
6156
6284
|
imap = self.backend.bk_reshape(
|
|
6157
6285
|
self.backend.bk_tile(
|
|
6158
|
-
self.backend.bk_cast(
|
|
6286
|
+
self.backend.bk_cast(l_input_image[k].flatten()),
|
|
6159
6287
|
synthesised_N,
|
|
6160
6288
|
),
|
|
6161
6289
|
[synthesised_N, tmp[k].shape[1]],
|
|
@@ -6170,24 +6298,46 @@ class funct(FOC.FoCUS):
|
|
|
6170
6298
|
imap = self.up_grade(omap, imap.shape[1] * 2, axis=1)
|
|
6171
6299
|
else:
|
|
6172
6300
|
imap = self.up_grade(omap, l_nside, axis=1)
|
|
6173
|
-
|
|
6301
|
+
|
|
6174
6302
|
if grd_mask is not None:
|
|
6175
|
-
imap=imap*l_grd_mask[k]+tmp[k]*(1-l_grd_mask[k])
|
|
6176
|
-
|
|
6303
|
+
imap = imap * l_grd_mask[k] + tmp[k] * (1 - l_grd_mask[k])
|
|
6304
|
+
|
|
6177
6305
|
# compute the coefficients for the target image
|
|
6178
6306
|
if use_variance:
|
|
6179
6307
|
ref, sref = self.scattering_cov(
|
|
6180
|
-
tmp[k],
|
|
6308
|
+
tmp[k],
|
|
6309
|
+
data2=l_ref[k],
|
|
6310
|
+
get_variance=True,
|
|
6311
|
+
edge=l_edge,
|
|
6312
|
+
Jmax=l_jmax[k],
|
|
6313
|
+
in_mask=l_in_mask[k],
|
|
6314
|
+
iso_ang=iso_ang,
|
|
6181
6315
|
)
|
|
6182
6316
|
else:
|
|
6183
|
-
ref = self.scattering_cov(
|
|
6317
|
+
ref = self.scattering_cov(
|
|
6318
|
+
tmp[k],
|
|
6319
|
+
data2=l_ref[k],
|
|
6320
|
+
in_mask=l_in_mask[k],
|
|
6321
|
+
edge=l_edge,
|
|
6322
|
+
Jmax=l_jmax[k],
|
|
6323
|
+
iso_ang=iso_ang,
|
|
6324
|
+
)
|
|
6184
6325
|
sref = ref
|
|
6185
6326
|
|
|
6186
6327
|
# compute the mean of the population does nothing if only one map is given
|
|
6187
6328
|
ref = self.reduce_mean_batch(ref)
|
|
6188
6329
|
|
|
6189
|
-
|
|
6190
|
-
|
|
6330
|
+
if l_in_mask[k] is not None:
|
|
6331
|
+
self.purge_edge_mask()
|
|
6332
|
+
|
|
6333
|
+
if l_ref[k] is None:
|
|
6334
|
+
# define a loss to minimize
|
|
6335
|
+
loss = synthe.Loss(The_loss, self, ref, sref, use_variance, l_jmax[k])
|
|
6336
|
+
else:
|
|
6337
|
+
# define a loss to minimize
|
|
6338
|
+
loss = synthe.Loss(
|
|
6339
|
+
The_lossX, self, ref, sref, use_variance, l_ref[k], l_jmax[k]
|
|
6340
|
+
)
|
|
6191
6341
|
|
|
6192
6342
|
sy = synthe.Synthesis([loss])
|
|
6193
6343
|
|
|
@@ -6201,7 +6351,12 @@ class funct(FOC.FoCUS):
|
|
|
6201
6351
|
l_nside *= 2
|
|
6202
6352
|
|
|
6203
6353
|
# do the minimization
|
|
6204
|
-
omap = sy.run(
|
|
6354
|
+
omap = sy.run(
|
|
6355
|
+
imap,
|
|
6356
|
+
EVAL_FREQUENCY=EVAL_FREQUENCY,
|
|
6357
|
+
NUM_EPOCHS=NUM_EPOCHS,
|
|
6358
|
+
grd_mask=l_grd_mask[k],
|
|
6359
|
+
)
|
|
6205
6360
|
|
|
6206
6361
|
t2 = time.time()
|
|
6207
6362
|
print("Total computation %.2fs" % (t2 - t1))
|