foscat 3.1.5__py3-none-any.whl → 3.2.0__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
foscat/Synthesis.py CHANGED
@@ -1,56 +1,67 @@
1
- import numpy as np
2
- import time
3
- import sys
4
1
  import os
5
- from datetime import datetime
6
- from packaging import version
7
- from threading import Thread
8
- from threading import Event
2
+ import sys
3
+ import time
4
+ from threading import Event, Thread
5
+
6
+ import numpy as np
9
7
  import scipy.optimize as opt
10
8
 
9
+
11
10
  class Loss:
12
-
13
- def __init__(self,function,scat_operator,*param,
14
- name='',
15
- batch=None,
16
- batch_data=None,
17
- batch_update=None,
18
- info_callback=False):
19
-
20
- self.loss_function=function
21
- self.scat_operator=scat_operator
22
- self.args=param
23
- self.name=name
24
- self.batch=batch
25
- self.batch_data=batch_data
26
- self.batch_update=batch_update
27
- self.info=info_callback
28
-
29
- def eval(self,x,batch,return_all=False):
11
+
12
+ def __init__(
13
+ self,
14
+ function,
15
+ scat_operator,
16
+ *param,
17
+ name="",
18
+ batch=None,
19
+ batch_data=None,
20
+ batch_update=None,
21
+ info_callback=False
22
+ ):
23
+
24
+ self.loss_function = function
25
+ self.scat_operator = scat_operator
26
+ self.args = param
27
+ self.name = name
28
+ self.batch = batch
29
+ self.batch_data = batch_data
30
+ self.batch_update = batch_update
31
+ self.info = info_callback
32
+
33
+ def eval(self, x, batch, return_all=False):
30
34
  if self.batch is None:
31
35
  if self.info:
32
- return self.loss_function(x,self.scat_operator,self.args,return_all=return_all)
36
+ return self.loss_function(
37
+ x, self.scat_operator, self.args, return_all=return_all
38
+ )
33
39
  else:
34
- return self.loss_function(x,self.scat_operator,self.args)
40
+ return self.loss_function(x, self.scat_operator, self.args)
35
41
  else:
36
42
  if self.info:
37
- return self.loss_function(x,batch,self.scat_operator,self.args,return_all=return_all)
43
+ return self.loss_function(
44
+ x, batch, self.scat_operator, self.args, return_all=return_all
45
+ )
38
46
  else:
39
- return self.loss_function(x,batch,self.scat_operator,self.args)
40
-
47
+ return self.loss_function(x, batch, self.scat_operator, self.args)
48
+
49
+
41
50
  class Synthesis:
42
- def __init__(self,
43
- loss_list,
44
- eta=0.03,
45
- beta1=0.9,
46
- beta2=0.999,
47
- epsilon=1e-7,
48
- decay_rate = 0.999):
49
-
50
- self.loss_class=loss_list
51
- self.number_of_loss=len(loss_list)
52
- self.__iteration__=1234
53
- self.nlog=0
51
+ def __init__(
52
+ self,
53
+ loss_list,
54
+ eta=0.03,
55
+ beta1=0.9,
56
+ beta2=0.999,
57
+ epsilon=1e-7,
58
+ decay_rate=0.999,
59
+ ):
60
+
61
+ self.loss_class = loss_list
62
+ self.number_of_loss = len(loss_list)
63
+ self.__iteration__ = 1234
64
+ self.nlog = 0
54
65
  self.m_dw, self.v_dw = 0.0, 0.0
55
66
  self.beta1 = beta1
56
67
  self.beta2 = beta2
@@ -58,329 +69,373 @@ class Synthesis:
58
69
  self.pbeta2 = beta2
59
70
  self.epsilon = epsilon
60
71
  self.eta = eta
61
- self.history=np.zeros([10])
62
- self.curr_gpu=0
72
+ self.history = np.zeros([10])
73
+ self.curr_gpu = 0
63
74
  self.event = Event()
