foscat 3.8.2__py3-none-any.whl → 2025.3.0__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
foscat/FoCUS.py CHANGED
@@ -1,5 +1,4 @@
1
1
  import os
2
- import os
3
2
  import sys
4
3
 
5
4
  import healpy as hp
@@ -11,32 +10,32 @@ TMPFILE_VERSION = "V4_0"
11
10
 
12
11
  class FoCUS:
13
12
  def __init__(
14
- self,
15
- NORIENT=4,
16
- LAMBDA=1.2,
17
- KERNELSZ=3,
18
- slope=1.0,
19
- all_type="float32",
20
- nstep_max=16,
21
- padding="SAME",
22
- gpupos=0,
23
- mask_thres=None,
24
- mask_norm=False,
25
- isMPI=False,
26
- TEMPLATE_PATH="data",
27
- BACKEND="tensorflow",
28
- use_2D=False,
29
- use_1D=False,
30
- return_data=False,
31
- JmaxDelta=0,
32
- DODIV=False,
33
- InitWave=None,
34
- silent=True,
35
- mpi_size=1,
36
- mpi_rank=0,
13
+ self,
14
+ NORIENT=4,
15
+ LAMBDA=1.2,
16
+ KERNELSZ=3,
17
+ slope=1.0,
18
+ all_type="float32",
19
+ nstep_max=16,
20
+ padding="SAME",
21
+ gpupos=0,
22
+ mask_thres=None,
23
+ mask_norm=False,
24
+ isMPI=False,
25
+ TEMPLATE_PATH="data",
26
+ BACKEND="tensorflow",
27
+ use_2D=False,
28
+ use_1D=False,
29
+ return_data=False,
30
+ JmaxDelta=0,
31
+ DODIV=False,
32
+ InitWave=None,
33
+ silent=True,
34
+ mpi_size=1,
35
+ mpi_rank=0,
37
36
  ):
38
37
 
39
- self.__version__ = "3.8.2"
38
+ self.__version__ = "2025.03.0"
40
39
  # P00 coeff for normalization for scat_cov
41
40
  self.TMPFILE_VERSION = TMPFILE_VERSION
42
41
  self.P1_dic = None
@@ -45,12 +44,18 @@ class FoCUS:
45
44
  self.mask_thres = mask_thres
46
45
  self.mask_norm = mask_norm
47
46
  self.InitWave = InitWave
48
- self.mask_mask=None
47
+ self.mask_mask = None
49
48
  self.mpi_size = mpi_size
50
49
  self.mpi_rank = mpi_rank
51
50
  self.return_data = return_data
52
51
  self.silent = silent
53
52
 
53
+ self.kernel_smooth = {}
54
+ self.padding_smooth = {}
55
+ self.kernelR_conv = {}
56
+ self.kernelI_conv = {}
57
+ self.padding_conv = {}
58
+
54
59
  if not self.silent:
55
60
  print("================================================")
56
61
  print(" START FOSCAT CONFIGURATION")
@@ -69,10 +74,7 @@ class FoCUS:
69
74
  if not self.silent:
70
75
  print("The directory %s is created")
71
76
  except:
72
- if not self.silent:
73
- print(
74
- "Impossible to create the directory %s" % (self.TEMPLATE_PATH)
75
- )
77
+ print("Impossible to create the directory %s" % (self.TEMPLATE_PATH))
76
78
  return None
77
79
 
78
80
  self.number_of_loss = 0
@@ -82,10 +84,9 @@ class FoCUS:
82
84
  self.padding = padding
83
85
 
84
86
  if JmaxDelta != 0:
85
- if not self.silent:
86
- print(
87
- "OPTION JmaxDelta is not avialable anymore after version 3.6.2. Please use Jmax option in eval function"
88
- )
87
+ print(
88
+ "OPTION JmaxDelta is not avialable anymore after version 3.6.2. Please use Jmax option in eval function"
89
+ )
89
90
  return None
90
91
 
91
92
  self.OSTEP = JmaxDelta
@@ -105,31 +106,34 @@ class FoCUS:
105
106
 
106
107
  self.all_type = all_type
107
108
  self.BACKEND = BACKEND
108
-
109
- if BACKEND=='torch':
109
+
110
+ if BACKEND == "torch":
110
111
  from foscat.BkTorch import BkTorch
