foscat 3.7.0__py3-none-any.whl → 3.7.2__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- foscat/FoCUS.py +4 -1
- foscat/backend.py +62 -10
- foscat/scat_cov.py +1851 -31
- foscat/scat_cov2D.py +65 -25
- {foscat-3.7.0.dist-info → foscat-3.7.2.dist-info}/METADATA +2 -2
- {foscat-3.7.0.dist-info → foscat-3.7.2.dist-info}/RECORD +9 -9
- {foscat-3.7.0.dist-info → foscat-3.7.2.dist-info}/WHEEL +1 -1
- /foscat-3.7.0.dist-info/LICENCE → /foscat-3.7.2.dist-info/LICENSE +0 -0
- {foscat-3.7.0.dist-info → foscat-3.7.2.dist-info}/top_level.txt +0 -0
foscat/scat_cov.py
CHANGED
|
@@ -92,7 +92,7 @@ class scat_cov:
|
|
|
92
92
|
)
|
|
93
93
|
|
|
94
94
|
def conv2complex(self, val):
|
|
95
|
-
if val.dtype == "complex64"
|
|
95
|
+
if val.dtype == "complex64" or val.dtype=="complex128" or val.dtype == "torch.complex64" or val.dtype == "torch.complex128" :
|
|
96
96
|
return val
|
|
97
97
|
else:
|
|
98
98
|
return self.backend.bk_complex(val, 0 * val)
|
|
@@ -2540,7 +2540,1050 @@ class funct(FOC.FoCUS):
|
|
|
2540
2540
|
-------
|
|
2541
2541
|
S1, S2, S3, S4 normalized
|
|
2542
2542
|
"""
|
|
2543
|
+
|
|
2544
|
+
return_data = self.return_data
|
|
2545
|
+
|
|
2546
|
+
# Check input consistency
|
|
2547
|
+
if image2 is not None:
|
|
2548
|
+
if list(image1.shape) != list(image2.shape):
|
|
2549
|
+
print(
|
|
2550
|
+
"The two input image should have the same size to eval Scattering Covariance"
|
|
2551
|
+
)
|
|
2552
|
+
return None
|
|
2553
|
+
if mask is not None:
|
|
2554
|
+
if self.use_2D:
|
|
2555
|
+
if image1.shape[-2] != mask.shape[1] or image1.shape[-1] != mask.shape[2]:
|
|
2556
|
+
print(
|
|
2557
|
+
"The LAST 2 COLUMNs of the mask should have the same size ",
|
|
2558
|
+
mask.shape,
|
|
2559
|
+
"than the input image ",
|
|
2560
|
+
image1.shape,
|
|
2561
|
+
"to eval Scattering Covariance",
|
|
2562
|
+
)
|
|
2563
|
+
return None
|
|
2564
|
+
else:
|
|
2565
|
+
if image1.shape[-1] != mask.shape[1]:
|
|
2566
|
+
print(
|
|
2567
|
+
"The LAST COLUMN of the mask should have the same size ",
|
|
2568
|
+
mask.shape,
|
|
2569
|
+
"than the input image ",
|
|
2570
|
+
image1.shape,
|
|
2571
|
+
"to eval Scattering Covariance",
|
|
2572
|
+
)
|
|
2573
|
+
return None
|
|
2574
|
+
if self.use_2D and len(image1.shape) < 2:
|
|
2575
|
+
print(
|
|
2576
|
+
"To work with 2D scattering transform, two dimension is needed, input map has only on dimension"
|
|
2577
|
+
)
|
|
2578
|
+
return None
|
|
2579
|
+
|
|
2580
|
+
### AUTO OR CROSS
|
|
2581
|
+
cross = False
|
|
2582
|
+
if image2 is not None:
|
|
2583
|
+
cross = True
|
|
2584
|
+
|
|
2585
|
+
### PARAMETERS
|
|
2586
|
+
axis = 1
|
|
2587
|
+
# determine jmax and nside corresponding to the input map
|
|
2588
|
+
im_shape = image1.shape
|
|
2589
|
+
if self.use_2D:
|
|
2590
|
+
if len(image1.shape) == 2:
|
|
2591
|
+
nside = np.min([im_shape[0], im_shape[1]])
|
|
2592
|
+
npix = im_shape[0] * im_shape[1] # Number of pixels
|
|
2593
|
+
x1 = im_shape[0]
|
|
2594
|
+
x2 = im_shape[1]
|
|
2595
|
+
else:
|
|
2596
|
+
nside = np.min([im_shape[1], im_shape[2]])
|
|
2597
|
+
npix = im_shape[1] * im_shape[2] # Number of pixels
|
|
2598
|
+
x1 = im_shape[1]
|
|
2599
|
+
x2 = im_shape[2]
|
|
2600
|
+
J = int(np.log(nside - self.KERNELSZ) / np.log(2)) # Number of j scales
|
|
2601
|
+
elif self.use_1D:
|
|
2602
|
+
if len(image1.shape) == 2:
|
|
2603
|
+
npix = int(im_shape[1]) # Number of pixels
|
|
2604
|
+
else:
|
|
2605
|
+
npix = int(im_shape[0]) # Number of pixels
|
|
2606
|
+
|
|
2607
|
+
nside = int(npix)
|
|
2608
|
+
|
|
2609
|
+
J = int(np.log(nside) / np.log(2)) # Number of j scales
|
|
2610
|
+
else:
|
|
2611
|
+
if len(image1.shape) == 2:
|
|
2612
|
+
npix = int(im_shape[1]) # Number of pixels
|
|
2613
|
+
else:
|
|
2614
|
+
npix = int(im_shape[0]) # Number of pixels
|
|
2615
|
+
|
|
2616
|
+
nside = int(np.sqrt(npix // 12))
|
|
2617
|
+
|
|
2618
|
+
J = int(np.log(nside) / np.log(2)) # Number of j scales
|
|
2619
|
+
|
|
2620
|
+
if (self.use_2D or self.use_1D) and self.KERNELSZ>3:
|
|
2621
|
+
J-=1
|
|
2622
|
+
if Jmax is None:
|
|
2623
|
+
Jmax = J # Number of steps for the loop on scales
|
|
2624
|
+
if Jmax>J:
|
|
2625
|
+
print('==========\n\n')
|
|
2626
|
+
print('The Jmax you requested is larger than the data size, which may cause problems while computing the scattering transform.')
|
|
2627
|
+
print('\n\n==========')
|
|
2628
|
+
|
|
2629
|
+
|
|
2630
|
+
### LOCAL VARIABLES (IMAGES and MASK)
|
|
2631
|
+
if len(image1.shape) == 1 or (len(image1.shape) == 2 and self.use_2D):
|
|
2632
|
+
I1 = self.backend.bk_cast(
|
|
2633
|
+
self.backend.bk_expand_dims(image1, 0)
|
|
2634
|
+
) # Local image1 [Nbatch, Npix]
|
|
2635
|
+
if cross:
|
|
2636
|
+
I2 = self.backend.bk_cast(
|
|
2637
|
+
self.backend.bk_expand_dims(image2, 0)
|
|
2638
|
+
) # Local image2 [Nbatch, Npix]
|
|
2639
|
+
else:
|
|
2640
|
+
I1 = self.backend.bk_cast(image1) # Local image1 [Nbatch, Npix]
|
|
2641
|
+
if cross:
|
|
2642
|
+
I2 = self.backend.bk_cast(image2) # Local image2 [Nbatch, Npix]
|
|
2643
|
+
|
|
2644
|
+
if mask is None:
|
|
2645
|
+
if self.use_2D:
|
|
2646
|
+
vmask = self.backend.bk_ones([1, x1, x2], dtype=self.all_type)
|
|
2647
|
+
else:
|
|
2648
|
+
vmask = self.backend.bk_ones([1, npix], dtype=self.all_type)
|
|
2649
|
+
else:
|
|
2650
|
+
vmask = self.backend.bk_cast(mask) # [Nmask, Npix]
|
|
2651
|
+
|
|
2652
|
+
if self.KERNELSZ > 3 and not self.use_2D:
|
|
2653
|
+
# if the kernel size is bigger than 3 increase the binning before smoothing
|
|
2654
|
+
if self.use_2D:
|
|
2655
|
+
vmask = self.up_grade(
|
|
2656
|
+
vmask, I1.shape[axis] * 2, axis=1, nouty=I1.shape[axis + 1] * 2
|
|
2657
|
+
)
|
|
2658
|
+
I1 = self.up_grade(
|
|
2659
|
+
I1, I1.shape[axis] * 2, axis=axis, nouty=I1.shape[axis + 1] * 2
|
|
2660
|
+
)
|
|
2661
|
+
if cross:
|
|
2662
|
+
I2 = self.up_grade(
|
|
2663
|
+
I2, I2.shape[axis] * 2, axis=axis, nouty=I2.shape[axis + 1] * 2
|
|
2664
|
+
)
|
|
2665
|
+
elif self.use_1D:
|
|
2666
|
+
vmask = self.up_grade(vmask, I1.shape[axis] * 2, axis=1)
|
|
2667
|
+
I1 = self.up_grade(I1, I1.shape[axis] * 2, axis=axis)
|
|
2668
|
+
if cross:
|
|
2669
|
+
I2 = self.up_grade(I2, I2.shape[axis] * 2, axis=axis)
|
|
2670
|
+
else:
|
|
2671
|
+
I1 = self.up_grade(I1, nside * 2, axis=axis)
|
|
2672
|
+
vmask = self.up_grade(vmask, nside * 2, axis=1)
|
|
2673
|
+
if cross:
|
|
2674
|
+
I2 = self.up_grade(I2, nside * 2, axis=axis)
|
|
2675
|
+
|
|
2676
|
+
if self.KERNELSZ > 5 and not self.use_2D:
|
|
2677
|
+
# if the kernel size is bigger than 3 increase the binning before smoothing
|
|
2678
|
+
if self.use_2D:
|
|
2679
|
+
vmask = self.up_grade(
|
|
2680
|
+
vmask, I1.shape[axis] * 2, axis=1, nouty=I1.shape[axis + 1] * 2
|
|
2681
|
+
)
|
|
2682
|
+
I1 = self.up_grade(
|
|
2683
|
+
I1, I1.shape[axis] * 2, axis=axis, nouty=I1.shape[axis + 1] * 2
|
|
2684
|
+
)
|
|
2685
|
+
if cross:
|
|
2686
|
+
I2 = self.up_grade(
|
|
2687
|
+
I2,
|
|
2688
|
+
I2.shape[axis] * 2,
|
|
2689
|
+
axis=axis,
|
|
2690
|
+
nouty=I2.shape[axis + 1] * 2,
|
|
2691
|
+
)
|
|
2692
|
+
elif self.use_1D:
|
|
2693
|
+
vmask = self.up_grade(vmask, I1.shape[axis] * 4, axis=1)
|
|
2694
|
+
I1 = self.up_grade(I1, I1.shape[axis] * 4, axis=axis)
|
|
2695
|
+
if cross:
|
|
2696
|
+
I2 = self.up_grade(I2, I2.shape[axis] * 4, axis=axis)
|
|
2697
|
+
else:
|
|
2698
|
+
I1 = self.up_grade(I1, nside * 4, axis=axis)
|
|
2699
|
+
vmask = self.up_grade(vmask, nside * 4, axis=1)
|
|
2700
|
+
if cross:
|
|
2701
|
+
I2 = self.up_grade(I2, nside * 4, axis=axis)
|
|
2702
|
+
|
|
2703
|
+
# Normalize the masks because they have different pixel numbers
|
|
2704
|
+
# vmask /= self.backend.bk_reduce_sum(vmask, axis=1)[:, None] # [Nmask, Npix]
|
|
2705
|
+
|
|
2706
|
+
### INITIALIZATION
|
|
2707
|
+
# Coefficients
|
|
2708
|
+
if return_data:
|
|
2709
|
+
S1 = {}
|
|
2710
|
+
S2 = {}
|
|
2711
|
+
S3 = {}
|
|
2712
|
+
S3P = {}
|
|
2713
|
+
S4 = {}
|
|
2714
|
+
else:
|
|
2715
|
+
S1 = []
|
|
2716
|
+
S2 = []
|
|
2717
|
+
S3 = []
|
|
2718
|
+
S4 = []
|
|
2719
|
+
S3P = []
|
|
2720
|
+
VS1 = []
|
|
2721
|
+
VS2 = []
|
|
2722
|
+
VS3 = []
|
|
2723
|
+
VS3P = []
|
|
2724
|
+
VS4 = []
|
|
2725
|
+
|
|
2726
|
+
off_S2 = -2
|
|
2727
|
+
off_S3 = -3
|
|
2728
|
+
off_S4 = -4
|
|
2729
|
+
if self.use_1D:
|
|
2730
|
+
off_S2 = -1
|
|
2731
|
+
off_S3 = -1
|
|
2732
|
+
off_S4 = -1
|
|
2733
|
+
|
|
2734
|
+
# S2 for normalization
|
|
2735
|
+
cond_init_P1_dic = (norm == "self") or (
|
|
2736
|
+
(norm == "auto") and (self.P1_dic is None)
|
|
2737
|
+
)
|
|
2738
|
+
if norm is None:
|
|
2739
|
+
pass
|
|
2740
|
+
elif cond_init_P1_dic:
|
|
2741
|
+
P1_dic = {}
|
|
2742
|
+
if cross:
|
|
2743
|
+
P2_dic = {}
|
|
2744
|
+
elif (norm == "auto") and (self.P1_dic is not None):
|
|
2745
|
+
P1_dic = self.P1_dic
|
|
2746
|
+
if cross:
|
|
2747
|
+
P2_dic = self.P2_dic
|
|
2748
|
+
|
|
2749
|
+
if return_data:
|
|
2750
|
+
s0 = I1
|
|
2751
|
+
if out_nside is not None:
|
|
2752
|
+
s0 = self.backend.bk_reduce_mean(
|
|
2753
|
+
self.backend.bk_reshape(
|
|
2754
|
+
s0, [s0.shape[0], 12 * out_nside**2, (nside // out_nside) ** 2]
|
|
2755
|
+
),
|
|
2756
|
+
2,
|
|
2757
|
+
)
|
|
2758
|
+
else:
|
|
2759
|
+
if not cross:
|
|
2760
|
+
s0, l_vs0 = self.masked_mean(I1, vmask, axis=1, calc_var=True)
|
|
2761
|
+
else:
|
|
2762
|
+
s0, l_vs0 = self.masked_mean(
|
|
2763
|
+
self.backend.bk_L1(I1 * I2), vmask, axis=1, calc_var=True
|
|
2764
|
+
)
|
|
2765
|
+
vs0 = self.backend.bk_concat([l_vs0, l_vs0], 1)
|
|
2766
|
+
s0 = self.backend.bk_concat([s0, l_vs0], 1)
|
|
2767
|
+
#### COMPUTE S1, S2, S3 and S4
|
|
2768
|
+
nside_j3 = nside # NSIDE start (nside_j3 = nside / 2^j3)
|
|
2769
|
+
|
|
2770
|
+
# a remettre comme avant
|
|
2771
|
+
M1_dic={}
|
|
2772
|
+
M2_dic={}
|
|
2773
|
+
|
|
2774
|
+
for j3 in range(Jmax):
|
|
2775
|
+
|
|
2776
|
+
if edge:
|
|
2777
|
+
if self.mask_mask is None:
|
|
2778
|
+
self.mask_mask={}
|
|
2779
|
+
if self.use_2D:
|
|
2780
|
+
if (vmask.shape[1],vmask.shape[2]) not in self.mask_mask:
|
|
2781
|
+
mask_mask=np.zeros([1,vmask.shape[1],vmask.shape[2]])
|
|
2782
|
+
mask_mask[0,
|
|
2783
|
+
self.KERNELSZ//2:-self.KERNELSZ//2+1,
|
|
2784
|
+
self.KERNELSZ//2:-self.KERNELSZ//2+1]=1.0
|
|
2785
|
+
self.mask_mask[(vmask.shape[1],vmask.shape[2])]=self.backend.bk_cast(mask_mask)
|
|
2786
|
+
vmask=vmask*self.mask_mask[(vmask.shape[1],vmask.shape[2])]
|
|
2787
|
+
#print(self.KERNELSZ//2,vmask,mask_mask)
|
|
2788
|
+
|
|
2789
|
+
if self.use_1D:
|
|
2790
|
+
if (vmask.shape[1]) not in self.mask_mask:
|
|
2791
|
+
mask_mask=np.zeros([1,vmask.shape[1]])
|
|
2792
|
+
mask_mask[0,
|
|
2793
|
+
self.KERNELSZ//2:-self.KERNELSZ//2+1]=1.0
|
|
2794
|
+
self.mask_mask[(vmask.shape[1])]=self.backend.bk_cast(mask_mask)
|
|
2795
|
+
vmask=vmask*self.mask_mask[(vmask.shape[1])]
|
|
2796
|
+
|
|
2797
|
+
if return_data:
|
|
2798
|
+
S3[j3] = None
|
|
2799
|
+
S3P[j3] = None
|
|
2800
|
+
|
|
2801
|
+
if S4 is None:
|
|
2802
|
+
S4 = {}
|
|
2803
|
+
S4[j3] = None
|
|
2804
|
+
|
|
2805
|
+
####### S1 and S2
|
|
2806
|
+
### Make the convolution I1 * Psi_j3
|
|
2807
|
+
conv1 = self.convol(I1, axis=1) # [Nbatch, Npix_j3, Norient3]
|
|
2808
|
+
if cmat is not None:
|
|
2809
|
+
tmp2 = self.backend.bk_repeat(conv1, self.NORIENT, axis=-1)
|
|
2810
|
+
conv1 = self.backend.bk_reduce_sum(
|
|
2811
|
+
self.backend.bk_reshape(
|
|
2812
|
+
cmat[j3] * tmp2,
|
|
2813
|
+
[tmp2.shape[0], cmat[j3].shape[0], self.NORIENT, self.NORIENT],
|
|
2814
|
+
),
|
|
2815
|
+
2,
|
|
2816
|
+
)
|
|
2817
|
+
|
|
2818
|
+
### Take the module M1 = |I1 * Psi_j3|
|
|
2819
|
+
M1_square = conv1 * self.backend.bk_conjugate(
|
|
2820
|
+
conv1
|
|
2821
|
+
) # [Nbatch, Npix_j3, Norient3]
|
|
2822
|
+
M1 = self.backend.bk_L1(M1_square) # [Nbatch, Npix_j3, Norient3]
|
|
2823
|
+
# Store M1_j3 in a dictionary
|
|
2824
|
+
M1_dic[j3] = M1
|
|
2825
|
+
|
|
2826
|
+
if not cross: # Auto
|
|
2827
|
+
M1_square = self.backend.bk_real(M1_square)
|
|
2828
|
+
|
|
2829
|
+
### S2_auto = < M1^2 >_pix
|
|
2830
|
+
# Apply the mask [Nmask, Npix_j3] and average over pixels
|
|
2831
|
+
if return_data:
|
|
2832
|
+
s2 = M1_square
|
|
2833
|
+
else:
|
|
2834
|
+
if calc_var:
|
|
2835
|
+
s2, vs2 = self.masked_mean(
|
|
2836
|
+
M1_square, vmask, axis=1, rank=j3, calc_var=True
|
|
2837
|
+
)
|
|
2838
|
+
else:
|
|
2839
|
+
s2 = self.masked_mean(M1_square, vmask, axis=1, rank=j3)
|
|
2840
|
+
|
|
2841
|
+
if cond_init_P1_dic:
|
|
2842
|
+
# We fill P1_dic with S2 for normalisation of S3 and S4
|
|
2843
|
+
P1_dic[j3] = self.backend.bk_real(s2) # [Nbatch, Nmask, Norient3]
|
|
2844
|
+
|
|
2845
|
+
# We store S2_auto to return it [Nbatch, Nmask, NS2, Norient3]
|
|
2846
|
+
if return_data:
|
|
2847
|
+
if S2 is None:
|
|
2848
|
+
S2 = {}
|
|
2849
|
+
if out_nside is not None and out_nside < nside_j3:
|
|
2850
|
+
s2 = self.backend.bk_reduce_mean(
|
|
2851
|
+
self.backend.bk_reshape(
|
|
2852
|
+
s2,
|
|
2853
|
+
[
|
|
2854
|
+
s2.shape[0],
|
|
2855
|
+
12 * out_nside**2,
|
|
2856
|
+
(nside_j3 // out_nside) ** 2,
|
|
2857
|
+
s2.shape[2],
|
|
2858
|
+
],
|
|
2859
|
+
),
|
|
2860
|
+
2,
|
|
2861
|
+
)
|
|
2862
|
+
S2[j3] = s2
|
|
2863
|
+
else:
|
|
2864
|
+
if norm == "auto": # Normalize S2
|
|
2865
|
+
s2 /= P1_dic[j3]
|
|
2866
|
+
|
|
2867
|
+
S2.append(
|
|
2868
|
+
self.backend.bk_expand_dims(s2, off_S2)
|
|
2869
|
+
) # Add a dimension for NS2
|
|
2870
|
+
if calc_var:
|
|
2871
|
+
VS2.append(
|
|
2872
|
+
self.backend.bk_expand_dims(vs2, off_S2)
|
|
2873
|
+
) # Add a dimension for NS2
|
|
2874
|
+
|
|
2875
|
+
#### S1_auto computation
|
|
2876
|
+
### Image 1 : S1 = < M1 >_pix
|
|
2877
|
+
# Apply the mask [Nmask, Npix_j3] and average over pixels
|
|
2878
|
+
if return_data:
|
|
2879
|
+
s1 = M1
|
|
2880
|
+
else:
|
|
2881
|
+
if calc_var:
|
|
2882
|
+
s1, vs1 = self.masked_mean(
|
|
2883
|
+
M1, vmask, axis=1, rank=j3, calc_var=True
|
|
2884
|
+
) # [Nbatch, Nmask, Norient3]
|
|
2885
|
+
else:
|
|
2886
|
+
s1 = self.masked_mean(
|
|
2887
|
+
M1, vmask, axis=1, rank=j3
|
|
2888
|
+
) # [Nbatch, Nmask, Norient3]
|
|
2889
|
+
|
|
2890
|
+
if return_data:
|
|
2891
|
+
if out_nside is not None and out_nside < nside_j3:
|
|
2892
|
+
s1 = self.backend.bk_reduce_mean(
|
|
2893
|
+
self.backend.bk_reshape(
|
|
2894
|
+
s1,
|
|
2895
|
+
[
|
|
2896
|
+
s1.shape[0],
|
|
2897
|
+
12 * out_nside**2,
|
|
2898
|
+
(nside_j3 // out_nside) ** 2,
|
|
2899
|
+
s1.shape[2],
|
|
2900
|
+
],
|
|
2901
|
+
),
|
|
2902
|
+
2,
|
|
2903
|
+
)
|
|
2904
|
+
S1[j3] = s1
|
|
2905
|
+
else:
|
|
2906
|
+
### Normalize S1
|
|
2907
|
+
if norm is not None:
|
|
2908
|
+
self.div_norm(s1, (P1_dic[j3]) ** 0.5)
|
|
2909
|
+
### We store S1 for image1 [Nbatch, Nmask, NS1, Norient3]
|
|
2910
|
+
S1.append(
|
|
2911
|
+
self.backend.bk_expand_dims(s1, off_S2)
|
|
2912
|
+
) # Add a dimension for NS1
|
|
2913
|
+
if calc_var:
|
|
2914
|
+
VS1.append(
|
|
2915
|
+
self.backend.bk_expand_dims(vs1, off_S2)
|
|
2916
|
+
) # Add a dimension for NS1
|
|
2917
|
+
|
|
2918
|
+
else: # Cross
|
|
2919
|
+
### Make the convolution I2 * Psi_j3
|
|
2920
|
+
conv2 = self.convol(I2, axis=1) # [Nbatch, Npix_j3, Norient3]
|
|
2921
|
+
if cmat is not None:
|
|
2922
|
+
tmp2 = self.backend.bk_repeat(conv2, self.NORIENT, axis=-1)
|
|
2923
|
+
conv2 = self.backend.bk_reduce_sum(
|
|
2924
|
+
self.backend.bk_reshape(
|
|
2925
|
+
cmat[j3] * tmp2,
|
|
2926
|
+
[
|
|
2927
|
+
tmp2.shape[0],
|
|
2928
|
+
cmat[j3].shape[0],
|
|
2929
|
+
self.NORIENT,
|
|
2930
|
+
self.NORIENT,
|
|
2931
|
+
],
|
|
2932
|
+
),
|
|
2933
|
+
2,
|
|
2934
|
+
)
|
|
2935
|
+
### Take the module M2 = |I2 * Psi_j3|
|
|
2936
|
+
M2_square = conv2 * self.backend.bk_conjugate(
|
|
2937
|
+
conv2
|
|
2938
|
+
) # [Nbatch, Npix_j3, Norient3]
|
|
2939
|
+
M2 = self.backend.bk_L1(M2_square) # [Nbatch, Npix_j3, Norient3]
|
|
2940
|
+
# Store M2_j3 in a dictionary
|
|
2941
|
+
M2_dic[j3] = M2
|
|
2942
|
+
|
|
2943
|
+
### S2_auto = < M2^2 >_pix
|
|
2944
|
+
# Not returned, only for normalization
|
|
2945
|
+
if cond_init_P1_dic:
|
|
2946
|
+
# Apply the mask [Nmask, Npix_j3] and average over pixels
|
|
2947
|
+
if return_data:
|
|
2948
|
+
p1 = M1_square
|
|
2949
|
+
p2 = M2_square
|
|
2950
|
+
else:
|
|
2951
|
+
if calc_var:
|
|
2952
|
+
p1, vp1 = self.masked_mean(
|
|
2953
|
+
M1_square, vmask, axis=1, rank=j3, calc_var=True
|
|
2954
|
+
) # [Nbatch, Nmask, Norient3]
|
|
2955
|
+
p2, vp2 = self.masked_mean(
|
|
2956
|
+
M2_square, vmask, axis=1, rank=j3, calc_var=True
|
|
2957
|
+
) # [Nbatch, Nmask, Norient3]
|
|
2958
|
+
else:
|
|
2959
|
+
p1 = self.masked_mean(
|
|
2960
|
+
M1_square, vmask, axis=1, rank=j3
|
|
2961
|
+
) # [Nbatch, Nmask, Norient3]
|
|
2962
|
+
p2 = self.masked_mean(
|
|
2963
|
+
M2_square, vmask, axis=1, rank=j3
|
|
2964
|
+
) # [Nbatch, Nmask, Norient3]
|
|
2965
|
+
# We fill P1_dic with S2 for normalisation of S3 and S4
|
|
2966
|
+
P1_dic[j3] = self.backend.bk_real(p1) # [Nbatch, Nmask, Norient3]
|
|
2967
|
+
P2_dic[j3] = self.backend.bk_real(p2) # [Nbatch, Nmask, Norient3]
|
|
2968
|
+
|
|
2969
|
+
### S2_cross = < (I1 * Psi_j3) (I2 * Psi_j3)^* >_pix
|
|
2970
|
+
# z_1 x z_2^* = (a1a2 + b1b2) + i(b1a2 - a1b2)
|
|
2971
|
+
s2 = conv1 * self.backend.bk_conjugate(conv2)
|
|
2972
|
+
MX = self.backend.bk_L1(s2)
|
|
2973
|
+
# Apply the mask [Nmask, Npix_j3] and average over pixels
|
|
2974
|
+
if return_data:
|
|
2975
|
+
s2 = s2
|
|
2976
|
+
else:
|
|
2977
|
+
if calc_var:
|
|
2978
|
+
s2, vs2 = self.masked_mean(
|
|
2979
|
+
s2, vmask, axis=1, rank=j3, calc_var=True
|
|
2980
|
+
)
|
|
2981
|
+
else:
|
|
2982
|
+
s2 = self.masked_mean(s2, vmask, axis=1, rank=j3)
|
|
2983
|
+
|
|
2984
|
+
if return_data:
|
|
2985
|
+
if out_nside is not None and out_nside < nside_j3:
|
|
2986
|
+
s2 = self.backend.bk_reduce_mean(
|
|
2987
|
+
self.backend.bk_reshape(
|
|
2988
|
+
s2,
|
|
2989
|
+
[
|
|
2990
|
+
s2.shape[0],
|
|
2991
|
+
12 * out_nside**2,
|
|
2992
|
+
(nside_j3 // out_nside) ** 2,
|
|
2993
|
+
s2.shape[2],
|
|
2994
|
+
],
|
|
2995
|
+
),
|
|
2996
|
+
2,
|
|
2997
|
+
)
|
|
2998
|
+
S2[j3] = s2
|
|
2999
|
+
else:
|
|
3000
|
+
### Normalize S2_cross
|
|
3001
|
+
if norm == "auto":
|
|
3002
|
+
s2 /= (P1_dic[j3] * P2_dic[j3]) ** 0.5
|
|
3003
|
+
|
|
3004
|
+
### Store S2_cross as complex [Nbatch, Nmask, NS2, Norient3]
|
|
3005
|
+
s2 = self.backend.bk_real(s2)
|
|
3006
|
+
|
|
3007
|
+
S2.append(
|
|
3008
|
+
self.backend.bk_expand_dims(s2, off_S2)
|
|
3009
|
+
) # Add a dimension for NS2
|
|
3010
|
+
if calc_var:
|
|
3011
|
+
VS2.append(
|
|
3012
|
+
self.backend.bk_expand_dims(vs2, off_S2)
|
|
3013
|
+
) # Add a dimension for NS2
|
|
3014
|
+
|
|
3015
|
+
#### S1_auto computation
|
|
3016
|
+
### Image 1 : S1 = < M1 >_pix
|
|
3017
|
+
# Apply the mask [Nmask, Npix_j3] and average over pixels
|
|
3018
|
+
if return_data:
|
|
3019
|
+
s1 = MX
|
|
3020
|
+
else:
|
|
3021
|
+
if calc_var:
|
|
3022
|
+
s1, vs1 = self.masked_mean(
|
|
3023
|
+
MX, vmask, axis=1, rank=j3, calc_var=True
|
|
3024
|
+
) # [Nbatch, Nmask, Norient3]
|
|
3025
|
+
else:
|
|
3026
|
+
s1 = self.masked_mean(
|
|
3027
|
+
MX, vmask, axis=1, rank=j3
|
|
3028
|
+
) # [Nbatch, Nmask, Norient3]
|
|
3029
|
+
if return_data:
|
|
3030
|
+
if out_nside is not None and out_nside < nside_j3:
|
|
3031
|
+
s1 = self.backend.bk_reduce_mean(
|
|
3032
|
+
self.backend.bk_reshape(
|
|
3033
|
+
s1,
|
|
3034
|
+
[
|
|
3035
|
+
s1.shape[0],
|
|
3036
|
+
12 * out_nside**2,
|
|
3037
|
+
(nside_j3 // out_nside) ** 2,
|
|
3038
|
+
s1.shape[2],
|
|
3039
|
+
],
|
|
3040
|
+
),
|
|
3041
|
+
2,
|
|
3042
|
+
)
|
|
3043
|
+
S1[j3] = s1
|
|
3044
|
+
else:
|
|
3045
|
+
### Normalize S1
|
|
3046
|
+
if norm is not None:
|
|
3047
|
+
self.div_norm(s1, (P1_dic[j3]) ** 0.5)
|
|
3048
|
+
### We store S1 for image1 [Nbatch, Nmask, NS1, Norient3]
|
|
3049
|
+
S1.append(
|
|
3050
|
+
self.backend.bk_expand_dims(s1, off_S2)
|
|
3051
|
+
) # Add a dimension for NS1
|
|
3052
|
+
if calc_var:
|
|
3053
|
+
VS1.append(
|
|
3054
|
+
self.backend.bk_expand_dims(vs1, off_S2)
|
|
3055
|
+
) # Add a dimension for NS1
|
|
3056
|
+
|
|
3057
|
+
# Initialize dictionaries for |I1*Psi_j| * Psi_j3
|
|
3058
|
+
M1convPsi_dic = {}
|
|
3059
|
+
if cross:
|
|
3060
|
+
# Initialize dictionaries for |I2*Psi_j| * Psi_j3
|
|
3061
|
+
M2convPsi_dic = {}
|
|
3062
|
+
|
|
3063
|
+
###### S3
|
|
3064
|
+
nside_j2 = nside_j3
|
|
3065
|
+
for j2 in range(0, j3 + 1): # j2 <= j3
|
|
3066
|
+
if return_data:
|
|
3067
|
+
if S4[j3] is None:
|
|
3068
|
+
S4[j3] = {}
|
|
3069
|
+
S4[j3][j2] = None
|
|
3070
|
+
|
|
3071
|
+
### S3_auto = < (I1 * Psi)_j3 x (|I1 * Psi_j2| * Psi_j3)^* >_pix
|
|
3072
|
+
if not cross:
|
|
3073
|
+
if calc_var:
|
|
3074
|
+
s3, vs3 = self._compute_S3(
|
|
3075
|
+
j2,
|
|
3076
|
+
j3,
|
|
3077
|
+
conv1,
|
|
3078
|
+
vmask,
|
|
3079
|
+
M1_dic,
|
|
3080
|
+
M1convPsi_dic,
|
|
3081
|
+
calc_var=True,
|
|
3082
|
+
cmat2=cmat2,
|
|
3083
|
+
) # [Nbatch, Nmask, Norient3, Norient2]
|
|
3084
|
+
else:
|
|
3085
|
+
s3 = self._compute_S3(
|
|
3086
|
+
j2,
|
|
3087
|
+
j3,
|
|
3088
|
+
conv1,
|
|
3089
|
+
vmask,
|
|
3090
|
+
M1_dic,
|
|
3091
|
+
M1convPsi_dic,
|
|
3092
|
+
return_data=return_data,
|
|
3093
|
+
cmat2=cmat2,
|
|
3094
|
+
) # [Nbatch, Nmask, Norient3, Norient2]
|
|
3095
|
+
|
|
3096
|
+
if return_data:
|
|
3097
|
+
if S3[j3] is None:
|
|
3098
|
+
S3[j3] = {}
|
|
3099
|
+
if out_nside is not None and out_nside < nside_j2:
|
|
3100
|
+
s3 = self.backend.bk_reduce_mean(
|
|
3101
|
+
self.backend.bk_reshape(
|
|
3102
|
+
s3,
|
|
3103
|
+
[
|
|
3104
|
+
s3.shape[0],
|
|
3105
|
+
12 * out_nside**2,
|
|
3106
|
+
(nside_j2 // out_nside) ** 2,
|
|
3107
|
+
s3.shape[2],
|
|
3108
|
+
s3.shape[3],
|
|
3109
|
+
],
|
|
3110
|
+
),
|
|
3111
|
+
2,
|
|
3112
|
+
)
|
|
3113
|
+
S3[j3][j2] = s3
|
|
3114
|
+
else:
|
|
3115
|
+
### Normalize S3 with S2_j [Nbatch, Nmask, Norient_j]
|
|
3116
|
+
if norm is not None:
|
|
3117
|
+
self.div_norm(
|
|
3118
|
+
s3,
|
|
3119
|
+
(
|
|
3120
|
+
self.backend.bk_expand_dims(P1_dic[j2], off_S2)
|
|
3121
|
+
* self.backend.bk_expand_dims(P1_dic[j3], -1)
|
|
3122
|
+
)
|
|
3123
|
+
** 0.5,
|
|
3124
|
+
) # [Nbatch, Nmask, Norient3, Norient2]
|
|
3125
|
+
|
|
3126
|
+
### Store S3 as a complex [Nbatch, Nmask, NS3, Norient3, Norient2]
|
|
3127
|
+
|
|
3128
|
+
# S3.append(self.backend.bk_reshape(s3,[s3.shape[0],s3.shape[1],
|
|
3129
|
+
# s3.shape[2]*s3.shape[3]]))
|
|
3130
|
+
S3.append(
|
|
3131
|
+
self.backend.bk_expand_dims(s3, off_S3)
|
|
3132
|
+
) # Add a dimension for NS3
|
|
3133
|
+
if calc_var:
|
|
3134
|
+
VS3.append(
|
|
3135
|
+
self.backend.bk_expand_dims(vs3, off_S3)
|
|
3136
|
+
) # Add a dimension for NS3
|
|
3137
|
+
# VS3.append(self.backend.bk_reshape(vs3,[s3.shape[0],s3.shape[1],
|
|
3138
|
+
# s3.shape[2]*s3.shape[3]]))
|
|
3139
|
+
|
|
3140
|
+
### S3_cross = < (I1 * Psi)_j3 x (|I2 * Psi_j2| * Psi_j3)^* >_pix
|
|
3141
|
+
### S3P_cross = < (I2 * Psi)_j3 x (|I1 * Psi_j2| * Psi_j3)^* >_pix
|
|
3142
|
+
else:
|
|
3143
|
+
if calc_var:
|
|
3144
|
+
s3, vs3 = self._compute_S3(
|
|
3145
|
+
j2,
|
|
3146
|
+
j3,
|
|
3147
|
+
conv1,
|
|
3148
|
+
vmask,
|
|
3149
|
+
M2_dic,
|
|
3150
|
+
M2convPsi_dic,
|
|
3151
|
+
calc_var=True,
|
|
3152
|
+
cmat2=cmat2,
|
|
3153
|
+
)
|
|
3154
|
+
s3p, vs3p = self._compute_S3(
|
|
3155
|
+
j2,
|
|
3156
|
+
j3,
|
|
3157
|
+
conv2,
|
|
3158
|
+
vmask,
|
|
3159
|
+
M1_dic,
|
|
3160
|
+
M1convPsi_dic,
|
|
3161
|
+
calc_var=True,
|
|
3162
|
+
cmat2=cmat2,
|
|
3163
|
+
)
|
|
3164
|
+
else:
|
|
3165
|
+
s3 = self._compute_S3(
|
|
3166
|
+
j2,
|
|
3167
|
+
j3,
|
|
3168
|
+
conv1,
|
|
3169
|
+
vmask,
|
|
3170
|
+
M2_dic,
|
|
3171
|
+
M2convPsi_dic,
|
|
3172
|
+
return_data=return_data,
|
|
3173
|
+
cmat2=cmat2,
|
|
3174
|
+
)
|
|
3175
|
+
s3p = self._compute_S3(
|
|
3176
|
+
j2,
|
|
3177
|
+
j3,
|
|
3178
|
+
conv2,
|
|
3179
|
+
vmask,
|
|
3180
|
+
M1_dic,
|
|
3181
|
+
M1convPsi_dic,
|
|
3182
|
+
return_data=return_data,
|
|
3183
|
+
cmat2=cmat2,
|
|
3184
|
+
)
|
|
3185
|
+
|
|
3186
|
+
if return_data:
|
|
3187
|
+
if S3[j3] is None:
|
|
3188
|
+
S3[j3] = {}
|
|
3189
|
+
S3P[j3] = {}
|
|
3190
|
+
if out_nside is not None and out_nside < nside_j2:
|
|
3191
|
+
s3 = self.backend.bk_reduce_mean(
|
|
3192
|
+
self.backend.bk_reshape(
|
|
3193
|
+
s3,
|
|
3194
|
+
[
|
|
3195
|
+
s3.shape[0],
|
|
3196
|
+
12 * out_nside**2,
|
|
3197
|
+
(nside_j2 // out_nside) ** 2,
|
|
3198
|
+
s3.shape[2],
|
|
3199
|
+
s3.shape[3],
|
|
3200
|
+
],
|
|
3201
|
+
),
|
|
3202
|
+
2,
|
|
3203
|
+
)
|
|
3204
|
+
s3p = self.backend.bk_reduce_mean(
|
|
3205
|
+
self.backend.bk_reshape(
|
|
3206
|
+
s3p,
|
|
3207
|
+
[
|
|
3208
|
+
s3.shape[0],
|
|
3209
|
+
12 * out_nside**2,
|
|
3210
|
+
(nside_j2 // out_nside) ** 2,
|
|
3211
|
+
s3.shape[2],
|
|
3212
|
+
s3.shape[3],
|
|
3213
|
+
],
|
|
3214
|
+
),
|
|
3215
|
+
2,
|
|
3216
|
+
)
|
|
3217
|
+
S3[j3][j2] = s3
|
|
3218
|
+
S3P[j3][j2] = s3p
|
|
3219
|
+
else:
|
|
3220
|
+
### Normalize S3 and S3P with S2_j [Nbatch, Nmask, Norient_j]
|
|
3221
|
+
if norm is not None:
|
|
3222
|
+
self.div_norm(
|
|
3223
|
+
s3,
|
|
3224
|
+
(
|
|
3225
|
+
self.backend.bk_expand_dims(P2_dic[j2], off_S2)
|
|
3226
|
+
* self.backend.bk_expand_dims(P1_dic[j3], -1)
|
|
3227
|
+
)
|
|
3228
|
+
** 0.5,
|
|
3229
|
+
) # [Nbatch, Nmask, Norient3, Norient2]
|
|
3230
|
+
self.div_norm(
|
|
3231
|
+
s3p,
|
|
3232
|
+
(
|
|
3233
|
+
self.backend.bk_expand_dims(P1_dic[j2], off_S2)
|
|
3234
|
+
* self.backend.bk_expand_dims(P2_dic[j3], -1)
|
|
3235
|
+
)
|
|
3236
|
+
** 0.5,
|
|
3237
|
+
) # [Nbatch, Nmask, Norient3, Norient2]
|
|
3238
|
+
|
|
3239
|
+
### Store S3 and S3P as a complex [Nbatch, Nmask, NS3, Norient3, Norient2]
|
|
3240
|
+
|
|
3241
|
+
# S3.append(self.backend.bk_reshape(s3,[s3.shape[0],s3.shape[1],
|
|
3242
|
+
# s3.shape[2]*s3.shape[3]]))
|
|
3243
|
+
S3.append(
|
|
3244
|
+
self.backend.bk_expand_dims(s3, off_S3)
|
|
3245
|
+
) # Add a dimension for NS3
|
|
3246
|
+
if calc_var:
|
|
3247
|
+
VS3.append(
|
|
3248
|
+
self.backend.bk_expand_dims(vs3, off_S3)
|
|
3249
|
+
) # Add a dimension for NS3
|
|
3250
|
+
|
|
3251
|
+
# VS3.append(self.backend.bk_reshape(vs3,[s3.shape[0],s3.shape[1],
|
|
3252
|
+
# s3.shape[2]*s3.shape[3]]))
|
|
3253
|
+
|
|
3254
|
+
# S3P.append(self.backend.bk_reshape(s3p,[s3.shape[0],s3.shape[1],
|
|
3255
|
+
# s3.shape[2]*s3.shape[3]]))
|
|
3256
|
+
S3P.append(
|
|
3257
|
+
self.backend.bk_expand_dims(s3p, off_S3)
|
|
3258
|
+
) # Add a dimension for NS3
|
|
3259
|
+
if calc_var:
|
|
3260
|
+
VS3P.append(
|
|
3261
|
+
self.backend.bk_expand_dims(vs3p, off_S3)
|
|
3262
|
+
) # Add a dimension for NS3
|
|
3263
|
+
# VS3P.append(self.backend.bk_reshape(vs3p,[s3.shape[0],s3.shape[1],
|
|
3264
|
+
# s3.shape[2]*s3.shape[3]]))
|
|
3265
|
+
|
|
3266
|
+
##### S4
|
|
3267
|
+
nside_j1 = nside_j2
|
|
3268
|
+
for j1 in range(0, j2 + 1): # j1 <= j2
|
|
3269
|
+
### S4_auto = <(|I1 * psi1| * psi3)(|I1 * psi2| * psi3)^*>
|
|
3270
|
+
if not cross:
|
|
3271
|
+
if calc_var:
|
|
3272
|
+
s4, vs4 = self._compute_S4(
|
|
3273
|
+
j1,
|
|
3274
|
+
j2,
|
|
3275
|
+
vmask,
|
|
3276
|
+
M1convPsi_dic,
|
|
3277
|
+
M2convPsi_dic=None,
|
|
3278
|
+
calc_var=True,
|
|
3279
|
+
) # [Nbatch, Nmask, Norient3, Norient2, Norient1]
|
|
3280
|
+
else:
|
|
3281
|
+
s4 = self._compute_S4(
|
|
3282
|
+
j1,
|
|
3283
|
+
j2,
|
|
3284
|
+
vmask,
|
|
3285
|
+
M1convPsi_dic,
|
|
3286
|
+
M2convPsi_dic=None,
|
|
3287
|
+
return_data=return_data,
|
|
3288
|
+
) # [Nbatch, Nmask, Norient3, Norient2, Norient1]
|
|
3289
|
+
|
|
3290
|
+
if return_data:
|
|
3291
|
+
if S4[j3][j2] is None:
|
|
3292
|
+
S4[j3][j2] = {}
|
|
3293
|
+
if out_nside is not None and out_nside < nside_j1:
|
|
3294
|
+
s4 = self.backend.bk_reduce_mean(
|
|
3295
|
+
self.backend.bk_reshape(
|
|
3296
|
+
s4,
|
|
3297
|
+
[
|
|
3298
|
+
s4.shape[0],
|
|
3299
|
+
12 * out_nside**2,
|
|
3300
|
+
(nside_j1 // out_nside) ** 2,
|
|
3301
|
+
s4.shape[2],
|
|
3302
|
+
s4.shape[3],
|
|
3303
|
+
s4.shape[4],
|
|
3304
|
+
],
|
|
3305
|
+
),
|
|
3306
|
+
2,
|
|
3307
|
+
)
|
|
3308
|
+
S4[j3][j2][j1] = s4
|
|
3309
|
+
else:
|
|
3310
|
+
### Normalize S4 with S2_j [Nbatch, Nmask, Norient_j]
|
|
3311
|
+
if norm is not None:
|
|
3312
|
+
self.div_norm(
|
|
3313
|
+
s4,
|
|
3314
|
+
(
|
|
3315
|
+
self.backend.bk_expand_dims(
|
|
3316
|
+
self.backend.bk_expand_dims(
|
|
3317
|
+
P1_dic[j1], off_S2
|
|
3318
|
+
),
|
|
3319
|
+
off_S2,
|
|
3320
|
+
)
|
|
3321
|
+
* self.backend.bk_expand_dims(
|
|
3322
|
+
self.backend.bk_expand_dims(
|
|
3323
|
+
P1_dic[j2], off_S2
|
|
3324
|
+
),
|
|
3325
|
+
-1,
|
|
3326
|
+
)
|
|
3327
|
+
)
|
|
3328
|
+
** 0.5,
|
|
3329
|
+
) # [Nbatch, Nmask, Norient3, Norient2, Norient1]
|
|
3330
|
+
### Store S4 as a complex [Nbatch, Nmask, NS4, Norient3, Norient2, Norient1]
|
|
3331
|
+
|
|
3332
|
+
# S4.append(self.backend.bk_reshape(s4,[s4.shape[0],s4.shape[1],
|
|
3333
|
+
# s4.shape[2]*s4.shape[3]*s4.shape[4]]))
|
|
3334
|
+
S4.append(
|
|
3335
|
+
self.backend.bk_expand_dims(s4, off_S4)
|
|
3336
|
+
) # Add a dimension for NS4
|
|
3337
|
+
if calc_var:
|
|
3338
|
+
# VS4.append(self.backend.bk_reshape(vs4,[s4.shape[0],s4.shape[1],
|
|
3339
|
+
# s4.shape[2]*s4.shape[3]*s4.shape[4]]))
|
|
3340
|
+
VS4.append(
|
|
3341
|
+
self.backend.bk_expand_dims(vs4, off_S4)
|
|
3342
|
+
) # Add a dimension for NS4
|
|
3343
|
+
|
|
3344
|
+
### S4_cross = <(|I1 * psi1| * psi3)(|I2 * psi2| * psi3)^*>
|
|
3345
|
+
else:
|
|
3346
|
+
if calc_var:
|
|
3347
|
+
s4, vs4 = self._compute_S4(
|
|
3348
|
+
j1,
|
|
3349
|
+
j2,
|
|
3350
|
+
vmask,
|
|
3351
|
+
M1convPsi_dic,
|
|
3352
|
+
M2convPsi_dic=M2convPsi_dic,
|
|
3353
|
+
calc_var=True,
|
|
3354
|
+
) # [Nbatch, Nmask, Norient3, Norient2, Norient1]
|
|
3355
|
+
else:
|
|
3356
|
+
s4 = self._compute_S4(
|
|
3357
|
+
j1,
|
|
3358
|
+
j2,
|
|
3359
|
+
vmask,
|
|
3360
|
+
M1convPsi_dic,
|
|
3361
|
+
M2convPsi_dic=M2convPsi_dic,
|
|
3362
|
+
return_data=return_data,
|
|
3363
|
+
) # [Nbatch, Nmask, Norient3, Norient2, Norient1]
|
|
3364
|
+
|
|
3365
|
+
if return_data:
|
|
3366
|
+
if S4[j3][j2] is None:
|
|
3367
|
+
S4[j3][j2] = {}
|
|
3368
|
+
if out_nside is not None and out_nside < nside_j1:
|
|
3369
|
+
s4 = self.backend.bk_reduce_mean(
|
|
3370
|
+
self.backend.bk_reshape(
|
|
3371
|
+
s4,
|
|
3372
|
+
[
|
|
3373
|
+
s4.shape[0],
|
|
3374
|
+
12 * out_nside**2,
|
|
3375
|
+
(nside_j1 // out_nside) ** 2,
|
|
3376
|
+
s4.shape[2],
|
|
3377
|
+
s4.shape[3],
|
|
3378
|
+
s4.shape[4],
|
|
3379
|
+
],
|
|
3380
|
+
),
|
|
3381
|
+
2,
|
|
3382
|
+
)
|
|
3383
|
+
S4[j3][j2][j1] = s4
|
|
3384
|
+
else:
|
|
3385
|
+
### Normalize S4 with S2_j [Nbatch, Nmask, Norient_j]
|
|
3386
|
+
if norm is not None:
|
|
3387
|
+
self.div_norm(
|
|
3388
|
+
s4,
|
|
3389
|
+
(
|
|
3390
|
+
self.backend.bk_expand_dims(
|
|
3391
|
+
self.backend.bk_expand_dims(
|
|
3392
|
+
P1_dic[j1], off_S2
|
|
3393
|
+
),
|
|
3394
|
+
off_S2,
|
|
3395
|
+
)
|
|
3396
|
+
* self.backend.bk_expand_dims(
|
|
3397
|
+
self.backend.bk_expand_dims(
|
|
3398
|
+
P2_dic[j2], off_S2
|
|
3399
|
+
),
|
|
3400
|
+
-1,
|
|
3401
|
+
)
|
|
3402
|
+
)
|
|
3403
|
+
** 0.5,
|
|
3404
|
+
) # [Nbatch, Nmask, Norient3, Norient2, Norient1]
|
|
3405
|
+
### Store S4 as a complex [Nbatch, Nmask, NS4, Norient3, Norient2, Norient1]
|
|
3406
|
+
# S4.append(self.backend.bk_reshape(s4,[s4.shape[0],s4.shape[1],
|
|
3407
|
+
# s4.shape[2]*s4.shape[3]*s4.shape[4]]))
|
|
3408
|
+
S4.append(
|
|
3409
|
+
self.backend.bk_expand_dims(s4, off_S4)
|
|
3410
|
+
) # Add a dimension for NS4
|
|
3411
|
+
if calc_var:
|
|
3412
|
+
|
|
3413
|
+
# VS4.append(self.backend.bk_reshape(vs4,[s4.shape[0],s4.shape[1],
|
|
3414
|
+
# s4.shape[2]*s4.shape[3]*s4.shape[4]]))
|
|
3415
|
+
VS4.append(
|
|
3416
|
+
self.backend.bk_expand_dims(vs4, off_S4)
|
|
3417
|
+
) # Add a dimension for NS4
|
|
3418
|
+
|
|
3419
|
+
nside_j1 = nside_j1 // 2
|
|
3420
|
+
nside_j2 = nside_j2 // 2
|
|
3421
|
+
|
|
3422
|
+
###### Reshape for next iteration on j3
|
|
3423
|
+
### Image I1,
|
|
3424
|
+
# downscale the I1 [Nbatch, Npix_j3]
|
|
3425
|
+
if j3 != Jmax - 1:
|
|
3426
|
+
I1 = self.smooth(I1, axis=1)
|
|
3427
|
+
I1 = self.ud_grade_2(I1, axis=1)
|
|
3428
|
+
|
|
3429
|
+
### Image I2
|
|
3430
|
+
if cross:
|
|
3431
|
+
I2 = self.smooth(I2, axis=1)
|
|
3432
|
+
I2 = self.ud_grade_2(I2, axis=1)
|
|
3433
|
+
|
|
3434
|
+
### Modules
|
|
3435
|
+
for j2 in range(0, j3 + 1): # j2 =< j3
|
|
3436
|
+
### Dictionary M1_dic[j2]
|
|
3437
|
+
M1_smooth = self.smooth(
|
|
3438
|
+
M1_dic[j2], axis=1
|
|
3439
|
+
) # [Nbatch, Npix_j3, Norient3]
|
|
3440
|
+
M1_dic[j2] = self.ud_grade_2(
|
|
3441
|
+
M1_smooth, axis=1
|
|
3442
|
+
) # [Nbatch, Npix_j3, Norient3]
|
|
3443
|
+
|
|
3444
|
+
### Dictionary M2_dic[j2]
|
|
3445
|
+
if cross:
|
|
3446
|
+
M2_smooth = self.smooth(
|
|
3447
|
+
M2_dic[j2], axis=1
|
|
3448
|
+
) # [Nbatch, Npix_j3, Norient3]
|
|
3449
|
+
M2_dic[j2] = self.ud_grade_2(
|
|
3450
|
+
M2, axis=1
|
|
3451
|
+
) # [Nbatch, Npix_j3, Norient3]
|
|
3452
|
+
|
|
3453
|
+
### Mask
|
|
3454
|
+
vmask = self.ud_grade_2(vmask, axis=1)
|
|
3455
|
+
|
|
3456
|
+
if self.mask_thres is not None:
|
|
3457
|
+
vmask = self.backend.bk_threshold(vmask, self.mask_thres)
|
|
3458
|
+
|
|
3459
|
+
### NSIDE_j3
|
|
3460
|
+
nside_j3 = nside_j3 // 2
|
|
3461
|
+
|
|
3462
|
+
### Store P1_dic and P2_dic in self
|
|
3463
|
+
if (norm == "auto") and (self.P1_dic is None):
|
|
3464
|
+
self.P1_dic = P1_dic
|
|
3465
|
+
if cross:
|
|
3466
|
+
self.P2_dic = P2_dic
|
|
3467
|
+
"""
|
|
3468
|
+
Sout=[s0]+S1+S2+S3+S4
|
|
3469
|
+
|
|
3470
|
+
if cross:
|
|
3471
|
+
Sout=Sout+S3P
|
|
3472
|
+
if calc_var:
|
|
3473
|
+
SVout=[vs0]+VS1+VS2+VS3+VS4
|
|
3474
|
+
if cross:
|
|
3475
|
+
VSout=VSout+VS3P
|
|
3476
|
+
return self.backend.bk_concat(Sout, 2),self.backend.bk_concat(VSout, 2)
|
|
3477
|
+
|
|
3478
|
+
return self.backend.bk_concat(Sout, 2)
|
|
3479
|
+
"""
|
|
3480
|
+
if not return_data:
|
|
3481
|
+
S1 = self.backend.bk_concat(S1, 2)
|
|
3482
|
+
S2 = self.backend.bk_concat(S2, 2)
|
|
3483
|
+
S3 = self.backend.bk_concat(S3, 2)
|
|
3484
|
+
S4 = self.backend.bk_concat(S4, 2)
|
|
3485
|
+
if cross:
|
|
3486
|
+
S3P = self.backend.bk_concat(S3P, 2)
|
|
3487
|
+
if calc_var:
|
|
3488
|
+
VS1 = self.backend.bk_concat(VS1, 2)
|
|
3489
|
+
VS2 = self.backend.bk_concat(VS2, 2)
|
|
3490
|
+
VS3 = self.backend.bk_concat(VS3, 2)
|
|
3491
|
+
VS4 = self.backend.bk_concat(VS4, 2)
|
|
3492
|
+
if cross:
|
|
3493
|
+
VS3P = self.backend.bk_concat(VS3P, 2)
|
|
3494
|
+
if calc_var:
|
|
3495
|
+
if not cross:
|
|
3496
|
+
return scat_cov(
|
|
3497
|
+
s0, S2, S3, S4, s1=S1, backend=self.backend, use_1D=self.use_1D
|
|
3498
|
+
), scat_cov(
|
|
3499
|
+
vs0,
|
|
3500
|
+
VS2,
|
|
3501
|
+
VS3,
|
|
3502
|
+
VS4,
|
|
3503
|
+
s1=VS1,
|
|
3504
|
+
backend=self.backend,
|
|
3505
|
+
use_1D=self.use_1D,
|
|
3506
|
+
)
|
|
3507
|
+
else:
|
|
3508
|
+
return scat_cov(
|
|
3509
|
+
s0,
|
|
3510
|
+
S2,
|
|
3511
|
+
S3,
|
|
3512
|
+
S4,
|
|
3513
|
+
s1=S1,
|
|
3514
|
+
s3p=S3P,
|
|
3515
|
+
backend=self.backend,
|
|
3516
|
+
use_1D=self.use_1D,
|
|
3517
|
+
), scat_cov(
|
|
3518
|
+
vs0,
|
|
3519
|
+
VS2,
|
|
3520
|
+
VS3,
|
|
3521
|
+
VS4,
|
|
3522
|
+
s1=VS1,
|
|
3523
|
+
s3p=VS3P,
|
|
3524
|
+
backend=self.backend,
|
|
3525
|
+
use_1D=self.use_1D,
|
|
3526
|
+
)
|
|
3527
|
+
else:
|
|
3528
|
+
if not cross:
|
|
3529
|
+
return scat_cov(
|
|
3530
|
+
s0, S2, S3, S4, s1=S1, backend=self.backend, use_1D=self.use_1D
|
|
3531
|
+
)
|
|
3532
|
+
else:
|
|
3533
|
+
return scat_cov(
|
|
3534
|
+
s0,
|
|
3535
|
+
S2,
|
|
3536
|
+
S3,
|
|
3537
|
+
S4,
|
|
3538
|
+
s1=S1,
|
|
3539
|
+
s3p=S3P,
|
|
3540
|
+
backend=self.backend,
|
|
3541
|
+
use_1D=self.use_1D,
|
|
3542
|
+
)
|
|
3543
|
+
|
|
3544
|
+
def eval_new(
|
|
3545
|
+
self,
|
|
3546
|
+
image1,
|
|
3547
|
+
image2=None,
|
|
3548
|
+
mask=None,
|
|
3549
|
+
norm=None,
|
|
3550
|
+
calc_var=False,
|
|
3551
|
+
cmat=None,
|
|
3552
|
+
cmat2=None,
|
|
3553
|
+
Jmax=None,
|
|
3554
|
+
out_nside=None,
|
|
3555
|
+
edge=True
|
|
3556
|
+
):
|
|
3557
|
+
"""
|
|
3558
|
+
Calculates the scattering correlations for a batch of images. Mean are done over pixels.
|
|
3559
|
+
mean of modulus:
|
|
3560
|
+
S1 = <|I * Psi_j3|>
|
|
3561
|
+
Normalization : take the log
|
|
3562
|
+
power spectrum:
|
|
3563
|
+
S2 = <|I * Psi_j3|^2>
|
|
3564
|
+
Normalization : take the log
|
|
3565
|
+
orig. x modulus:
|
|
3566
|
+
S3 = < (I * Psi)_j3 x (|I * Psi_j2| * Psi_j3)^* >
|
|
3567
|
+
Normalization : divide by (S2_j2 * S2_j3)^0.5
|
|
3568
|
+
modulus x modulus:
|
|
3569
|
+
S4 = <(|I * psi1| * psi3)(|I * psi2| * psi3)^*>
|
|
3570
|
+
Normalization : divide by (S2_j1 * S2_j2)^0.5
|
|
3571
|
+
Parameters
|
|
3572
|
+
----------
|
|
3573
|
+
image1: tensor
|
|
3574
|
+
Image on which we compute the scattering coefficients [Nbatch, Npix, 1, 1]
|
|
3575
|
+
image2: tensor
|
|
3576
|
+
Second image. If not None, we compute cross-scattering covariance coefficients.
|
|
3577
|
+
mask:
|
|
3578
|
+
norm: None or str
|
|
3579
|
+
If None no normalization is applied, if 'auto' normalize by the reference S2,
|
|
3580
|
+
if 'self' normalize by the current S2.
|
|
3581
|
+
Returns
|
|
3582
|
+
-------
|
|
3583
|
+
S1, S2, S3, S4 normalized
|
|
3584
|
+
"""
|
|
2543
3585
|
return_data = self.return_data
|
|
3586
|
+
NORIENT=self.NORIENT
|
|
2544
3587
|
# Check input consistency
|
|
2545
3588
|
if image2 is not None:
|
|
2546
3589
|
if list(image1.shape) != list(image2.shape):
|
|
@@ -2699,13 +3742,19 @@ class funct(FOC.FoCUS):
|
|
|
2699
3742
|
S3P = {}
|
|
2700
3743
|
S4 = {}
|
|
2701
3744
|
else:
|
|
2702
|
-
|
|
2703
|
-
|
|
3745
|
+
result=self.backend.backend.zeros([I1.shape[0],vmask.shape[0],2+2*Jmax*self.NORIENT],
|
|
3746
|
+
dtype=self.backend.backend.float32,
|
|
3747
|
+
device=self.backend.torch_device)
|
|
3748
|
+
vresult=self.backend.backend.zeros([I1.shape[0],vmask.shape[0],2+2*Jmax*self.NORIENT],
|
|
3749
|
+
dtype=self.backend.backend.float32,
|
|
3750
|
+
device=self.backend.torch_device)
|
|
3751
|
+
S1 = self.backend.backend.zeros([1,Jmax*self.NORIENT],dtype=self.backend.backend.float32,device=self.backend.torch_device)
|
|
3752
|
+
S2 = self.backend.backend.zeros([1,Jmax*self.NORIENT],dtype=self.backend.backend.float32,device=self.backend.torch_device)
|
|
2704
3753
|
S3 = []
|
|
2705
3754
|
S4 = []
|
|
2706
3755
|
S3P = []
|
|
2707
|
-
VS1 = []
|
|
2708
|
-
VS2 = []
|
|
3756
|
+
VS1 = self.backend.backend.zeros([1,Jmax*self.NORIENT],dtype=self.backend.backend.float32,device=self.backend.torch_device)
|
|
3757
|
+
VS2 = self.backend.backend.zeros([1,Jmax*self.NORIENT],dtype=self.backend.backend.float32,device=self.backend.torch_device)
|
|
2709
3758
|
VS3 = []
|
|
2710
3759
|
VS3P = []
|
|
2711
3760
|
VS4 = []
|
|
@@ -2749,13 +3798,18 @@ class funct(FOC.FoCUS):
|
|
|
2749
3798
|
s0, l_vs0 = self.masked_mean(
|
|
2750
3799
|
self.backend.bk_L1(I1 * I2), vmask, axis=1, calc_var=True
|
|
2751
3800
|
)
|
|
2752
|
-
vs0 = self.backend.bk_concat([l_vs0, l_vs0], 1)
|
|
2753
|
-
s0 = self.backend.bk_concat([s0, l_vs0], 1)
|
|
3801
|
+
#vs0 = self.backend.bk_concat([l_vs0, l_vs0], 1)
|
|
3802
|
+
#s0 = self.backend.bk_concat([s0, l_vs0], 1)
|
|
3803
|
+
result[:,:,0]=s0
|
|
3804
|
+
result[:,:,1]=l_vs0
|
|
3805
|
+
vresult[:,:,0]=l_vs0
|
|
3806
|
+
vresult[:,:,1]=l_vs0
|
|
2754
3807
|
#### COMPUTE S1, S2, S3 and S4
|
|
2755
3808
|
nside_j3 = nside # NSIDE start (nside_j3 = nside / 2^j3)
|
|
2756
3809
|
|
|
2757
3810
|
# a remettre comme avant
|
|
2758
|
-
M1_dic={}
|
|
3811
|
+
M1_dic={}
|
|
3812
|
+
M2_dic={}
|
|
2759
3813
|
|
|
2760
3814
|
for j3 in range(Jmax):
|
|
2761
3815
|
|
|
@@ -2805,7 +3859,9 @@ class funct(FOC.FoCUS):
|
|
|
2805
3859
|
M1_square = conv1 * self.backend.bk_conjugate(
|
|
2806
3860
|
conv1
|
|
2807
3861
|
) # [Nbatch, Npix_j3, Norient3]
|
|
3862
|
+
|
|
2808
3863
|
M1 = self.backend.bk_L1(M1_square) # [Nbatch, Npix_j3, Norient3]
|
|
3864
|
+
|
|
2809
3865
|
# Store M1_j3 in a dictionary
|
|
2810
3866
|
M1_dic[j3] = M1
|
|
2811
3867
|
|
|
@@ -2821,13 +3877,15 @@ class funct(FOC.FoCUS):
|
|
|
2821
3877
|
s2, vs2 = self.masked_mean(
|
|
2822
3878
|
M1_square, vmask, axis=1, rank=j3, calc_var=True
|
|
2823
3879
|
)
|
|
3880
|
+
#s2=self.backend.bk_flatten(self.backend.bk_real(s2))
|
|
3881
|
+
#vs2=self.backend.bk_flatten(vs2)
|
|
2824
3882
|
else:
|
|
2825
3883
|
s2 = self.masked_mean(M1_square, vmask, axis=1, rank=j3)
|
|
2826
3884
|
|
|
2827
3885
|
if cond_init_P1_dic:
|
|
2828
3886
|
# We fill P1_dic with S2 for normalisation of S3 and S4
|
|
2829
|
-
P1_dic[j3] = self.backend.bk_real(s2) # [Nbatch, Nmask, Norient3]
|
|
2830
|
-
|
|
3887
|
+
P1_dic[j3] = self.backend.bk_real(self.backend.bk_real(s2)) # [Nbatch, Nmask, Norient3]
|
|
3888
|
+
|
|
2831
3889
|
# We store S2_auto to return it [Nbatch, Nmask, NS2, Norient3]
|
|
2832
3890
|
if return_data:
|
|
2833
3891
|
if S2 is None:
|
|
@@ -2849,7 +3907,7 @@ class funct(FOC.FoCUS):
|
|
|
2849
3907
|
else:
|
|
2850
3908
|
if norm == "auto": # Normalize S2
|
|
2851
3909
|
s2 /= P1_dic[j3]
|
|
2852
|
-
|
|
3910
|
+
"""
|
|
2853
3911
|
S2.append(
|
|
2854
3912
|
self.backend.bk_expand_dims(s2, off_S2)
|
|
2855
3913
|
) # Add a dimension for NS2
|
|
@@ -2857,7 +3915,11 @@ class funct(FOC.FoCUS):
|
|
|
2857
3915
|
VS2.append(
|
|
2858
3916
|
self.backend.bk_expand_dims(vs2, off_S2)
|
|
2859
3917
|
) # Add a dimension for NS2
|
|
2860
|
-
|
|
3918
|
+
"""
|
|
3919
|
+
#print(s2.shape,result[:,:,2+j3*NORIENT*2:2+j3*NORIENT*2+NORIENT].shape,result.shape,2+j3*NORIENT*2)
|
|
3920
|
+
result[:,:,2+j3*NORIENT*2:2+j3*NORIENT*2+NORIENT]=s2
|
|
3921
|
+
if calc_var:
|
|
3922
|
+
vresult[:,:,2+j3*NORIENT*2:2+j3*NORIENT*2+NORIENT]=vs2
|
|
2861
3923
|
#### S1_auto computation
|
|
2862
3924
|
### Image 1 : S1 = < M1 >_pix
|
|
2863
3925
|
# Apply the mask [Nmask, Npix_j3] and average over pixels
|
|
@@ -2868,10 +3930,13 @@ class funct(FOC.FoCUS):
|
|
|
2868
3930
|
s1, vs1 = self.masked_mean(
|
|
2869
3931
|
M1, vmask, axis=1, rank=j3, calc_var=True
|
|
2870
3932
|
) # [Nbatch, Nmask, Norient3]
|
|
3933
|
+
#s1=self.backend.bk_flatten(self.backend.bk_real(s1))
|
|
3934
|
+
#vs1=self.backend.bk_flatten(vs1)
|
|
2871
3935
|
else:
|
|
2872
3936
|
s1 = self.masked_mean(
|
|
2873
3937
|
M1, vmask, axis=1, rank=j3
|
|
2874
3938
|
) # [Nbatch, Nmask, Norient3]
|
|
3939
|
+
#s1=self.backend.bk_flatten(self.backend.bk_real(s1))
|
|
2875
3940
|
|
|
2876
3941
|
if return_data:
|
|
2877
3942
|
if out_nside is not None and out_nside < nside_j3:
|
|
@@ -2892,6 +3957,10 @@ class funct(FOC.FoCUS):
|
|
|
2892
3957
|
### Normalize S1
|
|
2893
3958
|
if norm is not None:
|
|
2894
3959
|
self.div_norm(s1, (P1_dic[j3]) ** 0.5)
|
|
3960
|
+
result[:,:,2+j3*NORIENT*2+NORIENT:2+j3*NORIENT*2+2*NORIENT]=s1
|
|
3961
|
+
if calc_var:
|
|
3962
|
+
vresult[:,:,2+j3*NORIENT*2+NORIENT:2+j3*NORIENT*2+2*NORIENT]=vs1
|
|
3963
|
+
"""
|
|
2895
3964
|
### We store S1 for image1 [Nbatch, Nmask, NS1, Norient3]
|
|
2896
3965
|
S1.append(
|
|
2897
3966
|
self.backend.bk_expand_dims(s1, off_S2)
|
|
@@ -2900,6 +3969,7 @@ class funct(FOC.FoCUS):
|
|
|
2900
3969
|
VS1.append(
|
|
2901
3970
|
self.backend.bk_expand_dims(vs1, off_S2)
|
|
2902
3971
|
) # Add a dimension for NS1
|
|
3972
|
+
"""
|
|
2903
3973
|
|
|
2904
3974
|
else: # Cross
|
|
2905
3975
|
### Make the convolution I2 * Psi_j3
|
|
@@ -3048,7 +4118,7 @@ class funct(FOC.FoCUS):
|
|
|
3048
4118
|
|
|
3049
4119
|
###### S3
|
|
3050
4120
|
nside_j2 = nside_j3
|
|
3051
|
-
for j2 in range(0
|
|
4121
|
+
for j2 in range(0,-1): # j3 + 1): # j2 <= j3
|
|
3052
4122
|
if return_data:
|
|
3053
4123
|
if S4[j3] is None:
|
|
3054
4124
|
S4[j3] = {}
|
|
@@ -3463,6 +4533,20 @@ class funct(FOC.FoCUS):
|
|
|
3463
4533
|
|
|
3464
4534
|
return self.backend.bk_concat(Sout, 2)
|
|
3465
4535
|
"""
|
|
4536
|
+
if calc_var:
|
|
4537
|
+
return result,vresult
|
|
4538
|
+
else:
|
|
4539
|
+
return result
|
|
4540
|
+
if calc_var:
|
|
4541
|
+
for k in S1:
|
|
4542
|
+
print(k.shape,k.dtype)
|
|
4543
|
+
for k in S2:
|
|
4544
|
+
print(k.shape,k.dtype)
|
|
4545
|
+
print(s0.shape,s0.dtype)
|
|
4546
|
+
return self.backend.bk_concat([s0]+S1+S2,axis=1),self.backend.bk_concat([vs0]+VS1+VS2,axis=1)
|
|
4547
|
+
else:
|
|
4548
|
+
return self.backend.bk_concat([s0]+S1+S2,axis=1)
|
|
4549
|
+
|
|
3466
4550
|
if not return_data:
|
|
3467
4551
|
S1 = self.backend.bk_concat(S1, 2)
|
|
3468
4552
|
S2 = self.backend.bk_concat(S2, 2)
|
|
@@ -3526,7 +4610,6 @@ class funct(FOC.FoCUS):
|
|
|
3526
4610
|
backend=self.backend,
|
|
3527
4611
|
use_1D=self.use_1D,
|
|
3528
4612
|
)
|
|
3529
|
-
|
|
3530
4613
|
def clean_norm(self):
|
|
3531
4614
|
self.P1_dic = None
|
|
3532
4615
|
self.P2_dic = None
|
|
@@ -3644,7 +4727,733 @@ class funct(FOC.FoCUS):
|
|
|
3644
4727
|
s4, vmask, axis=1, rank=j2
|
|
3645
4728
|
) # [Nbatch, Nmask, Norient3, Norient2, Norient1]
|
|
3646
4729
|
return s4
|
|
4730
|
+
|
|
4731
|
+
def computer_filter(self,M,N,J,L):
|
|
4732
|
+
'''
|
|
4733
|
+
This function is strongly inspire by the package https://github.com/SihaoCheng/scattering_transform
|
|
4734
|
+
Done by Sihao Cheng and Rudy Morel.
|
|
4735
|
+
'''
|
|
4736
|
+
|
|
4737
|
+
filter = np.zeros([J, L, M, N],dtype='complex64')
|
|
4738
|
+
|
|
4739
|
+
slant=4.0 / L
|
|
4740
|
+
|
|
4741
|
+
for j in range(J):
|
|
4742
|
+
|
|
4743
|
+
for l in range(L):
|
|
4744
|
+
|
|
4745
|
+
theta = (int(L-L/2-1)-l) * np.pi / L
|
|
4746
|
+
sigma = 0.8 * 2**j
|
|
4747
|
+
xi = 3.0 / 4.0 * np.pi /2**j
|
|
4748
|
+
|
|
4749
|
+
R = np.array([[np.cos(theta), -np.sin(theta)], [np.sin(theta), np.cos(theta)]], np.float64)
|
|
4750
|
+
R_inv = np.array([[np.cos(theta), np.sin(theta)], [-np.sin(theta), np.cos(theta)]], np.float64)
|
|
4751
|
+
D = np.array([[1, 0], [0, slant * slant]])
|
|
4752
|
+
curv = np.matmul(R, np.matmul(D, R_inv)) / ( 2 * sigma * sigma)
|
|
4753
|
+
|
|
4754
|
+
gab = np.zeros((M, N), np.complex128)
|
|
4755
|
+
xx = np.empty((2,2, M, N))
|
|
4756
|
+
yy = np.empty((2,2, M, N))
|
|
4757
|
+
|
|
4758
|
+
for ii, ex in enumerate([-1, 0]):
|
|
4759
|
+
for jj, ey in enumerate([-1, 0]):
|
|
4760
|
+
xx[ii,jj], yy[ii,jj] = np.mgrid[
|
|
4761
|
+
ex * M : M + ex * M,
|
|
4762
|
+
ey * N : N + ey * N]
|
|
4763
|
+
|
|
4764
|
+
arg = -(curv[0, 0] * xx * xx + (curv[0, 1] + curv[1, 0]) * xx * yy + curv[1, 1] * yy * yy)
|
|
4765
|
+
argi = arg + 1.j * (xx * xi * np.cos(theta) + yy * xi * np.sin(theta))
|
|
4766
|
+
|
|
4767
|
+
gabi = np.exp(argi).sum((0,1))
|
|
4768
|
+
gab = np.exp(arg).sum((0,1))
|
|
4769
|
+
|
|
4770
|
+
norm_factor = 2 * np.pi * sigma * sigma / slant
|
|
4771
|
+
|
|
4772
|
+
gab = gab / norm_factor
|
|
4773
|
+
|
|
4774
|
+
gabi = gabi / norm_factor
|
|
4775
|
+
|
|
4776
|
+
K = gabi.sum() / gab.sum()
|
|
4777
|
+
|
|
4778
|
+
# Apply the Gaussian
|
|
4779
|
+
filter[j, l] = np.fft.fft2(gabi-K*gab)
|
|
4780
|
+
filter[j,l,0,0]=0.0
|
|
4781
|
+
|
|
4782
|
+
return self.backend.bk_cast(filter)
|
|
4783
|
+
|
|
4784
|
+
# ------------------------------------------------------------------------------------------
|
|
4785
|
+
#
|
|
4786
|
+
# utility functions
|
|
4787
|
+
#
|
|
4788
|
+
# ------------------------------------------------------------------------------------------
|
|
4789
|
+
def cut_high_k_off(self,data_f, dx, dy):
|
|
4790
|
+
'''
|
|
4791
|
+
This function is strongly inspire by the package https://github.com/SihaoCheng/scattering_transform
|
|
4792
|
+
Done by Sihao Cheng and Rudy Morel.
|
|
4793
|
+
'''
|
|
4794
|
+
if_xodd = (data_f.shape[-2]%2==1)
|
|
4795
|
+
if_yodd = (data_f.shape[-1]%2==1)
|
|
4796
|
+
result = self.backend.backend.cat(
|
|
4797
|
+
(self.backend.backend.cat(
|
|
4798
|
+
( data_f[...,:dx+if_xodd, :dy+if_yodd] , data_f[...,-dx:, :dy+if_yodd]
|
|
4799
|
+
), -2),
|
|
4800
|
+
self.backend.backend.cat(
|
|
4801
|
+
( data_f[...,:dx+if_xodd, -dy:] , data_f[...,-dx:, -dy:]
|
|
4802
|
+
), -2)
|
|
4803
|
+
),-1)
|
|
4804
|
+
return result
|
|
4805
|
+
# ---------------------------------------------------------------------------
|
|
4806
|
+
#
|
|
4807
|
+
# utility functions for computing scattering coef and covariance
|
|
4808
|
+
#
|
|
4809
|
+
# ---------------------------------------------------------------------------
|
|
4810
|
+
|
|
4811
|
+
def get_dxdy(self, j,M,N):
|
|
4812
|
+
'''
|
|
4813
|
+
This function is strongly inspire by the package https://github.com/SihaoCheng/scattering_transform
|
|
4814
|
+
Done by Sihao Cheng and Rudy Morel.
|
|
4815
|
+
'''
|
|
4816
|
+
dx = int(max( 8, min( np.ceil(M/2**j), M//2 ) ))
|
|
4817
|
+
dy = int(max( 8, min( np.ceil(N/2**j), N//2 ) ))
|
|
4818
|
+
return dx, dy
|
|
4819
|
+
|
|
4820
|
+
|
|
4821
|
+
|
|
4822
|
+
def get_edge_masks(self,M, N, J, d0=1):
|
|
4823
|
+
'''
|
|
4824
|
+
This function is strongly inspire by the package https://github.com/SihaoCheng/scattering_transform
|
|
4825
|
+
Done by Sihao Cheng and Rudy Morel.
|
|
4826
|
+
'''
|
|
4827
|
+
edge_masks = self.backend.backend.empty((J, M, N))
|
|
4828
|
+
X, Y = self.backend.backend.meshgrid(self.backend.backend.arange(M), self.backend.backend.arange(N), indexing='ij')
|
|
4829
|
+
for j in range(J):
|
|
4830
|
+
edge_dx = min(M//4, 2**j*d0)
|
|
4831
|
+
edge_dy = min(N//4, 2**j*d0)
|
|
4832
|
+
edge_masks[j] = (X>=edge_dx) * (X<=M-edge_dx) * (Y>=edge_dy) * (Y<=N-edge_dy)
|
|
4833
|
+
return edge_masks.to(self.backend.torch_device)
|
|
4834
|
+
|
|
4835
|
+
# ---------------------------------------------------------------------------
|
|
4836
|
+
#
|
|
4837
|
+
# scattering cov
|
|
4838
|
+
#
|
|
4839
|
+
# ---------------------------------------------------------------------------
|
|
4840
|
+
def scattering_cov(
|
|
4841
|
+
self, data,
|
|
4842
|
+
data2=None,
|
|
4843
|
+
Jmax=None,
|
|
4844
|
+
if_large_batch=False,
|
|
4845
|
+
S4_criteria=None,
|
|
4846
|
+
use_ref=False,
|
|
4847
|
+
normalization='S2',
|
|
4848
|
+
edge=False,
|
|
4849
|
+
pseudo_coef=1,
|
|
4850
|
+
get_variance=False,
|
|
4851
|
+
ref_sigma=None,
|
|
4852
|
+
iso_ang=False
|
|
4853
|
+
):
|
|
4854
|
+
'''
|
|
4855
|
+
Calculates the scattering correlations for a batch of images, including:
|
|
4856
|
+
|
|
4857
|
+
This function is strongly inspire by the package https://github.com/SihaoCheng/scattering_transform
|
|
4858
|
+
Done by Sihao Cheng and Rudy Morel.
|
|
4859
|
+
|
|
4860
|
+
orig. x orig.:
|
|
4861
|
+
P00 = <(I * psi)(I * psi)*> = L2(I * psi)^2
|
|
4862
|
+
orig. x modulus:
|
|
4863
|
+
C01 = <(I * psi2)(|I * psi1| * psi2)*> / factor
|
|
4864
|
+
when normalization == 'P00', factor = L2(I * psi2) * L2(I * psi1)
|
|
4865
|
+
when normalization == 'P11', factor = L2(I * psi2) * L2(|I * psi1| * psi2)
|
|
4866
|
+
modulus x modulus:
|
|
4867
|
+
C11_pre_norm = <(|I * psi1| * psi3)(|I * psi2| * psi3)>
|
|
4868
|
+
C11 = C11_pre_norm / factor
|
|
4869
|
+
when normalization == 'P00', factor = L2(I * psi1) * L2(I * psi2)
|
|
4870
|
+
when normalization == 'P11', factor = L2(|I * psi1| * psi3) * L2(|I * psi2| * psi3)
|
|
4871
|
+
modulus x modulus (auto):
|
|
4872
|
+
P11 = <(|I * psi1| * psi2)(|I * psi1| * psi2)*>
|
|
4873
|
+
Parameters
|
|
4874
|
+
----------
|
|
4875
|
+
data : numpy array or torch tensor
|
|
4876
|
+
image set, with size [N_image, x-sidelength, y-sidelength]
|
|
4877
|
+
if_large_batch : Bool (=False)
|
|
4878
|
+
It is recommended to use "False" unless one meets a memory issue
|
|
4879
|
+
C11_criteria : str or None (=None)
|
|
4880
|
+
Only C11 coefficients that satisfy this criteria will be computed.
|
|
4881
|
+
Any expressions of j1, j2, and j3 that can be evaluated as a Bool
|
|
4882
|
+
is accepted.The default "None" corresponds to "j1 <= j2 <= j3".
|
|
4883
|
+
use_ref : Bool (=False)
|
|
4884
|
+
When normalizing, whether or not to use the normalization factor
|
|
4885
|
+
computed from a reference field. For just computing the statistics,
|
|
4886
|
+
the default is False. However, for synthesis, set it to "True" will
|
|
4887
|
+
stablize the optimization process.
|
|
4888
|
+
normalization : str 'P00' or 'P11' (='P00')
|
|
4889
|
+
Whether 'P00' or 'P11' is used as the normalization factor for C01
|
|
4890
|
+
and C11.
|
|
4891
|
+
remove_edge : Bool (=False)
|
|
4892
|
+
If true, the edge region with a width of rougly the size of the largest
|
|
4893
|
+
wavelet involved is excluded when taking the global average to obtain
|
|
4894
|
+
the scattering coefficients.
|
|
4895
|
+
|
|
4896
|
+
Returns
|
|
4897
|
+
-------
|
|
4898
|
+
'P00' : torch tensor with size [N_image, J, L] (# image, j1, l1)
|
|
4899
|
+
the power in each wavelet bands (the orig. x orig. term)
|
|
4900
|
+
'S1' : torch tensor with size [N_image, J, L] (# image, j1, l1)
|
|
4901
|
+
the 1st-order scattering coefficients, i.e., the mean of wavelet modulus fields
|
|
4902
|
+
'C01' : torch tensor with size [N_image, J, J, L, L] (# image, j1, j2, l1, l2)
|
|
4903
|
+
the orig. x modulus terms. Elements with j1 < j2 are all set to np.nan and not computed.
|
|
4904
|
+
'C11' : torch tensor with size [N_image, J, J, J, L, L, L] (# image, j1, j2, j3, l1, l2, l3)
|
|
4905
|
+
the modulus x modulus terms. Elements not satisfying j1 <= j2 <= j3 and the conditions
|
|
4906
|
+
defined in 'C11_criteria' are all set to np.nan and not computed.
|
|
4907
|
+
'C11_pre_norm' and 'C11_pre_norm_iso': pre-normalized modulus x modulus terms.
|
|
4908
|
+
'P11' : torch tensor with size [N_image, J, J, L, L] (# image, j1, j2, l1, l2)
|
|
4909
|
+
the modulus x modulus terms with the two wavelets within modulus the same. Elements not following
|
|
4910
|
+
j1 <= j3 are set to np.nan and not computed.
|
|
4911
|
+
'P11_iso' : torch tensor with size [N_image, J, J, L] (# image, j1, j2, l2-l1)
|
|
4912
|
+
'P11' averaged over l1 while keeping l2-l1 constant.
|
|
4913
|
+
'''
|
|
4914
|
+
if S4_criteria is None:
|
|
4915
|
+
S4_criteria = 'j2>=j1'
|
|
4916
|
+
|
|
4917
|
+
# determine jmax and nside corresponding to the input map
|
|
4918
|
+
im_shape = data.shape
|
|
4919
|
+
if self.use_2D:
|
|
4920
|
+
if len(data.shape) == 2:
|
|
4921
|
+
nside = np.min([im_shape[0], im_shape[1]])
|
|
4922
|
+
M,N = im_shape[0],im_shape[1]
|
|
4923
|
+
N_image = 1
|
|
4924
|
+
else:
|
|
4925
|
+
nside = np.min([im_shape[1], im_shape[2]])
|
|
4926
|
+
M,N = im_shape[1],im_shape[2]
|
|
4927
|
+
N_image = data.shape[0]
|
|
4928
|
+
J = int(np.log(nside) / np.log(2))-1 # Number of j scales
|
|
4929
|
+
elif self.use_1D:
|
|
4930
|
+
if len(data.shape) == 2:
|
|
4931
|
+
npix = int(im_shape[1]) # Number of pixels
|
|
4932
|
+
N_image = 1
|
|
4933
|
+
else:
|
|
4934
|
+
npix = int(im_shape[0]) # Number of pixels
|
|
4935
|
+
N_image = data.shape[0]
|
|
4936
|
+
|
|
4937
|
+
nside = int(npix)
|
|
4938
|
+
|
|
4939
|
+
J = int(np.log(nside) / np.log(2))-1 # Number of j scales
|
|
4940
|
+
else:
|
|
4941
|
+
if len(data.shape) == 2:
|
|
4942
|
+
npix = int(im_shape[1]) # Number of pixels
|
|
4943
|
+
N_image = 1
|
|
4944
|
+
else:
|
|
4945
|
+
npix = int(im_shape[0]) # Number of pixels
|
|
4946
|
+
N_image = data.shape[0]
|
|
4947
|
+
|
|
4948
|
+
nside = int(np.sqrt(npix // 12))
|
|
4949
|
+
|
|
4950
|
+
J = int(np.log(nside) / np.log(2)) # Number of j scales
|
|
4951
|
+
|
|
4952
|
+
if Jmax is None:
|
|
4953
|
+
Jmax = J # Number of steps for the loop on scales
|
|
4954
|
+
if Jmax>J:
|
|
4955
|
+
print('==========\n\n')
|
|
4956
|
+
print('The Jmax you requested is larger than the data size, which may cause problems while computing the scattering transform.')
|
|
4957
|
+
print('\n\n==========')
|
|
4958
|
+
|
|
4959
|
+
L=self.NORIENT
|
|
4960
|
+
|
|
4961
|
+
if (M,N,J,L) not in self.filters_set:
|
|
4962
|
+
self.filters_set[(M,N,J,L)] = self.computer_filter(M,N,J,L) #self.computer_filter(M,N,J,L)
|
|
4963
|
+
|
|
4964
|
+
filters_set = self.filters_set[(M,N,J,L)]
|
|
4965
|
+
|
|
4966
|
+
#weight = self.weight
|
|
4967
|
+
if use_ref:
|
|
4968
|
+
if normalization=='S2':
|
|
4969
|
+
ref_S2 = self.ref_scattering_cov_S2
|
|
4970
|
+
else:
|
|
4971
|
+
ref_P11 = self.ref_scattering_cov['P11']
|
|
4972
|
+
|
|
4973
|
+
# convert numpy array input into self.backend.bk_ tensors
|
|
4974
|
+
data = self.backend.bk_cast(data)
|
|
4975
|
+
data_f = self.backend.bk_fftn(data, dim=(-2,-1))
|
|
4976
|
+
if data2 is not None:
|
|
4977
|
+
data2 = self.backend.bk_cast(data2)
|
|
4978
|
+
data2_f = self.backend.bk_fftn(data2, dim=(-2,-1))
|
|
4979
|
+
|
|
4980
|
+
# initialize tensors for scattering coefficients
|
|
4981
|
+
S2 = self.backend.bk_zeros((N_image,J,L), dtype=data.dtype)
|
|
4982
|
+
S1 = self.backend.bk_zeros((N_image,J,L), dtype=data.dtype)
|
|
4983
|
+
|
|
4984
|
+
Ndata_S3 = J*(J+1)//2
|
|
4985
|
+
Ndata_S4 = J*(J+1)*(J+2)//6
|
|
4986
|
+
J_S4={}
|
|
4987
|
+
|
|
4988
|
+
S3 = self.backend.bk_zeros((N_image,Ndata_S3,L,L), dtype=data_f.dtype)
|
|
4989
|
+
if data2 is not None:
|
|
4990
|
+
S3p = self.backend.bk_zeros((N_image,Ndata_S3,L,L), dtype=data_f.dtype)
|
|
4991
|
+
S4_pre_norm = self.backend.bk_zeros((N_image,Ndata_S4,L,L,L), dtype=data_f.dtype)
|
|
4992
|
+
S4 = self.backend.bk_zeros((N_image,Ndata_S4,L,L,L), dtype=data_f.dtype)
|
|
4993
|
+
|
|
4994
|
+
# variance
|
|
4995
|
+
if get_variance:
|
|
4996
|
+
S2_sigma = self.backend.bk_zeros((N_image,J,L), dtype=data.dtype)
|
|
4997
|
+
S1_sigma = self.backend.bk_zeros((N_image,J,L), dtype=data.dtype)
|
|
4998
|
+
S3_sigma = self.backend.bk_zeros((N_image,Ndata_S3,L,L), dtype=data_f.dtype)
|
|
4999
|
+
if data2 is not None:
|
|
5000
|
+
S3p_sigma = self.backend.bk_zeros((N_image,Ndata_S3,L,L), dtype=data_f.dtype)
|
|
5001
|
+
S4_sigma = self.backend.bk_zeros((N_image,Ndata_S4,L,L,L), dtype=data_f.dtype)
|
|
5002
|
+
|
|
5003
|
+
if iso_ang:
|
|
5004
|
+
S3_iso = self.backend.bk_zeros((N_image,Ndata_S3,L), dtype=data_f.dtype)
|
|
5005
|
+
if data2 is not None:
|
|
5006
|
+
S3p_iso = self.backend.bk_zeros((N_image,Ndata_S3,L,L), dtype=data_f.dtype)
|
|
5007
|
+
|
|
5008
|
+
S4_iso = self.backend.bk_zeros((N_image,Ndata_S4,L,L), dtype=data_f.dtype)
|
|
5009
|
+
if get_variance:
|
|
5010
|
+
S3_sigma_iso = self.backend.bk_zeros((N_image,Ndata_S3,L), dtype=data_f.dtype)
|
|
5011
|
+
if data2 is not None:
|
|
5012
|
+
S3p_sigma_iso = self.backend.bk_zeros((N_image,Ndata_S3,L), dtype=data_f.dtype)
|
|
5013
|
+
S4_sigma_iso = self.backend.bk_zeros((N_image,Ndata_S4,L,L), dtype=data_f.dtype)
|
|
5014
|
+
|
|
5015
|
+
#
|
|
5016
|
+
if edge:
|
|
5017
|
+
if (M,N,J) not in self.edge_masks:
|
|
5018
|
+
self.edge_masks[(M,N,J)] = self.get_edge_masks(M,N,J)
|
|
5019
|
+
edge_mask = self.edge_masks[(M,N,J)][:,None,:,:]
|
|
5020
|
+
edge_mask = edge_mask / edge_mask.mean((-2,-1))[:,:,None,None]
|
|
5021
|
+
else:
|
|
5022
|
+
edge_mask = 1
|
|
5023
|
+
|
|
5024
|
+
# calculate scattering fields
|
|
5025
|
+
if data2 is None:
|
|
5026
|
+
if self.use_2D:
|
|
5027
|
+
if len(data.shape) == 2:
|
|
5028
|
+
I1 = self.backend.bk_ifftn(
|
|
5029
|
+
data_f[None,None,None,:,:] * filters_set[None,:J,:,:,:], dim=(-2,-1)
|
|
5030
|
+
).abs()
|
|
5031
|
+
else:
|
|
5032
|
+
I1 = self.backend.bk_ifftn(
|
|
5033
|
+
data_f[:,None,None,:,:] * filters_set[None,:J,:,:,:], dim=(-2,-1)
|
|
5034
|
+
).abs()
|
|
5035
|
+
elif self.use_1D:
|
|
5036
|
+
if len(data.shape) == 1:
|
|
5037
|
+
I1 = self.backend.bk_ifftn(
|
|
5038
|
+
data_f[None,None,None,:] * filters_set[None,:J,:,:], dim=(-1)
|
|
5039
|
+
).abs()
|
|
5040
|
+
else:
|
|
5041
|
+
I1 = self.backend.bk_ifftn(
|
|
5042
|
+
data_f[:,None,None,:] * filters_set[None,:J,:,:], dim=(-1)
|
|
5043
|
+
).abs()
|
|
5044
|
+
else:
|
|
5045
|
+
print('todo')
|
|
5046
|
+
|
|
5047
|
+
S2 = (I1**2 * edge_mask).mean((-2,-1))
|
|
5048
|
+
S1 = (I1 * edge_mask).mean((-2,-1))
|
|
5049
|
+
|
|
5050
|
+
if get_variance:
|
|
5051
|
+
S2_sigma = (I1**2 * edge_mask).std((-2,-1))
|
|
5052
|
+
S1_sigma = (I1 * edge_mask).std((-2,-1))
|
|
5053
|
+
|
|
5054
|
+
I1_f= self.backend.bk_fftn(I1, dim=(-2,-1))
|
|
5055
|
+
|
|
5056
|
+
else:
|
|
5057
|
+
if self.use_2D:
|
|
5058
|
+
if len(data.shape) == 2:
|
|
5059
|
+
I1 = self.backend.bk_ifftn(
|
|
5060
|
+
data_f[None,None,None,:,:] * filters_set[None,:J,:,:,:], dim=(-2,-1)
|
|
5061
|
+
)
|
|
5062
|
+
I2 = self.backend.bk_ifftn(
|
|
5063
|
+
data2_f[None,None,None,:,:] * filters_set[None,:J,:,:,:], dim=(-2,-1)
|
|
5064
|
+
)
|
|
5065
|
+
else:
|
|
5066
|
+
I1 = self.backend.bk_ifftn(
|
|
5067
|
+
data_f[:,None,None,:,:] * filters_set[None,:J,:,:,:], dim=(-2,-1)
|
|
5068
|
+
)
|
|
5069
|
+
I2 = self.backend.bk_ifftn(
|
|
5070
|
+
data2_f[:,None,None,:,:] * filters_set[None,:J,:,:,:], dim=(-2,-1)
|
|
5071
|
+
)
|
|
5072
|
+
elif self.use_1D:
|
|
5073
|
+
if len(data.shape) == 1:
|
|
5074
|
+
I1 = self.backend.bk_ifftn(
|
|
5075
|
+
data_f[None,None,None,:] * filters_set[None,:J,:,:], dim=(-1)
|
|
5076
|
+
)
|
|
5077
|
+
I2 = self.backend.bk_ifftn(
|
|
5078
|
+
data2_f[None,None,None,:] * filters_set[None,:J,:,:], dim=(-1)
|
|
5079
|
+
)
|
|
5080
|
+
else:
|
|
5081
|
+
I1 = self.backend.bk_ifftn(
|
|
5082
|
+
data_f[:,None,None,:] * filters_set[None,:J,:,:], dim=(-1)
|
|
5083
|
+
)
|
|
5084
|
+
I2 = self.backend.bk_ifftn(
|
|
5085
|
+
data2_f[:,None,None,:] * filters_set[None,:J,:,:], dim=(-1)
|
|
5086
|
+
)
|
|
5087
|
+
else:
|
|
5088
|
+
print('todo')
|
|
5089
|
+
|
|
5090
|
+
I1=self.backend.bk_real(I1*self.backend.bk_conjugate(I2))
|
|
5091
|
+
|
|
5092
|
+
S2 = (I1 * edge_mask).mean((-2,-1))
|
|
5093
|
+
if get_variance:
|
|
5094
|
+
S2_sigma = (I1 * edge_mask).std((-2,-1))
|
|
5095
|
+
|
|
5096
|
+
I1=self.backend.bk_L1(I1)
|
|
5097
|
+
|
|
5098
|
+
S1 = (I1 * edge_mask).mean((-2,-1))
|
|
3647
5099
|
|
|
5100
|
+
if get_variance:
|
|
5101
|
+
S1_sigma = (I1 * edge_mask).std((-2,-1))
|
|
5102
|
+
|
|
5103
|
+
I1_f= self.backend.bk_fftn(I1, dim=(-2,-1))
|
|
5104
|
+
|
|
5105
|
+
if pseudo_coef != 1:
|
|
5106
|
+
I1 = I1**pseudo_coef
|
|
5107
|
+
|
|
5108
|
+
Ndata_S3=0
|
|
5109
|
+
Ndata_S4=0
|
|
5110
|
+
|
|
5111
|
+
# calculate the covariance and correlations of the scattering fields
|
|
5112
|
+
# only use the low-k Fourier coefs when calculating large-j scattering coefs.
|
|
5113
|
+
for j3 in range(0,J):
|
|
5114
|
+
J_S4[j3]=Ndata_S4
|
|
5115
|
+
|
|
5116
|
+
dx3, dy3 = self.get_dxdy(j3,M,N)
|
|
5117
|
+
I1_f_small = self.cut_high_k_off(I1_f[:,:j3+1], dx3, dy3) # Nimage, J, L, x, y
|
|
5118
|
+
data_f_small = self.cut_high_k_off(data_f, dx3, dy3)
|
|
5119
|
+
if data2 is not None:
|
|
5120
|
+
data2_f_small = self.cut_high_k_off(data2_f, dx3, dy3)
|
|
5121
|
+
if edge:
|
|
5122
|
+
I1_small = self.backend.bk_ifftn(I1_f_small, dim=(-2,-1), norm='ortho')
|
|
5123
|
+
data_small = self.backend.bk_ifftn(data_f_small, dim=(-2,-1), norm='ortho')
|
|
5124
|
+
if data2 is not None:
|
|
5125
|
+
data2_small = self.backend.bk_ifftn(data2_f_small, dim=(-2,-1), norm='ortho')
|
|
5126
|
+
wavelet_f3 = self.cut_high_k_off(filters_set[j3], dx3, dy3) # L,x,y
|
|
5127
|
+
_, M3, N3 = wavelet_f3.shape
|
|
5128
|
+
wavelet_f3_squared = wavelet_f3**2
|
|
5129
|
+
edge_dx = min(4, int(2**j3*dx3*2/M))
|
|
5130
|
+
edge_dy = min(4, int(2**j3*dy3*2/N))
|
|
5131
|
+
# a normalization change due to the cutoff of frequency space
|
|
5132
|
+
fft_factor = 1 /(M3*N3) * (M3*N3/M/N)**2
|
|
5133
|
+
for j2 in range(0,j3+1):
|
|
5134
|
+
I1_f2_wf3_small = I1_f_small[:,j2].view(N_image,L,1,M3,N3) * wavelet_f3.view(1,1,L,M3,N3)
|
|
5135
|
+
I1_f2_wf3_2_small = I1_f_small[:,j2].view(N_image,L,1,M3,N3) * wavelet_f3_squared.view(1,1,L,M3,N3)
|
|
5136
|
+
if edge:
|
|
5137
|
+
I12_w3_small = self.backend.bk_ifftn(I1_f2_wf3_small, dim=(-2,-1), norm='ortho')
|
|
5138
|
+
I12_w3_2_small = self.backend.bk_ifftn(I1_f2_wf3_2_small, dim=(-2,-1), norm='ortho')
|
|
5139
|
+
if use_ref:
|
|
5140
|
+
if normalization=='P11':
|
|
5141
|
+
norm_factor_S3 = (ref_S2[:,None,j3,:] * ref_P11[:,j2,j3,:,:]**pseudo_coef)**0.5
|
|
5142
|
+
elif normalization=='S2':
|
|
5143
|
+
norm_factor_S3 = (ref_S2[:,None,j3,:] * ref_S2[:,j2,:,None]**pseudo_coef)**0.5
|
|
5144
|
+
norm_factor_S3 = 1.0
|
|
5145
|
+
else:
|
|
5146
|
+
if normalization=='P11':
|
|
5147
|
+
# [N_image,l2,l3,x,y]
|
|
5148
|
+
P11_temp = (I1_f2_wf3_small.abs()**2).mean((-2,-1)) * fft_factor
|
|
5149
|
+
norm_factor_S3 = (S2[:,None,j3,:] * P11_temp**pseudo_coef)**0.5
|
|
5150
|
+
elif normalization=='S2':
|
|
5151
|
+
norm_factor_S3 = (S2[:,None,j3,:] * S2[:,j2,:,None]**pseudo_coef)**0.5
|
|
5152
|
+
norm_factor_S3 = 1.0
|
|
5153
|
+
|
|
5154
|
+
if not edge:
|
|
5155
|
+
S3[:,Ndata_S3,:,:] = (
|
|
5156
|
+
data_f_small.view(N_image,1,1,M3,N3) * self.backend.bk_conjugate(I1_f2_wf3_small)
|
|
5157
|
+
).mean((-2,-1)) * fft_factor / norm_factor_S3
|
|
5158
|
+
|
|
5159
|
+
if get_variance:
|
|
5160
|
+
S3_sigma[:,Ndata_S3,:,:] = (
|
|
5161
|
+
data_f_small.view(N_image,1,1,M3,N3) * self.backend.bk_conjugate(I1_f2_wf3_small)
|
|
5162
|
+
).std((-2,-1)) * fft_factor / norm_factor_S3
|
|
5163
|
+
else:
|
|
5164
|
+
|
|
5165
|
+
S3[:,Ndata_S3,:,:] = (
|
|
5166
|
+
data_small.view(N_image,1,1,M3,N3) * self.backend.bk_conjugate(I12_w3_small)
|
|
5167
|
+
)[...,edge_dx:M3-edge_dx, edge_dy:N3-edge_dy].mean((-2,-1)) * fft_factor / norm_factor_S3
|
|
5168
|
+
if get_variance:
|
|
5169
|
+
S3_sigma[:,Ndata_S3,:,:] = (
|
|
5170
|
+
data_small.view(N_image,1,1,M3,N3) * self.backend.bk_conjugate(I12_w3_small)
|
|
5171
|
+
)[...,edge_dx:M3-edge_dx, edge_dy:N3-edge_dy].std((-2,-1)) * fft_factor / norm_factor_S3
|
|
5172
|
+
if data2 is not None:
|
|
5173
|
+
if not edge:
|
|
5174
|
+
S3p[:,Ndata_S3,:,:] = (
|
|
5175
|
+
data2_f_small.view(N_image,1,1,M3,N3) * self.backend.bk_conjugate(I1_f2_wf3_small)
|
|
5176
|
+
).mean((-2,-1)) * fft_factor / norm_factor_S3
|
|
5177
|
+
|
|
5178
|
+
if get_variance:
|
|
5179
|
+
S3p_sigma[:,Ndata_S3,:,:] = (
|
|
5180
|
+
data2_f_small.view(N_image,1,1,M3,N3) * self.backend.bk_conjugate(I1_f2_wf3_small)
|
|
5181
|
+
).std((-2,-1)) * fft_factor / norm_factor_S3
|
|
5182
|
+
else:
|
|
5183
|
+
|
|
5184
|
+
S3p[:,Ndata_S3,:,:] = (
|
|
5185
|
+
data2_small.view(N_image,1,1,M3,N3) * self.backend.bk_conjugate(I12_w3_small)
|
|
5186
|
+
)[...,edge_dx:M3-edge_dx, edge_dy:N3-edge_dy].mean((-2,-1)) * fft_factor / norm_factor_S3
|
|
5187
|
+
if get_variance:
|
|
5188
|
+
S3p_sigma[:,Ndata_S3,:,:] = (
|
|
5189
|
+
data2_small.view(N_image,1,1,M3,N3) * self.backend.bk_conjugate(I12_w3_small)
|
|
5190
|
+
)[...,edge_dx:M3-edge_dx, edge_dy:N3-edge_dy].std((-2,-1)) * fft_factor / norm_factor_S3
|
|
5191
|
+
|
|
5192
|
+
Ndata_S3+=1
|
|
5193
|
+
if j2 <= j3:
|
|
5194
|
+
beg_n=Ndata_S4
|
|
5195
|
+
for j1 in range(0, j2+1):
|
|
5196
|
+
if eval(S4_criteria):
|
|
5197
|
+
if not edge:
|
|
5198
|
+
if not if_large_batch:
|
|
5199
|
+
# [N_image,l1,l2,l3,x,y]
|
|
5200
|
+
S4_pre_norm[:,Ndata_S4,:,:,:] = (
|
|
5201
|
+
I1_f_small[:,j1].view(N_image,L,1,1,M3,N3) *
|
|
5202
|
+
self.backend.bk_conjugate(I1_f2_wf3_2_small.view(N_image,1,L,L,M3,N3))
|
|
5203
|
+
).mean((-2,-1)) * fft_factor
|
|
5204
|
+
if get_variance:
|
|
5205
|
+
S4_sigma[:,Ndata_S4,:,:,:] = (
|
|
5206
|
+
I1_f_small[:,j1].view(N_image,L,1,1,M3,N3) *
|
|
5207
|
+
self.backend.bk_conjugate(I1_f2_wf3_2_small.view(N_image,1,L,L,M3,N3))
|
|
5208
|
+
).std((-2,-1)) * fft_factor
|
|
5209
|
+
else:
|
|
5210
|
+
for l1 in range(L):
|
|
5211
|
+
# [N_image,l2,l3,x,y]
|
|
5212
|
+
S4_pre_norm[:,Ndata_S4,l1,:,:] = (
|
|
5213
|
+
I1_f_small[:,j1,l1].view(N_image,1,1,M3,N3) *
|
|
5214
|
+
self.backend.bk_conjugate(I1_f2_wf3_2_small.view(N_image,L,L,M3,N3))
|
|
5215
|
+
).mean((-2,-1)) * fft_factor
|
|
5216
|
+
if get_variance:
|
|
5217
|
+
S4_sigma[:,Ndata_S4,l1,:,:] = (
|
|
5218
|
+
I1_f_small[:,j1,l1].view(N_image,1,1,M3,N3) *
|
|
5219
|
+
self.backend.bk_conjugate(I1_f2_wf3_2_small.view(N_image,L,L,M3,N3))
|
|
5220
|
+
).std((-2,-1)) * fft_factor
|
|
5221
|
+
else:
|
|
5222
|
+
if not if_large_batch:
|
|
5223
|
+
# [N_image,l1,l2,l3,x,y]
|
|
5224
|
+
S4_pre_norm[:,Ndata_S4,:,:,:] = (
|
|
5225
|
+
I1_small[:,j1].view(N_image,L,1,1,M3,N3) * self.backend.bk_conjugate(
|
|
5226
|
+
I12_w3_2_small.view(N_image,1,L,L,M3,N3)
|
|
5227
|
+
)
|
|
5228
|
+
)[...,edge_dx:-edge_dx, edge_dy:-edge_dy].mean((-2,-1)) * fft_factor
|
|
5229
|
+
if get_variance:
|
|
5230
|
+
S4_sigma[:,Ndata_S4,:,:,:] = (
|
|
5231
|
+
I1_small[:,j1].view(N_image,L,1,1,M3,N3) * self.backend.bk_conjugate(
|
|
5232
|
+
I12_w3_2_small.view(N_image,1,L,L,M3,N3)
|
|
5233
|
+
)
|
|
5234
|
+
)[...,edge_dx:-edge_dx, edge_dy:-edge_dy].std((-2,-1)) * fft_factor
|
|
5235
|
+
else:
|
|
5236
|
+
for l1 in range(L):
|
|
5237
|
+
# [N_image,l2,l3,x,y]
|
|
5238
|
+
S4_pre_norm[:,Ndata_S4,l1,:,:] = (
|
|
5239
|
+
I1_small[:,j1].view(N_image,1,1,M3,N3) * self.backend.bk_conjugate(
|
|
5240
|
+
I12_w3_2_small.view(N_image,L,L,M3,N3)
|
|
5241
|
+
)
|
|
5242
|
+
)[...,edge_dx:-edge_dx, edge_dy:-edge_dy].mean((-2,-1)) * fft_factor
|
|
5243
|
+
if get_variance:
|
|
5244
|
+
S4_sigma[:,Ndata_S4,l1,:,:] = (
|
|
5245
|
+
I1_small[:,j1].view(N_image,1,1,M3,N3) * self.backend.bk_conjugate(
|
|
5246
|
+
I12_w3_2_small.view(N_image,L,L,M3,N3)
|
|
5247
|
+
)
|
|
5248
|
+
)[...,edge_dx:-edge_dx, edge_dy:-edge_dy].mean((-2,-1)) * fft_factor
|
|
5249
|
+
|
|
5250
|
+
Ndata_S4+=1
|
|
5251
|
+
|
|
5252
|
+
if normalization=='S2':
|
|
5253
|
+
if use_ref:
|
|
5254
|
+
P = (ref_S2[:,j3,:,None,None] * ref_S2[:,j2,None,:,None] )**(0.5*pseudo_coef)
|
|
5255
|
+
else:
|
|
5256
|
+
P = ((S2[:,j3,:,None,None] * S2[:,j2,None,:,None] )**(0.5*pseudo_coef))
|
|
5257
|
+
S4[:,beg_n:Ndata_S4,:,:,:]=S4_pre_norm[:,beg_n:Ndata_S4,:,:,:].clone()/(P.clone())
|
|
5258
|
+
|
|
5259
|
+
if get_variance:
|
|
5260
|
+
S4_sigma[:,beg_n:Ndata_S4,:,:,:] = S4_sigma[:,beg_n:Ndata_S4,:,:,:] / (P)
|
|
5261
|
+
else:
|
|
5262
|
+
S4[:,beg_n:Ndata_S4,:,:,:]=S4_pre_norm[:,beg_n:Ndata_S4,:,:,:].clone()
|
|
5263
|
+
|
|
5264
|
+
if get_variance:
|
|
5265
|
+
S4_sigma[:,beg_n:Ndata_S4,:,:,:] = S4_sigma[:,beg_n:Ndata_S4,:,:,:]
|
|
5266
|
+
|
|
5267
|
+
"""
|
|
5268
|
+
# define P11 from diagonals of S4
|
|
5269
|
+
for j1 in range(J):
|
|
5270
|
+
for l1 in range(L):
|
|
5271
|
+
P11[:,j1,:,l1,:] = S4_pre_norm[:,j1,j1,:,l1,l1,:].real
|
|
5272
|
+
|
|
5273
|
+
|
|
5274
|
+
if normalization=='S4':
|
|
5275
|
+
if use_ref:
|
|
5276
|
+
P = ref_P11
|
|
5277
|
+
else:
|
|
5278
|
+
P = P11
|
|
5279
|
+
#.view(N_image,J,1,J,L,1,L) * .view(N_image,1,J,J,1,L,L)
|
|
5280
|
+
S4 = S4_pre_norm / (
|
|
5281
|
+
P[:,:,None,:,:,None,:] * P[:,None,:,:,None,:,:]
|
|
5282
|
+
)**(0.5*pseudo_coef)
|
|
5283
|
+
|
|
5284
|
+
|
|
5285
|
+
|
|
5286
|
+
|
|
5287
|
+
# get a single, flattened data vector for_synthesis
|
|
5288
|
+
select_and_index = self.get_scattering_index(J, L, normalization, S4_criteria)
|
|
5289
|
+
index_for_synthesis = select_and_index['index_for_synthesis']
|
|
5290
|
+
index_for_synthesis_iso = select_and_index['index_for_synthesis_iso']
|
|
5291
|
+
"""
|
|
5292
|
+
# average over l1 to obtain simple isotropic statistics
|
|
5293
|
+
if iso_ang:
|
|
5294
|
+
S2_iso = S2.mean(-1)
|
|
5295
|
+
S1_iso = S1.mean(-1)
|
|
5296
|
+
for l1 in range(L):
|
|
5297
|
+
for l2 in range(L):
|
|
5298
|
+
S3_iso[...,(l2-l1)%L] += S3[...,l1,l2]
|
|
5299
|
+
if data2 is not None:
|
|
5300
|
+
S3p_iso[...,(l2-l1)%L] += S3p[...,l1,l2]
|
|
5301
|
+
for l3 in range(L):
|
|
5302
|
+
S4_iso[...,(l2-l1)%L,(l3-l1)%L] += S4[...,l1,l2,l3]
|
|
5303
|
+
S3_iso /= L; S4_iso /= L
|
|
5304
|
+
if data2 is not None:
|
|
5305
|
+
S3p_iso /= L
|
|
5306
|
+
|
|
5307
|
+
if get_variance:
|
|
5308
|
+
S2_sigma_iso = S2_sigma.mean(-1)
|
|
5309
|
+
S1_sigma_iso = S1_sigma.mean(-1)
|
|
5310
|
+
for l1 in range(L):
|
|
5311
|
+
for l2 in range(L):
|
|
5312
|
+
S3_sigma_iso[...,(l2-l1)%L] += S3_sigma[...,l1,l2]
|
|
5313
|
+
if data2 is not None:
|
|
5314
|
+
S3p_sigma_iso[...,(l2-l1)%L] += S3p_sigma[...,l1,l2]
|
|
5315
|
+
for l3 in range(L):
|
|
5316
|
+
S4_sigma_iso[...,(l2-l1)%L,(l3-l1)%L] += S4_sigma[...,l1,l2,l3]
|
|
5317
|
+
S3_sigma_iso /= L; S4_sigma_iso /= L
|
|
5318
|
+
if data2 is not None:
|
|
5319
|
+
S3p_sigma_iso /= L
|
|
5320
|
+
|
|
5321
|
+
mean_data=self.backend.bk_zeros((N_image,1), dtype=data.dtype)
|
|
5322
|
+
std_data=self.backend.bk_zeros((N_image,1), dtype=data.dtype)
|
|
5323
|
+
mean_data[:,0]=data.mean((-2,-1))
|
|
5324
|
+
std_data[:,0]=data.std((-2,-1))
|
|
5325
|
+
|
|
5326
|
+
if get_variance:
|
|
5327
|
+
ref_sigma={}
|
|
5328
|
+
if iso_ang:
|
|
5329
|
+
ref_sigma['std_data']=std_data
|
|
5330
|
+
ref_sigma['S1_sigma']=S1_sigma_iso
|
|
5331
|
+
ref_sigma['S2_sigma']=S2_sigma_iso
|
|
5332
|
+
ref_sigma['S3_sigma']=S3_sigma_iso
|
|
5333
|
+
ref_sigma['S4_sigma']=S4_sigma_iso
|
|
5334
|
+
if data2 is not None:
|
|
5335
|
+
ref_sigma['S3p_sigma']=S3p_sigma_iso
|
|
5336
|
+
else:
|
|
5337
|
+
ref_sigma['std_data']=std_data
|
|
5338
|
+
ref_sigma['S1_sigma']=S1_sigma
|
|
5339
|
+
ref_sigma['S2_sigma']=S2_sigma
|
|
5340
|
+
ref_sigma['S3_sigma']=S3_sigma
|
|
5341
|
+
ref_sigma['S4_sigma']=S4_sigma
|
|
5342
|
+
if data2 is not None:
|
|
5343
|
+
ref_sigma['S3p_sigma']=S3_sigma
|
|
5344
|
+
|
|
5345
|
+
if data2 is None:
|
|
5346
|
+
if iso_ang:
|
|
5347
|
+
if ref_sigma is not None:
|
|
5348
|
+
for_synthesis = self.backend.backend.cat((
|
|
5349
|
+
mean_data/ref_sigma['std_data'],
|
|
5350
|
+
std_data/ref_sigma['std_data'],
|
|
5351
|
+
(S2_iso/ref_sigma['S2_sigma']).reshape((N_image, -1)).log(),
|
|
5352
|
+
(S1_iso/ref_sigma['S1_sigma']).reshape((N_image, -1)).log(),
|
|
5353
|
+
(S3_iso/ref_sigma['S3_sigma']).reshape((N_image, -1)).real,
|
|
5354
|
+
(S3_iso/ref_sigma['S3_sigma']).reshape((N_image, -1)).imag,
|
|
5355
|
+
(S4_iso/ref_sigma['S4_sigma']).reshape((N_image, -1)).real,
|
|
5356
|
+
(S4_iso/ref_sigma['S4_sigma']).reshape((N_image, -1)).imag,
|
|
5357
|
+
),dim=-1)
|
|
5358
|
+
else:
|
|
5359
|
+
for_synthesis = self.backend.backend.cat((
|
|
5360
|
+
mean_data/std_data,
|
|
5361
|
+
std_data,
|
|
5362
|
+
S2_iso.reshape((N_image, -1)).log(),
|
|
5363
|
+
S1_iso.reshape((N_image, -1)).log(),
|
|
5364
|
+
S3_iso.reshape((N_image, -1)).real,
|
|
5365
|
+
S3_iso.reshape((N_image, -1)).imag,
|
|
5366
|
+
S4_iso.reshape((N_image, -1)).real,
|
|
5367
|
+
S4_iso.reshape((N_image, -1)).imag,
|
|
5368
|
+
),dim=-1)
|
|
5369
|
+
else:
|
|
5370
|
+
if ref_sigma is not None:
|
|
5371
|
+
for_synthesis = self.backend.backend.cat((
|
|
5372
|
+
mean_data/ref_sigma['std_data'],
|
|
5373
|
+
std_data/ref_sigma['std_data'],
|
|
5374
|
+
(S2/ref_sigma['S2_sigma']).reshape((N_image, -1)).log(),
|
|
5375
|
+
(S1/ref_sigma['S1_sigma']).reshape((N_image, -1)).log(),
|
|
5376
|
+
(S3/ref_sigma['S3_sigma']).reshape((N_image, -1)).real,
|
|
5377
|
+
(S3/ref_sigma['S3_sigma']).reshape((N_image, -1)).imag,
|
|
5378
|
+
(S4/ref_sigma['S4_sigma']).reshape((N_image, -1)).real,
|
|
5379
|
+
(S4/ref_sigma['S4_sigma']).reshape((N_image, -1)).imag,
|
|
5380
|
+
),dim=-1)
|
|
5381
|
+
else:
|
|
5382
|
+
for_synthesis = self.backend.backend.cat((
|
|
5383
|
+
mean_data/std_data,
|
|
5384
|
+
std_data,
|
|
5385
|
+
S2.reshape((N_image, -1)).log(),
|
|
5386
|
+
S1.reshape((N_image, -1)).log(),
|
|
5387
|
+
S3.reshape((N_image, -1)).real,
|
|
5388
|
+
S3.reshape((N_image, -1)).imag,
|
|
5389
|
+
S4.reshape((N_image, -1)).real,
|
|
5390
|
+
S4.reshape((N_image, -1)).imag,
|
|
5391
|
+
),dim=-1)
|
|
5392
|
+
else:
|
|
5393
|
+
if iso_ang:
|
|
5394
|
+
if ref_sigma is not None:
|
|
5395
|
+
for_synthesis = self.backend.backend.cat((
|
|
5396
|
+
mean_data/ref_sigma['std_data'],
|
|
5397
|
+
std_data/ref_sigma['std_data'],
|
|
5398
|
+
(S2_iso/ref_sigma['S2_sigma']).reshape((N_image, -1)),
|
|
5399
|
+
(S1_iso/ref_sigma['S1_sigma']).reshape((N_image, -1)),
|
|
5400
|
+
(S3_iso/ref_sigma['S3_sigma']).reshape((N_image, -1)).real,
|
|
5401
|
+
(S3_iso/ref_sigma['S3_sigma']).reshape((N_image, -1)).imag,
|
|
5402
|
+
(S3p_iso/ref_sigma['S3p_sigma']).reshape((N_image, -1)).real,
|
|
5403
|
+
(S3p_iso/ref_sigma['S3p_sigma']).reshape((N_image, -1)).imag,
|
|
5404
|
+
(S4_iso/ref_sigma['S4_sigma']).reshape((N_image, -1)).real,
|
|
5405
|
+
(S4_iso/ref_sigma['S4_sigma']).reshape((N_image, -1)).imag,
|
|
5406
|
+
),dim=-1)
|
|
5407
|
+
else:
|
|
5408
|
+
for_synthesis = self.backend.backend.cat((
|
|
5409
|
+
mean_data/std_data,
|
|
5410
|
+
std_data,
|
|
5411
|
+
S2_iso.reshape((N_image, -1)),
|
|
5412
|
+
S1_iso.reshape((N_image, -1)),
|
|
5413
|
+
S3_iso.reshape((N_image, -1)).real,
|
|
5414
|
+
S3_iso.reshape((N_image, -1)).imag,
|
|
5415
|
+
S3p_iso.reshape((N_image, -1)).real,
|
|
5416
|
+
S3p_iso.reshape((N_image, -1)).imag,
|
|
5417
|
+
S4_iso.reshape((N_image, -1)).real,
|
|
5418
|
+
S4_iso.reshape((N_image, -1)).imag,
|
|
5419
|
+
),dim=-1)
|
|
5420
|
+
else:
|
|
5421
|
+
if ref_sigma is not None:
|
|
5422
|
+
for_synthesis = self.backend.backend.cat((
|
|
5423
|
+
mean_data/ref_sigma['std_data'],
|
|
5424
|
+
std_data/ref_sigma['std_data'],
|
|
5425
|
+
(S2/ref_sigma['S2_sigma']).reshape((N_image, -1)),
|
|
5426
|
+
(S1/ref_sigma['S1_sigma']).reshape((N_image, -1)),
|
|
5427
|
+
(S3/ref_sigma['S3_sigma']).reshape((N_image, -1)).real,
|
|
5428
|
+
(S3/ref_sigma['S3_sigma']).reshape((N_image, -1)).imag,
|
|
5429
|
+
(S3p/ref_sigma['S3p_sigma']).reshape((N_image, -1)).real,
|
|
5430
|
+
(S3p/ref_sigma['S3p_sigma']).reshape((N_image, -1)).imag,
|
|
5431
|
+
(S4/ref_sigma['S4_sigma']).reshape((N_image, -1)).real,
|
|
5432
|
+
(S4/ref_sigma['S4_sigma']).reshape((N_image, -1)).imag,
|
|
5433
|
+
),dim=-1)
|
|
5434
|
+
else:
|
|
5435
|
+
for_synthesis = self.backend.backend.cat((
|
|
5436
|
+
mean_data/std_data,
|
|
5437
|
+
std_data,
|
|
5438
|
+
S2.reshape((N_image, -1)),
|
|
5439
|
+
S1.reshape((N_image, -1)),
|
|
5440
|
+
S3.reshape((N_image, -1)).real,
|
|
5441
|
+
S3.reshape((N_image, -1)).imag,
|
|
5442
|
+
S3p.reshape((N_image, -1)).real,
|
|
5443
|
+
S3p.reshape((N_image, -1)).imag,
|
|
5444
|
+
S4.reshape((N_image, -1)).real,
|
|
5445
|
+
S4.reshape((N_image, -1)).imag,
|
|
5446
|
+
),dim=-1)
|
|
5447
|
+
|
|
5448
|
+
if not use_ref:
|
|
5449
|
+
self.ref_scattering_cov_S2=S2
|
|
5450
|
+
|
|
5451
|
+
if get_variance:
|
|
5452
|
+
return for_synthesis,ref_sigma
|
|
5453
|
+
|
|
5454
|
+
return for_synthesis
|
|
5455
|
+
|
|
5456
|
+
|
|
3648
5457
|
def to_gaussian(self,x):
|
|
3649
5458
|
from scipy.stats import norm
|
|
3650
5459
|
from scipy.interpolate import interp1d
|
|
@@ -4021,8 +5830,12 @@ class funct(FOC.FoCUS):
|
|
|
4021
5830
|
image_target,
|
|
4022
5831
|
nstep=4,
|
|
4023
5832
|
seed=1234,
|
|
4024
|
-
|
|
5833
|
+
Jmax=None,
|
|
5834
|
+
edge=False,
|
|
4025
5835
|
to_gaussian=True,
|
|
5836
|
+
use_variance=False,
|
|
5837
|
+
synthesised_N=1,
|
|
5838
|
+
iso_ang=False,
|
|
4026
5839
|
EVAL_FREQUENCY=100,
|
|
4027
5840
|
NUM_EPOCHS = 300):
|
|
4028
5841
|
|
|
@@ -4032,13 +5845,16 @@ class funct(FOC.FoCUS):
|
|
|
4032
5845
|
def The_loss(u,scat_operator,args):
|
|
4033
5846
|
ref = args[0]
|
|
4034
5847
|
sref = args[1]
|
|
5848
|
+
use_v= args[2]
|
|
4035
5849
|
|
|
4036
5850
|
# compute scattering covariance of the current synthetised map called u
|
|
4037
|
-
|
|
5851
|
+
if use_v:
|
|
5852
|
+
learn=scat_operator.reduce_mean_batch(scat_operator.scattering_cov(u,edge=edge,Jmax=Jmax,ref_sigma=sref,use_ref=True,iso_ang=iso_ang))
|
|
5853
|
+
else:
|
|
5854
|
+
learn=scat_operator.reduce_mean_batch(scat_operator.scattering_cov(u,edge=edge,Jmax=Jmax,use_ref=True,iso_ang=iso_ang))
|
|
4038
5855
|
|
|
4039
5856
|
# make the difference withe the reference coordinates
|
|
4040
|
-
loss=scat_operator.
|
|
4041
|
-
|
|
5857
|
+
loss=scat_operator.backend.bk_reduce_mean(scat_operator.backend.bk_square((learn-ref)))
|
|
4042
5858
|
return loss
|
|
4043
5859
|
|
|
4044
5860
|
if to_gaussian:
|
|
@@ -4078,33 +5894,35 @@ class funct(FOC.FoCUS):
|
|
|
4078
5894
|
if k==0:
|
|
4079
5895
|
np.random.seed(seed)
|
|
4080
5896
|
if self.use_2D:
|
|
4081
|
-
imap=np.random.randn(
|
|
5897
|
+
imap=np.random.randn(synthesised_N,
|
|
4082
5898
|
tmp[k].shape[1],
|
|
4083
5899
|
tmp[k].shape[2])
|
|
4084
5900
|
else:
|
|
4085
|
-
imap=np.random.randn(
|
|
5901
|
+
imap=np.random.randn(synthesised_N,
|
|
4086
5902
|
tmp[k].shape[1])
|
|
4087
5903
|
else:
|
|
4088
|
-
|
|
4089
|
-
# if the kernel size is bigger than 3 increase the binning before smoothing
|
|
5904
|
+
# Increase the resolution between each step
|
|
4090
5905
|
if self.use_2D:
|
|
4091
5906
|
imap = self.up_grade(
|
|
4092
|
-
omap, imap.shape[
|
|
5907
|
+
omap, imap.shape[1] * 2, axis=1, nouty=imap.shape[2] * 2
|
|
4093
5908
|
)
|
|
4094
5909
|
elif self.use_1D:
|
|
4095
|
-
imap = self.up_grade(omap, imap.shape[
|
|
5910
|
+
imap = self.up_grade(omap, imap.shape[1] * 2, axis=1)
|
|
4096
5911
|
else:
|
|
4097
|
-
imap = self.up_grade(omap, l_nside, axis=
|
|
5912
|
+
imap = self.up_grade(omap, l_nside, axis=1)
|
|
4098
5913
|
|
|
4099
5914
|
# compute the coefficients for the target image
|
|
4100
|
-
|
|
4101
|
-
|
|
5915
|
+
if use_variance:
|
|
5916
|
+
ref,sref=self.scattering_cov(tmp[k],get_variance=True,edge=edge,Jmax=Jmax,iso_ang=iso_ang)
|
|
5917
|
+
else:
|
|
5918
|
+
ref=self.scattering_cov(tmp[k],edge=edge,Jmax=Jmax,iso_ang=iso_ang)
|
|
5919
|
+
sref=ref
|
|
5920
|
+
|
|
4102
5921
|
# compute the mean of the population does nothing if only one map is given
|
|
4103
5922
|
ref=self.reduce_mean_batch(ref)
|
|
4104
|
-
|
|
4105
|
-
|
|
5923
|
+
|
|
4106
5924
|
# define a loss to minimize
|
|
4107
|
-
loss=synthe.Loss(The_loss,self,ref,sref)
|
|
5925
|
+
loss=synthe.Loss(The_loss,self,ref,sref,use_variance)
|
|
4108
5926
|
|
|
4109
5927
|
sy = synthe.Synthesis([loss])
|
|
4110
5928
|
|
|
@@ -4121,6 +5939,8 @@ class funct(FOC.FoCUS):
|
|
|
4121
5939
|
omap=sy.run(imap,
|
|
4122
5940
|
EVAL_FREQUENCY=EVAL_FREQUENCY,
|
|
4123
5941
|
NUM_EPOCHS = NUM_EPOCHS)
|
|
5942
|
+
|
|
5943
|
+
|
|
4124
5944
|
|
|
4125
5945
|
t2=time.time()
|
|
4126
5946
|
print('Total computation %.2fs'%(t2-t1))
|
|
@@ -4128,7 +5948,7 @@ class funct(FOC.FoCUS):
|
|
|
4128
5948
|
if to_gaussian:
|
|
4129
5949
|
omap=self.from_gaussian(omap)
|
|
4130
5950
|
|
|
4131
|
-
if axis==0:
|
|
5951
|
+
if axis==0 and synthesised_N==1:
|
|
4132
5952
|
return omap[0]
|
|
4133
5953
|
else:
|
|
4134
5954
|
return omap
|