foscat 3.6.1__tar.gz → 3.7.1__tar.gz

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (34) hide show
  1. {foscat-3.6.1/src/foscat.egg-info → foscat-3.7.1}/PKG-INFO +10 -3
  2. {foscat-3.6.1 → foscat-3.7.1}/pyproject.toml +1 -1
  3. {foscat-3.6.1 → foscat-3.7.1}/src/foscat/CircSpline.py +19 -22
  4. {foscat-3.6.1 → foscat-3.7.1}/src/foscat/FoCUS.py +61 -51
  5. {foscat-3.6.1 → foscat-3.7.1}/src/foscat/Spline1D.py +12 -13
  6. {foscat-3.6.1 → foscat-3.7.1}/src/foscat/Synthesis.py +16 -13
  7. foscat-3.7.1/src/foscat/alm.py +937 -0
  8. {foscat-3.6.1 → foscat-3.7.1}/src/foscat/backend.py +141 -37
  9. {foscat-3.6.1 → foscat-3.7.1}/src/foscat/scat_cov.py +2311 -430
  10. foscat-3.7.1/src/foscat/scat_cov2D.py +118 -0
  11. {foscat-3.6.1 → foscat-3.7.1}/src/foscat/scat_cov_map.py +15 -2
  12. {foscat-3.6.1 → foscat-3.7.1/src/foscat.egg-info}/PKG-INFO +10 -3
  13. {foscat-3.6.1 → foscat-3.7.1}/src/foscat.egg-info/SOURCES.txt +1 -2
  14. foscat-3.6.1/src/foscat/alm.py +0 -799
  15. foscat-3.6.1/src/foscat/alm_tools.py +0 -11
  16. foscat-3.6.1/src/foscat/scat_cov2D.py +0 -18
  17. /foscat-3.6.1/LICENCE → /foscat-3.7.1/LICENSE +0 -0
  18. {foscat-3.6.1 → foscat-3.7.1}/README.md +0 -0
  19. {foscat-3.6.1 → foscat-3.7.1}/setup.cfg +0 -0
  20. {foscat-3.6.1 → foscat-3.7.1}/src/foscat/CNN.py +0 -0
  21. {foscat-3.6.1 → foscat-3.7.1}/src/foscat/GCNN.py +0 -0
  22. {foscat-3.6.1 → foscat-3.7.1}/src/foscat/Softmax.py +0 -0
  23. {foscat-3.6.1 → foscat-3.7.1}/src/foscat/__init__.py +0 -0
  24. {foscat-3.6.1 → foscat-3.7.1}/src/foscat/backend_tens.py +0 -0
  25. {foscat-3.6.1 → foscat-3.7.1}/src/foscat/loss_backend_tens.py +0 -0
  26. {foscat-3.6.1 → foscat-3.7.1}/src/foscat/loss_backend_torch.py +0 -0
  27. {foscat-3.6.1 → foscat-3.7.1}/src/foscat/scat.py +0 -0
  28. {foscat-3.6.1 → foscat-3.7.1}/src/foscat/scat1D.py +0 -0
  29. {foscat-3.6.1 → foscat-3.7.1}/src/foscat/scat2D.py +0 -0
  30. {foscat-3.6.1 → foscat-3.7.1}/src/foscat/scat_cov1D.py +0 -0
  31. {foscat-3.6.1 → foscat-3.7.1}/src/foscat/scat_cov_map2D.py +0 -0
  32. {foscat-3.6.1 → foscat-3.7.1}/src/foscat.egg-info/dependency_links.txt +0 -0
  33. {foscat-3.6.1 → foscat-3.7.1}/src/foscat.egg-info/requires.txt +0 -0
  34. {foscat-3.6.1 → foscat-3.7.1}/src/foscat.egg-info/top_level.txt +0 -0
@@ -1,6 +1,6 @@
1
- Metadata-Version: 2.1
1
+ Metadata-Version: 2.2
2
2
  Name: foscat
3
- Version: 3.6.1
3
+ Version: 3.7.1
4
4
  Summary: Generate synthetic Healpix or 2D data using Cross Scattering Transform
5
5
  Author-email: Jean-Marc DELOUIS <jean.marc.delouis@ifremer.fr>
6
6
  Maintainer-email: Theo Foulquier <theo.foulquier@ifremer.fr>
@@ -18,7 +18,14 @@ Classifier: Programming Language :: Python :: 3.11
18
18
  Classifier: Programming Language :: Python :: 3.12
