deepliif 1.1.5__py3-none-any.whl → 1.1.7__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
cli.py CHANGED
@@ -11,7 +11,7 @@ from PIL import Image
11
11
 
12
12
  from deepliif.data import create_dataset, transform
13
13
  from deepliif.models import inference, postprocess, compute_overlap, init_nets, DeepLIIFModel, infer_modalities, infer_results_for_wsi
14
- from deepliif.util import allowed_file, Visualizer, get_information
14
+ from deepliif.util import allowed_file, Visualizer, get_information, test_diff_original_serialized, disable_batchnorm_tracking_stats
15
15
  from deepliif.util.util import mkdirs, check_multi_scale
16
16
  # from deepliif.util import infer_results_for_wsi
17
17
 
@@ -461,21 +461,51 @@ def trainlaunch(**kwargs):
461
461
  @cli.command()
462
462
  @click.option('--models-dir', default='./model-server/DeepLIIF_Latest_Model', help='reads models from here')
463
463
  @click.option('--output-dir', help='saves results here.')
464
- def serialize(models_dir, output_dir):
464
+ @click.option('--device', default='cpu', type=str, help='device to load model, either cpu or gpu')
465
+ @click.option('--verbose', default=0, type=int,help='saves results here.')
466
+ def serialize(models_dir, output_dir, device, verbose):
465
467
  """Serialize DeepLIIF models using Torchscript
466
468
  """
467
469
  output_dir = output_dir or models_dir
470
+ ensure_exists(output_dir)
468
471
 
469
472
  sample = transform(Image.new('RGB', (512, 512)))
470
-
473
+
471
474
  with click.progressbar(
472
475
  init_nets(models_dir, eager_mode=True).items(),
473
476
  label='Tracing nets',
474
477
  item_show_func=lambda n: n[0] if n else n
475
478
  ) as bar:
476
479
  for name, net in bar:
477
- traced_net = torch.jit.trace(net, sample)
480
+ # the model should be in eval model so that there won't be randomness in tracking brought by dropout etc. layers
481
+ # https://github.com/pytorch/pytorch/issues/23999#issuecomment-747832122
482
+ net = net.eval()
483
+ net = disable_batchnorm_tracking_stats(net)
484
+ net = net.cpu()
485
+ if name.startswith('GS'):
486
+ traced_net = torch.jit.trace(net, torch.cat([sample, sample, sample], 1))
487
+ else:
488
+ traced_net = torch.jit.trace(net, sample)
489
+ # traced_net = torch.jit.script(net)
478
490
  traced_net.save(f'{output_dir}/{name}.pt')
491
+
492
+ # test: whether the original and the serialized model produces highly similar predictions
493
+ print('testing similarity between prediction from original vs serialized models...')
494
+ models_original = init_nets(models_dir,eager_mode=True)
495
+ models_serialized = init_nets(output_dir,eager_mode=False)
496
+ if device == 'gpu':
497
+ sample = sample.cuda()
498
+ else:
499
+ sample = sample.cpu()
500
+ for name in models_serialized.keys():
501
+ print(name,':')
502
+ model_original = models_original[name].cuda().eval() if device=='gpu' else models_original[name].cpu().eval()
503
+ model_serialized = models_serialized[name].cuda() if device=='gpu' else models_serialized[name].cpu().eval()
504
+ if name.startswith('GS'):
505
+ test_diff_original_serialized(model_original,model_serialized,torch.cat([sample, sample, sample], 1),verbose)
506
+ else:
507
+ test_diff_original_serialized(model_original,model_serialized,sample,verbose)
508
+ print('PASS')
479
509
 
480
510
 
481
511
  @cli.command()
@@ -486,7 +516,11 @@ def serialize(models_dir, output_dir):
486
516
  @click.option('--region-size', default=20000, help='Due to limits in the resources, the whole slide image cannot be processed in whole.'
487
517
  'So the WSI image is read region by region. '
488
518
  'This parameter specifies the size each region to be read into GPU for inferrence.')
489
- def test(input_dir, output_dir, tile_size, model_dir, region_size):
519
+ @click.option('--eager-mode', is_flag=True, help='use eager mode (loading original models, otherwise serialized ones)')
520
+ @click.option('--color-dapi', is_flag=True, help='color dapi image to produce the same coloring as in the paper')
521
+ @click.option('--color-marker', is_flag=True, help='color marker image to produce the same coloring as in the paper')
522
+ def test(input_dir, output_dir, tile_size, model_dir, region_size, eager_mode,
523
+ color_dapi, color_marker):
490
524
 
