foscat 3.9.0__py3-none-any.whl → 2025.3.0__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
foscat/scat_cov.py CHANGED
@@ -177,9 +177,21 @@ class scat_cov:
177
177
  ],
178
178
  )
179
179
  ),
180
- self.conv2complex(
180
+ self.backend.bk_reshape(
181
+ self.S3,
182
+ [
183
+ self.S3.shape[0],
184
+ self.S3.shape[1]
185
+ * self.S3.shape[2]
186
+ * self.S3.shape[3]
187
+ * self.S3.shape[4],
188
+ ],
189
+ ),
190
+ ]
191
+ if self.S3P is not None:
192
+ tmp = tmp + [
181
193
  self.backend.bk_reshape(
182
- self.S3,
194
+ self.S3P,
183
195
  [
184
196
  self.S3.shape[0],
185
197
  self.S3.shape[1]
@@ -188,37 +200,19 @@ class scat_cov:
188
200
  * self.S3.shape[4],
189
201
  ],
190
202
  )
191
- ),
192
- ]
193
- if self.S3P is not None:
194
- tmp = tmp + [
195
- self.conv2complex(
196
- self.backend.bk_reshape(
197
- self.S3P,
198
- [
199
- self.S3.shape[0],
200
- self.S3.shape[1]
201
- * self.S3.shape[2]
202
- * self.S3.shape[3]
203
- * self.S3.shape[4],
204
- ],
205
- )
206
- )
207
203
  ]
208
204
 
