foscat 3.9.0__py3-none-any.whl → 2025.5.0__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
foscat/scat_cov.py CHANGED
@@ -177,9 +177,21 @@ class scat_cov:
177
177
  ],
178
178
  )
179
179
  ),
180
- self.conv2complex(
180
+ self.backend.bk_reshape(
181
+ self.S3,
182
+ [
183
+ self.S3.shape[0],
184
+ self.S3.shape[1]
185
+ * self.S3.shape[2]
186
+ * self.S3.shape[3]
187
+ * self.S3.shape[4],
188
+ ],
189
+ ),
190
+ ]
191
+ if self.S3P is not None:
192
+ tmp = tmp + [
181
193
  self.backend.bk_reshape(
182
- self.S3,
194
+ self.S3P,
183
195
  [
184
196
  self.S3.shape[0],
185
197
  self.S3.shape[1]
@@ -188,37 +200,19 @@ class scat_cov:
188
200
  * self.S3.shape[4],
189
201
  ],
190
202
  )
191
- ),
192
- ]
193
- if self.S3P is not None:
194
- tmp = tmp + [
195
- self.conv2complex(
196
- self.backend.bk_reshape(
197
- self.S3P,
198
- [
199
- self.S3.shape[0],
200
- self.S3.shape[1]
201
- * self.S3.shape[2]
202
- * self.S3.shape[3]
203
- * self.S3.shape[4],
204
- ],
205
- )
206
- )
207
203
  ]
208
204
 