491
525
  """Test trained models
492
526
  """
@@ -507,7 +541,7 @@ def test(input_dir, output_dir, tile_size, model_dir, region_size):
507
541
  print(time.time() - start_time)
508
542
  else:
509
543
  img = Image.open(os.path.join(input_dir, filename)).convert('RGB')
510
- images, scoring = infer_modalities(img, tile_size, model_dir)
544
+ images, scoring = infer_modalities(img, tile_size, model_dir, eager_mode, color_dapi, color_marker)
511
545
 
512
546
  for name, i in images.items():
513
547
  i.save(os.path.join(
@@ -589,6 +623,15 @@ def prepare_testing_data(input_dir, dataset_dir):
589
623
  cv2.imwrite(os.path.join(test_dir, img), np.concatenate([image, image, image, image, image, image], 1))
590
624
 
591
625
 
626
+ # to load pickle file saved from gpu in a cpu environment: https://github.com/pytorch/pytorch/issues/16797#issuecomment-633423219
627
+ from io import BytesIO
628
+ class CPU_Unpickler(pickle.Unpickler):
629
+ def find_class(self, module, name):
630
+ if module == 'torch.storage' and name == '_load_from_bytes':
631
+ return lambda b: torch.load(BytesIO(b), map_location='cpu')
632
+ else: return super().find_class(module, name)
633
+
634
+
592
635
  @cli.command()
593
636
  @click.option('--pickle-dir', required=True, help='directory where the pickled snapshots are stored')
594
637
  def visualize(pickle_dir):
@@ -599,8 +642,8 @@ def visualize(pickle_dir):
599
642
  time.sleep(1)
600
643
 
601
644
  params_opt = pickle.load(open(path_init,'rb'))
602
- params_opt['remote'] = False
603
- visualizer = Visualizer(**params_opt) # create a visualizer that display/save images and plots
645
+ params_opt.remote = False
646
+ visualizer = Visualizer(params_opt) # create a visualizer that display/save images and plots
604
647
 
605
648
  paths_plot = {'display_current_results':os.path.join(pickle_dir,'display_current_results.pickle'),
606
649
  'plot_current_losses':os.path.join(pickle_dir,'plot_current_losses.pickle')}
@@ -612,7 +655,7 @@ def visualize(pickle_dir):
612
655
  try:
613
656
  last_modified_time_plot = os.path.getmtime(path_plot)
614
657
  if last_modified_time_plot > last_modified_time[method]:
615
- params_plot = pickle.load(open(path_plot,'rb'))
658
+ params_plot = CPU_Unpickler(open(path_plot,'rb')).load()
616
659
  last_modified_time[method] = last_modified_time_plot
617
660
  getattr(visualizer,method)(**params_plot)
618
661
  print(f'{method} refreshed, last modified time {time.ctime(last_modified_time[method])}')
@@ -88,7 +88,10 @@ def create_model(opt):
88
88
 
89
89
 
90
90
  def load_torchscript_model(model_pt_path, device):
91
- return torch.jit.load(model_pt_path, map_location=device)
91
+ net = torch.jit.load(model_pt_path, map_location=device)
92
+ net = disable_batchnorm_tracking_stats(net)
93
+ net.eval()
94
+ return net
92
95
 
93
96
 
94
97
  def read_model_params(file_addr):
@@ -132,7 +135,8 @@ def load_eager_models(model_dir, devices):
132
135
  os.path.join(model_dir, f'latest_net_{n}.pth'),
133
136
  map_location=devices[n]
134
137
  ))
135
- nets[n] = net
138
+ nets[n] = disable_batchnorm_tracking_stats(net)
139
+ nets[n].eval()
136
140
 
137
141
  for n in ['G51', 'G52', 'G53', 'G54', 'G55']:
138
142
  net = UnetGenerator(input_nc, output_nc, 9, ngf, norm_layer=norm_layer, use_dropout=use_dropout)
@@ -140,7 +144,8 @@ def load_eager_models(model_dir, devices):
140
144
  os.path.join(model_dir, f'latest_net_{n}.pth'),
141
145
  map_location=devices[n]
