foscat 3.6.1__tar.gz → 3.7.0__tar.gz
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- {foscat-3.6.1/src/foscat.egg-info → foscat-3.7.0}/PKG-INFO +9 -2
- {foscat-3.6.1 → foscat-3.7.0}/pyproject.toml +1 -1
- {foscat-3.6.1 → foscat-3.7.0}/src/foscat/CircSpline.py +19 -22
- {foscat-3.6.1 → foscat-3.7.0}/src/foscat/FoCUS.py +58 -51
- {foscat-3.6.1 → foscat-3.7.0}/src/foscat/Spline1D.py +12 -13
- {foscat-3.6.1 → foscat-3.7.0}/src/foscat/Synthesis.py +16 -13
- foscat-3.7.0/src/foscat/alm.py +937 -0
- {foscat-3.6.1 → foscat-3.7.0}/src/foscat/backend.py +98 -29
- {foscat-3.6.1 → foscat-3.7.0}/src/foscat/scat_cov.py +597 -371
- foscat-3.7.0/src/foscat/scat_cov2D.py +78 -0
- {foscat-3.6.1 → foscat-3.7.0}/src/foscat/scat_cov_map.py +15 -2
- {foscat-3.6.1 → foscat-3.7.0/src/foscat.egg-info}/PKG-INFO +9 -2
- {foscat-3.6.1 → foscat-3.7.0}/src/foscat.egg-info/SOURCES.txt +0 -1
- foscat-3.6.1/src/foscat/alm.py +0 -799
- foscat-3.6.1/src/foscat/alm_tools.py +0 -11
- foscat-3.6.1/src/foscat/scat_cov2D.py +0 -18
- {foscat-3.6.1 → foscat-3.7.0}/LICENCE +0 -0
- {foscat-3.6.1 → foscat-3.7.0}/README.md +0 -0
- {foscat-3.6.1 → foscat-3.7.0}/setup.cfg +0 -0
- {foscat-3.6.1 → foscat-3.7.0}/src/foscat/CNN.py +0 -0
- {foscat-3.6.1 → foscat-3.7.0}/src/foscat/GCNN.py +0 -0
- {foscat-3.6.1 → foscat-3.7.0}/src/foscat/Softmax.py +0 -0
- {foscat-3.6.1 → foscat-3.7.0}/src/foscat/__init__.py +0 -0
- {foscat-3.6.1 → foscat-3.7.0}/src/foscat/backend_tens.py +0 -0
- {foscat-3.6.1 → foscat-3.7.0}/src/foscat/loss_backend_tens.py +0 -0
- {foscat-3.6.1 → foscat-3.7.0}/src/foscat/loss_backend_torch.py +0 -0
- {foscat-3.6.1 → foscat-3.7.0}/src/foscat/scat.py +0 -0
- {foscat-3.6.1 → foscat-3.7.0}/src/foscat/scat1D.py +0 -0
- {foscat-3.6.1 → foscat-3.7.0}/src/foscat/scat2D.py +0 -0
- {foscat-3.6.1 → foscat-3.7.0}/src/foscat/scat_cov1D.py +0 -0
- {foscat-3.6.1 → foscat-3.7.0}/src/foscat/scat_cov_map2D.py +0 -0
- {foscat-3.6.1 → foscat-3.7.0}/src/foscat.egg-info/dependency_links.txt +0 -0
- {foscat-3.6.1 → foscat-3.7.0}/src/foscat.egg-info/requires.txt +0 -0
- {foscat-3.6.1 → foscat-3.7.0}/src/foscat.egg-info/top_level.txt +0 -0
|
@@ -1,6 +1,6 @@
|
|
|
1
|
-
Metadata-Version: 2.
|
|
1
|
+
Metadata-Version: 2.2
|
|
2
2
|
Name: foscat
|
|
3
|
-
Version: 3.
|
|
3
|
+
Version: 3.7.0
|
|
4
4
|
Summary: Generate synthetic Healpix or 2D data using Cross Scattering Transform
|
|
5
5
|
Author-email: Jean-Marc DELOUIS <jean.marc.delouis@ifremer.fr>
|
|
6
6
|
Maintainer-email: Theo Foulquier <theo.foulquier@ifremer.fr>
|
|
@@ -19,6 +19,13 @@ Classifier: Programming Language :: Python :: 3.12
|
|
|
19
19
|
Requires-Python: >=3.9
|
|
20
20
|
Description-Content-Type: text/markdown
|
|
21
21
|
License-File: LICENCE
|
|
22
|
+
Requires-Dist: imageio
|
|
23
|
+
Requires-Dist: imagecodecs
|
|
24
|
+
Requires-Dist: matplotlib
|
|
25
|
+
Requires-Dist: numpy
|
|
26
|
+
Requires-Dist: tensorflow
|
|
27
|
+
Requires-Dist: healpy
|
|
28
|
+
Requires-Dist: spherical
|
|
22
29
|
|
|
23
30
|
# foscat
|
|
24
31
|
|
|
@@ -12,25 +12,23 @@ class CircSpline:
|
|
|
12
12
|
"""
|
|
13
13
|
self.degree = degree
|
|
14
14
|
self.nodes = nodes
|
|
15
|
-
|
|
16
|
-
|
|
17
|
-
def cubic_spline_function(self,x):
|
|
15
|
+
|
|
16
|
+
def cubic_spline_function(self, x):
|
|
18
17
|
"""
|
|
19
18
|
Evaluate the cubic spline basis function.
|
|
20
|
-
|
|
19
|
+
|
|
21
20
|
Args:
|
|
22
21
|
x (float or array): Input value(s) to evaluate the spline basis function.
|
|
23
|
-
|
|
22
|
+
|
|
24
23
|
Returns:
|
|
25
24
|
float or array: Result of the cubic spline basis function.
|
|
26
25
|
"""
|
|
27
26
|
return -2 * x**3 + 3 * x**2
|
|
28
27
|
|
|
29
|
-
|
|
30
|
-
def eval(self,x):
|
|
28
|
+
def eval(self, x):
|
|
31
29
|
"""
|
|
32
30
|
Compute a 3rd-degree cubic spline with 4-point support.
|
|
33
|
-
|
|
31
|
+
|
|
34
32
|
Args:
|
|
35
33
|
x (float or array): Input value(s) to compute the spline.
|
|
36
34
|
|
|
@@ -38,8 +36,8 @@ class CircSpline:
|
|
|
38
36
|
indices (array): Indices of the spline support points.
|
|
39
37
|
coefficients (array): Normalized spline coefficients.
|
|
40
38
|
"""
|
|
41
|
-
N=self.nodes
|
|
42
|
-
|
|
39
|
+
N = self.nodes
|
|
40
|
+
|
|
43
41
|
if isinstance(x, float):
|
|
44
42
|
# Single scalar input
|
|
45
43
|
base_idx = int(x * (N))
|
|
@@ -61,10 +59,10 @@ class CircSpline:
|
|
|
61
59
|
coefficients[0] = self.cubic_spline_function(0.5 - fractional_part / 2) / 2
|
|
62
60
|
|
|
63
61
|
# Assign indices for the support points
|
|
64
|
-
indices[3] = (base_idx + 2+N)%N
|
|
65
|
-
indices[2] = (base_idx + 1+N)%N
|
|
66
|
-
indices[1] = (base_idx + N
|
|
67
|
-
indices[0] = (base_idx + N-1)%N
|
|
62
|
+
indices[3] = (base_idx + 2 + N) % N
|
|
63
|
+
indices[2] = (base_idx + 1 + N) % N
|
|
64
|
+
indices[1] = (base_idx + N) % N
|
|
65
|
+
indices[0] = (base_idx + N - 1) % N
|
|
68
66
|
|
|
69
67
|
# Square coefficients and normalize
|
|
70
68
|
coefficients = coefficients * coefficients
|
|
@@ -72,11 +70,10 @@ class CircSpline:
|
|
|
72
70
|
|
|
73
71
|
return indices, coefficients
|
|
74
72
|
|
|
75
|
-
|
|
76
|
-
def eval_N(self,x,N):
|
|
73
|
+
def eval_N(self, x, N):
|
|
77
74
|
"""
|
|
78
75
|
Compute a 3rd-degree cubic spline with 4-point support.
|
|
79
|
-
|
|
76
|
+
|
|
80
77
|
Args:
|
|
81
78
|
x (float or array): Input value(s) to compute the spline.
|
|
82
79
|
|
|
@@ -84,7 +81,7 @@ class CircSpline:
|
|
|
84
81
|
indices (array): Indices of the spline support points.
|
|
85
82
|
coefficients (array): Normalized spline coefficients.
|
|
86
83
|
"""
|
|
87
|
-
|
|
84
|
+
|
|
88
85
|
if isinstance(x, float):
|
|
89
86
|
# Single scalar input
|
|
90
87
|
base_idx = int(x * (N))
|
|
@@ -106,10 +103,10 @@ class CircSpline:
|
|
|
106
103
|
coefficients[0] = self.cubic_spline_function(0.5 - fractional_part / 2) / 2
|
|
107
104
|
|
|
108
105
|
# Assign indices for the support points
|
|
109
|
-
indices[3] = (base_idx + 2+N)%N
|
|
110
|
-
indices[2] = (base_idx + 1+N)%N
|
|
111
|
-
indices[1] = (base_idx + N
|
|
112
|
-
indices[0] =(
|
|
106
|
+
indices[3] = (base_idx + 2 + N) % N
|
|
107
|
+
indices[2] = (base_idx + 1 + N) % N
|
|
108
|
+
indices[1] = (base_idx + N) % N
|
|
109
|
+
indices[0] = (base_idx + N - 1) % N
|
|
113
110
|
|
|
114
111
|
# Adjust indices to start from 0
|
|
115
112
|
# Square coefficients and normalize
|
|
@@ -12,33 +12,32 @@ TMPFILE_VERSION = "V4_0"
|
|
|
12
12
|
|
|
13
13
|
class FoCUS:
|
|
14
14
|
def __init__(
|
|
15
|
-
|
|
16
|
-
|
|
17
|
-
|
|
18
|
-
|
|
19
|
-
|
|
20
|
-
|
|
21
|
-
|
|
22
|
-
|
|
23
|
-
|
|
24
|
-
|
|
25
|
-
|
|
26
|
-
|
|
27
|
-
|
|
28
|
-
|
|
29
|
-
|
|
30
|
-
|
|
31
|
-
|
|
32
|
-
|
|
33
|
-
|
|
34
|
-
|
|
35
|
-
|
|
36
|
-
|
|
37
|
-
|
|
38
|
-
mpi_rank=0,
|
|
15
|
+
self,
|
|
16
|
+
NORIENT=4,
|
|
17
|
+
LAMBDA=1.2,
|
|
18
|
+
KERNELSZ=3,
|
|
19
|
+
slope=1.0,
|
|
20
|
+
all_type="float32",
|
|
21
|
+
nstep_max=16,
|
|
22
|
+
padding="SAME",
|
|
23
|
+
gpupos=0,
|
|
24
|
+
mask_thres=None,
|
|
25
|
+
mask_norm=False,
|
|
26
|
+
isMPI=False,
|
|
27
|
+
TEMPLATE_PATH="data",
|
|
28
|
+
BACKEND="tensorflow",
|
|
29
|
+
use_2D=False,
|
|
30
|
+
use_1D=False,
|
|
31
|
+
return_data=False,
|
|
32
|
+
JmaxDelta=0,
|
|
33
|
+
DODIV=False,
|
|
34
|
+
InitWave=None,
|
|
35
|
+
silent=True,
|
|
36
|
+
mpi_size=1,
|
|
37
|
+
mpi_rank=0,
|
|
39
38
|
):
|
|
40
39
|
|
|
41
|
-
self.__version__ = "3.
|
|
40
|
+
self.__version__ = "3.7.0"
|
|
42
41
|
# P00 coeff for normalization for scat_cov
|
|
43
42
|
self.TMPFILE_VERSION = TMPFILE_VERSION
|
|
44
43
|
self.P1_dic = None
|
|
@@ -47,7 +46,7 @@ class FoCUS:
|
|
|
47
46
|
self.mask_thres = mask_thres
|
|
48
47
|
self.mask_norm = mask_norm
|
|
49
48
|
self.InitWave = InitWave
|
|
50
|
-
|
|
49
|
+
self.mask_mask=None
|
|
51
50
|
self.mpi_size = mpi_size
|
|
52
51
|
self.mpi_rank = mpi_rank
|
|
53
52
|
self.return_data = return_data
|
|
@@ -83,18 +82,11 @@ class FoCUS:
|
|
|
83
82
|
self.nlog = 0
|
|
84
83
|
self.padding = padding
|
|
85
84
|
|
|
86
|
-
if
|
|
85
|
+
if JmaxDelta != 0:
|
|
87
86
|
if not self.silent:
|
|
88
87
|
print(
|
|
89
|
-
"OPTION
|
|
88
|
+
"OPTION JmaxDelta is not avialable anymore after version 3.6.2. Please use Jmax option in eval function"
|
|
90
89
|
)
|
|
91
|
-
JmaxDelta = OSTEP
|
|
92
|
-
else:
|
|
93
|
-
OSTEP = JmaxDelta
|
|
94
|
-
|
|
95
|
-
if JmaxDelta < -1:
|
|
96
|
-
if not self.silent:
|
|
97
|
-
print("Warning : Jmax can not be smaller than -1")
|
|
98
90
|
return None
|
|
99
91
|
|
|
100
92
|
self.OSTEP = JmaxDelta
|
|
@@ -197,13 +189,15 @@ class FoCUS:
|
|
|
197
189
|
w_smooth = w_smooth.flatten()
|
|
198
190
|
else:
|
|
199
191
|
for i in range(NORIENT):
|
|
200
|
-
a = i / float(NORIENT) * np.pi
|
|
201
|
-
|
|
202
|
-
|
|
203
|
-
|
|
192
|
+
a = (NORIENT-1-i) / float(NORIENT) * np.pi # get the same angle number than scattering lib
|
|
193
|
+
if KERNELSZ<5:
|
|
194
|
+
xx = (3 / float(KERNELSZ)) * LAMBDA * (x * np.cos(a) + y * np.sin(a))
|
|
195
|
+
yy = (3 / float(KERNELSZ)) * LAMBDA * (x * np.sin(a) - y * np.cos(a))
|
|
196
|
+
else:
|
|
197
|
+
xx = (3 /5) * LAMBDA * (x * np.cos(a) + y * np.sin(a))
|
|
198
|
+
yy = (3 /5) * LAMBDA * (x * np.sin(a) - y * np.cos(a))
|
|
204
199
|
if KERNELSZ == 5:
|
|
205
|
-
|
|
206
|
-
w_smooth = np.exp(-(xx**2 + yy**2))
|
|
200
|
+
w_smooth=np.exp(-2*((3.0/float(KERNELSZ)*xx)**2+(3.0/float(KERNELSZ)*yy)**2))
|
|
207
201
|
else:
|
|
208
202
|
w_smooth = np.exp(-0.5 * (xx**2 + yy**2))
|
|
209
203
|
tmp1 = np.cos(yy * np.pi) * w_smooth
|
|
@@ -211,7 +205,8 @@ class FoCUS:
|
|
|
211
205
|
|
|
212
206
|
wwc[:, i] = tmp1.flatten() - tmp1.mean()
|
|
213
207
|
wws[:, i] = tmp2.flatten() - tmp2.mean()
|
|
214
|
-
sigma = np.sqrt((wwc[:, i] ** 2).mean())
|
|
208
|
+
#sigma = np.sqrt((wwc[:, i] ** 2).mean())
|
|
209
|
+
sigma = np.mean(w_smooth)
|
|
215
210
|
wwc[:, i] /= sigma
|
|
216
211
|
wws[:, i] /= sigma
|
|
217
212
|
|
|
@@ -224,7 +219,8 @@ class FoCUS:
|
|
|
224
219
|
|
|
225
220
|
wwc[:, NORIENT] = tmp1.flatten() - tmp1.mean()
|
|
226
221
|
wws[:, NORIENT] = tmp2.flatten() - tmp2.mean()
|
|
227
|
-
sigma = np.sqrt((wwc[:, NORIENT] ** 2).mean())
|
|
222
|
+
#sigma = np.sqrt((wwc[:, NORIENT] ** 2).mean())
|
|
223
|
+
sigma = np.mean(w_smooth)
|
|
228
224
|
|
|
229
225
|
wwc[:, NORIENT] /= sigma
|
|
230
226
|
wws[:, NORIENT] /= sigma
|
|
@@ -233,11 +229,13 @@ class FoCUS:
|
|
|
233
229
|
|
|
234
230
|
wwc[:, NORIENT + 1] = tmp1.flatten() - tmp1.mean()
|
|
235
231
|
wws[:, NORIENT + 1] = tmp2.flatten() - tmp2.mean()
|
|
236
|
-
sigma = np.sqrt((wwc[:, NORIENT + 1] ** 2).mean())
|
|
232
|
+
#sigma = np.sqrt((wwc[:, NORIENT + 1] ** 2).mean())
|
|
233
|
+
sigma = np.mean(w_smooth)
|
|
237
234
|
wwc[:, NORIENT + 1] /= sigma
|
|
238
235
|
wws[:, NORIENT + 1] /= sigma
|
|
239
236
|
|
|
240
237
|
w_smooth = w_smooth.flatten()
|
|
238
|
+
|
|
241
239
|
if self.use_1D:
|
|
242
240
|
KERNELSZ = 5
|
|
243
241
|
|
|
@@ -1776,14 +1774,14 @@ class FoCUS:
|
|
|
1776
1774
|
if self.padding == "VALID":
|
|
1777
1775
|
l_mask = l_mask[
|
|
1778
1776
|
:,
|
|
1779
|
-
self.KERNELSZ // 2 : -self.KERNELSZ // 2
|
|
1780
|
-
self.KERNELSZ // 2 : -self.KERNELSZ // 2
|
|
1777
|
+
self.KERNELSZ // 2 : -self.KERNELSZ // 2+1,
|
|
1778
|
+
self.KERNELSZ // 2 : -self.KERNELSZ // 2+1,
|
|
1781
1779
|
]
|
|
1782
1780
|
if shape[axis] != l_mask.shape[1]:
|
|
1783
1781
|
l_mask = l_mask[
|
|
1784
1782
|
:,
|
|
1785
|
-
self.KERNELSZ // 2 : -self.KERNELSZ // 2
|
|
1786
|
-
self.KERNELSZ // 2 : -self.KERNELSZ // 2
|
|
1783
|
+
self.KERNELSZ // 2 : -self.KERNELSZ // 2+1,
|
|
1784
|
+
self.KERNELSZ // 2 : -self.KERNELSZ // 2+1,
|
|
1787
1785
|
]
|
|
1788
1786
|
|
|
1789
1787
|
ichannel = 1
|
|
@@ -1850,8 +1848,12 @@ class FoCUS:
|
|
|
1850
1848
|
l_mask = self.backend.bk_complex(l_mask, self.backend.bk_cast(0.0 * l_mask))
|
|
1851
1849
|
|
|
1852
1850
|
if self.use_2D:
|
|
1851
|
+
#if self.padding == "VALID":
|
|
1853
1852
|
mtmp = l_mask
|
|
1854
1853
|
vtmp = l_x
|
|
1854
|
+
#else:
|
|
1855
|
+
# mtmp = l_mask[:,self.KERNELSZ // 2 : -self.KERNELSZ // 2,self.KERNELSZ // 2 : -self.KERNELSZ // 2,:]
|
|
1856
|
+
# vtmp = l_x[:,self.KERNELSZ // 2 : -self.KERNELSZ // 2,self.KERNELSZ // 2 : -self.KERNELSZ // 2,:]
|
|
1855
1857
|
|
|
1856
1858
|
v1 = self.backend.bk_reduce_sum(
|
|
1857
1859
|
self.backend.bk_reduce_sum(mtmp * vtmp, axis=2), 2
|
|
@@ -2684,7 +2686,12 @@ class FoCUS:
|
|
|
2684
2686
|
|
|
2685
2687
|
# ---------------------------------------------−---------
|
|
2686
2688
|
def get_ww(self, nside=1):
|
|
2687
|
-
|
|
2689
|
+
if self.use_2D:
|
|
2690
|
+
|
|
2691
|
+
return (self.ww_RealT[1].reshape(self.KERNELSZ*self.KERNELSZ,self.NORIENT),
|
|
2692
|
+
self.ww_ImagT[1].reshape(self.KERNELSZ*self.KERNELSZ,self.NORIENT))
|
|
2693
|
+
else:
|
|
2694
|
+
return (self.ww_Real[nside], self.ww_Imag[nside])
|
|
2688
2695
|
|
|
2689
2696
|
# ---------------------------------------------−---------
|
|
2690
2697
|
def plot_ww(self):
|
|
@@ -2696,11 +2703,11 @@ class FoCUS:
|
|
|
2696
2703
|
for i in range(c.shape[1]):
|
|
2697
2704
|
plt.subplot(2, c.shape[1], 1 + i)
|
|
2698
2705
|
plt.imshow(
|
|
2699
|
-
c[:, i].reshape(npt, npt), cmap="
|
|
2706
|
+
c[:, i].reshape(npt, npt), cmap="viridis", vmin=-c.max(), vmax=c.max()
|
|
2700
2707
|
)
|
|
2701
2708
|
plt.subplot(2, c.shape[1], 1 + i + c.shape[1])
|
|
2702
2709
|
plt.imshow(
|
|
2703
|
-
s[:, i].reshape(npt, npt), cmap="
|
|
2710
|
+
s[:, i].reshape(npt, npt), cmap="viridis", vmin=-c.max(), vmax=c.max()
|
|
2704
2711
|
)
|
|
2705
2712
|
sys.stdout.flush()
|
|
2706
2713
|
plt.show()
|
|
@@ -1,5 +1,6 @@
|
|
|
1
1
|
import numpy as np
|
|
2
2
|
|
|
3
|
+
|
|
3
4
|
class Spline1D:
|
|
4
5
|
def __init__(self, nodes, degree=3):
|
|
5
6
|
"""
|
|
@@ -11,25 +12,23 @@ class Spline1D:
|
|
|
11
12
|
"""
|
|
12
13
|
self.degree = degree
|
|
13
14
|
self.nodes = nodes
|
|
14
|
-
|
|
15
|
-
|
|
16
|
-
def cubic_spline_function(self,x):
|
|
15
|
+
|
|
16
|
+
def cubic_spline_function(self, x):
|
|
17
17
|
"""
|
|
18
18
|
Evaluate the cubic spline basis function.
|
|
19
|
-
|
|
19
|
+
|
|
20
20
|
Args:
|
|
21
21
|
x (float or array): Input value(s) to evaluate the spline basis function.
|
|
22
|
-
|
|
22
|
+
|
|
23
23
|
Returns:
|
|
24
24
|
float or array: Result of the cubic spline basis function.
|
|
25
25
|
"""
|
|
26
26
|
return -2 * x**3 + 3 * x**2
|
|
27
27
|
|
|
28
|
-
|
|
29
|
-
def eval(self,x):
|
|
28
|
+
def eval(self, x):
|
|
30
29
|
"""
|
|
31
30
|
Compute a 3rd-degree cubic spline with 4-point support.
|
|
32
|
-
|
|
31
|
+
|
|
33
32
|
Args:
|
|
34
33
|
x (float or array): Input value(s) to compute the spline.
|
|
35
34
|
|
|
@@ -37,21 +36,21 @@ class Spline1D:
|
|
|
37
36
|
indices (array): Indices of the spline support points.
|
|
38
37
|
coefficients (array): Normalized spline coefficients.
|
|
39
38
|
"""
|
|
40
|
-
N=self.nodes
|
|
41
|
-
|
|
39
|
+
N = self.nodes
|
|
40
|
+
|
|
42
41
|
if isinstance(x, float):
|
|
43
42
|
# Single scalar input
|
|
44
|
-
base_idx = int(x * (N-1))
|
|
43
|
+
base_idx = int(x * (N - 1))
|
|
45
44
|
indices = np.zeros([4], dtype="int")
|
|
46
45
|
coefficients = np.zeros([4])
|
|
47
46
|
else:
|
|
48
47
|
# Array input
|
|
49
|
-
base_idx = (x * (N-1)).astype("int")
|
|
48
|
+
base_idx = (x * (N - 1)).astype("int")
|
|
50
49
|
indices = np.zeros([4, x.shape[0]], dtype="int")
|
|
51
50
|
coefficients = np.zeros([4, x.shape[0]])
|
|
52
51
|
|
|
53
52
|
# Compute the fractional part of the input
|
|
54
|
-
fractional_part = x * (N-1) - base_idx
|
|
53
|
+
fractional_part = x * (N - 1) - base_idx
|
|
55
54
|
|
|
56
55
|
# Compute spline coefficients for 4 support points
|
|
57
56
|
coefficients[3] = self.cubic_spline_function(fractional_part / 2) / 2
|
|
@@ -23,6 +23,7 @@ class Loss:
|
|
|
23
23
|
|
|
24
24
|
self.loss_function = function
|
|
25
25
|
self.scat_operator = scat_operator
|
|
26
|
+
self.to_numpy = scat_operator.backend.to_numpy
|
|
26
27
|
self.args = param
|
|
27
28
|
self.name = name
|
|
28
29
|
self.batch = batch
|
|
@@ -47,12 +48,13 @@ class Loss:
|
|
|
47
48
|
else:
|
|
48
49
|
return self.loss_function(x, batch, self.scat_operator, self.args)
|
|
49
50
|
|
|
50
|
-
def set_id_loss(self,id_loss):
|
|
51
|
+
def set_id_loss(self, id_loss):
|
|
51
52
|
self.id_loss = id_loss
|
|
52
|
-
|
|
53
|
-
def get_id_loss(self,id_loss):
|
|
53
|
+
|
|
54
|
+
def get_id_loss(self, id_loss):
|
|
54
55
|
return self.id_loss
|
|
55
|
-
|
|
56
|
+
|
|
57
|
+
|
|
56
58
|
class Synthesis:
|
|
57
59
|
def __init__(
|
|
58
60
|
self,
|
|
@@ -66,10 +68,10 @@ class Synthesis:
|
|
|
66
68
|
|
|
67
69
|
self.loss_class = loss_list
|
|
68
70
|
self.number_of_loss = len(loss_list)
|
|
69
|
-
|
|
71
|
+
|
|
70
72
|
for k in range(self.number_of_loss):
|
|
71
73
|
self.loss_class[k].set_id_loss(k)
|
|
72
|
-
|
|
74
|
+
|
|
73
75
|
self.__iteration__ = 1234
|
|
74
76
|
self.nlog = 0
|
|
75
77
|
self.m_dw, self.v_dw = 0.0, 0.0
|
|
@@ -83,6 +85,7 @@ class Synthesis:
|
|
|
83
85
|
self.curr_gpu = 0
|
|
84
86
|
self.event = Event()
|
|
85
87
|
self.operation = loss_list[0].scat_operator
|
|
88
|
+
self.to_numpy = self.operation.backend.to_numpy
|
|
86
89
|
self.mpi_size = self.operation.mpi_size
|
|
87
90
|
self.mpi_rank = self.operation.mpi_rank
|
|
88
91
|
self.KEEP_TRACK = None
|
|
@@ -222,24 +225,24 @@ class Synthesis:
|
|
|
222
225
|
else:
|
|
223
226
|
g_tot = g_tot + g
|
|
224
227
|
|
|
225
|
-
l_tot = l_tot +
|
|
228
|
+
l_tot = l_tot + self.to_numpy(l_loss)
|
|
226
229
|
|
|
227
230
|
if self.l_log[self.mpi_rank * self.MAXNUMLOSS + k] == -1:
|
|
228
231
|
self.l_log[self.mpi_rank * self.MAXNUMLOSS + k] = (
|
|
229
|
-
|
|
232
|
+
self.to_numpy(l_loss) / nstep
|
|
230
233
|
)
|
|
231
234
|
else:
|
|
232
235
|
self.l_log[self.mpi_rank * self.MAXNUMLOSS + k] = (
|
|
233
236
|
self.l_log[self.mpi_rank * self.MAXNUMLOSS + k]
|
|
234
|
-
+
|
|
237
|
+
+ self.to_numpy(l_loss) / nstep
|
|
235
238
|
)
|
|
236
239
|
|
|
237
240
|
grd_mask = self.grd_mask
|
|
238
241
|
|
|
239
242
|
if grd_mask is not None:
|
|
240
|
-
g_tot = grd_mask *
|
|
243
|
+
g_tot = grd_mask * self.to_numpy(g_tot)
|
|
241
244
|
else:
|
|
242
|
-
g_tot =
|
|
245
|
+
g_tot = self.to_numpy(g_tot)
|
|
243
246
|
|
|
244
247
|
g_tot[np.isnan(g_tot)] = 0.0
|
|
245
248
|
|
|
@@ -389,7 +392,7 @@ class Synthesis:
|
|
|
389
392
|
self.oshape = list(x.shape)
|
|
390
393
|
|
|
391
394
|
if not isinstance(x, np.ndarray):
|
|
392
|
-
x =
|
|
395
|
+
x = self.to_numpy(x)
|
|
393
396
|
|
|
394
397
|
x = x.flatten()
|
|
395
398
|
|
|
@@ -423,7 +426,7 @@ class Synthesis:
|
|
|
423
426
|
factr=factr,
|
|
424
427
|
maxiter=maxitt,
|
|
425
428
|
)
|
|
426
|
-
|
|
429
|
+
print('Final Loss ',loss)
|
|
427
430
|
# update bias input data
|
|
428
431
|
if iteration < NUM_STEP_BIAS - 1:
|
|
429
432
|
# if self.mpi_rank==0:
|