foscat 3.7.0__py3-none-any.whl → 3.7.2__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
foscat/scat_cov.py CHANGED
@@ -92,7 +92,7 @@ class scat_cov:
92
92
  )
93
93
 
94
94
  def conv2complex(self, val):
95
- if val.dtype == "complex64" in val.dtype or "complex128" or val.dtype == "torch.complex64" or val.dtype == "torch.complex128" :
95
+ if val.dtype == "complex64" or val.dtype=="complex128" or val.dtype == "torch.complex64" or val.dtype == "torch.complex128" :
96
96
  return val
97
97
  else:
98
98
  return self.backend.bk_complex(val, 0 * val)
@@ -2540,7 +2540,1050 @@ class funct(FOC.FoCUS):
2540
2540
  -------
2541
2541
  S1, S2, S3, S4 normalized
2542
2542
  """
2543
+
2544
+ return_data = self.return_data
2545
+
2546
+ # Check input consistency
2547
+ if image2 is not None:
2548
+ if list(image1.shape) != list(image2.shape):
2549
+ print(
2550
+ "The two input image should have the same size to eval Scattering Covariance"
2551
+ )
2552
+ return None
2553
+ if mask is not None:
2554
+ if self.use_2D:
2555
+ if image1.shape[-2] != mask.shape[1] or image1.shape[-1] != mask.shape[2]:
2556
+ print(
2557
+ "The LAST 2 COLUMNs of the mask should have the same size ",
2558
+ mask.shape,
2559
+ "than the input image ",
2560
+ image1.shape,
2561
+ "to eval Scattering Covariance",
2562
+ )
2563
+ return None
2564
+ else:
2565
+ if image1.shape[-1] != mask.shape[1]:
2566
+ print(
2567
+ "The LAST COLUMN of the mask should have the same size ",
2568
+ mask.shape,
2569
+ "than the input image ",
2570
+ image1.shape,
2571
+ "to eval Scattering Covariance",
2572
+ )
2573
+ return None
2574
+ if self.use_2D and len(image1.shape) < 2:
2575
+ print(
2576
+ "To work with 2D scattering transform, two dimension is needed, input map has only on dimension"
2577
+ )
2578
+ return None
2579
+
2580
+ ### AUTO OR CROSS
2581
+ cross = False
2582
+ if image2 is not None:
2583
+ cross = True
2584
+
2585
+ ### PARAMETERS
2586
+ axis = 1
2587
+ # determine jmax and nside corresponding to the input map
2588
+ im_shape = image1.shape
2589
+ if self.use_2D:
2590
+ if len(image1.shape) == 2:
2591
+ nside = np.min([im_shape[0], im_shape[1]])
2592
+ npix = im_shape[0] * im_shape[1] # Number of pixels
2593
+ x1 = im_shape[0]
2594
+ x2 = im_shape[1]
2595
+ else:
2596
+ nside = np.min([im_shape[1], im_shape[2]])
2597
+ npix = im_shape[1] * im_shape[2] # Number of pixels
2598
+ x1 = im_shape[1]
2599
+ x2 = im_shape[2]
2600
+ J = int(np.log(nside - self.KERNELSZ) / np.log(2)) # Number of j scales
2601
+ elif self.use_1D:
2602
+ if len(image1.shape) == 2:
2603
+ npix = int(im_shape[1]) # Number of pixels
2604
+ else:
2605
+ npix = int(im_shape[0]) # Number of pixels
2606
+
2607
+ nside = int(npix)
2608
+
2609
+ J = int(np.log(nside) / np.log(2)) # Number of j scales
2610
+ else:
2611
+ if len(image1.shape) == 2:
2612
+ npix = int(im_shape[1]) # Number of pixels
2613
+ else:
2614
+ npix = int(im_shape[0]) # Number of pixels
2615
+
2616
+ nside = int(np.sqrt(npix // 12))
2617
+
2618
+ J = int(np.log(nside) / np.log(2)) # Number of j scales
2619
+
2620
+ if (self.use_2D or self.use_1D) and self.KERNELSZ>3:
2621
+ J-=1
2622
+ if Jmax is None:
2623
+ Jmax = J # Number of steps for the loop on scales
2624
+ if Jmax>J:
2625
+ print('==========\n\n')
2626
+ print('The Jmax you requested is larger than the data size, which may cause problems while computing the scattering transform.')
2627
+ print('\n\n==========')
2628
+
2629
+
2630
+ ### LOCAL VARIABLES (IMAGES and MASK)
2631
+ if len(image1.shape) == 1 or (len(image1.shape) == 2 and self.use_2D):
2632
+ I1 = self.backend.bk_cast(
2633
+ self.backend.bk_expand_dims(image1, 0)
2634
+ ) # Local image1 [Nbatch, Npix]
2635
+ if cross:
2636
+ I2 = self.backend.bk_cast(
2637
+ self.backend.bk_expand_dims(image2, 0)
2638
+ ) # Local image2 [Nbatch, Npix]
2639
+ else:
2640
+ I1 = self.backend.bk_cast(image1) # Local image1 [Nbatch, Npix]
2641
+ if cross:
2642
+ I2 = self.backend.bk_cast(image2) # Local image2 [Nbatch, Npix]
2643
+
2644
+ if mask is None:
2645
+ if self.use_2D:
2646
+ vmask = self.backend.bk_ones([1, x1, x2], dtype=self.all_type)
2647
+ else:
2648
+ vmask = self.backend.bk_ones([1, npix], dtype=self.all_type)
2649
+ else:
2650
+ vmask = self.backend.bk_cast(mask) # [Nmask, Npix]
2651
+
2652
+ if self.KERNELSZ > 3 and not self.use_2D:
2653
+ # if the kernel size is bigger than 3 increase the binning before smoothing
2654
+ if self.use_2D:
2655
+ vmask = self.up_grade(
2656
+ vmask, I1.shape[axis] * 2, axis=1, nouty=I1.shape[axis + 1] * 2
2657
+ )
2658
+ I1 = self.up_grade(
2659
+ I1, I1.shape[axis] * 2, axis=axis, nouty=I1.shape[axis + 1] * 2
2660
+ )
2661
+ if cross:
2662
+ I2 = self.up_grade(
2663
+ I2, I2.shape[axis] * 2, axis=axis, nouty=I2.shape[axis + 1] * 2
2664
+ )
2665
+ elif self.use_1D:
2666
+ vmask = self.up_grade(vmask, I1.shape[axis] * 2, axis=1)
2667
+ I1 = self.up_grade(I1, I1.shape[axis] * 2, axis=axis)
2668
+ if cross:
2669
+ I2 = self.up_grade(I2, I2.shape[axis] * 2, axis=axis)
2670
+ else:
2671
+ I1 = self.up_grade(I1, nside * 2, axis=axis)
2672
+ vmask = self.up_grade(vmask, nside * 2, axis=1)
2673
+ if cross:
2674
+ I2 = self.up_grade(I2, nside * 2, axis=axis)
2675
+
2676
+ if self.KERNELSZ > 5 and not self.use_2D:
2677
+ # if the kernel size is bigger than 3 increase the binning before smoothing
2678
+ if self.use_2D:
2679
+ vmask = self.up_grade(
2680
+ vmask, I1.shape[axis] * 2, axis=1, nouty=I1.shape[axis + 1] * 2
2681
+ )
2682
+ I1 = self.up_grade(
2683
+ I1, I1.shape[axis] * 2, axis=axis, nouty=I1.shape[axis + 1] * 2
2684
+ )
2685
+ if cross:
2686
+ I2 = self.up_grade(
2687
+ I2,
2688
+ I2.shape[axis] * 2,
2689
+ axis=axis,
2690
+ nouty=I2.shape[axis + 1] * 2,
2691
+ )
2692
+ elif self.use_1D:
2693
+ vmask = self.up_grade(vmask, I1.shape[axis] * 4, axis=1)
2694
+ I1 = self.up_grade(I1, I1.shape[axis] * 4, axis=axis)
2695
+ if cross:
2696
+ I2 = self.up_grade(I2, I2.shape[axis] * 4, axis=axis)
2697
+ else:
2698
+ I1 = self.up_grade(I1, nside * 4, axis=axis)
2699
+ vmask = self.up_grade(vmask, nside * 4, axis=1)
2700
+ if cross:
2701
+ I2 = self.up_grade(I2, nside * 4, axis=axis)
2702
+
2703
+ # Normalize the masks because they have different pixel numbers
2704
+ # vmask /= self.backend.bk_reduce_sum(vmask, axis=1)[:, None] # [Nmask, Npix]
2705
+
2706
+ ### INITIALIZATION
2707
+ # Coefficients
2708
+ if return_data:
2709
+ S1 = {}
2710
+ S2 = {}
2711
+ S3 = {}
2712
+ S3P = {}
2713
+ S4 = {}
2714
+ else:
2715
+ S1 = []
2716
+ S2 = []
2717
+ S3 = []
2718
+ S4 = []
2719
+ S3P = []
2720
+ VS1 = []
2721
+ VS2 = []
2722
+ VS3 = []
2723
+ VS3P = []
2724
+ VS4 = []
2725
+
2726
+ off_S2 = -2
2727
+ off_S3 = -3
2728
+ off_S4 = -4
2729
+ if self.use_1D:
2730
+ off_S2 = -1
2731
+ off_S3 = -1
2732
+ off_S4 = -1
2733
+
2734
+ # S2 for normalization
2735
+ cond_init_P1_dic = (norm == "self") or (
2736
+ (norm == "auto") and (self.P1_dic is None)
2737
+ )
2738
+ if norm is None:
2739
+ pass
2740
+ elif cond_init_P1_dic:
2741
+ P1_dic = {}
2742
+ if cross:
2743
+ P2_dic = {}
2744
+ elif (norm == "auto") and (self.P1_dic is not None):
2745
+ P1_dic = self.P1_dic
2746
+ if cross:
2747
+ P2_dic = self.P2_dic
2748
+
2749
+ if return_data:
2750
+ s0 = I1
2751
+ if out_nside is not None:
2752
+ s0 = self.backend.bk_reduce_mean(
2753
+ self.backend.bk_reshape(
2754
+ s0, [s0.shape[0], 12 * out_nside**2, (nside // out_nside) ** 2]
2755
+ ),
2756
+ 2,
2757
+ )
2758
+ else:
2759
+ if not cross:
2760
+ s0, l_vs0 = self.masked_mean(I1, vmask, axis=1, calc_var=True)
2761
+ else:
2762
+ s0, l_vs0 = self.masked_mean(
2763
+ self.backend.bk_L1(I1 * I2), vmask, axis=1, calc_var=True
2764
+ )
2765
+ vs0 = self.backend.bk_concat([l_vs0, l_vs0], 1)
2766
+ s0 = self.backend.bk_concat([s0, l_vs0], 1)
2767
+ #### COMPUTE S1, S2, S3 and S4
2768
+ nside_j3 = nside # NSIDE start (nside_j3 = nside / 2^j3)
2769
+
2770
+ # a remettre comme avant
2771
+ M1_dic={}
2772
+ M2_dic={}
2773
+
2774
+ for j3 in range(Jmax):
2775
+
2776
+ if edge:
2777
+ if self.mask_mask is None:
2778
+ self.mask_mask={}
2779
+ if self.use_2D:
2780
+ if (vmask.shape[1],vmask.shape[2]) not in self.mask_mask:
2781
+ mask_mask=np.zeros([1,vmask.shape[1],vmask.shape[2]])
2782
+ mask_mask[0,
2783
+ self.KERNELSZ//2:-self.KERNELSZ//2+1,
2784
+ self.KERNELSZ//2:-self.KERNELSZ//2+1]=1.0
2785
+ self.mask_mask[(vmask.shape[1],vmask.shape[2])]=self.backend.bk_cast(mask_mask)
2786
+ vmask=vmask*self.mask_mask[(vmask.shape[1],vmask.shape[2])]
2787
+ #print(self.KERNELSZ//2,vmask,mask_mask)
2788
+
2789
+ if self.use_1D:
2790
+ if (vmask.shape[1]) not in self.mask_mask:
2791
+ mask_mask=np.zeros([1,vmask.shape[1]])
2792
+ mask_mask[0,
2793
+ self.KERNELSZ//2:-self.KERNELSZ//2+1]=1.0
2794
+ self.mask_mask[(vmask.shape[1])]=self.backend.bk_cast(mask_mask)
2795
+ vmask=vmask*self.mask_mask[(vmask.shape[1])]
2796
+
2797
+ if return_data:
2798
+ S3[j3] = None
2799
+ S3P[j3] = None
2800
+
2801
+ if S4 is None:
2802
+ S4 = {}
2803
+ S4[j3] = None
2804
+
2805
+ ####### S1 and S2
2806
+ ### Make the convolution I1 * Psi_j3
2807
+ conv1 = self.convol(I1, axis=1) # [Nbatch, Npix_j3, Norient3]
2808
+ if cmat is not None:
2809
+ tmp2 = self.backend.bk_repeat(conv1, self.NORIENT, axis=-1)
2810
+ conv1 = self.backend.bk_reduce_sum(
2811
+ self.backend.bk_reshape(
2812
+ cmat[j3] * tmp2,
2813
+ [tmp2.shape[0], cmat[j3].shape[0], self.NORIENT, self.NORIENT],
2814
+ ),
2815
+ 2,
2816
+ )
2817
+
2818
+ ### Take the module M1 = |I1 * Psi_j3|
2819
+ M1_square = conv1 * self.backend.bk_conjugate(
2820
+ conv1
2821
+ ) # [Nbatch, Npix_j3, Norient3]
2822
+ M1 = self.backend.bk_L1(M1_square) # [Nbatch, Npix_j3, Norient3]
2823
+ # Store M1_j3 in a dictionary
2824
+ M1_dic[j3] = M1
2825
+
2826
+ if not cross: # Auto
2827
+ M1_square = self.backend.bk_real(M1_square)
2828
+
2829
+ ### S2_auto = < M1^2 >_pix
2830
+ # Apply the mask [Nmask, Npix_j3] and average over pixels
2831
+ if return_data:
2832
+ s2 = M1_square
2833
+ else:
2834
+ if calc_var:
2835
+ s2, vs2 = self.masked_mean(
2836
+ M1_square, vmask, axis=1, rank=j3, calc_var=True
2837
+ )
2838
+ else:
2839
+ s2 = self.masked_mean(M1_square, vmask, axis=1, rank=j3)
2840
+
2841
+ if cond_init_P1_dic:
2842
+ # We fill P1_dic with S2 for normalisation of S3 and S4
2843
+ P1_dic[j3] = self.backend.bk_real(s2) # [Nbatch, Nmask, Norient3]
2844
+
2845
+ # We store S2_auto to return it [Nbatch, Nmask, NS2, Norient3]
2846
+ if return_data:
2847
+ if S2 is None:
2848
+ S2 = {}
2849
+ if out_nside is not None and out_nside < nside_j3:
2850
+ s2 = self.backend.bk_reduce_mean(
2851
+ self.backend.bk_reshape(
2852
+ s2,
2853
+ [
2854
+ s2.shape[0],
2855
+ 12 * out_nside**2,
2856
+ (nside_j3 // out_nside) ** 2,
2857
+ s2.shape[2],
2858
+ ],
2859
+ ),
2860
+ 2,
2861
+ )
2862
+ S2[j3] = s2
2863
+ else:
2864
+ if norm == "auto": # Normalize S2
2865
+ s2 /= P1_dic[j3]
2866
+
2867
+ S2.append(
2868
+ self.backend.bk_expand_dims(s2, off_S2)
2869
+ ) # Add a dimension for NS2
2870
+ if calc_var:
2871
+ VS2.append(
2872
+ self.backend.bk_expand_dims(vs2, off_S2)
2873
+ ) # Add a dimension for NS2
2874
+
2875
+ #### S1_auto computation
2876
+ ### Image 1 : S1 = < M1 >_pix
2877
+ # Apply the mask [Nmask, Npix_j3] and average over pixels
2878
+ if return_data:
2879
+ s1 = M1
2880
+ else:
2881
+ if calc_var:
2882
+ s1, vs1 = self.masked_mean(
2883
+ M1, vmask, axis=1, rank=j3, calc_var=True
2884
+ ) # [Nbatch, Nmask, Norient3]
2885
+ else:
2886
+ s1 = self.masked_mean(
2887
+ M1, vmask, axis=1, rank=j3
2888
+ ) # [Nbatch, Nmask, Norient3]
2889
+
2890
+ if return_data:
2891
+ if out_nside is not None and out_nside < nside_j3:
2892
+ s1 = self.backend.bk_reduce_mean(
2893
+ self.backend.bk_reshape(
2894
+ s1,
2895
+ [
2896
+ s1.shape[0],
2897
+ 12 * out_nside**2,
2898
+ (nside_j3 // out_nside) ** 2,
2899
+ s1.shape[2],
2900
+ ],
2901
+ ),
2902
+ 2,
2903
+ )
2904
+ S1[j3] = s1
2905
+ else:
2906
+ ### Normalize S1
2907
+ if norm is not None:
2908
+ self.div_norm(s1, (P1_dic[j3]) ** 0.5)
2909
+ ### We store S1 for image1 [Nbatch, Nmask, NS1, Norient3]
2910
+ S1.append(
2911
+ self.backend.bk_expand_dims(s1, off_S2)
2912
+ ) # Add a dimension for NS1
2913
+ if calc_var:
2914
+ VS1.append(
2915
+ self.backend.bk_expand_dims(vs1, off_S2)
2916
+ ) # Add a dimension for NS1
2917
+
2918
+ else: # Cross
2919
+ ### Make the convolution I2 * Psi_j3
2920
+ conv2 = self.convol(I2, axis=1) # [Nbatch, Npix_j3, Norient3]
2921
+ if cmat is not None:
2922
+ tmp2 = self.backend.bk_repeat(conv2, self.NORIENT, axis=-1)
2923
+ conv2 = self.backend.bk_reduce_sum(
2924
+ self.backend.bk_reshape(
2925
+ cmat[j3] * tmp2,
2926
+ [
2927
+ tmp2.shape[0],
2928
+ cmat[j3].shape[0],
2929
+ self.NORIENT,
2930
+ self.NORIENT,
2931
+ ],
2932
+ ),
2933
+ 2,
2934
+ )
2935
+ ### Take the module M2 = |I2 * Psi_j3|
2936
+ M2_square = conv2 * self.backend.bk_conjugate(
2937
+ conv2
2938
+ ) # [Nbatch, Npix_j3, Norient3]
2939
+ M2 = self.backend.bk_L1(M2_square) # [Nbatch, Npix_j3, Norient3]
2940
+ # Store M2_j3 in a dictionary
2941
+ M2_dic[j3] = M2
2942
+
2943
+ ### S2_auto = < M2^2 >_pix
2944
+ # Not returned, only for normalization
2945
+ if cond_init_P1_dic:
2946
+ # Apply the mask [Nmask, Npix_j3] and average over pixels
2947
+ if return_data:
2948
+ p1 = M1_square
2949
+ p2 = M2_square
2950
+ else:
2951
+ if calc_var:
2952
+ p1, vp1 = self.masked_mean(
2953
+ M1_square, vmask, axis=1, rank=j3, calc_var=True
2954
+ ) # [Nbatch, Nmask, Norient3]
2955
+ p2, vp2 = self.masked_mean(
2956
+ M2_square, vmask, axis=1, rank=j3, calc_var=True
2957
+ ) # [Nbatch, Nmask, Norient3]
2958
+ else:
2959
+ p1 = self.masked_mean(
2960
+ M1_square, vmask, axis=1, rank=j3
2961
+ ) # [Nbatch, Nmask, Norient3]
2962
+ p2 = self.masked_mean(
2963
+ M2_square, vmask, axis=1, rank=j3
2964
+ ) # [Nbatch, Nmask, Norient3]
2965
+ # We fill P1_dic with S2 for normalisation of S3 and S4
2966
+ P1_dic[j3] = self.backend.bk_real(p1) # [Nbatch, Nmask, Norient3]
2967
+ P2_dic[j3] = self.backend.bk_real(p2) # [Nbatch, Nmask, Norient3]
2968
+
2969
+ ### S2_cross = < (I1 * Psi_j3) (I2 * Psi_j3)^* >_pix
2970
+ # z_1 x z_2^* = (a1a2 + b1b2) + i(b1a2 - a1b2)
2971
+ s2 = conv1 * self.backend.bk_conjugate(conv2)
2972
+ MX = self.backend.bk_L1(s2)
2973
+ # Apply the mask [Nmask, Npix_j3] and average over pixels
2974
+ if return_data:
2975
+ s2 = s2
2976
+ else:
2977
+ if calc_var:
2978
+ s2, vs2 = self.masked_mean(
2979
+ s2, vmask, axis=1, rank=j3, calc_var=True
2980
+ )
2981
+ else:
2982
+ s2 = self.masked_mean(s2, vmask, axis=1, rank=j3)
2983
+
2984
+ if return_data:
2985
+ if out_nside is not None and out_nside < nside_j3:
2986
+ s2 = self.backend.bk_reduce_mean(
2987
+ self.backend.bk_reshape(
2988
+ s2,
2989
+ [
2990
+ s2.shape[0],
2991
+ 12 * out_nside**2,
2992
+ (nside_j3 // out_nside) ** 2,
2993
+ s2.shape[2],
2994
+ ],
2995
+ ),
2996
+ 2,
2997
+ )
2998
+ S2[j3] = s2
2999
+ else:
3000
+ ### Normalize S2_cross
3001
+ if norm == "auto":
3002
+ s2 /= (P1_dic[j3] * P2_dic[j3]) ** 0.5
3003
+
3004
+ ### Store S2_cross as complex [Nbatch, Nmask, NS2, Norient3]
3005
+ s2 = self.backend.bk_real(s2)
3006
+
3007
+ S2.append(
3008
+ self.backend.bk_expand_dims(s2, off_S2)
3009
+ ) # Add a dimension for NS2
3010
+ if calc_var:
3011
+ VS2.append(
3012
+ self.backend.bk_expand_dims(vs2, off_S2)
3013
+ ) # Add a dimension for NS2
3014
+
3015
+ #### S1_auto computation
3016
+ ### Image 1 : S1 = < M1 >_pix
3017
+ # Apply the mask [Nmask, Npix_j3] and average over pixels
3018
+ if return_data:
3019
+ s1 = MX
3020
+ else:
3021
+ if calc_var:
3022
+ s1, vs1 = self.masked_mean(
3023
+ MX, vmask, axis=1, rank=j3, calc_var=True
3024
+ ) # [Nbatch, Nmask, Norient3]
3025
+ else:
3026
+ s1 = self.masked_mean(
3027
+ MX, vmask, axis=1, rank=j3
3028
+ ) # [Nbatch, Nmask, Norient3]
3029
+ if return_data:
3030
+ if out_nside is not None and out_nside < nside_j3:
3031
+ s1 = self.backend.bk_reduce_mean(
3032
+ self.backend.bk_reshape(
3033
+ s1,
3034
+ [
3035
+ s1.shape[0],
3036
+ 12 * out_nside**2,
3037
+ (nside_j3 // out_nside) ** 2,
3038
+ s1.shape[2],
3039
+ ],
3040
+ ),
3041
+ 2,
3042
+ )
3043
+ S1[j3] = s1
3044
+ else:
3045
+ ### Normalize S1
3046
+ if norm is not None:
3047
+ self.div_norm(s1, (P1_dic[j3]) ** 0.5)
3048
+ ### We store S1 for image1 [Nbatch, Nmask, NS1, Norient3]
3049
+ S1.append(
3050
+ self.backend.bk_expand_dims(s1, off_S2)
3051
+ ) # Add a dimension for NS1
3052
+ if calc_var:
3053
+ VS1.append(
3054
+ self.backend.bk_expand_dims(vs1, off_S2)
3055
+ ) # Add a dimension for NS1
3056
+
3057
+ # Initialize dictionaries for |I1*Psi_j| * Psi_j3
3058
+ M1convPsi_dic = {}
3059
+ if cross:
3060
+ # Initialize dictionaries for |I2*Psi_j| * Psi_j3
3061
+ M2convPsi_dic = {}
3062
+
3063
+ ###### S3
3064
+ nside_j2 = nside_j3
3065
+ for j2 in range(0, j3 + 1): # j2 <= j3
3066
+ if return_data:
3067
+ if S4[j3] is None:
3068
+ S4[j3] = {}
3069
+ S4[j3][j2] = None
3070
+
3071
+ ### S3_auto = < (I1 * Psi)_j3 x (|I1 * Psi_j2| * Psi_j3)^* >_pix
3072
+ if not cross:
3073
+ if calc_var:
3074
+ s3, vs3 = self._compute_S3(
3075
+ j2,
3076
+ j3,
3077
+ conv1,
3078
+ vmask,
3079
+ M1_dic,
3080
+ M1convPsi_dic,
3081
+ calc_var=True,
3082
+ cmat2=cmat2,
3083
+ ) # [Nbatch, Nmask, Norient3, Norient2]
3084
+ else:
3085
+ s3 = self._compute_S3(
3086
+ j2,
3087
+ j3,
3088
+ conv1,
3089
+ vmask,
3090
+ M1_dic,
3091
+ M1convPsi_dic,
3092
+ return_data=return_data,
3093
+ cmat2=cmat2,
3094
+ ) # [Nbatch, Nmask, Norient3, Norient2]
3095
+
3096
+ if return_data:
3097
+ if S3[j3] is None:
3098
+ S3[j3] = {}
3099
+ if out_nside is not None and out_nside < nside_j2:
3100
+ s3 = self.backend.bk_reduce_mean(
3101
+ self.backend.bk_reshape(
3102
+ s3,
3103
+ [
3104
+ s3.shape[0],
3105
+ 12 * out_nside**2,
3106
+ (nside_j2 // out_nside) ** 2,
3107
+ s3.shape[2],
3108
+ s3.shape[3],
3109
+ ],
3110
+ ),
3111
+ 2,
3112
+ )
3113
+ S3[j3][j2] = s3
3114
+ else:
3115
+ ### Normalize S3 with S2_j [Nbatch, Nmask, Norient_j]
3116
+ if norm is not None:
3117
+ self.div_norm(
3118
+ s3,
3119
+ (
3120
+ self.backend.bk_expand_dims(P1_dic[j2], off_S2)
3121
+ * self.backend.bk_expand_dims(P1_dic[j3], -1)
3122
+ )
3123
+ ** 0.5,
3124
+ ) # [Nbatch, Nmask, Norient3, Norient2]
3125
+
3126
+ ### Store S3 as a complex [Nbatch, Nmask, NS3, Norient3, Norient2]
3127
+
3128
+ # S3.append(self.backend.bk_reshape(s3,[s3.shape[0],s3.shape[1],
3129
+ # s3.shape[2]*s3.shape[3]]))
3130
+ S3.append(
3131
+ self.backend.bk_expand_dims(s3, off_S3)
3132
+ ) # Add a dimension for NS3
3133
+ if calc_var:
3134
+ VS3.append(
3135
+ self.backend.bk_expand_dims(vs3, off_S3)
3136
+ ) # Add a dimension for NS3
3137
+ # VS3.append(self.backend.bk_reshape(vs3,[s3.shape[0],s3.shape[1],
3138
+ # s3.shape[2]*s3.shape[3]]))
3139
+
3140
+ ### S3_cross = < (I1 * Psi)_j3 x (|I2 * Psi_j2| * Psi_j3)^* >_pix
3141
+ ### S3P_cross = < (I2 * Psi)_j3 x (|I1 * Psi_j2| * Psi_j3)^* >_pix
3142
+ else:
3143
+ if calc_var:
3144
+ s3, vs3 = self._compute_S3(
3145
+ j2,
3146
+ j3,
3147
+ conv1,
3148
+ vmask,
3149
+ M2_dic,
3150
+ M2convPsi_dic,
3151
+ calc_var=True,
3152
+ cmat2=cmat2,
3153
+ )
3154
+ s3p, vs3p = self._compute_S3(
3155
+ j2,
3156
+ j3,
3157
+ conv2,
3158
+ vmask,
3159
+ M1_dic,
3160
+ M1convPsi_dic,
3161
+ calc_var=True,
3162
+ cmat2=cmat2,
3163
+ )
3164
+ else:
3165
+ s3 = self._compute_S3(
3166
+ j2,
3167
+ j3,
3168
+ conv1,
3169
+ vmask,
3170
+ M2_dic,
3171
+ M2convPsi_dic,
3172
+ return_data=return_data,
3173
+ cmat2=cmat2,
3174
+ )
3175
+ s3p = self._compute_S3(
3176
+ j2,
3177
+ j3,
3178
+ conv2,
3179
+ vmask,
3180
+ M1_dic,
3181
+ M1convPsi_dic,
3182
+ return_data=return_data,
3183
+ cmat2=cmat2,
3184
+ )
3185
+
3186
+ if return_data:
3187
+ if S3[j3] is None:
3188
+ S3[j3] = {}
3189
+ S3P[j3] = {}
3190
+ if out_nside is not None and out_nside < nside_j2:
3191
+ s3 = self.backend.bk_reduce_mean(
3192
+ self.backend.bk_reshape(
3193
+ s3,
3194
+ [
3195
+ s3.shape[0],
3196
+ 12 * out_nside**2,
3197
+ (nside_j2 // out_nside) ** 2,
3198
+ s3.shape[2],
3199
+ s3.shape[3],
3200
+ ],
3201
+ ),
3202
+ 2,
3203
+ )
3204
+ s3p = self.backend.bk_reduce_mean(
3205
+ self.backend.bk_reshape(
3206
+ s3p,
3207
+ [
3208
+ s3.shape[0],
3209
+ 12 * out_nside**2,
3210
+ (nside_j2 // out_nside) ** 2,
3211
+ s3.shape[2],
3212
+ s3.shape[3],
3213
+ ],
3214
+ ),
3215
+ 2,
3216
+ )
3217
+ S3[j3][j2] = s3
3218
+ S3P[j3][j2] = s3p
3219
+ else:
3220
+ ### Normalize S3 and S3P with S2_j [Nbatch, Nmask, Norient_j]
3221
+ if norm is not None:
3222
+ self.div_norm(
3223
+ s3,
3224
+ (
3225
+ self.backend.bk_expand_dims(P2_dic[j2], off_S2)
3226
+ * self.backend.bk_expand_dims(P1_dic[j3], -1)
3227
+ )
3228
+ ** 0.5,
3229
+ ) # [Nbatch, Nmask, Norient3, Norient2]
3230
+ self.div_norm(
3231
+ s3p,
3232
+ (
3233
+ self.backend.bk_expand_dims(P1_dic[j2], off_S2)
3234
+ * self.backend.bk_expand_dims(P2_dic[j3], -1)
3235
+ )
3236
+ ** 0.5,
3237
+ ) # [Nbatch, Nmask, Norient3, Norient2]
3238
+
3239
+ ### Store S3 and S3P as a complex [Nbatch, Nmask, NS3, Norient3, Norient2]
3240
+
3241
+ # S3.append(self.backend.bk_reshape(s3,[s3.shape[0],s3.shape[1],
3242
+ # s3.shape[2]*s3.shape[3]]))
3243
+ S3.append(
3244
+ self.backend.bk_expand_dims(s3, off_S3)
3245
+ ) # Add a dimension for NS3
3246
+ if calc_var:
3247
+ VS3.append(
3248
+ self.backend.bk_expand_dims(vs3, off_S3)
3249
+ ) # Add a dimension for NS3
3250
+
3251
+ # VS3.append(self.backend.bk_reshape(vs3,[s3.shape[0],s3.shape[1],
3252
+ # s3.shape[2]*s3.shape[3]]))
3253
+
3254
+ # S3P.append(self.backend.bk_reshape(s3p,[s3.shape[0],s3.shape[1],
3255
+ # s3.shape[2]*s3.shape[3]]))
3256
+ S3P.append(
3257
+ self.backend.bk_expand_dims(s3p, off_S3)
3258
+ ) # Add a dimension for NS3
3259
+ if calc_var:
3260
+ VS3P.append(
3261
+ self.backend.bk_expand_dims(vs3p, off_S3)
3262
+ ) # Add a dimension for NS3
3263
+ # VS3P.append(self.backend.bk_reshape(vs3p,[s3.shape[0],s3.shape[1],
3264
+ # s3.shape[2]*s3.shape[3]]))
3265
+
3266
+ ##### S4
3267
+ nside_j1 = nside_j2
3268
+ for j1 in range(0, j2 + 1): # j1 <= j2
3269
+ ### S4_auto = <(|I1 * psi1| * psi3)(|I1 * psi2| * psi3)^*>
3270
+ if not cross:
3271
+ if calc_var:
3272
+ s4, vs4 = self._compute_S4(
3273
+ j1,
3274
+ j2,
3275
+ vmask,
3276
+ M1convPsi_dic,
3277
+ M2convPsi_dic=None,
3278
+ calc_var=True,
3279
+ ) # [Nbatch, Nmask, Norient3, Norient2, Norient1]
3280
+ else:
3281
+ s4 = self._compute_S4(
3282
+ j1,
3283
+ j2,
3284
+ vmask,
3285
+ M1convPsi_dic,
3286
+ M2convPsi_dic=None,
3287
+ return_data=return_data,
3288
+ ) # [Nbatch, Nmask, Norient3, Norient2, Norient1]
3289
+
3290
+ if return_data:
3291
+ if S4[j3][j2] is None:
3292
+ S4[j3][j2] = {}
3293
+ if out_nside is not None and out_nside < nside_j1:
3294
+ s4 = self.backend.bk_reduce_mean(
3295
+ self.backend.bk_reshape(
3296
+ s4,
3297
+ [
3298
+ s4.shape[0],
3299
+ 12 * out_nside**2,
3300
+ (nside_j1 // out_nside) ** 2,
3301
+ s4.shape[2],
3302
+ s4.shape[3],
3303
+ s4.shape[4],
3304
+ ],
3305
+ ),
3306
+ 2,
3307
+ )
3308
+ S4[j3][j2][j1] = s4
3309
+ else:
3310
+ ### Normalize S4 with S2_j [Nbatch, Nmask, Norient_j]
3311
+ if norm is not None:
3312
+ self.div_norm(
3313
+ s4,
3314
+ (
3315
+ self.backend.bk_expand_dims(
3316
+ self.backend.bk_expand_dims(
3317
+ P1_dic[j1], off_S2
3318
+ ),
3319
+ off_S2,
3320
+ )
3321
+ * self.backend.bk_expand_dims(
3322
+ self.backend.bk_expand_dims(
3323
+ P1_dic[j2], off_S2
3324
+ ),
3325
+ -1,
3326
+ )
3327
+ )
3328
+ ** 0.5,
3329
+ ) # [Nbatch, Nmask, Norient3, Norient2, Norient1]
3330
+ ### Store S4 as a complex [Nbatch, Nmask, NS4, Norient3, Norient2, Norient1]
3331
+
3332
+ # S4.append(self.backend.bk_reshape(s4,[s4.shape[0],s4.shape[1],
3333
+ # s4.shape[2]*s4.shape[3]*s4.shape[4]]))
3334
+ S4.append(
3335
+ self.backend.bk_expand_dims(s4, off_S4)
3336
+ ) # Add a dimension for NS4
3337
+ if calc_var:
3338
+ # VS4.append(self.backend.bk_reshape(vs4,[s4.shape[0],s4.shape[1],
3339
+ # s4.shape[2]*s4.shape[3]*s4.shape[4]]))
3340
+ VS4.append(
3341
+ self.backend.bk_expand_dims(vs4, off_S4)
3342
+ ) # Add a dimension for NS4
3343
+
3344
+ ### S4_cross = <(|I1 * psi1| * psi3)(|I2 * psi2| * psi3)^*>
3345
+ else:
3346
+ if calc_var:
3347
+ s4, vs4 = self._compute_S4(
3348
+ j1,
3349
+ j2,
3350
+ vmask,
3351
+ M1convPsi_dic,
3352
+ M2convPsi_dic=M2convPsi_dic,
3353
+ calc_var=True,
3354
+ ) # [Nbatch, Nmask, Norient3, Norient2, Norient1]
3355
+ else:
3356
+ s4 = self._compute_S4(
3357
+ j1,
3358
+ j2,
3359
+ vmask,
3360
+ M1convPsi_dic,
3361
+ M2convPsi_dic=M2convPsi_dic,
3362
+ return_data=return_data,
3363
+ ) # [Nbatch, Nmask, Norient3, Norient2, Norient1]
3364
+
3365
+ if return_data:
3366
+ if S4[j3][j2] is None:
3367
+ S4[j3][j2] = {}
3368
+ if out_nside is not None and out_nside < nside_j1:
3369
+ s4 = self.backend.bk_reduce_mean(
3370
+ self.backend.bk_reshape(
3371
+ s4,
3372
+ [
3373
+ s4.shape[0],
3374
+ 12 * out_nside**2,
3375
+ (nside_j1 // out_nside) ** 2,
3376
+ s4.shape[2],
3377
+ s4.shape[3],
3378
+ s4.shape[4],
3379
+ ],
3380
+ ),
3381
+ 2,
3382
+ )
3383
+ S4[j3][j2][j1] = s4
3384
+ else:
3385
+ ### Normalize S4 with S2_j [Nbatch, Nmask, Norient_j]
3386
+ if norm is not None:
3387
+ self.div_norm(
3388
+ s4,
3389
+ (
3390
+ self.backend.bk_expand_dims(
3391
+ self.backend.bk_expand_dims(
3392
+ P1_dic[j1], off_S2
3393
+ ),
3394
+ off_S2,
3395
+ )
3396
+ * self.backend.bk_expand_dims(
3397
+ self.backend.bk_expand_dims(
3398
+ P2_dic[j2], off_S2
3399
+ ),
3400
+ -1,
3401
+ )
3402
+ )
3403
+ ** 0.5,
3404
+ ) # [Nbatch, Nmask, Norient3, Norient2, Norient1]
3405
+ ### Store S4 as a complex [Nbatch, Nmask, NS4, Norient3, Norient2, Norient1]
3406
+ # S4.append(self.backend.bk_reshape(s4,[s4.shape[0],s4.shape[1],
3407
+ # s4.shape[2]*s4.shape[3]*s4.shape[4]]))
3408
+ S4.append(
3409
+ self.backend.bk_expand_dims(s4, off_S4)
3410
+ ) # Add a dimension for NS4
3411
+ if calc_var:
3412
+
3413
+ # VS4.append(self.backend.bk_reshape(vs4,[s4.shape[0],s4.shape[1],
3414
+ # s4.shape[2]*s4.shape[3]*s4.shape[4]]))
3415
+ VS4.append(
3416
+ self.backend.bk_expand_dims(vs4, off_S4)
3417
+ ) # Add a dimension for NS4
3418
+
3419
+ nside_j1 = nside_j1 // 2
3420
+ nside_j2 = nside_j2 // 2
3421
+
3422
+ ###### Reshape for next iteration on j3
3423
+ ### Image I1,
3424
+ # downscale the I1 [Nbatch, Npix_j3]
3425
+ if j3 != Jmax - 1:
3426
+ I1 = self.smooth(I1, axis=1)
3427
+ I1 = self.ud_grade_2(I1, axis=1)
3428
+
3429
+ ### Image I2
3430
+ if cross:
3431
+ I2 = self.smooth(I2, axis=1)
3432
+ I2 = self.ud_grade_2(I2, axis=1)
3433
+
3434
+ ### Modules
3435
+ for j2 in range(0, j3 + 1): # j2 =< j3
3436
+ ### Dictionary M1_dic[j2]
3437
+ M1_smooth = self.smooth(
3438
+ M1_dic[j2], axis=1
3439
+ ) # [Nbatch, Npix_j3, Norient3]
3440
+ M1_dic[j2] = self.ud_grade_2(
3441
+ M1_smooth, axis=1
3442
+ ) # [Nbatch, Npix_j3, Norient3]
3443
+
3444
+ ### Dictionary M2_dic[j2]
3445
+ if cross:
3446
+ M2_smooth = self.smooth(
3447
+ M2_dic[j2], axis=1
3448
+ ) # [Nbatch, Npix_j3, Norient3]
3449
+ M2_dic[j2] = self.ud_grade_2(
3450
+ M2, axis=1
3451
+ ) # [Nbatch, Npix_j3, Norient3]
3452
+
3453
+ ### Mask
3454
+ vmask = self.ud_grade_2(vmask, axis=1)
3455
+
3456
+ if self.mask_thres is not None:
3457
+ vmask = self.backend.bk_threshold(vmask, self.mask_thres)
3458
+
3459
+ ### NSIDE_j3
3460
+ nside_j3 = nside_j3 // 2
3461
+
3462
+ ### Store P1_dic and P2_dic in self
3463
+ if (norm == "auto") and (self.P1_dic is None):
3464
+ self.P1_dic = P1_dic
3465
+ if cross:
3466
+ self.P2_dic = P2_dic
3467
+ """
3468
+ Sout=[s0]+S1+S2+S3+S4
3469
+
3470
+ if cross:
3471
+ Sout=Sout+S3P
3472
+ if calc_var:
3473
+ SVout=[vs0]+VS1+VS2+VS3+VS4
3474
+ if cross:
3475
+ VSout=VSout+VS3P
3476
+ return self.backend.bk_concat(Sout, 2),self.backend.bk_concat(VSout, 2)
3477
+
3478
+ return self.backend.bk_concat(Sout, 2)
3479
+ """
3480
+ if not return_data:
3481
+ S1 = self.backend.bk_concat(S1, 2)
3482
+ S2 = self.backend.bk_concat(S2, 2)
3483
+ S3 = self.backend.bk_concat(S3, 2)
3484
+ S4 = self.backend.bk_concat(S4, 2)
3485
+ if cross:
3486
+ S3P = self.backend.bk_concat(S3P, 2)
3487
+ if calc_var:
3488
+ VS1 = self.backend.bk_concat(VS1, 2)
3489
+ VS2 = self.backend.bk_concat(VS2, 2)
3490
+ VS3 = self.backend.bk_concat(VS3, 2)
3491
+ VS4 = self.backend.bk_concat(VS4, 2)
3492
+ if cross:
3493
+ VS3P = self.backend.bk_concat(VS3P, 2)
3494
+ if calc_var:
3495
+ if not cross:
3496
+ return scat_cov(
3497
+ s0, S2, S3, S4, s1=S1, backend=self.backend, use_1D=self.use_1D
3498
+ ), scat_cov(
3499
+ vs0,
3500
+ VS2,
3501
+ VS3,
3502
+ VS4,
3503
+ s1=VS1,
3504
+ backend=self.backend,
3505
+ use_1D=self.use_1D,
3506
+ )
3507
+ else:
3508
+ return scat_cov(
3509
+ s0,
3510
+ S2,
3511
+ S3,
3512
+ S4,
3513
+ s1=S1,
3514
+ s3p=S3P,
3515
+ backend=self.backend,
3516
+ use_1D=self.use_1D,
3517
+ ), scat_cov(
3518
+ vs0,
3519
+ VS2,
3520
+ VS3,
3521
+ VS4,
3522
+ s1=VS1,
3523
+ s3p=VS3P,
3524
+ backend=self.backend,
3525
+ use_1D=self.use_1D,
3526
+ )
3527
+ else:
3528
+ if not cross:
3529
+ return scat_cov(
3530
+ s0, S2, S3, S4, s1=S1, backend=self.backend, use_1D=self.use_1D
3531
+ )
3532
+ else:
3533
+ return scat_cov(
3534
+ s0,
3535
+ S2,
3536
+ S3,
3537
+ S4,
3538
+ s1=S1,
3539
+ s3p=S3P,
3540
+ backend=self.backend,
3541
+ use_1D=self.use_1D,
3542
+ )
3543
+
3544
+ def eval_new(
3545
+ self,
3546
+ image1,
3547
+ image2=None,
3548
+ mask=None,
3549
+ norm=None,
3550
+ calc_var=False,
3551
+ cmat=None,
3552
+ cmat2=None,
3553
+ Jmax=None,
3554
+ out_nside=None,
3555
+ edge=True
3556
+ ):
3557
+ """
3558
+ Calculates the scattering correlations for a batch of images. Mean are done over pixels.
3559
+ mean of modulus:
3560
+ S1 = <|I * Psi_j3|>
3561
+ Normalization : take the log
3562
+ power spectrum:
3563
+ S2 = <|I * Psi_j3|^2>
3564
+ Normalization : take the log
3565
+ orig. x modulus:
3566
+ S3 = < (I * Psi)_j3 x (|I * Psi_j2| * Psi_j3)^* >
3567
+ Normalization : divide by (S2_j2 * S2_j3)^0.5
3568
+ modulus x modulus:
3569
+ S4 = <(|I * psi1| * psi3)(|I * psi2| * psi3)^*>
3570
+ Normalization : divide by (S2_j1 * S2_j2)^0.5
3571
+ Parameters
3572
+ ----------
3573
+ image1: tensor
3574
+ Image on which we compute the scattering coefficients [Nbatch, Npix, 1, 1]
3575
+ image2: tensor
3576
+ Second image. If not None, we compute cross-scattering covariance coefficients.
3577
+ mask:
3578
+ norm: None or str
3579
+ If None no normalization is applied, if 'auto' normalize by the reference S2,
3580
+ if 'self' normalize by the current S2.
3581
+ Returns
3582
+ -------
3583
+ S1, S2, S3, S4 normalized
3584
+ """
2543
3585
  return_data = self.return_data
3586
+ NORIENT=self.NORIENT
2544
3587
  # Check input consistency
2545
3588
  if image2 is not None:
2546
3589
  if list(image1.shape) != list(image2.shape):
@@ -2699,13 +3742,19 @@ class funct(FOC.FoCUS):
2699
3742
  S3P = {}
2700
3743
  S4 = {}
2701
3744
  else:
2702
- S1 = []
2703
- S2 = []
3745
+ result=self.backend.backend.zeros([I1.shape[0],vmask.shape[0],2+2*Jmax*self.NORIENT],
3746
+ dtype=self.backend.backend.float32,
3747
+ device=self.backend.torch_device)
3748
+ vresult=self.backend.backend.zeros([I1.shape[0],vmask.shape[0],2+2*Jmax*self.NORIENT],
3749
+ dtype=self.backend.backend.float32,
3750
+ device=self.backend.torch_device)
3751
+ S1 = self.backend.backend.zeros([1,Jmax*self.NORIENT],dtype=self.backend.backend.float32,device=self.backend.torch_device)
3752
+ S2 = self.backend.backend.zeros([1,Jmax*self.NORIENT],dtype=self.backend.backend.float32,device=self.backend.torch_device)
2704
3753
  S3 = []
2705
3754
  S4 = []
2706
3755
  S3P = []
2707
- VS1 = []
2708
- VS2 = []
3756
+ VS1 = self.backend.backend.zeros([1,Jmax*self.NORIENT],dtype=self.backend.backend.float32,device=self.backend.torch_device)
3757
+ VS2 = self.backend.backend.zeros([1,Jmax*self.NORIENT],dtype=self.backend.backend.float32,device=self.backend.torch_device)
2709
3758
  VS3 = []
2710
3759
  VS3P = []
2711
3760
  VS4 = []
@@ -2749,13 +3798,18 @@ class funct(FOC.FoCUS):
2749
3798
  s0, l_vs0 = self.masked_mean(
2750
3799
  self.backend.bk_L1(I1 * I2), vmask, axis=1, calc_var=True
2751
3800
  )
2752
- vs0 = self.backend.bk_concat([l_vs0, l_vs0], 1)
2753
- s0 = self.backend.bk_concat([s0, l_vs0], 1)
3801
+ #vs0 = self.backend.bk_concat([l_vs0, l_vs0], 1)
3802
+ #s0 = self.backend.bk_concat([s0, l_vs0], 1)
3803
+ result[:,:,0]=s0
3804
+ result[:,:,1]=l_vs0
3805
+ vresult[:,:,0]=l_vs0
3806
+ vresult[:,:,1]=l_vs0
2754
3807
  #### COMPUTE S1, S2, S3 and S4
2755
3808
  nside_j3 = nside # NSIDE start (nside_j3 = nside / 2^j3)
2756
3809
 
2757
3810
  # a remettre comme avant
2758
- M1_dic={}
3811
+ M1_dic={}
3812
+ M2_dic={}
2759
3813
 
2760
3814
  for j3 in range(Jmax):
2761
3815
 
@@ -2805,7 +3859,9 @@ class funct(FOC.FoCUS):
2805
3859
  M1_square = conv1 * self.backend.bk_conjugate(
2806
3860
  conv1
2807
3861
  ) # [Nbatch, Npix_j3, Norient3]
3862
+
2808
3863
  M1 = self.backend.bk_L1(M1_square) # [Nbatch, Npix_j3, Norient3]
3864
+
2809
3865
  # Store M1_j3 in a dictionary
2810
3866
  M1_dic[j3] = M1
2811
3867
 
@@ -2821,13 +3877,15 @@ class funct(FOC.FoCUS):
2821
3877
  s2, vs2 = self.masked_mean(
2822
3878
  M1_square, vmask, axis=1, rank=j3, calc_var=True
2823
3879
  )
3880
+ #s2=self.backend.bk_flatten(self.backend.bk_real(s2))
3881
+ #vs2=self.backend.bk_flatten(vs2)
2824
3882
  else:
2825
3883
  s2 = self.masked_mean(M1_square, vmask, axis=1, rank=j3)
2826
3884
 
2827
3885
  if cond_init_P1_dic:
2828
3886
  # We fill P1_dic with S2 for normalisation of S3 and S4
2829
- P1_dic[j3] = self.backend.bk_real(s2) # [Nbatch, Nmask, Norient3]
2830
-
3887
+ P1_dic[j3] = self.backend.bk_real(self.backend.bk_real(s2)) # [Nbatch, Nmask, Norient3]
3888
+
2831
3889
  # We store S2_auto to return it [Nbatch, Nmask, NS2, Norient3]
2832
3890
  if return_data:
2833
3891
  if S2 is None:
@@ -2849,7 +3907,7 @@ class funct(FOC.FoCUS):
2849
3907
  else:
2850
3908
  if norm == "auto": # Normalize S2
2851
3909
  s2 /= P1_dic[j3]
2852
-
3910
+ """
2853
3911
  S2.append(
2854
3912
  self.backend.bk_expand_dims(s2, off_S2)
2855
3913
  ) # Add a dimension for NS2
