foscat 3.0.8__py3-none-any.whl → 3.6.0__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
foscat/Synthesis.py CHANGED
@@ -1,69 +1,77 @@
1
- #import tensorflow as tf
2
- import numpy as np
3
- import time
4
- import sys
5
1
  import os
6
- from datetime import datetime
7
- from packaging import version
8
- from threading import Thread
9
- from threading import Event
2
+ import sys
3
+ import time
4
+ from threading import Event, Thread
5
+
6
+ import numpy as np
10
7
  import scipy.optimize as opt
11
8
 
9
+
12
10
  class Loss:
13
-
14
- def __init__(self,function,scat_operator,*param,
15
- name='',
16
- batch=None,
17
- batch_data=None,
18
- batch_update=None,
19
- info_callback=False):
20
-
21
- self.loss_function=function
22
- self.scat_operator=scat_operator
23
- self.args=param
24
- self.name=name
25
- self.batch=batch
26
- self.batch_data=batch_data
27
- self.batch_update=batch_update
28
- self.info=info_callback
29
-
30
- if scat_operator.BACKEND=='tensorflow':
31
- import loss_backend_tens as fbk
32
- self.bk=fbk.loss_backend(scat_operator)
33
-
34
- if scat_operator.BACKEND=='torch':
35
- import loss_backend_torch as fbk
36
- self.bk=fbk.loss_backend(scat_operator)
37
-
38
- if scat_operator.BACKEND=='numpy':
39
- print('Synthesis does not work with numpy. Please use Torch or Tensorflow')
40
- exit(0)
41
11
 
42
- def eval(self,x,batch,return_all=False):
12
+ def __init__(
13
+ self,
14
+ function,
15
+ scat_operator,
16
+ *param,
17
+ name="",
18
+ batch=None,
19
+ batch_data=None,
20
+ batch_update=None,
21
+ info_callback=False,
22
+ ):
23
+
24
+ self.loss_function = function
25
+ self.scat_operator = scat_operator
26
+ self.args = param
27
+ self.name = name
28
+ self.batch = batch
29
+ self.batch_data = batch_data
30
+ self.batch_update = batch_update
31
+ self.info = info_callback
32
+ self.id_loss = 0
33
+
34
+ def eval(self, x, batch, return_all=False):
43
35
  if self.batch is None:
44
36
  if self.info:
45
- return self.loss_function(x,self.scat_operator,self.args,return_all=return_all)
37
+ return self.loss_function(
38
+ x, self.scat_operator, self.args, return_all=return_all
39
+ )
46
40
  else:
47
- return self.loss_function(x,self.scat_operator,self.args)
41
+ return self.loss_function(x, self.scat_operator, self.args)
48
42
  else:
49
43
  if self.info:
50
- return self.loss_function(x,batch,self.scat_operator,self.args,return_all=return_all)
44
+ return self.loss_function(
45
+ x, batch, self.scat_operator, self.args, return_all=return_all
46
+ )
51
47
  else:
52
- return self.loss_function(x,batch,self.scat_operator,self.args)
48
+ return self.loss_function(x, batch, self.scat_operator, self.args)
49
+
50
+ def set_id_loss(self,id_loss):
51
+ self.id_loss = id_loss
53
52
 
53
+ def get_id_loss(self,id_loss):
54
+ return self.id_loss
55
+
54
56
  class Synthesis:
55
- def __init__(self,
56
- loss_list,
57
- eta=0.03,
58
- beta1=0.9,
59
- beta2=0.999,
60
- epsilon=1e-7,
61
- decay_rate = 0.999,
62
- MAXNUMLOSS=10):
63
-
64
- self.loss_class=loss_list
65
- self.number_of_loss=len(loss_list)
66
- self.nlog=0
57
+ def __init__(
58
+ self,
59
+ loss_list,
60
+ eta=0.03,
61
+ beta1=0.9,
62
+ beta2=0.999,
63
+ epsilon=1e-7,
64
+ decay_rate=0.999,
65
+ ):
66
+
67
+ self.loss_class = loss_list
68
+ self.number_of_loss = len(loss_list)
69
+
70
+ for k in range(self.number_of_loss):
71
+ self.loss_class[k].set_id_loss(k)
72
+
73
+ self.__iteration__ = 1234
74
+ self.nlog = 0
67
75
  self.m_dw, self.v_dw = 0.0, 0.0
68
76
  self.beta1 = beta1
69
77
  self.beta2 = beta2
@@ -71,320 +79,377 @@ class Synthesis:
71
79
  self.pbeta2 = beta2
72
80
  self.epsilon = epsilon
73
81
  self.eta = eta
74
- self.history=np.zeros([10])
75
- self.curr_gpu=0
82
+ self.history = np.zeros([10])
83
+ self.curr_gpu = 0
76
84
  self.event = Event()
77
- self.operation=loss_list[0].scat_operator
78
- self.mpi_size=self.operation.mpi_size
79
- self.mpi_rank=self.operation.mpi_rank
80
- self.KEEP_TRACK=None
81
- self.MAXNUMLOSS=MAXNUMLOSS
82
-
85
+ self.operation = loss_list[0].scat_operator
86
+ self.mpi_size = self.operation.mpi_size
87
+ self.mpi_rank = self.operation.mpi_rank
88
+ self.KEEP_TRACK = None
89
+ self.MAXNUMLOSS = len(loss_list)
90
+
91
+ if self.operation.BACKEND == "tensorflow":
92
+ import foscat.loss_backend_tens as fbk
93
+
94
+ self.bk = fbk.loss_backend(self.operation, self.curr_gpu, self.mpi_rank)
95
+
96
+ if self.operation.BACKEND == "torch":
97
+ import foscat.loss_backend_torch as fbk
98
+
99
+ self.bk = fbk.loss_backend(self.operation, self.curr_gpu, self.mpi_rank)
100
+
101
+ if self.operation.BACKEND == "numpy":
102
+ print(
103
+ "Synthesis does not work with numpy. Please select Torch or Tensorflow FOSCAT backend"
104
+ )
105
+ return None
106
+
83
107
  # ---------------------------------------------−---------
84
- def get_gpu(self,event,delay):
108
+ def get_gpu(self, event, delay):
85
109
 
86
- isnvidia=os.system('which nvidia-smi &> /dev/null')
110
+ isnvidia = os.system("which nvidia-smi &> /dev/null")
87
111
 
88
- while (1):
112
+ while 1:
89
113
  if event.is_set():
90
114
  break
91
115
  time.sleep(delay)
92
- if isnvidia==0:
116
+ if isnvidia == 0:
93
117
  try:
94
- os.system("nvidia-smi | awk '$2==\"N/A\"{print substr($9,1,length($9)-3),substr($11,1,length($11)-3),substr($13,1,length($13)-1)}' > smi_tmp.txt")
118
+ os.system(
119
+ "nvidia-smi | awk '$2==\"N/A\"{print substr($9,1,length($9)-3),substr($11,1,length($11)-3),substr($13,1,length($13)-1)}' > smi_tmp.txt"
120
+ )
95
121
  except:
96
- nogpu=1
97
-
122
+ print("No nvidia GPU: Impossible to trace")
123
+ self.nogpu = 1
124
+
98
125
  def stop_synthesis(self):
99
126
  # stop thread that catch GPU information
100
127
  self.event.set()
101
-
128
+
102
129
  try:
103
130
  self.gpu_thrd.join()
104
131
  except:
105
- print('No thread to stop, everything is ok')
132
+ print("No thread to stop, everything is ok")
106
133
  sys.stdout.flush()
107
-
134
+
108
135
  # ---------------------------------------------−---------
109
136
  def getgpumem(self):
110
137
  try:
111
- return np.loadtxt('smi_tmp.txt')
138
+ return np.loadtxt("smi_tmp.txt")
112
139
  except:
113
- return(np.zeros([1,3]))
114
-
140
+ return np.zeros([1, 3])
141
+
115
142
  # ---------------------------------------------−---------
116
- def info_back(self,x):
117
-
118
- self.nlog=self.nlog+1
119
- self.itt2=0
120
-
121
- if self.itt%self.EVAL_FREQUENCY==0 and self.mpi_rank==0:
143
+ def info_back(self, x):
144
+
145
+ self.nlog = self.nlog + 1
146
+ self.itt2 = 0
147
+
148
+ if self.itt % self.EVAL_FREQUENCY == 0 and self.mpi_rank == 0:
122
149
  end = time.time()
123
- cur_loss='%10.3g ('%(self.ltot[self.ltot!=-1].mean())
124
- for k in self.ltot[self.ltot!=-1]:
125
- cur_loss=cur_loss+'%10.3g '%(k)
126
-
127
- cur_loss=cur_loss+')'
128
-
129
- mess=''
130
-
150
+ cur_loss = "%10.3g (" % (self.ltot[self.ltot != -1].mean())
151
+ for k in self.ltot[self.ltot != -1]:
152
+ cur_loss = cur_loss + "%10.3g " % (k)
153
+
154
+ cur_loss = cur_loss + ")"
155
+
156
+ mess = ""
157
+
131
158
  if self.SHOWGPU:
132
- info_gpu=self.getgpumem()
159
+ info_gpu = self.getgpumem()
133
160
  for k in range(info_gpu.shape[0]):
134
- mess=mess+'[GPU%d %.0f/%.0f MB %.0f%%]'%(k,info_gpu[k,0],info_gpu[k,1],info_gpu[k,2])
135
-
136
- print('%sItt %6d L=%s %.3fs %s'%(self.MESSAGE,self.itt,cur_loss,(end-self.start),mess))
161
+ mess = mess + "[GPU%d %.0f/%.0f MB %.0f%%]" % (
162
+ k,
163
+ info_gpu[k, 0],
164
+ info_gpu[k, 1],
165
+ info_gpu[k, 2],
166
+ )
167
+
168
+ print(
169
+ "%sItt %6d L=%s %.3fs %s"
170
+ % (self.MESSAGE, self.itt, cur_loss, (end - self.start), mess)
171
+ )
137
172
  sys.stdout.flush()
138
173
  if self.KEEP_TRACK is not None:
139
174
  print(self.last_info)
140
175
  sys.stdout.flush()
141
-
176
+
142
177
  self.start = time.time()
143
-
144
- self.itt=self.itt+1
145
-
178
+
179
+ self.itt = self.itt + 1
180
+
146
181
  # ---------------------------------------------−---------
147
- def calc_grad(self,in_x):
148
-
149
- g_tot=None
150
- l_tot=0.0
182
+ def calc_grad(self, in_x):
183
+
184
+ g_tot = None
185
+ l_tot = 0.0
151
186
 
152
- if self.do_all_noise and self.totalsz>self.batchsz:
153
- nstep=self.totalsz//self.batchsz
187
+ if self.do_all_noise and self.totalsz > self.batchsz:
188
+ nstep = self.totalsz // self.batchsz
154
189
  else:
155
- nstep=1
190
+ nstep = 1
191
+
192
+ x = self.operation.backend.bk_reshape(
193
+ self.operation.backend.bk_cast(in_x), self.oshape
194
+ )
195
+
196
+ self.l_log[
197
+ self.mpi_rank * self.MAXNUMLOSS : (self.mpi_rank + 1) * self.MAXNUMLOSS
198
+ ] = -1.0
156
199
 
157
- x=self.operation.backend.bk_cast(self.operation.backend.bk_reshape(in_x,self.oshape))
158
-
159
- self.l_log[self.mpi_rank*self.MAXNUMLOSS:(self.mpi_rank+1)*self.MAXNUMLOSS]=-1.0
160
-
161
200
  for istep in range(nstep):
162
-
201
+
163
202
  for k in range(self.number_of_loss):
164
203
  if self.loss_class[k].batch is None:
165
- l_batch=None
204
+ l_batch = None
166
205
  else:
167
- l_batch=self.loss_class[k].batch(self.loss_class[k].batch_data,istep)
206
+ l_batch = self.loss_class[k].batch(
207
+ self.loss_class[k].batch_data, istep
208
+ )
168
209
 
169
210
  if self.KEEP_TRACK is not None:
170
- l,g,linfo=self.bk.loss(x,l_batch,self.loss_class[k])
171
- self.last_info=self.KEEP_TRACK(linfo,self.mpi_rank,add=True)
211
+ l_loss, g, linfo = self.bk.loss(
212
+ x, l_batch, self.loss_class[k], self.KEEP_TRACK
213
+ )
214
+ self.last_info = self.KEEP_TRACK(linfo, self.mpi_rank, add=True)
172
215
  else:
173
- l,g=self.bk.loss(x,l_batch,self.loss_class[k])
216
+ l_loss, g = self.bk.loss(
217
+ x, l_batch, self.loss_class[k], self.KEEP_TRACK
218
+ )
174
219
 
175
220
  if g_tot is None:
176
- g_tot=g
221
+ g_tot = g
177
222
  else:
178
- g_tot=g_tot+g
223
+ g_tot = g_tot + g
179
224
 
180
- l_tot=l_tot+l.numpy()
225
+ l_tot = l_tot + l_loss.numpy()
181
226
 
182
- if self.l_log[self.mpi_rank*self.MAXNUMLOSS+k]==-1:
183
- self.l_log[self.mpi_rank*self.MAXNUMLOSS+k]=l.numpy()/nstep
227
+ if self.l_log[self.mpi_rank * self.MAXNUMLOSS + k] == -1:
228
+ self.l_log[self.mpi_rank * self.MAXNUMLOSS + k] = (
229
+ l_loss.numpy() / nstep
230
+ )
184
231
  else:
185
- self.l_log[self.mpi_rank*self.MAXNUMLOSS+k]=self.l_log[self.mpi_rank*self.MAXNUMLOSS+k]+l.numpy()/nstep
186
-
187
- grd_mask=self.grd_mask
188
-
232
+ self.l_log[self.mpi_rank * self.MAXNUMLOSS + k] = (
233
+ self.l_log[self.mpi_rank * self.MAXNUMLOSS + k]
234
+ + l_loss.numpy() / nstep
235
+ )
236
+
237
+ grd_mask = self.grd_mask
238
+
189
239
  if grd_mask is not None:
190
- g_tot=grd_mask*g_tot.numpy()
240
+ g_tot = grd_mask * g_tot.numpy()
191
241
  else:
192
- g_tot=g_tot.numpy()
193
-
194
- g_tot[np.isnan(g_tot)]=0.0
242
+ g_tot = g_tot.numpy()
243
+
244
+ g_tot[np.isnan(g_tot)] = 0.0
195
245
 
196
- self.imin=self.imin+self.batchsz
246
+ self.imin = self.imin + self.batchsz
197
247
 
198
- if self.mpi_size==1:
199
- self.ltot=self.l_log
248
+ if self.mpi_size == 1:
249
+ self.ltot = self.l_log
200
250
  else:
201
- local_log=(self.l_log).astype('float64')
202
- self.ltot=np.zeros(self.l_log.shape,dtype='float64')
203
- self.comm.Allreduce((local_log,self.MPI.DOUBLE),(self.ltot,self.MPI.DOUBLE))
204
-
205
- if self.mpi_size==1:
206
- grad=g_tot
251
+ local_log = (self.l_log).astype("float64")
252
+ self.ltot = np.zeros(self.l_log.shape, dtype="float64")
253
+ self.comm.Allreduce(
254
+ (local_log, self.MPI.DOUBLE), (self.ltot, self.MPI.DOUBLE)
255
+ )
256
+
257
+ if self.mpi_size == 1:
258
+ grad = g_tot
207
259
  else:
208
-
209
- if g_tot.dtype=='complex64' or g_tot.dtype=='complex128':
210
- grad=np.zeros(self.oshape,dtype=gtot.dtype)
260
+ if self.operation.backend.bk_is_complex(g_tot):
261
+ grad = np.zeros(self.oshape, dtype=g_tot.dtype)
211
262
 
212
- self.comm.Allreduce((g_tot),(grad))
263
+ self.comm.Allreduce((g_tot), (grad))
213
264
  else:
214
- grad=np.zeros(self.oshape,dtype='float64')
265
+ grad = np.zeros(self.oshape, dtype="float64")
215
266
 