64
- self.operation=loss_list[0].scat_operator
65
- self.mpi_size=self.operation.mpi_size
66
- self.mpi_rank=self.operation.mpi_rank
67
- self.KEEP_TRACK=None
68
- self.MAXNUMLOSS=len(loss_list)
69
-
70
- if self.operation.BACKEND=='tensorflow':
75
+ self.operation = loss_list[0].scat_operator
76
+ self.mpi_size = self.operation.mpi_size
77
+ self.mpi_rank = self.operation.mpi_rank
78
+ self.KEEP_TRACK = None
79
+ self.MAXNUMLOSS = len(loss_list)
80
+
81
+ if self.operation.BACKEND == "tensorflow":
71
82
  import foscat.loss_backend_tens as fbk
72
- self.bk=fbk.loss_backend(self.operation,self.curr_gpu,self.mpi_rank)
73
-
74
- if self.operation.BACKEND=='torch':
83
+
84
+ self.bk = fbk.loss_backend(self.operation, self.curr_gpu, self.mpi_rank)
85
+
86
+ if self.operation.BACKEND == "torch":
75
87
  import foscat.loss_backend_torch as fbk
76
- self.bk=fbk.loss_backend(self.operation,self.curr_gpu,self.mpi_rank)
77
-
78
- if self.operation.BACKEND=='numpy':
79
- print('Synthesis does not work with numpy. Please select Torch or Tensorflow FOSCAT backend')
88
+
89
+ self.bk = fbk.loss_backend(self.operation, self.curr_gpu, self.mpi_rank)
90
+
91
+ if self.operation.BACKEND == "numpy":
92
+ print(
93
+ "Synthesis does not work with numpy. Please select Torch or Tensorflow FOSCAT backend"
94
+ )
80
95
  return None
81
96
 
82
97
  # ---------------------------------------------−---------
83
- def get_gpu(self,event,delay):
98
+ def get_gpu(self, event, delay):
84
99
 
85
- isnvidia=os.system('which nvidia-smi &> /dev/null')
100
+ isnvidia = os.system("which nvidia-smi &> /dev/null")
86
101
 
87
- while (1):
102
+ while 1:
88
103
  if event.is_set():
89
104
  break
90
105
  time.sleep(delay)
91
- if isnvidia==0:
106
+ if isnvidia == 0:
92
107
  try:
93
- os.system("nvidia-smi | awk '$2==\"N/A\"{print substr($9,1,length($9)-3),substr($11,1,length($11)-3),substr($13,1,length($13)-1)}' > smi_tmp.txt")
108
+ os.system(
109
+ "nvidia-smi | awk '$2==\"N/A\"{print substr($9,1,length($9)-3),substr($11,1,length($11)-3),substr($13,1,length($13)-1)}' > smi_tmp.txt"
110
+ )
94
111
  except:
95
- nogpu=1
96
-
112
+ print('No nvidia GPU: Impossible to trace')
113
+ self.nogpu = 1
114
+
97
115
  def stop_synthesis(self):
98
116
  # stop thread that catch GPU information
99
117
  self.event.set()
100
-
118
+
101
119
  try:
102
120
  self.gpu_thrd.join()
103
121
  except:
104
- print('No thread to stop, everything is ok')
122
+ print("No thread to stop, everything is ok")
105
123
  sys.stdout.flush()
106
-
124
+
107
125
  # ---------------------------------------------−---------
108
126
  def getgpumem(self):
109
127
  try:
110
- return np.loadtxt('smi_tmp.txt')
128
+ return np.loadtxt("smi_tmp.txt")
111
129
  except:
112
- return(np.zeros([1,3]))
113
-
130
+ return np.zeros([1, 3])
131
+
114
132
  # ---------------------------------------------−---------
115
- def info_back(self,x):
116
-
117
- self.nlog=self.nlog+1
118
- self.itt2=0
119
-
120
- if self.itt%self.EVAL_FREQUENCY==0 and self.mpi_rank==0:
133
+ def info_back(self, x):
134
+
135
+ self.nlog = self.nlog + 1
136
+ self.itt2 = 0
137
+
138
+ if self.itt % self.EVAL_FREQUENCY == 0 and self.mpi_rank == 0:
121
139
  end = time.time()
