foscat 3.7.0__py3-none-any.whl → 3.7.1__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- foscat/FoCUS.py +4 -1
- foscat/backend.py +45 -10
- foscat/scat_cov.py +1685 -30
- foscat/scat_cov2D.py +65 -25
- {foscat-3.7.0.dist-info → foscat-3.7.1.dist-info}/METADATA +2 -2
- {foscat-3.7.0.dist-info → foscat-3.7.1.dist-info}/RECORD +9 -9
- {foscat-3.7.0.dist-info → foscat-3.7.1.dist-info}/WHEEL +1 -1
- /foscat-3.7.0.dist-info/LICENCE → /foscat-3.7.1.dist-info/LICENSE +0 -0
- {foscat-3.7.0.dist-info → foscat-3.7.1.dist-info}/top_level.txt +0 -0
foscat/scat_cov.py
CHANGED
|
@@ -92,7 +92,7 @@ class scat_cov:
|
|
|
92
92
|
)
|
|
93
93
|
|
|
94
94
|
def conv2complex(self, val):
|
|
95
|
-
if val.dtype == "complex64"
|
|
95
|
+
if val.dtype == "complex64" or val.dtype=="complex128" or val.dtype == "torch.complex64" or val.dtype == "torch.complex128" :
|
|
96
96
|
return val
|
|
97
97
|
else:
|
|
98
98
|
return self.backend.bk_complex(val, 0 * val)
|
|
@@ -2542,6 +2542,1047 @@ class funct(FOC.FoCUS):
|
|
|
2542
2542
|
"""
|
|
2543
2543
|
return_data = self.return_data
|
|
2544
2544
|
# Check input consistency
|
|
2545
|
+
if image2 is not None:
|
|
2546
|
+
if list(image1.shape) != list(image2.shape):
|
|
2547
|
+
print(
|
|
2548
|
+
"The two input image should have the same size to eval Scattering Covariance"
|
|
2549
|
+
)
|
|
2550
|
+
return None
|
|
2551
|
+
if mask is not None:
|
|
2552
|
+
if self.use_2D:
|
|
2553
|
+
if image1.shape[-2] != mask.shape[1] or image1.shape[-1] != mask.shape[2]:
|
|
2554
|
+
print(
|
|
2555
|
+
"The LAST 2 COLUMNs of the mask should have the same size ",
|
|
2556
|
+
mask.shape,
|
|
2557
|
+
"than the input image ",
|
|
2558
|
+
image1.shape,
|
|
2559
|
+
"to eval Scattering Covariance",
|
|
2560
|
+
)
|
|
2561
|
+
return None
|
|
2562
|
+
else:
|
|
2563
|
+
if image1.shape[-1] != mask.shape[1]:
|
|
2564
|
+
print(
|
|
2565
|
+
"The LAST COLUMN of the mask should have the same size ",
|
|
2566
|
+
mask.shape,
|
|
2567
|
+
"than the input image ",
|
|
2568
|
+
image1.shape,
|
|
2569
|
+
"to eval Scattering Covariance",
|
|
2570
|
+
)
|
|
2571
|
+
return None
|
|
2572
|
+
if self.use_2D and len(image1.shape) < 2:
|
|
2573
|
+
print(
|
|
2574
|
+
"To work with 2D scattering transform, two dimension is needed, input map has only on dimension"
|
|
2575
|
+
)
|
|
2576
|
+
return None
|
|
2577
|
+
|
|
2578
|
+
### AUTO OR CROSS
|
|
2579
|
+
cross = False
|
|
2580
|
+
if image2 is not None:
|
|
2581
|
+
cross = True
|
|
2582
|
+
|
|
2583
|
+
### PARAMETERS
|
|
2584
|
+
axis = 1
|
|
2585
|
+
# determine jmax and nside corresponding to the input map
|
|
2586
|
+
im_shape = image1.shape
|
|
2587
|
+
if self.use_2D:
|
|
2588
|
+
if len(image1.shape) == 2:
|
|
2589
|
+
nside = np.min([im_shape[0], im_shape[1]])
|
|
2590
|
+
npix = im_shape[0] * im_shape[1] # Number of pixels
|
|
2591
|
+
x1 = im_shape[0]
|
|
2592
|
+
x2 = im_shape[1]
|
|
2593
|
+
else:
|
|
2594
|
+
nside = np.min([im_shape[1], im_shape[2]])
|
|
2595
|
+
npix = im_shape[1] * im_shape[2] # Number of pixels
|
|
2596
|
+
x1 = im_shape[1]
|
|
2597
|
+
x2 = im_shape[2]
|
|
2598
|
+
J = int(np.log(nside - self.KERNELSZ) / np.log(2)) # Number of j scales
|
|
2599
|
+
elif self.use_1D:
|
|
2600
|
+
if len(image1.shape) == 2:
|
|
2601
|
+
npix = int(im_shape[1]) # Number of pixels
|
|
2602
|
+
else:
|
|
2603
|
+
npix = int(im_shape[0]) # Number of pixels
|
|
2604
|
+
|
|
2605
|
+
nside = int(npix)
|
|
2606
|
+
|
|
2607
|
+
J = int(np.log(nside) / np.log(2)) # Number of j scales
|
|
2608
|
+
else:
|
|
2609
|
+
if len(image1.shape) == 2:
|
|
2610
|
+
npix = int(im_shape[1]) # Number of pixels
|
|
2611
|
+
else:
|
|
2612
|
+
npix = int(im_shape[0]) # Number of pixels
|
|
2613
|
+
|
|
2614
|
+
nside = int(np.sqrt(npix // 12))
|
|
2615
|
+
|
|
2616
|
+
J = int(np.log(nside) / np.log(2)) # Number of j scales
|
|
2617
|
+
|
|
2618
|
+
if (self.use_2D or self.use_1D) and self.KERNELSZ>3:
|
|
2619
|
+
J-=1
|
|
2620
|
+
if Jmax is None:
|
|
2621
|
+
Jmax = J # Number of steps for the loop on scales
|
|
2622
|
+
if Jmax>J:
|
|
2623
|
+
print('==========\n\n')
|
|
2624
|
+
print('The Jmax you requested is larger than the data size, which may cause problems while computing the scattering transform.')
|
|
2625
|
+
print('\n\n==========')
|
|
2626
|
+
|
|
2627
|
+
|
|
2628
|
+
### LOCAL VARIABLES (IMAGES and MASK)
|
|
2629
|
+
if len(image1.shape) == 1 or (len(image1.shape) == 2 and self.use_2D):
|
|
2630
|
+
I1 = self.backend.bk_cast(
|
|
2631
|
+
self.backend.bk_expand_dims(image1, 0)
|
|
2632
|
+
) # Local image1 [Nbatch, Npix]
|
|
2633
|
+
if cross:
|
|
2634
|
+
I2 = self.backend.bk_cast(
|
|
2635
|
+
self.backend.bk_expand_dims(image2, 0)
|
|
2636
|
+
) # Local image2 [Nbatch, Npix]
|
|
2637
|
+
else:
|
|
2638
|
+
I1 = self.backend.bk_cast(image1) # Local image1 [Nbatch, Npix]
|
|
2639
|
+
if cross:
|
|
2640
|
+
I2 = self.backend.bk_cast(image2) # Local image2 [Nbatch, Npix]
|
|
2641
|
+
|
|
2642
|
+
if mask is None:
|
|
2643
|
+
if self.use_2D:
|
|
2644
|
+
vmask = self.backend.bk_ones([1, x1, x2], dtype=self.all_type)
|
|
2645
|
+
else:
|
|
2646
|
+
vmask = self.backend.bk_ones([1, npix], dtype=self.all_type)
|
|
2647
|
+
else:
|
|
2648
|
+
vmask = self.backend.bk_cast(mask) # [Nmask, Npix]
|
|
2649
|
+
|
|
2650
|
+
if self.KERNELSZ > 3 and not self.use_2D:
|
|
2651
|
+
# if the kernel size is bigger than 3 increase the binning before smoothing
|
|
2652
|
+
if self.use_2D:
|
|
2653
|
+
vmask = self.up_grade(
|
|
2654
|
+
vmask, I1.shape[axis] * 2, axis=1, nouty=I1.shape[axis + 1] * 2
|
|
2655
|
+
)
|
|
2656
|
+
I1 = self.up_grade(
|
|
2657
|
+
I1, I1.shape[axis] * 2, axis=axis, nouty=I1.shape[axis + 1] * 2
|
|
2658
|
+
)
|
|
2659
|
+
if cross:
|
|
2660
|
+
I2 = self.up_grade(
|
|
2661
|
+
I2, I2.shape[axis] * 2, axis=axis, nouty=I2.shape[axis + 1] * 2
|
|
2662
|
+
)
|
|
2663
|
+
elif self.use_1D:
|
|
2664
|
+
vmask = self.up_grade(vmask, I1.shape[axis] * 2, axis=1)
|
|
2665
|
+
I1 = self.up_grade(I1, I1.shape[axis] * 2, axis=axis)
|
|
2666
|
+
if cross:
|
|
2667
|
+
I2 = self.up_grade(I2, I2.shape[axis] * 2, axis=axis)
|
|
2668
|
+
else:
|
|
2669
|
+
I1 = self.up_grade(I1, nside * 2, axis=axis)
|
|
2670
|
+
vmask = self.up_grade(vmask, nside * 2, axis=1)
|
|
2671
|
+
if cross:
|
|
2672
|
+
I2 = self.up_grade(I2, nside * 2, axis=axis)
|
|
2673
|
+
|
|
2674
|
+
if self.KERNELSZ > 5 and not self.use_2D:
|
|
2675
|
+
# if the kernel size is bigger than 3 increase the binning before smoothing
|
|
2676
|
+
if self.use_2D:
|
|
2677
|
+
vmask = self.up_grade(
|
|
2678
|
+
vmask, I1.shape[axis] * 2, axis=1, nouty=I1.shape[axis + 1] * 2
|
|
2679
|
+
)
|
|
2680
|
+
I1 = self.up_grade(
|
|
2681
|
+
I1, I1.shape[axis] * 2, axis=axis, nouty=I1.shape[axis + 1] * 2
|
|
2682
|
+
)
|
|
2683
|
+
if cross:
|
|
2684
|
+
I2 = self.up_grade(
|
|
2685
|
+
I2,
|
|
2686
|
+
I2.shape[axis] * 2,
|
|
2687
|
+
axis=axis,
|
|
2688
|
+
nouty=I2.shape[axis + 1] * 2,
|
|
2689
|
+
)
|
|
2690
|
+
elif self.use_1D:
|
|
2691
|
+
vmask = self.up_grade(vmask, I1.shape[axis] * 4, axis=1)
|
|
2692
|
+
I1 = self.up_grade(I1, I1.shape[axis] * 4, axis=axis)
|
|
2693
|
+
if cross:
|
|
2694
|
+
I2 = self.up_grade(I2, I2.shape[axis] * 4, axis=axis)
|
|
2695
|
+
else:
|
|
2696
|
+
I1 = self.up_grade(I1, nside * 4, axis=axis)
|
|
2697
|
+
vmask = self.up_grade(vmask, nside * 4, axis=1)
|
|
2698
|
+
if cross:
|
|
2699
|
+
I2 = self.up_grade(I2, nside * 4, axis=axis)
|
|
2700
|
+
|
|
2701
|
+
# Normalize the masks because they have different pixel numbers
|
|
2702
|
+
# vmask /= self.backend.bk_reduce_sum(vmask, axis=1)[:, None] # [Nmask, Npix]
|
|
2703
|
+
|
|
2704
|
+
### INITIALIZATION
|
|
2705
|
+
# Coefficients
|
|
2706
|
+
if return_data:
|
|
2707
|
+
S1 = {}
|
|
2708
|
+
S2 = {}
|
|
2709
|
+
S3 = {}
|
|
2710
|
+
S3P = {}
|
|
2711
|
+
S4 = {}
|
|
2712
|
+
else:
|
|
2713
|
+
S1 = []
|
|
2714
|
+
S2 = []
|
|
2715
|
+
S3 = []
|
|
2716
|
+
S4 = []
|
|
2717
|
+
S3P = []
|
|
2718
|
+
VS1 = []
|
|
2719
|
+
VS2 = []
|
|
2720
|
+
VS3 = []
|
|
2721
|
+
VS3P = []
|
|
2722
|
+
VS4 = []
|
|
2723
|
+
|
|
2724
|
+
off_S2 = -2
|
|
2725
|
+
off_S3 = -3
|
|
2726
|
+
off_S4 = -4
|
|
2727
|
+
if self.use_1D:
|
|
2728
|
+
off_S2 = -1
|
|
2729
|
+
off_S3 = -1
|
|
2730
|
+
off_S4 = -1
|
|
2731
|
+
|
|
2732
|
+
# S2 for normalization
|
|
2733
|
+
cond_init_P1_dic = (norm == "self") or (
|
|
2734
|
+
(norm == "auto") and (self.P1_dic is None)
|
|
2735
|
+
)
|
|
2736
|
+
if norm is None:
|
|
2737
|
+
pass
|
|
2738
|
+
elif cond_init_P1_dic:
|
|
2739
|
+
P1_dic = {}
|
|
2740
|
+
if cross:
|
|
2741
|
+
P2_dic = {}
|
|
2742
|
+
elif (norm == "auto") and (self.P1_dic is not None):
|
|
2743
|
+
P1_dic = self.P1_dic
|
|
2744
|
+
if cross:
|
|
2745
|
+
P2_dic = self.P2_dic
|
|
2746
|
+
|
|
2747
|
+
if return_data:
|
|
2748
|
+
s0 = I1
|
|
2749
|
+
if out_nside is not None:
|
|
2750
|
+
s0 = self.backend.bk_reduce_mean(
|
|
2751
|
+
self.backend.bk_reshape(
|
|
2752
|
+
s0, [s0.shape[0], 12 * out_nside**2, (nside // out_nside) ** 2]
|
|
2753
|
+
),
|
|
2754
|
+
2,
|
|
2755
|
+
)
|
|
2756
|
+
else:
|
|
2757
|
+
if not cross:
|
|
2758
|
+
s0, l_vs0 = self.masked_mean(I1, vmask, axis=1, calc_var=True)
|
|
2759
|
+
else:
|
|
2760
|
+
s0, l_vs0 = self.masked_mean(
|
|
2761
|
+
self.backend.bk_L1(I1 * I2), vmask, axis=1, calc_var=True
|
|
2762
|
+
)
|
|
2763
|
+
vs0 = self.backend.bk_concat([l_vs0, l_vs0], 1)
|
|
2764
|
+
s0 = self.backend.bk_concat([s0, l_vs0], 1)
|
|
2765
|
+
#### COMPUTE S1, S2, S3 and S4
|
|
2766
|
+
nside_j3 = nside # NSIDE start (nside_j3 = nside / 2^j3)
|
|
2767
|
+
|
|
2768
|
+
# a remettre comme avant
|
|
2769
|
+
M1_dic={}
|
|
2770
|
+
M2_dic={}
|
|
2771
|
+
|
|
2772
|
+
for j3 in range(Jmax):
|
|
2773
|
+
|
|
2774
|
+
if edge:
|
|
2775
|
+
if self.mask_mask is None:
|
|
2776
|
+
self.mask_mask={}
|
|
2777
|
+
if self.use_2D:
|
|
2778
|
+
if (vmask.shape[1],vmask.shape[2]) not in self.mask_mask:
|
|
2779
|
+
mask_mask=np.zeros([1,vmask.shape[1],vmask.shape[2]])
|
|
2780
|
+
mask_mask[0,
|
|
2781
|
+
self.KERNELSZ//2:-self.KERNELSZ//2+1,
|
|
2782
|
+
self.KERNELSZ//2:-self.KERNELSZ//2+1]=1.0
|
|
2783
|
+
self.mask_mask[(vmask.shape[1],vmask.shape[2])]=self.backend.bk_cast(mask_mask)
|
|
2784
|
+
vmask=vmask*self.mask_mask[(vmask.shape[1],vmask.shape[2])]
|
|
2785
|
+
#print(self.KERNELSZ//2,vmask,mask_mask)
|
|
2786
|
+
|
|
2787
|
+
if self.use_1D:
|
|
2788
|
+
if (vmask.shape[1]) not in self.mask_mask:
|
|
2789
|
+
mask_mask=np.zeros([1,vmask.shape[1]])
|
|
2790
|
+
mask_mask[0,
|
|
2791
|
+
self.KERNELSZ//2:-self.KERNELSZ//2+1]=1.0
|
|
2792
|
+
self.mask_mask[(vmask.shape[1])]=self.backend.bk_cast(mask_mask)
|
|
2793
|
+
vmask=vmask*self.mask_mask[(vmask.shape[1])]
|
|
2794
|
+
|
|
2795
|
+
if return_data:
|
|
2796
|
+
S3[j3] = None
|
|
2797
|
+
S3P[j3] = None
|
|
2798
|
+
|
|
2799
|
+
if S4 is None:
|
|
2800
|
+
S4 = {}
|
|
2801
|
+
S4[j3] = None
|
|
2802
|
+
|
|
2803
|
+
####### S1 and S2
|
|
2804
|
+
### Make the convolution I1 * Psi_j3
|
|
2805
|
+
conv1 = self.convol(I1, axis=1) # [Nbatch, Npix_j3, Norient3]
|
|
2806
|
+
if cmat is not None:
|
|
2807
|
+
tmp2 = self.backend.bk_repeat(conv1, self.NORIENT, axis=-1)
|
|
2808
|
+
conv1 = self.backend.bk_reduce_sum(
|
|
2809
|
+
self.backend.bk_reshape(
|
|
2810
|
+
cmat[j3] * tmp2,
|
|
2811
|
+
[tmp2.shape[0], cmat[j3].shape[0], self.NORIENT, self.NORIENT],
|
|
2812
|
+
),
|
|
2813
|
+
2,
|
|
2814
|
+
)
|
|
2815
|
+
|
|
2816
|
+
### Take the module M1 = |I1 * Psi_j3|
|
|
2817
|
+
M1_square = conv1 * self.backend.bk_conjugate(
|
|
2818
|
+
conv1
|
|
2819
|
+
) # [Nbatch, Npix_j3, Norient3]
|
|
2820
|
+
M1 = self.backend.bk_L1(M1_square) # [Nbatch, Npix_j3, Norient3]
|
|
2821
|
+
# Store M1_j3 in a dictionary
|
|
2822
|
+
M1_dic[j3] = M1
|
|
2823
|
+
|
|
2824
|
+
if not cross: # Auto
|
|
2825
|
+
M1_square = self.backend.bk_real(M1_square)
|
|
2826
|
+
|
|
2827
|
+
### S2_auto = < M1^2 >_pix
|
|
2828
|
+
# Apply the mask [Nmask, Npix_j3] and average over pixels
|
|
2829
|
+
if return_data:
|
|
2830
|
+
s2 = M1_square
|
|
2831
|
+
else:
|
|
2832
|
+
if calc_var:
|
|
2833
|
+
s2, vs2 = self.masked_mean(
|
|
2834
|
+
M1_square, vmask, axis=1, rank=j3, calc_var=True
|
|
2835
|
+
)
|
|
2836
|
+
else:
|
|
2837
|
+
s2 = self.masked_mean(M1_square, vmask, axis=1, rank=j3)
|
|
2838
|
+
|
|
2839
|
+
if cond_init_P1_dic:
|
|
2840
|
+
# We fill P1_dic with S2 for normalisation of S3 and S4
|
|
2841
|
+
P1_dic[j3] = self.backend.bk_real(s2) # [Nbatch, Nmask, Norient3]
|
|
2842
|
+
|
|
2843
|
+
# We store S2_auto to return it [Nbatch, Nmask, NS2, Norient3]
|
|
2844
|
+
if return_data:
|
|
2845
|
+
if S2 is None:
|
|
2846
|
+
S2 = {}
|
|
2847
|
+
if out_nside is not None and out_nside < nside_j3:
|
|
2848
|
+
s2 = self.backend.bk_reduce_mean(
|
|
2849
|
+
self.backend.bk_reshape(
|
|
2850
|
+
s2,
|
|
2851
|
+
[
|
|
2852
|
+
s2.shape[0],
|
|
2853
|
+
12 * out_nside**2,
|
|
2854
|
+
(nside_j3 // out_nside) ** 2,
|
|
2855
|
+
s2.shape[2],
|
|
2856
|
+
],
|
|
2857
|
+
),
|
|
2858
|
+
2,
|
|
2859
|
+
)
|
|
2860
|
+
S2[j3] = s2
|
|
2861
|
+
else:
|
|
2862
|
+
if norm == "auto": # Normalize S2
|
|
2863
|
+
s2 /= P1_dic[j3]
|
|
2864
|
+
|
|
2865
|
+
S2.append(
|
|
2866
|
+
self.backend.bk_expand_dims(s2, off_S2)
|
|
2867
|
+
) # Add a dimension for NS2
|
|
2868
|
+
if calc_var:
|
|
2869
|
+
VS2.append(
|
|
2870
|
+
self.backend.bk_expand_dims(vs2, off_S2)
|
|
2871
|
+
) # Add a dimension for NS2
|
|
2872
|
+
|
|
2873
|
+
#### S1_auto computation
|
|
2874
|
+
### Image 1 : S1 = < M1 >_pix
|
|
2875
|
+
# Apply the mask [Nmask, Npix_j3] and average over pixels
|
|
2876
|
+
if return_data:
|
|
2877
|
+
s1 = M1
|
|
2878
|
+
else:
|
|
2879
|
+
if calc_var:
|
|
2880
|
+
s1, vs1 = self.masked_mean(
|
|
2881
|
+
M1, vmask, axis=1, rank=j3, calc_var=True
|
|
2882
|
+
) # [Nbatch, Nmask, Norient3]
|
|
2883
|
+
else:
|
|
2884
|
+
s1 = self.masked_mean(
|
|
2885
|
+
M1, vmask, axis=1, rank=j3
|
|
2886
|
+
) # [Nbatch, Nmask, Norient3]
|
|
2887
|
+
|
|
2888
|
+
if return_data:
|
|
2889
|
+
if out_nside is not None and out_nside < nside_j3:
|
|
2890
|
+
s1 = self.backend.bk_reduce_mean(
|
|
2891
|
+
self.backend.bk_reshape(
|
|
2892
|
+
s1,
|
|
2893
|
+
[
|
|
2894
|
+
s1.shape[0],
|
|
2895
|
+
12 * out_nside**2,
|
|
2896
|
+
(nside_j3 // out_nside) ** 2,
|
|
2897
|
+
s1.shape[2],
|
|
2898
|
+
],
|
|
2899
|
+
),
|
|
2900
|
+
2,
|
|
2901
|
+
)
|
|
2902
|
+
S1[j3] = s1
|
|
2903
|
+
else:
|
|
2904
|
+
### Normalize S1
|
|
2905
|
+
if norm is not None:
|
|
2906
|
+
self.div_norm(s1, (P1_dic[j3]) ** 0.5)
|
|
2907
|
+
### We store S1 for image1 [Nbatch, Nmask, NS1, Norient3]
|
|
2908
|
+
S1.append(
|
|
2909
|
+
self.backend.bk_expand_dims(s1, off_S2)
|
|
2910
|
+
) # Add a dimension for NS1
|
|
2911
|
+
if calc_var:
|
|
2912
|
+
VS1.append(
|
|
2913
|
+
self.backend.bk_expand_dims(vs1, off_S2)
|
|
2914
|
+
) # Add a dimension for NS1
|
|
2915
|
+
|
|
2916
|
+
else: # Cross
|
|
2917
|
+
### Make the convolution I2 * Psi_j3
|
|
2918
|
+
conv2 = self.convol(I2, axis=1) # [Nbatch, Npix_j3, Norient3]
|
|
2919
|
+
if cmat is not None:
|
|
2920
|
+
tmp2 = self.backend.bk_repeat(conv2, self.NORIENT, axis=-1)
|
|
2921
|
+
conv2 = self.backend.bk_reduce_sum(
|
|
2922
|
+
self.backend.bk_reshape(
|
|
2923
|
+
cmat[j3] * tmp2,
|
|
2924
|
+
[
|
|
2925
|
+
tmp2.shape[0],
|
|
2926
|
+
cmat[j3].shape[0],
|
|
2927
|
+
self.NORIENT,
|
|
2928
|
+
self.NORIENT,
|
|
2929
|
+
],
|
|
2930
|
+
),
|
|
2931
|
+
2,
|
|
2932
|
+
)
|
|
2933
|
+
### Take the module M2 = |I2 * Psi_j3|
|
|
2934
|
+
M2_square = conv2 * self.backend.bk_conjugate(
|
|
2935
|
+
conv2
|
|
2936
|
+
) # [Nbatch, Npix_j3, Norient3]
|
|
2937
|
+
M2 = self.backend.bk_L1(M2_square) # [Nbatch, Npix_j3, Norient3]
|
|
2938
|
+
# Store M2_j3 in a dictionary
|
|
2939
|
+
M2_dic[j3] = M2
|
|
2940
|
+
|
|
2941
|
+
### S2_auto = < M2^2 >_pix
|
|
2942
|
+
# Not returned, only for normalization
|
|
2943
|
+
if cond_init_P1_dic:
|
|
2944
|
+
# Apply the mask [Nmask, Npix_j3] and average over pixels
|
|
2945
|
+
if return_data:
|
|
2946
|
+
p1 = M1_square
|
|
2947
|
+
p2 = M2_square
|
|
2948
|
+
else:
|
|
2949
|
+
if calc_var:
|
|
2950
|
+
p1, vp1 = self.masked_mean(
|
|
2951
|
+
M1_square, vmask, axis=1, rank=j3, calc_var=True
|
|
2952
|
+
) # [Nbatch, Nmask, Norient3]
|
|
2953
|
+
p2, vp2 = self.masked_mean(
|
|
2954
|
+
M2_square, vmask, axis=1, rank=j3, calc_var=True
|
|
2955
|
+
) # [Nbatch, Nmask, Norient3]
|
|
2956
|
+
else:
|
|
2957
|
+
p1 = self.masked_mean(
|
|
2958
|
+
M1_square, vmask, axis=1, rank=j3
|
|
2959
|
+
) # [Nbatch, Nmask, Norient3]
|
|
2960
|
+
p2 = self.masked_mean(
|
|
2961
|
+
M2_square, vmask, axis=1, rank=j3
|
|
2962
|
+
) # [Nbatch, Nmask, Norient3]
|
|
2963
|
+
# We fill P1_dic with S2 for normalisation of S3 and S4
|
|
2964
|
+
P1_dic[j3] = self.backend.bk_real(p1) # [Nbatch, Nmask, Norient3]
|
|
2965
|
+
P2_dic[j3] = self.backend.bk_real(p2) # [Nbatch, Nmask, Norient3]
|
|
2966
|
+
|
|
2967
|
+
### S2_cross = < (I1 * Psi_j3) (I2 * Psi_j3)^* >_pix
|
|
2968
|
+
# z_1 x z_2^* = (a1a2 + b1b2) + i(b1a2 - a1b2)
|
|
2969
|
+
s2 = conv1 * self.backend.bk_conjugate(conv2)
|
|
2970
|
+
MX = self.backend.bk_L1(s2)
|
|
2971
|
+
# Apply the mask [Nmask, Npix_j3] and average over pixels
|
|
2972
|
+
if return_data:
|
|
2973
|
+
s2 = s2
|
|
2974
|
+
else:
|
|
2975
|
+
if calc_var:
|
|
2976
|
+
s2, vs2 = self.masked_mean(
|
|
2977
|
+
s2, vmask, axis=1, rank=j3, calc_var=True
|
|
2978
|
+
)
|
|
2979
|
+
else:
|
|
2980
|
+
s2 = self.masked_mean(s2, vmask, axis=1, rank=j3)
|
|
2981
|
+
|
|
2982
|
+
if return_data:
|
|
2983
|
+
if out_nside is not None and out_nside < nside_j3:
|
|
2984
|
+
s2 = self.backend.bk_reduce_mean(
|
|
2985
|
+
self.backend.bk_reshape(
|
|
2986
|
+
s2,
|
|
2987
|
+
[
|
|
2988
|
+
s2.shape[0],
|
|
2989
|
+
12 * out_nside**2,
|
|
2990
|
+
(nside_j3 // out_nside) ** 2,
|
|
2991
|
+
s2.shape[2],
|
|
2992
|
+
],
|
|
2993
|
+
),
|
|
2994
|
+
2,
|
|
2995
|
+
)
|
|
2996
|
+
S2[j3] = s2
|
|
2997
|
+
else:
|
|
2998
|
+
### Normalize S2_cross
|
|
2999
|
+
if norm == "auto":
|
|
3000
|
+
s2 /= (P1_dic[j3] * P2_dic[j3]) ** 0.5
|
|
3001
|
+
|
|
3002
|
+
### Store S2_cross as complex [Nbatch, Nmask, NS2, Norient3]
|
|
3003
|
+
s2 = self.backend.bk_real(s2)
|
|
3004
|
+
|
|
3005
|
+
S2.append(
|
|
3006
|
+
self.backend.bk_expand_dims(s2, off_S2)
|
|
3007
|
+
) # Add a dimension for NS2
|
|
3008
|
+
if calc_var:
|
|
3009
|
+
VS2.append(
|
|
3010
|
+
self.backend.bk_expand_dims(vs2, off_S2)
|
|
3011
|
+
) # Add a dimension for NS2
|
|
3012
|
+
|
|
3013
|
+
#### S1_auto computation
|
|
3014
|
+
### Image 1 : S1 = < M1 >_pix
|
|
3015
|
+
# Apply the mask [Nmask, Npix_j3] and average over pixels
|
|
3016
|
+
if return_data:
|
|
3017
|
+
s1 = MX
|
|
3018
|
+
else:
|
|
3019
|
+
if calc_var:
|
|
3020
|
+
s1, vs1 = self.masked_mean(
|
|
3021
|
+
MX, vmask, axis=1, rank=j3, calc_var=True
|
|
3022
|
+
) # [Nbatch, Nmask, Norient3]
|
|
3023
|
+
else:
|
|
3024
|
+
s1 = self.masked_mean(
|
|
3025
|
+
MX, vmask, axis=1, rank=j3
|
|
3026
|
+
) # [Nbatch, Nmask, Norient3]
|
|
3027
|
+
if return_data:
|
|
3028
|
+
if out_nside is not None and out_nside < nside_j3:
|
|
3029
|
+
s1 = self.backend.bk_reduce_mean(
|
|
3030
|
+
self.backend.bk_reshape(
|
|
3031
|
+
s1,
|
|
3032
|
+
[
|
|
3033
|
+
s1.shape[0],
|
|
3034
|
+
12 * out_nside**2,
|
|
3035
|
+
(nside_j3 // out_nside) ** 2,
|
|
3036
|
+
s1.shape[2],
|
|
3037
|
+
],
|
|
3038
|
+
),
|
|
3039
|
+
2,
|
|
3040
|
+
)
|
|
3041
|
+
S1[j3] = s1
|
|
3042
|
+
else:
|
|
3043
|
+
### Normalize S1
|
|
3044
|
+
if norm is not None:
|
|
3045
|
+
self.div_norm(s1, (P1_dic[j3]) ** 0.5)
|
|
3046
|
+
### We store S1 for image1 [Nbatch, Nmask, NS1, Norient3]
|
|
3047
|
+
S1.append(
|
|
3048
|
+
self.backend.bk_expand_dims(s1, off_S2)
|
|
3049
|
+
) # Add a dimension for NS1
|
|
3050
|
+
if calc_var:
|
|
3051
|
+
VS1.append(
|
|
3052
|
+
self.backend.bk_expand_dims(vs1, off_S2)
|
|
3053
|
+
) # Add a dimension for NS1
|
|
3054
|
+
|
|
3055
|
+
# Initialize dictionaries for |I1*Psi_j| * Psi_j3
|
|
3056
|
+
M1convPsi_dic = {}
|
|
3057
|
+
if cross:
|
|
3058
|
+
# Initialize dictionaries for |I2*Psi_j| * Psi_j3
|
|
3059
|
+
M2convPsi_dic = {}
|
|
3060
|
+
|
|
3061
|
+
###### S3
|
|
3062
|
+
nside_j2 = nside_j3
|
|
3063
|
+
for j2 in range(0, j3 + 1): # j2 <= j3
|
|
3064
|
+
if return_data:
|
|
3065
|
+
if S4[j3] is None:
|
|
3066
|
+
S4[j3] = {}
|
|
3067
|
+
S4[j3][j2] = None
|
|
3068
|
+
|
|
3069
|
+
### S3_auto = < (I1 * Psi)_j3 x (|I1 * Psi_j2| * Psi_j3)^* >_pix
|
|
3070
|
+
if not cross:
|
|
3071
|
+
if calc_var:
|
|
3072
|
+
s3, vs3 = self._compute_S3(
|
|
3073
|
+
j2,
|
|
3074
|
+
j3,
|
|
3075
|
+
conv1,
|
|
3076
|
+
vmask,
|
|
3077
|
+
M1_dic,
|
|
3078
|
+
M1convPsi_dic,
|
|
3079
|
+
calc_var=True,
|
|
3080
|
+
cmat2=cmat2,
|
|
3081
|
+
) # [Nbatch, Nmask, Norient3, Norient2]
|
|
3082
|
+
else:
|
|
3083
|
+
s3 = self._compute_S3(
|
|
3084
|
+
j2,
|
|
3085
|
+
j3,
|
|
3086
|
+
conv1,
|
|
3087
|
+
vmask,
|
|
3088
|
+
M1_dic,
|
|
3089
|
+
M1convPsi_dic,
|
|
3090
|
+
return_data=return_data,
|
|
3091
|
+
cmat2=cmat2,
|
|
3092
|
+
) # [Nbatch, Nmask, Norient3, Norient2]
|
|
3093
|
+
|
|
3094
|
+
if return_data:
|
|
3095
|
+
if S3[j3] is None:
|
|
3096
|
+
S3[j3] = {}
|
|
3097
|
+
if out_nside is not None and out_nside < nside_j2:
|
|
3098
|
+
s3 = self.backend.bk_reduce_mean(
|
|
3099
|
+
self.backend.bk_reshape(
|
|
3100
|
+
s3,
|
|
3101
|
+
[
|
|
3102
|
+
s3.shape[0],
|
|
3103
|
+
12 * out_nside**2,
|
|
3104
|
+
(nside_j2 // out_nside) ** 2,
|
|
3105
|
+
s3.shape[2],
|
|
3106
|
+
s3.shape[3],
|
|
3107
|
+
],
|
|
3108
|
+
),
|
|
3109
|
+
2,
|
|
3110
|
+
)
|
|
3111
|
+
S3[j3][j2] = s3
|
|
3112
|
+
else:
|
|
3113
|
+
### Normalize S3 with S2_j [Nbatch, Nmask, Norient_j]
|
|
3114
|
+
if norm is not None:
|
|
3115
|
+
self.div_norm(
|
|
3116
|
+
s3,
|
|
3117
|
+
(
|
|
3118
|
+
self.backend.bk_expand_dims(P1_dic[j2], off_S2)
|
|
3119
|
+
* self.backend.bk_expand_dims(P1_dic[j3], -1)
|
|
3120
|
+
)
|
|
3121
|
+
** 0.5,
|
|
3122
|
+
) # [Nbatch, Nmask, Norient3, Norient2]
|
|
3123
|
+
|
|
3124
|
+
### Store S3 as a complex [Nbatch, Nmask, NS3, Norient3, Norient2]
|
|
3125
|
+
|
|
3126
|
+
# S3.append(self.backend.bk_reshape(s3,[s3.shape[0],s3.shape[1],
|
|
3127
|
+
# s3.shape[2]*s3.shape[3]]))
|
|
3128
|
+
S3.append(
|
|
3129
|
+
self.backend.bk_expand_dims(s3, off_S3)
|
|
3130
|
+
) # Add a dimension for NS3
|
|
3131
|
+
if calc_var:
|
|
3132
|
+
VS3.append(
|
|
3133
|
+
self.backend.bk_expand_dims(vs3, off_S3)
|
|
3134
|
+
) # Add a dimension for NS3
|
|
3135
|
+
# VS3.append(self.backend.bk_reshape(vs3,[s3.shape[0],s3.shape[1],
|
|
3136
|
+
# s3.shape[2]*s3.shape[3]]))
|
|
3137
|
+
|
|
3138
|
+
### S3_cross = < (I1 * Psi)_j3 x (|I2 * Psi_j2| * Psi_j3)^* >_pix
|
|
3139
|
+
### S3P_cross = < (I2 * Psi)_j3 x (|I1 * Psi_j2| * Psi_j3)^* >_pix
|
|
3140
|
+
else:
|
|
3141
|
+
if calc_var:
|
|
3142
|
+
s3, vs3 = self._compute_S3(
|
|
3143
|
+
j2,
|
|
3144
|
+
j3,
|
|
3145
|
+
conv1,
|
|
3146
|
+
vmask,
|
|
3147
|
+
M2_dic,
|
|
3148
|
+
M2convPsi_dic,
|
|
3149
|
+
calc_var=True,
|
|
3150
|
+
cmat2=cmat2,
|
|
3151
|
+
)
|
|
3152
|
+
s3p, vs3p = self._compute_S3(
|
|
3153
|
+
j2,
|
|
3154
|
+
j3,
|
|
3155
|
+
conv2,
|
|
3156
|
+
vmask,
|
|
3157
|
+
M1_dic,
|
|
3158
|
+
M1convPsi_dic,
|
|
3159
|
+
calc_var=True,
|
|
3160
|
+
cmat2=cmat2,
|
|
3161
|
+
)
|
|
3162
|
+
else:
|
|
3163
|
+
s3 = self._compute_S3(
|
|
3164
|
+
j2,
|
|
3165
|
+
j3,
|
|
3166
|
+
conv1,
|
|
3167
|
+
vmask,
|
|
3168
|
+
M2_dic,
|
|
3169
|
+
M2convPsi_dic,
|
|
3170
|
+
return_data=return_data,
|
|
3171
|
+
cmat2=cmat2,
|
|
3172
|
+
)
|
|
3173
|
+
s3p = self._compute_S3(
|
|
3174
|
+
j2,
|
|
3175
|
+
j3,
|
|
3176
|
+
conv2,
|
|
3177
|
+
vmask,
|
|
3178
|
+
M1_dic,
|
|
3179
|
+
M1convPsi_dic,
|
|
3180
|
+
return_data=return_data,
|
|
3181
|
+
cmat2=cmat2,
|
|
3182
|
+
)
|
|
3183
|
+
|
|
3184
|
+
if return_data:
|
|
3185
|
+
if S3[j3] is None:
|
|
3186
|
+
S3[j3] = {}
|
|
3187
|
+
S3P[j3] = {}
|
|
3188
|
+
if out_nside is not None and out_nside < nside_j2:
|
|
3189
|
+
s3 = self.backend.bk_reduce_mean(
|
|
3190
|
+
self.backend.bk_reshape(
|
|
3191
|
+
s3,
|
|
3192
|
+
[
|
|
3193
|
+
s3.shape[0],
|
|
3194
|
+
12 * out_nside**2,
|
|
3195
|
+
(nside_j2 // out_nside) ** 2,
|
|
3196
|
+
s3.shape[2],
|
|
3197
|
+
s3.shape[3],
|
|
3198
|
+
],
|
|
3199
|
+
),
|
|
3200
|
+
2,
|
|
3201
|
+
)
|
|
3202
|
+
s3p = self.backend.bk_reduce_mean(
|
|
3203
|
+
self.backend.bk_reshape(
|
|
3204
|
+
s3p,
|
|
3205
|
+
[
|
|
3206
|
+
s3.shape[0],
|
|
3207
|
+
12 * out_nside**2,
|
|
3208
|
+
(nside_j2 // out_nside) ** 2,
|
|
3209
|
+
s3.shape[2],
|
|
3210
|
+
s3.shape[3],
|
|
3211
|
+
],
|
|
3212
|
+
),
|
|
3213
|
+
2,
|
|
3214
|
+
)
|
|
3215
|
+
S3[j3][j2] = s3
|
|
3216
|
+
S3P[j3][j2] = s3p
|
|
3217
|
+
else:
|
|
3218
|
+
### Normalize S3 and S3P with S2_j [Nbatch, Nmask, Norient_j]
|
|
3219
|
+
if norm is not None:
|
|
3220
|
+
self.div_norm(
|
|
3221
|
+
s3,
|
|
3222
|
+
(
|
|
3223
|
+
self.backend.bk_expand_dims(P2_dic[j2], off_S2)
|
|
3224
|
+
* self.backend.bk_expand_dims(P1_dic[j3], -1)
|
|
3225
|
+
)
|
|
3226
|
+
** 0.5,
|
|
3227
|
+
) # [Nbatch, Nmask, Norient3, Norient2]
|
|
3228
|
+
self.div_norm(
|
|
3229
|
+
s3p,
|
|
3230
|
+
(
|
|
3231
|
+
self.backend.bk_expand_dims(P1_dic[j2], off_S2)
|
|
3232
|
+
* self.backend.bk_expand_dims(P2_dic[j3], -1)
|
|
3233
|
+
)
|
|
3234
|
+
** 0.5,
|
|
3235
|
+
) # [Nbatch, Nmask, Norient3, Norient2]
|
|
3236
|
+
|
|
3237
|
+
### Store S3 and S3P as a complex [Nbatch, Nmask, NS3, Norient3, Norient2]
|
|
3238
|
+
|
|
3239
|
+
# S3.append(self.backend.bk_reshape(s3,[s3.shape[0],s3.shape[1],
|
|
3240
|
+
# s3.shape[2]*s3.shape[3]]))
|
|
3241
|
+
S3.append(
|
|
3242
|
+
self.backend.bk_expand_dims(s3, off_S3)
|
|
3243
|
+
) # Add a dimension for NS3
|
|
3244
|
+
if calc_var:
|
|
3245
|
+
VS3.append(
|
|
3246
|
+
self.backend.bk_expand_dims(vs3, off_S3)
|
|
3247
|
+
) # Add a dimension for NS3
|
|
3248
|
+
|
|
3249
|
+
# VS3.append(self.backend.bk_reshape(vs3,[s3.shape[0],s3.shape[1],
|
|
3250
|
+
# s3.shape[2]*s3.shape[3]]))
|
|
3251
|
+
|
|
3252
|
+
# S3P.append(self.backend.bk_reshape(s3p,[s3.shape[0],s3.shape[1],
|
|
3253
|
+
# s3.shape[2]*s3.shape[3]]))
|
|
3254
|
+
S3P.append(
|
|
3255
|
+
self.backend.bk_expand_dims(s3p, off_S3)
|
|
3256
|
+
) # Add a dimension for NS3
|
|
3257
|
+
if calc_var:
|
|
3258
|
+
VS3P.append(
|
|
3259
|
+
self.backend.bk_expand_dims(vs3p, off_S3)
|
|
3260
|
+
) # Add a dimension for NS3
|
|
3261
|
+
# VS3P.append(self.backend.bk_reshape(vs3p,[s3.shape[0],s3.shape[1],
|
|
3262
|
+
# s3.shape[2]*s3.shape[3]]))
|
|
3263
|
+
|
|
3264
|
+
##### S4
|
|
3265
|
+
nside_j1 = nside_j2
|
|
3266
|
+
for j1 in range(0, j2 + 1): # j1 <= j2
|
|
3267
|
+
### S4_auto = <(|I1 * psi1| * psi3)(|I1 * psi2| * psi3)^*>
|
|
3268
|
+
if not cross:
|
|
3269
|
+
if calc_var:
|
|
3270
|
+
s4, vs4 = self._compute_S4(
|
|
3271
|
+
j1,
|
|
3272
|
+
j2,
|
|
3273
|
+
vmask,
|
|
3274
|
+
M1convPsi_dic,
|
|
3275
|
+
M2convPsi_dic=None,
|
|
3276
|
+
calc_var=True,
|
|
3277
|
+
) # [Nbatch, Nmask, Norient3, Norient2, Norient1]
|
|
3278
|
+
else:
|
|
3279
|
+
s4 = self._compute_S4(
|
|
3280
|
+
j1,
|
|
3281
|
+
j2,
|
|
3282
|
+
vmask,
|
|
3283
|
+
M1convPsi_dic,
|
|
3284
|
+
M2convPsi_dic=None,
|
|
3285
|
+
return_data=return_data,
|
|
3286
|
+
) # [Nbatch, Nmask, Norient3, Norient2, Norient1]
|
|
3287
|
+
|
|
3288
|
+
if return_data:
|
|
3289
|
+
if S4[j3][j2] is None:
|
|
3290
|
+
S4[j3][j2] = {}
|
|
3291
|
+
if out_nside is not None and out_nside < nside_j1:
|
|
3292
|
+
s4 = self.backend.bk_reduce_mean(
|
|
3293
|
+
self.backend.bk_reshape(
|
|
3294
|
+
s4,
|
|
3295
|
+
[
|
|
3296
|
+
s4.shape[0],
|
|
3297
|
+
12 * out_nside**2,
|
|
3298
|
+
(nside_j1 // out_nside) ** 2,
|
|
3299
|
+
s4.shape[2],
|
|
3300
|
+
s4.shape[3],
|
|
3301
|
+
s4.shape[4],
|
|
3302
|
+
],
|
|
3303
|
+
),
|
|
3304
|
+
2,
|
|
3305
|
+
)
|
|
3306
|
+
S4[j3][j2][j1] = s4
|
|
3307
|
+
else:
|
|
3308
|
+
### Normalize S4 with S2_j [Nbatch, Nmask, Norient_j]
|
|
3309
|
+
if norm is not None:
|
|
3310
|
+
self.div_norm(
|
|
3311
|
+
s4,
|
|
3312
|
+
(
|
|
3313
|
+
self.backend.bk_expand_dims(
|
|
3314
|
+
self.backend.bk_expand_dims(
|
|
3315
|
+
P1_dic[j1], off_S2
|
|
3316
|
+
),
|
|
3317
|
+
off_S2,
|
|
3318
|
+
)
|
|
3319
|
+
* self.backend.bk_expand_dims(
|
|
3320
|
+
self.backend.bk_expand_dims(
|
|
3321
|
+
P1_dic[j2], off_S2
|
|
3322
|
+
),
|
|
3323
|
+
-1,
|
|
3324
|
+
)
|
|
3325
|
+
)
|
|
3326
|
+
** 0.5,
|
|
3327
|
+
) # [Nbatch, Nmask, Norient3, Norient2, Norient1]
|
|
3328
|
+
### Store S4 as a complex [Nbatch, Nmask, NS4, Norient3, Norient2, Norient1]
|
|
3329
|
+
|
|
3330
|
+
# S4.append(self.backend.bk_reshape(s4,[s4.shape[0],s4.shape[1],
|
|
3331
|
+
# s4.shape[2]*s4.shape[3]*s4.shape[4]]))
|
|
3332
|
+
S4.append(
|
|
3333
|
+
self.backend.bk_expand_dims(s4, off_S4)
|
|
3334
|
+
) # Add a dimension for NS4
|
|
3335
|
+
if calc_var:
|
|
3336
|
+
# VS4.append(self.backend.bk_reshape(vs4,[s4.shape[0],s4.shape[1],
|
|
3337
|
+
# s4.shape[2]*s4.shape[3]*s4.shape[4]]))
|
|
3338
|
+
VS4.append(
|
|
3339
|
+
self.backend.bk_expand_dims(vs4, off_S4)
|
|
3340
|
+
) # Add a dimension for NS4
|
|
3341
|
+
|
|
3342
|
+
### S4_cross = <(|I1 * psi1| * psi3)(|I2 * psi2| * psi3)^*>
|
|
3343
|
+
else:
|
|
3344
|
+
if calc_var:
|
|
3345
|
+
s4, vs4 = self._compute_S4(
|
|
3346
|
+
j1,
|
|
3347
|
+
j2,
|
|
3348
|
+
vmask,
|
|
3349
|
+
M1convPsi_dic,
|
|
3350
|
+
M2convPsi_dic=M2convPsi_dic,
|
|
3351
|
+
calc_var=True,
|
|
3352
|
+
) # [Nbatch, Nmask, Norient3, Norient2, Norient1]
|
|
3353
|
+
else:
|
|
3354
|
+
s4 = self._compute_S4(
|
|
3355
|
+
j1,
|
|
3356
|
+
j2,
|
|
3357
|
+
vmask,
|
|
3358
|
+
M1convPsi_dic,
|
|
3359
|
+
M2convPsi_dic=M2convPsi_dic,
|
|
3360
|
+
return_data=return_data,
|
|
3361
|
+
) # [Nbatch, Nmask, Norient3, Norient2, Norient1]
|
|
3362
|
+
|
|
3363
|
+
if return_data:
|
|
3364
|
+
if S4[j3][j2] is None:
|
|
3365
|
+
S4[j3][j2] = {}
|
|
3366
|
+
if out_nside is not None and out_nside < nside_j1:
|
|
3367
|
+
s4 = self.backend.bk_reduce_mean(
|
|
3368
|
+
self.backend.bk_reshape(
|
|
3369
|
+
s4,
|
|
3370
|
+
[
|
|
3371
|
+
s4.shape[0],
|
|
3372
|
+
12 * out_nside**2,
|
|
3373
|
+
(nside_j1 // out_nside) ** 2,
|
|
3374
|
+
s4.shape[2],
|
|
3375
|
+
s4.shape[3],
|
|
3376
|
+
s4.shape[4],
|
|
3377
|
+
],
|
|
3378
|
+
),
|
|
3379
|
+
2,
|
|
3380
|
+
)
|
|
3381
|
+
S4[j3][j2][j1] = s4
|
|
3382
|
+
else:
|
|
3383
|
+
### Normalize S4 with S2_j [Nbatch, Nmask, Norient_j]
|
|
3384
|
+
if norm is not None:
|
|
3385
|
+
self.div_norm(
|
|
3386
|
+
s4,
|
|
3387
|
+
(
|
|
3388
|
+
self.backend.bk_expand_dims(
|
|
3389
|
+
self.backend.bk_expand_dims(
|
|
3390
|
+
P1_dic[j1], off_S2
|
|
3391
|
+
),
|
|
3392
|
+
off_S2,
|
|
3393
|
+
)
|
|
3394
|
+
* self.backend.bk_expand_dims(
|
|
3395
|
+
self.backend.bk_expand_dims(
|
|
3396
|
+
P2_dic[j2], off_S2
|
|
3397
|
+
),
|
|
3398
|
+
-1,
|
|
3399
|
+
)
|
|
3400
|
+
)
|
|
3401
|
+
** 0.5,
|
|
3402
|
+
) # [Nbatch, Nmask, Norient3, Norient2, Norient1]
|
|
3403
|
+
### Store S4 as a complex [Nbatch, Nmask, NS4, Norient3, Norient2, Norient1]
|
|
3404
|
+
# S4.append(self.backend.bk_reshape(s4,[s4.shape[0],s4.shape[1],
|
|
3405
|
+
# s4.shape[2]*s4.shape[3]*s4.shape[4]]))
|
|
3406
|
+
S4.append(
|
|
3407
|
+
self.backend.bk_expand_dims(s4, off_S4)
|
|
3408
|
+
) # Add a dimension for NS4
|
|
3409
|
+
if calc_var:
|
|
3410
|
+
|
|
3411
|
+
# VS4.append(self.backend.bk_reshape(vs4,[s4.shape[0],s4.shape[1],
|
|
3412
|
+
# s4.shape[2]*s4.shape[3]*s4.shape[4]]))
|
|
3413
|
+
VS4.append(
|
|
3414
|
+
self.backend.bk_expand_dims(vs4, off_S4)
|
|
3415
|
+
) # Add a dimension for NS4
|
|
3416
|
+
|
|
3417
|
+
nside_j1 = nside_j1 // 2
|
|
3418
|
+
nside_j2 = nside_j2 // 2
|
|
3419
|
+
|
|
3420
|
+
###### Reshape for next iteration on j3
|
|
3421
|
+
### Image I1,
|
|
3422
|
+
# downscale the I1 [Nbatch, Npix_j3]
|
|
3423
|
+
if j3 != Jmax - 1:
|
|
3424
|
+
I1 = self.smooth(I1, axis=1)
|
|
3425
|
+
I1 = self.ud_grade_2(I1, axis=1)
|
|
3426
|
+
|
|
3427
|
+
### Image I2
|
|
3428
|
+
if cross:
|
|
3429
|
+
I2 = self.smooth(I2, axis=1)
|
|
3430
|
+
I2 = self.ud_grade_2(I2, axis=1)
|
|
3431
|
+
|
|
3432
|
+
### Modules
|
|
3433
|
+
for j2 in range(0, j3 + 1): # j2 =< j3
|
|
3434
|
+
### Dictionary M1_dic[j2]
|
|
3435
|
+
M1_smooth = self.smooth(
|
|
3436
|
+
M1_dic[j2], axis=1
|
|
3437
|
+
) # [Nbatch, Npix_j3, Norient3]
|
|
3438
|
+
M1_dic[j2] = self.ud_grade_2(
|
|
3439
|
+
M1_smooth, axis=1
|
|
3440
|
+
) # [Nbatch, Npix_j3, Norient3]
|
|
3441
|
+
|
|
3442
|
+
### Dictionary M2_dic[j2]
|
|
3443
|
+
if cross:
|
|
3444
|
+
M2_smooth = self.smooth(
|
|
3445
|
+
M2_dic[j2], axis=1
|
|
3446
|
+
) # [Nbatch, Npix_j3, Norient3]
|
|
3447
|
+
M2_dic[j2] = self.ud_grade_2(
|
|
3448
|
+
M2, axis=1
|
|
3449
|
+
) # [Nbatch, Npix_j3, Norient3]
|
|
3450
|
+
|
|
3451
|
+
### Mask
|
|
3452
|
+
vmask = self.ud_grade_2(vmask, axis=1)
|
|
3453
|
+
|
|
3454
|
+
if self.mask_thres is not None:
|
|
3455
|
+
vmask = self.backend.bk_threshold(vmask, self.mask_thres)
|
|
3456
|
+
|
|
3457
|
+
### NSIDE_j3
|
|
3458
|
+
nside_j3 = nside_j3 // 2
|
|
3459
|
+
|
|
3460
|
+
### Store P1_dic and P2_dic in self
|
|
3461
|
+
if (norm == "auto") and (self.P1_dic is None):
|
|
3462
|
+
self.P1_dic = P1_dic
|
|
3463
|
+
if cross:
|
|
3464
|
+
self.P2_dic = P2_dic
|
|
3465
|
+
"""
|
|
3466
|
+
Sout=[s0]+S1+S2+S3+S4
|
|
3467
|
+
|
|
3468
|
+
if cross:
|
|
3469
|
+
Sout=Sout+S3P
|
|
3470
|
+
if calc_var:
|
|
3471
|
+
SVout=[vs0]+VS1+VS2+VS3+VS4
|
|
3472
|
+
if cross:
|
|
3473
|
+
VSout=VSout+VS3P
|
|
3474
|
+
return self.backend.bk_concat(Sout, 2),self.backend.bk_concat(VSout, 2)
|
|
3475
|
+
|
|
3476
|
+
return self.backend.bk_concat(Sout, 2)
|
|
3477
|
+
"""
|
|
3478
|
+
if not return_data:
|
|
3479
|
+
S1 = self.backend.bk_concat(S1, 2)
|
|
3480
|
+
S2 = self.backend.bk_concat(S2, 2)
|
|
3481
|
+
S3 = self.backend.bk_concat(S3, 2)
|
|
3482
|
+
S4 = self.backend.bk_concat(S4, 2)
|
|
3483
|
+
if cross:
|
|
3484
|
+
S3P = self.backend.bk_concat(S3P, 2)
|
|
3485
|
+
if calc_var:
|
|
3486
|
+
VS1 = self.backend.bk_concat(VS1, 2)
|
|
3487
|
+
VS2 = self.backend.bk_concat(VS2, 2)
|
|
3488
|
+
VS3 = self.backend.bk_concat(VS3, 2)
|
|
3489
|
+
VS4 = self.backend.bk_concat(VS4, 2)
|
|
3490
|
+
if cross:
|
|
3491
|
+
VS3P = self.backend.bk_concat(VS3P, 2)
|
|
3492
|
+
if calc_var:
|
|
3493
|
+
if not cross:
|
|
3494
|
+
return scat_cov(
|
|
3495
|
+
s0, S2, S3, S4, s1=S1, backend=self.backend, use_1D=self.use_1D
|
|
3496
|
+
), scat_cov(
|
|
3497
|
+
vs0,
|
|
3498
|
+
VS2,
|
|
3499
|
+
VS3,
|
|
3500
|
+
VS4,
|
|
3501
|
+
s1=VS1,
|
|
3502
|
+
backend=self.backend,
|
|
3503
|
+
use_1D=self.use_1D,
|
|
3504
|
+
)
|
|
3505
|
+
else:
|
|
3506
|
+
return scat_cov(
|
|
3507
|
+
s0,
|
|
3508
|
+
S2,
|
|
3509
|
+
S3,
|
|
3510
|
+
S4,
|
|
3511
|
+
s1=S1,
|
|
3512
|
+
s3p=S3P,
|
|
3513
|
+
backend=self.backend,
|
|
3514
|
+
use_1D=self.use_1D,
|
|
3515
|
+
), scat_cov(
|
|
3516
|
+
vs0,
|
|
3517
|
+
VS2,
|
|
3518
|
+
VS3,
|
|
3519
|
+
VS4,
|
|
3520
|
+
s1=VS1,
|
|
3521
|
+
s3p=VS3P,
|
|
3522
|
+
backend=self.backend,
|
|
3523
|
+
use_1D=self.use_1D,
|
|
3524
|
+
)
|
|
3525
|
+
else:
|
|
3526
|
+
if not cross:
|
|
3527
|
+
return scat_cov(
|
|
3528
|
+
s0, S2, S3, S4, s1=S1, backend=self.backend, use_1D=self.use_1D
|
|
3529
|
+
)
|
|
3530
|
+
else:
|
|
3531
|
+
return scat_cov(
|
|
3532
|
+
s0,
|
|
3533
|
+
S2,
|
|
3534
|
+
S3,
|
|
3535
|
+
S4,
|
|
3536
|
+
s1=S1,
|
|
3537
|
+
s3p=S3P,
|
|
3538
|
+
backend=self.backend,
|
|
3539
|
+
use_1D=self.use_1D,
|
|
3540
|
+
)
|
|
3541
|
+
|
|
3542
|
+
def eval_new(
|
|
3543
|
+
self,
|
|
3544
|
+
image1,
|
|
3545
|
+
image2=None,
|
|
3546
|
+
mask=None,
|
|
3547
|
+
norm=None,
|
|
3548
|
+
calc_var=False,
|
|
3549
|
+
cmat=None,
|
|
3550
|
+
cmat2=None,
|
|
3551
|
+
Jmax=None,
|
|
3552
|
+
out_nside=None,
|
|
3553
|
+
edge=True
|
|
3554
|
+
):
|
|
3555
|
+
"""
|
|
3556
|
+
Calculates the scattering correlations for a batch of images. Mean are done over pixels.
|
|
3557
|
+
mean of modulus:
|
|
3558
|
+
S1 = <|I * Psi_j3|>
|
|
3559
|
+
Normalization : take the log
|
|
3560
|
+
power spectrum:
|
|
3561
|
+
S2 = <|I * Psi_j3|^2>
|
|
3562
|
+
Normalization : take the log
|
|
3563
|
+
orig. x modulus:
|
|
3564
|
+
S3 = < (I * Psi)_j3 x (|I * Psi_j2| * Psi_j3)^* >
|
|
3565
|
+
Normalization : divide by (S2_j2 * S2_j3)^0.5
|
|
3566
|
+
modulus x modulus:
|
|
3567
|
+
S4 = <(|I * psi1| * psi3)(|I * psi2| * psi3)^*>
|
|
3568
|
+
Normalization : divide by (S2_j1 * S2_j2)^0.5
|
|
3569
|
+
Parameters
|
|
3570
|
+
----------
|
|
3571
|
+
image1: tensor
|
|
3572
|
+
Image on which we compute the scattering coefficients [Nbatch, Npix, 1, 1]
|
|
3573
|
+
image2: tensor
|
|
3574
|
+
Second image. If not None, we compute cross-scattering covariance coefficients.
|
|
3575
|
+
mask:
|
|
3576
|
+
norm: None or str
|
|
3577
|
+
If None no normalization is applied, if 'auto' normalize by the reference S2,
|
|
3578
|
+
if 'self' normalize by the current S2.
|
|
3579
|
+
Returns
|
|
3580
|
+
-------
|
|
3581
|
+
S1, S2, S3, S4 normalized
|
|
3582
|
+
"""
|
|
3583
|
+
return_data = self.return_data
|
|
3584
|
+
NORIENT=self.NORIENT
|
|
3585
|
+
# Check input consistency
|
|
2545
3586
|
if image2 is not None:
|
|
2546
3587
|
if list(image1.shape) != list(image2.shape):
|
|
2547
3588
|
print(
|
|
@@ -2699,13 +3740,19 @@ class funct(FOC.FoCUS):
|
|
|
2699
3740
|
S3P = {}
|
|
2700
3741
|
S4 = {}
|
|
2701
3742
|
else:
|
|
2702
|
-
|
|
2703
|
-
|
|
3743
|
+
result=self.backend.backend.zeros([I1.shape[0],vmask.shape[0],2+2*Jmax*self.NORIENT],
|
|
3744
|
+
dtype=self.backend.backend.float32,
|
|
3745
|
+
device=self.backend.torch_device)
|
|
3746
|
+
vresult=self.backend.backend.zeros([I1.shape[0],vmask.shape[0],2+2*Jmax*self.NORIENT],
|
|
3747
|
+
dtype=self.backend.backend.float32,
|
|
3748
|
+
device=self.backend.torch_device)
|
|
3749
|
+
S1 = self.backend.backend.zeros([1,Jmax*self.NORIENT],dtype=self.backend.backend.float32,device=self.backend.torch_device)
|
|
3750
|
+
S2 = self.backend.backend.zeros([1,Jmax*self.NORIENT],dtype=self.backend.backend.float32,device=self.backend.torch_device)
|
|
2704
3751
|
S3 = []
|
|
2705
3752
|
S4 = []
|
|
2706
3753
|
S3P = []
|
|
2707
|
-
VS1 = []
|
|
2708
|
-
VS2 = []
|
|
3754
|
+
VS1 = self.backend.backend.zeros([1,Jmax*self.NORIENT],dtype=self.backend.backend.float32,device=self.backend.torch_device)
|
|
3755
|
+
VS2 = self.backend.backend.zeros([1,Jmax*self.NORIENT],dtype=self.backend.backend.float32,device=self.backend.torch_device)
|
|
2709
3756
|
VS3 = []
|
|
2710
3757
|
VS3P = []
|
|
2711
3758
|
VS4 = []
|
|
@@ -2749,14 +3796,20 @@ class funct(FOC.FoCUS):
|
|
|
2749
3796
|
s0, l_vs0 = self.masked_mean(
|
|
2750
3797
|
self.backend.bk_L1(I1 * I2), vmask, axis=1, calc_var=True
|
|
2751
3798
|
)
|
|
2752
|
-
vs0 = self.backend.bk_concat([l_vs0, l_vs0], 1)
|
|
2753
|
-
s0 = self.backend.bk_concat([s0, l_vs0], 1)
|
|
3799
|
+
#vs0 = self.backend.bk_concat([l_vs0, l_vs0], 1)
|
|
3800
|
+
#s0 = self.backend.bk_concat([s0, l_vs0], 1)
|
|
3801
|
+
result[:,:,0]=s0
|
|
3802
|
+
result[:,:,1]=l_vs0
|
|
3803
|
+
vresult[:,:,0]=l_vs0
|
|
3804
|
+
vresult[:,:,1]=l_vs0
|
|
2754
3805
|
#### COMPUTE S1, S2, S3 and S4
|
|
2755
3806
|
nside_j3 = nside # NSIDE start (nside_j3 = nside / 2^j3)
|
|
2756
3807
|
|
|
2757
3808
|
# a remettre comme avant
|
|
2758
3809
|
M1_dic={}
|
|
2759
3810
|
|
|
3811
|
+
M2_dic={}
|
|
3812
|
+
|
|
2760
3813
|
for j3 in range(Jmax):
|
|
2761
3814
|
|
|
2762
3815
|
if edge:
|
|
@@ -2805,7 +3858,9 @@ class funct(FOC.FoCUS):
|
|
|
2805
3858
|
M1_square = conv1 * self.backend.bk_conjugate(
|
|
2806
3859
|
conv1
|
|
2807
3860
|
) # [Nbatch, Npix_j3, Norient3]
|
|
3861
|
+
|
|
2808
3862
|
M1 = self.backend.bk_L1(M1_square) # [Nbatch, Npix_j3, Norient3]
|
|
3863
|
+
|
|
2809
3864
|
# Store M1_j3 in a dictionary
|
|
2810
3865
|
M1_dic[j3] = M1
|
|
2811
3866
|
|
|
@@ -2821,13 +3876,15 @@ class funct(FOC.FoCUS):
|
|
|
2821
3876
|
s2, vs2 = self.masked_mean(
|
|
2822
3877
|
M1_square, vmask, axis=1, rank=j3, calc_var=True
|
|
2823
3878
|
)
|
|
3879
|
+
#s2=self.backend.bk_flatten(self.backend.bk_real(s2))
|
|
3880
|
+
#vs2=self.backend.bk_flatten(vs2)
|
|
2824
3881
|
else:
|
|
2825
3882
|
s2 = self.masked_mean(M1_square, vmask, axis=1, rank=j3)
|
|
2826
3883
|
|
|
2827
3884
|
if cond_init_P1_dic:
|
|
2828
3885
|
# We fill P1_dic with S2 for normalisation of S3 and S4
|
|
2829
|
-
P1_dic[j3] = self.backend.bk_real(s2) # [Nbatch, Nmask, Norient3]
|
|
2830
|
-
|
|
3886
|
+
P1_dic[j3] = self.backend.bk_real(self.backend.bk_real(s2)) # [Nbatch, Nmask, Norient3]
|
|
3887
|
+
|
|
2831
3888
|
# We store S2_auto to return it [Nbatch, Nmask, NS2, Norient3]
|
|
2832
3889
|
if return_data:
|
|
2833
3890
|
if S2 is None:
|
|
@@ -2849,7 +3906,7 @@ class funct(FOC.FoCUS):
|
|
|
2849
3906
|
else:
|
|
2850
3907
|
if norm == "auto": # Normalize S2
|
|
2851
3908
|
s2 /= P1_dic[j3]
|
|
2852
|
-
|
|
3909
|
+
"""
|
|
2853
3910
|
S2.append(
|
|
2854
3911
|
self.backend.bk_expand_dims(s2, off_S2)
|
|
2855
3912
|
) # Add a dimension for NS2
|
|
@@ -2857,7 +3914,11 @@ class funct(FOC.FoCUS):
|
|
|
2857
3914
|
VS2.append(
|
|
2858
3915
|
self.backend.bk_expand_dims(vs2, off_S2)
|
|
2859
3916
|
) # Add a dimension for NS2
|
|
2860
|
-
|
|
3917
|
+
"""
|
|
3918
|
+
#print(s2.shape,result[:,:,2+j3*NORIENT*2:2+j3*NORIENT*2+NORIENT].shape,result.shape,2+j3*NORIENT*2)
|
|
3919
|
+
result[:,:,2+j3*NORIENT*2:2+j3*NORIENT*2+NORIENT]=s2
|
|
3920
|
+
if calc_var:
|
|
3921
|
+
vresult[:,:,2+j3*NORIENT*2:2+j3*NORIENT*2+NORIENT]=vs2
|
|
2861
3922
|
#### S1_auto computation
|
|
2862
3923
|
### Image 1 : S1 = < M1 >_pix
|
|
2863
3924
|
# Apply the mask [Nmask, Npix_j3] and average over pixels
|
|
@@ -2868,10 +3929,13 @@ class funct(FOC.FoCUS):
|
|
|
2868
3929
|
s1, vs1 = self.masked_mean(
|
|
2869
3930
|
M1, vmask, axis=1, rank=j3, calc_var=True
|
|
2870
3931
|
) # [Nbatch, Nmask, Norient3]
|
|
3932
|
+
#s1=self.backend.bk_flatten(self.backend.bk_real(s1))
|
|
3933
|
+
#vs1=self.backend.bk_flatten(vs1)
|
|
2871
3934
|
else:
|
|
2872
3935
|
s1 = self.masked_mean(
|
|
2873
3936
|
M1, vmask, axis=1, rank=j3
|
|
2874
3937
|
) # [Nbatch, Nmask, Norient3]
|
|
3938
|
+
#s1=self.backend.bk_flatten(self.backend.bk_real(s1))
|
|
2875
3939
|
|
|
2876
3940
|
if return_data:
|
|
2877
3941
|
if out_nside is not None and out_nside < nside_j3:
|
|
@@ -2892,6 +3956,10 @@ class funct(FOC.FoCUS):
|
|
|
2892
3956
|
### Normalize S1
|
|
2893
3957
|
if norm is not None:
|
|
2894
3958
|
self.div_norm(s1, (P1_dic[j3]) ** 0.5)
|
|
3959
|
+
result[:,:,2+j3*NORIENT*2+NORIENT:2+j3*NORIENT*2+2*NORIENT]=s1
|
|
3960
|
+
if calc_var:
|
|
3961
|
+
vresult[:,:,2+j3*NORIENT*2+NORIENT:2+j3*NORIENT*2+2*NORIENT]=vs1
|
|
3962
|
+
"""
|
|
2895
3963
|
### We store S1 for image1 [Nbatch, Nmask, NS1, Norient3]
|
|
2896
3964
|
S1.append(
|
|
2897
3965
|
self.backend.bk_expand_dims(s1, off_S2)
|
|
@@ -2900,6 +3968,7 @@ class funct(FOC.FoCUS):
|
|
|
2900
3968
|
VS1.append(
|
|
2901
3969
|
self.backend.bk_expand_dims(vs1, off_S2)
|
|
2902
3970
|
) # Add a dimension for NS1
|
|
3971
|
+
"""
|
|
2903
3972
|
|
|
2904
3973
|
else: # Cross
|
|
2905
3974
|
### Make the convolution I2 * Psi_j3
|
|
@@ -3048,7 +4117,7 @@ class funct(FOC.FoCUS):
|
|
|
3048
4117
|
|
|
3049
4118
|
###### S3
|
|
3050
4119
|
nside_j2 = nside_j3
|
|
3051
|
-
for j2 in range(0
|
|
4120
|
+
for j2 in range(0,-1): # j3 + 1): # j2 <= j3
|
|
3052
4121
|
if return_data:
|
|
3053
4122
|
if S4[j3] is None:
|
|
3054
4123
|
S4[j3] = {}
|
|
@@ -3463,6 +4532,19 @@ class funct(FOC.FoCUS):
|
|
|
3463
4532
|
|
|
3464
4533
|
return self.backend.bk_concat(Sout, 2)
|
|
3465
4534
|
"""
|
|
4535
|
+
if calc_var:
|
|
4536
|
+
return result,vresult
|
|
4537
|
+
else:
|
|
4538
|
+
return result
|
|
4539
|
+
if calc_var:
|
|
4540
|
+
for k in S1:
|
|
4541
|
+
print(k.shape,k.dtype)
|
|
4542
|
+
for k in S2:
|
|
4543
|
+
print(k.shape,k.dtype)
|
|
4544
|
+
print(s0.shape,s0.dtype)
|
|
4545
|
+
return self.backend.bk_concat([s0]+S1+S2,axis=1),self.backend.bk_concat([vs0]+VS1+VS2,axis=1)
|
|
4546
|
+
else:
|
|
4547
|
+
return self.backend.bk_concat([s0]+S1+S2,axis=1)
|
|
3466
4548
|
if not return_data:
|
|
3467
4549
|
S1 = self.backend.bk_concat(S1, 2)
|
|
3468
4550
|
S2 = self.backend.bk_concat(S2, 2)
|
|
@@ -3526,7 +4608,6 @@ class funct(FOC.FoCUS):
|
|
|
3526
4608
|
backend=self.backend,
|
|
3527
4609
|
use_1D=self.use_1D,
|
|
3528
4610
|
)
|
|
3529
|
-
|
|
3530
4611
|
def clean_norm(self):
|
|
3531
4612
|
self.P1_dic = None
|
|
3532
4613
|
self.P2_dic = None
|
|
@@ -3644,7 +4725,570 @@ class funct(FOC.FoCUS):
|
|
|
3644
4725
|
s4, vmask, axis=1, rank=j2
|
|
3645
4726
|
) # [Nbatch, Nmask, Norient3, Norient2, Norient1]
|
|
3646
4727
|
return s4
|
|
4728
|
+
|
|
4729
|
+
def computer_filter(self,M,N,J,L):
|
|
4730
|
+
'''
|
|
4731
|
+
This function is strongly inspire by the package https://github.com/SihaoCheng/scattering_transform
|
|
4732
|
+
Done by Sihao Cheng and Rudy Morel.
|
|
4733
|
+
'''
|
|
4734
|
+
|
|
4735
|
+
filter = np.zeros([J, L, M, N],dtype='complex64')
|
|
4736
|
+
|
|
4737
|
+
slant=4.0 / L
|
|
4738
|
+
|
|
4739
|
+
for j in range(J):
|
|
4740
|
+
|
|
4741
|
+
for l in range(L):
|
|
4742
|
+
|
|
4743
|
+
theta = (int(L-L/2-1)-l) * np.pi / L
|
|
4744
|
+
sigma = 0.8 * 2**j
|
|
4745
|
+
xi = 3.0 / 4.0 * np.pi /2**j
|
|
4746
|
+
|
|
4747
|
+
R = np.array([[np.cos(theta), -np.sin(theta)], [np.sin(theta), np.cos(theta)]], np.float64)
|
|
4748
|
+
R_inv = np.array([[np.cos(theta), np.sin(theta)], [-np.sin(theta), np.cos(theta)]], np.float64)
|
|
4749
|
+
D = np.array([[1, 0], [0, slant * slant]])
|
|
4750
|
+
curv = np.matmul(R, np.matmul(D, R_inv)) / ( 2 * sigma * sigma)
|
|
4751
|
+
|
|
4752
|
+
gab = np.zeros((M, N), np.complex128)
|
|
4753
|
+
xx = np.empty((2,2, M, N))
|
|
4754
|
+
yy = np.empty((2,2, M, N))
|
|
4755
|
+
|
|
4756
|
+
for ii, ex in enumerate([-1, 0]):
|
|
4757
|
+
for jj, ey in enumerate([-1, 0]):
|
|
4758
|
+
xx[ii,jj], yy[ii,jj] = np.mgrid[
|
|
4759
|
+
ex * M : M + ex * M,
|
|
4760
|
+
ey * N : N + ey * N]
|
|
4761
|
+
|
|
4762
|
+
arg = -(curv[0, 0] * xx * xx + (curv[0, 1] + curv[1, 0]) * xx * yy + curv[1, 1] * yy * yy)
|
|
4763
|
+
argi = arg + 1.j * (xx * xi * np.cos(theta) + yy * xi * np.sin(theta))
|
|
4764
|
+
|
|
4765
|
+
gabi = np.exp(argi).sum((0,1))
|
|
4766
|
+
gab = np.exp(arg).sum((0,1))
|
|
4767
|
+
|
|
4768
|
+
norm_factor = 2 * np.pi * sigma * sigma / slant
|
|
4769
|
+
|
|
4770
|
+
gab = gab / norm_factor
|
|
4771
|
+
|
|
4772
|
+
gabi = gabi / norm_factor
|
|
4773
|
+
|
|
4774
|
+
K = gabi.sum() / gab.sum()
|
|
4775
|
+
|
|
4776
|
+
# Apply the Gaussian
|
|
4777
|
+
filter[j, l] = np.fft.fft2(gabi-K*gab)
|
|
4778
|
+
filter[j,l,0,0]=0.0
|
|
4779
|
+
|
|
4780
|
+
return self.backend.bk_cast(filter)
|
|
4781
|
+
|
|
4782
|
+
# ------------------------------------------------------------------------------------------
|
|
4783
|
+
#
|
|
4784
|
+
# utility functions
|
|
4785
|
+
#
|
|
4786
|
+
# ------------------------------------------------------------------------------------------
|
|
4787
|
+
def cut_high_k_off(self,data_f, dx, dy):
|
|
4788
|
+
'''
|
|
4789
|
+
This function is strongly inspire by the package https://github.com/SihaoCheng/scattering_transform
|
|
4790
|
+
Done by Sihao Cheng and Rudy Morel.
|
|
4791
|
+
'''
|
|
4792
|
+
if_xodd = (data_f.shape[-2]%2==1)
|
|
4793
|
+
if_yodd = (data_f.shape[-1]%2==1)
|
|
4794
|
+
result = self.backend.backend.cat(
|
|
4795
|
+
(self.backend.backend.cat(
|
|
4796
|
+
( data_f[...,:dx+if_xodd, :dy+if_yodd] , data_f[...,-dx:, :dy+if_yodd]
|
|
4797
|
+
), -2),
|
|
4798
|
+
self.backend.backend.cat(
|
|
4799
|
+
( data_f[...,:dx+if_xodd, -dy:] , data_f[...,-dx:, -dy:]
|
|
4800
|
+
), -2)
|
|
4801
|
+
),-1)
|
|
4802
|
+
return result
|
|
4803
|
+
# ---------------------------------------------------------------------------
|
|
4804
|
+
#
|
|
4805
|
+
# utility functions for computing scattering coef and covariance
|
|
4806
|
+
#
|
|
4807
|
+
# ---------------------------------------------------------------------------
|
|
4808
|
+
|
|
4809
|
+
def get_dxdy(self, j,M,N):
|
|
4810
|
+
'''
|
|
4811
|
+
This function is strongly inspire by the package https://github.com/SihaoCheng/scattering_transform
|
|
4812
|
+
Done by Sihao Cheng and Rudy Morel.
|
|
4813
|
+
'''
|
|
4814
|
+
dx = int(max( 8, min( np.ceil(M/2**j), M//2 ) ))
|
|
4815
|
+
dy = int(max( 8, min( np.ceil(N/2**j), N//2 ) ))
|
|
4816
|
+
return dx, dy
|
|
4817
|
+
|
|
4818
|
+
|
|
4819
|
+
|
|
4820
|
+
def get_edge_masks(self,M, N, J, d0=1):
|
|
4821
|
+
'''
|
|
4822
|
+
This function is strongly inspire by the package https://github.com/SihaoCheng/scattering_transform
|
|
4823
|
+
Done by Sihao Cheng and Rudy Morel.
|
|
4824
|
+
'''
|
|
4825
|
+
edge_masks = self.backend.backend.empty((J, M, N))
|
|
4826
|
+
X, Y = self.backend.backend.meshgrid(self.backend.backend.arange(M), self.backend.backend.arange(N), indexing='ij')
|
|
4827
|
+
for j in range(J):
|
|
4828
|
+
edge_dx = min(M//4, 2**j*d0)
|
|
4829
|
+
edge_dy = min(N//4, 2**j*d0)
|
|
4830
|
+
edge_masks[j] = (X>=edge_dx) * (X<=M-edge_dx) * (Y>=edge_dy) * (Y<=N-edge_dy)
|
|
4831
|
+
return edge_masks.to(self.backend.torch_device)
|
|
4832
|
+
|
|
4833
|
+
# ---------------------------------------------------------------------------
|
|
4834
|
+
#
|
|
4835
|
+
# scattering cov
|
|
4836
|
+
#
|
|
4837
|
+
# ---------------------------------------------------------------------------
|
|
4838
|
+
def scattering_cov(
|
|
4839
|
+
self, data, Jmax=None,
|
|
4840
|
+
if_large_batch=False,
|
|
4841
|
+
S4_criteria=None,
|
|
4842
|
+
use_ref=False,
|
|
4843
|
+
normalization='S2',
|
|
4844
|
+
edge=False,
|
|
4845
|
+
pseudo_coef=1,
|
|
4846
|
+
get_variance=False,
|
|
4847
|
+
ref_sigma=None,
|
|
4848
|
+
iso_ang=False
|
|
4849
|
+
):
|
|
4850
|
+
'''
|
|
4851
|
+
Calculates the scattering correlations for a batch of images, including:
|
|
4852
|
+
|
|
4853
|
+
This function is strongly inspire by the package https://github.com/SihaoCheng/scattering_transform
|
|
4854
|
+
Done by Sihao Cheng and Rudy Morel.
|
|
4855
|
+
|
|
4856
|
+
orig. x orig.:
|
|
4857
|
+
P00 = <(I * psi)(I * psi)*> = L2(I * psi)^2
|
|
4858
|
+
orig. x modulus:
|
|
4859
|
+
C01 = <(I * psi2)(|I * psi1| * psi2)*> / factor
|
|
4860
|
+
when normalization == 'P00', factor = L2(I * psi2) * L2(I * psi1)
|
|
4861
|
+
when normalization == 'P11', factor = L2(I * psi2) * L2(|I * psi1| * psi2)
|
|
4862
|
+
modulus x modulus:
|
|
4863
|
+
C11_pre_norm = <(|I * psi1| * psi3)(|I * psi2| * psi3)>
|
|
4864
|
+
C11 = C11_pre_norm / factor
|
|
4865
|
+
when normalization == 'P00', factor = L2(I * psi1) * L2(I * psi2)
|
|
4866
|
+
when normalization == 'P11', factor = L2(|I * psi1| * psi3) * L2(|I * psi2| * psi3)
|
|
4867
|
+
modulus x modulus (auto):
|
|
4868
|
+
P11 = <(|I * psi1| * psi2)(|I * psi1| * psi2)*>
|
|
4869
|
+
Parameters
|
|
4870
|
+
----------
|
|
4871
|
+
data : numpy array or torch tensor
|
|
4872
|
+
image set, with size [N_image, x-sidelength, y-sidelength]
|
|
4873
|
+
if_large_batch : Bool (=False)
|
|
4874
|
+
It is recommended to use "False" unless one meets a memory issue
|
|
4875
|
+
C11_criteria : str or None (=None)
|
|
4876
|
+
Only C11 coefficients that satisfy this criteria will be computed.
|
|
4877
|
+
Any expressions of j1, j2, and j3 that can be evaluated as a Bool
|
|
4878
|
+
is accepted.The default "None" corresponds to "j1 <= j2 <= j3".
|
|
4879
|
+
use_ref : Bool (=False)
|
|
4880
|
+
When normalizing, whether or not to use the normalization factor
|
|
4881
|
+
computed from a reference field. For just computing the statistics,
|
|
4882
|
+
the default is False. However, for synthesis, set it to "True" will
|
|
4883
|
+
stablize the optimization process.
|
|
4884
|
+
normalization : str 'P00' or 'P11' (='P00')
|
|
4885
|
+
Whether 'P00' or 'P11' is used as the normalization factor for C01
|
|
4886
|
+
and C11.
|
|
4887
|
+
remove_edge : Bool (=False)
|
|
4888
|
+
If true, the edge region with a width of rougly the size of the largest
|
|
4889
|
+
wavelet involved is excluded when taking the global average to obtain
|
|
4890
|
+
the scattering coefficients.
|
|
4891
|
+
|
|
4892
|
+
Returns
|
|
4893
|
+
-------
|
|
4894
|
+
'P00' : torch tensor with size [N_image, J, L] (# image, j1, l1)
|
|
4895
|
+
the power in each wavelet bands (the orig. x orig. term)
|
|
4896
|
+
'S1' : torch tensor with size [N_image, J, L] (# image, j1, l1)
|
|
4897
|
+
the 1st-order scattering coefficients, i.e., the mean of wavelet modulus fields
|
|
4898
|
+
'C01' : torch tensor with size [N_image, J, J, L, L] (# image, j1, j2, l1, l2)
|
|
4899
|
+
the orig. x modulus terms. Elements with j1 < j2 are all set to np.nan and not computed.
|
|
4900
|
+
'C11' : torch tensor with size [N_image, J, J, J, L, L, L] (# image, j1, j2, j3, l1, l2, l3)
|
|
4901
|
+
the modulus x modulus terms. Elements not satisfying j1 <= j2 <= j3 and the conditions
|
|
4902
|
+
defined in 'C11_criteria' are all set to np.nan and not computed.
|
|
4903
|
+
'C11_pre_norm' and 'C11_pre_norm_iso': pre-normalized modulus x modulus terms.
|
|
4904
|
+
'P11' : torch tensor with size [N_image, J, J, L, L] (# image, j1, j2, l1, l2)
|
|
4905
|
+
the modulus x modulus terms with the two wavelets within modulus the same. Elements not following
|
|
4906
|
+
j1 <= j3 are set to np.nan and not computed.
|
|
4907
|
+
'P11_iso' : torch tensor with size [N_image, J, J, L] (# image, j1, j2, l2-l1)
|
|
4908
|
+
'P11' averaged over l1 while keeping l2-l1 constant.
|
|
4909
|
+
'''
|
|
4910
|
+
if S4_criteria is None:
|
|
4911
|
+
S4_criteria = 'j2>=j1'
|
|
4912
|
+
|
|
4913
|
+
# determine jmax and nside corresponding to the input map
|
|
4914
|
+
im_shape = data.shape
|
|
4915
|
+
if self.use_2D:
|
|
4916
|
+
if len(data.shape) == 2:
|
|
4917
|
+
nside = np.min([im_shape[0], im_shape[1]])
|
|
4918
|
+
M,N = im_shape[0],im_shape[1]
|
|
4919
|
+
N_image = 1
|
|
4920
|
+
else:
|
|
4921
|
+
nside = np.min([im_shape[1], im_shape[2]])
|
|
4922
|
+
M,N = im_shape[1],im_shape[2]
|
|
4923
|
+
N_image = data.shape[0]
|
|
4924
|
+
J = int(np.log(nside) / np.log(2))-1 # Number of j scales
|
|
4925
|
+
elif self.use_1D:
|
|
4926
|
+
if len(data.shape) == 2:
|
|
4927
|
+
npix = int(im_shape[1]) # Number of pixels
|
|
4928
|
+
N_image = 1
|
|
4929
|
+
else:
|
|
4930
|
+
npix = int(im_shape[0]) # Number of pixels
|
|
4931
|
+
N_image = data.shape[0]
|
|
4932
|
+
|
|
4933
|
+
nside = int(npix)
|
|
4934
|
+
|
|
4935
|
+
J = int(np.log(nside) / np.log(2))-1 # Number of j scales
|
|
4936
|
+
else:
|
|
4937
|
+
if len(data.shape) == 2:
|
|
4938
|
+
npix = int(im_shape[1]) # Number of pixels
|
|
4939
|
+
N_image = 1
|
|
4940
|
+
else:
|
|
4941
|
+
npix = int(im_shape[0]) # Number of pixels
|
|
4942
|
+
N_image = data.shape[0]
|
|
4943
|
+
|
|
4944
|
+
nside = int(np.sqrt(npix // 12))
|
|
3647
4945
|
|
|
4946
|
+
J = int(np.log(nside) / np.log(2)) # Number of j scales
|
|
4947
|
+
|
|
4948
|
+
if Jmax is None:
|
|
4949
|
+
Jmax = J # Number of steps for the loop on scales
|
|
4950
|
+
if Jmax>J:
|
|
4951
|
+
print('==========\n\n')
|
|
4952
|
+
print('The Jmax you requested is larger than the data size, which may cause problems while computing the scattering transform.')
|
|
4953
|
+
print('\n\n==========')
|
|
4954
|
+
|
|
4955
|
+
L=self.NORIENT
|
|
4956
|
+
|
|
4957
|
+
if (M,N,J,L) not in self.filters_set:
|
|
4958
|
+
self.filters_set[(M,N,J,L)] = self.computer_filter(M,N,J,L) #self.computer_filter(M,N,J,L)
|
|
4959
|
+
|
|
4960
|
+
filters_set = self.filters_set[(M,N,J,L)]
|
|
4961
|
+
|
|
4962
|
+
#weight = self.weight
|
|
4963
|
+
if use_ref:
|
|
4964
|
+
if normalization=='S2':
|
|
4965
|
+
ref_S2 = self.ref_scattering_cov_S2
|
|
4966
|
+
else:
|
|
4967
|
+
ref_P11 = self.ref_scattering_cov['P11']
|
|
4968
|
+
|
|
4969
|
+
# convert numpy array input into self.backend.bk_ tensors
|
|
4970
|
+
data = self.backend.bk_cast(data)
|
|
4971
|
+
data_f = self.backend.bk_fftn(data, dim=(-2,-1))
|
|
4972
|
+
|
|
4973
|
+
# initialize tensors for scattering coefficients
|
|
4974
|
+
S2 = self.backend.bk_zeros((N_image,J,L), dtype=data.dtype)
|
|
4975
|
+
S1 = self.backend.bk_zeros((N_image,J,L), dtype=data.dtype)
|
|
4976
|
+
|
|
4977
|
+
Ndata_S3 = J*(J+1)//2
|
|
4978
|
+
Ndata_S4 = J*(J+1)*(J+2)//6
|
|
4979
|
+
J_S4={}
|
|
4980
|
+
|
|
4981
|
+
S3 = self.backend.bk_zeros((N_image,Ndata_S3,L,L), dtype=data_f.dtype)
|
|
4982
|
+
S4_pre_norm = self.backend.bk_zeros((N_image,Ndata_S4,L,L,L), dtype=data_f.dtype)
|
|
4983
|
+
S4 = self.backend.bk_zeros((N_image,Ndata_S4,L,L,L), dtype=data_f.dtype)
|
|
4984
|
+
|
|
4985
|
+
# variance
|
|
4986
|
+
if get_variance:
|
|
4987
|
+
S2_sigma = self.backend.bk_zeros((N_image,J,L), dtype=data.dtype)
|
|
4988
|
+
S1_sigma = self.backend.bk_zeros((N_image,J,L), dtype=data.dtype)
|
|
4989
|
+
S3_sigma = self.backend.bk_zeros((N_image,Ndata_S3,L,L), dtype=data_f.dtype)
|
|
4990
|
+
S4_sigma = self.backend.bk_zeros((N_image,Ndata_S4,L,L,L), dtype=data_f.dtype)
|
|
4991
|
+
|
|
4992
|
+
if iso_ang:
|
|
4993
|
+
S3_iso = self.backend.bk_zeros((N_image,Ndata_S3,L), dtype=data_f.dtype)
|
|
4994
|
+
S4_iso = self.backend.bk_zeros((N_image,Ndata_S4,L,L), dtype=data_f.dtype)
|
|
4995
|
+
if get_variance:
|
|
4996
|
+
S3_sigma_iso = self.backend.bk_zeros((N_image,Ndata_S3,L), dtype=data_f.dtype)
|
|
4997
|
+
S4_sigma_iso = self.backend.bk_zeros((N_image,Ndata_S4,L,L), dtype=data_f.dtype)
|
|
4998
|
+
|
|
4999
|
+
# calculate scattering fields
|
|
5000
|
+
if self.use_2D:
|
|
5001
|
+
if len(data.shape) == 2:
|
|
5002
|
+
I1 = self.backend.bk_ifftn(
|
|
5003
|
+
data_f[None,None,None,:,:] * filters_set[None,:J,:,:,:], dim=(-2,-1)
|
|
5004
|
+
).abs()
|
|
5005
|
+
else:
|
|
5006
|
+
I1 = self.backend.bk_ifftn(
|
|
5007
|
+
data_f[:,None,None,:,:] * filters_set[None,:J,:,:,:], dim=(-2,-1)
|
|
5008
|
+
).abs()
|
|
5009
|
+
elif self.use_1D:
|
|
5010
|
+
if len(data.shape) == 1:
|
|
5011
|
+
I1 = self.backend.bk_ifftn(
|
|
5012
|
+
data_f[None,None,None,:] * filters_set[None,:J,:,:], dim=(-1)
|
|
5013
|
+
).abs()
|
|
5014
|
+
else:
|
|
5015
|
+
I1 = self.backend.bk_ifftn(
|
|
5016
|
+
data_f[:,None,None,:] * filters_set[None,:J,:,:], dim=(-1)
|
|
5017
|
+
).abs()
|
|
5018
|
+
else:
|
|
5019
|
+
print('todo')
|
|
5020
|
+
|
|
5021
|
+
I1_f= self.backend.bk_fftn(I1, dim=(-2,-1))
|
|
5022
|
+
|
|
5023
|
+
#
|
|
5024
|
+
if edge:
|
|
5025
|
+
if (M,N,J) not in self.edge_masks:
|
|
5026
|
+
self.edge_masks[(M,N,J)] = self.get_edge_masks(M,N,J)
|
|
5027
|
+
edge_mask = self.edge_masks[(M,N,J)][:,None,:,:]
|
|
5028
|
+
edge_mask = edge_mask / edge_mask.mean((-2,-1))[:,:,None,None]
|
|
5029
|
+
else:
|
|
5030
|
+
edge_mask = 1
|
|
5031
|
+
S2 = (I1**2 * edge_mask).mean((-2,-1))
|
|
5032
|
+
S1 = (I1 * edge_mask).mean((-2,-1))
|
|
5033
|
+
|
|
5034
|
+
if get_variance:
|
|
5035
|
+
S2_sigma = (I1**2 * edge_mask).std((-2,-1))
|
|
5036
|
+
S1_sigma = (I1 * edge_mask).std((-2,-1))
|
|
5037
|
+
|
|
5038
|
+
if pseudo_coef != 1:
|
|
5039
|
+
I1 = I1**pseudo_coef
|
|
5040
|
+
|
|
5041
|
+
Ndata_S3=0
|
|
5042
|
+
Ndata_S4=0
|
|
5043
|
+
|
|
5044
|
+
# calculate the covariance and correlations of the scattering fields
|
|
5045
|
+
# only use the low-k Fourier coefs when calculating large-j scattering coefs.
|
|
5046
|
+
for j3 in range(0,J):
|
|
5047
|
+
J_S4[j3]=Ndata_S4
|
|
5048
|
+
|
|
5049
|
+
dx3, dy3 = self.get_dxdy(j3,M,N)
|
|
5050
|
+
I1_f_small = self.cut_high_k_off(I1_f[:,:j3+1], dx3, dy3) # Nimage, J, L, x, y
|
|
5051
|
+
data_f_small = self.cut_high_k_off(data_f, dx3, dy3)
|
|
5052
|
+
if edge:
|
|
5053
|
+
I1_small = self.backend.bk_ifftn(I1_f_small, dim=(-2,-1), norm='ortho')
|
|
5054
|
+
data_small = self.backend.bk_ifftn(data_f_small, dim=(-2,-1), norm='ortho')
|
|
5055
|
+
wavelet_f3 = self.cut_high_k_off(filters_set[j3], dx3, dy3) # L,x,y
|
|
5056
|
+
_, M3, N3 = wavelet_f3.shape
|
|
5057
|
+
wavelet_f3_squared = wavelet_f3**2
|
|
5058
|
+
edge_dx = min(4, int(2**j3*dx3*2/M))
|
|
5059
|
+
edge_dy = min(4, int(2**j3*dy3*2/N))
|
|
5060
|
+
# a normalization change due to the cutoff of frequency space
|
|
5061
|
+
fft_factor = 1 /(M3*N3) * (M3*N3/M/N)**2
|
|
5062
|
+
for j2 in range(0,j3+1):
|
|
5063
|
+
I1_f2_wf3_small = I1_f_small[:,j2].view(N_image,L,1,M3,N3) * wavelet_f3.view(1,1,L,M3,N3)
|
|
5064
|
+
I1_f2_wf3_2_small = I1_f_small[:,j2].view(N_image,L,1,M3,N3) * wavelet_f3_squared.view(1,1,L,M3,N3)
|
|
5065
|
+
if edge:
|
|
5066
|
+
I12_w3_small = self.backend.bk_ifftn(I1_f2_wf3_small, dim=(-2,-1), norm='ortho')
|
|
5067
|
+
I12_w3_2_small = self.backend.bk_ifftn(I1_f2_wf3_2_small, dim=(-2,-1), norm='ortho')
|
|
5068
|
+
if use_ref:
|
|
5069
|
+
if normalization=='P11':
|
|
5070
|
+
norm_factor_S3 = (ref_S2[:,None,j3,:] * ref_P11[:,j2,j3,:,:]**pseudo_coef)**0.5
|
|
5071
|
+
if normalization=='S2':
|
|
5072
|
+
norm_factor_S3 = (ref_S2[:,None,j3,:] * ref_S2[:,j2,:,None]**pseudo_coef)**0.5
|
|
5073
|
+
else:
|
|
5074
|
+
if normalization=='P11':
|
|
5075
|
+
# [N_image,l2,l3,x,y]
|
|
5076
|
+
P11_temp = (I1_f2_wf3_small.abs()**2).mean((-2,-1)) * fft_factor
|
|
5077
|
+
norm_factor_S3 = (S2[:,None,j3,:] * P11_temp**pseudo_coef)**0.5
|
|
5078
|
+
if normalization=='S2':
|
|
5079
|
+
norm_factor_S3 = (S2[:,None,j3,:] * S2[:,j2,:,None]**pseudo_coef)**0.5
|
|
5080
|
+
|
|
5081
|
+
if not edge:
|
|
5082
|
+
S3[:,Ndata_S3,:,:] = (
|
|
5083
|
+
data_f_small.view(N_image,1,1,M3,N3) * self.backend.bk_conjugate(I1_f2_wf3_small)
|
|
5084
|
+
).mean((-2,-1)) * fft_factor / norm_factor_S3
|
|
5085
|
+
|
|
5086
|
+
if get_variance:
|
|
5087
|
+
S3_sigma[:,Ndata_S3,:,:] = (
|
|
5088
|
+
data_f_small.view(N_image,1,1,M3,N3) * self.backend.bk_conjugate(I1_f2_wf3_small)
|
|
5089
|
+
).std((-2,-1)) * fft_factor / norm_factor_S3
|
|
5090
|
+
else:
|
|
5091
|
+
|
|
5092
|
+
S3[:,Ndata_S3,:,:] = (
|
|
5093
|
+
data_small.view(N_image,1,1,M3,N3) * self.backend.bk_conjugate(I12_w3_small)
|
|
5094
|
+
)[...,edge_dx:M3-edge_dx, edge_dy:N3-edge_dy].mean((-2,-1)) * fft_factor / norm_factor_S3
|
|
5095
|
+
if get_variance:
|
|
5096
|
+
S3_sigma[:,Ndata_S3,:,:] = (
|
|
5097
|
+
data_small.view(N_image,1,1,M3,N3) * self.backend.bk_conjugate(I12_w3_small)
|
|
5098
|
+
)[...,edge_dx:M3-edge_dx, edge_dy:N3-edge_dy].std((-2,-1)) * fft_factor / norm_factor_S3
|
|
5099
|
+
Ndata_S3+=1
|
|
5100
|
+
if j2 <= j3:
|
|
5101
|
+
beg_n=Ndata_S4
|
|
5102
|
+
for j1 in range(0, j2+1):
|
|
5103
|
+
if eval(S4_criteria):
|
|
5104
|
+
if not edge:
|
|
5105
|
+
if not if_large_batch:
|
|
5106
|
+
# [N_image,l1,l2,l3,x,y]
|
|
5107
|
+
S4_pre_norm[:,Ndata_S4,:,:,:] = (
|
|
5108
|
+
I1_f_small[:,j1].view(N_image,L,1,1,M3,N3) *
|
|
5109
|
+
self.backend.bk_conjugate(I1_f2_wf3_2_small.view(N_image,1,L,L,M3,N3))
|
|
5110
|
+
).mean((-2,-1)) * fft_factor
|
|
5111
|
+
if get_variance:
|
|
5112
|
+
S4_sigma[:,Ndata_S4,:,:,:] = (
|
|
5113
|
+
I1_f_small[:,j1].view(N_image,L,1,1,M3,N3) *
|
|
5114
|
+
self.backend.bk_conjugate(I1_f2_wf3_2_small.view(N_image,1,L,L,M3,N3))
|
|
5115
|
+
).std((-2,-1)) * fft_factor
|
|
5116
|
+
else:
|
|
5117
|
+
for l1 in range(L):
|
|
5118
|
+
# [N_image,l2,l3,x,y]
|
|
5119
|
+
S4_pre_norm[:,Ndata_S4,l1,:,:] = (
|
|
5120
|
+
I1_f_small[:,j1,l1].view(N_image,1,1,M3,N3) *
|
|
5121
|
+
self.backend.bk_conjugate(I1_f2_wf3_2_small.view(N_image,L,L,M3,N3))
|
|
5122
|
+
).mean((-2,-1)) * fft_factor
|
|
5123
|
+
if get_variance:
|
|
5124
|
+
S4_sigma[:,Ndata_S4,l1,:,:] = (
|
|
5125
|
+
I1_f_small[:,j1,l1].view(N_image,1,1,M3,N3) *
|
|
5126
|
+
self.backend.bk_conjugate(I1_f2_wf3_2_small.view(N_image,L,L,M3,N3))
|
|
5127
|
+
).std((-2,-1)) * fft_factor
|
|
5128
|
+
else:
|
|
5129
|
+
if not if_large_batch:
|
|
5130
|
+
# [N_image,l1,l2,l3,x,y]
|
|
5131
|
+
S4_pre_norm[:,Ndata_S4,:,:,:] = (
|
|
5132
|
+
I1_small[:,j1].view(N_image,L,1,1,M3,N3) * self.backend.bk_conjugate(
|
|
5133
|
+
I12_w3_2_small.view(N_image,1,L,L,M3,N3)
|
|
5134
|
+
)
|
|
5135
|
+
)[...,edge_dx:-edge_dx, edge_dy:-edge_dy].mean((-2,-1)) * fft_factor
|
|
5136
|
+
if get_variance:
|
|
5137
|
+
S4_sigma[:,Ndata_S4,:,:,:] = (
|
|
5138
|
+
I1_small[:,j1].view(N_image,L,1,1,M3,N3) * self.backend.bk_conjugate(
|
|
5139
|
+
I12_w3_2_small.view(N_image,1,L,L,M3,N3)
|
|
5140
|
+
)
|
|
5141
|
+
)[...,edge_dx:-edge_dx, edge_dy:-edge_dy].std((-2,-1)) * fft_factor
|
|
5142
|
+
else:
|
|
5143
|
+
for l1 in range(L):
|
|
5144
|
+
# [N_image,l2,l3,x,y]
|
|
5145
|
+
S4_pre_norm[:,Ndata_S4,l1,:,:] = (
|
|
5146
|
+
I1_small[:,j1].view(N_image,1,1,M3,N3) * self.backend.bk_conjugate(
|
|
5147
|
+
I12_w3_2_small.view(N_image,L,L,M3,N3)
|
|
5148
|
+
)
|
|
5149
|
+
)[...,edge_dx:-edge_dx, edge_dy:-edge_dy].mean((-2,-1)) * fft_factor
|
|
5150
|
+
if get_variance:
|
|
5151
|
+
S4_sigma[:,Ndata_S4,l1,:,:] = (
|
|
5152
|
+
I1_small[:,j1].view(N_image,1,1,M3,N3) * self.backend.bk_conjugate(
|
|
5153
|
+
I12_w3_2_small.view(N_image,L,L,M3,N3)
|
|
5154
|
+
)
|
|
5155
|
+
)[...,edge_dx:-edge_dx, edge_dy:-edge_dy].mean((-2,-1)) * fft_factor
|
|
5156
|
+
|
|
5157
|
+
Ndata_S4+=1
|
|
5158
|
+
|
|
5159
|
+
if normalization=='S2':
|
|
5160
|
+
if use_ref:
|
|
5161
|
+
P = (ref_S2[:,j3,:,None,None] * ref_S2[:,j2,None,:,None] )**(0.5*pseudo_coef)
|
|
5162
|
+
else:
|
|
5163
|
+
P = (S2[:,j3,:,None,None] * S2[:,j2,None,:,None] )**(0.5*pseudo_coef)
|
|
5164
|
+
|
|
5165
|
+
S4[:,beg_n:Ndata_S4,:,:,:]=S4_pre_norm[:,beg_n:Ndata_S4,:,:,:]/P
|
|
5166
|
+
|
|
5167
|
+
if get_variance:
|
|
5168
|
+
S4_sigma[:,beg_n:Ndata_S4,:,:,:] = S4_sigma[:,beg_n:Ndata_S4,:,:,:] / P
|
|
5169
|
+
|
|
5170
|
+
"""
|
|
5171
|
+
# define P11 from diagonals of S4
|
|
5172
|
+
for j1 in range(J):
|
|
5173
|
+
for l1 in range(L):
|
|
5174
|
+
P11[:,j1,:,l1,:] = S4_pre_norm[:,j1,j1,:,l1,l1,:].real
|
|
5175
|
+
|
|
5176
|
+
|
|
5177
|
+
if normalization=='S4':
|
|
5178
|
+
if use_ref:
|
|
5179
|
+
P = ref_P11
|
|
5180
|
+
else:
|
|
5181
|
+
P = P11
|
|
5182
|
+
#.view(N_image,J,1,J,L,1,L) * .view(N_image,1,J,J,1,L,L)
|
|
5183
|
+
S4 = S4_pre_norm / (
|
|
5184
|
+
P[:,:,None,:,:,None,:] * P[:,None,:,:,None,:,:]
|
|
5185
|
+
)**(0.5*pseudo_coef)
|
|
5186
|
+
|
|
5187
|
+
|
|
5188
|
+
|
|
5189
|
+
|
|
5190
|
+
# get a single, flattened data vector for_synthesis
|
|
5191
|
+
select_and_index = self.get_scattering_index(J, L, normalization, S4_criteria)
|
|
5192
|
+
index_for_synthesis = select_and_index['index_for_synthesis']
|
|
5193
|
+
index_for_synthesis_iso = select_and_index['index_for_synthesis_iso']
|
|
5194
|
+
"""
|
|
5195
|
+
# average over l1 to obtain simple isotropic statistics
|
|
5196
|
+
if iso_ang:
|
|
5197
|
+
S2_iso = S2.mean(-1)
|
|
5198
|
+
S1_iso = S1.mean(-1)
|
|
5199
|
+
for l1 in range(L):
|
|
5200
|
+
for l2 in range(L):
|
|
5201
|
+
S3_iso[...,(l2-l1)%L] += S3[...,l1,l2]
|
|
5202
|
+
for l3 in range(L):
|
|
5203
|
+
S4_iso[...,(l2-l1)%L,(l3-l1)%L] += S4[...,l1,l2,l3]
|
|
5204
|
+
S3_iso /= L; S4_iso /= L
|
|
5205
|
+
|
|
5206
|
+
if get_variance:
|
|
5207
|
+
S2_sigma_iso = S2_sigma.mean(-1)
|
|
5208
|
+
S1_sigma_iso = S1_sigma.mean(-1)
|
|
5209
|
+
for l1 in range(L):
|
|
5210
|
+
for l2 in range(L):
|
|
5211
|
+
S3_sigma_iso[...,(l2-l1)%L] += S3_sigma[...,l1,l2]
|
|
5212
|
+
for l3 in range(L):
|
|
5213
|
+
S4_sigma_iso[...,(l2-l1)%L,(l3-l1)%L] += S4_sigma[...,l1,l2,l3]
|
|
5214
|
+
S3_sigma_iso /= L; S4_sigma_iso /= L
|
|
5215
|
+
|
|
5216
|
+
mean_data=self.backend.bk_zeros((N_image,1), dtype=data.dtype)
|
|
5217
|
+
std_data=self.backend.bk_zeros((N_image,1), dtype=data.dtype)
|
|
5218
|
+
mean_data[:,0]=data.mean((-2,-1))
|
|
5219
|
+
std_data[:,0]=data.std((-2,-1))
|
|
5220
|
+
|
|
5221
|
+
if get_variance:
|
|
5222
|
+
ref_sigma={}
|
|
5223
|
+
if iso_ang:
|
|
5224
|
+
ref_sigma['std_data']=std_data
|
|
5225
|
+
ref_sigma['S1_sigma']=S1_sigma_iso
|
|
5226
|
+
ref_sigma['S2_sigma']=S2_sigma_iso
|
|
5227
|
+
ref_sigma['S3_sigma']=S3_sigma_iso
|
|
5228
|
+
ref_sigma['S4_sigma']=S4_sigma_iso
|
|
5229
|
+
else:
|
|
5230
|
+
ref_sigma['std_data']=std_data
|
|
5231
|
+
ref_sigma['S1_sigma']=S1_sigma
|
|
5232
|
+
ref_sigma['S2_sigma']=S2_sigma
|
|
5233
|
+
ref_sigma['S3_sigma']=S3_sigma
|
|
5234
|
+
ref_sigma['S4_sigma']=S4_sigma
|
|
5235
|
+
|
|
5236
|
+
if iso_ang:
|
|
5237
|
+
if ref_sigma is not None:
|
|
5238
|
+
for_synthesis = self.backend.backend.cat((
|
|
5239
|
+
mean_data/ref_sigma['std_data'],
|
|
5240
|
+
std_data/ref_sigma['std_data'],
|
|
5241
|
+
(S2_iso/ref_sigma['S2_sigma']).reshape((N_image, -1)).log(),
|
|
5242
|
+
(S1_iso/ref_sigma['S1_sigma']).reshape((N_image, -1)).log(),
|
|
5243
|
+
(S3_iso/ref_sigma['S3_sigma']).reshape((N_image, -1)).real,
|
|
5244
|
+
(S3_iso/ref_sigma['S3_sigma']).reshape((N_image, -1)).imag,
|
|
5245
|
+
(S4_iso/ref_sigma['S4_sigma']).reshape((N_image, -1)).real,
|
|
5246
|
+
(S4_iso/ref_sigma['S4_sigma']).reshape((N_image, -1)).imag,
|
|
5247
|
+
),dim=-1)
|
|
5248
|
+
else:
|
|
5249
|
+
for_synthesis = self.backend.backend.cat((
|
|
5250
|
+
mean_data/std_data,
|
|
5251
|
+
std_data,
|
|
5252
|
+
S2_iso.reshape((N_image, -1)).log(),
|
|
5253
|
+
S1_iso.reshape((N_image, -1)).log(),
|
|
5254
|
+
S3_iso.reshape((N_image, -1)).real,
|
|
5255
|
+
S3_iso.reshape((N_image, -1)).imag,
|
|
5256
|
+
S4_iso.reshape((N_image, -1)).real,
|
|
5257
|
+
S4_iso.reshape((N_image, -1)).imag,
|
|
5258
|
+
),dim=-1)
|
|
5259
|
+
else:
|
|
5260
|
+
if ref_sigma is not None:
|
|
5261
|
+
for_synthesis = self.backend.backend.cat((
|
|
5262
|
+
mean_data/ref_sigma['std_data'],
|
|
5263
|
+
std_data/ref_sigma['std_data'],
|
|
5264
|
+
(S2/ref_sigma['S2_sigma']).reshape((N_image, -1)).log(),
|
|
5265
|
+
(S1/ref_sigma['S1_sigma']).reshape((N_image, -1)).log(),
|
|
5266
|
+
(S3/ref_sigma['S3_sigma']).reshape((N_image, -1)).real,
|
|
5267
|
+
(S3/ref_sigma['S3_sigma']).reshape((N_image, -1)).imag,
|
|
5268
|
+
(S4/ref_sigma['S4_sigma']).reshape((N_image, -1)).real,
|
|
5269
|
+
(S4/ref_sigma['S4_sigma']).reshape((N_image, -1)).imag,
|
|
5270
|
+
),dim=-1)
|
|
5271
|
+
else:
|
|
5272
|
+
for_synthesis = self.backend.backend.cat((
|
|
5273
|
+
mean_data/std_data,
|
|
5274
|
+
std_data,
|
|
5275
|
+
S2.reshape((N_image, -1)).log(),
|
|
5276
|
+
S1.reshape((N_image, -1)).log(),
|
|
5277
|
+
S3.reshape((N_image, -1)).real,
|
|
5278
|
+
S3.reshape((N_image, -1)).imag,
|
|
5279
|
+
S4.reshape((N_image, -1)).real,
|
|
5280
|
+
S4.reshape((N_image, -1)).imag,
|
|
5281
|
+
),dim=-1)
|
|
5282
|
+
|
|
5283
|
+
if not use_ref:
|
|
5284
|
+
self.ref_scattering_cov_S2=S2
|
|
5285
|
+
|
|
5286
|
+
if get_variance:
|
|
5287
|
+
return for_synthesis,ref_sigma
|
|
5288
|
+
|
|
5289
|
+
return for_synthesis
|
|
5290
|
+
|
|
5291
|
+
|
|
3648
5292
|
def to_gaussian(self,x):
|
|
3649
5293
|
from scipy.stats import norm
|
|
3650
5294
|
from scipy.interpolate import interp1d
|
|
@@ -4021,8 +5665,12 @@ class funct(FOC.FoCUS):
|
|
|
4021
5665
|
image_target,
|
|
4022
5666
|
nstep=4,
|
|
4023
5667
|
seed=1234,
|
|
4024
|
-
|
|
5668
|
+
Jmax=None,
|
|
5669
|
+
edge=False,
|
|
4025
5670
|
to_gaussian=True,
|
|
5671
|
+
use_variance=False,
|
|
5672
|
+
synthesised_N=1,
|
|
5673
|
+
iso_ang=False,
|
|
4026
5674
|
EVAL_FREQUENCY=100,
|
|
4027
5675
|
NUM_EPOCHS = 300):
|
|
4028
5676
|
|
|
@@ -4032,13 +5680,16 @@ class funct(FOC.FoCUS):
|
|
|
4032
5680
|
def The_loss(u,scat_operator,args):
|
|
4033
5681
|
ref = args[0]
|
|
4034
5682
|
sref = args[1]
|
|
5683
|
+
use_v= args[2]
|
|
4035
5684
|
|
|
4036
5685
|
# compute scattering covariance of the current synthetised map called u
|
|
4037
|
-
|
|
5686
|
+
if use_v:
|
|
5687
|
+
learn=scat_operator.reduce_mean_batch(scat_operator.scattering_cov(u,edge=edge,Jmax=Jmax,ref_sigma=sref,use_ref=True,iso_ang=iso_ang))
|
|
5688
|
+
else:
|
|
5689
|
+
learn=scat_operator.reduce_mean_batch(scat_operator.scattering_cov(u,edge=edge,Jmax=Jmax,use_ref=True,iso_ang=iso_ang))
|
|
4038
5690
|
|
|
4039
5691
|
# make the difference withe the reference coordinates
|
|
4040
|
-
loss=scat_operator.
|
|
4041
|
-
|
|
5692
|
+
loss=scat_operator.backend.bk_reduce_mean(scat_operator.backend.bk_square((learn-ref)))
|
|
4042
5693
|
return loss
|
|
4043
5694
|
|
|
4044
5695
|
if to_gaussian:
|
|
@@ -4078,33 +5729,35 @@ class funct(FOC.FoCUS):
|
|
|
4078
5729
|
if k==0:
|
|
4079
5730
|
np.random.seed(seed)
|
|
4080
5731
|
if self.use_2D:
|
|
4081
|
-
imap=np.random.randn(
|
|
5732
|
+
imap=np.random.randn(synthesised_N,
|
|
4082
5733
|
tmp[k].shape[1],
|
|
4083
5734
|
tmp[k].shape[2])
|
|
4084
5735
|
else:
|
|
4085
|
-
imap=np.random.randn(
|
|
5736
|
+
imap=np.random.randn(synthesised_N,
|
|
4086
5737
|
tmp[k].shape[1])
|
|
4087
5738
|
else:
|
|
4088
|
-
|
|
4089
|
-
# if the kernel size is bigger than 3 increase the binning before smoothing
|
|
5739
|
+
# Increase the resolution between each step
|
|
4090
5740
|
if self.use_2D:
|
|
4091
5741
|
imap = self.up_grade(
|
|
4092
|
-
omap, imap.shape[
|
|
5742
|
+
omap, imap.shape[1] * 2, axis=1, nouty=imap.shape[2] * 2
|
|
4093
5743
|
)
|
|
4094
5744
|
elif self.use_1D:
|
|
4095
|
-
imap = self.up_grade(omap, imap.shape[
|
|
5745
|
+
imap = self.up_grade(omap, imap.shape[1] * 2, axis=1)
|
|
4096
5746
|
else:
|
|
4097
|
-
imap = self.up_grade(omap, l_nside, axis=
|
|
5747
|
+
imap = self.up_grade(omap, l_nside, axis=1)
|
|
4098
5748
|
|
|
4099
5749
|
# compute the coefficients for the target image
|
|
4100
|
-
|
|
4101
|
-
|
|
5750
|
+
if use_variance:
|
|
5751
|
+
ref,sref=self.scattering_cov(tmp[k],get_variance=True,edge=edge,Jmax=Jmax,iso_ang=iso_ang)
|
|
5752
|
+
else:
|
|
5753
|
+
ref=self.scattering_cov(tmp[k],edge=edge,Jmax=Jmax,iso_ang=iso_ang)
|
|
5754
|
+
sref=ref
|
|
5755
|
+
|
|
4102
5756
|
# compute the mean of the population does nothing if only one map is given
|
|
4103
5757
|
ref=self.reduce_mean_batch(ref)
|
|
4104
|
-
|
|
4105
|
-
|
|
5758
|
+
|
|
4106
5759
|
# define a loss to minimize
|
|
4107
|
-
loss=synthe.Loss(The_loss,self,ref,sref)
|
|
5760
|
+
loss=synthe.Loss(The_loss,self,ref,sref,use_variance)
|
|
4108
5761
|
|
|
4109
5762
|
sy = synthe.Synthesis([loss])
|
|
4110
5763
|
|
|
@@ -4121,6 +5774,8 @@ class funct(FOC.FoCUS):
|
|
|
4121
5774
|
omap=sy.run(imap,
|
|
4122
5775
|
EVAL_FREQUENCY=EVAL_FREQUENCY,
|
|
4123
5776
|
NUM_EPOCHS = NUM_EPOCHS)
|
|
5777
|
+
|
|
5778
|
+
|
|
4124
5779
|
|
|
4125
5780
|
t2=time.time()
|
|
4126
5781
|
print('Total computation %.2fs'%(t2-t1))
|
|
@@ -4128,7 +5783,7 @@ class funct(FOC.FoCUS):
|
|
|
4128
5783
|
if to_gaussian:
|
|
4129
5784
|
omap=self.from_gaussian(omap)
|
|
4130
5785
|
|
|
4131
|
-
if axis==0:
|
|
5786
|
+
if axis==0 and synthesised_N==1:
|
|
4132
5787
|
return omap[0]
|
|
4133
5788
|
else:
|
|
4134
5789
|
return omap
|