foscat 3.0.9__py3-none-any.whl → 3.6.0__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
foscat/Synthesis.py CHANGED
@@ -1,55 +1,77 @@
1
- import numpy as np
2
- import time
3
- import sys
4
1
  import os
5
- from datetime import datetime
6
- from packaging import version
7
- from threading import Thread
8
- from threading import Event
2
+ import sys
3
+ import time
4
+ from threading import Event, Thread
5
+
6
+ import numpy as np
9
7
  import scipy.optimize as opt
10
8
 
9
+
11
10
  class Loss:
12
-
13
- def __init__(self,function,scat_operator,*param,
14
- name='',
15
- batch=None,
16
- batch_data=None,
17
- batch_update=None,
18
- info_callback=False):
19
-
20
- self.loss_function=function
21
- self.scat_operator=scat_operator
22
- self.args=param
23
- self.name=name
24
- self.batch=batch
25
- self.batch_data=batch_data
26
- self.batch_update=batch_update
27
- self.info=info_callback
28
-
29
- def eval(self,x,batch,return_all=False):
11
+
12
+ def __init__(
13
+ self,
14
+ function,
15
+ scat_operator,
16
+ *param,
17
+ name="",
18
+ batch=None,
19
+ batch_data=None,
20
+ batch_update=None,
21
+ info_callback=False,
22
+ ):
23
+
24
+ self.loss_function = function
25
+ self.scat_operator = scat_operator
26
+ self.args = param
27
+ self.name = name
28
+ self.batch = batch
29
+ self.batch_data = batch_data
30
+ self.batch_update = batch_update
31
+ self.info = info_callback
32
+ self.id_loss = 0
33
+
34
+ def eval(self, x, batch, return_all=False):
30
35
  if self.batch is None:
31
36
  if self.info:
32
- return self.loss_function(x,self.scat_operator,self.args,return_all=return_all)
37
+ return self.loss_function(
38
+ x, self.scat_operator, self.args, return_all=return_all
39
+ )
33
40
  else:
34
- return self.loss_function(x,self.scat_operator,self.args)
41
+ return self.loss_function(x, self.scat_operator, self.args)
35
42
  else:
36
43
  if self.info:
37
- return self.loss_function(x,batch,self.scat_operator,self.args,return_all=return_all)
44
+ return self.loss_function(
45
+ x, batch, self.scat_operator, self.args, return_all=return_all
46
+ )
38
47
  else:
39
- return self.loss_function(x,batch,self.scat_operator,self.args)
48
+ return self.loss_function(x, batch, self.scat_operator, self.args)
49
+
50
+ def set_id_loss(self,id_loss):
51
+ self.id_loss = id_loss
40
52
 
53
+ def get_id_loss(self,id_loss):
54
+ return self.id_loss
55
+
41
56
  class Synthesis:
42
- def __init__(self,
43
- loss_list,
44
- eta=0.03,
45
- beta1=0.9,
46
- beta2=0.999,
47
- epsilon=1e-7,
48
- decay_rate = 0.999):
49
-
50
- self.loss_class=loss_list
51
- self.number_of_loss=len(loss_list)
52
- self.nlog=0
57
+ def __init__(
58
+ self,
59
+ loss_list,
60
+ eta=0.03,
61
+ beta1=0.9,
62
+ beta2=0.999,
63
+ epsilon=1e-7,
64
+ decay_rate=0.999,
65
+ ):
66
+
67
+ self.loss_class = loss_list
68
+ self.number_of_loss = len(loss_list)
69
+
70
+ for k in range(self.number_of_loss):
71
+ self.loss_class[k].set_id_loss(k)
72
+
73
+ self.__iteration__ = 1234
74
+ self.nlog = 0
53
75
  self.m_dw, self.v_dw = 0.0, 0.0
54
76
  self.beta1 = beta1
55
77
  self.beta2 = beta2
@@ -57,326 +79,377 @@ class Synthesis:
57
79
  self.pbeta2 = beta2
58
80
  self.epsilon = epsilon
59
81
  self.eta = eta
60
- self.history=np.zeros([10])
61
- self.curr_gpu=0
82
+ self.history = np.zeros([10])
83
+ self.curr_gpu = 0
62
84
  self.event = Event()