122
- cur_loss='%10.3g ('%(self.ltot[self.ltot!=-1].mean())
123
- for k in self.ltot[self.ltot!=-1]:
124
- cur_loss=cur_loss+'%10.3g '%(k)
125
-
126
- cur_loss=cur_loss+')'
127
-
128
- mess=''
129
-
140
+ cur_loss = "%10.3g (" % (self.ltot[self.ltot != -1].mean())
141
+ for k in self.ltot[self.ltot != -1]:
142
+ cur_loss = cur_loss + "%10.3g " % (k)
143
+
144
+ cur_loss = cur_loss + ")"
145
+
146
+ mess = ""
147
+
130
148
  if self.SHOWGPU:
131
- info_gpu=self.getgpumem()
149
+ info_gpu = self.getgpumem()
132
150
  for k in range(info_gpu.shape[0]):
133
- mess=mess+'[GPU%d %.0f/%.0f MB %.0f%%]'%(k,info_gpu[k,0],info_gpu[k,1],info_gpu[k,2])
134
-
135
-
136
- print('%sItt %6d L=%s %.3fs %s'%(self.MESSAGE,self.itt,cur_loss,(end-self.start),mess))
151
+ mess = mess + "[GPU%d %.0f/%.0f MB %.0f%%]" % (
152
+ k,
153
+ info_gpu[k, 0],
154
+ info_gpu[k, 1],
155
+ info_gpu[k, 2],
156
+ )
157
+
158
+ print(
159
+ "%sItt %6d L=%s %.3fs %s"
160
+ % (self.MESSAGE, self.itt, cur_loss, (end - self.start), mess)
161
+ )
137
162
  sys.stdout.flush()
138
163
  if self.KEEP_TRACK is not None:
139
164
  print(self.last_info)
140
165
  sys.stdout.flush()
141
-
166
+
142
167
  self.start = time.time()
143
-
144
- self.itt=self.itt+1
168
+
169
+ self.itt = self.itt + 1
145
170
 
146
171
  # ---------------------------------------------−---------
147
- def calc_grad(self,in_x):
148
-
149
- g_tot=None
150
- l_tot=0.0
151
-
152
-
153
- if self.do_all_noise and self.totalsz>self.batchsz:
154
- nstep=self.totalsz//self.batchsz
172
+ def calc_grad(self, in_x):
173
+
174
+ g_tot = None
175
+ l_tot = 0.0
176
+
177
+ if self.do_all_noise and self.totalsz > self.batchsz:
178
+ nstep = self.totalsz // self.batchsz
155
179
  else:
156
- nstep=1
180
+ nstep = 1
181
+
182
+ x = self.operation.backend.bk_reshape(
183
+ self.operation.backend.bk_cast(in_x), self.oshape
184
+ )
185
+
186
+ self.l_log[
187
+ self.mpi_rank * self.MAXNUMLOSS : (self.mpi_rank + 1) * self.MAXNUMLOSS
188
+ ] = -1.0
157
189
 
158
- x=self.operation.backend.bk_reshape(self.operation.backend.bk_cast(in_x),self.oshape)
159
-
160
- self.l_log[self.mpi_rank*self.MAXNUMLOSS:(self.mpi_rank+1)*self.MAXNUMLOSS]=-1.0
161
-
162
190
  for istep in range(nstep):
163
-
191
+
164
192
  for k in range(self.number_of_loss):
165
193
  if self.loss_class[k].batch is None:
166
- l_batch=None
194
+ l_batch = None
167
195
  else:
168
- l_batch=self.loss_class[k].batch(self.loss_class[k].batch_data,istep)
196
+ l_batch = self.loss_class[k].batch(
197
+ self.loss_class[k].batch_data, istep
198
+ )
169
199
 
170
200
  if self.KEEP_TRACK is not None:
171
- l,g,linfo=self.bk.loss(x,l_batch,self.loss_class[k],self.KEEP_TRACK)
172
- self.last_info=self.KEEP_TRACK(linfo,self.mpi_rank,add=True)
201
+ l_loss, g, linfo = self.bk.loss(
202
+ x, l_batch, self.loss_class[k], self.KEEP_TRACK
203
+ )
204
+ self.last_info = self.KEEP_TRACK(linfo, self.mpi_rank, add=True)
173
205
  else:
174
- l,g=self.bk.loss(x,l_batch,self.loss_class[k],self.KEEP_TRACK)
206
+ l_loss, g = self.bk.loss(x, l_batch, self.loss_class[k], self.KEEP_TRACK)
175
207
 
176
208
  if g_tot is None:
177
- g_tot=g
209
+ g_tot = g
178
210
  else:
179
- g_tot=g_tot+g
211
+ g_tot = g_tot + g
180
212
 
181
- l_tot=l_tot+l.numpy()
213
+ l_tot = l_tot + l_loss.numpy()
182
214
 
183
- if self.l_log[self.mpi_rank*self.MAXNUMLOSS+k]==-1:
184
- self.l_log[self.mpi_rank*self.MAXNUMLOSS+k]=l.numpy()/nstep
215
+ if self.l_log[self.mpi_rank * self.MAXNUMLOSS + k] == -1:
216
+ self.l_log[self.mpi_rank * self.MAXNUMLOSS + k] = l_loss.numpy() / nstep
185
217
  else:
186
- self.l_log[self.mpi_rank*self.MAXNUMLOSS+k]=self.l_log[self.mpi_rank*self.MAXNUMLOSS+k]+l.numpy()/nstep
187
-
188
- grd_mask=self.grd_mask
189
-
218
+ self.l_log[self.mpi_rank * self.MAXNUMLOSS + k] = (
219
+ self.l_log[self.mpi_rank * self.MAXNUMLOSS + k]
220
+ + l_loss.numpy() / nstep
221
+ )
222
+
223
+ grd_mask = self.grd_mask
224
+
190
225
  if grd_mask is not None:
191
- g_tot=grd_mask*g_tot.numpy()
226
+ g_tot = grd_mask * g_tot.numpy()
192
227
  else:
193
- g_tot=g_tot.numpy()
194
-
195
- g_tot[np.isnan(g_tot)]=0.0
228
+ g_tot = g_tot.numpy()
196
229
 
197
- self.imin=self.imin+self.batchsz
230
+ g_tot[np.isnan(g_tot)] = 0.0
198
231
 
199
- if self.mpi_size==1:
200
- self.ltot=self.l_log
232
+ self.imin = self.imin + self.batchsz
233
+
234
+ if self.mpi_size == 1:
235
+ self.ltot = self.l_log
201
236
  else:
202
- local_log=(self.l_log).astype('float64')
203
- self.ltot=np.zeros(self.l_log.shape,dtype='float64')
204
- self.comm.Allreduce((local_log,self.MPI.DOUBLE),(self.ltot,self.MPI.DOUBLE))
205
-
206
- if self.mpi_size==1:
207
- grad=g_tot
237
+ local_log = (self.l_log).astype("float64")
238
+ self.ltot = np.zeros(self.l_log.shape, dtype="float64")
239
+ self.comm.Allreduce(
240
+ (local_log, self.MPI.DOUBLE), (self.ltot, self.MPI.DOUBLE)
241
+ )
242
+
243
+ if self.mpi_size == 1:
244
+ grad = g_tot
208
245
  else:
209
- if self.operation.backend.bk_is_complex( g_tot):
210
- grad=np.zeros(self.oshape,dtype=gtot.dtype)
246
+ if self.operation.backend.bk_is_complex(g_tot):
247
+ grad = np.zeros(self.oshape, dtype=g_tot.dtype)
211
248
 
212
- self.comm.Allreduce((g_tot),(grad))
249
+ self.comm.Allreduce((g_tot), (grad))
213
250
  else:
214
- grad=np.zeros(self.oshape,dtype='float64')
251
+ grad = np.zeros(self.oshape, dtype="float64")
252
+
253
+ self.comm.Allreduce(
254
+ (g_tot.astype("float64"), self.MPI.DOUBLE), (grad, self.MPI.DOUBLE)
255
+ )
256
+
257
+ if self.nlog == self.history.shape[0]:
258
+ new_log = np.zeros([self.history.shape[0] * 2])
259
+ new_log[0 : self.nlog] = self.history
260
+ self.history = new_log
261
+
262
+ l_tot = self.ltot[self.ltot != -1].mean()
215
263
 