209
205
  tmp = tmp + [
210
- self.conv2complex(
211
- self.backend.bk_reshape(
212
- self.S4,
213
- [
214
- self.S3.shape[0],
215
- self.S4.shape[1]
216
- * self.S4.shape[2]
217
- * self.S4.shape[3]
218
- * self.S4.shape[4]
219
- * self.S4.shape[5],
220
- ],
221
- )
206
+ self.backend.bk_reshape(
207
+ self.S4,
208
+ [
209
+ self.S3.shape[0],
210
+ self.S4.shape[1]
211
+ * self.S4.shape[2]
212
+ * self.S4.shape[3]
213
+ * self.S4.shape[4]
214
+ * self.S4.shape[5],
215
+ ],
222
216
  )
223
217
  ]
224
218
 
@@ -504,7 +498,7 @@ class scat_cov:
504
498
  if other.S1 is None:
505
499
  s1 = None
506
500
  else:
507
- s1 = self.S1 + other.S1
501
+ s1 = self.doadd(self.S1, other.S1)
508
502
  else:
509
503
  s1 = self.S1 + other
510
504
 
@@ -534,7 +528,7 @@ class scat_cov:
534
528
  return scat_cov(
535
529
  self.doadd(self.S0, other.S0),
536
530
  self.doadd(self.S2, other.S2),
537
- (self.S3 + other.S3),
531
+ self.doadd(self.S3, other.S3),
538
532
  s4,
539
533
  s1=s1,
540
534
  s3p=s3p,
@@ -662,7 +656,7 @@ class scat_cov:
662
656
  s1 = None
663
657
  else:
664
658
  if isinstance(other, scat_cov):
665
- s1 = other.S1 / self.S1
659
+ s1 = self.dodiv(other.S1, self.S1)
666
660
  else:
667
661
  s1 = other / self.S1
668
662
 
@@ -689,7 +683,7 @@ class scat_cov:
689
683
  return scat_cov(
690
684
  self.dodiv(other.S0, self.S0),
691
685
  self.dodiv(other.S2, self.S2),
692
- (other.S3 / self.S3),
686
+ self.dodiv(other.S3, self.S3),
693
687
  s4,
694
688
  s1=s1,
695
689
  s3p=s3p,
@@ -725,7 +719,7 @@ class scat_cov:
725
719
  if other.S1 is None:
726
720
  s1 = None
727
721
  else:
728
- s1 = other.S1 - self.S1
722
+ s1 = self.domin(other.S1, self.S1)
729
723
  else:
730
724
  s1 = other - self.S1
731
725
 
@@ -755,7 +749,7 @@ class scat_cov:
755
749
  return scat_cov(
756
750
  self.domin(other.S0, self.S0),
757
751
  self.domin(other.S2, self.S2),
758
- (other.S3 - self.S3),
752
+ self.domin(other.S3, self.S3),
759
753
  s4,
760
754
  s1=s1,
761
755
  s3p=s3p,
@@ -790,7 +784,7 @@ class scat_cov:
790
784
  if other.S1 is None:
791
785
  s1 = None
792
786
  else:
793
- s1 = self.S1 - other.S1
787
+ s1 = self.domin(self.S1, other.S1)
794
788
  else:
795
789
  s1 = self.S1 - other
796
790
 
@@ -820,7 +814,7 @@ class scat_cov:
820
814
  return scat_cov(
821
815
  self.domin(self.S0, other.S0),
822
816
  self.domin(self.S2, other.S2),
823
- (self.S3 - other.S3),
817
+ self.domin(self.S3, other.S3),
824
818
  s4,
825
819
  s1=s1,
826
820
  s3p=s3p,
@@ -920,7 +914,7 @@ class scat_cov:
920
914
  if other.S1 is None:
921
915
  s1 = None
922
916
  else:
923
- s1 = self.S1 * other.S1
917
+ s1 = self.domult(self.S1, other.S1)
924
918
  else:
925
919
  s1 = self.S1 * other
926
920
 
@@ -2215,7 +2209,7 @@ class scat_cov:
2215
2209
 
2216
2210
 
2217
2211
  class funct(FOC.FoCUS):
2218
-
2212
+
2219
2213
  def fill(self, im, nullval=hp.UNSEEN):
2220
2214
  if self.use_2D:
2221
2215
  return self.fill_2d(im, nullval=nullval)
@@ -2402,8 +2396,9 @@ class funct(FOC.FoCUS):
2402
2396
  if smooth_scale > 0:
2403
2397
  for m in range(smooth_scale):
2404
2398
  if cc.shape[0] > 12:
2405
- cc = self.ud_grade_2(self.smooth(cc))
2406
- ss = self.ud_grade_2(self.smooth(ss))
2399
+ cc, _ = self.ud_grade_2(self.smooth(cc))
2400
+ ss, _ = self.ud_grade_2(self.smooth(ss))
2401
+
2407
2402
  if cc.shape[0] != tmp.shape[0]:
2408
2403
  ll_nside = int(np.sqrt(tmp.shape[1] // 12))
2409
2404
  cc = self.up_grade(cc, ll_nside)
@@ -2473,8 +2468,8 @@ class funct(FOC.FoCUS):
2473
2468
  if smooth_scale > 0:
2474
2469
  for m in range(smooth_scale):
2475
2470
  if cc2.shape[0] > 12:
2476
- cc2 = self.ud_grade_2(self.smooth(cc2))
2477
- ss2 = self.ud_grade_2(self.smooth(ss2))
2471
+ cc2, _ = self.ud_grade_2(self.smooth(cc2))
2472
+ ss2, _ = self.ud_grade_2(self.smooth(ss2))
2478
2473
 
2479
2474
  if cc2.shape[0] != sim.shape[1]:
2480
2475
  ll_nside = int(np.sqrt(sim.shape[1] // 12))
@@ -2510,9 +2505,9 @@ class funct(FOC.FoCUS):
2510
2505
  cmat2[k] = self.backend.bk_cast(mat2.astype("complex64"))
2511
2506
 
2512
2507
  if k < l_nside - 1:
2513
- tmp = self.ud_grade_2(tmp, axis=1)
2508
+ tmp, _ = self.ud_grade_2(tmp, axis=1)
2514
2509
  if image2 is not None:
2515
- tmpi2 = self.ud_grade_2(tmpi2, axis=1)
2510
+ tmpi2, _ = self.ud_grade_2(tmpi2, axis=1)
2516
2511
  return cmat, cmat2
2517
2512
 
2518
2513
  def div_norm(self, complex_value, float_value):
@@ -2533,6 +2528,8 @@ class funct(FOC.FoCUS):
2533
2528
  Jmax=None,
2534
2529
  out_nside=None,
2535
2530
  edge=True,
2531
+ nside=None,
2532
+ cell_ids=None,
2536
2533
  ):
2537
2534
  """
2538
2535
  Calculates the scattering correlations for a batch of images. Mean are done over pixels.
@@ -2606,7 +2603,7 @@ class funct(FOC.FoCUS):
2606
2603
  cross = False
2607
2604
  if image2 is not None:
2608
2605
  cross = True
2609
-
2606
+ l_nside = 2**32 # not initialize if 1D or 2D
2610
2607
  ### PARAMETERS
2611
2608
  axis = 1
2612
2609
  # determine jmax and nside corresponding to the input map
@@ -2638,7 +2635,8 @@ class funct(FOC.FoCUS):
2638
2635
  else:
2639
2636
  npix = int(im_shape[0]) # Number of pixels
2640
2637
 
2641
- nside = int(np.sqrt(npix // 12))
2638
+ if nside is None:
2639
+ nside = int(np.sqrt(npix // 12))
2642
2640
 
2643
2641
  J = int(np.log(nside) / np.log(2)) # Number of j scales
2644
2642
 
@@ -2675,7 +2673,7 @@ class funct(FOC.FoCUS):
2675
2673
  else:
2676
2674
  vmask = self.backend.bk_cast(mask) # [Nmask, Npix]
2677
2675
 
2678
- if self.KERNELSZ > 3 and not self.use_2D:
2676
+ if self.KERNELSZ > 3 and not self.use_2D and cell_ids is None:
2679
2677
  # if the kernel size is bigger than 3 increase the binning before smoothing
2680
2678
  if self.use_2D:
2681
2679
  vmask = self.up_grade(
@@ -2693,12 +2691,15 @@ class funct(FOC.FoCUS):
2693
2691
  I1 = self.up_grade(I1, I1.shape[axis] * 2, axis=axis)
2694
2692
  if cross:
2695
2693
  I2 = self.up_grade(I2, I2.shape[axis] * 2, axis=axis)
2694
+ nside = nside * 2
2696
2695
  else:
2697
2696
  I1 = self.up_grade(I1, nside * 2, axis=axis)
2698
2697
  vmask = self.up_grade(vmask, nside * 2, axis=1)
2699
2698
  if cross:
2700
2699
  I2 = self.up_grade(I2, nside * 2, axis=axis)
2701
2700
 
2701
+ nside = nside * 2
2702
+
2702
2703
  if self.KERNELSZ > 5 and not self.use_2D:
2703
2704
  # if the kernel size is bigger than 3 increase the binning before smoothing
2704
2705
  if self.use_2D:
@@ -2716,15 +2717,17 @@ class funct(FOC.FoCUS):
2716
2717
  nouty=I2.shape[axis + 1] * 2,
2717
2718
  )
2718
2719
  elif self.use_1D:
2719
- vmask = self.up_grade(vmask, I1.shape[axis] * 4, axis=1)
2720
- I1 = self.up_grade(I1, I1.shape[axis] * 4, axis=axis)
2720
+ vmask = self.up_grade(vmask, I1.shape[axis] * 2, axis=1)
2721
+ I1 = self.up_grade(I1, I1.shape[axis] * 2, axis=axis)
2721
2722
  if cross:
2722
- I2 = self.up_grade(I2, I2.shape[axis] * 4, axis=axis)
2723
+ I2 = self.up_grade(I2, I2.shape[axis] * 2, axis=axis)
2724
+ nside = nside * 2
2723
2725
  else:
2724
- I1 = self.up_grade(I1, nside * 4, axis=axis)
2725
- vmask = self.up_grade(vmask, nside * 4, axis=1)
2726
+ I1 = self.up_grade(I1, nside * 2, axis=axis)
2727
+ vmask = self.up_grade(vmask, nside * 2, axis=1)
2726
2728
  if cross:
2727
- I2 = self.up_grade(I2, nside * 4, axis=axis)
2729
+ I2 = self.up_grade(I2, nside * 2, axis=axis)
2730
+ nside = nside * 2
2728
2731
 
2729
2732
  # Normalize the masks because they have different pixel numbers
2730
2733
  # vmask /= self.backend.bk_reduce_sum(vmask, axis=1)[:, None] # [Nmask, Npix]
@@ -2797,6 +2800,8 @@ class funct(FOC.FoCUS):
2797
2800
  M1_dic = {}
2798
2801
  M2_dic = {}
2799
2802
 
2803
+ cell_ids_j3 = cell_ids
2804
+
2800
2805
  for j3 in range(Jmax):
2801
2806
 
2802
2807
  if edge:
@@ -2835,7 +2840,9 @@ class funct(FOC.FoCUS):
2835
2840
 
2836
2841
  ####### S1 and S2
2837
2842
  ### Make the convolution I1 * Psi_j3
2838
- conv1 = self.convol(I1, axis=1) # [Nbatch, Npix_j3, Norient3]
2843
+ conv1 = self.convol(
2844
+ I1, axis=1, cell_ids=cell_ids_j3, nside=nside_j3
2845
+ ) # [Nbatch, Npix_j3, Norient3]
2839
2846
  if cmat is not None:
2840
2847
  tmp2 = self.backend.bk_repeat(conv1, self.NORIENT, axis=-1)
2841
2848
  conv1 = self.backend.bk_reduce_sum(
@@ -2948,7 +2955,9 @@ class funct(FOC.FoCUS):
2948
2955
 
2949
2956
  else: # Cross
2950
2957
  ### Make the convolution I2 * Psi_j3
2951
- conv2 = self.convol(I2, axis=1) # [Nbatch, Npix_j3, Norient3]
2958
+ conv2 = self.convol(
2959
+ I2, axis=1, cell_ids=cell_ids_j3, nside=nside_j3
2960
+ ) # [Nbatch, Npix_j3, Norient3]
2952
2961
  if cmat is not None:
2953
2962
  tmp2 = self.backend.bk_repeat(conv2, self.NORIENT, axis=-1)
2954
2963
  conv2 = self.backend.bk_reduce_sum(
@@ -3092,7 +3101,6 @@ class funct(FOC.FoCUS):
3092
3101
  M2convPsi_dic = {}
3093
3102
 
3094
3103
  ###### S3
3095
- nside_j2 = nside_j3
3096
3104
  for j2 in range(0, j3 + 1): # j2 <= j3
3097
3105
  if return_data:
3098
3106
  if S4[j3] is None:
@@ -3111,6 +3119,8 @@ class funct(FOC.FoCUS):
3111
3119
  M1convPsi_dic,
3112
3120
  calc_var=True,
3113
3121
  cmat2=cmat2,
3122
+ cell_ids=cell_ids_j3,
3123
+ nside_j2=nside_j3,
3114
3124
  ) # [Nbatch, Nmask, Norient3, Norient2]
3115
3125
  else:
3116
3126
  s3 = self._compute_S3(
@@ -3122,19 +3132,21 @@ class funct(FOC.FoCUS):
3122
3132
  M1convPsi_dic,
3123
3133
  return_data=return_data,
3124
3134
  cmat2=cmat2,
3135
+ cell_ids=cell_ids_j3,
3136
+ nside_j2=nside_j3,
3125
3137
  ) # [Nbatch, Nmask, Norient3, Norient2]
3126
3138
 
3127
3139
  if return_data:
3128
3140
  if S3[j3] is None:
3129
3141
  S3[j3] = {}
3130
- if out_nside is not None and out_nside < nside_j2:
3142
+ if out_nside is not None and out_nside < nside_j3:
3131
3143
  s3 = self.backend.bk_reduce_mean(
3132
3144
  self.backend.bk_reshape(
3133
3145
  s3,
3134
3146
  [
3135
3147
  s3.shape[0],
3136
3148
  12 * out_nside**2,
3137
- (nside_j2 // out_nside) ** 2,
3149
+ (nside_j3 // out_nside) ** 2,
3138
3150
  s3.shape[2],
3139
3151
  s3.shape[3],
3140
3152
  ],
@@ -3181,6 +3193,8 @@ class funct(FOC.FoCUS):
3181
3193
  M2convPsi_dic,
3182
3194
  calc_var=True,
3183
3195
  cmat2=cmat2,
3196
+ cell_ids=cell_ids_j3,
3197
+ nside_j2=nside_j3,
3184
3198
  )
3185
3199
  s3p, vs3p = self._compute_S3(
3186
3200
  j2,
@@ -3191,41 +3205,47 @@ class funct(FOC.FoCUS):
3191
3205
  M1convPsi_dic,
3192
3206
  calc_var=True,
3193
3207
  cmat2=cmat2,
3208
+ cell_ids=cell_ids_j3,
3209
+ nside_j2=nside_j3,
3194
3210
  )
3195
3211
  else:
3196
- s3 = self._compute_S3(
3212
+ s3p = self._compute_S3(
3197
3213
  j2,
3198
3214
  j3,
3199
- conv1,
3215
+ conv2,
3200
3216
  vmask,
3201
- M2_dic,
3202
- M2convPsi_dic,
3217
+ M1_dic,
3218
+ M1convPsi_dic,
3203
3219
  return_data=return_data,
3204
3220
  cmat2=cmat2,
3221
+ cell_ids=cell_ids_j3,
3222
+ nside_j2=nside_j3,
3205
3223
  )
3206
- s3p = self._compute_S3(
3224
+ s3 = self._compute_S3(
3207
3225
  j2,
3208
3226
  j3,
3209
- conv2,
3227
+ conv1,
3210
3228
  vmask,
3211
- M1_dic,
3212
- M1convPsi_dic,
3229
+ M2_dic,
3230
+ M2convPsi_dic,
3213
3231
  return_data=return_data,
3214
3232
  cmat2=cmat2,
3233
+ cell_ids=cell_ids_j3,
3234
+ nside_j2=nside_j3,
3215
3235
  )
3216
3236
 
3217
3237
  if return_data:
3218
3238
  if S3[j3] is None:
3219
3239
  S3[j3] = {}
3220
3240
  S3P[j3] = {}
3221
- if out_nside is not None and out_nside < nside_j2:
3241
+ if out_nside is not None and out_nside < nside_j3:
3222
3242
  s3 = self.backend.bk_reduce_mean(
3223
3243
  self.backend.bk_reshape(
3224
3244
  s3,
3225
3245
  [
3226
3246
  s3.shape[0],
3227
3247
  12 * out_nside**2,
3228
- (nside_j2 // out_nside) ** 2,
3248
+ (nside_j3 // out_nside) ** 2,
3229
3249
  s3.shape[2],
3230
3250
  s3.shape[3],
3231
3251
  ],
@@ -3238,7 +3258,7 @@ class funct(FOC.FoCUS):
3238
3258
  [
3239
3259
  s3.shape[0],
3240
3260
  12 * out_nside**2,
3241
- (nside_j2 // out_nside) ** 2,
3261
+ (nside_j3 // out_nside) ** 2,
3242
3262
  s3.shape[2],
3243
3263
  s3.shape[3],
3244
3264
  ],
@@ -3295,7 +3315,6 @@ class funct(FOC.FoCUS):
3295
3315
  # s3.shape[2]*s3.shape[3]]))
3296
3316
 
3297
3317
  ##### S4
3298
- nside_j1 = nside_j2
3299
3318
  for j1 in range(0, j2 + 1): # j1 <= j2
3300
3319
  ### S4_auto = <(|I1 * psi1| * psi3)(|I1 * psi2| * psi3)^*>
3301
3320
  if not cross:
@@ -3321,14 +3340,14 @@ class funct(FOC.FoCUS):
3321
3340
  if return_data:
3322
3341
  if S4[j3][j2] is None:
3323
3342
  S4[j3][j2] = {}
3324
- if out_nside is not None and out_nside < nside_j1:
3343
+ if out_nside is not None and out_nside < nside_j3:
3325
3344
  s4 = self.backend.bk_reduce_mean(
3326
3345
  self.backend.bk_reshape(
3327
3346
  s4,
3328
3347
  [
3329
3348
  s4.shape[0],
3330
3349
  12 * out_nside**2,
3331
- (nside_j1 // out_nside) ** 2,
3350
+ (nside_j3 // out_nside) ** 2,
3332
3351
  s4.shape[2],
3333
3352
  s4.shape[3],
3334
3353
  s4.shape[4],
@@ -3396,14 +3415,14 @@ class funct(FOC.FoCUS):
3396
3415
  if return_data:
3397
3416
  if S4[j3][j2] is None:
3398
3417
  S4[j3][j2] = {}
3399
- if out_nside is not None and out_nside < nside_j1:
3418
+ if out_nside is not None and out_nside < nside_j3:
3400
3419
  s4 = self.backend.bk_reduce_mean(
3401
3420
  self.backend.bk_reshape(
3402
3421
  s4,
3403
3422
  [
3404
3423
  s4.shape[0],
3405
3424
  12 * out_nside**2,
3406
- (nside_j1 // out_nside) ** 2,
3425
+ (nside_j3 // out_nside) ** 2,
3407
3426
  s4.shape[2],
3408
3427
  s4.shape[3],
3409
3428
  s4.shape[4],
@@ -3447,48 +3466,51 @@ class funct(FOC.FoCUS):
3447
3466
  self.backend.bk_expand_dims(vs4, off_S4)
3448
3467
  ) # Add a dimension for NS4
3449
3468
 
3450
- nside_j1 = nside_j1 // 2
3451
- nside_j2 = nside_j2 // 2
3452
-
3453
3469
  ###### Reshape for next iteration on j3
3454
3470
  ### Image I1,
3455
3471
  # downscale the I1 [Nbatch, Npix_j3]
3456
3472
  if j3 != Jmax - 1:
3457
- I1 = self.smooth(I1, axis=1)
3458
- I1 = self.ud_grade_2(I1, axis=1)
3473
+ I1 = self.smooth(I1, axis=1, cell_ids=cell_ids_j3, nside=nside_j3)
3474
+ I1, new_cell_ids_j3 = self.ud_grade_2(
3475
+ I1, axis=1, cell_ids=cell_ids_j3, nside=nside_j3
3476
+ )
3459
3477
 
3460
3478
  ### Image I2
3461
3479
  if cross:
3462
- I2 = self.smooth(I2, axis=1)
3463
- I2 = self.ud_grade_2(I2, axis=1)
3480
+ I2 = self.smooth(I2, axis=1, cell_ids=cell_ids_j3, nside=nside_j3)
3481
+ I2, new_cell_ids_j3 = self.ud_grade_2(
3482
+ I2, axis=1, cell_ids=cell_ids_j3, nside=nside_j3
3483
+ )
3464
3484
 
3465
3485
  ### Modules
3466
3486
  for j2 in range(0, j3 + 1): # j2 =< j3
3467
3487
  ### Dictionary M1_dic[j2]
3468
3488
  M1_smooth = self.smooth(
3469
- M1_dic[j2], axis=1
3489
+ M1_dic[j2], axis=1, cell_ids=cell_ids_j3, nside=nside_j3
3470
3490
  ) # [Nbatch, Npix_j3, Norient3]
3471
- M1_dic[j2] = self.ud_grade_2(
3472
- M1_smooth, axis=1
3491
+ M1_dic[j2], new_cell_ids_j2 = self.ud_grade_2(
3492
+ M1_smooth, axis=1, cell_ids=cell_ids_j3, nside=nside_j3
3473
3493
  ) # [Nbatch, Npix_j3, Norient3]
3474
3494
 
3475
3495
  ### Dictionary M2_dic[j2]
3476
3496
  if cross:
3477
3497
  M2_smooth = self.smooth(
3478
- M2_dic[j2], axis=1
3498
+ M2_dic[j2], axis=1, cell_ids=cell_ids_j3, nside=nside_j3
3479
3499
  ) # [Nbatch, Npix_j3, Norient3]
3480
- M2_dic[j2] = self.ud_grade_2(
3481
- M2_smooth, axis=1
3500
+ M2_dic[j2], new_cell_ids_j2 = self.ud_grade_2(
3501
+ M2_smooth, axis=1, cell_ids=cell_ids_j3, nside=nside_j3
3482
3502
  ) # [Nbatch, Npix_j3, Norient3]
3483
-
3484
3503
  ### Mask
3485
- vmask = self.ud_grade_2(vmask, axis=1)
3504
+ vmask, new_cell_ids_j3 = self.ud_grade_2(
3505
+ vmask, axis=1, cell_ids=cell_ids_j3, nside=nside_j3
3506
+ )
3486
3507
 
3487
3508
  if self.mask_thres is not None:
3488
3509
  vmask = self.backend.bk_threshold(vmask, self.mask_thres)
3489
3510
 
3490
3511
  ### NSIDE_j3
3491
3512
  nside_j3 = nside_j3 // 2
3513
+ cell_ids_j3 = new_cell_ids_j3
3492
3514
 
3493
3515
  ### Store P1_dic and P2_dic in self
3494
3516
  if (norm == "auto") and (self.P1_dic is None):
@@ -3588,6 +3610,8 @@ class funct(FOC.FoCUS):
3588
3610
  calc_var=False,
3589
3611
  return_data=False,
3590
3612
  cmat2=None,
3613
+ cell_ids=None,
3614
+ nside_j2=None,
3591
3615
  ):
3592
3616
  """
3593
3617
  Compute the S3 coefficients (auto or cross)
@@ -3601,7 +3625,7 @@ class funct(FOC.FoCUS):
3601
3625
  ### Compute |I1 * Psi_j2| * Psi_j3 = M1_j2 * Psi_j3
3602
3626
  # Warning: M1_dic[j2] is already at j3 resolution [Nbatch, Npix_j3, Norient3]
3603
3627
  MconvPsi = self.convol(
3604
- M_dic[j2], axis=1
3628
+ M_dic[j2], axis=1, cell_ids=cell_ids, nside=nside_j2
3605
3629
  ) # [Nbatch, Npix_j3, Norient3, Norient2]
3606
3630
  if cmat2 is not None:
3607
3631
  tmp2 = self.backend.bk_repeat(MconvPsi, self.NORIENT, axis=-1)
@@ -3822,24 +3846,62 @@ class funct(FOC.FoCUS):
3822
3846
  dy = int(max(8, min(np.ceil(N / 2**j), N // 2)))
3823
3847
  return dx, dy
3824
3848
 
3825
- def get_edge_masks(self, M, N, J, d0=1):
3849
+ def get_edge_masks(self, M, N, J, d0=1, in_mask=None, edge_dx=None, edge_dy=None):
3826
3850
  """
3827
3851
  This function is strongly inspire by the package https://github.com/SihaoCheng/scattering_transform
3828
3852
  Done by Sihao Cheng and Rudy Morel.
3829
3853
  """
3830
3854
  edge_masks = np.empty((J, M, N))
3855
+
3831
3856
  X, Y = np.meshgrid(np.arange(M), np.arange(N), indexing="ij")
3832
- for j in range(J):
3833
- edge_dx = min(M // 4, 2**j * d0)
3834
- edge_dy = min(N // 4, 2**j * d0)
3835
- edge_masks[j] = (
3836
- (X >= edge_dx)
3837
- * (X <= M - edge_dx)
3838
- * (Y >= edge_dy)
3839
- * (Y <= N - edge_dy)
3857
+ if in_mask is not None:
3858
+ from scipy.ndimage import binary_erosion
3859
+
3860
+ if in_mask is not None:
3861
+ if in_mask.shape[0] != M or in_mask.shape[0] != N:
3862
+ l_mask = in_mask.reshape(
3863
+ M, in_mask.shape[0] // M, N, in_mask.shape[1] // N
3864
+ )
3865
+ l_mask = (
3866
+ np.sum(np.sum(l_mask, 1), 2)
3867
+ * (M * N)
3868
+ / (in_mask.shape[0] * in_mask.shape[1])
3869
+ )
3870
+ else:
3871
+ l_mask = in_mask
3872
+
3873
+ if edge_dx is None:
3874
+ for j in range(J):
3875
+ edge_dx = min(M // 4, 2**j * d0)
3876
+ edge_dy = min(N // 4, 2**j * d0)
3877
+
3878
+ edge_masks[j] = (
3879
+ (X >= edge_dx)
3880
+ * (X < M - edge_dx)
3881
+ * (Y >= edge_dy)
3882
+ * (Y < N - edge_dy)
3883
+ )
3884
+ if in_mask is not None:
3885
+ l_mask = binary_erosion(
3886
+ l_mask, iterations=1 + np.max([edge_dx, edge_dy])
3887
+ )
3888
+ edge_masks[j] *= l_mask
3889
+
3890
+ edge_masks = edge_masks[:, None, :, :]
3891
+
3892
+ edge_masks = edge_masks / edge_masks.mean((-2, -1))[:, :, None, None]
3893
+ else:
3894
+ edge_masks = (
3895
+ (X >= edge_dx) * (X < M - edge_dx) * (Y >= edge_dy) * (Y < N - edge_dy)
3840
3896
  )
3841
- edge_masks = edge_masks[:, None, :, :]
3842
- edge_masks = edge_masks / edge_masks.mean((-2, -1))[:, :, None, None]
3897
+ if in_mask is not None:
3898
+ l_mask = binary_erosion(
3899
+ l_mask, iterations=1 + np.max([edge_dx, edge_dy])
3900
+ )
3901
+ edge_masks *= l_mask
3902
+
3903
+ edge_masks = edge_masks / edge_masks.mean((-2, -1))
3904
+
3843
3905
  return self.backend.bk_cast(edge_masks)
3844
3906
 
3845
3907
  # ---------------------------------------------------------------------------
@@ -3857,6 +3919,7 @@ class funct(FOC.FoCUS):
3857
3919
  use_ref=False,
3858
3920
  normalization="S2",
3859
3921
  edge=False,
3922
+ in_mask=None,
3860
3923
  pseudo_coef=1,
3861
3924
  get_variance=False,
3862
3925
  ref_sigma=None,
@@ -3925,6 +3988,9 @@ class funct(FOC.FoCUS):
3925
3988
  if S4_criteria is None:
3926
3989
  S4_criteria = "j2>=j1"
3927
3990
 
3991
+ if not edge and in_mask is not None:
3992
+ edge = True
3993
+
3928
3994
  if self.all_bk_type == "float32":
3929
3995
  C_ONE = np.complex64(1.0)
3930
3996
  else:
@@ -3974,14 +4040,15 @@ class funct(FOC.FoCUS):
3974
4040
 
3975
4041
  J = int(np.log(nside) / np.log(2)) # Number of j scales
3976
4042
 
3977
- if Jmax is None:
3978
- Jmax = J # Number of steps for the loop on scales
3979
- if Jmax > J:
3980
- print("==========\n\n")
3981
- print(
3982
- "The Jmax you requested is larger than the data size, which may cause problems while computing the scattering transform."
3983
- )
3984
- print("\n\n==========")
4043
+ if Jmax is not None:
4044
+
4045
+ if Jmax > J:
4046
+ print("==========\n\n")
4047
+ print(
4048
+ "The Jmax you requested is larger than the data size, which may cause problems while computing the scattering transform."
4049
+ )
4050
+ print("\n\n==========")
4051
+ J = Jmax # Number of steps for the loop on scales
3985
4052
 
3986
4053
  L = self.NORIENT
3987
4054
  norm_factor_S3 = 1.0
@@ -4067,7 +4134,10 @@ class funct(FOC.FoCUS):
4067
4134
  #
4068
4135
  if edge:
4069
4136
  if (M, N, J) not in self.edge_masks:
4070
- self.edge_masks[(M, N, J)] = self.get_edge_masks(M, N, J)
4137
+ self.edge_masks[(M, N, J)] = self.get_edge_masks(
4138
+ M, N, J, in_mask=in_mask
4139
+ )
4140
+
4071
4141
  edge_mask = self.edge_masks[(M, N, J)]
4072
4142
  else:
4073
4143
  edge_mask = 1
@@ -4206,8 +4276,19 @@ class funct(FOC.FoCUS):
4206
4276
  wavelet_f3 = self.cut_high_k_off(filters_set[j3], dx3, dy3) # L,x,y
4207
4277
  _, M3, N3 = wavelet_f3.shape
4208
4278
  wavelet_f3_squared = wavelet_f3**2
4209
- edge_dx = min(4, int(2**j3 * dx3 * 2 / M))
4210
- edge_dy = min(4, int(2**j3 * dy3 * 2 / N))
4279
+ if edge is True:
4280
+ if (M3, N3, J, j3) not in self.edge_masks:
4281
+
4282
+ edge_dx = min(4, int(2**j3 * dx3 * 2 / M))
4283
+ edge_dy = min(4, int(2**j3 * dy3 * 2 / N))
4284
+
4285
+ self.edge_masks[(M3, N3, J, j3)] = self.get_edge_masks(
4286
+ M3, N3, J, in_mask=in_mask, edge_dx=edge_dx, edge_dy=edge_dy
4287
+ )
4288
+
4289
+ edge_mask = self.edge_masks[(M3, N3, J, j3)]
4290
+ else:
4291
+ edge_mask = 1
4211
4292
 
4212
4293
  # a normalization change due to the cutoff of frequency space
4213
4294
  fft_factor = 1 / (M3 * N3) * (M3 * N3 / M / N) ** 2
@@ -4274,7 +4355,8 @@ class funct(FOC.FoCUS):
4274
4355
  (
4275
4356
  data_small.view(N_image, 1, 1, M3, N3)
4276
4357
  * self.backend.bk_conjugate(I12_w3_small)
4277
- )[..., edge_dx : M3 - edge_dx, edge_dy : N3 - edge_dy].mean(
4358
+ * edge_mask[None, None, None, :, :]
4359
+ ).mean( # [..., edge_dx : M3 - edge_dx, edge_dy : N3 - edge_dy]
4278
4360
  (-2, -1)
4279
4361
  )
4280
4362
  * fft_factor
@@ -4285,11 +4367,8 @@ class funct(FOC.FoCUS):
4285
4367
  (
4286
4368
  data_small.view(N_image, 1, 1, M3, N3)
4287
4369
  * self.backend.bk_conjugate(I12_w3_small)
4288
- )[
4289
- ..., edge_dx : M3 - edge_dx, edge_dy : N3 - edge_dy
4290
- ].std(
4291
- (-2, -1)
4292
- )
4370
+ * edge_mask[None, None, None, :, :]
4371
+ ).std((-2, -1))
4293
4372
  * fft_factor
4294
4373
  / norm_factor_S3
4295
4374
  )
@@ -4318,11 +4397,8 @@ class funct(FOC.FoCUS):
4318
4397
  (
4319
4398
  data2_small.view(N_image2, 1, 1, M3, N3)
4320
4399
  * self.backend.bk_conjugate(I12_w3_small)
4321
- )[
4322
- ..., edge_dx : M3 - edge_dx, edge_dy : N3 - edge_dy
4323
- ].mean(
4324
- (-2, -1)
4325
- )
4400
+ * edge_mask[None, None, None, :, :]
4401
+ ).mean((-2, -1))
4326
4402
  * fft_factor
4327
4403
  / norm_factor_S3
4328
4404
  )
@@ -4331,13 +4407,8 @@ class funct(FOC.FoCUS):
4331
4407
  (
4332
4408
  data2_small.view(N_image2, 1, 1, M3, N3)
4333
4409
  * self.backend.bk_conjugate(I12_w3_small)
4334
- )[
4335
- ...,
4336
- edge_dx : M3 - edge_dx,
4337
- edge_dy : N3 - edge_dy,
4338
- ].std(
4339
- (-2, -1)
4340
- )
4410
+ * edge_mask[None, None, None, :, :]
4411
+ ).std((-2, -1))
4341
4412
  * fft_factor
4342
4413
  / norm_factor_S3
4343
4414
  )
@@ -4406,9 +4477,8 @@ class funct(FOC.FoCUS):
4406
4477
  N_image, 1, L, L, M3, N3
4407
4478
  )
4408
4479
  )
4409
- )[..., edge_dx:-edge_dx, edge_dy:-edge_dy].mean(
4410
- (-2, -1)
4411
- ) * fft_factor
4480
+ * edge_mask[None, None, None, None, :, :]
4481
+ ).mean((-2, -1)) * fft_factor
4412
4482
  if get_variance:
4413
4483
  S4_sigma[:, Ndata_S4, :, :, :] = (
4414
4484
  I1_small[:, j1].view(
@@ -4419,11 +4489,10 @@ class funct(FOC.FoCUS):
4419
4489
  N_image, 1, L, L, M3, N3
4420
4490
  )
4421
4491
  )
4422
- )[
4423
- ..., edge_dx:-edge_dx, edge_dy:-edge_dy
4424
- ].std(
4425
- (-2, -1)
4426
- ) * fft_factor
4492
+ * edge_mask[
4493
+ None, None, None, None, :, :
4494
+ ]
4495
+ ).std((-2, -1)) * fft_factor
4427
4496
  else:
4428
4497
  for l1 in range(L):
4429
4498
  # [N_image,l2,l3,x,y]
@@ -4436,11 +4505,10 @@ class funct(FOC.FoCUS):
4436
4505
  N_image, L, L, M3, N3
4437
4506
  )
4438
4507
  )
4439
- )[
4440
- ..., edge_dx:-edge_dx, edge_dy:-edge_dy
4441
- ].mean(
4442
- (-2, -1)
4443
- ) * fft_factor
4508
+ * edge_mask[
4509
+ None, None, None, None, :, :
4510
+ ]
4511
+ ).mean((-2, -1)) * fft_factor
4444
4512
  if get_variance:
4445
4513
  S4_sigma[:, Ndata_S4, l1, :, :] = (
4446
4514
  I1_small[:, j1].view(
@@ -4451,13 +4519,10 @@ class funct(FOC.FoCUS):
4451
4519
  N_image, L, L, M3, N3
4452
4520
  )
4453
4521
  )
4454
- )[
4455
- ...,
4456
- edge_dx:-edge_dx,
4457
- edge_dy:-edge_dy,
4458
- ].mean(
4459
- (-2, -1)
4460
- ) * fft_factor
4522
+ * edge_mask[
4523
+ None, None, None, None, :, :
4524
+ ]
4525
+ ).std((-2, -1)) * fft_factor
4461
4526
 
4462
4527
  Ndata_S4 += 1
4463
4528
 
@@ -4788,7 +4853,9 @@ class funct(FOC.FoCUS):
4788
4853
  #
4789
4854
  if edge:
4790
4855
  if (M, N, J) not in self.edge_masks:
4791
- self.edge_masks[(M, N, J)] = self.get_edge_masks(M, N, J)
4856
+ self.edge_masks[(M, N, J)] = self.get_edge_masks(
4857
+ M, N, J, in_mask=in_mask
4858
+ )
4792
4859
  edge_mask = self.edge_masks[(M, N, J)]