142
146
  ))
143
- nets[n] = net
147
+ nets[n] = disable_batchnorm_tracking_stats(net)
148
+ nets[n].eval()
144
149
 
145
150
  return nets
146
151
 
@@ -185,7 +190,12 @@ def compute_overlap(img_size, tile_size):
185
190
  return tile_size // 4
186
191
 
187
192
 
188
- def run_torchserve(img, model_path=None):
193
+ def run_torchserve(img, model_path=None, eager_mode=False):
194
+ """
195
+ eager_mode: not used in this function; put in place to be consistent with run_dask
196
+ so that run_wrapper() could call either this function or run_dask with
197
+ same syntax
198
+ """
189
199
  buffer = BytesIO()
190
200
  torch.save(transform(img.resize((512, 512))), buffer)
191
201
 
@@ -203,9 +213,9 @@ def run_torchserve(img, model_path=None):
203
213
  return {k: tensor_to_pil(deserialize_tensor(v)) for k, v in res.json().items()}
204
214
 
205
215
 
206
- def run_dask(img, model_path):
216
+ def run_dask(img, model_path, eager_mode=False):
207
217
  model_dir = os.getenv('DEEPLIIF_MODEL_DIR', model_path)
208
- nets = init_nets(model_dir)
218
+ nets = init_nets(model_dir, eager_mode)
209
219
 
210
220
  ts = transform(img.resize((512, 512)))
211
221
 
@@ -237,7 +247,7 @@ def is_empty(tile):
237
247
  return True if calculate_background_area(tile) > 98 else False
238
248
 
239
249
 
240
- def run_wrapper(tile, run_fn, model_path):
250
+ def run_wrapper(tile, run_fn, model_path, eager_mode=False):
241
251
  if is_empty(tile):
242
252
  return {
243
253
  'G1': Image.new(mode='RGB', size=(512, 512), color=(201, 211, 208)),
@@ -247,17 +257,17 @@ def run_wrapper(tile, run_fn, model_path):
247
257
  'G5': Image.new(mode='RGB', size=(512, 512), color=(0, 0, 0))
248
258
  }
249
259
  else:
250
- return run_fn(tile, model_path)
251
-
260
+ return run_fn(tile, model_path, eager_mode)
252
261
 
253
- def inference(img, tile_size, overlap_size, model_path, use_torchserve=False):
254
262
 
263
+ def inference_old(img, tile_size, overlap_size, model_path, use_torchserve=False, eager_mode=False,
264
+ color_dapi=False, color_marker=False):
255
265
 
256
266
  tiles = list(generate_tiles(img, tile_size, overlap_size))
257
267
 
258
268
  run_fn = run_torchserve if use_torchserve else run_dask
259
269
  # res = [Tile(t.i, t.j, run_fn(t.img, model_path)) for t in tiles]
260
- res = [Tile(t.i, t.j, run_wrapper(t.img, run_fn, model_path)) for t in tiles]
270
+ res = [Tile(t.i, t.j, run_wrapper(t.img, run_fn, model_path, eager_mode)) for t in tiles]
261
271
 
262
272
  def get_net_tiles(n):
263
273
  return [Tile(t.i, t.j, t.img[n]) for t in res]
@@ -276,12 +286,14 @@ def inference(img, tile_size, overlap_size, model_path, use_torchserve=False):
276
286
 
277
287
  images['DAPI'] = stitch(get_net_tiles('G2'), tile_size, overlap_size).resize(img.size)
278
288
  dapi_pix = np.array(images['DAPI'].convert('L').convert('RGB'))
279
- dapi_pix[:, :, 0] = 0
289
+ if color_dapi:
290
+ dapi_pix[:, :, 0] = 0
280
291
  images['DAPI'] = Image.fromarray(dapi_pix)
281
292
  images['Lap2'] = stitch(get_net_tiles('G3'), tile_size, overlap_size).resize(img.size)
282
293
  images['Marker'] = stitch(get_net_tiles('G4'), tile_size, overlap_size).resize(img.size)
283
294
  marker_pix = np.array(images['Marker'].convert('L').convert('RGB'))
284
- marker_pix[:, :, 2] = 0
295
+ if color_marker:
296
+ marker_pix[:, :, 2] = 0
285
297
  images['Marker'] = Image.fromarray(marker_pix)
286
298
 
