deepliif 1.1.10__py3-none-any.whl → 1.1.12__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
deepliif/util/__init__.py CHANGED
@@ -14,17 +14,18 @@ from .visualizer import Visualizer
14
14
  from ..postprocessing import imadjust
15
15
  import cv2
16
16
 
17
+ import pickle
18
+ import sys
19
+
17
20
  import bioformats
18
21
  import javabridge
19
22
  import bioformats.omexml as ome
20
23
  import tifffile as tf
21
24
 
22
- import pickle
23
- import sys
24
25
 
25
26
  excluding_names = ['Hema', 'DAPI', 'DAPILap2', 'Ki67', 'Seg', 'Marked', 'SegRefined', 'SegOverlaid', 'Marker', 'Lap2']
26
27
  # Image extensions to consider
27
- image_extensions = ['.png', '.jpg', '.tif', '.jpeg', '.svs']
28
+ image_extensions = ['.png', '.jpg', '.tif', '.jpeg']
28
29
 
29
30
 
30
31
  def allowed_file(filename):
@@ -118,6 +119,211 @@ def stitch_tile(img, tile, tile_size, overlap_size, i, j):
118
119
  img.paste(tile, (i * tile_size, j * tile_size))
119
120
 
120
121
 
122
+ class InferenceTiler:
123
+ """
124
+ Iterable class to tile image(s) and stitch result tiles together.
125
+
126
+ To perform inference on a large image, that image will need to be
127
+ tiled into smaller tiles that can be run individually and then
128
+ stitched back together. This class wraps the functionality as an
129
+ iterable object that can accept a single image or list of images
130
+ if multiple images are taken as input for inference.
131
+
132
+ An overlap size can be specified so that neighboring tiles will
133
+ overlap at the edges, helping to reduce seams or other artifacts
134
+ near the edge of a tile. Padding of a solid color around the
135
+ perimeter of the tile is also possible, if needed. The specified
136
+ tile size includes this overlap and pad sizes, so a tile size of
137
+ 512 with an overlap size of 32 and pad size of 16 would have a
138
+ central area of 416 pixels that are stitched into the result image.
139
+
140
+ Example Usage
141
+ -------------
142
+ tiler = InferenceTiler(img, 512, 32)
143
+ for tile in tiler:
144
+ result_tiles = infer(tile)
145
+ tiler.stitch(result_tiles)
146
+ images = tiler.results()
147
+ """
148
+
149
+ def __init__(self, orig, tile_size, overlap_size=0, pad_size=0, pad_color=(255, 255, 255)):
150
+ """
151
+ Initialize for tiling an image or list of images.
152
+
153
+ Parameters
154
+ ----------
155
+ orig : Image | list(Image)
156
+ Original image or list of images to be tiled.
157
+ tile_size: int
158
+ Size (width and height) of the tiles to be generated.
159
+ overlap_size: int [default: 0]
160
+ Amount of overlap on each side of the tile.
161
+ pad_size: int [default: 0]
162
+ Amount of solid color padding around perimeter of tile.
163
+ pad_color: tuple(int, int, int) [default: (255,255,255)]
164
+ RGB color to use for padding.
165
+ """
166
+
167
+ if tile_size <= 0:
168
+ raise ValueError('InfereneTiler input tile_size must be positive and non-zero')
169
+ if overlap_size < 0:
170
+ raise ValueError('InfereneTiler input overlap_size must be positive or zero')
171
+ if pad_size < 0:
172
+ raise ValueError('InfereneTiler input pad_size must be positive or zero')
173
+
174
+ self.single_orig = not type(orig) is list
175
+ if self.single_orig:
176
+ orig = [orig]
177
+
178
+ for i in range(1, len(orig)):
179
+ if orig[i].size != orig[0].size:
180
+ raise ValueError('InferenceTiler input images do not have the same size.')
181
+ self.orig_width = orig[0].width
182
+ self.orig_height = orig[0].height
183
+
184
+ # patch size to extract from input image, which is then padded to tile size
185
+ patch_size = tile_size - (2 * pad_size)
186
+
187
+ # make sure width and height are both at least patch_size
188
+ if orig[0].width < patch_size:
189
+ for i in range(len(orig)):
190
+ while orig[i].width < patch_size:
191
+ mirrored = ImageOps.mirror(orig[i])
192
+ orig[i] = ImageOps.expand(orig[i], (0, 0, orig[i].width, 0))
193
+ orig[i].paste(mirrored, (mirrored.width, 0))
194
+ orig[i] = orig[i].crop((0, 0, patch_size, orig[i].height))
195
+ if orig[0].height < patch_size:
196
+ for i in range(len(orig)):
197
+ while orig[i].height < patch_size:
198
+ flipped = ImageOps.flip(orig[i])
199
+ orig[i] = ImageOps.expand(orig[i], (0, 0, 0, orig[i].height))
200
+ orig[i].paste(flipped, (0, flipped.height))
201
+ orig[i] = orig[i].crop((0, 0, orig[i].width, patch_size))
202
+ self.image_width = orig[0].width
203
+ self.image_height = orig[0].height
204
+
205
+ overlap_width = 0 if patch_size >= self.image_width else overlap_size
206
+ overlap_height = 0 if patch_size >= self.image_height else overlap_size
207
+ center_width = patch_size - (2 * overlap_width)
208
+ center_height = patch_size - (2 * overlap_height)
209
+ if center_width <= 0 or center_height <= 0:
210
+ raise ValueError('InferenceTiler combined overlap_size and pad_size are too large')
211
+
212
+ self.c0x = pad_size # crop offset for left of non-pad content in result tile
213
+ self.c0y = pad_size # crop offset for top of non-pad content in result tile
214
+ self.c1x = overlap_width + pad_size # crop offset for left of center region in result tile
215
+ self.c1y = overlap_height + pad_size # crop offset for top of center region in result tile
216
+ self.c2x = patch_size - overlap_width + pad_size # crop offset for right of center region in result tile
217
+ self.c2y = patch_size - overlap_height + pad_size # crop offset for bottom of center region in result tile
218
+ self.c3x = patch_size + pad_size # crop offset for right of non-pad content in result tile
219
+ self.c3y = patch_size + pad_size # crop offset for bottom of non-pad content in result tile
220
+ self.p1x = overlap_width # paste offset for left of center region w.r.t (x,y) coord
221
+ self.p1y = overlap_height # paste offset for top of center region w.r.t (x,y) coord
222
+ self.p2x = patch_size - overlap_width # paste offset for right of center region w.r.t (x,y) coord
223
+ self.p2y = patch_size - overlap_height # paste offset for bottom of center region w.r.t (x,y) coord
224
+
225
+ self.overlap_width = overlap_width
226
+ self.overlap_height = overlap_height
227
+ self.patch_size = patch_size
228
+ self.center_width = center_width
229
+ self.center_height = center_height
230
+
231
+ self.orig = orig
232
+ self.tile_size = tile_size
233
+ self.pad_size = pad_size
234
+ self.pad_color = pad_color
235
+ self.res = {}
236
+
237
+ def __iter__(self):
238
+ """
239
+ Generate the tiles as an iterable.
240
+
241
+ Tiles are created and iterated over from top left to bottom
242
+ right, going across the rows. The yielded tile(s) match the
243
+ type of the original input when initialized (either a single
244
+ image or a list of images in the same order as initialized).
245
+ The (x, y) coordinate of the current tile is maintained
246
+ internally for use in the stitch function.
247
+ """
248
+
249
+ for y in range(0, self.image_height, self.center_height):
250
+ for x in range(0, self.image_width, self.center_width):
251
+ if x + self.patch_size > self.image_width:
252
+ x = self.image_width - self.patch_size
253
+ if y + self.patch_size > self.image_height:
254
+ y = self.image_height - self.patch_size
255
+ self.x = x
256
+ self.y = y
257
+ tiles = [im.crop((x, y, x + self.patch_size, y + self.patch_size)) for im in self.orig]
258
+ if self.pad_size != 0:
259
+ tiles = [ImageOps.expand(t, self.pad_size, self.pad_color) for t in tiles]
260
+ yield tiles[0] if self.single_orig else tiles
261
+
262
+ def stitch(self, result_tiles):
263
+ """
264
+ Stitch result tiles into the result images.
265
+
266
+ The key names for the dictionary of result tiles are used to
267
+ stitch each tile into its corresponding final image in the
268
+ results attribute. If a result image does not exist for a
269
+ result tile key name, then it will be created. The result tiles
270
+ are stitched at the location from which the list iterated tile
271
+ was extracted.
272
+
273
+ Parameters
274
+ ----------
275
+ result_tiles : dict(str: Image)
276
+ Dictionary of result tiles from the inference.
277
+ """
278
+
279
+ for k, tile in result_tiles.items():
280
+ if k not in self.res:
281
+ self.res[k] = Image.new('RGB', (self.image_width, self.image_height))
282
+ if tile.size != (self.tile_size, self.tile_size):
283
+ tile = tile.resize((self.tile_size, self.tile_size))
284
+ self.res[k].paste(tile.crop((self.c1x, self.c1y, self.c2x, self.c2y)), (self.x + self.p1x, self.y + self.p1y))
285
+
286
+ # top left corner
287
+ if self.x == 0 and self.y == 0:
288
+ self.res[k].paste(tile.crop((self.c0x, self.c0y, self.c1x, self.c1y)), (self.x, self.y))
289
+ # top row
290
+ if self.y == 0:
291
+ self.res[k].paste(tile.crop((self.c1x, self.c0y, self.c2x, self.c1y)), (self.x + self.p1x, self.y))
292
+ # top right corner
293
+ if self.x == self.image_width - self.patch_size and self.y == 0:
294
+ self.res[k].paste(tile.crop((self.c2x, self.c0y, self.c3x, self.c1y)), (self.x + self.p2x, self.y))
295
+ # left column
296
+ if self.x == 0:
297
+ self.res[k].paste(tile.crop((self.c0x, self.c1y, self.c1x, self.c2y)), (self.x, self.y + self.p1y))
298
+ # right column
299
+ if self.x == self.image_width - self.patch_size:
300
+ self.res[k].paste(tile.crop((self.c2x, self.c1y, self.c3x, self.c2y)), (self.x + self.p2x, self.y + self.p1y))
301
+ # bottom left corner
302
+ if self.x == 0 and self.y == self.image_height - self.patch_size:
303
+ self.res[k].paste(tile.crop((self.c0x, self.c2y, self.c1x, self.c3y)), (self.x, self.y + self.p2y))
304
+ # bottom row
305
+ if self.y == self.image_height - self.patch_size:
306
+ self.res[k].paste(tile.crop((self.c1x, self.c2y, self.c2x, self.c3y)), (self.x + self.p1x, self.y + self.p2y))
307
+ # bottom right corner
308
+ if self.x == self.image_width - self.patch_size and self.y == self.image_height - self.patch_size:
309
+ self.res[k].paste(tile.crop((self.c2x, self.c2y, self.c3x, self.c3y)), (self.x + self.p2x, self.y + self.p2y))
310
+
311
+ def results(self):
312
+ """
313
+ Return a dictionary of result images.
314
+
315
+ The keys for the result images are the same as those used for
316
+ the result tiles in the stitch function. This function should
317
+ only be called once, since the stitched images will be cropped
318
+ if the original image size was less than the patch size.
319
+ """
320
+
321
+ if self.orig_width != self.image_width or self.orig_height != self.image_height:
322
+ return {k: im.crop((0, 0, self.orig_width, self.orig_height)) for k, im in self.res.items()}
323
+ else:
324
+ return {k: im for k, im in self.res.items()}
325
+
326
+
121
327
  def calculate_background_mean_value(img):
122
328
  img = cv2.fastNlMeansDenoisingColored(np.array(img), None, 10, 10, 7, 21)
123
329
  img = np.array(img, dtype=float)
@@ -235,6 +441,87 @@ def get_information(filename):
235
441
  return size_x, size_y, size_z, size_c, size_t, pixel_type
236
442
 
237
443
 
444
+
445
+
446
+ def write_results_to_pickle_file(output_addr, results):
447
+ """
448
+ This function writes data into the pickle file.
449
+ :param output_addr: The address of the pickle file to write data into.
450
+ :param results: The data to be written into the pickle file.
451
+ :return:
452
+ """
453
+ pickle_obj = open(output_addr, "wb")
454
+ pickle.dump(results, pickle_obj)
455
+ pickle_obj.close()
456
+
457
+
458
+ def read_results_from_pickle_file(input_addr):
459
+ """
460
+ This function reads data from a pickle file and returns it.
461
+ :param input_addr: The address to the pickle file.
462
+ :return: The data inside pickle file.
463
+ """
464
+ pickle_obj = open(input_addr, "rb")
465
+ results = pickle.load(pickle_obj)
466
+ pickle_obj.close()
467
+ return results
468
+
469
+ def test_diff_original_serialized(model_original,model_serialized,example,verbose=0):
470
+ threshold = 10
471
+
472
+ orig_res = model_original(example)
473
+ if verbose > 0:
474
+ print('Original:')
475
+ print(orig_res.shape)
476
+ print(orig_res[0, 0:10])
477
+ print('min abs value:{}'.format(torch.min(torch.abs(orig_res))))
478
+
479
+ ts_res = model_serialized(example)
480
+ if verbose > 0:
481
+ print('Torchscript:')
482
+ print(ts_res.shape)
483
+ print(ts_res[0, 0:10])
484
+ print('min abs value:{}'.format(torch.min(torch.abs(ts_res))))
485
+
486
+ abs_diff = torch.abs(orig_res-ts_res)
487
+ if verbose > 0:
488
+ print('Dif sum:')
489
+ print(torch.sum(abs_diff))
490
+ print('max dif:{}'.format(torch.max(abs_diff)))
491
+
492
+ assert torch.sum(abs_diff) <= threshold, f"Sum of difference in predicted values {torch.sum(abs_diff)} is larger than threshold {threshold}"
493
+
494
+ def disable_batchnorm_tracking_stats(model):
495
+ # https://discuss.pytorch.org/t/performance-highly-degraded-when-eval-is-activated-in-the-test-phase/3323/16
496
+ # https://discuss.pytorch.org/t/performance-highly-degraded-when-eval-is-activated-in-the-test-phase/3323/67
497
+ # https://github.com/pytorch/pytorch/blob/ca39c5b04e30a67512589cafbd9d063cc17168a5/torch/nn/modules/batchnorm.py#L158
498
+ for m in model.modules():
499
+ for child in m.children():
500
+ if type(child) == torch.nn.BatchNorm2d:
501
+ child.track_running_stats = False
502
+ child.running_mean_backup = child.running_mean
503
+ child.running_mean = None
504
+ child.running_var_backup = child.running_var
505
+ child.running_var = None
506
+ return model
507
+
508
+ def enable_batchnorm_tracking_stats(model):
509
+ """
510
+ This is needed during training when val set loss/metrics calculation is enabled.
511
+ In this case, we need to switch to eval mode for inference, which triggers
512
+ disable_batchnorm_tracking_stats(). After the evaluation, the model should be
513
+ set back to train mode, where running stats are restored for batchnorm layers.
514
+ """
515
+ for m in model.modules():
516
+ for child in m.children():
517
+ if type(child) == torch.nn.BatchNorm2d:
518
+ child.track_running_stats = True
519
+ assert hasattr(child, 'running_mean_backup') and hasattr(child, 'running_var_backup'), 'enable_batchnorm_tracking_stats() is supposed to be executed after disable_batchnorm_tracking_stats() is applied'
520
+ child.running_mean = child.running_mean_backup
521
+ child.running_var = child.running_var_backup
522
+ return model
523
+
524
+
238
525
  def write_big_tiff_file(output_addr, img, tile_size):
239
526
  """
240
527
  This function write the image into a big tiff file using the tiling and compression.
@@ -376,64 +663,3 @@ def write_ome_tiff_file_array(results_array, output_addr, size_t, size_z, size_c
376
663
  output_addr,
377
664
  SizeT=size_t, SizeZ=size_z, SizeC=len(channel_names), SizeX=size_x, SizeY=size_y,
378
665
  channel_names=channel_names)
379
-
380
-
381
- def write_results_to_pickle_file(output_addr, results):
382
- """
383
- This function writes data into the pickle file.
384
- :param output_addr: The address of the pickle file to write data into.
385
- :param results: The data to be written into the pickle file.
386
- :return:
387
- """
388
- pickle_obj = open(output_addr, "wb")
389
- pickle.dump(results, pickle_obj)
390
- pickle_obj.close()
391
-
392
-
393
- def read_results_from_pickle_file(input_addr):
394
- """
395
- This function reads data from a pickle file and returns it.
396
- :param input_addr: The address to the pickle file.
397
- :return: The data inside pickle file.
398
- """
399
- pickle_obj = open(input_addr, "rb")
400
- results = pickle.load(pickle_obj)
401
- pickle_obj.close()
402
- return results
403
-
404
- def test_diff_original_serialized(model_original,model_serialized,example,verbose=0):
405
- threshold = 10
406
-
407
- orig_res = model_original(example)
408
- if verbose > 0:
409
- print('Original:')
410
- print(orig_res.shape)
411
- print(orig_res[0, 0:10])
412
- print('min abs value:{}'.format(torch.min(torch.abs(orig_res))))
413
-
414
- ts_res = model_serialized(example)
415
- if verbose > 0:
416
- print('Torchscript:')
417
- print(ts_res.shape)
418
- print(ts_res[0, 0:10])
419
- print('min abs value:{}'.format(torch.min(torch.abs(ts_res))))
420
-
421
- abs_diff = torch.abs(orig_res-ts_res)
422
- if verbose > 0:
423
- print('Dif sum:')
424
- print(torch.sum(abs_diff))
425
- print('max dif:{}'.format(torch.max(abs_diff)))
426
-
427
- assert torch.sum(abs_diff) <= threshold, f"Sum of difference in predicted values {torch.sum(abs_diff)} is larger than threshold {threshold}"
428
-
429
- def disable_batchnorm_tracking_stats(model):
430
- # https://discuss.pytorch.org/t/performance-highly-degraded-when-eval-is-activated-in-the-test-phase/3323/16
431
- # https://discuss.pytorch.org/t/performance-highly-degraded-when-eval-is-activated-in-the-test-phase/3323/67
432
- # https://github.com/pytorch/pytorch/blob/ca39c5b04e30a67512589cafbd9d063cc17168a5/torch/nn/modules/batchnorm.py#L158
433
- for m in model.modules():
434
- for child in m.children():
435
- if type(child) == torch.nn.BatchNorm2d:
436
- child.track_running_stats = False
437
- child.running_mean = None
438
- child.running_var = None
439
- return model
@@ -165,6 +165,7 @@ class Visualizer():
165
165
  if ncols > 0: # show all the images in one visdom panel
166
166
  ncols = min(ncols, len(visuals))
167
167
  h, w = next(iter(visuals.values())).shape[:2]
168
+
168
169
  table_css = """<style>