4793
4860
  else:
4794
4861
  edge_mask = 1
@@ -5662,19 +5729,42 @@ class funct(FOC.FoCUS):
5662
5729
 
5663
5730
  return for_synthesis
5664
5731
 
5665
- def to_gaussian(self, x):
5732
+ def purge_edge_mask(self):
5733
+
5734
+ list_edge = []
5735
+ for k in self.edge_masks:
5736
+ list_edge.append(k)
5737
+ for k in list_edge:
5738
+ del self.edge_masks[k]
5739
+
5740
+ self.edge_masks = {}
5741
+
5742
+ def to_gaussian(self, x, in_mask=None):
5666
5743
  from scipy.interpolate import interp1d
5667
5744
  from scipy.stats import norm
5668
5745
 
5669
- idx = np.argsort(x.flatten())
5670
- p = (np.arange(1, idx.shape[0] + 1) - 0.5) / idx.shape[0]
5671
- im_target = x.flatten()
5672
- im_target[idx] = norm.ppf(p)
5746
+ if in_mask is not None:
5747
+ m_idx = np.where(in_mask.flatten() > 0)[0]
5748
+ idx = np.argsort(x.flatten()[m_idx])
5749
+ p = norm.ppf((np.arange(1, idx.shape[0] + 1) - 0.5) / idx.shape[0])
5750
+ im_target = x.flatten()
5751
+ im_target[m_idx[idx]] = p
5673
5752
 