19
19
  Requires-Python: >=3.9
20
20
  Description-Content-Type: text/markdown
21
- License-File: LICENCE
21
+ License-File: LICENSE
22
+ Requires-Dist: imageio
23
+ Requires-Dist: imagecodecs
24
+ Requires-Dist: matplotlib
25
+ Requires-Dist: numpy
26
+ Requires-Dist: tensorflow
27
+ Requires-Dist: healpy
28
+ Requires-Dist: spherical
22
29
 
23
30
  # foscat
24
31
 
@@ -1,6 +1,6 @@
1
1
  [project]
2
2
  name = "foscat"
3
- version = "3.6.1"
3
+ version = "3.7.1"
4
4
  description = "Generate synthetic Healpix or 2D data using Cross Scattering Transform"
5
5
  readme = "README.md"
6
6
  license = { text = "BSD-3-Clause" }
@@ -12,25 +12,23 @@ class CircSpline:
12
12
  """
13
13
  self.degree = degree
14
14
  self.nodes = nodes
15
-
16
-
17
- def cubic_spline_function(self,x):
15
+
16
+ def cubic_spline_function(self, x):
18
17
  """
19
18
  Evaluate the cubic spline basis function.
20
-
19
+
21
20
  Args:
22
21
  x (float or array): Input value(s) to evaluate the spline basis function.
23
-
22
+
24
23
  Returns:
25
24
  float or array: Result of the cubic spline basis function.
26
25
  """
27
26
  return -2 * x**3 + 3 * x**2
28
27
 
29
-
30
- def eval(self,x):
28
+ def eval(self, x):
31
29
  """
32
30
  Compute a 3rd-degree cubic spline with 4-point support.
33
-
31
+
34
32
  Args:
35
33
  x (float or array): Input value(s) to compute the spline.
36
34
 
@@ -38,8 +36,8 @@ class CircSpline:
38
36
  indices (array): Indices of the spline support points.
39
37
  coefficients (array): Normalized spline coefficients.
40
38
  """
41
- N=self.nodes
42
-
39
+ N = self.nodes
40
+
43
41
  if isinstance(x, float):
44
42
  # Single scalar input
45
43
  base_idx = int(x * (N))
@@ -61,10 +59,10 @@ class CircSpline:
61
59
  coefficients[0] = self.cubic_spline_function(0.5 - fractional_part / 2) / 2
62
60
 
63
61
  # Assign indices for the support points
64
- indices[3] = (base_idx + 2+N)%N
65
- indices[2] = (base_idx + 1+N)%N
66
- indices[1] = (base_idx + N )%N
67
- indices[0] = (base_idx + N-1)%N
62
+ indices[3] = (base_idx + 2 + N) % N
63
+ indices[2] = (base_idx + 1 + N) % N
64
+ indices[1] = (base_idx + N) % N
65
+ indices[0] = (base_idx + N - 1) % N
68
66
 
69
67
  # Square coefficients and normalize
70
68
  coefficients = coefficients * coefficients
@@ -72,11 +70,10 @@ class CircSpline:
72
70
 
73
71
  return indices, coefficients
74
72
 
75
-
76
- def eval_N(self,x,N):
73
+ def eval_N(self, x, N):
77
74
  """
78
75
  Compute a 3rd-degree cubic spline with 4-point support.
79
-
76
+
80
77
  Args:
81
78
  x (float or array): Input value(s) to compute the spline.
82
79
 
@@ -84,7 +81,7 @@ class CircSpline:
84
81
  indices (array): Indices of the spline support points.
85
82
  coefficients (array): Normalized spline coefficients.