169
170
  table {border-collapse: separate; border-spacing: 4px; white-space: nowrap; text-align: center}
170
171
  table td {width: % dpx; height: % dpx; padding: 4px; outline: 4px solid black}
@@ -176,13 +177,16 @@ class Visualizer():
176
177
  images = []
177
178
  idx = 0
178
179
  for label, image in visuals.items():
179
- image_numpy = util.tensor2im(image)
180
- label_html_row += '<td>%s</td>' % label
181
- images.append(image_numpy.transpose([2, 0, 1]))
182
- idx += 1
183
- if idx % ncols == 0:
184
- label_html += '<tr>%s</tr>' % label_html_row
185
- label_html_row = ''
180
+ if image.shape[1] != 3:
181
+ pass
182
+ else:
183
+ image_numpy = util.tensor2im(image)
184
+ label_html_row += '<td>%s</td>' % label
185
+ images.append(image_numpy.transpose([2, 0, 1]))
186
+ idx += 1
187
+ if idx % ncols == 0:
188
+ label_html += '<tr>%s</tr>' % label_html_row
189
+ label_html_row = ''
186
190
 
187
191
  white_image = np.ones_like(image_numpy.transpose([2, 0, 1])) * 255
188
192
  while idx % ncols != 0:
@@ -191,6 +195,7 @@ class Visualizer():
191
195
  idx += 1