209
205
  tmp = tmp + [
210
- self.conv2complex(
211
- self.backend.bk_reshape(
212
- self.S4,
213
- [
214
- self.S3.shape[0],
215
- self.S4.shape[1]
216
- * self.S4.shape[2]
217
- * self.S4.shape[3]
218
- * self.S4.shape[4]
219
- * self.S4.shape[5],
220
- ],
221
- )
206
+ self.backend.bk_reshape(
207
+ self.S4,
208
+ [
209
+ self.S3.shape[0],
210
+ self.S4.shape[1]
211
+ * self.S4.shape[2]
212
+ * self.S4.shape[3]
213
+ * self.S4.shape[4]
214
+ * self.S4.shape[5],
215
+ ],
222
216
  )
223
217
  ]
224
218
 
@@ -504,7 +498,7 @@ class scat_cov:
504
498
  if other.S1 is None:
505
499
  s1 = None
506
500
  else:
507
- s1 = self.S1 + other.S1
501
+ s1 = self.doadd(self.S1, other.S1)
508
502
  else:
509
503
  s1 = self.S1 + other
510
504
 
@@ -534,7 +528,7 @@ class scat_cov:
534
528
  return scat_cov(
535
529
  self.doadd(self.S0, other.S0),
536
530
  self.doadd(self.S2, other.S2),
537
- (self.S3 + other.S3),
531
+ self.doadd(self.S3, other.S3),
538
532
  s4,
539
533
  s1=s1,
540
534
  s3p=s3p,
@@ -662,7 +656,7 @@ class scat_cov:
662
656
  s1 = None
663
657
  else:
664
658
  if isinstance(other, scat_cov):
665
- s1 = other.S1 / self.S1
659
+ s1 = self.dodiv(other.S1, self.S1)
666
660
  else:
667
661
  s1 = other / self.S1
668
662
 
@@ -689,7 +683,7 @@ class scat_cov:
689
683
  return scat_cov(
690
684
  self.dodiv(other.S0, self.S0),
691
685
  self.dodiv(other.S2, self.S2),
692
- (other.S3 / self.S3),
686
+ self.dodiv(other.S3, self.S3),
693
687
  s4,
694
688
  s1=s1,
695
689
  s3p=s3p,
@@ -725,7 +719,7 @@ class scat_cov:
725
719
  if other.S1 is None:
726
720
  s1 = None
727
721
  else:
728
- s1 = other.S1 - self.S1
722
+ s1 = self.domin(other.S1, self.S1)
729
723
  else:
730
724
  s1 = other - self.S1
731
725
 
@@ -755,7 +749,7 @@ class scat_cov:
755
749
  return scat_cov(
756
750
  self.domin(other.S0, self.S0),
757
751
  self.domin(other.S2, self.S2),
758
- (other.S3 - self.S3),
752
+ self.domin(other.S3, self.S3),
759
753
  s4,
760
754
  s1=s1,
761
755
  s3p=s3p,
@@ -790,7 +784,7 @@ class scat_cov:
790
784
  if other.S1 is None:
791
785
  s1 = None
792
786
  else:
793
- s1 = self.S1 - other.S1
787
+ s1 = self.domin(self.S1, other.S1)
794
788
  else:
795
789
  s1 = self.S1 - other
796
790
 
@@ -820,7 +814,7 @@ class scat_cov:
820
814
  return scat_cov(
821
815
  self.domin(self.S0, other.S0),
822
816
  self.domin(self.S2, other.S2),
823
- (self.S3 - other.S3),
817
+ self.domin(self.S3, other.S3),
824
818
  s4,
825
819
  s1=s1,
826
820
  s3p=s3p,
@@ -920,7 +914,7 @@ class scat_cov:
920
914
  if other.S1 is None:
921
915
  s1 = None
922
916
  else:
923
- s1 = self.S1 * other.S1
917
+ s1 = self.domult(self.S1, other.S1)
924
918
  else:
925
919
  s1 = self.S1 * other
926
920
 
@@ -2215,7 +2209,7 @@ class scat_cov:
2215
2209
 
2216
2210
 
2217
2211
  class funct(FOC.FoCUS):
2218
-
2212
+
2219
2213
  def fill(self, im, nullval=hp.UNSEEN):
2220
2214
  if self.use_2D:
2221
2215
  return self.fill_2d(im, nullval=nullval)
@@ -2402,8 +2396,9 @@ class funct(FOC.FoCUS):
2402
2396
  if smooth_scale > 0:
2403
2397
  for m in range(smooth_scale):
2404
2398
  if cc.shape[0] > 12:
2405
- cc = self.ud_grade_2(self.smooth(cc))
2406
- ss = self.ud_grade_2(self.smooth(ss))
2399
+ cc, _ = self.ud_grade_2(self.smooth(cc))
2400
+ ss, _ = self.ud_grade_2(self.smooth(ss))
2401
+
2407
2402
  if cc.shape[0] != tmp.shape[0]:
2408
2403
  ll_nside = int(np.sqrt(tmp.shape[1] // 12))
2409
2404
  cc = self.up_grade(cc, ll_nside)
@@ -2473,8 +2468,8 @@ class funct(FOC.FoCUS):
2473
2468
  if smooth_scale > 0:
2474
2469
  for m in range(smooth_scale):
2475
2470
  if cc2.shape[0] > 12:
2476
- cc2 = self.ud_grade_2(self.smooth(cc2))
2477
- ss2 = self.ud_grade_2(self.smooth(ss2))
2471
+ cc2, _ = self.ud_grade_2(self.smooth(cc2))
2472
+ ss2, _ = self.ud_grade_2(self.smooth(ss2))
2478
2473
 
2479
2474
  if cc2.shape[0] != sim.shape[1]:
2480
2475
  ll_nside = int(np.sqrt(sim.shape[1] // 12))
@@ -2510,9 +2505,9 @@ class funct(FOC.FoCUS):
2510
2505
  cmat2[k] = self.backend.bk_cast(mat2.astype("complex64"))
2511
2506
 
2512
2507
  if k < l_nside - 1:
2513
- tmp = self.ud_grade_2(tmp, axis=1)
2508
+ tmp, _ = self.ud_grade_2(tmp, axis=1)
2514
2509
  if image2 is not None:
2515
- tmpi2 = self.ud_grade_2(tmpi2, axis=1)
2510
+ tmpi2, _ = self.ud_grade_2(tmpi2, axis=1)
2516
2511
  return cmat, cmat2
2517
2512
 
2518
2513
  def div_norm(self, complex_value, float_value):
@@ -2533,6 +2528,8 @@ class funct(FOC.FoCUS):
2533
2528
  Jmax=None,
2534
2529
  out_nside=None,
2535
2530
  edge=True,
2531
+ nside=None,
2532
+ cell_ids=None,
2536
2533
  ):
2537
2534
  """
2538
2535
  Calculates the scattering correlations for a batch of images. Mean are done over pixels.
@@ -2606,7 +2603,7 @@ class funct(FOC.FoCUS):
2606
2603
  cross = False
2607
2604
  if image2 is not None:
2608
2605
  cross = True
2609
-
2606
+ l_nside = 2**32 # not initialize if 1D or 2D
2610
2607
  ### PARAMETERS
2611
2608
  axis = 1
2612
2609
  # determine jmax and nside corresponding to the input map
@@ -2638,7 +2635,8 @@ class funct(FOC.FoCUS):
2638
2635
  else:
2639
2636
  npix = int(im_shape[0]) # Number of pixels
2640
2637
 
2641
- nside = int(np.sqrt(npix // 12))
2638
+ if nside is None:
2639
+ nside = int(np.sqrt(npix // 12))
2642
2640
 
2643
2641
  J = int(np.log(nside) / np.log(2)) # Number of j scales
2644
2642
 
@@ -2675,7 +2673,7 @@ class funct(FOC.FoCUS):
2675
2673
  else:
2676
2674
  vmask = self.backend.bk_cast(mask) # [Nmask, Npix]
2677
2675
 
2678
- if self.KERNELSZ > 3 and not self.use_2D:
2676
+ if self.KERNELSZ > 3 and not self.use_2D and cell_ids is None:
2679
2677
  # if the kernel size is bigger than 3 increase the binning before smoothing
2680
2678
  if self.use_2D:
2681
2679
  vmask = self.up_grade(
@@ -2693,12 +2691,15 @@ class funct(FOC.FoCUS):
2693
2691
  I1 = self.up_grade(I1, I1.shape[axis] * 2, axis=axis)
2694
2692
  if cross:
2695
2693
  I2 = self.up_grade(I2, I2.shape[axis] * 2, axis=axis)
2694
+ nside = nside * 2
2696
2695
  else:
2697
2696
  I1 = self.up_grade(I1, nside * 2, axis=axis)
2698
2697
  vmask = self.up_grade(vmask, nside * 2, axis=1)
2699
2698
  if cross:
2700
2699
  I2 = self.up_grade(I2, nside * 2, axis=axis)
2701
2700
 
2701
+ nside = nside * 2
2702
+
2702
2703
  if self.KERNELSZ > 5 and not self.use_2D:
2703
2704
  # if the kernel size is bigger than 3 increase the binning before smoothing
2704
2705
  if self.use_2D:
@@ -2716,15 +2717,17 @@ class funct(FOC.FoCUS):
2716
2717
  nouty=I2.shape[axis + 1] * 2,
2717
2718
  )
2718
2719
  elif self.use_1D:
2719
- vmask = self.up_grade(vmask, I1.shape[axis] * 4, axis=1)
2720
- I1 = self.up_grade(I1, I1.shape[axis] * 4, axis=axis)
2720
+ vmask = self.up_grade(vmask, I1.shape[axis] * 2, axis=1)
2721
+ I1 = self.up_grade(I1, I1.shape[axis] * 2, axis=axis)
2721
2722
  if cross:
2722
- I2 = self.up_grade(I2, I2.shape[axis] * 4, axis=axis)
2723
+ I2 = self.up_grade(I2, I2.shape[axis] * 2, axis=axis)
2724
+ nside = nside * 2
2723
2725
  else:
2724
- I1 = self.up_grade(I1, nside * 4, axis=axis)
2725
- vmask = self.up_grade(vmask, nside * 4, axis=1)
2726
+ I1 = self.up_grade(I1, nside * 2, axis=axis)
2727
+ vmask = self.up_grade(vmask, nside * 2, axis=1)
2726
2728
  if cross:
2727
- I2 = self.up_grade(I2, nside * 4, axis=axis)
2729
+ I2 = self.up_grade(I2, nside * 2, axis=axis)
2730
+ nside = nside * 2
2728
2731
 
2729
2732
  # Normalize the masks because they have different pixel numbers
2730
2733
  # vmask /= self.backend.bk_reduce_sum(vmask, axis=1)[:, None] # [Nmask, Npix]
@@ -2797,6 +2800,8 @@ class funct(FOC.FoCUS):
2797
2800
  M1_dic = {}
2798
2801
  M2_dic = {}
2799
2802
 
2803
+ cell_ids_j3 = cell_ids
2804
+
2800
2805
  for j3 in range(Jmax):
2801
2806
 
2802
2807
  if edge:
@@ -2835,7 +2840,9 @@ class funct(FOC.FoCUS):
2835
2840
 
2836
2841
  ####### S1 and S2
2837
2842
  ### Make the convolution I1 * Psi_j3
2838
- conv1 = self.convol(I1, axis=1) # [Nbatch, Npix_j3, Norient3]
2843
+ conv1 = self.convol(
2844
+ I1, axis=1, cell_ids=cell_ids_j3, nside=nside_j3
2845
+ ) # [Nbatch, Npix_j3, Norient3]
2839
2846
  if cmat is not None:
2840
2847
  tmp2 = self.backend.bk_repeat(conv1, self.NORIENT, axis=-1)
2841
2848
  conv1 = self.backend.bk_reduce_sum(
@@ -2948,7 +2955,9 @@ class funct(FOC.FoCUS):
2948
2955
 
2949
2956
  else: # Cross
2950
2957
  ### Make the convolution I2 * Psi_j3
2951
- conv2 = self.convol(I2, axis=1) # [Nbatch, Npix_j3, Norient3]
2958
+ conv2 = self.convol(
2959
+ I2, axis=1, cell_ids=cell_ids_j3, nside=nside_j3
2960
+ ) # [Nbatch, Npix_j3, Norient3]
2952
2961
  if cmat is not None:
2953
2962
  tmp2 = self.backend.bk_repeat(conv2, self.NORIENT, axis=-1)
2954
2963
  conv2 = self.backend.bk_reduce_sum(
@@ -3028,12 +3037,13 @@ class funct(FOC.FoCUS):
3028
3037
  )
3029
3038
  S2[j3] = s2
3030
3039
  else:
3031
- ### Normalize S2_cross
3032
- if norm == "auto":
3033
- s2 /= (P1_dic[j3] * P2_dic[j3]) ** 0.5
3034
3040
 
3035
3041
  ### Store S2_cross as complex [Nbatch, Nmask, NS2, Norient3]
3036
3042
  s2 = self.backend.bk_real(s2)
3043
+
3044
+ ### Normalize S2_cross
3045
+ if norm == "auto":
3046
+ s2 /= (P1_dic[j3] * P2_dic[j3]) ** 0.5
3037
3047
 
3038
3048
  S2.append(
3039
3049
  self.backend.bk_expand_dims(s2, off_S2)
@@ -3092,7 +3102,6 @@ class funct(FOC.FoCUS):
3092
3102
  M2convPsi_dic = {}
3093
3103
 
3094
3104
  ###### S3
3095
- nside_j2 = nside_j3
3096
3105
  for j2 in range(0, j3 + 1): # j2 <= j3
3097
3106
  if return_data:
3098
3107
  if S4[j3] is None:
@@ -3111,6 +3120,8 @@ class funct(FOC.FoCUS):
3111
3120
  M1convPsi_dic,
3112
3121
  calc_var=True,
3113
3122
  cmat2=cmat2,
3123
+ cell_ids=cell_ids_j3,
3124
+ nside_j2=nside_j3,
3114
3125
  ) # [Nbatch, Nmask, Norient3, Norient2]
3115
3126
  else:
3116
3127
  s3 = self._compute_S3(
@@ -3122,19 +3133,21 @@ class funct(FOC.FoCUS):
3122
3133
  M1convPsi_dic,
3123
3134
  return_data=return_data,
3124
3135
  cmat2=cmat2,
3136
+ cell_ids=cell_ids_j3,
3137
+ nside_j2=nside_j3,
3125
3138
  ) # [Nbatch, Nmask, Norient3, Norient2]
3126
3139
 
3127
3140
  if return_data:
3128
3141
  if S3[j3] is None:
3129
3142
  S3[j3] = {}
3130
- if out_nside is not None and out_nside < nside_j2:
3143
+ if out_nside is not None and out_nside < nside_j3:
3131
3144
  s3 = self.backend.bk_reduce_mean(
3132
3145
  self.backend.bk_reshape(
3133
3146
  s3,
3134
3147
  [
3135
3148
  s3.shape[0],
3136
3149
  12 * out_nside**2,
3137
- (nside_j2 // out_nside) ** 2,
3150
+ (nside_j3 // out_nside) ** 2,
3138
3151
  s3.shape[2],
3139
3152
  s3.shape[3],
3140
3153
  ],
@@ -3181,6 +3194,8 @@ class funct(FOC.FoCUS):
3181
3194
  M2convPsi_dic,
3182
3195
  calc_var=True,
3183
3196
  cmat2=cmat2,
3197
+ cell_ids=cell_ids_j3,
3198
+ nside_j2=nside_j3,
3184
3199
  )
3185
3200
  s3p, vs3p = self._compute_S3(
3186
3201
  j2,
@@ -3191,41 +3206,47 @@ class funct(FOC.FoCUS):
3191
3206
  M1convPsi_dic,
3192
3207
  calc_var=True,
3193
3208
  cmat2=cmat2,
3209
+ cell_ids=cell_ids_j3,
3210
+ nside_j2=nside_j3,
3194
3211
  )
3195
3212
  else:
3196
- s3 = self._compute_S3(
3213
+ s3p = self._compute_S3(
3197
3214
  j2,
3198
3215
  j3,
3199
- conv1,
3216
+ conv2,
3200
3217
  vmask,
3201
- M2_dic,
3202
- M2convPsi_dic,
3218
+ M1_dic,
3219
+ M1convPsi_dic,
3203
3220
  return_data=return_data,
3204
3221
  cmat2=cmat2,
3222
+ cell_ids=cell_ids_j3,
3223
+ nside_j2=nside_j3,
3205
3224
  )
3206
- s3p = self._compute_S3(
3225
+ s3 = self._compute_S3(
3207
3226
  j2,
3208
3227
  j3,
3209
- conv2,
3228
+ conv1,
3210
3229
  vmask,
3211
- M1_dic,
3212
- M1convPsi_dic,
3230
+ M2_dic,
3231
+ M2convPsi_dic,
3213
3232
  return_data=return_data,
3214
3233
  cmat2=cmat2,
3234
+ cell_ids=cell_ids_j3,
3235
+ nside_j2=nside_j3,
3215
3236
  )
3216
3237
 
3217
3238
  if return_data:
3218
3239
  if S3[j3] is None:
3219
3240
  S3[j3] = {}
3220
3241
  S3P[j3] = {}
3221
- if out_nside is not None and out_nside < nside_j2:
3242
+ if out_nside is not None and out_nside < nside_j3:
3222
3243
  s3 = self.backend.bk_reduce_mean(
3223
3244
  self.backend.bk_reshape(
3224
3245
  s3,
3225
3246
  [
3226
3247
  s3.shape[0],
3227
3248
  12 * out_nside**2,
3228
- (nside_j2 // out_nside) ** 2,
3249
+ (nside_j3 // out_nside) ** 2,
3229
3250
  s3.shape[2],
3230
3251
  s3.shape[3],
3231
3252
  ],
@@ -3238,7 +3259,7 @@ class funct(FOC.FoCUS):
3238
3259
  [
3239
3260
  s3.shape[0],
3240
3261
  12 * out_nside**2,
3241
- (nside_j2 // out_nside) ** 2,
3262
+ (nside_j3 // out_nside) ** 2,
3242
3263
  s3.shape[2],
3243
3264
  s3.shape[3],
3244
3265
  ],
@@ -3295,7 +3316,6 @@ class funct(FOC.FoCUS):
3295
3316
  # s3.shape[2]*s3.shape[3]]))
3296
3317
 
3297
3318
  ##### S4
3298
- nside_j1 = nside_j2
3299
3319
  for j1 in range(0, j2 + 1): # j1 <= j2
3300
3320
  ### S4_auto = <(|I1 * psi1| * psi3)(|I1 * psi2| * psi3)^*>
3301
3321
  if not cross:
@@ -3321,14 +3341,14 @@ class funct(FOC.FoCUS):
3321
3341
  if return_data:
3322
3342
  if S4[j3][j2] is None:
3323
3343
  S4[j3][j2] = {}
3324
- if out_nside is not None and out_nside < nside_j1:
3344
+ if out_nside is not None and out_nside < nside_j3:
3325
3345
  s4 = self.backend.bk_reduce_mean(
3326
3346
  self.backend.bk_reshape(
3327
3347
  s4,
3328
3348
  [
3329
3349
  s4.shape[0],
3330
3350
  12 * out_nside**2,
3331
- (nside_j1 // out_nside) ** 2,
3351
+ (nside_j3 // out_nside) ** 2,
3332
3352
  s4.shape[2],
3333
3353
  s4.shape[3],
3334
3354
  s4.shape[4],
@@ -3396,14 +3416,14 @@ class funct(FOC.FoCUS):
3396
3416
  if return_data:
3397
3417
  if S4[j3][j2] is None:
3398
3418
  S4[j3][j2] = {}
3399
- if out_nside is not None and out_nside < nside_j1:
3419
+ if out_nside is not None and out_nside < nside_j3:
3400
3420
  s4 = self.backend.bk_reduce_mean(
3401
3421
  self.backend.bk_reshape(
3402
3422
  s4,
3403
3423
  [
3404
3424
  s4.shape[0],
3405
3425
  12 * out_nside**2,
3406
- (nside_j1 // out_nside) ** 2,
3426
+ (nside_j3 // out_nside) ** 2,
3407
3427
  s4.shape[2],
3408
3428
  s4.shape[3],
3409
3429
  s4.shape[4],
@@ -3447,48 +3467,51 @@ class funct(FOC.FoCUS):
3447
3467
  self.backend.bk_expand_dims(vs4, off_S4)
3448
3468
  ) # Add a dimension for NS4
3449
3469
 
3450
- nside_j1 = nside_j1 // 2
3451
- nside_j2 = nside_j2 // 2
3452
-
3453
3470
  ###### Reshape for next iteration on j3
3454
3471
  ### Image I1,
3455
3472
  # downscale the I1 [Nbatch, Npix_j3]
3456
3473
  if j3 != Jmax - 1:
3457
- I1 = self.smooth(I1, axis=1)
3458
- I1 = self.ud_grade_2(I1, axis=1)
3474
+ I1 = self.smooth(I1, axis=1, cell_ids=cell_ids_j3, nside=nside_j3)
3475
+ I1, new_cell_ids_j3 = self.ud_grade_2(
3476
+ I1, axis=1, cell_ids=cell_ids_j3, nside=nside_j3
3477
+ )
3459
3478
 
3460
3479
  ### Image I2
3461
3480
  if cross:
3462
- I2 = self.smooth(I2, axis=1)
3463
- I2 = self.ud_grade_2(I2, axis=1)
3481
+ I2 = self.smooth(I2, axis=1, cell_ids=cell_ids_j3, nside=nside_j3)
3482
+ I2, new_cell_ids_j3 = self.ud_grade_2(
3483
+ I2, axis=1, cell_ids=cell_ids_j3, nside=nside_j3
3484
+ )
3464
3485
 
3465
3486
  ### Modules
3466
3487
  for j2 in range(0, j3 + 1): # j2 =< j3
3467
3488
  ### Dictionary M1_dic[j2]
3468
3489
  M1_smooth = self.smooth(
3469
- M1_dic[j2], axis=1
3490
+ M1_dic[j2], axis=1, cell_ids=cell_ids_j3, nside=nside_j3
3470
3491
  ) # [Nbatch, Npix_j3, Norient3]
3471
- M1_dic[j2] = self.ud_grade_2(
3472
- M1_smooth, axis=1
3492
+ M1_dic[j2], new_cell_ids_j2 = self.ud_grade_2(
3493
+ M1_smooth, axis=1, cell_ids=cell_ids_j3, nside=nside_j3
3473
3494
  ) # [Nbatch, Npix_j3, Norient3]
3474
3495
 
3475
3496
  ### Dictionary M2_dic[j2]
3476
3497
  if cross:
3477
3498
  M2_smooth = self.smooth(
3478
- M2_dic[j2], axis=1
3499
+ M2_dic[j2], axis=1, cell_ids=cell_ids_j3, nside=nside_j3
3479
3500
  ) # [Nbatch, Npix_j3, Norient3]
3480
- M2_dic[j2] = self.ud_grade_2(
3481
- M2_smooth, axis=1
3501
+ M2_dic[j2], new_cell_ids_j2 = self.ud_grade_2(
3502
+ M2_smooth, axis=1, cell_ids=cell_ids_j3, nside=nside_j3
3482
3503
  ) # [Nbatch, Npix_j3, Norient3]
3483
-
3484
3504
  ### Mask
3485
- vmask = self.ud_grade_2(vmask, axis=1)
3505
+ vmask, new_cell_ids_j3 = self.ud_grade_2(
3506
+ vmask, axis=1, cell_ids=cell_ids_j3, nside=nside_j3
3507
+ )
3486
3508
 
3487
3509
  if self.mask_thres is not None:
3488
3510
  vmask = self.backend.bk_threshold(vmask, self.mask_thres)
3489
3511
 
3490
3512
  ### NSIDE_j3
3491
3513
  nside_j3 = nside_j3 // 2
3514
+ cell_ids_j3 = new_cell_ids_j3
3492
3515
 
3493
3516
  ### Store P1_dic and P2_dic in self
3494
3517
  if (norm == "auto") and (self.P1_dic is None):
@@ -3588,6 +3611,8 @@ class funct(FOC.FoCUS):
3588
3611
  calc_var=False,
3589
3612
  return_data=False,
3590
3613
  cmat2=None,
3614
+ cell_ids=None,
3615
+ nside_j2=None,
3591
3616
  ):
3592
3617
  """
3593
3618
  Compute the S3 coefficients (auto or cross)
@@ -3601,7 +3626,7 @@ class funct(FOC.FoCUS):
3601
3626
  ### Compute |I1 * Psi_j2| * Psi_j3 = M1_j2 * Psi_j3
3602
3627
  # Warning: M1_dic[j2] is already at j3 resolution [Nbatch, Npix_j3, Norient3]
3603
3628
  MconvPsi = self.convol(
3604
- M_dic[j2], axis=1
3629
+ M_dic[j2], axis=1, cell_ids=cell_ids, nside=nside_j2
3605
3630
  ) # [Nbatch, Npix_j3, Norient3, Norient2]
3606
3631
  if cmat2 is not None:
3607
3632
  tmp2 = self.backend.bk_repeat(MconvPsi, self.NORIENT, axis=-1)
@@ -3822,24 +3847,62 @@ class funct(FOC.FoCUS):
3822
3847
  dy = int(max(8, min(np.ceil(N / 2**j), N // 2)))
3823
3848
  return dx, dy
3824
3849
 
3825
- def get_edge_masks(self, M, N, J, d0=1):
3850
+ def get_edge_masks(self, M, N, J, d0=1, in_mask=None, edge_dx=None, edge_dy=None):
3826
3851
  """
3827
3852
  This function is strongly inspire by the package https://github.com/SihaoCheng/scattering_transform
3828
3853
  Done by Sihao Cheng and Rudy Morel.
3829
3854
  """
3830
3855
  edge_masks = np.empty((J, M, N))
3856
+
3831
3857
  X, Y = np.meshgrid(np.arange(M), np.arange(N), indexing="ij")
3832
- for j in range(J):
3833
- edge_dx = min(M // 4, 2**j * d0)
3834
- edge_dy = min(N // 4, 2**j * d0)
3835
- edge_masks[j] = (
3836
- (X >= edge_dx)
3837
- * (X <= M - edge_dx)
3838
- * (Y >= edge_dy)
3839
- * (Y <= N - edge_dy)
3858
+ if in_mask is not None:
3859
+ from scipy.ndimage import binary_erosion
3860
+
3861
+ if in_mask is not None:
3862
+ if in_mask.shape[0] != M or in_mask.shape[0] != N:
3863
+ l_mask = in_mask.reshape(
3864
+ M, in_mask.shape[0] // M, N, in_mask.shape[1] // N
3865
+ )
3866
+ l_mask = (
3867
+ np.sum(np.sum(l_mask, 1), 2)
3868
+ * (M * N)
3869
+ / (in_mask.shape[0] * in_mask.shape[1])
3870
+ )
3871
+ else:
3872
+ l_mask = in_mask
3873
+
3874
+ if edge_dx is None:
3875
+ for j in range(J):
3876
+ edge_dx = min(M // 4, 2**j * d0)
3877
+ edge_dy = min(N // 4, 2**j * d0)
3878
+
3879
+ edge_masks[j] = (
3880
+ (X >= edge_dx)
3881
+ * (X < M - edge_dx)
3882
+ * (Y >= edge_dy)
3883
+ * (Y < N - edge_dy)
3884
+ )
3885
+ if in_mask is not None:
3886
+ l_mask = binary_erosion(
3887
+ l_mask, iterations=1 + np.max([edge_dx, edge_dy])
3888
+ )
3889
+ edge_masks[j] *= l_mask
3890
+
3891
+ edge_masks = edge_masks[:, None, :, :]
3892
+
3893
+ edge_masks = edge_masks / edge_masks.mean((-2, -1))[:, :, None, None]
3894
+ else:
3895
+ edge_masks = (
3896
+ (X >= edge_dx) * (X < M - edge_dx) * (Y >= edge_dy) * (Y < N - edge_dy)
3840
3897
  )
3841
- edge_masks = edge_masks[:, None, :, :]
3842
- edge_masks = edge_masks / edge_masks.mean((-2, -1))[:, :, None, None]
3898
+ if in_mask is not None:
3899
+ l_mask = binary_erosion(
3900
+ l_mask, iterations=1 + np.max([edge_dx, edge_dy])
3901
+ )
3902
+ edge_masks *= l_mask
3903
+
3904
+ edge_masks = edge_masks / edge_masks.mean((-2, -1))
3905
+
3843
3906
  return self.backend.bk_cast(edge_masks)
3844
3907
 
3845
3908
  # ---------------------------------------------------------------------------
@@ -3857,6 +3920,7 @@ class funct(FOC.FoCUS):
3857
3920
  use_ref=False,
3858
3921
  normalization="S2",
3859
3922
  edge=False,
3923
+ in_mask=None,
3860
3924
  pseudo_coef=1,
3861
3925
  get_variance=False,
3862
3926
  ref_sigma=None,
@@ -3925,6 +3989,9 @@ class funct(FOC.FoCUS):
3925
3989
  if S4_criteria is None:
3926
3990
  S4_criteria = "j2>=j1"
3927
3991
 
3992
+ if not edge and in_mask is not None:
3993
+ edge = True
3994
+
3928
3995
  if self.all_bk_type == "float32":
3929
3996
  C_ONE = np.complex64(1.0)
3930
3997
  else:
@@ -3974,14 +4041,15 @@ class funct(FOC.FoCUS):
3974
4041
 
3975
4042
  J = int(np.log(nside) / np.log(2)) # Number of j scales
3976
4043
 
3977
- if Jmax is None:
3978
- Jmax = J # Number of steps for the loop on scales
3979
- if Jmax > J:
3980
- print("==========\n\n")
3981
- print(
3982
- "The Jmax you requested is larger than the data size, which may cause problems while computing the scattering transform."
3983
- )
3984
- print("\n\n==========")
4044
+ if Jmax is not None:
4045
+
4046
+ if Jmax > J:
4047
+ print("==========\n\n")
4048
+ print(
4049
+ "The Jmax you requested is larger than the data size, which may cause problems while computing the scattering transform."
4050
+ )
4051
+ print("\n\n==========")
4052
+ J = Jmax # Number of steps for the loop on scales
3985
4053
 
3986
4054
  L = self.NORIENT
3987
4055
  norm_factor_S3 = 1.0
@@ -4067,7 +4135,10 @@ class funct(FOC.FoCUS):
4067
4135
  #
4068
4136
  if edge:
4069
4137
  if (M, N, J) not in self.edge_masks:
4070
- self.edge_masks[(M, N, J)] = self.get_edge_masks(M, N, J)
4138
+ self.edge_masks[(M, N, J)] = self.get_edge_masks(
4139
+ M, N, J, in_mask=in_mask
4140
+ )
4141
+
4071
4142
  edge_mask = self.edge_masks[(M, N, J)]
4072
4143
  else:
4073
4144
  edge_mask = 1
@@ -4206,8 +4277,19 @@ class funct(FOC.FoCUS):
4206
4277
  wavelet_f3 = self.cut_high_k_off(filters_set[j3], dx3, dy3) # L,x,y
4207
4278
  _, M3, N3 = wavelet_f3.shape
4208
4279
  wavelet_f3_squared = wavelet_f3**2
4209
- edge_dx = min(4, int(2**j3 * dx3 * 2 / M))
4210
- edge_dy = min(4, int(2**j3 * dy3 * 2 / N))
4280
+ if edge is True:
4281
+ if (M3, N3, J, j3) not in self.edge_masks:
4282
+
4283
+ edge_dx = min(4, int(2**j3 * dx3 * 2 / M))
4284
+ edge_dy = min(4, int(2**j3 * dy3 * 2 / N))
4285
+
4286
+ self.edge_masks[(M3, N3, J, j3)] = self.get_edge_masks(
4287
+ M3, N3, J, in_mask=in_mask, edge_dx=edge_dx, edge_dy=edge_dy
4288
+ )
4289
+
4290
+ edge_mask = self.edge_masks[(M3, N3, J, j3)]
4291
+ else:
4292
+ edge_mask = 1
4211
4293
 
4212
4294
  # a normalization change due to the cutoff of frequency space
4213
4295
  fft_factor = 1 / (M3 * N3) * (M3 * N3 / M / N) ** 2
@@ -4274,7 +4356,8 @@ class funct(FOC.FoCUS):
4274
4356
  (
4275
4357
  data_small.view(N_image, 1, 1, M3, N3)
4276
4358
  * self.backend.bk_conjugate(I12_w3_small)
4277
- )[..., edge_dx : M3 - edge_dx, edge_dy : N3 - edge_dy].mean(
4359
+ * edge_mask[None, None, None, :, :]
4360
+ ).mean( # [..., edge_dx : M3 - edge_dx, edge_dy : N3 - edge_dy]
4278
4361
  (-2, -1)
4279
4362
  )
4280
4363
  * fft_factor
@@ -4285,11 +4368,8 @@ class funct(FOC.FoCUS):
4285
4368
  (
4286
4369
  data_small.view(N_image, 1, 1, M3, N3)
4287
4370
  * self.backend.bk_conjugate(I12_w3_small)
4288
- )[
4289
- ..., edge_dx : M3 - edge_dx, edge_dy : N3 - edge_dy
4290
- ].std(
4291
- (-2, -1)
4292
- )
4371
+ * edge_mask[None, None, None, :, :]
4372
+ ).std((-2, -1))
4293
4373
  * fft_factor
4294
4374
  / norm_factor_S3
4295
4375
  )
@@ -4318,11 +4398,8 @@ class funct(FOC.FoCUS):
4318
4398
  (
4319
4399
  data2_small.view(N_image2, 1, 1, M3, N3)
4320
4400
  * self.backend.bk_conjugate(I12_w3_small)
4321
- )[
4322
- ..., edge_dx : M3 - edge_dx, edge_dy : N3 - edge_dy
4323
- ].mean(
4324
- (-2, -1)
4325
- )
4401
+ * edge_mask[None, None, None, :, :]
4402
+ ).mean((-2, -1))
4326
4403
  * fft_factor
4327
4404
  / norm_factor_S3
4328
4405
  )
@@ -4331,13 +4408,8 @@ class funct(FOC.FoCUS):
4331
4408
  (
4332
4409
  data2_small.view(N_image2, 1, 1, M3, N3)
4333
4410
  * self.backend.bk_conjugate(I12_w3_small)
4334
- )[
4335
- ...,
4336
- edge_dx : M3 - edge_dx,
4337
- edge_dy : N3 - edge_dy,
4338
- ].std(
4339
- (-2, -1)
4340
- )
4411
+ * edge_mask[None, None, None, :, :]
4412
+ ).std((-2, -1))
4341
4413
  * fft_factor
4342
4414
  / norm_factor_S3
4343
4415
  )
@@ -4406,9 +4478,8 @@ class funct(FOC.FoCUS):
4406
4478
  N_image, 1, L, L, M3, N3
4407
4479
  )
4408
4480
  )
4409
- )[..., edge_dx:-edge_dx, edge_dy:-edge_dy].mean(
4410
- (-2, -1)
4411
- ) * fft_factor
4481
+ * edge_mask[None, None, None, None, :, :]
4482
+ ).mean((-2, -1)) * fft_factor
4412
4483
  if get_variance:
4413
4484
  S4_sigma[:, Ndata_S4, :, :, :] = (
4414
4485
  I1_small[:, j1].view(
@@ -4419,11 +4490,10 @@ class funct(FOC.FoCUS):
4419
4490
  N_image, 1, L, L, M3, N3
4420
4491
  )
4421
4492
  )
4422
- )[
4423
- ..., edge_dx:-edge_dx, edge_dy:-edge_dy
4424
- ].std(
4425
- (-2, -1)
4426
- ) * fft_factor
4493
+ * edge_mask[
4494
+ None, None, None, None, :, :
4495
+ ]
4496
+ ).std((-2, -1)) * fft_factor
4427
4497
  else:
4428
4498
  for l1 in range(L):
4429
4499
  # [N_image,l2,l3,x,y]
@@ -4436,11 +4506,10 @@ class funct(FOC.FoCUS):
4436
4506
  N_image, L, L, M3, N3
4437
4507
  )
4438
4508
  )
4439
- )[
4440
- ..., edge_dx:-edge_dx, edge_dy:-edge_dy
4441
- ].mean(
4442
- (-2, -1)
4443
- ) * fft_factor
4509
+ * edge_mask[
4510
+ None, None, None, None, :, :
4511
+ ]
4512
+ ).mean((-2, -1)) * fft_factor
4444
4513
  if get_variance:
4445
4514
  S4_sigma[:, Ndata_S4, l1, :, :] = (
4446
4515
  I1_small[:, j1].view(
@@ -4451,13 +4520,10 @@ class funct(FOC.FoCUS):
4451
4520
  N_image, L, L, M3, N3
4452
4521
  )
4453
4522
  )
4454
- )[
4455
- ...,
4456
- edge_dx:-edge_dx,
4457
- edge_dy:-edge_dy,
4458
- ].mean(
4459
- (-2, -1)
4460
- ) * fft_factor
4523
+ * edge_mask[
4524
+ None, None, None, None, :, :
4525
+ ]
4526
+ ).std((-2, -1)) * fft_factor
4461
4527
 
4462
4528
  Ndata_S4 += 1
4463
4529
 
@@ -4788,7 +4854,9 @@ class funct(FOC.FoCUS):
4788
4854
  #
4789
4855
  if edge:
4790
4856
  if (M, N, J) not in self.edge_masks:
4791
- self.edge_masks[(M, N, J)] = self.get_edge_masks(M, N, J)
4857
+ self.edge_masks[(M, N, J)] = self.get_edge_masks(
4858
+ M, N, J, in_mask=in_mask
4859
+ )
4792
4860
  edge_mask = self.edge_masks[(M, N, J)]
