deepliif 1.2.3__py3-none-any.whl → 1.2.4__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- cli.py +79 -24
- deepliif/data/base_dataset.py +2 -0
- deepliif/models/DeepLIIFKD_model.py +243 -255
- deepliif/models/DeepLIIF_model.py +344 -235
- deepliif/models/__init__.py +194 -103
- deepliif/models/base_model.py +7 -2
- deepliif/options/__init__.py +40 -8
- deepliif/postprocessing.py +1 -1
- deepliif/util/__init__.py +98 -1
- deepliif/util/util.py +85 -0
- deepliif/util/visualizer.py +2 -2
- {deepliif-1.2.3.dist-info → deepliif-1.2.4.dist-info}/METADATA +2 -2
- {deepliif-1.2.3.dist-info → deepliif-1.2.4.dist-info}/RECORD +17 -17
- {deepliif-1.2.3.dist-info → deepliif-1.2.4.dist-info}/LICENSE.md +0 -0
- {deepliif-1.2.3.dist-info → deepliif-1.2.4.dist-info}/WHEEL +0 -0
- {deepliif-1.2.3.dist-info → deepliif-1.2.4.dist-info}/entry_points.txt +0 -0
- {deepliif-1.2.3.dist-info → deepliif-1.2.4.dist-info}/top_level.txt +0 -0
deepliif/models/__init__.py
CHANGED
|
@@ -133,7 +133,7 @@ def load_eager_models(opt, devices=None):
|
|
|
133
133
|
model_names = list(devices.keys())
|
|
134
134
|
else:
|
|
135
135
|
model_names = model.model_names
|
|
136
|
-
|
|
136
|
+
|
|
137
137
|
for name in model_names:#model.model_names:
|
|
138
138
|
if isinstance(name, str):
|
|
139
139
|
if '_' in name:
|
|
@@ -155,7 +155,6 @@ def load_eager_models(opt, devices=None):
|
|
|
155
155
|
|
|
156
156
|
return nets
|
|
157
157
|
|
|
158
|
-
|
|
159
158
|
@lru_cache
|
|
160
159
|
def init_nets(model_dir, eager_mode=False, opt=None, phase='test'):
|
|
161
160
|
"""
|
|
@@ -169,14 +168,20 @@ def init_nets(model_dir, eager_mode=False, opt=None, phase='test'):
|
|
|
169
168
|
opt = get_opt(model_dir, mode=phase)
|
|
170
169
|
opt.use_dp = False
|
|
171
170
|
|
|
171
|
+
# get l_net_groups: a list of lists, each sublist is a candidate, ordered by priority (try 1st, if fail try 2nd, etc.)
|
|
172
172
|
if opt.model in ['DeepLIIF','DeepLIIFKD']:
|
|
173
|
-
net_groups = [
|
|
174
|
-
|
|
175
|
-
|
|
176
|
-
|
|
177
|
-
|
|
178
|
-
|
|
179
|
-
]
|
|
173
|
+
# net_groups = [
|
|
174
|
+
# ('G1', 'G52'),
|
|
175
|
+
# ('G2', 'G53'),
|
|
176
|
+
# ('G3', 'G54'),
|
|
177
|
+
# ('G4', 'G55'),
|
|
178
|
+
# ('G51',)
|
|
179
|
+
# ]
|
|
180
|
+
if opt.modalities_no == 0:
|
|
181
|
+
net_groups = [(f'G{opt.mod_id_seg}{opt.input_id}',)]
|
|
182
|
+
else:
|
|
183
|
+
net_groups = [(f'G{i+1}', f'G{opt.mod_id_seg}{int(opt.input_id)+i+1}') for i in range(opt.modalities_no)]
|
|
184
|
+
net_groups += [(f'G{opt.mod_id_seg}{opt.input_id}',)] # this is the generator for the input base mod
|
|
180
185
|
elif opt.model in ['DeepLIIFExt','SDG']:
|
|
181
186
|
if opt.seg_gen:
|
|
182
187
|
net_groups = [(f'G_{i+1}',f'GS_{i+1}') for i in range(opt.modalities_no)]
|
|
@@ -197,8 +202,8 @@ def init_nets(model_dir, eager_mode=False, opt=None, phase='test'):
|
|
|
197
202
|
mapping_gpu_ids = {i:idx for i,idx in enumerate(opt.gpu_ids)}
|
|
198
203
|
chunks = [itertools.chain.from_iterable(c) for c in chunker(net_groups, number_of_gpus)]
|
|
199
204
|
# chunks = chunks[1:]
|
|
200
|
-
|
|
201
|
-
|
|
205
|
+
#l_devices = [{n: torch.device(f'cuda:{mapping_gpu_ids[i]}') for i, g in enumerate(chunks) for n in g} for chunks in l_chunks]
|
|
206
|
+
devices = {n: torch.device(f'cuda:{i}') for i, g in enumerate(chunks) for n in g}
|
|
202
207
|
else:
|
|
203
208
|
devices = {n: torch.device('cpu') for n in itertools.chain.from_iterable(net_groups)}
|
|
204
209
|
|
|
@@ -219,13 +224,14 @@ def compute_overlap(img_size, tile_size):
|
|
|
219
224
|
return tile_size // 4
|
|
220
225
|
|
|
221
226
|
|
|
222
|
-
def run_torchserve(img, model_path=None, nets=None, eager_mode=False, opt=None, seg_only=False, seg_weights=None, use_dask=True, output_tensor=False):
|
|
227
|
+
def run_torchserve(img, model_path=None, nets=None, eager_mode=False, opt=None, seg_only=False, mod_only=False, seg_weights=None, use_dask=True, output_tensor=False):
|
|
223
228
|
"""
|
|
224
229
|
eager_mode: not used in this function; put in place to be consistent with run_dask
|
|
225
230
|
so that run_wrapper() could call either this function or run_dask with
|
|
226
231
|
same syntax
|
|
227
232
|
opt: same as eager_mode
|
|
228
233
|
seg_only: same as eager_mode
|
|
234
|
+
mod_only: same as eager_mode
|
|
229
235
|
seg_weights: same as eager_mode
|
|
230
236
|
nets: same as eager_mode
|
|
231
237
|
"""
|
|
@@ -246,7 +252,8 @@ def run_torchserve(img, model_path=None, nets=None, eager_mode=False, opt=None,
|
|
|
246
252
|
return {k: tensor_to_pil(deserialize_tensor(v)) for k, v in res.json().items()}
|
|
247
253
|
|
|
248
254
|
|
|
249
|
-
def run_dask(img, model_path=None, nets=None, eager_mode=False, opt=None, seg_only=False,
|
|
255
|
+
def run_dask(img, model_path=None, nets=None, eager_mode=False, opt=None, seg_only=False, mod_only=False,
|
|
256
|
+
seg_weights=None, use_dask=True, output_tensor=False):
|
|
250
257
|
"""
|
|
251
258
|
Provide either the model path or the networks object.
|
|
252
259
|
|
|
@@ -269,65 +276,81 @@ def run_dask(img, model_path=None, nets=None, eager_mode=False, opt=None, seg_on
|
|
|
269
276
|
else:
|
|
270
277
|
ts = transform(img.resize((opt.scale_size, opt.scale_size)))
|
|
271
278
|
|
|
272
|
-
|
|
279
|
+
|
|
273
280
|
if use_dask:
|
|
274
281
|
@delayed
|
|
275
282
|
def forward(input, model):
|
|
276
283
|
with torch.no_grad():
|
|
277
284
|
return model(input.to(next(model.parameters()).device))
|
|
278
|
-
else: # some train settings like spectral norm
|
|
285
|
+
else: # some train settings like spectral norm somehow in inference mode is not compatible with dask
|
|
279
286
|
def forward(input, model):
|
|
280
287
|
with torch.no_grad():
|
|
281
288
|
return model(input.to(next(model.parameters()).device))
|
|
282
289
|
|
|
283
290
|
if opt.model in ['DeepLIIF','DeepLIIFKD']:
|
|
284
291
|
if seg_weights is None:
|
|
285
|
-
weights = {
|
|
286
|
-
|
|
287
|
-
|
|
288
|
-
|
|
289
|
-
|
|
290
|
-
|
|
291
|
-
}
|
|
292
|
+
# weights = {
|
|
293
|
+
# 'G51': 0.5, # IHC
|
|
294
|
+
# 'G52': 0.0, # Hema
|
|
295
|
+
# 'G53': 0.0, # DAPI
|
|
296
|
+
# 'G54': 0.0, # Lap2
|
|
297
|
+
# 'G55': 0.5, # Marker
|
|
298
|
+
# }
|
|
299
|
+
weights = {f'G{opt.mod_id_seg}{int(opt.input_id)+i}': 1/(opt.modalities_no+1) for i in range(opt.modalities_no+1)}
|
|
292
300
|
else:
|
|
293
|
-
weights = {
|
|
294
|
-
|
|
295
|
-
|
|
296
|
-
|
|
297
|
-
|
|
298
|
-
|
|
299
|
-
}
|
|
300
|
-
|
|
301
|
-
|
|
301
|
+
# weights = {
|
|
302
|
+
# 'G51': seg_weights[0], # IHC
|
|
303
|
+
# 'G52': seg_weights[1], # Hema
|
|
304
|
+
# 'G53': seg_weights[2], # DAPI
|
|
305
|
+
# 'G54': seg_weights[3], # Lap2
|
|
306
|
+
# 'G55': seg_weights[4], # Marker
|
|
307
|
+
# }
|
|
308
|
+
weights = {f'G{opt.mod_id_seg}{int(opt.input_id)+i}': seg_weight for i,seg_weight in enumerate(seg_weights)}
|
|
309
|
+
|
|
310
|
+
|
|
311
|
+
seg_map = {f'G{i+1}': f'G{opt.mod_id_seg}{int(opt.input_id)+i+1}' for i in range(opt.modalities_no)}
|
|
302
312
|
if seg_only:
|
|
303
313
|
seg_map = {k: v for k, v in seg_map.items() if weights[v] != 0}
|
|
304
314
|
|
|
305
315
|
lazy_gens = {k: forward(ts, nets[k]) for k in seg_map}
|
|
306
|
-
if '
|
|
307
|
-
|
|
316
|
+
if 'Marker' in opt.modalities_names:
|
|
317
|
+
mod_id_marker = opt.modalities_names.index("Marker")
|
|
318
|
+
if f'G{mod_id_marker}' not in seg_map:
|
|
319
|
+
lazy_gens[f'G{mod_id_marker}'] = forward(ts, nets[f'G{mod_id_marker}'])
|
|
320
|
+
|
|
308
321
|
gens = compute(lazy_gens)[0]
|
|
309
322
|
|
|
310
|
-
|
|
311
|
-
|
|
312
|
-
|
|
313
|
-
|
|
323
|
+
if not mod_only:
|
|
324
|
+
lazy_segs = {v: forward(gens[k], nets[v]) for k, v in seg_map.items()}
|
|
325
|
+
# run seg generator for the base input
|
|
326
|
+
if weights[f'G{opt.mod_id_seg}{opt.input_id}'] != 0:
|
|
327
|
+
lazy_segs[f'G{opt.mod_id_seg}{opt.input_id}'] = forward(ts, nets[f'G{opt.mod_id_seg}{opt.input_id}'])
|
|
328
|
+
segs = compute(lazy_segs)[0]
|
|
314
329
|
|
|
315
|
-
|
|
316
|
-
|
|
330
|
+
model_name_first = list(nets.keys())[0]
|
|
331
|
+
device = next(nets[model_name_first].parameters()).device # take the device of the first net and move all outputs there for seg aggregation
|
|
332
|
+
seg = torch.stack([torch.mul(segs[k].to(device), weights[k]) for k in segs.keys()]).sum(dim=0)
|
|
317
333
|
|
|
318
334
|
if output_tensor:
|
|
319
|
-
if
|
|
320
|
-
res =
|
|
335
|
+
if mod_only:
|
|
336
|
+
res = gens
|
|
337
|
+
elif seg_only and opt.modalities_no > 0:
|
|
338
|
+
res = {f'G{opt.modalities_no}': gens[f'G{opt.modalities_no}']} if f'G{opt.modalities_no}' in gens else {}
|
|
339
|
+
res[f'G{opt.mod_id_seg}'] = seg
|
|
321
340
|
else:
|
|
322
341
|
res = {**gens, **segs}
|
|
323
|
-
|
|
342
|
+
res[f'G{opt.mod_id_seg}'] = seg
|
|
343
|
+
|
|
324
344
|
else:
|
|
325
|
-
if
|
|
326
|
-
res = {
|
|
345
|
+
if mod_only:
|
|
346
|
+
res = {k: tensor_to_pil(v.to(torch.device('cpu'))) for k, v in gens.items()}
|
|
347
|
+
elif seg_only and opt.modalities_no > 0:
|
|
348
|
+
res = {f'G{opt.modalities_no}': tensor_to_pil(gens[f'G{opt.modalities_no}'].to(torch.device('cpu')))} if f'G{opt.modalities_no}' in gens else {}
|
|
349
|
+
res[f'G{opt.mod_id_seg}'] = tensor_to_pil(seg.to(torch.device('cpu')))
|
|
327
350
|
else:
|
|
328
351
|
res = {k: tensor_to_pil(v.to(torch.device('cpu'))) for k, v in gens.items()}
|
|
329
352
|
res.update({k: tensor_to_pil(v.to(torch.device('cpu'))) for k, v in segs.items()})
|
|
330
|
-
|
|
353
|
+
res[f'G{opt.mod_id_seg}'] = tensor_to_pil(seg.to(torch.device('cpu')))
|
|
331
354
|
|
|
332
355
|
return res
|
|
333
356
|
elif opt.model in ['DeepLIIFExt','SDG','CycleGAN']:
|
|
@@ -343,6 +366,8 @@ def run_dask(img, model_path=None, nets=None, eager_mode=False, opt=None, seg_on
|
|
|
343
366
|
gens = {k: forward(ts, nets[k]) for k in seg_map}
|
|
344
367
|
|
|
345
368
|
res = {k: tensor_to_pil(v) for k, v in gens.items()}
|
|
369
|
+
if mod_only:
|
|
370
|
+
return res
|
|
346
371
|
|
|
347
372
|
if opt.seg_gen:
|
|
348
373
|
if use_dask:
|
|
@@ -358,36 +383,55 @@ def run_dask(img, model_path=None, nets=None, eager_mode=False, opt=None, seg_on
|
|
|
358
383
|
|
|
359
384
|
|
|
360
385
|
def is_empty(tile):
|
|
361
|
-
thresh =
|
|
386
|
+
thresh = 9
|
|
362
387
|
if isinstance(tile, list): # for pair of tiles, only mark it as empty / no need for prediction if ALL tiles are empty
|
|
363
388
|
return all([True if image_variance_gray(t) < thresh else False for t in tile])
|
|
364
389
|
else:
|
|
365
390
|
return True if image_variance_gray(tile) < thresh else False
|
|
366
391
|
|
|
367
392
|
|
|
368
|
-
def run_wrapper(tile, run_fn, model_path=None, nets=None, eager_mode=False, opt=None, seg_only=False, seg_weights=None, use_dask=True, output_tensor=False):
|
|
393
|
+
def run_wrapper(tile, run_fn, model_path=None, nets=None, eager_mode=False, opt=None, seg_only=False, mod_only=False, seg_weights=None, use_dask=True, output_tensor=False):
|
|
369
394
|
if opt.model in ['DeepLIIF','DeepLIIFKD']:
|
|
370
395
|
if is_empty(tile):
|
|
371
|
-
if seg_only:
|
|
372
|
-
|
|
373
|
-
'
|
|
374
|
-
'
|
|
396
|
+
if seg_only: # return seg image and the last translated modality
|
|
397
|
+
res = {
|
|
398
|
+
#f'G{opt.modalities_no}': Image.new(mode='RGB', size=(512, 512), color=(10, 10, 10)),
|
|
399
|
+
f'G{opt.modalities_no}': Image.new(mode='RGB', size=(512, 512), color=opt.background_colors[-1]),
|
|
400
|
+
f'G{opt.mod_id_seg}': Image.new(mode='RGB', size=(512, 512), color=(0, 0, 0)),
|
|
375
401
|
}
|
|
402
|
+
elif mod_only:
|
|
403
|
+
res = {f'G{i+1}': Image.new(mode='RGB', size=(512, 512), color=opt.background_colors[i]) for i in range(opt.modalities_no)}
|
|
404
|
+
|
|
376
405
|
else :
|
|
377
|
-
return {
|
|
378
|
-
|
|
379
|
-
|
|
380
|
-
|
|
381
|
-
|
|
382
|
-
|
|
383
|
-
|
|
384
|
-
|
|
385
|
-
|
|
386
|
-
|
|
387
|
-
|
|
388
|
-
}
|
|
406
|
+
# return {
|
|
407
|
+
# 'G1': Image.new(mode='RGB', size=(512, 512), color=(201, 211, 208)),
|
|
408
|
+
# 'G2': Image.new(mode='RGB', size=(512, 512), color=(10, 10, 10)),
|
|
409
|
+
# 'G3': Image.new(mode='RGB', size=(512, 512), color=(0, 0, 0)),
|
|
410
|
+
# 'G4': Image.new(mode='RGB', size=(512, 512), color=(10, 10, 10)),
|
|
411
|
+
# 'G5': Image.new(mode='RGB', size=(512, 512), color=(0, 0, 0)),
|
|
412
|
+
# 'G51': Image.new(mode='RGB', size=(512, 512), color=(0, 0, 0)),
|
|
413
|
+
# 'G52': Image.new(mode='RGB', size=(512, 512), color=(0, 0, 0)),
|
|
414
|
+
# 'G53': Image.new(mode='RGB', size=(512, 512), color=(0, 0, 0)),
|
|
415
|
+
# 'G54': Image.new(mode='RGB', size=(512, 512), color=(0, 0, 0)),
|
|
416
|
+
# 'G55': Image.new(mode='RGB', size=(512, 512), color=(0, 0, 0)),
|
|
417
|
+
# }
|
|
418
|
+
res = {**{f'G{i+1}': Image.new(mode='RGB', size=(512, 512), color=opt.background_colors[i]) for i in range(opt.modalities_no)},
|
|
419
|
+
**{f'G{opt.mod_id_seg}': Image.new(mode='RGB', size=(512, 512), color=(0, 0, 0))}}
|
|
420
|
+
|
|
421
|
+
# assign mod-wise seg output to corresponding keys, again currently input_id only supports 0 or 1
|
|
422
|
+
if opt.input_id == 1:
|
|
423
|
+
res = {**res,
|
|
424
|
+
**{f'G{opt.mod_id_seg}{i+1}': Image.new(mode='RGB', size=(512, 512), color=(0, 0, 0)) for i in range(opt.modalities_no+1)}}
|
|
425
|
+
else:
|
|
426
|
+
res = {**res,
|
|
427
|
+
**{f'G{opt.mod_id_seg}{i}': Image.new(mode='RGB', size=(512, 512), color=(0, 0, 0)) for i in range(opt.modalities_no+1)}}
|
|
428
|
+
|
|
429
|
+
# when modalities_no = 0... we do not need to generate G0 - output should be just a seg image
|
|
430
|
+
if 'G0' in res:
|
|
431
|
+
del res['G0']
|
|
432
|
+
return res
|
|
389
433
|
else:
|
|
390
|
-
return run_fn(tile, model_path, None, eager_mode, opt, seg_only, seg_weights)
|
|
434
|
+
return run_fn(tile, model_path, None, eager_mode, opt, seg_only, mod_only, seg_weights)
|
|
391
435
|
elif opt.model in ['DeepLIIFExt', 'SDG']:
|
|
392
436
|
if is_empty(tile):
|
|
393
437
|
res = {'G_' + str(i): Image.new(mode='RGB', size=(512, 512)) for i in range(1, opt.modalities_no + 1)}
|
|
@@ -408,7 +452,8 @@ def run_wrapper(tile, run_fn, model_path=None, nets=None, eager_mode=False, opt=
|
|
|
408
452
|
|
|
409
453
|
def inference(img, tile_size, overlap_size, model_path, use_torchserve=False,
|
|
410
454
|
eager_mode=False, color_dapi=False, color_marker=False, opt=None,
|
|
411
|
-
return_seg_intermediate=False, seg_only=False,
|
|
455
|
+
return_seg_intermediate=False, seg_only=False, mod_only=False,
|
|
456
|
+
seg_weights=None, opt_args={}):
|
|
412
457
|
"""
|
|
413
458
|
opt_args: a dictionary of key and values to add/overwrite to opt
|
|
414
459
|
"""
|
|
@@ -433,41 +478,75 @@ def inference(img, tile_size, overlap_size, model_path, use_torchserve=False,
|
|
|
433
478
|
|
|
434
479
|
tiler = InferenceTiler(orig, tile_size, overlap_size)
|
|
435
480
|
for tile in tiler:
|
|
436
|
-
tiler.stitch(run_wrapper(tile, run_fn, model_path, None, eager_mode, opt, seg_only, seg_weights))
|
|
481
|
+
tiler.stitch(run_wrapper(tile, run_fn, model_path, None, eager_mode, opt, seg_only, mod_only, seg_weights))
|
|
437
482
|
|
|
438
483
|
results = tiler.results()
|
|
439
|
-
|
|
484
|
+
|
|
440
485
|
if opt.model in ['DeepLIIF','DeepLIIFKD']:
|
|
486
|
+
# check if both the elements and the order are exactly the same
|
|
487
|
+
l_modname = [f'mod{i+1}' for i in range(opt.modalities_no)]
|
|
488
|
+
if l_modname != opt.modalities_names[1:]:
|
|
489
|
+
# if not, append modalities_names to mod names
|
|
490
|
+
l_modname = [f'mod{i+1}-{mod_name}' for i,mod_name in enumerate(opt.modalities_names[1:])]
|
|
491
|
+
d_modname2id = {mod_name:f'G{i+1}' for i,mod_name in enumerate(l_modname)}
|
|
492
|
+
|
|
493
|
+
if opt.seg_gen:
|
|
494
|
+
l_modname_seg = [f'mod{i}' for i in range(opt.modalities_no+1)]
|
|
495
|
+
if l_modname_seg != opt.modalities_names:
|
|
496
|
+
# if not, append modalities_names to mod names
|
|
497
|
+
l_modname_seg = [f'mod{i}-{mod_name}' for i,mod_name in enumerate(opt.modalities_names)]
|
|
498
|
+
if f'G{opt.mod_id_seg}0' in results.keys():
|
|
499
|
+
d_modname2id_seg = {mod_name:f'G{opt.mod_id_seg}{i}' for i,mod_name in enumerate(l_modname_seg)}
|
|
500
|
+
else:
|
|
501
|
+
d_modname2id_seg = {mod_name:f'G{opt.mod_id_seg}{i+1}' for i,mod_name in enumerate(l_modname_seg)}
|
|
502
|
+
|
|
503
|
+
if not mod_only:
|
|
504
|
+
d_modname2id['Seg'] = f'G{opt.mod_id_seg}'
|
|
505
|
+
|
|
506
|
+
#print('d_modname2id:',d_modname2id)
|
|
507
|
+
|
|
441
508
|
if seg_only:
|
|
442
|
-
images = {'Seg': results['
|
|
443
|
-
|
|
444
|
-
|
|
509
|
+
images = {'Seg': results[d_modname2id['Seg']]}
|
|
510
|
+
marker_key = find_marker_key(d_modname2id)
|
|
511
|
+
if marker_key is not None:
|
|
512
|
+
images[marker_key] = results[d_modname2id[marker_key]]
|
|
513
|
+
# if f'G{opt.modalities_no}' in results:
|
|
514
|
+
# images.update({'Marker': results['G{opt.modalities_no}']})
|
|
515
|
+
# if 'Marker' in d_modname2id:
|
|
516
|
+
# images.update({'Marker': results[d_modname2id['Marker']]})
|
|
445
517
|
else:
|
|
446
|
-
images = {
|
|
447
|
-
|
|
448
|
-
|
|
449
|
-
|
|
450
|
-
|
|
451
|
-
|
|
452
|
-
}
|
|
518
|
+
# images = {
|
|
519
|
+
# 'Hema': results['G1'],
|
|
520
|
+
# 'DAPI': results['G2'],
|
|
521
|
+
# 'Lap2': results['G3'],
|
|
522
|
+
# 'Marker': results['G4'],
|
|
523
|
+
# 'Seg': results['G5'],
|
|
524
|
+
# }
|
|
525
|
+
# images = {f'mod{i+1}': results[f'G{i+1}'] for i in range(opt.modalities_no)}
|
|
526
|
+
# images['Seg'] = results[f'G{opt.modalities_no+1}']
|
|
527
|
+
images = {mod_name: results[mod_id] for mod_name,mod_id in d_modname2id.items()}
|
|
453
528
|
|
|
454
529
|
if return_seg_intermediate and not seg_only:
|
|
455
|
-
images.update({'IHC_s':results['G51'],
|
|
456
|
-
|
|
457
|
-
|
|
458
|
-
|
|
459
|
-
|
|
530
|
+
# images.update({'IHC_s':results['G51'],
|
|
531
|
+
# 'Hema_s':results['G52'],
|
|
532
|
+
# 'DAPI_s':results['G53'],
|
|
533
|
+
# 'Lap2_s':results['G54'],
|
|
534
|
+
# 'Marker_s':results['G55'],})
|
|
535
|
+
|
|
536
|
+
#images.update({f'mod{i+1}_s':results[f'G{opt.modalities_no+1}{i+1}'] for i in range(opt.modalities_no+1)})
|
|
537
|
+
#images.update({f'{mod_name}_s':results[f'G{opt.modalities_no+1}']})
|
|
538
|
+
images.update({f'{mod_name}_s':results[d_modname2id_seg[mod_name]] for mod_name in d_modname2id_seg.keys()})
|
|
460
539
|
|
|
461
|
-
if color_dapi and not seg_only:
|
|
462
|
-
|
|
463
|
-
|
|
464
|
-
|
|
465
|
-
|
|
466
|
-
if color_marker and not seg_only:
|
|
467
|
-
|
|
468
|
-
|
|
469
|
-
|
|
470
|
-
|
|
540
|
+
# if color_dapi and not seg_only:
|
|
541
|
+
# matrix = ( 0, 0, 0, 0,
|
|
542
|
+
# 299/1000, 587/1000, 114/1000, 0,
|
|
543
|
+
# 299/1000, 587/1000, 114/1000, 0)
|
|
544
|
+
# images['DAPI'] = images['DAPI'].convert('RGB', matrix)
|
|
545
|
+
# if color_marker and not seg_only:
|
|
546
|
+
# matrix = (299/1000, 587/1000, 114/1000, 0,
|
|
547
|
+
# 299/1000, 587/1000, 114/1000, 0,
|
|
548
|
+
# 0, 0, 0, 0)
|
|
549
|
+
# images['Marker'] = images['Marker'].convert('RGB', matrix)
|
|
471
550
|
return images
|
|
472
551
|
|
|
473
552
|
elif opt.model == 'DeepLIIFExt':
|
|
@@ -485,11 +564,11 @@ def inference(img, tile_size, overlap_size, model_path, use_torchserve=False,
|
|
|
485
564
|
return results # return result images with default key names (i.e., net names)
|
|
486
565
|
|
|
487
566
|
|
|
488
|
-
def postprocess(orig, images, tile_size, model, seg_thresh=
|
|
567
|
+
def postprocess(orig, images, tile_size, model, seg_thresh=120, size_thresh='default', marker_thresh=None, size_thresh_upper=None):
|
|
489
568
|
if model in ['DeepLIIF','DeepLIIFKD']:
|
|
490
569
|
resolution = '40x' if tile_size > 384 else ('20x' if tile_size > 192 else '10x')
|
|
491
570
|
overlay, refined, scoring = compute_final_results(
|
|
492
|
-
orig, images['Seg'], images.get(
|
|
571
|
+
orig, images['Seg'], images.get(find_marker_key(images)), resolution,
|
|
493
572
|
size_thresh, marker_thresh, size_thresh_upper, seg_thresh)
|
|
494
573
|
processed_images = {}
|
|
495
574
|
processed_images['SegOverlaid'] = Image.fromarray(overlay)
|
|
@@ -518,7 +597,7 @@ def postprocess(orig, images, tile_size, model, seg_thresh=150, size_thresh='def
|
|
|
518
597
|
|
|
519
598
|
def infer_modalities(img, tile_size, model_dir, eager_mode=False,
|
|
520
599
|
color_dapi=False, color_marker=False, opt=None,
|
|
521
|
-
return_seg_intermediate=False, seg_only=False, seg_weights=None):
|
|
600
|
+
return_seg_intermediate=False, seg_only=False, mod_only=False, seg_weights=None):
|
|
522
601
|
"""
|
|
523
602
|
This function is used to infer modalities for the given image using a trained model.
|
|
524
603
|
:param img: The input image.
|
|
@@ -547,17 +626,21 @@ def infer_modalities(img, tile_size, model_dir, eager_mode=False,
|
|
|
547
626
|
opt=opt,
|
|
548
627
|
return_seg_intermediate=return_seg_intermediate,
|
|
549
628
|
seg_only=seg_only,
|
|
629
|
+
mod_only=mod_only,
|
|
550
630
|
seg_weights=seg_weights,
|
|
551
631
|
)
|
|
552
632
|
|
|
553
633
|
if not hasattr(opt,'seg_gen') or (hasattr(opt,'seg_gen') and opt.seg_gen): # the first condition accounts for old settings of deepliif; the second refers to deepliifext models
|
|
554
|
-
|
|
555
|
-
|
|
556
|
-
|
|
557
|
-
|
|
558
|
-
|
|
559
|
-
|
|
560
|
-
|
|
634
|
+
if not mod_only:
|
|
635
|
+
post_images, scoring = postprocess(img, images, tile_size, opt.model)
|
|
636
|
+
images = {**images, **post_images}
|
|
637
|
+
if seg_only:
|
|
638
|
+
delete_keys = [k for k in images.keys() if 'Seg' not in k]
|
|
639
|
+
for name in delete_keys:
|
|
640
|
+
del images[name]
|
|
641
|
+
return images, scoring
|
|
642
|
+
else:
|
|
643
|
+
return images, None
|
|
561
644
|
else:
|
|
562
645
|
return images, None
|
|
563
646
|
|
|
@@ -763,7 +846,8 @@ def infer_cells_for_wsi(filename, model_dir, tile_size, region_size=20000, versi
|
|
|
763
846
|
|
|
764
847
|
seg = to_array(images['Seg'])
|
|
765
848
|
del images['Seg']
|
|
766
|
-
|
|
849
|
+
marker_key = find_marker_key(images)
|
|
850
|
+
marker = to_array(images[marker_key], True) if marker_key is not None else None
|
|
767
851
|
del images
|
|
768
852
|
region_data = compute_cell_results(seg, marker, resolution, version=version)
|
|
769
853
|
del seg
|
|
@@ -818,3 +902,10 @@ def infer_cells_for_wsi(filename, model_dir, tile_size, region_size=20000, versi
|
|
|
818
902
|
data['modelVersion'] = 'unknown'
|
|
819
903
|
|
|
820
904
|
return data
|
|
905
|
+
|
|
906
|
+
|
|
907
|
+
def find_marker_key(dictionary):
|
|
908
|
+
for key in dictionary:
|
|
909
|
+
if key.endswith('Marker'):
|
|
910
|
+
return key
|
|
911
|
+
return None
|
deepliif/models/base_model.py
CHANGED
|
@@ -148,8 +148,13 @@ class BaseModel(ABC):
|
|
|
148
148
|
if not hasattr(self, name):
|
|
149
149
|
if len(name.split('_')) != 2:
|
|
150
150
|
if self.opt.model in ['DeepLIIF','DeepLIIFKD']:
|
|
151
|
-
|
|
152
|
-
|
|
151
|
+
if name.endswith('_teacher'):
|
|
152
|
+
suffix = '_teacher'
|
|
153
|
+
name = name[:-8]
|
|
154
|
+
else:
|
|
155
|
+
suffix = ''
|
|
156
|
+
img_name = name[:-1] + '_' + name[-1] + suffix
|
|
157
|
+
visual_ret[name+suffix] = getattr(self, img_name)
|
|
153
158
|
else:
|
|
154
159
|
if self.opt.model == 'CycleGAN':
|
|
155
160
|
l_output = getattr(self, name.split('_')[0] + '_' + name.split('_')[1])
|
deepliif/options/__init__.py
CHANGED
|
@@ -2,7 +2,7 @@
|
|
|
2
2
|
|
|
3
3
|
from pathlib import Path
|
|
4
4
|
import os
|
|
5
|
-
from ..util.util import mkdirs
|
|
5
|
+
from ..util.util import mkdirs, init_input_and_mod_id
|
|
6
6
|
import re
|
|
7
7
|
|
|
8
8
|
def read_model_params(file_addr):
|
|
@@ -32,7 +32,7 @@ def read_model_params(file_addr):
|
|
|
32
32
|
|
|
33
33
|
# if isinstance(param_dict[key],list):
|
|
34
34
|
# param_dict[key] = param_dict[key][0]
|
|
35
|
-
|
|
35
|
+
|
|
36
36
|
return param_dict
|
|
37
37
|
|
|
38
38
|
class Options:
|
|
@@ -43,7 +43,7 @@ class Options:
|
|
|
43
43
|
|
|
44
44
|
if path_file:
|
|
45
45
|
d_params = read_model_params(path_file)
|
|
46
|
-
|
|
46
|
+
|
|
47
47
|
for k,v in d_params.items():
|
|
48
48
|
try:
|
|
49
49
|
if k not in ['phase']: # e.g., k = 'phase', v = 'train', eval(v) is a function rather than a string
|
|
@@ -69,6 +69,7 @@ class Options:
|
|
|
69
69
|
else:
|
|
70
70
|
self.phase = 'test'
|
|
71
71
|
self.is_train = False
|
|
72
|
+
self.continue_train = False
|
|
72
73
|
self.input_nc = 3
|
|
73
74
|
self.output_nc = 3
|
|
74
75
|
self.ngf = 64
|
|
@@ -83,6 +84,41 @@ class Options:
|
|
|
83
84
|
# and can be configured in the inference function
|
|
84
85
|
self.BtoA = False if not hasattr(self,'BtoA') else self.BtoA
|
|
85
86
|
|
|
87
|
+
# to account for old settings before modalities_no was introduced
|
|
88
|
+
if not hasattr(self,'modalities_no') and hasattr(self,'targets_no'):
|
|
89
|
+
self.modalities_no = self.targets_no - 1
|
|
90
|
+
del self.targets_no
|
|
91
|
+
|
|
92
|
+
if self.model in ['DeepLIIF','DeepLIIFKD']:
|
|
93
|
+
self.mod_id_seg, self.input_id = init_input_and_mod_id(self, os.path.dirname(path_file))
|
|
94
|
+
self.input_id = int(self.input_id)
|
|
95
|
+
print('mod id seg:', self.mod_id_seg, '; input id:', self.input_id)
|
|
96
|
+
|
|
97
|
+
print('Determining modalities names for test-mode model...')
|
|
98
|
+
if self.modalities_no == 4:
|
|
99
|
+
if not hasattr(self,'modalities_names'):
|
|
100
|
+
self.modalities_names = ['IHC','Hema','DAPI','Lap2','Marker']
|
|
101
|
+
self.seg_weights = [0.5,0,0,0,0.5]
|
|
102
|
+
elif not hasattr(self,'modalities_names') or len(self.modalities_names)==0:
|
|
103
|
+
# if self.model == 'DeepLIIFKD':
|
|
104
|
+
# # try find the modalities names from the teacher model
|
|
105
|
+
# d_params_teacher = read_model_params(os.path.join(self.model_dir_teacher,'train_opt.txt'))
|
|
106
|
+
# if 'modalities_names' in d_params_teacher:
|
|
107
|
+
# self.modalities_names = d_params_teacher['modalities_names']
|
|
108
|
+
# # check again
|
|
109
|
+
# if not hasattr(self,'modalities_names') or len(self.modalities_names)==0:
|
|
110
|
+
self.modalities_names = [f'mod{i}' for i in range(self.modalities_no+1)]
|
|
111
|
+
else:
|
|
112
|
+
self.modalities_names = [f'mod{i}' for i in range(self.modalities_no+1)]
|
|
113
|
+
|
|
114
|
+
print('modalities names:', self.modalities_names)
|
|
115
|
+
|
|
116
|
+
if not hasattr(self, 'background_colors'):
|
|
117
|
+
if self.model in ['DeepLIIF','DeepLIIFKD']:
|
|
118
|
+
self.background_colors = [(201, 211, 208),(10, 10, 10), (0, 0, 0), (10, 10, 10)]
|
|
119
|
+
else:
|
|
120
|
+
self.background_colors = [(10, 10, 10)] * self.modalities_no
|
|
121
|
+
|
|
86
122
|
# reset checkpoints_dir and name based on the model directory
|
|
87
123
|
# when base model is initialized: self.save_dir = os.path.join(opt.checkpoints_dir, opt.name)
|
|
88
124
|
model_dir = Path(path_file).parent
|
|
@@ -95,15 +131,11 @@ class Options:
|
|
|
95
131
|
if isinstance(self.gpu_ids,int):
|
|
96
132
|
self.gpu_ids = (self.gpu_ids,)
|
|
97
133
|
|
|
98
|
-
# to account for old settings before modalities_no was introduced
|
|
99
|
-
if not hasattr(self,'modalities_no') and hasattr(self,'targets_no'):
|
|
100
|
-
self.modalities_no = self.targets_no - 1
|
|
101
|
-
del self.targets_no
|
|
102
|
-
|
|
103
134
|
# to account for old settings: same as in cli.py train
|
|
104
135
|
if not hasattr(self,'seg_no'):
|
|
105
136
|
if self.model == 'DeepLIIF':
|
|
106
137
|
self.seg_no = 1
|
|
138
|
+
self.seg_gen = True
|
|
107
139
|
elif self.model == 'DeepLIIFExt':
|
|
108
140
|
if self.seg_gen:
|
|
109
141
|
self.seg_no = self.modalities_no
|