192
196
  if label_html_row != '':
193
197
  label_html += '<tr>%s</tr>' % label_html_row
198
+
194
199
 
195
200
  try:
196
201
  self.vis.images(images, nrow=ncols, win=self.display_id + 1,
@@ -248,6 +253,10 @@ class Visualizer():
248
253
  # if having 2 processes, each process obtains 50% of the data (effective dataset_size divided by half), the effective counter ratio shall multiply by 2 to compensate that
249
254
  n_proc = int(os.getenv('WORLD_SIZE',1))
250
255
  counter_ratio = counter_ratio * n_proc
256
+
257
+ self.plot_data_update_train = False
258
+ self.plot_data_update_val = False
259
+ self.plot_data_update_metrics = False
251
260
 
252
261
  if self.remote:
253
262
  fn = 'plot_current_losses.pickle'
@@ -263,20 +272,98 @@ class Visualizer():
263
272
  exec(f'{self.remote_transfer_cmd_function}("{path_source}")')
264
273
  else:
265
274
  if not hasattr(self, 'plot_data'):
266
- self.plot_data = {'X': [], 'Y': [], 'legend': list(losses.keys())}
267
- self.plot_data['X'].append(epoch + counter_ratio)
268
- self.plot_data['Y'].append([losses[k] for k in self.plot_data['legend']])
275
+ self.plot_data = {'X': [], 'X_val':[], 'X_metrics':[],
276
+ 'Y': [], 'Y_val':[], 'Y_metrics':[],
277
+ 'legend': [], 'legend_val': [], 'legend_metrics':[]}
278
+ for k in list(losses.keys()):
279
+ if k.endswith('_val'):
280
+ self.plot_data['legend_val'].append(k)
281
+ elif k.startswith(('G_','D_')):
282
+ self.plot_data['legend'].append(k)
283
+ else:
284
+ self.plot_data['legend_metrics'].append(k)
285
+
286
+ # check if all names in losses dict have been seen
287
+ # currently we assume the three types of metrics (train loss, val loss, other metrics) can come into the losses dict
288
+ # at any step, but each type will join or leave the dict as a whole (i.e., train loss metrics will either all appear or all be missing)
289
+ for k in list(losses.keys()):
290
+ if k.endswith('_val'):
291
+ if k not in self.plot_data['legend_val']:
292
+ self.plot_data['legend_val'].append(k)
293
+ elif k.startswith(('G_','D_')):
294
+ if k not in self.plot_data['legend']:
295
+ self.plot_data['legend'].append(k)
296
+ else:
297
+ if k not in self.plot_data['legend_metrics']:
298
+ self.plot_data['legend_metrics'].append(k)
299
+
300
+ # update training loss
301
+ print('update training loss')
302
+ if len(self.plot_data['legend']) > 0:
303
+ if self.plot_data['legend'][0] in losses:
304
+ self.plot_data['X'].append(epoch + counter_ratio)
305
+ self.plot_data['Y'].append([losses[k] for k in self.plot_data['legend']])
306
+ self.plot_data_update_train = True
307
+
308
+ # update validation loss
309
+ print('update validation loss')
310
+ if len(self.plot_data['legend_val']) > 0:
311
+ if self.plot_data['legend_val'][0] in losses:
312
+ self.plot_data['X_val'].append(epoch + counter_ratio)
313
+ self.plot_data['Y_val'].append([losses[k] for k in self.plot_data['legend_val']])
314
+ self.plot_data_update_val = True
315
+
316
+ # update other calculated metrics
317
+ print('update other metrics')
318
+ if len(self.plot_data['legend_metrics']) > 0:
319
+ if self.plot_data['legend_metrics'][0] in losses:
320
+ self.plot_data['X_metrics'].append(epoch + counter_ratio)
321
+ self.plot_data['Y_metrics'].append([losses[k] for k in self.plot_data['legend_metrics']])
322
+ self.plot_data_update_metrics = True
269
323
 
270
324
  try:
271
- self.vis.line(
272
- X=np.stack([np.array(self.plot_data['X'])] * len(self.plot_data['legend']), 1),
273
- Y=np.array(self.plot_data['Y']),
274
- opts={
275
- 'title': self.name + ' loss over time',
276
- 'legend': self.plot_data['legend'],
277
- 'xlabel': 'epoch',
278
- 'ylabel': 'loss'},
279
- win=self.display_id)
325
+ if self.plot_data_update_train:
326
+ print('plotting train loss')
327
+ self.vis.line(
328
+ X=np.stack([np.array(self.plot_data['X'])] * len(self.plot_data['legend']), 1),
329
+ Y=np.array(self.plot_data['Y']),
330
+ opts={
331
+ 'title': self.name + ' train loss over time',
332
+ 'legend': self.plot_data['legend'],
333
+ 'xlabel': 'epoch',
334
+ 'ylabel': 'loss'},
335
+ win = 'train',
336
+ #env=self.display_id
337
+ )
338
+
339
+ if self.plot_data_update_val:
340
+ print('plotting val loss')
341
+ self.vis.line(
342
+ X=np.stack([np.array(self.plot_data['X_val'])] * len(self.plot_data['legend_val']), 1),
343
+ Y=np.array(self.plot_data['Y_val']),
344
+ opts={
345
+ 'title': self.name + ' val loss over time',
346
+ 'legend': self.plot_data['legend_val'],
347
+ 'xlabel': 'epoch',
348
+ 'ylabel': 'loss'},
349
+ win = 'val',
350
+ #env=self.display_id
351
+ )
352
+
353
+ if self.plot_data_update_metrics:
354
+ print('plotting other metrics')
355
+ self.vis.line(
356
+ X=np.stack([np.array(self.plot_data['X_metrics'])] * len(self.plot_data['legend_metrics']), 1),
357
+ Y=np.array(self.plot_data['Y_metrics']),
358
+ opts={
359
+ 'title': self.name + ' metrics over time',
360
+ 'legend': self.plot_data['legend_metrics'],
361
+ 'xlabel': 'epoch',
362
+ 'ylabel': 'metrics'},
363
+ win = 'metrics',
364
+ #env=self.display_id
365
+ )
366
+
280
367
  except VisdomExceptionBase:
281
368
  self.create_visdom_connections()
282
369