216
- self.comm.Allreduce((g_tot.astype('float64'),self.MPI.DOUBLE),
217
- (grad,self.MPI.DOUBLE))
218
-
219
- if self.nlog==self.history.shape[0]:
220
- new_log=np.zeros([self.history.shape[0]*2])
221
- new_log[0:self.nlog]=self.history
222
- self.history=new_log
267
+ self.comm.Allreduce(
268
+ (g_tot.astype("float64"), self.MPI.DOUBLE), (grad, self.MPI.DOUBLE)
269
+ )
223
270
 
224
- l_tot=self.ltot[self.ltot!=-1].mean()
225
-
226
- self.history[self.nlog]=l_tot
271
+ if self.nlog == self.history.shape[0]:
272
+ new_log = np.zeros([self.history.shape[0] * 2])
273
+ new_log[0 : self.nlog] = self.history
274
+ self.history = new_log
227
275
 
228
- g_tot=grad.flatten()
276
+ l_tot = self.ltot[self.ltot != -1].mean()
229
277
 
230
- if g_tot.dtype=='complex64' or g_tot.dtype=='complex128':
231
- return l_tot.astype('float64'),g_tot
232
-
233
- return l_tot.astype('float64'),g_tot.astype('float64')
278
+ self.history[self.nlog] = l_tot
279
+
280
+ g_tot = grad.flatten()
281
+
282
+ if self.operation.backend.bk_is_complex(g_tot):
283
+ return l_tot.astype("float64"), g_tot
284
+
285
+ return l_tot.astype("float64"), g_tot.astype("float64")
234
286
 
235
287
  # ---------------------------------------------−---------
236
- def xtractmap(self,x,axis):
237
- x=self.operation.backend.bk_reshape(x,self.oshape)
238
-
288
+ def xtractmap(self, x, axis):
289
+ x = self.operation.backend.bk_reshape(x, self.oshape)
290
+
239
291
  return x
240
292
 
241
293
  # ---------------------------------------------−---------
242
- def run(self,
243
- in_x,
244
- NUM_EPOCHS = 100,
245
- DECAY_RATE=0.95,
246
- EVAL_FREQUENCY = 100,
247
- DEVAL_STAT_FREQUENCY = 1000,
248
- NUM_STEP_BIAS = 1,
249
- LEARNING_RATE = 0.03,
250
- EPSILON = 1E-7,
251
- KEEP_TRACK=None,
252
- grd_mask=None,
253
- SHOWGPU=False,
254
- MESSAGE='',
255
- factr=10.0,
256
- batchsz=1,
257
- totalsz=1,
258
- do_lbfgs=True,
259
- axis=0):
260
-
261
- self.KEEP_TRACK=KEEP_TRACK
262
- self.track={}
263
- self.ntrack=0
264
- self.eta=LEARNING_RATE
265
- self.epsilon=EPSILON
294
+ def run(
295
+ self,
296
+ in_x,
297
+ NUM_EPOCHS=100,
298
+ DECAY_RATE=0.95,
299
+ EVAL_FREQUENCY=100,
300
+ DEVAL_STAT_FREQUENCY=1000,
301
+ NUM_STEP_BIAS=1,
302
+ LEARNING_RATE=0.03,
303
+ EPSILON=1e-7,
304
+ KEEP_TRACK=None,
305
+ grd_mask=None,
306
+ SHOWGPU=False,
307
+ MESSAGE="",
308
+ factr=10.0,
309
+ batchsz=1,
310
+ totalsz=1,
311
+ do_lbfgs=True,
312
+ axis=0,
313
+ ):
314
+
315
+ self.KEEP_TRACK = KEEP_TRACK
316
+ self.track = {}
317
+ self.ntrack = 0
318
+ self.eta = LEARNING_RATE
319
+ self.epsilon = EPSILON
266
320
  self.decay_rate = DECAY_RATE
267
- self.nlog=0
268
- self.itt2=0
269
- self.batchsz=batchsz
270
- self.totalsz=totalsz
271
- self.grd_mask=grd_mask
272
- self.EVAL_FREQUENCY=EVAL_FREQUENCY
273
- self.MESSAGE=MESSAGE
274
- self.SHOWGPU=SHOWGPU
275
- self.axis=axis
276
- self.in_x_nshape=in_x.shape[0]
277
-
278
- """
279
- if do_lbfgs and (in_x.dtype=='complex64' or in_x.dtype=='complex128'):
280
- print('L_BFGS minimisation not yet implemented for acomplex array, use default FOSCAT minimizer or convert your problem to float32 or float64')
281
- exit(0)
282
- """
283
- np.random.seed(self.mpi_rank*7+1234)
284
-
285
- x=in_x
286
-
287
- self.curr_gpu=self.curr_gpu+self.mpi_rank
288
-
289
- if self.mpi_size>1:
321
+ self.nlog = 0
322
+ self.itt2 = 0
323
+ self.batchsz = batchsz
324
+ self.totalsz = totalsz
325
+ self.grd_mask = grd_mask
326
+ self.EVAL_FREQUENCY = EVAL_FREQUENCY
327
+ self.MESSAGE = MESSAGE
328
+ self.SHOWGPU = SHOWGPU
329
+ self.axis = axis
330
+ self.in_x_nshape = in_x.shape[0]
331
+ self.seed = 1234
332
+
333
+ np.random.seed(self.mpi_rank * 7 + 1234)
334
+
335
+ x = in_x
336
+
337
+ self.curr_gpu = self.curr_gpu + self.mpi_rank
338
+
339
+ if self.mpi_size > 1:
290
340
  from mpi4py import MPI
291
-
292
341
 
293
342
  comm = MPI.COMM_WORLD
294
- self.comm=comm
295
- self.MPI=MPI
296
- if self.mpi_rank==0:
297
- print('Work with MPI')
343
+ self.comm = comm
344
+ self.MPI = MPI
345
+ if self.mpi_rank == 0:
346
+ print("Work with MPI")
298
347
  sys.stdout.flush()
299
-
300
- if self.mpi_rank==0 and SHOWGPU:
348
+
349
+ if self.mpi_rank == 0 and SHOWGPU:
301
350
  # start thread that catch GPU information
302
351
  try:
303
- self.gpu_thrd = Thread(target=self.get_gpu, args=(self.event,1,))
352
+ self.gpu_thrd = Thread(
353
+ target=self.get_gpu,
354
+ args=(
355
+ self.event,
356
+ 1,
357
+ ),
358
+ )
304
359
  self.gpu_thrd.start()
305
360
  except:
306
361
  print("Error: unable to start thread for GPU survey")
307
-
308
- start = time.time()
309
-
310
- if self.mpi_size>1:
311
- num_loss=np.zeros([1],dtype='int32')
312
- total_num_loss=np.zeros([1],dtype='int32')
313
- num_loss[0]=self.number_of_loss
314
- comm.Allreduce((num_loss,MPI.INT),(total_num_loss,MPI.INT))
315
- total_num_loss=total_num_loss[0]
362
+
363
+ # start = time.time()
364
+
365
+ if self.mpi_size > 1:
366
+ num_loss = np.zeros([1], dtype="int32")
367
+ total_num_loss = np.zeros([1], dtype="int32")
368
+ num_loss[0] = self.number_of_loss
369
+ comm.Allreduce((num_loss, MPI.INT), (total_num_loss, MPI.INT))
370
+ total_num_loss = total_num_loss[0]
316
371
  else:
317
- total_num_loss=self.number_of_loss
318
-
319
- if self.mpi_rank==0:
320
- print('Total number of loss ',total_num_loss)
372
+ total_num_loss = self.number_of_loss
373
+
374
+ if self.mpi_rank == 0:
375
+ print("Total number of loss ", total_num_loss)
321
376
  sys.stdout.flush()
322
-
323
- l_log=np.zeros([self.mpi_size*self.MAXNUMLOSS],dtype='float32')
324
- l_log[self.mpi_rank*self.MAXNUMLOSS:(self.mpi_rank+1)*self.MAXNUMLOSS]=-1.0
325
- self.ltot=l_log.copy()
326
- self.l_log=l_log
327
-
328
- self.imin=0
329
- self.start=time.time()
330
- self.itt=0
331
-
332
- self.oshape=list(x.shape)
333
-
334
- if not isinstance(x,np.ndarray):
335
- x=x.numpy()
336
-
337
- x=x.flatten()
338
377
 
