foscat 3.7.0__py3-none-any.whl → 3.7.1__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
foscat/scat_cov.py CHANGED
@@ -92,7 +92,7 @@ class scat_cov:
92
92
  )
93
93
 
94
94
  def conv2complex(self, val):
95
- if val.dtype == "complex64" in val.dtype or "complex128" or val.dtype == "torch.complex64" or val.dtype == "torch.complex128" :
95
+ if val.dtype == "complex64" or val.dtype=="complex128" or val.dtype == "torch.complex64" or val.dtype == "torch.complex128" :
96
96
  return val
97
97
  else:
98
98
  return self.backend.bk_complex(val, 0 * val)
@@ -2542,6 +2542,1047 @@ class funct(FOC.FoCUS):
2542
2542
  """
2543
2543
  return_data = self.return_data
2544
2544
  # Check input consistency
2545
+ if image2 is not None:
2546
+ if list(image1.shape) != list(image2.shape):
2547
+ print(
2548
+ "The two input image should have the same size to eval Scattering Covariance"
2549
+ )
2550
+ return None
2551
+ if mask is not None:
2552
+ if self.use_2D:
2553
+ if image1.shape[-2] != mask.shape[1] or image1.shape[-1] != mask.shape[2]:
2554
+ print(
2555
+ "The LAST 2 COLUMNs of the mask should have the same size ",
2556
+ mask.shape,
2557
+ "than the input image ",
2558
+ image1.shape,
2559
+ "to eval Scattering Covariance",
2560
+ )
2561
+ return None
2562
+ else:
2563
+ if image1.shape[-1] != mask.shape[1]:
2564
+ print(
2565
+ "The LAST COLUMN of the mask should have the same size ",
2566
+ mask.shape,
2567
+ "than the input image ",
2568
+ image1.shape,
2569
+ "to eval Scattering Covariance",
2570
+ )
2571
+ return None
2572
+ if self.use_2D and len(image1.shape) < 2:
2573
+ print(
2574
+ "To work with 2D scattering transform, two dimension is needed, input map has only on dimension"
2575
+ )
2576
+ return None
2577
+
2578
+ ### AUTO OR CROSS
2579
+ cross = False
2580
+ if image2 is not None:
2581
+ cross = True
2582
+
2583
+ ### PARAMETERS
2584
+ axis = 1
2585
+ # determine jmax and nside corresponding to the input map
2586
+ im_shape = image1.shape
2587
+ if self.use_2D:
2588
+ if len(image1.shape) == 2:
2589
+ nside = np.min([im_shape[0], im_shape[1]])
2590
+ npix = im_shape[0] * im_shape[1] # Number of pixels
2591
+ x1 = im_shape[0]
2592
+ x2 = im_shape[1]
2593
+ else:
2594
+ nside = np.min([im_shape[1], im_shape[2]])
2595
+ npix = im_shape[1] * im_shape[2] # Number of pixels
2596
+ x1 = im_shape[1]
2597
+ x2 = im_shape[2]
2598
+ J = int(np.log(nside - self.KERNELSZ) / np.log(2)) # Number of j scales
2599
+ elif self.use_1D:
2600
+ if len(image1.shape) == 2:
2601
+ npix = int(im_shape[1]) # Number of pixels
2602
+ else:
2603
+ npix = int(im_shape[0]) # Number of pixels
2604
+
2605
+ nside = int(npix)
2606
+
2607
+ J = int(np.log(nside) / np.log(2)) # Number of j scales
2608
+ else:
2609
+ if len(image1.shape) == 2:
2610
+ npix = int(im_shape[1]) # Number of pixels
2611
+ else:
2612
+ npix = int(im_shape[0]) # Number of pixels
2613
+
2614
+ nside = int(np.sqrt(npix // 12))
2615
+
2616
+ J = int(np.log(nside) / np.log(2)) # Number of j scales
2617
+
2618
+ if (self.use_2D or self.use_1D) and self.KERNELSZ>3:
2619
+ J-=1
2620
+ if Jmax is None:
2621
+ Jmax = J # Number of steps for the loop on scales
2622
+ if Jmax>J:
2623
+ print('==========\n\n')
2624
+ print('The Jmax you requested is larger than the data size, which may cause problems while computing the scattering transform.')
2625
+ print('\n\n==========')
2626
+
2627
+
2628
+ ### LOCAL VARIABLES (IMAGES and MASK)
2629
+ if len(image1.shape) == 1 or (len(image1.shape) == 2 and self.use_2D):
2630
+ I1 = self.backend.bk_cast(
2631
+ self.backend.bk_expand_dims(image1, 0)
2632
+ ) # Local image1 [Nbatch, Npix]
2633
+ if cross:
2634
+ I2 = self.backend.bk_cast(
2635
+ self.backend.bk_expand_dims(image2, 0)
2636
+ ) # Local image2 [Nbatch, Npix]
2637
+ else:
2638
+ I1 = self.backend.bk_cast(image1) # Local image1 [Nbatch, Npix]
2639
+ if cross:
2640
+ I2 = self.backend.bk_cast(image2) # Local image2 [Nbatch, Npix]
2641
+
2642
+ if mask is None:
2643
+ if self.use_2D:
2644
+ vmask = self.backend.bk_ones([1, x1, x2], dtype=self.all_type)
2645
+ else:
2646
+ vmask = self.backend.bk_ones([1, npix], dtype=self.all_type)
2647
+ else:
2648
+ vmask = self.backend.bk_cast(mask) # [Nmask, Npix]
2649
+
2650
+ if self.KERNELSZ > 3 and not self.use_2D:
2651
+ # if the kernel size is bigger than 3 increase the binning before smoothing
2652
+ if self.use_2D:
2653
+ vmask = self.up_grade(
2654
+ vmask, I1.shape[axis] * 2, axis=1, nouty=I1.shape[axis + 1] * 2
2655
+ )
2656
+ I1 = self.up_grade(
2657
+ I1, I1.shape[axis] * 2, axis=axis, nouty=I1.shape[axis + 1] * 2
2658
+ )
2659
+ if cross:
2660
+ I2 = self.up_grade(
2661
+ I2, I2.shape[axis] * 2, axis=axis, nouty=I2.shape[axis + 1] * 2
2662
+ )
2663
+ elif self.use_1D:
2664
+ vmask = self.up_grade(vmask, I1.shape[axis] * 2, axis=1)
2665
+ I1 = self.up_grade(I1, I1.shape[axis] * 2, axis=axis)
2666
+ if cross:
2667
+ I2 = self.up_grade(I2, I2.shape[axis] * 2, axis=axis)
2668
+ else:
2669
+ I1 = self.up_grade(I1, nside * 2, axis=axis)
2670
+ vmask = self.up_grade(vmask, nside * 2, axis=1)
2671
+ if cross:
2672
+ I2 = self.up_grade(I2, nside * 2, axis=axis)
2673
+
2674
+ if self.KERNELSZ > 5 and not self.use_2D:
2675
+ # if the kernel size is bigger than 3 increase the binning before smoothing
2676
+ if self.use_2D:
2677
+ vmask = self.up_grade(
2678
+ vmask, I1.shape[axis] * 2, axis=1, nouty=I1.shape[axis + 1] * 2
2679
+ )
2680
+ I1 = self.up_grade(
2681
+ I1, I1.shape[axis] * 2, axis=axis, nouty=I1.shape[axis + 1] * 2
2682
+ )
2683
+ if cross:
2684
+ I2 = self.up_grade(
2685
+ I2,
2686
+ I2.shape[axis] * 2,
2687
+ axis=axis,
2688
+ nouty=I2.shape[axis + 1] * 2,
2689
+ )
2690
+ elif self.use_1D:
2691
+ vmask = self.up_grade(vmask, I1.shape[axis] * 4, axis=1)
2692
+ I1 = self.up_grade(I1, I1.shape[axis] * 4, axis=axis)
2693
+ if cross:
2694
+ I2 = self.up_grade(I2, I2.shape[axis] * 4, axis=axis)
2695
+ else:
2696
+ I1 = self.up_grade(I1, nside * 4, axis=axis)
2697
+ vmask = self.up_grade(vmask, nside * 4, axis=1)
2698
+ if cross:
2699
+ I2 = self.up_grade(I2, nside * 4, axis=axis)
2700
+
2701
+ # Normalize the masks because they have different pixel numbers
2702
+ # vmask /= self.backend.bk_reduce_sum(vmask, axis=1)[:, None] # [Nmask, Npix]
2703
+
2704
+ ### INITIALIZATION
2705
+ # Coefficients
2706
+ if return_data:
2707
+ S1 = {}
2708
+ S2 = {}
2709
+ S3 = {}
2710
+ S3P = {}
2711
+ S4 = {}
2712
+ else:
2713
+ S1 = []
2714
+ S2 = []
2715
+ S3 = []
2716
+ S4 = []
2717
+ S3P = []
2718
+ VS1 = []
2719
+ VS2 = []
2720
+ VS3 = []
2721
+ VS3P = []
2722
+ VS4 = []
2723
+
2724
+ off_S2 = -2
2725
+ off_S3 = -3
2726
+ off_S4 = -4
2727
+ if self.use_1D:
2728
+ off_S2 = -1
2729
+ off_S3 = -1
2730
+ off_S4 = -1
2731
+
2732
+ # S2 for normalization
2733
+ cond_init_P1_dic = (norm == "self") or (
2734
+ (norm == "auto") and (self.P1_dic is None)
2735
+ )
2736
+ if norm is None:
2737
+ pass
2738
+ elif cond_init_P1_dic:
2739
+ P1_dic = {}
2740
+ if cross:
2741
+ P2_dic = {}
2742
+ elif (norm == "auto") and (self.P1_dic is not None):
2743
+ P1_dic = self.P1_dic
2744
+ if cross:
2745
+ P2_dic = self.P2_dic
2746
+
2747
+ if return_data:
2748
+ s0 = I1
2749
+ if out_nside is not None:
2750
+ s0 = self.backend.bk_reduce_mean(
2751
+ self.backend.bk_reshape(
2752
+ s0, [s0.shape[0], 12 * out_nside**2, (nside // out_nside) ** 2]
2753
+ ),
2754
+ 2,
2755
+ )
2756
+ else:
2757
+ if not cross:
2758
+ s0, l_vs0 = self.masked_mean(I1, vmask, axis=1, calc_var=True)
2759
+ else:
2760
+ s0, l_vs0 = self.masked_mean(
2761
+ self.backend.bk_L1(I1 * I2), vmask, axis=1, calc_var=True
2762
+ )
2763
+ vs0 = self.backend.bk_concat([l_vs0, l_vs0], 1)
2764
+ s0 = self.backend.bk_concat([s0, l_vs0], 1)
2765
+ #### COMPUTE S1, S2, S3 and S4
2766
+ nside_j3 = nside # NSIDE start (nside_j3 = nside / 2^j3)
2767
+
2768
+ # a remettre comme avant
2769
+ M1_dic={}
2770
+ M2_dic={}
2771
+
2772
+ for j3 in range(Jmax):
2773
+
2774
+ if edge:
2775
+ if self.mask_mask is None:
2776
+ self.mask_mask={}
2777
+ if self.use_2D:
2778
+ if (vmask.shape[1],vmask.shape[2]) not in self.mask_mask:
2779
+ mask_mask=np.zeros([1,vmask.shape[1],vmask.shape[2]])
2780
+ mask_mask[0,
2781
+ self.KERNELSZ//2:-self.KERNELSZ//2+1,
2782
+ self.KERNELSZ//2:-self.KERNELSZ//2+1]=1.0
2783
+ self.mask_mask[(vmask.shape[1],vmask.shape[2])]=self.backend.bk_cast(mask_mask)
2784
+ vmask=vmask*self.mask_mask[(vmask.shape[1],vmask.shape[2])]
2785
+ #print(self.KERNELSZ//2,vmask,mask_mask)
2786
+
2787
+ if self.use_1D:
2788
+ if (vmask.shape[1]) not in self.mask_mask:
2789
+ mask_mask=np.zeros([1,vmask.shape[1]])
2790
+ mask_mask[0,
2791
+ self.KERNELSZ//2:-self.KERNELSZ//2+1]=1.0
2792
+ self.mask_mask[(vmask.shape[1])]=self.backend.bk_cast(mask_mask)
2793
+ vmask=vmask*self.mask_mask[(vmask.shape[1])]
2794
+
2795
+ if return_data:
2796
+ S3[j3] = None
2797
+ S3P[j3] = None
2798
+
2799
+ if S4 is None:
2800
+ S4 = {}
2801
+ S4[j3] = None
2802
+
2803
+ ####### S1 and S2
2804
+ ### Make the convolution I1 * Psi_j3
2805
+ conv1 = self.convol(I1, axis=1) # [Nbatch, Npix_j3, Norient3]
2806
+ if cmat is not None:
2807
+ tmp2 = self.backend.bk_repeat(conv1, self.NORIENT, axis=-1)
2808
+ conv1 = self.backend.bk_reduce_sum(
2809
+ self.backend.bk_reshape(
2810
+ cmat[j3] * tmp2,
2811
+ [tmp2.shape[0], cmat[j3].shape[0], self.NORIENT, self.NORIENT],
2812
+ ),
2813
+ 2,
2814
+ )
2815
+
2816
+ ### Take the module M1 = |I1 * Psi_j3|
2817
+ M1_square = conv1 * self.backend.bk_conjugate(
2818
+ conv1
2819
+ ) # [Nbatch, Npix_j3, Norient3]
2820
+ M1 = self.backend.bk_L1(M1_square) # [Nbatch, Npix_j3, Norient3]
2821
+ # Store M1_j3 in a dictionary
2822
+ M1_dic[j3] = M1
2823
+
2824
+ if not cross: # Auto
2825
+ M1_square = self.backend.bk_real(M1_square)
2826
+
2827
+ ### S2_auto = < M1^2 >_pix
2828
+ # Apply the mask [Nmask, Npix_j3] and average over pixels
2829
+ if return_data:
2830
+ s2 = M1_square
2831
+ else:
2832
+ if calc_var:
2833
+ s2, vs2 = self.masked_mean(
2834
+ M1_square, vmask, axis=1, rank=j3, calc_var=True
2835
+ )
2836
+ else:
2837
+ s2 = self.masked_mean(M1_square, vmask, axis=1, rank=j3)
2838
+
2839
+ if cond_init_P1_dic:
2840
+ # We fill P1_dic with S2 for normalisation of S3 and S4
2841
+ P1_dic[j3] = self.backend.bk_real(s2) # [Nbatch, Nmask, Norient3]
2842
+
2843
+ # We store S2_auto to return it [Nbatch, Nmask, NS2, Norient3]
2844
+ if return_data:
2845
+ if S2 is None:
2846
+ S2 = {}
2847
+ if out_nside is not None and out_nside < nside_j3:
2848
+ s2 = self.backend.bk_reduce_mean(
2849
+ self.backend.bk_reshape(
2850
+ s2,
2851
+ [
2852
+ s2.shape[0],
2853
+ 12 * out_nside**2,
2854
+ (nside_j3 // out_nside) ** 2,
2855
+ s2.shape[2],
2856
+ ],
2857
+ ),
2858
+ 2,
2859
+ )
2860
+ S2[j3] = s2
2861
+ else:
2862
+ if norm == "auto": # Normalize S2
2863
+ s2 /= P1_dic[j3]
2864
+
2865
+ S2.append(
2866
+ self.backend.bk_expand_dims(s2, off_S2)
2867
+ ) # Add a dimension for NS2
2868
+ if calc_var:
2869
+ VS2.append(
2870
+ self.backend.bk_expand_dims(vs2, off_S2)
2871
+ ) # Add a dimension for NS2
2872
+
2873
+ #### S1_auto computation
2874
+ ### Image 1 : S1 = < M1 >_pix
2875
+ # Apply the mask [Nmask, Npix_j3] and average over pixels
2876
+ if return_data:
2877
+ s1 = M1
2878
+ else:
2879
+ if calc_var:
2880
+ s1, vs1 = self.masked_mean(
2881
+ M1, vmask, axis=1, rank=j3, calc_var=True
2882
+ ) # [Nbatch, Nmask, Norient3]
2883
+ else:
2884
+ s1 = self.masked_mean(
2885
+ M1, vmask, axis=1, rank=j3
2886
+ ) # [Nbatch, Nmask, Norient3]
2887
+
2888
+ if return_data:
2889
+ if out_nside is not None and out_nside < nside_j3:
2890
+ s1 = self.backend.bk_reduce_mean(
2891
+ self.backend.bk_reshape(
2892
+ s1,
2893
+ [
2894
+ s1.shape[0],
2895
+ 12 * out_nside**2,
2896
+ (nside_j3 // out_nside) ** 2,
2897
+ s1.shape[2],
2898
+ ],
2899
+ ),
2900
+ 2,
2901
+ )
2902
+ S1[j3] = s1
2903
+ else:
2904
+ ### Normalize S1
2905
+ if norm is not None:
2906
+ self.div_norm(s1, (P1_dic[j3]) ** 0.5)
2907
+ ### We store S1 for image1 [Nbatch, Nmask, NS1, Norient3]
2908
+ S1.append(
2909
+ self.backend.bk_expand_dims(s1, off_S2)
2910
+ ) # Add a dimension for NS1
2911
+ if calc_var:
2912
+ VS1.append(
2913
+ self.backend.bk_expand_dims(vs1, off_S2)
2914
+ ) # Add a dimension for NS1
2915
+
2916
+ else: # Cross
2917
+ ### Make the convolution I2 * Psi_j3
2918
+ conv2 = self.convol(I2, axis=1) # [Nbatch, Npix_j3, Norient3]
2919
+ if cmat is not None:
2920
+ tmp2 = self.backend.bk_repeat(conv2, self.NORIENT, axis=-1)
2921
+ conv2 = self.backend.bk_reduce_sum(
2922
+ self.backend.bk_reshape(
2923
+ cmat[j3] * tmp2,
2924
+ [
2925
+ tmp2.shape[0],
2926
+ cmat[j3].shape[0],
2927
+ self.NORIENT,
2928
+ self.NORIENT,
2929
+ ],
2930
+ ),
2931
+ 2,
2932
+ )
2933
+ ### Take the module M2 = |I2 * Psi_j3|
2934
+ M2_square = conv2 * self.backend.bk_conjugate(
2935
+ conv2
2936
+ ) # [Nbatch, Npix_j3, Norient3]
2937
+ M2 = self.backend.bk_L1(M2_square) # [Nbatch, Npix_j3, Norient3]
2938
+ # Store M2_j3 in a dictionary
2939
+ M2_dic[j3] = M2
2940
+
2941
+ ### S2_auto = < M2^2 >_pix
2942
+ # Not returned, only for normalization
2943
+ if cond_init_P1_dic:
2944
+ # Apply the mask [Nmask, Npix_j3] and average over pixels
2945
+ if return_data:
2946
+ p1 = M1_square
2947
+ p2 = M2_square
2948
+ else:
2949
+ if calc_var:
2950
+ p1, vp1 = self.masked_mean(
2951
+ M1_square, vmask, axis=1, rank=j3, calc_var=True
2952
+ ) # [Nbatch, Nmask, Norient3]
2953
+ p2, vp2 = self.masked_mean(
2954
+ M2_square, vmask, axis=1, rank=j3, calc_var=True
2955
+ ) # [Nbatch, Nmask, Norient3]
2956
+ else:
2957
+ p1 = self.masked_mean(
2958
+ M1_square, vmask, axis=1, rank=j3
2959
+ ) # [Nbatch, Nmask, Norient3]
2960
+ p2 = self.masked_mean(
2961
+ M2_square, vmask, axis=1, rank=j3
2962
+ ) # [Nbatch, Nmask, Norient3]
2963
+ # We fill P1_dic with S2 for normalisation of S3 and S4
2964
+ P1_dic[j3] = self.backend.bk_real(p1) # [Nbatch, Nmask, Norient3]
2965
+ P2_dic[j3] = self.backend.bk_real(p2) # [Nbatch, Nmask, Norient3]
2966
+
2967
+ ### S2_cross = < (I1 * Psi_j3) (I2 * Psi_j3)^* >_pix
2968
+ # z_1 x z_2^* = (a1a2 + b1b2) + i(b1a2 - a1b2)
2969
+ s2 = conv1 * self.backend.bk_conjugate(conv2)
2970
+ MX = self.backend.bk_L1(s2)
2971
+ # Apply the mask [Nmask, Npix_j3] and average over pixels
2972
+ if return_data:
2973
+ s2 = s2
2974
+ else:
2975
+ if calc_var:
2976
+ s2, vs2 = self.masked_mean(
2977
+ s2, vmask, axis=1, rank=j3, calc_var=True
2978
+ )
2979
+ else:
2980
+ s2 = self.masked_mean(s2, vmask, axis=1, rank=j3)
2981
+
2982
+ if return_data:
2983
+ if out_nside is not None and out_nside < nside_j3:
2984
+ s2 = self.backend.bk_reduce_mean(
2985
+ self.backend.bk_reshape(
2986
+ s2,
2987
+ [
2988
+ s2.shape[0],
2989
+ 12 * out_nside**2,
2990
+ (nside_j3 // out_nside) ** 2,
2991
+ s2.shape[2],
2992
+ ],
2993
+ ),
2994
+ 2,
2995
+ )
2996
+ S2[j3] = s2
2997
+ else:
2998
+ ### Normalize S2_cross
2999
+ if norm == "auto":
3000
+ s2 /= (P1_dic[j3] * P2_dic[j3]) ** 0.5
3001
+
3002
+ ### Store S2_cross as complex [Nbatch, Nmask, NS2, Norient3]
3003
+ s2 = self.backend.bk_real(s2)
3004
+
3005
+ S2.append(
3006
+ self.backend.bk_expand_dims(s2, off_S2)
3007
+ ) # Add a dimension for NS2
3008
+ if calc_var:
3009
+ VS2.append(
3010
+ self.backend.bk_expand_dims(vs2, off_S2)
3011
+ ) # Add a dimension for NS2
3012
+
3013
+ #### S1_auto computation
3014
+ ### Image 1 : S1 = < M1 >_pix
3015
+ # Apply the mask [Nmask, Npix_j3] and average over pixels
3016
+ if return_data:
3017
+ s1 = MX
3018
+ else:
3019
+ if calc_var:
3020
+ s1, vs1 = self.masked_mean(
3021
+ MX, vmask, axis=1, rank=j3, calc_var=True
3022
+ ) # [Nbatch, Nmask, Norient3]
3023
+ else:
3024
+ s1 = self.masked_mean(
3025
+ MX, vmask, axis=1, rank=j3
3026
+ ) # [Nbatch, Nmask, Norient3]
3027
+ if return_data:
3028
+ if out_nside is not None and out_nside < nside_j3:
3029
+ s1 = self.backend.bk_reduce_mean(
3030
+ self.backend.bk_reshape(
3031
+ s1,
3032
+ [
3033
+ s1.shape[0],
3034
+ 12 * out_nside**2,
3035
+ (nside_j3 // out_nside) ** 2,
3036
+ s1.shape[2],
3037
+ ],
3038
+ ),
3039
+ 2,
3040
+ )
3041
+ S1[j3] = s1
3042
+ else:
3043
+ ### Normalize S1
3044
+ if norm is not None:
3045
+ self.div_norm(s1, (P1_dic[j3]) ** 0.5)
3046
+ ### We store S1 for image1 [Nbatch, Nmask, NS1, Norient3]
3047
+ S1.append(
3048
+ self.backend.bk_expand_dims(s1, off_S2)
3049
+ ) # Add a dimension for NS1
3050
+ if calc_var:
3051
+ VS1.append(
3052
+ self.backend.bk_expand_dims(vs1, off_S2)
3053
+ ) # Add a dimension for NS1
3054
+
3055
+ # Initialize dictionaries for |I1*Psi_j| * Psi_j3
3056
+ M1convPsi_dic = {}
3057
+ if cross:
3058
+ # Initialize dictionaries for |I2*Psi_j| * Psi_j3
3059
+ M2convPsi_dic = {}
3060
+
3061
+ ###### S3
3062
+ nside_j2 = nside_j3
3063
+ for j2 in range(0, j3 + 1): # j2 <= j3
3064
+ if return_data:
3065
+ if S4[j3] is None:
3066
+ S4[j3] = {}
3067
+ S4[j3][j2] = None
3068
+
3069
+ ### S3_auto = < (I1 * Psi)_j3 x (|I1 * Psi_j2| * Psi_j3)^* >_pix
3070
+ if not cross:
3071
+ if calc_var:
3072
+ s3, vs3 = self._compute_S3(
3073
+ j2,
3074
+ j3,
3075
+ conv1,
3076
+ vmask,
3077
+ M1_dic,
3078
+ M1convPsi_dic,
3079
+ calc_var=True,
3080
+ cmat2=cmat2,
3081
+ ) # [Nbatch, Nmask, Norient3, Norient2]
3082
+ else:
3083
+ s3 = self._compute_S3(
3084
+ j2,
3085
+ j3,
3086
+ conv1,
3087
+ vmask,
3088
+ M1_dic,
3089
+ M1convPsi_dic,
3090
+ return_data=return_data,
3091
+ cmat2=cmat2,
3092
+ ) # [Nbatch, Nmask, Norient3, Norient2]
3093
+
3094
+ if return_data:
3095
+ if S3[j3] is None:
3096
+ S3[j3] = {}
3097
+ if out_nside is not None and out_nside < nside_j2:
3098
+ s3 = self.backend.bk_reduce_mean(
3099
+ self.backend.bk_reshape(
3100
+ s3,
3101
+ [
3102
+ s3.shape[0],
3103
+ 12 * out_nside**2,
3104
+ (nside_j2 // out_nside) ** 2,
3105
+ s3.shape[2],
3106
+ s3.shape[3],
3107
+ ],
3108
+ ),
3109
+ 2,
3110
+ )
3111
+ S3[j3][j2] = s3
3112
+ else:
3113
+ ### Normalize S3 with S2_j [Nbatch, Nmask, Norient_j]
3114
+ if norm is not None:
3115
+ self.div_norm(
3116
+ s3,
3117
+ (
3118
+ self.backend.bk_expand_dims(P1_dic[j2], off_S2)
3119
+ * self.backend.bk_expand_dims(P1_dic[j3], -1)
3120
+ )
3121
+ ** 0.5,
3122
+ ) # [Nbatch, Nmask, Norient3, Norient2]
3123
+
3124
+ ### Store S3 as a complex [Nbatch, Nmask, NS3, Norient3, Norient2]
3125
+
3126
+ # S3.append(self.backend.bk_reshape(s3,[s3.shape[0],s3.shape[1],
3127
+ # s3.shape[2]*s3.shape[3]]))
3128
+ S3.append(
3129
+ self.backend.bk_expand_dims(s3, off_S3)
3130
+ ) # Add a dimension for NS3
3131
+ if calc_var:
3132
+ VS3.append(
3133
+ self.backend.bk_expand_dims(vs3, off_S3)
3134
+ ) # Add a dimension for NS3
3135
+ # VS3.append(self.backend.bk_reshape(vs3,[s3.shape[0],s3.shape[1],
3136
+ # s3.shape[2]*s3.shape[3]]))
3137
+
3138
+ ### S3_cross = < (I1 * Psi)_j3 x (|I2 * Psi_j2| * Psi_j3)^* >_pix
3139
+ ### S3P_cross = < (I2 * Psi)_j3 x (|I1 * Psi_j2| * Psi_j3)^* >_pix
3140
+ else:
3141
+ if calc_var:
3142
+ s3, vs3 = self._compute_S3(
3143
+ j2,
3144
+ j3,
3145
+ conv1,
3146
+ vmask,
3147
+ M2_dic,
3148
+ M2convPsi_dic,
3149
+ calc_var=True,
3150
+ cmat2=cmat2,
3151
+ )
3152
+ s3p, vs3p = self._compute_S3(
3153
+ j2,
3154
+ j3,
3155
+ conv2,
3156
+ vmask,
3157
+ M1_dic,
3158
+ M1convPsi_dic,
3159
+ calc_var=True,
3160
+ cmat2=cmat2,
3161
+ )
3162
+ else:
3163
+ s3 = self._compute_S3(
3164
+ j2,
3165
+ j3,
3166
+ conv1,
3167
+ vmask,
3168
+ M2_dic,
3169
+ M2convPsi_dic,
3170
+ return_data=return_data,
3171
+ cmat2=cmat2,
3172
+ )
3173
+ s3p = self._compute_S3(
3174
+ j2,
3175
+ j3,
3176
+ conv2,
3177
+ vmask,
3178
+ M1_dic,
3179
+ M1convPsi_dic,
3180
+ return_data=return_data,
3181
+ cmat2=cmat2,
3182
+ )
3183
+
3184
+ if return_data:
3185
+ if S3[j3] is None:
3186
+ S3[j3] = {}
3187
+ S3P[j3] = {}
3188
+ if out_nside is not None and out_nside < nside_j2:
3189
+ s3 = self.backend.bk_reduce_mean(
3190
+ self.backend.bk_reshape(
3191
+ s3,
3192
+ [
3193
+ s3.shape[0],
3194
+ 12 * out_nside**2,
3195
+ (nside_j2 // out_nside) ** 2,
3196
+ s3.shape[2],
3197
+ s3.shape[3],
3198
+ ],
3199
+ ),
3200
+ 2,
3201
+ )
3202
+ s3p = self.backend.bk_reduce_mean(
3203
+ self.backend.bk_reshape(
3204
+ s3p,
3205
+ [
3206
+ s3.shape[0],
3207
+ 12 * out_nside**2,
3208
+ (nside_j2 // out_nside) ** 2,
3209
+ s3.shape[2],
3210
+ s3.shape[3],
3211
+ ],
3212
+ ),
3213
+ 2,
3214
+ )
3215
+ S3[j3][j2] = s3
3216
+ S3P[j3][j2] = s3p
3217
+ else:
3218
+ ### Normalize S3 and S3P with S2_j [Nbatch, Nmask, Norient_j]
3219
+ if norm is not None:
3220
+ self.div_norm(
3221
+ s3,
3222
+ (
3223
+ self.backend.bk_expand_dims(P2_dic[j2], off_S2)
3224
+ * self.backend.bk_expand_dims(P1_dic[j3], -1)
3225
+ )
3226
+ ** 0.5,
3227
+ ) # [Nbatch, Nmask, Norient3, Norient2]
3228
+ self.div_norm(
3229
+ s3p,
3230
+ (
3231
+ self.backend.bk_expand_dims(P1_dic[j2], off_S2)
3232
+ * self.backend.bk_expand_dims(P2_dic[j3], -1)
3233
+ )
3234
+ ** 0.5,
3235
+ ) # [Nbatch, Nmask, Norient3, Norient2]
3236
+
3237
+ ### Store S3 and S3P as a complex [Nbatch, Nmask, NS3, Norient3, Norient2]
3238
+
3239
+ # S3.append(self.backend.bk_reshape(s3,[s3.shape[0],s3.shape[1],
3240
+ # s3.shape[2]*s3.shape[3]]))
3241
+ S3.append(
3242
+ self.backend.bk_expand_dims(s3, off_S3)
3243
+ ) # Add a dimension for NS3
3244
+ if calc_var:
3245
+ VS3.append(
3246
+ self.backend.bk_expand_dims(vs3, off_S3)
3247
+ ) # Add a dimension for NS3
3248
+
3249
+ # VS3.append(self.backend.bk_reshape(vs3,[s3.shape[0],s3.shape[1],
3250
+ # s3.shape[2]*s3.shape[3]]))
3251
+
3252
+ # S3P.append(self.backend.bk_reshape(s3p,[s3.shape[0],s3.shape[1],
3253
+ # s3.shape[2]*s3.shape[3]]))
3254
+ S3P.append(
3255
+ self.backend.bk_expand_dims(s3p, off_S3)
3256
+ ) # Add a dimension for NS3
3257
+ if calc_var:
3258
+ VS3P.append(
3259
+ self.backend.bk_expand_dims(vs3p, off_S3)
3260
+ ) # Add a dimension for NS3
3261
+ # VS3P.append(self.backend.bk_reshape(vs3p,[s3.shape[0],s3.shape[1],
3262
+ # s3.shape[2]*s3.shape[3]]))
3263
+
3264
+ ##### S4
3265
+ nside_j1 = nside_j2
3266
+ for j1 in range(0, j2 + 1): # j1 <= j2
3267
+ ### S4_auto = <(|I1 * psi1| * psi3)(|I1 * psi2| * psi3)^*>
3268
+ if not cross:
3269
+ if calc_var:
3270
+ s4, vs4 = self._compute_S4(
3271
+ j1,
3272
+ j2,
3273
+ vmask,
3274
+ M1convPsi_dic,
3275
+ M2convPsi_dic=None,
3276
+ calc_var=True,
3277
+ ) # [Nbatch, Nmask, Norient3, Norient2, Norient1]
3278
+ else:
3279
+ s4 = self._compute_S4(
3280
+ j1,
3281
+ j2,
3282
+ vmask,
3283
+ M1convPsi_dic,
3284
+ M2convPsi_dic=None,
3285
+ return_data=return_data,
3286
+ ) # [Nbatch, Nmask, Norient3, Norient2, Norient1]
3287
+
3288
+ if return_data:
3289
+ if S4[j3][j2] is None:
3290
+ S4[j3][j2] = {}
3291
+ if out_nside is not None and out_nside < nside_j1:
3292
+ s4 = self.backend.bk_reduce_mean(
3293
+ self.backend.bk_reshape(
3294
+ s4,
3295
+ [
3296
+ s4.shape[0],
3297
+ 12 * out_nside**2,
3298
+ (nside_j1 // out_nside) ** 2,
3299
+ s4.shape[2],
3300
+ s4.shape[3],
3301
+ s4.shape[4],
3302
+ ],
3303
+ ),
3304
+ 2,
3305
+ )
3306
+ S4[j3][j2][j1] = s4
3307
+ else:
3308
+ ### Normalize S4 with S2_j [Nbatch, Nmask, Norient_j]
3309
+ if norm is not None:
3310
+ self.div_norm(
3311
+ s4,
3312
+ (
3313
+ self.backend.bk_expand_dims(
3314
+ self.backend.bk_expand_dims(
3315
+ P1_dic[j1], off_S2
3316
+ ),
3317
+ off_S2,
3318
+ )
3319
+ * self.backend.bk_expand_dims(
3320
+ self.backend.bk_expand_dims(
3321
+ P1_dic[j2], off_S2
3322
+ ),
3323
+ -1,
3324
+ )
3325
+ )
3326
+ ** 0.5,
3327
+ ) # [Nbatch, Nmask, Norient3, Norient2, Norient1]
3328
+ ### Store S4 as a complex [Nbatch, Nmask, NS4, Norient3, Norient2, Norient1]
3329
+
3330
+ # S4.append(self.backend.bk_reshape(s4,[s4.shape[0],s4.shape[1],
3331
+ # s4.shape[2]*s4.shape[3]*s4.shape[4]]))
3332
+ S4.append(
3333
+ self.backend.bk_expand_dims(s4, off_S4)
3334
+ ) # Add a dimension for NS4
3335
+ if calc_var:
3336
+ # VS4.append(self.backend.bk_reshape(vs4,[s4.shape[0],s4.shape[1],
3337
+ # s4.shape[2]*s4.shape[3]*s4.shape[4]]))
3338
+ VS4.append(
3339
+ self.backend.bk_expand_dims(vs4, off_S4)
3340
+ ) # Add a dimension for NS4
3341
+
3342
+ ### S4_cross = <(|I1 * psi1| * psi3)(|I2 * psi2| * psi3)^*>
3343
+ else:
3344
+ if calc_var:
3345
+ s4, vs4 = self._compute_S4(
3346
+ j1,
3347
+ j2,
3348
+ vmask,
3349
+ M1convPsi_dic,
3350
+ M2convPsi_dic=M2convPsi_dic,
3351
+ calc_var=True,
3352
+ ) # [Nbatch, Nmask, Norient3, Norient2, Norient1]
3353
+ else:
3354
+ s4 = self._compute_S4(
3355
+ j1,
3356
+ j2,
3357
+ vmask,
3358
+ M1convPsi_dic,
3359
+ M2convPsi_dic=M2convPsi_dic,
3360
+ return_data=return_data,
3361
+ ) # [Nbatch, Nmask, Norient3, Norient2, Norient1]
3362
+
3363
+ if return_data:
3364
+ if S4[j3][j2] is None:
3365
+ S4[j3][j2] = {}
3366
+ if out_nside is not None and out_nside < nside_j1:
3367
+ s4 = self.backend.bk_reduce_mean(
3368
+ self.backend.bk_reshape(
3369
+ s4,
3370
+ [
3371
+ s4.shape[0],
3372
+ 12 * out_nside**2,
3373
+ (nside_j1 // out_nside) ** 2,
3374
+ s4.shape[2],
3375
+ s4.shape[3],
3376
+ s4.shape[4],
3377
+ ],
3378
+ ),
3379
+ 2,
3380
+ )
3381
+ S4[j3][j2][j1] = s4
3382
+ else:
3383
+ ### Normalize S4 with S2_j [Nbatch, Nmask, Norient_j]
3384
+ if norm is not None:
3385
+ self.div_norm(
3386
+ s4,
3387
+ (
3388
+ self.backend.bk_expand_dims(
3389
+ self.backend.bk_expand_dims(
3390
+ P1_dic[j1], off_S2
3391
+ ),
3392
+ off_S2,
3393
+ )
3394
+ * self.backend.bk_expand_dims(
3395
+ self.backend.bk_expand_dims(
3396
+ P2_dic[j2], off_S2
3397
+ ),
3398
+ -1,
3399
+ )
3400
+ )
3401
+ ** 0.5,
3402
+ ) # [Nbatch, Nmask, Norient3, Norient2, Norient1]
3403
+ ### Store S4 as a complex [Nbatch, Nmask, NS4, Norient3, Norient2, Norient1]
3404
+ # S4.append(self.backend.bk_reshape(s4,[s4.shape[0],s4.shape[1],
3405
+ # s4.shape[2]*s4.shape[3]*s4.shape[4]]))
3406
+ S4.append(
3407
+ self.backend.bk_expand_dims(s4, off_S4)
3408
+ ) # Add a dimension for NS4
3409
+ if calc_var:
3410
+
3411
+ # VS4.append(self.backend.bk_reshape(vs4,[s4.shape[0],s4.shape[1],
3412
+ # s4.shape[2]*s4.shape[3]*s4.shape[4]]))
3413
+ VS4.append(
3414
+ self.backend.bk_expand_dims(vs4, off_S4)
3415
+ ) # Add a dimension for NS4
3416
+
3417
+ nside_j1 = nside_j1 // 2
3418
+ nside_j2 = nside_j2 // 2
3419
+
3420
+ ###### Reshape for next iteration on j3
3421
+ ### Image I1,
3422
+ # downscale the I1 [Nbatch, Npix_j3]
3423
+ if j3 != Jmax - 1:
3424
+ I1 = self.smooth(I1, axis=1)
3425
+ I1 = self.ud_grade_2(I1, axis=1)
3426
+
3427
+ ### Image I2
3428
+ if cross:
3429
+ I2 = self.smooth(I2, axis=1)
3430
+ I2 = self.ud_grade_2(I2, axis=1)
3431
+
3432
+ ### Modules
3433
+ for j2 in range(0, j3 + 1): # j2 =< j3
3434
+ ### Dictionary M1_dic[j2]
3435
+ M1_smooth = self.smooth(
3436
+ M1_dic[j2], axis=1
3437
+ ) # [Nbatch, Npix_j3, Norient3]
3438
+ M1_dic[j2] = self.ud_grade_2(
3439
+ M1_smooth, axis=1
3440
+ ) # [Nbatch, Npix_j3, Norient3]
3441
+
3442
+ ### Dictionary M2_dic[j2]
3443
+ if cross:
3444
+ M2_smooth = self.smooth(
3445
+ M2_dic[j2], axis=1
3446
+ ) # [Nbatch, Npix_j3, Norient3]
3447
+ M2_dic[j2] = self.ud_grade_2(
3448
+ M2, axis=1
3449
+ ) # [Nbatch, Npix_j3, Norient3]
3450
+
3451
+ ### Mask
3452
+ vmask = self.ud_grade_2(vmask, axis=1)
3453
+
3454
+ if self.mask_thres is not None:
3455
+ vmask = self.backend.bk_threshold(vmask, self.mask_thres)
3456
+
3457
+ ### NSIDE_j3
3458
+ nside_j3 = nside_j3 // 2
3459
+
3460
+ ### Store P1_dic and P2_dic in self
3461
+ if (norm == "auto") and (self.P1_dic is None):
3462
+ self.P1_dic = P1_dic
3463
+ if cross:
3464
+ self.P2_dic = P2_dic
3465
+ """
3466
+ Sout=[s0]+S1+S2+S3+S4
3467
+
3468
+ if cross:
3469
+ Sout=Sout+S3P
3470
+ if calc_var:
3471
+ SVout=[vs0]+VS1+VS2+VS3+VS4
3472
+ if cross:
3473
+ VSout=VSout+VS3P
3474
+ return self.backend.bk_concat(Sout, 2),self.backend.bk_concat(VSout, 2)
3475
+
3476
+ return self.backend.bk_concat(Sout, 2)
3477
+ """
3478
+ if not return_data:
3479
+ S1 = self.backend.bk_concat(S1, 2)
3480
+ S2 = self.backend.bk_concat(S2, 2)
3481
+ S3 = self.backend.bk_concat(S3, 2)
3482
+ S4 = self.backend.bk_concat(S4, 2)
3483
+ if cross:
3484
+ S3P = self.backend.bk_concat(S3P, 2)
3485
+ if calc_var:
3486
+ VS1 = self.backend.bk_concat(VS1, 2)
3487
+ VS2 = self.backend.bk_concat(VS2, 2)
3488
+ VS3 = self.backend.bk_concat(VS3, 2)
3489
+ VS4 = self.backend.bk_concat(VS4, 2)
3490
+ if cross:
3491
+ VS3P = self.backend.bk_concat(VS3P, 2)
3492
+ if calc_var:
3493
+ if not cross:
3494
+ return scat_cov(
3495
+ s0, S2, S3, S4, s1=S1, backend=self.backend, use_1D=self.use_1D
3496
+ ), scat_cov(
3497
+ vs0,
3498
+ VS2,
3499
+ VS3,
3500
+ VS4,
3501
+ s1=VS1,
3502
+ backend=self.backend,
3503
+ use_1D=self.use_1D,
3504
+ )
3505
+ else:
3506
+ return scat_cov(
3507
+ s0,
3508
+ S2,
3509
+ S3,
3510
+ S4,
3511
+ s1=S1,
3512
+ s3p=S3P,
3513
+ backend=self.backend,
3514
+ use_1D=self.use_1D,
3515
+ ), scat_cov(
3516
+ vs0,
3517
+ VS2,
3518
+ VS3,
3519
+ VS4,
3520
+ s1=VS1,
3521
+ s3p=VS3P,
3522
+ backend=self.backend,
3523
+ use_1D=self.use_1D,
3524
+ )
3525
+ else:
3526
+ if not cross:
3527
+ return scat_cov(
3528
+ s0, S2, S3, S4, s1=S1, backend=self.backend, use_1D=self.use_1D
3529
+ )
3530
+ else:
3531
+ return scat_cov(
3532
+ s0,
3533
+ S2,
3534
+ S3,
3535
+ S4,
3536
+ s1=S1,
3537
+ s3p=S3P,
3538
+ backend=self.backend,
3539
+ use_1D=self.use_1D,
3540
+ )
3541
+
3542
+ def eval_new(
3543
+ self,
3544
+ image1,
3545
+ image2=None,
3546
+ mask=None,
3547
+ norm=None,
3548
+ calc_var=False,
3549
+ cmat=None,
3550
+ cmat2=None,
3551
+ Jmax=None,
3552
+ out_nside=None,
3553
+ edge=True
3554
+ ):
3555
+ """
3556
+ Calculates the scattering correlations for a batch of images. Mean are done over pixels.
3557
+ mean of modulus:
3558
+ S1 = <|I * Psi_j3|>
3559
+ Normalization : take the log
3560
+ power spectrum:
3561
+ S2 = <|I * Psi_j3|^2>
3562
+ Normalization : take the log
3563
+ orig. x modulus:
3564
+ S3 = < (I * Psi)_j3 x (|I * Psi_j2| * Psi_j3)^* >
3565
+ Normalization : divide by (S2_j2 * S2_j3)^0.5
3566
+ modulus x modulus:
3567
+ S4 = <(|I * psi1| * psi3)(|I * psi2| * psi3)^*>
3568
+ Normalization : divide by (S2_j1 * S2_j2)^0.5
3569
+ Parameters
3570
+ ----------
3571
+ image1: tensor
3572
+ Image on which we compute the scattering coefficients [Nbatch, Npix, 1, 1]
3573
+ image2: tensor
3574
+ Second image. If not None, we compute cross-scattering covariance coefficients.
3575
+ mask:
3576
+ norm: None or str
3577
+ If None no normalization is applied, if 'auto' normalize by the reference S2,
3578
+ if 'self' normalize by the current S2.
3579
+ Returns
3580
+ -------
3581
+ S1, S2, S3, S4 normalized
3582
+ """
3583
+ return_data = self.return_data
3584
+ NORIENT=self.NORIENT
3585
+ # Check input consistency
2545
3586
  if image2 is not None:
2546
3587
  if list(image1.shape) != list(image2.shape):
2547
3588
  print(
@@ -2699,13 +3740,19 @@ class funct(FOC.FoCUS):
2699
3740
  S3P = {}
2700
3741
  S4 = {}
2701
3742
  else:
2702
- S1 = []
2703
- S2 = []
3743
+ result=self.backend.backend.zeros([I1.shape[0],vmask.shape[0],2+2*Jmax*self.NORIENT],
3744
+ dtype=self.backend.backend.float32,
3745
+ device=self.backend.torch_device)
3746
+ vresult=self.backend.backend.zeros([I1.shape[0],vmask.shape[0],2+2*Jmax*self.NORIENT],
3747
+ dtype=self.backend.backend.float32,
3748
+ device=self.backend.torch_device)
3749
+ S1 = self.backend.backend.zeros([1,Jmax*self.NORIENT],dtype=self.backend.backend.float32,device=self.backend.torch_device)
3750
+ S2 = self.backend.backend.zeros([1,Jmax*self.NORIENT],dtype=self.backend.backend.float32,device=self.backend.torch_device)
2704
3751
  S3 = []
2705
3752
  S4 = []
2706
3753
  S3P = []
2707
- VS1 = []
2708
- VS2 = []
3754
+ VS1 = self.backend.backend.zeros([1,Jmax*self.NORIENT],dtype=self.backend.backend.float32,device=self.backend.torch_device)
3755
+ VS2 = self.backend.backend.zeros([1,Jmax*self.NORIENT],dtype=self.backend.backend.float32,device=self.backend.torch_device)
2709
3756
  VS3 = []
2710
3757
  VS3P = []
2711
3758
  VS4 = []
@@ -2749,14 +3796,20 @@ class funct(FOC.FoCUS):
2749
3796
  s0, l_vs0 = self.masked_mean(
2750
3797
  self.backend.bk_L1(I1 * I2), vmask, axis=1, calc_var=True
2751
3798
  )
2752
- vs0 = self.backend.bk_concat([l_vs0, l_vs0], 1)
2753
- s0 = self.backend.bk_concat([s0, l_vs0], 1)
3799
+ #vs0 = self.backend.bk_concat([l_vs0, l_vs0], 1)
3800
+ #s0 = self.backend.bk_concat([s0, l_vs0], 1)
3801
+ result[:,:,0]=s0
3802
+ result[:,:,1]=l_vs0
3803
+ vresult[:,:,0]=l_vs0
3804
+ vresult[:,:,1]=l_vs0
2754
3805
  #### COMPUTE S1, S2, S3 and S4
2755
3806
  nside_j3 = nside # NSIDE start (nside_j3 = nside / 2^j3)
2756
3807
 
2757
3808
  # a remettre comme avant
2758
3809
  M1_dic={}
2759
3810
 
3811
+ M2_dic={}
3812
+
2760
3813
  for j3 in range(Jmax):
2761
3814
 
2762
3815
  if edge:
@@ -2805,7 +3858,9 @@ class funct(FOC.FoCUS):
2805
3858
  M1_square = conv1 * self.backend.bk_conjugate(
2806
3859
  conv1
2807
3860
  ) # [Nbatch, Npix_j3, Norient3]
3861
+
2808
3862
  M1 = self.backend.bk_L1(M1_square) # [Nbatch, Npix_j3, Norient3]
3863
+
2809
3864
  # Store M1_j3 in a dictionary
2810
3865
  M1_dic[j3] = M1
2811
3866
 
@@ -2821,13 +3876,15 @@ class funct(FOC.FoCUS):
2821
3876
  s2, vs2 = self.masked_mean(
2822
3877
  M1_square, vmask, axis=1, rank=j3, calc_var=True
2823
3878
  )
3879
+ #s2=self.backend.bk_flatten(self.backend.bk_real(s2))
3880
+ #vs2=self.backend.bk_flatten(vs2)
2824
3881
  else:
2825
3882
  s2 = self.masked_mean(M1_square, vmask, axis=1, rank=j3)
2826
3883
 
2827
3884
  if cond_init_P1_dic:
2828
3885
  # We fill P1_dic with S2 for normalisation of S3 and S4
2829
- P1_dic[j3] = self.backend.bk_real(s2) # [Nbatch, Nmask, Norient3]
2830
-
3886
+ P1_dic[j3] = self.backend.bk_real(self.backend.bk_real(s2)) # [Nbatch, Nmask, Norient3]
3887
+
2831
3888
  # We store S2_auto to return it [Nbatch, Nmask, NS2, Norient3]
2832
3889
  if return_data:
2833
3890
  if S2 is None:
@@ -2849,7 +3906,7 @@ class funct(FOC.FoCUS):
2849
3906
  else:
2850
3907
  if norm == "auto": # Normalize S2
2851
3908
  s2 /= P1_dic[j3]
2852
-
3909
+ """
2853
3910
  S2.append(
2854
3911
  self.backend.bk_expand_dims(s2, off_S2)
2855
3912
  ) # Add a dimension for NS2