5674
- # Interpolation cubique
5675
- self.f_gaussian = interp1d(im_target[idx], x.flatten()[idx], kind="cubic")
5676
- self.val_min = im_target[idx[0]]
5677
- self.val_max = im_target[idx[-1]]
5753
+ self.f_gaussian = interp1d(
5754
+ im_target[m_idx[idx]], x.flatten()[m_idx[idx]], kind="cubic"
5755
+ )
5756
+ self.val_min = im_target[m_idx][idx[0]]
5757
+ self.val_max = im_target[m_idx][idx[-1]]
5758
+ else:
5759
+ idx = np.argsort(x.flatten())
5760
+ p = (np.arange(1, idx.shape[0] + 1) - 0.5) / idx.shape[0]
5761
+ im_target = x.flatten()
5762
+ im_target[idx] = norm.ppf(p)
5763
+
5764
+ # Interpolation cubique
5765
+ self.f_gaussian = interp1d(im_target[idx], x.flatten()[idx], kind="cubic")
5766
+ self.val_min = im_target[idx[0]]
5767
+ self.val_max = im_target[idx[-1]]
5678
5768
  return im_target.reshape(x.shape)
5679
5769
 
5680
5770
  def from_gaussian(self, x):
@@ -5948,57 +6038,6 @@ class funct(FOC.FoCUS):
5948
6038
 