@@ -2857,7 +3915,11 @@ class funct(FOC.FoCUS):
2857
3915
  VS2.append(
2858
3916
  self.backend.bk_expand_dims(vs2, off_S2)
2859
3917
  ) # Add a dimension for NS2
2860
-
3918
+ """
3919
+ #print(s2.shape,result[:,:,2+j3*NORIENT*2:2+j3*NORIENT*2+NORIENT].shape,result.shape,2+j3*NORIENT*2)
3920
+ result[:,:,2+j3*NORIENT*2:2+j3*NORIENT*2+NORIENT]=s2
3921
+ if calc_var:
3922
+ vresult[:,:,2+j3*NORIENT*2:2+j3*NORIENT*2+NORIENT]=vs2
2861
3923
  #### S1_auto computation
2862
3924
  ### Image 1 : S1 = < M1 >_pix
2863
3925
  # Apply the mask [Nmask, Npix_j3] and average over pixels
@@ -2868,10 +3930,13 @@ class funct(FOC.FoCUS):
2868
3930
  s1, vs1 = self.masked_mean(
2869
3931
  M1, vmask, axis=1, rank=j3, calc_var=True
2870
3932
  ) # [Nbatch, Nmask, Norient3]
3933
+ #s1=self.backend.bk_flatten(self.backend.bk_real(s1))
3934
+ #vs1=self.backend.bk_flatten(vs1)
2871
3935
  else:
2872
3936
  s1 = self.masked_mean(
2873
3937
  M1, vmask, axis=1, rank=j3
2874
3938
  ) # [Nbatch, Nmask, Norient3]
3939
+ #s1=self.backend.bk_flatten(self.backend.bk_real(s1))
2875
3940
 
2876
3941
  if return_data:
2877
3942
  if out_nside is not None and out_nside < nside_j3:
@@ -2892,6 +3957,10 @@ class funct(FOC.FoCUS):
2892
3957
  ### Normalize S1
2893
3958
  if norm is not None:
2894
3959
  self.div_norm(s1, (P1_dic[j3]) ** 0.5)
3960
+ result[:,:,2+j3*NORIENT*2+NORIENT:2+j3*NORIENT*2+2*NORIENT]=s1
3961
+ if calc_var:
3962
+ vresult[:,:,2+j3*NORIENT*2+NORIENT:2+j3*NORIENT*2+2*NORIENT]=vs1
3963
+ """
2895
3964
  ### We store S1 for image1 [Nbatch, Nmask, NS1, Norient3]
