deepliif 1.2.3__py3-none-any.whl → 1.2.4__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
@@ -133,7 +133,7 @@ def load_eager_models(opt, devices=None):
133
133
  model_names = list(devices.keys())
134
134
  else:
135
135
  model_names = model.model_names
136
-
136
+
137
137
  for name in model_names:#model.model_names:
138
138
  if isinstance(name, str):
139
139
  if '_' in name:
@@ -155,7 +155,6 @@ def load_eager_models(opt, devices=None):
155
155
 
156
156
  return nets
157
157
 
158
-
159
158
  @lru_cache
160
159
  def init_nets(model_dir, eager_mode=False, opt=None, phase='test'):
161
160
  """
@@ -169,14 +168,20 @@ def init_nets(model_dir, eager_mode=False, opt=None, phase='test'):
169
168
  opt = get_opt(model_dir, mode=phase)
170
169
  opt.use_dp = False
171
170
 
171
+ # get l_net_groups: a list of lists, each sublist is a candidate, ordered by priority (try 1st, if fail try 2nd, etc.)
172
172
  if opt.model in ['DeepLIIF','DeepLIIFKD']:
173
- net_groups = [
174
- ('G1', 'G52'),
175
- ('G2', 'G53'),
176
- ('G3', 'G54'),
177
- ('G4', 'G55'),
178
- ('G51',)
179
- ]
173
+ # net_groups = [
174
+ # ('G1', 'G52'),
175
+ # ('G2', 'G53'),
176
+ # ('G3', 'G54'),
177
+ # ('G4', 'G55'),
178
+ # ('G51',)
179
+ # ]
180
+ if opt.modalities_no == 0:
181
+ net_groups = [(f'G{opt.mod_id_seg}{opt.input_id}',)]
182
+ else:
183
+ net_groups = [(f'G{i+1}', f'G{opt.mod_id_seg}{int(opt.input_id)+i+1}') for i in range(opt.modalities_no)]
184
+ net_groups += [(f'G{opt.mod_id_seg}{opt.input_id}',)] # this is the generator for the input base mod
180
185
  elif opt.model in ['DeepLIIFExt','SDG']:
181
186
  if opt.seg_gen:
182
187
  net_groups = [(f'G_{i+1}',f'GS_{i+1}') for i in range(opt.modalities_no)]
@@ -197,8 +202,8 @@ def init_nets(model_dir, eager_mode=False, opt=None, phase='test'):
197
202
  mapping_gpu_ids = {i:idx for i,idx in enumerate(opt.gpu_ids)}
198
203
  chunks = [itertools.chain.from_iterable(c) for c in chunker(net_groups, number_of_gpus)]
199
204
  # chunks = chunks[1:]
200
- devices = {n: torch.device(f'cuda:{mapping_gpu_ids[i]}') for i, g in enumerate(chunks) for n in g}
201
- # devices = {n: torch.device(f'cuda:{i}') for i, g in enumerate(chunks) for n in g}
205
+ #l_devices = [{n: torch.device(f'cuda:{mapping_gpu_ids[i]}') for i, g in enumerate(chunks) for n in g} for chunks in l_chunks]
206
+ devices = {n: torch.device(f'cuda:{i}') for i, g in enumerate(chunks) for n in g}
202
207
  else:
203
208
  devices = {n: torch.device('cpu') for n in itertools.chain.from_iterable(net_groups)}
204
209
 
@@ -219,13 +224,14 @@ def compute_overlap(img_size, tile_size):
219
224
  return tile_size // 4
220
225
 
221
226
 
222
- def run_torchserve(img, model_path=None, nets=None, eager_mode=False, opt=None, seg_only=False, seg_weights=None, use_dask=True, output_tensor=False):
227
+ def run_torchserve(img, model_path=None, nets=None, eager_mode=False, opt=None, seg_only=False, mod_only=False, seg_weights=None, use_dask=True, output_tensor=False):
223
228
  """
224
229
  eager_mode: not used in this function; put in place to be consistent with run_dask
225
230
  so that run_wrapper() could call either this function or run_dask with
226
231
  same syntax
227
232
  opt: same as eager_mode
228
233
  seg_only: same as eager_mode
234
+ mod_only: same as eager_mode
229
235
  seg_weights: same as eager_mode
230
236
  nets: same as eager_mode