5949
6039
  return result
5950
6040
 
5951
- # # ---------------------------------------------−---------
5952
- # def std(self, list_of_sc):
5953
- # n = len(list_of_sc)
5954
- # res = list_of_sc[0]
5955
- # res2 = list_of_sc[0] * list_of_sc[0]
5956
- # for k in range(1, n):
5957
- # res = res + list_of_sc[k]
5958
- # res2 = res2 + list_of_sc[k] * list_of_sc[k]
5959
- #
5960
- # if res.S1 is None:
5961
- # if res.S3P is not None:
5962
- # return scat_cov(
5963
- # res.domult(sig.S0, res.S0) * res.domult(sig.S0, res.S0),
5964
- # res.domult(sig.S2, res.S2) * res.domult(sig.S2, res.S2),
5965
- # res.domult(sig.S3, res.S3) * res.domult(sig.S3, res.S3),
5966
- # res.domult(sig.S4, res.S4) * res.domult(sig.S4, res.S4),
5967
- # S3P=res.domult(sig.S3P, res.S3P) * res.domult(sig.S3P, res.S3P),
5968
- # backend=self.backend,
5969
- # use_1D=self.use_1D,
5970
- # )
5971
- # else:
5972
- # return scat_cov(
5973
- # res.domult(sig.S0, res.S0) * res.domult(sig.S0, res.S0),
5974
- # res.domult(sig.S2, res.S2) * res.domult(sig.S2, res.S2),
5975
- # res.domult(sig.S3, res.S3) * res.domult(sig.S3, res.S3),
5976
- # res.domult(sig.S4, res.S4) * res.domult(sig.S4, res.S4),
5977
- # backend=self.backend,
5978
- # use_1D=self.use_1D,
5979
- # )
5980
- # else:
5981
- # if res.S3P is None:
5982
- # return scat_cov(
5983
- # res.domult(sig.S0, res.S0) * res.domult(sig.S0, res.S0),
5984
- # res.domult(sig.S2, res.S2) * res.domult(sig.S2, res.S2),
5985
- # res.domult(sig.S3, res.S3) * res.domult(sig.S3, res.S3),
5986
- # res.domult(sig.S4, res.S4) * res.domult(sig.S4, res.S4),
5987
- # S1=res.domult(sig.S1, res.S1) * res.domult(sig.S1, res.S1),
5988
- # S3P=res.domult(sig.S3P, res.S3P) * res.domult(sig.S3P, res.S3P),
5989
- # backend=self.backend,
5990
- # )
5991
- # else:
5992
- # return scat_cov(
5993
- # res.domult(sig.S2, res.S2) * res.domult(sig.S2, res.S2),
5994
- # res.domult(sig.S1, res.S1) * res.domult(sig.S1, res.S1),
5995
- # res.domult(sig.S3, res.S3) * res.domult(sig.S3, res.S3),
5996
- # res.domult(sig.S4, res.S4) * res.domult(sig.S4, res.S4),
5997
- # backend=self.backend,
5998
- # use_1D=self.use_1D,
5999
- # )
6000
- # return self.NORIENT
6001
-
6002
6041
  @tf_function