216
- self.comm.Allreduce((g_tot.astype('float64'),self.MPI.DOUBLE),
217
- (grad,self.MPI.DOUBLE))
218
-
219
- if self.nlog==self.history.shape[0]:
220
- new_log=np.zeros([self.history.shape[0]*2])
221
- new_log[0:self.nlog]=self.history
222
- self.history=new_log
264
+ self.history[self.nlog] = l_tot
223
265
 
224
- l_tot=self.ltot[self.ltot!=-1].mean()
225
-
226
- self.history[self.nlog]=l_tot
266
+ g_tot = grad.flatten()
227
267
 
228
- g_tot=grad.flatten()
268
+ if self.operation.backend.bk_is_complex(g_tot):
269
+ return l_tot.astype("float64"), g_tot
229
270
 
230
- if self.operation.backend.bk_is_complex( g_tot):
231
- return l_tot.astype('float64'),g_tot
232
-
233
- return l_tot.astype('float64'),g_tot.astype('float64')
271
+ return l_tot.astype("float64"), g_tot.astype("float64")
234
272
 
235
273
  # ---------------------------------------------−---------
236
- def xtractmap(self,x,axis):
237
- x=self.operation.backend.bk_reshape(x,self.oshape)
238
-
274
+ def xtractmap(self, x, axis):
275
+ x = self.operation.backend.bk_reshape(x, self.oshape)
276
+
239
277
  return x
240
278
 
241
279
  # ---------------------------------------------−---------
242
- def run(self,
243
- in_x,
244
- NUM_EPOCHS = 100,
245
- DECAY_RATE=0.95,
246
- EVAL_FREQUENCY = 100,
247
- DEVAL_STAT_FREQUENCY = 1000,
248
- NUM_STEP_BIAS = 1,
249
- LEARNING_RATE = 0.03,
250
- EPSILON = 1E-7,
251
- KEEP_TRACK=None,
252
- grd_mask=None,
253
- SHOWGPU=False,
254
- MESSAGE='',
255
- factr=10.0,
256
- batchsz=1,
257
- totalsz=1,
258
- do_lbfgs=True,
259
- axis=0):
260
-
261
- self.KEEP_TRACK=KEEP_TRACK
262
- self.track={}
263
- self.ntrack=0
264
- self.eta=LEARNING_RATE
265
- self.epsilon=EPSILON
280
+ def run(
281
+ self,
282
+ in_x,
283
+ NUM_EPOCHS=100,
284
+ DECAY_RATE=0.95,
285
+ EVAL_FREQUENCY=100,
286
+ DEVAL_STAT_FREQUENCY=1000,
287
+ NUM_STEP_BIAS=1,
288
+ LEARNING_RATE=0.03,
289
+ EPSILON=1e-7,
290
+ KEEP_TRACK=None,
291
+ grd_mask=None,
292
+ SHOWGPU=False,
293
+ MESSAGE="",
294
+ factr=10.0,
295
+ batchsz=1,
296
+ totalsz=1,
297
+ do_lbfgs=True,
298
+ axis=0,
299
+ ):
300
+
301
+ self.KEEP_TRACK = KEEP_TRACK
302
+ self.track = {}
303
+ self.ntrack = 0
304
+ self.eta = LEARNING_RATE
305
+ self.epsilon = EPSILON
266
306
  self.decay_rate = DECAY_RATE