287
299
  # images['Marker'] = stitch(
@@ -294,6 +306,52 @@ def inference(img, tile_size, overlap_size, model_path, use_torchserve=False):
294
306
  return images
295
307
 
296
308
 
309
+ def inference(img, tile_size, overlap_size, model_path, use_torchserve=False, eager_mode=False,
310
+ color_dapi=False, color_marker=False):
311
+
312
+ rescaled, rows, cols = format_image_for_tiling(img, tile_size, overlap_size)
313
+
314
+ run_fn = run_torchserve if use_torchserve else run_dask
315
+
316
+ images = {}
317
+ images['Hema'] = create_image_for_stitching(tile_size, rows, cols)
318
+ images['DAPI'] = create_image_for_stitching(tile_size, rows, cols)
319
+ images['Lap2'] = create_image_for_stitching(tile_size, rows, cols)
320
+ images['Marker'] = create_image_for_stitching(tile_size, rows, cols)
321
+ images['Seg'] = create_image_for_stitching(tile_size, rows, cols)
322
+
323
+ for i in range(cols):
324
+ for j in range(rows):
325
+ tile = extract_tile(rescaled, tile_size, overlap_size, i, j)
326
+ res = run_wrapper(tile, run_fn, model_path, eager_mode)
327
+
328
+ stitch_tile(images['Hema'], res['G1'], tile_size, overlap_size, i, j)
329
+ stitch_tile(images['DAPI'], res['G2'], tile_size, overlap_size, i, j)
330
+ stitch_tile(images['Lap2'], res['G3'], tile_size, overlap_size, i, j)
331
+ stitch_tile(images['Marker'], res['G4'], tile_size, overlap_size, i, j)
332
+ stitch_tile(images['Seg'], res['G5'], tile_size, overlap_size, i, j)
333
+
334
+ images['Hema'] = images['Hema'].resize(img.size)
335
+ images['DAPI'] = images['DAPI'].resize(img.size)
336
+ images['Lap2'] = images['Lap2'].resize(img.size)
337
+ images['Marker'] = images['Marker'].resize(img.size)
338
+ images['Seg'] = images['Seg'].resize(img.size)
339
+
340
+ if color_dapi:
341
+ matrix = ( 0, 0, 0, 0,
342
+ 299/1000, 587/1000, 114/1000, 0,
343
+ 299/1000, 587/1000, 114/1000, 0)
344
+ images['DAPI'] = images['DAPI'].convert('RGB', matrix)
345
+
346
+ if color_marker:
347
+ matrix = (299/1000, 587/1000, 114/1000, 0,
348
+ 299/1000, 587/1000, 114/1000, 0,
349
+ 0, 0, 0, 0)
350
+ images['Marker'] = images['Marker'].convert('RGB', matrix)
351
+
352
+ return images
353
+
354
+
297
355
  def postprocess(img, seg_img, thresh=80, noise_objects_size=20, small_object_size=50):
298
356
  mask_image = create_basic_segmentation_mask(np.array(img), np.array(seg_img),
299
357
  thresh, noise_objects_size, small_object_size)