6003
6042
  def eval_comp_fast(
6004
6043
  self,
@@ -6036,6 +6075,7 @@ class funct(FOC.FoCUS):
6036
6075
  def synthesis(
6037
6076
  self,
6038
6077
  image_target,
6078
+ reference=None,
6039
6079
  nstep=4,
6040
6080
  seed=1234,
6041
6081
  Jmax=None,
@@ -6045,6 +6085,7 @@ class funct(FOC.FoCUS):
6045
6085
  synthesised_N=1,
6046
6086
  input_image=None,
6047
6087
  grd_mask=None,
6088
+ in_mask=None,
6048
6089
  iso_ang=False,
6049
6090
  EVAL_FREQUENCY=100,
6050
6091
  NUM_EPOCHS=300,
@@ -6054,18 +6095,59 @@ class funct(FOC.FoCUS):
6054
6095
 
6055
6096
  import foscat.Synthesis as synthe
6056
6097
 
6098
+ l_edge = edge
6099
+ if in_mask is not None:
6100
+ l_edge = True
6101
+
6102
+ if edge:
6103
+ self.purge_edge_mask()
6104
+
6057
6105
  def The_loss(u, scat_operator, args):
6058
6106
  ref = args[0]
6059
6107
  sref = args[1]
6060
6108
  use_v = args[2]
6109
+ ljmax = args[3]
6110
+
6111
+ # compute scattering covariance of the current synthetised map called u
6112
+ if use_v:
6113
+ learn = scat_operator.reduce_mean_batch(
6114
+ scat_operator.scattering_cov(
6115
+ u,
6116
+ edge=l_edge,
6117
+ Jmax=ljmax,
6118
+ ref_sigma=sref,
6119
+ use_ref=True,
6120
+ iso_ang=iso_ang,
6121
+ )
6122
+ )
6123
+ else:
6124
+ learn = scat_operator.reduce_mean_batch(
6125
+ scat_operator.scattering_cov(
6126
+ u, edge=l_edge, Jmax=ljmax, use_ref=True, iso_ang=iso_ang
6127
+ )
6128
+ )
6129
+
6130
+ # make the difference withe the reference coordinates
6131
+ loss = scat_operator.backend.bk_reduce_mean(
6132
+ scat_operator.backend.bk_square(learn - ref)
6133
+ )
6134
+ return loss
6135
+
6136
+ def The_lossX(u, scat_operator, args):
6137
+ ref = args[0]
6138
+ sref = args[1]
6139
+ use_v = args[2]
6140
+ im2 = args[3]
6141
+ ljmax = args[4]
6061
6142
 
6062
6143
  # compute scattering covariance of the current synthetised map called u
6063
6144
  if use_v:
6064
6145
  learn = scat_operator.reduce_mean_batch(
6065
6146
  scat_operator.scattering_cov(
6066
6147
  u,
6067
- edge=edge,
6068
- Jmax=Jmax,
6148
+ data2=im2,
6149
+ edge=l_edge,
6150
+ Jmax=ljmax,
6069
6151
  ref_sigma=sref,
6070
6152
  use_ref=True,
6071
6153
  iso_ang=iso_ang,
@@ -6074,7 +6156,12 @@ class funct(FOC.FoCUS):
6074
6156
  else:
6075
6157
  learn = scat_operator.reduce_mean_batch(
6076
6158
  scat_operator.scattering_cov(
6077
- u, edge=edge, Jmax=Jmax, use_ref=True, iso_ang=iso_ang
6159
+ u,
6160
+ data2=im2,
6161
+ edge=l_edge,
6162
+ Jmax=ljmax,
6163
+ use_ref=True,
6164
+ iso_ang=iso_ang,
6078
6165
  )
6079
6166
  )
6080
6167
 
@@ -6086,7 +6173,7 @@ class funct(FOC.FoCUS):
6086
6173
 
6087
6174
  if to_gaussian:
6088
6175
  # Change the data histogram to gaussian distribution
6089
- im_target = self.to_gaussian(image_target)
6176
+ im_target = self.to_gaussian(image_target, in_mask=in_mask)
6090
6177
  else:
6091
6178
  im_target = image_target
6092
6179
 
@@ -6113,22 +6200,59 @@ class funct(FOC.FoCUS):
6113
6200
 
6114
6201
  t1 = time.time()
6115
6202
  tmp = {}
6116
-
6117
- l_grd_mask={}
6118
-
6203
+
6204
+ l_grd_mask = {}
6205
+ l_in_mask = {}
6206
+ l_input_image = {}
6207
+ l_ref = {}
6208
+ l_jmax = {}
6209
+
6119
6210
  tmp[nstep - 1] = self.backend.bk_cast(im_target)
6211
+ l_jmax[nstep - 1] = Jmax
6212
+
6213
+ if reference is not None:
6214
+ l_ref[nstep - 1] = self.backend.bk_cast(reference)
6215
+ else:
6216
+ l_ref[nstep - 1] = None
6217
+
6120
6218
  if grd_mask is not None:
6121
6219
  l_grd_mask[nstep - 1] = self.backend.bk_cast(grd_mask)
6122
6220
  else:
6123
6221
  l_grd_mask[nstep - 1] = None
6124
-
6222
+ if in_mask is not None:
6223
+ l_in_mask[nstep - 1] = in_mask
6224
+ else:
6225
+ l_in_mask[nstep - 1] = None
6226
+
6227
+ if input_image is not None:
6228
+ l_input_image[nstep - 1] = input_image
6229
+
6125
6230
  for ell in range(nstep - 2, -1, -1):
6126
- tmp[ell] = self.ud_grade_2(tmp[ell + 1], axis=1)
6231
+ tmp[ell], _ = self.ud_grade_2(tmp[ell + 1], axis=1)
6232
+
6127
6233
  if grd_mask is not None:
6128
- l_grd_mask[ell] = self.ud_grade_2(l_grd_mask[ell + 1], axis=1)
6234
+ l_grd_mask[ell], _ = self.ud_grade_2(l_grd_mask[ell + 1], axis=1)
6129
6235
  else:
6130
6236
  l_grd_mask[ell] = None
6131
-
6237
+
6238
+ if in_mask is not None:
6239
+ l_in_mask[ell], _ = self.ud_grade_2(l_in_mask[ell + 1])
6240
+ l_in_mask[ell] = self.backend.to_numpy(l_in_mask[ell])
6241
+ else:
6242
+ l_in_mask[ell] = None
6243
+
6244
+ if input_image is not None:
6245
+ l_input_image[ell], _ = self.ud_grade_2(l_input_image[ell + 1], axis=1)
6246
+
6247
+ if reference is not None:
6248
+ l_ref[ell], _ = self.ud_grade_2(l_ref[ell + 1], axis=1)
6249
+ else:
6250
+ l_ref[ell] = None
6251
+
6252
+ if l_jmax[ell + 1] is None:
6253
+ l_jmax[ell] = None
6254
+ else:
6255
+ l_jmax[ell] = l_jmax[ell + 1] - 1
6132
6256
 
6133
6257
  if not self.use_2D and not self.use_1D:
6134
6258
  l_nside = nside // (2 ** (nstep - 1))
@@ -6138,16 +6262,20 @@ class funct(FOC.FoCUS):
6138
6262
  if input_image is None:
6139
6263
  np.random.seed(seed)
6140
6264
  if self.use_2D:
6141
- imap = self.backend.bk_cast(np.random.randn(
6142
- synthesised_N, tmp[k].shape[1], tmp[k].shape[2]
6143
- ))
6265
+ imap = self.backend.bk_cast(
6266
+ np.random.randn(
6267
+ synthesised_N, tmp[k].shape[1], tmp[k].shape[2]
6268
+ )
6269
+ )
6144
6270
  else:
6145
- imap = self.backend.bk_cast(np.random.randn(synthesised_N, tmp[k].shape[1]))
6271
+ imap = self.backend.bk_cast(
6272
+ np.random.randn(synthesised_N, tmp[k].shape[1])
6273
+ )
6146
6274
  else:
6147
6275
  if self.use_2D:
6148
6276
  imap = self.backend.bk_reshape(
6149
6277
  self.backend.bk_tile(
6150
- self.backend.bk_cast(input_image.flatten()),
6278
+ self.backend.bk_cast(l_input_image[k].flatten()),
6151
6279
  synthesised_N,
6152
6280
  ),
6153
6281
  [synthesised_N, tmp[k].shape[1], tmp[k].shape[2]],
@@ -6155,7 +6283,7 @@ class funct(FOC.FoCUS):
6155
6283
  else:
6156
6284
  imap = self.backend.bk_reshape(
6157
6285
  self.backend.bk_tile(
6158
- self.backend.bk_cast(input_image.flatten()),
6286
+ self.backend.bk_cast(l_input_image[k].flatten()),
6159
6287
  synthesised_N,
6160
6288
  ),
6161
6289
  [synthesised_N, tmp[k].shape[1]],
@@ -6170,24 +6298,46 @@ class funct(FOC.FoCUS):
6170
6298
  imap = self.up_grade(omap, imap.shape[1] * 2, axis=1)
6171
6299
  else:
6172
6300
  imap = self.up_grade(omap, l_nside, axis=1)
6173
-
6301
+
6174
6302
  if grd_mask is not None:
6175
- imap=imap*l_grd_mask[k]+tmp[k]*(1-l_grd_mask[k])
6176
-
6303
+ imap = imap * l_grd_mask[k] + tmp[k] * (1 - l_grd_mask[k])
6304
+
6177
6305
  # compute the coefficients for the target image
6178
6306
  if use_variance:
6179
6307
  ref, sref = self.scattering_cov(
6180
- tmp[k], get_variance=True, edge=edge, Jmax=Jmax, iso_ang=iso_ang
6308
+ tmp[k],
6309
+ data2=l_ref[k],
6310
+ get_variance=True,
6311
+ edge=l_edge,
6312
+ Jmax=l_jmax[k],
6313
+ in_mask=l_in_mask[k],
6314
+ iso_ang=iso_ang,
6181
6315
  )
6182
6316
  else:
6183
- ref = self.scattering_cov(tmp[k], edge=edge, Jmax=Jmax, iso_ang=iso_ang)
6317
+ ref = self.scattering_cov(
6318
+ tmp[k],
6319
+ data2=l_ref[k],
6320
+ in_mask=l_in_mask[k],
6321
+ edge=l_edge,
6322
+ Jmax=l_jmax[k],
6323
+ iso_ang=iso_ang,
6324
+ )
6184
6325
  sref = ref
6185
6326
 
6186
6327
  # compute the mean of the population does nothing if only one map is given
6187
6328
  ref = self.reduce_mean_batch(ref)
6188
6329
 
6189
- # define a loss to minimize
6190
- loss = synthe.Loss(The_loss, self, ref, sref, use_variance)
6330
+ if l_in_mask[k] is not None:
6331
+ self.purge_edge_mask()
6332
+
6333
+ if l_ref[k] is None:
6334
+ # define a loss to minimize
6335
+ loss = synthe.Loss(The_loss, self, ref, sref, use_variance, l_jmax[k])
6336
+ else:
6337
+ # define a loss to minimize
6338
+ loss = synthe.Loss(
6339
+ The_lossX, self, ref, sref, use_variance, l_ref[k], l_jmax[k]
6340
+ )
6191
6341
 
6192
6342
  sy = synthe.Synthesis([loss])
6193
6343
 
@@ -6201,7 +6351,12 @@ class funct(FOC.FoCUS):
6201
6351
  l_nside *= 2
6202
6352
 
6203
6353
  # do the minimization
6204
- omap = sy.run(imap, EVAL_FREQUENCY=EVAL_FREQUENCY, NUM_EPOCHS=NUM_EPOCHS,grd_mask=l_grd_mask[k])
6354
+ omap = sy.run(
6355
+ imap,
6356
+ EVAL_FREQUENCY=EVAL_FREQUENCY,
6357
+ NUM_EPOCHS=NUM_EPOCHS,
6358
+ grd_mask=l_grd_mask[k],
6359
+ )
6205
6360
 
6206
6361
  t2 = time.time()
6207
6362
  print("Total computation %.2fs" % (t2 - t1))