339
- self.do_all_noise=False
378
+ l_log = np.zeros([self.mpi_size * self.MAXNUMLOSS], dtype="float32")
379
+ l_log[
380
+ self.mpi_rank * self.MAXNUMLOSS : (self.mpi_rank + 1) * self.MAXNUMLOSS
381
+ ] = -1.0
382
+ self.ltot = l_log.copy()
383
+ self.l_log = l_log
340
384
 
341
- self.do_all_noise=True
385
+ self.imin = 0
386
+ self.start = time.time()
387
+ self.itt = 0
342
388
 
343
- self.noise_idx=None
389
+ self.oshape = list(x.shape)
344
390
 
345
- for k in range(self.number_of_loss):
346
- if self.loss_class[k].batch is not None:
347
- l_batch=self.loss_class[k].batch(self.loss_class[k].batch_data,0,init=True)
348
-
349
- l_tot,g_tot=self.calc_grad(x)
391
+ if not isinstance(x, np.ndarray):
392
+ x = x.numpy()
393
+
394
+ x = x.flatten()
395
+
396
+ self.do_all_noise = False
397
+
398
+ self.do_all_noise = True
399
+
400
+ self.noise_idx = None
401
+
402
+ # for k in range(self.number_of_loss):
403
+ # if self.loss_class[k].batch is not None:
404
+ # l_batch = self.loss_class[k].batch(
405
+ # self.loss_class[k].batch_data, 0, init=True
406
+ # )
407
+
408
+ l_tot, g_tot = self.calc_grad(x)
350
409
 
351
410
  self.info_back(x)
352
411
 
353
- maxitt=NUM_EPOCHS
412
+ maxitt = NUM_EPOCHS
354
413
 
355
- start_x=x.copy()
414
+ # start_x = x.copy()
356
415
 
357
416
  for iteration in range(NUM_STEP_BIAS):
358
417
 
359
- x,l,i=opt.fmin_l_bfgs_b(self.calc_grad,
360
- x.astype('float64'),
361
- callback=self.info_back,
362
- pgtol=1E-32,
363
- factr=factr,
364
- maxiter=maxitt)
418
+ x, loss, i = opt.fmin_l_bfgs_b(
419
+ self.calc_grad,
420
+ x.astype("float64"),
421
+ callback=self.info_back,
422
+ pgtol=1e-32,
423
+ factr=factr,
424
+ maxiter=maxitt,
425
+ )
365
426
 
366
427
  # update bias input data
367
- if iteration<NUM_STEP_BIAS-1:
368
- if self.mpi_rank==0:
369
- print('%s Hessian restart'%(self.MESSAGE))
428
+ if iteration < NUM_STEP_BIAS - 1:
429
+ # if self.mpi_rank==0:
430
+ # print('%s Hessian restart'%(self.MESSAGE))
370
431
 
371
- omap=self.xtractmap(x,axis)
432
+ omap = self.xtractmap(x, axis)
372
433
 
373
434
  for k in range(self.number_of_loss):
374
435
  if self.loss_class[k].batch_update is not None:
375
- self.loss_class[k].batch_update(self.loss_class[k].batch_data,omap)
376
- l_batch=self.loss_class[k].batch(self.loss_class[k].batch_data,0,init=True)
377
- #x=start_x.copy()
378
-
379
-
380
- if self.mpi_rank==0 and SHOWGPU:
436
+ self.loss_class[k].batch_update(
437
+ self.loss_class[k].batch_data, omap
438
+ )
439
+ # if self.loss_class[k].batch is not None:
440
+ # l_batch = self.loss_class[k].batch(
441
+ # self.loss_class[k].batch_data, 0, init=True
442
+ # )
443
+ # x=start_x.copy()
444
+
445
+ if self.mpi_rank == 0 and SHOWGPU:
381
446
  self.stop_synthesis()
382
447
 
383
448
  if self.KEEP_TRACK is not None:
384
- self.last_info=self.KEEP_TRACK(None,self.mpi_rank,add=False)
449
+ self.last_info = self.KEEP_TRACK(None, self.mpi_rank, add=False)
385
450
 
386
- x=self.xtractmap(x,axis)
387
- return(x)
451
+ x = self.xtractmap(x, axis)
452
+ return x
388
453
 
389
454
  def get_history(self):
390
- return(self.history[0:self.nlog])
455
+ return self.history[0 : self.nlog]