2896
3965
  S1.append(
2897
3966
  self.backend.bk_expand_dims(s1, off_S2)
@@ -2900,6 +3969,7 @@ class funct(FOC.FoCUS):
2900
3969
  VS1.append(
2901
3970
  self.backend.bk_expand_dims(vs1, off_S2)
2902
3971
  ) # Add a dimension for NS1
3972
+ """
2903
3973
 
2904
3974
  else: # Cross
2905
3975
  ### Make the convolution I2 * Psi_j3
@@ -3048,7 +4118,7 @@ class funct(FOC.FoCUS):
3048
4118
 
3049
4119
  ###### S3
3050
4120
  nside_j2 = nside_j3
3051
- for j2 in range(0, j3 + 1): # j2 <= j3
4121
+ for j2 in range(0,-1): # j3 + 1): # j2 <= j3
3052
4122
  if return_data:
3053
4123
  if S4[j3] is None:
3054
4124
  S4[j3] = {}
@@ -3463,6 +4533,20 @@ class funct(FOC.FoCUS):
3463
4533
 
3464
4534
  return self.backend.bk_concat(Sout, 2)
3465
4535
  """
4536
+ if calc_var:
4537
+ return result,vresult
4538
+ else:
4539
+ return result
4540
+ if calc_var:
4541
+ for k in S1:
4542
+ print(k.shape,k.dtype)
4543
+ for k in S2:
4544
+ print(k.shape,k.dtype)
4545
+ print(s0.shape,s0.dtype)
4546
+ return self.backend.bk_concat([s0]+S1+S2,axis=1),self.backend.bk_concat([vs0]+VS1+VS2,axis=1)
4547
+ else:
4548
+ return self.backend.bk_concat([s0]+S1+S2,axis=1)
4549
+
3466
4550
  if not return_data:
3467
4551
  S1 = self.backend.bk_concat(S1, 2)
3468
4552
  S2 = self.backend.bk_concat(S2, 2)
@@ -3526,7 +4610,6 @@ class funct(FOC.FoCUS):
3526
4610
  backend=self.backend,
3527
4611
  use_1D=self.use_1D,
3528
4612
  )
3529
-
3530
4613
  def clean_norm(self):
3531
4614
  self.P1_dic = None
3532
4615
  self.P2_dic = None
@@ -3644,7 +4727,733 @@ class funct(FOC.FoCUS):
3644
4727
  s4, vmask, axis=1, rank=j2
3645
4728
  ) # [Nbatch, Nmask, Norient3, Norient2, Norient1]
3646
4729
  return s4
4730
+
4731
+ def computer_filter(self,M,N,J,L):
4732
+ '''
4733
+ This function is strongly inspire by the package https://github.com/SihaoCheng/scattering_transform
4734
+ Done by Sihao Cheng and Rudy Morel.
4735
+ '''
4736
+
4737
+ filter = np.zeros([J, L, M, N],dtype='complex64')
4738
+
4739
+ slant=4.0 / L
4740
+
4741
+ for j in range(J):
4742
+
4743
+ for l in range(L):
4744
+
4745
+ theta = (int(L-L/2-1)-l) * np.pi / L
4746
+ sigma = 0.8 * 2**j
4747
+ xi = 3.0 / 4.0 * np.pi /2**j
4748
+
4749
+ R = np.array([[np.cos(theta), -np.sin(theta)], [np.sin(theta), np.cos(theta)]], np.float64)
4750
+ R_inv = np.array([[np.cos(theta), np.sin(theta)], [-np.sin(theta), np.cos(theta)]], np.float64)
4751
+ D = np.array([[1, 0], [0, slant * slant]])
4752
+ curv = np.matmul(R, np.matmul(D, R_inv)) / ( 2 * sigma * sigma)
4753
+
4754
+ gab = np.zeros((M, N), np.complex128)
4755
+ xx = np.empty((2,2, M, N))
4756
+ yy = np.empty((2,2, M, N))
4757
+
4758
+ for ii, ex in enumerate([-1, 0]):
4759
+ for jj, ey in enumerate([-1, 0]):
4760
+ xx[ii,jj], yy[ii,jj] = np.mgrid[
4761
+ ex * M : M + ex * M,
4762
+ ey * N : N + ey * N]
4763
+
4764
+ arg = -(curv[0, 0] * xx * xx + (curv[0, 1] + curv[1, 0]) * xx * yy + curv[1, 1] * yy * yy)
4765
+ argi = arg + 1.j * (xx * xi * np.cos(theta) + yy * xi * np.sin(theta))
4766
+
4767
+ gabi = np.exp(argi).sum((0,1))
4768
+ gab = np.exp(arg).sum((0,1))
4769
+
4770
+ norm_factor = 2 * np.pi * sigma * sigma / slant
4771
+
4772
+ gab = gab / norm_factor
4773
+
4774
+ gabi = gabi / norm_factor
4775
+
4776
+ K = gabi.sum() / gab.sum()
4777
+
4778
+ # Apply the Gaussian
4779
+ filter[j, l] = np.fft.fft2(gabi-K*gab)
4780
+ filter[j,l,0,0]=0.0
4781
+
4782
+ return self.backend.bk_cast(filter)
4783
+
4784
+ # ------------------------------------------------------------------------------------------
4785
+ #
4786
+ # utility functions
4787
+ #
4788
+ # ------------------------------------------------------------------------------------------
4789
+ def cut_high_k_off(self,data_f, dx, dy):
4790
+ '''
4791
+ This function is strongly inspire by the package https://github.com/SihaoCheng/scattering_transform
4792
+ Done by Sihao Cheng and Rudy Morel.
4793
+ '''
4794
+ if_xodd = (data_f.shape[-2]%2==1)
4795
+ if_yodd = (data_f.shape[-1]%2==1)
4796
+ result = self.backend.backend.cat(
4797
+ (self.backend.backend.cat(
4798
+ ( data_f[...,:dx+if_xodd, :dy+if_yodd] , data_f[...,-dx:, :dy+if_yodd]
4799
+ ), -2),
4800
+ self.backend.backend.cat(
4801
+ ( data_f[...,:dx+if_xodd, -dy:] , data_f[...,-dx:, -dy:]
4802
+ ), -2)
4803
+ ),-1)
4804
+ return result
4805
+ # ---------------------------------------------------------------------------
4806
+ #
4807
+ # utility functions for computing scattering coef and covariance
4808
+ #
4809
+ # ---------------------------------------------------------------------------
4810
+
4811
+ def get_dxdy(self, j,M,N):
4812
+ '''
4813
+ This function is strongly inspire by the package https://github.com/SihaoCheng/scattering_transform
4814
+ Done by Sihao Cheng and Rudy Morel.
4815
+ '''
4816
+ dx = int(max( 8, min( np.ceil(M/2**j), M//2 ) ))
4817
+ dy = int(max( 8, min( np.ceil(N/2**j), N//2 ) ))
4818
+ return dx, dy
4819
+
4820
+
4821
+
4822
+ def get_edge_masks(self,M, N, J, d0=1):
4823
+ '''
4824
+ This function is strongly inspire by the package https://github.com/SihaoCheng/scattering_transform
4825
+ Done by Sihao Cheng and Rudy Morel.
4826
+ '''
4827
+ edge_masks = self.backend.backend.empty((J, M, N))
4828
+ X, Y = self.backend.backend.meshgrid(self.backend.backend.arange(M), self.backend.backend.arange(N), indexing='ij')
4829
+ for j in range(J):
4830
+ edge_dx = min(M//4, 2**j*d0)
4831
+ edge_dy = min(N//4, 2**j*d0)
4832
+ edge_masks[j] = (X>=edge_dx) * (X<=M-edge_dx) * (Y>=edge_dy) * (Y<=N-edge_dy)
4833
+ return edge_masks.to(self.backend.torch_device)
4834
+
4835
+ # ---------------------------------------------------------------------------
4836
+ #
4837
+ # scattering cov
4838
+ #
4839
+ # ---------------------------------------------------------------------------
4840
+ def scattering_cov(
4841
+ self, data,
4842
+ data2=None,
4843
+ Jmax=None,
4844
+ if_large_batch=False,
4845
+ S4_criteria=None,
4846
+ use_ref=False,
4847
+ normalization='S2',
4848
+ edge=False,
4849
+ pseudo_coef=1,
4850
+ get_variance=False,
4851
+ ref_sigma=None,
4852
+ iso_ang=False
4853
+ ):
4854
+ '''
4855
+ Calculates the scattering correlations for a batch of images, including:
4856
+
4857
+ This function is strongly inspire by the package https://github.com/SihaoCheng/scattering_transform
4858
+ Done by Sihao Cheng and Rudy Morel.
4859
+
4860
+ orig. x orig.:
4861
+ P00 = <(I * psi)(I * psi)*> = L2(I * psi)^2
4862
+ orig. x modulus:
4863
+ C01 = <(I * psi2)(|I * psi1| * psi2)*> / factor
4864
+ when normalization == 'P00', factor = L2(I * psi2) * L2(I * psi1)
4865
+ when normalization == 'P11', factor = L2(I * psi2) * L2(|I * psi1| * psi2)
4866
+ modulus x modulus:
4867
+ C11_pre_norm = <(|I * psi1| * psi3)(|I * psi2| * psi3)>
4868
+ C11 = C11_pre_norm / factor
4869
+ when normalization == 'P00', factor = L2(I * psi1) * L2(I * psi2)
4870
+ when normalization == 'P11', factor = L2(|I * psi1| * psi3) * L2(|I * psi2| * psi3)
4871
+ modulus x modulus (auto):
4872
+ P11 = <(|I * psi1| * psi2)(|I * psi1| * psi2)*>
4873
+ Parameters
4874
+ ----------
4875
+ data : numpy array or torch tensor
4876
+ image set, with size [N_image, x-sidelength, y-sidelength]
4877
+ if_large_batch : Bool (=False)
4878
+ It is recommended to use "False" unless one meets a memory issue
4879
+ C11_criteria : str or None (=None)
4880
+ Only C11 coefficients that satisfy this criteria will be computed.
4881
+ Any expressions of j1, j2, and j3 that can be evaluated as a Bool
4882
+ is accepted.The default "None" corresponds to "j1 <= j2 <= j3".
4883
+ use_ref : Bool (=False)
4884
+ When normalizing, whether or not to use the normalization factor
4885
+ computed from a reference field. For just computing the statistics,
4886
+ the default is False. However, for synthesis, set it to "True" will
4887
+ stablize the optimization process.
4888
+ normalization : str 'P00' or 'P11' (='P00')
4889
+ Whether 'P00' or 'P11' is used as the normalization factor for C01
4890
+ and C11.
4891
+ remove_edge : Bool (=False)
4892
+ If true, the edge region with a width of rougly the size of the largest
4893
+ wavelet involved is excluded when taking the global average to obtain
4894
+ the scattering coefficients.
4895
+
4896
+ Returns
4897
+ -------
4898
+ 'P00' : torch tensor with size [N_image, J, L] (# image, j1, l1)
4899
+ the power in each wavelet bands (the orig. x orig. term)
4900
+ 'S1' : torch tensor with size [N_image, J, L] (# image, j1, l1)
4901
+ the 1st-order scattering coefficients, i.e., the mean of wavelet modulus fields
4902
+ 'C01' : torch tensor with size [N_image, J, J, L, L] (# image, j1, j2, l1, l2)
4903
+ the orig. x modulus terms. Elements with j1 < j2 are all set to np.nan and not computed.
4904
+ 'C11' : torch tensor with size [N_image, J, J, J, L, L, L] (# image, j1, j2, j3, l1, l2, l3)
4905
+ the modulus x modulus terms. Elements not satisfying j1 <= j2 <= j3 and the conditions
4906
+ defined in 'C11_criteria' are all set to np.nan and not computed.
4907
+ 'C11_pre_norm' and 'C11_pre_norm_iso': pre-normalized modulus x modulus terms.
4908
+ 'P11' : torch tensor with size [N_image, J, J, L, L] (# image, j1, j2, l1, l2)
4909
+ the modulus x modulus terms with the two wavelets within modulus the same. Elements not following
4910
+ j1 <= j3 are set to np.nan and not computed.
4911
+ 'P11_iso' : torch tensor with size [N_image, J, J, L] (# image, j1, j2, l2-l1)
4912
+ 'P11' averaged over l1 while keeping l2-l1 constant.
4913
+ '''
4914
+ if S4_criteria is None:
4915
+ S4_criteria = 'j2>=j1'
4916
+
4917
+ # determine jmax and nside corresponding to the input map
4918
+ im_shape = data.shape
4919
+ if self.use_2D:
4920
+ if len(data.shape) == 2:
4921
+ nside = np.min([im_shape[0], im_shape[1]])
4922
+ M,N = im_shape[0],im_shape[1]
4923
+ N_image = 1
4924
+ else:
4925
+ nside = np.min([im_shape[1], im_shape[2]])
4926
+ M,N = im_shape[1],im_shape[2]
4927
+ N_image = data.shape[0]
4928
+ J = int(np.log(nside) / np.log(2))-1 # Number of j scales
4929
+ elif self.use_1D:
4930
+ if len(data.shape) == 2:
4931
+ npix = int(im_shape[1]) # Number of pixels
4932
+ N_image = 1
4933
+ else:
4934
+ npix = int(im_shape[0]) # Number of pixels
4935
+ N_image = data.shape[0]
4936
+
4937
+ nside = int(npix)
4938
+
4939
+ J = int(np.log(nside) / np.log(2))-1 # Number of j scales
4940
+ else:
4941
+ if len(data.shape) == 2:
4942
+ npix = int(im_shape[1]) # Number of pixels
4943
+ N_image = 1
4944
+ else:
4945
+ npix = int(im_shape[0]) # Number of pixels
4946
+ N_image = data.shape[0]
4947
+
4948
+ nside = int(np.sqrt(npix // 12))
4949
+
4950
+ J = int(np.log(nside) / np.log(2)) # Number of j scales
4951
+
4952
+ if Jmax is None:
4953
+ Jmax = J # Number of steps for the loop on scales
4954
+ if Jmax>J:
4955
+ print('==========\n\n')
4956
+ print('The Jmax you requested is larger than the data size, which may cause problems while computing the scattering transform.')
4957
+ print('\n\n==========')
4958
+
4959
+ L=self.NORIENT
4960
+
4961
+ if (M,N,J,L) not in self.filters_set:
4962
+ self.filters_set[(M,N,J,L)] = self.computer_filter(M,N,J,L) #self.computer_filter(M,N,J,L)
4963
+
4964
+ filters_set = self.filters_set[(M,N,J,L)]
4965
+
4966
+ #weight = self.weight
4967
+ if use_ref:
4968
+ if normalization=='S2':
4969
+ ref_S2 = self.ref_scattering_cov_S2
4970
+ else:
4971
+ ref_P11 = self.ref_scattering_cov['P11']
4972
+
4973
+ # convert numpy array input into self.backend.bk_ tensors
4974
+ data = self.backend.bk_cast(data)
4975
+ data_f = self.backend.bk_fftn(data, dim=(-2,-1))
4976
+ if data2 is not None:
4977
+ data2 = self.backend.bk_cast(data2)
4978
+ data2_f = self.backend.bk_fftn(data2, dim=(-2,-1))
4979
+
4980
+ # initialize tensors for scattering coefficients
4981
+ S2 = self.backend.bk_zeros((N_image,J,L), dtype=data.dtype)
4982
+ S1 = self.backend.bk_zeros((N_image,J,L), dtype=data.dtype)
4983
+
4984
+ Ndata_S3 = J*(J+1)//2
4985
+ Ndata_S4 = J*(J+1)*(J+2)//6
4986
+ J_S4={}
4987
+
4988
+ S3 = self.backend.bk_zeros((N_image,Ndata_S3,L,L), dtype=data_f.dtype)
4989
+ if data2 is not None:
4990
+ S3p = self.backend.bk_zeros((N_image,Ndata_S3,L,L), dtype=data_f.dtype)
4991
+ S4_pre_norm = self.backend.bk_zeros((N_image,Ndata_S4,L,L,L), dtype=data_f.dtype)
4992
+ S4 = self.backend.bk_zeros((N_image,Ndata_S4,L,L,L), dtype=data_f.dtype)
4993
+
4994
+ # variance
4995
+ if get_variance:
4996
+ S2_sigma = self.backend.bk_zeros((N_image,J,L), dtype=data.dtype)
4997
+ S1_sigma = self.backend.bk_zeros((N_image,J,L), dtype=data.dtype)
4998
+ S3_sigma = self.backend.bk_zeros((N_image,Ndata_S3,L,L), dtype=data_f.dtype)
4999
+ if data2 is not None:
5000
+ S3p_sigma = self.backend.bk_zeros((N_image,Ndata_S3,L,L), dtype=data_f.dtype)
5001
+ S4_sigma = self.backend.bk_zeros((N_image,Ndata_S4,L,L,L), dtype=data_f.dtype)
5002
+
5003
+ if iso_ang:
5004
+ S3_iso = self.backend.bk_zeros((N_image,Ndata_S3,L), dtype=data_f.dtype)
5005
+ if data2 is not None:
5006
+ S3p_iso = self.backend.bk_zeros((N_image,Ndata_S3,L,L), dtype=data_f.dtype)
5007
+
5008
+ S4_iso = self.backend.bk_zeros((N_image,Ndata_S4,L,L), dtype=data_f.dtype)
5009
+ if get_variance:
5010
+ S3_sigma_iso = self.backend.bk_zeros((N_image,Ndata_S3,L), dtype=data_f.dtype)
5011
+ if data2 is not None:
5012
+ S3p_sigma_iso = self.backend.bk_zeros((N_image,Ndata_S3,L), dtype=data_f.dtype)
5013
+ S4_sigma_iso = self.backend.bk_zeros((N_image,Ndata_S4,L,L), dtype=data_f.dtype)
5014
+
5015
+ #
5016
+ if edge:
5017
+ if (M,N,J) not in self.edge_masks:
5018
+ self.edge_masks[(M,N,J)] = self.get_edge_masks(M,N,J)
5019
+ edge_mask = self.edge_masks[(M,N,J)][:,None,:,:]
5020
+ edge_mask = edge_mask / edge_mask.mean((-2,-1))[:,:,None,None]
5021
+ else:
5022
+ edge_mask = 1
5023
+
5024
+ # calculate scattering fields
5025
+ if data2 is None:
5026
+ if self.use_2D:
5027
+ if len(data.shape) == 2:
5028
+ I1 = self.backend.bk_ifftn(
5029
+ data_f[None,None,None,:,:] * filters_set[None,:J,:,:,:], dim=(-2,-1)
5030
+ ).abs()
5031
+ else:
5032
+ I1 = self.backend.bk_ifftn(
5033
+ data_f[:,None,None,:,:] * filters_set[None,:J,:,:,:], dim=(-2,-1)
5034
+ ).abs()
5035
+ elif self.use_1D:
5036
+ if len(data.shape) == 1:
5037
+ I1 = self.backend.bk_ifftn(
5038
+ data_f[None,None,None,:] * filters_set[None,:J,:,:], dim=(-1)
5039
+ ).abs()
5040
+ else:
5041
+ I1 = self.backend.bk_ifftn(
5042
+ data_f[:,None,None,:] * filters_set[None,:J,:,:], dim=(-1)
5043
+ ).abs()
5044
+ else:
5045
+ print('todo')
5046
+
5047
+ S2 = (I1**2 * edge_mask).mean((-2,-1))
5048
+ S1 = (I1 * edge_mask).mean((-2,-1))
5049
+
5050
+ if get_variance:
5051
+ S2_sigma = (I1**2 * edge_mask).std((-2,-1))
5052
+ S1_sigma = (I1 * edge_mask).std((-2,-1))
5053
+
5054
+ I1_f= self.backend.bk_fftn(I1, dim=(-2,-1))
5055
+
5056
+ else:
5057
+ if self.use_2D:
5058
+ if len(data.shape) == 2:
5059
+ I1 = self.backend.bk_ifftn(
5060
+ data_f[None,None,None,:,:] * filters_set[None,:J,:,:,:], dim=(-2,-1)
5061
+ )
5062
+ I2 = self.backend.bk_ifftn(
5063
+ data2_f[None,None,None,:,:] * filters_set[None,:J,:,:,:], dim=(-2,-1)
5064
+ )
5065
+ else:
5066
+ I1 = self.backend.bk_ifftn(
5067
+ data_f[:,None,None,:,:] * filters_set[None,:J,:,:,:], dim=(-2,-1)
5068
+ )
5069
+ I2 = self.backend.bk_ifftn(
5070
+ data2_f[:,None,None,:,:] * filters_set[None,:J,:,:,:], dim=(-2,-1)
5071
+ )
5072
+ elif self.use_1D:
5073
+ if len(data.shape) == 1:
5074
+ I1 = self.backend.bk_ifftn(
5075
+ data_f[None,None,None,:] * filters_set[None,:J,:,:], dim=(-1)
5076
+ )
5077
+ I2 = self.backend.bk_ifftn(
5078
+ data2_f[None,None,None,:] * filters_set[None,:J,:,:], dim=(-1)
5079
+ )
5080
+ else:
5081
+ I1 = self.backend.bk_ifftn(
5082
+ data_f[:,None,None,:] * filters_set[None,:J,:,:], dim=(-1)
5083
+ )
5084
+ I2 = self.backend.bk_ifftn(
5085
+ data2_f[:,None,None,:] * filters_set[None,:J,:,:], dim=(-1)
5086
+ )
5087
+ else:
5088
+ print('todo')
5089
+
5090
+ I1=self.backend.bk_real(I1*self.backend.bk_conjugate(I2))
5091
+
5092
+ S2 = (I1 * edge_mask).mean((-2,-1))
5093
+ if get_variance:
5094
+ S2_sigma = (I1 * edge_mask).std((-2,-1))
5095
+
5096
+ I1=self.backend.bk_L1(I1)
5097
+
5098
+ S1 = (I1 * edge_mask).mean((-2,-1))
3647
5099
 
5100
+ if get_variance:
5101
+ S1_sigma = (I1 * edge_mask).std((-2,-1))
5102
+
5103
+ I1_f= self.backend.bk_fftn(I1, dim=(-2,-1))
5104
+
5105
+ if pseudo_coef != 1:
5106
+ I1 = I1**pseudo_coef
5107
+
5108
+ Ndata_S3=0
5109
+ Ndata_S4=0
5110
+
5111
+ # calculate the covariance and correlations of the scattering fields
5112
+ # only use the low-k Fourier coefs when calculating large-j scattering coefs.
5113
+ for j3 in range(0,J):
5114
+ J_S4[j3]=Ndata_S4
5115
+
5116
+ dx3, dy3 = self.get_dxdy(j3,M,N)
5117
+ I1_f_small = self.cut_high_k_off(I1_f[:,:j3+1], dx3, dy3) # Nimage, J, L, x, y
5118
+ data_f_small = self.cut_high_k_off(data_f, dx3, dy3)
5119
+ if data2 is not None:
5120
+ data2_f_small = self.cut_high_k_off(data2_f, dx3, dy3)
5121
+ if edge:
5122
+ I1_small = self.backend.bk_ifftn(I1_f_small, dim=(-2,-1), norm='ortho')
5123
+ data_small = self.backend.bk_ifftn(data_f_small, dim=(-2,-1), norm='ortho')
5124
+ if data2 is not None:
5125
+ data2_small = self.backend.bk_ifftn(data2_f_small, dim=(-2,-1), norm='ortho')
5126
+ wavelet_f3 = self.cut_high_k_off(filters_set[j3], dx3, dy3) # L,x,y
5127
+ _, M3, N3 = wavelet_f3.shape
5128
+ wavelet_f3_squared = wavelet_f3**2
5129
+ edge_dx = min(4, int(2**j3*dx3*2/M))
5130
+ edge_dy = min(4, int(2**j3*dy3*2/N))
5131
+ # a normalization change due to the cutoff of frequency space
5132
+ fft_factor = 1 /(M3*N3) * (M3*N3/M/N)**2
5133
+ for j2 in range(0,j3+1):
5134
+ I1_f2_wf3_small = I1_f_small[:,j2].view(N_image,L,1,M3,N3) * wavelet_f3.view(1,1,L,M3,N3)
5135
+ I1_f2_wf3_2_small = I1_f_small[:,j2].view(N_image,L,1,M3,N3) * wavelet_f3_squared.view(1,1,L,M3,N3)
5136
+ if edge:
5137
+ I12_w3_small = self.backend.bk_ifftn(I1_f2_wf3_small, dim=(-2,-1), norm='ortho')
5138
+ I12_w3_2_small = self.backend.bk_ifftn(I1_f2_wf3_2_small, dim=(-2,-1), norm='ortho')
5139
+ if use_ref:
5140
+ if normalization=='P11':
5141
+ norm_factor_S3 = (ref_S2[:,None,j3,:] * ref_P11[:,j2,j3,:,:]**pseudo_coef)**0.5
5142
+ elif normalization=='S2':
5143
+ norm_factor_S3 = (ref_S2[:,None,j3,:] * ref_S2[:,j2,:,None]**pseudo_coef)**0.5
5144
+ norm_factor_S3 = 1.0
5145
+ else:
5146
+ if normalization=='P11':
5147
+ # [N_image,l2,l3,x,y]
5148
+ P11_temp = (I1_f2_wf3_small.abs()**2).mean((-2,-1)) * fft_factor
5149
+ norm_factor_S3 = (S2[:,None,j3,:] * P11_temp**pseudo_coef)**0.5
5150
+ elif normalization=='S2':
5151
+ norm_factor_S3 = (S2[:,None,j3,:] * S2[:,j2,:,None]**pseudo_coef)**0.5
5152
+ norm_factor_S3 = 1.0
5153
+
5154
+ if not edge:
5155
+ S3[:,Ndata_S3,:,:] = (
5156
+ data_f_small.view(N_image,1,1,M3,N3) * self.backend.bk_conjugate(I1_f2_wf3_small)
5157
+ ).mean((-2,-1)) * fft_factor / norm_factor_S3
5158
+
5159
+ if get_variance:
5160
+ S3_sigma[:,Ndata_S3,:,:] = (
5161
+ data_f_small.view(N_image,1,1,M3,N3) * self.backend.bk_conjugate(I1_f2_wf3_small)
5162
+ ).std((-2,-1)) * fft_factor / norm_factor_S3
5163
+ else:
5164
+
5165
+ S3[:,Ndata_S3,:,:] = (
5166
+ data_small.view(N_image,1,1,M3,N3) * self.backend.bk_conjugate(I12_w3_small)
5167
+ )[...,edge_dx:M3-edge_dx, edge_dy:N3-edge_dy].mean((-2,-1)) * fft_factor / norm_factor_S3
5168
+ if get_variance:
5169
+ S3_sigma[:,Ndata_S3,:,:] = (
5170
+ data_small.view(N_image,1,1,M3,N3) * self.backend.bk_conjugate(I12_w3_small)
5171
+ )[...,edge_dx:M3-edge_dx, edge_dy:N3-edge_dy].std((-2,-1)) * fft_factor / norm_factor_S3
5172
+ if data2 is not None:
5173
+ if not edge:
5174
+ S3p[:,Ndata_S3,:,:] = (
5175
+ data2_f_small.view(N_image,1,1,M3,N3) * self.backend.bk_conjugate(I1_f2_wf3_small)
5176
+ ).mean((-2,-1)) * fft_factor / norm_factor_S3
5177
+
5178
+ if get_variance:
5179
+ S3p_sigma[:,Ndata_S3,:,:] = (
5180
+ data2_f_small.view(N_image,1,1,M3,N3) * self.backend.bk_conjugate(I1_f2_wf3_small)
5181
+ ).std((-2,-1)) * fft_factor / norm_factor_S3
5182
+ else:
5183
+
5184
+ S3p[:,Ndata_S3,:,:] = (
5185
+ data2_small.view(N_image,1,1,M3,N3) * self.backend.bk_conjugate(I12_w3_small)
5186
+ )[...,edge_dx:M3-edge_dx, edge_dy:N3-edge_dy].mean((-2,-1)) * fft_factor / norm_factor_S3
5187
+ if get_variance:
5188
+ S3p_sigma[:,Ndata_S3,:,:] = (
5189
+ data2_small.view(N_image,1,1,M3,N3) * self.backend.bk_conjugate(I12_w3_small)
5190
+ )[...,edge_dx:M3-edge_dx, edge_dy:N3-edge_dy].std((-2,-1)) * fft_factor / norm_factor_S3
5191
+
5192
+ Ndata_S3+=1
5193
+ if j2 <= j3:
5194
+ beg_n=Ndata_S4
5195
+ for j1 in range(0, j2+1):
5196
+ if eval(S4_criteria):
5197
+ if not edge:
5198
+ if not if_large_batch:
5199
+ # [N_image,l1,l2,l3,x,y]
5200
+ S4_pre_norm[:,Ndata_S4,:,:,:] = (
5201
+ I1_f_small[:,j1].view(N_image,L,1,1,M3,N3) *
5202
+ self.backend.bk_conjugate(I1_f2_wf3_2_small.view(N_image,1,L,L,M3,N3))
5203
+ ).mean((-2,-1)) * fft_factor
5204
+ if get_variance:
5205
+ S4_sigma[:,Ndata_S4,:,:,:] = (
5206
+ I1_f_small[:,j1].view(N_image,L,1,1,M3,N3) *
5207
+ self.backend.bk_conjugate(I1_f2_wf3_2_small.view(N_image,1,L,L,M3,N3))
5208
+ ).std((-2,-1)) * fft_factor
5209
+ else:
5210
+ for l1 in range(L):
5211
+ # [N_image,l2,l3,x,y]
5212
+ S4_pre_norm[:,Ndata_S4,l1,:,:] = (
5213
+ I1_f_small[:,j1,l1].view(N_image,1,1,M3,N3) *
5214
+ self.backend.bk_conjugate(I1_f2_wf3_2_small.view(N_image,L,L,M3,N3))
5215
+ ).mean((-2,-1)) * fft_factor
5216
+ if get_variance:
5217
+ S4_sigma[:,Ndata_S4,l1,:,:] = (
5218
+ I1_f_small[:,j1,l1].view(N_image,1,1,M3,N3) *
5219
+ self.backend.bk_conjugate(I1_f2_wf3_2_small.view(N_image,L,L,M3,N3))
5220
+ ).std((-2,-1)) * fft_factor
5221
+ else:
5222
+ if not if_large_batch:
5223
+ # [N_image,l1,l2,l3,x,y]
5224
+ S4_pre_norm[:,Ndata_S4,:,:,:] = (
5225
+ I1_small[:,j1].view(N_image,L,1,1,M3,N3) * self.backend.bk_conjugate(
5226
+ I12_w3_2_small.view(N_image,1,L,L,M3,N3)
5227
+ )
5228
+ )[...,edge_dx:-edge_dx, edge_dy:-edge_dy].mean((-2,-1)) * fft_factor
5229
+ if get_variance:
5230
+ S4_sigma[:,Ndata_S4,:,:,:] = (
5231
+ I1_small[:,j1].view(N_image,L,1,1,M3,N3) * self.backend.bk_conjugate(
5232
+ I12_w3_2_small.view(N_image,1,L,L,M3,N3)
5233
+ )
5234
+ )[...,edge_dx:-edge_dx, edge_dy:-edge_dy].std((-2,-1)) * fft_factor
5235
+ else:
5236
+ for l1 in range(L):
5237
+ # [N_image,l2,l3,x,y]
5238
+ S4_pre_norm[:,Ndata_S4,l1,:,:] = (
5239
+ I1_small[:,j1].view(N_image,1,1,M3,N3) * self.backend.bk_conjugate(
5240
+ I12_w3_2_small.view(N_image,L,L,M3,N3)
5241
+ )
5242
+ )[...,edge_dx:-edge_dx, edge_dy:-edge_dy].mean((-2,-1)) * fft_factor
5243
+ if get_variance:
5244
+ S4_sigma[:,Ndata_S4,l1,:,:] = (
5245
+ I1_small[:,j1].view(N_image,1,1,M3,N3) * self.backend.bk_conjugate(
5246
+ I12_w3_2_small.view(N_image,L,L,M3,N3)
5247
+ )
5248
+ )[...,edge_dx:-edge_dx, edge_dy:-edge_dy].mean((-2,-1)) * fft_factor
5249
+
5250
+ Ndata_S4+=1
5251
+
5252
+ if normalization=='S2':
5253
+ if use_ref:
5254
+ P = (ref_S2[:,j3,:,None,None] * ref_S2[:,j2,None,:,None] )**(0.5*pseudo_coef)
5255
+ else:
5256
+ P = ((S2[:,j3,:,None,None] * S2[:,j2,None,:,None] )**(0.5*pseudo_coef))
5257
+ S4[:,beg_n:Ndata_S4,:,:,:]=S4_pre_norm[:,beg_n:Ndata_S4,:,:,:].clone()/(P.clone())
5258
+
5259
+ if get_variance:
5260
+ S4_sigma[:,beg_n:Ndata_S4,:,:,:] = S4_sigma[:,beg_n:Ndata_S4,:,:,:] / (P)
5261
+ else:
5262
+ S4[:,beg_n:Ndata_S4,:,:,:]=S4_pre_norm[:,beg_n:Ndata_S4,:,:,:].clone()
5263
+
5264
+ if get_variance:
5265
+ S4_sigma[:,beg_n:Ndata_S4,:,:,:] = S4_sigma[:,beg_n:Ndata_S4,:,:,:]
5266
+
5267
+ """
5268
+ # define P11 from diagonals of S4
5269
+ for j1 in range(J):
5270
+ for l1 in range(L):
5271
+ P11[:,j1,:,l1,:] = S4_pre_norm[:,j1,j1,:,l1,l1,:].real
5272
+
5273
+
5274
+ if normalization=='S4':
5275
+ if use_ref:
5276
+ P = ref_P11
5277
+ else:
5278
+ P = P11
5279
+ #.view(N_image,J,1,J,L,1,L) * .view(N_image,1,J,J,1,L,L)
5280
+ S4 = S4_pre_norm / (
5281
+ P[:,:,None,:,:,None,:] * P[:,None,:,:,None,:,:]
5282
+ )**(0.5*pseudo_coef)
5283
+
5284
+
5285
+
5286
+
5287
+ # get a single, flattened data vector for_synthesis
5288
+ select_and_index = self.get_scattering_index(J, L, normalization, S4_criteria)
5289
+ index_for_synthesis = select_and_index['index_for_synthesis']
5290
+ index_for_synthesis_iso = select_and_index['index_for_synthesis_iso']
5291
+ """
5292
+ # average over l1 to obtain simple isotropic statistics
5293
+ if iso_ang:
5294
+ S2_iso = S2.mean(-1)
5295
+ S1_iso = S1.mean(-1)
5296
+ for l1 in range(L):
5297
+ for l2 in range(L):
5298
+ S3_iso[...,(l2-l1)%L] += S3[...,l1,l2]
5299
+ if data2 is not None:
5300
+ S3p_iso[...,(l2-l1)%L] += S3p[...,l1,l2]
5301
+ for l3 in range(L):
5302
+ S4_iso[...,(l2-l1)%L,(l3-l1)%L] += S4[...,l1,l2,l3]
5303
+ S3_iso /= L; S4_iso /= L
5304
+ if data2 is not None:
5305
+ S3p_iso /= L
5306
+
5307
+ if get_variance:
5308
+ S2_sigma_iso = S2_sigma.mean(-1)
5309
+ S1_sigma_iso = S1_sigma.mean(-1)
5310
+ for l1 in range(L):
5311
+ for l2 in range(L):
5312
+ S3_sigma_iso[...,(l2-l1)%L] += S3_sigma[...,l1,l2]
5313
+ if data2 is not None:
5314
+ S3p_sigma_iso[...,(l2-l1)%L] += S3p_sigma[...,l1,l2]
5315
+ for l3 in range(L):
5316
+ S4_sigma_iso[...,(l2-l1)%L,(l3-l1)%L] += S4_sigma[...,l1,l2,l3]
5317
+ S3_sigma_iso /= L; S4_sigma_iso /= L
5318
+ if data2 is not None:
5319
+ S3p_sigma_iso /= L
5320
+
5321
+ mean_data=self.backend.bk_zeros((N_image,1), dtype=data.dtype)
5322
+ std_data=self.backend.bk_zeros((N_image,1), dtype=data.dtype)
5323
+ mean_data[:,0]=data.mean((-2,-1))
5324
+ std_data[:,0]=data.std((-2,-1))
5325
+
5326
+ if get_variance:
5327
+ ref_sigma={}
5328
+ if iso_ang:
5329
+ ref_sigma['std_data']=std_data
5330
+ ref_sigma['S1_sigma']=S1_sigma_iso
5331
+ ref_sigma['S2_sigma']=S2_sigma_iso
5332
+ ref_sigma['S3_sigma']=S3_sigma_iso
5333
+ ref_sigma['S4_sigma']=S4_sigma_iso
5334
+ if data2 is not None:
5335
+ ref_sigma['S3p_sigma']=S3p_sigma_iso
5336
+ else:
5337
+ ref_sigma['std_data']=std_data
5338
+ ref_sigma['S1_sigma']=S1_sigma
5339
+ ref_sigma['S2_sigma']=S2_sigma
5340
+ ref_sigma['S3_sigma']=S3_sigma
5341
+ ref_sigma['S4_sigma']=S4_sigma
5342
+ if data2 is not None:
5343
+ ref_sigma['S3p_sigma']=S3_sigma
5344
+
5345
+ if data2 is None:
5346
+ if iso_ang:
5347
+ if ref_sigma is not None:
5348
+ for_synthesis = self.backend.backend.cat((
5349
+ mean_data/ref_sigma['std_data'],
5350
+ std_data/ref_sigma['std_data'],
5351
+ (S2_iso/ref_sigma['S2_sigma']).reshape((N_image, -1)).log(),
5352
+ (S1_iso/ref_sigma['S1_sigma']).reshape((N_image, -1)).log(),
5353
+ (S3_iso/ref_sigma['S3_sigma']).reshape((N_image, -1)).real,
5354
+ (S3_iso/ref_sigma['S3_sigma']).reshape((N_image, -1)).imag,
5355
+ (S4_iso/ref_sigma['S4_sigma']).reshape((N_image, -1)).real,
5356
+ (S4_iso/ref_sigma['S4_sigma']).reshape((N_image, -1)).imag,
5357
+ ),dim=-1)
5358
+ else:
5359
+ for_synthesis = self.backend.backend.cat((
5360
+ mean_data/std_data,
5361
+ std_data,
5362
+ S2_iso.reshape((N_image, -1)).log(),
5363
+ S1_iso.reshape((N_image, -1)).log(),
5364
+ S3_iso.reshape((N_image, -1)).real,
5365
+ S3_iso.reshape((N_image, -1)).imag,
5366
+ S4_iso.reshape((N_image, -1)).real,
5367
+ S4_iso.reshape((N_image, -1)).imag,
5368
+ ),dim=-1)
5369
+ else:
5370
+ if ref_sigma is not None:
5371
+ for_synthesis = self.backend.backend.cat((
5372
+ mean_data/ref_sigma['std_data'],
5373
+ std_data/ref_sigma['std_data'],
5374
+ (S2/ref_sigma['S2_sigma']).reshape((N_image, -1)).log(),
5375
+ (S1/ref_sigma['S1_sigma']).reshape((N_image, -1)).log(),
5376
+ (S3/ref_sigma['S3_sigma']).reshape((N_image, -1)).real,
5377
+ (S3/ref_sigma['S3_sigma']).reshape((N_image, -1)).imag,
5378
+ (S4/ref_sigma['S4_sigma']).reshape((N_image, -1)).real,
5379
+ (S4/ref_sigma['S4_sigma']).reshape((N_image, -1)).imag,
5380
+ ),dim=-1)
5381
+ else:
5382
+ for_synthesis = self.backend.backend.cat((
5383
+ mean_data/std_data,
5384
+ std_data,
5385
+ S2.reshape((N_image, -1)).log(),
5386
+ S1.reshape((N_image, -1)).log(),
5387
+ S3.reshape((N_image, -1)).real,
5388
+ S3.reshape((N_image, -1)).imag,
5389
+ S4.reshape((N_image, -1)).real,
5390
+ S4.reshape((N_image, -1)).imag,
5391
+ ),dim=-1)
5392
+ else:
5393
+ if iso_ang:
5394
+ if ref_sigma is not None:
5395
+ for_synthesis = self.backend.backend.cat((
5396
+ mean_data/ref_sigma['std_data'],
5397
+ std_data/ref_sigma['std_data'],
5398
+ (S2_iso/ref_sigma['S2_sigma']).reshape((N_image, -1)),
5399
+ (S1_iso/ref_sigma['S1_sigma']).reshape((N_image, -1)),
5400
+ (S3_iso/ref_sigma['S3_sigma']).reshape((N_image, -1)).real,
5401
+ (S3_iso/ref_sigma['S3_sigma']).reshape((N_image, -1)).imag,
5402
+ (S3p_iso/ref_sigma['S3p_sigma']).reshape((N_image, -1)).real,
5403
+ (S3p_iso/ref_sigma['S3p_sigma']).reshape((N_image, -1)).imag,
5404
+ (S4_iso/ref_sigma['S4_sigma']).reshape((N_image, -1)).real,
5405
+ (S4_iso/ref_sigma['S4_sigma']).reshape((N_image, -1)).imag,
5406
+ ),dim=-1)
5407
+ else:
5408
+ for_synthesis = self.backend.backend.cat((
5409
+ mean_data/std_data,
5410
+ std_data,
5411
+ S2_iso.reshape((N_image, -1)),
5412
+ S1_iso.reshape((N_image, -1)),
5413
+ S3_iso.reshape((N_image, -1)).real,
5414
+ S3_iso.reshape((N_image, -1)).imag,
5415
+ S3p_iso.reshape((N_image, -1)).real,
5416
+ S3p_iso.reshape((N_image, -1)).imag,
5417
+ S4_iso.reshape((N_image, -1)).real,
5418
+ S4_iso.reshape((N_image, -1)).imag,
5419
+ ),dim=-1)
5420
+ else:
5421
+ if ref_sigma is not None:
5422
+ for_synthesis = self.backend.backend.cat((
5423
+ mean_data/ref_sigma['std_data'],
5424
+ std_data/ref_sigma['std_data'],
5425
+ (S2/ref_sigma['S2_sigma']).reshape((N_image, -1)),
5426
+ (S1/ref_sigma['S1_sigma']).reshape((N_image, -1)),
5427
+ (S3/ref_sigma['S3_sigma']).reshape((N_image, -1)).real,
5428
+ (S3/ref_sigma['S3_sigma']).reshape((N_image, -1)).imag,
5429
+ (S3p/ref_sigma['S3p_sigma']).reshape((N_image, -1)).real,
5430
+ (S3p/ref_sigma['S3p_sigma']).reshape((N_image, -1)).imag,
5431
+ (S4/ref_sigma['S4_sigma']).reshape((N_image, -1)).real,
5432
+ (S4/ref_sigma['S4_sigma']).reshape((N_image, -1)).imag,
5433
+ ),dim=-1)
5434
+ else:
5435
+ for_synthesis = self.backend.backend.cat((
5436
+ mean_data/std_data,
5437
+ std_data,
5438
+ S2.reshape((N_image, -1)),
5439
+ S1.reshape((N_image, -1)),
5440
+ S3.reshape((N_image, -1)).real,
5441
+ S3.reshape((N_image, -1)).imag,
5442
+ S3p.reshape((N_image, -1)).real,
5443
+ S3p.reshape((N_image, -1)).imag,
5444
+ S4.reshape((N_image, -1)).real,
5445
+ S4.reshape((N_image, -1)).imag,
5446
+ ),dim=-1)
5447
+
5448
+ if not use_ref:
5449
+ self.ref_scattering_cov_S2=S2
5450
+
5451
+ if get_variance:
5452
+ return for_synthesis,ref_sigma
5453
+
5454
+ return for_synthesis
5455
+
5456
+
3648
5457
  def to_gaussian(self,x):
3649
5458
  from scipy.stats import norm
3650
5459
  from scipy.interpolate import interp1d
@@ -4021,8 +5830,12 @@ class funct(FOC.FoCUS):
4021
5830
  image_target,
4022
5831
  nstep=4,
4023
5832
  seed=1234,
4024
- edge=True,
5833
+ Jmax=None,
5834
+ edge=False,
4025
5835
  to_gaussian=True,
5836
+ use_variance=False,
5837
+ synthesised_N=1,
5838
+ iso_ang=False,
4026
5839
  EVAL_FREQUENCY=100,
4027
5840
  NUM_EPOCHS = 300):
4028
5841
 
@@ -4032,13 +5845,16 @@ class funct(FOC.FoCUS):
4032
5845
  def The_loss(u,scat_operator,args):
4033
5846
  ref = args[0]
4034
5847
  sref = args[1]
5848
+ use_v= args[2]
4035
5849
 
4036
5850
  # compute scattering covariance of the current synthetised map called u
4037
- learn=scat_operator.reduce_mean_batch(scat_operator.eval(u,edge=edge))
5851
+ if use_v:
5852
+ learn=scat_operator.reduce_mean_batch(scat_operator.scattering_cov(u,edge=edge,Jmax=Jmax,ref_sigma=sref,use_ref=True,iso_ang=iso_ang))
5853
+ else:
5854
+ learn=scat_operator.reduce_mean_batch(scat_operator.scattering_cov(u,edge=edge,Jmax=Jmax,use_ref=True,iso_ang=iso_ang))
4038
5855
 
4039
5856
  # make the difference withe the reference coordinates
4040
- loss=scat_operator.reduce_distance(learn,ref,sigma=sref)
4041
-
5857
+ loss=scat_operator.backend.bk_reduce_mean(scat_operator.backend.bk_square((learn-ref)))
4042
5858
  return loss
4043
5859
 
4044
5860
  if to_gaussian:
@@ -4078,33 +5894,35 @@ class funct(FOC.FoCUS):
4078
5894
  if k==0:
4079
5895
  np.random.seed(seed)
4080
5896
  if self.use_2D:
4081
- imap=np.random.randn(tmp[k].shape[0],
5897
+ imap=np.random.randn(synthesised_N,
4082
5898
  tmp[k].shape[1],
4083
5899
  tmp[k].shape[2])
4084
5900
  else:
4085
- imap=np.random.randn(tmp[k].shape[0],
5901
+ imap=np.random.randn(synthesised_N,
4086
5902
  tmp[k].shape[1])
4087
5903
  else:
4088
- axis=1
4089
- # if the kernel size is bigger than 3 increase the binning before smoothing
5904
+ # Increase the resolution between each step
4090
5905
  if self.use_2D:
4091
5906
  imap = self.up_grade(
4092
- omap, imap.shape[axis] * 2, axis=1, nouty=imap.shape[axis + 1] * 2
5907
+ omap, imap.shape[1] * 2, axis=1, nouty=imap.shape[2] * 2
4093
5908
  )
4094
5909
  elif self.use_1D:
4095
- imap = self.up_grade(omap, imap.shape[axis] * 2, axis=1)
5910
+ imap = self.up_grade(omap, imap.shape[1] * 2, axis=1)
4096
5911
  else:
4097
- imap = self.up_grade(omap, l_nside, axis=axis)
5912
+ imap = self.up_grade(omap, l_nside, axis=1)
4098
5913
 
4099
5914
  # compute the coefficients for the target image
4100
- ref,sref=self.eval(tmp[k],calc_var=True,edge=edge)
4101
-
5915
+ if use_variance:
5916
+ ref,sref=self.scattering_cov(tmp[k],get_variance=True,edge=edge,Jmax=Jmax,iso_ang=iso_ang)
5917
+ else:
5918
+ ref=self.scattering_cov(tmp[k],edge=edge,Jmax=Jmax,iso_ang=iso_ang)
5919
+ sref=ref
5920
+
4102
5921
  # compute the mean of the population does nothing if only one map is given
4103
5922
  ref=self.reduce_mean_batch(ref)
4104
- sref=self.reduce_mean_batch(sref)
4105
-
5923
+
4106
5924
  # define a loss to minimize
4107
- loss=synthe.Loss(The_loss,self,ref,sref)
5925
+ loss=synthe.Loss(The_loss,self,ref,sref,use_variance)
4108
5926
 
4109
5927
  sy = synthe.Synthesis([loss])
4110
5928
 
@@ -4121,6 +5939,8 @@ class funct(FOC.FoCUS):
4121
5939
  omap=sy.run(imap,
4122
5940
  EVAL_FREQUENCY=EVAL_FREQUENCY,
4123
5941
  NUM_EPOCHS = NUM_EPOCHS)
5942
+
5943
+
4124
5944
 
4125
5945
  t2=time.time()
4126
5946
  print('Total computation %.2fs'%(t2-t1))
@@ -4128,7 +5948,7 @@ class funct(FOC.FoCUS):
4128
5948
  if to_gaussian:
4129
5949
  omap=self.from_gaussian(omap)
4130
5950
 
4131
- if axis==0:
5951
+ if axis==0 and synthesised_N==1:
4132
5952
  return omap[0]
4133
5953
  else:
4134
5954
  return omap