63
- self.operation=loss_list[0].scat_operator
64
- self.mpi_size=self.operation.mpi_size
65
- self.mpi_rank=self.operation.mpi_rank
66
- self.KEEP_TRACK=None
67
- self.MAXNUMLOSS=len(loss_list)
68
-
69
- if self.operation.BACKEND=='tensorflow':
85
+ self.operation = loss_list[0].scat_operator
86
+ self.mpi_size = self.operation.mpi_size
87
+ self.mpi_rank = self.operation.mpi_rank
88
+ self.KEEP_TRACK = None
89
+ self.MAXNUMLOSS = len(loss_list)
90
+
91
+ if self.operation.BACKEND == "tensorflow":
70
92
  import foscat.loss_backend_tens as fbk
71
- self.bk=fbk.loss_backend(self.operation,self.curr_gpu,self.mpi_rank)
72
-
73
- if self.operation.BACKEND=='torch':
93
+
94
+ self.bk = fbk.loss_backend(self.operation, self.curr_gpu, self.mpi_rank)
95
+
96
+ if self.operation.BACKEND == "torch":
74
97
  import foscat.loss_backend_torch as fbk
75
- self.bk=fbk.loss_backend(self.operation,self.curr_gpu,self.mpi_rank)
76
-
77
- if self.operation.BACKEND=='numpy':
78
- print('Synthesis does not work with numpy. Please select Torch or Tensorflow FOSCAT backend')
79
- exit(0)
98
+
99
+ self.bk = fbk.loss_backend(self.operation, self.curr_gpu, self.mpi_rank)
100
+
101
+ if self.operation.BACKEND == "numpy":
102
+ print(
103
+ "Synthesis does not work with numpy. Please select Torch or Tensorflow FOSCAT backend"
104
+ )
105
+ return None
80
106
 
81
107
  # ---------------------------------------------−---------
82
- def get_gpu(self,event,delay):
108
+ def get_gpu(self, event, delay):
83
109
 
84
- isnvidia=os.system('which nvidia-smi &> /dev/null')
110
+ isnvidia = os.system("which nvidia-smi &> /dev/null")
85
111
 
86
- while (1):
112
+ while 1:
87
113
  if event.is_set():
88
114
  break
89
115
  time.sleep(delay)
90
- if isnvidia==0:
116
+ if isnvidia == 0:
91
117
  try:
92
- os.system("nvidia-smi | awk '$2==\"N/A\"{print substr($9,1,length($9)-3),substr($11,1,length($11)-3),substr($13,1,length($13)-1)}' > smi_tmp.txt")
118
+ os.system(
119
+ "nvidia-smi | awk '$2==\"N/A\"{print substr($9,1,length($9)-3),substr($11,1,length($11)-3),substr($13,1,length($13)-1)}' > smi_tmp.txt"
120
+ )
93
121
  except:
94
- nogpu=1
95
-
122
+ print("No nvidia GPU: Impossible to trace")
123
+ self.nogpu = 1
124
+
96
125
  def stop_synthesis(self):
97
126
  # stop thread that catch GPU information
98
127
  self.event.set()
99
-
128
+
100
129
  try:
101
130
  self.gpu_thrd.join()
102
131
  except:
103
- print('No thread to stop, everything is ok')
132
+ print("No thread to stop, everything is ok")
104
133
  sys.stdout.flush()
105
-
134
+
106
135
  # ---------------------------------------------−---------
107
136
  def getgpumem(self):
108
137
  try:
109
- return np.loadtxt('smi_tmp.txt')
138
+ return np.loadtxt("smi_tmp.txt")
110
139
  except:
111
- return(np.zeros([1,3]))
112
-
140
+ return np.zeros([1, 3])
141
+
113
142
  # ---------------------------------------------−---------
114
- def info_back(self,x):
115
-
116
- self.nlog=self.nlog+1
117
- self.itt2=0
118
-
119
- if self.itt%self.EVAL_FREQUENCY==0 and self.mpi_rank==0:
143
+ def info_back(self, x):
144
+
145
+ self.nlog = self.nlog + 1
146
+ self.itt2 = 0
147
+
148
+ if self.itt % self.EVAL_FREQUENCY == 0 and self.mpi_rank == 0:
120
149
  end = time.time()
121
- cur_loss='%10.3g ('%(self.ltot[self.ltot!=-1].mean())
122
- for k in self.ltot[self.ltot!=-1]:
123
- cur_loss=cur_loss+'%10.3g '%(k)
124
-
125
- cur_loss=cur_loss+')'
126
-
127
- mess=''
128
-
150
+ cur_loss = "%10.3g (" % (self.ltot[self.ltot != -1].mean())
151
+ for k in self.ltot[self.ltot != -1]:
152
+ cur_loss = cur_loss + "%10.3g " % (k)
153
+
154
+ cur_loss = cur_loss + ")"
155
+
156
+ mess = ""
157
+
129
158
  if self.SHOWGPU:
130
- info_gpu=self.getgpumem()
159
+ info_gpu = self.getgpumem()
131
160
  for k in range(info_gpu.shape[0]):
132
- mess=mess+'[GPU%d %.0f/%.0f MB %.0f%%]'%(k,info_gpu[k,0],info_gpu[k,1],info_gpu[k,2])
133
-
134
- print('%sItt %6d L=%s %.3fs %s'%(self.MESSAGE,self.itt,cur_loss,(end-self.start),mess))
161
+ mess = mess + "[GPU%d %.0f/%.0f MB %.0f%%]" % (
162
+ k,
163
+ info_gpu[k, 0],
164
+ info_gpu[k, 1],
165
+ info_gpu[k, 2],
166
+ )
167
+
168
+ print(
169
+ "%sItt %6d L=%s %.3fs %s"
170
+ % (self.MESSAGE, self.itt, cur_loss, (end - self.start), mess)
171
+ )
135
172
  sys.stdout.flush()
136
173
  if self.KEEP_TRACK is not None:
137
174
  print(self.last_info)
138
175
  sys.stdout.flush()
139
-
176
+
140
177
  self.start = time.time()
141
-
142
- self.itt=self.itt+1
143
-
178
+
179
+ self.itt = self.itt + 1
180
+
144
181
  # ---------------------------------------------−---------
145
- def calc_grad(self,in_x):
146
-
147
- g_tot=None
148
- l_tot=0.0
182
+ def calc_grad(self, in_x):
183
+
184
+ g_tot = None
185
+ l_tot = 0.0
149
186
 
150
- if self.do_all_noise and self.totalsz>self.batchsz:
151
- nstep=self.totalsz//self.batchsz
187
+ if self.do_all_noise and self.totalsz > self.batchsz:
188
+ nstep = self.totalsz // self.batchsz
152
189
  else:
153
- nstep=1
190
+ nstep = 1
191
+
192
+ x = self.operation.backend.bk_reshape(
193
+ self.operation.backend.bk_cast(in_x), self.oshape
194
+ )
195
+
196
+ self.l_log[
197
+ self.mpi_rank * self.MAXNUMLOSS : (self.mpi_rank + 1) * self.MAXNUMLOSS
198
+ ] = -1.0
154
199
 
155
- x=self.operation.backend.bk_reshape(self.operation.backend.bk_cast(in_x),self.oshape)
156
-
157
- self.l_log[self.mpi_rank*self.MAXNUMLOSS:(self.mpi_rank+1)*self.MAXNUMLOSS]=-1.0
158
-
159
200
  for istep in range(nstep):
160
-
201
+
161
202
  for k in range(self.number_of_loss):
162
203
  if self.loss_class[k].batch is None:
163
- l_batch=None
204
+ l_batch = None
164
205
  else:
165
- l_batch=self.loss_class[k].batch(self.loss_class[k].batch_data,istep)
206
+ l_batch = self.loss_class[k].batch(
207
+ self.loss_class[k].batch_data, istep
208
+ )
166
209
 
167
210
  if self.KEEP_TRACK is not None:
168
- l,g,linfo=self.bk.loss(x,l_batch,self.loss_class[k],self.KEEP_TRACK)
169
- self.last_info=self.KEEP_TRACK(linfo,self.mpi_rank,add=True)
211
+ l_loss, g, linfo = self.bk.loss(
212
+ x, l_batch, self.loss_class[k], self.KEEP_TRACK
213
+ )
214
+ self.last_info = self.KEEP_TRACK(linfo, self.mpi_rank, add=True)
170
215
  else:
171
- l,g=self.bk.loss(x,l_batch,self.loss_class[k],self.KEEP_TRACK)
216
+ l_loss, g = self.bk.loss(
217
+ x, l_batch, self.loss_class[k], self.KEEP_TRACK
218
+ )
172
219
 
173
220
  if g_tot is None:
174
- g_tot=g
221
+ g_tot = g
175
222
  else:
176
- g_tot=g_tot+g
223
+ g_tot = g_tot + g
177
224
 
178
- l_tot=l_tot+l.numpy()
225
+ l_tot = l_tot + l_loss.numpy()
179
226
 
180
- if self.l_log[self.mpi_rank*self.MAXNUMLOSS+k]==-1:
181
- self.l_log[self.mpi_rank*self.MAXNUMLOSS+k]=l.numpy()/nstep
227
+ if self.l_log[self.mpi_rank * self.MAXNUMLOSS + k] == -1:
228
+ self.l_log[self.mpi_rank * self.MAXNUMLOSS + k] = (
229
+ l_loss.numpy() / nstep
230
+ )
182
231
  else:
183
- self.l_log[self.mpi_rank*self.MAXNUMLOSS+k]=self.l_log[self.mpi_rank*self.MAXNUMLOSS+k]+l.numpy()/nstep
184
-
185
- grd_mask=self.grd_mask
186
-
232
+ self.l_log[self.mpi_rank * self.MAXNUMLOSS + k] = (
233
+ self.l_log[self.mpi_rank * self.MAXNUMLOSS + k]
234
+ + l_loss.numpy() / nstep
235
+ )
236
+
237
+ grd_mask = self.grd_mask
238
+
187
239
  if grd_mask is not None:
188
- g_tot=grd_mask*g_tot.numpy()
240
+ g_tot = grd_mask * g_tot.numpy()
189
241
  else:
190
- g_tot=g_tot.numpy()
191
-
192
- g_tot[np.isnan(g_tot)]=0.0
242
+ g_tot = g_tot.numpy()
243
+
244
+ g_tot[np.isnan(g_tot)] = 0.0
193
245
 
194
- self.imin=self.imin+self.batchsz
246
+ self.imin = self.imin + self.batchsz
195
247
 
196
- if self.mpi_size==1:
197
- self.ltot=self.l_log
248
+ if self.mpi_size == 1:
249
+ self.ltot = self.l_log
198
250
  else:
199
- local_log=(self.l_log).astype('float64')
200
- self.ltot=np.zeros(self.l_log.shape,dtype='float64')
201
- self.comm.Allreduce((local_log,self.MPI.DOUBLE),(self.ltot,self.MPI.DOUBLE))
202
-
203
- if self.mpi_size==1:
204
- grad=g_tot
251
+ local_log = (self.l_log).astype("float64")
252
+ self.ltot = np.zeros(self.l_log.shape, dtype="float64")
253
+ self.comm.Allreduce(
254
+ (local_log, self.MPI.DOUBLE), (self.ltot, self.MPI.DOUBLE)
255
+ )
256
+
257
+ if self.mpi_size == 1:
258
+ grad = g_tot
205
259
  else:
206
- if self.operation.backend.bk_is_complex( g_tot):
207
- grad=np.zeros(self.oshape,dtype=gtot.dtype)
260
+ if self.operation.backend.bk_is_complex(g_tot):
261
+ grad = np.zeros(self.oshape, dtype=g_tot.dtype)
208
262
 
209
- self.comm.Allreduce((g_tot),(grad))
263
+ self.comm.Allreduce((g_tot), (grad))
210
264
  else:
211
- grad=np.zeros(self.oshape,dtype='float64')
265
+ grad = np.zeros(self.oshape, dtype="float64")
212
266
 
213
- self.comm.Allreduce((g_tot.astype('float64'),self.MPI.DOUBLE),
214
- (grad,self.MPI.DOUBLE))
215
-
216
- if self.nlog==self.history.shape[0]:
217
- new_log=np.zeros([self.history.shape[0]*2])
218
- new_log[0:self.nlog]=self.history
219
- self.history=new_log
267
+ self.comm.Allreduce(
268
+ (g_tot.astype("float64"), self.MPI.DOUBLE), (grad, self.MPI.DOUBLE)
269
+ )
220
270
 
221
- l_tot=self.ltot[self.ltot!=-1].mean()
222
-
223
- self.history[self.nlog]=l_tot
271
+ if self.nlog == self.history.shape[0]:
272
+ new_log = np.zeros([self.history.shape[0] * 2])
273
+ new_log[0 : self.nlog] = self.history
274
+ self.history = new_log
224
275
 
225
- g_tot=grad.flatten()
276
+ l_tot = self.ltot[self.ltot != -1].mean()
226
277
 
227
- if self.operation.backend.bk_is_complex( g_tot):
228
- return l_tot.astype('float64'),g_tot
229
-
230
- return l_tot.astype('float64'),g_tot.astype('float64')
278
+ self.history[self.nlog] = l_tot
279
+
280
+ g_tot = grad.flatten()
281
+
282
+ if self.operation.backend.bk_is_complex(g_tot):
283
+ return l_tot.astype("float64"), g_tot
284
+
285
+ return l_tot.astype("float64"), g_tot.astype("float64")
231
286
 
232
287
  # ---------------------------------------------−---------
233
- def xtractmap(self,x,axis):
234
- x=self.operation.backend.bk_reshape(x,self.oshape)
235
-
288
+ def xtractmap(self, x, axis):
289
+ x = self.operation.backend.bk_reshape(x, self.oshape)
290
+
236
291
  return x
237
292
 
238
293
  # ---------------------------------------------−---------
239
- def run(self,
240
- in_x,
241
- NUM_EPOCHS = 100,
242
- DECAY_RATE=0.95,
243
- EVAL_FREQUENCY = 100,
244
- DEVAL_STAT_FREQUENCY = 1000,
245
- NUM_STEP_BIAS = 1,
246
- LEARNING_RATE = 0.03,
247
- EPSILON = 1E-7,
248
- KEEP_TRACK=None,
249
- grd_mask=None,
250
- SHOWGPU=False,
251
- MESSAGE='',
252
- factr=10.0,
253
- batchsz=1,
254
- totalsz=1,
255
- do_lbfgs=True,
256
- axis=0):
257
-
258
- self.KEEP_TRACK=KEEP_TRACK
259
- self.track={}
260
- self.ntrack=0
261
- self.eta=LEARNING_RATE
262
- self.epsilon=EPSILON
294
+ def run(
295
+ self,
296
+ in_x,
297
+ NUM_EPOCHS=100,
298
+ DECAY_RATE=0.95,
299
+ EVAL_FREQUENCY=100,
300
+ DEVAL_STAT_FREQUENCY=1000,
301
+ NUM_STEP_BIAS=1,
302
+ LEARNING_RATE=0.03,
303
+ EPSILON=1e-7,
304
+ KEEP_TRACK=None,
305
+ grd_mask=None,
306
+ SHOWGPU=False,
307
+ MESSAGE="",
308
+ factr=10.0,
309
+ batchsz=1,
310
+ totalsz=1,
311
+ do_lbfgs=True,
312
+ axis=0,
313
+ ):
314
+
315
+ self.KEEP_TRACK = KEEP_TRACK
316
+ self.track = {}
317
+ self.ntrack = 0
318
+ self.eta = LEARNING_RATE
319
+ self.epsilon = EPSILON
263
320
  self.decay_rate = DECAY_RATE
264
- self.nlog=0
265
- self.itt2=0
266
- self.batchsz=batchsz
267
- self.totalsz=totalsz
268
- self.grd_mask=grd_mask
269
- self.EVAL_FREQUENCY=EVAL_FREQUENCY
270
- self.MESSAGE=MESSAGE
271
- self.SHOWGPU=SHOWGPU
272
- self.axis=axis
273
- self.in_x_nshape=in_x.shape[0]
274
-
275
- np.random.seed(self.mpi_rank*7+1234)
276
-
277
- x=in_x
278
-
279
- self.curr_gpu=self.curr_gpu+self.mpi_rank
280
-
281
- if self.mpi_size>1:
321
+ self.nlog = 0
322
+ self.itt2 = 0
323
+ self.batchsz = batchsz
324
+ self.totalsz = totalsz
325
+ self.grd_mask = grd_mask
326
+ self.EVAL_FREQUENCY = EVAL_FREQUENCY
327
+ self.MESSAGE = MESSAGE
328
+ self.SHOWGPU = SHOWGPU
329
+ self.axis = axis
330
+ self.in_x_nshape = in_x.shape[0]
331
+ self.seed = 1234
332
+
333
+ np.random.seed(self.mpi_rank * 7 + 1234)
334
+
335
+ x = in_x
336
+
337
+ self.curr_gpu = self.curr_gpu + self.mpi_rank
338
+
339
+ if self.mpi_size > 1:
282
340
  from mpi4py import MPI
283
-
284
341
 
285
342
  comm = MPI.COMM_WORLD
286
- self.comm=comm
287
- self.MPI=MPI
288
- if self.mpi_rank==0:
289
- print('Work with MPI')
343
+ self.comm = comm
344
+ self.MPI = MPI
345
+ if self.mpi_rank == 0:
346
+ print("Work with MPI")
290
347
  sys.stdout.flush()
291
-
292
- if self.mpi_rank==0 and SHOWGPU:
348
+
349
+ if self.mpi_rank == 0 and SHOWGPU:
293
350
  # start thread that catch GPU information
294
351
  try:
295
- self.gpu_thrd = Thread(target=self.get_gpu, args=(self.event,1,))
352
+ self.gpu_thrd = Thread(
353
+ target=self.get_gpu,
354
+ args=(
355
+ self.event,
356
+ 1,
357
+ ),
358
+ )
296
359
  self.gpu_thrd.start()
297
360
  except:
298
361
  print("Error: unable to start thread for GPU survey")
299
-
300
- start = time.time()
301
-
302
- if self.mpi_size>1:
303
- num_loss=np.zeros([1],dtype='int32')
304
- total_num_loss=np.zeros([1],dtype='int32')
305
- num_loss[0]=self.number_of_loss
306
- comm.Allreduce((num_loss,MPI.INT),(total_num_loss,MPI.INT))
307
- total_num_loss=total_num_loss[0]
362
+
363
+ # start = time.time()
364
+
365
+ if self.mpi_size > 1:
366
+ num_loss = np.zeros([1], dtype="int32")
367
+ total_num_loss = np.zeros([1], dtype="int32")
368
+ num_loss[0] = self.number_of_loss
369
+ comm.Allreduce((num_loss, MPI.INT), (total_num_loss, MPI.INT))
370
+ total_num_loss = total_num_loss[0]
308
371
  else:
309
- total_num_loss=self.number_of_loss
310
-
311
- if self.mpi_rank==0:
312
- print('Total number of loss ',total_num_loss)
372
+ total_num_loss = self.number_of_loss
373
+
374
+ if self.mpi_rank == 0:
375
+ print("Total number of loss ", total_num_loss)
313
376
  sys.stdout.flush()
314
-
315
- l_log=np.zeros([self.mpi_size*self.MAXNUMLOSS],dtype='float32')
316
- l_log[self.mpi_rank*self.MAXNUMLOSS:(self.mpi_rank+1)*self.MAXNUMLOSS]=-1.0
317
- self.ltot=l_log.copy()
318
- self.l_log=l_log
319
-
320
- self.imin=0
321
- self.start=time.time()
322
- self.itt=0
323
-
324
- self.oshape=list(x.shape)
325
-
326
- if not isinstance(x,np.ndarray):
327
- x=x.numpy()
328
-
329
- x=x.flatten()
330
377
 
331
- self.do_all_noise=False
378
+ l_log = np.zeros([self.mpi_size * self.MAXNUMLOSS], dtype="float32")
379
+ l_log[
380
+ self.mpi_rank * self.MAXNUMLOSS : (self.mpi_rank + 1) * self.MAXNUMLOSS
381
+ ] = -1.0
382
+ self.ltot = l_log.copy()
383
+ self.l_log = l_log
332
384
 