4793
4861
  else:
4794
4862
  edge_mask = 1
@@ -5662,19 +5730,42 @@ class funct(FOC.FoCUS):
5662
5730
 
5663
5731
  return for_synthesis
5664
5732
 
5665
- def to_gaussian(self, x):
5733
+ def purge_edge_mask(self):
5734
+
5735
+ list_edge = []
5736
+ for k in self.edge_masks:
5737
+ list_edge.append(k)
5738
+ for k in list_edge:
5739
+ del self.edge_masks[k]
5740
+
5741
+ self.edge_masks = {}
5742
+
5743
+ def to_gaussian(self, x, in_mask=None):
5666
5744
  from scipy.interpolate import interp1d
5667
5745
  from scipy.stats import norm
5668
5746
 
5669
- idx = np.argsort(x.flatten())
5670
- p = (np.arange(1, idx.shape[0] + 1) - 0.5) / idx.shape[0]
5671
- im_target = x.flatten()
5672
- im_target[idx] = norm.ppf(p)
5747
+ if in_mask is not None:
5748
+ m_idx = np.where(in_mask.flatten() > 0)[0]
5749
+ idx = np.argsort(x.flatten()[m_idx])
5750
+ p = norm.ppf((np.arange(1, idx.shape[0] + 1) - 0.5) / idx.shape[0])
5751
+ im_target = x.flatten()
5752
+ im_target[m_idx[idx]] = p
5673
5753
 