@@ -312,7 +370,8 @@ def postprocess(img, seg_img, thresh=80, noise_objects_size=20, small_object_siz
312
370
  return images, scoring
313
371
 
314
372
 
315
- def infer_modalities(img, tile_size, model_dir):
373
+ def infer_modalities(img, tile_size, model_dir, eager_mode=False,
374
+ color_dapi=False, color_marker=False):
316
375
  """
317
376
  This function is used to infer modalities for the given image using a trained model.
318
377
  :param img: The input image.
@@ -329,7 +388,10 @@ def infer_modalities(img, tile_size, model_dir):
329
388
  img,
330
389
  tile_size=tile_size,
331
390
  overlap_size=compute_overlap(img.size, tile_size),
332
- model_path=model_dir
391
+ model_path=model_dir,
392
+ eager_mode=eager_mode,
393
+ color_dapi=color_dapi,
394
+ color_marker=color_marker
333
395
  )
334
396
 
335
397
  post_images, scoring = postprocess(img, images['Seg'], small_object_size=20)
@@ -3,6 +3,7 @@ import torch
3
3
  from collections import OrderedDict
4
4
  from abc import ABC, abstractmethod
5
5
  from . import networks
6
+ from ..util import disable_batchnorm_tracking_stats
6
7
 
7
8
 
8
9
  class BaseModel(ABC):
@@ -90,6 +91,7 @@ class BaseModel(ABC):
90
91
  if isinstance(name, str):
91
92
  net = getattr(self, 'net' + name)
92
93
  net.eval()
94
+ net = disable_batchnorm_tracking_stats(net)
93
95
 
94
96
  def test(self):
95
97
  """Forward function used in test time.
deepliif/util/__init__.py CHANGED
@@ -88,6 +88,36 @@ def stitch(tiles, tile_size, overlap_size):
88
88
  return new_im
89
89
 
90
90
 
91
+ def format_image_for_tiling(img, tile_size, overlap_size):
92
+ mean_background_val = calculate_background_mean_value(img)
93
+ img = img.resize(output_size(img, tile_size))
94
+ # Adding borders with size of given overlap around the whole slide image
95
+ img = ImageOps.expand(img, border=overlap_size, fill=tuple(mean_background_val))
96
+ rows = int(img.height / tile_size)
97
+ cols = int(img.width / tile_size)
98
+ return img, rows, cols
99
+
100
+
101
+ def extract_tile(img, tile_size, overlap_size, i, j):
102
+ return img.crop((
103
+ i * tile_size, j * tile_size,
104
+ i * tile_size + tile_size + 2 * overlap_size,
105
+ j * tile_size + tile_size + 2 * overlap_size
106
+ ))
107
+
108
+
109
+ def create_image_for_stitching(tile_size, rows, cols):
110
+ width = tile_size * cols
111
+ height = tile_size * rows
112
+ return Image.new('RGB', (width, height))
113
+
114
+
115
+ def stitch_tile(img, tile, tile_size, overlap_size, i, j):
116
+ tile = tile.resize((tile_size + 2 * overlap_size, tile_size + 2 * overlap_size))
117
+ tile = tile.crop((overlap_size, overlap_size, overlap_size + tile_size, overlap_size + tile_size))
118
+ img.paste(tile, (i * tile_size, j * tile_size))
119
+
120
+
91
121
  def calculate_background_mean_value(img):
92
122
  img = cv2.fastNlMeansDenoisingColored(np.array(img), None, 10, 10, 7, 21)
93
123
  img = np.array(img, dtype=float)
@@ -349,3 +379,39 @@ def read_results_from_pickle_file(input_addr):
349
379
  pickle_obj.close()
350
380
  return results
351
381
 
382
+ def test_diff_original_serialized(model_original,model_serialized,example,verbose=0):
383
+ threshold = 10
384
+
385
+ orig_res = model_original(example)
386
+ if verbose > 0:
387
+ print('Original:')
388
+ print(orig_res.shape)
389
+ print(orig_res[0, 0:10])
390
+ print('min abs value:{}'.format(torch.min(torch.abs(orig_res))))
391
+
392
+ ts_res = model_serialized(example)
393
+ if verbose > 0:
394
+ print('Torchscript:')
395
+ print(ts_res.shape)
396
+ print(ts_res[0, 0:10])
397
+ print('min abs value:{}'.format(torch.min(torch.abs(ts_res))))
398
+
399
+ abs_diff = torch.abs(orig_res-ts_res)
400
+ if verbose > 0:
401
+ print('Dif sum:')
402
+ print(torch.sum(abs_diff))
403
+ print('max dif:{}'.format(torch.max(abs_diff)))
404
+
405
+ assert torch.sum(abs_diff) <= threshold, f"Sum of difference in predicted values {torch.sum(abs_diff)} is larger than threshold {threshold}"
406
+
407
+ def disable_batchnorm_tracking_stats(model):
408
+ # https://discuss.pytorch.org/t/performance-highly-degraded-when-eval-is-activated-in-the-test-phase/3323/16
409
+ # https://discuss.pytorch.org/t/performance-highly-degraded-when-eval-is-activated-in-the-test-phase/3323/67
410
+ # https://github.com/pytorch/pytorch/blob/ca39c5b04e30a67512589cafbd9d063cc17168a5/torch/nn/modules/batchnorm.py#L158
411
+ for m in model.modules():
412
+ for child in m.children():
413
+ if type(child) == torch.nn.BatchNorm2d:
414
+ child.track_running_stats = False
415
+ child.running_mean = None
416
+ child.running_var = None
417
+ return model