231
237
  """
@@ -246,7 +252,8 @@ def run_torchserve(img, model_path=None, nets=None, eager_mode=False, opt=None,
246
252
  return {k: tensor_to_pil(deserialize_tensor(v)) for k, v in res.json().items()}
247
253
 
248
254
 
249
- def run_dask(img, model_path=None, nets=None, eager_mode=False, opt=None, seg_only=False, seg_weights=None, use_dask=True, output_tensor=False):
255
+ def run_dask(img, model_path=None, nets=None, eager_mode=False, opt=None, seg_only=False, mod_only=False,
256
+ seg_weights=None, use_dask=True, output_tensor=False):
250
257
  """
251
258
  Provide either the model path or the networks object.
252
259
 
@@ -269,65 +276,81 @@ def run_dask(img, model_path=None, nets=None, eager_mode=False, opt=None, seg_on
269
276
  else:
270
277
  ts = transform(img.resize((opt.scale_size, opt.scale_size)))
271
278
 
272
-
279
+
273
280
  if use_dask:
274
281
  @delayed
275
282
  def forward(input, model):
276
283
  with torch.no_grad():
277
284
  return model(input.to(next(model.parameters()).device))
278
- else: # some train settings like spectral norm some how in inference mode is not compatible with dask
285
+ else: # some train settings like spectral norm somehow in inference mode is not compatible with dask
279
286
  def forward(input, model):
280
287
  with torch.no_grad():
281
288
  return model(input.to(next(model.parameters()).device))
282
289
 
283
290
  if opt.model in ['DeepLIIF','DeepLIIFKD']:
284
291
  if seg_weights is None:
285
- weights = {
286
- 'G51': 0.5, # IHC
287
- 'G52': 0.0, # Hema
288
- 'G53': 0.0, # DAPI
289
- 'G54': 0.0, # Lap2
290
- 'G55': 0.5, # Marker
291
- }
292
+ # weights = {
293
+ # 'G51': 0.5, # IHC
294
+ # 'G52': 0.0, # Hema
295
+ # 'G53': 0.0, # DAPI
296
+ # 'G54': 0.0, # Lap2
297
+ # 'G55': 0.5, # Marker
298
+ # }
299
+ weights = {f'G{opt.mod_id_seg}{int(opt.input_id)+i}': 1/(opt.modalities_no+1) for i in range(opt.modalities_no+1)}
292
300
  else:
293
- weights = {
294
- 'G51': seg_weights[0], # IHC
295
- 'G52': seg_weights[1], # Hema
296
- 'G53': seg_weights[2], # DAPI
297
- 'G54': seg_weights[3], # Lap2
298
- 'G55': seg_weights[4], # Marker
299
- }
300
-
301
- seg_map = {'G1': 'G52', 'G2': 'G53', 'G3': 'G54', 'G4': 'G55'}
301
+ # weights = {
302
+ # 'G51': seg_weights[0], # IHC
303
+ # 'G52': seg_weights[1], # Hema
304
+ # 'G53': seg_weights[2], # DAPI
305
+ # 'G54': seg_weights[3], # Lap2
306
+ # 'G55': seg_weights[4], # Marker
307
+ # }
308
+ weights = {f'G{opt.mod_id_seg}{int(opt.input_id)+i}': seg_weight for i,seg_weight in enumerate(seg_weights)}
309
+
310
+
311
+ seg_map = {f'G{i+1}': f'G{opt.mod_id_seg}{int(opt.input_id)+i+1}' for i in range(opt.modalities_no)}
302
312
  if seg_only:
303
313
  seg_map = {k: v for k, v in seg_map.items() if weights[v] != 0}
304
314
 
305
315
  lazy_gens = {k: forward(ts, nets[k]) for k in seg_map}
306
- if 'G4' not in seg_map:
307
- lazy_gens['G4'] = forward(ts, nets['G4'])
316
+ if 'Marker' in opt.modalities_names:
317
+ mod_id_marker = opt.modalities_names.index("Marker")
318
+ if f'G{mod_id_marker}' not in seg_map:
319
+ lazy_gens[f'G{mod_id_marker}'] = forward(ts, nets[f'G{mod_id_marker}'])
320
+
308
321
  gens = compute(lazy_gens)[0]
309
322
 
310
- lazy_segs = {v: forward(gens[k], nets[v]) for k, v in seg_map.items()}
311
- if not seg_only or weights['G51'] != 0:
312
- lazy_segs['G51'] = forward(ts, nets['G51'])
313
- segs = compute(lazy_segs)[0]
323
+ if not mod_only:
324
+ lazy_segs = {v: forward(gens[k], nets[v]) for k, v in seg_map.items()}
325
+ # run seg generator for the base input
326
+ if weights[f'G{opt.mod_id_seg}{opt.input_id}'] != 0:
327
+ lazy_segs[f'G{opt.mod_id_seg}{opt.input_id}'] = forward(ts, nets[f'G{opt.mod_id_seg}{opt.input_id}'])
328
+ segs = compute(lazy_segs)[0]
314
329
 
315
- device = next(nets['G1'].parameters()).device # take the device of the first net and move all outputs there for seg aggregation
316
- seg = torch.stack([torch.mul(segs[k].to(device), weights[k]) for k in segs.keys()]).sum(dim=0)
330
+ model_name_first = list(nets.keys())[0]
331
+ device = next(nets[model_name_first].parameters()).device # take the device of the first net and move all outputs there for seg aggregation
332
+ seg = torch.stack([torch.mul(segs[k].to(device), weights[k]) for k in segs.keys()]).sum(dim=0)
317
333
 
318
334
  if output_tensor:
319
- if seg_only:
320
- res = {'G4': gens['G4']} if 'G4' in gens else {}
335
+ if mod_only:
336
+ res = gens
337
+ elif seg_only and opt.modalities_no > 0:
338
+ res = {f'G{opt.modalities_no}': gens[f'G{opt.modalities_no}']} if f'G{opt.modalities_no}' in gens else {}
339
+ res[f'G{opt.mod_id_seg}'] = seg
321
340
  else:
322
341
  res = {**gens, **segs}
323
- res['G5'] = seg
342
+ res[f'G{opt.mod_id_seg}'] = seg
343
+
324
344
  else:
325
- if seg_only:
326
- res = {'G4': tensor_to_pil(gens['G4'].to(torch.device('cpu')))} if 'G4' in gens else {}
345
+ if mod_only:
346
+ res = {k: tensor_to_pil(v.to(torch.device('cpu'))) for k, v in gens.items()}
347
+ elif seg_only and opt.modalities_no > 0:
348
+ res = {f'G{opt.modalities_no}': tensor_to_pil(gens[f'G{opt.modalities_no}'].to(torch.device('cpu')))} if f'G{opt.modalities_no}' in gens else {}
349
+ res[f'G{opt.mod_id_seg}'] = tensor_to_pil(seg.to(torch.device('cpu')))
327
350
  else:
328
351
  res = {k: tensor_to_pil(v.to(torch.device('cpu'))) for k, v in gens.items()}
329
352
  res.update({k: tensor_to_pil(v.to(torch.device('cpu'))) for k, v in segs.items()})
330
- res['G5'] = tensor_to_pil(seg.to(torch.device('cpu')))
353
+ res[f'G{opt.mod_id_seg}'] = tensor_to_pil(seg.to(torch.device('cpu')))
331
354
 
332
355
  return res
333
356
  elif opt.model in ['DeepLIIFExt','SDG','CycleGAN']:
@@ -343,6 +366,8 @@ def run_dask(img, model_path=None, nets=None, eager_mode=False, opt=None, seg_on
343
366
  gens = {k: forward(ts, nets[k]) for k in seg_map}
344
367
 
345
368
  res = {k: tensor_to_pil(v) for k, v in gens.items()}
369
+ if mod_only:
370
+ return res
346
371
 
347
372
  if opt.seg_gen:
348
373
  if use_dask:
@@ -358,36 +383,55 @@ def run_dask(img, model_path=None, nets=None, eager_mode=False, opt=None, seg_on
358
383
 
359
384
 
360
385
  def is_empty(tile):
361
- thresh = 15
386
+ thresh = 9
362
387
  if isinstance(tile, list): # for pair of tiles, only mark it as empty / no need for prediction if ALL tiles are empty
363
388
  return all([True if image_variance_gray(t) < thresh else False for t in tile])
364
389
  else:
365
390
  return True if image_variance_gray(tile) < thresh else False
366
391
 
367
392
 
368
- def run_wrapper(tile, run_fn, model_path=None, nets=None, eager_mode=False, opt=None, seg_only=False, seg_weights=None, use_dask=True, output_tensor=False):
393
+ def run_wrapper(tile, run_fn, model_path=None, nets=None, eager_mode=False, opt=None, seg_only=False, mod_only=False, seg_weights=None, use_dask=True, output_tensor=False):
369
394
  if opt.model in ['DeepLIIF','DeepLIIFKD']:
370
395
  if is_empty(tile):
371
- if seg_only:
372
- return {
373
- 'G4': Image.new(mode='RGB', size=(512, 512), color=(10, 10, 10)),
374
- 'G5': Image.new(mode='RGB', size=(512, 512), color=(0, 0, 0)),
396
+ if seg_only: # return seg image and the last translated modality
397
+ res = {
398
+ #f'G{opt.modalities_no}': Image.new(mode='RGB', size=(512, 512), color=(10, 10, 10)),
399
+ f'G{opt.modalities_no}': Image.new(mode='RGB', size=(512, 512), color=opt.background_colors[-1]),
400
+ f'G{opt.mod_id_seg}': Image.new(mode='RGB', size=(512, 512), color=(0, 0, 0)),
375
401
  }
402
+ elif mod_only:
403
+ res = {f'G{i+1}': Image.new(mode='RGB', size=(512, 512), color=opt.background_colors[i]) for i in range(opt.modalities_no)}
404
+
376
405
  else :
377
- return {
378
- 'G1': Image.new(mode='RGB', size=(512, 512), color=(201, 211, 208)),
379
- 'G2': Image.new(mode='RGB', size=(512, 512), color=(10, 10, 10)),
380
- 'G3': Image.new(mode='RGB', size=(512, 512), color=(0, 0, 0)),
381
- 'G4': Image.new(mode='RGB', size=(512, 512), color=(10, 10, 10)),
382
- 'G5': Image.new(mode='RGB', size=(512, 512), color=(0, 0, 0)),
383
- 'G51': Image.new(mode='RGB', size=(512, 512), color=(0, 0, 0)),
384
- 'G52': Image.new(mode='RGB', size=(512, 512), color=(0, 0, 0)),
385
- 'G53': Image.new(mode='RGB', size=(512, 512), color=(0, 0, 0)),
386
- 'G54': Image.new(mode='RGB', size=(512, 512), color=(0, 0, 0)),
387
- 'G55': Image.new(mode='RGB', size=(512, 512), color=(0, 0, 0)),
388
- }
406
+ # return {
407
+ # 'G1': Image.new(mode='RGB', size=(512, 512), color=(201, 211, 208)),
408
+ # 'G2': Image.new(mode='RGB', size=(512, 512), color=(10, 10, 10)),
409
+ # 'G3': Image.new(mode='RGB', size=(512, 512), color=(0, 0, 0)),
410
+ # 'G4': Image.new(mode='RGB', size=(512, 512), color=(10, 10, 10)),
411
+ # 'G5': Image.new(mode='RGB', size=(512, 512), color=(0, 0, 0)),
412
+ # 'G51': Image.new(mode='RGB', size=(512, 512), color=(0, 0, 0)),
413
+ # 'G52': Image.new(mode='RGB', size=(512, 512), color=(0, 0, 0)),
414
+ # 'G53': Image.new(mode='RGB', size=(512, 512), color=(0, 0, 0)),
415
+ # 'G54': Image.new(mode='RGB', size=(512, 512), color=(0, 0, 0)),
416
+ # 'G55': Image.new(mode='RGB', size=(512, 512), color=(0, 0, 0)),
417
+ # }
418
+ res = {**{f'G{i+1}': Image.new(mode='RGB', size=(512, 512), color=opt.background_colors[i]) for i in range(opt.modalities_no)},
419
+ **{f'G{opt.mod_id_seg}': Image.new(mode='RGB', size=(512, 512), color=(0, 0, 0))}}
420
+
421
+ # assign mod-wise seg output to corresponding keys, again currently input_id only supports 0 or 1
422
+ if opt.input_id == 1:
423
+ res = {**res,
424
+ **{f'G{opt.mod_id_seg}{i+1}': Image.new(mode='RGB', size=(512, 512), color=(0, 0, 0)) for i in range(opt.modalities_no+1)}}
425
+ else:
426
+ res = {**res,
427
+ **{f'G{opt.mod_id_seg}{i}': Image.new(mode='RGB', size=(512, 512), color=(0, 0, 0)) for i in range(opt.modalities_no+1)}}
428
+
429
+ # when modalities_no = 0... we do not need to generate G0 - output should be just a seg image
430
+ if 'G0' in res:
431
+ del res['G0']
432
+ return res
389
433
  else:
390
- return run_fn(tile, model_path, None, eager_mode, opt, seg_only, seg_weights)
434
+ return run_fn(tile, model_path, None, eager_mode, opt, seg_only, mod_only, seg_weights)
391
435
  elif opt.model in ['DeepLIIFExt', 'SDG']:
392
436
  if is_empty(tile):
393
437
  res = {'G_' + str(i): Image.new(mode='RGB', size=(512, 512)) for i in range(1, opt.modalities_no + 1)}
@@ -408,7 +452,8 @@ def run_wrapper(tile, run_fn, model_path=None, nets=None, eager_mode=False, opt=
408
452
 
409
453
  def inference(img, tile_size, overlap_size, model_path, use_torchserve=False,
410
454
  eager_mode=False, color_dapi=False, color_marker=False, opt=None,
411
- return_seg_intermediate=False, seg_only=False, seg_weights=None, opt_args={}):
455
+ return_seg_intermediate=False, seg_only=False, mod_only=False,
456
+ seg_weights=None, opt_args={}):
412
457
  """
413
458
  opt_args: a dictionary of key and values to add/overwrite to opt
414
459
  """
@@ -433,41 +478,75 @@ def inference(img, tile_size, overlap_size, model_path, use_torchserve=False,
433
478
 
434
479
  tiler = InferenceTiler(orig, tile_size, overlap_size)
435
480
  for tile in tiler:
436
- tiler.stitch(run_wrapper(tile, run_fn, model_path, None, eager_mode, opt, seg_only, seg_weights))
481
+ tiler.stitch(run_wrapper(tile, run_fn, model_path, None, eager_mode, opt, seg_only, mod_only, seg_weights))
437
482
 
438
483
  results = tiler.results()
439
-
484
+
440
485
  if opt.model in ['DeepLIIF','DeepLIIFKD']:
486
+ # check if both the elements and the order are exactly the same
487
+ l_modname = [f'mod{i+1}' for i in range(opt.modalities_no)]
488
+ if l_modname != opt.modalities_names[1:]:
489
+ # if not, append modalities_names to mod names
490
+ l_modname = [f'mod{i+1}-{mod_name}' for i,mod_name in enumerate(opt.modalities_names[1:])]
491
+ d_modname2id = {mod_name:f'G{i+1}' for i,mod_name in enumerate(l_modname)}
492
+
493
+ if opt.seg_gen:
494
+ l_modname_seg = [f'mod{i}' for i in range(opt.modalities_no+1)]
495
+ if l_modname_seg != opt.modalities_names:
496
+ # if not, append modalities_names to mod names
497
+ l_modname_seg = [f'mod{i}-{mod_name}' for i,mod_name in enumerate(opt.modalities_names)]
498
+ if f'G{opt.mod_id_seg}0' in results.keys():
499
+ d_modname2id_seg = {mod_name:f'G{opt.mod_id_seg}{i}' for i,mod_name in enumerate(l_modname_seg)}
500
+ else:
501
+ d_modname2id_seg = {mod_name:f'G{opt.mod_id_seg}{i+1}' for i,mod_name in enumerate(l_modname_seg)}
502
+
503
+ if not mod_only:
504
+ d_modname2id['Seg'] = f'G{opt.mod_id_seg}'
505
+
506
+ #print('d_modname2id:',d_modname2id)
507
+
441
508
  if seg_only:
442
- images = {'Seg': results['G5']}
443
- if 'G4' in results:
444
- images.update({'Marker': results['G4']})
509
+ images = {'Seg': results[d_modname2id['Seg']]}
510
+ marker_key = find_marker_key(d_modname2id)
511
+ if marker_key is not None:
512
+ images[marker_key] = results[d_modname2id[marker_key]]
513
+ # if f'G{opt.modalities_no}' in results:
514
+ # images.update({'Marker': results['G{opt.modalities_no}']})
515
+ # if 'Marker' in d_modname2id:
516
+ # images.update({'Marker': results[d_modname2id['Marker']]})
445
517
  else:
446
- images = {
447
- 'Hema': results['G1'],
448
- 'DAPI': results['G2'],
449
- 'Lap2': results['G3'],
450
- 'Marker': results['G4'],
451
- 'Seg': results['G5'],
452
- }
518
+ # images = {
519
+ # 'Hema': results['G1'],
520
+ # 'DAPI': results['G2'],
521
+ # 'Lap2': results['G3'],
522
+ # 'Marker': results['G4'],
523
+ # 'Seg': results['G5'],
524
+ # }
525
+ # images = {f'mod{i+1}': results[f'G{i+1}'] for i in range(opt.modalities_no)}
526
+ # images['Seg'] = results[f'G{opt.modalities_no+1}']
527
+ images = {mod_name: results[mod_id] for mod_name,mod_id in d_modname2id.items()}
453
528
 
454
529
  if return_seg_intermediate and not seg_only:
455
- images.update({'IHC_s':results['G51'],
456
- 'Hema_s':results['G52'],
457
- 'DAPI_s':results['G53'],
458
- 'Lap2_s':results['G54'],
459
- 'Marker_s':results['G55'],})
530
+ # images.update({'IHC_s':results['G51'],
531
+ # 'Hema_s':results['G52'],
532
+ # 'DAPI_s':results['G53'],
533
+ # 'Lap2_s':results['G54'],
534
+ # 'Marker_s':results['G55'],})
535
+
536
+ #images.update({f'mod{i+1}_s':results[f'G{opt.modalities_no+1}{i+1}'] for i in range(opt.modalities_no+1)})
537
+ #images.update({f'{mod_name}_s':results[f'G{opt.modalities_no+1}']})
538
+ images.update({f'{mod_name}_s':results[d_modname2id_seg[mod_name]] for mod_name in d_modname2id_seg.keys()})
460
539
 
461
- if color_dapi and not seg_only:
462
- matrix = ( 0, 0, 0, 0,
463
- 299/1000, 587/1000, 114/1000, 0,
464
- 299/1000, 587/1000, 114/1000, 0)
465
- images['DAPI'] = images['DAPI'].convert('RGB', matrix)
466
- if color_marker and not seg_only:
467
- matrix = (299/1000, 587/1000, 114/1000, 0,
468
- 299/1000, 587/1000, 114/1000, 0,
469
- 0, 0, 0, 0)
470
- images['Marker'] = images['Marker'].convert('RGB', matrix)
540
+ # if color_dapi and not seg_only:
541
+ # matrix = ( 0, 0, 0, 0,
542
+ # 299/1000, 587/1000, 114/1000, 0,
543
+ # 299/1000, 587/1000, 114/1000, 0)
544
+ # images['DAPI'] = images['DAPI'].convert('RGB', matrix)
545
+ # if color_marker and not seg_only:
546
+ # matrix = (299/1000, 587/1000, 114/1000, 0,
547
+ # 299/1000, 587/1000, 114/1000, 0,
548
+ # 0, 0, 0, 0)
549
+ # images['Marker'] = images['Marker'].convert('RGB', matrix)
471
550
  return images
472
551
 
473
552
  elif opt.model == 'DeepLIIFExt':
@@ -485,11 +564,11 @@ def inference(img, tile_size, overlap_size, model_path, use_torchserve=False,
485
564
  return results # return result images with default key names (i.e., net names)
486
565
 
487
566
 
488
- def postprocess(orig, images, tile_size, model, seg_thresh=150, size_thresh='default', marker_thresh=None, size_thresh_upper=None):
567
+ def postprocess(orig, images, tile_size, model, seg_thresh=120, size_thresh='default', marker_thresh=None, size_thresh_upper=None):
489
568
  if model in ['DeepLIIF','DeepLIIFKD']:
490
569
  resolution = '40x' if tile_size > 384 else ('20x' if tile_size > 192 else '10x')
491
570
  overlay, refined, scoring = compute_final_results(
492
- orig, images['Seg'], images.get('Marker'), resolution,
571
+ orig, images['Seg'], images.get(find_marker_key(images)), resolution,
493
572
  size_thresh, marker_thresh, size_thresh_upper, seg_thresh)
494
573
  processed_images = {}
495
574
  processed_images['SegOverlaid'] = Image.fromarray(overlay)
@@ -518,7 +597,7 @@ def postprocess(orig, images, tile_size, model, seg_thresh=150, size_thresh='def
518
597
 
519
598
  def infer_modalities(img, tile_size, model_dir, eager_mode=False,
520
599
  color_dapi=False, color_marker=False, opt=None,
521
- return_seg_intermediate=False, seg_only=False, seg_weights=None):
600
+ return_seg_intermediate=False, seg_only=False, mod_only=False, seg_weights=None):
522
601
  """
523
602
  This function is used to infer modalities for the given image using a trained model.
524
603
  :param img: The input image.
@@ -547,17 +626,21 @@ def infer_modalities(img, tile_size, model_dir, eager_mode=False,
547
626
  opt=opt,
548
627
  return_seg_intermediate=return_seg_intermediate,
549
628
  seg_only=seg_only,
629
+ mod_only=mod_only,
550
630
  seg_weights=seg_weights,
551
631
  )
552
632
 
553
633
  if not hasattr(opt,'seg_gen') or (hasattr(opt,'seg_gen') and opt.seg_gen): # the first condition accounts for old settings of deepliif; the second refers to deepliifext models
554
- post_images, scoring = postprocess(img, images, tile_size, opt.model)
555
- images = {**images, **post_images}
556
- if seg_only:
557
- delete_keys = [k for k in images.keys() if 'Seg' not in k]
558
- for name in delete_keys:
559
- del images[name]
560
- return images, scoring
634
+ if not mod_only:
635
+ post_images, scoring = postprocess(img, images, tile_size, opt.model)
636
+ images = {**images, **post_images}
637
+ if seg_only:
638
+ delete_keys = [k for k in images.keys() if 'Seg' not in k]
639
+ for name in delete_keys:
640
+ del images[name]
641
+ return images, scoring
642
+ else:
643
+ return images, None
561
644
  else:
562
645
  return images, None
563
646
 
@@ -763,7 +846,8 @@ def infer_cells_for_wsi(filename, model_dir, tile_size, region_size=20000, versi
763
846
 
764
847
  seg = to_array(images['Seg'])
765
848
  del images['Seg']
766
- marker = to_array(images['Marker'], True) if 'Marker' in images else None
849
+ marker_key = find_marker_key(images)
850
+ marker = to_array(images[marker_key], True) if marker_key is not None else None
767
851
  del images
768
852
  region_data = compute_cell_results(seg, marker, resolution, version=version)
769
853
  del seg
@@ -818,3 +902,10 @@ def infer_cells_for_wsi(filename, model_dir, tile_size, region_size=20000, versi
818
902
  data['modelVersion'] = 'unknown'
819
903
 
820
904
  return data
905
+
906
+
907
+ def find_marker_key(dictionary):
908
+ for key in dictionary:
909
+ if key.endswith('Marker'):
910
+ return key
911
+ return None
@@ -148,8 +148,13 @@ class BaseModel(ABC):
148
148
  if not hasattr(self, name):
149
149
  if len(name.split('_')) != 2:
150
150
  if self.opt.model in ['DeepLIIF','DeepLIIFKD']:
151
- img_name = name[:-1] + '_' + name[-1]
152
- visual_ret[name] = getattr(self, img_name)
151
+ if name.endswith('_teacher'):
152
+ suffix = '_teacher'
153
+ name = name[:-8]
154
+ else:
155
+ suffix = ''
156
+ img_name = name[:-1] + '_' + name[-1] + suffix
157
+ visual_ret[name+suffix] = getattr(self, img_name)
153
158
  else:
154
159
  if self.opt.model == 'CycleGAN':
155
160
  l_output = getattr(self, name.split('_')[0] + '_' + name.split('_')[1])
@@ -2,7 +2,7 @@
2
2
 
3
3
  from pathlib import Path
4
4
  import os
5
- from ..util.util import mkdirs
5
+ from ..util.util import mkdirs, init_input_and_mod_id
6
6
  import re
7
7
 
8
8
  def read_model_params(file_addr):
@@ -32,7 +32,7 @@ def read_model_params(file_addr):
32
32
 
33
33
  # if isinstance(param_dict[key],list):
34
34
  # param_dict[key] = param_dict[key][0]
35
-
35
+
36
36
  return param_dict
37
37
 
38
38
  class Options:
@@ -43,7 +43,7 @@ class Options:
43
43
 
44
44
  if path_file:
45
45
  d_params = read_model_params(path_file)
46
-
46
+
47
47
  for k,v in d_params.items():
48
48
  try:
49
49
  if k not in ['phase']: # e.g., k = 'phase', v = 'train', eval(v) is a function rather than a string
@@ -69,6 +69,7 @@ class Options:
69
69
  else:
70
70
  self.phase = 'test'
71
71
  self.is_train = False
72
+ self.continue_train = False
72
73
  self.input_nc = 3
73
74
  self.output_nc = 3
74
75
  self.ngf = 64
@@ -83,6 +84,41 @@ class Options:
83
84
  # and can be configured in the inference function
84
85
  self.BtoA = False if not hasattr(self,'BtoA') else self.BtoA
85
86
 
87
+ # to account for old settings before modalities_no was introduced
88
+ if not hasattr(self,'modalities_no') and hasattr(self,'targets_no'):
89
+ self.modalities_no = self.targets_no - 1
90
+ del self.targets_no
91
+
92
+ if self.model in ['DeepLIIF','DeepLIIFKD']:
93
+ self.mod_id_seg, self.input_id = init_input_and_mod_id(self, os.path.dirname(path_file))
94
+ self.input_id = int(self.input_id)
95
+ print('mod id seg:', self.mod_id_seg, '; input id:', self.input_id)
96
+
97
+ print('Determining modalities names for test-mode model...')
98
+ if self.modalities_no == 4:
99
+ if not hasattr(self,'modalities_names'):
100
+ self.modalities_names = ['IHC','Hema','DAPI','Lap2','Marker']
101
+ self.seg_weights = [0.5,0,0,0,0.5]
102
+ elif not hasattr(self,'modalities_names') or len(self.modalities_names)==0:
103
+ # if self.model == 'DeepLIIFKD':
104
+ # # try find the modalities names from the teacher model
105
+ # d_params_teacher = read_model_params(os.path.join(self.model_dir_teacher,'train_opt.txt'))
106
+ # if 'modalities_names' in d_params_teacher:
107
+ # self.modalities_names = d_params_teacher['modalities_names']
108
+ # # check again
109
+ # if not hasattr(self,'modalities_names') or len(self.modalities_names)==0:
110
+ self.modalities_names = [f'mod{i}' for i in range(self.modalities_no+1)]
111
+ else:
112
+ self.modalities_names = [f'mod{i}' for i in range(self.modalities_no+1)]
113
+
114
+ print('modalities names:', self.modalities_names)
115
+
116
+ if not hasattr(self, 'background_colors'):
117
+ if self.model in ['DeepLIIF','DeepLIIFKD']:
118
+ self.background_colors = [(201, 211, 208),(10, 10, 10), (0, 0, 0), (10, 10, 10)]
119
+ else:
120
+ self.background_colors = [(10, 10, 10)] * self.modalities_no
121
+
86
122
  # reset checkpoints_dir and name based on the model directory
87
123
  # when base model is initialized: self.save_dir = os.path.join(opt.checkpoints_dir, opt.name)
88
124
  model_dir = Path(path_file).parent
@@ -95,15 +131,11 @@ class Options:
95
131
  if isinstance(self.gpu_ids,int):
96
132
  self.gpu_ids = (self.gpu_ids,)
97
133
 
98
- # to account for old settings before modalities_no was introduced
99
- if not hasattr(self,'modalities_no') and hasattr(self,'targets_no'):
100
- self.modalities_no = self.targets_no - 1
101
- del self.targets_no
102
-
103
134
  # to account for old settings: same as in cli.py train
104
135
  if not hasattr(self,'seg_no'):
105
136
  if self.model == 'DeepLIIF':
106
137
  self.seg_no = 1
138
+ self.seg_gen = True
107
139
  elif self.model == 'DeepLIIFExt':
108
140
  if self.seg_gen:
109
141
  self.seg_no = self.modalities_no
@@ -80,7 +80,7 @@ def adjust_marker(inferred_tile, orig_tile):
80
80
 
81
81
 
82
82
  # Default postprocessing values
83
- DEFAULT_SEG_THRESH = 150
83
+ DEFAULT_SEG_THRESH = 120
84
84
  DEFAULT_NOISE_THRESH = 4
85
85
 
86
86
  # Values for uint8 label masks