112
+
111
113
  self.backend = BkTorch(
112
114
  all_type=all_type,
113
115
  mpi_rank=mpi_rank,
114
116
  gpupos=gpupos,
115
117
  silent=self.silent,
116
- )
117
- elif BACKEND=='tensorflow':
118
+ )
119
+ elif BACKEND == "tensorflow":
118
120
  from foscat.BkTensorflow import BkTensorflow
121
+
119
122
  self.backend = BkTensorflow(
120
123
  all_type=all_type,
121
124
  mpi_rank=mpi_rank,
122
125
  gpupos=gpupos,
123
126
  silent=self.silent,
124
- )
127
+ )
125
128
  else:
126
129
  from foscat.BkNumpy import BkNumpy
130
+
127
131
  self.backend = BkNumpy(
128
132
  all_type=all_type,
129
133
  mpi_rank=mpi_rank,
130
134
  gpupos=gpupos,
131
135
  silent=self.silent,
132
- )
136
+ )
133
137
 
134
138
  self.all_bk_type = self.backend.all_bk_type
135
139
  self.all_cbk_type = self.backend.all_cbk_type
@@ -172,9 +176,9 @@ class FoCUS:
172
176
  self.Y_CNN = {}
173
177
  self.Z_CNN = {}
174
178
 
175
- self.filters_set={}
176
- self.edge_masks={}
177
-
179
+ self.filters_set = {}
180
+ self.edge_masks = {}
181
+
178
182
  wwc = np.zeros([KERNELSZ**2, l_NORIENT]).astype(all_type)
179
183
  wws = np.zeros([KERNELSZ**2, l_NORIENT]).astype(all_type)
180
184
 
@@ -209,15 +213,27 @@ class FoCUS:
209
213
  w_smooth = w_smooth.flatten()
210
214
  else:
211
215
  for i in range(NORIENT):
212
- a = (NORIENT-1-i) / float(NORIENT) * np.pi # get the same angle number than scattering lib
213
- if KERNELSZ<5:
214
- xx = (3 / float(KERNELSZ)) * LAMBDA * (x * np.cos(a) + y * np.sin(a))
215
- yy = (3 / float(KERNELSZ)) * LAMBDA * (x * np.sin(a) - y * np.cos(a))
216
+ a = (
217
+ (NORIENT - 1 - i) / float(NORIENT) * np.pi
218
+ ) # get the same angle number than scattering lib
219
+ if KERNELSZ < 5:
220
+ xx = (
221
+ (3 / float(KERNELSZ)) * LAMBDA * (x * np.cos(a) + y * np.sin(a))
222
+ )
223
+ yy = (
224
+ (3 / float(KERNELSZ)) * LAMBDA * (x * np.sin(a) - y * np.cos(a))
225
+ )
216
226
  else:
217
- xx = (3 /5) * LAMBDA * (x * np.cos(a) + y * np.sin(a))
218
- yy = (3 /5) * LAMBDA * (x * np.sin(a) - y * np.cos(a))
227
+ xx = (3 / 5) * LAMBDA * (x * np.cos(a) + y * np.sin(a))
228
+ yy = (3 / 5) * LAMBDA * (x * np.sin(a) - y * np.cos(a))
219
229
  if KERNELSZ == 5:
220
- w_smooth=np.exp(-2*((3.0/float(KERNELSZ)*xx)**2+(3.0/float(KERNELSZ)*yy)**2))
230
+ w_smooth = np.exp(
231
+ -2
232
+ * (
233
+ (3.0 / float(KERNELSZ) * xx) ** 2
234
+ + (3.0 / float(KERNELSZ) * yy) ** 2
235
+ )
236
+ )
221
237
  else:
222
238
  w_smooth = np.exp(-0.5 * (xx**2 + yy**2))
223
239
  tmp1 = np.cos(yy * np.pi) * w_smooth
@@ -225,7 +241,7 @@ class FoCUS:
225
241
 
226
242
  wwc[:, i] = tmp1.flatten() - tmp1.mean()
227
243
  wws[:, i] = tmp2.flatten() - tmp2.mean()
228
- #sigma = np.sqrt((wwc[:, i] ** 2).mean())
244
+ # sigma = np.sqrt((wwc[:, i] ** 2).mean())
229
245
  sigma = np.mean(w_smooth)
230
246
  wwc[:, i] /= sigma
231
247
  wws[:, i] /= sigma
@@ -239,7 +255,7 @@ class FoCUS:
239
255
 
240
256
  wwc[:, NORIENT] = tmp1.flatten() - tmp1.mean()
241
257
  wws[:, NORIENT] = tmp2.flatten() - tmp2.mean()
242
- #sigma = np.sqrt((wwc[:, NORIENT] ** 2).mean())
258
+ # sigma = np.sqrt((wwc[:, NORIENT] ** 2).mean())
243
259
  sigma = np.mean(w_smooth)
244
260
 
245
261
  wwc[:, NORIENT] /= sigma
@@ -249,13 +265,13 @@ class FoCUS:
249
265
 
250
266
  wwc[:, NORIENT + 1] = tmp1.flatten() - tmp1.mean()
251
267
  wws[:, NORIENT + 1] = tmp2.flatten() - tmp2.mean()
252
- #sigma = np.sqrt((wwc[:, NORIENT + 1] ** 2).mean())
268
+ # sigma = np.sqrt((wwc[:, NORIENT + 1] ** 2).mean())
253
269
  sigma = np.mean(w_smooth)
254
270
  wwc[:, NORIENT + 1] /= sigma
255
271
  wws[:, NORIENT + 1] /= sigma
256
272
 
257
273
  w_smooth = w_smooth.flatten()
258
-
274
+
259
275
  if self.use_1D:
260
276
  KERNELSZ = 5
261
277
 
@@ -723,19 +739,19 @@ class FoCUS:
723
739
  def ud_grade(self, im, j, axis=0):
724
740
  rim = im
725
741
  for k in range(j):
726
- #rim = self.smooth(rim, axis=axis)
742
+ # rim = self.smooth(rim, axis=axis)
727
743
  rim = self.ud_grade_2(rim, axis=axis)
728
744
  return rim
729
745
 
730
746
  # --------------------------------------------------------
731
- def ud_grade_2(self, im, axis=0):
747
+ def ud_grade_2(self, im, axis=0, cell_ids=None, nside=None):
732
748
 
733
749
  if self.use_2D:
734
750
  ishape = list(im.shape)
735
751
  if len(ishape) < axis + 2:
736
752
  if not self.silent:
737
753
  print("Use of 2D scat with data that has less than 2D")
738
- return None
754
+ return None, None
739
755
 
740
756
  npix = im.shape[axis]
741
757
  npiy = im.shape[axis + 1]
@@ -760,29 +776,40 @@ class FoCUS:
760
776
 
761
777
  if axis == 0:
762
778
  if len(ishape) == 2:
763
- return self.backend.bk_reshape(res, [npix // 2, npiy // 2])
779
+ return self.backend.bk_reshape(res, [npix // 2, npiy // 2]), None
764
780
  else:
765
- return self.backend.bk_reshape(
766
- res, [npix // 2, npiy // 2] + ishape[axis + 2 :]
781
+ return (
782
+ self.backend.bk_reshape(
783
+ res, [npix // 2, npiy // 2] + ishape[axis + 2 :]
784
+ ),
785
+ None,
767
786
  )
768
787
  else:
769
788
  if len(ishape) == axis + 2:
770
- return self.backend.bk_reshape(
771
- res, ishape[0:axis] + [npix // 2, npiy // 2]
789
+ return (
790
+ self.backend.bk_reshape(
791
+ res, ishape[0:axis] + [npix // 2, npiy // 2]
792
+ ),
793
+ None,
772
794
  )
773
795
  else:
774
- return self.backend.bk_reshape(
775
- res,
776
- ishape[0:axis] + [npix // 2, npiy // 2] + ishape[axis + 2 :],
796
+ return (
797
+ self.backend.bk_reshape(
798
+ res,
799
+ ishape[0:axis]
800
+ + [npix // 2, npiy // 2]
801
+ + ishape[axis + 2 :],
802
+ ),
803
+ None,
777
804
  )
778
805
 
779
- return self.backend.bk_reshape(res, [npix // 2, npiy // 2])
806
+ return self.backend.bk_reshape(res, [npix // 2, npiy // 2]), None
780
807
  elif self.use_1D:
781
808
  ishape = list(im.shape)
782
809
  if len(ishape) < axis + 1:
783
810
  if not self.silent:
784
811
  print("Use of 1D scat with data that has less than 1D")
785
- return None
812
+ return None, None
786
813
 
787
814
  npix = im.shape[axis]
788
815
  odata = 1
@@ -805,23 +832,33 @@ class FoCUS:
805
832
 
806
833
  if axis == 0:
807
834
  if len(ishape) == 1:
808
- return self.backend.bk_reshape(res, [npix // 2])
835
+ return self.backend.bk_reshape(res, [npix // 2]), None
809
836
  else:
810
- return self.backend.bk_reshape(
811
- res, [npix // 2] + ishape[axis + 1 :]
837
+ return (
838
+ self.backend.bk_reshape(res, [npix // 2] + ishape[axis + 1 :]),
839
+ None,
812
840
  )
813
841
  else:
814
842
  if len(ishape) == axis + 1:
815
- return self.backend.bk_reshape(res, ishape[0:axis] + [npix // 2])
843
+ return (
844
+ self.backend.bk_reshape(res, ishape[0:axis] + [npix // 2]),
845
+ None,
846
+ )
816
847
  else:
817
- return self.backend.bk_reshape(
818
- res, ishape[0:axis] + [npix // 2] + ishape[axis + 1 :]
848
+ return (
849
+ self.backend.bk_reshape(
850
+ res, ishape[0:axis] + [npix // 2] + ishape[axis + 1 :]
851
+ ),
852
+ None,
819
853
  )
820
854
 
821
- return self.backend.bk_reshape(res, [npix // 2])
855
+ return self.backend.bk_reshape(res, [npix // 2]), None
822
856
 
823
857
  else:
824
858
  shape = list(im.shape)
859
+ if cell_ids is not None:
860
+ sim, new_cell_ids = self.backend.binned_mean(im, cell_ids)
861
+ return sim, new_cell_ids
825
862
 
826
863
  lout = int(np.sqrt(shape[axis] // 12))
827
864
  if im.__class__ == np.zeros([0]).__class__:
@@ -840,8 +877,11 @@ class FoCUS:
840
877
  if len(shape) > axis:
841
878
  oshape = oshape + shape[axis + 1 :]
842
879
 
843
- return self.backend.bk_reduce_mean(
844
- self.backend.bk_reshape(im, oshape), axis=axis + 1
880
+ return (
881
+ self.backend.bk_reduce_mean(
882
+ self.backend.bk_reshape(im, oshape), axis=axis + 1
883
+ ),
884
+ None,
845
885
  )
846
886
 
847
887
  # --------------------------------------------------------
@@ -1794,14 +1834,14 @@ class FoCUS:
1794
1834
  if self.padding == "VALID":
1795
1835
  l_mask = l_mask[
1796
1836
  :,
1797
- self.KERNELSZ // 2 : -self.KERNELSZ // 2+1,
1798
- self.KERNELSZ // 2 : -self.KERNELSZ // 2+1,
1837
+ self.KERNELSZ // 2 : -self.KERNELSZ // 2 + 1,
1838
+ self.KERNELSZ // 2 : -self.KERNELSZ // 2 + 1,
1799
1839
  ]
1800
1840
  if shape[axis] != l_mask.shape[1]:
1801
1841
  l_mask = l_mask[
1802
1842
  :,
1803
- self.KERNELSZ // 2 : -self.KERNELSZ // 2+1,
1804
- self.KERNELSZ // 2 : -self.KERNELSZ // 2+1,
1843
+ self.KERNELSZ // 2 : -self.KERNELSZ // 2 + 1,
1844
+ self.KERNELSZ // 2 : -self.KERNELSZ // 2 + 1,
1805
1845
  ]
1806
1846
 
1807
1847
  ichannel = 1
@@ -1868,10 +1908,10 @@ class FoCUS:
1868
1908
  l_mask = self.backend.bk_complex(l_mask, self.backend.bk_cast(0.0 * l_mask))
1869
1909
 
1870
1910
  if self.use_2D:
1871
- #if self.padding == "VALID":
1911
+ # if self.padding == "VALID":
1872
1912
  mtmp = l_mask
1873
1913
  vtmp = l_x
1874
- #else:
1914
+ # else:
1875
1915
  # mtmp = l_mask[:,self.KERNELSZ // 2 : -self.KERNELSZ // 2,self.KERNELSZ // 2 : -self.KERNELSZ // 2,:]
1876
1916
  # vtmp = l_x[:,self.KERNELSZ // 2 : -self.KERNELSZ // 2,self.KERNELSZ // 2 : -self.KERNELSZ // 2,:]
1877
1917
 
@@ -2125,7 +2165,7 @@ class FoCUS:
2125
2165
  return self.backend.bk_reduce_sum(r)
2126
2166
 
2127
2167
  # ---------------------------------------------−---------
2128
- def convol(self, in_image, axis=0):
2168
+ def convol(self, in_image, axis=0, cell_ids=None, nside=None):
2129
2169
 
2130
2170
  image = self.backend.bk_cast(in_image)
2131
2171
 
@@ -2290,6 +2330,61 @@ class FoCUS:
2290
2330
  return self.backend.bk_reshape(res, in_image.shape + [self.NORIENT])
2291
2331
 
2292
2332
  else:
2333
+ ishape = list(image.shape)
2334
+
2335
+ if cell_ids is not None:
2336
+ if cell_ids.shape[0] not in self.padding_conv:
2337
+ import healpix_convolution as hc
2338
+ from xdggs.healpix import HealpixInfo
2339
+
2340
+ res = self.backend.bk_zeros(
2341
+ ishape + [self.NORIENT], dtype=self.backend.all_cbk_type
2342
+ )
2343
+
2344
+ grid_info = HealpixInfo(
2345
+ level=int(np.log(nside) / np.log(2)), indexing_scheme="nested"
2346
+ )
2347
+
2348
+ for k in range(self.NORIENT):
2349
+ kernelR, kernelI = hc.kernels.wavelet_kernel(
2350
+ cell_ids, grid_info=grid_info, orientation=k, is_torch=True
2351
+ )
2352
+ self.kernelR_conv[(cell_ids.shape[0], k)] = kernelR.to(
2353
+ self.backend.all_bk_type
2354
+ ).to(image.device)
2355
+ self.kernelI_conv[(cell_ids.shape[0], k)] = kernelI.to(
2356
+ self.backend.all_bk_type
2357
+ ).to(image.device)
2358
+ self.padding_conv[(cell_ids.shape[0], k)] = hc.pad(
2359
+ cell_ids,
2360
+ grid_info=grid_info,
2361
+ ring=5 // 2, # wavelet kernel_size=5 is hard coded
2362
+ mode="mean",
2363
+ constant_value=0,
2364
+ )
2365
+
2366
+ for k in range(self.NORIENT):
2367
+
2368
+ kernelR = self.kernelR_conv[(cell_ids.shape[0], k)]
2369
+ kernelI = self.kernelI_conv[(cell_ids.shape[0], k)]
2370
+ padding = self.padding_conv[(cell_ids.shape[0], k)]
2371
+ if len(ishape) == 2:
2372
+ for l in range(ishape[0]):
2373
+ padded_data = padding.apply(image[l], is_torch=True)
2374
+ res[l, :, k] = kernelR.matmul(
2375
+ padded_data
2376
+ ) + 1j * kernelI.matmul(padded_data)
2377
+ else:
2378
+ for l in range(ishape[0]):
2379
+ for k2 in range(ishape[2]):
2380
+ padded_data = padding.apply(
2381
+ image[l, :, k2], is_torch=True
2382
+ )
2383
+ res[l, :, k2, k] = kernelR.matmul(
2384
+ padded_data
2385
+ ) + 1j * kernelI.matmul(padded_data)
2386
+ return res
2387
+
2293
2388
  nside = int(np.sqrt(image.shape[axis] // 12))
2294
2389
 
2295
2390
  if self.Idx_Neighbours[nside] is None:
@@ -2306,7 +2401,6 @@ class FoCUS:
2306
2401
  l_ww_real = self.ww_Real[nside]
2307
2402
  l_ww_imag = self.ww_Imag[nside]
2308
2403
 
2309
- ishape = list(image.shape)
2310
2404
  odata = 1
2311
2405
  for k in range(axis + 1, len(ishape)):
2312
2406
  odata = odata * ishape[k]
@@ -2460,7 +2554,7 @@ class FoCUS:
2460
2554
  return res
2461
2555
 
2462
2556
  # ---------------------------------------------−---------
2463
- def smooth(self, in_image, axis=0):
2557
+ def smooth(self, in_image, axis=0, cell_ids=None, nside=None):
2464
2558
 
2465
2559
  image = self.backend.bk_cast(in_image)
2466
2560
 
@@ -2589,6 +2683,50 @@ class FoCUS:
2589
2683
  return self.backend.bk_reshape(res, in_image.shape)
2590
2684
 
2591
2685
  else:
2686
+
2687
+ ishape = list(image.shape)
2688
+
2689
+ if cell_ids is not None:
2690
+ if cell_ids.shape[0] not in self.padding_smooth:
2691
+ import healpix_convolution as hc
2692
+ from xdggs.healpix import HealpixInfo
2693
+
2694
+ grid_info = HealpixInfo(
2695
+ level=int(np.log(nside) / np.log(2)), indexing_scheme="nested"
2696
+ )
2697
+
2698
+ kernel = hc.kernels.wavelet_smooth_kernel(
2699
+ cell_ids, grid_info=grid_info, is_torch=True
2700
+ )
2701
+
2702
+ self.kernel_smooth[cell_ids.shape[0]] = kernel.to(
2703
+ self.backend.all_bk_type
2704
+ ).to(image.device)
2705
+
2706
+ self.padding_smooth[cell_ids.shape[0]] = hc.pad(
2707
+ cell_ids,
2708
+ grid_info=grid_info,
2709
+ ring=5 // 2, # wavelet kernel_size=5 is hard coded
2710
+ mode="mean",
2711
+ constant_value=0,
2712
+ )
2713
+
2714
+ kernel = self.kernel_smooth[cell_ids.shape[0]]
2715
+ padding = self.padding_smooth[cell_ids.shape[0]]
2716
+
2717
+ res = self.backend.bk_zeros(ishape, dtype=self.backend.all_cbk_type)
2718
+
2719
+ if len(ishape) == 2:
2720
+ for l in range(ishape[0]):
2721
+ padded_data = padding.apply(image[l], is_torch=True)
2722
+ res[l] = kernel.matmul(padded_data)
2723
+ else:
2724
+ for l in range(ishape[0]):
2725
+ for k2 in range(ishape[2]):
2726
+ padded_data = padding.apply(image[l, :, k2], is_torch=True)
2727
+ res[l, :, k2] = kernel.matmul(padded_data)
2728
+ return res
2729
+
2592
2730
  nside = int(np.sqrt(image.shape[axis] // 12))
2593
2731
 
2594
2732
  if self.Idx_Neighbours[nside] is None:
@@ -2604,7 +2742,6 @@ class FoCUS:
2604
2742
  self.w_smooth[nside] = ws
2605
2743
 
2606
2744
  l_w_smooth = self.w_smooth[nside]
2607
- ishape = list(image.shape)
2608
2745
 
2609
2746
  odata = 1
2610
2747
  for k in range(axis + 1, len(ishape)):
@@ -2707,9 +2844,11 @@ class FoCUS:
2707
2844
  # ---------------------------------------------−---------
2708
2845
  def get_ww(self, nside=1):
2709
2846
  if self.use_2D:
2710
-
2711
- return (self.ww_RealT[1].reshape(self.KERNELSZ*self.KERNELSZ,self.NORIENT),
2712
- self.ww_ImagT[1].reshape(self.KERNELSZ*self.KERNELSZ,self.NORIENT))
2847
+
2848
+ return (
2849
+ self.ww_RealT[1].reshape(self.KERNELSZ * self.KERNELSZ, self.NORIENT),
2850
+ self.ww_ImagT[1].reshape(self.KERNELSZ * self.KERNELSZ, self.NORIENT),
2851
+ )
2713
2852
  else:
2714
2853
  return (self.ww_Real[nside], self.ww_Imag[nside])
2715
2854
 
foscat/Synthesis.py CHANGED
@@ -240,9 +240,9 @@ class Synthesis:
240
240
  grd_mask = self.grd_mask
241
241
 
242
242
  if grd_mask is not None:
243
- g_tot = grd_mask * self.to_numpy(g_tot)
243
+ g_tot = self.operation.backend.to_numpy(g_tot * grd_mask)
244
244
  else:
245
- g_tot = self.to_numpy(g_tot)
245
+ g_tot = self.operation.backend.to_numpy(g_tot)
246
246
 
247
247
  g_tot[np.isnan(g_tot)] = 0.0
248
248
 
@@ -426,7 +426,7 @@ class Synthesis:
426
426
  factr=factr,
427
427
  maxiter=maxitt,
428
428
  )
429
- print('Final Loss ',loss)
429
+ print("Final Loss ", loss)
430
430
  # update bias input data
431
431
  if iteration < NUM_STEP_BIAS - 1:
432
432
  # if self.mpi_rank==0: