deepliif 1.2.1__py3-none-any.whl → 1.2.2__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
@@ -1,980 +0,0 @@
1
- """This package contains modules related to objective functions, optimizations, and network architectures.
2
-
3
- To add a custom model class called 'dummy', you need to add a file called 'dummy_model.py' and define a subclass DummyModel inherited from BaseModel.
4
- You need to implement the following five functions:
5
- -- <__init__>: initialize the class; first call BaseModel.__init__(self, opt).
6
- -- <set_input>: unpack data from dataset and apply preprocessing.
7
- -- <forward>: produce intermediate results.
8
- -- <optimize_parameters>: calculate loss, gradients, and update network weights.
9
- -- <modify_commandline_options>: (optionally) add model-specific options and set default options.
10
-
11
- In the function <__init__>, you need to define four lists:
12
- -- self.loss_names (str list): specify the training losses that you want to plot and save.
13
- -- self.model_names (str list): define networks used in our training.
14
- -- self.visual_names (str list): specify the images that you want to display and save.
15
- -- self.optimizers (optimizer list): define and initialize optimizers. You can define one optimizer for each network. If two networks are updated at the same time, you can use itertools.chain to group them. See CycleGAN_model.py for an usage.
16
-
17
- Now you can use the model class by specifying flag '--model dummy'.
18
- See our template model class 'template_model.py' for more details.
19
- """
20
- import base64
21
- import os
22
- import itertools
23
- import importlib
24
- from functools import lru_cache
25
- from io import BytesIO
26
- import json
27
- import math
28
- import importlib.metadata
29
- import pathlib
30
- #from multiprocessing import Pool
31
- from torch.multiprocessing import Pool
32
-
33
- import requests
34
- import torch
35
- from PIL import Image
36
- Image.MAX_IMAGE_PIXELS = None
37
-
38
- import numpy as np
39
- from dask import delayed, compute
40
-
41
- from deepliif.util import *
42
- from deepliif.util.util import tensor_to_pil
43
- from deepliif.data import transform
44
- from deepliif.postprocessing import compute_final_results, compute_cell_results, to_array
45
- from deepliif.postprocessing import encode_cell_data_v4, decode_cell_data_v4
46
- from deepliif.options import Options, print_options
47
-
48
- from .base_model import BaseModel
49
-
50
- # import for init purpose, not used in this script
51
- from .DeepLIIF_model import DeepLIIFModel
52
- from .DeepLIIFExt_model import DeepLIIFExtModel
53
-
54
-
55
- @lru_cache
56
- def get_opt(model_dir, mode='test'):
57
- """
58
- mode: test or train, currently only functions used for inference utilize get_opt so it
59
- defaults to test
60
- """
61
- if mode == 'train':
62
- opt = Options(path_file=os.path.join(model_dir,'train_opt.txt'), mode=mode)
63
- elif mode == 'test':
64
- try:
65
- opt = Options(path_file=os.path.join(model_dir,'test_opt.txt'), mode=mode)
66
- except:
67
- opt = Options(path_file=os.path.join(model_dir,'train_opt.txt'), mode=mode)
68
- opt.use_dp = False
69
- opt.gpu_ids = list(range(torch.cuda.device_count()))
70
- return opt
71
-
72
-
73
- def find_model_using_name(model_name):
74
- """Import the module "models/[model_name]_model.py".
75
-
76
- In the file, the class called DatasetNameModel() will
77
- be instantiated. It has to be a subclass of BaseModel,
78
- and it is case-insensitive.
79
- """
80
- model_filename = "deepliif.models." + model_name + "_model"
81
- modellib = importlib.import_module(model_filename)
82
- model = None
83
- target_model_name = model_name.replace('_', '') + 'model'
84
- for name, cls in modellib.__dict__.items():
85
- if name.lower() == target_model_name.lower() \
86
- and issubclass(cls, BaseModel):
87
- model = cls
88
-
89
- if model is None:
90
- print("In %s.py, there should be a subclass of BaseModel with class name that matches %s in lowercase." % (
91
- model_filename, target_model_name))
92
- exit(0)
93
-
94
- return model
95
-
96
-
97
- def get_option_setter(model_name):
98
- """Return the static method <modify_commandline_options> of the model class."""
99
- model_class = find_model_using_name(model_name)
100
- return model_class.modify_commandline_options
101
-
102
-
103
- def create_model(opt):
104
- """Create a model given the option.
105
-
106
- This function warps the class CustomDatasetDataLoader.
107
- This is the main interface between this package and 'train.py'/'test.py'
108
-
109
- Example:
110
- >>> from deepliif.models import create_model
111
- >>> model = create_model(opt)
112
- """
113
- model = find_model_using_name(opt.model)
114
- instance = model(opt)
115
- print("model [%s] was created" % type(instance).__name__)
116
- return instance
117
-
118
-
119
- def load_torchscript_model(model_pt_path, device):
120
- net = torch.jit.load(model_pt_path, map_location=device)
121
- net = disable_batchnorm_tracking_stats(net)
122
- net.eval()
123
- return net
124
-
125
-
126
-
127
- def load_eager_models(opt, devices=None):
128
- # create a model given model and other options
129
- model = create_model(opt)
130
- # regular setup: load and print networks; create schedulers
131
- model.setup(opt)
132
-
133
- nets = {}
134
- if devices:
135
- model_names = list(devices.keys())
136
- else:
137
- model_names = model.model_names
138
-
139
- for name in model_names:#model.model_names:
140
- if isinstance(name, str):
141
- if '_' in name:
142
- net = getattr(model, 'net' + name.split('_')[0])[int(name.split('_')[-1]) - 1]
143
- else:
144
- net = getattr(model, 'net' + name)
145
-
146
- if opt.phase != 'train':
147
- net.eval()
148
- net = disable_batchnorm_tracking_stats(net)
149
-
150
- # SDG models when loaded are still DP.. not sure why
151
- if isinstance(net, torch.nn.DataParallel):
152
- net = net.module
153
-
154
- nets[name] = net
155
- if devices:
156
- nets[name].to(devices[name])
157
-
158
- return nets
159
-
160
-
161
- @lru_cache
162
- def init_nets(model_dir, eager_mode=False, opt=None, phase='test'):
163
- """
164
- Init DeepLIIF networks so that every net in
165
- the same group is deployed on the same GPU
166
-
167
- opt_args: to overwrite opt arguments in train_opt.txt, typically used in inference stage
168
- for example, opt_args={'phase':'test'}
169
- """
170
- if opt is None:
171
- opt = get_opt(model_dir, mode=phase)
172
- opt.use_dp = False
173
-
174
- if opt.model in ['DeepLIIF','DeepLIIFKD']:
175
- net_groups = [
176
- ('G1', 'G52'),
177
- ('G2', 'G53'),
178
- ('G3', 'G54'),
179
- ('G4', 'G55'),
180
- ('G51',)
181
- ]
182
- elif opt.model in ['DeepLIIFExt','SDG']:
183
- if opt.seg_gen:
184
- net_groups = [(f'G_{i+1}',f'GS_{i+1}') for i in range(opt.modalities_no)]
185
- else:
186
- net_groups = [(f'G_{i+1}',) for i in range(opt.modalities_no)]
187
- elif opt.model == 'CycleGAN':
188
- if opt.BtoA:
189
- net_groups = [(f'GB_{i+1}',) for i in range(opt.modalities_no)]
190
- else:
191
- net_groups = [(f'GA_{i+1}',) for i in range(opt.modalities_no)]
192
- else:
193
- raise Exception(f'init_nets() not implemented for model {opt.model}')
194
-
195
- number_of_gpus_all = torch.cuda.device_count()
196
- number_of_gpus = min(len(opt.gpu_ids),number_of_gpus_all)
197
-
198
- if number_of_gpus > 0:
199
- mapping_gpu_ids = {i:idx for i,idx in enumerate(opt.gpu_ids)}
200
- chunks = [itertools.chain.from_iterable(c) for c in chunker(net_groups, number_of_gpus)]
201
- # chunks = chunks[1:]
202
- devices = {n: torch.device(f'cuda:{mapping_gpu_ids[i]}') for i, g in enumerate(chunks) for n in g}
203
- # devices = {n: torch.device(f'cuda:{i}') for i, g in enumerate(chunks) for n in g}
204
- else:
205
- devices = {n: torch.device('cpu') for n in itertools.chain.from_iterable(net_groups)}
206
-
207
- if eager_mode:
208
- return load_eager_models(opt, devices)
209
-
210
- return {
211
- n: load_torchscript_model(os.path.join(model_dir, f'{n}.pt'), device=d)
212
- for n, d in devices.items()
213
- }
214
-
215
-
216
- def compute_overlap(img_size, tile_size):
217
- w, h = img_size
218
- if round(w / tile_size) == 1 and round(h / tile_size) == 1:
219
- return 0
220
-
221
- return tile_size // 4
222
-
223
-
224
- def run_torchserve(img, model_path=None, nets=None, eager_mode=False, opt=None, seg_only=False, use_dask=True, output_tensor=False):
225
- """
226
- eager_mode: not used in this function; put in place to be consistent with run_dask
227
- so that run_wrapper() could call either this function or run_dask with
228
- same syntax
229
- opt: same as eager_mode
230
- seg_only: same as eager_mode
231
- nets: same as eager_mode
232
- """
233
- buffer = BytesIO()
234
- torch.save(transform(img.resize((opt.scale_size, opt.scale_size))), buffer)
235
-
236
- torchserve_host = os.getenv('TORCHSERVE_HOST', 'http://localhost')
237
- res = requests.post(
238
- f'{torchserve_host}/wfpredict/deepliif',
239
- json={'img': base64.b64encode(buffer.getvalue()).decode('utf-8')}
240
- )
241
-
242
- res.raise_for_status()
243
-
244
- def deserialize_tensor(bs):
245
- return torch.load(BytesIO(base64.b64decode(bs.encode())), map_location=torch.device('cpu'))
246
-
247
- return {k: tensor_to_pil(deserialize_tensor(v)) for k, v in res.json().items()}
248
-
249
-
250
- def run_dask(img, model_path=None, nets=None, eager_mode=False, opt=None, seg_only=False, use_dask=True, output_tensor=False):
251
- """
252
- Provide either the model path or the networks object.
253
-
254
- `eager_mode` is only applicable if model_path is provided.
255
- """
256
- assert model_path is not None or nets is not None, 'Provide either the model path or the networks object.'
257
- if nets is None:
258
- model_dir = os.getenv('DEEPLIIF_MODEL_DIR', model_path)
259
- nets = init_nets(model_dir, eager_mode, opt)
260
-
261
- if use_dask: # check if use_dask should be overwritten
262
- use_dask = True if opt.norm != 'spectral' else False
263
-
264
- if isinstance(img,torch.Tensor): # if img input is already a tensor, pass
265
- ts = img
266
- else:
267
- if opt.input_no > 1 or opt.model == 'SDG':
268
- l_ts = [transform(img_i.resize((opt.scale_size,opt.scale_size))) for img_i in img]
269
- ts = torch.cat(l_ts, dim=1)
270
- else:
271
- ts = transform(img.resize((opt.scale_size, opt.scale_size)))
272
-
273
-
274
- if use_dask:
275
- @delayed
276
- def forward(input, model):
277
- with torch.no_grad():
278
- return model(input.to(next(model.parameters()).device))
279
- else: # some train settings like spectral norm some how in inference mode is not compatible with dask
280
- def forward(input, model):
281
- with torch.no_grad():
282
- return model(input.to(next(model.parameters()).device))
283
-
284
- if opt.model in ['DeepLIIF','DeepLIIFKD']:
285
- #weights = {
286
- # 'G51': 0.25, # IHC
287
- # 'G52': 0.25, # Hema
288
- # 'G53': 0.25, # DAPI
289
- # 'G54': 0.00, # Lap2
290
- # 'G55': 0.25, # Marker
291
- #}
292
- weights = {
293
- 'G51': 0.5, # IHC
294
- 'G52': 0.0, # Hema
295
- 'G53': 0.0, # DAPI
296
- 'G54': 0.0, # Lap2
297
- 'G55': 0.5, # Marker
298
- }
299
-
300
- seg_map = {'G1': 'G52', 'G2': 'G53', 'G3': 'G54', 'G4': 'G55'}
301
- if seg_only:
302
- seg_map = {k: v for k, v in seg_map.items() if weights[v] != 0}
303
-
304
- lazy_gens = {k: forward(ts, nets[k]) for k in seg_map}
305
- if 'G4' not in seg_map:
306
- lazy_gens['G4'] = forward(ts, nets['G4'])
307
- gens = compute(lazy_gens)[0]
308
-
309
- lazy_segs = {v: forward(gens[k], nets[v]) for k, v in seg_map.items()}
310
- if not seg_only or weights['G51'] != 0:
311
- lazy_segs['G51'] = forward(ts, nets['G51'])
312
- segs = compute(lazy_segs)[0]
313
-
314
- device = next(nets['G1'].parameters()).device # take the device of the first net and move all outputs there for seg aggregation
315
- seg = torch.stack([torch.mul(segs[k].to(device), weights[k]) for k in segs.keys()]).sum(dim=0)
316
-
317
- if output_tensor:
318
- if seg_only:
319
- res = {'G4': gens['G4']} if 'G4' in gens else {}
320
- else:
321
- res = {**gens, **segs}
322
- res['G5'] = seg
323
- else:
324
- if seg_only:
325
- res = {'G4': tensor_to_pil(gens['G4'].to(torch.device('cpu')))} if 'G4' in gens else {}
326
- else:
327
- res = {k: tensor_to_pil(v.to(torch.device('cpu'))) for k, v in gens.items()}
328
- res.update({k: tensor_to_pil(v.to(torch.device('cpu'))) for k, v in segs.items()})
329
- res['G5'] = tensor_to_pil(seg.to(torch.device('cpu')))
330
-
331
- return res
332
- elif opt.model in ['DeepLIIFExt','SDG','CycleGAN']:
333
- if opt.model == 'CycleGAN':
334
- seg_map = {f'GB_{i+1}':None for i in range(opt.modalities_no)} if opt.BtoA else {f'GA_{i+1}':None for i in range(opt.modalities_no)}
335
- else:
336
- seg_map = {'G_' + str(i): 'GS_' + str(i) for i in range(1, opt.modalities_no + 1)}
337
-
338
- if use_dask:
339
- lazy_gens = {k: forward(ts, nets[k]) for k in seg_map}
340
- gens = compute(lazy_gens)[0]
341
- else:
342
- gens = {k: forward(ts, nets[k]) for k in seg_map}
343
-
344
- res = {k: tensor_to_pil(v) for k, v in gens.items()}
345
-
346
- if opt.seg_gen:
347
- if use_dask:
348
- lazy_segs = {v: forward(torch.cat([ts.to(torch.device('cpu')), gens[next(iter(seg_map))].to(torch.device('cpu')), gens[k].to(torch.device('cpu'))], 1), nets[v]).to(torch.device('cpu')) for k, v in seg_map.items()}
349
- segs = compute(lazy_segs)[0]
350
- else:
351
- segs = {v: forward(torch.cat([ts.to(torch.device('cpu')), gens[next(iter(seg_map))].to(torch.device('cpu')), gens[k].to(torch.device('cpu'))], 1), nets[v]).to(torch.device('cpu')) for k, v in seg_map.items()}
352
- res.update({k: tensor_to_pil(v) for k, v in segs.items()})
353
-
354
- return res
355
- else:
356
- raise Exception(f'run_dask() not fully implemented for {opt.model}')
357
-
358
-
359
- def is_empty(tile):
360
- thresh = 15
361
- if isinstance(tile, list): # for pair of tiles, only mark it as empty / no need for prediction if ALL tiles are empty
362
- return all([True if image_variance_gray(t) < thresh else False for t in tile])
363
- else:
364
- return True if image_variance_gray(tile) < thresh else False
365
-
366
-
367
- def run_wrapper(tile, run_fn, model_path=None, nets=None, eager_mode=False, opt=None, seg_only=False, use_dask=True, output_tensor=False):
368
- if opt.model in ['DeepLIIF','DeepLIIFKD']:
369
- if is_empty(tile):
370
- if seg_only:
371
- return {
372
- 'G4': Image.new(mode='RGB', size=(512, 512), color=(10, 10, 10)),
373
- 'G5': Image.new(mode='RGB', size=(512, 512), color=(0, 0, 0)),
374
- }
375
- else :
376
- return {
377
- 'G1': Image.new(mode='RGB', size=(512, 512), color=(201, 211, 208)),
378
- 'G2': Image.new(mode='RGB', size=(512, 512), color=(10, 10, 10)),
379
- 'G3': Image.new(mode='RGB', size=(512, 512), color=(0, 0, 0)),
380
- 'G4': Image.new(mode='RGB', size=(512, 512), color=(10, 10, 10)),
381
- 'G5': Image.new(mode='RGB', size=(512, 512), color=(0, 0, 0)),
382
- 'G51': Image.new(mode='RGB', size=(512, 512), color=(0, 0, 0)),
383
- 'G52': Image.new(mode='RGB', size=(512, 512), color=(0, 0, 0)),
384
- 'G53': Image.new(mode='RGB', size=(512, 512), color=(0, 0, 0)),
385
- 'G54': Image.new(mode='RGB', size=(512, 512), color=(0, 0, 0)),
386
- 'G55': Image.new(mode='RGB', size=(512, 512), color=(0, 0, 0)),
387
- }
388
- else:
389
- return run_fn(tile, model_path, None, eager_mode, opt, seg_only)
390
- elif opt.model in ['DeepLIIFExt', 'SDG']:
391
- if is_empty(tile):
392
- res = {'G_' + str(i): Image.new(mode='RGB', size=(512, 512)) for i in range(1, opt.modalities_no + 1)}
393
- res.update({'GS_' + str(i): Image.new(mode='RGB', size=(512, 512)) for i in range(1, opt.modalities_no + 1)})
394
- return res
395
- else:
396
- return run_fn(tile, model_path, None, eager_mode, opt)
397
- elif opt.model in ['CycleGAN']:
398
- if is_empty(tile):
399
- net_names = ['GB_{i+1}' for i in range(opt.modalities_no)] if opt.BtoA else [f'GA_{i+1}' for i in range(opt.modalities_no)]
400
- res = {net_name: Image.new(mode='RGB', size=(512, 512)) for net_name in net_names}
401
- return res
402
- else:
403
- return run_fn(tile, model_path, None, eager_mode, opt)
404
- else:
405
- raise Exception(f'run_wrapper() not implemented for model {opt.model}')
406
-
407
-
408
- def inference(img, tile_size, overlap_size, model_path, use_torchserve=False,
409
- eager_mode=False, color_dapi=False, color_marker=False, opt=None,
410
- return_seg_intermediate=False, seg_only=False, opt_args={}):
411
- """
412
- opt_args: a dictionary of key and values to add/overwrite to opt
413
- """
414
- if not opt:
415
- opt = get_opt(model_path)
416
- #print_options(opt)
417
-
418
- for k,v in opt_args.items():
419
- setattr(opt,k,v)
420
- #print_options(opt)
421
-
422
- run_fn = run_torchserve if use_torchserve else run_dask
423
-
424
- if opt.model == 'SDG':
425
- # SDG could have multiple input images/modalities, hence the input could be a rectangle.
426
- # We split the input to get each modality image then create tiles for each set of input images.
427
- w, h = int(img.width / opt.input_no), img.height
428
- orig = [img.crop((w * i, 0, w * (i+1), h)) for i in range(opt.input_no)]
429
- else:
430
- # Otherwise expect a single input image, which is used directly.
431
- orig = img
432
-
433
- tiler = InferenceTiler(orig, tile_size, overlap_size)
434
- for tile in tiler:
435
- tiler.stitch(run_wrapper(tile, run_fn, model_path, None, eager_mode, opt, seg_only))
436
-
437
- results = tiler.results()
438
-
439
- if opt.model in ['DeepLIIF','DeepLIIFKD']:
440
- if seg_only:
441
- images = {'Seg': results['G5']}
442
- if 'G4' in results:
443
- images.update({'Marker': results['G4']})
444
- else:
445
- images = {
446
- 'Hema': results['G1'],
447
- 'DAPI': results['G2'],
448
- 'Lap2': results['G3'],
449
- 'Marker': results['G4'],
450
- 'Seg': results['G5'],
451
- }
452
-
453
- if return_seg_intermediate and not seg_only:
454
- images.update({'IHC_s':results['G51'],
455
- 'Hema_s':results['G52'],
456
- 'DAPI_s':results['G53'],
457
- 'Lap2_s':results['G54'],
458
- 'Marker_s':results['G55'],})
459
-
460
- if color_dapi and not seg_only:
461
- matrix = ( 0, 0, 0, 0,
462
- 299/1000, 587/1000, 114/1000, 0,
463
- 299/1000, 587/1000, 114/1000, 0)
464
- images['DAPI'] = images['DAPI'].convert('RGB', matrix)
465
- if color_marker and not seg_only:
466
- matrix = (299/1000, 587/1000, 114/1000, 0,
467
- 299/1000, 587/1000, 114/1000, 0,
468
- 0, 0, 0, 0)
469
- images['Marker'] = images['Marker'].convert('RGB', matrix)
470
- return images
471
-
472
- elif opt.model == 'DeepLIIFExt':
473
- images = {f'mod{i}': results[f'G_{i}'] for i in range(1, opt.modalities_no + 1)}
474
- if opt.seg_gen:
475
- images.update({f'Seg{i}': results[f'GS_{i}'] for i in range(1, opt.modalities_no + 1)})
476
- return images
477
-
478
- elif opt.model == 'SDG':
479
- images = {f'mod{i}': results[f'G_{i}'] for i in range(1, opt.modalities_no + 1)}
480
- return images
481
-
482
- else:
483
- #raise Exception(f'inference() not implemented for model {opt.model}')
484
- return results # return result images with default key names (i.e., net names)
485
-
486
-
487
- def postprocess(orig, images, tile_size, model, seg_thresh=150, size_thresh='default', marker_thresh=None, size_thresh_upper=None):
488
- if model in ['DeepLIIF','DeepLIIFKD']:
489
- resolution = '40x' if tile_size > 384 else ('20x' if tile_size > 192 else '10x')
490
- overlay, refined, scoring = compute_final_results(
491
- orig, images['Seg'], images.get('Marker'), resolution,
492
- size_thresh, marker_thresh, size_thresh_upper, seg_thresh)
493
- processed_images = {}
494
- processed_images['SegOverlaid'] = Image.fromarray(overlay)
495
- processed_images['SegRefined'] = Image.fromarray(refined)
496
- return processed_images, scoring
497
-
498
- elif model in ['DeepLIIFExt','SDG']:
499
- resolution = '40x' if tile_size > 768 else ('20x' if tile_size > 384 else '10x')
500
- processed_images = {}
501
- scoring = {}
502
- for img_name in list(images.keys()):
503
- if 'Seg' in img_name:
504
- seg_img = images[img_name]
505
- overlay, refined, score = compute_final_results(
506
- orig, images[img_name], None, resolution,
507
- size_thresh, marker_thresh, size_thresh_upper, seg_thresh)
508
-
509
- processed_images[img_name + '_Overlaid'] = Image.fromarray(overlay)
510
- processed_images[img_name + '_Refined'] = Image.fromarray(refined)
511
- scoring[img_name] = score
512
- return processed_images, scoring
513
-
514
- else:
515
- raise Exception(f'postprocess() not implemented for model {model}')
516
-
517
-
518
- def infer_modalities(img, tile_size, model_dir, eager_mode=False,
519
- color_dapi=False, color_marker=False, opt=None,
520
- return_seg_intermediate=False, seg_only=False):
521
- """
522
- This function is used to infer modalities for the given image using a trained model.
523
- :param img: The input image.
524
- :param tile_size: The tile size.
525
- :param model_dir: The directory containing serialized model files.
526
- :return: The inferred modalities and the segmentation mask.
527
- """
528
- if opt is None:
529
- opt = get_opt(model_dir)
530
- opt.use_dp = False
531
- #print_options(opt)
532
-
533
- # for those with multiple input modalities, find the correct size to calculate overlap_size
534
- input_no = opt.input_no if hasattr(opt, 'input_no') else 1
535
- img_size = (img.size[0] / input_no, img.size[1]) # (width, height)
536
-
537
- images = inference(
538
- img,
539
- tile_size=tile_size,
540
- #overlap_size=compute_overlap(img_size, tile_size),
541
- overlap_size=tile_size//16,
542
- model_path=model_dir,
543
- eager_mode=eager_mode,
544
- color_dapi=color_dapi,
545
- color_marker=color_marker,
546
- opt=opt,
547
- return_seg_intermediate=return_seg_intermediate,
548
- seg_only=seg_only
549
- )
550
-
551
- if not hasattr(opt,'seg_gen') or (hasattr(opt,'seg_gen') and opt.seg_gen): # the first condition accounts for old settings of deepliif; the second refers to deepliifext models
552
- post_images, scoring = postprocess(img, images, tile_size, opt.model)
553
- images = {**images, **post_images}
554
- if seg_only:
555
- delete_keys = [k for k in images.keys() if 'Seg' not in k]
556
- for name in delete_keys:
557
- del images[name]
558
- return images, scoring
559
- else:
560
- return images, None
561
-
562
-
563
- def infer_results_for_wsi(input_dir, filename, output_dir, model_dir, tile_size, region_size=20000, color_dapi=False, color_marker=False, seg_intermediate=False, seg_only=False):
564
- """
565
- This function infers modalities and segmentation mask for the given WSI image. It
566
-
567
- :param input_dir: The directory containing the WSI.
568
- :param filename: The WSI name.
569
- :param output_dir: The directory for saving the inferred modalities.
570
- :param model_dir: The directory containing the serialized model files.
571
- :param tile_size: The tile size.
572
- :param region_size: The size of each individual region to be processed at once.
573
- :return:
574
- """
575
- basename, _ = os.path.splitext(filename)
576
- results_dir = os.path.join(output_dir, basename)
577
- if not os.path.exists(results_dir):
578
- os.makedirs(results_dir)
579
- size_x, size_y, size_z, size_c, size_t, pixel_type = get_information(os.path.join(input_dir, filename))
580
- rescale = (pixel_type != 'uint8')
581
- print(filename, size_x, size_y, size_z, size_c, size_t, pixel_type, flush=True)
582
-
583
- results = {}
584
- scoring = None
585
-
586
- # javabridge already set up from previous call to get_information()
587
- with bioformats.ImageReader(os.path.join(input_dir, filename)) as reader:
588
- start_x, start_y = 0, 0
589
-
590
- while start_x < size_x:
591
- while start_y < size_y:
592
- print(start_x, start_y, flush=True)
593
- region_XYWH = (start_x, start_y, min(region_size, size_x - start_x), min(region_size, size_y - start_y))
594
- region = reader.read(XYWH=region_XYWH, rescale=rescale)
595
- img = Image.fromarray((region * 255).astype(np.uint8)) if rescale else Image.fromarray(region)
596
-
597
- region_modalities, region_scoring = infer_modalities(img, tile_size, model_dir, color_dapi=color_dapi, color_marker=color_marker, return_seg_intermediate=seg_intermediate, seg_only=seg_only)
598
- if region_scoring is not None:
599
- if scoring is None:
600
- scoring = {
601
- 'num_pos': region_scoring['num_pos'],
602
- 'num_neg': region_scoring['num_neg'],
603
- }
604
- else:
605
- scoring['num_pos'] += region_scoring['num_pos']
606
- scoring['num_neg'] += region_scoring['num_neg']
607
-
608
- for name, img in region_modalities.items():
609
- if name not in results:
610
- results[name] = np.zeros((size_y, size_x, 3), dtype=np.uint8)
611
- results[name][region_XYWH[1]: region_XYWH[1] + region_XYWH[3],
612
- region_XYWH[0]: region_XYWH[0] + region_XYWH[2]] = np.array(img)
613
- start_y += region_size
614
- start_y = 0
615
- start_x += region_size
616
-
617
- # write_results_to_pickle_file(os.path.join(results_dir, "results.pickle"), results)
618
- # read_results_from_pickle_file(os.path.join(results_dir, "results.pickle"))
619
-
620
- for name, img in results.items():
621
- write_big_tiff_file(os.path.join(results_dir, f'{basename}_{name}.ome.tiff'), img, tile_size)
622
-
623
- if scoring is not None:
624
- scoring['num_total'] = scoring['num_pos'] + scoring['num_neg']
625
- scoring['percent_pos'] = round(scoring['num_pos'] / scoring['num_total'] * 100, 1) if scoring['num_pos'] > 0 else 0
626
- with open(os.path.join(results_dir, f'{basename}.json'), 'w') as f:
627
- json.dump(scoring, f, indent=2)
628
-
629
-
630
- def get_wsi_resolution(filename):
631
- """
632
- Try to get the resolution (magnification) of the slide and
633
- the corresponding tile size to use by default for DeepLIIF.
634
- If it cannot be found, return (None, None) instead.
635
-
636
- Parameters
637
- ----------
638
- filename : str
639
- Full path to the file.
640
-
641
- Returns
642
- -------
643
- str :
644
- Magnification (objective power) from image metadata.
645
- int :
646
- Corresponding tile size for DeepLIIF.
647
- """
648
-
649
- init_javabridge_bioformats()
650
- metadata = bioformats.get_omexml_metadata(filename)
651
-
652
- mag = None
653
- try:
654
- omexml = bioformats.OMEXML(metadata)
655
- mag = omexml.instrument().Objective.NominalMagnification
656
- except Exception as e:
657
- fields = ['AppMag', 'NominalMagnification']
658
- try:
659
- for field in fields:
660
- idx = metadata.find(field)
661
- if idx >= 0:
662
- for i in range(idx, len(metadata)):
663
- if metadata[i].isdigit() or metadata[i] == '.':
664
- break
665
- for j in range(i, len(metadata)):
666
- if not metadata[j].isdigit() and metadata[j] != '.':
667
- break
668
- if i == j:
669
- continue
670
- mag = metadata[i:j]
671
- break
672
- except Exception as e:
673
- pass
674
-
675
- if mag is None:
676
- return None, None
677
-
678
- try:
679
- tile_size = round((float(mag) / 40) * 512)
680
- return mag, tile_size
681
- except Exception as e:
682
- return None, None
683
-
684
-
685
- def infer_cells_for_wsi(filename, model_dir, tile_size, region_size=20000, version=3, print_log=False):
686
- """
687
- Perform inference on a slide and get the results individual cell data.
688
-
689
- Parameters
690
- ----------
691
- filename : str
692
- Full path to the file.
693
- model_dir : str
694
- Full path to the directory with the DeepLIIF model files.
695
- tile_size : int
696
- Size of tiles to extract and perform inference on.
697
- region_size : int
698
- Maximum size to split the slide for processing.
699
- version : int
700
- Version of cell data to return (3 or 4).
701
- print_log : bool
702
- Whether or not to print updates while processing.
703
-
704
- Returns
705
- -------
706
- dict :
707
- Individual cell data and associated values.
708
- """
709
-
710
- def print_info(*args):
711
- if print_log:
712
- print(*args, flush=True)
713
-
714
- resolution = '40x' if tile_size > 384 else ('20x' if tile_size > 192 else '10x')
715
-
716
- data = None
717
- default_marker_thresh, count_marker_thresh = 0, 0
718
- default_size_thresh, count_size_thresh = 0, 0
719
-
720
- with WSIReader(filename) as reader:
721
- size_x = reader.width
722
- size_y = reader.height
723
- print_info('Info:', size_x, size_y)
724
-
725
- num_regions_x = math.ceil(size_x / region_size)
726
- num_regions_y = math.ceil(size_y / region_size)
727
- stride_x = math.ceil(size_x / num_regions_x)
728
- stride_y = math.ceil(size_y / num_regions_y)
729
- print_info('Strides:', stride_x, stride_y)
730
-
731
- start_x, start_y = 0, 0
732
-
733
- while start_y < size_y:
734
- while start_x < size_x:
735
- region_XYWH = (start_x, start_y, min(stride_x, size_x-start_x), min(stride_y, size_y-start_y))
736
- print_info('Region:', region_XYWH)
737
-
738
- region = reader.read(region_XYWH)
739
- print_info(region.shape, region.dtype)
740
- img = Image.fromarray(region)
741
- print_info(img.size, img.mode)
742
- del region
743
-
744
- images = inference(
745
- img,
746
- tile_size=tile_size,
747
- overlap_size=tile_size//16,
748
- model_path=model_dir,
749
- eager_mode=False,
750
- color_dapi=False,
751
- color_marker=False,
752
- opt=None,
753
- return_seg_intermediate=False,
754
- seg_only=True,
755
- )
756
- del img
757
-
758
- seg = to_array(images['Seg'])
759
- del images['Seg']
760
- marker = to_array(images['Marker'], True) if 'Marker' in images else None
761
- del images
762
- region_data = compute_cell_results(seg, marker, resolution, version=version)
763
- del seg
764
- del marker
765
-
766
- if start_x != 0 or start_y != 0:
767
- for i in range(len(region_data['cells'])):
768
- cell = decode_cell_data_v4(region_data['cells'][i]) if version == 4 else region_data['cells'][i]
769
- for j in range(2):
770
- cell['bbox'][j] = (cell['bbox'][j][0] + start_x, cell['bbox'][j][1] + start_y)
771
- cell['centroid'] = (cell['centroid'][0] + start_x, cell['centroid'][1] + start_y)
772
- for j in range(len(cell['boundary'])):
773
- cell['boundary'][j] = (cell['boundary'][j][0] + start_x, cell['boundary'][j][1] + start_y)
774
- region_data['cells'][i] = encode_cell_data_v4(cell) if version == 4 else cell
775
-
776
- if data is None:
777
- data = region_data
778
- else:
779
- data['cells'] += region_data['cells']
780
-
781
- if region_data['settings']['default_marker_thresh'] is not None and region_data['settings']['default_marker_thresh'] != 0:
782
- default_marker_thresh += region_data['settings']['default_marker_thresh']
783
- count_marker_thresh += 1
784
- if region_data['settings']['default_size_thresh'] != 0:
785
- default_size_thresh += region_data['settings']['default_size_thresh']
786
- count_size_thresh += 1
787
-
788
- start_x += stride_x
789
-
790
- start_x = 0
791
- start_y += stride_y
792
-
793
- if count_marker_thresh == 0:
794
- count_marker_thresh = 1
795
- if count_size_thresh == 0:
796
- count_size_thresh = 1
797
- data['settings']['default_marker_thresh'] = round(default_marker_thresh / count_marker_thresh)
798
- data['settings']['default_size_thresh'] = round(default_size_thresh / count_size_thresh)
799
-
800
- try:
801
- data['deepliifVersion'] = importlib.metadata.version('deepliif')
802
- except Exception as e:
803
- data['deepliifVersion'] = 'unknown'
804
-
805
- try:
806
- data['modelVersion'] = pathlib.PurePath(model_dir).name
807
- except Exception as e:
808
- data['modelVersion'] = 'unknown'
809
-
810
- return data
811
-
812
-
813
- def infer_cells_for_wsi_process_region(filename, xywh, model_dir, tile_size, version, print_log):
814
- def print_info(*args):
815
- if print_log:
816
- print(*args, flush=True)
817
-
818
- resolution = '40x' if tile_size > 384 else ('20x' if tile_size > 192 else '10x')
819
-
820
- #region_XYWH = region_param['xywh']
821
- print_info(os.getpid(), 'Region:', xywh)
822
- x = xywh[0]
823
- y = xywh[1]
824
-
825
- with WSIReader(filename) as reader:
826
- region = reader.read(xywh)
827
- print_info(region.shape, region.dtype)
828
- img = Image.fromarray(region)
829
- print_info(img.size, img.mode)
830
- del region
831
-
832
- images = inference(
833
- img,
834
- tile_size=tile_size,
835
- overlap_size=tile_size//16,
836
- model_path=model_dir,
837
- eager_mode=False,
838
- color_dapi=False,
839
- color_marker=False,
840
- opt=None,
841
- return_seg_intermediate=False,
842
- seg_only=True,
843
- )
844
- del img
845
-
846
- seg = to_array(images['Seg'])
847
- del images['Seg']
848
- marker = to_array(images['Marker'], True) if 'Marker' in images else None
849
- del images
850
- region_data = compute_cell_results(seg, marker, resolution, version=version)
851
- del seg
852
- del marker
853
-
854
- if x != 0 or y != 0:
855
- for i in range(len(region_data['cells'])):
856
- cell = decode_cell_data_v4(region_data['cells'][i]) if version == 4 else region_data['cells'][i]
857
- for j in range(2):
858
- cell['bbox'][j] = (cell['bbox'][j][0] + x, cell['bbox'][j][1] + y)
859
- cell['centroid'] = (cell['centroid'][0] + x, cell['centroid'][1] + y)
860
- for j in range(len(cell['boundary'])):
861
- cell['boundary'][j] = (cell['boundary'][j][0] + x, cell['boundary'][j][1] + y)
862
- region_data['cells'][i] = encode_cell_data_v4(cell) if version == 4 else cell
863
-
864
- return region_data
865
-
866
-
867
- def infer_cells_for_wsi_process(params):
868
- return infer_cells_for_wsi_process_region(**params)
869
-
870
-
871
- def infer_cells_for_wsi_multi(filename, model_dir, tile_size, region_size=20000, version=3, print_log=False, processes=1):
872
- """
873
- Perform inference on a slide and get the results individual cell data.
874
-
875
- Parameters
876
- ----------
877
- filename : str
878
- Full path to the file.
879
- model_dir : str
880
- Full path to the directory with the DeepLIIF model files.
881
- tile_size : int
882
- Size of tiles to extract and perform inference on.
883
- region_size : int
884
- Maximum size to split the slide for processing.
885
- version : int
886
- Version of cell data to return (3 or 4).
887
- print_log : bool
888
- Whether or not to print updates while processing.
889
-
890
- Returns
891
- -------
892
- dict :
893
- Individual cell data and associated values.
894
- """
895
-
896
- def print_info(*args):
897
- if print_log:
898
- print(*args, flush=True)
899
-
900
- def create_region_params(xywh):
901
- return {
902
- 'filename': filename,
903
- 'model_dir': model_dir,
904
- }
905
-
906
- #resolution = '40x' if tile_size > 384 else ('20x' if tile_size > 192 else '10x')
907
-
908
- with WSIReader(filename) as reader:
909
- size_x = reader.width
910
- size_y = reader.height
911
- print_info('Info:', size_x, size_y)
912
-
913
- num_regions_x = math.ceil(size_x / region_size)
914
- num_regions_y = math.ceil(size_y / region_size)
915
- stride_x = math.ceil(size_x / num_regions_x)
916
- stride_y = math.ceil(size_y / num_regions_y)
917
- print_info('Strides:', stride_x, stride_y)
918
-
919
- '''
920
- print('', flush=True)
921
- region_XYWHs = []
922
- for y in range(0, size_y, stride_y):
923
- for x in range(0, size_x, stride_x):
924
- region_XYWHs.append((x, y, min(stride_x, size_x-x), min(stride_y, size_y-y)))
925
- print(region_XYWHs[-1], flush=True)
926
- print('', flush=True)
927
- '''
928
- print('', flush=True)
929
- region_params = []
930
- for y in range(0, size_y, stride_y):
931
- for x in range(0, size_x, stride_x):
932
- region_params.append({
933
- 'xywh': (x, y, min(stride_x, size_x-x), min(stride_y, size_y-y)),
934
- 'filename': filename,
935
- 'model_dir': model_dir,
936
- 'tile_size': tile_size,
937
- 'version': version,
938
- 'print_log': print_log,
939
- })
940
- print(region_params[-1], flush=True)
941
- print('', flush=True)
942
-
943
- data = None
944
- default_marker_thresh, count_marker_thresh = 0, 0
945
- default_size_thresh, count_size_thresh = 0, 0
946
-
947
- #for region_param in region_params:
948
- # region_data = infer_cells_for_wsi_process(region_param)
949
- with Pool(processes) as pool:
950
- for region_data in pool.imap_unordered(infer_cells_for_wsi_process, region_params):
951
- if data is None:
952
- data = region_data
953
- else:
954
- data['cells'] += region_data['cells']
955
-
956
- if region_data['settings']['default_marker_thresh'] is not None and region_data['settings']['default_marker_thresh'] != 0:
957
- default_marker_thresh += region_data['settings']['default_marker_thresh']
958
- count_marker_thresh += 1
959
- if region_data['settings']['default_size_thresh'] != 0:
960
- default_size_thresh += region_data['settings']['default_size_thresh']
961
- count_size_thresh += 1
962
-
963
- if count_marker_thresh == 0:
964
- count_marker_thresh = 1
965
- if count_size_thresh == 0:
966
- count_size_thresh = 1
967
- data['settings']['default_marker_thresh'] = round(default_marker_thresh / count_marker_thresh)
968
- data['settings']['default_size_thresh'] = round(default_size_thresh / count_size_thresh)
969
-
970
- try:
971
- data['deepliifVersion'] = importlib.metadata.version('deepliif')
972
- except Exception as e:
973
- data['deepliifVersion'] = 'unknown'
974
-
975
- try:
976
- data['modelVersion'] = pathlib.PurePath(model_dir).name
977
- except Exception as e:
978
- data['modelVersion'] = 'unknown'
979
-
980
- return data