86
83
  """
87
-
84
+
88
85
  if isinstance(x, float):
89
86
  # Single scalar input
90
87
  base_idx = int(x * (N))
@@ -106,10 +103,10 @@ class CircSpline:
106
103
  coefficients[0] = self.cubic_spline_function(0.5 - fractional_part / 2) / 2
107
104
 
108
105
  # Assign indices for the support points
109
- indices[3] = (base_idx + 2+N)%N
110
- indices[2] = (base_idx + 1+N)%N
111
- indices[1] = (base_idx + N )%N
112
- indices[0] =( base_idx + N-1)%N
106
+ indices[3] = (base_idx + 2 + N) % N
107
+ indices[2] = (base_idx + 1 + N) % N
108
+ indices[1] = (base_idx + N) % N
109
+ indices[0] = (base_idx + N - 1) % N
113
110
 
114
111
  # Adjust indices to start from 0
115
112
  # Square coefficients and normalize
@@ -12,33 +12,32 @@ TMPFILE_VERSION = "V4_0"
12
12
 
13
13
  class FoCUS:
14
14
  def __init__(
15
- self,
16
- NORIENT=4,
17
- LAMBDA=1.2,
18
- KERNELSZ=3,
19
- slope=1.0,
20
- all_type="float64",
21
- nstep_max=16,
22
- padding="SAME",
23
- gpupos=0,
24
- mask_thres=None,
25
- mask_norm=False,
26
- OSTEP=0,
27
- isMPI=False,
28
- TEMPLATE_PATH="data",
29
- BACKEND="tensorflow",
30
- use_2D=False,
31
- use_1D=False,
32
- return_data=False,
33
- JmaxDelta=0,
34
- DODIV=False,
35
- InitWave=None,
36
- silent=False,
37
- mpi_size=1,
38
- mpi_rank=0,
15
+ self,
16
+ NORIENT=4,
17
+ LAMBDA=1.2,
18
+ KERNELSZ=3,
19
+ slope=1.0,
20
+ all_type="float32",
21
+ nstep_max=16,
22
+ padding="SAME",
23
+ gpupos=0,
24
+ mask_thres=None,
25
+ mask_norm=False,
26
+ isMPI=False,
27
+ TEMPLATE_PATH="data",
28
+ BACKEND="tensorflow",
29
+ use_2D=False,
30
+ use_1D=False,
31
+ return_data=False,
32
+ JmaxDelta=0,
33
+ DODIV=False,
34
+ InitWave=None,
35
+ silent=True,
36
+ mpi_size=1,
37
+ mpi_rank=0,
39
38
  ):
40
39
 
41
- self.__version__ = "3.6.1"
40
+ self.__version__ = "3.7.1"
42
41
  # P00 coeff for normalization for scat_cov
43
42
  self.TMPFILE_VERSION = TMPFILE_VERSION
44
43
  self.P1_dic = None
@@ -47,7 +46,7 @@ class FoCUS:
47
46
  self.mask_thres = mask_thres
48
47
  self.mask_norm = mask_norm
49
48
  self.InitWave = InitWave
50
-
49
+ self.mask_mask=None
51
50
  self.mpi_size = mpi_size
52
51
  self.mpi_rank = mpi_rank
53
52
  self.return_data = return_data
@@ -83,18 +82,11 @@ class FoCUS:
83
82
  self.nlog = 0
84
83
  self.padding = padding
85
84
 
86
- if OSTEP != 0:
85
+ if JmaxDelta != 0:
87
86
  if not self.silent:
88
87
  print(
89
- "OPTION option is deprecated after version 2.0.6. Please use Jmax option"
88
+ "OPTION JmaxDelta is not avialable anymore after version 3.6.2. Please use Jmax option in eval function"
90
89
  )
91
- JmaxDelta = OSTEP
92
- else:
93
- OSTEP = JmaxDelta
94
-
95
- if JmaxDelta < -1:
96
- if not self.silent:
97
- print("Warning : Jmax can not be smaller than -1")
98
90
  return None
99
91
 
100
92
  self.OSTEP = JmaxDelta
@@ -163,6 +155,9 @@ class FoCUS:
163
155
  self.Y_CNN = {}
164
156
  self.Z_CNN = {}
165
157
 
158
+ self.filters_set={}
159
+ self.edge_masks={}
160
+
166
161
  wwc = np.zeros([KERNELSZ**2, l_NORIENT]).astype(all_type)
167
162
  wws = np.zeros([KERNELSZ**2, l_NORIENT]).astype(all_type)
168
163
 
@@ -197,13 +192,15 @@ class FoCUS:
197
192
  w_smooth = w_smooth.flatten()
198
193
  else:
199
194
  for i in range(NORIENT):
200
- a = i / float(NORIENT) * np.pi
201
- xx = (3 / float(KERNELSZ)) * LAMBDA * (x * np.cos(a) + y * np.sin(a))
202
- yy = (3 / float(KERNELSZ)) * LAMBDA * (x * np.sin(a) - y * np.cos(a))
203
-
195
+ a = (NORIENT-1-i) / float(NORIENT) * np.pi # get the same angle number than scattering lib
196
+ if KERNELSZ<5:
197
+ xx = (3 / float(KERNELSZ)) * LAMBDA * (x * np.cos(a) + y * np.sin(a))
198
+ yy = (3 / float(KERNELSZ)) * LAMBDA * (x * np.sin(a) - y * np.cos(a))
199
+ else:
200
+ xx = (3 /5) * LAMBDA * (x * np.cos(a) + y * np.sin(a))
201
+ yy = (3 /5) * LAMBDA * (x * np.sin(a) - y * np.cos(a))
204
202
  if KERNELSZ == 5:
205
- # w_smooth=np.exp(-2*((3.0/float(KERNELSZ)*xx)**2+(3.0/float(KERNELSZ)*yy)**2))
206
- w_smooth = np.exp(-(xx**2 + yy**2))
203
+ w_smooth=np.exp(-2*((3.0/float(KERNELSZ)*xx)**2+(3.0/float(KERNELSZ)*yy)**2))
207
204
  else:
208
205
  w_smooth = np.exp(-0.5 * (xx**2 + yy**2))
209
206
  tmp1 = np.cos(yy * np.pi) * w_smooth
@@ -211,7 +208,8 @@ class FoCUS:
211
208
 
212
209
  wwc[:, i] = tmp1.flatten() - tmp1.mean()
213
210
  wws[:, i] = tmp2.flatten() - tmp2.mean()
214
- sigma = np.sqrt((wwc[:, i] ** 2).mean())
211
+ #sigma = np.sqrt((wwc[:, i] ** 2).mean())
212
+ sigma = np.mean(w_smooth)
215
213
  wwc[:, i] /= sigma
216
214
  wws[:, i] /= sigma
217
215
 
@@ -224,7 +222,8 @@ class FoCUS:
224
222
 
225
223
  wwc[:, NORIENT] = tmp1.flatten() - tmp1.mean()
226
224
  wws[:, NORIENT] = tmp2.flatten() - tmp2.mean()
227
- sigma = np.sqrt((wwc[:, NORIENT] ** 2).mean())
225
+ #sigma = np.sqrt((wwc[:, NORIENT] ** 2).mean())
226
+ sigma = np.mean(w_smooth)
228
227
 
229
228
  wwc[:, NORIENT] /= sigma
230
229
  wws[:, NORIENT] /= sigma
@@ -233,11 +232,13 @@ class FoCUS:
233
232
 
234
233
  wwc[:, NORIENT + 1] = tmp1.flatten() - tmp1.mean()
235
234
  wws[:, NORIENT + 1] = tmp2.flatten() - tmp2.mean()
236
- sigma = np.sqrt((wwc[:, NORIENT + 1] ** 2).mean())
235
+ #sigma = np.sqrt((wwc[:, NORIENT + 1] ** 2).mean())
236
+ sigma = np.mean(w_smooth)
237
237
  wwc[:, NORIENT + 1] /= sigma
238
238
  wws[:, NORIENT + 1] /= sigma
239
239
 
240
240
  w_smooth = w_smooth.flatten()
241
+
241
242
  if self.use_1D:
242
243
  KERNELSZ = 5
243
244
 
@@ -1776,14 +1777,14 @@ class FoCUS:
1776
1777
  if self.padding == "VALID":
1777
1778
  l_mask = l_mask[
1778
1779
  :,
1779
- self.KERNELSZ // 2 : -self.KERNELSZ // 2 + 1,
1780
- self.KERNELSZ // 2 : -self.KERNELSZ // 2 + 1,
1780
+ self.KERNELSZ // 2 : -self.KERNELSZ // 2+1,
1781
+ self.KERNELSZ // 2 : -self.KERNELSZ // 2+1,
1781
1782
  ]
1782
1783
  if shape[axis] != l_mask.shape[1]:
1783
1784
  l_mask = l_mask[
1784
1785
  :,
1785
- self.KERNELSZ // 2 : -self.KERNELSZ // 2 + 1,
1786
- self.KERNELSZ // 2 : -self.KERNELSZ // 2 + 1,
1786
+ self.KERNELSZ // 2 : -self.KERNELSZ // 2+1,
1787
+ self.KERNELSZ // 2 : -self.KERNELSZ // 2+1,
1787
1788
  ]
1788
1789
 
1789
1790
  ichannel = 1
@@ -1850,8 +1851,12 @@ class FoCUS:
1850
1851
  l_mask = self.backend.bk_complex(l_mask, self.backend.bk_cast(0.0 * l_mask))
1851
1852
 
1852
1853
  if self.use_2D:
1854
+ #if self.padding == "VALID":
1853
1855
  mtmp = l_mask
1854
1856
  vtmp = l_x
1857
+ #else:
1858
+ # mtmp = l_mask[:,self.KERNELSZ // 2 : -self.KERNELSZ // 2,self.KERNELSZ // 2 : -self.KERNELSZ // 2,:]
1859
+ # vtmp = l_x[:,self.KERNELSZ // 2 : -self.KERNELSZ // 2,self.KERNELSZ // 2 : -self.KERNELSZ // 2,:]
1855
1860
 
1856
1861
  v1 = self.backend.bk_reduce_sum(
1857
1862
  self.backend.bk_reduce_sum(mtmp * vtmp, axis=2), 2
@@ -2684,7 +2689,12 @@ class FoCUS:
2684
2689
 
2685
2690
  # ---------------------------------------------−---------
2686
2691
  def get_ww(self, nside=1):
2687
- return (self.ww_Real[nside], self.ww_Imag[nside])
2692
+ if self.use_2D:
2693
+
2694
+ return (self.ww_RealT[1].reshape(self.KERNELSZ*self.KERNELSZ,self.NORIENT),
2695
+ self.ww_ImagT[1].reshape(self.KERNELSZ*self.KERNELSZ,self.NORIENT))
2696
+ else:
2697
+ return (self.ww_Real[nside], self.ww_Imag[nside])
2688
2698
 
2689
2699
  # ---------------------------------------------−---------
2690
2700
  def plot_ww(self):
@@ -2696,11 +2706,11 @@ class FoCUS:
2696
2706
  for i in range(c.shape[1]):
2697
2707
  plt.subplot(2, c.shape[1], 1 + i)
2698
2708
  plt.imshow(
2699
- c[:, i].reshape(npt, npt), cmap="jet", vmin=-c.max(), vmax=c.max()
2709
+ c[:, i].reshape(npt, npt), cmap="viridis", vmin=-c.max(), vmax=c.max()
2700
2710
  )
2701
2711
  plt.subplot(2, c.shape[1], 1 + i + c.shape[1])
2702
2712
  plt.imshow(
2703
- s[:, i].reshape(npt, npt), cmap="jet", vmin=-c.max(), vmax=c.max()
2713
+ s[:, i].reshape(npt, npt), cmap="viridis", vmin=-c.max(), vmax=c.max()
2704
2714
  )
2705
2715
  sys.stdout.flush()
2706
2716
  plt.show()
@@ -1,5 +1,6 @@
1
1
  import numpy as np
2
2
 
3
+
3
4
  class Spline1D:
4
5
  def __init__(self, nodes, degree=3):
5
6
  """