5674
- # Interpolation cubique
5675
- self.f_gaussian = interp1d(im_target[idx], x.flatten()[idx], kind="cubic")
5676
- self.val_min = im_target[idx[0]]
5677
- self.val_max = im_target[idx[-1]]
5754
+ self.f_gaussian = interp1d(
5755
+ im_target[m_idx[idx]], x.flatten()[m_idx[idx]], kind="cubic"
5756
+ )
5757
+ self.val_min = im_target[m_idx][idx[0]]
5758
+ self.val_max = im_target[m_idx][idx[-1]]
5759
+ else:
5760
+ idx = np.argsort(x.flatten())
5761
+ p = (np.arange(1, idx.shape[0] + 1) - 0.5) / idx.shape[0]
5762
+ im_target = x.flatten()
5763
+ im_target[idx] = norm.ppf(p)
5764
+
5765
+ # Interpolation cubique
5766
+ self.f_gaussian = interp1d(im_target[idx], x.flatten()[idx], kind="cubic")
5767
+ self.val_min = im_target[idx[0]]
5768
+ self.val_max = im_target[idx[-1]]
5678
5769
  return im_target.reshape(x.shape)
5679
5770
 
5680
5771
  def from_gaussian(self, x):
@@ -5948,57 +6039,6 @@ class funct(FOC.FoCUS):
5948
6039
 
5949
6040
  return result
5950
6041
 
5951
- # # ---------------------------------------------−---------
5952
- # def std(self, list_of_sc):
5953
- # n = len(list_of_sc)
5954
- # res = list_of_sc[0]
5955
- # res2 = list_of_sc[0] * list_of_sc[0]
5956
- # for k in range(1, n):
5957
- # res = res + list_of_sc[k]
5958
- # res2 = res2 + list_of_sc[k] * list_of_sc[k]
5959
- #
5960
- # if res.S1 is None:
5961
- # if res.S3P is not None:
5962
- # return scat_cov(
5963
- # res.domult(sig.S0, res.S0) * res.domult(sig.S0, res.S0),
5964
- # res.domult(sig.S2, res.S2) * res.domult(sig.S2, res.S2),
5965
- # res.domult(sig.S3, res.S3) * res.domult(sig.S3, res.S3),
5966
- # res.domult(sig.S4, res.S4) * res.domult(sig.S4, res.S4),
5967
- # S3P=res.domult(sig.S3P, res.S3P) * res.domult(sig.S3P, res.S3P),
5968
- # backend=self.backend,
5969
- # use_1D=self.use_1D,
5970
- # )
5971
- # else:
5972
- # return scat_cov(
5973
- # res.domult(sig.S0, res.S0) * res.domult(sig.S0, res.S0),
5974
- # res.domult(sig.S2, res.S2) * res.domult(sig.S2, res.S2),
5975
- # res.domult(sig.S3, res.S3) * res.domult(sig.S3, res.S3),
5976
- # res.domult(sig.S4, res.S4) * res.domult(sig.S4, res.S4),
5977
- # backend=self.backend,
5978
- # use_1D=self.use_1D,
5979
- # )
5980
- # else:
5981
- # if res.S3P is None:
5982
- # return scat_cov(
5983
- # res.domult(sig.S0, res.S0) * res.domult(sig.S0, res.S0),
5984
- # res.domult(sig.S2, res.S2) * res.domult(sig.S2, res.S2),
5985
- # res.domult(sig.S3, res.S3) * res.domult(sig.S3, res.S3),
5986
- # res.domult(sig.S4, res.S4) * res.domult(sig.S4, res.S4),
5987
- # S1=res.domult(sig.S1, res.S1) * res.domult(sig.S1, res.S1),
5988
- # S3P=res.domult(sig.S3P, res.S3P) * res.domult(sig.S3P, res.S3P),
5989
- # backend=self.backend,
5990
- # )
5991
- # else:
5992
- # return scat_cov(
5993
- # res.domult(sig.S2, res.S2) * res.domult(sig.S2, res.S2),
5994
- # res.domult(sig.S1, res.S1) * res.domult(sig.S1, res.S1),
5995
- # res.domult(sig.S3, res.S3) * res.domult(sig.S3, res.S3),
5996
- # res.domult(sig.S4, res.S4) * res.domult(sig.S4, res.S4),
5997
- # backend=self.backend,
5998
- # use_1D=self.use_1D,
5999
- # )
6000
- # return self.NORIENT
6001
-
6002
6042
  @tf_function
6003
6043
  def eval_comp_fast(
6004
6044
  self,
@@ -6006,13 +6046,12 @@ class funct(FOC.FoCUS):
6006
6046
  image2=None,
6007
6047
  mask=None,
6008
6048
  norm=None,
6009
- Auto=True,
6010
6049
  cmat=None,
6011
6050
  cmat2=None,
6012
6051
  ):
6013
6052
 
6014
6053
  res = self.eval(
6015
- image1, image2=image2, mask=mask, Auto=Auto, cmat=cmat, cmat2=cmat2
6054
+ image1, image2=image2, mask=mask, cmat=cmat, cmat2=cmat2
6016
6055
  )
6017
6056
  return res.S0, res.S2, res.S1, res.S3, res.S4, res.S3P
6018
6057
 
@@ -6022,12 +6061,11 @@ class funct(FOC.FoCUS):
6022
6061
  image2=None,
6023
6062
  mask=None,
6024
6063
  norm=None,
6025
- Auto=True,
6026
6064
  cmat=None,
6027
6065
  cmat2=None,
6028
6066
  ):
6029
6067
  s0, s2, s1, s3, s4, s3p = self.eval_comp_fast(
6030
- image1, image2=image2, mask=mask, Auto=Auto, cmat=cmat, cmat2=cmat2
6068
+ image1, image2=image2, mask=mask, cmat=cmat, cmat2=cmat2
6031
6069
  )
6032
6070
  return scat_cov(
6033
6071
  s0, s2, s3, s4, s1=s1, s3p=s3p, backend=self.backend, use_1D=self.use_1D
@@ -6036,6 +6074,7 @@ class funct(FOC.FoCUS):
6036
6074
  def synthesis(
6037
6075
  self,
6038
6076
  image_target,
6077
+ reference=None,
6039
6078
  nstep=4,
6040
6079
  seed=1234,
6041
6080
  Jmax=None,
@@ -6045,6 +6084,7 @@ class funct(FOC.FoCUS):
6045
6084
  synthesised_N=1,
6046
6085
  input_image=None,
6047
6086
  grd_mask=None,
6087
+ in_mask=None,
6048
6088
  iso_ang=False,
6049
6089
  EVAL_FREQUENCY=100,
6050
6090
  NUM_EPOCHS=300,
@@ -6054,18 +6094,68 @@ class funct(FOC.FoCUS):
6054
6094
 
6055
6095
  import foscat.Synthesis as synthe
6056
6096
 
6097
+ l_edge = edge
6098
+ if in_mask is not None:
6099
+ l_edge = True
6100
+
6101
+ if edge:
6102
+ self.purge_edge_mask()
6103
+
6104
+ def The_loss_ref_image(u, scat_operator, args):
6105
+ input_image = args[0]
6106
+ mask = args[1]
6107
+
6108
+ loss = 1E-3*scat_operator.backend.bk_reduce_mean(
6109
+ scat_operator.backend.bk_square(mask*(input_image - u))
6110
+ )
6111
+ return loss
6112
+
6057
6113
  def The_loss(u, scat_operator, args):
6058
6114
  ref = args[0]
6059
6115
  sref = args[1]
6060
6116
  use_v = args[2]
6117
+ ljmax = args[3]
6118
+
6119
+ # compute scattering covariance of the current synthetised map called u
6120
+ if use_v:
6121
+ learn = scat_operator.reduce_mean_batch(
6122
+ scat_operator.scattering_cov(
6123
+ u,
6124
+ edge=l_edge,
6125
+ Jmax=ljmax,
6126
+ ref_sigma=sref,
6127
+ use_ref=True,
6128
+ iso_ang=iso_ang,
6129
+ )
6130
+ )
6131
+ else:
6132
+ learn = scat_operator.reduce_mean_batch(
6133
+ scat_operator.scattering_cov(
6134
+ u, edge=l_edge, Jmax=ljmax, use_ref=True, iso_ang=iso_ang
6135
+ )
6136
+ )
6137
+
6138
+ # make the difference withe the reference coordinates
6139
+ loss = scat_operator.backend.bk_reduce_mean(
6140
+ scat_operator.backend.bk_square(learn - ref)
6141
+ )
6142
+ return loss
6143
+
6144
+ def The_lossX(u, scat_operator, args):
6145
+ ref = args[0]
6146
+ sref = args[1]
6147
+ use_v = args[2]
6148
+ im2 = args[3]
6149
+ ljmax = args[4]
6061
6150
 
6062
6151
  # compute scattering covariance of the current synthetised map called u
6063
6152
  if use_v:
6064
6153
  learn = scat_operator.reduce_mean_batch(
6065
6154
  scat_operator.scattering_cov(
6066
6155
  u,
6067
- edge=edge,
6068
- Jmax=Jmax,
6156
+ data2=im2,
6157
+ edge=l_edge,
6158
+ Jmax=ljmax,
6069
6159
  ref_sigma=sref,
6070
6160
  use_ref=True,
6071
6161
  iso_ang=iso_ang,
@@ -6074,7 +6164,12 @@ class funct(FOC.FoCUS):
6074
6164
  else:
6075
6165
  learn = scat_operator.reduce_mean_batch(
6076
6166
  scat_operator.scattering_cov(
6077
- u, edge=edge, Jmax=Jmax, use_ref=True, iso_ang=iso_ang
6167
+ u,
6168
+ data2=im2,
6169
+ edge=l_edge,
6170
+ Jmax=ljmax,
6171
+ use_ref=True,
6172
+ iso_ang=iso_ang,
6078
6173
  )
6079
6174
  )
6080
6175
 
@@ -6086,7 +6181,7 @@ class funct(FOC.FoCUS):
6086
6181
 
6087
6182
  if to_gaussian:
6088
6183
  # Change the data histogram to gaussian distribution
6089
- im_target = self.to_gaussian(image_target)
6184
+ im_target = self.to_gaussian(image_target, in_mask=in_mask)
6090
6185
  else:
6091
6186
  im_target = image_target
6092
6187
 
@@ -6113,22 +6208,59 @@ class funct(FOC.FoCUS):
6113
6208
 
6114
6209
  t1 = time.time()
6115
6210
  tmp = {}
6116
-
6117
- l_grd_mask={}
6118
-
6211
+
6212
+ l_grd_mask = {}
6213
+ l_in_mask = {}
6214
+ l_input_image = {}
6215
+ l_ref = {}
6216
+ l_jmax = {}
6217
+
6119
6218
  tmp[nstep - 1] = self.backend.bk_cast(im_target)
6219
+ l_jmax[nstep - 1] = Jmax
6220
+
6221
+ if reference is not None:
6222
+ l_ref[nstep - 1] = self.backend.bk_cast(reference)
6223
+ else:
6224
+ l_ref[nstep - 1] = None
6225
+
6120
6226
  if grd_mask is not None:
6121
6227
  l_grd_mask[nstep - 1] = self.backend.bk_cast(grd_mask)
6122
6228
  else:
6123
6229
  l_grd_mask[nstep - 1] = None
6124
-
6230
+ if in_mask is not None:
6231
+ l_in_mask[nstep - 1] = in_mask
6232
+ else:
6233
+ l_in_mask[nstep - 1] = None
6234
+
6235
+ if input_image is not None:
6236
+ l_input_image[nstep - 1] = input_image
6237
+
6125
6238
  for ell in range(nstep - 2, -1, -1):
6126
- tmp[ell] = self.ud_grade_2(tmp[ell + 1], axis=1)
6239
+ tmp[ell], _ = self.ud_grade_2(tmp[ell + 1], axis=1)
6240
+
6127
6241
  if grd_mask is not None:
6128
- l_grd_mask[ell] = self.ud_grade_2(l_grd_mask[ell + 1], axis=1)
6242
+ l_grd_mask[ell], _ = self.ud_grade_2(l_grd_mask[ell + 1], axis=1)
6129
6243
  else:
6130
6244
  l_grd_mask[ell] = None
6131
-
6245
+
6246
+ if in_mask is not None:
6247
+ l_in_mask[ell], _ = self.ud_grade_2(l_in_mask[ell + 1])
6248
+ l_in_mask[ell] = self.backend.to_numpy(l_in_mask[ell])
6249
+ else:
6250
+ l_in_mask[ell] = None
6251
+
6252
+ if input_image is not None:
6253
+ l_input_image[ell], _ = self.ud_grade_2(l_input_image[ell + 1], axis=1)
6254
+
6255
+ if reference is not None:
6256
+ l_ref[ell], _ = self.ud_grade_2(l_ref[ell + 1], axis=1)
6257
+ else:
6258
+ l_ref[ell] = None
6259
+
6260
+ if l_jmax[ell + 1] is None:
6261
+ l_jmax[ell] = None
6262
+ else:
6263
+ l_jmax[ell] = l_jmax[ell + 1] - 1
6132
6264
 
6133
6265
  if not self.use_2D and not self.use_1D:
6134
6266
  l_nside = nside // (2 ** (nstep - 1))
@@ -6138,16 +6270,20 @@ class funct(FOC.FoCUS):
6138
6270
  if input_image is None:
6139
6271
  np.random.seed(seed)
6140
6272
  if self.use_2D:
6141
- imap = self.backend.bk_cast(np.random.randn(
6142
- synthesised_N, tmp[k].shape[1], tmp[k].shape[2]
6143
- ))
6273
+ imap = self.backend.bk_cast(
6274
+ np.random.randn(
6275
+ synthesised_N, tmp[k].shape[1], tmp[k].shape[2]
6276
+ )
6277
+ )
6144
6278
  else:
6145
- imap = self.backend.bk_cast(np.random.randn(synthesised_N, tmp[k].shape[1]))
6279
+ imap = self.backend.bk_cast(
6280
+ np.random.randn(synthesised_N, tmp[k].shape[1])
6281
+ )
6146
6282
  else:
6147
6283
  if self.use_2D:
6148
6284
  imap = self.backend.bk_reshape(
6149
6285
  self.backend.bk_tile(
6150
- self.backend.bk_cast(input_image.flatten()),
6286
+ self.backend.bk_cast(l_input_image[k].flatten()),
6151
6287
  synthesised_N,
6152
6288
  ),
6153
6289
  [synthesised_N, tmp[k].shape[1], tmp[k].shape[2]],
@@ -6155,7 +6291,7 @@ class funct(FOC.FoCUS):
6155
6291
  else:
6156
6292
  imap = self.backend.bk_reshape(
6157
6293
  self.backend.bk_tile(
6158
- self.backend.bk_cast(input_image.flatten()),
6294
+ self.backend.bk_cast(l_input_image[k].flatten()),
6159
6295
  synthesised_N,
6160
6296
  ),
6161
6297
  [synthesised_N, tmp[k].shape[1]],
@@ -6170,26 +6306,56 @@ class funct(FOC.FoCUS):
6170
6306
  imap = self.up_grade(omap, imap.shape[1] * 2, axis=1)
6171
6307
  else:
6172
6308
  imap = self.up_grade(omap, l_nside, axis=1)
6173
-
6309
+
6174
6310
  if grd_mask is not None:
6175
- imap=imap*l_grd_mask[k]+tmp[k]*(1-l_grd_mask[k])
6176
-
6311
+ imap = imap * l_grd_mask[k] + tmp[k] * (1 - l_grd_mask[k])
6312
+
6177
6313
  # compute the coefficients for the target image
6178
6314
  if use_variance:
6179
6315
  ref, sref = self.scattering_cov(
6180
- tmp[k], get_variance=True, edge=edge, Jmax=Jmax, iso_ang=iso_ang
6316
+ tmp[k],
6317
+ data2=l_ref[k],
6318
+ get_variance=True,
6319
+ edge=l_edge,
6320
+ Jmax=l_jmax[k],
6321
+ in_mask=l_in_mask[k],
6322
+ iso_ang=iso_ang,
6181
6323
  )
6182
6324
  else:
6183
- ref = self.scattering_cov(tmp[k], edge=edge, Jmax=Jmax, iso_ang=iso_ang)
6325
+ ref = self.scattering_cov(
6326
+ tmp[k],
6327
+ data2=l_ref[k],
6328
+ in_mask=l_in_mask[k],
6329
+ edge=l_edge,
6330
+ Jmax=l_jmax[k],
6331
+ iso_ang=iso_ang,
6332
+ )
6184
6333
  sref = ref
6185
6334
 
6186
6335
  # compute the mean of the population does nothing if only one map is given
6187
6336
  ref = self.reduce_mean_batch(ref)
6188
6337
 
6189
- # define a loss to minimize
6190
- loss = synthe.Loss(The_loss, self, ref, sref, use_variance)
6338
+ if l_in_mask[k] is not None:
6339
+ self.purge_edge_mask()
6191
6340
 
6192
- sy = synthe.Synthesis([loss])
6341
+ if l_ref[k] is None:
6342
+ # define a loss to minimize
6343
+ loss = synthe.Loss(The_loss, self, ref, sref, use_variance, l_jmax[k])
6344
+ else:
6345
+ # define a loss to minimize
6346
+ loss = synthe.Loss(
6347
+ The_lossX, self, ref, sref, use_variance, l_ref[k], l_jmax[k]
6348
+ )
6349
+
6350
+ if input_image is not None:
6351
+ # define a loss to minimize
6352
+ loss_input = synthe.Loss(The_loss_ref_image, self,
6353
+ self.backend.bk_cast(l_input_image[k]),
6354
+ self.backend.bk_cast(l_in_mask[k]))
6355
+
6356
+ sy = synthe.Synthesis([loss]) #,loss_input])
6357
+ else:
6358
+ sy = synthe.Synthesis([loss])
6193
6359
 
6194
6360
  # initialize the synthesised map
6195
6361
  if self.use_2D:
@@ -6201,7 +6367,12 @@ class funct(FOC.FoCUS):
6201
6367
  l_nside *= 2
6202
6368
 
6203
6369
  # do the minimization
6204
- omap = sy.run(imap, EVAL_FREQUENCY=EVAL_FREQUENCY, NUM_EPOCHS=NUM_EPOCHS,grd_mask=l_grd_mask[k])
6370
+ omap = sy.run(
6371
+ imap,
6372
+ EVAL_FREQUENCY=EVAL_FREQUENCY,
6373
+ NUM_EPOCHS=NUM_EPOCHS,
6374
+ grd_mask=l_grd_mask[k],
6375
+ )
6205
6376
 
6206
6377
  t2 = time.time()
6207
6378
  print("Total computation %.2fs" % (t2 - t1))