foscat 3.6.1__tar.gz → 3.7.1__tar.gz
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- {foscat-3.6.1/src/foscat.egg-info → foscat-3.7.1}/PKG-INFO +10 -3
- {foscat-3.6.1 → foscat-3.7.1}/pyproject.toml +1 -1
- {foscat-3.6.1 → foscat-3.7.1}/src/foscat/CircSpline.py +19 -22
- {foscat-3.6.1 → foscat-3.7.1}/src/foscat/FoCUS.py +61 -51
- {foscat-3.6.1 → foscat-3.7.1}/src/foscat/Spline1D.py +12 -13
- {foscat-3.6.1 → foscat-3.7.1}/src/foscat/Synthesis.py +16 -13
- foscat-3.7.1/src/foscat/alm.py +937 -0
- {foscat-3.6.1 → foscat-3.7.1}/src/foscat/backend.py +141 -37
- {foscat-3.6.1 → foscat-3.7.1}/src/foscat/scat_cov.py +2311 -430
- foscat-3.7.1/src/foscat/scat_cov2D.py +118 -0
- {foscat-3.6.1 → foscat-3.7.1}/src/foscat/scat_cov_map.py +15 -2
- {foscat-3.6.1 → foscat-3.7.1/src/foscat.egg-info}/PKG-INFO +10 -3
- {foscat-3.6.1 → foscat-3.7.1}/src/foscat.egg-info/SOURCES.txt +1 -2
- foscat-3.6.1/src/foscat/alm.py +0 -799
- foscat-3.6.1/src/foscat/alm_tools.py +0 -11
- foscat-3.6.1/src/foscat/scat_cov2D.py +0 -18
- /foscat-3.6.1/LICENCE → /foscat-3.7.1/LICENSE +0 -0
- {foscat-3.6.1 → foscat-3.7.1}/README.md +0 -0
- {foscat-3.6.1 → foscat-3.7.1}/setup.cfg +0 -0
- {foscat-3.6.1 → foscat-3.7.1}/src/foscat/CNN.py +0 -0
- {foscat-3.6.1 → foscat-3.7.1}/src/foscat/GCNN.py +0 -0
- {foscat-3.6.1 → foscat-3.7.1}/src/foscat/Softmax.py +0 -0
- {foscat-3.6.1 → foscat-3.7.1}/src/foscat/__init__.py +0 -0
- {foscat-3.6.1 → foscat-3.7.1}/src/foscat/backend_tens.py +0 -0
- {foscat-3.6.1 → foscat-3.7.1}/src/foscat/loss_backend_tens.py +0 -0
- {foscat-3.6.1 → foscat-3.7.1}/src/foscat/loss_backend_torch.py +0 -0
- {foscat-3.6.1 → foscat-3.7.1}/src/foscat/scat.py +0 -0
- {foscat-3.6.1 → foscat-3.7.1}/src/foscat/scat1D.py +0 -0
- {foscat-3.6.1 → foscat-3.7.1}/src/foscat/scat2D.py +0 -0
- {foscat-3.6.1 → foscat-3.7.1}/src/foscat/scat_cov1D.py +0 -0
- {foscat-3.6.1 → foscat-3.7.1}/src/foscat/scat_cov_map2D.py +0 -0
- {foscat-3.6.1 → foscat-3.7.1}/src/foscat.egg-info/dependency_links.txt +0 -0
- {foscat-3.6.1 → foscat-3.7.1}/src/foscat.egg-info/requires.txt +0 -0
- {foscat-3.6.1 → foscat-3.7.1}/src/foscat.egg-info/top_level.txt +0 -0
|
@@ -1,6 +1,6 @@
|
|
|
1
|
-
Metadata-Version: 2.
|
|
1
|
+
Metadata-Version: 2.2
|
|
2
2
|
Name: foscat
|
|
3
|
-
Version: 3.
|
|
3
|
+
Version: 3.7.1
|
|
4
4
|
Summary: Generate synthetic Healpix or 2D data using Cross Scattering Transform
|
|
5
5
|
Author-email: Jean-Marc DELOUIS <jean.marc.delouis@ifremer.fr>
|
|
6
6
|
Maintainer-email: Theo Foulquier <theo.foulquier@ifremer.fr>
|
|
@@ -18,7 +18,14 @@ Classifier: Programming Language :: Python :: 3.11
|
|
|
18
18
|
Classifier: Programming Language :: Python :: 3.12
|
|
19
19
|
Requires-Python: >=3.9
|
|
20
20
|
Description-Content-Type: text/markdown
|
|
21
|
-
License-File:
|
|
21
|
+
License-File: LICENSE
|
|
22
|
+
Requires-Dist: imageio
|
|
23
|
+
Requires-Dist: imagecodecs
|
|
24
|
+
Requires-Dist: matplotlib
|
|
25
|
+
Requires-Dist: numpy
|
|
26
|
+
Requires-Dist: tensorflow
|
|
27
|
+
Requires-Dist: healpy
|
|
28
|
+
Requires-Dist: spherical
|
|
22
29
|
|
|
23
30
|
# foscat
|
|
24
31
|
|
|
@@ -12,25 +12,23 @@ class CircSpline:
|
|
|
12
12
|
"""
|
|
13
13
|
self.degree = degree
|
|
14
14
|
self.nodes = nodes
|
|
15
|
-
|
|
16
|
-
|
|
17
|
-
def cubic_spline_function(self,x):
|
|
15
|
+
|
|
16
|
+
def cubic_spline_function(self, x):
|
|
18
17
|
"""
|
|
19
18
|
Evaluate the cubic spline basis function.
|
|
20
|
-
|
|
19
|
+
|
|
21
20
|
Args:
|
|
22
21
|
x (float or array): Input value(s) to evaluate the spline basis function.
|
|
23
|
-
|
|
22
|
+
|
|
24
23
|
Returns:
|
|
25
24
|
float or array: Result of the cubic spline basis function.
|
|
26
25
|
"""
|
|
27
26
|
return -2 * x**3 + 3 * x**2
|
|
28
27
|
|
|
29
|
-
|
|
30
|
-
def eval(self,x):
|
|
28
|
+
def eval(self, x):
|
|
31
29
|
"""
|
|
32
30
|
Compute a 3rd-degree cubic spline with 4-point support.
|
|
33
|
-
|
|
31
|
+
|
|
34
32
|
Args:
|
|
35
33
|
x (float or array): Input value(s) to compute the spline.
|
|
36
34
|
|
|
@@ -38,8 +36,8 @@ class CircSpline:
|
|
|
38
36
|
indices (array): Indices of the spline support points.
|
|
39
37
|
coefficients (array): Normalized spline coefficients.
|
|
40
38
|
"""
|
|
41
|
-
N=self.nodes
|
|
42
|
-
|
|
39
|
+
N = self.nodes
|
|
40
|
+
|
|
43
41
|
if isinstance(x, float):
|
|
44
42
|
# Single scalar input
|
|
45
43
|
base_idx = int(x * (N))
|
|
@@ -61,10 +59,10 @@ class CircSpline:
|
|
|
61
59
|
coefficients[0] = self.cubic_spline_function(0.5 - fractional_part / 2) / 2
|
|
62
60
|
|
|
63
61
|
# Assign indices for the support points
|
|
64
|
-
indices[3] = (base_idx + 2+N)%N
|
|
65
|
-
indices[2] = (base_idx + 1+N)%N
|
|
66
|
-
indices[1] = (base_idx + N
|
|
67
|
-
indices[0] = (base_idx + N-1)%N
|
|
62
|
+
indices[3] = (base_idx + 2 + N) % N
|
|
63
|
+
indices[2] = (base_idx + 1 + N) % N
|
|
64
|
+
indices[1] = (base_idx + N) % N
|
|
65
|
+
indices[0] = (base_idx + N - 1) % N
|
|
68
66
|
|
|
69
67
|
# Square coefficients and normalize
|
|
70
68
|
coefficients = coefficients * coefficients
|
|
@@ -72,11 +70,10 @@ class CircSpline:
|
|
|
72
70
|
|
|
73
71
|
return indices, coefficients
|
|
74
72
|
|
|
75
|
-
|
|
76
|
-
def eval_N(self,x,N):
|
|
73
|
+
def eval_N(self, x, N):
|
|
77
74
|
"""
|
|
78
75
|
Compute a 3rd-degree cubic spline with 4-point support.
|
|
79
|
-
|
|
76
|
+
|
|
80
77
|
Args:
|
|
81
78
|
x (float or array): Input value(s) to compute the spline.
|
|
82
79
|
|
|
@@ -84,7 +81,7 @@ class CircSpline:
|
|
|
84
81
|
indices (array): Indices of the spline support points.
|
|
85
82
|
coefficients (array): Normalized spline coefficients.
|
|
86
83
|
"""
|
|
87
|
-
|
|
84
|
+
|
|
88
85
|
if isinstance(x, float):
|
|
89
86
|
# Single scalar input
|
|
90
87
|
base_idx = int(x * (N))
|
|
@@ -106,10 +103,10 @@ class CircSpline:
|
|
|
106
103
|
coefficients[0] = self.cubic_spline_function(0.5 - fractional_part / 2) / 2
|
|
107
104
|
|
|
108
105
|
# Assign indices for the support points
|
|
109
|
-
indices[3] = (base_idx + 2+N)%N
|
|
110
|
-
indices[2] = (base_idx + 1+N)%N
|
|
111
|
-
indices[1] = (base_idx + N
|
|
112
|
-
indices[0] =(
|
|
106
|
+
indices[3] = (base_idx + 2 + N) % N
|
|
107
|
+
indices[2] = (base_idx + 1 + N) % N
|
|
108
|
+
indices[1] = (base_idx + N) % N
|
|
109
|
+
indices[0] = (base_idx + N - 1) % N
|
|
113
110
|
|
|
114
111
|
# Adjust indices to start from 0
|
|
115
112
|
# Square coefficients and normalize
|
|
@@ -12,33 +12,32 @@ TMPFILE_VERSION = "V4_0"
|
|
|
12
12
|
|
|
13
13
|
class FoCUS:
|
|
14
14
|
def __init__(
|
|
15
|
-
|
|
16
|
-
|
|
17
|
-
|
|
18
|
-
|
|
19
|
-
|
|
20
|
-
|
|
21
|
-
|
|
22
|
-
|
|
23
|
-
|
|
24
|
-
|
|
25
|
-
|
|
26
|
-
|
|
27
|
-
|
|
28
|
-
|
|
29
|
-
|
|
30
|
-
|
|
31
|
-
|
|
32
|
-
|
|
33
|
-
|
|
34
|
-
|
|
35
|
-
|
|
36
|
-
|
|
37
|
-
|
|
38
|
-
mpi_rank=0,
|
|
15
|
+
self,
|
|
16
|
+
NORIENT=4,
|
|
17
|
+
LAMBDA=1.2,
|
|
18
|
+
KERNELSZ=3,
|
|
19
|
+
slope=1.0,
|
|
20
|
+
all_type="float32",
|
|
21
|
+
nstep_max=16,
|
|
22
|
+
padding="SAME",
|
|
23
|
+
gpupos=0,
|
|
24
|
+
mask_thres=None,
|
|
25
|
+
mask_norm=False,
|
|
26
|
+
isMPI=False,
|
|
27
|
+
TEMPLATE_PATH="data",
|
|
28
|
+
BACKEND="tensorflow",
|
|
29
|
+
use_2D=False,
|
|
30
|
+
use_1D=False,
|
|
31
|
+
return_data=False,
|
|
32
|
+
JmaxDelta=0,
|
|
33
|
+
DODIV=False,
|
|
34
|
+
InitWave=None,
|
|
35
|
+
silent=True,
|
|
36
|
+
mpi_size=1,
|
|
37
|
+
mpi_rank=0,
|
|
39
38
|
):
|
|
40
39
|
|
|
41
|
-
self.__version__ = "3.
|
|
40
|
+
self.__version__ = "3.7.1"
|
|
42
41
|
# P00 coeff for normalization for scat_cov
|
|
43
42
|
self.TMPFILE_VERSION = TMPFILE_VERSION
|
|
44
43
|
self.P1_dic = None
|
|
@@ -47,7 +46,7 @@ class FoCUS:
|
|
|
47
46
|
self.mask_thres = mask_thres
|
|
48
47
|
self.mask_norm = mask_norm
|
|
49
48
|
self.InitWave = InitWave
|
|
50
|
-
|
|
49
|
+
self.mask_mask=None
|
|
51
50
|
self.mpi_size = mpi_size
|
|
52
51
|
self.mpi_rank = mpi_rank
|
|
53
52
|
self.return_data = return_data
|
|
@@ -83,18 +82,11 @@ class FoCUS:
|
|
|
83
82
|
self.nlog = 0
|
|
84
83
|
self.padding = padding
|
|
85
84
|
|
|
86
|
-
if
|
|
85
|
+
if JmaxDelta != 0:
|
|
87
86
|
if not self.silent:
|
|
88
87
|
print(
|
|
89
|
-
"OPTION
|
|
88
|
+
"OPTION JmaxDelta is not avialable anymore after version 3.6.2. Please use Jmax option in eval function"
|
|
90
89
|
)
|
|
91
|
-
JmaxDelta = OSTEP
|
|
92
|
-
else:
|
|
93
|
-
OSTEP = JmaxDelta
|
|
94
|
-
|
|
95
|
-
if JmaxDelta < -1:
|
|
96
|
-
if not self.silent:
|
|
97
|
-
print("Warning : Jmax can not be smaller than -1")
|
|
98
90
|
return None
|
|
99
91
|
|
|
100
92
|
self.OSTEP = JmaxDelta
|
|
@@ -163,6 +155,9 @@ class FoCUS:
|
|
|
163
155
|
self.Y_CNN = {}
|
|
164
156
|
self.Z_CNN = {}
|
|
165
157
|
|
|
158
|
+
self.filters_set={}
|
|
159
|
+
self.edge_masks={}
|
|
160
|
+
|
|
166
161
|
wwc = np.zeros([KERNELSZ**2, l_NORIENT]).astype(all_type)
|
|
167
162
|
wws = np.zeros([KERNELSZ**2, l_NORIENT]).astype(all_type)
|
|
168
163
|
|
|
@@ -197,13 +192,15 @@ class FoCUS:
|
|
|
197
192
|
w_smooth = w_smooth.flatten()
|
|
198
193
|
else:
|
|
199
194
|
for i in range(NORIENT):
|
|
200
|
-
a = i / float(NORIENT) * np.pi
|
|
201
|
-
|
|
202
|
-
|
|
203
|
-
|
|
195
|
+
a = (NORIENT-1-i) / float(NORIENT) * np.pi # get the same angle number than scattering lib
|
|
196
|
+
if KERNELSZ<5:
|
|
197
|
+
xx = (3 / float(KERNELSZ)) * LAMBDA * (x * np.cos(a) + y * np.sin(a))
|
|
198
|
+
yy = (3 / float(KERNELSZ)) * LAMBDA * (x * np.sin(a) - y * np.cos(a))
|
|
199
|
+
else:
|
|
200
|
+
xx = (3 /5) * LAMBDA * (x * np.cos(a) + y * np.sin(a))
|
|
201
|
+
yy = (3 /5) * LAMBDA * (x * np.sin(a) - y * np.cos(a))
|
|
204
202
|
if KERNELSZ == 5:
|
|
205
|
-
|
|
206
|
-
w_smooth = np.exp(-(xx**2 + yy**2))
|
|
203
|
+
w_smooth=np.exp(-2*((3.0/float(KERNELSZ)*xx)**2+(3.0/float(KERNELSZ)*yy)**2))
|
|
207
204
|
else:
|
|
208
205
|
w_smooth = np.exp(-0.5 * (xx**2 + yy**2))
|
|
209
206
|
tmp1 = np.cos(yy * np.pi) * w_smooth
|
|
@@ -211,7 +208,8 @@ class FoCUS:
|
|
|
211
208
|
|
|
212
209
|
wwc[:, i] = tmp1.flatten() - tmp1.mean()
|
|
213
210
|
wws[:, i] = tmp2.flatten() - tmp2.mean()
|
|
214
|
-
sigma = np.sqrt((wwc[:, i] ** 2).mean())
|
|
211
|
+
#sigma = np.sqrt((wwc[:, i] ** 2).mean())
|
|
212
|
+
sigma = np.mean(w_smooth)
|
|
215
213
|
wwc[:, i] /= sigma
|
|
216
214
|
wws[:, i] /= sigma
|
|
217
215
|
|
|
@@ -224,7 +222,8 @@ class FoCUS:
|
|
|
224
222
|
|
|
225
223
|
wwc[:, NORIENT] = tmp1.flatten() - tmp1.mean()
|
|
226
224
|
wws[:, NORIENT] = tmp2.flatten() - tmp2.mean()
|
|
227
|
-
sigma = np.sqrt((wwc[:, NORIENT] ** 2).mean())
|
|
225
|
+
#sigma = np.sqrt((wwc[:, NORIENT] ** 2).mean())
|
|
226
|
+
sigma = np.mean(w_smooth)
|
|
228
227
|
|
|
229
228
|
wwc[:, NORIENT] /= sigma
|
|
230
229
|
wws[:, NORIENT] /= sigma
|
|
@@ -233,11 +232,13 @@ class FoCUS:
|
|
|
233
232
|
|
|
234
233
|
wwc[:, NORIENT + 1] = tmp1.flatten() - tmp1.mean()
|
|
235
234
|
wws[:, NORIENT + 1] = tmp2.flatten() - tmp2.mean()
|
|
236
|
-
sigma = np.sqrt((wwc[:, NORIENT + 1] ** 2).mean())
|
|
235
|
+
#sigma = np.sqrt((wwc[:, NORIENT + 1] ** 2).mean())
|
|
236
|
+
sigma = np.mean(w_smooth)
|
|
237
237
|
wwc[:, NORIENT + 1] /= sigma
|
|
238
238
|
wws[:, NORIENT + 1] /= sigma
|
|
239
239
|
|
|
240
240
|
w_smooth = w_smooth.flatten()
|
|
241
|
+
|
|
241
242
|
if self.use_1D:
|
|
242
243
|
KERNELSZ = 5
|
|
243
244
|
|
|
@@ -1776,14 +1777,14 @@ class FoCUS:
|
|
|
1776
1777
|
if self.padding == "VALID":
|
|
1777
1778
|
l_mask = l_mask[
|
|
1778
1779
|
:,
|
|
1779
|
-
self.KERNELSZ // 2 : -self.KERNELSZ // 2
|
|
1780
|
-
self.KERNELSZ // 2 : -self.KERNELSZ // 2
|
|
1780
|
+
self.KERNELSZ // 2 : -self.KERNELSZ // 2+1,
|
|
1781
|
+
self.KERNELSZ // 2 : -self.KERNELSZ // 2+1,
|
|
1781
1782
|
]
|
|
1782
1783
|
if shape[axis] != l_mask.shape[1]:
|
|
1783
1784
|
l_mask = l_mask[
|
|
1784
1785
|
:,
|
|
1785
|
-
self.KERNELSZ // 2 : -self.KERNELSZ // 2
|
|
1786
|
-
self.KERNELSZ // 2 : -self.KERNELSZ // 2
|
|
1786
|
+
self.KERNELSZ // 2 : -self.KERNELSZ // 2+1,
|
|
1787
|
+
self.KERNELSZ // 2 : -self.KERNELSZ // 2+1,
|
|
1787
1788
|
]
|
|
1788
1789
|
|
|
1789
1790
|
ichannel = 1
|
|
@@ -1850,8 +1851,12 @@ class FoCUS:
|
|
|
1850
1851
|
l_mask = self.backend.bk_complex(l_mask, self.backend.bk_cast(0.0 * l_mask))
|
|
1851
1852
|
|
|
1852
1853
|
if self.use_2D:
|
|
1854
|
+
#if self.padding == "VALID":
|
|
1853
1855
|
mtmp = l_mask
|
|
1854
1856
|
vtmp = l_x
|
|
1857
|
+
#else:
|
|
1858
|
+
# mtmp = l_mask[:,self.KERNELSZ // 2 : -self.KERNELSZ // 2,self.KERNELSZ // 2 : -self.KERNELSZ // 2,:]
|
|
1859
|
+
# vtmp = l_x[:,self.KERNELSZ // 2 : -self.KERNELSZ // 2,self.KERNELSZ // 2 : -self.KERNELSZ // 2,:]
|
|
1855
1860
|
|
|
1856
1861
|
v1 = self.backend.bk_reduce_sum(
|
|
1857
1862
|
self.backend.bk_reduce_sum(mtmp * vtmp, axis=2), 2
|
|
@@ -2684,7 +2689,12 @@ class FoCUS:
|
|
|
2684
2689
|
|
|
2685
2690
|
# ---------------------------------------------−---------
|
|
2686
2691
|
def get_ww(self, nside=1):
|
|
2687
|
-
|
|
2692
|
+
if self.use_2D:
|
|
2693
|
+
|
|
2694
|
+
return (self.ww_RealT[1].reshape(self.KERNELSZ*self.KERNELSZ,self.NORIENT),
|
|
2695
|
+
self.ww_ImagT[1].reshape(self.KERNELSZ*self.KERNELSZ,self.NORIENT))
|
|
2696
|
+
else:
|
|
2697
|
+
return (self.ww_Real[nside], self.ww_Imag[nside])
|
|
2688
2698
|
|
|
2689
2699
|
# ---------------------------------------------−---------
|
|
2690
2700
|
def plot_ww(self):
|
|
@@ -2696,11 +2706,11 @@ class FoCUS:
|
|
|
2696
2706
|
for i in range(c.shape[1]):
|
|
2697
2707
|
plt.subplot(2, c.shape[1], 1 + i)
|
|
2698
2708
|
plt.imshow(
|
|
2699
|
-
c[:, i].reshape(npt, npt), cmap="
|
|
2709
|
+
c[:, i].reshape(npt, npt), cmap="viridis", vmin=-c.max(), vmax=c.max()
|
|
2700
2710
|
)
|
|
2701
2711
|
plt.subplot(2, c.shape[1], 1 + i + c.shape[1])
|
|
2702
2712
|
plt.imshow(
|
|
2703
|
-
s[:, i].reshape(npt, npt), cmap="
|
|
2713
|
+
s[:, i].reshape(npt, npt), cmap="viridis", vmin=-c.max(), vmax=c.max()
|
|
2704
2714
|
)
|
|
2705
2715
|
sys.stdout.flush()
|
|
2706
2716
|
plt.show()
|
|
@@ -1,5 +1,6 @@
|
|
|
1
1
|
import numpy as np
|
|
2
2
|
|
|
3
|
+
|
|
3
4
|
class Spline1D:
|
|
4
5
|
def __init__(self, nodes, degree=3):
|
|
5
6
|
"""
|
|
@@ -11,25 +12,23 @@ class Spline1D:
|
|
|
11
12
|
"""
|
|
12
13
|
self.degree = degree
|
|
13
14
|
self.nodes = nodes
|
|
14
|
-
|
|
15
|
-
|
|
16
|
-
def cubic_spline_function(self,x):
|
|
15
|
+
|
|
16
|
+
def cubic_spline_function(self, x):
|
|
17
17
|
"""
|
|
18
18
|
Evaluate the cubic spline basis function.
|
|
19
|
-
|
|
19
|
+
|
|
20
20
|
Args:
|
|
21
21
|
x (float or array): Input value(s) to evaluate the spline basis function.
|
|
22
|
-
|
|
22
|
+
|
|
23
23
|
Returns:
|
|
24
24
|
float or array: Result of the cubic spline basis function.
|
|
25
25
|
"""
|
|
26
26
|
return -2 * x**3 + 3 * x**2
|
|
27
27
|
|
|
28
|
-
|
|
29
|
-
def eval(self,x):
|
|
28
|
+
def eval(self, x):
|
|
30
29
|
"""
|
|
31
30
|
Compute a 3rd-degree cubic spline with 4-point support.
|
|
32
|
-
|
|
31
|
+
|
|
33
32
|
Args:
|
|
34
33
|
x (float or array): Input value(s) to compute the spline.
|
|
35
34
|
|
|
@@ -37,21 +36,21 @@ class Spline1D:
|
|
|
37
36
|
indices (array): Indices of the spline support points.
|
|
38
37
|
coefficients (array): Normalized spline coefficients.
|
|
39
38
|
"""
|
|
40
|
-
N=self.nodes
|
|
41
|
-
|
|
39
|
+
N = self.nodes
|
|
40
|
+
|
|
42
41
|
if isinstance(x, float):
|
|
43
42
|
# Single scalar input
|
|
44
|
-
base_idx = int(x * (N-1))
|
|
43
|
+
base_idx = int(x * (N - 1))
|
|
45
44
|
indices = np.zeros([4], dtype="int")
|
|
46
45
|
coefficients = np.zeros([4])
|
|
47
46
|
else:
|
|
48
47
|
# Array input
|
|
49
|
-
base_idx = (x * (N-1)).astype("int")
|
|
48
|
+
base_idx = (x * (N - 1)).astype("int")
|
|
50
49
|
indices = np.zeros([4, x.shape[0]], dtype="int")
|
|
51
50
|
coefficients = np.zeros([4, x.shape[0]])
|
|
52
51
|
|
|
53
52
|
# Compute the fractional part of the input
|
|
54
|
-
fractional_part = x * (N-1) - base_idx
|
|
53
|
+
fractional_part = x * (N - 1) - base_idx
|
|
55
54
|
|
|
56
55
|
# Compute spline coefficients for 4 support points
|
|
57
56
|
coefficients[3] = self.cubic_spline_function(fractional_part / 2) / 2
|
|
@@ -23,6 +23,7 @@ class Loss:
|
|
|
23
23
|
|
|
24
24
|
self.loss_function = function
|
|
25
25
|
self.scat_operator = scat_operator
|
|
26
|
+
self.to_numpy = scat_operator.backend.to_numpy
|
|
26
27
|
self.args = param
|
|
27
28
|
self.name = name
|
|
28
29
|
self.batch = batch
|
|
@@ -47,12 +48,13 @@ class Loss:
|
|
|
47
48
|
else:
|
|
48
49
|
return self.loss_function(x, batch, self.scat_operator, self.args)
|
|
49
50
|
|
|
50
|
-
def set_id_loss(self,id_loss):
|
|
51
|
+
def set_id_loss(self, id_loss):
|
|
51
52
|
self.id_loss = id_loss
|
|
52
|
-
|
|
53
|
-
def get_id_loss(self,id_loss):
|
|
53
|
+
|
|
54
|
+
def get_id_loss(self, id_loss):
|
|
54
55
|
return self.id_loss
|
|
55
|
-
|
|
56
|
+
|
|
57
|
+
|
|
56
58
|
class Synthesis:
|
|
57
59
|
def __init__(
|
|
58
60
|
self,
|
|
@@ -66,10 +68,10 @@ class Synthesis:
|
|
|
66
68
|
|
|
67
69
|
self.loss_class = loss_list
|
|
68
70
|
self.number_of_loss = len(loss_list)
|
|
69
|
-
|
|
71
|
+
|
|
70
72
|
for k in range(self.number_of_loss):
|
|
71
73
|
self.loss_class[k].set_id_loss(k)
|
|
72
|
-
|
|
74
|
+
|
|
73
75
|
self.__iteration__ = 1234
|
|
74
76
|
self.nlog = 0
|
|
75
77
|
self.m_dw, self.v_dw = 0.0, 0.0
|
|
@@ -83,6 +85,7 @@ class Synthesis:
|
|
|
83
85
|
self.curr_gpu = 0
|
|
84
86
|
self.event = Event()
|
|
85
87
|
self.operation = loss_list[0].scat_operator
|
|
88
|
+
self.to_numpy = self.operation.backend.to_numpy
|
|
86
89
|
self.mpi_size = self.operation.mpi_size
|
|
87
90
|
self.mpi_rank = self.operation.mpi_rank
|
|
88
91
|
self.KEEP_TRACK = None
|
|
@@ -222,24 +225,24 @@ class Synthesis:
|
|
|
222
225
|
else:
|
|
223
226
|
g_tot = g_tot + g
|
|
224
227
|
|
|
225
|
-
l_tot = l_tot +
|
|
228
|
+
l_tot = l_tot + self.to_numpy(l_loss)
|
|
226
229
|
|
|
227
230
|
if self.l_log[self.mpi_rank * self.MAXNUMLOSS + k] == -1:
|
|
228
231
|
self.l_log[self.mpi_rank * self.MAXNUMLOSS + k] = (
|
|
229
|
-
|
|
232
|
+
self.to_numpy(l_loss) / nstep
|
|
230
233
|
)
|
|
231
234
|
else:
|
|
232
235
|
self.l_log[self.mpi_rank * self.MAXNUMLOSS + k] = (
|
|
233
236
|
self.l_log[self.mpi_rank * self.MAXNUMLOSS + k]
|
|
234
|
-
+
|
|
237
|
+
+ self.to_numpy(l_loss) / nstep
|
|
235
238
|
)
|
|
236
239
|
|
|
237
240
|
grd_mask = self.grd_mask
|
|
238
241
|
|
|
239
242
|
if grd_mask is not None:
|
|
240
|
-
g_tot = grd_mask *
|
|
243
|
+
g_tot = grd_mask * self.to_numpy(g_tot)
|
|
241
244
|
else:
|
|
242
|
-
g_tot =
|
|
245
|
+
g_tot = self.to_numpy(g_tot)
|
|
243
246
|
|
|
244
247
|
g_tot[np.isnan(g_tot)] = 0.0
|
|
245
248
|
|
|
@@ -389,7 +392,7 @@ class Synthesis:
|
|
|
389
392
|
self.oshape = list(x.shape)
|
|
390
393
|
|
|
391
394
|
if not isinstance(x, np.ndarray):
|
|
392
|
-
x =
|
|
395
|
+
x = self.to_numpy(x)
|
|
393
396
|
|
|
394
397
|
x = x.flatten()
|
|
395
398
|
|
|
@@ -423,7 +426,7 @@ class Synthesis:
|
|
|
423
426
|
factr=factr,
|
|
424
427
|
maxiter=maxitt,
|
|
425
428
|
)
|
|
426
|
-
|
|
429
|
+
print('Final Loss ',loss)
|
|
427
430
|
# update bias input data
|
|
428
431
|
if iteration < NUM_STEP_BIAS - 1:
|
|
429
432
|
# if self.mpi_rank==0:
|