foscat 3.1.6__py3-none-any.whl → 3.3.0__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
foscat/Synthesis.py CHANGED
@@ -1,56 +1,77 @@
1
- import numpy as np
2
- import time
3
- import sys
4
1
  import os
5
- from datetime import datetime
6
- from packaging import version
7
- from threading import Thread
8
- from threading import Event
2
+ import sys
3
+ import time
4
+ from threading import Event, Thread
5
+
6
+ import numpy as np
9
7
  import scipy.optimize as opt
10
8
 
9
+
11
10
  class Loss:
12
-
13
- def __init__(self,function,scat_operator,*param,
14
- name='',
15
- batch=None,
16
- batch_data=None,
17
- batch_update=None,
18
- info_callback=False):
19
-
20
- self.loss_function=function
21
- self.scat_operator=scat_operator
22
- self.args=param
23
- self.name=name
24
- self.batch=batch
25
- self.batch_data=batch_data
26
- self.batch_update=batch_update
27
- self.info=info_callback
28
-
29
- def eval(self,x,batch,return_all=False):
11
+
12
+ def __init__(
13
+ self,
14
+ function,
15
+ scat_operator,
16
+ *param,
17
+ name="",
18
+ batch=None,
19
+ batch_data=None,
20
+ batch_update=None,
21
+ info_callback=False,
22
+ ):
23
+
24
+ self.loss_function = function
25
+ self.scat_operator = scat_operator
26
+ self.args = param
27
+ self.name = name
28
+ self.batch = batch
29
+ self.batch_data = batch_data
30
+ self.batch_update = batch_update
31
+ self.info = info_callback
32
+ self.id_loss = 0
33
+
34
+ def eval(self, x, batch, return_all=False):
30
35
  if self.batch is None:
31
36
  if self.info:
32
- return self.loss_function(x,self.scat_operator,self.args,return_all=return_all)
37
+ return self.loss_function(
38
+ x, self.scat_operator, self.args, return_all=return_all
39
+ )
33
40
  else:
34
- return self.loss_function(x,self.scat_operator,self.args)
41
+ return self.loss_function(x, self.scat_operator, self.args)
35
42
  else:
36
43
  if self.info:
37
- return self.loss_function(x,batch,self.scat_operator,self.args,return_all=return_all)
44
+ return self.loss_function(
45
+ x, batch, self.scat_operator, self.args, return_all=return_all
46
+ )
38
47
  else:
39
- return self.loss_function(x,batch,self.scat_operator,self.args)
48
+ return self.loss_function(x, batch, self.scat_operator, self.args)
49
+
50
+ def set_id_loss(self,id_loss):
51
+ self.id_loss = id_loss
40
52
 
53
+ def get_id_loss(self,id_loss):
54
+ return self.id_loss
55
+
41
56
  class Synthesis:
42
- def __init__(self,
43
- loss_list,
44
- eta=0.03,
45
- beta1=0.9,
46
- beta2=0.999,
47
- epsilon=1e-7,
48
- decay_rate = 0.999):
49
-
50
- self.loss_class=loss_list
51
- self.number_of_loss=len(loss_list)
52
- self.__iteration__=1234
53
- self.nlog=0
57
+ def __init__(
58
+ self,
59
+ loss_list,
60
+ eta=0.03,
61
+ beta1=0.9,
62
+ beta2=0.999,
63
+ epsilon=1e-7,
64
+ decay_rate=0.999,
65
+ ):
66
+
67
+ self.loss_class = loss_list
68
+ self.number_of_loss = len(loss_list)
69
+
70
+ for k in range(self.number_of_loss):
71
+ self.loss_class[k].set_id_loss(k)
72
+
73
+ self.__iteration__ = 1234
74
+ self.nlog = 0
54
75
  self.m_dw, self.v_dw = 0.0, 0.0
55
76
  self.beta1 = beta1
56
77
  self.beta2 = beta2
@@ -58,329 +79,377 @@ class Synthesis:
58
79
  self.pbeta2 = beta2
59
80
  self.epsilon = epsilon
60
81
  self.eta = eta
61
- self.history=np.zeros([10])
62
- self.curr_gpu=0
82
+ self.history = np.zeros([10])
83
+ self.curr_gpu = 0
63
84
  self.event = Event()
64
- self.operation=loss_list[0].scat_operator
65
- self.mpi_size=self.operation.mpi_size
66
- self.mpi_rank=self.operation.mpi_rank
67
- self.KEEP_TRACK=None
68
- self.MAXNUMLOSS=len(loss_list)
69
-
70
- if self.operation.BACKEND=='tensorflow':
85
+ self.operation = loss_list[0].scat_operator
86
+ self.mpi_size = self.operation.mpi_size
87
+ self.mpi_rank = self.operation.mpi_rank
88
+ self.KEEP_TRACK = None
89
+ self.MAXNUMLOSS = len(loss_list)
90
+
91
+ if self.operation.BACKEND == "tensorflow":
71
92
  import foscat.loss_backend_tens as fbk
72
- self.bk=fbk.loss_backend(self.operation,self.curr_gpu,self.mpi_rank)
73
-
74
- if self.operation.BACKEND=='torch':
93
+
94
+ self.bk = fbk.loss_backend(self.operation, self.curr_gpu, self.mpi_rank)
95
+
96
+ if self.operation.BACKEND == "torch":
75
97
  import foscat.loss_backend_torch as fbk
76
- self.bk=fbk.loss_backend(self.operation,self.curr_gpu,self.mpi_rank)
77
-
78
- if self.operation.BACKEND=='numpy':
79
- print('Synthesis does not work with numpy. Please select Torch or Tensorflow FOSCAT backend')
98
+
99
+ self.bk = fbk.loss_backend(self.operation, self.curr_gpu, self.mpi_rank)
100
+
101
+ if self.operation.BACKEND == "numpy":
102
+ print(
103
+ "Synthesis does not work with numpy. Please select Torch or Tensorflow FOSCAT backend"
104
+ )
80
105
  return None
81
106
 
82
107
  # ---------------------------------------------−---------
83
- def get_gpu(self,event,delay):
108
+ def get_gpu(self, event, delay):
84
109
 
85
- isnvidia=os.system('which nvidia-smi &> /dev/null')
110
+ isnvidia = os.system("which nvidia-smi &> /dev/null")
86
111
 
87
- while (1):
112
+ while 1:
88
113
  if event.is_set():
89
114
  break
90
115
  time.sleep(delay)
91
- if isnvidia==0:
116
+ if isnvidia == 0:
92
117
  try:
93
- os.system("nvidia-smi | awk '$2==\"N/A\"{print substr($9,1,length($9)-3),substr($11,1,length($11)-3),substr($13,1,length($13)-1)}' > smi_tmp.txt")
118
+ os.system(
119
+ "nvidia-smi | awk '$2==\"N/A\"{print substr($9,1,length($9)-3),substr($11,1,length($11)-3),substr($13,1,length($13)-1)}' > smi_tmp.txt"
120
+ )
94
121
  except:
95
- nogpu=1
96
-
122
+ print("No nvidia GPU: Impossible to trace")
123
+ self.nogpu = 1
124
+
97
125
  def stop_synthesis(self):
98
126
  # stop thread that catch GPU information
99
127
  self.event.set()
100
-
128
+
101
129
  try:
102
130
  self.gpu_thrd.join()
103
131
  except:
104
- print('No thread to stop, everything is ok')
132
+ print("No thread to stop, everything is ok")
105
133
  sys.stdout.flush()
106
-
134
+
107
135
  # ---------------------------------------------−---------
108
136
  def getgpumem(self):
109
137
  try:
110
- return np.loadtxt('smi_tmp.txt')
138
+ return np.loadtxt("smi_tmp.txt")
111
139
  except:
112
- return(np.zeros([1,3]))
113
-
140
+ return np.zeros([1, 3])
141
+
114
142
  # ---------------------------------------------−---------
115
- def info_back(self,x):
116
-
117
- self.nlog=self.nlog+1
118
- self.itt2=0
119
-
120
- if self.itt%self.EVAL_FREQUENCY==0 and self.mpi_rank==0:
143
+ def info_back(self, x):
144
+
145
+ self.nlog = self.nlog + 1
146
+ self.itt2 = 0
147
+
148
+ if self.itt % self.EVAL_FREQUENCY == 0 and self.mpi_rank == 0:
121
149
  end = time.time()
122
- cur_loss='%10.3g ('%(self.ltot[self.ltot!=-1].mean())
123
- for k in self.ltot[self.ltot!=-1]:
124
- cur_loss=cur_loss+'%10.3g '%(k)
125
-
126
- cur_loss=cur_loss+')'
127
-
128
- mess=''
129
-
150
+ cur_loss = "%10.3g (" % (self.ltot[self.ltot != -1].mean())
151
+ for k in self.ltot[self.ltot != -1]:
152
+ cur_loss = cur_loss + "%10.3g " % (k)
153
+
154
+ cur_loss = cur_loss + ")"
155
+
156
+ mess = ""
157
+
130
158
  if self.SHOWGPU:
131
- info_gpu=self.getgpumem()
159
+ info_gpu = self.getgpumem()
132
160
  for k in range(info_gpu.shape[0]):
133
- mess=mess+'[GPU%d %.0f/%.0f MB %.0f%%]'%(k,info_gpu[k,0],info_gpu[k,1],info_gpu[k,2])
134
-
135
-
136
- print('%sItt %6d L=%s %.3fs %s'%(self.MESSAGE,self.itt,cur_loss,(end-self.start),mess))
161
+ mess = mess + "[GPU%d %.0f/%.0f MB %.0f%%]" % (
162
+ k,
163
+ info_gpu[k, 0],
164
+ info_gpu[k, 1],
165
+ info_gpu[k, 2],
166
+ )
167
+
168
+ print(
169
+ "%sItt %6d L=%s %.3fs %s"
170
+ % (self.MESSAGE, self.itt, cur_loss, (end - self.start), mess)
171
+ )
137
172
  sys.stdout.flush()
138
173
  if self.KEEP_TRACK is not None:
139
174
  print(self.last_info)
140
175
  sys.stdout.flush()
141
-
176
+
142
177
  self.start = time.time()
143
-
144
- self.itt=self.itt+1
178
+
179
+ self.itt = self.itt + 1
145
180
 
146
181
  # ---------------------------------------------−---------
147
- def calc_grad(self,in_x):
148
-
149
- g_tot=None
150
- l_tot=0.0
182
+ def calc_grad(self, in_x):
151
183
 
152
-
153
- if self.do_all_noise and self.totalsz>self.batchsz:
154
- nstep=self.totalsz//self.batchsz
184
+ g_tot = None
185
+ l_tot = 0.0
186
+
187
+ if self.do_all_noise and self.totalsz > self.batchsz:
188
+ nstep = self.totalsz // self.batchsz
155
189
  else:
156
- nstep=1
190
+ nstep = 1
191
+
192
+ x = self.operation.backend.bk_reshape(
193
+ self.operation.backend.bk_cast(in_x), self.oshape
194
+ )
195
+
196
+ self.l_log[
197
+ self.mpi_rank * self.MAXNUMLOSS : (self.mpi_rank + 1) * self.MAXNUMLOSS
198
+ ] = -1.0
157
199
 
158
- x=self.operation.backend.bk_reshape(self.operation.backend.bk_cast(in_x),self.oshape)
159
-
160
- self.l_log[self.mpi_rank*self.MAXNUMLOSS:(self.mpi_rank+1)*self.MAXNUMLOSS]=-1.0
161
-
162
200
  for istep in range(nstep):
163
-
201
+
164
202
  for k in range(self.number_of_loss):
165
203
  if self.loss_class[k].batch is None:
166
- l_batch=None
204
+ l_batch = None
167
205
  else:
168
- l_batch=self.loss_class[k].batch(self.loss_class[k].batch_data,istep)
206
+ l_batch = self.loss_class[k].batch(
207
+ self.loss_class[k].batch_data, istep
208
+ )
169
209
 
170
210
  if self.KEEP_TRACK is not None:
171
- l,g,linfo=self.bk.loss(x,l_batch,self.loss_class[k],self.KEEP_TRACK)
172
- self.last_info=self.KEEP_TRACK(linfo,self.mpi_rank,add=True)
211
+ l_loss, g, linfo = self.bk.loss(
212
+ x, l_batch, self.loss_class[k], self.KEEP_TRACK
213
+ )
214
+ self.last_info = self.KEEP_TRACK(linfo, self.mpi_rank, add=True)
173
215
  else:
174
- l,g=self.bk.loss(x,l_batch,self.loss_class[k],self.KEEP_TRACK)
216
+ l_loss, g = self.bk.loss(
217
+ x, l_batch, self.loss_class[k], self.KEEP_TRACK
218
+ )
175
219
 
176
220
  if g_tot is None:
177
- g_tot=g
221
+ g_tot = g
178
222
  else:
179
- g_tot=g_tot+g
223
+ g_tot = g_tot + g
180
224
 
181
- l_tot=l_tot+l.numpy()
225
+ l_tot = l_tot + l_loss.numpy()
182
226
 
183
- if self.l_log[self.mpi_rank*self.MAXNUMLOSS+k]==-1:
184
- self.l_log[self.mpi_rank*self.MAXNUMLOSS+k]=l.numpy()/nstep
227
+ if self.l_log[self.mpi_rank * self.MAXNUMLOSS + k] == -1:
228
+ self.l_log[self.mpi_rank * self.MAXNUMLOSS + k] = (
229
+ l_loss.numpy() / nstep
230
+ )
185
231
  else:
186
- self.l_log[self.mpi_rank*self.MAXNUMLOSS+k]=self.l_log[self.mpi_rank*self.MAXNUMLOSS+k]+l.numpy()/nstep
187
-
188
- grd_mask=self.grd_mask
189
-
232
+ self.l_log[self.mpi_rank * self.MAXNUMLOSS + k] = (
233
+ self.l_log[self.mpi_rank * self.MAXNUMLOSS + k]
234
+ + l_loss.numpy() / nstep
235
+ )
236
+
237
+ grd_mask = self.grd_mask
238
+
190
239
  if grd_mask is not None:
191
- g_tot=grd_mask*g_tot.numpy()
240
+ g_tot = grd_mask * g_tot.numpy()
192
241
  else:
193
- g_tot=g_tot.numpy()
194
-
195
- g_tot[np.isnan(g_tot)]=0.0
242
+ g_tot = g_tot.numpy()
196
243
 
197
- self.imin=self.imin+self.batchsz
244
+ g_tot[np.isnan(g_tot)] = 0.0
198
245
 
199
- if self.mpi_size==1:
200
- self.ltot=self.l_log
246
+ self.imin = self.imin + self.batchsz
247
+
248
+ if self.mpi_size == 1:
249
+ self.ltot = self.l_log
201
250
  else:
202
- local_log=(self.l_log).astype('float64')
203
- self.ltot=np.zeros(self.l_log.shape,dtype='float64')
204
- self.comm.Allreduce((local_log,self.MPI.DOUBLE),(self.ltot,self.MPI.DOUBLE))
205
-
206
- if self.mpi_size==1:
207
- grad=g_tot
251
+ local_log = (self.l_log).astype("float64")
252
+ self.ltot = np.zeros(self.l_log.shape, dtype="float64")
253
+ self.comm.Allreduce(
254
+ (local_log, self.MPI.DOUBLE), (self.ltot, self.MPI.DOUBLE)
255
+ )
256
+
257
+ if self.mpi_size == 1:
258
+ grad = g_tot
208
259
  else:
209
- if self.operation.backend.bk_is_complex( g_tot):
210
- grad=np.zeros(self.oshape,dtype=gtot.dtype)
260
+ if self.operation.backend.bk_is_complex(g_tot):
261
+ grad = np.zeros(self.oshape, dtype=g_tot.dtype)
211
262
 
212
- self.comm.Allreduce((g_tot),(grad))
263
+ self.comm.Allreduce((g_tot), (grad))
213
264
  else:
214
- grad=np.zeros(self.oshape,dtype='float64')
265
+ grad = np.zeros(self.oshape, dtype="float64")
215
266
 
216
- self.comm.Allreduce((g_tot.astype('float64'),self.MPI.DOUBLE),
217
- (grad,self.MPI.DOUBLE))
218
-
219
- if self.nlog==self.history.shape[0]:
220
- new_log=np.zeros([self.history.shape[0]*2])
221
- new_log[0:self.nlog]=self.history
222
- self.history=new_log
267
+ self.comm.Allreduce(
268
+ (g_tot.astype("float64"), self.MPI.DOUBLE), (grad, self.MPI.DOUBLE)
269
+ )
223
270
 
224
- l_tot=self.ltot[self.ltot!=-1].mean()
225
-
226
- self.history[self.nlog]=l_tot
271
+ if self.nlog == self.history.shape[0]:
272
+ new_log = np.zeros([self.history.shape[0] * 2])
273
+ new_log[0 : self.nlog] = self.history
274
+ self.history = new_log
227
275
 
228
- g_tot=grad.flatten()
276
+ l_tot = self.ltot[self.ltot != -1].mean()
229
277
 
230
- if self.operation.backend.bk_is_complex( g_tot):
231
- return l_tot.astype('float64'),g_tot
232
-
233
- return l_tot.astype('float64'),g_tot.astype('float64')
278
+ self.history[self.nlog] = l_tot
279
+
280
+ g_tot = grad.flatten()
281
+
282
+ if self.operation.backend.bk_is_complex(g_tot):
283
+ return l_tot.astype("float64"), g_tot
284
+
285
+ return l_tot.astype("float64"), g_tot.astype("float64")
234
286
 
235
287
  # ---------------------------------------------−---------
236
- def xtractmap(self,x,axis):
237
- x=self.operation.backend.bk_reshape(x,self.oshape)
238
-
288
+ def xtractmap(self, x, axis):
289
+ x = self.operation.backend.bk_reshape(x, self.oshape)
290
+
239
291
  return x
240
292
 
241
293
  # ---------------------------------------------−---------
242
- def run(self,
243
- in_x,
244
- NUM_EPOCHS = 100,
245
- DECAY_RATE=0.95,
246
- EVAL_FREQUENCY = 100,
247
- DEVAL_STAT_FREQUENCY = 1000,
248
- NUM_STEP_BIAS = 1,
249
- LEARNING_RATE = 0.03,
250
- EPSILON = 1E-7,
251
- KEEP_TRACK=None,
252
- grd_mask=None,
253
- SHOWGPU=False,
254
- MESSAGE='',
255
- factr=10.0,
256
- batchsz=1,
257
- totalsz=1,
258
- do_lbfgs=True,
259
- axis=0):
260
-
261
- self.KEEP_TRACK=KEEP_TRACK
262
- self.track={}
263
- self.ntrack=0
264
- self.eta=LEARNING_RATE
265
- self.epsilon=EPSILON
294
+ def run(
295
+ self,
296
+ in_x,
297
+ NUM_EPOCHS=100,
298
+ DECAY_RATE=0.95,
299
+ EVAL_FREQUENCY=100,
300
+ DEVAL_STAT_FREQUENCY=1000,
301
+ NUM_STEP_BIAS=1,
302
+ LEARNING_RATE=0.03,
303
+ EPSILON=1e-7,
304
+ KEEP_TRACK=None,
305
+ grd_mask=None,
306
+ SHOWGPU=False,
307
+ MESSAGE="",
308
+ factr=10.0,
309
+ batchsz=1,
310
+ totalsz=1,
311
+ do_lbfgs=True,
312
+ axis=0,
313
+ ):
314
+
315
+ self.KEEP_TRACK = KEEP_TRACK
316
+ self.track = {}
317
+ self.ntrack = 0
318
+ self.eta = LEARNING_RATE
319
+ self.epsilon = EPSILON
266
320
  self.decay_rate = DECAY_RATE
267
- self.nlog=0
268
- self.itt2=0
269
- self.batchsz=batchsz
270
- self.totalsz=totalsz
271
- self.grd_mask=grd_mask
272
- self.EVAL_FREQUENCY=EVAL_FREQUENCY
273
- self.MESSAGE=MESSAGE
274
- self.SHOWGPU=SHOWGPU
275
- self.axis=axis
276
- self.in_x_nshape=in_x.shape[0]
277
- self.seed=1234
278
-
279
- np.random.seed(self.mpi_rank*7+1234)
280
-
281
- x=in_x
282
-
283
- self.curr_gpu=self.curr_gpu+self.mpi_rank
284
-
285
- if self.mpi_size>1:
321
+ self.nlog = 0
322
+ self.itt2 = 0
323
+ self.batchsz = batchsz
324
+ self.totalsz = totalsz
325
+ self.grd_mask = grd_mask
326
+ self.EVAL_FREQUENCY = EVAL_FREQUENCY
327
+ self.MESSAGE = MESSAGE
328
+ self.SHOWGPU = SHOWGPU
329
+ self.axis = axis
330
+ self.in_x_nshape = in_x.shape[0]
331
+ self.seed = 1234
332
+
333
+ np.random.seed(self.mpi_rank * 7 + 1234)
334
+
335
+ x = in_x
336
+
337
+ self.curr_gpu = self.curr_gpu + self.mpi_rank
338
+
339
+ if self.mpi_size > 1:
286
340
  from mpi4py import MPI
287
-
288
341
 
289
342
  comm = MPI.COMM_WORLD
290
- self.comm=comm
291
- self.MPI=MPI
292
- if self.mpi_rank==0:
293
- print('Work with MPI')
343
+ self.comm = comm
344
+ self.MPI = MPI
345
+ if self.mpi_rank == 0:
346
+ print("Work with MPI")
294
347
  sys.stdout.flush()
295
-
296
- if self.mpi_rank==0 and SHOWGPU:
348
+
349
+ if self.mpi_rank == 0 and SHOWGPU:
297
350
  # start thread that catch GPU information
298
351
  try:
299
- self.gpu_thrd = Thread(target=self.get_gpu, args=(self.event,1,))
352
+ self.gpu_thrd = Thread(
353
+ target=self.get_gpu,
354
+ args=(
355
+ self.event,
356
+ 1,
357
+ ),
358
+ )
300
359
  self.gpu_thrd.start()
301
360
  except:
302
361
  print("Error: unable to start thread for GPU survey")
303
-
304
- start = time.time()
305
-
306
- if self.mpi_size>1:
307
- num_loss=np.zeros([1],dtype='int32')
308
- total_num_loss=np.zeros([1],dtype='int32')
309
- num_loss[0]=self.number_of_loss
310
- comm.Allreduce((num_loss,MPI.INT),(total_num_loss,MPI.INT))
311
- total_num_loss=total_num_loss[0]
362
+
363
+ # start = time.time()
364
+
365
+ if self.mpi_size > 1:
366
+ num_loss = np.zeros([1], dtype="int32")
367
+ total_num_loss = np.zeros([1], dtype="int32")
368
+ num_loss[0] = self.number_of_loss
369
+ comm.Allreduce((num_loss, MPI.INT), (total_num_loss, MPI.INT))
370
+ total_num_loss = total_num_loss[0]
312
371
  else:
313
- total_num_loss=self.number_of_loss
314
-
315
- if self.mpi_rank==0:
316
- print('Total number of loss ',total_num_loss)
372
+ total_num_loss = self.number_of_loss
373
+
374
+ if self.mpi_rank == 0:
375
+ print("Total number of loss ", total_num_loss)
317
376
  sys.stdout.flush()
318
-
319
- l_log=np.zeros([self.mpi_size*self.MAXNUMLOSS],dtype='float32')
320
- l_log[self.mpi_rank*self.MAXNUMLOSS:(self.mpi_rank+1)*self.MAXNUMLOSS]=-1.0
321
- self.ltot=l_log.copy()
322
- self.l_log=l_log
323
-
324
- self.imin=0
325
- self.start=time.time()
326
- self.itt=0
327
-
328
- self.oshape=list(x.shape)
329
-
330
- if not isinstance(x,np.ndarray):
331
- x=x.numpy()
332
-
333
- x=x.flatten()
334
377
 