@@ -11,25 +12,23 @@ class Spline1D:
11
12
  """
12
13
  self.degree = degree
13
14
  self.nodes = nodes
14
-
15
-
16
- def cubic_spline_function(self,x):
15
+
16
+ def cubic_spline_function(self, x):
17
17
  """
18
18
  Evaluate the cubic spline basis function.
19
-
19
+
20
20
  Args:
21
21
  x (float or array): Input value(s) to evaluate the spline basis function.
22
-
22
+
23
23
  Returns:
24
24
  float or array: Result of the cubic spline basis function.
25
25
  """
26
26
  return -2 * x**3 + 3 * x**2
27
27
 
28
-
29
- def eval(self,x):
28
+ def eval(self, x):
30
29
  """
31
30
  Compute a 3rd-degree cubic spline with 4-point support.
32
-
31
+
33
32
  Args:
34
33
  x (float or array): Input value(s) to compute the spline.
35
34
 
@@ -37,21 +36,21 @@ class Spline1D:
37
36
  indices (array): Indices of the spline support points.
38
37
  coefficients (array): Normalized spline coefficients.
39
38
  """
40
- N=self.nodes
41
-
39
+ N = self.nodes
40
+
42
41
  if isinstance(x, float):
43
42
  # Single scalar input
44
- base_idx = int(x * (N-1))
43
+ base_idx = int(x * (N - 1))
45
44
  indices = np.zeros([4], dtype="int")
46
45
  coefficients = np.zeros([4])
47
46
  else:
48
47
  # Array input
49
- base_idx = (x * (N-1)).astype("int")
48
+ base_idx = (x * (N - 1)).astype("int")
50
49
  indices = np.zeros([4, x.shape[0]], dtype="int")
51
50
  coefficients = np.zeros([4, x.shape[0]])
52
51
 
53
52
  # Compute the fractional part of the input
54
- fractional_part = x * (N-1) - base_idx
53
+ fractional_part = x * (N - 1) - base_idx
55
54
 
56
55
  # Compute spline coefficients for 4 support points
57
56
  coefficients[3] = self.cubic_spline_function(fractional_part / 2) / 2
@@ -23,6 +23,7 @@ class Loss:
23
23
 
24
24
  self.loss_function = function
25
25
  self.scat_operator = scat_operator
26
+ self.to_numpy = scat_operator.backend.to_numpy
26
27
  self.args = param
27
28
  self.name = name
28
29
  self.batch = batch
@@ -47,12 +48,13 @@ class Loss:
47
48
  else:
48
49
  return self.loss_function(x, batch, self.scat_operator, self.args)
49
50
 
50
- def set_id_loss(self,id_loss):
51
+ def set_id_loss(self, id_loss):
51
52
  self.id_loss = id_loss
52
-
53
- def get_id_loss(self,id_loss):
53
+
54
+ def get_id_loss(self, id_loss):
54
55
  return self.id_loss
55
-
56
+
57
+
56
58
  class Synthesis:
57
59
  def __init__(
58
60
  self,
@@ -66,10 +68,10 @@ class Synthesis:
66
68
 
67
69
  self.loss_class = loss_list
68
70
  self.number_of_loss = len(loss_list)
69
-
71
+
70
72
  for k in range(self.number_of_loss):
71
73
  self.loss_class[k].set_id_loss(k)
72
-
74
+
73
75
  self.__iteration__ = 1234
74
76
  self.nlog = 0
75
77
  self.m_dw, self.v_dw = 0.0, 0.0
@@ -83,6 +85,7 @@ class Synthesis:
83
85
  self.curr_gpu = 0
84
86
  self.event = Event()
85
87
  self.operation = loss_list[0].scat_operator
88
+ self.to_numpy = self.operation.backend.to_numpy
86
89
  self.mpi_size = self.operation.mpi_size
87
90
  self.mpi_rank = self.operation.mpi_rank
88
91
  self.KEEP_TRACK = None
@@ -222,24 +225,24 @@ class Synthesis:
222
225
  else:
223
226
  g_tot = g_tot + g
224
227
 
225
- l_tot = l_tot + l_loss.numpy()
228
+ l_tot = l_tot + self.to_numpy(l_loss)
226
229
 
227
230
  if self.l_log[self.mpi_rank * self.MAXNUMLOSS + k] == -1:
228
231
  self.l_log[self.mpi_rank * self.MAXNUMLOSS + k] = (
229
- l_loss.numpy() / nstep
232
+ self.to_numpy(l_loss) / nstep
230
233
  )
231
234
  else:
232
235
  self.l_log[self.mpi_rank * self.MAXNUMLOSS + k] = (
233
236
  self.l_log[self.mpi_rank * self.MAXNUMLOSS + k]
234
- + l_loss.numpy() / nstep
237
+ + self.to_numpy(l_loss) / nstep
235
238
  )
236
239
 
237
240
  grd_mask = self.grd_mask
238
241
 
239
242
  if grd_mask is not None:
240
- g_tot = grd_mask * g_tot.numpy()
243
+ g_tot = grd_mask * self.to_numpy(g_tot)
241
244
  else:
242
- g_tot = g_tot.numpy()
245
+ g_tot = self.to_numpy(g_tot)
243
246
 
244
247
  g_tot[np.isnan(g_tot)] = 0.0
245
248
 
@@ -389,7 +392,7 @@ class Synthesis:
389
392
  self.oshape = list(x.shape)
390
393
 
391
394
  if not isinstance(x, np.ndarray):
392
- x = x.numpy()
395
+ x = self.to_numpy(x)
393
396
 
394
397
  x = x.flatten()
395
398
 
@@ -423,7 +426,7 @@ class Synthesis:
423
426
  factr=factr,
424
427
  maxiter=maxitt,
425
428
  )
426
-
429
+ print('Final Loss ',loss)
427
430
  # update bias input data
428
431
  if iteration < NUM_STEP_BIAS - 1:
429
432
  # if self.mpi_rank==0: