deepliif 1.1.9__py3-none-any.whl → 1.1.11__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
@@ -133,6 +133,10 @@ def load_eager_models(opt, devices):
133
133
  net.eval()
134
134
  net = disable_batchnorm_tracking_stats(net)
135
135
 
136
+ # SDG models when loaded are still DP.. not sure why
137
+ if isinstance(net, torch.nn.DataParallel):
138
+ net = net.module
139
+
136
140
  nets[name] = net
137
141
  nets[name].to(devices[name])
138
142
 
@@ -161,7 +165,7 @@ def init_nets(model_dir, eager_mode=False, opt=None, phase='test'):
161
165
  ('G4', 'G55'),
162
166
  ('G51',)
163
167
  ]
164
- elif opt.model == 'DeepLIIFExt':
168
+ elif opt.model in ['DeepLIIFExt','SDG']:
165
169
  if opt.seg_gen:
166
170
  net_groups = [(f'G_{i+1}',f'GS_{i+1}') for i in range(opt.modalities_no)]
167
171
  else:
@@ -172,6 +176,7 @@ def init_nets(model_dir, eager_mode=False, opt=None, phase='test'):
172
176
  number_of_gpus_all = torch.cuda.device_count()
173
177
  number_of_gpus = len(opt.gpu_ids)
174
178
  #print(number_of_gpus)
179
+
175
180
  if number_of_gpus > 0:
176
181
  mapping_gpu_ids = {i:idx for i,idx in enumerate(opt.gpu_ids)}
177
182
  chunks = [itertools.chain.from_iterable(c) for c in chunker(net_groups, number_of_gpus)]
@@ -206,10 +211,7 @@ def run_torchserve(img, model_path=None, eager_mode=False, opt=None):
206
211
  opt: same as eager_mode
207
212
  """
208
213
  buffer = BytesIO()
209
- if opt.model == 'DeepLIIFExt':
210
- torch.save(transform(img.resize((1024, 1024))), buffer)
211
- else:
212
- torch.save(transform(img.resize((512, 512))), buffer)
214
+ torch.save(transform(img.resize((opt.scale_size, opt.scale_size))), buffer)
213
215
 
214
216
  torchserve_host = os.getenv('TORCHSERVE_HOST', 'http://localhost')
215
217
  res = requests.post(
@@ -229,10 +231,12 @@ def run_dask(img, model_path, eager_mode=False, opt=None):
229
231
  model_dir = os.getenv('DEEPLIIF_MODEL_DIR', model_path)
230
232
  nets = init_nets(model_dir, eager_mode, opt)
231
233
 
232
- if opt.model == 'DeepLIIFExt':
233
- ts = transform(img.resize((1024, 1024)))
234
+ if opt.input_no > 1 or opt.model == 'SDG':
235
+ l_ts = [transform(img_i.resize((opt.scale_size,opt.scale_size))) for img_i in img]
236
+ ts = torch.cat(l_ts, dim=1)
234
237
  else:
235
- ts = transform(img.resize((512, 512)))
238
+ ts = transform(img.resize((opt.scale_size, opt.scale_size)))
239
+
236
240
 
237
241
  @delayed
238
242
  def forward(input, model):
@@ -249,14 +253,20 @@ def run_dask(img, model_path, eager_mode=False, opt=None):
249
253
  lazy_segs['G51'] = forward(ts, nets['G51']).to(torch.device('cpu'))
250
254
  segs = compute(lazy_segs)[0]
251
255
 
252
- seg_weights = [0.25, 0.25, 0.25, 0, 0.25]
253
- seg = torch.stack([torch.mul(n, w) for n, w in zip(segs.values(), seg_weights)]).sum(dim=0)
256
+ weights = {
257
+ 'G51': 0.25, # IHC
258
+ 'G52': 0.25, # Hema
259
+ 'G53': 0.25, # DAPI
260
+ 'G54': 0.00, # Lap2
261
+ 'G55': 0.25, # Marker
262
+ }
263
+ seg = torch.stack([torch.mul(segs[k], weights[k]) for k in segs.keys()]).sum(dim=0)
254
264
 
255
265
  res = {k: tensor_to_pil(v) for k, v in gens.items()}
256
266
  res['G5'] = tensor_to_pil(seg)
257
267
 
258
268
  return res
259
- elif opt.model == 'DeepLIIFExt':
269
+ elif opt.model in ['DeepLIIFExt','SDG']:
260
270
  seg_map = {'G_' + str(i): 'GS_' + str(i) for i in range(1, opt.modalities_no + 1)}
261
271
 
262
272
  lazy_gens = {k: forward(ts, nets[k]) for k in seg_map}
@@ -271,14 +281,23 @@ def run_dask(img, model_path, eager_mode=False, opt=None):
271
281
 
272
282
  return res
273
283
  else:
274
- raise Exception(f'run_dask() not implemented for {opt.model}')
284
+ raise Exception(f'run_dask() not fully implemented for {opt.model}')
275
285
 
276
-
286
+
287
+ def is_empty_old(tile):
288
+ # return True if np.mean(np.array(tile) - np.array(mean_background_val)) < 40 else False
289
+ if isinstance(tile, list): # for pair of tiles, only mark it as empty / no need for prediction if ALL tiles are empty
290
+ return all([True if calculate_background_area(t) > 98 else False for t in tile])
291
+ else:
292
+ return True if calculate_background_area(tile) > 98 else False
277
293
 
278
294
 
279
295
  def is_empty(tile):
280
- # return True if np.mean(np.array(tile) - np.array(mean_background_val)) < 40 else False
281
- return True if calculate_background_area(tile) > 98 else False
296
+ thresh = 15
297
+ if isinstance(tile, list): # for pair of tiles, only mark it as empty / no need for prediction if ALL tiles are empty
298
+ return all([True if np.max(image_variance_rgb(t)) < thresh else False for t in tile])
299
+ else:
300
+ return True if np.max(image_variance_rgb(tile)) < thresh else False
282
301
 
283
302
 
284
303
  def run_wrapper(tile, run_fn, model_path, eager_mode=False, opt=None):
@@ -293,7 +312,7 @@ def run_wrapper(tile, run_fn, model_path, eager_mode=False, opt=None):
293
312
  }
294
313
  else:
295
314
  return run_fn(tile, model_path, eager_mode, opt)
296
- elif opt.model == 'DeepLIIFExt':
315
+ elif opt.model in ['DeepLIIFExt', 'SDG']:
297
316
  if is_empty(tile):
298
317
  res = {'G_' + str(i): Image.new(mode='RGB', size=(512, 512)) for i in range(1, opt.modalities_no + 1)}
299
318
  res.update({'GS_' + str(i): Image.new(mode='RGB', size=(512, 512)) for i in range(1, opt.modalities_no + 1)})
@@ -350,8 +369,8 @@ def inference_old(img, tile_size, overlap_size, model_path, use_torchserve=False
350
369
  return images
351
370
 
352
371
 
353
- def inference(img, tile_size, overlap_size, model_path, use_torchserve=False, eager_mode=False,
354
- color_dapi=False, color_marker=False, opt=None):
372
+ def inference_old2(img, tile_size, overlap_size, model_path, use_torchserve=False, eager_mode=False,
373
+ color_dapi=False, color_marker=False, opt=None):
355
374
  if not opt:
356
375
  opt = get_opt(model_path)
357
376
  #print_options(opt)
@@ -362,28 +381,25 @@ def inference(img, tile_size, overlap_size, model_path, use_torchserve=False, ea
362
381
  run_fn = run_torchserve if use_torchserve else run_dask
363
382
 
364
383
  images = {}
365
- images['Hema'] = create_image_for_stitching(tile_size, rows, cols)
366
- images['DAPI'] = create_image_for_stitching(tile_size, rows, cols)
367
- images['Lap2'] = create_image_for_stitching(tile_size, rows, cols)
368
- images['Marker'] = create_image_for_stitching(tile_size, rows, cols)
369
- images['Seg'] = create_image_for_stitching(tile_size, rows, cols)
384
+ d_modality2net = {'Hema':'G1',
385
+ 'DAPI':'G2',
386
+ 'Lap2':'G3',
387
+ 'Marker':'G4',
388
+ 'Seg':'G5'}
389
+
390
+ for k in d_modality2net.keys():
391
+ images[k] = create_image_for_stitching(tile_size, rows, cols)
370
392
 
371
393
  for i in range(cols):
372
394
  for j in range(rows):
373
395
  tile = extract_tile(rescaled, tile_size, overlap_size, i, j)
374
396
  res = run_wrapper(tile, run_fn, model_path, eager_mode, opt)
375
-
376
- stitch_tile(images['Hema'], res['G1'], tile_size, overlap_size, i, j)
377
- stitch_tile(images['DAPI'], res['G2'], tile_size, overlap_size, i, j)
378
- stitch_tile(images['Lap2'], res['G3'], tile_size, overlap_size, i, j)
379
- stitch_tile(images['Marker'], res['G4'], tile_size, overlap_size, i, j)
380
- stitch_tile(images['Seg'], res['G5'], tile_size, overlap_size, i, j)
381
-
382
- images['Hema'] = images['Hema'].resize(img.size)
383
- images['DAPI'] = images['DAPI'].resize(img.size)
384
- images['Lap2'] = images['Lap2'].resize(img.size)
385
- images['Marker'] = images['Marker'].resize(img.size)
386
- images['Seg'] = images['Seg'].resize(img.size)
397
+
398
+ for modality_name, net_name in d_modality2net.items():
399
+ stitch_tile(images[modality_name], res[net_name], tile_size, overlap_size, i, j)
400
+
401
+ for modality_name, output_img in images.items():
402
+ images[modality_name] = output_img.resize(img.size)
387
403
 
388
404
  if color_dapi:
389
405
  matrix = ( 0, 0, 0, 0,
@@ -403,30 +419,133 @@ def inference(img, tile_size, overlap_size, model_path, use_torchserve=False, ea
403
419
  #param_dict = read_train_options(model_path)
404
420
  #modalities_no = int(param_dict['modalities_no']) if param_dict else 4
405
421
  #seg_gen = (param_dict['seg_gen'] == 'True') if param_dict else True
406
-
407
- tiles = list(generate_tiles(img, tile_size, overlap_size))
408
-
422
+
423
+
424
+ rescaled, rows, cols = format_image_for_tiling(img, tile_size, overlap_size)
409
425
  run_fn = run_torchserve if use_torchserve else run_dask
410
- res = [Tile(t.i, t.j, run_wrapper(t.img, run_fn, model_path, eager_mode, opt)) for t in tiles]
411
426
 
412
427
  def get_net_tiles(n):
413
428
  return [Tile(t.i, t.j, t.img[n]) for t in res]
414
429
 
415
430
  images = {}
416
-
417
- for i in range(1, opt.modalities_no + 1):
418
- images['mod' + str(i)] = stitch(get_net_tiles('G_' + str(i)), tile_size, overlap_size).resize(img.size)
419
-
431
+ d_modality2net = {f'mod{i}':f'G_{i}' for i in range(1, opt.modalities_no + 1)}
420
432
  if opt.seg_gen:
421
- for i in range(1, opt.modalities_no + 1):
422
- images['Seg' + str(i)] = stitch(get_net_tiles('GS_' + str(i)), tile_size, overlap_size).resize(img.size)
433
+ d_modality2net.update({f'Seg{i}':f'GS_{i}' for i in range(1, opt.modalities_no + 1)})
434
+
435
+ for k in d_modality2net.keys():
436
+ images[k] = create_image_for_stitching(tile_size, rows, cols)
423
437
 
438
+ for i in range(cols):
439
+ for j in range(rows):
440
+ tile = extract_tile(rescaled, tile_size, overlap_size, i, j)
441
+ res = run_wrapper(tile, run_fn, model_path, eager_mode, opt)
442
+
443
+ for modality_name, net_name in d_modality2net.items():
444
+ stitch_tile(images[modality_name], res[net_name], tile_size, overlap_size, i, j)
445
+
446
+ for modality_name, output_img in images.items():
447
+ images[modality_name] = output_img.resize(img.size)
448
+
449
+ return images
450
+
451
+ elif opt.model == 'SDG':
452
+ # SDG could have multiple input images / modalities
453
+ # the input hence could be a rectangle
454
+ # we split the input to get each modality image one by one
455
+ # then create tiles for each of the modality images
456
+ # tile_pair is a list that contains the tiles at the given location for each modality image
457
+ # l_tile_pair is a list of tile_pair that covers all locations
458
+ # for inference, each tile_pair is used to get the output at the given location
459
+ w, h = img.size
460
+ w2 = int(w / opt.input_no)
461
+
462
+ l_img = []
463
+ for i in range(opt.input_no):
464
+ img_i = img.crop((w2 * i, 0, w2 * (i+1), h))
465
+ rescaled_img_i, rows, cols = format_image_for_tiling(img_i, tile_size, overlap_size)
466
+ l_img.append(rescaled_img_i)
467
+
468
+ run_fn = run_torchserve if use_torchserve else run_dask
469
+
470
+ images = {}
471
+ d_modality2net = {f'mod{i}':f'G_{i}' for i in range(1, opt.modalities_no + 1)}
472
+ for k in d_modality2net.keys():
473
+ images[k] = create_image_for_stitching(tile_size, rows, cols)
474
+
475
+ for i in range(cols):
476
+ for j in range(rows):
477
+ tile_pair = [extract_tile(rescaled, tile_size, overlap_size, i, j) for rescaled in l_img]
478
+ res = run_wrapper(tile_pair, run_fn, model_path, eager_mode, opt)
479
+
480
+ for modality_name, net_name in d_modality2net.items():
481
+ stitch_tile(images[modality_name], res[net_name], tile_size, overlap_size, i, j)
482
+
483
+ for modality_name, output_img in images.items():
484
+ images[modality_name] = output_img.resize((w2,w2))
485
+
424
486
  return images
425
487
 
426
488
  else:
427
489
  raise Exception(f'inference() not implemented for model {opt.model}')
428
490
 
429
491
 
492
+ def inference(img, tile_size, overlap_size, model_path, use_torchserve=False,
493
+ eager_mode=False, color_dapi=False, color_marker=False, opt=None):
494
+ if not opt:
495
+ opt = get_opt(model_path)
496
+ #print_options(opt)
497
+
498
+ run_fn = run_torchserve if use_torchserve else run_dask
499
+
500
+ if opt.model == 'SDG':
501
+ # SDG could have multiple input images/modalities, hence the input could be a rectangle.
502
+ # We split the input to get each modality image then create tiles for each set of input images.
503
+ w, h = int(img.width / opt.input_no), img.height
504
+ orig = [img.crop((w * i, 0, w * (i+1), h)) for i in range(opt.input_no)]
505
+ else:
506
+ # Otherwise expect a single input image, which is used directly.
507
+ orig = img
508
+
509
+ tiler = InferenceTiler(orig, tile_size, overlap_size)
510
+ for tile in tiler:
511
+ tiler.stitch(run_wrapper(tile, run_fn, model_path, eager_mode, opt))
512
+ results = tiler.results()
513
+
514
+ if opt.model == 'DeepLIIF':
515
+ images = {
516
+ 'Hema': results['G1'],
517
+ 'DAPI': results['G2'],
518
+ 'Lap2': results['G3'],
519
+ 'Marker': results['G4'],
520
+ 'Seg': results['G5'],
521
+ }
522
+ if color_dapi:
523
+ matrix = ( 0, 0, 0, 0,
524
+ 299/1000, 587/1000, 114/1000, 0,
525
+ 299/1000, 587/1000, 114/1000, 0)
526
+ images['DAPI'] = images['DAPI'].convert('RGB', matrix)
527
+ if color_marker:
528
+ matrix = (299/1000, 587/1000, 114/1000, 0,
529
+ 299/1000, 587/1000, 114/1000, 0,
530
+ 0, 0, 0, 0)
531
+ images['Marker'] = images['Marker'].convert('RGB', matrix)
532
+ return images
533
+
534
+ elif opt.model == 'DeepLIIFExt':
535
+ images = {f'mod{i}': results[f'G_{i}'] for i in range(1, opt.modalities_no + 1)}
536
+ if opt.seg_gen:
537
+ images.update({f'Seg{i}': results[f'GS_{i}'] for i in range(1, opt.modalities_no + 1)})
538
+ return images
539
+
540
+ elif opt.model == 'SDG':
541
+ images = {f'mod{i}': results[f'G_{i}'] for i in range(1, opt.modalities_no + 1)}
542
+ return images
543
+
544
+ else:
545
+ #raise Exception(f'inference() not implemented for model {opt.model}')
546
+ return results # return result images with default key names (i.e., net names)
547
+
548
+
430
549
  def postprocess(orig, images, tile_size, model, seg_thresh=150, size_thresh='auto', marker_thresh='auto', size_thresh_upper=None):
431
550
  if model == 'DeepLIIF':
432
551
  resolution = '40x' if tile_size > 384 else ('20x' if tile_size > 192 else '10x')
@@ -476,11 +595,16 @@ def infer_modalities(img, tile_size, model_dir, eager_mode=False,
476
595
  tile_size = check_multi_scale(Image.open('./images/target.png').convert('L'),
477
596
  img.convert('L'))
478
597
  tile_size = int(tile_size)
598
+
599
+ # for those with multiple input modalities, find the correct size to calculate overlap_size
600
+ input_no = opt.input_no if hasattr(opt, 'input_no') else 1
601
+ img_size = (img.size[0] / input_no, img.size[1]) # (width, height)
479
602
 
480
603
  images = inference(
481
604
  img,
482
605
  tile_size=tile_size,
483
- overlap_size=compute_overlap(img.size, tile_size),
606
+ #overlap_size=compute_overlap(img_size, tile_size),
607
+ overlap_size=tile_size//16,
484
608
  model_path=model_dir,
485
609
  eager_mode=eager_mode,
486
610
  color_dapi=color_dapi,
@@ -13,7 +13,6 @@ def read_model_params(file_addr):
13
13
  key = line.split(':')[0].strip()
14
14
  val = line.split(':')[1].split('[')[0].strip()
15
15
  param_dict[key] = val
16
- print(param_dict)
17
16
  return param_dict
18
17
 
19
18
  class Options:
@@ -32,19 +31,7 @@ class Options:
32
31
  else:
33
32
  setattr(self,k,v)
34
33
  except:
35
- setattr(self,k,v)
36
-
37
- if mode != 'train':
38
- # to account for old settings where gpu_ids value is an integer, not a tuple
39
- if isinstance(self.gpu_ids,int):
40
- self.gpu_ids = (self.gpu_ids,)
41
-
42
- # to account for old settings before modalities_no was introduced
43
- if not hasattr(self,'modalities_no') and hasattr(self,'targets_no'):
44
- self.modalities_no = self.targets_no - 1
45
- del self.targets_no
46
-
47
-
34
+ setattr(self,k,v)
48
35
 
49
36
  if mode == 'train':
50
37
  self.is_train = True
@@ -71,32 +58,78 @@ class Options:
71
58
  self.checkpoints_dir = str(model_dir.parent)
72
59
  self.name = str(model_dir.name)
73
60
 
74
- self.gpu_ids = [] # gpu_ids is only used by eager mode, set to empty / cpu to be the same as the old settings; non-eager mode will use all gpus
61
+ #self.gpu_ids = [] # gpu_ids is only used by eager mode, set to empty / cpu to be the same as the old settings; non-eager mode will use all gpus
62
+
63
+ # to account for old settings where gpu_ids value is an integer, not a tuple
64
+ if isinstance(self.gpu_ids,int):
65
+ self.gpu_ids = (self.gpu_ids,)
66
+
67
+ # to account for old settings before modalities_no was introduced
68
+ if not hasattr(self,'modalities_no') and hasattr(self,'targets_no'):
69
+ self.modalities_no = self.targets_no - 1
70
+ del self.targets_no
71
+
72
+ # to account for old settings: same as in cli.py train
73
+ if not hasattr(self,'seg_no'):
74
+ if self.model == 'DeepLIIF':
75
+ self.seg_no = 1
76
+ elif self.model == 'DeepLIIFExt':
77
+ if self.seg_gen:
78
+ self.seg_no = self.modalities_no
79
+ else:
80
+ self.seg_no = 0
81
+ elif self.model == 'SDG':
82
+ self.seg_no = 0
83
+ self.seg_gen = False
84
+ else:
85
+ raise Exception(f'seg_gen cannot be automatically determined for {opt.model}')
86
+
87
+ # to account for old settings: prior to SDG, our models only have 1 input image
88
+ if not hasattr(self,'input_no'):
89
+ self.input_no = 1
90
+
91
+ # to account for old settings: before adding scale_size
92
+ if not hasattr(self, 'scale_size'):
93
+ if self.model in ['DeepLIIF','SDG']:
94
+ self.scale_size = 512
95
+ elif self.model == 'DeepLIIFExt':
96
+ self.scale_size = 1024
97
+ else:
98
+ raise Exception(f'scale_size cannot be automatically determined for {opt.model}')
99
+
100
+
75
101
 
76
102
  def _get_kwargs(self):
77
103
  common_attr = ['__class__', '__delattr__', '__dict__', '__dir__', '__doc__', '__eq__', '__format__', '__ge__', '__getattribute__', '__gt__', '__hash__', '__init__', '__init_subclass__', '__le__', '__lt__', '__module__', '__ne__', '__new__', '__reduce__', '__reduce_ex__', '__repr__', '__setattr__', '__sizeof__', '__str__', '__subclasshook__', '__weakref__']
78
104
  l_args = [x for x in dir(self) if x not in common_attr]
79
105
  return {k:getattr(self,k) for k in l_args}
80
-
81
- def print_options(opt):
82
- """Print and save options
83
106
 
84
- It will print both current options and default values(if different).
85
- It will save options into a text file / [checkpoints_dir] / opt.txt
86
- """
107
+ def format_options(opt):
87
108
  message = ''
88
109
  message += '----------------- Options ---------------\n'
89
110
  for k, v in sorted(vars(opt).items()):
90
111
  comment = ''
91
112
  message += '{:>25}: {:<30}{}\n'.format(str(k), str(v), comment)
92
113
  message += '----------------- End -------------------'
93
- print(message)
114
+ return message
115
+
116
+ def print_options(opt, save=False):
117
+ """Print (and save) options
94
118
 
119
+ It will print both current options and default values(if different).
120
+ If save=True, it will save options into a text file / [checkpoints_dir] / opt.txt
121
+ """
122
+ message = format_options(opt)
123
+ print(message)
124
+
95
125
  # save to the disk
96
- if opt.phase == 'train':
97
- expr_dir = os.path.join(opt.checkpoints_dir, opt.name)
98
- mkdirs(expr_dir)
99
- file_name = os.path.join(expr_dir, '{}_opt.txt'.format(opt.phase))
100
- with open(file_name, 'wt') as opt_file:
101
- opt_file.write(message)
102
- opt_file.write('\n')
126
+ if save:
127
+ save_options(opt)
128
+
129
+ def save_options(opt):
130
+ message = format_options(opt)
131
+ expr_dir = os.path.join(opt.checkpoints_dir, opt.name)
132
+ mkdirs(expr_dir)
133
+ file_name = os.path.join(expr_dir, '{}_opt.txt'.format(opt.phase))
134
+ with open(file_name, 'wt') as opt_file:
135
+ opt_file.write(message+'\n')