335
- self.do_all_noise=False
378
+ l_log = np.zeros([self.mpi_size * self.MAXNUMLOSS], dtype="float32")
379
+ l_log[
380
+ self.mpi_rank * self.MAXNUMLOSS : (self.mpi_rank + 1) * self.MAXNUMLOSS
381
+ ] = -1.0
382
+ self.ltot = l_log.copy()
383
+ self.l_log = l_log
336
384
 
337
- self.do_all_noise=True
385
+ self.imin = 0
386
+ self.start = time.time()
387
+ self.itt = 0
338
388
 
339
- self.noise_idx=None
389
+ self.oshape = list(x.shape)
340
390
 
341
- for k in range(self.number_of_loss):
342
- if self.loss_class[k].batch is not None:
343
- l_batch=self.loss_class[k].batch(self.loss_class[k].batch_data,0,init=True)
344
-
345
- l_tot,g_tot=self.calc_grad(x)
391
+ if not isinstance(x, np.ndarray):
392
+ x = x.numpy()
393
+
394
+ x = x.flatten()
395
+
396
+ self.do_all_noise = False
397
+
398
+ self.do_all_noise = True
399
+
400
+ self.noise_idx = None
401
+
402
+ # for k in range(self.number_of_loss):
403
+ # if self.loss_class[k].batch is not None:
404
+ # l_batch = self.loss_class[k].batch(
405
+ # self.loss_class[k].batch_data, 0, init=True
406
+ # )
407
+
408
+ l_tot, g_tot = self.calc_grad(x)
346
409
 
347
410
  self.info_back(x)
348
411
 
349
- maxitt=NUM_EPOCHS
412
+ maxitt = NUM_EPOCHS
350
413
 
351
- start_x=x.copy()
414
+ # start_x = x.copy()
352
415
 
353
416
  for iteration in range(NUM_STEP_BIAS):
354
-
355
- x,l,i=opt.fmin_l_bfgs_b(self.calc_grad,
356
- x.astype('float64'),
357
- callback=self.info_back,
358
- pgtol=1E-32,
359
- factr=factr,
360
- maxiter=maxitt)
417
+
418
+ x, loss, i = opt.fmin_l_bfgs_b(
419
+ self.calc_grad,
420
+ x.astype("float64"),
421
+ callback=self.info_back,
422
+ pgtol=1e-32,
423
+ factr=factr,
424
+ maxiter=maxitt,
425
+ )
361
426
 
362
427
  # update bias input data
363
- if iteration<NUM_STEP_BIAS-1:
364
- #if self.mpi_rank==0:
428
+ if iteration < NUM_STEP_BIAS - 1:
429
+ # if self.mpi_rank==0:
365
430
  # print('%s Hessian restart'%(self.MESSAGE))
366
431
 
367
- omap=self.xtractmap(x,axis)
432
+ omap = self.xtractmap(x, axis)
368
433
 
369
434
  for k in range(self.number_of_loss):
370
435
  if self.loss_class[k].batch_update is not None:
371
- self.loss_class[k].batch_update(self.loss_class[k].batch_data,omap)
372
- if self.loss_class[k].batch is not None:
373
- l_batch=self.loss_class[k].batch(self.loss_class[k].batch_data,0,init=True)
374
- #x=start_x.copy()
375
-
376
- if self.mpi_rank==0 and SHOWGPU:
436
+ self.loss_class[k].batch_update(
437
+ self.loss_class[k].batch_data, omap
438
+ )
439
+ # if self.loss_class[k].batch is not None:
440
+ # l_batch = self.loss_class[k].batch(
441
+ # self.loss_class[k].batch_data, 0, init=True
442
+ # )
443
+ # x=start_x.copy()
444
+
445
+ if self.mpi_rank == 0 and SHOWGPU:
377
446
  self.stop_synthesis()
378
447
 
379
448
  if self.KEEP_TRACK is not None:
380
- self.last_info=self.KEEP_TRACK(None,self.mpi_rank,add=False)
449
+ self.last_info = self.KEEP_TRACK(None, self.mpi_rank, add=False)
381
450
 
382
- x=self.xtractmap(x,axis)
383
- return(x)
451
+ x = self.xtractmap(x, axis)
452
+ return x
384
453
 
385
454
  def get_history(self):
386
- return(self.history[0:self.nlog])
455
+ return self.history[0 : self.nlog]