foscat 3.0.9__py3-none-any.whl → 3.6.0__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- foscat/CNN.py +151 -0
- foscat/CircSpline.py +102 -34
- foscat/FoCUS.py +2363 -1052
- foscat/GCNN.py +239 -0
- foscat/Softmax.py +29 -20
- foscat/Spline1D.py +86 -36
- foscat/Synthesis.py +335 -262
- foscat/alm.py +690 -0
- foscat/alm_tools.py +11 -0
- foscat/backend.py +933 -588
- foscat/backend_tens.py +63 -0
- foscat/loss_backend_tens.py +48 -38
- foscat/loss_backend_torch.py +35 -41
- foscat/scat.py +1639 -1015
- foscat/scat1D.py +1256 -774
- foscat/scat2D.py +9 -7
- foscat/scat_cov.py +3067 -1541
- foscat/scat_cov1D.py +11 -1467
- foscat/scat_cov2D.py +9 -7
- foscat/scat_cov_map.py +77 -51
- foscat/scat_cov_map2D.py +79 -49
- foscat-3.6.0.dist-info/LICENCE +13 -0
- foscat-3.6.0.dist-info/METADATA +184 -0
- foscat-3.6.0.dist-info/RECORD +27 -0
- {foscat-3.0.9.dist-info → foscat-3.6.0.dist-info}/WHEEL +1 -1
- foscat/GetGPUinfo.py +0 -36
- foscat-3.0.9.dist-info/METADATA +0 -23
- foscat-3.0.9.dist-info/RECORD +0 -22
- {foscat-3.0.9.dist-info → foscat-3.6.0.dist-info}/top_level.txt +0 -0
foscat/Synthesis.py
CHANGED
|
@@ -1,55 +1,77 @@
|
|
|
1
|
-
import numpy as np
|
|
2
|
-
import time
|
|
3
|
-
import sys
|
|
4
1
|
import os
|
|
5
|
-
|
|
6
|
-
|
|
7
|
-
from threading import Thread
|
|
8
|
-
|
|
2
|
+
import sys
|
|
3
|
+
import time
|
|
4
|
+
from threading import Event, Thread
|
|
5
|
+
|
|
6
|
+
import numpy as np
|
|
9
7
|
import scipy.optimize as opt
|
|
10
8
|
|
|
9
|
+
|
|
11
10
|
class Loss:
|
|
12
|
-
|
|
13
|
-
def __init__(
|
|
14
|
-
|
|
15
|
-
|
|
16
|
-
|
|
17
|
-
|
|
18
|
-
|
|
19
|
-
|
|
20
|
-
|
|
21
|
-
|
|
22
|
-
|
|
23
|
-
|
|
24
|
-
|
|
25
|
-
self.
|
|
26
|
-
self.
|
|
27
|
-
self.
|
|
28
|
-
|
|
29
|
-
|
|
11
|
+
|
|
12
|
+
def __init__(
|
|
13
|
+
self,
|
|
14
|
+
function,
|
|
15
|
+
scat_operator,
|
|
16
|
+
*param,
|
|
17
|
+
name="",
|
|
18
|
+
batch=None,
|
|
19
|
+
batch_data=None,
|
|
20
|
+
batch_update=None,
|
|
21
|
+
info_callback=False,
|
|
22
|
+
):
|
|
23
|
+
|
|
24
|
+
self.loss_function = function
|
|
25
|
+
self.scat_operator = scat_operator
|
|
26
|
+
self.args = param
|
|
27
|
+
self.name = name
|
|
28
|
+
self.batch = batch
|
|
29
|
+
self.batch_data = batch_data
|
|
30
|
+
self.batch_update = batch_update
|
|
31
|
+
self.info = info_callback
|
|
32
|
+
self.id_loss = 0
|
|
33
|
+
|
|
34
|
+
def eval(self, x, batch, return_all=False):
|
|
30
35
|
if self.batch is None:
|
|
31
36
|
if self.info:
|
|
32
|
-
return self.loss_function(
|
|
37
|
+
return self.loss_function(
|
|
38
|
+
x, self.scat_operator, self.args, return_all=return_all
|
|
39
|
+
)
|
|
33
40
|
else:
|
|
34
|
-
return self.loss_function(x,self.scat_operator,self.args)
|
|
41
|
+
return self.loss_function(x, self.scat_operator, self.args)
|
|
35
42
|
else:
|
|
36
43
|
if self.info:
|
|
37
|
-
return self.loss_function(
|
|
44
|
+
return self.loss_function(
|
|
45
|
+
x, batch, self.scat_operator, self.args, return_all=return_all
|
|
46
|
+
)
|
|
38
47
|
else:
|
|
39
|
-
return self.loss_function(x,batch,self.scat_operator,self.args)
|
|
48
|
+
return self.loss_function(x, batch, self.scat_operator, self.args)
|
|
49
|
+
|
|
50
|
+
def set_id_loss(self,id_loss):
|
|
51
|
+
self.id_loss = id_loss
|
|
40
52
|
|
|
53
|
+
def get_id_loss(self,id_loss):
|
|
54
|
+
return self.id_loss
|
|
55
|
+
|
|
41
56
|
class Synthesis:
|
|
42
|
-
def __init__(
|
|
43
|
-
|
|
44
|
-
|
|
45
|
-
|
|
46
|
-
|
|
47
|
-
|
|
48
|
-
|
|
49
|
-
|
|
50
|
-
|
|
51
|
-
|
|
52
|
-
self.
|
|
57
|
+
def __init__(
|
|
58
|
+
self,
|
|
59
|
+
loss_list,
|
|
60
|
+
eta=0.03,
|
|
61
|
+
beta1=0.9,
|
|
62
|
+
beta2=0.999,
|
|
63
|
+
epsilon=1e-7,
|
|
64
|
+
decay_rate=0.999,
|
|
65
|
+
):
|
|
66
|
+
|
|
67
|
+
self.loss_class = loss_list
|
|
68
|
+
self.number_of_loss = len(loss_list)
|
|
69
|
+
|
|
70
|
+
for k in range(self.number_of_loss):
|
|
71
|
+
self.loss_class[k].set_id_loss(k)
|
|
72
|
+
|
|
73
|
+
self.__iteration__ = 1234
|
|
74
|
+
self.nlog = 0
|
|
53
75
|
self.m_dw, self.v_dw = 0.0, 0.0
|
|
54
76
|
self.beta1 = beta1
|
|
55
77
|
self.beta2 = beta2
|
|
@@ -57,326 +79,377 @@ class Synthesis:
|
|
|
57
79
|
self.pbeta2 = beta2
|
|
58
80
|
self.epsilon = epsilon
|
|
59
81
|
self.eta = eta
|
|
60
|
-
self.history=np.zeros([10])
|
|
61
|
-
self.curr_gpu=0
|
|
82
|
+
self.history = np.zeros([10])
|
|
83
|
+
self.curr_gpu = 0
|
|
62
84
|
self.event = Event()
|
|
63
|
-
self.operation=loss_list[0].scat_operator
|
|
64
|
-
self.mpi_size=self.operation.mpi_size
|
|
65
|
-
self.mpi_rank=self.operation.mpi_rank
|
|
66
|
-
self.KEEP_TRACK=None
|
|
67
|
-
self.MAXNUMLOSS=len(loss_list)
|
|
68
|
-
|
|
69
|
-
if self.operation.BACKEND==
|
|
85
|
+
self.operation = loss_list[0].scat_operator
|
|
86
|
+
self.mpi_size = self.operation.mpi_size
|
|
87
|
+
self.mpi_rank = self.operation.mpi_rank
|
|
88
|
+
self.KEEP_TRACK = None
|
|
89
|
+
self.MAXNUMLOSS = len(loss_list)
|
|
90
|
+
|
|
91
|
+
if self.operation.BACKEND == "tensorflow":
|
|
70
92
|
import foscat.loss_backend_tens as fbk
|
|
71
|
-
|
|
72
|
-
|
|
73
|
-
|
|
93
|
+
|
|
94
|
+
self.bk = fbk.loss_backend(self.operation, self.curr_gpu, self.mpi_rank)
|
|
95
|
+
|
|
96
|
+
if self.operation.BACKEND == "torch":
|
|
74
97
|
import foscat.loss_backend_torch as fbk
|
|
75
|
-
|
|
76
|
-
|
|
77
|
-
|
|
78
|
-
|
|
79
|
-
|
|
98
|
+
|
|
99
|
+
self.bk = fbk.loss_backend(self.operation, self.curr_gpu, self.mpi_rank)
|
|
100
|
+
|
|
101
|
+
if self.operation.BACKEND == "numpy":
|
|
102
|
+
print(
|
|
103
|
+
"Synthesis does not work with numpy. Please select Torch or Tensorflow FOSCAT backend"
|
|
104
|
+
)
|
|
105
|
+
return None
|
|
80
106
|
|
|
81
107
|
# ---------------------------------------------−---------
|
|
82
|
-
def get_gpu(self,event,delay):
|
|
108
|
+
def get_gpu(self, event, delay):
|
|
83
109
|
|
|
84
|
-
isnvidia=os.system(
|
|
110
|
+
isnvidia = os.system("which nvidia-smi &> /dev/null")
|
|
85
111
|
|
|
86
|
-
while
|
|
112
|
+
while 1:
|
|
87
113
|
if event.is_set():
|
|
88
114
|
break
|
|
89
115
|
time.sleep(delay)
|
|
90
|
-
if isnvidia==0:
|
|
116
|
+
if isnvidia == 0:
|
|
91
117
|
try:
|
|
92
|
-
os.system(
|
|
118
|
+
os.system(
|
|
119
|
+
"nvidia-smi | awk '$2==\"N/A\"{print substr($9,1,length($9)-3),substr($11,1,length($11)-3),substr($13,1,length($13)-1)}' > smi_tmp.txt"
|
|
120
|
+
)
|
|
93
121
|
except:
|
|
94
|
-
|
|
95
|
-
|
|
122
|
+
print("No nvidia GPU: Impossible to trace")
|
|
123
|
+
self.nogpu = 1
|
|
124
|
+
|
|
96
125
|
def stop_synthesis(self):
|
|
97
126
|
# stop thread that catch GPU information
|
|
98
127
|
self.event.set()
|
|
99
|
-
|
|
128
|
+
|
|
100
129
|
try:
|
|
101
130
|
self.gpu_thrd.join()
|
|
102
131
|
except:
|
|
103
|
-
print(
|
|
132
|
+
print("No thread to stop, everything is ok")
|
|
104
133
|
sys.stdout.flush()
|
|
105
|
-
|
|
134
|
+
|
|
106
135
|
# ---------------------------------------------−---------
|
|
107
136
|
def getgpumem(self):
|
|
108
137
|
try:
|
|
109
|
-
return np.loadtxt(
|
|
138
|
+
return np.loadtxt("smi_tmp.txt")
|
|
110
139
|
except:
|
|
111
|
-
return
|
|
112
|
-
|
|
140
|
+
return np.zeros([1, 3])
|
|
141
|
+
|
|
113
142
|
# ---------------------------------------------−---------
|
|
114
|
-
def info_back(self,x):
|
|
115
|
-
|
|
116
|
-
self.nlog=self.nlog+1
|
|
117
|
-
self.itt2=0
|
|
118
|
-
|
|
119
|
-
if self.itt%self.EVAL_FREQUENCY==0 and self.mpi_rank==0:
|
|
143
|
+
def info_back(self, x):
|
|
144
|
+
|
|
145
|
+
self.nlog = self.nlog + 1
|
|
146
|
+
self.itt2 = 0
|
|
147
|
+
|
|
148
|
+
if self.itt % self.EVAL_FREQUENCY == 0 and self.mpi_rank == 0:
|
|
120
149
|
end = time.time()
|
|
121
|
-
cur_loss=
|
|
122
|
-
for k in self.ltot[self.ltot
|
|
123
|
-
cur_loss=cur_loss+
|
|
124
|
-
|
|
125
|
-
cur_loss=cur_loss+
|
|
126
|
-
|
|
127
|
-
mess=
|
|
128
|
-
|
|
150
|
+
cur_loss = "%10.3g (" % (self.ltot[self.ltot != -1].mean())
|
|
151
|
+
for k in self.ltot[self.ltot != -1]:
|
|
152
|
+
cur_loss = cur_loss + "%10.3g " % (k)
|
|
153
|
+
|
|
154
|
+
cur_loss = cur_loss + ")"
|
|
155
|
+
|
|
156
|
+
mess = ""
|
|
157
|
+
|
|
129
158
|
if self.SHOWGPU:
|
|
130
|
-
info_gpu=self.getgpumem()
|
|
159
|
+
info_gpu = self.getgpumem()
|
|
131
160
|
for k in range(info_gpu.shape[0]):
|
|
132
|
-
mess=mess+
|
|
133
|
-
|
|
134
|
-
|
|
161
|
+
mess = mess + "[GPU%d %.0f/%.0f MB %.0f%%]" % (
|
|
162
|
+
k,
|
|
163
|
+
info_gpu[k, 0],
|
|
164
|
+
info_gpu[k, 1],
|
|
165
|
+
info_gpu[k, 2],
|
|
166
|
+
)
|
|
167
|
+
|
|
168
|
+
print(
|
|
169
|
+
"%sItt %6d L=%s %.3fs %s"
|
|
170
|
+
% (self.MESSAGE, self.itt, cur_loss, (end - self.start), mess)
|
|
171
|
+
)
|
|
135
172
|
sys.stdout.flush()
|
|
136
173
|
if self.KEEP_TRACK is not None:
|
|
137
174
|
print(self.last_info)
|
|
138
175
|
sys.stdout.flush()
|
|
139
|
-
|
|
176
|
+
|
|
140
177
|
self.start = time.time()
|
|
141
|
-
|
|
142
|
-
self.itt=self.itt+1
|
|
143
|
-
|
|
178
|
+
|
|
179
|
+
self.itt = self.itt + 1
|
|
180
|
+
|
|
144
181
|
# ---------------------------------------------−---------
|
|
145
|
-
def calc_grad(self,in_x):
|
|
146
|
-
|
|
147
|
-
g_tot=None
|
|
148
|
-
l_tot=0.0
|
|
182
|
+
def calc_grad(self, in_x):
|
|
183
|
+
|
|
184
|
+
g_tot = None
|
|
185
|
+
l_tot = 0.0
|
|
149
186
|
|
|
150
|
-
if self.do_all_noise and self.totalsz>self.batchsz:
|
|
151
|
-
nstep=self.totalsz//self.batchsz
|
|
187
|
+
if self.do_all_noise and self.totalsz > self.batchsz:
|
|
188
|
+
nstep = self.totalsz // self.batchsz
|
|
152
189
|
else:
|
|
153
|
-
nstep=1
|
|
190
|
+
nstep = 1
|
|
191
|
+
|
|
192
|
+
x = self.operation.backend.bk_reshape(
|
|
193
|
+
self.operation.backend.bk_cast(in_x), self.oshape
|
|
194
|
+
)
|
|
195
|
+
|
|
196
|
+
self.l_log[
|
|
197
|
+
self.mpi_rank * self.MAXNUMLOSS : (self.mpi_rank + 1) * self.MAXNUMLOSS
|
|
198
|
+
] = -1.0
|
|
154
199
|
|
|
155
|
-
x=self.operation.backend.bk_reshape(self.operation.backend.bk_cast(in_x),self.oshape)
|
|
156
|
-
|
|
157
|
-
self.l_log[self.mpi_rank*self.MAXNUMLOSS:(self.mpi_rank+1)*self.MAXNUMLOSS]=-1.0
|
|
158
|
-
|
|
159
200
|
for istep in range(nstep):
|
|
160
|
-
|
|
201
|
+
|
|
161
202
|
for k in range(self.number_of_loss):
|
|
162
203
|
if self.loss_class[k].batch is None:
|
|
163
|
-
l_batch=None
|
|
204
|
+
l_batch = None
|
|
164
205
|
else:
|
|
165
|
-
l_batch=self.loss_class[k].batch(
|
|
206
|
+
l_batch = self.loss_class[k].batch(
|
|
207
|
+
self.loss_class[k].batch_data, istep
|
|
208
|
+
)
|
|
166
209
|
|
|
167
210
|
if self.KEEP_TRACK is not None:
|
|
168
|
-
|
|
169
|
-
|
|
211
|
+
l_loss, g, linfo = self.bk.loss(
|
|
212
|
+
x, l_batch, self.loss_class[k], self.KEEP_TRACK
|
|
213
|
+
)
|
|
214
|
+
self.last_info = self.KEEP_TRACK(linfo, self.mpi_rank, add=True)
|
|
170
215
|
else:
|
|
171
|
-
|
|
216
|
+
l_loss, g = self.bk.loss(
|
|
217
|
+
x, l_batch, self.loss_class[k], self.KEEP_TRACK
|
|
218
|
+
)
|
|
172
219
|
|
|
173
220
|
if g_tot is None:
|
|
174
|
-
g_tot=g
|
|
221
|
+
g_tot = g
|
|
175
222
|
else:
|
|
176
|
-
g_tot=g_tot+g
|
|
223
|
+
g_tot = g_tot + g
|
|
177
224
|
|
|
178
|
-
l_tot=l_tot+
|
|
225
|
+
l_tot = l_tot + l_loss.numpy()
|
|
179
226
|
|
|
180
|
-
if self.l_log[self.mpi_rank*self.MAXNUMLOSS+k]
|
|
181
|
-
self.l_log[self.mpi_rank*self.MAXNUMLOSS+k]=
|
|
227
|
+
if self.l_log[self.mpi_rank * self.MAXNUMLOSS + k] == -1:
|
|
228
|
+
self.l_log[self.mpi_rank * self.MAXNUMLOSS + k] = (
|
|
229
|
+
l_loss.numpy() / nstep
|
|
230
|
+
)
|
|
182
231
|
else:
|
|
183
|
-
self.l_log[self.mpi_rank*self.MAXNUMLOSS+k]=
|
|
184
|
-
|
|
185
|
-
|
|
186
|
-
|
|
232
|
+
self.l_log[self.mpi_rank * self.MAXNUMLOSS + k] = (
|
|
233
|
+
self.l_log[self.mpi_rank * self.MAXNUMLOSS + k]
|
|
234
|
+
+ l_loss.numpy() / nstep
|
|
235
|
+
)
|
|
236
|
+
|
|
237
|
+
grd_mask = self.grd_mask
|
|
238
|
+
|
|
187
239
|
if grd_mask is not None:
|
|
188
|
-
g_tot=grd_mask*g_tot.numpy()
|
|
240
|
+
g_tot = grd_mask * g_tot.numpy()
|
|
189
241
|
else:
|
|
190
|
-
g_tot=g_tot.numpy()
|
|
191
|
-
|
|
192
|
-
g_tot[np.isnan(g_tot)]=0.0
|
|
242
|
+
g_tot = g_tot.numpy()
|
|
243
|
+
|
|
244
|
+
g_tot[np.isnan(g_tot)] = 0.0
|
|
193
245
|
|
|
194
|
-
self.imin=self.imin+self.batchsz
|
|
246
|
+
self.imin = self.imin + self.batchsz
|
|
195
247
|
|
|
196
|
-
if self.mpi_size==1:
|
|
197
|
-
self.ltot=self.l_log
|
|
248
|
+
if self.mpi_size == 1:
|
|
249
|
+
self.ltot = self.l_log
|
|
198
250
|
else:
|
|
199
|
-
local_log=(self.l_log).astype(
|
|
200
|
-
self.ltot=np.zeros(self.l_log.shape,dtype=
|
|
201
|
-
self.comm.Allreduce(
|
|
202
|
-
|
|
203
|
-
|
|
204
|
-
|
|
251
|
+
local_log = (self.l_log).astype("float64")
|
|
252
|
+
self.ltot = np.zeros(self.l_log.shape, dtype="float64")
|
|
253
|
+
self.comm.Allreduce(
|
|
254
|
+
(local_log, self.MPI.DOUBLE), (self.ltot, self.MPI.DOUBLE)
|
|
255
|
+
)
|
|
256
|
+
|
|
257
|
+
if self.mpi_size == 1:
|
|
258
|
+
grad = g_tot
|
|
205
259
|
else:
|
|
206
|
-
if self.operation.backend.bk_is_complex(
|
|
207
|
-
grad=np.zeros(self.oshape,dtype=
|
|
260
|
+
if self.operation.backend.bk_is_complex(g_tot):
|
|
261
|
+
grad = np.zeros(self.oshape, dtype=g_tot.dtype)
|
|
208
262
|
|
|
209
|
-
self.comm.Allreduce((g_tot),(grad))
|
|
263
|
+
self.comm.Allreduce((g_tot), (grad))
|
|
210
264
|
else:
|
|
211
|
-
grad=np.zeros(self.oshape,dtype=
|
|
265
|
+
grad = np.zeros(self.oshape, dtype="float64")
|
|
212
266
|
|
|
213
|
-
self.comm.Allreduce(
|
|
214
|
-
|
|
215
|
-
|
|
216
|
-
if self.nlog==self.history.shape[0]:
|
|
217
|
-
new_log=np.zeros([self.history.shape[0]*2])
|
|
218
|
-
new_log[0:self.nlog]=self.history
|
|
219
|
-
self.history=new_log
|
|
267
|
+
self.comm.Allreduce(
|
|
268
|
+
(g_tot.astype("float64"), self.MPI.DOUBLE), (grad, self.MPI.DOUBLE)
|
|
269
|
+
)
|
|
220
270
|
|
|
221
|
-
|
|
222
|
-
|
|
223
|
-
|
|
271
|
+
if self.nlog == self.history.shape[0]:
|
|
272
|
+
new_log = np.zeros([self.history.shape[0] * 2])
|
|
273
|
+
new_log[0 : self.nlog] = self.history
|
|
274
|
+
self.history = new_log
|
|
224
275
|
|
|
225
|
-
|
|
276
|
+
l_tot = self.ltot[self.ltot != -1].mean()
|
|
226
277
|
|
|
227
|
-
|
|
228
|
-
|
|
229
|
-
|
|
230
|
-
|
|
278
|
+
self.history[self.nlog] = l_tot
|
|
279
|
+
|
|
280
|
+
g_tot = grad.flatten()
|
|
281
|
+
|
|
282
|
+
if self.operation.backend.bk_is_complex(g_tot):
|
|
283
|
+
return l_tot.astype("float64"), g_tot
|
|
284
|
+
|
|
285
|
+
return l_tot.astype("float64"), g_tot.astype("float64")
|
|
231
286
|
|
|
232
287
|
# ---------------------------------------------−---------
|
|
233
|
-
def xtractmap(self,x,axis):
|
|
234
|
-
x=self.operation.backend.bk_reshape(x,self.oshape)
|
|
235
|
-
|
|
288
|
+
def xtractmap(self, x, axis):
|
|
289
|
+
x = self.operation.backend.bk_reshape(x, self.oshape)
|
|
290
|
+
|
|
236
291
|
return x
|
|
237
292
|
|
|
238
293
|
# ---------------------------------------------−---------
|
|
239
|
-
def run(
|
|
240
|
-
|
|
241
|
-
|
|
242
|
-
|
|
243
|
-
|
|
244
|
-
|
|
245
|
-
|
|
246
|
-
|
|
247
|
-
|
|
248
|
-
|
|
249
|
-
|
|
250
|
-
|
|
251
|
-
|
|
252
|
-
|
|
253
|
-
|
|
254
|
-
|
|
255
|
-
|
|
256
|
-
|
|
257
|
-
|
|
258
|
-
|
|
259
|
-
|
|
260
|
-
self.
|
|
261
|
-
self.
|
|
262
|
-
self.
|
|
294
|
+
def run(
|
|
295
|
+
self,
|
|
296
|
+
in_x,
|
|
297
|
+
NUM_EPOCHS=100,
|
|
298
|
+
DECAY_RATE=0.95,
|
|
299
|
+
EVAL_FREQUENCY=100,
|
|
300
|
+
DEVAL_STAT_FREQUENCY=1000,
|
|
301
|
+
NUM_STEP_BIAS=1,
|
|
302
|
+
LEARNING_RATE=0.03,
|
|
303
|
+
EPSILON=1e-7,
|
|
304
|
+
KEEP_TRACK=None,
|
|
305
|
+
grd_mask=None,
|
|
306
|
+
SHOWGPU=False,
|
|
307
|
+
MESSAGE="",
|
|
308
|
+
factr=10.0,
|
|
309
|
+
batchsz=1,
|
|
310
|
+
totalsz=1,
|
|
311
|
+
do_lbfgs=True,
|
|
312
|
+
axis=0,
|
|
313
|
+
):
|
|
314
|
+
|
|
315
|
+
self.KEEP_TRACK = KEEP_TRACK
|
|
316
|
+
self.track = {}
|
|
317
|
+
self.ntrack = 0
|
|
318
|
+
self.eta = LEARNING_RATE
|
|
319
|
+
self.epsilon = EPSILON
|
|
263
320
|
self.decay_rate = DECAY_RATE
|
|
264
|
-
self.nlog=0
|
|
265
|
-
self.itt2=0
|
|
266
|
-
self.batchsz=batchsz
|
|
267
|
-
self.totalsz=totalsz
|
|
268
|
-
self.grd_mask=grd_mask
|
|
269
|
-
self.EVAL_FREQUENCY=EVAL_FREQUENCY
|
|
270
|
-
self.MESSAGE=MESSAGE
|
|
271
|
-
self.SHOWGPU=SHOWGPU
|
|
272
|
-
self.axis=axis
|
|
273
|
-
self.in_x_nshape=in_x.shape[0]
|
|
274
|
-
|
|
275
|
-
|
|
276
|
-
|
|
277
|
-
|
|
278
|
-
|
|
279
|
-
|
|
280
|
-
|
|
281
|
-
|
|
321
|
+
self.nlog = 0
|
|
322
|
+
self.itt2 = 0
|
|
323
|
+
self.batchsz = batchsz
|
|
324
|
+
self.totalsz = totalsz
|
|
325
|
+
self.grd_mask = grd_mask
|
|
326
|
+
self.EVAL_FREQUENCY = EVAL_FREQUENCY
|
|
327
|
+
self.MESSAGE = MESSAGE
|
|
328
|
+
self.SHOWGPU = SHOWGPU
|
|
329
|
+
self.axis = axis
|
|
330
|
+
self.in_x_nshape = in_x.shape[0]
|
|
331
|
+
self.seed = 1234
|
|
332
|
+
|
|
333
|
+
np.random.seed(self.mpi_rank * 7 + 1234)
|
|
334
|
+
|
|
335
|
+
x = in_x
|
|
336
|
+
|
|
337
|
+
self.curr_gpu = self.curr_gpu + self.mpi_rank
|
|
338
|
+
|
|
339
|
+
if self.mpi_size > 1:
|
|
282
340
|
from mpi4py import MPI
|
|
283
|
-
|
|
284
341
|
|
|
285
342
|
comm = MPI.COMM_WORLD
|
|
286
|
-
self.comm=comm
|
|
287
|
-
self.MPI=MPI
|
|
288
|
-
if self.mpi_rank==0:
|
|
289
|
-
print(
|
|
343
|
+
self.comm = comm
|
|
344
|
+
self.MPI = MPI
|
|
345
|
+
if self.mpi_rank == 0:
|
|
346
|
+
print("Work with MPI")
|
|
290
347
|
sys.stdout.flush()
|
|
291
|
-
|
|
292
|
-
if self.mpi_rank==0 and SHOWGPU:
|
|
348
|
+
|
|
349
|
+
if self.mpi_rank == 0 and SHOWGPU:
|
|
293
350
|
# start thread that catch GPU information
|
|
294
351
|
try:
|
|
295
|
-
self.gpu_thrd = Thread(
|
|
352
|
+
self.gpu_thrd = Thread(
|
|
353
|
+
target=self.get_gpu,
|
|
354
|
+
args=(
|
|
355
|
+
self.event,
|
|
356
|
+
1,
|
|
357
|
+
),
|
|
358
|
+
)
|
|
296
359
|
self.gpu_thrd.start()
|
|
297
360
|
except:
|
|
298
361
|
print("Error: unable to start thread for GPU survey")
|
|
299
|
-
|
|
300
|
-
start = time.time()
|
|
301
|
-
|
|
302
|
-
if self.mpi_size>1:
|
|
303
|
-
num_loss=np.zeros([1],dtype=
|
|
304
|
-
total_num_loss=np.zeros([1],dtype=
|
|
305
|
-
num_loss[0]=self.number_of_loss
|
|
306
|
-
comm.Allreduce((num_loss,MPI.INT),(total_num_loss,MPI.INT))
|
|
307
|
-
total_num_loss=total_num_loss[0]
|
|
362
|
+
|
|
363
|
+
# start = time.time()
|
|
364
|
+
|
|
365
|
+
if self.mpi_size > 1:
|
|
366
|
+
num_loss = np.zeros([1], dtype="int32")
|
|
367
|
+
total_num_loss = np.zeros([1], dtype="int32")
|
|
368
|
+
num_loss[0] = self.number_of_loss
|
|
369
|
+
comm.Allreduce((num_loss, MPI.INT), (total_num_loss, MPI.INT))
|
|
370
|
+
total_num_loss = total_num_loss[0]
|
|
308
371
|
else:
|
|
309
|
-
total_num_loss=self.number_of_loss
|
|
310
|
-
|
|
311
|
-
if self.mpi_rank==0:
|
|
312
|
-
print(
|
|
372
|
+
total_num_loss = self.number_of_loss
|
|
373
|
+
|
|
374
|
+
if self.mpi_rank == 0:
|
|
375
|
+
print("Total number of loss ", total_num_loss)
|
|
313
376
|
sys.stdout.flush()
|
|
314
|
-
|
|
315
|
-
l_log=np.zeros([self.mpi_size*self.MAXNUMLOSS],dtype='float32')
|
|
316
|
-
l_log[self.mpi_rank*self.MAXNUMLOSS:(self.mpi_rank+1)*self.MAXNUMLOSS]=-1.0
|
|
317
|
-
self.ltot=l_log.copy()
|
|
318
|
-
self.l_log=l_log
|
|
319
|
-
|
|
320
|
-
self.imin=0
|
|
321
|
-
self.start=time.time()
|
|
322
|
-
self.itt=0
|
|
323
|
-
|
|
324
|
-
self.oshape=list(x.shape)
|
|
325
|
-
|
|
326
|
-
if not isinstance(x,np.ndarray):
|
|
327
|
-
x=x.numpy()
|
|
328
|
-
|
|
329
|
-
x=x.flatten()
|
|
330
377
|
|
|
331
|
-
self.
|
|
378
|
+
l_log = np.zeros([self.mpi_size * self.MAXNUMLOSS], dtype="float32")
|
|
379
|
+
l_log[
|
|
380
|
+
self.mpi_rank * self.MAXNUMLOSS : (self.mpi_rank + 1) * self.MAXNUMLOSS
|
|
381
|
+
] = -1.0
|
|
382
|
+
self.ltot = l_log.copy()
|
|
383
|
+
self.l_log = l_log
|
|
332
384
|
|
|
333
|
-
self.
|
|
385
|
+
self.imin = 0
|
|
386
|
+
self.start = time.time()
|
|
387
|
+
self.itt = 0
|
|
334
388
|
|
|
335
|
-
self.
|
|
389
|
+
self.oshape = list(x.shape)
|
|
336
390
|
|
|
337
|
-
|
|
338
|
-
|
|
339
|
-
|
|
340
|
-
|
|
341
|
-
|
|
391
|
+
if not isinstance(x, np.ndarray):
|
|
392
|
+
x = x.numpy()
|
|
393
|
+
|
|
394
|
+
x = x.flatten()
|
|
395
|
+
|
|
396
|
+
self.do_all_noise = False
|
|
397
|
+
|
|
398
|
+
self.do_all_noise = True
|
|
399
|
+
|
|
400
|
+
self.noise_idx = None
|
|
401
|
+
|
|
402
|
+
# for k in range(self.number_of_loss):
|
|
403
|
+
# if self.loss_class[k].batch is not None:
|
|
404
|
+
# l_batch = self.loss_class[k].batch(
|
|
405
|
+
# self.loss_class[k].batch_data, 0, init=True
|
|
406
|
+
# )
|
|
407
|
+
|
|
408
|
+
l_tot, g_tot = self.calc_grad(x)
|
|
342
409
|
|
|
343
410
|
self.info_back(x)
|
|
344
411
|
|
|
345
|
-
maxitt=NUM_EPOCHS
|
|
412
|
+
maxitt = NUM_EPOCHS
|
|
346
413
|
|
|
347
|
-
start_x=x.copy()
|
|
414
|
+
# start_x = x.copy()
|
|
348
415
|
|
|
349
416
|
for iteration in range(NUM_STEP_BIAS):
|
|
350
417
|
|
|
351
|
-
x,
|
|
352
|
-
|
|
353
|
-
|
|
354
|
-
|
|
355
|
-
|
|
356
|
-
|
|
418
|
+
x, loss, i = opt.fmin_l_bfgs_b(
|
|
419
|
+
self.calc_grad,
|
|
420
|
+
x.astype("float64"),
|
|
421
|
+
callback=self.info_back,
|
|
422
|
+
pgtol=1e-32,
|
|
423
|
+
factr=factr,
|
|
424
|
+
maxiter=maxitt,
|
|
425
|
+
)
|
|
357
426
|
|
|
358
427
|
# update bias input data
|
|
359
|
-
if iteration<NUM_STEP_BIAS-1:
|
|
360
|
-
if self.mpi_rank==0:
|
|
361
|
-
|
|
428
|
+
if iteration < NUM_STEP_BIAS - 1:
|
|
429
|
+
# if self.mpi_rank==0:
|
|
430
|
+
# print('%s Hessian restart'%(self.MESSAGE))
|
|
362
431
|
|
|
363
|
-
omap=self.xtractmap(x,axis)
|
|
432
|
+
omap = self.xtractmap(x, axis)
|
|
364
433
|
|
|
365
434
|
for k in range(self.number_of_loss):
|
|
366
435
|
if self.loss_class[k].batch_update is not None:
|
|
367
|
-
self.loss_class[k].batch_update(
|
|
368
|
-
|
|
369
|
-
|
|
370
|
-
|
|
371
|
-
|
|
372
|
-
|
|
436
|
+
self.loss_class[k].batch_update(
|
|
437
|
+
self.loss_class[k].batch_data, omap
|
|
438
|
+
)
|
|
439
|
+
# if self.loss_class[k].batch is not None:
|
|
440
|
+
# l_batch = self.loss_class[k].batch(
|
|
441
|
+
# self.loss_class[k].batch_data, 0, init=True
|
|
442
|
+
# )
|
|
443
|
+
# x=start_x.copy()
|
|
444
|
+
|
|
445
|
+
if self.mpi_rank == 0 and SHOWGPU:
|
|
373
446
|
self.stop_synthesis()
|
|
374
447
|
|
|
375
448
|
if self.KEEP_TRACK is not None:
|
|
376
|
-
self.last_info=self.KEEP_TRACK(None,self.mpi_rank,add=False)
|
|
449
|
+
self.last_info = self.KEEP_TRACK(None, self.mpi_rank, add=False)
|
|
377
450
|
|
|
378
|
-
x=self.xtractmap(x,axis)
|
|
379
|
-
return
|
|
451
|
+
x = self.xtractmap(x, axis)
|
|
452
|
+
return x
|
|
380
453
|
|
|
381
454
|
def get_history(self):
|
|
382
|
-
return
|
|
455
|
+
return self.history[0 : self.nlog]
|