@@ -2857,7 +3914,11 @@ class funct(FOC.FoCUS):
2857
3914
  VS2.append(
2858
3915
  self.backend.bk_expand_dims(vs2, off_S2)
2859
3916
  ) # Add a dimension for NS2
2860
-
3917
+ """
3918
+ #print(s2.shape,result[:,:,2+j3*NORIENT*2:2+j3*NORIENT*2+NORIENT].shape,result.shape,2+j3*NORIENT*2)
3919
+ result[:,:,2+j3*NORIENT*2:2+j3*NORIENT*2+NORIENT]=s2
3920
+ if calc_var:
3921
+ vresult[:,:,2+j3*NORIENT*2:2+j3*NORIENT*2+NORIENT]=vs2
2861
3922
  #### S1_auto computation
2862
3923
  ### Image 1 : S1 = < M1 >_pix
2863
3924
  # Apply the mask [Nmask, Npix_j3] and average over pixels
@@ -2868,10 +3929,13 @@ class funct(FOC.FoCUS):
2868
3929
  s1, vs1 = self.masked_mean(
2869
3930
  M1, vmask, axis=1, rank=j3, calc_var=True
2870
3931
  ) # [Nbatch, Nmask, Norient3]
3932
+ #s1=self.backend.bk_flatten(self.backend.bk_real(s1))
3933
+ #vs1=self.backend.bk_flatten(vs1)
2871
3934
  else:
2872
3935
  s1 = self.masked_mean(
2873
3936
  M1, vmask, axis=1, rank=j3
2874
3937
  ) # [Nbatch, Nmask, Norient3]
3938
+ #s1=self.backend.bk_flatten(self.backend.bk_real(s1))
2875
3939
 
2876
3940
  if return_data:
2877
3941
  if out_nside is not None and out_nside < nside_j3:
@@ -2892,6 +3956,10 @@ class funct(FOC.FoCUS):
2892
3956
  ### Normalize S1
2893
3957
  if norm is not None:
2894
3958
  self.div_norm(s1, (P1_dic[j3]) ** 0.5)
3959
+ result[:,:,2+j3*NORIENT*2+NORIENT:2+j3*NORIENT*2+2*NORIENT]=s1
3960
+ if calc_var:
3961
+ vresult[:,:,2+j3*NORIENT*2+NORIENT:2+j3*NORIENT*2+2*NORIENT]=vs1
3962
+ """
2895
3963
  ### We store S1 for image1 [Nbatch, Nmask, NS1, Norient3]
2896
3964
  S1.append(
2897
3965
  self.backend.bk_expand_dims(s1, off_S2)
@@ -2900,6 +3968,7 @@ class funct(FOC.FoCUS):
2900
3968
  VS1.append(
2901
3969
  self.backend.bk_expand_dims(vs1, off_S2)
2902
3970
  ) # Add a dimension for NS1
3971
+ """
2903
3972
 
2904
3973
  else: # Cross
2905
3974
  ### Make the convolution I2 * Psi_j3
@@ -3048,7 +4117,7 @@ class funct(FOC.FoCUS):
3048
4117
 
3049
4118
  ###### S3
3050
4119
  nside_j2 = nside_j3
3051
- for j2 in range(0, j3 + 1): # j2 <= j3
4120
+ for j2 in range(0,-1): # j3 + 1): # j2 <= j3
3052
4121
  if return_data:
3053
4122
  if S4[j3] is None:
3054
4123
  S4[j3] = {}
@@ -3463,6 +4532,19 @@ class funct(FOC.FoCUS):
3463
4532
 
3464
4533
  return self.backend.bk_concat(Sout, 2)
3465
4534
  """
4535
+ if calc_var:
4536
+ return result,vresult
4537
+ else:
4538
+ return result
4539
+ if calc_var:
4540
+ for k in S1:
4541
+ print(k.shape,k.dtype)
4542
+ for k in S2:
4543
+ print(k.shape,k.dtype)
4544
+ print(s0.shape,s0.dtype)
4545
+ return self.backend.bk_concat([s0]+S1+S2,axis=1),self.backend.bk_concat([vs0]+VS1+VS2,axis=1)
4546
+ else:
4547
+ return self.backend.bk_concat([s0]+S1+S2,axis=1)
3466
4548
  if not return_data:
3467
4549
  S1 = self.backend.bk_concat(S1, 2)
3468
4550
  S2 = self.backend.bk_concat(S2, 2)
@@ -3526,7 +4608,6 @@ class funct(FOC.FoCUS):
3526
4608
  backend=self.backend,
3527
4609
  use_1D=self.use_1D,
3528
4610
  )
3529
-
3530
4611
  def clean_norm(self):
3531
4612
  self.P1_dic = None
3532
4613
  self.P2_dic = None
@@ -3644,7 +4725,570 @@ class funct(FOC.FoCUS):
3644
4725
  s4, vmask, axis=1, rank=j2
3645
4726
  ) # [Nbatch, Nmask, Norient3, Norient2, Norient1]
3646
4727
  return s4
4728
+
4729
+ def computer_filter(self,M,N,J,L):
4730
+ '''
4731
+ This function is strongly inspire by the package https://github.com/SihaoCheng/scattering_transform
4732
+ Done by Sihao Cheng and Rudy Morel.
4733
+ '''
4734
+
4735
+ filter = np.zeros([J, L, M, N],dtype='complex64')
4736
+
4737
+ slant=4.0 / L
4738
+
4739
+ for j in range(J):
4740
+
4741
+ for l in range(L):
4742
+
4743
+ theta = (int(L-L/2-1)-l) * np.pi / L
4744
+ sigma = 0.8 * 2**j
4745
+ xi = 3.0 / 4.0 * np.pi /2**j
4746
+
4747
+ R = np.array([[np.cos(theta), -np.sin(theta)], [np.sin(theta), np.cos(theta)]], np.float64)
4748
+ R_inv = np.array([[np.cos(theta), np.sin(theta)], [-np.sin(theta), np.cos(theta)]], np.float64)
4749
+ D = np.array([[1, 0], [0, slant * slant]])
4750
+ curv = np.matmul(R, np.matmul(D, R_inv)) / ( 2 * sigma * sigma)
4751
+
4752
+ gab = np.zeros((M, N), np.complex128)
4753
+ xx = np.empty((2,2, M, N))
4754
+ yy = np.empty((2,2, M, N))
4755
+
4756
+ for ii, ex in enumerate([-1, 0]):
4757
+ for jj, ey in enumerate([-1, 0]):
4758
+ xx[ii,jj], yy[ii,jj] = np.mgrid[
4759
+ ex * M : M + ex * M,
4760
+ ey * N : N + ey * N]
4761
+
4762
+ arg = -(curv[0, 0] * xx * xx + (curv[0, 1] + curv[1, 0]) * xx * yy + curv[1, 1] * yy * yy)
4763
+ argi = arg + 1.j * (xx * xi * np.cos(theta) + yy * xi * np.sin(theta))
4764
+
4765
+ gabi = np.exp(argi).sum((0,1))
4766
+ gab = np.exp(arg).sum((0,1))
4767
+
4768
+ norm_factor = 2 * np.pi * sigma * sigma / slant
4769
+
4770
+ gab = gab / norm_factor
4771
+
4772
+ gabi = gabi / norm_factor
4773
+
4774
+ K = gabi.sum() / gab.sum()
4775
+
4776
+ # Apply the Gaussian
4777
+ filter[j, l] = np.fft.fft2(gabi-K*gab)
4778
+ filter[j,l,0,0]=0.0
4779
+
4780
+ return self.backend.bk_cast(filter)
4781
+
4782
+ # ------------------------------------------------------------------------------------------
4783
+ #
4784
+ # utility functions
4785
+ #
4786
+ # ------------------------------------------------------------------------------------------
4787
+ def cut_high_k_off(self,data_f, dx, dy):
4788
+ '''
4789
+ This function is strongly inspire by the package https://github.com/SihaoCheng/scattering_transform
4790
+ Done by Sihao Cheng and Rudy Morel.
4791
+ '''
4792
+ if_xodd = (data_f.shape[-2]%2==1)
4793
+ if_yodd = (data_f.shape[-1]%2==1)
4794
+ result = self.backend.backend.cat(
4795
+ (self.backend.backend.cat(
4796
+ ( data_f[...,:dx+if_xodd, :dy+if_yodd] , data_f[...,-dx:, :dy+if_yodd]
4797
+ ), -2),
4798
+ self.backend.backend.cat(
4799
+ ( data_f[...,:dx+if_xodd, -dy:] , data_f[...,-dx:, -dy:]
4800
+ ), -2)
4801
+ ),-1)
4802
+ return result
4803
+ # ---------------------------------------------------------------------------
4804
+ #
4805
+ # utility functions for computing scattering coef and covariance
4806
+ #
4807
+ # ---------------------------------------------------------------------------
4808
+
4809
+ def get_dxdy(self, j,M,N):
4810
+ '''
4811
+ This function is strongly inspire by the package https://github.com/SihaoCheng/scattering_transform
4812
+ Done by Sihao Cheng and Rudy Morel.
4813
+ '''
4814
+ dx = int(max( 8, min( np.ceil(M/2**j), M//2 ) ))
4815
+ dy = int(max( 8, min( np.ceil(N/2**j), N//2 ) ))
4816
+ return dx, dy
4817
+
4818
+
4819
+
4820
+ def get_edge_masks(self,M, N, J, d0=1):
4821
+ '''
4822
+ This function is strongly inspire by the package https://github.com/SihaoCheng/scattering_transform
4823
+ Done by Sihao Cheng and Rudy Morel.
4824
+ '''
4825
+ edge_masks = self.backend.backend.empty((J, M, N))
4826
+ X, Y = self.backend.backend.meshgrid(self.backend.backend.arange(M), self.backend.backend.arange(N), indexing='ij')
4827
+ for j in range(J):
4828
+ edge_dx = min(M//4, 2**j*d0)
4829
+ edge_dy = min(N//4, 2**j*d0)
4830
+ edge_masks[j] = (X>=edge_dx) * (X<=M-edge_dx) * (Y>=edge_dy) * (Y<=N-edge_dy)
4831
+ return edge_masks.to(self.backend.torch_device)
4832
+
4833
+ # ---------------------------------------------------------------------------
4834
+ #
4835
+ # scattering cov
4836
+ #
4837
+ # ---------------------------------------------------------------------------
4838
+ def scattering_cov(
4839
+ self, data, Jmax=None,
4840
+ if_large_batch=False,
4841
+ S4_criteria=None,
4842
+ use_ref=False,
4843
+ normalization='S2',
4844
+ edge=False,
4845
+ pseudo_coef=1,
4846
+ get_variance=False,
4847
+ ref_sigma=None,
4848
+ iso_ang=False
4849
+ ):
4850
+ '''
4851
+ Calculates the scattering correlations for a batch of images, including:
4852
+
4853
+ This function is strongly inspire by the package https://github.com/SihaoCheng/scattering_transform
4854
+ Done by Sihao Cheng and Rudy Morel.
4855
+
4856
+ orig. x orig.:
4857
+ P00 = <(I * psi)(I * psi)*> = L2(I * psi)^2
4858
+ orig. x modulus:
4859
+ C01 = <(I * psi2)(|I * psi1| * psi2)*> / factor
4860
+ when normalization == 'P00', factor = L2(I * psi2) * L2(I * psi1)
4861
+ when normalization == 'P11', factor = L2(I * psi2) * L2(|I * psi1| * psi2)
4862
+ modulus x modulus:
4863
+ C11_pre_norm = <(|I * psi1| * psi3)(|I * psi2| * psi3)>
4864
+ C11 = C11_pre_norm / factor
4865
+ when normalization == 'P00', factor = L2(I * psi1) * L2(I * psi2)
4866
+ when normalization == 'P11', factor = L2(|I * psi1| * psi3) * L2(|I * psi2| * psi3)
4867
+ modulus x modulus (auto):
4868
+ P11 = <(|I * psi1| * psi2)(|I * psi1| * psi2)*>
4869
+ Parameters
4870
+ ----------
4871
+ data : numpy array or torch tensor
4872
+ image set, with size [N_image, x-sidelength, y-sidelength]
4873
+ if_large_batch : Bool (=False)
4874
+ It is recommended to use "False" unless one meets a memory issue
4875
+ C11_criteria : str or None (=None)
4876
+ Only C11 coefficients that satisfy this criteria will be computed.
4877
+ Any expressions of j1, j2, and j3 that can be evaluated as a Bool
4878
+ is accepted.The default "None" corresponds to "j1 <= j2 <= j3".
4879
+ use_ref : Bool (=False)
4880
+ When normalizing, whether or not to use the normalization factor
4881
+ computed from a reference field. For just computing the statistics,
4882
+ the default is False. However, for synthesis, set it to "True" will
4883
+ stablize the optimization process.
4884
+ normalization : str 'P00' or 'P11' (='P00')
4885
+ Whether 'P00' or 'P11' is used as the normalization factor for C01
4886
+ and C11.
4887
+ remove_edge : Bool (=False)
4888
+ If true, the edge region with a width of rougly the size of the largest
4889
+ wavelet involved is excluded when taking the global average to obtain
4890
+ the scattering coefficients.
4891
+
4892
+ Returns
4893
+ -------
4894
+ 'P00' : torch tensor with size [N_image, J, L] (# image, j1, l1)
4895
+ the power in each wavelet bands (the orig. x orig. term)
4896
+ 'S1' : torch tensor with size [N_image, J, L] (# image, j1, l1)
4897
+ the 1st-order scattering coefficients, i.e., the mean of wavelet modulus fields
4898
+ 'C01' : torch tensor with size [N_image, J, J, L, L] (# image, j1, j2, l1, l2)
4899
+ the orig. x modulus terms. Elements with j1 < j2 are all set to np.nan and not computed.
4900
+ 'C11' : torch tensor with size [N_image, J, J, J, L, L, L] (# image, j1, j2, j3, l1, l2, l3)
4901
+ the modulus x modulus terms. Elements not satisfying j1 <= j2 <= j3 and the conditions
4902
+ defined in 'C11_criteria' are all set to np.nan and not computed.
4903
+ 'C11_pre_norm' and 'C11_pre_norm_iso': pre-normalized modulus x modulus terms.
4904
+ 'P11' : torch tensor with size [N_image, J, J, L, L] (# image, j1, j2, l1, l2)
4905
+ the modulus x modulus terms with the two wavelets within modulus the same. Elements not following
4906
+ j1 <= j3 are set to np.nan and not computed.
4907
+ 'P11_iso' : torch tensor with size [N_image, J, J, L] (# image, j1, j2, l2-l1)
4908
+ 'P11' averaged over l1 while keeping l2-l1 constant.
4909
+ '''
4910
+ if S4_criteria is None:
4911
+ S4_criteria = 'j2>=j1'
4912
+
4913
+ # determine jmax and nside corresponding to the input map
4914
+ im_shape = data.shape
4915
+ if self.use_2D:
4916
+ if len(data.shape) == 2:
4917
+ nside = np.min([im_shape[0], im_shape[1]])
4918
+ M,N = im_shape[0],im_shape[1]
4919
+ N_image = 1
4920
+ else:
4921
+ nside = np.min([im_shape[1], im_shape[2]])
4922
+ M,N = im_shape[1],im_shape[2]
4923
+ N_image = data.shape[0]
4924
+ J = int(np.log(nside) / np.log(2))-1 # Number of j scales
4925
+ elif self.use_1D:
4926
+ if len(data.shape) == 2:
4927
+ npix = int(im_shape[1]) # Number of pixels
4928
+ N_image = 1
4929
+ else:
4930
+ npix = int(im_shape[0]) # Number of pixels
4931
+ N_image = data.shape[0]
4932
+
4933
+ nside = int(npix)
4934
+
4935
+ J = int(np.log(nside) / np.log(2))-1 # Number of j scales
4936
+ else:
4937
+ if len(data.shape) == 2:
4938
+ npix = int(im_shape[1]) # Number of pixels
4939
+ N_image = 1
4940
+ else:
4941
+ npix = int(im_shape[0]) # Number of pixels
4942
+ N_image = data.shape[0]
4943
+
4944
+ nside = int(np.sqrt(npix // 12))
3647
4945
 
4946
+ J = int(np.log(nside) / np.log(2)) # Number of j scales
4947
+
4948
+ if Jmax is None:
4949
+ Jmax = J # Number of steps for the loop on scales
4950
+ if Jmax>J:
4951
+ print('==========\n\n')
4952
+ print('The Jmax you requested is larger than the data size, which may cause problems while computing the scattering transform.')
4953
+ print('\n\n==========')
4954
+
4955
+ L=self.NORIENT
4956
+
4957
+ if (M,N,J,L) not in self.filters_set:
4958
+ self.filters_set[(M,N,J,L)] = self.computer_filter(M,N,J,L) #self.computer_filter(M,N,J,L)
4959
+
4960
+ filters_set = self.filters_set[(M,N,J,L)]
4961
+
4962
+ #weight = self.weight
4963
+ if use_ref:
4964
+ if normalization=='S2':
4965
+ ref_S2 = self.ref_scattering_cov_S2
4966
+ else:
4967
+ ref_P11 = self.ref_scattering_cov['P11']
4968
+
4969
+ # convert numpy array input into self.backend.bk_ tensors
4970
+ data = self.backend.bk_cast(data)
4971
+ data_f = self.backend.bk_fftn(data, dim=(-2,-1))
4972
+
4973
+ # initialize tensors for scattering coefficients
4974
+ S2 = self.backend.bk_zeros((N_image,J,L), dtype=data.dtype)
4975
+ S1 = self.backend.bk_zeros((N_image,J,L), dtype=data.dtype)
4976
+
4977
+ Ndata_S3 = J*(J+1)//2
4978
+ Ndata_S4 = J*(J+1)*(J+2)//6
4979
+ J_S4={}
4980
+
4981
+ S3 = self.backend.bk_zeros((N_image,Ndata_S3,L,L), dtype=data_f.dtype)
4982
+ S4_pre_norm = self.backend.bk_zeros((N_image,Ndata_S4,L,L,L), dtype=data_f.dtype)
4983
+ S4 = self.backend.bk_zeros((N_image,Ndata_S4,L,L,L), dtype=data_f.dtype)
4984
+
4985
+ # variance
4986
+ if get_variance:
4987
+ S2_sigma = self.backend.bk_zeros((N_image,J,L), dtype=data.dtype)
4988
+ S1_sigma = self.backend.bk_zeros((N_image,J,L), dtype=data.dtype)
4989
+ S3_sigma = self.backend.bk_zeros((N_image,Ndata_S3,L,L), dtype=data_f.dtype)
4990
+ S4_sigma = self.backend.bk_zeros((N_image,Ndata_S4,L,L,L), dtype=data_f.dtype)
4991
+
4992
+ if iso_ang:
4993
+ S3_iso = self.backend.bk_zeros((N_image,Ndata_S3,L), dtype=data_f.dtype)
4994
+ S4_iso = self.backend.bk_zeros((N_image,Ndata_S4,L,L), dtype=data_f.dtype)
4995
+ if get_variance:
4996
+ S3_sigma_iso = self.backend.bk_zeros((N_image,Ndata_S3,L), dtype=data_f.dtype)
4997
+ S4_sigma_iso = self.backend.bk_zeros((N_image,Ndata_S4,L,L), dtype=data_f.dtype)
4998
+
4999
+ # calculate scattering fields
5000
+ if self.use_2D:
5001
+ if len(data.shape) == 2:
5002
+ I1 = self.backend.bk_ifftn(
5003
+ data_f[None,None,None,:,:] * filters_set[None,:J,:,:,:], dim=(-2,-1)
5004
+ ).abs()
5005
+ else:
5006
+ I1 = self.backend.bk_ifftn(
5007
+ data_f[:,None,None,:,:] * filters_set[None,:J,:,:,:], dim=(-2,-1)
5008
+ ).abs()
5009
+ elif self.use_1D:
5010
+ if len(data.shape) == 1:
5011
+ I1 = self.backend.bk_ifftn(
5012
+ data_f[None,None,None,:] * filters_set[None,:J,:,:], dim=(-1)
5013
+ ).abs()
5014
+ else:
5015
+ I1 = self.backend.bk_ifftn(
5016
+ data_f[:,None,None,:] * filters_set[None,:J,:,:], dim=(-1)
5017
+ ).abs()
5018
+ else:
5019
+ print('todo')
5020
+
5021
+ I1_f= self.backend.bk_fftn(I1, dim=(-2,-1))
5022
+
5023
+ #
5024
+ if edge:
5025
+ if (M,N,J) not in self.edge_masks:
5026
+ self.edge_masks[(M,N,J)] = self.get_edge_masks(M,N,J)
5027
+ edge_mask = self.edge_masks[(M,N,J)][:,None,:,:]
5028
+ edge_mask = edge_mask / edge_mask.mean((-2,-1))[:,:,None,None]
5029
+ else:
5030
+ edge_mask = 1
5031
+ S2 = (I1**2 * edge_mask).mean((-2,-1))
5032
+ S1 = (I1 * edge_mask).mean((-2,-1))
5033
+
5034
+ if get_variance:
5035
+ S2_sigma = (I1**2 * edge_mask).std((-2,-1))
5036
+ S1_sigma = (I1 * edge_mask).std((-2,-1))
5037
+
5038
+ if pseudo_coef != 1:
5039
+ I1 = I1**pseudo_coef
5040
+
5041
+ Ndata_S3=0
5042
+ Ndata_S4=0
5043
+
5044
+ # calculate the covariance and correlations of the scattering fields
5045
+ # only use the low-k Fourier coefs when calculating large-j scattering coefs.
5046
+ for j3 in range(0,J):
5047
+ J_S4[j3]=Ndata_S4
5048
+
5049
+ dx3, dy3 = self.get_dxdy(j3,M,N)
5050
+ I1_f_small = self.cut_high_k_off(I1_f[:,:j3+1], dx3, dy3) # Nimage, J, L, x, y
5051
+ data_f_small = self.cut_high_k_off(data_f, dx3, dy3)
5052
+ if edge:
5053
+ I1_small = self.backend.bk_ifftn(I1_f_small, dim=(-2,-1), norm='ortho')
5054
+ data_small = self.backend.bk_ifftn(data_f_small, dim=(-2,-1), norm='ortho')
5055
+ wavelet_f3 = self.cut_high_k_off(filters_set[j3], dx3, dy3) # L,x,y
5056
+ _, M3, N3 = wavelet_f3.shape
5057
+ wavelet_f3_squared = wavelet_f3**2
5058
+ edge_dx = min(4, int(2**j3*dx3*2/M))
5059
+ edge_dy = min(4, int(2**j3*dy3*2/N))
5060
+ # a normalization change due to the cutoff of frequency space
5061
+ fft_factor = 1 /(M3*N3) * (M3*N3/M/N)**2
5062
+ for j2 in range(0,j3+1):
5063
+ I1_f2_wf3_small = I1_f_small[:,j2].view(N_image,L,1,M3,N3) * wavelet_f3.view(1,1,L,M3,N3)
5064
+ I1_f2_wf3_2_small = I1_f_small[:,j2].view(N_image,L,1,M3,N3) * wavelet_f3_squared.view(1,1,L,M3,N3)
5065
+ if edge:
5066
+ I12_w3_small = self.backend.bk_ifftn(I1_f2_wf3_small, dim=(-2,-1), norm='ortho')
5067
+ I12_w3_2_small = self.backend.bk_ifftn(I1_f2_wf3_2_small, dim=(-2,-1), norm='ortho')
5068
+ if use_ref:
5069
+ if normalization=='P11':
5070
+ norm_factor_S3 = (ref_S2[:,None,j3,:] * ref_P11[:,j2,j3,:,:]**pseudo_coef)**0.5
5071
+ if normalization=='S2':
5072
+ norm_factor_S3 = (ref_S2[:,None,j3,:] * ref_S2[:,j2,:,None]**pseudo_coef)**0.5
5073
+ else:
5074
+ if normalization=='P11':
5075
+ # [N_image,l2,l3,x,y]
5076
+ P11_temp = (I1_f2_wf3_small.abs()**2).mean((-2,-1)) * fft_factor
5077
+ norm_factor_S3 = (S2[:,None,j3,:] * P11_temp**pseudo_coef)**0.5
5078
+ if normalization=='S2':
5079
+ norm_factor_S3 = (S2[:,None,j3,:] * S2[:,j2,:,None]**pseudo_coef)**0.5
5080
+
5081
+ if not edge:
5082
+ S3[:,Ndata_S3,:,:] = (
5083
+ data_f_small.view(N_image,1,1,M3,N3) * self.backend.bk_conjugate(I1_f2_wf3_small)
5084
+ ).mean((-2,-1)) * fft_factor / norm_factor_S3
5085
+
5086
+ if get_variance:
5087
+ S3_sigma[:,Ndata_S3,:,:] = (
5088
+ data_f_small.view(N_image,1,1,M3,N3) * self.backend.bk_conjugate(I1_f2_wf3_small)
5089
+ ).std((-2,-1)) * fft_factor / norm_factor_S3
5090
+ else:
5091
+
5092
+ S3[:,Ndata_S3,:,:] = (
5093
+ data_small.view(N_image,1,1,M3,N3) * self.backend.bk_conjugate(I12_w3_small)
5094
+ )[...,edge_dx:M3-edge_dx, edge_dy:N3-edge_dy].mean((-2,-1)) * fft_factor / norm_factor_S3
5095
+ if get_variance:
5096
+ S3_sigma[:,Ndata_S3,:,:] = (
5097
+ data_small.view(N_image,1,1,M3,N3) * self.backend.bk_conjugate(I12_w3_small)
5098
+ )[...,edge_dx:M3-edge_dx, edge_dy:N3-edge_dy].std((-2,-1)) * fft_factor / norm_factor_S3
5099
+ Ndata_S3+=1
5100
+ if j2 <= j3:
5101
+ beg_n=Ndata_S4
5102
+ for j1 in range(0, j2+1):
5103
+ if eval(S4_criteria):
5104
+ if not edge:
5105
+ if not if_large_batch:
5106
+ # [N_image,l1,l2,l3,x,y]
5107
+ S4_pre_norm[:,Ndata_S4,:,:,:] = (
5108
+ I1_f_small[:,j1].view(N_image,L,1,1,M3,N3) *
5109
+ self.backend.bk_conjugate(I1_f2_wf3_2_small.view(N_image,1,L,L,M3,N3))
5110
+ ).mean((-2,-1)) * fft_factor
5111
+ if get_variance:
5112
+ S4_sigma[:,Ndata_S4,:,:,:] = (
5113
+ I1_f_small[:,j1].view(N_image,L,1,1,M3,N3) *
5114
+ self.backend.bk_conjugate(I1_f2_wf3_2_small.view(N_image,1,L,L,M3,N3))
5115
+ ).std((-2,-1)) * fft_factor
5116
+ else:
5117
+ for l1 in range(L):
5118
+ # [N_image,l2,l3,x,y]
5119
+ S4_pre_norm[:,Ndata_S4,l1,:,:] = (
5120
+ I1_f_small[:,j1,l1].view(N_image,1,1,M3,N3) *
5121
+ self.backend.bk_conjugate(I1_f2_wf3_2_small.view(N_image,L,L,M3,N3))
5122
+ ).mean((-2,-1)) * fft_factor
5123
+ if get_variance:
5124
+ S4_sigma[:,Ndata_S4,l1,:,:] = (
5125
+ I1_f_small[:,j1,l1].view(N_image,1,1,M3,N3) *
5126
+ self.backend.bk_conjugate(I1_f2_wf3_2_small.view(N_image,L,L,M3,N3))
5127
+ ).std((-2,-1)) * fft_factor
5128
+ else:
5129
+ if not if_large_batch:
5130
+ # [N_image,l1,l2,l3,x,y]
5131
+ S4_pre_norm[:,Ndata_S4,:,:,:] = (
5132
+ I1_small[:,j1].view(N_image,L,1,1,M3,N3) * self.backend.bk_conjugate(
5133
+ I12_w3_2_small.view(N_image,1,L,L,M3,N3)
5134
+ )
5135
+ )[...,edge_dx:-edge_dx, edge_dy:-edge_dy].mean((-2,-1)) * fft_factor
5136
+ if get_variance:
5137
+ S4_sigma[:,Ndata_S4,:,:,:] = (
5138
+ I1_small[:,j1].view(N_image,L,1,1,M3,N3) * self.backend.bk_conjugate(
5139
+ I12_w3_2_small.view(N_image,1,L,L,M3,N3)
5140
+ )
5141
+ )[...,edge_dx:-edge_dx, edge_dy:-edge_dy].std((-2,-1)) * fft_factor
5142
+ else:
5143
+ for l1 in range(L):
5144
+ # [N_image,l2,l3,x,y]
5145
+ S4_pre_norm[:,Ndata_S4,l1,:,:] = (
5146
+ I1_small[:,j1].view(N_image,1,1,M3,N3) * self.backend.bk_conjugate(
5147
+ I12_w3_2_small.view(N_image,L,L,M3,N3)
5148
+ )
5149
+ )[...,edge_dx:-edge_dx, edge_dy:-edge_dy].mean((-2,-1)) * fft_factor
5150
+ if get_variance:
5151
+ S4_sigma[:,Ndata_S4,l1,:,:] = (
5152
+ I1_small[:,j1].view(N_image,1,1,M3,N3) * self.backend.bk_conjugate(
5153
+ I12_w3_2_small.view(N_image,L,L,M3,N3)
5154
+ )
5155
+ )[...,edge_dx:-edge_dx, edge_dy:-edge_dy].mean((-2,-1)) * fft_factor
5156
+
5157
+ Ndata_S4+=1
5158
+
5159
+ if normalization=='S2':
5160
+ if use_ref:
5161
+ P = (ref_S2[:,j3,:,None,None] * ref_S2[:,j2,None,:,None] )**(0.5*pseudo_coef)
5162
+ else:
5163
+ P = (S2[:,j3,:,None,None] * S2[:,j2,None,:,None] )**(0.5*pseudo_coef)
5164
+
5165
+ S4[:,beg_n:Ndata_S4,:,:,:]=S4_pre_norm[:,beg_n:Ndata_S4,:,:,:]/P
5166
+
5167
+ if get_variance:
5168
+ S4_sigma[:,beg_n:Ndata_S4,:,:,:] = S4_sigma[:,beg_n:Ndata_S4,:,:,:] / P
5169
+
5170
+ """
5171
+ # define P11 from diagonals of S4
5172
+ for j1 in range(J):
5173
+ for l1 in range(L):
5174
+ P11[:,j1,:,l1,:] = S4_pre_norm[:,j1,j1,:,l1,l1,:].real
5175
+
5176
+
5177
+ if normalization=='S4':
5178
+ if use_ref:
5179
+ P = ref_P11
5180
+ else:
5181
+ P = P11
5182
+ #.view(N_image,J,1,J,L,1,L) * .view(N_image,1,J,J,1,L,L)
5183
+ S4 = S4_pre_norm / (
5184
+ P[:,:,None,:,:,None,:] * P[:,None,:,:,None,:,:]
5185
+ )**(0.5*pseudo_coef)
5186
+
5187
+
5188
+
5189
+
5190
+ # get a single, flattened data vector for_synthesis
5191
+ select_and_index = self.get_scattering_index(J, L, normalization, S4_criteria)
5192
+ index_for_synthesis = select_and_index['index_for_synthesis']
5193
+ index_for_synthesis_iso = select_and_index['index_for_synthesis_iso']
5194
+ """
5195
+ # average over l1 to obtain simple isotropic statistics
5196
+ if iso_ang:
5197
+ S2_iso = S2.mean(-1)
5198
+ S1_iso = S1.mean(-1)
5199
+ for l1 in range(L):
5200
+ for l2 in range(L):
5201
+ S3_iso[...,(l2-l1)%L] += S3[...,l1,l2]
5202
+ for l3 in range(L):
5203
+ S4_iso[...,(l2-l1)%L,(l3-l1)%L] += S4[...,l1,l2,l3]
5204
+ S3_iso /= L; S4_iso /= L
5205
+
5206
+ if get_variance:
5207
+ S2_sigma_iso = S2_sigma.mean(-1)
5208
+ S1_sigma_iso = S1_sigma.mean(-1)
5209
+ for l1 in range(L):
5210
+ for l2 in range(L):
5211
+ S3_sigma_iso[...,(l2-l1)%L] += S3_sigma[...,l1,l2]
5212
+ for l3 in range(L):
5213
+ S4_sigma_iso[...,(l2-l1)%L,(l3-l1)%L] += S4_sigma[...,l1,l2,l3]
5214
+ S3_sigma_iso /= L; S4_sigma_iso /= L
5215
+
5216
+ mean_data=self.backend.bk_zeros((N_image,1), dtype=data.dtype)
5217
+ std_data=self.backend.bk_zeros((N_image,1), dtype=data.dtype)
5218
+ mean_data[:,0]=data.mean((-2,-1))
5219
+ std_data[:,0]=data.std((-2,-1))
5220
+
5221
+ if get_variance:
5222
+ ref_sigma={}
5223
+ if iso_ang:
5224
+ ref_sigma['std_data']=std_data
5225
+ ref_sigma['S1_sigma']=S1_sigma_iso
5226
+ ref_sigma['S2_sigma']=S2_sigma_iso
5227
+ ref_sigma['S3_sigma']=S3_sigma_iso
5228
+ ref_sigma['S4_sigma']=S4_sigma_iso
5229
+ else:
5230
+ ref_sigma['std_data']=std_data
5231
+ ref_sigma['S1_sigma']=S1_sigma
5232
+ ref_sigma['S2_sigma']=S2_sigma
5233
+ ref_sigma['S3_sigma']=S3_sigma
5234
+ ref_sigma['S4_sigma']=S4_sigma
5235
+
5236
+ if iso_ang:
5237
+ if ref_sigma is not None:
5238
+ for_synthesis = self.backend.backend.cat((
5239
+ mean_data/ref_sigma['std_data'],
5240
+ std_data/ref_sigma['std_data'],
5241
+ (S2_iso/ref_sigma['S2_sigma']).reshape((N_image, -1)).log(),
5242
+ (S1_iso/ref_sigma['S1_sigma']).reshape((N_image, -1)).log(),
5243
+ (S3_iso/ref_sigma['S3_sigma']).reshape((N_image, -1)).real,
5244
+ (S3_iso/ref_sigma['S3_sigma']).reshape((N_image, -1)).imag,
5245
+ (S4_iso/ref_sigma['S4_sigma']).reshape((N_image, -1)).real,
5246
+ (S4_iso/ref_sigma['S4_sigma']).reshape((N_image, -1)).imag,
5247
+ ),dim=-1)
5248
+ else:
5249
+ for_synthesis = self.backend.backend.cat((
5250
+ mean_data/std_data,
5251
+ std_data,
5252
+ S2_iso.reshape((N_image, -1)).log(),
5253
+ S1_iso.reshape((N_image, -1)).log(),
5254
+ S3_iso.reshape((N_image, -1)).real,
5255
+ S3_iso.reshape((N_image, -1)).imag,
5256
+ S4_iso.reshape((N_image, -1)).real,
5257
+ S4_iso.reshape((N_image, -1)).imag,
5258
+ ),dim=-1)
5259
+ else:
5260
+ if ref_sigma is not None:
5261
+ for_synthesis = self.backend.backend.cat((
5262
+ mean_data/ref_sigma['std_data'],
5263
+ std_data/ref_sigma['std_data'],
5264
+ (S2/ref_sigma['S2_sigma']).reshape((N_image, -1)).log(),
5265
+ (S1/ref_sigma['S1_sigma']).reshape((N_image, -1)).log(),
5266
+ (S3/ref_sigma['S3_sigma']).reshape((N_image, -1)).real,
5267
+ (S3/ref_sigma['S3_sigma']).reshape((N_image, -1)).imag,
5268
+ (S4/ref_sigma['S4_sigma']).reshape((N_image, -1)).real,
5269
+ (S4/ref_sigma['S4_sigma']).reshape((N_image, -1)).imag,
5270
+ ),dim=-1)
5271
+ else:
5272
+ for_synthesis = self.backend.backend.cat((
5273
+ mean_data/std_data,
5274
+ std_data,
5275
+ S2.reshape((N_image, -1)).log(),
5276
+ S1.reshape((N_image, -1)).log(),
5277
+ S3.reshape((N_image, -1)).real,
5278
+ S3.reshape((N_image, -1)).imag,
5279
+ S4.reshape((N_image, -1)).real,
5280
+ S4.reshape((N_image, -1)).imag,
5281
+ ),dim=-1)
5282
+
5283
+ if not use_ref:
5284
+ self.ref_scattering_cov_S2=S2
5285
+
5286
+ if get_variance:
5287
+ return for_synthesis,ref_sigma
5288
+
5289
+ return for_synthesis
5290
+
5291
+
3648
5292
  def to_gaussian(self,x):
3649
5293
  from scipy.stats import norm
3650
5294
  from scipy.interpolate import interp1d
@@ -4021,8 +5665,12 @@ class funct(FOC.FoCUS):
4021
5665
  image_target,
4022
5666
  nstep=4,
4023
5667
  seed=1234,
4024
- edge=True,
5668
+ Jmax=None,
5669
+ edge=False,
4025
5670
  to_gaussian=True,
5671
+ use_variance=False,
5672
+ synthesised_N=1,
5673
+ iso_ang=False,
4026
5674
  EVAL_FREQUENCY=100,
4027
5675
  NUM_EPOCHS = 300):
4028
5676
 
@@ -4032,13 +5680,16 @@ class funct(FOC.FoCUS):
4032
5680
  def The_loss(u,scat_operator,args):
4033
5681
  ref = args[0]
4034
5682
  sref = args[1]
5683
+ use_v= args[2]
4035
5684
 
4036
5685
  # compute scattering covariance of the current synthetised map called u
4037
- learn=scat_operator.reduce_mean_batch(scat_operator.eval(u,edge=edge))
5686
+ if use_v:
5687
+ learn=scat_operator.reduce_mean_batch(scat_operator.scattering_cov(u,edge=edge,Jmax=Jmax,ref_sigma=sref,use_ref=True,iso_ang=iso_ang))
5688
+ else:
5689
+ learn=scat_operator.reduce_mean_batch(scat_operator.scattering_cov(u,edge=edge,Jmax=Jmax,use_ref=True,iso_ang=iso_ang))
4038
5690
 
4039
5691
  # make the difference withe the reference coordinates
4040
- loss=scat_operator.reduce_distance(learn,ref,sigma=sref)
4041
-
5692
+ loss=scat_operator.backend.bk_reduce_mean(scat_operator.backend.bk_square((learn-ref)))
4042
5693
  return loss
4043
5694
 
4044
5695
  if to_gaussian:
@@ -4078,33 +5729,35 @@ class funct(FOC.FoCUS):
4078
5729
  if k==0:
4079
5730
  np.random.seed(seed)
4080
5731
  if self.use_2D:
4081
- imap=np.random.randn(tmp[k].shape[0],
5732
+ imap=np.random.randn(synthesised_N,
4082
5733
  tmp[k].shape[1],
4083
5734
  tmp[k].shape[2])
4084
5735
  else:
4085
- imap=np.random.randn(tmp[k].shape[0],
5736
+ imap=np.random.randn(synthesised_N,
4086
5737
  tmp[k].shape[1])
4087
5738
  else:
4088
- axis=1
4089
- # if the kernel size is bigger than 3 increase the binning before smoothing
5739
+ # Increase the resolution between each step
4090
5740
  if self.use_2D:
4091
5741
  imap = self.up_grade(
4092
- omap, imap.shape[axis] * 2, axis=1, nouty=imap.shape[axis + 1] * 2
5742
+ omap, imap.shape[1] * 2, axis=1, nouty=imap.shape[2] * 2
4093
5743
  )
4094
5744
  elif self.use_1D:
4095
- imap = self.up_grade(omap, imap.shape[axis] * 2, axis=1)
5745
+ imap = self.up_grade(omap, imap.shape[1] * 2, axis=1)
4096
5746
  else:
4097
- imap = self.up_grade(omap, l_nside, axis=axis)
5747
+ imap = self.up_grade(omap, l_nside, axis=1)
4098
5748
 
4099
5749
  # compute the coefficients for the target image
4100
- ref,sref=self.eval(tmp[k],calc_var=True,edge=edge)
4101
-
5750
+ if use_variance:
5751
+ ref,sref=self.scattering_cov(tmp[k],get_variance=True,edge=edge,Jmax=Jmax,iso_ang=iso_ang)
5752
+ else:
5753
+ ref=self.scattering_cov(tmp[k],edge=edge,Jmax=Jmax,iso_ang=iso_ang)
5754
+ sref=ref
5755
+
4102
5756
  # compute the mean of the population does nothing if only one map is given
4103
5757
  ref=self.reduce_mean_batch(ref)
4104
- sref=self.reduce_mean_batch(sref)
4105
-
5758
+
4106
5759
  # define a loss to minimize
4107
- loss=synthe.Loss(The_loss,self,ref,sref)
5760
+ loss=synthe.Loss(The_loss,self,ref,sref,use_variance)
4108
5761
 
4109
5762
  sy = synthe.Synthesis([loss])
4110
5763
 
@@ -4121,6 +5774,8 @@ class funct(FOC.FoCUS):
4121
5774
  omap=sy.run(imap,
4122
5775
  EVAL_FREQUENCY=EVAL_FREQUENCY,
4123
5776
  NUM_EPOCHS = NUM_EPOCHS)
5777
+
5778
+
4124
5779
 
4125
5780
  t2=time.time()
4126
5781
  print('Total computation %.2fs'%(t2-t1))
@@ -4128,7 +5783,7 @@ class funct(FOC.FoCUS):
4128
5783
  if to_gaussian:
4129
5784
  omap=self.from_gaussian(omap)
4130
5785
 
4131
- if axis==0:
5786
+ if axis==0 and synthesised_N==1:
4132
5787
  return omap[0]
4133
5788
  else:
4134
5789
  return omap