foscat 3.1.6__py3-none-any.whl → 3.3.0__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- foscat/CNN.py +129 -90
- foscat/CircSpline.py +60 -36
- foscat/FoCUS.py +2216 -1500
- foscat/GCNN.py +201 -134
- foscat/Softmax.py +27 -22
- foscat/Spline1D.py +86 -36
- foscat/Synthesis.py +333 -264
- foscat/alm.py +134 -0
- foscat/backend.py +856 -683
- foscat/backend_tens.py +44 -30
- foscat/loss_backend_tens.py +48 -38
- foscat/loss_backend_torch.py +32 -58
- foscat/scat.py +1600 -1020
- foscat/scat1D.py +1230 -814
- foscat/scat2D.py +9 -7
- foscat/scat_cov.py +2867 -1766
- foscat/scat_cov1D.py +9 -7
- foscat/scat_cov2D.py +9 -7
- foscat/scat_cov_map.py +77 -51
- foscat/scat_cov_map2D.py +79 -49
- foscat-3.3.0.dist-info/LICENCE +13 -0
- foscat-3.3.0.dist-info/METADATA +183 -0
- foscat-3.3.0.dist-info/RECORD +26 -0
- {foscat-3.1.6.dist-info → foscat-3.3.0.dist-info}/WHEEL +1 -1
- foscat/GetGPUinfo.py +0 -36
- foscat/scat_cov1D.old.py +0 -1547
- foscat-3.1.6.dist-info/METADATA +0 -23
- foscat-3.1.6.dist-info/RECORD +0 -26
- {foscat-3.1.6.dist-info → foscat-3.3.0.dist-info}/top_level.txt +0 -0
foscat/Synthesis.py
CHANGED
|
@@ -1,56 +1,77 @@
|
|
|
1
|
-
import numpy as np
|
|
2
|
-
import time
|
|
3
|
-
import sys
|
|
4
1
|
import os
|
|
5
|
-
|
|
6
|
-
|
|
7
|
-
from threading import Thread
|
|
8
|
-
|
|
2
|
+
import sys
|
|
3
|
+
import time
|
|
4
|
+
from threading import Event, Thread
|
|
5
|
+
|
|
6
|
+
import numpy as np
|
|
9
7
|
import scipy.optimize as opt
|
|
10
8
|
|
|
9
|
+
|
|
11
10
|
class Loss:
|
|
12
|
-
|
|
13
|
-
def __init__(
|
|
14
|
-
|
|
15
|
-
|
|
16
|
-
|
|
17
|
-
|
|
18
|
-
|
|
19
|
-
|
|
20
|
-
|
|
21
|
-
|
|
22
|
-
|
|
23
|
-
|
|
24
|
-
|
|
25
|
-
self.
|
|
26
|
-
self.
|
|
27
|
-
self.
|
|
28
|
-
|
|
29
|
-
|
|
11
|
+
|
|
12
|
+
def __init__(
|
|
13
|
+
self,
|
|
14
|
+
function,
|
|
15
|
+
scat_operator,
|
|
16
|
+
*param,
|
|
17
|
+
name="",
|
|
18
|
+
batch=None,
|
|
19
|
+
batch_data=None,
|
|
20
|
+
batch_update=None,
|
|
21
|
+
info_callback=False,
|
|
22
|
+
):
|
|
23
|
+
|
|
24
|
+
self.loss_function = function
|
|
25
|
+
self.scat_operator = scat_operator
|
|
26
|
+
self.args = param
|
|
27
|
+
self.name = name
|
|
28
|
+
self.batch = batch
|
|
29
|
+
self.batch_data = batch_data
|
|
30
|
+
self.batch_update = batch_update
|
|
31
|
+
self.info = info_callback
|
|
32
|
+
self.id_loss = 0
|
|
33
|
+
|
|
34
|
+
def eval(self, x, batch, return_all=False):
|
|
30
35
|
if self.batch is None:
|
|
31
36
|
if self.info:
|
|
32
|
-
return self.loss_function(
|
|
37
|
+
return self.loss_function(
|
|
38
|
+
x, self.scat_operator, self.args, return_all=return_all
|
|
39
|
+
)
|
|
33
40
|
else:
|
|
34
|
-
return self.loss_function(x,self.scat_operator,self.args)
|
|
41
|
+
return self.loss_function(x, self.scat_operator, self.args)
|
|
35
42
|
else:
|
|
36
43
|
if self.info:
|
|
37
|
-
return self.loss_function(
|
|
44
|
+
return self.loss_function(
|
|
45
|
+
x, batch, self.scat_operator, self.args, return_all=return_all
|
|
46
|
+
)
|
|
38
47
|
else:
|
|
39
|
-
return self.loss_function(x,batch,self.scat_operator,self.args)
|
|
48
|
+
return self.loss_function(x, batch, self.scat_operator, self.args)
|
|
49
|
+
|
|
50
|
+
def set_id_loss(self,id_loss):
|
|
51
|
+
self.id_loss = id_loss
|
|
40
52
|
|
|
53
|
+
def get_id_loss(self,id_loss):
|
|
54
|
+
return self.id_loss
|
|
55
|
+
|
|
41
56
|
class Synthesis:
|
|
42
|
-
def __init__(
|
|
43
|
-
|
|
44
|
-
|
|
45
|
-
|
|
46
|
-
|
|
47
|
-
|
|
48
|
-
|
|
49
|
-
|
|
50
|
-
|
|
51
|
-
|
|
52
|
-
self.
|
|
53
|
-
self.
|
|
57
|
+
def __init__(
|
|
58
|
+
self,
|
|
59
|
+
loss_list,
|
|
60
|
+
eta=0.03,
|
|
61
|
+
beta1=0.9,
|
|
62
|
+
beta2=0.999,
|
|
63
|
+
epsilon=1e-7,
|
|
64
|
+
decay_rate=0.999,
|
|
65
|
+
):
|
|
66
|
+
|
|
67
|
+
self.loss_class = loss_list
|
|
68
|
+
self.number_of_loss = len(loss_list)
|
|
69
|
+
|
|
70
|
+
for k in range(self.number_of_loss):
|
|
71
|
+
self.loss_class[k].set_id_loss(k)
|
|
72
|
+
|
|
73
|
+
self.__iteration__ = 1234
|
|
74
|
+
self.nlog = 0
|
|
54
75
|
self.m_dw, self.v_dw = 0.0, 0.0
|
|
55
76
|
self.beta1 = beta1
|
|
56
77
|
self.beta2 = beta2
|
|
@@ -58,329 +79,377 @@ class Synthesis:
|
|
|
58
79
|
self.pbeta2 = beta2
|
|
59
80
|
self.epsilon = epsilon
|
|
60
81
|
self.eta = eta
|
|
61
|
-
self.history=np.zeros([10])
|
|
62
|
-
self.curr_gpu=0
|
|
82
|
+
self.history = np.zeros([10])
|
|
83
|
+
self.curr_gpu = 0
|
|
63
84
|
self.event = Event()
|
|
64
|
-
self.operation=loss_list[0].scat_operator
|
|
65
|
-
self.mpi_size=self.operation.mpi_size
|
|
66
|
-
self.mpi_rank=self.operation.mpi_rank
|
|
67
|
-
self.KEEP_TRACK=None
|
|
68
|
-
self.MAXNUMLOSS=len(loss_list)
|
|
69
|
-
|
|
70
|
-
if self.operation.BACKEND==
|
|
85
|
+
self.operation = loss_list[0].scat_operator
|
|
86
|
+
self.mpi_size = self.operation.mpi_size
|
|
87
|
+
self.mpi_rank = self.operation.mpi_rank
|
|
88
|
+
self.KEEP_TRACK = None
|
|
89
|
+
self.MAXNUMLOSS = len(loss_list)
|
|
90
|
+
|
|
91
|
+
if self.operation.BACKEND == "tensorflow":
|
|
71
92
|
import foscat.loss_backend_tens as fbk
|
|
72
|
-
|
|
73
|
-
|
|
74
|
-
|
|
93
|
+
|
|
94
|
+
self.bk = fbk.loss_backend(self.operation, self.curr_gpu, self.mpi_rank)
|
|
95
|
+
|
|
96
|
+
if self.operation.BACKEND == "torch":
|
|
75
97
|
import foscat.loss_backend_torch as fbk
|
|
76
|
-
|
|
77
|
-
|
|
78
|
-
|
|
79
|
-
|
|
98
|
+
|
|
99
|
+
self.bk = fbk.loss_backend(self.operation, self.curr_gpu, self.mpi_rank)
|
|
100
|
+
|
|
101
|
+
if self.operation.BACKEND == "numpy":
|
|
102
|
+
print(
|
|
103
|
+
"Synthesis does not work with numpy. Please select Torch or Tensorflow FOSCAT backend"
|
|
104
|
+
)
|
|
80
105
|
return None
|
|
81
106
|
|
|
82
107
|
# ---------------------------------------------−---------
|
|
83
|
-
def get_gpu(self,event,delay):
|
|
108
|
+
def get_gpu(self, event, delay):
|
|
84
109
|
|
|
85
|
-
isnvidia=os.system(
|
|
110
|
+
isnvidia = os.system("which nvidia-smi &> /dev/null")
|
|
86
111
|
|
|
87
|
-
while
|
|
112
|
+
while 1:
|
|
88
113
|
if event.is_set():
|
|
89
114
|
break
|
|
90
115
|
time.sleep(delay)
|
|
91
|
-
if isnvidia==0:
|
|
116
|
+
if isnvidia == 0:
|
|
92
117
|
try:
|
|
93
|
-
os.system(
|
|
118
|
+
os.system(
|
|
119
|
+
"nvidia-smi | awk '$2==\"N/A\"{print substr($9,1,length($9)-3),substr($11,1,length($11)-3),substr($13,1,length($13)-1)}' > smi_tmp.txt"
|
|
120
|
+
)
|
|
94
121
|
except:
|
|
95
|
-
|
|
96
|
-
|
|
122
|
+
print("No nvidia GPU: Impossible to trace")
|
|
123
|
+
self.nogpu = 1
|
|
124
|
+
|
|
97
125
|
def stop_synthesis(self):
|
|
98
126
|
# stop thread that catch GPU information
|
|
99
127
|
self.event.set()
|
|
100
|
-
|
|
128
|
+
|
|
101
129
|
try:
|
|
102
130
|
self.gpu_thrd.join()
|
|
103
131
|
except:
|
|
104
|
-
print(
|
|
132
|
+
print("No thread to stop, everything is ok")
|
|
105
133
|
sys.stdout.flush()
|
|
106
|
-
|
|
134
|
+
|
|
107
135
|
# ---------------------------------------------−---------
|
|
108
136
|
def getgpumem(self):
|
|
109
137
|
try:
|
|
110
|
-
return np.loadtxt(
|
|
138
|
+
return np.loadtxt("smi_tmp.txt")
|
|
111
139
|
except:
|
|
112
|
-
return
|
|
113
|
-
|
|
140
|
+
return np.zeros([1, 3])
|
|
141
|
+
|
|
114
142
|
# ---------------------------------------------−---------
|
|
115
|
-
def info_back(self,x):
|
|
116
|
-
|
|
117
|
-
self.nlog=self.nlog+1
|
|
118
|
-
self.itt2=0
|
|
119
|
-
|
|
120
|
-
if self.itt%self.EVAL_FREQUENCY==0 and self.mpi_rank==0:
|
|
143
|
+
def info_back(self, x):
|
|
144
|
+
|
|
145
|
+
self.nlog = self.nlog + 1
|
|
146
|
+
self.itt2 = 0
|
|
147
|
+
|
|
148
|
+
if self.itt % self.EVAL_FREQUENCY == 0 and self.mpi_rank == 0:
|
|
121
149
|
end = time.time()
|
|
122
|
-
cur_loss=
|
|
123
|
-
for k in self.ltot[self.ltot
|
|
124
|
-
cur_loss=cur_loss+
|
|
125
|
-
|
|
126
|
-
cur_loss=cur_loss+
|
|
127
|
-
|
|
128
|
-
mess=
|
|
129
|
-
|
|
150
|
+
cur_loss = "%10.3g (" % (self.ltot[self.ltot != -1].mean())
|
|
151
|
+
for k in self.ltot[self.ltot != -1]:
|
|
152
|
+
cur_loss = cur_loss + "%10.3g " % (k)
|
|
153
|
+
|
|
154
|
+
cur_loss = cur_loss + ")"
|
|
155
|
+
|
|
156
|
+
mess = ""
|
|
157
|
+
|
|
130
158
|
if self.SHOWGPU:
|
|
131
|
-
info_gpu=self.getgpumem()
|
|
159
|
+
info_gpu = self.getgpumem()
|
|
132
160
|
for k in range(info_gpu.shape[0]):
|
|
133
|
-
mess=mess+
|
|
134
|
-
|
|
135
|
-
|
|
136
|
-
|
|
161
|
+
mess = mess + "[GPU%d %.0f/%.0f MB %.0f%%]" % (
|
|
162
|
+
k,
|
|
163
|
+
info_gpu[k, 0],
|
|
164
|
+
info_gpu[k, 1],
|
|
165
|
+
info_gpu[k, 2],
|
|
166
|
+
)
|
|
167
|
+
|
|
168
|
+
print(
|
|
169
|
+
"%sItt %6d L=%s %.3fs %s"
|
|
170
|
+
% (self.MESSAGE, self.itt, cur_loss, (end - self.start), mess)
|
|
171
|
+
)
|
|
137
172
|
sys.stdout.flush()
|
|
138
173
|
if self.KEEP_TRACK is not None:
|
|
139
174
|
print(self.last_info)
|
|
140
175
|
sys.stdout.flush()
|
|
141
|
-
|
|
176
|
+
|
|
142
177
|
self.start = time.time()
|
|
143
|
-
|
|
144
|
-
self.itt=self.itt+1
|
|
178
|
+
|
|
179
|
+
self.itt = self.itt + 1
|
|
145
180
|
|
|
146
181
|
# ---------------------------------------------−---------
|
|
147
|
-
def calc_grad(self,in_x):
|
|
148
|
-
|
|
149
|
-
g_tot=None
|
|
150
|
-
l_tot=0.0
|
|
182
|
+
def calc_grad(self, in_x):
|
|
151
183
|
|
|
152
|
-
|
|
153
|
-
|
|
154
|
-
|
|
184
|
+
g_tot = None
|
|
185
|
+
l_tot = 0.0
|
|
186
|
+
|
|
187
|
+
if self.do_all_noise and self.totalsz > self.batchsz:
|
|
188
|
+
nstep = self.totalsz // self.batchsz
|
|
155
189
|
else:
|
|
156
|
-
nstep=1
|
|
190
|
+
nstep = 1
|
|
191
|
+
|
|
192
|
+
x = self.operation.backend.bk_reshape(
|
|
193
|
+
self.operation.backend.bk_cast(in_x), self.oshape
|
|
194
|
+
)
|
|
195
|
+
|
|
196
|
+
self.l_log[
|
|
197
|
+
self.mpi_rank * self.MAXNUMLOSS : (self.mpi_rank + 1) * self.MAXNUMLOSS
|
|
198
|
+
] = -1.0
|
|
157
199
|
|
|
158
|
-
x=self.operation.backend.bk_reshape(self.operation.backend.bk_cast(in_x),self.oshape)
|
|
159
|
-
|
|
160
|
-
self.l_log[self.mpi_rank*self.MAXNUMLOSS:(self.mpi_rank+1)*self.MAXNUMLOSS]=-1.0
|
|
161
|
-
|
|
162
200
|
for istep in range(nstep):
|
|
163
|
-
|
|
201
|
+
|
|
164
202
|
for k in range(self.number_of_loss):
|
|
165
203
|
if self.loss_class[k].batch is None:
|
|
166
|
-
l_batch=None
|
|
204
|
+
l_batch = None
|
|
167
205
|
else:
|
|
168
|
-
l_batch=self.loss_class[k].batch(
|
|
206
|
+
l_batch = self.loss_class[k].batch(
|
|
207
|
+
self.loss_class[k].batch_data, istep
|
|
208
|
+
)
|
|
169
209
|
|
|
170
210
|
if self.KEEP_TRACK is not None:
|
|
171
|
-
|
|
172
|
-
|
|
211
|
+
l_loss, g, linfo = self.bk.loss(
|
|
212
|
+
x, l_batch, self.loss_class[k], self.KEEP_TRACK
|
|
213
|
+
)
|
|
214
|
+
self.last_info = self.KEEP_TRACK(linfo, self.mpi_rank, add=True)
|
|
173
215
|
else:
|
|
174
|
-
|
|
216
|
+
l_loss, g = self.bk.loss(
|
|
217
|
+
x, l_batch, self.loss_class[k], self.KEEP_TRACK
|
|
218
|
+
)
|
|
175
219
|
|
|
176
220
|
if g_tot is None:
|
|
177
|
-
g_tot=g
|
|
221
|
+
g_tot = g
|
|
178
222
|
else:
|
|
179
|
-
g_tot=g_tot+g
|
|
223
|
+
g_tot = g_tot + g
|
|
180
224
|
|
|
181
|
-
l_tot=l_tot+
|
|
225
|
+
l_tot = l_tot + l_loss.numpy()
|
|
182
226
|
|
|
183
|
-
if self.l_log[self.mpi_rank*self.MAXNUMLOSS+k]
|
|
184
|
-
self.l_log[self.mpi_rank*self.MAXNUMLOSS+k]=
|
|
227
|
+
if self.l_log[self.mpi_rank * self.MAXNUMLOSS + k] == -1:
|
|
228
|
+
self.l_log[self.mpi_rank * self.MAXNUMLOSS + k] = (
|
|
229
|
+
l_loss.numpy() / nstep
|
|
230
|
+
)
|
|
185
231
|
else:
|
|
186
|
-
self.l_log[self.mpi_rank*self.MAXNUMLOSS+k]=
|
|
187
|
-
|
|
188
|
-
|
|
189
|
-
|
|
232
|
+
self.l_log[self.mpi_rank * self.MAXNUMLOSS + k] = (
|
|
233
|
+
self.l_log[self.mpi_rank * self.MAXNUMLOSS + k]
|
|
234
|
+
+ l_loss.numpy() / nstep
|
|
235
|
+
)
|
|
236
|
+
|
|
237
|
+
grd_mask = self.grd_mask
|
|
238
|
+
|
|
190
239
|
if grd_mask is not None:
|
|
191
|
-
g_tot=grd_mask*g_tot.numpy()
|
|
240
|
+
g_tot = grd_mask * g_tot.numpy()
|
|
192
241
|
else:
|
|
193
|
-
g_tot=g_tot.numpy()
|
|
194
|
-
|
|
195
|
-
g_tot[np.isnan(g_tot)]=0.0
|
|
242
|
+
g_tot = g_tot.numpy()
|
|
196
243
|
|
|
197
|
-
|
|
244
|
+
g_tot[np.isnan(g_tot)] = 0.0
|
|
198
245
|
|
|
199
|
-
|
|
200
|
-
|
|
246
|
+
self.imin = self.imin + self.batchsz
|
|
247
|
+
|
|
248
|
+
if self.mpi_size == 1:
|
|
249
|
+
self.ltot = self.l_log
|
|
201
250
|
else:
|
|
202
|
-
local_log=(self.l_log).astype(
|
|
203
|
-
self.ltot=np.zeros(self.l_log.shape,dtype=
|
|
204
|
-
self.comm.Allreduce(
|
|
205
|
-
|
|
206
|
-
|
|
207
|
-
|
|
251
|
+
local_log = (self.l_log).astype("float64")
|
|
252
|
+
self.ltot = np.zeros(self.l_log.shape, dtype="float64")
|
|
253
|
+
self.comm.Allreduce(
|
|
254
|
+
(local_log, self.MPI.DOUBLE), (self.ltot, self.MPI.DOUBLE)
|
|
255
|
+
)
|
|
256
|
+
|
|
257
|
+
if self.mpi_size == 1:
|
|
258
|
+
grad = g_tot
|
|
208
259
|
else:
|
|
209
|
-
if self.operation.backend.bk_is_complex(
|
|
210
|
-
grad=np.zeros(self.oshape,dtype=
|
|
260
|
+
if self.operation.backend.bk_is_complex(g_tot):
|
|
261
|
+
grad = np.zeros(self.oshape, dtype=g_tot.dtype)
|
|
211
262
|
|
|
212
|
-
self.comm.Allreduce((g_tot),(grad))
|
|
263
|
+
self.comm.Allreduce((g_tot), (grad))
|
|
213
264
|
else:
|
|
214
|
-
grad=np.zeros(self.oshape,dtype=
|
|
265
|
+
grad = np.zeros(self.oshape, dtype="float64")
|
|
215
266
|
|
|
216
|
-
self.comm.Allreduce(
|
|
217
|
-
|
|
218
|
-
|
|
219
|
-
if self.nlog==self.history.shape[0]:
|
|
220
|
-
new_log=np.zeros([self.history.shape[0]*2])
|
|
221
|
-
new_log[0:self.nlog]=self.history
|
|
222
|
-
self.history=new_log
|
|
267
|
+
self.comm.Allreduce(
|
|
268
|
+
(g_tot.astype("float64"), self.MPI.DOUBLE), (grad, self.MPI.DOUBLE)
|
|
269
|
+
)
|
|
223
270
|
|
|
224
|
-
|
|
225
|
-
|
|
226
|
-
|
|
271
|
+
if self.nlog == self.history.shape[0]:
|
|
272
|
+
new_log = np.zeros([self.history.shape[0] * 2])
|
|
273
|
+
new_log[0 : self.nlog] = self.history
|
|
274
|
+
self.history = new_log
|
|
227
275
|
|
|
228
|
-
|
|
276
|
+
l_tot = self.ltot[self.ltot != -1].mean()
|
|
229
277
|
|
|
230
|
-
|
|
231
|
-
|
|
232
|
-
|
|
233
|
-
|
|
278
|
+
self.history[self.nlog] = l_tot
|
|
279
|
+
|
|
280
|
+
g_tot = grad.flatten()
|
|
281
|
+
|
|
282
|
+
if self.operation.backend.bk_is_complex(g_tot):
|
|
283
|
+
return l_tot.astype("float64"), g_tot
|
|
284
|
+
|
|
285
|
+
return l_tot.astype("float64"), g_tot.astype("float64")
|
|
234
286
|
|
|
235
287
|
# ---------------------------------------------−---------
|
|
236
|
-
def xtractmap(self,x,axis):
|
|
237
|
-
x=self.operation.backend.bk_reshape(x,self.oshape)
|
|
238
|
-
|
|
288
|
+
def xtractmap(self, x, axis):
|
|
289
|
+
x = self.operation.backend.bk_reshape(x, self.oshape)
|
|
290
|
+
|
|
239
291
|
return x
|
|
240
292
|
|
|
241
293
|
# ---------------------------------------------−---------
|
|
242
|
-
def run(
|
|
243
|
-
|
|
244
|
-
|
|
245
|
-
|
|
246
|
-
|
|
247
|
-
|
|
248
|
-
|
|
249
|
-
|
|
250
|
-
|
|
251
|
-
|
|
252
|
-
|
|
253
|
-
|
|
254
|
-
|
|
255
|
-
|
|
256
|
-
|
|
257
|
-
|
|
258
|
-
|
|
259
|
-
|
|
260
|
-
|
|
261
|
-
|
|
262
|
-
|
|
263
|
-
self.
|
|
264
|
-
self.
|
|
265
|
-
self.
|
|
294
|
+
def run(
|
|
295
|
+
self,
|
|
296
|
+
in_x,
|
|
297
|
+
NUM_EPOCHS=100,
|
|
298
|
+
DECAY_RATE=0.95,
|
|
299
|
+
EVAL_FREQUENCY=100,
|
|
300
|
+
DEVAL_STAT_FREQUENCY=1000,
|
|
301
|
+
NUM_STEP_BIAS=1,
|
|
302
|
+
LEARNING_RATE=0.03,
|
|
303
|
+
EPSILON=1e-7,
|
|
304
|
+
KEEP_TRACK=None,
|
|
305
|
+
grd_mask=None,
|
|
306
|
+
SHOWGPU=False,
|
|
307
|
+
MESSAGE="",
|
|
308
|
+
factr=10.0,
|
|
309
|
+
batchsz=1,
|
|
310
|
+
totalsz=1,
|
|
311
|
+
do_lbfgs=True,
|
|
312
|
+
axis=0,
|
|
313
|
+
):
|
|
314
|
+
|
|
315
|
+
self.KEEP_TRACK = KEEP_TRACK
|
|
316
|
+
self.track = {}
|
|
317
|
+
self.ntrack = 0
|
|
318
|
+
self.eta = LEARNING_RATE
|
|
319
|
+
self.epsilon = EPSILON
|
|
266
320
|
self.decay_rate = DECAY_RATE
|
|
267
|
-
self.nlog=0
|
|
268
|
-
self.itt2=0
|
|
269
|
-
self.batchsz=batchsz
|
|
270
|
-
self.totalsz=totalsz
|
|
271
|
-
self.grd_mask=grd_mask
|
|
272
|
-
self.EVAL_FREQUENCY=EVAL_FREQUENCY
|
|
273
|
-
self.MESSAGE=MESSAGE
|
|
274
|
-
self.SHOWGPU=SHOWGPU
|
|
275
|
-
self.axis=axis
|
|
276
|
-
self.in_x_nshape=in_x.shape[0]
|
|
277
|
-
self.seed=1234
|
|
278
|
-
|
|
279
|
-
np.random.seed(self.mpi_rank*7+1234)
|
|
280
|
-
|
|
281
|
-
x=in_x
|
|
282
|
-
|
|
283
|
-
self.curr_gpu=self.curr_gpu+self.mpi_rank
|
|
284
|
-
|
|
285
|
-
if self.mpi_size>1:
|
|
321
|
+
self.nlog = 0
|
|
322
|
+
self.itt2 = 0
|
|
323
|
+
self.batchsz = batchsz
|
|
324
|
+
self.totalsz = totalsz
|
|
325
|
+
self.grd_mask = grd_mask
|
|
326
|
+
self.EVAL_FREQUENCY = EVAL_FREQUENCY
|
|
327
|
+
self.MESSAGE = MESSAGE
|
|
328
|
+
self.SHOWGPU = SHOWGPU
|
|
329
|
+
self.axis = axis
|
|
330
|
+
self.in_x_nshape = in_x.shape[0]
|
|
331
|
+
self.seed = 1234
|
|
332
|
+
|
|
333
|
+
np.random.seed(self.mpi_rank * 7 + 1234)
|
|
334
|
+
|
|
335
|
+
x = in_x
|
|
336
|
+
|
|
337
|
+
self.curr_gpu = self.curr_gpu + self.mpi_rank
|
|
338
|
+
|
|
339
|
+
if self.mpi_size > 1:
|
|
286
340
|
from mpi4py import MPI
|
|
287
|
-
|
|
288
341
|
|
|
289
342
|
comm = MPI.COMM_WORLD
|
|
290
|
-
self.comm=comm
|
|
291
|
-
self.MPI=MPI
|
|
292
|
-
if self.mpi_rank==0:
|
|
293
|
-
print(
|
|
343
|
+
self.comm = comm
|
|
344
|
+
self.MPI = MPI
|
|
345
|
+
if self.mpi_rank == 0:
|
|
346
|
+
print("Work with MPI")
|
|
294
347
|
sys.stdout.flush()
|
|
295
|
-
|
|
296
|
-
if self.mpi_rank==0 and SHOWGPU:
|
|
348
|
+
|
|
349
|
+
if self.mpi_rank == 0 and SHOWGPU:
|
|
297
350
|
# start thread that catch GPU information
|
|
298
351
|
try:
|
|
299
|
-
self.gpu_thrd = Thread(
|
|
352
|
+
self.gpu_thrd = Thread(
|
|
353
|
+
target=self.get_gpu,
|
|
354
|
+
args=(
|
|
355
|
+
self.event,
|
|
356
|
+
1,
|
|
357
|
+
),
|
|
358
|
+
)
|
|
300
359
|
self.gpu_thrd.start()
|
|
301
360
|
except:
|
|
302
361
|
print("Error: unable to start thread for GPU survey")
|
|
303
|
-
|
|
304
|
-
start = time.time()
|
|
305
|
-
|
|
306
|
-
if self.mpi_size>1:
|
|
307
|
-
num_loss=np.zeros([1],dtype=
|
|
308
|
-
total_num_loss=np.zeros([1],dtype=
|
|
309
|
-
num_loss[0]=self.number_of_loss
|
|
310
|
-
comm.Allreduce((num_loss,MPI.INT),(total_num_loss,MPI.INT))
|
|
311
|
-
total_num_loss=total_num_loss[0]
|
|
362
|
+
|
|
363
|
+
# start = time.time()
|
|
364
|
+
|
|
365
|
+
if self.mpi_size > 1:
|
|
366
|
+
num_loss = np.zeros([1], dtype="int32")
|
|
367
|
+
total_num_loss = np.zeros([1], dtype="int32")
|
|
368
|
+
num_loss[0] = self.number_of_loss
|
|
369
|
+
comm.Allreduce((num_loss, MPI.INT), (total_num_loss, MPI.INT))
|
|
370
|
+
total_num_loss = total_num_loss[0]
|
|
312
371
|
else:
|
|
313
|
-
total_num_loss=self.number_of_loss
|
|
314
|
-
|
|
315
|
-
if self.mpi_rank==0:
|
|
316
|
-
print(
|
|
372
|
+
total_num_loss = self.number_of_loss
|
|
373
|
+
|
|
374
|
+
if self.mpi_rank == 0:
|
|
375
|
+
print("Total number of loss ", total_num_loss)
|
|
317
376
|
sys.stdout.flush()
|
|
318
|
-
|
|
319
|
-
l_log=np.zeros([self.mpi_size*self.MAXNUMLOSS],dtype='float32')
|
|
320
|
-
l_log[self.mpi_rank*self.MAXNUMLOSS:(self.mpi_rank+1)*self.MAXNUMLOSS]=-1.0
|
|
321
|
-
self.ltot=l_log.copy()
|
|
322
|
-
self.l_log=l_log
|
|
323
|
-
|
|
324
|
-
self.imin=0
|
|
325
|
-
self.start=time.time()
|
|
326
|
-
self.itt=0
|
|
327
|
-
|
|
328
|
-
self.oshape=list(x.shape)
|
|
329
|
-
|
|
330
|
-
if not isinstance(x,np.ndarray):
|
|
331
|
-
x=x.numpy()
|
|
332
|
-
|
|
333
|
-
x=x.flatten()
|
|
334
377
|
|
|
335
|
-
self.
|
|
378
|
+
l_log = np.zeros([self.mpi_size * self.MAXNUMLOSS], dtype="float32")
|
|
379
|
+
l_log[
|
|
380
|
+
self.mpi_rank * self.MAXNUMLOSS : (self.mpi_rank + 1) * self.MAXNUMLOSS
|
|
381
|
+
] = -1.0
|
|
382
|
+
self.ltot = l_log.copy()
|
|
383
|
+
self.l_log = l_log
|
|
336
384
|
|
|
337
|
-
self.
|
|
385
|
+
self.imin = 0
|
|
386
|
+
self.start = time.time()
|
|
387
|
+
self.itt = 0
|
|
338
388
|
|
|
339
|
-
self.
|
|
389
|
+
self.oshape = list(x.shape)
|
|
340
390
|
|
|
341
|
-
|
|
342
|
-
|
|
343
|
-
|
|
344
|
-
|
|
345
|
-
|
|
391
|
+
if not isinstance(x, np.ndarray):
|
|
392
|
+
x = x.numpy()
|
|
393
|
+
|
|
394
|
+
x = x.flatten()
|
|
395
|
+
|
|
396
|
+
self.do_all_noise = False
|
|
397
|
+
|
|
398
|
+
self.do_all_noise = True
|
|
399
|
+
|
|
400
|
+
self.noise_idx = None
|
|
401
|
+
|
|
402
|
+
# for k in range(self.number_of_loss):
|
|
403
|
+
# if self.loss_class[k].batch is not None:
|
|
404
|
+
# l_batch = self.loss_class[k].batch(
|
|
405
|
+
# self.loss_class[k].batch_data, 0, init=True
|
|
406
|
+
# )
|
|
407
|
+
|
|
408
|
+
l_tot, g_tot = self.calc_grad(x)
|
|
346
409
|
|
|
347
410
|
self.info_back(x)
|
|
348
411
|
|
|
349
|
-
maxitt=NUM_EPOCHS
|
|
412
|
+
maxitt = NUM_EPOCHS
|
|
350
413
|
|
|
351
|
-
start_x=x.copy()
|
|
414
|
+
# start_x = x.copy()
|
|
352
415
|
|
|
353
416
|
for iteration in range(NUM_STEP_BIAS):
|
|
354
|
-
|
|
355
|
-
x,
|
|
356
|
-
|
|
357
|
-
|
|
358
|
-
|
|
359
|
-
|
|
360
|
-
|
|
417
|
+
|
|
418
|
+
x, loss, i = opt.fmin_l_bfgs_b(
|
|
419
|
+
self.calc_grad,
|
|
420
|
+
x.astype("float64"),
|
|
421
|
+
callback=self.info_back,
|
|
422
|
+
pgtol=1e-32,
|
|
423
|
+
factr=factr,
|
|
424
|
+
maxiter=maxitt,
|
|
425
|
+
)
|
|
361
426
|
|
|
362
427
|
# update bias input data
|
|
363
|
-
if iteration<NUM_STEP_BIAS-1:
|
|
364
|
-
#if self.mpi_rank==0:
|
|
428
|
+
if iteration < NUM_STEP_BIAS - 1:
|
|
429
|
+
# if self.mpi_rank==0:
|
|
365
430
|
# print('%s Hessian restart'%(self.MESSAGE))
|
|
366
431
|
|
|
367
|
-
omap=self.xtractmap(x,axis)
|
|
432
|
+
omap = self.xtractmap(x, axis)
|
|
368
433
|
|
|
369
434
|
for k in range(self.number_of_loss):
|
|
370
435
|
if self.loss_class[k].batch_update is not None:
|
|
371
|
-
self.loss_class[k].batch_update(
|
|
372
|
-
|
|
373
|
-
|
|
374
|
-
|
|
375
|
-
|
|
376
|
-
|
|
436
|
+
self.loss_class[k].batch_update(
|
|
437
|
+
self.loss_class[k].batch_data, omap
|
|
438
|
+
)
|
|
439
|
+
# if self.loss_class[k].batch is not None:
|
|
440
|
+
# l_batch = self.loss_class[k].batch(
|
|
441
|
+
# self.loss_class[k].batch_data, 0, init=True
|
|
442
|
+
# )
|
|
443
|
+
# x=start_x.copy()
|
|
444
|
+
|
|
445
|
+
if self.mpi_rank == 0 and SHOWGPU:
|
|
377
446
|
self.stop_synthesis()
|
|
378
447
|
|
|
379
448
|
if self.KEEP_TRACK is not None:
|
|
380
|
-
self.last_info=self.KEEP_TRACK(None,self.mpi_rank,add=False)
|
|
449
|
+
self.last_info = self.KEEP_TRACK(None, self.mpi_rank, add=False)
|
|
381
450
|
|
|
382
|
-
x=self.xtractmap(x,axis)
|
|
383
|
-
return
|
|
451
|
+
x = self.xtractmap(x, axis)
|
|
452
|
+
return x
|
|
384
453
|
|
|
385
454
|
def get_history(self):
|
|
386
|
-
return
|
|
455
|
+
return self.history[0 : self.nlog]
|