333
- self.do_all_noise=True
385
+ self.imin = 0
386
+ self.start = time.time()
387
+ self.itt = 0
334
388
 
335
- self.noise_idx=None
389
+ self.oshape = list(x.shape)
336
390
 
337
- for k in range(self.number_of_loss):
338
- if self.loss_class[k].batch is not None:
339
- l_batch=self.loss_class[k].batch(self.loss_class[k].batch_data,0,init=True)
340
-
341
- l_tot,g_tot=self.calc_grad(x)
391
+ if not isinstance(x, np.ndarray):
392
+ x = x.numpy()
393
+
394
+ x = x.flatten()
395
+
396
+ self.do_all_noise = False
397
+
398
+ self.do_all_noise = True
399
+
400
+ self.noise_idx = None
401
+
402
+ # for k in range(self.number_of_loss):
403
+ # if self.loss_class[k].batch is not None:
404
+ # l_batch = self.loss_class[k].batch(
405
+ # self.loss_class[k].batch_data, 0, init=True
406
+ # )
407
+
408
+ l_tot, g_tot = self.calc_grad(x)
342
409
 
343
410
  self.info_back(x)
344
411
 
345
- maxitt=NUM_EPOCHS
412
+ maxitt = NUM_EPOCHS
346
413
 
347
- start_x=x.copy()
414
+ # start_x = x.copy()
348
415
 
349
416
  for iteration in range(NUM_STEP_BIAS):
350
417
 
351
- x,l,i=opt.fmin_l_bfgs_b(self.calc_grad,
352
- x.astype('float64'),
353
- callback=self.info_back,
354
- pgtol=1E-32,
355
- factr=factr,
356
- maxiter=maxitt)
418
+ x, loss, i = opt.fmin_l_bfgs_b(
419
+ self.calc_grad,
420
+ x.astype("float64"),
421
+ callback=self.info_back,
422
+ pgtol=1e-32,
423
+ factr=factr,
424
+ maxiter=maxitt,
425
+ )
357
426
 
358
427
  # update bias input data
359
- if iteration<NUM_STEP_BIAS-1:
360
- if self.mpi_rank==0:
361
- print('%s Hessian restart'%(self.MESSAGE))
428
+ if iteration < NUM_STEP_BIAS - 1:
429
+ # if self.mpi_rank==0:
430
+ # print('%s Hessian restart'%(self.MESSAGE))
362
431
 
363
- omap=self.xtractmap(x,axis)
432
+ omap = self.xtractmap(x, axis)
364
433
 
365
434
  for k in range(self.number_of_loss):
366
435
  if self.loss_class[k].batch_update is not None:
367
- self.loss_class[k].batch_update(self.loss_class[k].batch_data,omap)
368
- l_batch=self.loss_class[k].batch(self.loss_class[k].batch_data,0,init=True)
369
- #x=start_x.copy()
370
-
371
-
372
- if self.mpi_rank==0 and SHOWGPU:
436
+ self.loss_class[k].batch_update(
437
+ self.loss_class[k].batch_data, omap
438
+ )
439
+ # if self.loss_class[k].batch is not None:
440
+ # l_batch = self.loss_class[k].batch(
441
+ # self.loss_class[k].batch_data, 0, init=True
442
+ # )
443
+ # x=start_x.copy()
444
+
445
+ if self.mpi_rank == 0 and SHOWGPU:
373
446
  self.stop_synthesis()
374
447
 
375
448
  if self.KEEP_TRACK is not None:
376
- self.last_info=self.KEEP_TRACK(None,self.mpi_rank,add=False)
449
+ self.last_info = self.KEEP_TRACK(None, self.mpi_rank, add=False)
377
450
 
378
- x=self.xtractmap(x,axis)
379
- return(x)
451
+ x = self.xtractmap(x, axis)
452
+ return x
380
453
 
381
454
  def get_history(self):
382
- return(self.history[0:self.nlog])
455
+ return self.history[0 : self.nlog]