267
- self.nlog=0
268
- self.itt2=0
269
- self.batchsz=batchsz
270
- self.totalsz=totalsz
271
- self.grd_mask=grd_mask
272
- self.EVAL_FREQUENCY=EVAL_FREQUENCY
273
- self.MESSAGE=MESSAGE
274
- self.SHOWGPU=SHOWGPU
275
- self.axis=axis
276
- self.in_x_nshape=in_x.shape[0]
277
- self.seed=1234
278
-
279
- np.random.seed(self.mpi_rank*7+1234)
280
-
281
- x=in_x
282
-
283
- self.curr_gpu=self.curr_gpu+self.mpi_rank
284
-
285
- if self.mpi_size>1:
307
+ self.nlog = 0
308
+ self.itt2 = 0
309
+ self.batchsz = batchsz
310
+ self.totalsz = totalsz
311
+ self.grd_mask = grd_mask
312
+ self.EVAL_FREQUENCY = EVAL_FREQUENCY
313
+ self.MESSAGE = MESSAGE
314
+ self.SHOWGPU = SHOWGPU
315
+ self.axis = axis
316
+ self.in_x_nshape = in_x.shape[0]
317
+ self.seed = 1234
318
+
319
+ np.random.seed(self.mpi_rank * 7 + 1234)
320
+
321
+ x = in_x
322
+
323
+ self.curr_gpu = self.curr_gpu + self.mpi_rank
324
+
325
+ if self.mpi_size > 1:
286
326
  from mpi4py import MPI
287
-
288
327
 
289
328
  comm = MPI.COMM_WORLD
290
- self.comm=comm
291
- self.MPI=MPI
292
- if self.mpi_rank==0:
293
- print('Work with MPI')
329
+ self.comm = comm
330
+ self.MPI = MPI
331
+ if self.mpi_rank == 0:
332
+ print("Work with MPI")
294
333
  sys.stdout.flush()
295
-
296
- if self.mpi_rank==0 and SHOWGPU:
334
+
335
+ if self.mpi_rank == 0 and SHOWGPU:
297
336
  # start thread that catch GPU information
298
337
  try:
299
- self.gpu_thrd = Thread(target=self.get_gpu, args=(self.event,1,))
338
+ self.gpu_thrd = Thread(
339
+ target=self.get_gpu,
340
+ args=(
341
+ self.event,
342
+ 1,
343
+ ),
344
+ )
300
345
  self.gpu_thrd.start()
301
346
  except:
302
347
  print("Error: unable to start thread for GPU survey")
303
-
304
- start = time.time()
305
-
306
- if self.mpi_size>1:
307
- num_loss=np.zeros([1],dtype='int32')
308
- total_num_loss=np.zeros([1],dtype='int32')
309
- num_loss[0]=self.number_of_loss
310
- comm.Allreduce((num_loss,MPI.INT),(total_num_loss,MPI.INT))
311
- total_num_loss=total_num_loss[0]
348
+
349
+ #start = time.time()
350
+
351
+ if self.mpi_size > 1:
352
+ num_loss = np.zeros([1], dtype="int32")
353
+ total_num_loss = np.zeros([1], dtype="int32")
354
+ num_loss[0] = self.number_of_loss
355
+ comm.Allreduce((num_loss, MPI.INT), (total_num_loss, MPI.INT))
356
+ total_num_loss = total_num_loss[0]
312
357
  else:
313
- total_num_loss=self.number_of_loss
314
-
315
- if self.mpi_rank==0:
316
- print('Total number of loss ',total_num_loss)
358
+ total_num_loss = self.number_of_loss
359
+
360
+ if self.mpi_rank == 0:
361
+ print("Total number of loss ", total_num_loss)
317
362
  sys.stdout.flush()
318
-
319
- l_log=np.zeros([self.mpi_size*self.MAXNUMLOSS],dtype='float32')
320
- l_log[self.mpi_rank*self.MAXNUMLOSS:(self.mpi_rank+1)*self.MAXNUMLOSS]=-1.0
321
- self.ltot=l_log.copy()
322
- self.l_log=l_log
323
-
324
- self.imin=0
325
- self.start=time.time()
326
- self.itt=0
327
-
328
- self.oshape=list(x.shape)
329
-
330
- if not isinstance(x,np.ndarray):
331
- x=x.numpy()
332
-
333
- x=x.flatten()
334
-
335
- self.do_all_noise=False
336
-
337
- self.do_all_noise=True
338
-
339
- self.noise_idx=None
340
-
341
- for k in range(self.number_of_loss):
342
- if self.loss_class[k].batch is not None:
343
- l_batch=self.loss_class[k].batch(self.loss_class[k].batch_data,0,init=True)
344
-
345
- l_tot,g_tot=self.calc_grad(x)
363
+
364
+ l_log = np.zeros([self.mpi_size * self.MAXNUMLOSS], dtype="float32")
365
+ l_log[
366
+ self.mpi_rank * self.MAXNUMLOSS : (self.mpi_rank + 1) * self.MAXNUMLOSS
367
+ ] = -1.0
368
+ self.ltot = l_log.copy()
369
+ self.l_log = l_log
370
+
371
+ self.imin = 0
372
+ self.start = time.time()
373
+ self.itt = 0
374
+
375
+ self.oshape = list(x.shape)
376
+
377
+ if not isinstance(x, np.ndarray):
378
+ x = x.numpy()
379
+
380
+ x = x.flatten()
381
+
382
+ self.do_all_noise = False
383
+
384
+ self.do_all_noise = True
385
+
386
+ self.noise_idx = None
387
+
388
+ #for k in range(self.number_of_loss):
389
+ # if self.loss_class[k].batch is not None:
390
+ # l_batch = self.loss_class[k].batch(
391
+ # self.loss_class[k].batch_data, 0, init=True
392
+ # )
393
+
394
+ l_tot, g_tot = self.calc_grad(x)
346
395
 
347
396
  self.info_back(x)
348
397
 
349
- maxitt=NUM_EPOCHS
398
+ maxitt = NUM_EPOCHS
350
399
 
351
- start_x=x.copy()
400
+ # start_x = x.copy()
352
401
 
353
402
  for iteration in range(NUM_STEP_BIAS):
354
-
355
- x,l,i=opt.fmin_l_bfgs_b(self.calc_grad,
356
- x.astype('float64'),
357
- callback=self.info_back,
358
- pgtol=1E-32,
359
- factr=factr,
360
- maxiter=maxitt)
403
+
404
+ x, loss, i = opt.fmin_l_bfgs_b(
405
+ self.calc_grad,
406
+ x.astype("float64"),
407
+ callback=self.info_back,
408
+ pgtol=1e-32,
409
+ factr=factr,
410
+ maxiter=maxitt,
411
+ )
361
412
 
362
413
  # update bias input data
363
- if iteration<NUM_STEP_BIAS-1:
364
- #if self.mpi_rank==0:
414
+ if iteration < NUM_STEP_BIAS - 1:
415
+ # if self.mpi_rank==0:
365
416
  # print('%s Hessian restart'%(self.MESSAGE))
366
417
 
367
- omap=self.xtractmap(x,axis)
418
+ omap = self.xtractmap(x, axis)
368
419
 
369
420
  for k in range(self.number_of_loss):
370
421
  if self.loss_class[k].batch_update is not None:
371
- self.loss_class[k].batch_update(self.loss_class[k].batch_data,omap)
372
- if self.loss_class[k].batch is not None:
373
- l_batch=self.loss_class[k].batch(self.loss_class[k].batch_data,0,init=True)
374
- #x=start_x.copy()
375
-
376
- if self.mpi_rank==0 and SHOWGPU:
422
+ self.loss_class[k].batch_update(
423
+ self.loss_class[k].batch_data, omap
424
+ )
425
+ #if self.loss_class[k].batch is not None:
426
+ # l_batch = self.loss_class[k].batch(
427
+ # self.loss_class[k].batch_data, 0, init=True
428
+ # )
429
+ # x=start_x.copy()
430
+
431
+ if self.mpi_rank == 0 and SHOWGPU:
377
432
  self.stop_synthesis()
378
433
 
379
434
  if self.KEEP_TRACK is not None:
380
- self.last_info=self.KEEP_TRACK(None,self.mpi_rank,add=False)
435
+ self.last_info = self.KEEP_TRACK(None, self.mpi_rank, add=False)
381
436
 
382
- x=self.xtractmap(x,axis)
383
- return(x)
437
+ x = self.xtractmap(x, axis)
438
+ return x
384
439
 
385
440
  def get_history(self):
386
- return(self.history[0:self.nlog])
441
+ return self.history[0 : self.nlog]